| from __future__ import annotations |
|
|
| import asyncio |
| import logging |
| import os |
| import time |
| from contextlib import asynccontextmanager |
| from pathlib import Path |
|
|
| from dotenv import load_dotenv |
| from fastapi import FastAPI, HTTPException |
| from pydantic import BaseModel, ConfigDict, Field |
|
|
| from app.recommendation_pipeline import RecommendationService |
| from app.shared_models import warm_shared_weights |
| from app.user_modeling import UserModelingService |
|
|
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| load_dotenv(Path(__file__).resolve().parents[1] / ".env") |
|
|
|
|
| @asynccontextmanager |
| async def lifespan(app: FastAPI): |
| if os.environ.get("SKIP_STARTUP_PREWARM", "").strip().lower() in ("1", "true", "yes"): |
| yield |
| return |
| mode = os.environ.get("STARTUP_PREWARM", "all").strip().lower() |
| logger.info("Startup prewarm (STARTUP_PREWARM=%s) …", mode) |
| try: |
| if mode not in ("none", "off", "0", "skip"): |
| warm_shared_weights() |
| if mode in ("all", "both", "*", "user_modeling", "task_a", "task1", "1"): |
| user_modeling_service().warm() |
| if mode in ("all", "both", "*", "recommendation", "task_b", "task2", "2"): |
| recommendation_service().warm() |
| logger.info("Startup prewarm complete.") |
| except Exception: |
| logger.exception( |
| "Startup prewarm failed — first requests may be slow; set SKIP_STARTUP_PREWARM=1 to disable" |
| ) |
| yield |
|
|
|
|
| app = FastAPI( |
| title="DSN X BCT — User modeling & Recommendation", |
| description=( |
| "Task 1 (User modeling): persona + product -> rating & review. " |
| "Task 2 (Recommendation): persona -> personalised ranked items." |
| ), |
| version="1.0", |
| lifespan=lifespan, |
| ) |
|
|
| _um: UserModelingService | None = None |
| _rec: RecommendationService | None = None |
|
|
|
|
| def user_modeling_service() -> UserModelingService: |
| global _um |
| if _um is None: |
| _um = UserModelingService() |
| return _um |
|
|
|
|
| def recommendation_service() -> RecommendationService: |
| global _rec |
| if _rec is None: |
| _rec = RecommendationService() |
| return _rec |
|
|
|
|
| class UserModelingRequest(BaseModel): |
| model_config = ConfigDict( |
| json_schema_extra={ |
| "examples": [ |
| { |
| "persona": ( |
| "user_id: stub_user_0\ndisplay_name: Jordan\nyelping_since: 2018-06-01\n" |
| "reviews_written: 42\naverage_stars_across_reviews: 3.9\nuseful_votes_given: 12\n" |
| "funny_votes_given: 2\ncool_votes_given: 4\nfans: 1\nelite_years: none\nfriend_count: 28" |
| ), |
| "product": ( |
| "name: Riverfront Ramen\ncategories: Restaurants, Japanese, Ramen\n" |
| "location: Portland, OR\nbusiness_avg_stars: 4.1\nbusiness_review_count: 256\nis_open: 1" |
| ), |
| "include_raw": False, |
| } |
| ] |
| } |
| ) |
| persona: str |
| product: str |
| include_raw: bool = False |
|
|
|
|
| class UserModelingResponse(BaseModel): |
| task: str |
| agent_steps: list[str] |
| rag_snippets_used: int = 0 |
| stars: int | None |
| review: str |
| parse_ok: bool |
| raw: str | None = None |
|
|
|
|
| @app.post( |
| "/user-modeling", |
| tags=["Task 1 — User modeling"], |
| response_model=UserModelingResponse, |
| ) |
| @app.post("/task-1", tags=["Task 1 — User modeling"], response_model=UserModelingResponse) |
| async def user_modeling(req: UserModelingRequest) -> UserModelingResponse: |
| t0 = time.perf_counter() |
| logger.info("POST /user-modeling started") |
| try: |
| data = await asyncio.to_thread( |
| user_modeling_service().generate, |
| req.persona, |
| req.product, |
| include_raw=req.include_raw, |
| ) |
| logger.info("POST /user-modeling finished in %.2fs", time.perf_counter() - t0) |
| return UserModelingResponse.model_validate(data) |
| except RuntimeError as e: |
| raise HTTPException(status_code=503, detail=str(e)) from e |
| except Exception as e: |
| logger.exception("user_modeling") |
| raise HTTPException(status_code=500, detail=str(e)) from e |
|
|
|
|
| class ChatTurn(BaseModel): |
| role: str |
| content: str |
|
|
|
|
| class RecommendationRequest(BaseModel): |
| model_config = ConfigDict( |
| json_schema_extra={ |
| "examples": [ |
| { |
| "persona": ( |
| "display_name: Morgan\nreviews_written: 120\n" |
| "average_stars_across_reviews: 4.1\nelite_years: none" |
| ), |
| "city": None, |
| "state": None, |
| "chat_history": [], |
| "top_k_retrieval": 30, |
| "top_n_final": 5, |
| } |
| ] |
| } |
| ) |
| persona: str |
| city: str | None = None |
| state: str | None = None |
| chat_history: list[ChatTurn] = Field(default_factory=list) |
| top_k_retrieval: int = Field(20, ge=5, le=200) |
| top_n_final: int = Field(5, ge=1, le=25) |
|
|
|
|
| class RecommendationRank(BaseModel): |
| business_id: str |
| rank: int |
| rationale: str |
|
|
|
|
| class RecommendationResponse(BaseModel): |
| task: str |
| agent_steps: list[str] |
| candidates_considered: int |
| recommendations: list[RecommendationRank] |
|
|
|
|
| @app.post("/recommendation", tags=["Task 2 — Recommendation"], response_model=RecommendationResponse) |
| @app.post("/task-2", tags=["Task 2 — Recommendation"], response_model=RecommendationResponse) |
| async def recommendation(req: RecommendationRequest) -> RecommendationResponse: |
| t0 = time.perf_counter() |
| logger.info("POST /recommendation started (top_k=%s top_n=%s)", req.top_k_retrieval, req.top_n_final) |
| try: |
| svc = recommendation_service() |
| history = [t.model_dump() for t in req.chat_history] |
| data = await asyncio.to_thread( |
| svc.recommend, |
| req.persona, |
| city=req.city, |
| state=req.state, |
| chat_history=history, |
| top_k_retrieval=req.top_k_retrieval, |
| top_n_final=req.top_n_final, |
| ) |
| logger.info("POST /recommendation finished in %.2fs", time.perf_counter() - t0) |
| return RecommendationResponse.model_validate(data) |
| except FileNotFoundError as e: |
| raise HTTPException(status_code=503, detail=str(e)) from e |
| except RuntimeError as e: |
| raise HTTPException(status_code=503, detail=str(e)) from e |
| except Exception as e: |
| logger.exception("recommendation") |
| raise HTTPException(status_code=500, detail=str(e)) from e |
|
|
|
|
| @app.post("/task_a", include_in_schema=False, response_model=UserModelingResponse) |
| async def task_a_legacy(req: UserModelingRequest) -> UserModelingResponse: |
| return await user_modeling(req) |
|
|
|
|
| @app.post("/task_b", include_in_schema=False, response_model=RecommendationResponse) |
| async def task_b_legacy(req: RecommendationRequest) -> RecommendationResponse: |
| return await recommendation(req) |
|
|
|
|
| @app.get("/health") |
| def health() -> dict[str, str]: |
| return {"status": "ok"} |
|
|
|
|
| @app.get("/") |
| def root() -> dict[str, str | list[str]]: |
| return { |
| "challenge": "DSN X BCT LLM Agent Challenge", |
| "task_1_user_modeling": ["POST /user-modeling", "POST /task-1"], |
| "task_2_recommendation": ["POST /recommendation", "POST /task-2"], |
| } |
|
|
|
|
| if __name__ == "__main__": |
| import uvicorn |
|
|
| port = int(os.environ.get("PORT", "8080")) |
| uvicorn.run(app, host="0.0.0.0", port=port) |
|
|