"""Load precomputed explorer artifacts (no torch required at runtime).""" from __future__ import annotations import html import pickle import re import unicodedata from pathlib import Path import numpy as np import pandas as pd from streamlit_hf.lib.formatters import annotate_modality_column from streamlit_hf.lib.reactions import normalize_reaction_key REPO_ROOT = Path(__file__).resolve().parents[2] CACHE_DIR = REPO_ROOT / "streamlit_hf" / "cache" METABOLIC_MODEL_METADATA = REPO_ROOT / "data" / "datasets" / "metabolic_model_metadata.csv" def _is_valid_features_csv(path: Path) -> bool: if not path.is_file(): return False try: head = pd.read_csv(path, nrows=2) except Exception: return False return "feature" in head.columns and "importance_shift" in head.columns def load_latent_bundle(): path = CACHE_DIR / "latent_umap.pkl" if not path.is_file(): return None with open(path, "rb") as f: return pickle.load(f) def load_attention_summary(): path = CACHE_DIR / "attention_summary.pkl" if not path.is_file(): return None with open(path, "rb") as f: return pickle.load(f) def load_samples_df() -> pd.DataFrame | None: pq = CACHE_DIR / "samples.parquet" if pq.is_file(): df = pd.read_parquet(pq) return annotate_modality_column(df) if "modality" in df.columns else df return None def _add_within_modality_orders(df: pd.DataFrame) -> pd.DataFrame: """ Align scatter / table columns with the notebook. Parquet from precompute already has rank_shift_in_modal / rank_att_in_modal from the same merge-of-sorted-lists logic as the notebook; do not overwrite those with pandas ranks on rounded importances (tie order can differ and changes the RNA cloud). """ out = df.copy() if "modality" not in out.columns: return out if "rank_shift_in_modal" in out.columns and "rank_att_in_modal" in out.columns: out["shift_order_mod"] = out["rank_shift_in_modal"].astype(int) out["attention_order_mod"] = out["rank_att_in_modal"].astype(int) else: g = out.groupby("modality", observed=True) out["shift_order_mod"] = g["importance_shift"].rank(ascending=False, method="first").astype(int) out["attention_order_mod"] = g["importance_att"].rank(ascending=False, method="first").astype(int) out["rank_shift_in_modal"] = out["shift_order_mod"] out["rank_att_in_modal"] = out["attention_order_mod"] if "combined_order_mod" not in out.columns: g = out.groupby("modality", observed=True) out["combined_order_mod"] = g["mean_rank"].rank(ascending=True, method="first").astype(int) return out def load_metabolic_model_metadata() -> pd.DataFrame | None: """Directed reaction edges: substrate → product, grouped by supermodule (see CSV headers).""" if not METABOLIC_MODEL_METADATA.is_file(): return None return pd.read_csv(METABOLIC_MODEL_METADATA) def build_metabolic_model_table( meta: pd.DataFrame, supermodule_id: int | None = None, ) -> pd.DataFrame: """Rows from ``metabolic_model_metadata.csv`` (all file columns except a stray ``Unnamed: 0`` index column).""" need = {"Compound_IN_name", "Compound_OUT_name", "rxnName", "Supermodule_id", "Super.Module.class"} if not need.issubset(set(meta.columns)): return pd.DataFrame() m = meta.copy() if supermodule_id is not None: m = m[m["Supermodule_id"] == int(supermodule_id)] if m.empty: return pd.DataFrame() if "Unnamed: 0" in m.columns: m = m.drop(columns=["Unnamed: 0"]) return m.reset_index(drop=True) def _normalize_metabolite_token(name: str) -> str: t = unicodedata.normalize("NFD", str(name).strip().lower()) t = "".join(ch for ch in t if unicodedata.category(ch) != "Mn") t = re.sub(r"\s+", " ", t).strip() return t def _is_plausible_metabolite_name(name: str) -> bool: t = str(name).strip() if len(t) < 2: return False if t.endswith("-OUT"): return False if t in {"C00000", "***", "**", "*"}: return False if re.fullmatch(r"C\d{5,}", t): return False return True def _token_variants(raw: str) -> set[str]: base = _normalize_metabolite_token(raw) if not base: return set() beta = "\u03b2" alpha = "\u03b1" out = { base, base.replace(beta, "B").replace(alpha, "A").replace("ß", "ss"), } if base.startswith("B-") and len(base) > 2: out.add(f"{beta}-{base[2:]}") if base.startswith(f"{beta}-") and len(base) > 2: out.add(f"B-{base[2:]}") if "alanine" in base and (base.startswith("B-") or base.startswith(f"{beta}-")): out.add("beta-alanine") return {x for x in out if x} def _json_float(v) -> float | None: if v is None: return None try: x = float(v) except (TypeError, ValueError): return None if isinstance(x, float) and np.isnan(x): return None return x def build_metabolite_map_bundle( meta: pd.DataFrame | None, flux_df: pd.DataFrame | None, ) -> dict | None: """ Curated metabolites from metabolic_model_metadata.csv, enriched with flux rows from df_features where reaction strings match. Used by the metabolic map iframe (sidebar list + hover cards). """ need = {"Compound_IN_name", "Compound_OUT_name", "rxnName", "Super.Module.class", "Compound_IN_ID", "Compound_OUT_ID"} if meta is None or meta.empty or not need.issubset(meta.columns): return None fd = pd.DataFrame() if flux_df is not None and not flux_df.empty and "feature" in flux_df.columns: fd = flux_df.copy() fd["_rk"] = fd["feature"].map(normalize_reaction_key) fd = fd.drop_duplicates("_rk", keep="first").set_index("_rk", drop=False) reaction_importance_rank: dict[str, int] = {} if not fd.empty and "mean_rank" in fd.columns: for idx in fd.index: row = fd.loc[idx] if isinstance(row, pd.DataFrame): row = row.iloc[0] if "combined_order_mod" in row.index and pd.notna(row["combined_order_mod"]): reaction_importance_rank[idx] = int(row["combined_order_mod"]) if len(reaction_importance_rank) < len(fd): sub = fd.sort_values("mean_rank", ascending=True, kind="mergesort") for i, idx in enumerate(sub.index, start=1): reaction_importance_rank.setdefault(idx, i) buckets: dict[str, dict] = {} def touch(key: str, display: str) -> dict: if key not in buckets: buckets[key] = { "key": key, "name": display.strip(), "tokens": set(), "chebi": set(), "reactions": [], "supermodules": set(), } b = buckets[key] b["tokens"].update(_token_variants(display)) return b for _, row in meta.iterrows(): sub_raw = row["Compound_IN_name"] prod_raw = row["Compound_OUT_name"] rxn = str(row["rxnName"]).strip() rk = normalize_reaction_key(rxn) smod = row.get("Super.Module.class") smod_s = str(smod).strip() if smod is not None and str(smod) != "nan" else "" fr = None if rk in fd.index: fr = fd.loc[rk] if isinstance(fr, pd.DataFrame): fr = fr.iloc[0] mean_rank = _json_float(fr["mean_rank"]) if fr is not None and "mean_rank" in fr.index else None log_fc = _json_float(fr["log_fc"]) if fr is not None and "log_fc" in fr.index else None pval_adj = _json_float(fr["pval_adj"]) if fr is not None and "pval_adj" in fr.index else None pathway = None if fr is not None and "pathway" in fr.index: pv = fr["pathway"] if pd.notna(pv): pathway = str(pv).strip() fate_group = None if fr is not None and "group" in fr.index: g = fr["group"] if pd.notna(g): fate_group = str(g).strip() imp_r = reaction_importance_rank.get(rk) base_rx = { "reaction": rxn, "supermodule": smod_s, "mean_rank": mean_rank, "importance_rank": imp_r, "log_fc": log_fc, "pval_adj": pval_adj, "pathway": pathway, "fate_group": fate_group, } if _is_plausible_metabolite_name(sub_raw): k = _normalize_metabolite_token(sub_raw) b = touch(k, str(sub_raw).strip()) if smod_s: b["supermodules"].add(smod_s) b["chebi"].add(str(row["Compound_IN_ID"]).strip()) b["reactions"].append({**base_rx, "as": "substrate", "partner": str(prod_raw).strip()}) if _is_plausible_metabolite_name(prod_raw): k = _normalize_metabolite_token(prod_raw) b = touch(k, str(prod_raw).strip()) if smod_s: b["supermodules"].add(smod_s) b["chebi"].add(str(row["Compound_OUT_ID"]).strip()) b["reactions"].append({**base_rx, "as": "product", "partner": str(sub_raw).strip()}) if not buckets: return None by_key: dict[str, dict] = {} ordered: list[dict] = [] for key, b in buckets.items(): seen_rx: set[tuple[str, str]] = set() uniq_rx: list[dict] = [] for r in b["reactions"]: sig = (normalize_reaction_key(r["reaction"]), r["as"]) if sig in seen_rx: continue seen_rx.add(sig) uniq_rx.append(r) b["reactions"] = uniq_rx imp_ranks = [r["importance_rank"] for r in uniq_rx if r.get("importance_rank") is not None] best_importance = min(imp_ranks) if imp_ranks else None chebi_sorted = sorted({x for x in b["chebi"] if x and x not in {"nan", "C00000"}}) tokens_sorted = sorted(b["tokens"]) smods = sorted(b["supermodules"]) lines: list[str] = [f"{html.escape(b['name'])}"] if chebi_sorted: lines.append(f"Model IDs: {html.escape(', '.join(chebi_sorted[:8]))}") if smods: lines.append(f"Modules: {html.escape(' · '.join(smods[:4]))}") if best_importance is not None: lines.append(f"Strongest linked reaction: #{best_importance}") top_rx = sorted( uniq_rx, key=lambda r: ( r.get("importance_rank") is None, r["importance_rank"] if r.get("importance_rank") is not None else 10**9, ), )[:5] if top_rx: lines.append("Linked reactions (# · log₂FC · fate)") for r in top_rx: bits = [html.escape(r["reaction"][:80] + ("…" if len(r["reaction"]) > 80 else ""))] if r.get("importance_rank") is not None: bits.append(f"#{r['importance_rank']}") if r["log_fc"] is not None: bits.append(f"log₂FC {r['log_fc']:.3f}") if r["fate_group"]: bits.append(html.escape(r["fate_group"])) if r["pathway"]: bits.append(f"({html.escape(r['pathway'])})") lines.append(" · ".join(bits)) precursors = sorted( {r["partner"] for r in uniq_rx if r["as"] == "product" and r.get("partner") and _is_plausible_metabolite_name(r["partner"])} ) products = sorted( {r["partner"] for r in uniq_rx if r["as"] == "substrate" and r.get("partner") and _is_plausible_metabolite_name(r["partner"])} ) if precursors: lines.append( f"Model precursors (substrates in linked reactions)
" f"{html.escape(', '.join(precursors[:8]))}" ) if products: lines.append( f"Model products (downstream in linked reactions)
" f"{html.escape(', '.join(products[:8]))}" ) blurb = "
".join(lines) search_parts: list[str] = [b["name"], key, *tokens_sorted, *smods, *chebi_sorted] for r in uniq_rx: search_parts.extend( [ str(r.get("reaction") or ""), str(r.get("pathway") or ""), str(r.get("fate_group") or ""), str(r.get("supermodule") or ""), str(r.get("as") or ""), str(r.get("partner") or ""), ] ) if r.get("importance_rank") is not None: search_parts.append(str(r["importance_rank"])) if r.get("mean_rank") is not None: search_parts.append(str(r["mean_rank"])) if r.get("log_fc") is not None: search_parts.append(str(r["log_fc"])) search_parts.extend(precursors) search_parts.extend(products) search_text = re.sub(r"\s+", " ", " ".join(search_parts).lower()).strip() card = { "key": key, "name": b["name"], "tokens": tokens_sorted, "importance_rank": best_importance, "n_reactions": len(uniq_rx), "blurb_html": blurb, "search_text": search_text, } by_key[key] = card ordered.append(card) ordered.sort( key=lambda c: ( c["importance_rank"] is None, c["importance_rank"] if c["importance_rank"] is not None else 10**9, str(c["name"]).lower(), ) ) return {"list": ordered, "by_key": by_key} def load_df_features() -> pd.DataFrame | None: pq = CACHE_DIR / "df_features.parquet" if pq.is_file(): return _add_within_modality_orders(pd.read_parquet(pq)) csv_cache = CACHE_DIR / "df_features.csv" if csv_cache.is_file(): return _add_within_modality_orders(pd.read_csv(csv_cache)) analysis_csv = REPO_ROOT / "analysis" / "df_features.csv" if _is_valid_features_csv(analysis_csv): return _add_within_modality_orders(pd.read_csv(analysis_csv)) return None def latent_join_samples(bundle: dict, samples: pd.DataFrame | None) -> pd.DataFrame: """One row per UMAP point, aligned with bundle arrays.""" n = len(bundle["umap_x"]) df = pd.DataFrame( { "umap_x": bundle["umap_x"], "umap_y": bundle["umap_y"], "label": bundle["label_name"], "predicted_class": bundle["pred_name"], "correct": bundle["correct"].astype(bool), "fold": bundle["fold"].astype(int), "batch_no": bundle["batch_no"].astype(int), "pct": bundle["pct"], "modality": bundle["modality"], "dataset_idx": bundle["dataset_idx"].astype(int), } ) if samples is not None and not samples.empty: s = samples.drop_duplicates(subset=["ind"], keep="first").set_index("ind") extra = s.reindex(df["dataset_idx"].values) for col in ["predicted_value", "clone_id", "clone_size", "cell_type"]: if col in extra.columns: df[col] = extra[col].values return annotate_modality_column(df)