Spaces:
Running
Running
GitHub Actions commited on
Commit Β·
efdd22e
1
Parent(s): 3d134a6
Deploy c8a8192
Browse files- app/api/chat.py +49 -46
- app/core/topic.py +79 -0
- app/models/pipeline.py +3 -0
- app/pipeline/nodes/cache.py +8 -0
- app/pipeline/nodes/gemini_fast.py +13 -5
- app/pipeline/nodes/generate.py +82 -29
- app/pipeline/nodes/guard.py +40 -27
- app/pipeline/nodes/retrieve.py +65 -5
- pytest.ini +2 -0
- tests/conftest.py +14 -4
app/api/chat.py
CHANGED
|
@@ -1,6 +1,4 @@
|
|
| 1 |
-
import asyncio
|
| 2 |
import json
|
| 3 |
-
import re
|
| 4 |
import time
|
| 5 |
from fastapi import APIRouter, Request, Depends
|
| 6 |
from fastapi.responses import StreamingResponse
|
|
@@ -37,7 +35,7 @@ async def _generate_follow_ups(
|
|
| 37 |
) -> list[str]:
|
| 38 |
"""
|
| 39 |
Generates 3 specific follow-up questions after the main answer is complete.
|
| 40 |
-
Runs
|
| 41 |
|
| 42 |
Questions must be:
|
| 43 |
- Specific to the answer content (never generic like "tell me more")
|
|
@@ -85,10 +83,21 @@ async def chat_endpoint(
|
|
| 85 |
request_data: ChatRequest,
|
| 86 |
token_payload: dict = Depends(verify_jwt),
|
| 87 |
) -> StreamingResponse:
|
| 88 |
-
"""Stream RAG answer as SSE
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
start_time = time.monotonic()
|
| 90 |
|
| 91 |
-
# All singletons pre-built in lifespan β zero allocation in hot path.
|
| 92 |
pipeline = request.app.state.pipeline
|
| 93 |
conv_store = request.app.state.conversation_store
|
| 94 |
llm_client = request.app.state.llm_client
|
|
@@ -120,6 +129,8 @@ async def chat_endpoint(
|
|
| 120 |
"retrieval_attempts": 0,
|
| 121 |
"rewritten_query": None,
|
| 122 |
"follow_ups": [],
|
|
|
|
|
|
|
| 123 |
}
|
| 124 |
|
| 125 |
async def sse_generator():
|
|
@@ -129,45 +140,35 @@ async def chat_endpoint(
|
|
| 129 |
interaction_id = None
|
| 130 |
|
| 131 |
try:
|
| 132 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
if await request.is_disconnected():
|
| 134 |
break
|
| 135 |
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
#
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
answer_update[len(final_answer):]
|
| 156 |
-
if answer_update.startswith(final_answer)
|
| 157 |
-
else answer_update
|
| 158 |
-
)
|
| 159 |
-
final_answer = answer_update
|
| 160 |
-
if delta:
|
| 161 |
-
yield f'data: {json.dumps({"token": delta})}\n\n'
|
| 162 |
-
|
| 163 |
-
if "sources" in updates:
|
| 164 |
-
final_sources = updates["sources"]
|
| 165 |
-
|
| 166 |
-
if "cached" in updates:
|
| 167 |
-
is_cached = updates["cached"]
|
| 168 |
-
|
| 169 |
-
if "interaction_id" in updates and updates["interaction_id"] is not None:
|
| 170 |
-
interaction_id = updates["interaction_id"]
|
| 171 |
|
| 172 |
elapsed_ms = int((time.monotonic() - start_time) * 1000)
|
| 173 |
|
|
@@ -178,24 +179,26 @@ async def chat_endpoint(
|
|
| 178 |
for s in final_sources
|
| 179 |
]
|
| 180 |
|
| 181 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
|
| 183 |
# ββ Follow-up questions ββββββββββββββββββββββββββββββββββββββββββββ
|
| 184 |
# Generated after the done event so it never delays answer delivery.
|
| 185 |
-
# Works for both cache hits (no sources) and full RAG responses.
|
| 186 |
if final_answer and not await request.is_disconnected():
|
| 187 |
follow_ups = await _generate_follow_ups(
|
| 188 |
request_data.message, final_answer, final_sources, llm_client
|
| 189 |
)
|
| 190 |
if follow_ups:
|
| 191 |
-
yield f
|
| 192 |
|
| 193 |
except Exception as exc:
|
| 194 |
-
yield f
|
| 195 |
|
| 196 |
return StreamingResponse(
|
| 197 |
sse_generator(),
|
| 198 |
media_type="text/event-stream",
|
| 199 |
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
| 200 |
)
|
| 201 |
-
|
|
|
|
|
|
|
| 1 |
import json
|
|
|
|
| 2 |
import time
|
| 3 |
from fastapi import APIRouter, Request, Depends
|
| 4 |
from fastapi.responses import StreamingResponse
|
|
|
|
| 35 |
) -> list[str]:
|
| 36 |
"""
|
| 37 |
Generates 3 specific follow-up questions after the main answer is complete.
|
| 38 |
+
Runs after the answer stream finishes β zero added latency before first token.
|
| 39 |
|
| 40 |
Questions must be:
|
| 41 |
- Specific to the answer content (never generic like "tell me more")
|
|
|
|
| 83 |
request_data: ChatRequest,
|
| 84 |
token_payload: dict = Depends(verify_jwt),
|
| 85 |
) -> StreamingResponse:
|
| 86 |
+
"""Stream RAG answer as typed SSE events.
|
| 87 |
+
|
| 88 |
+
Event sequence for a full RAG request:
|
| 89 |
+
event: status β guard label, cache miss, gemini routing, retrieve labels
|
| 90 |
+
event: reading β one per unique source found in Qdrant (before rerank)
|
| 91 |
+
event: sources β final selected sources array (after rerank)
|
| 92 |
+
event: thinking β CoT scratchpad tokens (70B only)
|
| 93 |
+
event: token β answer tokens
|
| 94 |
+
event: follow_ups β three suggested follow-up questions
|
| 95 |
+
|
| 96 |
+
For cache hits: status β status β token
|
| 97 |
+
For Gemini fast-path: status β status β token
|
| 98 |
+
"""
|
| 99 |
start_time = time.monotonic()
|
| 100 |
|
|
|
|
| 101 |
pipeline = request.app.state.pipeline
|
| 102 |
conv_store = request.app.state.conversation_store
|
| 103 |
llm_client = request.app.state.llm_client
|
|
|
|
| 129 |
"retrieval_attempts": 0,
|
| 130 |
"rewritten_query": None,
|
| 131 |
"follow_ups": [],
|
| 132 |
+
"path": None,
|
| 133 |
+
"query_topic": None,
|
| 134 |
}
|
| 135 |
|
| 136 |
async def sse_generator():
|
|
|
|
| 140 |
interaction_id = None
|
| 141 |
|
| 142 |
try:
|
| 143 |
+
# stream_mode=["custom", "updates"] yields (mode, data) tuples:
|
| 144 |
+
# mode="custom" β data is whatever writer(payload) was called with
|
| 145 |
+
# mode="updates" β data is {node_name: state_updates_dict}
|
| 146 |
+
async for mode, data in pipeline.astream(
|
| 147 |
+
initial_state,
|
| 148 |
+
stream_mode=["custom", "updates"],
|
| 149 |
+
):
|
| 150 |
if await request.is_disconnected():
|
| 151 |
break
|
| 152 |
|
| 153 |
+
if mode == "custom":
|
| 154 |
+
# Forward writer events as named SSE events.
|
| 155 |
+
# Each node emits {"type": "<event_name>", ...payload}.
|
| 156 |
+
event_type = data.get("type", "status")
|
| 157 |
+
# Strip the "type" key so the client receives a clean payload.
|
| 158 |
+
payload = {k: v for k, v in data.items() if k != "type"}
|
| 159 |
+
yield f"event: {event_type}\ndata: {json.dumps(payload)}\n\n"
|
| 160 |
+
|
| 161 |
+
elif mode == "updates":
|
| 162 |
+
# Capture terminal state for the done event; do not re-emit tokens.
|
| 163 |
+
for _node_name, updates in data.items():
|
| 164 |
+
if "sources" in updates and updates["sources"]:
|
| 165 |
+
final_sources = updates["sources"]
|
| 166 |
+
if "cached" in updates:
|
| 167 |
+
is_cached = updates["cached"]
|
| 168 |
+
if "interaction_id" in updates and updates["interaction_id"] is not None:
|
| 169 |
+
interaction_id = updates["interaction_id"]
|
| 170 |
+
if "answer" in updates and updates["answer"]:
|
| 171 |
+
final_answer = updates["answer"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
|
| 173 |
elapsed_ms = int((time.monotonic() - start_time) * 1000)
|
| 174 |
|
|
|
|
| 179 |
for s in final_sources
|
| 180 |
]
|
| 181 |
|
| 182 |
+
# The done event uses plain data: (no event: type) for backward
|
| 183 |
+
# compatibility with widgets that listen on the raw data channel.
|
| 184 |
+
yield (
|
| 185 |
+
f"data: {json.dumps({'done': True, 'sources': sources_list, 'cached': is_cached, 'latency_ms': elapsed_ms, 'interaction_id': interaction_id})}\n\n"
|
| 186 |
+
)
|
| 187 |
|
| 188 |
# ββ Follow-up questions ββββββββββββββββββββββββββββββββββββββββββββ
|
| 189 |
# Generated after the done event so it never delays answer delivery.
|
|
|
|
| 190 |
if final_answer and not await request.is_disconnected():
|
| 191 |
follow_ups = await _generate_follow_ups(
|
| 192 |
request_data.message, final_answer, final_sources, llm_client
|
| 193 |
)
|
| 194 |
if follow_ups:
|
| 195 |
+
yield f"event: follow_ups\ndata: {json.dumps({'questions': follow_ups})}\n\n"
|
| 196 |
|
| 197 |
except Exception as exc:
|
| 198 |
+
yield f"data: {json.dumps({'error': str(exc) or 'Generation failed'})}\n\n"
|
| 199 |
|
| 200 |
return StreamingResponse(
|
| 201 |
sse_generator(),
|
| 202 |
media_type="text/event-stream",
|
| 203 |
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
| 204 |
)
|
|
|
app/core/topic.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
backend/app/core/topic.py
|
| 3 |
+
|
| 4 |
+
Extracts a 1β3 word topic label from a natural-language query.
|
| 5 |
+
|
| 6 |
+
Used by Guard, Retrieve, and any node that surfaces context-specific status
|
| 7 |
+
labels ("Checking your question about machine learning", "Searching portfolio
|
| 8 |
+
for RAG pipeline") without any LLM call. The extraction is a pure set-lookup
|
| 9 |
+
β it adds no measurable latency.
|
| 10 |
+
|
| 11 |
+
>>> extract_topic("What are Darshan's machine learning projects?")
|
| 12 |
+
'machine learning projects'
|
| 13 |
+
>>> extract_topic("Tell me about his background")
|
| 14 |
+
'background'
|
| 15 |
+
>>> extract_topic("How does he implement RAG?")
|
| 16 |
+
'implement RAG'
|
| 17 |
+
>>> extract_topic("What is")
|
| 18 |
+
'What is'
|
| 19 |
+
"""
|
| 20 |
+
from __future__ import annotations
|
| 21 |
+
|
| 22 |
+
import re
|
| 23 |
+
|
| 24 |
+
# Comprehensive stopword set: prepositions, articles, auxiliary verbs, common
|
| 25 |
+
# question words, personal pronouns, demonstratives, and portfolio-query filler.
|
| 26 |
+
# Content-bearing words (nouns, adjectives, verbs like "implement", "built")
|
| 27 |
+
# are intentionally absent β they ARE the topic.
|
| 28 |
+
_STOPWORDS: frozenset[str] = frozenset({
|
| 29 |
+
# Articles
|
| 30 |
+
"a", "an", "the",
|
| 31 |
+
# Prepositions
|
| 32 |
+
"about", "above", "across", "after", "against", "along", "among",
|
| 33 |
+
"around", "at", "before", "behind", "below", "beneath", "beside",
|
| 34 |
+
"between", "beyond", "by", "during", "except", "for", "from", "in",
|
| 35 |
+
"inside", "into", "like", "near", "of", "off", "on", "onto", "out",
|
| 36 |
+
"outside", "over", "past", "regarding", "since", "through",
|
| 37 |
+
"throughout", "to", "toward", "under", "underneath", "until", "up",
|
| 38 |
+
"upon", "with", "within", "without",
|
| 39 |
+
# Conjunctions
|
| 40 |
+
"and", "but", "or", "nor", "so", "yet", "both", "either", "neither",
|
| 41 |
+
# Common auxiliary verbs
|
| 42 |
+
"is", "are", "was", "were", "be", "been", "being",
|
| 43 |
+
"has", "have", "had", "do", "does", "did",
|
| 44 |
+
"will", "would", "could", "should", "may", "might", "can", "shall",
|
| 45 |
+
# Question words
|
| 46 |
+
"what", "who", "where", "when", "how", "why", "which",
|
| 47 |
+
# Personal pronouns
|
| 48 |
+
"i", "you", "he", "she", "it", "we", "they",
|
| 49 |
+
"me", "him", "her", "us", "them",
|
| 50 |
+
"my", "your", "his", "its", "our", "their",
|
| 51 |
+
"mine", "yours", "hers", "ours", "theirs",
|
| 52 |
+
# Demonstratives
|
| 53 |
+
"this", "that", "these", "those",
|
| 54 |
+
# Common portfolio-query filler
|
| 55 |
+
"tell", "me", "about", "show", "give", "list", "get", "find",
|
| 56 |
+
"look", "also", "just", "really", "very", "more", "most",
|
| 57 |
+
"some", "any", "other", "another", "same", "such", "own",
|
| 58 |
+
"darshan", "chheda", # owner name is not a useful topic word
|
| 59 |
+
})
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def extract_topic(query: str) -> str:
|
| 63 |
+
"""Return a 1β3 word topic phrase extracted from ``query``.
|
| 64 |
+
|
| 65 |
+
Words matching the stopword set are stripped (case-insensitive). The first
|
| 66 |
+
1β3 remaining words are returned joined by spaces. If the query resolves
|
| 67 |
+
to zero content words (all stopwords, or empty), the first two whitespace-
|
| 68 |
+
separated tokens of the original query are returned unchanged so the caller
|
| 69 |
+
always receives a non-empty string.
|
| 70 |
+
"""
|
| 71 |
+
tokens = re.findall(r"[a-zA-Z']+", query)
|
| 72 |
+
content = [t for t in tokens if t.lower() not in _STOPWORDS and len(t) > 1]
|
| 73 |
+
|
| 74 |
+
if not content:
|
| 75 |
+
# Fallback: keep the first two words of the original query verbatim.
|
| 76 |
+
parts = query.strip().split()
|
| 77 |
+
return " ".join(parts[:2]) if len(parts) >= 2 else (parts[0] if parts else query)
|
| 78 |
+
|
| 79 |
+
return " ".join(content[:3])
|
app/models/pipeline.py
CHANGED
|
@@ -57,3 +57,6 @@ class PipelineState(TypedDict):
|
|
| 57 |
# data_prep.py filters to path=="rag" when building reranker triplets because
|
| 58 |
# only RAG interactions have chunk associations that form valid training pairs.
|
| 59 |
path: Optional[str]
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
# data_prep.py filters to path=="rag" when building reranker triplets because
|
| 58 |
# only RAG interactions have chunk associations that form valid training pairs.
|
| 59 |
path: Optional[str]
|
| 60 |
+
# 1β3 word topic extracted from the query by the guard node (extract_topic).
|
| 61 |
+
# Stored in state so retrieve_node can reuse it without recomputing.
|
| 62 |
+
query_topic: Optional[str]
|
app/pipeline/nodes/cache.py
CHANGED
|
@@ -18,6 +18,7 @@
|
|
| 18 |
from typing import Callable
|
| 19 |
|
| 20 |
import numpy as np
|
|
|
|
| 21 |
|
| 22 |
from app.models.pipeline import PipelineState
|
| 23 |
from app.services.semantic_cache import SemanticCache
|
|
@@ -44,6 +45,9 @@ def _has_unresolved_reference(query: str) -> bool:
|
|
| 44 |
|
| 45 |
def make_cache_node(cache: SemanticCache, embedder) -> Callable[[PipelineState], dict]:
|
| 46 |
async def cache_node(state: PipelineState) -> dict:
|
|
|
|
|
|
|
|
|
|
| 47 |
query = state["query"]
|
| 48 |
has_history = bool(state.get("conversation_history"))
|
| 49 |
|
|
@@ -62,6 +66,10 @@ def make_cache_node(cache: SemanticCache, embedder) -> Callable[[PipelineState],
|
|
| 62 |
|
| 63 |
cached = await cache.get(query_embedding)
|
| 64 |
if cached:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
return {
|
| 66 |
"answer": cached,
|
| 67 |
"cached": True,
|
|
|
|
| 18 |
from typing import Callable
|
| 19 |
|
| 20 |
import numpy as np
|
| 21 |
+
from langgraph.config import get_stream_writer
|
| 22 |
|
| 23 |
from app.models.pipeline import PipelineState
|
| 24 |
from app.services.semantic_cache import SemanticCache
|
|
|
|
| 45 |
|
| 46 |
def make_cache_node(cache: SemanticCache, embedder) -> Callable[[PipelineState], dict]:
|
| 47 |
async def cache_node(state: PipelineState) -> dict:
|
| 48 |
+
writer = get_stream_writer()
|
| 49 |
+
writer({"type": "status", "label": "Looking up in memory..."})
|
| 50 |
+
|
| 51 |
query = state["query"]
|
| 52 |
has_history = bool(state.get("conversation_history"))
|
| 53 |
|
|
|
|
| 66 |
|
| 67 |
cached = await cache.get(query_embedding)
|
| 68 |
if cached:
|
| 69 |
+
writer({"type": "status", "label": "Found a recent answer, loading..."})
|
| 70 |
+
# Emit the full cached answer as a single token event β the cache
|
| 71 |
+
# returns complete text, not a stream, so one event is correct.
|
| 72 |
+
writer({"type": "token", "text": cached})
|
| 73 |
return {
|
| 74 |
"answer": cached,
|
| 75 |
"cached": True,
|
app/pipeline/nodes/gemini_fast.py
CHANGED
|
@@ -23,6 +23,8 @@ from __future__ import annotations
|
|
| 23 |
import logging
|
| 24 |
from typing import Any
|
| 25 |
|
|
|
|
|
|
|
| 26 |
from app.models.pipeline import PipelineState
|
| 27 |
from app.services.gemini_client import GeminiClient
|
| 28 |
from app.core.quality import is_low_trust
|
|
@@ -70,6 +72,9 @@ def make_gemini_fast_node(gemini_client: GeminiClient) -> Any:
|
|
| 70 |
"""
|
| 71 |
|
| 72 |
async def gemini_fast(state: PipelineState) -> dict:
|
|
|
|
|
|
|
|
|
|
| 73 |
query = state["query"]
|
| 74 |
complexity = "complex" if _is_complex(query) else "simple"
|
| 75 |
|
|
@@ -77,6 +82,7 @@ def make_gemini_fast_node(gemini_client: GeminiClient) -> Any:
|
|
| 77 |
# traffic straight to RAG β behaviour is identical to the old graph.
|
| 78 |
if not gemini_client.is_configured:
|
| 79 |
logger.debug("Gemini not configured; routing query to RAG.")
|
|
|
|
| 80 |
return {
|
| 81 |
"query_complexity": complexity,
|
| 82 |
"expanded_queries": [query],
|
|
@@ -90,21 +96,22 @@ def make_gemini_fast_node(gemini_client: GeminiClient) -> Any:
|
|
| 90 |
|
| 91 |
if answer is not None:
|
| 92 |
# Run the same quality gate that guards Groq answers.
|
| 93 |
-
# Gemini fast-path has no retrieved chunks, so only the hedge-phrase
|
| 94 |
-
# and short-complex-answer signals apply (chunks argument is []).
|
| 95 |
if is_low_trust(answer, [], complexity):
|
| 96 |
logger.debug(
|
| 97 |
"Gemini fast-path answer failed quality gate β routing to RAG."
|
| 98 |
)
|
| 99 |
-
|
| 100 |
return {
|
| 101 |
"query_complexity": complexity,
|
| 102 |
"expanded_queries": [query],
|
| 103 |
"thinking": True,
|
| 104 |
}
|
| 105 |
|
| 106 |
-
# Gemini answered
|
| 107 |
logger.debug("Gemini fast-path answered query (len=%d)", len(answer))
|
|
|
|
|
|
|
|
|
|
| 108 |
return {
|
| 109 |
"query_complexity": complexity,
|
| 110 |
"answer": answer,
|
|
@@ -113,9 +120,10 @@ def make_gemini_fast_node(gemini_client: GeminiClient) -> Any:
|
|
| 113 |
"path": "gemini_fast",
|
| 114 |
}
|
| 115 |
|
| 116 |
-
# Gemini called search_knowledge_base() β
|
| 117 |
rag_query = tool_query or query
|
| 118 |
logger.debug("Gemini routed to RAG (tool_query=%r)", rag_query)
|
|
|
|
| 119 |
return {
|
| 120 |
"query_complexity": complexity,
|
| 121 |
"expanded_queries": [rag_query],
|
|
|
|
| 23 |
import logging
|
| 24 |
from typing import Any
|
| 25 |
|
| 26 |
+
from langgraph.config import get_stream_writer
|
| 27 |
+
|
| 28 |
from app.models.pipeline import PipelineState
|
| 29 |
from app.services.gemini_client import GeminiClient
|
| 30 |
from app.core.quality import is_low_trust
|
|
|
|
| 72 |
"""
|
| 73 |
|
| 74 |
async def gemini_fast(state: PipelineState) -> dict:
|
| 75 |
+
writer = get_stream_writer()
|
| 76 |
+
writer({"type": "status", "label": "Thinking about your question directly..."})
|
| 77 |
+
|
| 78 |
query = state["query"]
|
| 79 |
complexity = "complex" if _is_complex(query) else "simple"
|
| 80 |
|
|
|
|
| 82 |
# traffic straight to RAG β behaviour is identical to the old graph.
|
| 83 |
if not gemini_client.is_configured:
|
| 84 |
logger.debug("Gemini not configured; routing query to RAG.")
|
| 85 |
+
writer({"type": "status", "label": "Needs deeper search, checking portfolio..."})
|
| 86 |
return {
|
| 87 |
"query_complexity": complexity,
|
| 88 |
"expanded_queries": [query],
|
|
|
|
| 96 |
|
| 97 |
if answer is not None:
|
| 98 |
# Run the same quality gate that guards Groq answers.
|
|
|
|
|
|
|
| 99 |
if is_low_trust(answer, [], complexity):
|
| 100 |
logger.debug(
|
| 101 |
"Gemini fast-path answer failed quality gate β routing to RAG."
|
| 102 |
)
|
| 103 |
+
writer({"type": "status", "label": "Needs deeper search, checking portfolio..."})
|
| 104 |
return {
|
| 105 |
"query_complexity": complexity,
|
| 106 |
"expanded_queries": [query],
|
| 107 |
"thinking": True,
|
| 108 |
}
|
| 109 |
|
| 110 |
+
# Gemini answered and passed the quality gate.
|
| 111 |
logger.debug("Gemini fast-path answered query (len=%d)", len(answer))
|
| 112 |
+
writer({"type": "status", "label": "Got a direct answer, writing now..."})
|
| 113 |
+
# Gemini does not stream; emit the complete answer as one token event.
|
| 114 |
+
writer({"type": "token", "text": answer})
|
| 115 |
return {
|
| 116 |
"query_complexity": complexity,
|
| 117 |
"answer": answer,
|
|
|
|
| 120 |
"path": "gemini_fast",
|
| 121 |
}
|
| 122 |
|
| 123 |
+
# Gemini called search_knowledge_base() β route to full RAG.
|
| 124 |
rag_query = tool_query or query
|
| 125 |
logger.debug("Gemini routed to RAG (tool_query=%r)", rag_query)
|
| 126 |
+
writer({"type": "status", "label": "Needs deeper search, checking portfolio..."})
|
| 127 |
return {
|
| 128 |
"query_complexity": complexity,
|
| 129 |
"expanded_queries": [rag_query],
|
app/pipeline/nodes/generate.py
CHANGED
|
@@ -2,6 +2,8 @@ import logging
|
|
| 2 |
import re
|
| 3 |
from typing import Callable
|
| 4 |
|
|
|
|
|
|
|
| 5 |
from app.models.chat import SourceRef
|
| 6 |
from app.models.pipeline import PipelineState
|
| 7 |
from app.services.llm_client import LLMClient
|
|
@@ -139,37 +141,38 @@ def _format_history(history: list[dict]) -> str:
|
|
| 139 |
|
| 140 |
|
| 141 |
def make_generate_node(llm_client: LLMClient, gemini_client=None) -> Callable[[PipelineState], dict]: # noqa: ANN001
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
async def generate_node(state: PipelineState) -> dict:
|
|
|
|
| 143 |
query = state["query"]
|
| 144 |
complexity = state.get("query_complexity", "simple")
|
| 145 |
reranked_chunks = state.get("reranked_chunks", [])
|
| 146 |
|
| 147 |
# ββ Not-found path βββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 148 |
-
# Retrieve found no relevant chunks (either KB empty or below rerank
|
| 149 |
-
# threshold). Use a short, model-generated honest refusal so guard
|
| 150 |
-
# rejections and not-found both route here with quality responses.
|
| 151 |
if not reranked_chunks:
|
|
|
|
| 152 |
history_prefix = _format_history(state.get("conversation_history") or [])
|
| 153 |
stream = llm_client.complete_with_complexity(
|
| 154 |
prompt=f"{history_prefix}Visitor question: {query}",
|
| 155 |
system=_NOT_FOUND_SYSTEM,
|
| 156 |
stream=True,
|
| 157 |
-
complexity="simple",
|
| 158 |
)
|
| 159 |
full_answer = ""
|
| 160 |
async for token in stream:
|
| 161 |
full_answer += token
|
|
|
|
| 162 |
return {"answer": full_answer, "sources": [], "path": "rag"}
|
| 163 |
|
| 164 |
# ββ Pre-LLM coherence shortcut ββββββββββββββββββββββββββββββββββββββ
|
| 165 |
-
# Check that at least one meaningful query token appears somewhere in
|
| 166 |
-
# the retrieved chunks. If there is zero textual overlap AND the top
|
| 167 |
-
# rerank score is negative, the retriever returned topically unrelated
|
| 168 |
-
# chunks β skip the LLM call entirely and go straight to not-found.
|
| 169 |
-
# This saves a Groq call (~300ms) when the KB truly has nothing.
|
| 170 |
top_score = reranked_chunks[0]["metadata"].get("rerank_score", 0.0)
|
| 171 |
query_toks = _query_tokens(query)
|
| 172 |
if top_score < 0.0 and not _chunks_overlap_query(query_toks, reranked_chunks):
|
|
|
|
| 173 |
history_prefix = _format_history(state.get("conversation_history") or [])
|
| 174 |
stream = llm_client.complete_with_complexity(
|
| 175 |
prompt=f"{history_prefix}Visitor question: {query}",
|
|
@@ -180,6 +183,7 @@ def make_generate_node(llm_client: LLMClient, gemini_client=None) -> Callable[[P
|
|
| 180 |
full_answer = ""
|
| 181 |
async for token in stream:
|
| 182 |
full_answer += token
|
|
|
|
| 183 |
return {"answer": full_answer, "sources": [], "path": "rag"}
|
| 184 |
|
| 185 |
# ββ Build numbered context block ββββββββββββββββββββββββββββββββββββ
|
|
@@ -188,7 +192,6 @@ def make_generate_node(llm_client: LLMClient, gemini_client=None) -> Callable[[P
|
|
| 188 |
|
| 189 |
for i, chunk in enumerate(reranked_chunks, start=1):
|
| 190 |
meta = chunk["metadata"]
|
| 191 |
-
# Include title and URL so the LLM can verify passage relevance.
|
| 192 |
header = f"[{i}] {meta['source_title']}"
|
| 193 |
if meta.get("source_url"):
|
| 194 |
header += f" ({meta['source_url']})"
|
|
@@ -203,10 +206,6 @@ def make_generate_node(llm_client: LLMClient, gemini_client=None) -> Callable[[P
|
|
| 203 |
|
| 204 |
context_block = "\n\n".join(context_parts)
|
| 205 |
|
| 206 |
-
# ββ Compact conversation history prefix βββββββββββββββββββββββββββββ
|
| 207 |
-
# Injected before passages so the model can resolve follow-up references
|
| 208 |
-
# ("tell me more", "which one used Java?", "that was wrong") without
|
| 209 |
-
# needing to re-retrieve resolved information.
|
| 210 |
history_prefix = _format_history(state.get("conversation_history") or [])
|
| 211 |
is_criticism = state.get("is_criticism", False)
|
| 212 |
criticism_note = (
|
|
@@ -216,44 +215,98 @@ def make_generate_node(llm_client: LLMClient, gemini_client=None) -> Callable[[P
|
|
| 216 |
)
|
| 217 |
prompt = f"{criticism_note}{history_prefix}Passages:\n{context_block}\n\nVisitor question: {query}"
|
| 218 |
|
| 219 |
-
# ββ
|
| 220 |
-
#
|
| 221 |
-
#
|
| 222 |
-
#
|
| 223 |
-
#
|
|
|
|
| 224 |
stream = llm_client.complete_with_complexity(
|
| 225 |
prompt=prompt,
|
| 226 |
system=_SYSTEM_PROMPT,
|
| 227 |
stream=True,
|
| 228 |
complexity=complexity,
|
| 229 |
)
|
| 230 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 231 |
async for token in stream:
|
| 232 |
raw_answer += token
|
| 233 |
-
|
| 234 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 235 |
full_answer = re.sub(r"<think>.*?</think>\s*", "", raw_answer, flags=re.DOTALL).strip()
|
| 236 |
|
| 237 |
# ββ Quality gate: Gemini editorial reformat ββββββββββββββββββββββββββ
|
| 238 |
-
# Fires when: (a) criticism
|
| 239 |
-
# (b) low-trust heuristic flags the draft
|
| 240 |
-
# Zero extra cost on good responses; ~200-400ms only when genuinely needed.
|
| 241 |
if gemini_client is not None and (is_criticism or is_low_trust(full_answer, reranked_chunks, complexity)):
|
| 242 |
logger.debug("Triggering Gemini reformat (criticism=%s).", is_criticism)
|
| 243 |
reformatted = await gemini_client.reformat_rag_answer(query, context_block, full_answer)
|
| 244 |
if reformatted:
|
| 245 |
full_answer = reformatted
|
| 246 |
|
| 247 |
-
# Only surface sources the LLM actually cited
|
| 248 |
-
# Fall back to top-2 if the model produced no [N] markers.
|
| 249 |
cited_indices = {int(m) for m in re.findall(r"\[(\d+)\]", full_answer)}
|
| 250 |
cited_sources = [sr for i, sr in enumerate(source_refs, start=1) if i in cited_indices]
|
| 251 |
|
| 252 |
return {
|
| 253 |
"answer": full_answer,
|
| 254 |
"sources": cited_sources if cited_sources else source_refs[:2],
|
| 255 |
-
# Tag this interaction so data_prep.py can filter to RAG-path only
|
| 256 |
-
# when building reranker triplets (only RAG has chunk associations).
|
| 257 |
"path": "rag",
|
| 258 |
}
|
| 259 |
|
|
|
|
| 2 |
import re
|
| 3 |
from typing import Callable
|
| 4 |
|
| 5 |
+
from langgraph.config import get_stream_writer
|
| 6 |
+
|
| 7 |
from app.models.chat import SourceRef
|
| 8 |
from app.models.pipeline import PipelineState
|
| 9 |
from app.services.llm_client import LLMClient
|
|
|
|
| 141 |
|
| 142 |
|
| 143 |
def make_generate_node(llm_client: LLMClient, gemini_client=None) -> Callable[[PipelineState], dict]: # noqa: ANN001
|
| 144 |
+
# Number of token chunks to buffer before deciding there is no CoT block.
|
| 145 |
+
# Llama 3.1 8B may omit <think> entirely; Llama 3.3 70B always starts with one.
|
| 146 |
+
# 50 chunks is enough to cover the opening tag without delaying short answers.
|
| 147 |
+
_THINK_LOOKAHEAD: int = 50
|
| 148 |
+
|
| 149 |
async def generate_node(state: PipelineState) -> dict:
|
| 150 |
+
writer = get_stream_writer()
|
| 151 |
query = state["query"]
|
| 152 |
complexity = state.get("query_complexity", "simple")
|
| 153 |
reranked_chunks = state.get("reranked_chunks", [])
|
| 154 |
|
| 155 |
# ββ Not-found path βββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
|
|
|
|
|
|
| 156 |
if not reranked_chunks:
|
| 157 |
+
writer({"type": "status", "label": "Could not find specific information, responding carefully..."})
|
| 158 |
history_prefix = _format_history(state.get("conversation_history") or [])
|
| 159 |
stream = llm_client.complete_with_complexity(
|
| 160 |
prompt=f"{history_prefix}Visitor question: {query}",
|
| 161 |
system=_NOT_FOUND_SYSTEM,
|
| 162 |
stream=True,
|
| 163 |
+
complexity="simple",
|
| 164 |
)
|
| 165 |
full_answer = ""
|
| 166 |
async for token in stream:
|
| 167 |
full_answer += token
|
| 168 |
+
writer({"type": "token", "text": token})
|
| 169 |
return {"answer": full_answer, "sources": [], "path": "rag"}
|
| 170 |
|
| 171 |
# ββ Pre-LLM coherence shortcut ββββββββββββββββββββββββββββββββββββββ
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
top_score = reranked_chunks[0]["metadata"].get("rerank_score", 0.0)
|
| 173 |
query_toks = _query_tokens(query)
|
| 174 |
if top_score < 0.0 and not _chunks_overlap_query(query_toks, reranked_chunks):
|
| 175 |
+
writer({"type": "status", "label": "Could not find specific information, responding carefully..."})
|
| 176 |
history_prefix = _format_history(state.get("conversation_history") or [])
|
| 177 |
stream = llm_client.complete_with_complexity(
|
| 178 |
prompt=f"{history_prefix}Visitor question: {query}",
|
|
|
|
| 183 |
full_answer = ""
|
| 184 |
async for token in stream:
|
| 185 |
full_answer += token
|
| 186 |
+
writer({"type": "token", "text": token})
|
| 187 |
return {"answer": full_answer, "sources": [], "path": "rag"}
|
| 188 |
|
| 189 |
# ββ Build numbered context block ββββββββββββββββββββββββββββββββββββ
|
|
|
|
| 192 |
|
| 193 |
for i, chunk in enumerate(reranked_chunks, start=1):
|
| 194 |
meta = chunk["metadata"]
|
|
|
|
| 195 |
header = f"[{i}] {meta['source_title']}"
|
| 196 |
if meta.get("source_url"):
|
| 197 |
header += f" ({meta['source_url']})"
|
|
|
|
| 206 |
|
| 207 |
context_block = "\n\n".join(context_parts)
|
| 208 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
history_prefix = _format_history(state.get("conversation_history") or [])
|
| 210 |
is_criticism = state.get("is_criticism", False)
|
| 211 |
criticism_note = (
|
|
|
|
| 215 |
)
|
| 216 |
prompt = f"{criticism_note}{history_prefix}Passages:\n{context_block}\n\nVisitor question: {query}"
|
| 217 |
|
| 218 |
+
# ββ Streaming CoT-aware token emission ββββββββββββββββββββββββββββββ
|
| 219 |
+
# Groq streams tokens one chunk at a time. We intercept them to:
|
| 220 |
+
# Phase 1 β detect and buffer the <think> block, emitting thinking events.
|
| 221 |
+
# Phase 2 β emit answer tokens in real time after </think>.
|
| 222 |
+
# If no <think> tag appears in the first _THINK_LOOKAHEAD token chunks
|
| 223 |
+
# (Llama 3.1 8B on simple queries), we switch to direct emission with no wait.
|
| 224 |
stream = llm_client.complete_with_complexity(
|
| 225 |
prompt=prompt,
|
| 226 |
system=_SYSTEM_PROMPT,
|
| 227 |
stream=True,
|
| 228 |
complexity=complexity,
|
| 229 |
)
|
| 230 |
+
|
| 231 |
+
raw_answer = "" # complete unmodified response for quality gate
|
| 232 |
+
buf = "" # character buffer for tag detection
|
| 233 |
+
in_think = False # currently inside <think> block
|
| 234 |
+
think_done = False # </think> was found; switched to direct streaming
|
| 235 |
+
no_cot = False # no <think> seen in first LOOKAHEAD token chunks
|
| 236 |
+
token_chunk_count = 0 # number of token chunks received
|
| 237 |
+
think_first_emitted = False # CoT first-sentence status already sent
|
| 238 |
+
|
| 239 |
async for token in stream:
|
| 240 |
raw_answer += token
|
| 241 |
+
token_chunk_count += 1
|
| 242 |
+
|
| 243 |
+
if think_done or no_cot:
|
| 244 |
+
# Phase 2: real-time answer streaming.
|
| 245 |
+
writer({"type": "token", "text": token})
|
| 246 |
+
continue
|
| 247 |
+
|
| 248 |
+
buf += token
|
| 249 |
+
|
| 250 |
+
if not in_think:
|
| 251 |
+
if "<think>" in buf:
|
| 252 |
+
in_think = True
|
| 253 |
+
pre = buf[: buf.index("<think>")]
|
| 254 |
+
if pre.strip():
|
| 255 |
+
# Text before the think tag is part of the answer.
|
| 256 |
+
writer({"type": "token", "text": pre})
|
| 257 |
+
buf = buf[buf.index("<think>") + 7:] # 7 = len("<think>")
|
| 258 |
+
elif token_chunk_count >= _THINK_LOOKAHEAD:
|
| 259 |
+
# No CoT block in first 50 chunks β emit buffered and go direct.
|
| 260 |
+
no_cot = True
|
| 261 |
+
writer({"type": "token", "text": buf})
|
| 262 |
+
buf = ""
|
| 263 |
+
else:
|
| 264 |
+
# Phase 1: inside the <think> block; buffer until </think>.
|
| 265 |
+
if "</think>" in buf:
|
| 266 |
+
idx = buf.index("</think>")
|
| 267 |
+
think_txt = buf[:idx].strip()
|
| 268 |
+
after_think = buf[idx + 9:] # 9 = len("</think>")
|
| 269 |
+
|
| 270 |
+
if think_txt and not think_first_emitted:
|
| 271 |
+
# Surface the first sentence as a legible status label.
|
| 272 |
+
for j, ch in enumerate(think_txt):
|
| 273 |
+
if ch in ".?!\n":
|
| 274 |
+
first_sent = think_txt[: j + 1].strip()[:120]
|
| 275 |
+
writer({"type": "status", "label": first_sent})
|
| 276 |
+
think_first_emitted = True
|
| 277 |
+
break
|
| 278 |
+
|
| 279 |
+
if think_txt:
|
| 280 |
+
writer({"type": "thinking", "text": think_txt})
|
| 281 |
+
|
| 282 |
+
think_done = True
|
| 283 |
+
buf = ""
|
| 284 |
+
if after_think.strip():
|
| 285 |
+
writer({"type": "token", "text": after_think})
|
| 286 |
+
|
| 287 |
+
# Flush buffer if the stream ended mid-detection (e.g. model forgot </think>).
|
| 288 |
+
if buf:
|
| 289 |
+
writer({"type": "token", "text": buf})
|
| 290 |
+
|
| 291 |
+
# ββ Strip CoT scratchpad ββββββββββββββββββββββββββββββββββββββββββββ
|
| 292 |
full_answer = re.sub(r"<think>.*?</think>\s*", "", raw_answer, flags=re.DOTALL).strip()
|
| 293 |
|
| 294 |
# ββ Quality gate: Gemini editorial reformat ββββββββββββββββββββββββββ
|
| 295 |
+
# Fires when: (a) criticism detected β always reformat, or
|
| 296 |
+
# (b) low-trust heuristic flags the draft. Zero extra cost on good responses.
|
|
|
|
| 297 |
if gemini_client is not None and (is_criticism or is_low_trust(full_answer, reranked_chunks, complexity)):
|
| 298 |
logger.debug("Triggering Gemini reformat (criticism=%s).", is_criticism)
|
| 299 |
reformatted = await gemini_client.reformat_rag_answer(query, context_block, full_answer)
|
| 300 |
if reformatted:
|
| 301 |
full_answer = reformatted
|
| 302 |
|
| 303 |
+
# Only surface sources the LLM actually cited.
|
|
|
|
| 304 |
cited_indices = {int(m) for m in re.findall(r"\[(\d+)\]", full_answer)}
|
| 305 |
cited_sources = [sr for i, sr in enumerate(source_refs, start=1) if i in cited_indices]
|
| 306 |
|
| 307 |
return {
|
| 308 |
"answer": full_answer,
|
| 309 |
"sources": cited_sources if cited_sources else source_refs[:2],
|
|
|
|
|
|
|
| 310 |
"path": "rag",
|
| 311 |
}
|
| 312 |
|
app/pipeline/nodes/guard.py
CHANGED
|
@@ -1,45 +1,58 @@
|
|
| 1 |
from typing import Callable
|
| 2 |
|
|
|
|
|
|
|
|
|
|
| 3 |
from app.models.pipeline import PipelineState
|
| 4 |
from app.security.guard_classifier import GuardClassifier
|
| 5 |
from app.security.sanitizer import sanitize_input, redact_pii
|
| 6 |
|
|
|
|
| 7 |
def make_guard_node(classifier: GuardClassifier) -> Callable[[PipelineState], dict]:
|
| 8 |
def guard_node(state: PipelineState) -> dict:
|
|
|
|
| 9 |
original_query = state["query"]
|
| 10 |
-
|
| 11 |
-
# 1. Sanitize
|
| 12 |
sanitized = sanitize_input(original_query)
|
| 13 |
-
|
| 14 |
-
# 2. PII Redact
|
| 15 |
-
# Note: the prompt says "Return cleaned text. Used in log_eval node before writing to SQLite."
|
| 16 |
-
# If we redact it here, the rest of the pipeline gets the redacted text.
|
| 17 |
-
# This is safe and ensures PII doesn't leak into LLM prompts or vector similarity.
|
| 18 |
clean_query = redact_pii(sanitized)
|
| 19 |
-
|
| 20 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
if len(clean_query) == 0:
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
|
|
|
| 30 |
is_safe, score = classifier.is_in_scope(clean_query)
|
| 31 |
-
|
| 32 |
if not is_safe:
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
|
|
|
| 40 |
return {
|
| 41 |
"query": clean_query,
|
| 42 |
-
"guard_passed": True
|
|
|
|
| 43 |
}
|
| 44 |
-
|
| 45 |
return guard_node
|
|
|
|
|
|
| 1 |
from typing import Callable
|
| 2 |
|
| 3 |
+
from langgraph.config import get_stream_writer
|
| 4 |
+
|
| 5 |
+
from app.core.topic import extract_topic
|
| 6 |
from app.models.pipeline import PipelineState
|
| 7 |
from app.security.guard_classifier import GuardClassifier
|
| 8 |
from app.security.sanitizer import sanitize_input, redact_pii
|
| 9 |
|
| 10 |
+
|
| 11 |
def make_guard_node(classifier: GuardClassifier) -> Callable[[PipelineState], dict]:
|
| 12 |
def guard_node(state: PipelineState) -> dict:
|
| 13 |
+
writer = get_stream_writer()
|
| 14 |
original_query = state["query"]
|
| 15 |
+
|
| 16 |
+
# 1. Sanitize and PII-redact before any LLM or classifier call.
|
| 17 |
sanitized = sanitize_input(original_query)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
clean_query = redact_pii(sanitized)
|
| 19 |
+
|
| 20 |
+
# Emit the first status event now that we have a clean query to describe.
|
| 21 |
+
# Topic extraction is O(N) set lookup β adds zero measurable latency.
|
| 22 |
+
if clean_query:
|
| 23 |
+
topic = extract_topic(clean_query)
|
| 24 |
+
label = f"Checking your question about {topic}" if topic else "Checking your question"
|
| 25 |
+
else:
|
| 26 |
+
topic = ""
|
| 27 |
+
label = "Checking your question"
|
| 28 |
+
writer({"type": "status", "label": label})
|
| 29 |
+
|
| 30 |
if len(clean_query) == 0:
|
| 31 |
+
return {
|
| 32 |
+
"query": clean_query,
|
| 33 |
+
"guard_passed": False,
|
| 34 |
+
"answer": "I can only answer questions about Darshan's work, projects, and background.",
|
| 35 |
+
"path": "blocked",
|
| 36 |
+
"query_topic": topic,
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
# 2. Classify (scope evaluation).
|
| 40 |
is_safe, score = classifier.is_in_scope(clean_query)
|
| 41 |
+
|
| 42 |
if not is_safe:
|
| 43 |
+
return {
|
| 44 |
+
"query": clean_query,
|
| 45 |
+
"guard_passed": False,
|
| 46 |
+
"answer": "I can only answer questions about Darshan's work, projects, and background.",
|
| 47 |
+
"path": "blocked",
|
| 48 |
+
"query_topic": topic,
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
return {
|
| 52 |
"query": clean_query,
|
| 53 |
+
"guard_passed": True,
|
| 54 |
+
"query_topic": topic,
|
| 55 |
}
|
| 56 |
+
|
| 57 |
return guard_node
|
| 58 |
+
|
app/pipeline/nodes/retrieve.py
CHANGED
|
@@ -1,6 +1,8 @@
|
|
| 1 |
import asyncio
|
| 2 |
from typing import Callable
|
| 3 |
|
|
|
|
|
|
|
| 4 |
from app.models.pipeline import PipelineState, Chunk
|
| 5 |
from app.services.vector_store import VectorStore
|
| 6 |
from app.services.embedder import Embedder
|
|
@@ -90,13 +92,33 @@ def _rrf_merge(ranked_lists: list[list[Chunk]]) -> list[Chunk]:
|
|
| 90 |
return [chunks_by_fp[fp] for fp in sorted_fps]
|
| 91 |
|
| 92 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
def make_retrieve_node(
|
| 94 |
vector_store: VectorStore, embedder: Embedder, reranker: Reranker
|
| 95 |
) -> Callable[[PipelineState], dict]:
|
| 96 |
async def retrieve_node(state: PipelineState) -> dict:
|
|
|
|
|
|
|
| 97 |
attempts = state.get("retrieval_attempts", 0)
|
| 98 |
query = state["query"]
|
| 99 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
# On a CRAG retry (attempts >= 1) the query has been rewritten and
|
| 101 |
# query_embedding is explicitly set to None β always re-embed.
|
| 102 |
# On the first attempt, reuse the embedding computed by the cache node.
|
|
@@ -134,6 +156,25 @@ def make_retrieve_node(
|
|
| 134 |
all_ranked_lists = dense_results + ([sparse_results] if sparse_results else [])
|
| 135 |
fused: list[Chunk] = _rrf_merge(all_ranked_lists)
|
| 136 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
# ββ Deduplication (question-point collapse) ββββββββββββββββββββββββββββ
|
| 138 |
# Multiple points for the same chunk (main + question points from Stage 3)
|
| 139 |
# share the same doc_id::section fingerprint and collapse here.
|
|
@@ -145,6 +186,11 @@ def make_retrieve_node(
|
|
| 145 |
seen.add(fp)
|
| 146 |
unique_chunks.append(c)
|
| 147 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
reranked = await reranker.rerank(query, unique_chunks, top_k=5)
|
| 149 |
|
| 150 |
# ββ Relevance gate βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
@@ -158,11 +204,6 @@ def make_retrieve_node(
|
|
| 158 |
}
|
| 159 |
|
| 160 |
# ββ Source diversity cap (query-aware) βββββββββββββββββββββββββββββββββ
|
| 161 |
-
# Broad queries: max 2 chunks per source document (anti-resume-monopoly).
|
| 162 |
-
# Focused queries (experience, skills, project, blog): raise the cap for
|
| 163 |
-
# the matching source type to 4, cap everything else at 1. This lets
|
| 164 |
-
# the resume fill appropriately on "what is Darshan's work experience?"
|
| 165 |
-
# without harming answer quality on broad queries.
|
| 166 |
focused_type = _focused_source_type(query)
|
| 167 |
doc_counts: dict[str, int] = {}
|
| 168 |
diverse_chunks: list[Chunk] = []
|
|
@@ -179,6 +220,25 @@ def make_retrieve_node(
|
|
| 179 |
diverse_chunks.append(chunk)
|
| 180 |
doc_counts[doc_id] = doc_counts.get(doc_id, 0) + 1
|
| 181 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
return {
|
| 183 |
"retrieved_chunks": unique_chunks,
|
| 184 |
"reranked_chunks": diverse_chunks,
|
|
|
|
| 1 |
import asyncio
|
| 2 |
from typing import Callable
|
| 3 |
|
| 4 |
+
from langgraph.config import get_stream_writer
|
| 5 |
+
|
| 6 |
from app.models.pipeline import PipelineState, Chunk
|
| 7 |
from app.services.vector_store import VectorStore
|
| 8 |
from app.services.embedder import Embedder
|
|
|
|
| 92 |
return [chunks_by_fp[fp] for fp in sorted_fps]
|
| 93 |
|
| 94 |
|
| 95 |
+
_TYPE_REMAP: dict[str, str] = {
|
| 96 |
+
"github": "readme",
|
| 97 |
+
"bio": "resume",
|
| 98 |
+
"cv": "resume",
|
| 99 |
+
"blog": "blog",
|
| 100 |
+
"project": "project",
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
|
| 104 |
def make_retrieve_node(
|
| 105 |
vector_store: VectorStore, embedder: Embedder, reranker: Reranker
|
| 106 |
) -> Callable[[PipelineState], dict]:
|
| 107 |
async def retrieve_node(state: PipelineState) -> dict:
|
| 108 |
+
writer = get_stream_writer()
|
| 109 |
+
|
| 110 |
attempts = state.get("retrieval_attempts", 0)
|
| 111 |
query = state["query"]
|
| 112 |
|
| 113 |
+
# Reuse the topic computed by the guard node β no recomputation needed.
|
| 114 |
+
topic = state.get("query_topic") or ""
|
| 115 |
+
searching_label = (
|
| 116 |
+
f"Searching portfolio for {topic}..."
|
| 117 |
+
if topic
|
| 118 |
+
else "Searching portfolio..."
|
| 119 |
+
)
|
| 120 |
+
writer({"type": "status", "label": searching_label})
|
| 121 |
+
|
| 122 |
# On a CRAG retry (attempts >= 1) the query has been rewritten and
|
| 123 |
# query_embedding is explicitly set to None β always re-embed.
|
| 124 |
# On the first attempt, reuse the embedding computed by the cache node.
|
|
|
|
| 156 |
all_ranked_lists = dense_results + ([sparse_results] if sparse_results else [])
|
| 157 |
fused: list[Chunk] = _rrf_merge(all_ranked_lists)
|
| 158 |
|
| 159 |
+
# ββ Reading events β one per unique source document ββββββββββββββββββββ
|
| 160 |
+
# Emitted BEFORE deduplication so the user sees sources appear in
|
| 161 |
+
# real time as Qdrant returns them, matching Perplexity's "Reading..."
|
| 162 |
+
# display. Deduplication here is by source_url so blog posts with
|
| 163 |
+
# multiple chunk hits fire only one event.
|
| 164 |
+
seen_urls: set[str] = set()
|
| 165 |
+
for chunk in fused:
|
| 166 |
+
meta = chunk["metadata"]
|
| 167 |
+
url = meta.get("source_url") or ""
|
| 168 |
+
dedup_key = url if url else meta.get("doc_id", "")
|
| 169 |
+
if dedup_key and dedup_key not in seen_urls:
|
| 170 |
+
seen_urls.add(dedup_key)
|
| 171 |
+
writer({
|
| 172 |
+
"type": "reading",
|
| 173 |
+
"title": meta.get("source_title", ""),
|
| 174 |
+
"url": url or None,
|
| 175 |
+
"source_type": _TYPE_REMAP.get(meta.get("source_type", ""), meta.get("source_type", "")),
|
| 176 |
+
})
|
| 177 |
+
|
| 178 |
# ββ Deduplication (question-point collapse) ββββββββββββββββββββββββββββ
|
| 179 |
# Multiple points for the same chunk (main + question points from Stage 3)
|
| 180 |
# share the same doc_id::section fingerprint and collapse here.
|
|
|
|
| 186 |
seen.add(fp)
|
| 187 |
unique_chunks.append(c)
|
| 188 |
|
| 189 |
+
writer({
|
| 190 |
+
"type": "status",
|
| 191 |
+
"label": f"Comparing {len(unique_chunks)} sources for relevance...",
|
| 192 |
+
})
|
| 193 |
+
|
| 194 |
reranked = await reranker.rerank(query, unique_chunks, top_k=5)
|
| 195 |
|
| 196 |
# ββ Relevance gate βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
| 204 |
}
|
| 205 |
|
| 206 |
# ββ Source diversity cap (query-aware) βββββββββββββββββββββββββββββββββ
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
focused_type = _focused_source_type(query)
|
| 208 |
doc_counts: dict[str, int] = {}
|
| 209 |
diverse_chunks: list[Chunk] = []
|
|
|
|
| 220 |
diverse_chunks.append(chunk)
|
| 221 |
doc_counts[doc_id] = doc_counts.get(doc_id, 0) + 1
|
| 222 |
|
| 223 |
+
# ββ Sources event β final selected sources shown before the answer ββββββ
|
| 224 |
+
# This is the Perplexity-style source card row that appears before tokens.
|
| 225 |
+
# Emitted here so the frontend can display source cards before Groq starts.
|
| 226 |
+
sources_payload = []
|
| 227 |
+
for chunk in diverse_chunks:
|
| 228 |
+
meta = chunk["metadata"]
|
| 229 |
+
url = meta.get("source_url") or None
|
| 230 |
+
sources_payload.append({
|
| 231 |
+
"title": meta.get("source_title", ""),
|
| 232 |
+
"url": url,
|
| 233 |
+
"source_type": _TYPE_REMAP.get(meta.get("source_type", ""), meta.get("source_type", "")),
|
| 234 |
+
"section": meta.get("section", ""),
|
| 235 |
+
})
|
| 236 |
+
writer({"type": "sources", "sources": sources_payload})
|
| 237 |
+
|
| 238 |
+
# Let the user know what top source the answer will be written from.
|
| 239 |
+
top_title = diverse_chunks[0]["metadata"].get("source_title", "sources")
|
| 240 |
+
writer({"type": "status", "label": f"Writing answer from {top_title}..."})
|
| 241 |
+
|
| 242 |
return {
|
| 243 |
"retrieved_chunks": unique_chunks,
|
| 244 |
"reranked_chunks": diverse_chunks,
|
pytest.ini
CHANGED
|
@@ -4,3 +4,5 @@ python_files = test_*.py
|
|
| 4 |
python_classes = Test*
|
| 5 |
python_functions = test_*
|
| 6 |
addopts = -x --tb=short -q
|
|
|
|
|
|
|
|
|
| 4 |
python_classes = Test*
|
| 5 |
python_functions = test_*
|
| 6 |
addopts = -x --tb=short -q
|
| 7 |
+
filterwarnings =
|
| 8 |
+
ignore::DeprecationWarning:slowapi.*
|
tests/conftest.py
CHANGED
|
@@ -57,10 +57,20 @@ def app_client():
|
|
| 57 |
|
| 58 |
mock_pipeline = MagicMock()
|
| 59 |
|
| 60 |
-
async def fake_astream(state):
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
|
| 65 |
mock_pipeline.astream = fake_astream
|
| 66 |
|
|
|
|
| 57 |
|
| 58 |
mock_pipeline = MagicMock()
|
| 59 |
|
| 60 |
+
async def fake_astream(state, stream_mode=None):
|
| 61 |
+
# Support the new stream_mode=["custom", "updates"] tuple format used by chat.py.
|
| 62 |
+
if isinstance(stream_mode, list):
|
| 63 |
+
yield ("custom", {"type": "status", "label": "Checking your question"})
|
| 64 |
+
yield ("updates", {"guard": {"guard_passed": True}})
|
| 65 |
+
yield ("updates", {"cache": {"cached": False}})
|
| 66 |
+
yield ("custom", {"type": "status", "label": "Thinking about your question directly..."})
|
| 67 |
+
yield ("custom", {"type": "token", "text": "I built TextOps."})
|
| 68 |
+
yield ("updates", {"generate": {"answer": "I built TextOps.", "sources": []}})
|
| 69 |
+
else:
|
| 70 |
+
# Fallback for any code that still calls astream without stream_mode.
|
| 71 |
+
yield {"guard": {"guard_passed": True}}
|
| 72 |
+
yield {"cache": {"cached": False}}
|
| 73 |
+
yield {"generate": {"answer": "I built TextOps.", "sources": []}}
|
| 74 |
|
| 75 |
mock_pipeline.astream = fake_astream
|
| 76 |
|