from fastapi import FastAPI, HTTPException from contextlib import asynccontextmanager from config import settings from rag_service import preload, rag_query, state from schemas import QueryRequest, QueryResponse @asynccontextmanager async def lifespan(_app: FastAPI): preload() yield app = FastAPI(title=settings.app_title, lifespan=lifespan) @app.get("/") def root(): model_runtime_device = None if state.model is not None: model_runtime_device = str(next(state.model.parameters()).device) return { "message": "RAG API is running", "device": state.device, "model_runtime_device": model_runtime_device, "model_dtype": str(state.model_dtype), "startup_timing": state.startup_timing, } @app.post("/query", response_model=QueryResponse) def query(payload: QueryRequest): if state.index is None or state.embedding_model is None or state.model is None: raise HTTPException(status_code=503, detail="Model is not loaded yet") result = rag_query(payload.question, k=payload.k) return QueryResponse(**result)