File size: 6,273 Bytes
c7256ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c27a4e3
c7256ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c27a4e3
c7256ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
860aa5d
c7256ee
 
 
 
 
 
 
 
 
 
 
 
 
 
c27a4e3
c7256ee
 
 
 
 
 
 
 
 
 
 
 
 
c27a4e3
c7256ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import os
import time
from typing import Any

from fastapi import APIRouter, HTTPException
from fastapi.responses import StreamingResponse

from backend.schemas import PredictRequest
from backend.services.chunks import build_retrieved_chunks
from backend.services.models import resolve_model
from backend.services.streaming import to_ndjson
from backend.state import state
from retriever.generator import RAGGenerator
from retriever.retriever import HybridRetriever

router = APIRouter()


# all paths define and API router object which is called
# in the api.py


@router.post("/predict/stream")
def predict_stream(payload: PredictRequest) -> StreamingResponse:
    req_start = time.perf_counter()
    stream_max_tokens = int(os.getenv("STREAM_MAX_TOKENS", "400"))

    precheck_start = time.perf_counter()
    if not state:
        raise HTTPException(status_code=503, detail="Service not initialized yet")

    query = payload.query.strip()
    if not query:
        raise HTTPException(status_code=400, detail="Query cannot be empty")
    precheck_time = time.perf_counter() - precheck_start

    state_access_start = time.perf_counter()
    retriever: HybridRetriever = state["retriever"]
    index = state["index"]
    rag_engine: RAGGenerator = state["rag_engine"]
    models: dict[str, Any] = state["models"]
    chunk_lookup: dict[str, dict[str, Any]] = state.get("chunk_lookup", {})
    state_access_time = time.perf_counter() - state_access_start

    model_resolve_start = time.perf_counter()
    model_name, model_instance = resolve_model(payload.model, models)
    model_resolve_time = time.perf_counter() - model_resolve_start

    retrieval_start = time.perf_counter()
    contexts, chunk_score = retriever.search(
        query,
        index,
        chunking_technique=payload.chunking_technique,
        mode=payload.mode,
        rerank_strategy=payload.rerank_strategy,
        use_mmr=payload.use_mmr,
        lambda_param=payload.lambda_param,
        top_k=payload.top_k,
        final_k=payload.final_k,
        verbose=False,
    )
    retrieval_time = time.perf_counter() - retrieval_start

    if not contexts:
        raise HTTPException(status_code=404, detail="No context chunks retrieved for this query")

    def stream_events():
        inference_start = time.perf_counter()
        first_token_latency = None
        answer_parts: list[str] = []
        try:
            yield to_ndjson(
                {
                    "type": "status",
                    "stage": "inference_start",
                    "model": model_name,
                    "retrieval_s": round(retrieval_time, 3),
                    "retrieval_debug": {
                        "requested_chunking_technique": payload.chunking_technique,
                        "requested_top_k": payload.top_k,
                        "requested_final_k": payload.final_k,
                        "returned_context_count": len(contexts),
                        "chunk_score": chunk_score,
                        "use_mmr": payload.use_mmr,
                        "lambda_param": payload.lambda_param,
                    },
                }
            )

            for token in rag_engine.get_answer_stream(
                model_instance,
                query,
                contexts,
                temperature=payload.temperature,
                max_tokens=stream_max_tokens,
            ):
                if first_token_latency is None:
                    first_token_latency = time.perf_counter() - inference_start
                answer_parts.append(token)
                yield to_ndjson({"type": "token", "token": token})

            inference_time = time.perf_counter() - inference_start
            answer = rag_engine.truncate_incomplete_tail("".join(answer_parts))
            retrieved_chunks = build_retrieved_chunks(contexts=contexts, chunk_lookup=chunk_lookup)

            yield to_ndjson(
                {
                    "type": "done",
                    "model": model_name,
                    "answer": answer,
                    "contexts": contexts,
                    "retrieved_chunks": retrieved_chunks,
                    "retrieval_debug": {
                        "requested_chunking_technique": payload.chunking_technique,
                        "requested_top_k": payload.top_k,
                        "requested_final_k": payload.final_k,
                        "returned_context_count": len(contexts),
                        "chunk_score": chunk_score,
                        "use_mmr": payload.use_mmr,
                        "lambda_param": payload.lambda_param,
                    },
                }
            )

            total_time = time.perf_counter() - req_start
            print(
                f"Predict stream timing | model={model_name} | mode={payload.mode} | "
                f"rerank={payload.rerank_strategy} | use_mmr={payload.use_mmr} | "
                f"lambda={payload.lambda_param:.2f} | temp={payload.temperature:.2f} | "
                f"chunking={payload.chunking_technique} | "
                f"top_k={payload.top_k} | final_k={payload.final_k} | returned={len(contexts)} | "
                f"chunk_score={chunk_score:.4f} | "
                f"precheck={precheck_time:.3f}s | "
                f"state_access={state_access_time:.3f}s | model_resolve={model_resolve_time:.3f}s | "
                f"retrieval={retrieval_time:.3f}s | first_token={first_token_latency if first_token_latency is not None else -1:.3f}s | "
                f"inference={inference_time:.3f}s | total={total_time:.3f}s | "
                f"max_tokens={stream_max_tokens}"
            )
        except Exception as exc:
            total_time = time.perf_counter() - req_start
            print(
                f"Predict stream error | model={model_name} | mode={payload.mode} | "
                f"retrieval={retrieval_time:.3f}s | elapsed={total_time:.3f}s | error={exc}"
            )
            yield to_ndjson({"type": "error", "message": f"Streaming failed: {exc}"})

    return StreamingResponse(
        stream_events(),
        media_type="application/x-ndjson",
        headers={
            "Cache-Control": "no-cache",
            "X-Accel-Buffering": "no",
        },
    )