Spaces:
Running
Fix classification fold bug + add end-to-end integration tests
Browse filesBug: XGBClassifier requires contiguous class labels 0..k-1, but global LabelEncoder
followed by GroupKFold can produce train folds with non-contiguous subsets (e.g.
classes {0,1,3,4,5} when class 2 happened to be all in the test fold). xgboost
raised ValueError and the entire classification target failed.
Fix in src/microbe_model/train/baseline.py:
- Re-encode labels per fold (LabelEncoder fit on train fold only)
- Drop test samples whose class never appeared in train (correct behavior — model
cannot be evaluated on a class it has never seen)
- Skip folds where train fold has fewer than 2 distinct classes
Caught by a partial-data smoke run on the live featurize output. Now caught
automatically by tests/test_integration.py:
- test_train_all_handles_classification_with_missing_classes_per_fold — exercises
the bug case with synthetic data containing 5 oxygen classes across 12 families
- test_render_report_writes_markdown — full train→save→render path
- test_save_results_roundtrip — JSON serialization
All three pass. Total tests: 15/15.
Featurize at 11% (1799/17094 strains, 0 failures across the version-fallback path,
3 ascertainment failures elsewhere — all within tolerance).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- src/microbe_model/train/baseline.py +23 -12
- tests/test_integration.py +98 -0
|
@@ -60,10 +60,9 @@ def train_target(
|
|
| 60 |
return TargetResult(target=target, task=task)
|
| 61 |
|
| 62 |
if task == "classification":
|
| 63 |
-
|
| 64 |
-
y_enc = encoder.fit_transform(y.astype(str))
|
| 65 |
else:
|
| 66 |
-
|
| 67 |
|
| 68 |
n_unique_groups = groups.nunique()
|
| 69 |
splits = min(n_splits, max(2, n_unique_groups))
|
|
@@ -73,11 +72,21 @@ def train_target(
|
|
| 73 |
importance_acc = np.zeros(len(feature_cols), dtype=float)
|
| 74 |
fold_count = 0
|
| 75 |
|
| 76 |
-
|
|
|
|
| 77 |
if task == "classification":
|
| 78 |
-
|
| 79 |
-
|
|
|
|
|
|
|
|
|
|
| 80 |
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
model = xgb.XGBClassifier(
|
| 82 |
n_estimators=300,
|
| 83 |
max_depth=5,
|
|
@@ -86,10 +95,11 @@ def train_target(
|
|
| 86 |
n_jobs=-1,
|
| 87 |
eval_metric="mlogloss",
|
| 88 |
)
|
| 89 |
-
model.fit(X.iloc[tr_idx],
|
| 90 |
-
preds = model.predict(X.iloc[te_idx])
|
| 91 |
-
score = f1_score(
|
| 92 |
metric = "f1_macro"
|
|
|
|
| 93 |
else:
|
| 94 |
model = xgb.XGBRegressor(
|
| 95 |
n_estimators=500,
|
|
@@ -98,10 +108,11 @@ def train_target(
|
|
| 98 |
tree_method="hist",
|
| 99 |
n_jobs=-1,
|
| 100 |
)
|
| 101 |
-
model.fit(X.iloc[tr_idx],
|
| 102 |
preds = model.predict(X.iloc[te_idx])
|
| 103 |
-
score = mean_absolute_error(
|
| 104 |
metric = "mae"
|
|
|
|
| 105 |
|
| 106 |
result.folds.append(FoldResult(
|
| 107 |
target=target,
|
|
@@ -109,7 +120,7 @@ def train_target(
|
|
| 109 |
metric_name=metric,
|
| 110 |
value=float(score),
|
| 111 |
n_train=int(len(tr_idx)),
|
| 112 |
-
n_test=
|
| 113 |
))
|
| 114 |
importance_acc += model.feature_importances_
|
| 115 |
fold_count += 1
|
|
|
|
| 60 |
return TargetResult(target=target, task=task)
|
| 61 |
|
| 62 |
if task == "classification":
|
| 63 |
+
y_str = y.astype(str).to_numpy()
|
|
|
|
| 64 |
else:
|
| 65 |
+
y_arr = y.to_numpy(dtype=float)
|
| 66 |
|
| 67 |
n_unique_groups = groups.nunique()
|
| 68 |
splits = min(n_splits, max(2, n_unique_groups))
|
|
|
|
| 72 |
importance_acc = np.zeros(len(feature_cols), dtype=float)
|
| 73 |
fold_count = 0
|
| 74 |
|
| 75 |
+
split_iter = kfold.split(X, y_str if task == "classification" else y_arr, groups)
|
| 76 |
+
for tr_idx, te_idx in split_iter:
|
| 77 |
if task == "classification":
|
| 78 |
+
# Per-fold encoding: ensures contiguous 0..k-1 labels for xgboost.
|
| 79 |
+
# Test samples whose class never appears in train are dropped from eval.
|
| 80 |
+
fold_encoder = LabelEncoder()
|
| 81 |
+
y_tr = fold_encoder.fit_transform(y_str[tr_idx])
|
| 82 |
+
if len(fold_encoder.classes_) < 2:
|
| 83 |
continue
|
| 84 |
+
known = set(fold_encoder.classes_)
|
| 85 |
+
te_mask = np.array([c in known for c in y_str[te_idx]])
|
| 86 |
+
if te_mask.sum() == 0:
|
| 87 |
+
continue
|
| 88 |
+
y_te = fold_encoder.transform(y_str[te_idx][te_mask])
|
| 89 |
+
|
| 90 |
model = xgb.XGBClassifier(
|
| 91 |
n_estimators=300,
|
| 92 |
max_depth=5,
|
|
|
|
| 95 |
n_jobs=-1,
|
| 96 |
eval_metric="mlogloss",
|
| 97 |
)
|
| 98 |
+
model.fit(X.iloc[tr_idx], y_tr)
|
| 99 |
+
preds = model.predict(X.iloc[te_idx][te_mask])
|
| 100 |
+
score = f1_score(y_te, preds, average="macro")
|
| 101 |
metric = "f1_macro"
|
| 102 |
+
n_test = int(te_mask.sum())
|
| 103 |
else:
|
| 104 |
model = xgb.XGBRegressor(
|
| 105 |
n_estimators=500,
|
|
|
|
| 108 |
tree_method="hist",
|
| 109 |
n_jobs=-1,
|
| 110 |
)
|
| 111 |
+
model.fit(X.iloc[tr_idx], y_arr[tr_idx])
|
| 112 |
preds = model.predict(X.iloc[te_idx])
|
| 113 |
+
score = mean_absolute_error(y_arr[te_idx], preds)
|
| 114 |
metric = "mae"
|
| 115 |
+
n_test = int(len(te_idx))
|
| 116 |
|
| 117 |
result.folds.append(FoldResult(
|
| 118 |
target=target,
|
|
|
|
| 120 |
metric_name=metric,
|
| 121 |
value=float(score),
|
| 122 |
n_train=int(len(tr_idx)),
|
| 123 |
+
n_test=n_test,
|
| 124 |
))
|
| 125 |
importance_acc += model.feature_importances_
|
| 126 |
fold_count += 1
|
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Integration test: train_all + render_report end-to-end on synthetic data.
|
| 2 |
+
|
| 3 |
+
Exercises the contiguous-class fix in the classification path, the GroupKFold split,
|
| 4 |
+
and the markdown rendering — without needing real BacDive or NCBI data.
|
| 5 |
+
"""
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
import json
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
import pandas as pd
|
| 13 |
+
|
| 14 |
+
from microbe_model.eval import render_report
|
| 15 |
+
from microbe_model.train.baseline import save_results, train_all
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _synthetic_dataset(n: int = 300, seed: int = 0) -> tuple[pd.DataFrame, list[str]]:
|
| 19 |
+
rng = np.random.default_rng(seed)
|
| 20 |
+
feature_cols = [f"f{i}" for i in range(8)]
|
| 21 |
+
df = pd.DataFrame(rng.normal(size=(n, 8)), columns=feature_cols)
|
| 22 |
+
df["bacdive_id"] = np.arange(n)
|
| 23 |
+
df["genome_accession"] = [f"GCA_{i:09d}.1" for i in range(n)]
|
| 24 |
+
df["family"] = [f"family_{i % 12}" for i in range(n)]
|
| 25 |
+
df["genus"] = [f"genus_{i % 30}" for i in range(n)]
|
| 26 |
+
df["species"] = [f"species_{i}" for i in range(n)]
|
| 27 |
+
|
| 28 |
+
# Regression target with real signal in f0 + noise
|
| 29 |
+
df["optimal_temperature_c"] = 30 + 5 * df["f0"] + rng.normal(scale=2, size=n)
|
| 30 |
+
df["optimal_ph"] = 7.0 + 0.5 * df["f1"] + rng.normal(scale=0.1, size=n)
|
| 31 |
+
df["salt_tolerance_pct"] = np.abs(2 + df["f2"] + rng.normal(scale=0.5, size=n))
|
| 32 |
+
|
| 33 |
+
# Classification target — sometimes only some classes appear in a fold
|
| 34 |
+
classes = ["aerobe", "anaerobe", "facultative", "microaerophile", "obligate aerobe"]
|
| 35 |
+
df["oxygen_requirement"] = rng.choice(classes, size=n)
|
| 36 |
+
|
| 37 |
+
# Inject some NaNs to mirror real BacDive sparsity
|
| 38 |
+
nan_mask = rng.random(n) > 0.7
|
| 39 |
+
df.loc[nan_mask, "optimal_ph"] = np.nan
|
| 40 |
+
nan_mask = rng.random(n) > 0.5
|
| 41 |
+
df.loc[nan_mask, "salt_tolerance_pct"] = np.nan
|
| 42 |
+
|
| 43 |
+
df["group"] = df["family"]
|
| 44 |
+
return df, feature_cols
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def test_train_all_handles_classification_with_missing_classes_per_fold(tmp_path: Path) -> None:
|
| 48 |
+
df, feature_cols = _synthetic_dataset(n=200)
|
| 49 |
+
results = train_all(df, feature_cols, group_col_override="group")
|
| 50 |
+
|
| 51 |
+
# All four targets should produce at least one fold of results
|
| 52 |
+
for target in ("optimal_temperature_c", "optimal_ph", "oxygen_requirement", "salt_tolerance_pct"):
|
| 53 |
+
assert target in results
|
| 54 |
+
assert results[target].folds, f"{target} produced no folds"
|
| 55 |
+
|
| 56 |
+
# Regression should beat the always-mean baseline since f0 carries real signal
|
| 57 |
+
temp_result = results["optimal_temperature_c"]
|
| 58 |
+
baseline_mae = float(np.mean(np.abs(
|
| 59 |
+
df["optimal_temperature_c"] - df["optimal_temperature_c"].mean()
|
| 60 |
+
)))
|
| 61 |
+
assert temp_result.mean() < baseline_mae, "model worse than always-mean baseline"
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def test_render_report_writes_markdown(tmp_path: Path) -> None:
|
| 65 |
+
df, feature_cols = _synthetic_dataset(n=150)
|
| 66 |
+
results = train_all(df, feature_cols, group_col_override="group")
|
| 67 |
+
|
| 68 |
+
results_path = tmp_path / "results.json"
|
| 69 |
+
save_results(results, results_path)
|
| 70 |
+
|
| 71 |
+
table_path = tmp_path / "table.parquet"
|
| 72 |
+
df.to_parquet(table_path, index=False)
|
| 73 |
+
|
| 74 |
+
out_path = tmp_path / "report.md"
|
| 75 |
+
render_report(results_path, table_path, out_path)
|
| 76 |
+
text = out_path.read_text()
|
| 77 |
+
|
| 78 |
+
assert text.startswith("# microbe-model")
|
| 79 |
+
assert "## Per-target results" in text
|
| 80 |
+
assert "optimal_temperature_c" in text
|
| 81 |
+
assert "oxygen_requirement" in text
|
| 82 |
+
assert "## Known limitations" in text
|
| 83 |
+
assert "## Next steps" in text
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def test_save_results_roundtrip(tmp_path: Path) -> None:
|
| 87 |
+
df, feature_cols = _synthetic_dataset(n=100)
|
| 88 |
+
results = train_all(df, feature_cols, group_col_override="group")
|
| 89 |
+
|
| 90 |
+
path = tmp_path / "results.json"
|
| 91 |
+
save_results(results, path)
|
| 92 |
+
loaded = json.loads(path.read_text())
|
| 93 |
+
|
| 94 |
+
for target in results:
|
| 95 |
+
assert target in loaded
|
| 96 |
+
assert "task" in loaded[target]
|
| 97 |
+
assert "mean_metric" in loaded[target]
|
| 98 |
+
assert "folds" in loaded[target]
|