sofhiaazzhr commited on
Commit
a205d0c
·
1 Parent(s): 23eeb2d

[NOTICKET][db] add sheet-level retrieval and focus LLM schema context to retrieved columns

Browse files
src/query/executors/tabular.py CHANGED
@@ -252,6 +252,10 @@ class TabularExecutor(BaseExecutor):
252
  ) -> QueryResult | None:
253
  try:
254
  df = await download_parquet(user_id, doc_id, sheet_name)
 
 
 
 
255
  df_result = await self._query_with_agent(df, question, limit)
256
 
257
  table_label = info["filename"]
 
252
  ) -> QueryResult | None:
253
  try:
254
  df = await download_parquet(user_id, doc_id, sheet_name)
255
+ if info["columns"]:
256
+ valid_cols = [c for c in info["columns"] if c in df.columns]
257
+ if valid_cols:
258
+ df = df[valid_cols]
259
  df_result = await self._query_with_agent(df, question, limit)
260
 
261
  table_label = info["filename"]
src/rag/retrievers/schema.py CHANGED
@@ -2,8 +2,8 @@
2
  columns stored as source_type="document" with file_type in ("csv","xlsx").
3
 
4
  Strategy: hybrid_bm25 — RRF merge of dense cosine search (DB columns + DB tables
5
- + tabular) and PostgreSQL full-text search (DB columns only). Embeds the query
6
- once, fans out four legs in parallel.
7
 
8
  The DB-tables leg surfaces table-level summary chunks (chunk_level='table') as
9
  a recall signal for multi-table questions: when a relevant table's columns
@@ -127,6 +127,7 @@ class SchemaRetriever(BaseRetriever):
127
  WHERE lpc.name = 'document_embeddings'
128
  AND lpe.cmetadata->>'user_id' = :user_id
129
  AND lpe.cmetadata->>'source_type' = 'document'
 
130
  AND (lpe.cmetadata->'data'->>'file_type' = 'csv'
131
  OR lpe.cmetadata->'data'->>'file_type' = 'xlsx')
132
  ORDER BY lpe.embedding <=> '{emb_str}'::vector ASC
@@ -147,6 +148,41 @@ class SchemaRetriever(BaseRetriever):
147
  for row in rows
148
  ]
149
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  async def _search_fts_db(self, query: str, user_id: str, k: int) -> list[RetrievalResult]:
151
  """Full-text search over DB schema chunks using PostgreSQL tsvector."""
152
  sql = text("""
@@ -182,9 +218,10 @@ class SchemaRetriever(BaseRetriever):
182
  def _chunk_key(r: RetrievalResult) -> tuple:
183
  """Stable identity for dedup/RRF.
184
 
185
- Includes filename and sheet_name so that tabular column chunks with
186
- the same column name across different files (e.g. `id` in two CSVs)
187
- and future sheet-level chunks across XLSX sheets don't collide.
 
188
  """
189
  d = r.metadata.get("data", {})
190
  return (
@@ -192,6 +229,7 @@ class SchemaRetriever(BaseRetriever):
192
  d.get("column_name"),
193
  d.get("filename"),
194
  d.get("sheet_name"),
 
195
  )
196
 
197
  def _dedup(self, results: list[RetrievalResult]) -> list[RetrievalResult]:
@@ -291,15 +329,16 @@ class SchemaRetriever(BaseRetriever):
291
  no table-level chunks.
292
  """
293
  embedding = await self._embed_query(query)
294
- db_col_results, db_tbl_results, tabular_results, fts_results = await asyncio.gather(
295
  self._search_db(embedding, user_id, k),
296
  self._search_db_tables(embedding, user_id, k),
297
  self._search_tabular(embedding, user_id, k),
298
  self._search_fts_db(query, user_id, k * 4),
 
299
  )
300
 
301
  db_ranked = self._rank_db_tables(db_tbl_results, db_col_results, fts_results, top_k=k)
302
- tabular_final = self._dedup(tabular_results)[:k]
303
 
304
  results = db_ranked + tabular_final
305
  logger.info(
@@ -309,6 +348,7 @@ class SchemaRetriever(BaseRetriever):
309
  db_cols=len(db_col_results),
310
  db_tables=len(db_tbl_results),
311
  tabular=len(tabular_results),
 
312
  fts=len(fts_results),
313
  )
314
  return results
 
2
  columns stored as source_type="document" with file_type in ("csv","xlsx").
3
 
4
  Strategy: hybrid_bm25 — RRF merge of dense cosine search (DB columns + DB tables
5
+ + tabular columns + tabular sheets) and PostgreSQL full-text search (DB columns only).
6
+ Embeds the query once, fans out five legs in parallel.
7
 
8
  The DB-tables leg surfaces table-level summary chunks (chunk_level='table') as
9
  a recall signal for multi-table questions: when a relevant table's columns
 
127
  WHERE lpc.name = 'document_embeddings'
128
  AND lpe.cmetadata->>'user_id' = :user_id
129
  AND lpe.cmetadata->>'source_type' = 'document'
130
+ AND lpe.cmetadata->>'chunk_level' = 'column'
131
  AND (lpe.cmetadata->'data'->>'file_type' = 'csv'
132
  OR lpe.cmetadata->'data'->>'file_type' = 'xlsx')
133
  ORDER BY lpe.embedding <=> '{emb_str}'::vector ASC
 
148
  for row in rows
149
  ]
150
 
151
+ async def _search_tabular_sheets(
152
+ self, embedding: list[float], user_id: str, k: int
153
+ ) -> list[RetrievalResult]:
154
+ """Leg 5: sheet-level summary chunks from CSV/XLSX files."""
155
+ emb_str = "[" + ",".join(str(x) for x in embedding) + "]"
156
+
157
+ sql = text(f"""
158
+ SELECT lpe.document, lpe.cmetadata,
159
+ 1.0 - (lpe.embedding <=> '{emb_str}'::vector) AS score
160
+ FROM langchain_pg_embedding lpe
161
+ JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid
162
+ WHERE lpc.name = 'document_embeddings'
163
+ AND lpe.cmetadata->>'user_id' = :user_id
164
+ AND lpe.cmetadata->>'source_type' = 'document'
165
+ AND lpe.cmetadata->>'chunk_level' = 'sheet'
166
+ AND (lpe.cmetadata->'data'->>'file_type' = 'csv'
167
+ OR lpe.cmetadata->'data'->>'file_type' = 'xlsx')
168
+ ORDER BY lpe.embedding <=> '{emb_str}'::vector ASC
169
+ LIMIT :k
170
+ """)
171
+
172
+ async with _pgvector_engine.connect() as conn:
173
+ result = await conn.execute(sql, {"user_id": user_id, "k": k})
174
+ rows = result.fetchall()
175
+
176
+ return [
177
+ RetrievalResult(
178
+ content=row.document,
179
+ metadata=row.cmetadata,
180
+ score=float(row.score),
181
+ source_type="document",
182
+ )
183
+ for row in rows
184
+ ]
185
+
186
  async def _search_fts_db(self, query: str, user_id: str, k: int) -> list[RetrievalResult]:
187
  """Full-text search over DB schema chunks using PostgreSQL tsvector."""
188
  sql = text("""
 
218
  def _chunk_key(r: RetrievalResult) -> tuple:
219
  """Stable identity for dedup/RRF.
220
 
221
+ Includes filename, sheet_name, and chunk_level so that column chunks
222
+ and sheet chunks for the same file/sheet don't collide, and column
223
+ chunks with the same name across different files (e.g. `id` in two CSVs)
224
+ are kept distinct.
225
  """
226
  d = r.metadata.get("data", {})
227
  return (
 
229
  d.get("column_name"),
230
  d.get("filename"),
231
  d.get("sheet_name"),
232
+ r.metadata.get("chunk_level"),
233
  )
234
 
235
  def _dedup(self, results: list[RetrievalResult]) -> list[RetrievalResult]:
 
329
  no table-level chunks.
330
  """
331
  embedding = await self._embed_query(query)
332
+ db_col_results, db_tbl_results, tabular_results, fts_results, sheet_results = await asyncio.gather(
333
  self._search_db(embedding, user_id, k),
334
  self._search_db_tables(embedding, user_id, k),
335
  self._search_tabular(embedding, user_id, k),
336
  self._search_fts_db(query, user_id, k * 4),
337
+ self._search_tabular_sheets(embedding, user_id, k),
338
  )
339
 
340
  db_ranked = self._rank_db_tables(db_tbl_results, db_col_results, fts_results, top_k=k)
341
+ tabular_final = self._dedup(tabular_results + sheet_results)[:k]
342
 
343
  results = db_ranked + tabular_final
344
  logger.info(
 
348
  db_cols=len(db_col_results),
349
  db_tables=len(db_tbl_results),
350
  tabular=len(tabular_results),
351
+ tabular_sheets=len(sheet_results),
352
  fts=len(fts_results),
353
  )
354
  return results
src/rag/router.py CHANGED
@@ -25,7 +25,7 @@ SourceHint = Literal["document", "schema", "both"]
25
 
26
  def _result_dedup_key(r: RetrievalResult) -> tuple:
27
  """Cross-retriever dedup key — distinguishes DB columns vs DB tables vs
28
- tabular columns vs prose chunks vs sheet-level (future)."""
29
  data = r.metadata.get("data", {})
30
  return (
31
  r.source_type,
@@ -34,6 +34,7 @@ def _result_dedup_key(r: RetrievalResult) -> tuple:
34
  data.get("filename"),
35
  data.get("sheet_name"),
36
  data.get("chunk_index"), # disambiguates multiple prose chunks per doc
 
37
  )
38
 
39
 
 
25
 
26
  def _result_dedup_key(r: RetrievalResult) -> tuple:
27
  """Cross-retriever dedup key — distinguishes DB columns vs DB tables vs
28
+ tabular columns vs prose chunks vs sheet-level chunks."""
29
  data = r.metadata.get("data", {})
30
  return (
31
  r.source_type,
 
34
  data.get("filename"),
35
  data.get("sheet_name"),
36
  data.get("chunk_index"), # disambiguates multiple prose chunks per doc
37
+ r.metadata.get("chunk_level"), # distinguishes sheet vs column chunks
38
  )
39
 
40