Spaces:
Build error
Build error
| import clip | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from model.config import MDMConfig | |
| from model.rotation2xyz import Rotation2xyz | |
| class MDM(nn.Module): | |
| def from_config(cls, config: MDMConfig): | |
| """ | |
| Instantiate MDM from an MDMConfig object. | |
| """ | |
| return cls( | |
| modeltype=config.modeltype, | |
| njoints=config.njoints, | |
| nfeats=config.nfeats, | |
| num_actions=config.num_actions, | |
| translation=config.translation, | |
| pose_rep=config.pose_rep, | |
| glob=config.glob, | |
| glob_rot=config.glob_rot, | |
| latent_dim=config.latent_dim, | |
| ff_size=config.ff_size, | |
| num_layers=config.layers, | |
| num_heads=config.num_heads, | |
| dropout=config.dropout, | |
| ablation=config.ablation, | |
| activation=config.activation, | |
| legacy=config.legacy, | |
| data_rep=config.data_rep, | |
| dataset=config.dataset, | |
| clip_dim=config.clip_dim, | |
| arch=config.arch, | |
| emb_trans_dec=config.emb_trans_dec, | |
| clip_version=config.clip_version, | |
| action_emb=config.action_emb, | |
| normalize_encoder_output=config.normalize_encoder_output, | |
| cond_mask_prob=config.cond_mask_prob, | |
| mask_frames=config.mask_frames, | |
| emb_policy=config.emb_policy, | |
| pos_embed_max_len=config.pos_embed_max_len, | |
| pred_len=config.pred_len, | |
| context_len=config.context_len, | |
| all_goal_joint_names=config.all_goal_joint_names, | |
| multi_target_cond=config.multi_target_cond, | |
| multi_encoder_type=config.multi_encoder_type, | |
| target_enc_layers=config.target_enc_layers, | |
| ) | |
| def __init__( | |
| self, | |
| modeltype, | |
| njoints, | |
| nfeats, | |
| num_actions, | |
| translation, | |
| pose_rep, | |
| glob, | |
| glob_rot, | |
| latent_dim=256, | |
| ff_size=1024, | |
| num_layers=8, | |
| num_heads=4, | |
| dropout=0.1, | |
| ablation=None, | |
| activation="gelu", | |
| legacy=False, | |
| data_rep="rot6d", | |
| dataset="amass", | |
| clip_dim=512, | |
| arch="trans_enc", | |
| emb_trans_dec=False, | |
| clip_version=None, | |
| **kargs, | |
| ): | |
| super().__init__() | |
| self.legacy = legacy | |
| self.modeltype = modeltype | |
| self.njoints = njoints | |
| self.nfeats = nfeats | |
| self.num_actions = num_actions | |
| self.data_rep = data_rep | |
| self.dataset = dataset | |
| self.pose_rep = pose_rep | |
| self.glob = glob | |
| self.glob_rot = glob_rot | |
| self.translation = translation | |
| self.latent_dim = latent_dim | |
| self.ff_size = ff_size | |
| self.num_layers = num_layers | |
| self.num_heads = num_heads | |
| self.dropout = dropout | |
| self.ablation = ablation | |
| self.activation = activation | |
| self.clip_dim = clip_dim | |
| self.action_emb = kargs.get("action_emb", None) | |
| self.input_feats = self.njoints * self.nfeats | |
| self.normalize_output = kargs.get("normalize_encoder_output", False) | |
| self.cond_mode = kargs.get("cond_mode", "no_cond") | |
| self.cond_mask_prob = kargs.get("cond_mask_prob", 0.0) | |
| self.mask_frames = kargs.get("mask_frames", False) | |
| self.arch = arch | |
| self.emb_policy = kargs.get("emb_policy", "add") | |
| self.pred_len = kargs.get("pred_len", 0) | |
| self.context_len = kargs.get("context_len", 0) | |
| self.total_len = self.pred_len + self.context_len | |
| self.is_prefix_comp = self.total_len > 0 | |
| self.all_goal_joint_names = kargs.get("all_goal_joint_names", []) | |
| self.multi_target_cond = kargs.get("multi_target_cond", False) | |
| self.text_encoder_type = kargs.get("text_encoder_type", "clip") | |
| # Assert some assumptions we're doing for simplicity | |
| assert self.arch == "trans_enc" | |
| assert self.cond_mode == "text" | |
| assert self.text_encoder_type == "clip" | |
| assert not self.multi_target_cond | |
| assert not self.is_prefix_comp | |
| assert self.emb_policy == "add" | |
| assert self.data_rep == "hml_vec" | |
| # Using the Encoder architecture | |
| transformer_encoder_layer = nn.TransformerEncoderLayer( | |
| d_model=self.latent_dim, | |
| nhead=self.num_heads, | |
| dim_feedforward=self.ff_size, | |
| dropout=self.dropout, | |
| activation=self.activation, | |
| ) | |
| self.seqTransEncoder = nn.TransformerEncoder( | |
| transformer_encoder_layer, num_layers=self.num_layers | |
| ) | |
| self.sequence_pos_encoder = PositionalEncoding( | |
| self.latent_dim, self.dropout, max_len=kargs.get("pos_embed_max_len", 5000) | |
| ) | |
| self.embed_timestep = TimestepEmbedder( | |
| self.latent_dim, self.sequence_pos_encoder | |
| ) | |
| # We'll use CLIP for now | |
| self.clip_version = clip_version | |
| self.clip_model = load_and_freeze_clip(clip_version) | |
| self.encode_text = self.clip_encode_text | |
| self.embed_text = nn.Linear(self.clip_dim, self.latent_dim) | |
| # Linear input and output layers | |
| self.input_process = InputProcess(self.input_feats, self.latent_dim) | |
| self.output_process = OutputProcess( | |
| self.input_feats, self.latent_dim, self.njoints, self.nfeats | |
| ) | |
| self.rot2xyz = Rotation2xyz(device="cpu", dataset=self.dataset) | |
| def parameters_wo_clip(self): | |
| return [ | |
| p | |
| for name, p in self.named_parameters() | |
| if not name.startswith("clip_model.") | |
| ] | |
| def mask_cond(self, cond, force_mask=False): | |
| bs = cond.shape[-2] | |
| if force_mask: | |
| return torch.zeros_like(cond) | |
| elif self.training and self.cond_mask_prob > 0.0: | |
| mask = torch.bernoulli( | |
| torch.ones(bs, device=cond.device) * self.cond_mask_prob | |
| ).view(1, bs, 1) # 1-> use null_cond, 0-> use real cond | |
| return cond * (1.0 - mask) | |
| else: | |
| return cond | |
| def clip_encode_text(self, raw_text): | |
| # raw_text - list (batch_size length) of strings with input text prompts | |
| device = next(self.parameters()).device | |
| max_text_len = ( | |
| 20 if self.dataset in ["kit", "humanml", "humanml_with_images"] else None | |
| ) # Specific hardcoding for humanml dataset | |
| if max_text_len is not None: | |
| default_context_length = 77 | |
| context_length = max_text_len + 2 # start_token + 20 + end_token | |
| assert context_length < default_context_length | |
| texts = clip.tokenize( | |
| raw_text, context_length=context_length, truncate=True | |
| ).to( | |
| device | |
| ) # [bs, context_length] # if n_tokens > context_length -> will truncate | |
| # print('texts', texts.shape) | |
| zero_pad = torch.zeros( | |
| [texts.shape[0], default_context_length - context_length], | |
| dtype=texts.dtype, | |
| device=texts.device, | |
| ) | |
| texts = torch.cat([texts, zero_pad], dim=1) | |
| # print('texts after pad', texts.shape, texts) | |
| else: | |
| texts = clip.tokenize(raw_text, truncate=True).to( | |
| device | |
| ) # [bs, context_length] # if n_tokens > 77 -> will truncate | |
| return self.clip_model.encode_text(texts).float().unsqueeze(0) | |
| def motion_to_sequence(self, motion, timesteps, y): | |
| if "text_embed" not in y: | |
| clip_encoded_text = self.encode_text(y["text"]) | |
| else: | |
| clip_encoded_text = y["text_embed"] | |
| # casting mask for the single-prompt-for-all case | |
| force_mask = y.get("uncond", False) | |
| # [1, bs, latent_dim] | |
| text_embedding = self.embed_text( | |
| self.mask_cond(clip_encoded_text, force_mask=force_mask) | |
| ) | |
| # compute the embedding of the timestep + text, z_tk in the paper | |
| time_embedding = self.embed_timestep(timesteps) # [1, bs, latent_dim] | |
| embedding = text_embedding + time_embedding # [1, bs, latent_dim] | |
| # get the motion into latent space | |
| sequence = self.input_process(motion) # [num_frames, bs, latent_dim] | |
| sequence_plus_emb = torch.cat( | |
| (embedding, sequence), dim=0 | |
| ) # [num_frames + 1, bs, latent_dim] | |
| return sequence_plus_emb | |
| def sequence_to_motion(self, sequence_plus_emb): | |
| # remove the embedding from the sequence, remove the z_tk from the paper | |
| sequence = sequence_plus_emb[1:] # [num_frames, bs, latent_dim] | |
| # get back the motion from the latent space | |
| motion = self.output_process( | |
| sequence | |
| ) # [bs, num_joints, num_features, num_frames] | |
| return motion | |
| def prepare_mask(self, sequence, device, y, bs): | |
| # Don't use mask with the generate script | |
| is_valid_mask = y["mask"].shape[-1] > 1 | |
| if self.mask_frames and is_valid_mask: | |
| frames_mask = (torch.logical_not(y["mask"][..., : sequence.shape[0]].squeeze(1).squeeze(1)) | |
| .to(device=device)) | |
| step_mask = torch.zeros((bs, 1), dtype=torch.bool, device=device) | |
| return torch.cat([step_mask, frames_mask], dim=1) | |
| else: | |
| return None | |
| def forward(self, motion, timesteps, y=None): | |
| """ | |
| motion: [bs, num_joints, num_features, num_frames] | |
| timesteps: [bs] | |
| """ | |
| sequence = self.motion_to_sequence(motion, timesteps, y) | |
| # apply positional encoding | |
| sequence = self.sequence_pos_encoder( | |
| sequence | |
| ) # [num_frames + 1, bs, latent_dim] | |
| frames_mask = self.prepare_mask(sequence, motion.device, y, motion.shape[0]) | |
| # actual transformer magic | |
| sequence = self.seqTransEncoder(sequence, src_key_padding_mask=frames_mask) | |
| motion = self.sequence_to_motion(sequence) | |
| return motion | |
| def _apply(self, fn): | |
| super()._apply(fn) | |
| self.rot2xyz.smpl_model._apply(fn) | |
| def train(self, *args, **kwargs): | |
| super().train(*args, **kwargs) | |
| self.rot2xyz.smpl_model.train(*args, **kwargs) | |
| class PositionalEncoding(nn.Module): | |
| def __init__(self, d_model, dropout=0.1, max_len=5000): | |
| super(PositionalEncoding, self).__init__() | |
| self.dropout = nn.Dropout(p=dropout) | |
| pe = torch.zeros(max_len, d_model) | |
| position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) | |
| div_term = torch.exp( | |
| torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model) | |
| ) | |
| pe[:, 0::2] = torch.sin(position * div_term) | |
| pe[:, 1::2] = torch.cos(position * div_term) | |
| pe = pe.unsqueeze(0).transpose(0, 1) | |
| self.register_buffer("pe", pe) | |
| def forward(self, x): | |
| # not used in the final model | |
| x = x + self.pe[: x.shape[0], :] | |
| return self.dropout(x) | |
| class TimestepEmbedder(nn.Module): | |
| def __init__(self, latent_dim, sequence_pos_encoder): | |
| super().__init__() | |
| self.sequence_pos_encoder = sequence_pos_encoder | |
| self.time_embed = nn.Sequential( | |
| nn.Linear(latent_dim, latent_dim), | |
| nn.SiLU(), | |
| nn.Linear(latent_dim, latent_dim), | |
| ) | |
| def forward(self, timesteps): | |
| return self.time_embed(self.sequence_pos_encoder.pe[timesteps]).permute(1, 0, 2) | |
| class InputProcess(nn.Module): | |
| """ | |
| Applies the linear layer on the motion sequence at the beginning of the MDM | |
| Also changes the shape | |
| [bs, num_joints, num_features, num_frames] -> [num_frames, bs, latent_dim] | |
| """ | |
| def __init__(self, input_feats, latent_dim): | |
| super().__init__() | |
| self.poseEmbedding = nn.Linear(input_feats, latent_dim) | |
| def forward(self, sequence): | |
| bs, num_joints, num_features, num_frames = sequence.shape | |
| sequence = sequence.permute((3, 0, 1, 2)).reshape( | |
| num_frames, bs, num_joints * num_features | |
| ) | |
| sequence = self.poseEmbedding(sequence) | |
| return sequence | |
| class OutputProcess(nn.Module): | |
| """ | |
| Applies the linear layer on the motion sequence at the end of the MDM | |
| Also changes the shape | |
| [num_frames, bs, latent_dim] -> [bs, num_joints, num_features, num_frames] | |
| """ | |
| def __init__(self, input_feats, latent_dim, num_joints, num_features): | |
| super().__init__() | |
| self.input_feats = input_feats | |
| self.latent_dim = latent_dim | |
| self.num_joints = num_joints | |
| self.num_features = num_features | |
| self.poseFinal = nn.Linear(latent_dim, input_feats) | |
| def forward(self, sequence): | |
| num_frames, bs, _ = sequence.shape | |
| sequence = self.poseFinal(sequence) | |
| sequence = sequence.reshape( | |
| num_frames, bs, self.num_joints, self.num_features | |
| ).permute(1, 2, 3, 0) | |
| return sequence | |
| def load_and_freeze_clip(clip_version): | |
| # Must set jit=False for training | |
| clip_model, clip_preprocess = clip.load(clip_version, device="cpu", jit=False) | |
| clip_model.eval() | |
| # Freeze CLIP weights | |
| clip_model.requires_grad_(False) | |
| return clip_model | |