import time from typing import Any from fastapi import APIRouter, HTTPException from backend.schemas import PredictRequest, PredictResponse from backend.services.chunks import build_retrieved_chunks from backend.services.models import resolve_model from backend.state import state from retriever.generator import RAGGenerator from retriever.retriever import HybridRetriever router = APIRouter() @router.post("/predict", response_model=PredictResponse) def predict(payload: PredictRequest) -> PredictResponse: req_start = time.perf_counter() precheck_start = time.perf_counter() if not state: raise HTTPException(status_code=503, detail="Service not initialized yet") query = payload.query.strip() if not query: raise HTTPException(status_code=400, detail="Query cannot be empty") precheck_time = time.perf_counter() - precheck_start state_access_start = time.perf_counter() retriever: HybridRetriever = state["retriever"] index = state["index"] rag_engine: RAGGenerator = state["rag_engine"] models: dict[str, Any] = state["models"] chunk_lookup: dict[str, dict[str, Any]] = state.get("chunk_lookup", {}) state_access_time = time.perf_counter() - state_access_start model_resolve_start = time.perf_counter() model_name, model_instance = resolve_model(payload.model, models) model_resolve_time = time.perf_counter() - model_resolve_start retrieval_start = time.perf_counter() contexts, chunk_score = retriever.search( query, index, chunking_technique=payload.chunking_technique, mode=payload.mode, rerank_strategy=payload.rerank_strategy, use_mmr=payload.use_mmr, lambda_param=payload.lambda_param, top_k=payload.top_k, final_k=payload.final_k, verbose=False, ) retrieval_time = time.perf_counter() - retrieval_start if not contexts: raise HTTPException(status_code=404, detail="No context chunks retrieved for this query") inference_start = time.perf_counter() answer = rag_engine.get_answer(model_instance, query, contexts, temperature=payload.temperature) inference_time = time.perf_counter() - inference_start mapping_start = time.perf_counter() retrieved_chunks = build_retrieved_chunks(contexts=contexts, chunk_lookup=chunk_lookup) mapping_time = time.perf_counter() - mapping_start total_time = time.perf_counter() - req_start print( f"Predict timing | model={model_name} | mode={payload.mode} | " f"rerank={payload.rerank_strategy} | use_mmr={payload.use_mmr} | " f"lambda={payload.lambda_param:.2f} | temp={payload.temperature:.2f} | " f"chunking={payload.chunking_technique} | " f"top_k={payload.top_k} | final_k={payload.final_k} | returned={len(contexts)} | " f"chunk_score={chunk_score:.4f} | " f"precheck={precheck_time:.3f}s | " f"state_access={state_access_time:.3f}s | model_resolve={model_resolve_time:.3f}s | " f"retrieval={retrieval_time:.3f}s | inference={inference_time:.3f}s | " f"context_map={mapping_time:.3f}s | total={total_time:.3f}s" ) return PredictResponse( model=model_name, answer=answer, contexts=contexts, retrieved_chunks=retrieved_chunks, )