| from __future__ import annotations |
|
|
| import json |
| from pathlib import Path |
|
|
| import pandas as pd |
|
|
| from sepsis_mcp.simulation import ( |
| SimulationConfig, |
| estimate_gamma_from_correlation, |
| generate_simulation_splits, |
| run_model_swap_simulation_cell, |
| run_simulation_cell, |
| write_sim1_main_figure, |
| select_shifted_missingness_group, |
| ) |
| from sepsis_mcp.simulation_sweep import main |
|
|
|
|
| def test_generate_simulation_splits_returns_expected_shapes_and_labels() -> None: |
| config = SimulationConfig( |
| n_train=120, |
| n_cal=60, |
| n_test=80, |
| d=10, |
| random_state=7, |
| ) |
|
|
| splits = generate_simulation_splits(config, delta_m=0.2, delta_x=0.0, seed=7) |
|
|
| assert splits.train.complete_features.shape == (120, 10) |
| assert splits.calibration.observed_features.shape == (60, 10) |
| assert splits.test.missing_mask.shape == (80, 10) |
| assert set(splits.train.labels.unique()) <= {0, 1} |
| assert set(splits.calibration.labels.unique()) <= {0, 1} |
| assert set(splits.test.labels.unique()) <= {0, 1} |
|
|
|
|
| def test_missingness_shift_increases_test_missingness_on_predictive_features() -> None: |
| config = SimulationConfig( |
| n_train=250, |
| n_cal=120, |
| n_test=160, |
| d=10, |
| random_state=0, |
| ) |
|
|
| no_shift = generate_simulation_splits(config, delta_m=0.0, delta_x=0.0, seed=0) |
| shifted = generate_simulation_splits(config, delta_m=0.3, delta_x=0.0, seed=0) |
|
|
| baseline_feature_missing = no_shift.test.missing_mask.iloc[:, 0].mean() |
| shifted_feature_missing = shifted.test.missing_mask.iloc[:, 0].mean() |
| calibration_feature_missing = shifted.calibration.missing_mask.iloc[:, 0].mean() |
|
|
| assert shifted_feature_missing > baseline_feature_missing |
| assert abs(calibration_feature_missing - config.base_missing_rate) < 0.06 |
| assert shifted_feature_missing - calibration_feature_missing > 0.12 |
|
|
|
|
| def test_select_shifted_missingness_group_picks_high_shift_feature() -> None: |
| config = SimulationConfig( |
| n_train=180, |
| n_cal=90, |
| n_test=90, |
| d=8, |
| random_state=13, |
| ) |
| splits = generate_simulation_splits(config, delta_m=0.25, delta_x=0.0, seed=13) |
|
|
| grouping = select_shifted_missingness_group( |
| splits.calibration.missing_mask, |
| splits.test.missing_mask, |
| min_missing_fraction=0.05, |
| max_missing_fraction=0.95, |
| min_group_size=10, |
| ) |
|
|
| assert grouping.variable_name in splits.calibration.missing_mask.columns |
| assert set(grouping.calibration_group_ids.unique()) <= {0, 1} |
| assert set(grouping.test_group_ids.unique()) <= {0, 1} |
|
|
|
|
| def test_run_simulation_cell_emits_all_requested_methods() -> None: |
| config = SimulationConfig( |
| n_train=220, |
| n_cal=120, |
| n_test=160, |
| d=10, |
| random_state=5, |
| ) |
|
|
| results = run_simulation_cell( |
| config, |
| delta_m=0.2, |
| delta_x=0.0, |
| seed=5, |
| methods=( |
| "standard", |
| "mondrian_tilted", |
| "predicted_risk_tercile", |
| "random_grouping", |
| "oracle_grouping", |
| ), |
| ) |
|
|
| assert isinstance(results, pd.DataFrame) |
| assert set(results["method"]) == { |
| "standard", |
| "mondrian_tilted", |
| "predicted_risk_tercile", |
| "random_grouping", |
| "oracle_grouping", |
| } |
| assert { |
| "delta_m", |
| "delta_x", |
| "seed", |
| "empirical_coverage", |
| "average_set_size", |
| "max_group_coverage_gap", |
| "evaluation_grouping_strategy", |
| } <= set(results.columns) |
| assert set(results["evaluation_grouping_strategy"]) == {"coverage_gap_variable"} |
|
|
|
|
| def test_run_model_swap_simulation_cell_emits_task_d_methods() -> None: |
| config = SimulationConfig( |
| n_train=220, |
| n_cal=120, |
| n_test=160, |
| d=10, |
| random_state=9, |
| missingness_mechanism="mar_severity", |
| mar_gamma=2.0, |
| ) |
|
|
| results = run_model_swap_simulation_cell( |
| config, |
| delta_m=0.2, |
| delta_x=0.0, |
| seed=9, |
| calibrate_model_type="logistic_regression", |
| deploy_model_type="xgboost", |
| ) |
|
|
| assert set(results["method"]) == { |
| "missingness_mondrian_matched", |
| "missingness_mondrian_swap", |
| "risk_mondrian_swap", |
| } |
| assert { |
| "delta_m", |
| "delta_x", |
| "gamma", |
| "seed", |
| "calibrate_model", |
| "deploy_model", |
| "empirical_coverage", |
| "average_set_size", |
| "max_group_coverage_gap", |
| "test_group_overlap_vs_matched", |
| } <= set(results.columns) |
| missingness_rows = results[results["method"].str.contains("missingness")] |
| assert (missingness_rows["test_group_overlap_vs_matched"] == 1.0).all() |
| risk_row = results[results["method"] == "risk_mondrian_swap"].iloc[0] |
| assert float(risk_row["test_group_overlap_vs_matched"]) < 1.0 |
|
|
|
|
| def test_mar_missingness_mechanism_couples_missingness_with_severity() -> None: |
| config = SimulationConfig( |
| n_train=300, |
| n_cal=150, |
| n_test=300, |
| d=10, |
| missingness_mechanism="mar_severity", |
| mar_gamma=2.0, |
| random_state=11, |
| ) |
|
|
| splits = generate_simulation_splits(config, delta_m=0.3, delta_x=0.0, seed=11) |
| test_complete = splits.test.complete_features |
| severity = ( |
| (test_complete.iloc[:, 0] * 5.0) |
| + (test_complete.iloc[:, 1] * 3.0) |
| + (test_complete.iloc[:, 2] * 0.5) |
| + (test_complete.iloc[:, 3] * 0.5) |
| ) |
| missing_rate = splits.test.missing_mask.mean(axis=1) |
|
|
| high_severity_missing = float(missing_rate[severity >= severity.median()].mean()) |
| low_severity_missing = float(missing_rate[severity < severity.median()].mean()) |
|
|
| assert high_severity_missing > low_severity_missing |
| assert high_severity_missing - low_severity_missing > 0.05 |
|
|
|
|
| def test_mar_selective_changes_predictive_features_more_than_noise_features() -> None: |
| config = SimulationConfig( |
| n_train=300, |
| n_cal=150, |
| n_test=300, |
| d=10, |
| missingness_mechanism="mar_selective", |
| mar_gamma=2.0, |
| random_state=17, |
| ) |
|
|
| splits = generate_simulation_splits(config, delta_m=0.2, delta_x=0.0, seed=17) |
| predictive_missing = float(splits.test.missing_mask.iloc[:, :3].mean().mean()) |
| noise_missing = float(splits.test.missing_mask.iloc[:, 5:].mean().mean()) |
|
|
| assert predictive_missing > noise_missing |
| assert predictive_missing - noise_missing > 0.05 |
|
|
|
|
| def test_mnar_extreme_increases_missingness_for_extreme_feature_values() -> None: |
| config = SimulationConfig( |
| n_train=300, |
| n_cal=150, |
| n_test=300, |
| d=10, |
| missingness_mechanism="mnar_extreme", |
| random_state=19, |
| ) |
|
|
| splits = generate_simulation_splits(config, delta_m=0.2, delta_x=0.0, seed=19) |
| x0 = splits.test.complete_features.iloc[:, 0].abs() |
| row_missing = splits.test.missing_mask.iloc[:, 0] |
|
|
| extreme_missing = float(row_missing[x0 > 1.5].mean()) |
| typical_missing = float(row_missing[x0 <= 1.5].mean()) |
|
|
| assert extreme_missing > typical_missing |
|
|
|
|
| def test_simulation_sweep_main_writes_results_summary_and_plot(tmp_path: Path) -> None: |
| output_dir = tmp_path / "simulation-sweep" |
|
|
| main( |
| [ |
| "--sim", |
| "1", |
| "--output-dir", |
| str(output_dir), |
| "--seeds", |
| "2", |
| "--delta-m-grid", |
| "0.0,0.2", |
| "--model-type", |
| "logistic_regression", |
| "--n-train", |
| "150", |
| "--n-cal", |
| "80", |
| "--n-test", |
| "120", |
| "--d", |
| "10", |
| ] |
| ) |
|
|
| results = pd.read_csv(output_dir / "sweep_results.csv") |
| summary = pd.read_csv(output_dir / "sweep_summary.csv") |
| config = json.loads((output_dir / "simulation_config.json").read_text(encoding="utf-8")) |
|
|
| assert results.shape[0] == 2 * 2 * 5 |
| assert set(summary["method"]) == { |
| "standard", |
| "mondrian_tilted", |
| "predicted_risk_tercile", |
| "random_grouping", |
| "oracle_grouping", |
| } |
| assert config["model_type"] == "logistic_regression" |
| assert (output_dir / "sim1_gap_vs_delta_m.png").exists() |
|
|
|
|
| def test_simulation_sweep_main_supports_sim5(tmp_path: Path) -> None: |
| output_dir = tmp_path / "sim5" |
|
|
| main( |
| [ |
| "--sim", |
| "5", |
| "--output-dir", |
| str(output_dir), |
| "--seeds", |
| "2", |
| "--delta-m-grid", |
| "0.0,0.2", |
| "--gamma-grid", |
| "0.0,2.0", |
| "--n-train", |
| "150", |
| "--n-cal", |
| "80", |
| "--n-test", |
| "120", |
| "--d", |
| "10", |
| ] |
| ) |
|
|
| results = pd.read_csv(output_dir / "sweep_results.csv") |
| summary = pd.read_csv(output_dir / "sweep_summary.csv") |
|
|
| assert set(results["method"]) == { |
| "missingness_mondrian_matched", |
| "missingness_mondrian_swap", |
| "risk_mondrian_swap", |
| } |
| assert set(summary["method"]) == { |
| "missingness_mondrian_matched", |
| "missingness_mondrian_swap", |
| "risk_mondrian_swap", |
| } |
| assert (output_dir / "sim5_summary_figure.png").exists() |
|
|
|
|
| def test_write_sim1_main_figure_writes_two_panel_png(tmp_path: Path) -> None: |
| summary = pd.DataFrame( |
| [ |
| {"gamma": 0.0, "delta_m": 0.0, "method": "standard", "max_group_coverage_gap_mean": 0.03}, |
| {"gamma": 0.0, "delta_m": 0.2, "method": "standard", "max_group_coverage_gap_mean": 0.04}, |
| {"gamma": 0.0, "delta_m": 0.0, "method": "mondrian_tilted", "max_group_coverage_gap_mean": 0.02}, |
| {"gamma": 0.0, "delta_m": 0.2, "method": "mondrian_tilted", "max_group_coverage_gap_mean": 0.02}, |
| {"gamma": 2.0, "delta_m": 0.0, "method": "standard", "max_group_coverage_gap_mean": 0.06}, |
| {"gamma": 2.0, "delta_m": 0.2, "method": "standard", "max_group_coverage_gap_mean": 0.08}, |
| {"gamma": 2.0, "delta_m": 0.0, "method": "mondrian_tilted", "max_group_coverage_gap_mean": 0.05}, |
| {"gamma": 2.0, "delta_m": 0.2, "method": "mondrian_tilted", "max_group_coverage_gap_mean": 0.05}, |
| ] |
| ) |
| output_path = tmp_path / "sim1_main.png" |
|
|
| write_sim1_main_figure(summary, output_path=output_path, gamma_values=(0.0, 2.0)) |
|
|
| assert output_path.exists() |
|
|
|
|
| def test_estimate_gamma_from_correlation_chooses_nearest_gamma() -> None: |
| estimated = estimate_gamma_from_correlation( |
| target_correlation=0.31, |
| gamma_to_correlation={ |
| 0.0: 0.05, |
| 1.0: 0.18, |
| 2.0: 0.29, |
| 3.0: 0.41, |
| }, |
| ) |
|
|
| assert estimated == 2.0 |
|
|
|
|
| def test_simulation_sweep_main_supports_sim2_heatmap_outputs(tmp_path: Path) -> None: |
| output_dir = tmp_path / "simulation-sim2" |
|
|
| main( |
| [ |
| "--sim", |
| "2", |
| "--output-dir", |
| str(output_dir), |
| "--seeds", |
| "2", |
| "--delta-m-grid", |
| "0.0,0.2", |
| "--delta-x-grid", |
| "0.0,1.0", |
| "--gamma-grid", |
| "0.0,2.0", |
| "--model-type", |
| "logistic_regression", |
| "--n-train", |
| "150", |
| "--n-cal", |
| "80", |
| "--n-test", |
| "120", |
| "--d", |
| "10", |
| ] |
| ) |
|
|
| results = pd.read_csv(output_dir / "sweep_results.csv") |
| summary = pd.read_csv(output_dir / "sweep_summary.csv") |
| advantage = pd.read_csv(output_dir / "advantage_summary.csv") |
|
|
| assert {"gamma", "delta_m", "delta_x", "method"} <= set(results.columns) |
| assert {"gamma", "delta_m", "delta_x", "method"} <= set(summary.columns) |
| assert {"gamma", "delta_m", "delta_x", "advantage_gap"} <= set(advantage.columns) |
| assert (output_dir / "sim2_advantage_heatmap_gamma0p0.png").exists() |
| assert (output_dir / "sim2_advantage_heatmap_gamma2p0.png").exists() |
|
|
|
|
| def test_simulation_sweep_main_supports_sim3_mechanism_outputs(tmp_path: Path) -> None: |
| output_dir = tmp_path / "simulation-sim3" |
|
|
| main( |
| [ |
| "--sim", |
| "3", |
| "--output-dir", |
| str(output_dir), |
| "--seeds", |
| "2", |
| "--delta-m", |
| "0.2", |
| "--delta-x", |
| "0.0", |
| "--gamma", |
| "2.0", |
| "--model-type", |
| "logistic_regression", |
| "--n-train", |
| "150", |
| "--n-cal", |
| "80", |
| "--n-test", |
| "120", |
| "--d", |
| "10", |
| ] |
| ) |
|
|
| results = pd.read_csv(output_dir / "sweep_results.csv") |
| summary = pd.read_csv(output_dir / "sweep_summary.csv") |
|
|
| assert {"mechanism", "method"} <= set(results.columns) |
| assert {"mechanism", "method"} <= set(summary.columns) |
| assert set(summary["mechanism"]) == {"mcar", "mar_severity", "mar_selective", "mnar_extreme"} |
| assert (output_dir / "sim3_mechanism_bar.png").exists() |
|
|