Spaces:
Sleeping
Sleeping
the-puzzler commited on
Commit ·
2c1ba2b
1
Parent(s): 1624b4f
Add MicrobeAtlas import and community builder UI
Browse files
app.py
CHANGED
|
@@ -1,24 +1,80 @@
|
|
|
|
|
| 1 |
import os
|
| 2 |
from dataclasses import dataclass
|
| 3 |
-
from typing import List, Tuple
|
| 4 |
|
| 5 |
import gradio as gr
|
| 6 |
import numpy as np
|
| 7 |
import plotly.express as px
|
| 8 |
import torch
|
| 9 |
-
import umap
|
| 10 |
from Bio import SeqIO
|
| 11 |
from transformers import AutoModel, AutoTokenizer
|
| 12 |
|
| 13 |
from model import MicrobiomeTransformer
|
| 14 |
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
MAX_GENES = 800
|
| 17 |
MAX_SEQ_LEN = 1024
|
|
|
|
| 18 |
PROKBERT_MODEL_ID = os.getenv("PROKBERT_MODEL_ID", "neuralbioinfo/prokbert-mini-long")
|
| 19 |
CHECKPOINT_PATH = os.getenv("CHECKPOINT_PATH", "large-notext.pt")
|
| 20 |
-
|
| 21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
|
| 24 |
@dataclass
|
|
@@ -29,7 +85,26 @@ class LoadedModels:
|
|
| 29 |
device: torch.device
|
| 30 |
|
| 31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
_MODELS: LoadedModels | None = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
|
| 35 |
def _load_models() -> LoadedModels:
|
|
@@ -38,7 +113,6 @@ def _load_models() -> LoadedModels:
|
|
| 38 |
return _MODELS
|
| 39 |
|
| 40 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 41 |
-
|
| 42 |
tokenizer = AutoTokenizer.from_pretrained(PROKBERT_MODEL_ID, trust_remote_code=TRUST_REMOTE_CODE)
|
| 43 |
prokbert = AutoModel.from_pretrained(PROKBERT_MODEL_ID, trust_remote_code=TRUST_REMOTE_CODE)
|
| 44 |
prokbert.to(device)
|
|
@@ -69,28 +143,142 @@ def _load_models() -> LoadedModels:
|
|
| 69 |
return _MODELS
|
| 70 |
|
| 71 |
|
| 72 |
-
def
|
| 73 |
-
|
| 74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
truncated = 0
|
| 76 |
|
| 77 |
for record in SeqIO.parse(path, "fasta"):
|
| 78 |
-
seq = str(record.seq).upper()
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
|
| 85 |
-
original_n = len(ids)
|
| 86 |
-
if original_n == 0:
|
| 87 |
-
raise ValueError("No FASTA records found.")
|
| 88 |
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
|
|
|
|
|
|
| 92 |
|
| 93 |
-
|
|
|
|
|
|
|
| 94 |
|
| 95 |
|
| 96 |
def _mean_pool(last_hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
|
@@ -112,7 +300,7 @@ def _embed_sequences(seqs: List[str], models: LoadedModels) -> np.ndarray:
|
|
| 112 |
max_length=MAX_SEQ_LEN,
|
| 113 |
padding=True,
|
| 114 |
)
|
| 115 |
-
inputs = {
|
| 116 |
|
| 117 |
with torch.no_grad():
|
| 118 |
outputs = models.prokbert(**inputs)
|
|
@@ -120,139 +308,356 @@ def _embed_sequences(seqs: List[str], models: LoadedModels) -> np.ndarray:
|
|
| 120 |
|
| 121 |
pooled_batches.append(pooled.detach().cpu().numpy())
|
| 122 |
|
| 123 |
-
|
| 124 |
-
if
|
| 125 |
-
raise
|
| 126 |
-
f"Expected 384-d ProkBERT embeddings, got {
|
| 127 |
)
|
| 128 |
-
return
|
| 129 |
|
| 130 |
|
| 131 |
def _infer_logits_and_final_embeddings(input_embeddings: np.ndarray, models: LoadedModels) -> Tuple[np.ndarray, np.ndarray]:
|
| 132 |
x = torch.tensor(input_embeddings, dtype=torch.float32, device=models.device).unsqueeze(0)
|
| 133 |
n = x.shape[1]
|
| 134 |
-
|
| 135 |
empty_text = torch.zeros((1, 0, 1536), dtype=torch.float32, device=models.device)
|
| 136 |
mask = torch.ones((1, n), dtype=torch.bool, device=models.device)
|
| 137 |
-
type_indicators = torch.zeros((1, n), dtype=torch.long, device=models.device)
|
| 138 |
-
|
| 139 |
-
batch = {
|
| 140 |
-
"embeddings_type1": x,
|
| 141 |
-
"embeddings_type2": empty_text,
|
| 142 |
-
"mask": mask,
|
| 143 |
-
"type_indicators": type_indicators,
|
| 144 |
-
}
|
| 145 |
|
| 146 |
with torch.no_grad():
|
| 147 |
-
x_proj = models.microbiome.input_projection_type1(
|
| 148 |
final_hidden = models.microbiome.transformer(x_proj, src_key_padding_mask=~mask)
|
| 149 |
logits = models.microbiome.output_projection(final_hidden).squeeze(-1)
|
| 150 |
|
| 151 |
-
return (
|
| 152 |
-
logits.squeeze(0).detach().cpu().numpy(),
|
| 153 |
-
final_hidden.squeeze(0).detach().cpu().numpy(),
|
| 154 |
-
)
|
| 155 |
|
| 156 |
|
| 157 |
-
def
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
raise ValueError("Need at least 2 genes to compute UMAP.")
|
| 161 |
|
| 162 |
reducer = umap.UMAP(
|
| 163 |
n_components=2,
|
| 164 |
-
n_neighbors=min(15,
|
| 165 |
min_dist=0.1,
|
| 166 |
metric="cosine",
|
| 167 |
random_state=42,
|
| 168 |
)
|
| 169 |
coords = reducer.fit_transform(vectors)
|
| 170 |
-
|
| 171 |
-
"x": coords[:, 0],
|
| 172 |
-
"y": coords[:, 1],
|
| 173 |
-
"gene": labels,
|
| 174 |
-
value_name: np.linalg.norm(vectors, axis=1),
|
| 175 |
-
}
|
| 176 |
-
|
| 177 |
|
| 178 |
-
def _plot_umap(vectors: np.ndarray, labels: List[str], title: str):
|
| 179 |
-
df = _umap_df(vectors, labels, "norm")
|
| 180 |
fig = px.scatter(
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
hover_name=
|
| 185 |
-
|
| 186 |
title=title,
|
| 187 |
color_continuous_scale="Viridis",
|
| 188 |
)
|
| 189 |
-
fig.update_traces(marker={"size":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 190 |
return fig
|
| 191 |
|
| 192 |
|
| 193 |
-
def _plot_logits(logits: np.ndarray
|
| 194 |
fig = px.histogram(
|
| 195 |
x=logits,
|
| 196 |
-
nbins=min(50, max(
|
| 197 |
title="Logit Distribution Over Input DNA Embeddings",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
)
|
| 199 |
-
fig.update_layout(xaxis_title="Logit", yaxis_title="Count")
|
| 200 |
return fig
|
| 201 |
|
| 202 |
|
| 203 |
-
def
|
| 204 |
-
|
| 205 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
|
| 207 |
models = _load_models()
|
| 208 |
-
|
|
|
|
|
|
|
| 209 |
|
| 210 |
input_embeddings = _embed_sequences(seqs, models)
|
| 211 |
logits, final_embeddings = _infer_logits_and_final_embeddings(input_embeddings, models)
|
| 212 |
|
| 213 |
-
input_umap = _plot_umap(input_embeddings, labels, "UMAP of Input DNA Embeddings
|
| 214 |
-
final_umap = _plot_umap(final_embeddings, labels, "UMAP of Final
|
| 215 |
-
logits_hist = _plot_logits(logits
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
f"
|
| 220 |
-
f"Used {capped_n} (cap={MAX_GENES}). "
|
| 221 |
-
f"Truncated {truncated} sequence(s) to {MAX_SEQ_LEN} nt."
|
| 222 |
)
|
|
|
|
|
|
|
| 223 |
|
| 224 |
-
|
| 225 |
-
|
| 226 |
|
| 227 |
-
return info, input_umap, final_umap, logits_hist, top_rows
|
| 228 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
|
| 230 |
-
with gr.Blocks(title="Microbiome Space: ProkBERT -> large-notext") as demo:
|
| 231 |
-
gr.Markdown(
|
| 232 |
-
"""
|
| 233 |
-
# Microbiome Gene Scoring Explorer
|
| 234 |
-
Upload a FASTA of genes, embed with `prokbert-mini-long` (mean pooling), score with `large-notext`, and inspect embedding geometry + logit distribution.
|
| 235 |
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 240 |
)
|
| 241 |
|
| 242 |
-
with gr.Row():
|
| 243 |
-
fasta_in = gr.File(label="FASTA file", file_types=[".fa", ".fasta", ".fna", ".txt"], type="filepath")
|
| 244 |
-
run_btn = gr.Button("Run", variant="primary")
|
| 245 |
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 251 |
|
| 252 |
-
|
| 253 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
inputs=[fasta_in],
|
| 255 |
-
outputs=[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 256 |
)
|
| 257 |
|
| 258 |
|
|
|
|
| 1 |
+
import csv
|
| 2 |
import os
|
| 3 |
from dataclasses import dataclass
|
| 4 |
+
from typing import Dict, List, Tuple
|
| 5 |
|
| 6 |
import gradio as gr
|
| 7 |
import numpy as np
|
| 8 |
import plotly.express as px
|
| 9 |
import torch
|
|
|
|
| 10 |
from Bio import SeqIO
|
| 11 |
from transformers import AutoModel, AutoTokenizer
|
| 12 |
|
| 13 |
from model import MicrobiomeTransformer
|
| 14 |
|
| 15 |
|
| 16 |
+
os.environ.setdefault("NUMBA_CACHE_DIR", "/tmp/numba-cache")
|
| 17 |
+
|
| 18 |
+
import umap
|
| 19 |
+
|
| 20 |
+
|
| 21 |
MAX_GENES = 800
|
| 22 |
MAX_SEQ_LEN = 1024
|
| 23 |
+
BATCH_SIZE = int(os.getenv("EMBED_BATCH_SIZE", "32"))
|
| 24 |
PROKBERT_MODEL_ID = os.getenv("PROKBERT_MODEL_ID", "neuralbioinfo/prokbert-mini-long")
|
| 25 |
CHECKPOINT_PATH = os.getenv("CHECKPOINT_PATH", "large-notext.pt")
|
| 26 |
+
OTU_INFO_PATH = os.getenv("OTU_INFO_PATH", "otus.97.allinfo")
|
| 27 |
+
EXAMPLE_SAMPLE_PATH = "sample_DRS000421_DRR000770_taxa.tsv"
|
| 28 |
+
MICROBEATLAS_SAMPLE_URL = "https://microbeatlas.org/sample_detail?sid=DRS000421&rid=null"
|
| 29 |
+
TRUST_REMOTE_CODE = os.getenv("TRUST_REMOTE_CODE", "true").lower() == "true"
|
| 30 |
+
|
| 31 |
+
CSS = """
|
| 32 |
+
:root {
|
| 33 |
+
--bg: #f4f0e8;
|
| 34 |
+
--panel: rgba(255, 252, 247, 0.88);
|
| 35 |
+
--panel-strong: rgba(246, 240, 230, 0.96);
|
| 36 |
+
--ink: #1d2a1f;
|
| 37 |
+
--muted: #586454;
|
| 38 |
+
--accent: #0e7a5f;
|
| 39 |
+
--accent-2: #d8832f;
|
| 40 |
+
--line: rgba(29, 42, 31, 0.12);
|
| 41 |
+
}
|
| 42 |
+
.gradio-container {
|
| 43 |
+
background:
|
| 44 |
+
radial-gradient(circle at top left, rgba(216, 131, 47, 0.18), transparent 28%),
|
| 45 |
+
radial-gradient(circle at top right, rgba(14, 122, 95, 0.18), transparent 24%),
|
| 46 |
+
linear-gradient(180deg, #f7f2e9 0%, #eee6d8 100%);
|
| 47 |
+
color: var(--ink);
|
| 48 |
+
}
|
| 49 |
+
.hero {
|
| 50 |
+
padding: 28px;
|
| 51 |
+
border: 1px solid var(--line);
|
| 52 |
+
border-radius: 24px;
|
| 53 |
+
background: linear-gradient(135deg, rgba(255,255,255,0.85), rgba(241,232,218,0.92));
|
| 54 |
+
box-shadow: 0 18px 60px rgba(69, 57, 34, 0.08);
|
| 55 |
+
}
|
| 56 |
+
.hero h1 {
|
| 57 |
+
margin: 0 0 10px 0;
|
| 58 |
+
font-size: 2.4rem;
|
| 59 |
+
line-height: 1.05;
|
| 60 |
+
}
|
| 61 |
+
.hero p {
|
| 62 |
+
margin: 0;
|
| 63 |
+
max-width: 900px;
|
| 64 |
+
color: var(--muted);
|
| 65 |
+
font-size: 1rem;
|
| 66 |
+
}
|
| 67 |
+
.soft-card {
|
| 68 |
+
border: 1px solid var(--line);
|
| 69 |
+
border-radius: 22px;
|
| 70 |
+
background: var(--panel);
|
| 71 |
+
box-shadow: 0 12px 32px rgba(40, 36, 26, 0.06);
|
| 72 |
+
}
|
| 73 |
+
.section-note {
|
| 74 |
+
color: var(--muted);
|
| 75 |
+
font-size: 0.95rem;
|
| 76 |
+
}
|
| 77 |
+
"""
|
| 78 |
|
| 79 |
|
| 80 |
@dataclass
|
|
|
|
| 85 |
device: torch.device
|
| 86 |
|
| 87 |
|
| 88 |
+
@dataclass
|
| 89 |
+
class OTUEntry:
|
| 90 |
+
otu_id: str
|
| 91 |
+
label: str
|
| 92 |
+
taxonomy: str
|
| 93 |
+
sequence: str
|
| 94 |
+
seq_len: int
|
| 95 |
+
search_text: str
|
| 96 |
+
|
| 97 |
+
|
| 98 |
_MODELS: LoadedModels | None = None
|
| 99 |
+
_OTU_DB: Dict[str, OTUEntry] | None = None
|
| 100 |
+
_OTU_SEARCH: List[OTUEntry] | None = None
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def _extract_taxa_name(taxonomy: str) -> str:
|
| 104 |
+
parts = [part.strip() for part in taxonomy.split(";") if part.strip()]
|
| 105 |
+
if not parts:
|
| 106 |
+
return "Unclassified"
|
| 107 |
+
return parts[-1].replace("g__", "").replace("s__", "").replace("f__", "")
|
| 108 |
|
| 109 |
|
| 110 |
def _load_models() -> LoadedModels:
|
|
|
|
| 113 |
return _MODELS
|
| 114 |
|
| 115 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
| 116 |
tokenizer = AutoTokenizer.from_pretrained(PROKBERT_MODEL_ID, trust_remote_code=TRUST_REMOTE_CODE)
|
| 117 |
prokbert = AutoModel.from_pretrained(PROKBERT_MODEL_ID, trust_remote_code=TRUST_REMOTE_CODE)
|
| 118 |
prokbert.to(device)
|
|
|
|
| 143 |
return _MODELS
|
| 144 |
|
| 145 |
|
| 146 |
+
def _load_otu_db() -> Tuple[Dict[str, OTUEntry], List[OTUEntry]]:
|
| 147 |
+
global _OTU_DB, _OTU_SEARCH
|
| 148 |
+
if _OTU_DB is not None and _OTU_SEARCH is not None:
|
| 149 |
+
return _OTU_DB, _OTU_SEARCH
|
| 150 |
+
|
| 151 |
+
otu_db: Dict[str, OTUEntry] = {}
|
| 152 |
+
otu_search: List[OTUEntry] = []
|
| 153 |
+
|
| 154 |
+
with open(OTU_INFO_PATH, newline="") as handle:
|
| 155 |
+
reader = csv.reader(handle, delimiter="\t")
|
| 156 |
+
for row in reader:
|
| 157 |
+
if len(row) < 15:
|
| 158 |
+
continue
|
| 159 |
+
|
| 160 |
+
raw_id = row[0].strip()
|
| 161 |
+
sequence = row[6].strip().upper()
|
| 162 |
+
taxonomy = row[14].strip() or row[8].strip() or "Unclassified"
|
| 163 |
+
|
| 164 |
+
if not raw_id or not sequence:
|
| 165 |
+
continue
|
| 166 |
+
|
| 167 |
+
otu_id = raw_id.split(";")[-1]
|
| 168 |
+
label = _extract_taxa_name(taxonomy)
|
| 169 |
+
entry = OTUEntry(
|
| 170 |
+
otu_id=otu_id,
|
| 171 |
+
label=label,
|
| 172 |
+
taxonomy=taxonomy,
|
| 173 |
+
sequence=sequence,
|
| 174 |
+
seq_len=len(sequence),
|
| 175 |
+
search_text=f"{otu_id} {label} {taxonomy}".lower(),
|
| 176 |
+
)
|
| 177 |
+
otu_db[otu_id] = entry
|
| 178 |
+
otu_search.append(entry)
|
| 179 |
+
|
| 180 |
+
_OTU_DB = otu_db
|
| 181 |
+
_OTU_SEARCH = otu_search
|
| 182 |
+
return otu_db, otu_search
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def _trim_sequence(sequence: str) -> Tuple[str, bool]:
|
| 186 |
+
if len(sequence) > MAX_SEQ_LEN:
|
| 187 |
+
return sequence[:MAX_SEQ_LEN], True
|
| 188 |
+
return sequence, False
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def _read_fasta(path: str) -> Tuple[List[dict], int, int]:
|
| 192 |
+
records: List[dict] = []
|
| 193 |
truncated = 0
|
| 194 |
|
| 195 |
for record in SeqIO.parse(path, "fasta"):
|
| 196 |
+
seq, was_truncated = _trim_sequence(str(record.seq).upper())
|
| 197 |
+
truncated += int(was_truncated)
|
| 198 |
+
records.append(
|
| 199 |
+
{
|
| 200 |
+
"id": record.id,
|
| 201 |
+
"sequence": seq,
|
| 202 |
+
"source": "FASTA",
|
| 203 |
+
"taxonomy": "",
|
| 204 |
+
"detail": f"{len(seq)} nt",
|
| 205 |
+
}
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
if not records:
|
| 209 |
+
raise gr.Error("No FASTA records found.")
|
| 210 |
+
|
| 211 |
+
return records[:MAX_GENES], len(records), truncated
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def _read_microbeatlas_sample(path: str) -> Tuple[List[dict], str]:
|
| 215 |
+
otu_db, _ = _load_otu_db()
|
| 216 |
+
records: List[dict] = []
|
| 217 |
+
missing_ids: List[str] = []
|
| 218 |
+
|
| 219 |
+
with open(path, newline="") as handle:
|
| 220 |
+
reader = csv.reader(handle, delimiter="\t")
|
| 221 |
+
header = next(reader, None)
|
| 222 |
+
if header is None:
|
| 223 |
+
raise gr.Error("The MicrobeAtlas file is empty.")
|
| 224 |
+
|
| 225 |
+
columns = [col.strip() for col in header]
|
| 226 |
+
column_index = {name: idx for idx, name in enumerate(columns)}
|
| 227 |
+
if "SHORT_TID" not in column_index:
|
| 228 |
+
raise gr.Error("Expected a MicrobeAtlas taxa file with a SHORT_TID column.")
|
| 229 |
+
|
| 230 |
+
for row in reader:
|
| 231 |
+
if not row:
|
| 232 |
+
continue
|
| 233 |
+
otu_id = row[column_index["SHORT_TID"]].strip()
|
| 234 |
+
if not otu_id:
|
| 235 |
+
continue
|
| 236 |
+
entry = otu_db.get(otu_id)
|
| 237 |
+
if entry is None:
|
| 238 |
+
missing_ids.append(otu_id)
|
| 239 |
+
continue
|
| 240 |
+
|
| 241 |
+
seq, was_truncated = _trim_sequence(entry.sequence)
|
| 242 |
+
detail_bits = []
|
| 243 |
+
for column in ("COUNT", "ABUNDANCE"):
|
| 244 |
+
idx = column_index.get(column)
|
| 245 |
+
if idx is not None and idx < len(row):
|
| 246 |
+
value = row[idx].strip()
|
| 247 |
+
if value:
|
| 248 |
+
detail_bits.append(f"{column.lower()}={value}")
|
| 249 |
+
if was_truncated:
|
| 250 |
+
detail_bits.append("trimmed")
|
| 251 |
+
|
| 252 |
+
records.append(
|
| 253 |
+
{
|
| 254 |
+
"id": otu_id,
|
| 255 |
+
"sequence": seq,
|
| 256 |
+
"source": "MicrobeAtlas",
|
| 257 |
+
"taxonomy": entry.taxonomy,
|
| 258 |
+
"detail": ", ".join(detail_bits) if detail_bits else f"{entry.seq_len} nt",
|
| 259 |
+
}
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
if not records:
|
| 263 |
+
raise gr.Error("No OTU IDs from this MicrobeAtlas file matched otus.97.allinfo.")
|
| 264 |
+
|
| 265 |
+
used_records = records[:MAX_GENES]
|
| 266 |
+
summary = (
|
| 267 |
+
f"Translated {len(used_records)} OTUs from the MicrobeAtlas upload. "
|
| 268 |
+
f"Missing sequence mappings for {len(missing_ids)} OTUs."
|
| 269 |
+
)
|
| 270 |
+
return used_records, summary
|
| 271 |
|
|
|
|
|
|
|
|
|
|
| 272 |
|
| 273 |
+
def _search_otu_records(query: str, limit: int = 80) -> List[OTUEntry]:
|
| 274 |
+
_, otu_search = _load_otu_db()
|
| 275 |
+
needle = query.strip().lower()
|
| 276 |
+
if not needle:
|
| 277 |
+
return []
|
| 278 |
|
| 279 |
+
matches = [entry for entry in otu_search if needle in entry.search_text]
|
| 280 |
+
matches.sort(key=lambda entry: (entry.label.lower(), entry.otu_id))
|
| 281 |
+
return matches[:limit]
|
| 282 |
|
| 283 |
|
| 284 |
def _mean_pool(last_hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
|
|
|
| 300 |
max_length=MAX_SEQ_LEN,
|
| 301 |
padding=True,
|
| 302 |
)
|
| 303 |
+
inputs = {key: value.to(models.device) for key, value in inputs.items()}
|
| 304 |
|
| 305 |
with torch.no_grad():
|
| 306 |
outputs = models.prokbert(**inputs)
|
|
|
|
| 308 |
|
| 309 |
pooled_batches.append(pooled.detach().cpu().numpy())
|
| 310 |
|
| 311 |
+
embeddings = np.vstack(pooled_batches)
|
| 312 |
+
if embeddings.shape[1] != 384:
|
| 313 |
+
raise gr.Error(
|
| 314 |
+
f"Expected 384-d ProkBERT embeddings, got {embeddings.shape[1]} from {PROKBERT_MODEL_ID}."
|
| 315 |
)
|
| 316 |
+
return embeddings
|
| 317 |
|
| 318 |
|
| 319 |
def _infer_logits_and_final_embeddings(input_embeddings: np.ndarray, models: LoadedModels) -> Tuple[np.ndarray, np.ndarray]:
|
| 320 |
x = torch.tensor(input_embeddings, dtype=torch.float32, device=models.device).unsqueeze(0)
|
| 321 |
n = x.shape[1]
|
|
|
|
| 322 |
empty_text = torch.zeros((1, 0, 1536), dtype=torch.float32, device=models.device)
|
| 323 |
mask = torch.ones((1, n), dtype=torch.bool, device=models.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 324 |
|
| 325 |
with torch.no_grad():
|
| 326 |
+
x_proj = models.microbiome.input_projection_type1(x)
|
| 327 |
final_hidden = models.microbiome.transformer(x_proj, src_key_padding_mask=~mask)
|
| 328 |
logits = models.microbiome.output_projection(final_hidden).squeeze(-1)
|
| 329 |
|
| 330 |
+
return logits.squeeze(0).detach().cpu().numpy(), final_hidden.squeeze(0).detach().cpu().numpy()
|
|
|
|
|
|
|
|
|
|
| 331 |
|
| 332 |
|
| 333 |
+
def _plot_umap(vectors: np.ndarray, labels: List[str], title: str):
|
| 334 |
+
if len(vectors) < 2:
|
| 335 |
+
raise gr.Error("UMAP needs at least 2 sequences.")
|
|
|
|
| 336 |
|
| 337 |
reducer = umap.UMAP(
|
| 338 |
n_components=2,
|
| 339 |
+
n_neighbors=min(15, len(vectors) - 1),
|
| 340 |
min_dist=0.1,
|
| 341 |
metric="cosine",
|
| 342 |
random_state=42,
|
| 343 |
)
|
| 344 |
coords = reducer.fit_transform(vectors)
|
| 345 |
+
norms = np.linalg.norm(vectors, axis=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 346 |
|
|
|
|
|
|
|
| 347 |
fig = px.scatter(
|
| 348 |
+
x=coords[:, 0],
|
| 349 |
+
y=coords[:, 1],
|
| 350 |
+
color=norms,
|
| 351 |
+
hover_name=labels,
|
| 352 |
+
labels={"x": "UMAP 1", "y": "UMAP 2", "color": "vector norm"},
|
| 353 |
title=title,
|
| 354 |
color_continuous_scale="Viridis",
|
| 355 |
)
|
| 356 |
+
fig.update_traces(marker={"size": 10, "line": {"width": 0.6, "color": "#1d2a1f"}, "opacity": 0.9})
|
| 357 |
+
fig.update_layout(
|
| 358 |
+
paper_bgcolor="rgba(255,255,255,0)",
|
| 359 |
+
plot_bgcolor="rgba(255,255,255,0.75)",
|
| 360 |
+
margin={"l": 10, "r": 10, "t": 60, "b": 10},
|
| 361 |
+
)
|
| 362 |
return fig
|
| 363 |
|
| 364 |
|
| 365 |
+
def _plot_logits(logits: np.ndarray):
|
| 366 |
fig = px.histogram(
|
| 367 |
x=logits,
|
| 368 |
+
nbins=min(50, max(12, len(logits) // 4)),
|
| 369 |
title="Logit Distribution Over Input DNA Embeddings",
|
| 370 |
+
color_discrete_sequence=["#d8832f"],
|
| 371 |
+
)
|
| 372 |
+
fig.update_layout(
|
| 373 |
+
xaxis_title="Logit",
|
| 374 |
+
yaxis_title="Count",
|
| 375 |
+
paper_bgcolor="rgba(255,255,255,0)",
|
| 376 |
+
plot_bgcolor="rgba(255,255,255,0.75)",
|
| 377 |
+
margin={"l": 10, "r": 10, "t": 60, "b": 10},
|
| 378 |
)
|
|
|
|
| 379 |
return fig
|
| 380 |
|
| 381 |
|
| 382 |
+
def _records_to_member_table(records: List[dict]) -> List[List[object]]:
|
| 383 |
+
rows: List[List[object]] = []
|
| 384 |
+
for record in records:
|
| 385 |
+
rows.append(
|
| 386 |
+
[
|
| 387 |
+
record["id"],
|
| 388 |
+
record.get("source", ""),
|
| 389 |
+
record.get("taxonomy", ""),
|
| 390 |
+
record.get("detail", ""),
|
| 391 |
+
len(record["sequence"]),
|
| 392 |
+
]
|
| 393 |
+
)
|
| 394 |
+
return rows
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
def _analyze_records(records: List[dict], source_title: str, extra_summary: str = ""):
|
| 398 |
+
if len(records) < 2:
|
| 399 |
+
raise gr.Error("This explorer needs at least 2 sequences to compute the UMAP views.")
|
| 400 |
|
| 401 |
models = _load_models()
|
| 402 |
+
used_records = records[:MAX_GENES]
|
| 403 |
+
labels = [record["id"] for record in used_records]
|
| 404 |
+
seqs = [record["sequence"] for record in used_records]
|
| 405 |
|
| 406 |
input_embeddings = _embed_sequences(seqs, models)
|
| 407 |
logits, final_embeddings = _infer_logits_and_final_embeddings(input_embeddings, models)
|
| 408 |
|
| 409 |
+
input_umap = _plot_umap(input_embeddings, labels, "UMAP of Input DNA Embeddings")
|
| 410 |
+
final_umap = _plot_umap(final_embeddings, labels, "UMAP of Final Transformer Embeddings")
|
| 411 |
+
logits_hist = _plot_logits(logits)
|
| 412 |
+
|
| 413 |
+
rows = []
|
| 414 |
+
order = np.argsort(logits)[::-1]
|
| 415 |
+
for idx in order:
|
| 416 |
+
record = used_records[idx]
|
| 417 |
+
rows.append(
|
| 418 |
+
[
|
| 419 |
+
record["id"],
|
| 420 |
+
float(logits[idx]),
|
| 421 |
+
record.get("source", ""),
|
| 422 |
+
record.get("taxonomy", ""),
|
| 423 |
+
record.get("detail", ""),
|
| 424 |
+
]
|
| 425 |
+
)
|
| 426 |
|
| 427 |
+
summary = (
|
| 428 |
+
f"{source_title}: analyzed {len(used_records)} sequences "
|
| 429 |
+
f"(cap={MAX_GENES}, trim={MAX_SEQ_LEN} nt)."
|
|
|
|
|
|
|
| 430 |
)
|
| 431 |
+
if extra_summary:
|
| 432 |
+
summary = f"{summary} {extra_summary}"
|
| 433 |
|
| 434 |
+
members = _records_to_member_table(used_records)
|
| 435 |
+
return summary, input_umap, final_umap, logits_hist, rows[:50], members
|
| 436 |
|
|
|
|
| 437 |
|
| 438 |
+
def analyze_fasta(fasta_file: str):
|
| 439 |
+
if fasta_file is None:
|
| 440 |
+
raise gr.Error("Upload a FASTA file first.")
|
| 441 |
+
records, original_n, truncated = _read_fasta(fasta_file)
|
| 442 |
+
extra = f"Loaded {original_n} records and truncated {truncated} sequence(s)."
|
| 443 |
+
return _analyze_records(records, "Raw FASTA upload", extra)
|
| 444 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 445 |
|
| 446 |
+
def analyze_microbeatlas(sample_file: str):
|
| 447 |
+
if sample_file is None:
|
| 448 |
+
raise gr.Error("Upload a MicrobeAtlas taxa TSV first.")
|
| 449 |
+
records, translation_summary = _read_microbeatlas_sample(sample_file)
|
| 450 |
+
return _analyze_records(records, "MicrobeAtlas import", translation_summary)
|
| 451 |
+
|
| 452 |
+
|
| 453 |
+
def search_taxa(query: str):
|
| 454 |
+
matches = _search_otu_records(query)
|
| 455 |
+
if not matches:
|
| 456 |
+
return (
|
| 457 |
+
gr.update(choices=[], value=[]),
|
| 458 |
+
[],
|
| 459 |
+
"No OTUs matched that taxon query.",
|
| 460 |
+
)
|
| 461 |
+
|
| 462 |
+
choices = [(f"{entry.label} | {entry.otu_id}", entry.otu_id) for entry in matches]
|
| 463 |
+
preview = [[entry.otu_id, entry.label, entry.taxonomy, entry.seq_len] for entry in matches]
|
| 464 |
+
return (
|
| 465 |
+
gr.update(choices=choices, value=[]),
|
| 466 |
+
preview,
|
| 467 |
+
f"Found {len(matches)} matching OTUs. Select the ones you want to add to the community.",
|
| 468 |
)
|
| 469 |
|
|
|
|
|
|
|
|
|
|
| 470 |
|
| 471 |
+
def add_to_community(selected_otu_ids: List[str], community_ids: List[str]):
|
| 472 |
+
otu_db, _ = _load_otu_db()
|
| 473 |
+
current = list(community_ids or [])
|
| 474 |
+
added = 0
|
| 475 |
+
|
| 476 |
+
for otu_id in selected_otu_ids or []:
|
| 477 |
+
if otu_id in current:
|
| 478 |
+
continue
|
| 479 |
+
if len(current) >= MAX_GENES:
|
| 480 |
+
break
|
| 481 |
+
if otu_id in otu_db:
|
| 482 |
+
current.append(otu_id)
|
| 483 |
+
added += 1
|
| 484 |
+
|
| 485 |
+
records = [
|
| 486 |
+
{
|
| 487 |
+
"id": otu_db[otu_id].otu_id,
|
| 488 |
+
"sequence": otu_db[otu_id].sequence[:MAX_SEQ_LEN],
|
| 489 |
+
"source": "Community builder",
|
| 490 |
+
"taxonomy": otu_db[otu_id].taxonomy,
|
| 491 |
+
"detail": otu_db[otu_id].label,
|
| 492 |
+
}
|
| 493 |
+
for otu_id in current
|
| 494 |
+
if otu_id in otu_db
|
| 495 |
+
]
|
| 496 |
+
status = f"Community now contains {len(records)} OTUs. Added {added} new member(s)."
|
| 497 |
+
return current, _records_to_member_table(records), status
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
def clear_community():
|
| 501 |
+
return [], [], "Community cleared."
|
| 502 |
+
|
| 503 |
+
|
| 504 |
+
def analyze_community(community_ids: List[str]):
|
| 505 |
+
otu_db, _ = _load_otu_db()
|
| 506 |
+
if not community_ids:
|
| 507 |
+
raise gr.Error("Build a community first by searching taxa and adding OTUs.")
|
| 508 |
+
|
| 509 |
+
records = []
|
| 510 |
+
for otu_id in community_ids[:MAX_GENES]:
|
| 511 |
+
entry = otu_db.get(otu_id)
|
| 512 |
+
if entry is None:
|
| 513 |
+
continue
|
| 514 |
+
records.append(
|
| 515 |
+
{
|
| 516 |
+
"id": entry.otu_id,
|
| 517 |
+
"sequence": entry.sequence[:MAX_SEQ_LEN],
|
| 518 |
+
"source": "Community builder",
|
| 519 |
+
"taxonomy": entry.taxonomy,
|
| 520 |
+
"detail": entry.label,
|
| 521 |
+
}
|
| 522 |
+
)
|
| 523 |
+
|
| 524 |
+
if not records:
|
| 525 |
+
raise gr.Error("No valid OTU members remain in the current community.")
|
| 526 |
+
|
| 527 |
+
return _analyze_records(records, "Community builder", "Selected by taxon search against otus.97.allinfo.")
|
| 528 |
+
|
| 529 |
+
|
| 530 |
+
with gr.Blocks(title="Microbiome Explorer", css=CSS, theme=gr.themes.Soft()) as demo:
|
| 531 |
+
community_state = gr.State([])
|
| 532 |
+
|
| 533 |
+
gr.HTML(
|
| 534 |
+
"""
|
| 535 |
+
<section class="hero">
|
| 536 |
+
<h1>Microbiome Gene Scoring Explorer</h1>
|
| 537 |
+
<p>
|
| 538 |
+
Upload raw FASTA, translate a MicrobeAtlas sample into representative OTU sequences,
|
| 539 |
+
or build a synthetic community by taxonomy. Every route ends in the same pipeline:
|
| 540 |
+
ProkBERT mean pooling, <code>large-notext</code> scoring, and linked embedding views.
|
| 541 |
+
</p>
|
| 542 |
+
</section>
|
| 543 |
+
"""
|
| 544 |
+
)
|
| 545 |
|
| 546 |
+
with gr.Tabs():
|
| 547 |
+
with gr.Tab("Raw FASTA"):
|
| 548 |
+
with gr.Column(elem_classes=["soft-card"]):
|
| 549 |
+
gr.Markdown(
|
| 550 |
+
"Upload genes directly in FASTA format. Sequences longer than 1024 nt are trimmed and only the first 800 records are used."
|
| 551 |
+
)
|
| 552 |
+
fasta_in = gr.File(
|
| 553 |
+
label="FASTA file",
|
| 554 |
+
file_types=[".fa", ".fasta", ".fna", ".txt"],
|
| 555 |
+
type="filepath",
|
| 556 |
+
)
|
| 557 |
+
fasta_run_btn = gr.Button("Analyze FASTA", variant="primary")
|
| 558 |
+
|
| 559 |
+
with gr.Tab("Import From MicrobeAtlas"):
|
| 560 |
+
with gr.Column(elem_classes=["soft-card"]):
|
| 561 |
+
gr.Markdown(
|
| 562 |
+
f"""
|
| 563 |
+
Bring in a taxa file exported from MicrobeAtlas. Go to
|
| 564 |
+
[MicrobeAtlas sample detail]({MICROBEATLAS_SAMPLE_URL}), click `Download`, and upload the taxa TSV here.
|
| 565 |
+
OTU IDs from `SHORT_TID` are translated to representative sequences using `otus.97.allinfo`.
|
| 566 |
+
"""
|
| 567 |
+
)
|
| 568 |
+
microbeatlas_in = gr.File(
|
| 569 |
+
label="MicrobeAtlas taxa TSV",
|
| 570 |
+
file_types=[".tsv", ".txt"],
|
| 571 |
+
type="filepath",
|
| 572 |
+
)
|
| 573 |
+
gr.Examples(
|
| 574 |
+
examples=[[EXAMPLE_SAMPLE_PATH]],
|
| 575 |
+
inputs=[microbeatlas_in],
|
| 576 |
+
label="Use example",
|
| 577 |
+
)
|
| 578 |
+
microbeatlas_run_btn = gr.Button("Translate And Analyze", variant="primary")
|
| 579 |
+
|
| 580 |
+
with gr.Tab("Build A Community"):
|
| 581 |
+
with gr.Column(elem_classes=["soft-card"]):
|
| 582 |
+
gr.Markdown(
|
| 583 |
+
"Search `otus.97.allinfo` by OTU ID, taxon label, or taxonomy string. Add matching OTUs to a custom community, then score the assembled set."
|
| 584 |
+
)
|
| 585 |
+
with gr.Row():
|
| 586 |
+
taxa_query = gr.Textbox(
|
| 587 |
+
label="Search taxa",
|
| 588 |
+
placeholder="Try Nitrospira, Lysobacter, Gammaproteobacteria, 97_8697 ...",
|
| 589 |
+
scale=5,
|
| 590 |
+
)
|
| 591 |
+
taxa_search_btn = gr.Button("Search", variant="secondary", scale=1)
|
| 592 |
+
|
| 593 |
+
community_search_status = gr.Markdown(elem_classes=["section-note"])
|
| 594 |
+
taxa_matches = gr.CheckboxGroup(label="Matching OTUs")
|
| 595 |
+
taxa_matches_preview = gr.Dataframe(
|
| 596 |
+
headers=["otu_id", "label", "taxonomy", "seq_len"],
|
| 597 |
+
label="Match preview",
|
| 598 |
+
wrap=True,
|
| 599 |
+
)
|
| 600 |
+
|
| 601 |
+
with gr.Row():
|
| 602 |
+
community_add_btn = gr.Button("Add Selected OTUs", variant="primary")
|
| 603 |
+
community_clear_btn = gr.Button("Clear Community")
|
| 604 |
+
community_run_btn = gr.Button("Analyze Community", variant="secondary")
|
| 605 |
+
|
| 606 |
+
with gr.Accordion("Community Members", open=False):
|
| 607 |
+
community_table = gr.Dataframe(
|
| 608 |
+
headers=["id", "source", "taxonomy", "detail", "seq_len"],
|
| 609 |
+
label="Current community",
|
| 610 |
+
wrap=True,
|
| 611 |
+
)
|
| 612 |
+
community_status = gr.Markdown(elem_classes=["section-note"])
|
| 613 |
+
|
| 614 |
+
with gr.Accordion("Analysis Results", open=True):
|
| 615 |
+
run_summary = gr.Textbox(label="Run summary")
|
| 616 |
+
with gr.Row():
|
| 617 |
+
input_umap_plot = gr.Plot(label="Input embedding UMAP")
|
| 618 |
+
final_umap_plot = gr.Plot(label="Final embedding UMAP")
|
| 619 |
+
logits_plot = gr.Plot(label="Logit distribution")
|
| 620 |
+
with gr.Accordion("Top-scoring members", open=False):
|
| 621 |
+
top_table = gr.Dataframe(
|
| 622 |
+
headers=["id", "logit", "source", "taxonomy", "detail"],
|
| 623 |
+
label="Top genes by logit",
|
| 624 |
+
wrap=True,
|
| 625 |
+
)
|
| 626 |
+
with gr.Accordion("Analyzed members", open=False):
|
| 627 |
+
member_table = gr.Dataframe(
|
| 628 |
+
headers=["id", "source", "taxonomy", "detail", "seq_len"],
|
| 629 |
+
label="Members used in the run",
|
| 630 |
+
wrap=True,
|
| 631 |
+
)
|
| 632 |
+
|
| 633 |
+
fasta_run_btn.click(
|
| 634 |
+
fn=analyze_fasta,
|
| 635 |
inputs=[fasta_in],
|
| 636 |
+
outputs=[run_summary, input_umap_plot, final_umap_plot, logits_plot, top_table, member_table],
|
| 637 |
+
)
|
| 638 |
+
microbeatlas_run_btn.click(
|
| 639 |
+
fn=analyze_microbeatlas,
|
| 640 |
+
inputs=[microbeatlas_in],
|
| 641 |
+
outputs=[run_summary, input_umap_plot, final_umap_plot, logits_plot, top_table, member_table],
|
| 642 |
+
)
|
| 643 |
+
taxa_search_btn.click(
|
| 644 |
+
fn=search_taxa,
|
| 645 |
+
inputs=[taxa_query],
|
| 646 |
+
outputs=[taxa_matches, taxa_matches_preview, community_search_status],
|
| 647 |
+
)
|
| 648 |
+
community_add_btn.click(
|
| 649 |
+
fn=add_to_community,
|
| 650 |
+
inputs=[taxa_matches, community_state],
|
| 651 |
+
outputs=[community_state, community_table, community_status],
|
| 652 |
+
)
|
| 653 |
+
community_clear_btn.click(
|
| 654 |
+
fn=clear_community,
|
| 655 |
+
outputs=[community_state, community_table, community_status],
|
| 656 |
+
)
|
| 657 |
+
community_run_btn.click(
|
| 658 |
+
fn=analyze_community,
|
| 659 |
+
inputs=[community_state],
|
| 660 |
+
outputs=[run_summary, input_umap_plot, final_umap_plot, logits_plot, top_table, member_table],
|
| 661 |
)
|
| 662 |
|
| 663 |
|