| | import random |
| | import math |
| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import smplx |
| | import copy |
| | from .motion_encoder import * |
| |
|
| | |
| | class VAEConvZero(nn.Module): |
| | def __init__(self, args): |
| | super(VAEConvZero, self).__init__() |
| | self.encoder = VQEncoderV5(args) |
| | |
| | self.decoder = VQDecoderV5(args) |
| | |
| | def forward(self, inputs): |
| | pre_latent = self.encoder(inputs) |
| | |
| | |
| | rec_pose = self.decoder(pre_latent) |
| | return { |
| | |
| | |
| | |
| | "rec_pose": rec_pose |
| | } |
| | |
| | class VAEConv(nn.Module): |
| | def __init__(self, args): |
| | super(VAEConv, self).__init__() |
| | self.encoder = VQEncoderV3(args) |
| | self.decoder = VQDecoderV3(args) |
| | self.fc_mu = nn.Linear(args.vae_length, args.vae_length) |
| | self.fc_logvar = nn.Linear(args.vae_length, args.vae_length) |
| | self.variational = args.variational |
| | |
| | def forward(self, inputs): |
| | pre_latent = self.encoder(inputs) |
| | mu, logvar = None, None |
| | if self.variational: |
| | mu = self.fc_mu(pre_latent) |
| | logvar = self.fc_logvar(pre_latent) |
| | pre_latent = reparameterize(mu, logvar) |
| | rec_pose = self.decoder(pre_latent) |
| | return { |
| | "poses_feat":pre_latent, |
| | "rec_pose": rec_pose, |
| | "pose_mu": mu, |
| | "pose_logvar": logvar, |
| | } |
| | |
| | def map2latent(self, inputs): |
| | pre_latent = self.encoder(inputs) |
| | if self.variational: |
| | mu = self.fc_mu(pre_latent) |
| | logvar = self.fc_logvar(pre_latent) |
| | pre_latent = reparameterize(mu, logvar) |
| | return pre_latent |
| | |
| | def decode(self, pre_latent): |
| | rec_pose = self.decoder(pre_latent) |
| | return rec_pose |
| |
|
| | class VAESKConv(VAEConv): |
| | def __init__(self, args): |
| | super(VAESKConv, self).__init__(args) |
| | smpl_fname = args.data_path_1+'smplx_models/smplx/SMPLX_NEUTRAL_2020.npz' |
| | smpl_data = np.load(smpl_fname, encoding='latin1') |
| | parents = smpl_data['kintree_table'][0].astype(np.int32) |
| | edges = build_edge_topology(parents) |
| | self.encoder = LocalEncoder(args, edges) |
| | self.decoder = VQDecoderV3(args) |
| | |
| | class VAEConvMLP(VAEConv): |
| | def __init__(self, args): |
| | super(VAEConvMLP, self).__init__(args) |
| | self.encoder = PoseEncoderConv(args.vae_test_len, args.vae_test_dim, feature_length=args.vae_length) |
| | self.decoder = PoseDecoderConv(args.vae_test_len, args.vae_test_dim, feature_length=args.vae_length) |
| | |
| | class VAELSTM(VAEConv): |
| | def __init__(self, args): |
| | super(VAELSTM, self).__init__(args) |
| | pose_dim = args.vae_test_dim |
| | feature_length = args.vae_length |
| | self.encoder = PoseEncoderLSTM_Resnet(pose_dim, feature_length=feature_length) |
| | self.decoder = PoseDecoderLSTM(pose_dim, feature_length=feature_length) |
| |
|
| | class VAETransformer(VAEConv): |
| | def __init__(self, args): |
| | super(VAETransformer, self).__init__(args) |
| | self.encoder = Encoder_TRANSFORMER(args) |
| | self.decoder = Decoder_TRANSFORMER(args) |
| |
|
| | |
| | class VQVAEConv(nn.Module): |
| | def __init__(self, args): |
| | super(VQVAEConv, self).__init__() |
| | self.encoder = VQEncoderV3(args) |
| | self.quantizer = Quantizer(args.vae_codebook_size, args.vae_length, args.vae_quantizer_lambda) |
| | self.decoder = VQDecoderV3(args) |
| | |
| | def forward(self, inputs): |
| | pre_latent = self.encoder(inputs) |
| | |
| | embedding_loss, vq_latent, _, perplexity = self.quantizer(pre_latent) |
| | rec_pose = self.decoder(vq_latent) |
| | return { |
| | "poses_feat":vq_latent, |
| | "embedding_loss":embedding_loss, |
| | "perplexity":perplexity, |
| | "rec_pose": rec_pose |
| | } |
| | |
| | def map2index(self, inputs): |
| | pre_latent = self.encoder(inputs) |
| | index = self.quantizer.map2index(pre_latent) |
| | return index |
| | |
| | def map2latent(self, inputs): |
| | pre_latent = self.encoder(inputs) |
| | index = self.quantizer.map2index(pre_latent) |
| | z_q = self.quantizer.get_codebook_entry(index) |
| | return z_q |
| | |
| | def decode(self, index): |
| | z_q = self.quantizer.get_codebook_entry(index) |
| | rec_pose = self.decoder(z_q) |
| | return rec_pose |
| |
|
| | class VQVAESKConv(VQVAEConv): |
| | def __init__(self, args): |
| | super(VQVAESKConv, self).__init__(args) |
| | smpl_fname = args.data_path_1+'smplx_models/smplx/SMPLX_NEUTRAL_2020.npz' |
| | smpl_data = np.load(smpl_fname, encoding='latin1') |
| | parents = smpl_data['kintree_table'][0].astype(np.int32) |
| | edges = build_edge_topology(parents) |
| | self.encoder = LocalEncoder(args, edges) |
| |
|
| |
|
| | class VQVAEConvStride(nn.Module): |
| | def __init__(self, args): |
| | super(VQVAEConvStride, self).__init__() |
| | self.encoder = VQEncoderV4(args) |
| | self.quantizer = Quantizer(args.vae_codebook_size, args.vae_length, args.vae_quantizer_lambda) |
| | self.decoder = VQDecoderV4(args) |
| | |
| | def forward(self, inputs): |
| | pre_latent = self.encoder(inputs) |
| | |
| | embedding_loss, vq_latent, _, perplexity = self.quantizer(pre_latent) |
| | rec_pose = self.decoder(vq_latent) |
| | return { |
| | "poses_feat":vq_latent, |
| | "embedding_loss":embedding_loss, |
| | "perplexity":perplexity, |
| | "rec_pose": rec_pose |
| | } |
| | |
| | def map2index(self, inputs): |
| | pre_latent = self.encoder(inputs) |
| | index = self.quantizer.map2index(pre_latent) |
| | return index |
| | |
| | def map2latent(self, inputs): |
| | pre_latent = self.encoder(inputs) |
| | index = self.quantizer.map2index(pre_latent) |
| | z_q = self.quantizer.get_codebook_entry(index) |
| | return z_q |
| | |
| | def decode(self, index): |
| | z_q = self.quantizer.get_codebook_entry(index) |
| | rec_pose = self.decoder(z_q) |
| | return rec_pose |
| |
|
| | class VQVAEConvZero(nn.Module): |
| | def __init__(self, args): |
| | super(VQVAEConvZero, self).__init__() |
| | self.encoder = VQEncoderV5(args) |
| | self.quantizer = Quantizer(args.vae_codebook_size, args.vae_length, args.vae_quantizer_lambda) |
| | self.decoder = VQDecoderV5(args) |
| | |
| | def forward(self, inputs): |
| | pre_latent = self.encoder(inputs) |
| | |
| | embedding_loss, vq_latent, _, perplexity = self.quantizer(pre_latent) |
| | rec_pose = self.decoder(vq_latent) |
| | return { |
| | "poses_feat":vq_latent, |
| | "embedding_loss":embedding_loss, |
| | "perplexity":perplexity, |
| | "rec_pose": rec_pose |
| | } |
| | |
| | def map2index(self, inputs): |
| | pre_latent = self.encoder(inputs) |
| | index = self.quantizer.map2index(pre_latent) |
| | return index |
| | |
| | def map2latent(self, inputs): |
| | pre_latent = self.encoder(inputs) |
| | index = self.quantizer.map2index(pre_latent) |
| | z_q = self.quantizer.get_codebook_entry(index) |
| | return z_q |
| | |
| | def decode(self, index): |
| | z_q = self.quantizer.get_codebook_entry(index) |
| | rec_pose = self.decoder(z_q) |
| | return rec_pose |
| | |
| |
|
| | class VAEConvZero(nn.Module): |
| | def __init__(self, args): |
| | super(VAEConvZero, self).__init__() |
| | self.encoder = VQEncoderV5(args) |
| | |
| | self.decoder = VQDecoderV5(args) |
| | |
| | def forward(self, inputs): |
| | pre_latent = self.encoder(inputs) |
| | |
| | |
| | rec_pose = self.decoder(pre_latent) |
| | return { |
| | |
| | |
| | |
| | "rec_pose": rec_pose |
| | } |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | class VQVAEConvZero3(nn.Module): |
| | def __init__(self, args): |
| | super(VQVAEConvZero3, self).__init__() |
| | self.encoder = VQEncoderV5(args) |
| | self.quantizer = Quantizer(args.vae_codebook_size, args.vae_length, args.vae_quantizer_lambda) |
| | self.decoder = VQDecoderV5(args) |
| | |
| | def forward(self, inputs): |
| | pre_latent = self.encoder(inputs) |
| | |
| | embedding_loss, vq_latent, _, perplexity = self.quantizer(pre_latent) |
| | rec_pose = self.decoder(vq_latent) |
| | return { |
| | "poses_feat":vq_latent, |
| | "embedding_loss":embedding_loss, |
| | "perplexity":perplexity, |
| | "rec_pose": rec_pose |
| | } |
| | |
| | def map2index(self, inputs): |
| | pre_latent = self.encoder(inputs) |
| | index = self.quantizer.map2index(pre_latent) |
| | return index |
| | |
| | def map2latent(self, inputs): |
| | pre_latent = self.encoder(inputs) |
| | index = self.quantizer.map2index(pre_latent) |
| | z_q = self.quantizer.get_codebook_entry(index) |
| | return z_q |
| | |
| | def decode(self, index): |
| | z_q = self.quantizer.get_codebook_entry(index) |
| | rec_pose = self.decoder(z_q) |
| | return rec_pose |
| |
|
| | class VQVAEConvZero2(nn.Module): |
| | def __init__(self, args): |
| | super(VQVAEConvZero2, self).__init__() |
| | self.encoder = VQEncoderV5(args) |
| | self.quantizer = Quantizer(args.vae_codebook_size, args.vae_length, args.vae_quantizer_lambda) |
| | self.decoder = VQDecoderV7(args) |
| | |
| | def forward(self, inputs): |
| | pre_latent = self.encoder(inputs) |
| | |
| | embedding_loss, vq_latent, _, perplexity = self.quantizer(pre_latent) |
| | rec_pose = self.decoder(vq_latent) |
| | return { |
| | "poses_feat":vq_latent, |
| | "embedding_loss":embedding_loss, |
| | "perplexity":perplexity, |
| | "rec_pose": rec_pose |
| | } |
| | |
| | def map2index(self, inputs): |
| | pre_latent = self.encoder(inputs) |
| | index = self.quantizer.map2index(pre_latent) |
| | return index |
| | |
| | def map2latent(self, inputs): |
| | pre_latent = self.encoder(inputs) |
| | index = self.quantizer.map2index(pre_latent) |
| | z_q = self.quantizer.get_codebook_entry(index) |
| | return z_q |
| | |
| | def decode(self, index): |
| | z_q = self.quantizer.get_codebook_entry(index) |
| | rec_pose = self.decoder(z_q) |
| | return rec_pose |
| |
|
| | class VQVAE2(nn.Module): |
| | def __init__(self, args): |
| | super(VQVAE2, self).__init__() |
| | |
| | args_bottom = copy.deepcopy(args) |
| | args_bottom.vae_layer = 2 |
| | self.bottom_encoder = VQEncoderV6(args_bottom) |
| | self.bottom_quantizer = Quantizer(args.vae_codebook_size, args.vae_length, args.vae_quantizer_lambda) |
| | args_bottom.vae_test_dim = args.vae_test_dim |
| | self.bottom_decoder = VQDecoderV6(args_bottom) |
| | |
| | |
| | args_top = copy.deepcopy(args) |
| | args_top.vae_layer = 3 |
| | args_top.vae_test_dim = args.vae_length |
| | self.top_encoder = VQEncoderV3(args_top) |
| | self.quantize_conv_t = nn.Conv1d(args.vae_length+args.vae_length, args.vae_length, 1) |
| | self.top_quantizer = Quantizer(args.vae_codebook_size, args.vae_length, args.vae_quantizer_lambda) |
| | |
| | layers = [ |
| | nn.Upsample(scale_factor=2, mode='nearest'), |
| | nn.Conv1d(args.vae_length, args.vae_length, kernel_size=3, stride=1, padding=1), |
| | nn.LeakyReLU(0.2, inplace=True), |
| | nn.Upsample(scale_factor=2, mode='nearest'), |
| | nn.Conv1d(args.vae_length, args.vae_length, kernel_size=3, stride=1, padding=1), |
| | nn.LeakyReLU(0.2, inplace=True), |
| | nn.Upsample(scale_factor=2, mode='nearest'), |
| | nn.Conv1d(args.vae_length, args.vae_length, kernel_size=3, stride=1, padding=1), |
| | nn.LeakyReLU(0.2, inplace=True) |
| | ] |
| | self.upsample_t= nn.Sequential(*layers) |
| | self.top_decoder = VQDecoderV3(args_top) |
| |
|
| | def forward(self, inputs): |
| | |
| | enc_b = self.bottom_encoder(inputs) |
| | enc_t = self.top_encoder(enc_b) |
| | |
| | top_embedding_loss, quant_t, _, top_perplexity = self.top_quantizer(enc_t) |
| | |
| | dec_t = self.top_decoder(quant_t) |
| | |
| | enc_b = torch.cat([dec_t, enc_b], dim=2).permute(0,2,1) |
| | |
| | quant_b = self.quantize_conv_t(enc_b).permute(0,2,1) |
| | |
| | bottom_embedding_loss, quant_b, _, bottom_perplexity = self.bottom_quantizer(quant_b) |
| | |
| | upsample_t = self.upsample_t(quant_t.permute(0,2,1)).permute(0,2,1) |
| | |
| | quant = torch.cat([upsample_t, quant_b], 2) |
| | rec_pose = self.bottom_decoder(quant) |
| | |
| | return { |
| | "poses_feat_top": quant_t, |
| | "pose_feat_bottom": quant_b, |
| | "embedding_loss":top_embedding_loss+bottom_embedding_loss, |
| | |
| | "rec_pose": rec_pose |
| | } |
| | |
| | def map2index(self, inputs): |
| | enc_b = self.bottom_encoder(inputs) |
| | enc_t = self.top_encoder(enc_b) |
| | |
| | _, quant_t, _, _ = self.top_quantizer(enc_t) |
| | top_index = self.top_quantizer.map2index(enc_t) |
| | dec_t = self.top_decoder(quant_t) |
| |
|
| | enc_b = torch.cat([dec_t, enc_b], dim=2).permute(0,2,1) |
| | |
| | quant_b = self.quantize_conv_t(enc_b).permute(0,2,1) |
| | |
| | bottom_index = self.bottom_quantizer.map2index(quant_b) |
| | return top_index, bottom_index |
| | |
| | def get_top_laent(self, top_index): |
| | z_q_top = self.top_quantizer.get_codebook_entry(top_index) |
| | return z_q_top |
| | |
| | def map2latent(self, inputs): |
| | enc_b = self.bottom_encoder(inputs) |
| | enc_t = self.top_encoder(enc_b) |
| | |
| | _, quant_t, _, _ = self.top_quantizer(enc_t) |
| | top_index = self.top_quantizer.map2index(enc_t) |
| | dec_t = self.top_decoder(quant_t) |
| |
|
| | enc_b = torch.cat([dec_t, enc_b], dim=2).permute(0,2,1) |
| | |
| | quant_b = self.quantize_conv_t(enc_b).permute(0,2,1) |
| | |
| | bottom_index = self.bottom_quantizer.map2index(quant_b) |
| | z_q_top = self.top_quantizer.get_codebook_entry(top_index) |
| | z_q_bottom = self.bottom_quantizer.get_codebook_entry(bottom_index) |
| | return z_q_top, z_q_bottom |
| | |
| | def map2latent_top(self, inputs): |
| | enc_b = self.bottom_encoder(inputs) |
| | enc_t = self.top_encoder(enc_b) |
| | top_index = self.top_quantizer.map2index(enc_t) |
| | z_q_top = self.top_quantizer.get_codebook_entry(top_index) |
| | return z_q_top |
| | |
| | def decode(self, top_index, bottom_index): |
| | quant_t = self.top_quantizer.get_codebook_entry(top_index) |
| | quant_b = self.bottom_quantizer.get_codebook_entry(bottom_index) |
| | upsample_t = self.upsample_t(quant_t.permute(0,2,1)).permute(0,2,1) |
| | |
| | quant = torch.cat([upsample_t, quant_b], 2) |
| | rec_pose = self.bottom_decoder(quant) |
| | return rec_pose |