File size: 13,173 Bytes
88519e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bf39698
88519e8
 
66926c8
88519e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bf39698
 
 
 
88519e8
 
bf39698
88519e8
 
 
 
d507c32
88519e8
 
 
bf39698
88519e8
 
 
 
 
 
 
 
 
 
 
 
 
27e97ac
88519e8
 
 
 
 
 
 
 
 
 
 
 
 
bf39698
88519e8
 
66926c8
 
 
 
88519e8
66926c8
 
 
 
 
 
 
88519e8
 
 
 
 
 
 
 
bf39698
 
 
 
88519e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bf39698
 
 
 
 
88519e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bf39698
 
 
 
 
 
 
 
88519e8
 
 
 
 
 
 
 
 
 
bf39698
 
 
 
 
88519e8
 
 
 
bf39698
 
 
88519e8
 
 
 
 
 
 
 
bf39698
88519e8
 
bf39698
 
 
 
 
88519e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bf39698
88519e8
 
 
 
 
 
 
 
 
 
 
 
 
66926c8
 
 
 
 
88519e8
66926c8
 
 
 
 
88519e8
 
 
 
 
 
 
 
bf39698
 
 
88519e8
 
 
 
 
 
 
 
 
 
 
 
bf39698
88519e8
 
 
 
 
 
 
 
 
 
 
 
bf39698
 
 
88519e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66926c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88519e8
 
 
 
 
 
 
 
bf39698
88519e8
 
 
bf39698
 
88519e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
"""
Explanation generation tests.

Combines:
- Basic explanation generation with HHEM detection
- Evidence quality gate validation
- Post-generation verification loop
- Cold-start handling tests

Usage:
    python scripts/explanation.py                  # All tests
    python scripts/explanation.py --section basic  # Basic generation only
    python scripts/explanation.py --section gate   # Quality gate only
    python scripts/explanation.py --section verify # Verification only
    python scripts/explanation.py --section cold   # Cold-start only

Run from project root.
"""

import argparse

import numpy as np

from sage.core import AggregationMethod, ProductScore, RetrievedChunk
from sage.config import (
    LLM_PROVIDER,
    get_logger,
    log_banner,
    log_section,
)
from sage.services.retrieval import get_candidates

logger = get_logger(__name__)

TOP_K_PRODUCTS = 3
PRODUCTS_PER_QUERY = 2


# ============================================================================
# SECTION: Basic Explanation Generation
# ============================================================================


def run_basic_tests():
    """Test basic explanation generation and HHEM detection."""
    from sage.services import get_explanation_services

    log_banner(logger, "BASIC EXPLANATION TESTS")
    logger.info("Using LLM provider: %s", LLM_PROVIDER)

    test_queries = [
        "wireless headphones with good noise cancellation",
        "laptop charger that works with MacBook",
        "USB hub with multiple ports",
    ]

    # Get recommendations
    log_section(logger, "1. GETTING RECOMMENDATIONS")
    query_results = {}
    for query in test_queries:
        products = get_candidates(
            query=query,
            k=TOP_K_PRODUCTS,
            min_rating=4.0,
            aggregation=AggregationMethod.MAX,
        )
        query_results[query] = products
        logger.info('Query: "%s"', query)
        logger.info("  Found %d products", len(products))

    # Generate explanations
    log_section(logger, "2. GENERATING EXPLANATIONS")
    explainer, detector = get_explanation_services()
    all_explanations = []

    for query, products in query_results.items():
        logger.info('--- Query: "%s" ---', query)
        for product in products[:PRODUCTS_PER_QUERY]:
            result = explainer.generate_explanation(query, product)
            all_explanations.append(result)
            logger.info("Product: %s", result.product_id)
            logger.info("Explanation: %s", result.explanation)

    # Run HHEM
    log_section(logger, "3. HHEM HALLUCINATION DETECTION")
    hhem_results = [
        detector.check_explanation(expl.evidence_texts, expl.explanation)
        for expl in all_explanations
    ]

    for expl, result in zip(all_explanations, hhem_results, strict=True):
        status = "GROUNDED" if not result.is_hallucinated else "HALLUCINATED"
        logger.info("[%s] Score: %.3f - %s", status, result.score, expl.product_id)

    scores = [r.score for r in hhem_results]
    n_hall = sum(1 for r in hhem_results if r.is_hallucinated)
    logger.info("Summary: %d total, %d hallucinated", len(hhem_results), n_hall)
    logger.info("Mean HHEM: %.3f", np.mean(scores))

    # Streaming test
    log_section(logger, "4. STREAMING TEST")
    if query_results:
        test_query = list(query_results.keys())[0]
        test_product = query_results[test_query][0]
        logger.info('Query: "%s"', test_query)
        logger.info("Streaming: ")

        try:
            stream = explainer.generate_explanation_stream(test_query, test_product)
            chunks = list(stream)
            logger.info("".join(chunks))

            streamed_result = stream.get_complete_result()
            hhem = detector.check_explanation(
                streamed_result.evidence_texts, streamed_result.explanation
            )
            logger.info("HHEM Score: %.3f", hhem.score)
        except ValueError as e:
            logger.info("Quality gate refused streaming: %s", e)

    log_banner(logger, "BASIC TESTS COMPLETE")


# ============================================================================
# SECTION: Evidence Quality Gate
# ============================================================================


def create_mock_product(
    n_chunks: int, tokens_per_chunk: int = 100, product_score: float = 0.85
) -> ProductScore:
    """Create a mock ProductScore for testing."""
    chunks = [
        RetrievedChunk(
            text="x" * (tokens_per_chunk * 4),
            score=product_score - i * 0.01,
            product_id="TEST_PRODUCT",
            rating=4.5,
            review_id=f"review_{i}",
        )
        for i in range(n_chunks)
    ]
    return ProductScore(
        product_id="TEST_PRODUCT",
        score=product_score,
        chunk_count=n_chunks,
        avg_rating=4.5,
        evidence=chunks,
    )


def run_quality_gate_tests():
    """Test the evidence quality gate."""
    from sage.core.evidence import check_evidence_quality, generate_refusal_message
    from sage.services.faithfulness import is_refusal
    from sage.config import (
        MIN_EVIDENCE_CHUNKS,
        MIN_EVIDENCE_TOKENS,
        MIN_RETRIEVAL_SCORE,
    )

    log_banner(logger, "EVIDENCE QUALITY GATE TESTS")

    log_section(logger, "1. QUALITY CHECK FUNCTION")
    test_cases = [
        (3, 100, 0.85, True, "sufficient"),
        (1, 100, 0.85, False, "insufficient_chunks"),
        (2, 10, 0.85, False, "insufficient_tokens"),
        (3, 100, 0.5, False, "low_relevance"),
    ]

    for n_chunks, tok, score, expected, reason in test_cases:
        product = create_mock_product(n_chunks, tok, score)
        quality = check_evidence_quality(product)
        status = "PASS" if quality.is_sufficient == expected else "FAIL"
        logger.info(
            "[%s] %d chunks, %d tok, score=%.2f -> %s",
            status,
            n_chunks,
            tok,
            score,
            reason,
        )
        assert quality.is_sufficient == expected

    log_section(logger, "2. REFUSAL GENERATION")
    query = "wireless headphones"

    for n_chunks, tok, score in [(1, 100, 0.85), (2, 10, 0.85), (3, 100, 0.5)]:
        product = create_mock_product(n_chunks, tok, score)
        quality = check_evidence_quality(product)
        refusal = generate_refusal_message(query, quality)
        detected = is_refusal(refusal)
        logger.info(
            "[%s] Refusal detected for %s",
            "PASS" if detected else "FAIL",
            quality.failure_reason,
        )
        assert detected

    logger.info(
        "Thresholds: chunks=%d, tokens=%d, score=%.2f",
        MIN_EVIDENCE_CHUNKS,
        MIN_EVIDENCE_TOKENS,
        MIN_RETRIEVAL_SCORE,
    )
    log_banner(logger, "QUALITY GATE TESTS COMPLETE")


# ============================================================================
# SECTION: Verification Loop
# ============================================================================


def run_verification_tests():
    """Test the post-generation verification loop."""
    from sage.core.verification import (
        extract_quotes,
        verify_quote_in_evidence,
        verify_explanation,
    )

    log_banner(logger, "VERIFICATION LOOP TESTS")

    log_section(logger, "1. QUOTE EXTRACTION")
    test_texts = [
        ('Said "amazing sound" and "great value".', 2),
        ('Noted "excellent battery life".', 1),
        ('It was "ok" but "amazing quality".', 1),  # "ok" filtered
        ("No quotes.", 0),
    ]
    for text, expected in test_texts:
        quotes = extract_quotes(text)
        status = "PASS" if len(quotes) == expected else "FAIL"
        logger.info("[%s] %d quotes: %s...", status, len(quotes), text[:40])

    log_section(logger, "2. QUOTE VERIFICATION")
    evidence = [
        "This product has amazing sound quality.",
        "Great value for the money.",
        "Battery lasts about 8 hours.",
    ]

    verify_cases = [
        ("amazing sound quality", True),
        ("AMAZING SOUND QUALITY", True),
        ("Battery lasts about 8 hours", True),
        ("terrible product", False),
    ]
    for quote, expected in verify_cases:
        result = verify_quote_in_evidence(quote, evidence)
        status = "PASS" if result.found == expected else "FAIL"
        logger.info("[%s] '%s'", status, quote)

    log_section(logger, "3. FULL VERIFICATION")
    explanation = 'Praise for "amazing sound quality" and "great value for the money".'
    result = verify_explanation(explanation, evidence)
    logger.info("Found: %d, Missing: %d", result.quotes_found, result.quotes_missing)
    logger.info("All verified: %s", result.all_verified)

    log_banner(logger, "VERIFICATION TESTS COMPLETE")


# ============================================================================
# SECTION: Cold-Start
# ============================================================================


def run_cold_start_tests():
    """Test cold-start handling."""
    from sage.services.cold_start import (
        recommend_cold_start_user,
        get_warmup_level,
        get_content_weight,
        hybrid_recommend,
    )
    from sage.core import UserPreferences
    from sage.services.cold_start import preferences_to_query

    log_banner(logger, "COLD-START HANDLING TESTS")

    # Try to load splits for warm user tests (optional)
    train_df = None
    user_counts = {}
    try:
        from sage.data import load_splits

        train_df, _, _ = load_splits()
        user_counts = train_df.groupby("user_id").size().to_dict()
        logger.info("Loaded splits: %d training users", len(user_counts))
    except FileNotFoundError:
        logger.info("Splits not available - warm user tests will be skipped")

    # Test warmup levels
    log_section(logger, "1. WARMUP LEVEL DETECTION")

    test_counts = [0, 1, 3, 5, 10]
    for count in test_counts:
        level = get_warmup_level(count)
        weight = get_content_weight(count)
        logger.info(
            "  %d interactions: level=%s, content_weight=%.1f", count, level, weight
        )

    # Test preferences to query
    log_section(logger, "2. PREFERENCES TO QUERY")

    prefs = UserPreferences(
        categories=["headphones", "audio"],
        budget="medium",
        priorities=["quality", "durability"],
        use_cases="travel",
    )
    query = preferences_to_query(prefs)
    logger.info("Preferences: %s", prefs)
    logger.info('Query: "%s"', query)

    # Test cold-start recommendations
    log_section(logger, "3. COLD-START RECOMMENDATIONS")

    logger.info("Preference-based (cold user):")
    recs = recommend_cold_start_user(
        preferences=prefs,
        top_k=5,
        min_rating=4.0,
    )
    logger.info("Got %d recommendations", len(recs))
    for r in recs[:3]:
        logger.info(
            "  %s: score=%.3f, rating=%.1f", r.product_id, r.score, r.avg_rating
        )

    logger.info("Query-based (cold user):")
    recs = recommend_cold_start_user(
        query="wireless earbuds for running",
        top_k=5,
    )
    logger.info("Got %d recommendations", len(recs))
    for r in recs[:3]:
        logger.info("  %s: score=%.3f", r.product_id, r.score)

    # Test hybrid recommend
    log_section(logger, "4. HYBRID RECOMMENDATIONS")

    # Cold user (no history)
    logger.info("Cold user (0 interactions):")
    recs = hybrid_recommend(
        query="noise cancelling headphones",
        user_history=None,
        preferences=prefs,
        top_k=3,
    )
    for r in recs:
        logger.info("  %s: score=%.3f", r.product_id, r.score)

    # Find a warm user (only if splits available)
    if train_df is not None:
        warm_users = [u for u, c in user_counts.items() if c >= 5]
        if warm_users:
            warm_user = warm_users[0]
            user_history = train_df[train_df["user_id"] == warm_user].to_dict("records")

            logger.info("Warm user (%d interactions):", len(user_history))
            recs = hybrid_recommend(
                query="similar products",
                user_history=user_history,
                top_k=3,
            )
            for r in recs:
                logger.info("  %s: score=%.3f", r.product_id, r.score)
    else:
        logger.info("Skipping warm user test (no splits)")

    log_banner(logger, "COLD-START TESTS COMPLETE")


# ============================================================================
# Main
# ============================================================================


def main():
    parser = argparse.ArgumentParser(description="Run explanation tests")
    parser.add_argument(
        "--section",
        "-s",
        choices=["all", "basic", "gate", "verify", "cold"],
        default="all",
        help="Which section to run",
    )
    args = parser.parse_args()

    if args.section in ("all", "basic"):
        run_basic_tests()
    if args.section in ("all", "gate"):
        run_quality_gate_tests()
    if args.section in ("all", "verify"):
        run_verification_tests()
    if args.section in ("all", "cold"):
        run_cold_start_tests()


if __name__ == "__main__":
    main()