Spaces:
Sleeping
Add MLM surprise tab: per-base -log p(true) along the sequence
Browse filesUses the model's MLM head (final Dense-6 over the nucleotide vocabulary) that
the model was actually pretrained with. For each sliding window we mask ~15%
of positions, run one forward pass, softmax the logits, and read off
-log(p_true) at the masked positions. Aggregation:
- per-window mean surprise -> line plot with ln(6) uniform baseline
- per-base scatter at masked positions -> finer-grained view of local spikes
Low values = model confidently reconstructs the base from context (conserved
or training-typical motifs). High values near ln(6) = model is near-uniform
(unusual relative to training distribution).
One forward pass per window at the default stride, so runtime is the same
as the embedding extraction path.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
@@ -149,6 +149,98 @@ def embed_sequence(sequence, mode="mean", stride=100, layer=21):
|
|
| 149 |
window_emb = np.mean(embeddings, axis=1)
|
| 150 |
return np.mean(window_emb, axis=0), window_emb, positions
|
| 151 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
def create_embedding_heatmap(embedding, title="Embedding"):
|
| 153 |
"""Create a heatmap of a single embedding vector."""
|
| 154 |
embedding = np.array(embedding)
|
|
@@ -381,6 +473,43 @@ def process(sequence: str, mode: str, stride: int, layer: int):
|
|
| 381 |
|
| 382 |
return summary, path, heatmap_fig, trajectory_fig, familiarity_fig, dims_fig
|
| 383 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 384 |
# Build interface
|
| 385 |
with gr.Blocks(
|
| 386 |
title="BERT Metagenome Embeddings",
|
|
@@ -422,6 +551,39 @@ with gr.Blocks(
|
|
| 422 |
api_name="embed"
|
| 423 |
)
|
| 424 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 425 |
with gr.Tab("API"):
|
| 426 |
gr.Markdown("""
|
| 427 |
### API
|
|
@@ -448,6 +610,20 @@ embedding = np.load(emb_path)
|
|
| 448 |
from the rest of the sequence. Spikes = unusual regions relative to context.
|
| 449 |
|
| 450 |
Numeric stats (L2, entropy, sparsity, kurtosis) are in the summary text.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 451 |
""")
|
| 452 |
|
| 453 |
with gr.Tab("About"):
|
|
|
|
| 149 |
window_emb = np.mean(embeddings, axis=1)
|
| 150 |
return np.mean(window_emb, axis=0), window_emb, positions
|
| 151 |
|
| 152 |
+
# ln(vocab_size=6): surprise if the model predicted uniformly at random.
|
| 153 |
+
UNIFORM_SURPRISE = float(np.log(6))
|
| 154 |
+
MASK_TOKEN = 0 # PAD/OOV; used as the MLM mask slot
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def compute_mlm_surprise(sequence, stride=100, mask_fraction=0.15, seed=42):
|
| 158 |
+
"""Per-window and per-base MLM surprise.
|
| 159 |
+
|
| 160 |
+
For each sliding window, randomly mask ~mask_fraction of positions, run one
|
| 161 |
+
forward pass through the full model (which ends in a Dense(vocab_size=6)),
|
| 162 |
+
softmax the per-position logits, and take -log(p_true) at the masked
|
| 163 |
+
positions. Returns:
|
| 164 |
+
|
| 165 |
+
- per_window: list of (position, mean_surprise)
|
| 166 |
+
- per_base_pos, per_base_vals: flat arrays of (position, surprise) samples,
|
| 167 |
+
one entry per (window × masked_position). Overlapping windows give
|
| 168 |
+
multiple observations per base.
|
| 169 |
+
"""
|
| 170 |
+
model = get_base_model()
|
| 171 |
+
tokens = tokenize(sequence)
|
| 172 |
+
seq_len = len(tokens)
|
| 173 |
+
rng = np.random.default_rng(seed)
|
| 174 |
+
n_mask = max(1, int(WINDOW_SIZE * mask_fraction))
|
| 175 |
+
|
| 176 |
+
per_window = []
|
| 177 |
+
per_base_pos = []
|
| 178 |
+
per_base_vals = []
|
| 179 |
+
|
| 180 |
+
for start in range(0, seq_len - WINDOW_SIZE + 1, stride):
|
| 181 |
+
window = tokens[start:start + WINDOW_SIZE].copy()
|
| 182 |
+
true_tokens = window.copy()
|
| 183 |
+
mask_idx = rng.choice(WINDOW_SIZE, size=n_mask, replace=False)
|
| 184 |
+
window[mask_idx] = MASK_TOKEN
|
| 185 |
+
|
| 186 |
+
logits = model.predict(window[np.newaxis, :], verbose=0)[0] # (1000, 6)
|
| 187 |
+
logits -= logits.max(axis=-1, keepdims=True)
|
| 188 |
+
exp_l = np.exp(logits)
|
| 189 |
+
probs = exp_l / exp_l.sum(axis=-1, keepdims=True)
|
| 190 |
+
|
| 191 |
+
surprises = -np.log(np.clip(probs[mask_idx, true_tokens[mask_idx]], 1e-10, None))
|
| 192 |
+
per_window.append((start + WINDOW_SIZE // 2, float(surprises.mean())))
|
| 193 |
+
per_base_pos.extend((start + mask_idx).tolist())
|
| 194 |
+
per_base_vals.extend(surprises.tolist())
|
| 195 |
+
|
| 196 |
+
return per_window, np.array(per_base_pos), np.array(per_base_vals)
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def create_surprise_plot(per_window, per_base_pos, per_base_vals, seq_len):
|
| 200 |
+
"""Two-panel Plotly figure: per-window surprise line + per-base scatter."""
|
| 201 |
+
from plotly.subplots import make_subplots
|
| 202 |
+
fig = make_subplots(
|
| 203 |
+
rows=2, cols=1, shared_xaxes=True, vertical_spacing=0.08,
|
| 204 |
+
row_heights=[0.6, 0.4],
|
| 205 |
+
subplot_titles=('per-window mean surprise (lower = model finds region predictable)',
|
| 206 |
+
'per-base surprise at masked positions (dots; darker = more surprising)')
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
wx = [p for p, _ in per_window]
|
| 210 |
+
wy = [s for _, s in per_window]
|
| 211 |
+
fig.add_trace(go.Scatter(
|
| 212 |
+
x=wx, y=wy, mode='lines+markers',
|
| 213 |
+
line=dict(color='#18181b', width=2), marker=dict(size=6),
|
| 214 |
+
hovertemplate='center %{x} bp<br>surprise %{y:.3f} nats<extra></extra>',
|
| 215 |
+
showlegend=False,
|
| 216 |
+
), row=1, col=1)
|
| 217 |
+
fig.add_hline(
|
| 218 |
+
y=UNIFORM_SURPRISE, line_dash='dash', line_color='#a1a1aa',
|
| 219 |
+
annotation_text=f'uniform baseline (ln 6 = {UNIFORM_SURPRISE:.2f})',
|
| 220 |
+
annotation_position='top right', annotation_font=dict(size=10, color='#71717a'),
|
| 221 |
+
row=1, col=1,
|
| 222 |
+
)
|
| 223 |
+
fig.add_trace(go.Scatter(
|
| 224 |
+
x=per_base_pos, y=per_base_vals, mode='markers',
|
| 225 |
+
marker=dict(size=4, color=per_base_vals, colorscale='Reds',
|
| 226 |
+
cmin=0, cmax=UNIFORM_SURPRISE,
|
| 227 |
+
colorbar=dict(title=dict(text='nats', font=dict(size=10)),
|
| 228 |
+
thickness=10, len=0.35, y=0.18, tickfont=dict(size=9))),
|
| 229 |
+
hovertemplate='pos %{x} bp<br>surprise %{y:.3f}<extra></extra>',
|
| 230 |
+
showlegend=False,
|
| 231 |
+
), row=2, col=1)
|
| 232 |
+
|
| 233 |
+
fig.update_xaxes(title_text='position (bp)', row=2, col=1, range=[0, seq_len])
|
| 234 |
+
fig.update_xaxes(range=[0, seq_len], row=1, col=1)
|
| 235 |
+
fig.update_yaxes(title_text='nats', row=1, col=1, rangemode='tozero')
|
| 236 |
+
fig.update_yaxes(title_text='nats', row=2, col=1, rangemode='tozero')
|
| 237 |
+
fig.update_layout(height=520, margin=dict(l=50, r=20, t=50, b=50))
|
| 238 |
+
for ann in fig['layout']['annotations']:
|
| 239 |
+
if 'font' not in ann:
|
| 240 |
+
ann['font'] = dict(size=11)
|
| 241 |
+
return fig
|
| 242 |
+
|
| 243 |
+
|
| 244 |
def create_embedding_heatmap(embedding, title="Embedding"):
|
| 245 |
"""Create a heatmap of a single embedding vector."""
|
| 246 |
embedding = np.array(embedding)
|
|
|
|
| 473 |
|
| 474 |
return summary, path, heatmap_fig, trajectory_fig, familiarity_fig, dims_fig
|
| 475 |
|
| 476 |
+
|
| 477 |
+
def process_surprise(sequence: str, stride: int, mask_fraction: float):
|
| 478 |
+
"""Compute MLM surprise across the sequence."""
|
| 479 |
+
sequence = strip_fasta_header(sequence.strip())
|
| 480 |
+
is_valid, error = validate_sequence(sequence)
|
| 481 |
+
if not is_valid:
|
| 482 |
+
return f"**Error**: {error}", None
|
| 483 |
+
|
| 484 |
+
per_window, per_base_pos, per_base_vals = compute_mlm_surprise(
|
| 485 |
+
sequence, stride=stride, mask_fraction=mask_fraction
|
| 486 |
+
)
|
| 487 |
+
if not per_window:
|
| 488 |
+
return "**Error**: sequence too short for one window", None
|
| 489 |
+
|
| 490 |
+
fig = create_surprise_plot(per_window, per_base_pos, per_base_vals, len(sequence))
|
| 491 |
+
|
| 492 |
+
w_vals = np.array([s for _, s in per_window])
|
| 493 |
+
lo_pos, lo_val = per_window[int(np.argmin(w_vals))]
|
| 494 |
+
hi_pos, hi_val = per_window[int(np.argmax(w_vals))]
|
| 495 |
+
summary = f"""### MLM surprise
|
| 496 |
+
|
| 497 |
+
| | |
|
| 498 |
+
|---|---|
|
| 499 |
+
| sequence | {len(sequence):,} bp |
|
| 500 |
+
| windows | {len(per_window)} |
|
| 501 |
+
| mask fraction | {mask_fraction:.0%} |
|
| 502 |
+
| mean surprise | {w_vals.mean():.3f} nats |
|
| 503 |
+
| uniform baseline | {UNIFORM_SURPRISE:.3f} nats (ln 6) |
|
| 504 |
+
| most predictable window | {lo_val:.3f} nats @ ~{lo_pos:,} bp |
|
| 505 |
+
| most surprising window | {hi_val:.3f} nats @ ~{hi_pos:,} bp |
|
| 506 |
+
|
| 507 |
+
Lower = model confidently predicts the true base → conserved/typical pattern.
|
| 508 |
+
Higher = model is unsure → unusual region relative to training distribution.
|
| 509 |
+
"""
|
| 510 |
+
return summary, fig
|
| 511 |
+
|
| 512 |
+
|
| 513 |
# Build interface
|
| 514 |
with gr.Blocks(
|
| 515 |
title="BERT Metagenome Embeddings",
|
|
|
|
| 551 |
api_name="embed"
|
| 552 |
)
|
| 553 |
|
| 554 |
+
with gr.Tab("MLM surprise"):
|
| 555 |
+
gr.Markdown("""
|
| 556 |
+
Per-base "surprise" from the model's masked-language-modeling head.
|
| 557 |
+
Each window randomly masks ~15% of positions, one forward pass predicts them,
|
| 558 |
+
and we measure how hard the model finds each true base to reconstruct.
|
| 559 |
+
**Lower** = conserved/predictable pattern. **Higher** = unusual region.
|
| 560 |
+
""")
|
| 561 |
+
with gr.Row():
|
| 562 |
+
with gr.Column(scale=1, min_width=260):
|
| 563 |
+
surp_seq = gr.Textbox(
|
| 564 |
+
label="sequence",
|
| 565 |
+
placeholder="Paste DNA (FASTA or raw)...",
|
| 566 |
+
lines=8,
|
| 567 |
+
value=EXAMPLE_SEQUENCE,
|
| 568 |
+
)
|
| 569 |
+
surp_stride = gr.Slider(50, 500, value=100, step=50, label="stride",
|
| 570 |
+
info="lower = finer resolution, more compute")
|
| 571 |
+
surp_mask = gr.Slider(0.05, 0.5, value=0.15, step=0.05,
|
| 572 |
+
label="mask fraction",
|
| 573 |
+
info="fraction of positions masked per window")
|
| 574 |
+
surp_btn = gr.Button("score", variant="primary")
|
| 575 |
+
|
| 576 |
+
with gr.Column(scale=3, min_width=500):
|
| 577 |
+
surp_summary = gr.Markdown()
|
| 578 |
+
surp_plot = gr.Plot(label="surprise along sequence")
|
| 579 |
+
|
| 580 |
+
surp_btn.click(
|
| 581 |
+
process_surprise,
|
| 582 |
+
inputs=[surp_seq, surp_stride, surp_mask],
|
| 583 |
+
outputs=[surp_summary, surp_plot],
|
| 584 |
+
api_name="surprise",
|
| 585 |
+
)
|
| 586 |
+
|
| 587 |
with gr.Tab("API"):
|
| 588 |
gr.Markdown("""
|
| 589 |
### API
|
|
|
|
| 610 |
from the rest of the sequence. Spikes = unusual regions relative to context.
|
| 611 |
|
| 612 |
Numeric stats (L2, entropy, sparsity, kurtosis) are in the summary text.
|
| 613 |
+
|
| 614 |
+
### MLM surprise endpoint
|
| 615 |
+
|
| 616 |
+
```python
|
| 617 |
+
summary, plot = client.predict(
|
| 618 |
+
sequence="ATGC...",
|
| 619 |
+
stride=100,
|
| 620 |
+
mask_fraction=0.15,
|
| 621 |
+
api_name="/surprise",
|
| 622 |
+
)
|
| 623 |
+
```
|
| 624 |
+
|
| 625 |
+
Returns per-window mean `-log(p_true)` at masked positions (in nats).
|
| 626 |
+
Uniform-random baseline is `ln(6) ≈ 1.79 nats`.
|
| 627 |
""")
|
| 628 |
|
| 629 |
with gr.Tab("About"):
|