Instructions to use Synthyra/Boltz2 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Synthyra/Boltz2 with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="Synthyra/Boltz2", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Synthyra/Boltz2", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| import torch | |
| from torch import Tensor, nn | |
| from . import vb_layers_initialize as init | |
| class PairWeightedAveraging(nn.Module): | |
| """Pair weighted averaging layer.""" | |
| def __init__( | |
| self, | |
| c_m: int, | |
| c_z: int, | |
| c_h: int, | |
| num_heads: int, | |
| inf: float = 1e6, | |
| ) -> None: | |
| """Initialize the pair weighted averaging layer. | |
| Parameters | |
| ---------- | |
| c_m: int | |
| The dimension of the input sequence. | |
| c_z: int | |
| The dimension of the input pairwise tensor. | |
| c_h: int | |
| The dimension of the hidden. | |
| num_heads: int | |
| The number of heads. | |
| inf: float | |
| The value to use for masking, default 1e6. | |
| """ | |
| super().__init__() | |
| self.c_m = c_m | |
| self.c_z = c_z | |
| self.c_h = c_h | |
| self.num_heads = num_heads | |
| self.inf = inf | |
| self.norm_m = nn.LayerNorm(c_m) | |
| self.norm_z = nn.LayerNorm(c_z) | |
| self.proj_m = nn.Linear(c_m, c_h * num_heads, bias=False) | |
| self.proj_g = nn.Linear(c_m, c_h * num_heads, bias=False) | |
| self.proj_z = nn.Linear(c_z, num_heads, bias=False) | |
| self.proj_o = nn.Linear(c_h * num_heads, c_m, bias=False) | |
| init.final_init_(self.proj_o.weight) | |
| def forward( | |
| self, m: Tensor, z: Tensor, mask: Tensor, chunk_heads: False = bool | |
| ) -> Tensor: | |
| """Forward pass. | |
| Parameters | |
| ---------- | |
| m : torch.Tensor | |
| The input sequence tensor (B, S, N, D) | |
| z : torch.Tensor | |
| The input pairwise tensor (B, N, N, D) | |
| mask : torch.Tensor | |
| The pairwise mask tensor (B, N, N) | |
| Returns | |
| ------- | |
| torch.Tensor | |
| The output sequence tensor (B, S, N, D) | |
| """ | |
| # Compute layer norms | |
| m = self.norm_m(m) | |
| z = self.norm_z(z) | |
| if chunk_heads and not self.training: | |
| # Compute heads sequentially | |
| o_chunks = [] | |
| for head_idx in range(self.num_heads): | |
| sliced_weight_proj_m = self.proj_m.weight[ | |
| head_idx * self.c_h : (head_idx + 1) * self.c_h, : | |
| ] | |
| sliced_weight_proj_g = self.proj_g.weight[ | |
| head_idx * self.c_h : (head_idx + 1) * self.c_h, : | |
| ] | |
| sliced_weight_proj_z = self.proj_z.weight[head_idx : (head_idx + 1), :] | |
| sliced_weight_proj_o = self.proj_o.weight[ | |
| :, head_idx * self.c_h : (head_idx + 1) * self.c_h | |
| ] | |
| # Project input tensors | |
| v: Tensor = m @ sliced_weight_proj_m.T | |
| v = v.reshape(*v.shape[:3], 1, self.c_h) | |
| v = v.permute(0, 3, 1, 2, 4) | |
| # Compute weights | |
| b: Tensor = z @ sliced_weight_proj_z.T | |
| b = b.permute(0, 3, 1, 2) | |
| b = b + (1 - mask[:, None]) * -self.inf | |
| w = torch.softmax(b, dim=-1) | |
| # Compute gating | |
| g: Tensor = m @ sliced_weight_proj_g.T | |
| g = g.sigmoid() | |
| # Compute output | |
| o = torch.einsum("bhij,bhsjd->bhsid", w, v) | |
| o = o.permute(0, 2, 3, 1, 4) | |
| o = o.reshape(*o.shape[:3], 1 * self.c_h) | |
| o_chunks = g * o | |
| if head_idx == 0: | |
| o_out = o_chunks @ sliced_weight_proj_o.T | |
| else: | |
| o_out += o_chunks @ sliced_weight_proj_o.T | |
| return o_out | |
| else: | |
| # Project input tensors | |
| v: Tensor = self.proj_m(m) | |
| v = v.reshape(*v.shape[:3], self.num_heads, self.c_h) | |
| v = v.permute(0, 3, 1, 2, 4) | |
| # Compute weights | |
| b: Tensor = self.proj_z(z) | |
| b = b.permute(0, 3, 1, 2) | |
| b = b + (1 - mask[:, None]) * -self.inf | |
| w = torch.softmax(b, dim=-1) | |
| # Compute gating | |
| g: Tensor = self.proj_g(m) | |
| g = g.sigmoid() | |
| # Compute output | |
| o = torch.einsum("bhij,bhsjd->bhsid", w, v) | |
| o = o.permute(0, 2, 3, 1, 4) | |
| o = o.reshape(*o.shape[:3], self.num_heads * self.c_h) | |
| o = self.proj_o(g * o) | |
| return o | |