# This code is based on https://github.com/openai/guided-diffusion """ Generate a large batch of image samples from a model and save them as a large numpy array. This can be used to produce samples for FID evaluation. """ from utils.fixseed import fixseed import os import numpy as np import torch from utils.parser_util import generate_args from utils.model_util import create_model_and_diffusion, load_saved_model from utils import dist_util from utils.sampler_util import ClassifierFreeSampleModel, AutoRegressiveSampler from data_loaders.get_data import get_dataset_loader from data_loaders.humanml.scripts.motion_process import recover_from_ric, get_target_location, sample_goal import data_loaders.humanml.utils.paramUtil as paramUtil from data_loaders.humanml.utils.plot_script import plot_3d_motion import shutil from data_loaders.tensors import collate from moviepy.editor import clips_array def main(args=None): if args is None: # args is None unless this method is called from another function (e.g. during training) args = generate_args() fixseed(args.seed) out_path = args.output_dir n_joints = 22 if args.dataset == 'humanml' else 21 name = os.path.basename(os.path.dirname(args.model_path)) niter = os.path.basename(args.model_path).replace('model', '').replace('.pt', '') max_frames = 196 if args.dataset in ['kit', 'humanml'] else 60 fps = 12.5 if args.dataset == 'kit' else 20 n_frames = min(max_frames, int(args.motion_length*fps)) is_using_data = not any([args.input_text, args.text_prompt, args.action_file, args.action_name]) if args.context_len > 0: is_using_data = True # For prefix completion, we need to sample a prefix dist_util.setup_dist(args.device) if out_path == '': out_path = os.path.join(os.path.dirname(args.model_path), 'samples_{}_{}_seed{}'.format(name, niter, args.seed)) if args.text_prompt != '': out_path += '_' + args.text_prompt.replace(' ', '_').replace('.', '') elif args.input_text != '': out_path += '_' + os.path.basename(args.input_text).replace('.txt', '').replace(' ', '_').replace('.', '') elif args.dynamic_text_path != '': out_path += '_' + os.path.basename(args.dynamic_text_path).replace('.txt', '').replace(' ', '_').replace('.', '') # this block must be called BEFORE the dataset is loaded texts = None if args.text_prompt != '': texts = [args.text_prompt] * args.num_samples elif args.input_text != '': assert os.path.exists(args.input_text) with open(args.input_text, 'r') as fr: texts = fr.readlines() texts = [s.replace('\n', '') for s in texts] args.num_samples = len(texts) elif args.dynamic_text_path != '': assert os.path.exists(args.dynamic_text_path) assert args.autoregressive, "Dynamic text sampling is only supported with autoregressive sampling." with open(args.dynamic_text_path, 'r') as fr: texts = fr.readlines() texts = [s.replace('\n', '') for s in texts] n_frames = len(texts) * args.pred_len # each text prompt is for a single prediction elif args.action_name: action_text = [args.action_name] args.num_samples = 1 elif args.action_file != '': assert os.path.exists(args.action_file) with open(args.action_file, 'r') as fr: action_text = fr.readlines() action_text = [s.replace('\n', '') for s in action_text] args.num_samples = len(action_text) args.batch_size = args.num_samples # Sampling a single batch from the testset, with exactly args.num_samples print('Loading dataset...') data = load_dataset(args, max_frames, n_frames) total_num_samples = args.num_samples * args.num_repetitions print("Creating model and diffusion...") model, diffusion = create_model_and_diffusion(args, data) sample_fn = diffusion.p_sample_loop if args.autoregressive: sample_cls = AutoRegressiveSampler(args, sample_fn, n_frames) sample_fn = sample_cls.sample print(f"Loading checkpoints from [{args.model_path}]...") load_saved_model(model, args.model_path, use_avg=args.use_ema) if args.guidance_param != 1: model = ClassifierFreeSampleModel(model) # wrapping model with the classifier-free sampler model.to(dist_util.dev()) model.eval() # disable random masking motion_shape = (args.batch_size, model.njoints, model.nfeats, n_frames) if is_using_data: iterator = iter(data) input_motion, model_kwargs = next(iterator) input_motion = input_motion.to(dist_util.dev()) if texts is not None: model_kwargs['y']['text'] = texts else: collate_args = [{'inp': torch.zeros(n_frames), 'tokens': None, 'lengths': n_frames}] * args.num_samples is_t2m = any([args.input_text, args.text_prompt]) if is_t2m: # t2m collate_args = [dict(arg, text=txt) for arg, txt in zip(collate_args, texts)] else: # a2m action = data.dataset.action_name_to_action(action_text) collate_args = [dict(arg, action=one_action, action_text=one_action_text) for arg, one_action, one_action_text in zip(collate_args, action, action_text)] _, model_kwargs = collate(collate_args) model_kwargs['y'] = {key: val.to(dist_util.dev()) if torch.is_tensor(val) else val for key, val in model_kwargs['y'].items()} init_image = None all_motions = [] all_lengths = [] all_text = [] # add CFG scale to batch if args.guidance_param != 1: model_kwargs['y']['scale'] = torch.ones(args.batch_size, device=dist_util.dev()) * args.guidance_param if 'text' in model_kwargs['y'].keys(): # encoding once instead of each iteration saves lots of time model_kwargs['y']['text_embed'] = model.encode_text(model_kwargs['y']['text']) if args.dynamic_text_path != '': # Rearange the text to match the autoregressive sampling - each prompt fits to a single prediction # Which is 2 seconds of motion by default model_kwargs['y']['text'] = [model_kwargs['y']['text']] * args.num_samples if args.text_encoder_type == 'bert': model_kwargs['y']['text_embed'] = (model_kwargs['y']['text_embed'][0].unsqueeze(0).repeat(args.num_samples, 1, 1, 1), model_kwargs['y']['text_embed'][1].unsqueeze(0).repeat(args.num_samples, 1, 1)) else: raise NotImplementedError('DiP model only supports BERT text encoder at the moment. If you implement this, please send a PR!') for rep_i in range(args.num_repetitions): print(f'### Sampling [repetitions #{rep_i}]') sample = sample_fn( model, motion_shape, clip_denoised=False, model_kwargs=model_kwargs, skip_timesteps=0, # 0 is the default value - i.e. don't skip any step init_image=init_image, progress=True, dump_steps=None, noise=None, const_noise=False, ) # Recover XYZ *positions* from HumanML3D vector representation if model.data_rep == 'hml_vec': n_joints = 22 if sample.shape[1] == 263 else 21 sample = data.dataset.t2m_dataset.inv_transform(sample.cpu().permute(0, 2, 3, 1)).float() sample = recover_from_ric(sample, n_joints) sample = sample.view(-1, *sample.shape[2:]).permute(0, 2, 3, 1) rot2xyz_pose_rep = 'xyz' if model.data_rep in ['xyz', 'hml_vec'] else model.data_rep rot2xyz_mask = None if rot2xyz_pose_rep == 'xyz' else model_kwargs['y']['mask'].reshape(args.batch_size, n_frames).bool() sample = model.rot2xyz(x=sample, mask=rot2xyz_mask, pose_rep=rot2xyz_pose_rep, glob=True, translation=True, jointstype='smpl', vertstrans=True, betas=None, beta=0, glob_rot=None, get_rotations_back=False) if args.unconstrained: all_text += ['unconstrained'] * args.num_samples else: text_key = 'text' if 'text' in model_kwargs['y'] else 'action_text' all_text += model_kwargs['y'][text_key] all_motions.append(sample.cpu().numpy()) _len = model_kwargs['y']['lengths'].cpu().numpy() if 'prefix' in model_kwargs['y'].keys(): _len[:] = sample.shape[-1] all_lengths.append(_len) print(f"created {len(all_motions) * args.batch_size} samples") all_motions = np.concatenate(all_motions, axis=0) all_motions = all_motions[:total_num_samples] # [bs, njoints, 6, seqlen] all_text = all_text[:total_num_samples] all_lengths = np.concatenate(all_lengths, axis=0)[:total_num_samples] if os.path.exists(out_path): shutil.rmtree(out_path) os.makedirs(out_path) npy_path = os.path.join(out_path, 'results.npy') print(f"saving results file to [{npy_path}]") np.save(npy_path, {'motion': all_motions, 'text': all_text, 'lengths': all_lengths, 'num_samples': args.num_samples, 'num_repetitions': args.num_repetitions}) if args.dynamic_text_path != '': text_file_content = '\n'.join(['#'.join(s) for s in all_text]) else: text_file_content = '\n'.join(all_text) with open(npy_path.replace('.npy', '.txt'), 'w') as fw: fw.write(text_file_content) with open(npy_path.replace('.npy', '_len.txt'), 'w') as fw: fw.write('\n'.join([str(l) for l in all_lengths])) print(f"saving visualizations to [{out_path}]...") skeleton = paramUtil.kit_kinematic_chain if args.dataset == 'kit' else paramUtil.t2m_kinematic_chain sample_print_template, row_print_template, all_print_template, \ sample_file_template, row_file_template, all_file_template = construct_template_variables(args.unconstrained) max_vis_samples = 6 num_vis_samples = min(args.num_samples, max_vis_samples) animations = np.empty(shape=(args.num_samples, args.num_repetitions), dtype=object) max_length = max(all_lengths) for sample_i in range(args.num_samples): rep_files = [] for rep_i in range(args.num_repetitions): caption = all_text[rep_i*args.batch_size + sample_i] if args.dynamic_text_path != '': # caption per frame assert type(caption) == list caption_per_frame = [] for c in caption: caption_per_frame += [c] * args.pred_len caption = caption_per_frame # Trim / freeze motion if needed length = all_lengths[rep_i*args.batch_size + sample_i] motion = all_motions[rep_i*args.batch_size + sample_i].transpose(2, 0, 1)[:max_length] if motion.shape[0] > length: motion[length:-1] = motion[length-1] # duplicate the last frame to end of motion, so all motions will be in equal length save_file = sample_file_template.format(sample_i, rep_i) animation_save_path = os.path.join(out_path, save_file) gt_frames = np.arange(args.context_len) if args.context_len > 0 and not args.autoregressive else [] animations[sample_i, rep_i] = plot_3d_motion(animation_save_path, skeleton, motion, dataset=args.dataset, title=caption, fps=fps, gt_frames=gt_frames) rep_files.append(animation_save_path) save_multiple_samples(out_path, {'all': all_file_template}, animations, fps, max(list(all_lengths) + [n_frames])) abs_path = os.path.abspath(out_path) print(f'[Done] Results are at [{abs_path}]') return out_path def save_multiple_samples(out_path, file_templates, animations, fps, max_frames, no_dir=False): num_samples_in_out_file = 3 n_samples = animations.shape[0] for sample_i in range(0,n_samples,num_samples_in_out_file): last_sample_i = min(sample_i+num_samples_in_out_file, n_samples) all_sample_save_file = file_templates['all'].format(sample_i, last_sample_i-1) if no_dir and n_samples <= num_samples_in_out_file: all_sample_save_path = out_path else: all_sample_save_path = os.path.join(out_path, all_sample_save_file) print(f'saving {os.path.split(out_path)[1]}/{all_sample_save_file}') clips = clips_array(animations[sample_i:last_sample_i]) clips.duration = max_frames/fps # import time # start = time.time() clips.write_videofile(all_sample_save_path, fps=fps, threads=4, logger=None) # print(f'duration = {time.time()-start}') for clip in clips.clips: # close internal clips. Does nothing but better use in case one day it will do something clip.close() clips.close() # important def construct_template_variables(unconstrained): row_file_template = 'sample{:02d}.mp4' all_file_template = 'samples_{:02d}_to_{:02d}.mp4' if unconstrained: sample_file_template = 'row{:02d}_col{:02d}.mp4' sample_print_template = '[{} row #{:02d} column #{:02d} | -> {}]' row_file_template = row_file_template.replace('sample', 'row') row_print_template = '[{} row #{:02d} | all columns | -> {}]' all_file_template = all_file_template.replace('samples', 'rows') all_print_template = '[rows {:02d} to {:02d} | -> {}]' else: sample_file_template = 'sample{:02d}_rep{:02d}.mp4' sample_print_template = '["{}" ({:02d}) | Rep #{:02d} | -> {}]' row_print_template = '[ "{}" ({:02d}) | all repetitions | -> {}]' all_print_template = '[samples {:02d} to {:02d} | all repetitions | -> {}]' return sample_print_template, row_print_template, all_print_template, \ sample_file_template, row_file_template, all_file_template def load_dataset(args, max_frames, n_frames): data = get_dataset_loader(name=args.dataset, batch_size=args.batch_size, num_frames=max_frames, split='test', hml_mode='train' if args.pred_len > 0 else 'text_only', # We need to sample a prefix from the dataset fixed_len=args.pred_len + args.context_len, pred_len=args.pred_len, device=dist_util.dev()) data.fixed_length = n_frames return data def is_substr_in_list(substr, list_of_strs): return np.char.find(list_of_strs, substr) != -1 # [substr in string for string in list_of_strs] if __name__ == "__main__": main()