Spaces:
Running
Running
§5 Folding: precompute.py — best-of-N tiering + 503 cold-start retries
Browse filesLong-context generation (MYC, TP53 ~20 kb) collapses more often than
HBB/INS, so use 5 retries above 5 kb and 3 below. Also tolerate the HF
Inference Endpoint scale-to-zero cold start by retrying 503s with a
short backoff, and cap AA sent to ESMFold at 1000 to avoid one stray
oversized hallucination breaking the whole fold step.
Co-authored-by: Cursor <cursoragent@cursor.com>
- scripts/precompute.py +72 -36
scripts/precompute.py
CHANGED
|
@@ -33,8 +33,17 @@ TRACK_MAX_BP = 6000
|
|
| 33 |
NIM_FOLD_URL = "https://health.api.nvidia.com/v1/biology/nvidia/esmfold"
|
| 34 |
FOLD_PREFIX_LEN = 200
|
| 35 |
FOLD_TEMPERATURE = 0.9
|
| 36 |
-
|
| 37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
HERE = os.path.dirname(os.path.abspath(__file__))
|
| 40 |
DATA = os.path.join(os.path.dirname(HERE), "data")
|
|
@@ -203,7 +212,8 @@ def find_longest_orf(dna, min_aa=30):
|
|
| 203 |
BACKEND_URL = os.environ.get("CARBON_BACKEND", "http://127.0.0.1:7870")
|
| 204 |
|
| 205 |
|
| 206 |
-
def carbon_continue(_client, prompt_dna, max_tokens, temperature
|
|
|
|
| 207 |
"""Ask Carbon to continue a DNA prompt and return the cleaned continuation.
|
| 208 |
|
| 209 |
Calls the backend /generate (SSE-streamed) rather than the OpenAI
|
|
@@ -211,26 +221,48 @@ def carbon_continue(_client, prompt_dna, max_tokens, temperature):
|
|
| 211 |
follows the exact same pipeline as the live demo (left-padding to a
|
| 212 |
multiple of 6, <dna> prefix, streaming framing) so the cached example
|
| 213 |
is identical to what runFold() would produce.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
"""
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 226 |
continue
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
raise RuntimeError(payload["error"])
|
| 230 |
-
t = payload.get("text") or ""
|
| 231 |
-
out.append(t)
|
| 232 |
-
text = ("".join(out)).upper()
|
| 233 |
-
return "".join(c for c in text if c in "ACGT")
|
| 234 |
|
| 235 |
|
| 236 |
def nim_fold(api_key, sequence):
|
|
@@ -275,11 +307,8 @@ def precompute_folds(client):
|
|
| 275 |
if not g.get("exons"):
|
| 276 |
continue
|
| 277 |
last_exon_end = g["exons"][-1]["end"]
|
| 278 |
-
if last_exon_end >
|
| 279 |
-
|
| 280 |
-
g.pop("fold_example", None)
|
| 281 |
-
continue
|
| 282 |
-
print(f" folding {g['symbol']} (last exon end {last_exon_end} bp)…", flush=True)
|
| 283 |
try:
|
| 284 |
seq = g["seq"].upper()
|
| 285 |
ref_mrna = splice_exons(seq, g["exons"])
|
|
@@ -295,41 +324,48 @@ def precompute_folds(client):
|
|
| 295 |
# a fluky premature codon. Try a few times and keep the longest
|
| 296 |
# ORF — closer to "what Carbon usually produces for this gene".
|
| 297 |
carbon_orf = None
|
| 298 |
-
for attempt in range(
|
| 299 |
t0 = time.time()
|
| 300 |
cont = carbon_continue(client, prompt, max_tokens, FOLD_TEMPERATURE)
|
| 301 |
carbon_dna = (prompt + cont)[: FOLD_PREFIX_LEN + gen_bp]
|
| 302 |
carbon_mrna = splice_exons(carbon_dna, g["exons"])
|
| 303 |
orf = find_longest_orf(carbon_mrna, 30)
|
| 304 |
n = len(orf["aa"]) if orf else 0
|
| 305 |
-
print(f" carbon try {attempt+1}/{
|
| 306 |
if orf and (carbon_orf is None or len(orf["aa"]) > len(carbon_orf["aa"])):
|
| 307 |
carbon_orf = orf
|
| 308 |
if not carbon_orf:
|
| 309 |
-
raise RuntimeError("Carbon's spliced mRNA has no valid ORF after %d tries" %
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 310 |
|
| 311 |
t0 = time.time()
|
| 312 |
-
ref_fold = nim_fold(api_key,
|
| 313 |
print(f" ref fold: {ref_fold['n_residues']} aa, pLDDT {ref_fold['plddt_mean']:.1f} ({time.time()-t0:.1f}s)")
|
| 314 |
t0 = time.time()
|
| 315 |
-
carbon_fold = nim_fold(api_key,
|
| 316 |
print(f" carbon fold: {carbon_fold['n_residues']} aa, pLDDT {carbon_fold['plddt_mean']:.1f} ({time.time()-t0:.1f}s)")
|
| 317 |
|
| 318 |
-
n = min(len(
|
| 319 |
-
identity = sum(1 for i in range(n) if
|
| 320 |
|
| 321 |
g["fold_example"] = {
|
| 322 |
"prefix_len": FOLD_PREFIX_LEN,
|
| 323 |
"temperature": FOLD_TEMPERATURE,
|
| 324 |
-
"carbon_aa":
|
| 325 |
-
"ref_aa":
|
| 326 |
"carbon_pdb": carbon_fold["pdb"],
|
| 327 |
"ref_pdb": ref_fold["pdb"],
|
| 328 |
"carbon_plddt_mean": carbon_fold["plddt_mean"],
|
| 329 |
"ref_plddt_mean": ref_fold["plddt_mean"],
|
| 330 |
"identity_1d": identity,
|
| 331 |
}
|
| 332 |
-
print(f" ✓ identity {identity*100:.1f}% ({len(
|
| 333 |
except Exception as e:
|
| 334 |
print(f" ✗ {e} (keeping previous fold_example if any)", file=sys.stderr)
|
| 335 |
json.dump(genes, open(path, "w"), indent=2)
|
|
|
|
| 33 |
NIM_FOLD_URL = "https://health.api.nvidia.com/v1/biology/nvidia/esmfold"
|
| 34 |
FOLD_PREFIX_LEN = 200
|
| 35 |
FOLD_TEMPERATURE = 0.9
|
| 36 |
+
# Best-of-N retries per gene. Long-context generation (MYC/TP53 need
|
| 37 |
+
# ~20 kb of DNA) collapses more often than HBB/INS, so we give those
|
| 38 |
+
# more shots at producing a viable ORF.
|
| 39 |
+
FOLD_BEST_OF_SHORT = 3
|
| 40 |
+
FOLD_BEST_OF_LONG = 5
|
| 41 |
+
FOLD_LONG_THRESHOLD = 5000 # bp; above this, use BEST_OF_LONG retries
|
| 42 |
+
# Hard cap on AA sent to ESMFold — NVIDIA NIM tops out around ~1024 aa
|
| 43 |
+
# and we don't want a single oversized Carbon hallucination to fail the
|
| 44 |
+
# whole fold step. The reference proteins for our 4 demo genes are all
|
| 45 |
+
# well below this anyway (HBB 147, INS 110, MYC 439, TP53 393).
|
| 46 |
+
FOLD_MAX_AA = 1000
|
| 47 |
|
| 48 |
HERE = os.path.dirname(os.path.abspath(__file__))
|
| 49 |
DATA = os.path.join(os.path.dirname(HERE), "data")
|
|
|
|
| 212 |
BACKEND_URL = os.environ.get("CARBON_BACKEND", "http://127.0.0.1:7870")
|
| 213 |
|
| 214 |
|
| 215 |
+
def carbon_continue(_client, prompt_dna, max_tokens, temperature,
|
| 216 |
+
max_503_retries=10, retry_wait_s=15):
|
| 217 |
"""Ask Carbon to continue a DNA prompt and return the cleaned continuation.
|
| 218 |
|
| 219 |
Calls the backend /generate (SSE-streamed) rather than the OpenAI
|
|
|
|
| 221 |
follows the exact same pipeline as the live demo (left-padding to a
|
| 222 |
multiple of 6, <dna> prefix, streaming framing) so the cached example
|
| 223 |
is identical to what runFold() would produce.
|
| 224 |
+
|
| 225 |
+
The HF Inference Endpoint is configured to scale-to-zero after a few
|
| 226 |
+
hours idle, so the first call after a cold period bubbles up a 503
|
| 227 |
+
from upstream. We wait the endpoint out with a fixed backoff instead
|
| 228 |
+
of giving up — subsequent calls in the same session hit warm pods
|
| 229 |
+
and return in seconds.
|
| 230 |
"""
|
| 231 |
+
last_err = None
|
| 232 |
+
for attempt in range(max_503_retries + 1):
|
| 233 |
+
try:
|
| 234 |
+
with httpx.Client(timeout=300) as cx:
|
| 235 |
+
r = cx.post(
|
| 236 |
+
f"{BACKEND_URL}/generate",
|
| 237 |
+
json={"prompt": prompt_dna, "max_tokens": max_tokens,
|
| 238 |
+
"temperature": temperature, "top_p": 1.0},
|
| 239 |
+
)
|
| 240 |
+
r.raise_for_status()
|
| 241 |
+
out = []
|
| 242 |
+
for line in r.text.splitlines():
|
| 243 |
+
line = line.strip()
|
| 244 |
+
if not line.startswith("data:"):
|
| 245 |
+
continue
|
| 246 |
+
payload = json.loads(line[5:].strip())
|
| 247 |
+
if "error" in payload:
|
| 248 |
+
msg = str(payload["error"])
|
| 249 |
+
if "503" in msg and attempt < max_503_retries:
|
| 250 |
+
raise RuntimeError(msg)
|
| 251 |
+
raise RuntimeError(msg)
|
| 252 |
+
t = payload.get("text") or ""
|
| 253 |
+
out.append(t)
|
| 254 |
+
text = ("".join(out)).upper()
|
| 255 |
+
return "".join(c for c in text if c in "ACGT")
|
| 256 |
+
except (httpx.HTTPStatusError, RuntimeError) as e:
|
| 257 |
+
msg = str(e)
|
| 258 |
+
last_err = e
|
| 259 |
+
if "503" in msg and attempt < max_503_retries:
|
| 260 |
+
print(f" … HF endpoint cold, waiting {retry_wait_s}s "
|
| 261 |
+
f"(attempt {attempt+1}/{max_503_retries+1})", flush=True)
|
| 262 |
+
time.sleep(retry_wait_s)
|
| 263 |
continue
|
| 264 |
+
raise
|
| 265 |
+
raise last_err if last_err else RuntimeError("carbon_continue: unreachable")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 266 |
|
| 267 |
|
| 268 |
def nim_fold(api_key, sequence):
|
|
|
|
| 307 |
if not g.get("exons"):
|
| 308 |
continue
|
| 309 |
last_exon_end = g["exons"][-1]["end"]
|
| 310 |
+
n_tries = FOLD_BEST_OF_LONG if last_exon_end > FOLD_LONG_THRESHOLD else FOLD_BEST_OF_SHORT
|
| 311 |
+
print(f" folding {g['symbol']} (last exon end {last_exon_end} bp, best-of-{n_tries})���", flush=True)
|
|
|
|
|
|
|
|
|
|
| 312 |
try:
|
| 313 |
seq = g["seq"].upper()
|
| 314 |
ref_mrna = splice_exons(seq, g["exons"])
|
|
|
|
| 324 |
# a fluky premature codon. Try a few times and keep the longest
|
| 325 |
# ORF — closer to "what Carbon usually produces for this gene".
|
| 326 |
carbon_orf = None
|
| 327 |
+
for attempt in range(n_tries):
|
| 328 |
t0 = time.time()
|
| 329 |
cont = carbon_continue(client, prompt, max_tokens, FOLD_TEMPERATURE)
|
| 330 |
carbon_dna = (prompt + cont)[: FOLD_PREFIX_LEN + gen_bp]
|
| 331 |
carbon_mrna = splice_exons(carbon_dna, g["exons"])
|
| 332 |
orf = find_longest_orf(carbon_mrna, 30)
|
| 333 |
n = len(orf["aa"]) if orf else 0
|
| 334 |
+
print(f" carbon try {attempt+1}/{n_tries}: ORF {n} aa ({time.time()-t0:.1f}s)")
|
| 335 |
if orf and (carbon_orf is None or len(orf["aa"]) > len(carbon_orf["aa"])):
|
| 336 |
carbon_orf = orf
|
| 337 |
if not carbon_orf:
|
| 338 |
+
raise RuntimeError("Carbon's spliced mRNA has no valid ORF after %d tries" % n_tries)
|
| 339 |
+
|
| 340 |
+
# Clamp to NIM's ~1024 aa ceiling. The reference proteins are
|
| 341 |
+
# all well below this; Carbon hallucinations can occasionally
|
| 342 |
+
# exceed it after a mutated stop codon, in which case we just
|
| 343 |
+
# fold the first FOLD_MAX_AA aa to keep the pipeline robust.
|
| 344 |
+
ref_aa = ref_orf["aa"][:FOLD_MAX_AA]
|
| 345 |
+
carbon_aa = carbon_orf["aa"][:FOLD_MAX_AA]
|
| 346 |
|
| 347 |
t0 = time.time()
|
| 348 |
+
ref_fold = nim_fold(api_key, ref_aa)
|
| 349 |
print(f" ref fold: {ref_fold['n_residues']} aa, pLDDT {ref_fold['plddt_mean']:.1f} ({time.time()-t0:.1f}s)")
|
| 350 |
t0 = time.time()
|
| 351 |
+
carbon_fold = nim_fold(api_key, carbon_aa)
|
| 352 |
print(f" carbon fold: {carbon_fold['n_residues']} aa, pLDDT {carbon_fold['plddt_mean']:.1f} ({time.time()-t0:.1f}s)")
|
| 353 |
|
| 354 |
+
n = min(len(carbon_aa), len(ref_aa))
|
| 355 |
+
identity = sum(1 for i in range(n) if carbon_aa[i] == ref_aa[i]) / n if n else 0.0
|
| 356 |
|
| 357 |
g["fold_example"] = {
|
| 358 |
"prefix_len": FOLD_PREFIX_LEN,
|
| 359 |
"temperature": FOLD_TEMPERATURE,
|
| 360 |
+
"carbon_aa": carbon_aa,
|
| 361 |
+
"ref_aa": ref_aa,
|
| 362 |
"carbon_pdb": carbon_fold["pdb"],
|
| 363 |
"ref_pdb": ref_fold["pdb"],
|
| 364 |
"carbon_plddt_mean": carbon_fold["plddt_mean"],
|
| 365 |
"ref_plddt_mean": ref_fold["plddt_mean"],
|
| 366 |
"identity_1d": identity,
|
| 367 |
}
|
| 368 |
+
print(f" ✓ identity {identity*100:.1f}% ({len(carbon_aa)}/{len(ref_aa)} aa)")
|
| 369 |
except Exception as e:
|
| 370 |
print(f" ✗ {e} (keeping previous fold_example if any)", file=sys.stderr)
|
| 371 |
json.dump(genes, open(path, "w"), indent=2)
|