mdm / model /mdm.py
hassanjbara's picture
update model
5007d4b
import clip
import numpy as np
import torch
import torch.nn as nn
from model.config import MDMConfig
from model.rotation2xyz import Rotation2xyz
class MDM(nn.Module):
@classmethod
def from_config(cls, config: MDMConfig):
"""
Instantiate MDM from an MDMConfig object.
"""
return cls(
modeltype=config.modeltype,
njoints=config.njoints,
nfeats=config.nfeats,
num_actions=config.num_actions,
translation=config.translation,
pose_rep=config.pose_rep,
glob=config.glob,
glob_rot=config.glob_rot,
latent_dim=config.latent_dim,
ff_size=config.ff_size,
num_layers=config.layers,
num_heads=config.num_heads,
dropout=config.dropout,
ablation=config.ablation,
activation=config.activation,
legacy=config.legacy,
data_rep=config.data_rep,
dataset=config.dataset,
clip_dim=config.clip_dim,
arch=config.arch,
emb_trans_dec=config.emb_trans_dec,
clip_version=config.clip_version,
action_emb=config.action_emb,
normalize_encoder_output=config.normalize_encoder_output,
cond_mask_prob=config.cond_mask_prob,
mask_frames=config.mask_frames,
emb_policy=config.emb_policy,
pos_embed_max_len=config.pos_embed_max_len,
pred_len=config.pred_len,
context_len=config.context_len,
all_goal_joint_names=config.all_goal_joint_names,
multi_target_cond=config.multi_target_cond,
multi_encoder_type=config.multi_encoder_type,
target_enc_layers=config.target_enc_layers,
)
def __init__(
self,
modeltype,
njoints,
nfeats,
num_actions,
translation,
pose_rep,
glob,
glob_rot,
latent_dim=256,
ff_size=1024,
num_layers=8,
num_heads=4,
dropout=0.1,
ablation=None,
activation="gelu",
legacy=False,
data_rep="rot6d",
dataset="amass",
clip_dim=512,
arch="trans_enc",
emb_trans_dec=False,
clip_version=None,
**kargs,
):
super().__init__()
self.legacy = legacy
self.modeltype = modeltype
self.njoints = njoints
self.nfeats = nfeats
self.num_actions = num_actions
self.data_rep = data_rep
self.dataset = dataset
self.pose_rep = pose_rep
self.glob = glob
self.glob_rot = glob_rot
self.translation = translation
self.latent_dim = latent_dim
self.ff_size = ff_size
self.num_layers = num_layers
self.num_heads = num_heads
self.dropout = dropout
self.ablation = ablation
self.activation = activation
self.clip_dim = clip_dim
self.action_emb = kargs.get("action_emb", None)
self.input_feats = self.njoints * self.nfeats
self.normalize_output = kargs.get("normalize_encoder_output", False)
self.cond_mode = kargs.get("cond_mode", "no_cond")
self.cond_mask_prob = kargs.get("cond_mask_prob", 0.0)
self.mask_frames = kargs.get("mask_frames", False)
self.arch = arch
self.emb_policy = kargs.get("emb_policy", "add")
self.pred_len = kargs.get("pred_len", 0)
self.context_len = kargs.get("context_len", 0)
self.total_len = self.pred_len + self.context_len
self.is_prefix_comp = self.total_len > 0
self.all_goal_joint_names = kargs.get("all_goal_joint_names", [])
self.multi_target_cond = kargs.get("multi_target_cond", False)
self.text_encoder_type = kargs.get("text_encoder_type", "clip")
# Assert some assumptions we're doing for simplicity
assert self.arch == "trans_enc"
assert self.cond_mode == "text"
assert self.text_encoder_type == "clip"
assert not self.multi_target_cond
assert not self.is_prefix_comp
assert self.emb_policy == "add"
assert self.data_rep == "hml_vec"
# Using the Encoder architecture
transformer_encoder_layer = nn.TransformerEncoderLayer(
d_model=self.latent_dim,
nhead=self.num_heads,
dim_feedforward=self.ff_size,
dropout=self.dropout,
activation=self.activation,
)
self.seqTransEncoder = nn.TransformerEncoder(
transformer_encoder_layer, num_layers=self.num_layers
)
self.sequence_pos_encoder = PositionalEncoding(
self.latent_dim, self.dropout, max_len=kargs.get("pos_embed_max_len", 5000)
)
self.embed_timestep = TimestepEmbedder(
self.latent_dim, self.sequence_pos_encoder
)
# We'll use CLIP for now
self.clip_version = clip_version
self.clip_model = load_and_freeze_clip(clip_version)
self.encode_text = self.clip_encode_text
self.embed_text = nn.Linear(self.clip_dim, self.latent_dim)
# Linear input and output layers
self.input_process = InputProcess(self.input_feats, self.latent_dim)
self.output_process = OutputProcess(
self.input_feats, self.latent_dim, self.njoints, self.nfeats
)
self.rot2xyz = Rotation2xyz(device="cpu", dataset=self.dataset)
def parameters_wo_clip(self):
return [
p
for name, p in self.named_parameters()
if not name.startswith("clip_model.")
]
def mask_cond(self, cond, force_mask=False):
bs = cond.shape[-2]
if force_mask:
return torch.zeros_like(cond)
elif self.training and self.cond_mask_prob > 0.0:
mask = torch.bernoulli(
torch.ones(bs, device=cond.device) * self.cond_mask_prob
).view(1, bs, 1) # 1-> use null_cond, 0-> use real cond
return cond * (1.0 - mask)
else:
return cond
def clip_encode_text(self, raw_text):
# raw_text - list (batch_size length) of strings with input text prompts
device = next(self.parameters()).device
max_text_len = (
20 if self.dataset in ["kit", "humanml", "humanml_with_images"] else None
) # Specific hardcoding for humanml dataset
if max_text_len is not None:
default_context_length = 77
context_length = max_text_len + 2 # start_token + 20 + end_token
assert context_length < default_context_length
texts = clip.tokenize(
raw_text, context_length=context_length, truncate=True
).to(
device
) # [bs, context_length] # if n_tokens > context_length -> will truncate
# print('texts', texts.shape)
zero_pad = torch.zeros(
[texts.shape[0], default_context_length - context_length],
dtype=texts.dtype,
device=texts.device,
)
texts = torch.cat([texts, zero_pad], dim=1)
# print('texts after pad', texts.shape, texts)
else:
texts = clip.tokenize(raw_text, truncate=True).to(
device
) # [bs, context_length] # if n_tokens > 77 -> will truncate
return self.clip_model.encode_text(texts).float().unsqueeze(0)
def motion_to_sequence(self, motion, timesteps, y):
if "text_embed" not in y:
clip_encoded_text = self.encode_text(y["text"])
else:
clip_encoded_text = y["text_embed"]
# casting mask for the single-prompt-for-all case
force_mask = y.get("uncond", False)
# [1, bs, latent_dim]
text_embedding = self.embed_text(
self.mask_cond(clip_encoded_text, force_mask=force_mask)
)
# compute the embedding of the timestep + text, z_tk in the paper
time_embedding = self.embed_timestep(timesteps) # [1, bs, latent_dim]
embedding = text_embedding + time_embedding # [1, bs, latent_dim]
# get the motion into latent space
sequence = self.input_process(motion) # [num_frames, bs, latent_dim]
sequence_plus_emb = torch.cat(
(embedding, sequence), dim=0
) # [num_frames + 1, bs, latent_dim]
return sequence_plus_emb
def sequence_to_motion(self, sequence_plus_emb):
# remove the embedding from the sequence, remove the z_tk from the paper
sequence = sequence_plus_emb[1:] # [num_frames, bs, latent_dim]
# get back the motion from the latent space
motion = self.output_process(
sequence
) # [bs, num_joints, num_features, num_frames]
return motion
def prepare_mask(self, sequence, device, y, bs):
# Don't use mask with the generate script
is_valid_mask = y["mask"].shape[-1] > 1
if self.mask_frames and is_valid_mask:
frames_mask = (torch.logical_not(y["mask"][..., : sequence.shape[0]].squeeze(1).squeeze(1))
.to(device=device))
step_mask = torch.zeros((bs, 1), dtype=torch.bool, device=device)
return torch.cat([step_mask, frames_mask], dim=1)
else:
return None
def forward(self, motion, timesteps, y=None):
"""
motion: [bs, num_joints, num_features, num_frames]
timesteps: [bs]
"""
sequence = self.motion_to_sequence(motion, timesteps, y)
# apply positional encoding
sequence = self.sequence_pos_encoder(
sequence
) # [num_frames + 1, bs, latent_dim]
frames_mask = self.prepare_mask(sequence, motion.device, y, motion.shape[0])
# actual transformer magic
sequence = self.seqTransEncoder(sequence, src_key_padding_mask=frames_mask)
motion = self.sequence_to_motion(sequence)
return motion
def _apply(self, fn):
super()._apply(fn)
self.rot2xyz.smpl_model._apply(fn)
def train(self, *args, **kwargs):
super().train(*args, **kwargs)
self.rot2xyz.smpl_model.train(*args, **kwargs)
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.1, max_len=5000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer("pe", pe)
def forward(self, x):
# not used in the final model
x = x + self.pe[: x.shape[0], :]
return self.dropout(x)
class TimestepEmbedder(nn.Module):
def __init__(self, latent_dim, sequence_pos_encoder):
super().__init__()
self.sequence_pos_encoder = sequence_pos_encoder
self.time_embed = nn.Sequential(
nn.Linear(latent_dim, latent_dim),
nn.SiLU(),
nn.Linear(latent_dim, latent_dim),
)
def forward(self, timesteps):
return self.time_embed(self.sequence_pos_encoder.pe[timesteps]).permute(1, 0, 2)
class InputProcess(nn.Module):
"""
Applies the linear layer on the motion sequence at the beginning of the MDM
Also changes the shape
[bs, num_joints, num_features, num_frames] -> [num_frames, bs, latent_dim]
"""
def __init__(self, input_feats, latent_dim):
super().__init__()
self.poseEmbedding = nn.Linear(input_feats, latent_dim)
def forward(self, sequence):
bs, num_joints, num_features, num_frames = sequence.shape
sequence = sequence.permute((3, 0, 1, 2)).reshape(
num_frames, bs, num_joints * num_features
)
sequence = self.poseEmbedding(sequence)
return sequence
class OutputProcess(nn.Module):
"""
Applies the linear layer on the motion sequence at the end of the MDM
Also changes the shape
[num_frames, bs, latent_dim] -> [bs, num_joints, num_features, num_frames]
"""
def __init__(self, input_feats, latent_dim, num_joints, num_features):
super().__init__()
self.input_feats = input_feats
self.latent_dim = latent_dim
self.num_joints = num_joints
self.num_features = num_features
self.poseFinal = nn.Linear(latent_dim, input_feats)
def forward(self, sequence):
num_frames, bs, _ = sequence.shape
sequence = self.poseFinal(sequence)
sequence = sequence.reshape(
num_frames, bs, self.num_joints, self.num_features
).permute(1, 2, 3, 0)
return sequence
def load_and_freeze_clip(clip_version):
# Must set jit=False for training
clip_model, clip_preprocess = clip.load(clip_version, device="cpu", jit=False)
clip_model.eval()
# Freeze CLIP weights
clip_model.requires_grad_(False)
return clip_model