File size: 6,273 Bytes
c7256ee c27a4e3 c7256ee c27a4e3 c7256ee 860aa5d c7256ee c27a4e3 c7256ee c27a4e3 c7256ee | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 | import os
import time
from typing import Any
from fastapi import APIRouter, HTTPException
from fastapi.responses import StreamingResponse
from backend.schemas import PredictRequest
from backend.services.chunks import build_retrieved_chunks
from backend.services.models import resolve_model
from backend.services.streaming import to_ndjson
from backend.state import state
from retriever.generator import RAGGenerator
from retriever.retriever import HybridRetriever
router = APIRouter()
# all paths define and API router object which is called
# in the api.py
@router.post("/predict/stream")
def predict_stream(payload: PredictRequest) -> StreamingResponse:
req_start = time.perf_counter()
stream_max_tokens = int(os.getenv("STREAM_MAX_TOKENS", "400"))
precheck_start = time.perf_counter()
if not state:
raise HTTPException(status_code=503, detail="Service not initialized yet")
query = payload.query.strip()
if not query:
raise HTTPException(status_code=400, detail="Query cannot be empty")
precheck_time = time.perf_counter() - precheck_start
state_access_start = time.perf_counter()
retriever: HybridRetriever = state["retriever"]
index = state["index"]
rag_engine: RAGGenerator = state["rag_engine"]
models: dict[str, Any] = state["models"]
chunk_lookup: dict[str, dict[str, Any]] = state.get("chunk_lookup", {})
state_access_time = time.perf_counter() - state_access_start
model_resolve_start = time.perf_counter()
model_name, model_instance = resolve_model(payload.model, models)
model_resolve_time = time.perf_counter() - model_resolve_start
retrieval_start = time.perf_counter()
contexts, chunk_score = retriever.search(
query,
index,
chunking_technique=payload.chunking_technique,
mode=payload.mode,
rerank_strategy=payload.rerank_strategy,
use_mmr=payload.use_mmr,
lambda_param=payload.lambda_param,
top_k=payload.top_k,
final_k=payload.final_k,
verbose=False,
)
retrieval_time = time.perf_counter() - retrieval_start
if not contexts:
raise HTTPException(status_code=404, detail="No context chunks retrieved for this query")
def stream_events():
inference_start = time.perf_counter()
first_token_latency = None
answer_parts: list[str] = []
try:
yield to_ndjson(
{
"type": "status",
"stage": "inference_start",
"model": model_name,
"retrieval_s": round(retrieval_time, 3),
"retrieval_debug": {
"requested_chunking_technique": payload.chunking_technique,
"requested_top_k": payload.top_k,
"requested_final_k": payload.final_k,
"returned_context_count": len(contexts),
"chunk_score": chunk_score,
"use_mmr": payload.use_mmr,
"lambda_param": payload.lambda_param,
},
}
)
for token in rag_engine.get_answer_stream(
model_instance,
query,
contexts,
temperature=payload.temperature,
max_tokens=stream_max_tokens,
):
if first_token_latency is None:
first_token_latency = time.perf_counter() - inference_start
answer_parts.append(token)
yield to_ndjson({"type": "token", "token": token})
inference_time = time.perf_counter() - inference_start
answer = rag_engine.truncate_incomplete_tail("".join(answer_parts))
retrieved_chunks = build_retrieved_chunks(contexts=contexts, chunk_lookup=chunk_lookup)
yield to_ndjson(
{
"type": "done",
"model": model_name,
"answer": answer,
"contexts": contexts,
"retrieved_chunks": retrieved_chunks,
"retrieval_debug": {
"requested_chunking_technique": payload.chunking_technique,
"requested_top_k": payload.top_k,
"requested_final_k": payload.final_k,
"returned_context_count": len(contexts),
"chunk_score": chunk_score,
"use_mmr": payload.use_mmr,
"lambda_param": payload.lambda_param,
},
}
)
total_time = time.perf_counter() - req_start
print(
f"Predict stream timing | model={model_name} | mode={payload.mode} | "
f"rerank={payload.rerank_strategy} | use_mmr={payload.use_mmr} | "
f"lambda={payload.lambda_param:.2f} | temp={payload.temperature:.2f} | "
f"chunking={payload.chunking_technique} | "
f"top_k={payload.top_k} | final_k={payload.final_k} | returned={len(contexts)} | "
f"chunk_score={chunk_score:.4f} | "
f"precheck={precheck_time:.3f}s | "
f"state_access={state_access_time:.3f}s | model_resolve={model_resolve_time:.3f}s | "
f"retrieval={retrieval_time:.3f}s | first_token={first_token_latency if first_token_latency is not None else -1:.3f}s | "
f"inference={inference_time:.3f}s | total={total_time:.3f}s | "
f"max_tokens={stream_max_tokens}"
)
except Exception as exc:
total_time = time.perf_counter() - req_start
print(
f"Predict stream error | model={model_name} | mode={payload.mode} | "
f"retrieval={retrieval_time:.3f}s | elapsed={total_time:.3f}s | error={exc}"
)
yield to_ndjson({"type": "error", "message": f"Streaming failed: {exc}"})
return StreamingResponse(
stream_events(),
media_type="application/x-ndjson",
headers={
"Cache-Control": "no-cache",
"X-Accel-Buffering": "no",
},
)
|