Spaces:
Sleeping
Sleeping
Add embedding stats, full-width layout, 3-column design
Browse files
app.py
CHANGED
|
@@ -13,9 +13,7 @@ from huggingface_hub import hf_hub_download
|
|
| 13 |
import matplotlib
|
| 14 |
matplotlib.use('Agg')
|
| 15 |
import matplotlib.pyplot as plt
|
| 16 |
-
from matplotlib.colors import TwoSlopeNorm
|
| 17 |
import plotly.graph_objects as go
|
| 18 |
-
from plotly.subplots import make_subplots
|
| 19 |
|
| 20 |
from custom_layers import get_custom_objects
|
| 21 |
|
|
@@ -23,12 +21,12 @@ from custom_layers import get_custom_objects
|
|
| 23 |
MODEL_REPO = "genomenet/bert-metagenome"
|
| 24 |
MODEL_FILE = "bert_1k_3.h5"
|
| 25 |
WINDOW_SIZE = 1000
|
| 26 |
-
NUM_LAYERS = 24
|
| 27 |
EMBEDDING_DIM = 768
|
| 28 |
|
| 29 |
# Singleton model cache
|
| 30 |
_model = None
|
| 31 |
-
_embedding_models = {}
|
| 32 |
|
| 33 |
def get_base_model():
|
| 34 |
"""Load and cache the base model."""
|
|
@@ -39,6 +37,8 @@ def get_base_model():
|
|
| 39 |
print(f"Loading model from {model_path}...")
|
| 40 |
_model = tf.keras.models.load_model(model_path, custom_objects=get_custom_objects(), compile=False)
|
| 41 |
print("Model loaded.")
|
|
|
|
|
|
|
| 42 |
return _model
|
| 43 |
|
| 44 |
def get_embedding_model(layer_idx=21):
|
|
@@ -53,7 +53,6 @@ def get_embedding_model(layer_idx=21):
|
|
| 53 |
outputs=model.get_layer(layer_name).output
|
| 54 |
)
|
| 55 |
except ValueError:
|
| 56 |
-
# Fallback to layer 21 if requested layer not found
|
| 57 |
_embedding_models[layer_idx] = tf.keras.Model(
|
| 58 |
inputs=model.input,
|
| 59 |
outputs=model.get_layer("layer_transformer_block_21").output
|
|
@@ -62,25 +61,14 @@ def get_embedding_model(layer_idx=21):
|
|
| 62 |
|
| 63 |
def get_gpu_status():
|
| 64 |
gpus = tf.config.list_physical_devices('GPU')
|
| 65 |
-
if gpus
|
| 66 |
-
return f"GPU: {gpus[0].name}"
|
| 67 |
-
return "CPU only"
|
| 68 |
|
| 69 |
-
# Tokenization
|
| 70 |
TOKEN_MAP = {'A': 1, 'C': 2, 'G': 3, 'T': 4, 'N': 5}
|
| 71 |
|
| 72 |
def tokenize(sequence):
|
| 73 |
-
"""Convert DNA sequence to integer token IDs."""
|
| 74 |
sequence = sequence.upper().replace('U', 'T')
|
| 75 |
-
|
| 76 |
-
for char in sequence:
|
| 77 |
-
if char in TOKEN_MAP:
|
| 78 |
-
tokens.append(TOKEN_MAP[char])
|
| 79 |
-
elif char in 'RYSWKMBDHV':
|
| 80 |
-
tokens.append(5)
|
| 81 |
-
else:
|
| 82 |
-
tokens.append(5)
|
| 83 |
-
return np.array(tokens, dtype=np.int32)
|
| 84 |
|
| 85 |
def validate_sequence(sequence):
|
| 86 |
if not sequence or len(sequence.strip()) == 0:
|
|
@@ -97,29 +85,57 @@ def validate_sequence(sequence):
|
|
| 97 |
|
| 98 |
def strip_fasta_header(text):
|
| 99 |
lines = text.strip().split('\n')
|
| 100 |
-
|
| 101 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
|
| 103 |
def embed_sequence(sequence, mode="mean", stride=100, layer=21):
|
| 104 |
"""Extract embeddings from sequence."""
|
| 105 |
model = get_embedding_model(layer)
|
| 106 |
-
|
| 107 |
seq_len = len(sequence)
|
| 108 |
embeddings = []
|
| 109 |
positions = []
|
| 110 |
|
| 111 |
for start in range(0, seq_len - WINDOW_SIZE + 1, stride):
|
| 112 |
window = sequence[start:start + WINDOW_SIZE]
|
| 113 |
-
tokens = tokenize(window)
|
| 114 |
-
tokens = np.expand_dims(tokens, axis=0)
|
| 115 |
-
|
| 116 |
emb = model.predict(tokens, verbose=0)
|
| 117 |
embeddings.append(emb[0])
|
| 118 |
positions.append(start)
|
| 119 |
|
| 120 |
embeddings = np.array(embeddings) # (n_windows, 1000, 768)
|
| 121 |
|
| 122 |
-
# Pool across sequence positions within each window
|
| 123 |
if mode == "mean":
|
| 124 |
window_emb = np.mean(embeddings, axis=1)
|
| 125 |
return np.mean(window_emb, axis=0), window_emb, positions
|
|
@@ -140,120 +156,116 @@ def create_embedding_heatmap(embedding, title="Embedding"):
|
|
| 140 |
cols = 32
|
| 141 |
rows = int(np.ceil(n_dims / cols))
|
| 142 |
|
| 143 |
-
# Pad to fill grid
|
| 144 |
padded = np.full(rows * cols, np.nan)
|
| 145 |
padded[:n_dims] = embedding
|
| 146 |
grid = padded.reshape(rows, cols)
|
| 147 |
|
| 148 |
-
# Symmetric normalization
|
| 149 |
finite = embedding[np.isfinite(embedding)]
|
| 150 |
-
if finite.size > 0
|
| 151 |
-
vmax = max(abs(np.nanmin(finite)), abs(np.nanmax(finite)), 0.01)
|
| 152 |
-
else:
|
| 153 |
-
vmax = 1.0
|
| 154 |
|
| 155 |
-
fig, ax = plt.subplots(figsize=(
|
| 156 |
im = ax.imshow(grid, cmap='RdBu_r', vmin=-vmax, vmax=vmax, aspect='auto')
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
ax.set_xlabel('Dimension', fontsize=9)
|
| 162 |
-
ax.set_ylabel('Row', fontsize=9)
|
| 163 |
-
ax.set_title(f'{title} ({n_dims} dims)', fontsize=10)
|
| 164 |
ax.set_xticks(np.arange(0, cols, 8))
|
| 165 |
-
|
| 166 |
plt.tight_layout()
|
| 167 |
return fig
|
| 168 |
|
| 169 |
-
def create_trajectory_plot(window_embeddings, positions
|
| 170 |
-
"""Create interactive trajectory
|
| 171 |
-
n_windows = len(window_embeddings)
|
| 172 |
-
|
| 173 |
-
# Subsample dimensions for visualization
|
| 174 |
emb = np.array(window_embeddings)
|
| 175 |
-
n_dims = emb.shape
|
| 176 |
-
if n_dims > 100:
|
| 177 |
-
step = n_dims // 100
|
| 178 |
-
emb_sub = emb[:, ::step]
|
| 179 |
-
else:
|
| 180 |
-
emb_sub = emb
|
| 181 |
|
| 182 |
-
#
|
| 183 |
-
|
|
|
|
| 184 |
|
| 185 |
-
# Symmetric color scale
|
| 186 |
vmax = max(abs(np.nanmin(emb_sub)), abs(np.nanmax(emb_sub)), 0.01)
|
| 187 |
|
| 188 |
-
fig.
|
| 189 |
z=emb_sub,
|
| 190 |
x=list(range(emb_sub.shape[1])),
|
| 191 |
-
y=[f"{p}
|
| 192 |
colorscale='RdBu_r',
|
| 193 |
zmin=-vmax, zmax=vmax,
|
| 194 |
-
colorbar=dict(title='
|
| 195 |
-
hovertemplate='
|
| 196 |
))
|
| 197 |
|
| 198 |
fig.update_layout(
|
| 199 |
-
title=
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
height=max(300, n_windows * 20 + 100),
|
| 204 |
-
plot_bgcolor='#fafafa',
|
| 205 |
-
paper_bgcolor='#fafafa',
|
| 206 |
-
font=dict(family='Inter, system-ui, sans-serif', size=10)
|
| 207 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 208 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
return fig
|
| 210 |
|
| 211 |
-
def create_dimension_plot(window_embeddings, positions, top_k=
|
| 212 |
-
"""Show top varying dimensions
|
| 213 |
emb = np.array(window_embeddings)
|
| 214 |
-
|
| 215 |
-
# Find dimensions with highest variance
|
| 216 |
variances = np.var(emb, axis=0)
|
| 217 |
top_dims = np.argsort(variances)[-top_k:][::-1]
|
| 218 |
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
colors = ['#e41a1c', '#377eb8', '#4daf4a', '#984ea3', '#ff7f00',
|
| 222 |
-
'#a65628', '#f781bf', '#999999', '#66c2a5', '#fc8d62']
|
| 223 |
|
|
|
|
| 224 |
for i, dim in enumerate(top_dims):
|
| 225 |
fig.add_trace(go.Scatter(
|
| 226 |
-
x=positions,
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
name=f'dim {dim}',
|
| 230 |
-
line=dict(color=colors[i % len(colors)], width=1.5),
|
| 231 |
-
hovertemplate=f'Dim {dim}<br>Pos: %{{x}}<br>Value: %{{y:.3f}}<extra></extra>'
|
| 232 |
))
|
| 233 |
|
| 234 |
fig.update_layout(
|
| 235 |
-
title=
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
plot_bgcolor='#fafafa',
|
| 241 |
-
paper_bgcolor='#fafafa',
|
| 242 |
-
font=dict(family='Inter, system-ui, sans-serif', size=10)
|
| 243 |
)
|
| 244 |
-
|
| 245 |
return fig
|
| 246 |
|
| 247 |
-
# Example sequence
|
| 248 |
EXAMPLE_SEQUENCE = """ATGCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTACGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCG"""
|
| 249 |
|
| 250 |
-
def process(sequence: str, mode: str, stride: int, layer: int
|
| 251 |
"""Main processing function."""
|
| 252 |
sequence = strip_fasta_header(sequence.strip())
|
| 253 |
|
| 254 |
is_valid, error = validate_sequence(sequence)
|
| 255 |
if not is_valid:
|
| 256 |
-
return f"**Error**: {error}", None, None, None, None
|
| 257 |
|
| 258 |
embedding, window_embeddings, positions = embed_sequence(
|
| 259 |
sequence, mode=mode, stride=stride, layer=layer
|
|
@@ -263,22 +275,29 @@ def process(sequence: str, mode: str, stride: int, layer: int, show_heatmap: boo
|
|
| 263 |
path = os.path.join(tempfile.gettempdir(), "embedding.npy")
|
| 264 |
np.save(path, embedding)
|
| 265 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 266 |
# Create summary
|
| 267 |
if mode == "per-window":
|
| 268 |
-
|
| 269 |
-
summary = f"""## Embeddings extracted
|
| 270 |
|
| 271 |
| | |
|
| 272 |
|---|---|
|
| 273 |
| sequence | {len(sequence):,} bp |
|
| 274 |
| layer | {layer} |
|
| 275 |
| windows | {embedding.shape[0]} |
|
| 276 |
-
|
|
| 277 |
-
|
|
|
|
| 278 |
"""
|
| 279 |
else:
|
| 280 |
-
|
| 281 |
-
summary = f"""## Embedding extracted
|
| 282 |
|
| 283 |
| | |
|
| 284 |
|---|---|
|
|
@@ -287,81 +306,60 @@ def process(sequence: str, mode: str, stride: int, layer: int, show_heatmap: boo
|
|
| 287 |
| mode | {mode} |
|
| 288 |
| dim | {len(embedding)} |
|
| 289 |
|
| 290 |
-
**
|
| 291 |
"""
|
| 292 |
|
| 293 |
# Create visualizations
|
| 294 |
heatmap_fig = None
|
| 295 |
-
|
| 296 |
-
|
| 297 |
|
| 298 |
-
if
|
| 299 |
-
|
|
|
|
| 300 |
|
| 301 |
-
|
| 302 |
-
trajectory_fig = create_trajectory_plot(window_embeddings, positions, stride)
|
| 303 |
-
dims_fig = create_dimension_plot(window_embeddings, positions)
|
| 304 |
-
|
| 305 |
-
return summary, path, heatmap_fig, trajectory_fig, dims_fig
|
| 306 |
-
|
| 307 |
-
# CSS
|
| 308 |
-
CUSTOM_CSS = """
|
| 309 |
-
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500&display=swap');
|
| 310 |
-
* { font-family: 'Inter', system-ui, sans-serif !important; }
|
| 311 |
-
code, pre, textarea { font-family: 'SF Mono', Consolas, monospace !important; }
|
| 312 |
-
.gradio-container { max-width: 1100px !important; background: #fafafa !important; }
|
| 313 |
-
"""
|
| 314 |
|
| 315 |
# Build interface
|
| 316 |
-
with gr.Blocks(
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
""")
|
| 322 |
|
| 323 |
with gr.Tab("Extract"):
|
| 324 |
with gr.Row():
|
| 325 |
-
with gr.Column(scale=1):
|
| 326 |
seq_input = gr.Textbox(
|
| 327 |
label="sequence",
|
| 328 |
-
placeholder="Paste DNA
|
| 329 |
-
lines=
|
| 330 |
-
value=EXAMPLE_SEQUENCE
|
| 331 |
-
info="min 1000 bp"
|
| 332 |
)
|
| 333 |
with gr.Row():
|
| 334 |
mode_input = gr.Radio(
|
| 335 |
choices=["mean", "max", "per-window"],
|
| 336 |
-
value="mean",
|
| 337 |
-
label="pooling"
|
| 338 |
-
)
|
| 339 |
-
layer_input = gr.Slider(
|
| 340 |
-
minimum=0, maximum=23, value=21, step=1,
|
| 341 |
-
label="layer",
|
| 342 |
-
info="transformer block (0-23)"
|
| 343 |
-
)
|
| 344 |
-
with gr.Row():
|
| 345 |
-
stride_input = gr.Slider(
|
| 346 |
-
minimum=50, maximum=500, value=100, step=50,
|
| 347 |
-
label="stride"
|
| 348 |
)
|
| 349 |
with gr.Row():
|
| 350 |
-
|
| 351 |
-
|
| 352 |
btn = gr.Button("extract", variant="primary")
|
| 353 |
output = gr.Markdown()
|
| 354 |
-
download = gr.File(label="download")
|
| 355 |
|
| 356 |
-
with gr.Column(scale=
|
|
|
|
| 357 |
heatmap_plot = gr.Plot(label="embedding heatmap")
|
|
|
|
|
|
|
| 358 |
trajectory_plot = gr.Plot(label="window trajectory")
|
| 359 |
dims_plot = gr.Plot(label="top varying dimensions")
|
| 360 |
|
| 361 |
btn.click(
|
| 362 |
process,
|
| 363 |
-
inputs=[seq_input, mode_input, stride_input, layer_input
|
| 364 |
-
outputs=[output, download, heatmap_plot, trajectory_plot, dims_plot],
|
| 365 |
api_name="embed"
|
| 366 |
)
|
| 367 |
|
|
@@ -374,33 +372,25 @@ from gradio_client import Client
|
|
| 374 |
import numpy as np
|
| 375 |
|
| 376 |
client = Client("genomenet/bert-embedding")
|
| 377 |
-
|
| 378 |
result = client.predict(
|
| 379 |
-
sequence="
|
| 380 |
-
mode="mean",
|
| 381 |
stride=100,
|
| 382 |
-
layer=21,
|
| 383 |
-
show_heatmap=True,
|
| 384 |
-
show_trajectory=True,
|
| 385 |
api_name="/embed"
|
| 386 |
)
|
| 387 |
-
|
| 388 |
summary, emb_path, *plots = result
|
| 389 |
embedding = np.load(emb_path)
|
| 390 |
```
|
| 391 |
|
| 392 |
-
**
|
| 393 |
-
|
| 394 |
-
**
|
| 395 |
-
-
|
| 396 |
-
-
|
| 397 |
-
- `per-window`: Matrix `(n_windows, 768)`
|
| 398 |
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
git clone https://huggingface.co/spaces/genomenet/bert-embedding
|
| 402 |
-
pip install -r requirements.txt && python app.py
|
| 403 |
-
```
|
| 404 |
""")
|
| 405 |
|
| 406 |
with gr.Tab("About"):
|
|
@@ -411,20 +401,24 @@ pip install -r requirements.txt && python app.py
|
|
| 411 |
|---|---|
|
| 412 |
| architecture | BERT, 24 layers, 768 hidden, 12 heads |
|
| 413 |
| parameters | ~430M |
|
| 414 |
-
| input | 1000 bp
|
| 415 |
-
| output | 768-dim embedding per position |
|
| 416 |
| pretraining | metagenomic contigs + microbial genomes |
|
| 417 |
|
| 418 |
-
###
|
| 419 |
|
| 420 |
-
|
| 421 |
-
- **Trajectory**: How embeddings change across sliding windows. Useful for seeing sequence structure.
|
| 422 |
-
- **Top dimensions**: Dimensions with highest variance - most informative for distinguishing sequence regions.
|
| 423 |
|
| 424 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 425 |
|
|
|
|
| 426 |
- Model: [genomenet/bert-metagenome](https://huggingface.co/genomenet/bert-metagenome)
|
| 427 |
-
- CRISPR
|
| 428 |
""")
|
| 429 |
|
| 430 |
if __name__ == "__main__":
|
|
|
|
| 13 |
import matplotlib
|
| 14 |
matplotlib.use('Agg')
|
| 15 |
import matplotlib.pyplot as plt
|
|
|
|
| 16 |
import plotly.graph_objects as go
|
|
|
|
| 17 |
|
| 18 |
from custom_layers import get_custom_objects
|
| 19 |
|
|
|
|
| 21 |
MODEL_REPO = "genomenet/bert-metagenome"
|
| 22 |
MODEL_FILE = "bert_1k_3.h5"
|
| 23 |
WINDOW_SIZE = 1000
|
| 24 |
+
NUM_LAYERS = 24
|
| 25 |
EMBEDDING_DIM = 768
|
| 26 |
|
| 27 |
# Singleton model cache
|
| 28 |
_model = None
|
| 29 |
+
_embedding_models = {}
|
| 30 |
|
| 31 |
def get_base_model():
|
| 32 |
"""Load and cache the base model."""
|
|
|
|
| 37 |
print(f"Loading model from {model_path}...")
|
| 38 |
_model = tf.keras.models.load_model(model_path, custom_objects=get_custom_objects(), compile=False)
|
| 39 |
print("Model loaded.")
|
| 40 |
+
# Print model summary for debugging
|
| 41 |
+
print(f"Model outputs: {_model.output_names}")
|
| 42 |
return _model
|
| 43 |
|
| 44 |
def get_embedding_model(layer_idx=21):
|
|
|
|
| 53 |
outputs=model.get_layer(layer_name).output
|
| 54 |
)
|
| 55 |
except ValueError:
|
|
|
|
| 56 |
_embedding_models[layer_idx] = tf.keras.Model(
|
| 57 |
inputs=model.input,
|
| 58 |
outputs=model.get_layer("layer_transformer_block_21").output
|
|
|
|
| 61 |
|
| 62 |
def get_gpu_status():
|
| 63 |
gpus = tf.config.list_physical_devices('GPU')
|
| 64 |
+
return f"GPU: {gpus[0].name}" if gpus else "CPU only"
|
|
|
|
|
|
|
| 65 |
|
| 66 |
+
# Tokenization
|
| 67 |
TOKEN_MAP = {'A': 1, 'C': 2, 'G': 3, 'T': 4, 'N': 5}
|
| 68 |
|
| 69 |
def tokenize(sequence):
|
|
|
|
| 70 |
sequence = sequence.upper().replace('U', 'T')
|
| 71 |
+
return np.array([TOKEN_MAP.get(c, 5) for c in sequence], dtype=np.int32)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
|
| 73 |
def validate_sequence(sequence):
|
| 74 |
if not sequence or len(sequence.strip()) == 0:
|
|
|
|
| 85 |
|
| 86 |
def strip_fasta_header(text):
|
| 87 |
lines = text.strip().split('\n')
|
| 88 |
+
return ''.join(l for l in lines if not l.startswith('>')).replace(' ', '').replace('\t', '')
|
| 89 |
+
|
| 90 |
+
def compute_embedding_stats(embedding):
|
| 91 |
+
"""Compute statistics that may indicate sequence 'familiarity'."""
|
| 92 |
+
emb = np.array(embedding)
|
| 93 |
+
|
| 94 |
+
# L2 norm - magnitude of response
|
| 95 |
+
l2_norm = np.linalg.norm(emb)
|
| 96 |
+
|
| 97 |
+
# Mean activation
|
| 98 |
+
mean_act = np.mean(emb)
|
| 99 |
+
|
| 100 |
+
# Std - spread of activations
|
| 101 |
+
std_act = np.std(emb)
|
| 102 |
+
|
| 103 |
+
# Sparsity - fraction of near-zero activations
|
| 104 |
+
sparsity = np.mean(np.abs(emb) < 0.1)
|
| 105 |
+
|
| 106 |
+
# Activation entropy (discretized)
|
| 107 |
+
hist, _ = np.histogram(emb, bins=50, density=True)
|
| 108 |
+
hist = hist[hist > 0]
|
| 109 |
+
entropy = -np.sum(hist * np.log(hist + 1e-10))
|
| 110 |
+
|
| 111 |
+
# Kurtosis - peakedness (high = more concentrated activations)
|
| 112 |
+
kurtosis = np.mean(((emb - mean_act) / (std_act + 1e-10)) ** 4) - 3
|
| 113 |
+
|
| 114 |
+
return {
|
| 115 |
+
'l2_norm': float(l2_norm),
|
| 116 |
+
'mean': float(mean_act),
|
| 117 |
+
'std': float(std_act),
|
| 118 |
+
'sparsity': float(sparsity),
|
| 119 |
+
'entropy': float(entropy),
|
| 120 |
+
'kurtosis': float(kurtosis)
|
| 121 |
+
}
|
| 122 |
|
| 123 |
def embed_sequence(sequence, mode="mean", stride=100, layer=21):
|
| 124 |
"""Extract embeddings from sequence."""
|
| 125 |
model = get_embedding_model(layer)
|
|
|
|
| 126 |
seq_len = len(sequence)
|
| 127 |
embeddings = []
|
| 128 |
positions = []
|
| 129 |
|
| 130 |
for start in range(0, seq_len - WINDOW_SIZE + 1, stride):
|
| 131 |
window = sequence[start:start + WINDOW_SIZE]
|
| 132 |
+
tokens = np.expand_dims(tokenize(window), axis=0)
|
|
|
|
|
|
|
| 133 |
emb = model.predict(tokens, verbose=0)
|
| 134 |
embeddings.append(emb[0])
|
| 135 |
positions.append(start)
|
| 136 |
|
| 137 |
embeddings = np.array(embeddings) # (n_windows, 1000, 768)
|
| 138 |
|
|
|
|
| 139 |
if mode == "mean":
|
| 140 |
window_emb = np.mean(embeddings, axis=1)
|
| 141 |
return np.mean(window_emb, axis=0), window_emb, positions
|
|
|
|
| 156 |
cols = 32
|
| 157 |
rows = int(np.ceil(n_dims / cols))
|
| 158 |
|
|
|
|
| 159 |
padded = np.full(rows * cols, np.nan)
|
| 160 |
padded[:n_dims] = embedding
|
| 161 |
grid = padded.reshape(rows, cols)
|
| 162 |
|
|
|
|
| 163 |
finite = embedding[np.isfinite(embedding)]
|
| 164 |
+
vmax = max(abs(np.nanmin(finite)), abs(np.nanmax(finite)), 0.01) if finite.size > 0 else 1.0
|
|
|
|
|
|
|
|
|
|
| 165 |
|
| 166 |
+
fig, ax = plt.subplots(figsize=(14, max(4, rows * 0.35)))
|
| 167 |
im = ax.imshow(grid, cmap='RdBu_r', vmin=-vmax, vmax=vmax, aspect='auto')
|
| 168 |
+
plt.colorbar(im, ax=ax, shrink=0.8, label='Activation')
|
| 169 |
+
ax.set_xlabel('Dimension')
|
| 170 |
+
ax.set_ylabel('Row')
|
| 171 |
+
ax.set_title(f'{title} ({n_dims} dims)')
|
|
|
|
|
|
|
|
|
|
| 172 |
ax.set_xticks(np.arange(0, cols, 8))
|
|
|
|
| 173 |
plt.tight_layout()
|
| 174 |
return fig
|
| 175 |
|
| 176 |
+
def create_trajectory_plot(window_embeddings, positions):
|
| 177 |
+
"""Create interactive trajectory heatmap."""
|
|
|
|
|
|
|
|
|
|
| 178 |
emb = np.array(window_embeddings)
|
| 179 |
+
n_windows, n_dims = emb.shape
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
|
| 181 |
+
# Subsample dimensions
|
| 182 |
+
step = max(1, n_dims // 100)
|
| 183 |
+
emb_sub = emb[:, ::step]
|
| 184 |
|
|
|
|
| 185 |
vmax = max(abs(np.nanmin(emb_sub)), abs(np.nanmax(emb_sub)), 0.01)
|
| 186 |
|
| 187 |
+
fig = go.Figure(go.Heatmap(
|
| 188 |
z=emb_sub,
|
| 189 |
x=list(range(emb_sub.shape[1])),
|
| 190 |
+
y=[f"{p}" for p in positions],
|
| 191 |
colorscale='RdBu_r',
|
| 192 |
zmin=-vmax, zmax=vmax,
|
| 193 |
+
colorbar=dict(title='Act.'),
|
| 194 |
+
hovertemplate='Pos: %{y} bp<br>Dim: %{x}<br>Val: %{z:.3f}<extra></extra>'
|
| 195 |
))
|
| 196 |
|
| 197 |
fig.update_layout(
|
| 198 |
+
xaxis=dict(title='Dimension' + (' (subsampled)' if step > 1 else '')),
|
| 199 |
+
yaxis=dict(title='Window start (bp)'),
|
| 200 |
+
height=max(350, n_windows * 15 + 100),
|
| 201 |
+
margin=dict(l=60, r=20, t=30, b=50)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 202 |
)
|
| 203 |
+
return fig
|
| 204 |
+
|
| 205 |
+
def create_stats_plot(stats):
|
| 206 |
+
"""Create a bar chart of embedding statistics."""
|
| 207 |
+
names = ['L2 Norm', 'Mean', 'Std', 'Sparsity', 'Entropy', 'Kurtosis']
|
| 208 |
+
values = [stats['l2_norm'], stats['mean'], stats['std'],
|
| 209 |
+
stats['sparsity'], stats['entropy'], stats['kurtosis']]
|
| 210 |
|
| 211 |
+
# Normalize for display (different scales)
|
| 212 |
+
fig = go.Figure()
|
| 213 |
+
|
| 214 |
+
colors = ['#3b82f6', '#10b981', '#f59e0b', '#ef4444', '#8b5cf6', '#ec4899']
|
| 215 |
+
|
| 216 |
+
for i, (name, val) in enumerate(zip(names, values)):
|
| 217 |
+
fig.add_trace(go.Bar(
|
| 218 |
+
x=[name], y=[val],
|
| 219 |
+
name=name,
|
| 220 |
+
marker_color=colors[i],
|
| 221 |
+
text=[f'{val:.3f}'],
|
| 222 |
+
textposition='outside'
|
| 223 |
+
))
|
| 224 |
+
|
| 225 |
+
fig.update_layout(
|
| 226 |
+
showlegend=False,
|
| 227 |
+
height=280,
|
| 228 |
+
margin=dict(l=40, r=20, t=30, b=40),
|
| 229 |
+
yaxis=dict(title='Value')
|
| 230 |
+
)
|
| 231 |
return fig
|
| 232 |
|
| 233 |
+
def create_dimension_plot(window_embeddings, positions, top_k=8):
|
| 234 |
+
"""Show top varying dimensions."""
|
| 235 |
emb = np.array(window_embeddings)
|
|
|
|
|
|
|
| 236 |
variances = np.var(emb, axis=0)
|
| 237 |
top_dims = np.argsort(variances)[-top_k:][::-1]
|
| 238 |
|
| 239 |
+
colors = ['#e41a1c', '#377eb8', '#4daf4a', '#984ea3',
|
| 240 |
+
'#ff7f00', '#a65628', '#f781bf', '#999999']
|
|
|
|
|
|
|
| 241 |
|
| 242 |
+
fig = go.Figure()
|
| 243 |
for i, dim in enumerate(top_dims):
|
| 244 |
fig.add_trace(go.Scatter(
|
| 245 |
+
x=positions, y=emb[:, dim],
|
| 246 |
+
mode='lines', name=f'd{dim}',
|
| 247 |
+
line=dict(color=colors[i % len(colors)], width=1.5)
|
|
|
|
|
|
|
|
|
|
| 248 |
))
|
| 249 |
|
| 250 |
fig.update_layout(
|
| 251 |
+
xaxis=dict(title='Position (bp)'),
|
| 252 |
+
yaxis=dict(title='Activation'),
|
| 253 |
+
height=300,
|
| 254 |
+
legend=dict(orientation='h', y=1.1),
|
| 255 |
+
margin=dict(l=50, r=20, t=40, b=50)
|
|
|
|
|
|
|
|
|
|
| 256 |
)
|
|
|
|
| 257 |
return fig
|
| 258 |
|
| 259 |
+
# Example sequence
|
| 260 |
EXAMPLE_SEQUENCE = """ATGCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTACGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCG"""
|
| 261 |
|
| 262 |
+
def process(sequence: str, mode: str, stride: int, layer: int):
|
| 263 |
"""Main processing function."""
|
| 264 |
sequence = strip_fasta_header(sequence.strip())
|
| 265 |
|
| 266 |
is_valid, error = validate_sequence(sequence)
|
| 267 |
if not is_valid:
|
| 268 |
+
return f"**Error**: {error}", None, None, None, None, None
|
| 269 |
|
| 270 |
embedding, window_embeddings, positions = embed_sequence(
|
| 271 |
sequence, mode=mode, stride=stride, layer=layer
|
|
|
|
| 275 |
path = os.path.join(tempfile.gettempdir(), "embedding.npy")
|
| 276 |
np.save(path, embedding)
|
| 277 |
|
| 278 |
+
# Compute stats
|
| 279 |
+
if mode == "per-window":
|
| 280 |
+
# For per-window, compute stats on mean embedding
|
| 281 |
+
mean_emb = np.mean(embedding, axis=0)
|
| 282 |
+
stats = compute_embedding_stats(mean_emb)
|
| 283 |
+
else:
|
| 284 |
+
stats = compute_embedding_stats(embedding)
|
| 285 |
+
|
| 286 |
# Create summary
|
| 287 |
if mode == "per-window":
|
| 288 |
+
summary = f"""### Results
|
|
|
|
| 289 |
|
| 290 |
| | |
|
| 291 |
|---|---|
|
| 292 |
| sequence | {len(sequence):,} bp |
|
| 293 |
| layer | {layer} |
|
| 294 |
| windows | {embedding.shape[0]} |
|
| 295 |
+
| shape | {embedding.shape} |
|
| 296 |
+
|
| 297 |
+
**Stats** (on mean): L2={stats['l2_norm']:.1f}, entropy={stats['entropy']:.2f}
|
| 298 |
"""
|
| 299 |
else:
|
| 300 |
+
summary = f"""### Results
|
|
|
|
| 301 |
|
| 302 |
| | |
|
| 303 |
|---|---|
|
|
|
|
| 306 |
| mode | {mode} |
|
| 307 |
| dim | {len(embedding)} |
|
| 308 |
|
| 309 |
+
**Stats**: L2={stats['l2_norm']:.1f}, entropy={stats['entropy']:.2f}, sparsity={stats['sparsity']:.1%}
|
| 310 |
"""
|
| 311 |
|
| 312 |
# Create visualizations
|
| 313 |
heatmap_fig = None
|
| 314 |
+
if mode != "per-window":
|
| 315 |
+
heatmap_fig = create_embedding_heatmap(embedding, f"Layer {layer}")
|
| 316 |
|
| 317 |
+
trajectory_fig = create_trajectory_plot(window_embeddings, positions) if len(window_embeddings) > 1 else None
|
| 318 |
+
stats_fig = create_stats_plot(stats)
|
| 319 |
+
dims_fig = create_dimension_plot(window_embeddings, positions) if len(window_embeddings) > 1 else None
|
| 320 |
|
| 321 |
+
return summary, path, heatmap_fig, trajectory_fig, stats_fig, dims_fig
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 322 |
|
| 323 |
# Build interface
|
| 324 |
+
with gr.Blocks(
|
| 325 |
+
title="BERT Metagenome Embeddings",
|
| 326 |
+
css=".gradio-container { max-width: 100% !important; }"
|
| 327 |
+
) as demo:
|
| 328 |
+
gr.Markdown("# bert-embedding\nExtract embeddings from DNA sequences. BERT (430M params) pretrained on metagenomes.")
|
|
|
|
| 329 |
|
| 330 |
with gr.Tab("Extract"):
|
| 331 |
with gr.Row():
|
| 332 |
+
with gr.Column(scale=1, min_width=300):
|
| 333 |
seq_input = gr.Textbox(
|
| 334 |
label="sequence",
|
| 335 |
+
placeholder="Paste DNA (FASTA or raw)...",
|
| 336 |
+
lines=5,
|
| 337 |
+
value=EXAMPLE_SEQUENCE
|
|
|
|
| 338 |
)
|
| 339 |
with gr.Row():
|
| 340 |
mode_input = gr.Radio(
|
| 341 |
choices=["mean", "max", "per-window"],
|
| 342 |
+
value="mean", label="pooling"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 343 |
)
|
| 344 |
with gr.Row():
|
| 345 |
+
layer_input = gr.Slider(0, 23, value=21, step=1, label="layer")
|
| 346 |
+
stride_input = gr.Slider(50, 500, value=100, step=50, label="stride")
|
| 347 |
btn = gr.Button("extract", variant="primary")
|
| 348 |
output = gr.Markdown()
|
| 349 |
+
download = gr.File(label="download .npy")
|
| 350 |
|
| 351 |
+
with gr.Column(scale=1, min_width=300):
|
| 352 |
+
stats_plot = gr.Plot(label="embedding statistics")
|
| 353 |
heatmap_plot = gr.Plot(label="embedding heatmap")
|
| 354 |
+
|
| 355 |
+
with gr.Column(scale=1, min_width=300):
|
| 356 |
trajectory_plot = gr.Plot(label="window trajectory")
|
| 357 |
dims_plot = gr.Plot(label="top varying dimensions")
|
| 358 |
|
| 359 |
btn.click(
|
| 360 |
process,
|
| 361 |
+
inputs=[seq_input, mode_input, stride_input, layer_input],
|
| 362 |
+
outputs=[output, download, heatmap_plot, trajectory_plot, stats_plot, dims_plot],
|
| 363 |
api_name="embed"
|
| 364 |
)
|
| 365 |
|
|
|
|
| 372 |
import numpy as np
|
| 373 |
|
| 374 |
client = Client("genomenet/bert-embedding")
|
|
|
|
| 375 |
result = client.predict(
|
| 376 |
+
sequence="ATGC...", # min 1000 bp
|
| 377 |
+
mode="mean", # mean/max/per-window
|
| 378 |
stride=100,
|
| 379 |
+
layer=21, # 0-23
|
|
|
|
|
|
|
| 380 |
api_name="/embed"
|
| 381 |
)
|
|
|
|
| 382 |
summary, emb_path, *plots = result
|
| 383 |
embedding = np.load(emb_path)
|
| 384 |
```
|
| 385 |
|
| 386 |
+
**Statistics**:
|
| 387 |
+
- **L2 Norm**: Magnitude of embedding. Higher = stronger model response.
|
| 388 |
+
- **Entropy**: Activation distribution spread. Lower = more structured/confident.
|
| 389 |
+
- **Sparsity**: Fraction of near-zero dims. Higher = sparser representation.
|
| 390 |
+
- **Kurtosis**: Peakedness. Higher = more concentrated activations.
|
|
|
|
| 391 |
|
| 392 |
+
These can serve as proxy "familiarity" scores - sequences similar to training data
|
| 393 |
+
tend to produce more structured embeddings (lower entropy, higher kurtosis).
|
|
|
|
|
|
|
|
|
|
| 394 |
""")
|
| 395 |
|
| 396 |
with gr.Tab("About"):
|
|
|
|
| 401 |
|---|---|
|
| 402 |
| architecture | BERT, 24 layers, 768 hidden, 12 heads |
|
| 403 |
| parameters | ~430M |
|
| 404 |
+
| input | 1000 bp sliding window |
|
|
|
|
| 405 |
| pretraining | metagenomic contigs + microbial genomes |
|
| 406 |
|
| 407 |
+
### Interpreting Statistics
|
| 408 |
|
| 409 |
+
The embedding statistics provide indirect measures of how the model "responds" to a sequence:
|
|
|
|
|
|
|
| 410 |
|
| 411 |
+
- **L2 Norm**: Total activation magnitude. Very high or low may indicate unusual sequences.
|
| 412 |
+
- **Entropy**: How spread out the activations are. Lower entropy suggests more confident/structured representation.
|
| 413 |
+
- **Sparsity**: Fraction of dimensions with near-zero activation.
|
| 414 |
+
- **Kurtosis**: How peaked the distribution is. Higher values = more concentrated activations.
|
| 415 |
+
|
| 416 |
+
**Note**: These are not direct "familiarity" probabilities, but patterns in these metrics across
|
| 417 |
+
different sequence types may reveal what the model considers typical vs. unusual.
|
| 418 |
|
| 419 |
+
### Links
|
| 420 |
- Model: [genomenet/bert-metagenome](https://huggingface.co/genomenet/bert-metagenome)
|
| 421 |
+
- CRISPR: [genomenet/crispr-array-detection](https://huggingface.co/spaces/genomenet/crispr-array-detection)
|
| 422 |
""")
|
| 423 |
|
| 424 |
if __name__ == "__main__":
|