Spaces:
Sleeping
Sleeping
File size: 38,854 Bytes
25669cc 0c6b9b9 25669cc 0c6b9b9 25669cc 9573dc0 25669cc 038ad80 25669cc 0c6b9b9 25669cc 038ad80 25669cc 0c6b9b9 25669cc 9573dc0 25669cc 038ad80 0c6b9b9 25669cc 038ad80 25669cc 038ad80 605be16 25669cc 038ad80 25669cc 038ad80 25669cc 0c6b9b9 25669cc 0c6b9b9 25669cc 038ad80 25669cc 0c6b9b9 25669cc 0c6b9b9 25669cc 0c6b9b9 bbe3d0a 2ce50b6 bbe3d0a 2ce50b6 bbe3d0a 2ce50b6 bbe3d0a 2ce50b6 bbe3d0a 2ce50b6 bbe3d0a 2ce50b6 bbe3d0a 2ce50b6 bbe3d0a 2ce50b6 bbe3d0a 2ce50b6 bbe3d0a 2ce50b6 bbe3d0a 0c6b9b9 038ad80 0c6b9b9 038ad80 0c6b9b9 038ad80 0c6b9b9 038ad80 0c6b9b9 038ad80 0c6b9b9 038ad80 0c6b9b9 038ad80 0c6b9b9 038ad80 0c6b9b9 038ad80 0c6b9b9 038ad80 0c6b9b9 038ad80 f48b1be 0c6b9b9 f48b1be 038ad80 f48b1be 038ad80 f48b1be 2ce50b6 f48b1be 2ce50b6 f48b1be 2ce50b6 f48b1be 2ce50b6 038ad80 2ce50b6 038ad80 f48b1be 0c6b9b9 038ad80 0c6b9b9 038ad80 0c6b9b9 038ad80 0c6b9b9 038ad80 0c6b9b9 038ad80 0c6b9b9 25669cc f48b1be 25669cc 2ce50b6 038ad80 25669cc 038ad80 25669cc 0c6b9b9 25669cc 0c6b9b9 25669cc 038ad80 0c6b9b9 038ad80 25669cc 0c6b9b9 25669cc 038ad80 25669cc 038ad80 25669cc 0c6b9b9 25669cc 0c6b9b9 25669cc 038ad80 25669cc 0c6b9b9 038ad80 0c6b9b9 f48b1be 0c6b9b9 f48b1be 25669cc bbe3d0a 25669cc 038ad80 2ce50b6 25669cc 43ed1fb 2ce50b6 43ed1fb f48b1be 43ed1fb 2ce50b6 038ad80 f48b1be 038ad80 43ed1fb 2ce50b6 f48b1be 2ce50b6 f48b1be 2ce50b6 43ed1fb f48b1be 43ed1fb 2ce50b6 f48b1be 2ce50b6 25669cc bbe3d0a 2ce50b6 bbe3d0a 2ce50b6 bbe3d0a 2ce50b6 bbe3d0a 2ce50b6 bbe3d0a 2ce50b6 bbe3d0a 2ce50b6 bbe3d0a 43ed1fb 038ad80 0c6b9b9 038ad80 43ed1fb 0c6b9b9 43ed1fb f48b1be 43ed1fb f48b1be bbe3d0a 43ed1fb 0c6b9b9 038ad80 0c6b9b9 038ad80 0c6b9b9 038ad80 0c6b9b9 038ad80 0c6b9b9 038ad80 0c6b9b9 038ad80 0c6b9b9 25669cc 0c6b9b9 25669cc 0c6b9b9 25669cc | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 | """
BERT Metagenome Embeddings - HuggingFace Spaces App
"""
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tempfile
import gradio as gr
import numpy as np
import tensorflow as tf
from huggingface_hub import hf_hub_download
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import plotly.graph_objects as go
from custom_layers import get_custom_objects
# Model config
MODEL_REPO = "genomenet/bert-metagenome"
MODEL_FILE = "bert_1k_3.h5"
WINDOW_SIZE = 1000
NUM_LAYERS = 24
EMBEDDING_DIM = 768
# Singleton model cache
_model = None
_embedding_models = {}
def get_base_model():
"""Load and cache the base model."""
global _model
if _model is None:
print("Downloading model...")
model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILE)
print(f"Loading model from {model_path}...")
_model = tf.keras.models.load_model(model_path, custom_objects=get_custom_objects(), compile=False)
print("Model loaded.")
# Print model summary for debugging
print(f"Model outputs: {_model.output_names}")
return _model
def get_embedding_model(layer_idx=21):
"""Get embedding model for a specific layer."""
global _embedding_models
if layer_idx not in _embedding_models:
model = get_base_model()
layer_name = f"layer_transformer_block_{layer_idx}"
try:
_embedding_models[layer_idx] = tf.keras.Model(
inputs=model.input,
outputs=model.get_layer(layer_name).output
)
except ValueError:
_embedding_models[layer_idx] = tf.keras.Model(
inputs=model.input,
outputs=model.get_layer("layer_transformer_block_21").output
)
return _embedding_models[layer_idx]
def get_gpu_status():
gpus = tf.config.list_physical_devices('GPU')
return f"GPU: {gpus[0].name}" if gpus else "CPU only"
# Tokenization
TOKEN_MAP = {'A': 1, 'C': 2, 'G': 3, 'T': 4, 'N': 5}
def tokenize(sequence):
sequence = sequence.upper().replace('U', 'T')
return np.array([TOKEN_MAP.get(c, 5) for c in sequence], dtype=np.int32)
def validate_sequence(sequence):
if not sequence or len(sequence.strip()) == 0:
return False, "Sequence is empty"
sequence = sequence.upper().replace('U', 'T')
valid_chars = set('ACGTNRYSWKMBDHV')
invalid = set(sequence) - valid_chars - set(' \n\r\t')
if invalid:
return False, f"Invalid characters: {invalid}"
clean = ''.join(c for c in sequence if c in valid_chars)
if len(clean) < WINDOW_SIZE:
return False, f"Sequence too short: {len(clean)} < {WINDOW_SIZE} bp"
return True, ""
def strip_fasta_header(text):
lines = text.strip().split('\n')
return ''.join(l for l in lines if not l.startswith('>')).replace(' ', '').replace('\t', '')
def compute_embedding_stats(embedding):
"""Compute statistics that may indicate sequence 'familiarity'."""
emb = np.array(embedding)
# L2 norm - magnitude of response
l2_norm = np.linalg.norm(emb)
# Mean activation
mean_act = np.mean(emb)
# Std - spread of activations
std_act = np.std(emb)
# Sparsity - fraction of near-zero activations
sparsity = np.mean(np.abs(emb) < 0.1)
# Activation entropy (discretized)
hist, _ = np.histogram(emb, bins=50, density=True)
hist = hist[hist > 0]
entropy = -np.sum(hist * np.log(hist + 1e-10))
# Kurtosis - peakedness (high = more concentrated activations)
kurtosis = np.mean(((emb - mean_act) / (std_act + 1e-10)) ** 4) - 3
return {
'l2_norm': float(l2_norm),
'mean': float(mean_act),
'std': float(std_act),
'sparsity': float(sparsity),
'entropy': float(entropy),
'kurtosis': float(kurtosis)
}
def embed_sequence(sequence, mode="mean", stride=100, layer=21):
"""Extract embeddings from sequence."""
model = get_embedding_model(layer)
seq_len = len(sequence)
embeddings = []
positions = []
for start in range(0, seq_len - WINDOW_SIZE + 1, stride):
window = sequence[start:start + WINDOW_SIZE]
tokens = np.expand_dims(tokenize(window), axis=0)
emb = model.predict(tokens, verbose=0)
embeddings.append(emb[0])
positions.append(start)
embeddings = np.array(embeddings) # (n_windows, 1000, 768)
if mode == "mean":
window_emb = np.mean(embeddings, axis=1)
return np.mean(window_emb, axis=0), window_emb, positions
elif mode == "max":
window_emb = np.max(embeddings, axis=1)
return np.max(window_emb, axis=0), window_emb, positions
elif mode == "per-window":
window_emb = np.mean(embeddings, axis=1)
return window_emb, window_emb, positions
else:
window_emb = np.mean(embeddings, axis=1)
return np.mean(window_emb, axis=0), window_emb, positions
# ln(vocab_size=6): surprise if the model predicted uniformly at random.
UNIFORM_SURPRISE = float(np.log(6))
MASK_TOKEN = 0 # PAD/OOV; used as the MLM mask slot
def compute_mlm_surprise(sequence, stride=100, mask_fraction=0.15, seed=42):
"""Per-window and per-base MLM surprise.
For each sliding window, randomly mask ~mask_fraction of positions, run one
forward pass through the full model (which ends in a Dense(vocab_size=6)),
softmax the per-position logits, and take -log(p_true) at the masked
positions. Returns:
- per_window: list of (position, mean_surprise)
- per_base_pos, per_base_vals: flat arrays of (position, surprise) samples,
one entry per (window Γ masked_position). Overlapping windows give
multiple observations per base.
"""
model = get_base_model()
tokens = tokenize(sequence)
seq_len = len(tokens)
rng = np.random.default_rng(seed)
n_mask = max(1, int(WINDOW_SIZE * mask_fraction))
starts = list(range(0, seq_len - WINDOW_SIZE + 1, stride))
if not starts:
return [], np.array([]), np.array([])
# Build all windows + mask sets, run one batched forward pass.
batch = np.zeros((len(starts), WINDOW_SIZE), dtype=np.int32)
truth = np.zeros((len(starts), WINDOW_SIZE), dtype=np.int32)
mask_idxs = []
for i, start in enumerate(starts):
w = tokens[start:start + WINDOW_SIZE]
truth[i] = w
idx = rng.choice(WINDOW_SIZE, size=n_mask, replace=False)
mask_idxs.append(idx)
w_masked = w.copy()
w_masked[idx] = MASK_TOKEN
batch[i] = w_masked
logits = model.predict(batch, verbose=0, batch_size=8) # (n_win, 1000, 6)
logits -= logits.max(axis=-1, keepdims=True)
exp_l = np.exp(logits)
probs = exp_l / exp_l.sum(axis=-1, keepdims=True)
per_window = []
per_base_pos = []
per_base_vals = []
for i, start in enumerate(starts):
idx = mask_idxs[i]
p_true = probs[i, idx, truth[i, idx]]
surprises = -np.log(np.clip(p_true, 1e-10, None))
per_window.append((start + WINDOW_SIZE // 2, float(surprises.mean())))
per_base_pos.extend((start + idx).tolist())
per_base_vals.extend(surprises.tolist())
return per_window, np.array(per_base_pos), np.array(per_base_vals)
def create_surprise_plot(per_window, per_base_pos, per_base_vals, seq_len):
"""Two-panel Plotly figure: per-window surprise line + per-base scatter."""
from plotly.subplots import make_subplots
fig = make_subplots(
rows=2, cols=1, shared_xaxes=True, vertical_spacing=0.1,
row_heights=[0.55, 0.45],
subplot_titles=(
'per-WINDOW mean surprise β one point per 1000-bp window, plotted at its center',
'per-BASE surprise β one dot per masked base (~15% of bases in each window)',
),
)
wx = [p for p, _ in per_window]
wy = [s for _, s in per_window]
fig.add_trace(go.Scatter(
x=wx, y=wy, mode='lines+markers',
line=dict(color='#18181b', width=2), marker=dict(size=7),
hovertemplate='window center: %{x} bp<br>surprise: %{y:.3f} nats<extra></extra>',
name='window mean', showlegend=False,
), row=1, col=1)
fig.add_hline(
y=UNIFORM_SURPRISE, line_dash='dash', line_color='#a1a1aa',
annotation_text=f'uniform-random baseline (ln 6 = {UNIFORM_SURPRISE:.2f})',
annotation_position='top right', annotation_font=dict(size=10, color='#71717a'),
row=1, col=1,
)
fig.add_trace(go.Scatter(
x=per_base_pos, y=per_base_vals, mode='markers',
marker=dict(size=4, color=per_base_vals, colorscale='Reds',
cmin=0, cmax=UNIFORM_SURPRISE,
colorbar=dict(title=dict(text='surprise<br>(nats)', font=dict(size=10)),
thickness=10, len=0.4, y=0.2, tickfont=dict(size=9))),
hovertemplate='base position: %{x} bp<br>surprise: %{y:.3f} nats<extra></extra>',
name='per base', showlegend=False,
), row=2, col=1)
fig.add_hline(
y=UNIFORM_SURPRISE, line_dash='dash', line_color='#a1a1aa',
row=2, col=1,
)
fig.update_xaxes(title_text='position along input sequence (bp)', row=2, col=1,
range=[0, seq_len])
fig.update_xaxes(range=[0, seq_len], row=1, col=1)
fig.update_yaxes(title_text='surprise (nats)', row=1, col=1, rangemode='tozero')
fig.update_yaxes(title_text='surprise (nats)', row=2, col=1, rangemode='tozero')
fig.update_layout(height=560, margin=dict(l=60, r=20, t=70, b=60))
for ann in fig['layout']['annotations']:
if 'font' not in ann:
ann['font'] = dict(size=11)
return fig
def create_embedding_heatmap(embedding, title="Embedding"):
"""Create a heatmap of a single embedding vector."""
embedding = np.array(embedding)
n_dims = len(embedding)
cols = 32
rows = int(np.ceil(n_dims / cols))
padded = np.full(rows * cols, np.nan)
padded[:n_dims] = embedding
grid = padded.reshape(rows, cols)
finite = embedding[np.isfinite(embedding)]
vmax = max(abs(np.nanmin(finite)), abs(np.nanmax(finite)), 0.01) if finite.size > 0 else 1.0
fig, ax = plt.subplots(figsize=(14, max(4, rows * 0.35)))
im = ax.imshow(grid, cmap='RdBu_r', vmin=-vmax, vmax=vmax, aspect='auto')
plt.colorbar(im, ax=ax, shrink=0.8, label='Activation')
ax.set_xlabel('Dimension')
ax.set_ylabel('Row')
ax.set_title(f'{title} ({n_dims} dims)')
ax.set_xticks(np.arange(0, cols, 8))
plt.tight_layout()
return fig
def create_trajectory_plot(window_embeddings, positions):
"""Create interactive trajectory heatmap."""
emb = np.array(window_embeddings)
n_windows, n_dims = emb.shape
# Subsample dimensions
step = max(1, n_dims // 100)
emb_sub = emb[:, ::step]
vmax = max(abs(np.nanmin(emb_sub)), abs(np.nanmax(emb_sub)), 0.01)
fig = go.Figure(go.Heatmap(
z=emb_sub,
x=list(range(emb_sub.shape[1])),
y=[f"{p}" for p in positions],
colorscale='RdBu_r',
zmin=-vmax, zmax=vmax,
colorbar=dict(title='Act.'),
hovertemplate='Pos: %{y} bp<br>Dim: %{x}<br>Val: %{z:.3f}<extra></extra>'
))
fig.update_layout(
xaxis=dict(title='Dimension' + (' (subsampled)' if step > 1 else '')),
yaxis=dict(title='Window start (bp)'),
height=max(350, n_windows * 15 + 100),
margin=dict(l=60, r=20, t=30, b=50)
)
return fig
def create_familiarity_plot(window_embeddings, positions):
"""Per-window L2 norm + novelty (cosine distance to sequence mean) along the sequence.
High L2 norm = strong response. High novelty = window looks different from the rest
of the sequence (the model's internal 'surprise' relative to the sequence average).
"""
from plotly.subplots import make_subplots
emb = np.array(window_embeddings)
n_windows = emb.shape[0]
l2 = np.linalg.norm(emb, axis=1)
mean_vec = emb.mean(axis=0)
mean_norm = np.linalg.norm(mean_vec) + 1e-10
cos_sim = (emb @ mean_vec) / (l2 * mean_norm + 1e-10)
novelty = 1.0 - cos_sim
fig = make_subplots(
rows=2, cols=1, shared_xaxes=True, vertical_spacing=0.14,
subplot_titles=(
'per-WINDOW L2 norm of embedding (activation magnitude at layer L)',
'per-WINDOW embedding novelty (1 β cos similarity to sequence-mean embedding)',
),
)
fig.add_trace(go.Scatter(
x=positions, y=l2, mode='lines+markers',
line=dict(color='#3b82f6', width=2), marker=dict(size=6),
hovertemplate='window start: %{x} bp<br>L2: %{y:.2f}<extra></extra>', showlegend=False
), row=1, col=1)
fig.add_trace(go.Scatter(
x=positions, y=novelty, mode='lines+markers',
line=dict(color='#ef4444', width=2), marker=dict(size=6),
hovertemplate='window start: %{x} bp<br>novelty: %{y:.3f}<extra></extra>', showlegend=False
), row=2, col=1)
fig.update_xaxes(title_text='window start (bp along input sequence)', row=2, col=1)
fig.update_yaxes(title_text='L2 norm', row=1, col=1)
fig.update_yaxes(title_text='1 β cos sim', row=2, col=1)
fig.update_layout(
height=380 if n_windows > 1 else 260,
margin=dict(l=60, r=20, t=50, b=50),
)
for ann in fig['layout']['annotations']:
ann['font'] = dict(size=11)
return fig
def create_dimension_plot(window_embeddings, positions, top_k=8):
"""Show top varying dimensions."""
emb = np.array(window_embeddings)
variances = np.var(emb, axis=0)
top_dims = np.argsort(variances)[-top_k:][::-1]
colors = ['#e41a1c', '#377eb8', '#4daf4a', '#984ea3',
'#ff7f00', '#a65628', '#f781bf', '#999999']
fig = go.Figure()
for i, dim in enumerate(top_dims):
fig.add_trace(go.Scatter(
x=positions, y=emb[:, dim],
mode='lines', name=f'd{dim}',
line=dict(color=colors[i % len(colors)], width=1.5)
))
fig.update_layout(
xaxis=dict(title='Position (bp)'),
yaxis=dict(title='Activation'),
height=300,
legend=dict(orientation='h', y=1.1),
margin=dict(l=50, r=20, t=40, b=50)
)
return fig
# Example sequence: ~3 kb slice of E. coli K-12 MG1655 around the lacZ operon
# (NC_000913.3, positions 365529-368600). Covers the lac repressor binding region,
# the lacZ gene, and flanking regulatory sequence, so per-window plots show
# real biological structure transitions.
EXAMPLE_SEQUENCE = (
"AACTGTTACCCGTAGGTAGTCACGCAACTCGCCGCACATCTGAACTTCAGCCTCCAGTACAGCGCGGCTGAA"
"ATCATCATTAAAGCGAGTGGCAACATGGAAATCGCTGATTTGTGTAGTCGGTTTATGCAGCAACGAGACGTC"
"ACGGAAAATGCCGCTCATCCGCCACATATCCTGATCTTCCAGATAACTGCCGTCACTCCAGCGCAGCACCAT"
"CACCGCGAGGCGGTTTTCTCCGGCGCGTAAAAATGCGCTCAGGTCAAATTCAGACGGCAAACGACTGTCCTG"
"GCCGTAACCGACCCAGCGCCCGTTGCACCACAGATGAAACGCCGAGTTAACGCCATCAAAAATAATTCGCGT"
"CTGGCCTTCCTGTAGCCAGCTTTCATCAACATTAAATGTGAGCGAGTAACAACCCGTCGGATTCTCCGTGGG"
"AACAAACGGCGGATTGACCGTAATGGGATAGGTCACGTTGGTGTAGATGGGCGCATCGTAACCGTGCATCTG"
"CCAGTTTGAGGGGACGACGACAGTATCGGCCTCAGGAAGATCGCACTCCAGCCAGCTTTCCGGCACCGCTTC"
"TGGTGCCGGAAACCAGGCAAAGCGCCATTCGCCATTCAGGCTGCGCAACTGTTGGGAAGGGCGATCGGTGCG"
"GGCCTCTTCGCTATTACGCCAGCTGGCGAAAGGGGGATGTGCTGCAAGGCGATTAAGTTGGGTAACGCCAGG"
"GTTTTCCCAGTCACGACGTTGTAAAACGACGGCCAGTGAATCCGTAATCATGGTCATAGCTGTTTCCTGTGT"
"GAAATTGTTATCCGCTCACAATTCCACACAACATACGAGCCGGAAGCATAAAGTGTAAAGCCTGGGGTGCCT"
"AATGAGTGAGCTAACTCACATTAATTGCGTTGCGCTCACTGCCCGCTTTCCAGTCGGGAAACCTGTCGTGCC"
"AGCTGCATTAATGAATCGGCCAACGCGCGGGGAGAGGCGGTTTGCGTATTGGGCGCCAGGGTGGTTTTTCTT"
"TTCACCAGTGAGACGGGCAACAGCTGATTGCCCTTCACCGCCTGGCCCTGAGAGAGTTGCAGCAAGCGGTCC"
"ACGCTGGTTTGCCCCAGCAGGCGAAAATCCTGTTTGATGGTGGTTAACGGCGGGATATAACATGAGCTGTCT"
"TCGGTATCGTCGTATCCCACTACCGAGATATCCGCACCAACGCGCAGCCCGGACTCGGTAATGGCGCGCATT"
"GCGCCCAGCGCCATCTGATCGTTGGCAACCAGCATCGCAGTGGGAACGATGCCCTCATTCAGCATTTGCATG"
"GTTTGTTGAAAACCGGACATGGCACTCCAGTCGCCTTCCCGTTCCGCTATCGGCTGAATTTGATTGCGAGTG"
"AGATATTTATGCCAGCCAGCCAGACGCAGACGCGCCGAGACAGAACTTAATGGGCCCGCTAACAGCGCGATT"
"TGCTGGTGACCCAATGCGACCAGATGCTCCACGCCCAGTCGCGTACCGTCTTCATGGGAGAAAATAATACTG"
"TTGATGGGTGTCTGGTCAGAGACATCAAGAAATAACGCCGGAACATTAGTGCAGGCAGCTTCCACAGCAATG"
"GCATCCTGGTCATCCAGCGGATAGTTAATGATCAGCCCACTGACGCGTTGCGCGAGAAGATTGTGCACCGCC"
"GCTTTACAGGCTTCGACGCCGCTTCGTTCTACCATCGACACCACCACGCTGGCACCCAGTTGATCGGCGCGA"
"GATTTAATCGCCGCGACAATTTGCGACGGCGCGTGCAGGGCCAGACTGGAGGTGGCAACGCCAATCAGCAAC"
"GACTGTTTGCCCGCCAGTTGTTGTGCCACGCGGTTGGGAATGTAATTCAGCTCCGCCATCGCCGCTTCCACT"
"TTTTCCCGCGTTTTCGCAGAAACGTGGCTGGCCTGGTTCACCACGCGGGAAACGGTCTGATAAGAGACACCG"
"GCATACTCTGCGACATCGTATAACGTTACTGGTTTCACATTCACCACCCTGAATTGACTCTCTTCCGGGCGC"
"TATCATGCCATACCGCGAAAGGTTTTGCGCCATTCGATGGTGTCAACGTAAATGCATGCCGCTTCGCCTTCC"
"GGCCACCAGAATAGCCTGCGATTCAACCCCTTCTTCGATCTGTTTTGCTACCCGTTGTAGCGCCGGAAGATG"
"CTTTTCCGCTGCCTGTTCAATGGTCATTGCGCTCGCCATATACACCAGATTCAGACAGCCAATCACCCGTTG"
"TTCACTGCGCAGCGGTACGGCGATAGAGGCGATCTTCTCCTCCTGATCCCAGCCGCGGTAGTTCTGTCCGTA"
"ACCCTCTTTGCGCGCGCGCGCCAGAATGGCTTCCAGCTTTAACGGTTCCCGTGCCAGTTGATAGTCATCACC"
"GGGGCGGGAGGCTAACATTTCGATTAATTCCTTGCGGTCTTGTTCCGGGCAAAAGGCCAGCCAGGTCAGGCC"
"CGAGGCGGTTTTCAGAAGCGGCAAACGTCGCCCGACCATTGCCCGGTGAAAGGATAAGCGGCTGAAACGGTG"
"AGTGGTTTCGCGTACCACCATTGCATCAACATCCAGCGTGGACACATCTGTCGGCCATACCACTTCGCGCAA"
"CAGATCGCCCAGCAGTGGGGCCGCCAGTGCAGAAATCCACTGTTCGTCACGAAATCCTTCGCTTAATTGCCG"
"CACTTTGATGGTCAGTCGAAAACTATCATCGGAGGGGCTACGGCGGACATATCCCTCTTCCTGCAGCGTCTC"
"CAGCAGTCGCCGCACAGTGGTGCGATGCAGGCCGCTGAGTTCCGCCAGCAGCCCGACGCTGGCACCGCCATC"
"AAGTTTATTTAACATATTTAATAACATTAGACCGCGGGTTAAGCCGCGCACGGTTTTGTATTCCGTCTGCTC"
"ATTGTTCTGCATATTAATTGACATTTCTATAGTTAAAACAACGTGGTGCACCTGGTGCACATTCGGGCATGT"
"TTTGATTGTAGCCGAAAACACCCTTCCTATACTGAGCGCACAATAAAAAATCATTTACATGTTTTTAACAAA"
"ATAAGTTGCGCTGTACTGTGCGCGCAACGACATTTTGTCCGAGTCGTG"
)
# Highly conserved example: ~3 kb slice across E. coli K-12 MG1655 rrnB operon
# (NC_000913.3:4035531-4038602), which contains the 16S rRNA gene (rrsB) and
# flanking regulatory regions. rRNA genes are among the most conserved in
# bacteria, so MLM surprise should be visibly lower than on lacZ, and the
# per-window embedding plot should be flatter (one coherent functional region).
EXAMPLE_16S = (
"AAATTGAAGAGTTTGATCATGGCTCAGATTGAACGCTGGCGGCAGGCCTAACACATGCAAGTCGAACGGTAA"
"CAGGAAGAAGCTTGCTTCTTTGCTGACGAGTGGCGGACGGGTGAGTAATGTCTGGGAAACTGCCTGATGGAG"
"GGGGATAACTACTGGAAACGGTAGCTAATACCGCATAACGTCGCAAGACCAAAGAGGGGGACCTTCGGGCCT"
"CTTGCCATCGGATGTGCCCAGATGGGATTAGCTAGTAGGTGGGGTAACGGCTCACCTAGGCGACGATCCCTA"
"GCTGGTCTGAGAGGATGACCAGCCACACTGGAACTGAGACACGGTCCAGACTCCTACGGGAGGCAGCAGTGG"
"GGAATATTGCACAATGGGCGCAAGCCTGATGCAGCCATGCCGCGTGTATGAAGAAGGCCTTCGGGTTGTAAA"
"GTACTTTCAGCGGGGAGGAAGGGAGTAAAGTTAATACCTTTGCTCATTGACGTTACCCGCAGAAGAAGCACC"
"GGCTAACTCCGTGCCAGCAGCCGCGGTAATACGGAGGGTGCAAGCGTTAATCGGAATTACTGGGCGTAAAGC"
"GCACGCAGGCGGTTTGTTAAGTCAGATGTGAAATCCCCGGGCTCAACCTGGGAACTGCATCTGATACTGGCA"
"AGCTTGAGTCTCGTAGAGGGGGGTAGAATTCCAGGTGTAGCGGTGAAATGCGTAGAGATCTGGAGGAATACC"
"GGTGGCGAAGGCGGCCCCCTGGACGAAGACTGACGCTCAGGTGCGAAAGCGTGGGGAGCAAACAGGATTAGA"
"TACCCTGGTAGTCCACGCCGTAAACGATGTCGACTTGGAGGTTGTGCCCTTGAGGCGTGGCTTCCGGAGCTA"
"ACGCGTTAAGTCGACCGCCTGGGGAGTACGGCCGCAAGGTTAAAACTCAAATGAATTGACGGGGGCCCGCAC"
"AAGCGGTGGAGCATGTGGTTTAATTCGATGCAACGCGAAGAACCTTACCTGGTCTTGACATCCACGGAAGTT"
"TTCAGAGATGAGAATGTGCCTTCGGGAACCGTGAGACAGGTGCTGCATGGCTGTCGTCAGCTCGTGTTGTGA"
"AATGTTGGGTTAAGTCCCGCAACGAGCGCAACCCTTATCCTTTGTTGCCAGCGGTCCGGCCGGGAACTCAAA"
"GGAGACTGCCAGTGATAAACTGGAGGAAGGTGGGGATGACGTCAAGTCATCATGGCCCTTACGACCAGGGCT"
"ACACACGTGCTACAATGGCGCATACAAAGAGAAGCGACCTCGCGAGAGCAAGCGGACCTCATAAAGTGCGTC"
"GTAGTCCGGATTGGAGTCTGCAACTCGACTCCATGAAGTCGGAATCGCTAGTAATCGTGGATCAGAATGCCA"
"CGGTGAATACGTTCCCGGGCCTTGTACACACCGCCCGTCACACCATGGGAGTGGGTTGCAAAAGAAGTAGGT"
"AGCTTAACCTTCGGGAGGGCGCTTACCACTTTGTGATTCATGACTGGGGTGAAGTCGTAACAAGGTAACCGT"
"AGGGGAACCTGCGGTTGGATCACCTCCTTACCTCAGAACAAGAAGTATACCAAACTCAACGGAGTTACCACC"
"CGGTGGCAATTGCACAATTGATCATCAGAGAGAACCAATGCTGAACATCAAGAACAATGAAGGCCAGCAAGG"
"TAATCCGAGCAGAGCTAAAGAGAACGGATACCCATGACCACGGAAGTGGTCAAGAATATAGGCATCCAAAGA"
"CGATCAGATACCGTCGTAGTTCCGACCATAAACGATGAAGAACACGTCCATATAGCCAAGCTCCGTAGGACA"
"AGAATAATGAGACAAAACACAAAGCCAACAATGAGCCCTAAGTGATGTCCGGGGAAACCAGAAAGACCCGTA"
"GCCTGAAAGATTGCCGGCCACTTGGAACGCTGGATTGAGCACCCTGTAGAACATTTGTTTGAACAGGTGCGG"
"ACCGAATAAAGCCACATGATGCAACTATGAATCTGAACTTGCAATGCTGAACGAATCGCGATAAACCTAAGG"
"CAGAAGCGTACCCGGGAACATCAATAGACTGCGATGTGATAACGTACCCAAACTTATCCCAGGGCCCGTAAC"
"TAAACTGCCCCTTTGCGCTTCGAGTAAAGGCATCAAATAGATATAGACTCATAATGCCACAGTCCAATTACA"
"TGCCCGGAAGTTATTAATACTGCGAACGTTATACATACGAAGCCGTAAGGTAATTTGATAAGCGTAACCGAT"
"AGCCCCGACAGCGAACTAGCAACCTTGGAGTATATGAACCCAAATATCTGTGAGGCCTGGAACGTCCGAGAT"
"GAGAGTGCCACATACTCAAGACTCAAAGTCACCCGAAGGGAATTTGCATATGAGCTCGTCTGGCCAGGAGTT"
"TTAAGAGGGGCGCAGATATCACCTAATACGATAGCTAGCCGAATGCACTACGCCAACCATCTAACGGACAGA"
"GTAATGAACCACAAGCTCCGAAATGATGCTGAGAGACGCATGGCCAGTTCTCATCAGCCGTCGTGGTCAATC"
"GGTGCTTGGCCATCACCATGGGGGCCCGCATCTGCCATCGACAGCGCTTTCATCGTAAACCGTCTTATGGAA"
"AGACATTACAGCCAGTGTAAAATCCCGCACACTATTAGCCATCAAATCATATAAGGCATACGGTCAGTCAGT"
"ATTCCGAAAGAACACCACCAGTGATAGTACCAAGAGCACGTATGAATACGATGCCGACCATAGCGGACAAAT"
"CTCCCAATACGAGAGTAAAATAAGCAAATAATAGATATCCATGCATGGAGTCACCACAATAGAGCGCTACGT"
"CGTCGTGAAGAGGGAAACAACCCAGACCGCCAGCTAAGGTCCCAAAGTCATGGTTAAGTGGGAAACGATGTG"
"GGAAGGCCCAGACAGCCAGGATGTTGGCTTAGAAGCAGCCATCATTTA"
)
# Random DNA: uniform ACGT, 3 kb. Should be ~uninformative β model surprise
# should be close to the uniform baseline (ln 6 β 1.79 nats) because there
# is no context-dependent pattern to learn.
_rng = np.random.default_rng(42)
EXAMPLE_RANDOM = "".join(_rng.choice(list("ACGT"), size=3000))
# Low-complexity repeat: 3 kb of AT dinucleotide repeats. Should give very
# LOW surprise β the model trivially predicts the next base from context.
EXAMPLE_REPEAT = "AT" * 1500
def process(sequence: str, mode: str, stride: int, layer: int):
"""Main processing function."""
sequence = strip_fasta_header(sequence.strip())
is_valid, error = validate_sequence(sequence)
if not is_valid:
return f"**Error**: {error}", None, None, None, None, None
embedding, window_embeddings, positions = embed_sequence(
sequence, mode=mode, stride=stride, layer=layer
)
# Save embedding
path = os.path.join(tempfile.gettempdir(), "embedding.npy")
np.save(path, embedding)
# Compute stats
if mode == "per-window":
# For per-window, compute stats on mean embedding
mean_emb = np.mean(embedding, axis=0)
stats = compute_embedding_stats(mean_emb)
else:
stats = compute_embedding_stats(embedding)
# Create summary
if mode == "per-window":
summary = f"""### Results
| | |
|---|---|
| sequence | {len(sequence):,} bp |
| layer | {layer} |
| windows | {embedding.shape[0]} |
| shape | {embedding.shape} |
**Stats** (on mean): L2={stats['l2_norm']:.1f}, entropy={stats['entropy']:.2f}
"""
else:
summary = f"""### Results
| | |
|---|---|
| sequence | {len(sequence):,} bp |
| layer | {layer} |
| mode | {mode} |
| dim | {len(embedding)} |
**Stats**: L2={stats['l2_norm']:.1f}, entropy={stats['entropy']:.2f}, sparsity={stats['sparsity']:.1%}
"""
# Create visualizations
heatmap_fig = None
if mode != "per-window":
heatmap_fig = create_embedding_heatmap(embedding, f"Layer {layer}")
multi_window = len(window_embeddings) > 1
trajectory_fig = create_trajectory_plot(window_embeddings, positions) if multi_window else None
familiarity_fig = create_familiarity_plot(window_embeddings, positions) if multi_window else None
dims_fig = create_dimension_plot(window_embeddings, positions) if multi_window else None
return summary, path, heatmap_fig, trajectory_fig, familiarity_fig, dims_fig
def process_surprise(sequence: str, stride: int, mask_fraction: float):
"""Compute MLM surprise across the sequence."""
sequence = strip_fasta_header(sequence.strip())
is_valid, error = validate_sequence(sequence)
if not is_valid:
return f"**Error**: {error}", None
per_window, per_base_pos, per_base_vals = compute_mlm_surprise(
sequence, stride=stride, mask_fraction=mask_fraction
)
if not per_window:
return "**Error**: sequence too short for one window", None
fig = create_surprise_plot(per_window, per_base_pos, per_base_vals, len(sequence))
w_vals = np.array([s for _, s in per_window])
lo_pos, lo_val = per_window[int(np.argmin(w_vals))]
hi_pos, hi_val = per_window[int(np.argmax(w_vals))]
summary = f"""### MLM surprise
| | |
|---|---|
| sequence | {len(sequence):,} bp |
| windows | {len(per_window)} |
| mask fraction | {mask_fraction:.0%} |
| mean surprise | {w_vals.mean():.3f} nats |
| uniform baseline | {UNIFORM_SURPRISE:.3f} nats (ln 6) |
| most predictable window | {lo_val:.3f} nats @ ~{lo_pos:,} bp |
| most surprising window | {hi_val:.3f} nats @ ~{hi_pos:,} bp |
Lower = model confidently predicts the true base β conserved/typical pattern.
Higher = model is unsure β unusual region relative to training distribution.
"""
return summary, fig
# Build interface
with gr.Blocks(
title="BERT Metagenome Embeddings",
css=".gradio-container { max-width: 100% !important; }"
) as demo:
gr.Markdown(
"# bert-embedding\n"
"BERT (24 layers, 430M params) pretrained with masked-language-modeling "
"on metagenomic contigs. Input: DNA sequence β₯ 1000 bp. The model slides a "
"1000 bp window over your sequence and produces two kinds of output β "
"**embeddings** (768-dim hidden-state vector per window, under *Extract*) "
"and **MLM surprise** (model's confidence in reconstructing masked bases, "
"under *MLM surprise*). See the *Guide* tab for how to read the plots."
)
with gr.Tab("Extract"):
gr.Markdown(
"Extract 768-dim hidden-state embeddings from a chosen transformer layer. "
"Produces **per-window** vectors (one per 1000-bp window) that can be pooled, "
"visualised as heatmaps, and compared via cosine similarity. "
"The per-window plot below shows how the embedding *changes along the sequence* β "
"it does **not** say anything about whether bases are predictable "
"(that's the MLM surprise tab)."
)
with gr.Row():
with gr.Column(scale=1, min_width=260):
seq_input = gr.Textbox(
label="sequence (β₯ 1000 bp)",
placeholder="Paste DNA (FASTA or raw)...",
lines=8,
value=EXAMPLE_SEQUENCE
)
gr.Markdown("**load an example:**")
with gr.Row():
ex_lacz = gr.Button("lacZ (mixed)", size="sm")
ex_16s = gr.Button("16S rRNA (conserved)", size="sm")
with gr.Row():
ex_rand = gr.Button("random DNA", size="sm")
ex_rep = gr.Button("AT repeat (low complexity)", size="sm")
mode_input = gr.Radio(
choices=["mean", "max", "per-window"],
value="mean", label="pooling",
info="how to collapse positions within each window"
)
layer_input = gr.Slider(0, 23, value=21, step=1, label="layer (0=shallow, 23=deep)")
stride_input = gr.Slider(50, 500, value=100, step=50, label="stride (bp)",
info="step between windows. lower = more windows, more compute")
btn = gr.Button("extract", variant="primary")
with gr.Column(scale=3, min_width=500):
output = gr.Markdown()
download = gr.File(label="download embedding (.npy)")
gr.Markdown(
"**Per-window plot below.** x-axis = window-start position along your input "
"sequence. Each point is one 1000-bp window. "
"*L2 norm* = how strongly the model's neurons fire on this window (bigger = "
"stronger activation). *Novelty* = how different this window's embedding is "
"from the average embedding across your sequence (1 β cos sim to mean); "
"peaks = regions that stand out from the rest of **this** sequence."
)
familiarity_plot = gr.Plot(label="per-window L2 norm & embedding-space novelty along your input sequence")
gr.Markdown(
"**Trajectory** (left): heatmap of (windows Γ ~100 subsampled embedding "
"dimensions). Sharp horizontal bands = sudden embedding change β boundary. "
"**Top varying dims** (right): the 8 dimensions that vary most across windows, "
"plotted vs window position."
)
with gr.Row():
trajectory_plot = gr.Plot(label="window Γ dimension heatmap (trajectory)")
dims_plot = gr.Plot(label="top-8 most variable dimensions vs window position")
gr.Markdown(
"**Pooled embedding heatmap** (below): the single 768-dim vector after "
"pooling across windows β only shown in `mean` / `max` mode. "
"Red = positive activation, blue = negative."
)
heatmap_plot = gr.Plot(label="pooled 768-dim embedding heatmap (24 Γ 32 grid)")
btn.click(
process,
inputs=[seq_input, mode_input, stride_input, layer_input],
outputs=[output, download, heatmap_plot, trajectory_plot, familiarity_plot, dims_plot],
api_name="embed"
)
ex_lacz.click(lambda: EXAMPLE_SEQUENCE, outputs=seq_input)
ex_16s.click(lambda: EXAMPLE_16S, outputs=seq_input)
ex_rand.click(lambda: EXAMPLE_RANDOM, outputs=seq_input)
ex_rep.click(lambda: EXAMPLE_REPEAT, outputs=seq_input)
with gr.Tab("MLM surprise"):
gr.Markdown(
"**What it does.** For each 1000-bp window, we randomly replace ~15% of bases "
"with a mask token, ask the model to predict what was there, and record "
"**βlog p(true base)** at those positions. Low = model is confident (pattern "
"matches training distribution). High = model is uncertain (unusual region).\n\n"
"**Unit**: nats. Uniform-random guessing over {A,C,G,T,N,PAD} = **ln 6 β 1.79 nats**. "
"A perfectly confident correct prediction = **0 nats**.\n\n"
"**Expected shape on the examples below**:\n"
"- *lacZ* β moderate (~0.6β1.3 nats), with visible structure between CDS and UTR.\n"
"- *16S rRNA* β lowest values, often flat across the whole region (highly conserved).\n"
"- *random DNA* β near the 1.79 baseline, no trend (by construction unpredictable).\n"
"- *AT repeat* β near 0 (trivial to predict once you see a few bases)."
)
with gr.Row():
with gr.Column(scale=1, min_width=260):
surp_seq = gr.Textbox(
label="sequence (β₯ 1000 bp)",
placeholder="Paste DNA (FASTA or raw)...",
lines=8,
value=EXAMPLE_SEQUENCE,
)
gr.Markdown("**load an example:**")
with gr.Row():
sx_lacz = gr.Button("lacZ (mixed)", size="sm")
sx_16s = gr.Button("16S rRNA (conserved)", size="sm")
with gr.Row():
sx_rand = gr.Button("random DNA", size="sm")
sx_rep = gr.Button("AT repeat (low complexity)", size="sm")
surp_stride = gr.Slider(50, 500, value=100, step=50, label="stride (bp)",
info="step between windows. lower = more windows")
surp_mask = gr.Slider(0.05, 0.5, value=0.15, step=0.05,
label="mask fraction",
info="fraction of positions masked in each window (0.15 matches BERT training)")
surp_btn = gr.Button("score", variant="primary")
with gr.Column(scale=3, min_width=500):
surp_summary = gr.Markdown()
surp_plot = gr.Plot(label="MLM surprise along the input sequence")
surp_btn.click(
process_surprise,
inputs=[surp_seq, surp_stride, surp_mask],
outputs=[surp_summary, surp_plot],
api_name="surprise",
)
sx_lacz.click(lambda: EXAMPLE_SEQUENCE, outputs=surp_seq)
sx_16s.click(lambda: EXAMPLE_16S, outputs=surp_seq)
sx_rand.click(lambda: EXAMPLE_RANDOM, outputs=surp_seq)
sx_rep.click(lambda: EXAMPLE_REPEAT, outputs=surp_seq)
with gr.Tab("Guide"):
gr.Markdown("""
### What this space actually does
The model is a BERT trained to predict masked bases in DNA. Two things come out of it:
| | **Extract (embeddings)** | **MLM surprise** |
|---|---|---|
| What | 768-dim hidden-state vector per window (layer L) | βlog p(true base) at masked positions |
| Y-axis | L2 norm / cosine distance to sequence mean | nats |
| Measures | how the *representation* changes along the sequence | how *predictable* the bases are |
| Depends on training data? | indirectly (via what the model learned to represent) | yes β bases typical of training data get low surprise |
| "Unusual" means | this window's embedding differs from the rest of *this* sequence | this base is hard to predict from context, relative to what the model saw in pretraining |
They look similar (both have spiky plots over position) but answer **different questions**. A
region can be *embedding-novel* (structurally unlike the rest of your sequence) without being
*MLM-surprising* (it could still be a predictable pattern the model knows). And vice versa.
### Reading the per-window plots
- **x-axis** on every per-window plot = position along your input sequence, in bp.
Each point = one 1000-bp window, plotted at the window's start (Extract tab) or center
(MLM surprise tab). With `stride=100`, successive points are 100 bp apart and **overlap
by 900 bp** β so plots are smoother than they look, and a single isolated peak is usually
real, not noise.
- **y-axis units**:
- Extract β L2 norm (unitless magnitude of the 768-dim vector) and cosine-distance 1βcos.
- MLM surprise β nats (natural-log likelihood). 0 = perfectly confident, ln 6 β 1.79 = uniform.
### How the examples should look
- **lacZ (~3 kb of E. coli)** β mixed content (regulatory + coding). MLM surprise will have
peaks and troughs reflecting CDS vs UTR structure. Embedding novelty has some variation.
- **16S rRNA (~3 kb of rrnB)** β highly conserved, functional RNA. MLM surprise should be
visibly *lower* than on lacZ, often flat. Embedding trajectory looks like one consistent
region.
- **random DNA (3 kb, uniform ACGT)** β no pattern. MLM surprise should sit near the ln 6
baseline with no trend. Embedding novelty will look noisy.
- **AT repeat (3 kb of ATATATβ¦)** β trivially predictable. MLM surprise should be very low
(~0) across the whole sequence.
### Caveats
- MLM surprise depends on which token we use as [MASK]. This space uses token `0` (PAD/OOV).
If the original pretraining used a different sentinel, scores will be uniformly *pessimistic*
(near the 1.79 baseline on everything). If you see that on the conserved 16S example, the
mask token is wrong and we need to switch to `5` (N/AMB).
- Only 15% of each window is scored per run. Points in the per-base scatter are sampled, not
exhaustive β but because windows overlap 10Γ at stride 100, every base usually has several
observations.
- Both metrics are relative. Absolute values mean little; compare across regions of the same
sequence, or across the example sequences above.
""")
with gr.Tab("API"):
gr.Markdown("""
### API
```python
from gradio_client import Client
import numpy as np
client = Client("genomenet/bert-embedding")
result = client.predict(
sequence="ATGC...", # min 1000 bp
mode="mean", # mean/max/per-window
stride=100,
layer=21, # 0-23
api_name="/embed"
)
summary, emb_path, *plots = result
embedding = np.load(emb_path)
```
**Per-window plots** (along sequence position):
- **L2 norm**: activation magnitude β high = strong, structured response.
- **Novelty** (1 β cosine similarity to mean embedding): how much the window differs
from the rest of the sequence. Spikes = unusual regions relative to context.
Numeric stats (L2, entropy, sparsity, kurtosis) are in the summary text.
### MLM surprise endpoint
```python
summary, plot = client.predict(
sequence="ATGC...",
stride=100,
mask_fraction=0.15,
api_name="/surprise",
)
```
Returns per-window mean `-log(p_true)` at masked positions (in nats).
Uniform-random baseline is `ln(6) β 1.79 nats`.
""")
with gr.Tab("About"):
gr.Markdown("""
### Model
| | |
|---|---|
| architecture | BERT, 24 layers, 768 hidden, 12 heads |
| parameters | ~430M |
| input | 1000 bp sliding window |
| pretraining | metagenomic contigs + microbial genomes |
### Interpreting Statistics
The embedding statistics provide indirect measures of how the model "responds" to a sequence:
- **L2 Norm**: Total activation magnitude. Very high or low may indicate unusual sequences.
- **Entropy**: How spread out the activations are. Lower entropy suggests more confident/structured representation.
- **Sparsity**: Fraction of dimensions with near-zero activation.
- **Kurtosis**: How peaked the distribution is. Higher values = more concentrated activations.
**Note**: These are not direct "familiarity" probabilities, but patterns in these metrics across
different sequence types may reveal what the model considers typical vs. unusual.
### Links
- Model: [genomenet/bert-metagenome](https://huggingface.co/genomenet/bert-metagenome)
- CRISPR: [genomenet/crispr-array-detection](https://huggingface.co/spaces/genomenet/crispr-array-detection)
""")
if __name__ == "__main__":
print("Loading model...")
_ = get_base_model()
print(f"Ready! {get_gpu_status()}")
demo.launch(
server_name="0.0.0.0",
server_port=7860,
theme=gr.themes.Base(
primary_hue=gr.themes.colors.zinc,
neutral_hue=gr.themes.colors.zinc,
)
)
|