Israelbliz's picture
Upload task_a_user_modeling
1539e17 verified
"""Task A service β€” FastAPI wrapper around the Impersonation agent.
Exposes:
POST /generate
Body: { user_id: str | persona: {...}, item: {...}, naija_mode: bool }
Returns: { rating, review, reasoning, used_history_count, naija_mode }
GET /health
Returns: { status: "ok" }
GET /users (helper)
Returns: list of sample user_ids the judges can try
Run locally:
uvicorn task_a_user_modeling.main:app --host 0.0.0.0 --port 8001 --reload
"""
from __future__ import annotations
import logging
from functools import lru_cache
from typing import Optional
import pandas as pd
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
from core.config import settings
from core.llm import LLMClient
from core.persona import PersonaEngine, UserPersona
from task_a_user_modeling.agent import (
ImpersonationAgent, ItemInput, ImpersonationResult,
)
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
log = logging.getLogger(__name__)
app = FastAPI(
title="NaijaTaste AI β€” Task A: Behavioral Review Simulator",
description=(
"Generates ratings and reviews in the voice of a specific user, given "
"their review history and a target item. Optional Nigerian English mode."
),
version="0.1.0",
)
# ──────────────────────────────────────────────────────────────────────────────
# Lazy-loaded resources
# ──────────────────────────────────────────────────────────────────────────────
@lru_cache(maxsize=1)
def _load_reviews() -> pd.DataFrame:
path = settings.processed_dir / "reviews.parquet"
if not path.exists():
raise RuntimeError(
f"Reviews file not found at {path}. Run `python data/prepare_data.py` first."
)
df = pd.read_parquet(path)
log.info(f"Loaded {len(df):,} reviews from {path}")
return df
@lru_cache(maxsize=1)
def _persona_engine() -> PersonaEngine:
return PersonaEngine()
@lru_cache(maxsize=1)
def _agent() -> ImpersonationAgent:
return ImpersonationAgent()
@lru_cache(maxsize=512)
def _build_persona(user_id: str, enrich: bool = True) -> UserPersona:
"""Build (and LLM-enrich) a persona for a user. Cached so repeated calls
for the same user are free."""
reviews = _load_reviews()
train = reviews[reviews["split"] == "train"]
engine = _persona_engine()
persona = engine.from_dataframe(user_id, train)
if enrich and persona.history_samples:
persona = engine.enrich(persona)
return persona
# ──────────────────────────────────────────────────────────────────────────────
# Request / response schemas
# ──────────────────────────────────────────────────────────────────────────────
class PersonaInput(BaseModel):
"""Free-form persona supplied directly (instead of by user_id)."""
description: str = Field(
description="Free-text description of the user (cold-start friendly)"
)
avg_rating: float = Field(default=4.0, ge=1.0, le=5.0)
avg_review_length: int = Field(default=80, ge=10, le=2000)
preferred_themes: list[str] = Field(default_factory=list)
common_complaints: list[str] = Field(default_factory=list)
class GenerateRequest(BaseModel):
user_id: Optional[str] = Field(
default=None,
description="If set, the persona is built from this user's review history",
)
persona: Optional[PersonaInput] = Field(
default=None,
description="Alternative to user_id β€” supply a free-text persona for cold-start",
)
item: ItemInput
naija_mode: bool = Field(
default=False,
description="If true, rewrite the generated review in Nigerian English register",
)
class GenerateResponse(BaseModel):
rating: float
review: str
reasoning: str
used_history_count: int
naija_mode: bool
user_id: Optional[str] = None
# ──────────────────────────────────────────────────────────────────────────────
# Endpoints
# ──────────────────────────────────────────────────────────────────────────────
@app.get("/health")
def health():
return {"status": "ok", "service": "task_a_user_modeling"}
@app.get("/users")
def list_users(limit: int = 20):
"""Return a sample of user_ids that have rich histories, for judges to try."""
reviews = _load_reviews()
train = reviews[reviews["split"] == "train"]
counts = (train.groupby("user_id")
.agg(n_reviews=("rating", "size"),
n_domains=("domain", "nunique"),
avg_rating=("rating", "mean"))
.reset_index())
# Prefer cross-domain users
pick = counts[counts["n_domains"] >= 2].nlargest(limit, "n_reviews")
return {
"users": [
{
"user_id": row["user_id"],
"n_reviews": int(row["n_reviews"]),
"n_domains": int(row["n_domains"]),
"avg_rating": round(float(row["avg_rating"]), 2),
}
for _, row in pick.iterrows()
]
}
@app.post("/generate", response_model=GenerateResponse)
def generate(req: GenerateRequest):
"""Generate a rating + review impersonating the given user."""
if req.user_id and req.persona:
raise HTTPException(400, "Provide either user_id OR persona, not both")
if not req.user_id and not req.persona:
raise HTTPException(400, "Provide either user_id or persona")
if req.user_id:
try:
persona = _build_persona(req.user_id, enrich=True)
except ValueError:
raise HTTPException(404, f"user_id {req.user_id!r} not found in training data")
used_user_id = req.user_id
else:
# Build a synthetic persona from the free-text input
p = req.persona
persona = UserPersona(
user_id="custom_cold_start",
n_reviews=0,
avg_rating=p.avg_rating,
std_rating=0.5,
avg_review_length=float(p.avg_review_length),
std_review_length=20.0,
verified_rate=1.0,
domains=[req.item.domain],
n_domains=1,
rating_distribution={int(round(p.avg_rating)): 1.0},
top_terms=[],
tone="",
preferred_themes=p.preferred_themes,
common_complaints=p.common_complaints,
voice_one_liner=p.description,
history_samples=[],
)
used_user_id = None
agent = _agent()
result: ImpersonationResult = agent.run(persona, req.item, naija_mode=req.naija_mode)
return GenerateResponse(
rating=result.rating,
review=result.review,
reasoning=result.reasoning,
used_history_count=result.used_history_count,
naija_mode=result.naija_mode,
user_id=used_user_id,
)