Spaces:
Running
Running
| """End-to-end test of the Task B recommender on real data. | |
| Picks a user, generates k=10 recommendations, and reports whether any of | |
| the user's actual held-out test items appeared in the top-10 (Hit Rate). | |
| This is the first time you'll see the Hit Rate signal β the same metric | |
| judges score on. Real data, real test, real number. | |
| Usage: | |
| python -m scripts.test_task_b | |
| python -m scripts.test_task_b --user <user_id> | |
| python -m scripts.test_task_b --cross-domain # recommend in domains they haven't tried | |
| python -m scripts.test_task_b --cold-start # use a synthetic Naija persona | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import logging | |
| import pandas as pd | |
| from core.config import settings | |
| from core.persona import PersonaEngine, UserPersona | |
| from core.nigerian import naija_persona_examples | |
| from task_b_recommender.agent import RecommendationAgent, detect_mode | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") | |
| def _pick_cross_domain_user(train: pd.DataFrame, test: pd.DataFrame) -> str: | |
| """Pick a user with multi-domain history AND held-out test items.""" | |
| users_in_test = set(test["user_id"]) | |
| counts = (train.groupby("user_id") | |
| .agg(n=("rating", "size"), d=("domain", "nunique")) | |
| .reset_index()) | |
| counts = counts[counts["user_id"].isin(users_in_test)] | |
| counts = counts[counts["d"] >= 2] | |
| if counts.empty: | |
| raise SystemExit("No cross-domain user has test reviews.") | |
| return counts.nlargest(1, "n").iloc[0]["user_id"] | |
| def main(): | |
| ap = argparse.ArgumentParser() | |
| ap.add_argument("--user", type=str, default=None) | |
| ap.add_argument("--cross-domain", action="store_true", | |
| help="Recommend in domains the user has NOT engaged with") | |
| ap.add_argument("--cold-start", action="store_true", | |
| help="Use a synthetic Naija demo persona instead of a real user") | |
| ap.add_argument("--k", type=int, default=10) | |
| args = ap.parse_args() | |
| reviews_path = settings.processed_dir / "reviews.parquet" | |
| if not reviews_path.exists(): | |
| raise SystemExit("Run `python data/prepare_data.py` first.") | |
| reviews = pd.read_parquet(reviews_path) | |
| train = reviews[reviews["split"] == "train"] | |
| test = reviews[reviews["split"] == "test"] | |
| # ββ Build persona ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if args.cold_start: | |
| demo = naija_persona_examples()[0] # Tunde β Lagos software engineer | |
| print(f"Cold-start persona: {demo['name']}\n") | |
| print(f"Description: {demo['description']}\n") | |
| persona = UserPersona( | |
| user_id="custom_cold_start", | |
| n_reviews=0, avg_rating=4.0, std_rating=0.5, | |
| avg_review_length=80.0, std_review_length=20.0, | |
| verified_rate=1.0, domains=["Books"], n_domains=1, | |
| rating_distribution={4: 0.6, 5: 0.3, 3: 0.1}, | |
| top_terms=[], | |
| tone="", preferred_themes=demo["stated_preferences"], | |
| common_complaints=demo["deal_breakers"], | |
| voice_one_liner=demo["description"], | |
| history_samples=[], | |
| ) | |
| held_out = pd.DataFrame() | |
| else: | |
| user_id = args.user or _pick_cross_domain_user(train, test) | |
| if not args.user: | |
| print(f"Auto-selected cross-domain user: {user_id}\n") | |
| print(f"Building persona for {user_id}...") | |
| engine = PersonaEngine() | |
| persona = engine.from_dataframe(user_id, train) | |
| persona = engine.enrich(persona) | |
| held_out = test[test["user_id"] == user_id] | |
| print(f"User has {len(held_out)} held-out test items") | |
| # ββ Generate recommendations βββββββββββββββββββββββββββββββββββββββββββ | |
| mode = detect_mode(persona, requested_cross_domain=args.cross_domain) | |
| print(f"\nMode: {mode}") | |
| print(f"Generating {args.k} recommendations...\n") | |
| # use_review_enrichment=False β pure HyDE for cold-start, no aggregator. | |
| # This keeps the rerank prompt small and isolates HyDE's effect. | |
| agent = RecommendationAgent(use_review_enrichment=False) | |
| recs = agent.run(persona, k=args.k, cross_domain=args.cross_domain) | |
| # ββ Display ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| print("=" * 70) | |
| print("PERSONA") | |
| print("=" * 70) | |
| if persona.n_reviews > 0: | |
| print(f"User: {persona.user_id}") | |
| print(f"Avg rating: {persona.avg_rating:.2f} Reviews: {persona.n_reviews}") | |
| print(f"Domains: {', '.join(persona.domains)}") | |
| print(f"Voice: {persona.voice_one_liner}") | |
| else: | |
| print(f"Cold-start: {persona.voice_one_liner}") | |
| print("\n" + "=" * 70) | |
| print(f"TOP {args.k} RECOMMENDATIONS") | |
| print("=" * 70) | |
| if not recs: | |
| print("(no recommendations returned)") | |
| held_out_ids = set(held_out["parent_asin"]) if not held_out.empty else set() | |
| hits = 0 | |
| for r in recs: | |
| is_hit = r.item_id in held_out_ids | |
| marker = " π― HIT" if is_hit else "" | |
| if is_hit: | |
| hits += 1 | |
| print(f"\n#{r.rank} [{r.domain}] {r.title[:80]}{marker}") | |
| print(f" Why: {r.reasoning}") | |
| # ββ Hit Rate report ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if not held_out.empty: | |
| print("\n" + "=" * 70) | |
| print("EVALUATION (vs held-out test items)") | |
| print("=" * 70) | |
| print(f"User's actual held-out items: {len(held_out)}") | |
| print(f"Hits in top-{args.k}: {hits}") | |
| print(f"Hit Rate@{args.k}: {1.0 if hits > 0 else 0.0:.2f}") | |
| if hits > 0: | |
| print(f"\nπ The system recommended {hits} item(s) the user actually engaged with.") | |
| else: | |
| print(f"\n(No hits this run β held-out items are tiny needles in a 64k-item haystack.") | |
| print(f" Real evaluation will average across 500 users to get a stable score.)") | |
| print() | |
| if __name__ == "__main__": | |
| main() | |