Spaces:
Runtime error
Runtime error
| """ | |
| Copyright (c) Meta Platforms, Inc. and affiliates. | |
| All rights reserved. | |
| This source code is licensed under the license found in the | |
| LICENSE file in the root directory of this source tree. | |
| """ | |
| import argparse | |
| import copy | |
| import json | |
| import logging | |
| import os | |
| import sys | |
| import warnings | |
| from typing import Any, Dict | |
| import model.vqvae as vqvae | |
| import numpy as np | |
| import torch | |
| import torch.optim as optim | |
| from data_loaders.get_data import get_dataset_loader, load_local_data | |
| from diffusion.nn import sum_flat | |
| from torch.utils.tensorboard import SummaryWriter | |
| from tqdm import tqdm | |
| from utils.vq_parser_utils import train_args | |
| warnings.filterwarnings("ignore") | |
| def cycle(iterable): | |
| while True: | |
| for x in iterable: | |
| yield x | |
| def get_logger(out_dir: str): | |
| logger = logging.getLogger("Exp") | |
| logger.setLevel(logging.INFO) | |
| formatter = logging.Formatter("%(asctime)s %(levelname)s %(message)s") | |
| file_path = os.path.join(out_dir, "run.log") | |
| file_hdlr = logging.FileHandler(file_path) | |
| file_hdlr.setFormatter(formatter) | |
| strm_hdlr = logging.StreamHandler(sys.stdout) | |
| strm_hdlr.setFormatter(formatter) | |
| logger.addHandler(file_hdlr) | |
| logger.addHandler(strm_hdlr) | |
| return logger | |
| class ModelTrainer: | |
| def __init__(self, args, net: vqvae.TemporalVertexCodec, logger, writer): | |
| self.net = net | |
| self.warm_up_iter = args.warm_up_iter | |
| self.lr = args.lr | |
| self.optimizer = optim.AdamW( | |
| self.net.parameters(), | |
| lr=args.lr, | |
| betas=(0.9, 0.99), | |
| weight_decay=args.weight_decay, | |
| ) | |
| self.scheduler = torch.optim.lr_scheduler.MultiStepLR( | |
| self.optimizer, milestones=args.lr_scheduler, gamma=args.gamma | |
| ) | |
| self.data_format = args.data_format | |
| self.loss = torch.nn.SmoothL1Loss() | |
| self.loss_vel = args.loss_vel | |
| self.commit = args.commit | |
| self.logger = logger | |
| self.writer = writer | |
| self.best_commit = float("inf") | |
| self.best_recons = float("inf") | |
| self.best_perplexity = float("inf") | |
| self.best_iter = 0 | |
| self.out_dir = args.out_dir | |
| def _masked_l2(self, a, b, mask): | |
| loss = self._l2_loss(a, b) | |
| loss = sum_flat(loss * mask.float()) | |
| n_entries = a.shape[1] * a.shape[2] | |
| non_zero_elements = sum_flat(mask) * n_entries | |
| mse_loss_val = loss / non_zero_elements | |
| return mse_loss_val | |
| def _l2_loss(self, motion_pred, motion_gt, mask=None): | |
| if mask is not None: | |
| return self._masked_l2(motion_pred, motion_gt, mask) | |
| else: | |
| return self.loss(motion_pred, motion_gt) | |
| def _vel_loss(self, motion_pred, motion_gt): | |
| model_results_vel = motion_pred[..., :-1] - motion_pred[..., 1:] | |
| model_targets_vel = motion_gt[..., :-1] - motion_gt[..., 1:] | |
| return self.loss(model_results_vel, model_targets_vel) | |
| def _update_lr_warm_up(self, nb_iter): | |
| current_lr = self.lr * (nb_iter + 1) / (self.warm_up_iter + 1) | |
| for param_group in self.optimizer.param_groups: | |
| param_group["lr"] = current_lr | |
| return current_lr | |
| def run_warmup_steps(self, train_loader_iter, skip_step, logger): | |
| avg_recons, avg_perplexity, avg_commit = 0.0, 0.0, 0.0 | |
| for nb_iter in tqdm(range(1, args.warm_up_iter)): | |
| current_lr = self._update_lr_warm_up(nb_iter) | |
| gt_motion, cond = next(train_loader_iter) | |
| loss_dict = self.run_train_step(gt_motion, cond, skip_step) | |
| avg_recons += loss_dict["loss_motion"] | |
| avg_perplexity += loss_dict["perplexity"] | |
| avg_commit += loss_dict["loss_commit"] | |
| if nb_iter % args.print_iter == 0: | |
| avg_recons /= args.print_iter | |
| avg_perplexity /= args.print_iter | |
| avg_commit /= args.print_iter | |
| logger.info( | |
| f"Warmup. Iter {nb_iter} : lr {current_lr:.5f} \t Commit. {avg_commit:.5f} \t PPL. {avg_perplexity:.2f} \t Recons. {avg_recons:.5f}" | |
| ) | |
| avg_recons, avg_perplexity, avg_commit = 0.0, 0.0, 0.0 | |
| def run_train_step( | |
| self, gt_motion: torch.Tensor, cond: torch.Tensor, skip_step: int | |
| ) -> Dict[str, Any]: | |
| self.net.train() | |
| loss_dict = {} | |
| # run model | |
| gt_motion = gt_motion.permute(0, 3, 1, 2).squeeze(-1).cuda().float() | |
| cond["y"] = { | |
| key: val.to(gt_motion.device) if torch.is_tensor(val) else val | |
| for key, val in cond["y"].items() | |
| } | |
| gt_motion = gt_motion[:, ::skip_step, :] | |
| pred_motion, loss_commit, perplexity = self.net(gt_motion, mask=None) | |
| loss_motion = self._l2_loss(pred_motion, gt_motion).mean() | |
| loss_vel = 0.0 | |
| if self.loss_vel > 0: | |
| loss_vel = self._vel_loss(pred_motion, gt_motion) | |
| loss = loss_motion + self.commit * loss_commit + self.loss_vel * loss_vel | |
| self.optimizer.zero_grad() | |
| loss.backward() | |
| self.optimizer.step() | |
| # record losses | |
| if self.loss_vel > 0: | |
| loss_dict["vel"] = loss_vel.item() | |
| loss_dict["loss"] = loss.item() | |
| loss_dict["loss_motion"] = loss_motion.item() | |
| loss_dict["loss_commit"] = loss_commit.item() | |
| loss_dict["perplexity"] = perplexity.item() | |
| return loss_dict | |
| def save_model(self, save_path): | |
| torch.save( | |
| { | |
| "net": self.net.state_dict(), | |
| "optimizer": self.optimizer.state_dict(), | |
| "scheduler": self.scheduler, | |
| }, | |
| save_path, | |
| ) | |
| def _save_predictions(self, name, unstd_pose, unstd_pred): | |
| curr_name = os.path.basename(name) | |
| path = os.path.join(self.out_dir, curr_name) | |
| for j in range(len(path.split("/")) - 1): | |
| if not os.path.exists("/".join(path.split("/")[: j + 1])): | |
| os.system("mkdir " + "/".join(path.split("/")[: j + 1])) | |
| np.save(os.path.join(self.out_dir, curr_name + "_gt.npy"), unstd_pose) | |
| np.save(os.path.join(self.out_dir, curr_name + "_pred.npy"), unstd_pred) | |
| def _log_losses( | |
| self, | |
| commit_loss: float, | |
| recons_loss: float, | |
| total_perplexity: float, | |
| nb_iter: int, | |
| nb_sample: int, | |
| draw: bool, | |
| save: bool, | |
| ) -> None: | |
| avg_commit = commit_loss / nb_sample | |
| avg_recons = recons_loss / nb_sample | |
| avg_perplexity = total_perplexity / nb_sample | |
| self.logger.info( | |
| f"Eval. Iter {nb_iter} : \t Commit. {avg_commit:.5f} \t PPL. {avg_perplexity:.2f} \t Recons. {avg_recons:.5f}" | |
| ) | |
| if draw: | |
| self.writer.add_scalar("./Val/Perplexity", avg_perplexity, nb_iter) | |
| self.writer.add_scalar("./Val/Commit", avg_commit, nb_iter) | |
| self.writer.add_scalar("./Val/Recons", avg_recons, nb_iter) | |
| if avg_perplexity < self.best_perplexity: | |
| msg = f"--> --> \t Perplexity Improved from {self.best_perplexity:.5f} to {avg_perplexity:.5f} !!!" | |
| self.logger.info(msg) | |
| self.best_perplexity = avg_perplexity | |
| if save: | |
| print(f"saving checkpoint net_best.pth") | |
| self.save_model(os.path.join(self.out_dir, "net_best.pth")) | |
| if avg_commit < self.best_commit: | |
| msg = f"--> --> \t Commit Improved from {self.best_commit:.5f} to {avg_commit:.5f} !!!" | |
| self.logger.info(msg) | |
| self.best_commit = avg_commit | |
| if avg_recons < self.best_recons: | |
| msg = f"--> --> \t Recons Improved from {self.best_recons:.5f} to {avg_recons:.5f} !!!" | |
| self.logger.info(msg) | |
| self.best_recons = avg_recons | |
| def evaluation_vqvae( | |
| self, | |
| val_loader, | |
| nb_iter: int, | |
| draw: bool = True, | |
| save: bool = True, | |
| savenpy: bool = False, | |
| ) -> None: | |
| self.net.eval() | |
| nb_sample = 0 | |
| commit_loss = 0 | |
| recons_loss = 0 | |
| total_perplexity = 0 | |
| for _, batch in enumerate(val_loader): | |
| motion, cond = batch | |
| m_length = cond["y"]["lengths"] | |
| motion = motion.permute(0, 3, 1, 2).squeeze(-1).cuda().float() | |
| cond["y"] = { | |
| key: val.to(motion.device) if torch.is_tensor(val) else val | |
| for key, val in cond["y"].items() | |
| } | |
| motion = motion[:, :: val_loader.dataset.step, :].cuda().float() | |
| bs, seq = motion.shape[0], motion.shape[1] | |
| pred_pose_eval = torch.zeros((bs, seq, motion.shape[-1])).cuda() | |
| for i in range(bs): | |
| curr_gt = motion[i : i + 1, : m_length[i]] | |
| pred, loss_commit, perplexity = self.net(curr_gt) | |
| l2_loss = self._l2_loss(pred, curr_gt) | |
| recons_loss += l2_loss.mean().item() | |
| commit_loss += loss_commit | |
| total_perplexity += perplexity | |
| unstd_pred = val_loader.dataset.inv_transform( | |
| pred.detach().cpu().numpy(), self.data_format | |
| ) | |
| unstd_pose = val_loader.dataset.inv_transform( | |
| curr_gt.detach().cpu().numpy(), self.data_format | |
| ) | |
| if savenpy: | |
| self._save_predictions( | |
| "b{i:04d}", unstd_pose[:, : m_length[i]], unstd_pred | |
| ) | |
| pred_pose_eval[i : i + 1, : m_length[i], :] = pred | |
| nb_sample += bs | |
| self._log_losses( | |
| commit_loss, recons_loss, total_perplexity, nb_iter, nb_sample, draw, save | |
| ) | |
| if save: | |
| print(f"saving checkpoint net_last.pth") | |
| self.save_model(os.path.join(self.out_dir, "net_last.pth")) | |
| if nb_iter % 100000 == 0: | |
| print(f"saving checkpoint net_iter_x.pth") | |
| self.save_model( | |
| os.path.join(self.out_dir, "net_iter" + str(nb_iter) + ".pth") | |
| ) | |
| def _load_data_info(args, logger): | |
| data_dict = load_local_data(args.data_root, audio_per_frame=1600) | |
| train_loader = get_dataset_loader( | |
| args=args, data_dict=data_dict, split="train", add_padding=False | |
| ) | |
| val_loader = get_dataset_loader( | |
| args=args, data_dict=data_dict, split="val", add_padding=False | |
| ) | |
| logger.info( | |
| f"Training on {args.dataname}, motions are with {args.nb_joints} joints" | |
| ) | |
| train_loader_iter = cycle(train_loader) | |
| skip_step = train_loader.dataset.step | |
| return train_loader_iter, val_loader, skip_step | |
| def _load_checkpoint(args, net, logger): | |
| cp_dir = os.path.dirname(args.resume_pth) | |
| with open(f"{cp_dir}/args.json") as f: | |
| trans_args = json.load(f) | |
| assert trans_args["data_root"] == args.data_root, "data_root doesnt match" | |
| logger.info("loading checkpoint from {}".format(args.resume_pth)) | |
| ckpt = torch.load(args.resume_pth, map_location="cpu") | |
| net.load_state_dict(ckpt["net"], strict=True) | |
| return net | |
| def main(args): | |
| torch.manual_seed(args.seed) | |
| os.makedirs(args.out_dir, exist_ok=True) | |
| logger = get_logger(args.out_dir) | |
| writer = SummaryWriter(args.out_dir) | |
| logger.info(json.dumps(vars(args), indent=4, sort_keys=True)) | |
| if args.data_format == "pose": | |
| args.nb_joints = 104 | |
| elif args.data_format == "face": | |
| args.nb_joints = 256 | |
| args_path = os.path.join(args.out_dir, "args.json") | |
| with open(args_path, "w") as fw: | |
| json.dump(vars(args), fw, indent=4, sort_keys=True) | |
| if not os.path.exists(args.data_root): | |
| args.data_root = args.data_root.replace("/home/", "/derived/") | |
| train_loader_iter, val_loader, skip_step = _load_data_info(args, logger) | |
| net = vqvae.TemporalVertexCodec( | |
| n_vertices=args.nb_joints, | |
| latent_dim=args.output_emb_width, | |
| categories=args.code_dim, | |
| residual_depth=args.depth, | |
| ) | |
| if args.resume_pth: | |
| net = _load_checkpoint(args, net, logger) | |
| net.train() | |
| net.cuda() | |
| trainer = ModelTrainer(args, net, logger, writer) | |
| trainer.run_warmup_steps(train_loader_iter, skip_step, logger) | |
| avg_recons, avg_perplexity, avg_commit = 0.0, 0.0, 0.0 | |
| with torch.no_grad(): | |
| trainer.evaluation_vqvae( | |
| val_loader, 0, save=(args.total_iter > 0), savenpy=True | |
| ) | |
| for nb_iter in range(1, args.total_iter + 1): | |
| gt_motion, cond = next(train_loader_iter) | |
| loss_dict = trainer.run_train_step(gt_motion, cond, skip_step) | |
| trainer.scheduler.step() | |
| avg_recons += loss_dict["loss_motion"] | |
| avg_perplexity += loss_dict["perplexity"] | |
| avg_commit += loss_dict["loss_commit"] | |
| if nb_iter % args.print_iter == 0: | |
| avg_recons /= args.print_iter | |
| avg_perplexity /= args.print_iter | |
| avg_commit /= args.print_iter | |
| writer.add_scalar("./Train/L1", avg_recons, nb_iter) | |
| writer.add_scalar("./Train/PPL", avg_perplexity, nb_iter) | |
| writer.add_scalar("./Train/Commit", avg_commit, nb_iter) | |
| logger.info( | |
| f"Train. Iter {nb_iter} : \t Commit. {avg_commit:.5f} \t PPL. {avg_perplexity:.2f} \t Recons. {avg_recons:.5f}" | |
| ) | |
| avg_recons, avg_perplexity, avg_commit = (0.0, 0.0, 0.0) | |
| if nb_iter % args.eval_iter == 0: | |
| trainer.evaluation_vqvae( | |
| val_loader, nb_iter, save=(args.total_iter > 0), savenpy=True | |
| ) | |
| if __name__ == "__main__": | |
| args = train_args() | |
| main(args) | |