ishaq101's picture
feat/Knowledge & Data Tools (#3)
0721bb4
Raw
History Blame
12.5 kB
"""SqlCompiler — IR → (SQL string, named-params dict).
Identifiers (table / column names) come from the catalog and are quoted
verbatim — they were verified by the IR validator against the catalog,
so injection through identifiers is not possible at this layer.
Values from filter clauses are ALWAYS parameterized.
The output `CompiledSql.sql` uses SQLAlchemy-style named placeholders
(`:p_0, :p_1, ...`) so it can be executed via `text(sql)` with a params
dict on a sync SQLAlchemy engine.
Joins (KM-652 T4): a column_id resolves to its owning table across the base
table + any joined tables, so every reference is emitted table-qualified
(`"table"."col"`). With `joins=[]` the output is identical to the single-table
form. v1 supports the Postgres dialect only (Supabase = Postgres).
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any
from ...catalog.models import Catalog, Column, Source, Table
from ..ir.models import (
AggSelect,
ColumnSelect,
FilterClause,
OrderByClause,
QueryIR,
SelectItem,
)
from .base import BaseCompiler
# Hard ceiling on rows returned to the agent layer. Every compiled query is
# bounded by this even when the IR sets no limit, so an unbounded SELECT can never
# stream an entire user table over the wire / into memory. The executor caps to
# `row_cap` and flags truncation.
MAX_RESULT_ROWS = 10_000
# A resolved column reference: the table that owns it + the column itself.
ColRef = tuple[Table, Column]
@dataclass
class CompiledSql:
sql: str
params: dict[str, Any] = field(default_factory=dict)
row_cap: int = MAX_RESULT_ROWS # executor caps rows to this; flags truncation
class SqlCompilerError(Exception):
pass
_NULLARY_OPS = frozenset({"is_null", "is_not_null"})
_LIST_OPS = frozenset({"in", "not_in"})
_COMPARISON_OPS = frozenset({"=", "!=", "<", "<=", ">", ">="})
class SqlCompiler(BaseCompiler):
"""Deterministic IR → Postgres SQL. No LLM."""
def __init__(self, catalog: Catalog, dialect: str = "postgres") -> None:
if dialect not in {"postgres", "supabase"}:
raise SqlCompilerError(
f"only 'postgres' / 'supabase' supported in v1, got {dialect!r}"
)
self._catalog = catalog
self._dialect = dialect
def compile(self, ir: QueryIR) -> CompiledSql:
base_table, cols_by_id = self._lookup(ir)
params: dict[str, Any] = {}
param_seq = [0]
select_clause, select_aliases = self._build_select(ir.select, cols_by_id)
from_clause = self._build_from(base_table, ir, cols_by_id)
where_clause = self._build_where(ir.filters, cols_by_id, params, param_seq)
groupby_clause = self._build_groupby(ir.group_by, cols_by_id)
orderby_clause = self._build_orderby(ir.order_by, cols_by_id, select_aliases)
limit_clause, row_cap = self._build_limit(ir.limit)
parts: list[str] = [select_clause, from_clause]
for clause in (where_clause, groupby_clause, orderby_clause, limit_clause):
if clause:
parts.append(clause)
return CompiledSql(sql=" ".join(parts), params=params, row_cap=row_cap)
# ------------------------------------------------------------------
# Catalog lookup — column_id -> (owning Table, Column) across base + joins
# ------------------------------------------------------------------
def _lookup(self, ir: QueryIR) -> tuple[Table, dict[str, ColRef]]:
source = next(
(s for s in self._catalog.sources if s.source_id == ir.source_id), None
)
if source is None:
raise SqlCompilerError(f"source_id {ir.source_id!r} not in catalog")
base_table = self._find_table(source, ir.table_id)
cols_by_id: dict[str, ColRef] = {
c.column_id: (base_table, c) for c in base_table.columns
}
for j in ir.joins:
target = self._find_table(source, j.target_table_id)
for c in target.columns:
cols_by_id[c.column_id] = (target, c)
return base_table, cols_by_id
@staticmethod
def _find_table(source: Source, table_id: str) -> Table:
table = next((t for t in source.tables if t.table_id == table_id), None)
if table is None:
raise SqlCompilerError(
f"table_id {table_id!r} not in source {source.source_id!r}"
)
return table
# ------------------------------------------------------------------
# Identifier quoting
# ------------------------------------------------------------------
@staticmethod
def _qident(name: str) -> str:
"""Postgres-style double-quoted identifier with embedded-quote escape."""
return '"' + name.replace('"', '""') + '"'
def _qcol(self, ref: ColRef) -> str:
table, col = ref
return f"{self._qident(table.name)}.{self._qident(col.name)}"
# ------------------------------------------------------------------
# Clauses
# ------------------------------------------------------------------
def _build_select(
self, items: list[SelectItem], cols_by_id: dict[str, ColRef]
) -> tuple[str, set[str]]:
if not items:
raise SqlCompilerError("select clause cannot be empty")
parts: list[str] = []
aliases: set[str] = set()
for i, item in enumerate(items):
expr, alias = self._select_item(item, cols_by_id, i)
if alias:
parts.append(f"{expr} AS {self._qident(alias)}")
aliases.add(alias)
else:
parts.append(expr)
return "SELECT " + ", ".join(parts), aliases
def _select_item(
self, item: SelectItem, cols_by_id: dict[str, ColRef], index: int
) -> tuple[str, str | None]:
if isinstance(item, ColumnSelect):
ref = self._require_col(cols_by_id, item.column_id, f"select[{index}]")
return self._qcol(ref), item.alias
if not isinstance(item, AggSelect):
raise SqlCompilerError(
f"select[{index}]: unknown SelectItem kind {type(item).__name__}"
)
return self._compile_agg(item, cols_by_id, index), item.alias
def _compile_agg(
self, item: AggSelect, cols_by_id: dict[str, ColRef], index: int
) -> str:
if item.fn == "count_distinct":
if item.column_id is None:
raise SqlCompilerError(
f"select[{index}].fn=count_distinct requires column_id"
)
ref = self._require_col(cols_by_id, item.column_id, f"select[{index}]")
return f"COUNT(DISTINCT {self._qcol(ref)})"
if item.column_id is None:
if item.fn != "count":
raise SqlCompilerError(
f"select[{index}].fn={item.fn!r} requires column_id "
"(only 'count' may omit it for COUNT(*))"
)
return "COUNT(*)"
ref = self._require_col(cols_by_id, item.column_id, f"select[{index}]")
return f"{item.fn.upper()}({self._qcol(ref)})"
def _build_from(
self, base_table: Table, ir: QueryIR, cols_by_id: dict[str, ColRef]
) -> str:
sql = f"FROM {self._qident(base_table.name)}"
for i, j in enumerate(ir.joins):
left = self._require_col(cols_by_id, j.left_column_id, f"joins[{i}].left")
right = self._require_col(cols_by_id, j.right_column_id, f"joins[{i}].right")
target_table = right[0]
keyword = "INNER JOIN" if j.type == "inner" else "LEFT JOIN"
sql += (
f" {keyword} {self._qident(target_table.name)} "
f"ON {self._qcol(left)} = {self._qcol(right)}"
)
return sql
def _build_where(
self,
filters: list[FilterClause],
cols_by_id: dict[str, ColRef],
params: dict[str, Any],
param_seq: list[int],
) -> str:
if not filters:
return ""
parts = [
self._compile_filter(f, cols_by_id, params, param_seq, index=i)
for i, f in enumerate(filters)
]
return "WHERE " + " AND ".join(parts)
def _compile_filter(
self,
f: FilterClause,
cols_by_id: dict[str, ColRef],
params: dict[str, Any],
param_seq: list[int],
index: int,
) -> str:
ref = self._require_col(cols_by_id, f.column_id, f"filters[{index}]")
col_ref = self._qcol(ref)
op = f.op
if op == "is_null":
return f"{col_ref} IS NULL"
if op == "is_not_null":
return f"{col_ref} IS NOT NULL"
if op in _LIST_OPS:
if not isinstance(f.value, list) or not f.value:
raise SqlCompilerError(
f"filters[{index}]: op {op!r} requires a non-empty list value"
)
placeholders = [
":" + self._next_param(params, param_seq, v) for v in f.value
]
sql_op = "IN" if op == "in" else "NOT IN"
return f"{col_ref} {sql_op} ({', '.join(placeholders)})"
if op == "between":
if not isinstance(f.value, list) or len(f.value) != 2:
raise SqlCompilerError(
f"filters[{index}]: op 'between' requires a list of two values"
)
lo = self._next_param(params, param_seq, f.value[0])
hi = self._next_param(params, param_seq, f.value[1])
return f"{col_ref} BETWEEN :{lo} AND :{hi}"
if op == "like":
p = self._next_param(params, param_seq, f.value)
return f"{col_ref} LIKE :{p}"
if op in _COMPARISON_OPS:
p = self._next_param(params, param_seq, f.value)
return f"{col_ref} {op} :{p}"
# Should not reach here — IRValidator already filters disallowed ops
raise SqlCompilerError(f"filters[{index}]: unhandled op {op!r}")
def _build_groupby(
self, group_by: list[str], cols_by_id: dict[str, ColRef]
) -> str:
if not group_by:
return ""
parts = [
self._qcol(self._require_col(cols_by_id, col_id, f"group_by[{i}]"))
for i, col_id in enumerate(group_by)
]
return "GROUP BY " + ", ".join(parts)
def _build_orderby(
self,
order_by: list[OrderByClause],
cols_by_id: dict[str, ColRef],
select_aliases: set[str],
) -> str:
if not order_by:
return ""
parts: list[str] = []
for i, ob in enumerate(order_by):
if ob.column_id in cols_by_id:
ref = self._qcol(cols_by_id[ob.column_id])
elif ob.column_id in select_aliases:
ref = self._qident(ob.column_id)
else:
raise SqlCompilerError(
f"order_by[{i}].column_id: {ob.column_id!r} not in query "
"columns or select aliases"
)
parts.append(f"{ref} {ob.dir.upper()}")
return "ORDER BY " + ", ".join(parts)
def _build_limit(self, limit: int | None) -> tuple[str, int]:
"""Return (LIMIT clause, row_cap).
Always bounded. An explicit IR limit is honored exactly (capped at
MAX_RESULT_ROWS). When the IR has no limit we still emit
`LIMIT MAX_RESULT_ROWS + 1` — the extra row lets the executor tell
"exactly the cap" from "more rows existed" and flag truncation.
"""
if limit is None:
return f"LIMIT {MAX_RESULT_ROWS + 1}", MAX_RESULT_ROWS
row_cap = min(int(limit), MAX_RESULT_ROWS)
return f"LIMIT {row_cap}", row_cap
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
@staticmethod
def _next_param(
params: dict[str, Any], param_seq: list[int], value: Any
) -> str:
name = f"p_{param_seq[0]}"
param_seq[0] += 1
params[name] = value
return name
@staticmethod
def _require_col(
cols_by_id: dict[str, ColRef], col_id: str, where: str
) -> ColRef:
ref = cols_by_id.get(col_id)
if ref is None:
raise SqlCompilerError(f"{where}.column_id: {col_id!r} not in query tables")
return ref