microbe-model / scripts /10_train_media_recommender.py
Miyu Horiuchi
Deploy app from main@a3254bf (no paper/ binaries)
0ed74db
"""Train per-medium classifiers and report metrics across all media meeting the count cutoff.
Outputs:
artifacts/media_recommender_results.json — per-medium PR-AUC + ROC-AUC, fold-by-fold.
artifacts/media_recommender_report.md — human-readable summary.
"""
from __future__ import annotations
import time
import pandas as pd
from microbe_model import config
from microbe_model.train.media_recommender import (
build_training_table,
save_models,
save_results,
train_per_medium,
train_production_models,
)
def main() -> None:
pheno = pd.read_parquet(config.DATA / "bacdive_phenotypes.parquet")
feats = pd.read_parquet(config.DATA / "features.parquet")
sm = pd.read_parquet(config.DATA / "strain_media.parquet")
md = pd.read_parquet(config.DATA / "media_metadata.parquet")
medium_name_by_id = dict(zip(md["medium_id"].astype(str), md["name"], strict=True))
hmm_path = config.DATA / "hmm_features.parquet"
if hmm_path.exists():
hmm = pd.read_parquet(hmm_path)
feats = feats.merge(hmm, on="genome_accession", how="left")
n_hmm_cols = len([c for c in hmm.columns if c != "genome_accession"])
print(f"Joined HMM features ({n_hmm_cols} cols) into features table")
kegg_path = config.DATA / "kegg_modules.parquet"
if kegg_path.exists():
kegg = pd.read_parquet(kegg_path)
feats = feats.merge(kegg, on="genome_accession", how="left")
n_kegg_cols = len([c for c in kegg.columns if c != "genome_accession"])
print(f"Joined KEGG module completeness ({n_kegg_cols} cols) into features table")
iso_meta_path = config.DATA / "isolation_metadata.parquet"
if iso_meta_path.exists():
iso_meta = pd.read_parquet(iso_meta_path)
iso_meta["bacdive_id"] = iso_meta["bacdive_id"].astype(int)
feats["bacdive_id"] = feats["bacdive_id"].astype(int)
keep = ["bacdive_id", "iso_lat", "iso_lon", "iso_collection_year"]
keep += [c for c in iso_meta.columns if c.startswith(("iso_continent_", "iso_country_", "iso_host_kingdom_"))]
feats = feats.merge(iso_meta[keep], on="bacdive_id", how="left")
print(f"Joined isolation metadata ({len(keep) - 1} cols) into features table")
print(f"Inputs: {len(feats):,} feature rows, {len(sm):,} strain↔medium links")
X, y_matrix, medium_ids = build_training_table(feats, sm, pheno)
groups = pheno.set_index("bacdive_id").loc[X.index, "family"].fillna("__unknown__")
print(f"Training table: {len(X):,} strains × {X.shape[1]} features × {len(medium_ids)} media")
print(f"Distinct families: {groups.nunique():,}")
print()
t0 = time.time()
results = train_per_medium(X, y_matrix, medium_name_by_id, groups)
print(f"Trained {len(results)} per-medium classifiers in {time.time() - t0:.1f}s\n")
out_json = config.ARTIFACTS / "media_recommender_results.json"
save_results(results, out_json)
print(f"Wrote {out_json}")
# Train production models on ALL data for inference + persist
print("\nFitting production models on full dataset...")
prod_models = train_production_models(X, y_matrix)
models_dir = config.ROOT / "models" / "recommender"
save_models(prod_models, list(X.columns), models_dir)
print(f"Saved {len(prod_models)} production models to {models_dir}")
# Headline summary
rows = [(mid, r.medium_name, r.n_positives, r.n_negatives, r.mean_pr_auc(), r.mean_roc_auc())
for mid, r in results.items()]
summary = pd.DataFrame(rows, columns=["medium_id", "name", "n_pos", "n_neg",
"pr_auc", "roc_auc"])
summary = summary.sort_values("pr_auc", ascending=False)
print(f"Median PR-AUC: {summary['pr_auc'].median():.3f}")
print(f"Median ROC-AUC: {summary['roc_auc'].median():.3f}")
print("\nTop 15 best-modeled media (by PR-AUC):")
print(summary.head(15).to_string(index=False))
print("\nWorst 5:")
print(summary.tail(5).to_string(index=False))
if __name__ == "__main__":
main()