lipsync-docker / scripts /train_unet.py
naicoi's picture
model-dirs (#2)
f5651ba
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import math
import argparse
import shutil
import datetime
import logging
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from config import MODELS_DIR
from omegaconf import OmegaConf
from tqdm.auto import tqdm
from einops import rearrange
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
import diffusers
from diffusers import AutoencoderKL, DDIMScheduler
from diffusers.utils.logging import get_logger
from diffusers.optimization import get_scheduler
from accelerate.utils import set_seed
from latentsync.data.unet_dataset import UNetDataset
from latentsync.models.unet import UNet3DConditionModel
from latentsync.models.stable_syncnet import StableSyncNet
from latentsync.pipelines.lipsync_pipeline import LipsyncPipeline
from latentsync.utils.util import (
init_dist,
cosine_loss,
one_step_sampling,
)
from latentsync.utils.util import plot_loss_chart
from latentsync.whisper.audio2feature import Audio2Feature
from latentsync.trepa.loss import TREPALoss
from eval.syncnet import SyncNetEval
from eval.syncnet_detect import SyncNetDetector
from eval.eval_sync_conf import syncnet_eval
import lpips
logger = get_logger(__name__)
def main(config):
# Initialize distributed training
local_rank = init_dist()
global_rank = dist.get_rank()
num_processes = dist.get_world_size()
is_main_process = global_rank == 0
seed = config.run.seed + global_rank
set_seed(seed)
# Logging folder
folder_name = "train" + datetime.datetime.now().strftime(f"-%Y_%m_%d-%H:%M:%S")
output_dir = os.path.join(config.data.train_output_dir, folder_name)
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
# Handle the output folder creation
if is_main_process:
diffusers.utils.logging.set_verbosity_info()
os.makedirs(output_dir, exist_ok=True)
os.makedirs(f"{output_dir}/checkpoints", exist_ok=True)
os.makedirs(f"{output_dir}/val_videos", exist_ok=True)
os.makedirs(f"{output_dir}/sync_conf_results", exist_ok=True)
shutil.copy(config.unet_config_path, output_dir)
shutil.copy(config.data.syncnet_config_path, output_dir)
device = torch.device(local_rank)
noise_scheduler = DDIMScheduler.from_pretrained("configs")
vae = AutoencoderKL.from_pretrained(
"stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16, cache_dir=MODELS_DIR
)
vae.config.scaling_factor = 0.18215
vae.config.shift_factor = 0
vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
vae.requires_grad_(False)
vae.to(device)
if config.run.pixel_space_supervise:
vae.enable_gradient_checkpointing()
syncnet_eval_model = SyncNetEval(device=device)
syncnet_eval_model.loadParameters("checkpoints/auxiliary/syncnet_v2.model")
syncnet_detector = SyncNetDetector(
device=device, detect_results_dir="detect_results"
)
if config.model.cross_attention_dim == 768:
whisper_model_path = "small"
elif config.model.cross_attention_dim == 384:
whisper_model_path = "tiny"
else:
raise NotImplementedError("cross_attention_dim must be 768 or 384")
audio_encoder = Audio2Feature(
model_path=whisper_model_path,
device=device,
audio_embeds_cache_dir=config.data.audio_embeds_cache_dir,
num_frames=config.data.num_frames,
audio_feat_length=config.data.audio_feat_length,
)
unet, resume_global_step = UNet3DConditionModel.from_pretrained(
OmegaConf.to_container(config.model),
config.ckpt.resume_ckpt_path,
device=device,
)
if config.model.add_audio_layer and config.run.use_syncnet:
syncnet_config = OmegaConf.load(config.data.syncnet_config_path)
if syncnet_config.ckpt.inference_ckpt_path == "":
raise ValueError("SyncNet path is not provided")
syncnet = StableSyncNet(
OmegaConf.to_container(syncnet_config.model), gradient_checkpointing=True
).to(device=device, dtype=torch.float16)
syncnet_checkpoint = torch.load(
syncnet_config.ckpt.inference_ckpt_path,
map_location=device,
weights_only=True,
)
syncnet.load_state_dict(syncnet_checkpoint["state_dict"])
syncnet.requires_grad_(False)
del syncnet_checkpoint
torch.cuda.empty_cache()
if config.model.use_motion_module:
unet.requires_grad_(False)
for name, param in unet.named_parameters():
for trainable_module_name in config.run.trainable_modules:
if trainable_module_name in name:
param.requires_grad = True
break
trainable_params = list(filter(lambda p: p.requires_grad, unet.parameters()))
else:
unet.requires_grad_(True)
trainable_params = list(unet.parameters())
if config.optimizer.scale_lr:
config.optimizer.lr = config.optimizer.lr * num_processes
optimizer = torch.optim.AdamW(trainable_params, lr=config.optimizer.lr)
if is_main_process:
logger.info(f"trainable params number: {len(trainable_params)}")
logger.info(
f"trainable params scale: {sum(p.numel() for p in trainable_params) / 1e6:.3f} M"
)
# Enable gradient checkpointing
if config.run.enable_gradient_checkpointing:
unet.enable_gradient_checkpointing()
# Get the training dataset
train_dataset = UNetDataset(config.data.train_data_dir, config)
distributed_sampler = DistributedSampler(
train_dataset,
num_replicas=num_processes,
rank=global_rank,
shuffle=True,
seed=config.run.seed,
)
# DataLoaders creation:
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_size=config.data.batch_size,
shuffle=False,
sampler=distributed_sampler,
num_workers=config.data.num_workers,
pin_memory=False,
drop_last=True,
worker_init_fn=train_dataset.worker_init_fn,
)
# Get the training iteration
if config.run.max_train_steps == -1:
assert config.run.max_train_epochs != -1
config.run.max_train_steps = config.run.max_train_epochs * len(train_dataloader)
# Scheduler
lr_scheduler = get_scheduler(
config.optimizer.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=config.optimizer.lr_warmup_steps,
num_training_steps=config.run.max_train_steps,
)
if config.run.perceptual_loss_weight != 0 and config.run.pixel_space_supervise:
lpips_loss_func = lpips.LPIPS(net="vgg").to(device)
if config.run.trepa_loss_weight != 0 and config.run.pixel_space_supervise:
trepa_loss_func = TREPALoss(device=device, with_cp=True)
# Validation pipeline
pipeline = LipsyncPipeline(
vae=vae,
audio_encoder=audio_encoder,
unet=unet,
scheduler=noise_scheduler,
).to(device)
pipeline.set_progress_bar_config(disable=True)
# DDP warpper
unet = DDP(unet, device_ids=[local_rank], output_device=local_rank)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader))
# Afterwards we recalculate our number of training epochs
num_train_epochs = math.ceil(
config.run.max_train_steps / num_update_steps_per_epoch
)
# Train!
total_batch_size = config.data.batch_size * num_processes
if is_main_process:
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num Epochs = {num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {config.data.batch_size}")
logger.info(
f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
)
logger.info(f" Total optimization steps = {config.run.max_train_steps}")
global_step = resume_global_step
first_epoch = resume_global_step // num_update_steps_per_epoch
# Only show the progress bar once on each machine.
progress_bar = tqdm(
range(0, config.run.max_train_steps),
initial=resume_global_step,
desc="Steps",
disable=not is_main_process,
)
train_step_list = []
val_step_list = []
sync_conf_list = []
# Support mixed-precision training
scaler = (
torch.amp.GradScaler("cuda") if config.run.mixed_precision_training else None
)
for epoch in range(first_epoch, num_train_epochs):
train_dataloader.sampler.set_epoch(epoch)
unet.train()
for step, batch in enumerate(train_dataloader):
### >>>> Training >>>> ###
if config.model.add_audio_layer:
if batch["mel"] != []:
mel = batch["mel"].to(device, dtype=torch.float16)
audio_embeds_list = []
try:
for idx in range(len(batch["video_path"])):
video_path = batch["video_path"][idx]
start_idx = batch["start_idx"][idx]
with torch.no_grad():
audio_feat = audio_encoder.audio2feat(video_path)
audio_embeds = audio_encoder.crop_overlap_audio_window(
audio_feat, start_idx
)
audio_embeds_list.append(audio_embeds)
except Exception as e:
logger.info(f"{type(e).__name__} - {e} - {video_path}")
continue
audio_embeds = torch.stack(audio_embeds_list) # (B, 16, 50, 384)
audio_embeds = audio_embeds.to(device, dtype=torch.float16)
else:
audio_embeds = None
# Convert videos to latent space
gt_pixel_values = batch["gt_pixel_values"].to(device, dtype=torch.float16)
masked_pixel_values = batch["masked_pixel_values"].to(
device, dtype=torch.float16
)
masks = batch["masks"].to(device, dtype=torch.float16)
ref_pixel_values = batch["ref_pixel_values"].to(device, dtype=torch.float16)
gt_pixel_values = rearrange(gt_pixel_values, "b f c h w -> (b f) c h w")
masked_pixel_values = rearrange(
masked_pixel_values, "b f c h w -> (b f) c h w"
)
masks = rearrange(masks, "b f c h w -> (b f) c h w")
ref_pixel_values = rearrange(ref_pixel_values, "b f c h w -> (b f) c h w")
with torch.no_grad():
gt_latents = vae.encode(gt_pixel_values).latent_dist.sample()
masked_latents = vae.encode(masked_pixel_values).latent_dist.sample()
ref_latents = vae.encode(ref_pixel_values).latent_dist.sample()
masks = torch.nn.functional.interpolate(
masks, size=config.data.resolution // vae_scale_factor
)
gt_latents = (
rearrange(
gt_latents, "(b f) c h w -> b c f h w", f=config.data.num_frames
)
- vae.config.shift_factor
) * vae.config.scaling_factor
masked_latents = (
rearrange(
masked_latents, "(b f) c h w -> b c f h w", f=config.data.num_frames
)
- vae.config.shift_factor
) * vae.config.scaling_factor
ref_latents = (
rearrange(
ref_latents, "(b f) c h w -> b c f h w", f=config.data.num_frames
)
- vae.config.shift_factor
) * vae.config.scaling_factor
masks = rearrange(
masks, "(b f) c h w -> b c f h w", f=config.data.num_frames
)
# Sample noise that we'll add to the latents
if config.run.use_mixed_noise:
# Refer to the paper: https://arxiv.org/abs/2305.10474
noise_shared_std_dev = (
config.run.mixed_noise_alpha**2
/ (1 + config.run.mixed_noise_alpha**2)
) ** 0.5
noise_shared = torch.randn_like(gt_latents) * noise_shared_std_dev
noise_shared = noise_shared[:, :, 0:1].repeat(
1, 1, config.data.num_frames, 1, 1
)
noise_ind_std_dev = (1 / (1 + config.run.mixed_noise_alpha**2)) ** 0.5
noise_ind = torch.randn_like(gt_latents) * noise_ind_std_dev
noise = noise_ind + noise_shared
else:
noise = torch.randn_like(gt_latents)
noise = noise[
:, :, 0:1
].repeat(
1, 1, config.data.num_frames, 1, 1
) # Using the same noise for all frames, refer to the paper: https://arxiv.org/abs/2308.09716
bsz = gt_latents.shape[0]
# Sample a random timestep for each video
timesteps = torch.randint(
0,
noise_scheduler.config.num_train_timesteps,
(bsz,),
device=gt_latents.device,
)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_gt_latents = noise_scheduler.add_noise(gt_latents, noise, timesteps)
# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
raise NotImplementedError
else:
raise ValueError(
f"Unknown prediction type {noise_scheduler.config.prediction_type}"
)
unet_input = torch.cat(
[noisy_gt_latents, masks, masked_latents, ref_latents], dim=1
)
# Predict the noise and compute loss
# Mixed-precision training
with torch.autocast(
device_type="cuda",
dtype=torch.float16,
enabled=config.run.mixed_precision_training,
):
pred_noise = unet(
unet_input, timesteps, encoder_hidden_states=audio_embeds
).sample
if config.run.recon_loss_weight != 0:
recon_loss = F.mse_loss(
pred_noise.float(), target.float(), reduction="mean"
)
else:
recon_loss = 0
pred_latents = one_step_sampling(
noise_scheduler, pred_noise, timesteps, noisy_gt_latents
)
if config.run.pixel_space_supervise:
pred_pixel_values = vae.decode(
rearrange(pred_latents, "b c f h w -> (b f) c h w")
/ vae.config.scaling_factor
+ vae.config.shift_factor
).sample
if (
config.run.perceptual_loss_weight != 0
and config.run.pixel_space_supervise
):
pred_pixel_values_perceptual = pred_pixel_values[
:, :, pred_pixel_values.shape[2] // 2 :, :
]
gt_pixel_values_perceptual = gt_pixel_values[
:, :, gt_pixel_values.shape[2] // 2 :, :
]
lpips_loss = lpips_loss_func(
pred_pixel_values_perceptual.float(),
gt_pixel_values_perceptual.float(),
).mean()
else:
lpips_loss = 0
if config.run.trepa_loss_weight != 0 and config.run.pixel_space_supervise:
trepa_pred_pixel_values = rearrange(
pred_pixel_values,
"(b f) c h w -> b c f h w",
f=config.data.num_frames,
)
trepa_gt_pixel_values = rearrange(
gt_pixel_values,
"(b f) c h w -> b c f h w",
f=config.data.num_frames,
)
trepa_loss = trepa_loss_func(
trepa_pred_pixel_values, trepa_gt_pixel_values
)
else:
trepa_loss = 0
if config.model.add_audio_layer and config.run.use_syncnet:
if config.run.pixel_space_supervise:
if config.data.resolution != syncnet_config.data.resolution:
pred_pixel_values = F.interpolate(
pred_pixel_values,
size=(
syncnet_config.data.resolution,
syncnet_config.data.resolution,
),
mode="bicubic",
)
syncnet_input = rearrange(
pred_pixel_values,
"(b f) c h w -> b (f c) h w",
f=config.data.num_frames,
)
else:
syncnet_input = rearrange(pred_latents, "b c f h w -> b (f c) h w")
if syncnet_config.data.lower_half:
height = syncnet_input.shape[2]
syncnet_input = syncnet_input[:, :, height // 2 :, :]
ones_tensor = (
torch.ones((config.data.batch_size, 1)).float().to(device=device)
)
vision_embeds, audio_embeds = syncnet(syncnet_input, mel)
sync_loss = cosine_loss(
vision_embeds.float(), audio_embeds.float(), ones_tensor
).mean()
else:
sync_loss = 0
loss = (
recon_loss * config.run.recon_loss_weight
+ sync_loss * config.run.sync_loss_weight
+ lpips_loss * config.run.perceptual_loss_weight
+ trepa_loss * config.run.trepa_loss_weight
)
train_step_list.append(global_step)
optimizer.zero_grad()
# Backpropagate
if config.run.mixed_precision_training:
scaler.scale(loss).backward()
""" >>> gradient clipping >>> """
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(
trainable_params, config.optimizer.max_grad_norm
)
""" <<< gradient clipping <<< """
scaler.step(optimizer)
scaler.update()
else:
loss.backward()
""" >>> gradient clipping >>> """
torch.nn.utils.clip_grad_norm_(
trainable_params, config.optimizer.max_grad_norm
)
""" <<< gradient clipping <<< """
optimizer.step()
# Check the grad of attn blocks for debugging
# print(unet.module.up_blocks[3].attentions[2].transformer_blocks[0].attn2.to_q.weight.grad)
lr_scheduler.step()
progress_bar.update(1)
global_step += 1
### <<<< Training <<<< ###
# Save checkpoint and conduct validation
if is_main_process and (global_step % config.ckpt.save_ckpt_steps == 0):
model_save_path = os.path.join(
output_dir, f"checkpoints/checkpoint-{global_step}.pt"
)
state_dict = {
"global_step": global_step,
"state_dict": unet.module.state_dict(),
}
try:
torch.save(state_dict, model_save_path)
logger.info(f"Saved checkpoint to {model_save_path}")
except Exception as e:
logger.error(f"Error saving model: {e}")
# Validation
logger.info("Running validation... ")
validation_video_out_path = os.path.join(
output_dir, f"val_videos/val_video_{global_step}.mp4"
)
with torch.autocast(device_type="cuda", dtype=torch.float16):
pipeline(
config.data.val_video_path,
config.data.val_audio_path,
validation_video_out_path,
num_frames=config.data.num_frames,
num_inference_steps=config.run.inference_steps,
guidance_scale=config.run.guidance_scale,
weight_dtype=torch.float16,
width=config.data.resolution,
height=config.data.resolution,
mask_image_path=config.data.mask_image_path,
)
logger.info(
f"Saved validation video output to {validation_video_out_path}"
)
val_step_list.append(global_step)
if config.model.add_audio_layer and os.path.exists(
validation_video_out_path
):
try:
_, conf = syncnet_eval(
syncnet_eval_model,
syncnet_detector,
validation_video_out_path,
"temp",
)
except Exception as e:
logger.info(e)
conf = 0
sync_conf_list.append(conf)
plot_loss_chart(
os.path.join(
output_dir,
f"sync_conf_results/sync_conf_chart-{global_step}.png",
),
("Sync confidence", val_step_list, sync_conf_list),
)
logs = {"step_loss": loss.item(), "epoch": epoch}
progress_bar.set_postfix(**logs)
if global_step >= config.run.max_train_steps:
break
progress_bar.close()
dist.destroy_process_group()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Config file path
parser.add_argument("--unet_config_path", type=str, default="configs/unet.yaml")
args = parser.parse_args()
config = OmegaConf.load(args.unet_config_path)
config.unet_config_path = args.unet_config_path
main(config)