"""Prepare ESMFold2 model inputs from sequence-level StructurePredictionInput. This module converts StructurePredictionInput (protein/DNA/RNA/ligand sequences) into the tensor dict expected by the ESMFold2 model forward pass. """ from __future__ import annotations import math import warnings from collections import defaultdict from dataclasses import dataclass, field import numpy as np import torch from .esmfold2_conformers import ( get_ccd_leaving_atoms, get_idealized_atom_pos, get_ligand_ccd_atoms_with_charges, get_ligand_ccd_bonds, get_ligand_idealized_atom_pos, ) from .esmfold2_constants import ( CHARGED_ATOMS, DNA_1TO3, DNA_BACKBONE_ATOMS, DNA_HEAVY_ATOMS, DNA_RESIDUE_TO_RES_TYPE, DNA_RNA_LIGAND_INPUT_ID, DNA_UNK_RES_TYPE, ELEMENT_TO_ATOMIC_NUM, ESM_PROTEIN_VOCAB, MOL_TYPE_DNA, MOL_TYPE_NONPOLYMER, MOL_TYPE_PROTEIN, MOL_TYPE_RNA, MSA_GAP_TOKEN_ID, PROTEIN_1TO3, PROTEIN_3TO1, PROTEIN_HEAVY_ATOMS, PROTEIN_RESIDUE_TO_RES_TYPE, PROTEIN_UNK_RES_TYPE, RNA_1TO3, RNA_BACKBONE_ATOMS, RNA_HEAVY_ATOMS, RNA_RESIDUE_TO_RES_TYPE, RNA_UNK_RES_TYPE, ) from .esmfold2_types import ( MSA, DNAInput, LigandInput, Modification, ProteinInput, RNAInput, StructurePredictionInput, ) # ============================================================================= # Lightweight data model # ============================================================================= _ZERO_POS = np.array([0.0, 0.0, 0.0], dtype=np.float32) @dataclass class AtomInfo: name: str element: str charge: int ref_pos: np.ndarray # Idealized position from CCD [3] pos: np.ndarray # Experimental position [3] (zeros for inference) token_index: int = -1 atom_index: int = -1 space_uid: int = -1 is_valid: bool = True @dataclass class TokenInfo: token_index: int residue_index: int # Within chain (0-based) residue_name: str # 3-letter code mol_type: int # 0=protein, 1=DNA, 2=RNA, 3=nonpolymer res_type: int # Residue type index (2-32) input_id: int # ESM vocab ID asym_id: int sym_id: int entity_id: int atom_start: int # Index into atoms list atom_count: int @dataclass class ChainInfo: chain_id: str asym_id: int entity_id: int sym_id: int mol_type: int tokens: list[TokenInfo] = field(default_factory=list) # ============================================================================= # Helper functions # ============================================================================= # Caches for hot-path functions _ENCODE_ATOM_NAME_CACHE: dict[str, list[int]] = {} _ELEMENT_ATOMIC_NUM_CACHE: dict[str, int] = {} def encode_atom_name(name: str) -> list[int]: """Encode atom name as 4 character indices (offset by 32 from ASCII).""" if name in _ENCODE_ATOM_NAME_CACHE: return _ENCODE_ATOM_NAME_CACHE[name] padded = name.ljust(4)[:4] result = [ord(c) - 32 if c != " " else 0 for c in padded] _ENCODE_ATOM_NAME_CACHE[name] = result return result def get_element_atomic_num(element: str) -> int: """Get atomic number for an element symbol.""" if element in _ELEMENT_ATOMIC_NUM_CACHE: return _ELEMENT_ATOMIC_NUM_CACHE[element] result = ELEMENT_TO_ATOMIC_NUM.get(element.upper(), 0) _ELEMENT_ATOMIC_NUM_CACHE[element] = result return result def _infer_element(atom_name: str) -> str: """Infer element from atom name.""" name = atom_name.strip() if not name: return "C" if name[0].isdigit(): return name[1] if len(name) > 1 else "H" if len(name) == 2 and name in ( "FE", "ZN", "MG", "MN", "CO", "NI", "CU", "SE", "BR", ): return name return name[0] def _compute_res_type(name: str, mol_type: int) -> int: """Compute residue type index from residue name and mol_type.""" if mol_type == MOL_TYPE_PROTEIN: return PROTEIN_RESIDUE_TO_RES_TYPE.get(name, PROTEIN_UNK_RES_TYPE) elif mol_type == MOL_TYPE_DNA: if name in DNA_RESIDUE_TO_RES_TYPE: return DNA_RESIDUE_TO_RES_TYPE[name] if name in RNA_RESIDUE_TO_RES_TYPE: return RNA_RESIDUE_TO_RES_TYPE[name] return DNA_UNK_RES_TYPE elif mol_type == MOL_TYPE_RNA: if name in RNA_RESIDUE_TO_RES_TYPE: return RNA_RESIDUE_TO_RES_TYPE[name] if name in DNA_RESIDUE_TO_RES_TYPE: return DNA_RESIDUE_TO_RES_TYPE[name] return RNA_UNK_RES_TYPE return PROTEIN_UNK_RES_TYPE def _compute_esm_input_id(name: str, mol_type: int) -> int: """Compute ESM vocabulary input ID.""" if mol_type == MOL_TYPE_PROTEIN: letter = PROTEIN_3TO1.get(name) if letter is None: return DNA_RNA_LIGAND_INPUT_ID return ESM_PROTEIN_VOCAB.get(letter, ESM_PROTEIN_VOCAB["X"]) return DNA_RNA_LIGAND_INPUT_ID # ============================================================================= # Tokenization functions — build tokens and atoms from sequences # ============================================================================= def tokenize_protein( sequence: str, modifications: list[Modification] | None, entity_id: int, asym_id: int, sym_id: int, token_offset: int, atom_offset: int, space_uid_offset: int, ) -> tuple[list[TokenInfo], list[AtomInfo]]: """Tokenize a protein sequence into tokens and atoms. Standard residues produce 1 token with all heavy atoms. Modified residues (from modifications) are atom-tokenized (1 token per atom). """ tokens: list[TokenInfo] = [] atoms: list[AtomInfo] = [] # Build 3-letter sequence, applying modifications seq_3letter = [PROTEIN_1TO3.get(c, "UNK") for c in sequence] modified_positions: set[int] = set() if modifications: for mod in modifications: seq_3letter[mod.position] = mod.ccd modified_positions.add(mod.position) token_idx = token_offset atom_idx = atom_offset space_uid = space_uid_offset for res_idx, res_name in enumerate(seq_3letter): # MSE → MET for atom lookup res_corrected = "MET" if res_name == "MSE" else res_name is_modified = res_idx in modified_positions # Check if standard residue (has predefined atom list) if not is_modified and res_corrected in PROTEIN_HEAVY_ATOMS: # Standard residue: 1 token, multiple atoms atom_names = PROTEIN_HEAVY_ATOMS[res_corrected] res_type = _compute_res_type(res_corrected, MOL_TYPE_PROTEIN) input_id = _compute_esm_input_id(res_corrected, MOL_TYPE_PROTEIN) atom_start = atom_idx for a_name in atom_names: ref_pos = get_idealized_atom_pos(res_type, a_name) atoms.append( AtomInfo( name=a_name, element=_infer_element(a_name), charge=CHARGED_ATOMS.get((res_corrected, a_name), 0), ref_pos=ref_pos.copy() if ref_pos is not None else _ZERO_POS.copy(), pos=_ZERO_POS.copy(), token_index=token_idx, atom_index=atom_idx, space_uid=space_uid, ) ) atom_idx += 1 tokens.append( TokenInfo( token_index=token_idx, residue_index=res_idx, residue_name=res_corrected, mol_type=MOL_TYPE_PROTEIN, res_type=res_type, input_id=input_id, asym_id=asym_id, sym_id=sym_id, entity_id=entity_id, atom_start=atom_start, atom_count=len(atom_names), ) ) token_idx += 1 space_uid += 1 else: # Modified or unknown residue: atom-tokenized ccd_atoms = get_ligand_ccd_atoms_with_charges(res_name) if ccd_atoms is None: # Fallback: backbone only ccd_atoms = [ (_infer_element(n), _infer_element(n), 0) for n in ["N", "CA", "C", "O"] ] # Filter leaving atoms if not terminal is_terminal = res_idx == len(seq_3letter) - 1 leaving_atoms = set() if is_terminal else get_ccd_leaving_atoms(res_name) kept_atoms = [a for a in ccd_atoms if a[0] not in leaving_atoms] # Single-atom residues (e.g. NH2 cap): the local frame is # ill-defined with one atom; place at origin. single_atom_residue = len(kept_atoms) == 1 for a_name, a_element, a_charge in kept_atoms: ref_pos = get_ligand_idealized_atom_pos(res_name, a_name) atoms.append( AtomInfo( name=a_name, element=a_element, charge=a_charge, ref_pos=_ZERO_POS.copy() if single_atom_residue else ( ref_pos.copy() if ref_pos is not None else _ZERO_POS.copy() ), pos=_ZERO_POS.copy(), token_index=token_idx, atom_index=atom_idx, space_uid=space_uid, ) ) tokens.append( TokenInfo( token_index=token_idx, residue_index=res_idx, residue_name=res_name, mol_type=MOL_TYPE_PROTEIN, res_type=PROTEIN_UNK_RES_TYPE, input_id=DNA_RNA_LIGAND_INPUT_ID, asym_id=asym_id, sym_id=sym_id, entity_id=entity_id, atom_start=atom_idx, atom_count=1, ) ) token_idx += 1 atom_idx += 1 space_uid += 1 return tokens, atoms def tokenize_nucleotide( sequence: str, modifications: list[Modification] | None, mol_type: int, entity_id: int, asym_id: int, sym_id: int, token_offset: int, atom_offset: int, space_uid_offset: int, ) -> tuple[list[TokenInfo], list[AtomInfo]]: """Tokenize a DNA or RNA sequence into tokens and atoms.""" tokens: list[TokenInfo] = [] atoms: list[AtomInfo] = [] letter_to_3 = DNA_1TO3 if mol_type == MOL_TYPE_DNA else RNA_1TO3 heavy_atoms = DNA_HEAVY_ATOMS if mol_type == MOL_TYPE_DNA else RNA_HEAVY_ATOMS backbone_atoms = ( DNA_BACKBONE_ATOMS if mol_type == MOL_TYPE_DNA else RNA_BACKBONE_ATOMS ) unk_res_type = DNA_UNK_RES_TYPE if mol_type == MOL_TYPE_DNA else RNA_UNK_RES_TYPE seq_3letter = [letter_to_3.get(c, "UNK") for c in sequence] modified_positions: set[int] = set() if modifications: for mod in modifications: seq_3letter[mod.position] = mod.ccd modified_positions.add(mod.position) token_idx = token_offset atom_idx = atom_offset space_uid = space_uid_offset for res_idx, res_name in enumerate(seq_3letter): is_modified = res_idx in modified_positions if not is_modified and res_name in heavy_atoms: # Standard nucleotide atom_names = heavy_atoms[res_name] res_type = _compute_res_type(res_name, mol_type) input_id = DNA_RNA_LIGAND_INPUT_ID atom_start = atom_idx for a_name in atom_names: ref_pos = get_idealized_atom_pos(res_type, a_name) atoms.append( AtomInfo( name=a_name, element=_infer_element(a_name), charge=CHARGED_ATOMS.get((res_name, a_name), 0), ref_pos=ref_pos.copy() if ref_pos is not None else _ZERO_POS.copy(), pos=_ZERO_POS.copy(), token_index=token_idx, atom_index=atom_idx, space_uid=space_uid, ) ) atom_idx += 1 tokens.append( TokenInfo( token_index=token_idx, residue_index=res_idx, residue_name=res_name, mol_type=mol_type, res_type=res_type, input_id=input_id, asym_id=asym_id, sym_id=sym_id, entity_id=entity_id, atom_start=atom_start, atom_count=len(atom_names), ) ) token_idx += 1 space_uid += 1 elif not is_modified and res_name == "UNK": # Unknown nucleotide: backbone only atom_names = backbone_atoms atom_start = atom_idx for a_name in atom_names: ref_pos = None # No idealized positions for UNK atoms.append( AtomInfo( name=a_name, element=_infer_element(a_name), charge=0, ref_pos=_ZERO_POS.copy(), pos=_ZERO_POS.copy(), token_index=token_idx, atom_index=atom_idx, space_uid=space_uid, ) ) atom_idx += 1 tokens.append( TokenInfo( token_index=token_idx, residue_index=res_idx, residue_name=res_name, mol_type=mol_type, res_type=unk_res_type, input_id=DNA_RNA_LIGAND_INPUT_ID, asym_id=asym_id, sym_id=sym_id, entity_id=entity_id, atom_start=atom_start, atom_count=len(atom_names), ) ) token_idx += 1 space_uid += 1 else: # Modified nucleotide: atom-tokenized ccd_atoms = get_ligand_ccd_atoms_with_charges(res_name) if ccd_atoms is None: ccd_atoms = [ (_infer_element(n), _infer_element(n), 0) for n in backbone_atoms ] is_terminal = res_idx == len(seq_3letter) - 1 leaving_atoms = set() if is_terminal else get_ccd_leaving_atoms(res_name) for a_name, a_element, a_charge in ccd_atoms: if a_name in leaving_atoms: continue ref_pos = get_ligand_idealized_atom_pos(res_name, a_name) atoms.append( AtomInfo( name=a_name, element=a_element, charge=a_charge, ref_pos=ref_pos.copy() if ref_pos is not None else _ZERO_POS.copy(), pos=_ZERO_POS.copy(), token_index=token_idx, atom_index=atom_idx, space_uid=space_uid, ) ) tokens.append( TokenInfo( token_index=token_idx, residue_index=res_idx, residue_name=res_name, mol_type=mol_type, res_type=PROTEIN_UNK_RES_TYPE, input_id=DNA_RNA_LIGAND_INPUT_ID, asym_id=asym_id, sym_id=sym_id, entity_id=entity_id, atom_start=atom_idx, atom_count=1, ) ) token_idx += 1 atom_idx += 1 space_uid += 1 return tokens, atoms def tokenize_ligand_ccd( ccd_codes: list[str], entity_id: int, asym_id: int, sym_id: int, token_offset: int, atom_offset: int, space_uid_offset: int, has_covalent_bond: bool, ) -> tuple[list[TokenInfo], list[AtomInfo]]: """Tokenize a ligand from CCD codes (1 token per atom).""" tokens: list[TokenInfo] = [] atoms: list[AtomInfo] = [] token_idx = token_offset atom_idx = atom_offset space_uid = space_uid_offset for res_idx, code in enumerate(ccd_codes): ccd_atoms = get_ligand_ccd_atoms_with_charges(code) if ccd_atoms is None: raise ValueError(f"CCD component {code} not found") leaving_atoms = get_ccd_leaving_atoms(code) if has_covalent_bond else set() for a_name, a_element, a_charge in ccd_atoms: if a_name in leaving_atoms: continue ref_pos = get_ligand_idealized_atom_pos(code, a_name) atoms.append( AtomInfo( name=a_name, element=a_element, charge=a_charge, ref_pos=ref_pos.copy() if ref_pos is not None else _ZERO_POS.copy(), pos=_ZERO_POS.copy(), token_index=token_idx, atom_index=atom_idx, space_uid=space_uid, ) ) tokens.append( TokenInfo( token_index=token_idx, residue_index=res_idx, residue_name=code, mol_type=MOL_TYPE_NONPOLYMER, res_type=PROTEIN_UNK_RES_TYPE, input_id=DNA_RNA_LIGAND_INPUT_ID, asym_id=asym_id, sym_id=sym_id, entity_id=entity_id, atom_start=atom_idx, atom_count=1, ) ) token_idx += 1 atom_idx += 1 space_uid += 1 return tokens, atoms def tokenize_ligand_smiles( smiles: str, entity_id: int, asym_id: int, sym_id: int, token_offset: int, atom_offset: int, space_uid_offset: int, seed: int | None = None, ) -> tuple[list[TokenInfo], list[AtomInfo]]: """Tokenize a ligand from SMILES (1 token per heavy atom).""" from rdkit import Chem from rdkit.Chem import AllChem mol = Chem.MolFromSmiles(smiles) if mol is None: raise ValueError(f"Failed to parse SMILES: {smiles}") mol = Chem.AddHs(mol) # Assign atom names using canonical ranking canonical_order = AllChem.CanonicalRankAtoms(mol) # type: ignore[attr-defined] for atom, can_idx in zip(mol.GetAtoms(), canonical_order): atom_name = atom.GetSymbol().upper() + str(can_idx + 1) if len(atom_name) > 4: raise ValueError( f"SMILES {smiles} has atom name longer than 4 chars: {atom_name}" ) atom.SetProp("name", atom_name) # Generate 3D conformer options = AllChem.ETKDGv3() # type: ignore[attr-defined] options.clearConfs = False if seed is not None: options.randomSeed = seed conf_id = AllChem.EmbedMolecule(mol, options) # type: ignore[attr-defined] if conf_id == -1: options.useRandomCoords = True conf_id = AllChem.EmbedMolecule(mol, options) # type: ignore[attr-defined] if conf_id != -1: try: AllChem.UFFOptimizeMolecule(mol, confId=conf_id, maxIters=1000) # type: ignore[attr-defined] except (RuntimeError, ValueError): pass # Remove hydrogens mol_no_h = Chem.RemoveHs(mol) if mol_no_h.GetNumConformers() == 0: raise ValueError(f"Failed to generate conformer for SMILES: {smiles}") conformer = mol_no_h.GetConformer(0) tokens: list[TokenInfo] = [] atoms_list: list[AtomInfo] = [] token_idx = token_offset atom_idx = atom_offset space_uid = space_uid_offset for atom in mol_no_h.GetAtoms(): a_name = atom.GetProp("name") a_element = atom.GetSymbol() a_charge = atom.GetFormalCharge() pos_3d = conformer.GetAtomPosition(atom.GetIdx()) ref_pos = np.array([pos_3d.x, pos_3d.y, pos_3d.z], dtype=np.float32) atoms_list.append( AtomInfo( name=a_name, element=a_element, charge=a_charge, ref_pos=ref_pos, pos=_ZERO_POS.copy(), token_index=token_idx, atom_index=atom_idx, space_uid=space_uid, ) ) tokens.append( TokenInfo( token_index=token_idx, residue_index=0, residue_name="LIG", mol_type=MOL_TYPE_NONPOLYMER, res_type=PROTEIN_UNK_RES_TYPE, input_id=DNA_RNA_LIGAND_INPUT_ID, asym_id=asym_id, sym_id=sym_id, entity_id=entity_id, atom_start=atom_idx, atom_count=1, ) ) token_idx += 1 atom_idx += 1 return tokens, atoms_list # ============================================================================= # Build chains from StructurePredictionInput # ============================================================================= def _get_sequence_key(item) -> str: """Get a hashable key for entity deduplication.""" if isinstance(item, ProteinInput): return f"PROTEIN:{item.sequence}" elif isinstance(item, DNAInput): return f"DNA:{item.sequence}" elif isinstance(item, RNAInput): return f"RNA:{item.sequence}" elif isinstance(item, LigandInput): if item.ccd: return f"LIGAND_CCD:{','.join(item.ccd)}" return f"LIGAND_SMILES:{item.smiles}" raise ValueError(f"Unknown input type: {type(item)}") def build_chains_from_input( input: StructurePredictionInput, seed: int | None = None ) -> tuple[list[ChainInfo], list[TokenInfo], list[AtomInfo]]: """Build chains, tokens, and atoms from StructurePredictionInput. Handles entity deduplication (identical sequences get same entity_id), sym_id assignment, and delegates to type-specific tokenization functions. """ chains: list[ChainInfo] = [] all_tokens: list[TokenInfo] = [] all_atoms: list[AtomInfo] = [] # Entity deduplication sequence_to_entity: dict[str, int] = {} entity_sym_count: dict[int, int] = {} next_entity_id = 0 # Gather chain IDs involved in covalent bonds covalent_chain_ids: set[str] = set() if input.covalent_bonds: for cb in input.covalent_bonds: covalent_chain_ids.update([cb.chain_id1, cb.chain_id2]) token_offset = 0 atom_offset = 0 space_uid_offset = 0 asym_id = 0 for item in input.sequences: # Entity deduplication seq_key = _get_sequence_key(item) if seq_key in sequence_to_entity: entity_id = sequence_to_entity[seq_key] else: entity_id = next_entity_id sequence_to_entity[seq_key] = entity_id next_entity_id += 1 # Get all chain IDs for this item ids = [item.id] if isinstance(item.id, str) else item.id for chain_id_str in ids: # sym_id is the per-entity copy index; increment per chain so # ProteinInput(id=['A','B']) gives chain A sym_id=0, chain B sym_id=1. sym_id = entity_sym_count.get(entity_id, 0) entity_sym_count[entity_id] = sym_id + 1 if isinstance(item, ProteinInput): if item.msa is None: warnings.warn( f"No MSA provided for {item.id}, using single sequence mode" ) new_tokens, new_atoms = tokenize_protein( sequence=item.sequence, modifications=item.modifications, entity_id=entity_id, asym_id=asym_id, sym_id=sym_id, token_offset=token_offset, atom_offset=atom_offset, space_uid_offset=space_uid_offset, ) elif isinstance(item, (DNAInput, RNAInput)): mol_type = MOL_TYPE_DNA if isinstance(item, DNAInput) else MOL_TYPE_RNA new_tokens, new_atoms = tokenize_nucleotide( sequence=item.sequence, modifications=item.modifications, mol_type=mol_type, entity_id=entity_id, asym_id=asym_id, sym_id=sym_id, token_offset=token_offset, atom_offset=atom_offset, space_uid_offset=space_uid_offset, ) elif isinstance(item, LigandInput): has_cov = chain_id_str in covalent_chain_ids if item.ccd is not None: if item.smiles is not None: warnings.warn("Both ccd and smiles provided, using ccd") new_tokens, new_atoms = tokenize_ligand_ccd( ccd_codes=item.ccd, entity_id=entity_id, asym_id=asym_id, sym_id=sym_id, token_offset=token_offset, atom_offset=atom_offset, space_uid_offset=space_uid_offset, has_covalent_bond=has_cov, ) elif item.smiles is not None: new_tokens, new_atoms = tokenize_ligand_smiles( smiles=item.smiles, entity_id=entity_id, asym_id=asym_id, sym_id=sym_id, token_offset=token_offset, atom_offset=atom_offset, space_uid_offset=space_uid_offset, seed=seed, ) else: raise ValueError("LigandInput must have either ccd or smiles") else: raise ValueError(f"Unknown input type: {type(item)}") chain = ChainInfo( chain_id=chain_id_str, asym_id=asym_id, entity_id=entity_id, sym_id=sym_id, mol_type=new_tokens[0].mol_type if new_tokens else MOL_TYPE_PROTEIN, tokens=new_tokens, ) chains.append(chain) all_tokens.extend(new_tokens) all_atoms.extend(new_atoms) token_offset += len(new_tokens) atom_offset += len(new_atoms) space_uid_offset += len(set(a.space_uid for a in new_atoms)) asym_id += 1 return chains, all_tokens, all_atoms # ============================================================================= # Feature tensor building # ============================================================================= def compute_frame_indices( tokens: list[TokenInfo], atoms: list[AtomInfo] ) -> tuple[np.ndarray, np.ndarray]: """Compute backbone frame indices for each token. Protein: [N, CA, C]; DNA/RNA: [C1', C3', C4']; Ligand: distance-based. """ # Build atom name -> atom_index lookup per token token_atoms: dict[int, dict[str, int]] = defaultdict(dict) for atom in atoms: if atom.is_valid: token_atoms[atom.token_index][atom.name] = atom.atom_index # Ligand-token frames come from CCD reference-conformer geometry, # grouped per residue. For each token, the frame is the 3 atoms nearest # to its own atom in the residue's ref-pos space, ordered # (1st-nearest, self, 2nd-nearest). ligand_token_to_atom: dict[int, int] = {} ligand_tokens_by_res: dict[tuple[int, int], list[int]] = defaultdict(list) for t in tokens: if t.mol_type == MOL_TYPE_NONPOLYMER: ad = token_atoms.get(t.token_index) if ad: ligand_token_to_atom[t.token_index] = next(iter(ad.values())) ligand_tokens_by_res[(t.asym_id, t.residue_index)].append(t.token_index) ligand_token_frames: dict[int, tuple[int, int, int]] = {} for tok_indices in ligand_tokens_by_res.values(): atom_indices = [ ligand_token_to_atom[ti] for ti in tok_indices if ti in ligand_token_to_atom ] if len(atom_indices) < 3: for ti in tok_indices: if ti in ligand_token_to_atom: ai = ligand_token_to_atom[ti] ligand_token_frames[ti] = (ai, ai, ai) continue ref_pos_chain = np.array([atoms[ai].ref_pos for ai in atom_indices]) dist_mat = np.sqrt( ((ref_pos_chain[:, None] - ref_pos_chain[None]) ** 2).sum(-1) ) sort_indices = np.argsort(dist_mat, axis=1) local_frames = np.column_stack( [sort_indices[:, 1], sort_indices[:, 0], sort_indices[:, 2]] ) for ti in tok_indices: if ti not in ligand_token_to_atom: continue ai = ligand_token_to_atom[ti] local_idx = atom_indices.index(ai) fl = local_frames[local_idx] ligand_token_frames[ti] = ( atom_indices[fl[0]], atom_indices[fl[1]], atom_indices[fl[2]], ) # Build frames for all tokens frames_list: list[tuple[int, int, int]] = [] for t in tokens: ad = token_atoms.get(t.token_index, {}) fallback = list(ad.values())[0] if ad else 0 if t.mol_type == MOL_TYPE_PROTEIN: if t.res_type == PROTEIN_UNK_RES_TYPE: frames_list.append((fallback, fallback, fallback)) else: frames_list.append((ad.get("N", 0), ad.get("CA", 0), ad.get("C", 0))) elif t.mol_type in (MOL_TYPE_DNA, MOL_TYPE_RNA): if t.res_type == PROTEIN_UNK_RES_TYPE: frames_list.append((fallback, fallback, fallback)) else: frames_list.append( (ad.get("C1'", 0), ad.get("C3'", 0), ad.get("C4'", 0)) ) elif t.mol_type == MOL_TYPE_NONPOLYMER: if t.token_index in ligand_token_frames: frames_list.append(ligand_token_frames[t.token_index]) else: frames_list.append((fallback, fallback, fallback)) else: frames_list.append((fallback, fallback, fallback)) frames = np.array(frames_list, dtype=np.int64) # Compute resolved mask (vectorized) n_atoms = len(atoms) atom_positions = ( np.array([a.pos for a in atoms], dtype=np.float32) if atoms else np.zeros((0, 3), dtype=np.float32) ) atom_is_valid = ( np.array([a.is_valid for a in atoms], dtype=bool) if atoms else np.zeros(0, dtype=bool) ) atom_is_resolved = ( atom_is_valid & np.any(atom_positions != 0, axis=1) if n_atoms > 0 else np.zeros(0, dtype=bool) ) n_tokens = len(tokens) if n_tokens == 0: return frames, np.zeros(0, dtype=bool) pos1 = atom_positions[frames[:, 0]] pos2 = atom_positions[frames[:, 1]] pos3 = atom_positions[frames[:, 2]] all_resolved = ( atom_is_resolved[frames[:, 0]] & atom_is_resolved[frames[:, 1]] & atom_is_resolved[frames[:, 2]] ) all_same = (frames[:, 0] == frames[:, 1]) & (frames[:, 1] == frames[:, 2]) v1 = pos1 - pos2 v2 = pos3 - pos2 norm1 = np.linalg.norm(v1, axis=1) norm2 = np.linalg.norm(v2, axis=1) valid_norms = (norm1 >= 1e-6) & (norm2 >= 1e-6) cos_angle = np.zeros(n_tokens, dtype=np.float32) mask = valid_norms if np.any(mask): cos_angle[mask] = np.sum(v1[mask] * v2[mask], axis=1) / ( norm1[mask] * norm2[mask] ) cos_angle = np.clip(cos_angle, -1, 1) angle_deg = np.degrees(np.arccos(np.abs(cos_angle))) not_colinear = angle_deg >= 25 resolved_mask = all_resolved & ~all_same & valid_norms & not_colinear return frames, resolved_mask def compute_token_bonds( tokens: list[TokenInfo], atoms: list[AtomInfo], input: StructurePredictionInput, chains: list[ChainInfo], ) -> torch.Tensor: """Compute dense token bond matrix [L, L, 1]. Includes ligand intra-residue bonds (from CCD) and covalent bonds. """ n_tokens = len(tokens) edge_set: set[tuple[int, int]] = set() def add_bond(i: int, j: int) -> None: if i != j: edge_set.add((min(i, j), max(i, j))) # Build per-residue atom name -> token_index mapping for ligands and modified residues # Key: (asym_id, residue_index, atom_name) -> token_index atom_name_to_token: dict[tuple[int, int, str], int] = {} for atom in atoms: if atom.is_valid: t = tokens[atom.token_index] if atom.token_index < len(tokens) else None if t and ( t.mol_type == MOL_TYPE_NONPOLYMER or t.res_type == PROTEIN_UNK_RES_TYPE ): atom_name_to_token[(t.asym_id, t.residue_index, atom.name)] = ( atom.token_index ) # Group atom-tokenized tokens by (asym_id, residue_index) residue_tokens: dict[tuple[int, int], list[tuple[str, int]]] = defaultdict(list) for atom in atoms: if not atom.is_valid: continue t = tokens[atom.token_index] if atom.token_index < len(tokens) else None if t and ( t.mol_type == MOL_TYPE_NONPOLYMER or t.res_type == PROTEIN_UNK_RES_TYPE ): residue_tokens[(t.asym_id, t.residue_index)].append( (atom.name, atom.token_index) ) # Add intra-residue bonds from CCD for (asym_id_val, res_idx), atom_list in residue_tokens.items(): if not atom_list: continue res_name = tokens[atom_list[0][1]].residue_name ccd_bonds = get_ligand_ccd_bonds(res_name) atom_to_tok = {name: ti for name, ti in atom_list} if ccd_bonds: for a1, a2 in ccd_bonds: if a1 in atom_to_tok and a2 in atom_to_tok: add_bond(atom_to_tok[a1], atom_to_tok[a2]) else: # Fallback: fully connected within residue tok_indices = [ti for _, ti in atom_list] for i_idx in tok_indices: for j_idx in tok_indices: add_bond(i_idx, j_idx) # Add covalent bonds from input if input.covalent_bonds: # Build chain_id -> chain mapping chain_by_id: dict[str, ChainInfo] = {c.chain_id: c for c in chains} # Build (asym_id, residue_index) -> list of tokens for atom index lookup chain_res_atoms: dict[tuple[int, int], list[AtomInfo]] = defaultdict(list) for atom in atoms: if atom.is_valid and atom.token_index < len(tokens): t = tokens[atom.token_index] chain_res_atoms[(t.asym_id, t.residue_index)].append(atom) for cb in input.covalent_bonds: c1 = chain_by_id.get(cb.chain_id1) c2 = chain_by_id.get(cb.chain_id2) if c1 is None or c2 is None: continue atoms_1 = chain_res_atoms.get((c1.asym_id, cb.res_idx1), []) atoms_2 = chain_res_atoms.get((c2.asym_id, cb.res_idx2), []) if cb.atom_idx1 < len(atoms_1) and cb.atom_idx2 < len(atoms_2): add_bond( atoms_1[cb.atom_idx1].token_index, atoms_2[cb.atom_idx2].token_index ) # Add peptide bonds at modified-residue boundaries: an atom-tokenized # residue's N atom connects to the prev residue's C atom (and same for # the C side to the next residue's N). tokens_by_chain_res: dict[tuple[int, int], list[TokenInfo]] = defaultdict(list) for t in tokens: if t.mol_type == MOL_TYPE_PROTEIN: tokens_by_chain_res[(t.asym_id, t.residue_index)].append(t) def _backbone_token(res_tokens: list[TokenInfo], atom_name: str) -> int | None: # Standard residue (single token wrapping all atoms): return that token. if len(res_tokens) == 1 and res_tokens[0].res_type != PROTEIN_UNK_RES_TYPE: return res_tokens[0].token_index for t in res_tokens: for a_idx in range(t.atom_start, t.atom_start + t.atom_count): if a_idx < len(atoms) and atoms[a_idx].name == atom_name: return t.token_index # Atom-tokenized residue without an atom of that name (e.g. ACE has # no N, NH2 has no C). Fall back to the first atom-tokenized token. return res_tokens[0].token_index if res_tokens else None for (asym_id_val, res_idx), res_tokens in tokens_by_chain_res.items(): is_atom_tokenized = any(t.res_type == PROTEIN_UNK_RES_TYPE for t in res_tokens) if not is_atom_tokenized: continue # Standard residue — no peptide bond added here. n_tok = _backbone_token(res_tokens, "N") c_tok = _backbone_token(res_tokens, "C") prev_tokens = tokens_by_chain_res.get((asym_id_val, res_idx - 1)) if prev_tokens and n_tok is not None: prev_c = _backbone_token(prev_tokens, "C") if prev_c is not None: add_bond(prev_c, n_tok) next_tokens = tokens_by_chain_res.get((asym_id_val, res_idx + 1)) if next_tokens and c_tok is not None: next_n = _backbone_token(next_tokens, "N") if next_n is not None: add_bond(c_tok, next_n) # Expand to dense matrix bonds = torch.zeros(n_tokens, n_tokens, 1, dtype=torch.float32) for i, j in edge_set: bonds[i, j, 0] = 1.0 bonds[j, i, 0] = 1.0 return bonds def compute_representative_atoms( tokens: list[TokenInfo], atoms: list[AtomInfo] ) -> torch.Tensor: """Compute representative atom index per token (for token_to_rep_atom). Returns: distogram_atom_idx: [L] — representative atom per token Protein: CB (or CA for GLY), DNA/RNA: C4/C2/C1', Ligand: first atom. """ n_tokens = len(tokens) # Build atom name -> index lookup per token token_atoms: dict[int, dict[str, int]] = defaultdict(dict) for atom in atoms: if atom.is_valid: token_atoms[atom.token_index][atom.name] = atom.atom_index distogram_atom_idx = torch.zeros(n_tokens, dtype=torch.int64) for t in tokens: ad = token_atoms.get(t.token_index, {}) fallback_idx = list(ad.values())[0] if ad else 0 if t.mol_type == MOL_TYPE_PROTEIN: rep_idx = ad.get("CB", ad.get("CA", fallback_idx)) elif t.mol_type in (MOL_TYPE_DNA, MOL_TYPE_RNA): if t.res_type in (27, 32): # Unknown nucleotides rep_idx = ad.get("C1'", fallback_idx) elif t.res_type in (23, 24, 28, 29): # Purines (A, G) rep_idx = ad.get("C4", ad.get("C1'", fallback_idx)) else: # Pyrimidines (C, U, T) rep_idx = ad.get("C2", ad.get("C1'", fallback_idx)) else: rep_idx = fallback_idx distogram_atom_idx[t.token_index] = rep_idx return distogram_atom_idx def compute_msa_features( input: StructurePredictionInput, chains: list[ChainInfo], tokens: list[TokenInfo], max_seqs: int = 16384, ) -> dict[str, torch.Tensor]: """Compute MSA features from protein MSAs. Uses taxonomy-based pairing across chains (:func:`paired_msa.construct_paired_msa`): rows whose FASTA header contains ``key=N`` get paired across chains sharing the same ``N``. Output: msa [M, L], deletion_value [M, L], has_deletion [M, L], deletion_mean [L], msa_mask [M, L] """ from .esmfold2_paired_msa import ( construct_paired_msa, protein_letter_to_res_type, ) n_tokens = len(tokens) # A single ProteinInput with id=['A','B','C',...] yields one item but # multiple chains (one per id); broadcast the MSA across all of them. chain_msas: dict[int, MSA | None] = {} item_idx = 0 for item in input.sequences: ids = [item.id] if isinstance(item.id, str) else list(item.id) for _ in ids: chain = chains[item_idx] if isinstance(item, ProteinInput): msa = item.msa if msa is None: msa = MSA.from_sequences([item.sequence]) chain_msas[chain.asym_id] = msa else: chain_msas[chain.asym_id] = None item_idx += 1 letter_to_res_type = protein_letter_to_res_type() # Build per-chain query res_types (used for chains without an MSA). chain_query_res_types: dict[int, np.ndarray] = {} for chain in chains: chain_tokens = [t for t in tokens if t.asym_id == chain.asym_id] chain_query_res_types[chain.asym_id] = np.array( [t.res_type for t in chain_tokens], dtype=np.int64 ) token_asym_ids = np.array([t.asym_id for t in tokens], dtype=np.int64) token_res_ids = np.array([t.residue_index for t in tokens], dtype=np.int64) msa_res, del_counts, paired = construct_paired_msa( chain_msas, chain_query_res_types, token_asym_ids, token_res_ids, letter_to_res_type=letter_to_res_type, max_seqs=max_seqs, ) # Tokens for chains without an MSA get their res_type at row 0 and gap # elsewhere; this mirrors the prior non-protein-token branch. for t in tokens: if chain_msas.get(t.asym_id) is None: msa_res[:, t.token_index] = MSA_GAP_TOKEN_ID msa_res[0, t.token_index] = t.res_type if msa_res.shape[0] == 0: msa_res = np.full((1, n_tokens), MSA_GAP_TOKEN_ID, dtype=np.int64) del_counts = np.zeros((1, n_tokens), dtype=np.float32) msa_data = torch.from_numpy(msa_res) del_data = torch.from_numpy(del_counts) has_deletion = del_data > 0 deletion_value = (np.pi / 2) * torch.arctan(del_data / 3) deletion_mean = deletion_value.mean(dim=0) msa_mask = torch.ones_like(msa_data, dtype=torch.bool) return { "msa": msa_data, "deletion_value": deletion_value, "has_deletion": has_deletion, "deletion_mean": deletion_mean, "msa_attention_mask": msa_mask, } def compute_distogram_conditioning( input: StructurePredictionInput, chains: list[ChainInfo], tokens: list[TokenInfo], disto_center: torch.Tensor, min_dist: float = 2.0, max_dist: float = 22.0, num_bins: int = 64, ) -> tuple[torch.Tensor, torch.Tensor]: """Compute distogram conditioning from user-provided distograms. Returns: disto_cond: [L, L] int64 (bin indices) disto_cond_mask: [L, L] bool """ n_tokens = len(tokens) disto_cond = torch.zeros(n_tokens, n_tokens, dtype=torch.long) disto_cond_mask = torch.zeros(n_tokens, n_tokens, dtype=torch.bool) if not input.distogram_conditioning: return disto_cond, disto_cond_mask # Build chain_id -> asym_id mapping chain_id_to_asym: dict[str, int] = {c.chain_id: c.asym_id for c in chains} # Build asym_id -> token indices mapping asym_to_tokens: dict[int, list[int]] = defaultdict(list) for t in tokens: asym_to_tokens[t.asym_id].append(t.token_index) boundaries = torch.linspace(min_dist, max_dist, num_bins + 1) for dc in input.distogram_conditioning: asym_id_val = chain_id_to_asym.get(dc.chain_id) if asym_id_val is None: continue tok_indices = asym_to_tokens[asym_id_val] n_chain = len(tok_indices) distogram = torch.tensor(dc.distogram, dtype=torch.float32) if distogram.shape != (n_chain, n_chain): raise ValueError( f"Distogram shape {distogram.shape} doesn't match chain length {n_chain}" ) # Bin the distogram binned = torch.bucketize(distogram, boundaries[:-1]) - 1 binned = binned.clamp(0, num_bins - 1) for i, ti in enumerate(tok_indices): for j, tj in enumerate(tok_indices): disto_cond[ti, tj] = binned[i, j] disto_cond_mask[ti, tj] = True return disto_cond, disto_cond_mask def build_feature_tensors( chains: list[ChainInfo], tokens: list[TokenInfo], atoms: list[AtomInfo], input: StructurePredictionInput, ) -> dict[str, torch.Tensor]: """Build all model input tensors from tokens and atoms.""" n_tokens = len(tokens) n_real_atoms = len(atoms) # Pad atoms to nearest multiple of 32 target_atoms = math.ceil(n_real_atoms / 32) * 32 if n_real_atoms > 0 else 32 n_padding = target_atoms - n_real_atoms padding_atoms = [ AtomInfo( name="", element="", charge=0, ref_pos=_ZERO_POS.copy(), pos=_ZERO_POS.copy(), token_index=0, atom_index=n_real_atoms + i, space_uid=0, is_valid=False, ) for i in range(n_padding) ] all_atoms = atoms + padding_atoms n_atoms = len(all_atoms) # --- Token-level tensors --- token_index_arr = np.empty(n_tokens, dtype=np.int64) residue_index_arr = np.empty(n_tokens, dtype=np.int64) asym_id_arr = np.empty(n_tokens, dtype=np.int64) sym_id_arr = np.empty(n_tokens, dtype=np.int64) entity_id_arr = np.empty(n_tokens, dtype=np.int64) mol_type_arr = np.empty(n_tokens, dtype=np.int64) res_type_arr = np.empty(n_tokens, dtype=np.int64) input_ids_arr = np.empty(n_tokens, dtype=np.int64) for i, t in enumerate(tokens): token_index_arr[i] = t.token_index residue_index_arr[i] = t.residue_index asym_id_arr[i] = t.asym_id sym_id_arr[i] = t.sym_id entity_id_arr[i] = t.entity_id mol_type_arr[i] = t.mol_type res_type_arr[i] = t.res_type input_ids_arr[i] = t.input_id token_index = torch.from_numpy(token_index_arr) residue_index = torch.from_numpy(residue_index_arr) asym_id = torch.from_numpy(asym_id_arr) sym_id = torch.from_numpy(sym_id_arr) entity_id = torch.from_numpy(entity_id_arr) mol_type = torch.from_numpy(mol_type_arr) res_type = torch.from_numpy(res_type_arr) input_ids = torch.from_numpy(input_ids_arr) token_pad_mask = torch.ones(n_tokens, dtype=torch.bool) # --- Atom-level tensors --- ref_pos_arr = np.zeros((n_atoms, 3), dtype=np.float32) ref_element_arr = np.zeros(n_atoms, dtype=np.int64) ref_charge_arr = np.zeros(n_atoms, dtype=np.int8) ref_atom_name_chars_arr = np.zeros((n_atoms, 4), dtype=np.int64) ref_space_uid_arr = np.zeros(n_atoms, dtype=np.int64) atom_pad_mask_arr = np.zeros(n_atoms, dtype=np.bool_) atom_to_token_arr = np.zeros(n_atoms, dtype=np.int64) all_positions = np.zeros((n_atoms, 3), dtype=np.float64) is_valid_arr = np.zeros(n_atoms, dtype=np.bool_) for i, atom in enumerate(all_atoms): if atom.ref_pos is not None: ref_pos_arr[i] = atom.ref_pos ref_charge_arr[i] = atom.charge ref_space_uid_arr[i] = ( atom.space_uid if atom.space_uid >= 0 else atom.token_index ) atom_pad_mask_arr[i] = atom.is_valid is_valid_arr[i] = atom.is_valid all_positions[i] = atom.pos if atom.is_valid: ref_element_arr[i] = get_element_atomic_num(atom.element) name_indices = encode_atom_name(atom.name) ref_atom_name_chars_arr[i] = name_indices atom_to_token_arr[i] = atom.token_index ref_pos = torch.from_numpy(ref_pos_arr) ref_element = torch.from_numpy(ref_element_arr) ref_charge = torch.from_numpy(ref_charge_arr) ref_atom_name_chars = torch.from_numpy(ref_atom_name_chars_arr) ref_space_uid = torch.from_numpy(ref_space_uid_arr) atom_pad_mask = torch.from_numpy(atom_pad_mask_arr) atom_to_token = torch.from_numpy(atom_to_token_arr) # Coordinates — center on resolved atoms raw_coords = torch.from_numpy(all_positions) is_nonzero = np.any(all_positions != 0, axis=1) atom_resolved_arr = is_valid_arr & is_nonzero resolved_mask = torch.from_numpy(atom_resolved_arr) valid_mask = torch.from_numpy(is_valid_arr) if resolved_mask.any(): centroid = raw_coords[resolved_mask].mean(dim=0, keepdim=True) raw_coords = raw_coords - centroid raw_coords[~valid_mask] = 0.0 coords = raw_coords.float().unsqueeze(0) # [1, A, 3] atom_resolved_mask = torch.tensor(atom_resolved_arr, dtype=torch.bool) # --- Frames --- frames, _ = compute_frame_indices(tokens, atoms) frames_idx = torch.from_numpy(frames).to(torch.int64) # --- Token bonds --- token_bonds = compute_token_bonds(tokens, atoms, input, chains) # --- Representative atoms --- distogram_atom_idx = compute_representative_atoms(tokens, atoms) # --- MSA features --- msa_features = compute_msa_features(input, chains, tokens) # --- Distogram conditioning --- # disto_center is not needed for inference (no experimental coords) disto_center = torch.zeros(n_tokens, 3, dtype=torch.float32) disto_cond, disto_cond_mask = compute_distogram_conditioning( input, chains, tokens, disto_center ) # ref_pos: CCD conformer positions, used as-is for inference. # No random rotation or masking — at inference there are no resolved # experimental coordinates, so atom_resolved_mask is all False. # The model uses ref_pos for atom feature embedding. # --- Pocket (dropped) --- pocket_feature = torch.zeros(n_tokens, dtype=torch.long) return { # Token-level "token_index": token_index, "residue_index": residue_index, "asym_id": asym_id, "entity_id": entity_id, "sym_id": sym_id, "mol_type": mol_type, "res_type": res_type, "input_ids": input_ids, "token_bonds": token_bonds, "token_attention_mask": token_pad_mask, "pocket_feature": pocket_feature, # Atom-level "ref_pos": ref_pos, "ref_element": ref_element, "ref_charge": ref_charge, "ref_atom_name_chars": ref_atom_name_chars, "ref_space_uid": ref_space_uid, "gt_coords": coords, "atom_attention_mask": atom_pad_mask, "atom_to_token": atom_to_token, "is_resolved": atom_resolved_mask, "distogram_atom_idx": distogram_atom_idx, # Frames "frames_idx": frames_idx, # Distogram "disto_cond": disto_cond, "disto_cond_mask": disto_cond_mask, # MSA **msa_features, } # ============================================================================= # Top-level entry point # ============================================================================= def prepare_esmfold2_input( input: StructurePredictionInput, seed: int | None = None ) -> tuple[dict[str, torch.Tensor], list[ChainInfo]]: """Prepare ESMFold2 model inputs from StructurePredictionInput. Args: input: The structure prediction input (sequences, conditioning, etc.) seed: Random seed for SMILES conformer generation and augmentation. Returns: Tuple of (feature_dict, chain_infos) where feature_dict contains all tensors for the model forward pass, and chain_infos contains metadata for output processing. """ chains, tokens, atoms = build_chains_from_input(input, seed) features = build_feature_tensors(chains, tokens, atoms, input) return features, chains