directionality_probe / protify /FastPLMs /boltz_fastplms /vb_layers_transition.py
nikraf's picture
Upload folder using huggingface_hub
714cf46 verified
from typing import Optional
from torch import Tensor, nn
from . import vb_layers_initialize as init
class Transition(nn.Module):
"""Perform a two-layer MLP."""
def __init__(
self,
dim: int = 128,
hidden: int = 512,
out_dim: Optional[int] = None,
) -> None:
"""Initialize the TransitionUpdate module.
Parameters
----------
dim: int
The dimension of the input, default 128
hidden: int
The dimension of the hidden, default 512
out_dim: Optional[int]
The dimension of the output, default None
"""
super().__init__()
if out_dim is None:
out_dim = dim
self.norm = nn.LayerNorm(dim, eps=1e-5)
self.fc1 = nn.Linear(dim, hidden, bias=False)
self.fc2 = nn.Linear(dim, hidden, bias=False)
self.fc3 = nn.Linear(hidden, out_dim, bias=False)
self.silu = nn.SiLU()
self.hidden = hidden
init.bias_init_one_(self.norm.weight)
init.bias_init_zero_(self.norm.bias)
init.lecun_normal_init_(self.fc1.weight)
init.lecun_normal_init_(self.fc2.weight)
init.final_init_(self.fc3.weight)
def forward(self, x: Tensor, chunk_size: int = None) -> Tensor:
"""Perform a forward pass.
Parameters
----------
x: torch.Tensor
The input data of shape (..., D)
Returns
-------
x: torch.Tensor
The output data of shape (..., D)
"""
x = self.norm(x)
if chunk_size is None or self.training:
x = self.silu(self.fc1(x)) * self.fc2(x)
x = self.fc3(x)
return x
else:
# Compute in chunks
for i in range(0, self.hidden, chunk_size):
fc1_slice = self.fc1.weight[i : i + chunk_size, :]
fc2_slice = self.fc2.weight[i : i + chunk_size, :]
fc3_slice = self.fc3.weight[:, i : i + chunk_size]
x_chunk = self.silu((x @ fc1_slice.T)) * (x @ fc2_slice.T)
if i == 0:
x_out = x_chunk @ fc3_slice.T
else:
x_out = x_out + x_chunk @ fc3_slice.T
return x_out