Rifqi Hafizuddin
fix: minor returned type if sql writes limit yang melebihi batas
b4df8b1
raw
history blame
27.7 kB
"""Executor for registered database sources (source_type="database").
Flow per (client_id, question):
1. Collect all relevant (table_name, column_name) pairs from retrieval results.
2. Fetch the FULL schema for those tables from PGVector (not just top-k columns).
3. Build a schema context string and send to LLM β†’ structured SQLQuery output.
4. Validate via sqlglot: SELECT-only, schema-grounded, LIMIT enforced.
5. Execute on the user's DB via engine_scope + asyncio.to_thread.
6. Return QueryResult per client_id (may span multiple tables via JOINs).
Supported db_types: postgres, supabase, mysql.
Other types are skipped with a warning β€” they do not raise.
"""
import asyncio
from collections import defaultdict
from typing import Any
import sqlglot
import sqlglot.expressions as exp
import tiktoken
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import AzureChatOpenAI
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession
from src.config.settings import settings
from src.database_client.database_client_service import database_client_service
from src.db.postgres.connection import _pgvector_engine
from src.middlewares.logging import get_logger
from src.models.sql_query import SQLQuery
from src.pipeline.db_pipeline import db_pipeline_service
from src.query.base import BaseExecutor, QueryResult
from src.rag.base import RetrievalResult
from src.utils.db_credential_encryption import decrypt_credentials_dict
logger = get_logger("db_executor")
_enc = tiktoken.get_encoding("cl100k_base")
_SUPPORTED_DB_TYPES = {"postgres", "supabase", "mysql"}
_MAX_RETRIES = 3
_MAX_LIMIT = 500
_FK_EXPANSION_MAX_TABLES = 5
_SQL_SYSTEM_PROMPT = """\
You are a SQL data analyst working with a user's database.
Generate a single SQL SELECT statement that answers the user's question.
Database dialect: {dialect}
Rules:
- ONLY reference tables and columns listed in the schema below. Do not invent names.
- Always include a LIMIT clause (max {limit}).
- Do not use DELETE, UPDATE, INSERT, DROP, TRUNCATE, ALTER, CREATE, or any DDL.
- Prefer explicit JOINs over subqueries when combining tables.
- For aggregations, always alias the result column (e.g. COUNT(*) AS order_count).
- For date filtering, use dialect-appropriate functions ({dialect} syntax).
Schema:
{schema}
{error_section}"""
class DbExecutor(BaseExecutor):
def __init__(self) -> None:
self._llm = AzureChatOpenAI(
azure_deployment=settings.azureai_deployment_name_4o,
openai_api_version=settings.azureai_api_version_4o,
azure_endpoint=settings.azureai_endpoint_url_4o,
api_key=settings.azureai_api_key_4o,
temperature=0,
)
self._prompt = ChatPromptTemplate.from_messages([
("system", _SQL_SYSTEM_PROMPT),
("human", "{question}"),
])
self._chain = self._prompt | self._llm.with_structured_output(SQLQuery)
# ------------------------------------------------------------------
# Public interface
# ------------------------------------------------------------------
async def execute(
self,
results: list[RetrievalResult],
user_id: str,
db: AsyncSession,
question: str,
limit: int = 100,
) -> list[QueryResult]:
db_results = [r for r in results if r.source_type == "database"]
if not db_results:
return []
# Group by client_id β€” one SQL generation + execution pass per client
by_client: dict[str, list[RetrievalResult]] = defaultdict(list)
for r in db_results:
client_id = r.metadata.get("database_client_id", "")
if client_id:
by_client[client_id].append(r)
else:
logger.warning("db result missing database_client_id, skipping")
query_results: list[QueryResult] = []
for client_id, client_results in by_client.items():
try:
qr = await self._execute_for_client(client_id, client_results, user_id, db, question, limit)
if qr:
query_results.append(qr)
except Exception as e:
logger.error("db executor failed for client", client_id=client_id, error=str(e))
return query_results
# ------------------------------------------------------------------
# Per-client execution
# ------------------------------------------------------------------
async def _execute_for_client(
self,
client_id: str,
results: list[RetrievalResult],
user_id: str,
db: AsyncSession,
question: str,
limit: int,
) -> QueryResult | None:
client = await database_client_service.get(db, client_id)
if not client:
logger.warning("database client not found", client_id=client_id)
return None
if client.user_id != user_id:
logger.warning("client ownership mismatch", client_id=client_id)
return None
if client.db_type not in _SUPPORTED_DB_TYPES:
logger.warning("unsupported db_type for query execution", db_type=client.db_type)
return None
# Hit tables = tables retrieval pointed at directly. Get full per-column
# schema for these. Related tables (one FK hop away, both directions) are
# fetched separately in abbreviated form to give the LLM enough context
# to JOIN without paying the per-column profile token cost.
hit_tables = list({
r.metadata.get("data", {}).get("table_name")
for r in results
if r.metadata.get("data", {}).get("table_name")
})
if not hit_tables:
logger.warning("no table_name on any retrieval result", client_id=client_id)
return None
full_schema = await self._fetch_full_schema(client_id, hit_tables, user_id)
if not full_schema:
logger.warning("no schema found in vector store", client_id=client_id, tables=hit_tables)
return None
related_tables = await self._find_related_tables(client_id, user_id, hit_tables)
related_schema = (
await self._fetch_abbreviated_schema(client_id, user_id, related_tables)
if related_tables else {}
)
schema_ctx = self._build_schema_context(full_schema, related_schema)
capped_limit = min(limit, _MAX_LIMIT)
dialect = client.db_type
# SQL generation with retry
validated_sql: str | None = None
prev_error: str = ""
prev_reasoning: str = ""
for attempt in range(_MAX_RETRIES):
if prev_error:
error_section = (
f"Previous attempt reasoning: {prev_reasoning}\n"
f"Previous attempt failed: {prev_error}\n"
"Fix the issue above."
)
else:
error_section = ""
try:
prompt_text = schema_ctx + error_section + question
input_tokens = len(_enc.encode(prompt_text))
logger.info("sql generation input tokens", attempt=attempt + 1, tokens=input_tokens)
result: SQLQuery = await self._chain.ainvoke({
"schema": schema_ctx,
"dialect": dialect,
"limit": capped_limit,
"error_section": error_section,
"question": question,
})
sql = result.sql.strip()
allowed_tables = set(full_schema) | set(related_schema)
column_map: dict[str, set[str]] = {
t: {c["name"] for c in cols} for t, cols in full_schema.items()
}
for t, info in related_schema.items():
column_map[t] = set(info.get("column_names") or [])
validation_error = self._validate(sql, allowed_tables, capped_limit, column_map)
if validation_error:
prev_error = validation_error
prev_reasoning = result.reasoning
logger.warning("sql validation failed", attempt=attempt + 1, error=validation_error)
continue
validated_sql = self._enforce_limit(sql, capped_limit)
output_tokens = len(_enc.encode(result.sql)) + len(_enc.encode(result.reasoning))
logger.info(
"sql generated",
attempt=attempt + 1,
input_tokens=input_tokens,
output_tokens=output_tokens,
total_tokens=input_tokens + output_tokens,
reasoning=result.reasoning,
)
break
except Exception as e:
prev_error = str(e)
logger.warning("sql generation error", attempt=attempt + 1, error=prev_error)
if not validated_sql:
logger.error("sql generation failed after retries", client_id=client_id)
return None
# Execute on user's DB
creds = decrypt_credentials_dict(client.credentials)
with db_pipeline_service.engine_scope(client.db_type, creds) as engine:
rows = await asyncio.to_thread(self._run_sql, engine, validated_sql)
column_types = {
col["name"]: col["type"]
for cols in full_schema.values()
for col in cols
}
columns = list(rows[0].keys()) if rows else []
return QueryResult(
source_type="database",
source_id=client_id,
table_or_file=", ".join(hit_tables),
columns=columns,
rows=rows,
row_count=len(rows),
metadata={
"db_type": client.db_type,
"client_name": client.name,
"sql": validated_sql,
"column_types": {c: column_types.get(c, "unknown") for c in columns},
},
)
# ------------------------------------------------------------------
# Schema helpers
# ------------------------------------------------------------------
async def _find_related_tables(
self,
client_id: str,
user_id: str,
hit_tables: list[str],
) -> list[str]:
"""One-hop FK neighbours of `hit_tables`, both directions, excluding hits.
Prefers chunk_level='table' rows; if none exist for the client (legacy
ingest predating Phase 1), falls back to aggregating from column-chunk
metadata. Returns [] when no FK metadata is available.
Capped at _FK_EXPANSION_MAX_TABLES, ranked by edge count desc then
table name asc. A warning is logged when the cap kicks in.
"""
if not hit_tables:
return []
hit_set = set(hit_tables)
# edge_counts[related_table] = number of FK edges connecting it to the hit set
edge_counts: dict[str, int] = defaultdict(int)
# ---- Primary path: table-level chunks ----
sql = text("""
SELECT lpe.cmetadata
FROM langchain_pg_embedding lpe
JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid
WHERE lpc.name = 'document_embeddings'
AND lpe.cmetadata->>'user_id' = :user_id
AND lpe.cmetadata->>'source_type' = 'database'
AND lpe.cmetadata->>'database_client_id' = :client_id
AND lpe.cmetadata->>'chunk_level' = 'table'
""")
async with _pgvector_engine.connect() as conn:
result = await conn.execute(sql, {"user_id": user_id, "client_id": client_id})
table_rows = result.fetchall()
if table_rows:
for row in table_rows:
data = row.cmetadata.get("data", {})
table = data.get("table_name")
fks = data.get("foreign_keys") or []
if not table:
continue
if table in hit_set:
# Outgoing: this hit's FKs point at related tables
for fk in fks:
target = fk.get("target_table")
if target and target not in hit_set:
edge_counts[target] += 1
else:
# Incoming: this non-hit table's FKs point into the hit set
for fk in fks:
target = fk.get("target_table")
if target in hit_set:
edge_counts[table] += 1
else:
# ---- Fallback: aggregate from column chunks ----
sql = text("""
SELECT lpe.cmetadata->'data'->>'table_name' AS src_table,
lpe.cmetadata->'data'->>'foreign_key' AS fk
FROM langchain_pg_embedding lpe
JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid
WHERE lpc.name = 'document_embeddings'
AND lpe.cmetadata->>'user_id' = :user_id
AND lpe.cmetadata->>'source_type' = 'database'
AND lpe.cmetadata->>'database_client_id' = :client_id
AND lpe.cmetadata->>'chunk_level' = 'column'
AND lpe.cmetadata->'data'->>'foreign_key' IS NOT NULL
""")
async with _pgvector_engine.connect() as conn:
result = await conn.execute(sql, {"user_id": user_id, "client_id": client_id})
col_rows = result.fetchall()
for row in col_rows:
src = row.src_table
fk = row.fk
if not src or not fk:
continue
target = fk.split(".", 1)[0]
if src in hit_set and target and target not in hit_set:
edge_counts[target] += 1
elif src not in hit_set and target in hit_set:
edge_counts[src] += 1
if not edge_counts:
return []
ranked = sorted(edge_counts.items(), key=lambda kv: (-kv[1], kv[0]))
if len(ranked) > _FK_EXPANSION_MAX_TABLES:
logger.warning(
"fk expansion cap hit",
client_id=client_id,
total=len(ranked),
cap=_FK_EXPANSION_MAX_TABLES,
dropped=[t for t, _ in ranked[_FK_EXPANSION_MAX_TABLES:]],
)
ranked = ranked[:_FK_EXPANSION_MAX_TABLES]
related = [t for t, _ in ranked]
logger.info("fk-related tables", hit=sorted(hit_set), related=related)
return related
async def _fetch_abbreviated_schema(
self,
client_id: str,
user_id: str,
table_names: list[str],
) -> dict[str, dict[str, Any]]:
"""Abbreviated schema: name, row_count, PK, FKs, column names β€” no profiles.
Prefers chunk_level='table' rows. Falls back to aggregating column-chunk
metadata when table chunks are missing for a given table_name.
Returns {table_name: {"row_count": int|None, "primary_key": [str],
"foreign_keys": [{column, target_table, target_column}],
"column_names": [str]}}.
"""
if not table_names:
return {}
placeholders = ", ".join(f":t{i}" for i in range(len(table_names)))
params: dict[str, Any] = {"user_id": user_id, "client_id": client_id}
for i, name in enumerate(table_names):
params[f"t{i}"] = name
# Primary path: one row per table from chunk_level='table'
sql_table = text(f"""
SELECT lpe.cmetadata
FROM langchain_pg_embedding lpe
JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid
WHERE lpc.name = 'document_embeddings'
AND lpe.cmetadata->>'user_id' = :user_id
AND lpe.cmetadata->>'source_type' = 'database'
AND lpe.cmetadata->>'database_client_id' = :client_id
AND lpe.cmetadata->>'chunk_level' = 'table'
AND lpe.cmetadata->'data'->>'table_name' IN ({placeholders})
""")
async with _pgvector_engine.connect() as conn:
result = await conn.execute(sql_table, params)
t_rows = result.fetchall()
out: dict[str, dict[str, Any]] = {}
for row in t_rows:
data = row.cmetadata.get("data", {})
tname = data.get("table_name")
if not tname:
continue
out[tname] = {
"row_count": data.get("row_count"),
"primary_key": list(data.get("primary_key") or []),
"foreign_keys": list(data.get("foreign_keys") or []),
"column_names": list(data.get("column_names") or []),
}
# Fallback for tables with no table-chunk: aggregate column chunks
missing = [t for t in table_names if t not in out]
if missing:
placeholders_m = ", ".join(f":m{i}" for i in range(len(missing)))
params_m: dict[str, Any] = {"user_id": user_id, "client_id": client_id}
for i, name in enumerate(missing):
params_m[f"m{i}"] = name
sql_col = text(f"""
SELECT lpe.cmetadata
FROM langchain_pg_embedding lpe
JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid
WHERE lpc.name = 'document_embeddings'
AND lpe.cmetadata->>'user_id' = :user_id
AND lpe.cmetadata->>'source_type' = 'database'
AND lpe.cmetadata->>'database_client_id' = :client_id
AND lpe.cmetadata->>'chunk_level' = 'column'
AND lpe.cmetadata->'data'->>'table_name' IN ({placeholders_m})
ORDER BY lpe.cmetadata->'data'->>'table_name', lpe.cmetadata->'data'->>'column_name'
""")
async with _pgvector_engine.connect() as conn:
result = await conn.execute(sql_col, params_m)
c_rows = result.fetchall()
agg: dict[str, dict[str, Any]] = {
t: {"row_count": None, "primary_key": [], "foreign_keys": [], "column_names": []}
for t in missing
}
for row in c_rows:
data = row.cmetadata.get("data", {})
tname = data.get("table_name")
cname = data.get("column_name")
if not tname or tname not in agg or not cname:
continue
bucket = agg[tname]
bucket["column_names"].append(cname)
if data.get("is_primary_key"):
bucket["primary_key"].append(cname)
fk = data.get("foreign_key")
if fk:
target_table, _, target_col = fk.partition(".")
bucket["foreign_keys"].append({
"column": cname,
"target_table": target_table,
"target_column": target_col,
})
for t, v in agg.items():
if v["column_names"]:
out[t] = v
return out
async def _fetch_full_schema(
self,
client_id: str,
table_names: list[str],
user_id: str,
) -> dict[str, list[dict[str, Any]]]:
"""Fetch ALL column chunks for the given tables from PGVector.
Returns {table_name: [{"name": ..., "type": ..., "is_primary_key": ...,
"foreign_key": ..., "content": ...}]}
"""
placeholders = ", ".join(f":t{i}" for i in range(len(table_names)))
sql = text(f"""
SELECT lpe.cmetadata, lpe.document
FROM langchain_pg_embedding lpe
JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid
WHERE lpc.name = 'document_embeddings'
AND lpe.cmetadata->>'user_id' = :user_id
AND lpe.cmetadata->>'source_type' = 'database'
AND lpe.cmetadata->>'chunk_level' = 'column'
AND lpe.cmetadata->>'database_client_id' = :client_id
AND lpe.cmetadata->'data'->>'table_name' IN ({placeholders})
ORDER BY lpe.cmetadata->'data'->>'table_name', lpe.cmetadata->'data'->>'column_name'
""")
params: dict[str, Any] = {"user_id": user_id, "client_id": client_id}
for i, name in enumerate(table_names):
params[f"t{i}"] = name
async with _pgvector_engine.connect() as conn:
result = await conn.execute(sql, params)
rows = result.fetchall()
schema: dict[str, list[dict[str, Any]]] = defaultdict(list)
for row in rows:
data = row.cmetadata.get("data", {})
table = data.get("table_name")
if table:
schema[table].append({
"name": data.get("column_name", ""),
"type": data.get("column_type", ""),
"is_primary_key": data.get("is_primary_key", False),
"foreign_key": data.get("foreign_key"),
"content": row.document, # chunk text includes top values / samples
})
return dict(schema)
def _build_schema_context(
self,
schema: dict[str, list[dict[str, Any]]],
related_schema: dict[str, dict[str, Any]] | None = None,
) -> str:
lines: list[str] = []
for table, columns in schema.items():
lines.append(f"Table: {table}")
for col in columns:
flags = []
if col["is_primary_key"]:
flags.append("PRIMARY KEY")
if col["foreign_key"]:
flags.append(f"FK -> {col['foreign_key']}")
flag_str = f" [{', '.join(flags)}]" if flags else ""
lines.append(f" - {col['name']} {col['type']}{flag_str}")
# Include sample/top-values line from chunk content if present
for line in col["content"].splitlines():
if line.startswith(("Top values:", "Sample values:")):
lines.append(f" {line}")
break
lines.append("")
related_block = self._build_related_schema_block(related_schema or {})
if related_block:
lines.append(related_block)
return "\n".join(lines).strip()
def _build_related_schema_block(self, related_schema: dict[str, dict[str, Any]]) -> str:
"""Format the abbreviated FK-related-tables section. Empty string when no related."""
if not related_schema:
return ""
lines: list[str] = ["Related tables (one hop via FK, abbreviated β€” use for JOINs only):"]
for table, info in related_schema.items():
row_count = info.get("row_count")
header = f"- {table} ({row_count} rows)" if row_count is not None else f"- {table}"
lines.append(header)
pk = info.get("primary_key") or []
lines.append(f" Primary key: {', '.join(pk) if pk else '(none)'}")
fks = info.get("foreign_keys") or []
if fks:
fk_strs = [
f"{fk.get('column')} -> {fk.get('target_table')}.{fk.get('target_column')}"
for fk in fks
]
lines.append(f" Foreign keys: {', '.join(fk_strs)}")
else:
lines.append(" Foreign keys: (none)")
cols = info.get("column_names") or []
lines.append(f" Columns: {', '.join(cols)}")
return "\n".join(lines)
# ------------------------------------------------------------------
# Guardrails
# ------------------------------------------------------------------
def _validate(
self,
sql: str,
allowed_tables: set[str],
limit: int,
column_map: dict[str, set[str]] | None = None,
) -> str:
"""Return an error string if validation fails, empty string if OK.
`allowed_tables` is the union of hit-table names and FK-related table
names β€” both are legal targets for SELECT/JOIN.
`column_map` maps table_name β†’ set of valid column names. When provided,
any qualified table.column reference not found in the map triggers a retry
with an informative error so the LLM can self-correct without hallucinating.
"""
# Layer 1: sqlglot parse + SELECT-only check
try:
parsed = sqlglot.parse_one(sql)
except sqlglot.errors.ParseError as e:
return f"SQL parse error: {e}"
if not isinstance(parsed, exp.Select):
return f"Only SELECT statements are allowed. Got: {type(parsed).__name__}"
# Check for DML anywhere in the AST (including writeable CTEs)
for node in parsed.find_all((exp.Insert, exp.Update, exp.Delete)):
return f"DML ({type(node).__name__}) is not allowed."
# Layer 2: schema grounding β€” table names
known_tables = {t.lower() for t in allowed_tables}
alias_to_table: dict[str, str] = {}
for tbl in parsed.find_all(exp.Table):
name = tbl.name.lower()
if name and name not in known_tables:
return f"Unknown table '{tbl.name}'. Only use tables from the schema."
alias = (tbl.alias or tbl.name).lower()
alias_to_table[alias] = name
# Layer 3: column grounding β€” qualified references only (table.column)
if column_map:
normalized_map = {t.lower(): {c.lower() for c in cols} for t, cols in column_map.items()}
for col_node in parsed.find_all(exp.Column):
tbl_ref = col_node.table
if not tbl_ref:
continue # unqualified β€” skip, can't resolve without full alias tracking
tbl_name = alias_to_table.get(tbl_ref.lower(), tbl_ref.lower())
col_name = col_node.name.lower()
if tbl_name in normalized_map and col_name not in normalized_map[tbl_name]:
available = ", ".join(sorted(normalized_map[tbl_name]))
return (
f"Column '{col_node.name}' does not exist on table '{tbl_name}'. "
f"Available columns: {available}."
)
# Layer 4: LIMIT enforcement (inject if missing β€” done before execution)
return ""
# ------------------------------------------------------------------
# SQL execution
# ------------------------------------------------------------------
def _enforce_limit(self, sql: str, limit: int) -> str:
"""Inject or cap LIMIT using sqlglot AST manipulation."""
parsed = sqlglot.parse_one(sql)
existing = parsed.find(exp.Limit)
if existing:
current = int(existing.expression.this)
if current > limit:
return parsed.limit(limit).sql()
else:
return parsed.limit(limit).sql()
return parsed.sql()
def _run_sql(self, engine: Any, sql: str) -> list[dict]:
# Ensure the user DB connection is a read-only credential β€” sqlglot validation alone is not sufficient.
with engine.connect() as conn:
result = conn.execute(text(sql))
return [dict(row) for row in result.mappings()]
db_executor = DbExecutor()