File size: 4,520 Bytes
45950ff | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 | '''
copy from https://github.com/Mathux/TMR/blob/master/src/model/actor.py
'''
import torch
import torch.nn as nn
from mmengine.registry import MODELS
from mmengine.model import BaseModel
from torch import Tensor
import numpy as np
from einops import repeat
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.1, max_len=5000, batch_first=False) -> None:
super().__init__()
self.batch_first = batch_first
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer("pe", pe, persistent=False)
def forward(self, x: Tensor) -> Tensor:
if self.batch_first:
x = x + self.pe.permute(1, 0, 2)[:, : x.shape[1], :]
else:
x = x + self.pe[: x.shape[0], :]
return self.dropout(x)
@MODELS.register_module()
class ACTORStyleEncoder(BaseModel):
# Similar to ACTOR but "action agnostic" and more general
def __init__(
self,
nfeats: int,
latent_dim: int = 256,
ff_size: int = 1024,
num_layers: int = 4,
num_heads: int = 4,
dropout: float = 0.1,
activation: str = "gelu",
) -> None:
super().__init__()
self.nfeats = nfeats
self.projection = nn.Linear(nfeats, latent_dim)
self.nbtokens = 2
self.tokens = nn.Parameter(torch.randn(self.nbtokens, latent_dim))
self.sequence_pos_encoding = PositionalEncoding(
latent_dim, dropout=dropout, batch_first=True)
seq_trans_encoder_layer = nn.TransformerEncoderLayer(
d_model=latent_dim,
nhead=num_heads,
dim_feedforward=ff_size,
dropout=dropout,
activation=activation,
batch_first=True)
self.seqTransEncoder = nn.TransformerEncoder(
seq_trans_encoder_layer, num_layers=num_layers)
def forward(self, x, mask, **kwargs) -> Tensor: # type: ignore
x = self.projection(x)
device = x.device
bs = len(x)
tokens = repeat(self.tokens, "nbtoken dim -> bs nbtoken dim", bs=bs)
xseq = torch.cat((tokens, x), 1)
token_mask = torch.ones((bs, self.nbtokens), dtype=bool, device=device) # type: ignore
aug_mask = torch.cat((token_mask, mask), 1)
# add positional encoding
xseq = self.sequence_pos_encoding(xseq)
final = self.seqTransEncoder(xseq, src_key_padding_mask=~aug_mask)
return final[:, : self.nbtokens]
@MODELS.register_module()
class ACTORStyleDecoder(BaseModel):
# Similar to ACTOR Decoder
def __init__(
self,
nfeats: int,
latent_dim: int = 256,
ff_size: int = 1024,
num_layers: int = 4,
num_heads: int = 4,
dropout: float = 0.1,
activation: str = "gelu",
) -> None:
super().__init__()
output_feats = nfeats
self.nfeats = nfeats
self.sequence_pos_encoding = PositionalEncoding(
latent_dim, dropout, batch_first=True)
seq_trans_decoder_layer = nn.TransformerDecoderLayer(
d_model=latent_dim,
nhead=num_heads,
dim_feedforward=ff_size,
dropout=dropout,
activation=activation,
batch_first=True)
self.seqTransDecoder = nn.TransformerDecoder(
seq_trans_decoder_layer, num_layers=num_layers)
self.final_layer = nn.Linear(latent_dim, output_feats)
def forward(self, z, mask, **kwargs) -> Tensor: # type: ignore
latent_dim = z.shape[1]
bs, nframes = mask.shape
z = z[:, None] # sequence of 1 element for the memory
# Construct time queries
time_queries = torch.zeros(bs, nframes, latent_dim, device=z.device)
time_queries = self.sequence_pos_encoding(time_queries)
# Pass through the transformer decoder
# with the latent vector for memory
output = self.seqTransDecoder(
tgt=time_queries, memory=z, tgt_key_padding_mask=~mask)
output = self.final_layer(output)
# zero for padded area
output[~mask] = 0
return output
|