Spaces:
Running
Running
| """Quick end-to-end test of the Task A agent on real data. | |
| Picks a user from the training set, picks one of their held-out test | |
| reviews (which is real ground truth we know), generates a predicted | |
| rating + review for that item, and prints both side by side. | |
| Usage: | |
| python -m scripts.test_task_a | |
| python -m scripts.test_task_a --user <user_id> | |
| python -m scripts.test_task_a --naija | |
| python -m scripts.test_task_a --user <user_id> --naija | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import logging | |
| import pandas as pd | |
| from core.config import settings | |
| from core.persona import PersonaEngine | |
| from task_a_user_modeling.agent import ImpersonationAgent, ItemInput | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") | |
| def main(): | |
| ap = argparse.ArgumentParser() | |
| ap.add_argument("--user", type=str, default=None, | |
| help="Specific user_id; else picks a cross-domain user with the most reviews") | |
| ap.add_argument("--naija", action="store_true", | |
| help="Apply Nigerian English style transfer to the generated review") | |
| args = ap.parse_args() | |
| reviews_path = settings.processed_dir / "reviews.parquet" | |
| items_path = settings.processed_dir / "items.parquet" | |
| if not reviews_path.exists() or not items_path.exists(): | |
| raise SystemExit("Run `python data/prepare_data.py` first.") | |
| reviews = pd.read_parquet(reviews_path) | |
| items = pd.read_parquet(items_path) | |
| train = reviews[reviews["split"] == "train"] | |
| test = reviews[reviews["split"] == "test"] | |
| # Pick a user | |
| if args.user: | |
| user_id = args.user | |
| else: | |
| cross_users = (train.groupby("user_id") | |
| .agg(n=("rating", "size"), d=("domain", "nunique")) | |
| .reset_index()) | |
| cross_users = cross_users[cross_users["d"] >= 2] | |
| # Prefer users who also have test reviews | |
| users_with_test = set(test["user_id"]) | |
| cross_users = cross_users[cross_users["user_id"].isin(users_with_test)] | |
| if cross_users.empty: | |
| raise SystemExit("No cross-domain user has test reviews. Try --user <id>") | |
| user_id = cross_users.nlargest(1, "n").iloc[0]["user_id"] | |
| print(f"Auto-selected cross-domain user: {user_id}\n") | |
| # Pick a test review for this user | |
| user_test = test[test["user_id"] == user_id] | |
| if user_test.empty: | |
| raise SystemExit(f"User {user_id} has no test reviews — try a different user.") | |
| test_review = user_test.iloc[0] | |
| target_item_id = test_review["parent_asin"] | |
| # Look up item metadata | |
| item_meta = items[items["parent_asin"] == target_item_id] | |
| if item_meta.empty: | |
| print(f"WARN: no item metadata for {target_item_id}; using review title only") | |
| item = ItemInput( | |
| parent_asin=target_item_id, | |
| title=str(test_review.get("title", "")), | |
| description="", | |
| categories="", | |
| domain=test_review["domain"], | |
| ) | |
| else: | |
| meta = item_meta.iloc[0] | |
| item = ItemInput( | |
| parent_asin=target_item_id, | |
| title=str(meta.get("title", "")), | |
| description=str(meta.get("description", ""))[:1500], | |
| categories=str(meta.get("categories", "")), | |
| domain=test_review["domain"], | |
| average_rating=float(meta["average_rating"]) if pd.notna(meta.get("average_rating")) else None, | |
| ) | |
| # Build persona (with LLM enrichment) | |
| print(f"Building persona for {user_id}...") | |
| engine = PersonaEngine() | |
| persona = engine.from_dataframe(user_id, train) | |
| persona = engine.enrich(persona) | |
| # Run the agent | |
| print(f"\nGenerating review for item: {item.title[:80]}...\n") | |
| agent = ImpersonationAgent() | |
| result = agent.run(persona, item, naija_mode=args.naija) | |
| # Print side-by-side comparison with ground truth | |
| print("=" * 70) | |
| print("PERSONA SUMMARY") | |
| print("=" * 70) | |
| print(f"User: {user_id}") | |
| print(f"Avg rating: {persona.avg_rating:.2f} | Tone: {persona.tone}") | |
| print(f"Voice: {persona.voice_one_liner}") | |
| print("\n" + "=" * 70) | |
| print("TARGET ITEM") | |
| print("=" * 70) | |
| print(f"Domain: {item.domain}") | |
| print(f"Title: {item.title}") | |
| if item.description: | |
| print(f"Description: {item.description[:300]}...") | |
| print("\n" + "=" * 70) | |
| print(f"AI-GENERATED PREDICTION {'(Naija mode)' if args.naija else ''}") | |
| print("=" * 70) | |
| print(f"Rating: {result.rating}★") | |
| print(f"Reasoning: {result.reasoning}") | |
| print(f"\nReview:\n{result.review}") | |
| print("\n" + "=" * 70) | |
| print("GROUND TRUTH (what the user actually wrote)") | |
| print("=" * 70) | |
| print(f"Rating: {test_review['rating']}★") | |
| print(f"\nReview:\n{test_review['text']}") | |
| print("\n" + "=" * 70) | |
| print("DELTA") | |
| print("=" * 70) | |
| rating_delta = abs(result.rating - float(test_review["rating"])) | |
| print(f"Rating absolute error: {rating_delta:.1f} stars") | |
| print(f"Generated review length: {len(result.review.split())} words") | |
| print(f"Ground truth length: {len(str(test_review['text']).split())} words") | |
| if __name__ == "__main__": | |
| main() | |