| from __future__ import annotations |
|
|
| import json |
| from pathlib import Path |
|
|
| import pandas as pd |
|
|
| from sepsis_mcp.tabular_sweep import TabularSweepConfig, build_sweep_jobs, main |
|
|
|
|
| def _write_patient(path: Path, *, positive: bool, shifted: bool) -> None: |
| base_hr = 78 if positive else 62 |
| if shifted: |
| base_hr += 6 |
|
|
| rows = ["HR|O2Sat|Age|Gender|Unit1|Unit2|HospAdmTime|ICULOS|SepsisLabel"] |
| for hour in range(1, 9): |
| hr = base_hr + hour * (3 if positive else 1) |
| o2sat = 96 - hour if positive else 98 - (hour % 2) |
| if shifted and hour in {2, 3, 6}: |
| o2sat_value = "NaN" |
| else: |
| o2sat_value = str(o2sat) |
| sepsis_label = 1 if positive and hour >= 5 else 0 |
| rows.append(f"{hr}|{o2sat_value}|65|1|1|0|-5|{hour}|{sepsis_label}") |
|
|
| path.write_text("\n".join(rows), encoding="utf-8") |
|
|
|
|
| def _write_dataset(root: Path) -> None: |
| training_a = root / "training_setA" |
| training_b = root / "training_setB" |
| training_a.mkdir(parents=True) |
| training_b.mkdir(parents=True) |
|
|
| for index in range(10): |
| _write_patient( |
| training_a / f"pA{index:03d}.psv", |
| positive=index % 2 == 0, |
| shifted=False, |
| ) |
|
|
| for index in range(4): |
| _write_patient( |
| training_b / f"pB{index:03d}.psv", |
| positive=index % 2 == 0, |
| shifted=True, |
| ) |
|
|
|
|
| def test_build_sweep_jobs_keeps_same_hospital_test_paths_fixed(tmp_path: Path) -> None: |
| data_root = tmp_path / "training" |
| _write_dataset(data_root) |
| config = TabularSweepConfig( |
| data_root=data_root, |
| train_hospital="A", |
| train_patients=4, |
| calibration_patients_grid=(1, 2), |
| test_patients=2, |
| model_type_grid=("sklearn_gbdt", "logistic_regression"), |
| alpha_grid=(0.2,), |
| mask_strategy_grid=("none",), |
| output_dir=tmp_path / "out", |
| ) |
|
|
| jobs = build_sweep_jobs(config) |
|
|
| first_aa = [path.name for path in jobs[0][2]["A_to_A"][2]] |
| second_aa = [path.name for path in jobs[2][2]["A_to_A"][2]] |
|
|
| assert len(jobs) == 4 |
| assert first_aa == second_aa |
| assert len(jobs[0][2]["A_to_A"][1]) == 1 |
| assert len(jobs[2][2]["A_to_A"][1]) == 2 |
|
|
|
|
| def test_tabular_sweep_main_writes_summary_csv(tmp_path: Path) -> None: |
| data_root = tmp_path / "training" |
| _write_dataset(data_root) |
| output_dir = tmp_path / "sweep-output" |
|
|
| main( |
| [ |
| "--data-root", |
| str(data_root), |
| "--train-hospital", |
| "A", |
| "--test-mode", |
| "both", |
| "--train-patients", |
| "4", |
| "--calibration-patients-grid", |
| "1,2", |
| "--test-patients", |
| "2", |
| "--lookback-hours", |
| "3", |
| "--horizon-hours", |
| "2", |
| "--alpha-grid", |
| "0.2", |
| "--model-type-grid", |
| "sklearn_gbdt,logistic_regression", |
| "--mask-strategy-grid", |
| "none,random_drop", |
| "--mask-rate", |
| "0.4", |
| "--output-dir", |
| str(output_dir), |
| ] |
| ) |
|
|
| results = pd.read_csv(output_dir / "sweep_results.csv") |
| overall_summary = pd.read_csv(output_dir / "overall_summary.csv") |
| subgroup_summary = pd.read_csv(output_dir / "subgroup_summary.csv") |
| config = json.loads((output_dir / "sweep_config.json").read_text(encoding="utf-8")) |
|
|
| assert len(results) == 112 |
| assert len(overall_summary) == 112 |
| assert {"low", "high"} <= set(subgroup_summary["group_label"]) |
| assert {"run_id", "experiment", "method", "model_type", "calibration_patients", "alpha", "mask_strategy"} <= set( |
| results.columns |
| ) |
| assert set(results["method"]) == { |
| "standard", |
| "missingness_aware", |
| "label_conditional", |
| "label_conditional_missingness_aware", |
| "mondrian_tilted", |
| "shrunk_weighted_missingness_aware", |
| "weighted_missingness_aware", |
| } |
| assert set(results["model_type"]) == {"sklearn_gbdt", "logistic_regression"} |
| assert config["calibration_patients_grid"] == [1, 2] |
| run_dirs = sorted((output_dir / "runs").glob("*")) |
| assert len(run_dirs) == 8 |
| assert all((run_dir / "metrics.json").exists() for run_dir in run_dirs) |
| assert all((run_dir / "subgroup_coverage.csv").exists() for run_dir in run_dirs) |
|
|