File size: 2,143 Bytes
3512a86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import pandas as pd

ALLOWED_OPS = {"mean", "median", "count"}

def execute_plan(df: pd.DataFrame, plan: dict) -> pd.DataFrame:
    q = df.copy()

    # filters
    for col, rule in (plan.get("filters") or {}).items():
        if col not in q.columns:
            raise ValueError("Unknown column: %s" % col)
        if not isinstance(rule, dict):
            raise ValueError("Bad filter rule for %s" % col)
        if "eq" in rule:
            q = q[q[col] == rule["eq"]]
        if "in" in rule:
            q = q[q[col].isin(rule["in"])]
        if "not_in" in rule:
            q = q[~q[col].isin(rule["not_in"])]
        if "gte" in rule:
            q = q[q[col] >= rule["gte"]]
        if "lte" in rule:
            q = q[q[col] <= rule["lte"]]

    groupby = plan.get("groupby") or []
    metrics = plan.get("metrics") or []

    if groupby:
        gb = q.groupby(groupby, dropna=False)
        agg_dict = {}
        for m in metrics:
            col, op = m.get("col"), m.get("op")
            label = m.get("label", "%s_%s" % (op, col))
            if op not in ALLOWED_OPS:
                raise ValueError("Unsupported op: %s" % op)
            if op == "count":
                agg_dict[label] = (col, "count")
            else:
                agg_dict[label] = (col, op)
        res = gb.agg(**agg_dict).reset_index() if agg_dict else gb.size().reset_index(name="count")
    else:
        # global summary
        rows = {}
        for m in metrics:
            col, op = m.get("col"), m.get("op")
            label = m.get("label", "%s_%s" % (op, col))
            if op not in ALLOWED_OPS:
                raise ValueError("Unsupported op: %s" % op)
            if op == "count":
                rows[label] = int(q[col].count())
            else:
                rows[label] = float(getattr(q[col], op)())
        res = pd.DataFrame([rows]) if rows else q.head(20)

    for s in (plan.get("sort_by") or []):
        res = res.sort_values(s.get("col"), ascending=bool(s.get("asc", True)))

    limit = min(int(plan.get("limit", 20)), 50)
    return res.head(limit)