Feature Extraction
Transformers
Safetensors
esmfold2
biology
protein-structure
multimodal-protein-model
custom_code
Instructions to use Synthyra/ESMFold2-Fast with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Synthyra/ESMFold2-Fast with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="Synthyra/ESMFold2-Fast", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Synthyra/ESMFold2-Fast", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| from itertools import groupby | |
| from typing import Any | |
| import numpy as np | |
| import torch | |
| from .esmfold2_constants import ELEMENT_NUMBER_TO_SYMBOL, MOL_TYPE_NONPOLYMER | |
| from .esmfold2_molecular_complex import ( | |
| MolecularComplex, | |
| MolecularComplexMetadata, | |
| ) | |
| def get_element_symbol(atomic_num: int) -> str: | |
| return ELEMENT_NUMBER_TO_SYMBOL.get(atomic_num, "X") | |
| def build_molecular_complex_from_features( | |
| coords: torch.Tensor, | |
| plddt: torch.Tensor, | |
| atom_mask: torch.Tensor, | |
| ref_element: torch.Tensor, | |
| ref_atom_name_chars: torch.Tensor, | |
| chain_infos: list, | |
| complex_id: str, | |
| ) -> MolecularComplex: | |
| """Construct a MolecularComplex from feature-dict tensors and chain metadata. | |
| Non-polymer chains (ligands) collapse all per-atom tokens into a single | |
| residue token whose pLDDT is the per-token average and whose hetero flag | |
| is True. | |
| """ | |
| mask_np = atom_mask.bool().cpu().numpy() | |
| coords_np = coords.float().cpu().numpy() | |
| name_chars_np = ref_atom_name_chars.cpu().numpy() | |
| elements_np = ref_element.cpu().numpy() | |
| plddt_np = plddt.float().cpu().numpy() | |
| sequence_tokens: list[str] = [] | |
| chain_ids_per_token: list[int] = [] | |
| token_to_atoms: list[list[int]] = [] | |
| confidence: list[float] = [] | |
| flat_positions: list[list[float]] = [] | |
| flat_elements: list[str] = [] | |
| flat_names: list[str] = [] | |
| flat_hetero: list[bool] = [] | |
| chain_lookup: dict[int, str] = {} | |
| entity_info: dict[int, str] = {} | |
| out_atom_cursor = 0 | |
| for ci in chain_infos: | |
| chain_lookup[ci.asym_id] = ci.chain_id | |
| is_nonpolymer = ci.mol_type == MOL_TYPE_NONPOLYMER | |
| entity_info[ci.entity_id] = "non-polymer" if is_nonpolymer else "polymer" | |
| if is_nonpolymer: | |
| residue_name = ci.tokens[0].residue_name if ci.tokens else "LIG" | |
| sequence_tokens.append(residue_name) | |
| chain_ids_per_token.append(ci.asym_id) | |
| avg_plddt = ( | |
| float(np.mean([plddt_np[ti.token_index] for ti in ci.tokens])) | |
| if ci.tokens | |
| else 0.0 | |
| ) | |
| confidence.append(avg_plddt) | |
| token_atom_start = out_atom_cursor | |
| for ti in ci.tokens: | |
| for atom_idx in range(ti.atom_start, ti.atom_start + ti.atom_count): | |
| if not mask_np[atom_idx]: | |
| continue | |
| flat_positions.append(coords_np[atom_idx].tolist()) | |
| flat_elements.append(get_element_symbol(int(elements_np[atom_idx]))) | |
| chars = name_chars_np[atom_idx] | |
| name = "".join( | |
| chr(int(c) + 32) for c in chars if int(c) != 0 | |
| ).strip() | |
| flat_names.append(name) | |
| flat_hetero.append(True) | |
| out_atom_cursor += 1 | |
| token_to_atoms.append([token_atom_start, out_atom_cursor]) | |
| continue | |
| # Atom-tokenized modified residues (HYP, MSE, ...) span multiple | |
| # tokens per residue; collapse them back to one mmCIF residue. | |
| for _residue_index, ti_iter in groupby( | |
| ci.tokens, key=lambda t: t.residue_index | |
| ): | |
| ti_group = list(ti_iter) | |
| sequence_tokens.append(ti_group[0].residue_name) | |
| chain_ids_per_token.append(ci.asym_id) | |
| confidence.append( | |
| float(np.mean([plddt_np[ti.token_index] for ti in ti_group])) | |
| ) | |
| token_atom_start = out_atom_cursor | |
| for ti in ti_group: | |
| for atom_idx in range(ti.atom_start, ti.atom_start + ti.atom_count): | |
| if not mask_np[atom_idx]: | |
| continue | |
| flat_positions.append(coords_np[atom_idx].tolist()) | |
| flat_elements.append(get_element_symbol(int(elements_np[atom_idx]))) | |
| chars = name_chars_np[atom_idx] | |
| name = "".join( | |
| chr(int(c) + 32) for c in chars if int(c) != 0 | |
| ).strip() | |
| flat_names.append(name) | |
| flat_hetero.append(False) | |
| out_atom_cursor += 1 | |
| token_to_atoms.append([token_atom_start, out_atom_cursor]) | |
| return MolecularComplex( | |
| id=complex_id, | |
| sequence=sequence_tokens, | |
| atom_positions=np.array(flat_positions, dtype=np.float32).reshape(-1, 3), | |
| atom_elements=np.array(flat_elements, dtype=object), | |
| token_to_atoms=np.array(token_to_atoms, dtype=np.int32).reshape(-1, 2), | |
| chain_id=np.array(chain_ids_per_token, dtype=np.int64), | |
| plddt=np.array(confidence, dtype=np.float32), | |
| atom_names=np.array(flat_names, dtype=object), | |
| atom_hetero=np.array(flat_hetero, dtype=bool), | |
| metadata=MolecularComplexMetadata( | |
| entity_lookup=entity_info, | |
| chain_lookup=chain_lookup, | |
| assembly_composition=None, | |
| ), | |
| ) | |
| def build_molecular_complex( | |
| structure: Any, coords: torch.Tensor, plddt: torch.Tensor, complex_id: str | |
| ) -> MolecularComplex: | |
| """Directly constructs a MolecularComplex from model outputs without intermediate files. | |
| Args: | |
| structure: Object with .chains, .residues, .atoms numpy structured arrays. | |
| coords: [N_atoms, 3] predicted atom coordinates. | |
| plddt: [N_residues] per-residue confidence scores. | |
| complex_id: Identifier string for the resulting complex. | |
| """ | |
| flat_positions = [] | |
| flat_elements = [] | |
| flat_names = [] | |
| flat_hetero = [] | |
| sequence_tokens = [] | |
| token_to_atoms = [] | |
| chain_ids_per_token = [] | |
| confidence_scores = [] | |
| chain_lookup = {} | |
| entity_info = {} | |
| global_atom_cursor = 0 | |
| global_res_cursor = 0 | |
| atom_array_idx = 0 | |
| for chain in structure.chains: | |
| chain_idx_numeric = chain["asym_id"] | |
| chain_name_str = str(chain["name"]) | |
| mol_type = chain["mol_type"] | |
| chain_lookup[chain_idx_numeric] = chain_name_str | |
| entity_info[chain["entity_id"]] = ( | |
| "polymer" if mol_type != MOL_TYPE_NONPOLYMER else "non-polymer" | |
| ) | |
| res_start = chain["res_idx"] | |
| res_end = chain["res_idx"] + chain["res_num"] | |
| residues = structure.residues[res_start:res_end] | |
| for residue in residues: | |
| res_name = str(residue["name"]) | |
| sequence_tokens.append(res_name) | |
| chain_ids_per_token.append(chain_idx_numeric) | |
| score = plddt[global_res_cursor].item() | |
| confidence_scores.append(score) | |
| token_start_idx = atom_array_idx | |
| atom_start = residue["atom_idx"] | |
| atom_end = residue["atom_idx"] + residue["atom_num"] | |
| atoms = structure.atoms[atom_start:atom_end] | |
| for atom in atoms: | |
| if not atom["is_present"]: | |
| continue | |
| pos = coords[global_atom_cursor].tolist() | |
| flat_positions.append(pos) | |
| elem = get_element_symbol(atom["element"].item()) | |
| flat_elements.append(elem) | |
| raw_name = atom["name"] | |
| if hasattr(raw_name, "tolist"): | |
| raw_name = raw_name.tolist() | |
| name_str = "".join([chr(c + 32) for c in raw_name if c != 0]) | |
| flat_names.append(name_str) | |
| flat_hetero.append(mol_type == MOL_TYPE_NONPOLYMER) | |
| global_atom_cursor += 1 | |
| atom_array_idx += 1 | |
| token_to_atoms.append([token_start_idx, atom_array_idx]) | |
| global_res_cursor += 1 | |
| return MolecularComplex( | |
| id=complex_id, | |
| sequence=sequence_tokens, | |
| atom_positions=np.array(flat_positions, dtype=np.float32), | |
| atom_elements=np.array(flat_elements, dtype=object), | |
| token_to_atoms=np.array(token_to_atoms, dtype=np.int32), | |
| chain_id=np.array(chain_ids_per_token, dtype=np.int64), | |
| plddt=np.array(confidence_scores, dtype=np.float32), | |
| atom_names=np.array(flat_names, dtype=object), | |
| atom_hetero=np.array(flat_hetero, dtype=bool), | |
| metadata=MolecularComplexMetadata( | |
| entity_lookup=entity_info, | |
| chain_lookup=chain_lookup, | |
| assembly_composition=None, | |
| ), | |
| ) | |