''' copy from https://github.com/Mathux/TMR/blob/master/src/model/actor.py ''' import torch import torch.nn as nn from mmengine.registry import MODELS from mmengine.model import BaseModel from torch import Tensor import numpy as np from einops import repeat class PositionalEncoding(nn.Module): def __init__(self, d_model, dropout=0.1, max_len=5000, batch_first=False) -> None: super().__init__() self.batch_first = batch_first self.dropout = nn.Dropout(p=dropout) pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp( torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0).transpose(0, 1) self.register_buffer("pe", pe, persistent=False) def forward(self, x: Tensor) -> Tensor: if self.batch_first: x = x + self.pe.permute(1, 0, 2)[:, : x.shape[1], :] else: x = x + self.pe[: x.shape[0], :] return self.dropout(x) @MODELS.register_module() class ACTORStyleEncoder(BaseModel): # Similar to ACTOR but "action agnostic" and more general def __init__( self, nfeats: int, latent_dim: int = 256, ff_size: int = 1024, num_layers: int = 4, num_heads: int = 4, dropout: float = 0.1, activation: str = "gelu", ) -> None: super().__init__() self.nfeats = nfeats self.projection = nn.Linear(nfeats, latent_dim) self.nbtokens = 2 self.tokens = nn.Parameter(torch.randn(self.nbtokens, latent_dim)) self.sequence_pos_encoding = PositionalEncoding( latent_dim, dropout=dropout, batch_first=True) seq_trans_encoder_layer = nn.TransformerEncoderLayer( d_model=latent_dim, nhead=num_heads, dim_feedforward=ff_size, dropout=dropout, activation=activation, batch_first=True) self.seqTransEncoder = nn.TransformerEncoder( seq_trans_encoder_layer, num_layers=num_layers) def forward(self, x, mask, **kwargs) -> Tensor: # type: ignore x = self.projection(x) device = x.device bs = len(x) tokens = repeat(self.tokens, "nbtoken dim -> bs nbtoken dim", bs=bs) xseq = torch.cat((tokens, x), 1) token_mask = torch.ones((bs, self.nbtokens), dtype=bool, device=device) # type: ignore aug_mask = torch.cat((token_mask, mask), 1) # add positional encoding xseq = self.sequence_pos_encoding(xseq) final = self.seqTransEncoder(xseq, src_key_padding_mask=~aug_mask) return final[:, : self.nbtokens] @MODELS.register_module() class ACTORStyleDecoder(BaseModel): # Similar to ACTOR Decoder def __init__( self, nfeats: int, latent_dim: int = 256, ff_size: int = 1024, num_layers: int = 4, num_heads: int = 4, dropout: float = 0.1, activation: str = "gelu", ) -> None: super().__init__() output_feats = nfeats self.nfeats = nfeats self.sequence_pos_encoding = PositionalEncoding( latent_dim, dropout, batch_first=True) seq_trans_decoder_layer = nn.TransformerDecoderLayer( d_model=latent_dim, nhead=num_heads, dim_feedforward=ff_size, dropout=dropout, activation=activation, batch_first=True) self.seqTransDecoder = nn.TransformerDecoder( seq_trans_decoder_layer, num_layers=num_layers) self.final_layer = nn.Linear(latent_dim, output_feats) def forward(self, z, mask, **kwargs) -> Tensor: # type: ignore latent_dim = z.shape[1] bs, nframes = mask.shape z = z[:, None] # sequence of 1 element for the memory # Construct time queries time_queries = torch.zeros(bs, nframes, latent_dim, device=z.device) time_queries = self.sequence_pos_encoding(time_queries) # Pass through the transformer decoder # with the latent vector for memory output = self.seqTransDecoder( tgt=time_queries, memory=z, tgt_key_padding_mask=~mask) output = self.final_layer(output) # zero for padded area output[~mask] = 0 return output