DeMemWM / scripts /dememwm_full_train.slurm
BonanDing's picture
Clean DeMemWM deterministic memory slot handling
93d7b0a
#!/usr/bin/env bash
#SBATCH --job-name=dememwm-full-train
#SBATCH --partition=gpu
#SBATCH --time=3-00:00:00
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=2
#SBATCH --cpus-per-task=16
#SBATCH --mem=256G
#SBATCH --gres=gpu:2
#SBATCH --chdir=/share_1/users/bonan_ding/DeMemWM
#SBATCH --output=/share_1/users/bonan_ding/DeMemWM/slurm_logs/dememwm-full-train-%j.out
#SBATCH --error=/share_1/users/bonan_ding/DeMemWM/slurm_logs/dememwm-full-train-%j.err
set -eo pipefail
source /share_0/conda/etc/profile.d/conda.sh
conda activate worldmem
export PYTHONPATH="./:$PYTHONPATH"
export HYDRA_FULL_ERROR=1
export PYTHONWARNINGS=ignore
export OMP_NUM_THREADS=16
export WANDB_MODE=online
export NCCL_P2P_DISABLE=1
wandb online >/dev/null 2>&1 || true
srun python -m main \
+name=train_dememwm_full_h200_2gpu_bs32_350k \
+output_dir=/share_1/users/bonan_ding/worldmem_ckpt/dememwm_full_h200_2gpu_bs32_350k/ \
wandb.mode=online \
auto_resume=true \
"experiment.tasks=[training]" \
algorithm=dememwm_memory_dit \
+customized_load=true \
+seperate_load=true \
+diffusion_model_path=/share_1/users/bonan_ding/WorldMem/open-oasis/checkpoints/oasis500m.safetensors \
+vae_path=/share_1/users/bonan_ding/WorldMem/open-oasis/checkpoints/vit-l-20.safetensors \
+only_tune_memory=false \
dataset=video_minecraft_latent \
dataset.save_dir=/share_1/users/bonan_ding/worldmem_data/minecraft \
dataset.precomputed_feature_dir=/share_1/users/bonan_ding/worldmem_data/minecraft/vae_features \
dataset.n_frames=1000 \
+dataset.n_frames_valid=1100 \
+dataset.customized_validation=true \
+dataset.memory_condition_length=0 \
+dataset.wo_updown=false \
+dataset.angle_range=180 \
+dataset.pos_range=8 \
++algorithm.n_tokens=4 \
"algorithm.x_shape=[16,18,32]" \
++algorithm.context_frames=100 \
++algorithm.log_video=true \
++algorithm.diffusion.sampling_timesteps=20 \
++algorithm.dememwm.debug_force_all_streams=false \
++algorithm.dememwm.generated_history_proxy.enabled=true \
++algorithm.dememwm.generated_history_proxy.start_step=40000 \
++algorithm.dememwm.generated_history_proxy.ramp_steps=40000 \
++algorithm.dememwm.generated_history_proxy.max_prob=0.25 \
++algorithm.dememwm.generated_history_proxy.noise_std=0.25 \
++algorithm.dememwm.generated_history_proxy.dropout_prob=0.0 \
++algorithm.dememwm.anchor.enabled=true \
++algorithm.dememwm.anchor.anchor_indices=[0,1,2,3] \
++algorithm.dememwm.anchor.diverse_selection=true \
++algorithm.dememwm.anchor.compress.downsample_ratio=3 \
++algorithm.dememwm.anchor.allow_generated_as_anchor=false \
++algorithm.dememwm.dynamic.enabled=true \
++algorithm.dememwm.dynamic.exclude_latest_local_frames=4 \
++algorithm.dememwm.dynamic.recent_frames=4 \
++algorithm.dememwm.revisit.enabled=true \
++algorithm.dememwm.revisit.deterministic_pose_retrieval=true \
++algorithm.dememwm.revisit.fov_overlap_threshold=0.60 \
++algorithm.dememwm.revisit.pose_preselect_topk=64 \
++algorithm.dememwm.revisit.fov_yaw_samples=25 \
++algorithm.dememwm.revisit.fov_pitch_samples=20 \
++algorithm.dememwm.revisit.fov_depth_samples=20 \
++algorithm.dememwm.revisit.plucker_weight=0.10 \
++algorithm.dememwm.revisit.max_frames=2 \
++algorithm.dememwm.revisit.compress.downsample_ratio=3 \
++algorithm.dememwm.stage_policy.noise_bucket_logging=true \
++algorithm.dememwm.cache.enabled=true \
++algorithm.dememwm.cache.device=cpu \
++algorithm.dememwm.cache.keep_raw_latents=all \
++algorithm.dememwm.cache.keep_compressed_records=true \
++algorithm.dememwm.cache.eviction_policy=none \
++algorithm.dememwm.cache.no_evict=true \
++algorithm.dememwm.cache.clear_between_videos=true \
++algorithm.dememwm.cache.max_records=null \
++algorithm.dememwm.cache.on_capacity_exceeded=warn \
++algorithm.dememwm.curriculum.enabled=true \
++algorithm.dememwm.curriculum.full_stage_start_step=20000 \
++algorithm.dememwm.curriculum.freeze_vae=true \
++algorithm.dememwm.curriculum.dit_freeze.enabled=true \
++algorithm.dememwm.curriculum.lr.dememwm_modules=4.0e-5 \
++algorithm.dememwm.curriculum.lr.memory_adapters=4.0e-5 \
++algorithm.dememwm.curriculum.lr.full_dit=1.0e-5 \
experiment.training.batch_size=32 \
experiment.training.optim.accumulate_grad_batches=1 \
experiment.validation.batch_size=1 \
experiment.validation.limit_batch=16 \
experiment.training.checkpointing.every_n_train_steps=2000 \
experiment.validation.val_every_n_step=2000 \
experiment.training.max_steps=350000