| import time |
| from typing import Any |
|
|
| from fastapi import APIRouter, HTTPException |
|
|
| from backend.schemas import PredictRequest, PredictResponse |
| from backend.services.chunks import build_retrieved_chunks |
| from backend.services.models import resolve_model |
| from backend.state import state |
| from retriever.generator import RAGGenerator |
| from retriever.retriever import HybridRetriever |
|
|
| router = APIRouter() |
|
|
|
|
| @router.post("/predict", response_model=PredictResponse) |
| def predict(payload: PredictRequest) -> PredictResponse: |
| req_start = time.perf_counter() |
|
|
| precheck_start = time.perf_counter() |
| if not state: |
| raise HTTPException(status_code=503, detail="Service not initialized yet") |
|
|
| query = payload.query.strip() |
| if not query: |
| raise HTTPException(status_code=400, detail="Query cannot be empty") |
| precheck_time = time.perf_counter() - precheck_start |
|
|
| state_access_start = time.perf_counter() |
| retriever: HybridRetriever = state["retriever"] |
| index = state["index"] |
| rag_engine: RAGGenerator = state["rag_engine"] |
| models: dict[str, Any] = state["models"] |
| chunk_lookup: dict[str, dict[str, Any]] = state.get("chunk_lookup", {}) |
| state_access_time = time.perf_counter() - state_access_start |
|
|
| model_resolve_start = time.perf_counter() |
| model_name, model_instance = resolve_model(payload.model, models) |
| model_resolve_time = time.perf_counter() - model_resolve_start |
|
|
| retrieval_start = time.perf_counter() |
| contexts = retriever.search( |
| query, |
| index, |
| chunking_technique=payload.chunking_technique, |
| mode=payload.mode, |
| rerank_strategy=payload.rerank_strategy, |
| use_mmr=payload.use_mmr, |
| lambda_param=payload.lambda_param, |
| top_k=payload.top_k, |
| final_k=payload.final_k, |
| verbose=False, |
| ) |
| retrieval_time = time.perf_counter() - retrieval_start |
|
|
| if not contexts: |
| raise HTTPException(status_code=404, detail="No context chunks retrieved for this query") |
|
|
| inference_start = time.perf_counter() |
| answer = rag_engine.get_answer(model_instance, query, contexts, temperature=payload.temperature) |
| inference_time = time.perf_counter() - inference_start |
|
|
| mapping_start = time.perf_counter() |
| retrieved_chunks = build_retrieved_chunks(contexts=contexts, chunk_lookup=chunk_lookup) |
| mapping_time = time.perf_counter() - mapping_start |
|
|
| total_time = time.perf_counter() - req_start |
|
|
| print( |
| f"Predict timing | model={model_name} | mode={payload.mode} | " |
| f"rerank={payload.rerank_strategy} | use_mmr={payload.use_mmr} | " |
| f"lambda={payload.lambda_param:.2f} | temp={payload.temperature:.2f} | " |
| f"chunking={payload.chunking_technique} | " |
| f"top_k={payload.top_k} | final_k={payload.final_k} | returned={len(contexts)} | " |
| f"precheck={precheck_time:.3f}s | " |
| f"state_access={state_access_time:.3f}s | model_resolve={model_resolve_time:.3f}s | " |
| f"retrieval={retrieval_time:.3f}s | inference={inference_time:.3f}s | " |
| f"context_map={mapping_time:.3f}s | total={total_time:.3f}s" |
| ) |
|
|
| return PredictResponse( |
| model=model_name, |
| answer=answer, |
| contexts=contexts, |
| retrieved_chunks=retrieved_chunks, |
| ) |
|
|
|
|
|
|