File size: 4,057 Bytes
d3cbd87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31110fe
d3cbd87
 
31110fe
d3cbd87
 
 
 
 
 
 
 
 
 
0ed74db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d3cbd87
 
 
 
 
 
 
 
 
 
 
 
 
 
31110fe
 
 
 
 
 
 
 
d3cbd87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
"""Train per-medium classifiers and report metrics across all media meeting the count cutoff.

Outputs:
  artifacts/media_recommender_results.json — per-medium PR-AUC + ROC-AUC, fold-by-fold.
  artifacts/media_recommender_report.md   — human-readable summary.
"""
from __future__ import annotations

import time

import pandas as pd

from microbe_model import config
from microbe_model.train.media_recommender import (
    build_training_table,
    save_models,
    save_results,
    train_per_medium,
    train_production_models,
)


def main() -> None:
    pheno = pd.read_parquet(config.DATA / "bacdive_phenotypes.parquet")
    feats = pd.read_parquet(config.DATA / "features.parquet")
    sm = pd.read_parquet(config.DATA / "strain_media.parquet")
    md = pd.read_parquet(config.DATA / "media_metadata.parquet")
    medium_name_by_id = dict(zip(md["medium_id"].astype(str), md["name"], strict=True))

    hmm_path = config.DATA / "hmm_features.parquet"
    if hmm_path.exists():
        hmm = pd.read_parquet(hmm_path)
        feats = feats.merge(hmm, on="genome_accession", how="left")
        n_hmm_cols = len([c for c in hmm.columns if c != "genome_accession"])
        print(f"Joined HMM features ({n_hmm_cols} cols) into features table")

    kegg_path = config.DATA / "kegg_modules.parquet"
    if kegg_path.exists():
        kegg = pd.read_parquet(kegg_path)
        feats = feats.merge(kegg, on="genome_accession", how="left")
        n_kegg_cols = len([c for c in kegg.columns if c != "genome_accession"])
        print(f"Joined KEGG module completeness ({n_kegg_cols} cols) into features table")

    iso_meta_path = config.DATA / "isolation_metadata.parquet"
    if iso_meta_path.exists():
        iso_meta = pd.read_parquet(iso_meta_path)
        iso_meta["bacdive_id"] = iso_meta["bacdive_id"].astype(int)
        feats["bacdive_id"] = feats["bacdive_id"].astype(int)
        keep = ["bacdive_id", "iso_lat", "iso_lon", "iso_collection_year"]
        keep += [c for c in iso_meta.columns if c.startswith(("iso_continent_", "iso_country_", "iso_host_kingdom_"))]
        feats = feats.merge(iso_meta[keep], on="bacdive_id", how="left")
        print(f"Joined isolation metadata ({len(keep) - 1} cols) into features table")

    print(f"Inputs: {len(feats):,} feature rows, {len(sm):,} strain↔medium links")

    X, y_matrix, medium_ids = build_training_table(feats, sm, pheno)
    groups = pheno.set_index("bacdive_id").loc[X.index, "family"].fillna("__unknown__")
    print(f"Training table: {len(X):,} strains × {X.shape[1]} features × {len(medium_ids)} media")
    print(f"Distinct families: {groups.nunique():,}")
    print()

    t0 = time.time()
    results = train_per_medium(X, y_matrix, medium_name_by_id, groups)
    print(f"Trained {len(results)} per-medium classifiers in {time.time() - t0:.1f}s\n")

    out_json = config.ARTIFACTS / "media_recommender_results.json"
    save_results(results, out_json)
    print(f"Wrote {out_json}")

    # Train production models on ALL data for inference + persist
    print("\nFitting production models on full dataset...")
    prod_models = train_production_models(X, y_matrix)
    models_dir = config.ROOT / "models" / "recommender"
    save_models(prod_models, list(X.columns), models_dir)
    print(f"Saved {len(prod_models)} production models to {models_dir}")

    # Headline summary
    rows = [(mid, r.medium_name, r.n_positives, r.n_negatives, r.mean_pr_auc(), r.mean_roc_auc())
            for mid, r in results.items()]
    summary = pd.DataFrame(rows, columns=["medium_id", "name", "n_pos", "n_neg",
                                          "pr_auc", "roc_auc"])
    summary = summary.sort_values("pr_auc", ascending=False)
    print(f"Median PR-AUC: {summary['pr_auc'].median():.3f}")
    print(f"Median ROC-AUC: {summary['roc_auc'].median():.3f}")
    print("\nTop 15 best-modeled media (by PR-AUC):")
    print(summary.head(15).to_string(index=False))
    print("\nWorst 5:")
    print(summary.tail(5).to_string(index=False))


if __name__ == "__main__":
    main()