| import os |
| import subprocess |
| from typing import Any, List, Optional |
| from argparse import Namespace |
|
|
| import torch |
| from cog import BasePredictor, Input, Path, BaseModel |
|
|
| import data_loaders.humanml.utils.paramUtil as paramUtil |
| from data_loaders.get_data import get_dataset_loader |
| from data_loaders.humanml.scripts.motion_process import recover_from_ric |
| from data_loaders.humanml.utils.plot_script import plot_3d_motion |
| from data_loaders.tensors import collate |
| from utils.sampler_util import ClassifierFreeSampleModel |
| from utils import dist_util |
| from utils.model_util import create_model_and_diffusion, load_model_wo_clip |
| from visualize.motions2hik import motions2hik |
| from sample.generate import construct_template_variables |
|
|
| """ |
| In case of matplot lib issues it may be needed to delete model/data_loaders/humanml/utils/plot_script.py" in lines 89~92 as |
| suggested in https://github.com/GuyTevet/motion-diffusion-model/issues/6 |
| """ |
|
|
|
|
| class ModelOutput(BaseModel): |
| json_file: Optional[Any] |
| animation: Optional[List[Path]] |
|
|
|
|
| def get_args(): |
| args = Namespace() |
| args.fps = 20 |
| args.model_path = './save/humanml_trans_enc_512/model000200000.pt' |
| args.guidance_param = 2.5 |
| args.unconstrained = False |
| args.dataset = 'humanml' |
|
|
| args.cond_mask_prob = 1 |
| args.emb_trans_dec = False |
| args.latent_dim = 512 |
| args.layers = 8 |
| args.arch = 'trans_enc' |
|
|
| args.noise_schedule = 'cosine' |
| args.sigma_small = True |
| args.lambda_vel = 0.0 |
| args.lambda_rcxyz = 0.0 |
| args.lambda_fc = 0.0 |
| return args |
|
|
|
|
| class Predictor(BasePredictor): |
| def setup(self): |
| subprocess.run(["mkdir", "/root/.cache/clip"]) |
| subprocess.run(["cp", "-r", "ViT-B-32.pt", "/root/.cache/clip"]) |
|
|
| self.args = get_args() |
| self.num_frames = self.args.fps * 6 |
| print('Loading dataset...') |
|
|
| |
| self.data = get_dataset_loader(name=self.args.dataset, |
| batch_size=1, |
| num_frames=196, |
| split='test', |
| hml_mode='text_only') |
|
|
| self.data.fixed_length = float(self.num_frames) |
|
|
| print("Creating model and diffusion...") |
| self.model, self.diffusion = create_model_and_diffusion(self.args, self.data) |
|
|
| print(f"Loading checkpoints from...") |
| state_dict = torch.load(self.args.model_path, map_location='cpu') |
| load_model_wo_clip(self.model, state_dict) |
|
|
| if self.args.guidance_param != 1: |
| self.model = ClassifierFreeSampleModel(self.model) |
| self.model.to(dist_util.dev()) |
| self.model.eval() |
|
|
| def predict( |
| self, |
| prompt: str = Input(default="the person walked forward and is picking up his toolbox."), |
| num_repetitions: int = Input(default=3, description="How many"), |
| output_format: str = Input( |
| description='Choose the format of the output, either an animation or a json file of the animation data.\ |
| The json format is: {"thetas": [...], "root_translation": [...], "joint_map": [...]}, where "thetas" \ |
| is an [nframes x njoints x 3] array of joint rotations in degrees, "root_translation" is an [nframes x 3] \ |
| array of (X, Y, Z) positions of the root, and "joint_map" is a list mapping the SMPL joint index to the\ |
| corresponding HumanIK joint name', |
| default="animation", |
| choices=["animation", "json_file"], |
| ), |
| ) -> ModelOutput: |
| args = self.args |
| args.num_repetitions = int(num_repetitions) |
|
|
| self.data = get_dataset_loader(name=self.args.dataset, |
| batch_size=args.num_repetitions, |
| num_frames=self.num_frames, |
| split='test', |
| hml_mode='text_only') |
|
|
| collate_args = [{'inp': torch.zeros(self.num_frames), 'tokens': None, 'lengths': self.num_frames, 'text': str(prompt)}] |
| _, model_kwargs = collate(collate_args) |
|
|
| |
| if args.guidance_param != 1: |
| model_kwargs['y']['scale'] = torch.ones(args.num_repetitions, device=dist_util.dev()) * args.guidance_param |
|
|
| sample_fn = self.diffusion.p_sample_loop |
| sample = sample_fn( |
| self.model, |
| (args.num_repetitions, self.model.njoints, self.model.nfeats, self.num_frames), |
| clip_denoised=False, |
| model_kwargs=model_kwargs, |
| skip_timesteps=0, |
| init_image=None, |
| progress=True, |
| dump_steps=None, |
| noise=None, |
| const_noise=False, |
| ) |
|
|
| |
| if self.model.data_rep == 'hml_vec': |
| n_joints = 22 if sample.shape[1] == 263 else 21 |
| sample = self.data.dataset.t2m_dataset.inv_transform(sample.cpu().permute(0, 2, 3, 1)).float() |
| sample = recover_from_ric(sample, n_joints) |
| sample = sample.view(-1, *sample.shape[2:]).permute(0, 2, 3, 1) |
|
|
| rot2xyz_pose_rep = 'xyz' if self.model.data_rep in ['xyz', 'hml_vec'] else self.model.data_rep |
| rot2xyz_mask = None if rot2xyz_pose_rep == 'xyz' else model_kwargs['y']['mask'].reshape(args.num_repetitions, |
| self.num_frames).bool() |
| sample = self.model.rot2xyz(x=sample, mask=rot2xyz_mask, pose_rep=rot2xyz_pose_rep, glob=True, translation=True, |
| jointstype='smpl', vertstrans=True, betas=None, beta=0, glob_rot=None, |
| get_rotations_back=False) |
|
|
| all_motions = sample.cpu().numpy() |
|
|
| if output_format == 'json_file': |
| data_dict = motions2hik(all_motions) |
| return ModelOutput(json_file=data_dict) |
|
|
| caption = str(prompt) |
|
|
| skeleton = paramUtil.t2m_kinematic_chain |
|
|
| sample_print_template, row_print_template, all_print_template, \ |
| sample_file_template, row_file_template, all_file_template = construct_template_variables( |
| args.unconstrained) |
|
|
| rep_files = [] |
| replicate_fnames = [] |
| for rep_i in range(args.num_repetitions): |
| motion = all_motions[rep_i].transpose(2, 0, 1)[:self.num_frames] |
| save_file = sample_file_template.format(1, rep_i) |
| print(sample_print_template.format(caption, 1, rep_i, save_file)) |
| plot_3d_motion(save_file, skeleton, motion, dataset=args.dataset, title=caption, fps=args.fps) |
| |
| rep_files.append(save_file) |
|
|
| replicate_fnames.append(Path(save_file)) |
|
|
| return ModelOutput(animation=replicate_fnames) |
|
|