Rajan Sharma commited on
Commit
492569d
·
verified ·
1 Parent(s): 1b29d16

Update scenario_engine.py

Browse files
Files changed (1) hide show
  1. scenario_engine.py +68 -142
scenario_engine.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from __future__ import annotations
2
  from typing import Dict, List, Any, Tuple, Optional
3
  import re, math, json, ast
@@ -6,6 +7,7 @@ import pandas as pd
6
  from schema import ScenarioPlan, TaskPlan
7
  from column_resolver import resolve_cols
8
 
 
9
  _ALLOWED_FUNCS = {
10
  "abs": abs, "round": round, "sqrt": math.sqrt, "log": math.log, "exp": math.exp,
11
  "min": np.minimum, "max": np.maximum,
@@ -15,21 +17,26 @@ _ALLOWED_FUNCS = {
15
  "p90": lambda x: np.percentile(x, 90), "p95": lambda x: np.percentile(x, 95),
16
  }
17
 
 
18
  class _SafeExpr(ast.NodeTransformer):
19
- def __init__(self, allowed_names: set): self.allowed_names = allowed_names
20
  def visit_Name(self, node):
21
- if node.id not in self.allowed_names and node.id not in ("True","False","None"):
22
  raise ValueError(f"Unknown name: {node.id}")
23
  return node
24
  def visit_Call(self, node):
25
- if not isinstance(node.func, ast.Name): raise ValueError("Only simple calls allowed")
26
- if node.func.id not in _ALLOWED_FUNCS: raise ValueError(f"Function not allowed: {node.func.id}")
 
 
27
  return self.generic_visit(node)
28
  def generic_visit(self, node):
29
- allowed = (ast.Expression, ast.BoolOp, ast.BinOp, ast.UnaryOp, ast.Compare, ast.Call, ast.Name,
30
- ast.Load, ast.Constant, ast.And, ast.Or, ast.Not, ast.Add, ast.Sub, ast.Mult, ast.Div,
31
- ast.Mod, ast.Pow, ast.FloorDiv, ast.Eq, ast.NotEq, ast.Lt, ast.LtE, ast.Gt, ast.GtE,
32
- ast.USub, ast.UAdd)
 
 
33
  if not isinstance(node, allowed):
34
  raise ValueError(f"Unsupported syntax: {type(node).__name__}")
35
  return super().generic_visit(node)
@@ -42,21 +49,36 @@ def _eval_series_expr(expr: str, df: pd.DataFrame) -> pd.Series:
42
  env = {**{k: df[k] for k in df.columns}, **_ALLOWED_FUNCS}
43
  return eval(code, {"__builtins__": {}}, env)
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  class ScenarioEngine:
46
  @staticmethod
47
  def _as_df(v: Any) -> Optional[pd.DataFrame]:
48
  if isinstance(v, list):
49
- if not v: return pd.DataFrame()
50
- return pd.DataFrame(v) if isinstance(v[0], dict) else pd.DataFrame({"value": v})
51
  if isinstance(v, dict):
52
- if any(isinstance(val, (int, float, str, bool, type(None))) for val in v.values()):
53
- return pd.DataFrame([v])
54
- rows = []
55
- for k, val in v.items():
56
- if isinstance(val, dict):
57
- rec = {"item": k}; rec.update(val); rows.append(rec)
58
- if rows: return pd.DataFrame(rows)
59
- if isinstance(v, pd.DataFrame): return v
60
  return None
61
 
62
  @staticmethod
@@ -83,21 +105,20 @@ class ScenarioEngine:
83
  def _apply_derive(df: pd.DataFrame, spec: str) -> pd.DataFrame:
84
  parts = re.split(r'[;,]\s*', spec)
85
  for p in parts:
86
- if not p.strip(): continue
87
- if "=" not in p: raise ValueError(f"derive requires col=expr: '{p}'")
88
- col, expr = p.split("=", 1); df[col.strip()] = _eval_series_expr(expr.strip(), df)
89
  return df
90
 
91
  @staticmethod
92
  def _parse_aggs(spec: Optional[str]) -> List[Tuple[str, str]]:
93
  if not spec: return []
94
- items = [x.strip() for x in spec.split(",") if x.strip()]
95
  out = []
96
- for it in items:
97
- m = re.match(r'([a-zA-Z_][a-zA-Z0-9_]*)\s*\(([^)]+)\)', it)
98
- if not m:
99
- if it.lower() in ("count","count(*)"): out.append(("count","count(*)")); continue
100
- raise ValueError(f"Bad agg: {it}")
101
  func, arg = m.group(1).lower(), m.group(2).strip()
102
  out.append((f"{func}_{arg}", f"{func}({arg})"))
103
  return out
@@ -106,16 +127,16 @@ class ScenarioEngine:
106
  def _apply_agg_call(df: pd.DataFrame, call: str):
107
  call = call.strip()
108
  if call.lower() in ("count","count(*)"): return int(len(df))
109
- m = re.match(r'([a-zA-Z_][a-zA-Z0-9_]*)\s*\(([^)]+)\)', call)
110
  func, arg = m.group(1).lower(), m.group(2).strip()
111
- if arg not in df.columns: raise ValueError(f"Unknown column: {arg}")
112
  col = df[arg].dropna()
113
  if func in ("avg","mean"): return float(np.mean(col)) if len(col) else float("nan")
114
  if func == "median": return float(np.median(col)) if len(col) else float("nan")
115
  if func == "sum": return float(np.sum(col)) if len(col) else 0.0
116
  if func in ("min","max"): return float(getattr(np, func)(col)) if len(col) else float("nan")
117
  if func.startswith("p") and func[1:].isdigit(): return float(np.percentile(col, int(func[1:]))) if len(col) else float("nan")
118
- raise ValueError(f"Unsupported agg: {func}")
119
 
120
  @staticmethod
121
  def _group_agg(df: pd.DataFrame, group_by: Optional[List[str]], agg_spec: Optional[str]) -> pd.DataFrame:
@@ -128,24 +149,12 @@ class ScenarioEngine:
128
  for keys, g in gb:
129
  if not isinstance(keys, tuple): keys = (keys,)
130
  rec = {group_by[i]: keys[i] for i in range(len(group_by))}
131
- if aggs:
132
- for out_col, call in aggs: rec[out_col] = ScenarioEngine._apply_agg_call(g, call)
133
- else:
134
- rec["count"] = len(g)
135
  rows.append(rec)
136
  return pd.DataFrame(rows)
137
 
138
- @staticmethod
139
- def _pivot(df: pd.DataFrame, spec: str) -> pd.DataFrame:
140
- parts = dict(re.findall(r'(\w+)\s*=\s*([^\s,]+)', spec))
141
- idx = [x.strip() for x in parts.get("index","").split(",") if x.strip()]
142
- cols = parts.get("columns"); vals = parts.get("values")
143
- if not (idx and cols and vals): raise ValueError("pivot requires index=.. columns=.. values=..")
144
- pv = df.pivot_table(index=idx, columns=cols, values=vals, aggfunc="first").reset_index()
145
- if isinstance(pv.columns, pd.MultiIndex):
146
- pv.columns = ["_".join([str(c) for c in tup if c!=""]) for tup in pv.columns]
147
- return pv
148
-
149
  @staticmethod
150
  def _render_table(df: pd.DataFrame) -> str:
151
  if df.empty: return "_No rows._"
@@ -157,119 +166,36 @@ class ScenarioEngine:
157
  rows = ["| " + " | ".join(map(str, r)) + " |" for r in dff.to_numpy().tolist()]
158
  return "\n".join([header, sep, *rows])
159
 
160
- @staticmethod
161
- def _render_list(df: pd.DataFrame) -> str:
162
- if df.empty: return "_No items._"
163
- primary = df.columns[0]
164
- lines = []
165
- for i, row in enumerate(df.itertuples(index=False), 1):
166
- extras = [f"{c}: {getattr(row,c)}" for c in df.columns if c != primary]
167
- lines.append(f"{i}. {getattr(row, primary)}" + (f" ({', '.join(extras)})" if extras else ""))
168
- return "\n".join(lines)
169
-
170
- @staticmethod
171
- def _render_comparison(df: pd.DataFrame) -> str:
172
- cols = {c.lower(): c for c in df.columns}
173
- cur = cols.get("current") or cols.get("now") or cols.get("value")
174
- prev = cols.get("previous") or cols.get("prior") or cols.get("past")
175
- name = cols.get("name") or cols.get("metric") or cols.get("item") or df.columns[0]
176
- if not (cur and prev): return "_Comparison requires 'current' and 'previous' columns._"
177
- header = "| Item | Current | Previous | Change |"; sep="|---|---:|---:|---:|"; body=[]
178
- for _, r in df.iterrows():
179
- c, p = r[cur], r[prev]
180
- ch = (c - p) if isinstance(c,(int,float)) and isinstance(p,(int,float)) else "N/A"
181
- body.append(f"| {r[name]} | {c} | {p} | {ch} |")
182
- return "\n".join([header, sep, *body])
183
-
184
- @staticmethod
185
- def _render_map(df: pd.DataFrame) -> str:
186
- col = {c.lower(): c for c in df.columns}
187
- name = col.get("facility") or col.get("name") or df.columns[0]
188
- lat = col.get("latitude") or col.get("lat"); lon = col.get("longitude") or col.get("lon")
189
- zone = col.get("zone"); city = col.get("city")
190
- show = [x for x in [name, city, zone, lat, lon] if x]
191
- if not show: return "_No geographic fields._"
192
- tmp = df[show].copy()
193
- if lat and lon:
194
- tmp["coordinates"] = tmp[lat].astype(str) + ", " + tmp[lon].astype(str)
195
- show = [name, city or "city", zone or "zone", "coordinates"]
196
- return ScenarioEngine._render_table(tmp[show])
197
-
198
- @staticmethod
199
- def _render_chart(df: pd.DataFrame, d: Dict[str, Any]) -> str:
200
- mark = d.get("chart","bar")
201
- spec = {
202
- "$schema": "https://vega.github.io/schema/vega-lite/v5.json",
203
- "description": d.get("title") or "Chart",
204
- "data": {"values": df.to_dict(orient="records")},
205
- "mark": mark, "encoding": {}
206
- }
207
- for enc in ("x","y","color","column"):
208
- if enc in d and d[enc] in df.columns:
209
- spec["encoding"][enc] = {"field": d[enc], "type": "quantitative" if pd.api.types.is_numeric_dtype(df[d[enc]]) else "nominal"}
210
- return "```vega-lite\n" + json.dumps(spec, ensure_ascii=False, indent=2) + "\n```"
211
-
212
  @staticmethod
213
  def _exec_task(t: TaskPlan, datasets: Dict[str, Any]) -> str:
214
- section = [f"## {t.title_override or t.title}\n"]
215
  df = ScenarioEngine._get_df(datasets, t.data_key)
216
  if df is None or df.empty:
217
- section += ["_No matching data for this task._", "\n**Provenance**", f"- Data key: `{t.data_key or 'auto'}`"]
218
  return "\n".join(section)
219
 
 
 
 
 
220
  if t.filter: df = ScenarioEngine._apply_filter(df, t.filter)
221
  if t.derive:
222
  for d in t.derive: df = ScenarioEngine._apply_derive(df, d)
223
- if t.joins:
224
- for j in t.joins:
225
- rk, lo, ro, how = j["right_key"], j["left_on"], j["right_on"], j.get("how","left").lower()
226
- r = ScenarioEngine._as_df(datasets.get(rk))
227
- if r is not None:
228
- df = df.merge(r, left_on=lo, right_on=ro, how=how)
229
  if t.group_by or t.agg:
230
  df = ScenarioEngine._group_agg(df, t.group_by, ", ".join(t.agg or []))
231
- if t.pivot:
232
- spec = t.pivot
233
- df = ScenarioEngine._pivot(df, f"index={','.join(spec.get('index', []))} columns={spec['columns']} values={spec['values']}")
234
  if t.sort_by and t.sort_by in df.columns:
235
  df = df.sort_values(by=t.sort_by, ascending=(t.sort_dir or "desc").lower()=="asc")
236
- if t.top and t.top>0: df = df.head(t.top)
 
 
 
237
  if t.fields:
238
  cols = resolve_cols(t.fields, df.columns.tolist())
239
  cols = [c for c in cols if c in df.columns]
240
  if cols: df = df[cols]
241
 
242
- if t.number_format:
243
- for col, fmt in t.number_format.items():
244
- if col in df.columns:
245
- if fmt.endswith("%"):
246
- decimals = len(fmt.split(".")[-1].rstrip("%")) if "." in fmt else 0
247
- df[col] = (df[col].astype(float) * 100).round(decimals).astype(str) + "%"
248
- else:
249
- try:
250
- decimals = int(fmt.split(".")[-1]) if "." in fmt else 0
251
- df[col] = df[col].astype(float).round(decimals)
252
- except Exception:
253
- pass
254
-
255
- fmt = (t.format or "table").lower()
256
- if fmt == "list": body = ScenarioEngine._render_list(df)
257
- elif fmt == "comparison": body = ScenarioEngine._render_comparison(df)
258
- elif fmt == "map": body = ScenarioEngine._render_map(df)
259
- elif fmt == "chart":
260
- enc = t.encodings or {}
261
- d = {"chart": t.chart or "bar", **enc}
262
- body = ScenarioEngine._render_chart(df, d)
263
- elif fmt == "narrative":
264
- lines = []
265
- for i, rec in enumerate(df.to_dict(orient="records"), 1):
266
- parts = [f"**{k}**: {v}" for k, v in rec.items()]
267
- lines.append(f"{i}. " + "; ".join(parts))
268
- body = "\n".join(lines) if lines else "_No content._"
269
- else:
270
- body = ScenarioEngine._render_table(df)
271
-
272
- section.append(body)
273
- section.append("\n**Provenance**")
274
- section.append(f"- Data key: `{t.data_key or 'auto'}`")
275
  return "\n".join(section)
 
 
1
+ # scenario_engine.py
2
  from __future__ import annotations
3
  from typing import Dict, List, Any, Tuple, Optional
4
  import re, math, json, ast
 
7
  from schema import ScenarioPlan, TaskPlan
8
  from column_resolver import resolve_cols
9
 
10
+ # Allowed safe functions
11
  _ALLOWED_FUNCS = {
12
  "abs": abs, "round": round, "sqrt": math.sqrt, "log": math.log, "exp": math.exp,
13
  "min": np.minimum, "max": np.maximum,
 
17
  "p90": lambda x: np.percentile(x, 90), "p95": lambda x: np.percentile(x, 95),
18
  }
19
 
20
+ # -------- SAFE EXPRESSION PARSER --------
21
  class _SafeExpr(ast.NodeTransformer):
22
+ def __init__(self, allowed: set): self.allowed = allowed
23
  def visit_Name(self, node):
24
+ if node.id not in self.allowed and node.id not in ("True","False","None"):
25
  raise ValueError(f"Unknown name: {node.id}")
26
  return node
27
  def visit_Call(self, node):
28
+ if not isinstance(node.func, ast.Name):
29
+ raise ValueError("Only simple calls allowed")
30
+ if node.func.id not in _ALLOWED_FUNCS:
31
+ raise ValueError(f"Function not allowed: {node.func.id}")
32
  return self.generic_visit(node)
33
  def generic_visit(self, node):
34
+ allowed = (
35
+ ast.Expression, ast.BoolOp, ast.BinOp, ast.UnaryOp, ast.Compare, ast.Call, ast.Name,
36
+ ast.Load, ast.Constant, ast.And, ast.Or, ast.Not, ast.Add, ast.Sub, ast.Mult, ast.Div,
37
+ ast.Mod, ast.Pow, ast.FloorDiv, ast.Eq, ast.NotEq, ast.Lt, ast.LtE, ast.Gt, ast.GtE,
38
+ ast.USub, ast.UAdd
39
+ )
40
  if not isinstance(node, allowed):
41
  raise ValueError(f"Unsupported syntax: {type(node).__name__}")
42
  return super().generic_visit(node)
 
49
  env = {**{k: df[k] for k in df.columns}, **_ALLOWED_FUNCS}
50
  return eval(code, {"__builtins__": {}}, env)
51
 
52
+ # -------- COLUMN ROLE RESOLVER --------
53
+ SEMANTIC_ROLES = {
54
+ "facility": ["facility", "hospital", "centre", "center", "clinic", "site", "settlement", "community"],
55
+ "zone": ["zone", "region", "area", "district"],
56
+ "specialty": ["specialty", "service", "program", "discipline"],
57
+ "city": ["city", "town", "village"],
58
+ "lat": ["latitude", "lat"],
59
+ "lon": ["longitude", "lon", "lng"],
60
+ }
61
+
62
+ def resolve_role(df: pd.DataFrame, role: str) -> Optional[str]:
63
+ """Find the best matching column for a semantic role."""
64
+ candidates = SEMANTIC_ROLES.get(role, [])
65
+ lower_cols = {c.lower(): c for c in df.columns}
66
+ for cand in candidates:
67
+ for col_lc, col in lower_cols.items():
68
+ if cand in col_lc:
69
+ return col
70
+ return None
71
+
72
+ # -------- MAIN ENGINE --------
73
  class ScenarioEngine:
74
  @staticmethod
75
  def _as_df(v: Any) -> Optional[pd.DataFrame]:
76
  if isinstance(v, list):
77
+ return pd.DataFrame(v) if v else pd.DataFrame()
 
78
  if isinstance(v, dict):
79
+ return pd.DataFrame([v]) if all(isinstance(val, (int,float,str,bool,type(None))) for val in v.values()) else pd.DataFrame()
80
+ if isinstance(v, pd.DataFrame):
81
+ return v
 
 
 
 
 
82
  return None
83
 
84
  @staticmethod
 
105
  def _apply_derive(df: pd.DataFrame, spec: str) -> pd.DataFrame:
106
  parts = re.split(r'[;,]\s*', spec)
107
  for p in parts:
108
+ if "=" in p:
109
+ col, expr = p.split("=", 1)
110
+ df[col.strip()] = _eval_series_expr(expr.strip(), df)
111
  return df
112
 
113
  @staticmethod
114
  def _parse_aggs(spec: Optional[str]) -> List[Tuple[str, str]]:
115
  if not spec: return []
 
116
  out = []
117
+ for it in [x.strip() for x in spec.split(",") if x.strip()]:
118
+ if it.lower() in ("count","count(*)"):
119
+ out.append(("count","count(*)")); continue
120
+ m = re.match(r'([a-zA-Z_][a-zA-Z0-9_]*)\(([^)]+)\)', it)
121
+ if not m: continue
122
  func, arg = m.group(1).lower(), m.group(2).strip()
123
  out.append((f"{func}_{arg}", f"{func}({arg})"))
124
  return out
 
127
  def _apply_agg_call(df: pd.DataFrame, call: str):
128
  call = call.strip()
129
  if call.lower() in ("count","count(*)"): return int(len(df))
130
+ m = re.match(r'([a-zA-Z_][a-zA-Z0-9_]*)\(([^)]+)\)', call)
131
  func, arg = m.group(1).lower(), m.group(2).strip()
132
+ if arg not in df.columns: return None
133
  col = df[arg].dropna()
134
  if func in ("avg","mean"): return float(np.mean(col)) if len(col) else float("nan")
135
  if func == "median": return float(np.median(col)) if len(col) else float("nan")
136
  if func == "sum": return float(np.sum(col)) if len(col) else 0.0
137
  if func in ("min","max"): return float(getattr(np, func)(col)) if len(col) else float("nan")
138
  if func.startswith("p") and func[1:].isdigit(): return float(np.percentile(col, int(func[1:]))) if len(col) else float("nan")
139
+ return None
140
 
141
  @staticmethod
142
  def _group_agg(df: pd.DataFrame, group_by: Optional[List[str]], agg_spec: Optional[str]) -> pd.DataFrame:
 
149
  for keys, g in gb:
150
  if not isinstance(keys, tuple): keys = (keys,)
151
  rec = {group_by[i]: keys[i] for i in range(len(group_by))}
152
+ for out_col, call in aggs:
153
+ rec[out_col] = ScenarioEngine._apply_agg_call(g, call)
 
 
154
  rows.append(rec)
155
  return pd.DataFrame(rows)
156
 
157
+ # -------- RENDERERS --------
 
 
 
 
 
 
 
 
 
 
158
  @staticmethod
159
  def _render_table(df: pd.DataFrame) -> str:
160
  if df.empty: return "_No rows._"
 
166
  rows = ["| " + " | ".join(map(str, r)) + " |" for r in dff.to_numpy().tolist()]
167
  return "\n".join([header, sep, *rows])
168
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  @staticmethod
170
  def _exec_task(t: TaskPlan, datasets: Dict[str, Any]) -> str:
171
+ section = [f"## {t.title}\n"]
172
  df = ScenarioEngine._get_df(datasets, t.data_key)
173
  if df is None or df.empty:
174
+ section.append("_No matching data for this task._")
175
  return "\n".join(section)
176
 
177
+ # Resolve semantic roles dynamically
178
+ if t.group_by:
179
+ t.group_by = resolve_cols(t.group_by, df.columns.tolist())
180
+
181
  if t.filter: df = ScenarioEngine._apply_filter(df, t.filter)
182
  if t.derive:
183
  for d in t.derive: df = ScenarioEngine._apply_derive(df, d)
184
+
 
 
 
 
 
185
  if t.group_by or t.agg:
186
  df = ScenarioEngine._group_agg(df, t.group_by, ", ".join(t.agg or []))
187
+
 
 
188
  if t.sort_by and t.sort_by in df.columns:
189
  df = df.sort_values(by=t.sort_by, ascending=(t.sort_dir or "desc").lower()=="asc")
190
+
191
+ if t.top and t.top > 0:
192
+ df = df.head(t.top)
193
+
194
  if t.fields:
195
  cols = resolve_cols(t.fields, df.columns.tolist())
196
  cols = [c for c in cols if c in df.columns]
197
  if cols: df = df[cols]
198
 
199
+ section.append(ScenarioEngine._render_table(df))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  return "\n".join(section)
201
+