File size: 2,898 Bytes
681f346
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
wandb enabled
export CUDA_VISIBLE_DEVICES=0,1,2,3
export NCCL_P2P_DISABLE=1
# export HYDRA_FULL_ERROR=1

set -e  # Exit on any error
set -o pipefail  # Exit on pipe failures

#Stage 1
python -m main +name=train \
    +diffusion_model_path=/share_1/users/bonan_ding/worldmem_ckpt/diffusion_only.ckpt \
    +vae_path=/share_1/users/bonan_ding/worldmem_ckpt/vae_only.ckpt \
    +customized_load=true \
    +seperate_load=true \
    +zero_init_gate=true \
    dataset.n_frames=8 \
    dataset.save_dir=/share_1/users/bonan_ding/worldmem_data/minecraft \
    +dataset.n_frames_valid=700 \
    +dataset.angle_range=110 \
    +dataset.pos_range=2 \
    +dataset.memory_condition_length=8 \
    +dataset.customized_validation=true \
    +dataset.add_timestamp_embedding=true \
    +dataset.wo_updown=true \
    +algorithm.n_tokens=8 \
    +algorithm.memory_condition_length=8 \
    algorithm.context_frames=600 \
    +algorithm.relative_embedding=true \
    +algorithm.log_video=true \
    +algorithm.add_timestamp_embedding=true \
    +algorithm.metrics=[lpips,psnr] \
    experiment.training.checkpointing.every_n_train_steps=2500 \
    experiment.training.max_steps=120000 \
    +output_dir=/share_1/users/bonan_ding/worldmem_ckpt/reproduce_official_set \

#Stage 2
python -m main +name=train \
    dataset.n_frames=8 \
    dataset.save_dir=data/minecraft \
    +dataset.n_frames_valid=700 \
    +dataset.angle_range=110 \
    +dataset.pos_range=8 \
    +dataset.memory_condition_length=8 \
    +dataset.customized_validation=true \
    +dataset.add_timestamp_embedding=true \
    +dataset.wo_updown=true \
    +algorithm.n_tokens=8 \
    +algorithm.memory_condition_length=8 \
    algorithm.context_frames=600 \
    +algorithm.relative_embedding=true \
    +algorithm.log_video=true \
    +algorithm.add_timestamp_embedding=true \
    +algorithm.metrics=[lpips,psnr] \
    experiment.training.checkpointing.every_n_train_steps=2500 \
    resume=ot7jqmgn \
    +output_dir=/share_1/users/bonan_ding/worldmem_ckpt/reproduce_official_set \
    experiment.training.max_steps=240000

#Stage 3
python -m main +name=train \
    dataset.n_frames=8 \
    dataset.save_dir=data/minecraft \
    +dataset.n_frames_valid=700 \
    +dataset.angle_range=110 \
    +dataset.pos_range=8 \
    +dataset.memory_condition_length=8 \
    +dataset.customized_validation=true \
    +dataset.add_timestamp_embedding=true \
    +dataset.wo_updown=false \
    +algorithm.n_tokens=8 \
    +algorithm.memory_condition_length=8 \
    algorithm.context_frames=600 \
    +algorithm.relative_embedding=true \
    +algorithm.log_video=true \
    +algorithm.add_timestamp_embedding=true \
    +algorithm.metrics=[lpips,psnr] \
    experiment.training.checkpointing.every_n_train_steps=2500 \
    resume=ot7jqmgn \
    +output_dir=/share_1/users/bonan_ding/worldmem_ckpt/reproduce_official_set \
    experiment.training.max_steps=700000