Spaces:
Sleeping
Sleeping
Rajan Sharma
commited on
Update narrative_safetynet.py
Browse files- narrative_safetynet.py +63 -77
narrative_safetynet.py
CHANGED
|
@@ -1,12 +1,11 @@
|
|
| 1 |
# narrative_safetynet.py
|
| 2 |
from __future__ import annotations
|
| 3 |
-
from typing import Dict, Any, List, Optional
|
| 4 |
import math
|
| 5 |
import numpy as np
|
| 6 |
import pandas as pd
|
| 7 |
import re
|
| 8 |
|
| 9 |
-
# ---------- Generic helpers ----------
|
| 10 |
_DEF_MIN_SAMPLE = 5 # threshold for "interpret with caution" (fully generic)
|
| 11 |
|
| 12 |
def _is_numeric(s: pd.Series) -> bool:
|
|
@@ -23,7 +22,7 @@ def _fmt_num(x: Any, decimals: int = 1) -> str:
|
|
| 23 |
return str(x)
|
| 24 |
|
| 25 |
def _pick_numeric(df: pd.DataFrame, hints: List[str]) -> Optional[str]:
|
| 26 |
-
# choose a numeric column;
|
| 27 |
cols = list(df.columns)
|
| 28 |
for h in hints:
|
| 29 |
for c in cols:
|
|
@@ -35,22 +34,20 @@ def _pick_numeric(df: pd.DataFrame, hints: List[str]) -> Optional[str]:
|
|
| 35 |
return None
|
| 36 |
|
| 37 |
def _find_group_col(df: pd.DataFrame, candidates: List[str]) -> Optional[str]:
|
| 38 |
-
# choose a categorical/grouping column by fuzzy name
|
| 39 |
cols = list(df.columns)
|
| 40 |
for cand in candidates:
|
| 41 |
for c in cols:
|
| 42 |
if cand.lower() in c.lower():
|
| 43 |
return c
|
| 44 |
-
# fallback: first
|
| 45 |
obj_cols = [c for c in cols if df[c].dtype == "object"]
|
| 46 |
for c in obj_cols:
|
| 47 |
nuniq = df[c].nunique(dropna=True)
|
| 48 |
-
if 1 < nuniq < max(50, len(df) // 10):
|
| 49 |
return c
|
| 50 |
return None
|
| 51 |
|
| 52 |
def _nanlike_to_nan(df: pd.DataFrame) -> pd.DataFrame:
|
| 53 |
-
# treat dashes and blank strings as NaN; do not hard-code schema
|
| 54 |
dff = df.copy()
|
| 55 |
for c in dff.columns:
|
| 56 |
if dff[c].dtype == "object":
|
|
@@ -61,15 +58,12 @@ def _small_sample_note(n: int, min_n: int = _DEF_MIN_SAMPLE) -> Optional[str]:
|
|
| 61 |
return f"Interpret averages cautiously (only {n} records)." if n < min_n else None
|
| 62 |
|
| 63 |
def _deviation_label(x: float, mu: float, tol: float = 0.01) -> str:
|
| 64 |
-
|
| 65 |
-
if np.isnan(x) or np.isnan(mu):
|
| 66 |
-
return "unknown"
|
| 67 |
-
if mu == 0:
|
| 68 |
return "unknown"
|
| 69 |
rel = (x - mu) / mu
|
| 70 |
-
if rel > 0.05:
|
| 71 |
return "higher than average"
|
| 72 |
-
if rel < -0.05:
|
| 73 |
return "lower than average"
|
| 74 |
if abs(rel) <= max(tol, 0.05):
|
| 75 |
return "about average"
|
|
@@ -78,31 +72,25 @@ def _deviation_label(x: float, mu: float, tol: float = 0.01) -> str:
|
|
| 78 |
def _pluralize(label: str, n: int) -> str:
|
| 79 |
return f"{label}{'' if n==1 else 's'}"
|
| 80 |
|
| 81 |
-
# ---------- Core narrative generator ----------
|
| 82 |
def build_narrative(
|
| 83 |
scenario_text: str,
|
| 84 |
-
# A dict of dataframes your engine just produced / loaded (e.g., the same dict passed to ScenarioEngine)
|
| 85 |
datasets: Dict[str, Any],
|
| 86 |
-
# Optional structured outputs your engine already rendered (tables etc.) if you want to cross-reference
|
| 87 |
structured_tables: Optional[Dict[str, pd.DataFrame]] = None,
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
group_hints: Optional[List[str]] = None, # e.g. ["Facility","Specialty","Zone"]
|
| 91 |
min_sample: int = _DEF_MIN_SAMPLE
|
| 92 |
) -> str:
|
| 93 |
"""
|
| 94 |
-
|
| 95 |
-
-
|
| 96 |
-
-
|
| 97 |
-
-
|
| 98 |
-
-
|
| 99 |
-
- Adds geographic notes if city/lat/lon are present (fully optional)
|
| 100 |
-
This function avoids any scenario-specific strings and infers columns dynamically.
|
| 101 |
"""
|
| 102 |
-
metric_hints = metric_hints or ["surgery_median", "consult_median", "wait", "median", "
|
| 103 |
group_hints = group_hints or ["facility", "specialty", "zone", "hospital", "city", "region"]
|
| 104 |
|
| 105 |
-
# 1)
|
| 106 |
df = None
|
| 107 |
df_key = None
|
| 108 |
for k, v in datasets.items():
|
|
@@ -113,50 +101,50 @@ def build_narrative(
|
|
| 113 |
if df is None:
|
| 114 |
return "No tabular data available. Unable to generate a narrative."
|
| 115 |
|
| 116 |
-
# 2)
|
| 117 |
primary_metric = _pick_numeric(df, metric_hints) # e.g., Surgery_Median
|
| 118 |
if not primary_metric:
|
| 119 |
return "No numeric metric found to summarize; please ensure at least one numeric wait-time column is present."
|
| 120 |
|
| 121 |
other_numeric = [c for c in df.columns if _is_numeric(df[c]) and c != primary_metric]
|
| 122 |
-
comparator_metric = next(
|
|
|
|
|
|
|
|
|
|
| 123 |
|
| 124 |
-
# 3)
|
| 125 |
group1 = _find_group_col(df, group_hints) # e.g., Facility
|
| 126 |
group2 = None
|
| 127 |
if group1:
|
| 128 |
-
# try to find a second group that isn't identical (e.g., Zone if Facility selected)
|
| 129 |
alt_hints = [h for h in group_hints if h.lower() not in group1.lower()]
|
| 130 |
group2 = _find_group_col(df.drop(columns=[group1], errors="ignore"), alt_hints)
|
| 131 |
|
| 132 |
-
# 4)
|
| 133 |
-
baseline = df[primary_metric]
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
.agg(
|
| 152 |
-
metric=(primary_metric, "mean"),
|
| 153 |
-
count=(primary_metric, "count"),
|
| 154 |
-
_comp=(comparator_metric, "mean") if comparator_metric else (primary_metric, "mean"),
|
| 155 |
-
)
|
| 156 |
-
.reset_index()
|
| 157 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
|
| 159 |
-
# 5)
|
| 160 |
top_lines: List[str] = []
|
| 161 |
if isinstance(g1, pd.DataFrame) and not g1.empty:
|
| 162 |
g1 = g1.sort_values(by="metric", ascending=False)
|
|
@@ -164,7 +152,7 @@ def build_narrative(
|
|
| 164 |
for i, row in enumerate(g1.head(k).itertuples(index=False), 1):
|
| 165 |
label = getattr(row, group1)
|
| 166 |
metric = getattr(row, "metric")
|
| 167 |
-
comp = getattr(row, "
|
| 168 |
cnt = getattr(row, "count")
|
| 169 |
devlab = _deviation_label(metric, baseline)
|
| 170 |
caution = _small_sample_note(int(cnt), min_sample)
|
|
@@ -177,15 +165,14 @@ def build_narrative(
|
|
| 177 |
msg += f" ({caution})"
|
| 178 |
top_lines.append(msg)
|
| 179 |
|
| 180 |
-
# 6)
|
| 181 |
region_lines: List[str] = []
|
| 182 |
if isinstance(g2, pd.DataFrame) and not g2.empty:
|
| 183 |
-
# order by metric descending
|
| 184 |
g2 = g2.sort_values(by="metric", ascending=False)
|
| 185 |
for row in g2.itertuples(index=False):
|
| 186 |
label = getattr(row, group2)
|
| 187 |
metric = getattr(row, "metric")
|
| 188 |
-
comp = getattr(row, "
|
| 189 |
cnt = getattr(row, "count")
|
| 190 |
devlab = _deviation_label(metric, baseline)
|
| 191 |
caution = _small_sample_note(int(cnt), min_sample)
|
|
@@ -196,34 +183,33 @@ def build_narrative(
|
|
| 196 |
line += f" — {caution}"
|
| 197 |
region_lines.append(line)
|
| 198 |
|
| 199 |
-
# 7) Geographic notes (
|
| 200 |
-
# We never hard-code field names; we look for city/lat/lon patterns
|
| 201 |
geo_notes: List[str] = []
|
| 202 |
city_col = next((c for c in df.columns if re.search(r"\bcity\b", c, re.I)), None)
|
| 203 |
lat_col = next((c for c in df.columns if re.search(r"\b(lat|latitude)\b", c, re.I)), None)
|
| 204 |
lon_col = next((c for c in df.columns if re.search(r"\b(lon|longitude)\b", c, re.I)), None)
|
| 205 |
if group1 and city_col and (lat_col and lon_col):
|
| 206 |
-
# summarize whether top groups cluster in specific cities
|
| 207 |
if isinstance(g1, pd.DataFrame) and not g1.empty and group1 in df.columns:
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
sub = df[df[group1].astype(str).isin(top_labels)]
|
| 211 |
if not sub.empty:
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
cname = r.get(city_col)
|
| 217 |
val = r.get(primary_metric)
|
| 218 |
geo_notes.append(f"- **{cname}** shows higher average {primary_metric} among top groups ({_fmt_num(val)}).")
|
| 219 |
|
| 220 |
-
# 8) Methodology (
|
| 221 |
methodology: List[str] = []
|
| 222 |
-
# missing values
|
| 223 |
na_counts = df.isna().sum().sum()
|
| 224 |
if na_counts > 0:
|
| 225 |
methodology.append("Missing values (blank/dash) were treated as nulls and excluded from means.")
|
| 226 |
-
# numeric coercion note
|
| 227 |
methodology.append(f"Primary metric: **{primary_metric}**; overall average: **{_fmt_num(baseline)}**.")
|
| 228 |
if comparator_metric:
|
| 229 |
methodology.append(f"Comparator metric detected: **{comparator_metric}** (means shown when available).")
|
|
@@ -256,7 +242,6 @@ def build_narrative(
|
|
| 256 |
lines.extend(geo_notes)
|
| 257 |
lines.append("")
|
| 258 |
|
| 259 |
-
# Generic recommendations template (data-driven, not hard-coded)
|
| 260 |
recs: List[str] = []
|
| 261 |
if top_lines:
|
| 262 |
recs.append("Prioritize resources to the highest-average groups (above overall baseline), especially those with sufficient volume.")
|
|
@@ -265,7 +250,7 @@ def build_narrative(
|
|
| 265 |
if isinstance(g2, pd.DataFrame) and not g2.empty:
|
| 266 |
high = g2[g2["metric"] > baseline]
|
| 267 |
if not high.empty:
|
| 268 |
-
recs.append(f"Address
|
| 269 |
recs.append("For very small groups, validate data quality and consider pooling across similar categories to stabilize estimates.")
|
| 270 |
recs.append("Validate coding differences (similar specialties or labels spelled differently) to ensure apples-to-apples comparison.")
|
| 271 |
|
|
@@ -274,3 +259,4 @@ def build_narrative(
|
|
| 274 |
lines.append(f"- {r}")
|
| 275 |
|
| 276 |
return "\n".join(lines).strip()
|
|
|
|
|
|
| 1 |
# narrative_safetynet.py
|
| 2 |
from __future__ import annotations
|
| 3 |
+
from typing import Dict, Any, List, Optional
|
| 4 |
import math
|
| 5 |
import numpy as np
|
| 6 |
import pandas as pd
|
| 7 |
import re
|
| 8 |
|
|
|
|
| 9 |
_DEF_MIN_SAMPLE = 5 # threshold for "interpret with caution" (fully generic)
|
| 10 |
|
| 11 |
def _is_numeric(s: pd.Series) -> bool:
|
|
|
|
| 22 |
return str(x)
|
| 23 |
|
| 24 |
def _pick_numeric(df: pd.DataFrame, hints: List[str]) -> Optional[str]:
|
| 25 |
+
# choose a numeric column; prefer hinted names
|
| 26 |
cols = list(df.columns)
|
| 27 |
for h in hints:
|
| 28 |
for c in cols:
|
|
|
|
| 34 |
return None
|
| 35 |
|
| 36 |
def _find_group_col(df: pd.DataFrame, candidates: List[str]) -> Optional[str]:
|
|
|
|
| 37 |
cols = list(df.columns)
|
| 38 |
for cand in candidates:
|
| 39 |
for c in cols:
|
| 40 |
if cand.lower() in c.lower():
|
| 41 |
return c
|
| 42 |
+
# fallback: first reasonable categorical column
|
| 43 |
obj_cols = [c for c in cols if df[c].dtype == "object"]
|
| 44 |
for c in obj_cols:
|
| 45 |
nuniq = df[c].nunique(dropna=True)
|
| 46 |
+
if 1 < nuniq < max(50, len(df) // 10):
|
| 47 |
return c
|
| 48 |
return None
|
| 49 |
|
| 50 |
def _nanlike_to_nan(df: pd.DataFrame) -> pd.DataFrame:
|
|
|
|
| 51 |
dff = df.copy()
|
| 52 |
for c in dff.columns:
|
| 53 |
if dff[c].dtype == "object":
|
|
|
|
| 58 |
return f"Interpret averages cautiously (only {n} records)." if n < min_n else None
|
| 59 |
|
| 60 |
def _deviation_label(x: float, mu: float, tol: float = 0.01) -> str:
|
| 61 |
+
if np.isnan(x) or np.isnan(mu) or mu == 0:
|
|
|
|
|
|
|
|
|
|
| 62 |
return "unknown"
|
| 63 |
rel = (x - mu) / mu
|
| 64 |
+
if rel > 0.05:
|
| 65 |
return "higher than average"
|
| 66 |
+
if rel < -0.05:
|
| 67 |
return "lower than average"
|
| 68 |
if abs(rel) <= max(tol, 0.05):
|
| 69 |
return "about average"
|
|
|
|
| 72 |
def _pluralize(label: str, n: int) -> str:
|
| 73 |
return f"{label}{'' if n==1 else 's'}"
|
| 74 |
|
|
|
|
| 75 |
def build_narrative(
|
| 76 |
scenario_text: str,
|
|
|
|
| 77 |
datasets: Dict[str, Any],
|
|
|
|
| 78 |
structured_tables: Optional[Dict[str, pd.DataFrame]] = None,
|
| 79 |
+
metric_hints: Optional[List[str]] = None,
|
| 80 |
+
group_hints: Optional[List[str]] = None,
|
|
|
|
| 81 |
min_sample: int = _DEF_MIN_SAMPLE
|
| 82 |
) -> str:
|
| 83 |
"""
|
| 84 |
+
Scenario-agnostic narrative fallback:
|
| 85 |
+
- Picks numeric metric & groupings dynamically
|
| 86 |
+
- Computes overall baseline + deviations
|
| 87 |
+
- Warns on small samples
|
| 88 |
+
- Optional geographic notes if city/lat/lon exist
|
|
|
|
|
|
|
| 89 |
"""
|
| 90 |
+
metric_hints = metric_hints or ["surgery_median", "consult_median", "wait", "median", "p90", "90th"]
|
| 91 |
group_hints = group_hints or ["facility", "specialty", "zone", "hospital", "city", "region"]
|
| 92 |
|
| 93 |
+
# 1) choose first non-empty table-like dataset
|
| 94 |
df = None
|
| 95 |
df_key = None
|
| 96 |
for k, v in datasets.items():
|
|
|
|
| 101 |
if df is None:
|
| 102 |
return "No tabular data available. Unable to generate a narrative."
|
| 103 |
|
| 104 |
+
# 2) metrics
|
| 105 |
primary_metric = _pick_numeric(df, metric_hints) # e.g., Surgery_Median
|
| 106 |
if not primary_metric:
|
| 107 |
return "No numeric metric found to summarize; please ensure at least one numeric wait-time column is present."
|
| 108 |
|
| 109 |
other_numeric = [c for c in df.columns if _is_numeric(df[c]) and c != primary_metric]
|
| 110 |
+
comparator_metric = next(
|
| 111 |
+
(c for c in other_numeric if any(h in c.lower() for h in ["consult", "wait", "median", "p90", "90th"])),
|
| 112 |
+
None
|
| 113 |
+
)
|
| 114 |
|
| 115 |
+
# 3) groups
|
| 116 |
group1 = _find_group_col(df, group_hints) # e.g., Facility
|
| 117 |
group2 = None
|
| 118 |
if group1:
|
|
|
|
| 119 |
alt_hints = [h for h in group_hints if h.lower() not in group1.lower()]
|
| 120 |
group2 = _find_group_col(df.drop(columns=[group1], errors="ignore"), alt_hints)
|
| 121 |
|
| 122 |
+
# 4) baseline + grouped
|
| 123 |
+
baseline = pd.to_numeric(df[primary_metric], errors="coerce").mean(skipna=True)
|
| 124 |
+
|
| 125 |
+
def _group_stats(col: str) -> Optional[pd.DataFrame]:
|
| 126 |
+
if not col:
|
| 127 |
+
return None
|
| 128 |
+
tmp = df.copy()
|
| 129 |
+
tmp[primary_metric] = pd.to_numeric(tmp[primary_metric], errors="coerce")
|
| 130 |
+
comp_col = comparator_metric or primary_metric
|
| 131 |
+
if comp_col in tmp.columns:
|
| 132 |
+
tmp[comp_col] = pd.to_numeric(tmp[comp_col], errors="coerce")
|
| 133 |
+
agg = (
|
| 134 |
+
tmp.groupby(col, dropna=False)
|
| 135 |
+
.agg(
|
| 136 |
+
metric=(primary_metric, "mean"),
|
| 137 |
+
count=(primary_metric, "count"),
|
| 138 |
+
comp=(comp_col, "mean") if comp_col in tmp.columns else (primary_metric, "mean"),
|
| 139 |
+
)
|
| 140 |
+
.reset_index()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
)
|
| 142 |
+
return agg
|
| 143 |
+
|
| 144 |
+
g1 = _group_stats(group1)
|
| 145 |
+
g2 = _group_stats(group2)
|
| 146 |
|
| 147 |
+
# 5) Top groups (by primary metric) from group1
|
| 148 |
top_lines: List[str] = []
|
| 149 |
if isinstance(g1, pd.DataFrame) and not g1.empty:
|
| 150 |
g1 = g1.sort_values(by="metric", ascending=False)
|
|
|
|
| 152 |
for i, row in enumerate(g1.head(k).itertuples(index=False), 1):
|
| 153 |
label = getattr(row, group1)
|
| 154 |
metric = getattr(row, "metric")
|
| 155 |
+
comp = getattr(row, "comp")
|
| 156 |
cnt = getattr(row, "count")
|
| 157 |
devlab = _deviation_label(metric, baseline)
|
| 158 |
caution = _small_sample_note(int(cnt), min_sample)
|
|
|
|
| 165 |
msg += f" ({caution})"
|
| 166 |
top_lines.append(msg)
|
| 167 |
|
| 168 |
+
# 6) Group2 overview
|
| 169 |
region_lines: List[str] = []
|
| 170 |
if isinstance(g2, pd.DataFrame) and not g2.empty:
|
|
|
|
| 171 |
g2 = g2.sort_values(by="metric", ascending=False)
|
| 172 |
for row in g2.itertuples(index=False):
|
| 173 |
label = getattr(row, group2)
|
| 174 |
metric = getattr(row, "metric")
|
| 175 |
+
comp = getattr(row, "comp")
|
| 176 |
cnt = getattr(row, "count")
|
| 177 |
devlab = _deviation_label(metric, baseline)
|
| 178 |
caution = _small_sample_note(int(cnt), min_sample)
|
|
|
|
| 183 |
line += f" — {caution}"
|
| 184 |
region_lines.append(line)
|
| 185 |
|
| 186 |
+
# 7) Geographic notes (optional)
|
|
|
|
| 187 |
geo_notes: List[str] = []
|
| 188 |
city_col = next((c for c in df.columns if re.search(r"\bcity\b", c, re.I)), None)
|
| 189 |
lat_col = next((c for c in df.columns if re.search(r"\b(lat|latitude)\b", c, re.I)), None)
|
| 190 |
lon_col = next((c for c in df.columns if re.search(r"\b(lon|longitude)\b", c, re.I)), None)
|
| 191 |
if group1 and city_col and (lat_col and lon_col):
|
|
|
|
| 192 |
if isinstance(g1, pd.DataFrame) and not g1.empty and group1 in df.columns:
|
| 193 |
+
top_labels = g1[group1].astype(str).head(10).tolist()
|
| 194 |
+
sub = df[df[group1].astype(str).isin(top_labels)].copy()
|
|
|
|
| 195 |
if not sub.empty:
|
| 196 |
+
sub[primary_metric] = pd.to_numeric(sub[primary_metric], errors="coerce")
|
| 197 |
+
by_city = (
|
| 198 |
+
sub.groupby(city_col, dropna=False)[primary_metric]
|
| 199 |
+
.mean()
|
| 200 |
+
.reset_index()
|
| 201 |
+
.sort_values(by=primary_metric, ascending=False)
|
| 202 |
+
)
|
| 203 |
+
for r in by_city.head(3).to_dict(orient="records"):
|
| 204 |
cname = r.get(city_col)
|
| 205 |
val = r.get(primary_metric)
|
| 206 |
geo_notes.append(f"- **{cname}** shows higher average {primary_metric} among top groups ({_fmt_num(val)}).")
|
| 207 |
|
| 208 |
+
# 8) Methodology (auto)
|
| 209 |
methodology: List[str] = []
|
|
|
|
| 210 |
na_counts = df.isna().sum().sum()
|
| 211 |
if na_counts > 0:
|
| 212 |
methodology.append("Missing values (blank/dash) were treated as nulls and excluded from means.")
|
|
|
|
| 213 |
methodology.append(f"Primary metric: **{primary_metric}**; overall average: **{_fmt_num(baseline)}**.")
|
| 214 |
if comparator_metric:
|
| 215 |
methodology.append(f"Comparator metric detected: **{comparator_metric}** (means shown when available).")
|
|
|
|
| 242 |
lines.extend(geo_notes)
|
| 243 |
lines.append("")
|
| 244 |
|
|
|
|
| 245 |
recs: List[str] = []
|
| 246 |
if top_lines:
|
| 247 |
recs.append("Prioritize resources to the highest-average groups (above overall baseline), especially those with sufficient volume.")
|
|
|
|
| 250 |
if isinstance(g2, pd.DataFrame) and not g2.empty:
|
| 251 |
high = g2[g2["metric"] > baseline]
|
| 252 |
if not high.empty:
|
| 253 |
+
recs.append(f"Address disparities where average **{primary_metric}** exceeds the overall baseline.")
|
| 254 |
recs.append("For very small groups, validate data quality and consider pooling across similar categories to stabilize estimates.")
|
| 255 |
recs.append("Validate coding differences (similar specialties or labels spelled differently) to ensure apples-to-apples comparison.")
|
| 256 |
|
|
|
|
| 259 |
lines.append(f"- {r}")
|
| 260 |
|
| 261 |
return "\n".join(lines).strip()
|
| 262 |
+
|