Sage / scripts /sanity_checks.py
vxa8502's picture
Use canonical is_refusal() from faithfulness module
ff5a9a1
"""
Pipeline sanity checks and calibration.
Combines:
- Manual spot-checks (explanation vs evidence)
- Adversarial tests (contradictory evidence)
- Empty context tests (graceful refusal)
- Calibration analysis (confidence vs faithfulness)
Usage:
python scripts/sanity_checks.py # All checks
python scripts/sanity_checks.py --section spot # Spot-checks only
python scripts/sanity_checks.py --section adversarial
python scripts/sanity_checks.py --section empty
python scripts/sanity_checks.py --section calibration
Run from project root.
"""
from __future__ import annotations
import argparse
from collections.abc import Iterator
from dataclasses import dataclass
from typing import TYPE_CHECKING
import numpy as np
from sage.core import AggregationMethod, ProductScore, RetrievedChunk
from sage.config import (
EVALUATION_QUERIES,
get_logger,
log_banner,
log_section,
)
from sage.services.faithfulness import is_refusal
from sage.services.retrieval import get_candidates
if TYPE_CHECKING:
from sage.adapters.hhem import HallucinationDetector
from sage.services.explanation import Explainer
logger = get_logger(__name__)
# ============================================================================
# Shared Helpers
# ============================================================================
def yn(condition: bool) -> str:
"""Format boolean as YES/NO for logging."""
return "YES" if condition else "NO"
def count_matches(results: list[dict], key: str) -> int:
"""Count results where key is truthy."""
return sum(1 for r in results if r.get(key))
def extract_key_terms(text: str, min_length: int = 4) -> set[str]:
"""Extract lowercase words of min_length+ chars, stripped of punctuation."""
words = text.lower().split()
return {w.strip(".,!?\"'") for w in words if len(w) >= min_length}
def make_test_chunk(
text: str,
score: float = 0.85,
rating: float = 3.0,
review_id: str = "r1",
) -> RetrievedChunk:
"""Create a RetrievedChunk for testing with sensible defaults."""
return RetrievedChunk(
text=text, score=score, product_id="TEST", rating=rating, review_id=review_id
)
def make_test_product(
chunks: list[RetrievedChunk],
product_id: str = "TEST",
score: float = 0.85,
) -> ProductScore:
"""Create a ProductScore for testing with sensible defaults."""
ratings = [c.rating for c in chunks if c.rating is not None]
return ProductScore(
product_id=product_id,
score=score,
chunk_count=len(chunks),
avg_rating=sum(ratings) / len(ratings) if ratings else 0.0,
evidence=chunks,
)
def compute_term_overlap(text: str, reference: str) -> float:
"""Compute fraction of key terms from reference that appear in text."""
ref_terms = extract_key_terms(reference)
if not ref_terms:
return 0.0
text_lower = text.lower()
matches = sum(1 for t in ref_terms if t in text_lower)
return matches / len(ref_terms)
def log_summary_counts(results: list[dict], metrics: list[tuple[str, str]]) -> None:
"""Log summary counts for multiple metrics."""
logger.info("SUMMARY:")
for label, key in metrics:
logger.info(" %s %d/%d", label, count_matches(results, key), len(results))
def contains_any_phrase(text: str, phrases: frozenset[str]) -> bool:
"""Check if text contains any of the given phrases (case-insensitive)."""
text_lower = text.lower()
return any(phrase in text_lower for phrase in phrases)
# ============================================================================
# Constants
# ============================================================================
# Phrases indicating conflict acknowledgment in explanations
CONFLICT_PHRASES = frozenset(
[
# Contrast words
"however",
"but",
"although",
"while",
"whereas",
"yet",
"nevertheless",
"nonetheless",
"on the other hand",
"conversely",
"in contrast",
# Acknowledgment of mixed opinions
"mixed",
"varies",
"some",
"others",
"both",
"range",
"conflicting",
"contradictory",
"inconsistent",
"divided",
"disagree",
"differ",
"not all",
"opinions vary",
"experiences differ",
]
)
# Thresholds
COMBINED_HHEM_THRESHOLD = 0.5
KEY_TERM_THRESHOLD = 0.3
# Spot-check limits
SPOT_CHECK_QUERY_LIMIT = 5
SPOT_CHECK_CANDIDATES_K = 2
SPOT_CHECK_MIN_RATING = 4.0
SPOT_CHECK_PRODUCTS_LIMIT = 2
SPOT_CHECK_MAX_EVIDENCE = 3
EVIDENCE_PREVIEW_COUNT = 2
EVIDENCE_PREVIEW_LENGTH = 100
# Calibration limits
CALIBRATION_QUERY_LIMIT = 15
CALIBRATION_CANDIDATES_K = 5
CALIBRATION_MIN_RATING = 3.0
CALIBRATION_PRODUCTS_LIMIT = 2
CALIBRATION_MAX_EVIDENCE = 3
MIN_CALIBRATION_SAMPLES = 5
# Empty context test settings (low values to trigger quality gate)
EMPTY_CONTEXT_SCORE = 0.3
EMPTY_CONTEXT_RATING = 3.0
# ============================================================================
# SECTION: Spot-Check
# ============================================================================
def _generate_spot_samples(
explainer: Explainer, detector: HallucinationDetector, max_samples: int = 10
) -> Iterator[tuple]:
"""Generate spot-check samples, yielding (query, hhem_result, explanation_result)."""
for query in EVALUATION_QUERIES[:SPOT_CHECK_QUERY_LIMIT]:
products = get_candidates(
query=query,
k=SPOT_CHECK_CANDIDATES_K,
min_rating=SPOT_CHECK_MIN_RATING,
aggregation=AggregationMethod.MAX,
)
for product in products[:SPOT_CHECK_PRODUCTS_LIMIT]:
result = explainer.generate_explanation(
query, product, max_evidence=SPOT_CHECK_MAX_EVIDENCE
)
hhem = detector.check_explanation(result.evidence_texts, result.explanation)
yield query, hhem, result
max_samples -= 1
if max_samples <= 0:
return
def run_spot_check(explainer: Explainer, detector: HallucinationDetector) -> None:
"""Manual spot-check of explanations vs evidence."""
log_banner(logger, "SPOT-CHECK: Manual Inspection", width=70)
results = []
for i, (query, hhem, result) in enumerate(
_generate_spot_samples(explainer, detector), 1
):
try:
log_section(logger, f"SAMPLE {i}")
logger.info('Query: "%s"', query)
logger.info(
"HHEM: %.3f (%s)",
hhem.score,
"PASS" if not hhem.is_hallucinated else "FAIL",
)
logger.info("EVIDENCE:")
for ev in result.evidence_texts[:EVIDENCE_PREVIEW_COUNT]:
logger.info(' "%s..."', ev[:EVIDENCE_PREVIEW_LENGTH])
logger.info("EXPLANATION:")
logger.info(" %s", result.explanation)
results.append({"query": query, "hhem_score": hhem.score})
except Exception:
logger.warning("Skipping sample %d due to error", i, exc_info=True)
continue
if results:
scores = [r["hhem_score"] for r in results]
logger.info(
"SUMMARY: %d samples, mean HHEM: %.3f", len(results), np.mean(scores)
)
else:
logger.warning("SUMMARY: No samples collected")
# ============================================================================
# SECTION: Adversarial Tests
# ============================================================================
# Test cases with contradictory evidence (must be ~50+ tokens each to pass quality gate)
ADVERSARIAL_CASES = [
{
"name": "Battery Contradiction",
"query": "laptop with good battery",
"positive": (
"Battery life on this laptop is absolutely incredible. I consistently "
"get 12 to 14 hours of use on a single charge, even with heavy browsing "
"and video streaming. Perfect for long flights and working remotely "
"without needing to find an outlet. Best battery I've ever had on a laptop."
),
"negative": (
"The battery on this laptop is terrible and a huge disappointment. "
"I barely get 3 hours of use before needing to charge again. Even with "
"brightness turned down and minimal apps running, it drains incredibly "
"fast. Do not buy this if you need any kind of portable use."
),
},
{
"name": "Build Quality Contradiction",
"query": "durable headphones",
"positive": (
"These headphones have premium build quality with solid metal construction. "
"I've dropped them multiple times on hard floors and they still work "
"perfectly with no damage. The hinges are sturdy and the headband has "
"survived being thrown in my bag daily for over a year."
),
"negative": (
"Build quality is cheap plastic throughout. The headband cracked after "
"just two weeks of normal use and the ear cups feel flimsy. I baby my "
"electronics and these still fell apart. Complete waste of money if "
"you expect them to last more than a month."
),
},
]
def run_adversarial_tests(
explainer: Explainer, detector: HallucinationDetector
) -> None:
"""Test with contradictory evidence using semantic entailment."""
log_banner(logger, "ADVERSARIAL: Contradictory Evidence", width=70)
results = []
for case in ADVERSARIAL_CASES:
log_section(logger, case["name"])
try:
chunks = [
make_test_chunk(
case["positive"], score=0.9, rating=5.0, review_id="pos"
),
make_test_chunk(
case["negative"], score=0.85, rating=1.0, review_id="neg"
),
]
product = make_test_product(chunks)
result = explainer.generate_explanation(
case["query"], product, max_evidence=2
)
# Faithfulness check: explanation is grounded in combined evidence
hhem_combined = detector.check_explanation(
result.evidence_texts, result.explanation
)
is_grounded = hhem_combined.score >= COMBINED_HHEM_THRESHOLD
# Content reference check: does explanation reference BOTH pieces?
pos_ratio = compute_term_overlap(result.explanation, case["positive"])
neg_ratio = compute_term_overlap(result.explanation, case["negative"])
references_positive = pos_ratio >= KEY_TERM_THRESHOLD
references_negative = neg_ratio >= KEY_TERM_THRESHOLD
references_both = references_positive and references_negative
# Keyword check: uses explicit conflict language
keyword_ack = contains_any_phrase(result.explanation, CONFLICT_PHRASES)
# Overall: grounded + references both + uses conflict language
full_ack = is_grounded and references_both and keyword_ack
logger.info("Explanation: %s", result.explanation)
logger.info(
"HHEM combined: %.3f (%s)",
hhem_combined.score,
"grounded" if is_grounded else "HALLUCINATED",
)
logger.info(
"References positive: %.0f%% of terms (%s)",
pos_ratio * 100,
yn(references_positive),
)
logger.info(
"References negative: %.0f%% of terms (%s)",
neg_ratio * 100,
yn(references_negative),
)
logger.info("Uses conflict language: %s", yn(keyword_ack))
logger.info("FULL ACKNOWLEDGMENT: %s", "PASS" if full_ack else "FAIL")
results.append(
{
"case": case["name"],
"grounded": is_grounded,
"references_both": references_both,
"keyword_ack": keyword_ack,
"full_ack": full_ack,
}
)
except Exception:
logger.warning("Skipping case %s due to error", case["name"], exc_info=True)
continue
log_summary_counts(
results,
[
("Grounded (HHEM): ", "grounded"),
("References both sides:", "references_both"),
("Uses conflict language:", "keyword_ack"),
("FULL ACKNOWLEDGMENT: ", "full_ack"),
],
)
# ============================================================================
# SECTION: Empty Context Tests
# ============================================================================
# Test cases for empty/irrelevant context handling
EMPTY_CONTEXT_CASES = [
{
"name": "Irrelevant",
"query": "quantum computing textbook",
"evidence": "Great USB cable.",
},
{"name": "Minimal", "query": "high-quality camera lens", "evidence": "OK."},
{
"name": "Minimal_NonEnglish",
"query": "wireless mouse",
"evidence": "Muy bueno el producto.",
},
]
def run_empty_context_tests(
explainer: Explainer, detector: HallucinationDetector
) -> None:
"""Test quality gate refusal on insufficient evidence."""
log_banner(logger, "EMPTY CONTEXT: Graceful Refusal", width=70)
del detector # Passed for interface consistency but unused (refusals bypass HHEM)
results = []
for case in EMPTY_CONTEXT_CASES:
log_section(logger, case["name"])
try:
chunk = make_test_chunk(
case["evidence"], score=EMPTY_CONTEXT_SCORE, rating=EMPTY_CONTEXT_RATING
)
product = make_test_product([chunk], score=EMPTY_CONTEXT_SCORE)
result = explainer.generate_explanation(
case["query"], product, max_evidence=1
)
graceful = is_refusal(result.explanation)
logger.info("Explanation: %s", result.explanation)
logger.info("Graceful refusal: %s", yn(graceful))
results.append({"case": case["name"], "graceful": graceful})
except Exception:
logger.warning("Skipping case %s due to error", case["name"], exc_info=True)
continue
logger.info(
"SUMMARY: %d/%d refused gracefully",
count_matches(results, "graceful"),
len(results),
)
# ============================================================================
# SECTION: Calibration Check
# ============================================================================
@dataclass
class CalibrationSample:
query: str
product_id: str
retrieval_score: float
evidence_count: int
avg_rating: float
hhem_score: float
def _safe_corr(x: np.ndarray, y: np.ndarray) -> float:
"""Compute correlation, returning 0.0 if either array has zero variance."""
if np.std(x) == 0 or np.std(y) == 0:
return 0.0
return float(np.corrcoef(x, y)[0, 1])
def _tier_mean(samples: list[CalibrationSample]) -> float:
"""Compute mean HHEM score for a tier of samples."""
return float(np.mean([s.hhem_score for s in samples])) if samples else 0.0
def run_calibration_check(
explainer: Explainer, detector: HallucinationDetector
) -> None:
"""Analyze confidence vs faithfulness correlation."""
log_banner(logger, "CALIBRATION: Confidence vs Faithfulness", width=70)
samples: list[CalibrationSample] = []
logger.info("Generating samples...")
for query in EVALUATION_QUERIES[:CALIBRATION_QUERY_LIMIT]:
products = get_candidates(
query=query,
k=CALIBRATION_CANDIDATES_K,
min_rating=CALIBRATION_MIN_RATING,
aggregation=AggregationMethod.MAX,
)
for product in products[:CALIBRATION_PRODUCTS_LIMIT]:
try:
result = explainer.generate_explanation(
query, product, max_evidence=CALIBRATION_MAX_EVIDENCE
)
hhem = detector.check_explanation(
result.evidence_texts, result.explanation
)
samples.append(
CalibrationSample(
query=query,
product_id=product.product_id,
retrieval_score=product.score,
evidence_count=product.chunk_count,
avg_rating=product.avg_rating,
hhem_score=hhem.score,
)
)
except Exception:
logger.debug("Error generating sample", exc_info=True)
logger.info("Samples: %d", len(samples))
if len(samples) < MIN_CALIBRATION_SAMPLES:
logger.warning("Not enough samples (need %d)", MIN_CALIBRATION_SAMPLES)
return
# Extract arrays for correlation analysis
hhem_scores = np.array([s.hhem_score for s in samples])
retrieval_scores = np.array([s.retrieval_score for s in samples])
evidence_counts = np.array([s.evidence_count for s in samples])
log_section(logger, "Correlations with HHEM")
logger.info(
" Retrieval score: r = %+.3f", _safe_corr(retrieval_scores, hhem_scores)
)
logger.info(
" Evidence count: r = %+.3f", _safe_corr(evidence_counts, hhem_scores)
)
# Stratified analysis by confidence tier
sorted_samples = sorted(samples, key=lambda s: s.retrieval_score)
n = len(sorted_samples)
tiers = [
("LOW ", sorted_samples[: n // 3]),
("MED ", sorted_samples[n // 3 : 2 * n // 3]),
("HIGH", sorted_samples[2 * n // 3 :]),
]
log_section(logger, "HHEM by Confidence Tier")
for name, tier in tiers:
logger.info(" %s (n=%2d): %.3f", name, len(tier), _tier_mean(tier))
# ============================================================================
# Main
# ============================================================================
def main() -> None:
from sage.adapters.hhem import HallucinationDetector
from sage.services.explanation import Explainer
parser = argparse.ArgumentParser(description="Run pipeline sanity checks")
parser.add_argument(
"--section",
"-s",
choices=["all", "spot", "adversarial", "empty", "calibration"],
default="all",
help="Which section to run",
)
args = parser.parse_args()
# Initialize services once
explainer = Explainer()
detector = HallucinationDetector()
if args.section in ("all", "spot"):
run_spot_check(explainer, detector)
if args.section in ("all", "adversarial"):
run_adversarial_tests(explainer, detector)
if args.section in ("all", "empty"):
run_empty_context_tests(explainer, detector)
if args.section in ("all", "calibration"):
run_calibration_check(explainer, detector)
log_banner(logger, "SANITY CHECKS COMPLETE", width=70)
if __name__ == "__main__":
main()