| """SqlCompiler — IR → (SQL string, named-params dict). |
| |
| Identifiers (table / column names) come from the catalog and are quoted |
| verbatim — they were verified by the IR validator against the catalog, |
| so injection through identifiers is not possible at this layer. |
| Values from filter clauses are ALWAYS parameterized. |
| |
| The output `CompiledSql.sql` uses SQLAlchemy-style named placeholders |
| (`:p_0, :p_1, ...`) so it can be executed via `text(sql)` with a params |
| dict on a sync SQLAlchemy engine. |
| |
| Joins (KM-652 T4): a column_id resolves to its owning table across the base |
| table + any joined tables, so every reference is emitted table-qualified |
| (`"table"."col"`). With `joins=[]` the output is identical to the single-table |
| form. v1 supports the Postgres dialect only (Supabase = Postgres). |
| """ |
|
|
| from __future__ import annotations |
|
|
| from dataclasses import dataclass, field |
| from typing import Any |
|
|
| from ...catalog.models import Catalog, Column, Source, Table |
| from ..ir.models import ( |
| AggSelect, |
| ColumnSelect, |
| FilterClause, |
| OrderByClause, |
| QueryIR, |
| SelectItem, |
| ) |
| from .base import BaseCompiler |
|
|
| |
| |
| |
| |
| MAX_RESULT_ROWS = 10_000 |
|
|
| |
| ColRef = tuple[Table, Column] |
|
|
|
|
| @dataclass |
| class CompiledSql: |
| sql: str |
| params: dict[str, Any] = field(default_factory=dict) |
| row_cap: int = MAX_RESULT_ROWS |
|
|
|
|
| class SqlCompilerError(Exception): |
| pass |
|
|
|
|
| _NULLARY_OPS = frozenset({"is_null", "is_not_null"}) |
| _LIST_OPS = frozenset({"in", "not_in"}) |
| _COMPARISON_OPS = frozenset({"=", "!=", "<", "<=", ">", ">="}) |
|
|
|
|
| class SqlCompiler(BaseCompiler): |
| """Deterministic IR → Postgres SQL. No LLM.""" |
|
|
| def __init__(self, catalog: Catalog, dialect: str = "postgres") -> None: |
| if dialect not in {"postgres", "supabase"}: |
| raise SqlCompilerError( |
| f"only 'postgres' / 'supabase' supported in v1, got {dialect!r}" |
| ) |
| self._catalog = catalog |
| self._dialect = dialect |
|
|
| def compile(self, ir: QueryIR) -> CompiledSql: |
| base_table, cols_by_id = self._lookup(ir) |
| params: dict[str, Any] = {} |
| param_seq = [0] |
|
|
| select_clause, select_aliases = self._build_select(ir.select, cols_by_id) |
| from_clause = self._build_from(base_table, ir, cols_by_id) |
| where_clause = self._build_where(ir.filters, cols_by_id, params, param_seq) |
| groupby_clause = self._build_groupby(ir.group_by, cols_by_id) |
| orderby_clause = self._build_orderby(ir.order_by, cols_by_id, select_aliases) |
| limit_clause, row_cap = self._build_limit(ir.limit) |
|
|
| parts: list[str] = [select_clause, from_clause] |
| for clause in (where_clause, groupby_clause, orderby_clause, limit_clause): |
| if clause: |
| parts.append(clause) |
|
|
| return CompiledSql(sql=" ".join(parts), params=params, row_cap=row_cap) |
|
|
| |
| |
| |
|
|
| def _lookup(self, ir: QueryIR) -> tuple[Table, dict[str, ColRef]]: |
| source = next( |
| (s for s in self._catalog.sources if s.source_id == ir.source_id), None |
| ) |
| if source is None: |
| raise SqlCompilerError(f"source_id {ir.source_id!r} not in catalog") |
| base_table = self._find_table(source, ir.table_id) |
| cols_by_id: dict[str, ColRef] = { |
| c.column_id: (base_table, c) for c in base_table.columns |
| } |
| for j in ir.joins: |
| target = self._find_table(source, j.target_table_id) |
| for c in target.columns: |
| cols_by_id[c.column_id] = (target, c) |
| return base_table, cols_by_id |
|
|
| @staticmethod |
| def _find_table(source: Source, table_id: str) -> Table: |
| table = next((t for t in source.tables if t.table_id == table_id), None) |
| if table is None: |
| raise SqlCompilerError( |
| f"table_id {table_id!r} not in source {source.source_id!r}" |
| ) |
| return table |
|
|
| |
| |
| |
|
|
| @staticmethod |
| def _qident(name: str) -> str: |
| """Postgres-style double-quoted identifier with embedded-quote escape.""" |
| return '"' + name.replace('"', '""') + '"' |
|
|
| def _qcol(self, ref: ColRef) -> str: |
| table, col = ref |
| return f"{self._qident(table.name)}.{self._qident(col.name)}" |
|
|
| |
| |
| |
|
|
| def _build_select( |
| self, items: list[SelectItem], cols_by_id: dict[str, ColRef] |
| ) -> tuple[str, set[str]]: |
| if not items: |
| raise SqlCompilerError("select clause cannot be empty") |
| parts: list[str] = [] |
| aliases: set[str] = set() |
| for i, item in enumerate(items): |
| expr, alias = self._select_item(item, cols_by_id, i) |
| if alias: |
| parts.append(f"{expr} AS {self._qident(alias)}") |
| aliases.add(alias) |
| else: |
| parts.append(expr) |
| return "SELECT " + ", ".join(parts), aliases |
|
|
| def _select_item( |
| self, item: SelectItem, cols_by_id: dict[str, ColRef], index: int |
| ) -> tuple[str, str | None]: |
| if isinstance(item, ColumnSelect): |
| ref = self._require_col(cols_by_id, item.column_id, f"select[{index}]") |
| return self._qcol(ref), item.alias |
| if not isinstance(item, AggSelect): |
| raise SqlCompilerError( |
| f"select[{index}]: unknown SelectItem kind {type(item).__name__}" |
| ) |
| return self._compile_agg(item, cols_by_id, index), item.alias |
|
|
| def _compile_agg( |
| self, item: AggSelect, cols_by_id: dict[str, ColRef], index: int |
| ) -> str: |
| if item.fn == "count_distinct": |
| if item.column_id is None: |
| raise SqlCompilerError( |
| f"select[{index}].fn=count_distinct requires column_id" |
| ) |
| ref = self._require_col(cols_by_id, item.column_id, f"select[{index}]") |
| return f"COUNT(DISTINCT {self._qcol(ref)})" |
| if item.column_id is None: |
| if item.fn != "count": |
| raise SqlCompilerError( |
| f"select[{index}].fn={item.fn!r} requires column_id " |
| "(only 'count' may omit it for COUNT(*))" |
| ) |
| return "COUNT(*)" |
| ref = self._require_col(cols_by_id, item.column_id, f"select[{index}]") |
| return f"{item.fn.upper()}({self._qcol(ref)})" |
|
|
| def _build_from( |
| self, base_table: Table, ir: QueryIR, cols_by_id: dict[str, ColRef] |
| ) -> str: |
| sql = f"FROM {self._qident(base_table.name)}" |
| for i, j in enumerate(ir.joins): |
| left = self._require_col(cols_by_id, j.left_column_id, f"joins[{i}].left") |
| right = self._require_col(cols_by_id, j.right_column_id, f"joins[{i}].right") |
| target_table = right[0] |
| keyword = "INNER JOIN" if j.type == "inner" else "LEFT JOIN" |
| sql += ( |
| f" {keyword} {self._qident(target_table.name)} " |
| f"ON {self._qcol(left)} = {self._qcol(right)}" |
| ) |
| return sql |
|
|
| def _build_where( |
| self, |
| filters: list[FilterClause], |
| cols_by_id: dict[str, ColRef], |
| params: dict[str, Any], |
| param_seq: list[int], |
| ) -> str: |
| if not filters: |
| return "" |
| parts = [ |
| self._compile_filter(f, cols_by_id, params, param_seq, index=i) |
| for i, f in enumerate(filters) |
| ] |
| return "WHERE " + " AND ".join(parts) |
|
|
| def _compile_filter( |
| self, |
| f: FilterClause, |
| cols_by_id: dict[str, ColRef], |
| params: dict[str, Any], |
| param_seq: list[int], |
| index: int, |
| ) -> str: |
| ref = self._require_col(cols_by_id, f.column_id, f"filters[{index}]") |
| col_ref = self._qcol(ref) |
| op = f.op |
|
|
| if op == "is_null": |
| return f"{col_ref} IS NULL" |
| if op == "is_not_null": |
| return f"{col_ref} IS NOT NULL" |
|
|
| if op in _LIST_OPS: |
| if not isinstance(f.value, list) or not f.value: |
| raise SqlCompilerError( |
| f"filters[{index}]: op {op!r} requires a non-empty list value" |
| ) |
| placeholders = [ |
| ":" + self._next_param(params, param_seq, v) for v in f.value |
| ] |
| sql_op = "IN" if op == "in" else "NOT IN" |
| return f"{col_ref} {sql_op} ({', '.join(placeholders)})" |
|
|
| if op == "between": |
| if not isinstance(f.value, list) or len(f.value) != 2: |
| raise SqlCompilerError( |
| f"filters[{index}]: op 'between' requires a list of two values" |
| ) |
| lo = self._next_param(params, param_seq, f.value[0]) |
| hi = self._next_param(params, param_seq, f.value[1]) |
| return f"{col_ref} BETWEEN :{lo} AND :{hi}" |
|
|
| if op == "like": |
| p = self._next_param(params, param_seq, f.value) |
| return f"{col_ref} LIKE :{p}" |
|
|
| if op in _COMPARISON_OPS: |
| p = self._next_param(params, param_seq, f.value) |
| return f"{col_ref} {op} :{p}" |
|
|
| |
| raise SqlCompilerError(f"filters[{index}]: unhandled op {op!r}") |
|
|
| def _build_groupby( |
| self, group_by: list[str], cols_by_id: dict[str, ColRef] |
| ) -> str: |
| if not group_by: |
| return "" |
| parts = [ |
| self._qcol(self._require_col(cols_by_id, col_id, f"group_by[{i}]")) |
| for i, col_id in enumerate(group_by) |
| ] |
| return "GROUP BY " + ", ".join(parts) |
|
|
| def _build_orderby( |
| self, |
| order_by: list[OrderByClause], |
| cols_by_id: dict[str, ColRef], |
| select_aliases: set[str], |
| ) -> str: |
| if not order_by: |
| return "" |
| parts: list[str] = [] |
| for i, ob in enumerate(order_by): |
| if ob.column_id in cols_by_id: |
| ref = self._qcol(cols_by_id[ob.column_id]) |
| elif ob.column_id in select_aliases: |
| ref = self._qident(ob.column_id) |
| else: |
| raise SqlCompilerError( |
| f"order_by[{i}].column_id: {ob.column_id!r} not in query " |
| "columns or select aliases" |
| ) |
| parts.append(f"{ref} {ob.dir.upper()}") |
| return "ORDER BY " + ", ".join(parts) |
|
|
| def _build_limit(self, limit: int | None) -> tuple[str, int]: |
| """Return (LIMIT clause, row_cap). |
| |
| Always bounded. An explicit IR limit is honored exactly (capped at |
| MAX_RESULT_ROWS). When the IR has no limit we still emit |
| `LIMIT MAX_RESULT_ROWS + 1` — the extra row lets the executor tell |
| "exactly the cap" from "more rows existed" and flag truncation. |
| """ |
| if limit is None: |
| return f"LIMIT {MAX_RESULT_ROWS + 1}", MAX_RESULT_ROWS |
| row_cap = min(int(limit), MAX_RESULT_ROWS) |
| return f"LIMIT {row_cap}", row_cap |
|
|
| |
| |
| |
|
|
| @staticmethod |
| def _next_param( |
| params: dict[str, Any], param_seq: list[int], value: Any |
| ) -> str: |
| name = f"p_{param_seq[0]}" |
| param_seq[0] += 1 |
| params[name] = value |
| return name |
|
|
| @staticmethod |
| def _require_col( |
| cols_by_id: dict[str, ColRef], col_id: str, where: str |
| ) -> ColRef: |
| ref = cols_by_id.get(col_id) |
| if ref is None: |
| raise SqlCompilerError(f"{where}.column_id: {col_id!r} not in query tables") |
| return ref |
|
|