File size: 16,469 Bytes
2c4851e
 
 
 
 
 
 
971b586
 
b64beb5
 
 
 
971b586
 
 
 
 
 
2c4851e
971b586
 
 
1d3e72f
 
 
 
 
5ea40ce
971b586
2c4851e
 
 
 
 
 
 
 
 
7c8a9cd
 
 
 
 
 
 
 
 
 
 
2c4851e
971b586
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ea40ce
 
971b586
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2c4851e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7c8a9cd
 
2c4851e
 
 
 
 
 
 
7c8a9cd
 
 
 
 
 
2c4851e
7c8a9cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2c4851e
7c8a9cd
 
2c4851e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b64beb5
2c4851e
 
 
 
 
 
 
 
b64beb5
 
 
2c4851e
7c8a9cd
 
2c4851e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7c8a9cd
2c4851e
 
 
 
 
 
7c8a9cd
2c4851e
 
 
7c8a9cd
 
 
 
 
 
 
 
2c4851e
 
7c8a9cd
2c4851e
 
7c8a9cd
2c4851e
 
7c8a9cd
 
2c4851e
 
 
 
7c8a9cd
 
2c4851e
 
 
 
 
 
7c8a9cd
2c4851e
 
 
 
 
 
971b586
2c4851e
 
 
b64beb5
2c4851e
971b586
2c4851e
 
 
 
 
 
 
 
 
b64beb5
971b586
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
"""Precompute /score and /fold outputs so demo.html can render results
instantly without round-tripping the inference endpoints.

Writes results back into data/genes.json and data/variants.json:
- per-gene likelihood `track` (Carbon /score)
- per-variant VEP `score` (Carbon /score)
- per-gene `fold_example` (Carbon /generate + NVIDIA NIM ESMFold)

Usage:
    python scripts/precompute.py                       # everything
    python scripts/precompute.py --folds               # only the folding fixtures
    python scripts/precompute.py --no-folds            # skip folding fixtures
    python scripts/precompute.py --folds --only-missing # only genes lacking fold_example
"""
import json
import os
import sys
import time

import httpx
from openai import OpenAI
from huggingface_hub import get_token

ENDPOINT_URL = os.environ.get(
    "ENDPOINT_URL",
    "https://cr2l9w72ys5pp8le.us-east-1.aws.endpoints.huggingface.cloud/v1/",
)
MODEL_NAME = os.environ.get("MODEL_NAME", "HuggingFaceBio/Carbon-3B")
TRACK_MAX_BP = 24000   # covers TP53 (19,070 bp) with headroom; model max is 32k tokens (~196k bp)

# Folding fixture parameters. T=0.9 because lower temperatures push the
# model into low-complexity sinks on long-context generation (poly-T at
# T<=0.5, AATC repeats at T=0.7), which is a known failure mode of
# autoregressive LMs on long DNA β€” most of the human genome is repetitive
# introns/intergenic, and greedy-ish sampling collapses there. T=0.9
# keeps enough exploration to stay in coding-region distribution.
NIM_FOLD_URL = "https://health.api.nvidia.com/v1/biology/nvidia/esmfold"
FOLD_PREFIX_LEN = 200
FOLD_TEMPERATURE = 0.9
# Best-of-N retries per gene. Long-context generation (MYC/TP53 need
# ~20 kb of DNA) collapses more often than HBB/INS, so we give those
# more shots at producing a viable ORF.
FOLD_BEST_OF_SHORT = 3
FOLD_BEST_OF_LONG  = 5
FOLD_LONG_THRESHOLD = 5000  # bp; above this, use BEST_OF_LONG retries
# Hard cap on AA sent to ESMFold β€” NVIDIA NIM tops out around ~1024 aa
# and we don't want a single oversized Carbon hallucination to fail the
# whole fold step. The reference proteins for our 4 demo genes are all
# well below this anyway (HBB 147, INS 110, MYC 439, TP53 393).
FOLD_MAX_AA = 1000

HERE = os.path.dirname(os.path.abspath(__file__))
DATA = os.path.join(os.path.dirname(HERE), "data")


def make_client():
    key = os.environ.get("HF_TOKEN") or get_token()
    if not key:
        raise RuntimeError("no HF token (set HF_TOKEN or `huggingface-cli login`)")
    return OpenAI(base_url=ENDPOINT_URL, api_key=key)


def left_pad_to_six(seq):
    if not seq: return seq, 0
    rem = len(seq) % 6
    if rem == 0: return seq, 0
    n = 6 - rem
    return ("A" * n) + seq, n


def score(client, seq):
    """Return {tokens, token_logprobs, pad_bases} for the sequence."""
    seq_padded, pad_bases = left_pad_to_six(seq)
    r = client.completions.create(
        model=MODEL_NAME,
        prompt="<dna>" + seq_padded,
        max_tokens=0,
        echo=True,
        logprobs=5,
        temperature=0,
    )
    lp = r.choices[0].logprobs
    return {
        "tokens": list(lp.tokens),
        "token_logprobs": list(lp.token_logprobs),
        "pad_bases": pad_bases,
    }


def precompute_tracks(client):
    path = os.path.join(DATA, "genes.json")
    genes = json.load(open(path))
    for g in genes:
        seq = g["seq"][:TRACK_MAX_BP]
        print(f"  scoring {g['symbol']} ({len(seq)}bp)…", flush=True)
        t0 = time.time()
        try:
            res = score(client, seq)
            res["scored_length"] = len(seq)
            g["track"] = res
            pad = res["pad_bases"]
            print(f"    βœ“ {len(res['tokens'])} tokens in {time.time()-t0:.1f}s" + (f"  (left-padded {pad}bp)" if pad else ""))
        except Exception as e:
            # Keep whatever track was already on this gene; a transient
            # endpoint hiccup shouldn't wipe valid precomputed data.
            print(f"    βœ— {e}", file=sys.stderr)
    json.dump(genes, open(path, "w"), indent=2)
    print(f"  wrote {path}")


def sum_lp(lps):
    s, n = 0.0, 0
    for x in lps:
        if x is not None:
            s += x; n += 1
    return s, n


def precompute_vep(client):
    path = os.path.join(DATA, "variants.json")
    variants = json.load(open(path))
    for v in variants:
        ref = v["ref_window"]
        alt = ref[:v["var_offset"]] + v["alt"] + ref[v["var_offset"]+1:]
        print(f"  scoring {v['rs']} ({v['name']})…", flush=True)
        try:
            r_ref = score(client, ref)
            r_alt = score(client, alt)
            ref_sum, ref_n = sum_lp(r_ref["token_logprobs"])
            alt_sum, alt_n = sum_lp(r_alt["token_logprobs"])
            v["score"] = {
                "ref_sum": ref_sum, "alt_sum": alt_sum,
                "ref_logprobs": r_ref["token_logprobs"],
                "alt_logprobs": r_alt["token_logprobs"],
                "n": ref_n,
                "delta": alt_sum - ref_sum,
            }
            print(f"    βœ“ Ξ” = {alt_sum - ref_sum:+.2f}  (ref {ref_sum:.2f}, alt {alt_sum:.2f})")
        except Exception as e:
            print(f"    βœ— {e}", file=sys.stderr)
            v.pop("score", None)
    json.dump(variants, open(path, "w"), indent=2)
    print(f"  wrote {path}")


# =========================================================================
# Folding fixture: per-gene Carbon continuation + ESMFold structures for
# both the reference and Carbon's spliced mRNA. Keeps the demo's first
# render instant even when the inference endpoints are cold.
# =========================================================================

CODON_TABLE = {
    "TTT":"F","TTC":"F","TTA":"L","TTG":"L", "CTT":"L","CTC":"L","CTA":"L","CTG":"L",
    "ATT":"I","ATC":"I","ATA":"I","ATG":"M", "GTT":"V","GTC":"V","GTA":"V","GTG":"V",
    "TCT":"S","TCC":"S","TCA":"S","TCG":"S", "CCT":"P","CCC":"P","CCA":"P","CCG":"P",
    "ACT":"T","ACC":"T","ACA":"T","ACG":"T", "GCT":"A","GCC":"A","GCA":"A","GCG":"A",
    "TAT":"Y","TAC":"Y","TAA":"*","TAG":"*", "CAT":"H","CAC":"H","CAA":"Q","CAG":"Q",
    "AAT":"N","AAC":"N","AAA":"K","AAG":"K", "GAT":"D","GAC":"D","GAA":"E","GAG":"E",
    "TGT":"C","TGC":"C","TGA":"*","TGG":"W", "CGT":"R","CGC":"R","CGA":"R","CGG":"R",
    "AGT":"S","AGC":"S","AGA":"R","AGG":"R", "GGT":"G","GGC":"G","GGA":"G","GGG":"G",
}


def splice_exons(dna, exons):
    """Concatenate exon slices, truncating any exon past the end of dna."""
    parts = []
    for e in exons:
        if e["start"] >= len(dna):
            break
        parts.append(dna[e["start"]:min(e["end"], len(dna))])
    return "".join(parts)


def find_longest_orf(dna, min_aa=30):
    """Mirror of demo.html findLongestORF β€” scans all 3 frames, returns
    the longest ATG-initiated ORF of at least `min_aa` amino acids.

    Prefers an ORF that ends on a clean stop codon. Falls back to a
    truncated ORF (reached the end of dna with no stop) β€” that case
    happens when Carbon mutates the canonical stop codon and the
    translation reads into the 3'UTR. Truncated entries are tagged
    so the UI can surface that biological detail."""
    best_clean = None
    best_trunc = None
    for frame in range(3):
        i = frame
        while i + 3 <= len(dna):
            if dna[i:i+3] != "ATG":
                i += 3
                continue
            aa = []
            j = i
            stopped = False
            invalid = False
            while j + 3 <= len(dna):
                a = CODON_TABLE.get(dna[j:j+3])
                if a is None:
                    invalid = True
                    break
                if a == "*":
                    stopped = True
                    break
                aa.append(a)
                j += 3
            if not invalid and len(aa) >= min_aa:
                entry = {"aa": "".join(aa), "frame": frame, "start_bp": i, "end_bp": j, "truncated": not stopped}
                if stopped:
                    if best_clean is None or len(aa) > len(best_clean["aa"]):
                        best_clean = entry
                else:
                    if best_trunc is None or len(aa) > len(best_trunc["aa"]):
                        best_trunc = entry
            i += 3
    return best_clean or best_trunc


BACKEND_URL = os.environ.get("CARBON_BACKEND", "http://127.0.0.1:7870")


def carbon_continue(_client, prompt_dna, max_tokens, temperature,
                    max_503_retries=10, retry_wait_s=15):
    """Ask Carbon to continue a DNA prompt and return the cleaned continuation.

    Calls the backend /generate (SSE-streamed) rather than the OpenAI
    client directly. Going through /generate guarantees the precompute
    follows the exact same pipeline as the live demo (left-padding to a
    multiple of 6, <dna> prefix, streaming framing) so the cached example
    is identical to what runFold() would produce.

    The HF Inference Endpoint is configured to scale-to-zero after a few
    hours idle, so the first call after a cold period bubbles up a 503
    from upstream. We wait the endpoint out with a fixed backoff instead
    of giving up β€” subsequent calls in the same session hit warm pods
    and return in seconds.
    """
    last_err = None
    for attempt in range(max_503_retries + 1):
        try:
            with httpx.Client(timeout=300) as cx:
                r = cx.post(
                    f"{BACKEND_URL}/generate",
                    json={"prompt": prompt_dna, "max_tokens": max_tokens,
                          "temperature": temperature, "top_p": 1.0},
                )
                r.raise_for_status()
                out = []
                for line in r.text.splitlines():
                    line = line.strip()
                    if not line.startswith("data:"):
                        continue
                    payload = json.loads(line[5:].strip())
                    if "error" in payload:
                        msg = str(payload["error"])
                        if "503" in msg and attempt < max_503_retries:
                            raise RuntimeError(msg)
                        raise RuntimeError(msg)
                    t = payload.get("text") or ""
                    out.append(t)
            text = ("".join(out)).upper()
            return "".join(c for c in text if c in "ACGT")
        except (httpx.HTTPStatusError, RuntimeError) as e:
            msg = str(e)
            last_err = e
            if "503" in msg and attempt < max_503_retries:
                print(f"    … HF endpoint cold, waiting {retry_wait_s}s "
                      f"(attempt {attempt+1}/{max_503_retries+1})", flush=True)
                time.sleep(retry_wait_s)
                continue
            raise
    raise last_err if last_err else RuntimeError("carbon_continue: unreachable")


def nim_fold(api_key, sequence):
    """Call NVIDIA NIM ESMFold and return {pdb, plddt_mean, n_residues}.
    Mirrors app.py's /fold logic β€” same endpoint, same JSON shape."""
    with httpx.Client(timeout=180) as cx:
        r = cx.post(
            NIM_FOLD_URL,
            headers={
                "Authorization": f"Bearer {api_key}",
                "Content-Type": "application/json",
                "Accept": "application/json",
            },
            json={"sequence": sequence},
        )
        r.raise_for_status()
        data = r.json()
    pdb = (data.get("pdbs") or [None])[0] or data.get("pdb") or ""
    if not pdb:
        raise RuntimeError("NIM returned no PDB")
    plddts = []
    for line in pdb.splitlines():
        if line.startswith("ATOM") and line[12:16].strip() == "CA":
            try:
                plddts.append(float(line[60:66].strip()))
            except ValueError:
                pass
    return {
        "pdb": pdb,
        "n_residues": len(plddts),
        "plddt_mean": (sum(plddts) / len(plddts)) if plddts else None,
    }


def precompute_folds(client, only_missing=False):
    api_key = os.environ.get("NVIDIA_API_KEY")
    if not api_key:
        raise RuntimeError("NVIDIA_API_KEY missing (set in .env or env)")
    path = os.path.join(DATA, "genes.json")
    genes = json.load(open(path))
    for g in genes:
        if not g.get("exons"):
            continue
        if only_missing and g.get("fold_example"):
            print(f"  skipping {g['symbol']} (fold_example already cached)")
            continue
        last_exon_end = g["exons"][-1]["end"]
        n_tries = FOLD_BEST_OF_LONG if last_exon_end > FOLD_LONG_THRESHOLD else FOLD_BEST_OF_SHORT
        print(f"  folding {g['symbol']} (last exon end {last_exon_end} bp, best-of-{n_tries})…", flush=True)
        try:
            seq = g["seq"].upper()
            ref_mrna = splice_exons(seq, g["exons"])
            ref_orf = find_longest_orf(ref_mrna, 30)
            if not ref_orf:
                raise RuntimeError("reference has no valid ORF after splice")

            prompt = "".join(c for c in seq[:FOLD_PREFIX_LEN] if c in "ACGT")
            gen_bp = max(0, last_exon_end - FOLD_PREFIX_LEN) + 60
            max_tokens = (gen_bp // 6) + 8

            # Best-of-N: at T>0 the model's continuation can stop early on
            # a fluky premature codon. Try a few times and keep the longest
            # ORF β€” closer to "what Carbon usually produces for this gene".
            carbon_orf = None
            for attempt in range(n_tries):
                t0 = time.time()
                cont = carbon_continue(client, prompt, max_tokens, FOLD_TEMPERATURE)
                carbon_dna = (prompt + cont)[: FOLD_PREFIX_LEN + gen_bp]
                carbon_mrna = splice_exons(carbon_dna, g["exons"])
                orf = find_longest_orf(carbon_mrna, 30)
                n = len(orf["aa"]) if orf else 0
                print(f"    carbon try {attempt+1}/{n_tries}: ORF {n} aa ({time.time()-t0:.1f}s)")
                if orf and (carbon_orf is None or len(orf["aa"]) > len(carbon_orf["aa"])):
                    carbon_orf = orf
            if not carbon_orf:
                raise RuntimeError("Carbon's spliced mRNA has no valid ORF after %d tries" % n_tries)

            # Clamp to NIM's ~1024 aa ceiling. The reference proteins are
            # all well below this; Carbon hallucinations can occasionally
            # exceed it after a mutated stop codon, in which case we just
            # fold the first FOLD_MAX_AA aa to keep the pipeline robust.
            ref_aa    = ref_orf["aa"][:FOLD_MAX_AA]
            carbon_aa = carbon_orf["aa"][:FOLD_MAX_AA]

            t0 = time.time()
            ref_fold = nim_fold(api_key, ref_aa)
            print(f"    ref fold: {ref_fold['n_residues']} aa, pLDDT {ref_fold['plddt_mean']:.1f} ({time.time()-t0:.1f}s)")
            t0 = time.time()
            carbon_fold = nim_fold(api_key, carbon_aa)
            print(f"    carbon fold: {carbon_fold['n_residues']} aa, pLDDT {carbon_fold['plddt_mean']:.1f} ({time.time()-t0:.1f}s)")

            n = min(len(carbon_aa), len(ref_aa))
            identity = sum(1 for i in range(n) if carbon_aa[i] == ref_aa[i]) / n if n else 0.0

            g["fold_example"] = {
                "prefix_len": FOLD_PREFIX_LEN,
                "temperature": FOLD_TEMPERATURE,
                "carbon_aa": carbon_aa,
                "ref_aa": ref_aa,
                "carbon_pdb": carbon_fold["pdb"],
                "ref_pdb": ref_fold["pdb"],
                "carbon_plddt_mean": carbon_fold["plddt_mean"],
                "ref_plddt_mean": ref_fold["plddt_mean"],
                "identity_1d": identity,
            }
            print(f"    βœ“ identity {identity*100:.1f}%  ({len(carbon_aa)}/{len(ref_aa)} aa)")
        except Exception as e:
            print(f"    βœ— {e}  (keeping previous fold_example if any)", file=sys.stderr)
    json.dump(genes, open(path, "w"), indent=2)
    print(f"  wrote {path}")


def main():
    argv = set(sys.argv[1:])
    only_folds = "--folds" in argv
    skip_folds = "--no-folds" in argv
    only_missing = "--only-missing" in argv

    client = make_client()
    if not only_folds:
        print("=== precomputing likelihood tracks ===")
        precompute_tracks(client)
        print()
        print("=== precomputing VEP scores ===")
        precompute_vep(client)
    if not skip_folds:
        print()
        print("=== precomputing fold fixtures ===")
        precompute_folds(client, only_missing=only_missing)


if __name__ == "__main__":
    main()