Paramjit Singh commited on
Commit
4defd96
·
unverified ·
2 Parent(s): 793ad4fcd64a8b

Merge pull request #251 from Exodus2004/feat/issue-114-hybrid-search

Browse files
backend/app/rag/bm25.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ BM25 Keyword Search implementation using rank_bm25.
3
+ Stores a BM25 index per document to allow easy updates and deletions.
4
+ """
5
+ import os
6
+ import glob
7
+ import pickle
8
+ import logging
9
+ from typing import List, Dict, Any, Optional
10
+
11
+ from rank_bm25 import BM25Okapi
12
+ from app.config import get_settings
13
+
14
+ logger = logging.getLogger(__name__)
15
+ settings = get_settings()
16
+
17
+ def get_bm25_dir(user_id: str) -> str:
18
+ """Get the directory path for a user's BM25 indexes."""
19
+ clean_id = user_id.replace("-", "_")
20
+ path = os.path.join(settings.CHROMA_PERSIST_DIR, "bm25", clean_id)
21
+ os.makedirs(path, exist_ok=True)
22
+ return path
23
+
24
+ def get_bm25_path(user_id: str, document_id: str) -> str:
25
+ """Get the file path for a specific document's BM25 index."""
26
+ return os.path.join(get_bm25_dir(user_id), f"{document_id}.pkl")
27
+
28
+ def tokenize(text: str) -> List[str]:
29
+ """Simple tokenization for BM25."""
30
+ # Convert to lowercase and split by whitespace
31
+ return text.lower().split()
32
+
33
+ def store_bm25_index(chunks: List[Dict[str, Any]], document_id: str, filename: str, user_id: str):
34
+ """
35
+ Build and store a BM25 index for the given document chunks.
36
+ """
37
+ if not chunks:
38
+ return
39
+
40
+ texts = [chunk["text"] for chunk in chunks]
41
+ tokenized_texts = [tokenize(text) for text in texts]
42
+ bm25 = BM25Okapi(tokenized_texts)
43
+
44
+ # Format chunks to match vectorstore output
45
+ formatted_chunks = []
46
+ for chunk in chunks:
47
+ formatted_chunks.append({
48
+ "text": chunk["text"],
49
+ "filename": filename,
50
+ "document_id": document_id,
51
+ "page": chunk.get("page", 1),
52
+ })
53
+
54
+ data = {
55
+ "bm25": bm25,
56
+ "chunks": formatted_chunks
57
+ }
58
+
59
+ path = get_bm25_path(user_id, document_id)
60
+ try:
61
+ with open(path, "wb") as f:
62
+ pickle.dump(data, f)
63
+ logger.info(f"Stored BM25 index for document {document_id}")
64
+ except Exception as e:
65
+ logger.error(f"Failed to store BM25 index for {document_id}: {e}")
66
+
67
+ def _query_single_index(path: str, tokenized_query: List[str], top_k: int) -> List[Dict[str, Any]]:
68
+ """Query a single BM25 index file."""
69
+ if not os.path.exists(path):
70
+ return []
71
+
72
+ try:
73
+ with open(path, "rb") as f:
74
+ data = pickle.load(f)
75
+ except Exception as e:
76
+ logger.error(f"Failed to load BM25 index from {path}: {e}")
77
+ return []
78
+
79
+ bm25 = data["bm25"]
80
+ chunks = data["chunks"]
81
+
82
+ scores = bm25.get_scores(tokenized_query)
83
+
84
+ # Get top_k indices sorted by score
85
+ top_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:top_k]
86
+
87
+ results = []
88
+ for i in top_indices:
89
+ if scores[i] > 0:
90
+ chunk = chunks[i].copy()
91
+ # Normalize BM25 score to 0-1 range roughly, or just keep raw.
92
+ # BM25 scores are usually > 0, often 1-10.
93
+ # We keep the raw score for now, RRF will handle the ranking.
94
+ chunk["score"] = float(scores[i])
95
+ results.append(chunk)
96
+
97
+ return results
98
+
99
+ def query_bm25(
100
+ query: str,
101
+ user_id: str,
102
+ document_id: Optional[str] = None,
103
+ top_k: int = 10,
104
+ ) -> List[Dict[str, Any]]:
105
+ """
106
+ Query BM25 index(es) for relevant chunks.
107
+ """
108
+ tokenized_query = tokenize(query)
109
+
110
+ if document_id:
111
+ path = get_bm25_path(user_id, document_id)
112
+ return _query_single_index(path, tokenized_query, top_k)
113
+
114
+ # If no document_id, query all documents for this user
115
+ user_dir = get_bm25_dir(user_id)
116
+ all_results = []
117
+
118
+ for path in glob.glob(os.path.join(user_dir, "*.pkl")):
119
+ results = _query_single_index(path, tokenized_query, top_k)
120
+ all_results.extend(results)
121
+
122
+ # Sort all results by score and take top_k
123
+ all_results.sort(key=lambda x: x["score"], reverse=True)
124
+ return all_results[:top_k]
125
+
126
+ def delete_bm25_index(document_id: str, user_id: str):
127
+ """Delete a specific document's BM25 index."""
128
+ path = get_bm25_path(user_id, document_id)
129
+ if os.path.exists(path):
130
+ try:
131
+ os.remove(path)
132
+ logger.info(f"Deleted BM25 index for document {document_id}")
133
+ except Exception as e:
134
+ logger.warning(f"Error deleting BM25 index: {e}")
135
+
136
+ def delete_user_bm25_indexes(user_id: str):
137
+ """Delete all BM25 indexes for a user."""
138
+ user_dir = get_bm25_dir(user_id)
139
+ if os.path.exists(user_dir):
140
+ try:
141
+ for path in glob.glob(os.path.join(user_dir, "*.pkl")):
142
+ os.remove(path)
143
+ os.rmdir(user_dir)
144
+ logger.info(f"Deleted BM25 directory for user {user_id}")
145
+ except Exception as e:
146
+ logger.warning(f"Error deleting BM25 directory for user {user_id}: {e}")
backend/app/rag/retriever.py CHANGED
@@ -1,10 +1,18 @@
1
  """
2
- Two-stage retrieval: ChromaDB similarity search + cross-encoder reranking.
3
  """
4
  import json
5
  import logging
6
  import re
7
  from typing import List, Dict, Any, Optional
 
 
 
 
 
 
 
 
8
  from app.config import get_settings
9
  from app.rag.embeddings import embed_query
10
  from app.rag.tracing import trace_function
@@ -35,6 +43,42 @@ def get_reranker():
35
  return _reranker if _reranker != "disabled" else None
36
 
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  def transform_query(query: str) -> List[str]:
39
  """Rewrite a user question into multiple retrieval-friendly search queries."""
40
  original_query = query.strip()
@@ -183,28 +227,43 @@ def retrieve(
183
  ) -> List[Dict[str, Any]]:
184
  """
185
  Two-stage retrieval pipeline:
186
- 1. ChromaDB similarity search (top-K broad)
187
  2. Cross-encoder reranking (top-K refined)
188
 
189
  Returns chunks with confidence scores.
190
  """
191
- # ── Stage 1: Query transformation + embedding search ─────────────
192
- candidates = []
193
- for search_query in transform_query(query):
194
- query_vector = embed_query(search_query)
195
- candidates.extend(
196
- query_chunks(
197
- query_embedding=query_vector,
198
- user_id=user_id,
199
- document_id=document_id,
200
- top_k=settings.TOP_K_RETRIEVAL,
201
- )
202
- )
203
 
204
- if not candidates:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  return []
206
 
207
- candidates = _merge_candidates(candidates)
208
 
209
  # ── Stage 2: Cross-encoder reranking ─────────────
210
  reranker = get_reranker()
@@ -223,8 +282,9 @@ def retrieve(
223
  candidates.sort(key=lambda x: x.get("rerank_score", 0), reverse=True)
224
 
225
  except Exception as e:
226
- logger.warning(f"Reranking failed, using embedding scores: {e}")
227
 
 
228
  candidates.sort(key=lambda x: x.get("rerank_score", x.get("score", 0)), reverse=True)
229
 
230
  # ── Take top-K after reranking ─────────���─────────
 
1
  """
2
+ Two-stage retrieval: Hybrid Ensemble (ChromaDB + BM25) + cross-encoder reranking.
3
  """
4
  import json
5
  import logging
6
  import re
7
  from typing import List, Dict, Any, Optional
8
+
9
+ # In LangChain 1.3.2+, EnsembleRetriever moved to langchain_classic (imported by langchain_community)
10
+ from langchain_classic.retrievers import EnsembleRetriever
11
+ from langchain_core.retrievers import BaseRetriever
12
+ from langchain_core.documents import Document as LangchainDocument
13
+ from langchain_core.callbacks import CallbackManagerForRetrieverRun
14
+ from pydantic import Field
15
+
16
  from app.config import get_settings
17
  from app.rag.embeddings import embed_query
18
  from app.rag.tracing import trace_function
 
43
  return _reranker if _reranker != "disabled" else None
44
 
45
 
46
+ class CustomVectorRetriever(BaseRetriever):
47
+ user_id: str = Field(description="User ID")
48
+ document_id: Optional[str] = Field(default=None, description="Document ID")
49
+ top_k: int = Field(default=10, description="Top K results")
50
+
51
+ def _get_relevant_documents(
52
+ self, query: str, *, run_manager: CallbackManagerForRetrieverRun
53
+ ) -> List[LangchainDocument]:
54
+ query_vector = embed_query(query)
55
+ candidates = query_chunks(
56
+ query_embedding=query_vector,
57
+ user_id=self.user_id,
58
+ document_id=self.document_id,
59
+ top_k=self.top_k,
60
+ )
61
+ return [LangchainDocument(page_content=c["text"], metadata=c) for c in candidates]
62
+
63
+
64
+ class CustomBM25Retriever(BaseRetriever):
65
+ user_id: str = Field(description="User ID")
66
+ document_id: Optional[str] = Field(default=None, description="Document ID")
67
+ top_k: int = Field(default=10, description="Top K results")
68
+
69
+ def _get_relevant_documents(
70
+ self, query: str, *, run_manager: CallbackManagerForRetrieverRun
71
+ ) -> List[LangchainDocument]:
72
+ from app.rag.bm25 import query_bm25
73
+ candidates = query_bm25(
74
+ query=query,
75
+ user_id=self.user_id,
76
+ document_id=self.document_id,
77
+ top_k=self.top_k,
78
+ )
79
+ return [LangchainDocument(page_content=c["text"], metadata=c) for c in candidates]
80
+
81
+
82
  def transform_query(query: str) -> List[str]:
83
  """Rewrite a user question into multiple retrieval-friendly search queries."""
84
  original_query = query.strip()
 
227
  ) -> List[Dict[str, Any]]:
228
  """
229
  Two-stage retrieval pipeline:
230
+ 1. Hybrid Search (Vector + BM25 via EnsembleRetriever with RRF) with Query Transformation
231
  2. Cross-encoder reranking (top-K refined)
232
 
233
  Returns chunks with confidence scores.
234
  """
235
+ # ── Stage 1: Hybrid Search with Query Transformation ─────────────
236
+ vector_retriever = CustomVectorRetriever(
237
+ user_id=user_id,
238
+ document_id=document_id,
239
+ top_k=settings.TOP_K_RETRIEVAL,
240
+ )
241
+
242
+ bm25_retriever = CustomBM25Retriever(
243
+ user_id=user_id,
244
+ document_id=document_id,
245
+ top_k=settings.TOP_K_RETRIEVAL,
246
+ )
247
 
248
+ ensemble_retriever = EnsembleRetriever(
249
+ retrievers=[vector_retriever, bm25_retriever],
250
+ weights=[0.6, 0.4]
251
+ )
252
+
253
+ all_candidates = []
254
+ for search_query in transform_query(query):
255
+ docs = ensemble_retriever.invoke(search_query)
256
+ for i, doc in enumerate(docs):
257
+ chunk = doc.metadata.copy()
258
+ # Preserve a mock score based on rank for fallback if reranker fails
259
+ # We use 1.0/(i+1) as a base RRF-like score
260
+ chunk["score"] = 1.0 / (i + 1)
261
+ all_candidates.append(chunk)
262
+
263
+ if not all_candidates:
264
  return []
265
 
266
+ candidates = _merge_candidates(all_candidates)
267
 
268
  # ── Stage 2: Cross-encoder reranking ─────────────
269
  reranker = get_reranker()
 
282
  candidates.sort(key=lambda x: x.get("rerank_score", 0), reverse=True)
283
 
284
  except Exception as e:
285
+ logger.warning(f"Reranking failed, using hybrid scores: {e}")
286
 
287
+ # Ensure candidates are sorted by best available score
288
  candidates.sort(key=lambda x: x.get("rerank_score", x.get("score", 0)), reverse=True)
289
 
290
  # ── Take top-K after reranking ─────────���─────────
backend/app/rag/vectorstore.py CHANGED
@@ -49,12 +49,19 @@ def store_chunks(
49
  user_id: str,
50
  ) -> int:
51
  """
52
- Embed and store document chunks in ChromaDB.
53
  Returns the number of chunks stored.
54
  """
55
  if not chunks:
56
  return 0
57
 
 
 
 
 
 
 
 
58
  # Generate captions for any extracted images before embedding
59
  try:
60
  from app.rag.vision import generate_captions_for_chunks
@@ -178,6 +185,12 @@ def delete_document_chunks(document_id: str, user_id: str):
178
  client = get_chroma_client()
179
  collection_name = get_collection_name(user_id)
180
 
 
 
 
 
 
 
181
  try:
182
  collection = client.get_collection(name=collection_name)
183
  # Get all IDs for this document
@@ -197,6 +210,12 @@ def delete_user_collection(user_id: str):
197
  client = get_chroma_client()
198
  collection_name = get_collection_name(user_id)
199
 
 
 
 
 
 
 
200
  try:
201
  client.delete_collection(name=collection_name)
202
  logger.info(f"Deleted collection {collection_name}")
 
49
  user_id: str,
50
  ) -> int:
51
  """
52
+ Embed and store document chunks in ChromaDB, and build a local BM25 index.
53
  Returns the number of chunks stored.
54
  """
55
  if not chunks:
56
  return 0
57
 
58
+ # Build and store BM25 index
59
+ from app.rag.bm25 import store_bm25_index
60
+ try:
61
+ store_bm25_index(chunks, document_id, filename, user_id)
62
+ except Exception as e:
63
+ logger.error(f"Could not build BM25 index: {e}")
64
+
65
  # Generate captions for any extracted images before embedding
66
  try:
67
  from app.rag.vision import generate_captions_for_chunks
 
185
  client = get_chroma_client()
186
  collection_name = get_collection_name(user_id)
187
 
188
+ try:
189
+ from app.rag.bm25 import delete_bm25_index
190
+ delete_bm25_index(document_id, user_id)
191
+ except Exception as e:
192
+ logger.warning(f"Error deleting BM25 index: {e}")
193
+
194
  try:
195
  collection = client.get_collection(name=collection_name)
196
  # Get all IDs for this document
 
210
  client = get_chroma_client()
211
  collection_name = get_collection_name(user_id)
212
 
213
+ try:
214
+ from app.rag.bm25 import delete_user_bm25_indexes
215
+ delete_user_bm25_indexes(user_id)
216
+ except Exception as e:
217
+ logger.warning(f"Error deleting user BM25 indexes: {e}")
218
+
219
  try:
220
  client.delete_collection(name=collection_name)
221
  logger.info(f"Deleted collection {collection_name}")
backend/requirements.txt CHANGED
@@ -30,10 +30,12 @@ python-docx
30
 
31
  # LangChain & RAG
32
  langchain
 
33
  langchain-community
34
  langchain-huggingface
35
  langchain-text-splitters
36
  langsmith
 
37
 
38
  # Embeddings & ML
39
  sentence-transformers
 
30
 
31
  # LangChain & RAG
32
  langchain
33
+ langchain-classic
34
  langchain-community
35
  langchain-huggingface
36
  langchain-text-splitters
37
  langsmith
38
+ rank-bm25
39
 
40
  # Embeddings & ML
41
  sentence-transformers
backend/tests/test_documents.py CHANGED
@@ -95,6 +95,7 @@ def test_ingest_document_builds_and_saves_graph(db_session, monkeypatch, tmp_pat
95
 
96
  def test_delete_document_removes_knowledge_graph(client, auth_headers, ready_document, monkeypatch):
97
  deleted = {}
 
98
 
99
  monkeypatch.setattr("app.routes.documents.delete_document_chunks", lambda **kwargs: None)
100
  monkeypatch.setattr(
@@ -105,9 +106,9 @@ def test_delete_document_removes_knowledge_graph(client, auth_headers, ready_doc
105
  )
106
 
107
  response = client.delete(
108
- f"/api/v1/documents/{ready_document.id}",
109
  headers=auth_headers,
110
  )
111
 
112
  assert response.status_code == 200
113
- assert deleted["document_id"] == ready_document.id
 
95
 
96
  def test_delete_document_removes_knowledge_graph(client, auth_headers, ready_document, monkeypatch):
97
  deleted = {}
98
+ doc_id = ready_document.id
99
 
100
  monkeypatch.setattr("app.routes.documents.delete_document_chunks", lambda **kwargs: None)
101
  monkeypatch.setattr(
 
106
  )
107
 
108
  response = client.delete(
109
+ f"/api/v1/documents/{doc_id}",
110
  headers=auth_headers,
111
  )
112
 
113
  assert response.status_code == 200
114
+ assert deleted["document_id"] == doc_id
backend/tests/test_graphrag_agent.py CHANGED
@@ -34,7 +34,7 @@ def test_generate_answer_appends_graph_context_without_changing_sources(monkeypa
34
  }
35
  ]
36
 
37
- monkeypatch.setattr(agent, "get_llm_client", lambda: client)
38
  monkeypatch.setattr(agent, "retrieve", lambda **kwargs: chunks)
39
  monkeypatch.setattr(
40
  agent,
@@ -66,7 +66,7 @@ def test_generate_answer_stream_appends_graph_context(monkeypatch):
66
  captured["messages"] = messages
67
  return iter([])
68
 
69
- monkeypatch.setattr(agent, "get_llm_client", lambda: StreamingClient())
70
  monkeypatch.setattr(
71
  agent,
72
  "retrieve",
 
34
  }
35
  ]
36
 
37
+ monkeypatch.setattr(agent, "get_llm_client", lambda hf_token=None: client)
38
  monkeypatch.setattr(agent, "retrieve", lambda **kwargs: chunks)
39
  monkeypatch.setattr(
40
  agent,
 
66
  captured["messages"] = messages
67
  return iter([])
68
 
69
+ monkeypatch.setattr(agent, "get_llm_client", lambda hf_token=None: StreamingClient())
70
  monkeypatch.setattr(
71
  agent,
72
  "retrieve",
backend/tests/test_retriever.py CHANGED
@@ -72,6 +72,6 @@ def test_retrieve_fans_out_transformed_queries_and_merges_duplicates(monkeypatch
72
  chunks = retriever.retrieve("How do taxes and healthcare work?", user_id="user-1")
73
 
74
  assert searched_queries == ["embedding:taxes", "embedding:healthcare"]
75
- assert [chunk["id"] for chunk in chunks] == ["shared", "healthcare", "taxes"]
76
- assert chunks[0]["score"] == 0.9
77
  assert chunks[0]["confidence"] == 100.0
 
72
  chunks = retriever.retrieve("How do taxes and healthcare work?", user_id="user-1")
73
 
74
  assert searched_queries == ["embedding:taxes", "embedding:healthcare"]
75
+ assert [chunk["id"] for chunk in chunks] == ["shared", "taxes", "healthcare"]
76
+ assert chunks[0]["score"] == 1.0
77
  assert chunks[0]["confidence"] == 100.0