WorldMem_Repro / train_3stages.sh
BonanDing's picture
Reproduce Training & Fix distributed eval
681f346
wandb enabled
export CUDA_VISIBLE_DEVICES=0,1,2,3
export NCCL_P2P_DISABLE=1
# export HYDRA_FULL_ERROR=1
set -e # Exit on any error
set -o pipefail # Exit on pipe failures
#Stage 1
python -m main +name=train \
+diffusion_model_path=/share_1/users/bonan_ding/worldmem_ckpt/diffusion_only.ckpt \
+vae_path=/share_1/users/bonan_ding/worldmem_ckpt/vae_only.ckpt \
+customized_load=true \
+seperate_load=true \
+zero_init_gate=true \
dataset.n_frames=8 \
dataset.save_dir=/share_1/users/bonan_ding/worldmem_data/minecraft \
+dataset.n_frames_valid=700 \
+dataset.angle_range=110 \
+dataset.pos_range=2 \
+dataset.memory_condition_length=8 \
+dataset.customized_validation=true \
+dataset.add_timestamp_embedding=true \
+dataset.wo_updown=true \
+algorithm.n_tokens=8 \
+algorithm.memory_condition_length=8 \
algorithm.context_frames=600 \
+algorithm.relative_embedding=true \
+algorithm.log_video=true \
+algorithm.add_timestamp_embedding=true \
+algorithm.metrics=[lpips,psnr] \
experiment.training.checkpointing.every_n_train_steps=2500 \
experiment.training.max_steps=120000 \
+output_dir=/share_1/users/bonan_ding/worldmem_ckpt/reproduce_official_set \
#Stage 2
python -m main +name=train \
dataset.n_frames=8 \
dataset.save_dir=data/minecraft \
+dataset.n_frames_valid=700 \
+dataset.angle_range=110 \
+dataset.pos_range=8 \
+dataset.memory_condition_length=8 \
+dataset.customized_validation=true \
+dataset.add_timestamp_embedding=true \
+dataset.wo_updown=true \
+algorithm.n_tokens=8 \
+algorithm.memory_condition_length=8 \
algorithm.context_frames=600 \
+algorithm.relative_embedding=true \
+algorithm.log_video=true \
+algorithm.add_timestamp_embedding=true \
+algorithm.metrics=[lpips,psnr] \
experiment.training.checkpointing.every_n_train_steps=2500 \
resume=ot7jqmgn \
+output_dir=/share_1/users/bonan_ding/worldmem_ckpt/reproduce_official_set \
experiment.training.max_steps=240000
#Stage 3
python -m main +name=train \
dataset.n_frames=8 \
dataset.save_dir=data/minecraft \
+dataset.n_frames_valid=700 \
+dataset.angle_range=110 \
+dataset.pos_range=8 \
+dataset.memory_condition_length=8 \
+dataset.customized_validation=true \
+dataset.add_timestamp_embedding=true \
+dataset.wo_updown=false \
+algorithm.n_tokens=8 \
+algorithm.memory_condition_length=8 \
algorithm.context_frames=600 \
+algorithm.relative_embedding=true \
+algorithm.log_video=true \
+algorithm.add_timestamp_embedding=true \
+algorithm.metrics=[lpips,psnr] \
experiment.training.checkpointing.every_n_train_steps=2500 \
resume=ot7jqmgn \
+output_dir=/share_1/users/bonan_ding/worldmem_ckpt/reproduce_official_set \
experiment.training.max_steps=700000