genomenet Claude Opus 4.7 (1M context) commited on
Commit
bbe3d0a
·
1 Parent(s): f48b1be

Add MLM surprise tab: per-base -log p(true) along the sequence

Browse files

Uses the model's MLM head (final Dense-6 over the nucleotide vocabulary) that
the model was actually pretrained with. For each sliding window we mask ~15%
of positions, run one forward pass, softmax the logits, and read off
-log(p_true) at the masked positions. Aggregation:

- per-window mean surprise -> line plot with ln(6) uniform baseline
- per-base scatter at masked positions -> finer-grained view of local spikes

Low values = model confidently reconstructs the base from context (conserved
or training-typical motifs). High values near ln(6) = model is near-uniform
(unusual relative to training distribution).

One forward pass per window at the default stride, so runtime is the same
as the embedding extraction path.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

Files changed (1) hide show
  1. app.py +176 -0
app.py CHANGED
@@ -149,6 +149,98 @@ def embed_sequence(sequence, mode="mean", stride=100, layer=21):
149
  window_emb = np.mean(embeddings, axis=1)
150
  return np.mean(window_emb, axis=0), window_emb, positions
151
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  def create_embedding_heatmap(embedding, title="Embedding"):
153
  """Create a heatmap of a single embedding vector."""
154
  embedding = np.array(embedding)
@@ -381,6 +473,43 @@ def process(sequence: str, mode: str, stride: int, layer: int):
381
 
382
  return summary, path, heatmap_fig, trajectory_fig, familiarity_fig, dims_fig
383
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
384
  # Build interface
385
  with gr.Blocks(
386
  title="BERT Metagenome Embeddings",
@@ -422,6 +551,39 @@ with gr.Blocks(
422
  api_name="embed"
423
  )
424
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
425
  with gr.Tab("API"):
426
  gr.Markdown("""
427
  ### API
@@ -448,6 +610,20 @@ embedding = np.load(emb_path)
448
  from the rest of the sequence. Spikes = unusual regions relative to context.
449
 
450
  Numeric stats (L2, entropy, sparsity, kurtosis) are in the summary text.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
451
  """)
452
 
453
  with gr.Tab("About"):
 
149
  window_emb = np.mean(embeddings, axis=1)
150
  return np.mean(window_emb, axis=0), window_emb, positions
151
 
152
+ # ln(vocab_size=6): surprise if the model predicted uniformly at random.
153
+ UNIFORM_SURPRISE = float(np.log(6))
154
+ MASK_TOKEN = 0 # PAD/OOV; used as the MLM mask slot
155
+
156
+
157
+ def compute_mlm_surprise(sequence, stride=100, mask_fraction=0.15, seed=42):
158
+ """Per-window and per-base MLM surprise.
159
+
160
+ For each sliding window, randomly mask ~mask_fraction of positions, run one
161
+ forward pass through the full model (which ends in a Dense(vocab_size=6)),
162
+ softmax the per-position logits, and take -log(p_true) at the masked
163
+ positions. Returns:
164
+
165
+ - per_window: list of (position, mean_surprise)
166
+ - per_base_pos, per_base_vals: flat arrays of (position, surprise) samples,
167
+ one entry per (window × masked_position). Overlapping windows give
168
+ multiple observations per base.
169
+ """
170
+ model = get_base_model()
171
+ tokens = tokenize(sequence)
172
+ seq_len = len(tokens)
173
+ rng = np.random.default_rng(seed)
174
+ n_mask = max(1, int(WINDOW_SIZE * mask_fraction))
175
+
176
+ per_window = []
177
+ per_base_pos = []
178
+ per_base_vals = []
179
+
180
+ for start in range(0, seq_len - WINDOW_SIZE + 1, stride):
181
+ window = tokens[start:start + WINDOW_SIZE].copy()
182
+ true_tokens = window.copy()
183
+ mask_idx = rng.choice(WINDOW_SIZE, size=n_mask, replace=False)
184
+ window[mask_idx] = MASK_TOKEN
185
+
186
+ logits = model.predict(window[np.newaxis, :], verbose=0)[0] # (1000, 6)
187
+ logits -= logits.max(axis=-1, keepdims=True)
188
+ exp_l = np.exp(logits)
189
+ probs = exp_l / exp_l.sum(axis=-1, keepdims=True)
190
+
191
+ surprises = -np.log(np.clip(probs[mask_idx, true_tokens[mask_idx]], 1e-10, None))
192
+ per_window.append((start + WINDOW_SIZE // 2, float(surprises.mean())))
193
+ per_base_pos.extend((start + mask_idx).tolist())
194
+ per_base_vals.extend(surprises.tolist())
195
+
196
+ return per_window, np.array(per_base_pos), np.array(per_base_vals)
197
+
198
+
199
+ def create_surprise_plot(per_window, per_base_pos, per_base_vals, seq_len):
200
+ """Two-panel Plotly figure: per-window surprise line + per-base scatter."""
201
+ from plotly.subplots import make_subplots
202
+ fig = make_subplots(
203
+ rows=2, cols=1, shared_xaxes=True, vertical_spacing=0.08,
204
+ row_heights=[0.6, 0.4],
205
+ subplot_titles=('per-window mean surprise (lower = model finds region predictable)',
206
+ 'per-base surprise at masked positions (dots; darker = more surprising)')
207
+ )
208
+
209
+ wx = [p for p, _ in per_window]
210
+ wy = [s for _, s in per_window]
211
+ fig.add_trace(go.Scatter(
212
+ x=wx, y=wy, mode='lines+markers',
213
+ line=dict(color='#18181b', width=2), marker=dict(size=6),
214
+ hovertemplate='center %{x} bp<br>surprise %{y:.3f} nats<extra></extra>',
215
+ showlegend=False,
216
+ ), row=1, col=1)
217
+ fig.add_hline(
218
+ y=UNIFORM_SURPRISE, line_dash='dash', line_color='#a1a1aa',
219
+ annotation_text=f'uniform baseline (ln 6 = {UNIFORM_SURPRISE:.2f})',
220
+ annotation_position='top right', annotation_font=dict(size=10, color='#71717a'),
221
+ row=1, col=1,
222
+ )
223
+ fig.add_trace(go.Scatter(
224
+ x=per_base_pos, y=per_base_vals, mode='markers',
225
+ marker=dict(size=4, color=per_base_vals, colorscale='Reds',
226
+ cmin=0, cmax=UNIFORM_SURPRISE,
227
+ colorbar=dict(title=dict(text='nats', font=dict(size=10)),
228
+ thickness=10, len=0.35, y=0.18, tickfont=dict(size=9))),
229
+ hovertemplate='pos %{x} bp<br>surprise %{y:.3f}<extra></extra>',
230
+ showlegend=False,
231
+ ), row=2, col=1)
232
+
233
+ fig.update_xaxes(title_text='position (bp)', row=2, col=1, range=[0, seq_len])
234
+ fig.update_xaxes(range=[0, seq_len], row=1, col=1)
235
+ fig.update_yaxes(title_text='nats', row=1, col=1, rangemode='tozero')
236
+ fig.update_yaxes(title_text='nats', row=2, col=1, rangemode='tozero')
237
+ fig.update_layout(height=520, margin=dict(l=50, r=20, t=50, b=50))
238
+ for ann in fig['layout']['annotations']:
239
+ if 'font' not in ann:
240
+ ann['font'] = dict(size=11)
241
+ return fig
242
+
243
+
244
  def create_embedding_heatmap(embedding, title="Embedding"):
245
  """Create a heatmap of a single embedding vector."""
246
  embedding = np.array(embedding)
 
473
 
474
  return summary, path, heatmap_fig, trajectory_fig, familiarity_fig, dims_fig
475
 
476
+
477
+ def process_surprise(sequence: str, stride: int, mask_fraction: float):
478
+ """Compute MLM surprise across the sequence."""
479
+ sequence = strip_fasta_header(sequence.strip())
480
+ is_valid, error = validate_sequence(sequence)
481
+ if not is_valid:
482
+ return f"**Error**: {error}", None
483
+
484
+ per_window, per_base_pos, per_base_vals = compute_mlm_surprise(
485
+ sequence, stride=stride, mask_fraction=mask_fraction
486
+ )
487
+ if not per_window:
488
+ return "**Error**: sequence too short for one window", None
489
+
490
+ fig = create_surprise_plot(per_window, per_base_pos, per_base_vals, len(sequence))
491
+
492
+ w_vals = np.array([s for _, s in per_window])
493
+ lo_pos, lo_val = per_window[int(np.argmin(w_vals))]
494
+ hi_pos, hi_val = per_window[int(np.argmax(w_vals))]
495
+ summary = f"""### MLM surprise
496
+
497
+ | | |
498
+ |---|---|
499
+ | sequence | {len(sequence):,} bp |
500
+ | windows | {len(per_window)} |
501
+ | mask fraction | {mask_fraction:.0%} |
502
+ | mean surprise | {w_vals.mean():.3f} nats |
503
+ | uniform baseline | {UNIFORM_SURPRISE:.3f} nats (ln 6) |
504
+ | most predictable window | {lo_val:.3f} nats @ ~{lo_pos:,} bp |
505
+ | most surprising window | {hi_val:.3f} nats @ ~{hi_pos:,} bp |
506
+
507
+ Lower = model confidently predicts the true base → conserved/typical pattern.
508
+ Higher = model is unsure → unusual region relative to training distribution.
509
+ """
510
+ return summary, fig
511
+
512
+
513
  # Build interface
514
  with gr.Blocks(
515
  title="BERT Metagenome Embeddings",
 
551
  api_name="embed"
552
  )
553
 
554
+ with gr.Tab("MLM surprise"):
555
+ gr.Markdown("""
556
+ Per-base "surprise" from the model's masked-language-modeling head.
557
+ Each window randomly masks ~15% of positions, one forward pass predicts them,
558
+ and we measure how hard the model finds each true base to reconstruct.
559
+ **Lower** = conserved/predictable pattern. **Higher** = unusual region.
560
+ """)
561
+ with gr.Row():
562
+ with gr.Column(scale=1, min_width=260):
563
+ surp_seq = gr.Textbox(
564
+ label="sequence",
565
+ placeholder="Paste DNA (FASTA or raw)...",
566
+ lines=8,
567
+ value=EXAMPLE_SEQUENCE,
568
+ )
569
+ surp_stride = gr.Slider(50, 500, value=100, step=50, label="stride",
570
+ info="lower = finer resolution, more compute")
571
+ surp_mask = gr.Slider(0.05, 0.5, value=0.15, step=0.05,
572
+ label="mask fraction",
573
+ info="fraction of positions masked per window")
574
+ surp_btn = gr.Button("score", variant="primary")
575
+
576
+ with gr.Column(scale=3, min_width=500):
577
+ surp_summary = gr.Markdown()
578
+ surp_plot = gr.Plot(label="surprise along sequence")
579
+
580
+ surp_btn.click(
581
+ process_surprise,
582
+ inputs=[surp_seq, surp_stride, surp_mask],
583
+ outputs=[surp_summary, surp_plot],
584
+ api_name="surprise",
585
+ )
586
+
587
  with gr.Tab("API"):
588
  gr.Markdown("""
589
  ### API
 
610
  from the rest of the sequence. Spikes = unusual regions relative to context.
611
 
612
  Numeric stats (L2, entropy, sparsity, kurtosis) are in the summary text.
613
+
614
+ ### MLM surprise endpoint
615
+
616
+ ```python
617
+ summary, plot = client.predict(
618
+ sequence="ATGC...",
619
+ stride=100,
620
+ mask_fraction=0.15,
621
+ api_name="/surprise",
622
+ )
623
+ ```
624
+
625
+ Returns per-window mean `-log(p_true)` at masked positions (in nats).
626
+ Uniform-random baseline is `ln(6) ≈ 1.79 nats`.
627
  """)
628
 
629
  with gr.Tab("About"):