medimind-api / main.py
Manikantaperla's picture
initial medimind backend
d0c827a
from contextlib import asynccontextmanager
from pathlib import Path
import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from model import QAPipeline
from retriever import MedicalRetriever
BASE_DIR = Path(__file__).resolve().parent
RETRIEVER_PATH = BASE_DIR / "artifacts" / "retriever.pkl"
@asynccontextmanager
async def lifespan(app: FastAPI):
app.state.retriever = MedicalRetriever.load(str(RETRIEVER_PATH))
app.state.pipeline = QAPipeline()
app.state.pipeline.retriever_ref = app.state.retriever
print("Server ready")
yield
app = FastAPI(lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
class QuestionRequest(BaseModel):
question: str
@app.get("/health")
async def health() -> dict:
passages_count = len(app.state.retriever.corpus)
return {"status": "ok", "model": "bert+flan-t5", "passages": passages_count}
@app.post("/predict")
async def predict(payload: QuestionRequest) -> dict:
question = payload.question.strip()
if not question:
raise HTTPException(status_code=400, detail="Question cannot be empty")
try:
passages = app.state.retriever.retrieve(question, top_k=5)
result = app.state.pipeline.answer(question, passages)
return result
except Exception as e:
return JSONResponse(status_code=500, content={"error": str(e)})
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)