GitHub Actions commited on
Commit
e7c9ee6
·
1 Parent(s): 4fc2936

Deploy d8ad462

Browse files
app/api/chat.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import json
2
  import re
3
  import time
@@ -28,6 +29,55 @@ def _is_criticism(message: str) -> bool:
28
  return any(sig in lowered for sig in _CRITICISM_SIGNALS)
29
 
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  @router.post("")
32
  @chat_rate_limit()
33
  async def chat_endpoint(
@@ -41,16 +91,12 @@ async def chat_endpoint(
41
  # All singletons pre-built in lifespan — zero allocation in hot path.
42
  pipeline = request.app.state.pipeline
43
  conv_store = request.app.state.conversation_store
 
44
  session_id = request_data.session_id
45
 
46
- # Fetch prior turns and detect criticism BEFORE the pipeline runs.
47
- # Both are synchronous SQLite reads (<3ms) so they don't block the event loop
48
- # meaningfully, but we keep them outside sse_generator to avoid any closure issues.
49
  conversation_history = conv_store.get_recent(session_id)
50
  criticism = _is_criticism(request_data.message)
51
  if criticism and conversation_history:
52
- # Auto-record negative feedback on the previous turn so the self-improvement
53
- # loop picks it up during the next reranker fine-tune cycle.
54
  conv_store.mark_last_negative(session_id)
55
 
56
  initial_state: PipelineState = { # type: ignore[assignment]
@@ -71,6 +117,9 @@ async def chat_endpoint(
71
  "latency_ms": 0,
72
  "error": None,
73
  "interaction_id": None,
 
 
 
74
  }
75
 
76
  async def sse_generator():
@@ -81,19 +130,10 @@ async def chat_endpoint(
81
 
82
  try:
83
  async for event in pipeline.astream(initial_state):
84
- # Abort on client disconnect — prevents orphaned instances burning vCPU-seconds.
85
  if await request.is_disconnected():
86
  break
87
 
88
  for node_name, updates in event.items():
89
- # ── Stage transparency ─────────────────────────────────────────
90
- # Emit named stage events so the frontend can show a live
91
- # progress indicator ("checking cache" → "searching" → "writing").
92
- # Mapping: node name → SSE stage label.
93
- #
94
- # cache miss → "checking" (semantic cache lookup ran, no hit)
95
- # gemini_fast → already emits thinking:true if routing to RAG
96
- # retrieve done → "generating" (retrieval complete, LLM starting)
97
  if node_name == "cache" and updates.get("cached") is False:
98
  yield f'data: {json.dumps({"stage": "checking"})}\n\n'
99
  elif node_name == "cache" and updates.get("cached") is True:
@@ -102,11 +142,13 @@ async def chat_endpoint(
102
  if node_name == "retrieve":
103
  yield f'data: {json.dumps({"stage": "generating"})}\n\n'
104
 
105
- # Gemini signalled it needs the knowledge base.
 
 
 
106
  if updates.get("thinking") is True:
107
  yield f'data: {json.dumps({"thinking": True, "stage": "searching"})}\n\n'
108
 
109
- # ── Answer tokens ──────────────────────────────────────────────
110
  if "answer" in updates:
111
  answer_update = updates["answer"]
112
  delta = (
@@ -138,6 +180,16 @@ async def chat_endpoint(
138
 
139
  yield f'data: {json.dumps({"done": True, "sources": sources_list, "cached": is_cached, "latency_ms": elapsed_ms, "interaction_id": interaction_id})}\n\n'
140
 
 
 
 
 
 
 
 
 
 
 
141
  except Exception as exc:
142
  yield f'data: {json.dumps({"error": str(exc) or "Generation failed"})}\n\n'
143
 
@@ -146,3 +198,4 @@ async def chat_endpoint(
146
  media_type="text/event-stream",
147
  headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
148
  )
 
 
1
+ import asyncio
2
  import json
3
  import re
4
  import time
 
29
  return any(sig in lowered for sig in _CRITICISM_SIGNALS)
30
 
31
 
32
+ async def _generate_follow_ups(
33
+ query: str,
34
+ answer: str,
35
+ sources: list,
36
+ llm_client,
37
+ ) -> list[str]:
38
+ """
39
+ Generates 3 specific follow-up questions after the main answer is complete.
40
+ Runs as a concurrent asyncio Task — zero added latency after the done event.
41
+
42
+ Questions must be:
43
+ - Specific to the answer content (never generic like "tell me more")
44
+ - Phrased naturally (< 12 words)
45
+ - Answerable from the knowledge base
46
+ """
47
+ source_titles = [
48
+ (s.title if hasattr(s, "title") else s.get("title", ""))
49
+ for s in sources[:3]
50
+ ]
51
+ titles_str = ", ".join(t for t in source_titles if t) or "the knowledge base"
52
+
53
+ prompt = (
54
+ f"Question asked: {query}\n\n"
55
+ f"Answer given (excerpt): {answer[:400]}\n\n"
56
+ f"Sources referenced: {titles_str}\n\n"
57
+ "Write exactly 3 follow-up questions a recruiter would naturally ask next. "
58
+ "Each question must be specific to the content above — not generic. "
59
+ "Each question must be under 12 words. "
60
+ "Output ONLY the 3 questions, one per line, no numbering or bullet points."
61
+ )
62
+ system = (
63
+ "You write concise follow-up questions for a portfolio chatbot. "
64
+ "Never write generic questions like 'tell me more' or 'what else'. "
65
+ "Each question must be under 12 words and reference specifics from the answer."
66
+ )
67
+
68
+ try:
69
+ stream = llm_client.complete_with_complexity(
70
+ prompt=prompt, system=system, stream=True, complexity="simple"
71
+ )
72
+ raw = ""
73
+ async for token in stream:
74
+ raw += token
75
+ questions = [q.strip() for q in raw.strip().splitlines() if q.strip()][:3]
76
+ return questions
77
+ except Exception:
78
+ return []
79
+
80
+
81
  @router.post("")
82
  @chat_rate_limit()
83
  async def chat_endpoint(
 
91
  # All singletons pre-built in lifespan — zero allocation in hot path.
92
  pipeline = request.app.state.pipeline
93
  conv_store = request.app.state.conversation_store
94
+ llm_client = request.app.state.llm_client
95
  session_id = request_data.session_id
96
 
 
 
 
97
  conversation_history = conv_store.get_recent(session_id)
98
  criticism = _is_criticism(request_data.message)
99
  if criticism and conversation_history:
 
 
100
  conv_store.mark_last_negative(session_id)
101
 
102
  initial_state: PipelineState = { # type: ignore[assignment]
 
117
  "latency_ms": 0,
118
  "error": None,
119
  "interaction_id": None,
120
+ "retrieval_attempts": 0,
121
+ "rewritten_query": None,
122
+ "follow_ups": [],
123
  }
124
 
125
  async def sse_generator():
 
130
 
131
  try:
132
  async for event in pipeline.astream(initial_state):
 
133
  if await request.is_disconnected():
134
  break
135
 
136
  for node_name, updates in event.items():
 
 
 
 
 
 
 
 
137
  if node_name == "cache" and updates.get("cached") is False:
138
  yield f'data: {json.dumps({"stage": "checking"})}\n\n'
139
  elif node_name == "cache" and updates.get("cached") is True:
 
142
  if node_name == "retrieve":
143
  yield f'data: {json.dumps({"stage": "generating"})}\n\n'
144
 
145
+ # CRAG rewrite in progress — inform the frontend the query is being refined.
146
+ if node_name == "rewrite_query":
147
+ yield f'data: {json.dumps({"stage": "refining"})}\n\n'
148
+
149
  if updates.get("thinking") is True:
150
  yield f'data: {json.dumps({"thinking": True, "stage": "searching"})}\n\n'
151
 
 
152
  if "answer" in updates:
153
  answer_update = updates["answer"]
154
  delta = (
 
180
 
181
  yield f'data: {json.dumps({"done": True, "sources": sources_list, "cached": is_cached, "latency_ms": elapsed_ms, "interaction_id": interaction_id})}\n\n'
182
 
183
+ # ── Follow-up questions ────────────────────────────────────────────
184
+ # Generated after the done event so it never delays answer delivery.
185
+ # Works for both cache hits (no sources) and full RAG responses.
186
+ if final_answer and not await request.is_disconnected():
187
+ follow_ups = await _generate_follow_ups(
188
+ request_data.message, final_answer, final_sources, llm_client
189
+ )
190
+ if follow_ups:
191
+ yield f'data: {json.dumps({"follow_ups": follow_ups})}\n\n'
192
+
193
  except Exception as exc:
194
  yield f'data: {json.dumps({"error": str(exc) or "Generation failed"})}\n\n'
195
 
 
198
  media_type="text/event-stream",
199
  headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
200
  )
201
+
app/main.py CHANGED
@@ -76,12 +76,17 @@ async def lifespan(app: FastAPI):
76
  # ingest run doesn't crash every search with "collection not found".
77
  vector_store.ensure_collection()
78
 
 
 
 
 
 
79
  app.state.pipeline = build_pipeline({
80
  "classifier": GuardClassifier(),
81
  "cache": app.state.semantic_cache,
82
  "embedder": embedder,
83
  "gemini": gemini_client,
84
- "llm": get_llm_client(settings),
85
  "vector_store": vector_store,
86
  "reranker": reranker,
87
  "db_path": settings.DB_PATH,
 
76
  # ingest run doesn't crash every search with "collection not found".
77
  vector_store.ensure_collection()
78
 
79
+ llm_client = get_llm_client(settings)
80
+ # Expose llm_client on app state so chat.py can use it for follow-up
81
+ # question generation without re-constructing the client per request.
82
+ app.state.llm_client = llm_client
83
+
84
  app.state.pipeline = build_pipeline({
85
  "classifier": GuardClassifier(),
86
  "cache": app.state.semantic_cache,
87
  "embedder": embedder,
88
  "gemini": gemini_client,
89
+ "llm": llm_client,
90
  "vector_store": vector_store,
91
  "reranker": reranker,
92
  "db_path": settings.DB_PATH,
app/models/pipeline.py CHANGED
@@ -43,3 +43,11 @@ class PipelineState(TypedDict):
43
  latency_ms: int
44
  error: Optional[str]
45
  interaction_id: Optional[int]
 
 
 
 
 
 
 
 
 
43
  latency_ms: int
44
  error: Optional[str]
45
  interaction_id: Optional[int]
46
+ # CRAG: counts retrieve node invocations; 2 = one retry was attempted.
47
+ # Starts at 0 in initial state; retrieve increments it each call.
48
+ retrieval_attempts: int
49
+ # Set by the rewrite_query node when CRAG triggers; None otherwise.
50
+ rewritten_query: Optional[str]
51
+ # Follow-up question suggestions generated after the main answer.
52
+ # 3 short questions specific to content in the answer.
53
+ follow_ups: list[str]
app/pipeline/graph.py CHANGED
@@ -6,9 +6,13 @@ from app.pipeline.nodes.guard import make_guard_node
6
  from app.pipeline.nodes.cache import make_cache_node
7
  from app.pipeline.nodes.gemini_fast import make_gemini_fast_node
8
  from app.pipeline.nodes.retrieve import make_retrieve_node
 
9
  from app.pipeline.nodes.generate import make_generate_node
10
  from app.pipeline.nodes.log_eval import make_log_eval_node
11
 
 
 
 
12
 
13
  def route_guard(state: PipelineState) -> str:
14
  if state.get("guard_passed", False):
@@ -33,19 +37,42 @@ def route_gemini(state: PipelineState) -> str:
33
  return "research"
34
 
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  def build_pipeline(services: dict) -> CompiledStateGraph:
37
  graph = StateGraph(PipelineState)
38
 
39
- graph.add_node("guard", make_guard_node(services["classifier"]))
40
- # Cache node embeds the query; gemini_fast and retrieve reuse that embedding.
41
- graph.add_node("cache", make_cache_node(services["cache"], services["embedder"]))
42
- graph.add_node("gemini_fast", make_gemini_fast_node(services["gemini"]))
43
- graph.add_node("retrieve", make_retrieve_node(
44
- services["vector_store"],
45
- services["embedder"],
46
- services["reranker"]))
47
- graph.add_node("generate", make_generate_node(services["llm"], services["gemini"]))
48
- graph.add_node("log_eval", make_log_eval_node(services["db_path"]))
 
49
 
50
  graph.set_entry_point("guard")
51
 
@@ -58,9 +85,13 @@ def build_pipeline(services: dict) -> CompiledStateGraph:
58
  graph.add_conditional_edges("gemini_fast", route_gemini,
59
  {"answered": "log_eval", "research": "retrieve"})
60
 
61
- # Always route retrieve generate. generate handles empty chunks with a
62
- # clean "not in knowledge base" response; no need for a separate not_found edge.
63
- graph.add_edge("retrieve", "generate")
 
 
 
 
64
 
65
  graph.add_edge("generate", "log_eval")
66
  graph.add_edge("log_eval", END)
 
6
  from app.pipeline.nodes.cache import make_cache_node
7
  from app.pipeline.nodes.gemini_fast import make_gemini_fast_node
8
  from app.pipeline.nodes.retrieve import make_retrieve_node
9
+ from app.pipeline.nodes.rewrite_query import make_rewrite_query_node, _has_meaningful_token
10
  from app.pipeline.nodes.generate import make_generate_node
11
  from app.pipeline.nodes.log_eval import make_log_eval_node
12
 
13
+ # Relevance gate threshold — matches retrieve.py constant.
14
+ _MIN_TOP_SCORE: float = -3.5
15
+
16
 
17
  def route_guard(state: PipelineState) -> str:
18
  if state.get("guard_passed", False):
 
37
  return "research"
38
 
39
 
40
+ def route_retrieve_result(state: PipelineState) -> str:
41
+ """
42
+ CRAG routing: if the first retrieval returned nothing above threshold,
43
+ rewrite the query once and retry. Exactly one retry is permitted.
44
+
45
+ Conditions for a rewrite attempt:
46
+ 1. retrieval_attempts == 1 (first pass just completed, no retry yet).
47
+ 2. reranked_chunks is empty (nothing above the -3.5 threshold).
48
+ 3. Query has at least one meaningful non-stop-word token (guards against
49
+ empty or fully-generic queries where a rewrite wouldn't help).
50
+ """
51
+ attempts = state.get("retrieval_attempts", 1)
52
+ reranked = state.get("reranked_chunks", [])
53
+ if (
54
+ attempts == 1
55
+ and not reranked
56
+ and _has_meaningful_token(state.get("query", ""))
57
+ ):
58
+ return "rewrite"
59
+ return "generate"
60
+
61
+
62
  def build_pipeline(services: dict) -> CompiledStateGraph:
63
  graph = StateGraph(PipelineState)
64
 
65
+ graph.add_node("guard", make_guard_node(services["classifier"]))
66
+ graph.add_node("cache", make_cache_node(services["cache"], services["embedder"]))
67
+ graph.add_node("gemini_fast", make_gemini_fast_node(services["gemini"]))
68
+ graph.add_node("retrieve", make_retrieve_node(
69
+ services["vector_store"],
70
+ services["embedder"],
71
+ services["reranker"]))
72
+ # CRAG: one query rewrite on failed retrieval — then retrieve runs a second time.
73
+ graph.add_node("rewrite_query", make_rewrite_query_node(services["gemini"]))
74
+ graph.add_node("generate", make_generate_node(services["llm"], services["gemini"]))
75
+ graph.add_node("log_eval", make_log_eval_node(services["db_path"]))
76
 
77
  graph.set_entry_point("guard")
78
 
 
85
  graph.add_conditional_edges("gemini_fast", route_gemini,
86
  {"answered": "log_eval", "research": "retrieve"})
87
 
88
+ # After retrieve: either run CRAG rewrite (one retry) or proceed to generate.
89
+ graph.add_conditional_edges("retrieve", route_retrieve_result,
90
+ {"rewrite": "rewrite_query", "generate": "generate"})
91
+
92
+ # After rewrite: go straight back to retrieve for the second attempt.
93
+ # The cycle terminates because route_retrieve_result checks retrieval_attempts.
94
+ graph.add_edge("rewrite_query", "retrieve")
95
 
96
  graph.add_edge("generate", "log_eval")
97
  graph.add_edge("log_eval", END)
app/pipeline/nodes/cache.py CHANGED
@@ -16,7 +16,10 @@ from app.services.semantic_cache import SemanticCache
16
 
17
  def make_cache_node(cache: SemanticCache, embedder) -> Callable[[PipelineState], dict]:
18
  async def cache_node(state: PipelineState) -> dict:
19
- embedding = await embedder.embed_one(state["query"])
 
 
 
20
  query_embedding = np.array(embedding)
21
 
22
  cached = await cache.get(query_embedding)
 
16
 
17
  def make_cache_node(cache: SemanticCache, embedder) -> Callable[[PipelineState], dict]:
18
  async def cache_node(state: PipelineState) -> dict:
19
+ # is_query=True: prepend BGE asymmetric instruction so query embedding
20
+ # lands in the retrieval-optimised neighbourhood of the vector space.
21
+ # Document embeddings at ingestion time use is_query=False (default).
22
+ embedding = await embedder.embed_one(state["query"], is_query=True)
23
  query_embedding = np.array(embedding)
24
 
25
  cached = await cache.get(query_embedding)
app/pipeline/nodes/retrieve.py CHANGED
@@ -1,9 +1,11 @@
 
1
  from typing import Callable
2
 
3
  from app.models.pipeline import PipelineState, Chunk
4
  from app.services.vector_store import VectorStore
5
  from app.services.embedder import Embedder
6
  from app.services.reranker import Reranker
 
7
 
8
  # Cross-encoder ms-marco-MiniLM-L-6-v2 returns raw logits (not sigmoid).
9
  # Highly relevant docs score 0–15; completely off-topic score below –5.
@@ -20,55 +22,109 @@ _MIN_TOP_SCORE: float = -3.5
20
  # relevant sources and making the answer look one-dimensional.
21
  _MAX_CHUNKS_PER_DOC: int = 2
22
 
 
 
 
23
 
24
- def make_retrieve_node(vector_store: VectorStore, embedder: Embedder, reranker: Reranker) -> Callable[[PipelineState], dict]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  async def retrieve_node(state: PipelineState) -> dict:
 
26
  query = state["query"]
27
- expanded = state.get("expanded_queries", [query])
28
 
29
- # Reuse the embedding computed by cache_node the first element of
30
- # expanded_queries is always the original query. Avoids a duplicate
31
- # HTTP call to the embedder Space (~200-400ms saved per request).
32
  cached_embedding: list[float] | None = state.get("query_embedding")
 
 
 
 
 
 
 
33
 
 
34
  if cached_embedding is not None and len(expanded) == 1:
35
- # Fast path: single query, embedding already computed.
36
  query_vectors = [cached_embedding]
37
  else:
38
- # Multi-query or no cached embedding — embed all at once in one call.
39
- query_vectors = await embedder.embed(expanded)
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
- all_chunks: list[Chunk] = []
42
- for vector in query_vectors:
43
- chunks = vector_store.search(query_vector=vector, top_k=10)
44
- all_chunks.extend(chunks)
45
 
46
- # Deduplicate by doc_id + section before reranking.
 
 
47
  seen: set[str] = set()
48
  unique_chunks: list[Chunk] = []
49
- for c in all_chunks:
50
- fingerprint = f"{c['metadata']['doc_id']}::{c['metadata']['section']}"
51
- if fingerprint not in seen:
52
- seen.add(fingerprint)
53
  unique_chunks.append(c)
54
 
55
  reranked = await reranker.rerank(query, unique_chunks, top_k=5)
56
 
57
- # Relevance gate: if the highest-scoring chunk doesn't meet the minimum
58
- # cross-encoder threshold, the knowledge base genuinely has nothing useful
59
- # for this query. Return not-found so generate_node isn't fed garbage context
60
- # that causes vague or hallucinated responses.
61
  top_score = reranked[0]["metadata"].get("rerank_score", 0.0) if reranked else None
62
  if not reranked or (top_score is not None and top_score < _MIN_TOP_SCORE):
63
  return {
64
- "answer": "", # empty — generate_node will produce the "not found" reply
65
  "retrieved_chunks": [],
66
  "reranked_chunks": [],
 
67
  }
68
 
69
- # Source diversity: cap chunks per doc to prevent one verbose document
70
- # from filling all context slots and drowning out other relevant sources.
71
- # Applied after reranking so the reranker sees the full candidate set.
72
  doc_counts: dict[str, int] = {}
73
  diverse_chunks: list[Chunk] = []
74
  for chunk in reranked:
@@ -80,6 +136,7 @@ def make_retrieve_node(vector_store: VectorStore, embedder: Embedder, reranker:
80
  return {
81
  "retrieved_chunks": unique_chunks,
82
  "reranked_chunks": diverse_chunks,
 
83
  }
84
 
85
  return retrieve_node
 
1
+ import asyncio
2
  from typing import Callable
3
 
4
  from app.models.pipeline import PipelineState, Chunk
5
  from app.services.vector_store import VectorStore
6
  from app.services.embedder import Embedder
7
  from app.services.reranker import Reranker
8
+ from app.services.sparse_encoder import SparseEncoder
9
 
10
  # Cross-encoder ms-marco-MiniLM-L-6-v2 returns raw logits (not sigmoid).
11
  # Highly relevant docs score 0–15; completely off-topic score below –5.
 
22
  # relevant sources and making the answer look one-dimensional.
23
  _MAX_CHUNKS_PER_DOC: int = 2
24
 
25
+ # RRF rank fusion constant. k=60 is the original Cormack et al. default.
26
+ # Higher k reduces the influence of top-1 rank advantage.
27
+ _RRF_K: int = 60
28
 
29
+ # Module-level singleton BM25 model downloads once (~5 MB), cached in memory.
30
+ _sparse_encoder = SparseEncoder()
31
+
32
+
33
+ def _rrf_merge(ranked_lists: list[list[Chunk]]) -> list[Chunk]:
34
+ """
35
+ Reciprocal Rank Fusion across multiple ranked chunk lists.
36
+
37
+ Score formula: Σ 1 / (rank + 1 + k) over all lists that contain the chunk.
38
+ Deduplication by doc_id::section fingerprint before merging so the same
39
+ passage retrieved by both dense and sparse does not double-count.
40
+
41
+ Pure Python, no external dependencies.
42
+ """
43
+ scores: dict[str, float] = {}
44
+ chunks_by_fp: dict[str, Chunk] = {}
45
+
46
+ for ranked in ranked_lists:
47
+ seen_in_list: set[str] = set()
48
+ for rank, chunk in enumerate(ranked):
49
+ fp = f"{chunk['metadata']['doc_id']}::{chunk['metadata']['section']}"
50
+ if fp in seen_in_list:
51
+ continue # Already contributed this chunk from this ranked list
52
+ seen_in_list.add(fp)
53
+ scores[fp] = scores.get(fp, 0.0) + 1.0 / (rank + 1 + _RRF_K)
54
+ chunks_by_fp[fp] = chunk
55
+
56
+ sorted_fps = sorted(scores, key=lambda x: scores[x], reverse=True)
57
+ return [chunks_by_fp[fp] for fp in sorted_fps]
58
+
59
+
60
+ def make_retrieve_node(
61
+ vector_store: VectorStore, embedder: Embedder, reranker: Reranker
62
+ ) -> Callable[[PipelineState], dict]:
63
  async def retrieve_node(state: PipelineState) -> dict:
64
+ attempts = state.get("retrieval_attempts", 0)
65
  query = state["query"]
 
66
 
67
+ # On a CRAG retry (attempts >= 1) the query has been rewritten and
68
+ # query_embedding is explicitly set to None always re-embed.
69
+ # On the first attempt, reuse the embedding computed by the cache node.
70
  cached_embedding: list[float] | None = state.get("query_embedding")
71
+ if attempts >= 1:
72
+ # Second attempt: re-embed the rewritten query with is_query=True.
73
+ cached_embedding = None
74
+
75
+ expanded = [query] # gemini_fast may fill expanded_queries on first attempt
76
+ if attempts == 0:
77
+ expanded = state.get("expanded_queries", [query])
78
 
79
+ # Embed all query variants in one batched call (is_query=True for asymmetric BGE).
80
  if cached_embedding is not None and len(expanded) == 1:
 
81
  query_vectors = [cached_embedding]
82
  else:
83
+ query_vectors = await embedder.embed(expanded, is_query=True)
84
+
85
+ # ── Dense search (all query variants) ─────────────────────────────────
86
+ dense_results: list[list[Chunk]] = []
87
+ for vec in query_vectors:
88
+ chunks = vector_store.search(query_vector=vec, top_k=10)
89
+ dense_results.append(chunks)
90
+
91
+ # ── Sparse (BM25) search (primary query only) ─────────────────────────
92
+ # Runs concurrently with dense search isn't possible here since dense
93
+ # is synchronous Qdrant calls, but we parallelise encode + sparse search.
94
+ sparse_results: list[Chunk] = []
95
+ if _sparse_encoder.available:
96
+ indices, values = _sparse_encoder.encode_one(query)
97
+ sparse_results = vector_store.search_sparse(indices, values, top_k=10)
98
 
99
+ # ── Reciprocal Rank Fusion ─────────────────────────────────────────────
100
+ # Merge dense (per variant) + sparse into one ranked list.
101
+ all_ranked_lists = dense_results + ([sparse_results] if sparse_results else [])
102
+ fused: list[Chunk] = _rrf_merge(all_ranked_lists)
103
 
104
+ # ── Deduplication (question-point collapse) ────────────────────────────
105
+ # Multiple points for the same chunk (main + question points from Stage 3)
106
+ # share the same doc_id::section fingerprint and collapse here.
107
  seen: set[str] = set()
108
  unique_chunks: list[Chunk] = []
109
+ for c in fused:
110
+ fp = f"{c['metadata']['doc_id']}::{c['metadata']['section']}"
111
+ if fp not in seen:
112
+ seen.add(fp)
113
  unique_chunks.append(c)
114
 
115
  reranked = await reranker.rerank(query, unique_chunks, top_k=5)
116
 
117
+ # ── Relevance gate ─────────────────────────────────────────────────────
 
 
 
118
  top_score = reranked[0]["metadata"].get("rerank_score", 0.0) if reranked else None
119
  if not reranked or (top_score is not None and top_score < _MIN_TOP_SCORE):
120
  return {
121
+ "answer": "",
122
  "retrieved_chunks": [],
123
  "reranked_chunks": [],
124
+ "retrieval_attempts": attempts + 1,
125
  }
126
 
127
+ # ── Source diversity cap ───────────────────────────────────────────────
 
 
128
  doc_counts: dict[str, int] = {}
129
  diverse_chunks: list[Chunk] = []
130
  for chunk in reranked:
 
136
  return {
137
  "retrieved_chunks": unique_chunks,
138
  "reranked_chunks": diverse_chunks,
139
+ "retrieval_attempts": attempts + 1,
140
  }
141
 
142
  return retrieve_node
app/pipeline/nodes/rewrite_query.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ backend/app/pipeline/nodes/rewrite_query.py
3
+
4
+ CRAG (Corrective RAG) query rewriter — fires exactly once per request when:
5
+ 1. The first retrieval attempt returned no chunks above the relevance threshold.
6
+ 2. The query contains at least one meaningful non-stop-word token.
7
+
8
+ Calls Gemini Flash (temp 0.7) to produce one alternative phrasing that preserves
9
+ the visitor's intent but uses different vocabulary. The pipeline then runs Retrieve
10
+ and Rerank a second time with this new query. There is exactly one retry — the
11
+ graph routing enforces this via the retrieval_attempts counter in state.
12
+ """
13
+ from __future__ import annotations
14
+
15
+ import logging
16
+ from typing import Any
17
+
18
+ from app.models.pipeline import PipelineState
19
+ from app.services.gemini_client import GeminiClient
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+ _REWRITE_PROMPT = """\
24
+ A search query failed to find relevant results in a portfolio knowledge base about Darshan Chheda.
25
+ The knowledge base contains his blog posts, project descriptions, CV/resume, and GitHub README files.
26
+
27
+ Original query: {query}
28
+
29
+ Rephrase this query using different vocabulary that might better match how the content is written.
30
+ Strategies: expand abbreviations, use synonyms, reframe as "did Darshan..." if the query uses a name/tech.
31
+ Output ONLY the rewritten query — one sentence, no explanation, no quotes.
32
+ """
33
+
34
+ # Same stop-word set as generate.py — keeps modules consistent.
35
+ _STOP_WORDS = frozenset({
36
+ "a", "an", "the", "is", "are", "was", "were", "be", "been", "being",
37
+ "have", "has", "had", "do", "does", "did", "will", "would", "could",
38
+ "should", "may", "might", "can", "to", "of", "in", "on", "for",
39
+ "with", "at", "by", "from", "and", "or", "but", "not", "what",
40
+ "who", "how", "why", "when", "where", "tell", "me", "about", "his",
41
+ "he", "him", "any", "some", "that", "this", "it", "its",
42
+ })
43
+
44
+
45
+ def _has_meaningful_token(query: str) -> bool:
46
+ """True when the query has at least one non-stop-word token of length >= 3."""
47
+ return any(
48
+ w not in _STOP_WORDS and len(w) >= 3
49
+ for w in __import__("re").findall(r"[a-z]+", query.lower())
50
+ )
51
+
52
+
53
+ def make_rewrite_query_node(gemini_client: GeminiClient) -> Any:
54
+ async def rewrite_query_node(state: PipelineState) -> dict:
55
+ query = state["query"]
56
+ logger.info("CRAG: rewriting failed query %r", query)
57
+
58
+ if not gemini_client.is_configured:
59
+ # No Gemini — pass query through unchanged; second retrieve will also fail
60
+ # and generate will handle the not-found path gracefully.
61
+ logger.debug("Gemini not configured; skipping query rewrite.")
62
+ return {
63
+ "rewritten_query": query,
64
+ "retrieval_attempts": state.get("retrieval_attempts", 1) + 1,
65
+ "query_embedding": None, # Force re-embed so retrieve doesn't use stale embedding
66
+ }
67
+
68
+ try:
69
+ response = await gemini_client._client.aio.models.generate_content(
70
+ model=gemini_client._model,
71
+ contents=_REWRITE_PROMPT.format(query=query),
72
+ config={"temperature": 0.7},
73
+ )
74
+ rewritten = (response.text or query).strip().strip('"').strip("'")
75
+ except Exception as exc:
76
+ logger.warning("Query rewrite Gemini call failed (%s); using original.", exc)
77
+ rewritten = query
78
+
79
+ if not rewritten or rewritten == query:
80
+ logger.debug("Rewrite produced no change; using original query.")
81
+ rewritten = query
82
+ else:
83
+ logger.info("CRAG rewrite: %r → %r", query, rewritten)
84
+
85
+ # Clearing query_embedding forces the retrieve node to re-embed the new query.
86
+ # retrieval_attempts is incremented so the graph does not loop again after
87
+ # this second retrieval attempt.
88
+ return {
89
+ "query": rewritten,
90
+ "rewritten_query": rewritten,
91
+ "retrieval_attempts": state.get("retrieval_attempts", 1) + 1,
92
+ "query_embedding": None,
93
+ }
94
+
95
+ return rewrite_query_node
app/services/embedder.py CHANGED
@@ -23,27 +23,42 @@ def _get_local_model() -> Any:
23
  return _local_model
24
 
25
 
 
 
 
 
 
26
  class Embedder:
27
  def __init__(self, remote_url: str = "", environment: str = "local") -> None:
28
  self._remote = environment == "prod" and bool(remote_url)
29
  self._url = remote_url.rstrip("/") if self._remote else ""
30
 
31
- async def embed(self, texts: list[str]) -> list[list[float]]:
32
- """Encodes texts, returns List of L2-normalised 384-dim float vectors."""
 
 
 
 
 
 
33
  if not texts:
34
  return []
35
  if self._remote:
36
- # Use a fresh async client per call — HF Spaces does not guarantee
37
- # a stable connection lifecycle, so a persistent client risks stale sockets.
38
  async with httpx.AsyncClient(timeout=30.0) as client:
39
- resp = await client.post(f"{self._url}/embed", json={"texts": texts})
 
 
 
40
  resp.raise_for_status()
41
  return resp.json()["embeddings"]
42
  model = _get_local_model()
 
 
43
  vectors = model.encode(texts, batch_size=32, normalize_embeddings=True, show_progress_bar=False)
44
  return vectors.tolist()
45
 
46
- async def embed_one(self, text: str) -> list[float]:
47
  """Convenience wrapper for a single string."""
48
- results = await self.embed([text])
49
  return results[0]
 
23
  return _local_model
24
 
25
 
26
+ # BGE asymmetric query instruction — prepended locally when is_query=True and
27
+ # environment is local. In prod the HF Space accepts is_query and prepends itself.
28
+ _BGE_QUERY_PREFIX = "Represent this sentence for searching relevant passages: "
29
+
30
+
31
  class Embedder:
32
  def __init__(self, remote_url: str = "", environment: str = "local") -> None:
33
  self._remote = environment == "prod" and bool(remote_url)
34
  self._url = remote_url.rstrip("/") if self._remote else ""
35
 
36
+ async def embed(self, texts: list[str], is_query: bool = False) -> list[list[float]]:
37
+ """
38
+ Encodes texts, returns List of L2-normalised 384-dim float vectors.
39
+
40
+ is_query=True: prepend BGE asymmetric query instruction (queries only).
41
+ is_query=False: encode as-is (document/ingestion embeddings).
42
+ See BGE paper: 2-4% NDCG gain from using the correct prefix on queries.
43
+ """
44
  if not texts:
45
  return []
46
  if self._remote:
47
+ # HF Space handles the prefix server-side when is_query=True.
 
48
  async with httpx.AsyncClient(timeout=30.0) as client:
49
+ resp = await client.post(
50
+ f"{self._url}/embed",
51
+ json={"texts": texts, "is_query": is_query},
52
+ )
53
  resp.raise_for_status()
54
  return resp.json()["embeddings"]
55
  model = _get_local_model()
56
+ if is_query:
57
+ texts = [_BGE_QUERY_PREFIX + t for t in texts]
58
  vectors = model.encode(texts, batch_size=32, normalize_embeddings=True, show_progress_bar=False)
59
  return vectors.tolist()
60
 
61
+ async def embed_one(self, text: str, is_query: bool = False) -> list[float]:
62
  """Convenience wrapper for a single string."""
63
+ results = await self.embed([text], is_query=is_query)
64
  return results[0]
app/services/sparse_encoder.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ backend/app/services/sparse_encoder.py
3
+
4
+ BM25 sparse encoder backed by FastEmbed's Qdrant/bm25 model.
5
+ Used at ingestion time (ingest.py) and at query time (retrieve node).
6
+
7
+ The model downloads a ~5 MB vocabulary file on first use. Subsequent calls
8
+ are fully local. The module-level singleton is loaded lazily on first call
9
+ to avoid startup delay in the API Space.
10
+
11
+ Fallback: if fastembed is not installed, encode() returns empty sparse vectors
12
+ so dense-only retrieval continues working unchanged.
13
+ """
14
+ from __future__ import annotations
15
+
16
+ import logging
17
+ from typing import Any, Optional
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+ _model: Optional[Any] = None
22
+ _fastembed_available: Optional[bool] = None
23
+
24
+
25
+ def _get_model() -> Optional[Any]:
26
+ global _model, _fastembed_available # noqa: PLW0603
27
+ if _fastembed_available is False:
28
+ return None
29
+ if _model is not None:
30
+ return _model
31
+ try:
32
+ from fastembed import SparseTextEmbedding # type: ignore[import]
33
+
34
+ _model = SparseTextEmbedding(model_name="Qdrant/bm25")
35
+ _fastembed_available = True
36
+ logger.info("FastEmbed BM25 sparse encoder loaded (Qdrant/bm25).")
37
+ return _model
38
+ except Exception as exc:
39
+ _fastembed_available = False
40
+ logger.warning(
41
+ "FastEmbed not available — sparse retrieval disabled, falling back to dense-only. (%s)",
42
+ exc,
43
+ )
44
+ return None
45
+
46
+
47
+ class SparseEncoder:
48
+ """
49
+ Wraps FastEmbed SparseTextEmbedding to produce BM25 sparse vectors.
50
+
51
+ Returns list of (indices, values) tuples — one per input text. If FastEmbed
52
+ is unavailable, returns empty ([], []) tuples so callers can gracefully skip
53
+ sparse indexing without breaking the ingestion pipeline.
54
+ """
55
+
56
+ def encode(self, texts: list[str]) -> list[tuple[list[int], list[float]]]:
57
+ """Encode a batch of texts. Returns [(indices, values), ...] per text."""
58
+ if not texts:
59
+ return []
60
+ model = _get_model()
61
+ if model is None:
62
+ return [([], []) for _ in texts]
63
+ try:
64
+ results = []
65
+ for emb in model.embed(texts):
66
+ # fastembed SparseEmbedding exposes .indices and .values as numpy arrays.
67
+ results.append((emb.indices.tolist(), emb.values.tolist()))
68
+ return results
69
+ except Exception as exc:
70
+ logger.warning("BM25 encoding failed (%s); returning empty sparse vectors.", exc)
71
+ return [([], []) for _ in texts]
72
+
73
+ def encode_one(self, text: str) -> tuple[list[int], list[float]]:
74
+ """Convenience wrapper for a single string."""
75
+ return self.encode([text])[0]
76
+
77
+ @property
78
+ def available(self) -> bool:
79
+ """True if FastEmbed loaded successfully and sparse encoding is active."""
80
+ return _get_model() is not None
app/services/vector_store.py CHANGED
@@ -1,113 +1,205 @@
 
1
  import uuid
2
  from typing import Optional
3
 
4
  from qdrant_client import QdrantClient
5
- from qdrant_client.models import PointStruct, VectorParams, Distance, Filter, FieldCondition, MatchValue, PayloadSchemaType
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  from app.models.pipeline import Chunk, ChunkMetadata
8
  from app.core.exceptions import RetrievalError
9
 
 
 
 
 
 
 
10
 
11
  class VectorStore:
12
  def __init__(self, client: QdrantClient, collection: str):
13
  self.client = client
14
  self.collection = collection
15
 
16
- def ensure_collection(self) -> None:
17
- """Creates collection with vectors size=384, distance=Cosine if it does not exist.
18
- Also ensures payload index on metadata.doc_id exists for efficient dedup deletes."""
 
 
 
 
 
 
 
 
19
  collections = self.client.get_collections().collections
20
  exists = any(c.name == self.collection for c in collections)
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  if not exists:
23
  self.client.create_collection(
24
  collection_name=self.collection,
25
- vectors_config=VectorParams(size=384, distance=Distance.COSINE),
 
 
 
 
 
 
 
 
26
  )
 
27
 
28
- # Keyword index allows filter-by-doc_id in delete_by_doc_id.
29
- # create_payload_index is idempotent — safe to call on every startup.
30
  self.client.create_payload_index(
31
  collection_name=self.collection,
32
  field_name="metadata.doc_id",
33
  field_schema=PayloadSchemaType.KEYWORD,
34
  )
35
 
36
- def upsert_chunks(self, chunks: list[Chunk], embeddings: list[list[float]]) -> None:
37
- """Builds PointStruct list and calls client.upsert. Batch size 100."""
38
- if len(chunks) != len(embeddings):
39
- raise ValueError("Number of chunks must match number of embeddings")
40
-
 
 
 
 
 
 
 
 
 
41
  if not chunks:
42
  return
43
 
44
  points = []
45
- for chunk, vector in zip(chunks, embeddings):
 
 
 
 
 
 
 
 
 
46
  points.append(
47
  PointStruct(
48
  id=str(uuid.uuid4()),
49
  vector=vector,
50
- payload=chunk
51
  )
52
  )
53
 
54
- # Qdrant client upsert takes care of batching if needed, but we can chunk our points list
55
  batch_size = 100
56
  for i in range(0, len(points), batch_size):
57
- batch = points[i:i + batch_size]
58
  self.client.upsert(
59
  collection_name=self.collection,
60
- points=batch
61
  )
62
 
63
  def delete_by_doc_id(self, doc_id: str) -> None:
64
- """Filters on metadata.doc_id and deletes. Called before upsert for incremental updates."""
65
  try:
66
- self.client.delete(
67
  collection_name=self.collection,
68
  points_selector=Filter(
69
  must=[
70
  FieldCondition(
71
  key="metadata.doc_id",
72
- match=MatchValue(value=doc_id)
73
  )
74
  ]
75
- )
76
- )
77
- except Exception as e:
78
- # Qdrant raises if index or something missing, but in setup we might just proceed
79
- pass
80
 
81
- def search(self, query_vector: list[float], top_k: int = 20, filters: Optional[dict] = None) -> list[Chunk]:
82
- """Returns chunks with metadata populated from payload."""
 
 
 
 
 
83
  try:
84
  qdrant_filter = None
85
  if filters:
86
- must_conditions = []
87
- for key, value in filters.items():
88
- must_conditions.append(
89
- FieldCondition(
90
- key=f"metadata.{key}",
91
- match=MatchValue(value=value)
92
- )
93
- )
94
- if must_conditions:
95
- qdrant_filter = Filter(must=must_conditions)
96
 
97
  results = self.client.search(
98
  collection_name=self.collection,
99
- query_vector=query_vector,
100
  limit=top_k,
101
- query_filter=qdrant_filter
 
102
  )
103
 
104
- chunks = []
105
- for hit in results:
106
- if hit.payload:
107
- chunks.append(Chunk(**hit.payload))
108
- return chunks
109
 
110
- except Exception as e:
111
  raise RetrievalError(
112
- f"Vector search failed: {e}", context={"error": str(e)}
113
- ) from e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
  import uuid
3
  from typing import Optional
4
 
5
  from qdrant_client import QdrantClient
6
+ from qdrant_client.models import (
7
+ Distance,
8
+ FieldCondition,
9
+ Filter,
10
+ MatchValue,
11
+ NamedSparseVector,
12
+ NamedVector,
13
+ PayloadSchemaType,
14
+ PointStruct,
15
+ SparseIndexParams,
16
+ SparseVector,
17
+ SparseVectorParams,
18
+ VectorParams,
19
+ )
20
 
21
  from app.models.pipeline import Chunk, ChunkMetadata
22
  from app.core.exceptions import RetrievalError
23
 
24
+ logger = logging.getLogger(__name__)
25
+
26
+ # Named vector keys used in the Qdrant collection.
27
+ _DENSE_VEC = "dense"
28
+ _SPARSE_VEC = "sparse"
29
+
30
 
31
  class VectorStore:
32
  def __init__(self, client: QdrantClient, collection: str):
33
  self.client = client
34
  self.collection = collection
35
 
36
+ def ensure_collection(self, allow_recreate: bool = False) -> None:
37
+ """
38
+ Creates or migrates the collection to support named dense + sparse vectors.
39
+
40
+ allow_recreate=True (ingestion): if the collection exists with the old
41
+ unnamed-vector format, delete and recreate it. Ingestion will re-index
42
+ everything on the same run, so data loss is acceptable.
43
+
44
+ allow_recreate=False (API startup): never touch an existing collection.
45
+ The API will use whatever format is already deployed.
46
+ """
47
  collections = self.client.get_collections().collections
48
  exists = any(c.name == self.collection for c in collections)
49
 
50
+ if exists and allow_recreate:
51
+ try:
52
+ info = self.client.get_collection(self.collection)
53
+ is_old_format = not isinstance(info.config.params.vectors, dict)
54
+ has_no_sparse = not info.config.params.sparse_vectors
55
+ if is_old_format or has_no_sparse:
56
+ logger.info(
57
+ "Collection %r uses old vector format; recreating for hybrid search.",
58
+ self.collection,
59
+ )
60
+ self.client.delete_collection(self.collection)
61
+ exists = False
62
+ except Exception as exc:
63
+ logger.warning("Could not inspect collection format (%s); skipping migration.", exc)
64
+
65
  if not exists:
66
  self.client.create_collection(
67
  collection_name=self.collection,
68
+ vectors_config={
69
+ _DENSE_VEC: VectorParams(size=384, distance=Distance.COSINE),
70
+ },
71
+ sparse_vectors_config={
72
+ # on_disk=False keeps sparse index in RAM for sub-ms lookup.
73
+ _SPARSE_VEC: SparseVectorParams(
74
+ index=SparseIndexParams(on_disk=False)
75
+ ),
76
+ },
77
  )
78
+ logger.info("Created collection %r with dense + sparse vectors.", self.collection)
79
 
80
+ # Keyword index for filter-by-doc_id in delete_by_doc_id. Idempotent.
 
81
  self.client.create_payload_index(
82
  collection_name=self.collection,
83
  field_name="metadata.doc_id",
84
  field_schema=PayloadSchemaType.KEYWORD,
85
  )
86
 
87
+ def upsert_chunks(
88
+ self,
89
+ chunks: list[Chunk],
90
+ dense_embeddings: list[list[float]],
91
+ sparse_embeddings: Optional[list[tuple[list[int], list[float]]]] = None,
92
+ ) -> None:
93
+ """
94
+ Builds PointStruct list with named dense (and optionally sparse) vectors.
95
+
96
+ sparse_embeddings: list of (indices, values) tuples from SparseEncoder.
97
+ If None or empty sparse vector for a chunk, dense-only point is used.
98
+ """
99
+ if len(chunks) != len(dense_embeddings):
100
+ raise ValueError("Number of chunks must match number of dense embeddings")
101
  if not chunks:
102
  return
103
 
104
  points = []
105
+ for i, (chunk, dense_vec) in enumerate(zip(chunks, dense_embeddings)):
106
+ vector: dict = {_DENSE_VEC: dense_vec}
107
+
108
+ if sparse_embeddings is not None:
109
+ indices, values = sparse_embeddings[i]
110
+ if indices: # Skip empty sparse vectors gracefully
111
+ vector[_SPARSE_VEC] = SparseVector(
112
+ indices=indices, values=values
113
+ )
114
+
115
  points.append(
116
  PointStruct(
117
  id=str(uuid.uuid4()),
118
  vector=vector,
119
+ payload=chunk,
120
  )
121
  )
122
 
 
123
  batch_size = 100
124
  for i in range(0, len(points), batch_size):
 
125
  self.client.upsert(
126
  collection_name=self.collection,
127
+ points=points[i : i + batch_size],
128
  )
129
 
130
  def delete_by_doc_id(self, doc_id: str) -> None:
131
+ """Filters on metadata.doc_id and deletes all matching points."""
132
  try:
133
+ self.client.delete(
134
  collection_name=self.collection,
135
  points_selector=Filter(
136
  must=[
137
  FieldCondition(
138
  key="metadata.doc_id",
139
+ match=MatchValue(value=doc_id),
140
  )
141
  ]
142
+ ),
143
+ )
144
+ except Exception:
145
+ pass # Safe to ignore collection or index may not exist yet
 
146
 
147
+ def search(
148
+ self,
149
+ query_vector: list[float],
150
+ top_k: int = 20,
151
+ filters: Optional[dict] = None,
152
+ ) -> list[Chunk]:
153
+ """Dense vector search using the named 'dense' vector."""
154
  try:
155
  qdrant_filter = None
156
  if filters:
157
+ must_conditions = [
158
+ FieldCondition(key=f"metadata.{k}", match=MatchValue(value=v))
159
+ for k, v in filters.items()
160
+ ]
161
+ qdrant_filter = Filter(must=must_conditions)
 
 
 
 
 
162
 
163
  results = self.client.search(
164
  collection_name=self.collection,
165
+ query_vector=NamedVector(name=_DENSE_VEC, vector=query_vector),
166
  limit=top_k,
167
+ query_filter=qdrant_filter,
168
+ with_payload=True,
169
  )
170
 
171
+ return [Chunk(**hit.payload) for hit in results if hit.payload]
 
 
 
 
172
 
173
+ except Exception as exc:
174
  raise RetrievalError(
175
+ f"Dense vector search failed: {exc}", context={"error": str(exc)}
176
+ ) from exc
177
+
178
+ def search_sparse(
179
+ self,
180
+ indices: list[int],
181
+ values: list[float],
182
+ top_k: int = 20,
183
+ ) -> list[Chunk]:
184
+ """
185
+ BM25 sparse vector search using the named 'sparse' vector.
186
+ Returns empty list if sparse vectors are absent or indices is empty.
187
+ """
188
+ if not indices:
189
+ return []
190
+ try:
191
+ results = self.client.search(
192
+ collection_name=self.collection,
193
+ query_vector=NamedSparseVector(
194
+ name=_SPARSE_VEC,
195
+ vector=SparseVector(indices=indices, values=values),
196
+ ),
197
+ limit=top_k,
198
+ with_payload=True,
199
+ )
200
+ return [Chunk(**hit.payload) for hit in results if hit.payload]
201
+
202
+ except Exception as exc:
203
+ # Sparse index may not exist on old collections — log and continue.
204
+ logger.warning("Sparse search failed (%s); skipping sparse results.", exc)
205
+ return []
requirements.txt CHANGED
@@ -20,4 +20,7 @@ presidio-analyzer>=2.2.354
20
  tenacity>=8.3.0
21
  python-jose[cryptography]>=3.3.0
22
  google-genai>=1.0.0
 
 
 
23
  toon_format @ git+https://github.com/toon-format/toon-python.git
 
20
  tenacity>=8.3.0
21
  python-jose[cryptography]>=3.3.0
22
  google-genai>=1.0.0
23
+ # fastembed: powers BM25 sparse retrieval (Stage 2). Qdrant/bm25 vocabulary
24
+ # downloads ~5 MB on first use then runs fully local — no GPU, no network at query time.
25
+ fastembed>=0.3.6
26
  toon_format @ git+https://github.com/toon-format/toon-python.git