|
|
import argparse |
|
|
import diffusers |
|
|
import logging |
|
|
import math |
|
|
import os |
|
|
import time |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import torch.utils.checkpoint |
|
|
import transformers |
|
|
import warnings |
|
|
import random |
|
|
|
|
|
from accelerate import Accelerator |
|
|
from accelerate.utils import LoggerType |
|
|
from accelerate import InitProcessGroupKwargs |
|
|
from accelerate.logging import get_logger |
|
|
from accelerate.utils import DistributedDataParallelKwargs |
|
|
from datetime import datetime |
|
|
from datetime import timedelta |
|
|
|
|
|
from diffusers.utils import check_min_version |
|
|
from einops import rearrange |
|
|
from omegaconf import OmegaConf |
|
|
from tqdm.auto import tqdm |
|
|
|
|
|
from musetalk.utils.utils import ( |
|
|
delete_additional_ckpt, |
|
|
seed_everything, |
|
|
get_mouth_region, |
|
|
process_audio_features, |
|
|
save_models |
|
|
) |
|
|
from musetalk.loss.basic_loss import set_requires_grad |
|
|
from musetalk.loss.syncnet import get_sync_loss |
|
|
from musetalk.utils.training_utils import ( |
|
|
initialize_models_and_optimizers, |
|
|
initialize_dataloaders, |
|
|
initialize_loss_functions, |
|
|
initialize_syncnet, |
|
|
initialize_vgg, |
|
|
validation |
|
|
) |
|
|
|
|
|
logger = get_logger(__name__, log_level="INFO") |
|
|
warnings.filterwarnings("ignore") |
|
|
check_min_version("0.10.0.dev0") |
|
|
|
|
|
def main(cfg): |
|
|
exp_name = cfg.exp_name |
|
|
save_dir = f"{cfg.output_dir}/{exp_name}" |
|
|
os.makedirs(save_dir, exist_ok=True) |
|
|
|
|
|
kwargs = DistributedDataParallelKwargs() |
|
|
process_group_kwargs = InitProcessGroupKwargs( |
|
|
timeout=timedelta(seconds=5400)) |
|
|
accelerator = Accelerator( |
|
|
gradient_accumulation_steps=cfg.solver.gradient_accumulation_steps, |
|
|
log_with=["tensorboard", LoggerType.TENSORBOARD], |
|
|
project_dir=os.path.join(save_dir, "./tensorboard"), |
|
|
kwargs_handlers=[kwargs, process_group_kwargs], |
|
|
) |
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
|
|
datefmt="%m/%d/%Y %H:%M:%S", |
|
|
level=logging.INFO, |
|
|
) |
|
|
logger.info(accelerator.state, main_process_only=False) |
|
|
if accelerator.is_local_main_process: |
|
|
transformers.utils.logging.set_verbosity_warning() |
|
|
diffusers.utils.logging.set_verbosity_info() |
|
|
else: |
|
|
transformers.utils.logging.set_verbosity_error() |
|
|
diffusers.utils.logging.set_verbosity_error() |
|
|
|
|
|
|
|
|
if cfg.seed is not None: |
|
|
print('cfg.seed', cfg.seed, accelerator.process_index) |
|
|
seed_everything(cfg.seed + accelerator.process_index) |
|
|
|
|
|
weight_dtype = torch.float32 |
|
|
|
|
|
model_dict = initialize_models_and_optimizers(cfg, accelerator, weight_dtype) |
|
|
dataloader_dict = initialize_dataloaders(cfg) |
|
|
loss_dict = initialize_loss_functions(cfg, accelerator, model_dict['scheduler_max_steps']) |
|
|
syncnet = initialize_syncnet(cfg, accelerator, weight_dtype) |
|
|
vgg_IN, pyramid, downsampler = initialize_vgg(cfg, accelerator) |
|
|
|
|
|
|
|
|
model_dict['net'], model_dict['optimizer'], model_dict['lr_scheduler'], dataloader_dict['train_dataloader'], dataloader_dict['val_dataloader'] = accelerator.prepare( |
|
|
model_dict['net'], model_dict['optimizer'], model_dict['lr_scheduler'], dataloader_dict['train_dataloader'], dataloader_dict['val_dataloader'] |
|
|
) |
|
|
print("length train/val", len(dataloader_dict['train_dataloader']), len(dataloader_dict['val_dataloader'])) |
|
|
|
|
|
|
|
|
num_update_steps_per_epoch = math.ceil( |
|
|
len(dataloader_dict['train_dataloader']) / cfg.solver.gradient_accumulation_steps |
|
|
) |
|
|
num_train_epochs = math.ceil( |
|
|
cfg.solver.max_train_steps / num_update_steps_per_epoch |
|
|
) |
|
|
|
|
|
|
|
|
if accelerator.is_main_process: |
|
|
run_time = datetime.now().strftime("%Y%m%d-%H%M") |
|
|
accelerator.init_trackers( |
|
|
cfg.exp_name, |
|
|
init_kwargs={"mlflow": {"run_name": run_time}}, |
|
|
) |
|
|
|
|
|
|
|
|
total_batch_size = ( |
|
|
cfg.data.train_bs |
|
|
* accelerator.num_processes |
|
|
* cfg.solver.gradient_accumulation_steps |
|
|
) |
|
|
|
|
|
|
|
|
logger.info("***** Running training *****") |
|
|
logger.info(f"Num Epochs = {num_train_epochs}") |
|
|
logger.info(f"Instantaneous batch size per device = {cfg.data.train_bs}") |
|
|
logger.info( |
|
|
f"Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}" |
|
|
) |
|
|
logger.info( |
|
|
f"Gradient Accumulation steps = {cfg.solver.gradient_accumulation_steps}") |
|
|
logger.info(f"Total optimization steps = {cfg.solver.max_train_steps}") |
|
|
|
|
|
global_step = 0 |
|
|
first_epoch = 0 |
|
|
|
|
|
|
|
|
if cfg.resume_from_checkpoint: |
|
|
resume_dir = save_dir |
|
|
dirs = os.listdir(resume_dir) |
|
|
dirs = [d for d in dirs if d.startswith("checkpoint")] |
|
|
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) |
|
|
if len(dirs) > 0: |
|
|
path = dirs[-1] |
|
|
accelerator.load_state(os.path.join(resume_dir, path)) |
|
|
accelerator.print(f"Resuming from checkpoint {path}") |
|
|
global_step = int(path.split("-")[1]) |
|
|
first_epoch = global_step // num_update_steps_per_epoch |
|
|
resume_step = global_step % num_update_steps_per_epoch |
|
|
|
|
|
|
|
|
progress_bar = tqdm( |
|
|
range(global_step, cfg.solver.max_train_steps), |
|
|
disable=not accelerator.is_local_main_process, |
|
|
) |
|
|
progress_bar.set_description("Steps") |
|
|
|
|
|
|
|
|
print("log type of models") |
|
|
print("unet", model_dict['unet'].dtype) |
|
|
print("vae", model_dict['vae'].dtype) |
|
|
print("wav2vec", model_dict['wav2vec'].dtype) |
|
|
|
|
|
def get_ganloss_weight(step): |
|
|
"""Calculate GAN loss weight based on training step""" |
|
|
if step < cfg.discriminator_train_params.start_gan: |
|
|
return 0.0 |
|
|
else: |
|
|
return 1.0 |
|
|
|
|
|
|
|
|
for epoch in range(first_epoch, num_train_epochs): |
|
|
|
|
|
model_dict['unet'].train() |
|
|
if cfg.loss_params.gan_loss > 0: |
|
|
loss_dict['discriminator'].train() |
|
|
if cfg.loss_params.mouth_gan_loss > 0: |
|
|
loss_dict['mouth_discriminator'].train() |
|
|
|
|
|
|
|
|
train_loss = 0.0 |
|
|
train_loss_D = 0.0 |
|
|
train_loss_D_mouth = 0.0 |
|
|
l1_loss_accum = 0.0 |
|
|
vgg_loss_accum = 0.0 |
|
|
gan_loss_accum = 0.0 |
|
|
gan_loss_accum_mouth = 0.0 |
|
|
fm_loss_accum = 0.0 |
|
|
sync_loss_accum = 0.0 |
|
|
adapted_weight_accum = 0.0 |
|
|
|
|
|
t_data_start = time.time() |
|
|
for step, batch in enumerate(dataloader_dict['train_dataloader']): |
|
|
t_data = time.time() - t_data_start |
|
|
t_model_start = time.time() |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
pixel_values = batch["pixel_values_vid"].to(weight_dtype).to( |
|
|
accelerator.device, |
|
|
non_blocking=True |
|
|
) |
|
|
bsz, num_frames, c, h, w = pixel_values.shape |
|
|
|
|
|
|
|
|
ref_pixel_values = batch["pixel_values_ref_img"].to(weight_dtype).to( |
|
|
accelerator.device, |
|
|
non_blocking=True |
|
|
) |
|
|
|
|
|
|
|
|
pixel_values_face_mask = batch['pixel_values_face_mask'] |
|
|
|
|
|
|
|
|
audio_prompts = process_audio_features(cfg, batch, model_dict['wav2vec'], bsz, num_frames, weight_dtype) |
|
|
|
|
|
|
|
|
adapted_weight = 1 |
|
|
|
|
|
|
|
|
if cfg.loss_params.sync_loss > 0: |
|
|
mels = batch['mel'] |
|
|
|
|
|
gt_frames = rearrange(pixel_values, 'b f c h w-> b (f c) h w') |
|
|
|
|
|
height = gt_frames.shape[2] |
|
|
gt_frames = gt_frames[:, :, height // 2:, :] |
|
|
|
|
|
|
|
|
audio_embed = syncnet.get_audio_embed(mels) |
|
|
|
|
|
|
|
|
if cfg.use_adapted_weight: |
|
|
vision_embed_gt = syncnet.get_vision_embed(gt_frames) |
|
|
image_audio_sim_gt = F.cosine_similarity( |
|
|
audio_embed, |
|
|
vision_embed_gt, |
|
|
dim=1 |
|
|
)[0] |
|
|
|
|
|
if image_audio_sim_gt < 0.05 or image_audio_sim_gt > 0.65: |
|
|
if cfg.adapted_weight_type == "cut_off": |
|
|
adapted_weight = 0.0 |
|
|
print( |
|
|
f"\nThe i-a similarity in step {global_step} is {image_audio_sim_gt}, set adapted_weight to {adapted_weight}.") |
|
|
elif cfg.adapted_weight_type == "linear": |
|
|
adapted_weight = image_audio_sim_gt |
|
|
else: |
|
|
print(f"unknown adapted_weight_type: {cfg.adapted_weight_type}") |
|
|
adapted_weight = 1 |
|
|
|
|
|
|
|
|
max_start = 16 - cfg.num_backward_frames |
|
|
frames_left_index = random.randint(0, max_start) if max_start > 0 else 0 |
|
|
frames_right_index = frames_left_index + cfg.num_backward_frames |
|
|
else: |
|
|
frames_left_index = 0 |
|
|
frames_right_index = cfg.data.n_sample_frames |
|
|
|
|
|
|
|
|
pixel_values_backward = pixel_values[:, frames_left_index:frames_right_index, ...] |
|
|
ref_pixel_values_backward = ref_pixel_values[:, frames_left_index:frames_right_index, ...] |
|
|
pixel_values_face_mask_backward = pixel_values_face_mask[:, frames_left_index:frames_right_index, ...] |
|
|
audio_prompts_backward = audio_prompts[:, frames_left_index:frames_right_index, ...] |
|
|
|
|
|
|
|
|
frames = rearrange(pixel_values_backward, 'b f c h w-> (b f) c h w') |
|
|
latents = model_dict['vae'].encode(frames).latent_dist.mode() |
|
|
latents = latents * model_dict['vae'].config.scaling_factor |
|
|
latents = latents.float() |
|
|
|
|
|
|
|
|
masked_pixel_values = pixel_values_backward.clone() |
|
|
masked_pixel_values[:, :, :, h//2:, :] = -1 |
|
|
masked_frames = rearrange(masked_pixel_values, 'b f c h w -> (b f) c h w') |
|
|
masked_latents = model_dict['vae'].encode(masked_frames).latent_dist.mode() |
|
|
masked_latents = masked_latents * model_dict['vae'].config.scaling_factor |
|
|
masked_latents = masked_latents.float() |
|
|
|
|
|
|
|
|
ref_frames = rearrange(ref_pixel_values_backward, 'b f c h w-> (b f) c h w') |
|
|
ref_latents = model_dict['vae'].encode(ref_frames).latent_dist.mode() |
|
|
ref_latents = ref_latents * model_dict['vae'].config.scaling_factor |
|
|
ref_latents = ref_latents.float() |
|
|
|
|
|
|
|
|
pixel_values_face_mask_backward = rearrange( |
|
|
pixel_values_face_mask_backward, |
|
|
"b f c h w -> (b f) c h w" |
|
|
) |
|
|
audio_prompts_backward = rearrange( |
|
|
audio_prompts_backward, |
|
|
'b f c h w-> (b f) c h w' |
|
|
) |
|
|
audio_prompts_backward = rearrange( |
|
|
audio_prompts_backward, |
|
|
'(b f) c h w -> (b f) (c h) w', |
|
|
b=bsz |
|
|
) |
|
|
|
|
|
|
|
|
dropout = nn.Dropout(p=cfg.ref_dropout_rate) |
|
|
ref_latents = dropout(ref_latents) |
|
|
|
|
|
|
|
|
input_latents = torch.cat([masked_latents, ref_latents], dim=1) |
|
|
input_latents = input_latents.to(weight_dtype) |
|
|
timesteps = torch.tensor([0], device=input_latents.device) |
|
|
|
|
|
|
|
|
latents_pred = model_dict['net']( |
|
|
input_latents, |
|
|
timesteps, |
|
|
audio_prompts_backward, |
|
|
) |
|
|
latents_pred = (1 / model_dict['vae'].config.scaling_factor) * latents_pred |
|
|
image_pred = model_dict['vae'].decode(latents_pred).sample |
|
|
|
|
|
|
|
|
image_pred = image_pred.float() |
|
|
frames = frames.float() |
|
|
|
|
|
|
|
|
l1_loss = loss_dict['L1_loss'](frames, image_pred) |
|
|
l1_loss_accum += l1_loss.item() |
|
|
loss = cfg.loss_params.l1_loss * l1_loss * adapted_weight |
|
|
|
|
|
|
|
|
if cfg.loss_params.mouth_gan_loss > 0: |
|
|
frames_mouth, image_pred_mouth = get_mouth_region( |
|
|
frames, |
|
|
image_pred, |
|
|
pixel_values_face_mask_backward |
|
|
) |
|
|
pyramide_real_mouth = pyramid(downsampler(frames_mouth)) |
|
|
pyramide_generated_mouth = pyramid(downsampler(image_pred_mouth)) |
|
|
|
|
|
|
|
|
if cfg.loss_params.vgg_loss > 0: |
|
|
pyramide_real = pyramid(downsampler(frames)) |
|
|
pyramide_generated = pyramid(downsampler(image_pred)) |
|
|
|
|
|
loss_IN = 0 |
|
|
for scale in cfg.loss_params.pyramid_scale: |
|
|
x_vgg = vgg_IN(pyramide_generated['prediction_' + str(scale)]) |
|
|
y_vgg = vgg_IN(pyramide_real['prediction_' + str(scale)]) |
|
|
for i, weight in enumerate(cfg.loss_params.vgg_layer_weight): |
|
|
value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean() |
|
|
loss_IN += weight * value |
|
|
loss_IN /= sum(cfg.loss_params.vgg_layer_weight) |
|
|
loss += loss_IN * cfg.loss_params.vgg_loss * adapted_weight |
|
|
vgg_loss_accum += loss_IN.item() |
|
|
|
|
|
|
|
|
if cfg.loss_params.gan_loss > 0: |
|
|
set_requires_grad(loss_dict['discriminator'], False) |
|
|
loss_G = 0. |
|
|
discriminator_maps_generated = loss_dict['discriminator'](pyramide_generated) |
|
|
discriminator_maps_real = loss_dict['discriminator'](pyramide_real) |
|
|
|
|
|
for scale in loss_dict['disc_scales']: |
|
|
key = 'prediction_map_%s' % scale |
|
|
value = ((1 - discriminator_maps_generated[key]) ** 2).mean() |
|
|
loss_G += value |
|
|
gan_loss_accum += loss_G.item() |
|
|
|
|
|
loss += loss_G * cfg.loss_params.gan_loss * get_ganloss_weight(global_step) * adapted_weight |
|
|
|
|
|
|
|
|
if cfg.loss_params.fm_loss[0] > 0: |
|
|
L_feature_matching = 0. |
|
|
for scale in loss_dict['disc_scales']: |
|
|
key = 'feature_maps_%s' % scale |
|
|
for i, (a, b) in enumerate(zip(discriminator_maps_real[key], discriminator_maps_generated[key])): |
|
|
value = torch.abs(a - b).mean() |
|
|
L_feature_matching += value * cfg.loss_params.fm_loss[i] |
|
|
loss += L_feature_matching * adapted_weight |
|
|
fm_loss_accum += L_feature_matching.item() |
|
|
|
|
|
|
|
|
if cfg.loss_params.mouth_gan_loss > 0: |
|
|
set_requires_grad(loss_dict['mouth_discriminator'], False) |
|
|
loss_G = 0. |
|
|
mouth_discriminator_maps_generated = loss_dict['mouth_discriminator'](pyramide_generated_mouth) |
|
|
mouth_discriminator_maps_real = loss_dict['mouth_discriminator'](pyramide_real_mouth) |
|
|
|
|
|
for scale in loss_dict['disc_scales']: |
|
|
key = 'prediction_map_%s' % scale |
|
|
value = ((1 - mouth_discriminator_maps_generated[key]) ** 2).mean() |
|
|
loss_G += value |
|
|
gan_loss_accum_mouth += loss_G.item() |
|
|
|
|
|
loss += loss_G * cfg.loss_params.mouth_gan_loss * get_ganloss_weight(global_step) * adapted_weight |
|
|
|
|
|
|
|
|
if cfg.loss_params.fm_loss[0] > 0: |
|
|
L_feature_matching = 0. |
|
|
for scale in loss_dict['disc_scales']: |
|
|
key = 'feature_maps_%s' % scale |
|
|
for i, (a, b) in enumerate(zip(mouth_discriminator_maps_real[key], mouth_discriminator_maps_generated[key])): |
|
|
value = torch.abs(a - b).mean() |
|
|
L_feature_matching += value * cfg.loss_params.fm_loss[i] |
|
|
loss += L_feature_matching * adapted_weight |
|
|
fm_loss_accum += L_feature_matching.item() |
|
|
|
|
|
|
|
|
if cfg.loss_params.sync_loss > 0: |
|
|
pred_frames = rearrange( |
|
|
image_pred, '(b f) c h w-> b (f c) h w', f=pixel_values_backward.shape[1]) |
|
|
pred_frames = pred_frames[:, :, height // 2 :, :] |
|
|
sync_loss, image_audio_sim_pred = get_sync_loss( |
|
|
audio_embed, |
|
|
gt_frames, |
|
|
pred_frames, |
|
|
syncnet, |
|
|
adapted_weight, |
|
|
frames_left_index=frames_left_index, |
|
|
frames_right_index=frames_right_index, |
|
|
) |
|
|
sync_loss_accum += sync_loss.item() |
|
|
loss += sync_loss * cfg.loss_params.sync_loss * adapted_weight |
|
|
|
|
|
|
|
|
avg_loss = accelerator.gather(loss.repeat(cfg.data.train_bs)).mean() |
|
|
train_loss += avg_loss.item() |
|
|
accelerator.backward(loss) |
|
|
|
|
|
|
|
|
if cfg.loss_params.gan_loss > 0: |
|
|
set_requires_grad(loss_dict['discriminator'], True) |
|
|
loss_D = loss_dict['discriminator_full'](frames, image_pred.detach()) |
|
|
avg_loss_D = accelerator.gather(loss_D.repeat(cfg.data.train_bs)).mean() |
|
|
train_loss_D += avg_loss_D.item() / 1 |
|
|
loss_D = loss_D * get_ganloss_weight(global_step) * adapted_weight |
|
|
accelerator.backward(loss_D) |
|
|
|
|
|
if accelerator.sync_gradients: |
|
|
accelerator.clip_grad_norm_( |
|
|
loss_dict['discriminator'].parameters(), cfg.solver.max_grad_norm) |
|
|
if (global_step + 1) % cfg.solver.gradient_accumulation_steps == 0: |
|
|
loss_dict['optimizer_D'].step() |
|
|
loss_dict['scheduler_D'].step() |
|
|
loss_dict['optimizer_D'].zero_grad() |
|
|
|
|
|
|
|
|
if cfg.loss_params.mouth_gan_loss > 0: |
|
|
set_requires_grad(loss_dict['mouth_discriminator'], True) |
|
|
mouth_loss_D = loss_dict['mouth_discriminator_full']( |
|
|
frames_mouth, image_pred_mouth.detach()) |
|
|
avg_mouth_loss_D = accelerator.gather( |
|
|
mouth_loss_D.repeat(cfg.data.train_bs)).mean() |
|
|
train_loss_D_mouth += avg_mouth_loss_D.item() / 1 |
|
|
mouth_loss_D = mouth_loss_D * get_ganloss_weight(global_step) * adapted_weight |
|
|
accelerator.backward(mouth_loss_D) |
|
|
|
|
|
if accelerator.sync_gradients: |
|
|
accelerator.clip_grad_norm_( |
|
|
loss_dict['mouth_discriminator'].parameters(), cfg.solver.max_grad_norm) |
|
|
if (global_step + 1) % cfg.solver.gradient_accumulation_steps == 0: |
|
|
loss_dict['mouth_optimizer_D'].step() |
|
|
loss_dict['mouth_scheduler_D'].step() |
|
|
loss_dict['mouth_optimizer_D'].zero_grad() |
|
|
|
|
|
|
|
|
if (global_step + 1) % cfg.solver.gradient_accumulation_steps == 0: |
|
|
if accelerator.sync_gradients: |
|
|
accelerator.clip_grad_norm_( |
|
|
model_dict['trainable_params'], |
|
|
cfg.solver.max_grad_norm, |
|
|
) |
|
|
model_dict['optimizer'].step() |
|
|
model_dict['lr_scheduler'].step() |
|
|
model_dict['optimizer'].zero_grad() |
|
|
|
|
|
|
|
|
if accelerator.sync_gradients: |
|
|
progress_bar.update(1) |
|
|
global_step += 1 |
|
|
accelerator.log({ |
|
|
"train_loss": train_loss, |
|
|
"train_loss_D": train_loss_D, |
|
|
"train_loss_D_mouth": train_loss_D_mouth, |
|
|
"l1_loss": l1_loss_accum, |
|
|
"vgg_loss": vgg_loss_accum, |
|
|
"gan_loss": gan_loss_accum, |
|
|
"fm_loss": fm_loss_accum, |
|
|
"sync_loss": sync_loss_accum, |
|
|
"adapted_weight": adapted_weight_accum, |
|
|
"lr": model_dict['lr_scheduler'].get_last_lr()[0], |
|
|
}, step=global_step) |
|
|
|
|
|
|
|
|
train_loss = 0.0 |
|
|
l1_loss_accum = 0.0 |
|
|
vgg_loss_accum = 0.0 |
|
|
gan_loss_accum = 0.0 |
|
|
fm_loss_accum = 0.0 |
|
|
sync_loss_accum = 0.0 |
|
|
adapted_weight_accum = 0.0 |
|
|
train_loss_D = 0.0 |
|
|
train_loss_D_mouth = 0.0 |
|
|
|
|
|
|
|
|
if global_step % cfg.val_freq == 0 or global_step == 10: |
|
|
try: |
|
|
validation( |
|
|
cfg, |
|
|
dataloader_dict['val_dataloader'], |
|
|
model_dict['net'], |
|
|
model_dict['vae'], |
|
|
model_dict['wav2vec'], |
|
|
accelerator, |
|
|
save_dir, |
|
|
global_step, |
|
|
weight_dtype, |
|
|
syncnet_score=adapted_weight, |
|
|
) |
|
|
except Exception as e: |
|
|
print(f"An error occurred during validation: {e}") |
|
|
|
|
|
|
|
|
if global_step % cfg.checkpointing_steps == 0: |
|
|
save_path = os.path.join(save_dir, f"checkpoint-{global_step}") |
|
|
try: |
|
|
start_time = time.time() |
|
|
if accelerator.is_main_process: |
|
|
save_models( |
|
|
accelerator, |
|
|
model_dict['net'], |
|
|
save_dir, |
|
|
global_step, |
|
|
cfg, |
|
|
logger=logger |
|
|
) |
|
|
delete_additional_ckpt(save_dir, cfg.total_limit) |
|
|
elapsed_time = time.time() - start_time |
|
|
if elapsed_time > 300: |
|
|
print(f"Skipping storage as it took too long in step {global_step}.") |
|
|
else: |
|
|
print(f"Resume states saved at {save_dir} successfully in {elapsed_time}s.") |
|
|
except Exception as e: |
|
|
print(f"Error when saving model in step {global_step}:", e) |
|
|
|
|
|
|
|
|
t_model = time.time() - t_model_start |
|
|
logs = { |
|
|
"step_loss": loss.detach().item(), |
|
|
"lr": model_dict['lr_scheduler'].get_last_lr()[0], |
|
|
"td": f"{t_data:.2f}s", |
|
|
"tm": f"{t_model:.2f}s", |
|
|
} |
|
|
t_data_start = time.time() |
|
|
progress_bar.set_postfix(**logs) |
|
|
|
|
|
if global_step >= cfg.solver.max_train_steps: |
|
|
break |
|
|
|
|
|
|
|
|
if (epoch + 1) % cfg.save_model_epoch_interval == 0: |
|
|
try: |
|
|
start_time = time.time() |
|
|
if accelerator.is_main_process: |
|
|
save_models(accelerator, model_dict['net'], save_dir, global_step, cfg) |
|
|
accelerator.save_state(save_path) |
|
|
elapsed_time = time.time() - start_time |
|
|
if elapsed_time > 120: |
|
|
print(f"Skipping storage as it took too long in step {global_step}.") |
|
|
else: |
|
|
print(f"Model saved successfully in {elapsed_time}s.") |
|
|
except Exception as e: |
|
|
print(f"Error when saving model in step {global_step}:", e) |
|
|
accelerator.wait_for_everyone() |
|
|
|
|
|
|
|
|
accelerator.end_training() |
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument("--config", type=str, default="./configs/training/stage2.yaml") |
|
|
args = parser.parse_args() |
|
|
config = OmegaConf.load(args.config) |
|
|
main(config) |
|
|
|