| from __future__ import annotations |
|
|
| import pytest |
| import sepsis_mcp.conformal as conformal_module |
|
|
| from sepsis_mcp.conformal import ( |
| CPMDAExactClassifier, |
| ConformalizedQuantileRegressor, |
| GroupedPropensityWeightedConformalClassifier, |
| LabelConditionalConformalClassifier, |
| LabelConditionalMissingnessAwareConformalClassifier, |
| MissingnessAwareConformalClassifier, |
| MissingnessAwareConformalizedQuantileRegressor, |
| MondrianTiltedConformalClassifier, |
| PropensityWeightedConformalClassifier, |
| ShrunkWeightedMissingnessConformalClassifier, |
| SplitConformalClassifier, |
| WeightedMissingnessGroupConformalClassifier, |
| WeightedMissingnessConformalClassifier, |
| binary_nonconformity_scores, |
| ) |
| from sepsis_mcp.metrics import average_set_size, empirical_coverage, abstention_rate, prediction_set_size_breakdown |
|
|
|
|
| def test_binary_nonconformity_scores_follow_true_label_definition() -> None: |
| y_true = [0, 1, 1, 0] |
| positive_probabilities = [0.1, 0.8, 0.3, 0.9] |
|
|
| scores = binary_nonconformity_scores(y_true, positive_probabilities) |
|
|
| assert scores.tolist() == pytest.approx([0.1, 0.2, 0.7, 0.9]) |
|
|
|
|
| def test_split_conformal_prediction_sets_use_single_global_threshold() -> None: |
| calibrator = SplitConformalClassifier(alpha=0.4) |
| calibrator.fit( |
| calibration_labels=[0, 1, 1, 0], |
| calibration_positive_probabilities=[0.1, 0.8, 0.3, 0.9], |
| ) |
|
|
| prediction_sets = calibrator.predict_sets([0.8, 0.55, 0.2]) |
|
|
| assert calibrator.threshold_ == pytest.approx(0.7) |
| assert prediction_sets == [{1}, {0, 1}, {0}] |
|
|
|
|
| def test_split_conformal_uses_infinite_threshold_when_nominal_rank_exceeds_sample() -> None: |
| calibrator = SplitConformalClassifier(alpha=0.1) |
| calibrator.fit( |
| calibration_labels=[0, 1], |
| calibration_positive_probabilities=[0.1, 0.8], |
| ) |
|
|
| prediction_sets = calibrator.predict_sets([0.05, 0.95]) |
|
|
| assert calibrator.threshold_ == float("inf") |
| assert prediction_sets == [{0, 1}, {0, 1}] |
|
|
|
|
| def test_missingness_aware_conformal_uses_group_specific_thresholds() -> None: |
| calibrator = MissingnessAwareConformalClassifier(alpha=0.5, min_group_size=2) |
| calibrator.fit( |
| calibration_labels=[0, 1, 1, 0, 1, 0], |
| calibration_positive_probabilities=[0.1, 0.85, 0.8, 0.55, 0.35, 0.9], |
| calibration_missing_rates=[0.05, 0.1, 0.2, 0.7, 0.8, 0.95], |
| ) |
|
|
| prediction_sets = calibrator.predict_sets( |
| positive_probabilities=[0.9, 0.45, 0.45], |
| missing_rates=[0.08, 0.5, 0.9], |
| ) |
|
|
| assert calibrator.global_threshold_ == pytest.approx(0.55) |
| assert calibrator.group_thresholds_[0] == pytest.approx(0.15) |
| assert calibrator.group_thresholds_[1] == pytest.approx(0.55) |
| assert calibrator.group_thresholds_[2] == pytest.approx(0.9) |
| assert prediction_sets == [{1}, {0, 1}, {0, 1}] |
|
|
|
|
| def test_missingness_aware_conformal_falls_back_to_global_threshold_for_small_groups() -> None: |
| calibrator = MissingnessAwareConformalClassifier(alpha=0.5, min_group_size=3) |
| calibrator.fit( |
| calibration_labels=[0, 1, 1, 0, 1, 0], |
| calibration_positive_probabilities=[0.1, 0.85, 0.8, 0.55, 0.35, 0.9], |
| calibration_missing_rates=[0.05, 0.1, 0.2, 0.7, 0.8, 0.95], |
| ) |
|
|
| prediction_sets = calibrator.predict_sets( |
| positive_probabilities=[0.45], |
| missing_rates=[0.08], |
| ) |
|
|
| assert calibrator.group_thresholds_ == {} |
| assert prediction_sets == [{0, 1}] |
|
|
|
|
| def test_missingness_aware_conformal_accepts_explicit_calibration_group_ids() -> None: |
| calibrator = MissingnessAwareConformalClassifier(alpha=0.5, min_group_size=2) |
| calibrator.fit( |
| calibration_labels=[0, 1, 1, 0, 1, 0], |
| calibration_positive_probabilities=[0.1, 0.85, 0.8, 0.55, 0.35, 0.9], |
| calibration_group_ids=[0, 0, 1, 1, 2, 2], |
| ) |
|
|
| assert calibrator.global_threshold_ == pytest.approx(0.55) |
| assert calibrator.group_thresholds_[0] == pytest.approx(0.15) |
| assert calibrator.group_thresholds_[1] == pytest.approx(0.55) |
| assert calibrator.group_thresholds_[2] == pytest.approx(0.9) |
|
|
|
|
| def test_missingness_aware_conformal_accepts_explicit_test_group_ids() -> None: |
| calibrator = MissingnessAwareConformalClassifier(alpha=0.5, min_group_size=2) |
| calibrator.fit( |
| calibration_labels=[0, 1, 1, 0, 1, 0], |
| calibration_positive_probabilities=[0.1, 0.85, 0.8, 0.55, 0.35, 0.9], |
| calibration_missing_rates=[0.05, 0.1, 0.2, 0.7, 0.8, 0.95], |
| ) |
|
|
| prediction_sets = calibrator.predict_sets( |
| positive_probabilities=[0.9, 0.45, 0.45], |
| test_group_ids=[0, 1, 2], |
| ) |
|
|
| assert prediction_sets == [{1}, {0, 1}, {0, 1}] |
|
|
|
|
| def test_mondrian_tilted_with_zero_shrinkage_matches_pure_mondrian() -> None: |
| calibration_labels = [0, 1, 1, 0, 1, 0] |
| calibration_probabilities = [0.1, 0.85, 0.8, 0.55, 0.35, 0.9] |
| calibration_missing_rates = [0.05, 0.1, 0.2, 0.7, 0.8, 0.95] |
| calibration_group_ids = [0, 0, 0, 1, 1, 1] |
| test_probabilities = [0.9, 0.45] |
| test_missing_rates = [0.08, 0.85] |
| test_group_ids = [0, 1] |
|
|
| mondrian = MissingnessAwareConformalClassifier(alpha=0.5, min_group_size=2).fit( |
| calibration_labels=calibration_labels, |
| calibration_positive_probabilities=calibration_probabilities, |
| calibration_group_ids=calibration_group_ids, |
| ) |
| tilted = MondrianTiltedConformalClassifier( |
| alpha=0.5, |
| min_group_size=2, |
| shrinkage_lambda=0.0, |
| bandwidth=0.1, |
| ).fit( |
| calibration_labels=calibration_labels, |
| calibration_positive_probabilities=calibration_probabilities, |
| calibration_missing_rates=calibration_missing_rates, |
| calibration_group_ids=calibration_group_ids, |
| ) |
|
|
| assert tilted.predict_sets( |
| positive_probabilities=test_probabilities, |
| test_missing_rates=test_missing_rates, |
| test_group_ids=test_group_ids, |
| ) == mondrian.predict_sets( |
| positive_probabilities=test_probabilities, |
| test_group_ids=test_group_ids, |
| ) |
|
|
|
|
| def test_mondrian_tilted_falls_back_to_global_threshold_for_small_groups() -> None: |
| calibrator = MondrianTiltedConformalClassifier( |
| alpha=0.5, |
| min_group_size=4, |
| shrinkage_lambda=0.7, |
| bandwidth=0.1, |
| ).fit( |
| calibration_labels=[0, 1, 1, 0, 1, 0], |
| calibration_positive_probabilities=[0.1, 0.85, 0.8, 0.55, 0.35, 0.9], |
| calibration_missing_rates=[0.05, 0.1, 0.2, 0.7, 0.8, 0.95], |
| calibration_group_ids=[0, 0, 0, 1, 1, 1], |
| ) |
|
|
| thresholds = calibrator.thresholds_for_test_points( |
| test_missing_rates=[0.08, 0.85], |
| test_group_ids=[0, 1], |
| ) |
|
|
| assert thresholds.tolist() == pytest.approx( |
| [calibrator.global_threshold_, calibrator.global_threshold_] |
| ) |
|
|
|
|
| def test_mondrian_tilted_accepts_explicit_test_group_ids() -> None: |
| calibrator = MondrianTiltedConformalClassifier( |
| alpha=0.5, |
| min_group_size=2, |
| shrinkage_lambda=0.5, |
| bandwidth=0.1, |
| ).fit( |
| calibration_labels=[0, 1, 1, 0, 1, 0], |
| calibration_positive_probabilities=[0.1, 0.85, 0.8, 0.55, 0.35, 0.9], |
| calibration_missing_rates=[0.05, 0.1, 0.2, 0.7, 0.8, 0.95], |
| calibration_group_ids=[0, 0, 0, 1, 1, 1], |
| ) |
|
|
| prediction_sets = calibrator.predict_sets( |
| positive_probabilities=[0.9, 0.45], |
| test_missing_rates=[0.08, 0.85], |
| test_group_ids=[0, 1], |
| ) |
|
|
| assert len(prediction_sets) == 2 |
| assert all(isinstance(prediction_set, set) for prediction_set in prediction_sets) |
|
|
|
|
| def test_cp_mda_exact_uses_exact_mask_signature_matches() -> None: |
| calibrator = CPMDAExactClassifier(alpha=0.4, top_k_features=2, min_match=2).fit( |
| calibration_labels=[0, 1, 0, 1], |
| calibration_positive_probabilities=[0.1, 0.8, 0.2, 0.9], |
| calibration_masks=[ |
| [0, 0, 1], |
| [0, 0, 0], |
| [1, 1, 1], |
| [1, 1, 0], |
| ], |
| feature_names=["a", "b", "c"], |
| ) |
|
|
| prediction_sets = calibrator.predict_sets( |
| positive_probabilities=[0.15, 0.85], |
| test_masks=[ |
| [0, 0, 1], |
| [1, 1, 0], |
| ], |
| ) |
|
|
| diagnostics = calibrator.diagnostics_summary() |
|
|
| assert calibrator.selected_feature_names_ == ["a", "b"] |
| assert calibrator.last_match_counts_.tolist() == [2, 2] |
| assert diagnostics["fallback_rate"] == pytest.approx(0.0) |
| assert prediction_sets == [{0}, {1}] |
|
|
|
|
| def test_weighted_missingness_group_conformal_reweights_binary_indicator() -> None: |
| calibrator = WeightedMissingnessGroupConformalClassifier(alpha=0.5).fit( |
| calibration_labels=[0, 0, 1, 1], |
| calibration_positive_probabilities=[0.80, 0.90, 0.90, 0.85], |
| calibration_group_ids=[0, 0, 1, 1], |
| ) |
|
|
| prediction_sets = calibrator.predict_sets( |
| positive_probabilities=[0.05, 0.95], |
| test_group_ids=[1, 1, 1, 0], |
| ) |
|
|
| assert calibrator.global_threshold_ == pytest.approx(0.8) |
| assert calibrator.threshold_ == pytest.approx(0.15) |
| assert prediction_sets == [{0}, {1}] |
|
|
|
|
| def test_kernel_smoothed_missingness_h_infinity_matches_standard_threshold() -> None: |
| classifier_cls = getattr( |
| conformal_module, |
| "KernelSmoothedMissingnessConformalClassifier", |
| None, |
| ) |
| assert classifier_cls is not None |
|
|
| calibrator = classifier_cls(alpha=0.4, bandwidth=float("inf")).fit( |
| calibration_labels=[0, 1, 1, 0], |
| calibration_positive_probabilities=[0.1, 0.8, 0.3, 0.9], |
| calibration_masks=[[0], [0], [1], [1]], |
| ) |
|
|
| assert calibrator.global_threshold_ == pytest.approx(0.7) |
| assert calibrator.thresholds_for_test_masks([[0], [1]]).tolist() == pytest.approx([0.7, 0.7]) |
|
|
|
|
| def test_kernel_smoothed_missingness_h_zero_uses_only_exact_mask_matches() -> None: |
| classifier_cls = getattr( |
| conformal_module, |
| "KernelSmoothedMissingnessConformalClassifier", |
| None, |
| ) |
| assert classifier_cls is not None |
|
|
| calibrator = classifier_cls(alpha=0.5, bandwidth=0.0).fit( |
| calibration_labels=[0, 0, 1, 1], |
| calibration_positive_probabilities=[0.80, 0.90, 0.90, 0.85], |
| calibration_masks=[[0], [0], [1], [1]], |
| ) |
|
|
| assert calibrator.thresholds_for_test_masks([[1], [0]]).tolist() == pytest.approx([0.15, 0.9]) |
|
|
|
|
| def test_kernel_smoothed_missingness_intermediate_bandwidth_softly_downweights_far_masks() -> None: |
| classifier_cls = getattr( |
| conformal_module, |
| "KernelSmoothedMissingnessConformalClassifier", |
| None, |
| ) |
| assert classifier_cls is not None |
|
|
| calibrator = classifier_cls(alpha=0.4, bandwidth=1.0).fit( |
| calibration_labels=[0, 0, 0, 0, 0], |
| calibration_positive_probabilities=[0.95, 0.80, 0.60, 0.20, 0.10], |
| calibration_masks=[[0, 0], [0, 1], [1, 0], [1, 1], [1, 1]], |
| ) |
|
|
| threshold = calibrator.thresholds_for_test_masks([[1, 1]])[0] |
| n_eff = calibrator.effective_sample_size([[1, 1]])[0] |
|
|
| assert 0.2 < threshold < 0.8 |
| assert 1.0 < n_eff < 5.0 |
|
|
|
|
| def test_prediction_set_size_breakdown_reports_binary_set_fractions() -> None: |
| breakdown = prediction_set_size_breakdown([set(), {0}, {1}, {0, 1}, {0}]) |
|
|
| assert breakdown == pytest.approx( |
| { |
| "empty_frac": 0.2, |
| "singleton_0_frac": 0.4, |
| "singleton_1_frac": 0.2, |
| "full_frac": 0.2, |
| } |
| ) |
|
|
|
|
| def test_cp_mda_exact_falls_back_to_global_threshold_when_matches_are_too_small() -> None: |
| calibrator = CPMDAExactClassifier(alpha=0.4, top_k_features=2, min_match=3).fit( |
| calibration_labels=[0, 1, 0, 1], |
| calibration_positive_probabilities=[0.1, 0.8, 0.2, 0.9], |
| calibration_masks=[ |
| [0, 0, 1], |
| [0, 0, 0], |
| [1, 1, 1], |
| [1, 1, 0], |
| ], |
| feature_names=["a", "b", "c"], |
| ) |
|
|
| prediction_sets = calibrator.predict_sets( |
| positive_probabilities=[0.55], |
| test_masks=[[0, 0, 1]], |
| ) |
|
|
| diagnostics = calibrator.diagnostics_summary() |
|
|
| assert calibrator.last_fallback_mask_.tolist() == [True] |
| assert diagnostics["mean_match_count"] == pytest.approx(2.0) |
| assert diagnostics["fallback_rate"] == pytest.approx(1.0) |
| assert prediction_sets == [set()] |
|
|
|
|
| def test_missingness_aware_conformal_rejects_both_fit_group_inputs() -> None: |
| calibrator = MissingnessAwareConformalClassifier(alpha=0.5, min_group_size=2) |
|
|
| with pytest.raises( |
| ValueError, |
| match="Provide exactly one of calibration_missing_rates or calibration_group_ids", |
| ): |
| calibrator.fit( |
| calibration_labels=[0, 1, 1, 0, 1, 0], |
| calibration_positive_probabilities=[0.1, 0.85, 0.8, 0.55, 0.35, 0.9], |
| calibration_missing_rates=[0.05, 0.1, 0.2, 0.7, 0.8, 0.95], |
| calibration_group_ids=[0, 0, 1, 1, 2, 2], |
| ) |
|
|
|
|
| def test_missingness_aware_conformal_rejects_missing_rate_prediction_in_explicit_group_mode() -> None: |
| calibrator = MissingnessAwareConformalClassifier(alpha=0.5, min_group_size=2) |
| calibrator.fit( |
| calibration_labels=[0, 1, 1, 0, 1, 0], |
| calibration_positive_probabilities=[0.1, 0.85, 0.8, 0.55, 0.35, 0.9], |
| calibration_group_ids=[0, 0, 1, 1, 2, 2], |
| ) |
|
|
| with pytest.raises( |
| ValueError, |
| match="explicit-group mode requires explicit group ids at prediction time", |
| ): |
| calibrator.assign_groups([0.08, 0.5, 0.9]) |
|
|
| with pytest.raises( |
| ValueError, |
| match="explicit-group mode requires explicit group ids at prediction time", |
| ): |
| calibrator.predict_sets( |
| positive_probabilities=[0.9, 0.45, 0.45], |
| missing_rates=[0.08, 0.5, 0.9], |
| ) |
|
|
|
|
| def test_label_conditional_conformal_uses_label_specific_thresholds() -> None: |
| calibrator = LabelConditionalConformalClassifier(alpha=0.5, min_group_size=2) |
| calibrator.fit( |
| calibration_labels=[0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1], |
| calibration_positive_probabilities=[0.1, 0.2, 0.9, 0.8, 0.4, 0.45, 0.55, 0.5, 0.7, 0.8, 0.3, 0.35], |
| ) |
|
|
| prediction_sets = calibrator.predict_sets([0.18, 0.53, 0.72]) |
|
|
| assert calibrator.global_threshold_ == pytest.approx(0.45) |
| assert calibrator.label_thresholds_[0] == pytest.approx(0.45) |
| assert calibrator.label_thresholds_[1] == pytest.approx(0.5) |
| assert prediction_sets == [{0}, {1}, {1}] |
|
|
|
|
| def test_label_conditional_conformal_falls_back_to_global_threshold() -> None: |
| calibrator = LabelConditionalConformalClassifier(alpha=0.5, min_group_size=7) |
| calibrator.fit( |
| calibration_labels=[0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1], |
| calibration_positive_probabilities=[0.1, 0.2, 0.9, 0.8, 0.4, 0.45, 0.55, 0.5, 0.7, 0.8, 0.3, 0.35], |
| ) |
|
|
| prediction_sets = calibrator.predict_sets([0.45, 0.55]) |
|
|
| assert calibrator.label_thresholds_ == {} |
| assert calibrator.global_threshold_ == pytest.approx(0.45) |
| assert prediction_sets == [{0}, {1}] |
|
|
|
|
| def test_label_conditional_missingness_aware_accepts_explicit_calibration_group_ids() -> None: |
| calibrator = LabelConditionalMissingnessAwareConformalClassifier(alpha=0.5, min_group_size=2) |
| calibrator.fit( |
| calibration_labels=[0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1], |
| calibration_positive_probabilities=[0.1, 0.2, 0.9, 0.8, 0.4, 0.45, 0.55, 0.5, 0.7, 0.8, 0.3, 0.35], |
| calibration_group_ids=[0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2], |
| ) |
|
|
| assert calibrator.global_threshold_ == pytest.approx(0.45) |
| assert calibrator.label_thresholds_[0] == pytest.approx(0.45) |
| assert calibrator.label_thresholds_[1] == pytest.approx(0.5) |
| assert calibrator.label_group_thresholds_[(0, 0)] == pytest.approx(0.2) |
| assert calibrator.label_group_thresholds_[(1, 0)] == pytest.approx(0.2) |
| assert calibrator.label_group_thresholds_[(0, 1)] == pytest.approx(0.45) |
| assert calibrator.label_group_thresholds_[(1, 1)] == pytest.approx(0.5) |
| assert calibrator.label_group_thresholds_[(0, 2)] == pytest.approx(0.8) |
| assert calibrator.label_group_thresholds_[(1, 2)] == pytest.approx(0.7) |
|
|
|
|
| def test_label_conditional_missingness_aware_accepts_explicit_test_group_ids() -> None: |
| calibrator = LabelConditionalMissingnessAwareConformalClassifier(alpha=0.5, min_group_size=2) |
| calibrator.fit( |
| calibration_labels=[0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1], |
| calibration_positive_probabilities=[0.1, 0.2, 0.9, 0.8, 0.4, 0.45, 0.55, 0.5, 0.7, 0.8, 0.3, 0.35], |
| calibration_missing_rates=[0.01, 0.02, 0.03, 0.04, 0.5, 0.51, 0.52, 0.53, 0.9, 0.91, 0.92, 0.93], |
| ) |
|
|
| prediction_sets = calibrator.predict_sets( |
| positive_probabilities=[0.18, 0.53, 0.72], |
| test_group_ids=[0, 1, 2], |
| ) |
|
|
| assert prediction_sets == [{0}, {1}, {0, 1}] |
|
|
|
|
| def test_label_conditional_missingness_aware_rejects_both_fit_group_inputs() -> None: |
| calibrator = LabelConditionalMissingnessAwareConformalClassifier(alpha=0.5, min_group_size=2) |
|
|
| with pytest.raises( |
| ValueError, |
| match="Provide exactly one of calibration_missing_rates or calibration_group_ids", |
| ): |
| calibrator.fit( |
| calibration_labels=[0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1], |
| calibration_positive_probabilities=[0.1, 0.2, 0.9, 0.8, 0.4, 0.45, 0.55, 0.5, 0.7, 0.8, 0.3, 0.35], |
| calibration_missing_rates=[0.01, 0.02, 0.03, 0.04, 0.5, 0.51, 0.52, 0.53, 0.9, 0.91, 0.92, 0.93], |
| calibration_group_ids=[0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2], |
| ) |
|
|
|
|
| def test_label_conditional_missingness_aware_rejects_missing_rate_prediction_in_explicit_group_mode() -> None: |
| calibrator = LabelConditionalMissingnessAwareConformalClassifier(alpha=0.5, min_group_size=2) |
| calibrator.fit( |
| calibration_labels=[0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1], |
| calibration_positive_probabilities=[0.1, 0.2, 0.9, 0.8, 0.4, 0.45, 0.55, 0.5, 0.7, 0.8, 0.3, 0.35], |
| calibration_group_ids=[0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2], |
| ) |
|
|
| with pytest.raises( |
| ValueError, |
| match="explicit-group mode requires explicit group ids at prediction time", |
| ): |
| calibrator.assign_groups([0.015, 0.515, 0.915]) |
|
|
| with pytest.raises( |
| ValueError, |
| match="explicit-group mode requires explicit group ids at prediction time", |
| ): |
| calibrator.predict_sets( |
| positive_probabilities=[0.18, 0.53, 0.72], |
| missing_rates=[0.015, 0.515, 0.915], |
| ) |
|
|
|
|
| def test_conformalized_quantile_regressor_expands_intervals_with_global_threshold() -> None: |
| calibrator = ConformalizedQuantileRegressor(alpha=0.25) |
| calibrator.fit( |
| calibration_targets=[1.0, 1.0, 1.0, 1.0], |
| calibration_lower_quantiles=[1.0, 0.9, 0.0, -0.5], |
| calibration_upper_quantiles=[1.0, 1.0, 1.0, 1.0], |
| ) |
|
|
| lower_bounds, upper_bounds = calibrator.predict_intervals( |
| lower_quantiles=[2.0, 4.0], |
| upper_quantiles=[3.0, 5.0], |
| ) |
|
|
| assert calibrator.threshold_ == pytest.approx(1.5) |
| assert lower_bounds.tolist() == pytest.approx([0.5, 2.5]) |
| assert upper_bounds.tolist() == pytest.approx([4.5, 6.5]) |
|
|
|
|
| def test_missingness_aware_conformalized_quantile_regressor_uses_group_specific_thresholds() -> None: |
| calibrator = MissingnessAwareConformalizedQuantileRegressor(alpha=0.5, min_group_size=2) |
| calibrator.fit( |
| calibration_targets=[1.0, 1.0, 1.0, 1.0], |
| calibration_lower_quantiles=[1.0, 0.9, 0.0, -0.5], |
| calibration_upper_quantiles=[1.0, 1.0, 1.0, 1.0], |
| calibration_group_ids=[0, 0, 1, 1], |
| ) |
|
|
| lower_bounds, upper_bounds = calibrator.predict_intervals( |
| lower_quantiles=[2.0, 2.0], |
| upper_quantiles=[3.0, 3.0], |
| test_group_ids=[0, 1], |
| ) |
|
|
| assert calibrator.global_threshold_ == pytest.approx(1.0) |
| assert calibrator.group_thresholds_[0] == pytest.approx(0.1) |
| assert calibrator.group_thresholds_[1] == pytest.approx(1.5) |
| assert lower_bounds.tolist() == pytest.approx([1.9, 0.5]) |
| assert upper_bounds.tolist() == pytest.approx([3.1, 4.5]) |
|
|
|
|
| def test_conformalized_quantile_regressor_trim_fraction_reduces_extreme_threshold() -> None: |
| base = ConformalizedQuantileRegressor(alpha=0.2) |
| trimmed = ConformalizedQuantileRegressor(alpha=0.2, trim_fraction=0.2) |
|
|
| calibration_targets = [0.0, 0.0, 0.0, 0.0, 100.0] |
| calibration_lower = [0.0, 0.0, 0.0, 0.0, 0.0] |
| calibration_upper = [0.0, 0.0, 0.0, 0.0, 0.0] |
|
|
| base.fit(calibration_targets, calibration_lower, calibration_upper) |
| trimmed.fit(calibration_targets, calibration_lower, calibration_upper) |
|
|
| assert trimmed.threshold_ < base.threshold_ |
| assert base.threshold_ == pytest.approx(100.0) |
| assert trimmed.threshold_ == pytest.approx(0.0) |
|
|
|
|
| def test_conformalized_quantile_regressor_normalized_score_scales_with_interval_width() -> None: |
| calibrator = ConformalizedQuantileRegressor(alpha=0.5, score_function="normalized") |
| calibrator.fit( |
| calibration_targets=[2.0, 4.0], |
| calibration_lower_quantiles=[1.0, 2.0], |
| calibration_upper_quantiles=[3.0, 6.0], |
| ) |
|
|
| lower_bounds, upper_bounds = calibrator.predict_intervals( |
| lower_quantiles=[10.0, 10.0], |
| upper_quantiles=[12.0, 14.0], |
| ) |
|
|
| assert calibrator.threshold_ == pytest.approx(0.5) |
| assert lower_bounds.tolist() == pytest.approx([9.0, 8.0]) |
| assert upper_bounds.tolist() == pytest.approx([13.0, 16.0]) |
|
|
|
|
| def test_label_conditional_missingness_aware_uses_label_group_thresholds() -> None: |
| calibrator = LabelConditionalMissingnessAwareConformalClassifier(alpha=0.5, min_group_size=2) |
| calibrator.fit( |
| calibration_labels=[0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1], |
| calibration_positive_probabilities=[0.1, 0.2, 0.9, 0.8, 0.4, 0.45, 0.55, 0.5, 0.7, 0.8, 0.3, 0.35], |
| calibration_missing_rates=[0.01, 0.02, 0.03, 0.04, 0.5, 0.51, 0.52, 0.53, 0.9, 0.91, 0.92, 0.93], |
| ) |
|
|
| prediction_sets = calibrator.predict_sets( |
| positive_probabilities=[0.18, 0.53, 0.72], |
| missing_rates=[0.015, 0.515, 0.915], |
| ) |
|
|
| assert calibrator.global_threshold_ == pytest.approx(0.45) |
| assert calibrator.label_thresholds_[0] == pytest.approx(0.45) |
| assert calibrator.label_thresholds_[1] == pytest.approx(0.5) |
| assert calibrator.label_group_thresholds_[(0, 0)] == pytest.approx(0.2) |
| assert calibrator.label_group_thresholds_[(1, 0)] == pytest.approx(0.2) |
| assert calibrator.label_group_thresholds_[(0, 1)] == pytest.approx(0.45) |
| assert calibrator.label_group_thresholds_[(1, 1)] == pytest.approx(0.5) |
| assert calibrator.label_group_thresholds_[(0, 2)] == pytest.approx(0.8) |
| assert calibrator.label_group_thresholds_[(1, 2)] == pytest.approx(0.7) |
| assert prediction_sets == [{0}, {1}, {0, 1}] |
|
|
|
|
| def test_label_conditional_missingness_aware_falls_back_to_label_thresholds() -> None: |
| calibrator = LabelConditionalMissingnessAwareConformalClassifier(alpha=0.5, min_group_size=3) |
| calibrator.fit( |
| calibration_labels=[0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1], |
| calibration_positive_probabilities=[0.1, 0.2, 0.9, 0.8, 0.4, 0.45, 0.55, 0.5, 0.7, 0.8, 0.3, 0.35], |
| calibration_missing_rates=[0.01, 0.02, 0.03, 0.04, 0.5, 0.51, 0.52, 0.53, 0.9, 0.91, 0.92, 0.93], |
| ) |
|
|
| prediction_sets = calibrator.predict_sets( |
| positive_probabilities=[0.53], |
| missing_rates=[0.515], |
| ) |
|
|
| assert calibrator.label_group_thresholds_ == {} |
| assert calibrator.label_thresholds_[0] == pytest.approx(0.45) |
| assert calibrator.label_thresholds_[1] == pytest.approx(0.5) |
| assert prediction_sets == [{1}] |
|
|
|
|
| def test_weighted_missingness_conformal_uses_local_similarity_weights() -> None: |
| calibrator = WeightedMissingnessConformalClassifier(alpha=0.5, bandwidth=0.1) |
| calibrator.fit( |
| calibration_labels=[0, 1, 1, 0, 1, 0], |
| calibration_positive_probabilities=[0.1, 0.85, 0.8, 0.55, 0.35, 0.9], |
| calibration_missing_rates=[0.05, 0.1, 0.2, 0.7, 0.8, 0.95], |
| ) |
|
|
| thresholds = calibrator.thresholds_for_missing_rates([0.08, 0.8]) |
| prediction_sets = calibrator.predict_sets( |
| positive_probabilities=[0.19, 0.45], |
| missing_rates=[0.08, 0.8], |
| ) |
|
|
| assert calibrator.global_threshold_ == pytest.approx(0.55) |
| assert thresholds == pytest.approx([0.15, 0.65]) |
| assert prediction_sets == [set(), {0, 1}] |
|
|
|
|
| def test_weighted_missingness_conformal_falls_back_to_global_threshold() -> None: |
| calibrator = WeightedMissingnessConformalClassifier(alpha=0.5, bandwidth=1e-12) |
| calibrator.fit( |
| calibration_labels=[0, 1, 1, 0, 1, 0], |
| calibration_positive_probabilities=[0.1, 0.85, 0.8, 0.55, 0.35, 0.9], |
| calibration_missing_rates=[0.05, 0.1, 0.2, 0.7, 0.8, 0.95], |
| ) |
|
|
| thresholds = calibrator.thresholds_for_missing_rates([1.5]) |
| prediction_sets = calibrator.predict_sets( |
| positive_probabilities=[0.45], |
| missing_rates=[1.5], |
| ) |
|
|
| assert thresholds == pytest.approx([0.55]) |
| assert prediction_sets == [{0, 1}] |
|
|
|
|
| def test_shrunk_weighted_missingness_conformal_lambda_zero_matches_standard() -> None: |
| calibrator = ShrunkWeightedMissingnessConformalClassifier(alpha=0.5, bandwidth=0.1, shrinkage_lambda=0.0) |
| calibrator.fit( |
| calibration_labels=[0, 1, 1, 0, 1, 0], |
| calibration_positive_probabilities=[0.1, 0.85, 0.8, 0.55, 0.35, 0.9], |
| calibration_missing_rates=[0.05, 0.1, 0.2, 0.7, 0.8, 0.95], |
| ) |
|
|
| thresholds = calibrator.thresholds_for_missing_rates([0.08, 0.8]) |
| prediction_sets = calibrator.predict_sets( |
| positive_probabilities=[0.45, 0.45], |
| missing_rates=[0.08, 0.8], |
| ) |
|
|
| assert thresholds == pytest.approx([0.55, 0.55]) |
| assert prediction_sets == [{0, 1}, {0, 1}] |
|
|
|
|
| def test_shrunk_weighted_missingness_conformal_lambda_one_matches_weighted() -> None: |
| calibrator = ShrunkWeightedMissingnessConformalClassifier(alpha=0.5, bandwidth=0.1, shrinkage_lambda=1.0) |
| calibrator.fit( |
| calibration_labels=[0, 1, 1, 0, 1, 0], |
| calibration_positive_probabilities=[0.1, 0.85, 0.8, 0.55, 0.35, 0.9], |
| calibration_missing_rates=[0.05, 0.1, 0.2, 0.7, 0.8, 0.95], |
| ) |
|
|
| thresholds = calibrator.thresholds_for_missing_rates([0.08, 0.8]) |
| prediction_sets = calibrator.predict_sets( |
| positive_probabilities=[0.19, 0.45], |
| missing_rates=[0.08, 0.8], |
| ) |
|
|
| assert thresholds == pytest.approx([0.15, 0.65]) |
| assert prediction_sets == [set(), {0, 1}] |
|
|
|
|
| def test_shrunk_weighted_missingness_conformal_intermediate_lambda_softens_local_threshold() -> None: |
| calibrator = ShrunkWeightedMissingnessConformalClassifier(alpha=0.5, bandwidth=0.1, shrinkage_lambda=0.5) |
| calibrator.fit( |
| calibration_labels=[0, 1, 1, 0, 1, 0], |
| calibration_positive_probabilities=[0.1, 0.85, 0.8, 0.55, 0.35, 0.9], |
| calibration_missing_rates=[0.05, 0.1, 0.2, 0.7, 0.8, 0.95], |
| ) |
|
|
| thresholds = calibrator.thresholds_for_missing_rates([0.08, 0.8]) |
| prediction_sets = calibrator.predict_sets( |
| positive_probabilities=[0.19, 0.45], |
| missing_rates=[0.08, 0.8], |
| ) |
|
|
| assert thresholds == pytest.approx([0.2, 0.65]) |
| assert prediction_sets == [{0}, {0, 1}] |
|
|
|
|
| def test_propensity_weighted_conformal_logistic_reweights_toward_test_like_missingness() -> None: |
| calibrator = PropensityWeightedConformalClassifier( |
| alpha=0.5, |
| propensity_model="logistic", |
| random_state=0, |
| ).fit( |
| calibration_labels=[0, 0, 1, 1], |
| calibration_positive_probabilities=[0.80, 0.90, 0.70, 0.65], |
| calibration_features=[[0.0], [0.0], [1.0], [1.0]], |
| test_features=[[1.0], [1.0], [1.0], [1.0]], |
| ) |
|
|
| prediction_sets = calibrator.predict_sets([0.18, 0.92]) |
|
|
| assert calibrator.global_threshold_ == pytest.approx(0.8) |
| assert calibrator.threshold_ < calibrator.global_threshold_ |
| assert prediction_sets == [{0}, {1}] |
| diagnostics = calibrator.diagnostics_summary() |
| assert diagnostics["propensity_model"] == "logistic" |
| assert diagnostics["importance_weight_max"] > diagnostics["importance_weight_min"] |
|
|
|
|
| def test_propensity_weighted_conformal_tree_reweights_toward_test_like_missingness() -> None: |
| calibrator = PropensityWeightedConformalClassifier( |
| alpha=0.5, |
| propensity_model="tree", |
| random_state=0, |
| ).fit( |
| calibration_labels=[0, 0, 1, 1], |
| calibration_positive_probabilities=[0.80, 0.90, 0.70, 0.65], |
| calibration_features=[[0.0], [0.0], [1.0], [1.0]], |
| test_features=[[1.0], [1.0], [1.0], [1.0]], |
| ) |
|
|
| prediction_sets = calibrator.predict_sets([0.18, 0.92]) |
|
|
| assert calibrator.global_threshold_ == pytest.approx(0.8) |
| assert calibrator.threshold_ < calibrator.global_threshold_ |
| assert prediction_sets == [{0}, {1}] |
| diagnostics = calibrator.diagnostics_summary() |
| assert diagnostics["propensity_model"] == "tree" |
|
|
|
|
| def test_propensity_weighted_conformal_rejects_empty_test_features() -> None: |
| calibrator = PropensityWeightedConformalClassifier(alpha=0.5, propensity_model="logistic") |
|
|
| with pytest.raises(ValueError, match="must be non-empty"): |
| calibrator.fit( |
| calibration_labels=[0, 1], |
| calibration_positive_probabilities=[0.2, 0.8], |
| calibration_features=[[0.0], [1.0]], |
| test_features=[], |
| ) |
|
|
|
|
| def test_grouped_propensity_weighted_conformal_reweights_within_group() -> None: |
| calibrator = GroupedPropensityWeightedConformalClassifier( |
| alpha=0.5, |
| propensity_model="logistic", |
| random_state=0, |
| min_group_size=2, |
| ).fit( |
| calibration_labels=[0, 0, 1, 1, 0, 1], |
| calibration_positive_probabilities=[0.80, 0.90, 0.70, 0.65, 0.20, 0.85], |
| calibration_group_ids=[0, 0, 0, 0, 1, 1], |
| calibration_features=[[0.0], [0.0], [1.0], [1.0], [0.0], [0.0]], |
| test_group_ids=[0, 0, 0], |
| test_features=[[1.0], [1.0], [1.0]], |
| ) |
|
|
| prediction_sets = calibrator.predict_sets( |
| positive_probabilities=[0.18, 0.92], |
| test_group_ids=[0, 0], |
| ) |
|
|
| assert calibrator.group_unweighted_thresholds_[0] == pytest.approx(0.8) |
| assert calibrator.group_weighted_thresholds_[0] < calibrator.group_unweighted_thresholds_[0] |
| assert prediction_sets == [{0}, {1}] |
| diagnostics = calibrator.diagnostics_summary() |
| assert diagnostics["reweighted_group_count"] == 1 |
| assert diagnostics["fallback_group_count"] == 1 |
|
|
|
|
| def test_prediction_set_metrics_cover_coverage_size_and_abstention() -> None: |
| prediction_sets = [{1}, {0, 1}, {0}] |
| labels = [1, 1, 0] |
|
|
| assert empirical_coverage(prediction_sets, labels) == pytest.approx(1.0) |
| assert average_set_size(prediction_sets) == pytest.approx(4 / 3) |
| assert abstention_rate(prediction_sets) == pytest.approx(1 / 3) |
|
|