| from typing import Optional |
|
|
| import torch |
| from torch import Tensor, nn |
|
|
| from . import vb_const as const |
| from .vb_layers_attention import AttentionPairBias |
| from .vb_layers_attentionv2 import AttentionPairBias as AttentionPairBiasV2 |
| from .vb_layers_dropout import get_dropout_mask |
| from .vb_layers_transition import Transition |
| from .vb_tri_attn_attention import ( |
| TriangleAttentionEndingNode, |
| TriangleAttentionStartingNode, |
| ) |
| from .vb_layers_triangular_mult import ( |
| TriangleMultiplicationIncoming, |
| TriangleMultiplicationOutgoing, |
| ) |
|
|
|
|
| class PairformerLayer(nn.Module): |
| """Pairformer module.""" |
|
|
| def __init__( |
| self, |
| token_s: int, |
| token_z: int, |
| num_heads: int = 16, |
| dropout: float = 0.25, |
| pairwise_head_width: int = 32, |
| pairwise_num_heads: int = 4, |
| post_layer_norm: bool = False, |
| v2: bool = False, |
| ) -> None: |
| super().__init__() |
| self.token_z = token_z |
| self.dropout = dropout |
| self.num_heads = num_heads |
| self.post_layer_norm = post_layer_norm |
|
|
| self.pre_norm_s = nn.LayerNorm(token_s) |
| if v2: |
| self.attention = AttentionPairBiasV2(token_s, token_z, num_heads) |
| else: |
| self.attention = AttentionPairBias(token_s, token_z, num_heads) |
|
|
| self.tri_mul_out = TriangleMultiplicationOutgoing(token_z) |
| self.tri_mul_in = TriangleMultiplicationIncoming(token_z) |
|
|
| self.tri_att_start = TriangleAttentionStartingNode( |
| token_z, pairwise_head_width, pairwise_num_heads, inf=1e9 |
| ) |
| self.tri_att_end = TriangleAttentionEndingNode( |
| token_z, pairwise_head_width, pairwise_num_heads, inf=1e9 |
| ) |
|
|
| self.transition_s = Transition(token_s, token_s * 4) |
| self.transition_z = Transition(token_z, token_z * 4) |
|
|
| self.s_post_norm = ( |
| nn.LayerNorm(token_s) if self.post_layer_norm else nn.Identity() |
| ) |
|
|
| def forward( |
| self, |
| s: Tensor, |
| z: Tensor, |
| mask: Tensor, |
| pair_mask: Tensor, |
| chunk_size_tri_attn: Optional[int] = None, |
| use_kernels: bool = False, |
| use_cuequiv_mul: bool = False, |
| use_cuequiv_attn: bool = False, |
| ) -> tuple[Tensor, Tensor]: |
| |
| dropout = get_dropout_mask(self.dropout, z, self.training) |
| z = z + dropout * self.tri_mul_out( |
| z, mask=pair_mask, use_kernels=use_cuequiv_mul or use_kernels |
| ) |
|
|
| dropout = get_dropout_mask(self.dropout, z, self.training) |
| z = z + dropout * self.tri_mul_in( |
| z, mask=pair_mask, use_kernels=use_cuequiv_mul or use_kernels |
| ) |
|
|
| dropout = get_dropout_mask(self.dropout, z, self.training) |
| z = z + dropout * self.tri_att_start( |
| z, |
| mask=pair_mask, |
| chunk_size=chunk_size_tri_attn, |
| use_kernels=use_cuequiv_attn or use_kernels, |
| ) |
|
|
| dropout = get_dropout_mask(self.dropout, z, self.training, columnwise=True) |
| z = z + dropout * self.tri_att_end( |
| z, |
| mask=pair_mask, |
| chunk_size=chunk_size_tri_attn, |
| use_kernels=use_cuequiv_attn or use_kernels, |
| ) |
|
|
| z = z + self.transition_z(z) |
|
|
| |
| with torch.autocast("cuda", enabled=False): |
| s_normed = self.pre_norm_s(s.float()) |
| s = s.float() + self.attention( |
| s=s_normed, z=z.float(), mask=mask.float(), k_in=s_normed |
| ) |
| s = s + self.transition_s(s) |
| s = self.s_post_norm(s) |
|
|
| return s, z |
|
|
|
|
| class PairformerModule(nn.Module): |
| """Pairformer module.""" |
|
|
| def __init__( |
| self, |
| token_s: int, |
| token_z: int, |
| num_blocks: int, |
| num_heads: int = 16, |
| dropout: float = 0.25, |
| pairwise_head_width: int = 32, |
| pairwise_num_heads: int = 4, |
| post_layer_norm: bool = False, |
| activation_checkpointing: bool = False, |
| v2: bool = False, |
| **kwargs, |
| ) -> None: |
| super().__init__() |
| self.token_z = token_z |
| self.num_blocks = num_blocks |
| self.dropout = dropout |
| self.num_heads = num_heads |
| self.post_layer_norm = post_layer_norm |
| self.activation_checkpointing = activation_checkpointing |
|
|
| self.layers = nn.ModuleList() |
| for _ in range(num_blocks): |
| self.layers.append( |
| PairformerLayer( |
| token_s, |
| token_z, |
| num_heads, |
| dropout, |
| pairwise_head_width, |
| pairwise_num_heads, |
| post_layer_norm, |
| v2, |
| ), |
| ) |
|
|
| def forward( |
| self, |
| s: Tensor, |
| z: Tensor, |
| mask: Tensor, |
| pair_mask: Tensor, |
| use_kernels: bool = False, |
| ) -> tuple[Tensor, Tensor]: |
| """Perform the forward pass. |
| |
| Parameters |
| ---------- |
| s : Tensor |
| The sequence stack. |
| z : Tensor |
| The pairwise stack. |
| mask : Tensor |
| The mask. |
| pair_mask : Tensor |
| The pairwise mask. |
| use_kernels : bool |
| Whether to use kernels. |
| |
| """ |
| if not self.training: |
| if z.shape[1] > const.chunk_size_threshold: |
| chunk_size_tri_attn = 128 |
| else: |
| chunk_size_tri_attn = 512 |
| else: |
| chunk_size_tri_attn = None |
|
|
| for layer in self.layers: |
| if self.activation_checkpointing and self.training: |
| s, z = torch.utils.checkpoint.checkpoint( |
| layer, |
| s, |
| z, |
| mask, |
| pair_mask, |
| chunk_size_tri_attn, |
| use_kernels, |
| ) |
| else: |
| s, z = layer(s, z, mask, pair_mask, chunk_size_tri_attn, use_kernels) |
| return s, z |
|
|
|
|
| class PairformerNoSeqLayer(nn.Module): |
| """Pairformer module without sequence track.""" |
|
|
| def __init__( |
| self, |
| token_z: int, |
| dropout: float = 0.25, |
| pairwise_head_width: int = 32, |
| pairwise_num_heads: int = 4, |
| post_layer_norm: bool = False, |
| ) -> None: |
| super().__init__() |
| self.token_z = token_z |
| self.dropout = dropout |
| self.post_layer_norm = post_layer_norm |
|
|
| self.tri_mul_out = TriangleMultiplicationOutgoing(token_z) |
| self.tri_mul_in = TriangleMultiplicationIncoming(token_z) |
|
|
| self.tri_att_start = TriangleAttentionStartingNode( |
| token_z, pairwise_head_width, pairwise_num_heads, inf=1e9 |
| ) |
| self.tri_att_end = TriangleAttentionEndingNode( |
| token_z, pairwise_head_width, pairwise_num_heads, inf=1e9 |
| ) |
|
|
| self.transition_z = Transition(token_z, token_z * 4) |
|
|
| def forward( |
| self, |
| z: Tensor, |
| pair_mask: Tensor, |
| chunk_size_tri_attn: Optional[int] = None, |
| use_kernels: bool = False, |
| use_cuequiv_mul: bool = False, |
| use_cuequiv_attn: bool = False, |
| ) -> Tensor: |
| |
| dropout = get_dropout_mask(self.dropout, z, self.training) |
| z = z + dropout * self.tri_mul_out( |
| z, mask=pair_mask, use_kernels=use_cuequiv_mul or use_kernels |
| ) |
|
|
| dropout = get_dropout_mask(self.dropout, z, self.training) |
| z = z + dropout * self.tri_mul_in( |
| z, mask=pair_mask, use_kernels=use_cuequiv_mul or use_kernels |
| ) |
|
|
| dropout = get_dropout_mask(self.dropout, z, self.training) |
| z = z + dropout * self.tri_att_start( |
| z, |
| mask=pair_mask, |
| chunk_size=chunk_size_tri_attn, |
| use_kernels=use_cuequiv_attn or use_kernels, |
| ) |
|
|
| dropout = get_dropout_mask(self.dropout, z, self.training, columnwise=True) |
| z = z + dropout * self.tri_att_end( |
| z, |
| mask=pair_mask, |
| chunk_size=chunk_size_tri_attn, |
| use_kernels=use_cuequiv_attn or use_kernels, |
| ) |
|
|
| z = z + self.transition_z(z) |
| return z |
|
|
|
|
| class PairformerNoSeqModule(nn.Module): |
| """Pairformer module without sequence track.""" |
|
|
| def __init__( |
| self, |
| token_z: int, |
| num_blocks: int, |
| dropout: float = 0.25, |
| pairwise_head_width: int = 32, |
| pairwise_num_heads: int = 4, |
| post_layer_norm: bool = False, |
| activation_checkpointing: bool = False, |
| **kwargs, |
| ) -> None: |
| super().__init__() |
| self.token_z = token_z |
| self.num_blocks = num_blocks |
| self.dropout = dropout |
| self.post_layer_norm = post_layer_norm |
| self.activation_checkpointing = activation_checkpointing |
|
|
| self.layers = nn.ModuleList() |
| for i in range(num_blocks): |
| self.layers.append( |
| PairformerNoSeqLayer( |
| token_z, |
| dropout, |
| pairwise_head_width, |
| pairwise_num_heads, |
| post_layer_norm, |
| ), |
| ) |
|
|
| def forward( |
| self, |
| z: Tensor, |
| pair_mask: Tensor, |
| use_kernels: bool = False, |
| ) -> Tensor: |
| if not self.training: |
| if z.shape[1] > const.chunk_size_threshold: |
| chunk_size_tri_attn = 128 |
| else: |
| chunk_size_tri_attn = 512 |
| else: |
| chunk_size_tri_attn = None |
|
|
| for layer in self.layers: |
| if self.activation_checkpointing and self.training: |
| z = torch.utils.checkpoint.checkpoint( |
| layer, |
| z, |
| pair_mask, |
| chunk_size_tri_attn, |
| use_kernels, |
| ) |
| else: |
| z = layer( |
| z, |
| pair_mask, |
| chunk_size_tri_attn, |
| use_kernels, |
| ) |
| return z |
|
|