Merge branch 'dev_new' of https://huggingface.co/spaces/DataEyond/Agentic-Service-Data-Eyond into dev_new
Browse files- src/query/executors/db.py +0 -32
- src/query/executors/db_executor.py +334 -0
- src/query/{executor.py → query_executor.py} +0 -0
- uv.lock +11 -0
src/query/executors/db.py
DELETED
|
@@ -1,32 +0,0 @@
|
|
| 1 |
-
"""Executor for registered database sources (source_type="database").
|
| 2 |
-
|
| 3 |
-
Flow:
|
| 4 |
-
1. Group RetrievalResult chunks by database_client_id.
|
| 5 |
-
2. For each client: decrypt creds -> connect -> SELECT relevant columns FROM table LIMIT n.
|
| 6 |
-
3. Return QueryResult per (client_id, table_name).
|
| 7 |
-
"""
|
| 8 |
-
|
| 9 |
-
from src.middlewares.logging import get_logger
|
| 10 |
-
from src.query.base import BaseExecutor, QueryResult
|
| 11 |
-
from src.rag.base import RetrievalResult
|
| 12 |
-
|
| 13 |
-
logger = get_logger("db_executor")
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
class DbExecutor(BaseExecutor):
|
| 17 |
-
async def execute(
|
| 18 |
-
self,
|
| 19 |
-
results: list[RetrievalResult],
|
| 20 |
-
user_id: str,
|
| 21 |
-
limit: int = 100,
|
| 22 |
-
) -> list[QueryResult]:
|
| 23 |
-
# TODO: implement
|
| 24 |
-
# 1. filter results where source_type == "database"
|
| 25 |
-
# 2. group by (database_client_id, table_name) -> list of column_names
|
| 26 |
-
# 3. per group: look up DatabaseClient, decrypt creds, connect via db_pipeline_service
|
| 27 |
-
# 4. SELECT <columns> FROM <table> LIMIT limit
|
| 28 |
-
# 5. return QueryResult per group
|
| 29 |
-
raise NotImplementedError
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
db_executor = DbExecutor()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/query/executors/db_executor.py
ADDED
|
@@ -0,0 +1,334 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Executor for registered database sources (source_type="database").
|
| 2 |
+
|
| 3 |
+
Flow per (client_id, question):
|
| 4 |
+
1. Collect all relevant (table_name, column_name) pairs from retrieval results.
|
| 5 |
+
2. Fetch the FULL schema for those tables from PGVector (not just top-k columns).
|
| 6 |
+
3. Build a schema context string and send to LLM → structured SQLQuery output.
|
| 7 |
+
4. Validate via sqlglot: SELECT-only, schema-grounded, LIMIT enforced.
|
| 8 |
+
5. Execute on the user's DB via engine_scope + asyncio.to_thread.
|
| 9 |
+
6. Return QueryResult per client_id (may span multiple tables via JOINs).
|
| 10 |
+
|
| 11 |
+
Supported db_types: postgres, supabase, mysql.
|
| 12 |
+
Other types are skipped with a warning — they do not raise.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import asyncio
|
| 16 |
+
from collections import defaultdict
|
| 17 |
+
from typing import Any
|
| 18 |
+
|
| 19 |
+
import sqlglot
|
| 20 |
+
import sqlglot.expressions as exp
|
| 21 |
+
from langchain_core.prompts import ChatPromptTemplate
|
| 22 |
+
from langchain_openai import AzureChatOpenAI
|
| 23 |
+
from sqlalchemy import text
|
| 24 |
+
from sqlalchemy.ext.asyncio import AsyncSession
|
| 25 |
+
|
| 26 |
+
from src.config.settings import settings
|
| 27 |
+
from src.database_client.database_client_service import database_client_service
|
| 28 |
+
from src.db.postgres.connection import _pgvector_engine
|
| 29 |
+
from src.middlewares.logging import get_logger
|
| 30 |
+
from src.models.sql_query import SQLQuery
|
| 31 |
+
from src.pipeline.db_pipeline import db_pipeline_service
|
| 32 |
+
from src.query.base import BaseExecutor, QueryResult
|
| 33 |
+
from src.rag.base import RetrievalResult
|
| 34 |
+
from src.utils.db_credential_encryption import decrypt_credentials_dict
|
| 35 |
+
|
| 36 |
+
logger = get_logger("db_executor")
|
| 37 |
+
|
| 38 |
+
_SUPPORTED_DB_TYPES = {"postgres", "supabase", "mysql"}
|
| 39 |
+
_MAX_RETRIES = 3
|
| 40 |
+
_MAX_LIMIT = 500
|
| 41 |
+
|
| 42 |
+
_SQL_SYSTEM_PROMPT = """\
|
| 43 |
+
You are a SQL data analyst working with a user's database.
|
| 44 |
+
Generate a single SQL SELECT statement that answers the user's question.
|
| 45 |
+
|
| 46 |
+
Rules:
|
| 47 |
+
- ONLY reference tables and columns listed in the schema below. Do not invent names.
|
| 48 |
+
- Always include a LIMIT clause (max {limit}).
|
| 49 |
+
- Do not use DELETE, UPDATE, INSERT, DROP, TRUNCATE, ALTER, CREATE, or any DDL.
|
| 50 |
+
- Prefer explicit JOINs over subqueries when combining tables.
|
| 51 |
+
- For aggregations, always alias the result column (e.g. COUNT(*) AS order_count).
|
| 52 |
+
- For date filtering, use standard SQL date functions appropriate for the dialect.
|
| 53 |
+
|
| 54 |
+
Schema:
|
| 55 |
+
{schema}
|
| 56 |
+
|
| 57 |
+
{error_section}"""
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class DbExecutor(BaseExecutor):
|
| 61 |
+
def __init__(self) -> None:
|
| 62 |
+
self._llm = AzureChatOpenAI(
|
| 63 |
+
azure_deployment=settings.azureai_deployment_name_4o,
|
| 64 |
+
openai_api_version=settings.azureai_api_version_4o,
|
| 65 |
+
azure_endpoint=settings.azureai_endpoint_url_4o,
|
| 66 |
+
api_key=settings.azureai_api_key_4o,
|
| 67 |
+
temperature=0,
|
| 68 |
+
)
|
| 69 |
+
self._prompt = ChatPromptTemplate.from_messages([
|
| 70 |
+
("system", _SQL_SYSTEM_PROMPT),
|
| 71 |
+
("human", "{question}"),
|
| 72 |
+
])
|
| 73 |
+
self._chain = self._prompt | self._llm.with_structured_output(SQLQuery)
|
| 74 |
+
|
| 75 |
+
# ------------------------------------------------------------------
|
| 76 |
+
# Public interface
|
| 77 |
+
# ------------------------------------------------------------------
|
| 78 |
+
|
| 79 |
+
async def execute(
|
| 80 |
+
self,
|
| 81 |
+
results: list[RetrievalResult],
|
| 82 |
+
user_id: str,
|
| 83 |
+
db: AsyncSession,
|
| 84 |
+
limit: int = 100,
|
| 85 |
+
) -> list[QueryResult]:
|
| 86 |
+
db_results = [r for r in results if r.source_type == "database"]
|
| 87 |
+
if not db_results:
|
| 88 |
+
return []
|
| 89 |
+
|
| 90 |
+
# Group by client_id — one SQL generation + execution pass per client
|
| 91 |
+
by_client: dict[str, list[RetrievalResult]] = defaultdict(list)
|
| 92 |
+
for r in db_results:
|
| 93 |
+
client_id = r.metadata.get("database_client_id", "")
|
| 94 |
+
if client_id:
|
| 95 |
+
by_client[client_id].append(r)
|
| 96 |
+
else:
|
| 97 |
+
logger.warning("db result missing database_client_id, skipping")
|
| 98 |
+
|
| 99 |
+
query_results: list[QueryResult] = []
|
| 100 |
+
for client_id, client_results in by_client.items():
|
| 101 |
+
try:
|
| 102 |
+
qr = await self._execute_for_client(client_id, client_results, user_id, db, limit)
|
| 103 |
+
if qr:
|
| 104 |
+
query_results.append(qr)
|
| 105 |
+
except Exception as e:
|
| 106 |
+
logger.error("db executor failed for client", client_id=client_id, error=str(e))
|
| 107 |
+
|
| 108 |
+
return query_results
|
| 109 |
+
|
| 110 |
+
# ------------------------------------------------------------------
|
| 111 |
+
# Per-client execution
|
| 112 |
+
# ------------------------------------------------------------------
|
| 113 |
+
|
| 114 |
+
async def _execute_for_client(
|
| 115 |
+
self,
|
| 116 |
+
client_id: str,
|
| 117 |
+
results: list[RetrievalResult],
|
| 118 |
+
user_id: str,
|
| 119 |
+
db: AsyncSession,
|
| 120 |
+
limit: int,
|
| 121 |
+
) -> QueryResult | None:
|
| 122 |
+
client = await database_client_service.get(db, client_id)
|
| 123 |
+
if not client:
|
| 124 |
+
logger.warning("database client not found", client_id=client_id)
|
| 125 |
+
return None
|
| 126 |
+
if client.user_id != user_id:
|
| 127 |
+
logger.warning("client ownership mismatch", client_id=client_id)
|
| 128 |
+
return None
|
| 129 |
+
if client.db_type not in _SUPPORTED_DB_TYPES:
|
| 130 |
+
logger.warning("unsupported db_type for query execution", db_type=client.db_type)
|
| 131 |
+
return None
|
| 132 |
+
|
| 133 |
+
# Distinct table names from retrieval results
|
| 134 |
+
table_names = list({
|
| 135 |
+
r.metadata.get("data", {}).get("table_name")
|
| 136 |
+
for r in results
|
| 137 |
+
if r.metadata.get("data", {}).get("table_name")
|
| 138 |
+
})
|
| 139 |
+
|
| 140 |
+
full_schema = await self._fetch_full_schema(client_id, table_names, user_id)
|
| 141 |
+
if not full_schema:
|
| 142 |
+
logger.warning("no schema found in vector store", client_id=client_id, tables=table_names)
|
| 143 |
+
return None
|
| 144 |
+
|
| 145 |
+
schema_ctx = self._build_schema_context(full_schema)
|
| 146 |
+
question = self._extract_question(results)
|
| 147 |
+
capped_limit = min(limit, _MAX_LIMIT)
|
| 148 |
+
|
| 149 |
+
# SQL generation with retry
|
| 150 |
+
validated_sql: str | None = None
|
| 151 |
+
prev_error: str = ""
|
| 152 |
+
for attempt in range(_MAX_RETRIES):
|
| 153 |
+
error_section = f"Previous attempt failed: {prev_error}\nFix the issue above." if prev_error else ""
|
| 154 |
+
try:
|
| 155 |
+
result: SQLQuery = await self._chain.ainvoke({
|
| 156 |
+
"schema": schema_ctx,
|
| 157 |
+
"limit": capped_limit,
|
| 158 |
+
"error_section": error_section,
|
| 159 |
+
"question": question,
|
| 160 |
+
})
|
| 161 |
+
sql = result.sql.strip()
|
| 162 |
+
validation_error = self._validate(sql, full_schema, capped_limit)
|
| 163 |
+
if validation_error:
|
| 164 |
+
prev_error = validation_error
|
| 165 |
+
logger.warning("sql validation failed", attempt=attempt + 1, error=validation_error)
|
| 166 |
+
continue
|
| 167 |
+
validated_sql = sql
|
| 168 |
+
logger.info("sql generated", attempt=attempt + 1, reasoning=result.reasoning)
|
| 169 |
+
break
|
| 170 |
+
except Exception as e:
|
| 171 |
+
prev_error = str(e)
|
| 172 |
+
logger.warning("sql generation error", attempt=attempt + 1, error=prev_error)
|
| 173 |
+
|
| 174 |
+
if not validated_sql:
|
| 175 |
+
logger.error("sql generation failed after retries", client_id=client_id)
|
| 176 |
+
return None
|
| 177 |
+
|
| 178 |
+
# Execute on user's DB
|
| 179 |
+
creds = decrypt_credentials_dict(client.credentials)
|
| 180 |
+
with db_pipeline_service.engine_scope(client.db_type, creds) as engine:
|
| 181 |
+
rows = await asyncio.to_thread(self._run_sql, engine, validated_sql)
|
| 182 |
+
|
| 183 |
+
column_types = {
|
| 184 |
+
col["name"]: col["type"]
|
| 185 |
+
for cols in full_schema.values()
|
| 186 |
+
for col in cols
|
| 187 |
+
}
|
| 188 |
+
columns = list(rows[0].keys()) if rows else []
|
| 189 |
+
|
| 190 |
+
return QueryResult(
|
| 191 |
+
source_type="database",
|
| 192 |
+
source_id=client_id,
|
| 193 |
+
table_or_file=", ".join(table_names),
|
| 194 |
+
columns=columns,
|
| 195 |
+
rows=rows,
|
| 196 |
+
row_count=len(rows),
|
| 197 |
+
metadata={
|
| 198 |
+
"db_type": client.db_type,
|
| 199 |
+
"client_name": client.name,
|
| 200 |
+
"sql": validated_sql,
|
| 201 |
+
"column_types": {c: column_types.get(c, "unknown") for c in columns},
|
| 202 |
+
},
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
# ------------------------------------------------------------------
|
| 206 |
+
# Schema helpers
|
| 207 |
+
# ------------------------------------------------------------------
|
| 208 |
+
|
| 209 |
+
async def _fetch_full_schema(
|
| 210 |
+
self,
|
| 211 |
+
client_id: str,
|
| 212 |
+
table_names: list[str],
|
| 213 |
+
user_id: str,
|
| 214 |
+
) -> dict[str, list[dict[str, Any]]]:
|
| 215 |
+
"""Fetch ALL column chunks for the given tables from PGVector.
|
| 216 |
+
|
| 217 |
+
Returns {table_name: [{"name": ..., "type": ..., "is_primary_key": ...,
|
| 218 |
+
"foreign_key": ..., "content": ...}]}
|
| 219 |
+
"""
|
| 220 |
+
placeholders = ", ".join(f":t{i}" for i in range(len(table_names)))
|
| 221 |
+
sql = text(f"""
|
| 222 |
+
SELECT lpe.cmetadata, lpe.document
|
| 223 |
+
FROM langchain_pg_embedding lpe
|
| 224 |
+
JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid
|
| 225 |
+
WHERE lpc.name = 'document_embeddings'
|
| 226 |
+
AND lpe.cmetadata->>'user_id' = :user_id
|
| 227 |
+
AND lpe.cmetadata->>'source_type' = 'database'
|
| 228 |
+
AND lpe.cmetadata->>'database_client_id' = :client_id
|
| 229 |
+
AND lpe.cmetadata->'data'->>'table_name' IN ({placeholders})
|
| 230 |
+
ORDER BY lpe.cmetadata->'data'->>'table_name', lpe.cmetadata->'data'->>'column_name'
|
| 231 |
+
""")
|
| 232 |
+
|
| 233 |
+
params: dict[str, Any] = {"user_id": user_id, "client_id": client_id}
|
| 234 |
+
for i, name in enumerate(table_names):
|
| 235 |
+
params[f"t{i}"] = name
|
| 236 |
+
|
| 237 |
+
async with _pgvector_engine.connect() as conn:
|
| 238 |
+
result = await conn.execute(sql, params)
|
| 239 |
+
rows = result.fetchall()
|
| 240 |
+
|
| 241 |
+
schema: dict[str, list[dict[str, Any]]] = defaultdict(list)
|
| 242 |
+
for row in rows:
|
| 243 |
+
data = row.cmetadata.get("data", {})
|
| 244 |
+
table = data.get("table_name")
|
| 245 |
+
if table:
|
| 246 |
+
schema[table].append({
|
| 247 |
+
"name": data.get("column_name", ""),
|
| 248 |
+
"type": data.get("column_type", ""),
|
| 249 |
+
"is_primary_key": data.get("is_primary_key", False),
|
| 250 |
+
"foreign_key": data.get("foreign_key"),
|
| 251 |
+
"content": row.document, # chunk text includes top values / samples
|
| 252 |
+
})
|
| 253 |
+
return dict(schema)
|
| 254 |
+
|
| 255 |
+
def _build_schema_context(self, schema: dict[str, list[dict[str, Any]]]) -> str:
|
| 256 |
+
lines: list[str] = []
|
| 257 |
+
for table, columns in schema.items():
|
| 258 |
+
lines.append(f"Table: {table}")
|
| 259 |
+
for col in columns:
|
| 260 |
+
flags = []
|
| 261 |
+
if col["is_primary_key"]:
|
| 262 |
+
flags.append("PRIMARY KEY")
|
| 263 |
+
if col["foreign_key"]:
|
| 264 |
+
flags.append(f"FK -> {col['foreign_key']}")
|
| 265 |
+
flag_str = f" [{', '.join(flags)}]" if flags else ""
|
| 266 |
+
lines.append(f" - {col['name']} {col['type']}{flag_str}")
|
| 267 |
+
# Include sample/top-values line from chunk content if present
|
| 268 |
+
for line in col["content"].splitlines():
|
| 269 |
+
if line.startswith(("Top values:", "Sample values:")):
|
| 270 |
+
lines.append(f" {line}")
|
| 271 |
+
break
|
| 272 |
+
lines.append("")
|
| 273 |
+
return "\n".join(lines).strip()
|
| 274 |
+
|
| 275 |
+
def _extract_question(self, results: list[RetrievalResult]) -> str:
|
| 276 |
+
# The search_query rewritten by the orchestrator is not in RetrievalResult —
|
| 277 |
+
# the content field carries schema descriptions. Return a generic fallback;
|
| 278 |
+
# callers that have the original question should pass it explicitly.
|
| 279 |
+
# TODO: thread the original user question through to execute() when wiring into the agent.
|
| 280 |
+
return "Answer the user's data question using the schema provided."
|
| 281 |
+
|
| 282 |
+
# ------------------------------------------------------------------
|
| 283 |
+
# Guardrails
|
| 284 |
+
# ------------------------------------------------------------------
|
| 285 |
+
|
| 286 |
+
def _validate(self, sql: str, schema: dict[str, list[dict]], limit: int) -> str:
|
| 287 |
+
"""Return an error string if validation fails, empty string if OK."""
|
| 288 |
+
# Layer 1: sqlglot parse + SELECT-only check
|
| 289 |
+
try:
|
| 290 |
+
parsed = sqlglot.parse_one(sql)
|
| 291 |
+
except sqlglot.errors.ParseError as e:
|
| 292 |
+
return f"SQL parse error: {e}"
|
| 293 |
+
|
| 294 |
+
if not isinstance(parsed, exp.Select):
|
| 295 |
+
return f"Only SELECT statements are allowed. Got: {type(parsed).__name__}"
|
| 296 |
+
|
| 297 |
+
# Check for DML inside CTEs
|
| 298 |
+
for cte in parsed.find_all(exp.With):
|
| 299 |
+
for node in cte.find_all((exp.Insert, exp.Update, exp.Delete)):
|
| 300 |
+
return f"DML ({type(node).__name__}) inside CTE is not allowed."
|
| 301 |
+
|
| 302 |
+
# Layer 2: schema grounding — table names
|
| 303 |
+
known_tables = {t.lower() for t in schema}
|
| 304 |
+
for tbl in parsed.find_all(exp.Table):
|
| 305 |
+
name = tbl.name.lower()
|
| 306 |
+
if name and name not in known_tables:
|
| 307 |
+
return f"Unknown table '{tbl.name}'. Only use tables from the schema."
|
| 308 |
+
|
| 309 |
+
# Layer 3: LIMIT enforcement (inject if missing — done before execution)
|
| 310 |
+
return ""
|
| 311 |
+
|
| 312 |
+
# ------------------------------------------------------------------
|
| 313 |
+
# SQL execution
|
| 314 |
+
# ------------------------------------------------------------------
|
| 315 |
+
|
| 316 |
+
def _enforce_limit(self, sql: str, limit: int) -> str:
|
| 317 |
+
"""Inject or cap LIMIT using sqlglot AST manipulation."""
|
| 318 |
+
parsed = sqlglot.parse_one(sql)
|
| 319 |
+
existing = parsed.find(exp.Limit)
|
| 320 |
+
if existing:
|
| 321 |
+
current = int(existing.expression.this)
|
| 322 |
+
if current > limit:
|
| 323 |
+
existing.expression.set("this", str(limit))
|
| 324 |
+
else:
|
| 325 |
+
parsed = parsed.limit(limit)
|
| 326 |
+
return parsed.sql()
|
| 327 |
+
|
| 328 |
+
def _run_sql(self, engine: Any, sql: str) -> list[dict]:
|
| 329 |
+
with engine.connect() as conn:
|
| 330 |
+
result = conn.execute(text(sql))
|
| 331 |
+
return [dict(row) for row in result.mappings()]
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
db_executor = DbExecutor()
|
src/query/{executor.py → query_executor.py}
RENAMED
|
File without changes
|
uv.lock
CHANGED
|
@@ -66,6 +66,7 @@ dependencies = [
|
|
| 66 |
{ name = "spacy" },
|
| 67 |
{ name = "sqlalchemy", extra = ["asyncio"] },
|
| 68 |
{ name = "sqlalchemy-bigquery" },
|
|
|
|
| 69 |
{ name = "sse-starlette" },
|
| 70 |
{ name = "starlette" },
|
| 71 |
{ name = "structlog" },
|
|
@@ -149,6 +150,7 @@ requires-dist = [
|
|
| 149 |
{ name = "spacy", specifier = "==3.8.3" },
|
| 150 |
{ name = "sqlalchemy", extras = ["asyncio"], specifier = "==2.0.36" },
|
| 151 |
{ name = "sqlalchemy-bigquery", specifier = ">=1.11.0" },
|
|
|
|
| 152 |
{ name = "sse-starlette", specifier = "==2.1.3" },
|
| 153 |
{ name = "starlette", specifier = "==0.41.3" },
|
| 154 |
{ name = "structlog", specifier = "==24.4.0" },
|
|
@@ -3221,6 +3223,15 @@ wheels = [
|
|
| 3221 |
{ url = "https://files.pythonhosted.org/packages/c0/87/11e6de00ef7949bb8ea06b55304a1a4911c329fdf0d9882b464db240c2c5/sqlalchemy_bigquery-1.16.0-py3-none-any.whl", hash = "sha256:0fe7634cd954f3e74f5e2db6d159f9e5ee87a47fbe8d52eac3cd3bb3dadb3a77", size = 40615, upload-time = "2025-11-06T01:35:39.358Z" },
|
| 3222 |
]
|
| 3223 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3224 |
[[package]]
|
| 3225 |
name = "srsly"
|
| 3226 |
version = "2.5.3"
|
|
|
|
| 66 |
{ name = "spacy" },
|
| 67 |
{ name = "sqlalchemy", extra = ["asyncio"] },
|
| 68 |
{ name = "sqlalchemy-bigquery" },
|
| 69 |
+
{ name = "sqlglot" },
|
| 70 |
{ name = "sse-starlette" },
|
| 71 |
{ name = "starlette" },
|
| 72 |
{ name = "structlog" },
|
|
|
|
| 150 |
{ name = "spacy", specifier = "==3.8.3" },
|
| 151 |
{ name = "sqlalchemy", extras = ["asyncio"], specifier = "==2.0.36" },
|
| 152 |
{ name = "sqlalchemy-bigquery", specifier = ">=1.11.0" },
|
| 153 |
+
{ name = "sqlglot", specifier = ">=25.0.0" },
|
| 154 |
{ name = "sse-starlette", specifier = "==2.1.3" },
|
| 155 |
{ name = "starlette", specifier = "==0.41.3" },
|
| 156 |
{ name = "structlog", specifier = "==24.4.0" },
|
|
|
|
| 3223 |
{ url = "https://files.pythonhosted.org/packages/c0/87/11e6de00ef7949bb8ea06b55304a1a4911c329fdf0d9882b464db240c2c5/sqlalchemy_bigquery-1.16.0-py3-none-any.whl", hash = "sha256:0fe7634cd954f3e74f5e2db6d159f9e5ee87a47fbe8d52eac3cd3bb3dadb3a77", size = 40615, upload-time = "2025-11-06T01:35:39.358Z" },
|
| 3224 |
]
|
| 3225 |
|
| 3226 |
+
[[package]]
|
| 3227 |
+
name = "sqlglot"
|
| 3228 |
+
version = "30.6.0"
|
| 3229 |
+
source = { registry = "https://pypi.org/simple" }
|
| 3230 |
+
sdist = { url = "https://files.pythonhosted.org/packages/3c/66/6ece15f197874e56c76e1d0269cebf284ba992a80dfadca9d1972fdf7edf/sqlglot-30.6.0.tar.gz", hash = "sha256:246d34d39927422a50a3fa155f37b2f6346fba85f1a755b13c941eb32ef93361", size = 5835307, upload-time = "2026-04-20T20:11:08.164Z" }
|
| 3231 |
+
wheels = [
|
| 3232 |
+
{ url = "https://files.pythonhosted.org/packages/dc/e7/64fe971cbca33a0446b06f4a5ff8e3fa4a1dbd0a039ceabcc3e6cf4087a9/sqlglot-30.6.0-py3-none-any.whl", hash = "sha256:e005fc2f47994f90d7d8df341f1cbe937518497b0b7b1507d4c03c4c9dfd2778", size = 673920, upload-time = "2026-04-20T20:11:05.758Z" },
|
| 3233 |
+
]
|
| 3234 |
+
|
| 3235 |
[[package]]
|
| 3236 |
name = "srsly"
|
| 3237 |
version = "2.5.3"
|