Rifqi Hafizuddin commited on
Commit
bb29492
·
1 Parent(s): 00aa61d

[NOTICKET] now retrieve db tables first, then get column from the obtained tables. reduce k to 5

Browse files
src/api/v1/chat.py CHANGED
@@ -190,11 +190,11 @@ async def chat_stream(request: ChatRequest, db: AsyncSession = Depends(get_db)):
190
 
191
  if intent_result is None:
192
  # Step 2: Launch retrieval and history loading in parallel, then run orchestrator.
193
- # k=10 (not the wrapper default of 5) so the merged top-k spans more
194
  # tables — db_executor's FK expansion is one-hop and cannot bridge
195
  # 2-hop gaps (e.g. customers -> order_items -> products) on its own.
196
  retrieval_task = asyncio.create_task(
197
- retriever.retrieve(request.message, request.user_id, db, k=10)
198
  )
199
  history_task = asyncio.create_task(
200
  load_history(db, request.room_id, limit=6) # 6 msgs (3 pairs) for orchestrator
@@ -222,7 +222,7 @@ async def chat_stream(request: ChatRequest, db: AsyncSession = Depends(get_db)):
222
  query=search_query,
223
  user_id=request.user_id,
224
  db=db,
225
- k=10,
226
  source_hint=intent_result.get("source_hint", "both"),
227
  )
228
  else:
 
190
 
191
  if intent_result is None:
192
  # Step 2: Launch retrieval and history loading in parallel, then run orchestrator.
193
+ # k=5
194
  # tables — db_executor's FK expansion is one-hop and cannot bridge
195
  # 2-hop gaps (e.g. customers -> order_items -> products) on its own.
196
  retrieval_task = asyncio.create_task(
197
+ retriever.retrieve(request.message, request.user_id, db, k=5)
198
  )
199
  history_task = asyncio.create_task(
200
  load_history(db, request.room_id, limit=6) # 6 msgs (3 pairs) for orchestrator
 
222
  query=search_query,
223
  user_id=request.user_id,
224
  db=db,
225
+ k=5,
226
  source_hint=intent_result.get("source_hint", "both"),
227
  )
228
  else:
src/query/executors/db_executor.py CHANGED
@@ -193,7 +193,12 @@ class DbExecutor(BaseExecutor):
193
  })
194
  sql = result.sql.strip()
195
  allowed_tables = set(full_schema) | set(related_schema)
196
- validation_error = self._validate(sql, allowed_tables, capped_limit)
 
 
 
 
 
197
  if validation_error:
198
  prev_error = validation_error
199
  prev_reasoning = result.reasoning
@@ -559,11 +564,21 @@ class DbExecutor(BaseExecutor):
559
  # Guardrails
560
  # ------------------------------------------------------------------
561
 
562
- def _validate(self, sql: str, allowed_tables: set[str], limit: int) -> str:
 
 
 
 
 
 
563
  """Return an error string if validation fails, empty string if OK.
564
 
565
  `allowed_tables` is the union of hit-table names and FK-related table
566
  names — both are legal targets for SELECT/JOIN.
 
 
 
 
567
  """
568
  # Layer 1: sqlglot parse + SELECT-only check
569
  try:
@@ -580,12 +595,31 @@ class DbExecutor(BaseExecutor):
580
 
581
  # Layer 2: schema grounding — table names
582
  known_tables = {t.lower() for t in allowed_tables}
 
583
  for tbl in parsed.find_all(exp.Table):
584
  name = tbl.name.lower()
585
  if name and name not in known_tables:
586
  return f"Unknown table '{tbl.name}'. Only use tables from the schema."
587
-
588
- # Layer 3: LIMIT enforcement (inject if missing — done before execution)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
589
  return ""
590
 
591
  # ------------------------------------------------------------------
 
193
  })
194
  sql = result.sql.strip()
195
  allowed_tables = set(full_schema) | set(related_schema)
196
+ column_map: dict[str, set[str]] = {
197
+ t: {c["name"] for c in cols} for t, cols in full_schema.items()
198
+ }
199
+ for t, info in related_schema.items():
200
+ column_map[t] = set(info.get("column_names") or [])
201
+ validation_error = self._validate(sql, allowed_tables, capped_limit, column_map)
202
  if validation_error:
203
  prev_error = validation_error
204
  prev_reasoning = result.reasoning
 
564
  # Guardrails
565
  # ------------------------------------------------------------------
566
 
567
+ def _validate(
568
+ self,
569
+ sql: str,
570
+ allowed_tables: set[str],
571
+ limit: int,
572
+ column_map: dict[str, set[str]] | None = None,
573
+ ) -> str:
574
  """Return an error string if validation fails, empty string if OK.
575
 
576
  `allowed_tables` is the union of hit-table names and FK-related table
577
  names — both are legal targets for SELECT/JOIN.
578
+
579
+ `column_map` maps table_name → set of valid column names. When provided,
580
+ any qualified table.column reference not found in the map triggers a retry
581
+ with an informative error so the LLM can self-correct without hallucinating.
582
  """
583
  # Layer 1: sqlglot parse + SELECT-only check
584
  try:
 
595
 
596
  # Layer 2: schema grounding — table names
597
  known_tables = {t.lower() for t in allowed_tables}
598
+ alias_to_table: dict[str, str] = {}
599
  for tbl in parsed.find_all(exp.Table):
600
  name = tbl.name.lower()
601
  if name and name not in known_tables:
602
  return f"Unknown table '{tbl.name}'. Only use tables from the schema."
603
+ alias = (tbl.alias or tbl.name).lower()
604
+ alias_to_table[alias] = name
605
+
606
+ # Layer 3: column grounding — qualified references only (table.column)
607
+ if column_map:
608
+ normalized_map = {t.lower(): {c.lower() for c in cols} for t, cols in column_map.items()}
609
+ for col_node in parsed.find_all(exp.Column):
610
+ tbl_ref = col_node.table
611
+ if not tbl_ref:
612
+ continue # unqualified — skip, can't resolve without full alias tracking
613
+ tbl_name = alias_to_table.get(tbl_ref.lower(), tbl_ref.lower())
614
+ col_name = col_node.name.lower()
615
+ if tbl_name in normalized_map and col_name not in normalized_map[tbl_name]:
616
+ available = ", ".join(sorted(normalized_map[tbl_name]))
617
+ return (
618
+ f"Column '{col_node.name}' does not exist on table '{tbl_name}'. "
619
+ f"Available columns: {available}."
620
+ )
621
+
622
+ # Layer 4: LIMIT enforcement (inject if missing — done before execution)
623
  return ""
624
 
625
  # ------------------------------------------------------------------
src/rag/retrievers/schema.py CHANGED
@@ -194,26 +194,6 @@ class SchemaRetriever(BaseRetriever):
194
  d.get("sheet_name"),
195
  )
196
 
197
- def _rrf_merge(
198
- self,
199
- *ranked_lists: list[RetrievalResult],
200
- k_rrf: int = 60,
201
- top_k: int = 5,
202
- ) -> list[RetrievalResult]:
203
- """Reciprocal Rank Fusion — combines ranked lists using rank positions only."""
204
- scores: dict[tuple, float] = {}
205
- index: dict[tuple, RetrievalResult] = {}
206
-
207
- for ranked in ranked_lists:
208
- for rank, result in enumerate(ranked):
209
- key = self._chunk_key(result)
210
- scores[key] = scores.get(key, 0.0) + 1.0 / (k_rrf + rank + 1)
211
- if key not in index or result.score > index[key].score:
212
- index[key] = result
213
-
214
- merged = sorted(index.values(), key=lambda r: scores[self._chunk_key(r)], reverse=True)
215
- return merged[:top_k]
216
-
217
  def _dedup(self, results: list[RetrievalResult]) -> list[RetrievalResult]:
218
  """Deduplicate by chunk identity, keeping highest score per unique key."""
219
  seen: dict[tuple, RetrievalResult] = {}
@@ -223,12 +203,93 @@ class SchemaRetriever(BaseRetriever):
223
  seen[key] = r
224
  return sorted(seen.values(), key=lambda r: r.score, reverse=True)
225
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
  # ------------------------------------------------------------------
227
  # Public interface — called by the router
228
  # ------------------------------------------------------------------
229
 
230
  async def retrieve(self, query: str, user_id: str, k: int = 5) -> list[RetrievalResult]:
231
- """RRF merge of dense (DB columns + DB tables + tabular) and FTS (DB cols only)."""
 
 
 
 
 
 
 
 
 
 
 
 
 
232
  embedding = await self._embed_query(query)
233
  db_col_results, db_tbl_results, tabular_results, fts_results = await asyncio.gather(
234
  self._search_db(embedding, user_id, k),
@@ -236,11 +297,15 @@ class SchemaRetriever(BaseRetriever):
236
  self._search_tabular(embedding, user_id, k),
237
  self._search_fts_db(query, user_id, k * 4),
238
  )
239
- dense = self._dedup(db_col_results + db_tbl_results + tabular_results)[:k]
240
- results = self._rrf_merge(dense, self._dedup(fts_results), top_k=k)
 
 
 
241
  logger.info(
242
  "schema retrieval",
243
  count=len(results),
 
244
  db_cols=len(db_col_results),
245
  db_tables=len(db_tbl_results),
246
  tabular=len(tabular_results),
 
194
  d.get("sheet_name"),
195
  )
196
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
  def _dedup(self, results: list[RetrievalResult]) -> list[RetrievalResult]:
198
  """Deduplicate by chunk identity, keeping highest score per unique key."""
199
  seen: dict[tuple, RetrievalResult] = {}
 
203
  seen[key] = r
204
  return sorted(seen.values(), key=lambda r: r.score, reverse=True)
205
 
206
+ def _rank_db_tables(
207
+ self,
208
+ tbl_results: list[RetrievalResult],
209
+ col_results: list[RetrievalResult],
210
+ fts_results: list[RetrievalResult],
211
+ top_k: int,
212
+ k_rrf: int = 60,
213
+ ) -> list[RetrievalResult]:
214
+ """Rank DB tables by RRF across three legs:
215
+ L1 (primary): table-summary chunk similarity
216
+ L2 (vote): best column-chunk position per table
217
+ L3 (vote): best FTS position per table
218
+
219
+ Returns top-k table-chunk RetrievalResults. For tables surfaced by
220
+ L2/L3 but missing a table chunk, a minimal stub is returned so that
221
+ db_executor._fetch_full_schema can seed off data.table_name.
222
+ """
223
+ # L1: tables ranked by table-chunk cosine score
224
+ tbl_index: dict[str, RetrievalResult] = {}
225
+ tbl_ranked: list[str] = []
226
+ for r in tbl_results:
227
+ tname = r.metadata.get("data", {}).get("table_name")
228
+ if tname and tname not in tbl_index:
229
+ tbl_index[tname] = r
230
+ tbl_ranked.append(tname)
231
+
232
+ # L2: tables ranked by first-appearance in column-chunk list (best col score)
233
+ col_table_ranked: list[str] = []
234
+ seen: set[str] = set()
235
+ for r in col_results:
236
+ tname = r.metadata.get("data", {}).get("table_name")
237
+ if tname and tname not in seen:
238
+ col_table_ranked.append(tname)
239
+ seen.add(tname)
240
+
241
+ # L3: tables ranked by first-appearance in FTS list
242
+ fts_table_ranked: list[str] = []
243
+ seen = set()
244
+ for r in fts_results:
245
+ tname = r.metadata.get("data", {}).get("table_name")
246
+ if tname and tname not in seen:
247
+ fts_table_ranked.append(tname)
248
+ seen.add(tname)
249
+
250
+ # RRF over table names across the three legs
251
+ rrf_scores: dict[str, float] = {}
252
+ for ranked_list in [tbl_ranked, col_table_ranked, fts_table_ranked]:
253
+ for rank, tname in enumerate(ranked_list):
254
+ rrf_scores[tname] = rrf_scores.get(tname, 0.0) + 1.0 / (k_rrf + rank + 1)
255
+
256
+ top_tables = sorted(rrf_scores, key=lambda t: rrf_scores[t], reverse=True)[:top_k]
257
+
258
+ results: list[RetrievalResult] = []
259
+ for tname in top_tables:
260
+ if tname in tbl_index:
261
+ r = tbl_index[tname]
262
+ r.score = rrf_scores[tname]
263
+ results.append(r)
264
+ else:
265
+ # Surfaced by column/FTS votes with no table chunk — minimal stub
266
+ results.append(RetrievalResult(
267
+ content=f"Table: {tname}",
268
+ metadata={"data": {"table_name": tname}, "source_type": "database"},
269
+ score=rrf_scores[tname],
270
+ source_type="database",
271
+ ))
272
+ return results
273
+
274
  # ------------------------------------------------------------------
275
  # Public interface — called by the router
276
  # ------------------------------------------------------------------
277
 
278
  async def retrieve(self, query: str, user_id: str, k: int = 5) -> list[RetrievalResult]:
279
+ """Table-first retrieval for DB sources; chunk-level for tabular.
280
+
281
+ DB tables are ranked via RRF across three legs:
282
+ L1 (primary): table-summary chunk similarity
283
+ L2 (vote): top-K column-chunk cosine, grouped by table
284
+ L3 (vote): top-K FTS column hits, grouped by table
285
+
286
+ db_executor downstream fetches the full per-column schema for the
287
+ ranked table set via _fetch_full_schema — the column chunks returned
288
+ here are intentionally NOT used as the schema source, only for voting.
289
+
290
+ Tabular (CSV/XLSX) chunks remain at column/sheet level since they have
291
+ no table-level chunks.
292
+ """
293
  embedding = await self._embed_query(query)
294
  db_col_results, db_tbl_results, tabular_results, fts_results = await asyncio.gather(
295
  self._search_db(embedding, user_id, k),
 
297
  self._search_tabular(embedding, user_id, k),
298
  self._search_fts_db(query, user_id, k * 4),
299
  )
300
+
301
+ db_ranked = self._rank_db_tables(db_tbl_results, db_col_results, fts_results, top_k=k)
302
+ tabular_final = self._dedup(tabular_results)[:k]
303
+
304
+ results = db_ranked + tabular_final
305
  logger.info(
306
  "schema retrieval",
307
  count=len(results),
308
+ db_tables_ranked=len(db_ranked),
309
  db_cols=len(db_col_results),
310
  db_tables=len(db_tbl_results),
311
  tabular=len(tabular_results),