|
|
""" |
|
|
Demo script showing the complete recommendation + explanation pipeline. |
|
|
|
|
|
For each recommended product, outputs: |
|
|
- Product recommendation with relevance score |
|
|
- Explanation grounded in specific review quotes |
|
|
- Confidence score (HHEM hallucination detection) |
|
|
- Evidence source IDs for traceability |
|
|
|
|
|
Usage: |
|
|
python scripts/demo.py |
|
|
python scripts/demo.py --query "wireless earbuds for running" |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import json |
|
|
|
|
|
from sage.core import AggregationMethod |
|
|
from sage.config import FAITHFULNESS_TARGET, get_logger, log_banner, log_section |
|
|
from sage.services.retrieval import get_candidates |
|
|
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
|
|
|
def demo_recommendation(query: str, top_k: int = 3, max_evidence: int = 3): |
|
|
""" |
|
|
Run full recommendation pipeline and display results in demo format. |
|
|
|
|
|
Returns dict suitable for JSON serialization. |
|
|
""" |
|
|
log_banner(logger, "SAGE RECOMMENDATION DEMO", width=70) |
|
|
logger.info('Query: "%s"', query) |
|
|
|
|
|
|
|
|
products = get_candidates( |
|
|
query=query, |
|
|
k=top_k, |
|
|
min_rating=4.0, |
|
|
aggregation=AggregationMethod.MAX, |
|
|
) |
|
|
|
|
|
if not products: |
|
|
logger.warning("No products found matching query") |
|
|
return None |
|
|
|
|
|
|
|
|
from sage.services import get_explanation_services |
|
|
|
|
|
explainer, detector = get_explanation_services() |
|
|
|
|
|
results = [] |
|
|
|
|
|
for i, product in enumerate(products, 1): |
|
|
log_banner(logger, f"RECOMMENDATION #{i}", width=70) |
|
|
|
|
|
|
|
|
explanation_result = explainer.generate_explanation( |
|
|
query=query, |
|
|
product=product, |
|
|
max_evidence=max_evidence, |
|
|
) |
|
|
|
|
|
|
|
|
hhem_result = detector.check_explanation( |
|
|
evidence_texts=explanation_result.evidence_texts, |
|
|
explanation=explanation_result.explanation, |
|
|
) |
|
|
|
|
|
|
|
|
logger.info("Product ID: %s", product.product_id) |
|
|
logger.info("Relevance Score: %.3f", product.score) |
|
|
logger.info("Average Rating: %.1f/5.0 stars", product.avg_rating) |
|
|
logger.info("Evidence Chunks: %d", product.chunk_count) |
|
|
|
|
|
|
|
|
log_section(logger, "EXPLANATION") |
|
|
logger.info(explanation_result.explanation) |
|
|
|
|
|
|
|
|
grounded = "GROUNDED" if not hhem_result.is_hallucinated else "NEEDS REVIEW" |
|
|
log_section(logger, "CONFIDENCE") |
|
|
logger.info("HHEM Score: %.3f (%s)", hhem_result.score, grounded) |
|
|
logger.info("Threshold: %s", hhem_result.threshold) |
|
|
|
|
|
|
|
|
log_section(logger, "EVIDENCE SOURCES") |
|
|
for j, (ev_id, ev_text) in enumerate( |
|
|
zip( |
|
|
explanation_result.evidence_ids, |
|
|
explanation_result.evidence_texts, |
|
|
strict=True, |
|
|
), |
|
|
1, |
|
|
): |
|
|
|
|
|
display_text = ev_text[:200] + "..." if len(ev_text) > 200 else ev_text |
|
|
logger.info("[%s]:", ev_id) |
|
|
logger.info(' "%s"', display_text) |
|
|
|
|
|
|
|
|
result = { |
|
|
"rank": i, |
|
|
"product_id": product.product_id, |
|
|
"relevance_score": round(product.score, 3), |
|
|
"avg_rating": round(product.avg_rating, 1), |
|
|
"explanation": explanation_result.explanation, |
|
|
"confidence": { |
|
|
"hhem_score": round(hhem_result.score, 3), |
|
|
"is_grounded": not hhem_result.is_hallucinated, |
|
|
"threshold": hhem_result.threshold, |
|
|
}, |
|
|
"evidence_sources": [ |
|
|
{"id": ev_id, "text": ev_text} |
|
|
for ev_id, ev_text in zip( |
|
|
explanation_result.evidence_ids, |
|
|
explanation_result.evidence_texts, |
|
|
strict=True, |
|
|
) |
|
|
], |
|
|
} |
|
|
results.append(result) |
|
|
|
|
|
|
|
|
log_banner(logger, "DEMO SUMMARY", width=70) |
|
|
grounded_count = sum(1 for r in results if r["confidence"]["is_grounded"]) |
|
|
logger.info("Products recommended: %d", len(results)) |
|
|
logger.info("Grounded explanations: %d/%d", grounded_count, len(results)) |
|
|
logger.info("Faithfulness target: %s", FAITHFULNESS_TARGET) |
|
|
|
|
|
return { |
|
|
"query": query, |
|
|
"recommendations": results, |
|
|
} |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description="Demo recommendation pipeline") |
|
|
parser.add_argument( |
|
|
"--query", |
|
|
"-q", |
|
|
type=str, |
|
|
default="wireless earbuds for running", |
|
|
help="Query to demonstrate", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--top-k", |
|
|
"-k", |
|
|
type=int, |
|
|
default=1, |
|
|
help="Number of products to recommend (default: 1)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--json", |
|
|
action="store_true", |
|
|
help="Output as JSON instead of formatted text", |
|
|
) |
|
|
args = parser.parse_args() |
|
|
|
|
|
result = demo_recommendation(args.query, top_k=args.top_k) |
|
|
|
|
|
if args.json and result: |
|
|
log_section(logger, "JSON OUTPUT") |
|
|
print(json.dumps(result, indent=2)) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|