NMR / tools /data_process /src /tmr.py
Xxx999's picture
upload
45950ff
r"""TMR: Text-to-Motion Retrieval
Using Contrastive 3D Human Motion Synthesis
Find more information about the model on the following website:
https://mathis.petrovich.fr/tmr
Args:
motion_encoder: a module to encode the input motion features in the latent space (required).
text_encoder: a module to encode the text embeddings in the latent space (required).
motion_decoder: a module to decode the latent vector into motion features (required).
vae: a boolean to make the model probabilistic (required).
fact: a scaling factor for sampling the VAE (optional).
sample_mean: sample the mean vector instead of random sampling (optional).
lmd: dictionary of losses weights (optional).
lr: learninig rate for the optimizer (optional).
temperature: temperature of the softmax in the contrastive loss (optional).
threshold_selfsim: threshold used to filter wrong negatives for the contrastive loss (optional).
threshold_selfsim_metrics: threshold used to filter wrong negatives for the metrics (optional).
"""
import torch, os
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
from mmengine.registry import MODELS
from mmengine.model import BaseModel
from typing import Optional
@MODELS.register_module()
class TMR(BaseModel):
def __init__(self, motion_encoder_cfg, text_encoder_cfg, motion_decoder_cfg,
temperature: float = 0.7, threshold_selfsim: float = 0.80, **kwargs):
super().__init__(**kwargs)
self.motion_encoder = MODELS.build(motion_encoder_cfg)
self.text_encoder = MODELS.build(text_encoder_cfg)
self.motion_decoder = MODELS.build(motion_decoder_cfg)
# use distilbert text embedding
model_path = 'ckpts/distilbert-base-uncased'
os.environ["TOKENIZERS_PARALLELISM"] = "false"
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.text_model = AutoModel.from_pretrained(model_path)
self.text_model.eval()
for param in self.text_model.parameters():
param.requires_grad = False
# losses
self.reconstruction_loss_fn = torch.nn.SmoothL1Loss(reduction="mean")
self.latent_loss_fn = torch.nn.SmoothL1Loss(reduction="mean")
self.kl_loss_fn = KLLoss()
# adding the contrastive loss
self.contrastive_loss_fn = InfoNCE_with_filtering(
temperature=temperature, threshold_selfsim=threshold_selfsim)
def encode_text(self, captions):
assert isinstance(captions, list)
with torch.no_grad():
text_tokens = self.tokenizer(captions, return_tensors="pt", padding=True)
output = self.text_model(**text_tokens.to(self.text_model.device))
mask = text_tokens.attention_mask.to(dtype=bool)
text_embeddings = output.last_hidden_state # B, max_length, C
mask_expanded = mask.unsqueeze(-1).expand(text_embeddings.shape).float()
sentence_embeddings = torch.sum(text_embeddings * mask_expanded, 1) / torch.clamp(
mask_expanded.sum(1), min=1e-9)
sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
dists = self.text_encoder(text_embeddings, mask).unbind(1)
mu, logvar = dists
# Reparameterization trick
std = logvar.exp().pow(0.5)
eps = std.data.new(std.size()).normal_()
latent_vectors = mu + eps * std
return latent_vectors, dists, sentence_embeddings
def encode_motion(self, motion: torch.Tensor,
mask: Optional[torch.Tensor]=None,
motion_length: Optional[torch.Tensor]=None):
assert mask is not None or motion_length is not None
if mask is None:
max_len = motion_length.max() # type: ignore
mask = torch.arange(max_len, device=motion.device).expand( # type: ignore
motion_length.shape[0], max_len) < motion_length.unsqueeze(1) # type: ignore
x = self.motion_encoder(motion, mask)
dists = x.unbind(1)
mu, logvar = dists
# Reparameterization trick
std = logvar.exp().pow(0.5)
eps = std.data.new(std.size()).normal_()
latent_vectors = mu + eps * std
return latent_vectors, dists
def forward(self, motion, motion_length, text_data, mode='loss', **kwargs): # type: ignore
'''
motion: Tensor [B, T, C]
motion_length: Tensor [B, ]
condition_data: List[str]
'''
# motion = torch.cat([motion, trans], dim=-1)
# generate motion mask
max_len = motion_length.max()
mask = torch.arange(max_len, device=motion.device).expand(
motion_length.shape[0], max_len) < motion_length.unsqueeze(1)
m_latents, m_dists = self.encode_motion(motion, mask=mask)
m_motions = self.motion_decoder(m_latents, mask)
t_latents, t_dists, sent_emb = self.encode_text(text_data)
t_motions = self.motion_decoder(t_latents, mask)
if mode == 'loss':
losses = dict()
# Reconstructions losses
losses["recons_loss"] = (
+ self.reconstruction_loss_fn(t_motions, motion) # text -> motion
+ self.reconstruction_loss_fn(m_motions, motion) # motion -> motion
) * 1.0
# VAE losses
# Create a centred normal distribution to compare with logvar = 0 -> std = 1
ref_mus = torch.zeros_like(m_dists[0])
ref_logvar = torch.zeros_like(m_dists[1])
ref_dists = (ref_mus, ref_logvar)
losses["kl_loss"] = (
self.kl_loss_fn(t_dists, m_dists) # text_to_motion
+ self.kl_loss_fn(m_dists, t_dists) # motion_to_text
+ self.kl_loss_fn(m_dists, ref_dists) # motion
+ self.kl_loss_fn(t_dists, ref_dists) # text
) * 1.0e-5
# Latent manifold loss
losses["latent_loss"] = self.latent_loss_fn(t_latents, m_latents) * 1.0e-5
# TMR: adding the contrastive loss
losses["contrastive_loss"] = self.contrastive_loss_fn(t_latents, m_latents, sent_emb) * 0.1
return losses
else:
return t_latents, m_latents, sent_emb
# For reference
# https://stats.stackexchange.com/questions/7440/kl-divergence-between-two-univariate-gaussians
# https://pytorch.org/docs/stable/_modules/torch/distributions/kl.html#kl_divergence
class KLLoss:
def __call__(self, q, p):
mu_q, logvar_q = q
mu_p, logvar_p = p
log_var_ratio = logvar_q - logvar_p
t1 = (mu_p - mu_q).pow(2) / logvar_p.exp()
div = 0.5 * (log_var_ratio.exp() + t1 - 1 - log_var_ratio)
return div.mean()
class InfoNCE_with_filtering:
def __init__(self, temperature=0.7, threshold_selfsim=0.8):
self.temperature = temperature
self.threshold_selfsim = threshold_selfsim
def get_sim_matrix(self, x, y):
x_logits = torch.nn.functional.normalize(x, dim=-1)
y_logits = torch.nn.functional.normalize(y, dim=-1)
sim_matrix = x_logits @ y_logits.T
return sim_matrix
def __call__(self, x, y, sent_emb=None):
bs, device = len(x), x.device
sim_matrix = self.get_sim_matrix(x, y) / self.temperature
if sent_emb is not None and self.threshold_selfsim:
# put the threshold value between -1 and 1
real_threshold_selfsim = 2 * self.threshold_selfsim - 1
# Filtering too close values
# mask them by putting -inf in the sim_matrix
selfsim = sent_emb @ sent_emb.T
selfsim_nodiag = selfsim - selfsim.diag().diag()
idx = torch.where(selfsim_nodiag > real_threshold_selfsim)
sim_matrix[idx] = -torch.inf
labels = torch.arange(bs, device=device)
total_loss = (
F.cross_entropy(sim_matrix, labels) + F.cross_entropy(sim_matrix.T, labels)
) / 2
return total_loss