Use canonical is_refusal() from faithfulness module
Browse files- scripts/sanity_checks.py +146 -101
scripts/sanity_checks.py
CHANGED
|
@@ -33,6 +33,7 @@ from sage.config import (
|
|
| 33 |
log_banner,
|
| 34 |
log_section,
|
| 35 |
)
|
|
|
|
| 36 |
from sage.services.retrieval import get_candidates
|
| 37 |
|
| 38 |
if TYPE_CHECKING:
|
|
@@ -152,22 +153,31 @@ CONFLICT_PHRASES = frozenset(
|
|
| 152 |
]
|
| 153 |
)
|
| 154 |
|
| 155 |
-
# Phrases indicating graceful refusal
|
| 156 |
-
REFUSAL_PHRASES = frozenset(
|
| 157 |
-
[
|
| 158 |
-
"cannot",
|
| 159 |
-
"can't",
|
| 160 |
-
"unable",
|
| 161 |
-
"no evidence",
|
| 162 |
-
"insufficient",
|
| 163 |
-
"limited",
|
| 164 |
-
]
|
| 165 |
-
)
|
| 166 |
-
|
| 167 |
# Thresholds
|
| 168 |
COMBINED_HHEM_THRESHOLD = 0.5
|
| 169 |
KEY_TERM_THRESHOLD = 0.3
|
| 170 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
|
| 172 |
# ============================================================================
|
| 173 |
# SECTION: Spot-Check
|
|
@@ -178,12 +188,17 @@ def _generate_spot_samples(
|
|
| 178 |
explainer: Explainer, detector: HallucinationDetector, max_samples: int = 10
|
| 179 |
) -> Iterator[tuple]:
|
| 180 |
"""Generate spot-check samples, yielding (query, hhem_result, explanation_result)."""
|
| 181 |
-
for query in EVALUATION_QUERIES[:
|
| 182 |
products = get_candidates(
|
| 183 |
-
query=query,
|
|
|
|
|
|
|
|
|
|
| 184 |
)
|
| 185 |
-
for product in products[:
|
| 186 |
-
result = explainer.generate_explanation(
|
|
|
|
|
|
|
| 187 |
hhem = detector.check_explanation(result.evidence_texts, result.explanation)
|
| 188 |
yield query, hhem, result
|
| 189 |
max_samples -= 1
|
|
@@ -199,22 +214,31 @@ def run_spot_check(explainer: Explainer, detector: HallucinationDetector) -> Non
|
|
| 199 |
for i, (query, hhem, result) in enumerate(
|
| 200 |
_generate_spot_samples(explainer, detector), 1
|
| 201 |
):
|
| 202 |
-
|
| 203 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 204 |
logger.info(
|
| 205 |
-
"HHEM: %.3f (
|
| 206 |
-
hhem.score,
|
| 207 |
-
"PASS" if not hhem.is_hallucinated else "FAIL",
|
| 208 |
)
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
logger.info(' "%s..."', ev[:100])
|
| 212 |
-
logger.info("EXPLANATION:")
|
| 213 |
-
logger.info(" %s", result.explanation)
|
| 214 |
-
results.append({"query": query, "hhem_score": hhem.score})
|
| 215 |
-
|
| 216 |
-
scores = [r["hhem_score"] for r in results]
|
| 217 |
-
logger.info("SUMMARY: %d samples, mean HHEM: %.3f", len(results), np.mean(scores))
|
| 218 |
|
| 219 |
|
| 220 |
# ============================================================================
|
|
@@ -267,61 +291,70 @@ def run_adversarial_tests(
|
|
| 267 |
results = []
|
| 268 |
for case in ADVERSARIAL_CASES:
|
| 269 |
log_section(logger, case["name"])
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
result
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 325 |
|
| 326 |
log_summary_counts(
|
| 327 |
results,
|
|
@@ -347,7 +380,7 @@ EMPTY_CONTEXT_CASES = [
|
|
| 347 |
},
|
| 348 |
{"name": "Minimal", "query": "high-quality camera lens", "evidence": "OK."},
|
| 349 |
{
|
| 350 |
-
"name": "
|
| 351 |
"query": "wireless mouse",
|
| 352 |
"evidence": "Muy bueno el producto.",
|
| 353 |
},
|
|
@@ -357,24 +390,31 @@ EMPTY_CONTEXT_CASES = [
|
|
| 357 |
def run_empty_context_tests(
|
| 358 |
explainer: Explainer, detector: HallucinationDetector
|
| 359 |
) -> None:
|
| 360 |
-
"""Test
|
| 361 |
log_banner(logger, "EMPTY CONTEXT: Graceful Refusal", width=70)
|
| 362 |
del detector # Passed for interface consistency but unused (refusals bypass HHEM)
|
| 363 |
|
| 364 |
results = []
|
| 365 |
for case in EMPTY_CONTEXT_CASES:
|
| 366 |
log_section(logger, case["name"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 367 |
|
| 368 |
-
|
| 369 |
-
product = make_test_product([chunk], score=0.3)
|
| 370 |
-
result = explainer.generate_explanation(case["query"], product, max_evidence=1)
|
| 371 |
-
|
| 372 |
-
graceful = contains_any_phrase(result.explanation, REFUSAL_PHRASES)
|
| 373 |
|
| 374 |
-
|
| 375 |
-
|
| 376 |
|
| 377 |
-
|
|
|
|
|
|
|
|
|
|
| 378 |
|
| 379 |
logger.info(
|
| 380 |
"SUMMARY: %d/%d refused gracefully",
|
|
@@ -419,13 +459,18 @@ def run_calibration_check(
|
|
| 419 |
samples: list[CalibrationSample] = []
|
| 420 |
logger.info("Generating samples...")
|
| 421 |
|
| 422 |
-
for query in EVALUATION_QUERIES[:
|
| 423 |
products = get_candidates(
|
| 424 |
-
query=query,
|
|
|
|
|
|
|
|
|
|
| 425 |
)
|
| 426 |
-
for product in products[:
|
| 427 |
try:
|
| 428 |
-
result = explainer.generate_explanation(
|
|
|
|
|
|
|
| 429 |
hhem = detector.check_explanation(
|
| 430 |
result.evidence_texts, result.explanation
|
| 431 |
)
|
|
@@ -443,8 +488,8 @@ def run_calibration_check(
|
|
| 443 |
logger.debug("Error generating sample", exc_info=True)
|
| 444 |
|
| 445 |
logger.info("Samples: %d", len(samples))
|
| 446 |
-
if len(samples) <
|
| 447 |
-
logger.warning("Not enough samples")
|
| 448 |
return
|
| 449 |
|
| 450 |
# Extract arrays for correlation analysis
|
|
|
|
| 33 |
log_banner,
|
| 34 |
log_section,
|
| 35 |
)
|
| 36 |
+
from sage.services.faithfulness import is_refusal
|
| 37 |
from sage.services.retrieval import get_candidates
|
| 38 |
|
| 39 |
if TYPE_CHECKING:
|
|
|
|
| 153 |
]
|
| 154 |
)
|
| 155 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
# Thresholds
|
| 157 |
COMBINED_HHEM_THRESHOLD = 0.5
|
| 158 |
KEY_TERM_THRESHOLD = 0.3
|
| 159 |
|
| 160 |
+
# Spot-check limits
|
| 161 |
+
SPOT_CHECK_QUERY_LIMIT = 5
|
| 162 |
+
SPOT_CHECK_CANDIDATES_K = 2
|
| 163 |
+
SPOT_CHECK_MIN_RATING = 4.0
|
| 164 |
+
SPOT_CHECK_PRODUCTS_LIMIT = 2
|
| 165 |
+
SPOT_CHECK_MAX_EVIDENCE = 3
|
| 166 |
+
EVIDENCE_PREVIEW_COUNT = 2
|
| 167 |
+
EVIDENCE_PREVIEW_LENGTH = 100
|
| 168 |
+
|
| 169 |
+
# Calibration limits
|
| 170 |
+
CALIBRATION_QUERY_LIMIT = 15
|
| 171 |
+
CALIBRATION_CANDIDATES_K = 5
|
| 172 |
+
CALIBRATION_MIN_RATING = 3.0
|
| 173 |
+
CALIBRATION_PRODUCTS_LIMIT = 2
|
| 174 |
+
CALIBRATION_MAX_EVIDENCE = 3
|
| 175 |
+
MIN_CALIBRATION_SAMPLES = 5
|
| 176 |
+
|
| 177 |
+
# Empty context test settings (low values to trigger quality gate)
|
| 178 |
+
EMPTY_CONTEXT_SCORE = 0.3
|
| 179 |
+
EMPTY_CONTEXT_RATING = 3.0
|
| 180 |
+
|
| 181 |
|
| 182 |
# ============================================================================
|
| 183 |
# SECTION: Spot-Check
|
|
|
|
| 188 |
explainer: Explainer, detector: HallucinationDetector, max_samples: int = 10
|
| 189 |
) -> Iterator[tuple]:
|
| 190 |
"""Generate spot-check samples, yielding (query, hhem_result, explanation_result)."""
|
| 191 |
+
for query in EVALUATION_QUERIES[:SPOT_CHECK_QUERY_LIMIT]:
|
| 192 |
products = get_candidates(
|
| 193 |
+
query=query,
|
| 194 |
+
k=SPOT_CHECK_CANDIDATES_K,
|
| 195 |
+
min_rating=SPOT_CHECK_MIN_RATING,
|
| 196 |
+
aggregation=AggregationMethod.MAX,
|
| 197 |
)
|
| 198 |
+
for product in products[:SPOT_CHECK_PRODUCTS_LIMIT]:
|
| 199 |
+
result = explainer.generate_explanation(
|
| 200 |
+
query, product, max_evidence=SPOT_CHECK_MAX_EVIDENCE
|
| 201 |
+
)
|
| 202 |
hhem = detector.check_explanation(result.evidence_texts, result.explanation)
|
| 203 |
yield query, hhem, result
|
| 204 |
max_samples -= 1
|
|
|
|
| 214 |
for i, (query, hhem, result) in enumerate(
|
| 215 |
_generate_spot_samples(explainer, detector), 1
|
| 216 |
):
|
| 217 |
+
try:
|
| 218 |
+
log_section(logger, f"SAMPLE {i}")
|
| 219 |
+
logger.info('Query: "%s"', query)
|
| 220 |
+
logger.info(
|
| 221 |
+
"HHEM: %.3f (%s)",
|
| 222 |
+
hhem.score,
|
| 223 |
+
"PASS" if not hhem.is_hallucinated else "FAIL",
|
| 224 |
+
)
|
| 225 |
+
logger.info("EVIDENCE:")
|
| 226 |
+
for ev in result.evidence_texts[:EVIDENCE_PREVIEW_COUNT]:
|
| 227 |
+
logger.info(' "%s..."', ev[:EVIDENCE_PREVIEW_LENGTH])
|
| 228 |
+
logger.info("EXPLANATION:")
|
| 229 |
+
logger.info(" %s", result.explanation)
|
| 230 |
+
results.append({"query": query, "hhem_score": hhem.score})
|
| 231 |
+
except Exception:
|
| 232 |
+
logger.warning("Skipping sample %d due to error", i, exc_info=True)
|
| 233 |
+
continue
|
| 234 |
+
|
| 235 |
+
if results:
|
| 236 |
+
scores = [r["hhem_score"] for r in results]
|
| 237 |
logger.info(
|
| 238 |
+
"SUMMARY: %d samples, mean HHEM: %.3f", len(results), np.mean(scores)
|
|
|
|
|
|
|
| 239 |
)
|
| 240 |
+
else:
|
| 241 |
+
logger.warning("SUMMARY: No samples collected")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 242 |
|
| 243 |
|
| 244 |
# ============================================================================
|
|
|
|
| 291 |
results = []
|
| 292 |
for case in ADVERSARIAL_CASES:
|
| 293 |
log_section(logger, case["name"])
|
| 294 |
+
try:
|
| 295 |
+
chunks = [
|
| 296 |
+
make_test_chunk(
|
| 297 |
+
case["positive"], score=0.9, rating=5.0, review_id="pos"
|
| 298 |
+
),
|
| 299 |
+
make_test_chunk(
|
| 300 |
+
case["negative"], score=0.85, rating=1.0, review_id="neg"
|
| 301 |
+
),
|
| 302 |
+
]
|
| 303 |
+
product = make_test_product(chunks)
|
| 304 |
+
result = explainer.generate_explanation(
|
| 305 |
+
case["query"], product, max_evidence=2
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
# Faithfulness check: explanation is grounded in combined evidence
|
| 309 |
+
hhem_combined = detector.check_explanation(
|
| 310 |
+
result.evidence_texts, result.explanation
|
| 311 |
+
)
|
| 312 |
+
is_grounded = hhem_combined.score >= COMBINED_HHEM_THRESHOLD
|
| 313 |
+
|
| 314 |
+
# Content reference check: does explanation reference BOTH pieces?
|
| 315 |
+
pos_ratio = compute_term_overlap(result.explanation, case["positive"])
|
| 316 |
+
neg_ratio = compute_term_overlap(result.explanation, case["negative"])
|
| 317 |
+
references_positive = pos_ratio >= KEY_TERM_THRESHOLD
|
| 318 |
+
references_negative = neg_ratio >= KEY_TERM_THRESHOLD
|
| 319 |
+
references_both = references_positive and references_negative
|
| 320 |
+
|
| 321 |
+
# Keyword check: uses explicit conflict language
|
| 322 |
+
keyword_ack = contains_any_phrase(result.explanation, CONFLICT_PHRASES)
|
| 323 |
+
|
| 324 |
+
# Overall: grounded + references both + uses conflict language
|
| 325 |
+
full_ack = is_grounded and references_both and keyword_ack
|
| 326 |
+
|
| 327 |
+
logger.info("Explanation: %s", result.explanation)
|
| 328 |
+
logger.info(
|
| 329 |
+
"HHEM combined: %.3f (%s)",
|
| 330 |
+
hhem_combined.score,
|
| 331 |
+
"grounded" if is_grounded else "HALLUCINATED",
|
| 332 |
+
)
|
| 333 |
+
logger.info(
|
| 334 |
+
"References positive: %.0f%% of terms (%s)",
|
| 335 |
+
pos_ratio * 100,
|
| 336 |
+
yn(references_positive),
|
| 337 |
+
)
|
| 338 |
+
logger.info(
|
| 339 |
+
"References negative: %.0f%% of terms (%s)",
|
| 340 |
+
neg_ratio * 100,
|
| 341 |
+
yn(references_negative),
|
| 342 |
+
)
|
| 343 |
+
logger.info("Uses conflict language: %s", yn(keyword_ack))
|
| 344 |
+
logger.info("FULL ACKNOWLEDGMENT: %s", "PASS" if full_ack else "FAIL")
|
| 345 |
+
|
| 346 |
+
results.append(
|
| 347 |
+
{
|
| 348 |
+
"case": case["name"],
|
| 349 |
+
"grounded": is_grounded,
|
| 350 |
+
"references_both": references_both,
|
| 351 |
+
"keyword_ack": keyword_ack,
|
| 352 |
+
"full_ack": full_ack,
|
| 353 |
+
}
|
| 354 |
+
)
|
| 355 |
+
except Exception:
|
| 356 |
+
logger.warning("Skipping case %s due to error", case["name"], exc_info=True)
|
| 357 |
+
continue
|
| 358 |
|
| 359 |
log_summary_counts(
|
| 360 |
results,
|
|
|
|
| 380 |
},
|
| 381 |
{"name": "Minimal", "query": "high-quality camera lens", "evidence": "OK."},
|
| 382 |
{
|
| 383 |
+
"name": "Minimal_NonEnglish",
|
| 384 |
"query": "wireless mouse",
|
| 385 |
"evidence": "Muy bueno el producto.",
|
| 386 |
},
|
|
|
|
| 390 |
def run_empty_context_tests(
|
| 391 |
explainer: Explainer, detector: HallucinationDetector
|
| 392 |
) -> None:
|
| 393 |
+
"""Test quality gate refusal on insufficient evidence."""
|
| 394 |
log_banner(logger, "EMPTY CONTEXT: Graceful Refusal", width=70)
|
| 395 |
del detector # Passed for interface consistency but unused (refusals bypass HHEM)
|
| 396 |
|
| 397 |
results = []
|
| 398 |
for case in EMPTY_CONTEXT_CASES:
|
| 399 |
log_section(logger, case["name"])
|
| 400 |
+
try:
|
| 401 |
+
chunk = make_test_chunk(
|
| 402 |
+
case["evidence"], score=EMPTY_CONTEXT_SCORE, rating=EMPTY_CONTEXT_RATING
|
| 403 |
+
)
|
| 404 |
+
product = make_test_product([chunk], score=EMPTY_CONTEXT_SCORE)
|
| 405 |
+
result = explainer.generate_explanation(
|
| 406 |
+
case["query"], product, max_evidence=1
|
| 407 |
+
)
|
| 408 |
|
| 409 |
+
graceful = is_refusal(result.explanation)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 410 |
|
| 411 |
+
logger.info("Explanation: %s", result.explanation)
|
| 412 |
+
logger.info("Graceful refusal: %s", yn(graceful))
|
| 413 |
|
| 414 |
+
results.append({"case": case["name"], "graceful": graceful})
|
| 415 |
+
except Exception:
|
| 416 |
+
logger.warning("Skipping case %s due to error", case["name"], exc_info=True)
|
| 417 |
+
continue
|
| 418 |
|
| 419 |
logger.info(
|
| 420 |
"SUMMARY: %d/%d refused gracefully",
|
|
|
|
| 459 |
samples: list[CalibrationSample] = []
|
| 460 |
logger.info("Generating samples...")
|
| 461 |
|
| 462 |
+
for query in EVALUATION_QUERIES[:CALIBRATION_QUERY_LIMIT]:
|
| 463 |
products = get_candidates(
|
| 464 |
+
query=query,
|
| 465 |
+
k=CALIBRATION_CANDIDATES_K,
|
| 466 |
+
min_rating=CALIBRATION_MIN_RATING,
|
| 467 |
+
aggregation=AggregationMethod.MAX,
|
| 468 |
)
|
| 469 |
+
for product in products[:CALIBRATION_PRODUCTS_LIMIT]:
|
| 470 |
try:
|
| 471 |
+
result = explainer.generate_explanation(
|
| 472 |
+
query, product, max_evidence=CALIBRATION_MAX_EVIDENCE
|
| 473 |
+
)
|
| 474 |
hhem = detector.check_explanation(
|
| 475 |
result.evidence_texts, result.explanation
|
| 476 |
)
|
|
|
|
| 488 |
logger.debug("Error generating sample", exc_info=True)
|
| 489 |
|
| 490 |
logger.info("Samples: %d", len(samples))
|
| 491 |
+
if len(samples) < MIN_CALIBRATION_SAMPLES:
|
| 492 |
+
logger.warning("Not enough samples (need %d)", MIN_CALIBRATION_SAMPLES)
|
| 493 |
return
|
| 494 |
|
| 495 |
# Extract arrays for correlation analysis
|