| import os |
| import time |
| from typing import Any |
|
|
| from fastapi import APIRouter, HTTPException |
| from fastapi.responses import StreamingResponse |
|
|
| from backend.schemas import PredictRequest |
| from backend.services.chunks import build_retrieved_chunks |
| from backend.services.models import resolve_model |
| from backend.services.streaming import to_ndjson |
| from backend.state import state |
| from retriever.generator import RAGGenerator |
| from retriever.retriever import HybridRetriever |
|
|
| router = APIRouter() |
|
|
|
|
| |
| |
|
|
|
|
| @router.post("/predict/stream") |
| def predict_stream(payload: PredictRequest) -> StreamingResponse: |
| req_start = time.perf_counter() |
| stream_max_tokens = int(os.getenv("STREAM_MAX_TOKENS", "400")) |
|
|
| precheck_start = time.perf_counter() |
| if not state: |
| raise HTTPException(status_code=503, detail="Service not initialized yet") |
|
|
| query = payload.query.strip() |
| if not query: |
| raise HTTPException(status_code=400, detail="Query cannot be empty") |
| precheck_time = time.perf_counter() - precheck_start |
|
|
| state_access_start = time.perf_counter() |
| retriever: HybridRetriever = state["retriever"] |
| index = state["index"] |
| rag_engine: RAGGenerator = state["rag_engine"] |
| models: dict[str, Any] = state["models"] |
| chunk_lookup: dict[str, dict[str, Any]] = state.get("chunk_lookup", {}) |
| state_access_time = time.perf_counter() - state_access_start |
|
|
| model_resolve_start = time.perf_counter() |
| model_name, model_instance = resolve_model(payload.model, models) |
| model_resolve_time = time.perf_counter() - model_resolve_start |
|
|
| retrieval_start = time.perf_counter() |
| contexts, chunk_score = retriever.search( |
| query, |
| index, |
| chunking_technique=payload.chunking_technique, |
| mode=payload.mode, |
| rerank_strategy=payload.rerank_strategy, |
| use_mmr=payload.use_mmr, |
| lambda_param=payload.lambda_param, |
| top_k=payload.top_k, |
| final_k=payload.final_k, |
| verbose=False, |
| ) |
| retrieval_time = time.perf_counter() - retrieval_start |
|
|
| if not contexts: |
| raise HTTPException(status_code=404, detail="No context chunks retrieved for this query") |
|
|
| def stream_events(): |
| inference_start = time.perf_counter() |
| first_token_latency = None |
| answer_parts: list[str] = [] |
| try: |
| yield to_ndjson( |
| { |
| "type": "status", |
| "stage": "inference_start", |
| "model": model_name, |
| "retrieval_s": round(retrieval_time, 3), |
| "retrieval_debug": { |
| "requested_chunking_technique": payload.chunking_technique, |
| "requested_top_k": payload.top_k, |
| "requested_final_k": payload.final_k, |
| "returned_context_count": len(contexts), |
| "chunk_score": chunk_score, |
| "use_mmr": payload.use_mmr, |
| "lambda_param": payload.lambda_param, |
| }, |
| } |
| ) |
|
|
| for token in rag_engine.get_answer_stream( |
| model_instance, |
| query, |
| contexts, |
| temperature=payload.temperature, |
| max_tokens=stream_max_tokens, |
| ): |
| if first_token_latency is None: |
| first_token_latency = time.perf_counter() - inference_start |
| answer_parts.append(token) |
| yield to_ndjson({"type": "token", "token": token}) |
|
|
| inference_time = time.perf_counter() - inference_start |
| answer = rag_engine.truncate_incomplete_tail("".join(answer_parts)) |
| retrieved_chunks = build_retrieved_chunks(contexts=contexts, chunk_lookup=chunk_lookup) |
|
|
| yield to_ndjson( |
| { |
| "type": "done", |
| "model": model_name, |
| "answer": answer, |
| "contexts": contexts, |
| "retrieved_chunks": retrieved_chunks, |
| "retrieval_debug": { |
| "requested_chunking_technique": payload.chunking_technique, |
| "requested_top_k": payload.top_k, |
| "requested_final_k": payload.final_k, |
| "returned_context_count": len(contexts), |
| "chunk_score": chunk_score, |
| "use_mmr": payload.use_mmr, |
| "lambda_param": payload.lambda_param, |
| }, |
| } |
| ) |
|
|
| total_time = time.perf_counter() - req_start |
| print( |
| f"Predict stream timing | model={model_name} | mode={payload.mode} | " |
| f"rerank={payload.rerank_strategy} | use_mmr={payload.use_mmr} | " |
| f"lambda={payload.lambda_param:.2f} | temp={payload.temperature:.2f} | " |
| f"chunking={payload.chunking_technique} | " |
| f"top_k={payload.top_k} | final_k={payload.final_k} | returned={len(contexts)} | " |
| f"chunk_score={chunk_score:.4f} | " |
| f"precheck={precheck_time:.3f}s | " |
| f"state_access={state_access_time:.3f}s | model_resolve={model_resolve_time:.3f}s | " |
| f"retrieval={retrieval_time:.3f}s | first_token={first_token_latency if first_token_latency is not None else -1:.3f}s | " |
| f"inference={inference_time:.3f}s | total={total_time:.3f}s | " |
| f"max_tokens={stream_max_tokens}" |
| ) |
| except Exception as exc: |
| total_time = time.perf_counter() - req_start |
| print( |
| f"Predict stream error | model={model_name} | mode={payload.mode} | " |
| f"retrieval={retrieval_time:.3f}s | elapsed={total_time:.3f}s | error={exc}" |
| ) |
| yield to_ndjson({"type": "error", "message": f"Streaming failed: {exc}"}) |
|
|
| return StreamingResponse( |
| stream_events(), |
| media_type="application/x-ndjson", |
| headers={ |
| "Cache-Control": "no-cache", |
| "X-Accel-Buffering": "no", |
| }, |
| ) |
|
|