from __future__ import annotations import asyncio import logging import os import time from contextlib import asynccontextmanager from pathlib import Path from dotenv import load_dotenv from fastapi import FastAPI, HTTPException from pydantic import BaseModel, ConfigDict, Field from app.recommendation_pipeline import RecommendationService from app.shared_models import warm_shared_weights from app.user_modeling import UserModelingService logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) load_dotenv(Path(__file__).resolve().parents[1] / ".env") @asynccontextmanager async def lifespan(app: FastAPI): if os.environ.get("SKIP_STARTUP_PREWARM", "").strip().lower() in ("1", "true", "yes"): yield return mode = os.environ.get("STARTUP_PREWARM", "all").strip().lower() logger.info("Startup prewarm (STARTUP_PREWARM=%s) …", mode) try: if mode not in ("none", "off", "0", "skip"): warm_shared_weights() if mode in ("all", "both", "*", "user_modeling", "task_a", "task1", "1"): user_modeling_service().warm() if mode in ("all", "both", "*", "recommendation", "task_b", "task2", "2"): recommendation_service().warm() logger.info("Startup prewarm complete.") except Exception: logger.exception( "Startup prewarm failed — first requests may be slow; set SKIP_STARTUP_PREWARM=1 to disable" ) yield app = FastAPI( title="DSN X BCT — User modeling & Recommendation", description=( "Task 1 (User modeling): persona + product -> rating & review. " "Task 2 (Recommendation): persona -> personalised ranked items." ), version="1.0", lifespan=lifespan, ) _um: UserModelingService | None = None _rec: RecommendationService | None = None def user_modeling_service() -> UserModelingService: global _um if _um is None: _um = UserModelingService() return _um def recommendation_service() -> RecommendationService: global _rec if _rec is None: _rec = RecommendationService() return _rec class UserModelingRequest(BaseModel): model_config = ConfigDict( json_schema_extra={ "examples": [ { "persona": ( "user_id: stub_user_0\ndisplay_name: Jordan\nyelping_since: 2018-06-01\n" "reviews_written: 42\naverage_stars_across_reviews: 3.9\nuseful_votes_given: 12\n" "funny_votes_given: 2\ncool_votes_given: 4\nfans: 1\nelite_years: none\nfriend_count: 28" ), "product": ( "name: Riverfront Ramen\ncategories: Restaurants, Japanese, Ramen\n" "location: Portland, OR\nbusiness_avg_stars: 4.1\nbusiness_review_count: 256\nis_open: 1" ), "include_raw": False, } ] } ) persona: str product: str include_raw: bool = False class UserModelingResponse(BaseModel): task: str agent_steps: list[str] rag_snippets_used: int = 0 stars: int | None review: str parse_ok: bool raw: str | None = None @app.post( "/user-modeling", tags=["Task 1 — User modeling"], response_model=UserModelingResponse, ) @app.post("/task-1", tags=["Task 1 — User modeling"], response_model=UserModelingResponse) async def user_modeling(req: UserModelingRequest) -> UserModelingResponse: t0 = time.perf_counter() logger.info("POST /user-modeling started") try: data = await asyncio.to_thread( user_modeling_service().generate, req.persona, req.product, include_raw=req.include_raw, ) logger.info("POST /user-modeling finished in %.2fs", time.perf_counter() - t0) return UserModelingResponse.model_validate(data) except RuntimeError as e: raise HTTPException(status_code=503, detail=str(e)) from e except Exception as e: logger.exception("user_modeling") raise HTTPException(status_code=500, detail=str(e)) from e class ChatTurn(BaseModel): role: str content: str class RecommendationRequest(BaseModel): model_config = ConfigDict( json_schema_extra={ "examples": [ { "persona": ( "display_name: Morgan\nreviews_written: 120\n" "average_stars_across_reviews: 4.1\nelite_years: none" ), "city": None, "state": None, "chat_history": [], "top_k_retrieval": 30, "top_n_final": 5, } ] } ) persona: str city: str | None = None state: str | None = None chat_history: list[ChatTurn] = Field(default_factory=list) top_k_retrieval: int = Field(20, ge=5, le=200) top_n_final: int = Field(5, ge=1, le=25) class RecommendationRank(BaseModel): business_id: str rank: int rationale: str class RecommendationResponse(BaseModel): task: str agent_steps: list[str] candidates_considered: int recommendations: list[RecommendationRank] @app.post("/recommendation", tags=["Task 2 — Recommendation"], response_model=RecommendationResponse) @app.post("/task-2", tags=["Task 2 — Recommendation"], response_model=RecommendationResponse) async def recommendation(req: RecommendationRequest) -> RecommendationResponse: t0 = time.perf_counter() logger.info("POST /recommendation started (top_k=%s top_n=%s)", req.top_k_retrieval, req.top_n_final) try: svc = recommendation_service() history = [t.model_dump() for t in req.chat_history] data = await asyncio.to_thread( svc.recommend, req.persona, city=req.city, state=req.state, chat_history=history, top_k_retrieval=req.top_k_retrieval, top_n_final=req.top_n_final, ) logger.info("POST /recommendation finished in %.2fs", time.perf_counter() - t0) return RecommendationResponse.model_validate(data) except FileNotFoundError as e: raise HTTPException(status_code=503, detail=str(e)) from e except RuntimeError as e: raise HTTPException(status_code=503, detail=str(e)) from e except Exception as e: logger.exception("recommendation") raise HTTPException(status_code=500, detail=str(e)) from e @app.post("/task_a", include_in_schema=False, response_model=UserModelingResponse) async def task_a_legacy(req: UserModelingRequest) -> UserModelingResponse: return await user_modeling(req) @app.post("/task_b", include_in_schema=False, response_model=RecommendationResponse) async def task_b_legacy(req: RecommendationRequest) -> RecommendationResponse: return await recommendation(req) @app.get("/health") def health() -> dict[str, str]: return {"status": "ok"} @app.get("/") def root() -> dict[str, str | list[str]]: return { "challenge": "DSN X BCT LLM Agent Challenge", "task_1_user_modeling": ["POST /user-modeling", "POST /task-1"], "task_2_recommendation": ["POST /recommendation", "POST /task-2"], } if __name__ == "__main__": import uvicorn port = int(os.environ.get("PORT", "8080")) uvicorn.run(app, host="0.0.0.0", port=port)