Amshaker commited on
Commit
a124d29
·
verified ·
1 Parent(s): 389d93d

Upload train_mamba_stage_b.sh with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_mamba_stage_b.sh +129 -0
train_mamba_stage_b.sh ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ #SBATCH --job-name=worldmem
3
+ #SBATCH --nodes=1
4
+ #SBATCH --ntasks-per-node=8
5
+ #SBATCH --time=72:00:00
6
+ #SBATCH --output=logs/slurmm1.0-%j.out
7
+ #SBATCH --error=logs/slurmm1.0-%j.out
8
+ #SBATCH --account=berzelius-2025-436
9
+ #SBATCH --gres=gpu:8
10
+ ##SBATCH --gres=gpu:A100-SXM4-80GB:8
11
+
12
+ module load buildenv-gcccuda/12.1.1-gcc12.3.0
13
+
14
+ source $(conda info --base)/etc/profile.d/conda.sh
15
+ export PYTHONPATH="./:$PYTHONPATH"
16
+ export HF_HOME=/proj/cvl/users/x_fahkh2/caches
17
+ export TORCH_HOME=/proj/cvl/users/x_fahkh2/caches
18
+ export PIP_CACHE_DIR=/proj/cvl/users/x_fahkh2/caches
19
+ export TMPDIR=/proj/cvl/users/x_fahkh2/caches
20
+ export TRITON_CACHE_DIR=/proj/cvl/users/x_fahkh2/caches
21
+ export CUDA_HOME=$CUDA_ROOT
22
+ export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
23
+
24
+ export WANDB_DISABLED=true
25
+
26
+ #export NCCL_P2P_DISABLE=1
27
+
28
+ #export WANDB_MODE=offline
29
+
30
+ export HYDRA_FULL_ERROR=1
31
+
32
+
33
+ srun python -m main \
34
+ +name=train_stage_b_mamba \
35
+ algorithm=df_video_mamba3stage \
36
+ +customized_load=true \
37
+ +seperate_load=false \
38
+ experiment.num_nodes=1 \
39
+ load=/proj/cvl/users/x_fahkh2/WorldMem_Repro/checkpoints/bimamba_stage_a_128/checkpoints/epoch0_step2000.ckpt \
40
+ dataset.save_dir=/proj/cvl/users/x_fahkh2/WorldMem_Repro/datasets/minecraft \
41
+ dataset.n_frames=200 \
42
+ +dataset.n_frames_valid=200 \
43
+ +dataset.angle_range=110 \
44
+ +dataset.pos_range=2 \
45
+ +dataset.wo_updown=false \
46
+ +dataset.customized_validation=true \
47
+ +dataset.add_timestamp_embedding=true \
48
+ +dataset.use_explicit_memory_frames=false \
49
+ algorithm.training_stage=stage_b_diffusion_frozen_memory \
50
+ algorithm.use_mamba_memory_pipeline=true \
51
+ algorithm.use_oracle_pose_eval=false \
52
+ algorithm.enable_memory_noise_curriculum=false \
53
+ +algorithm.require_pose_prediction=false \
54
+ +algorithm.use_memory_attention=false \
55
+ +algorithm.relative_embedding=false \
56
+ +algorithm.memory_retrieval_topk=32 \
57
+ algorithm.diff_window_size=8 \
58
+ algorithm.memory_condition_length=0 \
59
+ algorithm.context_frames=100 \
60
+ +algorithm.n_tokens=8 \
61
+ experiment.training.batch_size=8 \
62
+ experiment.training.checkpointing.every_n_train_steps=2500 \
63
+ experiment.training.max_steps=30000 \
64
+ experiment.validation.val_every_n_step=2500 \
65
+ +output_dir=/proj/cvl/users/x_fahkh2/WorldMem_Repro/checkpoints/bimamba_stage_b/
66
+
67
+ srun python -m main \
68
+ +name=train_stage_b_mamba \
69
+ algorithm=df_video_mamba3stage \
70
+ experiment.num_nodes=1 \
71
+ dataset.save_dir=/proj/cvl/users/x_fahkh2/WorldMem_Repro/datasets/minecraft \
72
+ dataset.n_frames=200 \
73
+ +dataset.n_frames_valid=200 \
74
+ +dataset.angle_range=110 \
75
+ +dataset.pos_range=8 \
76
+ +dataset.wo_updown=false \
77
+ +dataset.customized_validation=true \
78
+ +dataset.add_timestamp_embedding=true \
79
+ +dataset.use_explicit_memory_frames=false \
80
+ algorithm.training_stage=stage_b_diffusion_frozen_memory \
81
+ algorithm.use_mamba_memory_pipeline=true \
82
+ algorithm.use_oracle_pose_eval=false \
83
+ algorithm.enable_memory_noise_curriculum=false \
84
+ +algorithm.require_pose_prediction=false \
85
+ +algorithm.use_memory_attention=false \
86
+ +algorithm.relative_embedding=false \
87
+ +algorithm.memory_retrieval_topk=32 \
88
+ algorithm.diff_window_size=8 \
89
+ algorithm.memory_condition_length=0 \
90
+ algorithm.context_frames=100 \
91
+ +algorithm.n_tokens=8 \
92
+ experiment.training.batch_size=8 \
93
+ experiment.training.checkpointing.every_n_train_steps=2500 \
94
+ experiment.training.max_steps=60000 \
95
+ experiment.validation.val_every_n_step=2500 \
96
+ resume=stage_b_offline \
97
+ +output_dir=/proj/cvl/users/x_fahkh2/WorldMem_Repro/checkpoints/bimamba_stage_b/
98
+
99
+ srun python -m main \
100
+ +name=train_stage_b_mamba \
101
+ algorithm=df_video_mamba3stage \
102
+ experiment.num_nodes=1 \
103
+ dataset.save_dir=/proj/cvl/users/x_fahkh2/WorldMem_Repro/datasets/minecraft \
104
+ dataset.n_frames=200 \
105
+ +dataset.n_frames_valid=200 \
106
+ +dataset.angle_range=110 \
107
+ +dataset.pos_range=8 \
108
+ +dataset.wo_updown=false \
109
+ +dataset.customized_validation=true \
110
+ +dataset.add_timestamp_embedding=true \
111
+ +dataset.use_explicit_memory_frames=false \
112
+ algorithm.training_stage=stage_b_diffusion_frozen_memory \
113
+ algorithm.use_mamba_memory_pipeline=true \
114
+ algorithm.use_oracle_pose_eval=false \
115
+ algorithm.enable_memory_noise_curriculum=false \
116
+ +algorithm.require_pose_prediction=false \
117
+ +algorithm.use_memory_attention=false \
118
+ +algorithm.relative_embedding=false \
119
+ +algorithm.memory_retrieval_topk=32 \
120
+ algorithm.diff_window_size=8 \
121
+ algorithm.memory_condition_length=0 \
122
+ algorithm.context_frames=100 \
123
+ +algorithm.n_tokens=8 \
124
+ experiment.training.batch_size=8 \
125
+ experiment.training.checkpointing.every_n_train_steps=2500 \
126
+ experiment.training.max_steps=175000 \
127
+ experiment.validation.val_every_n_step=2500 \
128
+ resume=stage_b_offline \
129
+ +output_dir=/proj/cvl/users/x_fahkh2/WorldMem_Repro/checkpoints/bimamba_stage_b/