File size: 5,570 Bytes
7ba05d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bbbea9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
"""Integration test: train_all + render_report end-to-end on synthetic data.

Exercises the contiguous-class fix in the classification path, the GroupKFold split,
and the markdown rendering — without needing real BacDive or NCBI data.
"""
from __future__ import annotations

import json
from pathlib import Path

import numpy as np
import pandas as pd

from microbe_model.eval import render_report
from microbe_model.train.baseline import save_results, train_all


def _synthetic_dataset(n: int = 300, seed: int = 0) -> tuple[pd.DataFrame, list[str]]:
    rng = np.random.default_rng(seed)
    feature_cols = [f"f{i}" for i in range(8)]
    df = pd.DataFrame(rng.normal(size=(n, 8)), columns=feature_cols)
    df["bacdive_id"] = np.arange(n)
    df["genome_accession"] = [f"GCA_{i:09d}.1" for i in range(n)]
    df["family"] = [f"family_{i % 12}" for i in range(n)]
    df["genus"] = [f"genus_{i % 30}" for i in range(n)]
    df["species"] = [f"species_{i}" for i in range(n)]

    # Regression target with real signal in f0 + noise
    df["optimal_temperature_c"] = 30 + 5 * df["f0"] + rng.normal(scale=2, size=n)
    df["optimal_ph"] = 7.0 + 0.5 * df["f1"] + rng.normal(scale=0.1, size=n)
    df["salt_tolerance_pct"] = np.abs(2 + df["f2"] + rng.normal(scale=0.5, size=n))

    # Classification target — sometimes only some classes appear in a fold
    classes = ["aerobe", "anaerobe", "facultative", "microaerophile", "obligate aerobe"]
    df["oxygen_requirement"] = rng.choice(classes, size=n)

    # Inject some NaNs to mirror real BacDive sparsity
    nan_mask = rng.random(n) > 0.7
    df.loc[nan_mask, "optimal_ph"] = np.nan
    nan_mask = rng.random(n) > 0.5
    df.loc[nan_mask, "salt_tolerance_pct"] = np.nan

    df["group"] = df["family"]
    return df, feature_cols


def test_train_all_handles_classification_with_missing_classes_per_fold(tmp_path: Path) -> None:
    df, feature_cols = _synthetic_dataset(n=200)
    results = train_all(df, feature_cols, group_col_override="group")

    # All four targets should produce at least one fold of results
    for target in ("optimal_temperature_c", "optimal_ph", "oxygen_requirement", "salt_tolerance_pct"):
        assert target in results
        assert results[target].folds, f"{target} produced no folds"

    # Regression should beat the always-mean baseline since f0 carries real signal
    temp_result = results["optimal_temperature_c"]
    baseline_mae = float(np.mean(np.abs(
        df["optimal_temperature_c"] - df["optimal_temperature_c"].mean()
    )))
    assert temp_result.mean() < baseline_mae, "model worse than always-mean baseline"


def test_render_report_writes_markdown(tmp_path: Path) -> None:
    df, feature_cols = _synthetic_dataset(n=150)
    results = train_all(df, feature_cols, group_col_override="group")

    results_path = tmp_path / "results.json"
    save_results(results, results_path)

    table_path = tmp_path / "table.parquet"
    df.to_parquet(table_path, index=False)

    out_path = tmp_path / "report.md"
    render_report(results_path, table_path, out_path)
    text = out_path.read_text()

    assert text.startswith("# microbe-model")
    assert "## Per-target results" in text
    assert "optimal_temperature_c" in text
    assert "oxygen_requirement" in text
    assert "## Known limitations" in text
    assert "## Next steps" in text


def test_save_results_roundtrip(tmp_path: Path) -> None:
    df, feature_cols = _synthetic_dataset(n=100)
    results = train_all(df, feature_cols, group_col_override="group")

    path = tmp_path / "results.json"
    save_results(results, path)
    loaded = json.loads(path.read_text())

    for target in results:
        assert target in loaded
        assert "task" in loaded[target]
        assert "mean_metric" in loaded[target]
        assert "folds" in loaded[target]


def test_save_results_writes_predictions_parquet(tmp_path: Path) -> None:
    df, feature_cols = _synthetic_dataset(n=200)
    results = train_all(df, feature_cols, group_col_override="group")

    results_path = tmp_path / "results.json"
    pred_path = tmp_path / "predictions.parquet"
    save_results(results, results_path, predictions_path=pred_path)
    assert pred_path.exists()

    preds = pd.read_parquet(pred_path)
    # Should have rows for both regression and classification targets
    assert "target" in preds.columns
    assert "task" in preds.columns
    assert "row_idx" in preds.columns
    assert "predicted" in preds.columns
    assert "observed" in preds.columns
    assert preds["task"].isin({"regression", "classification"}).all()
    # row_idx should map back to the source df
    assert preds["row_idx"].max() < len(df)


def test_full_chain_render_with_predictions(tmp_path: Path) -> None:
    """Full chain: train → save with predictions → render report → check per-family section."""
    df, feature_cols = _synthetic_dataset(n=200)
    results = train_all(df, feature_cols, group_col_override="group")

    results_path = tmp_path / "results.json"
    pred_path = tmp_path / "predictions.parquet"
    save_results(results, results_path, predictions_path=pred_path)

    table_path = tmp_path / "table.parquet"
    df.to_parquet(table_path, index=False)

    out_path = tmp_path / "report.md"
    render_report(
        results_path, table_path, out_path,
        predictions_path=pred_path,
        feature_cols=feature_cols,
    )
    text = out_path.read_text()

    assert "## Per-family error breakdown" in text
    assert "## Feature ↔ target correlations" in text
    assert "## TL;DR" in text