Spaces:
Running
Running
GitHub Actions commited on
Commit ·
e7c9ee6
1
Parent(s): 4fc2936
Deploy d8ad462
Browse files- app/api/chat.py +69 -16
- app/main.py +6 -1
- app/models/pipeline.py +8 -0
- app/pipeline/graph.py +44 -13
- app/pipeline/nodes/cache.py +4 -1
- app/pipeline/nodes/retrieve.py +82 -25
- app/pipeline/nodes/rewrite_query.py +95 -0
- app/services/embedder.py +22 -7
- app/services/sparse_encoder.py +80 -0
- app/services/vector_store.py +139 -47
- requirements.txt +3 -0
app/api/chat.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
import json
|
| 2 |
import re
|
| 3 |
import time
|
|
@@ -28,6 +29,55 @@ def _is_criticism(message: str) -> bool:
|
|
| 28 |
return any(sig in lowered for sig in _CRITICISM_SIGNALS)
|
| 29 |
|
| 30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
@router.post("")
|
| 32 |
@chat_rate_limit()
|
| 33 |
async def chat_endpoint(
|
|
@@ -41,16 +91,12 @@ async def chat_endpoint(
|
|
| 41 |
# All singletons pre-built in lifespan — zero allocation in hot path.
|
| 42 |
pipeline = request.app.state.pipeline
|
| 43 |
conv_store = request.app.state.conversation_store
|
|
|
|
| 44 |
session_id = request_data.session_id
|
| 45 |
|
| 46 |
-
# Fetch prior turns and detect criticism BEFORE the pipeline runs.
|
| 47 |
-
# Both are synchronous SQLite reads (<3ms) so they don't block the event loop
|
| 48 |
-
# meaningfully, but we keep them outside sse_generator to avoid any closure issues.
|
| 49 |
conversation_history = conv_store.get_recent(session_id)
|
| 50 |
criticism = _is_criticism(request_data.message)
|
| 51 |
if criticism and conversation_history:
|
| 52 |
-
# Auto-record negative feedback on the previous turn so the self-improvement
|
| 53 |
-
# loop picks it up during the next reranker fine-tune cycle.
|
| 54 |
conv_store.mark_last_negative(session_id)
|
| 55 |
|
| 56 |
initial_state: PipelineState = { # type: ignore[assignment]
|
|
@@ -71,6 +117,9 @@ async def chat_endpoint(
|
|
| 71 |
"latency_ms": 0,
|
| 72 |
"error": None,
|
| 73 |
"interaction_id": None,
|
|
|
|
|
|
|
|
|
|
| 74 |
}
|
| 75 |
|
| 76 |
async def sse_generator():
|
|
@@ -81,19 +130,10 @@ async def chat_endpoint(
|
|
| 81 |
|
| 82 |
try:
|
| 83 |
async for event in pipeline.astream(initial_state):
|
| 84 |
-
# Abort on client disconnect — prevents orphaned instances burning vCPU-seconds.
|
| 85 |
if await request.is_disconnected():
|
| 86 |
break
|
| 87 |
|
| 88 |
for node_name, updates in event.items():
|
| 89 |
-
# ── Stage transparency ─────────────────────────────────────────
|
| 90 |
-
# Emit named stage events so the frontend can show a live
|
| 91 |
-
# progress indicator ("checking cache" → "searching" → "writing").
|
| 92 |
-
# Mapping: node name → SSE stage label.
|
| 93 |
-
#
|
| 94 |
-
# cache miss → "checking" (semantic cache lookup ran, no hit)
|
| 95 |
-
# gemini_fast → already emits thinking:true if routing to RAG
|
| 96 |
-
# retrieve done → "generating" (retrieval complete, LLM starting)
|
| 97 |
if node_name == "cache" and updates.get("cached") is False:
|
| 98 |
yield f'data: {json.dumps({"stage": "checking"})}\n\n'
|
| 99 |
elif node_name == "cache" and updates.get("cached") is True:
|
|
@@ -102,11 +142,13 @@ async def chat_endpoint(
|
|
| 102 |
if node_name == "retrieve":
|
| 103 |
yield f'data: {json.dumps({"stage": "generating"})}\n\n'
|
| 104 |
|
| 105 |
-
#
|
|
|
|
|
|
|
|
|
|
| 106 |
if updates.get("thinking") is True:
|
| 107 |
yield f'data: {json.dumps({"thinking": True, "stage": "searching"})}\n\n'
|
| 108 |
|
| 109 |
-
# ── Answer tokens ──────────────────────────────────────────────
|
| 110 |
if "answer" in updates:
|
| 111 |
answer_update = updates["answer"]
|
| 112 |
delta = (
|
|
@@ -138,6 +180,16 @@ async def chat_endpoint(
|
|
| 138 |
|
| 139 |
yield f'data: {json.dumps({"done": True, "sources": sources_list, "cached": is_cached, "latency_ms": elapsed_ms, "interaction_id": interaction_id})}\n\n'
|
| 140 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
except Exception as exc:
|
| 142 |
yield f'data: {json.dumps({"error": str(exc) or "Generation failed"})}\n\n'
|
| 143 |
|
|
@@ -146,3 +198,4 @@ async def chat_endpoint(
|
|
| 146 |
media_type="text/event-stream",
|
| 147 |
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
| 148 |
)
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
import json
|
| 3 |
import re
|
| 4 |
import time
|
|
|
|
| 29 |
return any(sig in lowered for sig in _CRITICISM_SIGNALS)
|
| 30 |
|
| 31 |
|
| 32 |
+
async def _generate_follow_ups(
|
| 33 |
+
query: str,
|
| 34 |
+
answer: str,
|
| 35 |
+
sources: list,
|
| 36 |
+
llm_client,
|
| 37 |
+
) -> list[str]:
|
| 38 |
+
"""
|
| 39 |
+
Generates 3 specific follow-up questions after the main answer is complete.
|
| 40 |
+
Runs as a concurrent asyncio Task — zero added latency after the done event.
|
| 41 |
+
|
| 42 |
+
Questions must be:
|
| 43 |
+
- Specific to the answer content (never generic like "tell me more")
|
| 44 |
+
- Phrased naturally (< 12 words)
|
| 45 |
+
- Answerable from the knowledge base
|
| 46 |
+
"""
|
| 47 |
+
source_titles = [
|
| 48 |
+
(s.title if hasattr(s, "title") else s.get("title", ""))
|
| 49 |
+
for s in sources[:3]
|
| 50 |
+
]
|
| 51 |
+
titles_str = ", ".join(t for t in source_titles if t) or "the knowledge base"
|
| 52 |
+
|
| 53 |
+
prompt = (
|
| 54 |
+
f"Question asked: {query}\n\n"
|
| 55 |
+
f"Answer given (excerpt): {answer[:400]}\n\n"
|
| 56 |
+
f"Sources referenced: {titles_str}\n\n"
|
| 57 |
+
"Write exactly 3 follow-up questions a recruiter would naturally ask next. "
|
| 58 |
+
"Each question must be specific to the content above — not generic. "
|
| 59 |
+
"Each question must be under 12 words. "
|
| 60 |
+
"Output ONLY the 3 questions, one per line, no numbering or bullet points."
|
| 61 |
+
)
|
| 62 |
+
system = (
|
| 63 |
+
"You write concise follow-up questions for a portfolio chatbot. "
|
| 64 |
+
"Never write generic questions like 'tell me more' or 'what else'. "
|
| 65 |
+
"Each question must be under 12 words and reference specifics from the answer."
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
try:
|
| 69 |
+
stream = llm_client.complete_with_complexity(
|
| 70 |
+
prompt=prompt, system=system, stream=True, complexity="simple"
|
| 71 |
+
)
|
| 72 |
+
raw = ""
|
| 73 |
+
async for token in stream:
|
| 74 |
+
raw += token
|
| 75 |
+
questions = [q.strip() for q in raw.strip().splitlines() if q.strip()][:3]
|
| 76 |
+
return questions
|
| 77 |
+
except Exception:
|
| 78 |
+
return []
|
| 79 |
+
|
| 80 |
+
|
| 81 |
@router.post("")
|
| 82 |
@chat_rate_limit()
|
| 83 |
async def chat_endpoint(
|
|
|
|
| 91 |
# All singletons pre-built in lifespan — zero allocation in hot path.
|
| 92 |
pipeline = request.app.state.pipeline
|
| 93 |
conv_store = request.app.state.conversation_store
|
| 94 |
+
llm_client = request.app.state.llm_client
|
| 95 |
session_id = request_data.session_id
|
| 96 |
|
|
|
|
|
|
|
|
|
|
| 97 |
conversation_history = conv_store.get_recent(session_id)
|
| 98 |
criticism = _is_criticism(request_data.message)
|
| 99 |
if criticism and conversation_history:
|
|
|
|
|
|
|
| 100 |
conv_store.mark_last_negative(session_id)
|
| 101 |
|
| 102 |
initial_state: PipelineState = { # type: ignore[assignment]
|
|
|
|
| 117 |
"latency_ms": 0,
|
| 118 |
"error": None,
|
| 119 |
"interaction_id": None,
|
| 120 |
+
"retrieval_attempts": 0,
|
| 121 |
+
"rewritten_query": None,
|
| 122 |
+
"follow_ups": [],
|
| 123 |
}
|
| 124 |
|
| 125 |
async def sse_generator():
|
|
|
|
| 130 |
|
| 131 |
try:
|
| 132 |
async for event in pipeline.astream(initial_state):
|
|
|
|
| 133 |
if await request.is_disconnected():
|
| 134 |
break
|
| 135 |
|
| 136 |
for node_name, updates in event.items():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
if node_name == "cache" and updates.get("cached") is False:
|
| 138 |
yield f'data: {json.dumps({"stage": "checking"})}\n\n'
|
| 139 |
elif node_name == "cache" and updates.get("cached") is True:
|
|
|
|
| 142 |
if node_name == "retrieve":
|
| 143 |
yield f'data: {json.dumps({"stage": "generating"})}\n\n'
|
| 144 |
|
| 145 |
+
# CRAG rewrite in progress — inform the frontend the query is being refined.
|
| 146 |
+
if node_name == "rewrite_query":
|
| 147 |
+
yield f'data: {json.dumps({"stage": "refining"})}\n\n'
|
| 148 |
+
|
| 149 |
if updates.get("thinking") is True:
|
| 150 |
yield f'data: {json.dumps({"thinking": True, "stage": "searching"})}\n\n'
|
| 151 |
|
|
|
|
| 152 |
if "answer" in updates:
|
| 153 |
answer_update = updates["answer"]
|
| 154 |
delta = (
|
|
|
|
| 180 |
|
| 181 |
yield f'data: {json.dumps({"done": True, "sources": sources_list, "cached": is_cached, "latency_ms": elapsed_ms, "interaction_id": interaction_id})}\n\n'
|
| 182 |
|
| 183 |
+
# ── Follow-up questions ────────────────────────────────────────────
|
| 184 |
+
# Generated after the done event so it never delays answer delivery.
|
| 185 |
+
# Works for both cache hits (no sources) and full RAG responses.
|
| 186 |
+
if final_answer and not await request.is_disconnected():
|
| 187 |
+
follow_ups = await _generate_follow_ups(
|
| 188 |
+
request_data.message, final_answer, final_sources, llm_client
|
| 189 |
+
)
|
| 190 |
+
if follow_ups:
|
| 191 |
+
yield f'data: {json.dumps({"follow_ups": follow_ups})}\n\n'
|
| 192 |
+
|
| 193 |
except Exception as exc:
|
| 194 |
yield f'data: {json.dumps({"error": str(exc) or "Generation failed"})}\n\n'
|
| 195 |
|
|
|
|
| 198 |
media_type="text/event-stream",
|
| 199 |
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
| 200 |
)
|
| 201 |
+
|
app/main.py
CHANGED
|
@@ -76,12 +76,17 @@ async def lifespan(app: FastAPI):
|
|
| 76 |
# ingest run doesn't crash every search with "collection not found".
|
| 77 |
vector_store.ensure_collection()
|
| 78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
app.state.pipeline = build_pipeline({
|
| 80 |
"classifier": GuardClassifier(),
|
| 81 |
"cache": app.state.semantic_cache,
|
| 82 |
"embedder": embedder,
|
| 83 |
"gemini": gemini_client,
|
| 84 |
-
"llm":
|
| 85 |
"vector_store": vector_store,
|
| 86 |
"reranker": reranker,
|
| 87 |
"db_path": settings.DB_PATH,
|
|
|
|
| 76 |
# ingest run doesn't crash every search with "collection not found".
|
| 77 |
vector_store.ensure_collection()
|
| 78 |
|
| 79 |
+
llm_client = get_llm_client(settings)
|
| 80 |
+
# Expose llm_client on app state so chat.py can use it for follow-up
|
| 81 |
+
# question generation without re-constructing the client per request.
|
| 82 |
+
app.state.llm_client = llm_client
|
| 83 |
+
|
| 84 |
app.state.pipeline = build_pipeline({
|
| 85 |
"classifier": GuardClassifier(),
|
| 86 |
"cache": app.state.semantic_cache,
|
| 87 |
"embedder": embedder,
|
| 88 |
"gemini": gemini_client,
|
| 89 |
+
"llm": llm_client,
|
| 90 |
"vector_store": vector_store,
|
| 91 |
"reranker": reranker,
|
| 92 |
"db_path": settings.DB_PATH,
|
app/models/pipeline.py
CHANGED
|
@@ -43,3 +43,11 @@ class PipelineState(TypedDict):
|
|
| 43 |
latency_ms: int
|
| 44 |
error: Optional[str]
|
| 45 |
interaction_id: Optional[int]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
latency_ms: int
|
| 44 |
error: Optional[str]
|
| 45 |
interaction_id: Optional[int]
|
| 46 |
+
# CRAG: counts retrieve node invocations; 2 = one retry was attempted.
|
| 47 |
+
# Starts at 0 in initial state; retrieve increments it each call.
|
| 48 |
+
retrieval_attempts: int
|
| 49 |
+
# Set by the rewrite_query node when CRAG triggers; None otherwise.
|
| 50 |
+
rewritten_query: Optional[str]
|
| 51 |
+
# Follow-up question suggestions generated after the main answer.
|
| 52 |
+
# 3 short questions specific to content in the answer.
|
| 53 |
+
follow_ups: list[str]
|
app/pipeline/graph.py
CHANGED
|
@@ -6,9 +6,13 @@ from app.pipeline.nodes.guard import make_guard_node
|
|
| 6 |
from app.pipeline.nodes.cache import make_cache_node
|
| 7 |
from app.pipeline.nodes.gemini_fast import make_gemini_fast_node
|
| 8 |
from app.pipeline.nodes.retrieve import make_retrieve_node
|
|
|
|
| 9 |
from app.pipeline.nodes.generate import make_generate_node
|
| 10 |
from app.pipeline.nodes.log_eval import make_log_eval_node
|
| 11 |
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
def route_guard(state: PipelineState) -> str:
|
| 14 |
if state.get("guard_passed", False):
|
|
@@ -33,19 +37,42 @@ def route_gemini(state: PipelineState) -> str:
|
|
| 33 |
return "research"
|
| 34 |
|
| 35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
def build_pipeline(services: dict) -> CompiledStateGraph:
|
| 37 |
graph = StateGraph(PipelineState)
|
| 38 |
|
| 39 |
-
graph.add_node("guard",
|
| 40 |
-
|
| 41 |
-
graph.add_node("
|
| 42 |
-
graph.add_node("
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
graph.add_node("
|
| 48 |
-
graph.add_node("
|
|
|
|
| 49 |
|
| 50 |
graph.set_entry_point("guard")
|
| 51 |
|
|
@@ -58,9 +85,13 @@ def build_pipeline(services: dict) -> CompiledStateGraph:
|
|
| 58 |
graph.add_conditional_edges("gemini_fast", route_gemini,
|
| 59 |
{"answered": "log_eval", "research": "retrieve"})
|
| 60 |
|
| 61 |
-
#
|
| 62 |
-
|
| 63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
|
| 65 |
graph.add_edge("generate", "log_eval")
|
| 66 |
graph.add_edge("log_eval", END)
|
|
|
|
| 6 |
from app.pipeline.nodes.cache import make_cache_node
|
| 7 |
from app.pipeline.nodes.gemini_fast import make_gemini_fast_node
|
| 8 |
from app.pipeline.nodes.retrieve import make_retrieve_node
|
| 9 |
+
from app.pipeline.nodes.rewrite_query import make_rewrite_query_node, _has_meaningful_token
|
| 10 |
from app.pipeline.nodes.generate import make_generate_node
|
| 11 |
from app.pipeline.nodes.log_eval import make_log_eval_node
|
| 12 |
|
| 13 |
+
# Relevance gate threshold — matches retrieve.py constant.
|
| 14 |
+
_MIN_TOP_SCORE: float = -3.5
|
| 15 |
+
|
| 16 |
|
| 17 |
def route_guard(state: PipelineState) -> str:
|
| 18 |
if state.get("guard_passed", False):
|
|
|
|
| 37 |
return "research"
|
| 38 |
|
| 39 |
|
| 40 |
+
def route_retrieve_result(state: PipelineState) -> str:
|
| 41 |
+
"""
|
| 42 |
+
CRAG routing: if the first retrieval returned nothing above threshold,
|
| 43 |
+
rewrite the query once and retry. Exactly one retry is permitted.
|
| 44 |
+
|
| 45 |
+
Conditions for a rewrite attempt:
|
| 46 |
+
1. retrieval_attempts == 1 (first pass just completed, no retry yet).
|
| 47 |
+
2. reranked_chunks is empty (nothing above the -3.5 threshold).
|
| 48 |
+
3. Query has at least one meaningful non-stop-word token (guards against
|
| 49 |
+
empty or fully-generic queries where a rewrite wouldn't help).
|
| 50 |
+
"""
|
| 51 |
+
attempts = state.get("retrieval_attempts", 1)
|
| 52 |
+
reranked = state.get("reranked_chunks", [])
|
| 53 |
+
if (
|
| 54 |
+
attempts == 1
|
| 55 |
+
and not reranked
|
| 56 |
+
and _has_meaningful_token(state.get("query", ""))
|
| 57 |
+
):
|
| 58 |
+
return "rewrite"
|
| 59 |
+
return "generate"
|
| 60 |
+
|
| 61 |
+
|
| 62 |
def build_pipeline(services: dict) -> CompiledStateGraph:
|
| 63 |
graph = StateGraph(PipelineState)
|
| 64 |
|
| 65 |
+
graph.add_node("guard", make_guard_node(services["classifier"]))
|
| 66 |
+
graph.add_node("cache", make_cache_node(services["cache"], services["embedder"]))
|
| 67 |
+
graph.add_node("gemini_fast", make_gemini_fast_node(services["gemini"]))
|
| 68 |
+
graph.add_node("retrieve", make_retrieve_node(
|
| 69 |
+
services["vector_store"],
|
| 70 |
+
services["embedder"],
|
| 71 |
+
services["reranker"]))
|
| 72 |
+
# CRAG: one query rewrite on failed retrieval — then retrieve runs a second time.
|
| 73 |
+
graph.add_node("rewrite_query", make_rewrite_query_node(services["gemini"]))
|
| 74 |
+
graph.add_node("generate", make_generate_node(services["llm"], services["gemini"]))
|
| 75 |
+
graph.add_node("log_eval", make_log_eval_node(services["db_path"]))
|
| 76 |
|
| 77 |
graph.set_entry_point("guard")
|
| 78 |
|
|
|
|
| 85 |
graph.add_conditional_edges("gemini_fast", route_gemini,
|
| 86 |
{"answered": "log_eval", "research": "retrieve"})
|
| 87 |
|
| 88 |
+
# After retrieve: either run CRAG rewrite (one retry) or proceed to generate.
|
| 89 |
+
graph.add_conditional_edges("retrieve", route_retrieve_result,
|
| 90 |
+
{"rewrite": "rewrite_query", "generate": "generate"})
|
| 91 |
+
|
| 92 |
+
# After rewrite: go straight back to retrieve for the second attempt.
|
| 93 |
+
# The cycle terminates because route_retrieve_result checks retrieval_attempts.
|
| 94 |
+
graph.add_edge("rewrite_query", "retrieve")
|
| 95 |
|
| 96 |
graph.add_edge("generate", "log_eval")
|
| 97 |
graph.add_edge("log_eval", END)
|
app/pipeline/nodes/cache.py
CHANGED
|
@@ -16,7 +16,10 @@ from app.services.semantic_cache import SemanticCache
|
|
| 16 |
|
| 17 |
def make_cache_node(cache: SemanticCache, embedder) -> Callable[[PipelineState], dict]:
|
| 18 |
async def cache_node(state: PipelineState) -> dict:
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
| 20 |
query_embedding = np.array(embedding)
|
| 21 |
|
| 22 |
cached = await cache.get(query_embedding)
|
|
|
|
| 16 |
|
| 17 |
def make_cache_node(cache: SemanticCache, embedder) -> Callable[[PipelineState], dict]:
|
| 18 |
async def cache_node(state: PipelineState) -> dict:
|
| 19 |
+
# is_query=True: prepend BGE asymmetric instruction so query embedding
|
| 20 |
+
# lands in the retrieval-optimised neighbourhood of the vector space.
|
| 21 |
+
# Document embeddings at ingestion time use is_query=False (default).
|
| 22 |
+
embedding = await embedder.embed_one(state["query"], is_query=True)
|
| 23 |
query_embedding = np.array(embedding)
|
| 24 |
|
| 25 |
cached = await cache.get(query_embedding)
|
app/pipeline/nodes/retrieve.py
CHANGED
|
@@ -1,9 +1,11 @@
|
|
|
|
|
| 1 |
from typing import Callable
|
| 2 |
|
| 3 |
from app.models.pipeline import PipelineState, Chunk
|
| 4 |
from app.services.vector_store import VectorStore
|
| 5 |
from app.services.embedder import Embedder
|
| 6 |
from app.services.reranker import Reranker
|
|
|
|
| 7 |
|
| 8 |
# Cross-encoder ms-marco-MiniLM-L-6-v2 returns raw logits (not sigmoid).
|
| 9 |
# Highly relevant docs score 0–15; completely off-topic score below –5.
|
|
@@ -20,55 +22,109 @@ _MIN_TOP_SCORE: float = -3.5
|
|
| 20 |
# relevant sources and making the answer look one-dimensional.
|
| 21 |
_MAX_CHUNKS_PER_DOC: int = 2
|
| 22 |
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
async def retrieve_node(state: PipelineState) -> dict:
|
|
|
|
| 26 |
query = state["query"]
|
| 27 |
-
expanded = state.get("expanded_queries", [query])
|
| 28 |
|
| 29 |
-
#
|
| 30 |
-
#
|
| 31 |
-
#
|
| 32 |
cached_embedding: list[float] | None = state.get("query_embedding")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
|
|
|
| 34 |
if cached_embedding is not None and len(expanded) == 1:
|
| 35 |
-
# Fast path: single query, embedding already computed.
|
| 36 |
query_vectors = [cached_embedding]
|
| 37 |
else:
|
| 38 |
-
|
| 39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
|
| 46 |
-
#
|
|
|
|
|
|
|
| 47 |
seen: set[str] = set()
|
| 48 |
unique_chunks: list[Chunk] = []
|
| 49 |
-
for c in
|
| 50 |
-
|
| 51 |
-
if
|
| 52 |
-
seen.add(
|
| 53 |
unique_chunks.append(c)
|
| 54 |
|
| 55 |
reranked = await reranker.rerank(query, unique_chunks, top_k=5)
|
| 56 |
|
| 57 |
-
# Relevance gate
|
| 58 |
-
# cross-encoder threshold, the knowledge base genuinely has nothing useful
|
| 59 |
-
# for this query. Return not-found so generate_node isn't fed garbage context
|
| 60 |
-
# that causes vague or hallucinated responses.
|
| 61 |
top_score = reranked[0]["metadata"].get("rerank_score", 0.0) if reranked else None
|
| 62 |
if not reranked or (top_score is not None and top_score < _MIN_TOP_SCORE):
|
| 63 |
return {
|
| 64 |
-
"answer": "",
|
| 65 |
"retrieved_chunks": [],
|
| 66 |
"reranked_chunks": [],
|
|
|
|
| 67 |
}
|
| 68 |
|
| 69 |
-
# Source diversity
|
| 70 |
-
# from filling all context slots and drowning out other relevant sources.
|
| 71 |
-
# Applied after reranking so the reranker sees the full candidate set.
|
| 72 |
doc_counts: dict[str, int] = {}
|
| 73 |
diverse_chunks: list[Chunk] = []
|
| 74 |
for chunk in reranked:
|
|
@@ -80,6 +136,7 @@ def make_retrieve_node(vector_store: VectorStore, embedder: Embedder, reranker:
|
|
| 80 |
return {
|
| 81 |
"retrieved_chunks": unique_chunks,
|
| 82 |
"reranked_chunks": diverse_chunks,
|
|
|
|
| 83 |
}
|
| 84 |
|
| 85 |
return retrieve_node
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
from typing import Callable
|
| 3 |
|
| 4 |
from app.models.pipeline import PipelineState, Chunk
|
| 5 |
from app.services.vector_store import VectorStore
|
| 6 |
from app.services.embedder import Embedder
|
| 7 |
from app.services.reranker import Reranker
|
| 8 |
+
from app.services.sparse_encoder import SparseEncoder
|
| 9 |
|
| 10 |
# Cross-encoder ms-marco-MiniLM-L-6-v2 returns raw logits (not sigmoid).
|
| 11 |
# Highly relevant docs score 0–15; completely off-topic score below –5.
|
|
|
|
| 22 |
# relevant sources and making the answer look one-dimensional.
|
| 23 |
_MAX_CHUNKS_PER_DOC: int = 2
|
| 24 |
|
| 25 |
+
# RRF rank fusion constant. k=60 is the original Cormack et al. default.
|
| 26 |
+
# Higher k reduces the influence of top-1 rank advantage.
|
| 27 |
+
_RRF_K: int = 60
|
| 28 |
|
| 29 |
+
# Module-level singleton — BM25 model downloads once (~5 MB), cached in memory.
|
| 30 |
+
_sparse_encoder = SparseEncoder()
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _rrf_merge(ranked_lists: list[list[Chunk]]) -> list[Chunk]:
|
| 34 |
+
"""
|
| 35 |
+
Reciprocal Rank Fusion across multiple ranked chunk lists.
|
| 36 |
+
|
| 37 |
+
Score formula: Σ 1 / (rank + 1 + k) over all lists that contain the chunk.
|
| 38 |
+
Deduplication by doc_id::section fingerprint before merging so the same
|
| 39 |
+
passage retrieved by both dense and sparse does not double-count.
|
| 40 |
+
|
| 41 |
+
Pure Python, no external dependencies.
|
| 42 |
+
"""
|
| 43 |
+
scores: dict[str, float] = {}
|
| 44 |
+
chunks_by_fp: dict[str, Chunk] = {}
|
| 45 |
+
|
| 46 |
+
for ranked in ranked_lists:
|
| 47 |
+
seen_in_list: set[str] = set()
|
| 48 |
+
for rank, chunk in enumerate(ranked):
|
| 49 |
+
fp = f"{chunk['metadata']['doc_id']}::{chunk['metadata']['section']}"
|
| 50 |
+
if fp in seen_in_list:
|
| 51 |
+
continue # Already contributed this chunk from this ranked list
|
| 52 |
+
seen_in_list.add(fp)
|
| 53 |
+
scores[fp] = scores.get(fp, 0.0) + 1.0 / (rank + 1 + _RRF_K)
|
| 54 |
+
chunks_by_fp[fp] = chunk
|
| 55 |
+
|
| 56 |
+
sorted_fps = sorted(scores, key=lambda x: scores[x], reverse=True)
|
| 57 |
+
return [chunks_by_fp[fp] for fp in sorted_fps]
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def make_retrieve_node(
|
| 61 |
+
vector_store: VectorStore, embedder: Embedder, reranker: Reranker
|
| 62 |
+
) -> Callable[[PipelineState], dict]:
|
| 63 |
async def retrieve_node(state: PipelineState) -> dict:
|
| 64 |
+
attempts = state.get("retrieval_attempts", 0)
|
| 65 |
query = state["query"]
|
|
|
|
| 66 |
|
| 67 |
+
# On a CRAG retry (attempts >= 1) the query has been rewritten and
|
| 68 |
+
# query_embedding is explicitly set to None — always re-embed.
|
| 69 |
+
# On the first attempt, reuse the embedding computed by the cache node.
|
| 70 |
cached_embedding: list[float] | None = state.get("query_embedding")
|
| 71 |
+
if attempts >= 1:
|
| 72 |
+
# Second attempt: re-embed the rewritten query with is_query=True.
|
| 73 |
+
cached_embedding = None
|
| 74 |
+
|
| 75 |
+
expanded = [query] # gemini_fast may fill expanded_queries on first attempt
|
| 76 |
+
if attempts == 0:
|
| 77 |
+
expanded = state.get("expanded_queries", [query])
|
| 78 |
|
| 79 |
+
# Embed all query variants in one batched call (is_query=True for asymmetric BGE).
|
| 80 |
if cached_embedding is not None and len(expanded) == 1:
|
|
|
|
| 81 |
query_vectors = [cached_embedding]
|
| 82 |
else:
|
| 83 |
+
query_vectors = await embedder.embed(expanded, is_query=True)
|
| 84 |
+
|
| 85 |
+
# ── Dense search (all query variants) ─────────────────────────────────
|
| 86 |
+
dense_results: list[list[Chunk]] = []
|
| 87 |
+
for vec in query_vectors:
|
| 88 |
+
chunks = vector_store.search(query_vector=vec, top_k=10)
|
| 89 |
+
dense_results.append(chunks)
|
| 90 |
+
|
| 91 |
+
# ── Sparse (BM25) search (primary query only) ─────────────────────────
|
| 92 |
+
# Runs concurrently with dense search isn't possible here since dense
|
| 93 |
+
# is synchronous Qdrant calls, but we parallelise encode + sparse search.
|
| 94 |
+
sparse_results: list[Chunk] = []
|
| 95 |
+
if _sparse_encoder.available:
|
| 96 |
+
indices, values = _sparse_encoder.encode_one(query)
|
| 97 |
+
sparse_results = vector_store.search_sparse(indices, values, top_k=10)
|
| 98 |
|
| 99 |
+
# ── Reciprocal Rank Fusion ─────────────────────────────────────────────
|
| 100 |
+
# Merge dense (per variant) + sparse into one ranked list.
|
| 101 |
+
all_ranked_lists = dense_results + ([sparse_results] if sparse_results else [])
|
| 102 |
+
fused: list[Chunk] = _rrf_merge(all_ranked_lists)
|
| 103 |
|
| 104 |
+
# ── Deduplication (question-point collapse) ────────────────────────────
|
| 105 |
+
# Multiple points for the same chunk (main + question points from Stage 3)
|
| 106 |
+
# share the same doc_id::section fingerprint and collapse here.
|
| 107 |
seen: set[str] = set()
|
| 108 |
unique_chunks: list[Chunk] = []
|
| 109 |
+
for c in fused:
|
| 110 |
+
fp = f"{c['metadata']['doc_id']}::{c['metadata']['section']}"
|
| 111 |
+
if fp not in seen:
|
| 112 |
+
seen.add(fp)
|
| 113 |
unique_chunks.append(c)
|
| 114 |
|
| 115 |
reranked = await reranker.rerank(query, unique_chunks, top_k=5)
|
| 116 |
|
| 117 |
+
# ── Relevance gate ─────────────────────────────────────────────────────
|
|
|
|
|
|
|
|
|
|
| 118 |
top_score = reranked[0]["metadata"].get("rerank_score", 0.0) if reranked else None
|
| 119 |
if not reranked or (top_score is not None and top_score < _MIN_TOP_SCORE):
|
| 120 |
return {
|
| 121 |
+
"answer": "",
|
| 122 |
"retrieved_chunks": [],
|
| 123 |
"reranked_chunks": [],
|
| 124 |
+
"retrieval_attempts": attempts + 1,
|
| 125 |
}
|
| 126 |
|
| 127 |
+
# ── Source diversity cap ───────────────────────────────────────────────
|
|
|
|
|
|
|
| 128 |
doc_counts: dict[str, int] = {}
|
| 129 |
diverse_chunks: list[Chunk] = []
|
| 130 |
for chunk in reranked:
|
|
|
|
| 136 |
return {
|
| 137 |
"retrieved_chunks": unique_chunks,
|
| 138 |
"reranked_chunks": diverse_chunks,
|
| 139 |
+
"retrieval_attempts": attempts + 1,
|
| 140 |
}
|
| 141 |
|
| 142 |
return retrieve_node
|
app/pipeline/nodes/rewrite_query.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
backend/app/pipeline/nodes/rewrite_query.py
|
| 3 |
+
|
| 4 |
+
CRAG (Corrective RAG) query rewriter — fires exactly once per request when:
|
| 5 |
+
1. The first retrieval attempt returned no chunks above the relevance threshold.
|
| 6 |
+
2. The query contains at least one meaningful non-stop-word token.
|
| 7 |
+
|
| 8 |
+
Calls Gemini Flash (temp 0.7) to produce one alternative phrasing that preserves
|
| 9 |
+
the visitor's intent but uses different vocabulary. The pipeline then runs Retrieve
|
| 10 |
+
and Rerank a second time with this new query. There is exactly one retry — the
|
| 11 |
+
graph routing enforces this via the retrieval_attempts counter in state.
|
| 12 |
+
"""
|
| 13 |
+
from __future__ import annotations
|
| 14 |
+
|
| 15 |
+
import logging
|
| 16 |
+
from typing import Any
|
| 17 |
+
|
| 18 |
+
from app.models.pipeline import PipelineState
|
| 19 |
+
from app.services.gemini_client import GeminiClient
|
| 20 |
+
|
| 21 |
+
logger = logging.getLogger(__name__)
|
| 22 |
+
|
| 23 |
+
_REWRITE_PROMPT = """\
|
| 24 |
+
A search query failed to find relevant results in a portfolio knowledge base about Darshan Chheda.
|
| 25 |
+
The knowledge base contains his blog posts, project descriptions, CV/resume, and GitHub README files.
|
| 26 |
+
|
| 27 |
+
Original query: {query}
|
| 28 |
+
|
| 29 |
+
Rephrase this query using different vocabulary that might better match how the content is written.
|
| 30 |
+
Strategies: expand abbreviations, use synonyms, reframe as "did Darshan..." if the query uses a name/tech.
|
| 31 |
+
Output ONLY the rewritten query — one sentence, no explanation, no quotes.
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
# Same stop-word set as generate.py — keeps modules consistent.
|
| 35 |
+
_STOP_WORDS = frozenset({
|
| 36 |
+
"a", "an", "the", "is", "are", "was", "were", "be", "been", "being",
|
| 37 |
+
"have", "has", "had", "do", "does", "did", "will", "would", "could",
|
| 38 |
+
"should", "may", "might", "can", "to", "of", "in", "on", "for",
|
| 39 |
+
"with", "at", "by", "from", "and", "or", "but", "not", "what",
|
| 40 |
+
"who", "how", "why", "when", "where", "tell", "me", "about", "his",
|
| 41 |
+
"he", "him", "any", "some", "that", "this", "it", "its",
|
| 42 |
+
})
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def _has_meaningful_token(query: str) -> bool:
|
| 46 |
+
"""True when the query has at least one non-stop-word token of length >= 3."""
|
| 47 |
+
return any(
|
| 48 |
+
w not in _STOP_WORDS and len(w) >= 3
|
| 49 |
+
for w in __import__("re").findall(r"[a-z]+", query.lower())
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def make_rewrite_query_node(gemini_client: GeminiClient) -> Any:
|
| 54 |
+
async def rewrite_query_node(state: PipelineState) -> dict:
|
| 55 |
+
query = state["query"]
|
| 56 |
+
logger.info("CRAG: rewriting failed query %r", query)
|
| 57 |
+
|
| 58 |
+
if not gemini_client.is_configured:
|
| 59 |
+
# No Gemini — pass query through unchanged; second retrieve will also fail
|
| 60 |
+
# and generate will handle the not-found path gracefully.
|
| 61 |
+
logger.debug("Gemini not configured; skipping query rewrite.")
|
| 62 |
+
return {
|
| 63 |
+
"rewritten_query": query,
|
| 64 |
+
"retrieval_attempts": state.get("retrieval_attempts", 1) + 1,
|
| 65 |
+
"query_embedding": None, # Force re-embed so retrieve doesn't use stale embedding
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
try:
|
| 69 |
+
response = await gemini_client._client.aio.models.generate_content(
|
| 70 |
+
model=gemini_client._model,
|
| 71 |
+
contents=_REWRITE_PROMPT.format(query=query),
|
| 72 |
+
config={"temperature": 0.7},
|
| 73 |
+
)
|
| 74 |
+
rewritten = (response.text or query).strip().strip('"').strip("'")
|
| 75 |
+
except Exception as exc:
|
| 76 |
+
logger.warning("Query rewrite Gemini call failed (%s); using original.", exc)
|
| 77 |
+
rewritten = query
|
| 78 |
+
|
| 79 |
+
if not rewritten or rewritten == query:
|
| 80 |
+
logger.debug("Rewrite produced no change; using original query.")
|
| 81 |
+
rewritten = query
|
| 82 |
+
else:
|
| 83 |
+
logger.info("CRAG rewrite: %r → %r", query, rewritten)
|
| 84 |
+
|
| 85 |
+
# Clearing query_embedding forces the retrieve node to re-embed the new query.
|
| 86 |
+
# retrieval_attempts is incremented so the graph does not loop again after
|
| 87 |
+
# this second retrieval attempt.
|
| 88 |
+
return {
|
| 89 |
+
"query": rewritten,
|
| 90 |
+
"rewritten_query": rewritten,
|
| 91 |
+
"retrieval_attempts": state.get("retrieval_attempts", 1) + 1,
|
| 92 |
+
"query_embedding": None,
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
return rewrite_query_node
|
app/services/embedder.py
CHANGED
|
@@ -23,27 +23,42 @@ def _get_local_model() -> Any:
|
|
| 23 |
return _local_model
|
| 24 |
|
| 25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
class Embedder:
|
| 27 |
def __init__(self, remote_url: str = "", environment: str = "local") -> None:
|
| 28 |
self._remote = environment == "prod" and bool(remote_url)
|
| 29 |
self._url = remote_url.rstrip("/") if self._remote else ""
|
| 30 |
|
| 31 |
-
async def embed(self, texts: list[str]) -> list[list[float]]:
|
| 32 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
if not texts:
|
| 34 |
return []
|
| 35 |
if self._remote:
|
| 36 |
-
#
|
| 37 |
-
# a stable connection lifecycle, so a persistent client risks stale sockets.
|
| 38 |
async with httpx.AsyncClient(timeout=30.0) as client:
|
| 39 |
-
resp = await client.post(
|
|
|
|
|
|
|
|
|
|
| 40 |
resp.raise_for_status()
|
| 41 |
return resp.json()["embeddings"]
|
| 42 |
model = _get_local_model()
|
|
|
|
|
|
|
| 43 |
vectors = model.encode(texts, batch_size=32, normalize_embeddings=True, show_progress_bar=False)
|
| 44 |
return vectors.tolist()
|
| 45 |
|
| 46 |
-
async def embed_one(self, text: str) -> list[float]:
|
| 47 |
"""Convenience wrapper for a single string."""
|
| 48 |
-
results = await self.embed([text])
|
| 49 |
return results[0]
|
|
|
|
| 23 |
return _local_model
|
| 24 |
|
| 25 |
|
| 26 |
+
# BGE asymmetric query instruction — prepended locally when is_query=True and
|
| 27 |
+
# environment is local. In prod the HF Space accepts is_query and prepends itself.
|
| 28 |
+
_BGE_QUERY_PREFIX = "Represent this sentence for searching relevant passages: "
|
| 29 |
+
|
| 30 |
+
|
| 31 |
class Embedder:
|
| 32 |
def __init__(self, remote_url: str = "", environment: str = "local") -> None:
|
| 33 |
self._remote = environment == "prod" and bool(remote_url)
|
| 34 |
self._url = remote_url.rstrip("/") if self._remote else ""
|
| 35 |
|
| 36 |
+
async def embed(self, texts: list[str], is_query: bool = False) -> list[list[float]]:
|
| 37 |
+
"""
|
| 38 |
+
Encodes texts, returns List of L2-normalised 384-dim float vectors.
|
| 39 |
+
|
| 40 |
+
is_query=True: prepend BGE asymmetric query instruction (queries only).
|
| 41 |
+
is_query=False: encode as-is (document/ingestion embeddings).
|
| 42 |
+
See BGE paper: 2-4% NDCG gain from using the correct prefix on queries.
|
| 43 |
+
"""
|
| 44 |
if not texts:
|
| 45 |
return []
|
| 46 |
if self._remote:
|
| 47 |
+
# HF Space handles the prefix server-side when is_query=True.
|
|
|
|
| 48 |
async with httpx.AsyncClient(timeout=30.0) as client:
|
| 49 |
+
resp = await client.post(
|
| 50 |
+
f"{self._url}/embed",
|
| 51 |
+
json={"texts": texts, "is_query": is_query},
|
| 52 |
+
)
|
| 53 |
resp.raise_for_status()
|
| 54 |
return resp.json()["embeddings"]
|
| 55 |
model = _get_local_model()
|
| 56 |
+
if is_query:
|
| 57 |
+
texts = [_BGE_QUERY_PREFIX + t for t in texts]
|
| 58 |
vectors = model.encode(texts, batch_size=32, normalize_embeddings=True, show_progress_bar=False)
|
| 59 |
return vectors.tolist()
|
| 60 |
|
| 61 |
+
async def embed_one(self, text: str, is_query: bool = False) -> list[float]:
|
| 62 |
"""Convenience wrapper for a single string."""
|
| 63 |
+
results = await self.embed([text], is_query=is_query)
|
| 64 |
return results[0]
|
app/services/sparse_encoder.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
backend/app/services/sparse_encoder.py
|
| 3 |
+
|
| 4 |
+
BM25 sparse encoder backed by FastEmbed's Qdrant/bm25 model.
|
| 5 |
+
Used at ingestion time (ingest.py) and at query time (retrieve node).
|
| 6 |
+
|
| 7 |
+
The model downloads a ~5 MB vocabulary file on first use. Subsequent calls
|
| 8 |
+
are fully local. The module-level singleton is loaded lazily on first call
|
| 9 |
+
to avoid startup delay in the API Space.
|
| 10 |
+
|
| 11 |
+
Fallback: if fastembed is not installed, encode() returns empty sparse vectors
|
| 12 |
+
so dense-only retrieval continues working unchanged.
|
| 13 |
+
"""
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
import logging
|
| 17 |
+
from typing import Any, Optional
|
| 18 |
+
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
_model: Optional[Any] = None
|
| 22 |
+
_fastembed_available: Optional[bool] = None
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def _get_model() -> Optional[Any]:
|
| 26 |
+
global _model, _fastembed_available # noqa: PLW0603
|
| 27 |
+
if _fastembed_available is False:
|
| 28 |
+
return None
|
| 29 |
+
if _model is not None:
|
| 30 |
+
return _model
|
| 31 |
+
try:
|
| 32 |
+
from fastembed import SparseTextEmbedding # type: ignore[import]
|
| 33 |
+
|
| 34 |
+
_model = SparseTextEmbedding(model_name="Qdrant/bm25")
|
| 35 |
+
_fastembed_available = True
|
| 36 |
+
logger.info("FastEmbed BM25 sparse encoder loaded (Qdrant/bm25).")
|
| 37 |
+
return _model
|
| 38 |
+
except Exception as exc:
|
| 39 |
+
_fastembed_available = False
|
| 40 |
+
logger.warning(
|
| 41 |
+
"FastEmbed not available — sparse retrieval disabled, falling back to dense-only. (%s)",
|
| 42 |
+
exc,
|
| 43 |
+
)
|
| 44 |
+
return None
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class SparseEncoder:
|
| 48 |
+
"""
|
| 49 |
+
Wraps FastEmbed SparseTextEmbedding to produce BM25 sparse vectors.
|
| 50 |
+
|
| 51 |
+
Returns list of (indices, values) tuples — one per input text. If FastEmbed
|
| 52 |
+
is unavailable, returns empty ([], []) tuples so callers can gracefully skip
|
| 53 |
+
sparse indexing without breaking the ingestion pipeline.
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
def encode(self, texts: list[str]) -> list[tuple[list[int], list[float]]]:
|
| 57 |
+
"""Encode a batch of texts. Returns [(indices, values), ...] per text."""
|
| 58 |
+
if not texts:
|
| 59 |
+
return []
|
| 60 |
+
model = _get_model()
|
| 61 |
+
if model is None:
|
| 62 |
+
return [([], []) for _ in texts]
|
| 63 |
+
try:
|
| 64 |
+
results = []
|
| 65 |
+
for emb in model.embed(texts):
|
| 66 |
+
# fastembed SparseEmbedding exposes .indices and .values as numpy arrays.
|
| 67 |
+
results.append((emb.indices.tolist(), emb.values.tolist()))
|
| 68 |
+
return results
|
| 69 |
+
except Exception as exc:
|
| 70 |
+
logger.warning("BM25 encoding failed (%s); returning empty sparse vectors.", exc)
|
| 71 |
+
return [([], []) for _ in texts]
|
| 72 |
+
|
| 73 |
+
def encode_one(self, text: str) -> tuple[list[int], list[float]]:
|
| 74 |
+
"""Convenience wrapper for a single string."""
|
| 75 |
+
return self.encode([text])[0]
|
| 76 |
+
|
| 77 |
+
@property
|
| 78 |
+
def available(self) -> bool:
|
| 79 |
+
"""True if FastEmbed loaded successfully and sparse encoding is active."""
|
| 80 |
+
return _get_model() is not None
|
app/services/vector_store.py
CHANGED
|
@@ -1,113 +1,205 @@
|
|
|
|
|
| 1 |
import uuid
|
| 2 |
from typing import Optional
|
| 3 |
|
| 4 |
from qdrant_client import QdrantClient
|
| 5 |
-
from qdrant_client.models import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
from app.models.pipeline import Chunk, ChunkMetadata
|
| 8 |
from app.core.exceptions import RetrievalError
|
| 9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
class VectorStore:
|
| 12 |
def __init__(self, client: QdrantClient, collection: str):
|
| 13 |
self.client = client
|
| 14 |
self.collection = collection
|
| 15 |
|
| 16 |
-
def ensure_collection(self) -> None:
|
| 17 |
-
"""
|
| 18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
collections = self.client.get_collections().collections
|
| 20 |
exists = any(c.name == self.collection for c in collections)
|
| 21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
if not exists:
|
| 23 |
self.client.create_collection(
|
| 24 |
collection_name=self.collection,
|
| 25 |
-
vectors_config=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
)
|
|
|
|
| 27 |
|
| 28 |
-
# Keyword index
|
| 29 |
-
# create_payload_index is idempotent — safe to call on every startup.
|
| 30 |
self.client.create_payload_index(
|
| 31 |
collection_name=self.collection,
|
| 32 |
field_name="metadata.doc_id",
|
| 33 |
field_schema=PayloadSchemaType.KEYWORD,
|
| 34 |
)
|
| 35 |
|
| 36 |
-
def upsert_chunks(
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
if not chunks:
|
| 42 |
return
|
| 43 |
|
| 44 |
points = []
|
| 45 |
-
for chunk,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
points.append(
|
| 47 |
PointStruct(
|
| 48 |
id=str(uuid.uuid4()),
|
| 49 |
vector=vector,
|
| 50 |
-
payload=chunk
|
| 51 |
)
|
| 52 |
)
|
| 53 |
|
| 54 |
-
# Qdrant client upsert takes care of batching if needed, but we can chunk our points list
|
| 55 |
batch_size = 100
|
| 56 |
for i in range(0, len(points), batch_size):
|
| 57 |
-
batch = points[i:i + batch_size]
|
| 58 |
self.client.upsert(
|
| 59 |
collection_name=self.collection,
|
| 60 |
-
points=
|
| 61 |
)
|
| 62 |
|
| 63 |
def delete_by_doc_id(self, doc_id: str) -> None:
|
| 64 |
-
"""Filters on metadata.doc_id and deletes
|
| 65 |
try:
|
| 66 |
-
|
| 67 |
collection_name=self.collection,
|
| 68 |
points_selector=Filter(
|
| 69 |
must=[
|
| 70 |
FieldCondition(
|
| 71 |
key="metadata.doc_id",
|
| 72 |
-
match=MatchValue(value=doc_id)
|
| 73 |
)
|
| 74 |
]
|
| 75 |
-
)
|
| 76 |
-
|
| 77 |
-
except Exception
|
| 78 |
-
#
|
| 79 |
-
pass
|
| 80 |
|
| 81 |
-
def search(
|
| 82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
try:
|
| 84 |
qdrant_filter = None
|
| 85 |
if filters:
|
| 86 |
-
must_conditions = [
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
match=MatchValue(value=value)
|
| 92 |
-
)
|
| 93 |
-
)
|
| 94 |
-
if must_conditions:
|
| 95 |
-
qdrant_filter = Filter(must=must_conditions)
|
| 96 |
|
| 97 |
results = self.client.search(
|
| 98 |
collection_name=self.collection,
|
| 99 |
-
query_vector=query_vector,
|
| 100 |
limit=top_k,
|
| 101 |
-
query_filter=qdrant_filter
|
|
|
|
| 102 |
)
|
| 103 |
|
| 104 |
-
|
| 105 |
-
for hit in results:
|
| 106 |
-
if hit.payload:
|
| 107 |
-
chunks.append(Chunk(**hit.payload))
|
| 108 |
-
return chunks
|
| 109 |
|
| 110 |
-
except Exception as
|
| 111 |
raise RetrievalError(
|
| 112 |
-
f"
|
| 113 |
-
) from
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
import uuid
|
| 3 |
from typing import Optional
|
| 4 |
|
| 5 |
from qdrant_client import QdrantClient
|
| 6 |
+
from qdrant_client.models import (
|
| 7 |
+
Distance,
|
| 8 |
+
FieldCondition,
|
| 9 |
+
Filter,
|
| 10 |
+
MatchValue,
|
| 11 |
+
NamedSparseVector,
|
| 12 |
+
NamedVector,
|
| 13 |
+
PayloadSchemaType,
|
| 14 |
+
PointStruct,
|
| 15 |
+
SparseIndexParams,
|
| 16 |
+
SparseVector,
|
| 17 |
+
SparseVectorParams,
|
| 18 |
+
VectorParams,
|
| 19 |
+
)
|
| 20 |
|
| 21 |
from app.models.pipeline import Chunk, ChunkMetadata
|
| 22 |
from app.core.exceptions import RetrievalError
|
| 23 |
|
| 24 |
+
logger = logging.getLogger(__name__)
|
| 25 |
+
|
| 26 |
+
# Named vector keys used in the Qdrant collection.
|
| 27 |
+
_DENSE_VEC = "dense"
|
| 28 |
+
_SPARSE_VEC = "sparse"
|
| 29 |
+
|
| 30 |
|
| 31 |
class VectorStore:
|
| 32 |
def __init__(self, client: QdrantClient, collection: str):
|
| 33 |
self.client = client
|
| 34 |
self.collection = collection
|
| 35 |
|
| 36 |
+
def ensure_collection(self, allow_recreate: bool = False) -> None:
|
| 37 |
+
"""
|
| 38 |
+
Creates or migrates the collection to support named dense + sparse vectors.
|
| 39 |
+
|
| 40 |
+
allow_recreate=True (ingestion): if the collection exists with the old
|
| 41 |
+
unnamed-vector format, delete and recreate it. Ingestion will re-index
|
| 42 |
+
everything on the same run, so data loss is acceptable.
|
| 43 |
+
|
| 44 |
+
allow_recreate=False (API startup): never touch an existing collection.
|
| 45 |
+
The API will use whatever format is already deployed.
|
| 46 |
+
"""
|
| 47 |
collections = self.client.get_collections().collections
|
| 48 |
exists = any(c.name == self.collection for c in collections)
|
| 49 |
|
| 50 |
+
if exists and allow_recreate:
|
| 51 |
+
try:
|
| 52 |
+
info = self.client.get_collection(self.collection)
|
| 53 |
+
is_old_format = not isinstance(info.config.params.vectors, dict)
|
| 54 |
+
has_no_sparse = not info.config.params.sparse_vectors
|
| 55 |
+
if is_old_format or has_no_sparse:
|
| 56 |
+
logger.info(
|
| 57 |
+
"Collection %r uses old vector format; recreating for hybrid search.",
|
| 58 |
+
self.collection,
|
| 59 |
+
)
|
| 60 |
+
self.client.delete_collection(self.collection)
|
| 61 |
+
exists = False
|
| 62 |
+
except Exception as exc:
|
| 63 |
+
logger.warning("Could not inspect collection format (%s); skipping migration.", exc)
|
| 64 |
+
|
| 65 |
if not exists:
|
| 66 |
self.client.create_collection(
|
| 67 |
collection_name=self.collection,
|
| 68 |
+
vectors_config={
|
| 69 |
+
_DENSE_VEC: VectorParams(size=384, distance=Distance.COSINE),
|
| 70 |
+
},
|
| 71 |
+
sparse_vectors_config={
|
| 72 |
+
# on_disk=False keeps sparse index in RAM for sub-ms lookup.
|
| 73 |
+
_SPARSE_VEC: SparseVectorParams(
|
| 74 |
+
index=SparseIndexParams(on_disk=False)
|
| 75 |
+
),
|
| 76 |
+
},
|
| 77 |
)
|
| 78 |
+
logger.info("Created collection %r with dense + sparse vectors.", self.collection)
|
| 79 |
|
| 80 |
+
# Keyword index for filter-by-doc_id in delete_by_doc_id. Idempotent.
|
|
|
|
| 81 |
self.client.create_payload_index(
|
| 82 |
collection_name=self.collection,
|
| 83 |
field_name="metadata.doc_id",
|
| 84 |
field_schema=PayloadSchemaType.KEYWORD,
|
| 85 |
)
|
| 86 |
|
| 87 |
+
def upsert_chunks(
|
| 88 |
+
self,
|
| 89 |
+
chunks: list[Chunk],
|
| 90 |
+
dense_embeddings: list[list[float]],
|
| 91 |
+
sparse_embeddings: Optional[list[tuple[list[int], list[float]]]] = None,
|
| 92 |
+
) -> None:
|
| 93 |
+
"""
|
| 94 |
+
Builds PointStruct list with named dense (and optionally sparse) vectors.
|
| 95 |
+
|
| 96 |
+
sparse_embeddings: list of (indices, values) tuples from SparseEncoder.
|
| 97 |
+
If None or empty sparse vector for a chunk, dense-only point is used.
|
| 98 |
+
"""
|
| 99 |
+
if len(chunks) != len(dense_embeddings):
|
| 100 |
+
raise ValueError("Number of chunks must match number of dense embeddings")
|
| 101 |
if not chunks:
|
| 102 |
return
|
| 103 |
|
| 104 |
points = []
|
| 105 |
+
for i, (chunk, dense_vec) in enumerate(zip(chunks, dense_embeddings)):
|
| 106 |
+
vector: dict = {_DENSE_VEC: dense_vec}
|
| 107 |
+
|
| 108 |
+
if sparse_embeddings is not None:
|
| 109 |
+
indices, values = sparse_embeddings[i]
|
| 110 |
+
if indices: # Skip empty sparse vectors gracefully
|
| 111 |
+
vector[_SPARSE_VEC] = SparseVector(
|
| 112 |
+
indices=indices, values=values
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
points.append(
|
| 116 |
PointStruct(
|
| 117 |
id=str(uuid.uuid4()),
|
| 118 |
vector=vector,
|
| 119 |
+
payload=chunk,
|
| 120 |
)
|
| 121 |
)
|
| 122 |
|
|
|
|
| 123 |
batch_size = 100
|
| 124 |
for i in range(0, len(points), batch_size):
|
|
|
|
| 125 |
self.client.upsert(
|
| 126 |
collection_name=self.collection,
|
| 127 |
+
points=points[i : i + batch_size],
|
| 128 |
)
|
| 129 |
|
| 130 |
def delete_by_doc_id(self, doc_id: str) -> None:
|
| 131 |
+
"""Filters on metadata.doc_id and deletes all matching points."""
|
| 132 |
try:
|
| 133 |
+
self.client.delete(
|
| 134 |
collection_name=self.collection,
|
| 135 |
points_selector=Filter(
|
| 136 |
must=[
|
| 137 |
FieldCondition(
|
| 138 |
key="metadata.doc_id",
|
| 139 |
+
match=MatchValue(value=doc_id),
|
| 140 |
)
|
| 141 |
]
|
| 142 |
+
),
|
| 143 |
+
)
|
| 144 |
+
except Exception:
|
| 145 |
+
pass # Safe to ignore — collection or index may not exist yet
|
|
|
|
| 146 |
|
| 147 |
+
def search(
|
| 148 |
+
self,
|
| 149 |
+
query_vector: list[float],
|
| 150 |
+
top_k: int = 20,
|
| 151 |
+
filters: Optional[dict] = None,
|
| 152 |
+
) -> list[Chunk]:
|
| 153 |
+
"""Dense vector search using the named 'dense' vector."""
|
| 154 |
try:
|
| 155 |
qdrant_filter = None
|
| 156 |
if filters:
|
| 157 |
+
must_conditions = [
|
| 158 |
+
FieldCondition(key=f"metadata.{k}", match=MatchValue(value=v))
|
| 159 |
+
for k, v in filters.items()
|
| 160 |
+
]
|
| 161 |
+
qdrant_filter = Filter(must=must_conditions)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
|
| 163 |
results = self.client.search(
|
| 164 |
collection_name=self.collection,
|
| 165 |
+
query_vector=NamedVector(name=_DENSE_VEC, vector=query_vector),
|
| 166 |
limit=top_k,
|
| 167 |
+
query_filter=qdrant_filter,
|
| 168 |
+
with_payload=True,
|
| 169 |
)
|
| 170 |
|
| 171 |
+
return [Chunk(**hit.payload) for hit in results if hit.payload]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
|
| 173 |
+
except Exception as exc:
|
| 174 |
raise RetrievalError(
|
| 175 |
+
f"Dense vector search failed: {exc}", context={"error": str(exc)}
|
| 176 |
+
) from exc
|
| 177 |
+
|
| 178 |
+
def search_sparse(
|
| 179 |
+
self,
|
| 180 |
+
indices: list[int],
|
| 181 |
+
values: list[float],
|
| 182 |
+
top_k: int = 20,
|
| 183 |
+
) -> list[Chunk]:
|
| 184 |
+
"""
|
| 185 |
+
BM25 sparse vector search using the named 'sparse' vector.
|
| 186 |
+
Returns empty list if sparse vectors are absent or indices is empty.
|
| 187 |
+
"""
|
| 188 |
+
if not indices:
|
| 189 |
+
return []
|
| 190 |
+
try:
|
| 191 |
+
results = self.client.search(
|
| 192 |
+
collection_name=self.collection,
|
| 193 |
+
query_vector=NamedSparseVector(
|
| 194 |
+
name=_SPARSE_VEC,
|
| 195 |
+
vector=SparseVector(indices=indices, values=values),
|
| 196 |
+
),
|
| 197 |
+
limit=top_k,
|
| 198 |
+
with_payload=True,
|
| 199 |
+
)
|
| 200 |
+
return [Chunk(**hit.payload) for hit in results if hit.payload]
|
| 201 |
+
|
| 202 |
+
except Exception as exc:
|
| 203 |
+
# Sparse index may not exist on old collections — log and continue.
|
| 204 |
+
logger.warning("Sparse search failed (%s); skipping sparse results.", exc)
|
| 205 |
+
return []
|
requirements.txt
CHANGED
|
@@ -20,4 +20,7 @@ presidio-analyzer>=2.2.354
|
|
| 20 |
tenacity>=8.3.0
|
| 21 |
python-jose[cryptography]>=3.3.0
|
| 22 |
google-genai>=1.0.0
|
|
|
|
|
|
|
|
|
|
| 23 |
toon_format @ git+https://github.com/toon-format/toon-python.git
|
|
|
|
| 20 |
tenacity>=8.3.0
|
| 21 |
python-jose[cryptography]>=3.3.0
|
| 22 |
google-genai>=1.0.0
|
| 23 |
+
# fastembed: powers BM25 sparse retrieval (Stage 2). Qdrant/bm25 vocabulary
|
| 24 |
+
# downloads ~5 MB on first use then runs fully local — no GPU, no network at query time.
|
| 25 |
+
fastembed>=0.3.6
|
| 26 |
toon_format @ git+https://github.com/toon-format/toon-python.git
|