Rajan Sharma commited on
Commit
42a6bd6
·
verified ·
1 Parent(s): dddc062

Create narrative_safetynet.py

Browse files
Files changed (1) hide show
  1. narrative_safetynet.py +276 -0
narrative_safetynet.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # narrative_safetynet.py
2
+ from __future__ import annotations
3
+ from typing import Dict, Any, List, Optional, Tuple
4
+ import math
5
+ import numpy as np
6
+ import pandas as pd
7
+ import re
8
+
9
+ # ---------- Generic helpers ----------
10
+ _DEF_MIN_SAMPLE = 5 # threshold for "interpret with caution" (fully generic)
11
+
12
+ def _is_numeric(s: pd.Series) -> bool:
13
+ return pd.api.types.is_numeric_dtype(s)
14
+
15
+ def _fmt_num(x: Any, decimals: int = 1) -> str:
16
+ try:
17
+ if x is None or (isinstance(x, float) and math.isnan(x)):
18
+ return "n/a"
19
+ if isinstance(x, (int, np.integer)):
20
+ return f"{x:,}"
21
+ return f"{float(x):,.{decimals}f}"
22
+ except Exception:
23
+ return str(x)
24
+
25
+ def _pick_numeric(df: pd.DataFrame, hints: List[str]) -> Optional[str]:
26
+ # choose a numeric column; use hints like "Surgery_Median", "Consult_Median" if present
27
+ cols = list(df.columns)
28
+ for h in hints:
29
+ for c in cols:
30
+ if h.lower() in c.lower() and _is_numeric(df[c]):
31
+ return c
32
+ for c in cols:
33
+ if _is_numeric(df[c]):
34
+ return c
35
+ return None
36
+
37
+ def _find_group_col(df: pd.DataFrame, candidates: List[str]) -> Optional[str]:
38
+ # choose a categorical/grouping column by fuzzy name
39
+ cols = list(df.columns)
40
+ for cand in candidates:
41
+ for c in cols:
42
+ if cand.lower() in c.lower():
43
+ return c
44
+ # fallback: first object/string column with reasonable cardinality
45
+ obj_cols = [c for c in cols if df[c].dtype == "object"]
46
+ for c in obj_cols:
47
+ nuniq = df[c].nunique(dropna=True)
48
+ if 1 < nuniq < max(50, len(df) // 10): # avoid IDs (too high cardinality) and constants
49
+ return c
50
+ return None
51
+
52
+ def _nanlike_to_nan(df: pd.DataFrame) -> pd.DataFrame:
53
+ # treat dashes and blank strings as NaN; do not hard-code schema
54
+ dff = df.copy()
55
+ for c in dff.columns:
56
+ if dff[c].dtype == "object":
57
+ dff[c] = dff[c].replace({r"^\s*$": np.nan, r"^[-–—]$": np.nan}, regex=True)
58
+ return dff
59
+
60
+ def _small_sample_note(n: int, min_n: int = _DEF_MIN_SAMPLE) -> Optional[str]:
61
+ return f"Interpret averages cautiously (only {n} records)." if n < min_n else None
62
+
63
+ def _deviation_label(x: float, mu: float, tol: float = 0.01) -> str:
64
+ # tol is a fraction of mu for ≈ equal bucket (1% default)
65
+ if np.isnan(x) or np.isnan(mu):
66
+ return "unknown"
67
+ if mu == 0:
68
+ return "unknown"
69
+ rel = (x - mu) / mu
70
+ if rel > 0.05: # > +5% above average
71
+ return "higher than average"
72
+ if rel < -0.05: # < -5% below average
73
+ return "lower than average"
74
+ if abs(rel) <= max(tol, 0.05):
75
+ return "about average"
76
+ return "about average"
77
+
78
+ def _pluralize(label: str, n: int) -> str:
79
+ return f"{label}{'' if n==1 else 's'}"
80
+
81
+ # ---------- Core narrative generator ----------
82
+ def build_narrative(
83
+ scenario_text: str,
84
+ # A dict of dataframes your engine just produced / loaded (e.g., the same dict passed to ScenarioEngine)
85
+ datasets: Dict[str, Any],
86
+ # Optional structured outputs your engine already rendered (tables etc.) if you want to cross-reference
87
+ structured_tables: Optional[Dict[str, pd.DataFrame]] = None,
88
+ # Hints for metric selection (keeps it scenario-agnostic)
89
+ metric_hints: Optional[List[str]] = None, # e.g. ["Surgery_Median", "Consult_Median", "Wait"]
90
+ group_hints: Optional[List[str]] = None, # e.g. ["Facility","Specialty","Zone"]
91
+ min_sample: int = _DEF_MIN_SAMPLE
92
+ ) -> str:
93
+ """
94
+ Returns a markdown narrative that:
95
+ - Summarizes methodology (cleaning, numeric detection)
96
+ - Highlights top groups by the chosen metric
97
+ - Computes an overall baseline and compares groups vs baseline
98
+ - Flags small-sample groups
99
+ - Adds geographic notes if city/lat/lon are present (fully optional)
100
+ This function avoids any scenario-specific strings and infers columns dynamically.
101
+ """
102
+ metric_hints = metric_hints or ["surgery_median", "consult_median", "wait", "median", "p50"]
103
+ group_hints = group_hints or ["facility", "specialty", "zone", "hospital", "city", "region"]
104
+
105
+ # 1) Pick a primary dataset (first table-like) and sanitize
106
+ df = None
107
+ df_key = None
108
+ for k, v in datasets.items():
109
+ if isinstance(v, pd.DataFrame) and not v.empty:
110
+ df = _nanlike_to_nan(v)
111
+ df_key = k
112
+ break
113
+ if df is None:
114
+ return "No tabular data available. Unable to generate a narrative."
115
+
116
+ # 2) Pick primary metric (numeric) and up to two comparators (e.g., consult vs surgery)
117
+ primary_metric = _pick_numeric(df, metric_hints) # e.g., Surgery_Median
118
+ if not primary_metric:
119
+ return "No numeric metric found to summarize; please ensure at least one numeric wait-time column is present."
120
+
121
+ other_numeric = [c for c in df.columns if _is_numeric(df[c]) and c != primary_metric]
122
+ comparator_metric = next((c for c in other_numeric if any(h in c.lower() for h in ["consult", "wait", "median", "p90", "90th"])), None)
123
+
124
+ # 3) Choose groupings dynamically
125
+ group1 = _find_group_col(df, group_hints) # e.g., Facility
126
+ group2 = None
127
+ if group1:
128
+ # try to find a second group that isn't identical (e.g., Zone if Facility selected)
129
+ alt_hints = [h for h in group_hints if h.lower() not in group1.lower()]
130
+ group2 = _find_group_col(df.drop(columns=[group1], errors="ignore"), alt_hints)
131
+
132
+ # 4) Baseline (overall) and grouped stats
133
+ baseline = df[primary_metric].astype(float).mean(skipna=True)
134
+ # grouped (group1)
135
+ g1 = None
136
+ if group1:
137
+ g1 = (
138
+ df.groupby(group1, dropna=False)
139
+ .agg(
140
+ metric=(primary_metric, "mean"),
141
+ count=(primary_metric, "count"),
142
+ _comp=(comparator_metric, "mean") if comparator_metric else (primary_metric, "mean"),
143
+ )
144
+ .reset_index()
145
+ )
146
+ # grouped (group2)
147
+ g2 = None
148
+ if group2:
149
+ g2 = (
150
+ df.groupby(group2, dropna=False)
151
+ .agg(
152
+ metric=(primary_metric, "mean"),
153
+ count=(primary_metric, "count"),
154
+ _comp=(comparator_metric, "mean") if comparator_metric else (primary_metric, "mean"),
155
+ )
156
+ .reset_index()
157
+ )
158
+
159
+ # 5) Identify top/bottom (by deviation) for group1
160
+ top_lines: List[str] = []
161
+ if isinstance(g1, pd.DataFrame) and not g1.empty:
162
+ g1 = g1.sort_values(by="metric", ascending=False)
163
+ k = min(5, len(g1))
164
+ for i, row in enumerate(g1.head(k).itertuples(index=False), 1):
165
+ label = getattr(row, group1)
166
+ metric = getattr(row, "metric")
167
+ comp = getattr(row, "_comp")
168
+ cnt = getattr(row, "count")
169
+ devlab = _deviation_label(metric, baseline)
170
+ caution = _small_sample_note(int(cnt), min_sample)
171
+ msg = f"{i}. **{label}** — {primary_metric}: {_fmt_num(metric)}"
172
+ if comparator_metric:
173
+ msg += f"; {comparator_metric}: {_fmt_num(comp)}"
174
+ msg += f"; {_pluralize('record', int(cnt))}: {cnt}"
175
+ msg += f" → {devlab}"
176
+ if caution:
177
+ msg += f" ({caution})"
178
+ top_lines.append(msg)
179
+
180
+ # 6) Zone/region style overview (group2)
181
+ region_lines: List[str] = []
182
+ if isinstance(g2, pd.DataFrame) and not g2.empty:
183
+ # order by metric descending
184
+ g2 = g2.sort_values(by="metric", ascending=False)
185
+ for row in g2.itertuples(index=False):
186
+ label = getattr(row, group2)
187
+ metric = getattr(row, "metric")
188
+ comp = getattr(row, "_comp")
189
+ cnt = getattr(row, "count")
190
+ devlab = _deviation_label(metric, baseline)
191
+ caution = _small_sample_note(int(cnt), min_sample)
192
+ line = f"- **{label}**: {_fmt_num(metric)} (vs. overall {_fmt_num(baseline)} → {devlab}); n={cnt}"
193
+ if comparator_metric:
194
+ line += f"; {comparator_metric}: {_fmt_num(comp)}"
195
+ if caution:
196
+ line += f" — {caution}"
197
+ region_lines.append(line)
198
+
199
+ # 7) Geographic notes (if present)
200
+ # We never hard-code field names; we look for city/lat/lon patterns
201
+ geo_notes: List[str] = []
202
+ city_col = next((c for c in df.columns if re.search(r"\bcity\b", c, re.I)), None)
203
+ lat_col = next((c for c in df.columns if re.search(r"\b(lat|latitude)\b", c, re.I)), None)
204
+ lon_col = next((c for c in df.columns if re.search(r"\b(lon|longitude)\b", c, re.I)), None)
205
+ if group1 and city_col and (lat_col and lon_col):
206
+ # summarize whether top groups cluster in specific cities
207
+ if isinstance(g1, pd.DataFrame) and not g1.empty and group1 in df.columns:
208
+ # join back to get city data for topK
209
+ top_labels = [re.sub(r"\s+", " ", re.sub(r"^\s+|\s+$", "", re.sub(r"\n", " ", l))) for l in g1[group1].astype(str).head(10).tolist()]
210
+ sub = df[df[group1].astype(str).isin(top_labels)]
211
+ if not sub.empty:
212
+ by_city = sub.groupby(city_col, dropna=False)[primary_metric].mean().reset_index().sort_values(by=primary_metric, ascending=False)
213
+ # Only a brief, dynamic note (no hard-coded cities)
214
+ top_city_rows = by_city.head(3).to_dict(orient="records")
215
+ for r in top_city_rows:
216
+ cname = r.get(city_col)
217
+ val = r.get(primary_metric)
218
+ geo_notes.append(f"- **{cname}** shows higher average {primary_metric} among top groups ({_fmt_num(val)}).")
219
+
220
+ # 8) Methodology (derived from actual data conditions)
221
+ methodology: List[str] = []
222
+ # missing values
223
+ na_counts = df.isna().sum().sum()
224
+ if na_counts > 0:
225
+ methodology.append("Missing values (blank/dash) were treated as nulls and excluded from means.")
226
+ # numeric coercion note
227
+ methodology.append(f"Primary metric: **{primary_metric}**; overall average: **{_fmt_num(baseline)}**.")
228
+ if comparator_metric:
229
+ methodology.append(f"Comparator metric detected: **{comparator_metric}** (means shown when available).")
230
+ if group1:
231
+ methodology.append(f"Primary grouping inferred: **{group1}**.")
232
+ if group2:
233
+ methodology.append(f"Secondary grouping inferred: **{group2}**.")
234
+ if min_sample != _DEF_MIN_SAMPLE:
235
+ methodology.append(f"Small-sample threshold set to {min_sample} records.")
236
+
237
+ # 9) Compose markdown
238
+ lines: List[str] = []
239
+ lines.append("## Methodology (Auto-generated)")
240
+ for m in methodology:
241
+ lines.append(f"- {m}")
242
+ lines.append("")
243
+
244
+ if top_lines:
245
+ lines.append("## Highest average values by group")
246
+ lines.extend(top_lines)
247
+ lines.append("")
248
+
249
+ if region_lines:
250
+ lines.append(f"## {group2 or 'Region/Category'} comparison vs overall")
251
+ lines.extend(region_lines)
252
+ lines.append("")
253
+
254
+ if geo_notes:
255
+ lines.append("## Geographic notes")
256
+ lines.extend(geo_notes)
257
+ lines.append("")
258
+
259
+ # Generic recommendations template (data-driven, not hard-coded)
260
+ recs: List[str] = []
261
+ if top_lines:
262
+ recs.append("Prioritize resources to the highest-average groups (above overall baseline), especially those with sufficient volume.")
263
+ if comparator_metric:
264
+ recs.append(f"Cross-check {comparator_metric} trends to identify upstream bottlenecks (e.g., long consult waits pushing surgery waits).")
265
+ if isinstance(g2, pd.DataFrame) and not g2.empty:
266
+ high = g2[g2["metric"] > baseline]
267
+ if not high.empty:
268
+ recs.append(f"Address regional disparities where average **{primary_metric}** exceeds the overall baseline.")
269
+ recs.append("For very small groups, validate data quality and consider pooling across similar categories to stabilize estimates.")
270
+ recs.append("Validate coding differences (similar specialties or labels spelled differently) to ensure apples-to-apples comparison.")
271
+
272
+ lines.append("## Recommendations (Auto-generated)")
273
+ for r in recs:
274
+ lines.append(f"- {r}")
275
+
276
+ return "\n".join(lines).strip()