File size: 2,801 Bytes
bad8292
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
"""Quick verification — load processed data, build a persona, print it.

Run after `python data/prepare_data.py` to confirm the persona engine works
on your real Amazon data before we plug in the LLM agents.

    python scripts/verify_persona.py
    python scripts/verify_persona.py --enrich       # also runs LLM enrichment
    python scripts/verify_persona.py --user <id>    # specific user
"""
from __future__ import annotations

import argparse
import logging

import pandas as pd

from core.config import settings
from core.persona import PersonaEngine

logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--user", type=str, default=None,
                    help="Specific user_id; otherwise picks a cross-domain user with the most reviews")
    ap.add_argument("--enrich", action="store_true",
                    help="Also run LLM qualitative enrichment (requires OPENAI_API_KEY)")
    args = ap.parse_args()

    reviews_path = settings.processed_dir / "reviews.parquet"
    if not reviews_path.exists():
        raise SystemExit(
            f"Could not find {reviews_path}. Run `python data/prepare_data.py` first."
        )

    reviews = pd.read_parquet(reviews_path)
    train = reviews[reviews["split"] == "train"]
    print(f"Loaded {len(train):,} training reviews across {train['user_id'].nunique():,} users\n")

    if args.user:
        user_id = args.user
    else:
        # Pick the cross-domain user with the most reviews
        counts = train.groupby("user_id").agg(n=("rating", "size"),
                                              d=("domain", "nunique")).reset_index()
        counts = counts[counts["d"] >= 2].sort_values("n", ascending=False)
        if counts.empty:
            user_id = train["user_id"].value_counts().idxmax()
        else:
            user_id = counts.iloc[0]["user_id"]
        print(f"Auto-selected cross-domain user: {user_id}\n")

    engine = PersonaEngine() if args.enrich else PersonaEngine.__new__(PersonaEngine)
    if not args.enrich:
        # Quantitative-only mode — skip LLM init
        engine.top_terms_k = 20
        engine.history_samples_k = 8
        engine.llm = None  # type: ignore

    persona = engine.from_dataframe(user_id, train)

    if args.enrich:
        print("Running LLM enrichment...\n")
        persona = engine.enrich(persona)

    print("=" * 70)
    print(persona.to_prompt_block())
    print("=" * 70)
    print(f"\nHistory samples ({len(persona.history_samples)}):")
    for i, s in enumerate(persona.history_samples[:3], 1):
        snippet = s["text"][:150].replace("\n", " ")
        print(f"  [{i}] {s['rating']}★ ({s['domain']}): {snippet}...")


if __name__ == "__main__":
    main()