Rajan Sharma commited on
Commit
87088be
·
verified ·
1 Parent(s): 2fccbc6

Update narrative_safetynet.py

Browse files
Files changed (1) hide show
  1. narrative_safetynet.py +63 -77
narrative_safetynet.py CHANGED
@@ -1,12 +1,11 @@
1
  # narrative_safetynet.py
2
  from __future__ import annotations
3
- from typing import Dict, Any, List, Optional, Tuple
4
  import math
5
  import numpy as np
6
  import pandas as pd
7
  import re
8
 
9
- # ---------- Generic helpers ----------
10
  _DEF_MIN_SAMPLE = 5 # threshold for "interpret with caution" (fully generic)
11
 
12
  def _is_numeric(s: pd.Series) -> bool:
@@ -23,7 +22,7 @@ def _fmt_num(x: Any, decimals: int = 1) -> str:
23
  return str(x)
24
 
25
  def _pick_numeric(df: pd.DataFrame, hints: List[str]) -> Optional[str]:
26
- # choose a numeric column; use hints like "Surgery_Median", "Consult_Median" if present
27
  cols = list(df.columns)
28
  for h in hints:
29
  for c in cols:
@@ -35,22 +34,20 @@ def _pick_numeric(df: pd.DataFrame, hints: List[str]) -> Optional[str]:
35
  return None
36
 
37
  def _find_group_col(df: pd.DataFrame, candidates: List[str]) -> Optional[str]:
38
- # choose a categorical/grouping column by fuzzy name
39
  cols = list(df.columns)
40
  for cand in candidates:
41
  for c in cols:
42
  if cand.lower() in c.lower():
43
  return c
44
- # fallback: first object/string column with reasonable cardinality
45
  obj_cols = [c for c in cols if df[c].dtype == "object"]
46
  for c in obj_cols:
47
  nuniq = df[c].nunique(dropna=True)
48
- if 1 < nuniq < max(50, len(df) // 10): # avoid IDs (too high cardinality) and constants
49
  return c
50
  return None
51
 
52
  def _nanlike_to_nan(df: pd.DataFrame) -> pd.DataFrame:
53
- # treat dashes and blank strings as NaN; do not hard-code schema
54
  dff = df.copy()
55
  for c in dff.columns:
56
  if dff[c].dtype == "object":
@@ -61,15 +58,12 @@ def _small_sample_note(n: int, min_n: int = _DEF_MIN_SAMPLE) -> Optional[str]:
61
  return f"Interpret averages cautiously (only {n} records)." if n < min_n else None
62
 
63
  def _deviation_label(x: float, mu: float, tol: float = 0.01) -> str:
64
- # tol is a fraction of mu for ≈ equal bucket (1% default)
65
- if np.isnan(x) or np.isnan(mu):
66
- return "unknown"
67
- if mu == 0:
68
  return "unknown"
69
  rel = (x - mu) / mu
70
- if rel > 0.05: # > +5% above average
71
  return "higher than average"
72
- if rel < -0.05: # < -5% below average
73
  return "lower than average"
74
  if abs(rel) <= max(tol, 0.05):
75
  return "about average"
@@ -78,31 +72,25 @@ def _deviation_label(x: float, mu: float, tol: float = 0.01) -> str:
78
  def _pluralize(label: str, n: int) -> str:
79
  return f"{label}{'' if n==1 else 's'}"
80
 
81
- # ---------- Core narrative generator ----------
82
  def build_narrative(
83
  scenario_text: str,
84
- # A dict of dataframes your engine just produced / loaded (e.g., the same dict passed to ScenarioEngine)
85
  datasets: Dict[str, Any],
86
- # Optional structured outputs your engine already rendered (tables etc.) if you want to cross-reference
87
  structured_tables: Optional[Dict[str, pd.DataFrame]] = None,
88
- # Hints for metric selection (keeps it scenario-agnostic)
89
- metric_hints: Optional[List[str]] = None, # e.g. ["Surgery_Median", "Consult_Median", "Wait"]
90
- group_hints: Optional[List[str]] = None, # e.g. ["Facility","Specialty","Zone"]
91
  min_sample: int = _DEF_MIN_SAMPLE
92
  ) -> str:
93
  """
94
- Returns a markdown narrative that:
95
- - Summarizes methodology (cleaning, numeric detection)
96
- - Highlights top groups by the chosen metric
97
- - Computes an overall baseline and compares groups vs baseline
98
- - Flags small-sample groups
99
- - Adds geographic notes if city/lat/lon are present (fully optional)
100
- This function avoids any scenario-specific strings and infers columns dynamically.
101
  """
102
- metric_hints = metric_hints or ["surgery_median", "consult_median", "wait", "median", "p50"]
103
  group_hints = group_hints or ["facility", "specialty", "zone", "hospital", "city", "region"]
104
 
105
- # 1) Pick a primary dataset (first table-like) and sanitize
106
  df = None
107
  df_key = None
108
  for k, v in datasets.items():
@@ -113,50 +101,50 @@ def build_narrative(
113
  if df is None:
114
  return "No tabular data available. Unable to generate a narrative."
115
 
116
- # 2) Pick primary metric (numeric) and up to two comparators (e.g., consult vs surgery)
117
  primary_metric = _pick_numeric(df, metric_hints) # e.g., Surgery_Median
118
  if not primary_metric:
119
  return "No numeric metric found to summarize; please ensure at least one numeric wait-time column is present."
120
 
121
  other_numeric = [c for c in df.columns if _is_numeric(df[c]) and c != primary_metric]
122
- comparator_metric = next((c for c in other_numeric if any(h in c.lower() for h in ["consult", "wait", "median", "p90", "90th"])), None)
 
 
 
123
 
124
- # 3) Choose groupings dynamically
125
  group1 = _find_group_col(df, group_hints) # e.g., Facility
126
  group2 = None
127
  if group1:
128
- # try to find a second group that isn't identical (e.g., Zone if Facility selected)
129
  alt_hints = [h for h in group_hints if h.lower() not in group1.lower()]
130
  group2 = _find_group_col(df.drop(columns=[group1], errors="ignore"), alt_hints)
131
 
132
- # 4) Baseline (overall) and grouped stats
133
- baseline = df[primary_metric].astype(float).mean(skipna=True)
134
- # grouped (group1)
135
- g1 = None
136
- if group1:
137
- g1 = (
138
- df.groupby(group1, dropna=False)
139
- .agg(
140
- metric=(primary_metric, "mean"),
141
- count=(primary_metric, "count"),
142
- _comp=(comparator_metric, "mean") if comparator_metric else (primary_metric, "mean"),
143
- )
144
- .reset_index()
145
- )
146
- # grouped (group2)
147
- g2 = None
148
- if group2:
149
- g2 = (
150
- df.groupby(group2, dropna=False)
151
- .agg(
152
- metric=(primary_metric, "mean"),
153
- count=(primary_metric, "count"),
154
- _comp=(comparator_metric, "mean") if comparator_metric else (primary_metric, "mean"),
155
- )
156
- .reset_index()
157
  )
 
 
 
 
158
 
159
- # 5) Identify top/bottom (by deviation) for group1
160
  top_lines: List[str] = []
161
  if isinstance(g1, pd.DataFrame) and not g1.empty:
162
  g1 = g1.sort_values(by="metric", ascending=False)
@@ -164,7 +152,7 @@ def build_narrative(
164
  for i, row in enumerate(g1.head(k).itertuples(index=False), 1):
165
  label = getattr(row, group1)
166
  metric = getattr(row, "metric")
167
- comp = getattr(row, "_comp")
168
  cnt = getattr(row, "count")
169
  devlab = _deviation_label(metric, baseline)
170
  caution = _small_sample_note(int(cnt), min_sample)
@@ -177,15 +165,14 @@ def build_narrative(
177
  msg += f" ({caution})"
178
  top_lines.append(msg)
179
 
180
- # 6) Zone/region style overview (group2)
181
  region_lines: List[str] = []
182
  if isinstance(g2, pd.DataFrame) and not g2.empty:
183
- # order by metric descending
184
  g2 = g2.sort_values(by="metric", ascending=False)
185
  for row in g2.itertuples(index=False):
186
  label = getattr(row, group2)
187
  metric = getattr(row, "metric")
188
- comp = getattr(row, "_comp")
189
  cnt = getattr(row, "count")
190
  devlab = _deviation_label(metric, baseline)
191
  caution = _small_sample_note(int(cnt), min_sample)
@@ -196,34 +183,33 @@ def build_narrative(
196
  line += f" — {caution}"
197
  region_lines.append(line)
198
 
199
- # 7) Geographic notes (if present)
200
- # We never hard-code field names; we look for city/lat/lon patterns
201
  geo_notes: List[str] = []
202
  city_col = next((c for c in df.columns if re.search(r"\bcity\b", c, re.I)), None)
203
  lat_col = next((c for c in df.columns if re.search(r"\b(lat|latitude)\b", c, re.I)), None)
204
  lon_col = next((c for c in df.columns if re.search(r"\b(lon|longitude)\b", c, re.I)), None)
205
  if group1 and city_col and (lat_col and lon_col):
206
- # summarize whether top groups cluster in specific cities
207
  if isinstance(g1, pd.DataFrame) and not g1.empty and group1 in df.columns:
208
- # join back to get city data for topK
209
- top_labels = [re.sub(r"\s+", " ", re.sub(r"^\s+|\s+$", "", re.sub(r"\n", " ", l))) for l in g1[group1].astype(str).head(10).tolist()]
210
- sub = df[df[group1].astype(str).isin(top_labels)]
211
  if not sub.empty:
212
- by_city = sub.groupby(city_col, dropna=False)[primary_metric].mean().reset_index().sort_values(by=primary_metric, ascending=False)
213
- # Only a brief, dynamic note (no hard-coded cities)
214
- top_city_rows = by_city.head(3).to_dict(orient="records")
215
- for r in top_city_rows:
 
 
 
 
216
  cname = r.get(city_col)
217
  val = r.get(primary_metric)
218
  geo_notes.append(f"- **{cname}** shows higher average {primary_metric} among top groups ({_fmt_num(val)}).")
219
 
220
- # 8) Methodology (derived from actual data conditions)
221
  methodology: List[str] = []
222
- # missing values
223
  na_counts = df.isna().sum().sum()
224
  if na_counts > 0:
225
  methodology.append("Missing values (blank/dash) were treated as nulls and excluded from means.")
226
- # numeric coercion note
227
  methodology.append(f"Primary metric: **{primary_metric}**; overall average: **{_fmt_num(baseline)}**.")
228
  if comparator_metric:
229
  methodology.append(f"Comparator metric detected: **{comparator_metric}** (means shown when available).")
@@ -256,7 +242,6 @@ def build_narrative(
256
  lines.extend(geo_notes)
257
  lines.append("")
258
 
259
- # Generic recommendations template (data-driven, not hard-coded)
260
  recs: List[str] = []
261
  if top_lines:
262
  recs.append("Prioritize resources to the highest-average groups (above overall baseline), especially those with sufficient volume.")
@@ -265,7 +250,7 @@ def build_narrative(
265
  if isinstance(g2, pd.DataFrame) and not g2.empty:
266
  high = g2[g2["metric"] > baseline]
267
  if not high.empty:
268
- recs.append(f"Address regional disparities where average **{primary_metric}** exceeds the overall baseline.")
269
  recs.append("For very small groups, validate data quality and consider pooling across similar categories to stabilize estimates.")
270
  recs.append("Validate coding differences (similar specialties or labels spelled differently) to ensure apples-to-apples comparison.")
271
 
@@ -274,3 +259,4 @@ def build_narrative(
274
  lines.append(f"- {r}")
275
 
276
  return "\n".join(lines).strip()
 
 
1
  # narrative_safetynet.py
2
  from __future__ import annotations
3
+ from typing import Dict, Any, List, Optional
4
  import math
5
  import numpy as np
6
  import pandas as pd
7
  import re
8
 
 
9
  _DEF_MIN_SAMPLE = 5 # threshold for "interpret with caution" (fully generic)
10
 
11
  def _is_numeric(s: pd.Series) -> bool:
 
22
  return str(x)
23
 
24
  def _pick_numeric(df: pd.DataFrame, hints: List[str]) -> Optional[str]:
25
+ # choose a numeric column; prefer hinted names
26
  cols = list(df.columns)
27
  for h in hints:
28
  for c in cols:
 
34
  return None
35
 
36
  def _find_group_col(df: pd.DataFrame, candidates: List[str]) -> Optional[str]:
 
37
  cols = list(df.columns)
38
  for cand in candidates:
39
  for c in cols:
40
  if cand.lower() in c.lower():
41
  return c
42
+ # fallback: first reasonable categorical column
43
  obj_cols = [c for c in cols if df[c].dtype == "object"]
44
  for c in obj_cols:
45
  nuniq = df[c].nunique(dropna=True)
46
+ if 1 < nuniq < max(50, len(df) // 10):
47
  return c
48
  return None
49
 
50
  def _nanlike_to_nan(df: pd.DataFrame) -> pd.DataFrame:
 
51
  dff = df.copy()
52
  for c in dff.columns:
53
  if dff[c].dtype == "object":
 
58
  return f"Interpret averages cautiously (only {n} records)." if n < min_n else None
59
 
60
  def _deviation_label(x: float, mu: float, tol: float = 0.01) -> str:
61
+ if np.isnan(x) or np.isnan(mu) or mu == 0:
 
 
 
62
  return "unknown"
63
  rel = (x - mu) / mu
64
+ if rel > 0.05:
65
  return "higher than average"
66
+ if rel < -0.05:
67
  return "lower than average"
68
  if abs(rel) <= max(tol, 0.05):
69
  return "about average"
 
72
  def _pluralize(label: str, n: int) -> str:
73
  return f"{label}{'' if n==1 else 's'}"
74
 
 
75
  def build_narrative(
76
  scenario_text: str,
 
77
  datasets: Dict[str, Any],
 
78
  structured_tables: Optional[Dict[str, pd.DataFrame]] = None,
79
+ metric_hints: Optional[List[str]] = None,
80
+ group_hints: Optional[List[str]] = None,
 
81
  min_sample: int = _DEF_MIN_SAMPLE
82
  ) -> str:
83
  """
84
+ Scenario-agnostic narrative fallback:
85
+ - Picks numeric metric & groupings dynamically
86
+ - Computes overall baseline + deviations
87
+ - Warns on small samples
88
+ - Optional geographic notes if city/lat/lon exist
 
 
89
  """
90
+ metric_hints = metric_hints or ["surgery_median", "consult_median", "wait", "median", "p90", "90th"]
91
  group_hints = group_hints or ["facility", "specialty", "zone", "hospital", "city", "region"]
92
 
93
+ # 1) choose first non-empty table-like dataset
94
  df = None
95
  df_key = None
96
  for k, v in datasets.items():
 
101
  if df is None:
102
  return "No tabular data available. Unable to generate a narrative."
103
 
104
+ # 2) metrics
105
  primary_metric = _pick_numeric(df, metric_hints) # e.g., Surgery_Median
106
  if not primary_metric:
107
  return "No numeric metric found to summarize; please ensure at least one numeric wait-time column is present."
108
 
109
  other_numeric = [c for c in df.columns if _is_numeric(df[c]) and c != primary_metric]
110
+ comparator_metric = next(
111
+ (c for c in other_numeric if any(h in c.lower() for h in ["consult", "wait", "median", "p90", "90th"])),
112
+ None
113
+ )
114
 
115
+ # 3) groups
116
  group1 = _find_group_col(df, group_hints) # e.g., Facility
117
  group2 = None
118
  if group1:
 
119
  alt_hints = [h for h in group_hints if h.lower() not in group1.lower()]
120
  group2 = _find_group_col(df.drop(columns=[group1], errors="ignore"), alt_hints)
121
 
122
+ # 4) baseline + grouped
123
+ baseline = pd.to_numeric(df[primary_metric], errors="coerce").mean(skipna=True)
124
+
125
+ def _group_stats(col: str) -> Optional[pd.DataFrame]:
126
+ if not col:
127
+ return None
128
+ tmp = df.copy()
129
+ tmp[primary_metric] = pd.to_numeric(tmp[primary_metric], errors="coerce")
130
+ comp_col = comparator_metric or primary_metric
131
+ if comp_col in tmp.columns:
132
+ tmp[comp_col] = pd.to_numeric(tmp[comp_col], errors="coerce")
133
+ agg = (
134
+ tmp.groupby(col, dropna=False)
135
+ .agg(
136
+ metric=(primary_metric, "mean"),
137
+ count=(primary_metric, "count"),
138
+ comp=(comp_col, "mean") if comp_col in tmp.columns else (primary_metric, "mean"),
139
+ )
140
+ .reset_index()
 
 
 
 
 
 
141
  )
142
+ return agg
143
+
144
+ g1 = _group_stats(group1)
145
+ g2 = _group_stats(group2)
146
 
147
+ # 5) Top groups (by primary metric) from group1
148
  top_lines: List[str] = []
149
  if isinstance(g1, pd.DataFrame) and not g1.empty:
150
  g1 = g1.sort_values(by="metric", ascending=False)
 
152
  for i, row in enumerate(g1.head(k).itertuples(index=False), 1):
153
  label = getattr(row, group1)
154
  metric = getattr(row, "metric")
155
+ comp = getattr(row, "comp")
156
  cnt = getattr(row, "count")
157
  devlab = _deviation_label(metric, baseline)
158
  caution = _small_sample_note(int(cnt), min_sample)
 
165
  msg += f" ({caution})"
166
  top_lines.append(msg)
167
 
168
+ # 6) Group2 overview
169
  region_lines: List[str] = []
170
  if isinstance(g2, pd.DataFrame) and not g2.empty:
 
171
  g2 = g2.sort_values(by="metric", ascending=False)
172
  for row in g2.itertuples(index=False):
173
  label = getattr(row, group2)
174
  metric = getattr(row, "metric")
175
+ comp = getattr(row, "comp")
176
  cnt = getattr(row, "count")
177
  devlab = _deviation_label(metric, baseline)
178
  caution = _small_sample_note(int(cnt), min_sample)
 
183
  line += f" — {caution}"
184
  region_lines.append(line)
185
 
186
+ # 7) Geographic notes (optional)
 
187
  geo_notes: List[str] = []
188
  city_col = next((c for c in df.columns if re.search(r"\bcity\b", c, re.I)), None)
189
  lat_col = next((c for c in df.columns if re.search(r"\b(lat|latitude)\b", c, re.I)), None)
190
  lon_col = next((c for c in df.columns if re.search(r"\b(lon|longitude)\b", c, re.I)), None)
191
  if group1 and city_col and (lat_col and lon_col):
 
192
  if isinstance(g1, pd.DataFrame) and not g1.empty and group1 in df.columns:
193
+ top_labels = g1[group1].astype(str).head(10).tolist()
194
+ sub = df[df[group1].astype(str).isin(top_labels)].copy()
 
195
  if not sub.empty:
196
+ sub[primary_metric] = pd.to_numeric(sub[primary_metric], errors="coerce")
197
+ by_city = (
198
+ sub.groupby(city_col, dropna=False)[primary_metric]
199
+ .mean()
200
+ .reset_index()
201
+ .sort_values(by=primary_metric, ascending=False)
202
+ )
203
+ for r in by_city.head(3).to_dict(orient="records"):
204
  cname = r.get(city_col)
205
  val = r.get(primary_metric)
206
  geo_notes.append(f"- **{cname}** shows higher average {primary_metric} among top groups ({_fmt_num(val)}).")
207
 
208
+ # 8) Methodology (auto)
209
  methodology: List[str] = []
 
210
  na_counts = df.isna().sum().sum()
211
  if na_counts > 0:
212
  methodology.append("Missing values (blank/dash) were treated as nulls and excluded from means.")
 
213
  methodology.append(f"Primary metric: **{primary_metric}**; overall average: **{_fmt_num(baseline)}**.")
214
  if comparator_metric:
215
  methodology.append(f"Comparator metric detected: **{comparator_metric}** (means shown when available).")
 
242
  lines.extend(geo_notes)
243
  lines.append("")
244
 
 
245
  recs: List[str] = []
246
  if top_lines:
247
  recs.append("Prioritize resources to the highest-average groups (above overall baseline), especially those with sufficient volume.")
 
250
  if isinstance(g2, pd.DataFrame) and not g2.empty:
251
  high = g2[g2["metric"] > baseline]
252
  if not high.empty:
253
+ recs.append(f"Address disparities where average **{primary_metric}** exceeds the overall baseline.")
254
  recs.append("For very small groups, validate data quality and consider pooling across similar categories to stabilize estimates.")
255
  recs.append("Validate coding differences (similar specialties or labels spelled differently) to ensure apples-to-apples comparison.")
256
 
 
259
  lines.append(f"- {r}")
260
 
261
  return "\n".join(lines).strip()
262
+