Commit ·
a205d0c
1
Parent(s): 23eeb2d
[NOTICKET][db] add sheet-level retrieval and focus LLM schema context to retrieved columns
Browse files- src/query/executors/tabular.py +4 -0
- src/rag/retrievers/schema.py +47 -7
- src/rag/router.py +2 -1
src/query/executors/tabular.py
CHANGED
|
@@ -252,6 +252,10 @@ class TabularExecutor(BaseExecutor):
|
|
| 252 |
) -> QueryResult | None:
|
| 253 |
try:
|
| 254 |
df = await download_parquet(user_id, doc_id, sheet_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 255 |
df_result = await self._query_with_agent(df, question, limit)
|
| 256 |
|
| 257 |
table_label = info["filename"]
|
|
|
|
| 252 |
) -> QueryResult | None:
|
| 253 |
try:
|
| 254 |
df = await download_parquet(user_id, doc_id, sheet_name)
|
| 255 |
+
if info["columns"]:
|
| 256 |
+
valid_cols = [c for c in info["columns"] if c in df.columns]
|
| 257 |
+
if valid_cols:
|
| 258 |
+
df = df[valid_cols]
|
| 259 |
df_result = await self._query_with_agent(df, question, limit)
|
| 260 |
|
| 261 |
table_label = info["filename"]
|
src/rag/retrievers/schema.py
CHANGED
|
@@ -2,8 +2,8 @@
|
|
| 2 |
columns stored as source_type="document" with file_type in ("csv","xlsx").
|
| 3 |
|
| 4 |
Strategy: hybrid_bm25 — RRF merge of dense cosine search (DB columns + DB tables
|
| 5 |
-
+ tabular) and PostgreSQL full-text search (DB columns only).
|
| 6 |
-
once, fans out
|
| 7 |
|
| 8 |
The DB-tables leg surfaces table-level summary chunks (chunk_level='table') as
|
| 9 |
a recall signal for multi-table questions: when a relevant table's columns
|
|
@@ -127,6 +127,7 @@ class SchemaRetriever(BaseRetriever):
|
|
| 127 |
WHERE lpc.name = 'document_embeddings'
|
| 128 |
AND lpe.cmetadata->>'user_id' = :user_id
|
| 129 |
AND lpe.cmetadata->>'source_type' = 'document'
|
|
|
|
| 130 |
AND (lpe.cmetadata->'data'->>'file_type' = 'csv'
|
| 131 |
OR lpe.cmetadata->'data'->>'file_type' = 'xlsx')
|
| 132 |
ORDER BY lpe.embedding <=> '{emb_str}'::vector ASC
|
|
@@ -147,6 +148,41 @@ class SchemaRetriever(BaseRetriever):
|
|
| 147 |
for row in rows
|
| 148 |
]
|
| 149 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
async def _search_fts_db(self, query: str, user_id: str, k: int) -> list[RetrievalResult]:
|
| 151 |
"""Full-text search over DB schema chunks using PostgreSQL tsvector."""
|
| 152 |
sql = text("""
|
|
@@ -182,9 +218,10 @@ class SchemaRetriever(BaseRetriever):
|
|
| 182 |
def _chunk_key(r: RetrievalResult) -> tuple:
|
| 183 |
"""Stable identity for dedup/RRF.
|
| 184 |
|
| 185 |
-
Includes filename and
|
| 186 |
-
|
| 187 |
-
|
|
|
|
| 188 |
"""
|
| 189 |
d = r.metadata.get("data", {})
|
| 190 |
return (
|
|
@@ -192,6 +229,7 @@ class SchemaRetriever(BaseRetriever):
|
|
| 192 |
d.get("column_name"),
|
| 193 |
d.get("filename"),
|
| 194 |
d.get("sheet_name"),
|
|
|
|
| 195 |
)
|
| 196 |
|
| 197 |
def _dedup(self, results: list[RetrievalResult]) -> list[RetrievalResult]:
|
|
@@ -291,15 +329,16 @@ class SchemaRetriever(BaseRetriever):
|
|
| 291 |
no table-level chunks.
|
| 292 |
"""
|
| 293 |
embedding = await self._embed_query(query)
|
| 294 |
-
db_col_results, db_tbl_results, tabular_results, fts_results = await asyncio.gather(
|
| 295 |
self._search_db(embedding, user_id, k),
|
| 296 |
self._search_db_tables(embedding, user_id, k),
|
| 297 |
self._search_tabular(embedding, user_id, k),
|
| 298 |
self._search_fts_db(query, user_id, k * 4),
|
|
|
|
| 299 |
)
|
| 300 |
|
| 301 |
db_ranked = self._rank_db_tables(db_tbl_results, db_col_results, fts_results, top_k=k)
|
| 302 |
-
tabular_final = self._dedup(tabular_results)[:k]
|
| 303 |
|
| 304 |
results = db_ranked + tabular_final
|
| 305 |
logger.info(
|
|
@@ -309,6 +348,7 @@ class SchemaRetriever(BaseRetriever):
|
|
| 309 |
db_cols=len(db_col_results),
|
| 310 |
db_tables=len(db_tbl_results),
|
| 311 |
tabular=len(tabular_results),
|
|
|
|
| 312 |
fts=len(fts_results),
|
| 313 |
)
|
| 314 |
return results
|
|
|
|
| 2 |
columns stored as source_type="document" with file_type in ("csv","xlsx").
|
| 3 |
|
| 4 |
Strategy: hybrid_bm25 — RRF merge of dense cosine search (DB columns + DB tables
|
| 5 |
+
+ tabular columns + tabular sheets) and PostgreSQL full-text search (DB columns only).
|
| 6 |
+
Embeds the query once, fans out five legs in parallel.
|
| 7 |
|
| 8 |
The DB-tables leg surfaces table-level summary chunks (chunk_level='table') as
|
| 9 |
a recall signal for multi-table questions: when a relevant table's columns
|
|
|
|
| 127 |
WHERE lpc.name = 'document_embeddings'
|
| 128 |
AND lpe.cmetadata->>'user_id' = :user_id
|
| 129 |
AND lpe.cmetadata->>'source_type' = 'document'
|
| 130 |
+
AND lpe.cmetadata->>'chunk_level' = 'column'
|
| 131 |
AND (lpe.cmetadata->'data'->>'file_type' = 'csv'
|
| 132 |
OR lpe.cmetadata->'data'->>'file_type' = 'xlsx')
|
| 133 |
ORDER BY lpe.embedding <=> '{emb_str}'::vector ASC
|
|
|
|
| 148 |
for row in rows
|
| 149 |
]
|
| 150 |
|
| 151 |
+
async def _search_tabular_sheets(
|
| 152 |
+
self, embedding: list[float], user_id: str, k: int
|
| 153 |
+
) -> list[RetrievalResult]:
|
| 154 |
+
"""Leg 5: sheet-level summary chunks from CSV/XLSX files."""
|
| 155 |
+
emb_str = "[" + ",".join(str(x) for x in embedding) + "]"
|
| 156 |
+
|
| 157 |
+
sql = text(f"""
|
| 158 |
+
SELECT lpe.document, lpe.cmetadata,
|
| 159 |
+
1.0 - (lpe.embedding <=> '{emb_str}'::vector) AS score
|
| 160 |
+
FROM langchain_pg_embedding lpe
|
| 161 |
+
JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid
|
| 162 |
+
WHERE lpc.name = 'document_embeddings'
|
| 163 |
+
AND lpe.cmetadata->>'user_id' = :user_id
|
| 164 |
+
AND lpe.cmetadata->>'source_type' = 'document'
|
| 165 |
+
AND lpe.cmetadata->>'chunk_level' = 'sheet'
|
| 166 |
+
AND (lpe.cmetadata->'data'->>'file_type' = 'csv'
|
| 167 |
+
OR lpe.cmetadata->'data'->>'file_type' = 'xlsx')
|
| 168 |
+
ORDER BY lpe.embedding <=> '{emb_str}'::vector ASC
|
| 169 |
+
LIMIT :k
|
| 170 |
+
""")
|
| 171 |
+
|
| 172 |
+
async with _pgvector_engine.connect() as conn:
|
| 173 |
+
result = await conn.execute(sql, {"user_id": user_id, "k": k})
|
| 174 |
+
rows = result.fetchall()
|
| 175 |
+
|
| 176 |
+
return [
|
| 177 |
+
RetrievalResult(
|
| 178 |
+
content=row.document,
|
| 179 |
+
metadata=row.cmetadata,
|
| 180 |
+
score=float(row.score),
|
| 181 |
+
source_type="document",
|
| 182 |
+
)
|
| 183 |
+
for row in rows
|
| 184 |
+
]
|
| 185 |
+
|
| 186 |
async def _search_fts_db(self, query: str, user_id: str, k: int) -> list[RetrievalResult]:
|
| 187 |
"""Full-text search over DB schema chunks using PostgreSQL tsvector."""
|
| 188 |
sql = text("""
|
|
|
|
| 218 |
def _chunk_key(r: RetrievalResult) -> tuple:
|
| 219 |
"""Stable identity for dedup/RRF.
|
| 220 |
|
| 221 |
+
Includes filename, sheet_name, and chunk_level so that column chunks
|
| 222 |
+
and sheet chunks for the same file/sheet don't collide, and column
|
| 223 |
+
chunks with the same name across different files (e.g. `id` in two CSVs)
|
| 224 |
+
are kept distinct.
|
| 225 |
"""
|
| 226 |
d = r.metadata.get("data", {})
|
| 227 |
return (
|
|
|
|
| 229 |
d.get("column_name"),
|
| 230 |
d.get("filename"),
|
| 231 |
d.get("sheet_name"),
|
| 232 |
+
r.metadata.get("chunk_level"),
|
| 233 |
)
|
| 234 |
|
| 235 |
def _dedup(self, results: list[RetrievalResult]) -> list[RetrievalResult]:
|
|
|
|
| 329 |
no table-level chunks.
|
| 330 |
"""
|
| 331 |
embedding = await self._embed_query(query)
|
| 332 |
+
db_col_results, db_tbl_results, tabular_results, fts_results, sheet_results = await asyncio.gather(
|
| 333 |
self._search_db(embedding, user_id, k),
|
| 334 |
self._search_db_tables(embedding, user_id, k),
|
| 335 |
self._search_tabular(embedding, user_id, k),
|
| 336 |
self._search_fts_db(query, user_id, k * 4),
|
| 337 |
+
self._search_tabular_sheets(embedding, user_id, k),
|
| 338 |
)
|
| 339 |
|
| 340 |
db_ranked = self._rank_db_tables(db_tbl_results, db_col_results, fts_results, top_k=k)
|
| 341 |
+
tabular_final = self._dedup(tabular_results + sheet_results)[:k]
|
| 342 |
|
| 343 |
results = db_ranked + tabular_final
|
| 344 |
logger.info(
|
|
|
|
| 348 |
db_cols=len(db_col_results),
|
| 349 |
db_tables=len(db_tbl_results),
|
| 350 |
tabular=len(tabular_results),
|
| 351 |
+
tabular_sheets=len(sheet_results),
|
| 352 |
fts=len(fts_results),
|
| 353 |
)
|
| 354 |
return results
|
src/rag/router.py
CHANGED
|
@@ -25,7 +25,7 @@ SourceHint = Literal["document", "schema", "both"]
|
|
| 25 |
|
| 26 |
def _result_dedup_key(r: RetrievalResult) -> tuple:
|
| 27 |
"""Cross-retriever dedup key — distinguishes DB columns vs DB tables vs
|
| 28 |
-
tabular columns vs prose chunks vs sheet-level
|
| 29 |
data = r.metadata.get("data", {})
|
| 30 |
return (
|
| 31 |
r.source_type,
|
|
@@ -34,6 +34,7 @@ def _result_dedup_key(r: RetrievalResult) -> tuple:
|
|
| 34 |
data.get("filename"),
|
| 35 |
data.get("sheet_name"),
|
| 36 |
data.get("chunk_index"), # disambiguates multiple prose chunks per doc
|
|
|
|
| 37 |
)
|
| 38 |
|
| 39 |
|
|
|
|
| 25 |
|
| 26 |
def _result_dedup_key(r: RetrievalResult) -> tuple:
|
| 27 |
"""Cross-retriever dedup key — distinguishes DB columns vs DB tables vs
|
| 28 |
+
tabular columns vs prose chunks vs sheet-level chunks."""
|
| 29 |
data = r.metadata.get("data", {})
|
| 30 |
return (
|
| 31 |
r.source_type,
|
|
|
|
| 34 |
data.get("filename"),
|
| 35 |
data.get("sheet_name"),
|
| 36 |
data.get("chunk_index"), # disambiguates multiple prose chunks per doc
|
| 37 |
+
r.metadata.get("chunk_level"), # distinguishes sheet vs column chunks
|
| 38 |
)
|
| 39 |
|
| 40 |
|