tfrere HF Staff Cursor commited on
Commit
7c8a9cd
·
1 Parent(s): 45f472e

§5 Folding: precompute.py — best-of-N tiering + 503 cold-start retries

Browse files

Long-context generation (MYC, TP53 ~20 kb) collapses more often than
HBB/INS, so use 5 retries above 5 kb and 3 below. Also tolerate the HF
Inference Endpoint scale-to-zero cold start by retrying 503s with a
short backoff, and cap AA sent to ESMFold at 1000 to avoid one stray
oversized hallucination breaking the whole fold step.

Co-authored-by: Cursor <cursoragent@cursor.com>

Files changed (1) hide show
  1. scripts/precompute.py +72 -36
scripts/precompute.py CHANGED
@@ -33,8 +33,17 @@ TRACK_MAX_BP = 6000
33
  NIM_FOLD_URL = "https://health.api.nvidia.com/v1/biology/nvidia/esmfold"
34
  FOLD_PREFIX_LEN = 200
35
  FOLD_TEMPERATURE = 0.9
36
- FOLD_BEST_OF = 3
37
- FOLD_MAX_GENOMIC_BP = 2500 # genes whose last exon sits past this are skipped
 
 
 
 
 
 
 
 
 
38
 
39
  HERE = os.path.dirname(os.path.abspath(__file__))
40
  DATA = os.path.join(os.path.dirname(HERE), "data")
@@ -203,7 +212,8 @@ def find_longest_orf(dna, min_aa=30):
203
  BACKEND_URL = os.environ.get("CARBON_BACKEND", "http://127.0.0.1:7870")
204
 
205
 
206
- def carbon_continue(_client, prompt_dna, max_tokens, temperature):
 
207
  """Ask Carbon to continue a DNA prompt and return the cleaned continuation.
208
 
209
  Calls the backend /generate (SSE-streamed) rather than the OpenAI
@@ -211,26 +221,48 @@ def carbon_continue(_client, prompt_dna, max_tokens, temperature):
211
  follows the exact same pipeline as the live demo (left-padding to a
212
  multiple of 6, <dna> prefix, streaming framing) so the cached example
213
  is identical to what runFold() would produce.
 
 
 
 
 
 
214
  """
215
- with httpx.Client(timeout=180) as cx:
216
- r = cx.post(
217
- f"{BACKEND_URL}/generate",
218
- json={"prompt": prompt_dna, "max_tokens": max_tokens,
219
- "temperature": temperature, "top_p": 1.0},
220
- )
221
- r.raise_for_status()
222
- out = []
223
- for line in r.text.splitlines():
224
- line = line.strip()
225
- if not line.startswith("data:"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
  continue
227
- payload = json.loads(line[5:].strip())
228
- if "error" in payload:
229
- raise RuntimeError(payload["error"])
230
- t = payload.get("text") or ""
231
- out.append(t)
232
- text = ("".join(out)).upper()
233
- return "".join(c for c in text if c in "ACGT")
234
 
235
 
236
  def nim_fold(api_key, sequence):
@@ -275,11 +307,8 @@ def precompute_folds(client):
275
  if not g.get("exons"):
276
  continue
277
  last_exon_end = g["exons"][-1]["end"]
278
- if last_exon_end > FOLD_MAX_GENOMIC_BP:
279
- print(f" skip {g['symbol']}: last exon ends at {last_exon_end} bp (> {FOLD_MAX_GENOMIC_BP})")
280
- g.pop("fold_example", None)
281
- continue
282
- print(f" folding {g['symbol']} (last exon end {last_exon_end} bp)…", flush=True)
283
  try:
284
  seq = g["seq"].upper()
285
  ref_mrna = splice_exons(seq, g["exons"])
@@ -295,41 +324,48 @@ def precompute_folds(client):
295
  # a fluky premature codon. Try a few times and keep the longest
296
  # ORF — closer to "what Carbon usually produces for this gene".
297
  carbon_orf = None
298
- for attempt in range(FOLD_BEST_OF):
299
  t0 = time.time()
300
  cont = carbon_continue(client, prompt, max_tokens, FOLD_TEMPERATURE)
301
  carbon_dna = (prompt + cont)[: FOLD_PREFIX_LEN + gen_bp]
302
  carbon_mrna = splice_exons(carbon_dna, g["exons"])
303
  orf = find_longest_orf(carbon_mrna, 30)
304
  n = len(orf["aa"]) if orf else 0
305
- print(f" carbon try {attempt+1}/{FOLD_BEST_OF}: ORF {n} aa ({time.time()-t0:.1f}s)")
306
  if orf and (carbon_orf is None or len(orf["aa"]) > len(carbon_orf["aa"])):
307
  carbon_orf = orf
308
  if not carbon_orf:
309
- raise RuntimeError("Carbon's spliced mRNA has no valid ORF after %d tries" % FOLD_BEST_OF)
 
 
 
 
 
 
 
310
 
311
  t0 = time.time()
312
- ref_fold = nim_fold(api_key, ref_orf["aa"])
313
  print(f" ref fold: {ref_fold['n_residues']} aa, pLDDT {ref_fold['plddt_mean']:.1f} ({time.time()-t0:.1f}s)")
314
  t0 = time.time()
315
- carbon_fold = nim_fold(api_key, carbon_orf["aa"])
316
  print(f" carbon fold: {carbon_fold['n_residues']} aa, pLDDT {carbon_fold['plddt_mean']:.1f} ({time.time()-t0:.1f}s)")
317
 
318
- n = min(len(carbon_orf["aa"]), len(ref_orf["aa"]))
319
- identity = sum(1 for i in range(n) if carbon_orf["aa"][i] == ref_orf["aa"][i]) / n if n else 0.0
320
 
321
  g["fold_example"] = {
322
  "prefix_len": FOLD_PREFIX_LEN,
323
  "temperature": FOLD_TEMPERATURE,
324
- "carbon_aa": carbon_orf["aa"],
325
- "ref_aa": ref_orf["aa"],
326
  "carbon_pdb": carbon_fold["pdb"],
327
  "ref_pdb": ref_fold["pdb"],
328
  "carbon_plddt_mean": carbon_fold["plddt_mean"],
329
  "ref_plddt_mean": ref_fold["plddt_mean"],
330
  "identity_1d": identity,
331
  }
332
- print(f" ✓ identity {identity*100:.1f}% ({len(carbon_orf['aa'])}/{len(ref_orf['aa'])} aa)")
333
  except Exception as e:
334
  print(f" ✗ {e} (keeping previous fold_example if any)", file=sys.stderr)
335
  json.dump(genes, open(path, "w"), indent=2)
 
33
  NIM_FOLD_URL = "https://health.api.nvidia.com/v1/biology/nvidia/esmfold"
34
  FOLD_PREFIX_LEN = 200
35
  FOLD_TEMPERATURE = 0.9
36
+ # Best-of-N retries per gene. Long-context generation (MYC/TP53 need
37
+ # ~20 kb of DNA) collapses more often than HBB/INS, so we give those
38
+ # more shots at producing a viable ORF.
39
+ FOLD_BEST_OF_SHORT = 3
40
+ FOLD_BEST_OF_LONG = 5
41
+ FOLD_LONG_THRESHOLD = 5000 # bp; above this, use BEST_OF_LONG retries
42
+ # Hard cap on AA sent to ESMFold — NVIDIA NIM tops out around ~1024 aa
43
+ # and we don't want a single oversized Carbon hallucination to fail the
44
+ # whole fold step. The reference proteins for our 4 demo genes are all
45
+ # well below this anyway (HBB 147, INS 110, MYC 439, TP53 393).
46
+ FOLD_MAX_AA = 1000
47
 
48
  HERE = os.path.dirname(os.path.abspath(__file__))
49
  DATA = os.path.join(os.path.dirname(HERE), "data")
 
212
  BACKEND_URL = os.environ.get("CARBON_BACKEND", "http://127.0.0.1:7870")
213
 
214
 
215
+ def carbon_continue(_client, prompt_dna, max_tokens, temperature,
216
+ max_503_retries=10, retry_wait_s=15):
217
  """Ask Carbon to continue a DNA prompt and return the cleaned continuation.
218
 
219
  Calls the backend /generate (SSE-streamed) rather than the OpenAI
 
221
  follows the exact same pipeline as the live demo (left-padding to a
222
  multiple of 6, <dna> prefix, streaming framing) so the cached example
223
  is identical to what runFold() would produce.
224
+
225
+ The HF Inference Endpoint is configured to scale-to-zero after a few
226
+ hours idle, so the first call after a cold period bubbles up a 503
227
+ from upstream. We wait the endpoint out with a fixed backoff instead
228
+ of giving up — subsequent calls in the same session hit warm pods
229
+ and return in seconds.
230
  """
231
+ last_err = None
232
+ for attempt in range(max_503_retries + 1):
233
+ try:
234
+ with httpx.Client(timeout=300) as cx:
235
+ r = cx.post(
236
+ f"{BACKEND_URL}/generate",
237
+ json={"prompt": prompt_dna, "max_tokens": max_tokens,
238
+ "temperature": temperature, "top_p": 1.0},
239
+ )
240
+ r.raise_for_status()
241
+ out = []
242
+ for line in r.text.splitlines():
243
+ line = line.strip()
244
+ if not line.startswith("data:"):
245
+ continue
246
+ payload = json.loads(line[5:].strip())
247
+ if "error" in payload:
248
+ msg = str(payload["error"])
249
+ if "503" in msg and attempt < max_503_retries:
250
+ raise RuntimeError(msg)
251
+ raise RuntimeError(msg)
252
+ t = payload.get("text") or ""
253
+ out.append(t)
254
+ text = ("".join(out)).upper()
255
+ return "".join(c for c in text if c in "ACGT")
256
+ except (httpx.HTTPStatusError, RuntimeError) as e:
257
+ msg = str(e)
258
+ last_err = e
259
+ if "503" in msg and attempt < max_503_retries:
260
+ print(f" … HF endpoint cold, waiting {retry_wait_s}s "
261
+ f"(attempt {attempt+1}/{max_503_retries+1})", flush=True)
262
+ time.sleep(retry_wait_s)
263
  continue
264
+ raise
265
+ raise last_err if last_err else RuntimeError("carbon_continue: unreachable")
 
 
 
 
 
266
 
267
 
268
  def nim_fold(api_key, sequence):
 
307
  if not g.get("exons"):
308
  continue
309
  last_exon_end = g["exons"][-1]["end"]
310
+ n_tries = FOLD_BEST_OF_LONG if last_exon_end > FOLD_LONG_THRESHOLD else FOLD_BEST_OF_SHORT
311
+ print(f" folding {g['symbol']} (last exon end {last_exon_end} bp, best-of-{n_tries})���", flush=True)
 
 
 
312
  try:
313
  seq = g["seq"].upper()
314
  ref_mrna = splice_exons(seq, g["exons"])
 
324
  # a fluky premature codon. Try a few times and keep the longest
325
  # ORF — closer to "what Carbon usually produces for this gene".
326
  carbon_orf = None
327
+ for attempt in range(n_tries):
328
  t0 = time.time()
329
  cont = carbon_continue(client, prompt, max_tokens, FOLD_TEMPERATURE)
330
  carbon_dna = (prompt + cont)[: FOLD_PREFIX_LEN + gen_bp]
331
  carbon_mrna = splice_exons(carbon_dna, g["exons"])
332
  orf = find_longest_orf(carbon_mrna, 30)
333
  n = len(orf["aa"]) if orf else 0
334
+ print(f" carbon try {attempt+1}/{n_tries}: ORF {n} aa ({time.time()-t0:.1f}s)")
335
  if orf and (carbon_orf is None or len(orf["aa"]) > len(carbon_orf["aa"])):
336
  carbon_orf = orf
337
  if not carbon_orf:
338
+ raise RuntimeError("Carbon's spliced mRNA has no valid ORF after %d tries" % n_tries)
339
+
340
+ # Clamp to NIM's ~1024 aa ceiling. The reference proteins are
341
+ # all well below this; Carbon hallucinations can occasionally
342
+ # exceed it after a mutated stop codon, in which case we just
343
+ # fold the first FOLD_MAX_AA aa to keep the pipeline robust.
344
+ ref_aa = ref_orf["aa"][:FOLD_MAX_AA]
345
+ carbon_aa = carbon_orf["aa"][:FOLD_MAX_AA]
346
 
347
  t0 = time.time()
348
+ ref_fold = nim_fold(api_key, ref_aa)
349
  print(f" ref fold: {ref_fold['n_residues']} aa, pLDDT {ref_fold['plddt_mean']:.1f} ({time.time()-t0:.1f}s)")
350
  t0 = time.time()
351
+ carbon_fold = nim_fold(api_key, carbon_aa)
352
  print(f" carbon fold: {carbon_fold['n_residues']} aa, pLDDT {carbon_fold['plddt_mean']:.1f} ({time.time()-t0:.1f}s)")
353
 
354
+ n = min(len(carbon_aa), len(ref_aa))
355
+ identity = sum(1 for i in range(n) if carbon_aa[i] == ref_aa[i]) / n if n else 0.0
356
 
357
  g["fold_example"] = {
358
  "prefix_len": FOLD_PREFIX_LEN,
359
  "temperature": FOLD_TEMPERATURE,
360
+ "carbon_aa": carbon_aa,
361
+ "ref_aa": ref_aa,
362
  "carbon_pdb": carbon_fold["pdb"],
363
  "ref_pdb": ref_fold["pdb"],
364
  "carbon_plddt_mean": carbon_fold["plddt_mean"],
365
  "ref_plddt_mean": ref_fold["plddt_mean"],
366
  "identity_1d": identity,
367
  }
368
+ print(f" ✓ identity {identity*100:.1f}% ({len(carbon_aa)}/{len(ref_aa)} aa)")
369
  except Exception as e:
370
  print(f" ✗ {e} (keeping previous fold_example if any)", file=sys.stderr)
371
  json.dump(genes, open(path, "w"), indent=2)