SIChoi commited on
Commit
1a6af5d
·
0 Parent(s):

upload dataset, checkpoint, and training script

Browse files
.gitattributes ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ image_v5_0_128/** filter=lfs diff=lfs merge=lfs -text
2
+ checkpoint-40000/** filter=lfs diff=lfs merge=lfs -text
3
+ *.json filter=lfs diff=lfs merge=lfs -text
checkpoint-40000 ADDED
@@ -0,0 +1 @@
 
 
1
+ /home/a6000/bk-project/multimodal-showo/show-o2/1st_show-o2-1.5b-downstream-mixed-modality-432x432/checkpoint-40000
image_v5_0_128 ADDED
@@ -0,0 +1 @@
 
 
1
+ ../AvaMERG_img_inter/image_v5_0_128
showo2_1.5b_downstream_mixed_modality_simple.yaml ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ wandb:
2
+ entity: null
3
+ resume: 'auto'
4
+
5
+ # 410k -- 512 res mixedmodal
6
+
7
+ experiment:
8
+ project: "showo2-2b-stage-1"
9
+ name: "showo2-1.5b-downstream-mixed-modality-432x432"
10
+ output_dir: "1st_show-o2-1.5b-downstream-mixed-modality-432x432"
11
+ output_dataloader_state_dir: null
12
+ max_train_examples_t2i: 60000000 # 10M HQ generation data
13
+ max_train_examples_mmu: null
14
+ save_every: 500
15
+ generate_every: 1000
16
+ log_every: 1
17
+ log_grad_norm_every: 500
18
+ resume_from_checkpoint: 'latest'
19
+
20
+ model:
21
+ vae_model:
22
+ type: "wan21"
23
+ pretrained_model_path: "Wan_VAE_model/Wan2.1_VAE.pth" # our local path
24
+
25
+ showo:
26
+ model_name: "Showo2"
27
+ load_from_showo: True
28
+ # load_from_showo: False
29
+ pretrained_model_path: "showlab/show-o2-1.5B"
30
+ # pretrained_model_path: "/home/a6000/bk-project/multimodal-showo/show-o2/3rd_show-o2-1.5b-downstream-mixed-modality-432x432/checkpoint-2000/unwrapped_model" # our stage-1 weight path
31
+ llm_model_path: "Qwen/Qwen2.5-1.5B-Instruct"
32
+ llm_vocab_size: null # will be updated when setting the tokenizer in other code files
33
+ hidden_size: 1536
34
+ image_latent_dim: 16
35
+ image_latent_height: 27
36
+ image_latent_width: 27
37
+ hq_image_latent_height: 64
38
+ hq_image_latent_width: 64
39
+ mixed_modal_latent_height: 27
40
+ mixed_modal_latent_width: 27
41
+ patch_size: 2
42
+ num_diffusion_layers: 10
43
+ clip_latent_dim: 1152
44
+ add_qk_norm: True
45
+ add_time_embeds: True
46
+ # frozen_params: [ 'image_embedder_und', 'und_trans', 'showo', 'position_embedding']
47
+ params_not_load: null
48
+
49
+ clip:
50
+ pretrained_model_path: "google/siglip-so400m-patch14-384"
51
+
52
+ gradient_checkpointing: True
53
+
54
+ dataset:
55
+ samp_probs: null
56
+ accumulation: 1
57
+ mixed_loader_mode: "sequential_max_size_cycle"
58
+ params:
59
+ train_mixed_modal_shards_path_or_url: "./AvaMERG_img_inter/image_v5_0_128" # our dataset
60
+ annotation_path: "./AvaMERG_img_inter/train_inter_final.json" # our dataset
61
+ is_clip_encoder: False
62
+ default_system_prompt: ""
63
+ add_caption_prompt: True
64
+ validation_prompts_file: "prompts/t2i_prompts.txt"
65
+ shuffle_buffer_size: 1000
66
+ num_workers: 0
67
+ pin_memory: True
68
+ persistent_workers: True
69
+
70
+ preprocessing:
71
+ max_seq_length: 1280
72
+ max_hq_seq_length: 4352
73
+ max_mixed_modal_seq_length: 4352
74
+ max_video_seq_length: 4352
75
+ resolution: 432
76
+ mixed_modal_resolution: 432
77
+ video_resolution: 432
78
+ hq_resolution: 1024
79
+ num_t2i_image_tokens: 729
80
+ num_mmu_image_tokens: 729
81
+ num_hq_image_tokens: 4096
82
+ num_mixed_modal_tokens: 729
83
+ num_video_tokens: 3645
84
+ latent_height: ${model.showo.image_latent_height}
85
+ latent_width: ${model.showo.image_latent_width}
86
+ video_latent_height: ${model.showo.image_latent_height}
87
+ video_latent_width: ${model.showo.image_latent_width}
88
+ hq_latent_height: ${model.showo.hq_image_latent_height}
89
+ hq_latent_width: ${model.showo.hq_image_latent_width}
90
+ mixed_modal_latent_height: ${model.showo.hq_image_latent_height}
91
+ mixed_modal_latent_width: ${model.showo.hq_image_latent_width}
92
+ min_res: [ 256, 256 ]
93
+ random_und_or_gen: 0.0
94
+ max_num_images: 4
95
+ max_num_videos: 4 # only for video training, not use in this case
96
+ num_frames: 2 # # only for video training, not use in this case
97
+
98
+ optimizer:
99
+ name: adamw
100
+ params: # default adamw params
101
+ learning_rate: 0.0001
102
+ scale_lr: False # scale learning rate by total batch size
103
+ beta1: 0.9
104
+ beta2: 0.999
105
+ weight_decay: 0.0
106
+ epsilon: 1e-8
107
+
108
+ lr_scheduler:
109
+ scheduler: "constant_with_warmup" # "polynomial"
110
+ params:
111
+ learning_rate: ${optimizer.params.learning_rate}
112
+ warmup_steps: 0
113
+ # min_lr: 1e-6
114
+ # power: 0.5 # for polynomial
115
+
116
+ transport:
117
+ path_type: "Linear"
118
+ prediction: "velocity"
119
+ loss_weight: null
120
+ train_eps: null
121
+ sample_eps: null
122
+ snr_type: "lognorm"
123
+ sampling_method: "euler"
124
+ guidance_scale: 5.0
125
+ num_inference_steps: 50
126
+ atol: 1e-6
127
+ rtol: 1e-3
128
+ reverse: False
129
+ do_shift: True
130
+ time_shifting_factor: 3.0
131
+
132
+ training:
133
+ gradient_accumulation_steps: 1
134
+ batch_size: 1
135
+ batch_size_mixed_modal: 1
136
+ batch_size_video: 0
137
+ mixed_precision: "bf16"
138
+ enable_tf32: True
139
+ seed: 10000
140
+ max_train_steps: 50000
141
+ cond_dropout_prob: 0.1
142
+ label_smoothing: 0.0
143
+ max_grad_norm: 1.0
144
+ ntp_coeff: 0.2
145
+ flow_coeff: 1.0
146
+ und_max_t0: 1.0
train_inter_final.json ADDED
@@ -0,0 +1 @@
 
 
1
+ ../AvaMERG_img_inter/train_inter_final.json
train_mixed_modality_simple.py ADDED
@@ -0,0 +1,856 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 NUS Show Lab, HuggingFace.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ import json
18
+ import logging
19
+ import math
20
+ import shutil
21
+ import time
22
+ from pathlib import Path
23
+ from typing import Union
24
+ import numpy as np
25
+ from PIL import Image
26
+ from omegaconf import OmegaConf
27
+ import wandb
28
+ import random
29
+ import torch
30
+ from torch.optim import AdamW
31
+ from einops import rearrange
32
+ from accelerate import Accelerator
33
+ from accelerate.logging import get_logger
34
+ from accelerate.utils import DistributedType, set_seed
35
+ from torch.utils.data import DataLoader
36
+ from torch.utils.data.distributed import DistributedSampler
37
+ from models import Showo2Qwen2_5, omni_attn_mask_naive, omni_attn_mask
38
+ from training.omni_attention import create_block_mask
39
+ from models.lr_schedulers import get_scheduler
40
+ from models.my_logging import set_verbosity_info, set_verbosity_error
41
+ from models.misc import prepare_gen_input, get_text_tokenizer, get_weight_type
42
+ from torch.nn.attention.flex_attention import flex_attention
43
+ os.environ["TOKENIZERS_PARALLELISM"] = "true"
44
+
45
+ if torch.cuda.is_available():
46
+ flex_attention = torch.compile(flex_attention)
47
+
48
+ from datasets import create_imagetext_dataloader, MixedDataLoader, VISTDataset
49
+ from utils import get_config, flatten_omega_conf, AverageMeter, denorm, denorm_vid, get_hyper_params, \
50
+ path_to_llm_name, _freeze_params
51
+
52
+ from transport import Sampler, create_transport
53
+
54
+ logger = get_logger(__name__, log_level="INFO")
55
+
56
+
57
+ def main():
58
+ #########################
59
+ # SETUP Accelerator #
60
+ #########################
61
+ config = get_config()
62
+
63
+ # Enable TF32 on Ampere GPUs
64
+ if config.training.enable_tf32:
65
+ torch.backends.cuda.matmul.allow_tf32 = True
66
+ torch.backends.cudnn.benchmark = True
67
+ torch.backends.cudnn.deterministic = False
68
+
69
+ config.experiment.logging_dir = str(Path(config.experiment.output_dir) / "logs")
70
+ accelerator = Accelerator(
71
+ gradient_accumulation_steps=config.training.gradient_accumulation_steps,
72
+ mixed_precision=config.training.mixed_precision,
73
+ log_with="wandb",
74
+ project_dir=config.experiment.logging_dir,
75
+ split_batches=True,
76
+ )
77
+
78
+ bs_mixed_modal = config.training.batch_size_mixed_modal
79
+
80
+ if "concat" in config.dataset.mixed_loader_mode:
81
+ raise NotImplementedError
82
+ else:
83
+ total_batch_size_per_gpu = bs_mixed_modal * config.dataset.accumulation
84
+ total_batch_size_without_accum = total_batch_size_per_gpu * accelerator.num_processes
85
+ total_batch_size = total_batch_size_without_accum * config.training.gradient_accumulation_steps
86
+
87
+ if accelerator.distributed_type == DistributedType.DEEPSPEED:
88
+ accelerator.state.deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = (
89
+ total_batch_size_per_gpu
90
+ )
91
+ print("[DEBUG] CUDA_VISIBLE_DEVICES:", os.environ.get("CUDA_VISIBLE_DEVICES"))
92
+ print("[DEBUG] torch.cuda.device_count():", torch.cuda.device_count())
93
+ print("[DEBUG] Accelerator processes:", accelerator.num_processes)
94
+
95
+ #####################################
96
+ # SETUP LOGGING, SEED and CONFIG #
97
+ #####################################
98
+ # Make one log on every process with the configuration for debugging.
99
+ logging.basicConfig(
100
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
101
+ datefmt="%m/%d/%Y %H:%M:%S",
102
+ level=logging.INFO,
103
+ )
104
+ logger.info(accelerator.state, main_process_only=False)
105
+ if accelerator.is_local_main_process:
106
+ set_verbosity_info()
107
+ else:
108
+ set_verbosity_error()
109
+
110
+ # We need to initialize the trackers we use, and also store our configuration.
111
+ # The trackers initializes automatically on the main process.
112
+ if accelerator.is_main_process:
113
+ resume_wandb_run = config.wandb.resume
114
+ run_id = config.wandb.get("run_id", None)
115
+ if run_id is None:
116
+ resume_wandb_run = False
117
+ run_id = wandb.util.generate_id()
118
+ config.wandb.run_id = run_id
119
+
120
+ wandb_init_kwargs = dict(
121
+ name=config.experiment.name,
122
+ id=run_id,
123
+ resume=resume_wandb_run,
124
+ entity=config.wandb.get("entity", None),
125
+ config_exclude_keys=[],
126
+ )
127
+ wandb_config = {k: v for k, v in flatten_omega_conf(config, resolve=True)}
128
+ wandb_config.pop("experiment.resume_from_checkpoint")
129
+
130
+ accelerator.init_trackers(
131
+ config.experiment.project,
132
+ config=wandb_config,
133
+ init_kwargs={"wandb": wandb_init_kwargs},
134
+ )
135
+
136
+ if accelerator.is_main_process:
137
+ os.makedirs(config.experiment.output_dir, exist_ok=True)
138
+ config_path = Path(config.experiment.output_dir) / "config.yaml"
139
+ logging.info(f"Saving config to {config_path}")
140
+ OmegaConf.save(config, config_path)
141
+
142
+ # If passed along, set the training seed now.
143
+ if config.training.seed is not None:
144
+ set_seed(config.training.seed)
145
+
146
+ #########################
147
+ # MODELS and OPTIMIZER #
148
+ #########################
149
+ logger.info("Loading models and optimizer")
150
+
151
+ weight_type = get_weight_type(config)
152
+
153
+ # VQ model for processing image into discrete tokens
154
+ if config.model.vae_model.type == 'wan21':
155
+ from models import WanVAE
156
+ vae_model = WanVAE(vae_pth=config.model.vae_model.pretrained_model_path, dtype=weight_type,
157
+ device=accelerator.device)
158
+ else:
159
+ raise NotImplementedError
160
+
161
+ # Initialize Show-o model
162
+ text_tokenizer, showo_token_ids = get_text_tokenizer(config.model.showo.llm_model_path, add_showo_tokens=True,
163
+ return_showo_token_ids=True,
164
+ llm_name=path_to_llm_name[config.model.showo.llm_model_path])
165
+ config.model.showo.llm_vocab_size = len(text_tokenizer)
166
+
167
+ if config.model.showo.load_from_showo:
168
+ model = Showo2Qwen2_5.from_pretrained(config.model.showo.pretrained_model_path, use_safetensors=False).to(accelerator.device)
169
+ else:
170
+ model = Showo2Qwen2_5(**config.model.showo).to(accelerator.device)
171
+
172
+ # Choose layers to freeze
173
+ _freeze_params(model, config.model.showo.frozen_params)
174
+
175
+ preproc_config = config.dataset.preprocessing
176
+ dataset_config = config.dataset.params
177
+
178
+ # for time embedding
179
+ if config.model.showo.add_time_embeds:
180
+ # we prepend the time embedding to vision tokens
181
+ config.dataset.preprocessing.num_mmu_image_tokens += 1
182
+ config.dataset.preprocessing.num_t2i_image_tokens += 1
183
+ config.dataset.preprocessing.num_hq_image_tokens += 1
184
+ config.dataset.preprocessing.num_video_tokens += 1
185
+ config.dataset.preprocessing.num_mixed_modal_tokens += 1
186
+
187
+ ##################################
188
+ # Optimizer and LR scheduler #
189
+ #################################
190
+ optimizer_config = config.optimizer.params
191
+ optimizer_type = config.optimizer.name
192
+
193
+ if optimizer_type == "adamw":
194
+ optimizer = AdamW(
195
+ model.parameters(),
196
+ lr=optimizer_config.learning_rate,
197
+ betas=(optimizer_config.beta1, optimizer_config.beta2),
198
+ weight_decay=optimizer_config.weight_decay,
199
+ eps=optimizer_config.epsilon,
200
+ )
201
+ else:
202
+ raise ValueError(f"Optimizer {optimizer_type} not supported")
203
+
204
+ ##################################
205
+ # DATALOADER #
206
+ #################################
207
+ logger.info("Creating dataloaders and lr_scheduler")
208
+
209
+ # DataLoaders creation:
210
+ # We use webdataset for data loading. The dataloaders are created with sampling with replacement.
211
+ # We don't do dataset resuming here, instead we resample the shards and buffer each time. The sampling is stochastic.
212
+ # This means that the dataloading is not deterministic, but it's fast and efficient.
213
+
214
+ def create_dataloader(dataset, batch_size, collate_fn):
215
+ generator = torch.Generator(device='cuda')
216
+ if accelerator.num_processes > 2:
217
+ sampler = DistributedSampler(dataset,
218
+ num_replicas=accelerator.num_processes,
219
+ rank=accelerator.process_index,
220
+ shuffle=True,
221
+ drop_last=True,
222
+ # generator=generator
223
+ )
224
+ shuffle = False
225
+ else:
226
+ sampler = None
227
+ shuffle = True
228
+
229
+ dataloader = DataLoader(dataset, batch_size=batch_size,
230
+ sampler=sampler, collate_fn=collate_fn,
231
+ shuffle=shuffle, num_workers=dataset_config.num_workers,
232
+ drop_last=True, generator=generator)
233
+ return dataloader
234
+
235
+ dataset = VISTDataset(
236
+ dataset_config.train_mixed_modal_shards_path_or_url,
237
+ anno_path=dataset_config.annotation_path,
238
+ text_tokenizer=text_tokenizer,
239
+ image_size=preproc_config.mixed_modal_resolution,
240
+ max_seq_len=preproc_config.max_mixed_modal_seq_length,
241
+ num_image_tokens=preproc_config.num_mixed_modal_tokens,
242
+ latent_width=preproc_config.mixed_modal_latent_width,
243
+ latent_height=preproc_config.mixed_modal_latent_height,
244
+ cond_dropout_prob=config.training.cond_dropout_prob,
245
+ min_res=preproc_config.min_res,
246
+ showo_token_ids=showo_token_ids,
247
+ system=("", "", ""),
248
+ max_num_images=preproc_config.max_num_images,
249
+ )
250
+ print("Dataset length:", len(dataset))
251
+ train_dataloader_mixed_modal = create_dataloader(dataset,
252
+ config.training.batch_size_mixed_modal, #1
253
+ dataset.collate_fn)
254
+
255
+ num_update_steps_per_epoch = len(train_dataloader_mixed_modal)
256
+ print('[DEBUG] num_update_steps_per_epoch:', num_update_steps_per_epoch)
257
+ num_train_epochs = math.ceil(config.training.max_train_steps / num_update_steps_per_epoch)
258
+
259
+ ##################################
260
+ # MODEL RESUME #
261
+ #################################
262
+ global_step = 0
263
+ first_epoch = 0
264
+
265
+ if config.experiment.resume_from_checkpoint:
266
+ dirs = os.listdir(config.experiment.output_dir)
267
+ # dirs = [d for d in dirs if d.startswith("checkpoint")]
268
+ dirs = [d for d in dirs if
269
+ d.startswith("checkpoint-") and d.split("-")[1].isdigit()] # 250804 수정; checkpoint-final 있을 경우
270
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
271
+ path = dirs[-1] if len(dirs) > 0 else None
272
+ if path is not None:
273
+ path = os.path.join(config.experiment.output_dir, path)
274
+
275
+ global_step = int(os.path.basename(path).split("-")[1])
276
+ first_epoch = global_step // num_update_steps_per_epoch
277
+
278
+ accelerator.print(f"Resuming from checkpoint {path}/unwrapped_model/pytorch_model.bin")
279
+ state_dict = torch.load(f'{path}/unwrapped_model/pytorch_model.bin', map_location="cpu")
280
+
281
+ # not load some parameters
282
+ if config.model.showo.params_not_load is not None:
283
+ params_to_delete = []
284
+ for k in state_dict:
285
+ for n in config.model.showo.params_not_load:
286
+ if n in k:
287
+ params_to_delete.append(k)
288
+ for k in params_to_delete:
289
+ del state_dict[k]
290
+
291
+ model.load_state_dict(state_dict, strict=False if config.model.showo.params_not_load is not None else True)
292
+ del state_dict
293
+
294
+ # Combine these dataloaders into a single iterable model
295
+ mixed_loader = MixedDataLoader(
296
+ loader_list=[train_dataloader_mixed_modal],
297
+ samp_probs=config.dataset.samp_probs,
298
+ accumulation=config.dataset.accumulation,
299
+ mode=config.dataset.mixed_loader_mode
300
+ )
301
+
302
+ lr_scheduler = get_scheduler(
303
+ config.lr_scheduler.scheduler,
304
+ optimizer=optimizer,
305
+ num_training_steps=config.training.max_train_steps - global_step,
306
+ num_warmup_steps=config.lr_scheduler.params.warmup_steps,
307
+ # power=config.lr_scheduler.params.power,
308
+ )
309
+
310
+ ##################################
311
+ # Prepare accelerator #
312
+ #################################
313
+ logger.info("Preparing model, optimizer and dataloaders")
314
+ model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler)
315
+
316
+ ##################################
317
+ # Training #
318
+ #################################
319
+ logger.info("***** Running training *****")
320
+ logger.info(f" Num training steps = {config.training.max_train_steps}")
321
+ logger.info(f" Instantaneous batch size per device = {total_batch_size_per_gpu}")
322
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
323
+ logger.info(f" Gradient Accumulation steps = {config.training.gradient_accumulation_steps}")
324
+
325
+ # default: 1000 steps, linear noise schedule
326
+ transport = create_transport(
327
+ path_type=config.transport.path_type,
328
+ prediction=config.transport.prediction,
329
+ loss_weight=config.transport.loss_weight,
330
+ train_eps=config.transport.train_eps,
331
+ sample_eps=config.transport.sample_eps,
332
+ snr_type=config.transport.snr_type,
333
+ do_shift=config.transport.do_shift,
334
+ seq_len=preproc_config.num_t2i_image_tokens,
335
+ ) # default: velocity;
336
+
337
+ sampler = Sampler(transport)
338
+
339
+ @torch.no_grad()
340
+ def prepare_latents_and_labels(
341
+ pixel_values: Union[torch.FloatTensor, torch.LongTensor],
342
+ data_type,
343
+ shape,
344
+ image_masks,
345
+ modality_positions
346
+ ):
347
+
348
+ if config.model.vae_model.type == 'wan21':
349
+ if len(pixel_values.shape) == 4:
350
+ pixel_values = pixel_values.unsqueeze(2)
351
+ image_latents = vae_model.sample(pixel_values)
352
+ recons_images = vae_model.batch_decode(image_latents)
353
+ if pixel_values.shape[2] == 1:
354
+ image_latents = image_latents.squeeze(2)
355
+ recons_images = recons_images.squeeze(2)
356
+ else:
357
+ raise NotImplementedError
358
+
359
+ c, h, w = image_latents.shape[1:]
360
+ # timesteps, noise, original image
361
+ # each for loop takes around 0.002, which is affordable
362
+ t_list, xt_list, ut_list, masks = [], [], [], []
363
+ for i, tp in enumerate(data_type):
364
+ # x0->noise x1->image
365
+ t, x0, x1 = transport.sample(image_latents[i][None],
366
+ config.training.und_max_t0 if tp in ['mmu', 'mmu_vid'] else None)
367
+ # timesteps, noised image, velocity
368
+ t, xt, ut = transport.path_sampler.plan(t, x0, x1)
369
+ t_list.append(t)
370
+ xt_list.append(xt)
371
+ ut_list.append(ut)
372
+ if data_type[0] != 'interleaved_data':
373
+ if tp in ['mmu', 'mmu_vid'] and config.training.und_max_t0 == 1.0:
374
+ masks.append(image_masks[i][None] * 0.0)
375
+ else:
376
+ masks.append(image_masks[i][None])
377
+
378
+ t = torch.stack(t_list, dim=0).squeeze(-1)
379
+ xt = torch.cat(xt_list, dim=0)
380
+ ut = torch.cat(ut_list, dim=0)
381
+
382
+ if len(masks) != 0:
383
+ masks = torch.cat(masks, dim=0)
384
+ else:
385
+ masks = image_masks
386
+
387
+ if data_type[0] == 'interleaved_data':
388
+ b, n = shape
389
+ image_latents = image_latents.reshape(b, n, c, h, w)
390
+ ut = ut.reshape(b, n, c, h, w)
391
+ xt = xt.reshape(b, n, c, h, w)
392
+ t = t.reshape(b, n)
393
+
394
+ for i in range(b):
395
+ if random.random() < 0.7:
396
+ non_zero_max_idx = max([_ for _, pos in enumerate(modality_positions[i]) if pos[1] != 0])
397
+ idx = random.randint(1, non_zero_max_idx) if non_zero_max_idx != 0 else 0
398
+ xt[i, :idx] = image_latents[i][None][:, :idx].clone()
399
+ # ut[i, :idx] = torch.zeros_like(image_latents[i][None][:, :idx])
400
+ t[i, :idx] = t[i, :idx] * 0.0 + 1.0
401
+
402
+ for j in range(idx):
403
+ img_sid, length = modality_positions[i, j]
404
+ masks[i, img_sid: img_sid + length] = 0
405
+
406
+ ut = ut.reshape(b * n, c, h, w)
407
+ xt = xt.reshape(b * n, c, h, w)
408
+ t = t.reshape(b * n)
409
+
410
+ return xt, t, ut, recons_images, masks
411
+
412
+ batch_time_m = AverageMeter()
413
+ data_time_m = AverageMeter()
414
+ end = time.time()
415
+
416
+ for epoch in range(first_epoch, num_train_epochs):
417
+ model.train()
418
+ for batch in mixed_loader:
419
+
420
+ text_tokens = batch['text_tokens'].to(accelerator.device)
421
+ text_labels = batch['text_labels'].to(accelerator.device)
422
+ pixel_values = batch['images'].to(accelerator.device).to(weight_type)
423
+ if batch['data_type'][0] == 'interleaved_data':
424
+ b, n = pixel_values.shape[:2]
425
+ pixel_values = rearrange(pixel_values, "b n c h w -> (b n) c h w")
426
+ batch['data_type'] = batch['data_type'] * n
427
+ else:
428
+ b, n = 0, 0
429
+
430
+ text_masks = batch['text_masks'].to(accelerator.device)
431
+ image_masks = batch['image_masks'].to(accelerator.device)
432
+ modality_positions = batch['modality_positions'].to(accelerator.device)
433
+ # prepare image latents and labels
434
+ image_latents, t, image_labels, recons_images, image_masks = prepare_latents_and_labels(pixel_values,
435
+ batch['data_type'],
436
+ (b, n),
437
+ image_masks,
438
+ modality_positions)
439
+ # B=None would potentially induce loss spike when there are a lot of ignored labels (-100) in the batch
440
+ # we must set B=text_tokens.shape[0] (loss spike may still happen sometimes)
441
+ omni_mask_fn = omni_attn_mask(modality_positions) # 여기서 마스크 정보가 다 준비됨
442
+ # block_mask = create_block_mask(omni_mask_fn, B=text_tokens.shape[0], H=None,
443
+ # Q_LEN=preproc_config.max_mixed_modal_seq_length,
444
+ # KV_LEN=preproc_config.max_mixed_modal_seq_length, device=accelerator.device)
445
+ # or use naive omni attention mask, which is more stable
446
+ block_mask = omni_attn_mask_naive(text_tokens.size(0),
447
+ text_tokens.size(1),
448
+ modality_positions,
449
+ accelerator.device).to(weight_type)
450
+
451
+ logits, loss_ntp, loss_flow = model(text_tokens=text_tokens,
452
+ image_latents=image_latents,
453
+ t=t.to(weight_type),
454
+ attention_mask=block_mask,
455
+ text_masks=text_masks,
456
+ image_masks=image_masks,
457
+ text_labels=text_labels,
458
+ image_labels=image_labels,
459
+ modality_positions=modality_positions,
460
+ output_hidden_states=True,
461
+ max_seq_len=text_tokens.size(1),
462
+ device=accelerator.device,
463
+ )
464
+
465
+ # Gather the losses across all processes for logging (if we use distributed training).
466
+ avg_loss_ntp = accelerator.gather(loss_ntp.repeat(total_batch_size_per_gpu)).mean()
467
+ avg_loss_flow = accelerator.gather(loss_flow.repeat(total_batch_size_per_gpu)).mean()
468
+ loss = config.training.ntp_coeff * loss_ntp + config.training.flow_coeff * loss_flow
469
+
470
+ accelerator.backward(loss.to(weight_type) / config.training.gradient_accumulation_steps)
471
+
472
+ if config.training.max_grad_norm is not None and accelerator.sync_gradients:
473
+ accelerator.clip_grad_norm_(model.parameters(), config.training.max_grad_norm)
474
+
475
+ if (global_step + 1) % config.training.gradient_accumulation_steps == 0:
476
+ optimizer.step()
477
+ lr_scheduler.step()
478
+
479
+ # log gradient norm before zeroing it
480
+ if (
481
+ accelerator.sync_gradients
482
+ and (global_step + 1) % config.experiment.log_grad_norm_every == 0
483
+ and accelerator.is_main_process
484
+ ):
485
+ log_grad_norm(model, accelerator, global_step + 1)
486
+
487
+ if (global_step + 1) % config.training.gradient_accumulation_steps == 0:
488
+ optimizer.zero_grad(set_to_none=True)
489
+
490
+ # Checks if the accelerator has performed an optimization step behind the scenes
491
+ if accelerator.sync_gradients:
492
+
493
+ batch_time_m.update(time.time() - end)
494
+ end = time.time()
495
+
496
+ # Log metrics
497
+ if (global_step + 1) % config.experiment.log_every == 0:
498
+ samples_per_second_per_gpu = (
499
+ config.training.gradient_accumulation_steps * total_batch_size_per_gpu / batch_time_m.val
500
+ )
501
+ lr = [group["lr"] for group in optimizer.param_groups]
502
+ if len(lr) == 3:
503
+ logs = {
504
+ "step_loss_ntp": avg_loss_ntp.item(),
505
+ "step_loss_flow": avg_loss_flow.item(),
506
+ "lr_ve": lr[0],
507
+ "lr_proj": lr[1],
508
+ "lr_showo": lr[2],
509
+ "samples/sec/gpu": samples_per_second_per_gpu,
510
+ "data_time": data_time_m.val,
511
+ "batch_time": batch_time_m.val,
512
+ }
513
+ accelerator.log(logs, step=global_step + 1)
514
+ logger.info(
515
+ f"Epoch: {epoch} "
516
+ f"Step: {global_step + 1} "
517
+ f"Loss_NTP: {avg_loss_ntp.item():0.4f} "
518
+ f"Loss_FLOW: {avg_loss_flow.item():0.4f} "
519
+ f"Data (t): {data_time_m.val:0.4f}, {samples_per_second_per_gpu:0.2f}/s/gpu "
520
+ f"Batch (t): {batch_time_m.val:0.4f} "
521
+ f"LR_ve: {lr[0]:0.6f} "
522
+ f"LR_proj: {lr[1]:0.6f} "
523
+ f"LR_showo: {lr[2]:0.6f}"
524
+ )
525
+ else:
526
+ logs = {
527
+ "step_loss_ntp": avg_loss_ntp.item(),
528
+ "step_loss_flow": avg_loss_flow.item(),
529
+ "lr": lr[0],
530
+ "samples/sec/gpu": samples_per_second_per_gpu,
531
+ "data_time": data_time_m.val,
532
+ "batch_time": batch_time_m.val,
533
+ }
534
+ accelerator.log(logs, step=global_step + 1)
535
+ logger.info(
536
+ f"Epoch: {epoch} "
537
+ f"Step: {global_step + 1} "
538
+ f"Loss_NTP: {avg_loss_ntp.item():0.4f} "
539
+ f"Loss_FLOW: {avg_loss_flow.item():0.4f} "
540
+ f"Data (t): {data_time_m.val:0.4f}, {samples_per_second_per_gpu:0.2f}/s/gpu "
541
+ f"Batch (t): {batch_time_m.val:0.4f} "
542
+ f"LR: {lr[0]:0.6f}"
543
+ )
544
+ # resetting batch / data time meters per log window
545
+ batch_time_m.reset()
546
+ data_time_m.reset()
547
+
548
+ # Save model checkpoint
549
+ if (global_step + 1) % config.experiment.save_every == 0:
550
+ save_checkpoint(model, config, accelerator, global_step + 1)
551
+
552
+ global_step += 1
553
+
554
+ # Stop training if max steps is reached
555
+ if global_step >= config.training.max_train_steps:
556
+ break
557
+ # End for
558
+
559
+ accelerator.wait_for_everyone()
560
+
561
+ # Evaluate and save checkpoint at the end of training
562
+ save_checkpoint(model, config, accelerator, "final")
563
+
564
+ # Save the final trained checkpoint
565
+ if accelerator.is_main_process:
566
+ model = accelerator.unwrap_model(model)
567
+ model.save_pretrained(config.experiment.output_dir, safe_serialization=False)
568
+
569
+ accelerator.end_training()
570
+
571
+
572
+ @torch.no_grad()
573
+ def generate_images(
574
+ model,
575
+ vae_model,
576
+ text_tokenizer,
577
+ config,
578
+ global_step,
579
+ device,
580
+ weight_type,
581
+ sampler,
582
+ showo_token_ids,
583
+ ):
584
+ logger.info("Generating images...")
585
+ model.eval()
586
+
587
+ # read validation prompts from file
588
+ with open(config.dataset.params.validation_prompts_file, "r") as f:
589
+ prompts = f.read().splitlines()[:config.training.batch_size_t2i]
590
+
591
+ num_t2i_image_tokens, num_mmu_image_tokens, num_video_tokens, max_seq_len, max_text_len, image_latent_dim, patch_size, latent_width, \
592
+ latent_height, pad_id, bos_id, eos_id, boi_id, eoi_id, bov_id, eov_id, image_pad_id, video_pad_id, guidance_scale \
593
+ = get_hyper_params(config, text_tokenizer, showo_token_ids)
594
+
595
+ batch_text_tokens, batch_text_tokens_null, batch_modality_positions, batch_modality_positions_null = \
596
+ prepare_gen_input(
597
+ prompts, text_tokenizer, num_t2i_image_tokens, bos_id, eos_id, boi_id, eoi_id, pad_id, image_pad_id,
598
+ max_text_len, device
599
+ )
600
+
601
+ z = torch.randn((len(prompts),
602
+ image_latent_dim, latent_height * patch_size,
603
+ latent_width * patch_size)).to(weight_type).to(device)
604
+
605
+ if guidance_scale > 0:
606
+ z = torch.cat([z, z], dim=0)
607
+ text_tokens = torch.cat([batch_text_tokens, batch_text_tokens_null], dim=0)
608
+ modality_positions = torch.cat([batch_modality_positions, batch_modality_positions_null], dim=0)
609
+ # B=None would potentially induce loss spike when there are a lot of ignored labels (-100) in the batch
610
+ # we must set B=text_tokens.shape[0] (loss spike may still happen sometimes)
611
+ # omni_mask_fn = omni_attn_mask(modality_positions)
612
+ # block_mask = create_block_mask(omni_mask_fn, B=z.size(0), H=None, Q_LEN=max_seq_len,
613
+ # KV_LEN=max_seq_len, device=device)
614
+ # or use naive omni attention mask, which is more stable
615
+ block_mask = omni_attn_mask_naive(text_tokens.size(0),
616
+ max_seq_len,
617
+ modality_positions,
618
+ device).to(weight_type)
619
+ else:
620
+ text_tokens = batch_text_tokens
621
+ modality_positions = batch_modality_positions
622
+ # B=None would potentially induce loss spike when there are a lot of ignored labels (-100) in the batch
623
+ # we must set B=text_tokens.shape[0] (loss spike may still happen sometimes)
624
+ # omni_mask_fn = omni_attn_mask(modality_positions)
625
+ # block_mask = create_block_mask(omni_mask_fn, B=z.size(0), H=None, Q_LEN=max_seq_len,
626
+ # KV_LEN=max_seq_len, device=device)
627
+ block_mask = omni_attn_mask_naive(text_tokens.size(0),
628
+ max_seq_len,
629
+ modality_positions,
630
+ device).to(weight_type)
631
+
632
+ model_kwargs = dict(
633
+ text_tokens=torch.cat([batch_text_tokens, batch_text_tokens_null], dim=0),
634
+ attention_mask=block_mask,
635
+ modality_positions=torch.cat([batch_modality_positions,
636
+ batch_modality_positions_null], dim=0),
637
+ output_hidden_states=True,
638
+ max_seq_len=max_seq_len,
639
+ guidance_scale=guidance_scale
640
+ )
641
+
642
+ sample_fn = sampler.sample_ode(
643
+ sampling_method=config.transport.sampling_method,
644
+ num_steps=config.transport.num_inference_steps,
645
+ atol=config.transport.atol,
646
+ rtol=config.transport.rtol,
647
+ reverse=config.transport.reverse,
648
+ time_shifting_factor=config.transport.time_shifting_factor
649
+ )
650
+ samples = sample_fn(z, model.t2i_generate, **model_kwargs)[-1]
651
+ samples = torch.chunk(samples, 2)[0]
652
+
653
+ if config.model.vae_model.type == 'wan21':
654
+ samples = samples.unsqueeze(2)
655
+ images = vae_model.batch_decode(samples)
656
+ images = images.squeeze(2)
657
+ else:
658
+ raise NotImplementedError
659
+
660
+ model.train()
661
+
662
+ # Convert to PIL images
663
+ images = denorm(images)
664
+ pil_images = [Image.fromarray(image) for image in images]
665
+
666
+ # Log images
667
+ wandb_images = [wandb.Image(image, caption=prompts[i]) for i, image in enumerate(pil_images)]
668
+ wandb.log({"Generated images": wandb_images}, step=global_step)
669
+
670
+
671
+ @torch.no_grad()
672
+ def visualize_reconstruction(
673
+ pixel_values,
674
+ recons_images,
675
+ captions,
676
+ global_step
677
+ ):
678
+ logger.info("Visualizing images...")
679
+
680
+ # Convert to PIL images
681
+ images = denorm(pixel_values)
682
+ recons_images = denorm(recons_images)
683
+ visualized_images = np.concatenate((images, recons_images), 2)
684
+ pil_images = [Image.fromarray(image) for image in visualized_images]
685
+
686
+ # Log images
687
+ wandb_images = [wandb.Image(image, caption=captions[i]) for i, image in enumerate(pil_images)]
688
+ wandb.log({"Original images vs. Reconstructed": wandb_images}, step=global_step)
689
+
690
+
691
+ @torch.no_grad()
692
+ def generate_videos(
693
+ model,
694
+ vae_model,
695
+ text_tokenizer,
696
+ config,
697
+ global_step,
698
+ device,
699
+ weight_type,
700
+ sampler,
701
+ showo_token_ids
702
+ ):
703
+ logger.info("Generating videos...")
704
+ model.eval()
705
+
706
+ # read validation prompts from file
707
+ with open(config.dataset.params.validation_prompts_file, "r") as f:
708
+ prompts = f.read().splitlines()[:config.training.batch_size_t2i]
709
+
710
+ num_image_tokens, num_video_tokens, max_seq_len, max_text_len, image_latent_dim, patch_size, latent_width, \
711
+ latent_height, pad_id, bos_id, eos_id, boi_id, eoi_id, bov_id, eov_id, image_pad_id, video_pad_id, guidance_scale \
712
+ = get_hyper_params(config, text_tokenizer, showo_token_ids, is_video=True)
713
+
714
+ batch_text_tokens, batch_text_tokens_null, batch_modality_positions, batch_modality_positions_null = \
715
+ prepare_gen_input(
716
+ prompts, text_tokenizer, num_video_tokens, bos_id, eos_id, bov_id, eov_id, pad_id, video_pad_id,
717
+ max_text_len, device
718
+ )
719
+
720
+ T = 5
721
+ z = torch.randn((len(prompts), image_latent_dim, T, latent_height * patch_size, latent_width * patch_size)).to(
722
+ device).to(weight_type)
723
+
724
+ if guidance_scale > 0:
725
+ z = torch.cat([z, z], dim=0)
726
+ text_tokens = torch.cat([batch_text_tokens, batch_text_tokens_null], dim=0)
727
+ modality_positions = torch.cat([batch_modality_positions, batch_modality_positions_null], dim=0)
728
+ # B=None would potentially induce loss spike when there are a lot of ignored labels (-100) in the batch
729
+ # we must set B=text_tokens.shape[0] (loss spike may still happen sometimes)
730
+ # omni_mask_fn = omni_attn_mask(modality_positions)
731
+ # block_mask = create_block_mask(omni_mask_fn, B=z.size(0), H=None, Q_LEN=max_seq_len,
732
+ # KV_LEN=max_seq_len, device=device)
733
+ # or use naive omni attention mask, which is more stable
734
+ block_mask = omni_attn_mask_naive(text_tokens.size(0),
735
+ max_seq_len,
736
+ modality_positions,
737
+ device).to(weight_type)
738
+ else:
739
+ text_tokens = batch_text_tokens
740
+ modality_positions = batch_modality_positions
741
+ # B=None would potentially induce loss spike when there are a lot of ignored labels (-100) in the batch
742
+ # we must set B=text_tokens.shape[0] (loss spike may still happen sometimes)
743
+ # omni_mask_fn = omni_attn_mask(modality_positions)
744
+ # block_mask = create_block_mask(omni_mask_fn, B=z.size(0), H=None, Q_LEN=max_seq_len,
745
+ # KV_LEN=max_seq_len, device=device)
746
+ block_mask = omni_attn_mask_naive(text_tokens.size(0),
747
+ max_seq_len,
748
+ modality_positions,
749
+ device).to(weight_type)
750
+
751
+ model_kwargs = dict(
752
+ text_tokens=text_tokens,
753
+ attention_mask=block_mask,
754
+ modality_positions=modality_positions,
755
+ output_hidden_states=True,
756
+ max_seq_len=max_seq_len,
757
+ guidance_scale=guidance_scale
758
+ )
759
+
760
+ sample_fn = sampler.sample_ode(
761
+ sampling_method=config.transport.sampling_method,
762
+ num_steps=config.transport.num_inference_steps,
763
+ atol=config.transport.atol,
764
+ rtol=config.transport.rtol,
765
+ reverse=config.transport.reverse,
766
+ time_shifting_factor=config.transport.time_shifting_factor
767
+ )
768
+ samples = sample_fn(z, model.t2i_generate, **model_kwargs)[-1]
769
+ samples = torch.chunk(samples, 2)[0]
770
+
771
+ if config.model.vae_model.type == 'wan21':
772
+ images = vae_model.batch_decode(samples)
773
+ else:
774
+ raise NotImplementedError
775
+
776
+ model.train()
777
+
778
+ # Convert to PIL images
779
+ images = denorm_vid(images)
780
+
781
+ # Log images
782
+ wandb_images = [wandb.Video(image, caption=prompts[i], fps=8, format="mp4") for i, image in enumerate(images)]
783
+ wandb.log({"Generated videos": wandb_images}, step=global_step)
784
+
785
+
786
+ @torch.no_grad()
787
+ def visualize_reconstruction_video(
788
+ pixel_values,
789
+ recons_images,
790
+ captions,
791
+ global_step
792
+ ):
793
+ logger.info("Visualizing videos...")
794
+
795
+ # Convert to PIL images
796
+ images = denorm_vid(pixel_values)
797
+ recons_images = denorm_vid(recons_images)
798
+ visualized_images = np.concatenate((images, recons_images), 4)
799
+
800
+ # Log images
801
+ wandb_images = [wandb.Video(image, caption=captions[i], fps=8, format="mp4") for i, image in
802
+ enumerate(visualized_images)]
803
+ wandb.log({"Original videos vs. Reconstructed": wandb_images}, step=global_step)
804
+
805
+
806
+ def save_checkpoint(model, config, accelerator, global_step):
807
+ output_dir = config.experiment.output_dir
808
+ checkpoints_total_limit = config.experiment.get("checkpoints_total_limit", None)
809
+
810
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
811
+ if accelerator.is_main_process and checkpoints_total_limit is not None:
812
+ checkpoints = os.listdir(output_dir)
813
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
814
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
815
+
816
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
817
+ if len(checkpoints) >= checkpoints_total_limit:
818
+ num_to_remove = len(checkpoints) - checkpoints_total_limit + 1
819
+ removing_checkpoints = checkpoints[0:num_to_remove]
820
+
821
+ logger.info(
822
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
823
+ )
824
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
825
+
826
+ for removing_checkpoint in removing_checkpoints:
827
+ removing_checkpoint = os.path.join(output_dir, removing_checkpoint)
828
+ shutil.rmtree(removing_checkpoint)
829
+
830
+ save_path = Path(output_dir) / f"checkpoint-{global_step}"
831
+
832
+ # retrieve the model on all processes for deepspeed stage 3 to work then save on one process (we are not using stage 3 yet)
833
+ # XXX: could also make this conditional on deepspeed
834
+ state_dict = accelerator.get_state_dict(model)
835
+ if accelerator.is_main_process:
836
+ unwrapped_model = accelerator.unwrap_model(model)
837
+ unwrapped_model.save_pretrained(
838
+ save_path / "unwrapped_model",
839
+ save_function=accelerator.save,
840
+ state_dict=state_dict,
841
+ safe_serialization=False
842
+ )
843
+ json.dump({"global_step": global_step}, (save_path / "metadata.json").open("w+"))
844
+ logger.info(f"Saved state to {save_path}")
845
+
846
+
847
+ def log_grad_norm(model, accelerator, global_step):
848
+ for name, param in model.named_parameters():
849
+ if param.grad is not None:
850
+ grads = param.grad.detach().data
851
+ grad_norm = (grads.norm(p=2) / grads.numel()).item()
852
+ accelerator.log({"grad_norm/" + name: grad_norm}, step=global_step)
853
+
854
+
855
+ if __name__ == "__main__":
856
+ main()