User-Modeling-Agent / scripts /test_task_a.py
Israelbliz's picture
Upload scripts
bad8292 verified
"""Quick end-to-end test of the Task A agent on real data.
Picks a user from the training set, picks one of their held-out test
reviews (which is real ground truth we know), generates a predicted
rating + review for that item, and prints both side by side.
Usage:
python -m scripts.test_task_a
python -m scripts.test_task_a --user <user_id>
python -m scripts.test_task_a --naija
python -m scripts.test_task_a --user <user_id> --naija
"""
from __future__ import annotations
import argparse
import logging
import pandas as pd
from core.config import settings
from core.persona import PersonaEngine
from task_a_user_modeling.agent import ImpersonationAgent, ItemInput
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--user", type=str, default=None,
help="Specific user_id; else picks a cross-domain user with the most reviews")
ap.add_argument("--naija", action="store_true",
help="Apply Nigerian English style transfer to the generated review")
args = ap.parse_args()
reviews_path = settings.processed_dir / "reviews.parquet"
items_path = settings.processed_dir / "items.parquet"
if not reviews_path.exists() or not items_path.exists():
raise SystemExit("Run `python data/prepare_data.py` first.")
reviews = pd.read_parquet(reviews_path)
items = pd.read_parquet(items_path)
train = reviews[reviews["split"] == "train"]
test = reviews[reviews["split"] == "test"]
# Pick a user
if args.user:
user_id = args.user
else:
cross_users = (train.groupby("user_id")
.agg(n=("rating", "size"), d=("domain", "nunique"))
.reset_index())
cross_users = cross_users[cross_users["d"] >= 2]
# Prefer users who also have test reviews
users_with_test = set(test["user_id"])
cross_users = cross_users[cross_users["user_id"].isin(users_with_test)]
if cross_users.empty:
raise SystemExit("No cross-domain user has test reviews. Try --user <id>")
user_id = cross_users.nlargest(1, "n").iloc[0]["user_id"]
print(f"Auto-selected cross-domain user: {user_id}\n")
# Pick a test review for this user
user_test = test[test["user_id"] == user_id]
if user_test.empty:
raise SystemExit(f"User {user_id} has no test reviews — try a different user.")
test_review = user_test.iloc[0]
target_item_id = test_review["parent_asin"]
# Look up item metadata
item_meta = items[items["parent_asin"] == target_item_id]
if item_meta.empty:
print(f"WARN: no item metadata for {target_item_id}; using review title only")
item = ItemInput(
parent_asin=target_item_id,
title=str(test_review.get("title", "")),
description="",
categories="",
domain=test_review["domain"],
)
else:
meta = item_meta.iloc[0]
item = ItemInput(
parent_asin=target_item_id,
title=str(meta.get("title", "")),
description=str(meta.get("description", ""))[:1500],
categories=str(meta.get("categories", "")),
domain=test_review["domain"],
average_rating=float(meta["average_rating"]) if pd.notna(meta.get("average_rating")) else None,
)
# Build persona (with LLM enrichment)
print(f"Building persona for {user_id}...")
engine = PersonaEngine()
persona = engine.from_dataframe(user_id, train)
persona = engine.enrich(persona)
# Run the agent
print(f"\nGenerating review for item: {item.title[:80]}...\n")
agent = ImpersonationAgent()
result = agent.run(persona, item, naija_mode=args.naija)
# Print side-by-side comparison with ground truth
print("=" * 70)
print("PERSONA SUMMARY")
print("=" * 70)
print(f"User: {user_id}")
print(f"Avg rating: {persona.avg_rating:.2f} | Tone: {persona.tone}")
print(f"Voice: {persona.voice_one_liner}")
print("\n" + "=" * 70)
print("TARGET ITEM")
print("=" * 70)
print(f"Domain: {item.domain}")
print(f"Title: {item.title}")
if item.description:
print(f"Description: {item.description[:300]}...")
print("\n" + "=" * 70)
print(f"AI-GENERATED PREDICTION {'(Naija mode)' if args.naija else ''}")
print("=" * 70)
print(f"Rating: {result.rating}★")
print(f"Reasoning: {result.reasoning}")
print(f"\nReview:\n{result.review}")
print("\n" + "=" * 70)
print("GROUND TRUTH (what the user actually wrote)")
print("=" * 70)
print(f"Rating: {test_review['rating']}★")
print(f"\nReview:\n{test_review['text']}")
print("\n" + "=" * 70)
print("DELTA")
print("=" * 70)
rating_delta = abs(result.rating - float(test_review["rating"]))
print(f"Rating absolute error: {rating_delta:.1f} stars")
print(f"Generated review length: {len(result.review.split())} words")
print(f"Ground truth length: {len(str(test_review['text']).split())} words")
if __name__ == "__main__":
main()