Spaces:
Sleeping
Sleeping
| """ | |
| Implementation of time conditioned Transformer. | |
| """ | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class PositionalEncoding(nn.Module): | |
| def __init__(self, d_hid, n_position=200): | |
| super(PositionalEncoding, self).__init__() | |
| # Not a parameter | |
| self.register_buffer( | |
| "pos_table", self._get_sinusoid_encoding_table(n_position, d_hid) | |
| ) | |
| def _get_sinusoid_encoding_table(self, n_position, d_hid): | |
| """Sinusoid position encoding table""" | |
| # TODO: make it with torch instead of numpy | |
| def get_position_angle_vec(position): | |
| return [ | |
| position / np.power(10000, 2 * (hid_j // 2) / d_hid) | |
| for hid_j in range(d_hid) | |
| ] | |
| sinusoid_table = np.array( | |
| [get_position_angle_vec(pos_i) for pos_i in range(n_position)] | |
| ) | |
| sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i | |
| sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 | |
| return torch.FloatTensor(sinusoid_table).unsqueeze(0) | |
| def forward(self, x): | |
| """ | |
| Input: | |
| x: [B,N,D] | |
| """ | |
| return x + self.pos_table[:, : x.size(1)].clone().detach() | |
| class ConcatSquashLinear(nn.Module): | |
| def __init__(self, dim_in, dim_out, dim_ctx): | |
| super(ConcatSquashLinear, self).__init__() | |
| self._layer = nn.Linear(dim_in, dim_out) | |
| self._hyper_bias = nn.Linear(dim_ctx, dim_out, bias=False) | |
| self._hyper_gate = nn.Linear(dim_ctx, dim_out) | |
| def forward(self, ctx, x): | |
| assert ctx.dim() == x.dim() | |
| gate = torch.sigmoid(self._hyper_gate(ctx)) | |
| bias = self._hyper_bias(ctx) | |
| ret = self._layer(x) * gate + bias | |
| return ret | |
| class TimeMLP(nn.Module): | |
| def __init__( | |
| self, | |
| dim_in, | |
| dim_h, | |
| dim_out, | |
| dim_ctx=None, | |
| act=F.relu, | |
| dropout=0.0, | |
| use_time=False, | |
| ): | |
| super().__init__() | |
| self.act = act | |
| self.use_time = use_time | |
| dim_h = int(dim_h) | |
| if use_time: | |
| self.fc1 = ConcatSquashLinear(dim_in, dim_h, dim_ctx) | |
| self.fc2 = ConcatSquashLinear(dim_h, dim_out, dim_ctx) | |
| else: | |
| self.fc1 = nn.Linear(dim_in, dim_h) | |
| self.fc2 = nn.Linear(dim_h, dim_out) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x, ctx=None): | |
| if self.use_time: | |
| x = self.fc1(x=x, ctx=ctx) | |
| else: | |
| x = self.fc1(x) | |
| x = self.act(x) | |
| x = self.dropout(x) | |
| if self.use_time: | |
| x = self.fc2(x=x, ctx=ctx) | |
| else: | |
| x = self.fc2(x) | |
| x = self.dropout(x) | |
| return x | |
| class MultiHeadAttention(nn.Module): | |
| def __init__(self, dim_self, dim_ref, num_heads, dropout=0.0): | |
| super().__init__() | |
| self.num_heads = num_heads | |
| head_dim = dim_self // num_heads | |
| self.scale = head_dim**-0.5 | |
| self.to_queries = nn.Linear(dim_self, dim_self) | |
| self.to_keys_values = nn.Linear(dim_ref, dim_self * 2) | |
| self.project = nn.Linear(dim_self, dim_self) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward( | |
| self, | |
| x, | |
| y=None, | |
| mask=None, | |
| alpha=None, | |
| ): | |
| y = y if y is not None else x | |
| b_a, n, c = x.shape | |
| b, m, d = y.shape | |
| # b n h dh | |
| queries = self.to_queries(x).reshape( | |
| b_a, n, self.num_heads, c // self.num_heads | |
| ) | |
| # b m 2 h dh | |
| keys_values = self.to_keys_values(y).reshape( | |
| b, m, 2, self.num_heads, c // self.num_heads | |
| ) | |
| keys, values = keys_values[:, :, 0], keys_values[:, :, 1] | |
| if alpha is not None: | |
| out, attention = self.forward_interpolation( | |
| queries, keys, values, alpha, mask | |
| ) | |
| else: | |
| attention = torch.einsum("bnhd,bmhd->bnmh", queries, keys) * self.scale | |
| if mask is not None: | |
| if mask.dim() == 2: | |
| mask = mask.unsqueeze(1) | |
| attention = attention.masked_fill(mask.unsqueeze(3), float("-inf")) | |
| attention = attention.softmax(dim=2) | |
| attention = self.dropout(attention) | |
| out = torch.einsum("bnmh,bmhd->bnhd", attention, values).reshape(b, n, c) | |
| out = self.project(out) | |
| return out, attention | |
| class TimeTransformerEncoderLayer(nn.Module): | |
| def __init__( | |
| self, | |
| dim_self, | |
| dim_ctx=None, | |
| num_heads=1, | |
| mlp_ratio=2.0, | |
| act=F.leaky_relu, | |
| dropout=0.0, | |
| use_time=True, | |
| ): | |
| super().__init__() | |
| self.use_time = use_time | |
| self.act = act | |
| self.attn = MultiHeadAttention(dim_self, dim_self, num_heads, dropout) | |
| self.attn_norm = nn.LayerNorm(dim_self) | |
| mlp_ratio = int(mlp_ratio) | |
| self.mlp = TimeMLP( | |
| dim_self, dim_self * mlp_ratio, dim_self, dim_ctx, use_time=use_time | |
| ) | |
| self.norm = nn.LayerNorm(dim_self) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x, ctx=None): | |
| res = x | |
| x, attn = self.attn(x) | |
| x = self.attn_norm(x + res) | |
| res = x | |
| x = self.mlp(x, ctx=ctx) | |
| x = self.norm(x + res) | |
| return x, attn | |
| class TimeTransformerDecoderLayer(TimeTransformerEncoderLayer): | |
| def __init__( | |
| self, | |
| dim_self, | |
| dim_ref, | |
| dim_ctx=None, | |
| num_heads=1, | |
| mlp_ratio=2, | |
| act=F.leaky_relu, | |
| dropout=0.0, | |
| use_time=True, | |
| ): | |
| super().__init__( | |
| dim_self=dim_self, | |
| dim_ctx=dim_ctx, | |
| num_heads=num_heads, | |
| mlp_ratio=mlp_ratio, | |
| act=act, | |
| dropout=dropout, | |
| use_time=use_time, | |
| ) | |
| self.cross_attn = MultiHeadAttention(dim_self, dim_ref, num_heads, dropout) | |
| self.cross_attn_norm = nn.LayerNorm(dim_self) | |
| def forward(self, x, y, ctx=None): | |
| res = x | |
| x, attn = self.attn(x) | |
| x = self.attn_norm(x + res) | |
| res = x | |
| x, attn = self.cross_attn(x, y) | |
| x = self.cross_attn_norm(x + res) | |
| res = x | |
| x = self.mlp(x, ctx=ctx) | |
| x = self.norm(x + res) | |
| return x, attn | |
| class TimeTransformerEncoder(nn.Module): | |
| def __init__( | |
| self, | |
| dim_self, | |
| dim_ctx=None, | |
| num_heads=1, | |
| mlp_ratio=2.0, | |
| act=F.leaky_relu, | |
| dropout=0.0, | |
| use_time=True, | |
| num_layers=3, | |
| last_fc=False, | |
| last_fc_dim_out=None, | |
| ): | |
| super().__init__() | |
| self.last_fc = last_fc | |
| if last_fc: | |
| self.fc = nn.Linear(dim_self, last_fc_dim_out) | |
| self.layers = nn.ModuleList( | |
| [ | |
| TimeTransformerEncoderLayer( | |
| dim_self, | |
| dim_ctx=dim_ctx, | |
| num_heads=num_heads, | |
| mlp_ratio=mlp_ratio, | |
| act=act, | |
| dropout=dropout, | |
| use_time=use_time, | |
| ) | |
| for _ in range(num_layers) | |
| ] | |
| ) | |
| def forward(self, x, ctx=None): | |
| for i, layer in enumerate(self.layers): | |
| x, attn = layer(x, ctx=ctx) | |
| if self.last_fc: | |
| x = self.fc(x) | |
| return x | |
| class TimeTransformerDecoder(nn.Module): | |
| def __init__( | |
| self, | |
| dim_self, | |
| dim_ref, | |
| dim_ctx=None, | |
| num_heads=1, | |
| mlp_ratio=2.0, | |
| act=F.leaky_relu, | |
| dropout=0.0, | |
| use_time=True, | |
| num_layers=3, | |
| last_fc=True, | |
| last_fc_dim_out=None, | |
| ): | |
| super().__init__() | |
| self.last_fc = last_fc | |
| if last_fc: | |
| self.fc = nn.Linear(dim_self, last_fc_dim_out) | |
| self.layers = nn.ModuleList( | |
| [ | |
| TimeTransformerDecoderLayer( | |
| dim_self, | |
| dim_ref, | |
| dim_ctx, | |
| num_heads, | |
| mlp_ratio, | |
| act, | |
| dropout, | |
| use_time, | |
| ) | |
| for _ in range(num_layers) | |
| ] | |
| ) | |
| def forward(self, x, y, ctx=None): | |
| for i, layer in enumerate(self.layers): | |
| x, attn = layer(x, y=y, ctx=ctx) | |
| if self.last_fc: | |
| x = self.fc(x) | |
| return x | |