File size: 2,052 Bytes
d2a17a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import os 
import torch
from utils import eval_trans
from humanml3d_272 import dataset_eval_tae
from options import option_transformer as option_trans
import warnings
import sys
warnings.filterwarnings('ignore')
os.chdir('Evaluator_272')
sys.path.insert(0, os.getcwd()) 
comp_device = torch.device('cuda')
args = option_trans.get_args_parser()
torch.manual_seed(args.seed)
val_loader = dataset_eval_tae.DATALoader(args.dataname, True, 32)


# load evaluator:--------------------------------
from mld.models.architectures.temos.textencoder.distillbert_actor import DistilbertActorAgnosticEncoder
from mld.models.architectures.temos.motionencoder.actor import ActorAgnosticEncoder

modelpath = 'distilbert-base-uncased'

textencoder = DistilbertActorAgnosticEncoder(modelpath, num_layers=4)
motionencoder = ActorAgnosticEncoder(nfeats=272, vae = True, num_layers=4, max_len=300)

ckpt_path = 'epoch=99.ckpt'
print(f'Loading evaluator checkpoint from {ckpt_path}')
ckpt = torch.load(ckpt_path)
# load textencoder
textencoder_ckpt = {}
for k, v in ckpt['state_dict'].items():
    if k.split(".")[0] == "textencoder":
        name = k.replace("textencoder.", "")
        textencoder_ckpt[name] = v
textencoder.load_state_dict(textencoder_ckpt, strict=True)
textencoder.eval()
textencoder.to(comp_device)

# load motionencoder
motionencoder_ckpt = {}
for k, v in ckpt['state_dict'].items():
    if k.split(".")[0] == "motionencoder":
        name = k.replace("motionencoder.", "")
        motionencoder_ckpt[name] = v
motionencoder.load_state_dict(motionencoder_ckpt, strict=True)
motionencoder.eval()
motionencoder.to(comp_device)
#--------------------------------

evaluator = [textencoder, motionencoder]

gt_fid, gt_div, gt_top1, gt_top2, gt_top3, gt_matching = eval_trans.evaluation_gt(val_loader, evaluator, device=comp_device)
    
print('final result:')
print(f'gt_fid: {gt_fid}')
print(f'gt_div: {gt_div}')
print(f'gt_top1: {gt_top1}')
print(f'gt_top2: {gt_top2}')
print(f'gt_top3: {gt_top3}')
print(f'gt_MM-dist (matching score): {gt_matching}')