NMR / tools /data_process /src /tmr_modules.py
Xxx999's picture
upload
45950ff
'''
copy from https://github.com/Mathux/TMR/blob/master/src/model/actor.py
'''
import torch
import torch.nn as nn
from mmengine.registry import MODELS
from mmengine.model import BaseModel
from torch import Tensor
import numpy as np
from einops import repeat
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.1, max_len=5000, batch_first=False) -> None:
super().__init__()
self.batch_first = batch_first
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer("pe", pe, persistent=False)
def forward(self, x: Tensor) -> Tensor:
if self.batch_first:
x = x + self.pe.permute(1, 0, 2)[:, : x.shape[1], :]
else:
x = x + self.pe[: x.shape[0], :]
return self.dropout(x)
@MODELS.register_module()
class ACTORStyleEncoder(BaseModel):
# Similar to ACTOR but "action agnostic" and more general
def __init__(
self,
nfeats: int,
latent_dim: int = 256,
ff_size: int = 1024,
num_layers: int = 4,
num_heads: int = 4,
dropout: float = 0.1,
activation: str = "gelu",
) -> None:
super().__init__()
self.nfeats = nfeats
self.projection = nn.Linear(nfeats, latent_dim)
self.nbtokens = 2
self.tokens = nn.Parameter(torch.randn(self.nbtokens, latent_dim))
self.sequence_pos_encoding = PositionalEncoding(
latent_dim, dropout=dropout, batch_first=True)
seq_trans_encoder_layer = nn.TransformerEncoderLayer(
d_model=latent_dim,
nhead=num_heads,
dim_feedforward=ff_size,
dropout=dropout,
activation=activation,
batch_first=True)
self.seqTransEncoder = nn.TransformerEncoder(
seq_trans_encoder_layer, num_layers=num_layers)
def forward(self, x, mask, **kwargs) -> Tensor: # type: ignore
x = self.projection(x)
device = x.device
bs = len(x)
tokens = repeat(self.tokens, "nbtoken dim -> bs nbtoken dim", bs=bs)
xseq = torch.cat((tokens, x), 1)
token_mask = torch.ones((bs, self.nbtokens), dtype=bool, device=device) # type: ignore
aug_mask = torch.cat((token_mask, mask), 1)
# add positional encoding
xseq = self.sequence_pos_encoding(xseq)
final = self.seqTransEncoder(xseq, src_key_padding_mask=~aug_mask)
return final[:, : self.nbtokens]
@MODELS.register_module()
class ACTORStyleDecoder(BaseModel):
# Similar to ACTOR Decoder
def __init__(
self,
nfeats: int,
latent_dim: int = 256,
ff_size: int = 1024,
num_layers: int = 4,
num_heads: int = 4,
dropout: float = 0.1,
activation: str = "gelu",
) -> None:
super().__init__()
output_feats = nfeats
self.nfeats = nfeats
self.sequence_pos_encoding = PositionalEncoding(
latent_dim, dropout, batch_first=True)
seq_trans_decoder_layer = nn.TransformerDecoderLayer(
d_model=latent_dim,
nhead=num_heads,
dim_feedforward=ff_size,
dropout=dropout,
activation=activation,
batch_first=True)
self.seqTransDecoder = nn.TransformerDecoder(
seq_trans_decoder_layer, num_layers=num_layers)
self.final_layer = nn.Linear(latent_dim, output_feats)
def forward(self, z, mask, **kwargs) -> Tensor: # type: ignore
latent_dim = z.shape[1]
bs, nframes = mask.shape
z = z[:, None] # sequence of 1 element for the memory
# Construct time queries
time_queries = torch.zeros(bs, nframes, latent_dim, device=z.device)
time_queries = self.sequence_pos_encoding(time_queries)
# Pass through the transformer decoder
# with the latent vector for memory
output = self.seqTransDecoder(
tgt=time_queries, memory=z, tgt_key_padding_mask=~mask)
output = self.final_layer(output)
# zero for padded area
output[~mask] = 0
return output