ishaq101's picture
Merge dev_new to main (improve retriever and add querying) (#13)
52999bc
raw
history blame
16.6 kB
"""Executor for registered database sources (source_type="database").
Flow per (client_id, question):
1. Collect all relevant (table_name, column_name) pairs from retrieval results.
2. Fetch the FULL schema for those tables from PGVector (not just top-k columns).
3. Build a schema context string and send to LLM β†’ structured SQLQuery output.
4. Validate via sqlglot: SELECT-only, schema-grounded, LIMIT enforced.
5. Execute on the user's DB via engine_scope + asyncio.to_thread.
6. Return QueryResult per client_id (may span multiple tables via JOINs).
Supported db_types: postgres, supabase, mysql.
Other types are skipped with a warning β€” they do not raise.
"""
import asyncio
from collections import defaultdict
from typing import Any
import sqlglot
import sqlglot.expressions as exp
import tiktoken
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import AzureChatOpenAI
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession
from src.config.settings import settings
from src.database_client.database_client_service import database_client_service
from src.db.postgres.connection import _pgvector_engine
from src.middlewares.logging import get_logger
from src.models.sql_query import SQLQuery
from src.pipeline.db_pipeline import db_pipeline_service
from src.query.base import BaseExecutor, QueryResult
from src.rag.base import RetrievalResult
from src.utils.db_credential_encryption import decrypt_credentials_dict
logger = get_logger("db_executor")
_enc = tiktoken.get_encoding("cl100k_base")
_SUPPORTED_DB_TYPES = {"postgres", "supabase", "mysql"}
_MAX_RETRIES = 3
_MAX_LIMIT = 500
_SQL_SYSTEM_PROMPT = """\
You are a SQL data analyst working with a user's database.
Generate a single SQL SELECT statement that answers the user's question.
Database dialect: {dialect}
Rules:
- ONLY reference tables and columns listed in the schema below. Do not invent names.
- Always include a LIMIT clause (max {limit}).
- Do not use DELETE, UPDATE, INSERT, DROP, TRUNCATE, ALTER, CREATE, or any DDL.
- Prefer explicit JOINs over subqueries when combining tables.
- For aggregations, always alias the result column (e.g. COUNT(*) AS order_count).
- For date filtering, use dialect-appropriate functions ({dialect} syntax).
Schema:
{schema}
{error_section}"""
class DbExecutor(BaseExecutor):
def __init__(self) -> None:
self._llm = AzureChatOpenAI(
azure_deployment=settings.azureai_deployment_name_4o,
openai_api_version=settings.azureai_api_version_4o,
azure_endpoint=settings.azureai_endpoint_url_4o,
api_key=settings.azureai_api_key_4o,
temperature=0,
)
self._prompt = ChatPromptTemplate.from_messages([
("system", _SQL_SYSTEM_PROMPT),
("human", "{question}"),
])
self._chain = self._prompt | self._llm.with_structured_output(SQLQuery)
# ------------------------------------------------------------------
# Public interface
# ------------------------------------------------------------------
async def execute(
self,
results: list[RetrievalResult],
user_id: str,
db: AsyncSession,
question: str,
limit: int = 100,
) -> list[QueryResult]:
db_results = [r for r in results if r.source_type == "database"]
if not db_results:
return []
# Group by client_id β€” one SQL generation + execution pass per client
by_client: dict[str, list[RetrievalResult]] = defaultdict(list)
for r in db_results:
client_id = r.metadata.get("database_client_id", "")
if client_id:
by_client[client_id].append(r)
else:
logger.warning("db result missing database_client_id, skipping")
query_results: list[QueryResult] = []
for client_id, client_results in by_client.items():
try:
qr = await self._execute_for_client(client_id, client_results, user_id, db, question, limit)
if qr:
query_results.append(qr)
except Exception as e:
logger.error("db executor failed for client", client_id=client_id, error=str(e))
return query_results
# ------------------------------------------------------------------
# Per-client execution
# ------------------------------------------------------------------
async def _execute_for_client(
self,
client_id: str,
results: list[RetrievalResult],
user_id: str,
db: AsyncSession,
question: str,
limit: int,
) -> QueryResult | None:
client = await database_client_service.get(db, client_id)
if not client:
logger.warning("database client not found", client_id=client_id)
return None
if client.user_id != user_id:
logger.warning("client ownership mismatch", client_id=client_id)
return None
if client.db_type not in _SUPPORTED_DB_TYPES:
logger.warning("unsupported db_type for query execution", db_type=client.db_type)
return None
# Distinct table names from retrieval results, expanded via FK relationships
table_names = list({
r.metadata.get("data", {}).get("table_name")
for r in results
if r.metadata.get("data", {}).get("table_name")
})
table_names = await self._expand_with_fk_tables(client_id, user_id, table_names)
full_schema = await self._fetch_full_schema(client_id, table_names, user_id)
if not full_schema:
logger.warning("no schema found in vector store", client_id=client_id, tables=table_names)
return None
schema_ctx = self._build_schema_context(full_schema)
capped_limit = min(limit, _MAX_LIMIT)
dialect = client.db_type
# SQL generation with retry
validated_sql: str | None = None
prev_error: str = ""
prev_reasoning: str = ""
for attempt in range(_MAX_RETRIES):
if prev_error:
error_section = (
f"Previous attempt reasoning: {prev_reasoning}\n"
f"Previous attempt failed: {prev_error}\n"
"Fix the issue above."
)
else:
error_section = ""
try:
prompt_text = schema_ctx + error_section + question
input_tokens = len(_enc.encode(prompt_text))
logger.info("sql generation input tokens", attempt=attempt + 1, tokens=input_tokens)
result: SQLQuery = await self._chain.ainvoke({
"schema": schema_ctx,
"dialect": dialect,
"limit": capped_limit,
"error_section": error_section,
"question": question,
})
sql = result.sql.strip()
validation_error = self._validate(sql, full_schema, capped_limit)
if validation_error:
prev_error = validation_error
prev_reasoning = result.reasoning
logger.warning("sql validation failed", attempt=attempt + 1, error=validation_error)
continue
validated_sql = self._enforce_limit(sql, capped_limit)
output_tokens = len(_enc.encode(result.sql)) + len(_enc.encode(result.reasoning))
logger.info(
"sql generated",
attempt=attempt + 1,
input_tokens=input_tokens,
output_tokens=output_tokens,
total_tokens=input_tokens + output_tokens,
reasoning=result.reasoning,
)
break
except Exception as e:
prev_error = str(e)
logger.warning("sql generation error", attempt=attempt + 1, error=prev_error)
if not validated_sql:
logger.error("sql generation failed after retries", client_id=client_id)
return None
# Execute on user's DB
creds = decrypt_credentials_dict(client.credentials)
with db_pipeline_service.engine_scope(client.db_type, creds) as engine:
rows = await asyncio.to_thread(self._run_sql, engine, validated_sql)
column_types = {
col["name"]: col["type"]
for cols in full_schema.values()
for col in cols
}
columns = list(rows[0].keys()) if rows else []
return QueryResult(
source_type="database",
source_id=client_id,
table_or_file=", ".join(table_names),
columns=columns,
rows=rows,
row_count=len(rows),
metadata={
"db_type": client.db_type,
"client_name": client.name,
"sql": validated_sql,
"column_types": {c: column_types.get(c, "unknown") for c in columns},
},
)
# ------------------------------------------------------------------
# Schema helpers
# ------------------------------------------------------------------
async def _expand_with_fk_tables(
self,
client_id: str,
user_id: str,
table_names: list[str],
) -> list[str]:
"""Expand table_names with any tables FK-referenced by the retrieved tables.
Prevents SQL generation failures when a required table (e.g. orders) wasn't
returned by retrieval but is referenced via FK from a table that was
(e.g. order_items.order_id -> orders.id).
"""
if not table_names:
return table_names
placeholders = ", ".join(f":t{i}" for i in range(len(table_names)))
sql = text(f"""
SELECT DISTINCT lpe.cmetadata->'data'->>'foreign_key' AS fk
FROM langchain_pg_embedding lpe
JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid
WHERE lpc.name = 'document_embeddings'
AND lpe.cmetadata->>'user_id' = :user_id
AND lpe.cmetadata->>'source_type' = 'database'
AND lpe.cmetadata->>'database_client_id' = :client_id
AND lpe.cmetadata->'data'->>'table_name' IN ({placeholders})
AND lpe.cmetadata->'data'->>'foreign_key' IS NOT NULL
""")
params: dict[str, Any] = {"user_id": user_id, "client_id": client_id}
for i, name in enumerate(table_names):
params[f"t{i}"] = name
async with _pgvector_engine.connect() as conn:
result = await conn.execute(sql, params)
rows = result.fetchall()
expanded = set(table_names)
for row in rows:
fk = row.fk # format: "referred_table.referred_column"
if fk:
referred_table = fk.split(".")[0]
expanded.add(referred_table)
if expanded != set(table_names):
logger.info(
"expanded tables via FK",
original=sorted(table_names),
expanded=sorted(expanded),
)
return list(expanded)
async def _fetch_full_schema(
self,
client_id: str,
table_names: list[str],
user_id: str,
) -> dict[str, list[dict[str, Any]]]:
"""Fetch ALL column chunks for the given tables from PGVector.
Returns {table_name: [{"name": ..., "type": ..., "is_primary_key": ...,
"foreign_key": ..., "content": ...}]}
"""
placeholders = ", ".join(f":t{i}" for i in range(len(table_names)))
sql = text(f"""
SELECT lpe.cmetadata, lpe.document
FROM langchain_pg_embedding lpe
JOIN langchain_pg_collection lpc ON lpe.collection_id = lpc.uuid
WHERE lpc.name = 'document_embeddings'
AND lpe.cmetadata->>'user_id' = :user_id
AND lpe.cmetadata->>'source_type' = 'database'
AND lpe.cmetadata->>'database_client_id' = :client_id
AND lpe.cmetadata->'data'->>'table_name' IN ({placeholders})
ORDER BY lpe.cmetadata->'data'->>'table_name', lpe.cmetadata->'data'->>'column_name'
""")
params: dict[str, Any] = {"user_id": user_id, "client_id": client_id}
for i, name in enumerate(table_names):
params[f"t{i}"] = name
async with _pgvector_engine.connect() as conn:
result = await conn.execute(sql, params)
rows = result.fetchall()
schema: dict[str, list[dict[str, Any]]] = defaultdict(list)
for row in rows:
data = row.cmetadata.get("data", {})
table = data.get("table_name")
if table:
schema[table].append({
"name": data.get("column_name", ""),
"type": data.get("column_type", ""),
"is_primary_key": data.get("is_primary_key", False),
"foreign_key": data.get("foreign_key"),
"content": row.document, # chunk text includes top values / samples
})
return dict(schema)
def _build_schema_context(self, schema: dict[str, list[dict[str, Any]]]) -> str:
lines: list[str] = []
for table, columns in schema.items():
lines.append(f"Table: {table}")
for col in columns:
flags = []
if col["is_primary_key"]:
flags.append("PRIMARY KEY")
if col["foreign_key"]:
flags.append(f"FK -> {col['foreign_key']}")
flag_str = f" [{', '.join(flags)}]" if flags else ""
lines.append(f" - {col['name']} {col['type']}{flag_str}")
# Include sample/top-values line from chunk content if present
for line in col["content"].splitlines():
if line.startswith(("Top values:", "Sample values:")):
lines.append(f" {line}")
break
lines.append("")
return "\n".join(lines).strip()
# ------------------------------------------------------------------
# Guardrails
# ------------------------------------------------------------------
def _validate(self, sql: str, schema: dict[str, list[dict]], limit: int) -> str:
"""Return an error string if validation fails, empty string if OK."""
# Layer 1: sqlglot parse + SELECT-only check
try:
parsed = sqlglot.parse_one(sql)
except sqlglot.errors.ParseError as e:
return f"SQL parse error: {e}"
if not isinstance(parsed, exp.Select):
return f"Only SELECT statements are allowed. Got: {type(parsed).__name__}"
# Check for DML anywhere in the AST (including writeable CTEs)
for node in parsed.find_all((exp.Insert, exp.Update, exp.Delete)):
return f"DML ({type(node).__name__}) is not allowed."
# Layer 2: schema grounding β€” table names
known_tables = {t.lower() for t in schema}
for tbl in parsed.find_all(exp.Table):
name = tbl.name.lower()
if name and name not in known_tables:
return f"Unknown table '{tbl.name}'. Only use tables from the schema."
# Layer 3: LIMIT enforcement (inject if missing β€” done before execution)
return ""
# ------------------------------------------------------------------
# SQL execution
# ------------------------------------------------------------------
def _enforce_limit(self, sql: str, limit: int) -> str:
"""Inject or cap LIMIT using sqlglot AST manipulation."""
parsed = sqlglot.parse_one(sql)
existing = parsed.find(exp.Limit)
if existing:
current = int(existing.expression.this)
if current > limit:
existing.expression.set("this", limit)
else:
parsed = parsed.limit(limit)
return parsed.sql()
def _run_sql(self, engine: Any, sql: str) -> list[dict]:
# Ensure the user DB connection is a read-only credential β€” sqlglot validation alone is not sufficient.
with engine.connect() as conn:
result = conn.execute(text(sql))
return [dict(row) for row in result.mappings()]
db_executor = DbExecutor()