Spaces:
Sleeping
Sleeping
File size: 6,696 Bytes
8ab4ff2 f6e8c3d 8ab4ff2 f6e8c3d 8ab4ff2 f6e8c3d 8ab4ff2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 | from __future__ import annotations
import os
import time
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any
DEMO_PDB = """HEADER CARBON FOLDING API STUB
ATOM 1 N ALA A 1 -0.500 1.300 0.000 1.00 80.00 N
ATOM 2 CA ALA A 1 0.000 0.000 0.000 1.00 80.00 C
ATOM 3 C ALA A 1 1.520 0.000 0.000 1.00 80.00 C
ATOM 4 O ALA A 1 2.110 -1.060 0.000 1.00 80.00 O
ATOM 5 N GLY A 2 2.160 1.170 0.000 1.00 82.00 N
ATOM 6 CA GLY A 2 3.600 1.260 0.000 1.00 82.00 C
ATOM 7 C GLY A 2 4.160 2.660 0.000 1.00 82.00 C
ATOM 8 O GLY A 2 3.480 3.660 0.000 1.00 82.00 O
ATOM 9 N SER A 3 5.430 2.730 0.000 1.00 76.00 N
ATOM 10 CA SER A 3 6.080 4.030 0.000 1.00 76.00 C
ATOM 11 C SER A 3 7.600 3.910 0.000 1.00 76.00 C
ATOM 12 O SER A 3 8.250 4.920 0.000 1.00 76.00 O
TER
END
"""
@dataclass(frozen=True)
class FoldOutput:
pdb: str
confidence: dict[str, Any]
metrics: dict[str, Any]
warnings: list[str]
class FoldingBackend(ABC):
@abstractmethod
def fold(self, sequence: str, options: dict[str, Any]) -> FoldOutput:
raise NotImplementedError
class StubBackend(FoldingBackend):
def fold(self, sequence: str, options: dict[str, Any]) -> FoldOutput:
del options
started = time.monotonic()
time.sleep(min(0.1, max(0.0, len(sequence) / 10_000)))
return FoldOutput(
pdb=DEMO_PDB,
confidence={"mean_plddt": 80.0},
metrics={"runtime_seconds": round(time.monotonic() - started, 4), "sequence_length": len(sequence)},
warnings=["stub backend returned a demo structure"],
)
class EsmFoldBackend(FoldingBackend):
def __init__(self, model_id: str = "facebook/esmfold_v1") -> None:
self.model_id = model_id
self._loaded = False
self._device = None
self._tokenizer = None
self._model = None
def _load(self) -> None:
if self._loaded:
return
import torch
from transformers import AutoTokenizer, EsmForProteinFolding
self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self._tokenizer = AutoTokenizer.from_pretrained(self.model_id)
self._model = EsmForProteinFolding.from_pretrained(
self.model_id,
low_cpu_mem_usage=True,
)
self._model.eval()
self._model.to(self._device)
# Reduce memory use for longer demo proteins. This is supported by the
# Transformers ESMFold implementation and is a no-op if unavailable.
if hasattr(self._model, "trunk") and hasattr(self._model.trunk, "set_chunk_size"):
self._model.trunk.set_chunk_size(int(os.getenv("ESMFOLD_CHUNK_SIZE", "64")))
self._loaded = True
def fold(self, sequence: str, options: dict[str, Any]) -> FoldOutput:
del options
started = time.monotonic()
self._load()
import torch
assert self._device is not None
assert self._tokenizer is not None
assert self._model is not None
tokenized = self._tokenizer([sequence], return_tensors="pt", add_special_tokens=False)
tokenized = {key: value.to(self._device) for key, value in tokenized.items()}
with torch.no_grad():
output = self._model(**tokenized)
pdb = _esmfold_output_to_pdb(output)
mean_plddt = _mean_plddt(output)
runtime = time.monotonic() - started
warnings = []
if self._device.type != "cuda":
warnings.append("ESMFold ran on CPU; GPU is recommended")
if mean_plddt is not None and mean_plddt < 50:
warnings.append("low mean pLDDT; predicted structure may be unreliable")
return FoldOutput(
pdb=pdb,
confidence={"mean_plddt": mean_plddt},
metrics={
"runtime_seconds": round(runtime, 4),
"sequence_length": len(sequence),
"device": self._device.type,
},
warnings=warnings,
)
def _as_mapping(output: Any) -> dict[str, Any]:
if isinstance(output, dict):
return output
if hasattr(output, "to_tuple") and hasattr(output, "keys"):
return {key: output[key] for key in output.keys()}
if hasattr(output, "__dict__"):
return {key: value for key, value in vars(output).items() if not key.startswith("_")}
raise TypeError("unsupported ESMFold output type")
def _esmfold_output_to_pdb(output: Any) -> str:
import torch
from transformers.models.esm.openfold_utils.feats import atom14_to_atom37
from transformers.models.esm.openfold_utils.protein import Protein as OpenFoldProtein
from transformers.models.esm.openfold_utils.protein import to_pdb
data = _as_mapping(output)
final_atom_positions = atom14_to_atom37(data["positions"][-1], data)
cpu_data = {}
for key, value in data.items():
if torch.is_tensor(value):
cpu_data[key] = value.detach().cpu().numpy()
else:
cpu_data[key] = value
final_atom_positions = final_atom_positions.detach().cpu().numpy()
final_atom_mask = cpu_data["atom37_atom_exists"]
b_factors = cpu_data["plddt"][0]
if float(b_factors.max()) <= 1.5:
b_factors = b_factors * 100.0
protein = OpenFoldProtein(
aatype=cpu_data["aatype"][0],
atom_positions=final_atom_positions[0],
atom_mask=final_atom_mask[0],
residue_index=cpu_data["residue_index"][0] + 1,
b_factors=b_factors,
chain_index=cpu_data.get("chain_index", [None])[0],
)
return to_pdb(protein)
def _mean_plddt(output: Any) -> float | None:
data = _as_mapping(output)
plddt = data.get("plddt")
if plddt is None:
return None
if hasattr(plddt, "detach"):
value = float(plddt.detach().float().mean().cpu().item())
else:
value = float(plddt.mean())
if value <= 1.5:
value *= 100.0
return round(value, 4)
def make_backend() -> FoldingBackend:
backend = os.getenv("FOLD_BACKEND", "esmfold").strip().lower()
if backend == "stub":
return StubBackend()
if backend == "esmfold":
return EsmFoldBackend(os.getenv("ESMFOLD_MODEL_ID", "facebook/esmfold_v1"))
raise ValueError(f"unsupported FOLD_BACKEND: {backend}")
|