Spaces:
Running
Running
| """Load precomputed explorer artifacts (no torch required at runtime).""" | |
| from __future__ import annotations | |
| import html | |
| import pickle | |
| import re | |
| import unicodedata | |
| from pathlib import Path | |
| import numpy as np | |
| import pandas as pd | |
| from streamlit_hf.lib.formatters import annotate_modality_column | |
| from streamlit_hf.lib.reactions import normalize_reaction_key | |
| REPO_ROOT = Path(__file__).resolve().parents[2] | |
| CACHE_DIR = REPO_ROOT / "streamlit_hf" / "cache" | |
| METABOLIC_MODEL_METADATA = REPO_ROOT / "data" / "datasets" / "metabolic_model_metadata.csv" | |
| def _is_valid_features_csv(path: Path) -> bool: | |
| if not path.is_file(): | |
| return False | |
| try: | |
| head = pd.read_csv(path, nrows=2) | |
| except Exception: | |
| return False | |
| return "feature" in head.columns and "importance_shift" in head.columns | |
| def load_latent_bundle(): | |
| path = CACHE_DIR / "latent_umap.pkl" | |
| if not path.is_file(): | |
| return None | |
| with open(path, "rb") as f: | |
| return pickle.load(f) | |
| def load_attention_summary(): | |
| path = CACHE_DIR / "attention_summary.pkl" | |
| if not path.is_file(): | |
| return None | |
| with open(path, "rb") as f: | |
| return pickle.load(f) | |
| def load_samples_df() -> pd.DataFrame | None: | |
| pq = CACHE_DIR / "samples.parquet" | |
| if pq.is_file(): | |
| df = pd.read_parquet(pq) | |
| return annotate_modality_column(df) if "modality" in df.columns else df | |
| return None | |
| def _add_within_modality_orders(df: pd.DataFrame) -> pd.DataFrame: | |
| """ | |
| Align scatter / table columns with the notebook. | |
| Parquet from precompute already has rank_shift_in_modal / rank_att_in_modal from the same | |
| merge-of-sorted-lists logic as the notebook; do not overwrite those with pandas ranks on | |
| rounded importances (tie order can differ and changes the RNA cloud). | |
| """ | |
| out = df.copy() | |
| if "modality" not in out.columns: | |
| return out | |
| if "rank_shift_in_modal" in out.columns and "rank_att_in_modal" in out.columns: | |
| out["shift_order_mod"] = out["rank_shift_in_modal"].astype(int) | |
| out["attention_order_mod"] = out["rank_att_in_modal"].astype(int) | |
| else: | |
| g = out.groupby("modality", observed=True) | |
| out["shift_order_mod"] = g["importance_shift"].rank(ascending=False, method="first").astype(int) | |
| out["attention_order_mod"] = g["importance_att"].rank(ascending=False, method="first").astype(int) | |
| out["rank_shift_in_modal"] = out["shift_order_mod"] | |
| out["rank_att_in_modal"] = out["attention_order_mod"] | |
| if "combined_order_mod" not in out.columns: | |
| g = out.groupby("modality", observed=True) | |
| out["combined_order_mod"] = g["mean_rank"].rank(ascending=True, method="first").astype(int) | |
| return out | |
| def load_metabolic_model_metadata() -> pd.DataFrame | None: | |
| """Directed reaction edges: substrate → product, grouped by supermodule (see CSV headers).""" | |
| if not METABOLIC_MODEL_METADATA.is_file(): | |
| return None | |
| return pd.read_csv(METABOLIC_MODEL_METADATA) | |
| def build_metabolic_model_table( | |
| meta: pd.DataFrame, | |
| supermodule_id: int | None = None, | |
| ) -> pd.DataFrame: | |
| """Rows from ``metabolic_model_metadata.csv`` (all file columns except a stray ``Unnamed: 0`` index column).""" | |
| need = {"Compound_IN_name", "Compound_OUT_name", "rxnName", "Supermodule_id", "Super.Module.class"} | |
| if not need.issubset(set(meta.columns)): | |
| return pd.DataFrame() | |
| m = meta.copy() | |
| if supermodule_id is not None: | |
| m = m[m["Supermodule_id"] == int(supermodule_id)] | |
| if m.empty: | |
| return pd.DataFrame() | |
| if "Unnamed: 0" in m.columns: | |
| m = m.drop(columns=["Unnamed: 0"]) | |
| return m.reset_index(drop=True) | |
| def _normalize_metabolite_token(name: str) -> str: | |
| t = unicodedata.normalize("NFD", str(name).strip().lower()) | |
| t = "".join(ch for ch in t if unicodedata.category(ch) != "Mn") | |
| t = re.sub(r"\s+", " ", t).strip() | |
| return t | |
| def _is_plausible_metabolite_name(name: str) -> bool: | |
| t = str(name).strip() | |
| if len(t) < 2: | |
| return False | |
| if t.endswith("-OUT"): | |
| return False | |
| if t in {"C00000", "***", "**", "*"}: | |
| return False | |
| if re.fullmatch(r"C\d{5,}", t): | |
| return False | |
| return True | |
| def _token_variants(raw: str) -> set[str]: | |
| base = _normalize_metabolite_token(raw) | |
| if not base: | |
| return set() | |
| beta = "\u03b2" | |
| alpha = "\u03b1" | |
| out = { | |
| base, | |
| base.replace(beta, "B").replace(alpha, "A").replace("ß", "ss"), | |
| } | |
| if base.startswith("B-") and len(base) > 2: | |
| out.add(f"{beta}-{base[2:]}") | |
| if base.startswith(f"{beta}-") and len(base) > 2: | |
| out.add(f"B-{base[2:]}") | |
| if "alanine" in base and (base.startswith("B-") or base.startswith(f"{beta}-")): | |
| out.add("beta-alanine") | |
| return {x for x in out if x} | |
| def _json_float(v) -> float | None: | |
| if v is None: | |
| return None | |
| try: | |
| x = float(v) | |
| except (TypeError, ValueError): | |
| return None | |
| if isinstance(x, float) and np.isnan(x): | |
| return None | |
| return x | |
| def build_metabolite_map_bundle( | |
| meta: pd.DataFrame | None, | |
| flux_df: pd.DataFrame | None, | |
| ) -> dict | None: | |
| """ | |
| Curated metabolites from metabolic_model_metadata.csv, enriched with flux rows from df_features | |
| where reaction strings match. Used by the metabolic map iframe (sidebar list + hover cards). | |
| """ | |
| need = {"Compound_IN_name", "Compound_OUT_name", "rxnName", "Super.Module.class", "Compound_IN_ID", "Compound_OUT_ID"} | |
| if meta is None or meta.empty or not need.issubset(meta.columns): | |
| return None | |
| fd = pd.DataFrame() | |
| if flux_df is not None and not flux_df.empty and "feature" in flux_df.columns: | |
| fd = flux_df.copy() | |
| fd["_rk"] = fd["feature"].map(normalize_reaction_key) | |
| fd = fd.drop_duplicates("_rk", keep="first").set_index("_rk", drop=False) | |
| reaction_importance_rank: dict[str, int] = {} | |
| if not fd.empty and "mean_rank" in fd.columns: | |
| for idx in fd.index: | |
| row = fd.loc[idx] | |
| if isinstance(row, pd.DataFrame): | |
| row = row.iloc[0] | |
| if "combined_order_mod" in row.index and pd.notna(row["combined_order_mod"]): | |
| reaction_importance_rank[idx] = int(row["combined_order_mod"]) | |
| if len(reaction_importance_rank) < len(fd): | |
| sub = fd.sort_values("mean_rank", ascending=True, kind="mergesort") | |
| for i, idx in enumerate(sub.index, start=1): | |
| reaction_importance_rank.setdefault(idx, i) | |
| buckets: dict[str, dict] = {} | |
| def touch(key: str, display: str) -> dict: | |
| if key not in buckets: | |
| buckets[key] = { | |
| "key": key, | |
| "name": display.strip(), | |
| "tokens": set(), | |
| "chebi": set(), | |
| "reactions": [], | |
| "supermodules": set(), | |
| } | |
| b = buckets[key] | |
| b["tokens"].update(_token_variants(display)) | |
| return b | |
| for _, row in meta.iterrows(): | |
| sub_raw = row["Compound_IN_name"] | |
| prod_raw = row["Compound_OUT_name"] | |
| rxn = str(row["rxnName"]).strip() | |
| rk = normalize_reaction_key(rxn) | |
| smod = row.get("Super.Module.class") | |
| smod_s = str(smod).strip() if smod is not None and str(smod) != "nan" else "" | |
| fr = None | |
| if rk in fd.index: | |
| fr = fd.loc[rk] | |
| if isinstance(fr, pd.DataFrame): | |
| fr = fr.iloc[0] | |
| mean_rank = _json_float(fr["mean_rank"]) if fr is not None and "mean_rank" in fr.index else None | |
| log_fc = _json_float(fr["log_fc"]) if fr is not None and "log_fc" in fr.index else None | |
| pval_adj = _json_float(fr["pval_adj"]) if fr is not None and "pval_adj" in fr.index else None | |
| pathway = None | |
| if fr is not None and "pathway" in fr.index: | |
| pv = fr["pathway"] | |
| if pd.notna(pv): | |
| pathway = str(pv).strip() | |
| fate_group = None | |
| if fr is not None and "group" in fr.index: | |
| g = fr["group"] | |
| if pd.notna(g): | |
| fate_group = str(g).strip() | |
| imp_r = reaction_importance_rank.get(rk) | |
| base_rx = { | |
| "reaction": rxn, | |
| "supermodule": smod_s, | |
| "mean_rank": mean_rank, | |
| "importance_rank": imp_r, | |
| "log_fc": log_fc, | |
| "pval_adj": pval_adj, | |
| "pathway": pathway, | |
| "fate_group": fate_group, | |
| } | |
| if _is_plausible_metabolite_name(sub_raw): | |
| k = _normalize_metabolite_token(sub_raw) | |
| b = touch(k, str(sub_raw).strip()) | |
| if smod_s: | |
| b["supermodules"].add(smod_s) | |
| b["chebi"].add(str(row["Compound_IN_ID"]).strip()) | |
| b["reactions"].append({**base_rx, "as": "substrate", "partner": str(prod_raw).strip()}) | |
| if _is_plausible_metabolite_name(prod_raw): | |
| k = _normalize_metabolite_token(prod_raw) | |
| b = touch(k, str(prod_raw).strip()) | |
| if smod_s: | |
| b["supermodules"].add(smod_s) | |
| b["chebi"].add(str(row["Compound_OUT_ID"]).strip()) | |
| b["reactions"].append({**base_rx, "as": "product", "partner": str(sub_raw).strip()}) | |
| if not buckets: | |
| return None | |
| by_key: dict[str, dict] = {} | |
| ordered: list[dict] = [] | |
| for key, b in buckets.items(): | |
| seen_rx: set[tuple[str, str]] = set() | |
| uniq_rx: list[dict] = [] | |
| for r in b["reactions"]: | |
| sig = (normalize_reaction_key(r["reaction"]), r["as"]) | |
| if sig in seen_rx: | |
| continue | |
| seen_rx.add(sig) | |
| uniq_rx.append(r) | |
| b["reactions"] = uniq_rx | |
| imp_ranks = [r["importance_rank"] for r in uniq_rx if r.get("importance_rank") is not None] | |
| best_importance = min(imp_ranks) if imp_ranks else None | |
| chebi_sorted = sorted({x for x in b["chebi"] if x and x not in {"nan", "C00000"}}) | |
| tokens_sorted = sorted(b["tokens"]) | |
| smods = sorted(b["supermodules"]) | |
| lines: list[str] = [f"<strong>{html.escape(b['name'])}</strong>"] | |
| if chebi_sorted: | |
| lines.append(f"Model IDs: {html.escape(', '.join(chebi_sorted[:8]))}") | |
| if smods: | |
| lines.append(f"Modules: {html.escape(' · '.join(smods[:4]))}") | |
| if best_importance is not None: | |
| lines.append(f"Strongest linked reaction: #{best_importance}") | |
| top_rx = sorted( | |
| uniq_rx, | |
| key=lambda r: ( | |
| r.get("importance_rank") is None, | |
| r["importance_rank"] if r.get("importance_rank") is not None else 10**9, | |
| ), | |
| )[:5] | |
| if top_rx: | |
| lines.append("<span style='color:#656d76'>Linked reactions (# · log₂FC · fate)</span>") | |
| for r in top_rx: | |
| bits = [html.escape(r["reaction"][:80] + ("…" if len(r["reaction"]) > 80 else ""))] | |
| if r.get("importance_rank") is not None: | |
| bits.append(f"#{r['importance_rank']}") | |
| if r["log_fc"] is not None: | |
| bits.append(f"log₂FC {r['log_fc']:.3f}") | |
| if r["fate_group"]: | |
| bits.append(html.escape(r["fate_group"])) | |
| if r["pathway"]: | |
| bits.append(f"({html.escape(r['pathway'])})") | |
| lines.append(" · ".join(bits)) | |
| precursors = sorted( | |
| {r["partner"] for r in uniq_rx if r["as"] == "product" and r.get("partner") and _is_plausible_metabolite_name(r["partner"])} | |
| ) | |
| products = sorted( | |
| {r["partner"] for r in uniq_rx if r["as"] == "substrate" and r.get("partner") and _is_plausible_metabolite_name(r["partner"])} | |
| ) | |
| if precursors: | |
| lines.append( | |
| f"<span style='color:#656d76'>Model precursors (substrates in linked reactions)</span><br/>" | |
| f"{html.escape(', '.join(precursors[:8]))}" | |
| ) | |
| if products: | |
| lines.append( | |
| f"<span style='color:#656d76'>Model products (downstream in linked reactions)</span><br/>" | |
| f"{html.escape(', '.join(products[:8]))}" | |
| ) | |
| blurb = "<br/>".join(lines) | |
| search_parts: list[str] = [b["name"], key, *tokens_sorted, *smods, *chebi_sorted] | |
| for r in uniq_rx: | |
| search_parts.extend( | |
| [ | |
| str(r.get("reaction") or ""), | |
| str(r.get("pathway") or ""), | |
| str(r.get("fate_group") or ""), | |
| str(r.get("supermodule") or ""), | |
| str(r.get("as") or ""), | |
| str(r.get("partner") or ""), | |
| ] | |
| ) | |
| if r.get("importance_rank") is not None: | |
| search_parts.append(str(r["importance_rank"])) | |
| if r.get("mean_rank") is not None: | |
| search_parts.append(str(r["mean_rank"])) | |
| if r.get("log_fc") is not None: | |
| search_parts.append(str(r["log_fc"])) | |
| search_parts.extend(precursors) | |
| search_parts.extend(products) | |
| search_text = re.sub(r"\s+", " ", " ".join(search_parts).lower()).strip() | |
| card = { | |
| "key": key, | |
| "name": b["name"], | |
| "tokens": tokens_sorted, | |
| "importance_rank": best_importance, | |
| "n_reactions": len(uniq_rx), | |
| "blurb_html": blurb, | |
| "search_text": search_text, | |
| } | |
| by_key[key] = card | |
| ordered.append(card) | |
| ordered.sort( | |
| key=lambda c: ( | |
| c["importance_rank"] is None, | |
| c["importance_rank"] if c["importance_rank"] is not None else 10**9, | |
| str(c["name"]).lower(), | |
| ) | |
| ) | |
| return {"list": ordered, "by_key": by_key} | |
| def load_df_features() -> pd.DataFrame | None: | |
| pq = CACHE_DIR / "df_features.parquet" | |
| if pq.is_file(): | |
| return _add_within_modality_orders(pd.read_parquet(pq)) | |
| csv_cache = CACHE_DIR / "df_features.csv" | |
| if csv_cache.is_file(): | |
| return _add_within_modality_orders(pd.read_csv(csv_cache)) | |
| analysis_csv = REPO_ROOT / "analysis" / "df_features.csv" | |
| if _is_valid_features_csv(analysis_csv): | |
| return _add_within_modality_orders(pd.read_csv(analysis_csv)) | |
| return None | |
| def latent_join_samples(bundle: dict, samples: pd.DataFrame | None) -> pd.DataFrame: | |
| """One row per UMAP point, aligned with bundle arrays.""" | |
| n = len(bundle["umap_x"]) | |
| df = pd.DataFrame( | |
| { | |
| "umap_x": bundle["umap_x"], | |
| "umap_y": bundle["umap_y"], | |
| "label": bundle["label_name"], | |
| "predicted_class": bundle["pred_name"], | |
| "correct": bundle["correct"].astype(bool), | |
| "fold": bundle["fold"].astype(int), | |
| "batch_no": bundle["batch_no"].astype(int), | |
| "pct": bundle["pct"], | |
| "modality": bundle["modality"], | |
| "dataset_idx": bundle["dataset_idx"].astype(int), | |
| } | |
| ) | |
| if samples is not None and not samples.empty: | |
| s = samples.drop_duplicates(subset=["ind"], keep="first").set_index("ind") | |
| extra = s.reindex(df["dataset_idx"].values) | |
| for col in ["predicted_value", "clone_id", "clone_size", "cell_type"]: | |
| if col in extra.columns: | |
| df[col] = extra[col].values | |
| return annotate_modality_column(df) | |