misscp / tests /test_external_baselines_analysis.py
Anonymous
Initial anonymous MissCP release
32f5a65
from __future__ import annotations
import json
from pathlib import Path
import pandas as pd
def test_run_external_baselines_analysis_writes_overlap_interpretability_and_stress_outputs(tmp_path: Path) -> None:
from sepsis_mcp.external_baselines_analysis import run_external_baselines_analysis
run_dir = tmp_path / "run-a"
run_dir.mkdir()
pd.DataFrame(
[
{
"split": "test",
"sample_id": "a",
"hospital_id": "icu-a",
"learned_group": 0,
"learned_group_label": "leaf_0",
"leaf_id": 3,
"missingness_group": 0,
"missingness_group_label": "low",
},
{
"split": "test",
"sample_id": "b",
"hospital_id": "icu-a",
"learned_group": 0,
"learned_group_label": "leaf_0",
"leaf_id": 3,
"missingness_group": 0,
"missingness_group_label": "low",
},
{
"split": "test",
"sample_id": "c",
"hospital_id": "icu-a",
"learned_group": 1,
"learned_group_label": "leaf_1",
"leaf_id": 4,
"missingness_group": 1,
"missingness_group_label": "high",
},
{
"split": "test",
"sample_id": "d",
"hospital_id": "icu-a",
"learned_group": 1,
"learned_group_label": "leaf_1",
"leaf_id": 4,
"missingness_group": 1,
"missingness_group_label": "high",
},
]
).to_csv(run_dir / "partition_overlap_assignments.csv", index=False)
pd.DataFrame(
[
{"feature": "creatinine_max_nan", "importance": 0.75, "split_count": 2, "is_missingness_proxy": True},
{"feature": "heart_rate_max", "importance": 0.25, "split_count": 1, "is_missingness_proxy": False},
]
).to_csv(run_dir / "partition_feature_importance.csv", index=False)
pd.DataFrame(
[
{"drop_rate": 0.0, "method": "standard", "mean_coverage": 0.95, "mean_gap": 0.05, "mean_set_size": 1.0},
{"drop_rate": 0.3, "method": "standard", "mean_coverage": 0.92, "mean_gap": 0.08, "mean_set_size": 1.05},
{"drop_rate": 0.0, "method": "gibbs_missingness", "mean_coverage": 0.90, "mean_gap": 0.10, "mean_set_size": 0.95},
{"drop_rate": 0.3, "method": "gibbs_missingness", "mean_coverage": 0.72, "mean_gap": 0.30, "mean_set_size": 0.75},
]
).to_csv(run_dir / "stress_aggregate.csv", index=False)
output_dir = tmp_path / "analysis"
paths = run_external_baselines_analysis(run_dirs=(run_dir,), output_dir=output_dir)
overlap_summary = pd.read_csv(paths["overlap_summary"])
overlap_details = pd.read_csv(paths["overlap_details"])
interpretability = pd.read_csv(paths["partition_interpretability"])
stress_summary = pd.read_csv(paths["stress_degradation_summary"])
analysis_summary = json.loads(paths["analysis_summary"].read_text(encoding="utf-8"))
assert overlap_summary.loc[0, "ari"] == 1.0
assert overlap_summary.loc[0, "nmi"] == 1.0
assert overlap_details.loc[0, "overlap_count"] > 0
assert interpretability.loc[0, "share_missingness_proxy_importance"] == 0.75
gibbs_row = stress_summary.loc[stress_summary["method"] == "gibbs_missingness"].iloc[0]
assert gibbs_row["coverage_change_at_max_drop"] == -0.18
assert gibbs_row["gap_change_at_max_drop"] == 0.2
assert analysis_summary["run_count"] == 1