the-puzzler commited on
Commit
2c1ba2b
·
1 Parent(s): 1624b4f

Add MicrobeAtlas import and community builder UI

Browse files
Files changed (1) hide show
  1. app.py +504 -99
app.py CHANGED
@@ -1,24 +1,80 @@
 
1
  import os
2
  from dataclasses import dataclass
3
- from typing import List, Tuple
4
 
5
  import gradio as gr
6
  import numpy as np
7
  import plotly.express as px
8
  import torch
9
- import umap
10
  from Bio import SeqIO
11
  from transformers import AutoModel, AutoTokenizer
12
 
13
  from model import MicrobiomeTransformer
14
 
15
 
 
 
 
 
 
16
  MAX_GENES = 800
17
  MAX_SEQ_LEN = 1024
 
18
  PROKBERT_MODEL_ID = os.getenv("PROKBERT_MODEL_ID", "neuralbioinfo/prokbert-mini-long")
19
  CHECKPOINT_PATH = os.getenv("CHECKPOINT_PATH", "large-notext.pt")
20
- BATCH_SIZE = int(os.getenv("EMBED_BATCH_SIZE", "32"))
21
- TRUST_REMOTE_CODE = "true"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
 
24
  @dataclass
@@ -29,7 +85,26 @@ class LoadedModels:
29
  device: torch.device
30
 
31
 
 
 
 
 
 
 
 
 
 
 
32
  _MODELS: LoadedModels | None = None
 
 
 
 
 
 
 
 
 
33
 
34
 
35
  def _load_models() -> LoadedModels:
@@ -38,7 +113,6 @@ def _load_models() -> LoadedModels:
38
  return _MODELS
39
 
40
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
41
-
42
  tokenizer = AutoTokenizer.from_pretrained(PROKBERT_MODEL_ID, trust_remote_code=TRUST_REMOTE_CODE)
43
  prokbert = AutoModel.from_pretrained(PROKBERT_MODEL_ID, trust_remote_code=TRUST_REMOTE_CODE)
44
  prokbert.to(device)
@@ -69,28 +143,142 @@ def _load_models() -> LoadedModels:
69
  return _MODELS
70
 
71
 
72
- def _read_fasta(path: str) -> Tuple[List[str], List[str], int, int]:
73
- ids: List[str] = []
74
- seqs: List[str] = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  truncated = 0
76
 
77
  for record in SeqIO.parse(path, "fasta"):
78
- seq = str(record.seq).upper()
79
- if len(seq) > MAX_SEQ_LEN:
80
- seq = seq[:MAX_SEQ_LEN]
81
- truncated += 1
82
- ids.append(record.id)
83
- seqs.append(seq)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
- original_n = len(ids)
86
- if original_n == 0:
87
- raise ValueError("No FASTA records found.")
88
 
89
- if original_n > MAX_GENES:
90
- ids = ids[:MAX_GENES]
91
- seqs = seqs[:MAX_GENES]
 
 
92
 
93
- return ids, seqs, original_n, truncated
 
 
94
 
95
 
96
  def _mean_pool(last_hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
@@ -112,7 +300,7 @@ def _embed_sequences(seqs: List[str], models: LoadedModels) -> np.ndarray:
112
  max_length=MAX_SEQ_LEN,
113
  padding=True,
114
  )
115
- inputs = {k: v.to(models.device) for k, v in inputs.items()}
116
 
117
  with torch.no_grad():
118
  outputs = models.prokbert(**inputs)
@@ -120,139 +308,356 @@ def _embed_sequences(seqs: List[str], models: LoadedModels) -> np.ndarray:
120
 
121
  pooled_batches.append(pooled.detach().cpu().numpy())
122
 
123
- emb = np.vstack(pooled_batches)
124
- if emb.shape[1] != 384:
125
- raise ValueError(
126
- f"Expected 384-d ProkBERT embeddings, got {emb.shape[1]} dimensions from {PROKBERT_MODEL_ID}."
127
  )
128
- return emb
129
 
130
 
131
  def _infer_logits_and_final_embeddings(input_embeddings: np.ndarray, models: LoadedModels) -> Tuple[np.ndarray, np.ndarray]:
132
  x = torch.tensor(input_embeddings, dtype=torch.float32, device=models.device).unsqueeze(0)
133
  n = x.shape[1]
134
-
135
  empty_text = torch.zeros((1, 0, 1536), dtype=torch.float32, device=models.device)
136
  mask = torch.ones((1, n), dtype=torch.bool, device=models.device)
137
- type_indicators = torch.zeros((1, n), dtype=torch.long, device=models.device)
138
-
139
- batch = {
140
- "embeddings_type1": x,
141
- "embeddings_type2": empty_text,
142
- "mask": mask,
143
- "type_indicators": type_indicators,
144
- }
145
 
146
  with torch.no_grad():
147
- x_proj = models.microbiome.input_projection_type1(batch["embeddings_type1"])
148
  final_hidden = models.microbiome.transformer(x_proj, src_key_padding_mask=~mask)
149
  logits = models.microbiome.output_projection(final_hidden).squeeze(-1)
150
 
151
- return (
152
- logits.squeeze(0).detach().cpu().numpy(),
153
- final_hidden.squeeze(0).detach().cpu().numpy(),
154
- )
155
 
156
 
157
- def _umap_df(vectors: np.ndarray, labels: List[str], value_name: str):
158
- n = vectors.shape[0]
159
- if n < 2:
160
- raise ValueError("Need at least 2 genes to compute UMAP.")
161
 
162
  reducer = umap.UMAP(
163
  n_components=2,
164
- n_neighbors=min(15, n - 1),
165
  min_dist=0.1,
166
  metric="cosine",
167
  random_state=42,
168
  )
169
  coords = reducer.fit_transform(vectors)
170
- return {
171
- "x": coords[:, 0],
172
- "y": coords[:, 1],
173
- "gene": labels,
174
- value_name: np.linalg.norm(vectors, axis=1),
175
- }
176
-
177
 
178
- def _plot_umap(vectors: np.ndarray, labels: List[str], title: str):
179
- df = _umap_df(vectors, labels, "norm")
180
  fig = px.scatter(
181
- df,
182
- x="x",
183
- y="y",
184
- hover_name="gene",
185
- color="norm",
186
  title=title,
187
  color_continuous_scale="Viridis",
188
  )
189
- fig.update_traces(marker={"size": 9, "line": {"width": 0.5, "color": "black"}})
 
 
 
 
 
190
  return fig
191
 
192
 
193
- def _plot_logits(logits: np.ndarray, labels: List[str]):
194
  fig = px.histogram(
195
  x=logits,
196
- nbins=min(50, max(10, len(logits) // 4)),
197
  title="Logit Distribution Over Input DNA Embeddings",
 
 
 
 
 
 
 
 
198
  )
199
- fig.update_layout(xaxis_title="Logit", yaxis_title="Count")
200
  return fig
201
 
202
 
203
- def run_pipeline(fasta_file: str):
204
- if fasta_file is None:
205
- raise gr.Error("Upload a FASTA file first.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
 
207
  models = _load_models()
208
- labels, seqs, original_n, truncated = _read_fasta(fasta_file)
 
 
209
 
210
  input_embeddings = _embed_sequences(seqs, models)
211
  logits, final_embeddings = _infer_logits_and_final_embeddings(input_embeddings, models)
212
 
213
- input_umap = _plot_umap(input_embeddings, labels, "UMAP of Input DNA Embeddings (ProkBERT Mean-Pooled)")
214
- final_umap = _plot_umap(final_embeddings, labels, "UMAP of Final Embeddings (After large-notext Transformer)")
215
- logits_hist = _plot_logits(logits, labels)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
 
217
- capped_n = len(labels)
218
- info = (
219
- f"Loaded {original_n} genes. "
220
- f"Used {capped_n} (cap={MAX_GENES}). "
221
- f"Truncated {truncated} sequence(s) to {MAX_SEQ_LEN} nt."
222
  )
 
 
223
 
224
- top_idx = np.argsort(logits)[::-1]
225
- top_rows = [[labels[i], float(logits[i])] for i in top_idx[: min(50, len(labels))]]
226
 
227
- return info, input_umap, final_umap, logits_hist, top_rows
228
 
 
 
 
 
 
 
229
 
230
- with gr.Blocks(title="Microbiome Space: ProkBERT -> large-notext") as demo:
231
- gr.Markdown(
232
- """
233
- # Microbiome Gene Scoring Explorer
234
- Upload a FASTA of genes, embed with `prokbert-mini-long` (mean pooling), score with `large-notext`, and inspect embedding geometry + logit distribution.
235
 
236
- Constraints:
237
- - Max genes per run: 800
238
- - Max gene length: 1024 nt (longer sequences are truncated)
239
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
  )
241
 
242
- with gr.Row():
243
- fasta_in = gr.File(label="FASTA file", file_types=[".fa", ".fasta", ".fna", ".txt"], type="filepath")
244
- run_btn = gr.Button("Run", variant="primary")
245
 
246
- status = gr.Textbox(label="Run Summary")
247
- input_umap_plot = gr.Plot(label="Input Embedding UMAP")
248
- final_umap_plot = gr.Plot(label="Final Embedding UMAP")
249
- logits_plot = gr.Plot(label="Logit Distribution")
250
- top_table = gr.Dataframe(headers=["gene_id", "logit"], label="Top genes by logit")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
 
252
- run_btn.click(
253
- fn=run_pipeline,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
  inputs=[fasta_in],
255
- outputs=[status, input_umap_plot, final_umap_plot, logits_plot, top_table],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256
  )
257
 
258
 
 
1
+ import csv
2
  import os
3
  from dataclasses import dataclass
4
+ from typing import Dict, List, Tuple
5
 
6
  import gradio as gr
7
  import numpy as np
8
  import plotly.express as px
9
  import torch
 
10
  from Bio import SeqIO
11
  from transformers import AutoModel, AutoTokenizer
12
 
13
  from model import MicrobiomeTransformer
14
 
15
 
16
+ os.environ.setdefault("NUMBA_CACHE_DIR", "/tmp/numba-cache")
17
+
18
+ import umap
19
+
20
+
21
  MAX_GENES = 800
22
  MAX_SEQ_LEN = 1024
23
+ BATCH_SIZE = int(os.getenv("EMBED_BATCH_SIZE", "32"))
24
  PROKBERT_MODEL_ID = os.getenv("PROKBERT_MODEL_ID", "neuralbioinfo/prokbert-mini-long")
25
  CHECKPOINT_PATH = os.getenv("CHECKPOINT_PATH", "large-notext.pt")
26
+ OTU_INFO_PATH = os.getenv("OTU_INFO_PATH", "otus.97.allinfo")
27
+ EXAMPLE_SAMPLE_PATH = "sample_DRS000421_DRR000770_taxa.tsv"
28
+ MICROBEATLAS_SAMPLE_URL = "https://microbeatlas.org/sample_detail?sid=DRS000421&rid=null"
29
+ TRUST_REMOTE_CODE = os.getenv("TRUST_REMOTE_CODE", "true").lower() == "true"
30
+
31
+ CSS = """
32
+ :root {
33
+ --bg: #f4f0e8;
34
+ --panel: rgba(255, 252, 247, 0.88);
35
+ --panel-strong: rgba(246, 240, 230, 0.96);
36
+ --ink: #1d2a1f;
37
+ --muted: #586454;
38
+ --accent: #0e7a5f;
39
+ --accent-2: #d8832f;
40
+ --line: rgba(29, 42, 31, 0.12);
41
+ }
42
+ .gradio-container {
43
+ background:
44
+ radial-gradient(circle at top left, rgba(216, 131, 47, 0.18), transparent 28%),
45
+ radial-gradient(circle at top right, rgba(14, 122, 95, 0.18), transparent 24%),
46
+ linear-gradient(180deg, #f7f2e9 0%, #eee6d8 100%);
47
+ color: var(--ink);
48
+ }
49
+ .hero {
50
+ padding: 28px;
51
+ border: 1px solid var(--line);
52
+ border-radius: 24px;
53
+ background: linear-gradient(135deg, rgba(255,255,255,0.85), rgba(241,232,218,0.92));
54
+ box-shadow: 0 18px 60px rgba(69, 57, 34, 0.08);
55
+ }
56
+ .hero h1 {
57
+ margin: 0 0 10px 0;
58
+ font-size: 2.4rem;
59
+ line-height: 1.05;
60
+ }
61
+ .hero p {
62
+ margin: 0;
63
+ max-width: 900px;
64
+ color: var(--muted);
65
+ font-size: 1rem;
66
+ }
67
+ .soft-card {
68
+ border: 1px solid var(--line);
69
+ border-radius: 22px;
70
+ background: var(--panel);
71
+ box-shadow: 0 12px 32px rgba(40, 36, 26, 0.06);
72
+ }
73
+ .section-note {
74
+ color: var(--muted);
75
+ font-size: 0.95rem;
76
+ }
77
+ """
78
 
79
 
80
  @dataclass
 
85
  device: torch.device
86
 
87
 
88
+ @dataclass
89
+ class OTUEntry:
90
+ otu_id: str
91
+ label: str
92
+ taxonomy: str
93
+ sequence: str
94
+ seq_len: int
95
+ search_text: str
96
+
97
+
98
  _MODELS: LoadedModels | None = None
99
+ _OTU_DB: Dict[str, OTUEntry] | None = None
100
+ _OTU_SEARCH: List[OTUEntry] | None = None
101
+
102
+
103
+ def _extract_taxa_name(taxonomy: str) -> str:
104
+ parts = [part.strip() for part in taxonomy.split(";") if part.strip()]
105
+ if not parts:
106
+ return "Unclassified"
107
+ return parts[-1].replace("g__", "").replace("s__", "").replace("f__", "")
108
 
109
 
110
  def _load_models() -> LoadedModels:
 
113
  return _MODELS
114
 
115
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
116
  tokenizer = AutoTokenizer.from_pretrained(PROKBERT_MODEL_ID, trust_remote_code=TRUST_REMOTE_CODE)
117
  prokbert = AutoModel.from_pretrained(PROKBERT_MODEL_ID, trust_remote_code=TRUST_REMOTE_CODE)
118
  prokbert.to(device)
 
143
  return _MODELS
144
 
145
 
146
+ def _load_otu_db() -> Tuple[Dict[str, OTUEntry], List[OTUEntry]]:
147
+ global _OTU_DB, _OTU_SEARCH
148
+ if _OTU_DB is not None and _OTU_SEARCH is not None:
149
+ return _OTU_DB, _OTU_SEARCH
150
+
151
+ otu_db: Dict[str, OTUEntry] = {}
152
+ otu_search: List[OTUEntry] = []
153
+
154
+ with open(OTU_INFO_PATH, newline="") as handle:
155
+ reader = csv.reader(handle, delimiter="\t")
156
+ for row in reader:
157
+ if len(row) < 15:
158
+ continue
159
+
160
+ raw_id = row[0].strip()
161
+ sequence = row[6].strip().upper()
162
+ taxonomy = row[14].strip() or row[8].strip() or "Unclassified"
163
+
164
+ if not raw_id or not sequence:
165
+ continue
166
+
167
+ otu_id = raw_id.split(";")[-1]
168
+ label = _extract_taxa_name(taxonomy)
169
+ entry = OTUEntry(
170
+ otu_id=otu_id,
171
+ label=label,
172
+ taxonomy=taxonomy,
173
+ sequence=sequence,
174
+ seq_len=len(sequence),
175
+ search_text=f"{otu_id} {label} {taxonomy}".lower(),
176
+ )
177
+ otu_db[otu_id] = entry
178
+ otu_search.append(entry)
179
+
180
+ _OTU_DB = otu_db
181
+ _OTU_SEARCH = otu_search
182
+ return otu_db, otu_search
183
+
184
+
185
+ def _trim_sequence(sequence: str) -> Tuple[str, bool]:
186
+ if len(sequence) > MAX_SEQ_LEN:
187
+ return sequence[:MAX_SEQ_LEN], True
188
+ return sequence, False
189
+
190
+
191
+ def _read_fasta(path: str) -> Tuple[List[dict], int, int]:
192
+ records: List[dict] = []
193
  truncated = 0
194
 
195
  for record in SeqIO.parse(path, "fasta"):
196
+ seq, was_truncated = _trim_sequence(str(record.seq).upper())
197
+ truncated += int(was_truncated)
198
+ records.append(
199
+ {
200
+ "id": record.id,
201
+ "sequence": seq,
202
+ "source": "FASTA",
203
+ "taxonomy": "",
204
+ "detail": f"{len(seq)} nt",
205
+ }
206
+ )
207
+
208
+ if not records:
209
+ raise gr.Error("No FASTA records found.")
210
+
211
+ return records[:MAX_GENES], len(records), truncated
212
+
213
+
214
+ def _read_microbeatlas_sample(path: str) -> Tuple[List[dict], str]:
215
+ otu_db, _ = _load_otu_db()
216
+ records: List[dict] = []
217
+ missing_ids: List[str] = []
218
+
219
+ with open(path, newline="") as handle:
220
+ reader = csv.reader(handle, delimiter="\t")
221
+ header = next(reader, None)
222
+ if header is None:
223
+ raise gr.Error("The MicrobeAtlas file is empty.")
224
+
225
+ columns = [col.strip() for col in header]
226
+ column_index = {name: idx for idx, name in enumerate(columns)}
227
+ if "SHORT_TID" not in column_index:
228
+ raise gr.Error("Expected a MicrobeAtlas taxa file with a SHORT_TID column.")
229
+
230
+ for row in reader:
231
+ if not row:
232
+ continue
233
+ otu_id = row[column_index["SHORT_TID"]].strip()
234
+ if not otu_id:
235
+ continue
236
+ entry = otu_db.get(otu_id)
237
+ if entry is None:
238
+ missing_ids.append(otu_id)
239
+ continue
240
+
241
+ seq, was_truncated = _trim_sequence(entry.sequence)
242
+ detail_bits = []
243
+ for column in ("COUNT", "ABUNDANCE"):
244
+ idx = column_index.get(column)
245
+ if idx is not None and idx < len(row):
246
+ value = row[idx].strip()
247
+ if value:
248
+ detail_bits.append(f"{column.lower()}={value}")
249
+ if was_truncated:
250
+ detail_bits.append("trimmed")
251
+
252
+ records.append(
253
+ {
254
+ "id": otu_id,
255
+ "sequence": seq,
256
+ "source": "MicrobeAtlas",
257
+ "taxonomy": entry.taxonomy,
258
+ "detail": ", ".join(detail_bits) if detail_bits else f"{entry.seq_len} nt",
259
+ }
260
+ )
261
+
262
+ if not records:
263
+ raise gr.Error("No OTU IDs from this MicrobeAtlas file matched otus.97.allinfo.")
264
+
265
+ used_records = records[:MAX_GENES]
266
+ summary = (
267
+ f"Translated {len(used_records)} OTUs from the MicrobeAtlas upload. "
268
+ f"Missing sequence mappings for {len(missing_ids)} OTUs."
269
+ )
270
+ return used_records, summary
271
 
 
 
 
272
 
273
+ def _search_otu_records(query: str, limit: int = 80) -> List[OTUEntry]:
274
+ _, otu_search = _load_otu_db()
275
+ needle = query.strip().lower()
276
+ if not needle:
277
+ return []
278
 
279
+ matches = [entry for entry in otu_search if needle in entry.search_text]
280
+ matches.sort(key=lambda entry: (entry.label.lower(), entry.otu_id))
281
+ return matches[:limit]
282
 
283
 
284
  def _mean_pool(last_hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
 
300
  max_length=MAX_SEQ_LEN,
301
  padding=True,
302
  )
303
+ inputs = {key: value.to(models.device) for key, value in inputs.items()}
304
 
305
  with torch.no_grad():
306
  outputs = models.prokbert(**inputs)
 
308
 
309
  pooled_batches.append(pooled.detach().cpu().numpy())
310
 
311
+ embeddings = np.vstack(pooled_batches)
312
+ if embeddings.shape[1] != 384:
313
+ raise gr.Error(
314
+ f"Expected 384-d ProkBERT embeddings, got {embeddings.shape[1]} from {PROKBERT_MODEL_ID}."
315
  )
316
+ return embeddings
317
 
318
 
319
  def _infer_logits_and_final_embeddings(input_embeddings: np.ndarray, models: LoadedModels) -> Tuple[np.ndarray, np.ndarray]:
320
  x = torch.tensor(input_embeddings, dtype=torch.float32, device=models.device).unsqueeze(0)
321
  n = x.shape[1]
 
322
  empty_text = torch.zeros((1, 0, 1536), dtype=torch.float32, device=models.device)
323
  mask = torch.ones((1, n), dtype=torch.bool, device=models.device)
 
 
 
 
 
 
 
 
324
 
325
  with torch.no_grad():
326
+ x_proj = models.microbiome.input_projection_type1(x)
327
  final_hidden = models.microbiome.transformer(x_proj, src_key_padding_mask=~mask)
328
  logits = models.microbiome.output_projection(final_hidden).squeeze(-1)
329
 
330
+ return logits.squeeze(0).detach().cpu().numpy(), final_hidden.squeeze(0).detach().cpu().numpy()
 
 
 
331
 
332
 
333
+ def _plot_umap(vectors: np.ndarray, labels: List[str], title: str):
334
+ if len(vectors) < 2:
335
+ raise gr.Error("UMAP needs at least 2 sequences.")
 
336
 
337
  reducer = umap.UMAP(
338
  n_components=2,
339
+ n_neighbors=min(15, len(vectors) - 1),
340
  min_dist=0.1,
341
  metric="cosine",
342
  random_state=42,
343
  )
344
  coords = reducer.fit_transform(vectors)
345
+ norms = np.linalg.norm(vectors, axis=1)
 
 
 
 
 
 
346
 
 
 
347
  fig = px.scatter(
348
+ x=coords[:, 0],
349
+ y=coords[:, 1],
350
+ color=norms,
351
+ hover_name=labels,
352
+ labels={"x": "UMAP 1", "y": "UMAP 2", "color": "vector norm"},
353
  title=title,
354
  color_continuous_scale="Viridis",
355
  )
356
+ fig.update_traces(marker={"size": 10, "line": {"width": 0.6, "color": "#1d2a1f"}, "opacity": 0.9})
357
+ fig.update_layout(
358
+ paper_bgcolor="rgba(255,255,255,0)",
359
+ plot_bgcolor="rgba(255,255,255,0.75)",
360
+ margin={"l": 10, "r": 10, "t": 60, "b": 10},
361
+ )
362
  return fig
363
 
364
 
365
+ def _plot_logits(logits: np.ndarray):
366
  fig = px.histogram(
367
  x=logits,
368
+ nbins=min(50, max(12, len(logits) // 4)),
369
  title="Logit Distribution Over Input DNA Embeddings",
370
+ color_discrete_sequence=["#d8832f"],
371
+ )
372
+ fig.update_layout(
373
+ xaxis_title="Logit",
374
+ yaxis_title="Count",
375
+ paper_bgcolor="rgba(255,255,255,0)",
376
+ plot_bgcolor="rgba(255,255,255,0.75)",
377
+ margin={"l": 10, "r": 10, "t": 60, "b": 10},
378
  )
 
379
  return fig
380
 
381
 
382
+ def _records_to_member_table(records: List[dict]) -> List[List[object]]:
383
+ rows: List[List[object]] = []
384
+ for record in records:
385
+ rows.append(
386
+ [
387
+ record["id"],
388
+ record.get("source", ""),
389
+ record.get("taxonomy", ""),
390
+ record.get("detail", ""),
391
+ len(record["sequence"]),
392
+ ]
393
+ )
394
+ return rows
395
+
396
+
397
+ def _analyze_records(records: List[dict], source_title: str, extra_summary: str = ""):
398
+ if len(records) < 2:
399
+ raise gr.Error("This explorer needs at least 2 sequences to compute the UMAP views.")
400
 
401
  models = _load_models()
402
+ used_records = records[:MAX_GENES]
403
+ labels = [record["id"] for record in used_records]
404
+ seqs = [record["sequence"] for record in used_records]
405
 
406
  input_embeddings = _embed_sequences(seqs, models)
407
  logits, final_embeddings = _infer_logits_and_final_embeddings(input_embeddings, models)
408
 
409
+ input_umap = _plot_umap(input_embeddings, labels, "UMAP of Input DNA Embeddings")
410
+ final_umap = _plot_umap(final_embeddings, labels, "UMAP of Final Transformer Embeddings")
411
+ logits_hist = _plot_logits(logits)
412
+
413
+ rows = []
414
+ order = np.argsort(logits)[::-1]
415
+ for idx in order:
416
+ record = used_records[idx]
417
+ rows.append(
418
+ [
419
+ record["id"],
420
+ float(logits[idx]),
421
+ record.get("source", ""),
422
+ record.get("taxonomy", ""),
423
+ record.get("detail", ""),
424
+ ]
425
+ )
426
 
427
+ summary = (
428
+ f"{source_title}: analyzed {len(used_records)} sequences "
429
+ f"(cap={MAX_GENES}, trim={MAX_SEQ_LEN} nt)."
 
 
430
  )
431
+ if extra_summary:
432
+ summary = f"{summary} {extra_summary}"
433
 
434
+ members = _records_to_member_table(used_records)
435
+ return summary, input_umap, final_umap, logits_hist, rows[:50], members
436
 
 
437
 
438
+ def analyze_fasta(fasta_file: str):
439
+ if fasta_file is None:
440
+ raise gr.Error("Upload a FASTA file first.")
441
+ records, original_n, truncated = _read_fasta(fasta_file)
442
+ extra = f"Loaded {original_n} records and truncated {truncated} sequence(s)."
443
+ return _analyze_records(records, "Raw FASTA upload", extra)
444
 
 
 
 
 
 
445
 
446
+ def analyze_microbeatlas(sample_file: str):
447
+ if sample_file is None:
448
+ raise gr.Error("Upload a MicrobeAtlas taxa TSV first.")
449
+ records, translation_summary = _read_microbeatlas_sample(sample_file)
450
+ return _analyze_records(records, "MicrobeAtlas import", translation_summary)
451
+
452
+
453
+ def search_taxa(query: str):
454
+ matches = _search_otu_records(query)
455
+ if not matches:
456
+ return (
457
+ gr.update(choices=[], value=[]),
458
+ [],
459
+ "No OTUs matched that taxon query.",
460
+ )
461
+
462
+ choices = [(f"{entry.label} | {entry.otu_id}", entry.otu_id) for entry in matches]
463
+ preview = [[entry.otu_id, entry.label, entry.taxonomy, entry.seq_len] for entry in matches]
464
+ return (
465
+ gr.update(choices=choices, value=[]),
466
+ preview,
467
+ f"Found {len(matches)} matching OTUs. Select the ones you want to add to the community.",
468
  )
469
 
 
 
 
470
 
471
+ def add_to_community(selected_otu_ids: List[str], community_ids: List[str]):
472
+ otu_db, _ = _load_otu_db()
473
+ current = list(community_ids or [])
474
+ added = 0
475
+
476
+ for otu_id in selected_otu_ids or []:
477
+ if otu_id in current:
478
+ continue
479
+ if len(current) >= MAX_GENES:
480
+ break
481
+ if otu_id in otu_db:
482
+ current.append(otu_id)
483
+ added += 1
484
+
485
+ records = [
486
+ {
487
+ "id": otu_db[otu_id].otu_id,
488
+ "sequence": otu_db[otu_id].sequence[:MAX_SEQ_LEN],
489
+ "source": "Community builder",
490
+ "taxonomy": otu_db[otu_id].taxonomy,
491
+ "detail": otu_db[otu_id].label,
492
+ }
493
+ for otu_id in current
494
+ if otu_id in otu_db
495
+ ]
496
+ status = f"Community now contains {len(records)} OTUs. Added {added} new member(s)."
497
+ return current, _records_to_member_table(records), status
498
+
499
+
500
+ def clear_community():
501
+ return [], [], "Community cleared."
502
+
503
+
504
+ def analyze_community(community_ids: List[str]):
505
+ otu_db, _ = _load_otu_db()
506
+ if not community_ids:
507
+ raise gr.Error("Build a community first by searching taxa and adding OTUs.")
508
+
509
+ records = []
510
+ for otu_id in community_ids[:MAX_GENES]:
511
+ entry = otu_db.get(otu_id)
512
+ if entry is None:
513
+ continue
514
+ records.append(
515
+ {
516
+ "id": entry.otu_id,
517
+ "sequence": entry.sequence[:MAX_SEQ_LEN],
518
+ "source": "Community builder",
519
+ "taxonomy": entry.taxonomy,
520
+ "detail": entry.label,
521
+ }
522
+ )
523
+
524
+ if not records:
525
+ raise gr.Error("No valid OTU members remain in the current community.")
526
+
527
+ return _analyze_records(records, "Community builder", "Selected by taxon search against otus.97.allinfo.")
528
+
529
+
530
+ with gr.Blocks(title="Microbiome Explorer", css=CSS, theme=gr.themes.Soft()) as demo:
531
+ community_state = gr.State([])
532
+
533
+ gr.HTML(
534
+ """
535
+ <section class="hero">
536
+ <h1>Microbiome Gene Scoring Explorer</h1>
537
+ <p>
538
+ Upload raw FASTA, translate a MicrobeAtlas sample into representative OTU sequences,
539
+ or build a synthetic community by taxonomy. Every route ends in the same pipeline:
540
+ ProkBERT mean pooling, <code>large-notext</code> scoring, and linked embedding views.
541
+ </p>
542
+ </section>
543
+ """
544
+ )
545
 
546
+ with gr.Tabs():
547
+ with gr.Tab("Raw FASTA"):
548
+ with gr.Column(elem_classes=["soft-card"]):
549
+ gr.Markdown(
550
+ "Upload genes directly in FASTA format. Sequences longer than 1024 nt are trimmed and only the first 800 records are used."
551
+ )
552
+ fasta_in = gr.File(
553
+ label="FASTA file",
554
+ file_types=[".fa", ".fasta", ".fna", ".txt"],
555
+ type="filepath",
556
+ )
557
+ fasta_run_btn = gr.Button("Analyze FASTA", variant="primary")
558
+
559
+ with gr.Tab("Import From MicrobeAtlas"):
560
+ with gr.Column(elem_classes=["soft-card"]):
561
+ gr.Markdown(
562
+ f"""
563
+ Bring in a taxa file exported from MicrobeAtlas. Go to
564
+ [MicrobeAtlas sample detail]({MICROBEATLAS_SAMPLE_URL}), click `Download`, and upload the taxa TSV here.
565
+ OTU IDs from `SHORT_TID` are translated to representative sequences using `otus.97.allinfo`.
566
+ """
567
+ )
568
+ microbeatlas_in = gr.File(
569
+ label="MicrobeAtlas taxa TSV",
570
+ file_types=[".tsv", ".txt"],
571
+ type="filepath",
572
+ )
573
+ gr.Examples(
574
+ examples=[[EXAMPLE_SAMPLE_PATH]],
575
+ inputs=[microbeatlas_in],
576
+ label="Use example",
577
+ )
578
+ microbeatlas_run_btn = gr.Button("Translate And Analyze", variant="primary")
579
+
580
+ with gr.Tab("Build A Community"):
581
+ with gr.Column(elem_classes=["soft-card"]):
582
+ gr.Markdown(
583
+ "Search `otus.97.allinfo` by OTU ID, taxon label, or taxonomy string. Add matching OTUs to a custom community, then score the assembled set."
584
+ )
585
+ with gr.Row():
586
+ taxa_query = gr.Textbox(
587
+ label="Search taxa",
588
+ placeholder="Try Nitrospira, Lysobacter, Gammaproteobacteria, 97_8697 ...",
589
+ scale=5,
590
+ )
591
+ taxa_search_btn = gr.Button("Search", variant="secondary", scale=1)
592
+
593
+ community_search_status = gr.Markdown(elem_classes=["section-note"])
594
+ taxa_matches = gr.CheckboxGroup(label="Matching OTUs")
595
+ taxa_matches_preview = gr.Dataframe(
596
+ headers=["otu_id", "label", "taxonomy", "seq_len"],
597
+ label="Match preview",
598
+ wrap=True,
599
+ )
600
+
601
+ with gr.Row():
602
+ community_add_btn = gr.Button("Add Selected OTUs", variant="primary")
603
+ community_clear_btn = gr.Button("Clear Community")
604
+ community_run_btn = gr.Button("Analyze Community", variant="secondary")
605
+
606
+ with gr.Accordion("Community Members", open=False):
607
+ community_table = gr.Dataframe(
608
+ headers=["id", "source", "taxonomy", "detail", "seq_len"],
609
+ label="Current community",
610
+ wrap=True,
611
+ )
612
+ community_status = gr.Markdown(elem_classes=["section-note"])
613
+
614
+ with gr.Accordion("Analysis Results", open=True):
615
+ run_summary = gr.Textbox(label="Run summary")
616
+ with gr.Row():
617
+ input_umap_plot = gr.Plot(label="Input embedding UMAP")
618
+ final_umap_plot = gr.Plot(label="Final embedding UMAP")
619
+ logits_plot = gr.Plot(label="Logit distribution")
620
+ with gr.Accordion("Top-scoring members", open=False):
621
+ top_table = gr.Dataframe(
622
+ headers=["id", "logit", "source", "taxonomy", "detail"],
623
+ label="Top genes by logit",
624
+ wrap=True,
625
+ )
626
+ with gr.Accordion("Analyzed members", open=False):
627
+ member_table = gr.Dataframe(
628
+ headers=["id", "source", "taxonomy", "detail", "seq_len"],
629
+ label="Members used in the run",
630
+ wrap=True,
631
+ )
632
+
633
+ fasta_run_btn.click(
634
+ fn=analyze_fasta,
635
  inputs=[fasta_in],
636
+ outputs=[run_summary, input_umap_plot, final_umap_plot, logits_plot, top_table, member_table],
637
+ )
638
+ microbeatlas_run_btn.click(
639
+ fn=analyze_microbeatlas,
640
+ inputs=[microbeatlas_in],
641
+ outputs=[run_summary, input_umap_plot, final_umap_plot, logits_plot, top_table, member_table],
642
+ )
643
+ taxa_search_btn.click(
644
+ fn=search_taxa,
645
+ inputs=[taxa_query],
646
+ outputs=[taxa_matches, taxa_matches_preview, community_search_status],
647
+ )
648
+ community_add_btn.click(
649
+ fn=add_to_community,
650
+ inputs=[taxa_matches, community_state],
651
+ outputs=[community_state, community_table, community_status],
652
+ )
653
+ community_clear_btn.click(
654
+ fn=clear_community,
655
+ outputs=[community_state, community_table, community_status],
656
+ )
657
+ community_run_btn.click(
658
+ fn=analyze_community,
659
+ inputs=[community_state],
660
+ outputs=[run_summary, input_umap_plot, final_umap_plot, logits_plot, top_table, member_table],
661
  )
662
 
663