Rifqi Hafizuddin commited on
Commit
4150ba7
·
1 Parent(s): fc1239a

[KM-533] now also retrieves table level chunk

Browse files
Files changed (1) hide show
  1. src/rag/retrievers/schema.py +62 -7
src/rag/retrievers/schema.py CHANGED
@@ -1,9 +1,15 @@
1
  """Schema retriever — handles DB schemas (source_type="database") and tabular file
2
  columns stored as source_type="document" with file_type in ("csv","xlsx").
3
 
4
- Strategy: hybrid_bm25 — RRF merge of dense cosine search (DB + tabular) and
5
- PostgreSQL full-text search (DB only). Embeds the query once, fans out the
6
- three legs in parallel.
 
 
 
 
 
 
7
 
8
  FTS requires a GIN index on langchain_pg_embedding.document (created by init_db.py).
9
  """
@@ -20,6 +26,7 @@ from src.rag.base import BaseRetriever, RetrievalResult
20
  logger = get_logger("schema_retriever")
21
 
22
  _TABULAR_FILE_TYPES = ("csv", "xlsx")
 
23
 
24
 
25
  class SchemaRetriever(BaseRetriever):
@@ -66,6 +73,46 @@ class SchemaRetriever(BaseRetriever):
66
  for row in rows
67
  ]
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  async def _search_tabular(
70
  self, embedding: list[float], user_id: str, k: int
71
  ) -> list[RetrievalResult]:
@@ -171,16 +218,24 @@ class SchemaRetriever(BaseRetriever):
171
  # ------------------------------------------------------------------
172
 
173
  async def retrieve(self, query: str, user_id: str, k: int = 5) -> list[RetrievalResult]:
174
- """RRF merge of dense (DB + tabular) and FTS (DB only)."""
175
  embedding = await self._embed_query(query)
176
- db_results, tabular_results, fts_results = await asyncio.gather(
177
  self._search_db(embedding, user_id, k),
 
178
  self._search_tabular(embedding, user_id, k),
179
  self._search_fts_db(query, user_id, k * 4),
180
  )
181
- dense = self._dedup(db_results + tabular_results)[:k]
182
  results = self._rrf_merge(dense, self._dedup(fts_results), top_k=k)
183
- logger.info("schema retrieval", count=len(results))
 
 
 
 
 
 
 
184
  return results
185
 
186
 
 
1
  """Schema retriever — handles DB schemas (source_type="database") and tabular file
2
  columns stored as source_type="document" with file_type in ("csv","xlsx").
3
 
4
+ Strategy: hybrid_bm25 — RRF merge of dense cosine search (DB columns + DB tables
5
+ + tabular) and PostgreSQL full-text search (DB columns only). Embeds the query
6
+ once, fans out four legs in parallel.
7
+
8
+ The DB-tables leg surfaces table-level summary chunks (chunk_level='table') as
9
+ a recall signal for multi-table questions: when a relevant table's columns
10
+ don't individually win on similarity, the table chunk can still pull the table
11
+ into the hit set, where db_executor's downstream full-schema fetch picks up
12
+ the per-column detail.
13
 
14
  FTS requires a GIN index on langchain_pg_embedding.document (created by init_db.py).
15
  """
 
26
  logger = get_logger("schema_retriever")
27
 
28
  _TABULAR_FILE_TYPES = ("csv", "xlsx")
29
+ _TABLE_CHUNK_K_MULTIPLIER = 2 # how many table chunks to pull before RRF
30
 
31
 
32
  class SchemaRetriever(BaseRetriever):
 
73
  for row in rows
74
  ]
75
 
76
+ async def _search_db_tables(
77
+ self, embedding: list[float], user_id: str, k: int
78
+ ) -> list[RetrievalResult]:
79
+ """Cosine vector search over database TABLE-level chunks.
80
+
81
+ Recall channel for multi-table questions. The chunk's content is
82
+ discarded downstream — db_executor only consumes its `data.table_name`
83
+ to seed full-schema fetch.
84
+ """
85
+ emb_str = "[" + ",".join(str(x) for x in embedding) + "]"
86
+
87
+ sql = text(f"""
88
+ SELECT lpe.document, lpe.cmetadata,
89
+ 1.0 - (lpe.embedding <=> '{emb_str}'::vector) AS score
90
+ FROM langchain_pg_embedding lpe
91
+ JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid
92
+ WHERE lpc.name = 'document_embeddings'
93
+ AND lpe.cmetadata->>'user_id' = :user_id
94
+ AND lpe.cmetadata->>'source_type' = 'database'
95
+ AND lpe.cmetadata->>'chunk_level' = 'table'
96
+ ORDER BY lpe.embedding <=> '{emb_str}'::vector ASC
97
+ LIMIT :k
98
+ """)
99
+
100
+ async with _pgvector_engine.connect() as conn:
101
+ result = await conn.execute(
102
+ sql, {"user_id": user_id, "k": k * _TABLE_CHUNK_K_MULTIPLIER}
103
+ )
104
+ rows = result.fetchall()
105
+
106
+ return [
107
+ RetrievalResult(
108
+ content=row.document,
109
+ metadata=row.cmetadata,
110
+ score=float(row.score),
111
+ source_type="database",
112
+ )
113
+ for row in rows
114
+ ]
115
+
116
  async def _search_tabular(
117
  self, embedding: list[float], user_id: str, k: int
118
  ) -> list[RetrievalResult]:
 
218
  # ------------------------------------------------------------------
219
 
220
  async def retrieve(self, query: str, user_id: str, k: int = 5) -> list[RetrievalResult]:
221
+ """RRF merge of dense (DB columns + DB tables + tabular) and FTS (DB cols only)."""
222
  embedding = await self._embed_query(query)
223
+ db_col_results, db_tbl_results, tabular_results, fts_results = await asyncio.gather(
224
  self._search_db(embedding, user_id, k),
225
+ self._search_db_tables(embedding, user_id, k),
226
  self._search_tabular(embedding, user_id, k),
227
  self._search_fts_db(query, user_id, k * 4),
228
  )
229
+ dense = self._dedup(db_col_results + db_tbl_results + tabular_results)[:k]
230
  results = self._rrf_merge(dense, self._dedup(fts_results), top_k=k)
231
+ logger.info(
232
+ "schema retrieval",
233
+ count=len(results),
234
+ db_cols=len(db_col_results),
235
+ db_tables=len(db_tbl_results),
236
+ tabular=len(tabular_results),
237
+ fts=len(fts_results),
238
+ )
239
  return results
240
 
241