Response-Quality-Assessment / tests /test_calibration.py
Ryoya Awano
deploy: fix MedLFQA Marginal mode sample matching
19fc84f
"""Phase 1 unit tests β€” no OpenAI API required.
Tests pure calculation functions in src/calibration/ and demo/inference_api.py
using synthetic data only.
"""
import math
import pytest
from src.calibration.utils import get_r_score, compute_threshold, split_group
from src.calibration.conformal import SplitConformalCalibration
from demo.inference_api import apply_threshold, SubclaimResult, Subclaim, FilteredResult
# ── Fixtures ──────────────────────────────────────────────────────────────────
def _make_subclaim(score: float, noise: float, annotation: str) -> dict:
"""Build a subclaim dict in the format produced by main.py."""
return {
"subclaim": "dummy text",
"scores": {"relavance": score, "noise": noise},
"annotations": {"gpt": annotation},
}
def _make_entry(subclaims: list[dict], group: str = "default") -> dict:
return {"subclaims": subclaims, "groups": [group]}
# ── get_r_score ───────────────────────────────────────────────────────────────
class TestGetRScore:
def test_returns_threshold_where_factuality_drops(self):
# 2 correct (S), 1 incorrect (I), sorted scores descending: 0.9, 0.6, 0.3
entry = _make_entry([
_make_subclaim(0.9, 0.0, "S"),
_make_subclaim(0.6, 0.0, "S"),
_make_subclaim(0.3, 0.0, "I"),
])
# a=1.0 means ALL retained subclaims must be correct.
# threshold=0.9: only the 0.9 subclaim is kept β†’ 1/1 correct β†’ OK
# threshold=0.6: 0.9 and 0.6 are kept β†’ 2/2 correct β†’ OK
# threshold=0.3: all three kept β†’ 2/3 correct < 1.0 β†’ drops here
r = get_r_score(entry, "relavance", a=1.0)
assert r == pytest.approx(0.3)
def test_returns_minus_one_when_always_safe(self):
# All subclaims are correct; factuality never drops below any a < 1
entry = _make_entry([
_make_subclaim(0.8, 0.0, "S"),
_make_subclaim(0.5, 0.0, "S"),
])
r = get_r_score(entry, "relavance", a=0.9)
assert r == -1
def test_noise_is_included_in_score(self):
# score=0.5, noise=0.2 β†’ effective score=0.7
# score=0.4, noise=0.1 β†’ effective score=0.5
# threshold_set = [0.7, 0.5]
# at threshold=0.7: only first kept β†’ 1/1 correct (S) β†’ OK
# at threshold=0.5: both kept β†’ 1st=S, 2nd=I β†’ 1/2 < 1.0 β†’ drops at 0.5
entry = _make_entry([
_make_subclaim(0.5, 0.2, "S"),
_make_subclaim(0.4, 0.1, "I"),
])
r = get_r_score(entry, "relavance", a=1.0)
assert r == pytest.approx(0.5)
def test_caching(self):
entry = _make_entry([_make_subclaim(0.7, 0.0, "S")])
r1 = get_r_score(entry, "relavance", a=0.9)
r2 = get_r_score(entry, "relavance", a=0.9)
assert r1 == r2
assert f"r_score_0.9_relavance" in entry
# ── compute_threshold ─────────────────────────────────────────────────────────
class TestComputeThreshold:
def _make_calib_data(self, r_scores: list[float]) -> list[dict]:
"""Create calibration entries where get_r_score returns predictable values."""
entries = []
for r in r_scores:
# Single subclaim with score == r, annotation I β†’ factuality drops immediately
entry = _make_entry([_make_subclaim(r, 0.0, "I")])
entries.append(entry)
return entries
def test_quantile_formula(self):
# r_scores = [0.1, 0.2, 0.3, 0.4, 0.5], n=5, alpha=0.1
# quantile_target_index = ceil((5+1)*(1-0.1)) = ceil(5.4) = 6
# sorted r_scores index 6-1=5 (1-indexed) β†’ largest = 0.5
# (any a that makes r_score == the injected score works; use a=1.0 so
# each entry returns its own score as r_score)
entries = []
for val in [0.1, 0.2, 0.3, 0.4, 0.5]:
entry = _make_entry([_make_subclaim(val, 0.0, "I")])
entries.append(entry)
q = compute_threshold(alpha=0.1, calibration_data=entries, a=1.0, confidence_method="relavance")
assert q == pytest.approx(0.5)
def test_alpha_zero_returns_maximum(self):
# alpha=0 β†’ ceil((n+1)*1) = n+1 β†’ last index β†’ max value
entries = [_make_entry([_make_subclaim(v, 0.0, "I")]) for v in [0.2, 0.8, 0.5]]
q = compute_threshold(alpha=0.0, calibration_data=entries, a=1.0, confidence_method="relavance")
assert q == pytest.approx(0.8)
def test_float_rounding_does_not_cause_index_error(self):
# Verify common floating-point alpha values (0.1 = 0.09999...) don't crash
entries = [_make_entry([_make_subclaim(v, 0.0, "I")]) for v in [0.3, 0.6, 0.9]]
for alpha in [0.05, 0.10, 0.15, 0.20, 0.25, 0.30, 0.35, 0.40]:
compute_threshold(alpha=alpha, calibration_data=entries, a=1.0, confidence_method="relavance")
# ── split_group ───────────────────────────────────────────────────────────────
class TestSplitGroup:
def _make_grouped_data(self) -> list[dict]:
data = []
for i in range(6):
data.append(_make_entry([_make_subclaim(0.5, 0.0, "S")], group="A"))
for i in range(4):
data.append(_make_entry([_make_subclaim(0.5, 0.0, "S")], group="B"))
return data
def test_split_ratio(self):
data = self._make_grouped_data()
calib, test = split_group(data, calibrate_range=0.5)
# Group A: 6 entries β†’ 3 calib, 3 test
# Group B: 4 entries β†’ 2 calib, 2 test
assert len(calib["A"]) == 3
assert len(calib["B"]) == 2
assert len(test) == 5 # 3 + 2
def test_no_overlap_between_calib_and_test(self):
data = self._make_grouped_data()
calib, test = split_group(data, calibrate_range=0.5)
calib_entries = calib["A"] + calib["B"]
# Use id() to check object identity β€” no entry appears in both sets
calib_ids = {id(e) for e in calib_entries}
test_ids = {id(e) for e in test}
assert calib_ids.isdisjoint(test_ids)
def test_all_groups_present_in_calib_keys(self):
data = self._make_grouped_data()
calib, _ = split_group(data)
assert set(calib.keys()) == {"A", "B"}
# ── _evaluate_conformal_correctness (SplitConformalCalibration) ───────────────
class TestEvaluateConformalCorrectness:
def setup_method(self):
self.cal = SplitConformalCalibration("test_dataset")
def test_all_kept_and_correct(self):
# threshold very low β†’ everything kept, all S β†’ 100% correct, 0% removed
data = [
_make_entry([
_make_subclaim(0.8, 0.0, "S"),
_make_subclaim(0.7, 0.0, "S"),
])
]
correctness, removed = self.cal._evaluate_conformal_correctness(
data, threshold=-1.0, a=1.0, confidence_method="relavance"
)
assert correctness == pytest.approx(1.0)
assert removed == pytest.approx(0.0)
def test_all_removed(self):
# threshold very high β†’ everything removed, retained_cnt=0 β†’ fallback 1.0
data = [
_make_entry([
_make_subclaim(0.3, 0.0, "S"),
_make_subclaim(0.4, 0.0, "I"),
])
]
correctness, removed = self.cal._evaluate_conformal_correctness(
data, threshold=999.0, a=0.9, confidence_method="relavance"
)
assert correctness == pytest.approx(1.0) # fallback: 0/0 β†’ 1.0
assert removed == pytest.approx(1.0)
def test_partial_removal(self):
# entry: scores [0.9S, 0.4I], threshold=0.5
# kept: [0.9S] β†’ 1/1 correct β‰₯ a=0.9 β†’ correctly_retained=True
# removed: 1/2 = 0.5
data = [
_make_entry([
_make_subclaim(0.9, 0.0, "S"),
_make_subclaim(0.4, 0.0, "I"),
])
]
correctness, removed = self.cal._evaluate_conformal_correctness(
data, threshold=0.5, a=0.9, confidence_method="relavance"
)
assert correctness == pytest.approx(1.0)
assert removed == pytest.approx(0.5)
def test_noise_included_in_comparison(self):
# score=0.3, noise=0.3 β†’ effective 0.6 β‰₯ threshold=0.5 β†’ kept
data = [_make_entry([_make_subclaim(0.3, 0.3, "S")])]
correctness, removed = self.cal._evaluate_conformal_correctness(
data, threshold=0.5, a=1.0, confidence_method="relavance"
)
assert removed == pytest.approx(0.0) # kept, not removed
# ── apply_threshold ───────────────────────────────────────────────────────────
class TestApplyThreshold:
def _make_result(self, scores: list[float]) -> SubclaimResult:
subclaims = [
Subclaim(id=i, text=f"claim {i}", score=s)
for i, s in enumerate(scores)
]
return SubclaimResult(
query="test query",
dataset="pop_qa",
mode="marginal",
group="default",
scoring_method="relevance",
rag_answer="test answer",
retrieved_docs=[],
subclaims=subclaims,
)
def test_keep_count(self):
result = self._make_result([0.9, 0.6, 0.3])
filtered = apply_threshold(result, q_hat=0.5)
assert filtered["keep_count"] == 2
assert filtered["remove_count"] == 1
def test_boundary_is_inclusive(self):
# score == q_hat should be kept (>=)
result = self._make_result([0.5])
filtered = apply_threshold(result, q_hat=0.5)
assert filtered["keep_count"] == 1
def test_all_kept(self):
result = self._make_result([0.8, 0.9])
filtered = apply_threshold(result, q_hat=0.0)
assert filtered["keep_count"] == 2
assert filtered["remove_count"] == 0
def test_all_removed(self):
result = self._make_result([0.1, 0.2])
filtered = apply_threshold(result, q_hat=1.0)
assert filtered["keep_count"] == 0
assert filtered["remove_count"] == 2
def test_subclaims_list_passed_through(self):
result = self._make_result([0.7, 0.4])
filtered = apply_threshold(result, q_hat=0.5)
assert filtered["subclaims"] is result["subclaims"]
def test_q_hat_stored(self):
result = self._make_result([0.7])
filtered = apply_threshold(result, q_hat=0.42)
assert filtered["q_hat"] == pytest.approx(0.42)