Spaces:
Running
Running
File size: 8,075 Bytes
0ed74db | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 | """Train per-medium binary classifiers to recommend cultivation media for a genome.
Setup:
- Filter to media used by >= MIN_STRAINS_PER_MEDIUM strains (default 100).
- For each such medium m, build a binary label: y_i = 1 if strain i has a
growth=yes link to m, else 0.
- Train one XGBoost classifier per medium with GroupKFold by family.
- At inference, output a (n_strains × n_media) probability matrix.
The deliverable: given a new (possibly uncultured) genome, output the top-K media
ranked by predicted probability. This is the "what should I try first?" output
microbiologists actually want.
Limitations:
- All BacDive `culture medium` entries are growth=yes — we have positive
examples but no explicit negatives. We construct negatives from strains that
have *some* media link but not this one. This may bias toward media that are
just under-recorded.
- No concentration prediction yet — only recipe selection. v1 will add a
secondary regression head that adjusts compound concentrations.
"""
from __future__ import annotations
import json
from dataclasses import dataclass, field
from pathlib import Path
import numpy as np
import pandas as pd
import xgboost as xgb
from sklearn.metrics import average_precision_score, roc_auc_score
from sklearn.model_selection import GroupKFold
MIN_STRAINS_PER_MEDIUM = 100
@dataclass
class MediumModelResult:
medium_id: str
medium_name: str
n_positives: int
n_negatives: int
fold_metrics: list[dict] = field(default_factory=list)
def mean_pr_auc(self) -> float:
if not self.fold_metrics:
return float("nan")
return float(np.mean([m["pr_auc"] for m in self.fold_metrics]))
def mean_roc_auc(self) -> float:
if not self.fold_metrics:
return float("nan")
return float(np.mean([m["roc_auc"] for m in self.fold_metrics]))
def build_training_table(
features: pd.DataFrame,
strain_media: pd.DataFrame,
bacdive: pd.DataFrame,
) -> tuple[pd.DataFrame, pd.DataFrame, list[str]]:
"""Return (X, y_matrix, medium_ids) for media meeting the strain-count threshold.
X: (n_strains × n_features) feature DataFrame, indexed by bacdive_id
y_matrix: (n_strains × n_media) {0,1} DataFrame, columns are medium_ids
"""
# Strains with both genome features and at least one positive medium link
strain_ids = sorted(set(features["bacdive_id"]).intersection(set(strain_media["bacdive_id"])))
if not strain_ids:
raise ValueError("No overlap between feature table and strain_media links.")
X = features[features["bacdive_id"].isin(strain_ids)].set_index("bacdive_id").sort_index()
feature_cols = [c for c in X.columns if c not in {"genome_accession"}]
X = X[feature_cols]
# Build sparse positive-link table → wide y matrix
sm = strain_media[strain_media["bacdive_id"].isin(strain_ids)]
sm = sm[sm["growth"] == "yes"]
counts = sm.groupby("medium_id").size()
keep_media = counts[counts >= MIN_STRAINS_PER_MEDIUM].index.tolist()
sm = sm[sm["medium_id"].isin(keep_media)]
y_matrix = (
sm.assign(_one=1)
.pivot_table(index="bacdive_id", columns="medium_id", values="_one", fill_value=0)
.reindex(index=X.index, columns=keep_media, fill_value=0)
.astype(np.uint8)
)
return X, y_matrix, keep_media
def train_per_medium(
X: pd.DataFrame,
y_matrix: pd.DataFrame,
medium_metadata: dict[str, str],
groups: pd.Series,
*,
n_splits: int = 5,
n_estimators: int = 200,
max_depth: int = 5,
) -> dict[str, MediumModelResult]:
"""Train one classifier per medium with GroupKFold by `groups` (e.g. taxonomic family)."""
results: dict[str, MediumModelResult] = {}
splits = min(n_splits, max(2, groups.nunique()))
kfold = GroupKFold(n_splits=splits)
for medium_id in y_matrix.columns:
y = y_matrix[medium_id].to_numpy()
n_pos, n_neg = int(y.sum()), int((y == 0).sum())
result = MediumModelResult(
medium_id=str(medium_id),
medium_name=medium_metadata.get(str(medium_id), ""),
n_positives=n_pos,
n_negatives=n_neg,
)
# Need both classes in train/test
for fold_idx, (tr_idx, te_idx) in enumerate(kfold.split(X, y, groups)):
y_tr = y[tr_idx]
y_te = y[te_idx]
if y_tr.sum() < 5 or y_te.sum() < 1:
continue
scale_pos_weight = (y_tr == 0).sum() / max(1, y_tr.sum())
model = xgb.XGBClassifier(
n_estimators=n_estimators,
max_depth=max_depth,
learning_rate=0.05,
tree_method="hist",
n_jobs=-1,
scale_pos_weight=scale_pos_weight,
eval_metric="logloss",
)
model.fit(X.iloc[tr_idx], y_tr)
proba = model.predict_proba(X.iloc[te_idx])[:, 1]
try:
roc = roc_auc_score(y_te, proba)
pr = average_precision_score(y_te, proba)
except ValueError:
continue
result.fold_metrics.append({
"fold": fold_idx,
"n_train": int(len(tr_idx)),
"n_test": int(len(te_idx)),
"n_test_positives": int(y_te.sum()),
"roc_auc": float(roc),
"pr_auc": float(pr),
})
results[str(medium_id)] = result
return results
def train_production_models(
X: pd.DataFrame,
y_matrix: pd.DataFrame,
*,
n_estimators: int = 300,
max_depth: int = 5,
) -> dict[str, xgb.XGBClassifier]:
"""Fit one classifier per medium on ALL data (no CV split). Used at inference.
Returns {medium_id: trained_model}. Caller is responsible for persistence —
see scripts/10_train_media_recommender.py for the disk layout.
"""
models: dict[str, xgb.XGBClassifier] = {}
for medium_id in y_matrix.columns:
y = y_matrix[medium_id].to_numpy()
if y.sum() < 10 or (y == 0).sum() < 10:
continue
scale_pos_weight = (y == 0).sum() / max(1, y.sum())
model = xgb.XGBClassifier(
n_estimators=n_estimators,
max_depth=max_depth,
learning_rate=0.05,
tree_method="hist",
n_jobs=-1,
scale_pos_weight=scale_pos_weight,
eval_metric="logloss",
)
model.fit(X, y)
models[str(medium_id)] = model
return models
def save_models(
models: dict[str, xgb.XGBClassifier],
feature_cols: list[str],
out_dir: Path,
) -> None:
"""Save each XGBoost model + feature column order for inference."""
out_dir.mkdir(parents=True, exist_ok=True)
for medium_id, model in models.items():
# Sanitize medium_id for filename safety
safe_id = "".join(c if c.isalnum() else "_" for c in medium_id)
model.save_model(out_dir / f"medium_{safe_id}.ubj")
(out_dir / "feature_cols.json").write_text(json.dumps(feature_cols))
def load_models(out_dir: Path) -> tuple[dict[str, xgb.XGBClassifier], list[str]]:
"""Load all saved per-medium models + the feature-column order."""
feature_cols = json.loads((out_dir / "feature_cols.json").read_text())
models: dict[str, xgb.XGBClassifier] = {}
for path in out_dir.glob("medium_*.ubj"):
medium_id = path.stem.removeprefix("medium_")
model = xgb.XGBClassifier()
model.load_model(path)
models[medium_id] = model
return models, feature_cols
def save_results(results: dict[str, MediumModelResult], path: Path) -> None:
payload = {
mid: {
"medium_name": r.medium_name,
"n_positives": r.n_positives,
"n_negatives": r.n_negatives,
"mean_pr_auc": r.mean_pr_auc(),
"mean_roc_auc": r.mean_roc_auc(),
"folds": r.fold_metrics,
}
for mid, r in results.items()
}
path.write_text(json.dumps(payload, indent=2))
|