vxa8502 commited on
Commit
ff5a9a1
·
1 Parent(s): 1ea55b9

Use canonical is_refusal() from faithfulness module

Browse files
Files changed (1) hide show
  1. scripts/sanity_checks.py +146 -101
scripts/sanity_checks.py CHANGED
@@ -33,6 +33,7 @@ from sage.config import (
33
  log_banner,
34
  log_section,
35
  )
 
36
  from sage.services.retrieval import get_candidates
37
 
38
  if TYPE_CHECKING:
@@ -152,22 +153,31 @@ CONFLICT_PHRASES = frozenset(
152
  ]
153
  )
154
 
155
- # Phrases indicating graceful refusal
156
- REFUSAL_PHRASES = frozenset(
157
- [
158
- "cannot",
159
- "can't",
160
- "unable",
161
- "no evidence",
162
- "insufficient",
163
- "limited",
164
- ]
165
- )
166
-
167
  # Thresholds
168
  COMBINED_HHEM_THRESHOLD = 0.5
169
  KEY_TERM_THRESHOLD = 0.3
170
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
 
172
  # ============================================================================
173
  # SECTION: Spot-Check
@@ -178,12 +188,17 @@ def _generate_spot_samples(
178
  explainer: Explainer, detector: HallucinationDetector, max_samples: int = 10
179
  ) -> Iterator[tuple]:
180
  """Generate spot-check samples, yielding (query, hhem_result, explanation_result)."""
181
- for query in EVALUATION_QUERIES[:5]:
182
  products = get_candidates(
183
- query=query, k=2, min_rating=4.0, aggregation=AggregationMethod.MAX
 
 
 
184
  )
185
- for product in products[:2]:
186
- result = explainer.generate_explanation(query, product, max_evidence=3)
 
 
187
  hhem = detector.check_explanation(result.evidence_texts, result.explanation)
188
  yield query, hhem, result
189
  max_samples -= 1
@@ -199,22 +214,31 @@ def run_spot_check(explainer: Explainer, detector: HallucinationDetector) -> Non
199
  for i, (query, hhem, result) in enumerate(
200
  _generate_spot_samples(explainer, detector), 1
201
  ):
202
- log_section(logger, f"SAMPLE {i}")
203
- logger.info('Query: "%s"', query)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
  logger.info(
205
- "HHEM: %.3f (%s)",
206
- hhem.score,
207
- "PASS" if not hhem.is_hallucinated else "FAIL",
208
  )
209
- logger.info("EVIDENCE:")
210
- for ev in result.evidence_texts[:2]:
211
- logger.info(' "%s..."', ev[:100])
212
- logger.info("EXPLANATION:")
213
- logger.info(" %s", result.explanation)
214
- results.append({"query": query, "hhem_score": hhem.score})
215
-
216
- scores = [r["hhem_score"] for r in results]
217
- logger.info("SUMMARY: %d samples, mean HHEM: %.3f", len(results), np.mean(scores))
218
 
219
 
220
  # ============================================================================
@@ -267,61 +291,70 @@ def run_adversarial_tests(
267
  results = []
268
  for case in ADVERSARIAL_CASES:
269
  log_section(logger, case["name"])
270
-
271
- chunks = [
272
- make_test_chunk(case["positive"], score=0.9, rating=5.0, review_id="pos"),
273
- make_test_chunk(case["negative"], score=0.85, rating=1.0, review_id="neg"),
274
- ]
275
- product = make_test_product(chunks)
276
- result = explainer.generate_explanation(case["query"], product, max_evidence=2)
277
-
278
- # Faithfulness check: explanation is grounded in combined evidence
279
- hhem_combined = detector.check_explanation(
280
- result.evidence_texts, result.explanation
281
- )
282
- is_grounded = hhem_combined.score >= COMBINED_HHEM_THRESHOLD
283
-
284
- # Content reference check: does explanation reference BOTH pieces?
285
- pos_ratio = compute_term_overlap(result.explanation, case["positive"])
286
- neg_ratio = compute_term_overlap(result.explanation, case["negative"])
287
- references_positive = pos_ratio >= KEY_TERM_THRESHOLD
288
- references_negative = neg_ratio >= KEY_TERM_THRESHOLD
289
- references_both = references_positive and references_negative
290
-
291
- # Keyword check: uses explicit conflict language
292
- keyword_ack = contains_any_phrase(result.explanation, CONFLICT_PHRASES)
293
-
294
- # Overall: grounded + references both + uses conflict language
295
- full_ack = is_grounded and references_both and keyword_ack
296
-
297
- logger.info("Explanation: %s", result.explanation)
298
- logger.info(
299
- "HHEM combined: %.3f (%s)",
300
- hhem_combined.score,
301
- "grounded" if is_grounded else "HALLUCINATED",
302
- )
303
- logger.info(
304
- "References positive: %.0f%% of terms (%s)",
305
- pos_ratio * 100,
306
- yn(references_positive),
307
- )
308
- logger.info(
309
- "References negative: %.0f%% of terms (%s)",
310
- neg_ratio * 100,
311
- yn(references_negative),
312
- )
313
- logger.info("Uses conflict language: %s", yn(keyword_ack))
314
- logger.info("FULL ACKNOWLEDGMENT: %s", "PASS" if full_ack else "FAIL")
315
-
316
- results.append(
317
- {
318
- "case": case["name"],
319
- "grounded": is_grounded,
320
- "references_both": references_both,
321
- "keyword_ack": keyword_ack,
322
- "full_ack": full_ack,
323
- }
324
- )
 
 
 
 
 
 
 
 
 
325
 
326
  log_summary_counts(
327
  results,
@@ -347,7 +380,7 @@ EMPTY_CONTEXT_CASES = [
347
  },
348
  {"name": "Minimal", "query": "high-quality camera lens", "evidence": "OK."},
349
  {
350
- "name": "Foreign",
351
  "query": "wireless mouse",
352
  "evidence": "Muy bueno el producto.",
353
  },
@@ -357,24 +390,31 @@ EMPTY_CONTEXT_CASES = [
357
  def run_empty_context_tests(
358
  explainer: Explainer, detector: HallucinationDetector
359
  ) -> None:
360
- """Test graceful refusal with irrelevant evidence."""
361
  log_banner(logger, "EMPTY CONTEXT: Graceful Refusal", width=70)
362
  del detector # Passed for interface consistency but unused (refusals bypass HHEM)
363
 
364
  results = []
365
  for case in EMPTY_CONTEXT_CASES:
366
  log_section(logger, case["name"])
 
 
 
 
 
 
 
 
367
 
368
- chunk = make_test_chunk(case["evidence"], score=0.3, rating=3.0)
369
- product = make_test_product([chunk], score=0.3)
370
- result = explainer.generate_explanation(case["query"], product, max_evidence=1)
371
-
372
- graceful = contains_any_phrase(result.explanation, REFUSAL_PHRASES)
373
 
374
- logger.info("Explanation: %s", result.explanation)
375
- logger.info("Graceful refusal: %s", yn(graceful))
376
 
377
- results.append({"case": case["name"], "graceful": graceful})
 
 
 
378
 
379
  logger.info(
380
  "SUMMARY: %d/%d refused gracefully",
@@ -419,13 +459,18 @@ def run_calibration_check(
419
  samples: list[CalibrationSample] = []
420
  logger.info("Generating samples...")
421
 
422
- for query in EVALUATION_QUERIES[:15]:
423
  products = get_candidates(
424
- query=query, k=5, min_rating=3.0, aggregation=AggregationMethod.MAX
 
 
 
425
  )
426
- for product in products[:2]:
427
  try:
428
- result = explainer.generate_explanation(query, product, max_evidence=3)
 
 
429
  hhem = detector.check_explanation(
430
  result.evidence_texts, result.explanation
431
  )
@@ -443,8 +488,8 @@ def run_calibration_check(
443
  logger.debug("Error generating sample", exc_info=True)
444
 
445
  logger.info("Samples: %d", len(samples))
446
- if len(samples) < 5:
447
- logger.warning("Not enough samples")
448
  return
449
 
450
  # Extract arrays for correlation analysis
 
33
  log_banner,
34
  log_section,
35
  )
36
+ from sage.services.faithfulness import is_refusal
37
  from sage.services.retrieval import get_candidates
38
 
39
  if TYPE_CHECKING:
 
153
  ]
154
  )
155
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  # Thresholds
157
  COMBINED_HHEM_THRESHOLD = 0.5
158
  KEY_TERM_THRESHOLD = 0.3
159
 
160
+ # Spot-check limits
161
+ SPOT_CHECK_QUERY_LIMIT = 5
162
+ SPOT_CHECK_CANDIDATES_K = 2
163
+ SPOT_CHECK_MIN_RATING = 4.0
164
+ SPOT_CHECK_PRODUCTS_LIMIT = 2
165
+ SPOT_CHECK_MAX_EVIDENCE = 3
166
+ EVIDENCE_PREVIEW_COUNT = 2
167
+ EVIDENCE_PREVIEW_LENGTH = 100
168
+
169
+ # Calibration limits
170
+ CALIBRATION_QUERY_LIMIT = 15
171
+ CALIBRATION_CANDIDATES_K = 5
172
+ CALIBRATION_MIN_RATING = 3.0
173
+ CALIBRATION_PRODUCTS_LIMIT = 2
174
+ CALIBRATION_MAX_EVIDENCE = 3
175
+ MIN_CALIBRATION_SAMPLES = 5
176
+
177
+ # Empty context test settings (low values to trigger quality gate)
178
+ EMPTY_CONTEXT_SCORE = 0.3
179
+ EMPTY_CONTEXT_RATING = 3.0
180
+
181
 
182
  # ============================================================================
183
  # SECTION: Spot-Check
 
188
  explainer: Explainer, detector: HallucinationDetector, max_samples: int = 10
189
  ) -> Iterator[tuple]:
190
  """Generate spot-check samples, yielding (query, hhem_result, explanation_result)."""
191
+ for query in EVALUATION_QUERIES[:SPOT_CHECK_QUERY_LIMIT]:
192
  products = get_candidates(
193
+ query=query,
194
+ k=SPOT_CHECK_CANDIDATES_K,
195
+ min_rating=SPOT_CHECK_MIN_RATING,
196
+ aggregation=AggregationMethod.MAX,
197
  )
198
+ for product in products[:SPOT_CHECK_PRODUCTS_LIMIT]:
199
+ result = explainer.generate_explanation(
200
+ query, product, max_evidence=SPOT_CHECK_MAX_EVIDENCE
201
+ )
202
  hhem = detector.check_explanation(result.evidence_texts, result.explanation)
203
  yield query, hhem, result
204
  max_samples -= 1
 
214
  for i, (query, hhem, result) in enumerate(
215
  _generate_spot_samples(explainer, detector), 1
216
  ):
217
+ try:
218
+ log_section(logger, f"SAMPLE {i}")
219
+ logger.info('Query: "%s"', query)
220
+ logger.info(
221
+ "HHEM: %.3f (%s)",
222
+ hhem.score,
223
+ "PASS" if not hhem.is_hallucinated else "FAIL",
224
+ )
225
+ logger.info("EVIDENCE:")
226
+ for ev in result.evidence_texts[:EVIDENCE_PREVIEW_COUNT]:
227
+ logger.info(' "%s..."', ev[:EVIDENCE_PREVIEW_LENGTH])
228
+ logger.info("EXPLANATION:")
229
+ logger.info(" %s", result.explanation)
230
+ results.append({"query": query, "hhem_score": hhem.score})
231
+ except Exception:
232
+ logger.warning("Skipping sample %d due to error", i, exc_info=True)
233
+ continue
234
+
235
+ if results:
236
+ scores = [r["hhem_score"] for r in results]
237
  logger.info(
238
+ "SUMMARY: %d samples, mean HHEM: %.3f", len(results), np.mean(scores)
 
 
239
  )
240
+ else:
241
+ logger.warning("SUMMARY: No samples collected")
 
 
 
 
 
 
 
242
 
243
 
244
  # ============================================================================
 
291
  results = []
292
  for case in ADVERSARIAL_CASES:
293
  log_section(logger, case["name"])
294
+ try:
295
+ chunks = [
296
+ make_test_chunk(
297
+ case["positive"], score=0.9, rating=5.0, review_id="pos"
298
+ ),
299
+ make_test_chunk(
300
+ case["negative"], score=0.85, rating=1.0, review_id="neg"
301
+ ),
302
+ ]
303
+ product = make_test_product(chunks)
304
+ result = explainer.generate_explanation(
305
+ case["query"], product, max_evidence=2
306
+ )
307
+
308
+ # Faithfulness check: explanation is grounded in combined evidence
309
+ hhem_combined = detector.check_explanation(
310
+ result.evidence_texts, result.explanation
311
+ )
312
+ is_grounded = hhem_combined.score >= COMBINED_HHEM_THRESHOLD
313
+
314
+ # Content reference check: does explanation reference BOTH pieces?
315
+ pos_ratio = compute_term_overlap(result.explanation, case["positive"])
316
+ neg_ratio = compute_term_overlap(result.explanation, case["negative"])
317
+ references_positive = pos_ratio >= KEY_TERM_THRESHOLD
318
+ references_negative = neg_ratio >= KEY_TERM_THRESHOLD
319
+ references_both = references_positive and references_negative
320
+
321
+ # Keyword check: uses explicit conflict language
322
+ keyword_ack = contains_any_phrase(result.explanation, CONFLICT_PHRASES)
323
+
324
+ # Overall: grounded + references both + uses conflict language
325
+ full_ack = is_grounded and references_both and keyword_ack
326
+
327
+ logger.info("Explanation: %s", result.explanation)
328
+ logger.info(
329
+ "HHEM combined: %.3f (%s)",
330
+ hhem_combined.score,
331
+ "grounded" if is_grounded else "HALLUCINATED",
332
+ )
333
+ logger.info(
334
+ "References positive: %.0f%% of terms (%s)",
335
+ pos_ratio * 100,
336
+ yn(references_positive),
337
+ )
338
+ logger.info(
339
+ "References negative: %.0f%% of terms (%s)",
340
+ neg_ratio * 100,
341
+ yn(references_negative),
342
+ )
343
+ logger.info("Uses conflict language: %s", yn(keyword_ack))
344
+ logger.info("FULL ACKNOWLEDGMENT: %s", "PASS" if full_ack else "FAIL")
345
+
346
+ results.append(
347
+ {
348
+ "case": case["name"],
349
+ "grounded": is_grounded,
350
+ "references_both": references_both,
351
+ "keyword_ack": keyword_ack,
352
+ "full_ack": full_ack,
353
+ }
354
+ )
355
+ except Exception:
356
+ logger.warning("Skipping case %s due to error", case["name"], exc_info=True)
357
+ continue
358
 
359
  log_summary_counts(
360
  results,
 
380
  },
381
  {"name": "Minimal", "query": "high-quality camera lens", "evidence": "OK."},
382
  {
383
+ "name": "Minimal_NonEnglish",
384
  "query": "wireless mouse",
385
  "evidence": "Muy bueno el producto.",
386
  },
 
390
  def run_empty_context_tests(
391
  explainer: Explainer, detector: HallucinationDetector
392
  ) -> None:
393
+ """Test quality gate refusal on insufficient evidence."""
394
  log_banner(logger, "EMPTY CONTEXT: Graceful Refusal", width=70)
395
  del detector # Passed for interface consistency but unused (refusals bypass HHEM)
396
 
397
  results = []
398
  for case in EMPTY_CONTEXT_CASES:
399
  log_section(logger, case["name"])
400
+ try:
401
+ chunk = make_test_chunk(
402
+ case["evidence"], score=EMPTY_CONTEXT_SCORE, rating=EMPTY_CONTEXT_RATING
403
+ )
404
+ product = make_test_product([chunk], score=EMPTY_CONTEXT_SCORE)
405
+ result = explainer.generate_explanation(
406
+ case["query"], product, max_evidence=1
407
+ )
408
 
409
+ graceful = is_refusal(result.explanation)
 
 
 
 
410
 
411
+ logger.info("Explanation: %s", result.explanation)
412
+ logger.info("Graceful refusal: %s", yn(graceful))
413
 
414
+ results.append({"case": case["name"], "graceful": graceful})
415
+ except Exception:
416
+ logger.warning("Skipping case %s due to error", case["name"], exc_info=True)
417
+ continue
418
 
419
  logger.info(
420
  "SUMMARY: %d/%d refused gracefully",
 
459
  samples: list[CalibrationSample] = []
460
  logger.info("Generating samples...")
461
 
462
+ for query in EVALUATION_QUERIES[:CALIBRATION_QUERY_LIMIT]:
463
  products = get_candidates(
464
+ query=query,
465
+ k=CALIBRATION_CANDIDATES_K,
466
+ min_rating=CALIBRATION_MIN_RATING,
467
+ aggregation=AggregationMethod.MAX,
468
  )
469
+ for product in products[:CALIBRATION_PRODUCTS_LIMIT]:
470
  try:
471
+ result = explainer.generate_explanation(
472
+ query, product, max_evidence=CALIBRATION_MAX_EVIDENCE
473
+ )
474
  hhem = detector.check_explanation(
475
  result.evidence_texts, result.explanation
476
  )
 
488
  logger.debug("Error generating sample", exc_info=True)
489
 
490
  logger.info("Samples: %d", len(samples))
491
+ if len(samples) < MIN_CALIBRATION_SAMPLES:
492
+ logger.warning("Not enough samples (need %d)", MIN_CALIBRATION_SAMPLES)
493
  return
494
 
495
  # Extract arrays for correlation analysis