"""SqlCompiler — IR → (SQL string, named-params dict). Identifiers (table / column names) come from the catalog and are quoted verbatim — they were verified by the IR validator against the catalog, so injection through identifiers is not possible at this layer. Values from filter clauses are ALWAYS parameterized. The output `CompiledSql.sql` uses SQLAlchemy-style named placeholders (`:p_0, :p_1, ...`) so it can be executed via `text(sql)` with a params dict on a sync SQLAlchemy engine. Joins (KM-652 T4): a column_id resolves to its owning table across the base table + any joined tables, so every reference is emitted table-qualified (`"table"."col"`). With `joins=[]` the output is identical to the single-table form. v1 supports the Postgres dialect only (Supabase = Postgres). """ from __future__ import annotations from dataclasses import dataclass, field from typing import Any from ...catalog.models import Catalog, Column, Source, Table from ..ir.models import ( AggSelect, ColumnSelect, FilterClause, OrderByClause, QueryIR, SelectItem, ) from .base import BaseCompiler # Hard ceiling on rows returned to the agent layer. Every compiled query is # bounded by this even when the IR sets no limit, so an unbounded SELECT can never # stream an entire user table over the wire / into memory. The executor caps to # `row_cap` and flags truncation. MAX_RESULT_ROWS = 10_000 # A resolved column reference: the table that owns it + the column itself. ColRef = tuple[Table, Column] @dataclass class CompiledSql: sql: str params: dict[str, Any] = field(default_factory=dict) row_cap: int = MAX_RESULT_ROWS # executor caps rows to this; flags truncation class SqlCompilerError(Exception): pass _NULLARY_OPS = frozenset({"is_null", "is_not_null"}) _LIST_OPS = frozenset({"in", "not_in"}) _COMPARISON_OPS = frozenset({"=", "!=", "<", "<=", ">", ">="}) class SqlCompiler(BaseCompiler): """Deterministic IR → Postgres SQL. No LLM.""" def __init__(self, catalog: Catalog, dialect: str = "postgres") -> None: if dialect not in {"postgres", "supabase"}: raise SqlCompilerError( f"only 'postgres' / 'supabase' supported in v1, got {dialect!r}" ) self._catalog = catalog self._dialect = dialect def compile(self, ir: QueryIR) -> CompiledSql: base_table, cols_by_id = self._lookup(ir) params: dict[str, Any] = {} param_seq = [0] select_clause, select_aliases = self._build_select(ir.select, cols_by_id) from_clause = self._build_from(base_table, ir, cols_by_id) where_clause = self._build_where(ir.filters, cols_by_id, params, param_seq) groupby_clause = self._build_groupby(ir.group_by, cols_by_id) orderby_clause = self._build_orderby(ir.order_by, cols_by_id, select_aliases) limit_clause, row_cap = self._build_limit(ir.limit) parts: list[str] = [select_clause, from_clause] for clause in (where_clause, groupby_clause, orderby_clause, limit_clause): if clause: parts.append(clause) return CompiledSql(sql=" ".join(parts), params=params, row_cap=row_cap) # ------------------------------------------------------------------ # Catalog lookup — column_id -> (owning Table, Column) across base + joins # ------------------------------------------------------------------ def _lookup(self, ir: QueryIR) -> tuple[Table, dict[str, ColRef]]: source = next( (s for s in self._catalog.sources if s.source_id == ir.source_id), None ) if source is None: raise SqlCompilerError(f"source_id {ir.source_id!r} not in catalog") base_table = self._find_table(source, ir.table_id) cols_by_id: dict[str, ColRef] = { c.column_id: (base_table, c) for c in base_table.columns } for j in ir.joins: target = self._find_table(source, j.target_table_id) for c in target.columns: cols_by_id[c.column_id] = (target, c) return base_table, cols_by_id @staticmethod def _find_table(source: Source, table_id: str) -> Table: table = next((t for t in source.tables if t.table_id == table_id), None) if table is None: raise SqlCompilerError( f"table_id {table_id!r} not in source {source.source_id!r}" ) return table # ------------------------------------------------------------------ # Identifier quoting # ------------------------------------------------------------------ @staticmethod def _qident(name: str) -> str: """Postgres-style double-quoted identifier with embedded-quote escape.""" return '"' + name.replace('"', '""') + '"' def _qcol(self, ref: ColRef) -> str: table, col = ref return f"{self._qident(table.name)}.{self._qident(col.name)}" # ------------------------------------------------------------------ # Clauses # ------------------------------------------------------------------ def _build_select( self, items: list[SelectItem], cols_by_id: dict[str, ColRef] ) -> tuple[str, set[str]]: if not items: raise SqlCompilerError("select clause cannot be empty") parts: list[str] = [] aliases: set[str] = set() for i, item in enumerate(items): expr, alias = self._select_item(item, cols_by_id, i) if alias: parts.append(f"{expr} AS {self._qident(alias)}") aliases.add(alias) else: parts.append(expr) return "SELECT " + ", ".join(parts), aliases def _select_item( self, item: SelectItem, cols_by_id: dict[str, ColRef], index: int ) -> tuple[str, str | None]: if isinstance(item, ColumnSelect): ref = self._require_col(cols_by_id, item.column_id, f"select[{index}]") return self._qcol(ref), item.alias if not isinstance(item, AggSelect): raise SqlCompilerError( f"select[{index}]: unknown SelectItem kind {type(item).__name__}" ) return self._compile_agg(item, cols_by_id, index), item.alias def _compile_agg( self, item: AggSelect, cols_by_id: dict[str, ColRef], index: int ) -> str: if item.fn == "count_distinct": if item.column_id is None: raise SqlCompilerError( f"select[{index}].fn=count_distinct requires column_id" ) ref = self._require_col(cols_by_id, item.column_id, f"select[{index}]") return f"COUNT(DISTINCT {self._qcol(ref)})" if item.column_id is None: if item.fn != "count": raise SqlCompilerError( f"select[{index}].fn={item.fn!r} requires column_id " "(only 'count' may omit it for COUNT(*))" ) return "COUNT(*)" ref = self._require_col(cols_by_id, item.column_id, f"select[{index}]") return f"{item.fn.upper()}({self._qcol(ref)})" def _build_from( self, base_table: Table, ir: QueryIR, cols_by_id: dict[str, ColRef] ) -> str: sql = f"FROM {self._qident(base_table.name)}" for i, j in enumerate(ir.joins): left = self._require_col(cols_by_id, j.left_column_id, f"joins[{i}].left") right = self._require_col(cols_by_id, j.right_column_id, f"joins[{i}].right") target_table = right[0] keyword = "INNER JOIN" if j.type == "inner" else "LEFT JOIN" sql += ( f" {keyword} {self._qident(target_table.name)} " f"ON {self._qcol(left)} = {self._qcol(right)}" ) return sql def _build_where( self, filters: list[FilterClause], cols_by_id: dict[str, ColRef], params: dict[str, Any], param_seq: list[int], ) -> str: if not filters: return "" parts = [ self._compile_filter(f, cols_by_id, params, param_seq, index=i) for i, f in enumerate(filters) ] return "WHERE " + " AND ".join(parts) def _compile_filter( self, f: FilterClause, cols_by_id: dict[str, ColRef], params: dict[str, Any], param_seq: list[int], index: int, ) -> str: ref = self._require_col(cols_by_id, f.column_id, f"filters[{index}]") col_ref = self._qcol(ref) op = f.op if op == "is_null": return f"{col_ref} IS NULL" if op == "is_not_null": return f"{col_ref} IS NOT NULL" if op in _LIST_OPS: if not isinstance(f.value, list) or not f.value: raise SqlCompilerError( f"filters[{index}]: op {op!r} requires a non-empty list value" ) placeholders = [ ":" + self._next_param(params, param_seq, v) for v in f.value ] sql_op = "IN" if op == "in" else "NOT IN" return f"{col_ref} {sql_op} ({', '.join(placeholders)})" if op == "between": if not isinstance(f.value, list) or len(f.value) != 2: raise SqlCompilerError( f"filters[{index}]: op 'between' requires a list of two values" ) lo = self._next_param(params, param_seq, f.value[0]) hi = self._next_param(params, param_seq, f.value[1]) return f"{col_ref} BETWEEN :{lo} AND :{hi}" if op == "like": p = self._next_param(params, param_seq, f.value) return f"{col_ref} LIKE :{p}" if op in _COMPARISON_OPS: p = self._next_param(params, param_seq, f.value) return f"{col_ref} {op} :{p}" # Should not reach here — IRValidator already filters disallowed ops raise SqlCompilerError(f"filters[{index}]: unhandled op {op!r}") def _build_groupby( self, group_by: list[str], cols_by_id: dict[str, ColRef] ) -> str: if not group_by: return "" parts = [ self._qcol(self._require_col(cols_by_id, col_id, f"group_by[{i}]")) for i, col_id in enumerate(group_by) ] return "GROUP BY " + ", ".join(parts) def _build_orderby( self, order_by: list[OrderByClause], cols_by_id: dict[str, ColRef], select_aliases: set[str], ) -> str: if not order_by: return "" parts: list[str] = [] for i, ob in enumerate(order_by): if ob.column_id in cols_by_id: ref = self._qcol(cols_by_id[ob.column_id]) elif ob.column_id in select_aliases: ref = self._qident(ob.column_id) else: raise SqlCompilerError( f"order_by[{i}].column_id: {ob.column_id!r} not in query " "columns or select aliases" ) parts.append(f"{ref} {ob.dir.upper()}") return "ORDER BY " + ", ".join(parts) def _build_limit(self, limit: int | None) -> tuple[str, int]: """Return (LIMIT clause, row_cap). Always bounded. An explicit IR limit is honored exactly (capped at MAX_RESULT_ROWS). When the IR has no limit we still emit `LIMIT MAX_RESULT_ROWS + 1` — the extra row lets the executor tell "exactly the cap" from "more rows existed" and flag truncation. """ if limit is None: return f"LIMIT {MAX_RESULT_ROWS + 1}", MAX_RESULT_ROWS row_cap = min(int(limit), MAX_RESULT_ROWS) return f"LIMIT {row_cap}", row_cap # ------------------------------------------------------------------ # Helpers # ------------------------------------------------------------------ @staticmethod def _next_param( params: dict[str, Any], param_seq: list[int], value: Any ) -> str: name = f"p_{param_seq[0]}" param_seq[0] += 1 params[name] = value return name @staticmethod def _require_col( cols_by_id: dict[str, ColRef], col_id: str, where: str ) -> ColRef: ref = cols_by_id.get(col_id) if ref is None: raise SqlCompilerError(f"{where}.column_id: {col_id!r} not in query tables") return ref