File size: 4,618 Bytes
0e267a7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import os
import torch
import numpy as np
from torch.utils.tensorboard import SummaryWriter
import json
import sys
from models.llama_model import LLaMAHF, LLaMAHFConfig
import options.option_transformer as option_trans
import utils.utils_model as utils_model
import utils.eval_trans as eval_trans
from humanml3d_272 import dataset_eval_t2m
import models.tae as tae
import warnings
warnings.filterwarnings('ignore')
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.chdir('Evaluator_272')
sys.path.insert(0, os.getcwd())
comp_device = torch.device('cuda')
##### ---- Exp dirs ---- #####
args = option_trans.get_args_parser()
torch.manual_seed(args.seed)
args.out_dir = os.path.join(args.out_dir, f'{args.exp_name}')
os.makedirs(args.out_dir, exist_ok = True)
##### ---- Logger ---- #####
logger = utils_model.get_logger(args.out_dir)
writer = SummaryWriter(args.out_dir)
logger.info(json.dumps(vars(args), indent=4, sort_keys=True))
val_loader = dataset_eval_t2m.DATALoader(args.dataname, True, 32)
##### ---- Network ---- #####
from sentence_transformers import SentenceTransformer
t5_model = SentenceTransformer('../sentencet5-xxl/')
t5_model.eval()
for p in t5_model.parameters():
p.requires_grad = False
tokenize_model = t5_model
# Causal TAE
clip_range = [-30,20]
net = tae.Causal_HumanTAE(
hidden_size=args.hidden_size,
down_t=args.down_t,
stride_t=args.stride_t,
depth=args.depth,
dilation_growth_rate=args.dilation_growth_rate,
activation='relu',
latent_dim=args.latent_dim,
clip_range=clip_range
)
config = LLaMAHFConfig.from_name('Normal_size')
config.block_size = 78
trans_encoder = LLaMAHF(config, args.num_diffusion_head_layers, args.latent_dim, comp_device)
print('loading checkpoint from {}'.format(args.resume_pth))
ckpt = torch.load(args.resume_pth, map_location='cpu')
net.load_state_dict(ckpt['net'], strict=True)
net.eval()
net.to(comp_device)
if args.resume_trans is not None:
print('loading transformer checkpoint from {}'.format(args.resume_trans))
ckpt = torch.load(args.resume_trans, map_location='cpu')
new_ckpt_trans = {}
for key in ckpt['trans'].keys():
if key.split('.')[0]=='module':
new_key = '.'.join(key.split('.')[1:])
else:
new_key = key
new_ckpt_trans[new_key] = ckpt['trans'][key]
trans_encoder.load_state_dict(new_ckpt_trans, strict=True)
trans_encoder.eval()
trans_encoder.to(comp_device)
# load evaluator:
import torch
from transformers import AutoTokenizer, AutoModel
from mld.models.architectures.temos.textencoder.distillbert_actor import DistilbertActorAgnosticEncoder
from mld.models.architectures.temos.motionencoder.actor import ActorAgnosticEncoder
from collections import OrderedDict
modelpath = 'distilbert-base-uncased'
textencoder = DistilbertActorAgnosticEncoder(modelpath, num_layers=4, latent_dim=256)
motionencoder = ActorAgnosticEncoder(nfeats=272, vae = True, num_layers=4, latent_dim=256, max_len=300)
ckpt_path = 'epoch=99.ckpt'
print(f'Loading evaluator checkpoint from {ckpt_path}')
ckpt = torch.load(ckpt_path)
# load textencoder
textencoder_ckpt = {}
for k, v in ckpt['state_dict'].items():
if k.split(".")[0] == "textencoder":
name = k.replace("textencoder.", "")
textencoder_ckpt[name] = v
textencoder.load_state_dict(textencoder_ckpt, strict=True)
textencoder.eval()
textencoder.to(comp_device)
# load motionencoder
motionencoder_ckpt = {}
for k, v in ckpt['state_dict'].items():
if k.split(".")[0] == "motionencoder":
name = k.replace("motionencoder.", "")
motionencoder_ckpt[name] = v
motionencoder.load_state_dict(motionencoder_ckpt, strict=True)
motionencoder.eval()
motionencoder.to(comp_device)
#--------------------------------
evaluator = [textencoder, motionencoder]
fid = []
div = []
top1 = []
top2 = []
top3 = []
matching = []
mpjpe = []
best_fid, best_div, best_top1, best_top2, best_top3, best_matching, logger = eval_trans.evaluation_transformer_272_single(val_loader, net, trans_encoder, tokenize_model, logger, evaluator, 4.0)
fid.append(best_fid)
div.append(best_div)
top1.append(best_top1)
top2.append(best_top2)
top3.append(best_top3)
matching.append(best_matching)
logger.info('final result:')
logger.info(f'fid: {fid}')
logger.info(f'div: {div}')
logger.info(f'top1: {top1}')
logger.info(f'top2: {top2}')
logger.info(f'top3: {top3}')
logger.info(f'MM-dist (matching score) : {matching}')
|