lvwerra HF Staff commited on
Commit
b9445bd
·
1 Parent(s): 5798713

Recipe §3: extend BP section to cover scoring as a second endpoint

Browse files

The same marginalization that powers FNS at training time also factors
the scoring endpoint, where you read P(actual base | context) directly
off the per-position marginals instead of forcing a token. Renamed the
section "BP-level inference" (id bpinference) to reflect both endpoints,
rewrote the lede, split the visual's last station into step 3a (generate)
and step 3b (score), and added a "score" tab to the code snippet showing
score_sequence() via the -remote checkpoints.

Files changed (1) hide show
  1. demo.html +94 -36
demo.html CHANGED
@@ -1204,8 +1204,8 @@ for name, ids in zip(species_prefixes, new_ids):
1204
  The sections below walk through each of those choices: how the tokenizer changes
1205
  what a "token" means in DNA <a class="lede-chip" href="#tokenizer">§1</a>, how
1206
  FNS rescues training in the BF16 regime <a class="lede-chip" href="#loss">§2</a>,
1207
- how bp-level generation falls out of the same marginalisation
1208
- <a class="lede-chip" href="#bpgen">§3</a>, what's in the training corpus
1209
  <a class="lede-chip" href="#data">§4</a>, what the architecture looks like
1210
  <a class="lede-chip" href="#architecture">§5</a>, how 8k-token pretraining reaches
1211
  786 kbp at inference <a class="lede-chip" href="#longcontext">§6</a>, how Carbon
@@ -1331,23 +1331,27 @@ for name, ids in zip(species_prefixes, new_ids):
1331
  </section>
1332
 
1333
  <!-- ============================================================ -->
1334
- <!-- §8.5 · BP-LEVEL GENERATION -->
1335
  <!-- ============================================================ -->
1336
- <section id="bpgen" class="section--two-col">
1337
  <div class="section-narrative">
1338
- <div class="section-num">§3 · BP-level generation</div>
1339
- <div class="section-title">Sample bases, not 6-mers</div>
1340
  <p class="lede">
1341
- The 6-mer tokenizer makes Carbon fast, but it's coarse at sampling time: each
1342
- step advances the sequence by 6 bases at once, temperature acts on a 4,096-way
1343
- distribution rather than per nucleotide, and stopping at an odd base count is
1344
- awkward. The same marginalisation that powers FNS at training time inverts the
1345
- tokenizer at inference: softmax over the 6-mer logits, then for each position
1346
- <code>p</code> sum the probabilities of every 6-mer that shares a given base at
1347
- <code>p</code>, and you recover six per-position 4-way base distributions.
1348
- Sample (or argmax) each independently, look up the matching 6-mer token id,
1349
- and force that token as the next selection. The decoder still emits one token
1350
- per step so throughput is unchanged, but the choice is now base-pair resolved.
 
 
 
 
1351
  </p>
1352
  </div>
1353
 
@@ -1499,16 +1503,33 @@ for name, ids in zip(species_prefixes, new_ids):
1499
  </div>
1500
  </div>
1501
 
1502
- <div style="text-align:center;color:#888;font-size:11px">▼ &nbsp; argmax (greedy) or multinomial (sampled) per position, then reassemble</div>
1503
 
1504
- <div>
1505
- <div style="font-size:10px;color:#888;letter-spacing:1px;text-transform:uppercase;margin-bottom:6px">step 3 · forced as the next 6-mer token</div>
1506
- <div style="display:flex;align-items:center;justify-content:center;gap:10px;padding:12px;background:#fafaf6;border:1px solid #eee">
1507
- <div style="display:flex;gap:6px;font-size:18px;font-weight:700;color:#1A7A40;letter-spacing:2px">
1508
- <span>A</span><span>C</span><span>G</span><span>T</span><span>A</span><span>T</span>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1509
  </div>
1510
- <span style="color:#888">→</span>
1511
- <span style="font-size:11px;color:#666">matching 6-mer token id forced via <code>scores.fill_(-inf); scores[id] = 0</code></span>
1512
  </div>
1513
  </div>
1514
 
@@ -1517,25 +1538,33 @@ for name, ids in zip(species_prefixes, new_ids):
1517
 
1518
  <div class="takeaway">
1519
  <strong>When to switch on bp-level</strong>
1520
- Reach for plain 6-mer sampling when 6-base granularity is fine: throughput-bound
1521
- decoding, long retrieval haystacks, large-scale screening. Switch to bp-level
1522
- when you need exact base counts, per-position masks, or temperature and top-p
1523
- applied at the base axis rather than the 4,096-way 6-mer axis. Same model, same
1524
- weights, same sampling controls; only the last step of the logits chain changes.
1525
- The <code>HuggingFaceBio/carbon-generate</code> repo ships this as a transformers
1526
- <code>custom_generate</code> method, so plain <code>LlamaForCausalLM</code>
1527
- checkpoints get bp-level generation without a custom modeling file or
1528
- <code>trust_remote_code</code> on the weights.
 
 
 
 
 
 
 
1529
  </div>
1530
 
1531
  <details class="code-snippet">
1532
  <summary>Run this from code</summary>
1533
  <div class="code-snippet__body">
1534
  <div class="code-snippet__tabs">
1535
- <button class="code-snippet__tab active" data-tab="local" type="button">transformers</button>
 
1536
  </div>
1537
  <button class="code-snippet__copy" type="button">Copy</button>
1538
- <div class="code-snippet__panel active" data-tab="local"><pre><code>from transformers import AutoModelForCausalLM, AutoTokenizer
1539
  import torch
1540
 
1541
  tok = AutoTokenizer.from_pretrained(
@@ -1549,7 +1578,7 @@ model = AutoModelForCausalLM.from_pretrained(
1549
  prompt = "&lt;dna&gt;ATGCGCTAGCTACGATCGATCGTAGCTAGCTAGCTAGCTACG"
1550
  inputs = tok(prompt, return_tensors="pt", add_special_tokens=False).to(model.device)
1551
 
1552
- # `custom_generate` injects a logits processor that marginalises the
1553
  # 6-mer logits to per-base distributions and samples each of the 6
1554
  # positions independently, then forces the matching 6-mer token. All
1555
  # standard generation knobs (temperature, top_p, top_k, repetition_penalty)
@@ -1566,6 +1595,35 @@ out = model.generate(
1566
  # Slice off the prompt and decode the continuation as plain DNA.
1567
  new_ids = out[0, inputs["input_ids"].shape[1]:]
1568
  print(tok.decode(new_ids, skip_special_tokens=True))</code></pre></div>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1569
  </div>
1570
  </details>
1571
  </div>
 
1204
  The sections below walk through each of those choices: how the tokenizer changes
1205
  what a "token" means in DNA <a class="lede-chip" href="#tokenizer">§1</a>, how
1206
  FNS rescues training in the BF16 regime <a class="lede-chip" href="#loss">§2</a>,
1207
+ how bp-level generation and scoring fall out of the same marginalization
1208
+ <a class="lede-chip" href="#bpinference">§3</a>, what's in the training corpus
1209
  <a class="lede-chip" href="#data">§4</a>, what the architecture looks like
1210
  <a class="lede-chip" href="#architecture">§5</a>, how 8k-token pretraining reaches
1211
  786 kbp at inference <a class="lede-chip" href="#longcontext">§6</a>, how Carbon
 
1331
  </section>
1332
 
1333
  <!-- ============================================================ -->
1334
+ <!-- §8.5 · BP-LEVEL INFERENCE -->
1335
  <!-- ============================================================ -->
1336
+ <section id="bpinference" class="section--two-col">
1337
  <div class="section-narrative">
1338
+ <div class="section-num">§3 · BP-level inference</div>
1339
+ <div class="section-title">Bases, not 6-mers</div>
1340
  <p class="lede">
1341
+ The 6-mer tokenizer makes Carbon fast, but it's coarse in both directions
1342
+ of inference. When <em>generating</em>, each step advances the sequence by
1343
+ 6 bases at once and temperature acts on a 4,096-way distribution rather
1344
+ than per nucleotide. When <em>scoring</em> an existing sequence, the raw
1345
+ next-token likelihood answers "how likely is this 6-mer in context?", not
1346
+ "how likely is this exact base at this exact position?", which is the
1347
+ version you want for variant-effect prediction. The same marginalization
1348
+ that powers FNS at training time fixes both: softmax over the 6-mer
1349
+ logits, then for each position <code>p</code> sum the probabilities of
1350
+ every 6-mer that shares a given base at <code>p</code>, and you recover
1351
+ six per-position 4-way base distributions. To generate, sample (or argmax)
1352
+ each independently and force the matching 6-mer token. To score, read
1353
+ <em>P(actual base | context)</em> directly off the marginals at every
1354
+ position. Same logits, same math, two endpoints.
1355
  </p>
1356
  </div>
1357
 
 
1503
  </div>
1504
  </div>
1505
 
1506
+ <div style="text-align:center;color:#888;font-size:11px">▼ &nbsp; same marginals feed two endpoints: generate (force a token) or score (read off P(base))</div>
1507
 
1508
+ <div style="display:grid;grid-template-columns:1fr 1fr;gap:10px">
1509
+ <!-- step 3a · generation endpoint -->
1510
+ <div>
1511
+ <div style="font-size:10px;color:#888;letter-spacing:1px;text-transform:uppercase;margin-bottom:6px">step 3a · generate</div>
1512
+ <div style="display:flex;flex-direction:column;align-items:center;justify-content:center;gap:6px;padding:12px;background:#fafaf6;border:1px solid #eee;height:88px;box-sizing:border-box">
1513
+ <div style="display:flex;gap:6px;font-size:18px;font-weight:700;color:#1A7A40;letter-spacing:2px">
1514
+ <span>A</span><span>C</span><span>G</span><span>T</span><span>A</span><span>T</span>
1515
+ </div>
1516
+ <div style="font-size:10px;color:#666;text-align:center;line-height:1.4">
1517
+ argmax / multinomial &rarr; force matching 6-mer token
1518
+ </div>
1519
+ </div>
1520
+ </div>
1521
+
1522
+ <!-- step 3b · scoring endpoint -->
1523
+ <div>
1524
+ <div style="font-size:10px;color:#888;letter-spacing:1px;text-transform:uppercase;margin-bottom:6px">step 3b · score</div>
1525
+ <div style="display:flex;flex-direction:column;align-items:center;justify-content:center;gap:6px;padding:12px;background:#fafaf6;border:1px solid #eee;height:88px;box-sizing:border-box">
1526
+ <div style="display:flex;gap:8px;font-size:11px;color:#1A7A40;font-weight:600;font-feature-settings:'tnum'">
1527
+ <span>.83</span><span>.71</span><span>.92</span><span>.67</span><span>.48</span><span>.79</span>
1528
+ </div>
1529
+ <div style="font-size:10px;color:#666;text-align:center;line-height:1.4">
1530
+ read P(actual base | context) at each position
1531
+ </div>
1532
  </div>
 
 
1533
  </div>
1534
  </div>
1535
 
 
1538
 
1539
  <div class="takeaway">
1540
  <strong>When to switch on bp-level</strong>
1541
+ Use plain 6-mer decoding when 6-base granularity is fine: throughput-bound
1542
+ generation, long retrieval haystacks, large-scale screening. Reach for
1543
+ bp-level <em>generation</em> when you need exact base counts, per-position
1544
+ masks, or temperature applied at the base axis rather than the 4,096-way
1545
+ 6-mer axis. Reach for bp-level <em>scoring</em> whenever the task is about
1546
+ a specific base: variant-effect prediction, single-nucleotide mutational
1547
+ scans, comparing the likelihood of a reference and an alternate allele at
1548
+ one position. Two complementary delivery paths: generation ships as a
1549
+ transformers <code>custom_generate</code> method at
1550
+ <code>HuggingFaceBio/carbon-generate</code> that works on the plain
1551
+ <code>Carbon-3B</code>/<code>8B</code>/<code>500M</code> checkpoints
1552
+ (standard <code>LlamaForCausalLM</code>, no custom modeling file).
1553
+ Scoring ships in the <code>-remote</code> variants of those same
1554
+ checkpoints, which add a <code>score_sequence(seq)</code> method that
1555
+ returns per-base distributions and the probability of the observed base
1556
+ at every position.
1557
  </div>
1558
 
1559
  <details class="code-snippet">
1560
  <summary>Run this from code</summary>
1561
  <div class="code-snippet__body">
1562
  <div class="code-snippet__tabs">
1563
+ <button class="code-snippet__tab active" data-tab="generate" type="button">generate</button>
1564
+ <button class="code-snippet__tab" data-tab="score" type="button">score</button>
1565
  </div>
1566
  <button class="code-snippet__copy" type="button">Copy</button>
1567
+ <div class="code-snippet__panel active" data-tab="generate"><pre><code>from transformers import AutoModelForCausalLM, AutoTokenizer
1568
  import torch
1569
 
1570
  tok = AutoTokenizer.from_pretrained(
 
1578
  prompt = "&lt;dna&gt;ATGCGCTAGCTACGATCGATCGTAGCTAGCTAGCTAGCTACG"
1579
  inputs = tok(prompt, return_tensors="pt", add_special_tokens=False).to(model.device)
1580
 
1581
+ # `custom_generate` injects a logits processor that marginalizes the
1582
  # 6-mer logits to per-base distributions and samples each of the 6
1583
  # positions independently, then forces the matching 6-mer token. All
1584
  # standard generation knobs (temperature, top_p, top_k, repetition_penalty)
 
1595
  # Slice off the prompt and decode the continuation as plain DNA.
1596
  new_ids = out[0, inputs["input_ids"].shape[1]:]
1597
  print(tok.decode(new_ids, skip_special_tokens=True))</code></pre></div>
1598
+ <div class="code-snippet__panel" data-tab="score"><pre><code>from transformers import AutoModelForCausalLM, AutoTokenizer
1599
+ import torch, math
1600
+
1601
+ # The -remote variants bundle modeling code that exposes
1602
+ # `score_sequence(seq)` directly on the model. It returns, for every
1603
+ # position in the input DNA, the marginal P(base | context) and the
1604
+ # probability of the observed base.
1605
+ tok = AutoTokenizer.from_pretrained(
1606
+ "HuggingFaceBio/Carbon-3B-remote", trust_remote_code=True,
1607
+ )
1608
+ model = AutoModelForCausalLM.from_pretrained(
1609
+ "HuggingFaceBio/Carbon-3B-remote",
1610
+ trust_remote_code=True,
1611
+ dtype=torch.bfloat16, device_map="auto",
1612
+ )
1613
+
1614
+ ref = "ATGCGCTAGCTACGATCGATCGTAGCTAGCTAGCTAGCTACG"
1615
+ alt = ref[:20] + "G" + ref[21:] # single-base substitution at pos 20
1616
+
1617
+ # bp_probs: [seq_len, 4] marginal P(A/T/C/G | context) at each position
1618
+ # actual: [seq_len] P(observed base | context) at each position
1619
+ bp_probs_ref, actual_ref = model.score_sequence(ref)
1620
+ bp_probs_alt, actual_alt = model.score_sequence(alt)
1621
+
1622
+ # log-likelihood delta at the substituted position
1623
+ # is the per-base variant-effect score in its simplest form.
1624
+ delta = math.log(actual_alt[20].item() + 1e-12) \
1625
+ - math.log(actual_ref[20].item() + 1e-12)
1626
+ print(f"log P(alt) - log P(ref) at pos 20: {delta:+.3f}")</code></pre></div>
1627
  </div>
1628
  </details>
1629
  </div>