microbe-model / scripts /26_marker_importance.py
Miyu Horiuchi
Deploy app from main@a3254bf (no paper/ binaries)
0ed74db
"""Diagnose where the HMM lift came from per phenotype target.
Re-runs train_all() on the HMM-covered subset and dumps, per target:
- top 25 features by XGBoost gain (with HMM/composition/codon/iso/MD tags)
- aggregated importance by HMM category (temperature, ph, oxygen, salt,
vitamin, nitrogen, carbon, special) — useful for spotting which categories
paid off across multiple targets
Output: artifacts/marker_importance.json + console summary.
"""
from __future__ import annotations
import json
import time
import pandas as pd
from microbe_model import config
from microbe_model.features.markers import all_markers, category_for
from microbe_model.train.baseline import train_all
PHENOTYPE_TARGETS = {
"optimal_temperature_c": "regression",
"optimal_ph": "regression",
"oxygen_requirement": "classification",
"salt_tolerance_pct": "regression",
}
# Map HMM column -> Pfam category. Built once at module load.
def _hmm_col_categories() -> dict[str, str]:
out: dict[str, str] = {}
name_to_pfam: dict[str, str] = {}
for pfam, (name, _) in all_markers().items():
name_to_pfam[name] = pfam
return name_to_pfam
def col_category(col_name: str) -> str:
"""Return one of: hmm:<cat>, composition, codon, tetra, iso, mediadive, baseline."""
if col_name.startswith("hmm_"):
# column is hmm_<friendly_name>_<n|score|present>
rest = col_name[len("hmm_"):]
for suffix in ("_n", "_score", "_present"):
if rest.endswith(suffix):
friendly = rest[: -len(suffix)]
pfam_for_name = {name: pfam for pfam, (name, _) in all_markers().items()}
pfam = pfam_for_name.get(friendly)
if pfam:
return f"hmm:{category_for(pfam)}"
return "hmm:unknown"
if col_name.startswith("aa_frac_"):
return "composition"
if col_name.startswith("codon_"):
return "codon"
if col_name.startswith("tetra_"):
return "tetra"
if col_name.startswith("iso_"):
return "iso"
if col_name.startswith("md_"):
return "mediadive"
return "baseline"
def derive_group(row: pd.Series) -> str:
for col in ("family", "genus"):
v = row.get(col)
if isinstance(v, str) and v:
return v
s = row.get("species")
if isinstance(s, str) and s:
return s.split()[0]
return "__unknown__"
def encode_iso(df: pd.DataFrame, *, min_count: int = 5) -> tuple[pd.DataFrame, list[str]]:
import re
from collections import Counter
new_cols: list[str] = []
for level in ("isolation_cat1", "isolation_cat2"):
if level not in df.columns:
continue
c: Counter[str] = Counter()
for v in df[level].dropna():
c.update(v.split("|"))
kept = [t for t, n in c.items() if n >= min_count]
for tag in sorted(kept):
slug = re.sub(r"[^a-z0-9]+", "_", tag.lower()).strip("_")
col = f"iso_{level.split('_')[1]}_{slug}"
if col in df.columns:
continue
df[col] = df[level].fillna("").apply(lambda v, t=tag: int(t in v.split("|")))
new_cols.append(col)
return df, new_cols
def main() -> None:
t0 = time.time()
pheno = pd.read_parquet(config.DATA / "bacdive_phenotypes.parquet")
feats = pd.read_parquet(config.DATA / "features.parquet")
hmm = pd.read_parquet(config.DATA / "hmm_features.parquet")
pheno["bacdive_id"] = pheno["bacdive_id"].astype(int)
feats["bacdive_id"] = feats["bacdive_id"].astype(int)
df = pheno.merge(feats, on=["bacdive_id", "genome_accession"], how="inner")
df = df[df["genome_accession"].isin(hmm["genome_accession"])].copy()
df = df.merge(hmm, on="genome_accession", how="left")
md_path = config.DATA / "mediadive_features.parquet"
md_cols: list[str] = []
if md_path.exists():
md = pd.read_parquet(md_path)
md["bacdive_id"] = md["bacdive_id"].astype(int)
md_cols = [c for c in md.columns if c != "bacdive_id"]
df = df.merge(md, on="bacdive_id", how="left")
df["group"] = df.apply(derive_group, axis=1)
df, iso_cols = encode_iso(df)
base_cols = [c for c in feats.columns if c not in {"bacdive_id", "genome_accession"}]
hmm_cols = [c for c in hmm.columns if c != "genome_accession"]
feature_cols = base_cols + iso_cols + md_cols + hmm_cols
print(f"Training on {len(df):,} HMM-covered strains × {len(feature_cols)} features\n")
results = train_all(df, feature_cols, group_col_override="group")
report: dict[str, dict] = {}
for target, r in results.items():
if not r.importances:
continue
ranked = sorted(r.importances.items(), key=lambda kv: kv[1], reverse=True)
top = ranked[:25]
cat_totals: dict[str, float] = {}
for col, imp in r.importances.items():
cat = col_category(col)
cat_totals[cat] = cat_totals.get(cat, 0.0) + imp
report[target] = {
"task": r.task,
"score": r.mean(),
"n_folds": len(r.folds),
"top_25_features": [{"name": n, "importance": float(i),
"category": col_category(n)} for n, i in top],
"category_totals": {k: float(v) for k, v in
sorted(cat_totals.items(), key=lambda kv: kv[1], reverse=True)},
}
score_label = "MAE" if r.task == "regression" else "F1_macro"
print(f"--- {target} ({score_label}={r.mean():.3f}, n_folds={len(r.folds)})")
print(" top 10 features:")
for n, i in top[:10]:
print(f" {i:.4f} [{col_category(n):20s}] {n}")
print(" category totals:")
for cat, total in sorted(cat_totals.items(), key=lambda kv: kv[1], reverse=True)[:8]:
print(f" {total:.4f} {cat}")
print()
out = config.ARTIFACTS / "marker_importance.json"
out.parent.mkdir(parents=True, exist_ok=True)
with open(out, "w") as fh:
json.dump(report, fh, indent=2)
print(f"Wrote {out}")
print(f"\nDone in {time.time() - t0:.1f}s")
if __name__ == "__main__":
main()