Spaces:
Running
Running
| """Train per-medium classifiers and report metrics across all media meeting the count cutoff. | |
| Outputs: | |
| artifacts/media_recommender_results.json — per-medium PR-AUC + ROC-AUC, fold-by-fold. | |
| artifacts/media_recommender_report.md — human-readable summary. | |
| """ | |
| from __future__ import annotations | |
| import time | |
| import pandas as pd | |
| from microbe_model import config | |
| from microbe_model.train.media_recommender import ( | |
| build_training_table, | |
| save_models, | |
| save_results, | |
| train_per_medium, | |
| train_production_models, | |
| ) | |
| def main() -> None: | |
| pheno = pd.read_parquet(config.DATA / "bacdive_phenotypes.parquet") | |
| feats = pd.read_parquet(config.DATA / "features.parquet") | |
| sm = pd.read_parquet(config.DATA / "strain_media.parquet") | |
| md = pd.read_parquet(config.DATA / "media_metadata.parquet") | |
| medium_name_by_id = dict(zip(md["medium_id"].astype(str), md["name"], strict=True)) | |
| hmm_path = config.DATA / "hmm_features.parquet" | |
| if hmm_path.exists(): | |
| hmm = pd.read_parquet(hmm_path) | |
| feats = feats.merge(hmm, on="genome_accession", how="left") | |
| n_hmm_cols = len([c for c in hmm.columns if c != "genome_accession"]) | |
| print(f"Joined HMM features ({n_hmm_cols} cols) into features table") | |
| kegg_path = config.DATA / "kegg_modules.parquet" | |
| if kegg_path.exists(): | |
| kegg = pd.read_parquet(kegg_path) | |
| feats = feats.merge(kegg, on="genome_accession", how="left") | |
| n_kegg_cols = len([c for c in kegg.columns if c != "genome_accession"]) | |
| print(f"Joined KEGG module completeness ({n_kegg_cols} cols) into features table") | |
| iso_meta_path = config.DATA / "isolation_metadata.parquet" | |
| if iso_meta_path.exists(): | |
| iso_meta = pd.read_parquet(iso_meta_path) | |
| iso_meta["bacdive_id"] = iso_meta["bacdive_id"].astype(int) | |
| feats["bacdive_id"] = feats["bacdive_id"].astype(int) | |
| keep = ["bacdive_id", "iso_lat", "iso_lon", "iso_collection_year"] | |
| keep += [c for c in iso_meta.columns if c.startswith(("iso_continent_", "iso_country_", "iso_host_kingdom_"))] | |
| feats = feats.merge(iso_meta[keep], on="bacdive_id", how="left") | |
| print(f"Joined isolation metadata ({len(keep) - 1} cols) into features table") | |
| print(f"Inputs: {len(feats):,} feature rows, {len(sm):,} strain↔medium links") | |
| X, y_matrix, medium_ids = build_training_table(feats, sm, pheno) | |
| groups = pheno.set_index("bacdive_id").loc[X.index, "family"].fillna("__unknown__") | |
| print(f"Training table: {len(X):,} strains × {X.shape[1]} features × {len(medium_ids)} media") | |
| print(f"Distinct families: {groups.nunique():,}") | |
| print() | |
| t0 = time.time() | |
| results = train_per_medium(X, y_matrix, medium_name_by_id, groups) | |
| print(f"Trained {len(results)} per-medium classifiers in {time.time() - t0:.1f}s\n") | |
| out_json = config.ARTIFACTS / "media_recommender_results.json" | |
| save_results(results, out_json) | |
| print(f"Wrote {out_json}") | |
| # Train production models on ALL data for inference + persist | |
| print("\nFitting production models on full dataset...") | |
| prod_models = train_production_models(X, y_matrix) | |
| models_dir = config.ROOT / "models" / "recommender" | |
| save_models(prod_models, list(X.columns), models_dir) | |
| print(f"Saved {len(prod_models)} production models to {models_dir}") | |
| # Headline summary | |
| rows = [(mid, r.medium_name, r.n_positives, r.n_negatives, r.mean_pr_auc(), r.mean_roc_auc()) | |
| for mid, r in results.items()] | |
| summary = pd.DataFrame(rows, columns=["medium_id", "name", "n_pos", "n_neg", | |
| "pr_auc", "roc_auc"]) | |
| summary = summary.sort_values("pr_auc", ascending=False) | |
| print(f"Median PR-AUC: {summary['pr_auc'].median():.3f}") | |
| print(f"Median ROC-AUC: {summary['roc_auc'].median():.3f}") | |
| print("\nTop 15 best-modeled media (by PR-AUC):") | |
| print(summary.head(15).to_string(index=False)) | |
| print("\nWorst 5:") | |
| print(summary.tail(5).to_string(index=False)) | |
| if __name__ == "__main__": | |
| main() | |