Spaces:
Sleeping
Sleeping
| """ | |
| executor.py — Pure-Python ETL pipeline interpreter. | |
| Supports step types: cast, select, rename, filter, date_parse, join, agg, dedup | |
| """ | |
| import copy | |
| import datetime | |
| from typing import Any | |
| # --- Execution Engine --- | |
| def execute_step(rows: list[dict], step: dict) -> list[dict]: | |
| """Execute a single pipeline step against a list of rows.""" | |
| op = step.get("op") | |
| field = step.get("field") | |
| if op == "cast": | |
| to_type = step.get("to_type") | |
| null_handling = step.get("null_handling", "error") # error | coerce | drop | |
| return _op_cast(rows, field, to_type, null_handling) | |
| elif op == "select": | |
| columns = step.get("columns", []) | |
| return _op_select(rows, columns) | |
| elif op == "rename": | |
| mapping = step.get("mapping", {}) | |
| return _op_rename(rows, mapping) | |
| elif op == "filter": | |
| condition = step.get("condition") | |
| value = step.get("value") | |
| return _op_filter(rows, field, condition, value) | |
| elif op == "dedup": | |
| subset = step.get("subset", []) | |
| keep = step.get("keep", "first") | |
| return _op_dedup(rows, subset, keep) | |
| elif op == "join": | |
| join_type = step.get("join_type", "left") | |
| on = step.get("on") | |
| right = step.get("right", []) | |
| return _op_join(rows, on, right, join_type) | |
| elif op == "agg": | |
| group_by = step.get("group_by", []) | |
| aggregations = step.get("aggregations", {}) | |
| return _op_agg(rows, group_by, aggregations) | |
| else: | |
| raise ValueError(f"Unknown operation: {op!r}") | |
| def run_pipeline(config: dict, rows: list[dict]) -> tuple[list[dict], str | None]: | |
| """Execute a sequence of steps defined in config['steps'].""" | |
| data = copy.deepcopy(rows) | |
| for i, step in enumerate(config.get("steps", [])): | |
| try: | |
| data = execute_step(data, step) | |
| except Exception as e: | |
| # Warm Baseline: Return current data (last valid state) instead of empty list | |
| return data, f"Step {i} ({step.get('op')}): {type(e).__name__}: {e}" | |
| return data, None | |
| # --- Internal Ops --- | |
| def _op_cast(rows, field, to_type, null_handling): | |
| out = [] | |
| for row in rows: | |
| val = row.get(field) | |
| # Check for null-like values | |
| is_null = val is None or str(val).lower().strip() in ("none", "null", "n/a", "") | |
| if is_null: | |
| if null_handling == "coerce": | |
| out.append({**row, field: None}) | |
| elif null_handling == "drop": | |
| continue | |
| else: | |
| raise ValueError(f"invalid input {val!r} for type {to_type}") | |
| continue | |
| try: | |
| if to_type == "DATE": | |
| # Basic YYYY-MM-DD check, but keep logic flexible for tests | |
| # The prompt tests expect "2024-01-15" to pass and "2024-99-99" to fail | |
| try: | |
| # Try parsing to validate | |
| datetime.datetime.strptime(str(val).strip(), "%Y-%m-%d") | |
| out.append({**row, field: str(val).strip()}) | |
| except ValueError: | |
| raise ValueError(f"invalid date {val!r}") | |
| elif to_type == "INT": | |
| out.append({**row, field: int(float(str(val)))}) | |
| elif to_type == "FLOAT": | |
| out.append({**row, field: float(str(val))}) | |
| elif to_type == "STRING": | |
| out.append({**row, field: str(val)}) | |
| elif to_type == "BOOLEAN": | |
| out.append({**row, field: str(val).lower() in ("1", "true", "yes")}) | |
| else: | |
| out.append({**row, field: val}) | |
| except Exception as e: | |
| if null_handling == "coerce": | |
| out.append({**row, field: None}) | |
| elif null_handling == "drop": | |
| continue | |
| else: | |
| raise ValueError(f"invalid input {val!r} for type {to_type}: {e}") | |
| return out | |
| def _op_select(rows, columns): | |
| return [{col: row[col] for col in columns} for row in rows] | |
| def _op_rename(rows, mapping): | |
| return [{mapping.get(k, k): v for k, v in row.items()} for row in rows] | |
| def _op_filter(rows, field, condition, value): | |
| def check(val): | |
| if condition == "eq": return val == value | |
| if condition == "not_null": return val is not None | |
| if condition == "is_null": return val is None | |
| return True | |
| return [row for row in rows if check(row.get(field))] | |
| def _op_dedup(rows, subset, keep): | |
| seen = set() | |
| out = [] | |
| for row in rows: | |
| key = tuple(row.get(s) for s in subset) | |
| if key not in seen: | |
| out.append(row) | |
| seen.add(key) | |
| elif keep == "all": # Dummy for completeness | |
| out.append(row) | |
| return out | |
| def _op_join(rows, on, right, join_type): | |
| # Simplified left/inner join | |
| right_lookup = {} | |
| from collections import defaultdict | |
| multi_lookup = defaultdict(list) | |
| for r in right: | |
| k = r.get(on) | |
| multi_lookup[k].append(r) | |
| if k not in right_lookup: | |
| right_lookup[k] = r | |
| out = [] | |
| for row in rows: | |
| k = row.get(on) | |
| matches = multi_lookup.get(k) | |
| if matches: | |
| for m in matches: | |
| merged = {**row} | |
| for mk, mv in m.items(): | |
| if mk != on: merged[mk] = mv | |
| out.append(merged) | |
| else: | |
| if join_type == "left": | |
| merged = {**row} | |
| # Find all potential keys from right to fill with None | |
| if right: | |
| for rk in right[0].keys(): | |
| if rk != on and rk not in merged: | |
| merged[rk] = None | |
| out.append(merged) | |
| return out | |
| def _op_agg(rows, group_by, aggregations): | |
| from collections import defaultdict | |
| groups = defaultdict(list) | |
| for row in rows: | |
| key = tuple(row.get(g) for g in group_by) | |
| groups[key].append(row) | |
| out = [] | |
| for key, grouped_rows in groups.items(): | |
| base = {g: k for g, k in zip(group_by, key)} | |
| for key, agg_spec in aggregations.items(): | |
| field = agg_spec["field"] | |
| func = agg_spec["func"] | |
| # Naming Logic: use explicit output_name if provided, else fallback to default | |
| out_name = agg_spec.get("output_name", f"{field}_{func}") | |
| vals = [r.get(field) for r in grouped_rows if r.get(field) is not None] | |
| if func == "sum": | |
| base[out_name] = sum(vals) | |
| elif func == "count_distinct": | |
| base[out_name] = len(set(vals)) | |
| elif func == "count": | |
| base[out_name] = len(vals) | |
| elif func == "mean": | |
| base[out_name] = sum(vals)/len(vals) if vals else None | |
| out.append(base) | |
| return out | |
| # --- Utility --- | |
| def compare_output(got: list[dict], expected: list[dict]) -> dict: | |
| if not expected: | |
| return {"row_match": 1.0 if not got else 0.0, "schema_match": not got} | |
| schema_match = bool(got) and set(got[0].keys()) == set(expected[0].keys()) | |
| matching_rows = 0 | |
| # Warm Baseline: If schema fails, check if the VALUES match (flexible row matching) | |
| for g, e in zip(got, expected): | |
| if g == e: | |
| matching_rows += 1 | |
| elif not schema_match: | |
| # Check if values match regardless of keys (for silent schema mismatches) | |
| try: | |
| g_vals = sorted(str(round(float(v), 3)) if isinstance(v, (int, float)) else str(v) for v in g.values()) | |
| e_vals = sorted(str(round(float(v), 3)) if isinstance(v, (int, float)) else str(v) for v in e.values()) | |
| if g_vals == e_vals: | |
| matching_rows += 0.8 # 80% credit for correct values but wrong keys | |
| except: | |
| if sorted(str(v) for v in g.values()) == sorted(str(v) for v in e.values()): | |
| matching_rows += 0.8 | |
| row_match = matching_rows / len(expected) if expected else 1.0 | |
| return { | |
| "row_match": round(row_match, 4), | |
| "schema_match": schema_match, | |
| "exact_match": row_match == 1.0 and schema_match | |
| } | |
| def apply_patch(config: dict, patch: dict) -> dict: | |
| config = copy.deepcopy(config) | |
| step_idx = patch["step_index"] | |
| field = patch["field"] | |
| new_val = patch["new_value"] | |
| # Flat schema: config["steps"][step_idx][field] = new_val | |
| config["steps"][step_idx][field] = new_val | |
| return config | |