megalado
Add local model code; tidy requirements
f87d582
import os
import subprocess
from typing import Any, List, Optional
from argparse import Namespace
import torch
from cog import BasePredictor, Input, Path, BaseModel
import data_loaders.humanml.utils.paramUtil as paramUtil
from data_loaders.get_data import get_dataset_loader
from data_loaders.humanml.scripts.motion_process import recover_from_ric
from data_loaders.humanml.utils.plot_script import plot_3d_motion
from data_loaders.tensors import collate
from utils.sampler_util import ClassifierFreeSampleModel
from utils import dist_util
from utils.model_util import create_model_and_diffusion, load_model_wo_clip
from visualize.motions2hik import motions2hik
from sample.generate import construct_template_variables
"""
In case of matplot lib issues it may be needed to delete model/data_loaders/humanml/utils/plot_script.py" in lines 89~92 as
suggested in https://github.com/GuyTevet/motion-diffusion-model/issues/6
"""
class ModelOutput(BaseModel):
json_file: Optional[Any]
animation: Optional[List[Path]]
def get_args():
args = Namespace()
args.fps = 20
args.model_path = './save/humanml_trans_enc_512/model000200000.pt'
args.guidance_param = 2.5
args.unconstrained = False
args.dataset = 'humanml'
args.cond_mask_prob = 1
args.emb_trans_dec = False
args.latent_dim = 512
args.layers = 8
args.arch = 'trans_enc'
args.noise_schedule = 'cosine'
args.sigma_small = True
args.lambda_vel = 0.0
args.lambda_rcxyz = 0.0
args.lambda_fc = 0.0
return args
class Predictor(BasePredictor):
def setup(self):
subprocess.run(["mkdir", "/root/.cache/clip"])
subprocess.run(["cp", "-r", "ViT-B-32.pt", "/root/.cache/clip"])
self.args = get_args()
self.num_frames = self.args.fps * 6
print('Loading dataset...')
# temporary data
self.data = get_dataset_loader(name=self.args.dataset,
batch_size=1,
num_frames=196,
split='test',
hml_mode='text_only')
self.data.fixed_length = float(self.num_frames)
print("Creating model and diffusion...")
self.model, self.diffusion = create_model_and_diffusion(self.args, self.data)
print(f"Loading checkpoints from...")
state_dict = torch.load(self.args.model_path, map_location='cpu')
load_model_wo_clip(self.model, state_dict)
if self.args.guidance_param != 1:
self.model = ClassifierFreeSampleModel(self.model) # wrapping model with the classifier-free sampler
self.model.to(dist_util.dev())
self.model.eval() # disable random masking
def predict(
self,
prompt: str = Input(default="the person walked forward and is picking up his toolbox."),
num_repetitions: int = Input(default=3, description="How many"),
output_format: str = Input(
description='Choose the format of the output, either an animation or a json file of the animation data.\
The json format is: {"thetas": [...], "root_translation": [...], "joint_map": [...]}, where "thetas" \
is an [nframes x njoints x 3] array of joint rotations in degrees, "root_translation" is an [nframes x 3] \
array of (X, Y, Z) positions of the root, and "joint_map" is a list mapping the SMPL joint index to the\
corresponding HumanIK joint name',
default="animation",
choices=["animation", "json_file"],
),
) -> ModelOutput:
args = self.args
args.num_repetitions = int(num_repetitions)
self.data = get_dataset_loader(name=self.args.dataset,
batch_size=args.num_repetitions,
num_frames=self.num_frames,
split='test',
hml_mode='text_only')
collate_args = [{'inp': torch.zeros(self.num_frames), 'tokens': None, 'lengths': self.num_frames, 'text': str(prompt)}]
_, model_kwargs = collate(collate_args)
# add CFG scale to batch
if args.guidance_param != 1:
model_kwargs['y']['scale'] = torch.ones(args.num_repetitions, device=dist_util.dev()) * args.guidance_param
sample_fn = self.diffusion.p_sample_loop
sample = sample_fn(
self.model,
(args.num_repetitions, self.model.njoints, self.model.nfeats, self.num_frames),
clip_denoised=False,
model_kwargs=model_kwargs,
skip_timesteps=0, # 0 is the default value - i.e. don't skip any step
init_image=None,
progress=True,
dump_steps=None,
noise=None,
const_noise=False,
)
# Recover XYZ *positions* from HumanML3D vector representation
if self.model.data_rep == 'hml_vec':
n_joints = 22 if sample.shape[1] == 263 else 21
sample = self.data.dataset.t2m_dataset.inv_transform(sample.cpu().permute(0, 2, 3, 1)).float()
sample = recover_from_ric(sample, n_joints)
sample = sample.view(-1, *sample.shape[2:]).permute(0, 2, 3, 1)
rot2xyz_pose_rep = 'xyz' if self.model.data_rep in ['xyz', 'hml_vec'] else self.model.data_rep
rot2xyz_mask = None if rot2xyz_pose_rep == 'xyz' else model_kwargs['y']['mask'].reshape(args.num_repetitions,
self.num_frames).bool()
sample = self.model.rot2xyz(x=sample, mask=rot2xyz_mask, pose_rep=rot2xyz_pose_rep, glob=True, translation=True,
jointstype='smpl', vertstrans=True, betas=None, beta=0, glob_rot=None,
get_rotations_back=False)
all_motions = sample.cpu().numpy()
if output_format == 'json_file':
data_dict = motions2hik(all_motions)
return ModelOutput(json_file=data_dict)
caption = str(prompt)
skeleton = paramUtil.t2m_kinematic_chain
sample_print_template, row_print_template, all_print_template, \
sample_file_template, row_file_template, all_file_template = construct_template_variables(
args.unconstrained)
rep_files = []
replicate_fnames = []
for rep_i in range(args.num_repetitions):
motion = all_motions[rep_i].transpose(2, 0, 1)[:self.num_frames]
save_file = sample_file_template.format(1, rep_i)
print(sample_print_template.format(caption, 1, rep_i, save_file))
plot_3d_motion(save_file, skeleton, motion, dataset=args.dataset, title=caption, fps=args.fps)
# Credit for visualization: https://github.com/EricGuo5513/text-to-motion
rep_files.append(save_file)
replicate_fnames.append(Path(save_file))
return ModelOutput(animation=replicate_fnames)