from __future__ import annotations import os import time from abc import ABC, abstractmethod from dataclasses import dataclass from typing import Any DEMO_PDB = """HEADER CARBON FOLDING API STUB ATOM 1 N ALA A 1 -0.500 1.300 0.000 1.00 80.00 N ATOM 2 CA ALA A 1 0.000 0.000 0.000 1.00 80.00 C ATOM 3 C ALA A 1 1.520 0.000 0.000 1.00 80.00 C ATOM 4 O ALA A 1 2.110 -1.060 0.000 1.00 80.00 O ATOM 5 N GLY A 2 2.160 1.170 0.000 1.00 82.00 N ATOM 6 CA GLY A 2 3.600 1.260 0.000 1.00 82.00 C ATOM 7 C GLY A 2 4.160 2.660 0.000 1.00 82.00 C ATOM 8 O GLY A 2 3.480 3.660 0.000 1.00 82.00 O ATOM 9 N SER A 3 5.430 2.730 0.000 1.00 76.00 N ATOM 10 CA SER A 3 6.080 4.030 0.000 1.00 76.00 C ATOM 11 C SER A 3 7.600 3.910 0.000 1.00 76.00 C ATOM 12 O SER A 3 8.250 4.920 0.000 1.00 76.00 O TER END """ @dataclass(frozen=True) class FoldOutput: pdb: str confidence: dict[str, Any] metrics: dict[str, Any] warnings: list[str] class FoldingBackend(ABC): @abstractmethod def fold(self, sequence: str, options: dict[str, Any]) -> FoldOutput: raise NotImplementedError class StubBackend(FoldingBackend): def fold(self, sequence: str, options: dict[str, Any]) -> FoldOutput: del options started = time.monotonic() time.sleep(min(0.1, max(0.0, len(sequence) / 10_000))) return FoldOutput( pdb=DEMO_PDB, confidence={"mean_plddt": 80.0}, metrics={"runtime_seconds": round(time.monotonic() - started, 4), "sequence_length": len(sequence)}, warnings=["stub backend returned a demo structure"], ) class EsmFoldBackend(FoldingBackend): def __init__(self, model_id: str = "facebook/esmfold_v1") -> None: self.model_id = model_id self._loaded = False self._device = None self._tokenizer = None self._model = None def _load(self) -> None: if self._loaded: return import torch from transformers import AutoTokenizer, EsmForProteinFolding self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self._tokenizer = AutoTokenizer.from_pretrained(self.model_id) self._model = EsmForProteinFolding.from_pretrained( self.model_id, low_cpu_mem_usage=True, ) self._model.eval() self._model.to(self._device) # Reduce memory use for longer demo proteins. This is supported by the # Transformers ESMFold implementation and is a no-op if unavailable. if hasattr(self._model, "trunk") and hasattr(self._model.trunk, "set_chunk_size"): self._model.trunk.set_chunk_size(int(os.getenv("ESMFOLD_CHUNK_SIZE", "64"))) self._loaded = True def fold(self, sequence: str, options: dict[str, Any]) -> FoldOutput: del options started = time.monotonic() self._load() import torch assert self._device is not None assert self._tokenizer is not None assert self._model is not None tokenized = self._tokenizer([sequence], return_tensors="pt", add_special_tokens=False) tokenized = {key: value.to(self._device) for key, value in tokenized.items()} with torch.no_grad(): output = self._model(**tokenized) pdb = _esmfold_output_to_pdb(output) mean_plddt = _mean_plddt(output) runtime = time.monotonic() - started warnings = [] if self._device.type != "cuda": warnings.append("ESMFold ran on CPU; GPU is recommended") if mean_plddt is not None and mean_plddt < 50: warnings.append("low mean pLDDT; predicted structure may be unreliable") return FoldOutput( pdb=pdb, confidence={"mean_plddt": mean_plddt}, metrics={ "runtime_seconds": round(runtime, 4), "sequence_length": len(sequence), "device": self._device.type, }, warnings=warnings, ) def _as_mapping(output: Any) -> dict[str, Any]: if isinstance(output, dict): return output if hasattr(output, "to_tuple") and hasattr(output, "keys"): return {key: output[key] for key in output.keys()} if hasattr(output, "__dict__"): return {key: value for key, value in vars(output).items() if not key.startswith("_")} raise TypeError("unsupported ESMFold output type") def _esmfold_output_to_pdb(output: Any) -> str: import torch from transformers.models.esm.openfold_utils.feats import atom14_to_atom37 from transformers.models.esm.openfold_utils.protein import Protein as OpenFoldProtein from transformers.models.esm.openfold_utils.protein import to_pdb data = _as_mapping(output) final_atom_positions = atom14_to_atom37(data["positions"][-1], data) cpu_data = {} for key, value in data.items(): if torch.is_tensor(value): cpu_data[key] = value.detach().cpu().numpy() else: cpu_data[key] = value final_atom_positions = final_atom_positions.detach().cpu().numpy() final_atom_mask = cpu_data["atom37_atom_exists"] b_factors = cpu_data["plddt"][0] if float(b_factors.max()) <= 1.5: b_factors = b_factors * 100.0 protein = OpenFoldProtein( aatype=cpu_data["aatype"][0], atom_positions=final_atom_positions[0], atom_mask=final_atom_mask[0], residue_index=cpu_data["residue_index"][0] + 1, b_factors=b_factors, chain_index=cpu_data.get("chain_index", [None])[0], ) return to_pdb(protein) def _mean_plddt(output: Any) -> float | None: data = _as_mapping(output) plddt = data.get("plddt") if plddt is None: return None if hasattr(plddt, "detach"): value = float(plddt.detach().float().mean().cpu().item()) else: value = float(plddt.mean()) if value <= 1.5: value *= 100.0 return round(value, 4) def make_backend() -> FoldingBackend: backend = os.getenv("FOLD_BACKEND", "esmfold").strip().lower() if backend == "stub": return StubBackend() if backend == "esmfold": return EsmFoldBackend(os.getenv("ESMFOLD_MODEL_ID", "facebook/esmfold_v1")) raise ValueError(f"unsupported FOLD_BACKEND: {backend}")