"""PandasCompiler — IR → callable that runs against a DataFrame. For tabular sources. The callable encapsulates the chain of operations (filter → select/agg → sort → limit) so the executor can apply them to a DataFrame loaded from a Parquet blob. Returns a `CompiledPandas` dataclass (mirrors `CompiledSql`) whose `.apply` is a pure function `(pd.DataFrame) -> pd.DataFrame`. No LLM, no I/O. """ from __future__ import annotations import re from collections.abc import Callable from dataclasses import dataclass from typing import Any import pandas as pd from ...catalog.models import Catalog, Column, Source, Table from ..ir.models import AggSelect, ColumnSelect, FilterClause, OrderByClause, QueryIR, SelectItem from .base import BaseCompiler @dataclass class CompiledPandas: """Compiled IR as a pandas operation chain. `apply(df)` executes the full filter → select/agg → sort → limit pipeline and returns the result as a new DataFrame. `output_columns` lists the expected column names so callers can label an empty result without inspecting rows. """ apply: Callable[[pd.DataFrame], pd.DataFrame] output_columns: list[str] class PandasCompilerError(Exception): pass class PandasCompiler(BaseCompiler): """Deterministic IR → pandas op chain. No LLM.""" def __init__(self, catalog: Catalog) -> None: self._catalog = catalog def compile(self, ir: QueryIR) -> CompiledPandas: _, table, cols_by_id = self._lookup(ir) output_columns = _output_column_names(ir.select, cols_by_id) # Capture IR fields explicitly so the closure is self-contained _filters = ir.filters _select = ir.select _group_by = ir.group_by _order_by = ir.order_by _limit = ir.limit _cols = cols_by_id def apply(df: pd.DataFrame) -> pd.DataFrame: df = _apply_filters(df, _filters, _cols) has_agg = any(isinstance(s, AggSelect) for s in _select) if has_agg: df = _apply_agg(df, _select, _group_by, _cols) else: df = _apply_select(df, _select, _cols) if _order_by: df = _apply_orderby(df, _order_by, _select, _cols) if _limit is not None: df = df.head(_limit) return df.reset_index(drop=True) return CompiledPandas(apply=apply, output_columns=output_columns) # ------------------------------------------------------------------ # Catalog lookup (mirrors SqlCompiler._lookup) # ------------------------------------------------------------------ def _lookup(self, ir: QueryIR) -> tuple[Source, Table, dict[str, Column]]: source = next((s for s in self._catalog.sources if s.source_id == ir.source_id), None) if source is None: raise PandasCompilerError(f"source_id {ir.source_id!r} not in catalog") table = next((t for t in source.tables if t.table_id == ir.table_id), None) if table is None: raise PandasCompilerError( f"table_id {ir.table_id!r} not in source {ir.source_id!r}" ) return source, table, {c.column_id: c for c in table.columns} # --------------------------------------------------------------------------- # Module-level helpers (pure functions — easier to test in isolation) # --------------------------------------------------------------------------- def _output_column_names(select: list[SelectItem], cols_by_id: dict[str, Column]) -> list[str]: names = [] for s in select: if isinstance(s, ColumnSelect): names.append(s.alias or cols_by_id[s.column_id].name) else: names.append(_agg_output_name(s, cols_by_id)) return names def _agg_output_name(s: AggSelect, cols_by_id: dict[str, Column]) -> str: if s.alias: return s.alias if s.fn == "count" and s.column_id is None: return "count" return f"{s.fn}_{cols_by_id[s.column_id].name}" def _like_to_regex(pattern: str) -> str: """Convert SQL LIKE pattern to Python regex string (no anchors — use fullmatch).""" parts: list[str] = [] for ch in pattern: if ch == "%": parts.append(".*") elif ch == "_": parts.append(".") else: parts.append(re.escape(ch)) return "".join(parts) def _apply_filters( df: pd.DataFrame, filters: list[FilterClause], cols_by_id: dict[str, Column], ) -> pd.DataFrame: if not filters: return df mask = pd.Series(True, index=df.index) for f in filters: col_name = cols_by_id[f.column_id].name series = df[col_name] op, val = f.op, f.value if op == "=": mask &= series == val elif op == "!=": mask &= series != val elif op == "<": mask &= series < val elif op == "<=": mask &= series <= val elif op == ">": mask &= series > val elif op == ">=": mask &= series >= val elif op == "in": mask &= series.isin(val) elif op == "not_in": mask &= ~series.isin(val) elif op == "is_null": mask &= series.isna() elif op == "is_not_null": mask &= series.notna() elif op == "like": mask &= series.astype(str).str.fullmatch(_like_to_regex(val), case=True, na=False) elif op == "between": mask &= (series >= val[0]) & (series <= val[1]) return df[mask].copy() def _apply_select( df: pd.DataFrame, select: list[SelectItem], cols_by_id: dict[str, Column], ) -> pd.DataFrame: col_names = [cols_by_id[s.column_id].name for s in select if isinstance(s, ColumnSelect)] result = df[col_names].copy() rename_map = { cols_by_id[s.column_id].name: s.alias for s in select if isinstance(s, ColumnSelect) and s.alias } if rename_map: result = result.rename(columns=rename_map) return result def _scalar_agg(df: pd.DataFrame, s: AggSelect, cols_by_id: dict[str, Column]) -> Any: if s.fn == "count" and s.column_id is None: return int(len(df)) col_name = cols_by_id[s.column_id].name series = df[col_name] match s.fn: case "count": return int(series.count()) case "count_distinct": return int(series.nunique()) case "sum": return series.sum() case "avg": return series.mean() case "min": return series.min() case "max": return series.max() raise PandasCompilerError(f"unhandled agg fn {s.fn!r}") def _group_agg_series( grouped: Any, s: AggSelect, cols_by_id: dict[str, Column], ) -> pd.Series: if s.fn == "count" and s.column_id is None: return grouped.size() col_name = cols_by_id[s.column_id].name match s.fn: case "count": return grouped[col_name].count() case "count_distinct": return grouped[col_name].nunique() case "sum": return grouped[col_name].sum() case "avg": return grouped[col_name].mean() case "min": return grouped[col_name].min() case "max": return grouped[col_name].max() raise PandasCompilerError(f"unhandled agg fn {s.fn!r}") def _apply_agg( df: pd.DataFrame, select: list[SelectItem], group_by: list[str], cols_by_id: dict[str, Column], ) -> pd.DataFrame: agg_items = [s for s in select if isinstance(s, AggSelect)] col_items = [s for s in select if isinstance(s, ColumnSelect)] group_col_names = [cols_by_id[col_id].name for col_id in group_by] if group_col_names: grouped = df.groupby(group_col_names, sort=False) series_list = [ _group_agg_series(grouped, s, cols_by_id).rename(_agg_output_name(s, cols_by_id)) for s in agg_items ] result = pd.concat(series_list, axis=1).reset_index() rename_map = { cols_by_id[s.column_id].name: s.alias for s in col_items if s.alias } if rename_map: result = result.rename(columns=rename_map) else: row = { _agg_output_name(s, cols_by_id): _scalar_agg(df, s, cols_by_id) for s in agg_items } result = pd.DataFrame([row]) return result def _resolve_order_col( col_id_or_alias: str, select: list[SelectItem], cols_by_id: dict[str, Column], ) -> str: """Map an order_by column_id (or alias) to the actual output column name.""" for s in select: if isinstance(s, ColumnSelect) and s.column_id == col_id_or_alias: return s.alias or cols_by_id[s.column_id].name if isinstance(s, AggSelect) and s.column_id == col_id_or_alias: return _agg_output_name(s, cols_by_id) return col_id_or_alias # treat as alias / output name directly def _apply_orderby( df: pd.DataFrame, order_by: list[OrderByClause], select: list[SelectItem], cols_by_id: dict[str, Column], ) -> pd.DataFrame: sort_cols: list[str] = [] ascending: list[bool] = [] for ob in order_by: out_name = _resolve_order_col(ob.column_id, select, cols_by_id) if out_name in df.columns: sort_cols.append(out_name) ascending.append(ob.dir == "asc") if sort_cols: return df.sort_values(by=sort_cols, ascending=ascending) return df