Spaces:
Sleeping
Sleeping
| # scenario_engine.py | |
| from __future__ import annotations | |
| from typing import Dict, List, Any, Tuple, Union, Optional | |
| import re | |
| import math | |
| import statistics | |
| import json | |
| import ast | |
| import pandas as pd | |
| import numpy as np | |
| # ---------------------------- | |
| # Safe expression evaluation | |
| # ---------------------------- | |
| _ALLOWED_FUNCS = { | |
| "abs": abs, | |
| "round": round, | |
| "sqrt": math.sqrt, | |
| "log": math.log, | |
| "exp": math.exp, | |
| "min": np.minimum, # vectorized | |
| "max": np.maximum, # vectorized | |
| "mean": np.mean, | |
| "avg": np.mean, | |
| "median": np.median, | |
| "sum": np.sum, | |
| "count": lambda x: np.size(x), | |
| "p50": lambda x: np.percentile(x, 50), | |
| "p75": lambda x: np.percentile(x, 75), | |
| "p90": lambda x: np.percentile(x, 90), | |
| "p95": lambda x: np.percentile(x, 95), | |
| "p99": lambda x: np.percentile(x, 99), | |
| "ceil": np.ceil, | |
| "floor": np.floor, | |
| } | |
| class _SafeExpr(ast.NodeTransformer): | |
| """ | |
| Restrict expressions to: | |
| - Names (columns), numbers, strings, booleans | |
| - Arithmetic: + - * / // % **, comparisons, and/or/not | |
| - Calls to allowed functions (above) | |
| """ | |
| def __init__(self, allowed_names: set): | |
| self.allowed_names = allowed_names | |
| def visit_Name(self, node): | |
| if node.id not in self.allowed_names and node.id not in ("True","False","None"): | |
| raise ValueError(f"Unknown name in expression: {node.id}") | |
| return node | |
| def visit_Call(self, node): | |
| if not isinstance(node.func, ast.Name): | |
| raise ValueError("Only simple function calls are allowed") | |
| func = node.func.id | |
| if func not in _ALLOWED_FUNCS: | |
| raise ValueError(f"Function not allowed: {func}") | |
| self.generic_visit(node) | |
| return node | |
| def generic_visit(self, node): | |
| allowed = ( | |
| ast.Expression, ast.BoolOp, ast.BinOp, ast.UnaryOp, | |
| ast.Compare, ast.Call, ast.Name, ast.Load, ast.Constant, | |
| ast.And, ast.Or, ast.Not, | |
| ast.Add, ast.Sub, ast.Mult, ast.Div, ast.Mod, ast.Pow, ast.FloorDiv, | |
| ast.Eq, ast.NotEq, ast.Lt, ast.LtE, ast.Gt, ast.GtE, ast.In, ast.NotIn, | |
| ast.USub, ast.UAdd | |
| ) | |
| if not isinstance(node, allowed): | |
| raise ValueError(f"Unsupported syntax: {type(node).__name__}") | |
| return super().generic_visit(node) | |
| def _eval_series_expr(expr: str, df: pd.DataFrame) -> pd.Series: | |
| allowed_names = set(df.columns) | {"True", "False", "None"} | |
| tree = ast.parse(expr, mode="eval") | |
| _SafeExpr(allowed_names).visit(tree) | |
| code = compile(tree, "<expr>", "eval") | |
| env = {**{k: df[k] for k in df.columns}, **_ALLOWED_FUNCS} | |
| return eval(code, {"__builtins__": {}}, env) | |
| # ---------------------------- | |
| # Engine | |
| # ---------------------------- | |
| class ScenarioEngine: | |
| """ | |
| Scenario-first engine: | |
| - Parse tasks + inline directives from scenario text | |
| - For each task, execute a pipeline over analysis_results: | |
| load -> filter -> derive -> groupby/agg -> pivot -> sort/top -> select fields -> render | |
| - Render formats: table | list | comparison | map | narrative | chart (Vega-Lite spec) | |
| - Strict: only what is asked is emitted. | |
| """ | |
| def render(scenario_text: str, analysis_results: Dict[str, Any]) -> str: | |
| scen = ScenarioEngine._parse_scenario(scenario_text) | |
| out: List[str] = ["# Scenario Output\n"] | |
| for task in scen["tasks"]: | |
| out.append(ScenarioEngine._render_task(task, analysis_results)) | |
| return "\n".join(out).strip() | |
| # ------------- Parsing ------------- | |
| def _parse_scenario(s: str) -> Dict[str, Any]: | |
| """ | |
| Detect a 'Tasks/Deliverables/Requirements/Your Tasks' block; fallback to any bullet/numbered lines. | |
| Each task may include inline directives: key: value | |
| Supported directives (per task): | |
| format: table|list|comparison|map|narrative|chart | |
| data_key: <key in analysis_results> | |
| filter: <expr using columns> e.g., zone == "North" and wait_time > 5 | |
| derive: <col>=<expr>[, <col2>=<expr2> ...] | |
| group_by: col1[, col2 ...] | |
| agg: avg(x), median(y), sum(z), p90(wait), count(*) | |
| pivot: index=a[,b] columns=c values=v (values must be an aggregated column) | |
| sort_by: col sort_dir: asc|desc | |
| top: N | |
| fields: col1 col2 col3 (space or comma separated) | |
| title: Custom name | |
| chart: bar|line|area|point (Vega-Lite spec emitted) | |
| x: <field> y: <field> color: <field> column: <facet> | |
| """ | |
| lines = [ln.rstrip() for ln in s.splitlines()] | |
| task_hdr = re.compile(r'^\s*(tasks?|deliverables|requirements|your tasks?)\s*$', re.I) | |
| bullet = re.compile(r'^\s*(?:\d+\.\s+|[-*•]\s+)') | |
| in_tasks = False | |
| raw_tasks: List[str] = [] | |
| for ln in lines: | |
| if task_hdr.match(ln): | |
| in_tasks = True | |
| continue | |
| if in_tasks: | |
| if bullet.match(ln.strip()): | |
| raw_tasks.append(ln.strip()) | |
| elif ln.strip() == "": | |
| continue | |
| else: | |
| # stop when we hit a non-task looking line after capturing some tasks | |
| if raw_tasks: | |
| in_tasks = False | |
| if not raw_tasks: | |
| # fallback: grab any bullet/numbered lines | |
| raw_tasks = [ln.strip() for ln in lines if bullet.match(ln.strip())] | |
| tasks: List[Dict[str, Any]] = [] | |
| for raw in raw_tasks: | |
| directives = ScenarioEngine._extract_directives(raw) | |
| title = directives.get("title") or ScenarioEngine._strip_bullet(raw) | |
| tasks.append({"title": title, "raw": raw, "d": directives}) | |
| return {"tasks": tasks} | |
| def _strip_bullet(line: str) -> str: | |
| return re.sub(r'^\s*(?:\d+\.\s+|[-*•]\s+)', '', line).strip() | |
| def _extract_directives(text: str) -> Dict[str, Any]: | |
| d: Dict[str, Any] = {} | |
| # key: value pairs (value extends until ; or end or two spaces before next key:) | |
| for m in re.finditer(r'([a-z_]+)\s*:\s*([^|,\n;]+)', text, re.I): | |
| k = m.group(1).strip().lower() | |
| v = m.group(2).strip() | |
| d[k] = v | |
| def _split_csv(val: str) -> List[str]: | |
| return [x.strip() for x in re.split(r'[,\s]+', val) if x.strip()] | |
| if "fields" in d: | |
| d["fields"] = _split_csv(d["fields"]) | |
| if "group_by" in d: | |
| d["group_by"] = _split_csv(d["group_by"]) | |
| if "top" in d: | |
| try: | |
| d["top"] = int(re.findall(r'\d+', d["top"])[0]) | |
| except Exception: | |
| d["top"] = None | |
| if "sort_dir" in d: | |
| d["sort_dir"] = "desc" if d["sort_dir"].lower().startswith("d") else "asc" | |
| if "format" in d: | |
| d["format"] = d["format"].lower() | |
| if "chart" in d: | |
| d["chart"] = d["chart"].lower() | |
| return d | |
| # ------------- Rendering ------------- | |
| def _render_task(task: Dict[str, Any], analysis_results: Dict[str, Any]) -> str: | |
| title, d = task["title"], task["d"] | |
| section: List[str] = [f"## {title}\n"] | |
| # 1) Resolve data | |
| df, key_used, why = ScenarioEngine._resolve_df(d, analysis_results) | |
| if df is None: | |
| section.append("_No matching data for this task._") | |
| section.append(f"\n> Resolver note: {why}") | |
| return "\n".join(section) | |
| # 2) Filter | |
| if "filter" in d: | |
| mask = ScenarioEngine._safe_filter(df, d["filter"]) | |
| df = df.loc[mask].copy() | |
| # 3) Derive columns | |
| if "derive" in d: | |
| df = ScenarioEngine._apply_derive(df, d["derive"]) | |
| # 4) Group & aggregate | |
| if "group_by" in d or "agg" in d: | |
| df = ScenarioEngine._group_agg(df, d.get("group_by"), d.get("agg")) | |
| # 5) Pivot | |
| if "pivot" in d: | |
| df = ScenarioEngine._pivot(df, d["pivot"]) | |
| # 6) Sort + Top | |
| if "sort_by" in d: | |
| asc = (d.get("sort_dir", "desc") == "asc") | |
| df = df.sort_values(by=d["sort_by"], ascending=asc) | |
| if isinstance(d.get("top"), int) and d["top"] > 0: | |
| df = df.head(d["top"]) | |
| # 7) Fields selection | |
| if "fields" in d: | |
| cols = [c for c in d["fields"] if c in df.columns] | |
| if cols: | |
| df = df[cols] | |
| # 8) Render by format | |
| fmt = d.get("format", "table") | |
| if fmt == "list": | |
| section.append(ScenarioEngine._render_list(df)) | |
| elif fmt == "comparison": | |
| section.append(ScenarioEngine._render_comparison(df)) | |
| elif fmt == "map": | |
| section.append(ScenarioEngine._render_map(df)) | |
| elif fmt == "narrative": | |
| section.append(ScenarioEngine._render_narrative(df)) | |
| elif fmt == "chart": | |
| section.append(ScenarioEngine._render_chart_spec(df, d)) | |
| else: | |
| section.append(ScenarioEngine._render_table(df)) | |
| # 9) Per-task provenance (kept minimal) | |
| section.append("\n**Provenance**") | |
| section.append(f"- Data key: `{key_used}`") | |
| section.append(f"- Match note: {why}") | |
| return "\n".join(section) | |
| # ------------- Data resolution ------------- | |
| def _resolve_df(d: Dict[str, Any], analysis_results: Dict[str, Any]) -> Tuple[Optional[pd.DataFrame], Optional[str], str]: | |
| # explicit key | |
| if "data_key" in d and d["data_key"] in analysis_results: | |
| return ScenarioEngine._as_df(analysis_results[d["data_key"]]), d["data_key"], "explicit data_key" | |
| # jaccard match on keys using hinted fields + any words in title/sort/agg | |
| hints = set() | |
| for k in ("fields", "sort_by"): | |
| v = d.get(k) | |
| if isinstance(v, list): | |
| hints |= set(v) | |
| elif isinstance(v, str): | |
| hints |= set(re.findall(r'[A-Za-z0-9_]+', v.lower())) | |
| best_key, best_score = None, 0.0 | |
| for k in analysis_results: | |
| words = set(re.findall(r'[A-Za-z0-9_]+', k.lower())) | |
| if not words: | |
| continue | |
| inter = len(hints & words) | |
| union = len(hints | words) or 1 | |
| score = inter / union | |
| if score > best_score: | |
| best_key, best_score = k, score | |
| if best_key: | |
| return ScenarioEngine._as_df(analysis_results[best_key]), best_key, f"keyword match (score={best_score:.2f})" | |
| # fallback: first list-of-dicts or dict-like | |
| for k, v in analysis_results.items(): | |
| df = ScenarioEngine._as_df(v) | |
| if df is not None and not df.empty: | |
| return df, k, "fallback first structured" | |
| return None, None, "no suitable dataset found" | |
| def _as_df(v: Any) -> Optional[pd.DataFrame]: | |
| if isinstance(v, list): | |
| if not v: | |
| return pd.DataFrame() | |
| if isinstance(v[0], dict): | |
| return pd.DataFrame(v) | |
| return pd.DataFrame({"value": v}) | |
| if isinstance(v, dict): | |
| # expand nested dicts into columns where sensible | |
| flat = {} | |
| any_scalar = False | |
| for k, val in v.items(): | |
| if isinstance(val, (int, float, str, bool, type(None))): | |
| flat[k] = [val] | |
| any_scalar = True | |
| if any_scalar: | |
| return pd.DataFrame(flat) | |
| # complex dict -> try records | |
| recs = [] | |
| for k, val in v.items(): | |
| if isinstance(val, dict): | |
| rec = {"item": k} | |
| rec.update({kk: valv for kk, valv in val.items()}) | |
| recs.append(rec) | |
| if recs: | |
| return pd.DataFrame(recs) | |
| return None | |
| # ------------- Pipeline ops ------------- | |
| def _safe_filter(df: pd.DataFrame, expr: str) -> pd.Series: | |
| try: | |
| s = _eval_series_expr(expr, df) | |
| if not isinstance(s, (pd.Series, np.ndarray)): | |
| raise ValueError("filter must evaluate to a boolean Series/array") | |
| return pd.Series(s).astype(bool).reindex(df.index, fill_value=False) | |
| except Exception as e: | |
| raise ValueError(f"Invalid filter expression: {e}") | |
| def _apply_derive(df: pd.DataFrame, spec: str) -> pd.DataFrame: | |
| # e.g., "load = patients / capacity, rate = 100*admits/pop" | |
| parts = re.split(r'[;,]\s*', spec) | |
| for p in parts: | |
| if not p.strip(): | |
| continue | |
| if "=" not in p: | |
| raise ValueError(f"derive requires assignments: '{p}'") | |
| col, expr = p.split("=", 1) | |
| col = col.strip() | |
| expr = expr.strip() | |
| df[col] = _eval_series_expr(expr, df) | |
| return df | |
| def _parse_aggs(spec: Optional[str]) -> List[Tuple[str, str]]: | |
| """ | |
| Returns list of (out_col, func_call_string), e.g. [("avg_wait_time","avg(wait_time)")] | |
| """ | |
| if not spec: | |
| return [] | |
| items = [x.strip() for x in spec.split(",") if x.strip()] | |
| out: List[Tuple[str, str]] = [] | |
| for it in items: | |
| m = re.match(r'([a-zA-Z_][a-zA-Z0-9_]*)\s*\(([^)]+)\)', it) | |
| if not m: | |
| if it.lower() in ("count", "count(*)"): | |
| out.append(("count", "count(*)")) | |
| continue | |
| raise ValueError(f"Bad agg item: '{it}' (use avg(x), median(y), p90(z), sum(a), count(*))") | |
| func = m.group(1) | |
| arg = m.group(2).strip() | |
| out_col = f"{func.lower()}_{arg}" | |
| out.append((out_col, f"{func}({arg})")) | |
| return out | |
| def _group_agg(df: pd.DataFrame, group_by: Optional[List[str]], agg_spec: Optional[str]) -> pd.DataFrame: | |
| aggs = ScenarioEngine._parse_aggs(agg_spec) | |
| if not aggs and not group_by: | |
| return df | |
| if not group_by: | |
| # reduce to single row with requested aggs | |
| res = {} | |
| for out_col, call in aggs: | |
| val = ScenarioEngine._apply_agg_call(df, call) | |
| res[out_col] = val | |
| return pd.DataFrame([res]) | |
| # grouped | |
| gb = df.groupby(group_by, dropna=False) | |
| rows = [] | |
| for keys, g in gb: | |
| if not isinstance(keys, tuple): | |
| keys = (keys,) | |
| rec = {group_by[i]: keys[i] for i in range(len(group_by))} | |
| for out_col, call in aggs: | |
| rec[out_col] = ScenarioEngine._apply_agg_call(g, call) | |
| if not aggs: | |
| # no aggs? carry counts by default | |
| rec["count"] = len(g) | |
| rows.append(rec) | |
| return pd.DataFrame(rows) | |
| def _apply_agg_call(df: pd.DataFrame, call: str) -> Any: | |
| call = call.strip() | |
| if call.lower() in ("count", "count(*)"): | |
| return int(len(df)) | |
| m = re.match(r'([a-zA-Z_][a-zA-Z0-9_]*)\s*\(([^)]+)\)', call) | |
| if not m: | |
| raise ValueError(f"Bad agg call: {call}") | |
| func, arg = m.group(1).lower(), m.group(2).strip() | |
| if arg not in df.columns: | |
| raise ValueError(f"Unknown column in agg: {arg}") | |
| col = df[arg].dropna() | |
| if func in ("avg", "mean"): | |
| return float(np.mean(col)) if len(col) else float("nan") | |
| if func == "median": | |
| return float(np.median(col)) if len(col) else float("nan") | |
| if func == "sum": | |
| return float(np.sum(col)) if len(col) else 0.0 | |
| if func in ("min", "max"): | |
| f = getattr(np, func) | |
| return float(f(col)) if len(col) else float("nan") | |
| if func.startswith("p") and func[1:].isdigit(): | |
| q = int(func[1:]) | |
| return float(np.percentile(col, q)) if len(col) else float("nan") | |
| raise ValueError(f"Unsupported agg function: {func}") | |
| def _pivot(df: pd.DataFrame, spec: str) -> pd.DataFrame: | |
| # spec: index=a[,b] columns=c values=v | |
| parts = dict(re.findall(r'(\w+)\s*=\s*([^\s,]+)', spec)) | |
| idx = parts.get("index") | |
| cols = parts.get("columns") | |
| vals = parts.get("values") | |
| if not (idx and cols and vals): | |
| raise ValueError("pivot requires 'index=.. columns=.. values=..'") | |
| idx = [x.strip() for x in idx.split(",")] | |
| pv = df.pivot_table(index=idx, columns=cols, values=vals, aggfunc="first") | |
| pv = pv.reset_index() | |
| # flatten columns if needed | |
| if isinstance(pv.columns, pd.MultiIndex): | |
| pv.columns = ["_".join([str(c) for c in tup if c != ""]) for tup in pv.columns] | |
| return pv | |
| # ------------- Output renderers ------------- | |
| def _render_table(df: pd.DataFrame) -> str: | |
| if df.empty: | |
| return "_No rows to display._" | |
| # convert all to string-friendly | |
| dff = df.copy() | |
| for c in dff.columns: | |
| dff[c] = dff[c].apply(lambda v: ScenarioEngine._fmt_val(v)) | |
| header = "| " + " | ".join(dff.columns) + " |" | |
| sep = "|" + "|".join(["---"] * len(dff.columns)) + "|" | |
| rows = ["| " + " | ".join(map(str, r)) + " |" for r in dff.to_numpy().tolist()] | |
| return "\n".join([header, sep, *rows]) | |
| def _render_list(df: pd.DataFrame) -> str: | |
| if df.empty: | |
| return "_No items._" | |
| # pick first column as primary | |
| primary = df.columns[0] | |
| lines = [] | |
| for i, row in enumerate(df.itertuples(index=False), 1): | |
| parts = [] | |
| for c, v in zip(df.columns, row): | |
| if c == primary: | |
| continue | |
| parts.append(f"{c}: {ScenarioEngine._fmt_val(v)}") | |
| extra = f" ({', '.join(parts)})" if parts else "" | |
| lines.append(f"{i}. {ScenarioEngine._fmt_val(getattr(row, primary))}{extra}") | |
| return "\n".join(lines) | |
| def _render_comparison(df: pd.DataFrame) -> str: | |
| # look for columns named like current/previous | |
| cols = {c.lower(): c for c in df.columns} | |
| cur = cols.get("current") or cols.get("now") or cols.get("value") | |
| prev = cols.get("previous") or cols.get("prior") or cols.get("past") | |
| name = cols.get("name") or cols.get("metric") or cols.get("item") or df.columns[0] | |
| if not (cur and prev): | |
| return "_Comparison format requires columns 'current' and 'previous' (or aliases)._" | |
| header = "| Item | Current | Previous | Change |" | |
| sep = "|---|---:|---:|---:|" | |
| body = [] | |
| for _, r in df.iterrows(): | |
| c, p = r[cur], r[prev] | |
| change = (c - p) if isinstance(c, (int, float)) and isinstance(p, (int, float)) else "N/A" | |
| body.append(f"| {ScenarioEngine._fmt_val(r[name])} | {ScenarioEngine._fmt_val(c)} | {ScenarioEngine._fmt_val(p)} | {ScenarioEngine._fmt_val(change)} |") | |
| return "\n".join([header, sep, *body]) | |
| def _render_map(df: pd.DataFrame) -> str: | |
| # simple location table | |
| colmap = {c.lower(): c for c in df.columns} | |
| name = colmap.get("name") or colmap.get("facility") or colmap.get("title") or df.columns[0] | |
| zone = colmap.get("zone") | |
| city = colmap.get("city") | |
| region = colmap.get("region") | |
| lat = colmap.get("latitude") or colmap.get("lat") | |
| lon = colmap.get("longitude") or colmap.get("lon") | |
| cols = [x for x in [name, city, region, zone, lat, lon] if x] | |
| if not cols: | |
| return "_No geographic fields to show._" | |
| dff = df[cols].copy() | |
| dff["coordinates"] = np.where((lat is not None) & (lon is not None) & dff[lat].notna() & dff[lon].notna(), | |
| dff[lat].astype(str) + ", " + dff[lon].astype(str), "N/A") | |
| show = [name, city or "city", region or "region", zone or "zone", "coordinates"] | |
| # ensure all exist | |
| for c in show: | |
| if c not in dff.columns: | |
| dff[c] = "" | |
| dff = dff[show] | |
| return ScenarioEngine._render_table(dff) | |
| def _render_narrative(df: pd.DataFrame) -> str: | |
| if df.empty: | |
| return "_No content._" | |
| paras = [] | |
| for i, row in enumerate(df.to_dict(orient="records"), 1): | |
| parts = [f"**{k}**: {ScenarioEngine._fmt_val(v)}" for k, v in row.items()] | |
| paras.append(f"{i}. " + "; ".join(parts)) | |
| return "\n".join(paras) | |
| def _render_chart_spec(df: pd.DataFrame, d: Dict[str, Any]) -> str: | |
| """ | |
| Emits a Vega-Lite spec in a fenced code block that downstream renderers can plot exactly. | |
| Accepts: chart (bar|line|area|point), x, y, color, column (facet) | |
| """ | |
| mark = d.get("chart", "bar") | |
| spec = { | |
| "$schema": "https://vega.github.io/schema/vega-lite/v5.json", | |
| "description": d.get("title") or "Chart", | |
| "data": {"values": df.to_dict(orient="records")}, | |
| "mark": mark, | |
| "encoding": {} | |
| } | |
| for enc in ("x", "y", "color", "column"): | |
| if enc in d and d[enc] in df.columns: | |
| spec["encoding"][enc] = {"field": d[enc], "type": "quantitative" if pd.api.types.is_numeric_dtype(df[d[enc]]) else "nominal"} | |
| return "```vega-lite\n" + json.dumps(spec, ensure_ascii=False, indent=2) + "\n```" | |
| # ------------- Helpers ------------- | |
| def _fmt_val(v: Any) -> str: | |
| if isinstance(v, float): | |
| if math.isnan(v): | |
| return "NaN" | |
| return f"{v:,.4g}" | |
| if isinstance(v, (int, np.integer)): | |
| return f"{int(v):,}" | |
| return str(v) | |