Spaces:
Paused
Paused
Upload 5 files
Browse files- model/BERT/BERT_encoder.py +32 -0
- model/cfg_sampler.py +33 -0
- model/mdm.py +480 -0
- model/rotation2xyz.py +92 -0
- model/smpl.py +97 -0
model/BERT/BERT_encoder.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
def load_bert(model_path):
|
| 5 |
+
bert = BERT(model_path)
|
| 6 |
+
bert.eval()
|
| 7 |
+
bert.text_model.training = False
|
| 8 |
+
for p in bert.parameters():
|
| 9 |
+
p.requires_grad = False
|
| 10 |
+
return bert
|
| 11 |
+
|
| 12 |
+
class BERT(nn.Module):
|
| 13 |
+
def __init__(self, modelpath: str):
|
| 14 |
+
super().__init__()
|
| 15 |
+
|
| 16 |
+
from transformers import AutoTokenizer, AutoModel
|
| 17 |
+
from transformers import logging
|
| 18 |
+
logging.set_verbosity_error()
|
| 19 |
+
# Tokenizer
|
| 20 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 21 |
+
# Tokenizer
|
| 22 |
+
self.tokenizer = AutoTokenizer.from_pretrained(modelpath)
|
| 23 |
+
# Text model
|
| 24 |
+
self.text_model = AutoModel.from_pretrained(modelpath)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def forward(self, texts):
|
| 28 |
+
encoded_inputs = self.tokenizer(texts, return_tensors="pt", padding=True)
|
| 29 |
+
output = self.text_model(**encoded_inputs.to(self.text_model.device)).last_hidden_state
|
| 30 |
+
mask = encoded_inputs.attention_mask.to(dtype=bool)
|
| 31 |
+
# output = output * mask.unsqueeze(-1)
|
| 32 |
+
return output, mask
|
model/cfg_sampler.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from copy import deepcopy
|
| 5 |
+
|
| 6 |
+
# A wrapper model for Classifier-free guidance **SAMPLING** only
|
| 7 |
+
# https://arxiv.org/abs/2207.12598
|
| 8 |
+
class ClassifierFreeSampleModel(nn.Module):
|
| 9 |
+
|
| 10 |
+
def __init__(self, model):
|
| 11 |
+
super().__init__()
|
| 12 |
+
self.model = model # model is the actual model to run
|
| 13 |
+
|
| 14 |
+
assert self.model.cond_mask_prob > 0, 'Cannot run a guided diffusion on a model that has not been trained with no conditions'
|
| 15 |
+
|
| 16 |
+
# pointers to inner model
|
| 17 |
+
self.rot2xyz = self.model.rot2xyz
|
| 18 |
+
self.translation = self.model.translation
|
| 19 |
+
self.njoints = self.model.njoints
|
| 20 |
+
self.nfeats = self.model.nfeats
|
| 21 |
+
self.data_rep = self.model.data_rep
|
| 22 |
+
self.cond_mode = self.model.cond_mode
|
| 23 |
+
self.encode_text = self.model.encode_text
|
| 24 |
+
|
| 25 |
+
def forward(self, x, timesteps, y=None):
|
| 26 |
+
cond_mode = self.model.cond_mode
|
| 27 |
+
assert cond_mode in ['text', 'action']
|
| 28 |
+
y_uncond = deepcopy(y)
|
| 29 |
+
y_uncond['uncond'] = True
|
| 30 |
+
out = self.model(x, timesteps, y)
|
| 31 |
+
out_uncond = self.model(x, timesteps, y_uncond)
|
| 32 |
+
return out_uncond + (y['scale'].view(-1, 1, 1, 1) * (out - out_uncond))
|
| 33 |
+
|
model/mdm.py
ADDED
|
@@ -0,0 +1,480 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
import clip
|
| 6 |
+
from model.rotation2xyz import Rotation2xyz
|
| 7 |
+
from model.BERT.BERT_encoder import load_bert
|
| 8 |
+
from utils.misc import WeightedSum
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class MDM(nn.Module):
|
| 12 |
+
def __init__(self, modeltype, njoints, nfeats, num_actions, translation, pose_rep, glob, glob_rot,
|
| 13 |
+
latent_dim=256, ff_size=1024, num_layers=8, num_heads=4, dropout=0.1,
|
| 14 |
+
ablation=None, activation="gelu", legacy=False, data_rep='rot6d', dataset='amass', clip_dim=512,
|
| 15 |
+
arch='trans_enc', emb_trans_dec=False, clip_version=None, **kargs):
|
| 16 |
+
super().__init__()
|
| 17 |
+
|
| 18 |
+
self.legacy = legacy
|
| 19 |
+
self.modeltype = modeltype
|
| 20 |
+
self.njoints = njoints
|
| 21 |
+
self.nfeats = nfeats
|
| 22 |
+
self.num_actions = num_actions
|
| 23 |
+
self.data_rep = data_rep
|
| 24 |
+
self.dataset = dataset
|
| 25 |
+
|
| 26 |
+
self.pose_rep = pose_rep
|
| 27 |
+
self.glob = glob
|
| 28 |
+
self.glob_rot = glob_rot
|
| 29 |
+
self.translation = translation
|
| 30 |
+
|
| 31 |
+
self.latent_dim = latent_dim
|
| 32 |
+
|
| 33 |
+
self.ff_size = ff_size
|
| 34 |
+
self.num_layers = num_layers
|
| 35 |
+
self.num_heads = num_heads
|
| 36 |
+
self.dropout = dropout
|
| 37 |
+
|
| 38 |
+
self.ablation = ablation
|
| 39 |
+
self.activation = activation
|
| 40 |
+
self.clip_dim = clip_dim
|
| 41 |
+
self.action_emb = kargs.get('action_emb', None)
|
| 42 |
+
self.input_feats = self.njoints * self.nfeats
|
| 43 |
+
|
| 44 |
+
self.normalize_output = kargs.get('normalize_encoder_output', False)
|
| 45 |
+
|
| 46 |
+
self.cond_mode = kargs.get('cond_mode', 'no_cond')
|
| 47 |
+
self.cond_mask_prob = kargs.get('cond_mask_prob', 0.)
|
| 48 |
+
self.mask_frames = kargs.get('mask_frames', False)
|
| 49 |
+
self.arch = arch
|
| 50 |
+
self.gru_emb_dim = self.latent_dim if self.arch == 'gru' else 0
|
| 51 |
+
self.input_process = InputProcess(self.data_rep, self.input_feats+self.gru_emb_dim, self.latent_dim)
|
| 52 |
+
|
| 53 |
+
self.emb_policy = kargs.get('emb_policy', 'add')
|
| 54 |
+
|
| 55 |
+
self.sequence_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout, max_len=kargs.get('pos_embed_max_len', 5000))
|
| 56 |
+
self.emb_trans_dec = emb_trans_dec
|
| 57 |
+
|
| 58 |
+
self.pred_len = kargs.get('pred_len', 0)
|
| 59 |
+
self.context_len = kargs.get('context_len', 0)
|
| 60 |
+
self.total_len = self.pred_len + self.context_len
|
| 61 |
+
self.is_prefix_comp = self.total_len > 0
|
| 62 |
+
self.all_goal_joint_names = kargs.get('all_goal_joint_names', [])
|
| 63 |
+
|
| 64 |
+
self.multi_target_cond = kargs.get('multi_target_cond', False)
|
| 65 |
+
self.multi_encoder_type = kargs.get('multi_encoder_type', 'multi')
|
| 66 |
+
self.target_enc_layers = kargs.get('target_enc_layers', 1)
|
| 67 |
+
if self.multi_target_cond:
|
| 68 |
+
if self.multi_encoder_type == 'multi':
|
| 69 |
+
self.embed_target_cond = EmbedTargetLocMulti(self.all_goal_joint_names, self.latent_dim)
|
| 70 |
+
elif self.multi_encoder_type == 'single':
|
| 71 |
+
self.embed_target_cond = EmbedTargetLocSingle(self.all_goal_joint_names, self.latent_dim, self.target_enc_layers)
|
| 72 |
+
elif self.multi_encoder_type == 'split':
|
| 73 |
+
self.embed_target_cond = EmbedTargetLocSplit(self.all_goal_joint_names, self.latent_dim, self.target_enc_layers)
|
| 74 |
+
|
| 75 |
+
if self.arch == 'trans_enc':
|
| 76 |
+
print("TRANS_ENC init")
|
| 77 |
+
seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim,
|
| 78 |
+
nhead=self.num_heads,
|
| 79 |
+
dim_feedforward=self.ff_size,
|
| 80 |
+
dropout=self.dropout,
|
| 81 |
+
activation=self.activation)
|
| 82 |
+
|
| 83 |
+
self.seqTransEncoder = nn.TransformerEncoder(seqTransEncoderLayer,
|
| 84 |
+
num_layers=self.num_layers)
|
| 85 |
+
elif self.arch == 'trans_dec':
|
| 86 |
+
print("TRANS_DEC init")
|
| 87 |
+
seqTransDecoderLayer = nn.TransformerDecoderLayer(d_model=self.latent_dim,
|
| 88 |
+
nhead=self.num_heads,
|
| 89 |
+
dim_feedforward=self.ff_size,
|
| 90 |
+
dropout=self.dropout,
|
| 91 |
+
activation=activation)
|
| 92 |
+
self.seqTransDecoder = nn.TransformerDecoder(seqTransDecoderLayer,
|
| 93 |
+
num_layers=self.num_layers)
|
| 94 |
+
elif self.arch == 'gru':
|
| 95 |
+
print("GRU init")
|
| 96 |
+
self.gru = nn.GRU(self.latent_dim, self.latent_dim, num_layers=self.num_layers, batch_first=True)
|
| 97 |
+
else:
|
| 98 |
+
raise ValueError('Please choose correct architecture [trans_enc, trans_dec, gru]')
|
| 99 |
+
|
| 100 |
+
self.embed_timestep = TimestepEmbedder(self.latent_dim, self.sequence_pos_encoder)
|
| 101 |
+
|
| 102 |
+
if self.cond_mode != 'no_cond':
|
| 103 |
+
if 'text' in self.cond_mode:
|
| 104 |
+
# We support CLIP encoder and DistilBERT
|
| 105 |
+
print('EMBED TEXT')
|
| 106 |
+
|
| 107 |
+
self.text_encoder_type = kargs.get('text_encoder_type', 'clip')
|
| 108 |
+
|
| 109 |
+
if self.text_encoder_type == "clip":
|
| 110 |
+
print('Loading CLIP...')
|
| 111 |
+
self.clip_version = clip_version
|
| 112 |
+
self.clip_model = self.load_and_freeze_clip(clip_version)
|
| 113 |
+
self.encode_text = self.clip_encode_text
|
| 114 |
+
elif self.text_encoder_type == 'bert':
|
| 115 |
+
assert self.arch == 'trans_dec'
|
| 116 |
+
# assert self.emb_trans_dec == False # passing just the time embed so it's fine
|
| 117 |
+
print("Loading BERT...")
|
| 118 |
+
# bert_model_path = 'model/BERT/distilbert-base-uncased'
|
| 119 |
+
bert_model_path = 'distilbert/distilbert-base-uncased'
|
| 120 |
+
self.clip_model = load_bert(bert_model_path) # Sorry for that, the naming is for backward compatibility
|
| 121 |
+
self.encode_text = self.bert_encode_text
|
| 122 |
+
self.clip_dim = 768
|
| 123 |
+
else:
|
| 124 |
+
raise ValueError('We only support [CLIP, BERT] text encoders')
|
| 125 |
+
|
| 126 |
+
self.embed_text = nn.Linear(self.clip_dim, self.latent_dim)
|
| 127 |
+
|
| 128 |
+
if 'action' in self.cond_mode:
|
| 129 |
+
self.embed_action = EmbedAction(self.num_actions, self.latent_dim)
|
| 130 |
+
print('EMBED ACTION')
|
| 131 |
+
|
| 132 |
+
self.output_process = OutputProcess(self.data_rep, self.input_feats, self.latent_dim, self.njoints,
|
| 133 |
+
self.nfeats)
|
| 134 |
+
|
| 135 |
+
self.rot2xyz = Rotation2xyz(device='cpu', dataset=self.dataset)
|
| 136 |
+
|
| 137 |
+
def parameters_wo_clip(self):
|
| 138 |
+
return [p for name, p in self.named_parameters() if not name.startswith('clip_model.')]
|
| 139 |
+
|
| 140 |
+
def load_and_freeze_clip(self, clip_version):
|
| 141 |
+
clip_model, clip_preprocess = clip.load(clip_version, device='cpu',
|
| 142 |
+
jit=False) # Must set jit=False for training
|
| 143 |
+
clip.model.convert_weights(
|
| 144 |
+
clip_model) # Actually this line is unnecessary since clip by default already on float16
|
| 145 |
+
|
| 146 |
+
# Freeze CLIP weights
|
| 147 |
+
clip_model.eval()
|
| 148 |
+
for p in clip_model.parameters():
|
| 149 |
+
p.requires_grad = False
|
| 150 |
+
|
| 151 |
+
return clip_model
|
| 152 |
+
|
| 153 |
+
def mask_cond(self, cond, force_mask=False):
|
| 154 |
+
bs = cond.shape[-2]
|
| 155 |
+
if force_mask:
|
| 156 |
+
return torch.zeros_like(cond)
|
| 157 |
+
elif self.training and self.cond_mask_prob > 0.:
|
| 158 |
+
mask = torch.bernoulli(torch.ones(bs, device=cond.device) * self.cond_mask_prob).view(1, bs, 1) # 1-> use null_cond, 0-> use real cond
|
| 159 |
+
return cond * (1. - mask)
|
| 160 |
+
else:
|
| 161 |
+
return cond
|
| 162 |
+
|
| 163 |
+
def clip_encode_text(self, raw_text):
|
| 164 |
+
# raw_text - list (batch_size length) of strings with input text prompts
|
| 165 |
+
device = next(self.parameters()).device
|
| 166 |
+
max_text_len = 20 if self.dataset in ['humanml', 'kit'] else None # Specific hardcoding for humanml dataset
|
| 167 |
+
if max_text_len is not None:
|
| 168 |
+
default_context_length = 77
|
| 169 |
+
context_length = max_text_len + 2 # start_token + 20 + end_token
|
| 170 |
+
assert context_length < default_context_length
|
| 171 |
+
texts = clip.tokenize(raw_text, context_length=context_length, truncate=True).to(device) # [bs, context_length] # if n_tokens > context_length -> will truncate
|
| 172 |
+
# print('texts', texts.shape)
|
| 173 |
+
zero_pad = torch.zeros([texts.shape[0], default_context_length-context_length], dtype=texts.dtype, device=texts.device)
|
| 174 |
+
texts = torch.cat([texts, zero_pad], dim=1)
|
| 175 |
+
# print('texts after pad', texts.shape, texts)
|
| 176 |
+
else:
|
| 177 |
+
texts = clip.tokenize(raw_text, truncate=True).to(device) # [bs, context_length] # if n_tokens > 77 -> will truncate
|
| 178 |
+
return self.clip_model.encode_text(texts).float().unsqueeze(0)
|
| 179 |
+
|
| 180 |
+
def bert_encode_text(self, raw_text):
|
| 181 |
+
# enc_text = self.clip_model(raw_text)
|
| 182 |
+
# enc_text = enc_text.permute(1, 0, 2)
|
| 183 |
+
# return enc_text
|
| 184 |
+
enc_text, mask = self.clip_model(raw_text) # self.clip_model.get_last_hidden_state(raw_text, return_mask=True) # mask: False means no token there
|
| 185 |
+
enc_text = enc_text.permute(1, 0, 2)
|
| 186 |
+
mask = ~mask # mask: True means no token there, we invert since the meaning of mask for transformer is inverted https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html
|
| 187 |
+
return enc_text, mask
|
| 188 |
+
|
| 189 |
+
def forward(self, x, timesteps, y=None):
|
| 190 |
+
"""
|
| 191 |
+
x: [batch_size, njoints, nfeats, max_frames], denoted x_t in the paper
|
| 192 |
+
timesteps: [batch_size] (int)
|
| 193 |
+
"""
|
| 194 |
+
bs, njoints, nfeats, nframes = x.shape
|
| 195 |
+
time_emb = self.embed_timestep(timesteps) # [1, bs, d]
|
| 196 |
+
|
| 197 |
+
if 'target_cond' in y.keys():
|
| 198 |
+
# NOTE: We don't use CFG for joints - but we do wat to support uncond sampling for generation and eval!
|
| 199 |
+
time_emb += self.mask_cond(self.embed_target_cond(y['target_cond'], y['target_joint_names'], y['is_heading'])[None], force_mask=y.get('target_uncond', False)) # For uncond support and CFG
|
| 200 |
+
# time_emb += self.embed_target_cond(y['target_cond'], y['target_joint_names'], y['is_heading'])[None]
|
| 201 |
+
|
| 202 |
+
# Build input for prefix completion
|
| 203 |
+
if self.is_prefix_comp:
|
| 204 |
+
x = torch.cat([y['prefix'], x], dim=-1)
|
| 205 |
+
y['mask'] = torch.cat([torch.ones([bs, 1, 1, self.context_len], dtype=y['mask'].dtype, device=y['mask'].device),
|
| 206 |
+
y['mask']], dim=-1)
|
| 207 |
+
|
| 208 |
+
force_mask = y.get('uncond', False)
|
| 209 |
+
if 'text' in self.cond_mode:
|
| 210 |
+
if 'text_embed' in y.keys(): # caching option
|
| 211 |
+
enc_text = y['text_embed']
|
| 212 |
+
else:
|
| 213 |
+
enc_text = self.encode_text(y['text'])
|
| 214 |
+
if type(enc_text) == tuple:
|
| 215 |
+
enc_text, text_mask = enc_text
|
| 216 |
+
if text_mask.shape[0] == 1 and bs > 1: # casting mask for the single-prompt-for-all case
|
| 217 |
+
text_mask = torch.repeat_interleave(text_mask, bs, dim=0)
|
| 218 |
+
text_emb = self.embed_text(self.mask_cond(enc_text, force_mask=force_mask)) # casting mask for the single-prompt-for-all case
|
| 219 |
+
if self.emb_policy == 'add':
|
| 220 |
+
emb = text_emb + time_emb
|
| 221 |
+
else:
|
| 222 |
+
emb = torch.cat([time_emb, text_emb], dim=0)
|
| 223 |
+
text_mask = torch.cat([torch.zeros_like(text_mask[:, 0:1]), text_mask], dim=1)
|
| 224 |
+
if 'action' in self.cond_mode:
|
| 225 |
+
action_emb = self.embed_action(y['action'])
|
| 226 |
+
emb = time_emb + self.mask_cond(action_emb, force_mask=force_mask)
|
| 227 |
+
if self.cond_mode == 'no_cond':
|
| 228 |
+
# unconstrained
|
| 229 |
+
emb = time_emb
|
| 230 |
+
|
| 231 |
+
if self.arch == 'gru':
|
| 232 |
+
x_reshaped = x.reshape(bs, njoints*nfeats, 1, nframes)
|
| 233 |
+
emb_gru = emb.repeat(nframes, 1, 1) #[#frames, bs, d]
|
| 234 |
+
emb_gru = emb_gru.permute(1, 2, 0) #[bs, d, #frames]
|
| 235 |
+
emb_gru = emb_gru.reshape(bs, self.latent_dim, 1, nframes) #[bs, d, 1, #frames]
|
| 236 |
+
x = torch.cat((x_reshaped, emb_gru), axis=1) #[bs, d+joints*feat, 1, #frames]
|
| 237 |
+
|
| 238 |
+
x = self.input_process(x)
|
| 239 |
+
|
| 240 |
+
# TODO - move to collate
|
| 241 |
+
frames_mask = None
|
| 242 |
+
is_valid_mask = y['mask'].shape[-1] > 1 # Don't use mask with the generate script
|
| 243 |
+
if self.mask_frames and is_valid_mask:
|
| 244 |
+
frames_mask = torch.logical_not(y['mask'][..., :x.shape[0]].squeeze(1).squeeze(1)).to(device=x.device)
|
| 245 |
+
if self.emb_trans_dec or self.arch == 'trans_enc':
|
| 246 |
+
step_mask = torch.zeros((bs, 1), dtype=torch.bool, device=x.device)
|
| 247 |
+
frames_mask = torch.cat([step_mask, frames_mask], dim=1)
|
| 248 |
+
|
| 249 |
+
if self.arch == 'trans_enc':
|
| 250 |
+
# adding the timestep embed
|
| 251 |
+
xseq = torch.cat((emb, x), axis=0) # [seqlen+1, bs, d]
|
| 252 |
+
xseq = self.sequence_pos_encoder(xseq) # [seqlen+1, bs, d]
|
| 253 |
+
output = self.seqTransEncoder(xseq, src_key_padding_mask=frames_mask)[1:] # , src_key_padding_mask=~maskseq) # [seqlen, bs, d]
|
| 254 |
+
|
| 255 |
+
elif self.arch == 'trans_dec':
|
| 256 |
+
if self.emb_trans_dec:
|
| 257 |
+
xseq = torch.cat((time_emb, x), axis=0)
|
| 258 |
+
else:
|
| 259 |
+
xseq = x
|
| 260 |
+
xseq = self.sequence_pos_encoder(xseq) # [seqlen+1, bs, d]
|
| 261 |
+
|
| 262 |
+
if self.text_encoder_type == 'clip':
|
| 263 |
+
output = self.seqTransDecoder(tgt=xseq, memory=emb, tgt_key_padding_mask=frames_mask)
|
| 264 |
+
elif self.text_encoder_type == 'bert':
|
| 265 |
+
output = self.seqTransDecoder(tgt=xseq, memory=emb, memory_key_padding_mask=text_mask, tgt_key_padding_mask=frames_mask) # Rotem's bug fix
|
| 266 |
+
else:
|
| 267 |
+
raise ValueError()
|
| 268 |
+
|
| 269 |
+
if self.emb_trans_dec:
|
| 270 |
+
output = output[1:] # [seqlen, bs, d]
|
| 271 |
+
|
| 272 |
+
elif self.arch == 'gru':
|
| 273 |
+
xseq = x
|
| 274 |
+
xseq = self.sequence_pos_encoder(xseq) # [seqlen, bs, d]
|
| 275 |
+
output, _ = self.gru(xseq)
|
| 276 |
+
|
| 277 |
+
# Extract completed suffix
|
| 278 |
+
if self.is_prefix_comp:
|
| 279 |
+
output = output[self.context_len:]
|
| 280 |
+
y['mask'] = y['mask'][..., self.context_len:]
|
| 281 |
+
|
| 282 |
+
output = self.output_process(output) # [bs, njoints, nfeats, nframes]
|
| 283 |
+
return output
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
def _apply(self, fn):
|
| 287 |
+
super()._apply(fn)
|
| 288 |
+
self.rot2xyz.smpl_model._apply(fn)
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
def train(self, *args, **kwargs):
|
| 292 |
+
super().train(*args, **kwargs)
|
| 293 |
+
self.rot2xyz.smpl_model.train(*args, **kwargs)
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
class PositionalEncoding(nn.Module):
|
| 297 |
+
def __init__(self, d_model, dropout=0.1, max_len=5000):
|
| 298 |
+
super(PositionalEncoding, self).__init__()
|
| 299 |
+
self.dropout = nn.Dropout(p=dropout)
|
| 300 |
+
|
| 301 |
+
pe = torch.zeros(max_len, d_model)
|
| 302 |
+
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
| 303 |
+
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
|
| 304 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
| 305 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
| 306 |
+
pe = pe.unsqueeze(0).transpose(0, 1)
|
| 307 |
+
|
| 308 |
+
self.register_buffer('pe', pe)
|
| 309 |
+
|
| 310 |
+
def forward(self, x):
|
| 311 |
+
# not used in the final model
|
| 312 |
+
x = x + self.pe[:x.shape[0], :]
|
| 313 |
+
return self.dropout(x)
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
class TimestepEmbedder(nn.Module):
|
| 317 |
+
def __init__(self, latent_dim, sequence_pos_encoder):
|
| 318 |
+
super().__init__()
|
| 319 |
+
self.latent_dim = latent_dim
|
| 320 |
+
self.sequence_pos_encoder = sequence_pos_encoder
|
| 321 |
+
|
| 322 |
+
time_embed_dim = self.latent_dim
|
| 323 |
+
self.time_embed = nn.Sequential(
|
| 324 |
+
nn.Linear(self.latent_dim, time_embed_dim),
|
| 325 |
+
nn.SiLU(),
|
| 326 |
+
nn.Linear(time_embed_dim, time_embed_dim),
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
def forward(self, timesteps):
|
| 330 |
+
return self.time_embed(self.sequence_pos_encoder.pe[timesteps]).permute(1, 0, 2)
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
class InputProcess(nn.Module):
|
| 334 |
+
def __init__(self, data_rep, input_feats, latent_dim):
|
| 335 |
+
super().__init__()
|
| 336 |
+
self.data_rep = data_rep
|
| 337 |
+
self.input_feats = input_feats
|
| 338 |
+
self.latent_dim = latent_dim
|
| 339 |
+
self.poseEmbedding = nn.Linear(self.input_feats, self.latent_dim)
|
| 340 |
+
if self.data_rep == 'rot_vel':
|
| 341 |
+
self.velEmbedding = nn.Linear(self.input_feats, self.latent_dim)
|
| 342 |
+
|
| 343 |
+
def forward(self, x):
|
| 344 |
+
bs, njoints, nfeats, nframes = x.shape
|
| 345 |
+
x = x.permute((3, 0, 1, 2)).reshape(nframes, bs, njoints*nfeats)
|
| 346 |
+
|
| 347 |
+
if self.data_rep in ['rot6d', 'xyz', 'hml_vec']:
|
| 348 |
+
x = self.poseEmbedding(x) # [seqlen, bs, d]
|
| 349 |
+
return x
|
| 350 |
+
elif self.data_rep == 'rot_vel':
|
| 351 |
+
first_pose = x[[0]] # [1, bs, 150]
|
| 352 |
+
first_pose = self.poseEmbedding(first_pose) # [1, bs, d]
|
| 353 |
+
vel = x[1:] # [seqlen-1, bs, 150]
|
| 354 |
+
vel = self.velEmbedding(vel) # [seqlen-1, bs, d]
|
| 355 |
+
return torch.cat((first_pose, vel), axis=0) # [seqlen, bs, d]
|
| 356 |
+
else:
|
| 357 |
+
raise ValueError
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
class OutputProcess(nn.Module):
|
| 361 |
+
def __init__(self, data_rep, input_feats, latent_dim, njoints, nfeats):
|
| 362 |
+
super().__init__()
|
| 363 |
+
self.data_rep = data_rep
|
| 364 |
+
self.input_feats = input_feats
|
| 365 |
+
self.latent_dim = latent_dim
|
| 366 |
+
self.njoints = njoints
|
| 367 |
+
self.nfeats = nfeats
|
| 368 |
+
self.poseFinal = nn.Linear(self.latent_dim, self.input_feats)
|
| 369 |
+
if self.data_rep == 'rot_vel':
|
| 370 |
+
self.velFinal = nn.Linear(self.latent_dim, self.input_feats)
|
| 371 |
+
|
| 372 |
+
def forward(self, output):
|
| 373 |
+
nframes, bs, d = output.shape
|
| 374 |
+
if self.data_rep in ['rot6d', 'xyz', 'hml_vec']:
|
| 375 |
+
output = self.poseFinal(output) # [seqlen, bs, 150]
|
| 376 |
+
elif self.data_rep == 'rot_vel':
|
| 377 |
+
first_pose = output[[0]] # [1, bs, d]
|
| 378 |
+
first_pose = self.poseFinal(first_pose) # [1, bs, 150]
|
| 379 |
+
vel = output[1:] # [seqlen-1, bs, d]
|
| 380 |
+
vel = self.velFinal(vel) # [seqlen-1, bs, 150]
|
| 381 |
+
output = torch.cat((first_pose, vel), axis=0) # [seqlen, bs, 150]
|
| 382 |
+
else:
|
| 383 |
+
raise ValueError
|
| 384 |
+
output = output.reshape(nframes, bs, self.njoints, self.nfeats)
|
| 385 |
+
output = output.permute(1, 2, 3, 0) # [bs, njoints, nfeats, nframes]
|
| 386 |
+
return output
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
class EmbedAction(nn.Module):
|
| 390 |
+
def __init__(self, num_actions, latent_dim):
|
| 391 |
+
super().__init__()
|
| 392 |
+
self.action_embedding = nn.Parameter(torch.randn(num_actions, latent_dim))
|
| 393 |
+
|
| 394 |
+
def forward(self, input):
|
| 395 |
+
idx = input[:, 0].to(torch.long) # an index array must be long
|
| 396 |
+
output = self.action_embedding[idx]
|
| 397 |
+
return output
|
| 398 |
+
|
| 399 |
+
class EmbedTargetLocSingle(nn.Module):
|
| 400 |
+
def __init__(self, all_goal_joint_names, latent_dim, num_layers=1):
|
| 401 |
+
super().__init__()
|
| 402 |
+
self.extended_goal_joint_names = all_goal_joint_names + ['traj', 'heading']
|
| 403 |
+
self.target_cond_dim = len(self.extended_goal_joint_names) * 4 # 4 => (x,y,z,is_valid)
|
| 404 |
+
self.latent_dim = latent_dim
|
| 405 |
+
_layers = [nn.Linear(self.target_cond_dim, self.latent_dim)]
|
| 406 |
+
for _ in range(num_layers):
|
| 407 |
+
_layers += [nn.SiLU(), nn.Linear(self.latent_dim, self.latent_dim)]
|
| 408 |
+
self.mlp = nn.Sequential(*_layers)
|
| 409 |
+
|
| 410 |
+
def forward(self, input, target_joint_names, target_heading):
|
| 411 |
+
# TODO - generate validity from outside the model
|
| 412 |
+
validity = torch.zeros_like(input)[..., :1]
|
| 413 |
+
for sample_idx, sample_joint_names in enumerate(target_joint_names):
|
| 414 |
+
sample_joint_names_w_heading = np.append(sample_joint_names, 'heading') if target_heading[sample_idx] else sample_joint_names
|
| 415 |
+
for j in sample_joint_names_w_heading:
|
| 416 |
+
validity[sample_idx, self.extended_goal_joint_names.index(j)] = 1.
|
| 417 |
+
|
| 418 |
+
mlp_input = torch.cat([input, validity], dim=-1).view(input.shape[0], -1)
|
| 419 |
+
return self.mlp(mlp_input)
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
class EmbedTargetLocSplit(nn.Module):
|
| 423 |
+
def __init__(self, all_goal_joint_names, latent_dim, num_layers=1):
|
| 424 |
+
super().__init__()
|
| 425 |
+
self.extended_goal_joint_names = all_goal_joint_names + ['traj', 'heading']
|
| 426 |
+
self.target_cond_dim = 4
|
| 427 |
+
self.latent_dim = latent_dim
|
| 428 |
+
self.splited_dim = self.latent_dim // len(self.extended_goal_joint_names)
|
| 429 |
+
assert self.latent_dim % len(self.extended_goal_joint_names) == 0
|
| 430 |
+
self.mini_mlps = nn.ModuleList()
|
| 431 |
+
for _ in self.extended_goal_joint_names:
|
| 432 |
+
_layers = [nn.Linear(self.target_cond_dim, self.splited_dim)]
|
| 433 |
+
for _ in range(num_layers):
|
| 434 |
+
_layers += [nn.SiLU(), nn.Linear(self.splited_dim, self.splited_dim)]
|
| 435 |
+
self.mini_mlps.append(nn.Sequential(*_layers))
|
| 436 |
+
|
| 437 |
+
def forward(self, input, target_joint_names, target_heading):
|
| 438 |
+
# TODO - generate validity from outside the model
|
| 439 |
+
validity = torch.zeros_like(input)[..., :1]
|
| 440 |
+
for sample_idx, sample_joint_names in enumerate(target_joint_names):
|
| 441 |
+
sample_joint_names_w_heading = np.append(sample_joint_names, 'heading') if target_heading[sample_idx] else sample_joint_names
|
| 442 |
+
for j in sample_joint_names_w_heading:
|
| 443 |
+
validity[sample_idx, self.extended_goal_joint_names.index(j)] = 1.
|
| 444 |
+
|
| 445 |
+
mlp_input = torch.cat([input, validity], dim=-1)
|
| 446 |
+
mlp_splits = [self.mini_mlps[i](mlp_input[:, i]) for i in range(mlp_input.shape[1])]
|
| 447 |
+
return torch.cat(mlp_splits, dim=-1)
|
| 448 |
+
|
| 449 |
+
class EmbedTargetLocMulti(nn.Module):
|
| 450 |
+
def __init__(self, all_goal_joint_names, latent_dim):
|
| 451 |
+
super().__init__()
|
| 452 |
+
|
| 453 |
+
# todo: use a tensor of weight per joint, and another one for biases, then apply a selection in one go like we to for actions
|
| 454 |
+
self.extended_goal_joint_names = all_goal_joint_names + ['traj', 'heading']
|
| 455 |
+
self.extended_goal_joint_idx = {joint_name: idx for idx, joint_name in enumerate(self.extended_goal_joint_names)}
|
| 456 |
+
self.n_extended_goal_joints = len(self.extended_goal_joint_names)
|
| 457 |
+
self.target_loc_emb = nn.ParameterDict({joint_name:
|
| 458 |
+
nn.Sequential(
|
| 459 |
+
nn.Linear(3, latent_dim),
|
| 460 |
+
nn.SiLU(),
|
| 461 |
+
nn.Linear(latent_dim, latent_dim))
|
| 462 |
+
for joint_name in self.extended_goal_joint_names}) # todo: check if 3 works for heading and traj
|
| 463 |
+
# nn.Linear(3, latent_dim) for joint_name in self.extended_goal_joint_names}) # todo: check if 3 works for heading and traj
|
| 464 |
+
self.target_all_loc_emb = WeightedSum(self.n_extended_goal_joints) # nn.Linear(self.n_extended_goal_joints, latent_dim)
|
| 465 |
+
self.latent_dim = latent_dim
|
| 466 |
+
|
| 467 |
+
def forward(self, input, target_joint_names, target_heading):
|
| 468 |
+
output = torch.zeros((input.shape[0], self.latent_dim), dtype=input.dtype, device=input.device)
|
| 469 |
+
|
| 470 |
+
# Iterate over the batch and apply the appropriate filter for each joint
|
| 471 |
+
for sample_idx, sample_joint_names in enumerate(target_joint_names):
|
| 472 |
+
sample_joint_names_w_heading = np.append(sample_joint_names, 'heading') if target_heading[sample_idx] else sample_joint_names
|
| 473 |
+
output_one_sample = torch.zeros((self.n_extended_goal_joints, self.latent_dim), dtype=input.dtype, device=input.device)
|
| 474 |
+
for joint_name in sample_joint_names_w_heading:
|
| 475 |
+
layer = self.target_loc_emb[joint_name]
|
| 476 |
+
output_one_sample[self.extended_goal_joint_idx[joint_name]] = layer(input[sample_idx, self.extended_goal_joint_idx[joint_name]])
|
| 477 |
+
output[sample_idx] = self.target_all_loc_emb(output_one_sample)
|
| 478 |
+
# print(torch.where(output_one_sample.sum(axis=1)!=0)[0].cpu().numpy())
|
| 479 |
+
|
| 480 |
+
return output
|
model/rotation2xyz.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This code is based on https://github.com/Mathux/ACTOR.git
|
| 2 |
+
import torch
|
| 3 |
+
import utils.rotation_conversions as geometry
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
from model.smpl import SMPL, JOINTSTYPE_ROOT
|
| 7 |
+
# from .get_model import JOINTSTYPES
|
| 8 |
+
JOINTSTYPES = ["a2m", "a2mpl", "smpl", "vibe", "vertices"]
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class Rotation2xyz:
|
| 12 |
+
def __init__(self, device, dataset='amass'):
|
| 13 |
+
self.device = device
|
| 14 |
+
self.dataset = dataset
|
| 15 |
+
self.smpl_model = SMPL().eval().to(device)
|
| 16 |
+
|
| 17 |
+
def __call__(self, x, mask, pose_rep, translation, glob,
|
| 18 |
+
jointstype, vertstrans, betas=None, beta=0,
|
| 19 |
+
glob_rot=None, get_rotations_back=False, **kwargs):
|
| 20 |
+
if pose_rep == "xyz":
|
| 21 |
+
return x
|
| 22 |
+
|
| 23 |
+
if mask is None:
|
| 24 |
+
mask = torch.ones((x.shape[0], x.shape[-1]), dtype=bool, device=x.device)
|
| 25 |
+
|
| 26 |
+
if not glob and glob_rot is None:
|
| 27 |
+
raise TypeError("You must specify global rotation if glob is False")
|
| 28 |
+
|
| 29 |
+
if jointstype not in JOINTSTYPES:
|
| 30 |
+
raise NotImplementedError("This jointstype is not implemented.")
|
| 31 |
+
|
| 32 |
+
if translation:
|
| 33 |
+
x_translations = x[:, -1, :3]
|
| 34 |
+
x_rotations = x[:, :-1]
|
| 35 |
+
else:
|
| 36 |
+
x_rotations = x
|
| 37 |
+
|
| 38 |
+
x_rotations = x_rotations.permute(0, 3, 1, 2)
|
| 39 |
+
nsamples, time, njoints, feats = x_rotations.shape
|
| 40 |
+
|
| 41 |
+
# Compute rotations (convert only masked sequences output)
|
| 42 |
+
if pose_rep == "rotvec":
|
| 43 |
+
rotations = geometry.axis_angle_to_matrix(x_rotations[mask])
|
| 44 |
+
elif pose_rep == "rotmat":
|
| 45 |
+
rotations = x_rotations[mask].view(-1, njoints, 3, 3)
|
| 46 |
+
elif pose_rep == "rotquat":
|
| 47 |
+
rotations = geometry.quaternion_to_matrix(x_rotations[mask])
|
| 48 |
+
elif pose_rep == "rot6d":
|
| 49 |
+
rotations = geometry.rotation_6d_to_matrix(x_rotations[mask])
|
| 50 |
+
else:
|
| 51 |
+
raise NotImplementedError("No geometry for this one.")
|
| 52 |
+
|
| 53 |
+
if not glob:
|
| 54 |
+
global_orient = torch.tensor(glob_rot, device=x.device)
|
| 55 |
+
global_orient = geometry.axis_angle_to_matrix(global_orient).view(1, 1, 3, 3)
|
| 56 |
+
global_orient = global_orient.repeat(len(rotations), 1, 1, 1)
|
| 57 |
+
else:
|
| 58 |
+
global_orient = rotations[:, 0]
|
| 59 |
+
rotations = rotations[:, 1:]
|
| 60 |
+
|
| 61 |
+
if betas is None:
|
| 62 |
+
betas = torch.zeros([rotations.shape[0], self.smpl_model.num_betas],
|
| 63 |
+
dtype=rotations.dtype, device=rotations.device)
|
| 64 |
+
betas[:, 1] = beta
|
| 65 |
+
# import ipdb; ipdb.set_trace()
|
| 66 |
+
out = self.smpl_model(body_pose=rotations, global_orient=global_orient, betas=betas)
|
| 67 |
+
|
| 68 |
+
# get the desirable joints
|
| 69 |
+
joints = out[jointstype]
|
| 70 |
+
|
| 71 |
+
x_xyz = torch.empty(nsamples, time, joints.shape[1], 3, device=x.device, dtype=x.dtype)
|
| 72 |
+
x_xyz[~mask] = 0
|
| 73 |
+
x_xyz[mask] = joints
|
| 74 |
+
|
| 75 |
+
x_xyz = x_xyz.permute(0, 2, 3, 1).contiguous()
|
| 76 |
+
|
| 77 |
+
# the first translation root at the origin on the prediction
|
| 78 |
+
if jointstype != "vertices":
|
| 79 |
+
rootindex = JOINTSTYPE_ROOT[jointstype]
|
| 80 |
+
x_xyz = x_xyz - x_xyz[:, [rootindex], :, :]
|
| 81 |
+
|
| 82 |
+
if translation and vertstrans:
|
| 83 |
+
# the first translation root at the origin
|
| 84 |
+
x_translations = x_translations - x_translations[:, :, [0]]
|
| 85 |
+
|
| 86 |
+
# add the translation to all the joints
|
| 87 |
+
x_xyz = x_xyz + x_translations[:, None, :, :]
|
| 88 |
+
|
| 89 |
+
if get_rotations_back:
|
| 90 |
+
return x_xyz, rotations, global_orient
|
| 91 |
+
else:
|
| 92 |
+
return x_xyz
|
model/smpl.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This code is based on https://github.com/Mathux/ACTOR.git
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
import contextlib
|
| 6 |
+
|
| 7 |
+
from smplx import SMPLLayer as _SMPLLayer
|
| 8 |
+
from smplx.lbs import vertices2joints
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
# action2motion_joints = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 21, 24, 38]
|
| 12 |
+
# change 0 and 8
|
| 13 |
+
action2motion_joints = [8, 1, 2, 3, 4, 5, 6, 7, 0, 9, 10, 11, 12, 13, 14, 21, 24, 38]
|
| 14 |
+
|
| 15 |
+
from utils.config import SMPL_MODEL_PATH, JOINT_REGRESSOR_TRAIN_EXTRA
|
| 16 |
+
|
| 17 |
+
JOINTSTYPE_ROOT = {"a2m": 0, # action2motion
|
| 18 |
+
"smpl": 0,
|
| 19 |
+
"a2mpl": 0, # set(smpl, a2m)
|
| 20 |
+
"vibe": 8} # 0 is the 8 position: OP MidHip below
|
| 21 |
+
|
| 22 |
+
JOINT_MAP = {
|
| 23 |
+
'OP Nose': 24, 'OP Neck': 12, 'OP RShoulder': 17,
|
| 24 |
+
'OP RElbow': 19, 'OP RWrist': 21, 'OP LShoulder': 16,
|
| 25 |
+
'OP LElbow': 18, 'OP LWrist': 20, 'OP MidHip': 0,
|
| 26 |
+
'OP RHip': 2, 'OP RKnee': 5, 'OP RAnkle': 8,
|
| 27 |
+
'OP LHip': 1, 'OP LKnee': 4, 'OP LAnkle': 7,
|
| 28 |
+
'OP REye': 25, 'OP LEye': 26, 'OP REar': 27,
|
| 29 |
+
'OP LEar': 28, 'OP LBigToe': 29, 'OP LSmallToe': 30,
|
| 30 |
+
'OP LHeel': 31, 'OP RBigToe': 32, 'OP RSmallToe': 33, 'OP RHeel': 34,
|
| 31 |
+
'Right Ankle': 8, 'Right Knee': 5, 'Right Hip': 45,
|
| 32 |
+
'Left Hip': 46, 'Left Knee': 4, 'Left Ankle': 7,
|
| 33 |
+
'Right Wrist': 21, 'Right Elbow': 19, 'Right Shoulder': 17,
|
| 34 |
+
'Left Shoulder': 16, 'Left Elbow': 18, 'Left Wrist': 20,
|
| 35 |
+
'Neck (LSP)': 47, 'Top of Head (LSP)': 48,
|
| 36 |
+
'Pelvis (MPII)': 49, 'Thorax (MPII)': 50,
|
| 37 |
+
'Spine (H36M)': 51, 'Jaw (H36M)': 52,
|
| 38 |
+
'Head (H36M)': 53, 'Nose': 24, 'Left Eye': 26,
|
| 39 |
+
'Right Eye': 25, 'Left Ear': 28, 'Right Ear': 27
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
JOINT_NAMES = [
|
| 43 |
+
'OP Nose', 'OP Neck', 'OP RShoulder',
|
| 44 |
+
'OP RElbow', 'OP RWrist', 'OP LShoulder',
|
| 45 |
+
'OP LElbow', 'OP LWrist', 'OP MidHip',
|
| 46 |
+
'OP RHip', 'OP RKnee', 'OP RAnkle',
|
| 47 |
+
'OP LHip', 'OP LKnee', 'OP LAnkle',
|
| 48 |
+
'OP REye', 'OP LEye', 'OP REar',
|
| 49 |
+
'OP LEar', 'OP LBigToe', 'OP LSmallToe',
|
| 50 |
+
'OP LHeel', 'OP RBigToe', 'OP RSmallToe', 'OP RHeel',
|
| 51 |
+
'Right Ankle', 'Right Knee', 'Right Hip',
|
| 52 |
+
'Left Hip', 'Left Knee', 'Left Ankle',
|
| 53 |
+
'Right Wrist', 'Right Elbow', 'Right Shoulder',
|
| 54 |
+
'Left Shoulder', 'Left Elbow', 'Left Wrist',
|
| 55 |
+
'Neck (LSP)', 'Top of Head (LSP)',
|
| 56 |
+
'Pelvis (MPII)', 'Thorax (MPII)',
|
| 57 |
+
'Spine (H36M)', 'Jaw (H36M)',
|
| 58 |
+
'Head (H36M)', 'Nose', 'Left Eye',
|
| 59 |
+
'Right Eye', 'Left Ear', 'Right Ear'
|
| 60 |
+
]
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
# adapted from VIBE/SPIN to output smpl_joints, vibe joints and action2motion joints
|
| 64 |
+
class SMPL(_SMPLLayer):
|
| 65 |
+
""" Extension of the official SMPL implementation to support more joints """
|
| 66 |
+
|
| 67 |
+
def __init__(self, model_path=SMPL_MODEL_PATH, **kwargs):
|
| 68 |
+
kwargs["model_path"] = model_path
|
| 69 |
+
|
| 70 |
+
# remove the verbosity for the 10-shapes beta parameters
|
| 71 |
+
with contextlib.redirect_stdout(None):
|
| 72 |
+
super(SMPL, self).__init__(**kwargs)
|
| 73 |
+
|
| 74 |
+
J_regressor_extra = np.load(JOINT_REGRESSOR_TRAIN_EXTRA)
|
| 75 |
+
self.register_buffer('J_regressor_extra', torch.tensor(J_regressor_extra, dtype=torch.float32))
|
| 76 |
+
vibe_indexes = np.array([JOINT_MAP[i] for i in JOINT_NAMES])
|
| 77 |
+
a2m_indexes = vibe_indexes[action2motion_joints]
|
| 78 |
+
smpl_indexes = np.arange(24)
|
| 79 |
+
a2mpl_indexes = np.unique(np.r_[smpl_indexes, a2m_indexes])
|
| 80 |
+
|
| 81 |
+
self.maps = {"vibe": vibe_indexes,
|
| 82 |
+
"a2m": a2m_indexes,
|
| 83 |
+
"smpl": smpl_indexes,
|
| 84 |
+
"a2mpl": a2mpl_indexes}
|
| 85 |
+
|
| 86 |
+
def forward(self, *args, **kwargs):
|
| 87 |
+
smpl_output = super(SMPL, self).forward(*args, **kwargs)
|
| 88 |
+
|
| 89 |
+
extra_joints = vertices2joints(self.J_regressor_extra, smpl_output.vertices)
|
| 90 |
+
all_joints = torch.cat([smpl_output.joints, extra_joints], dim=1)
|
| 91 |
+
|
| 92 |
+
output = {"vertices": smpl_output.vertices}
|
| 93 |
+
|
| 94 |
+
for joinstype, indexes in self.maps.items():
|
| 95 |
+
output[joinstype] = all_joints[:, indexes]
|
| 96 |
+
|
| 97 |
+
return output
|