Recommendation-Agent / scripts /test_task_b.py
Israelbliz's picture
Upload scripts
a971a56 verified
"""End-to-end test of the Task B recommender on real data.
Picks a user, generates k=10 recommendations, and reports whether any of
the user's actual held-out test items appeared in the top-10 (Hit Rate).
This is the first time you'll see the Hit Rate signal β€” the same metric
judges score on. Real data, real test, real number.
Usage:
python -m scripts.test_task_b
python -m scripts.test_task_b --user <user_id>
python -m scripts.test_task_b --cross-domain # recommend in domains they haven't tried
python -m scripts.test_task_b --cold-start # use a synthetic Naija persona
"""
from __future__ import annotations
import argparse
import logging
import pandas as pd
from core.config import settings
from core.persona import PersonaEngine, UserPersona
from core.nigerian import naija_persona_examples
from task_b_recommender.agent import RecommendationAgent, detect_mode
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
def _pick_cross_domain_user(train: pd.DataFrame, test: pd.DataFrame) -> str:
"""Pick a user with multi-domain history AND held-out test items."""
users_in_test = set(test["user_id"])
counts = (train.groupby("user_id")
.agg(n=("rating", "size"), d=("domain", "nunique"))
.reset_index())
counts = counts[counts["user_id"].isin(users_in_test)]
counts = counts[counts["d"] >= 2]
if counts.empty:
raise SystemExit("No cross-domain user has test reviews.")
return counts.nlargest(1, "n").iloc[0]["user_id"]
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--user", type=str, default=None)
ap.add_argument("--cross-domain", action="store_true",
help="Recommend in domains the user has NOT engaged with")
ap.add_argument("--cold-start", action="store_true",
help="Use a synthetic Naija demo persona instead of a real user")
ap.add_argument("--k", type=int, default=10)
args = ap.parse_args()
reviews_path = settings.processed_dir / "reviews.parquet"
if not reviews_path.exists():
raise SystemExit("Run `python data/prepare_data.py` first.")
reviews = pd.read_parquet(reviews_path)
train = reviews[reviews["split"] == "train"]
test = reviews[reviews["split"] == "test"]
# ── Build persona ──────────────────────────────────────────────────────
if args.cold_start:
demo = naija_persona_examples()[0] # Tunde β€” Lagos software engineer
print(f"Cold-start persona: {demo['name']}\n")
print(f"Description: {demo['description']}\n")
persona = UserPersona(
user_id="custom_cold_start",
n_reviews=0, avg_rating=4.0, std_rating=0.5,
avg_review_length=80.0, std_review_length=20.0,
verified_rate=1.0, domains=["Books"], n_domains=1,
rating_distribution={4: 0.6, 5: 0.3, 3: 0.1},
top_terms=[],
tone="", preferred_themes=demo["stated_preferences"],
common_complaints=demo["deal_breakers"],
voice_one_liner=demo["description"],
history_samples=[],
)
held_out = pd.DataFrame()
else:
user_id = args.user or _pick_cross_domain_user(train, test)
if not args.user:
print(f"Auto-selected cross-domain user: {user_id}\n")
print(f"Building persona for {user_id}...")
engine = PersonaEngine()
persona = engine.from_dataframe(user_id, train)
persona = engine.enrich(persona)
held_out = test[test["user_id"] == user_id]
print(f"User has {len(held_out)} held-out test items")
# ── Generate recommendations ───────────────────────────────────────────
mode = detect_mode(persona, requested_cross_domain=args.cross_domain)
print(f"\nMode: {mode}")
print(f"Generating {args.k} recommendations...\n")
# use_review_enrichment=False β†’ pure HyDE for cold-start, no aggregator.
# This keeps the rerank prompt small and isolates HyDE's effect.
agent = RecommendationAgent(use_review_enrichment=False)
recs = agent.run(persona, k=args.k, cross_domain=args.cross_domain)
# ── Display ────────────────────────────────────────────────────────────
print("=" * 70)
print("PERSONA")
print("=" * 70)
if persona.n_reviews > 0:
print(f"User: {persona.user_id}")
print(f"Avg rating: {persona.avg_rating:.2f} Reviews: {persona.n_reviews}")
print(f"Domains: {', '.join(persona.domains)}")
print(f"Voice: {persona.voice_one_liner}")
else:
print(f"Cold-start: {persona.voice_one_liner}")
print("\n" + "=" * 70)
print(f"TOP {args.k} RECOMMENDATIONS")
print("=" * 70)
if not recs:
print("(no recommendations returned)")
held_out_ids = set(held_out["parent_asin"]) if not held_out.empty else set()
hits = 0
for r in recs:
is_hit = r.item_id in held_out_ids
marker = " 🎯 HIT" if is_hit else ""
if is_hit:
hits += 1
print(f"\n#{r.rank} [{r.domain}] {r.title[:80]}{marker}")
print(f" Why: {r.reasoning}")
# ── Hit Rate report ────────────────────────────────────────────────────
if not held_out.empty:
print("\n" + "=" * 70)
print("EVALUATION (vs held-out test items)")
print("=" * 70)
print(f"User's actual held-out items: {len(held_out)}")
print(f"Hits in top-{args.k}: {hits}")
print(f"Hit Rate@{args.k}: {1.0 if hits > 0 else 0.0:.2f}")
if hits > 0:
print(f"\nπŸŽ‰ The system recommended {hits} item(s) the user actually engaged with.")
else:
print(f"\n(No hits this run β€” held-out items are tiny needles in a 64k-item haystack.")
print(f" Real evaluation will average across 500 users to get a stable score.)")
print()
if __name__ == "__main__":
main()