Commit ·
5f86993
1
Parent(s): b4df8b1
[NOTICKET][doc] add sheet-level leg and RRF voting for tabular retrieval
Browse files- src/rag/retrievers/schema.py +83 -29
src/rag/retrievers/schema.py
CHANGED
|
@@ -214,32 +214,82 @@ class SchemaRetriever(BaseRetriever):
|
|
| 214 |
for row in rows
|
| 215 |
]
|
| 216 |
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
"""
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
d.get("
|
| 231 |
-
d.get("
|
| 232 |
-
|
| 233 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 234 |
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 243 |
|
| 244 |
def _rank_db_tables(
|
| 245 |
self,
|
|
@@ -325,8 +375,11 @@ class SchemaRetriever(BaseRetriever):
|
|
| 325 |
ranked table set via _fetch_full_schema — the column chunks returned
|
| 326 |
here are intentionally NOT used as the schema source, only for voting.
|
| 327 |
|
| 328 |
-
Tabular (CSV/XLSX)
|
| 329 |
-
|
|
|
|
|
|
|
|
|
|
| 330 |
"""
|
| 331 |
embedding = await self._embed_query(query)
|
| 332 |
db_col_results, db_tbl_results, tabular_results, fts_results, sheet_results = await asyncio.gather(
|
|
@@ -338,17 +391,18 @@ class SchemaRetriever(BaseRetriever):
|
|
| 338 |
)
|
| 339 |
|
| 340 |
db_ranked = self._rank_db_tables(db_tbl_results, db_col_results, fts_results, top_k=k)
|
| 341 |
-
|
| 342 |
|
| 343 |
-
results = db_ranked +
|
| 344 |
logger.info(
|
| 345 |
"schema retrieval",
|
| 346 |
count=len(results),
|
| 347 |
db_tables_ranked=len(db_ranked),
|
| 348 |
db_cols=len(db_col_results),
|
| 349 |
db_tables=len(db_tbl_results),
|
| 350 |
-
|
| 351 |
tabular_sheets=len(sheet_results),
|
|
|
|
| 352 |
fts=len(fts_results),
|
| 353 |
)
|
| 354 |
return results
|
|
|
|
| 214 |
for row in rows
|
| 215 |
]
|
| 216 |
|
| 217 |
+
def _rank_tabular_sheets(
|
| 218 |
+
self,
|
| 219 |
+
sheet_results: list[RetrievalResult],
|
| 220 |
+
column_results: list[RetrievalResult],
|
| 221 |
+
top_k: int,
|
| 222 |
+
k_rrf: int = 60,
|
| 223 |
+
) -> list[RetrievalResult]:
|
| 224 |
+
"""Rank tabular sheets by RRF across two voting legs:
|
| 225 |
+
L1 (primary): sheet-chunk cosine score
|
| 226 |
+
L2 (vote): best column-chunk position per (doc_id, sheet_name)
|
| 227 |
+
|
| 228 |
+
Returns top-k sheet-level RetrievalResults. The full column list of
|
| 229 |
+
each sheet is already in the sheet chunk's data.column_names from
|
| 230 |
+
ingestion, so downstream tabular_executor can read full sheet context.
|
| 231 |
+
|
| 232 |
+
For sheets surfaced by column votes but missing a sheet chunk (rare —
|
| 233 |
+
ingestion always creates one), a minimal stub is returned and
|
| 234 |
+
tabular_executor falls back to reading columns from the parquet.
|
| 235 |
"""
|
| 236 |
+
# L1: sheets indexed by (doc_id, sheet_name) from sheet chunks
|
| 237 |
+
sheet_index: dict[tuple, RetrievalResult] = {}
|
| 238 |
+
sheet_ranked: list[tuple] = []
|
| 239 |
+
for r in sheet_results:
|
| 240 |
+
d = r.metadata.get("data", {})
|
| 241 |
+
key = (d.get("document_id"), d.get("sheet_name"))
|
| 242 |
+
if key[0] and key not in sheet_index:
|
| 243 |
+
sheet_index[key] = r
|
| 244 |
+
sheet_ranked.append(key)
|
| 245 |
+
|
| 246 |
+
# L2: sheets ranked by first-appearance in column-chunk results
|
| 247 |
+
col_sheet_ranked: list[tuple] = []
|
| 248 |
+
seen: set[tuple] = set()
|
| 249 |
+
for r in column_results:
|
| 250 |
+
d = r.metadata.get("data", {})
|
| 251 |
+
key = (d.get("document_id"), d.get("sheet_name"))
|
| 252 |
+
if key[0] and key not in seen:
|
| 253 |
+
col_sheet_ranked.append(key)
|
| 254 |
+
seen.add(key)
|
| 255 |
+
|
| 256 |
+
# RRF over (doc_id, sheet_name) across the two legs
|
| 257 |
+
rrf_scores: dict[tuple, float] = {}
|
| 258 |
+
for ranked_list in [sheet_ranked, col_sheet_ranked]:
|
| 259 |
+
for rank, key in enumerate(ranked_list):
|
| 260 |
+
rrf_scores[key] = rrf_scores.get(key, 0.0) + 1.0 / (k_rrf + rank + 1)
|
| 261 |
+
|
| 262 |
+
top_sheets = sorted(rrf_scores, key=lambda k: rrf_scores[k], reverse=True)[:top_k]
|
| 263 |
|
| 264 |
+
results: list[RetrievalResult] = []
|
| 265 |
+
for key in top_sheets:
|
| 266 |
+
if key in sheet_index:
|
| 267 |
+
r = sheet_index[key]
|
| 268 |
+
r.score = rrf_scores[key]
|
| 269 |
+
results.append(r)
|
| 270 |
+
else:
|
| 271 |
+
# Surfaced by column votes only — build stub from a representative
|
| 272 |
+
# column result so tabular_executor can group correctly.
|
| 273 |
+
doc_id, sheet_name = key
|
| 274 |
+
rep = next(
|
| 275 |
+
(r for r in column_results
|
| 276 |
+
if r.metadata.get("data", {}).get("document_id") == doc_id
|
| 277 |
+
and r.metadata.get("data", {}).get("sheet_name") == sheet_name),
|
| 278 |
+
None,
|
| 279 |
+
)
|
| 280 |
+
if rep is None:
|
| 281 |
+
continue
|
| 282 |
+
stub_data = dict(rep.metadata.get("data", {}))
|
| 283 |
+
stub_data.pop("column_name", None)
|
| 284 |
+
stub_data.pop("column_type", None)
|
| 285 |
+
results.append(RetrievalResult(
|
| 286 |
+
content=f"Sheet: {stub_data.get('filename', '')}"
|
| 287 |
+
+ (f" / sheet: {sheet_name}" if sheet_name else ""),
|
| 288 |
+
metadata={**rep.metadata, "data": stub_data, "chunk_level": "sheet"},
|
| 289 |
+
score=rrf_scores[key],
|
| 290 |
+
source_type="document",
|
| 291 |
+
))
|
| 292 |
+
return results
|
| 293 |
|
| 294 |
def _rank_db_tables(
|
| 295 |
self,
|
|
|
|
| 375 |
ranked table set via _fetch_full_schema — the column chunks returned
|
| 376 |
here are intentionally NOT used as the schema source, only for voting.
|
| 377 |
|
| 378 |
+
Tabular (CSV/XLSX) sheets are ranked via RRF across two legs:
|
| 379 |
+
L1: sheet-chunk cosine
|
| 380 |
+
L2: column-chunk votes (best position per sheet)
|
| 381 |
+
Returns sheet-level RetrievalResults so tabular_executor receives
|
| 382 |
+
full sheet context (all columns) rather than fragmented column hits.
|
| 383 |
"""
|
| 384 |
embedding = await self._embed_query(query)
|
| 385 |
db_col_results, db_tbl_results, tabular_results, fts_results, sheet_results = await asyncio.gather(
|
|
|
|
| 391 |
)
|
| 392 |
|
| 393 |
db_ranked = self._rank_db_tables(db_tbl_results, db_col_results, fts_results, top_k=k)
|
| 394 |
+
tabular_ranked = self._rank_tabular_sheets(sheet_results, tabular_results, top_k=k)
|
| 395 |
|
| 396 |
+
results = db_ranked + tabular_ranked
|
| 397 |
logger.info(
|
| 398 |
"schema retrieval",
|
| 399 |
count=len(results),
|
| 400 |
db_tables_ranked=len(db_ranked),
|
| 401 |
db_cols=len(db_col_results),
|
| 402 |
db_tables=len(db_tbl_results),
|
| 403 |
+
tabular_cols=len(tabular_results),
|
| 404 |
tabular_sheets=len(sheet_results),
|
| 405 |
+
tabular_ranked=len(tabular_ranked),
|
| 406 |
fts=len(fts_results),
|
| 407 |
)
|
| 408 |
return results
|