genomenet commited on
Commit
038ad80
·
1 Parent(s): 0c6b9b9

Add embedding stats, full-width layout, 3-column design

Browse files
Files changed (1) hide show
  1. app.py +166 -172
app.py CHANGED
@@ -13,9 +13,7 @@ from huggingface_hub import hf_hub_download
13
  import matplotlib
14
  matplotlib.use('Agg')
15
  import matplotlib.pyplot as plt
16
- from matplotlib.colors import TwoSlopeNorm
17
  import plotly.graph_objects as go
18
- from plotly.subplots import make_subplots
19
 
20
  from custom_layers import get_custom_objects
21
 
@@ -23,12 +21,12 @@ from custom_layers import get_custom_objects
23
  MODEL_REPO = "genomenet/bert-metagenome"
24
  MODEL_FILE = "bert_1k_3.h5"
25
  WINDOW_SIZE = 1000
26
- NUM_LAYERS = 24 # Transformer blocks 0-23
27
  EMBEDDING_DIM = 768
28
 
29
  # Singleton model cache
30
  _model = None
31
- _embedding_models = {} # layer_idx -> embedding_model
32
 
33
  def get_base_model():
34
  """Load and cache the base model."""
@@ -39,6 +37,8 @@ def get_base_model():
39
  print(f"Loading model from {model_path}...")
40
  _model = tf.keras.models.load_model(model_path, custom_objects=get_custom_objects(), compile=False)
41
  print("Model loaded.")
 
 
42
  return _model
43
 
44
  def get_embedding_model(layer_idx=21):
@@ -53,7 +53,6 @@ def get_embedding_model(layer_idx=21):
53
  outputs=model.get_layer(layer_name).output
54
  )
55
  except ValueError:
56
- # Fallback to layer 21 if requested layer not found
57
  _embedding_models[layer_idx] = tf.keras.Model(
58
  inputs=model.input,
59
  outputs=model.get_layer("layer_transformer_block_21").output
@@ -62,25 +61,14 @@ def get_embedding_model(layer_idx=21):
62
 
63
  def get_gpu_status():
64
  gpus = tf.config.list_physical_devices('GPU')
65
- if gpus:
66
- return f"GPU: {gpus[0].name}"
67
- return "CPU only"
68
 
69
- # Tokenization - Integer token IDs
70
  TOKEN_MAP = {'A': 1, 'C': 2, 'G': 3, 'T': 4, 'N': 5}
71
 
72
  def tokenize(sequence):
73
- """Convert DNA sequence to integer token IDs."""
74
  sequence = sequence.upper().replace('U', 'T')
75
- tokens = []
76
- for char in sequence:
77
- if char in TOKEN_MAP:
78
- tokens.append(TOKEN_MAP[char])
79
- elif char in 'RYSWKMBDHV':
80
- tokens.append(5)
81
- else:
82
- tokens.append(5)
83
- return np.array(tokens, dtype=np.int32)
84
 
85
  def validate_sequence(sequence):
86
  if not sequence or len(sequence.strip()) == 0:
@@ -97,29 +85,57 @@ def validate_sequence(sequence):
97
 
98
  def strip_fasta_header(text):
99
  lines = text.strip().split('\n')
100
- seq_lines = [l for l in lines if not l.startswith('>')]
101
- return ''.join(seq_lines).replace(' ', '').replace('\t', '')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
  def embed_sequence(sequence, mode="mean", stride=100, layer=21):
104
  """Extract embeddings from sequence."""
105
  model = get_embedding_model(layer)
106
-
107
  seq_len = len(sequence)
108
  embeddings = []
109
  positions = []
110
 
111
  for start in range(0, seq_len - WINDOW_SIZE + 1, stride):
112
  window = sequence[start:start + WINDOW_SIZE]
113
- tokens = tokenize(window)
114
- tokens = np.expand_dims(tokens, axis=0)
115
-
116
  emb = model.predict(tokens, verbose=0)
117
  embeddings.append(emb[0])
118
  positions.append(start)
119
 
120
  embeddings = np.array(embeddings) # (n_windows, 1000, 768)
121
 
122
- # Pool across sequence positions within each window
123
  if mode == "mean":
124
  window_emb = np.mean(embeddings, axis=1)
125
  return np.mean(window_emb, axis=0), window_emb, positions
@@ -140,120 +156,116 @@ def create_embedding_heatmap(embedding, title="Embedding"):
140
  cols = 32
141
  rows = int(np.ceil(n_dims / cols))
142
 
143
- # Pad to fill grid
144
  padded = np.full(rows * cols, np.nan)
145
  padded[:n_dims] = embedding
146
  grid = padded.reshape(rows, cols)
147
 
148
- # Symmetric normalization
149
  finite = embedding[np.isfinite(embedding)]
150
- if finite.size > 0:
151
- vmax = max(abs(np.nanmin(finite)), abs(np.nanmax(finite)), 0.01)
152
- else:
153
- vmax = 1.0
154
 
155
- fig, ax = plt.subplots(figsize=(12, max(3, rows * 0.3)))
156
  im = ax.imshow(grid, cmap='RdBu_r', vmin=-vmax, vmax=vmax, aspect='auto')
157
-
158
- cbar = plt.colorbar(im, ax=ax, shrink=0.8)
159
- cbar.set_label('Activation', fontsize=9)
160
-
161
- ax.set_xlabel('Dimension', fontsize=9)
162
- ax.set_ylabel('Row', fontsize=9)
163
- ax.set_title(f'{title} ({n_dims} dims)', fontsize=10)
164
  ax.set_xticks(np.arange(0, cols, 8))
165
-
166
  plt.tight_layout()
167
  return fig
168
 
169
- def create_trajectory_plot(window_embeddings, positions, stride):
170
- """Create interactive trajectory plot showing embedding evolution."""
171
- n_windows = len(window_embeddings)
172
-
173
- # Subsample dimensions for visualization
174
  emb = np.array(window_embeddings)
175
- n_dims = emb.shape[1]
176
- if n_dims > 100:
177
- step = n_dims // 100
178
- emb_sub = emb[:, ::step]
179
- else:
180
- emb_sub = emb
181
 
182
- # Create heatmap
183
- fig = go.Figure()
 
184
 
185
- # Symmetric color scale
186
  vmax = max(abs(np.nanmin(emb_sub)), abs(np.nanmax(emb_sub)), 0.01)
187
 
188
- fig.add_trace(go.Heatmap(
189
  z=emb_sub,
190
  x=list(range(emb_sub.shape[1])),
191
- y=[f"{p}-{p+WINDOW_SIZE}" for p in positions],
192
  colorscale='RdBu_r',
193
  zmin=-vmax, zmax=vmax,
194
- colorbar=dict(title='Activation'),
195
- hovertemplate='Window: %{y}<br>Dim: %{x}<br>Value: %{z:.3f}<extra></extra>'
196
  ))
197
 
198
  fig.update_layout(
199
- title=None,
200
- xaxis=dict(title='Dimension (subsampled)' if n_dims > 100 else 'Dimension',
201
- tickfont=dict(size=9)),
202
- yaxis=dict(title='Window position (bp)', tickfont=dict(size=9)),
203
- height=max(300, n_windows * 20 + 100),
204
- plot_bgcolor='#fafafa',
205
- paper_bgcolor='#fafafa',
206
- font=dict(family='Inter, system-ui, sans-serif', size=10)
207
  )
 
 
 
 
 
 
 
208
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
  return fig
210
 
211
- def create_dimension_plot(window_embeddings, positions, top_k=10):
212
- """Show top varying dimensions across windows."""
213
  emb = np.array(window_embeddings)
214
-
215
- # Find dimensions with highest variance
216
  variances = np.var(emb, axis=0)
217
  top_dims = np.argsort(variances)[-top_k:][::-1]
218
 
219
- fig = go.Figure()
220
-
221
- colors = ['#e41a1c', '#377eb8', '#4daf4a', '#984ea3', '#ff7f00',
222
- '#a65628', '#f781bf', '#999999', '#66c2a5', '#fc8d62']
223
 
 
224
  for i, dim in enumerate(top_dims):
225
  fig.add_trace(go.Scatter(
226
- x=positions,
227
- y=emb[:, dim],
228
- mode='lines',
229
- name=f'dim {dim}',
230
- line=dict(color=colors[i % len(colors)], width=1.5),
231
- hovertemplate=f'Dim {dim}<br>Pos: %{{x}}<br>Value: %{{y:.3f}}<extra></extra>'
232
  ))
233
 
234
  fig.update_layout(
235
- title=None,
236
- xaxis=dict(title='Position (bp)', tickfont=dict(size=9)),
237
- yaxis=dict(title='Activation', tickfont=dict(size=9)),
238
- height=350,
239
- legend=dict(orientation='h', yanchor='bottom', y=1.02, font=dict(size=9)),
240
- plot_bgcolor='#fafafa',
241
- paper_bgcolor='#fafafa',
242
- font=dict(family='Inter, system-ui, sans-serif', size=10)
243
  )
244
-
245
  return fig
246
 
247
- # Example sequence (1100 bp)
248
  EXAMPLE_SEQUENCE = """ATGCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTACGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCG"""
249
 
250
- def process(sequence: str, mode: str, stride: int, layer: int, show_heatmap: bool, show_trajectory: bool):
251
  """Main processing function."""
252
  sequence = strip_fasta_header(sequence.strip())
253
 
254
  is_valid, error = validate_sequence(sequence)
255
  if not is_valid:
256
- return f"**Error**: {error}", None, None, None, None
257
 
258
  embedding, window_embeddings, positions = embed_sequence(
259
  sequence, mode=mode, stride=stride, layer=layer
@@ -263,22 +275,29 @@ def process(sequence: str, mode: str, stride: int, layer: int, show_heatmap: boo
263
  path = os.path.join(tempfile.gettempdir(), "embedding.npy")
264
  np.save(path, embedding)
265
 
 
 
 
 
 
 
 
 
266
  # Create summary
267
  if mode == "per-window":
268
- emb_shape = f"({embedding.shape[0]}, {embedding.shape[1]})"
269
- summary = f"""## Embeddings extracted
270
 
271
  | | |
272
  |---|---|
273
  | sequence | {len(sequence):,} bp |
274
  | layer | {layer} |
275
  | windows | {embedding.shape[0]} |
276
- | dim | {embedding.shape[1]} |
277
- | stride | {stride} bp |
 
278
  """
279
  else:
280
- emb_str = ", ".join([f"{x:.3f}" for x in embedding[:8]])
281
- summary = f"""## Embedding extracted
282
 
283
  | | |
284
  |---|---|
@@ -287,81 +306,60 @@ def process(sequence: str, mode: str, stride: int, layer: int, show_heatmap: boo
287
  | mode | {mode} |
288
  | dim | {len(embedding)} |
289
 
290
- **First 8 dims**: [{emb_str}, ...]
291
  """
292
 
293
  # Create visualizations
294
  heatmap_fig = None
295
- trajectory_fig = None
296
- dims_fig = None
297
 
298
- if show_heatmap and mode != "per-window":
299
- heatmap_fig = create_embedding_heatmap(embedding, f"Layer {layer} Embedding")
 
300
 
301
- if show_trajectory and len(window_embeddings) > 1:
302
- trajectory_fig = create_trajectory_plot(window_embeddings, positions, stride)
303
- dims_fig = create_dimension_plot(window_embeddings, positions)
304
-
305
- return summary, path, heatmap_fig, trajectory_fig, dims_fig
306
-
307
- # CSS
308
- CUSTOM_CSS = """
309
- @import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500&display=swap');
310
- * { font-family: 'Inter', system-ui, sans-serif !important; }
311
- code, pre, textarea { font-family: 'SF Mono', Consolas, monospace !important; }
312
- .gradio-container { max-width: 1100px !important; background: #fafafa !important; }
313
- """
314
 
315
  # Build interface
316
- with gr.Blocks(title="BERT Metagenome Embeddings", css=CUSTOM_CSS) as demo:
317
- gr.Markdown("""
318
- # bert-embedding
319
-
320
- Extract embeddings from DNA sequences. BERT model (430M params) pretrained on metagenomic sequences.
321
- """)
322
 
323
  with gr.Tab("Extract"):
324
  with gr.Row():
325
- with gr.Column(scale=1):
326
  seq_input = gr.Textbox(
327
  label="sequence",
328
- placeholder="Paste DNA sequence (FASTA or raw)...",
329
- lines=6,
330
- value=EXAMPLE_SEQUENCE,
331
- info="min 1000 bp"
332
  )
333
  with gr.Row():
334
  mode_input = gr.Radio(
335
  choices=["mean", "max", "per-window"],
336
- value="mean",
337
- label="pooling"
338
- )
339
- layer_input = gr.Slider(
340
- minimum=0, maximum=23, value=21, step=1,
341
- label="layer",
342
- info="transformer block (0-23)"
343
- )
344
- with gr.Row():
345
- stride_input = gr.Slider(
346
- minimum=50, maximum=500, value=100, step=50,
347
- label="stride"
348
  )
349
  with gr.Row():
350
- show_heatmap = gr.Checkbox(label="heatmap", value=True)
351
- show_trajectory = gr.Checkbox(label="trajectory", value=True)
352
  btn = gr.Button("extract", variant="primary")
353
  output = gr.Markdown()
354
- download = gr.File(label="download")
355
 
356
- with gr.Column(scale=2):
 
357
  heatmap_plot = gr.Plot(label="embedding heatmap")
 
 
358
  trajectory_plot = gr.Plot(label="window trajectory")
359
  dims_plot = gr.Plot(label="top varying dimensions")
360
 
361
  btn.click(
362
  process,
363
- inputs=[seq_input, mode_input, stride_input, layer_input, show_heatmap, show_trajectory],
364
- outputs=[output, download, heatmap_plot, trajectory_plot, dims_plot],
365
  api_name="embed"
366
  )
367
 
@@ -374,33 +372,25 @@ from gradio_client import Client
374
  import numpy as np
375
 
376
  client = Client("genomenet/bert-embedding")
377
-
378
  result = client.predict(
379
- sequence="ATGCGATCGATCG...", # min 1000 bp
380
- mode="mean", # "mean", "max", or "per-window"
381
  stride=100,
382
- layer=21, # transformer layer 0-23
383
- show_heatmap=True,
384
- show_trajectory=True,
385
  api_name="/embed"
386
  )
387
-
388
  summary, emb_path, *plots = result
389
  embedding = np.load(emb_path)
390
  ```
391
 
392
- **Layers**: 0-23 (24 transformer blocks). Layer 21 is commonly used for embeddings.
393
-
394
- **Modes**:
395
- - `mean`: Single 768-dim vector (mean pooled)
396
- - `max`: Single 768-dim vector (max pooled)
397
- - `per-window`: Matrix `(n_windows, 768)`
398
 
399
- **Local**:
400
- ```bash
401
- git clone https://huggingface.co/spaces/genomenet/bert-embedding
402
- pip install -r requirements.txt && python app.py
403
- ```
404
  """)
405
 
406
  with gr.Tab("About"):
@@ -411,20 +401,24 @@ pip install -r requirements.txt && python app.py
411
  |---|---|
412
  | architecture | BERT, 24 layers, 768 hidden, 12 heads |
413
  | parameters | ~430M |
414
- | input | 1000 bp DNA (sliding window) |
415
- | output | 768-dim embedding per position |
416
  | pretraining | metagenomic contigs + microbial genomes |
417
 
418
- ### Visualization
419
 
420
- - **Heatmap**: 768 dimensions as colored grid. Blue=negative, Red=positive activation.
421
- - **Trajectory**: How embeddings change across sliding windows. Useful for seeing sequence structure.
422
- - **Top dimensions**: Dimensions with highest variance - most informative for distinguishing sequence regions.
423
 
424
- ### Links
 
 
 
 
 
 
425
 
 
426
  - Model: [genomenet/bert-metagenome](https://huggingface.co/genomenet/bert-metagenome)
427
- - CRISPR Detection: [genomenet/crispr-array-detection](https://huggingface.co/spaces/genomenet/crispr-array-detection)
428
  """)
429
 
430
  if __name__ == "__main__":
 
13
  import matplotlib
14
  matplotlib.use('Agg')
15
  import matplotlib.pyplot as plt
 
16
  import plotly.graph_objects as go
 
17
 
18
  from custom_layers import get_custom_objects
19
 
 
21
  MODEL_REPO = "genomenet/bert-metagenome"
22
  MODEL_FILE = "bert_1k_3.h5"
23
  WINDOW_SIZE = 1000
24
+ NUM_LAYERS = 24
25
  EMBEDDING_DIM = 768
26
 
27
  # Singleton model cache
28
  _model = None
29
+ _embedding_models = {}
30
 
31
  def get_base_model():
32
  """Load and cache the base model."""
 
37
  print(f"Loading model from {model_path}...")
38
  _model = tf.keras.models.load_model(model_path, custom_objects=get_custom_objects(), compile=False)
39
  print("Model loaded.")
40
+ # Print model summary for debugging
41
+ print(f"Model outputs: {_model.output_names}")
42
  return _model
43
 
44
  def get_embedding_model(layer_idx=21):
 
53
  outputs=model.get_layer(layer_name).output
54
  )
55
  except ValueError:
 
56
  _embedding_models[layer_idx] = tf.keras.Model(
57
  inputs=model.input,
58
  outputs=model.get_layer("layer_transformer_block_21").output
 
61
 
62
  def get_gpu_status():
63
  gpus = tf.config.list_physical_devices('GPU')
64
+ return f"GPU: {gpus[0].name}" if gpus else "CPU only"
 
 
65
 
66
+ # Tokenization
67
  TOKEN_MAP = {'A': 1, 'C': 2, 'G': 3, 'T': 4, 'N': 5}
68
 
69
  def tokenize(sequence):
 
70
  sequence = sequence.upper().replace('U', 'T')
71
+ return np.array([TOKEN_MAP.get(c, 5) for c in sequence], dtype=np.int32)
 
 
 
 
 
 
 
 
72
 
73
  def validate_sequence(sequence):
74
  if not sequence or len(sequence.strip()) == 0:
 
85
 
86
  def strip_fasta_header(text):
87
  lines = text.strip().split('\n')
88
+ return ''.join(l for l in lines if not l.startswith('>')).replace(' ', '').replace('\t', '')
89
+
90
+ def compute_embedding_stats(embedding):
91
+ """Compute statistics that may indicate sequence 'familiarity'."""
92
+ emb = np.array(embedding)
93
+
94
+ # L2 norm - magnitude of response
95
+ l2_norm = np.linalg.norm(emb)
96
+
97
+ # Mean activation
98
+ mean_act = np.mean(emb)
99
+
100
+ # Std - spread of activations
101
+ std_act = np.std(emb)
102
+
103
+ # Sparsity - fraction of near-zero activations
104
+ sparsity = np.mean(np.abs(emb) < 0.1)
105
+
106
+ # Activation entropy (discretized)
107
+ hist, _ = np.histogram(emb, bins=50, density=True)
108
+ hist = hist[hist > 0]
109
+ entropy = -np.sum(hist * np.log(hist + 1e-10))
110
+
111
+ # Kurtosis - peakedness (high = more concentrated activations)
112
+ kurtosis = np.mean(((emb - mean_act) / (std_act + 1e-10)) ** 4) - 3
113
+
114
+ return {
115
+ 'l2_norm': float(l2_norm),
116
+ 'mean': float(mean_act),
117
+ 'std': float(std_act),
118
+ 'sparsity': float(sparsity),
119
+ 'entropy': float(entropy),
120
+ 'kurtosis': float(kurtosis)
121
+ }
122
 
123
  def embed_sequence(sequence, mode="mean", stride=100, layer=21):
124
  """Extract embeddings from sequence."""
125
  model = get_embedding_model(layer)
 
126
  seq_len = len(sequence)
127
  embeddings = []
128
  positions = []
129
 
130
  for start in range(0, seq_len - WINDOW_SIZE + 1, stride):
131
  window = sequence[start:start + WINDOW_SIZE]
132
+ tokens = np.expand_dims(tokenize(window), axis=0)
 
 
133
  emb = model.predict(tokens, verbose=0)
134
  embeddings.append(emb[0])
135
  positions.append(start)
136
 
137
  embeddings = np.array(embeddings) # (n_windows, 1000, 768)
138
 
 
139
  if mode == "mean":
140
  window_emb = np.mean(embeddings, axis=1)
141
  return np.mean(window_emb, axis=0), window_emb, positions
 
156
  cols = 32
157
  rows = int(np.ceil(n_dims / cols))
158
 
 
159
  padded = np.full(rows * cols, np.nan)
160
  padded[:n_dims] = embedding
161
  grid = padded.reshape(rows, cols)
162
 
 
163
  finite = embedding[np.isfinite(embedding)]
164
+ vmax = max(abs(np.nanmin(finite)), abs(np.nanmax(finite)), 0.01) if finite.size > 0 else 1.0
 
 
 
165
 
166
+ fig, ax = plt.subplots(figsize=(14, max(4, rows * 0.35)))
167
  im = ax.imshow(grid, cmap='RdBu_r', vmin=-vmax, vmax=vmax, aspect='auto')
168
+ plt.colorbar(im, ax=ax, shrink=0.8, label='Activation')
169
+ ax.set_xlabel('Dimension')
170
+ ax.set_ylabel('Row')
171
+ ax.set_title(f'{title} ({n_dims} dims)')
 
 
 
172
  ax.set_xticks(np.arange(0, cols, 8))
 
173
  plt.tight_layout()
174
  return fig
175
 
176
+ def create_trajectory_plot(window_embeddings, positions):
177
+ """Create interactive trajectory heatmap."""
 
 
 
178
  emb = np.array(window_embeddings)
179
+ n_windows, n_dims = emb.shape
 
 
 
 
 
180
 
181
+ # Subsample dimensions
182
+ step = max(1, n_dims // 100)
183
+ emb_sub = emb[:, ::step]
184
 
 
185
  vmax = max(abs(np.nanmin(emb_sub)), abs(np.nanmax(emb_sub)), 0.01)
186
 
187
+ fig = go.Figure(go.Heatmap(
188
  z=emb_sub,
189
  x=list(range(emb_sub.shape[1])),
190
+ y=[f"{p}" for p in positions],
191
  colorscale='RdBu_r',
192
  zmin=-vmax, zmax=vmax,
193
+ colorbar=dict(title='Act.'),
194
+ hovertemplate='Pos: %{y} bp<br>Dim: %{x}<br>Val: %{z:.3f}<extra></extra>'
195
  ))
196
 
197
  fig.update_layout(
198
+ xaxis=dict(title='Dimension' + (' (subsampled)' if step > 1 else '')),
199
+ yaxis=dict(title='Window start (bp)'),
200
+ height=max(350, n_windows * 15 + 100),
201
+ margin=dict(l=60, r=20, t=30, b=50)
 
 
 
 
202
  )
203
+ return fig
204
+
205
+ def create_stats_plot(stats):
206
+ """Create a bar chart of embedding statistics."""
207
+ names = ['L2 Norm', 'Mean', 'Std', 'Sparsity', 'Entropy', 'Kurtosis']
208
+ values = [stats['l2_norm'], stats['mean'], stats['std'],
209
+ stats['sparsity'], stats['entropy'], stats['kurtosis']]
210
 
211
+ # Normalize for display (different scales)
212
+ fig = go.Figure()
213
+
214
+ colors = ['#3b82f6', '#10b981', '#f59e0b', '#ef4444', '#8b5cf6', '#ec4899']
215
+
216
+ for i, (name, val) in enumerate(zip(names, values)):
217
+ fig.add_trace(go.Bar(
218
+ x=[name], y=[val],
219
+ name=name,
220
+ marker_color=colors[i],
221
+ text=[f'{val:.3f}'],
222
+ textposition='outside'
223
+ ))
224
+
225
+ fig.update_layout(
226
+ showlegend=False,
227
+ height=280,
228
+ margin=dict(l=40, r=20, t=30, b=40),
229
+ yaxis=dict(title='Value')
230
+ )
231
  return fig
232
 
233
+ def create_dimension_plot(window_embeddings, positions, top_k=8):
234
+ """Show top varying dimensions."""
235
  emb = np.array(window_embeddings)
 
 
236
  variances = np.var(emb, axis=0)
237
  top_dims = np.argsort(variances)[-top_k:][::-1]
238
 
239
+ colors = ['#e41a1c', '#377eb8', '#4daf4a', '#984ea3',
240
+ '#ff7f00', '#a65628', '#f781bf', '#999999']
 
 
241
 
242
+ fig = go.Figure()
243
  for i, dim in enumerate(top_dims):
244
  fig.add_trace(go.Scatter(
245
+ x=positions, y=emb[:, dim],
246
+ mode='lines', name=f'd{dim}',
247
+ line=dict(color=colors[i % len(colors)], width=1.5)
 
 
 
248
  ))
249
 
250
  fig.update_layout(
251
+ xaxis=dict(title='Position (bp)'),
252
+ yaxis=dict(title='Activation'),
253
+ height=300,
254
+ legend=dict(orientation='h', y=1.1),
255
+ margin=dict(l=50, r=20, t=40, b=50)
 
 
 
256
  )
 
257
  return fig
258
 
259
+ # Example sequence
260
  EXAMPLE_SEQUENCE = """ATGCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTACGATCGATCGATCGATCGTAGCTAGCTAGCTAGCTAGCTGATCGATCGATCGTAGCTAGCTAGCTGATCGATCGATCGATCG"""
261
 
262
+ def process(sequence: str, mode: str, stride: int, layer: int):
263
  """Main processing function."""
264
  sequence = strip_fasta_header(sequence.strip())
265
 
266
  is_valid, error = validate_sequence(sequence)
267
  if not is_valid:
268
+ return f"**Error**: {error}", None, None, None, None, None
269
 
270
  embedding, window_embeddings, positions = embed_sequence(
271
  sequence, mode=mode, stride=stride, layer=layer
 
275
  path = os.path.join(tempfile.gettempdir(), "embedding.npy")
276
  np.save(path, embedding)
277
 
278
+ # Compute stats
279
+ if mode == "per-window":
280
+ # For per-window, compute stats on mean embedding
281
+ mean_emb = np.mean(embedding, axis=0)
282
+ stats = compute_embedding_stats(mean_emb)
283
+ else:
284
+ stats = compute_embedding_stats(embedding)
285
+
286
  # Create summary
287
  if mode == "per-window":
288
+ summary = f"""### Results
 
289
 
290
  | | |
291
  |---|---|
292
  | sequence | {len(sequence):,} bp |
293
  | layer | {layer} |
294
  | windows | {embedding.shape[0]} |
295
+ | shape | {embedding.shape} |
296
+
297
+ **Stats** (on mean): L2={stats['l2_norm']:.1f}, entropy={stats['entropy']:.2f}
298
  """
299
  else:
300
+ summary = f"""### Results
 
301
 
302
  | | |
303
  |---|---|
 
306
  | mode | {mode} |
307
  | dim | {len(embedding)} |
308
 
309
+ **Stats**: L2={stats['l2_norm']:.1f}, entropy={stats['entropy']:.2f}, sparsity={stats['sparsity']:.1%}
310
  """
311
 
312
  # Create visualizations
313
  heatmap_fig = None
314
+ if mode != "per-window":
315
+ heatmap_fig = create_embedding_heatmap(embedding, f"Layer {layer}")
316
 
317
+ trajectory_fig = create_trajectory_plot(window_embeddings, positions) if len(window_embeddings) > 1 else None
318
+ stats_fig = create_stats_plot(stats)
319
+ dims_fig = create_dimension_plot(window_embeddings, positions) if len(window_embeddings) > 1 else None
320
 
321
+ return summary, path, heatmap_fig, trajectory_fig, stats_fig, dims_fig
 
 
 
 
 
 
 
 
 
 
 
 
322
 
323
  # Build interface
324
+ with gr.Blocks(
325
+ title="BERT Metagenome Embeddings",
326
+ css=".gradio-container { max-width: 100% !important; }"
327
+ ) as demo:
328
+ gr.Markdown("# bert-embedding\nExtract embeddings from DNA sequences. BERT (430M params) pretrained on metagenomes.")
 
329
 
330
  with gr.Tab("Extract"):
331
  with gr.Row():
332
+ with gr.Column(scale=1, min_width=300):
333
  seq_input = gr.Textbox(
334
  label="sequence",
335
+ placeholder="Paste DNA (FASTA or raw)...",
336
+ lines=5,
337
+ value=EXAMPLE_SEQUENCE
 
338
  )
339
  with gr.Row():
340
  mode_input = gr.Radio(
341
  choices=["mean", "max", "per-window"],
342
+ value="mean", label="pooling"
 
 
 
 
 
 
 
 
 
 
 
343
  )
344
  with gr.Row():
345
+ layer_input = gr.Slider(0, 23, value=21, step=1, label="layer")
346
+ stride_input = gr.Slider(50, 500, value=100, step=50, label="stride")
347
  btn = gr.Button("extract", variant="primary")
348
  output = gr.Markdown()
349
+ download = gr.File(label="download .npy")
350
 
351
+ with gr.Column(scale=1, min_width=300):
352
+ stats_plot = gr.Plot(label="embedding statistics")
353
  heatmap_plot = gr.Plot(label="embedding heatmap")
354
+
355
+ with gr.Column(scale=1, min_width=300):
356
  trajectory_plot = gr.Plot(label="window trajectory")
357
  dims_plot = gr.Plot(label="top varying dimensions")
358
 
359
  btn.click(
360
  process,
361
+ inputs=[seq_input, mode_input, stride_input, layer_input],
362
+ outputs=[output, download, heatmap_plot, trajectory_plot, stats_plot, dims_plot],
363
  api_name="embed"
364
  )
365
 
 
372
  import numpy as np
373
 
374
  client = Client("genomenet/bert-embedding")
 
375
  result = client.predict(
376
+ sequence="ATGC...", # min 1000 bp
377
+ mode="mean", # mean/max/per-window
378
  stride=100,
379
+ layer=21, # 0-23
 
 
380
  api_name="/embed"
381
  )
 
382
  summary, emb_path, *plots = result
383
  embedding = np.load(emb_path)
384
  ```
385
 
386
+ **Statistics**:
387
+ - **L2 Norm**: Magnitude of embedding. Higher = stronger model response.
388
+ - **Entropy**: Activation distribution spread. Lower = more structured/confident.
389
+ - **Sparsity**: Fraction of near-zero dims. Higher = sparser representation.
390
+ - **Kurtosis**: Peakedness. Higher = more concentrated activations.
 
391
 
392
+ These can serve as proxy "familiarity" scores - sequences similar to training data
393
+ tend to produce more structured embeddings (lower entropy, higher kurtosis).
 
 
 
394
  """)
395
 
396
  with gr.Tab("About"):
 
401
  |---|---|
402
  | architecture | BERT, 24 layers, 768 hidden, 12 heads |
403
  | parameters | ~430M |
404
+ | input | 1000 bp sliding window |
 
405
  | pretraining | metagenomic contigs + microbial genomes |
406
 
407
+ ### Interpreting Statistics
408
 
409
+ The embedding statistics provide indirect measures of how the model "responds" to a sequence:
 
 
410
 
411
+ - **L2 Norm**: Total activation magnitude. Very high or low may indicate unusual sequences.
412
+ - **Entropy**: How spread out the activations are. Lower entropy suggests more confident/structured representation.
413
+ - **Sparsity**: Fraction of dimensions with near-zero activation.
414
+ - **Kurtosis**: How peaked the distribution is. Higher values = more concentrated activations.
415
+
416
+ **Note**: These are not direct "familiarity" probabilities, but patterns in these metrics across
417
+ different sequence types may reveal what the model considers typical vs. unusual.
418
 
419
+ ### Links
420
  - Model: [genomenet/bert-metagenome](https://huggingface.co/genomenet/bert-metagenome)
421
+ - CRISPR: [genomenet/crispr-array-detection](https://huggingface.co/spaces/genomenet/crispr-array-detection)
422
  """)
423
 
424
  if __name__ == "__main__":