Spaces:
Running
Running
| """Task A service β FastAPI wrapper around the Impersonation agent. | |
| Exposes: | |
| POST /generate | |
| Body: { user_id: str | persona: {...}, item: {...}, naija_mode: bool } | |
| Returns: { rating, review, reasoning, used_history_count, naija_mode } | |
| GET /health | |
| Returns: { status: "ok" } | |
| GET /users (helper) | |
| Returns: list of sample user_ids the judges can try | |
| Run locally: | |
| uvicorn task_a_user_modeling.main:app --host 0.0.0.0 --port 8001 --reload | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| from functools import lru_cache | |
| from typing import Optional | |
| import pandas as pd | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel, Field | |
| from core.config import settings | |
| from core.llm import LLMClient | |
| from core.persona import PersonaEngine, UserPersona | |
| from task_a_user_modeling.agent import ( | |
| ImpersonationAgent, ItemInput, ImpersonationResult, | |
| ) | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") | |
| log = logging.getLogger(__name__) | |
| app = FastAPI( | |
| title="NaijaTaste AI β Task A: Behavioral Review Simulator", | |
| description=( | |
| "Generates ratings and reviews in the voice of a specific user, given " | |
| "their review history and a target item. Optional Nigerian English mode." | |
| ), | |
| version="0.1.0", | |
| ) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Lazy-loaded resources | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _load_reviews() -> pd.DataFrame: | |
| path = settings.processed_dir / "reviews.parquet" | |
| if not path.exists(): | |
| raise RuntimeError( | |
| f"Reviews file not found at {path}. Run `python data/prepare_data.py` first." | |
| ) | |
| df = pd.read_parquet(path) | |
| log.info(f"Loaded {len(df):,} reviews from {path}") | |
| return df | |
| def _persona_engine() -> PersonaEngine: | |
| return PersonaEngine() | |
| def _agent() -> ImpersonationAgent: | |
| return ImpersonationAgent() | |
| def _build_persona(user_id: str, enrich: bool = True) -> UserPersona: | |
| """Build (and LLM-enrich) a persona for a user. Cached so repeated calls | |
| for the same user are free.""" | |
| reviews = _load_reviews() | |
| train = reviews[reviews["split"] == "train"] | |
| engine = _persona_engine() | |
| persona = engine.from_dataframe(user_id, train) | |
| if enrich and persona.history_samples: | |
| persona = engine.enrich(persona) | |
| return persona | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Request / response schemas | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class PersonaInput(BaseModel): | |
| """Free-form persona supplied directly (instead of by user_id).""" | |
| description: str = Field( | |
| description="Free-text description of the user (cold-start friendly)" | |
| ) | |
| avg_rating: float = Field(default=4.0, ge=1.0, le=5.0) | |
| avg_review_length: int = Field(default=80, ge=10, le=2000) | |
| preferred_themes: list[str] = Field(default_factory=list) | |
| common_complaints: list[str] = Field(default_factory=list) | |
| class GenerateRequest(BaseModel): | |
| user_id: Optional[str] = Field( | |
| default=None, | |
| description="If set, the persona is built from this user's review history", | |
| ) | |
| persona: Optional[PersonaInput] = Field( | |
| default=None, | |
| description="Alternative to user_id β supply a free-text persona for cold-start", | |
| ) | |
| item: ItemInput | |
| naija_mode: bool = Field( | |
| default=False, | |
| description="If true, rewrite the generated review in Nigerian English register", | |
| ) | |
| class GenerateResponse(BaseModel): | |
| rating: float | |
| review: str | |
| reasoning: str | |
| used_history_count: int | |
| naija_mode: bool | |
| user_id: Optional[str] = None | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Endpoints | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def health(): | |
| return {"status": "ok", "service": "task_a_user_modeling"} | |
| def list_users(limit: int = 20): | |
| """Return a sample of user_ids that have rich histories, for judges to try.""" | |
| reviews = _load_reviews() | |
| train = reviews[reviews["split"] == "train"] | |
| counts = (train.groupby("user_id") | |
| .agg(n_reviews=("rating", "size"), | |
| n_domains=("domain", "nunique"), | |
| avg_rating=("rating", "mean")) | |
| .reset_index()) | |
| # Prefer cross-domain users | |
| pick = counts[counts["n_domains"] >= 2].nlargest(limit, "n_reviews") | |
| return { | |
| "users": [ | |
| { | |
| "user_id": row["user_id"], | |
| "n_reviews": int(row["n_reviews"]), | |
| "n_domains": int(row["n_domains"]), | |
| "avg_rating": round(float(row["avg_rating"]), 2), | |
| } | |
| for _, row in pick.iterrows() | |
| ] | |
| } | |
| def generate(req: GenerateRequest): | |
| """Generate a rating + review impersonating the given user.""" | |
| if req.user_id and req.persona: | |
| raise HTTPException(400, "Provide either user_id OR persona, not both") | |
| if not req.user_id and not req.persona: | |
| raise HTTPException(400, "Provide either user_id or persona") | |
| if req.user_id: | |
| try: | |
| persona = _build_persona(req.user_id, enrich=True) | |
| except ValueError: | |
| raise HTTPException(404, f"user_id {req.user_id!r} not found in training data") | |
| used_user_id = req.user_id | |
| else: | |
| # Build a synthetic persona from the free-text input | |
| p = req.persona | |
| persona = UserPersona( | |
| user_id="custom_cold_start", | |
| n_reviews=0, | |
| avg_rating=p.avg_rating, | |
| std_rating=0.5, | |
| avg_review_length=float(p.avg_review_length), | |
| std_review_length=20.0, | |
| verified_rate=1.0, | |
| domains=[req.item.domain], | |
| n_domains=1, | |
| rating_distribution={int(round(p.avg_rating)): 1.0}, | |
| top_terms=[], | |
| tone="", | |
| preferred_themes=p.preferred_themes, | |
| common_complaints=p.common_complaints, | |
| voice_one_liner=p.description, | |
| history_samples=[], | |
| ) | |
| used_user_id = None | |
| agent = _agent() | |
| result: ImpersonationResult = agent.run(persona, req.item, naija_mode=req.naija_mode) | |
| return GenerateResponse( | |
| rating=result.rating, | |
| review=result.review, | |
| reasoning=result.reasoning, | |
| used_history_count=result.used_history_count, | |
| naija_mode=result.naija_mode, | |
| user_id=used_user_id, | |
| ) | |