Spaces:
Running
Running
File size: 7,778 Bytes
1539e17 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 | """Task A service β FastAPI wrapper around the Impersonation agent.
Exposes:
POST /generate
Body: { user_id: str | persona: {...}, item: {...}, naija_mode: bool }
Returns: { rating, review, reasoning, used_history_count, naija_mode }
GET /health
Returns: { status: "ok" }
GET /users (helper)
Returns: list of sample user_ids the judges can try
Run locally:
uvicorn task_a_user_modeling.main:app --host 0.0.0.0 --port 8001 --reload
"""
from __future__ import annotations
import logging
from functools import lru_cache
from typing import Optional
import pandas as pd
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
from core.config import settings
from core.llm import LLMClient
from core.persona import PersonaEngine, UserPersona
from task_a_user_modeling.agent import (
ImpersonationAgent, ItemInput, ImpersonationResult,
)
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
log = logging.getLogger(__name__)
app = FastAPI(
title="NaijaTaste AI β Task A: Behavioral Review Simulator",
description=(
"Generates ratings and reviews in the voice of a specific user, given "
"their review history and a target item. Optional Nigerian English mode."
),
version="0.1.0",
)
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Lazy-loaded resources
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
@lru_cache(maxsize=1)
def _load_reviews() -> pd.DataFrame:
path = settings.processed_dir / "reviews.parquet"
if not path.exists():
raise RuntimeError(
f"Reviews file not found at {path}. Run `python data/prepare_data.py` first."
)
df = pd.read_parquet(path)
log.info(f"Loaded {len(df):,} reviews from {path}")
return df
@lru_cache(maxsize=1)
def _persona_engine() -> PersonaEngine:
return PersonaEngine()
@lru_cache(maxsize=1)
def _agent() -> ImpersonationAgent:
return ImpersonationAgent()
@lru_cache(maxsize=512)
def _build_persona(user_id: str, enrich: bool = True) -> UserPersona:
"""Build (and LLM-enrich) a persona for a user. Cached so repeated calls
for the same user are free."""
reviews = _load_reviews()
train = reviews[reviews["split"] == "train"]
engine = _persona_engine()
persona = engine.from_dataframe(user_id, train)
if enrich and persona.history_samples:
persona = engine.enrich(persona)
return persona
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Request / response schemas
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class PersonaInput(BaseModel):
"""Free-form persona supplied directly (instead of by user_id)."""
description: str = Field(
description="Free-text description of the user (cold-start friendly)"
)
avg_rating: float = Field(default=4.0, ge=1.0, le=5.0)
avg_review_length: int = Field(default=80, ge=10, le=2000)
preferred_themes: list[str] = Field(default_factory=list)
common_complaints: list[str] = Field(default_factory=list)
class GenerateRequest(BaseModel):
user_id: Optional[str] = Field(
default=None,
description="If set, the persona is built from this user's review history",
)
persona: Optional[PersonaInput] = Field(
default=None,
description="Alternative to user_id β supply a free-text persona for cold-start",
)
item: ItemInput
naija_mode: bool = Field(
default=False,
description="If true, rewrite the generated review in Nigerian English register",
)
class GenerateResponse(BaseModel):
rating: float
review: str
reasoning: str
used_history_count: int
naija_mode: bool
user_id: Optional[str] = None
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Endpoints
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
@app.get("/health")
def health():
return {"status": "ok", "service": "task_a_user_modeling"}
@app.get("/users")
def list_users(limit: int = 20):
"""Return a sample of user_ids that have rich histories, for judges to try."""
reviews = _load_reviews()
train = reviews[reviews["split"] == "train"]
counts = (train.groupby("user_id")
.agg(n_reviews=("rating", "size"),
n_domains=("domain", "nunique"),
avg_rating=("rating", "mean"))
.reset_index())
# Prefer cross-domain users
pick = counts[counts["n_domains"] >= 2].nlargest(limit, "n_reviews")
return {
"users": [
{
"user_id": row["user_id"],
"n_reviews": int(row["n_reviews"]),
"n_domains": int(row["n_domains"]),
"avg_rating": round(float(row["avg_rating"]), 2),
}
for _, row in pick.iterrows()
]
}
@app.post("/generate", response_model=GenerateResponse)
def generate(req: GenerateRequest):
"""Generate a rating + review impersonating the given user."""
if req.user_id and req.persona:
raise HTTPException(400, "Provide either user_id OR persona, not both")
if not req.user_id and not req.persona:
raise HTTPException(400, "Provide either user_id or persona")
if req.user_id:
try:
persona = _build_persona(req.user_id, enrich=True)
except ValueError:
raise HTTPException(404, f"user_id {req.user_id!r} not found in training data")
used_user_id = req.user_id
else:
# Build a synthetic persona from the free-text input
p = req.persona
persona = UserPersona(
user_id="custom_cold_start",
n_reviews=0,
avg_rating=p.avg_rating,
std_rating=0.5,
avg_review_length=float(p.avg_review_length),
std_review_length=20.0,
verified_rate=1.0,
domains=[req.item.domain],
n_domains=1,
rating_distribution={int(round(p.avg_rating)): 1.0},
top_terms=[],
tone="",
preferred_themes=p.preferred_themes,
common_complaints=p.common_complaints,
voice_one_liner=p.description,
history_samples=[],
)
used_user_id = None
agent = _agent()
result: ImpersonationResult = agent.run(persona, req.item, naija_mode=req.naija_mode)
return GenerateResponse(
rating=result.rating,
review=result.review,
reasoning=result.reasoning,
used_history_count=result.used_history_count,
naija_mode=result.naija_mode,
user_id=used_user_id,
)
|