""" Explanation generation tests. Combines: - Basic explanation generation with HHEM detection - Evidence quality gate validation - Post-generation verification loop - Cold-start handling tests Usage: python scripts/explanation.py # All tests python scripts/explanation.py --section basic # Basic generation only python scripts/explanation.py --section gate # Quality gate only python scripts/explanation.py --section verify # Verification only python scripts/explanation.py --section cold # Cold-start only Run from project root. """ import argparse import numpy as np from sage.core import AggregationMethod, ProductScore, RetrievedChunk from sage.config import ( LLM_PROVIDER, get_logger, log_banner, log_section, ) from sage.services.retrieval import get_candidates logger = get_logger(__name__) TOP_K_PRODUCTS = 3 PRODUCTS_PER_QUERY = 2 # ============================================================================ # SECTION: Basic Explanation Generation # ============================================================================ def run_basic_tests(): """Test basic explanation generation and HHEM detection.""" from sage.services import get_explanation_services log_banner(logger, "BASIC EXPLANATION TESTS") logger.info("Using LLM provider: %s", LLM_PROVIDER) test_queries = [ "wireless headphones with good noise cancellation", "laptop charger that works with MacBook", "USB hub with multiple ports", ] # Get recommendations log_section(logger, "1. GETTING RECOMMENDATIONS") query_results = {} for query in test_queries: products = get_candidates( query=query, k=TOP_K_PRODUCTS, min_rating=4.0, aggregation=AggregationMethod.MAX, ) query_results[query] = products logger.info('Query: "%s"', query) logger.info(" Found %d products", len(products)) # Generate explanations log_section(logger, "2. GENERATING EXPLANATIONS") explainer, detector = get_explanation_services() all_explanations = [] for query, products in query_results.items(): logger.info('--- Query: "%s" ---', query) for product in products[:PRODUCTS_PER_QUERY]: result = explainer.generate_explanation(query, product) all_explanations.append(result) logger.info("Product: %s", result.product_id) logger.info("Explanation: %s", result.explanation) # Run HHEM log_section(logger, "3. HHEM HALLUCINATION DETECTION") hhem_results = [ detector.check_explanation(expl.evidence_texts, expl.explanation) for expl in all_explanations ] for expl, result in zip(all_explanations, hhem_results, strict=True): status = "GROUNDED" if not result.is_hallucinated else "HALLUCINATED" logger.info("[%s] Score: %.3f - %s", status, result.score, expl.product_id) scores = [r.score for r in hhem_results] n_hall = sum(1 for r in hhem_results if r.is_hallucinated) logger.info("Summary: %d total, %d hallucinated", len(hhem_results), n_hall) logger.info("Mean HHEM: %.3f", np.mean(scores)) # Streaming test log_section(logger, "4. STREAMING TEST") if query_results: test_query = list(query_results.keys())[0] test_product = query_results[test_query][0] logger.info('Query: "%s"', test_query) logger.info("Streaming: ") try: stream = explainer.generate_explanation_stream(test_query, test_product) chunks = list(stream) logger.info("".join(chunks)) streamed_result = stream.get_complete_result() hhem = detector.check_explanation( streamed_result.evidence_texts, streamed_result.explanation ) logger.info("HHEM Score: %.3f", hhem.score) except ValueError as e: logger.info("Quality gate refused streaming: %s", e) log_banner(logger, "BASIC TESTS COMPLETE") # ============================================================================ # SECTION: Evidence Quality Gate # ============================================================================ def create_mock_product( n_chunks: int, tokens_per_chunk: int = 100, product_score: float = 0.85 ) -> ProductScore: """Create a mock ProductScore for testing.""" chunks = [ RetrievedChunk( text="x" * (tokens_per_chunk * 4), score=product_score - i * 0.01, product_id="TEST_PRODUCT", rating=4.5, review_id=f"review_{i}", ) for i in range(n_chunks) ] return ProductScore( product_id="TEST_PRODUCT", score=product_score, chunk_count=n_chunks, avg_rating=4.5, evidence=chunks, ) def run_quality_gate_tests(): """Test the evidence quality gate.""" from sage.core.evidence import check_evidence_quality, generate_refusal_message from sage.services.faithfulness import is_refusal from sage.config import ( MIN_EVIDENCE_CHUNKS, MIN_EVIDENCE_TOKENS, MIN_RETRIEVAL_SCORE, ) log_banner(logger, "EVIDENCE QUALITY GATE TESTS") log_section(logger, "1. QUALITY CHECK FUNCTION") test_cases = [ (3, 100, 0.85, True, "sufficient"), (1, 100, 0.85, False, "insufficient_chunks"), (2, 10, 0.85, False, "insufficient_tokens"), (3, 100, 0.5, False, "low_relevance"), ] for n_chunks, tok, score, expected, reason in test_cases: product = create_mock_product(n_chunks, tok, score) quality = check_evidence_quality(product) status = "PASS" if quality.is_sufficient == expected else "FAIL" logger.info( "[%s] %d chunks, %d tok, score=%.2f -> %s", status, n_chunks, tok, score, reason, ) assert quality.is_sufficient == expected log_section(logger, "2. REFUSAL GENERATION") query = "wireless headphones" for n_chunks, tok, score in [(1, 100, 0.85), (2, 10, 0.85), (3, 100, 0.5)]: product = create_mock_product(n_chunks, tok, score) quality = check_evidence_quality(product) refusal = generate_refusal_message(query, quality) detected = is_refusal(refusal) logger.info( "[%s] Refusal detected for %s", "PASS" if detected else "FAIL", quality.failure_reason, ) assert detected logger.info( "Thresholds: chunks=%d, tokens=%d, score=%.2f", MIN_EVIDENCE_CHUNKS, MIN_EVIDENCE_TOKENS, MIN_RETRIEVAL_SCORE, ) log_banner(logger, "QUALITY GATE TESTS COMPLETE") # ============================================================================ # SECTION: Verification Loop # ============================================================================ def run_verification_tests(): """Test the post-generation verification loop.""" from sage.core.verification import ( extract_quotes, verify_quote_in_evidence, verify_explanation, ) log_banner(logger, "VERIFICATION LOOP TESTS") log_section(logger, "1. QUOTE EXTRACTION") test_texts = [ ('Said "amazing sound" and "great value".', 2), ('Noted "excellent battery life".', 1), ('It was "ok" but "amazing quality".', 1), # "ok" filtered ("No quotes.", 0), ] for text, expected in test_texts: quotes = extract_quotes(text) status = "PASS" if len(quotes) == expected else "FAIL" logger.info("[%s] %d quotes: %s...", status, len(quotes), text[:40]) log_section(logger, "2. QUOTE VERIFICATION") evidence = [ "This product has amazing sound quality.", "Great value for the money.", "Battery lasts about 8 hours.", ] verify_cases = [ ("amazing sound quality", True), ("AMAZING SOUND QUALITY", True), ("Battery lasts about 8 hours", True), ("terrible product", False), ] for quote, expected in verify_cases: result = verify_quote_in_evidence(quote, evidence) status = "PASS" if result.found == expected else "FAIL" logger.info("[%s] '%s'", status, quote) log_section(logger, "3. FULL VERIFICATION") explanation = 'Praise for "amazing sound quality" and "great value for the money".' result = verify_explanation(explanation, evidence) logger.info("Found: %d, Missing: %d", result.quotes_found, result.quotes_missing) logger.info("All verified: %s", result.all_verified) log_banner(logger, "VERIFICATION TESTS COMPLETE") # ============================================================================ # SECTION: Cold-Start # ============================================================================ def run_cold_start_tests(): """Test cold-start handling.""" from sage.services.cold_start import ( recommend_cold_start_user, get_warmup_level, get_content_weight, hybrid_recommend, ) from sage.core import UserPreferences from sage.services.cold_start import preferences_to_query log_banner(logger, "COLD-START HANDLING TESTS") # Try to load splits for warm user tests (optional) train_df = None user_counts = {} try: from sage.data import load_splits train_df, _, _ = load_splits() user_counts = train_df.groupby("user_id").size().to_dict() logger.info("Loaded splits: %d training users", len(user_counts)) except FileNotFoundError: logger.info("Splits not available - warm user tests will be skipped") # Test warmup levels log_section(logger, "1. WARMUP LEVEL DETECTION") test_counts = [0, 1, 3, 5, 10] for count in test_counts: level = get_warmup_level(count) weight = get_content_weight(count) logger.info( " %d interactions: level=%s, content_weight=%.1f", count, level, weight ) # Test preferences to query log_section(logger, "2. PREFERENCES TO QUERY") prefs = UserPreferences( categories=["headphones", "audio"], budget="medium", priorities=["quality", "durability"], use_cases="travel", ) query = preferences_to_query(prefs) logger.info("Preferences: %s", prefs) logger.info('Query: "%s"', query) # Test cold-start recommendations log_section(logger, "3. COLD-START RECOMMENDATIONS") logger.info("Preference-based (cold user):") recs = recommend_cold_start_user( preferences=prefs, top_k=5, min_rating=4.0, ) logger.info("Got %d recommendations", len(recs)) for r in recs[:3]: logger.info( " %s: score=%.3f, rating=%.1f", r.product_id, r.score, r.avg_rating ) logger.info("Query-based (cold user):") recs = recommend_cold_start_user( query="wireless earbuds for running", top_k=5, ) logger.info("Got %d recommendations", len(recs)) for r in recs[:3]: logger.info(" %s: score=%.3f", r.product_id, r.score) # Test hybrid recommend log_section(logger, "4. HYBRID RECOMMENDATIONS") # Cold user (no history) logger.info("Cold user (0 interactions):") recs = hybrid_recommend( query="noise cancelling headphones", user_history=None, preferences=prefs, top_k=3, ) for r in recs: logger.info(" %s: score=%.3f", r.product_id, r.score) # Find a warm user (only if splits available) if train_df is not None: warm_users = [u for u, c in user_counts.items() if c >= 5] if warm_users: warm_user = warm_users[0] user_history = train_df[train_df["user_id"] == warm_user].to_dict("records") logger.info("Warm user (%d interactions):", len(user_history)) recs = hybrid_recommend( query="similar products", user_history=user_history, top_k=3, ) for r in recs: logger.info(" %s: score=%.3f", r.product_id, r.score) else: logger.info("Skipping warm user test (no splits)") log_banner(logger, "COLD-START TESTS COMPLETE") # ============================================================================ # Main # ============================================================================ def main(): parser = argparse.ArgumentParser(description="Run explanation tests") parser.add_argument( "--section", "-s", choices=["all", "basic", "gate", "verify", "cold"], default="all", help="Which section to run", ) args = parser.parse_args() if args.section in ("all", "basic"): run_basic_tests() if args.section in ("all", "gate"): run_quality_gate_tests() if args.section in ("all", "verify"): run_verification_tests() if args.section in ("all", "cold"): run_cold_start_tests() if __name__ == "__main__": main()