directionality_probe / protify /FastPLMs /boltz_fastplms /vb_layers_pairformer.py
nikraf's picture
Upload folder using huggingface_hub
714cf46 verified
from typing import Optional
import torch
from torch import Tensor, nn
from . import vb_const as const
from .vb_layers_attention import AttentionPairBias
from .vb_layers_attentionv2 import AttentionPairBias as AttentionPairBiasV2
from .vb_layers_dropout import get_dropout_mask
from .vb_layers_transition import Transition
from .vb_tri_attn_attention import (
TriangleAttentionEndingNode,
TriangleAttentionStartingNode,
)
from .vb_layers_triangular_mult import (
TriangleMultiplicationIncoming,
TriangleMultiplicationOutgoing,
)
class PairformerLayer(nn.Module):
"""Pairformer module."""
def __init__(
self,
token_s: int,
token_z: int,
num_heads: int = 16,
dropout: float = 0.25,
pairwise_head_width: int = 32,
pairwise_num_heads: int = 4,
post_layer_norm: bool = False,
v2: bool = False,
) -> None:
super().__init__()
self.token_z = token_z
self.dropout = dropout
self.num_heads = num_heads
self.post_layer_norm = post_layer_norm
self.pre_norm_s = nn.LayerNorm(token_s)
if v2:
self.attention = AttentionPairBiasV2(token_s, token_z, num_heads)
else:
self.attention = AttentionPairBias(token_s, token_z, num_heads)
self.tri_mul_out = TriangleMultiplicationOutgoing(token_z)
self.tri_mul_in = TriangleMultiplicationIncoming(token_z)
self.tri_att_start = TriangleAttentionStartingNode(
token_z, pairwise_head_width, pairwise_num_heads, inf=1e9
)
self.tri_att_end = TriangleAttentionEndingNode(
token_z, pairwise_head_width, pairwise_num_heads, inf=1e9
)
self.transition_s = Transition(token_s, token_s * 4)
self.transition_z = Transition(token_z, token_z * 4)
self.s_post_norm = (
nn.LayerNorm(token_s) if self.post_layer_norm else nn.Identity()
)
def forward(
self,
s: Tensor,
z: Tensor,
mask: Tensor,
pair_mask: Tensor,
chunk_size_tri_attn: Optional[int] = None,
use_kernels: bool = False,
use_cuequiv_mul: bool = False,
use_cuequiv_attn: bool = False,
) -> tuple[Tensor, Tensor]:
# Compute pairwise stack
dropout = get_dropout_mask(self.dropout, z, self.training)
z = z + dropout * self.tri_mul_out(
z, mask=pair_mask, use_kernels=use_cuequiv_mul or use_kernels
)
dropout = get_dropout_mask(self.dropout, z, self.training)
z = z + dropout * self.tri_mul_in(
z, mask=pair_mask, use_kernels=use_cuequiv_mul or use_kernels
)
dropout = get_dropout_mask(self.dropout, z, self.training)
z = z + dropout * self.tri_att_start(
z,
mask=pair_mask,
chunk_size=chunk_size_tri_attn,
use_kernels=use_cuequiv_attn or use_kernels,
)
dropout = get_dropout_mask(self.dropout, z, self.training, columnwise=True)
z = z + dropout * self.tri_att_end(
z,
mask=pair_mask,
chunk_size=chunk_size_tri_attn,
use_kernels=use_cuequiv_attn or use_kernels,
)
z = z + self.transition_z(z)
# Compute sequence stack
with torch.autocast("cuda", enabled=False):
s_normed = self.pre_norm_s(s.float())
s = s.float() + self.attention(
s=s_normed, z=z.float(), mask=mask.float(), k_in=s_normed
)
s = s + self.transition_s(s)
s = self.s_post_norm(s)
return s, z
class PairformerModule(nn.Module):
"""Pairformer module."""
def __init__(
self,
token_s: int,
token_z: int,
num_blocks: int,
num_heads: int = 16,
dropout: float = 0.25,
pairwise_head_width: int = 32,
pairwise_num_heads: int = 4,
post_layer_norm: bool = False,
activation_checkpointing: bool = False,
v2: bool = False,
**kwargs,
) -> None:
super().__init__()
self.token_z = token_z
self.num_blocks = num_blocks
self.dropout = dropout
self.num_heads = num_heads
self.post_layer_norm = post_layer_norm
self.activation_checkpointing = activation_checkpointing
self.layers = nn.ModuleList()
for _ in range(num_blocks):
self.layers.append(
PairformerLayer(
token_s,
token_z,
num_heads,
dropout,
pairwise_head_width,
pairwise_num_heads,
post_layer_norm,
v2,
),
)
def forward(
self,
s: Tensor,
z: Tensor,
mask: Tensor,
pair_mask: Tensor,
use_kernels: bool = False,
) -> tuple[Tensor, Tensor]:
"""Perform the forward pass.
Parameters
----------
s : Tensor
The sequence stack.
z : Tensor
The pairwise stack.
mask : Tensor
The mask.
pair_mask : Tensor
The pairwise mask.
use_kernels : bool
Whether to use kernels.
"""
if not self.training:
if z.shape[1] > const.chunk_size_threshold:
chunk_size_tri_attn = 128
else:
chunk_size_tri_attn = 512
else:
chunk_size_tri_attn = None
for layer in self.layers:
if self.activation_checkpointing and self.training:
s, z = torch.utils.checkpoint.checkpoint(
layer,
s,
z,
mask,
pair_mask,
chunk_size_tri_attn,
use_kernels,
)
else:
s, z = layer(s, z, mask, pair_mask, chunk_size_tri_attn, use_kernels)
return s, z
class PairformerNoSeqLayer(nn.Module):
"""Pairformer module without sequence track."""
def __init__(
self,
token_z: int,
dropout: float = 0.25,
pairwise_head_width: int = 32,
pairwise_num_heads: int = 4,
post_layer_norm: bool = False,
) -> None:
super().__init__()
self.token_z = token_z
self.dropout = dropout
self.post_layer_norm = post_layer_norm
self.tri_mul_out = TriangleMultiplicationOutgoing(token_z)
self.tri_mul_in = TriangleMultiplicationIncoming(token_z)
self.tri_att_start = TriangleAttentionStartingNode(
token_z, pairwise_head_width, pairwise_num_heads, inf=1e9
)
self.tri_att_end = TriangleAttentionEndingNode(
token_z, pairwise_head_width, pairwise_num_heads, inf=1e9
)
self.transition_z = Transition(token_z, token_z * 4)
def forward(
self,
z: Tensor,
pair_mask: Tensor,
chunk_size_tri_attn: Optional[int] = None,
use_kernels: bool = False,
use_cuequiv_mul: bool = False,
use_cuequiv_attn: bool = False,
) -> Tensor:
# Compute pairwise stack
dropout = get_dropout_mask(self.dropout, z, self.training)
z = z + dropout * self.tri_mul_out(
z, mask=pair_mask, use_kernels=use_cuequiv_mul or use_kernels
)
dropout = get_dropout_mask(self.dropout, z, self.training)
z = z + dropout * self.tri_mul_in(
z, mask=pair_mask, use_kernels=use_cuequiv_mul or use_kernels
)
dropout = get_dropout_mask(self.dropout, z, self.training)
z = z + dropout * self.tri_att_start(
z,
mask=pair_mask,
chunk_size=chunk_size_tri_attn,
use_kernels=use_cuequiv_attn or use_kernels,
)
dropout = get_dropout_mask(self.dropout, z, self.training, columnwise=True)
z = z + dropout * self.tri_att_end(
z,
mask=pair_mask,
chunk_size=chunk_size_tri_attn,
use_kernels=use_cuequiv_attn or use_kernels,
)
z = z + self.transition_z(z)
return z
class PairformerNoSeqModule(nn.Module):
"""Pairformer module without sequence track."""
def __init__(
self,
token_z: int,
num_blocks: int,
dropout: float = 0.25,
pairwise_head_width: int = 32,
pairwise_num_heads: int = 4,
post_layer_norm: bool = False,
activation_checkpointing: bool = False,
**kwargs,
) -> None:
super().__init__()
self.token_z = token_z
self.num_blocks = num_blocks
self.dropout = dropout
self.post_layer_norm = post_layer_norm
self.activation_checkpointing = activation_checkpointing
self.layers = nn.ModuleList()
for i in range(num_blocks):
self.layers.append(
PairformerNoSeqLayer(
token_z,
dropout,
pairwise_head_width,
pairwise_num_heads,
post_layer_norm,
),
)
def forward(
self,
z: Tensor,
pair_mask: Tensor,
use_kernels: bool = False,
) -> Tensor:
if not self.training:
if z.shape[1] > const.chunk_size_threshold:
chunk_size_tri_attn = 128
else:
chunk_size_tri_attn = 512
else:
chunk_size_tri_attn = None
for layer in self.layers:
if self.activation_checkpointing and self.training:
z = torch.utils.checkpoint.checkpoint(
layer,
z,
pair_mask,
chunk_size_tri_attn,
use_kernels,
)
else:
z = layer(
z,
pair_mask,
chunk_size_tri_attn,
use_kernels,
)
return z