misscp / tests /test_tabular_sweep.py
Anonymous
Initial anonymous MissCP release
32f5a65
from __future__ import annotations
import json
from pathlib import Path
import pandas as pd
from sepsis_mcp.tabular_sweep import TabularSweepConfig, build_sweep_jobs, main
def _write_patient(path: Path, *, positive: bool, shifted: bool) -> None:
base_hr = 78 if positive else 62
if shifted:
base_hr += 6
rows = ["HR|O2Sat|Age|Gender|Unit1|Unit2|HospAdmTime|ICULOS|SepsisLabel"]
for hour in range(1, 9):
hr = base_hr + hour * (3 if positive else 1)
o2sat = 96 - hour if positive else 98 - (hour % 2)
if shifted and hour in {2, 3, 6}:
o2sat_value = "NaN"
else:
o2sat_value = str(o2sat)
sepsis_label = 1 if positive and hour >= 5 else 0
rows.append(f"{hr}|{o2sat_value}|65|1|1|0|-5|{hour}|{sepsis_label}")
path.write_text("\n".join(rows), encoding="utf-8")
def _write_dataset(root: Path) -> None:
training_a = root / "training_setA"
training_b = root / "training_setB"
training_a.mkdir(parents=True)
training_b.mkdir(parents=True)
for index in range(10):
_write_patient(
training_a / f"pA{index:03d}.psv",
positive=index % 2 == 0,
shifted=False,
)
for index in range(4):
_write_patient(
training_b / f"pB{index:03d}.psv",
positive=index % 2 == 0,
shifted=True,
)
def test_build_sweep_jobs_keeps_same_hospital_test_paths_fixed(tmp_path: Path) -> None:
data_root = tmp_path / "training"
_write_dataset(data_root)
config = TabularSweepConfig(
data_root=data_root,
train_hospital="A",
train_patients=4,
calibration_patients_grid=(1, 2),
test_patients=2,
model_type_grid=("sklearn_gbdt", "logistic_regression"),
alpha_grid=(0.2,),
mask_strategy_grid=("none",),
output_dir=tmp_path / "out",
)
jobs = build_sweep_jobs(config)
first_aa = [path.name for path in jobs[0][2]["A_to_A"][2]]
second_aa = [path.name for path in jobs[2][2]["A_to_A"][2]]
assert len(jobs) == 4
assert first_aa == second_aa
assert len(jobs[0][2]["A_to_A"][1]) == 1
assert len(jobs[2][2]["A_to_A"][1]) == 2
def test_tabular_sweep_main_writes_summary_csv(tmp_path: Path) -> None:
data_root = tmp_path / "training"
_write_dataset(data_root)
output_dir = tmp_path / "sweep-output"
main(
[
"--data-root",
str(data_root),
"--train-hospital",
"A",
"--test-mode",
"both",
"--train-patients",
"4",
"--calibration-patients-grid",
"1,2",
"--test-patients",
"2",
"--lookback-hours",
"3",
"--horizon-hours",
"2",
"--alpha-grid",
"0.2",
"--model-type-grid",
"sklearn_gbdt,logistic_regression",
"--mask-strategy-grid",
"none,random_drop",
"--mask-rate",
"0.4",
"--output-dir",
str(output_dir),
]
)
results = pd.read_csv(output_dir / "sweep_results.csv")
overall_summary = pd.read_csv(output_dir / "overall_summary.csv")
subgroup_summary = pd.read_csv(output_dir / "subgroup_summary.csv")
config = json.loads((output_dir / "sweep_config.json").read_text(encoding="utf-8"))
assert len(results) == 112
assert len(overall_summary) == 112
assert {"low", "high"} <= set(subgroup_summary["group_label"])
assert {"run_id", "experiment", "method", "model_type", "calibration_patients", "alpha", "mask_strategy"} <= set(
results.columns
)
assert set(results["method"]) == {
"standard",
"missingness_aware",
"label_conditional",
"label_conditional_missingness_aware",
"mondrian_tilted",
"shrunk_weighted_missingness_aware",
"weighted_missingness_aware",
}
assert set(results["model_type"]) == {"sklearn_gbdt", "logistic_regression"}
assert config["calibration_patients_grid"] == [1, 2]
run_dirs = sorted((output_dir / "runs").glob("*"))
assert len(run_dirs) == 8
assert all((run_dir / "metrics.json").exists() for run_dir in run_dirs)
assert all((run_dir / "subgroup_coverage.csv").exists() for run_dir in run_dirs)