File size: 5,060 Bytes
6bff5d9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 | """IRValidator — checks a QueryIR against a user's catalog.
See ARCHITECTURE.md §7 for the validation rules. On failure, the planner
is re-prompted with the error context (max 3 retries) — error messages
must therefore be specific enough that the LLM can self-correct.
"""
from ...catalog.models import Catalog, Column, Source, Table
from .models import QueryIR
from .operators import (
ALLOWED_AGG_FNS,
ALLOWED_FILTER_OPS,
LIMIT_HARD_CAP,
TYPE_COMPATIBILITY,
)
_NULLARY_FILTER_OPS = frozenset({"is_null", "is_not_null"})
class IRValidationError(Exception):
pass
class IRValidator:
"""Reject IRs that reference unknown sources/tables/columns or use disallowed ops.
Rules:
- source_id exists in catalog for this user
- table_id belongs to that source
- every column_id exists in that table
- every agg.fn and filter.op is whitelisted (see operators.py)
- value_type consistent with column.data_type (TYPE_COMPATIBILITY)
- limit positive int, ≤ LIMIT_HARD_CAP
"""
def validate(self, ir: QueryIR, catalog: Catalog) -> None:
source = self._find_source(catalog, ir.source_id)
table = self._find_table(source, ir.table_id)
columns_by_id: dict[str, Column] = {c.column_id: c for c in table.columns}
select_aliases: set[str] = set()
for i, item in enumerate(ir.select):
where = f"select[{i}]"
if item.kind == "column":
self._require_column(columns_by_id, item.column_id, where)
else: # "agg"
if item.fn not in ALLOWED_AGG_FNS:
raise IRValidationError(
f"{where}.fn: must be in {sorted(ALLOWED_AGG_FNS)}, "
f"got {item.fn!r}"
)
if item.column_id is not None:
self._require_column(columns_by_id, item.column_id, where)
elif item.fn != "count":
raise IRValidationError(
f"{where}.fn={item.fn!r} requires a column_id "
"(only 'count' may omit it for COUNT(*))"
)
if item.alias:
select_aliases.add(item.alias)
for i, f in enumerate(ir.filters):
where = f"filters[{i}]"
col = self._require_column(columns_by_id, f.column_id, where)
if f.op not in ALLOWED_FILTER_OPS:
raise IRValidationError(
f"{where}.op: must be in {sorted(ALLOWED_FILTER_OPS)}, "
f"got {f.op!r}"
)
if f.op not in _NULLARY_FILTER_OPS:
allowed = TYPE_COMPATIBILITY.get(col.data_type, frozenset())
if f.value_type not in allowed:
raise IRValidationError(
f"{where}: value_type {f.value_type!r} incompatible with "
f"column.data_type {col.data_type!r} "
f"(allowed: {sorted(allowed)})"
)
for i, col_id in enumerate(ir.group_by):
self._require_column(columns_by_id, col_id, f"group_by[{i}]")
for i, ob in enumerate(ir.order_by):
if ob.column_id not in columns_by_id and ob.column_id not in select_aliases:
raise IRValidationError(
f"order_by[{i}].column_id: {ob.column_id!r} not found in table "
f"{ir.table_id!r} columns or select aliases "
f"(known columns: {sorted(columns_by_id.keys())}, "
f"aliases: {sorted(select_aliases)})"
)
if ir.limit is not None:
if ir.limit <= 0:
raise IRValidationError(f"limit must be positive, got {ir.limit}")
if ir.limit > LIMIT_HARD_CAP:
raise IRValidationError(
f"limit {ir.limit} exceeds hard cap {LIMIT_HARD_CAP}"
)
@staticmethod
def _find_source(catalog: Catalog, source_id: str) -> Source:
for s in catalog.sources:
if s.source_id == source_id:
return s
raise IRValidationError(
f"source_id {source_id!r} not in catalog "
f"(known: {[s.source_id for s in catalog.sources]})"
)
@staticmethod
def _find_table(source: Source, table_id: str) -> Table:
for t in source.tables:
if t.table_id == table_id:
return t
raise IRValidationError(
f"table_id {table_id!r} not in source {source.source_id!r} "
f"(known: {[t.table_id for t in source.tables]})"
)
@staticmethod
def _require_column(
columns_by_id: dict[str, Column], col_id: str, where: str
) -> Column:
col = columns_by_id.get(col_id)
if col is None:
raise IRValidationError(
f"{where}.column_id: {col_id!r} not in table "
f"(known: {sorted(columns_by_id.keys())})"
)
return col
|