Instructions to use Synthyra/Boltz2 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Synthyra/Boltz2 with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="Synthyra/Boltz2", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Synthyra/Boltz2", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| import torch | |
| from torch import nn | |
| from torch.nn.functional import pad | |
| from . import vb_const as const | |
| from . import vb_layers_initialize as init | |
| from .vb_layers_confidence_utils import ( | |
| compute_aggregated_metric, | |
| compute_ptms, | |
| ) | |
| from .vb_layers_pairformer import PairformerModule | |
| from .vb_modules_encodersv2 import RelativePositionEncoder | |
| from .vb_modules_trunkv2 import ( | |
| ContactConditioning, | |
| ) | |
| from .vb_modules_utils import LinearNoBias | |
| class ConfidenceModule(nn.Module): | |
| """Algorithm 31""" | |
| def __init__( | |
| self, | |
| token_s, | |
| token_z, | |
| pairformer_args: dict, | |
| num_dist_bins=64, | |
| token_level_confidence=True, | |
| max_dist=22, | |
| add_s_to_z_prod=False, | |
| add_s_input_to_s=False, | |
| add_z_input_to_z=False, | |
| maximum_bond_distance=0, | |
| bond_type_feature=False, | |
| confidence_args: dict = None, | |
| compile_pairformer=False, | |
| fix_sym_check=False, | |
| cyclic_pos_enc=False, | |
| return_latent_feats=False, | |
| conditioning_cutoff_min=None, | |
| conditioning_cutoff_max=None, | |
| **kwargs, | |
| ): | |
| super().__init__() | |
| self.max_num_atoms_per_token = 23 | |
| if "no_update_s" in pairformer_args: | |
| self.no_update_s = pairformer_args["no_update_s"] | |
| else: | |
| self.no_update_s = False | |
| boundaries = torch.linspace(2, max_dist, num_dist_bins - 1) | |
| self.register_buffer("boundaries", boundaries) | |
| self.dist_bin_pairwise_embed = nn.Embedding(num_dist_bins, token_z) | |
| init.gating_init_(self.dist_bin_pairwise_embed.weight) | |
| self.token_level_confidence = token_level_confidence | |
| self.s_to_z = LinearNoBias(token_s, token_z) | |
| self.s_to_z_transpose = LinearNoBias(token_s, token_z) | |
| init.gating_init_(self.s_to_z.weight) | |
| init.gating_init_(self.s_to_z_transpose.weight) | |
| self.add_s_to_z_prod = add_s_to_z_prod | |
| if add_s_to_z_prod: | |
| self.s_to_z_prod_in1 = LinearNoBias(token_s, token_z) | |
| self.s_to_z_prod_in2 = LinearNoBias(token_s, token_z) | |
| self.s_to_z_prod_out = LinearNoBias(token_z, token_z) | |
| init.gating_init_(self.s_to_z_prod_out.weight) | |
| self.s_inputs_norm = nn.LayerNorm(token_s) | |
| if not self.no_update_s: | |
| self.s_norm = nn.LayerNorm(token_s) | |
| self.z_norm = nn.LayerNorm(token_z) | |
| self.add_s_input_to_s = add_s_input_to_s | |
| if add_s_input_to_s: | |
| self.s_input_to_s = LinearNoBias(token_s, token_s) | |
| init.gating_init_(self.s_input_to_s.weight) | |
| self.add_z_input_to_z = add_z_input_to_z | |
| if add_z_input_to_z: | |
| self.rel_pos = RelativePositionEncoder( | |
| token_z, fix_sym_check=fix_sym_check, cyclic_pos_enc=cyclic_pos_enc | |
| ) | |
| self.token_bonds = nn.Linear( | |
| 1 if maximum_bond_distance == 0 else maximum_bond_distance + 2, | |
| token_z, | |
| bias=False, | |
| ) | |
| self.bond_type_feature = bond_type_feature | |
| if bond_type_feature: | |
| self.token_bonds_type = nn.Embedding(len(const.bond_types) + 1, token_z) | |
| self.contact_conditioning = ContactConditioning( | |
| token_z=token_z, | |
| cutoff_min=conditioning_cutoff_min, | |
| cutoff_max=conditioning_cutoff_max, | |
| ) | |
| pairformer_args["v2"] = True | |
| self.pairformer_stack = PairformerModule( | |
| token_s, | |
| token_z, | |
| **pairformer_args, | |
| ) | |
| self.return_latent_feats = return_latent_feats | |
| self.confidence_heads = ConfidenceHeads( | |
| token_s, | |
| token_z, | |
| token_level_confidence=token_level_confidence, | |
| **confidence_args, | |
| ) | |
| def forward( | |
| self, | |
| s_inputs, # Float['b n ts'] | |
| s, # Float['b n ts'] | |
| z, # Float['b n n tz'] | |
| x_pred, # Float['bm m 3'] | |
| feats, | |
| pred_distogram_logits, | |
| multiplicity=1, | |
| run_sequentially=False, | |
| use_kernels: bool = False, | |
| ): | |
| if run_sequentially and multiplicity > 1: | |
| assert z.shape[0] == 1, "Not supported with batch size > 1" | |
| out_dicts = [] | |
| for sample_idx in range(multiplicity): | |
| out_dicts.append( # noqa: PERF401 | |
| self.forward( | |
| s_inputs, | |
| s, | |
| z, | |
| x_pred[sample_idx : sample_idx + 1], | |
| feats, | |
| pred_distogram_logits, | |
| multiplicity=1, | |
| run_sequentially=False, | |
| use_kernels=use_kernels, | |
| ) | |
| ) | |
| out_dict = {} | |
| for key in out_dicts[0]: | |
| if key != "pair_chains_iptm": | |
| out_dict[key] = torch.cat([out[key] for out in out_dicts], dim=0) | |
| else: | |
| pair_chains_iptm = {} | |
| for chain_idx1 in out_dicts[0][key]: | |
| chains_iptm = {} | |
| for chain_idx2 in out_dicts[0][key][chain_idx1]: | |
| chains_iptm[chain_idx2] = torch.cat( | |
| [out[key][chain_idx1][chain_idx2] for out in out_dicts], | |
| dim=0, | |
| ) | |
| pair_chains_iptm[chain_idx1] = chains_iptm | |
| out_dict[key] = pair_chains_iptm | |
| return out_dict | |
| s_inputs = self.s_inputs_norm(s_inputs) | |
| if not self.no_update_s: | |
| s = self.s_norm(s) | |
| if self.add_s_input_to_s: | |
| s = s + self.s_input_to_s(s_inputs) | |
| z = self.z_norm(z) | |
| if self.add_z_input_to_z: | |
| relative_position_encoding = self.rel_pos(feats) | |
| z = z + relative_position_encoding | |
| z = z + self.token_bonds(feats["token_bonds"].float()) | |
| if self.bond_type_feature: | |
| z = z + self.token_bonds_type(feats["type_bonds"].long()) | |
| z = z + self.contact_conditioning(feats) | |
| s = s.repeat_interleave(multiplicity, 0) | |
| z = ( | |
| z | |
| + self.s_to_z(s_inputs)[:, :, None, :] | |
| + self.s_to_z_transpose(s_inputs)[:, None, :, :] | |
| ) | |
| if self.add_s_to_z_prod: | |
| z = z + self.s_to_z_prod_out( | |
| self.s_to_z_prod_in1(s_inputs)[:, :, None, :] | |
| * self.s_to_z_prod_in2(s_inputs)[:, None, :, :] | |
| ) | |
| z = z.repeat_interleave(multiplicity, 0) | |
| s_inputs = s_inputs.repeat_interleave(multiplicity, 0) | |
| token_to_rep_atom = feats["token_to_rep_atom"] | |
| token_to_rep_atom = token_to_rep_atom.repeat_interleave(multiplicity, 0) | |
| if len(x_pred.shape) == 4: | |
| B, mult, N, _ = x_pred.shape | |
| x_pred = x_pred.reshape(B * mult, N, -1) | |
| else: | |
| BM, N, _ = x_pred.shape | |
| x_pred_repr = torch.bmm(token_to_rep_atom.float(), x_pred) | |
| d = torch.cdist(x_pred_repr, x_pred_repr) | |
| distogram = (d.unsqueeze(-1) > self.boundaries).sum(dim=-1).long() | |
| distogram = self.dist_bin_pairwise_embed(distogram) | |
| z = z + distogram | |
| mask = feats["token_pad_mask"].repeat_interleave(multiplicity, 0) | |
| pair_mask = mask[:, :, None] * mask[:, None, :] | |
| s_t, z_t = self.pairformer_stack( | |
| s, z, mask=mask, pair_mask=pair_mask, use_kernels=use_kernels | |
| ) | |
| # AF3 has residual connections, we remove them | |
| s = s_t | |
| z = z_t | |
| out_dict = {} | |
| if self.return_latent_feats: | |
| out_dict["s_conf"] = s | |
| out_dict["z_conf"] = z | |
| # confidence heads | |
| out_dict.update( | |
| self.confidence_heads( | |
| s=s, | |
| z=z, | |
| x_pred=x_pred, | |
| d=d, | |
| feats=feats, | |
| multiplicity=multiplicity, | |
| pred_distogram_logits=pred_distogram_logits, | |
| ) | |
| ) | |
| return out_dict | |
| class ConfidenceHeads(nn.Module): | |
| def __init__( | |
| self, | |
| token_s, | |
| token_z, | |
| num_plddt_bins=50, | |
| num_pde_bins=64, | |
| num_pae_bins=64, | |
| token_level_confidence=True, | |
| use_separate_heads: bool = False, | |
| **kwargs, | |
| ): | |
| super().__init__() | |
| self.max_num_atoms_per_token = 23 | |
| self.token_level_confidence = token_level_confidence | |
| self.use_separate_heads = use_separate_heads | |
| if self.use_separate_heads: | |
| self.to_pae_intra_logits = LinearNoBias(token_z, num_pae_bins) | |
| self.to_pae_inter_logits = LinearNoBias(token_z, num_pae_bins) | |
| else: | |
| self.to_pae_logits = LinearNoBias(token_z, num_pae_bins) | |
| if self.use_separate_heads: | |
| self.to_pde_intra_logits = LinearNoBias(token_z, num_pde_bins) | |
| self.to_pde_inter_logits = LinearNoBias(token_z, num_pde_bins) | |
| else: | |
| self.to_pde_logits = LinearNoBias(token_z, num_pde_bins) | |
| if self.token_level_confidence: | |
| self.to_plddt_logits = LinearNoBias(token_s, num_plddt_bins) | |
| self.to_resolved_logits = LinearNoBias(token_s, 2) | |
| else: | |
| self.to_plddt_logits = LinearNoBias( | |
| token_s, num_plddt_bins * self.max_num_atoms_per_token | |
| ) | |
| self.to_resolved_logits = LinearNoBias( | |
| token_s, 2 * self.max_num_atoms_per_token | |
| ) | |
| def forward( | |
| self, | |
| s, # Float['b n ts'] | |
| z, # Float['b n n tz'] | |
| x_pred, # Float['bm m 3'] | |
| d, | |
| feats, | |
| pred_distogram_logits, | |
| multiplicity=1, | |
| ): | |
| if self.use_separate_heads: | |
| asym_id_token = feats["asym_id"] | |
| is_same_chain = asym_id_token.unsqueeze(-1) == asym_id_token.unsqueeze(-2) | |
| is_different_chain = ~is_same_chain | |
| if self.use_separate_heads: | |
| pae_intra_logits = self.to_pae_intra_logits(z) | |
| pae_intra_logits = pae_intra_logits * is_same_chain.float().unsqueeze(-1) | |
| pae_inter_logits = self.to_pae_inter_logits(z) | |
| pae_inter_logits = pae_inter_logits * is_different_chain.float().unsqueeze( | |
| -1 | |
| ) | |
| pae_logits = pae_inter_logits + pae_intra_logits | |
| else: | |
| pae_logits = self.to_pae_logits(z) | |
| if self.use_separate_heads: | |
| pde_intra_logits = self.to_pde_intra_logits(z + z.transpose(1, 2)) | |
| pde_intra_logits = pde_intra_logits * is_same_chain.float().unsqueeze(-1) | |
| pde_inter_logits = self.to_pde_inter_logits(z + z.transpose(1, 2)) | |
| pde_inter_logits = pde_inter_logits * is_different_chain.float().unsqueeze( | |
| -1 | |
| ) | |
| pde_logits = pde_inter_logits + pde_intra_logits | |
| else: | |
| pde_logits = self.to_pde_logits(z + z.transpose(1, 2)) | |
| resolved_logits = self.to_resolved_logits(s) | |
| plddt_logits = self.to_plddt_logits(s) | |
| ligand_weight = 20 | |
| non_interface_weight = 1 | |
| interface_weight = 10 | |
| token_type = feats["mol_type"] | |
| token_type = token_type.repeat_interleave(multiplicity, 0) | |
| is_ligand_token = (token_type == const.chain_type_ids["NONPOLYMER"]).float() | |
| if self.token_level_confidence: | |
| plddt = compute_aggregated_metric(plddt_logits) | |
| token_pad_mask = feats["token_pad_mask"].repeat_interleave(multiplicity, 0) | |
| complex_plddt = (plddt * token_pad_mask).sum(dim=-1) / token_pad_mask.sum( | |
| dim=-1 | |
| ) | |
| is_contact = (d < 8).float() | |
| is_different_chain = ( | |
| feats["asym_id"].unsqueeze(-1) != feats["asym_id"].unsqueeze(-2) | |
| ).float() | |
| is_different_chain = is_different_chain.repeat_interleave(multiplicity, 0) | |
| token_interface_mask = torch.max( | |
| is_contact * is_different_chain * (1 - is_ligand_token).unsqueeze(-1), | |
| dim=-1, | |
| ).values | |
| token_non_interface_mask = (1 - token_interface_mask) * ( | |
| 1 - is_ligand_token | |
| ) | |
| iplddt_weight = ( | |
| is_ligand_token * ligand_weight | |
| + token_interface_mask * interface_weight | |
| + token_non_interface_mask * non_interface_weight | |
| ) | |
| complex_iplddt = (plddt * token_pad_mask * iplddt_weight).sum( | |
| dim=-1 | |
| ) / torch.sum(token_pad_mask * iplddt_weight, dim=-1) | |
| else: | |
| # token to atom conversion for resolved logits | |
| B, N, _ = resolved_logits.shape | |
| resolved_logits = resolved_logits.reshape( | |
| B, N, self.max_num_atoms_per_token, 2 | |
| ) | |
| arange_max_num_atoms = ( | |
| torch.arange(self.max_num_atoms_per_token) | |
| .reshape(1, 1, -1) | |
| .to(resolved_logits.device) | |
| ) | |
| max_num_atoms_mask = ( | |
| feats["atom_to_token"].sum(1).unsqueeze(-1) > arange_max_num_atoms | |
| ) | |
| resolved_logits = resolved_logits[:, max_num_atoms_mask.squeeze(0)] | |
| resolved_logits = pad( | |
| resolved_logits, | |
| ( | |
| 0, | |
| 0, | |
| 0, | |
| int( | |
| feats["atom_pad_mask"].shape[1] | |
| - feats["atom_pad_mask"].sum().item() | |
| ), | |
| ), | |
| value=0, | |
| ) | |
| plddt_logits = plddt_logits.reshape(B, N, self.max_num_atoms_per_token, -1) | |
| plddt_logits = plddt_logits[:, max_num_atoms_mask.squeeze(0)] | |
| plddt_logits = pad( | |
| plddt_logits, | |
| ( | |
| 0, | |
| 0, | |
| 0, | |
| int( | |
| feats["atom_pad_mask"].shape[1] | |
| - feats["atom_pad_mask"].sum().item() | |
| ), | |
| ), | |
| value=0, | |
| ) | |
| atom_pad_mask = feats["atom_pad_mask"].repeat_interleave(multiplicity, 0) | |
| plddt = compute_aggregated_metric(plddt_logits) | |
| complex_plddt = (plddt * atom_pad_mask).sum(dim=-1) / atom_pad_mask.sum( | |
| dim=-1 | |
| ) | |
| token_type = feats["mol_type"].float() | |
| atom_to_token = feats["atom_to_token"].float() | |
| chain_id_token = feats["asym_id"].float() | |
| atom_type = torch.bmm(atom_to_token, token_type.unsqueeze(-1)).squeeze(-1) | |
| is_ligand_atom = (atom_type == const.chain_type_ids["NONPOLYMER"]).float() | |
| d_atom = torch.cdist(x_pred, x_pred) | |
| is_contact = (d_atom < 8).float() | |
| chain_id_atom = torch.bmm( | |
| atom_to_token, chain_id_token.unsqueeze(-1) | |
| ).squeeze(-1) | |
| is_different_chain = ( | |
| chain_id_atom.unsqueeze(-1) != chain_id_atom.unsqueeze(-2) | |
| ).float() | |
| atom_interface_mask = torch.max( | |
| is_contact * is_different_chain * (1 - is_ligand_atom).unsqueeze(-1), | |
| dim=-1, | |
| ).values | |
| atom_non_interface_mask = (1 - atom_interface_mask) * (1 - is_ligand_atom) | |
| iplddt_weight = ( | |
| is_ligand_atom * ligand_weight | |
| + atom_interface_mask * interface_weight | |
| + atom_non_interface_mask * non_interface_weight | |
| ) | |
| complex_iplddt = (plddt * feats["atom_pad_mask"] * iplddt_weight).sum( | |
| dim=-1 | |
| ) / torch.sum(feats["atom_pad_mask"] * iplddt_weight, dim=-1) | |
| # Compute the gPDE and giPDE | |
| pde = compute_aggregated_metric(pde_logits, end=32) | |
| pred_distogram_prob = nn.functional.softmax( | |
| pred_distogram_logits, dim=-1 | |
| ).repeat_interleave(multiplicity, 0) | |
| contacts = torch.zeros((1, 1, 1, 64), dtype=pred_distogram_prob.dtype).to( | |
| pred_distogram_prob.device | |
| ) | |
| contacts[:, :, :, :20] = 1.0 | |
| prob_contact = (pred_distogram_prob * contacts).sum(-1) | |
| token_pad_mask = feats["token_pad_mask"].repeat_interleave(multiplicity, 0) | |
| token_pad_pair_mask = ( | |
| token_pad_mask.unsqueeze(-1) | |
| * token_pad_mask.unsqueeze(-2) | |
| * ( | |
| 1 | |
| - torch.eye( | |
| token_pad_mask.shape[1], device=token_pad_mask.device | |
| ).unsqueeze(0) | |
| ) | |
| ) | |
| token_pair_mask = token_pad_pair_mask * prob_contact | |
| complex_pde = (pde * token_pair_mask).sum(dim=(1, 2)) / token_pair_mask.sum( | |
| dim=(1, 2) | |
| ) | |
| asym_id = feats["asym_id"].repeat_interleave(multiplicity, 0) | |
| token_interface_pair_mask = token_pair_mask * ( | |
| asym_id.unsqueeze(-1) != asym_id.unsqueeze(-2) | |
| ) | |
| complex_ipde = (pde * token_interface_pair_mask).sum(dim=(1, 2)) / ( | |
| token_interface_pair_mask.sum(dim=(1, 2)) + 1e-5 | |
| ) | |
| out_dict = dict( | |
| pde_logits=pde_logits, | |
| plddt_logits=plddt_logits, | |
| resolved_logits=resolved_logits, | |
| pde=pde, | |
| plddt=plddt, | |
| complex_plddt=complex_plddt, | |
| complex_iplddt=complex_iplddt, | |
| complex_pde=complex_pde, | |
| complex_ipde=complex_ipde, | |
| ) | |
| out_dict["pae_logits"] = pae_logits | |
| out_dict["pae"] = compute_aggregated_metric(pae_logits, end=32) | |
| try: | |
| ptm, iptm, ligand_iptm, protein_iptm, pair_chains_iptm = compute_ptms( | |
| pae_logits, x_pred, feats, multiplicity | |
| ) | |
| out_dict["ptm"] = ptm | |
| out_dict["iptm"] = iptm | |
| out_dict["ligand_iptm"] = ligand_iptm | |
| out_dict["protein_iptm"] = protein_iptm | |
| out_dict["pair_chains_iptm"] = pair_chains_iptm | |
| except Exception as e: | |
| print(f"Error in compute_ptms: {e}") | |
| out_dict["ptm"] = torch.zeros_like(complex_plddt) | |
| out_dict["iptm"] = torch.zeros_like(complex_plddt) | |
| out_dict["ligand_iptm"] = torch.zeros_like(complex_plddt) | |
| out_dict["protein_iptm"] = torch.zeros_like(complex_plddt) | |
| out_dict["pair_chains_iptm"] = torch.zeros_like(complex_plddt) | |
| return out_dict | |