Spaces:
Runtime error
Runtime error
| import os | |
| from os.path import join as pjoin | |
| import utils.paramUtil as paramUtil | |
| from options.train_options import TrainCompOptions | |
| from utils.plot_script import * | |
| from models import MotionTransformer | |
| from trainers import DDPMTrainer | |
| from datasets import Text2MotionDataset | |
| from mmcv.runner import get_dist_info, init_dist | |
| from mmcv.parallel import MMDistributedDataParallel | |
| import torch | |
| import torch.distributed as dist | |
| def build_models(opt, dim_pose): | |
| encoder = MotionTransformer( | |
| input_feats=dim_pose, | |
| num_frames=opt.max_motion_length, | |
| num_layers=opt.num_layers, | |
| latent_dim=opt.latent_dim, | |
| no_clip=opt.no_clip, | |
| no_eff=opt.no_eff) | |
| return encoder | |
| if __name__ == '__main__': | |
| parser = TrainCompOptions() | |
| opt = parser.parse() | |
| rank, world_size = get_dist_info() | |
| opt.device = torch.device("cuda") | |
| torch.autograd.set_detect_anomaly(True) | |
| opt.save_root = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.name) | |
| opt.model_dir = pjoin(opt.save_root, 'model') | |
| opt.meta_dir = pjoin(opt.save_root, 'meta') | |
| if rank == 0: | |
| os.makedirs(opt.model_dir, exist_ok=True) | |
| os.makedirs(opt.meta_dir, exist_ok=True) | |
| if world_size > 1: | |
| dist.barrier() | |
| if opt.dataset_name == 't2m': | |
| opt.data_root = './data/HumanML3D' | |
| opt.motion_dir = pjoin(opt.data_root, 'new_joint_vecs') | |
| opt.text_dir = pjoin(opt.data_root, 'texts') | |
| opt.joints_num = 22 | |
| radius = 4 | |
| fps = 20 | |
| opt.max_motion_length = 196 | |
| dim_pose = 263 | |
| kinematic_chain = paramUtil.t2m_kinematic_chain | |
| elif opt.dataset_name == 'kit': | |
| opt.data_root = './data/KIT-ML' | |
| opt.motion_dir = pjoin(opt.data_root, 'new_joint_vecs') | |
| opt.text_dir = pjoin(opt.data_root, 'texts') | |
| opt.joints_num = 21 | |
| radius = 240 * 8 | |
| fps = 12.5 | |
| dim_pose = 251 | |
| opt.max_motion_length = 196 | |
| kinematic_chain = paramUtil.kit_kinematic_chain | |
| else: | |
| raise KeyError('Dataset Does Not Exist') | |
| dim_word = 300 | |
| mean = np.load(pjoin(opt.data_root, 'Mean.npy')) | |
| std = np.load(pjoin(opt.data_root, 'Std.npy')) | |
| train_split_file = pjoin(opt.data_root, 'train.txt') | |
| encoder = build_models(opt, dim_pose) | |
| if world_size > 1: | |
| encoder = MMDistributedDataParallel( | |
| encoder.cuda(), | |
| device_ids=[torch.cuda.current_device()], | |
| broadcast_buffers=False, | |
| find_unused_parameters=True) | |
| else: | |
| encoder = encoder.cuda() | |
| trainer = DDPMTrainer(opt, encoder) | |
| train_dataset = Text2MotionDataset(opt, mean, std, train_split_file, opt.times) | |
| trainer.train(train_dataset) | |