File size: 8,028 Bytes
45950ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
r"""TMR: Text-to-Motion Retrieval
Using Contrastive 3D Human Motion Synthesis
Find more information about the model on the following website:
https://mathis.petrovich.fr/tmr

Args:
    motion_encoder: a module to encode the input motion features in the latent space (required).
    text_encoder: a module to encode the text embeddings in the latent space (required).
    motion_decoder: a module to decode the latent vector into motion features (required).
    vae: a boolean to make the model probabilistic (required).
    fact: a scaling factor for sampling the VAE (optional).
    sample_mean: sample the mean vector instead of random sampling (optional).
    lmd: dictionary of losses weights (optional).
    lr: learninig rate for the optimizer (optional).
    temperature: temperature of the softmax in the contrastive loss (optional).
    threshold_selfsim: threshold used to filter wrong negatives for the contrastive loss (optional).
    threshold_selfsim_metrics: threshold used to filter wrong negatives for the metrics (optional).
"""
import torch, os
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
from mmengine.registry import MODELS
from mmengine.model import BaseModel
from typing import Optional

@MODELS.register_module()
class TMR(BaseModel):
    def __init__(self, motion_encoder_cfg, text_encoder_cfg, motion_decoder_cfg, 
                 temperature: float = 0.7, threshold_selfsim: float = 0.80, **kwargs):
        super().__init__(**kwargs)
        self.motion_encoder = MODELS.build(motion_encoder_cfg)
        self.text_encoder = MODELS.build(text_encoder_cfg)
        self.motion_decoder = MODELS.build(motion_decoder_cfg)

        # use distilbert text embedding
        model_path = 'ckpts/distilbert-base-uncased'
        os.environ["TOKENIZERS_PARALLELISM"] = "false"
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        self.text_model = AutoModel.from_pretrained(model_path)
        self.text_model.eval()
        for param in self.text_model.parameters():
            param.requires_grad = False

        # losses
        self.reconstruction_loss_fn = torch.nn.SmoothL1Loss(reduction="mean")
        self.latent_loss_fn = torch.nn.SmoothL1Loss(reduction="mean")
        self.kl_loss_fn = KLLoss()

        # adding the contrastive loss
        self.contrastive_loss_fn = InfoNCE_with_filtering(
            temperature=temperature, threshold_selfsim=threshold_selfsim)

    def encode_text(self, captions):
        assert isinstance(captions, list)
        with torch.no_grad():
            text_tokens = self.tokenizer(captions, return_tensors="pt", padding=True)
            output = self.text_model(**text_tokens.to(self.text_model.device))
        mask = text_tokens.attention_mask.to(dtype=bool)
        text_embeddings = output.last_hidden_state # B, max_length, C

        mask_expanded = mask.unsqueeze(-1).expand(text_embeddings.shape).float()
        sentence_embeddings = torch.sum(text_embeddings * mask_expanded, 1) / torch.clamp(
            mask_expanded.sum(1), min=1e-9)
        sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)

        dists = self.text_encoder(text_embeddings, mask).unbind(1)
        mu, logvar = dists

        # Reparameterization trick
        std = logvar.exp().pow(0.5)
        eps = std.data.new(std.size()).normal_()
        latent_vectors = mu + eps * std
        return latent_vectors, dists, sentence_embeddings

    def encode_motion(self, motion: torch.Tensor, 
                      mask: Optional[torch.Tensor]=None, 
                      motion_length: Optional[torch.Tensor]=None):
        assert mask is not None or motion_length is not None
        if mask is None:
            max_len = motion_length.max() # type: ignore
            mask = torch.arange(max_len, device=motion.device).expand( # type: ignore
                motion_length.shape[0], max_len) < motion_length.unsqueeze(1) # type: ignore
        x = self.motion_encoder(motion, mask)
        dists = x.unbind(1)
        mu, logvar = dists

        # Reparameterization trick
        std = logvar.exp().pow(0.5)
        eps = std.data.new(std.size()).normal_()
        latent_vectors = mu + eps * std
        return latent_vectors, dists
    
    def forward(self, motion, motion_length, text_data, mode='loss', **kwargs): # type: ignore
        '''
        motion: Tensor [B, T, C]
        motion_length: Tensor [B, ]
        condition_data: List[str]
        '''
        # motion = torch.cat([motion, trans], dim=-1)
        # generate motion mask
        max_len = motion_length.max()
        mask = torch.arange(max_len, device=motion.device).expand(
            motion_length.shape[0], max_len) < motion_length.unsqueeze(1)
        m_latents, m_dists = self.encode_motion(motion, mask=mask)
        m_motions = self.motion_decoder(m_latents, mask)

        t_latents, t_dists, sent_emb = self.encode_text(text_data)
        t_motions = self.motion_decoder(t_latents, mask)

        if mode == 'loss':
            losses = dict()

            # Reconstructions losses
            losses["recons_loss"] = (
                + self.reconstruction_loss_fn(t_motions, motion) # text -> motion
                + self.reconstruction_loss_fn(m_motions, motion) # motion -> motion
            ) * 1.0

            # VAE losses
            # Create a centred normal distribution to compare with logvar = 0 -> std = 1
            ref_mus = torch.zeros_like(m_dists[0])
            ref_logvar = torch.zeros_like(m_dists[1])
            ref_dists = (ref_mus, ref_logvar)

            losses["kl_loss"] = (
                self.kl_loss_fn(t_dists, m_dists)  # text_to_motion
                + self.kl_loss_fn(m_dists, t_dists)  # motion_to_text
                + self.kl_loss_fn(m_dists, ref_dists)  # motion
                + self.kl_loss_fn(t_dists, ref_dists)  # text
            ) * 1.0e-5

            # Latent manifold loss
            losses["latent_loss"] = self.latent_loss_fn(t_latents, m_latents) * 1.0e-5

            # TMR: adding the contrastive loss
            losses["contrastive_loss"] = self.contrastive_loss_fn(t_latents, m_latents, sent_emb) * 0.1
            return losses
        else:
            return t_latents, m_latents, sent_emb

# For reference
# https://stats.stackexchange.com/questions/7440/kl-divergence-between-two-univariate-gaussians
# https://pytorch.org/docs/stable/_modules/torch/distributions/kl.html#kl_divergence
class KLLoss:
    def __call__(self, q, p):
        mu_q, logvar_q = q
        mu_p, logvar_p = p

        log_var_ratio = logvar_q - logvar_p
        t1 = (mu_p - mu_q).pow(2) / logvar_p.exp()
        div = 0.5 * (log_var_ratio.exp() + t1 - 1 - log_var_ratio)
        return div.mean()


class InfoNCE_with_filtering:
    def __init__(self, temperature=0.7, threshold_selfsim=0.8):
        self.temperature = temperature
        self.threshold_selfsim = threshold_selfsim

    def get_sim_matrix(self, x, y):
        x_logits = torch.nn.functional.normalize(x, dim=-1)
        y_logits = torch.nn.functional.normalize(y, dim=-1)
        sim_matrix = x_logits @ y_logits.T
        return sim_matrix

    def __call__(self, x, y, sent_emb=None):
        bs, device = len(x), x.device
        sim_matrix = self.get_sim_matrix(x, y) / self.temperature

        if sent_emb is not None and self.threshold_selfsim:
            # put the threshold value between -1 and 1
            real_threshold_selfsim = 2 * self.threshold_selfsim - 1
            # Filtering too close values
            # mask them by putting -inf in the sim_matrix
            selfsim = sent_emb @ sent_emb.T
            selfsim_nodiag = selfsim - selfsim.diag().diag()
            idx = torch.where(selfsim_nodiag > real_threshold_selfsim)
            sim_matrix[idx] = -torch.inf

        labels = torch.arange(bs, device=device)

        total_loss = (
            F.cross_entropy(sim_matrix, labels) + F.cross_entropy(sim_matrix.T, labels)
        ) / 2

        return total_loss