bert-embedding / app.py
genomenet's picture
Add example library, per-tab documentation, Guide tab, clearer axis labels
2ce50b6
"""
BERT Metagenome Embeddings - HuggingFace Spaces App
"""
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tempfile
import gradio as gr
import numpy as np
import tensorflow as tf
from huggingface_hub import hf_hub_download
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import plotly.graph_objects as go
from custom_layers import get_custom_objects
# Model config
MODEL_REPO = "genomenet/bert-metagenome"
MODEL_FILE = "bert_1k_3.h5"
WINDOW_SIZE = 1000
NUM_LAYERS = 24
EMBEDDING_DIM = 768
# Singleton model cache
_model = None
_embedding_models = {}
def get_base_model():
"""Load and cache the base model."""
global _model
if _model is None:
print("Downloading model...")
model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILE)
print(f"Loading model from {model_path}...")
_model = tf.keras.models.load_model(model_path, custom_objects=get_custom_objects(), compile=False)
print("Model loaded.")
# Print model summary for debugging
print(f"Model outputs: {_model.output_names}")
return _model
def get_embedding_model(layer_idx=21):
"""Get embedding model for a specific layer."""
global _embedding_models
if layer_idx not in _embedding_models:
model = get_base_model()
layer_name = f"layer_transformer_block_{layer_idx}"
try:
_embedding_models[layer_idx] = tf.keras.Model(
inputs=model.input,
outputs=model.get_layer(layer_name).output
)
except ValueError:
_embedding_models[layer_idx] = tf.keras.Model(
inputs=model.input,
outputs=model.get_layer("layer_transformer_block_21").output
)
return _embedding_models[layer_idx]
def get_gpu_status():
gpus = tf.config.list_physical_devices('GPU')
return f"GPU: {gpus[0].name}" if gpus else "CPU only"
# Tokenization
TOKEN_MAP = {'A': 1, 'C': 2, 'G': 3, 'T': 4, 'N': 5}
def tokenize(sequence):
sequence = sequence.upper().replace('U', 'T')
return np.array([TOKEN_MAP.get(c, 5) for c in sequence], dtype=np.int32)
def validate_sequence(sequence):
if not sequence or len(sequence.strip()) == 0:
return False, "Sequence is empty"
sequence = sequence.upper().replace('U', 'T')
valid_chars = set('ACGTNRYSWKMBDHV')
invalid = set(sequence) - valid_chars - set(' \n\r\t')
if invalid:
return False, f"Invalid characters: {invalid}"
clean = ''.join(c for c in sequence if c in valid_chars)
if len(clean) < WINDOW_SIZE:
return False, f"Sequence too short: {len(clean)} < {WINDOW_SIZE} bp"
return True, ""
def strip_fasta_header(text):
lines = text.strip().split('\n')
return ''.join(l for l in lines if not l.startswith('>')).replace(' ', '').replace('\t', '')
def compute_embedding_stats(embedding):
"""Compute statistics that may indicate sequence 'familiarity'."""
emb = np.array(embedding)
# L2 norm - magnitude of response
l2_norm = np.linalg.norm(emb)
# Mean activation
mean_act = np.mean(emb)
# Std - spread of activations
std_act = np.std(emb)
# Sparsity - fraction of near-zero activations
sparsity = np.mean(np.abs(emb) < 0.1)
# Activation entropy (discretized)
hist, _ = np.histogram(emb, bins=50, density=True)
hist = hist[hist > 0]
entropy = -np.sum(hist * np.log(hist + 1e-10))
# Kurtosis - peakedness (high = more concentrated activations)
kurtosis = np.mean(((emb - mean_act) / (std_act + 1e-10)) ** 4) - 3
return {
'l2_norm': float(l2_norm),
'mean': float(mean_act),
'std': float(std_act),
'sparsity': float(sparsity),
'entropy': float(entropy),
'kurtosis': float(kurtosis)
}
def embed_sequence(sequence, mode="mean", stride=100, layer=21):
"""Extract embeddings from sequence."""
model = get_embedding_model(layer)
seq_len = len(sequence)
embeddings = []
positions = []
for start in range(0, seq_len - WINDOW_SIZE + 1, stride):
window = sequence[start:start + WINDOW_SIZE]
tokens = np.expand_dims(tokenize(window), axis=0)
emb = model.predict(tokens, verbose=0)
embeddings.append(emb[0])
positions.append(start)
embeddings = np.array(embeddings) # (n_windows, 1000, 768)
if mode == "mean":
window_emb = np.mean(embeddings, axis=1)
return np.mean(window_emb, axis=0), window_emb, positions
elif mode == "max":
window_emb = np.max(embeddings, axis=1)
return np.max(window_emb, axis=0), window_emb, positions
elif mode == "per-window":
window_emb = np.mean(embeddings, axis=1)
return window_emb, window_emb, positions
else:
window_emb = np.mean(embeddings, axis=1)
return np.mean(window_emb, axis=0), window_emb, positions
# ln(vocab_size=6): surprise if the model predicted uniformly at random.
UNIFORM_SURPRISE = float(np.log(6))
MASK_TOKEN = 0 # PAD/OOV; used as the MLM mask slot
def compute_mlm_surprise(sequence, stride=100, mask_fraction=0.15, seed=42):
"""Per-window and per-base MLM surprise.
For each sliding window, randomly mask ~mask_fraction of positions, run one
forward pass through the full model (which ends in a Dense(vocab_size=6)),
softmax the per-position logits, and take -log(p_true) at the masked
positions. Returns:
- per_window: list of (position, mean_surprise)
- per_base_pos, per_base_vals: flat arrays of (position, surprise) samples,
one entry per (window × masked_position). Overlapping windows give
multiple observations per base.
"""
model = get_base_model()
tokens = tokenize(sequence)
seq_len = len(tokens)
rng = np.random.default_rng(seed)
n_mask = max(1, int(WINDOW_SIZE * mask_fraction))
starts = list(range(0, seq_len - WINDOW_SIZE + 1, stride))
if not starts:
return [], np.array([]), np.array([])
# Build all windows + mask sets, run one batched forward pass.
batch = np.zeros((len(starts), WINDOW_SIZE), dtype=np.int32)
truth = np.zeros((len(starts), WINDOW_SIZE), dtype=np.int32)
mask_idxs = []
for i, start in enumerate(starts):
w = tokens[start:start + WINDOW_SIZE]
truth[i] = w
idx = rng.choice(WINDOW_SIZE, size=n_mask, replace=False)
mask_idxs.append(idx)
w_masked = w.copy()
w_masked[idx] = MASK_TOKEN
batch[i] = w_masked
logits = model.predict(batch, verbose=0, batch_size=8) # (n_win, 1000, 6)
logits -= logits.max(axis=-1, keepdims=True)
exp_l = np.exp(logits)
probs = exp_l / exp_l.sum(axis=-1, keepdims=True)
per_window = []
per_base_pos = []
per_base_vals = []
for i, start in enumerate(starts):
idx = mask_idxs[i]
p_true = probs[i, idx, truth[i, idx]]
surprises = -np.log(np.clip(p_true, 1e-10, None))
per_window.append((start + WINDOW_SIZE // 2, float(surprises.mean())))
per_base_pos.extend((start + idx).tolist())
per_base_vals.extend(surprises.tolist())
return per_window, np.array(per_base_pos), np.array(per_base_vals)
def create_surprise_plot(per_window, per_base_pos, per_base_vals, seq_len):
"""Two-panel Plotly figure: per-window surprise line + per-base scatter."""
from plotly.subplots import make_subplots
fig = make_subplots(
rows=2, cols=1, shared_xaxes=True, vertical_spacing=0.1,
row_heights=[0.55, 0.45],
subplot_titles=(
'per-WINDOW mean surprise — one point per 1000-bp window, plotted at its center',
'per-BASE surprise — one dot per masked base (~15% of bases in each window)',
),
)
wx = [p for p, _ in per_window]
wy = [s for _, s in per_window]
fig.add_trace(go.Scatter(
x=wx, y=wy, mode='lines+markers',
line=dict(color='#18181b', width=2), marker=dict(size=7),
hovertemplate='window center: %{x} bp<br>surprise: %{y:.3f} nats<extra></extra>',
name='window mean', showlegend=False,
), row=1, col=1)
fig.add_hline(
y=UNIFORM_SURPRISE, line_dash='dash', line_color='#a1a1aa',
annotation_text=f'uniform-random baseline (ln 6 = {UNIFORM_SURPRISE:.2f})',
annotation_position='top right', annotation_font=dict(size=10, color='#71717a'),
row=1, col=1,
)
fig.add_trace(go.Scatter(
x=per_base_pos, y=per_base_vals, mode='markers',
marker=dict(size=4, color=per_base_vals, colorscale='Reds',
cmin=0, cmax=UNIFORM_SURPRISE,
colorbar=dict(title=dict(text='surprise<br>(nats)', font=dict(size=10)),
thickness=10, len=0.4, y=0.2, tickfont=dict(size=9))),
hovertemplate='base position: %{x} bp<br>surprise: %{y:.3f} nats<extra></extra>',
name='per base', showlegend=False,
), row=2, col=1)
fig.add_hline(
y=UNIFORM_SURPRISE, line_dash='dash', line_color='#a1a1aa',
row=2, col=1,
)
fig.update_xaxes(title_text='position along input sequence (bp)', row=2, col=1,
range=[0, seq_len])
fig.update_xaxes(range=[0, seq_len], row=1, col=1)
fig.update_yaxes(title_text='surprise (nats)', row=1, col=1, rangemode='tozero')
fig.update_yaxes(title_text='surprise (nats)', row=2, col=1, rangemode='tozero')
fig.update_layout(height=560, margin=dict(l=60, r=20, t=70, b=60))
for ann in fig['layout']['annotations']:
if 'font' not in ann:
ann['font'] = dict(size=11)
return fig
def create_embedding_heatmap(embedding, title="Embedding"):
"""Create a heatmap of a single embedding vector."""
embedding = np.array(embedding)
n_dims = len(embedding)
cols = 32
rows = int(np.ceil(n_dims / cols))
padded = np.full(rows * cols, np.nan)
padded[:n_dims] = embedding
grid = padded.reshape(rows, cols)
finite = embedding[np.isfinite(embedding)]
vmax = max(abs(np.nanmin(finite)), abs(np.nanmax(finite)), 0.01) if finite.size > 0 else 1.0
fig, ax = plt.subplots(figsize=(14, max(4, rows * 0.35)))
im = ax.imshow(grid, cmap='RdBu_r', vmin=-vmax, vmax=vmax, aspect='auto')
plt.colorbar(im, ax=ax, shrink=0.8, label='Activation')
ax.set_xlabel('Dimension')
ax.set_ylabel('Row')
ax.set_title(f'{title} ({n_dims} dims)')
ax.set_xticks(np.arange(0, cols, 8))
plt.tight_layout()
return fig
def create_trajectory_plot(window_embeddings, positions):
"""Create interactive trajectory heatmap."""
emb = np.array(window_embeddings)
n_windows, n_dims = emb.shape
# Subsample dimensions
step = max(1, n_dims // 100)
emb_sub = emb[:, ::step]
vmax = max(abs(np.nanmin(emb_sub)), abs(np.nanmax(emb_sub)), 0.01)
fig = go.Figure(go.Heatmap(
z=emb_sub,
x=list(range(emb_sub.shape[1])),
y=[f"{p}" for p in positions],
colorscale='RdBu_r',
zmin=-vmax, zmax=vmax,
colorbar=dict(title='Act.'),
hovertemplate='Pos: %{y} bp<br>Dim: %{x}<br>Val: %{z:.3f}<extra></extra>'
))
fig.update_layout(
xaxis=dict(title='Dimension' + (' (subsampled)' if step > 1 else '')),
yaxis=dict(title='Window start (bp)'),
height=max(350, n_windows * 15 + 100),
margin=dict(l=60, r=20, t=30, b=50)
)
return fig
def create_familiarity_plot(window_embeddings, positions):
"""Per-window L2 norm + novelty (cosine distance to sequence mean) along the sequence.
High L2 norm = strong response. High novelty = window looks different from the rest
of the sequence (the model's internal 'surprise' relative to the sequence average).
"""
from plotly.subplots import make_subplots
emb = np.array(window_embeddings)
n_windows = emb.shape[0]
l2 = np.linalg.norm(emb, axis=1)
mean_vec = emb.mean(axis=0)
mean_norm = np.linalg.norm(mean_vec) + 1e-10
cos_sim = (emb @ mean_vec) / (l2 * mean_norm + 1e-10)
novelty = 1.0 - cos_sim
fig = make_subplots(
rows=2, cols=1, shared_xaxes=True, vertical_spacing=0.14,
subplot_titles=(
'per-WINDOW L2 norm of embedding (activation magnitude at layer L)',
'per-WINDOW embedding novelty (1 − cos similarity to sequence-mean embedding)',
),
)
fig.add_trace(go.Scatter(
x=positions, y=l2, mode='lines+markers',
line=dict(color='#3b82f6', width=2), marker=dict(size=6),
hovertemplate='window start: %{x} bp<br>L2: %{y:.2f}<extra></extra>', showlegend=False
), row=1, col=1)
fig.add_trace(go.Scatter(
x=positions, y=novelty, mode='lines+markers',
line=dict(color='#ef4444', width=2), marker=dict(size=6),
hovertemplate='window start: %{x} bp<br>novelty: %{y:.3f}<extra></extra>', showlegend=False
), row=2, col=1)
fig.update_xaxes(title_text='window start (bp along input sequence)', row=2, col=1)
fig.update_yaxes(title_text='L2 norm', row=1, col=1)
fig.update_yaxes(title_text='1 − cos sim', row=2, col=1)
fig.update_layout(
height=380 if n_windows > 1 else 260,
margin=dict(l=60, r=20, t=50, b=50),
)
for ann in fig['layout']['annotations']:
ann['font'] = dict(size=11)
return fig
def create_dimension_plot(window_embeddings, positions, top_k=8):
"""Show top varying dimensions."""
emb = np.array(window_embeddings)
variances = np.var(emb, axis=0)
top_dims = np.argsort(variances)[-top_k:][::-1]
colors = ['#e41a1c', '#377eb8', '#4daf4a', '#984ea3',
'#ff7f00', '#a65628', '#f781bf', '#999999']
fig = go.Figure()
for i, dim in enumerate(top_dims):
fig.add_trace(go.Scatter(
x=positions, y=emb[:, dim],
mode='lines', name=f'd{dim}',
line=dict(color=colors[i % len(colors)], width=1.5)
))
fig.update_layout(
xaxis=dict(title='Position (bp)'),
yaxis=dict(title='Activation'),
height=300,
legend=dict(orientation='h', y=1.1),
margin=dict(l=50, r=20, t=40, b=50)
)
return fig
# Example sequence: ~3 kb slice of E. coli K-12 MG1655 around the lacZ operon
# (NC_000913.3, positions 365529-368600). Covers the lac repressor binding region,
# the lacZ gene, and flanking regulatory sequence, so per-window plots show
# real biological structure transitions.
EXAMPLE_SEQUENCE = (
"AACTGTTACCCGTAGGTAGTCACGCAACTCGCCGCACATCTGAACTTCAGCCTCCAGTACAGCGCGGCTGAA"
"ATCATCATTAAAGCGAGTGGCAACATGGAAATCGCTGATTTGTGTAGTCGGTTTATGCAGCAACGAGACGTC"
"ACGGAAAATGCCGCTCATCCGCCACATATCCTGATCTTCCAGATAACTGCCGTCACTCCAGCGCAGCACCAT"
"CACCGCGAGGCGGTTTTCTCCGGCGCGTAAAAATGCGCTCAGGTCAAATTCAGACGGCAAACGACTGTCCTG"
"GCCGTAACCGACCCAGCGCCCGTTGCACCACAGATGAAACGCCGAGTTAACGCCATCAAAAATAATTCGCGT"
"CTGGCCTTCCTGTAGCCAGCTTTCATCAACATTAAATGTGAGCGAGTAACAACCCGTCGGATTCTCCGTGGG"
"AACAAACGGCGGATTGACCGTAATGGGATAGGTCACGTTGGTGTAGATGGGCGCATCGTAACCGTGCATCTG"
"CCAGTTTGAGGGGACGACGACAGTATCGGCCTCAGGAAGATCGCACTCCAGCCAGCTTTCCGGCACCGCTTC"
"TGGTGCCGGAAACCAGGCAAAGCGCCATTCGCCATTCAGGCTGCGCAACTGTTGGGAAGGGCGATCGGTGCG"
"GGCCTCTTCGCTATTACGCCAGCTGGCGAAAGGGGGATGTGCTGCAAGGCGATTAAGTTGGGTAACGCCAGG"
"GTTTTCCCAGTCACGACGTTGTAAAACGACGGCCAGTGAATCCGTAATCATGGTCATAGCTGTTTCCTGTGT"
"GAAATTGTTATCCGCTCACAATTCCACACAACATACGAGCCGGAAGCATAAAGTGTAAAGCCTGGGGTGCCT"
"AATGAGTGAGCTAACTCACATTAATTGCGTTGCGCTCACTGCCCGCTTTCCAGTCGGGAAACCTGTCGTGCC"
"AGCTGCATTAATGAATCGGCCAACGCGCGGGGAGAGGCGGTTTGCGTATTGGGCGCCAGGGTGGTTTTTCTT"
"TTCACCAGTGAGACGGGCAACAGCTGATTGCCCTTCACCGCCTGGCCCTGAGAGAGTTGCAGCAAGCGGTCC"
"ACGCTGGTTTGCCCCAGCAGGCGAAAATCCTGTTTGATGGTGGTTAACGGCGGGATATAACATGAGCTGTCT"
"TCGGTATCGTCGTATCCCACTACCGAGATATCCGCACCAACGCGCAGCCCGGACTCGGTAATGGCGCGCATT"
"GCGCCCAGCGCCATCTGATCGTTGGCAACCAGCATCGCAGTGGGAACGATGCCCTCATTCAGCATTTGCATG"
"GTTTGTTGAAAACCGGACATGGCACTCCAGTCGCCTTCCCGTTCCGCTATCGGCTGAATTTGATTGCGAGTG"
"AGATATTTATGCCAGCCAGCCAGACGCAGACGCGCCGAGACAGAACTTAATGGGCCCGCTAACAGCGCGATT"
"TGCTGGTGACCCAATGCGACCAGATGCTCCACGCCCAGTCGCGTACCGTCTTCATGGGAGAAAATAATACTG"
"TTGATGGGTGTCTGGTCAGAGACATCAAGAAATAACGCCGGAACATTAGTGCAGGCAGCTTCCACAGCAATG"
"GCATCCTGGTCATCCAGCGGATAGTTAATGATCAGCCCACTGACGCGTTGCGCGAGAAGATTGTGCACCGCC"
"GCTTTACAGGCTTCGACGCCGCTTCGTTCTACCATCGACACCACCACGCTGGCACCCAGTTGATCGGCGCGA"
"GATTTAATCGCCGCGACAATTTGCGACGGCGCGTGCAGGGCCAGACTGGAGGTGGCAACGCCAATCAGCAAC"
"GACTGTTTGCCCGCCAGTTGTTGTGCCACGCGGTTGGGAATGTAATTCAGCTCCGCCATCGCCGCTTCCACT"
"TTTTCCCGCGTTTTCGCAGAAACGTGGCTGGCCTGGTTCACCACGCGGGAAACGGTCTGATAAGAGACACCG"
"GCATACTCTGCGACATCGTATAACGTTACTGGTTTCACATTCACCACCCTGAATTGACTCTCTTCCGGGCGC"
"TATCATGCCATACCGCGAAAGGTTTTGCGCCATTCGATGGTGTCAACGTAAATGCATGCCGCTTCGCCTTCC"
"GGCCACCAGAATAGCCTGCGATTCAACCCCTTCTTCGATCTGTTTTGCTACCCGTTGTAGCGCCGGAAGATG"
"CTTTTCCGCTGCCTGTTCAATGGTCATTGCGCTCGCCATATACACCAGATTCAGACAGCCAATCACCCGTTG"
"TTCACTGCGCAGCGGTACGGCGATAGAGGCGATCTTCTCCTCCTGATCCCAGCCGCGGTAGTTCTGTCCGTA"
"ACCCTCTTTGCGCGCGCGCGCCAGAATGGCTTCCAGCTTTAACGGTTCCCGTGCCAGTTGATAGTCATCACC"
"GGGGCGGGAGGCTAACATTTCGATTAATTCCTTGCGGTCTTGTTCCGGGCAAAAGGCCAGCCAGGTCAGGCC"
"CGAGGCGGTTTTCAGAAGCGGCAAACGTCGCCCGACCATTGCCCGGTGAAAGGATAAGCGGCTGAAACGGTG"
"AGTGGTTTCGCGTACCACCATTGCATCAACATCCAGCGTGGACACATCTGTCGGCCATACCACTTCGCGCAA"
"CAGATCGCCCAGCAGTGGGGCCGCCAGTGCAGAAATCCACTGTTCGTCACGAAATCCTTCGCTTAATTGCCG"
"CACTTTGATGGTCAGTCGAAAACTATCATCGGAGGGGCTACGGCGGACATATCCCTCTTCCTGCAGCGTCTC"
"CAGCAGTCGCCGCACAGTGGTGCGATGCAGGCCGCTGAGTTCCGCCAGCAGCCCGACGCTGGCACCGCCATC"
"AAGTTTATTTAACATATTTAATAACATTAGACCGCGGGTTAAGCCGCGCACGGTTTTGTATTCCGTCTGCTC"
"ATTGTTCTGCATATTAATTGACATTTCTATAGTTAAAACAACGTGGTGCACCTGGTGCACATTCGGGCATGT"
"TTTGATTGTAGCCGAAAACACCCTTCCTATACTGAGCGCACAATAAAAAATCATTTACATGTTTTTAACAAA"
"ATAAGTTGCGCTGTACTGTGCGCGCAACGACATTTTGTCCGAGTCGTG"
)
# Highly conserved example: ~3 kb slice across E. coli K-12 MG1655 rrnB operon
# (NC_000913.3:4035531-4038602), which contains the 16S rRNA gene (rrsB) and
# flanking regulatory regions. rRNA genes are among the most conserved in
# bacteria, so MLM surprise should be visibly lower than on lacZ, and the
# per-window embedding plot should be flatter (one coherent functional region).
EXAMPLE_16S = (
"AAATTGAAGAGTTTGATCATGGCTCAGATTGAACGCTGGCGGCAGGCCTAACACATGCAAGTCGAACGGTAA"
"CAGGAAGAAGCTTGCTTCTTTGCTGACGAGTGGCGGACGGGTGAGTAATGTCTGGGAAACTGCCTGATGGAG"
"GGGGATAACTACTGGAAACGGTAGCTAATACCGCATAACGTCGCAAGACCAAAGAGGGGGACCTTCGGGCCT"
"CTTGCCATCGGATGTGCCCAGATGGGATTAGCTAGTAGGTGGGGTAACGGCTCACCTAGGCGACGATCCCTA"
"GCTGGTCTGAGAGGATGACCAGCCACACTGGAACTGAGACACGGTCCAGACTCCTACGGGAGGCAGCAGTGG"
"GGAATATTGCACAATGGGCGCAAGCCTGATGCAGCCATGCCGCGTGTATGAAGAAGGCCTTCGGGTTGTAAA"
"GTACTTTCAGCGGGGAGGAAGGGAGTAAAGTTAATACCTTTGCTCATTGACGTTACCCGCAGAAGAAGCACC"
"GGCTAACTCCGTGCCAGCAGCCGCGGTAATACGGAGGGTGCAAGCGTTAATCGGAATTACTGGGCGTAAAGC"
"GCACGCAGGCGGTTTGTTAAGTCAGATGTGAAATCCCCGGGCTCAACCTGGGAACTGCATCTGATACTGGCA"
"AGCTTGAGTCTCGTAGAGGGGGGTAGAATTCCAGGTGTAGCGGTGAAATGCGTAGAGATCTGGAGGAATACC"
"GGTGGCGAAGGCGGCCCCCTGGACGAAGACTGACGCTCAGGTGCGAAAGCGTGGGGAGCAAACAGGATTAGA"
"TACCCTGGTAGTCCACGCCGTAAACGATGTCGACTTGGAGGTTGTGCCCTTGAGGCGTGGCTTCCGGAGCTA"
"ACGCGTTAAGTCGACCGCCTGGGGAGTACGGCCGCAAGGTTAAAACTCAAATGAATTGACGGGGGCCCGCAC"
"AAGCGGTGGAGCATGTGGTTTAATTCGATGCAACGCGAAGAACCTTACCTGGTCTTGACATCCACGGAAGTT"
"TTCAGAGATGAGAATGTGCCTTCGGGAACCGTGAGACAGGTGCTGCATGGCTGTCGTCAGCTCGTGTTGTGA"
"AATGTTGGGTTAAGTCCCGCAACGAGCGCAACCCTTATCCTTTGTTGCCAGCGGTCCGGCCGGGAACTCAAA"
"GGAGACTGCCAGTGATAAACTGGAGGAAGGTGGGGATGACGTCAAGTCATCATGGCCCTTACGACCAGGGCT"
"ACACACGTGCTACAATGGCGCATACAAAGAGAAGCGACCTCGCGAGAGCAAGCGGACCTCATAAAGTGCGTC"
"GTAGTCCGGATTGGAGTCTGCAACTCGACTCCATGAAGTCGGAATCGCTAGTAATCGTGGATCAGAATGCCA"
"CGGTGAATACGTTCCCGGGCCTTGTACACACCGCCCGTCACACCATGGGAGTGGGTTGCAAAAGAAGTAGGT"
"AGCTTAACCTTCGGGAGGGCGCTTACCACTTTGTGATTCATGACTGGGGTGAAGTCGTAACAAGGTAACCGT"
"AGGGGAACCTGCGGTTGGATCACCTCCTTACCTCAGAACAAGAAGTATACCAAACTCAACGGAGTTACCACC"
"CGGTGGCAATTGCACAATTGATCATCAGAGAGAACCAATGCTGAACATCAAGAACAATGAAGGCCAGCAAGG"
"TAATCCGAGCAGAGCTAAAGAGAACGGATACCCATGACCACGGAAGTGGTCAAGAATATAGGCATCCAAAGA"
"CGATCAGATACCGTCGTAGTTCCGACCATAAACGATGAAGAACACGTCCATATAGCCAAGCTCCGTAGGACA"
"AGAATAATGAGACAAAACACAAAGCCAACAATGAGCCCTAAGTGATGTCCGGGGAAACCAGAAAGACCCGTA"
"GCCTGAAAGATTGCCGGCCACTTGGAACGCTGGATTGAGCACCCTGTAGAACATTTGTTTGAACAGGTGCGG"
"ACCGAATAAAGCCACATGATGCAACTATGAATCTGAACTTGCAATGCTGAACGAATCGCGATAAACCTAAGG"
"CAGAAGCGTACCCGGGAACATCAATAGACTGCGATGTGATAACGTACCCAAACTTATCCCAGGGCCCGTAAC"
"TAAACTGCCCCTTTGCGCTTCGAGTAAAGGCATCAAATAGATATAGACTCATAATGCCACAGTCCAATTACA"
"TGCCCGGAAGTTATTAATACTGCGAACGTTATACATACGAAGCCGTAAGGTAATTTGATAAGCGTAACCGAT"
"AGCCCCGACAGCGAACTAGCAACCTTGGAGTATATGAACCCAAATATCTGTGAGGCCTGGAACGTCCGAGAT"
"GAGAGTGCCACATACTCAAGACTCAAAGTCACCCGAAGGGAATTTGCATATGAGCTCGTCTGGCCAGGAGTT"
"TTAAGAGGGGCGCAGATATCACCTAATACGATAGCTAGCCGAATGCACTACGCCAACCATCTAACGGACAGA"
"GTAATGAACCACAAGCTCCGAAATGATGCTGAGAGACGCATGGCCAGTTCTCATCAGCCGTCGTGGTCAATC"
"GGTGCTTGGCCATCACCATGGGGGCCCGCATCTGCCATCGACAGCGCTTTCATCGTAAACCGTCTTATGGAA"
"AGACATTACAGCCAGTGTAAAATCCCGCACACTATTAGCCATCAAATCATATAAGGCATACGGTCAGTCAGT"
"ATTCCGAAAGAACACCACCAGTGATAGTACCAAGAGCACGTATGAATACGATGCCGACCATAGCGGACAAAT"
"CTCCCAATACGAGAGTAAAATAAGCAAATAATAGATATCCATGCATGGAGTCACCACAATAGAGCGCTACGT"
"CGTCGTGAAGAGGGAAACAACCCAGACCGCCAGCTAAGGTCCCAAAGTCATGGTTAAGTGGGAAACGATGTG"
"GGAAGGCCCAGACAGCCAGGATGTTGGCTTAGAAGCAGCCATCATTTA"
)
# Random DNA: uniform ACGT, 3 kb. Should be ~uninformative — model surprise
# should be close to the uniform baseline (ln 6 ≈ 1.79 nats) because there
# is no context-dependent pattern to learn.
_rng = np.random.default_rng(42)
EXAMPLE_RANDOM = "".join(_rng.choice(list("ACGT"), size=3000))
# Low-complexity repeat: 3 kb of AT dinucleotide repeats. Should give very
# LOW surprise — the model trivially predicts the next base from context.
EXAMPLE_REPEAT = "AT" * 1500
def process(sequence: str, mode: str, stride: int, layer: int):
"""Main processing function."""
sequence = strip_fasta_header(sequence.strip())
is_valid, error = validate_sequence(sequence)
if not is_valid:
return f"**Error**: {error}", None, None, None, None, None
embedding, window_embeddings, positions = embed_sequence(
sequence, mode=mode, stride=stride, layer=layer
)
# Save embedding
path = os.path.join(tempfile.gettempdir(), "embedding.npy")
np.save(path, embedding)
# Compute stats
if mode == "per-window":
# For per-window, compute stats on mean embedding
mean_emb = np.mean(embedding, axis=0)
stats = compute_embedding_stats(mean_emb)
else:
stats = compute_embedding_stats(embedding)
# Create summary
if mode == "per-window":
summary = f"""### Results
| | |
|---|---|
| sequence | {len(sequence):,} bp |
| layer | {layer} |
| windows | {embedding.shape[0]} |
| shape | {embedding.shape} |
**Stats** (on mean): L2={stats['l2_norm']:.1f}, entropy={stats['entropy']:.2f}
"""
else:
summary = f"""### Results
| | |
|---|---|
| sequence | {len(sequence):,} bp |
| layer | {layer} |
| mode | {mode} |
| dim | {len(embedding)} |
**Stats**: L2={stats['l2_norm']:.1f}, entropy={stats['entropy']:.2f}, sparsity={stats['sparsity']:.1%}
"""
# Create visualizations
heatmap_fig = None
if mode != "per-window":
heatmap_fig = create_embedding_heatmap(embedding, f"Layer {layer}")
multi_window = len(window_embeddings) > 1
trajectory_fig = create_trajectory_plot(window_embeddings, positions) if multi_window else None
familiarity_fig = create_familiarity_plot(window_embeddings, positions) if multi_window else None
dims_fig = create_dimension_plot(window_embeddings, positions) if multi_window else None
return summary, path, heatmap_fig, trajectory_fig, familiarity_fig, dims_fig
def process_surprise(sequence: str, stride: int, mask_fraction: float):
"""Compute MLM surprise across the sequence."""
sequence = strip_fasta_header(sequence.strip())
is_valid, error = validate_sequence(sequence)
if not is_valid:
return f"**Error**: {error}", None
per_window, per_base_pos, per_base_vals = compute_mlm_surprise(
sequence, stride=stride, mask_fraction=mask_fraction
)
if not per_window:
return "**Error**: sequence too short for one window", None
fig = create_surprise_plot(per_window, per_base_pos, per_base_vals, len(sequence))
w_vals = np.array([s for _, s in per_window])
lo_pos, lo_val = per_window[int(np.argmin(w_vals))]
hi_pos, hi_val = per_window[int(np.argmax(w_vals))]
summary = f"""### MLM surprise
| | |
|---|---|
| sequence | {len(sequence):,} bp |
| windows | {len(per_window)} |
| mask fraction | {mask_fraction:.0%} |
| mean surprise | {w_vals.mean():.3f} nats |
| uniform baseline | {UNIFORM_SURPRISE:.3f} nats (ln 6) |
| most predictable window | {lo_val:.3f} nats @ ~{lo_pos:,} bp |
| most surprising window | {hi_val:.3f} nats @ ~{hi_pos:,} bp |
Lower = model confidently predicts the true base → conserved/typical pattern.
Higher = model is unsure → unusual region relative to training distribution.
"""
return summary, fig
# Build interface
with gr.Blocks(
title="BERT Metagenome Embeddings",
css=".gradio-container { max-width: 100% !important; }"
) as demo:
gr.Markdown(
"# bert-embedding\n"
"BERT (24 layers, 430M params) pretrained with masked-language-modeling "
"on metagenomic contigs. Input: DNA sequence ≥ 1000 bp. The model slides a "
"1000 bp window over your sequence and produces two kinds of output — "
"**embeddings** (768-dim hidden-state vector per window, under *Extract*) "
"and **MLM surprise** (model's confidence in reconstructing masked bases, "
"under *MLM surprise*). See the *Guide* tab for how to read the plots."
)
with gr.Tab("Extract"):
gr.Markdown(
"Extract 768-dim hidden-state embeddings from a chosen transformer layer. "
"Produces **per-window** vectors (one per 1000-bp window) that can be pooled, "
"visualised as heatmaps, and compared via cosine similarity. "
"The per-window plot below shows how the embedding *changes along the sequence* — "
"it does **not** say anything about whether bases are predictable "
"(that's the MLM surprise tab)."
)
with gr.Row():
with gr.Column(scale=1, min_width=260):
seq_input = gr.Textbox(
label="sequence (≥ 1000 bp)",
placeholder="Paste DNA (FASTA or raw)...",
lines=8,
value=EXAMPLE_SEQUENCE
)
gr.Markdown("**load an example:**")
with gr.Row():
ex_lacz = gr.Button("lacZ (mixed)", size="sm")
ex_16s = gr.Button("16S rRNA (conserved)", size="sm")
with gr.Row():
ex_rand = gr.Button("random DNA", size="sm")
ex_rep = gr.Button("AT repeat (low complexity)", size="sm")
mode_input = gr.Radio(
choices=["mean", "max", "per-window"],
value="mean", label="pooling",
info="how to collapse positions within each window"
)
layer_input = gr.Slider(0, 23, value=21, step=1, label="layer (0=shallow, 23=deep)")
stride_input = gr.Slider(50, 500, value=100, step=50, label="stride (bp)",
info="step between windows. lower = more windows, more compute")
btn = gr.Button("extract", variant="primary")
with gr.Column(scale=3, min_width=500):
output = gr.Markdown()
download = gr.File(label="download embedding (.npy)")
gr.Markdown(
"**Per-window plot below.** x-axis = window-start position along your input "
"sequence. Each point is one 1000-bp window. "
"*L2 norm* = how strongly the model's neurons fire on this window (bigger = "
"stronger activation). *Novelty* = how different this window's embedding is "
"from the average embedding across your sequence (1 − cos sim to mean); "
"peaks = regions that stand out from the rest of **this** sequence."
)
familiarity_plot = gr.Plot(label="per-window L2 norm & embedding-space novelty along your input sequence")
gr.Markdown(
"**Trajectory** (left): heatmap of (windows × ~100 subsampled embedding "
"dimensions). Sharp horizontal bands = sudden embedding change → boundary. "
"**Top varying dims** (right): the 8 dimensions that vary most across windows, "
"plotted vs window position."
)
with gr.Row():
trajectory_plot = gr.Plot(label="window × dimension heatmap (trajectory)")
dims_plot = gr.Plot(label="top-8 most variable dimensions vs window position")
gr.Markdown(
"**Pooled embedding heatmap** (below): the single 768-dim vector after "
"pooling across windows — only shown in `mean` / `max` mode. "
"Red = positive activation, blue = negative."
)
heatmap_plot = gr.Plot(label="pooled 768-dim embedding heatmap (24 × 32 grid)")
btn.click(
process,
inputs=[seq_input, mode_input, stride_input, layer_input],
outputs=[output, download, heatmap_plot, trajectory_plot, familiarity_plot, dims_plot],
api_name="embed"
)
ex_lacz.click(lambda: EXAMPLE_SEQUENCE, outputs=seq_input)
ex_16s.click(lambda: EXAMPLE_16S, outputs=seq_input)
ex_rand.click(lambda: EXAMPLE_RANDOM, outputs=seq_input)
ex_rep.click(lambda: EXAMPLE_REPEAT, outputs=seq_input)
with gr.Tab("MLM surprise"):
gr.Markdown(
"**What it does.** For each 1000-bp window, we randomly replace ~15% of bases "
"with a mask token, ask the model to predict what was there, and record "
"**−log p(true base)** at those positions. Low = model is confident (pattern "
"matches training distribution). High = model is uncertain (unusual region).\n\n"
"**Unit**: nats. Uniform-random guessing over {A,C,G,T,N,PAD} = **ln 6 ≈ 1.79 nats**. "
"A perfectly confident correct prediction = **0 nats**.\n\n"
"**Expected shape on the examples below**:\n"
"- *lacZ* → moderate (~0.6–1.3 nats), with visible structure between CDS and UTR.\n"
"- *16S rRNA* → lowest values, often flat across the whole region (highly conserved).\n"
"- *random DNA* → near the 1.79 baseline, no trend (by construction unpredictable).\n"
"- *AT repeat* → near 0 (trivial to predict once you see a few bases)."
)
with gr.Row():
with gr.Column(scale=1, min_width=260):
surp_seq = gr.Textbox(
label="sequence (≥ 1000 bp)",
placeholder="Paste DNA (FASTA or raw)...",
lines=8,
value=EXAMPLE_SEQUENCE,
)
gr.Markdown("**load an example:**")
with gr.Row():
sx_lacz = gr.Button("lacZ (mixed)", size="sm")
sx_16s = gr.Button("16S rRNA (conserved)", size="sm")
with gr.Row():
sx_rand = gr.Button("random DNA", size="sm")
sx_rep = gr.Button("AT repeat (low complexity)", size="sm")
surp_stride = gr.Slider(50, 500, value=100, step=50, label="stride (bp)",
info="step between windows. lower = more windows")
surp_mask = gr.Slider(0.05, 0.5, value=0.15, step=0.05,
label="mask fraction",
info="fraction of positions masked in each window (0.15 matches BERT training)")
surp_btn = gr.Button("score", variant="primary")
with gr.Column(scale=3, min_width=500):
surp_summary = gr.Markdown()
surp_plot = gr.Plot(label="MLM surprise along the input sequence")
surp_btn.click(
process_surprise,
inputs=[surp_seq, surp_stride, surp_mask],
outputs=[surp_summary, surp_plot],
api_name="surprise",
)
sx_lacz.click(lambda: EXAMPLE_SEQUENCE, outputs=surp_seq)
sx_16s.click(lambda: EXAMPLE_16S, outputs=surp_seq)
sx_rand.click(lambda: EXAMPLE_RANDOM, outputs=surp_seq)
sx_rep.click(lambda: EXAMPLE_REPEAT, outputs=surp_seq)
with gr.Tab("Guide"):
gr.Markdown("""
### What this space actually does
The model is a BERT trained to predict masked bases in DNA. Two things come out of it:
| | **Extract (embeddings)** | **MLM surprise** |
|---|---|---|
| What | 768-dim hidden-state vector per window (layer L) | −log p(true base) at masked positions |
| Y-axis | L2 norm / cosine distance to sequence mean | nats |
| Measures | how the *representation* changes along the sequence | how *predictable* the bases are |
| Depends on training data? | indirectly (via what the model learned to represent) | yes — bases typical of training data get low surprise |
| "Unusual" means | this window's embedding differs from the rest of *this* sequence | this base is hard to predict from context, relative to what the model saw in pretraining |
They look similar (both have spiky plots over position) but answer **different questions**. A
region can be *embedding-novel* (structurally unlike the rest of your sequence) without being
*MLM-surprising* (it could still be a predictable pattern the model knows). And vice versa.
### Reading the per-window plots
- **x-axis** on every per-window plot = position along your input sequence, in bp.
Each point = one 1000-bp window, plotted at the window's start (Extract tab) or center
(MLM surprise tab). With `stride=100`, successive points are 100 bp apart and **overlap
by 900 bp** — so plots are smoother than they look, and a single isolated peak is usually
real, not noise.
- **y-axis units**:
- Extract → L2 norm (unitless magnitude of the 768-dim vector) and cosine-distance 1−cos.
- MLM surprise → nats (natural-log likelihood). 0 = perfectly confident, ln 6 ≈ 1.79 = uniform.
### How the examples should look
- **lacZ (~3 kb of E. coli)** — mixed content (regulatory + coding). MLM surprise will have
peaks and troughs reflecting CDS vs UTR structure. Embedding novelty has some variation.
- **16S rRNA (~3 kb of rrnB)** — highly conserved, functional RNA. MLM surprise should be
visibly *lower* than on lacZ, often flat. Embedding trajectory looks like one consistent
region.
- **random DNA (3 kb, uniform ACGT)** — no pattern. MLM surprise should sit near the ln 6
baseline with no trend. Embedding novelty will look noisy.
- **AT repeat (3 kb of ATATAT…)** — trivially predictable. MLM surprise should be very low
(~0) across the whole sequence.
### Caveats
- MLM surprise depends on which token we use as [MASK]. This space uses token `0` (PAD/OOV).
If the original pretraining used a different sentinel, scores will be uniformly *pessimistic*
(near the 1.79 baseline on everything). If you see that on the conserved 16S example, the
mask token is wrong and we need to switch to `5` (N/AMB).
- Only 15% of each window is scored per run. Points in the per-base scatter are sampled, not
exhaustive — but because windows overlap 10× at stride 100, every base usually has several
observations.
- Both metrics are relative. Absolute values mean little; compare across regions of the same
sequence, or across the example sequences above.
""")
with gr.Tab("API"):
gr.Markdown("""
### API
```python
from gradio_client import Client
import numpy as np
client = Client("genomenet/bert-embedding")
result = client.predict(
sequence="ATGC...", # min 1000 bp
mode="mean", # mean/max/per-window
stride=100,
layer=21, # 0-23
api_name="/embed"
)
summary, emb_path, *plots = result
embedding = np.load(emb_path)
```
**Per-window plots** (along sequence position):
- **L2 norm**: activation magnitude — high = strong, structured response.
- **Novelty** (1 − cosine similarity to mean embedding): how much the window differs
from the rest of the sequence. Spikes = unusual regions relative to context.
Numeric stats (L2, entropy, sparsity, kurtosis) are in the summary text.
### MLM surprise endpoint
```python
summary, plot = client.predict(
sequence="ATGC...",
stride=100,
mask_fraction=0.15,
api_name="/surprise",
)
```
Returns per-window mean `-log(p_true)` at masked positions (in nats).
Uniform-random baseline is `ln(6) ≈ 1.79 nats`.
""")
with gr.Tab("About"):
gr.Markdown("""
### Model
| | |
|---|---|
| architecture | BERT, 24 layers, 768 hidden, 12 heads |
| parameters | ~430M |
| input | 1000 bp sliding window |
| pretraining | metagenomic contigs + microbial genomes |
### Interpreting Statistics
The embedding statistics provide indirect measures of how the model "responds" to a sequence:
- **L2 Norm**: Total activation magnitude. Very high or low may indicate unusual sequences.
- **Entropy**: How spread out the activations are. Lower entropy suggests more confident/structured representation.
- **Sparsity**: Fraction of dimensions with near-zero activation.
- **Kurtosis**: How peaked the distribution is. Higher values = more concentrated activations.
**Note**: These are not direct "familiarity" probabilities, but patterns in these metrics across
different sequence types may reveal what the model considers typical vs. unusual.
### Links
- Model: [genomenet/bert-metagenome](https://huggingface.co/genomenet/bert-metagenome)
- CRISPR: [genomenet/crispr-array-detection](https://huggingface.co/spaces/genomenet/crispr-array-detection)
""")
if __name__ == "__main__":
print("Loading model...")
_ = get_base_model()
print(f"Ready! {get_gpu_status()}")
demo.launch(
server_name="0.0.0.0",
server_port=7860,
theme=gr.themes.Base(
primary_hue=gr.themes.colors.zinc,
neutral_hue=gr.themes.colors.zinc,
)
)