from __future__ import annotations import io import os import re from dataclasses import asdict, dataclass from pathlib import Path from subprocess import check_output from tempfile import TemporaryDirectory from typing import TYPE_CHECKING, Any import biotite.structure as bs import biotite.structure.io.pdbx as pdbx import brotli import msgpack import numpy as np import torch from biotite.structure.io.pdbx import ( CIFCategory, CIFColumn, CIFData, CIFFile, set_structure, ) from . import esmfold2_residue_constants as residue_constants from .esmfold2_metrics import compute_lddt, compute_rmsd from .esmfold2_protein_complex import ProteinComplex, ProteinComplexMetadata @dataclass class MolecularComplexResult: """Result of molecular complex folding""" complex: MolecularComplex plddt: torch.Tensor | None = None ptm: float | None = None iptm: float | None = None pae: torch.Tensor | None = None distogram: torch.Tensor | None = None pair_chains_iptm: torch.Tensor | None = None output_embedding_sequence: torch.Tensor | None = None output_embedding_pair_pooled: torch.Tensor | None = None residue_index: torch.Tensor | None = None entity_id: torch.Tensor | None = None sae_features: np.ndarray | None = None # [L, n_features] ttt_metrics: dict[str, Any] | None = None @dataclass class MolecularComplexMetadata: """Metadata for MolecularComplex objects.""" entity_lookup: dict[int, str] chain_lookup: dict[int, str] assembly_composition: dict[str, list[str]] | None = None @dataclass class Molecule: """Represents a single molecule/token within a MolecularComplex.""" token: str token_idx: int atom_positions: np.ndarray # [N_atoms, 3] atom_elements: np.ndarray # [N_atoms] element strings atom_names: np.ndarray | None = None # [N_atoms] atom names (optional) atom_hetero: np.ndarray | None = None # [N_atoms] hetero flags (optional) residue_type: int = 0 molecule_type: int = 0 # PROTEIN=0, RNA=1, DNA=2, LIGAND=3 confidence: float = 0.0 @dataclass(frozen=True) class MolecularComplex: """ Dataclass representing a molecular complex with support for proteins, nucleic acids, and ligands. Uses a flat atom representation with token-based sequence indexing, supporting all atom types beyond the traditional atom37 protein representation. """ id: str sequence: list[str] # Token sequence like ['MET', 'LYS', 'A', 'G', 'ATP'] # Flat atom arrays - simplified representation atom_positions: np.ndarray # [N_atoms, 3] 3D coordinates atom_elements: np.ndarray # [N_atoms] element strings # Token-to-atom mapping for efficient access token_to_atoms: np.ndarray # [N_tokens, 2] start/end indices into atoms array # Chain information chain_id: np.ndarray # [N_tokens] chain identifier for each token # Confidence data plddt: np.ndarray # Per-token confidence scores [N_tokens] # Metadata metadata: MolecularComplexMetadata # Optional atom names and hetero flags (preserved from original structures) atom_names: np.ndarray | None = None # [N_atoms] atom names (optional) atom_hetero: np.ndarray | None = None # [N_atoms] hetero flags (optional) def __post_init__(self): """Validate array dimensions.""" n_tokens = len(self.sequence) n_atoms = len(self.atom_positions) assert ( self.token_to_atoms.shape[0] == n_tokens ), f"token_to_atoms shape {self.token_to_atoms.shape} != {n_tokens} tokens" assert ( self.chain_id.shape[0] == n_tokens ), f"chain_id shape {self.chain_id.shape} != {n_tokens} tokens" assert ( self.plddt.shape[0] == n_tokens ), f"plddt shape {self.plddt.shape} != {n_tokens} tokens" if self.atom_names is not None: assert ( self.atom_names.shape[0] == n_atoms ), f"atom_names shape {self.atom_names.shape} != {n_atoms} atoms" if self.atom_hetero is not None: assert ( self.atom_hetero.shape[0] == n_atoms ), f"atom_hetero shape {self.atom_hetero.shape} != {n_atoms} atoms" def __len__(self) -> int: """Return number of tokens.""" return len(self.sequence) def __getitem__(self, idx: int) -> Molecule: """Access individual molecules/tokens by index.""" if idx >= len(self.sequence) or idx < 0: raise IndexError( f"Token index {idx} out of range for {len(self.sequence)} tokens" ) token = self.sequence[idx] start_atom, end_atom = self.token_to_atoms[idx] # Extract atom data for this token token_atom_positions = self.atom_positions[start_atom:end_atom] token_atom_elements = self.atom_elements[start_atom:end_atom] token_atom_names = None if self.atom_names is not None: token_atom_names = self.atom_names[start_atom:end_atom] token_atom_hetero = None if self.atom_hetero is not None: token_atom_hetero = self.atom_hetero[start_atom:end_atom] # Default values for residue/molecule type (would be extended based on actual implementation) residue_type = 0 # Default to standard residue molecule_type = 0 # Default to protein return Molecule( token=token, token_idx=idx, atom_positions=token_atom_positions, atom_elements=token_atom_elements, atom_names=token_atom_names, atom_hetero=token_atom_hetero, residue_type=residue_type, molecule_type=molecule_type, confidence=self.plddt[idx], ) @property def atom_coordinates(self) -> np.ndarray: """Get flat array of all atom coordinates [N_atoms, 3].""" return self.atom_positions # Conversion methods @classmethod def from_protein_complex(cls, pc: ProteinComplex) -> "MolecularComplex": """Convert a ProteinComplex to MolecularComplex. Args: pc: ProteinComplex object with atom37 representation Returns: MolecularComplex with flat atom arrays and token-based indexing """ from . import esmfold2_residue_constants # Extract sequence without chain breaks sequence_no_breaks = pc.sequence.replace("|", "") sequence_tokens = [ residue_constants.restype_1to3.get(aa, "UNK") for aa in sequence_no_breaks ] # Convert atom37 to flat arrays flat_positions = [] flat_elements = [] flat_names = [] flat_hetero = [] token_to_atoms = [] atom_idx = 0 for i, aa in enumerate(pc.sequence): if aa == "|": # Skip chain break tokens continue # Get atom37 positions and mask for this residue. # ProteinComplex arrays are indexed by sequence position (including |), # so use `i` not a separate residue counter. res_positions = pc.atom37_positions[i] # [37, 3] res_mask = pc.atom37_mask[i] # [37] # Track start position for this token token_start = atom_idx # Process each atom type in atom37 representation for atom_type_idx, atom_name in enumerate(residue_constants.atom_types): if res_mask[atom_type_idx]: # Atom is present # Add position flat_positions.append(res_positions[atom_type_idx]) # Determine element from atom name element = ( atom_name[0] if atom_name else "C" ) # First character is element flat_elements.append(element) # Add atom name flat_names.append(atom_name) # Add hetero flag (all proteins are non-hetero) flat_hetero.append(False) atom_idx += 1 # Record token-to-atom mapping [start_idx, end_idx) token_to_atoms.append([token_start, atom_idx]) # Convert to numpy arrays atom_positions = np.array(flat_positions, dtype=np.float32) atom_elements = np.array(flat_elements, dtype=object) atom_names = np.array(flat_names, dtype=object) atom_hetero = np.array(flat_hetero, dtype=bool) token_to_atoms_array = np.array(token_to_atoms, dtype=np.int32) # Extract confidence scores and chain_ids (skip chain breaks) confidence_scores = [] chain_ids = [] for seq_idx, aa in enumerate(pc.sequence): if aa != "|": confidence_scores.append(pc.confidence[seq_idx]) chain_ids.append(pc.chain_id[seq_idx]) confidence_array = np.array(confidence_scores, dtype=np.float32) chain_id_array = np.array(chain_ids, dtype=np.int64) # Create metadata - convert entity IDs to strings for MolecularComplexMetadata entity_lookup_str = {k: str(v) for k, v in pc.metadata.entity_lookup.items()} metadata = MolecularComplexMetadata( entity_lookup=entity_lookup_str, chain_lookup=pc.metadata.chain_lookup, assembly_composition=pc.metadata.assembly_composition, ) return cls( id=pc.id, sequence=sequence_tokens, atom_positions=atom_positions, atom_elements=atom_elements, token_to_atoms=token_to_atoms_array, chain_id=chain_id_array, plddt=confidence_array, metadata=metadata, atom_names=atom_names, atom_hetero=atom_hetero, ) def to_protein_complex(self) -> ProteinComplex: """Convert MolecularComplex back to ProteinComplex format. Extracts only protein tokens and converts from flat atom representation back to atom37 format used by ProteinComplex. Returns: ProteinComplex with protein residues only, excluding ligands/nucleic acids """ from . import esmfold2_residue_constants # No need for element mapping - already using element characters # Filter for protein tokens only (skip ligands, nucleic acids) protein_tokens = [] protein_indices = [] for i, token in enumerate(self.sequence): # Check if token is a standard 3-letter amino acid code if token in residue_constants.restype_3to1: protein_tokens.append(token) protein_indices.append(i) if not protein_tokens: raise ValueError("No protein tokens found in MolecularComplex") n_residues = len(protein_tokens) # Initialize atom37 arrays atom37_positions = np.full((n_residues, 37, 3), np.nan, dtype=np.float32) atom37_mask = np.zeros((n_residues, 37), dtype=bool) # Extract confidence scores and chain_ids for protein residues only protein_confidence = self.plddt[protein_indices] protein_chain_ids = self.chain_id[protein_indices] # Convert tokens back to single-letter sequence with chain breaks single_letter_residues = [] prev_chain_id = None for i, (token, chain_id_val) in enumerate( zip(protein_tokens, protein_chain_ids) ): # Add chain break if we're switching to a new chain if prev_chain_id is not None and chain_id_val != prev_chain_id: single_letter_residues.append("|") single_letter_residues.append(residue_constants.restype_3to1[token]) prev_chain_id = chain_id_val single_letter_sequence = "".join(single_letter_residues) # Calculate final sequence length (includes chain breaks) sequence_length = len(single_letter_sequence) # Convert flat atoms back to atom37 representation using atom names for res_idx, token_idx in enumerate(protein_indices): token = self.sequence[token_idx] start_atom, end_atom = self.token_to_atoms[token_idx] res_atom_positions = self.atom_positions[start_atom:end_atom] res_atom_names = ( np.array(self.atom_names[start_atom:end_atom], dtype=str) if self.atom_names is not None else np.array([], dtype=str) ) # Build a mapping from normalized atom name -> position for this residue # Normalize to uppercase and strip whitespace for robust matching name_to_pos: dict[str, np.ndarray] = {} for i, nm in enumerate(res_atom_names): key = nm.upper().strip() # Prefer first occurrence; ignore duplicates/altlocs if key not in name_to_pos: name_to_pos[key] = res_atom_positions[i] # Place atoms into atom37 by matching stored atom names to atom37 indices. # This handles all atoms present in the flat representation, not just # the canonical residue_atoms for this residue type. This preserves # atoms that were in the original atom37_mask even if they're atypical # for the residue (e.g., from alternate conformations or data quirks). for atom_name_str, pos in name_to_pos.items(): idx37 = residue_constants.atom_order.get(atom_name_str) if idx37 is not None: atom37_positions[res_idx, idx37] = pos atom37_mask[res_idx, idx37] = True # Create arrays that match sequence length (including chain breaks) # Initialize arrays with proper size chain_id_expanded = np.full(sequence_length, -1, dtype=np.int64) entity_id_expanded = np.full(sequence_length, -1, dtype=np.int64) sym_id_expanded = np.zeros(sequence_length, dtype=np.int64) residue_index_expanded = np.zeros(sequence_length, dtype=np.int64) insertion_code_expanded = np.array([""] * sequence_length, dtype=object) confidence_expanded = np.zeros(sequence_length, dtype=np.float32) atom37_positions_expanded = np.full( (sequence_length, 37, 3), np.nan, dtype=np.float32 ) atom37_mask_expanded = np.zeros((sequence_length, 37), dtype=bool) # Map residue data to sequence positions (skipping chain breaks) residue_idx = 0 residue_counter_per_chain = {} for seq_pos, char in enumerate(single_letter_sequence): if char != "|": # This is a residue position chain_id_val = protein_chain_ids[residue_idx] chain_id_expanded[seq_pos] = chain_id_val entity_id_expanded[seq_pos] = chain_id_val # Simplified mapping # Track residue numbering per chain if chain_id_val not in residue_counter_per_chain: residue_counter_per_chain[chain_id_val] = 1 else: residue_counter_per_chain[chain_id_val] += 1 residue_index_expanded[seq_pos] = residue_counter_per_chain[ chain_id_val ] confidence_expanded[seq_pos] = protein_confidence[residue_idx] atom37_positions_expanded[seq_pos] = atom37_positions[residue_idx] atom37_mask_expanded[seq_pos] = atom37_mask[residue_idx] residue_idx += 1 # Chain break positions keep default values (-1, False, etc.) # Use the expanded arrays chain_id = chain_id_expanded entity_id = entity_id_expanded sym_id = sym_id_expanded residue_index = residue_index_expanded insertion_code = insertion_code_expanded protein_confidence = confidence_expanded atom37_positions = atom37_positions_expanded atom37_mask = atom37_mask_expanded # Create protein complex metadata preserving chain information # Convert MolecularComplex metadata to ProteinComplex format unique_chain_ids = np.unique(protein_chain_ids) entity_lookup = {int(cid): int(cid) for cid in unique_chain_ids} chain_lookup = { int(cid): self.metadata.chain_lookup.get(int(cid), chr(65 + int(cid))) for cid in unique_chain_ids } protein_metadata = ProteinComplexMetadata( entity_lookup=entity_lookup, chain_lookup=chain_lookup, assembly_composition=self.metadata.assembly_composition, ) return ProteinComplex( id=self.id, sequence=single_letter_sequence, entity_id=entity_id, chain_id=chain_id, sym_id=sym_id, residue_index=residue_index, insertion_code=insertion_code, atom37_positions=atom37_positions, atom37_mask=atom37_mask, confidence=protein_confidence, metadata=protein_metadata, ) @classmethod def from_mmcif(cls, inp: str, id: str | None = None) -> "MolecularComplex": """Read MolecularComplex from mmcif file or string. Args: inp: Path to mmCIF file or mmCIF content as string id: Optional identifier to assign to the complex Returns: MolecularComplex with all molecules (proteins, ligands, nucleic acids) """ from io import StringIO # Check if input is a file path or mmCIF string content if os.path.exists(inp): # Input is a file path mmcif_file = pdbx.CIFFile.read(inp) else: # Input is mmCIF string content mmcif_file = pdbx.CIFFile.read(StringIO(inp)) # Get structure - handle missing model information gracefully try: structure = pdbx.get_structure( mmcif_file, model=1, extra_fields=["b_factor"] ) except (KeyError, ValueError): # Fallback for mmCIF files without model information try: structure = pdbx.get_structure(mmcif_file) except Exception: # Last resort: use the first available model or all atoms structure = pdbx.get_structure(mmcif_file, model=None) # Type hint for pyright - structure is an AtomArray which is iterable if TYPE_CHECKING: structure: Any = structure # Read label_asym_id from the raw CIF atom_site category. # Biotite's atom.chain_id uses auth_asym_id, which collapses ligands # onto their parent protein chain. label_asym_id gives each entity a # distinct chain identifier. block = mmcif_file.block label_asym_ids: list[str] | None = None if "atom_site" in block: atom_site = block["atom_site"] if "label_asym_id" in atom_site: _col = atom_site["label_asym_id"] _raw = ( _col.as_array(str) if hasattr(_col, "as_array") else np.array(list(_col), dtype=str) # type: ignore[arg-type] ) # biotite's get_structure(model=1) filters to model 1 AND # removes alternate conformations. We must apply the same # filters to label_asym_id to keep arrays aligned. keep = np.ones(len(_raw), dtype=bool) if "pdbx_PDB_model_num" in atom_site: _mc = atom_site["pdbx_PDB_model_num"] _models = ( _mc.as_array(str) if hasattr(_mc, "as_array") else np.array(list(_mc), dtype=str) # type: ignore[arg-type] ) keep &= _models == "1" if "label_alt_id" in atom_site: _ac = atom_site["label_alt_id"] _alts = ( _ac.as_array(str) if hasattr(_ac, "as_array") else np.array(list(_ac), dtype=str) # type: ignore[arg-type] ) keep &= np.isin(_alts, [".", "?", "", "A"]) filtered = _raw[keep] if len(filtered) == len(structure): label_asym_ids = filtered.tolist() # If lengths still don't match, fall back to atom.chain_id # Get entity information from mmCIF entity_info = {} try: if "entity" in block: entity_category = block["entity"] if "id" in entity_category and "type" in entity_category: entity_ids = entity_category["id"] entity_types = entity_category["type"] # Convert CIFColumn to list for iteration if hasattr(entity_ids, "__iter__") and hasattr( entity_types, "__iter__" ): # Type annotation to help pyright understand these are iterable entity_ids_list = list(entity_ids) # type: ignore entity_types_list = list(entity_types) # type: ignore for eid, etype in zip(entity_ids_list, entity_types_list): entity_info[eid] = etype except Exception: pass # Initialize arrays for flat atom representation sequence_tokens = [] flat_positions = [] flat_elements = [] flat_names = [] flat_hetero = [] token_to_atoms = [] confidence_scores = [] chain_ids = [] # Track chain IDs for each token atom_idx = 0 # Group atoms by chain and residue. # Use label_asym_id (distinct per entity) when available, otherwise # fall back to biotite's chain_id (auth_asym_id). chain_residue_groups: dict[str, dict[tuple[int, str], dict]] = {} for atom_i, atom in enumerate(structure): chain_id = ( label_asym_ids[atom_i] if label_asym_ids is not None else atom.chain_id ) res_id = atom.res_id res_name = atom.res_name if chain_id not in chain_residue_groups: chain_residue_groups[chain_id] = {} # Key by (res_id, res_name) to distinguish residues that share # the same res_id but have different res_name (e.g. a protein # residue and a ligand that were on the same auth chain). res_key = (res_id, res_name) if res_key not in chain_residue_groups[chain_id]: chain_residue_groups[chain_id][res_key] = { "atoms": [], "res_name": res_name, "is_hetero": atom.hetero, } chain_residue_groups[chain_id][res_key]["atoms"].append(atom) # Create a mapping from chain_id to numeric indices chain_id_to_numeric = { chain_id: idx for idx, chain_id in enumerate(sorted(chain_residue_groups.keys())) } # Process each chain and residue for chain_id in sorted(chain_residue_groups.keys()): residues = chain_residue_groups[chain_id] numeric_chain_id = chain_id_to_numeric[chain_id] for res_key in sorted(residues.keys()): residue_data = residues[res_key] res_name = residue_data["res_name"] atoms = residue_data["atoms"] is_hetero = residue_data["is_hetero"] # Skip water molecules if res_name == "HOH": continue # Determine token name if not is_hetero and res_name in residue_constants.restype_3to1: # Standard amino acid token_name = res_name elif res_name in ["A", "T", "G", "C", "U", "DA", "DT", "DG", "DC"]: # Nucleotide token_name = res_name else: # Ligand or other molecule token_name = res_name sequence_tokens.append(token_name) chain_ids.append( numeric_chain_id ) # Store the numeric chain ID for this token token_start = atom_idx # Add all atoms from this residue for atom in atoms: flat_positions.append(atom.coord) # Get element character element = atom.element flat_elements.append(element) # Get atom name atom_name = atom.atom_name flat_names.append(atom_name) # Get hetero flag hetero_flag = atom.hetero flat_hetero.append(hetero_flag) atom_idx += 1 # Record token-to-atom mapping token_to_atoms.append([token_start, atom_idx]) # Add confidence score (B-factor if available, otherwise 1.0) bfactor = getattr(atoms[0], "b_factor", 50.0) if atoms else 50.0 confidence_scores.append(min(bfactor / 100.0, 1.0)) # Convert to numpy arrays if not flat_positions: # Create minimal arrays if no atoms found atom_positions = np.zeros((0, 3), dtype=np.float32) atom_elements = np.zeros(0, dtype=object) atom_names = np.zeros(0, dtype=object) atom_hetero = np.zeros(0, dtype=bool) token_to_atoms_array = np.zeros((len(sequence_tokens), 2), dtype=np.int32) chain_id_array = ( np.array(chain_ids, dtype=np.int64) if chain_ids else np.zeros(len(sequence_tokens), dtype=np.int64) ) else: atom_positions = np.array(flat_positions, dtype=np.float32) atom_elements = np.array(flat_elements, dtype=object) atom_names = np.array(flat_names, dtype=object) atom_hetero = np.array(flat_hetero, dtype=bool) token_to_atoms_array = np.array(token_to_atoms, dtype=np.int32) chain_id_array = np.array(chain_ids, dtype=np.int64) confidence_array = np.array(confidence_scores, dtype=np.float32) # Create metadata using the chain_id_to_numeric mapping if chain_residue_groups: chain_lookup = { numeric_id: chain_id for chain_id, numeric_id in chain_id_to_numeric.items() } else: chain_lookup = {} metadata = MolecularComplexMetadata( entity_lookup=entity_info, chain_lookup=chain_lookup, assembly_composition=None, ) # Set complex ID - if input was a path, use the stem; otherwise use default if os.path.exists(inp): complex_id = id or Path(inp).stem else: complex_id = id or "complex_from_string" return cls( id=complex_id, sequence=sequence_tokens, atom_positions=atom_positions, atom_elements=atom_elements, token_to_atoms=token_to_atoms_array, chain_id=chain_id_array, plddt=confidence_array, metadata=metadata, atom_names=atom_names, atom_hetero=atom_hetero, ) def _get_entity_mapping( self, ) -> tuple[dict[str, list[str]], dict[str, int], dict[int, tuple[str, ...]]]: """Compute chain→sequence, chain→entity_id, and entity_id→sequence mappings. Returns: (chain_sequences, chain_to_entity, entity_sequences) """ chain_sequences: dict[str, list[str]] = {} for token_idx in range(len(self.token_to_atoms)): chain_id_numeric = self.chain_id[token_idx] chain_id_str = self.metadata.chain_lookup.get( int(chain_id_numeric), chr(65 + int(chain_id_numeric)) ) if chain_id_str not in chain_sequences: chain_sequences[chain_id_str] = [] chain_sequences[chain_id_str].append(self.sequence[token_idx]) sequence_to_entity: dict[tuple[str, ...], int] = {} chain_to_entity: dict[str, int] = {} entity_sequences: dict[int, tuple[str, ...]] = {} entity_id_counter = 1 for chain_id_str, sequence in chain_sequences.items(): seq_tuple = tuple(sequence) if seq_tuple not in sequence_to_entity: sequence_to_entity[seq_tuple] = entity_id_counter entity_sequences[entity_id_counter] = seq_tuple entity_id_counter += 1 chain_to_entity[chain_id_str] = sequence_to_entity[seq_tuple] return chain_sequences, chain_to_entity, entity_sequences def _add_entity_information( self, cif_file: CIFFile, entity_sequences: dict[int, tuple[str, ...]] ) -> None: """Add _entity category to CIF file so OST can identify ligands vs polymers.""" entity_ids: list[str] = [] entity_types: list[str] = [] entity_descriptions: list[str] = [] for eid in sorted(entity_sequences.keys()): seq = entity_sequences[eid] entity_ids.append(str(eid)) has_protein = any(t in residue_constants.restype_3to1 for t in seq) has_na = any( t in ("A", "T", "G", "C", "U", "DA", "DT", "DG", "DC") for t in seq ) if has_protein or has_na: entity_types.append("polymer") if has_protein: entity_descriptions.append(f"Polymer entity {eid} (protein)") else: entity_descriptions.append(f"Polymer entity {eid} (nucleic acid)") else: entity_types.append("non-polymer") entity_descriptions.append(f"Non-polymer entity {eid}") if entity_ids: cif_file.block["entity"] = CIFCategory( name="entity", columns={ "id": CIFColumn( data=CIFData(array=np.array(entity_ids), dtype=np.str_) ), "type": CIFColumn( data=CIFData(array=np.array(entity_types), dtype=np.str_) ), "pdbx_description": CIFColumn( data=CIFData(array=np.array(entity_descriptions), dtype=np.str_) ), }, ) # Add _struct_asym to map chain IDs to entity IDs _, chain_to_entity, _ = self._get_entity_mapping() if chain_to_entity: asym_ids = sorted(chain_to_entity.keys()) asym_entity_ids = [str(chain_to_entity[c]) for c in asym_ids] cif_file.block["struct_asym"] = CIFCategory( name="struct_asym", columns={ "id": CIFColumn( data=CIFData(array=np.array(asym_ids), dtype=np.str_) ), "entity_id": CIFColumn( data=CIFData(array=np.array(asym_entity_ids), dtype=np.str_) ), }, ) def to_mmcif(self) -> str: """Write MolecularComplex to mmcif string using biotite. Returns: String representation of the complex in mmCIF format """ # Pre-allocate AtomArray n_atoms = len(self.atom_positions) atom_array = bs.AtomArray(length=n_atoms) # Set coordinates directly (already vectorized) atom_array.coord = self.atom_positions # Pre-allocate per-atom arrays atom_res_ids = np.zeros(n_atoms, dtype=np.int32) atom_chain_ids = np.empty(n_atoms, dtype=object) atom_res_names = np.empty(n_atoms, dtype=object) atom_hetero = np.zeros(n_atoms, dtype=bool) atom_bfactors = np.zeros(n_atoms, dtype=np.float32) atom_names = np.empty(n_atoms, dtype=object) # Build entity mappings: chains with identical sequences share entity ID _, chain_to_entity, entity_sequences = self._get_entity_mapping() atom_entity_ids = np.zeros(n_atoms, dtype=np.int32) # Track residue IDs per chain chain_res_counters: dict[int, int] = {} # Vectorized expansion of token-level to atom-level annotations for token_idx, (start, end) in enumerate(self.token_to_atoms): token = self.sequence[token_idx] chain_id_numeric = self.chain_id[token_idx] chain_id_str = self.metadata.chain_lookup.get( int(chain_id_numeric), chr(65 + int(chain_id_numeric)) ) # Track residue numbering per chain if chain_id_numeric not in chain_res_counters: chain_res_counters[chain_id_numeric] = 1 res_id = chain_res_counters[chain_id_numeric] chain_res_counters[chain_id_numeric] += 1 # Determine if protein is_protein = token in residue_constants.restype_3to1 # Get atom names for this residue if self.atom_names is not None: # Use stored atom names (preserves original names from mmCIF) names = list(self.atom_names[start:end]) elif is_protein: # Fallback: use standard protein atom names standard_names = residue_constants.residue_atoms.get( token, ["N", "CA", "C", "O"] ) names = standard_names[: end - start] # Pad if needed while len(names) < (end - start): names.append(f"X{len(names)+1}") else: # Fallback: generate names for ligands/nucleic acids names = [f"C{i+1}" for i in range(end - start)] # Vectorized assignment for this token's atoms atom_res_ids[start:end] = res_id atom_chain_ids[start:end] = chain_id_str atom_res_names[start:end] = token # Use stored hetero flags if available, otherwise guess based on protein status if self.atom_hetero is not None: atom_hetero[start:end] = self.atom_hetero[start:end] else: atom_hetero[start:end] = not is_protein atom_bfactors[start:end] = self.plddt[token_idx] * 100.0 atom_names[start:end] = names atom_entity_ids[start:end] = chain_to_entity.get(chain_id_str, 1) # Set all AtomArray attributes at once (convert object arrays to proper string arrays) # res_name uses U8 to accommodate CCD codes up to 5 characters (e.g., A1AZ2); # chain_id uses U16 because chain names like ``ligand_1`` / ``ligand_2`` / # auth-asym ids of arbitrary length are possible. atom_array.res_id = atom_res_ids atom_array.chain_id = np.array(atom_chain_ids, dtype="U16") atom_array.res_name = np.array(atom_res_names, dtype="U8") atom_array.hetero = atom_hetero atom_array.atom_name = np.array(atom_names, dtype="U4") atom_array.add_annotation("b_factor", dtype=float) atom_array.b_factor = atom_bfactors atom_array.add_annotation("entity_id", dtype=int) atom_array.entity_id = atom_entity_ids # Use existing elements or infer them from atom names if self.atom_elements is not None and len(self.atom_elements) == n_atoms: # Convert object array to proper string array for biotite atom_array.element = np.array(self.atom_elements, dtype="U4") else: # Use biotite's built-in element inference atom_array.element = bs.infer_elements(atom_array) # Create CIF file and set structure cif_file = CIFFile() set_structure(cif_file, atom_array, data_block=self.id) # Manually fix label_entity_id (biotite doesn't use entity_id annotation correctly) if "atom_site" in cif_file.block: atom_site = cif_file.block["atom_site"] if "label_asym_id" in atom_site and "label_entity_id" in atom_site: label_asym_ids = atom_site["label_asym_id"] if hasattr(label_asym_ids, "as_array"): chain_ids_list = label_asym_ids.as_array(str).tolist() elif hasattr(label_asym_ids, "__iter__"): chain_ids_list = list(label_asym_ids) # type: ignore[arg-type] else: chain_ids_list = [] updated_entity_ids = [ str(chain_to_entity.get(cid, 1)) for cid in chain_ids_list ] if updated_entity_ids: atom_site["label_entity_id"] = CIFColumn( data=CIFData(array=np.array(updated_entity_ids), dtype=np.str_) ) # Add _entity category for OST compatibility self._add_entity_information(cif_file, entity_sequences) # Convert to string output = io.StringIO() cif_file.write(output) return output.getvalue() def dockq(self, native: "MolecularComplex") -> Any: """Compute DockQ score against native structure. Args: native: Native MolecularComplex to compute DockQ against Returns: DockQ result containing score and alignment information """ # Imports moved to top of file # Convert both complexes to ProteinComplex format for DockQ computation # This extracts only the protein portion and converts to PDB format try: self_pc = self.to_protein_complex() native_pc = native.to_protein_complex() except ValueError as e: raise ValueError( f"Cannot convert MolecularComplex to ProteinComplex for DockQ: {e}" ) # Normalize chain IDs for PDB compatibility self_pc = self_pc.normalize_chain_ids_for_pdb() native_pc = native_pc.normalize_chain_ids_for_pdb() # Use the existing ProteinComplex.dockq() method try: dockq_result = self_pc.dockq(native_pc) return dockq_result except Exception: # Fallback to manual DockQ computation if ProteinComplex.dockq() fails return self._compute_dockq_manual(native) def _compute_dockq_manual(self, native: "MolecularComplex") -> Any: """Manual DockQ computation fallback.""" # Imports moved to top of file # Convert both complexes to ProteinComplex format try: self_pc = self.to_protein_complex() native_pc = native.to_protein_complex() except ValueError as e: raise ValueError( f"Cannot convert MolecularComplex to ProteinComplex for DockQ: {e}" ) # Normalize chain IDs for PDB compatibility self_pc = self_pc.normalize_chain_ids_for_pdb() native_pc = native_pc.normalize_chain_ids_for_pdb() # Write temporary PDB files and run DockQ with TemporaryDirectory() as tdir: dir_path = Path(tdir) self_pdb = dir_path / "self.pdb" native_pdb = dir_path / "native.pdb" # Write PDB files self_pc.to_pdb(self_pdb) native_pc.to_pdb(native_pdb) # Run DockQ try: output = check_output(["DockQ", str(self_pdb), str(native_pdb)]) output_text = output.decode() # Parse DockQ output lines = output_text.split("\n") # Find the total DockQ score dockq_score = None for line in lines: if "Total DockQ" in line: match = re.search(r"Total DockQ.*: ([\d.]+)", line) if match: dockq_score = float(match.group(1)) break if dockq_score is None: # Try to find individual DockQ scores for line in lines: if line.startswith("DockQ") and ":" in line: try: dockq_score = float(line.split(":")[1].strip()) break except (ValueError, IndexError): continue if dockq_score is None: raise ValueError("Could not parse DockQ score from output") # Return a simple result structure return { "total_dockq": dockq_score, "raw_output": output_text, "aligned": self, # Return self as aligned structure } except FileNotFoundError: raise RuntimeError( "DockQ is not installed. Please install DockQ to use this method." ) except Exception as e: raise RuntimeError(f"DockQ computation failed: {e}") def rmsd(self, target: "MolecularComplex", **kwargs) -> float: """Compute RMSD against target structure. Args: target: Target MolecularComplex to compute RMSD against **kwargs: Additional arguments passed to compute_rmsd Returns: float: RMSD value between the two structures """ # Imports moved to top of file # Ensure both complexes have the same number of tokens if len(self) != len(target): raise ValueError( f"Complexes must have the same number of tokens: {len(self)} vs {len(target)}" ) # Extract center positions for each token (using centroid of atoms) mobile_coords = [] target_coords = [] atom_mask = [] for i in range(len(self)): # Get atom positions for this token mobile_start, mobile_end = self.token_to_atoms[i] target_start, target_end = target.token_to_atoms[i] # Extract atom positions mobile_atoms = self.atom_positions[mobile_start:mobile_end] target_atoms = target.atom_positions[target_start:target_end] # Check if both tokens have atoms if len(mobile_atoms) == 0 or len(target_atoms) == 0: # Skip tokens with no atoms continue # For simplicity, use the centroid of atoms as the representative position mobile_center = mobile_atoms.mean(axis=0) target_center = target_atoms.mean(axis=0) mobile_coords.append(mobile_center) target_coords.append(target_center) atom_mask.append(True) if len(mobile_coords) == 0: raise ValueError("No valid atoms found for RMSD computation") # Convert to tensors mobile_tensor = torch.from_numpy(np.stack(mobile_coords, axis=0)).unsqueeze( 0 ) # [1, N, 3] target_tensor = torch.from_numpy(np.stack(target_coords, axis=0)).unsqueeze( 0 ) # [1, N, 3] mask_tensor = torch.tensor(atom_mask, dtype=torch.bool).unsqueeze(0) # [1, N] # Compute RMSD using existing infrastructure rmsd_value = compute_rmsd( mobile=mobile_tensor, target=target_tensor, atom_exists_mask=mask_tensor, reduction="batch", **kwargs, ) return float(rmsd_value) def lddt_ca(self, target: "MolecularComplex", **kwargs) -> float: """Compute LDDT score against target structure. Args: target: Target MolecularComplex to compute LDDT against **kwargs: Additional arguments passed to compute_lddt Returns: float: LDDT value between the two structures """ # Imports moved to top of file # Ensure both complexes have the same number of tokens if len(self) != len(target): raise ValueError( f"Complexes must have the same number of tokens: {len(self)} vs {len(target)}" ) # Extract center positions for each token (using centroid of atoms) mobile_coords = [] target_coords = [] atom_mask = [] for i in range(len(self)): # Get atom positions for this token mobile_start, mobile_end = self.token_to_atoms[i] target_start, target_end = target.token_to_atoms[i] # Extract atom positions mobile_atoms = self.atom_positions[mobile_start:mobile_end] target_atoms = target.atom_positions[target_start:target_end] # Check if both tokens have atoms if len(mobile_atoms) == 0 or len(target_atoms) == 0: # Skip tokens with no atoms mobile_coords.append(np.full(3, np.nan)) target_coords.append(np.full(3, np.nan)) atom_mask.append(False) continue # For simplicity, use the centroid of atoms as the representative position mobile_center = mobile_atoms.mean(axis=0) target_center = target_atoms.mean(axis=0) mobile_coords.append(mobile_center) target_coords.append(target_center) atom_mask.append(True) if not any(atom_mask): raise ValueError("No valid atoms found for LDDT computation") # Convert to tensors mobile_tensor = torch.from_numpy(np.stack(mobile_coords, axis=0)).unsqueeze( 0 ) # [1, N, 3] target_tensor = torch.from_numpy(np.stack(target_coords, axis=0)).unsqueeze( 0 ) # [1, N, 3] mask_tensor = torch.tensor(atom_mask, dtype=torch.bool).unsqueeze(0) # [1, N] # Compute LDDT using existing infrastructure lddt_value = compute_lddt( all_atom_pred_pos=mobile_tensor, all_atom_positions=target_tensor, all_atom_mask=mask_tensor, per_residue=False, # Return overall LDDT score **kwargs, ) return float(lddt_value) def state_dict(self): """This state dict is optimized for storage, so it turns things to fp16 whenever possible and converts numpy arrays to lists for JSON serialization. """ dct = {k: v for k, v in vars(self).items()} for k, v in dct.items(): if isinstance(v, np.ndarray): match v.dtype: case np.int64: dct[k] = v.astype(np.int32).tolist() case np.float64 | np.float32: dct[k] = v.astype(np.float16).tolist() case _: dct[k] = v.tolist() elif isinstance(v, MolecularComplexMetadata): dct[k] = asdict(v) return dct def to_blob(self) -> bytes: return brotli.compress(msgpack.dumps(self.state_dict()), quality=5) @classmethod def from_state_dict(cls, dct): for k, v in dct.items(): if isinstance(v, list) and k in [ "atom_positions", "atom_elements", "atom_names", "atom_hetero", "token_to_atoms", "chain_id", "plddt", ]: dct[k] = np.array(v) for k, v in dct.items(): if isinstance(v, np.ndarray): if k in ["atom_positions", "plddt"]: dct[k] = v.astype(np.float32) elif k in ["token_to_atoms", "chain_id"]: dct[k] = ( v.astype(np.int32) if k == "token_to_atoms" else v.astype(np.int64) ) dct["metadata"] = MolecularComplexMetadata(**dct["metadata"]) # Backward compatibility: if chain_id is missing, create default array if "chain_id" not in dct: # Default all tokens to chain 0 dct["chain_id"] = np.zeros(len(dct["sequence"]), dtype=np.int64) return cls(**dct) @classmethod def from_blob(cls, input: Path | str | io.BytesIO | bytes): match input: case Path() | str(): bytes = Path(input).read_bytes() case io.BytesIO(): bytes = input.getvalue() case _: bytes = input return cls.from_state_dict( msgpack.loads(brotli.decompress(bytes), strict_map_key=False) )