Spaces:
Sleeping
Sleeping
| """ | |
| BERT Metagenome Embeddings - HuggingFace Spaces App | |
| """ | |
| import os | |
| os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' | |
| import tempfile | |
| import gradio as gr | |
| import numpy as np | |
| import tensorflow as tf | |
| from huggingface_hub import hf_hub_download | |
| import matplotlib | |
| matplotlib.use('Agg') | |
| import matplotlib.pyplot as plt | |
| import plotly.graph_objects as go | |
| from custom_layers import get_custom_objects | |
| # Model config | |
| MODEL_REPO = "genomenet/bert-metagenome" | |
| MODEL_FILE = "bert_1k_3.h5" | |
| WINDOW_SIZE = 1000 | |
| NUM_LAYERS = 24 | |
| EMBEDDING_DIM = 768 | |
| # Singleton model cache | |
| _model = None | |
| _embedding_models = {} | |
| def get_base_model(): | |
| """Load and cache the base model.""" | |
| global _model | |
| if _model is None: | |
| print("Downloading model...") | |
| model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILE) | |
| print(f"Loading model from {model_path}...") | |
| _model = tf.keras.models.load_model(model_path, custom_objects=get_custom_objects(), compile=False) | |
| print("Model loaded.") | |
| # Print model summary for debugging | |
| print(f"Model outputs: {_model.output_names}") | |
| return _model | |
| def get_embedding_model(layer_idx=21): | |
| """Get embedding model for a specific layer.""" | |
| global _embedding_models | |
| if layer_idx not in _embedding_models: | |
| model = get_base_model() | |
| layer_name = f"layer_transformer_block_{layer_idx}" | |
| try: | |
| _embedding_models[layer_idx] = tf.keras.Model( | |
| inputs=model.input, | |
| outputs=model.get_layer(layer_name).output | |
| ) | |
| except ValueError: | |
| _embedding_models[layer_idx] = tf.keras.Model( | |
| inputs=model.input, | |
| outputs=model.get_layer("layer_transformer_block_21").output | |
| ) | |
| return _embedding_models[layer_idx] | |
| def get_gpu_status(): | |
| gpus = tf.config.list_physical_devices('GPU') | |
| return f"GPU: {gpus[0].name}" if gpus else "CPU only" | |
| # Tokenization | |
| TOKEN_MAP = {'A': 1, 'C': 2, 'G': 3, 'T': 4, 'N': 5} | |
| def tokenize(sequence): | |
| sequence = sequence.upper().replace('U', 'T') | |
| return np.array([TOKEN_MAP.get(c, 5) for c in sequence], dtype=np.int32) | |
| def validate_sequence(sequence): | |
| if not sequence or len(sequence.strip()) == 0: | |
| return False, "Sequence is empty" | |
| sequence = sequence.upper().replace('U', 'T') | |
| valid_chars = set('ACGTNRYSWKMBDHV') | |
| invalid = set(sequence) - valid_chars - set(' \n\r\t') | |
| if invalid: | |
| return False, f"Invalid characters: {invalid}" | |
| clean = ''.join(c for c in sequence if c in valid_chars) | |
| if len(clean) < WINDOW_SIZE: | |
| return False, f"Sequence too short: {len(clean)} < {WINDOW_SIZE} bp" | |
| return True, "" | |
| def strip_fasta_header(text): | |
| lines = text.strip().split('\n') | |
| return ''.join(l for l in lines if not l.startswith('>')).replace(' ', '').replace('\t', '') | |
| def compute_embedding_stats(embedding): | |
| """Compute statistics that may indicate sequence 'familiarity'.""" | |
| emb = np.array(embedding) | |
| # L2 norm - magnitude of response | |
| l2_norm = np.linalg.norm(emb) | |
| # Mean activation | |
| mean_act = np.mean(emb) | |
| # Std - spread of activations | |
| std_act = np.std(emb) | |
| # Sparsity - fraction of near-zero activations | |
| sparsity = np.mean(np.abs(emb) < 0.1) | |
| # Activation entropy (discretized) | |
| hist, _ = np.histogram(emb, bins=50, density=True) | |
| hist = hist[hist > 0] | |
| entropy = -np.sum(hist * np.log(hist + 1e-10)) | |
| # Kurtosis - peakedness (high = more concentrated activations) | |
| kurtosis = np.mean(((emb - mean_act) / (std_act + 1e-10)) ** 4) - 3 | |
| return { | |
| 'l2_norm': float(l2_norm), | |
| 'mean': float(mean_act), | |
| 'std': float(std_act), | |
| 'sparsity': float(sparsity), | |
| 'entropy': float(entropy), | |
| 'kurtosis': float(kurtosis) | |
| } | |
| def embed_sequence(sequence, mode="mean", stride=100, layer=21): | |
| """Extract embeddings from sequence.""" | |
| model = get_embedding_model(layer) | |
| seq_len = len(sequence) | |
| embeddings = [] | |
| positions = [] | |
| for start in range(0, seq_len - WINDOW_SIZE + 1, stride): | |
| window = sequence[start:start + WINDOW_SIZE] | |
| tokens = np.expand_dims(tokenize(window), axis=0) | |
| emb = model.predict(tokens, verbose=0) | |
| embeddings.append(emb[0]) | |
| positions.append(start) | |
| embeddings = np.array(embeddings) # (n_windows, 1000, 768) | |
| if mode == "mean": | |
| window_emb = np.mean(embeddings, axis=1) | |
| return np.mean(window_emb, axis=0), window_emb, positions | |
| elif mode == "max": | |
| window_emb = np.max(embeddings, axis=1) | |
| return np.max(window_emb, axis=0), window_emb, positions | |
| elif mode == "per-window": | |
| window_emb = np.mean(embeddings, axis=1) | |
| return window_emb, window_emb, positions | |
| else: | |
| window_emb = np.mean(embeddings, axis=1) | |
| return np.mean(window_emb, axis=0), window_emb, positions | |
| # ln(vocab_size=6): surprise if the model predicted uniformly at random. | |
| UNIFORM_SURPRISE = float(np.log(6)) | |
| MASK_TOKEN = 0 # PAD/OOV; used as the MLM mask slot | |
| def compute_mlm_surprise(sequence, stride=100, mask_fraction=0.15, seed=42): | |
| """Per-window and per-base MLM surprise. | |
| For each sliding window, randomly mask ~mask_fraction of positions, run one | |
| forward pass through the full model (which ends in a Dense(vocab_size=6)), | |
| softmax the per-position logits, and take -log(p_true) at the masked | |
| positions. Returns: | |
| - per_window: list of (position, mean_surprise) | |
| - per_base_pos, per_base_vals: flat arrays of (position, surprise) samples, | |
| one entry per (window × masked_position). Overlapping windows give | |
| multiple observations per base. | |
| """ | |
| model = get_base_model() | |
| tokens = tokenize(sequence) | |
| seq_len = len(tokens) | |
| rng = np.random.default_rng(seed) | |
| n_mask = max(1, int(WINDOW_SIZE * mask_fraction)) | |
| starts = list(range(0, seq_len - WINDOW_SIZE + 1, stride)) | |
| if not starts: | |
| return [], np.array([]), np.array([]) | |
| # Build all windows + mask sets, run one batched forward pass. | |
| batch = np.zeros((len(starts), WINDOW_SIZE), dtype=np.int32) | |
| truth = np.zeros((len(starts), WINDOW_SIZE), dtype=np.int32) | |
| mask_idxs = [] | |
| for i, start in enumerate(starts): | |
| w = tokens[start:start + WINDOW_SIZE] | |
| truth[i] = w | |
| idx = rng.choice(WINDOW_SIZE, size=n_mask, replace=False) | |
| mask_idxs.append(idx) | |
| w_masked = w.copy() | |
| w_masked[idx] = MASK_TOKEN | |
| batch[i] = w_masked | |
| logits = model.predict(batch, verbose=0, batch_size=8) # (n_win, 1000, 6) | |
| logits -= logits.max(axis=-1, keepdims=True) | |
| exp_l = np.exp(logits) | |
| probs = exp_l / exp_l.sum(axis=-1, keepdims=True) | |
| per_window = [] | |
| per_base_pos = [] | |
| per_base_vals = [] | |
| for i, start in enumerate(starts): | |
| idx = mask_idxs[i] | |
| p_true = probs[i, idx, truth[i, idx]] | |
| surprises = -np.log(np.clip(p_true, 1e-10, None)) | |
| per_window.append((start + WINDOW_SIZE // 2, float(surprises.mean()))) | |
| per_base_pos.extend((start + idx).tolist()) | |
| per_base_vals.extend(surprises.tolist()) | |
| return per_window, np.array(per_base_pos), np.array(per_base_vals) | |
| def create_surprise_plot(per_window, per_base_pos, per_base_vals, seq_len): | |
| """Two-panel Plotly figure: per-window surprise line + per-base scatter.""" | |
| from plotly.subplots import make_subplots | |
| fig = make_subplots( | |
| rows=2, cols=1, shared_xaxes=True, vertical_spacing=0.1, | |
| row_heights=[0.55, 0.45], | |
| subplot_titles=( | |
| 'per-WINDOW mean surprise — one point per 1000-bp window, plotted at its center', | |
| 'per-BASE surprise — one dot per masked base (~15% of bases in each window)', | |
| ), | |
| ) | |
| wx = [p for p, _ in per_window] | |
| wy = [s for _, s in per_window] | |
| fig.add_trace(go.Scatter( | |
| x=wx, y=wy, mode='lines+markers', | |
| line=dict(color='#18181b', width=2), marker=dict(size=7), | |
| hovertemplate='window center: %{x} bp<br>surprise: %{y:.3f} nats<extra></extra>', | |
| name='window mean', showlegend=False, | |
| ), row=1, col=1) | |
| fig.add_hline( | |
| y=UNIFORM_SURPRISE, line_dash='dash', line_color='#a1a1aa', | |
| annotation_text=f'uniform-random baseline (ln 6 = {UNIFORM_SURPRISE:.2f})', | |
| annotation_position='top right', annotation_font=dict(size=10, color='#71717a'), | |
| row=1, col=1, | |
| ) | |
| fig.add_trace(go.Scatter( | |
| x=per_base_pos, y=per_base_vals, mode='markers', | |
| marker=dict(size=4, color=per_base_vals, colorscale='Reds', | |
| cmin=0, cmax=UNIFORM_SURPRISE, | |
| colorbar=dict(title=dict(text='surprise<br>(nats)', font=dict(size=10)), | |
| thickness=10, len=0.4, y=0.2, tickfont=dict(size=9))), | |
| hovertemplate='base position: %{x} bp<br>surprise: %{y:.3f} nats<extra></extra>', | |
| name='per base', showlegend=False, | |
| ), row=2, col=1) | |
| fig.add_hline( | |
| y=UNIFORM_SURPRISE, line_dash='dash', line_color='#a1a1aa', | |
| row=2, col=1, | |
| ) | |
| fig.update_xaxes(title_text='position along input sequence (bp)', row=2, col=1, | |
| range=[0, seq_len]) | |
| fig.update_xaxes(range=[0, seq_len], row=1, col=1) | |
| fig.update_yaxes(title_text='surprise (nats)', row=1, col=1, rangemode='tozero') | |
| fig.update_yaxes(title_text='surprise (nats)', row=2, col=1, rangemode='tozero') | |
| fig.update_layout(height=560, margin=dict(l=60, r=20, t=70, b=60)) | |
| for ann in fig['layout']['annotations']: | |
| if 'font' not in ann: | |
| ann['font'] = dict(size=11) | |
| return fig | |
| def create_embedding_heatmap(embedding, title="Embedding"): | |
| """Create a heatmap of a single embedding vector.""" | |
| embedding = np.array(embedding) | |
| n_dims = len(embedding) | |
| cols = 32 | |
| rows = int(np.ceil(n_dims / cols)) | |
| padded = np.full(rows * cols, np.nan) | |
| padded[:n_dims] = embedding | |
| grid = padded.reshape(rows, cols) | |
| finite = embedding[np.isfinite(embedding)] | |
| vmax = max(abs(np.nanmin(finite)), abs(np.nanmax(finite)), 0.01) if finite.size > 0 else 1.0 | |
| fig, ax = plt.subplots(figsize=(14, max(4, rows * 0.35))) | |
| im = ax.imshow(grid, cmap='RdBu_r', vmin=-vmax, vmax=vmax, aspect='auto') | |
| plt.colorbar(im, ax=ax, shrink=0.8, label='Activation') | |
| ax.set_xlabel('Dimension') | |
| ax.set_ylabel('Row') | |
| ax.set_title(f'{title} ({n_dims} dims)') | |
| ax.set_xticks(np.arange(0, cols, 8)) | |
| plt.tight_layout() | |
| return fig | |
| def create_trajectory_plot(window_embeddings, positions): | |
| """Create interactive trajectory heatmap.""" | |
| emb = np.array(window_embeddings) | |
| n_windows, n_dims = emb.shape | |
| # Subsample dimensions | |
| step = max(1, n_dims // 100) | |
| emb_sub = emb[:, ::step] | |
| vmax = max(abs(np.nanmin(emb_sub)), abs(np.nanmax(emb_sub)), 0.01) | |
| fig = go.Figure(go.Heatmap( | |
| z=emb_sub, | |
| x=list(range(emb_sub.shape[1])), | |
| y=[f"{p}" for p in positions], | |
| colorscale='RdBu_r', | |
| zmin=-vmax, zmax=vmax, | |
| colorbar=dict(title='Act.'), | |
| hovertemplate='Pos: %{y} bp<br>Dim: %{x}<br>Val: %{z:.3f}<extra></extra>' | |
| )) | |
| fig.update_layout( | |
| xaxis=dict(title='Dimension' + (' (subsampled)' if step > 1 else '')), | |
| yaxis=dict(title='Window start (bp)'), | |
| height=max(350, n_windows * 15 + 100), | |
| margin=dict(l=60, r=20, t=30, b=50) | |
| ) | |
| return fig | |
| def create_familiarity_plot(window_embeddings, positions): | |
| """Per-window L2 norm + novelty (cosine distance to sequence mean) along the sequence. | |
| High L2 norm = strong response. High novelty = window looks different from the rest | |
| of the sequence (the model's internal 'surprise' relative to the sequence average). | |
| """ | |
| from plotly.subplots import make_subplots | |
| emb = np.array(window_embeddings) | |
| n_windows = emb.shape[0] | |
| l2 = np.linalg.norm(emb, axis=1) | |
| mean_vec = emb.mean(axis=0) | |
| mean_norm = np.linalg.norm(mean_vec) + 1e-10 | |
| cos_sim = (emb @ mean_vec) / (l2 * mean_norm + 1e-10) | |
| novelty = 1.0 - cos_sim | |
| fig = make_subplots( | |
| rows=2, cols=1, shared_xaxes=True, vertical_spacing=0.14, | |
| subplot_titles=( | |
| 'per-WINDOW L2 norm of embedding (activation magnitude at layer L)', | |
| 'per-WINDOW embedding novelty (1 − cos similarity to sequence-mean embedding)', | |
| ), | |
| ) | |
| fig.add_trace(go.Scatter( | |
| x=positions, y=l2, mode='lines+markers', | |
| line=dict(color='#3b82f6', width=2), marker=dict(size=6), | |
| hovertemplate='window start: %{x} bp<br>L2: %{y:.2f}<extra></extra>', showlegend=False | |
| ), row=1, col=1) | |
| fig.add_trace(go.Scatter( | |
| x=positions, y=novelty, mode='lines+markers', | |
| line=dict(color='#ef4444', width=2), marker=dict(size=6), | |
| hovertemplate='window start: %{x} bp<br>novelty: %{y:.3f}<extra></extra>', showlegend=False | |
| ), row=2, col=1) | |
| fig.update_xaxes(title_text='window start (bp along input sequence)', row=2, col=1) | |
| fig.update_yaxes(title_text='L2 norm', row=1, col=1) | |
| fig.update_yaxes(title_text='1 − cos sim', row=2, col=1) | |
| fig.update_layout( | |
| height=380 if n_windows > 1 else 260, | |
| margin=dict(l=60, r=20, t=50, b=50), | |
| ) | |
| for ann in fig['layout']['annotations']: | |
| ann['font'] = dict(size=11) | |
| return fig | |
| def create_dimension_plot(window_embeddings, positions, top_k=8): | |
| """Show top varying dimensions.""" | |
| emb = np.array(window_embeddings) | |
| variances = np.var(emb, axis=0) | |
| top_dims = np.argsort(variances)[-top_k:][::-1] | |
| colors = ['#e41a1c', '#377eb8', '#4daf4a', '#984ea3', | |
| '#ff7f00', '#a65628', '#f781bf', '#999999'] | |
| fig = go.Figure() | |
| for i, dim in enumerate(top_dims): | |
| fig.add_trace(go.Scatter( | |
| x=positions, y=emb[:, dim], | |
| mode='lines', name=f'd{dim}', | |
| line=dict(color=colors[i % len(colors)], width=1.5) | |
| )) | |
| fig.update_layout( | |
| xaxis=dict(title='Position (bp)'), | |
| yaxis=dict(title='Activation'), | |
| height=300, | |
| legend=dict(orientation='h', y=1.1), | |
| margin=dict(l=50, r=20, t=40, b=50) | |
| ) | |
| return fig | |
| # Example sequence: ~3 kb slice of E. coli K-12 MG1655 around the lacZ operon | |
| # (NC_000913.3, positions 365529-368600). Covers the lac repressor binding region, | |
| # the lacZ gene, and flanking regulatory sequence, so per-window plots show | |
| # real biological structure transitions. | |
| EXAMPLE_SEQUENCE = ( | |
| "AACTGTTACCCGTAGGTAGTCACGCAACTCGCCGCACATCTGAACTTCAGCCTCCAGTACAGCGCGGCTGAA" | |
| "ATCATCATTAAAGCGAGTGGCAACATGGAAATCGCTGATTTGTGTAGTCGGTTTATGCAGCAACGAGACGTC" | |
| "ACGGAAAATGCCGCTCATCCGCCACATATCCTGATCTTCCAGATAACTGCCGTCACTCCAGCGCAGCACCAT" | |
| "CACCGCGAGGCGGTTTTCTCCGGCGCGTAAAAATGCGCTCAGGTCAAATTCAGACGGCAAACGACTGTCCTG" | |
| "GCCGTAACCGACCCAGCGCCCGTTGCACCACAGATGAAACGCCGAGTTAACGCCATCAAAAATAATTCGCGT" | |
| "CTGGCCTTCCTGTAGCCAGCTTTCATCAACATTAAATGTGAGCGAGTAACAACCCGTCGGATTCTCCGTGGG" | |
| "AACAAACGGCGGATTGACCGTAATGGGATAGGTCACGTTGGTGTAGATGGGCGCATCGTAACCGTGCATCTG" | |
| "CCAGTTTGAGGGGACGACGACAGTATCGGCCTCAGGAAGATCGCACTCCAGCCAGCTTTCCGGCACCGCTTC" | |
| "TGGTGCCGGAAACCAGGCAAAGCGCCATTCGCCATTCAGGCTGCGCAACTGTTGGGAAGGGCGATCGGTGCG" | |
| "GGCCTCTTCGCTATTACGCCAGCTGGCGAAAGGGGGATGTGCTGCAAGGCGATTAAGTTGGGTAACGCCAGG" | |
| "GTTTTCCCAGTCACGACGTTGTAAAACGACGGCCAGTGAATCCGTAATCATGGTCATAGCTGTTTCCTGTGT" | |
| "GAAATTGTTATCCGCTCACAATTCCACACAACATACGAGCCGGAAGCATAAAGTGTAAAGCCTGGGGTGCCT" | |
| "AATGAGTGAGCTAACTCACATTAATTGCGTTGCGCTCACTGCCCGCTTTCCAGTCGGGAAACCTGTCGTGCC" | |
| "AGCTGCATTAATGAATCGGCCAACGCGCGGGGAGAGGCGGTTTGCGTATTGGGCGCCAGGGTGGTTTTTCTT" | |
| "TTCACCAGTGAGACGGGCAACAGCTGATTGCCCTTCACCGCCTGGCCCTGAGAGAGTTGCAGCAAGCGGTCC" | |
| "ACGCTGGTTTGCCCCAGCAGGCGAAAATCCTGTTTGATGGTGGTTAACGGCGGGATATAACATGAGCTGTCT" | |
| "TCGGTATCGTCGTATCCCACTACCGAGATATCCGCACCAACGCGCAGCCCGGACTCGGTAATGGCGCGCATT" | |
| "GCGCCCAGCGCCATCTGATCGTTGGCAACCAGCATCGCAGTGGGAACGATGCCCTCATTCAGCATTTGCATG" | |
| "GTTTGTTGAAAACCGGACATGGCACTCCAGTCGCCTTCCCGTTCCGCTATCGGCTGAATTTGATTGCGAGTG" | |
| "AGATATTTATGCCAGCCAGCCAGACGCAGACGCGCCGAGACAGAACTTAATGGGCCCGCTAACAGCGCGATT" | |
| "TGCTGGTGACCCAATGCGACCAGATGCTCCACGCCCAGTCGCGTACCGTCTTCATGGGAGAAAATAATACTG" | |
| "TTGATGGGTGTCTGGTCAGAGACATCAAGAAATAACGCCGGAACATTAGTGCAGGCAGCTTCCACAGCAATG" | |
| "GCATCCTGGTCATCCAGCGGATAGTTAATGATCAGCCCACTGACGCGTTGCGCGAGAAGATTGTGCACCGCC" | |
| "GCTTTACAGGCTTCGACGCCGCTTCGTTCTACCATCGACACCACCACGCTGGCACCCAGTTGATCGGCGCGA" | |
| "GATTTAATCGCCGCGACAATTTGCGACGGCGCGTGCAGGGCCAGACTGGAGGTGGCAACGCCAATCAGCAAC" | |
| "GACTGTTTGCCCGCCAGTTGTTGTGCCACGCGGTTGGGAATGTAATTCAGCTCCGCCATCGCCGCTTCCACT" | |
| "TTTTCCCGCGTTTTCGCAGAAACGTGGCTGGCCTGGTTCACCACGCGGGAAACGGTCTGATAAGAGACACCG" | |
| "GCATACTCTGCGACATCGTATAACGTTACTGGTTTCACATTCACCACCCTGAATTGACTCTCTTCCGGGCGC" | |
| "TATCATGCCATACCGCGAAAGGTTTTGCGCCATTCGATGGTGTCAACGTAAATGCATGCCGCTTCGCCTTCC" | |
| "GGCCACCAGAATAGCCTGCGATTCAACCCCTTCTTCGATCTGTTTTGCTACCCGTTGTAGCGCCGGAAGATG" | |
| "CTTTTCCGCTGCCTGTTCAATGGTCATTGCGCTCGCCATATACACCAGATTCAGACAGCCAATCACCCGTTG" | |
| "TTCACTGCGCAGCGGTACGGCGATAGAGGCGATCTTCTCCTCCTGATCCCAGCCGCGGTAGTTCTGTCCGTA" | |
| "ACCCTCTTTGCGCGCGCGCGCCAGAATGGCTTCCAGCTTTAACGGTTCCCGTGCCAGTTGATAGTCATCACC" | |
| "GGGGCGGGAGGCTAACATTTCGATTAATTCCTTGCGGTCTTGTTCCGGGCAAAAGGCCAGCCAGGTCAGGCC" | |
| "CGAGGCGGTTTTCAGAAGCGGCAAACGTCGCCCGACCATTGCCCGGTGAAAGGATAAGCGGCTGAAACGGTG" | |
| "AGTGGTTTCGCGTACCACCATTGCATCAACATCCAGCGTGGACACATCTGTCGGCCATACCACTTCGCGCAA" | |
| "CAGATCGCCCAGCAGTGGGGCCGCCAGTGCAGAAATCCACTGTTCGTCACGAAATCCTTCGCTTAATTGCCG" | |
| "CACTTTGATGGTCAGTCGAAAACTATCATCGGAGGGGCTACGGCGGACATATCCCTCTTCCTGCAGCGTCTC" | |
| "CAGCAGTCGCCGCACAGTGGTGCGATGCAGGCCGCTGAGTTCCGCCAGCAGCCCGACGCTGGCACCGCCATC" | |
| "AAGTTTATTTAACATATTTAATAACATTAGACCGCGGGTTAAGCCGCGCACGGTTTTGTATTCCGTCTGCTC" | |
| "ATTGTTCTGCATATTAATTGACATTTCTATAGTTAAAACAACGTGGTGCACCTGGTGCACATTCGGGCATGT" | |
| "TTTGATTGTAGCCGAAAACACCCTTCCTATACTGAGCGCACAATAAAAAATCATTTACATGTTTTTAACAAA" | |
| "ATAAGTTGCGCTGTACTGTGCGCGCAACGACATTTTGTCCGAGTCGTG" | |
| ) | |
| # Highly conserved example: ~3 kb slice across E. coli K-12 MG1655 rrnB operon | |
| # (NC_000913.3:4035531-4038602), which contains the 16S rRNA gene (rrsB) and | |
| # flanking regulatory regions. rRNA genes are among the most conserved in | |
| # bacteria, so MLM surprise should be visibly lower than on lacZ, and the | |
| # per-window embedding plot should be flatter (one coherent functional region). | |
| EXAMPLE_16S = ( | |
| "AAATTGAAGAGTTTGATCATGGCTCAGATTGAACGCTGGCGGCAGGCCTAACACATGCAAGTCGAACGGTAA" | |
| "CAGGAAGAAGCTTGCTTCTTTGCTGACGAGTGGCGGACGGGTGAGTAATGTCTGGGAAACTGCCTGATGGAG" | |
| "GGGGATAACTACTGGAAACGGTAGCTAATACCGCATAACGTCGCAAGACCAAAGAGGGGGACCTTCGGGCCT" | |
| "CTTGCCATCGGATGTGCCCAGATGGGATTAGCTAGTAGGTGGGGTAACGGCTCACCTAGGCGACGATCCCTA" | |
| "GCTGGTCTGAGAGGATGACCAGCCACACTGGAACTGAGACACGGTCCAGACTCCTACGGGAGGCAGCAGTGG" | |
| "GGAATATTGCACAATGGGCGCAAGCCTGATGCAGCCATGCCGCGTGTATGAAGAAGGCCTTCGGGTTGTAAA" | |
| "GTACTTTCAGCGGGGAGGAAGGGAGTAAAGTTAATACCTTTGCTCATTGACGTTACCCGCAGAAGAAGCACC" | |
| "GGCTAACTCCGTGCCAGCAGCCGCGGTAATACGGAGGGTGCAAGCGTTAATCGGAATTACTGGGCGTAAAGC" | |
| "GCACGCAGGCGGTTTGTTAAGTCAGATGTGAAATCCCCGGGCTCAACCTGGGAACTGCATCTGATACTGGCA" | |
| "AGCTTGAGTCTCGTAGAGGGGGGTAGAATTCCAGGTGTAGCGGTGAAATGCGTAGAGATCTGGAGGAATACC" | |
| "GGTGGCGAAGGCGGCCCCCTGGACGAAGACTGACGCTCAGGTGCGAAAGCGTGGGGAGCAAACAGGATTAGA" | |
| "TACCCTGGTAGTCCACGCCGTAAACGATGTCGACTTGGAGGTTGTGCCCTTGAGGCGTGGCTTCCGGAGCTA" | |
| "ACGCGTTAAGTCGACCGCCTGGGGAGTACGGCCGCAAGGTTAAAACTCAAATGAATTGACGGGGGCCCGCAC" | |
| "AAGCGGTGGAGCATGTGGTTTAATTCGATGCAACGCGAAGAACCTTACCTGGTCTTGACATCCACGGAAGTT" | |
| "TTCAGAGATGAGAATGTGCCTTCGGGAACCGTGAGACAGGTGCTGCATGGCTGTCGTCAGCTCGTGTTGTGA" | |
| "AATGTTGGGTTAAGTCCCGCAACGAGCGCAACCCTTATCCTTTGTTGCCAGCGGTCCGGCCGGGAACTCAAA" | |
| "GGAGACTGCCAGTGATAAACTGGAGGAAGGTGGGGATGACGTCAAGTCATCATGGCCCTTACGACCAGGGCT" | |
| "ACACACGTGCTACAATGGCGCATACAAAGAGAAGCGACCTCGCGAGAGCAAGCGGACCTCATAAAGTGCGTC" | |
| "GTAGTCCGGATTGGAGTCTGCAACTCGACTCCATGAAGTCGGAATCGCTAGTAATCGTGGATCAGAATGCCA" | |
| "CGGTGAATACGTTCCCGGGCCTTGTACACACCGCCCGTCACACCATGGGAGTGGGTTGCAAAAGAAGTAGGT" | |
| "AGCTTAACCTTCGGGAGGGCGCTTACCACTTTGTGATTCATGACTGGGGTGAAGTCGTAACAAGGTAACCGT" | |
| "AGGGGAACCTGCGGTTGGATCACCTCCTTACCTCAGAACAAGAAGTATACCAAACTCAACGGAGTTACCACC" | |
| "CGGTGGCAATTGCACAATTGATCATCAGAGAGAACCAATGCTGAACATCAAGAACAATGAAGGCCAGCAAGG" | |
| "TAATCCGAGCAGAGCTAAAGAGAACGGATACCCATGACCACGGAAGTGGTCAAGAATATAGGCATCCAAAGA" | |
| "CGATCAGATACCGTCGTAGTTCCGACCATAAACGATGAAGAACACGTCCATATAGCCAAGCTCCGTAGGACA" | |
| "AGAATAATGAGACAAAACACAAAGCCAACAATGAGCCCTAAGTGATGTCCGGGGAAACCAGAAAGACCCGTA" | |
| "GCCTGAAAGATTGCCGGCCACTTGGAACGCTGGATTGAGCACCCTGTAGAACATTTGTTTGAACAGGTGCGG" | |
| "ACCGAATAAAGCCACATGATGCAACTATGAATCTGAACTTGCAATGCTGAACGAATCGCGATAAACCTAAGG" | |
| "CAGAAGCGTACCCGGGAACATCAATAGACTGCGATGTGATAACGTACCCAAACTTATCCCAGGGCCCGTAAC" | |
| "TAAACTGCCCCTTTGCGCTTCGAGTAAAGGCATCAAATAGATATAGACTCATAATGCCACAGTCCAATTACA" | |
| "TGCCCGGAAGTTATTAATACTGCGAACGTTATACATACGAAGCCGTAAGGTAATTTGATAAGCGTAACCGAT" | |
| "AGCCCCGACAGCGAACTAGCAACCTTGGAGTATATGAACCCAAATATCTGTGAGGCCTGGAACGTCCGAGAT" | |
| "GAGAGTGCCACATACTCAAGACTCAAAGTCACCCGAAGGGAATTTGCATATGAGCTCGTCTGGCCAGGAGTT" | |
| "TTAAGAGGGGCGCAGATATCACCTAATACGATAGCTAGCCGAATGCACTACGCCAACCATCTAACGGACAGA" | |
| "GTAATGAACCACAAGCTCCGAAATGATGCTGAGAGACGCATGGCCAGTTCTCATCAGCCGTCGTGGTCAATC" | |
| "GGTGCTTGGCCATCACCATGGGGGCCCGCATCTGCCATCGACAGCGCTTTCATCGTAAACCGTCTTATGGAA" | |
| "AGACATTACAGCCAGTGTAAAATCCCGCACACTATTAGCCATCAAATCATATAAGGCATACGGTCAGTCAGT" | |
| "ATTCCGAAAGAACACCACCAGTGATAGTACCAAGAGCACGTATGAATACGATGCCGACCATAGCGGACAAAT" | |
| "CTCCCAATACGAGAGTAAAATAAGCAAATAATAGATATCCATGCATGGAGTCACCACAATAGAGCGCTACGT" | |
| "CGTCGTGAAGAGGGAAACAACCCAGACCGCCAGCTAAGGTCCCAAAGTCATGGTTAAGTGGGAAACGATGTG" | |
| "GGAAGGCCCAGACAGCCAGGATGTTGGCTTAGAAGCAGCCATCATTTA" | |
| ) | |
| # Random DNA: uniform ACGT, 3 kb. Should be ~uninformative — model surprise | |
| # should be close to the uniform baseline (ln 6 ≈ 1.79 nats) because there | |
| # is no context-dependent pattern to learn. | |
| _rng = np.random.default_rng(42) | |
| EXAMPLE_RANDOM = "".join(_rng.choice(list("ACGT"), size=3000)) | |
| # Low-complexity repeat: 3 kb of AT dinucleotide repeats. Should give very | |
| # LOW surprise — the model trivially predicts the next base from context. | |
| EXAMPLE_REPEAT = "AT" * 1500 | |
| def process(sequence: str, mode: str, stride: int, layer: int): | |
| """Main processing function.""" | |
| sequence = strip_fasta_header(sequence.strip()) | |
| is_valid, error = validate_sequence(sequence) | |
| if not is_valid: | |
| return f"**Error**: {error}", None, None, None, None, None | |
| embedding, window_embeddings, positions = embed_sequence( | |
| sequence, mode=mode, stride=stride, layer=layer | |
| ) | |
| # Save embedding | |
| path = os.path.join(tempfile.gettempdir(), "embedding.npy") | |
| np.save(path, embedding) | |
| # Compute stats | |
| if mode == "per-window": | |
| # For per-window, compute stats on mean embedding | |
| mean_emb = np.mean(embedding, axis=0) | |
| stats = compute_embedding_stats(mean_emb) | |
| else: | |
| stats = compute_embedding_stats(embedding) | |
| # Create summary | |
| if mode == "per-window": | |
| summary = f"""### Results | |
| | | | | |
| |---|---| | |
| | sequence | {len(sequence):,} bp | | |
| | layer | {layer} | | |
| | windows | {embedding.shape[0]} | | |
| | shape | {embedding.shape} | | |
| **Stats** (on mean): L2={stats['l2_norm']:.1f}, entropy={stats['entropy']:.2f} | |
| """ | |
| else: | |
| summary = f"""### Results | |
| | | | | |
| |---|---| | |
| | sequence | {len(sequence):,} bp | | |
| | layer | {layer} | | |
| | mode | {mode} | | |
| | dim | {len(embedding)} | | |
| **Stats**: L2={stats['l2_norm']:.1f}, entropy={stats['entropy']:.2f}, sparsity={stats['sparsity']:.1%} | |
| """ | |
| # Create visualizations | |
| heatmap_fig = None | |
| if mode != "per-window": | |
| heatmap_fig = create_embedding_heatmap(embedding, f"Layer {layer}") | |
| multi_window = len(window_embeddings) > 1 | |
| trajectory_fig = create_trajectory_plot(window_embeddings, positions) if multi_window else None | |
| familiarity_fig = create_familiarity_plot(window_embeddings, positions) if multi_window else None | |
| dims_fig = create_dimension_plot(window_embeddings, positions) if multi_window else None | |
| return summary, path, heatmap_fig, trajectory_fig, familiarity_fig, dims_fig | |
| def process_surprise(sequence: str, stride: int, mask_fraction: float): | |
| """Compute MLM surprise across the sequence.""" | |
| sequence = strip_fasta_header(sequence.strip()) | |
| is_valid, error = validate_sequence(sequence) | |
| if not is_valid: | |
| return f"**Error**: {error}", None | |
| per_window, per_base_pos, per_base_vals = compute_mlm_surprise( | |
| sequence, stride=stride, mask_fraction=mask_fraction | |
| ) | |
| if not per_window: | |
| return "**Error**: sequence too short for one window", None | |
| fig = create_surprise_plot(per_window, per_base_pos, per_base_vals, len(sequence)) | |
| w_vals = np.array([s for _, s in per_window]) | |
| lo_pos, lo_val = per_window[int(np.argmin(w_vals))] | |
| hi_pos, hi_val = per_window[int(np.argmax(w_vals))] | |
| summary = f"""### MLM surprise | |
| | | | | |
| |---|---| | |
| | sequence | {len(sequence):,} bp | | |
| | windows | {len(per_window)} | | |
| | mask fraction | {mask_fraction:.0%} | | |
| | mean surprise | {w_vals.mean():.3f} nats | | |
| | uniform baseline | {UNIFORM_SURPRISE:.3f} nats (ln 6) | | |
| | most predictable window | {lo_val:.3f} nats @ ~{lo_pos:,} bp | | |
| | most surprising window | {hi_val:.3f} nats @ ~{hi_pos:,} bp | | |
| Lower = model confidently predicts the true base → conserved/typical pattern. | |
| Higher = model is unsure → unusual region relative to training distribution. | |
| """ | |
| return summary, fig | |
| # Build interface | |
| with gr.Blocks( | |
| title="BERT Metagenome Embeddings", | |
| css=".gradio-container { max-width: 100% !important; }" | |
| ) as demo: | |
| gr.Markdown( | |
| "# bert-embedding\n" | |
| "BERT (24 layers, 430M params) pretrained with masked-language-modeling " | |
| "on metagenomic contigs. Input: DNA sequence ≥ 1000 bp. The model slides a " | |
| "1000 bp window over your sequence and produces two kinds of output — " | |
| "**embeddings** (768-dim hidden-state vector per window, under *Extract*) " | |
| "and **MLM surprise** (model's confidence in reconstructing masked bases, " | |
| "under *MLM surprise*). See the *Guide* tab for how to read the plots." | |
| ) | |
| with gr.Tab("Extract"): | |
| gr.Markdown( | |
| "Extract 768-dim hidden-state embeddings from a chosen transformer layer. " | |
| "Produces **per-window** vectors (one per 1000-bp window) that can be pooled, " | |
| "visualised as heatmaps, and compared via cosine similarity. " | |
| "The per-window plot below shows how the embedding *changes along the sequence* — " | |
| "it does **not** say anything about whether bases are predictable " | |
| "(that's the MLM surprise tab)." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1, min_width=260): | |
| seq_input = gr.Textbox( | |
| label="sequence (≥ 1000 bp)", | |
| placeholder="Paste DNA (FASTA or raw)...", | |
| lines=8, | |
| value=EXAMPLE_SEQUENCE | |
| ) | |
| gr.Markdown("**load an example:**") | |
| with gr.Row(): | |
| ex_lacz = gr.Button("lacZ (mixed)", size="sm") | |
| ex_16s = gr.Button("16S rRNA (conserved)", size="sm") | |
| with gr.Row(): | |
| ex_rand = gr.Button("random DNA", size="sm") | |
| ex_rep = gr.Button("AT repeat (low complexity)", size="sm") | |
| mode_input = gr.Radio( | |
| choices=["mean", "max", "per-window"], | |
| value="mean", label="pooling", | |
| info="how to collapse positions within each window" | |
| ) | |
| layer_input = gr.Slider(0, 23, value=21, step=1, label="layer (0=shallow, 23=deep)") | |
| stride_input = gr.Slider(50, 500, value=100, step=50, label="stride (bp)", | |
| info="step between windows. lower = more windows, more compute") | |
| btn = gr.Button("extract", variant="primary") | |
| with gr.Column(scale=3, min_width=500): | |
| output = gr.Markdown() | |
| download = gr.File(label="download embedding (.npy)") | |
| gr.Markdown( | |
| "**Per-window plot below.** x-axis = window-start position along your input " | |
| "sequence. Each point is one 1000-bp window. " | |
| "*L2 norm* = how strongly the model's neurons fire on this window (bigger = " | |
| "stronger activation). *Novelty* = how different this window's embedding is " | |
| "from the average embedding across your sequence (1 − cos sim to mean); " | |
| "peaks = regions that stand out from the rest of **this** sequence." | |
| ) | |
| familiarity_plot = gr.Plot(label="per-window L2 norm & embedding-space novelty along your input sequence") | |
| gr.Markdown( | |
| "**Trajectory** (left): heatmap of (windows × ~100 subsampled embedding " | |
| "dimensions). Sharp horizontal bands = sudden embedding change → boundary. " | |
| "**Top varying dims** (right): the 8 dimensions that vary most across windows, " | |
| "plotted vs window position." | |
| ) | |
| with gr.Row(): | |
| trajectory_plot = gr.Plot(label="window × dimension heatmap (trajectory)") | |
| dims_plot = gr.Plot(label="top-8 most variable dimensions vs window position") | |
| gr.Markdown( | |
| "**Pooled embedding heatmap** (below): the single 768-dim vector after " | |
| "pooling across windows — only shown in `mean` / `max` mode. " | |
| "Red = positive activation, blue = negative." | |
| ) | |
| heatmap_plot = gr.Plot(label="pooled 768-dim embedding heatmap (24 × 32 grid)") | |
| btn.click( | |
| process, | |
| inputs=[seq_input, mode_input, stride_input, layer_input], | |
| outputs=[output, download, heatmap_plot, trajectory_plot, familiarity_plot, dims_plot], | |
| api_name="embed" | |
| ) | |
| ex_lacz.click(lambda: EXAMPLE_SEQUENCE, outputs=seq_input) | |
| ex_16s.click(lambda: EXAMPLE_16S, outputs=seq_input) | |
| ex_rand.click(lambda: EXAMPLE_RANDOM, outputs=seq_input) | |
| ex_rep.click(lambda: EXAMPLE_REPEAT, outputs=seq_input) | |
| with gr.Tab("MLM surprise"): | |
| gr.Markdown( | |
| "**What it does.** For each 1000-bp window, we randomly replace ~15% of bases " | |
| "with a mask token, ask the model to predict what was there, and record " | |
| "**−log p(true base)** at those positions. Low = model is confident (pattern " | |
| "matches training distribution). High = model is uncertain (unusual region).\n\n" | |
| "**Unit**: nats. Uniform-random guessing over {A,C,G,T,N,PAD} = **ln 6 ≈ 1.79 nats**. " | |
| "A perfectly confident correct prediction = **0 nats**.\n\n" | |
| "**Expected shape on the examples below**:\n" | |
| "- *lacZ* → moderate (~0.6–1.3 nats), with visible structure between CDS and UTR.\n" | |
| "- *16S rRNA* → lowest values, often flat across the whole region (highly conserved).\n" | |
| "- *random DNA* → near the 1.79 baseline, no trend (by construction unpredictable).\n" | |
| "- *AT repeat* → near 0 (trivial to predict once you see a few bases)." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1, min_width=260): | |
| surp_seq = gr.Textbox( | |
| label="sequence (≥ 1000 bp)", | |
| placeholder="Paste DNA (FASTA or raw)...", | |
| lines=8, | |
| value=EXAMPLE_SEQUENCE, | |
| ) | |
| gr.Markdown("**load an example:**") | |
| with gr.Row(): | |
| sx_lacz = gr.Button("lacZ (mixed)", size="sm") | |
| sx_16s = gr.Button("16S rRNA (conserved)", size="sm") | |
| with gr.Row(): | |
| sx_rand = gr.Button("random DNA", size="sm") | |
| sx_rep = gr.Button("AT repeat (low complexity)", size="sm") | |
| surp_stride = gr.Slider(50, 500, value=100, step=50, label="stride (bp)", | |
| info="step between windows. lower = more windows") | |
| surp_mask = gr.Slider(0.05, 0.5, value=0.15, step=0.05, | |
| label="mask fraction", | |
| info="fraction of positions masked in each window (0.15 matches BERT training)") | |
| surp_btn = gr.Button("score", variant="primary") | |
| with gr.Column(scale=3, min_width=500): | |
| surp_summary = gr.Markdown() | |
| surp_plot = gr.Plot(label="MLM surprise along the input sequence") | |
| surp_btn.click( | |
| process_surprise, | |
| inputs=[surp_seq, surp_stride, surp_mask], | |
| outputs=[surp_summary, surp_plot], | |
| api_name="surprise", | |
| ) | |
| sx_lacz.click(lambda: EXAMPLE_SEQUENCE, outputs=surp_seq) | |
| sx_16s.click(lambda: EXAMPLE_16S, outputs=surp_seq) | |
| sx_rand.click(lambda: EXAMPLE_RANDOM, outputs=surp_seq) | |
| sx_rep.click(lambda: EXAMPLE_REPEAT, outputs=surp_seq) | |
| with gr.Tab("Guide"): | |
| gr.Markdown(""" | |
| ### What this space actually does | |
| The model is a BERT trained to predict masked bases in DNA. Two things come out of it: | |
| | | **Extract (embeddings)** | **MLM surprise** | | |
| |---|---|---| | |
| | What | 768-dim hidden-state vector per window (layer L) | −log p(true base) at masked positions | | |
| | Y-axis | L2 norm / cosine distance to sequence mean | nats | | |
| | Measures | how the *representation* changes along the sequence | how *predictable* the bases are | | |
| | Depends on training data? | indirectly (via what the model learned to represent) | yes — bases typical of training data get low surprise | | |
| | "Unusual" means | this window's embedding differs from the rest of *this* sequence | this base is hard to predict from context, relative to what the model saw in pretraining | | |
| They look similar (both have spiky plots over position) but answer **different questions**. A | |
| region can be *embedding-novel* (structurally unlike the rest of your sequence) without being | |
| *MLM-surprising* (it could still be a predictable pattern the model knows). And vice versa. | |
| ### Reading the per-window plots | |
| - **x-axis** on every per-window plot = position along your input sequence, in bp. | |
| Each point = one 1000-bp window, plotted at the window's start (Extract tab) or center | |
| (MLM surprise tab). With `stride=100`, successive points are 100 bp apart and **overlap | |
| by 900 bp** — so plots are smoother than they look, and a single isolated peak is usually | |
| real, not noise. | |
| - **y-axis units**: | |
| - Extract → L2 norm (unitless magnitude of the 768-dim vector) and cosine-distance 1−cos. | |
| - MLM surprise → nats (natural-log likelihood). 0 = perfectly confident, ln 6 ≈ 1.79 = uniform. | |
| ### How the examples should look | |
| - **lacZ (~3 kb of E. coli)** — mixed content (regulatory + coding). MLM surprise will have | |
| peaks and troughs reflecting CDS vs UTR structure. Embedding novelty has some variation. | |
| - **16S rRNA (~3 kb of rrnB)** — highly conserved, functional RNA. MLM surprise should be | |
| visibly *lower* than on lacZ, often flat. Embedding trajectory looks like one consistent | |
| region. | |
| - **random DNA (3 kb, uniform ACGT)** — no pattern. MLM surprise should sit near the ln 6 | |
| baseline with no trend. Embedding novelty will look noisy. | |
| - **AT repeat (3 kb of ATATAT…)** — trivially predictable. MLM surprise should be very low | |
| (~0) across the whole sequence. | |
| ### Caveats | |
| - MLM surprise depends on which token we use as [MASK]. This space uses token `0` (PAD/OOV). | |
| If the original pretraining used a different sentinel, scores will be uniformly *pessimistic* | |
| (near the 1.79 baseline on everything). If you see that on the conserved 16S example, the | |
| mask token is wrong and we need to switch to `5` (N/AMB). | |
| - Only 15% of each window is scored per run. Points in the per-base scatter are sampled, not | |
| exhaustive — but because windows overlap 10× at stride 100, every base usually has several | |
| observations. | |
| - Both metrics are relative. Absolute values mean little; compare across regions of the same | |
| sequence, or across the example sequences above. | |
| """) | |
| with gr.Tab("API"): | |
| gr.Markdown(""" | |
| ### API | |
| ```python | |
| from gradio_client import Client | |
| import numpy as np | |
| client = Client("genomenet/bert-embedding") | |
| result = client.predict( | |
| sequence="ATGC...", # min 1000 bp | |
| mode="mean", # mean/max/per-window | |
| stride=100, | |
| layer=21, # 0-23 | |
| api_name="/embed" | |
| ) | |
| summary, emb_path, *plots = result | |
| embedding = np.load(emb_path) | |
| ``` | |
| **Per-window plots** (along sequence position): | |
| - **L2 norm**: activation magnitude — high = strong, structured response. | |
| - **Novelty** (1 − cosine similarity to mean embedding): how much the window differs | |
| from the rest of the sequence. Spikes = unusual regions relative to context. | |
| Numeric stats (L2, entropy, sparsity, kurtosis) are in the summary text. | |
| ### MLM surprise endpoint | |
| ```python | |
| summary, plot = client.predict( | |
| sequence="ATGC...", | |
| stride=100, | |
| mask_fraction=0.15, | |
| api_name="/surprise", | |
| ) | |
| ``` | |
| Returns per-window mean `-log(p_true)` at masked positions (in nats). | |
| Uniform-random baseline is `ln(6) ≈ 1.79 nats`. | |
| """) | |
| with gr.Tab("About"): | |
| gr.Markdown(""" | |
| ### Model | |
| | | | | |
| |---|---| | |
| | architecture | BERT, 24 layers, 768 hidden, 12 heads | | |
| | parameters | ~430M | | |
| | input | 1000 bp sliding window | | |
| | pretraining | metagenomic contigs + microbial genomes | | |
| ### Interpreting Statistics | |
| The embedding statistics provide indirect measures of how the model "responds" to a sequence: | |
| - **L2 Norm**: Total activation magnitude. Very high or low may indicate unusual sequences. | |
| - **Entropy**: How spread out the activations are. Lower entropy suggests more confident/structured representation. | |
| - **Sparsity**: Fraction of dimensions with near-zero activation. | |
| - **Kurtosis**: How peaked the distribution is. Higher values = more concentrated activations. | |
| **Note**: These are not direct "familiarity" probabilities, but patterns in these metrics across | |
| different sequence types may reveal what the model considers typical vs. unusual. | |
| ### Links | |
| - Model: [genomenet/bert-metagenome](https://huggingface.co/genomenet/bert-metagenome) | |
| - CRISPR: [genomenet/crispr-array-detection](https://huggingface.co/spaces/genomenet/crispr-array-detection) | |
| """) | |
| if __name__ == "__main__": | |
| print("Loading model...") | |
| _ = get_base_model() | |
| print(f"Ready! {get_gpu_status()}") | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| theme=gr.themes.Base( | |
| primary_hue=gr.themes.colors.zinc, | |
| neutral_hue=gr.themes.colors.zinc, | |
| ) | |
| ) | |