File size: 7,778 Bytes
1539e17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
"""Task A service β€” FastAPI wrapper around the Impersonation agent.

Exposes:
    POST /generate
        Body: { user_id: str | persona: {...}, item: {...}, naija_mode: bool }
        Returns: { rating, review, reasoning, used_history_count, naija_mode }

    GET /health
        Returns: { status: "ok" }

    GET /users (helper)
        Returns: list of sample user_ids the judges can try

Run locally:
    uvicorn task_a_user_modeling.main:app --host 0.0.0.0 --port 8001 --reload
"""
from __future__ import annotations

import logging
from functools import lru_cache
from typing import Optional

import pandas as pd
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field

from core.config import settings
from core.llm import LLMClient
from core.persona import PersonaEngine, UserPersona
from task_a_user_modeling.agent import (
    ImpersonationAgent, ItemInput, ImpersonationResult,
)

logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
log = logging.getLogger(__name__)

app = FastAPI(
    title="NaijaTaste AI β€” Task A: Behavioral Review Simulator",
    description=(
        "Generates ratings and reviews in the voice of a specific user, given "
        "their review history and a target item. Optional Nigerian English mode."
    ),
    version="0.1.0",
)


# ──────────────────────────────────────────────────────────────────────────────
# Lazy-loaded resources
# ──────────────────────────────────────────────────────────────────────────────

@lru_cache(maxsize=1)
def _load_reviews() -> pd.DataFrame:
    path = settings.processed_dir / "reviews.parquet"
    if not path.exists():
        raise RuntimeError(
            f"Reviews file not found at {path}. Run `python data/prepare_data.py` first."
        )
    df = pd.read_parquet(path)
    log.info(f"Loaded {len(df):,} reviews from {path}")
    return df


@lru_cache(maxsize=1)
def _persona_engine() -> PersonaEngine:
    return PersonaEngine()


@lru_cache(maxsize=1)
def _agent() -> ImpersonationAgent:
    return ImpersonationAgent()


@lru_cache(maxsize=512)
def _build_persona(user_id: str, enrich: bool = True) -> UserPersona:
    """Build (and LLM-enrich) a persona for a user. Cached so repeated calls
    for the same user are free."""
    reviews = _load_reviews()
    train = reviews[reviews["split"] == "train"]
    engine = _persona_engine()
    persona = engine.from_dataframe(user_id, train)
    if enrich and persona.history_samples:
        persona = engine.enrich(persona)
    return persona


# ──────────────────────────────────────────────────────────────────────────────
# Request / response schemas
# ──────────────────────────────────────────────────────────────────────────────

class PersonaInput(BaseModel):
    """Free-form persona supplied directly (instead of by user_id)."""
    description: str = Field(
        description="Free-text description of the user (cold-start friendly)"
    )
    avg_rating: float = Field(default=4.0, ge=1.0, le=5.0)
    avg_review_length: int = Field(default=80, ge=10, le=2000)
    preferred_themes: list[str] = Field(default_factory=list)
    common_complaints: list[str] = Field(default_factory=list)


class GenerateRequest(BaseModel):
    user_id: Optional[str] = Field(
        default=None,
        description="If set, the persona is built from this user's review history",
    )
    persona: Optional[PersonaInput] = Field(
        default=None,
        description="Alternative to user_id β€” supply a free-text persona for cold-start",
    )
    item: ItemInput
    naija_mode: bool = Field(
        default=False,
        description="If true, rewrite the generated review in Nigerian English register",
    )


class GenerateResponse(BaseModel):
    rating: float
    review: str
    reasoning: str
    used_history_count: int
    naija_mode: bool
    user_id: Optional[str] = None


# ──────────────────────────────────────────────────────────────────────────────
# Endpoints
# ──────────────────────────────────────────────────────────────────────────────

@app.get("/health")
def health():
    return {"status": "ok", "service": "task_a_user_modeling"}


@app.get("/users")
def list_users(limit: int = 20):
    """Return a sample of user_ids that have rich histories, for judges to try."""
    reviews = _load_reviews()
    train = reviews[reviews["split"] == "train"]
    counts = (train.groupby("user_id")
                   .agg(n_reviews=("rating", "size"),
                        n_domains=("domain", "nunique"),
                        avg_rating=("rating", "mean"))
                   .reset_index())
    # Prefer cross-domain users
    pick = counts[counts["n_domains"] >= 2].nlargest(limit, "n_reviews")
    return {
        "users": [
            {
                "user_id": row["user_id"],
                "n_reviews": int(row["n_reviews"]),
                "n_domains": int(row["n_domains"]),
                "avg_rating": round(float(row["avg_rating"]), 2),
            }
            for _, row in pick.iterrows()
        ]
    }


@app.post("/generate", response_model=GenerateResponse)
def generate(req: GenerateRequest):
    """Generate a rating + review impersonating the given user."""
    if req.user_id and req.persona:
        raise HTTPException(400, "Provide either user_id OR persona, not both")
    if not req.user_id and not req.persona:
        raise HTTPException(400, "Provide either user_id or persona")

    if req.user_id:
        try:
            persona = _build_persona(req.user_id, enrich=True)
        except ValueError:
            raise HTTPException(404, f"user_id {req.user_id!r} not found in training data")
        used_user_id = req.user_id
    else:
        # Build a synthetic persona from the free-text input
        p = req.persona
        persona = UserPersona(
            user_id="custom_cold_start",
            n_reviews=0,
            avg_rating=p.avg_rating,
            std_rating=0.5,
            avg_review_length=float(p.avg_review_length),
            std_review_length=20.0,
            verified_rate=1.0,
            domains=[req.item.domain],
            n_domains=1,
            rating_distribution={int(round(p.avg_rating)): 1.0},
            top_terms=[],
            tone="",
            preferred_themes=p.preferred_themes,
            common_complaints=p.common_complaints,
            voice_one_liner=p.description,
            history_samples=[],
        )
        used_user_id = None

    agent = _agent()
    result: ImpersonationResult = agent.run(persona, req.item, naija_mode=req.naija_mode)

    return GenerateResponse(
        rating=result.rating,
        review=result.review,
        reasoning=result.reasoning,
        used_history_count=result.used_history_count,
        naija_mode=result.naija_mode,
        user_id=used_user_id,
    )