NLP-RAG / backend /routes /predict_stream.py
Qar-Raz's picture
Sync backend Docker context from GitHub main
860aa5d verified
import os
import time
from typing import Any
from fastapi import APIRouter, HTTPException
from fastapi.responses import StreamingResponse
from backend.schemas import PredictRequest
from backend.services.chunks import build_retrieved_chunks
from backend.services.models import resolve_model
from backend.services.streaming import to_ndjson
from backend.state import state
from retriever.generator import RAGGenerator
from retriever.retriever import HybridRetriever
router = APIRouter()
# all paths define and API router object which is called
# in the api.py
@router.post("/predict/stream")
def predict_stream(payload: PredictRequest) -> StreamingResponse:
req_start = time.perf_counter()
stream_max_tokens = int(os.getenv("STREAM_MAX_TOKENS", "400"))
precheck_start = time.perf_counter()
if not state:
raise HTTPException(status_code=503, detail="Service not initialized yet")
query = payload.query.strip()
if not query:
raise HTTPException(status_code=400, detail="Query cannot be empty")
precheck_time = time.perf_counter() - precheck_start
state_access_start = time.perf_counter()
retriever: HybridRetriever = state["retriever"]
index = state["index"]
rag_engine: RAGGenerator = state["rag_engine"]
models: dict[str, Any] = state["models"]
chunk_lookup: dict[str, dict[str, Any]] = state.get("chunk_lookup", {})
state_access_time = time.perf_counter() - state_access_start
model_resolve_start = time.perf_counter()
model_name, model_instance = resolve_model(payload.model, models)
model_resolve_time = time.perf_counter() - model_resolve_start
retrieval_start = time.perf_counter()
contexts, chunk_score = retriever.search(
query,
index,
chunking_technique=payload.chunking_technique,
mode=payload.mode,
rerank_strategy=payload.rerank_strategy,
use_mmr=payload.use_mmr,
lambda_param=payload.lambda_param,
top_k=payload.top_k,
final_k=payload.final_k,
verbose=False,
)
retrieval_time = time.perf_counter() - retrieval_start
if not contexts:
raise HTTPException(status_code=404, detail="No context chunks retrieved for this query")
def stream_events():
inference_start = time.perf_counter()
first_token_latency = None
answer_parts: list[str] = []
try:
yield to_ndjson(
{
"type": "status",
"stage": "inference_start",
"model": model_name,
"retrieval_s": round(retrieval_time, 3),
"retrieval_debug": {
"requested_chunking_technique": payload.chunking_technique,
"requested_top_k": payload.top_k,
"requested_final_k": payload.final_k,
"returned_context_count": len(contexts),
"chunk_score": chunk_score,
"use_mmr": payload.use_mmr,
"lambda_param": payload.lambda_param,
},
}
)
for token in rag_engine.get_answer_stream(
model_instance,
query,
contexts,
temperature=payload.temperature,
max_tokens=stream_max_tokens,
):
if first_token_latency is None:
first_token_latency = time.perf_counter() - inference_start
answer_parts.append(token)
yield to_ndjson({"type": "token", "token": token})
inference_time = time.perf_counter() - inference_start
answer = rag_engine.truncate_incomplete_tail("".join(answer_parts))
retrieved_chunks = build_retrieved_chunks(contexts=contexts, chunk_lookup=chunk_lookup)
yield to_ndjson(
{
"type": "done",
"model": model_name,
"answer": answer,
"contexts": contexts,
"retrieved_chunks": retrieved_chunks,
"retrieval_debug": {
"requested_chunking_technique": payload.chunking_technique,
"requested_top_k": payload.top_k,
"requested_final_k": payload.final_k,
"returned_context_count": len(contexts),
"chunk_score": chunk_score,
"use_mmr": payload.use_mmr,
"lambda_param": payload.lambda_param,
},
}
)
total_time = time.perf_counter() - req_start
print(
f"Predict stream timing | model={model_name} | mode={payload.mode} | "
f"rerank={payload.rerank_strategy} | use_mmr={payload.use_mmr} | "
f"lambda={payload.lambda_param:.2f} | temp={payload.temperature:.2f} | "
f"chunking={payload.chunking_technique} | "
f"top_k={payload.top_k} | final_k={payload.final_k} | returned={len(contexts)} | "
f"chunk_score={chunk_score:.4f} | "
f"precheck={precheck_time:.3f}s | "
f"state_access={state_access_time:.3f}s | model_resolve={model_resolve_time:.3f}s | "
f"retrieval={retrieval_time:.3f}s | first_token={first_token_latency if first_token_latency is not None else -1:.3f}s | "
f"inference={inference_time:.3f}s | total={total_time:.3f}s | "
f"max_tokens={stream_max_tokens}"
)
except Exception as exc:
total_time = time.perf_counter() - req_start
print(
f"Predict stream error | model={model_name} | mode={payload.mode} | "
f"retrieval={retrieval_time:.3f}s | elapsed={total_time:.3f}s | error={exc}"
)
yield to_ndjson({"type": "error", "message": f"Streaming failed: {exc}"})
return StreamingResponse(
stream_events(),
media_type="application/x-ndjson",
headers={
"Cache-Control": "no-cache",
"X-Accel-Buffering": "no",
},
)