Fred808 commited on
Commit
d281d55
·
verified ·
1 Parent(s): 5eb4117

Upload 5 files

Browse files
model/BERT/BERT_encoder.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import os
3
+
4
+ def load_bert(model_path):
5
+ bert = BERT(model_path)
6
+ bert.eval()
7
+ bert.text_model.training = False
8
+ for p in bert.parameters():
9
+ p.requires_grad = False
10
+ return bert
11
+
12
+ class BERT(nn.Module):
13
+ def __init__(self, modelpath: str):
14
+ super().__init__()
15
+
16
+ from transformers import AutoTokenizer, AutoModel
17
+ from transformers import logging
18
+ logging.set_verbosity_error()
19
+ # Tokenizer
20
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
21
+ # Tokenizer
22
+ self.tokenizer = AutoTokenizer.from_pretrained(modelpath)
23
+ # Text model
24
+ self.text_model = AutoModel.from_pretrained(modelpath)
25
+
26
+
27
+ def forward(self, texts):
28
+ encoded_inputs = self.tokenizer(texts, return_tensors="pt", padding=True)
29
+ output = self.text_model(**encoded_inputs.to(self.text_model.device)).last_hidden_state
30
+ mask = encoded_inputs.attention_mask.to(dtype=bool)
31
+ # output = output * mask.unsqueeze(-1)
32
+ return output, mask
model/cfg_sampler.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ from copy import deepcopy
5
+
6
+ # A wrapper model for Classifier-free guidance **SAMPLING** only
7
+ # https://arxiv.org/abs/2207.12598
8
+ class ClassifierFreeSampleModel(nn.Module):
9
+
10
+ def __init__(self, model):
11
+ super().__init__()
12
+ self.model = model # model is the actual model to run
13
+
14
+ assert self.model.cond_mask_prob > 0, 'Cannot run a guided diffusion on a model that has not been trained with no conditions'
15
+
16
+ # pointers to inner model
17
+ self.rot2xyz = self.model.rot2xyz
18
+ self.translation = self.model.translation
19
+ self.njoints = self.model.njoints
20
+ self.nfeats = self.model.nfeats
21
+ self.data_rep = self.model.data_rep
22
+ self.cond_mode = self.model.cond_mode
23
+ self.encode_text = self.model.encode_text
24
+
25
+ def forward(self, x, timesteps, y=None):
26
+ cond_mode = self.model.cond_mode
27
+ assert cond_mode in ['text', 'action']
28
+ y_uncond = deepcopy(y)
29
+ y_uncond['uncond'] = True
30
+ out = self.model(x, timesteps, y)
31
+ out_uncond = self.model(x, timesteps, y_uncond)
32
+ return out_uncond + (y['scale'].view(-1, 1, 1, 1) * (out - out_uncond))
33
+
model/mdm.py ADDED
@@ -0,0 +1,480 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import clip
6
+ from model.rotation2xyz import Rotation2xyz
7
+ from model.BERT.BERT_encoder import load_bert
8
+ from utils.misc import WeightedSum
9
+
10
+
11
+ class MDM(nn.Module):
12
+ def __init__(self, modeltype, njoints, nfeats, num_actions, translation, pose_rep, glob, glob_rot,
13
+ latent_dim=256, ff_size=1024, num_layers=8, num_heads=4, dropout=0.1,
14
+ ablation=None, activation="gelu", legacy=False, data_rep='rot6d', dataset='amass', clip_dim=512,
15
+ arch='trans_enc', emb_trans_dec=False, clip_version=None, **kargs):
16
+ super().__init__()
17
+
18
+ self.legacy = legacy
19
+ self.modeltype = modeltype
20
+ self.njoints = njoints
21
+ self.nfeats = nfeats
22
+ self.num_actions = num_actions
23
+ self.data_rep = data_rep
24
+ self.dataset = dataset
25
+
26
+ self.pose_rep = pose_rep
27
+ self.glob = glob
28
+ self.glob_rot = glob_rot
29
+ self.translation = translation
30
+
31
+ self.latent_dim = latent_dim
32
+
33
+ self.ff_size = ff_size
34
+ self.num_layers = num_layers
35
+ self.num_heads = num_heads
36
+ self.dropout = dropout
37
+
38
+ self.ablation = ablation
39
+ self.activation = activation
40
+ self.clip_dim = clip_dim
41
+ self.action_emb = kargs.get('action_emb', None)
42
+ self.input_feats = self.njoints * self.nfeats
43
+
44
+ self.normalize_output = kargs.get('normalize_encoder_output', False)
45
+
46
+ self.cond_mode = kargs.get('cond_mode', 'no_cond')
47
+ self.cond_mask_prob = kargs.get('cond_mask_prob', 0.)
48
+ self.mask_frames = kargs.get('mask_frames', False)
49
+ self.arch = arch
50
+ self.gru_emb_dim = self.latent_dim if self.arch == 'gru' else 0
51
+ self.input_process = InputProcess(self.data_rep, self.input_feats+self.gru_emb_dim, self.latent_dim)
52
+
53
+ self.emb_policy = kargs.get('emb_policy', 'add')
54
+
55
+ self.sequence_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout, max_len=kargs.get('pos_embed_max_len', 5000))
56
+ self.emb_trans_dec = emb_trans_dec
57
+
58
+ self.pred_len = kargs.get('pred_len', 0)
59
+ self.context_len = kargs.get('context_len', 0)
60
+ self.total_len = self.pred_len + self.context_len
61
+ self.is_prefix_comp = self.total_len > 0
62
+ self.all_goal_joint_names = kargs.get('all_goal_joint_names', [])
63
+
64
+ self.multi_target_cond = kargs.get('multi_target_cond', False)
65
+ self.multi_encoder_type = kargs.get('multi_encoder_type', 'multi')
66
+ self.target_enc_layers = kargs.get('target_enc_layers', 1)
67
+ if self.multi_target_cond:
68
+ if self.multi_encoder_type == 'multi':
69
+ self.embed_target_cond = EmbedTargetLocMulti(self.all_goal_joint_names, self.latent_dim)
70
+ elif self.multi_encoder_type == 'single':
71
+ self.embed_target_cond = EmbedTargetLocSingle(self.all_goal_joint_names, self.latent_dim, self.target_enc_layers)
72
+ elif self.multi_encoder_type == 'split':
73
+ self.embed_target_cond = EmbedTargetLocSplit(self.all_goal_joint_names, self.latent_dim, self.target_enc_layers)
74
+
75
+ if self.arch == 'trans_enc':
76
+ print("TRANS_ENC init")
77
+ seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim,
78
+ nhead=self.num_heads,
79
+ dim_feedforward=self.ff_size,
80
+ dropout=self.dropout,
81
+ activation=self.activation)
82
+
83
+ self.seqTransEncoder = nn.TransformerEncoder(seqTransEncoderLayer,
84
+ num_layers=self.num_layers)
85
+ elif self.arch == 'trans_dec':
86
+ print("TRANS_DEC init")
87
+ seqTransDecoderLayer = nn.TransformerDecoderLayer(d_model=self.latent_dim,
88
+ nhead=self.num_heads,
89
+ dim_feedforward=self.ff_size,
90
+ dropout=self.dropout,
91
+ activation=activation)
92
+ self.seqTransDecoder = nn.TransformerDecoder(seqTransDecoderLayer,
93
+ num_layers=self.num_layers)
94
+ elif self.arch == 'gru':
95
+ print("GRU init")
96
+ self.gru = nn.GRU(self.latent_dim, self.latent_dim, num_layers=self.num_layers, batch_first=True)
97
+ else:
98
+ raise ValueError('Please choose correct architecture [trans_enc, trans_dec, gru]')
99
+
100
+ self.embed_timestep = TimestepEmbedder(self.latent_dim, self.sequence_pos_encoder)
101
+
102
+ if self.cond_mode != 'no_cond':
103
+ if 'text' in self.cond_mode:
104
+ # We support CLIP encoder and DistilBERT
105
+ print('EMBED TEXT')
106
+
107
+ self.text_encoder_type = kargs.get('text_encoder_type', 'clip')
108
+
109
+ if self.text_encoder_type == "clip":
110
+ print('Loading CLIP...')
111
+ self.clip_version = clip_version
112
+ self.clip_model = self.load_and_freeze_clip(clip_version)
113
+ self.encode_text = self.clip_encode_text
114
+ elif self.text_encoder_type == 'bert':
115
+ assert self.arch == 'trans_dec'
116
+ # assert self.emb_trans_dec == False # passing just the time embed so it's fine
117
+ print("Loading BERT...")
118
+ # bert_model_path = 'model/BERT/distilbert-base-uncased'
119
+ bert_model_path = 'distilbert/distilbert-base-uncased'
120
+ self.clip_model = load_bert(bert_model_path) # Sorry for that, the naming is for backward compatibility
121
+ self.encode_text = self.bert_encode_text
122
+ self.clip_dim = 768
123
+ else:
124
+ raise ValueError('We only support [CLIP, BERT] text encoders')
125
+
126
+ self.embed_text = nn.Linear(self.clip_dim, self.latent_dim)
127
+
128
+ if 'action' in self.cond_mode:
129
+ self.embed_action = EmbedAction(self.num_actions, self.latent_dim)
130
+ print('EMBED ACTION')
131
+
132
+ self.output_process = OutputProcess(self.data_rep, self.input_feats, self.latent_dim, self.njoints,
133
+ self.nfeats)
134
+
135
+ self.rot2xyz = Rotation2xyz(device='cpu', dataset=self.dataset)
136
+
137
+ def parameters_wo_clip(self):
138
+ return [p for name, p in self.named_parameters() if not name.startswith('clip_model.')]
139
+
140
+ def load_and_freeze_clip(self, clip_version):
141
+ clip_model, clip_preprocess = clip.load(clip_version, device='cpu',
142
+ jit=False) # Must set jit=False for training
143
+ clip.model.convert_weights(
144
+ clip_model) # Actually this line is unnecessary since clip by default already on float16
145
+
146
+ # Freeze CLIP weights
147
+ clip_model.eval()
148
+ for p in clip_model.parameters():
149
+ p.requires_grad = False
150
+
151
+ return clip_model
152
+
153
+ def mask_cond(self, cond, force_mask=False):
154
+ bs = cond.shape[-2]
155
+ if force_mask:
156
+ return torch.zeros_like(cond)
157
+ elif self.training and self.cond_mask_prob > 0.:
158
+ mask = torch.bernoulli(torch.ones(bs, device=cond.device) * self.cond_mask_prob).view(1, bs, 1) # 1-> use null_cond, 0-> use real cond
159
+ return cond * (1. - mask)
160
+ else:
161
+ return cond
162
+
163
+ def clip_encode_text(self, raw_text):
164
+ # raw_text - list (batch_size length) of strings with input text prompts
165
+ device = next(self.parameters()).device
166
+ max_text_len = 20 if self.dataset in ['humanml', 'kit'] else None # Specific hardcoding for humanml dataset
167
+ if max_text_len is not None:
168
+ default_context_length = 77
169
+ context_length = max_text_len + 2 # start_token + 20 + end_token
170
+ assert context_length < default_context_length
171
+ texts = clip.tokenize(raw_text, context_length=context_length, truncate=True).to(device) # [bs, context_length] # if n_tokens > context_length -> will truncate
172
+ # print('texts', texts.shape)
173
+ zero_pad = torch.zeros([texts.shape[0], default_context_length-context_length], dtype=texts.dtype, device=texts.device)
174
+ texts = torch.cat([texts, zero_pad], dim=1)
175
+ # print('texts after pad', texts.shape, texts)
176
+ else:
177
+ texts = clip.tokenize(raw_text, truncate=True).to(device) # [bs, context_length] # if n_tokens > 77 -> will truncate
178
+ return self.clip_model.encode_text(texts).float().unsqueeze(0)
179
+
180
+ def bert_encode_text(self, raw_text):
181
+ # enc_text = self.clip_model(raw_text)
182
+ # enc_text = enc_text.permute(1, 0, 2)
183
+ # return enc_text
184
+ enc_text, mask = self.clip_model(raw_text) # self.clip_model.get_last_hidden_state(raw_text, return_mask=True) # mask: False means no token there
185
+ enc_text = enc_text.permute(1, 0, 2)
186
+ mask = ~mask # mask: True means no token there, we invert since the meaning of mask for transformer is inverted https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html
187
+ return enc_text, mask
188
+
189
+ def forward(self, x, timesteps, y=None):
190
+ """
191
+ x: [batch_size, njoints, nfeats, max_frames], denoted x_t in the paper
192
+ timesteps: [batch_size] (int)
193
+ """
194
+ bs, njoints, nfeats, nframes = x.shape
195
+ time_emb = self.embed_timestep(timesteps) # [1, bs, d]
196
+
197
+ if 'target_cond' in y.keys():
198
+ # NOTE: We don't use CFG for joints - but we do wat to support uncond sampling for generation and eval!
199
+ time_emb += self.mask_cond(self.embed_target_cond(y['target_cond'], y['target_joint_names'], y['is_heading'])[None], force_mask=y.get('target_uncond', False)) # For uncond support and CFG
200
+ # time_emb += self.embed_target_cond(y['target_cond'], y['target_joint_names'], y['is_heading'])[None]
201
+
202
+ # Build input for prefix completion
203
+ if self.is_prefix_comp:
204
+ x = torch.cat([y['prefix'], x], dim=-1)
205
+ y['mask'] = torch.cat([torch.ones([bs, 1, 1, self.context_len], dtype=y['mask'].dtype, device=y['mask'].device),
206
+ y['mask']], dim=-1)
207
+
208
+ force_mask = y.get('uncond', False)
209
+ if 'text' in self.cond_mode:
210
+ if 'text_embed' in y.keys(): # caching option
211
+ enc_text = y['text_embed']
212
+ else:
213
+ enc_text = self.encode_text(y['text'])
214
+ if type(enc_text) == tuple:
215
+ enc_text, text_mask = enc_text
216
+ if text_mask.shape[0] == 1 and bs > 1: # casting mask for the single-prompt-for-all case
217
+ text_mask = torch.repeat_interleave(text_mask, bs, dim=0)
218
+ text_emb = self.embed_text(self.mask_cond(enc_text, force_mask=force_mask)) # casting mask for the single-prompt-for-all case
219
+ if self.emb_policy == 'add':
220
+ emb = text_emb + time_emb
221
+ else:
222
+ emb = torch.cat([time_emb, text_emb], dim=0)
223
+ text_mask = torch.cat([torch.zeros_like(text_mask[:, 0:1]), text_mask], dim=1)
224
+ if 'action' in self.cond_mode:
225
+ action_emb = self.embed_action(y['action'])
226
+ emb = time_emb + self.mask_cond(action_emb, force_mask=force_mask)
227
+ if self.cond_mode == 'no_cond':
228
+ # unconstrained
229
+ emb = time_emb
230
+
231
+ if self.arch == 'gru':
232
+ x_reshaped = x.reshape(bs, njoints*nfeats, 1, nframes)
233
+ emb_gru = emb.repeat(nframes, 1, 1) #[#frames, bs, d]
234
+ emb_gru = emb_gru.permute(1, 2, 0) #[bs, d, #frames]
235
+ emb_gru = emb_gru.reshape(bs, self.latent_dim, 1, nframes) #[bs, d, 1, #frames]
236
+ x = torch.cat((x_reshaped, emb_gru), axis=1) #[bs, d+joints*feat, 1, #frames]
237
+
238
+ x = self.input_process(x)
239
+
240
+ # TODO - move to collate
241
+ frames_mask = None
242
+ is_valid_mask = y['mask'].shape[-1] > 1 # Don't use mask with the generate script
243
+ if self.mask_frames and is_valid_mask:
244
+ frames_mask = torch.logical_not(y['mask'][..., :x.shape[0]].squeeze(1).squeeze(1)).to(device=x.device)
245
+ if self.emb_trans_dec or self.arch == 'trans_enc':
246
+ step_mask = torch.zeros((bs, 1), dtype=torch.bool, device=x.device)
247
+ frames_mask = torch.cat([step_mask, frames_mask], dim=1)
248
+
249
+ if self.arch == 'trans_enc':
250
+ # adding the timestep embed
251
+ xseq = torch.cat((emb, x), axis=0) # [seqlen+1, bs, d]
252
+ xseq = self.sequence_pos_encoder(xseq) # [seqlen+1, bs, d]
253
+ output = self.seqTransEncoder(xseq, src_key_padding_mask=frames_mask)[1:] # , src_key_padding_mask=~maskseq) # [seqlen, bs, d]
254
+
255
+ elif self.arch == 'trans_dec':
256
+ if self.emb_trans_dec:
257
+ xseq = torch.cat((time_emb, x), axis=0)
258
+ else:
259
+ xseq = x
260
+ xseq = self.sequence_pos_encoder(xseq) # [seqlen+1, bs, d]
261
+
262
+ if self.text_encoder_type == 'clip':
263
+ output = self.seqTransDecoder(tgt=xseq, memory=emb, tgt_key_padding_mask=frames_mask)
264
+ elif self.text_encoder_type == 'bert':
265
+ output = self.seqTransDecoder(tgt=xseq, memory=emb, memory_key_padding_mask=text_mask, tgt_key_padding_mask=frames_mask) # Rotem's bug fix
266
+ else:
267
+ raise ValueError()
268
+
269
+ if self.emb_trans_dec:
270
+ output = output[1:] # [seqlen, bs, d]
271
+
272
+ elif self.arch == 'gru':
273
+ xseq = x
274
+ xseq = self.sequence_pos_encoder(xseq) # [seqlen, bs, d]
275
+ output, _ = self.gru(xseq)
276
+
277
+ # Extract completed suffix
278
+ if self.is_prefix_comp:
279
+ output = output[self.context_len:]
280
+ y['mask'] = y['mask'][..., self.context_len:]
281
+
282
+ output = self.output_process(output) # [bs, njoints, nfeats, nframes]
283
+ return output
284
+
285
+
286
+ def _apply(self, fn):
287
+ super()._apply(fn)
288
+ self.rot2xyz.smpl_model._apply(fn)
289
+
290
+
291
+ def train(self, *args, **kwargs):
292
+ super().train(*args, **kwargs)
293
+ self.rot2xyz.smpl_model.train(*args, **kwargs)
294
+
295
+
296
+ class PositionalEncoding(nn.Module):
297
+ def __init__(self, d_model, dropout=0.1, max_len=5000):
298
+ super(PositionalEncoding, self).__init__()
299
+ self.dropout = nn.Dropout(p=dropout)
300
+
301
+ pe = torch.zeros(max_len, d_model)
302
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
303
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
304
+ pe[:, 0::2] = torch.sin(position * div_term)
305
+ pe[:, 1::2] = torch.cos(position * div_term)
306
+ pe = pe.unsqueeze(0).transpose(0, 1)
307
+
308
+ self.register_buffer('pe', pe)
309
+
310
+ def forward(self, x):
311
+ # not used in the final model
312
+ x = x + self.pe[:x.shape[0], :]
313
+ return self.dropout(x)
314
+
315
+
316
+ class TimestepEmbedder(nn.Module):
317
+ def __init__(self, latent_dim, sequence_pos_encoder):
318
+ super().__init__()
319
+ self.latent_dim = latent_dim
320
+ self.sequence_pos_encoder = sequence_pos_encoder
321
+
322
+ time_embed_dim = self.latent_dim
323
+ self.time_embed = nn.Sequential(
324
+ nn.Linear(self.latent_dim, time_embed_dim),
325
+ nn.SiLU(),
326
+ nn.Linear(time_embed_dim, time_embed_dim),
327
+ )
328
+
329
+ def forward(self, timesteps):
330
+ return self.time_embed(self.sequence_pos_encoder.pe[timesteps]).permute(1, 0, 2)
331
+
332
+
333
+ class InputProcess(nn.Module):
334
+ def __init__(self, data_rep, input_feats, latent_dim):
335
+ super().__init__()
336
+ self.data_rep = data_rep
337
+ self.input_feats = input_feats
338
+ self.latent_dim = latent_dim
339
+ self.poseEmbedding = nn.Linear(self.input_feats, self.latent_dim)
340
+ if self.data_rep == 'rot_vel':
341
+ self.velEmbedding = nn.Linear(self.input_feats, self.latent_dim)
342
+
343
+ def forward(self, x):
344
+ bs, njoints, nfeats, nframes = x.shape
345
+ x = x.permute((3, 0, 1, 2)).reshape(nframes, bs, njoints*nfeats)
346
+
347
+ if self.data_rep in ['rot6d', 'xyz', 'hml_vec']:
348
+ x = self.poseEmbedding(x) # [seqlen, bs, d]
349
+ return x
350
+ elif self.data_rep == 'rot_vel':
351
+ first_pose = x[[0]] # [1, bs, 150]
352
+ first_pose = self.poseEmbedding(first_pose) # [1, bs, d]
353
+ vel = x[1:] # [seqlen-1, bs, 150]
354
+ vel = self.velEmbedding(vel) # [seqlen-1, bs, d]
355
+ return torch.cat((first_pose, vel), axis=0) # [seqlen, bs, d]
356
+ else:
357
+ raise ValueError
358
+
359
+
360
+ class OutputProcess(nn.Module):
361
+ def __init__(self, data_rep, input_feats, latent_dim, njoints, nfeats):
362
+ super().__init__()
363
+ self.data_rep = data_rep
364
+ self.input_feats = input_feats
365
+ self.latent_dim = latent_dim
366
+ self.njoints = njoints
367
+ self.nfeats = nfeats
368
+ self.poseFinal = nn.Linear(self.latent_dim, self.input_feats)
369
+ if self.data_rep == 'rot_vel':
370
+ self.velFinal = nn.Linear(self.latent_dim, self.input_feats)
371
+
372
+ def forward(self, output):
373
+ nframes, bs, d = output.shape
374
+ if self.data_rep in ['rot6d', 'xyz', 'hml_vec']:
375
+ output = self.poseFinal(output) # [seqlen, bs, 150]
376
+ elif self.data_rep == 'rot_vel':
377
+ first_pose = output[[0]] # [1, bs, d]
378
+ first_pose = self.poseFinal(first_pose) # [1, bs, 150]
379
+ vel = output[1:] # [seqlen-1, bs, d]
380
+ vel = self.velFinal(vel) # [seqlen-1, bs, 150]
381
+ output = torch.cat((first_pose, vel), axis=0) # [seqlen, bs, 150]
382
+ else:
383
+ raise ValueError
384
+ output = output.reshape(nframes, bs, self.njoints, self.nfeats)
385
+ output = output.permute(1, 2, 3, 0) # [bs, njoints, nfeats, nframes]
386
+ return output
387
+
388
+
389
+ class EmbedAction(nn.Module):
390
+ def __init__(self, num_actions, latent_dim):
391
+ super().__init__()
392
+ self.action_embedding = nn.Parameter(torch.randn(num_actions, latent_dim))
393
+
394
+ def forward(self, input):
395
+ idx = input[:, 0].to(torch.long) # an index array must be long
396
+ output = self.action_embedding[idx]
397
+ return output
398
+
399
+ class EmbedTargetLocSingle(nn.Module):
400
+ def __init__(self, all_goal_joint_names, latent_dim, num_layers=1):
401
+ super().__init__()
402
+ self.extended_goal_joint_names = all_goal_joint_names + ['traj', 'heading']
403
+ self.target_cond_dim = len(self.extended_goal_joint_names) * 4 # 4 => (x,y,z,is_valid)
404
+ self.latent_dim = latent_dim
405
+ _layers = [nn.Linear(self.target_cond_dim, self.latent_dim)]
406
+ for _ in range(num_layers):
407
+ _layers += [nn.SiLU(), nn.Linear(self.latent_dim, self.latent_dim)]
408
+ self.mlp = nn.Sequential(*_layers)
409
+
410
+ def forward(self, input, target_joint_names, target_heading):
411
+ # TODO - generate validity from outside the model
412
+ validity = torch.zeros_like(input)[..., :1]
413
+ for sample_idx, sample_joint_names in enumerate(target_joint_names):
414
+ sample_joint_names_w_heading = np.append(sample_joint_names, 'heading') if target_heading[sample_idx] else sample_joint_names
415
+ for j in sample_joint_names_w_heading:
416
+ validity[sample_idx, self.extended_goal_joint_names.index(j)] = 1.
417
+
418
+ mlp_input = torch.cat([input, validity], dim=-1).view(input.shape[0], -1)
419
+ return self.mlp(mlp_input)
420
+
421
+
422
+ class EmbedTargetLocSplit(nn.Module):
423
+ def __init__(self, all_goal_joint_names, latent_dim, num_layers=1):
424
+ super().__init__()
425
+ self.extended_goal_joint_names = all_goal_joint_names + ['traj', 'heading']
426
+ self.target_cond_dim = 4
427
+ self.latent_dim = latent_dim
428
+ self.splited_dim = self.latent_dim // len(self.extended_goal_joint_names)
429
+ assert self.latent_dim % len(self.extended_goal_joint_names) == 0
430
+ self.mini_mlps = nn.ModuleList()
431
+ for _ in self.extended_goal_joint_names:
432
+ _layers = [nn.Linear(self.target_cond_dim, self.splited_dim)]
433
+ for _ in range(num_layers):
434
+ _layers += [nn.SiLU(), nn.Linear(self.splited_dim, self.splited_dim)]
435
+ self.mini_mlps.append(nn.Sequential(*_layers))
436
+
437
+ def forward(self, input, target_joint_names, target_heading):
438
+ # TODO - generate validity from outside the model
439
+ validity = torch.zeros_like(input)[..., :1]
440
+ for sample_idx, sample_joint_names in enumerate(target_joint_names):
441
+ sample_joint_names_w_heading = np.append(sample_joint_names, 'heading') if target_heading[sample_idx] else sample_joint_names
442
+ for j in sample_joint_names_w_heading:
443
+ validity[sample_idx, self.extended_goal_joint_names.index(j)] = 1.
444
+
445
+ mlp_input = torch.cat([input, validity], dim=-1)
446
+ mlp_splits = [self.mini_mlps[i](mlp_input[:, i]) for i in range(mlp_input.shape[1])]
447
+ return torch.cat(mlp_splits, dim=-1)
448
+
449
+ class EmbedTargetLocMulti(nn.Module):
450
+ def __init__(self, all_goal_joint_names, latent_dim):
451
+ super().__init__()
452
+
453
+ # todo: use a tensor of weight per joint, and another one for biases, then apply a selection in one go like we to for actions
454
+ self.extended_goal_joint_names = all_goal_joint_names + ['traj', 'heading']
455
+ self.extended_goal_joint_idx = {joint_name: idx for idx, joint_name in enumerate(self.extended_goal_joint_names)}
456
+ self.n_extended_goal_joints = len(self.extended_goal_joint_names)
457
+ self.target_loc_emb = nn.ParameterDict({joint_name:
458
+ nn.Sequential(
459
+ nn.Linear(3, latent_dim),
460
+ nn.SiLU(),
461
+ nn.Linear(latent_dim, latent_dim))
462
+ for joint_name in self.extended_goal_joint_names}) # todo: check if 3 works for heading and traj
463
+ # nn.Linear(3, latent_dim) for joint_name in self.extended_goal_joint_names}) # todo: check if 3 works for heading and traj
464
+ self.target_all_loc_emb = WeightedSum(self.n_extended_goal_joints) # nn.Linear(self.n_extended_goal_joints, latent_dim)
465
+ self.latent_dim = latent_dim
466
+
467
+ def forward(self, input, target_joint_names, target_heading):
468
+ output = torch.zeros((input.shape[0], self.latent_dim), dtype=input.dtype, device=input.device)
469
+
470
+ # Iterate over the batch and apply the appropriate filter for each joint
471
+ for sample_idx, sample_joint_names in enumerate(target_joint_names):
472
+ sample_joint_names_w_heading = np.append(sample_joint_names, 'heading') if target_heading[sample_idx] else sample_joint_names
473
+ output_one_sample = torch.zeros((self.n_extended_goal_joints, self.latent_dim), dtype=input.dtype, device=input.device)
474
+ for joint_name in sample_joint_names_w_heading:
475
+ layer = self.target_loc_emb[joint_name]
476
+ output_one_sample[self.extended_goal_joint_idx[joint_name]] = layer(input[sample_idx, self.extended_goal_joint_idx[joint_name]])
477
+ output[sample_idx] = self.target_all_loc_emb(output_one_sample)
478
+ # print(torch.where(output_one_sample.sum(axis=1)!=0)[0].cpu().numpy())
479
+
480
+ return output
model/rotation2xyz.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code is based on https://github.com/Mathux/ACTOR.git
2
+ import torch
3
+ import utils.rotation_conversions as geometry
4
+
5
+
6
+ from model.smpl import SMPL, JOINTSTYPE_ROOT
7
+ # from .get_model import JOINTSTYPES
8
+ JOINTSTYPES = ["a2m", "a2mpl", "smpl", "vibe", "vertices"]
9
+
10
+
11
+ class Rotation2xyz:
12
+ def __init__(self, device, dataset='amass'):
13
+ self.device = device
14
+ self.dataset = dataset
15
+ self.smpl_model = SMPL().eval().to(device)
16
+
17
+ def __call__(self, x, mask, pose_rep, translation, glob,
18
+ jointstype, vertstrans, betas=None, beta=0,
19
+ glob_rot=None, get_rotations_back=False, **kwargs):
20
+ if pose_rep == "xyz":
21
+ return x
22
+
23
+ if mask is None:
24
+ mask = torch.ones((x.shape[0], x.shape[-1]), dtype=bool, device=x.device)
25
+
26
+ if not glob and glob_rot is None:
27
+ raise TypeError("You must specify global rotation if glob is False")
28
+
29
+ if jointstype not in JOINTSTYPES:
30
+ raise NotImplementedError("This jointstype is not implemented.")
31
+
32
+ if translation:
33
+ x_translations = x[:, -1, :3]
34
+ x_rotations = x[:, :-1]
35
+ else:
36
+ x_rotations = x
37
+
38
+ x_rotations = x_rotations.permute(0, 3, 1, 2)
39
+ nsamples, time, njoints, feats = x_rotations.shape
40
+
41
+ # Compute rotations (convert only masked sequences output)
42
+ if pose_rep == "rotvec":
43
+ rotations = geometry.axis_angle_to_matrix(x_rotations[mask])
44
+ elif pose_rep == "rotmat":
45
+ rotations = x_rotations[mask].view(-1, njoints, 3, 3)
46
+ elif pose_rep == "rotquat":
47
+ rotations = geometry.quaternion_to_matrix(x_rotations[mask])
48
+ elif pose_rep == "rot6d":
49
+ rotations = geometry.rotation_6d_to_matrix(x_rotations[mask])
50
+ else:
51
+ raise NotImplementedError("No geometry for this one.")
52
+
53
+ if not glob:
54
+ global_orient = torch.tensor(glob_rot, device=x.device)
55
+ global_orient = geometry.axis_angle_to_matrix(global_orient).view(1, 1, 3, 3)
56
+ global_orient = global_orient.repeat(len(rotations), 1, 1, 1)
57
+ else:
58
+ global_orient = rotations[:, 0]
59
+ rotations = rotations[:, 1:]
60
+
61
+ if betas is None:
62
+ betas = torch.zeros([rotations.shape[0], self.smpl_model.num_betas],
63
+ dtype=rotations.dtype, device=rotations.device)
64
+ betas[:, 1] = beta
65
+ # import ipdb; ipdb.set_trace()
66
+ out = self.smpl_model(body_pose=rotations, global_orient=global_orient, betas=betas)
67
+
68
+ # get the desirable joints
69
+ joints = out[jointstype]
70
+
71
+ x_xyz = torch.empty(nsamples, time, joints.shape[1], 3, device=x.device, dtype=x.dtype)
72
+ x_xyz[~mask] = 0
73
+ x_xyz[mask] = joints
74
+
75
+ x_xyz = x_xyz.permute(0, 2, 3, 1).contiguous()
76
+
77
+ # the first translation root at the origin on the prediction
78
+ if jointstype != "vertices":
79
+ rootindex = JOINTSTYPE_ROOT[jointstype]
80
+ x_xyz = x_xyz - x_xyz[:, [rootindex], :, :]
81
+
82
+ if translation and vertstrans:
83
+ # the first translation root at the origin
84
+ x_translations = x_translations - x_translations[:, :, [0]]
85
+
86
+ # add the translation to all the joints
87
+ x_xyz = x_xyz + x_translations[:, None, :, :]
88
+
89
+ if get_rotations_back:
90
+ return x_xyz, rotations, global_orient
91
+ else:
92
+ return x_xyz
model/smpl.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code is based on https://github.com/Mathux/ACTOR.git
2
+ import numpy as np
3
+ import torch
4
+
5
+ import contextlib
6
+
7
+ from smplx import SMPLLayer as _SMPLLayer
8
+ from smplx.lbs import vertices2joints
9
+
10
+
11
+ # action2motion_joints = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 21, 24, 38]
12
+ # change 0 and 8
13
+ action2motion_joints = [8, 1, 2, 3, 4, 5, 6, 7, 0, 9, 10, 11, 12, 13, 14, 21, 24, 38]
14
+
15
+ from utils.config import SMPL_MODEL_PATH, JOINT_REGRESSOR_TRAIN_EXTRA
16
+
17
+ JOINTSTYPE_ROOT = {"a2m": 0, # action2motion
18
+ "smpl": 0,
19
+ "a2mpl": 0, # set(smpl, a2m)
20
+ "vibe": 8} # 0 is the 8 position: OP MidHip below
21
+
22
+ JOINT_MAP = {
23
+ 'OP Nose': 24, 'OP Neck': 12, 'OP RShoulder': 17,
24
+ 'OP RElbow': 19, 'OP RWrist': 21, 'OP LShoulder': 16,
25
+ 'OP LElbow': 18, 'OP LWrist': 20, 'OP MidHip': 0,
26
+ 'OP RHip': 2, 'OP RKnee': 5, 'OP RAnkle': 8,
27
+ 'OP LHip': 1, 'OP LKnee': 4, 'OP LAnkle': 7,
28
+ 'OP REye': 25, 'OP LEye': 26, 'OP REar': 27,
29
+ 'OP LEar': 28, 'OP LBigToe': 29, 'OP LSmallToe': 30,
30
+ 'OP LHeel': 31, 'OP RBigToe': 32, 'OP RSmallToe': 33, 'OP RHeel': 34,
31
+ 'Right Ankle': 8, 'Right Knee': 5, 'Right Hip': 45,
32
+ 'Left Hip': 46, 'Left Knee': 4, 'Left Ankle': 7,
33
+ 'Right Wrist': 21, 'Right Elbow': 19, 'Right Shoulder': 17,
34
+ 'Left Shoulder': 16, 'Left Elbow': 18, 'Left Wrist': 20,
35
+ 'Neck (LSP)': 47, 'Top of Head (LSP)': 48,
36
+ 'Pelvis (MPII)': 49, 'Thorax (MPII)': 50,
37
+ 'Spine (H36M)': 51, 'Jaw (H36M)': 52,
38
+ 'Head (H36M)': 53, 'Nose': 24, 'Left Eye': 26,
39
+ 'Right Eye': 25, 'Left Ear': 28, 'Right Ear': 27
40
+ }
41
+
42
+ JOINT_NAMES = [
43
+ 'OP Nose', 'OP Neck', 'OP RShoulder',
44
+ 'OP RElbow', 'OP RWrist', 'OP LShoulder',
45
+ 'OP LElbow', 'OP LWrist', 'OP MidHip',
46
+ 'OP RHip', 'OP RKnee', 'OP RAnkle',
47
+ 'OP LHip', 'OP LKnee', 'OP LAnkle',
48
+ 'OP REye', 'OP LEye', 'OP REar',
49
+ 'OP LEar', 'OP LBigToe', 'OP LSmallToe',
50
+ 'OP LHeel', 'OP RBigToe', 'OP RSmallToe', 'OP RHeel',
51
+ 'Right Ankle', 'Right Knee', 'Right Hip',
52
+ 'Left Hip', 'Left Knee', 'Left Ankle',
53
+ 'Right Wrist', 'Right Elbow', 'Right Shoulder',
54
+ 'Left Shoulder', 'Left Elbow', 'Left Wrist',
55
+ 'Neck (LSP)', 'Top of Head (LSP)',
56
+ 'Pelvis (MPII)', 'Thorax (MPII)',
57
+ 'Spine (H36M)', 'Jaw (H36M)',
58
+ 'Head (H36M)', 'Nose', 'Left Eye',
59
+ 'Right Eye', 'Left Ear', 'Right Ear'
60
+ ]
61
+
62
+
63
+ # adapted from VIBE/SPIN to output smpl_joints, vibe joints and action2motion joints
64
+ class SMPL(_SMPLLayer):
65
+ """ Extension of the official SMPL implementation to support more joints """
66
+
67
+ def __init__(self, model_path=SMPL_MODEL_PATH, **kwargs):
68
+ kwargs["model_path"] = model_path
69
+
70
+ # remove the verbosity for the 10-shapes beta parameters
71
+ with contextlib.redirect_stdout(None):
72
+ super(SMPL, self).__init__(**kwargs)
73
+
74
+ J_regressor_extra = np.load(JOINT_REGRESSOR_TRAIN_EXTRA)
75
+ self.register_buffer('J_regressor_extra', torch.tensor(J_regressor_extra, dtype=torch.float32))
76
+ vibe_indexes = np.array([JOINT_MAP[i] for i in JOINT_NAMES])
77
+ a2m_indexes = vibe_indexes[action2motion_joints]
78
+ smpl_indexes = np.arange(24)
79
+ a2mpl_indexes = np.unique(np.r_[smpl_indexes, a2m_indexes])
80
+
81
+ self.maps = {"vibe": vibe_indexes,
82
+ "a2m": a2m_indexes,
83
+ "smpl": smpl_indexes,
84
+ "a2mpl": a2mpl_indexes}
85
+
86
+ def forward(self, *args, **kwargs):
87
+ smpl_output = super(SMPL, self).forward(*args, **kwargs)
88
+
89
+ extra_joints = vertices2joints(self.J_regressor_extra, smpl_output.vertices)
90
+ all_joints = torch.cat([smpl_output.joints, extra_joints], dim=1)
91
+
92
+ output = {"vertices": smpl_output.vertices}
93
+
94
+ for joinstype, indexes in self.maps.items():
95
+ output[joinstype] = all_joints[:, indexes]
96
+
97
+ return output