misscp / tests /test_gossis_selection_bridge_analysis.py
Anonymous
Initial anonymous MissCP release
32f5a65
from __future__ import annotations
import numpy as np
import pandas as pd
from sepsis_mcp.gossis_selection_bridge_analysis import (
_group_eta_summary,
_select_eta_oracle_variable,
_select_proxy_variable,
_spearman_summary,
)
def test_group_eta_summary_returns_binary_group_stats() -> None:
scores = np.array([0.1, 0.2, 0.8, 0.9], dtype=float)
group_ids = np.array([0, 0, 1, 1], dtype=int)
summary = _group_eta_summary(scores, group_ids)
assert summary["eta_max"] >= summary["eta_mean"] >= summary["eta_min"] >= 0.0
assert summary["group_size_min"] == 2
assert summary["group_size_max"] == 2
def test_select_proxy_variable_matches_current_tie_breaker() -> None:
diagnostics = {
"b": {
"average_absolute_selection_gap": 0.30,
"imbalance": 0.20,
"minority_support": 120,
},
"a": {
"average_absolute_selection_gap": 0.30,
"imbalance": 0.20,
"minority_support": 110,
},
"c": {
"average_absolute_selection_gap": 0.20,
"imbalance": 0.05,
"minority_support": 500,
},
}
assert _select_proxy_variable(diagnostics) == "b"
def test_select_eta_oracle_variable_prefers_higher_eta_then_gap() -> None:
frame = pd.DataFrame(
[
{"variable": "v1", "eta_max": 0.50, "eta_mean": 0.20, "selection_gap": 0.10, "minority_support": 150},
{"variable": "v2", "eta_max": 0.50, "eta_mean": 0.25, "selection_gap": 0.05, "minority_support": 200},
{"variable": "v3", "eta_max": 0.40, "eta_mean": 0.30, "selection_gap": 0.30, "minority_support": 400},
]
)
assert _select_eta_oracle_variable(frame) == "v2"
def test_spearman_summary_reports_pooled_and_seed_level_stats() -> None:
frame = pd.DataFrame(
[
{"seed": 0, "selection_gap": 0.1, "eta_max": 0.2},
{"seed": 0, "selection_gap": 0.2, "eta_max": 0.3},
{"seed": 0, "selection_gap": 0.3, "eta_max": 0.4},
{"seed": 1, "selection_gap": 0.4, "eta_max": 0.5},
{"seed": 1, "selection_gap": 0.5, "eta_max": 0.6},
{"seed": 1, "selection_gap": 0.6, "eta_max": 0.7},
]
)
summary = _spearman_summary(frame, x="selection_gap", y="eta_max")
assert summary["pair_count"] == 6
assert summary["seed_count"] == 2
assert summary["pooled_rho"] == 1.0
assert summary["seed_rho_mean"] == 1.0