| from typing import Optional |
|
|
| from torch import Tensor, nn |
|
|
| from . import vb_layers_initialize as init |
|
|
|
|
| class Transition(nn.Module): |
| """Perform a two-layer MLP.""" |
|
|
| def __init__( |
| self, |
| dim: int = 128, |
| hidden: int = 512, |
| out_dim: Optional[int] = None, |
| ) -> None: |
| """Initialize the TransitionUpdate module. |
| |
| Parameters |
| ---------- |
| dim: int |
| The dimension of the input, default 128 |
| hidden: int |
| The dimension of the hidden, default 512 |
| out_dim: Optional[int] |
| The dimension of the output, default None |
| |
| """ |
| super().__init__() |
| if out_dim is None: |
| out_dim = dim |
|
|
| self.norm = nn.LayerNorm(dim, eps=1e-5) |
| self.fc1 = nn.Linear(dim, hidden, bias=False) |
| self.fc2 = nn.Linear(dim, hidden, bias=False) |
| self.fc3 = nn.Linear(hidden, out_dim, bias=False) |
| self.silu = nn.SiLU() |
| self.hidden = hidden |
|
|
| init.bias_init_one_(self.norm.weight) |
| init.bias_init_zero_(self.norm.bias) |
|
|
| init.lecun_normal_init_(self.fc1.weight) |
| init.lecun_normal_init_(self.fc2.weight) |
| init.final_init_(self.fc3.weight) |
|
|
| def forward(self, x: Tensor, chunk_size: int = None) -> Tensor: |
| """Perform a forward pass. |
| |
| Parameters |
| ---------- |
| x: torch.Tensor |
| The input data of shape (..., D) |
| |
| Returns |
| ------- |
| x: torch.Tensor |
| The output data of shape (..., D) |
| |
| """ |
| x = self.norm(x) |
|
|
| if chunk_size is None or self.training: |
| x = self.silu(self.fc1(x)) * self.fc2(x) |
| x = self.fc3(x) |
| return x |
| else: |
| |
| for i in range(0, self.hidden, chunk_size): |
| fc1_slice = self.fc1.weight[i : i + chunk_size, :] |
| fc2_slice = self.fc2.weight[i : i + chunk_size, :] |
| fc3_slice = self.fc3.weight[:, i : i + chunk_size] |
| x_chunk = self.silu((x @ fc1_slice.T)) * (x @ fc2_slice.T) |
| if i == 0: |
| x_out = x_chunk @ fc3_slice.T |
| else: |
| x_out = x_out + x_chunk @ fc3_slice.T |
| return x_out |
|
|