from contextlib import asynccontextmanager from pathlib import Path import uvicorn from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from pydantic import BaseModel from model import QAPipeline from retriever import MedicalRetriever BASE_DIR = Path(__file__).resolve().parent RETRIEVER_PATH = BASE_DIR / "artifacts" / "retriever.pkl" @asynccontextmanager async def lifespan(app: FastAPI): app.state.retriever = MedicalRetriever.load(str(RETRIEVER_PATH)) app.state.pipeline = QAPipeline() app.state.pipeline.retriever_ref = app.state.retriever print("Server ready") yield app = FastAPI(lifespan=lifespan) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) class QuestionRequest(BaseModel): question: str @app.get("/health") async def health() -> dict: passages_count = len(app.state.retriever.corpus) return {"status": "ok", "model": "bert+flan-t5", "passages": passages_count} @app.post("/predict") async def predict(payload: QuestionRequest) -> dict: question = payload.question.strip() if not question: raise HTTPException(status_code=400, detail="Question cannot be empty") try: passages = app.state.retriever.retrieve(question, top_k=5) result = app.state.pipeline.answer(question, passages) return result except Exception as e: return JSONResponse(status_code=500, content={"error": str(e)}) if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)