| | """ |
| | Pipeline sanity checks and calibration. |
| | |
| | Combines: |
| | - Manual spot-checks (explanation vs evidence) |
| | - Adversarial tests (contradictory evidence) |
| | - Empty context tests (graceful refusal) |
| | - Calibration analysis (confidence vs faithfulness) |
| | |
| | Usage: |
| | python scripts/sanity_checks.py # All checks |
| | python scripts/sanity_checks.py --section spot # Spot-checks only |
| | python scripts/sanity_checks.py --section adversarial |
| | python scripts/sanity_checks.py --section empty |
| | python scripts/sanity_checks.py --section calibration |
| | |
| | Run from project root. |
| | """ |
| |
|
| | from __future__ import annotations |
| |
|
| | import argparse |
| | from collections.abc import Iterator |
| | from dataclasses import dataclass |
| | from typing import TYPE_CHECKING |
| |
|
| | import numpy as np |
| |
|
| | from sage.core import AggregationMethod, ProductScore, RetrievedChunk |
| | from sage.config import ( |
| | EVALUATION_QUERIES, |
| | get_logger, |
| | log_banner, |
| | log_section, |
| | ) |
| | from sage.services.faithfulness import is_refusal |
| | from sage.services.retrieval import get_candidates |
| |
|
| | if TYPE_CHECKING: |
| | from sage.adapters.hhem import HallucinationDetector |
| | from sage.services.explanation import Explainer |
| |
|
| | logger = get_logger(__name__) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | def yn(condition: bool) -> str: |
| | """Format boolean as YES/NO for logging.""" |
| | return "YES" if condition else "NO" |
| |
|
| |
|
| | def count_matches(results: list[dict], key: str) -> int: |
| | """Count results where key is truthy.""" |
| | return sum(1 for r in results if r.get(key)) |
| |
|
| |
|
| | def extract_key_terms(text: str, min_length: int = 4) -> set[str]: |
| | """Extract lowercase words of min_length+ chars, stripped of punctuation.""" |
| | words = text.lower().split() |
| | return {w.strip(".,!?\"'") for w in words if len(w) >= min_length} |
| |
|
| |
|
| | def make_test_chunk( |
| | text: str, |
| | score: float = 0.85, |
| | rating: float = 3.0, |
| | review_id: str = "r1", |
| | ) -> RetrievedChunk: |
| | """Create a RetrievedChunk for testing with sensible defaults.""" |
| | return RetrievedChunk( |
| | text=text, score=score, product_id="TEST", rating=rating, review_id=review_id |
| | ) |
| |
|
| |
|
| | def make_test_product( |
| | chunks: list[RetrievedChunk], |
| | product_id: str = "TEST", |
| | score: float = 0.85, |
| | ) -> ProductScore: |
| | """Create a ProductScore for testing with sensible defaults.""" |
| | ratings = [c.rating for c in chunks if c.rating is not None] |
| | return ProductScore( |
| | product_id=product_id, |
| | score=score, |
| | chunk_count=len(chunks), |
| | avg_rating=sum(ratings) / len(ratings) if ratings else 0.0, |
| | evidence=chunks, |
| | ) |
| |
|
| |
|
| | def compute_term_overlap(text: str, reference: str) -> float: |
| | """Compute fraction of key terms from reference that appear in text.""" |
| | ref_terms = extract_key_terms(reference) |
| | if not ref_terms: |
| | return 0.0 |
| | text_lower = text.lower() |
| | matches = sum(1 for t in ref_terms if t in text_lower) |
| | return matches / len(ref_terms) |
| |
|
| |
|
| | def log_summary_counts(results: list[dict], metrics: list[tuple[str, str]]) -> None: |
| | """Log summary counts for multiple metrics.""" |
| | logger.info("SUMMARY:") |
| | for label, key in metrics: |
| | logger.info(" %s %d/%d", label, count_matches(results, key), len(results)) |
| |
|
| |
|
| | def contains_any_phrase(text: str, phrases: frozenset[str]) -> bool: |
| | """Check if text contains any of the given phrases (case-insensitive).""" |
| | text_lower = text.lower() |
| | return any(phrase in text_lower for phrase in phrases) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | |
| | CONFLICT_PHRASES = frozenset( |
| | [ |
| | |
| | "however", |
| | "but", |
| | "although", |
| | "while", |
| | "whereas", |
| | "yet", |
| | "nevertheless", |
| | "nonetheless", |
| | "on the other hand", |
| | "conversely", |
| | "in contrast", |
| | |
| | "mixed", |
| | "varies", |
| | "some", |
| | "others", |
| | "both", |
| | "range", |
| | "conflicting", |
| | "contradictory", |
| | "inconsistent", |
| | "divided", |
| | "disagree", |
| | "differ", |
| | "not all", |
| | "opinions vary", |
| | "experiences differ", |
| | ] |
| | ) |
| |
|
| | |
| | COMBINED_HHEM_THRESHOLD = 0.5 |
| | KEY_TERM_THRESHOLD = 0.3 |
| |
|
| | |
| | SPOT_CHECK_QUERY_LIMIT = 5 |
| | SPOT_CHECK_CANDIDATES_K = 2 |
| | SPOT_CHECK_MIN_RATING = 4.0 |
| | SPOT_CHECK_PRODUCTS_LIMIT = 2 |
| | SPOT_CHECK_MAX_EVIDENCE = 3 |
| | EVIDENCE_PREVIEW_COUNT = 2 |
| | EVIDENCE_PREVIEW_LENGTH = 100 |
| |
|
| | |
| | CALIBRATION_QUERY_LIMIT = 15 |
| | CALIBRATION_CANDIDATES_K = 5 |
| | CALIBRATION_MIN_RATING = 3.0 |
| | CALIBRATION_PRODUCTS_LIMIT = 2 |
| | CALIBRATION_MAX_EVIDENCE = 3 |
| | MIN_CALIBRATION_SAMPLES = 5 |
| |
|
| | |
| | EMPTY_CONTEXT_SCORE = 0.3 |
| | EMPTY_CONTEXT_RATING = 3.0 |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | def _generate_spot_samples( |
| | explainer: Explainer, detector: HallucinationDetector, max_samples: int = 10 |
| | ) -> Iterator[tuple]: |
| | """Generate spot-check samples, yielding (query, hhem_result, explanation_result).""" |
| | for query in EVALUATION_QUERIES[:SPOT_CHECK_QUERY_LIMIT]: |
| | products = get_candidates( |
| | query=query, |
| | k=SPOT_CHECK_CANDIDATES_K, |
| | min_rating=SPOT_CHECK_MIN_RATING, |
| | aggregation=AggregationMethod.MAX, |
| | ) |
| | for product in products[:SPOT_CHECK_PRODUCTS_LIMIT]: |
| | result = explainer.generate_explanation( |
| | query, product, max_evidence=SPOT_CHECK_MAX_EVIDENCE |
| | ) |
| | hhem = detector.check_explanation(result.evidence_texts, result.explanation) |
| | yield query, hhem, result |
| | max_samples -= 1 |
| | if max_samples <= 0: |
| | return |
| |
|
| |
|
| | def run_spot_check(explainer: Explainer, detector: HallucinationDetector) -> None: |
| | """Manual spot-check of explanations vs evidence.""" |
| | log_banner(logger, "SPOT-CHECK: Manual Inspection", width=70) |
| |
|
| | results = [] |
| | for i, (query, hhem, result) in enumerate( |
| | _generate_spot_samples(explainer, detector), 1 |
| | ): |
| | try: |
| | log_section(logger, f"SAMPLE {i}") |
| | logger.info('Query: "%s"', query) |
| | logger.info( |
| | "HHEM: %.3f (%s)", |
| | hhem.score, |
| | "PASS" if not hhem.is_hallucinated else "FAIL", |
| | ) |
| | logger.info("EVIDENCE:") |
| | for ev in result.evidence_texts[:EVIDENCE_PREVIEW_COUNT]: |
| | logger.info(' "%s..."', ev[:EVIDENCE_PREVIEW_LENGTH]) |
| | logger.info("EXPLANATION:") |
| | logger.info(" %s", result.explanation) |
| | results.append({"query": query, "hhem_score": hhem.score}) |
| | except Exception: |
| | logger.warning("Skipping sample %d due to error", i, exc_info=True) |
| | continue |
| |
|
| | if results: |
| | scores = [r["hhem_score"] for r in results] |
| | logger.info( |
| | "SUMMARY: %d samples, mean HHEM: %.3f", len(results), np.mean(scores) |
| | ) |
| | else: |
| | logger.warning("SUMMARY: No samples collected") |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | |
| | ADVERSARIAL_CASES = [ |
| | { |
| | "name": "Battery Contradiction", |
| | "query": "laptop with good battery", |
| | "positive": ( |
| | "Battery life on this laptop is absolutely incredible. I consistently " |
| | "get 12 to 14 hours of use on a single charge, even with heavy browsing " |
| | "and video streaming. Perfect for long flights and working remotely " |
| | "without needing to find an outlet. Best battery I've ever had on a laptop." |
| | ), |
| | "negative": ( |
| | "The battery on this laptop is terrible and a huge disappointment. " |
| | "I barely get 3 hours of use before needing to charge again. Even with " |
| | "brightness turned down and minimal apps running, it drains incredibly " |
| | "fast. Do not buy this if you need any kind of portable use." |
| | ), |
| | }, |
| | { |
| | "name": "Build Quality Contradiction", |
| | "query": "durable headphones", |
| | "positive": ( |
| | "These headphones have premium build quality with solid metal construction. " |
| | "I've dropped them multiple times on hard floors and they still work " |
| | "perfectly with no damage. The hinges are sturdy and the headband has " |
| | "survived being thrown in my bag daily for over a year." |
| | ), |
| | "negative": ( |
| | "Build quality is cheap plastic throughout. The headband cracked after " |
| | "just two weeks of normal use and the ear cups feel flimsy. I baby my " |
| | "electronics and these still fell apart. Complete waste of money if " |
| | "you expect them to last more than a month." |
| | ), |
| | }, |
| | ] |
| |
|
| |
|
| | def run_adversarial_tests( |
| | explainer: Explainer, detector: HallucinationDetector |
| | ) -> None: |
| | """Test with contradictory evidence using semantic entailment.""" |
| | log_banner(logger, "ADVERSARIAL: Contradictory Evidence", width=70) |
| |
|
| | results = [] |
| | for case in ADVERSARIAL_CASES: |
| | log_section(logger, case["name"]) |
| | try: |
| | chunks = [ |
| | make_test_chunk( |
| | case["positive"], score=0.9, rating=5.0, review_id="pos" |
| | ), |
| | make_test_chunk( |
| | case["negative"], score=0.85, rating=1.0, review_id="neg" |
| | ), |
| | ] |
| | product = make_test_product(chunks) |
| | result = explainer.generate_explanation( |
| | case["query"], product, max_evidence=2 |
| | ) |
| |
|
| | |
| | hhem_combined = detector.check_explanation( |
| | result.evidence_texts, result.explanation |
| | ) |
| | is_grounded = hhem_combined.score >= COMBINED_HHEM_THRESHOLD |
| |
|
| | |
| | pos_ratio = compute_term_overlap(result.explanation, case["positive"]) |
| | neg_ratio = compute_term_overlap(result.explanation, case["negative"]) |
| | references_positive = pos_ratio >= KEY_TERM_THRESHOLD |
| | references_negative = neg_ratio >= KEY_TERM_THRESHOLD |
| | references_both = references_positive and references_negative |
| |
|
| | |
| | keyword_ack = contains_any_phrase(result.explanation, CONFLICT_PHRASES) |
| |
|
| | |
| | full_ack = is_grounded and references_both and keyword_ack |
| |
|
| | logger.info("Explanation: %s", result.explanation) |
| | logger.info( |
| | "HHEM combined: %.3f (%s)", |
| | hhem_combined.score, |
| | "grounded" if is_grounded else "HALLUCINATED", |
| | ) |
| | logger.info( |
| | "References positive: %.0f%% of terms (%s)", |
| | pos_ratio * 100, |
| | yn(references_positive), |
| | ) |
| | logger.info( |
| | "References negative: %.0f%% of terms (%s)", |
| | neg_ratio * 100, |
| | yn(references_negative), |
| | ) |
| | logger.info("Uses conflict language: %s", yn(keyword_ack)) |
| | logger.info("FULL ACKNOWLEDGMENT: %s", "PASS" if full_ack else "FAIL") |
| |
|
| | results.append( |
| | { |
| | "case": case["name"], |
| | "grounded": is_grounded, |
| | "references_both": references_both, |
| | "keyword_ack": keyword_ack, |
| | "full_ack": full_ack, |
| | } |
| | ) |
| | except Exception: |
| | logger.warning("Skipping case %s due to error", case["name"], exc_info=True) |
| | continue |
| |
|
| | log_summary_counts( |
| | results, |
| | [ |
| | ("Grounded (HHEM): ", "grounded"), |
| | ("References both sides:", "references_both"), |
| | ("Uses conflict language:", "keyword_ack"), |
| | ("FULL ACKNOWLEDGMENT: ", "full_ack"), |
| | ], |
| | ) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | |
| | EMPTY_CONTEXT_CASES = [ |
| | { |
| | "name": "Irrelevant", |
| | "query": "quantum computing textbook", |
| | "evidence": "Great USB cable.", |
| | }, |
| | {"name": "Minimal", "query": "high-quality camera lens", "evidence": "OK."}, |
| | { |
| | "name": "Minimal_NonEnglish", |
| | "query": "wireless mouse", |
| | "evidence": "Muy bueno el producto.", |
| | }, |
| | ] |
| |
|
| |
|
| | def run_empty_context_tests( |
| | explainer: Explainer, detector: HallucinationDetector |
| | ) -> None: |
| | """Test quality gate refusal on insufficient evidence.""" |
| | log_banner(logger, "EMPTY CONTEXT: Graceful Refusal", width=70) |
| | del detector |
| |
|
| | results = [] |
| | for case in EMPTY_CONTEXT_CASES: |
| | log_section(logger, case["name"]) |
| | try: |
| | chunk = make_test_chunk( |
| | case["evidence"], score=EMPTY_CONTEXT_SCORE, rating=EMPTY_CONTEXT_RATING |
| | ) |
| | product = make_test_product([chunk], score=EMPTY_CONTEXT_SCORE) |
| | result = explainer.generate_explanation( |
| | case["query"], product, max_evidence=1 |
| | ) |
| |
|
| | graceful = is_refusal(result.explanation) |
| |
|
| | logger.info("Explanation: %s", result.explanation) |
| | logger.info("Graceful refusal: %s", yn(graceful)) |
| |
|
| | results.append({"case": case["name"], "graceful": graceful}) |
| | except Exception: |
| | logger.warning("Skipping case %s due to error", case["name"], exc_info=True) |
| | continue |
| |
|
| | logger.info( |
| | "SUMMARY: %d/%d refused gracefully", |
| | count_matches(results, "graceful"), |
| | len(results), |
| | ) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | @dataclass |
| | class CalibrationSample: |
| | query: str |
| | product_id: str |
| | retrieval_score: float |
| | evidence_count: int |
| | avg_rating: float |
| | hhem_score: float |
| |
|
| |
|
| | def _safe_corr(x: np.ndarray, y: np.ndarray) -> float: |
| | """Compute correlation, returning 0.0 if either array has zero variance.""" |
| | if np.std(x) == 0 or np.std(y) == 0: |
| | return 0.0 |
| | return float(np.corrcoef(x, y)[0, 1]) |
| |
|
| |
|
| | def _tier_mean(samples: list[CalibrationSample]) -> float: |
| | """Compute mean HHEM score for a tier of samples.""" |
| | return float(np.mean([s.hhem_score for s in samples])) if samples else 0.0 |
| |
|
| |
|
| | def run_calibration_check( |
| | explainer: Explainer, detector: HallucinationDetector |
| | ) -> None: |
| | """Analyze confidence vs faithfulness correlation.""" |
| | log_banner(logger, "CALIBRATION: Confidence vs Faithfulness", width=70) |
| |
|
| | samples: list[CalibrationSample] = [] |
| | logger.info("Generating samples...") |
| |
|
| | for query in EVALUATION_QUERIES[:CALIBRATION_QUERY_LIMIT]: |
| | products = get_candidates( |
| | query=query, |
| | k=CALIBRATION_CANDIDATES_K, |
| | min_rating=CALIBRATION_MIN_RATING, |
| | aggregation=AggregationMethod.MAX, |
| | ) |
| | for product in products[:CALIBRATION_PRODUCTS_LIMIT]: |
| | try: |
| | result = explainer.generate_explanation( |
| | query, product, max_evidence=CALIBRATION_MAX_EVIDENCE |
| | ) |
| | hhem = detector.check_explanation( |
| | result.evidence_texts, result.explanation |
| | ) |
| | samples.append( |
| | CalibrationSample( |
| | query=query, |
| | product_id=product.product_id, |
| | retrieval_score=product.score, |
| | evidence_count=product.chunk_count, |
| | avg_rating=product.avg_rating, |
| | hhem_score=hhem.score, |
| | ) |
| | ) |
| | except Exception: |
| | logger.debug("Error generating sample", exc_info=True) |
| |
|
| | logger.info("Samples: %d", len(samples)) |
| | if len(samples) < MIN_CALIBRATION_SAMPLES: |
| | logger.warning("Not enough samples (need %d)", MIN_CALIBRATION_SAMPLES) |
| | return |
| |
|
| | |
| | hhem_scores = np.array([s.hhem_score for s in samples]) |
| | retrieval_scores = np.array([s.retrieval_score for s in samples]) |
| | evidence_counts = np.array([s.evidence_count for s in samples]) |
| |
|
| | log_section(logger, "Correlations with HHEM") |
| | logger.info( |
| | " Retrieval score: r = %+.3f", _safe_corr(retrieval_scores, hhem_scores) |
| | ) |
| | logger.info( |
| | " Evidence count: r = %+.3f", _safe_corr(evidence_counts, hhem_scores) |
| | ) |
| |
|
| | |
| | sorted_samples = sorted(samples, key=lambda s: s.retrieval_score) |
| | n = len(sorted_samples) |
| | tiers = [ |
| | ("LOW ", sorted_samples[: n // 3]), |
| | ("MED ", sorted_samples[n // 3 : 2 * n // 3]), |
| | ("HIGH", sorted_samples[2 * n // 3 :]), |
| | ] |
| |
|
| | log_section(logger, "HHEM by Confidence Tier") |
| | for name, tier in tiers: |
| | logger.info(" %s (n=%2d): %.3f", name, len(tier), _tier_mean(tier)) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | def main() -> None: |
| | from sage.adapters.hhem import HallucinationDetector |
| | from sage.services.explanation import Explainer |
| |
|
| | parser = argparse.ArgumentParser(description="Run pipeline sanity checks") |
| | parser.add_argument( |
| | "--section", |
| | "-s", |
| | choices=["all", "spot", "adversarial", "empty", "calibration"], |
| | default="all", |
| | help="Which section to run", |
| | ) |
| | args = parser.parse_args() |
| |
|
| | |
| | explainer = Explainer() |
| | detector = HallucinationDetector() |
| |
|
| | if args.section in ("all", "spot"): |
| | run_spot_check(explainer, detector) |
| | if args.section in ("all", "adversarial"): |
| | run_adversarial_tests(explainer, detector) |
| | if args.section in ("all", "empty"): |
| | run_empty_context_tests(explainer, detector) |
| | if args.section in ("all", "calibration"): |
| | run_calibration_check(explainer, detector) |
| |
|
| | log_banner(logger, "SANITY CHECKS COMPLETE", width=70) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|