Rajan Sharma commited on
Commit
5d68d4f
·
verified ·
1 Parent(s): c5c42b6

Update scenario_engine.py

Browse files
Files changed (1) hide show
  1. scenario_engine.py +265 -25
scenario_engine.py CHANGED
@@ -1,37 +1,277 @@
1
- import pandas as pd
 
 
2
  import numpy as np
3
- import math, json, re, ast
4
- from schemas import ScenarioPlan, TaskSpec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  class ScenarioEngine:
7
  @staticmethod
8
- def render_plan(plan: ScenarioPlan, results: dict) -> str:
9
- sections = ["# Scenario Output\n"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  for t in plan.tasks:
11
- sections.append(ScenarioEngine._render_task(t, results))
12
- return "\n".join(sections)
13
 
14
  @staticmethod
15
- def _df(v):
16
- if isinstance(v, pd.DataFrame): return v
17
- if isinstance(v, list) and v and isinstance(v[0], dict): return pd.DataFrame(v)
18
- if isinstance(v, dict): return pd.DataFrame([v])
19
- return None
 
20
 
21
  @staticmethod
22
- def _render_task(t: TaskSpec, results: dict) -> str:
23
- out=[f"## {t.title}\n"]
24
- df=ScenarioEngine._df(results.get(t.data_key)) if t.data_key else None
25
- if df is None: return "\n".join(out+["_No data_"])
26
- if t.filter: df=df.query(t.filter)
27
- if t.group_by or t.agg: df=df.groupby(t.group_by).agg("first").reset_index()
28
- if t.sort_by: df=df.sort_values(by=t.sort_by, ascending=(t.sort_dir=="asc"))
29
- if t.top: df=df.head(t.top)
30
- if t.fields: df=df[t.fields]
31
 
32
- if t.format=="list":
33
- out += [f"- {row.to_dict()}" for _,row in df.iterrows()]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  else:
35
- out.append(df.to_markdown(index=False))
36
- return "\n".join(out)
 
 
 
 
37
 
 
1
+ from __future__ import annotations
2
+ from typing import Dict, List, Any, Tuple, Union, Optional
3
+ import re, math, json, ast
4
  import numpy as np
5
+ import pandas as pd
6
+ from schema import ScenarioPlan, TaskPlan
7
+ from column_resolver import resolve_cols
8
+
9
+ _ALLOWED_FUNCS = {
10
+ "abs": abs, "round": round, "sqrt": math.sqrt, "log": math.log, "exp": math.exp,
11
+ "min": np.minimum, "max": np.maximum,
12
+ "mean": np.mean, "avg": np.mean, "median": np.median, "sum": np.sum,
13
+ "count": lambda x: np.size(x),
14
+ "p50": lambda x: np.percentile(x, 50), "p75": lambda x: np.percentile(x, 75),
15
+ "p90": lambda x: np.percentile(x, 90), "p95": lambda x: np.percentile(x, 95),
16
+ }
17
+
18
+ class _SafeExpr(ast.NodeTransformer):
19
+ def __init__(self, allowed_names: set): self.allowed_names = allowed_names
20
+ def visit_Name(self, node):
21
+ if node.id not in self.allowed_names and node.id not in ("True","False","None"):
22
+ raise ValueError(f"Unknown name: {node.id}")
23
+ return node
24
+ def visit_Call(self, node):
25
+ if not isinstance(node.func, ast.Name): raise ValueError("Only simple calls allowed")
26
+ if node.func.id not in _ALLOWED_FUNCS: raise ValueError(f"Function not allowed: {node.func.id}")
27
+ return self.generic_visit(node)
28
+ def generic_visit(self, node):
29
+ allowed = (ast.Expression, ast.BoolOp, ast.BinOp, ast.UnaryOp, ast.Compare, ast.Call, ast.Name,
30
+ ast.Load, ast.Constant, ast.And, ast.Or, ast.Not, ast.Add, ast.Sub, ast.Mult, ast.Div,
31
+ ast.Mod, ast.Pow, ast.FloorDiv, ast.Eq, ast.NotEq, ast.Lt, ast.LtE, ast.Gt, ast.GtE,
32
+ ast.USub, ast.UAdd)
33
+ if not isinstance(node, allowed):
34
+ raise ValueError(f"Unsupported syntax: {type(node).__name__}")
35
+ return super().generic_visit(node)
36
+
37
+ def _eval_series_expr(expr: str, df: pd.DataFrame) -> pd.Series:
38
+ names = set(df.columns) | {"True","False","None"}
39
+ tree = ast.parse(expr, mode="eval")
40
+ _SafeExpr(names).visit(tree)
41
+ code = compile(tree, "<expr>", "eval")
42
+ env = {**{k: df[k] for k in df.columns}, **_ALLOWED_FUNCS}
43
+ return eval(code, {"__builtins__": {}}, env)
44
 
45
  class ScenarioEngine:
46
  @staticmethod
47
+ def _as_df(v: Any) -> Optional[pd.DataFrame]:
48
+ if isinstance(v, list):
49
+ if not v: return pd.DataFrame()
50
+ return pd.DataFrame(v) if isinstance(v[0], dict) else pd.DataFrame({"value": v})
51
+ if isinstance(v, dict):
52
+ if any(isinstance(val, (int, float, str, bool, type(None))) for val in v.values()):
53
+ return pd.DataFrame([v])
54
+ rows = []
55
+ for k, val in v.items():
56
+ if isinstance(val, dict):
57
+ rec = {"item": k}; rec.update(val); rows.append(rec)
58
+ if rows: return pd.DataFrame(rows)
59
+ if isinstance(v, pd.DataFrame): return v
60
+ return None
61
+
62
+ # ---------- Plan-first API ----------
63
+ @staticmethod
64
+ def execute_plan(plan: ScenarioPlan, datasets: Dict[str, Any]) -> str:
65
+ sections: List[str] = ["# Scenario Output\n"]
66
  for t in plan.tasks:
67
+ sections.append(ScenarioEngine._exec_task(t, datasets))
68
+ return "\n".join(sections).strip()
69
 
70
  @staticmethod
71
+ def _get_df(datasets: Dict[str, Any], key: Optional[str]) -> Optional[pd.DataFrame]:
72
+ if key and key in datasets:
73
+ v = datasets[key]
74
+ else:
75
+ v = next((vv for vv in datasets.values() if isinstance(vv, (list, dict, pd.DataFrame))), None)
76
+ return ScenarioEngine._as_df(v) if v is not None else None
77
 
78
  @staticmethod
79
+ def _apply_filter(df: pd.DataFrame, expr: str) -> pd.DataFrame:
80
+ m = _eval_series_expr(expr, df)
81
+ return df.loc[m.astype(bool)].copy()
 
 
 
 
 
 
82
 
83
+ @staticmethod
84
+ def _apply_derive(df: pd.DataFrame, spec: str) -> pd.DataFrame:
85
+ parts = re.split(r'[;,]\s*', spec)
86
+ for p in parts:
87
+ if not p.strip(): continue
88
+ if "=" not in p: raise ValueError(f"derive requires col=expr: '{p}'")
89
+ col, expr = p.split("=", 1); df[col.strip()] = _eval_series_expr(expr.strip(), df)
90
+ return df
91
+
92
+ @staticmethod
93
+ def _parse_aggs(spec: Optional[str]) -> List[Tuple[str, str]]:
94
+ if not spec: return []
95
+ items = [x.strip() for x in spec.split(",") if x.strip()]
96
+ out = []
97
+ for it in items:
98
+ m = re.match(r'([a-zA-Z_][a-zA-Z0-9_]*)\s*\(([^)]+)\)', it)
99
+ if not m:
100
+ if it.lower() in ("count","count(*)"): out.append(("count","count(*)")); continue
101
+ raise ValueError(f"Bad agg: {it}")
102
+ func, arg = m.group(1).lower(), m.group(2).strip()
103
+ out.append((f"{func}_{arg}", f"{func}({arg})"))
104
+ return out
105
+
106
+ @staticmethod
107
+ def _apply_agg_call(df: pd.DataFrame, call: str):
108
+ call = call.strip()
109
+ if call.lower() in ("count","count(*)"): return int(len(df))
110
+ m = re.match(r'([a-zA-Z_][a-zA-Z0-9_]*)\s*\(([^)]+)\)', call)
111
+ func, arg = m.group(1).lower(), m.group(2).strip()
112
+ if arg not in df.columns: raise ValueError(f"Unknown column: {arg}")
113
+ col = df[arg].dropna()
114
+ if func in ("avg","mean"): return float(np.mean(col)) if len(col) else float("nan")
115
+ if func == "median": return float(np.median(col)) if len(col) else float("nan")
116
+ if func == "sum": return float(np.sum(col)) if len(col) else 0.0
117
+ if func in ("min","max"): return float(getattr(np, func)(col)) if len(col) else float("nan")
118
+ if func.startswith("p") and func[1:].isdigit(): return float(np.percentile(col, int(func[1:]))) if len(col) else float("nan")
119
+ raise ValueError(f"Unsupported agg: {func}")
120
+
121
+ @staticmethod
122
+ def _group_agg(df: pd.DataFrame, group_by: Optional[List[str]], agg_spec: Optional[str]) -> pd.DataFrame:
123
+ aggs = ScenarioEngine._parse_aggs(agg_spec)
124
+ if not aggs and not group_by: return df
125
+ if not group_by:
126
+ return pd.DataFrame([{k: ScenarioEngine._apply_agg_call(df, call) for k, call in aggs}])
127
+ rows = []
128
+ gb = df.groupby(group_by, dropna=False)
129
+ for keys, g in gb:
130
+ if not isinstance(keys, tuple): keys = (keys,)
131
+ rec = {group_by[i]: keys[i] for i in range(len(group_by))}
132
+ if aggs:
133
+ for out_col, call in aggs: rec[out_col] = ScenarioEngine._apply_agg_call(g, call)
134
+ else:
135
+ rec["count"] = len(g)
136
+ rows.append(rec)
137
+ return pd.DataFrame(rows)
138
+
139
+ @staticmethod
140
+ def _pivot(df: pd.DataFrame, spec: str) -> pd.DataFrame:
141
+ parts = dict(re.findall(r'(\w+)\s*=\s*([^\s,]+)', spec))
142
+ idx = [x.strip() for x in parts.get("index","").split(",") if x.strip()]
143
+ cols = parts.get("columns"); vals = parts.get("values")
144
+ if not (idx and cols and vals): raise ValueError("pivot requires index=.. columns=.. values=..")
145
+ pv = df.pivot_table(index=idx, columns=cols, values=vals, aggfunc="first").reset_index()
146
+ if isinstance(pv.columns, pd.MultiIndex):
147
+ pv.columns = ["_".join([str(c) for c in tup if c!=""]) for tup in pv.columns]
148
+ return pv
149
+
150
+ @staticmethod
151
+ def _render_table(df: pd.DataFrame) -> str:
152
+ if df.empty: return "_No rows._"
153
+ dff = df.copy()
154
+ for c in dff.columns:
155
+ dff[c] = dff[c].apply(lambda v: "NaN" if (isinstance(v,float) and math.isnan(v)) else f"{v:,.4g}" if isinstance(v,float) else v)
156
+ header = "| " + " | ".join(dff.columns) + " |"
157
+ sep = "|" + "|".join(["---"] * len(dff.columns)) + "|"
158
+ rows = ["| " + " | ".join(map(str, r)) + " |" for r in dff.to_numpy().tolist()]
159
+ return "\n".join([header, sep, *rows])
160
+
161
+ @staticmethod
162
+ def _render_list(df: pd.DataFrame) -> str:
163
+ if df.empty: return "_No items._"
164
+ primary = df.columns[0]
165
+ lines = []
166
+ for i, row in enumerate(df.itertuples(index=False), 1):
167
+ extras = [f"{c}: {getattr(row,c)}" for c in df.columns if c != primary]
168
+ lines.append(f"{i}. {getattr(row, primary)}" + (f" ({', '.join(extras)})" if extras else ""))
169
+ return "\n".join(lines)
170
+
171
+ @staticmethod
172
+ def _render_comparison(df: pd.DataFrame) -> str:
173
+ cols = {c.lower(): c for c in df.columns}
174
+ cur = cols.get("current") or cols.get("now") or cols.get("value")
175
+ prev = cols.get("previous") or cols.get("prior") or cols.get("past")
176
+ name = cols.get("name") or cols.get("metric") or cols.get("item") or df.columns[0]
177
+ if not (cur and prev): return "_Comparison requires 'current' and 'previous' columns._"
178
+ header = "| Item | Current | Previous | Change |"; sep="|---|---:|---:|---:|"; body=[]
179
+ for _, r in df.iterrows():
180
+ c, p = r[cur], r[prev]
181
+ ch = (c - p) if isinstance(c,(int,float)) and isinstance(p,(int,float)) else "N/A"
182
+ body.append(f"| {r[name]} | {c} | {p} | {ch} |")
183
+ return "\n".join([header, sep, *body])
184
+
185
+ @staticmethod
186
+ def _render_map(df: pd.DataFrame) -> str:
187
+ col = {c.lower(): c for c in df.columns}
188
+ name = col.get("facility") or col.get("name") or df.columns[0]
189
+ lat = col.get("latitude") or col.get("lat"); lon = col.get("longitude") or col.get("lon")
190
+ zone = col.get("zone"); city = col.get("city")
191
+ show = [x for x in [name, city, zone, lat, lon] if x]
192
+ if not show: return "_No geographic fields._"
193
+ tmp = df[show].copy()
194
+ if lat and lon:
195
+ tmp["coordinates"] = tmp[lat].astype(str) + ", " + tmp[lon].astype(str)
196
+ show = [name, city or "city", zone or "zone", "coordinates"]
197
+ return ScenarioEngine._render_table(tmp[show])
198
+
199
+ @staticmethod
200
+ def _render_chart(df: pd.DataFrame, d: Dict[str, Any]) -> str:
201
+ mark = d.get("chart","bar")
202
+ spec = {
203
+ "$schema": "https://vega.github.io/schema/vega-lite/v5.json",
204
+ "description": d.get("title") or "Chart",
205
+ "data": {"values": df.to_dict(orient="records")},
206
+ "mark": mark, "encoding": {}
207
+ }
208
+ for enc in ("x","y","color","column"):
209
+ if enc in d and d[enc] in df.columns:
210
+ spec["encoding"][enc] = {"field": d[enc], "type": "quantitative" if pd.api.types.is_numeric_dtype(df[d[enc]]) else "nominal"}
211
+ return "```vega-lite\n" + json.dumps(spec, ensure_ascii=False, indent=2) + "\n```"
212
+
213
+ @staticmethod
214
+ def _exec_task(t: TaskPlan, datasets: Dict[str, Any]) -> str:
215
+ section = [f"## {t.title_override or t.title}\n"]
216
+ df = ScenarioEngine._get_df(datasets, t.data_key)
217
+ if df is None or df.empty:
218
+ section += ["_No matching data for this task._", "\n**Provenance**", f"- Data key: `{t.data_key or 'auto'}`"]
219
+ return "\n".join(section)
220
+
221
+ if t.filter: df = ScenarioEngine._apply_filter(df, t.filter)
222
+ if t.derive:
223
+ for d in t.derive: df = ScenarioEngine._apply_derive(df, d)
224
+ if t.joins:
225
+ for j in t.joins:
226
+ rk, lo, ro, how = j["right_key"], j["left_on"], j["right_on"], j.get("how","left").lower()
227
+ r = ScenarioEngine._as_df(datasets.get(rk))
228
+ if r is not None:
229
+ df = df.merge(r, left_on=lo, right_on=ro, how=how)
230
+ if t.group_by or t.agg:
231
+ df = ScenarioEngine._group_agg(df, t.group_by, ", ".join(t.agg or []))
232
+ if t.pivot:
233
+ spec = t.pivot
234
+ df = ScenarioEngine._pivot(df, f"index={','.join(spec.get('index', []))} columns={spec['columns']} values={spec['values']}")
235
+ if t.sort_by and t.sort_by in df.columns:
236
+ df = df.sort_values(by=t.sort_by, ascending=(t.sort_dir or "desc").lower()=="asc")
237
+ if t.top and t.top>0: df = df.head(t.top)
238
+ if t.fields:
239
+ cols = resolve_cols(t.fields, df.columns.tolist())
240
+ cols = [c for c in cols if c in df.columns]
241
+ if cols: df = df[cols]
242
+
243
+ if t.number_format:
244
+ for col, fmt in t.number_format.items():
245
+ if col in df.columns:
246
+ if fmt.endswith("%"):
247
+ decimals = len(fmt.split(".")[-1].rstrip("%")) if "." in fmt else 0
248
+ df[col] = (df[col].astype(float) * 100).round(decimals).astype(str) + "%"
249
+ else:
250
+ try:
251
+ decimals = int(fmt.split(".")[-1]) if "." in fmt else 0
252
+ df[col] = df[col].astype(float).round(decimals)
253
+ except Exception:
254
+ pass
255
+
256
+ fmt = (t.format or "table").lower()
257
+ if fmt == "list": body = ScenarioEngine._render_list(df)
258
+ elif fmt == "comparison": body = ScenarioEngine._render_comparison(df)
259
+ elif fmt == "map": body = ScenarioEngine._render_map(df)
260
+ elif fmt == "chart":
261
+ enc = t.encodings or {}
262
+ d = {"chart": t.chart or "bar", **enc}
263
+ body = ScenarioEngine._render_chart(df, d)
264
+ elif fmt == "narrative":
265
+ lines = []
266
+ for i, rec in enumerate(df.to_dict(orient="records"), 1):
267
+ parts = [f"**{k}**: {v}" for k, v in rec.items()]
268
+ lines.append(f"{i}. " + "; ".join(parts))
269
+ body = "\n".join(lines) if lines else "_No content._"
270
  else:
271
+ body = ScenarioEngine._render_table(df)
272
+
273
+ section.append(body)
274
+ section.append("\n**Provenance**")
275
+ section.append(f"- Data key: `{t.data_key or 'auto'}`")
276
+ return "\n".join(section)
277