Sage / scripts /demo.py
vxa8502's picture
Replace EDA with production Qdrant queries
66926c8
"""
Demo script showing the complete recommendation + explanation pipeline.
For each recommended product, outputs:
- Product recommendation with relevance score
- Explanation grounded in specific review quotes
- Confidence score (HHEM hallucination detection)
- Evidence source IDs for traceability
Usage:
python scripts/demo.py
python scripts/demo.py --query "wireless earbuds for running"
"""
import argparse
import json
from sage.core import AggregationMethod
from sage.config import FAITHFULNESS_TARGET, get_logger, log_banner, log_section
from sage.services.retrieval import get_candidates
logger = get_logger(__name__)
def demo_recommendation(query: str, top_k: int = 3, max_evidence: int = 3):
"""
Run full recommendation pipeline and display results in demo format.
Returns dict suitable for JSON serialization.
"""
log_banner(logger, "SAGE RECOMMENDATION DEMO", width=70)
logger.info('Query: "%s"', query)
# Get candidates
products = get_candidates(
query=query,
k=top_k,
min_rating=4.0,
aggregation=AggregationMethod.MAX,
)
if not products:
logger.warning("No products found matching query")
return None
# Initialize services
from sage.services import get_explanation_services
explainer, detector = get_explanation_services()
results = []
for i, product in enumerate(products, 1):
log_banner(logger, f"RECOMMENDATION #{i}", width=70)
# Generate explanation
explanation_result = explainer.generate_explanation(
query=query,
product=product,
max_evidence=max_evidence,
)
# Check faithfulness
hhem_result = detector.check_explanation(
evidence_texts=explanation_result.evidence_texts,
explanation=explanation_result.explanation,
)
# Display product info
logger.info("Product ID: %s", product.product_id)
logger.info("Relevance Score: %.3f", product.score)
logger.info("Average Rating: %.1f/5.0 stars", product.avg_rating)
logger.info("Evidence Chunks: %d", product.chunk_count)
# Display explanation with grounding status
log_section(logger, "EXPLANATION")
logger.info(explanation_result.explanation)
# Display confidence
grounded = "GROUNDED" if not hhem_result.is_hallucinated else "NEEDS REVIEW"
log_section(logger, "CONFIDENCE")
logger.info("HHEM Score: %.3f (%s)", hhem_result.score, grounded)
logger.info("Threshold: %s", hhem_result.threshold)
# Display evidence traceability
log_section(logger, "EVIDENCE SOURCES")
for j, (ev_id, ev_text) in enumerate(
zip(
explanation_result.evidence_ids,
explanation_result.evidence_texts,
strict=True,
),
1,
):
# Truncate long evidence for display
display_text = ev_text[:200] + "..." if len(ev_text) > 200 else ev_text
logger.info("[%s]:", ev_id)
logger.info(' "%s"', display_text)
# Compile result
result = {
"rank": i,
"product_id": product.product_id,
"relevance_score": round(product.score, 3),
"avg_rating": round(product.avg_rating, 1),
"explanation": explanation_result.explanation,
"confidence": {
"hhem_score": round(hhem_result.score, 3),
"is_grounded": not hhem_result.is_hallucinated,
"threshold": hhem_result.threshold,
},
"evidence_sources": [
{"id": ev_id, "text": ev_text}
for ev_id, ev_text in zip(
explanation_result.evidence_ids,
explanation_result.evidence_texts,
strict=True,
)
],
}
results.append(result)
# Summary
log_banner(logger, "DEMO SUMMARY", width=70)
grounded_count = sum(1 for r in results if r["confidence"]["is_grounded"])
logger.info("Products recommended: %d", len(results))
logger.info("Grounded explanations: %d/%d", grounded_count, len(results))
logger.info("Faithfulness target: %s", FAITHFULNESS_TARGET)
return {
"query": query,
"recommendations": results,
}
def main():
parser = argparse.ArgumentParser(description="Demo recommendation pipeline")
parser.add_argument(
"--query",
"-q",
type=str,
default="wireless earbuds for running",
help="Query to demonstrate",
)
parser.add_argument(
"--top-k",
"-k",
type=int,
default=1,
help="Number of products to recommend (default: 1)",
)
parser.add_argument(
"--json",
action="store_true",
help="Output as JSON instead of formatted text",
)
args = parser.parse_args()
result = demo_recommendation(args.query, top_k=args.top_k)
if args.json and result:
log_section(logger, "JSON OUTPUT")
print(json.dumps(result, indent=2))
if __name__ == "__main__":
main()