misscp / tests /test_conformal.py
Anonymous
Initial anonymous MissCP release
32f5a65
from __future__ import annotations
import pytest
import sepsis_mcp.conformal as conformal_module
from sepsis_mcp.conformal import (
CPMDAExactClassifier,
ConformalizedQuantileRegressor,
GroupedPropensityWeightedConformalClassifier,
LabelConditionalConformalClassifier,
LabelConditionalMissingnessAwareConformalClassifier,
MissingnessAwareConformalClassifier,
MissingnessAwareConformalizedQuantileRegressor,
MondrianTiltedConformalClassifier,
PropensityWeightedConformalClassifier,
ShrunkWeightedMissingnessConformalClassifier,
SplitConformalClassifier,
WeightedMissingnessGroupConformalClassifier,
WeightedMissingnessConformalClassifier,
binary_nonconformity_scores,
)
from sepsis_mcp.metrics import average_set_size, empirical_coverage, abstention_rate, prediction_set_size_breakdown
def test_binary_nonconformity_scores_follow_true_label_definition() -> None:
y_true = [0, 1, 1, 0]
positive_probabilities = [0.1, 0.8, 0.3, 0.9]
scores = binary_nonconformity_scores(y_true, positive_probabilities)
assert scores.tolist() == pytest.approx([0.1, 0.2, 0.7, 0.9])
def test_split_conformal_prediction_sets_use_single_global_threshold() -> None:
calibrator = SplitConformalClassifier(alpha=0.4)
calibrator.fit(
calibration_labels=[0, 1, 1, 0],
calibration_positive_probabilities=[0.1, 0.8, 0.3, 0.9],
)
prediction_sets = calibrator.predict_sets([0.8, 0.55, 0.2])
assert calibrator.threshold_ == pytest.approx(0.7)
assert prediction_sets == [{1}, {0, 1}, {0}]
def test_split_conformal_uses_infinite_threshold_when_nominal_rank_exceeds_sample() -> None:
calibrator = SplitConformalClassifier(alpha=0.1)
calibrator.fit(
calibration_labels=[0, 1],
calibration_positive_probabilities=[0.1, 0.8],
)
prediction_sets = calibrator.predict_sets([0.05, 0.95])
assert calibrator.threshold_ == float("inf")
assert prediction_sets == [{0, 1}, {0, 1}]
def test_missingness_aware_conformal_uses_group_specific_thresholds() -> None:
calibrator = MissingnessAwareConformalClassifier(alpha=0.5, min_group_size=2)
calibrator.fit(
calibration_labels=[0, 1, 1, 0, 1, 0],
calibration_positive_probabilities=[0.1, 0.85, 0.8, 0.55, 0.35, 0.9],
calibration_missing_rates=[0.05, 0.1, 0.2, 0.7, 0.8, 0.95],
)
prediction_sets = calibrator.predict_sets(
positive_probabilities=[0.9, 0.45, 0.45],
missing_rates=[0.08, 0.5, 0.9],
)
assert calibrator.global_threshold_ == pytest.approx(0.55)
assert calibrator.group_thresholds_[0] == pytest.approx(0.15)
assert calibrator.group_thresholds_[1] == pytest.approx(0.55)
assert calibrator.group_thresholds_[2] == pytest.approx(0.9)
assert prediction_sets == [{1}, {0, 1}, {0, 1}]
def test_missingness_aware_conformal_falls_back_to_global_threshold_for_small_groups() -> None:
calibrator = MissingnessAwareConformalClassifier(alpha=0.5, min_group_size=3)
calibrator.fit(
calibration_labels=[0, 1, 1, 0, 1, 0],
calibration_positive_probabilities=[0.1, 0.85, 0.8, 0.55, 0.35, 0.9],
calibration_missing_rates=[0.05, 0.1, 0.2, 0.7, 0.8, 0.95],
)
prediction_sets = calibrator.predict_sets(
positive_probabilities=[0.45],
missing_rates=[0.08],
)
assert calibrator.group_thresholds_ == {}
assert prediction_sets == [{0, 1}]
def test_missingness_aware_conformal_accepts_explicit_calibration_group_ids() -> None:
calibrator = MissingnessAwareConformalClassifier(alpha=0.5, min_group_size=2)
calibrator.fit(
calibration_labels=[0, 1, 1, 0, 1, 0],
calibration_positive_probabilities=[0.1, 0.85, 0.8, 0.55, 0.35, 0.9],
calibration_group_ids=[0, 0, 1, 1, 2, 2],
)
assert calibrator.global_threshold_ == pytest.approx(0.55)
assert calibrator.group_thresholds_[0] == pytest.approx(0.15)
assert calibrator.group_thresholds_[1] == pytest.approx(0.55)
assert calibrator.group_thresholds_[2] == pytest.approx(0.9)
def test_missingness_aware_conformal_accepts_explicit_test_group_ids() -> None:
calibrator = MissingnessAwareConformalClassifier(alpha=0.5, min_group_size=2)
calibrator.fit(
calibration_labels=[0, 1, 1, 0, 1, 0],
calibration_positive_probabilities=[0.1, 0.85, 0.8, 0.55, 0.35, 0.9],
calibration_missing_rates=[0.05, 0.1, 0.2, 0.7, 0.8, 0.95],
)
prediction_sets = calibrator.predict_sets(
positive_probabilities=[0.9, 0.45, 0.45],
test_group_ids=[0, 1, 2],
)
assert prediction_sets == [{1}, {0, 1}, {0, 1}]
def test_mondrian_tilted_with_zero_shrinkage_matches_pure_mondrian() -> None:
calibration_labels = [0, 1, 1, 0, 1, 0]
calibration_probabilities = [0.1, 0.85, 0.8, 0.55, 0.35, 0.9]
calibration_missing_rates = [0.05, 0.1, 0.2, 0.7, 0.8, 0.95]
calibration_group_ids = [0, 0, 0, 1, 1, 1]
test_probabilities = [0.9, 0.45]
test_missing_rates = [0.08, 0.85]
test_group_ids = [0, 1]
mondrian = MissingnessAwareConformalClassifier(alpha=0.5, min_group_size=2).fit(
calibration_labels=calibration_labels,
calibration_positive_probabilities=calibration_probabilities,
calibration_group_ids=calibration_group_ids,
)
tilted = MondrianTiltedConformalClassifier(
alpha=0.5,
min_group_size=2,
shrinkage_lambda=0.0,
bandwidth=0.1,
).fit(
calibration_labels=calibration_labels,
calibration_positive_probabilities=calibration_probabilities,
calibration_missing_rates=calibration_missing_rates,
calibration_group_ids=calibration_group_ids,
)
assert tilted.predict_sets(
positive_probabilities=test_probabilities,
test_missing_rates=test_missing_rates,
test_group_ids=test_group_ids,
) == mondrian.predict_sets(
positive_probabilities=test_probabilities,
test_group_ids=test_group_ids,
)
def test_mondrian_tilted_falls_back_to_global_threshold_for_small_groups() -> None:
calibrator = MondrianTiltedConformalClassifier(
alpha=0.5,
min_group_size=4,
shrinkage_lambda=0.7,
bandwidth=0.1,
).fit(
calibration_labels=[0, 1, 1, 0, 1, 0],
calibration_positive_probabilities=[0.1, 0.85, 0.8, 0.55, 0.35, 0.9],
calibration_missing_rates=[0.05, 0.1, 0.2, 0.7, 0.8, 0.95],
calibration_group_ids=[0, 0, 0, 1, 1, 1],
)
thresholds = calibrator.thresholds_for_test_points(
test_missing_rates=[0.08, 0.85],
test_group_ids=[0, 1],
)
assert thresholds.tolist() == pytest.approx(
[calibrator.global_threshold_, calibrator.global_threshold_]
)
def test_mondrian_tilted_accepts_explicit_test_group_ids() -> None:
calibrator = MondrianTiltedConformalClassifier(
alpha=0.5,
min_group_size=2,
shrinkage_lambda=0.5,
bandwidth=0.1,
).fit(
calibration_labels=[0, 1, 1, 0, 1, 0],
calibration_positive_probabilities=[0.1, 0.85, 0.8, 0.55, 0.35, 0.9],
calibration_missing_rates=[0.05, 0.1, 0.2, 0.7, 0.8, 0.95],
calibration_group_ids=[0, 0, 0, 1, 1, 1],
)
prediction_sets = calibrator.predict_sets(
positive_probabilities=[0.9, 0.45],
test_missing_rates=[0.08, 0.85],
test_group_ids=[0, 1],
)
assert len(prediction_sets) == 2
assert all(isinstance(prediction_set, set) for prediction_set in prediction_sets)
def test_cp_mda_exact_uses_exact_mask_signature_matches() -> None:
calibrator = CPMDAExactClassifier(alpha=0.4, top_k_features=2, min_match=2).fit(
calibration_labels=[0, 1, 0, 1],
calibration_positive_probabilities=[0.1, 0.8, 0.2, 0.9],
calibration_masks=[
[0, 0, 1],
[0, 0, 0],
[1, 1, 1],
[1, 1, 0],
],
feature_names=["a", "b", "c"],
)
prediction_sets = calibrator.predict_sets(
positive_probabilities=[0.15, 0.85],
test_masks=[
[0, 0, 1],
[1, 1, 0],
],
)
diagnostics = calibrator.diagnostics_summary()
assert calibrator.selected_feature_names_ == ["a", "b"]
assert calibrator.last_match_counts_.tolist() == [2, 2]
assert diagnostics["fallback_rate"] == pytest.approx(0.0)
assert prediction_sets == [{0}, {1}]
def test_weighted_missingness_group_conformal_reweights_binary_indicator() -> None:
calibrator = WeightedMissingnessGroupConformalClassifier(alpha=0.5).fit(
calibration_labels=[0, 0, 1, 1],
calibration_positive_probabilities=[0.80, 0.90, 0.90, 0.85],
calibration_group_ids=[0, 0, 1, 1],
)
prediction_sets = calibrator.predict_sets(
positive_probabilities=[0.05, 0.95],
test_group_ids=[1, 1, 1, 0],
)
assert calibrator.global_threshold_ == pytest.approx(0.8)
assert calibrator.threshold_ == pytest.approx(0.15)
assert prediction_sets == [{0}, {1}]
def test_kernel_smoothed_missingness_h_infinity_matches_standard_threshold() -> None:
classifier_cls = getattr(
conformal_module,
"KernelSmoothedMissingnessConformalClassifier",
None,
)
assert classifier_cls is not None
calibrator = classifier_cls(alpha=0.4, bandwidth=float("inf")).fit(
calibration_labels=[0, 1, 1, 0],
calibration_positive_probabilities=[0.1, 0.8, 0.3, 0.9],
calibration_masks=[[0], [0], [1], [1]],
)
assert calibrator.global_threshold_ == pytest.approx(0.7)
assert calibrator.thresholds_for_test_masks([[0], [1]]).tolist() == pytest.approx([0.7, 0.7])
def test_kernel_smoothed_missingness_h_zero_uses_only_exact_mask_matches() -> None:
classifier_cls = getattr(
conformal_module,
"KernelSmoothedMissingnessConformalClassifier",
None,
)
assert classifier_cls is not None
calibrator = classifier_cls(alpha=0.5, bandwidth=0.0).fit(
calibration_labels=[0, 0, 1, 1],
calibration_positive_probabilities=[0.80, 0.90, 0.90, 0.85],
calibration_masks=[[0], [0], [1], [1]],
)
assert calibrator.thresholds_for_test_masks([[1], [0]]).tolist() == pytest.approx([0.15, 0.9])
def test_kernel_smoothed_missingness_intermediate_bandwidth_softly_downweights_far_masks() -> None:
classifier_cls = getattr(
conformal_module,
"KernelSmoothedMissingnessConformalClassifier",
None,
)
assert classifier_cls is not None
calibrator = classifier_cls(alpha=0.4, bandwidth=1.0).fit(
calibration_labels=[0, 0, 0, 0, 0],
calibration_positive_probabilities=[0.95, 0.80, 0.60, 0.20, 0.10],
calibration_masks=[[0, 0], [0, 1], [1, 0], [1, 1], [1, 1]],
)
threshold = calibrator.thresholds_for_test_masks([[1, 1]])[0]
n_eff = calibrator.effective_sample_size([[1, 1]])[0]
assert 0.2 < threshold < 0.8
assert 1.0 < n_eff < 5.0
def test_prediction_set_size_breakdown_reports_binary_set_fractions() -> None:
breakdown = prediction_set_size_breakdown([set(), {0}, {1}, {0, 1}, {0}])
assert breakdown == pytest.approx(
{
"empty_frac": 0.2,
"singleton_0_frac": 0.4,
"singleton_1_frac": 0.2,
"full_frac": 0.2,
}
)
def test_cp_mda_exact_falls_back_to_global_threshold_when_matches_are_too_small() -> None:
calibrator = CPMDAExactClassifier(alpha=0.4, top_k_features=2, min_match=3).fit(
calibration_labels=[0, 1, 0, 1],
calibration_positive_probabilities=[0.1, 0.8, 0.2, 0.9],
calibration_masks=[
[0, 0, 1],
[0, 0, 0],
[1, 1, 1],
[1, 1, 0],
],
feature_names=["a", "b", "c"],
)
prediction_sets = calibrator.predict_sets(
positive_probabilities=[0.55],
test_masks=[[0, 0, 1]],
)
diagnostics = calibrator.diagnostics_summary()
assert calibrator.last_fallback_mask_.tolist() == [True]
assert diagnostics["mean_match_count"] == pytest.approx(2.0)
assert diagnostics["fallback_rate"] == pytest.approx(1.0)
assert prediction_sets == [set()]
def test_missingness_aware_conformal_rejects_both_fit_group_inputs() -> None:
calibrator = MissingnessAwareConformalClassifier(alpha=0.5, min_group_size=2)
with pytest.raises(
ValueError,
match="Provide exactly one of calibration_missing_rates or calibration_group_ids",
):
calibrator.fit(
calibration_labels=[0, 1, 1, 0, 1, 0],
calibration_positive_probabilities=[0.1, 0.85, 0.8, 0.55, 0.35, 0.9],
calibration_missing_rates=[0.05, 0.1, 0.2, 0.7, 0.8, 0.95],
calibration_group_ids=[0, 0, 1, 1, 2, 2],
)
def test_missingness_aware_conformal_rejects_missing_rate_prediction_in_explicit_group_mode() -> None:
calibrator = MissingnessAwareConformalClassifier(alpha=0.5, min_group_size=2)
calibrator.fit(
calibration_labels=[0, 1, 1, 0, 1, 0],
calibration_positive_probabilities=[0.1, 0.85, 0.8, 0.55, 0.35, 0.9],
calibration_group_ids=[0, 0, 1, 1, 2, 2],
)
with pytest.raises(
ValueError,
match="explicit-group mode requires explicit group ids at prediction time",
):
calibrator.assign_groups([0.08, 0.5, 0.9])
with pytest.raises(
ValueError,
match="explicit-group mode requires explicit group ids at prediction time",
):
calibrator.predict_sets(
positive_probabilities=[0.9, 0.45, 0.45],
missing_rates=[0.08, 0.5, 0.9],
)
def test_label_conditional_conformal_uses_label_specific_thresholds() -> None:
calibrator = LabelConditionalConformalClassifier(alpha=0.5, min_group_size=2)
calibrator.fit(
calibration_labels=[0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1],
calibration_positive_probabilities=[0.1, 0.2, 0.9, 0.8, 0.4, 0.45, 0.55, 0.5, 0.7, 0.8, 0.3, 0.35],
)
prediction_sets = calibrator.predict_sets([0.18, 0.53, 0.72])
assert calibrator.global_threshold_ == pytest.approx(0.45)
assert calibrator.label_thresholds_[0] == pytest.approx(0.45)
assert calibrator.label_thresholds_[1] == pytest.approx(0.5)
assert prediction_sets == [{0}, {1}, {1}]
def test_label_conditional_conformal_falls_back_to_global_threshold() -> None:
calibrator = LabelConditionalConformalClassifier(alpha=0.5, min_group_size=7)
calibrator.fit(
calibration_labels=[0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1],
calibration_positive_probabilities=[0.1, 0.2, 0.9, 0.8, 0.4, 0.45, 0.55, 0.5, 0.7, 0.8, 0.3, 0.35],
)
prediction_sets = calibrator.predict_sets([0.45, 0.55])
assert calibrator.label_thresholds_ == {}
assert calibrator.global_threshold_ == pytest.approx(0.45)
assert prediction_sets == [{0}, {1}]
def test_label_conditional_missingness_aware_accepts_explicit_calibration_group_ids() -> None:
calibrator = LabelConditionalMissingnessAwareConformalClassifier(alpha=0.5, min_group_size=2)
calibrator.fit(
calibration_labels=[0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1],
calibration_positive_probabilities=[0.1, 0.2, 0.9, 0.8, 0.4, 0.45, 0.55, 0.5, 0.7, 0.8, 0.3, 0.35],
calibration_group_ids=[0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2],
)
assert calibrator.global_threshold_ == pytest.approx(0.45)
assert calibrator.label_thresholds_[0] == pytest.approx(0.45)
assert calibrator.label_thresholds_[1] == pytest.approx(0.5)
assert calibrator.label_group_thresholds_[(0, 0)] == pytest.approx(0.2)
assert calibrator.label_group_thresholds_[(1, 0)] == pytest.approx(0.2)
assert calibrator.label_group_thresholds_[(0, 1)] == pytest.approx(0.45)
assert calibrator.label_group_thresholds_[(1, 1)] == pytest.approx(0.5)
assert calibrator.label_group_thresholds_[(0, 2)] == pytest.approx(0.8)
assert calibrator.label_group_thresholds_[(1, 2)] == pytest.approx(0.7)
def test_label_conditional_missingness_aware_accepts_explicit_test_group_ids() -> None:
calibrator = LabelConditionalMissingnessAwareConformalClassifier(alpha=0.5, min_group_size=2)
calibrator.fit(
calibration_labels=[0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1],
calibration_positive_probabilities=[0.1, 0.2, 0.9, 0.8, 0.4, 0.45, 0.55, 0.5, 0.7, 0.8, 0.3, 0.35],
calibration_missing_rates=[0.01, 0.02, 0.03, 0.04, 0.5, 0.51, 0.52, 0.53, 0.9, 0.91, 0.92, 0.93],
)
prediction_sets = calibrator.predict_sets(
positive_probabilities=[0.18, 0.53, 0.72],
test_group_ids=[0, 1, 2],
)
assert prediction_sets == [{0}, {1}, {0, 1}]
def test_label_conditional_missingness_aware_rejects_both_fit_group_inputs() -> None:
calibrator = LabelConditionalMissingnessAwareConformalClassifier(alpha=0.5, min_group_size=2)
with pytest.raises(
ValueError,
match="Provide exactly one of calibration_missing_rates or calibration_group_ids",
):
calibrator.fit(
calibration_labels=[0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1],
calibration_positive_probabilities=[0.1, 0.2, 0.9, 0.8, 0.4, 0.45, 0.55, 0.5, 0.7, 0.8, 0.3, 0.35],
calibration_missing_rates=[0.01, 0.02, 0.03, 0.04, 0.5, 0.51, 0.52, 0.53, 0.9, 0.91, 0.92, 0.93],
calibration_group_ids=[0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2],
)
def test_label_conditional_missingness_aware_rejects_missing_rate_prediction_in_explicit_group_mode() -> None:
calibrator = LabelConditionalMissingnessAwareConformalClassifier(alpha=0.5, min_group_size=2)
calibrator.fit(
calibration_labels=[0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1],
calibration_positive_probabilities=[0.1, 0.2, 0.9, 0.8, 0.4, 0.45, 0.55, 0.5, 0.7, 0.8, 0.3, 0.35],
calibration_group_ids=[0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2],
)
with pytest.raises(
ValueError,
match="explicit-group mode requires explicit group ids at prediction time",
):
calibrator.assign_groups([0.015, 0.515, 0.915])
with pytest.raises(
ValueError,
match="explicit-group mode requires explicit group ids at prediction time",
):
calibrator.predict_sets(
positive_probabilities=[0.18, 0.53, 0.72],
missing_rates=[0.015, 0.515, 0.915],
)
def test_conformalized_quantile_regressor_expands_intervals_with_global_threshold() -> None:
calibrator = ConformalizedQuantileRegressor(alpha=0.25)
calibrator.fit(
calibration_targets=[1.0, 1.0, 1.0, 1.0],
calibration_lower_quantiles=[1.0, 0.9, 0.0, -0.5],
calibration_upper_quantiles=[1.0, 1.0, 1.0, 1.0],
)
lower_bounds, upper_bounds = calibrator.predict_intervals(
lower_quantiles=[2.0, 4.0],
upper_quantiles=[3.0, 5.0],
)
assert calibrator.threshold_ == pytest.approx(1.5)
assert lower_bounds.tolist() == pytest.approx([0.5, 2.5])
assert upper_bounds.tolist() == pytest.approx([4.5, 6.5])
def test_missingness_aware_conformalized_quantile_regressor_uses_group_specific_thresholds() -> None:
calibrator = MissingnessAwareConformalizedQuantileRegressor(alpha=0.5, min_group_size=2)
calibrator.fit(
calibration_targets=[1.0, 1.0, 1.0, 1.0],
calibration_lower_quantiles=[1.0, 0.9, 0.0, -0.5],
calibration_upper_quantiles=[1.0, 1.0, 1.0, 1.0],
calibration_group_ids=[0, 0, 1, 1],
)
lower_bounds, upper_bounds = calibrator.predict_intervals(
lower_quantiles=[2.0, 2.0],
upper_quantiles=[3.0, 3.0],
test_group_ids=[0, 1],
)
assert calibrator.global_threshold_ == pytest.approx(1.0)
assert calibrator.group_thresholds_[0] == pytest.approx(0.1)
assert calibrator.group_thresholds_[1] == pytest.approx(1.5)
assert lower_bounds.tolist() == pytest.approx([1.9, 0.5])
assert upper_bounds.tolist() == pytest.approx([3.1, 4.5])
def test_conformalized_quantile_regressor_trim_fraction_reduces_extreme_threshold() -> None:
base = ConformalizedQuantileRegressor(alpha=0.2)
trimmed = ConformalizedQuantileRegressor(alpha=0.2, trim_fraction=0.2)
calibration_targets = [0.0, 0.0, 0.0, 0.0, 100.0]
calibration_lower = [0.0, 0.0, 0.0, 0.0, 0.0]
calibration_upper = [0.0, 0.0, 0.0, 0.0, 0.0]
base.fit(calibration_targets, calibration_lower, calibration_upper)
trimmed.fit(calibration_targets, calibration_lower, calibration_upper)
assert trimmed.threshold_ < base.threshold_
assert base.threshold_ == pytest.approx(100.0)
assert trimmed.threshold_ == pytest.approx(0.0)
def test_conformalized_quantile_regressor_normalized_score_scales_with_interval_width() -> None:
calibrator = ConformalizedQuantileRegressor(alpha=0.5, score_function="normalized")
calibrator.fit(
calibration_targets=[2.0, 4.0],
calibration_lower_quantiles=[1.0, 2.0],
calibration_upper_quantiles=[3.0, 6.0],
)
lower_bounds, upper_bounds = calibrator.predict_intervals(
lower_quantiles=[10.0, 10.0],
upper_quantiles=[12.0, 14.0],
)
assert calibrator.threshold_ == pytest.approx(0.5)
assert lower_bounds.tolist() == pytest.approx([9.0, 8.0])
assert upper_bounds.tolist() == pytest.approx([13.0, 16.0])
def test_label_conditional_missingness_aware_uses_label_group_thresholds() -> None:
calibrator = LabelConditionalMissingnessAwareConformalClassifier(alpha=0.5, min_group_size=2)
calibrator.fit(
calibration_labels=[0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1],
calibration_positive_probabilities=[0.1, 0.2, 0.9, 0.8, 0.4, 0.45, 0.55, 0.5, 0.7, 0.8, 0.3, 0.35],
calibration_missing_rates=[0.01, 0.02, 0.03, 0.04, 0.5, 0.51, 0.52, 0.53, 0.9, 0.91, 0.92, 0.93],
)
prediction_sets = calibrator.predict_sets(
positive_probabilities=[0.18, 0.53, 0.72],
missing_rates=[0.015, 0.515, 0.915],
)
assert calibrator.global_threshold_ == pytest.approx(0.45)
assert calibrator.label_thresholds_[0] == pytest.approx(0.45)
assert calibrator.label_thresholds_[1] == pytest.approx(0.5)
assert calibrator.label_group_thresholds_[(0, 0)] == pytest.approx(0.2)
assert calibrator.label_group_thresholds_[(1, 0)] == pytest.approx(0.2)
assert calibrator.label_group_thresholds_[(0, 1)] == pytest.approx(0.45)
assert calibrator.label_group_thresholds_[(1, 1)] == pytest.approx(0.5)
assert calibrator.label_group_thresholds_[(0, 2)] == pytest.approx(0.8)
assert calibrator.label_group_thresholds_[(1, 2)] == pytest.approx(0.7)
assert prediction_sets == [{0}, {1}, {0, 1}]
def test_label_conditional_missingness_aware_falls_back_to_label_thresholds() -> None:
calibrator = LabelConditionalMissingnessAwareConformalClassifier(alpha=0.5, min_group_size=3)
calibrator.fit(
calibration_labels=[0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1],
calibration_positive_probabilities=[0.1, 0.2, 0.9, 0.8, 0.4, 0.45, 0.55, 0.5, 0.7, 0.8, 0.3, 0.35],
calibration_missing_rates=[0.01, 0.02, 0.03, 0.04, 0.5, 0.51, 0.52, 0.53, 0.9, 0.91, 0.92, 0.93],
)
prediction_sets = calibrator.predict_sets(
positive_probabilities=[0.53],
missing_rates=[0.515],
)
assert calibrator.label_group_thresholds_ == {}
assert calibrator.label_thresholds_[0] == pytest.approx(0.45)
assert calibrator.label_thresholds_[1] == pytest.approx(0.5)
assert prediction_sets == [{1}]
def test_weighted_missingness_conformal_uses_local_similarity_weights() -> None:
calibrator = WeightedMissingnessConformalClassifier(alpha=0.5, bandwidth=0.1)
calibrator.fit(
calibration_labels=[0, 1, 1, 0, 1, 0],
calibration_positive_probabilities=[0.1, 0.85, 0.8, 0.55, 0.35, 0.9],
calibration_missing_rates=[0.05, 0.1, 0.2, 0.7, 0.8, 0.95],
)
thresholds = calibrator.thresholds_for_missing_rates([0.08, 0.8])
prediction_sets = calibrator.predict_sets(
positive_probabilities=[0.19, 0.45],
missing_rates=[0.08, 0.8],
)
assert calibrator.global_threshold_ == pytest.approx(0.55)
assert thresholds == pytest.approx([0.15, 0.65])
assert prediction_sets == [set(), {0, 1}]
def test_weighted_missingness_conformal_falls_back_to_global_threshold() -> None:
calibrator = WeightedMissingnessConformalClassifier(alpha=0.5, bandwidth=1e-12)
calibrator.fit(
calibration_labels=[0, 1, 1, 0, 1, 0],
calibration_positive_probabilities=[0.1, 0.85, 0.8, 0.55, 0.35, 0.9],
calibration_missing_rates=[0.05, 0.1, 0.2, 0.7, 0.8, 0.95],
)
thresholds = calibrator.thresholds_for_missing_rates([1.5])
prediction_sets = calibrator.predict_sets(
positive_probabilities=[0.45],
missing_rates=[1.5],
)
assert thresholds == pytest.approx([0.55])
assert prediction_sets == [{0, 1}]
def test_shrunk_weighted_missingness_conformal_lambda_zero_matches_standard() -> None:
calibrator = ShrunkWeightedMissingnessConformalClassifier(alpha=0.5, bandwidth=0.1, shrinkage_lambda=0.0)
calibrator.fit(
calibration_labels=[0, 1, 1, 0, 1, 0],
calibration_positive_probabilities=[0.1, 0.85, 0.8, 0.55, 0.35, 0.9],
calibration_missing_rates=[0.05, 0.1, 0.2, 0.7, 0.8, 0.95],
)
thresholds = calibrator.thresholds_for_missing_rates([0.08, 0.8])
prediction_sets = calibrator.predict_sets(
positive_probabilities=[0.45, 0.45],
missing_rates=[0.08, 0.8],
)
assert thresholds == pytest.approx([0.55, 0.55])
assert prediction_sets == [{0, 1}, {0, 1}]
def test_shrunk_weighted_missingness_conformal_lambda_one_matches_weighted() -> None:
calibrator = ShrunkWeightedMissingnessConformalClassifier(alpha=0.5, bandwidth=0.1, shrinkage_lambda=1.0)
calibrator.fit(
calibration_labels=[0, 1, 1, 0, 1, 0],
calibration_positive_probabilities=[0.1, 0.85, 0.8, 0.55, 0.35, 0.9],
calibration_missing_rates=[0.05, 0.1, 0.2, 0.7, 0.8, 0.95],
)
thresholds = calibrator.thresholds_for_missing_rates([0.08, 0.8])
prediction_sets = calibrator.predict_sets(
positive_probabilities=[0.19, 0.45],
missing_rates=[0.08, 0.8],
)
assert thresholds == pytest.approx([0.15, 0.65])
assert prediction_sets == [set(), {0, 1}]
def test_shrunk_weighted_missingness_conformal_intermediate_lambda_softens_local_threshold() -> None:
calibrator = ShrunkWeightedMissingnessConformalClassifier(alpha=0.5, bandwidth=0.1, shrinkage_lambda=0.5)
calibrator.fit(
calibration_labels=[0, 1, 1, 0, 1, 0],
calibration_positive_probabilities=[0.1, 0.85, 0.8, 0.55, 0.35, 0.9],
calibration_missing_rates=[0.05, 0.1, 0.2, 0.7, 0.8, 0.95],
)
thresholds = calibrator.thresholds_for_missing_rates([0.08, 0.8])
prediction_sets = calibrator.predict_sets(
positive_probabilities=[0.19, 0.45],
missing_rates=[0.08, 0.8],
)
assert thresholds == pytest.approx([0.2, 0.65])
assert prediction_sets == [{0}, {0, 1}]
def test_propensity_weighted_conformal_logistic_reweights_toward_test_like_missingness() -> None:
calibrator = PropensityWeightedConformalClassifier(
alpha=0.5,
propensity_model="logistic",
random_state=0,
).fit(
calibration_labels=[0, 0, 1, 1],
calibration_positive_probabilities=[0.80, 0.90, 0.70, 0.65],
calibration_features=[[0.0], [0.0], [1.0], [1.0]],
test_features=[[1.0], [1.0], [1.0], [1.0]],
)
prediction_sets = calibrator.predict_sets([0.18, 0.92])
assert calibrator.global_threshold_ == pytest.approx(0.8)
assert calibrator.threshold_ < calibrator.global_threshold_
assert prediction_sets == [{0}, {1}]
diagnostics = calibrator.diagnostics_summary()
assert diagnostics["propensity_model"] == "logistic"
assert diagnostics["importance_weight_max"] > diagnostics["importance_weight_min"]
def test_propensity_weighted_conformal_tree_reweights_toward_test_like_missingness() -> None:
calibrator = PropensityWeightedConformalClassifier(
alpha=0.5,
propensity_model="tree",
random_state=0,
).fit(
calibration_labels=[0, 0, 1, 1],
calibration_positive_probabilities=[0.80, 0.90, 0.70, 0.65],
calibration_features=[[0.0], [0.0], [1.0], [1.0]],
test_features=[[1.0], [1.0], [1.0], [1.0]],
)
prediction_sets = calibrator.predict_sets([0.18, 0.92])
assert calibrator.global_threshold_ == pytest.approx(0.8)
assert calibrator.threshold_ < calibrator.global_threshold_
assert prediction_sets == [{0}, {1}]
diagnostics = calibrator.diagnostics_summary()
assert diagnostics["propensity_model"] == "tree"
def test_propensity_weighted_conformal_rejects_empty_test_features() -> None:
calibrator = PropensityWeightedConformalClassifier(alpha=0.5, propensity_model="logistic")
with pytest.raises(ValueError, match="must be non-empty"):
calibrator.fit(
calibration_labels=[0, 1],
calibration_positive_probabilities=[0.2, 0.8],
calibration_features=[[0.0], [1.0]],
test_features=[],
)
def test_grouped_propensity_weighted_conformal_reweights_within_group() -> None:
calibrator = GroupedPropensityWeightedConformalClassifier(
alpha=0.5,
propensity_model="logistic",
random_state=0,
min_group_size=2,
).fit(
calibration_labels=[0, 0, 1, 1, 0, 1],
calibration_positive_probabilities=[0.80, 0.90, 0.70, 0.65, 0.20, 0.85],
calibration_group_ids=[0, 0, 0, 0, 1, 1],
calibration_features=[[0.0], [0.0], [1.0], [1.0], [0.0], [0.0]],
test_group_ids=[0, 0, 0],
test_features=[[1.0], [1.0], [1.0]],
)
prediction_sets = calibrator.predict_sets(
positive_probabilities=[0.18, 0.92],
test_group_ids=[0, 0],
)
assert calibrator.group_unweighted_thresholds_[0] == pytest.approx(0.8)
assert calibrator.group_weighted_thresholds_[0] < calibrator.group_unweighted_thresholds_[0]
assert prediction_sets == [{0}, {1}]
diagnostics = calibrator.diagnostics_summary()
assert diagnostics["reweighted_group_count"] == 1
assert diagnostics["fallback_group_count"] == 1
def test_prediction_set_metrics_cover_coverage_size_and_abstention() -> None:
prediction_sets = [{1}, {0, 1}, {0}]
labels = [1, 1, 0]
assert empirical_coverage(prediction_sets, labels) == pytest.approx(1.0)
assert average_set_size(prediction_sets) == pytest.approx(4 / 3)
assert abstention_rate(prediction_sets) == pytest.approx(1 / 3)