txt2motion / eval /eval_humanml.py
Fred808's picture
Upload 107 files
5952a56 verified
from utils.parser_util import evaluation_parser
from utils.fixseed import fixseed
from datetime import datetime
from data_loaders.humanml.motion_loaders.model_motion_loaders import get_mdm_loader # get_motion_loader
from data_loaders.humanml.utils.metrics import *
from data_loaders.humanml.networks.evaluator_wrapper import EvaluatorMDMWrapper
from collections import OrderedDict
from data_loaders.humanml.scripts.motion_process import *
from data_loaders.humanml.utils.utils import *
from utils.model_util import create_model_and_diffusion, load_saved_model
from diffusion import logger
from utils import dist_util
from data_loaders.get_data import get_dataset_loader
from utils.sampler_util import ClassifierFreeSampleModel
from train.train_platforms import ClearmlPlatform, TensorboardPlatform, NoPlatform, WandBPlatform # required for the eval operation
torch.multiprocessing.set_sharing_strategy('file_system')
def evaluate_matching_score(eval_wrapper, motion_loaders, file):
match_score_dict = OrderedDict({})
R_precision_dict = OrderedDict({})
activation_dict = OrderedDict({})
print('========== Evaluating Matching Score ==========')
for motion_loader_name, motion_loader in motion_loaders.items():
all_motion_embeddings = []
score_list = []
all_size = 0
matching_score_sum = 0
top_k_count = 0
# print(motion_loader_name)
with torch.no_grad():
for idx, batch in enumerate(motion_loader):
word_embeddings, pos_one_hots, _, sent_lens, motions, m_lens, _ = batch
text_embeddings, motion_embeddings = eval_wrapper.get_co_embeddings(
word_embs=word_embeddings,
pos_ohot=pos_one_hots,
cap_lens=sent_lens,
motions=motions,
m_lens=m_lens
)
dist_mat = euclidean_distance_matrix(text_embeddings.cpu().numpy(),
motion_embeddings.cpu().numpy())
matching_score_sum += dist_mat.trace()
argsmax = np.argsort(dist_mat, axis=1)
top_k_mat = calculate_top_k(argsmax, top_k=3)
top_k_count += top_k_mat.sum(axis=0)
all_size += text_embeddings.shape[0]
all_motion_embeddings.append(motion_embeddings.cpu().numpy())
all_motion_embeddings = np.concatenate(all_motion_embeddings, axis=0)
matching_score = matching_score_sum / all_size
R_precision = top_k_count / all_size
match_score_dict[motion_loader_name] = matching_score
R_precision_dict[motion_loader_name] = R_precision
activation_dict[motion_loader_name] = all_motion_embeddings
print(f'---> [{motion_loader_name}] Matching Score: {matching_score:.4f}')
print(f'---> [{motion_loader_name}] Matching Score: {matching_score:.4f}', file=file, flush=True)
line = f'---> [{motion_loader_name}] R_precision: '
for i in range(len(R_precision)):
line += '(top %d): %.4f ' % (i+1, R_precision[i])
print(line)
print(line, file=file, flush=True)
return match_score_dict, R_precision_dict, activation_dict
def evaluate_fid(eval_wrapper, groundtruth_loader, activation_dict, file):
eval_dict = OrderedDict({})
gt_motion_embeddings = []
print('========== Evaluating FID ==========')
with torch.no_grad():
for idx, batch in enumerate(groundtruth_loader):
_, _, _, sent_lens, motions, m_lens, _ = batch
motion_embeddings = eval_wrapper.get_motion_embeddings(
motions=motions,
m_lens=m_lens
)
gt_motion_embeddings.append(motion_embeddings.cpu().numpy())
gt_motion_embeddings = np.concatenate(gt_motion_embeddings, axis=0)
gt_mu, gt_cov = calculate_activation_statistics(gt_motion_embeddings)
# print(gt_mu)
for model_name, motion_embeddings in activation_dict.items():
mu, cov = calculate_activation_statistics(motion_embeddings)
# print(mu)
fid = calculate_frechet_distance(gt_mu, gt_cov, mu, cov)
print(f'---> [{model_name}] FID: {fid:.4f}')
print(f'---> [{model_name}] FID: {fid:.4f}', file=file, flush=True)
eval_dict[model_name] = fid
return eval_dict
def evaluate_diversity(activation_dict, file, diversity_times):
eval_dict = OrderedDict({})
print('========== Evaluating Diversity ==========')
for model_name, motion_embeddings in activation_dict.items():
diversity = calculate_diversity(motion_embeddings, diversity_times)
eval_dict[model_name] = diversity
print(f'---> [{model_name}] Diversity: {diversity:.4f}')
print(f'---> [{model_name}] Diversity: {diversity:.4f}', file=file, flush=True)
return eval_dict
def evaluate_multimodality(eval_wrapper, mm_motion_loaders, file, mm_num_times):
eval_dict = OrderedDict({})
print('========== Evaluating MultiModality ==========')
for model_name, mm_motion_loader in mm_motion_loaders.items():
mm_motion_embeddings = []
with torch.no_grad():
for idx, batch in enumerate(mm_motion_loader):
# (1, mm_replications, dim_pos)
motions, m_lens = batch
motion_embedings = eval_wrapper.get_motion_embeddings(motions[0], m_lens[0])
mm_motion_embeddings.append(motion_embedings.unsqueeze(0))
if len(mm_motion_embeddings) == 0:
multimodality = 0
else:
mm_motion_embeddings = torch.cat(mm_motion_embeddings, dim=0).cpu().numpy()
multimodality = calculate_multimodality(mm_motion_embeddings, mm_num_times)
print(f'---> [{model_name}] Multimodality: {multimodality:.4f}')
print(f'---> [{model_name}] Multimodality: {multimodality:.4f}', file=file, flush=True)
eval_dict[model_name] = multimodality
return eval_dict
def get_metric_statistics(values, replication_times):
mean = np.mean(values, axis=0)
std = np.std(values, axis=0)
conf_interval = 1.96 * std / np.sqrt(replication_times)
return mean, conf_interval
def evaluation(eval_wrapper, gt_loader, eval_motion_loaders, log_file, replication_times,
diversity_times, mm_num_times, run_mm=False, eval_platform=None):
with open(log_file, 'w') as f:
all_metrics = OrderedDict({'Matching Score': OrderedDict({}),
'R_precision': OrderedDict({}),
'FID': OrderedDict({}),
'Diversity': OrderedDict({}),
'MultiModality': OrderedDict({})})
for replication in range(replication_times):
motion_loaders = {}
mm_motion_loaders = {}
motion_loaders['ground truth'] = gt_loader
for motion_loader_name, motion_loader_getter in eval_motion_loaders.items():
motion_loader, mm_motion_loader = motion_loader_getter()
motion_loaders[motion_loader_name] = motion_loader
mm_motion_loaders[motion_loader_name] = mm_motion_loader
print(f'==================== Replication {replication} ====================')
print(f'==================== Replication {replication} ====================', file=f, flush=True)
print(f'Time: {datetime.now()}')
print(f'Time: {datetime.now()}', file=f, flush=True)
mat_score_dict, R_precision_dict, acti_dict = evaluate_matching_score(eval_wrapper, motion_loaders, f)
print(f'Time: {datetime.now()}')
print(f'Time: {datetime.now()}', file=f, flush=True)
fid_score_dict = evaluate_fid(eval_wrapper, gt_loader, acti_dict, f)
print(f'Time: {datetime.now()}')
print(f'Time: {datetime.now()}', file=f, flush=True)
div_score_dict = evaluate_diversity(acti_dict, f, diversity_times)
if run_mm:
print(f'Time: {datetime.now()}')
print(f'Time: {datetime.now()}', file=f, flush=True)
mm_score_dict = evaluate_multimodality(eval_wrapper, mm_motion_loaders, f, mm_num_times)
print(f'!!! DONE !!!')
print(f'!!! DONE !!!', file=f, flush=True)
for key, item in mat_score_dict.items():
if key not in all_metrics['Matching Score']:
all_metrics['Matching Score'][key] = [item]
else:
all_metrics['Matching Score'][key] += [item]
for key, item in R_precision_dict.items():
if key not in all_metrics['R_precision']:
all_metrics['R_precision'][key] = [item]
else:
all_metrics['R_precision'][key] += [item]
for key, item in fid_score_dict.items():
if key not in all_metrics['FID']:
all_metrics['FID'][key] = [item]
else:
all_metrics['FID'][key] += [item]
for key, item in div_score_dict.items():
if key not in all_metrics['Diversity']:
all_metrics['Diversity'][key] = [item]
else:
all_metrics['Diversity'][key] += [item]
if run_mm:
for key, item in mm_score_dict.items():
if key not in all_metrics['MultiModality']:
all_metrics['MultiModality'][key] = [item]
else:
all_metrics['MultiModality'][key] += [item]
# print(all_metrics['Diversity'])
mean_dict = {}
for metric_name, metric_dict in all_metrics.items():
print('========== %s Summary ==========' % metric_name)
print('========== %s Summary ==========' % metric_name, file=f, flush=True)
for model_name, values in metric_dict.items():
# print(metric_name, model_name)
mean, conf_interval = get_metric_statistics(np.array(values), replication_times)
mean_dict[metric_name + '_' + model_name] = mean
# print(mean, mean.dtype)
if isinstance(mean, np.float64) or isinstance(mean, np.float32):
print(f'---> [{model_name}] Mean: {mean:.4f} CInterval: {conf_interval:.4f}')
print(f'---> [{model_name}] Mean: {mean:.4f} CInterval: {conf_interval:.4f}', file=f, flush=True)
elif isinstance(mean, np.ndarray):
line = f'---> [{model_name}]'
for i in range(len(mean)):
line += '(top %d) Mean: %.4f CInt: %.4f;' % (i+1, mean[i], conf_interval[i])
print(line)
print(line, file=f, flush=True)
# log results
if eval_platform is not None:
for k, v in mean_dict.items():
if k.startswith('R_precision'):
for i in range(len(v)):
eval_platform.report_scalar(name=f'top{i + 1}_' + k, value=v[i],
iteration=1, group_name='Eval')
else:
eval_platform.report_scalar(name=k, value=v, iteration=1, group_name='Eval')
return mean_dict
if __name__ == '__main__':
args = evaluation_parser()
fixseed(args.seed)
args.batch_size = 32 # This must be 32! Don't change it! otherwise it will cause a bug in R precision calc!
name = os.path.basename(os.path.dirname(args.model_path))
niter = os.path.basename(args.model_path).replace('model', '').replace('.pt', '')
log_name = 'eval_humanml_{}_{}'.format(name, niter)
if args.guidance_param != 1.:
log_name += f'_gscale{args.guidance_param}'
log_name += f'_{args.eval_mode}'
log_file = os.path.join(os.path.dirname(args.model_path), log_name + '.log')
save_dir = os.path.dirname(log_file) # has not been tested with WandB
print(f'Will save to log file [{log_file}]')
eval_platform_type = eval(args.train_platform_type)
eval_platform = eval_platform_type(save_dir, name=log_name)
eval_platform.report_args(args, name='Args')
print(f'Eval mode [{args.eval_mode}]')
if args.eval_mode == 'debug':
num_samples_limit = 1000 # None means no limit (eval over all dataset)
run_mm = False
mm_num_samples = 0
mm_num_repeats = 0
mm_num_times = 0
diversity_times = 300
replication_times = 5 # about 3 Hrs
elif args.eval_mode == 'wo_mm':
num_samples_limit = 1000
run_mm = False
mm_num_samples = 0
mm_num_repeats = 0
mm_num_times = 0
diversity_times = 300
replication_times = 20 # about 12 Hrs
elif args.eval_mode == 'mm_short':
num_samples_limit = 1000
run_mm = True
mm_num_samples = 100
mm_num_repeats = 30
mm_num_times = 10
diversity_times = 300
replication_times = 5 # about 15 Hrs
else:
raise ValueError()
dist_util.setup_dist(args.device)
logger.configure()
logger.log("creating data loader...")
split = 'test'
gt_loader = get_dataset_loader(name=args.dataset, batch_size=args.batch_size, num_frames=None, split=split, hml_mode='gt')
# gen_loader = get_dataset_loader(name=args.dataset, batch_size=args.batch_size, num_frames=None, split=split, hml_mode='eval')
# added new features + support for prefix completion:
gen_loader = get_dataset_loader(name=args.dataset, batch_size=args.batch_size, num_frames=None, split=split, hml_mode='eval',
fixed_len=args.context_len+args.pred_len, pred_len=args.pred_len, device=dist_util.dev(),
autoregressive=args.autoregressive)
num_actions = gen_loader.dataset.num_actions
logger.log("Creating model and diffusion...")
model, diffusion = create_model_and_diffusion(args, gen_loader)
logger.log(f"Loading checkpoints from [{args.model_path}]...")
load_saved_model(model, args.model_path, use_avg=args.use_ema)
if args.guidance_param != 1:
model = ClassifierFreeSampleModel(model) # wrapping model with the classifier-free sampler
model.to(dist_util.dev())
model.eval() # disable random masking
eval_motion_loaders = {
################
## HumanML3D Dataset##
################
'vald': lambda: get_mdm_loader(args,
model=model, diffusion=diffusion, batch_size=args.batch_size,
ground_truth_loader=gen_loader, mm_num_samples=mm_num_samples, mm_num_repeats=mm_num_repeats,
max_motion_length=gt_loader.dataset.opt.max_motion_length, num_samples_limit=num_samples_limit,
scale=args.guidance_param
)
}
eval_wrapper = EvaluatorMDMWrapper(args.dataset, dist_util.dev())
evaluation(eval_wrapper, gt_loader, eval_motion_loaders, log_file, replication_times,
diversity_times, mm_num_times, run_mm=run_mm, eval_platform=eval_platform)
eval_platform.close()