File size: 16,748 Bytes
4ef165a
bbe01fe
26b51db
bbe01fe
 
 
 
 
 
 
 
 
 
 
9563e4a
 
 
 
 
 
8da917e
9563e4a
65543f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bbe01fe
1d47e3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e7c9ee6
 
 
 
 
 
 
 
efdd22e
e7c9ee6
d1766f7
 
 
 
 
e7c9ee6
d1766f7
 
 
 
 
 
 
 
 
e7c9ee6
 
d1766f7
 
 
 
 
 
 
e7c9ee6
 
 
 
 
d1766f7
 
 
 
e7c9ee6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4ef165a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bbe01fe
 
 
 
 
 
 
efdd22e
 
 
 
 
 
 
 
 
 
 
 
 
bbe01fe
 
 
65543f1
e7c9ee6
65543f1
 
 
4ef165a
65543f1
 
 
bbe01fe
4ef165a
 
 
 
 
 
 
 
 
 
 
 
26b51db
 
 
 
 
 
 
 
 
 
 
4ef165a
 
 
 
9563e4a
4ef165a
 
 
 
 
26b51db
 
 
 
9563e4a
26b51db
 
 
bbe01fe
 
 
 
26b51db
 
 
 
 
bbe01fe
 
 
 
 
 
 
8c8aea8
65543f1
 
bbe01fe
 
 
e7c9ee6
 
 
efdd22e
 
4ef165a
 
 
 
 
 
 
 
 
 
0da0699
 
26b51db
 
bbe01fe
 
 
 
 
 
 
 
 
9563e4a
 
 
efdd22e
 
 
9563e4a
efdd22e
 
9563e4a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bbe01fe
9563e4a
 
bbe01fe
 
efdd22e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bbe01fe
 
 
1d47e3c
 
 
26b51db
bbe01fe
 
 
 
 
 
 
efdd22e
 
 
 
 
bbe01fe
e7c9ee6
 
 
 
 
 
 
efdd22e
e7c9ee6
4ef165a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bbe01fe
efdd22e
bbe01fe
 
 
 
9563e4a
 
 
 
 
bbe01fe
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
import asyncio
import json
import re
import time
from fastapi import APIRouter, Request, Depends
from fastapi.responses import StreamingResponse

from app.models.chat import ChatRequest
from app.models.pipeline import PipelineState
from app.security.rate_limiter import chat_rate_limit
from app.security.jwt_auth import verify_jwt

router = APIRouter()

# Keep-alive interval for SSE when upstream nodes are still working.
# Prevents edge/proxy idle timeouts on long retrieval/generation turns.
_SSE_HEARTBEAT_SECONDS: float = 10.0

# Query pre-processing budgets must stay low to avoid delaying first byte.
_DECONTEXT_TIMEOUT_SECONDS: float = 0.35
_EXPANSION_TIMEOUT_SECONDS: float = 0.60

# Phrases a visitor uses when telling the bot it gave a wrong answer.
# Matched on the lowercased raw message before any LLM call β€” O(1), zero cost.
_CRITICISM_SIGNALS: frozenset[str] = frozenset({
    "that's wrong", "thats wrong", "you're wrong", "youre wrong",
    "not right", "wrong answer", "you got it wrong", "that is wrong",
    "that's incorrect", "you're incorrect", "thats incorrect", "youre incorrect",
    "fix that", "fix your answer", "actually no", "no that's", "no thats",
    "that was wrong", "your answer was wrong", "wrong information",
    "incorrect information", "that's not right", "thats not right",
})


def _is_criticism(message: str) -> bool:
    lowered = message.lower()
    return any(sig in lowered for sig in _CRITICISM_SIGNALS)


def _filter_sources_by_citations(answer: str, sources: list) -> list:
    """
    Keep only sources explicitly cited in answer text.

    If sources are already pre-filtered upstream (e.g. generate node returned
    only cited sources from original indices), citation numbers may no longer
    match local list positions. In that case, keep the original list unchanged.
    """
    if not answer or not sources:
        return sources

    cited_nums = {int(m) for m in re.findall(r"\[(\d+)\]", answer)}
    if not cited_nums:
        return sources

    max_cited = max(cited_nums)
    if max_cited > len(sources):
        return sources

    return [s for i, s in enumerate(sources, start=1) if i in cited_nums]


async def _generate_follow_ups(
    query: str,
    answer: str,
    sources: list,
    llm_client,
) -> list[str]:
    """
    Generates 3 specific follow-up questions after the main answer is complete.
    Runs after the answer stream finishes β€” zero added latency before first token.

    Questions MUST:
    - Be grounded in the source documents that were actually retrieved (not hypothetical).
    - Lead the visitor deeper into content the knowledge base ALREADY contains.
    - Never venture into topics not covered by the retrieved sources (no hallucinated follow-ups).
    - Be specific (< 12 words, no generic "tell me more" style).
    """
    # Collect source titles AND types so the LLM knows what was actually retrieved.
    source_info = []
    for s in sources[:4]:
        title = s.title if hasattr(s, "title") else s.get("title", "")
        src_type = s.source_type if hasattr(s, "source_type") else s.get("source_type", "")
        if title:
            source_info.append(f"{title} ({src_type})" if src_type else title)

    sources_str = "\n".join(f"- {si}" for si in source_info) if source_info else "- (no specific sources)"

    prompt = (
        f"Visitor's question: {query}\n\n"
        f"Answer given (excerpt): {answer[:500]}\n\n"
        f"Sources that were retrieved and cited in the answer:\n{sources_str}\n\n"
        "Write exactly 3 follow-up questions the visitor would logically ask NEXT, "
        "based ONLY on what was found in the sources above. "
        "Each question must be clearly answerable from the retrieved sources β€” "
        "do NOT invent topics that are not present in the sources listed. "
        "Each question must be under 12 words. "
        "Output ONLY the 3 questions, one per line, no numbering or bullet points."
    )
    system = (
        "You write concise follow-up questions for a portfolio chatbot. "
        "CRITICAL RULE: every question you write must be answerable from the source documents listed. "
        "Never invent follow-ups about topics, projects, or facts not mentioned in the retrieved sources. "
        "Never write generic questions like 'tell me more' or 'what else can you tell me'. "
        "Each question must be under 12 words and reference specifics from the answer and sources."
    )

    try:
        stream = llm_client.complete_with_complexity(
            prompt=prompt, system=system, stream=True, complexity="simple"
        )
        raw = ""
        async for token in stream:
            raw += token
        questions = [q.strip() for q in raw.strip().splitlines() if q.strip()][:3]
        return questions
    except Exception:
        return []


async def _update_summary_async(
    conv_store,
    gemini_client,
    session_id: str,
    previous_summary: str | None,
    query: str,
    answer: str,
    processing_api_key: str | None,
) -> None:
    """
    Triggered post-response to update the rolling conversation summary.
    Failures are silently swallowed β€” summary is best-effort context, not critical.
    """
    try:
        new_summary = await gemini_client.update_conversation_summary(
            previous_summary=previous_summary or "",
            new_turn_q=query,
            new_turn_a=answer[:600],  # cap answer chars sent to Gemini
            processing_api_key=processing_api_key,
        )
        if new_summary:
            conv_store.set_summary(session_id, new_summary)
    except Exception:
        pass


@router.post("")
@chat_rate_limit()
async def chat_endpoint(
    request: Request,
    request_data: ChatRequest,
    token_payload: dict = Depends(verify_jwt),
) -> StreamingResponse:
    """Stream RAG answer as typed SSE events.

    Event sequence for a full RAG request:
        event: status   β€” guard label, cache miss, gemini routing, retrieve labels
        event: reading  β€” one per unique source found in Qdrant (before rerank)
        event: sources  β€” final selected sources array (after rerank)
        event: thinking β€” CoT scratchpad tokens (70B only)
        event: token    β€” answer tokens
        event: follow_ups β€” three suggested follow-up questions

    For cache hits: status β†’ status β†’ token
    For Gemini fast-path: status β†’ status β†’ token
    """
    start_time = time.monotonic()

    pipeline = request.app.state.pipeline
    conv_store = request.app.state.conversation_store
    llm_client = request.app.state.llm_client
    session_id = request_data.session_id

    conversation_history = conv_store.get_recent(session_id)
    conversation_summary = conv_store.get_summary(session_id)
    criticism = _is_criticism(request_data.message)
    if criticism and conversation_history:
        conv_store.mark_last_negative(session_id)

    # Stage 2: decontextualize the query concurrently with Guard when we have a
    # rolling summary. Reference-heavy queries like "tell me more about that project"
    # embed poorly; a self-contained rewrite fixes retrieval without added latency
    # because Gemini Flash runs while Guard is classifying the query.
    gemini_client = getattr(request.app.state, "gemini_client", None)
    decontextualized_query: str | None = None
    decontext_task: asyncio.Task | None = None
    if conversation_summary and gemini_client and gemini_client.is_configured:
        decontext_task = asyncio.create_task(
            gemini_client.decontextualize_query(request_data.message, conversation_summary)
        )

    # Bug 4: concurrent query expansion β€” starts at request entry so it runs
    # while Guard, Cache, and Gemini-fast-path execute.  Result is ready before
    # the Retrieve node needs it (800 ms budget).  Gemini uses the TOON context
    # to generate canonical name forms (for BM25) and semantic expansions (for
    # dense multi-search).  Falls back to empty if Gemini unavailable or slow.
    expansion_task: asyncio.Task | None = None
    if gemini_client and gemini_client.is_configured:
        expansion_task = asyncio.create_task(
            gemini_client.expand_query(request_data.message)
        )

    # Await decontextualization result before the pipeline begins (retrieve node
    # will use it if present; Guard runs first so the latency is masked).
    if decontext_task is not None:
        try:
            result = await asyncio.wait_for(decontext_task, timeout=_DECONTEXT_TIMEOUT_SECONDS)
            if result and result.strip().lower() != request_data.message.strip().lower():
                decontextualized_query = result.strip()
        except Exception:
            pass  # Decontextualization is best-effort; fall back to raw query.

    # Await expansion result β€” 800 ms budget so Guard+Cache latency is fully masked.
    expansion_result: dict | None = None
    if expansion_task is not None:
        try:
            expansion_result = await asyncio.wait_for(expansion_task, timeout=_EXPANSION_TIMEOUT_SECONDS)
        except Exception:
            pass  # Expansion is best-effort; retriever falls back to raw query.

    initial_state: PipelineState = {  # type: ignore[assignment]
        "query": request_data.message,
        "session_id": request_data.session_id,
        "query_complexity": "simple",
        # Bug 4: seed expanded_queries with Gemini semantic expansions so the
        # retrieve node issues one dense search per expansion (up to 3 extras).
        # operator.add in PipelineState merges these with any queries added later
        # (e.g. the rag_query from gemini_fast routing to RAG).
        "expanded_queries": (expansion_result or {}).get("semantic_expansions", []),
        "retrieved_chunks": [],
        "reranked_chunks": [],
        "answer": "",
        "sources": [],
        "cached": False,
        "cache_key": None,
        "guard_passed": False,
        "thinking": False,
        "conversation_history": conversation_history,
        "is_criticism": criticism,
        "latency_ms": 0,
        "error": None,
        "interaction_id": None,
        "retrieval_attempts": 0,
        "rewritten_query": None,
        "follow_ups": [],
        "path": None,
        "query_topic": None,
        # Stage 1: follow-up bypass for Gemini fast-path
        "is_followup": request_data.is_followup,
        # Stage 2: progressive history summarisation
        "conversation_summary": conversation_summary or None,
        "decontextualized_query": decontextualized_query,
        # Stage 3: SELF-RAG critic scores (populated by generate node)
        "critic_groundedness": None,
        "critic_completeness": None,
        "critic_specificity": None,
        "critic_quality": None,
        # Fix 1: enumeration classifier β€” populated by enumerate_query node
        "is_enumeration_query": False,
        # Bug 4: query expansion β€” canonical name forms for BM25 union search.
        "query_canonical_forms": (expansion_result or {}).get("canonical_forms", []),
    }

    async def sse_generator():
        final_sources = []
        is_cached = False
        final_answer = ""
        interaction_id = None

        try:
            # Emit an early event so clients/proxies receive first bytes quickly.
            yield f"event: status\ndata: {json.dumps({'label': 'Starting response...'})}\n\n"

            # stream_mode=["custom", "updates"] yields (mode, data) tuples:
            #   mode="custom"  β†’ data is whatever writer(payload) was called with
            #   mode="updates" β†’ data is {node_name: state_updates_dict}
            stream_iter = pipeline.astream(
                initial_state,
                stream_mode=["custom", "updates"],
            ).__aiter__()
            next_item_task: asyncio.Task | None = asyncio.create_task(stream_iter.__anext__())

            while True:
                try:
                    mode, data = await asyncio.wait_for(
                        asyncio.shield(next_item_task),
                        timeout=_SSE_HEARTBEAT_SECONDS,
                    )
                except asyncio.TimeoutError:
                    if await request.is_disconnected():
                        if not next_item_task.done():
                            next_item_task.cancel()
                        break
                    yield f"event: ping\ndata: {json.dumps({'ts': int(time.time())})}\n\n"
                    continue
                except StopAsyncIteration:
                    break

                next_item_task = asyncio.create_task(stream_iter.__anext__())

                if await request.is_disconnected():
                    if not next_item_task.done():
                        next_item_task.cancel()
                    break

                if mode == "custom":
                    # Forward writer events as named SSE events.
                    # Each node emits {"type": "<event_name>", ...payload}.
                    event_type = data.get("type", "status")
                    # Strip the "type" key so the client receives a clean payload.
                    payload = {k: v for k, v in data.items() if k != "type"}
                    yield f"event: {event_type}\ndata: {json.dumps(payload)}\n\n"

                elif mode == "updates":
                    # Capture terminal state for the done event; do not re-emit tokens.
                    for _node_name, updates in data.items():
                        if "sources" in updates and updates["sources"]:
                            final_sources = updates["sources"]
                        if "cached" in updates:
                            is_cached = updates["cached"]
                        if "interaction_id" in updates and updates["interaction_id"] is not None:
                            interaction_id = updates["interaction_id"]
                        if "answer" in updates and updates["answer"]:
                            final_answer = updates["answer"]

            elapsed_ms = int((time.monotonic() - start_time) * 1000)

            # Citation-index filtering safety net for paths that return full
            # source lists. No-op when sources are already citation-filtered.
            final_sources = _filter_sources_by_citations(final_answer, final_sources)

            sources_list = [
                s.model_dump() if hasattr(s, "model_dump")
                else s.dict() if hasattr(s, "dict")
                else s
                for s in final_sources
            ]

            # The done event uses plain data: (no event: type) for backward
            # compatibility with widgets that listen on the raw data channel.
            yield (
                f"data: {json.dumps({'done': True, 'sources': sources_list, 'cached': is_cached, 'latency_ms': elapsed_ms, 'interaction_id': interaction_id})}\n\n"
            )

            # ── Follow-up questions ────────────────────────────────────────────
            # Generated after the done event so it never delays answer delivery.
            if final_answer and not await request.is_disconnected():
                follow_ups = await _generate_follow_ups(
                    request_data.message, final_answer, final_sources, llm_client
                )
                if follow_ups:
                    yield f"event: follow_ups\ndata: {json.dumps({'questions': follow_ups})}\n\n"

            # Stage 2: update rolling summary asynchronously β€” fired after the
            # response is fully delivered so it adds zero latency to the turn.
            if final_answer and gemini_client and gemini_client.is_configured:
                processing_key = getattr(
                    request.app.state, "gemini_processing_api_key", None
                )
                asyncio.create_task(
                    _update_summary_async(
                        conv_store=conv_store,
                        gemini_client=gemini_client,
                        session_id=session_id,
                        previous_summary=conversation_summary,
                        query=request_data.message,
                        answer=final_answer,
                        processing_api_key=processing_key,
                    )
                )

        except Exception as exc:
            yield f"data: {json.dumps({'error': str(exc) or 'Generation failed'})}\n\n"

    return StreamingResponse(
        sse_generator(),
        media_type="text/event-stream",
        headers={
            "Cache-Control": "no-cache",
            "X-Accel-Buffering": "no",
            "Connection": "keep-alive",
        },
    )