Spaces:
Running
Add eval report generator + training table persistence + group-col override
Browse filesAdds the morning-readable end product:
- src/microbe_model/eval.py — render_report() generates artifacts/eval_report.md
with per-target metrics vs always-predict-mean baseline, fold-by-fold detail,
top features, and a limitations + next-steps section
- scripts/04_eval.py — thin wrapper to render the report from a finished training run
- scripts/03_train_baseline.py — saves the merged training_table.parquet for eval
to read, prefers BacDive's LPSN family over derived genus for GroupKFold
- src/microbe_model/train/baseline.py — train_all gains group_col_override
Each metric in the report is paired with a dumb-baseline (always-predict-mean
or always-predict-majority) so the reader can interpret it without context.
Tests still 12/12 passing, lint clean.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- scripts/03_train_baseline.py +28 -12
- scripts/04_eval.py +23 -0
- src/microbe_model/eval.py +156 -0
- src/microbe_model/train/baseline.py +8 -2
|
@@ -1,38 +1,54 @@
|
|
| 1 |
"""Train the multi-task XGBoost baseline.
|
| 2 |
|
| 3 |
-
Joins phenotypes + features, derives a
|
| 4 |
-
and writes per-target metrics
|
| 5 |
"""
|
| 6 |
from __future__ import annotations
|
| 7 |
|
|
|
|
|
|
|
| 8 |
import pandas as pd
|
| 9 |
|
| 10 |
from microbe_model import config
|
| 11 |
from microbe_model.train.baseline import save_results, train_all
|
| 12 |
|
| 13 |
|
| 14 |
-
def
|
| 15 |
-
"""
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
|
| 21 |
def main() -> None:
|
|
|
|
| 22 |
pheno = pd.read_parquet(config.DATA / "bacdive_phenotypes.parquet")
|
| 23 |
feats = pd.read_parquet(config.DATA / "features.parquet")
|
| 24 |
df = pheno.merge(feats, on=["bacdive_id", "genome_accession"], how="inner")
|
| 25 |
-
df["
|
| 26 |
|
| 27 |
feature_cols = [c for c in feats.columns if c not in {"bacdive_id", "genome_accession"}]
|
| 28 |
-
print(f"Training on {len(df)} strains × {len(feature_cols)} features.")
|
| 29 |
-
print(f"Group counts (top 10): {df['family'].value_counts().head(10).to_dict()}")
|
| 30 |
|
| 31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
out = config.ARTIFACTS / "baseline_results.json"
|
| 34 |
save_results(results, out)
|
| 35 |
-
|
|
|
|
| 36 |
for target, r in results.items():
|
| 37 |
if r.folds:
|
| 38 |
metric = r.folds[0].metric_name
|
|
|
|
| 1 |
"""Train the multi-task XGBoost baseline.
|
| 2 |
|
| 3 |
+
Joins phenotypes + features, derives a stable group column for GroupKFold, trains, saves
|
| 4 |
+
the merged training table for the eval renderer, and writes per-target metrics.
|
| 5 |
"""
|
| 6 |
from __future__ import annotations
|
| 7 |
|
| 8 |
+
import time
|
| 9 |
+
|
| 10 |
import pandas as pd
|
| 11 |
|
| 12 |
from microbe_model import config
|
| 13 |
from microbe_model.train.baseline import save_results, train_all
|
| 14 |
|
| 15 |
|
| 16 |
+
def derive_group(row: pd.Series) -> str:
|
| 17 |
+
"""Group-K-fold key. Prefer LPSN family (from BacDive); fall back to genus then species."""
|
| 18 |
+
for col in ("family", "genus"):
|
| 19 |
+
val = row.get(col)
|
| 20 |
+
if isinstance(val, str) and val:
|
| 21 |
+
return val
|
| 22 |
+
species = row.get("species")
|
| 23 |
+
if isinstance(species, str) and species:
|
| 24 |
+
return species.split()[0]
|
| 25 |
+
return "__unknown__"
|
| 26 |
|
| 27 |
|
| 28 |
def main() -> None:
|
| 29 |
+
t0 = time.time()
|
| 30 |
pheno = pd.read_parquet(config.DATA / "bacdive_phenotypes.parquet")
|
| 31 |
feats = pd.read_parquet(config.DATA / "features.parquet")
|
| 32 |
df = pheno.merge(feats, on=["bacdive_id", "genome_accession"], how="inner")
|
| 33 |
+
df["group"] = df.apply(derive_group, axis=1)
|
| 34 |
|
| 35 |
feature_cols = [c for c in feats.columns if c not in {"bacdive_id", "genome_accession"}]
|
|
|
|
|
|
|
| 36 |
|
| 37 |
+
print(f"Training table: {len(df):,} strains × {len(feature_cols)} features")
|
| 38 |
+
print(f"Distinct groups: {df['group'].nunique():,}")
|
| 39 |
+
print(f"Group sizes (top 10): {df['group'].value_counts().head(10).to_dict()}")
|
| 40 |
+
print()
|
| 41 |
+
|
| 42 |
+
training_table = config.DATA / "training_table.parquet"
|
| 43 |
+
df.to_parquet(training_table, index=False)
|
| 44 |
+
print(f"Wrote training table to {training_table}")
|
| 45 |
+
|
| 46 |
+
results = train_all(df, feature_cols, group_col_override="group")
|
| 47 |
|
| 48 |
out = config.ARTIFACTS / "baseline_results.json"
|
| 49 |
save_results(results, out)
|
| 50 |
+
|
| 51 |
+
print(f"\nResults summary ({time.time() - t0:.1f}s):\n")
|
| 52 |
for target, r in results.items():
|
| 53 |
if r.folds:
|
| 54 |
metric = r.folds[0].metric_name
|
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Render the v0 eval report from the trained-results JSON + training table."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
from microbe_model import config
|
| 5 |
+
from microbe_model.eval import render_report
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def main() -> None:
|
| 9 |
+
results_path = config.ARTIFACTS / "baseline_results.json"
|
| 10 |
+
dataset_path = config.DATA / "training_table.parquet"
|
| 11 |
+
out_path = config.ARTIFACTS / "eval_report.md"
|
| 12 |
+
|
| 13 |
+
if not results_path.exists():
|
| 14 |
+
raise SystemExit(f"Missing {results_path}. Run scripts/03_train_baseline.py first.")
|
| 15 |
+
if not dataset_path.exists():
|
| 16 |
+
raise SystemExit(f"Missing {dataset_path}. Run scripts/03_train_baseline.py first.")
|
| 17 |
+
|
| 18 |
+
render_report(results_path, dataset_path, out_path)
|
| 19 |
+
print(f"Wrote {out_path}")
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
if __name__ == "__main__":
|
| 23 |
+
main()
|
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Evaluation report generation.
|
| 2 |
+
|
| 3 |
+
Renders a markdown report from a trained-results JSON (the output of train/baseline.py)
|
| 4 |
+
joined with the source dataset. Designed to be readable cold — every number includes
|
| 5 |
+
a comparison baseline so the reader can interpret it without context.
|
| 6 |
+
"""
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import json
|
| 10 |
+
from datetime import UTC, datetime
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from typing import Any
|
| 13 |
+
|
| 14 |
+
import numpy as np
|
| 15 |
+
import pandas as pd
|
| 16 |
+
|
| 17 |
+
from microbe_model import config
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def _baseline_mae(y: np.ndarray) -> float:
|
| 21 |
+
"""MAE of the always-predict-mean baseline (sanity floor)."""
|
| 22 |
+
if len(y) == 0:
|
| 23 |
+
return float("nan")
|
| 24 |
+
return float(np.mean(np.abs(y - np.mean(y))))
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def _baseline_f1(y: np.ndarray) -> float:
|
| 28 |
+
"""Macro-F1 of the always-predict-majority baseline."""
|
| 29 |
+
from sklearn.metrics import f1_score
|
| 30 |
+
if len(y) == 0:
|
| 31 |
+
return float("nan")
|
| 32 |
+
values, counts = np.unique(y, return_counts=True)
|
| 33 |
+
majority = values[np.argmax(counts)]
|
| 34 |
+
pred = np.full_like(y, majority)
|
| 35 |
+
return float(f1_score(y, pred, average="macro"))
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def render_report(
|
| 39 |
+
results_path: Path,
|
| 40 |
+
dataset_path: Path,
|
| 41 |
+
out_path: Path,
|
| 42 |
+
*,
|
| 43 |
+
n_strains: int | None = None,
|
| 44 |
+
runtime_seconds: float | None = None,
|
| 45 |
+
) -> None:
|
| 46 |
+
results: dict[str, Any] = json.loads(results_path.read_text())
|
| 47 |
+
df = pd.read_parquet(dataset_path)
|
| 48 |
+
|
| 49 |
+
lines: list[str] = []
|
| 50 |
+
lines.append("# microbe-model — v0 baseline eval report")
|
| 51 |
+
lines.append("")
|
| 52 |
+
lines.append(f"_Generated: {datetime.now(UTC).isoformat(timespec='seconds')}_")
|
| 53 |
+
lines.append("")
|
| 54 |
+
|
| 55 |
+
# Section: corpus
|
| 56 |
+
lines.append("## Corpus")
|
| 57 |
+
lines.append("")
|
| 58 |
+
lines.append(f"- Total strains in feature table: **{len(df):,}**")
|
| 59 |
+
if n_strains is not None:
|
| 60 |
+
lines.append(f"- Total strains attempted (had genome accession + label): {n_strains:,}")
|
| 61 |
+
lines.append(f"- Feature-extraction success rate: {100 * len(df) / max(1, n_strains):.1f}%")
|
| 62 |
+
if runtime_seconds is not None:
|
| 63 |
+
lines.append(f"- Featurize wall time: {runtime_seconds / 60:.1f} min")
|
| 64 |
+
lines.append("")
|
| 65 |
+
|
| 66 |
+
# Section: per-target results
|
| 67 |
+
lines.append("## Per-target results (5-fold GroupKFold by family)")
|
| 68 |
+
lines.append("")
|
| 69 |
+
lines.append("Metrics: regression = MAE (lower is better), classification = macro-F1 (higher is better).")
|
| 70 |
+
lines.append("Each is shown alongside the dumb-baseline (always-predict-mean / always-predict-majority).")
|
| 71 |
+
lines.append("")
|
| 72 |
+
lines.append("| Target | Task | n labeled | Model metric | Baseline | Improvement |")
|
| 73 |
+
lines.append("|---|---|---|---|---|---|")
|
| 74 |
+
for target, r in results.items():
|
| 75 |
+
if not r["folds"]:
|
| 76 |
+
lines.append(f"| {target} | {r['task']} | — | _skipped (insufficient data)_ | — | — |")
|
| 77 |
+
continue
|
| 78 |
+
y = df[target].dropna().to_numpy()
|
| 79 |
+
n_labeled = len(y)
|
| 80 |
+
if r["task"] == "regression":
|
| 81 |
+
baseline = _baseline_mae(y.astype(float))
|
| 82 |
+
mean = r["mean_metric"]
|
| 83 |
+
improvement = f"{(baseline - mean) / baseline * 100:+.1f}%"
|
| 84 |
+
lines.append(f"| `{target}` | regression | {n_labeled:,} | "
|
| 85 |
+
f"MAE={mean:.3f} | MAE={baseline:.3f} | {improvement} |")
|
| 86 |
+
else:
|
| 87 |
+
baseline = _baseline_f1(y)
|
| 88 |
+
mean = r["mean_metric"]
|
| 89 |
+
improvement = f"{(mean - baseline) / max(0.01, baseline) * 100:+.1f}%"
|
| 90 |
+
lines.append(f"| `{target}` | classification | {n_labeled:,} | "
|
| 91 |
+
f"F1={mean:.3f} | F1={baseline:.3f} | {improvement} |")
|
| 92 |
+
lines.append("")
|
| 93 |
+
|
| 94 |
+
# Section: per-fold detail
|
| 95 |
+
for target, r in results.items():
|
| 96 |
+
if not r["folds"]:
|
| 97 |
+
continue
|
| 98 |
+
lines.append(f"### `{target}` — fold-by-fold")
|
| 99 |
+
lines.append("")
|
| 100 |
+
lines.append("| Fold | Metric | Train | Test |")
|
| 101 |
+
lines.append("|---|---|---|---|")
|
| 102 |
+
for i, f in enumerate(r["folds"]):
|
| 103 |
+
lines.append(f"| {i+1} | {f['metric_name']} = {f['value']:.3f} | "
|
| 104 |
+
f"n={f['n_train']:,} | n={f['n_test']:,} |")
|
| 105 |
+
lines.append("")
|
| 106 |
+
|
| 107 |
+
top = r.get("top_features", {})
|
| 108 |
+
if top:
|
| 109 |
+
lines.append(f"**Top 10 features for `{target}`:**")
|
| 110 |
+
lines.append("")
|
| 111 |
+
for name, importance in list(top.items())[:10]:
|
| 112 |
+
lines.append(f"- `{name}` — {importance:.4f}")
|
| 113 |
+
lines.append("")
|
| 114 |
+
|
| 115 |
+
# Section: limitations
|
| 116 |
+
lines.append("## Known limitations")
|
| 117 |
+
lines.append("")
|
| 118 |
+
lines.append("- **Survivorship bias.** BacDive only contains organisms that have been cultured "
|
| 119 |
+
"successfully at least once. The model cannot generalize to truly uncultured strains "
|
| 120 |
+
"without explicit out-of-distribution evaluation.")
|
| 121 |
+
lines.append("- **Optimum derivation is heuristic.** Most BacDive temperature entries are tagged "
|
| 122 |
+
"as `growth` (positive growth at this temperature), not `optimum`. We approximate "
|
| 123 |
+
"the optimum as the median of positive-growth temperatures when no explicit "
|
| 124 |
+
"optimum is recorded — this can be off by 5°C or more for some strains.")
|
| 125 |
+
lines.append("- **Family grouping is naive.** The current `family` column is derived from the "
|
| 126 |
+
"genus (first word of binomial name). A proper LPSN/GTDB family assignment would "
|
| 127 |
+
"give tighter taxonomic grouping.")
|
| 128 |
+
lines.append("- **Feature set is shallow.** No HMM/KEGG annotations, no codon usage indices, no "
|
| 129 |
+
"tRNA counts. These are interpretable next steps before moving to genome LMs.")
|
| 130 |
+
lines.append("- **Pyrodigal accuracy.** Gene prediction quality drops on highly-fragmented "
|
| 131 |
+
"assemblies and atypical genetic codes. Not currently flagged in the feature set.")
|
| 132 |
+
lines.append("")
|
| 133 |
+
|
| 134 |
+
# Section: next steps
|
| 135 |
+
lines.append("## Next steps")
|
| 136 |
+
lines.append("")
|
| 137 |
+
lines.append("1. **Add tetranucleotide / codon-usage features.** ~50 extra columns, "
|
| 138 |
+
"well-known signal for thermophily.")
|
| 139 |
+
lines.append("2. **Replace naive family lookup with LPSN/GTDB join.** Reduces leakage in CV.")
|
| 140 |
+
lines.append("3. **Integrate KOMODO media DB** as a richer label source than BacDive alone.")
|
| 141 |
+
lines.append("4. **Move to genome embeddings** (Nucleotide Transformer / Evo-1 / DNABERT-2) "
|
| 142 |
+
"once the tabular ceiling is established.")
|
| 143 |
+
lines.append("5. **Active learning loop**: select novel-family strains where the model is "
|
| 144 |
+
"uncertain, prioritize these for wet-lab cultivation testing.")
|
| 145 |
+
lines.append("")
|
| 146 |
+
|
| 147 |
+
out_path.parent.mkdir(parents=True, exist_ok=True)
|
| 148 |
+
out_path.write_text("\n".join(lines))
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
if __name__ == "__main__":
|
| 152 |
+
render_report(
|
| 153 |
+
results_path=config.ARTIFACTS / "baseline_results.json",
|
| 154 |
+
dataset_path=config.DATA / "training_table.parquet",
|
| 155 |
+
out_path=config.ARTIFACTS / "eval_report.md",
|
| 156 |
+
)
|
|
@@ -120,12 +120,18 @@ def train_target(
|
|
| 120 |
return result
|
| 121 |
|
| 122 |
|
| 123 |
-
def train_all(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
results: dict[str, TargetResult] = {}
|
|
|
|
| 125 |
for target, task in config.PHENOTYPE_TARGETS.items():
|
| 126 |
if target not in df.columns:
|
| 127 |
continue
|
| 128 |
-
results[target] = train_target(df, target, task, feature_cols)
|
| 129 |
return results
|
| 130 |
|
| 131 |
|
|
|
|
| 120 |
return result
|
| 121 |
|
| 122 |
|
| 123 |
+
def train_all(
|
| 124 |
+
df: pd.DataFrame,
|
| 125 |
+
feature_cols: list[str],
|
| 126 |
+
*,
|
| 127 |
+
group_col_override: str | None = None,
|
| 128 |
+
) -> dict[str, TargetResult]:
|
| 129 |
results: dict[str, TargetResult] = {}
|
| 130 |
+
group_col = group_col_override or "family"
|
| 131 |
for target, task in config.PHENOTYPE_TARGETS.items():
|
| 132 |
if target not in df.columns:
|
| 133 |
continue
|
| 134 |
+
results[target] = train_target(df, target, task, feature_cols, group_col=group_col)
|
| 135 |
return results
|
| 136 |
|
| 137 |
|