Rifqi Hafizuddin commited on
Commit ·
f86da27
1
Parent(s): a25febe
[KM-516][KM-517] add new feature; ai can now see table & column names that have fk relationship with retrieved result
Browse files
src/query/executors/db_executor.py
CHANGED
|
@@ -137,12 +137,13 @@ class DbExecutor(BaseExecutor):
|
|
| 137 |
logger.warning("unsupported db_type for query execution", db_type=client.db_type)
|
| 138 |
return None
|
| 139 |
|
| 140 |
-
# Distinct table names from retrieval results
|
| 141 |
table_names = list({
|
| 142 |
r.metadata.get("data", {}).get("table_name")
|
| 143 |
for r in results
|
| 144 |
if r.metadata.get("data", {}).get("table_name")
|
| 145 |
})
|
|
|
|
| 146 |
|
| 147 |
full_schema = await self._fetch_full_schema(client_id, table_names, user_id)
|
| 148 |
if not full_schema:
|
|
@@ -235,6 +236,58 @@ class DbExecutor(BaseExecutor):
|
|
| 235 |
# Schema helpers
|
| 236 |
# ------------------------------------------------------------------
|
| 237 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 238 |
async def _fetch_full_schema(
|
| 239 |
self,
|
| 240 |
client_id: str,
|
|
|
|
| 137 |
logger.warning("unsupported db_type for query execution", db_type=client.db_type)
|
| 138 |
return None
|
| 139 |
|
| 140 |
+
# Distinct table names from retrieval results, expanded via FK relationships
|
| 141 |
table_names = list({
|
| 142 |
r.metadata.get("data", {}).get("table_name")
|
| 143 |
for r in results
|
| 144 |
if r.metadata.get("data", {}).get("table_name")
|
| 145 |
})
|
| 146 |
+
table_names = await self._expand_with_fk_tables(client_id, user_id, table_names)
|
| 147 |
|
| 148 |
full_schema = await self._fetch_full_schema(client_id, table_names, user_id)
|
| 149 |
if not full_schema:
|
|
|
|
| 236 |
# Schema helpers
|
| 237 |
# ------------------------------------------------------------------
|
| 238 |
|
| 239 |
+
async def _expand_with_fk_tables(
|
| 240 |
+
self,
|
| 241 |
+
client_id: str,
|
| 242 |
+
user_id: str,
|
| 243 |
+
table_names: list[str],
|
| 244 |
+
) -> list[str]:
|
| 245 |
+
"""Expand table_names with any tables FK-referenced by the retrieved tables.
|
| 246 |
+
|
| 247 |
+
Prevents SQL generation failures when a required table (e.g. orders) wasn't
|
| 248 |
+
returned by retrieval but is referenced via FK from a table that was
|
| 249 |
+
(e.g. order_items.order_id -> orders.id).
|
| 250 |
+
"""
|
| 251 |
+
if not table_names:
|
| 252 |
+
return table_names
|
| 253 |
+
|
| 254 |
+
placeholders = ", ".join(f":t{i}" for i in range(len(table_names)))
|
| 255 |
+
sql = text(f"""
|
| 256 |
+
SELECT DISTINCT lpe.cmetadata->'data'->>'foreign_key' AS fk
|
| 257 |
+
FROM langchain_pg_embedding lpe
|
| 258 |
+
JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid
|
| 259 |
+
WHERE lpc.name = 'document_embeddings'
|
| 260 |
+
AND lpe.cmetadata->>'user_id' = :user_id
|
| 261 |
+
AND lpe.cmetadata->>'source_type' = 'database'
|
| 262 |
+
AND lpe.cmetadata->>'database_client_id' = :client_id
|
| 263 |
+
AND lpe.cmetadata->'data'->>'table_name' IN ({placeholders})
|
| 264 |
+
AND lpe.cmetadata->'data'->>'foreign_key' IS NOT NULL
|
| 265 |
+
""")
|
| 266 |
+
|
| 267 |
+
params: dict[str, Any] = {"user_id": user_id, "client_id": client_id}
|
| 268 |
+
for i, name in enumerate(table_names):
|
| 269 |
+
params[f"t{i}"] = name
|
| 270 |
+
|
| 271 |
+
async with _pgvector_engine.connect() as conn:
|
| 272 |
+
result = await conn.execute(sql, params)
|
| 273 |
+
rows = result.fetchall()
|
| 274 |
+
|
| 275 |
+
expanded = set(table_names)
|
| 276 |
+
for row in rows:
|
| 277 |
+
fk = row.fk # format: "referred_table.referred_column"
|
| 278 |
+
if fk:
|
| 279 |
+
referred_table = fk.split(".")[0]
|
| 280 |
+
expanded.add(referred_table)
|
| 281 |
+
|
| 282 |
+
if expanded != set(table_names):
|
| 283 |
+
logger.info(
|
| 284 |
+
"expanded tables via FK",
|
| 285 |
+
original=sorted(table_names),
|
| 286 |
+
expanded=sorted(expanded),
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
return list(expanded)
|
| 290 |
+
|
| 291 |
async def _fetch_full_schema(
|
| 292 |
self,
|
| 293 |
client_id: str,
|