Spaces:
Running
Running
File size: 5,208 Bytes
bad8292 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 | """Quick end-to-end test of the Task A agent on real data.
Picks a user from the training set, picks one of their held-out test
reviews (which is real ground truth we know), generates a predicted
rating + review for that item, and prints both side by side.
Usage:
python -m scripts.test_task_a
python -m scripts.test_task_a --user <user_id>
python -m scripts.test_task_a --naija
python -m scripts.test_task_a --user <user_id> --naija
"""
from __future__ import annotations
import argparse
import logging
import pandas as pd
from core.config import settings
from core.persona import PersonaEngine
from task_a_user_modeling.agent import ImpersonationAgent, ItemInput
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--user", type=str, default=None,
help="Specific user_id; else picks a cross-domain user with the most reviews")
ap.add_argument("--naija", action="store_true",
help="Apply Nigerian English style transfer to the generated review")
args = ap.parse_args()
reviews_path = settings.processed_dir / "reviews.parquet"
items_path = settings.processed_dir / "items.parquet"
if not reviews_path.exists() or not items_path.exists():
raise SystemExit("Run `python data/prepare_data.py` first.")
reviews = pd.read_parquet(reviews_path)
items = pd.read_parquet(items_path)
train = reviews[reviews["split"] == "train"]
test = reviews[reviews["split"] == "test"]
# Pick a user
if args.user:
user_id = args.user
else:
cross_users = (train.groupby("user_id")
.agg(n=("rating", "size"), d=("domain", "nunique"))
.reset_index())
cross_users = cross_users[cross_users["d"] >= 2]
# Prefer users who also have test reviews
users_with_test = set(test["user_id"])
cross_users = cross_users[cross_users["user_id"].isin(users_with_test)]
if cross_users.empty:
raise SystemExit("No cross-domain user has test reviews. Try --user <id>")
user_id = cross_users.nlargest(1, "n").iloc[0]["user_id"]
print(f"Auto-selected cross-domain user: {user_id}\n")
# Pick a test review for this user
user_test = test[test["user_id"] == user_id]
if user_test.empty:
raise SystemExit(f"User {user_id} has no test reviews — try a different user.")
test_review = user_test.iloc[0]
target_item_id = test_review["parent_asin"]
# Look up item metadata
item_meta = items[items["parent_asin"] == target_item_id]
if item_meta.empty:
print(f"WARN: no item metadata for {target_item_id}; using review title only")
item = ItemInput(
parent_asin=target_item_id,
title=str(test_review.get("title", "")),
description="",
categories="",
domain=test_review["domain"],
)
else:
meta = item_meta.iloc[0]
item = ItemInput(
parent_asin=target_item_id,
title=str(meta.get("title", "")),
description=str(meta.get("description", ""))[:1500],
categories=str(meta.get("categories", "")),
domain=test_review["domain"],
average_rating=float(meta["average_rating"]) if pd.notna(meta.get("average_rating")) else None,
)
# Build persona (with LLM enrichment)
print(f"Building persona for {user_id}...")
engine = PersonaEngine()
persona = engine.from_dataframe(user_id, train)
persona = engine.enrich(persona)
# Run the agent
print(f"\nGenerating review for item: {item.title[:80]}...\n")
agent = ImpersonationAgent()
result = agent.run(persona, item, naija_mode=args.naija)
# Print side-by-side comparison with ground truth
print("=" * 70)
print("PERSONA SUMMARY")
print("=" * 70)
print(f"User: {user_id}")
print(f"Avg rating: {persona.avg_rating:.2f} | Tone: {persona.tone}")
print(f"Voice: {persona.voice_one_liner}")
print("\n" + "=" * 70)
print("TARGET ITEM")
print("=" * 70)
print(f"Domain: {item.domain}")
print(f"Title: {item.title}")
if item.description:
print(f"Description: {item.description[:300]}...")
print("\n" + "=" * 70)
print(f"AI-GENERATED PREDICTION {'(Naija mode)' if args.naija else ''}")
print("=" * 70)
print(f"Rating: {result.rating}★")
print(f"Reasoning: {result.reasoning}")
print(f"\nReview:\n{result.review}")
print("\n" + "=" * 70)
print("GROUND TRUTH (what the user actually wrote)")
print("=" * 70)
print(f"Rating: {test_review['rating']}★")
print(f"\nReview:\n{test_review['text']}")
print("\n" + "=" * 70)
print("DELTA")
print("=" * 70)
rating_delta = abs(result.rating - float(test_review["rating"]))
print(f"Rating absolute error: {rating_delta:.1f} stars")
print(f"Generated review length: {len(result.review.split())} words")
print(f"Ground truth length: {len(str(test_review['text']).split())} words")
if __name__ == "__main__":
main()
|