| ''' |
| copy from https://github.com/Mathux/TMR/blob/master/src/model/actor.py |
| ''' |
|
|
| import torch |
| import torch.nn as nn |
| from mmengine.registry import MODELS |
| from mmengine.model import BaseModel |
| from torch import Tensor |
| import numpy as np |
|
|
| from einops import repeat |
|
|
| class PositionalEncoding(nn.Module): |
| def __init__(self, d_model, dropout=0.1, max_len=5000, batch_first=False) -> None: |
| super().__init__() |
| self.batch_first = batch_first |
|
|
| self.dropout = nn.Dropout(p=dropout) |
|
|
| pe = torch.zeros(max_len, d_model) |
| position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) |
| div_term = torch.exp( |
| torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)) |
| pe[:, 0::2] = torch.sin(position * div_term) |
| pe[:, 1::2] = torch.cos(position * div_term) |
| pe = pe.unsqueeze(0).transpose(0, 1) |
| self.register_buffer("pe", pe, persistent=False) |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| if self.batch_first: |
| x = x + self.pe.permute(1, 0, 2)[:, : x.shape[1], :] |
| else: |
| x = x + self.pe[: x.shape[0], :] |
| return self.dropout(x) |
|
|
| @MODELS.register_module() |
| class ACTORStyleEncoder(BaseModel): |
| |
| def __init__( |
| self, |
| nfeats: int, |
| latent_dim: int = 256, |
| ff_size: int = 1024, |
| num_layers: int = 4, |
| num_heads: int = 4, |
| dropout: float = 0.1, |
| activation: str = "gelu", |
| ) -> None: |
| super().__init__() |
|
|
| self.nfeats = nfeats |
| self.projection = nn.Linear(nfeats, latent_dim) |
|
|
| self.nbtokens = 2 |
| self.tokens = nn.Parameter(torch.randn(self.nbtokens, latent_dim)) |
|
|
| self.sequence_pos_encoding = PositionalEncoding( |
| latent_dim, dropout=dropout, batch_first=True) |
|
|
| seq_trans_encoder_layer = nn.TransformerEncoderLayer( |
| d_model=latent_dim, |
| nhead=num_heads, |
| dim_feedforward=ff_size, |
| dropout=dropout, |
| activation=activation, |
| batch_first=True) |
|
|
| self.seqTransEncoder = nn.TransformerEncoder( |
| seq_trans_encoder_layer, num_layers=num_layers) |
|
|
| def forward(self, x, mask, **kwargs) -> Tensor: |
| x = self.projection(x) |
|
|
| device = x.device |
| bs = len(x) |
|
|
| tokens = repeat(self.tokens, "nbtoken dim -> bs nbtoken dim", bs=bs) |
| xseq = torch.cat((tokens, x), 1) |
|
|
| token_mask = torch.ones((bs, self.nbtokens), dtype=bool, device=device) |
| aug_mask = torch.cat((token_mask, mask), 1) |
|
|
| |
| xseq = self.sequence_pos_encoding(xseq) |
| final = self.seqTransEncoder(xseq, src_key_padding_mask=~aug_mask) |
| return final[:, : self.nbtokens] |
|
|
|
|
| @MODELS.register_module() |
| class ACTORStyleDecoder(BaseModel): |
| |
|
|
| def __init__( |
| self, |
| nfeats: int, |
| latent_dim: int = 256, |
| ff_size: int = 1024, |
| num_layers: int = 4, |
| num_heads: int = 4, |
| dropout: float = 0.1, |
| activation: str = "gelu", |
| ) -> None: |
| super().__init__() |
| output_feats = nfeats |
| self.nfeats = nfeats |
|
|
| self.sequence_pos_encoding = PositionalEncoding( |
| latent_dim, dropout, batch_first=True) |
|
|
| seq_trans_decoder_layer = nn.TransformerDecoderLayer( |
| d_model=latent_dim, |
| nhead=num_heads, |
| dim_feedforward=ff_size, |
| dropout=dropout, |
| activation=activation, |
| batch_first=True) |
|
|
| self.seqTransDecoder = nn.TransformerDecoder( |
| seq_trans_decoder_layer, num_layers=num_layers) |
|
|
| self.final_layer = nn.Linear(latent_dim, output_feats) |
|
|
| def forward(self, z, mask, **kwargs) -> Tensor: |
|
|
| latent_dim = z.shape[1] |
| bs, nframes = mask.shape |
|
|
| z = z[:, None] |
|
|
| |
| time_queries = torch.zeros(bs, nframes, latent_dim, device=z.device) |
| time_queries = self.sequence_pos_encoding(time_queries) |
|
|
| |
| |
| output = self.seqTransDecoder( |
| tgt=time_queries, memory=z, tgt_key_padding_mask=~mask) |
|
|
| output = self.final_layer(output) |
| |
| output[~mask] = 0 |
| return output |
|
|