| from fastapi import FastAPI, HTTPException |
| from contextlib import asynccontextmanager |
|
|
| from config import settings |
| from rag_service import preload, rag_query, state |
| from schemas import QueryRequest, QueryResponse |
|
|
|
|
| @asynccontextmanager |
| async def lifespan(_app: FastAPI): |
| preload() |
| yield |
|
|
|
|
| app = FastAPI(title=settings.app_title, lifespan=lifespan) |
|
|
|
|
| @app.get("/") |
| def root(): |
| model_runtime_device = None |
| if state.model is not None: |
| model_runtime_device = str(next(state.model.parameters()).device) |
| return { |
| "message": "RAG API is running", |
| "device": state.device, |
| "model_runtime_device": model_runtime_device, |
| "model_dtype": str(state.model_dtype), |
| "startup_timing": state.startup_timing, |
| } |
|
|
|
|
| @app.post("/query", response_model=QueryResponse) |
| def query(payload: QueryRequest): |
| if state.index is None or state.embedding_model is None or state.model is None: |
| raise HTTPException(status_code=503, detail="Model is not loaded yet") |
| result = rag_query(payload.question, k=payload.k) |
| return QueryResponse(**result) |