per_new / main.py
Pujan-Dev's picture
Upload 8 files
128b0a8 verified
from fastapi import FastAPI, HTTPException
from contextlib import asynccontextmanager
from config import settings
from rag_service import preload, rag_query, state
from schemas import QueryRequest, QueryResponse
@asynccontextmanager
async def lifespan(_app: FastAPI):
preload()
yield
app = FastAPI(title=settings.app_title, lifespan=lifespan)
@app.get("/")
def root():
model_runtime_device = None
if state.model is not None:
model_runtime_device = str(next(state.model.parameters()).device)
return {
"message": "RAG API is running",
"device": state.device,
"model_runtime_device": model_runtime_device,
"model_dtype": str(state.model_dtype),
"startup_timing": state.startup_timing,
}
@app.post("/query", response_model=QueryResponse)
def query(payload: QueryRequest):
if state.index is None or state.embedding_model is None or state.model is None:
raise HTTPException(status_code=503, detail="Model is not loaded yet")
result = rag_query(payload.question, k=payload.k)
return QueryResponse(**result)