Spaces:
Running
Running
File size: 16,748 Bytes
4ef165a bbe01fe 26b51db bbe01fe 9563e4a 8da917e 9563e4a 65543f1 bbe01fe 1d47e3c e7c9ee6 efdd22e e7c9ee6 d1766f7 e7c9ee6 d1766f7 e7c9ee6 d1766f7 e7c9ee6 d1766f7 e7c9ee6 4ef165a bbe01fe efdd22e bbe01fe 65543f1 e7c9ee6 65543f1 4ef165a 65543f1 bbe01fe 4ef165a 26b51db 4ef165a 9563e4a 4ef165a 26b51db 9563e4a 26b51db bbe01fe 26b51db bbe01fe 8c8aea8 65543f1 bbe01fe e7c9ee6 efdd22e 4ef165a 0da0699 26b51db bbe01fe 9563e4a efdd22e 9563e4a efdd22e 9563e4a bbe01fe 9563e4a bbe01fe efdd22e bbe01fe 1d47e3c 26b51db bbe01fe efdd22e bbe01fe e7c9ee6 efdd22e e7c9ee6 4ef165a bbe01fe efdd22e bbe01fe 9563e4a bbe01fe | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 | import asyncio
import json
import re
import time
from fastapi import APIRouter, Request, Depends
from fastapi.responses import StreamingResponse
from app.models.chat import ChatRequest
from app.models.pipeline import PipelineState
from app.security.rate_limiter import chat_rate_limit
from app.security.jwt_auth import verify_jwt
router = APIRouter()
# Keep-alive interval for SSE when upstream nodes are still working.
# Prevents edge/proxy idle timeouts on long retrieval/generation turns.
_SSE_HEARTBEAT_SECONDS: float = 10.0
# Query pre-processing budgets must stay low to avoid delaying first byte.
_DECONTEXT_TIMEOUT_SECONDS: float = 0.35
_EXPANSION_TIMEOUT_SECONDS: float = 0.60
# Phrases a visitor uses when telling the bot it gave a wrong answer.
# Matched on the lowercased raw message before any LLM call β O(1), zero cost.
_CRITICISM_SIGNALS: frozenset[str] = frozenset({
"that's wrong", "thats wrong", "you're wrong", "youre wrong",
"not right", "wrong answer", "you got it wrong", "that is wrong",
"that's incorrect", "you're incorrect", "thats incorrect", "youre incorrect",
"fix that", "fix your answer", "actually no", "no that's", "no thats",
"that was wrong", "your answer was wrong", "wrong information",
"incorrect information", "that's not right", "thats not right",
})
def _is_criticism(message: str) -> bool:
lowered = message.lower()
return any(sig in lowered for sig in _CRITICISM_SIGNALS)
def _filter_sources_by_citations(answer: str, sources: list) -> list:
"""
Keep only sources explicitly cited in answer text.
If sources are already pre-filtered upstream (e.g. generate node returned
only cited sources from original indices), citation numbers may no longer
match local list positions. In that case, keep the original list unchanged.
"""
if not answer or not sources:
return sources
cited_nums = {int(m) for m in re.findall(r"\[(\d+)\]", answer)}
if not cited_nums:
return sources
max_cited = max(cited_nums)
if max_cited > len(sources):
return sources
return [s for i, s in enumerate(sources, start=1) if i in cited_nums]
async def _generate_follow_ups(
query: str,
answer: str,
sources: list,
llm_client,
) -> list[str]:
"""
Generates 3 specific follow-up questions after the main answer is complete.
Runs after the answer stream finishes β zero added latency before first token.
Questions MUST:
- Be grounded in the source documents that were actually retrieved (not hypothetical).
- Lead the visitor deeper into content the knowledge base ALREADY contains.
- Never venture into topics not covered by the retrieved sources (no hallucinated follow-ups).
- Be specific (< 12 words, no generic "tell me more" style).
"""
# Collect source titles AND types so the LLM knows what was actually retrieved.
source_info = []
for s in sources[:4]:
title = s.title if hasattr(s, "title") else s.get("title", "")
src_type = s.source_type if hasattr(s, "source_type") else s.get("source_type", "")
if title:
source_info.append(f"{title} ({src_type})" if src_type else title)
sources_str = "\n".join(f"- {si}" for si in source_info) if source_info else "- (no specific sources)"
prompt = (
f"Visitor's question: {query}\n\n"
f"Answer given (excerpt): {answer[:500]}\n\n"
f"Sources that were retrieved and cited in the answer:\n{sources_str}\n\n"
"Write exactly 3 follow-up questions the visitor would logically ask NEXT, "
"based ONLY on what was found in the sources above. "
"Each question must be clearly answerable from the retrieved sources β "
"do NOT invent topics that are not present in the sources listed. "
"Each question must be under 12 words. "
"Output ONLY the 3 questions, one per line, no numbering or bullet points."
)
system = (
"You write concise follow-up questions for a portfolio chatbot. "
"CRITICAL RULE: every question you write must be answerable from the source documents listed. "
"Never invent follow-ups about topics, projects, or facts not mentioned in the retrieved sources. "
"Never write generic questions like 'tell me more' or 'what else can you tell me'. "
"Each question must be under 12 words and reference specifics from the answer and sources."
)
try:
stream = llm_client.complete_with_complexity(
prompt=prompt, system=system, stream=True, complexity="simple"
)
raw = ""
async for token in stream:
raw += token
questions = [q.strip() for q in raw.strip().splitlines() if q.strip()][:3]
return questions
except Exception:
return []
async def _update_summary_async(
conv_store,
gemini_client,
session_id: str,
previous_summary: str | None,
query: str,
answer: str,
processing_api_key: str | None,
) -> None:
"""
Triggered post-response to update the rolling conversation summary.
Failures are silently swallowed β summary is best-effort context, not critical.
"""
try:
new_summary = await gemini_client.update_conversation_summary(
previous_summary=previous_summary or "",
new_turn_q=query,
new_turn_a=answer[:600], # cap answer chars sent to Gemini
processing_api_key=processing_api_key,
)
if new_summary:
conv_store.set_summary(session_id, new_summary)
except Exception:
pass
@router.post("")
@chat_rate_limit()
async def chat_endpoint(
request: Request,
request_data: ChatRequest,
token_payload: dict = Depends(verify_jwt),
) -> StreamingResponse:
"""Stream RAG answer as typed SSE events.
Event sequence for a full RAG request:
event: status β guard label, cache miss, gemini routing, retrieve labels
event: reading β one per unique source found in Qdrant (before rerank)
event: sources β final selected sources array (after rerank)
event: thinking β CoT scratchpad tokens (70B only)
event: token β answer tokens
event: follow_ups β three suggested follow-up questions
For cache hits: status β status β token
For Gemini fast-path: status β status β token
"""
start_time = time.monotonic()
pipeline = request.app.state.pipeline
conv_store = request.app.state.conversation_store
llm_client = request.app.state.llm_client
session_id = request_data.session_id
conversation_history = conv_store.get_recent(session_id)
conversation_summary = conv_store.get_summary(session_id)
criticism = _is_criticism(request_data.message)
if criticism and conversation_history:
conv_store.mark_last_negative(session_id)
# Stage 2: decontextualize the query concurrently with Guard when we have a
# rolling summary. Reference-heavy queries like "tell me more about that project"
# embed poorly; a self-contained rewrite fixes retrieval without added latency
# because Gemini Flash runs while Guard is classifying the query.
gemini_client = getattr(request.app.state, "gemini_client", None)
decontextualized_query: str | None = None
decontext_task: asyncio.Task | None = None
if conversation_summary and gemini_client and gemini_client.is_configured:
decontext_task = asyncio.create_task(
gemini_client.decontextualize_query(request_data.message, conversation_summary)
)
# Bug 4: concurrent query expansion β starts at request entry so it runs
# while Guard, Cache, and Gemini-fast-path execute. Result is ready before
# the Retrieve node needs it (800 ms budget). Gemini uses the TOON context
# to generate canonical name forms (for BM25) and semantic expansions (for
# dense multi-search). Falls back to empty if Gemini unavailable or slow.
expansion_task: asyncio.Task | None = None
if gemini_client and gemini_client.is_configured:
expansion_task = asyncio.create_task(
gemini_client.expand_query(request_data.message)
)
# Await decontextualization result before the pipeline begins (retrieve node
# will use it if present; Guard runs first so the latency is masked).
if decontext_task is not None:
try:
result = await asyncio.wait_for(decontext_task, timeout=_DECONTEXT_TIMEOUT_SECONDS)
if result and result.strip().lower() != request_data.message.strip().lower():
decontextualized_query = result.strip()
except Exception:
pass # Decontextualization is best-effort; fall back to raw query.
# Await expansion result β 800 ms budget so Guard+Cache latency is fully masked.
expansion_result: dict | None = None
if expansion_task is not None:
try:
expansion_result = await asyncio.wait_for(expansion_task, timeout=_EXPANSION_TIMEOUT_SECONDS)
except Exception:
pass # Expansion is best-effort; retriever falls back to raw query.
initial_state: PipelineState = { # type: ignore[assignment]
"query": request_data.message,
"session_id": request_data.session_id,
"query_complexity": "simple",
# Bug 4: seed expanded_queries with Gemini semantic expansions so the
# retrieve node issues one dense search per expansion (up to 3 extras).
# operator.add in PipelineState merges these with any queries added later
# (e.g. the rag_query from gemini_fast routing to RAG).
"expanded_queries": (expansion_result or {}).get("semantic_expansions", []),
"retrieved_chunks": [],
"reranked_chunks": [],
"answer": "",
"sources": [],
"cached": False,
"cache_key": None,
"guard_passed": False,
"thinking": False,
"conversation_history": conversation_history,
"is_criticism": criticism,
"latency_ms": 0,
"error": None,
"interaction_id": None,
"retrieval_attempts": 0,
"rewritten_query": None,
"follow_ups": [],
"path": None,
"query_topic": None,
# Stage 1: follow-up bypass for Gemini fast-path
"is_followup": request_data.is_followup,
# Stage 2: progressive history summarisation
"conversation_summary": conversation_summary or None,
"decontextualized_query": decontextualized_query,
# Stage 3: SELF-RAG critic scores (populated by generate node)
"critic_groundedness": None,
"critic_completeness": None,
"critic_specificity": None,
"critic_quality": None,
# Fix 1: enumeration classifier β populated by enumerate_query node
"is_enumeration_query": False,
# Bug 4: query expansion β canonical name forms for BM25 union search.
"query_canonical_forms": (expansion_result or {}).get("canonical_forms", []),
}
async def sse_generator():
final_sources = []
is_cached = False
final_answer = ""
interaction_id = None
try:
# Emit an early event so clients/proxies receive first bytes quickly.
yield f"event: status\ndata: {json.dumps({'label': 'Starting response...'})}\n\n"
# stream_mode=["custom", "updates"] yields (mode, data) tuples:
# mode="custom" β data is whatever writer(payload) was called with
# mode="updates" β data is {node_name: state_updates_dict}
stream_iter = pipeline.astream(
initial_state,
stream_mode=["custom", "updates"],
).__aiter__()
next_item_task: asyncio.Task | None = asyncio.create_task(stream_iter.__anext__())
while True:
try:
mode, data = await asyncio.wait_for(
asyncio.shield(next_item_task),
timeout=_SSE_HEARTBEAT_SECONDS,
)
except asyncio.TimeoutError:
if await request.is_disconnected():
if not next_item_task.done():
next_item_task.cancel()
break
yield f"event: ping\ndata: {json.dumps({'ts': int(time.time())})}\n\n"
continue
except StopAsyncIteration:
break
next_item_task = asyncio.create_task(stream_iter.__anext__())
if await request.is_disconnected():
if not next_item_task.done():
next_item_task.cancel()
break
if mode == "custom":
# Forward writer events as named SSE events.
# Each node emits {"type": "<event_name>", ...payload}.
event_type = data.get("type", "status")
# Strip the "type" key so the client receives a clean payload.
payload = {k: v for k, v in data.items() if k != "type"}
yield f"event: {event_type}\ndata: {json.dumps(payload)}\n\n"
elif mode == "updates":
# Capture terminal state for the done event; do not re-emit tokens.
for _node_name, updates in data.items():
if "sources" in updates and updates["sources"]:
final_sources = updates["sources"]
if "cached" in updates:
is_cached = updates["cached"]
if "interaction_id" in updates and updates["interaction_id"] is not None:
interaction_id = updates["interaction_id"]
if "answer" in updates and updates["answer"]:
final_answer = updates["answer"]
elapsed_ms = int((time.monotonic() - start_time) * 1000)
# Citation-index filtering safety net for paths that return full
# source lists. No-op when sources are already citation-filtered.
final_sources = _filter_sources_by_citations(final_answer, final_sources)
sources_list = [
s.model_dump() if hasattr(s, "model_dump")
else s.dict() if hasattr(s, "dict")
else s
for s in final_sources
]
# The done event uses plain data: (no event: type) for backward
# compatibility with widgets that listen on the raw data channel.
yield (
f"data: {json.dumps({'done': True, 'sources': sources_list, 'cached': is_cached, 'latency_ms': elapsed_ms, 'interaction_id': interaction_id})}\n\n"
)
# ββ Follow-up questions ββββββββββββββββββββββββββββββββββββββββββββ
# Generated after the done event so it never delays answer delivery.
if final_answer and not await request.is_disconnected():
follow_ups = await _generate_follow_ups(
request_data.message, final_answer, final_sources, llm_client
)
if follow_ups:
yield f"event: follow_ups\ndata: {json.dumps({'questions': follow_ups})}\n\n"
# Stage 2: update rolling summary asynchronously β fired after the
# response is fully delivered so it adds zero latency to the turn.
if final_answer and gemini_client and gemini_client.is_configured:
processing_key = getattr(
request.app.state, "gemini_processing_api_key", None
)
asyncio.create_task(
_update_summary_async(
conv_store=conv_store,
gemini_client=gemini_client,
session_id=session_id,
previous_summary=conversation_summary,
query=request_data.message,
answer=final_answer,
processing_api_key=processing_key,
)
)
except Exception as exc:
yield f"data: {json.dumps({'error': str(exc) or 'Generation failed'})}\n\n"
return StreamingResponse(
sse_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"X-Accel-Buffering": "no",
"Connection": "keep-alive",
},
)
|