Spaces:
Sleeping
Sleeping
Add ESMFold folding API
Browse files- .dockerignore +6 -0
- Dockerfile +26 -0
- README.md +41 -6
- app.py +2 -0
- folding_api_service/__init__.py +2 -0
- folding_api_service/app.py +198 -0
- folding_api_service/backends.py +181 -0
- requirements.txt +9 -0
- tests/test_api.py +76 -0
.dockerignore
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.pyc
|
| 3 |
+
.pytest_cache/
|
| 4 |
+
.ruff_cache/
|
| 5 |
+
tests/
|
| 6 |
+
|
Dockerfile
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM pytorch/pytorch:2.5.1-cuda12.4-cudnn9-runtime
|
| 2 |
+
|
| 3 |
+
ENV PYTHONDONTWRITEBYTECODE=1 \
|
| 4 |
+
PYTHONUNBUFFERED=1 \
|
| 5 |
+
HF_HOME=/data/huggingface \
|
| 6 |
+
TRANSFORMERS_CACHE=/data/huggingface \
|
| 7 |
+
FOLD_BACKEND=esmfold \
|
| 8 |
+
MAX_PROTEIN_AA=400
|
| 9 |
+
|
| 10 |
+
WORKDIR /app
|
| 11 |
+
|
| 12 |
+
RUN apt-get update && \
|
| 13 |
+
apt-get install -y --no-install-recommends git git-lfs && \
|
| 14 |
+
rm -rf /var/lib/apt/lists/* && \
|
| 15 |
+
git lfs install
|
| 16 |
+
|
| 17 |
+
COPY requirements.txt /app/requirements.txt
|
| 18 |
+
RUN pip install --no-cache-dir --upgrade pip && \
|
| 19 |
+
pip install --no-cache-dir -r /app/requirements.txt
|
| 20 |
+
|
| 21 |
+
COPY . /app
|
| 22 |
+
|
| 23 |
+
EXPOSE 7860
|
| 24 |
+
|
| 25 |
+
CMD ["uvicorn", "folding_api_service.app:app", "--host", "0.0.0.0", "--port", "7860"]
|
| 26 |
+
|
README.md
CHANGED
|
@@ -1,10 +1,45 @@
|
|
| 1 |
---
|
| 2 |
-
title: Protein Folding
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: docker
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
---
|
| 9 |
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Carbon Protein Folding API
|
| 3 |
+
emoji: 🧬
|
| 4 |
+
colorFrom: green
|
| 5 |
+
colorTo: blue
|
| 6 |
sdk: docker
|
| 7 |
+
app_port: 7860
|
| 8 |
+
fullWidth: true
|
| 9 |
+
tags:
|
| 10 |
+
- biology
|
| 11 |
+
- protein-folding
|
| 12 |
+
- esmfold
|
| 13 |
+
- fastapi
|
| 14 |
+
- carbon
|
| 15 |
---
|
| 16 |
|
| 17 |
+
# Carbon Protein Folding API
|
| 18 |
+
|
| 19 |
+
FastAPI service for the Carbon DNA-to-Structure demo. The first live backend is ESMFold for single-chain protein folding.
|
| 20 |
+
|
| 21 |
+
## Endpoints
|
| 22 |
+
|
| 23 |
+
- `GET /health`
|
| 24 |
+
- `GET /tools`
|
| 25 |
+
- `POST /jobs`
|
| 26 |
+
- `GET /jobs/{job_id}`
|
| 27 |
+
|
| 28 |
+
`POST /jobs` accepts one protein entity and returns immediately with a `job_id`. Poll `GET /jobs/{job_id}` until the job reaches `succeeded` or `failed`.
|
| 29 |
+
|
| 30 |
+
## Configuration
|
| 31 |
+
|
| 32 |
+
Set these Space variables/secrets:
|
| 33 |
+
|
| 34 |
+
```sh
|
| 35 |
+
FOLD_BACKEND=esmfold
|
| 36 |
+
FOLD_API_TOKEN=...
|
| 37 |
+
MAX_PROTEIN_AA=400
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
For local CPU tests, use:
|
| 41 |
+
|
| 42 |
+
```sh
|
| 43 |
+
FOLD_BACKEND=stub uvicorn folding_api_service.app:app --host 0.0.0.0 --port 7860
|
| 44 |
+
```
|
| 45 |
+
|
app.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from folding_api_service.app import app
|
| 2 |
+
|
folding_api_service/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Protein folding API service for the Carbon demo."""
|
| 2 |
+
|
folding_api_service/app.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import re
|
| 5 |
+
import time
|
| 6 |
+
import uuid
|
| 7 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 8 |
+
from dataclasses import asdict, dataclass, field
|
| 9 |
+
from threading import Lock
|
| 10 |
+
from typing import Any
|
| 11 |
+
|
| 12 |
+
from fastapi import Depends, FastAPI, Header, HTTPException
|
| 13 |
+
from pydantic import BaseModel, Field
|
| 14 |
+
|
| 15 |
+
from .backends import FoldingBackend, FoldOutput, make_backend
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
PROTEIN_RE = re.compile(r"^[ACDEFGHIKLMNPQRSTVWYXBZUOJ]+$", re.IGNORECASE)
|
| 19 |
+
MAX_PROTEIN_AA = int(os.getenv("MAX_PROTEIN_AA", "400"))
|
| 20 |
+
API_TOKEN = os.getenv("FOLD_API_TOKEN", "").strip()
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class Entity(BaseModel):
|
| 24 |
+
id: str = Field(min_length=1, max_length=32)
|
| 25 |
+
type: str
|
| 26 |
+
sequence: str
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class JobRequest(BaseModel):
|
| 30 |
+
tool_id: str
|
| 31 |
+
entities: list[Entity]
|
| 32 |
+
options: dict[str, Any] = Field(default_factory=dict)
|
| 33 |
+
client_metadata: dict[str, Any] = Field(default_factory=dict)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@dataclass
|
| 37 |
+
class JobState:
|
| 38 |
+
job_id: str
|
| 39 |
+
tool_id: str
|
| 40 |
+
status: str
|
| 41 |
+
created_at: float
|
| 42 |
+
updated_at: float
|
| 43 |
+
progress: float = 0.0
|
| 44 |
+
result: dict[str, Any] | None = None
|
| 45 |
+
error: str | None = None
|
| 46 |
+
|
| 47 |
+
def public(self) -> dict[str, Any]:
|
| 48 |
+
payload = asdict(self)
|
| 49 |
+
payload.pop("created_at", None)
|
| 50 |
+
payload.pop("updated_at", None)
|
| 51 |
+
return payload
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
@dataclass
|
| 55 |
+
class RuntimeState:
|
| 56 |
+
backend: FoldingBackend = field(default_factory=make_backend)
|
| 57 |
+
jobs: dict[str, JobState] = field(default_factory=dict)
|
| 58 |
+
lock: Lock = field(default_factory=Lock)
|
| 59 |
+
executor: ThreadPoolExecutor = field(default_factory=lambda: ThreadPoolExecutor(max_workers=1))
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
state = RuntimeState()
|
| 63 |
+
app = FastAPI(title="Carbon Protein Folding API", version="0.1.0")
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def require_auth(authorization: str | None = Header(default=None)) -> None:
|
| 67 |
+
if not API_TOKEN:
|
| 68 |
+
return
|
| 69 |
+
expected = f"Bearer {API_TOKEN}"
|
| 70 |
+
if authorization != expected:
|
| 71 |
+
raise HTTPException(status_code=401, detail="invalid or missing bearer token")
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def validate_request(payload: JobRequest) -> str:
|
| 75 |
+
if payload.tool_id != "esmfold":
|
| 76 |
+
raise HTTPException(status_code=400, detail="only tool_id 'esmfold' is supported")
|
| 77 |
+
if len(payload.entities) != 1:
|
| 78 |
+
raise HTTPException(status_code=400, detail="exactly one protein entity is supported")
|
| 79 |
+
|
| 80 |
+
entity = payload.entities[0]
|
| 81 |
+
if entity.type.lower() != "protein":
|
| 82 |
+
raise HTTPException(status_code=400, detail="entity type must be 'protein'")
|
| 83 |
+
|
| 84 |
+
sequence = re.sub(r"\s+", "", entity.sequence).upper().replace("*", "")
|
| 85 |
+
if not sequence:
|
| 86 |
+
raise HTTPException(status_code=400, detail="protein sequence is empty")
|
| 87 |
+
if len(sequence) > MAX_PROTEIN_AA:
|
| 88 |
+
raise HTTPException(status_code=400, detail=f"protein sequence exceeds {MAX_PROTEIN_AA} aa")
|
| 89 |
+
if not PROTEIN_RE.match(sequence):
|
| 90 |
+
raise HTTPException(status_code=400, detail="protein sequence contains unsupported characters")
|
| 91 |
+
return sequence
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
@app.get("/health")
|
| 95 |
+
def health() -> dict[str, Any]:
|
| 96 |
+
return {
|
| 97 |
+
"ok": True,
|
| 98 |
+
"backend": os.getenv("FOLD_BACKEND", "esmfold"),
|
| 99 |
+
"max_protein_aa": MAX_PROTEIN_AA,
|
| 100 |
+
"jobs": len(state.jobs),
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
@app.get("/tools")
|
| 105 |
+
def tools(_: None = Depends(require_auth)) -> dict[str, Any]:
|
| 106 |
+
return {
|
| 107 |
+
"tools": [
|
| 108 |
+
{
|
| 109 |
+
"id": "esmfold",
|
| 110 |
+
"name": "ESMFold",
|
| 111 |
+
"status": "live",
|
| 112 |
+
"input_types": ["protein"],
|
| 113 |
+
"max_protein_aa": MAX_PROTEIN_AA,
|
| 114 |
+
"output_formats": ["pdb"],
|
| 115 |
+
"options": {
|
| 116 |
+
"seed": {"type": "integer", "supported": False},
|
| 117 |
+
"num_recycles": {"type": "integer", "supported": False},
|
| 118 |
+
"msa_mode": {"type": "string", "value": "none"},
|
| 119 |
+
},
|
| 120 |
+
}
|
| 121 |
+
]
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
@app.post("/jobs")
|
| 126 |
+
def create_job(payload: JobRequest, _: None = Depends(require_auth)) -> dict[str, str]:
|
| 127 |
+
sequence = validate_request(payload)
|
| 128 |
+
job_id = uuid.uuid4().hex
|
| 129 |
+
now = time.time()
|
| 130 |
+
job = JobState(
|
| 131 |
+
job_id=job_id,
|
| 132 |
+
tool_id=payload.tool_id,
|
| 133 |
+
status="queued",
|
| 134 |
+
created_at=now,
|
| 135 |
+
updated_at=now,
|
| 136 |
+
)
|
| 137 |
+
with state.lock:
|
| 138 |
+
state.jobs[job_id] = job
|
| 139 |
+
state.executor.submit(run_job, job_id, sequence, payload.options)
|
| 140 |
+
return {"job_id": job_id, "status": "queued"}
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
@app.get("/jobs/{job_id}")
|
| 144 |
+
def get_job(job_id: str, _: None = Depends(require_auth)) -> dict[str, Any]:
|
| 145 |
+
with state.lock:
|
| 146 |
+
job = state.jobs.get(job_id)
|
| 147 |
+
if job is None:
|
| 148 |
+
raise HTTPException(status_code=404, detail="unknown job_id")
|
| 149 |
+
return job.public()
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def run_job(job_id: str, sequence: str, options: dict[str, Any]) -> None:
|
| 153 |
+
_update_job(job_id, status="running", progress=0.05)
|
| 154 |
+
try:
|
| 155 |
+
output = state.backend.fold(sequence, options)
|
| 156 |
+
_update_job(
|
| 157 |
+
job_id,
|
| 158 |
+
status="succeeded",
|
| 159 |
+
progress=1.0,
|
| 160 |
+
result=_result_payload(output),
|
| 161 |
+
error=None,
|
| 162 |
+
)
|
| 163 |
+
except Exception as exc: # noqa: BLE001 - API should preserve job failure details.
|
| 164 |
+
_update_job(job_id, status="failed", progress=1.0, error=str(exc), result=None)
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def _update_job(
|
| 168 |
+
job_id: str,
|
| 169 |
+
*,
|
| 170 |
+
status: str,
|
| 171 |
+
progress: float,
|
| 172 |
+
result: dict[str, Any] | None = None,
|
| 173 |
+
error: str | None = None,
|
| 174 |
+
) -> None:
|
| 175 |
+
with state.lock:
|
| 176 |
+
job = state.jobs[job_id]
|
| 177 |
+
job.status = status
|
| 178 |
+
job.progress = progress
|
| 179 |
+
job.updated_at = time.time()
|
| 180 |
+
if result is not None:
|
| 181 |
+
job.result = result
|
| 182 |
+
if error is not None:
|
| 183 |
+
job.error = error
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def _result_payload(output: FoldOutput) -> dict[str, Any]:
|
| 187 |
+
return {
|
| 188 |
+
"structures": [
|
| 189 |
+
{
|
| 190 |
+
"format": "pdb",
|
| 191 |
+
"content": output.pdb,
|
| 192 |
+
"confidence": output.confidence,
|
| 193 |
+
}
|
| 194 |
+
],
|
| 195 |
+
"metrics": output.metrics,
|
| 196 |
+
"warnings": output.warnings,
|
| 197 |
+
}
|
| 198 |
+
|
folding_api_service/backends.py
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import time
|
| 5 |
+
from abc import ABC, abstractmethod
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from typing import Any
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
DEMO_PDB = """HEADER CARBON FOLDING API STUB
|
| 11 |
+
ATOM 1 N ALA A 1 -0.500 1.300 0.000 1.00 80.00 N
|
| 12 |
+
ATOM 2 CA ALA A 1 0.000 0.000 0.000 1.00 80.00 C
|
| 13 |
+
ATOM 3 C ALA A 1 1.520 0.000 0.000 1.00 80.00 C
|
| 14 |
+
ATOM 4 O ALA A 1 2.110 -1.060 0.000 1.00 80.00 O
|
| 15 |
+
ATOM 5 N GLY A 2 2.160 1.170 0.000 1.00 82.00 N
|
| 16 |
+
ATOM 6 CA GLY A 2 3.600 1.260 0.000 1.00 82.00 C
|
| 17 |
+
ATOM 7 C GLY A 2 4.160 2.660 0.000 1.00 82.00 C
|
| 18 |
+
ATOM 8 O GLY A 2 3.480 3.660 0.000 1.00 82.00 O
|
| 19 |
+
ATOM 9 N SER A 3 5.430 2.730 0.000 1.00 76.00 N
|
| 20 |
+
ATOM 10 CA SER A 3 6.080 4.030 0.000 1.00 76.00 C
|
| 21 |
+
ATOM 11 C SER A 3 7.600 3.910 0.000 1.00 76.00 C
|
| 22 |
+
ATOM 12 O SER A 3 8.250 4.920 0.000 1.00 76.00 O
|
| 23 |
+
TER
|
| 24 |
+
END
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@dataclass(frozen=True)
|
| 29 |
+
class FoldOutput:
|
| 30 |
+
pdb: str
|
| 31 |
+
confidence: dict[str, Any]
|
| 32 |
+
metrics: dict[str, Any]
|
| 33 |
+
warnings: list[str]
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class FoldingBackend(ABC):
|
| 37 |
+
@abstractmethod
|
| 38 |
+
def fold(self, sequence: str, options: dict[str, Any]) -> FoldOutput:
|
| 39 |
+
raise NotImplementedError
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class StubBackend(FoldingBackend):
|
| 43 |
+
def fold(self, sequence: str, options: dict[str, Any]) -> FoldOutput:
|
| 44 |
+
del options
|
| 45 |
+
started = time.monotonic()
|
| 46 |
+
time.sleep(min(0.1, max(0.0, len(sequence) / 10_000)))
|
| 47 |
+
return FoldOutput(
|
| 48 |
+
pdb=DEMO_PDB,
|
| 49 |
+
confidence={"mean_plddt": 80.0},
|
| 50 |
+
metrics={"runtime_seconds": round(time.monotonic() - started, 4), "sequence_length": len(sequence)},
|
| 51 |
+
warnings=["stub backend returned a demo structure"],
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class EsmFoldBackend(FoldingBackend):
|
| 56 |
+
def __init__(self, model_id: str = "facebook/esmfold_v1") -> None:
|
| 57 |
+
self.model_id = model_id
|
| 58 |
+
self._loaded = False
|
| 59 |
+
self._device = None
|
| 60 |
+
self._tokenizer = None
|
| 61 |
+
self._model = None
|
| 62 |
+
|
| 63 |
+
def _load(self) -> None:
|
| 64 |
+
if self._loaded:
|
| 65 |
+
return
|
| 66 |
+
|
| 67 |
+
import torch
|
| 68 |
+
from transformers import AutoTokenizer, EsmForProteinFolding
|
| 69 |
+
|
| 70 |
+
self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 71 |
+
self._tokenizer = AutoTokenizer.from_pretrained(self.model_id)
|
| 72 |
+
self._model = EsmForProteinFolding.from_pretrained(
|
| 73 |
+
self.model_id,
|
| 74 |
+
low_cpu_mem_usage=True,
|
| 75 |
+
)
|
| 76 |
+
self._model.eval()
|
| 77 |
+
self._model.to(self._device)
|
| 78 |
+
|
| 79 |
+
# Reduce memory use for longer demo proteins. This is supported by the
|
| 80 |
+
# Transformers ESMFold implementation and is a no-op if unavailable.
|
| 81 |
+
if hasattr(self._model, "trunk") and hasattr(self._model.trunk, "set_chunk_size"):
|
| 82 |
+
self._model.trunk.set_chunk_size(int(os.getenv("ESMFOLD_CHUNK_SIZE", "64")))
|
| 83 |
+
|
| 84 |
+
self._loaded = True
|
| 85 |
+
|
| 86 |
+
def fold(self, sequence: str, options: dict[str, Any]) -> FoldOutput:
|
| 87 |
+
del options
|
| 88 |
+
started = time.monotonic()
|
| 89 |
+
self._load()
|
| 90 |
+
|
| 91 |
+
import torch
|
| 92 |
+
|
| 93 |
+
assert self._device is not None
|
| 94 |
+
assert self._tokenizer is not None
|
| 95 |
+
assert self._model is not None
|
| 96 |
+
|
| 97 |
+
tokenized = self._tokenizer([sequence], return_tensors="pt", add_special_tokens=False)
|
| 98 |
+
tokenized = {key: value.to(self._device) for key, value in tokenized.items()}
|
| 99 |
+
|
| 100 |
+
with torch.no_grad():
|
| 101 |
+
output = self._model(**tokenized)
|
| 102 |
+
|
| 103 |
+
pdb = _esmfold_output_to_pdb(output)
|
| 104 |
+
mean_plddt = _mean_plddt(output)
|
| 105 |
+
runtime = time.monotonic() - started
|
| 106 |
+
warnings = []
|
| 107 |
+
if self._device.type != "cuda":
|
| 108 |
+
warnings.append("ESMFold ran on CPU; GPU is recommended")
|
| 109 |
+
if mean_plddt is not None and mean_plddt < 50:
|
| 110 |
+
warnings.append("low mean pLDDT; predicted structure may be unreliable")
|
| 111 |
+
|
| 112 |
+
return FoldOutput(
|
| 113 |
+
pdb=pdb,
|
| 114 |
+
confidence={"mean_plddt": mean_plddt},
|
| 115 |
+
metrics={
|
| 116 |
+
"runtime_seconds": round(runtime, 4),
|
| 117 |
+
"sequence_length": len(sequence),
|
| 118 |
+
"device": self._device.type,
|
| 119 |
+
},
|
| 120 |
+
warnings=warnings,
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def _as_mapping(output: Any) -> dict[str, Any]:
|
| 125 |
+
if isinstance(output, dict):
|
| 126 |
+
return output
|
| 127 |
+
if hasattr(output, "to_tuple") and hasattr(output, "keys"):
|
| 128 |
+
return {key: output[key] for key in output.keys()}
|
| 129 |
+
if hasattr(output, "__dict__"):
|
| 130 |
+
return {key: value for key, value in vars(output).items() if not key.startswith("_")}
|
| 131 |
+
raise TypeError("unsupported ESMFold output type")
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def _esmfold_output_to_pdb(output: Any) -> str:
|
| 135 |
+
import torch
|
| 136 |
+
from transformers.models.esm.openfold_utils.feats import atom14_to_atom37
|
| 137 |
+
from transformers.models.esm.openfold_utils.protein import Protein as OpenFoldProtein
|
| 138 |
+
from transformers.models.esm.openfold_utils.protein import to_pdb
|
| 139 |
+
|
| 140 |
+
data = _as_mapping(output)
|
| 141 |
+
final_atom_positions = atom14_to_atom37(data["positions"][-1], data)
|
| 142 |
+
|
| 143 |
+
cpu_data = {}
|
| 144 |
+
for key, value in data.items():
|
| 145 |
+
if torch.is_tensor(value):
|
| 146 |
+
cpu_data[key] = value.detach().cpu().numpy()
|
| 147 |
+
else:
|
| 148 |
+
cpu_data[key] = value
|
| 149 |
+
|
| 150 |
+
final_atom_positions = final_atom_positions.detach().cpu().numpy()
|
| 151 |
+
final_atom_mask = cpu_data["atom37_atom_exists"]
|
| 152 |
+
|
| 153 |
+
protein = OpenFoldProtein(
|
| 154 |
+
aatype=cpu_data["aatype"][0],
|
| 155 |
+
atom_positions=final_atom_positions[0],
|
| 156 |
+
atom_mask=final_atom_mask[0],
|
| 157 |
+
residue_index=cpu_data["residue_index"][0] + 1,
|
| 158 |
+
b_factors=cpu_data["plddt"][0],
|
| 159 |
+
chain_index=cpu_data.get("chain_index", [None])[0],
|
| 160 |
+
)
|
| 161 |
+
return to_pdb(protein)
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def _mean_plddt(output: Any) -> float | None:
|
| 165 |
+
data = _as_mapping(output)
|
| 166 |
+
plddt = data.get("plddt")
|
| 167 |
+
if plddt is None:
|
| 168 |
+
return None
|
| 169 |
+
if hasattr(plddt, "detach"):
|
| 170 |
+
return round(float(plddt.detach().float().mean().cpu().item()), 4)
|
| 171 |
+
return round(float(plddt.mean()), 4)
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def make_backend() -> FoldingBackend:
|
| 175 |
+
backend = os.getenv("FOLD_BACKEND", "esmfold").strip().lower()
|
| 176 |
+
if backend == "stub":
|
| 177 |
+
return StubBackend()
|
| 178 |
+
if backend == "esmfold":
|
| 179 |
+
return EsmFoldBackend(os.getenv("ESMFOLD_MODEL_ID", "facebook/esmfold_v1"))
|
| 180 |
+
raise ValueError(f"unsupported FOLD_BACKEND: {backend}")
|
| 181 |
+
|
requirements.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
accelerate>=0.31
|
| 2 |
+
fastapi>=0.110
|
| 3 |
+
huggingface-hub>=0.23
|
| 4 |
+
numpy<2
|
| 5 |
+
pydantic>=2.7
|
| 6 |
+
scipy>=1.11
|
| 7 |
+
transformers>=4.44,<5
|
| 8 |
+
uvicorn[standard]>=0.29
|
| 9 |
+
|
tests/test_api.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import time
|
| 3 |
+
import unittest
|
| 4 |
+
|
| 5 |
+
os.environ.setdefault("FOLD_BACKEND", "stub")
|
| 6 |
+
os.environ.setdefault("MAX_PROTEIN_AA", "400")
|
| 7 |
+
|
| 8 |
+
from fastapi.testclient import TestClient # noqa: E402
|
| 9 |
+
|
| 10 |
+
from folding_api_service.app import app # noqa: E402
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class FoldingApiTest(unittest.TestCase):
|
| 14 |
+
def setUp(self):
|
| 15 |
+
self.client = TestClient(app)
|
| 16 |
+
|
| 17 |
+
def test_health(self):
|
| 18 |
+
response = self.client.get("/health")
|
| 19 |
+
self.assertEqual(response.status_code, 200)
|
| 20 |
+
self.assertTrue(response.json()["ok"])
|
| 21 |
+
|
| 22 |
+
def test_tools(self):
|
| 23 |
+
response = self.client.get("/tools")
|
| 24 |
+
self.assertEqual(response.status_code, 200)
|
| 25 |
+
tools = response.json()["tools"]
|
| 26 |
+
self.assertEqual(tools[0]["id"], "esmfold")
|
| 27 |
+
self.assertEqual(tools[0]["max_protein_aa"], 400)
|
| 28 |
+
|
| 29 |
+
def test_submit_and_poll_stub_job(self):
|
| 30 |
+
response = self.client.post(
|
| 31 |
+
"/jobs",
|
| 32 |
+
json={
|
| 33 |
+
"tool_id": "esmfold",
|
| 34 |
+
"entities": [{"id": "A", "type": "protein", "sequence": "MLSDEDFKAVFGMTRSAFANLPLWKQQNLKKEKGLF"}],
|
| 35 |
+
"options": {"msa_mode": "none"},
|
| 36 |
+
},
|
| 37 |
+
)
|
| 38 |
+
self.assertEqual(response.status_code, 200)
|
| 39 |
+
job_id = response.json()["job_id"]
|
| 40 |
+
|
| 41 |
+
final = None
|
| 42 |
+
for _ in range(20):
|
| 43 |
+
poll = self.client.get(f"/jobs/{job_id}")
|
| 44 |
+
self.assertEqual(poll.status_code, 200)
|
| 45 |
+
final = poll.json()
|
| 46 |
+
if final["status"] in {"succeeded", "failed"}:
|
| 47 |
+
break
|
| 48 |
+
time.sleep(0.05)
|
| 49 |
+
|
| 50 |
+
self.assertIsNotNone(final)
|
| 51 |
+
self.assertEqual(final["status"], "succeeded")
|
| 52 |
+
self.assertIn("ATOM", final["result"]["structures"][0]["content"])
|
| 53 |
+
self.assertEqual(final["result"]["structures"][0]["format"], "pdb")
|
| 54 |
+
|
| 55 |
+
def test_validation_errors(self):
|
| 56 |
+
cases = [
|
| 57 |
+
{"tool_id": "omegafold", "entities": [{"id": "A", "type": "protein", "sequence": "MKT"}]},
|
| 58 |
+
{"tool_id": "esmfold", "entities": []},
|
| 59 |
+
{"tool_id": "esmfold", "entities": [{"id": "A", "type": "dna", "sequence": "ATG"}]},
|
| 60 |
+
{"tool_id": "esmfold", "entities": [{"id": "A", "type": "protein", "sequence": ""}]},
|
| 61 |
+
{"tool_id": "esmfold", "entities": [{"id": "A", "type": "protein", "sequence": "M1T"}]},
|
| 62 |
+
{"tool_id": "esmfold", "entities": [{"id": "A", "type": "protein", "sequence": "M" * 401}]},
|
| 63 |
+
]
|
| 64 |
+
for payload in cases:
|
| 65 |
+
with self.subTest(payload=payload):
|
| 66 |
+
response = self.client.post("/jobs", json={**payload, "options": {}})
|
| 67 |
+
self.assertEqual(response.status_code, 400)
|
| 68 |
+
|
| 69 |
+
def test_unknown_job(self):
|
| 70 |
+
response = self.client.get("/jobs/missing")
|
| 71 |
+
self.assertEqual(response.status_code, 404)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
if __name__ == "__main__":
|
| 75 |
+
unittest.main()
|
| 76 |
+
|