Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| from typing import Dict, List, Any, Tuple, Union, Optional | |
| import re, math, json, ast | |
| import numpy as np | |
| import pandas as pd | |
| from schema import ScenarioPlan, TaskPlan | |
| from column_resolver import resolve_cols | |
| _ALLOWED_FUNCS = { | |
| "abs": abs, "round": round, "sqrt": math.sqrt, "log": math.log, "exp": math.exp, | |
| "min": np.minimum, "max": np.maximum, | |
| "mean": np.mean, "avg": np.mean, "median": np.median, "sum": np.sum, | |
| "count": lambda x: np.size(x), | |
| "p50": lambda x: np.percentile(x, 50), "p75": lambda x: np.percentile(x, 75), | |
| "p90": lambda x: np.percentile(x, 90), "p95": lambda x: np.percentile(x, 95), | |
| } | |
| class _SafeExpr(ast.NodeTransformer): | |
| def __init__(self, allowed_names: set): self.allowed_names = allowed_names | |
| def visit_Name(self, node): | |
| if node.id not in self.allowed_names and node.id not in ("True","False","None"): | |
| raise ValueError(f"Unknown name: {node.id}") | |
| return node | |
| def visit_Call(self, node): | |
| if not isinstance(node.func, ast.Name): raise ValueError("Only simple calls allowed") | |
| if node.func.id not in _ALLOWED_FUNCS: raise ValueError(f"Function not allowed: {node.func.id}") | |
| return self.generic_visit(node) | |
| def generic_visit(self, node): | |
| allowed = (ast.Expression, ast.BoolOp, ast.BinOp, ast.UnaryOp, ast.Compare, ast.Call, ast.Name, | |
| ast.Load, ast.Constant, ast.And, ast.Or, ast.Not, ast.Add, ast.Sub, ast.Mult, ast.Div, | |
| ast.Mod, ast.Pow, ast.FloorDiv, ast.Eq, ast.NotEq, ast.Lt, ast.LtE, ast.Gt, ast.GtE, | |
| ast.USub, ast.UAdd) | |
| if not isinstance(node, allowed): | |
| raise ValueError(f"Unsupported syntax: {type(node).__name__}") | |
| return super().generic_visit(node) | |
| def _eval_series_expr(expr: str, df: pd.DataFrame) -> pd.Series: | |
| names = set(df.columns) | {"True","False","None"} | |
| tree = ast.parse(expr, mode="eval") | |
| _SafeExpr(names).visit(tree) | |
| code = compile(tree, "<expr>", "eval") | |
| env = {**{k: df[k] for k in df.columns}, **_ALLOWED_FUNCS} | |
| return eval(code, {"__builtins__": {}}, env) | |
| class ScenarioEngine: | |
| def _as_df(v: Any) -> Optional[pd.DataFrame]: | |
| if isinstance(v, list): | |
| if not v: return pd.DataFrame() | |
| return pd.DataFrame(v) if isinstance(v[0], dict) else pd.DataFrame({"value": v}) | |
| if isinstance(v, dict): | |
| if any(isinstance(val, (int, float, str, bool, type(None))) for val in v.values()): | |
| return pd.DataFrame([v]) | |
| rows = [] | |
| for k, val in v.items(): | |
| if isinstance(val, dict): | |
| rec = {"item": k}; rec.update(val); rows.append(rec) | |
| if rows: return pd.DataFrame(rows) | |
| if isinstance(v, pd.DataFrame): return v | |
| return None | |
| # ---------- Plan-first API ---------- | |
| def execute_plan(plan: ScenarioPlan, datasets: Dict[str, Any]) -> str: | |
| sections: List[str] = ["# Scenario Output\n"] | |
| for t in plan.tasks: | |
| sections.append(ScenarioEngine._exec_task(t, datasets)) | |
| return "\n".join(sections).strip() | |
| def _get_df(datasets: Dict[str, Any], key: Optional[str]) -> Optional[pd.DataFrame]: | |
| if key and key in datasets: | |
| v = datasets[key] | |
| else: | |
| v = next((vv for vv in datasets.values() if isinstance(vv, (list, dict, pd.DataFrame))), None) | |
| return ScenarioEngine._as_df(v) if v is not None else None | |
| def _apply_filter(df: pd.DataFrame, expr: str) -> pd.DataFrame: | |
| m = _eval_series_expr(expr, df) | |
| return df.loc[m.astype(bool)].copy() | |
| def _apply_derive(df: pd.DataFrame, spec: str) -> pd.DataFrame: | |
| parts = re.split(r'[;,]\s*', spec) | |
| for p in parts: | |
| if not p.strip(): continue | |
| if "=" not in p: raise ValueError(f"derive requires col=expr: '{p}'") | |
| col, expr = p.split("=", 1); df[col.strip()] = _eval_series_expr(expr.strip(), df) | |
| return df | |
| def _parse_aggs(spec: Optional[str]) -> List[Tuple[str, str]]: | |
| if not spec: return [] | |
| items = [x.strip() for x in spec.split(",") if x.strip()] | |
| out = [] | |
| for it in items: | |
| m = re.match(r'([a-zA-Z_][a-zA-Z0-9_]*)\s*\(([^)]+)\)', it) | |
| if not m: | |
| if it.lower() in ("count","count(*)"): out.append(("count","count(*)")); continue | |
| raise ValueError(f"Bad agg: {it}") | |
| func, arg = m.group(1).lower(), m.group(2).strip() | |
| out.append((f"{func}_{arg}", f"{func}({arg})")) | |
| return out | |
| def _apply_agg_call(df: pd.DataFrame, call: str): | |
| call = call.strip() | |
| if call.lower() in ("count","count(*)"): return int(len(df)) | |
| m = re.match(r'([a-zA-Z_][a-zA-Z0-9_]*)\s*\(([^)]+)\)', call) | |
| func, arg = m.group(1).lower(), m.group(2).strip() | |
| if arg not in df.columns: raise ValueError(f"Unknown column: {arg}") | |
| col = df[arg].dropna() | |
| if func in ("avg","mean"): return float(np.mean(col)) if len(col) else float("nan") | |
| if func == "median": return float(np.median(col)) if len(col) else float("nan") | |
| if func == "sum": return float(np.sum(col)) if len(col) else 0.0 | |
| if func in ("min","max"): return float(getattr(np, func)(col)) if len(col) else float("nan") | |
| if func.startswith("p") and func[1:].isdigit(): return float(np.percentile(col, int(func[1:]))) if len(col) else float("nan") | |
| raise ValueError(f"Unsupported agg: {func}") | |
| def _group_agg(df: pd.DataFrame, group_by: Optional[List[str]], agg_spec: Optional[str]) -> pd.DataFrame: | |
| aggs = ScenarioEngine._parse_aggs(agg_spec) | |
| if not aggs and not group_by: return df | |
| if not group_by: | |
| return pd.DataFrame([{k: ScenarioEngine._apply_agg_call(df, call) for k, call in aggs}]) | |
| rows = [] | |
| gb = df.groupby(group_by, dropna=False) | |
| for keys, g in gb: | |
| if not isinstance(keys, tuple): keys = (keys,) | |
| rec = {group_by[i]: keys[i] for i in range(len(group_by))} | |
| if aggs: | |
| for out_col, call in aggs: rec[out_col] = ScenarioEngine._apply_agg_call(g, call) | |
| else: | |
| rec["count"] = len(g) | |
| rows.append(rec) | |
| return pd.DataFrame(rows) | |
| def _pivot(df: pd.DataFrame, spec: str) -> pd.DataFrame: | |
| parts = dict(re.findall(r'(\w+)\s*=\s*([^\s,]+)', spec)) | |
| idx = [x.strip() for x in parts.get("index","").split(",") if x.strip()] | |
| cols = parts.get("columns"); vals = parts.get("values") | |
| if not (idx and cols and vals): raise ValueError("pivot requires index=.. columns=.. values=..") | |
| pv = df.pivot_table(index=idx, columns=cols, values=vals, aggfunc="first").reset_index() | |
| if isinstance(pv.columns, pd.MultiIndex): | |
| pv.columns = ["_".join([str(c) for c in tup if c!=""]) for tup in pv.columns] | |
| return pv | |
| def _render_table(df: pd.DataFrame) -> str: | |
| if df.empty: return "_No rows._" | |
| dff = df.copy() | |
| for c in dff.columns: | |
| dff[c] = dff[c].apply(lambda v: "NaN" if (isinstance(v,float) and math.isnan(v)) else f"{v:,.4g}" if isinstance(v,float) else v) | |
| header = "| " + " | ".join(dff.columns) + " |" | |
| sep = "|" + "|".join(["---"] * len(dff.columns)) + "|" | |
| rows = ["| " + " | ".join(map(str, r)) + " |" for r in dff.to_numpy().tolist()] | |
| return "\n".join([header, sep, *rows]) | |
| def _render_list(df: pd.DataFrame) -> str: | |
| if df.empty: return "_No items._" | |
| primary = df.columns[0] | |
| lines = [] | |
| for i, row in enumerate(df.itertuples(index=False), 1): | |
| extras = [f"{c}: {getattr(row,c)}" for c in df.columns if c != primary] | |
| lines.append(f"{i}. {getattr(row, primary)}" + (f" ({', '.join(extras)})" if extras else "")) | |
| return "\n".join(lines) | |
| def _render_comparison(df: pd.DataFrame) -> str: | |
| cols = {c.lower(): c for c in df.columns} | |
| cur = cols.get("current") or cols.get("now") or cols.get("value") | |
| prev = cols.get("previous") or cols.get("prior") or cols.get("past") | |
| name = cols.get("name") or cols.get("metric") or cols.get("item") or df.columns[0] | |
| if not (cur and prev): return "_Comparison requires 'current' and 'previous' columns._" | |
| header = "| Item | Current | Previous | Change |"; sep="|---|---:|---:|---:|"; body=[] | |
| for _, r in df.iterrows(): | |
| c, p = r[cur], r[prev] | |
| ch = (c - p) if isinstance(c,(int,float)) and isinstance(p,(int,float)) else "N/A" | |
| body.append(f"| {r[name]} | {c} | {p} | {ch} |") | |
| return "\n".join([header, sep, *body]) | |
| def _render_map(df: pd.DataFrame) -> str: | |
| col = {c.lower(): c for c in df.columns} | |
| name = col.get("facility") or col.get("name") or df.columns[0] | |
| lat = col.get("latitude") or col.get("lat"); lon = col.get("longitude") or col.get("lon") | |
| zone = col.get("zone"); city = col.get("city") | |
| show = [x for x in [name, city, zone, lat, lon] if x] | |
| if not show: return "_No geographic fields._" | |
| tmp = df[show].copy() | |
| if lat and lon: | |
| tmp["coordinates"] = tmp[lat].astype(str) + ", " + tmp[lon].astype(str) | |
| show = [name, city or "city", zone or "zone", "coordinates"] | |
| return ScenarioEngine._render_table(tmp[show]) | |
| def _render_chart(df: pd.DataFrame, d: Dict[str, Any]) -> str: | |
| mark = d.get("chart","bar") | |
| spec = { | |
| "$schema": "https://vega.github.io/schema/vega-lite/v5.json", | |
| "description": d.get("title") or "Chart", | |
| "data": {"values": df.to_dict(orient="records")}, | |
| "mark": mark, "encoding": {} | |
| } | |
| for enc in ("x","y","color","column"): | |
| if enc in d and d[enc] in df.columns: | |
| spec["encoding"][enc] = {"field": d[enc], "type": "quantitative" if pd.api.types.is_numeric_dtype(df[d[enc]]) else "nominal"} | |
| return "```vega-lite\n" + json.dumps(spec, ensure_ascii=False, indent=2) + "\n```" | |
| def _exec_task(t: TaskPlan, datasets: Dict[str, Any]) -> str: | |
| section = [f"## {t.title_override or t.title}\n"] | |
| df = ScenarioEngine._get_df(datasets, t.data_key) | |
| if df is None or df.empty: | |
| section += ["_No matching data for this task._", "\n**Provenance**", f"- Data key: `{t.data_key or 'auto'}`"] | |
| return "\n".join(section) | |
| if t.filter: df = ScenarioEngine._apply_filter(df, t.filter) | |
| if t.derive: | |
| for d in t.derive: df = ScenarioEngine._apply_derive(df, d) | |
| if t.joins: | |
| for j in t.joins: | |
| rk, lo, ro, how = j["right_key"], j["left_on"], j["right_on"], j.get("how","left").lower() | |
| r = ScenarioEngine._as_df(datasets.get(rk)) | |
| if r is not None: | |
| df = df.merge(r, left_on=lo, right_on=ro, how=how) | |
| if t.group_by or t.agg: | |
| df = ScenarioEngine._group_agg(df, t.group_by, ", ".join(t.agg or [])) | |
| if t.pivot: | |
| spec = t.pivot | |
| df = ScenarioEngine._pivot(df, f"index={','.join(spec.get('index', []))} columns={spec['columns']} values={spec['values']}") | |
| if t.sort_by and t.sort_by in df.columns: | |
| df = df.sort_values(by=t.sort_by, ascending=(t.sort_dir or "desc").lower()=="asc") | |
| if t.top and t.top>0: df = df.head(t.top) | |
| if t.fields: | |
| cols = resolve_cols(t.fields, df.columns.tolist()) | |
| cols = [c for c in cols if c in df.columns] | |
| if cols: df = df[cols] | |
| if t.number_format: | |
| for col, fmt in t.number_format.items(): | |
| if col in df.columns: | |
| if fmt.endswith("%"): | |
| decimals = len(fmt.split(".")[-1].rstrip("%")) if "." in fmt else 0 | |
| df[col] = (df[col].astype(float) * 100).round(decimals).astype(str) + "%" | |
| else: | |
| try: | |
| decimals = int(fmt.split(".")[-1]) if "." in fmt else 0 | |
| df[col] = df[col].astype(float).round(decimals) | |
| except Exception: | |
| pass | |
| fmt = (t.format or "table").lower() | |
| if fmt == "list": body = ScenarioEngine._render_list(df) | |
| elif fmt == "comparison": body = ScenarioEngine._render_comparison(df) | |
| elif fmt == "map": body = ScenarioEngine._render_map(df) | |
| elif fmt == "chart": | |
| enc = t.encodings or {} | |
| d = {"chart": t.chart or "bar", **enc} | |
| body = ScenarioEngine._render_chart(df, d) | |
| elif fmt == "narrative": | |
| lines = [] | |
| for i, rec in enumerate(df.to_dict(orient="records"), 1): | |
| parts = [f"**{k}**: {v}" for k, v in rec.items()] | |
| lines.append(f"{i}. " + "; ".join(parts)) | |
| body = "\n".join(lines) if lines else "_No content._" | |
| else: | |
| body = ScenarioEngine._render_table(df) | |
| section.append(body) | |
| section.append("\n**Provenance**") | |
| section.append(f"- Data key: `{t.data_key or 'auto'}`") | |
| return "\n".join(section) | |