Rifqi Hafizuddin commited on
Commit ·
bb29492
1
Parent(s): 00aa61d
[NOTICKET] now retrieve db tables first, then get column from the obtained tables. reduce k to 5
Browse files- src/api/v1/chat.py +3 -3
- src/query/executors/db_executor.py +38 -4
- src/rag/retrievers/schema.py +88 -23
src/api/v1/chat.py
CHANGED
|
@@ -190,11 +190,11 @@ async def chat_stream(request: ChatRequest, db: AsyncSession = Depends(get_db)):
|
|
| 190 |
|
| 191 |
if intent_result is None:
|
| 192 |
# Step 2: Launch retrieval and history loading in parallel, then run orchestrator.
|
| 193 |
-
# k=
|
| 194 |
# tables — db_executor's FK expansion is one-hop and cannot bridge
|
| 195 |
# 2-hop gaps (e.g. customers -> order_items -> products) on its own.
|
| 196 |
retrieval_task = asyncio.create_task(
|
| 197 |
-
retriever.retrieve(request.message, request.user_id, db, k=
|
| 198 |
)
|
| 199 |
history_task = asyncio.create_task(
|
| 200 |
load_history(db, request.room_id, limit=6) # 6 msgs (3 pairs) for orchestrator
|
|
@@ -222,7 +222,7 @@ async def chat_stream(request: ChatRequest, db: AsyncSession = Depends(get_db)):
|
|
| 222 |
query=search_query,
|
| 223 |
user_id=request.user_id,
|
| 224 |
db=db,
|
| 225 |
-
k=
|
| 226 |
source_hint=intent_result.get("source_hint", "both"),
|
| 227 |
)
|
| 228 |
else:
|
|
|
|
| 190 |
|
| 191 |
if intent_result is None:
|
| 192 |
# Step 2: Launch retrieval and history loading in parallel, then run orchestrator.
|
| 193 |
+
# k=5
|
| 194 |
# tables — db_executor's FK expansion is one-hop and cannot bridge
|
| 195 |
# 2-hop gaps (e.g. customers -> order_items -> products) on its own.
|
| 196 |
retrieval_task = asyncio.create_task(
|
| 197 |
+
retriever.retrieve(request.message, request.user_id, db, k=5)
|
| 198 |
)
|
| 199 |
history_task = asyncio.create_task(
|
| 200 |
load_history(db, request.room_id, limit=6) # 6 msgs (3 pairs) for orchestrator
|
|
|
|
| 222 |
query=search_query,
|
| 223 |
user_id=request.user_id,
|
| 224 |
db=db,
|
| 225 |
+
k=5,
|
| 226 |
source_hint=intent_result.get("source_hint", "both"),
|
| 227 |
)
|
| 228 |
else:
|
src/query/executors/db_executor.py
CHANGED
|
@@ -193,7 +193,12 @@ class DbExecutor(BaseExecutor):
|
|
| 193 |
})
|
| 194 |
sql = result.sql.strip()
|
| 195 |
allowed_tables = set(full_schema) | set(related_schema)
|
| 196 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 197 |
if validation_error:
|
| 198 |
prev_error = validation_error
|
| 199 |
prev_reasoning = result.reasoning
|
|
@@ -559,11 +564,21 @@ class DbExecutor(BaseExecutor):
|
|
| 559 |
# Guardrails
|
| 560 |
# ------------------------------------------------------------------
|
| 561 |
|
| 562 |
-
def _validate(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 563 |
"""Return an error string if validation fails, empty string if OK.
|
| 564 |
|
| 565 |
`allowed_tables` is the union of hit-table names and FK-related table
|
| 566 |
names — both are legal targets for SELECT/JOIN.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 567 |
"""
|
| 568 |
# Layer 1: sqlglot parse + SELECT-only check
|
| 569 |
try:
|
|
@@ -580,12 +595,31 @@ class DbExecutor(BaseExecutor):
|
|
| 580 |
|
| 581 |
# Layer 2: schema grounding — table names
|
| 582 |
known_tables = {t.lower() for t in allowed_tables}
|
|
|
|
| 583 |
for tbl in parsed.find_all(exp.Table):
|
| 584 |
name = tbl.name.lower()
|
| 585 |
if name and name not in known_tables:
|
| 586 |
return f"Unknown table '{tbl.name}'. Only use tables from the schema."
|
| 587 |
-
|
| 588 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 589 |
return ""
|
| 590 |
|
| 591 |
# ------------------------------------------------------------------
|
|
|
|
| 193 |
})
|
| 194 |
sql = result.sql.strip()
|
| 195 |
allowed_tables = set(full_schema) | set(related_schema)
|
| 196 |
+
column_map: dict[str, set[str]] = {
|
| 197 |
+
t: {c["name"] for c in cols} for t, cols in full_schema.items()
|
| 198 |
+
}
|
| 199 |
+
for t, info in related_schema.items():
|
| 200 |
+
column_map[t] = set(info.get("column_names") or [])
|
| 201 |
+
validation_error = self._validate(sql, allowed_tables, capped_limit, column_map)
|
| 202 |
if validation_error:
|
| 203 |
prev_error = validation_error
|
| 204 |
prev_reasoning = result.reasoning
|
|
|
|
| 564 |
# Guardrails
|
| 565 |
# ------------------------------------------------------------------
|
| 566 |
|
| 567 |
+
def _validate(
|
| 568 |
+
self,
|
| 569 |
+
sql: str,
|
| 570 |
+
allowed_tables: set[str],
|
| 571 |
+
limit: int,
|
| 572 |
+
column_map: dict[str, set[str]] | None = None,
|
| 573 |
+
) -> str:
|
| 574 |
"""Return an error string if validation fails, empty string if OK.
|
| 575 |
|
| 576 |
`allowed_tables` is the union of hit-table names and FK-related table
|
| 577 |
names — both are legal targets for SELECT/JOIN.
|
| 578 |
+
|
| 579 |
+
`column_map` maps table_name → set of valid column names. When provided,
|
| 580 |
+
any qualified table.column reference not found in the map triggers a retry
|
| 581 |
+
with an informative error so the LLM can self-correct without hallucinating.
|
| 582 |
"""
|
| 583 |
# Layer 1: sqlglot parse + SELECT-only check
|
| 584 |
try:
|
|
|
|
| 595 |
|
| 596 |
# Layer 2: schema grounding — table names
|
| 597 |
known_tables = {t.lower() for t in allowed_tables}
|
| 598 |
+
alias_to_table: dict[str, str] = {}
|
| 599 |
for tbl in parsed.find_all(exp.Table):
|
| 600 |
name = tbl.name.lower()
|
| 601 |
if name and name not in known_tables:
|
| 602 |
return f"Unknown table '{tbl.name}'. Only use tables from the schema."
|
| 603 |
+
alias = (tbl.alias or tbl.name).lower()
|
| 604 |
+
alias_to_table[alias] = name
|
| 605 |
+
|
| 606 |
+
# Layer 3: column grounding — qualified references only (table.column)
|
| 607 |
+
if column_map:
|
| 608 |
+
normalized_map = {t.lower(): {c.lower() for c in cols} for t, cols in column_map.items()}
|
| 609 |
+
for col_node in parsed.find_all(exp.Column):
|
| 610 |
+
tbl_ref = col_node.table
|
| 611 |
+
if not tbl_ref:
|
| 612 |
+
continue # unqualified — skip, can't resolve without full alias tracking
|
| 613 |
+
tbl_name = alias_to_table.get(tbl_ref.lower(), tbl_ref.lower())
|
| 614 |
+
col_name = col_node.name.lower()
|
| 615 |
+
if tbl_name in normalized_map and col_name not in normalized_map[tbl_name]:
|
| 616 |
+
available = ", ".join(sorted(normalized_map[tbl_name]))
|
| 617 |
+
return (
|
| 618 |
+
f"Column '{col_node.name}' does not exist on table '{tbl_name}'. "
|
| 619 |
+
f"Available columns: {available}."
|
| 620 |
+
)
|
| 621 |
+
|
| 622 |
+
# Layer 4: LIMIT enforcement (inject if missing — done before execution)
|
| 623 |
return ""
|
| 624 |
|
| 625 |
# ------------------------------------------------------------------
|
src/rag/retrievers/schema.py
CHANGED
|
@@ -194,26 +194,6 @@ class SchemaRetriever(BaseRetriever):
|
|
| 194 |
d.get("sheet_name"),
|
| 195 |
)
|
| 196 |
|
| 197 |
-
def _rrf_merge(
|
| 198 |
-
self,
|
| 199 |
-
*ranked_lists: list[RetrievalResult],
|
| 200 |
-
k_rrf: int = 60,
|
| 201 |
-
top_k: int = 5,
|
| 202 |
-
) -> list[RetrievalResult]:
|
| 203 |
-
"""Reciprocal Rank Fusion — combines ranked lists using rank positions only."""
|
| 204 |
-
scores: dict[tuple, float] = {}
|
| 205 |
-
index: dict[tuple, RetrievalResult] = {}
|
| 206 |
-
|
| 207 |
-
for ranked in ranked_lists:
|
| 208 |
-
for rank, result in enumerate(ranked):
|
| 209 |
-
key = self._chunk_key(result)
|
| 210 |
-
scores[key] = scores.get(key, 0.0) + 1.0 / (k_rrf + rank + 1)
|
| 211 |
-
if key not in index or result.score > index[key].score:
|
| 212 |
-
index[key] = result
|
| 213 |
-
|
| 214 |
-
merged = sorted(index.values(), key=lambda r: scores[self._chunk_key(r)], reverse=True)
|
| 215 |
-
return merged[:top_k]
|
| 216 |
-
|
| 217 |
def _dedup(self, results: list[RetrievalResult]) -> list[RetrievalResult]:
|
| 218 |
"""Deduplicate by chunk identity, keeping highest score per unique key."""
|
| 219 |
seen: dict[tuple, RetrievalResult] = {}
|
|
@@ -223,12 +203,93 @@ class SchemaRetriever(BaseRetriever):
|
|
| 223 |
seen[key] = r
|
| 224 |
return sorted(seen.values(), key=lambda r: r.score, reverse=True)
|
| 225 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 226 |
# ------------------------------------------------------------------
|
| 227 |
# Public interface — called by the router
|
| 228 |
# ------------------------------------------------------------------
|
| 229 |
|
| 230 |
async def retrieve(self, query: str, user_id: str, k: int = 5) -> list[RetrievalResult]:
|
| 231 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
embedding = await self._embed_query(query)
|
| 233 |
db_col_results, db_tbl_results, tabular_results, fts_results = await asyncio.gather(
|
| 234 |
self._search_db(embedding, user_id, k),
|
|
@@ -236,11 +297,15 @@ class SchemaRetriever(BaseRetriever):
|
|
| 236 |
self._search_tabular(embedding, user_id, k),
|
| 237 |
self._search_fts_db(query, user_id, k * 4),
|
| 238 |
)
|
| 239 |
-
|
| 240 |
-
|
|
|
|
|
|
|
|
|
|
| 241 |
logger.info(
|
| 242 |
"schema retrieval",
|
| 243 |
count=len(results),
|
|
|
|
| 244 |
db_cols=len(db_col_results),
|
| 245 |
db_tables=len(db_tbl_results),
|
| 246 |
tabular=len(tabular_results),
|
|
|
|
| 194 |
d.get("sheet_name"),
|
| 195 |
)
|
| 196 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 197 |
def _dedup(self, results: list[RetrievalResult]) -> list[RetrievalResult]:
|
| 198 |
"""Deduplicate by chunk identity, keeping highest score per unique key."""
|
| 199 |
seen: dict[tuple, RetrievalResult] = {}
|
|
|
|
| 203 |
seen[key] = r
|
| 204 |
return sorted(seen.values(), key=lambda r: r.score, reverse=True)
|
| 205 |
|
| 206 |
+
def _rank_db_tables(
|
| 207 |
+
self,
|
| 208 |
+
tbl_results: list[RetrievalResult],
|
| 209 |
+
col_results: list[RetrievalResult],
|
| 210 |
+
fts_results: list[RetrievalResult],
|
| 211 |
+
top_k: int,
|
| 212 |
+
k_rrf: int = 60,
|
| 213 |
+
) -> list[RetrievalResult]:
|
| 214 |
+
"""Rank DB tables by RRF across three legs:
|
| 215 |
+
L1 (primary): table-summary chunk similarity
|
| 216 |
+
L2 (vote): best column-chunk position per table
|
| 217 |
+
L3 (vote): best FTS position per table
|
| 218 |
+
|
| 219 |
+
Returns top-k table-chunk RetrievalResults. For tables surfaced by
|
| 220 |
+
L2/L3 but missing a table chunk, a minimal stub is returned so that
|
| 221 |
+
db_executor._fetch_full_schema can seed off data.table_name.
|
| 222 |
+
"""
|
| 223 |
+
# L1: tables ranked by table-chunk cosine score
|
| 224 |
+
tbl_index: dict[str, RetrievalResult] = {}
|
| 225 |
+
tbl_ranked: list[str] = []
|
| 226 |
+
for r in tbl_results:
|
| 227 |
+
tname = r.metadata.get("data", {}).get("table_name")
|
| 228 |
+
if tname and tname not in tbl_index:
|
| 229 |
+
tbl_index[tname] = r
|
| 230 |
+
tbl_ranked.append(tname)
|
| 231 |
+
|
| 232 |
+
# L2: tables ranked by first-appearance in column-chunk list (best col score)
|
| 233 |
+
col_table_ranked: list[str] = []
|
| 234 |
+
seen: set[str] = set()
|
| 235 |
+
for r in col_results:
|
| 236 |
+
tname = r.metadata.get("data", {}).get("table_name")
|
| 237 |
+
if tname and tname not in seen:
|
| 238 |
+
col_table_ranked.append(tname)
|
| 239 |
+
seen.add(tname)
|
| 240 |
+
|
| 241 |
+
# L3: tables ranked by first-appearance in FTS list
|
| 242 |
+
fts_table_ranked: list[str] = []
|
| 243 |
+
seen = set()
|
| 244 |
+
for r in fts_results:
|
| 245 |
+
tname = r.metadata.get("data", {}).get("table_name")
|
| 246 |
+
if tname and tname not in seen:
|
| 247 |
+
fts_table_ranked.append(tname)
|
| 248 |
+
seen.add(tname)
|
| 249 |
+
|
| 250 |
+
# RRF over table names across the three legs
|
| 251 |
+
rrf_scores: dict[str, float] = {}
|
| 252 |
+
for ranked_list in [tbl_ranked, col_table_ranked, fts_table_ranked]:
|
| 253 |
+
for rank, tname in enumerate(ranked_list):
|
| 254 |
+
rrf_scores[tname] = rrf_scores.get(tname, 0.0) + 1.0 / (k_rrf + rank + 1)
|
| 255 |
+
|
| 256 |
+
top_tables = sorted(rrf_scores, key=lambda t: rrf_scores[t], reverse=True)[:top_k]
|
| 257 |
+
|
| 258 |
+
results: list[RetrievalResult] = []
|
| 259 |
+
for tname in top_tables:
|
| 260 |
+
if tname in tbl_index:
|
| 261 |
+
r = tbl_index[tname]
|
| 262 |
+
r.score = rrf_scores[tname]
|
| 263 |
+
results.append(r)
|
| 264 |
+
else:
|
| 265 |
+
# Surfaced by column/FTS votes with no table chunk — minimal stub
|
| 266 |
+
results.append(RetrievalResult(
|
| 267 |
+
content=f"Table: {tname}",
|
| 268 |
+
metadata={"data": {"table_name": tname}, "source_type": "database"},
|
| 269 |
+
score=rrf_scores[tname],
|
| 270 |
+
source_type="database",
|
| 271 |
+
))
|
| 272 |
+
return results
|
| 273 |
+
|
| 274 |
# ------------------------------------------------------------------
|
| 275 |
# Public interface — called by the router
|
| 276 |
# ------------------------------------------------------------------
|
| 277 |
|
| 278 |
async def retrieve(self, query: str, user_id: str, k: int = 5) -> list[RetrievalResult]:
|
| 279 |
+
"""Table-first retrieval for DB sources; chunk-level for tabular.
|
| 280 |
+
|
| 281 |
+
DB tables are ranked via RRF across three legs:
|
| 282 |
+
L1 (primary): table-summary chunk similarity
|
| 283 |
+
L2 (vote): top-K column-chunk cosine, grouped by table
|
| 284 |
+
L3 (vote): top-K FTS column hits, grouped by table
|
| 285 |
+
|
| 286 |
+
db_executor downstream fetches the full per-column schema for the
|
| 287 |
+
ranked table set via _fetch_full_schema — the column chunks returned
|
| 288 |
+
here are intentionally NOT used as the schema source, only for voting.
|
| 289 |
+
|
| 290 |
+
Tabular (CSV/XLSX) chunks remain at column/sheet level since they have
|
| 291 |
+
no table-level chunks.
|
| 292 |
+
"""
|
| 293 |
embedding = await self._embed_query(query)
|
| 294 |
db_col_results, db_tbl_results, tabular_results, fts_results = await asyncio.gather(
|
| 295 |
self._search_db(embedding, user_id, k),
|
|
|
|
| 297 |
self._search_tabular(embedding, user_id, k),
|
| 298 |
self._search_fts_db(query, user_id, k * 4),
|
| 299 |
)
|
| 300 |
+
|
| 301 |
+
db_ranked = self._rank_db_tables(db_tbl_results, db_col_results, fts_results, top_k=k)
|
| 302 |
+
tabular_final = self._dedup(tabular_results)[:k]
|
| 303 |
+
|
| 304 |
+
results = db_ranked + tabular_final
|
| 305 |
logger.info(
|
| 306 |
"schema retrieval",
|
| 307 |
count=len(results),
|
| 308 |
+
db_tables_ranked=len(db_ranked),
|
| 309 |
db_cols=len(db_col_results),
|
| 310 |
db_tables=len(db_tbl_results),
|
| 311 |
tabular=len(tabular_results),
|