"""Precompute /score and /fold outputs so demo.html can render results instantly without round-tripping the inference endpoints. Writes results back into data/genes.json and data/variants.json: - per-gene likelihood `track` (Carbon /score) - per-variant VEP `score` (Carbon /score) - per-gene `fold_example` (Carbon /generate + NVIDIA NIM ESMFold) Usage: python scripts/precompute.py # everything python scripts/precompute.py --folds # only the folding fixtures python scripts/precompute.py --no-folds # skip folding fixtures python scripts/precompute.py --folds --only-missing # only genes lacking fold_example """ import json import os import sys import time import httpx from openai import OpenAI from huggingface_hub import get_token ENDPOINT_URL = os.environ.get( "ENDPOINT_URL", "https://cr2l9w72ys5pp8le.us-east-1.aws.endpoints.huggingface.cloud/v1/", ) MODEL_NAME = os.environ.get("MODEL_NAME", "HuggingFaceBio/Carbon-3B") TRACK_MAX_BP = 24000 # covers TP53 (19,070 bp) with headroom; model max is 32k tokens (~196k bp) # Folding fixture parameters. T=0.9 because lower temperatures push the # model into low-complexity sinks on long-context generation (poly-T at # T<=0.5, AATC repeats at T=0.7), which is a known failure mode of # autoregressive LMs on long DNA — most of the human genome is repetitive # introns/intergenic, and greedy-ish sampling collapses there. T=0.9 # keeps enough exploration to stay in coding-region distribution. NIM_FOLD_URL = "https://health.api.nvidia.com/v1/biology/nvidia/esmfold" FOLD_PREFIX_LEN = 200 FOLD_TEMPERATURE = 0.9 # Best-of-N retries per gene. Long-context generation (MYC/TP53 need # ~20 kb of DNA) collapses more often than HBB/INS, so we give those # more shots at producing a viable ORF. FOLD_BEST_OF_SHORT = 3 FOLD_BEST_OF_LONG = 5 FOLD_LONG_THRESHOLD = 5000 # bp; above this, use BEST_OF_LONG retries # Hard cap on AA sent to ESMFold — NVIDIA NIM tops out around ~1024 aa # and we don't want a single oversized Carbon hallucination to fail the # whole fold step. The reference proteins for our 4 demo genes are all # well below this anyway (HBB 147, INS 110, MYC 439, TP53 393). FOLD_MAX_AA = 1000 HERE = os.path.dirname(os.path.abspath(__file__)) DATA = os.path.join(os.path.dirname(HERE), "data") def make_client(): key = os.environ.get("HF_TOKEN") or get_token() if not key: raise RuntimeError("no HF token (set HF_TOKEN or `huggingface-cli login`)") return OpenAI(base_url=ENDPOINT_URL, api_key=key) def left_pad_to_six(seq): if not seq: return seq, 0 rem = len(seq) % 6 if rem == 0: return seq, 0 n = 6 - rem return ("A" * n) + seq, n def score(client, seq): """Return {tokens, token_logprobs, pad_bases} for the sequence.""" seq_padded, pad_bases = left_pad_to_six(seq) r = client.completions.create( model=MODEL_NAME, prompt="" + seq_padded, max_tokens=0, echo=True, logprobs=5, temperature=0, ) lp = r.choices[0].logprobs return { "tokens": list(lp.tokens), "token_logprobs": list(lp.token_logprobs), "pad_bases": pad_bases, } def precompute_tracks(client): path = os.path.join(DATA, "genes.json") genes = json.load(open(path)) for g in genes: seq = g["seq"][:TRACK_MAX_BP] print(f" scoring {g['symbol']} ({len(seq)}bp)…", flush=True) t0 = time.time() try: res = score(client, seq) res["scored_length"] = len(seq) g["track"] = res pad = res["pad_bases"] print(f" ✓ {len(res['tokens'])} tokens in {time.time()-t0:.1f}s" + (f" (left-padded {pad}bp)" if pad else "")) except Exception as e: # Keep whatever track was already on this gene; a transient # endpoint hiccup shouldn't wipe valid precomputed data. print(f" ✗ {e}", file=sys.stderr) json.dump(genes, open(path, "w"), indent=2) print(f" wrote {path}") def sum_lp(lps): s, n = 0.0, 0 for x in lps: if x is not None: s += x; n += 1 return s, n def precompute_vep(client): path = os.path.join(DATA, "variants.json") variants = json.load(open(path)) for v in variants: ref = v["ref_window"] alt = ref[:v["var_offset"]] + v["alt"] + ref[v["var_offset"]+1:] print(f" scoring {v['rs']} ({v['name']})…", flush=True) try: r_ref = score(client, ref) r_alt = score(client, alt) ref_sum, ref_n = sum_lp(r_ref["token_logprobs"]) alt_sum, alt_n = sum_lp(r_alt["token_logprobs"]) v["score"] = { "ref_sum": ref_sum, "alt_sum": alt_sum, "ref_logprobs": r_ref["token_logprobs"], "alt_logprobs": r_alt["token_logprobs"], "n": ref_n, "delta": alt_sum - ref_sum, } print(f" ✓ Δ = {alt_sum - ref_sum:+.2f} (ref {ref_sum:.2f}, alt {alt_sum:.2f})") except Exception as e: print(f" ✗ {e}", file=sys.stderr) v.pop("score", None) json.dump(variants, open(path, "w"), indent=2) print(f" wrote {path}") # ========================================================================= # Folding fixture: per-gene Carbon continuation + ESMFold structures for # both the reference and Carbon's spliced mRNA. Keeps the demo's first # render instant even when the inference endpoints are cold. # ========================================================================= CODON_TABLE = { "TTT":"F","TTC":"F","TTA":"L","TTG":"L", "CTT":"L","CTC":"L","CTA":"L","CTG":"L", "ATT":"I","ATC":"I","ATA":"I","ATG":"M", "GTT":"V","GTC":"V","GTA":"V","GTG":"V", "TCT":"S","TCC":"S","TCA":"S","TCG":"S", "CCT":"P","CCC":"P","CCA":"P","CCG":"P", "ACT":"T","ACC":"T","ACA":"T","ACG":"T", "GCT":"A","GCC":"A","GCA":"A","GCG":"A", "TAT":"Y","TAC":"Y","TAA":"*","TAG":"*", "CAT":"H","CAC":"H","CAA":"Q","CAG":"Q", "AAT":"N","AAC":"N","AAA":"K","AAG":"K", "GAT":"D","GAC":"D","GAA":"E","GAG":"E", "TGT":"C","TGC":"C","TGA":"*","TGG":"W", "CGT":"R","CGC":"R","CGA":"R","CGG":"R", "AGT":"S","AGC":"S","AGA":"R","AGG":"R", "GGT":"G","GGC":"G","GGA":"G","GGG":"G", } def splice_exons(dna, exons): """Concatenate exon slices, truncating any exon past the end of dna.""" parts = [] for e in exons: if e["start"] >= len(dna): break parts.append(dna[e["start"]:min(e["end"], len(dna))]) return "".join(parts) def find_longest_orf(dna, min_aa=30): """Mirror of demo.html findLongestORF — scans all 3 frames, returns the longest ATG-initiated ORF of at least `min_aa` amino acids. Prefers an ORF that ends on a clean stop codon. Falls back to a truncated ORF (reached the end of dna with no stop) — that case happens when Carbon mutates the canonical stop codon and the translation reads into the 3'UTR. Truncated entries are tagged so the UI can surface that biological detail.""" best_clean = None best_trunc = None for frame in range(3): i = frame while i + 3 <= len(dna): if dna[i:i+3] != "ATG": i += 3 continue aa = [] j = i stopped = False invalid = False while j + 3 <= len(dna): a = CODON_TABLE.get(dna[j:j+3]) if a is None: invalid = True break if a == "*": stopped = True break aa.append(a) j += 3 if not invalid and len(aa) >= min_aa: entry = {"aa": "".join(aa), "frame": frame, "start_bp": i, "end_bp": j, "truncated": not stopped} if stopped: if best_clean is None or len(aa) > len(best_clean["aa"]): best_clean = entry else: if best_trunc is None or len(aa) > len(best_trunc["aa"]): best_trunc = entry i += 3 return best_clean or best_trunc BACKEND_URL = os.environ.get("CARBON_BACKEND", "http://127.0.0.1:7870") def carbon_continue(_client, prompt_dna, max_tokens, temperature, max_503_retries=10, retry_wait_s=15): """Ask Carbon to continue a DNA prompt and return the cleaned continuation. Calls the backend /generate (SSE-streamed) rather than the OpenAI client directly. Going through /generate guarantees the precompute follows the exact same pipeline as the live demo (left-padding to a multiple of 6, prefix, streaming framing) so the cached example is identical to what runFold() would produce. The HF Inference Endpoint is configured to scale-to-zero after a few hours idle, so the first call after a cold period bubbles up a 503 from upstream. We wait the endpoint out with a fixed backoff instead of giving up — subsequent calls in the same session hit warm pods and return in seconds. """ last_err = None for attempt in range(max_503_retries + 1): try: with httpx.Client(timeout=300) as cx: r = cx.post( f"{BACKEND_URL}/generate", json={"prompt": prompt_dna, "max_tokens": max_tokens, "temperature": temperature, "top_p": 1.0}, ) r.raise_for_status() out = [] for line in r.text.splitlines(): line = line.strip() if not line.startswith("data:"): continue payload = json.loads(line[5:].strip()) if "error" in payload: msg = str(payload["error"]) if "503" in msg and attempt < max_503_retries: raise RuntimeError(msg) raise RuntimeError(msg) t = payload.get("text") or "" out.append(t) text = ("".join(out)).upper() return "".join(c for c in text if c in "ACGT") except (httpx.HTTPStatusError, RuntimeError) as e: msg = str(e) last_err = e if "503" in msg and attempt < max_503_retries: print(f" … HF endpoint cold, waiting {retry_wait_s}s " f"(attempt {attempt+1}/{max_503_retries+1})", flush=True) time.sleep(retry_wait_s) continue raise raise last_err if last_err else RuntimeError("carbon_continue: unreachable") def nim_fold(api_key, sequence): """Call NVIDIA NIM ESMFold and return {pdb, plddt_mean, n_residues}. Mirrors app.py's /fold logic — same endpoint, same JSON shape.""" with httpx.Client(timeout=180) as cx: r = cx.post( NIM_FOLD_URL, headers={ "Authorization": f"Bearer {api_key}", "Content-Type": "application/json", "Accept": "application/json", }, json={"sequence": sequence}, ) r.raise_for_status() data = r.json() pdb = (data.get("pdbs") or [None])[0] or data.get("pdb") or "" if not pdb: raise RuntimeError("NIM returned no PDB") plddts = [] for line in pdb.splitlines(): if line.startswith("ATOM") and line[12:16].strip() == "CA": try: plddts.append(float(line[60:66].strip())) except ValueError: pass return { "pdb": pdb, "n_residues": len(plddts), "plddt_mean": (sum(plddts) / len(plddts)) if plddts else None, } def precompute_folds(client, only_missing=False): api_key = os.environ.get("NVIDIA_API_KEY") if not api_key: raise RuntimeError("NVIDIA_API_KEY missing (set in .env or env)") path = os.path.join(DATA, "genes.json") genes = json.load(open(path)) for g in genes: if not g.get("exons"): continue if only_missing and g.get("fold_example"): print(f" skipping {g['symbol']} (fold_example already cached)") continue last_exon_end = g["exons"][-1]["end"] n_tries = FOLD_BEST_OF_LONG if last_exon_end > FOLD_LONG_THRESHOLD else FOLD_BEST_OF_SHORT print(f" folding {g['symbol']} (last exon end {last_exon_end} bp, best-of-{n_tries})…", flush=True) try: seq = g["seq"].upper() ref_mrna = splice_exons(seq, g["exons"]) ref_orf = find_longest_orf(ref_mrna, 30) if not ref_orf: raise RuntimeError("reference has no valid ORF after splice") prompt = "".join(c for c in seq[:FOLD_PREFIX_LEN] if c in "ACGT") gen_bp = max(0, last_exon_end - FOLD_PREFIX_LEN) + 60 max_tokens = (gen_bp // 6) + 8 # Best-of-N: at T>0 the model's continuation can stop early on # a fluky premature codon. Try a few times and keep the longest # ORF — closer to "what Carbon usually produces for this gene". carbon_orf = None for attempt in range(n_tries): t0 = time.time() cont = carbon_continue(client, prompt, max_tokens, FOLD_TEMPERATURE) carbon_dna = (prompt + cont)[: FOLD_PREFIX_LEN + gen_bp] carbon_mrna = splice_exons(carbon_dna, g["exons"]) orf = find_longest_orf(carbon_mrna, 30) n = len(orf["aa"]) if orf else 0 print(f" carbon try {attempt+1}/{n_tries}: ORF {n} aa ({time.time()-t0:.1f}s)") if orf and (carbon_orf is None or len(orf["aa"]) > len(carbon_orf["aa"])): carbon_orf = orf if not carbon_orf: raise RuntimeError("Carbon's spliced mRNA has no valid ORF after %d tries" % n_tries) # Clamp to NIM's ~1024 aa ceiling. The reference proteins are # all well below this; Carbon hallucinations can occasionally # exceed it after a mutated stop codon, in which case we just # fold the first FOLD_MAX_AA aa to keep the pipeline robust. ref_aa = ref_orf["aa"][:FOLD_MAX_AA] carbon_aa = carbon_orf["aa"][:FOLD_MAX_AA] t0 = time.time() ref_fold = nim_fold(api_key, ref_aa) print(f" ref fold: {ref_fold['n_residues']} aa, pLDDT {ref_fold['plddt_mean']:.1f} ({time.time()-t0:.1f}s)") t0 = time.time() carbon_fold = nim_fold(api_key, carbon_aa) print(f" carbon fold: {carbon_fold['n_residues']} aa, pLDDT {carbon_fold['plddt_mean']:.1f} ({time.time()-t0:.1f}s)") n = min(len(carbon_aa), len(ref_aa)) identity = sum(1 for i in range(n) if carbon_aa[i] == ref_aa[i]) / n if n else 0.0 g["fold_example"] = { "prefix_len": FOLD_PREFIX_LEN, "temperature": FOLD_TEMPERATURE, "carbon_aa": carbon_aa, "ref_aa": ref_aa, "carbon_pdb": carbon_fold["pdb"], "ref_pdb": ref_fold["pdb"], "carbon_plddt_mean": carbon_fold["plddt_mean"], "ref_plddt_mean": ref_fold["plddt_mean"], "identity_1d": identity, } print(f" ✓ identity {identity*100:.1f}% ({len(carbon_aa)}/{len(ref_aa)} aa)") except Exception as e: print(f" ✗ {e} (keeping previous fold_example if any)", file=sys.stderr) json.dump(genes, open(path, "w"), indent=2) print(f" wrote {path}") def main(): argv = set(sys.argv[1:]) only_folds = "--folds" in argv skip_folds = "--no-folds" in argv only_missing = "--only-missing" in argv client = make_client() if not only_folds: print("=== precomputing likelihood tracks ===") precompute_tracks(client) print() print("=== precomputing VEP scores ===") precompute_vep(client) if not skip_folds: print() print("=== precomputing fold fixtures ===") precompute_folds(client, only_missing=only_missing) if __name__ == "__main__": main()