Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import os | |
| import re | |
| import time | |
| import uuid | |
| from concurrent.futures import ThreadPoolExecutor | |
| from dataclasses import asdict, dataclass, field | |
| from threading import Lock | |
| from typing import Any | |
| from fastapi import Depends, FastAPI, Header, HTTPException | |
| from pydantic import BaseModel, Field | |
| from .backends import FoldingBackend, FoldOutput, make_backend | |
| PROTEIN_RE = re.compile(r"^[ACDEFGHIKLMNPQRSTVWYXBZUOJ]+$", re.IGNORECASE) | |
| MAX_PROTEIN_AA = int(os.getenv("MAX_PROTEIN_AA", "400")) | |
| API_TOKEN = os.getenv("FOLD_API_TOKEN", "").strip() | |
| class Entity(BaseModel): | |
| id: str = Field(min_length=1, max_length=32) | |
| type: str | |
| sequence: str | |
| class JobRequest(BaseModel): | |
| tool_id: str | |
| entities: list[Entity] | |
| options: dict[str, Any] = Field(default_factory=dict) | |
| client_metadata: dict[str, Any] = Field(default_factory=dict) | |
| class JobState: | |
| job_id: str | |
| tool_id: str | |
| status: str | |
| created_at: float | |
| updated_at: float | |
| progress: float = 0.0 | |
| result: dict[str, Any] | None = None | |
| error: str | None = None | |
| def public(self) -> dict[str, Any]: | |
| payload = asdict(self) | |
| payload.pop("created_at", None) | |
| payload.pop("updated_at", None) | |
| return payload | |
| class RuntimeState: | |
| backend: FoldingBackend = field(default_factory=make_backend) | |
| jobs: dict[str, JobState] = field(default_factory=dict) | |
| lock: Lock = field(default_factory=Lock) | |
| executor: ThreadPoolExecutor = field(default_factory=lambda: ThreadPoolExecutor(max_workers=1)) | |
| state = RuntimeState() | |
| app = FastAPI(title="Carbon Protein Folding API", version="0.1.0") | |
| def require_auth(authorization: str | None = Header(default=None)) -> None: | |
| if not API_TOKEN: | |
| return | |
| expected = f"Bearer {API_TOKEN}" | |
| if authorization != expected: | |
| raise HTTPException(status_code=401, detail="invalid or missing bearer token") | |
| def validate_request(payload: JobRequest) -> str: | |
| if payload.tool_id != "esmfold": | |
| raise HTTPException(status_code=400, detail="only tool_id 'esmfold' is supported") | |
| if len(payload.entities) != 1: | |
| raise HTTPException(status_code=400, detail="exactly one protein entity is supported") | |
| entity = payload.entities[0] | |
| if entity.type.lower() != "protein": | |
| raise HTTPException(status_code=400, detail="entity type must be 'protein'") | |
| sequence = re.sub(r"\s+", "", entity.sequence).upper().replace("*", "") | |
| if not sequence: | |
| raise HTTPException(status_code=400, detail="protein sequence is empty") | |
| if len(sequence) > MAX_PROTEIN_AA: | |
| raise HTTPException(status_code=400, detail=f"protein sequence exceeds {MAX_PROTEIN_AA} aa") | |
| if not PROTEIN_RE.match(sequence): | |
| raise HTTPException(status_code=400, detail="protein sequence contains unsupported characters") | |
| return sequence | |
| def health() -> dict[str, Any]: | |
| return { | |
| "ok": True, | |
| "backend": os.getenv("FOLD_BACKEND", "esmfold"), | |
| "max_protein_aa": MAX_PROTEIN_AA, | |
| "jobs": len(state.jobs), | |
| } | |
| def tools(_: None = Depends(require_auth)) -> dict[str, Any]: | |
| return { | |
| "tools": [ | |
| { | |
| "id": "esmfold", | |
| "name": "ESMFold", | |
| "status": "live", | |
| "input_types": ["protein"], | |
| "max_protein_aa": MAX_PROTEIN_AA, | |
| "output_formats": ["pdb"], | |
| "options": { | |
| "seed": {"type": "integer", "supported": False}, | |
| "num_recycles": {"type": "integer", "supported": False}, | |
| "msa_mode": {"type": "string", "value": "none"}, | |
| }, | |
| } | |
| ] | |
| } | |
| def create_job(payload: JobRequest, _: None = Depends(require_auth)) -> dict[str, str]: | |
| sequence = validate_request(payload) | |
| job_id = uuid.uuid4().hex | |
| now = time.time() | |
| job = JobState( | |
| job_id=job_id, | |
| tool_id=payload.tool_id, | |
| status="queued", | |
| created_at=now, | |
| updated_at=now, | |
| ) | |
| with state.lock: | |
| state.jobs[job_id] = job | |
| state.executor.submit(run_job, job_id, sequence, payload.options) | |
| return {"job_id": job_id, "status": "queued"} | |
| def get_job(job_id: str, _: None = Depends(require_auth)) -> dict[str, Any]: | |
| with state.lock: | |
| job = state.jobs.get(job_id) | |
| if job is None: | |
| raise HTTPException(status_code=404, detail="unknown job_id") | |
| return job.public() | |
| def run_job(job_id: str, sequence: str, options: dict[str, Any]) -> None: | |
| _update_job(job_id, status="running", progress=0.05) | |
| try: | |
| output = state.backend.fold(sequence, options) | |
| _update_job( | |
| job_id, | |
| status="succeeded", | |
| progress=1.0, | |
| result=_result_payload(output), | |
| error=None, | |
| ) | |
| except Exception as exc: # noqa: BLE001 - API should preserve job failure details. | |
| _update_job(job_id, status="failed", progress=1.0, error=str(exc), result=None) | |
| def _update_job( | |
| job_id: str, | |
| *, | |
| status: str, | |
| progress: float, | |
| result: dict[str, Any] | None = None, | |
| error: str | None = None, | |
| ) -> None: | |
| with state.lock: | |
| job = state.jobs[job_id] | |
| job.status = status | |
| job.progress = progress | |
| job.updated_at = time.time() | |
| if result is not None: | |
| job.result = result | |
| if error is not None: | |
| job.error = error | |
| def _result_payload(output: FoldOutput) -> dict[str, Any]: | |
| return { | |
| "structures": [ | |
| { | |
| "format": "pdb", | |
| "content": output.pdb, | |
| "confidence": output.confidence, | |
| } | |
| ], | |
| "metrics": output.metrics, | |
| "warnings": output.warnings, | |
| } | |