User-Modeling-Agent / scripts /verify_persona.py
Israelbliz's picture
Upload scripts
bad8292 verified
"""Quick verification β€” load processed data, build a persona, print it.
Run after `python data/prepare_data.py` to confirm the persona engine works
on your real Amazon data before we plug in the LLM agents.
python scripts/verify_persona.py
python scripts/verify_persona.py --enrich # also runs LLM enrichment
python scripts/verify_persona.py --user <id> # specific user
"""
from __future__ import annotations
import argparse
import logging
import pandas as pd
from core.config import settings
from core.persona import PersonaEngine
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--user", type=str, default=None,
help="Specific user_id; otherwise picks a cross-domain user with the most reviews")
ap.add_argument("--enrich", action="store_true",
help="Also run LLM qualitative enrichment (requires OPENAI_API_KEY)")
args = ap.parse_args()
reviews_path = settings.processed_dir / "reviews.parquet"
if not reviews_path.exists():
raise SystemExit(
f"Could not find {reviews_path}. Run `python data/prepare_data.py` first."
)
reviews = pd.read_parquet(reviews_path)
train = reviews[reviews["split"] == "train"]
print(f"Loaded {len(train):,} training reviews across {train['user_id'].nunique():,} users\n")
if args.user:
user_id = args.user
else:
# Pick the cross-domain user with the most reviews
counts = train.groupby("user_id").agg(n=("rating", "size"),
d=("domain", "nunique")).reset_index()
counts = counts[counts["d"] >= 2].sort_values("n", ascending=False)
if counts.empty:
user_id = train["user_id"].value_counts().idxmax()
else:
user_id = counts.iloc[0]["user_id"]
print(f"Auto-selected cross-domain user: {user_id}\n")
engine = PersonaEngine() if args.enrich else PersonaEngine.__new__(PersonaEngine)
if not args.enrich:
# Quantitative-only mode β€” skip LLM init
engine.top_terms_k = 20
engine.history_samples_k = 8
engine.llm = None # type: ignore
persona = engine.from_dataframe(user_id, train)
if args.enrich:
print("Running LLM enrichment...\n")
persona = engine.enrich(persona)
print("=" * 70)
print(persona.to_prompt_block())
print("=" * 70)
print(f"\nHistory samples ({len(persona.history_samples)}):")
for i, s in enumerate(persona.history_samples[:3], 1):
snippet = s["text"][:150].replace("\n", " ")
print(f" [{i}] {s['rating']}β˜… ({s['domain']}): {snippet}...")
if __name__ == "__main__":
main()