Spaces:
Running
Running
| import csv | |
| import os | |
| import sqlite3 | |
| import sys | |
| import tempfile | |
| from dataclasses import dataclass | |
| from typing import Dict, List, Tuple | |
| import gradio as gr | |
| import numpy as np | |
| import plotly.express as px | |
| import plotly.graph_objects as go | |
| import torch | |
| from gradio_client import utils as gradio_client_utils | |
| from transformers import AutoModel, AutoTokenizer | |
| from model import MicrobiomeTransformer | |
| os.environ.setdefault("NUMBA_CACHE_DIR", "/tmp/numba-cache") | |
| import umap | |
| _ORIGINAL_GRADIO_GET_TYPE = gradio_client_utils.get_type | |
| def _patched_gradio_get_type(schema): | |
| if isinstance(schema, bool): | |
| return "boolean" | |
| return _ORIGINAL_GRADIO_GET_TYPE(schema) | |
| gradio_client_utils.get_type = _patched_gradio_get_type | |
| csv.field_size_limit(min(sys.maxsize, 10**9)) | |
| MAX_GENES = 800 | |
| MAX_SEQ_LEN = 1024 | |
| BATCH_SIZE = int(os.getenv("EMBED_BATCH_SIZE", "32")) | |
| PROKBERT_MODEL_ID = os.getenv("PROKBERT_MODEL_ID", "neuralbioinfo/prokbert-mini-long") | |
| CHECKPOINT_PATH = os.getenv("CHECKPOINT_PATH", "large-notext.pt") | |
| APP_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| OTU_INFO_PATH = os.getenv("OTU_INFO_PATH", "otus.97.allinfo") | |
| OTU_DB_PATH = os.getenv("OTU_DB_PATH", os.path.join(APP_DIR, "otus.97.sqlite")) | |
| EXAMPLE_SAMPLE_PATH = os.path.join(APP_DIR, "sample_DRS000421_DRR000770_taxa.tsv") | |
| MICROBEATLAS_SAMPLE_URL = "https://microbeatlas.org/sample_detail?sid=DRS000421&rid=null" | |
| TRUST_REMOTE_CODE = os.getenv("TRUST_REMOTE_CODE", "true").lower() == "true" | |
| CSS = """ | |
| :root { | |
| --bg: #f4f0e8; | |
| --panel: rgba(255, 252, 247, 0.88); | |
| --panel-strong: rgba(246, 240, 230, 0.96); | |
| --ink: #1d2a1f; | |
| --muted: #586454; | |
| --accent: #0e7a5f; | |
| --accent-2: #d8832f; | |
| --line: rgba(29, 42, 31, 0.12); | |
| } | |
| .gradio-container { | |
| background: | |
| radial-gradient(circle at top left, rgba(216, 131, 47, 0.18), transparent 28%), | |
| radial-gradient(circle at top right, rgba(14, 122, 95, 0.18), transparent 24%), | |
| linear-gradient(180deg, #f7f2e9 0%, #eee6d8 100%); | |
| color: var(--ink); | |
| } | |
| .hero { | |
| padding: 28px; | |
| border: 1px solid var(--line); | |
| border-radius: 24px; | |
| background: linear-gradient(135deg, rgba(255,255,255,0.85), rgba(241,232,218,0.92)); | |
| box-shadow: 0 18px 60px rgba(69, 57, 34, 0.08); | |
| } | |
| .hero h1 { | |
| margin: 0 0 10px 0; | |
| font-size: 2.4rem; | |
| line-height: 1.05; | |
| } | |
| .hero p { | |
| margin: 0; | |
| max-width: 900px; | |
| color: var(--muted); | |
| font-size: 1rem; | |
| } | |
| .soft-card { | |
| border: 1px solid var(--line); | |
| border-radius: 22px; | |
| background: var(--panel); | |
| box-shadow: 0 12px 32px rgba(40, 36, 26, 0.06); | |
| } | |
| .section-note { | |
| color: var(--muted); | |
| font-size: 0.95rem; | |
| } | |
| .search-results { | |
| max-height: 320px; | |
| overflow-y: auto; | |
| border: 1px solid var(--line); | |
| border-radius: 16px; | |
| background: rgba(255, 255, 255, 0.72); | |
| padding: 10px 12px 2px 12px; | |
| } | |
| .fixed-table { | |
| border: 1px solid var(--line); | |
| border-radius: 16px; | |
| overflow: hidden; | |
| } | |
| .fixed-table .table-wrap, | |
| .fixed-table .wrap, | |
| .fixed-table .overflow-y-auto { | |
| max-height: 360px; | |
| overflow-y: auto; | |
| } | |
| """ | |
| class LoadedModels: | |
| tokenizer: AutoTokenizer | |
| prokbert: AutoModel | |
| microbiome: MicrobiomeTransformer | |
| device: torch.device | |
| class OTUEntry: | |
| otu_id: str | |
| label: str | |
| taxonomy: str | |
| sequence: str | |
| seq_len: int | |
| search_text: str | |
| _MODELS: LoadedModels | None = None | |
| _OTU_INDEX_READY = False | |
| def _extract_taxa_name(taxonomy: str) -> str: | |
| parts = [part.strip() for part in taxonomy.split(";") if part.strip()] | |
| if not parts: | |
| return "Unclassified" | |
| return parts[-1].replace("g__", "").replace("s__", "").replace("f__", "") | |
| def _load_models() -> LoadedModels: | |
| global _MODELS | |
| if _MODELS is not None: | |
| return _MODELS | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| tokenizer = AutoTokenizer.from_pretrained(PROKBERT_MODEL_ID, trust_remote_code=TRUST_REMOTE_CODE) | |
| prokbert = AutoModel.from_pretrained(PROKBERT_MODEL_ID, trust_remote_code=TRUST_REMOTE_CODE) | |
| prokbert.to(device) | |
| prokbert.eval() | |
| checkpoint = torch.load(CHECKPOINT_PATH, map_location=device) | |
| state_dict = checkpoint.get("model_state_dict", checkpoint) | |
| microbiome = MicrobiomeTransformer( | |
| input_dim_type1=384, | |
| input_dim_type2=1536, | |
| d_model=100, | |
| nhead=5, | |
| num_layers=5, | |
| dim_feedforward=400, | |
| dropout=0.1, | |
| use_output_activation=False, | |
| ) | |
| microbiome.load_state_dict(state_dict, strict=False) | |
| microbiome.to(device) | |
| microbiome.eval() | |
| _MODELS = LoadedModels( | |
| tokenizer=tokenizer, | |
| prokbert=prokbert, | |
| microbiome=microbiome, | |
| device=device, | |
| ) | |
| return _MODELS | |
| def _open_otu_index() -> sqlite3.Connection: | |
| conn = sqlite3.connect(OTU_DB_PATH) | |
| conn.row_factory = sqlite3.Row | |
| return conn | |
| def _ensure_otu_index() -> None: | |
| global _OTU_INDEX_READY | |
| if _OTU_INDEX_READY and os.path.exists(OTU_DB_PATH): | |
| return | |
| if not os.path.exists(OTU_DB_PATH) and not os.path.exists(OTU_INFO_PATH): | |
| raise gr.Error( | |
| f"Missing OTU index at {OTU_DB_PATH}. Ship the prebuilt SQLite file with the Space, " | |
| f"or provide {OTU_INFO_PATH} so the index can be built." | |
| ) | |
| with _open_otu_index() as conn: | |
| existing = conn.execute( | |
| "SELECT name FROM sqlite_master WHERE type='table' AND name='otu_entries'" | |
| ).fetchone() | |
| if existing is not None: | |
| _OTU_INDEX_READY = True | |
| return | |
| conn.execute( | |
| """ | |
| CREATE TABLE otu_entries ( | |
| otu_id TEXT PRIMARY KEY, | |
| label TEXT NOT NULL, | |
| taxonomy TEXT NOT NULL, | |
| sequence TEXT NOT NULL, | |
| seq_len INTEGER NOT NULL, | |
| search_text TEXT NOT NULL | |
| ) | |
| """ | |
| ) | |
| conn.execute( | |
| """ | |
| CREATE VIRTUAL TABLE otu_search | |
| USING fts5(otu_id, label, taxonomy, content='otu_entries', content_rowid='rowid') | |
| """ | |
| ) | |
| with open(OTU_INFO_PATH, newline="") as handle: | |
| reader = csv.reader(handle, delimiter="\t") | |
| batch = [] | |
| for row in reader: | |
| if len(row) < 15: | |
| continue | |
| raw_id = row[0].strip() | |
| sequence = row[6].strip().upper() | |
| taxonomy = row[14].strip() or row[8].strip() or "Unclassified" | |
| if not raw_id or not sequence: | |
| continue | |
| otu_id = raw_id.split(";")[-1] | |
| label = _extract_taxa_name(taxonomy) | |
| batch.append( | |
| ( | |
| otu_id, | |
| label, | |
| taxonomy, | |
| sequence, | |
| len(sequence), | |
| f"{otu_id} {label} {taxonomy}".lower(), | |
| ) | |
| ) | |
| if len(batch) >= 2000: | |
| conn.executemany( | |
| """ | |
| INSERT OR REPLACE INTO otu_entries | |
| (otu_id, label, taxonomy, sequence, seq_len, search_text) | |
| VALUES (?, ?, ?, ?, ?, ?) | |
| """, | |
| batch, | |
| ) | |
| batch.clear() | |
| if batch: | |
| conn.executemany( | |
| """ | |
| INSERT OR REPLACE INTO otu_entries | |
| (otu_id, label, taxonomy, sequence, seq_len, search_text) | |
| VALUES (?, ?, ?, ?, ?, ?) | |
| """, | |
| batch, | |
| ) | |
| conn.execute( | |
| """ | |
| INSERT INTO otu_search(rowid, otu_id, label, taxonomy) | |
| SELECT rowid, otu_id, label, taxonomy FROM otu_entries | |
| """ | |
| ) | |
| conn.execute("CREATE INDEX idx_otu_entries_label ON otu_entries(label)") | |
| conn.execute("CREATE INDEX idx_otu_entries_taxonomy ON otu_entries(taxonomy)") | |
| conn.commit() | |
| _OTU_INDEX_READY = True | |
| def _iter_fasta_records(path: str): | |
| header: str | None = None | |
| seq_chunks: List[str] = [] | |
| with open(path) as handle: | |
| for raw_line in handle: | |
| line = raw_line.strip() | |
| if not line: | |
| continue | |
| if line.startswith(">"): | |
| if header is not None: | |
| yield header, "".join(seq_chunks) | |
| header = line[1:].strip() | |
| seq_chunks = [] | |
| continue | |
| if header is None: | |
| raise gr.Error("Invalid FASTA: sequence data appeared before the first header line.") | |
| seq_chunks.append(line) | |
| if header is not None: | |
| yield header, "".join(seq_chunks) | |
| def _rows_to_entries(rows: List[sqlite3.Row]) -> List[OTUEntry]: | |
| return [ | |
| OTUEntry( | |
| otu_id=row["otu_id"], | |
| label=row["label"], | |
| taxonomy=row["taxonomy"], | |
| sequence=row["sequence"], | |
| seq_len=row["seq_len"], | |
| search_text=row["search_text"], | |
| ) | |
| for row in rows | |
| ] | |
| def _fetch_otu_entries_by_ids(otu_ids: List[str]) -> Dict[str, OTUEntry]: | |
| _ensure_otu_index() | |
| if not otu_ids: | |
| return {} | |
| placeholders = ",".join("?" for _ in otu_ids) | |
| with _open_otu_index() as conn: | |
| rows = conn.execute( | |
| f""" | |
| SELECT otu_id, label, taxonomy, sequence, seq_len, search_text | |
| FROM otu_entries | |
| WHERE otu_id IN ({placeholders}) | |
| """, | |
| otu_ids, | |
| ).fetchall() | |
| entries = _rows_to_entries(rows) | |
| return {entry.otu_id: entry for entry in entries} | |
| def _trim_sequence(sequence: str) -> Tuple[str, bool]: | |
| if len(sequence) > MAX_SEQ_LEN: | |
| return sequence[:MAX_SEQ_LEN], True | |
| return sequence, False | |
| def _read_fasta(path: str) -> Tuple[List[dict], int, int]: | |
| records: List[dict] = [] | |
| truncated = 0 | |
| for header, sequence in _iter_fasta_records(path): | |
| record_id = header.split()[0] if header.split() else "unnamed_record" | |
| seq, was_truncated = _trim_sequence(sequence.upper()) | |
| truncated += int(was_truncated) | |
| records.append( | |
| { | |
| "id": record_id, | |
| "sequence": seq, | |
| "source": "FASTA", | |
| "taxonomy": "", | |
| "detail": f"{len(seq)} nt", | |
| } | |
| ) | |
| if not records: | |
| raise gr.Error("No FASTA records found.") | |
| return records[:MAX_GENES], len(records), truncated | |
| def _read_microbeatlas_sample(path: str) -> Tuple[List[dict], str]: | |
| records: List[dict] = [] | |
| missing_ids: List[str] = [] | |
| with open(path, newline="") as handle: | |
| reader = csv.reader(handle, delimiter="\t") | |
| header = next(reader, None) | |
| if header is None: | |
| raise gr.Error("The MicrobeAtlas file is empty.") | |
| columns = [col.strip() for col in header] | |
| column_index = {name: idx for idx, name in enumerate(columns)} | |
| if "SHORT_TID" not in column_index: | |
| raise gr.Error("Expected a MicrobeAtlas taxa file with a SHORT_TID column.") | |
| sample_rows = [] | |
| requested_ids: List[str] = [] | |
| for row in reader: | |
| if not row: | |
| continue | |
| otu_id = row[column_index["SHORT_TID"]].strip() | |
| if not otu_id: | |
| continue | |
| sample_rows.append(row) | |
| requested_ids.append(otu_id) | |
| otu_entries = _fetch_otu_entries_by_ids(sorted(set(requested_ids))) | |
| for row in sample_rows: | |
| otu_id = row[column_index["SHORT_TID"]].strip() | |
| entry = otu_entries.get(otu_id) | |
| if entry is None: | |
| missing_ids.append(otu_id) | |
| continue | |
| seq, was_truncated = _trim_sequence(entry.sequence) | |
| detail_bits = [] | |
| for column in ("COUNT", "ABUNDANCE"): | |
| idx = column_index.get(column) | |
| if idx is not None and idx < len(row): | |
| value = row[idx].strip() | |
| if value: | |
| detail_bits.append(f"{column.lower()}={value}") | |
| if was_truncated: | |
| detail_bits.append("trimmed") | |
| records.append( | |
| { | |
| "id": otu_id, | |
| "sequence": seq, | |
| "source": "MicrobeAtlas", | |
| "taxonomy": entry.taxonomy, | |
| "detail": ", ".join(detail_bits) if detail_bits else f"{entry.seq_len} nt", | |
| } | |
| ) | |
| if not records: | |
| raise gr.Error("No OTU IDs from this MicrobeAtlas file matched otus.97.allinfo.") | |
| used_records = records[:MAX_GENES] | |
| summary = ( | |
| f"Translated {len(used_records)} OTUs from the MicrobeAtlas upload. " | |
| f"Missing sequence mappings for {len(missing_ids)} OTUs." | |
| ) | |
| return used_records, summary | |
| def _search_otu_records(query: str, limit: int = 80) -> List[OTUEntry]: | |
| needle = query.strip().lower() | |
| if not needle: | |
| return [] | |
| _ensure_otu_index() | |
| with _open_otu_index() as conn: | |
| if " " in needle: | |
| tokens = [token for token in needle.split() if token] | |
| else: | |
| tokens = [needle] | |
| fts_query = " OR ".join(f'"{token}"*' for token in tokens) | |
| rows = conn.execute( | |
| """ | |
| SELECT e.otu_id, e.label, e.taxonomy, e.sequence, e.seq_len, e.search_text | |
| FROM otu_search s | |
| JOIN otu_entries e ON e.rowid = s.rowid | |
| WHERE otu_search MATCH ? | |
| ORDER BY rank | |
| LIMIT ? | |
| """, | |
| (fts_query, limit), | |
| ).fetchall() | |
| if not rows: | |
| rows = conn.execute( | |
| """ | |
| SELECT otu_id, label, taxonomy, sequence, seq_len, search_text | |
| FROM otu_entries | |
| WHERE search_text LIKE ? | |
| ORDER BY label, otu_id | |
| LIMIT ? | |
| """, | |
| (f"%{needle}%", limit), | |
| ).fetchall() | |
| return _rows_to_entries(rows) | |
| def _mean_pool(last_hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: | |
| mask = attention_mask.unsqueeze(-1).to(last_hidden_state.dtype) | |
| summed = (last_hidden_state * mask).sum(dim=1) | |
| counts = mask.sum(dim=1).clamp(min=1e-8) | |
| return summed / counts | |
| def _embed_sequences(seqs: List[str], models: LoadedModels) -> np.ndarray: | |
| pooled_batches: List[np.ndarray] = [] | |
| for i in range(0, len(seqs), BATCH_SIZE): | |
| batch = seqs[i : i + BATCH_SIZE] | |
| inputs = models.tokenizer( | |
| batch, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=MAX_SEQ_LEN, | |
| padding=True, | |
| ) | |
| inputs = {key: value.to(models.device) for key, value in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = models.prokbert(**inputs) | |
| pooled = _mean_pool(outputs.last_hidden_state, inputs["attention_mask"]) | |
| pooled_batches.append(pooled.detach().cpu().numpy()) | |
| embeddings = np.vstack(pooled_batches) | |
| if embeddings.shape[1] != 384: | |
| raise gr.Error( | |
| f"Expected 384-d ProkBERT embeddings, got {embeddings.shape[1]} from {PROKBERT_MODEL_ID}." | |
| ) | |
| return embeddings | |
| def _infer_logits_and_final_embeddings(input_embeddings: np.ndarray, models: LoadedModels) -> Tuple[np.ndarray, np.ndarray]: | |
| x = torch.tensor(input_embeddings, dtype=torch.float32, device=models.device).unsqueeze(0) | |
| n = x.shape[1] | |
| empty_text = torch.zeros((1, 0, 1536), dtype=torch.float32, device=models.device) | |
| mask = torch.ones((1, n), dtype=torch.bool, device=models.device) | |
| with torch.no_grad(): | |
| x_proj = models.microbiome.input_projection_type1(x) | |
| final_hidden = models.microbiome.transformer(x_proj, src_key_padding_mask=~mask) | |
| logits = models.microbiome.output_projection(final_hidden).squeeze(-1) | |
| return logits.squeeze(0).detach().cpu().numpy(), final_hidden.squeeze(0).detach().cpu().numpy() | |
| def _plot_umap(vectors: np.ndarray, labels: List[str], logits: np.ndarray, title: str): | |
| if len(vectors) < 2: | |
| raise gr.Error("UMAP needs at least 2 sequences.") | |
| n_points = len(vectors) | |
| reducer = umap.UMAP( | |
| n_components=2, | |
| n_neighbors=max(2, min(15, n_points - 1)), | |
| min_dist=0.1, | |
| metric="cosine", | |
| random_state=42, | |
| init="random" if n_points <= 3 else "spectral", | |
| ) | |
| coords = reducer.fit_transform(vectors) | |
| x_values = [float(value) for value in coords[:, 0]] | |
| y_values = [float(value) for value in coords[:, 1]] | |
| color_values = [float(value) for value in logits] | |
| fig = go.Figure( | |
| data=[ | |
| go.Scatter( | |
| x=x_values, | |
| y=y_values, | |
| mode="markers", | |
| text=labels, | |
| customdata=np.array(color_values).reshape(-1, 1), | |
| hovertemplate="<b>%{text}</b><br>UMAP 1=%{x:.3f}<br>UMAP 2=%{y:.3f}<br>stability score=%{customdata[0]:.4f}<extra></extra>", | |
| marker={ | |
| "size": 10, | |
| "color": color_values, | |
| "colorscale": "Viridis", | |
| "line": {"width": 0.6, "color": "#1d2a1f"}, | |
| "opacity": 0.92, | |
| "showscale": True, | |
| "colorbar": {"title": "stability score"}, | |
| }, | |
| ) | |
| ] | |
| ) | |
| fig.update_layout( | |
| title=title, | |
| xaxis_title="UMAP 1", | |
| yaxis_title="UMAP 2", | |
| paper_bgcolor="rgba(255,255,255,0)", | |
| plot_bgcolor="rgba(255,255,255,0.75)", | |
| margin={"l": 10, "r": 10, "t": 60, "b": 10}, | |
| ) | |
| return fig | |
| def _display_label(record: dict) -> str: | |
| taxonomy = (record.get("taxonomy") or "").strip().strip(";") | |
| if taxonomy: | |
| return taxonomy | |
| return record["id"] | |
| def _short_plot_label(label: str, max_len: int = 32) -> str: | |
| short_label = _extract_taxa_name(label) | |
| if len(short_label) <= max_len: | |
| return short_label | |
| return f"{short_label[: max_len - 1].rstrip()}…" | |
| def _plot_logits(logits: np.ndarray, labels: List[str]): | |
| order = np.argsort(logits)[::-1] | |
| sorted_labels = [labels[idx] for idx in order] | |
| short_labels = [_short_plot_label(label) for label in sorted_labels] | |
| sorted_logits = [float(logits[idx]) for idx in order] | |
| x_positions = list(range(len(sorted_labels))) | |
| fig = go.Figure( | |
| data=[ | |
| go.Bar( | |
| x=x_positions, | |
| y=sorted_logits, | |
| marker={"color": "#d8832f"}, | |
| width=0.95, | |
| customdata=np.array(sorted_labels).reshape(-1, 1), | |
| hovertemplate="<b>%{customdata[0]}</b><br>stability score=%{y:.4f}<extra></extra>", | |
| ) | |
| ] | |
| ) | |
| fig.update_layout( | |
| title="Ranked Stability Scores", | |
| xaxis_title="Taxon", | |
| yaxis_title="Stability Score", | |
| bargap=0, | |
| paper_bgcolor="rgba(255,255,255,0)", | |
| plot_bgcolor="rgba(255,255,255,0.75)", | |
| margin={"l": 10, "r": 10, "t": 60, "b": 140}, | |
| ) | |
| fig.update_xaxes( | |
| tickmode="array", | |
| tickvals=x_positions, | |
| ticktext=short_labels, | |
| tickangle=-45, | |
| ) | |
| return fig | |
| def _records_to_member_table(records: List[dict]) -> List[List[object]]: | |
| rows: List[List[object]] = [] | |
| for record in records: | |
| rows.append( | |
| [ | |
| record["id"], | |
| record.get("taxonomy", ""), | |
| ] | |
| ) | |
| return rows | |
| def _write_tsv_download(prefix: str, headers: List[str], rows: List[List[object]]) -> str: | |
| with tempfile.NamedTemporaryFile( | |
| mode="w", newline="", suffix=".tsv", prefix=f"{prefix}_", delete=False, dir="/tmp" | |
| ) as handle: | |
| writer = csv.writer(handle, delimiter="\t") | |
| writer.writerow(headers) | |
| writer.writerows(rows) | |
| return handle.name | |
| def _analyze_records(records: List[dict], source_title: str, extra_summary: str = ""): | |
| if len(records) < 2: | |
| raise gr.Error("This explorer needs at least 2 sequences to compute the UMAP views.") | |
| models = _load_models() | |
| used_records = records[:MAX_GENES] | |
| labels = [_display_label(record) for record in used_records] | |
| seqs = [record["sequence"] for record in used_records] | |
| input_embeddings = _embed_sequences(seqs, models) | |
| logits, final_embeddings = _infer_logits_and_final_embeddings(input_embeddings, models) | |
| input_umap = _plot_umap(input_embeddings, labels, logits, "UMAP of Input DNA Embeddings") | |
| final_umap = _plot_umap(final_embeddings, labels, logits, "UMAP of Final Transformer Embeddings") | |
| logits_hist = _plot_logits(logits, labels) | |
| rows = [] | |
| order = np.argsort(logits)[::-1] | |
| for idx in order: | |
| record = used_records[idx] | |
| rows.append( | |
| [ | |
| record["id"], | |
| float(logits[idx]), | |
| record.get("taxonomy", ""), | |
| ] | |
| ) | |
| score_by_id = {record["id"]: float(logits[idx]) for idx, record in enumerate(used_records)} | |
| summary = ( | |
| f"{source_title}: analyzed {len(used_records)} sequences " | |
| f"(cap={MAX_GENES}, trim={MAX_SEQ_LEN} nt)." | |
| ) | |
| if extra_summary: | |
| summary = f"{summary} {extra_summary}" | |
| members = [ | |
| [ | |
| record["id"], | |
| score_by_id[record["id"]], | |
| record.get("taxonomy", ""), | |
| ] | |
| for record in used_records | |
| ] | |
| top_rows = rows[:50] | |
| top_tsv = _write_tsv_download("top_stability_scores", ["id", "stability_score", "taxonomy"], top_rows) | |
| member_tsv = _write_tsv_download("analyzed_members", ["id", "stability_score", "taxonomy"], members) | |
| return summary, input_umap, final_umap, logits_hist, top_rows, members, top_tsv, member_tsv | |
| def analyze_fasta(fasta_file: str): | |
| if fasta_file is None: | |
| raise gr.Error("Upload a FASTA file first.") | |
| records, original_n, truncated = _read_fasta(fasta_file) | |
| extra = f"Loaded {original_n} records and truncated {truncated} sequence(s)." | |
| return _analyze_records(records, "Raw FASTA upload", extra) | |
| def analyze_microbeatlas(sample_file: str): | |
| if sample_file is None: | |
| raise gr.Error("Upload a MicrobeAtlas taxa TSV first.") | |
| records, translation_summary = _read_microbeatlas_sample(sample_file) | |
| return _analyze_records(records, "MicrobeAtlas import", translation_summary) | |
| def search_taxa(query: str): | |
| matches = _search_otu_records(query) | |
| if not matches: | |
| return ( | |
| gr.update(choices=[], value=[]), | |
| "No OTUs matched that taxon query.", | |
| ) | |
| choices = [(f"{entry.label} | {entry.otu_id}", entry.otu_id) for entry in matches] | |
| return ( | |
| gr.update(choices=choices, value=[]), | |
| f"Found {len(matches)} matching OTUs. Select the ones you want to add to the community.", | |
| ) | |
| def add_to_community(selected_otu_ids: List[str], community_ids: List[str]): | |
| current = list(community_ids or []) | |
| added = 0 | |
| for otu_id in selected_otu_ids or []: | |
| if otu_id in current: | |
| continue | |
| if len(current) >= MAX_GENES: | |
| break | |
| current.append(otu_id) | |
| added += 1 | |
| otu_entries = _fetch_otu_entries_by_ids(current) | |
| records = [ | |
| { | |
| "id": otu_entries[otu_id].otu_id, | |
| "sequence": otu_entries[otu_id].sequence[:MAX_SEQ_LEN], | |
| "source": "Community builder", | |
| "taxonomy": otu_entries[otu_id].taxonomy, | |
| "detail": otu_entries[otu_id].label, | |
| } | |
| for otu_id in current | |
| if otu_id in otu_entries | |
| ] | |
| status = f"Community now contains {len(records)} OTUs. Added {added} new member(s)." | |
| return current, _records_to_member_table(records), status | |
| def clear_community(): | |
| return [], [], "Community cleared." | |
| def analyze_community(community_ids: List[str]): | |
| if not community_ids: | |
| raise gr.Error("Build a community first by searching taxa and adding OTUs.") | |
| otu_entries = _fetch_otu_entries_by_ids(community_ids[:MAX_GENES]) | |
| records = [] | |
| for otu_id in community_ids[:MAX_GENES]: | |
| entry = otu_entries.get(otu_id) | |
| if entry is None: | |
| continue | |
| records.append( | |
| { | |
| "id": entry.otu_id, | |
| "sequence": entry.sequence[:MAX_SEQ_LEN], | |
| "source": "Community builder", | |
| "taxonomy": entry.taxonomy, | |
| "detail": entry.label, | |
| } | |
| ) | |
| if not records: | |
| raise gr.Error("No valid OTU members remain in the current community.") | |
| return _analyze_records(records, "Community builder", "Selected by taxon search against otus.97.allinfo.") | |
| with gr.Blocks(title="Microbiome Explorer", css=CSS, theme=gr.themes.Soft()) as demo: | |
| community_state = gr.State([]) | |
| gr.HTML( | |
| """ | |
| <section class="hero"> | |
| <h1>Microbiome Stability Scoring Explorer</h1> | |
| <p> | |
| Upload raw FASTA, translate a MicrobeAtlas sample into representative OTU sequences, | |
| or build a synthetic community by taxonomy. Every route ends in the same pipeline: | |
| ProkBERT mean pooling, <code>large-notext</code> scoring, and linked embedding views. | |
| </p> | |
| </section> | |
| """ | |
| ) | |
| with gr.Tabs(): | |
| with gr.Tab("Raw FASTA"): | |
| with gr.Column(elem_classes=["soft-card"]): | |
| gr.Markdown( | |
| "Upload genes directly in FASTA format. Sequences longer than 1024 nt are trimmed and only the first 800 records are used." | |
| ) | |
| fasta_in = gr.File( | |
| label="FASTA file", | |
| file_types=[".fa", ".fasta", ".fna", ".txt"], | |
| type="filepath", | |
| ) | |
| fasta_run_btn = gr.Button("Analyze FASTA", variant="primary") | |
| with gr.Tab("Import From MicrobeAtlas"): | |
| with gr.Column(elem_classes=["soft-card"]): | |
| gr.Markdown( | |
| f""" | |
| Bring in a taxa file exported from MicrobeAtlas. Go to | |
| [MicrobeAtlas sample detail]({MICROBEATLAS_SAMPLE_URL}), click `Download`, and upload the taxa TSV here. | |
| OTU IDs from `SHORT_TID` are translated to representative sequences using `otus.97.allinfo`. | |
| """ | |
| ) | |
| microbeatlas_in = gr.File( | |
| label="MicrobeAtlas taxa TSV", | |
| file_types=[".tsv", ".txt"], | |
| type="filepath", | |
| ) | |
| if os.path.exists(EXAMPLE_SAMPLE_PATH): | |
| gr.Examples( | |
| examples=[[EXAMPLE_SAMPLE_PATH]], | |
| inputs=[microbeatlas_in], | |
| label="Use example", | |
| ) | |
| else: | |
| gr.Markdown( | |
| "Example file not bundled in this deployment. Upload a MicrobeAtlas taxa TSV exported from the sample page above." | |
| ) | |
| microbeatlas_run_btn = gr.Button("Translate And Analyze", variant="primary") | |
| with gr.Tab("Build A Community"): | |
| with gr.Column(elem_classes=["soft-card"]): | |
| gr.Markdown( | |
| "Search the OTU index by OTU ID, taxon label, or taxonomy string. Matching OTUs appear directly below as you type, so you can add them without opening another widget." | |
| ) | |
| with gr.Row(): | |
| taxa_query = gr.Textbox( | |
| label="Search taxa", | |
| placeholder="Try Nitrospira, Lysobacter, Gammaproteobacteria, 97_8697 ...", | |
| scale=6, | |
| ) | |
| community_search_status = gr.Markdown(elem_classes=["section-note"]) | |
| taxa_matches = gr.CheckboxGroup( | |
| label="Matching OTUs", | |
| choices=[], | |
| value=[], | |
| elem_classes=["search-results"], | |
| ) | |
| with gr.Row(): | |
| community_add_btn = gr.Button("Add Selected OTUs", variant="primary") | |
| community_clear_btn = gr.Button("Clear Community") | |
| community_run_btn = gr.Button("Analyze Community", variant="secondary") | |
| with gr.Accordion("Community Members", open=True): | |
| community_table = gr.Dataframe( | |
| headers=["id", "taxonomy"], | |
| label="Current community", | |
| wrap=True, | |
| elem_classes=["fixed-table"], | |
| ) | |
| community_status = gr.Markdown(elem_classes=["section-note"]) | |
| with gr.Accordion("Analysis Results", open=True): | |
| run_summary = gr.Textbox(label="Run summary") | |
| with gr.Row(): | |
| input_umap_plot = gr.Plot(label="Input embedding UMAP") | |
| final_umap_plot = gr.Plot(label="Final embedding UMAP") | |
| logits_plot = gr.Plot(label="Stability score distribution") | |
| with gr.Accordion("Top-scoring members", open=False): | |
| top_download = gr.DownloadButton("Download top members TSV") | |
| top_table = gr.Dataframe( | |
| headers=["id", "stability_score", "taxonomy"], | |
| label="Top members by stability score", | |
| wrap=True, | |
| elem_classes=["fixed-table"], | |
| ) | |
| with gr.Accordion("Analyzed members", open=False): | |
| member_download = gr.DownloadButton("Download analyzed members TSV") | |
| member_table = gr.Dataframe( | |
| headers=["id", "stability_score", "taxonomy"], | |
| label="Members used in the run", | |
| wrap=True, | |
| elem_classes=["fixed-table"], | |
| ) | |
| fasta_run_btn.click( | |
| fn=analyze_fasta, | |
| inputs=[fasta_in], | |
| outputs=[run_summary, input_umap_plot, final_umap_plot, logits_plot, top_table, member_table, top_download, member_download], | |
| ) | |
| microbeatlas_run_btn.click( | |
| fn=analyze_microbeatlas, | |
| inputs=[microbeatlas_in], | |
| outputs=[run_summary, input_umap_plot, final_umap_plot, logits_plot, top_table, member_table, top_download, member_download], | |
| ) | |
| taxa_query.change( | |
| fn=search_taxa, | |
| inputs=[taxa_query], | |
| outputs=[taxa_matches, community_search_status], | |
| ) | |
| community_add_btn.click( | |
| fn=add_to_community, | |
| inputs=[taxa_matches, community_state], | |
| outputs=[community_state, community_table, community_status], | |
| ) | |
| community_clear_btn.click( | |
| fn=clear_community, | |
| outputs=[community_state, community_table, community_status], | |
| ) | |
| community_run_btn.click( | |
| fn=analyze_community, | |
| inputs=[community_state], | |
| outputs=[run_summary, input_umap_plot, final_umap_plot, logits_plot, top_table, member_table, top_download, member_download], | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue().launch() | |