""" BERT Metagenome Embeddings - HuggingFace Spaces App """ import os os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' import tempfile import gradio as gr import numpy as np import tensorflow as tf from huggingface_hub import hf_hub_download import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt import plotly.graph_objects as go from custom_layers import get_custom_objects # Model config MODEL_REPO = "genomenet/bert-metagenome" MODEL_FILE = "bert_1k_3.h5" WINDOW_SIZE = 1000 NUM_LAYERS = 24 EMBEDDING_DIM = 768 # Singleton model cache _model = None _embedding_models = {} def get_base_model(): """Load and cache the base model.""" global _model if _model is None: print("Downloading model...") model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILE) print(f"Loading model from {model_path}...") _model = tf.keras.models.load_model(model_path, custom_objects=get_custom_objects(), compile=False) print("Model loaded.") # Print model summary for debugging print(f"Model outputs: {_model.output_names}") return _model def get_embedding_model(layer_idx=21): """Get embedding model for a specific layer.""" global _embedding_models if layer_idx not in _embedding_models: model = get_base_model() layer_name = f"layer_transformer_block_{layer_idx}" try: _embedding_models[layer_idx] = tf.keras.Model( inputs=model.input, outputs=model.get_layer(layer_name).output ) except ValueError: _embedding_models[layer_idx] = tf.keras.Model( inputs=model.input, outputs=model.get_layer("layer_transformer_block_21").output ) return _embedding_models[layer_idx] def get_gpu_status(): gpus = tf.config.list_physical_devices('GPU') return f"GPU: {gpus[0].name}" if gpus else "CPU only" # Tokenization TOKEN_MAP = {'A': 1, 'C': 2, 'G': 3, 'T': 4, 'N': 5} def tokenize(sequence): sequence = sequence.upper().replace('U', 'T') return np.array([TOKEN_MAP.get(c, 5) for c in sequence], dtype=np.int32) def validate_sequence(sequence): if not sequence or len(sequence.strip()) == 0: return False, "Sequence is empty" sequence = sequence.upper().replace('U', 'T') valid_chars = set('ACGTNRYSWKMBDHV') invalid = set(sequence) - valid_chars - set(' \n\r\t') if invalid: return False, f"Invalid characters: {invalid}" clean = ''.join(c for c in sequence if c in valid_chars) if len(clean) < WINDOW_SIZE: return False, f"Sequence too short: {len(clean)} < {WINDOW_SIZE} bp" return True, "" def strip_fasta_header(text): lines = text.strip().split('\n') return ''.join(l for l in lines if not l.startswith('>')).replace(' ', '').replace('\t', '') def compute_embedding_stats(embedding): """Compute statistics that may indicate sequence 'familiarity'.""" emb = np.array(embedding) # L2 norm - magnitude of response l2_norm = np.linalg.norm(emb) # Mean activation mean_act = np.mean(emb) # Std - spread of activations std_act = np.std(emb) # Sparsity - fraction of near-zero activations sparsity = np.mean(np.abs(emb) < 0.1) # Activation entropy (discretized) hist, _ = np.histogram(emb, bins=50, density=True) hist = hist[hist > 0] entropy = -np.sum(hist * np.log(hist + 1e-10)) # Kurtosis - peakedness (high = more concentrated activations) kurtosis = np.mean(((emb - mean_act) / (std_act + 1e-10)) ** 4) - 3 return { 'l2_norm': float(l2_norm), 'mean': float(mean_act), 'std': float(std_act), 'sparsity': float(sparsity), 'entropy': float(entropy), 'kurtosis': float(kurtosis) } def embed_sequence(sequence, mode="mean", stride=100, layer=21): """Extract embeddings from sequence.""" model = get_embedding_model(layer) seq_len = len(sequence) embeddings = [] positions = [] for start in range(0, seq_len - WINDOW_SIZE + 1, stride): window = sequence[start:start + WINDOW_SIZE] tokens = np.expand_dims(tokenize(window), axis=0) emb = model.predict(tokens, verbose=0) embeddings.append(emb[0]) positions.append(start) embeddings = np.array(embeddings) # (n_windows, 1000, 768) if mode == "mean": window_emb = np.mean(embeddings, axis=1) return np.mean(window_emb, axis=0), window_emb, positions elif mode == "max": window_emb = np.max(embeddings, axis=1) return np.max(window_emb, axis=0), window_emb, positions elif mode == "per-window": window_emb = np.mean(embeddings, axis=1) return window_emb, window_emb, positions else: window_emb = np.mean(embeddings, axis=1) return np.mean(window_emb, axis=0), window_emb, positions # ln(vocab_size=6): surprise if the model predicted uniformly at random. UNIFORM_SURPRISE = float(np.log(6)) MASK_TOKEN = 0 # PAD/OOV; used as the MLM mask slot def compute_mlm_surprise(sequence, stride=100, mask_fraction=0.15, seed=42): """Per-window and per-base MLM surprise. For each sliding window, randomly mask ~mask_fraction of positions, run one forward pass through the full model (which ends in a Dense(vocab_size=6)), softmax the per-position logits, and take -log(p_true) at the masked positions. Returns: - per_window: list of (position, mean_surprise) - per_base_pos, per_base_vals: flat arrays of (position, surprise) samples, one entry per (window × masked_position). Overlapping windows give multiple observations per base. """ model = get_base_model() tokens = tokenize(sequence) seq_len = len(tokens) rng = np.random.default_rng(seed) n_mask = max(1, int(WINDOW_SIZE * mask_fraction)) starts = list(range(0, seq_len - WINDOW_SIZE + 1, stride)) if not starts: return [], np.array([]), np.array([]) # Build all windows + mask sets, run one batched forward pass. batch = np.zeros((len(starts), WINDOW_SIZE), dtype=np.int32) truth = np.zeros((len(starts), WINDOW_SIZE), dtype=np.int32) mask_idxs = [] for i, start in enumerate(starts): w = tokens[start:start + WINDOW_SIZE] truth[i] = w idx = rng.choice(WINDOW_SIZE, size=n_mask, replace=False) mask_idxs.append(idx) w_masked = w.copy() w_masked[idx] = MASK_TOKEN batch[i] = w_masked logits = model.predict(batch, verbose=0, batch_size=8) # (n_win, 1000, 6) logits -= logits.max(axis=-1, keepdims=True) exp_l = np.exp(logits) probs = exp_l / exp_l.sum(axis=-1, keepdims=True) per_window = [] per_base_pos = [] per_base_vals = [] for i, start in enumerate(starts): idx = mask_idxs[i] p_true = probs[i, idx, truth[i, idx]] surprises = -np.log(np.clip(p_true, 1e-10, None)) per_window.append((start + WINDOW_SIZE // 2, float(surprises.mean()))) per_base_pos.extend((start + idx).tolist()) per_base_vals.extend(surprises.tolist()) return per_window, np.array(per_base_pos), np.array(per_base_vals) def create_surprise_plot(per_window, per_base_pos, per_base_vals, seq_len): """Two-panel Plotly figure: per-window surprise line + per-base scatter.""" from plotly.subplots import make_subplots fig = make_subplots( rows=2, cols=1, shared_xaxes=True, vertical_spacing=0.1, row_heights=[0.55, 0.45], subplot_titles=( 'per-WINDOW mean surprise — one point per 1000-bp window, plotted at its center', 'per-BASE surprise — one dot per masked base (~15% of bases in each window)', ), ) wx = [p for p, _ in per_window] wy = [s for _, s in per_window] fig.add_trace(go.Scatter( x=wx, y=wy, mode='lines+markers', line=dict(color='#18181b', width=2), marker=dict(size=7), hovertemplate='window center: %{x} bp
surprise: %{y:.3f} nats', name='window mean', showlegend=False, ), row=1, col=1) fig.add_hline( y=UNIFORM_SURPRISE, line_dash='dash', line_color='#a1a1aa', annotation_text=f'uniform-random baseline (ln 6 = {UNIFORM_SURPRISE:.2f})', annotation_position='top right', annotation_font=dict(size=10, color='#71717a'), row=1, col=1, ) fig.add_trace(go.Scatter( x=per_base_pos, y=per_base_vals, mode='markers', marker=dict(size=4, color=per_base_vals, colorscale='Reds', cmin=0, cmax=UNIFORM_SURPRISE, colorbar=dict(title=dict(text='surprise
(nats)', font=dict(size=10)), thickness=10, len=0.4, y=0.2, tickfont=dict(size=9))), hovertemplate='base position: %{x} bp
surprise: %{y:.3f} nats', name='per base', showlegend=False, ), row=2, col=1) fig.add_hline( y=UNIFORM_SURPRISE, line_dash='dash', line_color='#a1a1aa', row=2, col=1, ) fig.update_xaxes(title_text='position along input sequence (bp)', row=2, col=1, range=[0, seq_len]) fig.update_xaxes(range=[0, seq_len], row=1, col=1) fig.update_yaxes(title_text='surprise (nats)', row=1, col=1, rangemode='tozero') fig.update_yaxes(title_text='surprise (nats)', row=2, col=1, rangemode='tozero') fig.update_layout(height=560, margin=dict(l=60, r=20, t=70, b=60)) for ann in fig['layout']['annotations']: if 'font' not in ann: ann['font'] = dict(size=11) return fig def create_embedding_heatmap(embedding, title="Embedding"): """Create a heatmap of a single embedding vector.""" embedding = np.array(embedding) n_dims = len(embedding) cols = 32 rows = int(np.ceil(n_dims / cols)) padded = np.full(rows * cols, np.nan) padded[:n_dims] = embedding grid = padded.reshape(rows, cols) finite = embedding[np.isfinite(embedding)] vmax = max(abs(np.nanmin(finite)), abs(np.nanmax(finite)), 0.01) if finite.size > 0 else 1.0 fig, ax = plt.subplots(figsize=(14, max(4, rows * 0.35))) im = ax.imshow(grid, cmap='RdBu_r', vmin=-vmax, vmax=vmax, aspect='auto') plt.colorbar(im, ax=ax, shrink=0.8, label='Activation') ax.set_xlabel('Dimension') ax.set_ylabel('Row') ax.set_title(f'{title} ({n_dims} dims)') ax.set_xticks(np.arange(0, cols, 8)) plt.tight_layout() return fig def create_trajectory_plot(window_embeddings, positions): """Create interactive trajectory heatmap.""" emb = np.array(window_embeddings) n_windows, n_dims = emb.shape # Subsample dimensions step = max(1, n_dims // 100) emb_sub = emb[:, ::step] vmax = max(abs(np.nanmin(emb_sub)), abs(np.nanmax(emb_sub)), 0.01) fig = go.Figure(go.Heatmap( z=emb_sub, x=list(range(emb_sub.shape[1])), y=[f"{p}" for p in positions], colorscale='RdBu_r', zmin=-vmax, zmax=vmax, colorbar=dict(title='Act.'), hovertemplate='Pos: %{y} bp
Dim: %{x}
Val: %{z:.3f}' )) fig.update_layout( xaxis=dict(title='Dimension' + (' (subsampled)' if step > 1 else '')), yaxis=dict(title='Window start (bp)'), height=max(350, n_windows * 15 + 100), margin=dict(l=60, r=20, t=30, b=50) ) return fig def create_familiarity_plot(window_embeddings, positions): """Per-window L2 norm + novelty (cosine distance to sequence mean) along the sequence. High L2 norm = strong response. High novelty = window looks different from the rest of the sequence (the model's internal 'surprise' relative to the sequence average). """ from plotly.subplots import make_subplots emb = np.array(window_embeddings) n_windows = emb.shape[0] l2 = np.linalg.norm(emb, axis=1) mean_vec = emb.mean(axis=0) mean_norm = np.linalg.norm(mean_vec) + 1e-10 cos_sim = (emb @ mean_vec) / (l2 * mean_norm + 1e-10) novelty = 1.0 - cos_sim fig = make_subplots( rows=2, cols=1, shared_xaxes=True, vertical_spacing=0.14, subplot_titles=( 'per-WINDOW L2 norm of embedding (activation magnitude at layer L)', 'per-WINDOW embedding novelty (1 − cos similarity to sequence-mean embedding)', ), ) fig.add_trace(go.Scatter( x=positions, y=l2, mode='lines+markers', line=dict(color='#3b82f6', width=2), marker=dict(size=6), hovertemplate='window start: %{x} bp
L2: %{y:.2f}', showlegend=False ), row=1, col=1) fig.add_trace(go.Scatter( x=positions, y=novelty, mode='lines+markers', line=dict(color='#ef4444', width=2), marker=dict(size=6), hovertemplate='window start: %{x} bp
novelty: %{y:.3f}', showlegend=False ), row=2, col=1) fig.update_xaxes(title_text='window start (bp along input sequence)', row=2, col=1) fig.update_yaxes(title_text='L2 norm', row=1, col=1) fig.update_yaxes(title_text='1 − cos sim', row=2, col=1) fig.update_layout( height=380 if n_windows > 1 else 260, margin=dict(l=60, r=20, t=50, b=50), ) for ann in fig['layout']['annotations']: ann['font'] = dict(size=11) return fig def create_dimension_plot(window_embeddings, positions, top_k=8): """Show top varying dimensions.""" emb = np.array(window_embeddings) variances = np.var(emb, axis=0) top_dims = np.argsort(variances)[-top_k:][::-1] colors = ['#e41a1c', '#377eb8', '#4daf4a', '#984ea3', '#ff7f00', '#a65628', '#f781bf', '#999999'] fig = go.Figure() for i, dim in enumerate(top_dims): fig.add_trace(go.Scatter( x=positions, y=emb[:, dim], mode='lines', name=f'd{dim}', line=dict(color=colors[i % len(colors)], width=1.5) )) fig.update_layout( xaxis=dict(title='Position (bp)'), yaxis=dict(title='Activation'), height=300, legend=dict(orientation='h', y=1.1), margin=dict(l=50, r=20, t=40, b=50) ) return fig # Example sequence: ~3 kb slice of E. coli K-12 MG1655 around the lacZ operon # (NC_000913.3, positions 365529-368600). Covers the lac repressor binding region, # the lacZ gene, and flanking regulatory sequence, so per-window plots show # real biological structure transitions. EXAMPLE_SEQUENCE = ( "AACTGTTACCCGTAGGTAGTCACGCAACTCGCCGCACATCTGAACTTCAGCCTCCAGTACAGCGCGGCTGAA" "ATCATCATTAAAGCGAGTGGCAACATGGAAATCGCTGATTTGTGTAGTCGGTTTATGCAGCAACGAGACGTC" "ACGGAAAATGCCGCTCATCCGCCACATATCCTGATCTTCCAGATAACTGCCGTCACTCCAGCGCAGCACCAT" "CACCGCGAGGCGGTTTTCTCCGGCGCGTAAAAATGCGCTCAGGTCAAATTCAGACGGCAAACGACTGTCCTG" "GCCGTAACCGACCCAGCGCCCGTTGCACCACAGATGAAACGCCGAGTTAACGCCATCAAAAATAATTCGCGT" "CTGGCCTTCCTGTAGCCAGCTTTCATCAACATTAAATGTGAGCGAGTAACAACCCGTCGGATTCTCCGTGGG" "AACAAACGGCGGATTGACCGTAATGGGATAGGTCACGTTGGTGTAGATGGGCGCATCGTAACCGTGCATCTG" "CCAGTTTGAGGGGACGACGACAGTATCGGCCTCAGGAAGATCGCACTCCAGCCAGCTTTCCGGCACCGCTTC" "TGGTGCCGGAAACCAGGCAAAGCGCCATTCGCCATTCAGGCTGCGCAACTGTTGGGAAGGGCGATCGGTGCG" "GGCCTCTTCGCTATTACGCCAGCTGGCGAAAGGGGGATGTGCTGCAAGGCGATTAAGTTGGGTAACGCCAGG" "GTTTTCCCAGTCACGACGTTGTAAAACGACGGCCAGTGAATCCGTAATCATGGTCATAGCTGTTTCCTGTGT" "GAAATTGTTATCCGCTCACAATTCCACACAACATACGAGCCGGAAGCATAAAGTGTAAAGCCTGGGGTGCCT" "AATGAGTGAGCTAACTCACATTAATTGCGTTGCGCTCACTGCCCGCTTTCCAGTCGGGAAACCTGTCGTGCC" "AGCTGCATTAATGAATCGGCCAACGCGCGGGGAGAGGCGGTTTGCGTATTGGGCGCCAGGGTGGTTTTTCTT" "TTCACCAGTGAGACGGGCAACAGCTGATTGCCCTTCACCGCCTGGCCCTGAGAGAGTTGCAGCAAGCGGTCC" "ACGCTGGTTTGCCCCAGCAGGCGAAAATCCTGTTTGATGGTGGTTAACGGCGGGATATAACATGAGCTGTCT" "TCGGTATCGTCGTATCCCACTACCGAGATATCCGCACCAACGCGCAGCCCGGACTCGGTAATGGCGCGCATT" "GCGCCCAGCGCCATCTGATCGTTGGCAACCAGCATCGCAGTGGGAACGATGCCCTCATTCAGCATTTGCATG" "GTTTGTTGAAAACCGGACATGGCACTCCAGTCGCCTTCCCGTTCCGCTATCGGCTGAATTTGATTGCGAGTG" "AGATATTTATGCCAGCCAGCCAGACGCAGACGCGCCGAGACAGAACTTAATGGGCCCGCTAACAGCGCGATT" "TGCTGGTGACCCAATGCGACCAGATGCTCCACGCCCAGTCGCGTACCGTCTTCATGGGAGAAAATAATACTG" "TTGATGGGTGTCTGGTCAGAGACATCAAGAAATAACGCCGGAACATTAGTGCAGGCAGCTTCCACAGCAATG" "GCATCCTGGTCATCCAGCGGATAGTTAATGATCAGCCCACTGACGCGTTGCGCGAGAAGATTGTGCACCGCC" "GCTTTACAGGCTTCGACGCCGCTTCGTTCTACCATCGACACCACCACGCTGGCACCCAGTTGATCGGCGCGA" "GATTTAATCGCCGCGACAATTTGCGACGGCGCGTGCAGGGCCAGACTGGAGGTGGCAACGCCAATCAGCAAC" "GACTGTTTGCCCGCCAGTTGTTGTGCCACGCGGTTGGGAATGTAATTCAGCTCCGCCATCGCCGCTTCCACT" "TTTTCCCGCGTTTTCGCAGAAACGTGGCTGGCCTGGTTCACCACGCGGGAAACGGTCTGATAAGAGACACCG" "GCATACTCTGCGACATCGTATAACGTTACTGGTTTCACATTCACCACCCTGAATTGACTCTCTTCCGGGCGC" "TATCATGCCATACCGCGAAAGGTTTTGCGCCATTCGATGGTGTCAACGTAAATGCATGCCGCTTCGCCTTCC" "GGCCACCAGAATAGCCTGCGATTCAACCCCTTCTTCGATCTGTTTTGCTACCCGTTGTAGCGCCGGAAGATG" "CTTTTCCGCTGCCTGTTCAATGGTCATTGCGCTCGCCATATACACCAGATTCAGACAGCCAATCACCCGTTG" "TTCACTGCGCAGCGGTACGGCGATAGAGGCGATCTTCTCCTCCTGATCCCAGCCGCGGTAGTTCTGTCCGTA" "ACCCTCTTTGCGCGCGCGCGCCAGAATGGCTTCCAGCTTTAACGGTTCCCGTGCCAGTTGATAGTCATCACC" "GGGGCGGGAGGCTAACATTTCGATTAATTCCTTGCGGTCTTGTTCCGGGCAAAAGGCCAGCCAGGTCAGGCC" "CGAGGCGGTTTTCAGAAGCGGCAAACGTCGCCCGACCATTGCCCGGTGAAAGGATAAGCGGCTGAAACGGTG" "AGTGGTTTCGCGTACCACCATTGCATCAACATCCAGCGTGGACACATCTGTCGGCCATACCACTTCGCGCAA" "CAGATCGCCCAGCAGTGGGGCCGCCAGTGCAGAAATCCACTGTTCGTCACGAAATCCTTCGCTTAATTGCCG" "CACTTTGATGGTCAGTCGAAAACTATCATCGGAGGGGCTACGGCGGACATATCCCTCTTCCTGCAGCGTCTC" "CAGCAGTCGCCGCACAGTGGTGCGATGCAGGCCGCTGAGTTCCGCCAGCAGCCCGACGCTGGCACCGCCATC" "AAGTTTATTTAACATATTTAATAACATTAGACCGCGGGTTAAGCCGCGCACGGTTTTGTATTCCGTCTGCTC" "ATTGTTCTGCATATTAATTGACATTTCTATAGTTAAAACAACGTGGTGCACCTGGTGCACATTCGGGCATGT" "TTTGATTGTAGCCGAAAACACCCTTCCTATACTGAGCGCACAATAAAAAATCATTTACATGTTTTTAACAAA" "ATAAGTTGCGCTGTACTGTGCGCGCAACGACATTTTGTCCGAGTCGTG" ) # Highly conserved example: ~3 kb slice across E. coli K-12 MG1655 rrnB operon # (NC_000913.3:4035531-4038602), which contains the 16S rRNA gene (rrsB) and # flanking regulatory regions. rRNA genes are among the most conserved in # bacteria, so MLM surprise should be visibly lower than on lacZ, and the # per-window embedding plot should be flatter (one coherent functional region). EXAMPLE_16S = ( "AAATTGAAGAGTTTGATCATGGCTCAGATTGAACGCTGGCGGCAGGCCTAACACATGCAAGTCGAACGGTAA" "CAGGAAGAAGCTTGCTTCTTTGCTGACGAGTGGCGGACGGGTGAGTAATGTCTGGGAAACTGCCTGATGGAG" "GGGGATAACTACTGGAAACGGTAGCTAATACCGCATAACGTCGCAAGACCAAAGAGGGGGACCTTCGGGCCT" "CTTGCCATCGGATGTGCCCAGATGGGATTAGCTAGTAGGTGGGGTAACGGCTCACCTAGGCGACGATCCCTA" "GCTGGTCTGAGAGGATGACCAGCCACACTGGAACTGAGACACGGTCCAGACTCCTACGGGAGGCAGCAGTGG" "GGAATATTGCACAATGGGCGCAAGCCTGATGCAGCCATGCCGCGTGTATGAAGAAGGCCTTCGGGTTGTAAA" "GTACTTTCAGCGGGGAGGAAGGGAGTAAAGTTAATACCTTTGCTCATTGACGTTACCCGCAGAAGAAGCACC" "GGCTAACTCCGTGCCAGCAGCCGCGGTAATACGGAGGGTGCAAGCGTTAATCGGAATTACTGGGCGTAAAGC" "GCACGCAGGCGGTTTGTTAAGTCAGATGTGAAATCCCCGGGCTCAACCTGGGAACTGCATCTGATACTGGCA" "AGCTTGAGTCTCGTAGAGGGGGGTAGAATTCCAGGTGTAGCGGTGAAATGCGTAGAGATCTGGAGGAATACC" "GGTGGCGAAGGCGGCCCCCTGGACGAAGACTGACGCTCAGGTGCGAAAGCGTGGGGAGCAAACAGGATTAGA" "TACCCTGGTAGTCCACGCCGTAAACGATGTCGACTTGGAGGTTGTGCCCTTGAGGCGTGGCTTCCGGAGCTA" "ACGCGTTAAGTCGACCGCCTGGGGAGTACGGCCGCAAGGTTAAAACTCAAATGAATTGACGGGGGCCCGCAC" "AAGCGGTGGAGCATGTGGTTTAATTCGATGCAACGCGAAGAACCTTACCTGGTCTTGACATCCACGGAAGTT" "TTCAGAGATGAGAATGTGCCTTCGGGAACCGTGAGACAGGTGCTGCATGGCTGTCGTCAGCTCGTGTTGTGA" "AATGTTGGGTTAAGTCCCGCAACGAGCGCAACCCTTATCCTTTGTTGCCAGCGGTCCGGCCGGGAACTCAAA" "GGAGACTGCCAGTGATAAACTGGAGGAAGGTGGGGATGACGTCAAGTCATCATGGCCCTTACGACCAGGGCT" "ACACACGTGCTACAATGGCGCATACAAAGAGAAGCGACCTCGCGAGAGCAAGCGGACCTCATAAAGTGCGTC" "GTAGTCCGGATTGGAGTCTGCAACTCGACTCCATGAAGTCGGAATCGCTAGTAATCGTGGATCAGAATGCCA" "CGGTGAATACGTTCCCGGGCCTTGTACACACCGCCCGTCACACCATGGGAGTGGGTTGCAAAAGAAGTAGGT" "AGCTTAACCTTCGGGAGGGCGCTTACCACTTTGTGATTCATGACTGGGGTGAAGTCGTAACAAGGTAACCGT" "AGGGGAACCTGCGGTTGGATCACCTCCTTACCTCAGAACAAGAAGTATACCAAACTCAACGGAGTTACCACC" "CGGTGGCAATTGCACAATTGATCATCAGAGAGAACCAATGCTGAACATCAAGAACAATGAAGGCCAGCAAGG" "TAATCCGAGCAGAGCTAAAGAGAACGGATACCCATGACCACGGAAGTGGTCAAGAATATAGGCATCCAAAGA" "CGATCAGATACCGTCGTAGTTCCGACCATAAACGATGAAGAACACGTCCATATAGCCAAGCTCCGTAGGACA" "AGAATAATGAGACAAAACACAAAGCCAACAATGAGCCCTAAGTGATGTCCGGGGAAACCAGAAAGACCCGTA" "GCCTGAAAGATTGCCGGCCACTTGGAACGCTGGATTGAGCACCCTGTAGAACATTTGTTTGAACAGGTGCGG" "ACCGAATAAAGCCACATGATGCAACTATGAATCTGAACTTGCAATGCTGAACGAATCGCGATAAACCTAAGG" "CAGAAGCGTACCCGGGAACATCAATAGACTGCGATGTGATAACGTACCCAAACTTATCCCAGGGCCCGTAAC" "TAAACTGCCCCTTTGCGCTTCGAGTAAAGGCATCAAATAGATATAGACTCATAATGCCACAGTCCAATTACA" "TGCCCGGAAGTTATTAATACTGCGAACGTTATACATACGAAGCCGTAAGGTAATTTGATAAGCGTAACCGAT" "AGCCCCGACAGCGAACTAGCAACCTTGGAGTATATGAACCCAAATATCTGTGAGGCCTGGAACGTCCGAGAT" "GAGAGTGCCACATACTCAAGACTCAAAGTCACCCGAAGGGAATTTGCATATGAGCTCGTCTGGCCAGGAGTT" "TTAAGAGGGGCGCAGATATCACCTAATACGATAGCTAGCCGAATGCACTACGCCAACCATCTAACGGACAGA" "GTAATGAACCACAAGCTCCGAAATGATGCTGAGAGACGCATGGCCAGTTCTCATCAGCCGTCGTGGTCAATC" "GGTGCTTGGCCATCACCATGGGGGCCCGCATCTGCCATCGACAGCGCTTTCATCGTAAACCGTCTTATGGAA" "AGACATTACAGCCAGTGTAAAATCCCGCACACTATTAGCCATCAAATCATATAAGGCATACGGTCAGTCAGT" "ATTCCGAAAGAACACCACCAGTGATAGTACCAAGAGCACGTATGAATACGATGCCGACCATAGCGGACAAAT" "CTCCCAATACGAGAGTAAAATAAGCAAATAATAGATATCCATGCATGGAGTCACCACAATAGAGCGCTACGT" "CGTCGTGAAGAGGGAAACAACCCAGACCGCCAGCTAAGGTCCCAAAGTCATGGTTAAGTGGGAAACGATGTG" "GGAAGGCCCAGACAGCCAGGATGTTGGCTTAGAAGCAGCCATCATTTA" ) # Random DNA: uniform ACGT, 3 kb. Should be ~uninformative — model surprise # should be close to the uniform baseline (ln 6 ≈ 1.79 nats) because there # is no context-dependent pattern to learn. _rng = np.random.default_rng(42) EXAMPLE_RANDOM = "".join(_rng.choice(list("ACGT"), size=3000)) # Low-complexity repeat: 3 kb of AT dinucleotide repeats. Should give very # LOW surprise — the model trivially predicts the next base from context. EXAMPLE_REPEAT = "AT" * 1500 def process(sequence: str, mode: str, stride: int, layer: int): """Main processing function.""" sequence = strip_fasta_header(sequence.strip()) is_valid, error = validate_sequence(sequence) if not is_valid: return f"**Error**: {error}", None, None, None, None, None embedding, window_embeddings, positions = embed_sequence( sequence, mode=mode, stride=stride, layer=layer ) # Save embedding path = os.path.join(tempfile.gettempdir(), "embedding.npy") np.save(path, embedding) # Compute stats if mode == "per-window": # For per-window, compute stats on mean embedding mean_emb = np.mean(embedding, axis=0) stats = compute_embedding_stats(mean_emb) else: stats = compute_embedding_stats(embedding) # Create summary if mode == "per-window": summary = f"""### Results | | | |---|---| | sequence | {len(sequence):,} bp | | layer | {layer} | | windows | {embedding.shape[0]} | | shape | {embedding.shape} | **Stats** (on mean): L2={stats['l2_norm']:.1f}, entropy={stats['entropy']:.2f} """ else: summary = f"""### Results | | | |---|---| | sequence | {len(sequence):,} bp | | layer | {layer} | | mode | {mode} | | dim | {len(embedding)} | **Stats**: L2={stats['l2_norm']:.1f}, entropy={stats['entropy']:.2f}, sparsity={stats['sparsity']:.1%} """ # Create visualizations heatmap_fig = None if mode != "per-window": heatmap_fig = create_embedding_heatmap(embedding, f"Layer {layer}") multi_window = len(window_embeddings) > 1 trajectory_fig = create_trajectory_plot(window_embeddings, positions) if multi_window else None familiarity_fig = create_familiarity_plot(window_embeddings, positions) if multi_window else None dims_fig = create_dimension_plot(window_embeddings, positions) if multi_window else None return summary, path, heatmap_fig, trajectory_fig, familiarity_fig, dims_fig def process_surprise(sequence: str, stride: int, mask_fraction: float): """Compute MLM surprise across the sequence.""" sequence = strip_fasta_header(sequence.strip()) is_valid, error = validate_sequence(sequence) if not is_valid: return f"**Error**: {error}", None per_window, per_base_pos, per_base_vals = compute_mlm_surprise( sequence, stride=stride, mask_fraction=mask_fraction ) if not per_window: return "**Error**: sequence too short for one window", None fig = create_surprise_plot(per_window, per_base_pos, per_base_vals, len(sequence)) w_vals = np.array([s for _, s in per_window]) lo_pos, lo_val = per_window[int(np.argmin(w_vals))] hi_pos, hi_val = per_window[int(np.argmax(w_vals))] summary = f"""### MLM surprise | | | |---|---| | sequence | {len(sequence):,} bp | | windows | {len(per_window)} | | mask fraction | {mask_fraction:.0%} | | mean surprise | {w_vals.mean():.3f} nats | | uniform baseline | {UNIFORM_SURPRISE:.3f} nats (ln 6) | | most predictable window | {lo_val:.3f} nats @ ~{lo_pos:,} bp | | most surprising window | {hi_val:.3f} nats @ ~{hi_pos:,} bp | Lower = model confidently predicts the true base → conserved/typical pattern. Higher = model is unsure → unusual region relative to training distribution. """ return summary, fig # Build interface with gr.Blocks( title="BERT Metagenome Embeddings", css=".gradio-container { max-width: 100% !important; }" ) as demo: gr.Markdown( "# bert-embedding\n" "BERT (24 layers, 430M params) pretrained with masked-language-modeling " "on metagenomic contigs. Input: DNA sequence ≥ 1000 bp. The model slides a " "1000 bp window over your sequence and produces two kinds of output — " "**embeddings** (768-dim hidden-state vector per window, under *Extract*) " "and **MLM surprise** (model's confidence in reconstructing masked bases, " "under *MLM surprise*). See the *Guide* tab for how to read the plots." ) with gr.Tab("Extract"): gr.Markdown( "Extract 768-dim hidden-state embeddings from a chosen transformer layer. " "Produces **per-window** vectors (one per 1000-bp window) that can be pooled, " "visualised as heatmaps, and compared via cosine similarity. " "The per-window plot below shows how the embedding *changes along the sequence* — " "it does **not** say anything about whether bases are predictable " "(that's the MLM surprise tab)." ) with gr.Row(): with gr.Column(scale=1, min_width=260): seq_input = gr.Textbox( label="sequence (≥ 1000 bp)", placeholder="Paste DNA (FASTA or raw)...", lines=8, value=EXAMPLE_SEQUENCE ) gr.Markdown("**load an example:**") with gr.Row(): ex_lacz = gr.Button("lacZ (mixed)", size="sm") ex_16s = gr.Button("16S rRNA (conserved)", size="sm") with gr.Row(): ex_rand = gr.Button("random DNA", size="sm") ex_rep = gr.Button("AT repeat (low complexity)", size="sm") mode_input = gr.Radio( choices=["mean", "max", "per-window"], value="mean", label="pooling", info="how to collapse positions within each window" ) layer_input = gr.Slider(0, 23, value=21, step=1, label="layer (0=shallow, 23=deep)") stride_input = gr.Slider(50, 500, value=100, step=50, label="stride (bp)", info="step between windows. lower = more windows, more compute") btn = gr.Button("extract", variant="primary") with gr.Column(scale=3, min_width=500): output = gr.Markdown() download = gr.File(label="download embedding (.npy)") gr.Markdown( "**Per-window plot below.** x-axis = window-start position along your input " "sequence. Each point is one 1000-bp window. " "*L2 norm* = how strongly the model's neurons fire on this window (bigger = " "stronger activation). *Novelty* = how different this window's embedding is " "from the average embedding across your sequence (1 − cos sim to mean); " "peaks = regions that stand out from the rest of **this** sequence." ) familiarity_plot = gr.Plot(label="per-window L2 norm & embedding-space novelty along your input sequence") gr.Markdown( "**Trajectory** (left): heatmap of (windows × ~100 subsampled embedding " "dimensions). Sharp horizontal bands = sudden embedding change → boundary. " "**Top varying dims** (right): the 8 dimensions that vary most across windows, " "plotted vs window position." ) with gr.Row(): trajectory_plot = gr.Plot(label="window × dimension heatmap (trajectory)") dims_plot = gr.Plot(label="top-8 most variable dimensions vs window position") gr.Markdown( "**Pooled embedding heatmap** (below): the single 768-dim vector after " "pooling across windows — only shown in `mean` / `max` mode. " "Red = positive activation, blue = negative." ) heatmap_plot = gr.Plot(label="pooled 768-dim embedding heatmap (24 × 32 grid)") btn.click( process, inputs=[seq_input, mode_input, stride_input, layer_input], outputs=[output, download, heatmap_plot, trajectory_plot, familiarity_plot, dims_plot], api_name="embed" ) ex_lacz.click(lambda: EXAMPLE_SEQUENCE, outputs=seq_input) ex_16s.click(lambda: EXAMPLE_16S, outputs=seq_input) ex_rand.click(lambda: EXAMPLE_RANDOM, outputs=seq_input) ex_rep.click(lambda: EXAMPLE_REPEAT, outputs=seq_input) with gr.Tab("MLM surprise"): gr.Markdown( "**What it does.** For each 1000-bp window, we randomly replace ~15% of bases " "with a mask token, ask the model to predict what was there, and record " "**−log p(true base)** at those positions. Low = model is confident (pattern " "matches training distribution). High = model is uncertain (unusual region).\n\n" "**Unit**: nats. Uniform-random guessing over {A,C,G,T,N,PAD} = **ln 6 ≈ 1.79 nats**. " "A perfectly confident correct prediction = **0 nats**.\n\n" "**Expected shape on the examples below**:\n" "- *lacZ* → moderate (~0.6–1.3 nats), with visible structure between CDS and UTR.\n" "- *16S rRNA* → lowest values, often flat across the whole region (highly conserved).\n" "- *random DNA* → near the 1.79 baseline, no trend (by construction unpredictable).\n" "- *AT repeat* → near 0 (trivial to predict once you see a few bases)." ) with gr.Row(): with gr.Column(scale=1, min_width=260): surp_seq = gr.Textbox( label="sequence (≥ 1000 bp)", placeholder="Paste DNA (FASTA or raw)...", lines=8, value=EXAMPLE_SEQUENCE, ) gr.Markdown("**load an example:**") with gr.Row(): sx_lacz = gr.Button("lacZ (mixed)", size="sm") sx_16s = gr.Button("16S rRNA (conserved)", size="sm") with gr.Row(): sx_rand = gr.Button("random DNA", size="sm") sx_rep = gr.Button("AT repeat (low complexity)", size="sm") surp_stride = gr.Slider(50, 500, value=100, step=50, label="stride (bp)", info="step between windows. lower = more windows") surp_mask = gr.Slider(0.05, 0.5, value=0.15, step=0.05, label="mask fraction", info="fraction of positions masked in each window (0.15 matches BERT training)") surp_btn = gr.Button("score", variant="primary") with gr.Column(scale=3, min_width=500): surp_summary = gr.Markdown() surp_plot = gr.Plot(label="MLM surprise along the input sequence") surp_btn.click( process_surprise, inputs=[surp_seq, surp_stride, surp_mask], outputs=[surp_summary, surp_plot], api_name="surprise", ) sx_lacz.click(lambda: EXAMPLE_SEQUENCE, outputs=surp_seq) sx_16s.click(lambda: EXAMPLE_16S, outputs=surp_seq) sx_rand.click(lambda: EXAMPLE_RANDOM, outputs=surp_seq) sx_rep.click(lambda: EXAMPLE_REPEAT, outputs=surp_seq) with gr.Tab("Guide"): gr.Markdown(""" ### What this space actually does The model is a BERT trained to predict masked bases in DNA. Two things come out of it: | | **Extract (embeddings)** | **MLM surprise** | |---|---|---| | What | 768-dim hidden-state vector per window (layer L) | −log p(true base) at masked positions | | Y-axis | L2 norm / cosine distance to sequence mean | nats | | Measures | how the *representation* changes along the sequence | how *predictable* the bases are | | Depends on training data? | indirectly (via what the model learned to represent) | yes — bases typical of training data get low surprise | | "Unusual" means | this window's embedding differs from the rest of *this* sequence | this base is hard to predict from context, relative to what the model saw in pretraining | They look similar (both have spiky plots over position) but answer **different questions**. A region can be *embedding-novel* (structurally unlike the rest of your sequence) without being *MLM-surprising* (it could still be a predictable pattern the model knows). And vice versa. ### Reading the per-window plots - **x-axis** on every per-window plot = position along your input sequence, in bp. Each point = one 1000-bp window, plotted at the window's start (Extract tab) or center (MLM surprise tab). With `stride=100`, successive points are 100 bp apart and **overlap by 900 bp** — so plots are smoother than they look, and a single isolated peak is usually real, not noise. - **y-axis units**: - Extract → L2 norm (unitless magnitude of the 768-dim vector) and cosine-distance 1−cos. - MLM surprise → nats (natural-log likelihood). 0 = perfectly confident, ln 6 ≈ 1.79 = uniform. ### How the examples should look - **lacZ (~3 kb of E. coli)** — mixed content (regulatory + coding). MLM surprise will have peaks and troughs reflecting CDS vs UTR structure. Embedding novelty has some variation. - **16S rRNA (~3 kb of rrnB)** — highly conserved, functional RNA. MLM surprise should be visibly *lower* than on lacZ, often flat. Embedding trajectory looks like one consistent region. - **random DNA (3 kb, uniform ACGT)** — no pattern. MLM surprise should sit near the ln 6 baseline with no trend. Embedding novelty will look noisy. - **AT repeat (3 kb of ATATAT…)** — trivially predictable. MLM surprise should be very low (~0) across the whole sequence. ### Caveats - MLM surprise depends on which token we use as [MASK]. This space uses token `0` (PAD/OOV). If the original pretraining used a different sentinel, scores will be uniformly *pessimistic* (near the 1.79 baseline on everything). If you see that on the conserved 16S example, the mask token is wrong and we need to switch to `5` (N/AMB). - Only 15% of each window is scored per run. Points in the per-base scatter are sampled, not exhaustive — but because windows overlap 10× at stride 100, every base usually has several observations. - Both metrics are relative. Absolute values mean little; compare across regions of the same sequence, or across the example sequences above. """) with gr.Tab("API"): gr.Markdown(""" ### API ```python from gradio_client import Client import numpy as np client = Client("genomenet/bert-embedding") result = client.predict( sequence="ATGC...", # min 1000 bp mode="mean", # mean/max/per-window stride=100, layer=21, # 0-23 api_name="/embed" ) summary, emb_path, *plots = result embedding = np.load(emb_path) ``` **Per-window plots** (along sequence position): - **L2 norm**: activation magnitude — high = strong, structured response. - **Novelty** (1 − cosine similarity to mean embedding): how much the window differs from the rest of the sequence. Spikes = unusual regions relative to context. Numeric stats (L2, entropy, sparsity, kurtosis) are in the summary text. ### MLM surprise endpoint ```python summary, plot = client.predict( sequence="ATGC...", stride=100, mask_fraction=0.15, api_name="/surprise", ) ``` Returns per-window mean `-log(p_true)` at masked positions (in nats). Uniform-random baseline is `ln(6) ≈ 1.79 nats`. """) with gr.Tab("About"): gr.Markdown(""" ### Model | | | |---|---| | architecture | BERT, 24 layers, 768 hidden, 12 heads | | parameters | ~430M | | input | 1000 bp sliding window | | pretraining | metagenomic contigs + microbial genomes | ### Interpreting Statistics The embedding statistics provide indirect measures of how the model "responds" to a sequence: - **L2 Norm**: Total activation magnitude. Very high or low may indicate unusual sequences. - **Entropy**: How spread out the activations are. Lower entropy suggests more confident/structured representation. - **Sparsity**: Fraction of dimensions with near-zero activation. - **Kurtosis**: How peaked the distribution is. Higher values = more concentrated activations. **Note**: These are not direct "familiarity" probabilities, but patterns in these metrics across different sequence types may reveal what the model considers typical vs. unusual. ### Links - Model: [genomenet/bert-metagenome](https://huggingface.co/genomenet/bert-metagenome) - CRISPR: [genomenet/crispr-array-detection](https://huggingface.co/spaces/genomenet/crispr-array-detection) """) if __name__ == "__main__": print("Loading model...") _ = get_base_model() print(f"Ready! {get_gpu_status()}") demo.launch( server_name="0.0.0.0", server_port=7860, theme=gr.themes.Base( primary_hue=gr.themes.colors.zinc, neutral_hue=gr.themes.colors.zinc, ) )