|
|
import os |
|
|
from argparse import ArgumentParser |
|
|
import accelerate |
|
|
from tqdm.auto import tqdm |
|
|
from omegaconf import OmegaConf |
|
|
from datetime import datetime |
|
|
import numpy as np |
|
|
import math |
|
|
import shutil |
|
|
import gc |
|
|
from accelerate import DistributedDataParallelKwargs |
|
|
import torch |
|
|
from torch.utils.data import DataLoader |
|
|
from diffusers.optimization import get_scheduler |
|
|
from diffusers import AutoencoderKL |
|
|
|
|
|
from accelerate import Accelerator |
|
|
from accelerate.utils import ProjectConfiguration, set_seed |
|
|
|
|
|
from model.utils import save_cfg, vae_encode,cat_video,_freeze_parameters,vae_decode,save_videos_grid,model_load_pretrain |
|
|
from model import AMDModel,AMD_models |
|
|
from model.loss import l2 |
|
|
from safetensors.torch import load_model |
|
|
from dataset.dataset import (A2MVideoAudio, |
|
|
A2MVideoAudioPose, |
|
|
A2MVideoAudioPoseRandomRef, |
|
|
A2MVideoAudioPoseMultiSample, |
|
|
A2MVideoAudioPoseRandomRefMultiSample, |
|
|
A2MVideoAudioPoseMultiSampleMultiRef, |
|
|
A2MVideoAudioPoseMultiSampleMultiRefBalance, |
|
|
A2MVideoAudioMultiRefDoubleRef) |
|
|
from omegaconf import OmegaConf |
|
|
import einops |
|
|
from model.model_A2M import (A2MModel_MotionrefOnly_LearnableToken, |
|
|
A2MModel_CrossAtten_Audio, |
|
|
A2MModel_CrossAtten_Pose, |
|
|
A2MModel_CrossAtten_Audio_Pose, |
|
|
A2MModel_CrossAtten_Audio_PosePre, |
|
|
A2MModel_CrossAtten_Audio_DoubleRef) |
|
|
from model.model_AMD import AMDModel,AMDModel_Rec |
|
|
from model import set_vis_atten_flag |
|
|
set_vis_atten_flag(False) |
|
|
|
|
|
|
|
|
now = datetime.now() |
|
|
current_time = f'{now.year}-{now.month}-{now.day}-{now.hour}:{now.minute}' |
|
|
|
|
|
|
|
|
def get_cfg(): |
|
|
parser = ArgumentParser() |
|
|
|
|
|
def str2bool(v): |
|
|
if isinstance(v, bool): |
|
|
return v |
|
|
if v.lower() in ('yes', 'true', 't', 'y', '1'): |
|
|
return True |
|
|
elif v.lower() in ('no', 'false', 'f', 'n', '0'): |
|
|
return False |
|
|
else: |
|
|
raise argparse.ArgumentTypeError('Boolean value expected.') |
|
|
|
|
|
|
|
|
parser.add_argument('--trainset', type=str, default='/mnt/pfs-mc0p4k/tts/team/digital_avatar_group/sunwenzhang/qiyuan/code/AMD_linear/dataset/path/train_video_with_audio.pkl', help='trainset index file path') |
|
|
parser.add_argument('--evalset', type=str, default='/mnt/pfs-mc0p4k/tts/team/digital_avatar_group/sunwenzhang/qiyuan/code/AMD_linear/dataset/path/eval_video_with_audio.pkl', help='evalset index file path') |
|
|
|
|
|
parser.add_argument('--sample_size', type=str, default="(256,256)", help='Sample size as a tuple, e.g., (256, 256).') |
|
|
parser.add_argument('--sample_stride', type=int, default=1, help='data sample stride') |
|
|
parser.add_argument('--sample_n_frames', type=int, default=31, help='sample_n_frames.') |
|
|
parser.add_argument('--batch_size', type=int, default=4, help='batch size used in training.') |
|
|
parser.add_argument('--path_type', type=str, default='file', choices=['file', 'dir'], help='path type of the dataset.') |
|
|
parser.add_argument('--dataset_type',type=str,default='A2MVideoAudioPose') |
|
|
parser.add_argument('--max_ref_frame',type=int,default=8) |
|
|
parser.add_argument('--num_sample',type=int,default=4) |
|
|
parser.add_argument('--random_ref_num',type=str2bool,default=False) |
|
|
|
|
|
|
|
|
parser.add_argument('--exp_root', default='/mnt/pfs-mc0p4k/cvg/team/didonglin/zqy/exp', required=True, help='exp_root') |
|
|
parser.add_argument('--name', default=f'{current_time}', required=True, help='name of the experiment to load.') |
|
|
parser.add_argument('--log_with',default='tensorboard',choices=['tensorboard', 'wandb'],help='accelerator tracker.') |
|
|
parser.add_argument('--seed', type=int, default=None, help='A seed for reproducible training.') |
|
|
|
|
|
parser.add_argument('--mp', type=str, default='fp16', choices=['fp16', 'bf16', 'no'], help='use mixed precision') |
|
|
parser.add_argument('--num_workers', type=int, default=16) |
|
|
parser.add_argument('--max_train_epoch', type=int, default=200000000, help='maximum number of training steps') |
|
|
parser.add_argument('--max_train_steps', type=int, default=100000, help='max_train_steps') |
|
|
|
|
|
parser.add_argument('--lr', type=float, default=2e-4, help='learning rate in optimization') |
|
|
parser.add_argument('--weight_decay', type=float, default=1e-4, help='weight decay in optimization.') |
|
|
parser.add_argument('--gradient_accumulation_steps', type=int, default=1, help='number of steps for gradient accumulation') |
|
|
parser.add_argument('--lr_warmup_steps', type=int, default=20, help='lr_warmup_steps') |
|
|
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") |
|
|
parser.add_argument('--eval_interval_step', type=int, default=1000, help='eval_interval_step') |
|
|
parser.add_argument('--checkpoint_total_limit', type=int, default=3, help='checkpoint_total_limit') |
|
|
parser.add_argument('--save_checkpoint_interval_step', type=int, default=100, help='save_checkpoint_interval_step') |
|
|
parser.add_argument("--lr_scheduler", type=str, default="constant",help=('The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'' "constant", "constant_with_warmup"]')) |
|
|
parser.add_argument("--resume_from_checkpoint", type=str, default=None,help='input checkpoingt path') |
|
|
parser.add_argument('--motion_sample_step', type=int, default=4, help='checkpoint_total_limit') |
|
|
parser.add_argument('--video_sample_step', type=int, default=4, help='checkpoint_total_limit') |
|
|
parser.add_argument('--a2m_from_pretrained',type=str, default=None) |
|
|
parser.add_argument('--need_amd_loss',type=str2bool,default=False) |
|
|
parser.add_argument('--motion_mask_ratio',type=float,default=0.0) |
|
|
|
|
|
|
|
|
|
|
|
parser.add_argument('--vae_version',type=str,default='/mnt/pfs-mc0p4k/cvg/team/didonglin/zqy/model-checkpoints/Huggingface-Model/sd-vae-ft-mse') |
|
|
parser.add_argument('--amd_model_type',type=str,default='AMDModel',help='AMDModel,AMDModel_Rec') |
|
|
parser.add_argument('--amd_config',type=str, default="/mnt/pfs-mc0p4k/tts/team/digital_avatar_group/sunwenzhang/qiyuan/exp/amd-m-mae-s-1026-linear-final/config.json", help='amd model config path') |
|
|
parser.add_argument('--amd_ckpt',type=str,default="/mnt/pfs-mc0p4k/tts/team/digital_avatar_group/sunwenzhang/qiyuan/code/AMD_linear/ckpt/checkpoint-157000/model_1.safetensors",help="amd model checkpoint path") |
|
|
parser.add_argument('--a2m_config',type=str, default="/mnt/pfs-mc0p4k/tts/team/digital_avatar_group/sunwenzhang/qiyuan/code/AMD_linear/config/Audio2Motion.yaml") |
|
|
parser.add_argument('--use_sample_timestep',action="store_true") |
|
|
parser.add_argument('--sample_timestep_m',type=float,default=0.5) |
|
|
parser.add_argument('--sample_timestep_s',type=float,default=1.0) |
|
|
|
|
|
|
|
|
parser.add_argument('--model_type',type=str,default='type1',help='model type : type1 or type2') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
return args |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
|
|
|
|
|
|
|
|
|
args = get_cfg() |
|
|
|
|
|
|
|
|
proj_dir = os.path.join(args.exp_root, args.name) |
|
|
video_save_dir = os.path.join(proj_dir,'sample') |
|
|
|
|
|
|
|
|
if args.seed is not None: |
|
|
set_seed(args.seed) |
|
|
|
|
|
|
|
|
|
|
|
project_config = ProjectConfiguration(project_dir=proj_dir) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
accelerator = Accelerator( |
|
|
gradient_accumulation_steps = args.gradient_accumulation_steps, |
|
|
mixed_precision = args.mp, |
|
|
log_with = args.log_with, |
|
|
project_config = project_config, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
if accelerator.is_main_process: |
|
|
save_cfg(proj_dir, args) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
device = accelerator.device |
|
|
amd_model = eval(args.amd_model_type).from_config(eval(args.amd_model_type).load_config(args.amd_config)).to(device).requires_grad_(False) |
|
|
load_model(amd_model,args.amd_ckpt) |
|
|
|
|
|
if not args.need_amd_loss: |
|
|
amd_model.diffusion_transformer.to(torch.device('cpu')) |
|
|
|
|
|
_freeze_parameters(amd_model) |
|
|
vae = AutoencoderKL.from_pretrained(args.vae_version, subfolder="vae").to(device).requires_grad_(False) |
|
|
|
|
|
|
|
|
train_dataset = eval(args.dataset_type)( |
|
|
video_dir = args.trainset, |
|
|
sample_size=eval(args.sample_size), |
|
|
sample_stride=args.sample_stride, |
|
|
sample_n_frames=args.sample_n_frames, |
|
|
num_sample = args.num_sample, |
|
|
max_ref_frame = args.max_ref_frame, |
|
|
random_ref_num = args.random_ref_num, |
|
|
) |
|
|
eval_dataset = eval(args.dataset_type)( |
|
|
video_dir=args.evalset, |
|
|
sample_size=eval(args.sample_size), |
|
|
sample_stride=args.sample_stride, |
|
|
sample_n_frames=args.sample_n_frames, |
|
|
num_sample = args.num_sample, |
|
|
max_ref_frame = args.max_ref_frame, |
|
|
random_ref_num = False, |
|
|
) |
|
|
|
|
|
train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size,num_workers=args.num_workers, shuffle=True, collate_fn=train_dataset.collate_fn,pin_memory=True) |
|
|
eval_dataloader = DataLoader(eval_dataset, batch_size=args.batch_size,num_workers=args.num_workers, shuffle=True, collate_fn=eval_dataset.collate_fn,pin_memory=True) |
|
|
|
|
|
a2m_config = OmegaConf.load(args.a2m_config) |
|
|
audio_decoder = eval(a2m_config['model_type'])(**a2m_config['model']) |
|
|
if accelerator.is_main_process: |
|
|
audio_decoder.save_config(proj_dir) |
|
|
if args.a2m_from_pretrained is not None: |
|
|
model_load_pretrain(audio_decoder,args.a2m_from_pretrained,not_load_keyword='abcabcacbd',strict=False) |
|
|
if accelerator.is_main_process: |
|
|
print(f'######### load A2M weight from {args.a2m_from_pretrained} #############') |
|
|
|
|
|
|
|
|
optimizer = torch.optim.AdamW(audio_decoder.parameters(),lr=args.lr) |
|
|
lr_scheduler = get_scheduler( |
|
|
name = args.lr_scheduler, |
|
|
optimizer=optimizer, |
|
|
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, |
|
|
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, |
|
|
) |
|
|
|
|
|
i = 0 |
|
|
for name, param in audio_decoder.named_parameters(): |
|
|
accelerator.print(f"{i}:",name) |
|
|
i+=1 |
|
|
|
|
|
|
|
|
|
|
|
audio_decoder, optimizer, training_dataloader, scheduler = accelerator.prepare( |
|
|
audio_decoder, optimizer, train_dataloader,lr_scheduler |
|
|
) |
|
|
|
|
|
if accelerator.is_main_process: |
|
|
accelerator.init_trackers('tracker') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
total_batch_size = args.batch_size * accelerator.num_processes * args.gradient_accumulation_steps |
|
|
accelerator.print(f"{accelerator.state}") |
|
|
accelerator.print("***** Running training *****") |
|
|
accelerator.print(f" Num examples = {len(train_dataset)}") |
|
|
accelerator.print(f" Num Epochs = {args.max_train_epoch}") |
|
|
accelerator.print(f" Instantaneous batch size per device = {args.batch_size}") |
|
|
accelerator.print(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") |
|
|
accelerator.print(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") |
|
|
global_step = 0 |
|
|
train_loss = 0.0 |
|
|
first_epoch = 0 |
|
|
|
|
|
|
|
|
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) |
|
|
if args.resume_from_checkpoint is not None: |
|
|
model_path = args.resume_from_checkpoint |
|
|
accelerator.print(f"Resuming from checkpoint {model_path}") |
|
|
accelerator.load_state( model_path) |
|
|
global_step = int(os.path.basename(model_path).split("-")[1]) |
|
|
first_epoch = global_step // num_update_steps_per_epoch |
|
|
|
|
|
|
|
|
|
|
|
progress_bar = tqdm( |
|
|
range(0, args.max_train_steps), |
|
|
initial=global_step, |
|
|
desc="Steps", |
|
|
disable=not accelerator.is_local_main_process, |
|
|
) |
|
|
|
|
|
global_validation_step = [] |
|
|
|
|
|
@torch.inference_mode() |
|
|
def log_validation(audio_decoder,amd_model,vae,eval_dataloader, device,accelerator = None,global_step = 0,): |
|
|
|
|
|
|
|
|
accelerator.print(f"Running validation....\n") |
|
|
|
|
|
if accelerator is not None: |
|
|
audio_decoder = accelerator.unwrap_model(audio_decoder) |
|
|
audio_decoder.eval() |
|
|
amd_model.diffusion_transformer.to(device) |
|
|
|
|
|
|
|
|
data = next(iter(eval_dataloader)) |
|
|
|
|
|
ref_video = data["ref_video"].to(device) |
|
|
gt_video = data["gt_video"].to(device) |
|
|
ref_audio = data["ref_audio"].to(device) |
|
|
gt_audio = data["gt_audio"].to(device) |
|
|
randomref_video = data["randomref_video"].to(device) if "randomref_video" in data.keys() else None |
|
|
ref_pose = data["ref_pose"].to(device) if "ref_pose" in data.keys() else None |
|
|
gt_pose = data["gt_pose"].to(device) if "gt_pose" in data.keys() else None |
|
|
mask = data["mask"].to(device) |
|
|
|
|
|
|
|
|
|
|
|
ref_video_z = vae_encode(vae,ref_video) |
|
|
gt_video_z = vae_encode(vae,gt_video) |
|
|
randomref_video_z = vae_encode(vae,randomref_video) if randomref_video is not None else None |
|
|
ref_pose_z = vae_encode(vae,ref_pose) if ref_pose is not None else None |
|
|
gt_pose_z = vae_encode(vae,gt_pose) if gt_pose is not None else None |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ref_motion = amd_model.extract_motion(ref_video_z,mask_ratio=args.motion_mask_ratio) |
|
|
gt_motion = amd_model.extract_motion(gt_video_z) |
|
|
if randomref_video_z is not None: |
|
|
randomref_motion = amd_model.extract_motion(randomref_video_z) |
|
|
else: |
|
|
randomref_motion = None |
|
|
print(f"ref_motion shape : {ref_motion.shape}") |
|
|
print(f"randomref_video_z shape : {randomref_video_z.shape}") |
|
|
|
|
|
if args.use_sample_timestep: |
|
|
timestep = torch.from_numpy(sample_timestep(gt_motion.shape[0],args.sample_timestep_m,args.sample_timestep_s,1000)).to(device,ref_video.dtype) |
|
|
else: |
|
|
timestep = torch.ones(gt_motion.shape[0]).to(device,gt_motion.dtype) * 1000 |
|
|
|
|
|
|
|
|
gt_audio = gt_audio.to(gt_motion.dtype) |
|
|
|
|
|
loss_dict,_ = audio_decoder(motion_gt=gt_motion, |
|
|
ref_motion=ref_motion, |
|
|
randomref_motion = randomref_motion, |
|
|
audio=gt_audio, |
|
|
ref_audio = ref_audio, |
|
|
pose=gt_pose_z, |
|
|
ref_pose = ref_pose_z, |
|
|
timestep = timestep) |
|
|
|
|
|
val_loss = loss_dict['loss'].item() |
|
|
accelerator.print(f'val loss = {val_loss}') |
|
|
accelerator.log({"val_loss": val_loss}, step=global_step) |
|
|
|
|
|
|
|
|
motion_pre = audio_decoder.sample( ref_motion = ref_motion, |
|
|
randomref_motion = randomref_motion, |
|
|
audio =gt_audio, |
|
|
ref_audio =ref_audio , |
|
|
pose =gt_pose_z, |
|
|
ref_pose =ref_pose_z, |
|
|
sample_step=args.motion_sample_step) |
|
|
ref_img = ref_video_z[:,-1,:] |
|
|
_,video_pre_motion_gt,_ = amd_model.sample_with_refimg_motion(ref_img, |
|
|
gt_motion, |
|
|
ref_img, |
|
|
sample_step=args.video_sample_step) |
|
|
_,video_pre_motion_pre,_ = amd_model.sample_with_refimg_motion(ref_img, |
|
|
motion_pre, |
|
|
ref_img, |
|
|
sample_step=args.video_sample_step) |
|
|
video_gt = gt_video_z |
|
|
|
|
|
assert video_gt.shape == video_pre_motion_gt.shape , f'video_gt shape :{video_gt.shape} , video_pre_motion_gt shape:{video_pre_motion_gt.shape}' |
|
|
assert video_gt.shape == video_pre_motion_pre.shape, f'video_gt shape :{video_gt.shape} , video_pre_motion_gt shape:{video_pre_motion_pre.shape}' |
|
|
|
|
|
|
|
|
def transform(x:torch.Tensor): |
|
|
x = vae_decode(vae,x) |
|
|
x = ((x / 2.0 + 0.5).clamp(0, 1) * 255).to(dtype=torch.uint8).cpu().contiguous().numpy() |
|
|
return x |
|
|
video_pre_motion_gt_np = transform(video_pre_motion_gt) |
|
|
video_pre_motion_pre_np = transform(video_pre_motion_pre) |
|
|
video_gt_np = transform(video_gt) |
|
|
|
|
|
|
|
|
def log_transform(x,log_b:int,log_f:int): |
|
|
x = x[:log_b,:log_f,:] |
|
|
x = einops.rearrange(x,'n t c h w -> (n t) h w c') |
|
|
np_x = np.stack([np.asarray(img) for img in x]) |
|
|
return np_x |
|
|
|
|
|
log_b = 4 |
|
|
log_f = 8 |
|
|
|
|
|
for tracker in accelerator.trackers: |
|
|
if tracker.name == "tensorboard": |
|
|
video_gt_out = log_transform(video_gt_np,log_b,log_f) |
|
|
video_pre_motion_gt_out = log_transform(video_pre_motion_gt_np,log_b,log_f) |
|
|
video_pre_motion_pre_out = log_transform(video_pre_motion_pre_np,log_b,log_f) |
|
|
|
|
|
tracker.writer.add_images(f"video_gt", video_gt_out, global_step, dataformats="NHWC") |
|
|
tracker.writer.add_images(f"video_pre_motion_gt", video_pre_motion_gt_out, global_step, dataformats="NHWC") |
|
|
tracker.writer.add_images(f"video_pre_motion_pre", video_pre_motion_pre_out, global_step, dataformats="NHWC") |
|
|
|
|
|
|
|
|
gt_videos = np.stack([np.asarray(vid) for vid in video_gt_np]) |
|
|
tracker.writer.add_video("sample_gt_videos", gt_videos, global_step, fps=8) |
|
|
|
|
|
videos_gt_motion = np.stack([np.asarray(vid) for vid in video_pre_motion_gt_np]) |
|
|
tracker.writer.add_video("sample_videos_gt_motion", videos_gt_motion, global_step, fps=8) |
|
|
|
|
|
videos_pre_motion = np.stack([np.asarray(vid) for vid in video_pre_motion_pre_np]) |
|
|
tracker.writer.add_video("sample_videos_pre_motion", videos_pre_motion, global_step, fps=8) |
|
|
|
|
|
|
|
|
def save_mp4(latent,suffix='pre'): |
|
|
cur_save_path = os.path.join(video_save_dir,f'{global_step}-s{args.motion_sample_step}s{args.video_sample_step}-{suffix}.mp4') |
|
|
video = vae_decode(vae,latent) |
|
|
video = einops.rearrange(video.cpu(),'n t c h w -> n c t h w') |
|
|
save_videos_grid(video,cur_save_path,rescale=True) |
|
|
|
|
|
save_mp4(video_pre_motion_pre,'motionpre') |
|
|
save_mp4(video_pre_motion_gt,'motiongt') |
|
|
save_mp4(video_gt,'gt') |
|
|
|
|
|
|
|
|
video_limit = 9 |
|
|
if accelerator.is_main_process : |
|
|
videofiles = os.listdir(video_save_dir) |
|
|
videofiles = [d for d in videofiles if '.mp4' in d] |
|
|
videofiles = sorted(videofiles, key=lambda x: int(x.split("-")[0])) |
|
|
|
|
|
if len(videofiles) > video_limit: |
|
|
num_to_remove = len(videofiles) - video_limit |
|
|
removing_videofiles = videofiles[0:num_to_remove] |
|
|
accelerator.print(f"removing videofiles: {', '.join(removing_videofiles)}") |
|
|
|
|
|
for removing_videofile in removing_videofiles: |
|
|
removing_videofile = os.path.join(video_save_dir, removing_videofile) |
|
|
os.remove(removing_videofile) |
|
|
|
|
|
if not args.need_amd_loss: |
|
|
amd_model.diffusion_transformer.to(torch.device('cpu')) |
|
|
|
|
|
gc.collect() |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
if accelerator.is_main_process: |
|
|
log_validation(audio_decoder, |
|
|
amd_model, |
|
|
vae, |
|
|
eval_dataloader, |
|
|
device, |
|
|
accelerator, |
|
|
global_step) |
|
|
for epoch in range(first_epoch,args.max_train_epoch): |
|
|
accelerator.print(f"Epoch {epoch} start!!") |
|
|
if global_step >= args.max_train_steps: |
|
|
break |
|
|
|
|
|
for step,data in enumerate(training_dataloader): |
|
|
if global_step >= args.max_train_steps: |
|
|
break |
|
|
audio_decoder.train() |
|
|
with accelerator.accumulate(audio_decoder): |
|
|
ref_video = data["ref_video"] |
|
|
gt_video = data["gt_video"] |
|
|
ref_audio = data["ref_audio"] |
|
|
gt_audio = data["gt_audio"] |
|
|
randomref_video = data["randomref_video"] if "randomref_video" in data.keys() else None |
|
|
ref_pose = data["ref_pose"] if "ref_pose" in data.keys() else None |
|
|
gt_pose = data["gt_pose"] if "gt_pose" in data.keys() else None |
|
|
mask = data["mask"] |
|
|
|
|
|
|
|
|
|
|
|
ref_video_z = vae_encode(vae,ref_video) |
|
|
gt_video_z = vae_encode(vae,gt_video) |
|
|
randomref_video_z = vae_encode(vae,randomref_video) if randomref_video is not None else None |
|
|
ref_pose_z = vae_encode(vae,ref_pose) if "ref_pose" in data.keys() else None |
|
|
gt_pose_z = vae_encode(vae,gt_pose) if "gt_pose" in data.keys() else None |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ref_motion = amd_model.extract_motion(ref_video_z,mask_ratio=args.motion_mask_ratio) |
|
|
gt_motion = amd_model.extract_motion(gt_video_z) |
|
|
if randomref_video_z is not None: |
|
|
randomref_motion = amd_model.extract_motion(randomref_video_z) |
|
|
else: |
|
|
randomref_motion = None |
|
|
|
|
|
|
|
|
if args.use_sample_timestep: |
|
|
timestep = torch.from_numpy(sample_timestep(ref_motion.shape[0],args.sample_timestep_m,args.sample_timestep_s,1000)).to(device,ref_motion.dtype) |
|
|
else: |
|
|
timestep = torch.randint(0,1000+1,(ref_motion.shape[0],)).to(device,ref_motion.dtype) |
|
|
|
|
|
|
|
|
loss_dict,motion_pre_ode = audio_decoder(motion_gt=gt_motion, |
|
|
ref_motion=ref_motion, |
|
|
randomref_motion=randomref_motion, |
|
|
audio=gt_audio, |
|
|
ref_audio = ref_audio, |
|
|
pose=gt_pose_z, |
|
|
ref_pose = ref_pose_z, |
|
|
timestep = timestep) |
|
|
|
|
|
if args.need_amd_loss : |
|
|
amd_loss = amd_model.forward_with_refimg_motion(video=gt_video_z, |
|
|
ref_img =ref_video_z[:,-1:,:], |
|
|
motion = motion_pre_ode) |
|
|
loss_dict['amd_loss'] = amd_loss |
|
|
|
|
|
loss = loss_dict['loss'] + loss_dict['amd_loss'] if args.need_amd_loss else loss_dict['loss'] |
|
|
|
|
|
|
|
|
if accelerator.sync_gradients: |
|
|
global_step += 1 |
|
|
|
|
|
loss_cache = {} |
|
|
|
|
|
for key in loss_dict.keys(): |
|
|
avg_loss = accelerator.gather(loss_dict[key].repeat(args.batch_size)).mean() |
|
|
train_loss = avg_loss.item() |
|
|
loss_cache[key] = train_loss |
|
|
|
|
|
|
|
|
logs = {'global_step': loss_cache['loss']} |
|
|
progress_bar.set_postfix(**logs) |
|
|
progress_bar.update(1) |
|
|
|
|
|
|
|
|
txt = ''.join([f"{key:<10} {value:<10.6f}" for key,value in loss_cache.items()]) |
|
|
txt = f'Step {global_step:<5} :' + txt |
|
|
accelerator.print(txt) |
|
|
|
|
|
|
|
|
for key,val in loss_cache.items(): |
|
|
accelerator.log({key: val}, step=global_step) |
|
|
|
|
|
|
|
|
accelerator.backward(loss) |
|
|
|
|
|
|
|
|
if accelerator.sync_gradients: |
|
|
params_to_clip = audio_decoder.parameters() |
|
|
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) |
|
|
optimizer.step() |
|
|
scheduler.step() |
|
|
optimizer.zero_grad() |
|
|
|
|
|
|
|
|
if global_step % args.save_checkpoint_interval_step == 0: |
|
|
checkpoint_dir = os.path.join(proj_dir, "checkpoints") |
|
|
save_path = os.path.join(checkpoint_dir,f"checkpoint-{global_step}") |
|
|
accelerator.save_state(save_path) |
|
|
|
|
|
if accelerator.is_main_process and args.checkpoint_total_limit is not None: |
|
|
checkpoints = os.listdir(checkpoint_dir) |
|
|
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] |
|
|
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) |
|
|
|
|
|
|
|
|
if len(checkpoints) > args.checkpoint_total_limit: |
|
|
num_to_remove = len(checkpoints) - args.checkpoint_total_limit |
|
|
removing_checkpoints = checkpoints[0:num_to_remove] |
|
|
accelerator.print(f"removing checkpoints: {', '.join(removing_checkpoints)}") |
|
|
|
|
|
for removing_checkpoint in removing_checkpoints: |
|
|
removing_checkpoint = os.path.join(checkpoint_dir, removing_checkpoint) |
|
|
shutil.rmtree(removing_checkpoint) |
|
|
|
|
|
if global_step % args.eval_interval_step == 0 and accelerator.is_main_process: |
|
|
if global_step in global_validation_step: |
|
|
continue |
|
|
else: |
|
|
global_validation_step.append(global_step) |
|
|
log_validation(audio_decoder,amd_model, vae, eval_dataloader,device,accelerator, global_step) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
accelerator.wait_for_everyone() |
|
|
accelerator.end_training() |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main() |