Rajan Sharma commited on
Commit
16b5d3f
·
verified ·
1 Parent(s): 6953f37

Update scenario_engine.py

Browse files
Files changed (1) hide show
  1. scenario_engine.py +356 -132
scenario_engine.py CHANGED
@@ -1,32 +1,94 @@
1
  # scenario_engine.py
 
2
  from __future__ import annotations
3
- from typing import Dict, List, Any, Tuple, Optional
4
- import re, math, json, ast
5
  import numpy as np
6
  import pandas as pd
7
- from schema import ScenarioPlan, TaskPlan
8
- from column_resolver import resolve_cols
9
 
10
- # Allowed safe functions
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  _ALLOWED_FUNCS = {
12
- "abs": abs, "round": round, "sqrt": math.sqrt, "log": math.log, "exp": math.exp,
 
13
  "min": np.minimum, "max": np.maximum,
14
  "mean": np.mean, "avg": np.mean, "median": np.median, "sum": np.sum,
15
  "count": lambda x: np.size(x),
16
- "p50": lambda x: np.percentile(x, 50), "p75": lambda x: np.percentile(x, 75),
17
- "p90": lambda x: np.percentile(x, 90), "p95": lambda x: np.percentile(x, 95),
 
 
18
  }
19
 
20
- # -------- SAFE EXPRESSION PARSER --------
21
  class _SafeExpr(ast.NodeTransformer):
22
- def __init__(self, allowed: set): self.allowed = allowed
23
  def visit_Name(self, node):
24
  if node.id not in self.allowed and node.id not in ("True","False","None"):
25
  raise ValueError(f"Unknown name: {node.id}")
26
  return node
 
 
27
  def visit_Call(self, node):
28
  if not isinstance(node.func, ast.Name):
29
- raise ValueError("Only simple calls allowed")
30
  if node.func.id not in _ALLOWED_FUNCS:
31
  raise ValueError(f"Function not allowed: {node.func.id}")
32
  return self.generic_visit(node)
@@ -47,155 +109,317 @@ def _eval_series_expr(expr: str, df: pd.DataFrame) -> pd.Series:
47
  _SafeExpr(names).visit(tree)
48
  code = compile(tree, "<expr>", "eval")
49
  env = {**{k: df[k] for k in df.columns}, **_ALLOWED_FUNCS}
50
- return eval(code, {"__builtins__": {}}, env)
51
-
52
- # -------- COLUMN ROLE RESOLVER --------
53
- SEMANTIC_ROLES = {
54
- "facility": ["facility", "hospital", "centre", "center", "clinic", "site", "settlement", "community"],
55
- "zone": ["zone", "region", "area", "district"],
56
- "specialty": ["specialty", "service", "program", "discipline"],
57
- "city": ["city", "town", "village"],
58
- "lat": ["latitude", "lat"],
59
- "lon": ["longitude", "lon", "lng"],
60
- }
61
 
62
- def resolve_role(df: pd.DataFrame, role: str) -> Optional[str]:
63
- """Find the best matching column for a semantic role."""
64
- candidates = SEMANTIC_ROLES.get(role, [])
65
- lower_cols = {c.lower(): c for c in df.columns}
66
- for cand in candidates:
67
- for col_lc, col in lower_cols.items():
68
- if cand in col_lc:
69
- return col
 
70
  return None
71
 
72
- # -------- MAIN ENGINE --------
73
- class ScenarioEngine:
74
- @staticmethod
75
- def _as_df(v: Any) -> Optional[pd.DataFrame]:
76
- if isinstance(v, list):
77
- return pd.DataFrame(v) if v else pd.DataFrame()
78
- if isinstance(v, dict):
79
- return pd.DataFrame([v]) if all(isinstance(val, (int,float,str,bool,type(None))) for val in v.values()) else pd.DataFrame()
80
- if isinstance(v, pd.DataFrame):
81
- return v
82
- return None
83
 
84
- @staticmethod
85
- def execute_plan(plan: ScenarioPlan, datasets: Dict[str, Any]) -> str:
86
- sections: List[str] = ["# Scenario Output\n"]
87
- for t in plan.tasks:
88
- sections.append(ScenarioEngine._exec_task(t, datasets))
89
- return "\n".join(sections).strip()
 
 
 
 
 
90
 
91
- @staticmethod
92
- def _get_df(datasets: Dict[str, Any], key: Optional[str]) -> Optional[pd.DataFrame]:
93
- if key and key in datasets:
94
- v = datasets[key]
95
- else:
96
- v = next((vv for vv in datasets.values() if isinstance(vv, (list, dict, pd.DataFrame))), None)
97
- return ScenarioEngine._as_df(v) if v is not None else None
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
- @staticmethod
100
- def _apply_filter(df: pd.DataFrame, expr: str) -> pd.DataFrame:
101
- m = _eval_series_expr(expr, df)
102
- return df.loc[m.astype(bool)].copy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
- @staticmethod
105
- def _apply_derive(df: pd.DataFrame, spec: str) -> pd.DataFrame:
106
- parts = re.split(r'[;,]\s*', spec)
107
- for p in parts:
108
- if "=" in p:
109
- col, expr = p.split("=", 1)
110
- df[col.strip()] = _eval_series_expr(expr.strip(), df)
111
- return df
112
 
113
- @staticmethod
114
- def _parse_aggs(spec: Optional[str]) -> List[Tuple[str, str]]:
115
- if not spec: return []
116
- out = []
117
- for it in [x.strip() for x in spec.split(",") if x.strip()]:
118
- if it.lower() in ("count","count(*)"):
119
- out.append(("count","count(*)")); continue
120
- m = re.match(r'([a-zA-Z_][a-zA-Z0-9_]*)\(([^)]+)\)', it)
121
- if not m: continue
122
- func, arg = m.group(1).lower(), m.group(2).strip()
123
- out.append((f"{func}_{arg}", f"{func}({arg})"))
124
- return out
125
 
126
- @staticmethod
127
- def _apply_agg_call(df: pd.DataFrame, call: str):
128
- call = call.strip()
129
- if call.lower() in ("count","count(*)"): return int(len(df))
130
- m = re.match(r'([a-zA-Z_][a-zA-Z0-9_]*)\(([^)]+)\)', call)
131
- func, arg = m.group(1).lower(), m.group(2).strip()
132
- if arg not in df.columns: return None
133
- col = df[arg].dropna()
134
- if func in ("avg","mean"): return float(np.mean(col)) if len(col) else float("nan")
135
- if func == "median": return float(np.median(col)) if len(col) else float("nan")
136
- if func == "sum": return float(np.sum(col)) if len(col) else 0.0
137
- if func in ("min","max"): return float(getattr(np, func)(col)) if len(col) else float("nan")
138
- if func.startswith("p") and func[1:].isdigit(): return float(np.percentile(col, int(func[1:]))) if len(col) else float("nan")
 
139
  return None
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  @staticmethod
142
- def _group_agg(df: pd.DataFrame, group_by: Optional[List[str]], agg_spec: Optional[str]) -> pd.DataFrame:
143
- aggs = ScenarioEngine._parse_aggs(agg_spec)
144
- if not aggs and not group_by: return df
145
- if not group_by:
146
- return pd.DataFrame([{k: ScenarioEngine._apply_agg_call(df, call) for k, call in aggs}])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  rows = []
148
- gb = df.groupby(group_by, dropna=False)
149
  for keys, g in gb:
150
  if not isinstance(keys, tuple): keys = (keys,)
151
- rec = {group_by[i]: keys[i] for i in range(len(group_by))}
152
  for out_col, call in aggs:
153
- rec[out_col] = ScenarioEngine._apply_agg_call(g, call)
154
  rows.append(rec)
155
  return pd.DataFrame(rows)
156
 
157
- # -------- RENDERERS --------
158
  @staticmethod
159
- def _render_table(df: pd.DataFrame) -> str:
160
- if df.empty: return "_No rows._"
161
- dff = df.copy()
162
- for c in dff.columns:
163
- dff[c] = dff[c].apply(lambda v: "NaN" if (isinstance(v,float) and math.isnan(v)) else f"{v:,.4g}" if isinstance(v,float) else v)
164
- header = "| " + " | ".join(dff.columns) + " |"
165
- sep = "|" + "|".join(["---"] * len(dff.columns)) + "|"
166
- rows = ["| " + " | ".join(map(str, r)) + " |" for r in dff.to_numpy().tolist()]
167
- return "\n".join([header, sep, *rows])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
 
169
  @staticmethod
170
- def _exec_task(t: TaskPlan, datasets: Dict[str, Any]) -> str:
171
- section = [f"## {t.title}\n"]
172
- df = ScenarioEngine._get_df(datasets, t.data_key)
 
 
 
 
173
  if df is None or df.empty:
174
- section.append("_No matching data for this task._")
175
- return "\n".join(section)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
 
177
- # Resolve semantic roles dynamically
178
- if t.group_by:
179
- t.group_by = resolve_cols(t.group_by, df.columns.tolist())
 
 
180
 
181
- if t.filter: df = ScenarioEngine._apply_filter(df, t.filter)
182
- if t.derive:
183
- for d in t.derive: df = ScenarioEngine._apply_derive(df, d)
184
 
185
- if t.group_by or t.agg:
186
- df = ScenarioEngine._group_agg(df, t.group_by, ", ".join(t.agg or []))
 
 
 
187
 
188
- if t.sort_by and t.sort_by in df.columns:
189
- df = df.sort_values(by=t.sort_by, ascending=(t.sort_dir or "desc").lower()=="asc")
190
 
191
- if t.top and t.top > 0:
192
- df = df.head(t.top)
 
 
 
193
 
194
- if t.fields:
195
- cols = resolve_cols(t.fields, df.columns.tolist())
196
- cols = [c for c in cols if c in df.columns]
197
- if cols: df = df[cols]
 
 
198
 
199
- section.append(ScenarioEngine._render_table(df))
200
- return "\n".join(section)
 
 
 
 
 
 
 
 
 
 
 
201
 
 
1
  # scenario_engine.py
2
+ # scenario_engine.py
3
  from __future__ import annotations
4
+ from typing import Dict, List, Any, Tuple, Optional, Iterable
5
+ import re, math, ast
6
  import numpy as np
7
  import pandas as pd
 
 
8
 
9
+ # Optional import from column_resolver.py (recommended).
10
+ # If it's not available, we define light fallbacks so the engine still works.
11
+ try:
12
+ from column_resolver import resolve_one, resolve_cols # full resolver (headers + synonyms)
13
+ except Exception:
14
+ # ---- Minimal, schema-agnostic fallback (headers-only; safe, no hard-coding) ----
15
+ _ROLE_SYNONYMS = {
16
+ "facility": ["facility", "hospital", "centre", "center", "clinic", "site", "provider",
17
+ "settlement", "community", "location"],
18
+ "community": ["community", "settlement", "reserve", "town", "village", "city", "region", "area"],
19
+ "zone": ["zone", "region", "district", "area", "healthzone"],
20
+ "specialty": ["specialty", "programme", "program", "service", "discipline", "department"],
21
+ "period": ["period", "quarter", "year", "month", "time", "fiscal", "date"],
22
+ "city": ["city", "town", "village"],
23
+ "lat": ["latitude", "lat"],
24
+ "lon": ["longitude", "lon", "lng"],
25
+ "population": ["population", "members", "residents", "census"],
26
+ "prevalence": ["prevalence", "rate", "risk", "pct", "percentage"],
27
+ "volume": ["count", "visits", "clients", "volume", "n", "cases"],
28
+ "cost": ["cost", "expense", "spend", "budget", "perclient", "startup"],
29
+ "capacity": ["capacity", "throughput", "slots", "dailycapacity", "clientsperday"],
30
+ }
31
+ def _canon(s: str) -> str:
32
+ return re.sub(r"[^a-z0-9]+", "", (s or "").lower())
33
+
34
+ def resolve_one(want: str, columns: Iterable[str]) -> Optional[str]:
35
+ cols = list(columns or [])
36
+ if not cols:
37
+ return None
38
+ w = _canon(want or "")
39
+ if not w:
40
+ return None
41
+ canon_cols = { _canon(c): c for c in cols }
42
+ if w in canon_cols:
43
+ return canon_cols[w]
44
+ syns = _ROLE_SYNONYMS.get(want.lower(), [])
45
+ syns_canon = [_canon(s) for s in syns]
46
+ # Try synonyms exact/startswith/contains
47
+ best, score = None, -1
48
+ for c in cols:
49
+ cc = _canon(c)
50
+ sc = 0
51
+ if w and (cc == w or cc.startswith(w) or w in cc): sc += 3
52
+ for s in syns_canon:
53
+ if cc == s: sc += 5
54
+ elif cc.startswith(s): sc += 3
55
+ elif s in cc: sc += 2
56
+ if sc > score:
57
+ best, score = c, sc
58
+ return best if score >= 2 else None
59
+
60
+ def resolve_cols(requested: Iterable[str], columns: Iterable[str]) -> List[str]:
61
+ out, seen = [], set()
62
+ for r in requested or []:
63
+ col = resolve_one(r, columns)
64
+ if col and col not in seen:
65
+ out.append(col); seen.add(col)
66
+ return out
67
+
68
+ # ---------- Safe expression evaluation (filters/derivations) ----------
69
  _ALLOWED_FUNCS = {
70
+ "abs": abs, "round": round,
71
+ "sqrt": np.sqrt, "log": np.log, "exp": np.exp,
72
  "min": np.minimum, "max": np.maximum,
73
  "mean": np.mean, "avg": np.mean, "median": np.median, "sum": np.sum,
74
  "count": lambda x: np.size(x),
75
+ "p50": lambda x: np.percentile(x, 50),
76
+ "p75": lambda x: np.percentile(x, 75),
77
+ "p90": lambda x: np.percentile(x, 90),
78
+ "p95": lambda x: np.percentile(x, 95),
79
  }
80
 
 
81
  class _SafeExpr(ast.NodeTransformer):
82
+ def __init__(self, allowed): self.allowed = allowed
83
  def visit_Name(self, node):
84
  if node.id not in self.allowed and node.id not in ("True","False","None"):
85
  raise ValueError(f"Unknown name: {node.id}")
86
  return node
87
+ def visit_Attribute(self, node):
88
+ raise ValueError("Attribute access is not allowed")
89
  def visit_Call(self, node):
90
  if not isinstance(node.func, ast.Name):
91
+ raise ValueError("Only simple function calls are allowed")
92
  if node.func.id not in _ALLOWED_FUNCS:
93
  raise ValueError(f"Function not allowed: {node.func.id}")
94
  return self.generic_visit(node)
 
109
  _SafeExpr(names).visit(tree)
110
  code = compile(tree, "<expr>", "eval")
111
  env = {**{k: df[k] for k in df.columns}, **_ALLOWED_FUNCS}
112
+ val = eval(code, {"__builtins__": {}}, env)
113
+ if isinstance(val, (pd.Series, np.ndarray, list)):
114
+ return pd.Series(val, index=df.index)
115
+ if isinstance(val, (bool, np.bool_)):
116
+ return pd.Series([val] * len(df), index=df.index)
117
+ raise ValueError("Filter/derive expression must yield a vector or boolean")
 
 
 
 
 
118
 
119
+ # ---------- Helpers ----------
120
+ def _as_df(v: Any) -> Optional[pd.DataFrame]:
121
+ if isinstance(v, pd.DataFrame):
122
+ return v
123
+ if isinstance(v, list):
124
+ return pd.DataFrame(v) if v else pd.DataFrame()
125
+ if isinstance(v, dict):
126
+ flat = all(isinstance(val, (int,float,str,bool,type(None))) for val in v.values())
127
+ return pd.DataFrame([v]) if flat else pd.DataFrame()
128
  return None
129
 
130
+ def _get_df(datasets: Dict[str, Any], key: Optional[str]) -> Optional[pd.DataFrame]:
131
+ if key and key in datasets:
132
+ v = datasets[key]
133
+ else:
134
+ v = next((vv for vv in datasets.values() if vv is not None), None)
135
+ return _as_df(v) if v is not None else None
 
 
 
 
 
136
 
137
+ def _auto_group_cols(df: pd.DataFrame) -> List[str]:
138
+ prefs = ["facility","community","settlement","provider","zone","region","district","specialty","program","service","city"]
139
+ resolved = []
140
+ for p in prefs:
141
+ col = resolve_one(p, df.columns)
142
+ if col and col not in resolved:
143
+ resolved.append(col)
144
+ if resolved:
145
+ return [resolved[0]]
146
+ obj_cols = [c for c in df.columns if df[c].dtype == "object"]
147
+ return obj_cols[:1] if obj_cols else []
148
 
149
+ def _parse_aggs(spec: Optional[str]) -> List[Tuple[str, str]]:
150
+ """
151
+ "mean(wait_days), p90(wait_days), count(*)" -> [("mean_wait_days","mean(wait_days)"), ...]
152
+ bare token "wait_days" becomes mean(wait_days)
153
+ """
154
+ if not spec:
155
+ return []
156
+ out: List[Tuple[str,str]] = []
157
+ for it in [x.strip() for x in spec.split(",") if x.strip()]:
158
+ if it.lower() in ("count", "count(*)"):
159
+ out.append(("count_*", "count(*)")); continue
160
+ m = re.match(r'([a-zA-Z_][a-zA-Z0-9_]*)\(([^)]+)\)', it)
161
+ if not m:
162
+ arg = it
163
+ out.append((f"mean_{arg}", f"mean({arg})"))
164
+ continue
165
+ func, arg = m.group(1).lower(), m.group(2).strip()
166
+ out.append((f"{func}_{arg}", f"{func}({arg})"))
167
+ return out
168
 
169
+ def _apply_agg_call(df: pd.DataFrame, call: str):
170
+ call = call.strip().lower()
171
+ if call in ("count", "count(*)"):
172
+ return int(len(df))
173
+ m = re.match(r'([a-z_][a-z0-9_]*)\(([^)]+)\)', call)
174
+ if not m:
175
+ arg = call
176
+ if arg not in df.columns: return None
177
+ col = pd.to_numeric(df[arg], errors="coerce").dropna()
178
+ return float(col.mean()) if len(col) else float("nan")
179
+ func, arg = m.group(1), m.group(2).strip()
180
+ if arg not in df.columns:
181
+ return None
182
+ col = pd.to_numeric(df[arg], errors="coerce").dropna()
183
+ if not len(col):
184
+ return float("nan")
185
+ if func in ("avg","mean"): return float(col.mean())
186
+ if func == "median": return float(np.median(col))
187
+ if func == "sum": return float(col.sum())
188
+ if func in ("min","max"): return float(getattr(np, func)(col))
189
+ if func.startswith("p") and func[1:].isdigit(): return float(np.percentile(col, int(func[1:])))
190
+ return None
191
 
192
+ def _apply_filter(df: pd.DataFrame, expr: str) -> pd.DataFrame:
193
+ m = _eval_series_expr(expr, df)
194
+ return df.loc[m.astype(bool)].copy()
 
 
 
 
 
195
 
196
+ def _apply_derive(df: pd.DataFrame, spec: str) -> pd.DataFrame:
197
+ # supports "newcol = expr, other = expr2"
198
+ parts = re.split(r'[;,]\s*', spec)
199
+ for p in parts:
200
+ if "=" in p:
201
+ col, expr = p.split("=", 1)
202
+ df[col.strip()] = _eval_series_expr(expr.strip(), df)
203
+ return df
 
 
 
 
204
 
205
+ def _render_table(df: pd.DataFrame) -> str:
206
+ if df is None or df.empty:
207
+ return "_No rows._"
208
+ dff = df.copy()
209
+ for c in dff.columns:
210
+ if pd.api.types.is_float_dtype(dff[c]) or pd.api.types.is_integer_dtype(dff[c]):
211
+ dff[c] = dff[c].apply(lambda v: "NaN" if (isinstance(v,float) and math.isnan(v)) else f"{v:,.4g}")
212
+ header = "| " + " | ".join(map(str, dff.columns)) + " |"
213
+ sep = "|" + "|".join(["---"] * len(dff.columns)) + "|"
214
+ rows = ["| " + " | ".join(map(str, r)) + " |" for r in dff.to_numpy().tolist()]
215
+ return "\n".join([header, sep, *rows])
216
+
217
+ def _small_n_flags(df: pd.DataFrame, count_col: Optional[str] = None, threshold: int = 5) -> Optional[pd.Series]:
218
+ if df is None or df.empty:
219
  return None
220
+ if count_col and count_col in df.columns:
221
+ return df[count_col].apply(lambda n: " (interpret cautiously: small n)" if pd.notnull(n) and float(n) < threshold else "")
222
+ # Fallback if no explicit count column—don’t guess
223
+ return None
224
+
225
+ def _missingness(df: pd.DataFrame, metric_cols: List[str]) -> List[str]:
226
+ notes = []
227
+ for c in metric_cols:
228
+ if c in df.columns:
229
+ miss = df[c].isna().mean()
230
+ if miss > 0:
231
+ notes.append(f"{c}: missing {miss:.1%}")
232
+ return notes
233
 
234
+ # ---------- Scenario Engine ----------
235
+ class ScenarioEngine:
236
+ """
237
+ Execute a ScenarioPlan (or dict) consisting of tasks that specify:
238
+ - data_key: name of dataset in `datasets`
239
+ - filter: boolean/vectorized expression (safe-eval)
240
+ - derive: "new = expr, ..."
241
+ - group_by: list of roles/column names (resolved dynamically)
242
+ - agg: "mean(col), p90(col), count(*)" (bare 'col' => mean(col))
243
+ - sort_by / sort_dir
244
+ - top
245
+ - fields: project/alias output columns by role/name (resolved dynamically)
246
+ Returns markdown with:
247
+ - task section
248
+ - table output
249
+ - Assumptions & Mappings
250
+ - Data Quality notes
251
+ """
252
  @staticmethod
253
+ def _group_agg(df: pd.DataFrame,
254
+ group_by: Optional[List[str]],
255
+ agg_spec: Optional[str],
256
+ mapping_log: List[str]) -> pd.DataFrame:
257
+ # Resolve grouping to existing columns; tolerate roles or wrong names
258
+ if group_by:
259
+ gcols = resolve_cols(group_by, df.columns)
260
+ # log role->actual for transparency
261
+ for want in (group_by or []):
262
+ got = resolve_one(want, df.columns)
263
+ mapping_log.append(f"group_by: {want} → {got if got else '(unresolved)'}")
264
+ else:
265
+ gcols = _auto_group_cols(df)
266
+ if gcols:
267
+ mapping_log.append(f"group_by: (auto) → {gcols[0]}")
268
+ else:
269
+ mapping_log.append("group_by: (auto) → (none)")
270
+
271
+ # If no grouping and no aggregations → return df as-is (trim wide frames)
272
+ aggs = _parse_aggs(agg_spec or "")
273
+ if not gcols:
274
+ if not aggs:
275
+ # Keep a reasonable view: first 50 rows
276
+ return df.head(50).copy()
277
+ # global aggregate row
278
+ rec = { out_col: _apply_agg_call(df, call) for out_col, call in aggs }
279
+ return pd.DataFrame([rec])
280
+
281
+ if not aggs:
282
+ # default: mean of numeric cols + count(*)
283
+ num_cols = list(df.select_dtypes(include="number").columns)
284
+ gb = df.groupby(gcols, dropna=False)
285
+ if not num_cols:
286
+ out = gb.size().reset_index(name="count_*")
287
+ return out.sort_values("count_*", ascending=False)
288
+ out = gb[num_cols].mean(numeric_only=True)
289
+ out["count_*"] = gb.size()
290
+ return out.reset_index()
291
+
292
+ # Apply requested aggs
293
  rows = []
294
+ gb = df.groupby(gcols, dropna=False)
295
  for keys, g in gb:
296
  if not isinstance(keys, tuple): keys = (keys,)
297
+ rec = { gcols[i]: keys[i] for i in range(len(gcols)) }
298
  for out_col, call in aggs:
299
+ rec[out_col] = _apply_agg_call(g, call)
300
  rows.append(rec)
301
  return pd.DataFrame(rows)
302
 
 
303
  @staticmethod
304
+ def _project_fields(out_df: pd.DataFrame,
305
+ fields: Optional[List[str]],
306
+ mapping_log: List[str]) -> pd.DataFrame:
307
+ if not isinstance(out_df, pd.DataFrame) or out_df.empty or not fields:
308
+ return out_df
309
+ cols = resolve_cols(fields, out_df.columns)
310
+ for want in fields:
311
+ got = resolve_one(want, out_df.columns)
312
+ mapping_log.append(f"field: {want} → {got if got else '(unresolved)'}")
313
+ if cols:
314
+ return out_df[cols]
315
+ return out_df
316
+
317
+ @staticmethod
318
+ def _data_quality_notes(out_df: pd.DataFrame) -> List[str]:
319
+ notes: List[str] = []
320
+ if out_df is None or out_df.empty:
321
+ return notes
322
+ # small-n flag if a count column exists
323
+ cnt_col = None
324
+ for c in out_df.columns:
325
+ if c.lower() in ("count", "count_*", "n", "records"):
326
+ cnt_col = c; break
327
+ sn = _small_n_flags(out_df, count_col=cnt_col, threshold=5)
328
+ if sn is not None and sn.any():
329
+ n_small = (sn != "").sum()
330
+ if n_small > 0:
331
+ notes.append(f"{n_small} row(s) flagged as small-n (interpret cautiously).")
332
+ # missingness for numeric columns
333
+ metric_cols = [c for c in out_df.columns if pd.api.types.is_numeric_dtype(out_df[c])]
334
+ notes.extend(_missingness(out_df, metric_cols))
335
+ return notes
336
 
337
  @staticmethod
338
+ def _exec_task(t: Any, datasets: Dict[str, Any]) -> str:
339
+ # tolerate dict-like tasks or dataclass
340
+ title = getattr(t, "title", None) or (isinstance(t, dict) and t.get("title")) or "Task"
341
+ section_lines: List[str] = [f"## {title}\n"]
342
+
343
+ data_key = getattr(t, "data_key", None) or (isinstance(t, dict) and t.get("data_key"))
344
+ df = _get_df(datasets, data_key)
345
  if df is None or df.empty:
346
+ section_lines.append("_No matching data for this task._")
347
+ return "\n".join(section_lines)
348
+
349
+ # Optional filter(s)
350
+ t_filter = getattr(t, "filter", None) or (isinstance(t, dict) and t.get("filter"))
351
+ if t_filter:
352
+ try:
353
+ df = _apply_filter(df, t_filter)
354
+ except Exception as e:
355
+ section_lines.append(f"_Warning: filter ignored ({e})._")
356
+
357
+ # Optional derive(s)
358
+ t_derive = getattr(t, "derive", None) or (isinstance(t, dict) and t.get("derive"))
359
+ if t_derive:
360
+ for d in (t_derive if isinstance(t_derive, (list, tuple)) else [t_derive]):
361
+ try:
362
+ df = _apply_derive(df, d)
363
+ except Exception as e:
364
+ section_lines.append(f"_Warning: derive ignored ({e})._")
365
+
366
+ # Group/Aggregate
367
+ t_group_by = getattr(t, "group_by", None) or (isinstance(t, dict) and t.get("group_by"))
368
+ # allow single string in plans
369
+ if isinstance(t_group_by, str):
370
+ t_group_by = [t_group_by]
371
+ t_agg = getattr(t, "agg", None) or (isinstance(t, dict) and t.get("agg"))
372
+ if isinstance(t_agg, list):
373
+ agg_spec = ", ".join(t_agg)
374
+ else:
375
+ agg_spec = (t_agg or None)
376
+
377
+ mapping_log: List[str] = []
378
+ out_df = ScenarioEngine._group_agg(df, t_group_by, agg_spec, mapping_log)
379
 
380
+ # Sort / Top
381
+ t_sort_by = getattr(t, "sort_by", None) or (isinstance(t, dict) and t.get("sort_by"))
382
+ t_sort_dir = (getattr(t, "sort_dir", None) or (isinstance(t, dict) and t.get("sort_dir")) or "desc").lower()
383
+ if t_sort_by and isinstance(out_df, pd.DataFrame) and t_sort_by in out_df.columns:
384
+ out_df = out_df.sort_values(t_sort_by, ascending=(t_sort_dir=="asc"))
385
 
386
+ t_top = getattr(t, "top", None) or (isinstance(t, dict) and t.get("top"))
387
+ if isinstance(t_top, int) and t_top > 0 and isinstance(out_df, pd.DataFrame):
388
+ out_df = out_df.head(t_top)
389
 
390
+ # Field projection
391
+ t_fields = getattr(t, "fields", None) or (isinstance(t, dict) and t.get("fields"))
392
+ if isinstance(t_fields, str):
393
+ t_fields = [t_fields]
394
+ out_df = ScenarioEngine._project_fields(out_df, t_fields, mapping_log)
395
 
396
+ # Render table
397
+ section_lines.append(_render_table(out_df))
398
 
399
+ # Assumptions & Mappings
400
+ if mapping_log:
401
+ section_lines.append("\n**Assumptions & Mappings**")
402
+ for line in mapping_log:
403
+ section_lines.append(f"- {line}")
404
 
405
+ # Data quality
406
+ dq = ScenarioEngine._data_quality_notes(out_df)
407
+ if dq:
408
+ section_lines.append("\n**Data Quality Notes**")
409
+ for n in dq:
410
+ section_lines.append(f"- {n}")
411
 
412
+ return "\n".join(section_lines)
413
+
414
+ @staticmethod
415
+ def execute_plan(plan: Any, datasets: Dict[str, Any]) -> str:
416
+ """
417
+ plan: object or dict with `tasks: List[Task]`
418
+ Each Task can have: title, data_key, filter, derive, group_by, agg, sort_by, sort_dir, top, fields
419
+ """
420
+ sections: List[str] = ["# Scenario Output\n"]
421
+ tasks = getattr(plan, "tasks", None) or (isinstance(plan, dict) and plan.get("tasks")) or []
422
+ for t in tasks:
423
+ sections.append(ScenarioEngine._exec_task(t, datasets))
424
+ return "\n".join(sections).strip()
425