File size: 2,254 Bytes
0e267a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import os 
import torch
import numpy as np

from torch.utils.tensorboard import SummaryWriter
from os.path import join as pjoin
import json
import models.tae as tae
import options.option_tae as option_tae
import utils.utils_model as utils_model
from humanml3d_272 import dataset_tae_tokenizer
import warnings
from tqdm import tqdm
warnings.filterwarnings('ignore')

##### ---- Exp dirs ---- #####
args = option_tae.get_args_parser()
torch.manual_seed(args.seed)

args.out_dir = os.path.join(args.out_dir, f'{args.exp_name}')
os.makedirs(args.out_dir, exist_ok = True)

##### ---- Logger ---- #####
logger = utils_model.get_logger(args.out_dir)
writer = SummaryWriter(args.out_dir)
logger.info(json.dumps(vars(args), indent=4, sort_keys=True))

##### ---- Dataloader ---- #####
train_loader = dataset_tae_tokenizer.DATALoader(args.dataname)

clip_range = [-30,20]

net = tae.Causal_HumanTAE(
                       hidden_size=args.hidden_size,
                       down_t=args.down_t,
                       stride_t=args.stride_t,
                       depth=args.depth,
                       dilation_growth_rate=args.dilation_growth_rate,
                       activation='relu',
                       latent_dim=args.latent_dim,
                       clip_range=clip_range
                       )

logger.info('loading checkpoint from {}'.format(args.resume_pth))
ckpt = torch.load(args.resume_pth, map_location='cpu')
net.load_state_dict(ckpt['net'], strict=True)
net.eval()
net.cuda()


##### ---- get reference end latent ---- #####
reference_end_pose = torch.zeros(1, 4, 272).cuda()   # impossible pose prior
reference_end_latent, _, _ = net.encode(reference_end_pose)
reference_end_latent = reference_end_latent.permute(1,0)
np.save(f'reference_end_latent_{args.dataname}.npy', reference_end_latent.cpu().detach().numpy())

os.makedirs(args.latent_dir, exist_ok = True)

for batch in tqdm(train_loader):
    pose, name = batch
    bs, seq = pose.shape[0], pose.shape[1]
    pose = pose.cuda().float()
    latent, _, _  = net.encode(pose)
    latent = latent.permute(1,0)
    latent = torch.cat([latent, reference_end_latent], dim=0)
    latent = latent.cpu().detach().numpy()
    np.save(pjoin(args.latent_dir, name[0] +'.npy'), latent)