sofhiaazzhr commited on
Commit
29efec6
·
2 Parent(s): 240251c948d6dd

Merge branch 'dev_new' of https://huggingface.co/spaces/DataEyond/Agentic-Service-Data-Eyond into dev_new

Browse files
src/query/executors/db.py DELETED
@@ -1,32 +0,0 @@
1
- """Executor for registered database sources (source_type="database").
2
-
3
- Flow:
4
- 1. Group RetrievalResult chunks by database_client_id.
5
- 2. For each client: decrypt creds -> connect -> SELECT relevant columns FROM table LIMIT n.
6
- 3. Return QueryResult per (client_id, table_name).
7
- """
8
-
9
- from src.middlewares.logging import get_logger
10
- from src.query.base import BaseExecutor, QueryResult
11
- from src.rag.base import RetrievalResult
12
-
13
- logger = get_logger("db_executor")
14
-
15
-
16
- class DbExecutor(BaseExecutor):
17
- async def execute(
18
- self,
19
- results: list[RetrievalResult],
20
- user_id: str,
21
- limit: int = 100,
22
- ) -> list[QueryResult]:
23
- # TODO: implement
24
- # 1. filter results where source_type == "database"
25
- # 2. group by (database_client_id, table_name) -> list of column_names
26
- # 3. per group: look up DatabaseClient, decrypt creds, connect via db_pipeline_service
27
- # 4. SELECT <columns> FROM <table> LIMIT limit
28
- # 5. return QueryResult per group
29
- raise NotImplementedError
30
-
31
-
32
- db_executor = DbExecutor()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/query/executors/db_executor.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Executor for registered database sources (source_type="database").
2
+
3
+ Flow per (client_id, question):
4
+ 1. Collect all relevant (table_name, column_name) pairs from retrieval results.
5
+ 2. Fetch the FULL schema for those tables from PGVector (not just top-k columns).
6
+ 3. Build a schema context string and send to LLM → structured SQLQuery output.
7
+ 4. Validate via sqlglot: SELECT-only, schema-grounded, LIMIT enforced.
8
+ 5. Execute on the user's DB via engine_scope + asyncio.to_thread.
9
+ 6. Return QueryResult per client_id (may span multiple tables via JOINs).
10
+
11
+ Supported db_types: postgres, supabase, mysql.
12
+ Other types are skipped with a warning — they do not raise.
13
+ """
14
+
15
+ import asyncio
16
+ from collections import defaultdict
17
+ from typing import Any
18
+
19
+ import sqlglot
20
+ import sqlglot.expressions as exp
21
+ from langchain_core.prompts import ChatPromptTemplate
22
+ from langchain_openai import AzureChatOpenAI
23
+ from sqlalchemy import text
24
+ from sqlalchemy.ext.asyncio import AsyncSession
25
+
26
+ from src.config.settings import settings
27
+ from src.database_client.database_client_service import database_client_service
28
+ from src.db.postgres.connection import _pgvector_engine
29
+ from src.middlewares.logging import get_logger
30
+ from src.models.sql_query import SQLQuery
31
+ from src.pipeline.db_pipeline import db_pipeline_service
32
+ from src.query.base import BaseExecutor, QueryResult
33
+ from src.rag.base import RetrievalResult
34
+ from src.utils.db_credential_encryption import decrypt_credentials_dict
35
+
36
+ logger = get_logger("db_executor")
37
+
38
+ _SUPPORTED_DB_TYPES = {"postgres", "supabase", "mysql"}
39
+ _MAX_RETRIES = 3
40
+ _MAX_LIMIT = 500
41
+
42
+ _SQL_SYSTEM_PROMPT = """\
43
+ You are a SQL data analyst working with a user's database.
44
+ Generate a single SQL SELECT statement that answers the user's question.
45
+
46
+ Rules:
47
+ - ONLY reference tables and columns listed in the schema below. Do not invent names.
48
+ - Always include a LIMIT clause (max {limit}).
49
+ - Do not use DELETE, UPDATE, INSERT, DROP, TRUNCATE, ALTER, CREATE, or any DDL.
50
+ - Prefer explicit JOINs over subqueries when combining tables.
51
+ - For aggregations, always alias the result column (e.g. COUNT(*) AS order_count).
52
+ - For date filtering, use standard SQL date functions appropriate for the dialect.
53
+
54
+ Schema:
55
+ {schema}
56
+
57
+ {error_section}"""
58
+
59
+
60
+ class DbExecutor(BaseExecutor):
61
+ def __init__(self) -> None:
62
+ self._llm = AzureChatOpenAI(
63
+ azure_deployment=settings.azureai_deployment_name_4o,
64
+ openai_api_version=settings.azureai_api_version_4o,
65
+ azure_endpoint=settings.azureai_endpoint_url_4o,
66
+ api_key=settings.azureai_api_key_4o,
67
+ temperature=0,
68
+ )
69
+ self._prompt = ChatPromptTemplate.from_messages([
70
+ ("system", _SQL_SYSTEM_PROMPT),
71
+ ("human", "{question}"),
72
+ ])
73
+ self._chain = self._prompt | self._llm.with_structured_output(SQLQuery)
74
+
75
+ # ------------------------------------------------------------------
76
+ # Public interface
77
+ # ------------------------------------------------------------------
78
+
79
+ async def execute(
80
+ self,
81
+ results: list[RetrievalResult],
82
+ user_id: str,
83
+ db: AsyncSession,
84
+ limit: int = 100,
85
+ ) -> list[QueryResult]:
86
+ db_results = [r for r in results if r.source_type == "database"]
87
+ if not db_results:
88
+ return []
89
+
90
+ # Group by client_id — one SQL generation + execution pass per client
91
+ by_client: dict[str, list[RetrievalResult]] = defaultdict(list)
92
+ for r in db_results:
93
+ client_id = r.metadata.get("database_client_id", "")
94
+ if client_id:
95
+ by_client[client_id].append(r)
96
+ else:
97
+ logger.warning("db result missing database_client_id, skipping")
98
+
99
+ query_results: list[QueryResult] = []
100
+ for client_id, client_results in by_client.items():
101
+ try:
102
+ qr = await self._execute_for_client(client_id, client_results, user_id, db, limit)
103
+ if qr:
104
+ query_results.append(qr)
105
+ except Exception as e:
106
+ logger.error("db executor failed for client", client_id=client_id, error=str(e))
107
+
108
+ return query_results
109
+
110
+ # ------------------------------------------------------------------
111
+ # Per-client execution
112
+ # ------------------------------------------------------------------
113
+
114
+ async def _execute_for_client(
115
+ self,
116
+ client_id: str,
117
+ results: list[RetrievalResult],
118
+ user_id: str,
119
+ db: AsyncSession,
120
+ limit: int,
121
+ ) -> QueryResult | None:
122
+ client = await database_client_service.get(db, client_id)
123
+ if not client:
124
+ logger.warning("database client not found", client_id=client_id)
125
+ return None
126
+ if client.user_id != user_id:
127
+ logger.warning("client ownership mismatch", client_id=client_id)
128
+ return None
129
+ if client.db_type not in _SUPPORTED_DB_TYPES:
130
+ logger.warning("unsupported db_type for query execution", db_type=client.db_type)
131
+ return None
132
+
133
+ # Distinct table names from retrieval results
134
+ table_names = list({
135
+ r.metadata.get("data", {}).get("table_name")
136
+ for r in results
137
+ if r.metadata.get("data", {}).get("table_name")
138
+ })
139
+
140
+ full_schema = await self._fetch_full_schema(client_id, table_names, user_id)
141
+ if not full_schema:
142
+ logger.warning("no schema found in vector store", client_id=client_id, tables=table_names)
143
+ return None
144
+
145
+ schema_ctx = self._build_schema_context(full_schema)
146
+ question = self._extract_question(results)
147
+ capped_limit = min(limit, _MAX_LIMIT)
148
+
149
+ # SQL generation with retry
150
+ validated_sql: str | None = None
151
+ prev_error: str = ""
152
+ for attempt in range(_MAX_RETRIES):
153
+ error_section = f"Previous attempt failed: {prev_error}\nFix the issue above." if prev_error else ""
154
+ try:
155
+ result: SQLQuery = await self._chain.ainvoke({
156
+ "schema": schema_ctx,
157
+ "limit": capped_limit,
158
+ "error_section": error_section,
159
+ "question": question,
160
+ })
161
+ sql = result.sql.strip()
162
+ validation_error = self._validate(sql, full_schema, capped_limit)
163
+ if validation_error:
164
+ prev_error = validation_error
165
+ logger.warning("sql validation failed", attempt=attempt + 1, error=validation_error)
166
+ continue
167
+ validated_sql = sql
168
+ logger.info("sql generated", attempt=attempt + 1, reasoning=result.reasoning)
169
+ break
170
+ except Exception as e:
171
+ prev_error = str(e)
172
+ logger.warning("sql generation error", attempt=attempt + 1, error=prev_error)
173
+
174
+ if not validated_sql:
175
+ logger.error("sql generation failed after retries", client_id=client_id)
176
+ return None
177
+
178
+ # Execute on user's DB
179
+ creds = decrypt_credentials_dict(client.credentials)
180
+ with db_pipeline_service.engine_scope(client.db_type, creds) as engine:
181
+ rows = await asyncio.to_thread(self._run_sql, engine, validated_sql)
182
+
183
+ column_types = {
184
+ col["name"]: col["type"]
185
+ for cols in full_schema.values()
186
+ for col in cols
187
+ }
188
+ columns = list(rows[0].keys()) if rows else []
189
+
190
+ return QueryResult(
191
+ source_type="database",
192
+ source_id=client_id,
193
+ table_or_file=", ".join(table_names),
194
+ columns=columns,
195
+ rows=rows,
196
+ row_count=len(rows),
197
+ metadata={
198
+ "db_type": client.db_type,
199
+ "client_name": client.name,
200
+ "sql": validated_sql,
201
+ "column_types": {c: column_types.get(c, "unknown") for c in columns},
202
+ },
203
+ )
204
+
205
+ # ------------------------------------------------------------------
206
+ # Schema helpers
207
+ # ------------------------------------------------------------------
208
+
209
+ async def _fetch_full_schema(
210
+ self,
211
+ client_id: str,
212
+ table_names: list[str],
213
+ user_id: str,
214
+ ) -> dict[str, list[dict[str, Any]]]:
215
+ """Fetch ALL column chunks for the given tables from PGVector.
216
+
217
+ Returns {table_name: [{"name": ..., "type": ..., "is_primary_key": ...,
218
+ "foreign_key": ..., "content": ...}]}
219
+ """
220
+ placeholders = ", ".join(f":t{i}" for i in range(len(table_names)))
221
+ sql = text(f"""
222
+ SELECT lpe.cmetadata, lpe.document
223
+ FROM langchain_pg_embedding lpe
224
+ JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid
225
+ WHERE lpc.name = 'document_embeddings'
226
+ AND lpe.cmetadata->>'user_id' = :user_id
227
+ AND lpe.cmetadata->>'source_type' = 'database'
228
+ AND lpe.cmetadata->>'database_client_id' = :client_id
229
+ AND lpe.cmetadata->'data'->>'table_name' IN ({placeholders})
230
+ ORDER BY lpe.cmetadata->'data'->>'table_name', lpe.cmetadata->'data'->>'column_name'
231
+ """)
232
+
233
+ params: dict[str, Any] = {"user_id": user_id, "client_id": client_id}
234
+ for i, name in enumerate(table_names):
235
+ params[f"t{i}"] = name
236
+
237
+ async with _pgvector_engine.connect() as conn:
238
+ result = await conn.execute(sql, params)
239
+ rows = result.fetchall()
240
+
241
+ schema: dict[str, list[dict[str, Any]]] = defaultdict(list)
242
+ for row in rows:
243
+ data = row.cmetadata.get("data", {})
244
+ table = data.get("table_name")
245
+ if table:
246
+ schema[table].append({
247
+ "name": data.get("column_name", ""),
248
+ "type": data.get("column_type", ""),
249
+ "is_primary_key": data.get("is_primary_key", False),
250
+ "foreign_key": data.get("foreign_key"),
251
+ "content": row.document, # chunk text includes top values / samples
252
+ })
253
+ return dict(schema)
254
+
255
+ def _build_schema_context(self, schema: dict[str, list[dict[str, Any]]]) -> str:
256
+ lines: list[str] = []
257
+ for table, columns in schema.items():
258
+ lines.append(f"Table: {table}")
259
+ for col in columns:
260
+ flags = []
261
+ if col["is_primary_key"]:
262
+ flags.append("PRIMARY KEY")
263
+ if col["foreign_key"]:
264
+ flags.append(f"FK -> {col['foreign_key']}")
265
+ flag_str = f" [{', '.join(flags)}]" if flags else ""
266
+ lines.append(f" - {col['name']} {col['type']}{flag_str}")
267
+ # Include sample/top-values line from chunk content if present
268
+ for line in col["content"].splitlines():
269
+ if line.startswith(("Top values:", "Sample values:")):
270
+ lines.append(f" {line}")
271
+ break
272
+ lines.append("")
273
+ return "\n".join(lines).strip()
274
+
275
+ def _extract_question(self, results: list[RetrievalResult]) -> str:
276
+ # The search_query rewritten by the orchestrator is not in RetrievalResult —
277
+ # the content field carries schema descriptions. Return a generic fallback;
278
+ # callers that have the original question should pass it explicitly.
279
+ # TODO: thread the original user question through to execute() when wiring into the agent.
280
+ return "Answer the user's data question using the schema provided."
281
+
282
+ # ------------------------------------------------------------------
283
+ # Guardrails
284
+ # ------------------------------------------------------------------
285
+
286
+ def _validate(self, sql: str, schema: dict[str, list[dict]], limit: int) -> str:
287
+ """Return an error string if validation fails, empty string if OK."""
288
+ # Layer 1: sqlglot parse + SELECT-only check
289
+ try:
290
+ parsed = sqlglot.parse_one(sql)
291
+ except sqlglot.errors.ParseError as e:
292
+ return f"SQL parse error: {e}"
293
+
294
+ if not isinstance(parsed, exp.Select):
295
+ return f"Only SELECT statements are allowed. Got: {type(parsed).__name__}"
296
+
297
+ # Check for DML inside CTEs
298
+ for cte in parsed.find_all(exp.With):
299
+ for node in cte.find_all((exp.Insert, exp.Update, exp.Delete)):
300
+ return f"DML ({type(node).__name__}) inside CTE is not allowed."
301
+
302
+ # Layer 2: schema grounding — table names
303
+ known_tables = {t.lower() for t in schema}
304
+ for tbl in parsed.find_all(exp.Table):
305
+ name = tbl.name.lower()
306
+ if name and name not in known_tables:
307
+ return f"Unknown table '{tbl.name}'. Only use tables from the schema."
308
+
309
+ # Layer 3: LIMIT enforcement (inject if missing — done before execution)
310
+ return ""
311
+
312
+ # ------------------------------------------------------------------
313
+ # SQL execution
314
+ # ------------------------------------------------------------------
315
+
316
+ def _enforce_limit(self, sql: str, limit: int) -> str:
317
+ """Inject or cap LIMIT using sqlglot AST manipulation."""
318
+ parsed = sqlglot.parse_one(sql)
319
+ existing = parsed.find(exp.Limit)
320
+ if existing:
321
+ current = int(existing.expression.this)
322
+ if current > limit:
323
+ existing.expression.set("this", str(limit))
324
+ else:
325
+ parsed = parsed.limit(limit)
326
+ return parsed.sql()
327
+
328
+ def _run_sql(self, engine: Any, sql: str) -> list[dict]:
329
+ with engine.connect() as conn:
330
+ result = conn.execute(text(sql))
331
+ return [dict(row) for row in result.mappings()]
332
+
333
+
334
+ db_executor = DbExecutor()
src/query/{executor.py → query_executor.py} RENAMED
File without changes
uv.lock CHANGED
@@ -66,6 +66,7 @@ dependencies = [
66
  { name = "spacy" },
67
  { name = "sqlalchemy", extra = ["asyncio"] },
68
  { name = "sqlalchemy-bigquery" },
 
69
  { name = "sse-starlette" },
70
  { name = "starlette" },
71
  { name = "structlog" },
@@ -149,6 +150,7 @@ requires-dist = [
149
  { name = "spacy", specifier = "==3.8.3" },
150
  { name = "sqlalchemy", extras = ["asyncio"], specifier = "==2.0.36" },
151
  { name = "sqlalchemy-bigquery", specifier = ">=1.11.0" },
 
152
  { name = "sse-starlette", specifier = "==2.1.3" },
153
  { name = "starlette", specifier = "==0.41.3" },
154
  { name = "structlog", specifier = "==24.4.0" },
@@ -3221,6 +3223,15 @@ wheels = [
3221
  { url = "https://files.pythonhosted.org/packages/c0/87/11e6de00ef7949bb8ea06b55304a1a4911c329fdf0d9882b464db240c2c5/sqlalchemy_bigquery-1.16.0-py3-none-any.whl", hash = "sha256:0fe7634cd954f3e74f5e2db6d159f9e5ee87a47fbe8d52eac3cd3bb3dadb3a77", size = 40615, upload-time = "2025-11-06T01:35:39.358Z" },
3222
  ]
3223
 
 
 
 
 
 
 
 
 
 
3224
  [[package]]
3225
  name = "srsly"
3226
  version = "2.5.3"
 
66
  { name = "spacy" },
67
  { name = "sqlalchemy", extra = ["asyncio"] },
68
  { name = "sqlalchemy-bigquery" },
69
+ { name = "sqlglot" },
70
  { name = "sse-starlette" },
71
  { name = "starlette" },
72
  { name = "structlog" },
 
150
  { name = "spacy", specifier = "==3.8.3" },
151
  { name = "sqlalchemy", extras = ["asyncio"], specifier = "==2.0.36" },
152
  { name = "sqlalchemy-bigquery", specifier = ">=1.11.0" },
153
+ { name = "sqlglot", specifier = ">=25.0.0" },
154
  { name = "sse-starlette", specifier = "==2.1.3" },
155
  { name = "starlette", specifier = "==0.41.3" },
156
  { name = "structlog", specifier = "==24.4.0" },
 
3223
  { url = "https://files.pythonhosted.org/packages/c0/87/11e6de00ef7949bb8ea06b55304a1a4911c329fdf0d9882b464db240c2c5/sqlalchemy_bigquery-1.16.0-py3-none-any.whl", hash = "sha256:0fe7634cd954f3e74f5e2db6d159f9e5ee87a47fbe8d52eac3cd3bb3dadb3a77", size = 40615, upload-time = "2025-11-06T01:35:39.358Z" },
3224
  ]
3225
 
3226
+ [[package]]
3227
+ name = "sqlglot"
3228
+ version = "30.6.0"
3229
+ source = { registry = "https://pypi.org/simple" }
3230
+ sdist = { url = "https://files.pythonhosted.org/packages/3c/66/6ece15f197874e56c76e1d0269cebf284ba992a80dfadca9d1972fdf7edf/sqlglot-30.6.0.tar.gz", hash = "sha256:246d34d39927422a50a3fa155f37b2f6346fba85f1a755b13c941eb32ef93361", size = 5835307, upload-time = "2026-04-20T20:11:08.164Z" }
3231
+ wheels = [
3232
+ { url = "https://files.pythonhosted.org/packages/dc/e7/64fe971cbca33a0446b06f4a5ff8e3fa4a1dbd0a039ceabcc3e6cf4087a9/sqlglot-30.6.0-py3-none-any.whl", hash = "sha256:e005fc2f47994f90d7d8df341f1cbe937518497b0b7b1507d4c03c4c9dfd2778", size = 673920, upload-time = "2026-04-20T20:11:05.758Z" },
3233
+ ]
3234
+
3235
  [[package]]
3236
  name = "srsly"
3237
  version = "2.5.3"