tfrere HF Staff Cursor commited on
Commit
543de2f
·
1 Parent(s): d5563b7

Add /fold endpoint proxying NVIDIA NIM ESMFold

Browse files

Frontend §5 (the "From sequence to structure" section) needs predicted
3D structures to render. We proxy through the backend so the NIM API key
stays on the server.

Body: {"sequence": "<AA>"}
Returns: {"pdb": str, "n_residues": int, "plddt_mean": float, "cached": bool}
or {"error": str} on failure.

Implementation notes:
- httpx (already pulled in transitively by openai) replaces urllib so SSL
works out of the box on macOS without certifi gymnastics. Declared
explicitly in requirements.txt.
- Sequence is filtered to the 20 standard AAs before submission — NIM
rejects anything outside that charset (incl. stop codons "*"), and
callers shouldn't need to know the exact regex.
- Truncated to 1024 aa (NIM's hard cap).
- pLDDT mean is extracted from the B-factor column on CA atoms so the
frontend can show a global confidence score without re-parsing the PDB.
- sha1(sequence) → result FIFO cache (max 256 entries) since ESMFold is
deterministic. Demo viewers re-folding the same gene over and over
cost zero NIM calls.

Smoke-tested locally:
47 aa peptide → 0.55s, pLDDT 66.9
HBB 147 aa → 0.62s, pLDDT 93.8
cache hit → 0.037s (16x speedup)

Co-authored-by: Cursor <cursoragent@cursor.com>

Files changed (2) hide show
  1. app.py +108 -0
  2. requirements.txt +1 -0
app.py CHANGED
@@ -1,6 +1,8 @@
 
1
  import json
2
  import os
3
 
 
4
  from fastapi import FastAPI, Request
5
  from fastapi.responses import FileResponse, StreamingResponse
6
  from fastapi.staticfiles import StaticFiles
@@ -15,6 +17,22 @@ MODEL_NAME = os.environ.get(
15
  "hf-carbon/carbon-3B-hybrid-loss-1T-mix2-v1",
16
  )
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  HERE = os.path.dirname(os.path.abspath(__file__))
19
 
20
 
@@ -177,3 +195,93 @@ async def generate(request: Request):
177
  yield f"data: {json.dumps({'error': str(e)})}\n\n"
178
 
179
  return StreamingResponse(stream(), media_type="text/event-stream")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
  import json
3
  import os
4
 
5
+ import httpx
6
  from fastapi import FastAPI, Request
7
  from fastapi.responses import FileResponse, StreamingResponse
8
  from fastapi.staticfiles import StaticFiles
 
17
  "hf-carbon/carbon-3B-hybrid-loss-1T-mix2-v1",
18
  )
19
 
20
+ # NVIDIA NIM ESMFold endpoint (alignment-free protein structure prediction).
21
+ # Schema: POST {"sequence": "<AA>"} → {"pdbs": ["<PDB string>"]}.
22
+ # Constraints: max 1024 aa, charset = 20 standard AAs only.
23
+ NIM_FOLD_URL = os.environ.get(
24
+ "NIM_FOLD_URL",
25
+ "https://health.api.nvidia.com/v1/biology/nvidia/esmfold",
26
+ )
27
+ FOLD_MAX_LEN = 1024
28
+ FOLD_AA_ALPHABET = "ARNDCQEGHILKMFPSTWYV"
29
+
30
+ # In-memory cache: sha1(sequence) → result dict. ESMFold is deterministic at
31
+ # temperature 0, so caching is safe and lets demo viewers replay the same
32
+ # protein for free. Bounded to keep memory predictable on long-running Spaces.
33
+ _FOLD_CACHE: dict[str, dict] = {}
34
+ _FOLD_CACHE_MAX = 256
35
+
36
  HERE = os.path.dirname(os.path.abspath(__file__))
37
 
38
 
 
195
  yield f"data: {json.dumps({'error': str(e)})}\n\n"
196
 
197
  return StreamingResponse(stream(), media_type="text/event-stream")
198
+
199
+
200
+ def _extract_plddt(pdb: str) -> list[float]:
201
+ """Pull the per-residue pLDDT confidence out of the PDB B-factor column.
202
+
203
+ ESMFold writes its pLDDT score (0-100) into the B-factor field of every
204
+ atom. We sample CA atoms only so we get exactly one value per residue.
205
+ """
206
+ plddts: list[float] = []
207
+ for line in pdb.split("\n"):
208
+ if not line.startswith("ATOM"):
209
+ continue
210
+ if line[12:16].strip() != "CA":
211
+ continue
212
+ try:
213
+ plddts.append(float(line[60:66]))
214
+ except (ValueError, IndexError):
215
+ pass
216
+ return plddts
217
+
218
+
219
+ @app.post("/fold")
220
+ async def fold(request: Request):
221
+ """Predict a protein's 3D structure from its amino-acid sequence.
222
+
223
+ Body: {"sequence": "<AA>"}
224
+ Returns on success: {"pdb": str, "n_residues": int, "plddt_mean": float}
225
+ Returns on failure: {"error": str}
226
+
227
+ Implementation: thin proxy in front of NVIDIA NIM's ESMFold endpoint.
228
+ We strip non-standard characters (e.g. stop codons), enforce the 1024 aa
229
+ cap, and cache results by sha1(sequence) — ESMFold is deterministic so
230
+ caching is safe and free.
231
+ """
232
+ body = await request.json()
233
+ raw = (body.get("sequence") or "").upper()
234
+ # NIM rejects anything outside the 20 standard AAs; strip eagerly so the
235
+ # caller doesn't need to know the exact regex.
236
+ seq = "".join(c for c in raw if c in FOLD_AA_ALPHABET)
237
+ if not seq:
238
+ return {"error": "sequence empty after filtering to standard amino acids"}
239
+ if len(seq) > FOLD_MAX_LEN:
240
+ seq = seq[:FOLD_MAX_LEN]
241
+
242
+ key = hashlib.sha1(seq.encode()).hexdigest()
243
+ cached = _FOLD_CACHE.get(key)
244
+ if cached is not None:
245
+ return {**cached, "cached": True}
246
+
247
+ api_key = os.environ.get("NVIDIA_API_KEY")
248
+ if not api_key:
249
+ return {"error": "no NVIDIA_API_KEY env var — set it in .env"}
250
+
251
+ try:
252
+ with httpx.Client(timeout=120.0) as client:
253
+ resp = client.post(
254
+ NIM_FOLD_URL,
255
+ json={"sequence": seq},
256
+ headers={
257
+ "Authorization": f"Bearer {api_key}",
258
+ "Accept": "application/json",
259
+ },
260
+ )
261
+ except httpx.RequestError as e:
262
+ return {"error": f"NIM call failed: {e}"}
263
+ if resp.status_code != 200:
264
+ return {"error": f"NIM HTTP {resp.status_code}: {resp.text[:300]}"}
265
+ try:
266
+ data = resp.json()
267
+ except json.JSONDecodeError as e:
268
+ return {"error": f"NIM returned non-JSON: {e}"}
269
+
270
+ pdb = (data.get("pdbs") or [None])[0]
271
+ if not pdb:
272
+ return {"error": "NIM response had no PDB payload"}
273
+
274
+ plddts = _extract_plddt(pdb)
275
+ result = {
276
+ "pdb": pdb,
277
+ "n_residues": len(plddts),
278
+ "plddt_mean": (sum(plddts) / len(plddts)) if plddts else None,
279
+ }
280
+
281
+ # FIFO eviction. Dicts preserve insertion order in Python 3.7+ so the
282
+ # oldest entry is always next(iter(...)). Crude but the cache is a perf
283
+ # nicety, not a correctness mechanism.
284
+ if len(_FOLD_CACHE) >= _FOLD_CACHE_MAX:
285
+ _FOLD_CACHE.pop(next(iter(_FOLD_CACHE)), None)
286
+ _FOLD_CACHE[key] = result
287
+ return result
requirements.txt CHANGED
@@ -2,3 +2,4 @@ fastapi>=0.110
2
  uvicorn[standard]>=0.27
3
  openai>=1.40
4
  huggingface_hub>=0.24
 
 
2
  uvicorn[standard]>=0.27
3
  openai>=1.40
4
  huggingface_hub>=0.24
5
+ httpx>=0.27