File size: 9,615 Bytes
6bff5d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
"""PandasCompiler β€” IR β†’ callable that runs against a DataFrame.

For tabular sources. The callable encapsulates the chain of operations
(filter β†’ select/agg β†’ sort β†’ limit) so the executor can apply them
to a DataFrame loaded from a Parquet blob.

Returns a `CompiledPandas` dataclass (mirrors `CompiledSql`) whose `.apply`
is a pure function `(pd.DataFrame) -> pd.DataFrame`. No LLM, no I/O.
"""

from __future__ import annotations

import re
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any

import pandas as pd

from ...catalog.models import Catalog, Column, Source, Table
from ..ir.models import AggSelect, ColumnSelect, FilterClause, OrderByClause, QueryIR, SelectItem
from .base import BaseCompiler


@dataclass
class CompiledPandas:
    """Compiled IR as a pandas operation chain.

    `apply(df)` executes the full filter β†’ select/agg β†’ sort β†’ limit
    pipeline and returns the result as a new DataFrame.

    `output_columns` lists the expected column names so callers can label
    an empty result without inspecting rows.
    """

    apply: Callable[[pd.DataFrame], pd.DataFrame]
    output_columns: list[str]


class PandasCompilerError(Exception):
    pass


class PandasCompiler(BaseCompiler):
    """Deterministic IR β†’ pandas op chain. No LLM."""

    def __init__(self, catalog: Catalog) -> None:
        self._catalog = catalog

    def compile(self, ir: QueryIR) -> CompiledPandas:
        _, table, cols_by_id = self._lookup(ir)
        output_columns = _output_column_names(ir.select, cols_by_id)

        # Capture IR fields explicitly so the closure is self-contained
        _filters = ir.filters
        _select = ir.select
        _group_by = ir.group_by
        _order_by = ir.order_by
        _limit = ir.limit
        _cols = cols_by_id

        def apply(df: pd.DataFrame) -> pd.DataFrame:
            df = _apply_filters(df, _filters, _cols)

            has_agg = any(isinstance(s, AggSelect) for s in _select)
            if has_agg:
                df = _apply_agg(df, _select, _group_by, _cols)
            else:
                df = _apply_select(df, _select, _cols)

            if _order_by:
                df = _apply_orderby(df, _order_by, _select, _cols)

            if _limit is not None:
                df = df.head(_limit)

            return df.reset_index(drop=True)

        return CompiledPandas(apply=apply, output_columns=output_columns)

    # ------------------------------------------------------------------
    # Catalog lookup (mirrors SqlCompiler._lookup)
    # ------------------------------------------------------------------

    def _lookup(self, ir: QueryIR) -> tuple[Source, Table, dict[str, Column]]:
        source = next((s for s in self._catalog.sources if s.source_id == ir.source_id), None)
        if source is None:
            raise PandasCompilerError(f"source_id {ir.source_id!r} not in catalog")
        table = next((t for t in source.tables if t.table_id == ir.table_id), None)
        if table is None:
            raise PandasCompilerError(
                f"table_id {ir.table_id!r} not in source {ir.source_id!r}"
            )
        return source, table, {c.column_id: c for c in table.columns}


# ---------------------------------------------------------------------------
# Module-level helpers (pure functions β€” easier to test in isolation)
# ---------------------------------------------------------------------------

def _output_column_names(select: list[SelectItem], cols_by_id: dict[str, Column]) -> list[str]:
    names = []
    for s in select:
        if isinstance(s, ColumnSelect):
            names.append(s.alias or cols_by_id[s.column_id].name)
        else:
            names.append(_agg_output_name(s, cols_by_id))
    return names


def _agg_output_name(s: AggSelect, cols_by_id: dict[str, Column]) -> str:
    if s.alias:
        return s.alias
    if s.fn == "count" and s.column_id is None:
        return "count"
    return f"{s.fn}_{cols_by_id[s.column_id].name}"


def _like_to_regex(pattern: str) -> str:
    """Convert SQL LIKE pattern to Python regex string (no anchors β€” use fullmatch)."""
    parts: list[str] = []
    for ch in pattern:
        if ch == "%":
            parts.append(".*")
        elif ch == "_":
            parts.append(".")
        else:
            parts.append(re.escape(ch))
    return "".join(parts)


def _apply_filters(
    df: pd.DataFrame,
    filters: list[FilterClause],
    cols_by_id: dict[str, Column],
) -> pd.DataFrame:
    if not filters:
        return df
    mask = pd.Series(True, index=df.index)
    for f in filters:
        col_name = cols_by_id[f.column_id].name
        series = df[col_name]
        op, val = f.op, f.value
        if op == "=":
            mask &= series == val
        elif op == "!=":
            mask &= series != val
        elif op == "<":
            mask &= series < val
        elif op == "<=":
            mask &= series <= val
        elif op == ">":
            mask &= series > val
        elif op == ">=":
            mask &= series >= val
        elif op == "in":
            mask &= series.isin(val)
        elif op == "not_in":
            mask &= ~series.isin(val)
        elif op == "is_null":
            mask &= series.isna()
        elif op == "is_not_null":
            mask &= series.notna()
        elif op == "like":
            mask &= series.astype(str).str.fullmatch(_like_to_regex(val), case=True, na=False)
        elif op == "between":
            mask &= (series >= val[0]) & (series <= val[1])
    return df[mask].copy()


def _apply_select(
    df: pd.DataFrame,
    select: list[SelectItem],
    cols_by_id: dict[str, Column],
) -> pd.DataFrame:
    col_names = [cols_by_id[s.column_id].name for s in select if isinstance(s, ColumnSelect)]
    result = df[col_names].copy()
    rename_map = {
        cols_by_id[s.column_id].name: s.alias
        for s in select
        if isinstance(s, ColumnSelect) and s.alias
    }
    if rename_map:
        result = result.rename(columns=rename_map)
    return result


def _scalar_agg(df: pd.DataFrame, s: AggSelect, cols_by_id: dict[str, Column]) -> Any:
    if s.fn == "count" and s.column_id is None:
        return int(len(df))
    col_name = cols_by_id[s.column_id].name
    series = df[col_name]
    match s.fn:
        case "count":
            return int(series.count())
        case "count_distinct":
            return int(series.nunique())
        case "sum":
            return series.sum()
        case "avg":
            return series.mean()
        case "min":
            return series.min()
        case "max":
            return series.max()
    raise PandasCompilerError(f"unhandled agg fn {s.fn!r}")


def _group_agg_series(
    grouped: Any,
    s: AggSelect,
    cols_by_id: dict[str, Column],
) -> pd.Series:
    if s.fn == "count" and s.column_id is None:
        return grouped.size()
    col_name = cols_by_id[s.column_id].name
    match s.fn:
        case "count":
            return grouped[col_name].count()
        case "count_distinct":
            return grouped[col_name].nunique()
        case "sum":
            return grouped[col_name].sum()
        case "avg":
            return grouped[col_name].mean()
        case "min":
            return grouped[col_name].min()
        case "max":
            return grouped[col_name].max()
    raise PandasCompilerError(f"unhandled agg fn {s.fn!r}")


def _apply_agg(
    df: pd.DataFrame,
    select: list[SelectItem],
    group_by: list[str],
    cols_by_id: dict[str, Column],
) -> pd.DataFrame:
    agg_items = [s for s in select if isinstance(s, AggSelect)]
    col_items = [s for s in select if isinstance(s, ColumnSelect)]
    group_col_names = [cols_by_id[col_id].name for col_id in group_by]

    if group_col_names:
        grouped = df.groupby(group_col_names, sort=False)
        series_list = [
            _group_agg_series(grouped, s, cols_by_id).rename(_agg_output_name(s, cols_by_id))
            for s in agg_items
        ]
        result = pd.concat(series_list, axis=1).reset_index()
        rename_map = {
            cols_by_id[s.column_id].name: s.alias
            for s in col_items
            if s.alias
        }
        if rename_map:
            result = result.rename(columns=rename_map)
    else:
        row = {
            _agg_output_name(s, cols_by_id): _scalar_agg(df, s, cols_by_id)
            for s in agg_items
        }
        result = pd.DataFrame([row])

    return result


def _resolve_order_col(
    col_id_or_alias: str,
    select: list[SelectItem],
    cols_by_id: dict[str, Column],
) -> str:
    """Map an order_by column_id (or alias) to the actual output column name."""
    for s in select:
        if isinstance(s, ColumnSelect) and s.column_id == col_id_or_alias:
            return s.alias or cols_by_id[s.column_id].name
        if isinstance(s, AggSelect) and s.column_id == col_id_or_alias:
            return _agg_output_name(s, cols_by_id)
    return col_id_or_alias  # treat as alias / output name directly


def _apply_orderby(
    df: pd.DataFrame,
    order_by: list[OrderByClause],
    select: list[SelectItem],
    cols_by_id: dict[str, Column],
) -> pd.DataFrame:
    sort_cols: list[str] = []
    ascending: list[bool] = []
    for ob in order_by:
        out_name = _resolve_order_col(ob.column_id, select, cols_by_id)
        if out_name in df.columns:
            sort_cols.append(out_name)
            ascending.append(ob.dir == "asc")
    if sort_cols:
        return df.sort_values(by=sort_cols, ascending=ascending)
    return df