import os import time from typing import Any from fastapi import APIRouter, HTTPException from fastapi.responses import StreamingResponse from backend.schemas import PredictRequest from backend.services.chunks import build_retrieved_chunks from backend.services.models import resolve_model from backend.services.streaming import to_ndjson from backend.state import state from retriever.generator import RAGGenerator from retriever.retriever import HybridRetriever router = APIRouter() # all paths define and API router object which is called # in the api.py @router.post("/predict/stream") def predict_stream(payload: PredictRequest) -> StreamingResponse: req_start = time.perf_counter() stream_max_tokens = int(os.getenv("STREAM_MAX_TOKENS", "400")) precheck_start = time.perf_counter() if not state: raise HTTPException(status_code=503, detail="Service not initialized yet") query = payload.query.strip() if not query: raise HTTPException(status_code=400, detail="Query cannot be empty") precheck_time = time.perf_counter() - precheck_start state_access_start = time.perf_counter() retriever: HybridRetriever = state["retriever"] index = state["index"] rag_engine: RAGGenerator = state["rag_engine"] models: dict[str, Any] = state["models"] chunk_lookup: dict[str, dict[str, Any]] = state.get("chunk_lookup", {}) state_access_time = time.perf_counter() - state_access_start model_resolve_start = time.perf_counter() model_name, model_instance = resolve_model(payload.model, models) model_resolve_time = time.perf_counter() - model_resolve_start retrieval_start = time.perf_counter() contexts, chunk_score = retriever.search( query, index, chunking_technique=payload.chunking_technique, mode=payload.mode, rerank_strategy=payload.rerank_strategy, use_mmr=payload.use_mmr, lambda_param=payload.lambda_param, top_k=payload.top_k, final_k=payload.final_k, verbose=False, ) retrieval_time = time.perf_counter() - retrieval_start if not contexts: raise HTTPException(status_code=404, detail="No context chunks retrieved for this query") def stream_events(): inference_start = time.perf_counter() first_token_latency = None answer_parts: list[str] = [] try: yield to_ndjson( { "type": "status", "stage": "inference_start", "model": model_name, "retrieval_s": round(retrieval_time, 3), "retrieval_debug": { "requested_chunking_technique": payload.chunking_technique, "requested_top_k": payload.top_k, "requested_final_k": payload.final_k, "returned_context_count": len(contexts), "chunk_score": chunk_score, "use_mmr": payload.use_mmr, "lambda_param": payload.lambda_param, }, } ) for token in rag_engine.get_answer_stream( model_instance, query, contexts, temperature=payload.temperature, max_tokens=stream_max_tokens, ): if first_token_latency is None: first_token_latency = time.perf_counter() - inference_start answer_parts.append(token) yield to_ndjson({"type": "token", "token": token}) inference_time = time.perf_counter() - inference_start answer = rag_engine.truncate_incomplete_tail("".join(answer_parts)) retrieved_chunks = build_retrieved_chunks(contexts=contexts, chunk_lookup=chunk_lookup) yield to_ndjson( { "type": "done", "model": model_name, "answer": answer, "contexts": contexts, "retrieved_chunks": retrieved_chunks, "retrieval_debug": { "requested_chunking_technique": payload.chunking_technique, "requested_top_k": payload.top_k, "requested_final_k": payload.final_k, "returned_context_count": len(contexts), "chunk_score": chunk_score, "use_mmr": payload.use_mmr, "lambda_param": payload.lambda_param, }, } ) total_time = time.perf_counter() - req_start print( f"Predict stream timing | model={model_name} | mode={payload.mode} | " f"rerank={payload.rerank_strategy} | use_mmr={payload.use_mmr} | " f"lambda={payload.lambda_param:.2f} | temp={payload.temperature:.2f} | " f"chunking={payload.chunking_technique} | " f"top_k={payload.top_k} | final_k={payload.final_k} | returned={len(contexts)} | " f"chunk_score={chunk_score:.4f} | " f"precheck={precheck_time:.3f}s | " f"state_access={state_access_time:.3f}s | model_resolve={model_resolve_time:.3f}s | " f"retrieval={retrieval_time:.3f}s | first_token={first_token_latency if first_token_latency is not None else -1:.3f}s | " f"inference={inference_time:.3f}s | total={total_time:.3f}s | " f"max_tokens={stream_max_tokens}" ) except Exception as exc: total_time = time.perf_counter() - req_start print( f"Predict stream error | model={model_name} | mode={payload.mode} | " f"retrieval={retrieval_time:.3f}s | elapsed={total_time:.3f}s | error={exc}" ) yield to_ndjson({"type": "error", "message": f"Streaming failed: {exc}"}) return StreamingResponse( stream_events(), media_type="application/x-ndjson", headers={ "Cache-Control": "no-cache", "X-Accel-Buffering": "no", }, )