File size: 6,414 Bytes
a971a56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
"""End-to-end test of the Task B recommender on real data.

Picks a user, generates k=10 recommendations, and reports whether any of
the user's actual held-out test items appeared in the top-10 (Hit Rate).

This is the first time you'll see the Hit Rate signal β€” the same metric
judges score on. Real data, real test, real number.

Usage:
    python -m scripts.test_task_b
    python -m scripts.test_task_b --user <user_id>
    python -m scripts.test_task_b --cross-domain     # recommend in domains they haven't tried
    python -m scripts.test_task_b --cold-start        # use a synthetic Naija persona
"""
from __future__ import annotations

import argparse
import logging

import pandas as pd

from core.config import settings
from core.persona import PersonaEngine, UserPersona
from core.nigerian import naija_persona_examples
from task_b_recommender.agent import RecommendationAgent, detect_mode

logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")


def _pick_cross_domain_user(train: pd.DataFrame, test: pd.DataFrame) -> str:
    """Pick a user with multi-domain history AND held-out test items."""
    users_in_test = set(test["user_id"])
    counts = (train.groupby("user_id")
                   .agg(n=("rating", "size"), d=("domain", "nunique"))
                   .reset_index())
    counts = counts[counts["user_id"].isin(users_in_test)]
    counts = counts[counts["d"] >= 2]
    if counts.empty:
        raise SystemExit("No cross-domain user has test reviews.")
    return counts.nlargest(1, "n").iloc[0]["user_id"]


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--user", type=str, default=None)
    ap.add_argument("--cross-domain", action="store_true",
                    help="Recommend in domains the user has NOT engaged with")
    ap.add_argument("--cold-start", action="store_true",
                    help="Use a synthetic Naija demo persona instead of a real user")
    ap.add_argument("--k", type=int, default=10)
    args = ap.parse_args()

    reviews_path = settings.processed_dir / "reviews.parquet"
    if not reviews_path.exists():
        raise SystemExit("Run `python data/prepare_data.py` first.")

    reviews = pd.read_parquet(reviews_path)
    train = reviews[reviews["split"] == "train"]
    test = reviews[reviews["split"] == "test"]

    # ── Build persona ──────────────────────────────────────────────────────
    if args.cold_start:
        demo = naija_persona_examples()[0]   # Tunde β€” Lagos software engineer
        print(f"Cold-start persona: {demo['name']}\n")
        print(f"Description: {demo['description']}\n")
        persona = UserPersona(
            user_id="custom_cold_start",
            n_reviews=0, avg_rating=4.0, std_rating=0.5,
            avg_review_length=80.0, std_review_length=20.0,
            verified_rate=1.0, domains=["Books"], n_domains=1,
            rating_distribution={4: 0.6, 5: 0.3, 3: 0.1},
            top_terms=[],
            tone="", preferred_themes=demo["stated_preferences"],
            common_complaints=demo["deal_breakers"],
            voice_one_liner=demo["description"],
            history_samples=[],
        )
        held_out = pd.DataFrame()
    else:
        user_id = args.user or _pick_cross_domain_user(train, test)
        if not args.user:
            print(f"Auto-selected cross-domain user: {user_id}\n")

        print(f"Building persona for {user_id}...")
        engine = PersonaEngine()
        persona = engine.from_dataframe(user_id, train)
        persona = engine.enrich(persona)

        held_out = test[test["user_id"] == user_id]
        print(f"User has {len(held_out)} held-out test items")

    # ── Generate recommendations ───────────────────────────────────────────
    mode = detect_mode(persona, requested_cross_domain=args.cross_domain)
    print(f"\nMode: {mode}")
    print(f"Generating {args.k} recommendations...\n")

    # use_review_enrichment=False β†’ pure HyDE for cold-start, no aggregator.
    # This keeps the rerank prompt small and isolates HyDE's effect.
    agent = RecommendationAgent(use_review_enrichment=False)
    recs = agent.run(persona, k=args.k, cross_domain=args.cross_domain)

    # ── Display ────────────────────────────────────────────────────────────
    print("=" * 70)
    print("PERSONA")
    print("=" * 70)
    if persona.n_reviews > 0:
        print(f"User: {persona.user_id}")
        print(f"Avg rating: {persona.avg_rating:.2f}  Reviews: {persona.n_reviews}")
        print(f"Domains: {', '.join(persona.domains)}")
        print(f"Voice: {persona.voice_one_liner}")
    else:
        print(f"Cold-start: {persona.voice_one_liner}")

    print("\n" + "=" * 70)
    print(f"TOP {args.k} RECOMMENDATIONS")
    print("=" * 70)
    if not recs:
        print("(no recommendations returned)")
    held_out_ids = set(held_out["parent_asin"]) if not held_out.empty else set()
    hits = 0
    for r in recs:
        is_hit = r.item_id in held_out_ids
        marker = "  🎯 HIT" if is_hit else ""
        if is_hit:
            hits += 1
        print(f"\n#{r.rank}  [{r.domain}]  {r.title[:80]}{marker}")
        print(f"     Why: {r.reasoning}")

    # ── Hit Rate report ────────────────────────────────────────────────────
    if not held_out.empty:
        print("\n" + "=" * 70)
        print("EVALUATION (vs held-out test items)")
        print("=" * 70)
        print(f"User's actual held-out items: {len(held_out)}")
        print(f"Hits in top-{args.k}:           {hits}")
        print(f"Hit Rate@{args.k}:              {1.0 if hits > 0 else 0.0:.2f}")
        if hits > 0:
            print(f"\nπŸŽ‰ The system recommended {hits} item(s) the user actually engaged with.")
        else:
            print(f"\n(No hits this run β€” held-out items are tiny needles in a 64k-item haystack.")
            print(f" Real evaluation will average across 500 users to get a stable score.)")

    print()


if __name__ == "__main__":
    main()