File size: 5,232 Bytes
88519e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bf39698
88519e8
 
 
 
 
 
 
 
 
 
 
 
 
d507c32
66926c8
d507c32
 
88519e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27e97ac
 
 
 
 
 
88519e8
 
 
 
bf39698
88519e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27e97ac
 
 
88519e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bf39698
 
88519e8
 
 
 
 
bf39698
 
88519e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
"""
Demo script showing the complete recommendation + explanation pipeline.

For each recommended product, outputs:
- Product recommendation with relevance score
- Explanation grounded in specific review quotes
- Confidence score (HHEM hallucination detection)
- Evidence source IDs for traceability

Usage:
    python scripts/demo.py
    python scripts/demo.py --query "wireless earbuds for running"
"""

import argparse
import json

from sage.core import AggregationMethod
from sage.config import FAITHFULNESS_TARGET, get_logger, log_banner, log_section
from sage.services.retrieval import get_candidates

logger = get_logger(__name__)


def demo_recommendation(query: str, top_k: int = 3, max_evidence: int = 3):
    """
    Run full recommendation pipeline and display results in demo format.

    Returns dict suitable for JSON serialization.
    """
    log_banner(logger, "SAGE RECOMMENDATION DEMO", width=70)
    logger.info('Query: "%s"', query)

    # Get candidates
    products = get_candidates(
        query=query,
        k=top_k,
        min_rating=4.0,
        aggregation=AggregationMethod.MAX,
    )

    if not products:
        logger.warning("No products found matching query")
        return None

    # Initialize services
    from sage.services import get_explanation_services

    explainer, detector = get_explanation_services()

    results = []

    for i, product in enumerate(products, 1):
        log_banner(logger, f"RECOMMENDATION #{i}", width=70)

        # Generate explanation
        explanation_result = explainer.generate_explanation(
            query=query,
            product=product,
            max_evidence=max_evidence,
        )

        # Check faithfulness
        hhem_result = detector.check_explanation(
            evidence_texts=explanation_result.evidence_texts,
            explanation=explanation_result.explanation,
        )

        # Display product info
        logger.info("Product ID: %s", product.product_id)
        logger.info("Relevance Score: %.3f", product.score)
        logger.info("Average Rating: %.1f/5.0 stars", product.avg_rating)
        logger.info("Evidence Chunks: %d", product.chunk_count)

        # Display explanation with grounding status
        log_section(logger, "EXPLANATION")
        logger.info(explanation_result.explanation)

        # Display confidence
        grounded = "GROUNDED" if not hhem_result.is_hallucinated else "NEEDS REVIEW"
        log_section(logger, "CONFIDENCE")
        logger.info("HHEM Score: %.3f (%s)", hhem_result.score, grounded)
        logger.info("Threshold: %s", hhem_result.threshold)

        # Display evidence traceability
        log_section(logger, "EVIDENCE SOURCES")
        for j, (ev_id, ev_text) in enumerate(
            zip(
                explanation_result.evidence_ids,
                explanation_result.evidence_texts,
                strict=True,
            ),
            1,
        ):
            # Truncate long evidence for display
            display_text = ev_text[:200] + "..." if len(ev_text) > 200 else ev_text
            logger.info("[%s]:", ev_id)
            logger.info('  "%s"', display_text)

        # Compile result
        result = {
            "rank": i,
            "product_id": product.product_id,
            "relevance_score": round(product.score, 3),
            "avg_rating": round(product.avg_rating, 1),
            "explanation": explanation_result.explanation,
            "confidence": {
                "hhem_score": round(hhem_result.score, 3),
                "is_grounded": not hhem_result.is_hallucinated,
                "threshold": hhem_result.threshold,
            },
            "evidence_sources": [
                {"id": ev_id, "text": ev_text}
                for ev_id, ev_text in zip(
                    explanation_result.evidence_ids,
                    explanation_result.evidence_texts,
                    strict=True,
                )
            ],
        }
        results.append(result)

    # Summary
    log_banner(logger, "DEMO SUMMARY", width=70)
    grounded_count = sum(1 for r in results if r["confidence"]["is_grounded"])
    logger.info("Products recommended: %d", len(results))
    logger.info("Grounded explanations: %d/%d", grounded_count, len(results))
    logger.info("Faithfulness target: %s", FAITHFULNESS_TARGET)

    return {
        "query": query,
        "recommendations": results,
    }


def main():
    parser = argparse.ArgumentParser(description="Demo recommendation pipeline")
    parser.add_argument(
        "--query",
        "-q",
        type=str,
        default="wireless earbuds for running",
        help="Query to demonstrate",
    )
    parser.add_argument(
        "--top-k",
        "-k",
        type=int,
        default=1,
        help="Number of products to recommend (default: 1)",
    )
    parser.add_argument(
        "--json",
        action="store_true",
        help="Output as JSON instead of formatted text",
    )
    args = parser.parse_args()

    result = demo_recommendation(args.query, top_k=args.top_k)

    if args.json and result:
        log_section(logger, "JSON OUTPUT")
        print(json.dumps(result, indent=2))


if __name__ == "__main__":
    main()