Miyu Horiuchi Claude Opus 4.7 (1M context) commited on
Commit
d082ced
·
1 Parent(s): 383bb62

Add eval report generator + training table persistence + group-col override

Browse files

Adds the morning-readable end product:
- src/microbe_model/eval.py — render_report() generates artifacts/eval_report.md
with per-target metrics vs always-predict-mean baseline, fold-by-fold detail,
top features, and a limitations + next-steps section
- scripts/04_eval.py — thin wrapper to render the report from a finished training run
- scripts/03_train_baseline.py — saves the merged training_table.parquet for eval
to read, prefers BacDive's LPSN family over derived genus for GroupKFold
- src/microbe_model/train/baseline.py — train_all gains group_col_override

Each metric in the report is paired with a dumb-baseline (always-predict-mean
or always-predict-majority) so the reader can interpret it without context.

Tests still 12/12 passing, lint clean.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

scripts/03_train_baseline.py CHANGED
@@ -1,38 +1,54 @@
1
  """Train the multi-task XGBoost baseline.
2
 
3
- Joins phenotypes + features, derives a `family` column from `species` for group K-fold,
4
- and writes per-target metrics to artifacts/baseline_results.json.
5
  """
6
  from __future__ import annotations
7
 
 
 
8
  import pandas as pd
9
 
10
  from microbe_model import config
11
  from microbe_model.train.baseline import save_results, train_all
12
 
13
 
14
- def derive_family(species: str | None) -> str:
15
- """Crude family proxy: first word of binomial. Replace with GTDB lookup later."""
16
- if not species:
17
- return "__unknown__"
18
- return str(species).split()[0]
 
 
 
 
 
19
 
20
 
21
  def main() -> None:
 
22
  pheno = pd.read_parquet(config.DATA / "bacdive_phenotypes.parquet")
23
  feats = pd.read_parquet(config.DATA / "features.parquet")
24
  df = pheno.merge(feats, on=["bacdive_id", "genome_accession"], how="inner")
25
- df["family"] = df["species"].apply(derive_family)
26
 
27
  feature_cols = [c for c in feats.columns if c not in {"bacdive_id", "genome_accession"}]
28
- print(f"Training on {len(df)} strains × {len(feature_cols)} features.")
29
- print(f"Group counts (top 10): {df['family'].value_counts().head(10).to_dict()}")
30
 
31
- results = train_all(df, feature_cols)
 
 
 
 
 
 
 
 
 
32
 
33
  out = config.ARTIFACTS / "baseline_results.json"
34
  save_results(results, out)
35
- print(f"\nWrote results to {out}\n")
 
36
  for target, r in results.items():
37
  if r.folds:
38
  metric = r.folds[0].metric_name
 
1
  """Train the multi-task XGBoost baseline.
2
 
3
+ Joins phenotypes + features, derives a stable group column for GroupKFold, trains, saves
4
+ the merged training table for the eval renderer, and writes per-target metrics.
5
  """
6
  from __future__ import annotations
7
 
8
+ import time
9
+
10
  import pandas as pd
11
 
12
  from microbe_model import config
13
  from microbe_model.train.baseline import save_results, train_all
14
 
15
 
16
+ def derive_group(row: pd.Series) -> str:
17
+ """Group-K-fold key. Prefer LPSN family (from BacDive); fall back to genus then species."""
18
+ for col in ("family", "genus"):
19
+ val = row.get(col)
20
+ if isinstance(val, str) and val:
21
+ return val
22
+ species = row.get("species")
23
+ if isinstance(species, str) and species:
24
+ return species.split()[0]
25
+ return "__unknown__"
26
 
27
 
28
  def main() -> None:
29
+ t0 = time.time()
30
  pheno = pd.read_parquet(config.DATA / "bacdive_phenotypes.parquet")
31
  feats = pd.read_parquet(config.DATA / "features.parquet")
32
  df = pheno.merge(feats, on=["bacdive_id", "genome_accession"], how="inner")
33
+ df["group"] = df.apply(derive_group, axis=1)
34
 
35
  feature_cols = [c for c in feats.columns if c not in {"bacdive_id", "genome_accession"}]
 
 
36
 
37
+ print(f"Training table: {len(df):,} strains × {len(feature_cols)} features")
38
+ print(f"Distinct groups: {df['group'].nunique():,}")
39
+ print(f"Group sizes (top 10): {df['group'].value_counts().head(10).to_dict()}")
40
+ print()
41
+
42
+ training_table = config.DATA / "training_table.parquet"
43
+ df.to_parquet(training_table, index=False)
44
+ print(f"Wrote training table to {training_table}")
45
+
46
+ results = train_all(df, feature_cols, group_col_override="group")
47
 
48
  out = config.ARTIFACTS / "baseline_results.json"
49
  save_results(results, out)
50
+
51
+ print(f"\nResults summary ({time.time() - t0:.1f}s):\n")
52
  for target, r in results.items():
53
  if r.folds:
54
  metric = r.folds[0].metric_name
scripts/04_eval.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Render the v0 eval report from the trained-results JSON + training table."""
2
+ from __future__ import annotations
3
+
4
+ from microbe_model import config
5
+ from microbe_model.eval import render_report
6
+
7
+
8
+ def main() -> None:
9
+ results_path = config.ARTIFACTS / "baseline_results.json"
10
+ dataset_path = config.DATA / "training_table.parquet"
11
+ out_path = config.ARTIFACTS / "eval_report.md"
12
+
13
+ if not results_path.exists():
14
+ raise SystemExit(f"Missing {results_path}. Run scripts/03_train_baseline.py first.")
15
+ if not dataset_path.exists():
16
+ raise SystemExit(f"Missing {dataset_path}. Run scripts/03_train_baseline.py first.")
17
+
18
+ render_report(results_path, dataset_path, out_path)
19
+ print(f"Wrote {out_path}")
20
+
21
+
22
+ if __name__ == "__main__":
23
+ main()
src/microbe_model/eval.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Evaluation report generation.
2
+
3
+ Renders a markdown report from a trained-results JSON (the output of train/baseline.py)
4
+ joined with the source dataset. Designed to be readable cold — every number includes
5
+ a comparison baseline so the reader can interpret it without context.
6
+ """
7
+ from __future__ import annotations
8
+
9
+ import json
10
+ from datetime import UTC, datetime
11
+ from pathlib import Path
12
+ from typing import Any
13
+
14
+ import numpy as np
15
+ import pandas as pd
16
+
17
+ from microbe_model import config
18
+
19
+
20
+ def _baseline_mae(y: np.ndarray) -> float:
21
+ """MAE of the always-predict-mean baseline (sanity floor)."""
22
+ if len(y) == 0:
23
+ return float("nan")
24
+ return float(np.mean(np.abs(y - np.mean(y))))
25
+
26
+
27
+ def _baseline_f1(y: np.ndarray) -> float:
28
+ """Macro-F1 of the always-predict-majority baseline."""
29
+ from sklearn.metrics import f1_score
30
+ if len(y) == 0:
31
+ return float("nan")
32
+ values, counts = np.unique(y, return_counts=True)
33
+ majority = values[np.argmax(counts)]
34
+ pred = np.full_like(y, majority)
35
+ return float(f1_score(y, pred, average="macro"))
36
+
37
+
38
+ def render_report(
39
+ results_path: Path,
40
+ dataset_path: Path,
41
+ out_path: Path,
42
+ *,
43
+ n_strains: int | None = None,
44
+ runtime_seconds: float | None = None,
45
+ ) -> None:
46
+ results: dict[str, Any] = json.loads(results_path.read_text())
47
+ df = pd.read_parquet(dataset_path)
48
+
49
+ lines: list[str] = []
50
+ lines.append("# microbe-model — v0 baseline eval report")
51
+ lines.append("")
52
+ lines.append(f"_Generated: {datetime.now(UTC).isoformat(timespec='seconds')}_")
53
+ lines.append("")
54
+
55
+ # Section: corpus
56
+ lines.append("## Corpus")
57
+ lines.append("")
58
+ lines.append(f"- Total strains in feature table: **{len(df):,}**")
59
+ if n_strains is not None:
60
+ lines.append(f"- Total strains attempted (had genome accession + label): {n_strains:,}")
61
+ lines.append(f"- Feature-extraction success rate: {100 * len(df) / max(1, n_strains):.1f}%")
62
+ if runtime_seconds is not None:
63
+ lines.append(f"- Featurize wall time: {runtime_seconds / 60:.1f} min")
64
+ lines.append("")
65
+
66
+ # Section: per-target results
67
+ lines.append("## Per-target results (5-fold GroupKFold by family)")
68
+ lines.append("")
69
+ lines.append("Metrics: regression = MAE (lower is better), classification = macro-F1 (higher is better).")
70
+ lines.append("Each is shown alongside the dumb-baseline (always-predict-mean / always-predict-majority).")
71
+ lines.append("")
72
+ lines.append("| Target | Task | n labeled | Model metric | Baseline | Improvement |")
73
+ lines.append("|---|---|---|---|---|---|")
74
+ for target, r in results.items():
75
+ if not r["folds"]:
76
+ lines.append(f"| {target} | {r['task']} | — | _skipped (insufficient data)_ | — | — |")
77
+ continue
78
+ y = df[target].dropna().to_numpy()
79
+ n_labeled = len(y)
80
+ if r["task"] == "regression":
81
+ baseline = _baseline_mae(y.astype(float))
82
+ mean = r["mean_metric"]
83
+ improvement = f"{(baseline - mean) / baseline * 100:+.1f}%"
84
+ lines.append(f"| `{target}` | regression | {n_labeled:,} | "
85
+ f"MAE={mean:.3f} | MAE={baseline:.3f} | {improvement} |")
86
+ else:
87
+ baseline = _baseline_f1(y)
88
+ mean = r["mean_metric"]
89
+ improvement = f"{(mean - baseline) / max(0.01, baseline) * 100:+.1f}%"
90
+ lines.append(f"| `{target}` | classification | {n_labeled:,} | "
91
+ f"F1={mean:.3f} | F1={baseline:.3f} | {improvement} |")
92
+ lines.append("")
93
+
94
+ # Section: per-fold detail
95
+ for target, r in results.items():
96
+ if not r["folds"]:
97
+ continue
98
+ lines.append(f"### `{target}` — fold-by-fold")
99
+ lines.append("")
100
+ lines.append("| Fold | Metric | Train | Test |")
101
+ lines.append("|---|---|---|---|")
102
+ for i, f in enumerate(r["folds"]):
103
+ lines.append(f"| {i+1} | {f['metric_name']} = {f['value']:.3f} | "
104
+ f"n={f['n_train']:,} | n={f['n_test']:,} |")
105
+ lines.append("")
106
+
107
+ top = r.get("top_features", {})
108
+ if top:
109
+ lines.append(f"**Top 10 features for `{target}`:**")
110
+ lines.append("")
111
+ for name, importance in list(top.items())[:10]:
112
+ lines.append(f"- `{name}` — {importance:.4f}")
113
+ lines.append("")
114
+
115
+ # Section: limitations
116
+ lines.append("## Known limitations")
117
+ lines.append("")
118
+ lines.append("- **Survivorship bias.** BacDive only contains organisms that have been cultured "
119
+ "successfully at least once. The model cannot generalize to truly uncultured strains "
120
+ "without explicit out-of-distribution evaluation.")
121
+ lines.append("- **Optimum derivation is heuristic.** Most BacDive temperature entries are tagged "
122
+ "as `growth` (positive growth at this temperature), not `optimum`. We approximate "
123
+ "the optimum as the median of positive-growth temperatures when no explicit "
124
+ "optimum is recorded — this can be off by 5°C or more for some strains.")
125
+ lines.append("- **Family grouping is naive.** The current `family` column is derived from the "
126
+ "genus (first word of binomial name). A proper LPSN/GTDB family assignment would "
127
+ "give tighter taxonomic grouping.")
128
+ lines.append("- **Feature set is shallow.** No HMM/KEGG annotations, no codon usage indices, no "
129
+ "tRNA counts. These are interpretable next steps before moving to genome LMs.")
130
+ lines.append("- **Pyrodigal accuracy.** Gene prediction quality drops on highly-fragmented "
131
+ "assemblies and atypical genetic codes. Not currently flagged in the feature set.")
132
+ lines.append("")
133
+
134
+ # Section: next steps
135
+ lines.append("## Next steps")
136
+ lines.append("")
137
+ lines.append("1. **Add tetranucleotide / codon-usage features.** ~50 extra columns, "
138
+ "well-known signal for thermophily.")
139
+ lines.append("2. **Replace naive family lookup with LPSN/GTDB join.** Reduces leakage in CV.")
140
+ lines.append("3. **Integrate KOMODO media DB** as a richer label source than BacDive alone.")
141
+ lines.append("4. **Move to genome embeddings** (Nucleotide Transformer / Evo-1 / DNABERT-2) "
142
+ "once the tabular ceiling is established.")
143
+ lines.append("5. **Active learning loop**: select novel-family strains where the model is "
144
+ "uncertain, prioritize these for wet-lab cultivation testing.")
145
+ lines.append("")
146
+
147
+ out_path.parent.mkdir(parents=True, exist_ok=True)
148
+ out_path.write_text("\n".join(lines))
149
+
150
+
151
+ if __name__ == "__main__":
152
+ render_report(
153
+ results_path=config.ARTIFACTS / "baseline_results.json",
154
+ dataset_path=config.DATA / "training_table.parquet",
155
+ out_path=config.ARTIFACTS / "eval_report.md",
156
+ )
src/microbe_model/train/baseline.py CHANGED
@@ -120,12 +120,18 @@ def train_target(
120
  return result
121
 
122
 
123
- def train_all(df: pd.DataFrame, feature_cols: list[str]) -> dict[str, TargetResult]:
 
 
 
 
 
124
  results: dict[str, TargetResult] = {}
 
125
  for target, task in config.PHENOTYPE_TARGETS.items():
126
  if target not in df.columns:
127
  continue
128
- results[target] = train_target(df, target, task, feature_cols)
129
  return results
130
 
131
 
 
120
  return result
121
 
122
 
123
+ def train_all(
124
+ df: pd.DataFrame,
125
+ feature_cols: list[str],
126
+ *,
127
+ group_col_override: str | None = None,
128
+ ) -> dict[str, TargetResult]:
129
  results: dict[str, TargetResult] = {}
130
+ group_col = group_col_override or "family"
131
  for target, task in config.PHENOTYPE_TARGETS.items():
132
  if target not in df.columns:
133
  continue
134
+ results[target] = train_target(df, target, task, feature_cols, group_col=group_col)
135
  return results
136
 
137