Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| sys.path.append(os.getcwd()) | |
| from data_utils import torch_data | |
| from trainer.options import parse_args | |
| from trainer.config import load_JsonConfig | |
| from nets.init_model import init_model | |
| import torch | |
| import torch.utils.data as data | |
| import torch.optim as optim | |
| import numpy as np | |
| import random | |
| import logging | |
| import time | |
| import shutil | |
| def prn_obj(obj): | |
| print('\n'.join(['%s:%s' % item for item in obj.__dict__.items()])) | |
| class Trainer(): | |
| def __init__(self) -> None: | |
| parser = parse_args() | |
| self.args = parser.parse_args() | |
| self.config = load_JsonConfig(self.args.config_file) | |
| os.environ['smplx_npz_path']=self.config.smplx_npz_path | |
| os.environ['extra_joint_path']=self.config.extra_joint_path | |
| os.environ['j14_regressor_path']=self.config.j14_regressor_path | |
| # torch.set_default_dtype(torch.float64) | |
| # wandb_run = wandb.init(project=f's2g_sweep') | |
| # if self.args.use_wandb: | |
| # print('starting wandb sweep agent...') | |
| # wandb_key = 'e3d537403fce5c8a99893c2cbe20a8d49a79358d' | |
| # os.environ['WANDB_API_KEY'] = wandb_key | |
| # | |
| # default_config=dict(w_b=1,w_h=10) | |
| # wandb.init(config=default_config) | |
| # self.config.param.w_b=wandb.config.w_b | |
| # self.config.param.w_h=wandb.config.w_h | |
| # self.config.Train.epochs=30 | |
| # if self.args.use_wandb: | |
| # print('starting wandb sweep agent...') | |
| # wandb_key = 'e3d537403fce5c8a99893c2cbe20a8d49a79358d' | |
| # os.environ['WANDB_API_KEY'] = wandb_key | |
| # | |
| # wandb.init(config=self.args, project="s2g_sweep") | |
| # # wandb.config.update(self.args) | |
| # | |
| # self.config.param.w_b=self.args.w_b | |
| # self.config.param.w_h=self.args.w_h | |
| # self.config.Train.epochs=30 | |
| self.device = torch.device(self.args.gpu) | |
| torch.cuda.set_device(self.device) | |
| self.setup_seed(self.args.seed) | |
| self.set_train_dir() | |
| shutil.copy(self.args.config_file, self.train_dir) | |
| self.generator = init_model(self.config.Model.model_name, self.args, self.config) | |
| self.init_dataloader() | |
| self.start_epoch = 0 | |
| self.global_steps = 0 | |
| if self.args.resume: | |
| self.resume() | |
| # self.init_optimizer() | |
| def setup_seed(self, seed): | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| np.random.seed(seed) | |
| random.seed(seed) | |
| torch.backends.cudnn.deterministic = True | |
| def set_train_dir(self): | |
| time_stamp = time.strftime('%Y-%m-%d',time.localtime(time.time())) | |
| train_dir = os.path.join(os.getcwd(), self.args.save_dir, os.path.normpath( | |
| time_stamp + '-' + self.args.exp_name + '-' + self.config.Log.name)) | |
| # train_dir= os.path.join(os.getcwd(), self.args.save_dir, os.path.normpath(time_stamp+'-'+self.args.exp_name+'-'+time.strftime("%H:%M:%S"))) | |
| os.makedirs(train_dir, exist_ok=True) | |
| log_file=os.path.join(train_dir, 'train.log') | |
| fmt="%(asctime)s-%(lineno)d-%(message)s" | |
| logging.basicConfig( | |
| stream=sys.stdout, level=logging.INFO,format=fmt, datefmt='%m/%d %I:%M:%S %p' | |
| ) | |
| fh=logging.FileHandler(log_file) | |
| fh.setFormatter(logging.Formatter(fmt)) | |
| logging.getLogger().addHandler(fh) | |
| self.train_dir = train_dir | |
| def resume(self): | |
| print('resume from a previous ckpt') | |
| ckpt = torch.load(self.args.pretrained_pth) | |
| self.generator.load_state_dict(ckpt['generator']) | |
| self.start_epoch = ckpt['epoch'] | |
| self.global_steps = ckpt['global_steps'] | |
| self.generator.global_step = self.global_steps | |
| def init_dataloader(self): | |
| if 'freeMo' in self.config.Model.model_name: | |
| if self.config.Data.data_root.endswith('.csv'): | |
| raise NotImplementedError | |
| else: | |
| data_class = torch_data | |
| self.train_set = data_class( | |
| data_root=self.config.Data.data_root, | |
| speakers=self.args.speakers, | |
| split='train', | |
| limbscaling=self.config.Data.pose.augmentation, | |
| normalization=self.config.Data.pose.normalization, | |
| norm_method=self.config.Data.pose.norm_method, | |
| split_trans_zero=True, | |
| num_pre_frames=self.config.Data.pose.pre_pose_length, | |
| num_frames=self.config.Data.pose.generate_length, | |
| aud_feat_win_size=self.config.Data.aud.aud_feat_win_size, | |
| aud_feat_dim=self.config.Data.aud.aud_feat_dim, | |
| feat_method=self.config.Data.aud.feat_method, | |
| context_info=self.config.Data.aud.context_info | |
| ) | |
| if self.config.Data.pose.normalization: | |
| self.norm_stats = (self.train_set.data_mean, self.train_set.data_std) | |
| save_file = os.path.join(self.train_dir, 'norm_stats.npy') | |
| np.save(save_file, self.norm_stats, allow_pickle=True) | |
| self.train_set.get_dataset() | |
| self.trans_set = self.train_set.trans_dataset | |
| self.zero_set = self.train_set.zero_dataset | |
| self.trans_loader = data.DataLoader(self.trans_set, batch_size=self.config.DataLoader.batch_size, shuffle=True, num_workers=self.config.DataLoader.num_workers, drop_last=True) | |
| self.zero_loader = data.DataLoader(self.zero_set, batch_size=self.config.DataLoader.batch_size, shuffle=True, num_workers=self.config.DataLoader.num_workers, drop_last=True) | |
| elif 'smplx' in self.config.Model.model_name or 's2g' in self.config.Model.model_name: | |
| data_class = torch_data | |
| self.train_set = data_class( | |
| data_root=self.config.Data.data_root, | |
| speakers=self.args.speakers, | |
| split='train', | |
| limbscaling=self.config.Data.pose.augmentation, | |
| normalization=self.config.Data.pose.normalization, | |
| norm_method=self.config.Data.pose.norm_method, | |
| split_trans_zero=False, | |
| num_pre_frames=self.config.Data.pose.pre_pose_length, | |
| num_frames=self.config.Data.pose.generate_length, | |
| num_generate_length=self.config.Data.pose.generate_length, | |
| aud_feat_win_size=self.config.Data.aud.aud_feat_win_size, | |
| aud_feat_dim=self.config.Data.aud.aud_feat_dim, | |
| feat_method=self.config.Data.aud.feat_method, | |
| context_info=self.config.Data.aud.context_info, | |
| smplx=True, | |
| audio_sr=22000, | |
| convert_to_6d=self.config.Data.pose.convert_to_6d, | |
| expression=self.config.Data.pose.expression, | |
| config=self.config | |
| ) | |
| if self.config.Data.pose.normalization: | |
| self.norm_stats = (self.train_set.data_mean, self.train_set.data_std) | |
| save_file = os.path.join(self.train_dir, 'norm_stats.npy') | |
| np.save(save_file, self.norm_stats, allow_pickle=True) | |
| self.train_set.get_dataset() | |
| self.train_loader = data.DataLoader(self.train_set.all_dataset, | |
| batch_size=self.config.DataLoader.batch_size, shuffle=True, | |
| num_workers=self.config.DataLoader.num_workers, drop_last=True) | |
| else: | |
| data_class = torch_data | |
| self.train_set = data_class( | |
| data_root=self.config.Data.data_root, | |
| speakers=self.args.speakers, | |
| split='train', | |
| limbscaling=self.config.Data.pose.augmentation, | |
| normalization=self.config.Data.pose.normalization, | |
| norm_method=self.config.Data.pose.norm_method, | |
| split_trans_zero=False, | |
| num_pre_frames=self.config.Data.pose.pre_pose_length, | |
| num_frames=self.config.Data.pose.generate_length, | |
| aud_feat_win_size=self.config.Data.aud.aud_feat_win_size, | |
| aud_feat_dim=self.config.Data.aud.aud_feat_dim, | |
| feat_method=self.config.Data.aud.feat_method, | |
| context_info=self.config.Data.aud.context_info | |
| ) | |
| if self.config.Data.pose.normalization: | |
| self.norm_stats = (self.train_set.data_mean, self.train_set.data_std) | |
| save_file = os.path.join(self.train_dir, 'norm_stats.npy') | |
| np.save(save_file, self.norm_stats, allow_pickle=True) | |
| self.train_set.get_dataset() | |
| self.train_loader = data.DataLoader(self.train_set.all_dataset, batch_size=self.config.DataLoader.batch_size, shuffle=True, num_workers=self.config.DataLoader.num_workers, drop_last=True) | |
| def init_optimizer(self): | |
| pass | |
| def print_func(self, loss_dict, steps): | |
| info_str = ['global_steps:%d'%(self.global_steps)] | |
| info_str += ['%s:%.4f'%(key, loss_dict[key]/steps) for key in list(loss_dict.keys())] | |
| logging.info(','.join(info_str)) | |
| def save_model(self, epoch): | |
| # if 'vq' in self.config.Model.model_name: | |
| # state_dict = { | |
| # 'g_body': self.g_body.state_dict(), | |
| # 'g_hand': self.g_hand.state_dict(), | |
| # 'epoch': epoch, | |
| # 'global_steps': self.global_steps | |
| # } | |
| # else: | |
| state_dict = { | |
| 'generator': self.generator.state_dict(), | |
| 'epoch': epoch, | |
| 'global_steps': self.global_steps | |
| } | |
| save_name = os.path.join(self.train_dir, 'ckpt-%d.pth'%(epoch)) | |
| torch.save(state_dict, save_name) | |
| def train_epoch(self, epoch): | |
| epoch_loss_dict = {} #最好是追踪每个epoch的loss变换 | |
| epoch_steps = 0 | |
| if 'freeMo' in self.config.Model.model_name: | |
| for bat in zip(self.trans_loader, self.zero_loader): | |
| self.global_steps += 1 | |
| epoch_steps += 1 | |
| _, loss_dict = self.generator(bat) | |
| if epoch_loss_dict:#非空 | |
| for key in list(loss_dict.keys()): | |
| epoch_loss_dict[key] += loss_dict[key] | |
| else: | |
| for key in list(loss_dict.keys()): | |
| epoch_loss_dict[key] = loss_dict[key] | |
| if self.global_steps % self.config.Log.print_every == 0: | |
| self.print_func(epoch_loss_dict, epoch_steps) | |
| else: | |
| # self.config.Model.model_name==smplx_S2G | |
| for bat in self.train_loader: | |
| # if epoch_steps == 1000: | |
| # break | |
| self.global_steps += 1 | |
| epoch_steps += 1 | |
| bat['epoch'] = epoch | |
| _, loss_dict = self.generator(bat) | |
| if epoch_loss_dict:#非空 | |
| for key in list(loss_dict.keys()): | |
| epoch_loss_dict[key] += loss_dict[key] | |
| else: | |
| for key in list(loss_dict.keys()): | |
| epoch_loss_dict[key] = loss_dict[key] | |
| if self.global_steps % self.config.Log.print_every == 0: | |
| self.print_func(epoch_loss_dict, epoch_steps) | |
| def train(self): | |
| logging.info('start_training') | |
| self.total_loss_dict = {} | |
| for epoch in range(self.start_epoch, self.config.Train.epochs): | |
| logging.info('epoch:%d'%(epoch)) | |
| self.train_epoch(epoch) | |
| # self.generator.scheduler.step() | |
| # logging.info('learning rate:%d' % (self.generator.scheduler.get_lr()[0])) | |
| if (epoch+1)%self.config.Log.save_every == 0 or (epoch+1) == 30: | |
| self.save_model(epoch) | |