| from __future__ import annotations |
|
|
| import json |
| from pathlib import Path |
|
|
| import pandas as pd |
| import pytest |
|
|
| from sepsis_mcp.cli import main |
| from sepsis_mcp.constants import LABEL_COLUMN |
| import sepsis_mcp.external_baselines as external_baselines_module |
| from sepsis_mcp.io import load_patient_frame |
| import sepsis_mcp.runner as runner_module |
| from sepsis_mcp.runner import RunConfig, _split_train_calibration_test |
|
|
|
|
| def _write_patient(path: Path, *, positive: bool, shifted: bool) -> None: |
| base_hr = 78 if positive else 62 |
| if shifted: |
| base_hr += 6 |
|
|
| rows = ["HR|O2Sat|Age|Gender|Unit1|Unit2|HospAdmTime|ICULOS|SepsisLabel"] |
| for hour in range(1, 9): |
| hr = base_hr + hour * (3 if positive else 1) |
| o2sat = 96 - hour if positive else 98 - (hour % 2) |
| if shifted and hour in {2, 3, 6}: |
| o2sat_value = "NaN" |
| else: |
| o2sat_value = str(o2sat) |
| sepsis_label = 1 if positive and hour >= 5 else 0 |
| rows.append( |
| f"{hr}|{o2sat_value}|65|1|1|0|-5|{hour}|{sepsis_label}" |
| ) |
|
|
| path.write_text("\n".join(rows), encoding="utf-8") |
|
|
|
|
| def test_cli_run_generates_metrics_outputs_and_plot(tmp_path: Path) -> None: |
| data_root = tmp_path / "training" |
| training_a = data_root / "training_setA" |
| training_b = data_root / "training_setB" |
| training_a.mkdir(parents=True) |
| training_b.mkdir(parents=True) |
|
|
| for index in range(8): |
| _write_patient( |
| training_a / f"pA{index:03d}.psv", |
| positive=index % 2 == 0, |
| shifted=False, |
| ) |
|
|
| for index in range(4): |
| _write_patient( |
| training_b / f"pB{index:03d}.psv", |
| positive=index % 2 == 0, |
| shifted=True, |
| ) |
|
|
| output_dir = tmp_path / "outputs" |
| main( |
| [ |
| "run", |
| "--data-root", |
| str(data_root), |
| "--train-hospital", |
| "A", |
| "--test-mode", |
| "both", |
| "--train-patients", |
| "4", |
| "--calibration-patients", |
| "2", |
| "--test-patients", |
| "2", |
| "--lookback-hours", |
| "3", |
| "--horizon-hours", |
| "2", |
| "--alpha", |
| "0.2", |
| "--output-dir", |
| str(output_dir), |
| ] |
| ) |
|
|
| metrics_path = output_dir / "metrics.json" |
| subgroup_path = output_dir / "subgroup_coverage.csv" |
| plot_path = output_dir / "coverage_by_missingness.png" |
|
|
| assert metrics_path.exists() |
| assert subgroup_path.exists() |
| assert plot_path.exists() |
|
|
| metrics = json.loads(metrics_path.read_text(encoding="utf-8")) |
| subgroup_frame = pd.read_csv(subgroup_path) |
|
|
| assert set(metrics["experiments"]) == {"A_to_A", "A_to_B"} |
| assert set(metrics["experiments"]["A_to_A"]) == { |
| "standard", |
| "missingness_aware", |
| "mondrian_tilted", |
| "label_conditional", |
| "label_conditional_missingness_aware", |
| "shrunk_weighted_missingness_aware", |
| "weighted_missingness_aware", |
| } |
| assert "auroc" in metrics["experiments"]["A_to_A"]["standard"] |
| assert "empirical_coverage" in metrics["experiments"]["A_to_B"]["missingness_aware"] |
| assert set(subgroup_frame["method"]) == { |
| "standard", |
| "missingness_aware", |
| "mondrian_tilted", |
| "label_conditional", |
| "label_conditional_missingness_aware", |
| "shrunk_weighted_missingness_aware", |
| "weighted_missingness_aware", |
| } |
| assert set(subgroup_frame["experiment"]) == {"A_to_A", "A_to_B"} |
|
|
|
|
| def test_cli_run_with_random_drop_masking_encodes_mask_condition_in_outputs(tmp_path: Path) -> None: |
| data_root = tmp_path / "training" |
| training_a = data_root / "training_setA" |
| training_b = data_root / "training_setB" |
| training_a.mkdir(parents=True) |
| training_b.mkdir(parents=True) |
|
|
| for index in range(8): |
| _write_patient( |
| training_a / f"pA{index:03d}.psv", |
| positive=index % 2 == 0, |
| shifted=False, |
| ) |
|
|
| for index in range(4): |
| _write_patient( |
| training_b / f"pB{index:03d}.psv", |
| positive=index % 2 == 0, |
| shifted=False, |
| ) |
|
|
| output_dir = tmp_path / "masked-outputs" |
| main( |
| [ |
| "run", |
| "--data-root", |
| str(data_root), |
| "--train-hospital", |
| "A", |
| "--test-mode", |
| "both", |
| "--train-patients", |
| "4", |
| "--calibration-patients", |
| "2", |
| "--test-patients", |
| "2", |
| "--lookback-hours", |
| "3", |
| "--horizon-hours", |
| "2", |
| "--alpha", |
| "0.2", |
| "--mask-strategy", |
| "random_drop", |
| "--mask-rate", |
| "0.4", |
| "--output-dir", |
| str(output_dir), |
| ] |
| ) |
|
|
| metrics = json.loads((output_dir / "metrics.json").read_text(encoding="utf-8")) |
|
|
| assert metrics["config"]["mask_strategy"] == "random_drop" |
| assert set(metrics["experiments"]) == {"A_to_A__random_drop", "A_to_B__random_drop"} |
|
|
|
|
| def _patient_positive_count(patient_paths: list[Path]) -> int: |
| return sum(int(load_patient_frame(path)[LABEL_COLUMN].max() > 0) for path in patient_paths) |
|
|
|
|
| def test_split_train_calibration_test_shuffles_and_stratifies_patients(tmp_path: Path) -> None: |
| data_root = tmp_path / "training" |
| training_a = data_root / "training_setA" |
| training_a.mkdir(parents=True) |
|
|
| for index in range(12): |
| _write_patient( |
| training_a / f"pA{index:03d}.psv", |
| positive=index % 2 == 0, |
| shifted=False, |
| ) |
|
|
| split_zero = _split_train_calibration_test( |
| RunConfig( |
| data_root=data_root, |
| train_hospital="A", |
| test_mode="aa", |
| train_patients=4, |
| calibration_patients=4, |
| test_patients=4, |
| random_state=0, |
| ) |
| )["A_to_A"] |
| split_one = _split_train_calibration_test( |
| RunConfig( |
| data_root=data_root, |
| train_hospital="A", |
| test_mode="aa", |
| train_patients=4, |
| calibration_patients=4, |
| test_patients=4, |
| random_state=1, |
| ) |
| )["A_to_A"] |
|
|
| for patient_paths in split_zero: |
| assert len(patient_paths) == 4 |
| assert _patient_positive_count(patient_paths) == 2 |
|
|
| first_four = [f"pA{index:03d}.psv" for index in range(4)] |
| split_zero_train_names = [path.name for path in split_zero[0]] |
| split_one_train_names = [path.name for path in split_one[0]] |
| assert split_zero_train_names != first_four |
| assert split_zero_train_names != split_one_train_names |
|
|
|
|
| def test_cli_run_with_coverage_gap_variable_grouping_writes_selection_metadata(tmp_path: Path) -> None: |
| data_root = tmp_path / "training" |
| training_a = data_root / "training_setA" |
| training_a.mkdir(parents=True) |
|
|
| for index in range(16): |
| _write_patient( |
| training_a / f"pA{index:03d}.psv", |
| positive=index % 2 == 0, |
| shifted=index >= 8, |
| ) |
|
|
| output_dir = tmp_path / "structured-outputs" |
| main( |
| [ |
| "run", |
| "--data-root", |
| str(data_root), |
| "--train-hospital", |
| "A", |
| "--test-mode", |
| "aa", |
| "--model-type", |
| "logistic_regression", |
| "--train-patients", |
| "4", |
| "--selection-patients", |
| "4", |
| "--calibration-patients", |
| "4", |
| "--test-patients", |
| "4", |
| "--lookback-hours", |
| "3", |
| "--horizon-hours", |
| "2", |
| "--alpha", |
| "0.2", |
| "--missingness-grouping-strategy", |
| "coverage_gap_variable", |
| "--output-dir", |
| str(output_dir), |
| ] |
| ) |
|
|
| metrics = json.loads((output_dir / "metrics.json").read_text(encoding="utf-8")) |
| subgroup_frame = pd.read_csv(output_dir / "subgroup_coverage.csv") |
|
|
| assert metrics["config"]["missingness_grouping_strategy"] == "coverage_gap_variable" |
| assert metrics["structured_grouping"]["strategy"] == "coverage_gap_variable" |
| assert "selected_variable" in metrics["structured_grouping"] |
| assert set(subgroup_frame["group_label"]) != {"low", "medium", "high"} |
|
|
|
|
| def test_run_config_defaults_opt_in_baseline_flags_to_false(tmp_path: Path) -> None: |
| config = RunConfig(data_root=tmp_path) |
|
|
| assert config.enable_external_baselines is False |
| assert config.enable_learned_partition is False |
|
|
|
|
| def test_cli_run_with_learned_partition_requires_selection_split(tmp_path: Path) -> None: |
| data_root = tmp_path / "training" |
| training_a = data_root / "training_setA" |
| training_b = data_root / "training_setB" |
| training_a.mkdir(parents=True) |
| training_b.mkdir(parents=True) |
|
|
| for index in range(8): |
| _write_patient( |
| training_a / f"pA{index:03d}.psv", |
| positive=index % 2 == 0, |
| shifted=False, |
| ) |
|
|
| for index in range(4): |
| _write_patient( |
| training_b / f"pB{index:03d}.psv", |
| positive=index % 2 == 0, |
| shifted=True, |
| ) |
|
|
| with pytest.raises(ValueError, match="selection split"): |
| main( |
| [ |
| "run", |
| "--data-root", |
| str(data_root), |
| "--train-hospital", |
| "A", |
| "--test-mode", |
| "both", |
| "--train-patients", |
| "4", |
| "--calibration-patients", |
| "2", |
| "--test-patients", |
| "2", |
| "--lookback-hours", |
| "3", |
| "--horizon-hours", |
| "2", |
| "--alpha", |
| "0.2", |
| "--enable-learned-partition", |
| "--output-dir", |
| str(tmp_path / "outputs"), |
| ] |
| ) |
|
|
|
|
| def test_structured_grouping_uses_explicit_group_keyword_arguments(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: |
| data_root = tmp_path / "training" |
| training_a = data_root / "training_setA" |
| training_a.mkdir(parents=True) |
|
|
| for index in range(16): |
| _write_patient( |
| training_a / f"pA{index:03d}.psv", |
| positive=index % 2 == 0, |
| shifted=index >= 8, |
| ) |
|
|
| fit_calls: list[dict[str, object]] = [] |
| predict_calls: list[dict[str, object]] = [] |
|
|
| original_fit = runner_module.MissingnessAwareConformalClassifier.fit |
| original_predict_sets = runner_module.MissingnessAwareConformalClassifier.predict_sets |
|
|
| def fit_spy(self, calibration_labels, calibration_positive_probabilities, calibration_missing_rates=None, **kwargs): |
| fit_calls.append( |
| { |
| "calibration_missing_rates": calibration_missing_rates, |
| "kwargs": dict(kwargs), |
| } |
| ) |
| return original_fit( |
| self, |
| calibration_labels, |
| calibration_positive_probabilities, |
| calibration_missing_rates, |
| **kwargs, |
| ) |
|
|
| def predict_spy(self, positive_probabilities, missing_rates=None, **kwargs): |
| predict_calls.append( |
| { |
| "missing_rates": missing_rates, |
| "kwargs": dict(kwargs), |
| } |
| ) |
| return original_predict_sets( |
| self, |
| positive_probabilities, |
| missing_rates, |
| **kwargs, |
| ) |
|
|
| monkeypatch.setattr(runner_module.MissingnessAwareConformalClassifier, "fit", fit_spy) |
| monkeypatch.setattr(runner_module.MissingnessAwareConformalClassifier, "predict_sets", predict_spy) |
|
|
| main( |
| [ |
| "run", |
| "--data-root", |
| str(data_root), |
| "--train-hospital", |
| "A", |
| "--test-mode", |
| "aa", |
| "--model-type", |
| "logistic_regression", |
| "--train-patients", |
| "4", |
| "--selection-patients", |
| "4", |
| "--calibration-patients", |
| "4", |
| "--test-patients", |
| "4", |
| "--lookback-hours", |
| "3", |
| "--horizon-hours", |
| "2", |
| "--alpha", |
| "0.2", |
| "--missingness-grouping-strategy", |
| "coverage_gap_variable", |
| "--output-dir", |
| str(tmp_path / "structured-outputs"), |
| ] |
| ) |
|
|
| assert any("calibration_group_ids" in call["kwargs"] for call in fit_calls) |
| assert any(call["calibration_missing_rates"] is None for call in fit_calls if "calibration_group_ids" in call["kwargs"]) |
| assert any("test_group_ids" in call["kwargs"] for call in predict_calls) |
| assert any(call["missing_rates"] is None for call in predict_calls if "test_group_ids" in call["kwargs"]) |
|
|
|
|
| def test_cli_run_adds_learned_partition_only_when_enabled(tmp_path: Path) -> None: |
| data_root = tmp_path / "training" |
| training_a = data_root / "training_setA" |
| training_a.mkdir(parents=True) |
|
|
| for index in range(16): |
| _write_patient( |
| training_a / f"pA{index:03d}.psv", |
| positive=index % 2 == 0, |
| shifted=index >= 8, |
| ) |
|
|
| output_dir = tmp_path / "learned-partition-outputs" |
| main( |
| [ |
| "run", |
| "--data-root", |
| str(data_root), |
| "--train-hospital", |
| "A", |
| "--test-mode", |
| "aa", |
| "--model-type", |
| "logistic_regression", |
| "--train-patients", |
| "4", |
| "--selection-patients", |
| "4", |
| "--calibration-patients", |
| "4", |
| "--test-patients", |
| "4", |
| "--lookback-hours", |
| "3", |
| "--horizon-hours", |
| "2", |
| "--alpha", |
| "0.2", |
| "--enable-learned-partition", |
| "--output-dir", |
| str(output_dir), |
| ] |
| ) |
|
|
| metrics = json.loads((output_dir / "metrics.json").read_text(encoding="utf-8")) |
| subgroup_frame = pd.read_csv(output_dir / "subgroup_coverage.csv") |
|
|
| assert "learned_partition" in set(metrics["experiments"]["A_to_A"]) |
| assert "learned_partition" in set(subgroup_frame["method"]) |
|
|
|
|
| def test_cli_run_adds_external_gibbs_methods_only_when_enabled( |
| tmp_path: Path, |
| monkeypatch: pytest.MonkeyPatch, |
| ) -> None: |
| data_root = tmp_path / "training" |
| training_a = data_root / "training_setA" |
| training_b = data_root / "training_setB" |
| training_a.mkdir(parents=True) |
| training_b.mkdir(parents=True) |
|
|
| for index in range(8): |
| _write_patient( |
| training_a / f"pA{index:03d}.psv", |
| positive=index % 2 == 0, |
| shifted=False, |
| ) |
|
|
| for index in range(4): |
| _write_patient( |
| training_b / f"pB{index:03d}.psv", |
| positive=index % 2 == 0, |
| shifted=True, |
| ) |
|
|
| class FakeCondConf: |
| def __init__(self, score_fn, Phi_fn, quantile_fn, infinite_params): |
| self._score_fn = score_fn |
| self._phi_fn = Phi_fn |
|
|
| def setup_problem(self, X, Y, alpha=None): |
| self._alpha = alpha |
|
|
| def predict(self, *args, **kwargs): |
| X = args[1] if len(args) > 1 else args[0] |
| return [0.2] * len(X) |
|
|
| original_import_module = external_baselines_module.importlib.import_module |
|
|
| def fake_import_module(name: str, package: str | None = None): |
| if name == "conditionalconformal": |
| return type("FakeModule", (), {"CondConf": FakeCondConf}) |
| return original_import_module(name, package) |
|
|
| monkeypatch.setattr(external_baselines_module.importlib, "import_module", fake_import_module) |
|
|
| output_dir = tmp_path / "external-baselines-outputs" |
| main( |
| [ |
| "run", |
| "--data-root", |
| str(data_root), |
| "--train-hospital", |
| "A", |
| "--test-mode", |
| "both", |
| "--train-patients", |
| "4", |
| "--calibration-patients", |
| "2", |
| "--test-patients", |
| "2", |
| "--lookback-hours", |
| "3", |
| "--horizon-hours", |
| "2", |
| "--alpha", |
| "0.2", |
| "--enable-external-baselines", |
| "--output-dir", |
| str(output_dir), |
| ] |
| ) |
|
|
| metrics = json.loads((output_dir / "metrics.json").read_text(encoding="utf-8")) |
| subgroup_frame = pd.read_csv(output_dir / "subgroup_coverage.csv") |
|
|
| assert {"gibbs_general", "gibbs_missingness"} <= set(metrics["experiments"]["A_to_A"]) |
| assert {"gibbs_general", "gibbs_missingness"} <= set(subgroup_frame["method"]) |
|
|