sofhiaazzhr commited on
Commit
5f86993
·
1 Parent(s): b4df8b1

[NOTICKET][doc] add sheet-level leg and RRF voting for tabular retrieval

Browse files
Files changed (1) hide show
  1. src/rag/retrievers/schema.py +83 -29
src/rag/retrievers/schema.py CHANGED
@@ -214,32 +214,82 @@ class SchemaRetriever(BaseRetriever):
214
  for row in rows
215
  ]
216
 
217
- @staticmethod
218
- def _chunk_key(r: RetrievalResult) -> tuple:
219
- """Stable identity for dedup/RRF.
220
-
221
- Includes filename, sheet_name, and chunk_level so that column chunks
222
- and sheet chunks for the same file/sheet don't collide, and column
223
- chunks with the same name across different files (e.g. `id` in two CSVs)
224
- are kept distinct.
 
 
 
 
 
 
 
 
 
 
225
  """
226
- d = r.metadata.get("data", {})
227
- return (
228
- d.get("table_name"),
229
- d.get("column_name"),
230
- d.get("filename"),
231
- d.get("sheet_name"),
232
- r.metadata.get("chunk_level"),
233
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
 
235
- def _dedup(self, results: list[RetrievalResult]) -> list[RetrievalResult]:
236
- """Deduplicate by chunk identity, keeping highest score per unique key."""
237
- seen: dict[tuple, RetrievalResult] = {}
238
- for r in results:
239
- key = self._chunk_key(r)
240
- if key not in seen or r.score > seen[key].score:
241
- seen[key] = r
242
- return sorted(seen.values(), key=lambda r: r.score, reverse=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
243
 
244
  def _rank_db_tables(
245
  self,
@@ -325,8 +375,11 @@ class SchemaRetriever(BaseRetriever):
325
  ranked table set via _fetch_full_schema — the column chunks returned
326
  here are intentionally NOT used as the schema source, only for voting.
327
 
328
- Tabular (CSV/XLSX) chunks remain at column/sheet level since they have
329
- no table-level chunks.
 
 
 
330
  """
331
  embedding = await self._embed_query(query)
332
  db_col_results, db_tbl_results, tabular_results, fts_results, sheet_results = await asyncio.gather(
@@ -338,17 +391,18 @@ class SchemaRetriever(BaseRetriever):
338
  )
339
 
340
  db_ranked = self._rank_db_tables(db_tbl_results, db_col_results, fts_results, top_k=k)
341
- tabular_final = self._dedup(tabular_results + sheet_results)[:k]
342
 
343
- results = db_ranked + tabular_final
344
  logger.info(
345
  "schema retrieval",
346
  count=len(results),
347
  db_tables_ranked=len(db_ranked),
348
  db_cols=len(db_col_results),
349
  db_tables=len(db_tbl_results),
350
- tabular=len(tabular_results),
351
  tabular_sheets=len(sheet_results),
 
352
  fts=len(fts_results),
353
  )
354
  return results
 
214
  for row in rows
215
  ]
216
 
217
+ def _rank_tabular_sheets(
218
+ self,
219
+ sheet_results: list[RetrievalResult],
220
+ column_results: list[RetrievalResult],
221
+ top_k: int,
222
+ k_rrf: int = 60,
223
+ ) -> list[RetrievalResult]:
224
+ """Rank tabular sheets by RRF across two voting legs:
225
+ L1 (primary): sheet-chunk cosine score
226
+ L2 (vote): best column-chunk position per (doc_id, sheet_name)
227
+
228
+ Returns top-k sheet-level RetrievalResults. The full column list of
229
+ each sheet is already in the sheet chunk's data.column_names from
230
+ ingestion, so downstream tabular_executor can read full sheet context.
231
+
232
+ For sheets surfaced by column votes but missing a sheet chunk (rare —
233
+ ingestion always creates one), a minimal stub is returned and
234
+ tabular_executor falls back to reading columns from the parquet.
235
  """
236
+ # L1: sheets indexed by (doc_id, sheet_name) from sheet chunks
237
+ sheet_index: dict[tuple, RetrievalResult] = {}
238
+ sheet_ranked: list[tuple] = []
239
+ for r in sheet_results:
240
+ d = r.metadata.get("data", {})
241
+ key = (d.get("document_id"), d.get("sheet_name"))
242
+ if key[0] and key not in sheet_index:
243
+ sheet_index[key] = r
244
+ sheet_ranked.append(key)
245
+
246
+ # L2: sheets ranked by first-appearance in column-chunk results
247
+ col_sheet_ranked: list[tuple] = []
248
+ seen: set[tuple] = set()
249
+ for r in column_results:
250
+ d = r.metadata.get("data", {})
251
+ key = (d.get("document_id"), d.get("sheet_name"))
252
+ if key[0] and key not in seen:
253
+ col_sheet_ranked.append(key)
254
+ seen.add(key)
255
+
256
+ # RRF over (doc_id, sheet_name) across the two legs
257
+ rrf_scores: dict[tuple, float] = {}
258
+ for ranked_list in [sheet_ranked, col_sheet_ranked]:
259
+ for rank, key in enumerate(ranked_list):
260
+ rrf_scores[key] = rrf_scores.get(key, 0.0) + 1.0 / (k_rrf + rank + 1)
261
+
262
+ top_sheets = sorted(rrf_scores, key=lambda k: rrf_scores[k], reverse=True)[:top_k]
263
 
264
+ results: list[RetrievalResult] = []
265
+ for key in top_sheets:
266
+ if key in sheet_index:
267
+ r = sheet_index[key]
268
+ r.score = rrf_scores[key]
269
+ results.append(r)
270
+ else:
271
+ # Surfaced by column votes only — build stub from a representative
272
+ # column result so tabular_executor can group correctly.
273
+ doc_id, sheet_name = key
274
+ rep = next(
275
+ (r for r in column_results
276
+ if r.metadata.get("data", {}).get("document_id") == doc_id
277
+ and r.metadata.get("data", {}).get("sheet_name") == sheet_name),
278
+ None,
279
+ )
280
+ if rep is None:
281
+ continue
282
+ stub_data = dict(rep.metadata.get("data", {}))
283
+ stub_data.pop("column_name", None)
284
+ stub_data.pop("column_type", None)
285
+ results.append(RetrievalResult(
286
+ content=f"Sheet: {stub_data.get('filename', '')}"
287
+ + (f" / sheet: {sheet_name}" if sheet_name else ""),
288
+ metadata={**rep.metadata, "data": stub_data, "chunk_level": "sheet"},
289
+ score=rrf_scores[key],
290
+ source_type="document",
291
+ ))
292
+ return results
293
 
294
  def _rank_db_tables(
295
  self,
 
375
  ranked table set via _fetch_full_schema — the column chunks returned
376
  here are intentionally NOT used as the schema source, only for voting.
377
 
378
+ Tabular (CSV/XLSX) sheets are ranked via RRF across two legs:
379
+ L1: sheet-chunk cosine
380
+ L2: column-chunk votes (best position per sheet)
381
+ Returns sheet-level RetrievalResults so tabular_executor receives
382
+ full sheet context (all columns) rather than fragmented column hits.
383
  """
384
  embedding = await self._embed_query(query)
385
  db_col_results, db_tbl_results, tabular_results, fts_results, sheet_results = await asyncio.gather(
 
391
  )
392
 
393
  db_ranked = self._rank_db_tables(db_tbl_results, db_col_results, fts_results, top_k=k)
394
+ tabular_ranked = self._rank_tabular_sheets(sheet_results, tabular_results, top_k=k)
395
 
396
+ results = db_ranked + tabular_ranked
397
  logger.info(
398
  "schema retrieval",
399
  count=len(results),
400
  db_tables_ranked=len(db_ranked),
401
  db_cols=len(db_col_results),
402
  db_tables=len(db_tbl_results),
403
+ tabular_cols=len(tabular_results),
404
  tabular_sheets=len(sheet_results),
405
+ tabular_ranked=len(tabular_ranked),
406
  fts=len(fts_results),
407
  )
408
  return results