File size: 1,100 Bytes
128b0a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from fastapi import FastAPI, HTTPException
from contextlib import asynccontextmanager

from config import settings
from rag_service import preload, rag_query, state
from schemas import QueryRequest, QueryResponse


@asynccontextmanager
async def lifespan(_app: FastAPI):
    preload()
    yield


app = FastAPI(title=settings.app_title, lifespan=lifespan)


@app.get("/")
def root():
    model_runtime_device = None
    if state.model is not None:
        model_runtime_device = str(next(state.model.parameters()).device)
    return {
        "message": "RAG API is running",
        "device": state.device,
        "model_runtime_device": model_runtime_device,
        "model_dtype": str(state.model_dtype),
        "startup_timing": state.startup_timing,
    }


@app.post("/query", response_model=QueryResponse)
def query(payload: QueryRequest):
    if state.index is None or state.embedding_model is None or state.model is None:
        raise HTTPException(status_code=503, detail="Model is not loaded yet")
    result = rag_query(payload.question, k=payload.k)
    return QueryResponse(**result)