from __future__ import annotations import numpy as np import pandas as pd from sepsis_mcp.gossis_selection_bridge_analysis import ( _group_eta_summary, _select_eta_oracle_variable, _select_proxy_variable, _spearman_summary, ) def test_group_eta_summary_returns_binary_group_stats() -> None: scores = np.array([0.1, 0.2, 0.8, 0.9], dtype=float) group_ids = np.array([0, 0, 1, 1], dtype=int) summary = _group_eta_summary(scores, group_ids) assert summary["eta_max"] >= summary["eta_mean"] >= summary["eta_min"] >= 0.0 assert summary["group_size_min"] == 2 assert summary["group_size_max"] == 2 def test_select_proxy_variable_matches_current_tie_breaker() -> None: diagnostics = { "b": { "average_absolute_selection_gap": 0.30, "imbalance": 0.20, "minority_support": 120, }, "a": { "average_absolute_selection_gap": 0.30, "imbalance": 0.20, "minority_support": 110, }, "c": { "average_absolute_selection_gap": 0.20, "imbalance": 0.05, "minority_support": 500, }, } assert _select_proxy_variable(diagnostics) == "b" def test_select_eta_oracle_variable_prefers_higher_eta_then_gap() -> None: frame = pd.DataFrame( [ {"variable": "v1", "eta_max": 0.50, "eta_mean": 0.20, "selection_gap": 0.10, "minority_support": 150}, {"variable": "v2", "eta_max": 0.50, "eta_mean": 0.25, "selection_gap": 0.05, "minority_support": 200}, {"variable": "v3", "eta_max": 0.40, "eta_mean": 0.30, "selection_gap": 0.30, "minority_support": 400}, ] ) assert _select_eta_oracle_variable(frame) == "v2" def test_spearman_summary_reports_pooled_and_seed_level_stats() -> None: frame = pd.DataFrame( [ {"seed": 0, "selection_gap": 0.1, "eta_max": 0.2}, {"seed": 0, "selection_gap": 0.2, "eta_max": 0.3}, {"seed": 0, "selection_gap": 0.3, "eta_max": 0.4}, {"seed": 1, "selection_gap": 0.4, "eta_max": 0.5}, {"seed": 1, "selection_gap": 0.5, "eta_max": 0.6}, {"seed": 1, "selection_gap": 0.6, "eta_max": 0.7}, ] ) summary = _spearman_summary(frame, x="selection_gap", y="eta_max") assert summary["pair_count"] == 6 assert summary["seed_count"] == 2 assert summary["pooled_rho"] == 1.0 assert summary["seed_rho_mean"] == 1.0