megalado
Add local model code; tidy requirements
f87d582
# This code is based on https://github.com/openai/guided-diffusion
"""
Generate a large batch of image samples from a model and save them as a large
numpy array. This can be used to produce samples for FID evaluation.
"""
from utils.fixseed import fixseed
import os
import numpy as np
import torch
from utils.parser_util import generate_args
from utils.model_util import create_model_and_diffusion, load_saved_model
from utils import dist_util
from utils.sampler_util import ClassifierFreeSampleModel, AutoRegressiveSampler
from data_loaders.get_data import get_dataset_loader
from data_loaders.humanml.scripts.motion_process import recover_from_ric, get_target_location, sample_goal
import data_loaders.humanml.utils.paramUtil as paramUtil
from data_loaders.humanml.utils.plot_script import plot_3d_motion
import shutil
from data_loaders.tensors import collate
from moviepy.editor import clips_array
def main(args=None):
if args is None:
# args is None unless this method is called from another function (e.g. during training)
args = generate_args()
fixseed(args.seed)
out_path = args.output_dir
n_joints = 22 if args.dataset == 'humanml' else 21
name = os.path.basename(os.path.dirname(args.model_path))
niter = os.path.basename(args.model_path).replace('model', '').replace('.pt', '')
max_frames = 196 if args.dataset in ['kit', 'humanml'] else 60
fps = 12.5 if args.dataset == 'kit' else 20
n_frames = min(max_frames, int(args.motion_length*fps))
is_using_data = not any([args.input_text, args.text_prompt, args.action_file, args.action_name])
if args.context_len > 0:
is_using_data = True # For prefix completion, we need to sample a prefix
dist_util.setup_dist(args.device)
if out_path == '':
out_path = os.path.join(os.path.dirname(args.model_path),
'samples_{}_{}_seed{}'.format(name, niter, args.seed))
if args.text_prompt != '':
out_path += '_' + args.text_prompt.replace(' ', '_').replace('.', '')
elif args.input_text != '':
out_path += '_' + os.path.basename(args.input_text).replace('.txt', '').replace(' ', '_').replace('.', '')
elif args.dynamic_text_path != '':
out_path += '_' + os.path.basename(args.dynamic_text_path).replace('.txt', '').replace(' ', '_').replace('.', '')
# this block must be called BEFORE the dataset is loaded
texts = None
if args.text_prompt != '':
texts = [args.text_prompt] * args.num_samples
elif args.input_text != '':
assert os.path.exists(args.input_text)
with open(args.input_text, 'r') as fr:
texts = fr.readlines()
texts = [s.replace('\n', '') for s in texts]
args.num_samples = len(texts)
elif args.dynamic_text_path != '':
assert os.path.exists(args.dynamic_text_path)
assert args.autoregressive, "Dynamic text sampling is only supported with autoregressive sampling."
with open(args.dynamic_text_path, 'r') as fr:
texts = fr.readlines()
texts = [s.replace('\n', '') for s in texts]
n_frames = len(texts) * args.pred_len # each text prompt is for a single prediction
elif args.action_name:
action_text = [args.action_name]
args.num_samples = 1
elif args.action_file != '':
assert os.path.exists(args.action_file)
with open(args.action_file, 'r') as fr:
action_text = fr.readlines()
action_text = [s.replace('\n', '') for s in action_text]
args.num_samples = len(action_text)
args.batch_size = args.num_samples # Sampling a single batch from the testset, with exactly args.num_samples
print('Loading dataset...')
data = load_dataset(args, max_frames, n_frames)
total_num_samples = args.num_samples * args.num_repetitions
print("Creating model and diffusion...")
model, diffusion = create_model_and_diffusion(args, data)
sample_fn = diffusion.p_sample_loop
if args.autoregressive:
sample_cls = AutoRegressiveSampler(args, sample_fn, n_frames)
sample_fn = sample_cls.sample
print(f"Loading checkpoints from [{args.model_path}]...")
load_saved_model(model, args.model_path, use_avg=args.use_ema)
if args.guidance_param != 1:
model = ClassifierFreeSampleModel(model) # wrapping model with the classifier-free sampler
model.to(dist_util.dev())
model.eval() # disable random masking
motion_shape = (args.batch_size, model.njoints, model.nfeats, n_frames)
if is_using_data:
iterator = iter(data)
input_motion, model_kwargs = next(iterator)
input_motion = input_motion.to(dist_util.dev())
if texts is not None:
model_kwargs['y']['text'] = texts
else:
collate_args = [{'inp': torch.zeros(n_frames), 'tokens': None, 'lengths': n_frames}] * args.num_samples
is_t2m = any([args.input_text, args.text_prompt])
if is_t2m:
# t2m
collate_args = [dict(arg, text=txt) for arg, txt in zip(collate_args, texts)]
else:
# a2m
action = data.dataset.action_name_to_action(action_text)
collate_args = [dict(arg, action=one_action, action_text=one_action_text) for
arg, one_action, one_action_text in zip(collate_args, action, action_text)]
_, model_kwargs = collate(collate_args)
model_kwargs['y'] = {key: val.to(dist_util.dev()) if torch.is_tensor(val) else val for key, val in model_kwargs['y'].items()}
init_image = None
all_motions = []
all_lengths = []
all_text = []
# add CFG scale to batch
if args.guidance_param != 1:
model_kwargs['y']['scale'] = torch.ones(args.batch_size, device=dist_util.dev()) * args.guidance_param
if 'text' in model_kwargs['y'].keys():
# encoding once instead of each iteration saves lots of time
model_kwargs['y']['text_embed'] = model.encode_text(model_kwargs['y']['text'])
if args.dynamic_text_path != '':
# Rearange the text to match the autoregressive sampling - each prompt fits to a single prediction
# Which is 2 seconds of motion by default
model_kwargs['y']['text'] = [model_kwargs['y']['text']] * args.num_samples
if args.text_encoder_type == 'bert':
model_kwargs['y']['text_embed'] = (model_kwargs['y']['text_embed'][0].unsqueeze(0).repeat(args.num_samples, 1, 1, 1),
model_kwargs['y']['text_embed'][1].unsqueeze(0).repeat(args.num_samples, 1, 1))
else:
raise NotImplementedError('DiP model only supports BERT text encoder at the moment. If you implement this, please send a PR!')
for rep_i in range(args.num_repetitions):
print(f'### Sampling [repetitions #{rep_i}]')
sample = sample_fn(
model,
motion_shape,
clip_denoised=False,
model_kwargs=model_kwargs,
skip_timesteps=0, # 0 is the default value - i.e. don't skip any step
init_image=init_image,
progress=True,
dump_steps=None,
noise=None,
const_noise=False,
)
# Recover XYZ *positions* from HumanML3D vector representation
if model.data_rep == 'hml_vec':
n_joints = 22 if sample.shape[1] == 263 else 21
sample = data.dataset.t2m_dataset.inv_transform(sample.cpu().permute(0, 2, 3, 1)).float()
sample = recover_from_ric(sample, n_joints)
sample = sample.view(-1, *sample.shape[2:]).permute(0, 2, 3, 1)
rot2xyz_pose_rep = 'xyz' if model.data_rep in ['xyz', 'hml_vec'] else model.data_rep
rot2xyz_mask = None if rot2xyz_pose_rep == 'xyz' else model_kwargs['y']['mask'].reshape(args.batch_size, n_frames).bool()
sample = model.rot2xyz(x=sample, mask=rot2xyz_mask, pose_rep=rot2xyz_pose_rep, glob=True, translation=True,
jointstype='smpl', vertstrans=True, betas=None, beta=0, glob_rot=None,
get_rotations_back=False)
if args.unconstrained:
all_text += ['unconstrained'] * args.num_samples
else:
text_key = 'text' if 'text' in model_kwargs['y'] else 'action_text'
all_text += model_kwargs['y'][text_key]
all_motions.append(sample.cpu().numpy())
_len = model_kwargs['y']['lengths'].cpu().numpy()
if 'prefix' in model_kwargs['y'].keys():
_len[:] = sample.shape[-1]
all_lengths.append(_len)
print(f"created {len(all_motions) * args.batch_size} samples")
all_motions = np.concatenate(all_motions, axis=0)
all_motions = all_motions[:total_num_samples] # [bs, njoints, 6, seqlen]
all_text = all_text[:total_num_samples]
all_lengths = np.concatenate(all_lengths, axis=0)[:total_num_samples]
if os.path.exists(out_path):
shutil.rmtree(out_path)
os.makedirs(out_path)
npy_path = os.path.join(out_path, 'results.npy')
print(f"saving results file to [{npy_path}]")
np.save(npy_path,
{'motion': all_motions, 'text': all_text, 'lengths': all_lengths,
'num_samples': args.num_samples, 'num_repetitions': args.num_repetitions})
if args.dynamic_text_path != '':
text_file_content = '\n'.join(['#'.join(s) for s in all_text])
else:
text_file_content = '\n'.join(all_text)
with open(npy_path.replace('.npy', '.txt'), 'w') as fw:
fw.write(text_file_content)
with open(npy_path.replace('.npy', '_len.txt'), 'w') as fw:
fw.write('\n'.join([str(l) for l in all_lengths]))
print(f"saving visualizations to [{out_path}]...")
skeleton = paramUtil.kit_kinematic_chain if args.dataset == 'kit' else paramUtil.t2m_kinematic_chain
sample_print_template, row_print_template, all_print_template, \
sample_file_template, row_file_template, all_file_template = construct_template_variables(args.unconstrained)
max_vis_samples = 6
num_vis_samples = min(args.num_samples, max_vis_samples)
animations = np.empty(shape=(args.num_samples, args.num_repetitions), dtype=object)
max_length = max(all_lengths)
for sample_i in range(args.num_samples):
rep_files = []
for rep_i in range(args.num_repetitions):
caption = all_text[rep_i*args.batch_size + sample_i]
if args.dynamic_text_path != '': # caption per frame
assert type(caption) == list
caption_per_frame = []
for c in caption:
caption_per_frame += [c] * args.pred_len
caption = caption_per_frame
# Trim / freeze motion if needed
length = all_lengths[rep_i*args.batch_size + sample_i]
motion = all_motions[rep_i*args.batch_size + sample_i].transpose(2, 0, 1)[:max_length]
if motion.shape[0] > length:
motion[length:-1] = motion[length-1] # duplicate the last frame to end of motion, so all motions will be in equal length
save_file = sample_file_template.format(sample_i, rep_i)
animation_save_path = os.path.join(out_path, save_file)
gt_frames = np.arange(args.context_len) if args.context_len > 0 and not args.autoregressive else []
animations[sample_i, rep_i] = plot_3d_motion(animation_save_path,
skeleton, motion, dataset=args.dataset, title=caption,
fps=fps, gt_frames=gt_frames)
rep_files.append(animation_save_path)
save_multiple_samples(out_path, {'all': all_file_template}, animations, fps, max(list(all_lengths) + [n_frames]))
abs_path = os.path.abspath(out_path)
print(f'[Done] Results are at [{abs_path}]')
return out_path
def save_multiple_samples(out_path, file_templates, animations, fps, max_frames, no_dir=False):
num_samples_in_out_file = 3
n_samples = animations.shape[0]
for sample_i in range(0,n_samples,num_samples_in_out_file):
last_sample_i = min(sample_i+num_samples_in_out_file, n_samples)
all_sample_save_file = file_templates['all'].format(sample_i, last_sample_i-1)
if no_dir and n_samples <= num_samples_in_out_file:
all_sample_save_path = out_path
else:
all_sample_save_path = os.path.join(out_path, all_sample_save_file)
print(f'saving {os.path.split(out_path)[1]}/{all_sample_save_file}')
clips = clips_array(animations[sample_i:last_sample_i])
clips.duration = max_frames/fps
# import time
# start = time.time()
clips.write_videofile(all_sample_save_path, fps=fps, threads=4, logger=None)
# print(f'duration = {time.time()-start}')
for clip in clips.clips:
# close internal clips. Does nothing but better use in case one day it will do something
clip.close()
clips.close() # important
def construct_template_variables(unconstrained):
row_file_template = 'sample{:02d}.mp4'
all_file_template = 'samples_{:02d}_to_{:02d}.mp4'
if unconstrained:
sample_file_template = 'row{:02d}_col{:02d}.mp4'
sample_print_template = '[{} row #{:02d} column #{:02d} | -> {}]'
row_file_template = row_file_template.replace('sample', 'row')
row_print_template = '[{} row #{:02d} | all columns | -> {}]'
all_file_template = all_file_template.replace('samples', 'rows')
all_print_template = '[rows {:02d} to {:02d} | -> {}]'
else:
sample_file_template = 'sample{:02d}_rep{:02d}.mp4'
sample_print_template = '["{}" ({:02d}) | Rep #{:02d} | -> {}]'
row_print_template = '[ "{}" ({:02d}) | all repetitions | -> {}]'
all_print_template = '[samples {:02d} to {:02d} | all repetitions | -> {}]'
return sample_print_template, row_print_template, all_print_template, \
sample_file_template, row_file_template, all_file_template
def load_dataset(args, max_frames, n_frames):
data = get_dataset_loader(name=args.dataset,
batch_size=args.batch_size,
num_frames=max_frames,
split='test',
hml_mode='train' if args.pred_len > 0 else 'text_only', # We need to sample a prefix from the dataset
fixed_len=args.pred_len + args.context_len, pred_len=args.pred_len, device=dist_util.dev())
data.fixed_length = n_frames
return data
def is_substr_in_list(substr, list_of_strs):
return np.char.find(list_of_strs, substr) != -1 # [substr in string for string in list_of_strs]
if __name__ == "__main__":
main()