Spaces:
Sleeping
Sleeping
Rajan Sharma
commited on
Create narrative_safetynet.py
Browse files- narrative_safetynet.py +276 -0
narrative_safetynet.py
ADDED
|
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# narrative_safetynet.py
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
from typing import Dict, Any, List, Optional, Tuple
|
| 4 |
+
import math
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import re
|
| 8 |
+
|
| 9 |
+
# ---------- Generic helpers ----------
|
| 10 |
+
_DEF_MIN_SAMPLE = 5 # threshold for "interpret with caution" (fully generic)
|
| 11 |
+
|
| 12 |
+
def _is_numeric(s: pd.Series) -> bool:
|
| 13 |
+
return pd.api.types.is_numeric_dtype(s)
|
| 14 |
+
|
| 15 |
+
def _fmt_num(x: Any, decimals: int = 1) -> str:
|
| 16 |
+
try:
|
| 17 |
+
if x is None or (isinstance(x, float) and math.isnan(x)):
|
| 18 |
+
return "n/a"
|
| 19 |
+
if isinstance(x, (int, np.integer)):
|
| 20 |
+
return f"{x:,}"
|
| 21 |
+
return f"{float(x):,.{decimals}f}"
|
| 22 |
+
except Exception:
|
| 23 |
+
return str(x)
|
| 24 |
+
|
| 25 |
+
def _pick_numeric(df: pd.DataFrame, hints: List[str]) -> Optional[str]:
|
| 26 |
+
# choose a numeric column; use hints like "Surgery_Median", "Consult_Median" if present
|
| 27 |
+
cols = list(df.columns)
|
| 28 |
+
for h in hints:
|
| 29 |
+
for c in cols:
|
| 30 |
+
if h.lower() in c.lower() and _is_numeric(df[c]):
|
| 31 |
+
return c
|
| 32 |
+
for c in cols:
|
| 33 |
+
if _is_numeric(df[c]):
|
| 34 |
+
return c
|
| 35 |
+
return None
|
| 36 |
+
|
| 37 |
+
def _find_group_col(df: pd.DataFrame, candidates: List[str]) -> Optional[str]:
|
| 38 |
+
# choose a categorical/grouping column by fuzzy name
|
| 39 |
+
cols = list(df.columns)
|
| 40 |
+
for cand in candidates:
|
| 41 |
+
for c in cols:
|
| 42 |
+
if cand.lower() in c.lower():
|
| 43 |
+
return c
|
| 44 |
+
# fallback: first object/string column with reasonable cardinality
|
| 45 |
+
obj_cols = [c for c in cols if df[c].dtype == "object"]
|
| 46 |
+
for c in obj_cols:
|
| 47 |
+
nuniq = df[c].nunique(dropna=True)
|
| 48 |
+
if 1 < nuniq < max(50, len(df) // 10): # avoid IDs (too high cardinality) and constants
|
| 49 |
+
return c
|
| 50 |
+
return None
|
| 51 |
+
|
| 52 |
+
def _nanlike_to_nan(df: pd.DataFrame) -> pd.DataFrame:
|
| 53 |
+
# treat dashes and blank strings as NaN; do not hard-code schema
|
| 54 |
+
dff = df.copy()
|
| 55 |
+
for c in dff.columns:
|
| 56 |
+
if dff[c].dtype == "object":
|
| 57 |
+
dff[c] = dff[c].replace({r"^\s*$": np.nan, r"^[-–—]$": np.nan}, regex=True)
|
| 58 |
+
return dff
|
| 59 |
+
|
| 60 |
+
def _small_sample_note(n: int, min_n: int = _DEF_MIN_SAMPLE) -> Optional[str]:
|
| 61 |
+
return f"Interpret averages cautiously (only {n} records)." if n < min_n else None
|
| 62 |
+
|
| 63 |
+
def _deviation_label(x: float, mu: float, tol: float = 0.01) -> str:
|
| 64 |
+
# tol is a fraction of mu for ≈ equal bucket (1% default)
|
| 65 |
+
if np.isnan(x) or np.isnan(mu):
|
| 66 |
+
return "unknown"
|
| 67 |
+
if mu == 0:
|
| 68 |
+
return "unknown"
|
| 69 |
+
rel = (x - mu) / mu
|
| 70 |
+
if rel > 0.05: # > +5% above average
|
| 71 |
+
return "higher than average"
|
| 72 |
+
if rel < -0.05: # < -5% below average
|
| 73 |
+
return "lower than average"
|
| 74 |
+
if abs(rel) <= max(tol, 0.05):
|
| 75 |
+
return "about average"
|
| 76 |
+
return "about average"
|
| 77 |
+
|
| 78 |
+
def _pluralize(label: str, n: int) -> str:
|
| 79 |
+
return f"{label}{'' if n==1 else 's'}"
|
| 80 |
+
|
| 81 |
+
# ---------- Core narrative generator ----------
|
| 82 |
+
def build_narrative(
|
| 83 |
+
scenario_text: str,
|
| 84 |
+
# A dict of dataframes your engine just produced / loaded (e.g., the same dict passed to ScenarioEngine)
|
| 85 |
+
datasets: Dict[str, Any],
|
| 86 |
+
# Optional structured outputs your engine already rendered (tables etc.) if you want to cross-reference
|
| 87 |
+
structured_tables: Optional[Dict[str, pd.DataFrame]] = None,
|
| 88 |
+
# Hints for metric selection (keeps it scenario-agnostic)
|
| 89 |
+
metric_hints: Optional[List[str]] = None, # e.g. ["Surgery_Median", "Consult_Median", "Wait"]
|
| 90 |
+
group_hints: Optional[List[str]] = None, # e.g. ["Facility","Specialty","Zone"]
|
| 91 |
+
min_sample: int = _DEF_MIN_SAMPLE
|
| 92 |
+
) -> str:
|
| 93 |
+
"""
|
| 94 |
+
Returns a markdown narrative that:
|
| 95 |
+
- Summarizes methodology (cleaning, numeric detection)
|
| 96 |
+
- Highlights top groups by the chosen metric
|
| 97 |
+
- Computes an overall baseline and compares groups vs baseline
|
| 98 |
+
- Flags small-sample groups
|
| 99 |
+
- Adds geographic notes if city/lat/lon are present (fully optional)
|
| 100 |
+
This function avoids any scenario-specific strings and infers columns dynamically.
|
| 101 |
+
"""
|
| 102 |
+
metric_hints = metric_hints or ["surgery_median", "consult_median", "wait", "median", "p50"]
|
| 103 |
+
group_hints = group_hints or ["facility", "specialty", "zone", "hospital", "city", "region"]
|
| 104 |
+
|
| 105 |
+
# 1) Pick a primary dataset (first table-like) and sanitize
|
| 106 |
+
df = None
|
| 107 |
+
df_key = None
|
| 108 |
+
for k, v in datasets.items():
|
| 109 |
+
if isinstance(v, pd.DataFrame) and not v.empty:
|
| 110 |
+
df = _nanlike_to_nan(v)
|
| 111 |
+
df_key = k
|
| 112 |
+
break
|
| 113 |
+
if df is None:
|
| 114 |
+
return "No tabular data available. Unable to generate a narrative."
|
| 115 |
+
|
| 116 |
+
# 2) Pick primary metric (numeric) and up to two comparators (e.g., consult vs surgery)
|
| 117 |
+
primary_metric = _pick_numeric(df, metric_hints) # e.g., Surgery_Median
|
| 118 |
+
if not primary_metric:
|
| 119 |
+
return "No numeric metric found to summarize; please ensure at least one numeric wait-time column is present."
|
| 120 |
+
|
| 121 |
+
other_numeric = [c for c in df.columns if _is_numeric(df[c]) and c != primary_metric]
|
| 122 |
+
comparator_metric = next((c for c in other_numeric if any(h in c.lower() for h in ["consult", "wait", "median", "p90", "90th"])), None)
|
| 123 |
+
|
| 124 |
+
# 3) Choose groupings dynamically
|
| 125 |
+
group1 = _find_group_col(df, group_hints) # e.g., Facility
|
| 126 |
+
group2 = None
|
| 127 |
+
if group1:
|
| 128 |
+
# try to find a second group that isn't identical (e.g., Zone if Facility selected)
|
| 129 |
+
alt_hints = [h for h in group_hints if h.lower() not in group1.lower()]
|
| 130 |
+
group2 = _find_group_col(df.drop(columns=[group1], errors="ignore"), alt_hints)
|
| 131 |
+
|
| 132 |
+
# 4) Baseline (overall) and grouped stats
|
| 133 |
+
baseline = df[primary_metric].astype(float).mean(skipna=True)
|
| 134 |
+
# grouped (group1)
|
| 135 |
+
g1 = None
|
| 136 |
+
if group1:
|
| 137 |
+
g1 = (
|
| 138 |
+
df.groupby(group1, dropna=False)
|
| 139 |
+
.agg(
|
| 140 |
+
metric=(primary_metric, "mean"),
|
| 141 |
+
count=(primary_metric, "count"),
|
| 142 |
+
_comp=(comparator_metric, "mean") if comparator_metric else (primary_metric, "mean"),
|
| 143 |
+
)
|
| 144 |
+
.reset_index()
|
| 145 |
+
)
|
| 146 |
+
# grouped (group2)
|
| 147 |
+
g2 = None
|
| 148 |
+
if group2:
|
| 149 |
+
g2 = (
|
| 150 |
+
df.groupby(group2, dropna=False)
|
| 151 |
+
.agg(
|
| 152 |
+
metric=(primary_metric, "mean"),
|
| 153 |
+
count=(primary_metric, "count"),
|
| 154 |
+
_comp=(comparator_metric, "mean") if comparator_metric else (primary_metric, "mean"),
|
| 155 |
+
)
|
| 156 |
+
.reset_index()
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
# 5) Identify top/bottom (by deviation) for group1
|
| 160 |
+
top_lines: List[str] = []
|
| 161 |
+
if isinstance(g1, pd.DataFrame) and not g1.empty:
|
| 162 |
+
g1 = g1.sort_values(by="metric", ascending=False)
|
| 163 |
+
k = min(5, len(g1))
|
| 164 |
+
for i, row in enumerate(g1.head(k).itertuples(index=False), 1):
|
| 165 |
+
label = getattr(row, group1)
|
| 166 |
+
metric = getattr(row, "metric")
|
| 167 |
+
comp = getattr(row, "_comp")
|
| 168 |
+
cnt = getattr(row, "count")
|
| 169 |
+
devlab = _deviation_label(metric, baseline)
|
| 170 |
+
caution = _small_sample_note(int(cnt), min_sample)
|
| 171 |
+
msg = f"{i}. **{label}** — {primary_metric}: {_fmt_num(metric)}"
|
| 172 |
+
if comparator_metric:
|
| 173 |
+
msg += f"; {comparator_metric}: {_fmt_num(comp)}"
|
| 174 |
+
msg += f"; {_pluralize('record', int(cnt))}: {cnt}"
|
| 175 |
+
msg += f" → {devlab}"
|
| 176 |
+
if caution:
|
| 177 |
+
msg += f" ({caution})"
|
| 178 |
+
top_lines.append(msg)
|
| 179 |
+
|
| 180 |
+
# 6) Zone/region style overview (group2)
|
| 181 |
+
region_lines: List[str] = []
|
| 182 |
+
if isinstance(g2, pd.DataFrame) and not g2.empty:
|
| 183 |
+
# order by metric descending
|
| 184 |
+
g2 = g2.sort_values(by="metric", ascending=False)
|
| 185 |
+
for row in g2.itertuples(index=False):
|
| 186 |
+
label = getattr(row, group2)
|
| 187 |
+
metric = getattr(row, "metric")
|
| 188 |
+
comp = getattr(row, "_comp")
|
| 189 |
+
cnt = getattr(row, "count")
|
| 190 |
+
devlab = _deviation_label(metric, baseline)
|
| 191 |
+
caution = _small_sample_note(int(cnt), min_sample)
|
| 192 |
+
line = f"- **{label}**: {_fmt_num(metric)} (vs. overall {_fmt_num(baseline)} → {devlab}); n={cnt}"
|
| 193 |
+
if comparator_metric:
|
| 194 |
+
line += f"; {comparator_metric}: {_fmt_num(comp)}"
|
| 195 |
+
if caution:
|
| 196 |
+
line += f" — {caution}"
|
| 197 |
+
region_lines.append(line)
|
| 198 |
+
|
| 199 |
+
# 7) Geographic notes (if present)
|
| 200 |
+
# We never hard-code field names; we look for city/lat/lon patterns
|
| 201 |
+
geo_notes: List[str] = []
|
| 202 |
+
city_col = next((c for c in df.columns if re.search(r"\bcity\b", c, re.I)), None)
|
| 203 |
+
lat_col = next((c for c in df.columns if re.search(r"\b(lat|latitude)\b", c, re.I)), None)
|
| 204 |
+
lon_col = next((c for c in df.columns if re.search(r"\b(lon|longitude)\b", c, re.I)), None)
|
| 205 |
+
if group1 and city_col and (lat_col and lon_col):
|
| 206 |
+
# summarize whether top groups cluster in specific cities
|
| 207 |
+
if isinstance(g1, pd.DataFrame) and not g1.empty and group1 in df.columns:
|
| 208 |
+
# join back to get city data for topK
|
| 209 |
+
top_labels = [re.sub(r"\s+", " ", re.sub(r"^\s+|\s+$", "", re.sub(r"\n", " ", l))) for l in g1[group1].astype(str).head(10).tolist()]
|
| 210 |
+
sub = df[df[group1].astype(str).isin(top_labels)]
|
| 211 |
+
if not sub.empty:
|
| 212 |
+
by_city = sub.groupby(city_col, dropna=False)[primary_metric].mean().reset_index().sort_values(by=primary_metric, ascending=False)
|
| 213 |
+
# Only a brief, dynamic note (no hard-coded cities)
|
| 214 |
+
top_city_rows = by_city.head(3).to_dict(orient="records")
|
| 215 |
+
for r in top_city_rows:
|
| 216 |
+
cname = r.get(city_col)
|
| 217 |
+
val = r.get(primary_metric)
|
| 218 |
+
geo_notes.append(f"- **{cname}** shows higher average {primary_metric} among top groups ({_fmt_num(val)}).")
|
| 219 |
+
|
| 220 |
+
# 8) Methodology (derived from actual data conditions)
|
| 221 |
+
methodology: List[str] = []
|
| 222 |
+
# missing values
|
| 223 |
+
na_counts = df.isna().sum().sum()
|
| 224 |
+
if na_counts > 0:
|
| 225 |
+
methodology.append("Missing values (blank/dash) were treated as nulls and excluded from means.")
|
| 226 |
+
# numeric coercion note
|
| 227 |
+
methodology.append(f"Primary metric: **{primary_metric}**; overall average: **{_fmt_num(baseline)}**.")
|
| 228 |
+
if comparator_metric:
|
| 229 |
+
methodology.append(f"Comparator metric detected: **{comparator_metric}** (means shown when available).")
|
| 230 |
+
if group1:
|
| 231 |
+
methodology.append(f"Primary grouping inferred: **{group1}**.")
|
| 232 |
+
if group2:
|
| 233 |
+
methodology.append(f"Secondary grouping inferred: **{group2}**.")
|
| 234 |
+
if min_sample != _DEF_MIN_SAMPLE:
|
| 235 |
+
methodology.append(f"Small-sample threshold set to {min_sample} records.")
|
| 236 |
+
|
| 237 |
+
# 9) Compose markdown
|
| 238 |
+
lines: List[str] = []
|
| 239 |
+
lines.append("## Methodology (Auto-generated)")
|
| 240 |
+
for m in methodology:
|
| 241 |
+
lines.append(f"- {m}")
|
| 242 |
+
lines.append("")
|
| 243 |
+
|
| 244 |
+
if top_lines:
|
| 245 |
+
lines.append("## Highest average values by group")
|
| 246 |
+
lines.extend(top_lines)
|
| 247 |
+
lines.append("")
|
| 248 |
+
|
| 249 |
+
if region_lines:
|
| 250 |
+
lines.append(f"## {group2 or 'Region/Category'} comparison vs overall")
|
| 251 |
+
lines.extend(region_lines)
|
| 252 |
+
lines.append("")
|
| 253 |
+
|
| 254 |
+
if geo_notes:
|
| 255 |
+
lines.append("## Geographic notes")
|
| 256 |
+
lines.extend(geo_notes)
|
| 257 |
+
lines.append("")
|
| 258 |
+
|
| 259 |
+
# Generic recommendations template (data-driven, not hard-coded)
|
| 260 |
+
recs: List[str] = []
|
| 261 |
+
if top_lines:
|
| 262 |
+
recs.append("Prioritize resources to the highest-average groups (above overall baseline), especially those with sufficient volume.")
|
| 263 |
+
if comparator_metric:
|
| 264 |
+
recs.append(f"Cross-check {comparator_metric} trends to identify upstream bottlenecks (e.g., long consult waits pushing surgery waits).")
|
| 265 |
+
if isinstance(g2, pd.DataFrame) and not g2.empty:
|
| 266 |
+
high = g2[g2["metric"] > baseline]
|
| 267 |
+
if not high.empty:
|
| 268 |
+
recs.append(f"Address regional disparities where average **{primary_metric}** exceeds the overall baseline.")
|
| 269 |
+
recs.append("For very small groups, validate data quality and consider pooling across similar categories to stabilize estimates.")
|
| 270 |
+
recs.append("Validate coding differences (similar specialties or labels spelled differently) to ensure apples-to-apples comparison.")
|
| 271 |
+
|
| 272 |
+
lines.append("## Recommendations (Auto-generated)")
|
| 273 |
+
for r in recs:
|
| 274 |
+
lines.append(f"- {r}")
|
| 275 |
+
|
| 276 |
+
return "\n".join(lines).strip()
|