Rifqi Hafizuddin commited on
Commit
f86da27
·
1 Parent(s): a25febe

[KM-516][KM-517] add new feature; ai can now see table & column names that have fk relationship with retrieved result

Browse files
Files changed (1) hide show
  1. src/query/executors/db_executor.py +54 -1
src/query/executors/db_executor.py CHANGED
@@ -137,12 +137,13 @@ class DbExecutor(BaseExecutor):
137
  logger.warning("unsupported db_type for query execution", db_type=client.db_type)
138
  return None
139
 
140
- # Distinct table names from retrieval results
141
  table_names = list({
142
  r.metadata.get("data", {}).get("table_name")
143
  for r in results
144
  if r.metadata.get("data", {}).get("table_name")
145
  })
 
146
 
147
  full_schema = await self._fetch_full_schema(client_id, table_names, user_id)
148
  if not full_schema:
@@ -235,6 +236,58 @@ class DbExecutor(BaseExecutor):
235
  # Schema helpers
236
  # ------------------------------------------------------------------
237
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
  async def _fetch_full_schema(
239
  self,
240
  client_id: str,
 
137
  logger.warning("unsupported db_type for query execution", db_type=client.db_type)
138
  return None
139
 
140
+ # Distinct table names from retrieval results, expanded via FK relationships
141
  table_names = list({
142
  r.metadata.get("data", {}).get("table_name")
143
  for r in results
144
  if r.metadata.get("data", {}).get("table_name")
145
  })
146
+ table_names = await self._expand_with_fk_tables(client_id, user_id, table_names)
147
 
148
  full_schema = await self._fetch_full_schema(client_id, table_names, user_id)
149
  if not full_schema:
 
236
  # Schema helpers
237
  # ------------------------------------------------------------------
238
 
239
+ async def _expand_with_fk_tables(
240
+ self,
241
+ client_id: str,
242
+ user_id: str,
243
+ table_names: list[str],
244
+ ) -> list[str]:
245
+ """Expand table_names with any tables FK-referenced by the retrieved tables.
246
+
247
+ Prevents SQL generation failures when a required table (e.g. orders) wasn't
248
+ returned by retrieval but is referenced via FK from a table that was
249
+ (e.g. order_items.order_id -> orders.id).
250
+ """
251
+ if not table_names:
252
+ return table_names
253
+
254
+ placeholders = ", ".join(f":t{i}" for i in range(len(table_names)))
255
+ sql = text(f"""
256
+ SELECT DISTINCT lpe.cmetadata->'data'->>'foreign_key' AS fk
257
+ FROM langchain_pg_embedding lpe
258
+ JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid
259
+ WHERE lpc.name = 'document_embeddings'
260
+ AND lpe.cmetadata->>'user_id' = :user_id
261
+ AND lpe.cmetadata->>'source_type' = 'database'
262
+ AND lpe.cmetadata->>'database_client_id' = :client_id
263
+ AND lpe.cmetadata->'data'->>'table_name' IN ({placeholders})
264
+ AND lpe.cmetadata->'data'->>'foreign_key' IS NOT NULL
265
+ """)
266
+
267
+ params: dict[str, Any] = {"user_id": user_id, "client_id": client_id}
268
+ for i, name in enumerate(table_names):
269
+ params[f"t{i}"] = name
270
+
271
+ async with _pgvector_engine.connect() as conn:
272
+ result = await conn.execute(sql, params)
273
+ rows = result.fetchall()
274
+
275
+ expanded = set(table_names)
276
+ for row in rows:
277
+ fk = row.fk # format: "referred_table.referred_column"
278
+ if fk:
279
+ referred_table = fk.split(".")[0]
280
+ expanded.add(referred_table)
281
+
282
+ if expanded != set(table_names):
283
+ logger.info(
284
+ "expanded tables via FK",
285
+ original=sorted(table_names),
286
+ expanded=sorted(expanded),
287
+ )
288
+
289
+ return list(expanded)
290
+
291
  async def _fetch_full_schema(
292
  self,
293
  client_id: str,