move-it / eval_gt.py
zirobtc's picture
Upload folder using huggingface_hub
d2a17a9 verified
import os
import torch
from utils import eval_trans
from humanml3d_272 import dataset_eval_tae
from options import option_transformer as option_trans
import warnings
import sys
warnings.filterwarnings('ignore')
os.chdir('Evaluator_272')
sys.path.insert(0, os.getcwd())
comp_device = torch.device('cuda')
args = option_trans.get_args_parser()
torch.manual_seed(args.seed)
val_loader = dataset_eval_tae.DATALoader(args.dataname, True, 32)
# load evaluator:--------------------------------
from mld.models.architectures.temos.textencoder.distillbert_actor import DistilbertActorAgnosticEncoder
from mld.models.architectures.temos.motionencoder.actor import ActorAgnosticEncoder
modelpath = 'distilbert-base-uncased'
textencoder = DistilbertActorAgnosticEncoder(modelpath, num_layers=4)
motionencoder = ActorAgnosticEncoder(nfeats=272, vae = True, num_layers=4, max_len=300)
ckpt_path = 'epoch=99.ckpt'
print(f'Loading evaluator checkpoint from {ckpt_path}')
ckpt = torch.load(ckpt_path)
# load textencoder
textencoder_ckpt = {}
for k, v in ckpt['state_dict'].items():
if k.split(".")[0] == "textencoder":
name = k.replace("textencoder.", "")
textencoder_ckpt[name] = v
textencoder.load_state_dict(textencoder_ckpt, strict=True)
textencoder.eval()
textencoder.to(comp_device)
# load motionencoder
motionencoder_ckpt = {}
for k, v in ckpt['state_dict'].items():
if k.split(".")[0] == "motionencoder":
name = k.replace("motionencoder.", "")
motionencoder_ckpt[name] = v
motionencoder.load_state_dict(motionencoder_ckpt, strict=True)
motionencoder.eval()
motionencoder.to(comp_device)
#--------------------------------
evaluator = [textencoder, motionencoder]
gt_fid, gt_div, gt_top1, gt_top2, gt_top3, gt_matching = eval_trans.evaluation_gt(val_loader, evaluator, device=comp_device)
print('final result:')
print(f'gt_fid: {gt_fid}')
print(f'gt_div: {gt_div}')
print(f'gt_top1: {gt_top1}')
print(f'gt_top2: {gt_top2}')
print(f'gt_top3: {gt_top3}')
print(f'gt_MM-dist (matching score): {gt_matching}')