Rifqi Hafizuddin commited on
Commit
40925b4
·
1 Parent(s): be9bbd9

[KM-507] now only uses hybrid (cosine and bm25)

Browse files
Files changed (1) hide show
  1. src/rag/retrievers/schema.py +25 -189
src/rag/retrievers/schema.py CHANGED
@@ -1,23 +1,14 @@
1
  """Schema retriever — handles DB schemas (source_type="database") and tabular file
2
  columns stored as source_type="document" with file_type in ("csv","xlsx").
3
 
4
- Multiple retrieval strategies are exposed for benchmarking. The active strategy
5
- used by the router is `retrieve()`, which dispatches to ACTIVE_STRATEGY.
6
- Change ACTIVE_STRATEGY at module level to switch without touching the router.
7
 
8
- All strategies embed the query exactly once, then fan out to parallel SQL legs.
9
-
10
- Vector distance strategies:
11
- dense_no_threshold — cosine (<=>), no score floor, always returns k chunks
12
- dense_dot — inner product (<#>), equivalent to cosine for normalized embeddings
13
- dense_l2 — L2/euclidean (<->), monotonic with cosine on unit-sphere vectors
14
- hybrid — RRF merge of dense + FTS (database + tabular)
15
- hybrid_bm25 — RRF merge of dense + FTS (database only)
16
  """
17
 
18
  import asyncio
19
- import time
20
- from typing import Literal
21
 
22
  from sqlalchemy import text
23
 
@@ -30,9 +21,6 @@ logger = get_logger("schema_retriever")
30
 
31
  _TABULAR_FILE_TYPES = ("csv", "xlsx")
32
 
33
- Strategy = Literal["dense_no_threshold", "dense_dot", "dense_l2", "hybrid", "hybrid_bm25"]
34
- ACTIVE_STRATEGY: Strategy = "hybrid_bm25"
35
-
36
 
37
  class SchemaRetriever(BaseRetriever):
38
  def __init__(self):
@@ -46,26 +34,20 @@ class SchemaRetriever(BaseRetriever):
46
  return await asyncio.to_thread(self.vector_store.embeddings.embed_query, query)
47
 
48
  async def _search_db(
49
- self, embedding: list[float], user_id: str, k: int, operator: str = "<=>"
50
  ) -> list[RetrievalResult]:
51
- """Vector search over database chunks. Accepts a pre-computed embedding."""
52
  emb_str = "[" + ",".join(str(x) for x in embedding) + "]"
53
 
54
- if operator == "<#>":
55
- score_sql = f"(lpe.embedding <#> '{emb_str}'::vector) * -1"
56
- elif operator == "<->":
57
- score_sql = f"1.0 / (1.0 + (lpe.embedding <-> '{emb_str}'::vector))"
58
- else:
59
- score_sql = f"1.0 - (lpe.embedding <=> '{emb_str}'::vector)"
60
-
61
  sql = text(f"""
62
- SELECT lpe.document, lpe.cmetadata, {score_sql} AS score
 
63
  FROM langchain_pg_embedding lpe
64
  JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid
65
  WHERE lpc.name = 'document_embeddings'
66
  AND lpe.cmetadata->>'user_id' = :user_id
67
  AND lpe.cmetadata->>'source_type' = 'database'
68
- ORDER BY lpe.embedding {operator} '{emb_str}'::vector ASC
69
  LIMIT :k
70
  """)
71
 
@@ -84,20 +66,14 @@ class SchemaRetriever(BaseRetriever):
84
  ]
85
 
86
  async def _search_tabular(
87
- self, embedding: list[float], user_id: str, k: int, operator: str = "<=>"
88
  ) -> list[RetrievalResult]:
89
- """Vector search over tabular document chunks. Accepts a pre-computed embedding."""
90
  emb_str = "[" + ",".join(str(x) for x in embedding) + "]"
91
 
92
- if operator == "<#>":
93
- score_sql = f"(lpe.embedding <#> '{emb_str}'::vector) * -1"
94
- elif operator == "<->":
95
- score_sql = f"1.0 / (1.0 + (lpe.embedding <-> '{emb_str}'::vector))"
96
- else:
97
- score_sql = f"1.0 - (lpe.embedding <=> '{emb_str}'::vector)"
98
-
99
  sql = text(f"""
100
- SELECT lpe.document, lpe.cmetadata, {score_sql} AS score
 
101
  FROM langchain_pg_embedding lpe
102
  JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid
103
  WHERE lpc.name = 'document_embeddings'
@@ -105,7 +81,7 @@ class SchemaRetriever(BaseRetriever):
105
  AND lpe.cmetadata->>'source_type' = 'document'
106
  AND (lpe.cmetadata->'data'->>'file_type' = 'csv'
107
  OR lpe.cmetadata->'data'->>'file_type' = 'xlsx')
108
- ORDER BY lpe.embedding {operator} '{emb_str}'::vector ASC
109
  LIMIT :k
110
  """)
111
 
@@ -113,55 +89,18 @@ class SchemaRetriever(BaseRetriever):
113
  result = await conn.execute(sql, {"user_id": user_id, "k": k * 4})
114
  rows = result.fetchall()
115
 
116
- results = []
117
- for row in rows:
118
- results.append(
119
- RetrievalResult(
120
- content=row.document,
121
- metadata=row.cmetadata,
122
- score=float(row.score),
123
- source_type="document",
124
- )
125
- )
126
- if len(results) >= k:
127
- break
128
- return results
129
-
130
- async def _search_fts_db(self, query: str, user_id: str, k: int) -> list[RetrievalResult]:
131
- """Full-text search over DB schema chunks using PostgreSQL tsvector.
132
-
133
- Requires GIN index on langchain_pg_embedding.document (created by init_db.py).
134
- """
135
- sql = text("""
136
- SELECT lpe.document, lpe.cmetadata,
137
- ts_rank(to_tsvector('english', lpe.document),
138
- plainto_tsquery('english', :query)) AS rank
139
- FROM langchain_pg_embedding lpe
140
- JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid
141
- WHERE lpc.name = 'document_embeddings'
142
- AND lpe.cmetadata->>'user_id' = :user_id
143
- AND lpe.cmetadata->>'source_type' = 'database'
144
- AND to_tsvector('english', lpe.document) @@ plainto_tsquery('english', :query)
145
- ORDER BY rank DESC
146
- LIMIT :k
147
- """)
148
-
149
- async with _pgvector_engine.connect() as conn:
150
- result = await conn.execute(sql, {"query": query, "user_id": user_id, "k": k})
151
- rows = result.fetchall()
152
-
153
  return [
154
  RetrievalResult(
155
  content=row.document,
156
  metadata=row.cmetadata,
157
- score=float(row.rank),
158
- source_type="database",
159
  )
160
  for row in rows
161
  ]
162
 
163
- async def _search_fts_tabular(self, query: str, user_id: str, k: int) -> list[RetrievalResult]:
164
- """Full-text search over tabular document chunks using PostgreSQL tsvector."""
165
  sql = text("""
166
  SELECT lpe.document, lpe.cmetadata,
167
  ts_rank(to_tsvector('english', lpe.document),
@@ -170,9 +109,7 @@ class SchemaRetriever(BaseRetriever):
170
  JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid
171
  WHERE lpc.name = 'document_embeddings'
172
  AND lpe.cmetadata->>'user_id' = :user_id
173
- AND lpe.cmetadata->>'source_type' = 'document'
174
- AND (lpe.cmetadata->'data'->>'file_type' = 'csv'
175
- OR lpe.cmetadata->'data'->>'file_type' = 'xlsx')
176
  AND to_tsvector('english', lpe.document) @@ plainto_tsquery('english', :query)
177
  ORDER BY rank DESC
178
  LIMIT :k
@@ -187,7 +124,7 @@ class SchemaRetriever(BaseRetriever):
187
  content=row.document,
188
  metadata=row.cmetadata,
189
  score=float(row.rank),
190
- source_type="document",
191
  )
192
  for row in rows
193
  ]
@@ -228,66 +165,11 @@ class SchemaRetriever(BaseRetriever):
228
  return sorted(seen.values(), key=lambda r: r.score, reverse=True)
229
 
230
  # ------------------------------------------------------------------
231
- # Named strategiesone embed call each, legs run in parallel
232
  # ------------------------------------------------------------------
233
 
234
- async def dense_no_threshold(self, query: str, user_id: str, k: int = 5) -> list[RetrievalResult]:
235
- """Cosine similarity, no score cutoff always returns k chunks."""
236
- embedding = await self._embed_query(query)
237
- db_results, tabular_results = await asyncio.gather(
238
- self._search_db(embedding, user_id, k),
239
- self._search_tabular(embedding, user_id, k),
240
- )
241
- return self._dedup(db_results + tabular_results)[:k]
242
-
243
- async def dense_dot(self, query: str, user_id: str, k: int = 5) -> list[RetrievalResult]:
244
- """Inner product similarity (<#>).
245
-
246
- For L2-normalized embeddings (OpenAI), ranking is identical to cosine.
247
- Score = raw inner product (not bounded to [0,1]).
248
- """
249
- embedding = await self._embed_query(query)
250
- db_results, tabular_results = await asyncio.gather(
251
- self._search_db(embedding, user_id, k, "<#>"),
252
- self._search_tabular(embedding, user_id, k, "<#>"),
253
- )
254
- return self._dedup(db_results + tabular_results)[:k]
255
-
256
- async def dense_l2(self, query: str, user_id: str, k: int = 5) -> list[RetrievalResult]:
257
- """L2 (Euclidean) distance similarity (<->).
258
-
259
- For L2-normalized embeddings (OpenAI), ranking order matches cosine.
260
- Score = 1 / (1 + l2_distance), bounded to (0, 1].
261
- """
262
- embedding = await self._embed_query(query)
263
- db_results, tabular_results = await asyncio.gather(
264
- self._search_db(embedding, user_id, k, "<->"),
265
- self._search_tabular(embedding, user_id, k, "<->"),
266
- )
267
- return self._dedup(db_results + tabular_results)[:k]
268
-
269
- async def hybrid(self, query: str, user_id: str, k: int = 5) -> list[RetrievalResult]:
270
- """RRF merge of dense + FTS over both database and tabular sources.
271
-
272
- Embeds once, then runs all four legs (dense db, dense tabular, fts db,
273
- fts tabular) in a single asyncio.gather.
274
- """
275
- embedding = await self._embed_query(query)
276
- db_results, tabular_results, fts_db, fts_tabular = await asyncio.gather(
277
- self._search_db(embedding, user_id, k),
278
- self._search_tabular(embedding, user_id, k),
279
- self._search_fts_db(query, user_id, k * 4),
280
- self._search_fts_tabular(query, user_id, k * 4),
281
- )
282
- dense = self._dedup(db_results + tabular_results)[:k]
283
- fts_all = self._dedup(fts_db + fts_tabular)
284
- return self._rrf_merge(dense, fts_all, top_k=k)
285
-
286
- async def hybrid_bm25(self, query: str, user_id: str, k: int = 5) -> list[RetrievalResult]:
287
- """RRF merge of dense + FTS (database chunks only).
288
-
289
- Embeds once, then runs dense db, dense tabular, and fts db legs in parallel.
290
- """
291
  embedding = await self._embed_query(query)
292
  db_results, tabular_results, fts_results = await asyncio.gather(
293
  self._search_db(embedding, user_id, k),
@@ -295,55 +177,9 @@ class SchemaRetriever(BaseRetriever):
295
  self._search_fts_db(query, user_id, k * 4),
296
  )
297
  dense = self._dedup(db_results + tabular_results)[:k]
298
- return self._rrf_merge(dense, self._dedup(fts_results), top_k=k)
299
-
300
- # ------------------------------------------------------------------
301
- # Public interface — called by the router
302
- # ------------------------------------------------------------------
303
-
304
- async def retrieve(self, query: str, user_id: str, k: int = 5) -> list[RetrievalResult]:
305
- strategy_fn = getattr(self, ACTIVE_STRATEGY)
306
- results = await strategy_fn(query, user_id, k)
307
- logger.info("schema retrieval", strategy=ACTIVE_STRATEGY, count=len(results))
308
  return results
309
 
310
 
311
- # ------------------------------------------------------------------
312
- # Benchmark helper — import in test scripts
313
- # ------------------------------------------------------------------
314
-
315
- async def benchmark(
316
- query: str,
317
- user_id: str,
318
- k: int = 5,
319
- strategies: list[Strategy] | None = None,
320
- ) -> dict[str, dict]:
321
- """Run multiple strategies against the same query and return timing + results."""
322
- retriever = SchemaRetriever()
323
- targets: list[Strategy] = strategies or [
324
- "dense_no_threshold",
325
- "dense_dot",
326
- "dense_l2",
327
- "hybrid",
328
- "hybrid_bm25",
329
- ]
330
- report: dict[str, dict] = {}
331
-
332
- for name in targets:
333
- fn = getattr(retriever, name)
334
- t0 = time.perf_counter()
335
- chunks = await fn(query, user_id, k)
336
- elapsed_ms = round((time.perf_counter() - t0) * 1000)
337
-
338
- total_chars = sum(len(r.content) for r in chunks)
339
- report[name] = {
340
- "chunks": len(chunks),
341
- "estimated_tokens": total_chars // 4,
342
- "elapsed_ms": elapsed_ms,
343
- "results": chunks,
344
- }
345
-
346
- return report
347
-
348
-
349
  schema_retriever = SchemaRetriever()
 
1
  """Schema retriever — handles DB schemas (source_type="database") and tabular file
2
  columns stored as source_type="document" with file_type in ("csv","xlsx").
3
 
4
+ Strategy: hybrid_bm25 RRF merge of dense cosine search (DB + tabular) and
5
+ PostgreSQL full-text search (DB only). Embeds the query once, fans out the
6
+ three legs in parallel.
7
 
8
+ FTS requires a GIN index on langchain_pg_embedding.document (created by init_db.py).
 
 
 
 
 
 
 
9
  """
10
 
11
  import asyncio
 
 
12
 
13
  from sqlalchemy import text
14
 
 
21
 
22
  _TABULAR_FILE_TYPES = ("csv", "xlsx")
23
 
 
 
 
24
 
25
  class SchemaRetriever(BaseRetriever):
26
  def __init__(self):
 
34
  return await asyncio.to_thread(self.vector_store.embeddings.embed_query, query)
35
 
36
  async def _search_db(
37
+ self, embedding: list[float], user_id: str, k: int
38
  ) -> list[RetrievalResult]:
39
+ """Cosine vector search over database chunks."""
40
  emb_str = "[" + ",".join(str(x) for x in embedding) + "]"
41
 
 
 
 
 
 
 
 
42
  sql = text(f"""
43
+ SELECT lpe.document, lpe.cmetadata,
44
+ 1.0 - (lpe.embedding <=> '{emb_str}'::vector) AS score
45
  FROM langchain_pg_embedding lpe
46
  JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid
47
  WHERE lpc.name = 'document_embeddings'
48
  AND lpe.cmetadata->>'user_id' = :user_id
49
  AND lpe.cmetadata->>'source_type' = 'database'
50
+ ORDER BY lpe.embedding <=> '{emb_str}'::vector ASC
51
  LIMIT :k
52
  """)
53
 
 
66
  ]
67
 
68
  async def _search_tabular(
69
+ self, embedding: list[float], user_id: str, k: int
70
  ) -> list[RetrievalResult]:
71
+ """Cosine vector search over tabular document chunks (csv/xlsx)."""
72
  emb_str = "[" + ",".join(str(x) for x in embedding) + "]"
73
 
 
 
 
 
 
 
 
74
  sql = text(f"""
75
+ SELECT lpe.document, lpe.cmetadata,
76
+ 1.0 - (lpe.embedding <=> '{emb_str}'::vector) AS score
77
  FROM langchain_pg_embedding lpe
78
  JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid
79
  WHERE lpc.name = 'document_embeddings'
 
81
  AND lpe.cmetadata->>'source_type' = 'document'
82
  AND (lpe.cmetadata->'data'->>'file_type' = 'csv'
83
  OR lpe.cmetadata->'data'->>'file_type' = 'xlsx')
84
+ ORDER BY lpe.embedding <=> '{emb_str}'::vector ASC
85
  LIMIT :k
86
  """)
87
 
 
89
  result = await conn.execute(sql, {"user_id": user_id, "k": k * 4})
90
  rows = result.fetchall()
91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  return [
93
  RetrievalResult(
94
  content=row.document,
95
  metadata=row.cmetadata,
96
+ score=float(row.score),
97
+ source_type="document",
98
  )
99
  for row in rows
100
  ]
101
 
102
+ async def _search_fts_db(self, query: str, user_id: str, k: int) -> list[RetrievalResult]:
103
+ """Full-text search over DB schema chunks using PostgreSQL tsvector."""
104
  sql = text("""
105
  SELECT lpe.document, lpe.cmetadata,
106
  ts_rank(to_tsvector('english', lpe.document),
 
109
  JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid
110
  WHERE lpc.name = 'document_embeddings'
111
  AND lpe.cmetadata->>'user_id' = :user_id
112
+ AND lpe.cmetadata->>'source_type' = 'database'
 
 
113
  AND to_tsvector('english', lpe.document) @@ plainto_tsquery('english', :query)
114
  ORDER BY rank DESC
115
  LIMIT :k
 
124
  content=row.document,
125
  metadata=row.cmetadata,
126
  score=float(row.rank),
127
+ source_type="database",
128
  )
129
  for row in rows
130
  ]
 
165
  return sorted(seen.values(), key=lambda r: r.score, reverse=True)
166
 
167
  # ------------------------------------------------------------------
168
+ # Public interfacecalled by the router
169
  # ------------------------------------------------------------------
170
 
171
+ async def retrieve(self, query: str, user_id: str, k: int = 5) -> list[RetrievalResult]:
172
+ """RRF merge of dense (DB + tabular) and FTS (DB only)."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  embedding = await self._embed_query(query)
174
  db_results, tabular_results, fts_results = await asyncio.gather(
175
  self._search_db(embedding, user_id, k),
 
177
  self._search_fts_db(query, user_id, k * 4),
178
  )
179
  dense = self._dedup(db_results + tabular_results)[:k]
180
+ results = self._rrf_merge(dense, self._dedup(fts_results), top_k=k)
181
+ logger.info("schema retrieval", count=len(results))
 
 
 
 
 
 
 
 
182
  return results
183
 
184
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  schema_retriever = SchemaRetriever()