misscp / tests /test_simulation.py
Anonymous
Initial anonymous MissCP release
32f5a65
from __future__ import annotations
import json
from pathlib import Path
import pandas as pd
from sepsis_mcp.simulation import (
SimulationConfig,
estimate_gamma_from_correlation,
generate_simulation_splits,
run_model_swap_simulation_cell,
run_simulation_cell,
write_sim1_main_figure,
select_shifted_missingness_group,
)
from sepsis_mcp.simulation_sweep import main
def test_generate_simulation_splits_returns_expected_shapes_and_labels() -> None:
config = SimulationConfig(
n_train=120,
n_cal=60,
n_test=80,
d=10,
random_state=7,
)
splits = generate_simulation_splits(config, delta_m=0.2, delta_x=0.0, seed=7)
assert splits.train.complete_features.shape == (120, 10)
assert splits.calibration.observed_features.shape == (60, 10)
assert splits.test.missing_mask.shape == (80, 10)
assert set(splits.train.labels.unique()) <= {0, 1}
assert set(splits.calibration.labels.unique()) <= {0, 1}
assert set(splits.test.labels.unique()) <= {0, 1}
def test_missingness_shift_increases_test_missingness_on_predictive_features() -> None:
config = SimulationConfig(
n_train=250,
n_cal=120,
n_test=160,
d=10,
random_state=0,
)
no_shift = generate_simulation_splits(config, delta_m=0.0, delta_x=0.0, seed=0)
shifted = generate_simulation_splits(config, delta_m=0.3, delta_x=0.0, seed=0)
baseline_feature_missing = no_shift.test.missing_mask.iloc[:, 0].mean()
shifted_feature_missing = shifted.test.missing_mask.iloc[:, 0].mean()
calibration_feature_missing = shifted.calibration.missing_mask.iloc[:, 0].mean()
assert shifted_feature_missing > baseline_feature_missing
assert abs(calibration_feature_missing - config.base_missing_rate) < 0.06
assert shifted_feature_missing - calibration_feature_missing > 0.12
def test_select_shifted_missingness_group_picks_high_shift_feature() -> None:
config = SimulationConfig(
n_train=180,
n_cal=90,
n_test=90,
d=8,
random_state=13,
)
splits = generate_simulation_splits(config, delta_m=0.25, delta_x=0.0, seed=13)
grouping = select_shifted_missingness_group(
splits.calibration.missing_mask,
splits.test.missing_mask,
min_missing_fraction=0.05,
max_missing_fraction=0.95,
min_group_size=10,
)
assert grouping.variable_name in splits.calibration.missing_mask.columns
assert set(grouping.calibration_group_ids.unique()) <= {0, 1}
assert set(grouping.test_group_ids.unique()) <= {0, 1}
def test_run_simulation_cell_emits_all_requested_methods() -> None:
config = SimulationConfig(
n_train=220,
n_cal=120,
n_test=160,
d=10,
random_state=5,
)
results = run_simulation_cell(
config,
delta_m=0.2,
delta_x=0.0,
seed=5,
methods=(
"standard",
"mondrian_tilted",
"predicted_risk_tercile",
"random_grouping",
"oracle_grouping",
),
)
assert isinstance(results, pd.DataFrame)
assert set(results["method"]) == {
"standard",
"mondrian_tilted",
"predicted_risk_tercile",
"random_grouping",
"oracle_grouping",
}
assert {
"delta_m",
"delta_x",
"seed",
"empirical_coverage",
"average_set_size",
"max_group_coverage_gap",
"evaluation_grouping_strategy",
} <= set(results.columns)
assert set(results["evaluation_grouping_strategy"]) == {"coverage_gap_variable"}
def test_run_model_swap_simulation_cell_emits_task_d_methods() -> None:
config = SimulationConfig(
n_train=220,
n_cal=120,
n_test=160,
d=10,
random_state=9,
missingness_mechanism="mar_severity",
mar_gamma=2.0,
)
results = run_model_swap_simulation_cell(
config,
delta_m=0.2,
delta_x=0.0,
seed=9,
calibrate_model_type="logistic_regression",
deploy_model_type="xgboost",
)
assert set(results["method"]) == {
"missingness_mondrian_matched",
"missingness_mondrian_swap",
"risk_mondrian_swap",
}
assert {
"delta_m",
"delta_x",
"gamma",
"seed",
"calibrate_model",
"deploy_model",
"empirical_coverage",
"average_set_size",
"max_group_coverage_gap",
"test_group_overlap_vs_matched",
} <= set(results.columns)
missingness_rows = results[results["method"].str.contains("missingness")]
assert (missingness_rows["test_group_overlap_vs_matched"] == 1.0).all()
risk_row = results[results["method"] == "risk_mondrian_swap"].iloc[0]
assert float(risk_row["test_group_overlap_vs_matched"]) < 1.0
def test_mar_missingness_mechanism_couples_missingness_with_severity() -> None:
config = SimulationConfig(
n_train=300,
n_cal=150,
n_test=300,
d=10,
missingness_mechanism="mar_severity",
mar_gamma=2.0,
random_state=11,
)
splits = generate_simulation_splits(config, delta_m=0.3, delta_x=0.0, seed=11)
test_complete = splits.test.complete_features
severity = (
(test_complete.iloc[:, 0] * 5.0)
+ (test_complete.iloc[:, 1] * 3.0)
+ (test_complete.iloc[:, 2] * 0.5)
+ (test_complete.iloc[:, 3] * 0.5)
)
missing_rate = splits.test.missing_mask.mean(axis=1)
high_severity_missing = float(missing_rate[severity >= severity.median()].mean())
low_severity_missing = float(missing_rate[severity < severity.median()].mean())
assert high_severity_missing > low_severity_missing
assert high_severity_missing - low_severity_missing > 0.05
def test_mar_selective_changes_predictive_features_more_than_noise_features() -> None:
config = SimulationConfig(
n_train=300,
n_cal=150,
n_test=300,
d=10,
missingness_mechanism="mar_selective",
mar_gamma=2.0,
random_state=17,
)
splits = generate_simulation_splits(config, delta_m=0.2, delta_x=0.0, seed=17)
predictive_missing = float(splits.test.missing_mask.iloc[:, :3].mean().mean())
noise_missing = float(splits.test.missing_mask.iloc[:, 5:].mean().mean())
assert predictive_missing > noise_missing
assert predictive_missing - noise_missing > 0.05
def test_mnar_extreme_increases_missingness_for_extreme_feature_values() -> None:
config = SimulationConfig(
n_train=300,
n_cal=150,
n_test=300,
d=10,
missingness_mechanism="mnar_extreme",
random_state=19,
)
splits = generate_simulation_splits(config, delta_m=0.2, delta_x=0.0, seed=19)
x0 = splits.test.complete_features.iloc[:, 0].abs()
row_missing = splits.test.missing_mask.iloc[:, 0]
extreme_missing = float(row_missing[x0 > 1.5].mean())
typical_missing = float(row_missing[x0 <= 1.5].mean())
assert extreme_missing > typical_missing
def test_simulation_sweep_main_writes_results_summary_and_plot(tmp_path: Path) -> None:
output_dir = tmp_path / "simulation-sweep"
main(
[
"--sim",
"1",
"--output-dir",
str(output_dir),
"--seeds",
"2",
"--delta-m-grid",
"0.0,0.2",
"--model-type",
"logistic_regression",
"--n-train",
"150",
"--n-cal",
"80",
"--n-test",
"120",
"--d",
"10",
]
)
results = pd.read_csv(output_dir / "sweep_results.csv")
summary = pd.read_csv(output_dir / "sweep_summary.csv")
config = json.loads((output_dir / "simulation_config.json").read_text(encoding="utf-8"))
assert results.shape[0] == 2 * 2 * 5
assert set(summary["method"]) == {
"standard",
"mondrian_tilted",
"predicted_risk_tercile",
"random_grouping",
"oracle_grouping",
}
assert config["model_type"] == "logistic_regression"
assert (output_dir / "sim1_gap_vs_delta_m.png").exists()
def test_simulation_sweep_main_supports_sim5(tmp_path: Path) -> None:
output_dir = tmp_path / "sim5"
main(
[
"--sim",
"5",
"--output-dir",
str(output_dir),
"--seeds",
"2",
"--delta-m-grid",
"0.0,0.2",
"--gamma-grid",
"0.0,2.0",
"--n-train",
"150",
"--n-cal",
"80",
"--n-test",
"120",
"--d",
"10",
]
)
results = pd.read_csv(output_dir / "sweep_results.csv")
summary = pd.read_csv(output_dir / "sweep_summary.csv")
assert set(results["method"]) == {
"missingness_mondrian_matched",
"missingness_mondrian_swap",
"risk_mondrian_swap",
}
assert set(summary["method"]) == {
"missingness_mondrian_matched",
"missingness_mondrian_swap",
"risk_mondrian_swap",
}
assert (output_dir / "sim5_summary_figure.png").exists()
def test_write_sim1_main_figure_writes_two_panel_png(tmp_path: Path) -> None:
summary = pd.DataFrame(
[
{"gamma": 0.0, "delta_m": 0.0, "method": "standard", "max_group_coverage_gap_mean": 0.03},
{"gamma": 0.0, "delta_m": 0.2, "method": "standard", "max_group_coverage_gap_mean": 0.04},
{"gamma": 0.0, "delta_m": 0.0, "method": "mondrian_tilted", "max_group_coverage_gap_mean": 0.02},
{"gamma": 0.0, "delta_m": 0.2, "method": "mondrian_tilted", "max_group_coverage_gap_mean": 0.02},
{"gamma": 2.0, "delta_m": 0.0, "method": "standard", "max_group_coverage_gap_mean": 0.06},
{"gamma": 2.0, "delta_m": 0.2, "method": "standard", "max_group_coverage_gap_mean": 0.08},
{"gamma": 2.0, "delta_m": 0.0, "method": "mondrian_tilted", "max_group_coverage_gap_mean": 0.05},
{"gamma": 2.0, "delta_m": 0.2, "method": "mondrian_tilted", "max_group_coverage_gap_mean": 0.05},
]
)
output_path = tmp_path / "sim1_main.png"
write_sim1_main_figure(summary, output_path=output_path, gamma_values=(0.0, 2.0))
assert output_path.exists()
def test_estimate_gamma_from_correlation_chooses_nearest_gamma() -> None:
estimated = estimate_gamma_from_correlation(
target_correlation=0.31,
gamma_to_correlation={
0.0: 0.05,
1.0: 0.18,
2.0: 0.29,
3.0: 0.41,
},
)
assert estimated == 2.0
def test_simulation_sweep_main_supports_sim2_heatmap_outputs(tmp_path: Path) -> None:
output_dir = tmp_path / "simulation-sim2"
main(
[
"--sim",
"2",
"--output-dir",
str(output_dir),
"--seeds",
"2",
"--delta-m-grid",
"0.0,0.2",
"--delta-x-grid",
"0.0,1.0",
"--gamma-grid",
"0.0,2.0",
"--model-type",
"logistic_regression",
"--n-train",
"150",
"--n-cal",
"80",
"--n-test",
"120",
"--d",
"10",
]
)
results = pd.read_csv(output_dir / "sweep_results.csv")
summary = pd.read_csv(output_dir / "sweep_summary.csv")
advantage = pd.read_csv(output_dir / "advantage_summary.csv")
assert {"gamma", "delta_m", "delta_x", "method"} <= set(results.columns)
assert {"gamma", "delta_m", "delta_x", "method"} <= set(summary.columns)
assert {"gamma", "delta_m", "delta_x", "advantage_gap"} <= set(advantage.columns)
assert (output_dir / "sim2_advantage_heatmap_gamma0p0.png").exists()
assert (output_dir / "sim2_advantage_heatmap_gamma2p0.png").exists()
def test_simulation_sweep_main_supports_sim3_mechanism_outputs(tmp_path: Path) -> None:
output_dir = tmp_path / "simulation-sim3"
main(
[
"--sim",
"3",
"--output-dir",
str(output_dir),
"--seeds",
"2",
"--delta-m",
"0.2",
"--delta-x",
"0.0",
"--gamma",
"2.0",
"--model-type",
"logistic_regression",
"--n-train",
"150",
"--n-cal",
"80",
"--n-test",
"120",
"--d",
"10",
]
)
results = pd.read_csv(output_dir / "sweep_results.csv")
summary = pd.read_csv(output_dir / "sweep_summary.csv")
assert {"mechanism", "method"} <= set(results.columns)
assert {"mechanism", "method"} <= set(summary.columns)
assert set(summary["mechanism"]) == {"mcar", "mar_severity", "mar_selective", "mnar_extreme"}
assert (output_dir / "sim3_mechanism_bar.png").exists()