Miyu Horiuchi commited on
Commit
d3cbd87
·
1 Parent(s): 3d34be9

Phase E modeling: per-medium classifiers + recommender training script

Browse files

src/microbe_model/train/media_recommender.py:
- build_training_table: joins genome features + strain_media into (X, y_matrix)
where y_matrix is a sparse {0,1} (n_strains × n_media) presence indicator
- train_per_medium: one XGBoost per medium with GroupKFold by family,
scale_pos_weight to handle imbalance, PR-AUC + ROC-AUC reported per fold
- save_results: writes JSON with metrics, ready for eval rendering

scripts/10_train_media_recommender.py: end-to-end training entry point.
Filters to media with >= 100 strains (~50-80 most-used recipes), reports
median PR-AUC across all of them, lists top-15 best-modeled and worst-5.

Limitations documented in the module docstring:
- BacDive only records growth=yes — we have positive examples only. Negatives
are constructed implicitly (other media used by the same strain), which may
bias toward media that are simply under-recorded.
- No concentration prediction yet — recipe selection only. Concentration
regression head deferred to v1.

Will run after v1 featurize + MediaDive fetch both complete (~4 hrs from now).
The deliverable: given an uncultured genome, output top-K media to try.

scripts/10_train_media_recommender.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Train per-medium classifiers and report metrics across all media meeting the count cutoff.
2
+
3
+ Outputs:
4
+ artifacts/media_recommender_results.json — per-medium PR-AUC + ROC-AUC, fold-by-fold.
5
+ artifacts/media_recommender_report.md — human-readable summary.
6
+ """
7
+ from __future__ import annotations
8
+
9
+ import time
10
+
11
+ import pandas as pd
12
+
13
+ from microbe_model import config
14
+ from microbe_model.train.media_recommender import (
15
+ build_training_table,
16
+ save_results,
17
+ train_per_medium,
18
+ )
19
+
20
+
21
+ def main() -> None:
22
+ pheno = pd.read_parquet(config.DATA / "bacdive_phenotypes.parquet")
23
+ feats = pd.read_parquet(config.DATA / "features.parquet")
24
+ sm = pd.read_parquet(config.DATA / "strain_media.parquet")
25
+ md = pd.read_parquet(config.DATA / "media_metadata.parquet")
26
+ medium_name_by_id = dict(zip(md["medium_id"].astype(str), md["name"], strict=True))
27
+
28
+ print(f"Inputs: {len(feats):,} feature rows, {len(sm):,} strain↔medium links")
29
+
30
+ X, y_matrix, medium_ids = build_training_table(feats, sm, pheno)
31
+ groups = pheno.set_index("bacdive_id").loc[X.index, "family"].fillna("__unknown__")
32
+ print(f"Training table: {len(X):,} strains × {X.shape[1]} features × {len(medium_ids)} media")
33
+ print(f"Distinct families: {groups.nunique():,}")
34
+ print()
35
+
36
+ t0 = time.time()
37
+ results = train_per_medium(X, y_matrix, medium_name_by_id, groups)
38
+ print(f"Trained {len(results)} per-medium classifiers in {time.time() - t0:.1f}s\n")
39
+
40
+ out_json = config.ARTIFACTS / "media_recommender_results.json"
41
+ save_results(results, out_json)
42
+ print(f"Wrote {out_json}\n")
43
+
44
+ # Headline summary
45
+ rows = [(mid, r.medium_name, r.n_positives, r.n_negatives, r.mean_pr_auc(), r.mean_roc_auc())
46
+ for mid, r in results.items()]
47
+ summary = pd.DataFrame(rows, columns=["medium_id", "name", "n_pos", "n_neg",
48
+ "pr_auc", "roc_auc"])
49
+ summary = summary.sort_values("pr_auc", ascending=False)
50
+ print(f"Median PR-AUC: {summary['pr_auc'].median():.3f}")
51
+ print(f"Median ROC-AUC: {summary['roc_auc'].median():.3f}")
52
+ print("\nTop 15 best-modeled media (by PR-AUC):")
53
+ print(summary.head(15).to_string(index=False))
54
+ print("\nWorst 5:")
55
+ print(summary.tail(5).to_string(index=False))
56
+
57
+
58
+ if __name__ == "__main__":
59
+ main()
src/microbe_model/train/media_recommender.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Train per-medium binary classifiers to recommend cultivation media for a genome.
2
+
3
+ Setup:
4
+ - Filter to media used by >= MIN_STRAINS_PER_MEDIUM strains (default 100).
5
+ - For each such medium m, build a binary label: y_i = 1 if strain i has a
6
+ growth=yes link to m, else 0.
7
+ - Train one XGBoost classifier per medium with GroupKFold by family.
8
+ - At inference, output a (n_strains × n_media) probability matrix.
9
+
10
+ The deliverable: given a new (possibly uncultured) genome, output the top-K media
11
+ ranked by predicted probability. This is the "what should I try first?" output
12
+ microbiologists actually want.
13
+
14
+ Limitations:
15
+ - All BacDive `culture medium` entries are growth=yes — we have positive
16
+ examples but no explicit negatives. We construct negatives from strains that
17
+ have *some* media link but not this one. This may bias toward media that are
18
+ just under-recorded.
19
+ - No concentration prediction yet — only recipe selection. v1 will add a
20
+ secondary regression head that adjusts compound concentrations.
21
+ """
22
+ from __future__ import annotations
23
+
24
+ import json
25
+ from dataclasses import dataclass, field
26
+ from pathlib import Path
27
+
28
+ import numpy as np
29
+ import pandas as pd
30
+ import xgboost as xgb
31
+ from sklearn.metrics import average_precision_score, roc_auc_score
32
+ from sklearn.model_selection import GroupKFold
33
+
34
+ MIN_STRAINS_PER_MEDIUM = 100
35
+
36
+
37
+ @dataclass
38
+ class MediumModelResult:
39
+ medium_id: str
40
+ medium_name: str
41
+ n_positives: int
42
+ n_negatives: int
43
+ fold_metrics: list[dict] = field(default_factory=list)
44
+
45
+ def mean_pr_auc(self) -> float:
46
+ if not self.fold_metrics:
47
+ return float("nan")
48
+ return float(np.mean([m["pr_auc"] for m in self.fold_metrics]))
49
+
50
+ def mean_roc_auc(self) -> float:
51
+ if not self.fold_metrics:
52
+ return float("nan")
53
+ return float(np.mean([m["roc_auc"] for m in self.fold_metrics]))
54
+
55
+
56
+ def build_training_table(
57
+ features: pd.DataFrame,
58
+ strain_media: pd.DataFrame,
59
+ bacdive: pd.DataFrame,
60
+ ) -> tuple[pd.DataFrame, pd.DataFrame, list[str]]:
61
+ """Return (X, y_matrix, medium_ids) for media meeting the strain-count threshold.
62
+
63
+ X: (n_strains × n_features) feature DataFrame, indexed by bacdive_id
64
+ y_matrix: (n_strains × n_media) {0,1} DataFrame, columns are medium_ids
65
+ """
66
+ # Strains with both genome features and at least one positive medium link
67
+ strain_ids = sorted(set(features["bacdive_id"]).intersection(set(strain_media["bacdive_id"])))
68
+ if not strain_ids:
69
+ raise ValueError("No overlap between feature table and strain_media links.")
70
+
71
+ X = features[features["bacdive_id"].isin(strain_ids)].set_index("bacdive_id").sort_index()
72
+ feature_cols = [c for c in X.columns if c not in {"genome_accession"}]
73
+ X = X[feature_cols]
74
+
75
+ # Build sparse positive-link table → wide y matrix
76
+ sm = strain_media[strain_media["bacdive_id"].isin(strain_ids)]
77
+ sm = sm[sm["growth"] == "yes"]
78
+ counts = sm.groupby("medium_id").size()
79
+ keep_media = counts[counts >= MIN_STRAINS_PER_MEDIUM].index.tolist()
80
+ sm = sm[sm["medium_id"].isin(keep_media)]
81
+
82
+ y_matrix = (
83
+ sm.assign(_one=1)
84
+ .pivot_table(index="bacdive_id", columns="medium_id", values="_one", fill_value=0)
85
+ .reindex(index=X.index, columns=keep_media, fill_value=0)
86
+ .astype(np.uint8)
87
+ )
88
+
89
+ return X, y_matrix, keep_media
90
+
91
+
92
+ def train_per_medium(
93
+ X: pd.DataFrame,
94
+ y_matrix: pd.DataFrame,
95
+ medium_metadata: dict[str, str],
96
+ groups: pd.Series,
97
+ *,
98
+ n_splits: int = 5,
99
+ n_estimators: int = 200,
100
+ max_depth: int = 5,
101
+ ) -> dict[str, MediumModelResult]:
102
+ """Train one classifier per medium with GroupKFold by `groups` (e.g. taxonomic family)."""
103
+ results: dict[str, MediumModelResult] = {}
104
+ splits = min(n_splits, max(2, groups.nunique()))
105
+ kfold = GroupKFold(n_splits=splits)
106
+
107
+ for medium_id in y_matrix.columns:
108
+ y = y_matrix[medium_id].to_numpy()
109
+ n_pos, n_neg = int(y.sum()), int((y == 0).sum())
110
+ result = MediumModelResult(
111
+ medium_id=str(medium_id),
112
+ medium_name=medium_metadata.get(str(medium_id), ""),
113
+ n_positives=n_pos,
114
+ n_negatives=n_neg,
115
+ )
116
+
117
+ # Need both classes in train/test
118
+ for fold_idx, (tr_idx, te_idx) in enumerate(kfold.split(X, y, groups)):
119
+ y_tr = y[tr_idx]
120
+ y_te = y[te_idx]
121
+ if y_tr.sum() < 5 or y_te.sum() < 1:
122
+ continue
123
+
124
+ scale_pos_weight = (y_tr == 0).sum() / max(1, y_tr.sum())
125
+ model = xgb.XGBClassifier(
126
+ n_estimators=n_estimators,
127
+ max_depth=max_depth,
128
+ learning_rate=0.05,
129
+ tree_method="hist",
130
+ n_jobs=-1,
131
+ scale_pos_weight=scale_pos_weight,
132
+ eval_metric="logloss",
133
+ )
134
+ model.fit(X.iloc[tr_idx], y_tr)
135
+ proba = model.predict_proba(X.iloc[te_idx])[:, 1]
136
+ try:
137
+ roc = roc_auc_score(y_te, proba)
138
+ pr = average_precision_score(y_te, proba)
139
+ except ValueError:
140
+ continue
141
+ result.fold_metrics.append({
142
+ "fold": fold_idx,
143
+ "n_train": int(len(tr_idx)),
144
+ "n_test": int(len(te_idx)),
145
+ "n_test_positives": int(y_te.sum()),
146
+ "roc_auc": float(roc),
147
+ "pr_auc": float(pr),
148
+ })
149
+
150
+ results[str(medium_id)] = result
151
+ return results
152
+
153
+
154
+ def save_results(results: dict[str, MediumModelResult], path: Path) -> None:
155
+ payload = {
156
+ mid: {
157
+ "medium_name": r.medium_name,
158
+ "n_positives": r.n_positives,
159
+ "n_negatives": r.n_negatives,
160
+ "mean_pr_auc": r.mean_pr_auc(),
161
+ "mean_roc_auc": r.mean_roc_auc(),
162
+ "folds": r.fold_metrics,
163
+ }
164
+ for mid, r in results.items()
165
+ }
166
+ path.write_text(json.dumps(payload, indent=2))