File size: 5,232 Bytes
88519e8 bf39698 88519e8 d507c32 66926c8 d507c32 88519e8 27e97ac 88519e8 bf39698 88519e8 27e97ac 88519e8 bf39698 88519e8 bf39698 88519e8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 | """
Demo script showing the complete recommendation + explanation pipeline.
For each recommended product, outputs:
- Product recommendation with relevance score
- Explanation grounded in specific review quotes
- Confidence score (HHEM hallucination detection)
- Evidence source IDs for traceability
Usage:
python scripts/demo.py
python scripts/demo.py --query "wireless earbuds for running"
"""
import argparse
import json
from sage.core import AggregationMethod
from sage.config import FAITHFULNESS_TARGET, get_logger, log_banner, log_section
from sage.services.retrieval import get_candidates
logger = get_logger(__name__)
def demo_recommendation(query: str, top_k: int = 3, max_evidence: int = 3):
"""
Run full recommendation pipeline and display results in demo format.
Returns dict suitable for JSON serialization.
"""
log_banner(logger, "SAGE RECOMMENDATION DEMO", width=70)
logger.info('Query: "%s"', query)
# Get candidates
products = get_candidates(
query=query,
k=top_k,
min_rating=4.0,
aggregation=AggregationMethod.MAX,
)
if not products:
logger.warning("No products found matching query")
return None
# Initialize services
from sage.services import get_explanation_services
explainer, detector = get_explanation_services()
results = []
for i, product in enumerate(products, 1):
log_banner(logger, f"RECOMMENDATION #{i}", width=70)
# Generate explanation
explanation_result = explainer.generate_explanation(
query=query,
product=product,
max_evidence=max_evidence,
)
# Check faithfulness
hhem_result = detector.check_explanation(
evidence_texts=explanation_result.evidence_texts,
explanation=explanation_result.explanation,
)
# Display product info
logger.info("Product ID: %s", product.product_id)
logger.info("Relevance Score: %.3f", product.score)
logger.info("Average Rating: %.1f/5.0 stars", product.avg_rating)
logger.info("Evidence Chunks: %d", product.chunk_count)
# Display explanation with grounding status
log_section(logger, "EXPLANATION")
logger.info(explanation_result.explanation)
# Display confidence
grounded = "GROUNDED" if not hhem_result.is_hallucinated else "NEEDS REVIEW"
log_section(logger, "CONFIDENCE")
logger.info("HHEM Score: %.3f (%s)", hhem_result.score, grounded)
logger.info("Threshold: %s", hhem_result.threshold)
# Display evidence traceability
log_section(logger, "EVIDENCE SOURCES")
for j, (ev_id, ev_text) in enumerate(
zip(
explanation_result.evidence_ids,
explanation_result.evidence_texts,
strict=True,
),
1,
):
# Truncate long evidence for display
display_text = ev_text[:200] + "..." if len(ev_text) > 200 else ev_text
logger.info("[%s]:", ev_id)
logger.info(' "%s"', display_text)
# Compile result
result = {
"rank": i,
"product_id": product.product_id,
"relevance_score": round(product.score, 3),
"avg_rating": round(product.avg_rating, 1),
"explanation": explanation_result.explanation,
"confidence": {
"hhem_score": round(hhem_result.score, 3),
"is_grounded": not hhem_result.is_hallucinated,
"threshold": hhem_result.threshold,
},
"evidence_sources": [
{"id": ev_id, "text": ev_text}
for ev_id, ev_text in zip(
explanation_result.evidence_ids,
explanation_result.evidence_texts,
strict=True,
)
],
}
results.append(result)
# Summary
log_banner(logger, "DEMO SUMMARY", width=70)
grounded_count = sum(1 for r in results if r["confidence"]["is_grounded"])
logger.info("Products recommended: %d", len(results))
logger.info("Grounded explanations: %d/%d", grounded_count, len(results))
logger.info("Faithfulness target: %s", FAITHFULNESS_TARGET)
return {
"query": query,
"recommendations": results,
}
def main():
parser = argparse.ArgumentParser(description="Demo recommendation pipeline")
parser.add_argument(
"--query",
"-q",
type=str,
default="wireless earbuds for running",
help="Query to demonstrate",
)
parser.add_argument(
"--top-k",
"-k",
type=int,
default=1,
help="Number of products to recommend (default: 1)",
)
parser.add_argument(
"--json",
action="store_true",
help="Output as JSON instead of formatted text",
)
args = parser.parse_args()
result = demo_recommendation(args.query, top_k=args.top_k)
if args.json and result:
log_section(logger, "JSON OUTPUT")
print(json.dumps(result, indent=2))
if __name__ == "__main__":
main()
|