Spaces:
Running
Running
| from contextlib import asynccontextmanager | |
| from pathlib import Path | |
| import uvicorn | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse | |
| from pydantic import BaseModel | |
| from model import QAPipeline | |
| from retriever import MedicalRetriever | |
| BASE_DIR = Path(__file__).resolve().parent | |
| RETRIEVER_PATH = BASE_DIR / "artifacts" / "retriever.pkl" | |
| async def lifespan(app: FastAPI): | |
| app.state.retriever = MedicalRetriever.load(str(RETRIEVER_PATH)) | |
| app.state.pipeline = QAPipeline() | |
| app.state.pipeline.retriever_ref = app.state.retriever | |
| print("Server ready") | |
| yield | |
| app = FastAPI(lifespan=lifespan) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| class QuestionRequest(BaseModel): | |
| question: str | |
| async def health() -> dict: | |
| passages_count = len(app.state.retriever.corpus) | |
| return {"status": "ok", "model": "bert+flan-t5", "passages": passages_count} | |
| async def predict(payload: QuestionRequest) -> dict: | |
| question = payload.question.strip() | |
| if not question: | |
| raise HTTPException(status_code=400, detail="Question cannot be empty") | |
| try: | |
| passages = app.state.retriever.retrieve(question, top_k=5) | |
| result = app.state.pipeline.answer(question, passages) | |
| return result | |
| except Exception as e: | |
| return JSONResponse(status_code=500, content={"error": str(e)}) | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |