Spaces:
Sleeping
Sleeping
| # narrative_safetynet.py | |
| from __future__ import annotations | |
| from typing import Dict, Any, List, Optional, Tuple | |
| import math | |
| import numpy as np | |
| import pandas as pd | |
| import re | |
| # ---------- Generic helpers ---------- | |
| _DEF_MIN_SAMPLE = 5 # threshold for "interpret with caution" (fully generic) | |
| def _is_numeric(s: pd.Series) -> bool: | |
| return pd.api.types.is_numeric_dtype(s) | |
| def _fmt_num(x: Any, decimals: int = 1) -> str: | |
| try: | |
| if x is None or (isinstance(x, float) and math.isnan(x)): | |
| return "n/a" | |
| if isinstance(x, (int, np.integer)): | |
| return f"{x:,}" | |
| return f"{float(x):,.{decimals}f}" | |
| except Exception: | |
| return str(x) | |
| def _pick_numeric(df: pd.DataFrame, hints: List[str]) -> Optional[str]: | |
| # choose a numeric column; use hints like "Surgery_Median", "Consult_Median" if present | |
| cols = list(df.columns) | |
| for h in hints: | |
| for c in cols: | |
| if h.lower() in c.lower() and _is_numeric(df[c]): | |
| return c | |
| for c in cols: | |
| if _is_numeric(df[c]): | |
| return c | |
| return None | |
| def _find_group_col(df: pd.DataFrame, candidates: List[str]) -> Optional[str]: | |
| # choose a categorical/grouping column by fuzzy name | |
| cols = list(df.columns) | |
| for cand in candidates: | |
| for c in cols: | |
| if cand.lower() in c.lower(): | |
| return c | |
| # fallback: first object/string column with reasonable cardinality | |
| obj_cols = [c for c in cols if df[c].dtype == "object"] | |
| for c in obj_cols: | |
| nuniq = df[c].nunique(dropna=True) | |
| if 1 < nuniq < max(50, len(df) // 10): # avoid IDs (too high cardinality) and constants | |
| return c | |
| return None | |
| def _nanlike_to_nan(df: pd.DataFrame) -> pd.DataFrame: | |
| # treat dashes and blank strings as NaN; do not hard-code schema | |
| dff = df.copy() | |
| for c in dff.columns: | |
| if dff[c].dtype == "object": | |
| dff[c] = dff[c].replace({r"^\s*$": np.nan, r"^[-ββ]$": np.nan}, regex=True) | |
| return dff | |
| def _small_sample_note(n: int, min_n: int = _DEF_MIN_SAMPLE) -> Optional[str]: | |
| return f"Interpret averages cautiously (only {n} records)." if n < min_n else None | |
| def _deviation_label(x: float, mu: float, tol: float = 0.01) -> str: | |
| # tol is a fraction of mu for β equal bucket (1% default) | |
| if np.isnan(x) or np.isnan(mu): | |
| return "unknown" | |
| if mu == 0: | |
| return "unknown" | |
| rel = (x - mu) / mu | |
| if rel > 0.05: # > +5% above average | |
| return "higher than average" | |
| if rel < -0.05: # < -5% below average | |
| return "lower than average" | |
| if abs(rel) <= max(tol, 0.05): | |
| return "about average" | |
| return "about average" | |
| def _pluralize(label: str, n: int) -> str: | |
| return f"{label}{'' if n==1 else 's'}" | |
| # ---------- Core narrative generator ---------- | |
| def build_narrative( | |
| scenario_text: str, | |
| # A dict of dataframes your engine just produced / loaded (e.g., the same dict passed to ScenarioEngine) | |
| datasets: Dict[str, Any], | |
| # Optional structured outputs your engine already rendered (tables etc.) if you want to cross-reference | |
| structured_tables: Optional[Dict[str, pd.DataFrame]] = None, | |
| # Hints for metric selection (keeps it scenario-agnostic) | |
| metric_hints: Optional[List[str]] = None, # e.g. ["Surgery_Median", "Consult_Median", "Wait"] | |
| group_hints: Optional[List[str]] = None, # e.g. ["Facility","Specialty","Zone"] | |
| min_sample: int = _DEF_MIN_SAMPLE | |
| ) -> str: | |
| """ | |
| Returns a markdown narrative that: | |
| - Summarizes methodology (cleaning, numeric detection) | |
| - Highlights top groups by the chosen metric | |
| - Computes an overall baseline and compares groups vs baseline | |
| - Flags small-sample groups | |
| - Adds geographic notes if city/lat/lon are present (fully optional) | |
| This function avoids any scenario-specific strings and infers columns dynamically. | |
| """ | |
| metric_hints = metric_hints or ["surgery_median", "consult_median", "wait", "median", "p50"] | |
| group_hints = group_hints or ["facility", "specialty", "zone", "hospital", "city", "region"] | |
| # 1) Pick a primary dataset (first table-like) and sanitize | |
| df = None | |
| df_key = None | |
| for k, v in datasets.items(): | |
| if isinstance(v, pd.DataFrame) and not v.empty: | |
| df = _nanlike_to_nan(v) | |
| df_key = k | |
| break | |
| if df is None: | |
| return "No tabular data available. Unable to generate a narrative." | |
| # 2) Pick primary metric (numeric) and up to two comparators (e.g., consult vs surgery) | |
| primary_metric = _pick_numeric(df, metric_hints) # e.g., Surgery_Median | |
| if not primary_metric: | |
| return "No numeric metric found to summarize; please ensure at least one numeric wait-time column is present." | |
| other_numeric = [c for c in df.columns if _is_numeric(df[c]) and c != primary_metric] | |
| comparator_metric = next((c for c in other_numeric if any(h in c.lower() for h in ["consult", "wait", "median", "p90", "90th"])), None) | |
| # 3) Choose groupings dynamically | |
| group1 = _find_group_col(df, group_hints) # e.g., Facility | |
| group2 = None | |
| if group1: | |
| # try to find a second group that isn't identical (e.g., Zone if Facility selected) | |
| alt_hints = [h for h in group_hints if h.lower() not in group1.lower()] | |
| group2 = _find_group_col(df.drop(columns=[group1], errors="ignore"), alt_hints) | |
| # 4) Baseline (overall) and grouped stats | |
| baseline = df[primary_metric].astype(float).mean(skipna=True) | |
| # grouped (group1) | |
| g1 = None | |
| if group1: | |
| g1 = ( | |
| df.groupby(group1, dropna=False) | |
| .agg( | |
| metric=(primary_metric, "mean"), | |
| count=(primary_metric, "count"), | |
| _comp=(comparator_metric, "mean") if comparator_metric else (primary_metric, "mean"), | |
| ) | |
| .reset_index() | |
| ) | |
| # grouped (group2) | |
| g2 = None | |
| if group2: | |
| g2 = ( | |
| df.groupby(group2, dropna=False) | |
| .agg( | |
| metric=(primary_metric, "mean"), | |
| count=(primary_metric, "count"), | |
| _comp=(comparator_metric, "mean") if comparator_metric else (primary_metric, "mean"), | |
| ) | |
| .reset_index() | |
| ) | |
| # 5) Identify top/bottom (by deviation) for group1 | |
| top_lines: List[str] = [] | |
| if isinstance(g1, pd.DataFrame) and not g1.empty: | |
| g1 = g1.sort_values(by="metric", ascending=False) | |
| k = min(5, len(g1)) | |
| for i, row in enumerate(g1.head(k).itertuples(index=False), 1): | |
| label = getattr(row, group1) | |
| metric = getattr(row, "metric") | |
| comp = getattr(row, "_comp") | |
| cnt = getattr(row, "count") | |
| devlab = _deviation_label(metric, baseline) | |
| caution = _small_sample_note(int(cnt), min_sample) | |
| msg = f"{i}. **{label}** β {primary_metric}: {_fmt_num(metric)}" | |
| if comparator_metric: | |
| msg += f"; {comparator_metric}: {_fmt_num(comp)}" | |
| msg += f"; {_pluralize('record', int(cnt))}: {cnt}" | |
| msg += f" β {devlab}" | |
| if caution: | |
| msg += f" ({caution})" | |
| top_lines.append(msg) | |
| # 6) Zone/region style overview (group2) | |
| region_lines: List[str] = [] | |
| if isinstance(g2, pd.DataFrame) and not g2.empty: | |
| # order by metric descending | |
| g2 = g2.sort_values(by="metric", ascending=False) | |
| for row in g2.itertuples(index=False): | |
| label = getattr(row, group2) | |
| metric = getattr(row, "metric") | |
| comp = getattr(row, "_comp") | |
| cnt = getattr(row, "count") | |
| devlab = _deviation_label(metric, baseline) | |
| caution = _small_sample_note(int(cnt), min_sample) | |
| line = f"- **{label}**: {_fmt_num(metric)} (vs. overall {_fmt_num(baseline)} β {devlab}); n={cnt}" | |
| if comparator_metric: | |
| line += f"; {comparator_metric}: {_fmt_num(comp)}" | |
| if caution: | |
| line += f" β {caution}" | |
| region_lines.append(line) | |
| # 7) Geographic notes (if present) | |
| # We never hard-code field names; we look for city/lat/lon patterns | |
| geo_notes: List[str] = [] | |
| city_col = next((c for c in df.columns if re.search(r"\bcity\b", c, re.I)), None) | |
| lat_col = next((c for c in df.columns if re.search(r"\b(lat|latitude)\b", c, re.I)), None) | |
| lon_col = next((c for c in df.columns if re.search(r"\b(lon|longitude)\b", c, re.I)), None) | |
| if group1 and city_col and (lat_col and lon_col): | |
| # summarize whether top groups cluster in specific cities | |
| if isinstance(g1, pd.DataFrame) and not g1.empty and group1 in df.columns: | |
| # join back to get city data for topK | |
| top_labels = [re.sub(r"\s+", " ", re.sub(r"^\s+|\s+$", "", re.sub(r"\n", " ", l))) for l in g1[group1].astype(str).head(10).tolist()] | |
| sub = df[df[group1].astype(str).isin(top_labels)] | |
| if not sub.empty: | |
| by_city = sub.groupby(city_col, dropna=False)[primary_metric].mean().reset_index().sort_values(by=primary_metric, ascending=False) | |
| # Only a brief, dynamic note (no hard-coded cities) | |
| top_city_rows = by_city.head(3).to_dict(orient="records") | |
| for r in top_city_rows: | |
| cname = r.get(city_col) | |
| val = r.get(primary_metric) | |
| geo_notes.append(f"- **{cname}** shows higher average {primary_metric} among top groups ({_fmt_num(val)}).") | |
| # 8) Methodology (derived from actual data conditions) | |
| methodology: List[str] = [] | |
| # missing values | |
| na_counts = df.isna().sum().sum() | |
| if na_counts > 0: | |
| methodology.append("Missing values (blank/dash) were treated as nulls and excluded from means.") | |
| # numeric coercion note | |
| methodology.append(f"Primary metric: **{primary_metric}**; overall average: **{_fmt_num(baseline)}**.") | |
| if comparator_metric: | |
| methodology.append(f"Comparator metric detected: **{comparator_metric}** (means shown when available).") | |
| if group1: | |
| methodology.append(f"Primary grouping inferred: **{group1}**.") | |
| if group2: | |
| methodology.append(f"Secondary grouping inferred: **{group2}**.") | |
| if min_sample != _DEF_MIN_SAMPLE: | |
| methodology.append(f"Small-sample threshold set to {min_sample} records.") | |
| # 9) Compose markdown | |
| lines: List[str] = [] | |
| lines.append("## Methodology (Auto-generated)") | |
| for m in methodology: | |
| lines.append(f"- {m}") | |
| lines.append("") | |
| if top_lines: | |
| lines.append("## Highest average values by group") | |
| lines.extend(top_lines) | |
| lines.append("") | |
| if region_lines: | |
| lines.append(f"## {group2 or 'Region/Category'} comparison vs overall") | |
| lines.extend(region_lines) | |
| lines.append("") | |
| if geo_notes: | |
| lines.append("## Geographic notes") | |
| lines.extend(geo_notes) | |
| lines.append("") | |
| # Generic recommendations template (data-driven, not hard-coded) | |
| recs: List[str] = [] | |
| if top_lines: | |
| recs.append("Prioritize resources to the highest-average groups (above overall baseline), especially those with sufficient volume.") | |
| if comparator_metric: | |
| recs.append(f"Cross-check {comparator_metric} trends to identify upstream bottlenecks (e.g., long consult waits pushing surgery waits).") | |
| if isinstance(g2, pd.DataFrame) and not g2.empty: | |
| high = g2[g2["metric"] > baseline] | |
| if not high.empty: | |
| recs.append(f"Address regional disparities where average **{primary_metric}** exceeds the overall baseline.") | |
| recs.append("For very small groups, validate data quality and consider pooling across similar categories to stabilize estimates.") | |
| recs.append("Validate coding differences (similar specialties or labels spelled differently) to ensure apples-to-apples comparison.") | |
| lines.append("## Recommendations (Auto-generated)") | |
| for r in recs: | |
| lines.append(f"- {r}") | |
| return "\n".join(lines).strip() | |