from __future__ import annotations import json from pathlib import Path import pandas as pd import pytest from sepsis_mcp.cli import main from sepsis_mcp.constants import LABEL_COLUMN import sepsis_mcp.external_baselines as external_baselines_module from sepsis_mcp.io import load_patient_frame import sepsis_mcp.runner as runner_module from sepsis_mcp.runner import RunConfig, _split_train_calibration_test def _write_patient(path: Path, *, positive: bool, shifted: bool) -> None: base_hr = 78 if positive else 62 if shifted: base_hr += 6 rows = ["HR|O2Sat|Age|Gender|Unit1|Unit2|HospAdmTime|ICULOS|SepsisLabel"] for hour in range(1, 9): hr = base_hr + hour * (3 if positive else 1) o2sat = 96 - hour if positive else 98 - (hour % 2) if shifted and hour in {2, 3, 6}: o2sat_value = "NaN" else: o2sat_value = str(o2sat) sepsis_label = 1 if positive and hour >= 5 else 0 rows.append( f"{hr}|{o2sat_value}|65|1|1|0|-5|{hour}|{sepsis_label}" ) path.write_text("\n".join(rows), encoding="utf-8") def test_cli_run_generates_metrics_outputs_and_plot(tmp_path: Path) -> None: data_root = tmp_path / "training" training_a = data_root / "training_setA" training_b = data_root / "training_setB" training_a.mkdir(parents=True) training_b.mkdir(parents=True) for index in range(8): _write_patient( training_a / f"pA{index:03d}.psv", positive=index % 2 == 0, shifted=False, ) for index in range(4): _write_patient( training_b / f"pB{index:03d}.psv", positive=index % 2 == 0, shifted=True, ) output_dir = tmp_path / "outputs" main( [ "run", "--data-root", str(data_root), "--train-hospital", "A", "--test-mode", "both", "--train-patients", "4", "--calibration-patients", "2", "--test-patients", "2", "--lookback-hours", "3", "--horizon-hours", "2", "--alpha", "0.2", "--output-dir", str(output_dir), ] ) metrics_path = output_dir / "metrics.json" subgroup_path = output_dir / "subgroup_coverage.csv" plot_path = output_dir / "coverage_by_missingness.png" assert metrics_path.exists() assert subgroup_path.exists() assert plot_path.exists() metrics = json.loads(metrics_path.read_text(encoding="utf-8")) subgroup_frame = pd.read_csv(subgroup_path) assert set(metrics["experiments"]) == {"A_to_A", "A_to_B"} assert set(metrics["experiments"]["A_to_A"]) == { "standard", "missingness_aware", "mondrian_tilted", "label_conditional", "label_conditional_missingness_aware", "shrunk_weighted_missingness_aware", "weighted_missingness_aware", } assert "auroc" in metrics["experiments"]["A_to_A"]["standard"] assert "empirical_coverage" in metrics["experiments"]["A_to_B"]["missingness_aware"] assert set(subgroup_frame["method"]) == { "standard", "missingness_aware", "mondrian_tilted", "label_conditional", "label_conditional_missingness_aware", "shrunk_weighted_missingness_aware", "weighted_missingness_aware", } assert set(subgroup_frame["experiment"]) == {"A_to_A", "A_to_B"} def test_cli_run_with_random_drop_masking_encodes_mask_condition_in_outputs(tmp_path: Path) -> None: data_root = tmp_path / "training" training_a = data_root / "training_setA" training_b = data_root / "training_setB" training_a.mkdir(parents=True) training_b.mkdir(parents=True) for index in range(8): _write_patient( training_a / f"pA{index:03d}.psv", positive=index % 2 == 0, shifted=False, ) for index in range(4): _write_patient( training_b / f"pB{index:03d}.psv", positive=index % 2 == 0, shifted=False, ) output_dir = tmp_path / "masked-outputs" main( [ "run", "--data-root", str(data_root), "--train-hospital", "A", "--test-mode", "both", "--train-patients", "4", "--calibration-patients", "2", "--test-patients", "2", "--lookback-hours", "3", "--horizon-hours", "2", "--alpha", "0.2", "--mask-strategy", "random_drop", "--mask-rate", "0.4", "--output-dir", str(output_dir), ] ) metrics = json.loads((output_dir / "metrics.json").read_text(encoding="utf-8")) assert metrics["config"]["mask_strategy"] == "random_drop" assert set(metrics["experiments"]) == {"A_to_A__random_drop", "A_to_B__random_drop"} def _patient_positive_count(patient_paths: list[Path]) -> int: return sum(int(load_patient_frame(path)[LABEL_COLUMN].max() > 0) for path in patient_paths) def test_split_train_calibration_test_shuffles_and_stratifies_patients(tmp_path: Path) -> None: data_root = tmp_path / "training" training_a = data_root / "training_setA" training_a.mkdir(parents=True) for index in range(12): _write_patient( training_a / f"pA{index:03d}.psv", positive=index % 2 == 0, shifted=False, ) split_zero = _split_train_calibration_test( RunConfig( data_root=data_root, train_hospital="A", test_mode="aa", train_patients=4, calibration_patients=4, test_patients=4, random_state=0, ) )["A_to_A"] split_one = _split_train_calibration_test( RunConfig( data_root=data_root, train_hospital="A", test_mode="aa", train_patients=4, calibration_patients=4, test_patients=4, random_state=1, ) )["A_to_A"] for patient_paths in split_zero: assert len(patient_paths) == 4 assert _patient_positive_count(patient_paths) == 2 first_four = [f"pA{index:03d}.psv" for index in range(4)] split_zero_train_names = [path.name for path in split_zero[0]] split_one_train_names = [path.name for path in split_one[0]] assert split_zero_train_names != first_four assert split_zero_train_names != split_one_train_names def test_cli_run_with_coverage_gap_variable_grouping_writes_selection_metadata(tmp_path: Path) -> None: data_root = tmp_path / "training" training_a = data_root / "training_setA" training_a.mkdir(parents=True) for index in range(16): _write_patient( training_a / f"pA{index:03d}.psv", positive=index % 2 == 0, shifted=index >= 8, ) output_dir = tmp_path / "structured-outputs" main( [ "run", "--data-root", str(data_root), "--train-hospital", "A", "--test-mode", "aa", "--model-type", "logistic_regression", "--train-patients", "4", "--selection-patients", "4", "--calibration-patients", "4", "--test-patients", "4", "--lookback-hours", "3", "--horizon-hours", "2", "--alpha", "0.2", "--missingness-grouping-strategy", "coverage_gap_variable", "--output-dir", str(output_dir), ] ) metrics = json.loads((output_dir / "metrics.json").read_text(encoding="utf-8")) subgroup_frame = pd.read_csv(output_dir / "subgroup_coverage.csv") assert metrics["config"]["missingness_grouping_strategy"] == "coverage_gap_variable" assert metrics["structured_grouping"]["strategy"] == "coverage_gap_variable" assert "selected_variable" in metrics["structured_grouping"] assert set(subgroup_frame["group_label"]) != {"low", "medium", "high"} def test_run_config_defaults_opt_in_baseline_flags_to_false(tmp_path: Path) -> None: config = RunConfig(data_root=tmp_path) assert config.enable_external_baselines is False assert config.enable_learned_partition is False def test_cli_run_with_learned_partition_requires_selection_split(tmp_path: Path) -> None: data_root = tmp_path / "training" training_a = data_root / "training_setA" training_b = data_root / "training_setB" training_a.mkdir(parents=True) training_b.mkdir(parents=True) for index in range(8): _write_patient( training_a / f"pA{index:03d}.psv", positive=index % 2 == 0, shifted=False, ) for index in range(4): _write_patient( training_b / f"pB{index:03d}.psv", positive=index % 2 == 0, shifted=True, ) with pytest.raises(ValueError, match="selection split"): main( [ "run", "--data-root", str(data_root), "--train-hospital", "A", "--test-mode", "both", "--train-patients", "4", "--calibration-patients", "2", "--test-patients", "2", "--lookback-hours", "3", "--horizon-hours", "2", "--alpha", "0.2", "--enable-learned-partition", "--output-dir", str(tmp_path / "outputs"), ] ) def test_structured_grouping_uses_explicit_group_keyword_arguments(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: data_root = tmp_path / "training" training_a = data_root / "training_setA" training_a.mkdir(parents=True) for index in range(16): _write_patient( training_a / f"pA{index:03d}.psv", positive=index % 2 == 0, shifted=index >= 8, ) fit_calls: list[dict[str, object]] = [] predict_calls: list[dict[str, object]] = [] original_fit = runner_module.MissingnessAwareConformalClassifier.fit original_predict_sets = runner_module.MissingnessAwareConformalClassifier.predict_sets def fit_spy(self, calibration_labels, calibration_positive_probabilities, calibration_missing_rates=None, **kwargs): fit_calls.append( { "calibration_missing_rates": calibration_missing_rates, "kwargs": dict(kwargs), } ) return original_fit( self, calibration_labels, calibration_positive_probabilities, calibration_missing_rates, **kwargs, ) def predict_spy(self, positive_probabilities, missing_rates=None, **kwargs): predict_calls.append( { "missing_rates": missing_rates, "kwargs": dict(kwargs), } ) return original_predict_sets( self, positive_probabilities, missing_rates, **kwargs, ) monkeypatch.setattr(runner_module.MissingnessAwareConformalClassifier, "fit", fit_spy) monkeypatch.setattr(runner_module.MissingnessAwareConformalClassifier, "predict_sets", predict_spy) main( [ "run", "--data-root", str(data_root), "--train-hospital", "A", "--test-mode", "aa", "--model-type", "logistic_regression", "--train-patients", "4", "--selection-patients", "4", "--calibration-patients", "4", "--test-patients", "4", "--lookback-hours", "3", "--horizon-hours", "2", "--alpha", "0.2", "--missingness-grouping-strategy", "coverage_gap_variable", "--output-dir", str(tmp_path / "structured-outputs"), ] ) assert any("calibration_group_ids" in call["kwargs"] for call in fit_calls) assert any(call["calibration_missing_rates"] is None for call in fit_calls if "calibration_group_ids" in call["kwargs"]) assert any("test_group_ids" in call["kwargs"] for call in predict_calls) assert any(call["missing_rates"] is None for call in predict_calls if "test_group_ids" in call["kwargs"]) def test_cli_run_adds_learned_partition_only_when_enabled(tmp_path: Path) -> None: data_root = tmp_path / "training" training_a = data_root / "training_setA" training_a.mkdir(parents=True) for index in range(16): _write_patient( training_a / f"pA{index:03d}.psv", positive=index % 2 == 0, shifted=index >= 8, ) output_dir = tmp_path / "learned-partition-outputs" main( [ "run", "--data-root", str(data_root), "--train-hospital", "A", "--test-mode", "aa", "--model-type", "logistic_regression", "--train-patients", "4", "--selection-patients", "4", "--calibration-patients", "4", "--test-patients", "4", "--lookback-hours", "3", "--horizon-hours", "2", "--alpha", "0.2", "--enable-learned-partition", "--output-dir", str(output_dir), ] ) metrics = json.loads((output_dir / "metrics.json").read_text(encoding="utf-8")) subgroup_frame = pd.read_csv(output_dir / "subgroup_coverage.csv") assert "learned_partition" in set(metrics["experiments"]["A_to_A"]) assert "learned_partition" in set(subgroup_frame["method"]) def test_cli_run_adds_external_gibbs_methods_only_when_enabled( tmp_path: Path, monkeypatch: pytest.MonkeyPatch, ) -> None: data_root = tmp_path / "training" training_a = data_root / "training_setA" training_b = data_root / "training_setB" training_a.mkdir(parents=True) training_b.mkdir(parents=True) for index in range(8): _write_patient( training_a / f"pA{index:03d}.psv", positive=index % 2 == 0, shifted=False, ) for index in range(4): _write_patient( training_b / f"pB{index:03d}.psv", positive=index % 2 == 0, shifted=True, ) class FakeCondConf: def __init__(self, score_fn, Phi_fn, quantile_fn, infinite_params): self._score_fn = score_fn self._phi_fn = Phi_fn def setup_problem(self, X, Y, alpha=None): # noqa: N802 self._alpha = alpha def predict(self, *args, **kwargs): X = args[1] if len(args) > 1 else args[0] return [0.2] * len(X) original_import_module = external_baselines_module.importlib.import_module def fake_import_module(name: str, package: str | None = None): if name == "conditionalconformal": return type("FakeModule", (), {"CondConf": FakeCondConf}) return original_import_module(name, package) monkeypatch.setattr(external_baselines_module.importlib, "import_module", fake_import_module) output_dir = tmp_path / "external-baselines-outputs" main( [ "run", "--data-root", str(data_root), "--train-hospital", "A", "--test-mode", "both", "--train-patients", "4", "--calibration-patients", "2", "--test-patients", "2", "--lookback-hours", "3", "--horizon-hours", "2", "--alpha", "0.2", "--enable-external-baselines", "--output-dir", str(output_dir), ] ) metrics = json.loads((output_dir / "metrics.json").read_text(encoding="utf-8")) subgroup_frame = pd.read_csv(output_dir / "subgroup_coverage.csv") assert {"gibbs_general", "gibbs_missingness"} <= set(metrics["experiments"]["A_to_A"]) assert {"gibbs_general", "gibbs_missingness"} <= set(subgroup_frame["method"])