#!/usr/bin/env bash #SBATCH --job-name=dememwm-full-train #SBATCH --partition=gpu #SBATCH --time=3-00:00:00 #SBATCH --nodes=1 #SBATCH --ntasks-per-node=2 #SBATCH --cpus-per-task=16 #SBATCH --mem=256G #SBATCH --gres=gpu:2 #SBATCH --chdir=/share_1/users/bonan_ding/DeMemWM #SBATCH --output=/share_1/users/bonan_ding/DeMemWM/slurm_logs/dememwm-full-train-%j.out #SBATCH --error=/share_1/users/bonan_ding/DeMemWM/slurm_logs/dememwm-full-train-%j.err set -eo pipefail source /share_0/conda/etc/profile.d/conda.sh conda activate worldmem export PYTHONPATH="./:$PYTHONPATH" export HYDRA_FULL_ERROR=1 export PYTHONWARNINGS=ignore export OMP_NUM_THREADS=16 export WANDB_MODE=online export NCCL_P2P_DISABLE=1 wandb online >/dev/null 2>&1 || true srun python -m main \ +name=train_dememwm_full_h200_2gpu_bs32_350k \ +output_dir=/share_1/users/bonan_ding/worldmem_ckpt/dememwm_full_h200_2gpu_bs32_350k/ \ wandb.mode=online \ auto_resume=true \ "experiment.tasks=[training]" \ algorithm=dememwm_memory_dit \ +customized_load=true \ +seperate_load=true \ +diffusion_model_path=/share_1/users/bonan_ding/WorldMem/open-oasis/checkpoints/oasis500m.safetensors \ +vae_path=/share_1/users/bonan_ding/WorldMem/open-oasis/checkpoints/vit-l-20.safetensors \ +only_tune_memory=false \ dataset=video_minecraft_latent \ dataset.save_dir=/share_1/users/bonan_ding/worldmem_data/minecraft \ dataset.precomputed_feature_dir=/share_1/users/bonan_ding/worldmem_data/minecraft/vae_features \ dataset.n_frames=1000 \ +dataset.n_frames_valid=1100 \ +dataset.customized_validation=true \ +dataset.memory_condition_length=0 \ +dataset.wo_updown=false \ +dataset.angle_range=180 \ +dataset.pos_range=8 \ ++algorithm.n_tokens=4 \ "algorithm.x_shape=[16,18,32]" \ ++algorithm.context_frames=100 \ ++algorithm.log_video=true \ ++algorithm.diffusion.sampling_timesteps=20 \ ++algorithm.dememwm.debug_force_all_streams=false \ ++algorithm.dememwm.generated_history_proxy.enabled=true \ ++algorithm.dememwm.generated_history_proxy.start_step=40000 \ ++algorithm.dememwm.generated_history_proxy.ramp_steps=40000 \ ++algorithm.dememwm.generated_history_proxy.max_prob=0.25 \ ++algorithm.dememwm.generated_history_proxy.noise_std=0.25 \ ++algorithm.dememwm.generated_history_proxy.dropout_prob=0.0 \ ++algorithm.dememwm.anchor.enabled=true \ ++algorithm.dememwm.anchor.anchor_indices=[0,1,2,3] \ ++algorithm.dememwm.anchor.diverse_selection=true \ ++algorithm.dememwm.anchor.compress.downsample_ratio=3 \ ++algorithm.dememwm.anchor.allow_generated_as_anchor=false \ ++algorithm.dememwm.dynamic.enabled=true \ ++algorithm.dememwm.dynamic.exclude_latest_local_frames=4 \ ++algorithm.dememwm.dynamic.recent_frames=4 \ ++algorithm.dememwm.revisit.enabled=true \ ++algorithm.dememwm.revisit.deterministic_pose_retrieval=true \ ++algorithm.dememwm.revisit.fov_overlap_threshold=0.60 \ ++algorithm.dememwm.revisit.pose_preselect_topk=64 \ ++algorithm.dememwm.revisit.fov_yaw_samples=25 \ ++algorithm.dememwm.revisit.fov_pitch_samples=20 \ ++algorithm.dememwm.revisit.fov_depth_samples=20 \ ++algorithm.dememwm.revisit.plucker_weight=0.10 \ ++algorithm.dememwm.revisit.max_frames=2 \ ++algorithm.dememwm.revisit.compress.downsample_ratio=3 \ ++algorithm.dememwm.stage_policy.noise_bucket_logging=true \ ++algorithm.dememwm.cache.enabled=true \ ++algorithm.dememwm.cache.device=cpu \ ++algorithm.dememwm.cache.keep_raw_latents=all \ ++algorithm.dememwm.cache.keep_compressed_records=true \ ++algorithm.dememwm.cache.eviction_policy=none \ ++algorithm.dememwm.cache.no_evict=true \ ++algorithm.dememwm.cache.clear_between_videos=true \ ++algorithm.dememwm.cache.max_records=null \ ++algorithm.dememwm.cache.on_capacity_exceeded=warn \ ++algorithm.dememwm.curriculum.enabled=true \ ++algorithm.dememwm.curriculum.full_stage_start_step=20000 \ ++algorithm.dememwm.curriculum.freeze_vae=true \ ++algorithm.dememwm.curriculum.dit_freeze.enabled=true \ ++algorithm.dememwm.curriculum.lr.dememwm_modules=4.0e-5 \ ++algorithm.dememwm.curriculum.lr.memory_adapters=4.0e-5 \ ++algorithm.dememwm.curriculum.lr.full_dit=1.0e-5 \ experiment.training.batch_size=32 \ experiment.training.optim.accumulate_grad_batches=1 \ experiment.validation.batch_size=1 \ experiment.validation.limit_batch=16 \ experiment.training.checkpointing.every_n_train_steps=2000 \ experiment.validation.val_every_n_step=2000 \ experiment.training.max_steps=350000