Miyu Horiuchi commited on
Commit
bbbea9d
·
1 Parent(s): edf6713

Fix predictions parquet type mix + plumb feature_cols through eval

Browse files

Two bugs surfaced by new tests that would have broken the morning chain:

1. predictions.parquet schema error
pyarrow can't store a column with both float values (regression
predictions) and string values (classification labels). save_results now
casts both predicted/observed to str on write — eval.py already handles
the inverse cast via pd.to_numeric for regression analyses.

2. Feature-target correlations section missing on synthetic data
eval.py was inferring feature_cols from column-name prefixes, which
matched the production feature names but missed the test fixture's f0..f7
columns. Now: feature_cols can be passed explicitly OR read from a
__meta__ section in baseline_results.json (which 03_train_baseline.py
now writes).

Both surfaced by extending tests/test_integration.py with two new tests:
- test_save_results_writes_predictions_parquet
- test_full_chain_render_with_predictions

Total tests: 21/21 passing.

Featurize at 59% (10140/17094, 0.21% failure rate) — well within tolerance.

scripts/03_train_baseline.py CHANGED
@@ -47,7 +47,7 @@ def main() -> None:
47
 
48
  out = config.ARTIFACTS / "baseline_results.json"
49
  predictions_out = config.ARTIFACTS / "predictions.parquet"
50
- save_results(results, out, predictions_path=predictions_out)
51
  print(f"Wrote per-strain predictions to {predictions_out}")
52
 
53
  print(f"\nResults summary ({time.time() - t0:.1f}s):\n")
 
47
 
48
  out = config.ARTIFACTS / "baseline_results.json"
49
  predictions_out = config.ARTIFACTS / "predictions.parquet"
50
+ save_results(results, out, predictions_path=predictions_out, feature_cols=feature_cols)
51
  print(f"Wrote per-strain predictions to {predictions_out}")
52
 
53
  print(f"\nResults summary ({time.time() - t0:.1f}s):\n")
src/microbe_model/eval.py CHANGED
@@ -43,8 +43,13 @@ def render_report(
43
  n_strains: int | None = None,
44
  runtime_seconds: float | None = None,
45
  predictions_path: Path | None = None,
 
46
  ) -> None:
47
- results: dict[str, Any] = json.loads(results_path.read_text())
 
 
 
 
48
  df = pd.read_parquet(dataset_path)
49
  predictions = (
50
  pd.read_parquet(predictions_path)
@@ -84,7 +89,17 @@ def render_report(
84
  "- _No targets trained successfully — see logs._"
85
  ])
86
  lines.append("")
87
- lines.append(f"Trained on **{len(df):,}** strains with **{len([c for c in df.columns if c.startswith(('aa_frac_', 'genome_size', 'gc_', 'n_predicted', 'coding_', 'mean_', 'aromatic_', 'pos_', 'neg_', 'ivywrel_', 'median_'))])}** genome-derived features. "
 
 
 
 
 
 
 
 
 
 
88
  f"Cross-validation: 5-fold GroupKFold by taxonomic family.")
89
  lines.append("")
90
 
@@ -175,12 +190,13 @@ def render_report(
175
  lines.append("")
176
 
177
  # Section: feature-target correlations (data-exploration sanity check)
178
- feature_cols = [
179
  c for c in df.columns
180
  if c.startswith(("aa_frac_", "genome_size", "gc_", "n_predicted", "coding_",
181
- "mean_", "aromatic_", "pos_", "neg_", "ivywrel_", "median_"))
 
182
  ]
183
- if feature_cols:
184
  from microbe_model.explore import feature_target_correlations
185
  lines.append("## Feature ↔ target correlations (Spearman, top 10)")
186
  lines.append("")
@@ -189,7 +205,7 @@ def render_report(
189
  "`optimal_temperature_c` (Zeldovich 2007 thermophile signature).")
190
  lines.append("")
191
  for target in ("optimal_temperature_c", "optimal_ph", "salt_tolerance_pct"):
192
- corrs = feature_target_correlations(df, feature_cols, target, top_n=10)
193
  if not corrs:
194
  continue
195
  lines.append(f"### `{target}`")
 
43
  n_strains: int | None = None,
44
  runtime_seconds: float | None = None,
45
  predictions_path: Path | None = None,
46
+ feature_cols: list[str] | None = None,
47
  ) -> None:
48
+ raw_results: dict[str, Any] = json.loads(results_path.read_text())
49
+ meta = raw_results.pop("__meta__", {})
50
+ if feature_cols is None and "feature_cols" in meta:
51
+ feature_cols = meta["feature_cols"]
52
+ results: dict[str, Any] = raw_results
53
  df = pd.read_parquet(dataset_path)
54
  predictions = (
55
  pd.read_parquet(predictions_path)
 
89
  "- _No targets trained successfully — see logs._"
90
  ])
91
  lines.append("")
92
+ n_features = (
93
+ len(feature_cols) if feature_cols is not None
94
+ else sum(
95
+ 1 for c in df.columns
96
+ if c.startswith((
97
+ "aa_frac_", "genome_size", "gc_", "n_predicted", "coding_",
98
+ "mean_", "aromatic_", "pos_", "neg_", "ivywrel_", "median_",
99
+ ))
100
+ )
101
+ )
102
+ lines.append(f"Trained on **{len(df):,}** strains with **{n_features}** genome-derived features. "
103
  f"Cross-validation: 5-fold GroupKFold by taxonomic family.")
104
  lines.append("")
105
 
 
190
  lines.append("")
191
 
192
  # Section: feature-target correlations (data-exploration sanity check)
193
+ detected_feature_cols = feature_cols if feature_cols is not None else [
194
  c for c in df.columns
195
  if c.startswith(("aa_frac_", "genome_size", "gc_", "n_predicted", "coding_",
196
+ "mean_", "aromatic_", "pos_", "neg_", "ivywrel_", "median_", "f"))
197
+ and pd.api.types.is_numeric_dtype(df[c])
198
  ]
199
+ if detected_feature_cols:
200
  from microbe_model.explore import feature_target_correlations
201
  lines.append("## Feature ↔ target correlations (Spearman, top 10)")
202
  lines.append("")
 
205
  "`optimal_temperature_c` (Zeldovich 2007 thermophile signature).")
206
  lines.append("")
207
  for target in ("optimal_temperature_c", "optimal_ph", "salt_tolerance_pct"):
208
+ corrs = feature_target_correlations(df, detected_feature_cols, target, top_n=10)
209
  if not corrs:
210
  continue
211
  lines.append(f"### `{target}`")
src/microbe_model/train/baseline.py CHANGED
@@ -170,8 +170,9 @@ def save_results(
170
  path: Path,
171
  *,
172
  predictions_path: Path | None = None,
 
173
  ) -> None:
174
- payload = {
175
  target: {
176
  "task": r.task,
177
  "mean_metric": r.mean(),
@@ -182,6 +183,8 @@ def save_results(
182
  }
183
  for target, r in results.items()
184
  }
 
 
185
  path.write_text(json.dumps(payload, indent=2))
186
 
187
  if predictions_path is not None:
@@ -190,6 +193,11 @@ def save_results(
190
  if r.predictions is None or r.predictions.empty:
191
  continue
192
  df = r.predictions.copy()
 
 
 
 
 
193
  df["target"] = target
194
  df["task"] = r.task
195
  frames.append(df)
 
170
  path: Path,
171
  *,
172
  predictions_path: Path | None = None,
173
+ feature_cols: list[str] | None = None,
174
  ) -> None:
175
+ payload: dict[str, Any] = {
176
  target: {
177
  "task": r.task,
178
  "mean_metric": r.mean(),
 
183
  }
184
  for target, r in results.items()
185
  }
186
+ if feature_cols is not None:
187
+ payload["__meta__"] = {"feature_cols": list(feature_cols)}
188
  path.write_text(json.dumps(payload, indent=2))
189
 
190
  if predictions_path is not None:
 
193
  if r.predictions is None or r.predictions.empty:
194
  continue
195
  df = r.predictions.copy()
196
+ # Cast to str for parquet compatibility — predicted/observed can be float
197
+ # (regression) or class label (classification). Eval re-casts numerics
198
+ # via pd.to_numeric where needed.
199
+ df["predicted"] = df["predicted"].astype(str)
200
+ df["observed"] = df["observed"].astype(str)
201
  df["target"] = target
202
  df["task"] = r.task
203
  frames.append(df)
tests/test_integration.py CHANGED
@@ -96,3 +96,49 @@ def test_save_results_roundtrip(tmp_path: Path) -> None:
96
  assert "task" in loaded[target]
97
  assert "mean_metric" in loaded[target]
98
  assert "folds" in loaded[target]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  assert "task" in loaded[target]
97
  assert "mean_metric" in loaded[target]
98
  assert "folds" in loaded[target]
99
+
100
+
101
+ def test_save_results_writes_predictions_parquet(tmp_path: Path) -> None:
102
+ df, feature_cols = _synthetic_dataset(n=200)
103
+ results = train_all(df, feature_cols, group_col_override="group")
104
+
105
+ results_path = tmp_path / "results.json"
106
+ pred_path = tmp_path / "predictions.parquet"
107
+ save_results(results, results_path, predictions_path=pred_path)
108
+ assert pred_path.exists()
109
+
110
+ preds = pd.read_parquet(pred_path)
111
+ # Should have rows for both regression and classification targets
112
+ assert "target" in preds.columns
113
+ assert "task" in preds.columns
114
+ assert "row_idx" in preds.columns
115
+ assert "predicted" in preds.columns
116
+ assert "observed" in preds.columns
117
+ assert preds["task"].isin({"regression", "classification"}).all()
118
+ # row_idx should map back to the source df
119
+ assert preds["row_idx"].max() < len(df)
120
+
121
+
122
+ def test_full_chain_render_with_predictions(tmp_path: Path) -> None:
123
+ """Full chain: train → save with predictions → render report → check per-family section."""
124
+ df, feature_cols = _synthetic_dataset(n=200)
125
+ results = train_all(df, feature_cols, group_col_override="group")
126
+
127
+ results_path = tmp_path / "results.json"
128
+ pred_path = tmp_path / "predictions.parquet"
129
+ save_results(results, results_path, predictions_path=pred_path)
130
+
131
+ table_path = tmp_path / "table.parquet"
132
+ df.to_parquet(table_path, index=False)
133
+
134
+ out_path = tmp_path / "report.md"
135
+ render_report(
136
+ results_path, table_path, out_path,
137
+ predictions_path=pred_path,
138
+ feature_cols=feature_cols,
139
+ )
140
+ text = out_path.read_text()
141
+
142
+ assert "## Per-family error breakdown" in text
143
+ assert "## Feature ↔ target correlations" in text
144
+ assert "## TL;DR" in text