misscp / tests /test_runner.py
Anonymous
Initial anonymous MissCP release
32f5a65
from __future__ import annotations
import json
from pathlib import Path
import pandas as pd
import pytest
from sepsis_mcp.cli import main
from sepsis_mcp.constants import LABEL_COLUMN
import sepsis_mcp.external_baselines as external_baselines_module
from sepsis_mcp.io import load_patient_frame
import sepsis_mcp.runner as runner_module
from sepsis_mcp.runner import RunConfig, _split_train_calibration_test
def _write_patient(path: Path, *, positive: bool, shifted: bool) -> None:
base_hr = 78 if positive else 62
if shifted:
base_hr += 6
rows = ["HR|O2Sat|Age|Gender|Unit1|Unit2|HospAdmTime|ICULOS|SepsisLabel"]
for hour in range(1, 9):
hr = base_hr + hour * (3 if positive else 1)
o2sat = 96 - hour if positive else 98 - (hour % 2)
if shifted and hour in {2, 3, 6}:
o2sat_value = "NaN"
else:
o2sat_value = str(o2sat)
sepsis_label = 1 if positive and hour >= 5 else 0
rows.append(
f"{hr}|{o2sat_value}|65|1|1|0|-5|{hour}|{sepsis_label}"
)
path.write_text("\n".join(rows), encoding="utf-8")
def test_cli_run_generates_metrics_outputs_and_plot(tmp_path: Path) -> None:
data_root = tmp_path / "training"
training_a = data_root / "training_setA"
training_b = data_root / "training_setB"
training_a.mkdir(parents=True)
training_b.mkdir(parents=True)
for index in range(8):
_write_patient(
training_a / f"pA{index:03d}.psv",
positive=index % 2 == 0,
shifted=False,
)
for index in range(4):
_write_patient(
training_b / f"pB{index:03d}.psv",
positive=index % 2 == 0,
shifted=True,
)
output_dir = tmp_path / "outputs"
main(
[
"run",
"--data-root",
str(data_root),
"--train-hospital",
"A",
"--test-mode",
"both",
"--train-patients",
"4",
"--calibration-patients",
"2",
"--test-patients",
"2",
"--lookback-hours",
"3",
"--horizon-hours",
"2",
"--alpha",
"0.2",
"--output-dir",
str(output_dir),
]
)
metrics_path = output_dir / "metrics.json"
subgroup_path = output_dir / "subgroup_coverage.csv"
plot_path = output_dir / "coverage_by_missingness.png"
assert metrics_path.exists()
assert subgroup_path.exists()
assert plot_path.exists()
metrics = json.loads(metrics_path.read_text(encoding="utf-8"))
subgroup_frame = pd.read_csv(subgroup_path)
assert set(metrics["experiments"]) == {"A_to_A", "A_to_B"}
assert set(metrics["experiments"]["A_to_A"]) == {
"standard",
"missingness_aware",
"mondrian_tilted",
"label_conditional",
"label_conditional_missingness_aware",
"shrunk_weighted_missingness_aware",
"weighted_missingness_aware",
}
assert "auroc" in metrics["experiments"]["A_to_A"]["standard"]
assert "empirical_coverage" in metrics["experiments"]["A_to_B"]["missingness_aware"]
assert set(subgroup_frame["method"]) == {
"standard",
"missingness_aware",
"mondrian_tilted",
"label_conditional",
"label_conditional_missingness_aware",
"shrunk_weighted_missingness_aware",
"weighted_missingness_aware",
}
assert set(subgroup_frame["experiment"]) == {"A_to_A", "A_to_B"}
def test_cli_run_with_random_drop_masking_encodes_mask_condition_in_outputs(tmp_path: Path) -> None:
data_root = tmp_path / "training"
training_a = data_root / "training_setA"
training_b = data_root / "training_setB"
training_a.mkdir(parents=True)
training_b.mkdir(parents=True)
for index in range(8):
_write_patient(
training_a / f"pA{index:03d}.psv",
positive=index % 2 == 0,
shifted=False,
)
for index in range(4):
_write_patient(
training_b / f"pB{index:03d}.psv",
positive=index % 2 == 0,
shifted=False,
)
output_dir = tmp_path / "masked-outputs"
main(
[
"run",
"--data-root",
str(data_root),
"--train-hospital",
"A",
"--test-mode",
"both",
"--train-patients",
"4",
"--calibration-patients",
"2",
"--test-patients",
"2",
"--lookback-hours",
"3",
"--horizon-hours",
"2",
"--alpha",
"0.2",
"--mask-strategy",
"random_drop",
"--mask-rate",
"0.4",
"--output-dir",
str(output_dir),
]
)
metrics = json.loads((output_dir / "metrics.json").read_text(encoding="utf-8"))
assert metrics["config"]["mask_strategy"] == "random_drop"
assert set(metrics["experiments"]) == {"A_to_A__random_drop", "A_to_B__random_drop"}
def _patient_positive_count(patient_paths: list[Path]) -> int:
return sum(int(load_patient_frame(path)[LABEL_COLUMN].max() > 0) for path in patient_paths)
def test_split_train_calibration_test_shuffles_and_stratifies_patients(tmp_path: Path) -> None:
data_root = tmp_path / "training"
training_a = data_root / "training_setA"
training_a.mkdir(parents=True)
for index in range(12):
_write_patient(
training_a / f"pA{index:03d}.psv",
positive=index % 2 == 0,
shifted=False,
)
split_zero = _split_train_calibration_test(
RunConfig(
data_root=data_root,
train_hospital="A",
test_mode="aa",
train_patients=4,
calibration_patients=4,
test_patients=4,
random_state=0,
)
)["A_to_A"]
split_one = _split_train_calibration_test(
RunConfig(
data_root=data_root,
train_hospital="A",
test_mode="aa",
train_patients=4,
calibration_patients=4,
test_patients=4,
random_state=1,
)
)["A_to_A"]
for patient_paths in split_zero:
assert len(patient_paths) == 4
assert _patient_positive_count(patient_paths) == 2
first_four = [f"pA{index:03d}.psv" for index in range(4)]
split_zero_train_names = [path.name for path in split_zero[0]]
split_one_train_names = [path.name for path in split_one[0]]
assert split_zero_train_names != first_four
assert split_zero_train_names != split_one_train_names
def test_cli_run_with_coverage_gap_variable_grouping_writes_selection_metadata(tmp_path: Path) -> None:
data_root = tmp_path / "training"
training_a = data_root / "training_setA"
training_a.mkdir(parents=True)
for index in range(16):
_write_patient(
training_a / f"pA{index:03d}.psv",
positive=index % 2 == 0,
shifted=index >= 8,
)
output_dir = tmp_path / "structured-outputs"
main(
[
"run",
"--data-root",
str(data_root),
"--train-hospital",
"A",
"--test-mode",
"aa",
"--model-type",
"logistic_regression",
"--train-patients",
"4",
"--selection-patients",
"4",
"--calibration-patients",
"4",
"--test-patients",
"4",
"--lookback-hours",
"3",
"--horizon-hours",
"2",
"--alpha",
"0.2",
"--missingness-grouping-strategy",
"coverage_gap_variable",
"--output-dir",
str(output_dir),
]
)
metrics = json.loads((output_dir / "metrics.json").read_text(encoding="utf-8"))
subgroup_frame = pd.read_csv(output_dir / "subgroup_coverage.csv")
assert metrics["config"]["missingness_grouping_strategy"] == "coverage_gap_variable"
assert metrics["structured_grouping"]["strategy"] == "coverage_gap_variable"
assert "selected_variable" in metrics["structured_grouping"]
assert set(subgroup_frame["group_label"]) != {"low", "medium", "high"}
def test_run_config_defaults_opt_in_baseline_flags_to_false(tmp_path: Path) -> None:
config = RunConfig(data_root=tmp_path)
assert config.enable_external_baselines is False
assert config.enable_learned_partition is False
def test_cli_run_with_learned_partition_requires_selection_split(tmp_path: Path) -> None:
data_root = tmp_path / "training"
training_a = data_root / "training_setA"
training_b = data_root / "training_setB"
training_a.mkdir(parents=True)
training_b.mkdir(parents=True)
for index in range(8):
_write_patient(
training_a / f"pA{index:03d}.psv",
positive=index % 2 == 0,
shifted=False,
)
for index in range(4):
_write_patient(
training_b / f"pB{index:03d}.psv",
positive=index % 2 == 0,
shifted=True,
)
with pytest.raises(ValueError, match="selection split"):
main(
[
"run",
"--data-root",
str(data_root),
"--train-hospital",
"A",
"--test-mode",
"both",
"--train-patients",
"4",
"--calibration-patients",
"2",
"--test-patients",
"2",
"--lookback-hours",
"3",
"--horizon-hours",
"2",
"--alpha",
"0.2",
"--enable-learned-partition",
"--output-dir",
str(tmp_path / "outputs"),
]
)
def test_structured_grouping_uses_explicit_group_keyword_arguments(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
data_root = tmp_path / "training"
training_a = data_root / "training_setA"
training_a.mkdir(parents=True)
for index in range(16):
_write_patient(
training_a / f"pA{index:03d}.psv",
positive=index % 2 == 0,
shifted=index >= 8,
)
fit_calls: list[dict[str, object]] = []
predict_calls: list[dict[str, object]] = []
original_fit = runner_module.MissingnessAwareConformalClassifier.fit
original_predict_sets = runner_module.MissingnessAwareConformalClassifier.predict_sets
def fit_spy(self, calibration_labels, calibration_positive_probabilities, calibration_missing_rates=None, **kwargs):
fit_calls.append(
{
"calibration_missing_rates": calibration_missing_rates,
"kwargs": dict(kwargs),
}
)
return original_fit(
self,
calibration_labels,
calibration_positive_probabilities,
calibration_missing_rates,
**kwargs,
)
def predict_spy(self, positive_probabilities, missing_rates=None, **kwargs):
predict_calls.append(
{
"missing_rates": missing_rates,
"kwargs": dict(kwargs),
}
)
return original_predict_sets(
self,
positive_probabilities,
missing_rates,
**kwargs,
)
monkeypatch.setattr(runner_module.MissingnessAwareConformalClassifier, "fit", fit_spy)
monkeypatch.setattr(runner_module.MissingnessAwareConformalClassifier, "predict_sets", predict_spy)
main(
[
"run",
"--data-root",
str(data_root),
"--train-hospital",
"A",
"--test-mode",
"aa",
"--model-type",
"logistic_regression",
"--train-patients",
"4",
"--selection-patients",
"4",
"--calibration-patients",
"4",
"--test-patients",
"4",
"--lookback-hours",
"3",
"--horizon-hours",
"2",
"--alpha",
"0.2",
"--missingness-grouping-strategy",
"coverage_gap_variable",
"--output-dir",
str(tmp_path / "structured-outputs"),
]
)
assert any("calibration_group_ids" in call["kwargs"] for call in fit_calls)
assert any(call["calibration_missing_rates"] is None for call in fit_calls if "calibration_group_ids" in call["kwargs"])
assert any("test_group_ids" in call["kwargs"] for call in predict_calls)
assert any(call["missing_rates"] is None for call in predict_calls if "test_group_ids" in call["kwargs"])
def test_cli_run_adds_learned_partition_only_when_enabled(tmp_path: Path) -> None:
data_root = tmp_path / "training"
training_a = data_root / "training_setA"
training_a.mkdir(parents=True)
for index in range(16):
_write_patient(
training_a / f"pA{index:03d}.psv",
positive=index % 2 == 0,
shifted=index >= 8,
)
output_dir = tmp_path / "learned-partition-outputs"
main(
[
"run",
"--data-root",
str(data_root),
"--train-hospital",
"A",
"--test-mode",
"aa",
"--model-type",
"logistic_regression",
"--train-patients",
"4",
"--selection-patients",
"4",
"--calibration-patients",
"4",
"--test-patients",
"4",
"--lookback-hours",
"3",
"--horizon-hours",
"2",
"--alpha",
"0.2",
"--enable-learned-partition",
"--output-dir",
str(output_dir),
]
)
metrics = json.loads((output_dir / "metrics.json").read_text(encoding="utf-8"))
subgroup_frame = pd.read_csv(output_dir / "subgroup_coverage.csv")
assert "learned_partition" in set(metrics["experiments"]["A_to_A"])
assert "learned_partition" in set(subgroup_frame["method"])
def test_cli_run_adds_external_gibbs_methods_only_when_enabled(
tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
) -> None:
data_root = tmp_path / "training"
training_a = data_root / "training_setA"
training_b = data_root / "training_setB"
training_a.mkdir(parents=True)
training_b.mkdir(parents=True)
for index in range(8):
_write_patient(
training_a / f"pA{index:03d}.psv",
positive=index % 2 == 0,
shifted=False,
)
for index in range(4):
_write_patient(
training_b / f"pB{index:03d}.psv",
positive=index % 2 == 0,
shifted=True,
)
class FakeCondConf:
def __init__(self, score_fn, Phi_fn, quantile_fn, infinite_params):
self._score_fn = score_fn
self._phi_fn = Phi_fn
def setup_problem(self, X, Y, alpha=None): # noqa: N802
self._alpha = alpha
def predict(self, *args, **kwargs):
X = args[1] if len(args) > 1 else args[0]
return [0.2] * len(X)
original_import_module = external_baselines_module.importlib.import_module
def fake_import_module(name: str, package: str | None = None):
if name == "conditionalconformal":
return type("FakeModule", (), {"CondConf": FakeCondConf})
return original_import_module(name, package)
monkeypatch.setattr(external_baselines_module.importlib, "import_module", fake_import_module)
output_dir = tmp_path / "external-baselines-outputs"
main(
[
"run",
"--data-root",
str(data_root),
"--train-hospital",
"A",
"--test-mode",
"both",
"--train-patients",
"4",
"--calibration-patients",
"2",
"--test-patients",
"2",
"--lookback-hours",
"3",
"--horizon-hours",
"2",
"--alpha",
"0.2",
"--enable-external-baselines",
"--output-dir",
str(output_dir),
]
)
metrics = json.loads((output_dir / "metrics.json").read_text(encoding="utf-8"))
subgroup_frame = pd.read_csv(output_dir / "subgroup_coverage.csv")
assert {"gibbs_general", "gibbs_missingness"} <= set(metrics["experiments"]["A_to_A"])
assert {"gibbs_general", "gibbs_missingness"} <= set(subgroup_frame["method"])