Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import os | |
| import time | |
| from abc import ABC, abstractmethod | |
| from dataclasses import dataclass | |
| from typing import Any | |
| DEMO_PDB = """HEADER CARBON FOLDING API STUB | |
| ATOM 1 N ALA A 1 -0.500 1.300 0.000 1.00 80.00 N | |
| ATOM 2 CA ALA A 1 0.000 0.000 0.000 1.00 80.00 C | |
| ATOM 3 C ALA A 1 1.520 0.000 0.000 1.00 80.00 C | |
| ATOM 4 O ALA A 1 2.110 -1.060 0.000 1.00 80.00 O | |
| ATOM 5 N GLY A 2 2.160 1.170 0.000 1.00 82.00 N | |
| ATOM 6 CA GLY A 2 3.600 1.260 0.000 1.00 82.00 C | |
| ATOM 7 C GLY A 2 4.160 2.660 0.000 1.00 82.00 C | |
| ATOM 8 O GLY A 2 3.480 3.660 0.000 1.00 82.00 O | |
| ATOM 9 N SER A 3 5.430 2.730 0.000 1.00 76.00 N | |
| ATOM 10 CA SER A 3 6.080 4.030 0.000 1.00 76.00 C | |
| ATOM 11 C SER A 3 7.600 3.910 0.000 1.00 76.00 C | |
| ATOM 12 O SER A 3 8.250 4.920 0.000 1.00 76.00 O | |
| TER | |
| END | |
| """ | |
| class FoldOutput: | |
| pdb: str | |
| confidence: dict[str, Any] | |
| metrics: dict[str, Any] | |
| warnings: list[str] | |
| class FoldingBackend(ABC): | |
| def fold(self, sequence: str, options: dict[str, Any]) -> FoldOutput: | |
| raise NotImplementedError | |
| class StubBackend(FoldingBackend): | |
| def fold(self, sequence: str, options: dict[str, Any]) -> FoldOutput: | |
| del options | |
| started = time.monotonic() | |
| time.sleep(min(0.1, max(0.0, len(sequence) / 10_000))) | |
| return FoldOutput( | |
| pdb=DEMO_PDB, | |
| confidence={"mean_plddt": 80.0}, | |
| metrics={"runtime_seconds": round(time.monotonic() - started, 4), "sequence_length": len(sequence)}, | |
| warnings=["stub backend returned a demo structure"], | |
| ) | |
| class EsmFoldBackend(FoldingBackend): | |
| def __init__(self, model_id: str = "facebook/esmfold_v1") -> None: | |
| self.model_id = model_id | |
| self._loaded = False | |
| self._device = None | |
| self._tokenizer = None | |
| self._model = None | |
| def _load(self) -> None: | |
| if self._loaded: | |
| return | |
| import torch | |
| from transformers import AutoTokenizer, EsmForProteinFolding | |
| self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self._tokenizer = AutoTokenizer.from_pretrained(self.model_id) | |
| self._model = EsmForProteinFolding.from_pretrained( | |
| self.model_id, | |
| low_cpu_mem_usage=True, | |
| ) | |
| self._model.eval() | |
| self._model.to(self._device) | |
| # Reduce memory use for longer demo proteins. This is supported by the | |
| # Transformers ESMFold implementation and is a no-op if unavailable. | |
| if hasattr(self._model, "trunk") and hasattr(self._model.trunk, "set_chunk_size"): | |
| self._model.trunk.set_chunk_size(int(os.getenv("ESMFOLD_CHUNK_SIZE", "64"))) | |
| self._loaded = True | |
| def fold(self, sequence: str, options: dict[str, Any]) -> FoldOutput: | |
| del options | |
| started = time.monotonic() | |
| self._load() | |
| import torch | |
| assert self._device is not None | |
| assert self._tokenizer is not None | |
| assert self._model is not None | |
| tokenized = self._tokenizer([sequence], return_tensors="pt", add_special_tokens=False) | |
| tokenized = {key: value.to(self._device) for key, value in tokenized.items()} | |
| with torch.no_grad(): | |
| output = self._model(**tokenized) | |
| pdb = _esmfold_output_to_pdb(output) | |
| mean_plddt = _mean_plddt(output) | |
| runtime = time.monotonic() - started | |
| warnings = [] | |
| if self._device.type != "cuda": | |
| warnings.append("ESMFold ran on CPU; GPU is recommended") | |
| if mean_plddt is not None and mean_plddt < 50: | |
| warnings.append("low mean pLDDT; predicted structure may be unreliable") | |
| return FoldOutput( | |
| pdb=pdb, | |
| confidence={"mean_plddt": mean_plddt}, | |
| metrics={ | |
| "runtime_seconds": round(runtime, 4), | |
| "sequence_length": len(sequence), | |
| "device": self._device.type, | |
| }, | |
| warnings=warnings, | |
| ) | |
| def _as_mapping(output: Any) -> dict[str, Any]: | |
| if isinstance(output, dict): | |
| return output | |
| if hasattr(output, "to_tuple") and hasattr(output, "keys"): | |
| return {key: output[key] for key in output.keys()} | |
| if hasattr(output, "__dict__"): | |
| return {key: value for key, value in vars(output).items() if not key.startswith("_")} | |
| raise TypeError("unsupported ESMFold output type") | |
| def _esmfold_output_to_pdb(output: Any) -> str: | |
| import torch | |
| from transformers.models.esm.openfold_utils.feats import atom14_to_atom37 | |
| from transformers.models.esm.openfold_utils.protein import Protein as OpenFoldProtein | |
| from transformers.models.esm.openfold_utils.protein import to_pdb | |
| data = _as_mapping(output) | |
| final_atom_positions = atom14_to_atom37(data["positions"][-1], data) | |
| cpu_data = {} | |
| for key, value in data.items(): | |
| if torch.is_tensor(value): | |
| cpu_data[key] = value.detach().cpu().numpy() | |
| else: | |
| cpu_data[key] = value | |
| final_atom_positions = final_atom_positions.detach().cpu().numpy() | |
| final_atom_mask = cpu_data["atom37_atom_exists"] | |
| b_factors = cpu_data["plddt"][0] | |
| if float(b_factors.max()) <= 1.5: | |
| b_factors = b_factors * 100.0 | |
| protein = OpenFoldProtein( | |
| aatype=cpu_data["aatype"][0], | |
| atom_positions=final_atom_positions[0], | |
| atom_mask=final_atom_mask[0], | |
| residue_index=cpu_data["residue_index"][0] + 1, | |
| b_factors=b_factors, | |
| chain_index=cpu_data.get("chain_index", [None])[0], | |
| ) | |
| return to_pdb(protein) | |
| def _mean_plddt(output: Any) -> float | None: | |
| data = _as_mapping(output) | |
| plddt = data.get("plddt") | |
| if plddt is None: | |
| return None | |
| if hasattr(plddt, "detach"): | |
| value = float(plddt.detach().float().mean().cpu().item()) | |
| else: | |
| value = float(plddt.mean()) | |
| if value <= 1.5: | |
| value *= 100.0 | |
| return round(value, 4) | |
| def make_backend() -> FoldingBackend: | |
| backend = os.getenv("FOLD_BACKEND", "esmfold").strip().lower() | |
| if backend == "stub": | |
| return StubBackend() | |
| if backend == "esmfold": | |
| return EsmFoldBackend(os.getenv("ESMFOLD_MODEL_ID", "facebook/esmfold_v1")) | |
| raise ValueError(f"unsupported FOLD_BACKEND: {backend}") | |