edbeeching's picture
edbeeching HF Staff
Scale ESMFold pLDDT to percent
f6e8c3d verified
from __future__ import annotations
import os
import time
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any
DEMO_PDB = """HEADER CARBON FOLDING API STUB
ATOM 1 N ALA A 1 -0.500 1.300 0.000 1.00 80.00 N
ATOM 2 CA ALA A 1 0.000 0.000 0.000 1.00 80.00 C
ATOM 3 C ALA A 1 1.520 0.000 0.000 1.00 80.00 C
ATOM 4 O ALA A 1 2.110 -1.060 0.000 1.00 80.00 O
ATOM 5 N GLY A 2 2.160 1.170 0.000 1.00 82.00 N
ATOM 6 CA GLY A 2 3.600 1.260 0.000 1.00 82.00 C
ATOM 7 C GLY A 2 4.160 2.660 0.000 1.00 82.00 C
ATOM 8 O GLY A 2 3.480 3.660 0.000 1.00 82.00 O
ATOM 9 N SER A 3 5.430 2.730 0.000 1.00 76.00 N
ATOM 10 CA SER A 3 6.080 4.030 0.000 1.00 76.00 C
ATOM 11 C SER A 3 7.600 3.910 0.000 1.00 76.00 C
ATOM 12 O SER A 3 8.250 4.920 0.000 1.00 76.00 O
TER
END
"""
@dataclass(frozen=True)
class FoldOutput:
pdb: str
confidence: dict[str, Any]
metrics: dict[str, Any]
warnings: list[str]
class FoldingBackend(ABC):
@abstractmethod
def fold(self, sequence: str, options: dict[str, Any]) -> FoldOutput:
raise NotImplementedError
class StubBackend(FoldingBackend):
def fold(self, sequence: str, options: dict[str, Any]) -> FoldOutput:
del options
started = time.monotonic()
time.sleep(min(0.1, max(0.0, len(sequence) / 10_000)))
return FoldOutput(
pdb=DEMO_PDB,
confidence={"mean_plddt": 80.0},
metrics={"runtime_seconds": round(time.monotonic() - started, 4), "sequence_length": len(sequence)},
warnings=["stub backend returned a demo structure"],
)
class EsmFoldBackend(FoldingBackend):
def __init__(self, model_id: str = "facebook/esmfold_v1") -> None:
self.model_id = model_id
self._loaded = False
self._device = None
self._tokenizer = None
self._model = None
def _load(self) -> None:
if self._loaded:
return
import torch
from transformers import AutoTokenizer, EsmForProteinFolding
self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self._tokenizer = AutoTokenizer.from_pretrained(self.model_id)
self._model = EsmForProteinFolding.from_pretrained(
self.model_id,
low_cpu_mem_usage=True,
)
self._model.eval()
self._model.to(self._device)
# Reduce memory use for longer demo proteins. This is supported by the
# Transformers ESMFold implementation and is a no-op if unavailable.
if hasattr(self._model, "trunk") and hasattr(self._model.trunk, "set_chunk_size"):
self._model.trunk.set_chunk_size(int(os.getenv("ESMFOLD_CHUNK_SIZE", "64")))
self._loaded = True
def fold(self, sequence: str, options: dict[str, Any]) -> FoldOutput:
del options
started = time.monotonic()
self._load()
import torch
assert self._device is not None
assert self._tokenizer is not None
assert self._model is not None
tokenized = self._tokenizer([sequence], return_tensors="pt", add_special_tokens=False)
tokenized = {key: value.to(self._device) for key, value in tokenized.items()}
with torch.no_grad():
output = self._model(**tokenized)
pdb = _esmfold_output_to_pdb(output)
mean_plddt = _mean_plddt(output)
runtime = time.monotonic() - started
warnings = []
if self._device.type != "cuda":
warnings.append("ESMFold ran on CPU; GPU is recommended")
if mean_plddt is not None and mean_plddt < 50:
warnings.append("low mean pLDDT; predicted structure may be unreliable")
return FoldOutput(
pdb=pdb,
confidence={"mean_plddt": mean_plddt},
metrics={
"runtime_seconds": round(runtime, 4),
"sequence_length": len(sequence),
"device": self._device.type,
},
warnings=warnings,
)
def _as_mapping(output: Any) -> dict[str, Any]:
if isinstance(output, dict):
return output
if hasattr(output, "to_tuple") and hasattr(output, "keys"):
return {key: output[key] for key in output.keys()}
if hasattr(output, "__dict__"):
return {key: value for key, value in vars(output).items() if not key.startswith("_")}
raise TypeError("unsupported ESMFold output type")
def _esmfold_output_to_pdb(output: Any) -> str:
import torch
from transformers.models.esm.openfold_utils.feats import atom14_to_atom37
from transformers.models.esm.openfold_utils.protein import Protein as OpenFoldProtein
from transformers.models.esm.openfold_utils.protein import to_pdb
data = _as_mapping(output)
final_atom_positions = atom14_to_atom37(data["positions"][-1], data)
cpu_data = {}
for key, value in data.items():
if torch.is_tensor(value):
cpu_data[key] = value.detach().cpu().numpy()
else:
cpu_data[key] = value
final_atom_positions = final_atom_positions.detach().cpu().numpy()
final_atom_mask = cpu_data["atom37_atom_exists"]
b_factors = cpu_data["plddt"][0]
if float(b_factors.max()) <= 1.5:
b_factors = b_factors * 100.0
protein = OpenFoldProtein(
aatype=cpu_data["aatype"][0],
atom_positions=final_atom_positions[0],
atom_mask=final_atom_mask[0],
residue_index=cpu_data["residue_index"][0] + 1,
b_factors=b_factors,
chain_index=cpu_data.get("chain_index", [None])[0],
)
return to_pdb(protein)
def _mean_plddt(output: Any) -> float | None:
data = _as_mapping(output)
plddt = data.get("plddt")
if plddt is None:
return None
if hasattr(plddt, "detach"):
value = float(plddt.detach().float().mean().cpu().item())
else:
value = float(plddt.mean())
if value <= 1.5:
value *= 100.0
return round(value, 4)
def make_backend() -> FoldingBackend:
backend = os.getenv("FOLD_BACKEND", "esmfold").strip().lower()
if backend == "stub":
return StubBackend()
if backend == "esmfold":
return EsmFoldBackend(os.getenv("ESMFOLD_MODEL_ID", "facebook/esmfold_v1"))
raise ValueError(f"unsupported FOLD_BACKEND: {backend}")