File size: 7,593 Bytes
32f5a65 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 | from __future__ import annotations
import importlib
import numpy as np
import pandas as pd
import pytest
def test_gibbs_missingness_feature_transform_reuses_calibration_bin_edges() -> None:
from sepsis_mcp.external_baselines import fit_gibbs_missingness_feature_transform
calibration_frame = pd.DataFrame(
{
"global_missing_rate": [0.10, 0.20, 0.80, 0.90],
}
)
test_frame = pd.DataFrame(
{
"global_missing_rate": [0.15, 0.85],
}
)
transform = fit_gibbs_missingness_feature_transform(
calibration_frame=calibration_frame,
calibration_positive_probabilities=[0.2, 0.3, 0.7, 0.8],
quantile_grid=(1 / 3, 2 / 3),
)
calibration_features = transform.transform(
calibration_frame,
positive_probabilities=[0.2, 0.3, 0.7, 0.8],
)
test_features = transform.transform(
test_frame,
positive_probabilities=[0.25, 0.75],
)
assert transform.feature_names_[0] == "model_positive_probability"
assert transform.bin_edges_.shape == (2,)
assert calibration_features.shape == (4, len(transform.feature_names_))
assert test_features.shape == (2, len(transform.feature_names_))
assert np.allclose(test_features[:, 1], [0.15, 0.85])
def test_gibbs_general_feature_transform_emits_named_columns() -> None:
from sepsis_mcp.external_baselines import fit_gibbs_general_feature_transform
calibration_frame = pd.DataFrame(
{
"age": [40, 70, 50],
"severity": [1.0, 2.0, 1.5],
"global_missing_rate": [0.1, 0.4, 0.2],
}
)
transform = fit_gibbs_general_feature_transform(
calibration_frame=calibration_frame,
calibration_positive_probabilities=[0.2, 0.7, 0.4],
candidate_columns=["age", "severity"],
)
features = transform.transform(
calibration_frame,
positive_probabilities=[0.2, 0.7, 0.4],
)
assert transform.feature_names_ == [
"model_positive_probability",
"age",
"severity",
]
assert features.shape == (3, 3)
def test_gibbs_general_feature_transform_imputes_non_finite_values_from_calibration() -> None:
from sepsis_mcp.external_baselines import fit_gibbs_general_feature_transform
calibration_frame = pd.DataFrame(
{
"age": [40.0, np.nan, 60.0],
"severity": [1.0, 2.0, np.inf],
}
)
test_frame = pd.DataFrame(
{
"age": [np.nan, 55.0],
"severity": [np.nan, -np.inf],
}
)
transform = fit_gibbs_general_feature_transform(
calibration_frame=calibration_frame,
calibration_positive_probabilities=[0.2, 0.7, 0.4],
candidate_columns=["age", "severity"],
)
calibration_features = transform.transform(
calibration_frame,
positive_probabilities=[0.2, 0.7, 0.4],
)
test_features = transform.transform(
test_frame,
positive_probabilities=[0.3, 0.6],
)
assert np.isfinite(calibration_features).all()
assert np.isfinite(test_features).all()
assert np.allclose(test_features[:, 1], [50.0, 55.0])
assert np.allclose(test_features[:, 2], [1.5, 1.5])
def test_learned_partition_classifier_assigns_groups_and_predicts_sets() -> None:
from sepsis_mcp.external_baselines import LearnedPartitionConformalClassifier
selection_features = pd.DataFrame({"signal": [0.0, 0.1, 0.9, 1.0]})
calibration_features = pd.DataFrame({"signal": [0.05, 0.95, 0.15, 0.85]})
test_features = pd.DataFrame({"signal": [0.02, 0.98]})
classifier = LearnedPartitionConformalClassifier(
alpha=0.4,
min_group_size=2,
max_leaf_nodes=2,
random_state=0,
).fit(
selection_features=selection_features,
selection_labels=[0, 0, 1, 1],
selection_positive_probabilities=[0.1, 0.2, 0.8, 0.9],
calibration_features=calibration_features,
calibration_labels=[0, 1, 0, 1],
calibration_positive_probabilities=[0.15, 0.85, 0.25, 0.75],
)
predicted_groups = classifier.predict_groups(test_features)
prediction_sets = classifier.predict_sets(
positive_probabilities=[0.1, 0.9],
test_features=test_features,
)
diagnostics = classifier.diagnostics_summary()
assert predicted_groups.shape == (2,)
assert len(set(predicted_groups.tolist())) == 2
assert prediction_sets == [{0}, {1}]
assert diagnostics["leaf_count"] == 2
def test_gibbs_classifier_raises_actionable_error_when_dependency_missing(monkeypatch: pytest.MonkeyPatch) -> None:
from sepsis_mcp.external_baselines import GibbsConditionalConformalClassifier
original_import_module = importlib.import_module
def fake_import_module(name: str, package: str | None = None):
if name == "conditionalconformal":
raise ModuleNotFoundError("No module named 'conditionalconformal'")
return original_import_module(name, package)
monkeypatch.setattr(importlib, "import_module", fake_import_module)
classifier = GibbsConditionalConformalClassifier(alpha=0.2)
with pytest.raises(ImportError, match="conditionalconformal"):
classifier.fit(
calibration_labels=[0, 1],
calibration_positive_probabilities=[0.1, 0.9],
calibration_features=np.asarray([[0.1, 0.2], [0.9, 0.8]], dtype=float),
)
def test_gibbs_classifier_uses_finite_basis_condconf_path(monkeypatch: pytest.MonkeyPatch) -> None:
from sepsis_mcp.external_baselines import GibbsConditionalConformalClassifier
captured: dict[str, object] = {}
class FakeCondConf:
def __init__(self, score_fn, Phi_fn, quantile_fn, infinite_params):
captured["quantile_fn"] = quantile_fn
captured["infinite_params"] = infinite_params
self._score_fn = score_fn
self._phi_fn = Phi_fn
def setup_problem(self, X, Y): # noqa: N802
captured["setup_shape"] = np.asarray(X).shape
def predict(self, quantile, x_test, score_inv_fn, **kwargs):
captured.setdefault("predict_quantiles", []).append(float(quantile))
captured.setdefault("predict_shapes", []).append(np.asarray(x_test).shape)
captured.setdefault("predict_kwargs", []).append(kwargs)
return score_inv_fn(0.2, np.asarray(x_test))
original_import_module = importlib.import_module
def fake_import_module(name: str, package: str | None = None):
if name == "conditionalconformal":
return type("FakeModule", (), {"CondConf": FakeCondConf})
return original_import_module(name, package)
monkeypatch.setattr(importlib, "import_module", fake_import_module)
classifier = GibbsConditionalConformalClassifier(alpha=0.2)
classifier.fit(
calibration_labels=[0, 1],
calibration_positive_probabilities=[0.1, 0.9],
calibration_features=np.asarray([[0.1, 0.2], [0.9, 0.8]], dtype=float),
)
thresholds = classifier.thresholds_for_test_features(
np.asarray([[0.2, 0.3], [0.8, 0.7]], dtype=float)
)
assert captured["quantile_fn"] is None
assert captured["infinite_params"] == {}
assert captured["predict_quantiles"] == [pytest.approx(0.8), pytest.approx(0.8)]
assert captured["predict_shapes"] == [(1, 2), (1, 2)]
assert captured["predict_kwargs"] == [
{"randomize": False, "exact": False},
{"randomize": False, "exact": False},
]
assert np.allclose(thresholds, [0.2, 0.2])
|