File size: 5,060 Bytes
6bff5d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
"""IRValidator — checks a QueryIR against a user's catalog.

See ARCHITECTURE.md §7 for the validation rules. On failure, the planner
is re-prompted with the error context (max 3 retries) — error messages
must therefore be specific enough that the LLM can self-correct.
"""

from ...catalog.models import Catalog, Column, Source, Table
from .models import QueryIR
from .operators import (
    ALLOWED_AGG_FNS,
    ALLOWED_FILTER_OPS,
    LIMIT_HARD_CAP,
    TYPE_COMPATIBILITY,
)

_NULLARY_FILTER_OPS = frozenset({"is_null", "is_not_null"})


class IRValidationError(Exception):
    pass


class IRValidator:
    """Reject IRs that reference unknown sources/tables/columns or use disallowed ops.

    Rules:
    - source_id exists in catalog for this user
    - table_id belongs to that source
    - every column_id exists in that table
    - every agg.fn and filter.op is whitelisted (see operators.py)
    - value_type consistent with column.data_type (TYPE_COMPATIBILITY)
    - limit positive int, ≤ LIMIT_HARD_CAP
    """

    def validate(self, ir: QueryIR, catalog: Catalog) -> None:
        source = self._find_source(catalog, ir.source_id)
        table = self._find_table(source, ir.table_id)
        columns_by_id: dict[str, Column] = {c.column_id: c for c in table.columns}

        select_aliases: set[str] = set()
        for i, item in enumerate(ir.select):
            where = f"select[{i}]"
            if item.kind == "column":
                self._require_column(columns_by_id, item.column_id, where)
            else:  # "agg"
                if item.fn not in ALLOWED_AGG_FNS:
                    raise IRValidationError(
                        f"{where}.fn: must be in {sorted(ALLOWED_AGG_FNS)}, "
                        f"got {item.fn!r}"
                    )
                if item.column_id is not None:
                    self._require_column(columns_by_id, item.column_id, where)
                elif item.fn != "count":
                    raise IRValidationError(
                        f"{where}.fn={item.fn!r} requires a column_id "
                        "(only 'count' may omit it for COUNT(*))"
                    )
            if item.alias:
                select_aliases.add(item.alias)

        for i, f in enumerate(ir.filters):
            where = f"filters[{i}]"
            col = self._require_column(columns_by_id, f.column_id, where)
            if f.op not in ALLOWED_FILTER_OPS:
                raise IRValidationError(
                    f"{where}.op: must be in {sorted(ALLOWED_FILTER_OPS)}, "
                    f"got {f.op!r}"
                )
            if f.op not in _NULLARY_FILTER_OPS:
                allowed = TYPE_COMPATIBILITY.get(col.data_type, frozenset())
                if f.value_type not in allowed:
                    raise IRValidationError(
                        f"{where}: value_type {f.value_type!r} incompatible with "
                        f"column.data_type {col.data_type!r} "
                        f"(allowed: {sorted(allowed)})"
                    )

        for i, col_id in enumerate(ir.group_by):
            self._require_column(columns_by_id, col_id, f"group_by[{i}]")

        for i, ob in enumerate(ir.order_by):
            if ob.column_id not in columns_by_id and ob.column_id not in select_aliases:
                raise IRValidationError(
                    f"order_by[{i}].column_id: {ob.column_id!r} not found in table "
                    f"{ir.table_id!r} columns or select aliases "
                    f"(known columns: {sorted(columns_by_id.keys())}, "
                    f"aliases: {sorted(select_aliases)})"
                )

        if ir.limit is not None:
            if ir.limit <= 0:
                raise IRValidationError(f"limit must be positive, got {ir.limit}")
            if ir.limit > LIMIT_HARD_CAP:
                raise IRValidationError(
                    f"limit {ir.limit} exceeds hard cap {LIMIT_HARD_CAP}"
                )

    @staticmethod
    def _find_source(catalog: Catalog, source_id: str) -> Source:
        for s in catalog.sources:
            if s.source_id == source_id:
                return s
        raise IRValidationError(
            f"source_id {source_id!r} not in catalog "
            f"(known: {[s.source_id for s in catalog.sources]})"
        )

    @staticmethod
    def _find_table(source: Source, table_id: str) -> Table:
        for t in source.tables:
            if t.table_id == table_id:
                return t
        raise IRValidationError(
            f"table_id {table_id!r} not in source {source.source_id!r} "
            f"(known: {[t.table_id for t in source.tables]})"
        )

    @staticmethod
    def _require_column(
        columns_by_id: dict[str, Column], col_id: str, where: str
    ) -> Column:
        col = columns_by_id.get(col_id)
        if col is None:
            raise IRValidationError(
                f"{where}.column_id: {col_id!r} not in table "
                f"(known: {sorted(columns_by_id.keys())})"
            )
        return col