Miyu Horiuchi
Fix predictions parquet type mix + plumb feature_cols through eval
bbbea9d
"""Evaluation report generation.
Renders a markdown report from a trained-results JSON (the output of train/baseline.py)
joined with the source dataset. Designed to be readable cold — every number includes
a comparison baseline so the reader can interpret it without context.
"""
from __future__ import annotations
import json
from datetime import UTC, datetime
from pathlib import Path
from typing import Any
import numpy as np
import pandas as pd
from microbe_model import config
def _baseline_mae(y: np.ndarray) -> float:
"""MAE of the always-predict-mean baseline (sanity floor)."""
if len(y) == 0:
return float("nan")
return float(np.mean(np.abs(y - np.mean(y))))
def _baseline_f1(y: np.ndarray) -> float:
"""Macro-F1 of the always-predict-majority baseline."""
from sklearn.metrics import f1_score
if len(y) == 0:
return float("nan")
values, counts = np.unique(y, return_counts=True)
majority = values[np.argmax(counts)]
pred = np.full_like(y, majority)
return float(f1_score(y, pred, average="macro"))
def render_report(
results_path: Path,
dataset_path: Path,
out_path: Path,
*,
n_strains: int | None = None,
runtime_seconds: float | None = None,
predictions_path: Path | None = None,
feature_cols: list[str] | None = None,
) -> None:
raw_results: dict[str, Any] = json.loads(results_path.read_text())
meta = raw_results.pop("__meta__", {})
if feature_cols is None and "feature_cols" in meta:
feature_cols = meta["feature_cols"]
results: dict[str, Any] = raw_results
df = pd.read_parquet(dataset_path)
predictions = (
pd.read_parquet(predictions_path)
if predictions_path is not None and predictions_path.exists()
else None
)
lines: list[str] = []
lines.append("# microbe-model — v0 baseline eval report")
lines.append("")
lines.append(f"_Generated: {datetime.now(UTC).isoformat(timespec='seconds')}_")
lines.append("")
# Section: TL;DR — the headline number
lines.append("## TL;DR")
lines.append("")
headline_lines = []
for target, r in results.items():
if not r["folds"]:
continue
y = df[target].dropna().to_numpy()
if r["task"] == "regression":
baseline = _baseline_mae(y.astype(float))
improvement = (baseline - r["mean_metric"]) / max(0.001, baseline) * 100
headline_lines.append(
f"- **`{target}`**: MAE = **{r['mean_metric']:.2f}** "
f"(vs always-predict-mean {baseline:.2f}, **{improvement:+.0f}%**)"
)
else:
baseline = _baseline_f1(y)
improvement = (r["mean_metric"] - baseline) / max(0.001, baseline) * 100
headline_lines.append(
f"- **`{target}`**: macro-F1 = **{r['mean_metric']:.3f}** "
f"(vs always-predict-majority {baseline:.3f}, **{improvement:+.0f}%**)"
)
lines.extend(headline_lines if headline_lines else [
"- _No targets trained successfully — see logs._"
])
lines.append("")
n_features = (
len(feature_cols) if feature_cols is not None
else sum(
1 for c in df.columns
if c.startswith((
"aa_frac_", "genome_size", "gc_", "n_predicted", "coding_",
"mean_", "aromatic_", "pos_", "neg_", "ivywrel_", "median_",
))
)
)
lines.append(f"Trained on **{len(df):,}** strains with **{n_features}** genome-derived features. "
f"Cross-validation: 5-fold GroupKFold by taxonomic family.")
lines.append("")
# Section: corpus
lines.append("## Corpus")
lines.append("")
lines.append(f"- Total strains in feature table: **{len(df):,}**")
if n_strains is not None:
lines.append(f"- Total strains attempted (had genome accession + label): {n_strains:,}")
lines.append(f"- Feature-extraction success rate: {100 * len(df) / max(1, n_strains):.1f}%")
if runtime_seconds is not None:
lines.append(f"- Featurize wall time: {runtime_seconds / 60:.1f} min")
# Per-target label counts
lines.append("- Labeled-strain counts by target:")
for target in ("optimal_temperature_c", "optimal_ph", "oxygen_requirement", "salt_tolerance_pct"):
if target in df.columns:
n = df[target].notna().sum()
lines.append(f" - `{target}`: {n:,}")
lines.append("")
# Section: data exploration — distributions of the regression targets
lines.append("## Target distributions")
lines.append("")
for target in ("optimal_temperature_c", "optimal_ph", "salt_tolerance_pct"):
if target not in df.columns:
continue
y = df[target].dropna()
if len(y) == 0:
continue
lines.append(
f"- `{target}`: n={len(y):,}, mean={y.mean():.2f}, "
f"std={y.std():.2f}, p10={y.quantile(0.1):.2f}, "
f"median={y.median():.2f}, p90={y.quantile(0.9):.2f}"
)
if "oxygen_requirement" in df.columns:
lines.append("- `oxygen_requirement`:")
for cls, n in df["oxygen_requirement"].value_counts().head(10).items():
lines.append(f" - `{cls}`: {n:,}")
lines.append("")
# Section: per-target results
lines.append("## Per-target results (5-fold GroupKFold by family)")
lines.append("")
lines.append("Metrics: regression = MAE (lower is better), classification = macro-F1 (higher is better).")
lines.append("Each is shown alongside the dumb-baseline (always-predict-mean / always-predict-majority).")
lines.append("")
lines.append("| Target | Task | n labeled | Model metric | Baseline | Improvement |")
lines.append("|---|---|---|---|---|---|")
for target, r in results.items():
if not r["folds"]:
lines.append(f"| {target} | {r['task']} | — | _skipped (insufficient data)_ | — | — |")
continue
y = df[target].dropna().to_numpy()
n_labeled = len(y)
if r["task"] == "regression":
baseline = _baseline_mae(y.astype(float))
mean = r["mean_metric"]
improvement = f"{(baseline - mean) / baseline * 100:+.1f}%"
lines.append(f"| `{target}` | regression | {n_labeled:,} | "
f"MAE={mean:.3f} | MAE={baseline:.3f} | {improvement} |")
else:
baseline = _baseline_f1(y)
mean = r["mean_metric"]
improvement = f"{(mean - baseline) / max(0.01, baseline) * 100:+.1f}%"
lines.append(f"| `{target}` | classification | {n_labeled:,} | "
f"F1={mean:.3f} | F1={baseline:.3f} | {improvement} |")
lines.append("")
# Section: per-fold detail
for target, r in results.items():
if not r["folds"]:
continue
lines.append(f"### `{target}` — fold-by-fold")
lines.append("")
lines.append("| Fold | Metric | Train | Test |")
lines.append("|---|---|---|---|")
for i, f in enumerate(r["folds"]):
lines.append(f"| {i+1} | {f['metric_name']} = {f['value']:.3f} | "
f"n={f['n_train']:,} | n={f['n_test']:,} |")
lines.append("")
top = r.get("top_features", {})
if top:
lines.append(f"**Top 10 features for `{target}`:**")
lines.append("")
for name, importance in list(top.items())[:10]:
lines.append(f"- `{name}` — {importance:.4f}")
lines.append("")
# Section: feature-target correlations (data-exploration sanity check)
detected_feature_cols = feature_cols if feature_cols is not None else [
c for c in df.columns
if c.startswith(("aa_frac_", "genome_size", "gc_", "n_predicted", "coding_",
"mean_", "aromatic_", "pos_", "neg_", "ivywrel_", "median_", "f"))
and pd.api.types.is_numeric_dtype(df[c])
]
if detected_feature_cols:
from microbe_model.explore import feature_target_correlations
lines.append("## Feature ↔ target correlations (Spearman, top 10)")
lines.append("")
lines.append("Sanity-checks the biology — features known to track each target should "
"appear here at high |ρ|. E.g. `ivywrel_frac` should correlate with "
"`optimal_temperature_c` (Zeldovich 2007 thermophile signature).")
lines.append("")
for target in ("optimal_temperature_c", "optimal_ph", "salt_tolerance_pct"):
corrs = feature_target_correlations(df, detected_feature_cols, target, top_n=10)
if not corrs:
continue
lines.append(f"### `{target}`")
lines.append("")
lines.append("| Feature | Spearman ρ | p-value |")
lines.append("|---|---|---|")
for row in corrs:
lines.append(f"| `{row['feature']}` | {row['spearman_rho']:+.3f} | "
f"{row['p_value']:.1e} |")
lines.append("")
# Section: per-phylum error breakdown (regression targets only)
if predictions is not None and not predictions.empty and "row_idx" in predictions.columns:
joined = predictions.merge(
df[["genus", "family"]].rename_axis("row_idx").reset_index(),
on="row_idx",
how="left",
)
regression_preds = joined[joined["task"] == "regression"]
if not regression_preds.empty:
lines.append("## Per-family error breakdown (regression targets)")
lines.append("")
lines.append("Top 15 most-represented families, MAE per family. Highlights where the "
"model is doing well vs. struggling.")
lines.append("")
for target in regression_preds["target"].unique():
sub = regression_preds[regression_preds["target"] == target].copy()
sub["abs_error"] = (
pd.to_numeric(sub["predicted"]) - pd.to_numeric(sub["observed"])
).abs()
grp = (sub.groupby("family", dropna=False)
.agg(n=("abs_error", "size"), mae=("abs_error", "mean"))
.sort_values("n", ascending=False)
.head(15))
if grp.empty:
continue
lines.append(f"### `{target}`")
lines.append("")
lines.append("| Family | n | MAE |")
lines.append("|---|---|---|")
for fam, row in grp.iterrows():
fam_label = fam if pd.notna(fam) else "_(no family)_"
lines.append(f"| {fam_label} | {int(row['n'])} | {row['mae']:.3f} |")
lines.append("")
# Section: limitations
lines.append("## Known limitations")
lines.append("")
lines.append("- **Survivorship bias.** BacDive only contains organisms that have been cultured "
"successfully at least once. The model cannot generalize to truly uncultured strains "
"without explicit out-of-distribution evaluation.")
lines.append("- **Optimum derivation is heuristic.** Most BacDive temperature entries are tagged "
"as `growth` (positive growth at this temperature), not `optimum`. We approximate "
"the optimum as the median of positive-growth temperatures when no explicit "
"optimum is recorded — this can be off by 5°C or more for some strains.")
lines.append("- **Family grouping is naive.** The current `family` column is derived from the "
"genus (first word of binomial name). A proper LPSN/GTDB family assignment would "
"give tighter taxonomic grouping.")
lines.append("- **Feature set is shallow.** No HMM/KEGG annotations, no codon usage indices, no "
"tRNA counts. These are interpretable next steps before moving to genome LMs.")
lines.append("- **Pyrodigal accuracy.** Gene prediction quality drops on highly-fragmented "
"assemblies and atypical genetic codes. Not currently flagged in the feature set.")
lines.append("")
# Section: next steps
lines.append("## Next steps")
lines.append("")
lines.append("1. **Add tetranucleotide / codon-usage features.** ~50 extra columns, "
"well-known signal for thermophily.")
lines.append("2. **Replace naive family lookup with LPSN/GTDB join.** Reduces leakage in CV.")
lines.append("3. **Integrate KOMODO media DB** as a richer label source than BacDive alone.")
lines.append("4. **Move to genome embeddings** (Nucleotide Transformer / Evo-1 / DNABERT-2) "
"once the tabular ceiling is established.")
lines.append("5. **Active learning loop**: select novel-family strains where the model is "
"uncertain, prioritize these for wet-lab cultivation testing.")
lines.append("")
out_path.parent.mkdir(parents=True, exist_ok=True)
out_path.write_text("\n".join(lines))
if __name__ == "__main__":
render_report(
results_path=config.ARTIFACTS / "baseline_results.json",
dataset_path=config.DATA / "training_table.parquet",
out_path=config.ARTIFACTS / "eval_report.md",
predictions_path=config.ARTIFACTS / "predictions.parquet",
)