from __future__ import annotations import json from pathlib import Path import pandas as pd from sepsis_mcp.simulation import ( SimulationConfig, estimate_gamma_from_correlation, generate_simulation_splits, run_model_swap_simulation_cell, run_simulation_cell, write_sim1_main_figure, select_shifted_missingness_group, ) from sepsis_mcp.simulation_sweep import main def test_generate_simulation_splits_returns_expected_shapes_and_labels() -> None: config = SimulationConfig( n_train=120, n_cal=60, n_test=80, d=10, random_state=7, ) splits = generate_simulation_splits(config, delta_m=0.2, delta_x=0.0, seed=7) assert splits.train.complete_features.shape == (120, 10) assert splits.calibration.observed_features.shape == (60, 10) assert splits.test.missing_mask.shape == (80, 10) assert set(splits.train.labels.unique()) <= {0, 1} assert set(splits.calibration.labels.unique()) <= {0, 1} assert set(splits.test.labels.unique()) <= {0, 1} def test_missingness_shift_increases_test_missingness_on_predictive_features() -> None: config = SimulationConfig( n_train=250, n_cal=120, n_test=160, d=10, random_state=0, ) no_shift = generate_simulation_splits(config, delta_m=0.0, delta_x=0.0, seed=0) shifted = generate_simulation_splits(config, delta_m=0.3, delta_x=0.0, seed=0) baseline_feature_missing = no_shift.test.missing_mask.iloc[:, 0].mean() shifted_feature_missing = shifted.test.missing_mask.iloc[:, 0].mean() calibration_feature_missing = shifted.calibration.missing_mask.iloc[:, 0].mean() assert shifted_feature_missing > baseline_feature_missing assert abs(calibration_feature_missing - config.base_missing_rate) < 0.06 assert shifted_feature_missing - calibration_feature_missing > 0.12 def test_select_shifted_missingness_group_picks_high_shift_feature() -> None: config = SimulationConfig( n_train=180, n_cal=90, n_test=90, d=8, random_state=13, ) splits = generate_simulation_splits(config, delta_m=0.25, delta_x=0.0, seed=13) grouping = select_shifted_missingness_group( splits.calibration.missing_mask, splits.test.missing_mask, min_missing_fraction=0.05, max_missing_fraction=0.95, min_group_size=10, ) assert grouping.variable_name in splits.calibration.missing_mask.columns assert set(grouping.calibration_group_ids.unique()) <= {0, 1} assert set(grouping.test_group_ids.unique()) <= {0, 1} def test_run_simulation_cell_emits_all_requested_methods() -> None: config = SimulationConfig( n_train=220, n_cal=120, n_test=160, d=10, random_state=5, ) results = run_simulation_cell( config, delta_m=0.2, delta_x=0.0, seed=5, methods=( "standard", "mondrian_tilted", "predicted_risk_tercile", "random_grouping", "oracle_grouping", ), ) assert isinstance(results, pd.DataFrame) assert set(results["method"]) == { "standard", "mondrian_tilted", "predicted_risk_tercile", "random_grouping", "oracle_grouping", } assert { "delta_m", "delta_x", "seed", "empirical_coverage", "average_set_size", "max_group_coverage_gap", "evaluation_grouping_strategy", } <= set(results.columns) assert set(results["evaluation_grouping_strategy"]) == {"coverage_gap_variable"} def test_run_model_swap_simulation_cell_emits_task_d_methods() -> None: config = SimulationConfig( n_train=220, n_cal=120, n_test=160, d=10, random_state=9, missingness_mechanism="mar_severity", mar_gamma=2.0, ) results = run_model_swap_simulation_cell( config, delta_m=0.2, delta_x=0.0, seed=9, calibrate_model_type="logistic_regression", deploy_model_type="xgboost", ) assert set(results["method"]) == { "missingness_mondrian_matched", "missingness_mondrian_swap", "risk_mondrian_swap", } assert { "delta_m", "delta_x", "gamma", "seed", "calibrate_model", "deploy_model", "empirical_coverage", "average_set_size", "max_group_coverage_gap", "test_group_overlap_vs_matched", } <= set(results.columns) missingness_rows = results[results["method"].str.contains("missingness")] assert (missingness_rows["test_group_overlap_vs_matched"] == 1.0).all() risk_row = results[results["method"] == "risk_mondrian_swap"].iloc[0] assert float(risk_row["test_group_overlap_vs_matched"]) < 1.0 def test_mar_missingness_mechanism_couples_missingness_with_severity() -> None: config = SimulationConfig( n_train=300, n_cal=150, n_test=300, d=10, missingness_mechanism="mar_severity", mar_gamma=2.0, random_state=11, ) splits = generate_simulation_splits(config, delta_m=0.3, delta_x=0.0, seed=11) test_complete = splits.test.complete_features severity = ( (test_complete.iloc[:, 0] * 5.0) + (test_complete.iloc[:, 1] * 3.0) + (test_complete.iloc[:, 2] * 0.5) + (test_complete.iloc[:, 3] * 0.5) ) missing_rate = splits.test.missing_mask.mean(axis=1) high_severity_missing = float(missing_rate[severity >= severity.median()].mean()) low_severity_missing = float(missing_rate[severity < severity.median()].mean()) assert high_severity_missing > low_severity_missing assert high_severity_missing - low_severity_missing > 0.05 def test_mar_selective_changes_predictive_features_more_than_noise_features() -> None: config = SimulationConfig( n_train=300, n_cal=150, n_test=300, d=10, missingness_mechanism="mar_selective", mar_gamma=2.0, random_state=17, ) splits = generate_simulation_splits(config, delta_m=0.2, delta_x=0.0, seed=17) predictive_missing = float(splits.test.missing_mask.iloc[:, :3].mean().mean()) noise_missing = float(splits.test.missing_mask.iloc[:, 5:].mean().mean()) assert predictive_missing > noise_missing assert predictive_missing - noise_missing > 0.05 def test_mnar_extreme_increases_missingness_for_extreme_feature_values() -> None: config = SimulationConfig( n_train=300, n_cal=150, n_test=300, d=10, missingness_mechanism="mnar_extreme", random_state=19, ) splits = generate_simulation_splits(config, delta_m=0.2, delta_x=0.0, seed=19) x0 = splits.test.complete_features.iloc[:, 0].abs() row_missing = splits.test.missing_mask.iloc[:, 0] extreme_missing = float(row_missing[x0 > 1.5].mean()) typical_missing = float(row_missing[x0 <= 1.5].mean()) assert extreme_missing > typical_missing def test_simulation_sweep_main_writes_results_summary_and_plot(tmp_path: Path) -> None: output_dir = tmp_path / "simulation-sweep" main( [ "--sim", "1", "--output-dir", str(output_dir), "--seeds", "2", "--delta-m-grid", "0.0,0.2", "--model-type", "logistic_regression", "--n-train", "150", "--n-cal", "80", "--n-test", "120", "--d", "10", ] ) results = pd.read_csv(output_dir / "sweep_results.csv") summary = pd.read_csv(output_dir / "sweep_summary.csv") config = json.loads((output_dir / "simulation_config.json").read_text(encoding="utf-8")) assert results.shape[0] == 2 * 2 * 5 assert set(summary["method"]) == { "standard", "mondrian_tilted", "predicted_risk_tercile", "random_grouping", "oracle_grouping", } assert config["model_type"] == "logistic_regression" assert (output_dir / "sim1_gap_vs_delta_m.png").exists() def test_simulation_sweep_main_supports_sim5(tmp_path: Path) -> None: output_dir = tmp_path / "sim5" main( [ "--sim", "5", "--output-dir", str(output_dir), "--seeds", "2", "--delta-m-grid", "0.0,0.2", "--gamma-grid", "0.0,2.0", "--n-train", "150", "--n-cal", "80", "--n-test", "120", "--d", "10", ] ) results = pd.read_csv(output_dir / "sweep_results.csv") summary = pd.read_csv(output_dir / "sweep_summary.csv") assert set(results["method"]) == { "missingness_mondrian_matched", "missingness_mondrian_swap", "risk_mondrian_swap", } assert set(summary["method"]) == { "missingness_mondrian_matched", "missingness_mondrian_swap", "risk_mondrian_swap", } assert (output_dir / "sim5_summary_figure.png").exists() def test_write_sim1_main_figure_writes_two_panel_png(tmp_path: Path) -> None: summary = pd.DataFrame( [ {"gamma": 0.0, "delta_m": 0.0, "method": "standard", "max_group_coverage_gap_mean": 0.03}, {"gamma": 0.0, "delta_m": 0.2, "method": "standard", "max_group_coverage_gap_mean": 0.04}, {"gamma": 0.0, "delta_m": 0.0, "method": "mondrian_tilted", "max_group_coverage_gap_mean": 0.02}, {"gamma": 0.0, "delta_m": 0.2, "method": "mondrian_tilted", "max_group_coverage_gap_mean": 0.02}, {"gamma": 2.0, "delta_m": 0.0, "method": "standard", "max_group_coverage_gap_mean": 0.06}, {"gamma": 2.0, "delta_m": 0.2, "method": "standard", "max_group_coverage_gap_mean": 0.08}, {"gamma": 2.0, "delta_m": 0.0, "method": "mondrian_tilted", "max_group_coverage_gap_mean": 0.05}, {"gamma": 2.0, "delta_m": 0.2, "method": "mondrian_tilted", "max_group_coverage_gap_mean": 0.05}, ] ) output_path = tmp_path / "sim1_main.png" write_sim1_main_figure(summary, output_path=output_path, gamma_values=(0.0, 2.0)) assert output_path.exists() def test_estimate_gamma_from_correlation_chooses_nearest_gamma() -> None: estimated = estimate_gamma_from_correlation( target_correlation=0.31, gamma_to_correlation={ 0.0: 0.05, 1.0: 0.18, 2.0: 0.29, 3.0: 0.41, }, ) assert estimated == 2.0 def test_simulation_sweep_main_supports_sim2_heatmap_outputs(tmp_path: Path) -> None: output_dir = tmp_path / "simulation-sim2" main( [ "--sim", "2", "--output-dir", str(output_dir), "--seeds", "2", "--delta-m-grid", "0.0,0.2", "--delta-x-grid", "0.0,1.0", "--gamma-grid", "0.0,2.0", "--model-type", "logistic_regression", "--n-train", "150", "--n-cal", "80", "--n-test", "120", "--d", "10", ] ) results = pd.read_csv(output_dir / "sweep_results.csv") summary = pd.read_csv(output_dir / "sweep_summary.csv") advantage = pd.read_csv(output_dir / "advantage_summary.csv") assert {"gamma", "delta_m", "delta_x", "method"} <= set(results.columns) assert {"gamma", "delta_m", "delta_x", "method"} <= set(summary.columns) assert {"gamma", "delta_m", "delta_x", "advantage_gap"} <= set(advantage.columns) assert (output_dir / "sim2_advantage_heatmap_gamma0p0.png").exists() assert (output_dir / "sim2_advantage_heatmap_gamma2p0.png").exists() def test_simulation_sweep_main_supports_sim3_mechanism_outputs(tmp_path: Path) -> None: output_dir = tmp_path / "simulation-sim3" main( [ "--sim", "3", "--output-dir", str(output_dir), "--seeds", "2", "--delta-m", "0.2", "--delta-x", "0.0", "--gamma", "2.0", "--model-type", "logistic_regression", "--n-train", "150", "--n-cal", "80", "--n-test", "120", "--d", "10", ] ) results = pd.read_csv(output_dir / "sweep_results.csv") summary = pd.read_csv(output_dir / "sweep_summary.csv") assert {"mechanism", "method"} <= set(results.columns) assert {"mechanism", "method"} <= set(summary.columns) assert set(summary["mechanism"]) == {"mcar", "mar_severity", "mar_selective", "mnar_extreme"} assert (output_dir / "sim3_mechanism_bar.png").exists()