File size: 8,028 Bytes
45950ff | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 | r"""TMR: Text-to-Motion Retrieval
Using Contrastive 3D Human Motion Synthesis
Find more information about the model on the following website:
https://mathis.petrovich.fr/tmr
Args:
motion_encoder: a module to encode the input motion features in the latent space (required).
text_encoder: a module to encode the text embeddings in the latent space (required).
motion_decoder: a module to decode the latent vector into motion features (required).
vae: a boolean to make the model probabilistic (required).
fact: a scaling factor for sampling the VAE (optional).
sample_mean: sample the mean vector instead of random sampling (optional).
lmd: dictionary of losses weights (optional).
lr: learninig rate for the optimizer (optional).
temperature: temperature of the softmax in the contrastive loss (optional).
threshold_selfsim: threshold used to filter wrong negatives for the contrastive loss (optional).
threshold_selfsim_metrics: threshold used to filter wrong negatives for the metrics (optional).
"""
import torch, os
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
from mmengine.registry import MODELS
from mmengine.model import BaseModel
from typing import Optional
@MODELS.register_module()
class TMR(BaseModel):
def __init__(self, motion_encoder_cfg, text_encoder_cfg, motion_decoder_cfg,
temperature: float = 0.7, threshold_selfsim: float = 0.80, **kwargs):
super().__init__(**kwargs)
self.motion_encoder = MODELS.build(motion_encoder_cfg)
self.text_encoder = MODELS.build(text_encoder_cfg)
self.motion_decoder = MODELS.build(motion_decoder_cfg)
# use distilbert text embedding
model_path = 'ckpts/distilbert-base-uncased'
os.environ["TOKENIZERS_PARALLELISM"] = "false"
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.text_model = AutoModel.from_pretrained(model_path)
self.text_model.eval()
for param in self.text_model.parameters():
param.requires_grad = False
# losses
self.reconstruction_loss_fn = torch.nn.SmoothL1Loss(reduction="mean")
self.latent_loss_fn = torch.nn.SmoothL1Loss(reduction="mean")
self.kl_loss_fn = KLLoss()
# adding the contrastive loss
self.contrastive_loss_fn = InfoNCE_with_filtering(
temperature=temperature, threshold_selfsim=threshold_selfsim)
def encode_text(self, captions):
assert isinstance(captions, list)
with torch.no_grad():
text_tokens = self.tokenizer(captions, return_tensors="pt", padding=True)
output = self.text_model(**text_tokens.to(self.text_model.device))
mask = text_tokens.attention_mask.to(dtype=bool)
text_embeddings = output.last_hidden_state # B, max_length, C
mask_expanded = mask.unsqueeze(-1).expand(text_embeddings.shape).float()
sentence_embeddings = torch.sum(text_embeddings * mask_expanded, 1) / torch.clamp(
mask_expanded.sum(1), min=1e-9)
sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
dists = self.text_encoder(text_embeddings, mask).unbind(1)
mu, logvar = dists
# Reparameterization trick
std = logvar.exp().pow(0.5)
eps = std.data.new(std.size()).normal_()
latent_vectors = mu + eps * std
return latent_vectors, dists, sentence_embeddings
def encode_motion(self, motion: torch.Tensor,
mask: Optional[torch.Tensor]=None,
motion_length: Optional[torch.Tensor]=None):
assert mask is not None or motion_length is not None
if mask is None:
max_len = motion_length.max() # type: ignore
mask = torch.arange(max_len, device=motion.device).expand( # type: ignore
motion_length.shape[0], max_len) < motion_length.unsqueeze(1) # type: ignore
x = self.motion_encoder(motion, mask)
dists = x.unbind(1)
mu, logvar = dists
# Reparameterization trick
std = logvar.exp().pow(0.5)
eps = std.data.new(std.size()).normal_()
latent_vectors = mu + eps * std
return latent_vectors, dists
def forward(self, motion, motion_length, text_data, mode='loss', **kwargs): # type: ignore
'''
motion: Tensor [B, T, C]
motion_length: Tensor [B, ]
condition_data: List[str]
'''
# motion = torch.cat([motion, trans], dim=-1)
# generate motion mask
max_len = motion_length.max()
mask = torch.arange(max_len, device=motion.device).expand(
motion_length.shape[0], max_len) < motion_length.unsqueeze(1)
m_latents, m_dists = self.encode_motion(motion, mask=mask)
m_motions = self.motion_decoder(m_latents, mask)
t_latents, t_dists, sent_emb = self.encode_text(text_data)
t_motions = self.motion_decoder(t_latents, mask)
if mode == 'loss':
losses = dict()
# Reconstructions losses
losses["recons_loss"] = (
+ self.reconstruction_loss_fn(t_motions, motion) # text -> motion
+ self.reconstruction_loss_fn(m_motions, motion) # motion -> motion
) * 1.0
# VAE losses
# Create a centred normal distribution to compare with logvar = 0 -> std = 1
ref_mus = torch.zeros_like(m_dists[0])
ref_logvar = torch.zeros_like(m_dists[1])
ref_dists = (ref_mus, ref_logvar)
losses["kl_loss"] = (
self.kl_loss_fn(t_dists, m_dists) # text_to_motion
+ self.kl_loss_fn(m_dists, t_dists) # motion_to_text
+ self.kl_loss_fn(m_dists, ref_dists) # motion
+ self.kl_loss_fn(t_dists, ref_dists) # text
) * 1.0e-5
# Latent manifold loss
losses["latent_loss"] = self.latent_loss_fn(t_latents, m_latents) * 1.0e-5
# TMR: adding the contrastive loss
losses["contrastive_loss"] = self.contrastive_loss_fn(t_latents, m_latents, sent_emb) * 0.1
return losses
else:
return t_latents, m_latents, sent_emb
# For reference
# https://stats.stackexchange.com/questions/7440/kl-divergence-between-two-univariate-gaussians
# https://pytorch.org/docs/stable/_modules/torch/distributions/kl.html#kl_divergence
class KLLoss:
def __call__(self, q, p):
mu_q, logvar_q = q
mu_p, logvar_p = p
log_var_ratio = logvar_q - logvar_p
t1 = (mu_p - mu_q).pow(2) / logvar_p.exp()
div = 0.5 * (log_var_ratio.exp() + t1 - 1 - log_var_ratio)
return div.mean()
class InfoNCE_with_filtering:
def __init__(self, temperature=0.7, threshold_selfsim=0.8):
self.temperature = temperature
self.threshold_selfsim = threshold_selfsim
def get_sim_matrix(self, x, y):
x_logits = torch.nn.functional.normalize(x, dim=-1)
y_logits = torch.nn.functional.normalize(y, dim=-1)
sim_matrix = x_logits @ y_logits.T
return sim_matrix
def __call__(self, x, y, sent_emb=None):
bs, device = len(x), x.device
sim_matrix = self.get_sim_matrix(x, y) / self.temperature
if sent_emb is not None and self.threshold_selfsim:
# put the threshold value between -1 and 1
real_threshold_selfsim = 2 * self.threshold_selfsim - 1
# Filtering too close values
# mask them by putting -inf in the sim_matrix
selfsim = sent_emb @ sent_emb.T
selfsim_nodiag = selfsim - selfsim.diag().diag()
idx = torch.where(selfsim_nodiag > real_threshold_selfsim)
sim_matrix[idx] = -torch.inf
labels = torch.arange(bs, device=device)
total_loss = (
F.cross_entropy(sim_matrix, labels) + F.cross_entropy(sim_matrix.T, labels)
) / 2
return total_loss
|