File size: 3,339 Bytes
c7256ee c27a4e3 c7256ee c27a4e3 c7256ee | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 | import time
from typing import Any
from fastapi import APIRouter, HTTPException
from backend.schemas import PredictRequest, PredictResponse
from backend.services.chunks import build_retrieved_chunks
from backend.services.models import resolve_model
from backend.state import state
from retriever.generator import RAGGenerator
from retriever.retriever import HybridRetriever
router = APIRouter()
@router.post("/predict", response_model=PredictResponse)
def predict(payload: PredictRequest) -> PredictResponse:
req_start = time.perf_counter()
precheck_start = time.perf_counter()
if not state:
raise HTTPException(status_code=503, detail="Service not initialized yet")
query = payload.query.strip()
if not query:
raise HTTPException(status_code=400, detail="Query cannot be empty")
precheck_time = time.perf_counter() - precheck_start
state_access_start = time.perf_counter()
retriever: HybridRetriever = state["retriever"]
index = state["index"]
rag_engine: RAGGenerator = state["rag_engine"]
models: dict[str, Any] = state["models"]
chunk_lookup: dict[str, dict[str, Any]] = state.get("chunk_lookup", {})
state_access_time = time.perf_counter() - state_access_start
model_resolve_start = time.perf_counter()
model_name, model_instance = resolve_model(payload.model, models)
model_resolve_time = time.perf_counter() - model_resolve_start
retrieval_start = time.perf_counter()
contexts, chunk_score = retriever.search(
query,
index,
chunking_technique=payload.chunking_technique,
mode=payload.mode,
rerank_strategy=payload.rerank_strategy,
use_mmr=payload.use_mmr,
lambda_param=payload.lambda_param,
top_k=payload.top_k,
final_k=payload.final_k,
verbose=False,
)
retrieval_time = time.perf_counter() - retrieval_start
if not contexts:
raise HTTPException(status_code=404, detail="No context chunks retrieved for this query")
inference_start = time.perf_counter()
answer = rag_engine.get_answer(model_instance, query, contexts, temperature=payload.temperature)
inference_time = time.perf_counter() - inference_start
mapping_start = time.perf_counter()
retrieved_chunks = build_retrieved_chunks(contexts=contexts, chunk_lookup=chunk_lookup)
mapping_time = time.perf_counter() - mapping_start
total_time = time.perf_counter() - req_start
print(
f"Predict timing | model={model_name} | mode={payload.mode} | "
f"rerank={payload.rerank_strategy} | use_mmr={payload.use_mmr} | "
f"lambda={payload.lambda_param:.2f} | temp={payload.temperature:.2f} | "
f"chunking={payload.chunking_technique} | "
f"top_k={payload.top_k} | final_k={payload.final_k} | returned={len(contexts)} | "
f"chunk_score={chunk_score:.4f} | "
f"precheck={precheck_time:.3f}s | "
f"state_access={state_access_time:.3f}s | model_resolve={model_resolve_time:.3f}s | "
f"retrieval={retrieval_time:.3f}s | inference={inference_time:.3f}s | "
f"context_map={mapping_time:.3f}s | total={total_time:.3f}s"
)
return PredictResponse(
model=model_name,
answer=answer,
contexts=contexts,
retrieved_chunks=retrieved_chunks,
)
|