DSN / app /main.py
nexusbert's picture
push
4b0eec9
from __future__ import annotations
import asyncio
import logging
import os
import time
from contextlib import asynccontextmanager
from pathlib import Path
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, ConfigDict, Field
from app.recommendation_pipeline import RecommendationService
from app.shared_models import warm_shared_weights
from app.user_modeling import UserModelingService
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
load_dotenv(Path(__file__).resolve().parents[1] / ".env")
@asynccontextmanager
async def lifespan(app: FastAPI):
if os.environ.get("SKIP_STARTUP_PREWARM", "").strip().lower() in ("1", "true", "yes"):
yield
return
mode = os.environ.get("STARTUP_PREWARM", "all").strip().lower()
logger.info("Startup prewarm (STARTUP_PREWARM=%s) …", mode)
try:
if mode not in ("none", "off", "0", "skip"):
warm_shared_weights()
if mode in ("all", "both", "*", "user_modeling", "task_a", "task1", "1"):
user_modeling_service().warm()
if mode in ("all", "both", "*", "recommendation", "task_b", "task2", "2"):
recommendation_service().warm()
logger.info("Startup prewarm complete.")
except Exception:
logger.exception(
"Startup prewarm failed — first requests may be slow; set SKIP_STARTUP_PREWARM=1 to disable"
)
yield
app = FastAPI(
title="DSN X BCT — User modeling & Recommendation",
description=(
"Task 1 (User modeling): persona + product -> rating & review. "
"Task 2 (Recommendation): persona -> personalised ranked items."
),
version="1.0",
lifespan=lifespan,
)
_um: UserModelingService | None = None
_rec: RecommendationService | None = None
def user_modeling_service() -> UserModelingService:
global _um
if _um is None:
_um = UserModelingService()
return _um
def recommendation_service() -> RecommendationService:
global _rec
if _rec is None:
_rec = RecommendationService()
return _rec
class UserModelingRequest(BaseModel):
model_config = ConfigDict(
json_schema_extra={
"examples": [
{
"persona": (
"user_id: stub_user_0\ndisplay_name: Jordan\nyelping_since: 2018-06-01\n"
"reviews_written: 42\naverage_stars_across_reviews: 3.9\nuseful_votes_given: 12\n"
"funny_votes_given: 2\ncool_votes_given: 4\nfans: 1\nelite_years: none\nfriend_count: 28"
),
"product": (
"name: Riverfront Ramen\ncategories: Restaurants, Japanese, Ramen\n"
"location: Portland, OR\nbusiness_avg_stars: 4.1\nbusiness_review_count: 256\nis_open: 1"
),
"include_raw": False,
}
]
}
)
persona: str
product: str
include_raw: bool = False
class UserModelingResponse(BaseModel):
task: str
agent_steps: list[str]
rag_snippets_used: int = 0
stars: int | None
review: str
parse_ok: bool
raw: str | None = None
@app.post(
"/user-modeling",
tags=["Task 1 — User modeling"],
response_model=UserModelingResponse,
)
@app.post("/task-1", tags=["Task 1 — User modeling"], response_model=UserModelingResponse)
async def user_modeling(req: UserModelingRequest) -> UserModelingResponse:
t0 = time.perf_counter()
logger.info("POST /user-modeling started")
try:
data = await asyncio.to_thread(
user_modeling_service().generate,
req.persona,
req.product,
include_raw=req.include_raw,
)
logger.info("POST /user-modeling finished in %.2fs", time.perf_counter() - t0)
return UserModelingResponse.model_validate(data)
except RuntimeError as e:
raise HTTPException(status_code=503, detail=str(e)) from e
except Exception as e:
logger.exception("user_modeling")
raise HTTPException(status_code=500, detail=str(e)) from e
class ChatTurn(BaseModel):
role: str
content: str
class RecommendationRequest(BaseModel):
model_config = ConfigDict(
json_schema_extra={
"examples": [
{
"persona": (
"display_name: Morgan\nreviews_written: 120\n"
"average_stars_across_reviews: 4.1\nelite_years: none"
),
"city": None,
"state": None,
"chat_history": [],
"top_k_retrieval": 30,
"top_n_final": 5,
}
]
}
)
persona: str
city: str | None = None
state: str | None = None
chat_history: list[ChatTurn] = Field(default_factory=list)
top_k_retrieval: int = Field(20, ge=5, le=200)
top_n_final: int = Field(5, ge=1, le=25)
class RecommendationRank(BaseModel):
business_id: str
rank: int
rationale: str
class RecommendationResponse(BaseModel):
task: str
agent_steps: list[str]
candidates_considered: int
recommendations: list[RecommendationRank]
@app.post("/recommendation", tags=["Task 2 — Recommendation"], response_model=RecommendationResponse)
@app.post("/task-2", tags=["Task 2 — Recommendation"], response_model=RecommendationResponse)
async def recommendation(req: RecommendationRequest) -> RecommendationResponse:
t0 = time.perf_counter()
logger.info("POST /recommendation started (top_k=%s top_n=%s)", req.top_k_retrieval, req.top_n_final)
try:
svc = recommendation_service()
history = [t.model_dump() for t in req.chat_history]
data = await asyncio.to_thread(
svc.recommend,
req.persona,
city=req.city,
state=req.state,
chat_history=history,
top_k_retrieval=req.top_k_retrieval,
top_n_final=req.top_n_final,
)
logger.info("POST /recommendation finished in %.2fs", time.perf_counter() - t0)
return RecommendationResponse.model_validate(data)
except FileNotFoundError as e:
raise HTTPException(status_code=503, detail=str(e)) from e
except RuntimeError as e:
raise HTTPException(status_code=503, detail=str(e)) from e
except Exception as e:
logger.exception("recommendation")
raise HTTPException(status_code=500, detail=str(e)) from e
@app.post("/task_a", include_in_schema=False, response_model=UserModelingResponse)
async def task_a_legacy(req: UserModelingRequest) -> UserModelingResponse:
return await user_modeling(req)
@app.post("/task_b", include_in_schema=False, response_model=RecommendationResponse)
async def task_b_legacy(req: RecommendationRequest) -> RecommendationResponse:
return await recommendation(req)
@app.get("/health")
def health() -> dict[str, str]:
return {"status": "ok"}
@app.get("/")
def root() -> dict[str, str | list[str]]:
return {
"challenge": "DSN X BCT LLM Agent Challenge",
"task_1_user_modeling": ["POST /user-modeling", "POST /task-1"],
"task_2_recommendation": ["POST /recommendation", "POST /task-2"],
}
if __name__ == "__main__":
import uvicorn
port = int(os.environ.get("PORT", "8080"))
uvicorn.run(app, host="0.0.0.0", port=port)