Commit ·
1a6af5d
0
Parent(s):
upload dataset, checkpoint, and training script
Browse files- .gitattributes +3 -0
- checkpoint-40000 +1 -0
- image_v5_0_128 +1 -0
- showo2_1.5b_downstream_mixed_modality_simple.yaml +146 -0
- train_inter_final.json +1 -0
- train_mixed_modality_simple.py +856 -0
.gitattributes
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
image_v5_0_128/** filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
checkpoint-40000/** filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.json filter=lfs diff=lfs merge=lfs -text
|
checkpoint-40000
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
/home/a6000/bk-project/multimodal-showo/show-o2/1st_show-o2-1.5b-downstream-mixed-modality-432x432/checkpoint-40000
|
image_v5_0_128
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
../AvaMERG_img_inter/image_v5_0_128
|
showo2_1.5b_downstream_mixed_modality_simple.yaml
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
wandb:
|
| 2 |
+
entity: null
|
| 3 |
+
resume: 'auto'
|
| 4 |
+
|
| 5 |
+
# 410k -- 512 res mixedmodal
|
| 6 |
+
|
| 7 |
+
experiment:
|
| 8 |
+
project: "showo2-2b-stage-1"
|
| 9 |
+
name: "showo2-1.5b-downstream-mixed-modality-432x432"
|
| 10 |
+
output_dir: "1st_show-o2-1.5b-downstream-mixed-modality-432x432"
|
| 11 |
+
output_dataloader_state_dir: null
|
| 12 |
+
max_train_examples_t2i: 60000000 # 10M HQ generation data
|
| 13 |
+
max_train_examples_mmu: null
|
| 14 |
+
save_every: 500
|
| 15 |
+
generate_every: 1000
|
| 16 |
+
log_every: 1
|
| 17 |
+
log_grad_norm_every: 500
|
| 18 |
+
resume_from_checkpoint: 'latest'
|
| 19 |
+
|
| 20 |
+
model:
|
| 21 |
+
vae_model:
|
| 22 |
+
type: "wan21"
|
| 23 |
+
pretrained_model_path: "Wan_VAE_model/Wan2.1_VAE.pth" # our local path
|
| 24 |
+
|
| 25 |
+
showo:
|
| 26 |
+
model_name: "Showo2"
|
| 27 |
+
load_from_showo: True
|
| 28 |
+
# load_from_showo: False
|
| 29 |
+
pretrained_model_path: "showlab/show-o2-1.5B"
|
| 30 |
+
# pretrained_model_path: "/home/a6000/bk-project/multimodal-showo/show-o2/3rd_show-o2-1.5b-downstream-mixed-modality-432x432/checkpoint-2000/unwrapped_model" # our stage-1 weight path
|
| 31 |
+
llm_model_path: "Qwen/Qwen2.5-1.5B-Instruct"
|
| 32 |
+
llm_vocab_size: null # will be updated when setting the tokenizer in other code files
|
| 33 |
+
hidden_size: 1536
|
| 34 |
+
image_latent_dim: 16
|
| 35 |
+
image_latent_height: 27
|
| 36 |
+
image_latent_width: 27
|
| 37 |
+
hq_image_latent_height: 64
|
| 38 |
+
hq_image_latent_width: 64
|
| 39 |
+
mixed_modal_latent_height: 27
|
| 40 |
+
mixed_modal_latent_width: 27
|
| 41 |
+
patch_size: 2
|
| 42 |
+
num_diffusion_layers: 10
|
| 43 |
+
clip_latent_dim: 1152
|
| 44 |
+
add_qk_norm: True
|
| 45 |
+
add_time_embeds: True
|
| 46 |
+
# frozen_params: [ 'image_embedder_und', 'und_trans', 'showo', 'position_embedding']
|
| 47 |
+
params_not_load: null
|
| 48 |
+
|
| 49 |
+
clip:
|
| 50 |
+
pretrained_model_path: "google/siglip-so400m-patch14-384"
|
| 51 |
+
|
| 52 |
+
gradient_checkpointing: True
|
| 53 |
+
|
| 54 |
+
dataset:
|
| 55 |
+
samp_probs: null
|
| 56 |
+
accumulation: 1
|
| 57 |
+
mixed_loader_mode: "sequential_max_size_cycle"
|
| 58 |
+
params:
|
| 59 |
+
train_mixed_modal_shards_path_or_url: "./AvaMERG_img_inter/image_v5_0_128" # our dataset
|
| 60 |
+
annotation_path: "./AvaMERG_img_inter/train_inter_final.json" # our dataset
|
| 61 |
+
is_clip_encoder: False
|
| 62 |
+
default_system_prompt: ""
|
| 63 |
+
add_caption_prompt: True
|
| 64 |
+
validation_prompts_file: "prompts/t2i_prompts.txt"
|
| 65 |
+
shuffle_buffer_size: 1000
|
| 66 |
+
num_workers: 0
|
| 67 |
+
pin_memory: True
|
| 68 |
+
persistent_workers: True
|
| 69 |
+
|
| 70 |
+
preprocessing:
|
| 71 |
+
max_seq_length: 1280
|
| 72 |
+
max_hq_seq_length: 4352
|
| 73 |
+
max_mixed_modal_seq_length: 4352
|
| 74 |
+
max_video_seq_length: 4352
|
| 75 |
+
resolution: 432
|
| 76 |
+
mixed_modal_resolution: 432
|
| 77 |
+
video_resolution: 432
|
| 78 |
+
hq_resolution: 1024
|
| 79 |
+
num_t2i_image_tokens: 729
|
| 80 |
+
num_mmu_image_tokens: 729
|
| 81 |
+
num_hq_image_tokens: 4096
|
| 82 |
+
num_mixed_modal_tokens: 729
|
| 83 |
+
num_video_tokens: 3645
|
| 84 |
+
latent_height: ${model.showo.image_latent_height}
|
| 85 |
+
latent_width: ${model.showo.image_latent_width}
|
| 86 |
+
video_latent_height: ${model.showo.image_latent_height}
|
| 87 |
+
video_latent_width: ${model.showo.image_latent_width}
|
| 88 |
+
hq_latent_height: ${model.showo.hq_image_latent_height}
|
| 89 |
+
hq_latent_width: ${model.showo.hq_image_latent_width}
|
| 90 |
+
mixed_modal_latent_height: ${model.showo.hq_image_latent_height}
|
| 91 |
+
mixed_modal_latent_width: ${model.showo.hq_image_latent_width}
|
| 92 |
+
min_res: [ 256, 256 ]
|
| 93 |
+
random_und_or_gen: 0.0
|
| 94 |
+
max_num_images: 4
|
| 95 |
+
max_num_videos: 4 # only for video training, not use in this case
|
| 96 |
+
num_frames: 2 # # only for video training, not use in this case
|
| 97 |
+
|
| 98 |
+
optimizer:
|
| 99 |
+
name: adamw
|
| 100 |
+
params: # default adamw params
|
| 101 |
+
learning_rate: 0.0001
|
| 102 |
+
scale_lr: False # scale learning rate by total batch size
|
| 103 |
+
beta1: 0.9
|
| 104 |
+
beta2: 0.999
|
| 105 |
+
weight_decay: 0.0
|
| 106 |
+
epsilon: 1e-8
|
| 107 |
+
|
| 108 |
+
lr_scheduler:
|
| 109 |
+
scheduler: "constant_with_warmup" # "polynomial"
|
| 110 |
+
params:
|
| 111 |
+
learning_rate: ${optimizer.params.learning_rate}
|
| 112 |
+
warmup_steps: 0
|
| 113 |
+
# min_lr: 1e-6
|
| 114 |
+
# power: 0.5 # for polynomial
|
| 115 |
+
|
| 116 |
+
transport:
|
| 117 |
+
path_type: "Linear"
|
| 118 |
+
prediction: "velocity"
|
| 119 |
+
loss_weight: null
|
| 120 |
+
train_eps: null
|
| 121 |
+
sample_eps: null
|
| 122 |
+
snr_type: "lognorm"
|
| 123 |
+
sampling_method: "euler"
|
| 124 |
+
guidance_scale: 5.0
|
| 125 |
+
num_inference_steps: 50
|
| 126 |
+
atol: 1e-6
|
| 127 |
+
rtol: 1e-3
|
| 128 |
+
reverse: False
|
| 129 |
+
do_shift: True
|
| 130 |
+
time_shifting_factor: 3.0
|
| 131 |
+
|
| 132 |
+
training:
|
| 133 |
+
gradient_accumulation_steps: 1
|
| 134 |
+
batch_size: 1
|
| 135 |
+
batch_size_mixed_modal: 1
|
| 136 |
+
batch_size_video: 0
|
| 137 |
+
mixed_precision: "bf16"
|
| 138 |
+
enable_tf32: True
|
| 139 |
+
seed: 10000
|
| 140 |
+
max_train_steps: 50000
|
| 141 |
+
cond_dropout_prob: 0.1
|
| 142 |
+
label_smoothing: 0.0
|
| 143 |
+
max_grad_norm: 1.0
|
| 144 |
+
ntp_coeff: 0.2
|
| 145 |
+
flow_coeff: 1.0
|
| 146 |
+
und_max_t0: 1.0
|
train_inter_final.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
../AvaMERG_img_inter/train_inter_final.json
|
train_mixed_modality_simple.py
ADDED
|
@@ -0,0 +1,856 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2025 NUS Show Lab, HuggingFace.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import os
|
| 17 |
+
import json
|
| 18 |
+
import logging
|
| 19 |
+
import math
|
| 20 |
+
import shutil
|
| 21 |
+
import time
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
from typing import Union
|
| 24 |
+
import numpy as np
|
| 25 |
+
from PIL import Image
|
| 26 |
+
from omegaconf import OmegaConf
|
| 27 |
+
import wandb
|
| 28 |
+
import random
|
| 29 |
+
import torch
|
| 30 |
+
from torch.optim import AdamW
|
| 31 |
+
from einops import rearrange
|
| 32 |
+
from accelerate import Accelerator
|
| 33 |
+
from accelerate.logging import get_logger
|
| 34 |
+
from accelerate.utils import DistributedType, set_seed
|
| 35 |
+
from torch.utils.data import DataLoader
|
| 36 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 37 |
+
from models import Showo2Qwen2_5, omni_attn_mask_naive, omni_attn_mask
|
| 38 |
+
from training.omni_attention import create_block_mask
|
| 39 |
+
from models.lr_schedulers import get_scheduler
|
| 40 |
+
from models.my_logging import set_verbosity_info, set_verbosity_error
|
| 41 |
+
from models.misc import prepare_gen_input, get_text_tokenizer, get_weight_type
|
| 42 |
+
from torch.nn.attention.flex_attention import flex_attention
|
| 43 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
| 44 |
+
|
| 45 |
+
if torch.cuda.is_available():
|
| 46 |
+
flex_attention = torch.compile(flex_attention)
|
| 47 |
+
|
| 48 |
+
from datasets import create_imagetext_dataloader, MixedDataLoader, VISTDataset
|
| 49 |
+
from utils import get_config, flatten_omega_conf, AverageMeter, denorm, denorm_vid, get_hyper_params, \
|
| 50 |
+
path_to_llm_name, _freeze_params
|
| 51 |
+
|
| 52 |
+
from transport import Sampler, create_transport
|
| 53 |
+
|
| 54 |
+
logger = get_logger(__name__, log_level="INFO")
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def main():
|
| 58 |
+
#########################
|
| 59 |
+
# SETUP Accelerator #
|
| 60 |
+
#########################
|
| 61 |
+
config = get_config()
|
| 62 |
+
|
| 63 |
+
# Enable TF32 on Ampere GPUs
|
| 64 |
+
if config.training.enable_tf32:
|
| 65 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 66 |
+
torch.backends.cudnn.benchmark = True
|
| 67 |
+
torch.backends.cudnn.deterministic = False
|
| 68 |
+
|
| 69 |
+
config.experiment.logging_dir = str(Path(config.experiment.output_dir) / "logs")
|
| 70 |
+
accelerator = Accelerator(
|
| 71 |
+
gradient_accumulation_steps=config.training.gradient_accumulation_steps,
|
| 72 |
+
mixed_precision=config.training.mixed_precision,
|
| 73 |
+
log_with="wandb",
|
| 74 |
+
project_dir=config.experiment.logging_dir,
|
| 75 |
+
split_batches=True,
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
bs_mixed_modal = config.training.batch_size_mixed_modal
|
| 79 |
+
|
| 80 |
+
if "concat" in config.dataset.mixed_loader_mode:
|
| 81 |
+
raise NotImplementedError
|
| 82 |
+
else:
|
| 83 |
+
total_batch_size_per_gpu = bs_mixed_modal * config.dataset.accumulation
|
| 84 |
+
total_batch_size_without_accum = total_batch_size_per_gpu * accelerator.num_processes
|
| 85 |
+
total_batch_size = total_batch_size_without_accum * config.training.gradient_accumulation_steps
|
| 86 |
+
|
| 87 |
+
if accelerator.distributed_type == DistributedType.DEEPSPEED:
|
| 88 |
+
accelerator.state.deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = (
|
| 89 |
+
total_batch_size_per_gpu
|
| 90 |
+
)
|
| 91 |
+
print("[DEBUG] CUDA_VISIBLE_DEVICES:", os.environ.get("CUDA_VISIBLE_DEVICES"))
|
| 92 |
+
print("[DEBUG] torch.cuda.device_count():", torch.cuda.device_count())
|
| 93 |
+
print("[DEBUG] Accelerator processes:", accelerator.num_processes)
|
| 94 |
+
|
| 95 |
+
#####################################
|
| 96 |
+
# SETUP LOGGING, SEED and CONFIG #
|
| 97 |
+
#####################################
|
| 98 |
+
# Make one log on every process with the configuration for debugging.
|
| 99 |
+
logging.basicConfig(
|
| 100 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
| 101 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
| 102 |
+
level=logging.INFO,
|
| 103 |
+
)
|
| 104 |
+
logger.info(accelerator.state, main_process_only=False)
|
| 105 |
+
if accelerator.is_local_main_process:
|
| 106 |
+
set_verbosity_info()
|
| 107 |
+
else:
|
| 108 |
+
set_verbosity_error()
|
| 109 |
+
|
| 110 |
+
# We need to initialize the trackers we use, and also store our configuration.
|
| 111 |
+
# The trackers initializes automatically on the main process.
|
| 112 |
+
if accelerator.is_main_process:
|
| 113 |
+
resume_wandb_run = config.wandb.resume
|
| 114 |
+
run_id = config.wandb.get("run_id", None)
|
| 115 |
+
if run_id is None:
|
| 116 |
+
resume_wandb_run = False
|
| 117 |
+
run_id = wandb.util.generate_id()
|
| 118 |
+
config.wandb.run_id = run_id
|
| 119 |
+
|
| 120 |
+
wandb_init_kwargs = dict(
|
| 121 |
+
name=config.experiment.name,
|
| 122 |
+
id=run_id,
|
| 123 |
+
resume=resume_wandb_run,
|
| 124 |
+
entity=config.wandb.get("entity", None),
|
| 125 |
+
config_exclude_keys=[],
|
| 126 |
+
)
|
| 127 |
+
wandb_config = {k: v for k, v in flatten_omega_conf(config, resolve=True)}
|
| 128 |
+
wandb_config.pop("experiment.resume_from_checkpoint")
|
| 129 |
+
|
| 130 |
+
accelerator.init_trackers(
|
| 131 |
+
config.experiment.project,
|
| 132 |
+
config=wandb_config,
|
| 133 |
+
init_kwargs={"wandb": wandb_init_kwargs},
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
if accelerator.is_main_process:
|
| 137 |
+
os.makedirs(config.experiment.output_dir, exist_ok=True)
|
| 138 |
+
config_path = Path(config.experiment.output_dir) / "config.yaml"
|
| 139 |
+
logging.info(f"Saving config to {config_path}")
|
| 140 |
+
OmegaConf.save(config, config_path)
|
| 141 |
+
|
| 142 |
+
# If passed along, set the training seed now.
|
| 143 |
+
if config.training.seed is not None:
|
| 144 |
+
set_seed(config.training.seed)
|
| 145 |
+
|
| 146 |
+
#########################
|
| 147 |
+
# MODELS and OPTIMIZER #
|
| 148 |
+
#########################
|
| 149 |
+
logger.info("Loading models and optimizer")
|
| 150 |
+
|
| 151 |
+
weight_type = get_weight_type(config)
|
| 152 |
+
|
| 153 |
+
# VQ model for processing image into discrete tokens
|
| 154 |
+
if config.model.vae_model.type == 'wan21':
|
| 155 |
+
from models import WanVAE
|
| 156 |
+
vae_model = WanVAE(vae_pth=config.model.vae_model.pretrained_model_path, dtype=weight_type,
|
| 157 |
+
device=accelerator.device)
|
| 158 |
+
else:
|
| 159 |
+
raise NotImplementedError
|
| 160 |
+
|
| 161 |
+
# Initialize Show-o model
|
| 162 |
+
text_tokenizer, showo_token_ids = get_text_tokenizer(config.model.showo.llm_model_path, add_showo_tokens=True,
|
| 163 |
+
return_showo_token_ids=True,
|
| 164 |
+
llm_name=path_to_llm_name[config.model.showo.llm_model_path])
|
| 165 |
+
config.model.showo.llm_vocab_size = len(text_tokenizer)
|
| 166 |
+
|
| 167 |
+
if config.model.showo.load_from_showo:
|
| 168 |
+
model = Showo2Qwen2_5.from_pretrained(config.model.showo.pretrained_model_path, use_safetensors=False).to(accelerator.device)
|
| 169 |
+
else:
|
| 170 |
+
model = Showo2Qwen2_5(**config.model.showo).to(accelerator.device)
|
| 171 |
+
|
| 172 |
+
# Choose layers to freeze
|
| 173 |
+
_freeze_params(model, config.model.showo.frozen_params)
|
| 174 |
+
|
| 175 |
+
preproc_config = config.dataset.preprocessing
|
| 176 |
+
dataset_config = config.dataset.params
|
| 177 |
+
|
| 178 |
+
# for time embedding
|
| 179 |
+
if config.model.showo.add_time_embeds:
|
| 180 |
+
# we prepend the time embedding to vision tokens
|
| 181 |
+
config.dataset.preprocessing.num_mmu_image_tokens += 1
|
| 182 |
+
config.dataset.preprocessing.num_t2i_image_tokens += 1
|
| 183 |
+
config.dataset.preprocessing.num_hq_image_tokens += 1
|
| 184 |
+
config.dataset.preprocessing.num_video_tokens += 1
|
| 185 |
+
config.dataset.preprocessing.num_mixed_modal_tokens += 1
|
| 186 |
+
|
| 187 |
+
##################################
|
| 188 |
+
# Optimizer and LR scheduler #
|
| 189 |
+
#################################
|
| 190 |
+
optimizer_config = config.optimizer.params
|
| 191 |
+
optimizer_type = config.optimizer.name
|
| 192 |
+
|
| 193 |
+
if optimizer_type == "adamw":
|
| 194 |
+
optimizer = AdamW(
|
| 195 |
+
model.parameters(),
|
| 196 |
+
lr=optimizer_config.learning_rate,
|
| 197 |
+
betas=(optimizer_config.beta1, optimizer_config.beta2),
|
| 198 |
+
weight_decay=optimizer_config.weight_decay,
|
| 199 |
+
eps=optimizer_config.epsilon,
|
| 200 |
+
)
|
| 201 |
+
else:
|
| 202 |
+
raise ValueError(f"Optimizer {optimizer_type} not supported")
|
| 203 |
+
|
| 204 |
+
##################################
|
| 205 |
+
# DATALOADER #
|
| 206 |
+
#################################
|
| 207 |
+
logger.info("Creating dataloaders and lr_scheduler")
|
| 208 |
+
|
| 209 |
+
# DataLoaders creation:
|
| 210 |
+
# We use webdataset for data loading. The dataloaders are created with sampling with replacement.
|
| 211 |
+
# We don't do dataset resuming here, instead we resample the shards and buffer each time. The sampling is stochastic.
|
| 212 |
+
# This means that the dataloading is not deterministic, but it's fast and efficient.
|
| 213 |
+
|
| 214 |
+
def create_dataloader(dataset, batch_size, collate_fn):
|
| 215 |
+
generator = torch.Generator(device='cuda')
|
| 216 |
+
if accelerator.num_processes > 2:
|
| 217 |
+
sampler = DistributedSampler(dataset,
|
| 218 |
+
num_replicas=accelerator.num_processes,
|
| 219 |
+
rank=accelerator.process_index,
|
| 220 |
+
shuffle=True,
|
| 221 |
+
drop_last=True,
|
| 222 |
+
# generator=generator
|
| 223 |
+
)
|
| 224 |
+
shuffle = False
|
| 225 |
+
else:
|
| 226 |
+
sampler = None
|
| 227 |
+
shuffle = True
|
| 228 |
+
|
| 229 |
+
dataloader = DataLoader(dataset, batch_size=batch_size,
|
| 230 |
+
sampler=sampler, collate_fn=collate_fn,
|
| 231 |
+
shuffle=shuffle, num_workers=dataset_config.num_workers,
|
| 232 |
+
drop_last=True, generator=generator)
|
| 233 |
+
return dataloader
|
| 234 |
+
|
| 235 |
+
dataset = VISTDataset(
|
| 236 |
+
dataset_config.train_mixed_modal_shards_path_or_url,
|
| 237 |
+
anno_path=dataset_config.annotation_path,
|
| 238 |
+
text_tokenizer=text_tokenizer,
|
| 239 |
+
image_size=preproc_config.mixed_modal_resolution,
|
| 240 |
+
max_seq_len=preproc_config.max_mixed_modal_seq_length,
|
| 241 |
+
num_image_tokens=preproc_config.num_mixed_modal_tokens,
|
| 242 |
+
latent_width=preproc_config.mixed_modal_latent_width,
|
| 243 |
+
latent_height=preproc_config.mixed_modal_latent_height,
|
| 244 |
+
cond_dropout_prob=config.training.cond_dropout_prob,
|
| 245 |
+
min_res=preproc_config.min_res,
|
| 246 |
+
showo_token_ids=showo_token_ids,
|
| 247 |
+
system=("", "", ""),
|
| 248 |
+
max_num_images=preproc_config.max_num_images,
|
| 249 |
+
)
|
| 250 |
+
print("Dataset length:", len(dataset))
|
| 251 |
+
train_dataloader_mixed_modal = create_dataloader(dataset,
|
| 252 |
+
config.training.batch_size_mixed_modal, #1
|
| 253 |
+
dataset.collate_fn)
|
| 254 |
+
|
| 255 |
+
num_update_steps_per_epoch = len(train_dataloader_mixed_modal)
|
| 256 |
+
print('[DEBUG] num_update_steps_per_epoch:', num_update_steps_per_epoch)
|
| 257 |
+
num_train_epochs = math.ceil(config.training.max_train_steps / num_update_steps_per_epoch)
|
| 258 |
+
|
| 259 |
+
##################################
|
| 260 |
+
# MODEL RESUME #
|
| 261 |
+
#################################
|
| 262 |
+
global_step = 0
|
| 263 |
+
first_epoch = 0
|
| 264 |
+
|
| 265 |
+
if config.experiment.resume_from_checkpoint:
|
| 266 |
+
dirs = os.listdir(config.experiment.output_dir)
|
| 267 |
+
# dirs = [d for d in dirs if d.startswith("checkpoint")]
|
| 268 |
+
dirs = [d for d in dirs if
|
| 269 |
+
d.startswith("checkpoint-") and d.split("-")[1].isdigit()] # 250804 수정; checkpoint-final 있을 경우
|
| 270 |
+
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
| 271 |
+
path = dirs[-1] if len(dirs) > 0 else None
|
| 272 |
+
if path is not None:
|
| 273 |
+
path = os.path.join(config.experiment.output_dir, path)
|
| 274 |
+
|
| 275 |
+
global_step = int(os.path.basename(path).split("-")[1])
|
| 276 |
+
first_epoch = global_step // num_update_steps_per_epoch
|
| 277 |
+
|
| 278 |
+
accelerator.print(f"Resuming from checkpoint {path}/unwrapped_model/pytorch_model.bin")
|
| 279 |
+
state_dict = torch.load(f'{path}/unwrapped_model/pytorch_model.bin', map_location="cpu")
|
| 280 |
+
|
| 281 |
+
# not load some parameters
|
| 282 |
+
if config.model.showo.params_not_load is not None:
|
| 283 |
+
params_to_delete = []
|
| 284 |
+
for k in state_dict:
|
| 285 |
+
for n in config.model.showo.params_not_load:
|
| 286 |
+
if n in k:
|
| 287 |
+
params_to_delete.append(k)
|
| 288 |
+
for k in params_to_delete:
|
| 289 |
+
del state_dict[k]
|
| 290 |
+
|
| 291 |
+
model.load_state_dict(state_dict, strict=False if config.model.showo.params_not_load is not None else True)
|
| 292 |
+
del state_dict
|
| 293 |
+
|
| 294 |
+
# Combine these dataloaders into a single iterable model
|
| 295 |
+
mixed_loader = MixedDataLoader(
|
| 296 |
+
loader_list=[train_dataloader_mixed_modal],
|
| 297 |
+
samp_probs=config.dataset.samp_probs,
|
| 298 |
+
accumulation=config.dataset.accumulation,
|
| 299 |
+
mode=config.dataset.mixed_loader_mode
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
lr_scheduler = get_scheduler(
|
| 303 |
+
config.lr_scheduler.scheduler,
|
| 304 |
+
optimizer=optimizer,
|
| 305 |
+
num_training_steps=config.training.max_train_steps - global_step,
|
| 306 |
+
num_warmup_steps=config.lr_scheduler.params.warmup_steps,
|
| 307 |
+
# power=config.lr_scheduler.params.power,
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
##################################
|
| 311 |
+
# Prepare accelerator #
|
| 312 |
+
#################################
|
| 313 |
+
logger.info("Preparing model, optimizer and dataloaders")
|
| 314 |
+
model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler)
|
| 315 |
+
|
| 316 |
+
##################################
|
| 317 |
+
# Training #
|
| 318 |
+
#################################
|
| 319 |
+
logger.info("***** Running training *****")
|
| 320 |
+
logger.info(f" Num training steps = {config.training.max_train_steps}")
|
| 321 |
+
logger.info(f" Instantaneous batch size per device = {total_batch_size_per_gpu}")
|
| 322 |
+
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
| 323 |
+
logger.info(f" Gradient Accumulation steps = {config.training.gradient_accumulation_steps}")
|
| 324 |
+
|
| 325 |
+
# default: 1000 steps, linear noise schedule
|
| 326 |
+
transport = create_transport(
|
| 327 |
+
path_type=config.transport.path_type,
|
| 328 |
+
prediction=config.transport.prediction,
|
| 329 |
+
loss_weight=config.transport.loss_weight,
|
| 330 |
+
train_eps=config.transport.train_eps,
|
| 331 |
+
sample_eps=config.transport.sample_eps,
|
| 332 |
+
snr_type=config.transport.snr_type,
|
| 333 |
+
do_shift=config.transport.do_shift,
|
| 334 |
+
seq_len=preproc_config.num_t2i_image_tokens,
|
| 335 |
+
) # default: velocity;
|
| 336 |
+
|
| 337 |
+
sampler = Sampler(transport)
|
| 338 |
+
|
| 339 |
+
@torch.no_grad()
|
| 340 |
+
def prepare_latents_and_labels(
|
| 341 |
+
pixel_values: Union[torch.FloatTensor, torch.LongTensor],
|
| 342 |
+
data_type,
|
| 343 |
+
shape,
|
| 344 |
+
image_masks,
|
| 345 |
+
modality_positions
|
| 346 |
+
):
|
| 347 |
+
|
| 348 |
+
if config.model.vae_model.type == 'wan21':
|
| 349 |
+
if len(pixel_values.shape) == 4:
|
| 350 |
+
pixel_values = pixel_values.unsqueeze(2)
|
| 351 |
+
image_latents = vae_model.sample(pixel_values)
|
| 352 |
+
recons_images = vae_model.batch_decode(image_latents)
|
| 353 |
+
if pixel_values.shape[2] == 1:
|
| 354 |
+
image_latents = image_latents.squeeze(2)
|
| 355 |
+
recons_images = recons_images.squeeze(2)
|
| 356 |
+
else:
|
| 357 |
+
raise NotImplementedError
|
| 358 |
+
|
| 359 |
+
c, h, w = image_latents.shape[1:]
|
| 360 |
+
# timesteps, noise, original image
|
| 361 |
+
# each for loop takes around 0.002, which is affordable
|
| 362 |
+
t_list, xt_list, ut_list, masks = [], [], [], []
|
| 363 |
+
for i, tp in enumerate(data_type):
|
| 364 |
+
# x0->noise x1->image
|
| 365 |
+
t, x0, x1 = transport.sample(image_latents[i][None],
|
| 366 |
+
config.training.und_max_t0 if tp in ['mmu', 'mmu_vid'] else None)
|
| 367 |
+
# timesteps, noised image, velocity
|
| 368 |
+
t, xt, ut = transport.path_sampler.plan(t, x0, x1)
|
| 369 |
+
t_list.append(t)
|
| 370 |
+
xt_list.append(xt)
|
| 371 |
+
ut_list.append(ut)
|
| 372 |
+
if data_type[0] != 'interleaved_data':
|
| 373 |
+
if tp in ['mmu', 'mmu_vid'] and config.training.und_max_t0 == 1.0:
|
| 374 |
+
masks.append(image_masks[i][None] * 0.0)
|
| 375 |
+
else:
|
| 376 |
+
masks.append(image_masks[i][None])
|
| 377 |
+
|
| 378 |
+
t = torch.stack(t_list, dim=0).squeeze(-1)
|
| 379 |
+
xt = torch.cat(xt_list, dim=0)
|
| 380 |
+
ut = torch.cat(ut_list, dim=0)
|
| 381 |
+
|
| 382 |
+
if len(masks) != 0:
|
| 383 |
+
masks = torch.cat(masks, dim=0)
|
| 384 |
+
else:
|
| 385 |
+
masks = image_masks
|
| 386 |
+
|
| 387 |
+
if data_type[0] == 'interleaved_data':
|
| 388 |
+
b, n = shape
|
| 389 |
+
image_latents = image_latents.reshape(b, n, c, h, w)
|
| 390 |
+
ut = ut.reshape(b, n, c, h, w)
|
| 391 |
+
xt = xt.reshape(b, n, c, h, w)
|
| 392 |
+
t = t.reshape(b, n)
|
| 393 |
+
|
| 394 |
+
for i in range(b):
|
| 395 |
+
if random.random() < 0.7:
|
| 396 |
+
non_zero_max_idx = max([_ for _, pos in enumerate(modality_positions[i]) if pos[1] != 0])
|
| 397 |
+
idx = random.randint(1, non_zero_max_idx) if non_zero_max_idx != 0 else 0
|
| 398 |
+
xt[i, :idx] = image_latents[i][None][:, :idx].clone()
|
| 399 |
+
# ut[i, :idx] = torch.zeros_like(image_latents[i][None][:, :idx])
|
| 400 |
+
t[i, :idx] = t[i, :idx] * 0.0 + 1.0
|
| 401 |
+
|
| 402 |
+
for j in range(idx):
|
| 403 |
+
img_sid, length = modality_positions[i, j]
|
| 404 |
+
masks[i, img_sid: img_sid + length] = 0
|
| 405 |
+
|
| 406 |
+
ut = ut.reshape(b * n, c, h, w)
|
| 407 |
+
xt = xt.reshape(b * n, c, h, w)
|
| 408 |
+
t = t.reshape(b * n)
|
| 409 |
+
|
| 410 |
+
return xt, t, ut, recons_images, masks
|
| 411 |
+
|
| 412 |
+
batch_time_m = AverageMeter()
|
| 413 |
+
data_time_m = AverageMeter()
|
| 414 |
+
end = time.time()
|
| 415 |
+
|
| 416 |
+
for epoch in range(first_epoch, num_train_epochs):
|
| 417 |
+
model.train()
|
| 418 |
+
for batch in mixed_loader:
|
| 419 |
+
|
| 420 |
+
text_tokens = batch['text_tokens'].to(accelerator.device)
|
| 421 |
+
text_labels = batch['text_labels'].to(accelerator.device)
|
| 422 |
+
pixel_values = batch['images'].to(accelerator.device).to(weight_type)
|
| 423 |
+
if batch['data_type'][0] == 'interleaved_data':
|
| 424 |
+
b, n = pixel_values.shape[:2]
|
| 425 |
+
pixel_values = rearrange(pixel_values, "b n c h w -> (b n) c h w")
|
| 426 |
+
batch['data_type'] = batch['data_type'] * n
|
| 427 |
+
else:
|
| 428 |
+
b, n = 0, 0
|
| 429 |
+
|
| 430 |
+
text_masks = batch['text_masks'].to(accelerator.device)
|
| 431 |
+
image_masks = batch['image_masks'].to(accelerator.device)
|
| 432 |
+
modality_positions = batch['modality_positions'].to(accelerator.device)
|
| 433 |
+
# prepare image latents and labels
|
| 434 |
+
image_latents, t, image_labels, recons_images, image_masks = prepare_latents_and_labels(pixel_values,
|
| 435 |
+
batch['data_type'],
|
| 436 |
+
(b, n),
|
| 437 |
+
image_masks,
|
| 438 |
+
modality_positions)
|
| 439 |
+
# B=None would potentially induce loss spike when there are a lot of ignored labels (-100) in the batch
|
| 440 |
+
# we must set B=text_tokens.shape[0] (loss spike may still happen sometimes)
|
| 441 |
+
omni_mask_fn = omni_attn_mask(modality_positions) # 여기서 마스크 정보가 다 준비됨
|
| 442 |
+
# block_mask = create_block_mask(omni_mask_fn, B=text_tokens.shape[0], H=None,
|
| 443 |
+
# Q_LEN=preproc_config.max_mixed_modal_seq_length,
|
| 444 |
+
# KV_LEN=preproc_config.max_mixed_modal_seq_length, device=accelerator.device)
|
| 445 |
+
# or use naive omni attention mask, which is more stable
|
| 446 |
+
block_mask = omni_attn_mask_naive(text_tokens.size(0),
|
| 447 |
+
text_tokens.size(1),
|
| 448 |
+
modality_positions,
|
| 449 |
+
accelerator.device).to(weight_type)
|
| 450 |
+
|
| 451 |
+
logits, loss_ntp, loss_flow = model(text_tokens=text_tokens,
|
| 452 |
+
image_latents=image_latents,
|
| 453 |
+
t=t.to(weight_type),
|
| 454 |
+
attention_mask=block_mask,
|
| 455 |
+
text_masks=text_masks,
|
| 456 |
+
image_masks=image_masks,
|
| 457 |
+
text_labels=text_labels,
|
| 458 |
+
image_labels=image_labels,
|
| 459 |
+
modality_positions=modality_positions,
|
| 460 |
+
output_hidden_states=True,
|
| 461 |
+
max_seq_len=text_tokens.size(1),
|
| 462 |
+
device=accelerator.device,
|
| 463 |
+
)
|
| 464 |
+
|
| 465 |
+
# Gather the losses across all processes for logging (if we use distributed training).
|
| 466 |
+
avg_loss_ntp = accelerator.gather(loss_ntp.repeat(total_batch_size_per_gpu)).mean()
|
| 467 |
+
avg_loss_flow = accelerator.gather(loss_flow.repeat(total_batch_size_per_gpu)).mean()
|
| 468 |
+
loss = config.training.ntp_coeff * loss_ntp + config.training.flow_coeff * loss_flow
|
| 469 |
+
|
| 470 |
+
accelerator.backward(loss.to(weight_type) / config.training.gradient_accumulation_steps)
|
| 471 |
+
|
| 472 |
+
if config.training.max_grad_norm is not None and accelerator.sync_gradients:
|
| 473 |
+
accelerator.clip_grad_norm_(model.parameters(), config.training.max_grad_norm)
|
| 474 |
+
|
| 475 |
+
if (global_step + 1) % config.training.gradient_accumulation_steps == 0:
|
| 476 |
+
optimizer.step()
|
| 477 |
+
lr_scheduler.step()
|
| 478 |
+
|
| 479 |
+
# log gradient norm before zeroing it
|
| 480 |
+
if (
|
| 481 |
+
accelerator.sync_gradients
|
| 482 |
+
and (global_step + 1) % config.experiment.log_grad_norm_every == 0
|
| 483 |
+
and accelerator.is_main_process
|
| 484 |
+
):
|
| 485 |
+
log_grad_norm(model, accelerator, global_step + 1)
|
| 486 |
+
|
| 487 |
+
if (global_step + 1) % config.training.gradient_accumulation_steps == 0:
|
| 488 |
+
optimizer.zero_grad(set_to_none=True)
|
| 489 |
+
|
| 490 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
| 491 |
+
if accelerator.sync_gradients:
|
| 492 |
+
|
| 493 |
+
batch_time_m.update(time.time() - end)
|
| 494 |
+
end = time.time()
|
| 495 |
+
|
| 496 |
+
# Log metrics
|
| 497 |
+
if (global_step + 1) % config.experiment.log_every == 0:
|
| 498 |
+
samples_per_second_per_gpu = (
|
| 499 |
+
config.training.gradient_accumulation_steps * total_batch_size_per_gpu / batch_time_m.val
|
| 500 |
+
)
|
| 501 |
+
lr = [group["lr"] for group in optimizer.param_groups]
|
| 502 |
+
if len(lr) == 3:
|
| 503 |
+
logs = {
|
| 504 |
+
"step_loss_ntp": avg_loss_ntp.item(),
|
| 505 |
+
"step_loss_flow": avg_loss_flow.item(),
|
| 506 |
+
"lr_ve": lr[0],
|
| 507 |
+
"lr_proj": lr[1],
|
| 508 |
+
"lr_showo": lr[2],
|
| 509 |
+
"samples/sec/gpu": samples_per_second_per_gpu,
|
| 510 |
+
"data_time": data_time_m.val,
|
| 511 |
+
"batch_time": batch_time_m.val,
|
| 512 |
+
}
|
| 513 |
+
accelerator.log(logs, step=global_step + 1)
|
| 514 |
+
logger.info(
|
| 515 |
+
f"Epoch: {epoch} "
|
| 516 |
+
f"Step: {global_step + 1} "
|
| 517 |
+
f"Loss_NTP: {avg_loss_ntp.item():0.4f} "
|
| 518 |
+
f"Loss_FLOW: {avg_loss_flow.item():0.4f} "
|
| 519 |
+
f"Data (t): {data_time_m.val:0.4f}, {samples_per_second_per_gpu:0.2f}/s/gpu "
|
| 520 |
+
f"Batch (t): {batch_time_m.val:0.4f} "
|
| 521 |
+
f"LR_ve: {lr[0]:0.6f} "
|
| 522 |
+
f"LR_proj: {lr[1]:0.6f} "
|
| 523 |
+
f"LR_showo: {lr[2]:0.6f}"
|
| 524 |
+
)
|
| 525 |
+
else:
|
| 526 |
+
logs = {
|
| 527 |
+
"step_loss_ntp": avg_loss_ntp.item(),
|
| 528 |
+
"step_loss_flow": avg_loss_flow.item(),
|
| 529 |
+
"lr": lr[0],
|
| 530 |
+
"samples/sec/gpu": samples_per_second_per_gpu,
|
| 531 |
+
"data_time": data_time_m.val,
|
| 532 |
+
"batch_time": batch_time_m.val,
|
| 533 |
+
}
|
| 534 |
+
accelerator.log(logs, step=global_step + 1)
|
| 535 |
+
logger.info(
|
| 536 |
+
f"Epoch: {epoch} "
|
| 537 |
+
f"Step: {global_step + 1} "
|
| 538 |
+
f"Loss_NTP: {avg_loss_ntp.item():0.4f} "
|
| 539 |
+
f"Loss_FLOW: {avg_loss_flow.item():0.4f} "
|
| 540 |
+
f"Data (t): {data_time_m.val:0.4f}, {samples_per_second_per_gpu:0.2f}/s/gpu "
|
| 541 |
+
f"Batch (t): {batch_time_m.val:0.4f} "
|
| 542 |
+
f"LR: {lr[0]:0.6f}"
|
| 543 |
+
)
|
| 544 |
+
# resetting batch / data time meters per log window
|
| 545 |
+
batch_time_m.reset()
|
| 546 |
+
data_time_m.reset()
|
| 547 |
+
|
| 548 |
+
# Save model checkpoint
|
| 549 |
+
if (global_step + 1) % config.experiment.save_every == 0:
|
| 550 |
+
save_checkpoint(model, config, accelerator, global_step + 1)
|
| 551 |
+
|
| 552 |
+
global_step += 1
|
| 553 |
+
|
| 554 |
+
# Stop training if max steps is reached
|
| 555 |
+
if global_step >= config.training.max_train_steps:
|
| 556 |
+
break
|
| 557 |
+
# End for
|
| 558 |
+
|
| 559 |
+
accelerator.wait_for_everyone()
|
| 560 |
+
|
| 561 |
+
# Evaluate and save checkpoint at the end of training
|
| 562 |
+
save_checkpoint(model, config, accelerator, "final")
|
| 563 |
+
|
| 564 |
+
# Save the final trained checkpoint
|
| 565 |
+
if accelerator.is_main_process:
|
| 566 |
+
model = accelerator.unwrap_model(model)
|
| 567 |
+
model.save_pretrained(config.experiment.output_dir, safe_serialization=False)
|
| 568 |
+
|
| 569 |
+
accelerator.end_training()
|
| 570 |
+
|
| 571 |
+
|
| 572 |
+
@torch.no_grad()
|
| 573 |
+
def generate_images(
|
| 574 |
+
model,
|
| 575 |
+
vae_model,
|
| 576 |
+
text_tokenizer,
|
| 577 |
+
config,
|
| 578 |
+
global_step,
|
| 579 |
+
device,
|
| 580 |
+
weight_type,
|
| 581 |
+
sampler,
|
| 582 |
+
showo_token_ids,
|
| 583 |
+
):
|
| 584 |
+
logger.info("Generating images...")
|
| 585 |
+
model.eval()
|
| 586 |
+
|
| 587 |
+
# read validation prompts from file
|
| 588 |
+
with open(config.dataset.params.validation_prompts_file, "r") as f:
|
| 589 |
+
prompts = f.read().splitlines()[:config.training.batch_size_t2i]
|
| 590 |
+
|
| 591 |
+
num_t2i_image_tokens, num_mmu_image_tokens, num_video_tokens, max_seq_len, max_text_len, image_latent_dim, patch_size, latent_width, \
|
| 592 |
+
latent_height, pad_id, bos_id, eos_id, boi_id, eoi_id, bov_id, eov_id, image_pad_id, video_pad_id, guidance_scale \
|
| 593 |
+
= get_hyper_params(config, text_tokenizer, showo_token_ids)
|
| 594 |
+
|
| 595 |
+
batch_text_tokens, batch_text_tokens_null, batch_modality_positions, batch_modality_positions_null = \
|
| 596 |
+
prepare_gen_input(
|
| 597 |
+
prompts, text_tokenizer, num_t2i_image_tokens, bos_id, eos_id, boi_id, eoi_id, pad_id, image_pad_id,
|
| 598 |
+
max_text_len, device
|
| 599 |
+
)
|
| 600 |
+
|
| 601 |
+
z = torch.randn((len(prompts),
|
| 602 |
+
image_latent_dim, latent_height * patch_size,
|
| 603 |
+
latent_width * patch_size)).to(weight_type).to(device)
|
| 604 |
+
|
| 605 |
+
if guidance_scale > 0:
|
| 606 |
+
z = torch.cat([z, z], dim=0)
|
| 607 |
+
text_tokens = torch.cat([batch_text_tokens, batch_text_tokens_null], dim=0)
|
| 608 |
+
modality_positions = torch.cat([batch_modality_positions, batch_modality_positions_null], dim=0)
|
| 609 |
+
# B=None would potentially induce loss spike when there are a lot of ignored labels (-100) in the batch
|
| 610 |
+
# we must set B=text_tokens.shape[0] (loss spike may still happen sometimes)
|
| 611 |
+
# omni_mask_fn = omni_attn_mask(modality_positions)
|
| 612 |
+
# block_mask = create_block_mask(omni_mask_fn, B=z.size(0), H=None, Q_LEN=max_seq_len,
|
| 613 |
+
# KV_LEN=max_seq_len, device=device)
|
| 614 |
+
# or use naive omni attention mask, which is more stable
|
| 615 |
+
block_mask = omni_attn_mask_naive(text_tokens.size(0),
|
| 616 |
+
max_seq_len,
|
| 617 |
+
modality_positions,
|
| 618 |
+
device).to(weight_type)
|
| 619 |
+
else:
|
| 620 |
+
text_tokens = batch_text_tokens
|
| 621 |
+
modality_positions = batch_modality_positions
|
| 622 |
+
# B=None would potentially induce loss spike when there are a lot of ignored labels (-100) in the batch
|
| 623 |
+
# we must set B=text_tokens.shape[0] (loss spike may still happen sometimes)
|
| 624 |
+
# omni_mask_fn = omni_attn_mask(modality_positions)
|
| 625 |
+
# block_mask = create_block_mask(omni_mask_fn, B=z.size(0), H=None, Q_LEN=max_seq_len,
|
| 626 |
+
# KV_LEN=max_seq_len, device=device)
|
| 627 |
+
block_mask = omni_attn_mask_naive(text_tokens.size(0),
|
| 628 |
+
max_seq_len,
|
| 629 |
+
modality_positions,
|
| 630 |
+
device).to(weight_type)
|
| 631 |
+
|
| 632 |
+
model_kwargs = dict(
|
| 633 |
+
text_tokens=torch.cat([batch_text_tokens, batch_text_tokens_null], dim=0),
|
| 634 |
+
attention_mask=block_mask,
|
| 635 |
+
modality_positions=torch.cat([batch_modality_positions,
|
| 636 |
+
batch_modality_positions_null], dim=0),
|
| 637 |
+
output_hidden_states=True,
|
| 638 |
+
max_seq_len=max_seq_len,
|
| 639 |
+
guidance_scale=guidance_scale
|
| 640 |
+
)
|
| 641 |
+
|
| 642 |
+
sample_fn = sampler.sample_ode(
|
| 643 |
+
sampling_method=config.transport.sampling_method,
|
| 644 |
+
num_steps=config.transport.num_inference_steps,
|
| 645 |
+
atol=config.transport.atol,
|
| 646 |
+
rtol=config.transport.rtol,
|
| 647 |
+
reverse=config.transport.reverse,
|
| 648 |
+
time_shifting_factor=config.transport.time_shifting_factor
|
| 649 |
+
)
|
| 650 |
+
samples = sample_fn(z, model.t2i_generate, **model_kwargs)[-1]
|
| 651 |
+
samples = torch.chunk(samples, 2)[0]
|
| 652 |
+
|
| 653 |
+
if config.model.vae_model.type == 'wan21':
|
| 654 |
+
samples = samples.unsqueeze(2)
|
| 655 |
+
images = vae_model.batch_decode(samples)
|
| 656 |
+
images = images.squeeze(2)
|
| 657 |
+
else:
|
| 658 |
+
raise NotImplementedError
|
| 659 |
+
|
| 660 |
+
model.train()
|
| 661 |
+
|
| 662 |
+
# Convert to PIL images
|
| 663 |
+
images = denorm(images)
|
| 664 |
+
pil_images = [Image.fromarray(image) for image in images]
|
| 665 |
+
|
| 666 |
+
# Log images
|
| 667 |
+
wandb_images = [wandb.Image(image, caption=prompts[i]) for i, image in enumerate(pil_images)]
|
| 668 |
+
wandb.log({"Generated images": wandb_images}, step=global_step)
|
| 669 |
+
|
| 670 |
+
|
| 671 |
+
@torch.no_grad()
|
| 672 |
+
def visualize_reconstruction(
|
| 673 |
+
pixel_values,
|
| 674 |
+
recons_images,
|
| 675 |
+
captions,
|
| 676 |
+
global_step
|
| 677 |
+
):
|
| 678 |
+
logger.info("Visualizing images...")
|
| 679 |
+
|
| 680 |
+
# Convert to PIL images
|
| 681 |
+
images = denorm(pixel_values)
|
| 682 |
+
recons_images = denorm(recons_images)
|
| 683 |
+
visualized_images = np.concatenate((images, recons_images), 2)
|
| 684 |
+
pil_images = [Image.fromarray(image) for image in visualized_images]
|
| 685 |
+
|
| 686 |
+
# Log images
|
| 687 |
+
wandb_images = [wandb.Image(image, caption=captions[i]) for i, image in enumerate(pil_images)]
|
| 688 |
+
wandb.log({"Original images vs. Reconstructed": wandb_images}, step=global_step)
|
| 689 |
+
|
| 690 |
+
|
| 691 |
+
@torch.no_grad()
|
| 692 |
+
def generate_videos(
|
| 693 |
+
model,
|
| 694 |
+
vae_model,
|
| 695 |
+
text_tokenizer,
|
| 696 |
+
config,
|
| 697 |
+
global_step,
|
| 698 |
+
device,
|
| 699 |
+
weight_type,
|
| 700 |
+
sampler,
|
| 701 |
+
showo_token_ids
|
| 702 |
+
):
|
| 703 |
+
logger.info("Generating videos...")
|
| 704 |
+
model.eval()
|
| 705 |
+
|
| 706 |
+
# read validation prompts from file
|
| 707 |
+
with open(config.dataset.params.validation_prompts_file, "r") as f:
|
| 708 |
+
prompts = f.read().splitlines()[:config.training.batch_size_t2i]
|
| 709 |
+
|
| 710 |
+
num_image_tokens, num_video_tokens, max_seq_len, max_text_len, image_latent_dim, patch_size, latent_width, \
|
| 711 |
+
latent_height, pad_id, bos_id, eos_id, boi_id, eoi_id, bov_id, eov_id, image_pad_id, video_pad_id, guidance_scale \
|
| 712 |
+
= get_hyper_params(config, text_tokenizer, showo_token_ids, is_video=True)
|
| 713 |
+
|
| 714 |
+
batch_text_tokens, batch_text_tokens_null, batch_modality_positions, batch_modality_positions_null = \
|
| 715 |
+
prepare_gen_input(
|
| 716 |
+
prompts, text_tokenizer, num_video_tokens, bos_id, eos_id, bov_id, eov_id, pad_id, video_pad_id,
|
| 717 |
+
max_text_len, device
|
| 718 |
+
)
|
| 719 |
+
|
| 720 |
+
T = 5
|
| 721 |
+
z = torch.randn((len(prompts), image_latent_dim, T, latent_height * patch_size, latent_width * patch_size)).to(
|
| 722 |
+
device).to(weight_type)
|
| 723 |
+
|
| 724 |
+
if guidance_scale > 0:
|
| 725 |
+
z = torch.cat([z, z], dim=0)
|
| 726 |
+
text_tokens = torch.cat([batch_text_tokens, batch_text_tokens_null], dim=0)
|
| 727 |
+
modality_positions = torch.cat([batch_modality_positions, batch_modality_positions_null], dim=0)
|
| 728 |
+
# B=None would potentially induce loss spike when there are a lot of ignored labels (-100) in the batch
|
| 729 |
+
# we must set B=text_tokens.shape[0] (loss spike may still happen sometimes)
|
| 730 |
+
# omni_mask_fn = omni_attn_mask(modality_positions)
|
| 731 |
+
# block_mask = create_block_mask(omni_mask_fn, B=z.size(0), H=None, Q_LEN=max_seq_len,
|
| 732 |
+
# KV_LEN=max_seq_len, device=device)
|
| 733 |
+
# or use naive omni attention mask, which is more stable
|
| 734 |
+
block_mask = omni_attn_mask_naive(text_tokens.size(0),
|
| 735 |
+
max_seq_len,
|
| 736 |
+
modality_positions,
|
| 737 |
+
device).to(weight_type)
|
| 738 |
+
else:
|
| 739 |
+
text_tokens = batch_text_tokens
|
| 740 |
+
modality_positions = batch_modality_positions
|
| 741 |
+
# B=None would potentially induce loss spike when there are a lot of ignored labels (-100) in the batch
|
| 742 |
+
# we must set B=text_tokens.shape[0] (loss spike may still happen sometimes)
|
| 743 |
+
# omni_mask_fn = omni_attn_mask(modality_positions)
|
| 744 |
+
# block_mask = create_block_mask(omni_mask_fn, B=z.size(0), H=None, Q_LEN=max_seq_len,
|
| 745 |
+
# KV_LEN=max_seq_len, device=device)
|
| 746 |
+
block_mask = omni_attn_mask_naive(text_tokens.size(0),
|
| 747 |
+
max_seq_len,
|
| 748 |
+
modality_positions,
|
| 749 |
+
device).to(weight_type)
|
| 750 |
+
|
| 751 |
+
model_kwargs = dict(
|
| 752 |
+
text_tokens=text_tokens,
|
| 753 |
+
attention_mask=block_mask,
|
| 754 |
+
modality_positions=modality_positions,
|
| 755 |
+
output_hidden_states=True,
|
| 756 |
+
max_seq_len=max_seq_len,
|
| 757 |
+
guidance_scale=guidance_scale
|
| 758 |
+
)
|
| 759 |
+
|
| 760 |
+
sample_fn = sampler.sample_ode(
|
| 761 |
+
sampling_method=config.transport.sampling_method,
|
| 762 |
+
num_steps=config.transport.num_inference_steps,
|
| 763 |
+
atol=config.transport.atol,
|
| 764 |
+
rtol=config.transport.rtol,
|
| 765 |
+
reverse=config.transport.reverse,
|
| 766 |
+
time_shifting_factor=config.transport.time_shifting_factor
|
| 767 |
+
)
|
| 768 |
+
samples = sample_fn(z, model.t2i_generate, **model_kwargs)[-1]
|
| 769 |
+
samples = torch.chunk(samples, 2)[0]
|
| 770 |
+
|
| 771 |
+
if config.model.vae_model.type == 'wan21':
|
| 772 |
+
images = vae_model.batch_decode(samples)
|
| 773 |
+
else:
|
| 774 |
+
raise NotImplementedError
|
| 775 |
+
|
| 776 |
+
model.train()
|
| 777 |
+
|
| 778 |
+
# Convert to PIL images
|
| 779 |
+
images = denorm_vid(images)
|
| 780 |
+
|
| 781 |
+
# Log images
|
| 782 |
+
wandb_images = [wandb.Video(image, caption=prompts[i], fps=8, format="mp4") for i, image in enumerate(images)]
|
| 783 |
+
wandb.log({"Generated videos": wandb_images}, step=global_step)
|
| 784 |
+
|
| 785 |
+
|
| 786 |
+
@torch.no_grad()
|
| 787 |
+
def visualize_reconstruction_video(
|
| 788 |
+
pixel_values,
|
| 789 |
+
recons_images,
|
| 790 |
+
captions,
|
| 791 |
+
global_step
|
| 792 |
+
):
|
| 793 |
+
logger.info("Visualizing videos...")
|
| 794 |
+
|
| 795 |
+
# Convert to PIL images
|
| 796 |
+
images = denorm_vid(pixel_values)
|
| 797 |
+
recons_images = denorm_vid(recons_images)
|
| 798 |
+
visualized_images = np.concatenate((images, recons_images), 4)
|
| 799 |
+
|
| 800 |
+
# Log images
|
| 801 |
+
wandb_images = [wandb.Video(image, caption=captions[i], fps=8, format="mp4") for i, image in
|
| 802 |
+
enumerate(visualized_images)]
|
| 803 |
+
wandb.log({"Original videos vs. Reconstructed": wandb_images}, step=global_step)
|
| 804 |
+
|
| 805 |
+
|
| 806 |
+
def save_checkpoint(model, config, accelerator, global_step):
|
| 807 |
+
output_dir = config.experiment.output_dir
|
| 808 |
+
checkpoints_total_limit = config.experiment.get("checkpoints_total_limit", None)
|
| 809 |
+
|
| 810 |
+
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
|
| 811 |
+
if accelerator.is_main_process and checkpoints_total_limit is not None:
|
| 812 |
+
checkpoints = os.listdir(output_dir)
|
| 813 |
+
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
|
| 814 |
+
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
|
| 815 |
+
|
| 816 |
+
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
|
| 817 |
+
if len(checkpoints) >= checkpoints_total_limit:
|
| 818 |
+
num_to_remove = len(checkpoints) - checkpoints_total_limit + 1
|
| 819 |
+
removing_checkpoints = checkpoints[0:num_to_remove]
|
| 820 |
+
|
| 821 |
+
logger.info(
|
| 822 |
+
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
|
| 823 |
+
)
|
| 824 |
+
logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
|
| 825 |
+
|
| 826 |
+
for removing_checkpoint in removing_checkpoints:
|
| 827 |
+
removing_checkpoint = os.path.join(output_dir, removing_checkpoint)
|
| 828 |
+
shutil.rmtree(removing_checkpoint)
|
| 829 |
+
|
| 830 |
+
save_path = Path(output_dir) / f"checkpoint-{global_step}"
|
| 831 |
+
|
| 832 |
+
# retrieve the model on all processes for deepspeed stage 3 to work then save on one process (we are not using stage 3 yet)
|
| 833 |
+
# XXX: could also make this conditional on deepspeed
|
| 834 |
+
state_dict = accelerator.get_state_dict(model)
|
| 835 |
+
if accelerator.is_main_process:
|
| 836 |
+
unwrapped_model = accelerator.unwrap_model(model)
|
| 837 |
+
unwrapped_model.save_pretrained(
|
| 838 |
+
save_path / "unwrapped_model",
|
| 839 |
+
save_function=accelerator.save,
|
| 840 |
+
state_dict=state_dict,
|
| 841 |
+
safe_serialization=False
|
| 842 |
+
)
|
| 843 |
+
json.dump({"global_step": global_step}, (save_path / "metadata.json").open("w+"))
|
| 844 |
+
logger.info(f"Saved state to {save_path}")
|
| 845 |
+
|
| 846 |
+
|
| 847 |
+
def log_grad_norm(model, accelerator, global_step):
|
| 848 |
+
for name, param in model.named_parameters():
|
| 849 |
+
if param.grad is not None:
|
| 850 |
+
grads = param.grad.detach().data
|
| 851 |
+
grad_norm = (grads.norm(p=2) / grads.numel()).item()
|
| 852 |
+
accelerator.log({"grad_norm/" + name: grad_norm}, step=global_step)
|
| 853 |
+
|
| 854 |
+
|
| 855 |
+
if __name__ == "__main__":
|
| 856 |
+
main()
|