Spaces:
Sleeping
Sleeping
| import os | |
| import subprocess | |
| from typing import Any, List, Optional | |
| from argparse import Namespace | |
| import torch | |
| from cog import BasePredictor, Input, Path, BaseModel | |
| import data_loaders.humanml.utils.paramUtil as paramUtil | |
| from data_loaders.get_data import get_dataset_loader | |
| from data_loaders.humanml.scripts.motion_process import recover_from_ric | |
| from data_loaders.humanml.utils.plot_script import plot_3d_motion | |
| from data_loaders.tensors import collate | |
| from utils.sampler_util import ClassifierFreeSampleModel | |
| from utils import dist_util | |
| from utils.model_util import create_model_and_diffusion, load_model_wo_clip | |
| from visualize.motions2hik import motions2hik | |
| from sample.generate import construct_template_variables | |
| """ | |
| In case of matplot lib issues it may be needed to delete model/data_loaders/humanml/utils/plot_script.py" in lines 89~92 as | |
| suggested in https://github.com/GuyTevet/motion-diffusion-model/issues/6 | |
| """ | |
| class ModelOutput(BaseModel): | |
| json_file: Optional[Any] | |
| animation: Optional[List[Path]] | |
| def get_args(): | |
| args = Namespace() | |
| args.fps = 20 | |
| args.model_path = './save/humanml_trans_enc_512/model000200000.pt' | |
| args.guidance_param = 2.5 | |
| args.unconstrained = False | |
| args.dataset = 'humanml' | |
| args.cond_mask_prob = 1 | |
| args.emb_trans_dec = False | |
| args.latent_dim = 512 | |
| args.layers = 8 | |
| args.arch = 'trans_enc' | |
| args.noise_schedule = 'cosine' | |
| args.sigma_small = True | |
| args.lambda_vel = 0.0 | |
| args.lambda_rcxyz = 0.0 | |
| args.lambda_fc = 0.0 | |
| return args | |
| class Predictor(BasePredictor): | |
| def setup(self): | |
| subprocess.run(["mkdir", "/root/.cache/clip"]) | |
| subprocess.run(["cp", "-r", "ViT-B-32.pt", "/root/.cache/clip"]) | |
| self.args = get_args() | |
| self.num_frames = self.args.fps * 6 | |
| print('Loading dataset...') | |
| # temporary data | |
| self.data = get_dataset_loader(name=self.args.dataset, | |
| batch_size=1, | |
| num_frames=196, | |
| split='test', | |
| hml_mode='text_only') | |
| self.data.fixed_length = float(self.num_frames) | |
| print("Creating model and diffusion...") | |
| self.model, self.diffusion = create_model_and_diffusion(self.args, self.data) | |
| print(f"Loading checkpoints from...") | |
| state_dict = torch.load(self.args.model_path, map_location='cpu') | |
| load_model_wo_clip(self.model, state_dict) | |
| if self.args.guidance_param != 1: | |
| self.model = ClassifierFreeSampleModel(self.model) # wrapping model with the classifier-free sampler | |
| self.model.to(dist_util.dev()) | |
| self.model.eval() # disable random masking | |
| def predict( | |
| self, | |
| prompt: str = Input(default="the person walked forward and is picking up his toolbox."), | |
| num_repetitions: int = Input(default=3, description="How many"), | |
| output_format: str = Input( | |
| description='Choose the format of the output, either an animation or a json file of the animation data.\ | |
| The json format is: {"thetas": [...], "root_translation": [...], "joint_map": [...]}, where "thetas" \ | |
| is an [nframes x njoints x 3] array of joint rotations in degrees, "root_translation" is an [nframes x 3] \ | |
| array of (X, Y, Z) positions of the root, and "joint_map" is a list mapping the SMPL joint index to the\ | |
| corresponding HumanIK joint name', | |
| default="animation", | |
| choices=["animation", "json_file"], | |
| ), | |
| ) -> ModelOutput: | |
| args = self.args | |
| args.num_repetitions = int(num_repetitions) | |
| self.data = get_dataset_loader(name=self.args.dataset, | |
| batch_size=args.num_repetitions, | |
| num_frames=self.num_frames, | |
| split='test', | |
| hml_mode='text_only') | |
| collate_args = [{'inp': torch.zeros(self.num_frames), 'tokens': None, 'lengths': self.num_frames, 'text': str(prompt)}] | |
| _, model_kwargs = collate(collate_args) | |
| # add CFG scale to batch | |
| if args.guidance_param != 1: | |
| model_kwargs['y']['scale'] = torch.ones(args.num_repetitions, device=dist_util.dev()) * args.guidance_param | |
| sample_fn = self.diffusion.p_sample_loop | |
| sample = sample_fn( | |
| self.model, | |
| (args.num_repetitions, self.model.njoints, self.model.nfeats, self.num_frames), | |
| clip_denoised=False, | |
| model_kwargs=model_kwargs, | |
| skip_timesteps=0, # 0 is the default value - i.e. don't skip any step | |
| init_image=None, | |
| progress=True, | |
| dump_steps=None, | |
| noise=None, | |
| const_noise=False, | |
| ) | |
| # Recover XYZ *positions* from HumanML3D vector representation | |
| if self.model.data_rep == 'hml_vec': | |
| n_joints = 22 if sample.shape[1] == 263 else 21 | |
| sample = self.data.dataset.t2m_dataset.inv_transform(sample.cpu().permute(0, 2, 3, 1)).float() | |
| sample = recover_from_ric(sample, n_joints) | |
| sample = sample.view(-1, *sample.shape[2:]).permute(0, 2, 3, 1) | |
| rot2xyz_pose_rep = 'xyz' if self.model.data_rep in ['xyz', 'hml_vec'] else self.model.data_rep | |
| rot2xyz_mask = None if rot2xyz_pose_rep == 'xyz' else model_kwargs['y']['mask'].reshape(args.num_repetitions, | |
| self.num_frames).bool() | |
| sample = self.model.rot2xyz(x=sample, mask=rot2xyz_mask, pose_rep=rot2xyz_pose_rep, glob=True, translation=True, | |
| jointstype='smpl', vertstrans=True, betas=None, beta=0, glob_rot=None, | |
| get_rotations_back=False) | |
| all_motions = sample.cpu().numpy() | |
| if output_format == 'json_file': | |
| data_dict = motions2hik(all_motions) | |
| return ModelOutput(json_file=data_dict) | |
| caption = str(prompt) | |
| skeleton = paramUtil.t2m_kinematic_chain | |
| sample_print_template, row_print_template, all_print_template, \ | |
| sample_file_template, row_file_template, all_file_template = construct_template_variables( | |
| args.unconstrained) | |
| rep_files = [] | |
| replicate_fnames = [] | |
| for rep_i in range(args.num_repetitions): | |
| motion = all_motions[rep_i].transpose(2, 0, 1)[:self.num_frames] | |
| save_file = sample_file_template.format(1, rep_i) | |
| print(sample_print_template.format(caption, 1, rep_i, save_file)) | |
| plot_3d_motion(save_file, skeleton, motion, dataset=args.dataset, title=caption, fps=args.fps) | |
| # Credit for visualization: https://github.com/EricGuo5513/text-to-motion | |
| rep_files.append(save_file) | |
| replicate_fnames.append(Path(save_file)) | |
| return ModelOutput(animation=replicate_fnames) | |