File size: 2,288 Bytes
714cf46 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 | from typing import Optional
from torch import Tensor, nn
from . import vb_layers_initialize as init
class Transition(nn.Module):
"""Perform a two-layer MLP."""
def __init__(
self,
dim: int = 128,
hidden: int = 512,
out_dim: Optional[int] = None,
) -> None:
"""Initialize the TransitionUpdate module.
Parameters
----------
dim: int
The dimension of the input, default 128
hidden: int
The dimension of the hidden, default 512
out_dim: Optional[int]
The dimension of the output, default None
"""
super().__init__()
if out_dim is None:
out_dim = dim
self.norm = nn.LayerNorm(dim, eps=1e-5)
self.fc1 = nn.Linear(dim, hidden, bias=False)
self.fc2 = nn.Linear(dim, hidden, bias=False)
self.fc3 = nn.Linear(hidden, out_dim, bias=False)
self.silu = nn.SiLU()
self.hidden = hidden
init.bias_init_one_(self.norm.weight)
init.bias_init_zero_(self.norm.bias)
init.lecun_normal_init_(self.fc1.weight)
init.lecun_normal_init_(self.fc2.weight)
init.final_init_(self.fc3.weight)
def forward(self, x: Tensor, chunk_size: int = None) -> Tensor:
"""Perform a forward pass.
Parameters
----------
x: torch.Tensor
The input data of shape (..., D)
Returns
-------
x: torch.Tensor
The output data of shape (..., D)
"""
x = self.norm(x)
if chunk_size is None or self.training:
x = self.silu(self.fc1(x)) * self.fc2(x)
x = self.fc3(x)
return x
else:
# Compute in chunks
for i in range(0, self.hidden, chunk_size):
fc1_slice = self.fc1.weight[i : i + chunk_size, :]
fc2_slice = self.fc2.weight[i : i + chunk_size, :]
fc3_slice = self.fc3.weight[:, i : i + chunk_size]
x_chunk = self.silu((x @ fc1_slice.T)) * (x @ fc2_slice.T)
if i == 0:
x_out = x_chunk @ fc3_slice.T
else:
x_out = x_out + x_chunk @ fc3_slice.T
return x_out
|