Rajan Sharma commited on
Commit
eb5677d
·
verified ·
1 Parent(s): 16b5d3f

Update scenario_engine.py

Browse files
Files changed (1) hide show
  1. scenario_engine.py +117 -89
scenario_engine.py CHANGED
@@ -1,71 +1,110 @@
1
  # scenario_engine.py
2
- # scenario_engine.py
3
  from __future__ import annotations
4
  from typing import Dict, List, Any, Tuple, Optional, Iterable
5
  import re, math, ast
6
  import numpy as np
7
  import pandas as pd
8
 
9
- # Optional import from column_resolver.py (recommended).
10
- # If it's not available, we define light fallbacks so the engine still works.
11
  try:
12
- from column_resolver import resolve_one, resolve_cols # full resolver (headers + synonyms)
 
 
13
  except Exception:
14
- # ---- Minimal, schema-agnostic fallback (headers-only; safe, no hard-coding) ----
15
- _ROLE_SYNONYMS = {
16
- "facility": ["facility", "hospital", "centre", "center", "clinic", "site", "provider",
17
- "settlement", "community", "location"],
18
- "community": ["community", "settlement", "reserve", "town", "village", "city", "region", "area"],
19
- "zone": ["zone", "region", "district", "area", "healthzone"],
20
- "specialty": ["specialty", "programme", "program", "service", "discipline", "department"],
21
- "period": ["period", "quarter", "year", "month", "time", "fiscal", "date"],
22
- "city": ["city", "town", "village"],
23
- "lat": ["latitude", "lat"],
24
- "lon": ["longitude", "lon", "lng"],
25
- "population": ["population", "members", "residents", "census"],
26
- "prevalence": ["prevalence", "rate", "risk", "pct", "percentage"],
27
- "volume": ["count", "visits", "clients", "volume", "n", "cases"],
28
- "cost": ["cost", "expense", "spend", "budget", "perclient", "startup"],
29
- "capacity": ["capacity", "throughput", "slots", "dailycapacity", "clientsperday"],
30
- }
31
- def _canon(s: str) -> str:
32
- return re.sub(r"[^a-z0-9]+", "", (s or "").lower())
33
-
34
- def resolve_one(want: str, columns: Iterable[str]) -> Optional[str]:
35
- cols = list(columns or [])
36
- if not cols:
37
- return None
38
- w = _canon(want or "")
39
- if not w:
40
- return None
41
- canon_cols = { _canon(c): c for c in cols }
42
- if w in canon_cols:
43
- return canon_cols[w]
44
- syns = _ROLE_SYNONYMS.get(want.lower(), [])
45
- syns_canon = [_canon(s) for s in syns]
46
- # Try synonyms exact/startswith/contains
47
- best, score = None, -1
48
- for c in cols:
49
- cc = _canon(c)
50
- sc = 0
51
- if w and (cc == w or cc.startswith(w) or w in cc): sc += 3
52
- for s in syns_canon:
53
- if cc == s: sc += 5
54
- elif cc.startswith(s): sc += 3
55
- elif s in cc: sc += 2
56
- if sc > score:
57
- best, score = c, sc
58
- return best if score >= 2 else None
59
-
60
- def resolve_cols(requested: Iterable[str], columns: Iterable[str]) -> List[str]:
61
- out, seen = [], set()
62
- for r in requested or []:
63
- col = resolve_one(r, columns)
64
- if col and col not in seen:
65
- out.append(col); seen.add(col)
66
- return out
67
-
68
- # ---------- Safe expression evaluation (filters/derivations) ----------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  _ALLOWED_FUNCS = {
70
  "abs": abs, "round": round,
71
  "sqrt": np.sqrt, "log": np.log, "exp": np.exp,
@@ -114,9 +153,9 @@ def _eval_series_expr(expr: str, df: pd.DataFrame) -> pd.Series:
114
  return pd.Series(val, index=df.index)
115
  if isinstance(val, (bool, np.bool_)):
116
  return pd.Series([val] * len(df), index=df.index)
117
- raise ValueError("Filter/derive expression must yield a vector or boolean")
118
 
119
- # ---------- Helpers ----------
120
  def _as_df(v: Any) -> Optional[pd.DataFrame]:
121
  if isinstance(v, pd.DataFrame):
122
  return v
@@ -136,13 +175,10 @@ def _get_df(datasets: Dict[str, Any], key: Optional[str]) -> Optional[pd.DataFra
136
 
137
  def _auto_group_cols(df: pd.DataFrame) -> List[str]:
138
  prefs = ["facility","community","settlement","provider","zone","region","district","specialty","program","service","city"]
139
- resolved = []
140
  for p in prefs:
141
- col = resolve_one(p, df.columns)
142
- if col and col not in resolved:
143
- resolved.append(col)
144
- if resolved:
145
- return [resolved[0]]
146
  obj_cols = [c for c in df.columns if df[c].dtype == "object"]
147
  return obj_cols[:1] if obj_cols else []
148
 
@@ -219,7 +255,6 @@ def _small_n_flags(df: pd.DataFrame, count_col: Optional[str] = None, threshold:
219
  return None
220
  if count_col and count_col in df.columns:
221
  return df[count_col].apply(lambda n: " (interpret cautiously: small n)" if pd.notnull(n) and float(n) < threshold else "")
222
- # Fallback if no explicit count column—don’t guess
223
  return None
224
 
225
  def _missingness(df: pd.DataFrame, metric_cols: List[str]) -> List[str]:
@@ -231,7 +266,7 @@ def _missingness(df: pd.DataFrame, metric_cols: List[str]) -> List[str]:
231
  notes.append(f"{c}: missing {miss:.1%}")
232
  return notes
233
 
234
- # ---------- Scenario Engine ----------
235
  class ScenarioEngine:
236
  """
237
  Execute a ScenarioPlan (or dict) consisting of tasks that specify:
@@ -256,10 +291,9 @@ class ScenarioEngine:
256
  mapping_log: List[str]) -> pd.DataFrame:
257
  # Resolve grouping to existing columns; tolerate roles or wrong names
258
  if group_by:
259
- gcols = resolve_cols(group_by, df.columns)
260
- # log role->actual for transparency
261
  for want in (group_by or []):
262
- got = resolve_one(want, df.columns)
263
  mapping_log.append(f"group_by: {want} → {got if got else '(unresolved)'}")
264
  else:
265
  gcols = _auto_group_cols(df)
@@ -268,13 +302,12 @@ class ScenarioEngine:
268
  else:
269
  mapping_log.append("group_by: (auto) → (none)")
270
 
271
- # If no grouping and no aggregations → return df as-is (trim wide frames)
272
  aggs = _parse_aggs(agg_spec or "")
 
 
273
  if not gcols:
274
  if not aggs:
275
- # Keep a reasonable view: first 50 rows
276
  return df.head(50).copy()
277
- # global aggregate row
278
  rec = { out_col: _apply_agg_call(df, call) for out_col, call in aggs }
279
  return pd.DataFrame([rec])
280
 
@@ -306,9 +339,9 @@ class ScenarioEngine:
306
  mapping_log: List[str]) -> pd.DataFrame:
307
  if not isinstance(out_df, pd.DataFrame) or out_df.empty or not fields:
308
  return out_df
309
- cols = resolve_cols(fields, out_df.columns)
310
  for want in fields:
311
- got = resolve_one(want, out_df.columns)
312
  mapping_log.append(f"field: {want} → {got if got else '(unresolved)'}")
313
  if cols:
314
  return out_df[cols]
@@ -336,7 +369,6 @@ class ScenarioEngine:
336
 
337
  @staticmethod
338
  def _exec_task(t: Any, datasets: Dict[str, Any]) -> str:
339
- # tolerate dict-like tasks or dataclass
340
  title = getattr(t, "title", None) or (isinstance(t, dict) and t.get("title")) or "Task"
341
  section_lines: List[str] = [f"## {title}\n"]
342
 
@@ -346,7 +378,7 @@ class ScenarioEngine:
346
  section_lines.append("_No matching data for this task._")
347
  return "\n".join(section_lines)
348
 
349
- # Optional filter(s)
350
  t_filter = getattr(t, "filter", None) or (isinstance(t, dict) and t.get("filter"))
351
  if t_filter:
352
  try:
@@ -354,7 +386,7 @@ class ScenarioEngine:
354
  except Exception as e:
355
  section_lines.append(f"_Warning: filter ignored ({e})._")
356
 
357
- # Optional derive(s)
358
  t_derive = getattr(t, "derive", None) or (isinstance(t, dict) and t.get("derive"))
359
  if t_derive:
360
  for d in (t_derive if isinstance(t_derive, (list, tuple)) else [t_derive]):
@@ -363,16 +395,12 @@ class ScenarioEngine:
363
  except Exception as e:
364
  section_lines.append(f"_Warning: derive ignored ({e})._")
365
 
366
- # Group/Aggregate
367
  t_group_by = getattr(t, "group_by", None) or (isinstance(t, dict) and t.get("group_by"))
368
- # allow single string in plans
369
  if isinstance(t_group_by, str):
370
  t_group_by = [t_group_by]
371
  t_agg = getattr(t, "agg", None) or (isinstance(t, dict) and t.get("agg"))
372
- if isinstance(t_agg, list):
373
- agg_spec = ", ".join(t_agg)
374
- else:
375
- agg_spec = (t_agg or None)
376
 
377
  mapping_log: List[str] = []
378
  out_df = ScenarioEngine._group_agg(df, t_group_by, agg_spec, mapping_log)
@@ -380,11 +408,11 @@ class ScenarioEngine:
380
  # Sort / Top
381
  t_sort_by = getattr(t, "sort_by", None) or (isinstance(t, dict) and t.get("sort_by"))
382
  t_sort_dir = (getattr(t, "sort_dir", None) or (isinstance(t, dict) and t.get("sort_dir")) or "desc").lower()
383
- if t_sort_by and isinstance(out_df, pd.DataFrame) and t_sort_by in out_df.columns:
384
  out_df = out_df.sort_values(t_sort_by, ascending=(t_sort_dir=="asc"))
385
 
386
  t_top = getattr(t, "top", None) or (isinstance(t, dict) and t.get("top"))
387
- if isinstance(t_top, int) and t_top > 0 and isinstance(out_df, pd.DataFrame):
388
  out_df = out_df.head(t_top)
389
 
390
  # Field projection
@@ -393,7 +421,7 @@ class ScenarioEngine:
393
  t_fields = [t_fields]
394
  out_df = ScenarioEngine._project_fields(out_df, t_fields, mapping_log)
395
 
396
- # Render table
397
  section_lines.append(_render_table(out_df))
398
 
399
  # Assumptions & Mappings
 
1
  # scenario_engine.py
 
2
  from __future__ import annotations
3
  from typing import Dict, List, Any, Tuple, Optional, Iterable
4
  import re, math, ast
5
  import numpy as np
6
  import pandas as pd
7
 
8
+ # ========= Robust role/column resolver (safe with pandas.Index) =========
 
9
  try:
10
+ # If you have an external, richer resolver, we will use it automatically.
11
+ from column_resolver import resolve_one as _ext_resolve_one, resolve_cols as _ext_resolve_cols # type: ignore
12
+ _HAS_EXT_RESOLVER = True
13
  except Exception:
14
+ _HAS_EXT_RESOLVER = False
15
+
16
+ _ROLE_SYNONYMS_FALLBACK = {
17
+ "facility": ["facility", "hospital", "centre", "center", "clinic", "site", "provider",
18
+ "settlement", "community", "location"],
19
+ "community": ["community", "settlement", "reserve", "town", "village", "city", "region", "area"],
20
+ "zone": ["zone", "region", "district", "area", "healthzone"],
21
+ "specialty": ["specialty", "programme", "program", "service", "discipline", "department"],
22
+ "period": ["period", "quarter", "year", "month", "time", "fiscal", "date"],
23
+ "city": ["city", "town", "village"],
24
+ "lat": ["latitude", "lat"],
25
+ "lon": ["longitude", "lon", "lng"],
26
+ "population": ["population", "members", "residents", "census"],
27
+ "prevalence": ["prevalence", "rate", "risk", "pct", "percentage"],
28
+ "volume": ["count", "visits", "clients", "volume", "n", "cases"],
29
+ "cost": ["cost", "expense", "spend", "budget", "perclient", "startup"],
30
+ "capacity": ["capacity", "throughput", "slots", "dailycapacity", "clientsperday"],
31
+ }
32
+
33
+ def _canon(s: str) -> str:
34
+ return re.sub(r"[^a-z0-9]+", "", (s or "").lower())
35
+
36
+ def _to_list(x: Iterable | None) -> List:
37
+ if x is None:
38
+ return []
39
+ try:
40
+ return list(x)
41
+ except Exception:
42
+ return [x]
43
+
44
+ def resolve_one(want: str, columns: Iterable[str]) -> Optional[str]:
45
+ """Return best matching column for a semantic role or exact header. Safe for pandas.Index."""
46
+ cols = _to_list(columns)
47
+ if _HAS_EXT_RESOLVER:
48
+ try:
49
+ return _ext_resolve_one(want, cols)
50
+ except Exception:
51
+ pass
52
+
53
+ if not cols:
54
+ return None
55
+
56
+ wcanon = _canon(want)
57
+ if not wcanon:
58
+ return None
59
+
60
+ canon_cols = { _canon(c): c for c in cols if isinstance(c, str) }
61
+ if wcanon in canon_cols:
62
+ return canon_cols[wcanon]
63
+
64
+ syns = _ROLE_SYNONYMS_FALLBACK.get((want or "").lower(), [])
65
+ syns_canon = [_canon(s) for s in syns]
66
+
67
+ best, score = None, -1
68
+ for c in cols:
69
+ if not isinstance(c, str):
70
+ continue
71
+ cc = _canon(c)
72
+ sc = 0
73
+ if wcanon and (cc == wcanon or cc.startswith(wcanon) or wcanon in cc):
74
+ sc += 3
75
+ for s in syns_canon:
76
+ if not s:
77
+ continue
78
+ if cc == s:
79
+ sc += 5
80
+ elif cc.startswith(s):
81
+ sc += 3
82
+ elif s in cc:
83
+ sc += 2
84
+ if sc > score:
85
+ best, score = c, sc
86
+ return best if score >= 2 else None
87
+
88
+ def resolve_cols(requested: Iterable[str], columns: Iterable[str]) -> List[str]:
89
+ """Resolve a list of roles/headers to existing columns, uniquely. Safe for pandas.Index."""
90
+ reqs = _to_list(requested)
91
+ cols = _to_list(columns)
92
+
93
+ if _HAS_EXT_RESOLVER:
94
+ try:
95
+ return _ext_resolve_cols(reqs, cols)
96
+ except Exception:
97
+ pass
98
+
99
+ out, seen = [], set()
100
+ for r in reqs:
101
+ col = resolve_one(r, cols)
102
+ if col and col not in seen:
103
+ out.append(col)
104
+ seen.add(col)
105
+ return out
106
+
107
+ # ========= Safe expression evaluation (filters/derivations) =========
108
  _ALLOWED_FUNCS = {
109
  "abs": abs, "round": round,
110
  "sqrt": np.sqrt, "log": np.log, "exp": np.exp,
 
153
  return pd.Series(val, index=df.index)
154
  if isinstance(val, (bool, np.bool_)):
155
  return pd.Series([val] * len(df), index=df.index)
156
+ raise ValueError("Expression must yield a vector or boolean")
157
 
158
+ # ========= Helpers =========
159
  def _as_df(v: Any) -> Optional[pd.DataFrame]:
160
  if isinstance(v, pd.DataFrame):
161
  return v
 
175
 
176
  def _auto_group_cols(df: pd.DataFrame) -> List[str]:
177
  prefs = ["facility","community","settlement","provider","zone","region","district","specialty","program","service","city"]
 
178
  for p in prefs:
179
+ col = resolve_one(p, _to_list(df.columns))
180
+ if col:
181
+ return [col]
 
 
182
  obj_cols = [c for c in df.columns if df[c].dtype == "object"]
183
  return obj_cols[:1] if obj_cols else []
184
 
 
255
  return None
256
  if count_col and count_col in df.columns:
257
  return df[count_col].apply(lambda n: " (interpret cautiously: small n)" if pd.notnull(n) and float(n) < threshold else "")
 
258
  return None
259
 
260
  def _missingness(df: pd.DataFrame, metric_cols: List[str]) -> List[str]:
 
266
  notes.append(f"{c}: missing {miss:.1%}")
267
  return notes
268
 
269
+ # ========= Scenario Engine =========
270
  class ScenarioEngine:
271
  """
272
  Execute a ScenarioPlan (or dict) consisting of tasks that specify:
 
291
  mapping_log: List[str]) -> pd.DataFrame:
292
  # Resolve grouping to existing columns; tolerate roles or wrong names
293
  if group_by:
294
+ gcols = resolve_cols(group_by, _to_list(df.columns))
 
295
  for want in (group_by or []):
296
+ got = resolve_one(want, _to_list(df.columns))
297
  mapping_log.append(f"group_by: {want} → {got if got else '(unresolved)'}")
298
  else:
299
  gcols = _auto_group_cols(df)
 
302
  else:
303
  mapping_log.append("group_by: (auto) → (none)")
304
 
 
305
  aggs = _parse_aggs(agg_spec or "")
306
+
307
+ # No grouping & no agg => just preview a slice
308
  if not gcols:
309
  if not aggs:
 
310
  return df.head(50).copy()
 
311
  rec = { out_col: _apply_agg_call(df, call) for out_col, call in aggs }
312
  return pd.DataFrame([rec])
313
 
 
339
  mapping_log: List[str]) -> pd.DataFrame:
340
  if not isinstance(out_df, pd.DataFrame) or out_df.empty or not fields:
341
  return out_df
342
+ cols = resolve_cols(fields, _to_list(out_df.columns))
343
  for want in fields:
344
+ got = resolve_one(want, _to_list(out_df.columns))
345
  mapping_log.append(f"field: {want} → {got if got else '(unresolved)'}")
346
  if cols:
347
  return out_df[cols]
 
369
 
370
  @staticmethod
371
  def _exec_task(t: Any, datasets: Dict[str, Any]) -> str:
 
372
  title = getattr(t, "title", None) or (isinstance(t, dict) and t.get("title")) or "Task"
373
  section_lines: List[str] = [f"## {title}\n"]
374
 
 
378
  section_lines.append("_No matching data for this task._")
379
  return "\n".join(section_lines)
380
 
381
+ # Filter(s)
382
  t_filter = getattr(t, "filter", None) or (isinstance(t, dict) and t.get("filter"))
383
  if t_filter:
384
  try:
 
386
  except Exception as e:
387
  section_lines.append(f"_Warning: filter ignored ({e})._")
388
 
389
+ # Derive(s)
390
  t_derive = getattr(t, "derive", None) or (isinstance(t, dict) and t.get("derive"))
391
  if t_derive:
392
  for d in (t_derive if isinstance(t_derive, (list, tuple)) else [t_derive]):
 
395
  except Exception as e:
396
  section_lines.append(f"_Warning: derive ignored ({e})._")
397
 
398
+ # Group/Agg
399
  t_group_by = getattr(t, "group_by", None) or (isinstance(t, dict) and t.get("group_by"))
 
400
  if isinstance(t_group_by, str):
401
  t_group_by = [t_group_by]
402
  t_agg = getattr(t, "agg", None) or (isinstance(t, dict) and t.get("agg"))
403
+ agg_spec = ", ".join(t_agg) if isinstance(t_agg, list) else (t_agg or None)
 
 
 
404
 
405
  mapping_log: List[str] = []
406
  out_df = ScenarioEngine._group_agg(df, t_group_by, agg_spec, mapping_log)
 
408
  # Sort / Top
409
  t_sort_by = getattr(t, "sort_by", None) or (isinstance(t, dict) and t.get("sort_by"))
410
  t_sort_dir = (getattr(t, "sort_dir", None) or (isinstance(t, dict) and t.get("sort_dir")) or "desc").lower()
411
+ if isinstance(out_df, pd.DataFrame) and t_sort_by and t_sort_by in out_df.columns:
412
  out_df = out_df.sort_values(t_sort_by, ascending=(t_sort_dir=="asc"))
413
 
414
  t_top = getattr(t, "top", None) or (isinstance(t, dict) and t.get("top"))
415
+ if isinstance(out_df, pd.DataFrame) and isinstance(t_top, int) and t_top > 0:
416
  out_df = out_df.head(t_top)
417
 
418
  # Field projection
 
421
  t_fields = [t_fields]
422
  out_df = ScenarioEngine._project_fields(out_df, t_fields, mapping_log)
423
 
424
+ # Render
425
  section_lines.append(_render_table(out_df))
426
 
427
  # Assumptions & Mappings