from __future__ import annotations import os import pdb import random import time from typing import Literal from dataclasses import dataclass, asdict, make_dataclass import numpy as np import torch import torch.nn as nn import torch.optim as optim import tyro import yaml from torch.distributions.normal import Normal from torch.utils.tensorboard import SummaryWriter from pathlib import Path from tqdm import tqdm import pickle import json import copy from model.mld_denoiser import DenoiserMLP, DenoiserTransformer from model.mld_vae import AutoMldVae from data_loaders.humanml.data.dataset import WeightedPrimitiveSequenceDataset, SinglePrimitiveDataset from utils.smpl_utils import * from utils.misc_util import encode_text, compose_texts_with_and from pytorch3d import transforms from diffusion import gaussian_diffusion as gd from diffusion.respace import SpacedDiffusion, space_timesteps from diffusion.resample import create_named_schedule_sampler from mld.train_mvae import Args as MVAEArgs from mld.train_mvae import DataArgs, TrainArgs from mld.train_mld import DenoiserArgs, MLDArgs, create_gaussian_diffusion, DenoiserMLPArgs, DenoiserTransformerArgs from mld.rollout_mld import load_mld, ClassifierFreeWrapper debug = 0 @dataclass class OptimArgs: seed: int = 0 torch_deterministic: bool = True device: str = "cuda" save_dir = None denoiser_checkpoint: str = '' optim_input: str = '' text_prompt: str = None respacing: str = 'ddim10' guidance_param: float = 5.0 export_smpl: int = 0 zero_noise: int = 0 use_predicted_joints: int = 0 batch_size: int = 1 result_dir: str = 'inbetween' seed_type: str= 'history' optim_lr: float = 0.01 optim_steps: int = 300 optim_unit_grad: int = 1 optim_anneal_lr: int = 1 weight_jerk: float = 0.0 weight_floor: float = 0.0 init_noise_scale: float = 1.0 def calc_jerk(joints): vel = joints[:, 1:] - joints[:, :-1] # --> B x T-1 x 22 x 3 acc = vel[:, 1:] - vel[:, :-1] # --> B x T-2 x 22 x 3 jerk = acc[:, 1:] - acc[:, :-1] # --> B x T-3 x 22 x 3 jerk = torch.sqrt((jerk ** 2).sum(dim=-1)) # --> B x T-3 x 22, compute L1 norm of jerk jerk = jerk.amax(dim=[1, 2]) # --> B, Get the max of the jerk across all joints and frames return jerk.mean() def optimize(text_prompt, canonicalized_primitive_dict, goal_joints, joints_mask, denoiser_args, denoiser_model, vae_args, vae_model, diffusion, dataset, optim_args): device = optim_args.device batch_size = optim_args.batch_size future_length = dataset.future_length history_length = dataset.history_length primitive_length = history_length + future_length start_idx = history_length - 1 if optim_args.seed_type == 'repeat' else 0 end_idx = start_idx + seq_length - 1 assert 'ddim' in optim_args.respacing sample_fn = diffusion.ddim_sample_loop_full_chain texts = [] if ',' in text_prompt: # contain a time line of multipel actions num_rollout = 0 for segment in text_prompt.split(','): action, num_mp = segment.split('*') action = compose_texts_with_and(action.split(' and ')) texts = texts + [action] * int(num_mp) num_rollout += int(num_mp) else: action, num_rollout = text_prompt.split('*') action = compose_texts_with_and(action.split(' and ')) num_rollout = int(num_rollout) for _ in range(num_rollout): texts.append(action) all_text_embedding = encode_text(dataset.clip_model, texts, force_empty_zero=True).to(dtype=torch.float32, device=device) primitive_utility = dataset.primitive_utility out_path = optim_args.save_dir filename = f'guidance{optim_args.guidance_param}_seed{optim_args.seed}' if text_prompt != '': filename = text_prompt[:40].replace(' ', '_').replace('.', '') + '_' + filename if optim_args.respacing != '': filename = f'{optim_args.respacing}_{filename}' # if optim_args.smooth: # filename = f'smooth_{filename}' if optim_args.zero_noise: filename = f'zero_noise_{filename}' if optim_args.use_predicted_joints: filename = f'use_pred_joints_{filename}' filename = f'scale{optim_args.init_noise_scale}_floor{optim_args.weight_floor}_jerk{optim_args.weight_jerk}_{filename}' out_path = out_path / optim_args.result_dir / f'{optim_args.seed_type}seed' / filename out_path.mkdir(parents=True, exist_ok=True) batch = dataset.get_batch(batch_size=optim_args.batch_size) input_motions, model_kwargs = batch[0]['motion_tensor_normalized'], {'y': batch[0]} del model_kwargs['y']['motion_tensor_normalized'] gender = model_kwargs['y']['gender'][0] betas = model_kwargs['y']['betas'][:, :primitive_length, :].to(device) # [B, H+F, 10] pelvis_delta = primitive_utility.calc_calibrate_offset({ 'betas': betas[:, 0, :], 'gender': gender, }) # print(input_motions, model_kwargs) input_motions = input_motions.to(device) # [B, D, 1, T] motion_tensor = input_motions.squeeze(2).permute(0, 2, 1) # [B, T, D] history_motion_gt = motion_tensor[:, :history_length, :] # [B, H, D] if text_prompt == '': optim_args.guidance_param = 0. # Force unconditioned generation def rollout(noise): motion_sequences = None history_motion = history_motion_gt transf_rotmat = torch.eye(3, device=device, dtype=torch.float32).unsqueeze(0).repeat(batch_size, 1, 1) transf_transl = torch.zeros(3, device=device, dtype=torch.float32).reshape(1, 1, 3).repeat(batch_size, 1, 1) for segment_id in range(num_rollout): text_embedding = all_text_embedding[segment_id].expand(batch_size, -1) # [B, 512] guidance_param = torch.ones(batch_size, *denoiser_args.model_args.noise_shape).to(device=device) * optim_args.guidance_param y = { 'text_embedding': text_embedding, 'history_motion_normalized': history_motion, 'scale': guidance_param, } x_start_pred = sample_fn( denoiser_model, (batch_size, *denoiser_args.model_args.noise_shape), clip_denoised=False, model_kwargs={'y': y}, skip_timesteps=0, # 0 is the default value - i.e. don't skip any step init_image=None, progress=False, noise=noise[segment_id], ) # [B, T=1, D] # x_start_pred = x_start_pred.clamp(min=-3, max=3) # print('x_start_pred:', x_start_pred.mean(), x_start_pred.std(), x_start_pred.min(), x_start_pred.max()) latent_pred = x_start_pred.permute(1, 0, 2) # [T=1, B, D] future_motion_pred = vae_model.decode(latent_pred, history_motion, nfuture=future_length, scale_latent=denoiser_args.rescale_latent) # [B, F, D], normalized future_frames = dataset.denormalize(future_motion_pred) new_history_frames = future_frames[:, -history_length:, :] """transform primitive to world coordinate, prepare for serialization""" if segment_id == 0: # add init history motion future_frames = torch.cat([dataset.denormalize(history_motion), future_frames], dim=1) future_feature_dict = primitive_utility.tensor_to_dict(future_frames) future_feature_dict.update( { 'transf_rotmat': transf_rotmat, 'transf_transl': transf_transl, 'gender': gender, 'betas': betas[:, :future_length, :] if segment_id > 0 else betas[:, :primitive_length, :], 'pelvis_delta': pelvis_delta, } ) future_primitive_dict = primitive_utility.feature_dict_to_smpl_dict(future_feature_dict) future_primitive_dict = primitive_utility.transform_primitive_to_world(future_primitive_dict) if motion_sequences is None: motion_sequences = future_primitive_dict else: for key in ['transl', 'global_orient', 'body_pose', 'betas', 'joints']: motion_sequences[key] = torch.cat([motion_sequences[key], future_primitive_dict[key]], dim=1) # [B, T, ...] """update history motion seed, update global transform""" history_feature_dict = primitive_utility.tensor_to_dict(new_history_frames) history_feature_dict.update( { 'transf_rotmat': transf_rotmat, 'transf_transl': transf_transl, 'gender': gender, 'betas': betas[:, :history_length, :], 'pelvis_delta': pelvis_delta, } ) canonicalized_history_primitive_dict, blended_feature_dict = primitive_utility.get_blended_feature( history_feature_dict, use_predicted_joints=optim_args.use_predicted_joints) transf_rotmat, transf_transl = canonicalized_history_primitive_dict['transf_rotmat'], \ canonicalized_history_primitive_dict['transf_transl'] history_motion = primitive_utility.dict_to_tensor(blended_feature_dict) history_motion = dataset.normalize(history_motion) # [B, T, D] motion_sequences['texts'] = texts return motion_sequences optim_steps = optim_args.optim_steps lr = optim_args.optim_lr noise = torch.randn(num_rollout, batch_size, *denoiser_args.model_args.noise_shape, device=device, dtype=torch.float32) # noise = noise.clip(min=-1, max=1) noise = noise * optim_args.init_noise_scale noise.requires_grad_(True) reduction_dims = list(range(1, len(noise.shape))) criterion = torch.nn.HuberLoss(reduction='mean', delta=1.0) optimizer = torch.optim.Adam([noise], lr=lr) for i in tqdm(range(optim_steps)): optimizer.zero_grad() if optim_args.optim_anneal_lr: frac = 1.0 - i / optim_steps lrnow = frac * lr optimizer.param_groups[0]["lr"] = lrnow motion_sequences = rollout(noise) # joints_diff = (motion_sequences['joints'][:, seq_length - 1, joints_mask] - goal_joints[:, joints_mask]) ** 2 # joints_diff = torch.sqrt(joints_diff.sum(dim=-1)).mean(dim=1).mean(dim=0) # loss_joints = joints_diff # print('joints shape:', motion_sequences['joints'].shape, goal_joints.shape, joints_mask.shape) loss_joints = criterion(motion_sequences['joints'][:, end_idx, joints_mask], goal_joints[:, joints_mask]) loss_jerk = calc_jerk(motion_sequences['joints'][:, start_idx:end_idx + 1]) floor_height = motion_sequences['joints'][:, 0, FOOT_JOINTS_IDX, 2].amin(dim=-1) # [B], assuming first frame on floor foot_height = motion_sequences['joints'][:, start_idx:end_idx + 1, FOOT_JOINTS_IDX, 2].amin(dim=-1) # [B, T] loss_floor = -(foot_height - floor_height.unsqueeze(1)).clamp(max=0).mean() loss = loss_joints + optim_args.weight_jerk * loss_jerk + optim_args.weight_floor * loss_floor loss.backward() if optim_args.optim_unit_grad: noise.grad.data /= noise.grad.norm(p=2, dim=reduction_dims, keepdim=True).clamp(min=1e-6) optimizer.step() # print(f'[{i}/{optim_steps}] loss: {loss.item()} joints_diff: {loss_joints.item()} jerk: {loss_jerk.item()} floor: {loss_floor.item()}') print(f'[{i}/{optim_steps}] loss: {loss.item()} joints_diff: {loss_joints.item()} jerk: {loss_jerk.item()} floor: {loss_floor.item()}') motion_sequences = rollout(noise) # export input sequence sequence = { 'texts': texts, 'gender': canonicalized_primitive_dict['gender'], 'betas': canonicalized_primitive_dict['betas'][0], 'transl': canonicalized_primitive_dict['transl'][0], 'global_orient': canonicalized_primitive_dict['global_orient'][0], 'body_pose': canonicalized_primitive_dict['body_pose'][0], 'joints': canonicalized_primitive_dict['joints'][0], 'history_length': history_length, 'future_length': future_length, 'mocap_framerate': dataset.target_fps, } if optim_args.seed_type == 'history': for key in ['betas', 'transl', 'global_orient', 'body_pose', 'joints']: sequence[key][history_length:-1] = sequence[key][history_length] tensor_dict_to_device(sequence, 'cpu') with open(os.path.join(out_path, f'input.pkl'), 'wb') as f: pickle.dump(sequence, f) for idx in range(optim_args.batch_size): sequence = { 'texts': texts, 'gender': motion_sequences['gender'], 'betas': motion_sequences['betas'][idx, start_idx:end_idx + 1], 'transl': motion_sequences['transl'][idx, start_idx:end_idx + 1], 'global_orient': motion_sequences['global_orient'][idx, start_idx:end_idx + 1], 'body_pose': motion_sequences['body_pose'][idx, start_idx:end_idx + 1], 'joints': motion_sequences['joints'][idx, start_idx:end_idx + 1], 'history_length': history_length, 'future_length': future_length, 'mocap_framerate': dataset.target_fps, } tensor_dict_to_device(sequence, 'cpu') with open(out_path / f'sample_{idx}.pkl', 'wb') as f: pickle.dump(sequence, f) # export smplx sequences for blender if optim_args.export_smpl: poses = transforms.matrix_to_axis_angle( torch.cat([sequence['global_orient'].reshape(-1, 1, 3, 3), sequence['body_pose']], dim=1) ).reshape(-1, 22 * 3) poses = torch.cat([poses, torch.zeros(poses.shape[0], 99).to(dtype=poses.dtype, device=poses.device)], dim=1) data_dict = { 'mocap_framerate': dataset.target_fps, # 30 'gender': sequence['gender'], 'betas': sequence['betas'][0, :10].detach().cpu().numpy(), 'poses': poses.detach().cpu().numpy(), 'trans': sequence['transl'].detach().cpu().numpy(), } with open(out_path / f'sample_{idx}_smplx.npz', 'wb') as f: np.savez(f, **data_dict) abs_path = out_path.absolute() print(f'[Done] Results are at [{abs_path}]') if __name__ == '__main__': optim_args = tyro.cli(OptimArgs) # TRY NOT TO MODIFY: seeding random.seed(optim_args.seed) np.random.seed(optim_args.seed) torch.manual_seed(optim_args.seed) torch.set_default_dtype(torch.float32) torch.backends.cudnn.deterministic = optim_args.torch_deterministic device = torch.device(optim_args.device if torch.cuda.is_available() else "cpu") optim_args.device = device denoiser_args, denoiser_model, vae_args, vae_model = load_mld(optim_args.denoiser_checkpoint, device) denoiser_checkpoint = Path(optim_args.denoiser_checkpoint) save_dir = denoiser_checkpoint.parent / denoiser_checkpoint.name.split('.')[0] / 'optim' save_dir.mkdir(parents=True, exist_ok=True) optim_args.save_dir = save_dir diffusion_args = denoiser_args.diffusion_args diffusion_args.respacing = optim_args.respacing print('diffusion_args:', asdict(diffusion_args)) diffusion = create_gaussian_diffusion(diffusion_args) # load initial seed dataset seq_path = Path(optim_args.optim_input) dataset = SinglePrimitiveDataset(cfg_path=vae_args.data_args.cfg_path, # cfg path from model checkpoint dataset_path=vae_args.data_args.data_dir, # dataset path from model checkpoint sequence_path=seq_path, body_type=vae_args.data_args.body_type, batch_size=optim_args.batch_size, device=device, enforce_gender='male', enforce_zero_beta=1, ) future_length = dataset.future_length history_length = dataset.history_length primitive_length = history_length + future_length primitive_utility = dataset.primitive_utility print('body type:', primitive_utility.body_type) with open(seq_path, 'rb') as f: input_sequence = pickle.load(f) seq_length = input_sequence['transl'].shape[0] num_rollout = int(np.ceil((seq_length - 1) / future_length)) if optim_args.seed_type == 'repeat' else int(np.ceil((seq_length - history_length) / future_length)) print(f'seq_length: {seq_length}, num_rollout: {num_rollout}') text_prompt = input_sequence['texts'][0] if optim_args.text_prompt is None else optim_args.text_prompt text_prompt = f"{text_prompt}*{num_rollout}" body_pose = torch.tensor(input_sequence['body_pose'], dtype=torch.float32) body_pose = transforms.axis_angle_to_matrix(body_pose.reshape(-1, 3)).reshape(-1, 21, 3, 3).unsqueeze( 0) # [1, T, 21, 3, 3] global_orient = torch.tensor(input_sequence['global_orient'], dtype=torch.float32) global_orient = transforms.axis_angle_to_matrix(global_orient.reshape(-1, 3)).reshape(-1, 3, 3).unsqueeze( 0) # [1, T, 3, 3] transl = torch.tensor(input_sequence['transl'], dtype=torch.float32).unsqueeze(0) # [1, T, 3] betas = torch.tensor(input_sequence['betas'], dtype=torch.float32) if not dataset.enforce_zero_beta else torch.zeros(10, dtype=torch.float32) betas = betas.expand(1, seq_length, 10) # [1, T, 10] # transl[:, 1:-1] = transl[:, 0] # body_pose[:, 1:-1] = body_pose[:, 0] # global_orient[:, 1:-1] = global_orient[:, 0] seq_dict = { 'gender': dataset.enforce_gender, 'betas': betas, 'transl': transl, 'body_pose': body_pose, 'global_orient': global_orient, 'transf_rotmat': torch.eye(3).unsqueeze(0), 'transf_transl': torch.zeros(1, 1, 3), } seq_dict = tensor_dict_to_device(seq_dict, device) _, _, canonicalized_primitive_dict = primitive_utility.canonicalize(seq_dict) body_model = primitive_utility.get_smpl_model(dataset.enforce_gender) joints = body_model(return_verts=False, betas=canonicalized_primitive_dict['betas'][0], body_pose=canonicalized_primitive_dict['body_pose'][0], global_orient=canonicalized_primitive_dict['global_orient'][0], transl=canonicalized_primitive_dict['transl'][0] ).joints[:, :22, :] # [T, 22, 3] canonicalized_primitive_dict['joints'] = joints.unsqueeze(0) # [1, T, 22, 3] goal_joints = joints[[-1]].expand(optim_args.batch_size, -1, -1) # [B, 22, 3] joints_mask = torch.ones(22, dtype=torch.bool, device=device) optimize(text_prompt, canonicalized_primitive_dict, goal_joints, joints_mask, denoiser_args, denoiser_model, vae_args, vae_model, diffusion, dataset, optim_args)