Instructions to use Synthyra/Boltz2 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Synthyra/Boltz2 with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="Synthyra/Boltz2", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Synthyra/Boltz2", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| # started from code from https://github.com/lucidrains/alphafold3-pytorch, MIT License, Copyright (c) 2024 Phil Wang | |
| from functools import partial | |
| from math import pi | |
| import torch | |
| from einops import rearrange | |
| from torch import nn | |
| from torch.nn import Linear, Module, ModuleList | |
| from torch.nn.functional import one_hot | |
| from . import vb_layers_initialize as init | |
| from .vb_layers_transition import Transition | |
| from .vb_modules_transformersv2 import AtomTransformer | |
| from .vb_modules_utils import LinearNoBias | |
| class FourierEmbedding(Module): | |
| """Algorithm 22.""" | |
| def __init__(self, dim): | |
| super().__init__() | |
| self.proj = nn.Linear(1, dim) | |
| torch.nn.init.normal_(self.proj.weight, mean=0, std=1) | |
| torch.nn.init.normal_(self.proj.bias, mean=0, std=1) | |
| self.proj.requires_grad_(False) | |
| def forward( | |
| self, | |
| times, # Float[' b'], | |
| ): # -> Float['b d']: | |
| times = rearrange(times, "b -> b 1") | |
| rand_proj = self.proj(times) | |
| return torch.cos(2 * pi * rand_proj) | |
| class RelativePositionEncoder(Module): | |
| """Algorithm 3.""" | |
| def __init__( | |
| self, token_z, r_max=32, s_max=2, fix_sym_check=False, cyclic_pos_enc=False | |
| ): | |
| super().__init__() | |
| self.r_max = r_max | |
| self.s_max = s_max | |
| self.linear_layer = LinearNoBias(4 * (r_max + 1) + 2 * (s_max + 1) + 1, token_z) | |
| self.fix_sym_check = fix_sym_check | |
| self.cyclic_pos_enc = cyclic_pos_enc | |
| def forward(self, feats): | |
| b_same_chain = torch.eq( | |
| feats["asym_id"][:, :, None], feats["asym_id"][:, None, :] | |
| ) | |
| b_same_residue = torch.eq( | |
| feats["residue_index"][:, :, None], feats["residue_index"][:, None, :] | |
| ) | |
| b_same_entity = torch.eq( | |
| feats["entity_id"][:, :, None], feats["entity_id"][:, None, :] | |
| ) | |
| d_residue = ( | |
| feats["residue_index"][:, :, None] - feats["residue_index"][:, None, :] | |
| ) | |
| if self.cyclic_pos_enc and torch.any(feats["cyclic_period"] > 0): | |
| period = torch.where( | |
| feats["cyclic_period"] > 0, | |
| feats["cyclic_period"], | |
| torch.zeros_like(feats["cyclic_period"]) + 10000, | |
| ) | |
| d_residue = (d_residue - period * torch.round(d_residue / period)).long() | |
| d_residue = torch.clip( | |
| d_residue + self.r_max, | |
| 0, | |
| 2 * self.r_max, | |
| ) | |
| d_residue = torch.where( | |
| b_same_chain, d_residue, torch.zeros_like(d_residue) + 2 * self.r_max + 1 | |
| ) | |
| a_rel_pos = one_hot(d_residue, 2 * self.r_max + 2) | |
| d_token = torch.clip( | |
| feats["token_index"][:, :, None] | |
| - feats["token_index"][:, None, :] | |
| + self.r_max, | |
| 0, | |
| 2 * self.r_max, | |
| ) | |
| d_token = torch.where( | |
| b_same_chain & b_same_residue, | |
| d_token, | |
| torch.zeros_like(d_token) + 2 * self.r_max + 1, | |
| ) | |
| a_rel_token = one_hot(d_token, 2 * self.r_max + 2) | |
| d_chain = torch.clip( | |
| feats["sym_id"][:, :, None] - feats["sym_id"][:, None, :] + self.s_max, | |
| 0, | |
| 2 * self.s_max, | |
| ) | |
| d_chain = torch.where( | |
| (~b_same_entity) if self.fix_sym_check else b_same_chain, | |
| torch.zeros_like(d_chain) + 2 * self.s_max + 1, | |
| d_chain, | |
| ) | |
| # Note: added | (~b_same_entity) based on observation of ProteinX manuscript | |
| a_rel_chain = one_hot(d_chain, 2 * self.s_max + 2) | |
| p = self.linear_layer( | |
| torch.cat( | |
| [ | |
| a_rel_pos.float(), | |
| a_rel_token.float(), | |
| b_same_entity.unsqueeze(-1).float(), | |
| a_rel_chain.float(), | |
| ], | |
| dim=-1, | |
| ) | |
| ) | |
| return p | |
| class SingleConditioning(Module): | |
| """Algorithm 21.""" | |
| def __init__( | |
| self, | |
| sigma_data: float, | |
| token_s: int = 384, | |
| dim_fourier: int = 256, | |
| num_transitions: int = 2, | |
| transition_expansion_factor: int = 2, | |
| eps: float = 1e-20, | |
| disable_times: bool = False, | |
| ) -> None: | |
| super().__init__() | |
| self.eps = eps | |
| self.sigma_data = sigma_data | |
| self.disable_times = disable_times | |
| self.norm_single = nn.LayerNorm(2 * token_s) | |
| self.single_embed = nn.Linear(2 * token_s, 2 * token_s) | |
| if not self.disable_times: | |
| self.fourier_embed = FourierEmbedding(dim_fourier) | |
| self.norm_fourier = nn.LayerNorm(dim_fourier) | |
| self.fourier_to_single = LinearNoBias(dim_fourier, 2 * token_s) | |
| transitions = ModuleList([]) | |
| for _ in range(num_transitions): | |
| transition = Transition( | |
| dim=2 * token_s, hidden=transition_expansion_factor * 2 * token_s | |
| ) | |
| transitions.append(transition) | |
| self.transitions = transitions | |
| def forward( | |
| self, | |
| times, # Float[' b'], | |
| s_trunk, # Float['b n ts'], | |
| s_inputs, # Float['b n ts'], | |
| ): # -> Float['b n 2ts']: | |
| s = torch.cat((s_trunk, s_inputs), dim=-1) | |
| s = self.single_embed(self.norm_single(s)) | |
| if not self.disable_times: | |
| fourier_embed = self.fourier_embed( | |
| times | |
| ) # note: sigma rescaling done in diffusion module | |
| normed_fourier = self.norm_fourier(fourier_embed) | |
| fourier_to_single = self.fourier_to_single(normed_fourier) | |
| s = rearrange(fourier_to_single, "b d -> b 1 d") + s | |
| for transition in self.transitions: | |
| s = transition(s) + s | |
| return s, normed_fourier if not self.disable_times else None | |
| class PairwiseConditioning(Module): | |
| """Algorithm 21.""" | |
| def __init__( | |
| self, | |
| token_z, | |
| dim_token_rel_pos_feats, | |
| num_transitions=2, | |
| transition_expansion_factor=2, | |
| ): | |
| super().__init__() | |
| self.dim_pairwise_init_proj = nn.Sequential( | |
| nn.LayerNorm(token_z + dim_token_rel_pos_feats), | |
| LinearNoBias(token_z + dim_token_rel_pos_feats, token_z), | |
| ) | |
| transitions = ModuleList([]) | |
| for _ in range(num_transitions): | |
| transition = Transition( | |
| dim=token_z, hidden=transition_expansion_factor * token_z | |
| ) | |
| transitions.append(transition) | |
| self.transitions = transitions | |
| def forward( | |
| self, | |
| z_trunk, # Float['b n n tz'], | |
| token_rel_pos_feats, # Float['b n n 3'], | |
| ): # -> Float['b n n tz']: | |
| z = torch.cat((z_trunk, token_rel_pos_feats), dim=-1) | |
| z = self.dim_pairwise_init_proj(z) | |
| for transition in self.transitions: | |
| z = transition(z) + z | |
| return z | |
| def get_indexing_matrix(K, W, H, device): | |
| assert W % 2 == 0 | |
| assert H % (W // 2) == 0 | |
| h = H // (W // 2) | |
| assert h % 2 == 0 | |
| arange = torch.arange(2 * K, device=device) | |
| index = ((arange.unsqueeze(0) - arange.unsqueeze(1)) + h // 2).clamp( | |
| min=0, max=h + 1 | |
| ) | |
| index = index.view(K, 2, 2 * K)[:, 0, :] | |
| onehot = one_hot(index, num_classes=h + 2)[..., 1:-1].transpose(1, 0) | |
| return onehot.reshape(2 * K, h * K).float() | |
| def single_to_keys(single, indexing_matrix, W, H): | |
| B, N, D = single.shape | |
| K = N // W | |
| single = single.view(B, 2 * K, W // 2, D) | |
| return torch.einsum("b j i d, j k -> b k i d", single, indexing_matrix).reshape( | |
| B, K, H, D | |
| ) # j = 2K, i = W//2, k = h * K | |
| class AtomEncoder(Module): | |
| def __init__( | |
| self, | |
| atom_s, | |
| atom_z, | |
| token_s, | |
| token_z, | |
| atoms_per_window_queries, | |
| atoms_per_window_keys, | |
| atom_feature_dim, | |
| structure_prediction=True, | |
| use_no_atom_char=False, | |
| use_atom_backbone_feat=False, | |
| use_residue_feats_atoms=False, | |
| ): | |
| super().__init__() | |
| self.embed_atom_features = Linear(atom_feature_dim, atom_s) | |
| self.embed_atompair_ref_pos = LinearNoBias(3, atom_z) | |
| self.embed_atompair_ref_dist = LinearNoBias(1, atom_z) | |
| self.embed_atompair_mask = LinearNoBias(1, atom_z) | |
| self.atoms_per_window_queries = atoms_per_window_queries | |
| self.atoms_per_window_keys = atoms_per_window_keys | |
| self.use_no_atom_char = use_no_atom_char | |
| self.use_atom_backbone_feat = use_atom_backbone_feat | |
| self.use_residue_feats_atoms = use_residue_feats_atoms | |
| self.structure_prediction = structure_prediction | |
| if structure_prediction: | |
| self.s_to_c_trans = nn.Sequential( | |
| nn.LayerNorm(token_s), LinearNoBias(token_s, atom_s) | |
| ) | |
| init.final_init_(self.s_to_c_trans[1].weight) | |
| self.z_to_p_trans = nn.Sequential( | |
| nn.LayerNorm(token_z), LinearNoBias(token_z, atom_z) | |
| ) | |
| init.final_init_(self.z_to_p_trans[1].weight) | |
| self.c_to_p_trans_k = nn.Sequential( | |
| nn.ReLU(), | |
| LinearNoBias(atom_s, atom_z), | |
| ) | |
| init.final_init_(self.c_to_p_trans_k[1].weight) | |
| self.c_to_p_trans_q = nn.Sequential( | |
| nn.ReLU(), | |
| LinearNoBias(atom_s, atom_z), | |
| ) | |
| init.final_init_(self.c_to_p_trans_q[1].weight) | |
| self.p_mlp = nn.Sequential( | |
| nn.ReLU(), | |
| LinearNoBias(atom_z, atom_z), | |
| nn.ReLU(), | |
| LinearNoBias(atom_z, atom_z), | |
| nn.ReLU(), | |
| LinearNoBias(atom_z, atom_z), | |
| ) | |
| init.final_init_(self.p_mlp[5].weight) | |
| def forward( | |
| self, | |
| feats, | |
| s_trunk=None, # Float['bm n ts'], | |
| z=None, # Float['bm n n tz'], | |
| ): | |
| with torch.autocast("cuda", enabled=False): | |
| B, N, _ = feats["ref_pos"].shape | |
| atom_mask = feats["atom_pad_mask"].bool() # Bool['b m'], | |
| atom_ref_pos = feats["ref_pos"] # Float['b m 3'], | |
| atom_uid = feats["ref_space_uid"] # Long['b m'], | |
| atom_feats = [ | |
| atom_ref_pos, | |
| feats["ref_charge"].unsqueeze(-1), | |
| feats["ref_element"], | |
| ] | |
| if not self.use_no_atom_char: | |
| atom_feats.append(feats["ref_atom_name_chars"].reshape(B, N, 4 * 64)) | |
| if self.use_atom_backbone_feat: | |
| atom_feats.append(feats["atom_backbone_feat"]) | |
| if self.use_residue_feats_atoms: | |
| res_feats = torch.cat( | |
| [ | |
| feats["res_type"], | |
| feats["modified"].unsqueeze(-1), | |
| one_hot(feats["mol_type"], num_classes=4).float(), | |
| ], | |
| dim=-1, | |
| ) | |
| atom_to_token = feats["atom_to_token"].float() | |
| atom_res_feats = torch.bmm(atom_to_token, res_feats) | |
| atom_feats.append(atom_res_feats) | |
| atom_feats = torch.cat(atom_feats, dim=-1) | |
| c = self.embed_atom_features(atom_feats) | |
| # note we are already creating the windows to make it more efficient | |
| W, H = self.atoms_per_window_queries, self.atoms_per_window_keys | |
| B, N = c.shape[:2] | |
| K = N // W | |
| keys_indexing_matrix = get_indexing_matrix(K, W, H, c.device) | |
| to_keys = partial( | |
| single_to_keys, indexing_matrix=keys_indexing_matrix, W=W, H=H | |
| ) | |
| atom_ref_pos_queries = atom_ref_pos.view(B, K, W, 1, 3) | |
| atom_ref_pos_keys = to_keys(atom_ref_pos).view(B, K, 1, H, 3) | |
| d = atom_ref_pos_keys - atom_ref_pos_queries # Float['b k w h 3'] | |
| d_norm = torch.sum(d * d, dim=-1, keepdim=True) # Float['b k w h 1'] | |
| d_norm = 1 / ( | |
| 1 + d_norm | |
| ) # AF3 feeds in the reciprocal of the distance norm | |
| atom_mask_queries = atom_mask.view(B, K, W, 1) | |
| atom_mask_keys = ( | |
| to_keys(atom_mask.unsqueeze(-1).float()).view(B, K, 1, H).bool() | |
| ) | |
| atom_uid_queries = atom_uid.view(B, K, W, 1) | |
| atom_uid_keys = ( | |
| to_keys(atom_uid.unsqueeze(-1).float()).view(B, K, 1, H).long() | |
| ) | |
| v = ( | |
| ( | |
| atom_mask_queries | |
| & atom_mask_keys | |
| & (atom_uid_queries == atom_uid_keys) | |
| ) | |
| .float() | |
| .unsqueeze(-1) | |
| ) # Bool['b k w h 1'] | |
| p = self.embed_atompair_ref_pos(d) * v | |
| p = p + self.embed_atompair_ref_dist(d_norm) * v | |
| p = p + self.embed_atompair_mask(v) * v | |
| q = c | |
| if self.structure_prediction: | |
| # run only in structure model not in initial encoding | |
| atom_to_token = feats["atom_to_token"].float() # Long['b m n'], | |
| s_to_c = self.s_to_c_trans(s_trunk.float()) | |
| s_to_c = torch.bmm(atom_to_token, s_to_c) | |
| c = c + s_to_c.to(c) | |
| atom_to_token_queries = atom_to_token.view( | |
| B, K, W, atom_to_token.shape[-1] | |
| ) | |
| atom_to_token_keys = to_keys(atom_to_token) | |
| z_to_p = self.z_to_p_trans(z.float()) | |
| z_to_p = torch.einsum( | |
| "bijd,bwki,bwlj->bwkld", | |
| z_to_p, | |
| atom_to_token_queries, | |
| atom_to_token_keys, | |
| ) | |
| p = p + z_to_p.to(p) | |
| p = p + self.c_to_p_trans_q(c.view(B, K, W, 1, c.shape[-1])) | |
| p = p + self.c_to_p_trans_k(to_keys(c).view(B, K, 1, H, c.shape[-1])) | |
| p = p + self.p_mlp(p) | |
| return q, c, p, to_keys | |
| class AtomAttentionEncoder(Module): | |
| def __init__( | |
| self, | |
| atom_s, | |
| token_s, | |
| atoms_per_window_queries, | |
| atoms_per_window_keys, | |
| atom_encoder_depth=3, | |
| atom_encoder_heads=4, | |
| structure_prediction=True, | |
| activation_checkpointing=False, | |
| transformer_post_layer_norm=False, | |
| ): | |
| super().__init__() | |
| self.structure_prediction = structure_prediction | |
| if structure_prediction: | |
| self.r_to_q_trans = LinearNoBias(3, atom_s) | |
| init.final_init_(self.r_to_q_trans.weight) | |
| self.atom_encoder = AtomTransformer( | |
| dim=atom_s, | |
| dim_single_cond=atom_s, | |
| attn_window_queries=atoms_per_window_queries, | |
| attn_window_keys=atoms_per_window_keys, | |
| depth=atom_encoder_depth, | |
| heads=atom_encoder_heads, | |
| activation_checkpointing=activation_checkpointing, | |
| post_layer_norm=transformer_post_layer_norm, | |
| ) | |
| self.atom_to_token_trans = nn.Sequential( | |
| LinearNoBias(atom_s, 2 * token_s if structure_prediction else token_s), | |
| nn.ReLU(), | |
| ) | |
| def forward( | |
| self, | |
| feats, | |
| q, | |
| c, | |
| atom_enc_bias, | |
| to_keys, | |
| r=None, # Float['bm m 3'], | |
| multiplicity=1, | |
| ): | |
| B, N, _ = feats["ref_pos"].shape | |
| atom_mask = feats["atom_pad_mask"].bool() # Bool['b m'], | |
| if self.structure_prediction: | |
| # only here the multiplicity kicks in because we use the different positions r | |
| q = q.repeat_interleave(multiplicity, 0) | |
| r_to_q = self.r_to_q_trans(r) | |
| q = q + r_to_q | |
| c = c.repeat_interleave(multiplicity, 0) | |
| atom_mask = atom_mask.repeat_interleave(multiplicity, 0) | |
| q = self.atom_encoder( | |
| q=q, | |
| mask=atom_mask, | |
| c=c, | |
| bias=atom_enc_bias, | |
| multiplicity=multiplicity, | |
| to_keys=to_keys, | |
| ) | |
| with torch.autocast("cuda", enabled=False): | |
| q_to_a = self.atom_to_token_trans(q).float() | |
| atom_to_token = feats["atom_to_token"].float() | |
| atom_to_token = atom_to_token.repeat_interleave(multiplicity, 0) | |
| atom_to_token_mean = atom_to_token / ( | |
| atom_to_token.sum(dim=1, keepdim=True) + 1e-6 | |
| ) | |
| a = torch.bmm(atom_to_token_mean.transpose(1, 2), q_to_a) | |
| a = a.to(q) | |
| return a, q, c, to_keys | |
| class AtomAttentionDecoder(Module): | |
| """Algorithm 6.""" | |
| def __init__( | |
| self, | |
| atom_s, | |
| token_s, | |
| attn_window_queries, | |
| attn_window_keys, | |
| atom_decoder_depth=3, | |
| atom_decoder_heads=4, | |
| activation_checkpointing=False, | |
| transformer_post_layer_norm=False, | |
| ): | |
| super().__init__() | |
| self.a_to_q_trans = LinearNoBias(2 * token_s, atom_s) | |
| init.final_init_(self.a_to_q_trans.weight) | |
| self.atom_decoder = AtomTransformer( | |
| dim=atom_s, | |
| dim_single_cond=atom_s, | |
| attn_window_queries=attn_window_queries, | |
| attn_window_keys=attn_window_keys, | |
| depth=atom_decoder_depth, | |
| heads=atom_decoder_heads, | |
| activation_checkpointing=activation_checkpointing, | |
| post_layer_norm=transformer_post_layer_norm, | |
| ) | |
| if transformer_post_layer_norm: | |
| self.atom_feat_to_atom_pos_update = LinearNoBias(atom_s, 3) | |
| init.final_init_(self.atom_feat_to_atom_pos_update.weight) | |
| else: | |
| self.atom_feat_to_atom_pos_update = nn.Sequential( | |
| nn.LayerNorm(atom_s), LinearNoBias(atom_s, 3) | |
| ) | |
| init.final_init_(self.atom_feat_to_atom_pos_update[1].weight) | |
| def forward( | |
| self, | |
| a, # Float['bm n 2ts'], | |
| q, # Float['bm m as'], | |
| c, # Float['bm m as'], | |
| atom_dec_bias, # Float['bm m m az'], | |
| feats, | |
| to_keys, | |
| multiplicity=1, | |
| ): | |
| with torch.autocast("cuda", enabled=False): | |
| atom_to_token = feats["atom_to_token"].float() | |
| atom_to_token = atom_to_token.repeat_interleave(multiplicity, 0) | |
| a_to_q = self.a_to_q_trans(a.float()) | |
| a_to_q = torch.bmm(atom_to_token, a_to_q) | |
| q = q + a_to_q.to(q) | |
| atom_mask = feats["atom_pad_mask"] # Bool['b m'], | |
| atom_mask = atom_mask.repeat_interleave(multiplicity, 0) | |
| q = self.atom_decoder( | |
| q=q, | |
| mask=atom_mask, | |
| c=c, | |
| bias=atom_dec_bias, | |
| multiplicity=multiplicity, | |
| to_keys=to_keys, | |
| ) | |
| r_update = self.atom_feat_to_atom_pos_update(q) | |
| return r_update | |