|
|
import os |
|
|
import torch |
|
|
from utils import eval_trans |
|
|
from humanml3d_272 import dataset_eval_tae |
|
|
from options import option_transformer as option_trans |
|
|
import warnings |
|
|
import sys |
|
|
warnings.filterwarnings('ignore') |
|
|
os.chdir('Evaluator_272') |
|
|
sys.path.insert(0, os.getcwd()) |
|
|
comp_device = torch.device('cuda') |
|
|
args = option_trans.get_args_parser() |
|
|
torch.manual_seed(args.seed) |
|
|
val_loader = dataset_eval_tae.DATALoader(args.dataname, True, 32) |
|
|
|
|
|
|
|
|
|
|
|
from mld.models.architectures.temos.textencoder.distillbert_actor import DistilbertActorAgnosticEncoder |
|
|
from mld.models.architectures.temos.motionencoder.actor import ActorAgnosticEncoder |
|
|
|
|
|
modelpath = 'distilbert-base-uncased' |
|
|
|
|
|
textencoder = DistilbertActorAgnosticEncoder(modelpath, num_layers=4) |
|
|
motionencoder = ActorAgnosticEncoder(nfeats=272, vae = True, num_layers=4, max_len=300) |
|
|
|
|
|
ckpt_path = 'epoch=99.ckpt' |
|
|
print(f'Loading evaluator checkpoint from {ckpt_path}') |
|
|
ckpt = torch.load(ckpt_path) |
|
|
|
|
|
textencoder_ckpt = {} |
|
|
for k, v in ckpt['state_dict'].items(): |
|
|
if k.split(".")[0] == "textencoder": |
|
|
name = k.replace("textencoder.", "") |
|
|
textencoder_ckpt[name] = v |
|
|
textencoder.load_state_dict(textencoder_ckpt, strict=True) |
|
|
textencoder.eval() |
|
|
textencoder.to(comp_device) |
|
|
|
|
|
|
|
|
motionencoder_ckpt = {} |
|
|
for k, v in ckpt['state_dict'].items(): |
|
|
if k.split(".")[0] == "motionencoder": |
|
|
name = k.replace("motionencoder.", "") |
|
|
motionencoder_ckpt[name] = v |
|
|
motionencoder.load_state_dict(motionencoder_ckpt, strict=True) |
|
|
motionencoder.eval() |
|
|
motionencoder.to(comp_device) |
|
|
|
|
|
|
|
|
evaluator = [textencoder, motionencoder] |
|
|
|
|
|
gt_fid, gt_div, gt_top1, gt_top2, gt_top3, gt_matching = eval_trans.evaluation_gt(val_loader, evaluator, device=comp_device) |
|
|
|
|
|
print('final result:') |
|
|
print(f'gt_fid: {gt_fid}') |
|
|
print(f'gt_div: {gt_div}') |
|
|
print(f'gt_top1: {gt_top1}') |
|
|
print(f'gt_top2: {gt_top2}') |
|
|
print(f'gt_top3: {gt_top3}') |
|
|
print(f'gt_MM-dist (matching score): {gt_matching}') |
|
|
|
|
|
|
|
|
|