Medica_DecisionSupportAI / narrative_safetynet.py
Rajan Sharma
Update narrative_safetynet.py
1b29d16 verified
raw
history blame
17.2 kB
# narrative_safetynet.py
from __future__ import annotations
from typing import Dict, Any, List, Optional, Tuple
import re
import math
import numpy as np
import pandas as pd
# -------------------- helpers: dtype / formatting --------------------
_DEF_MIN_SAMPLE = 5 # generic caution threshold for group sizes
_HINT_METRICS_DEFAULT = [
"surgery_median", "consult_median",
"surgery_90th", "consult_90th",
"surgery", "consult",
"wait", "median", "p90", "90th"
]
_HINT_GROUPS_DEFAULT = [
"facility", "specialty", "zone",
"hospital", "city", "region"
]
_BAD_METRIC_NAMES = ["index", "id", "row", "unnamed"]
def _nanlike_to_nan(df: pd.DataFrame) -> pd.DataFrame:
dff = df.copy()
for c in dff.columns:
if dff[c].dtype == "object":
dff[c] = dff[c].replace({r"^\s*$": np.nan, r"^[-–—]$": np.nan}, regex=True)
return dff
def _is_numeric_series(s: pd.Series) -> bool:
try:
return pd.api.types.is_numeric_dtype(s)
except Exception:
return False
def _to_numeric(s: pd.Series) -> pd.Series:
return pd.to_numeric(s, errors="coerce")
def _fmt_num(x: Any, decimals: int = 1) -> str:
try:
if x is None or (isinstance(x, float) and math.isnan(x)):
return "n/a"
if isinstance(x, (int, np.integer)) or (isinstance(x, float) and float(x).is_integer()):
return f"{int(round(float(x))):,}"
return f"{float(x):,.{decimals}f}"
except Exception:
return str(x)
# -------------------- metric & dataset selection (dynamic) --------------------
def _score_metric_name(col: str, hints: List[str]) -> int:
name = (col or "").lower()
if any(bad in name for bad in _BAD_METRIC_NAMES):
return -10**6 # disqualify obvious counters/ids
score = 0
for h in hints:
if h in name:
score += 3
return score
def _choose_df_and_metric(
datasets: Dict[str, Any],
metric_hints: List[str]
) -> Optional[Tuple[str, pd.DataFrame, str]]:
"""
Sweep all dataframes & numeric columns. Pick the (df, metric) with best score:
+3 per hint match; +1 if non-constant numeric. Disqualify id-like names.
"""
best: Optional[Tuple[int, str, pd.DataFrame, str]] = None
for key, v in datasets.items():
if not isinstance(v, pd.DataFrame) or v.empty:
continue
df = _nanlike_to_nan(v)
for col in df.columns:
col_num = _to_numeric(df[col])
if not _is_numeric_series(col_num):
continue
s = _score_metric_name(col, metric_hints)
if col_num.nunique(dropna=True) > 1:
s += 1
if best is None or s > best[0]:
best = (s, key, df, col)
if best is None:
return None
_, key, df, metric = best
return key, df, metric
# -------------------- grouping detection (dynamic) --------------------
def _find_group_col(df: pd.DataFrame, candidates: List[str], avoid: Optional[List[str]] = None) -> Optional[str]:
avoid = [a.lower() for a in (avoid or [])]
cols = list(df.columns)
# prefer name matches
for cand in candidates:
for c in cols:
cname = c.lower()
if cand.lower() in cname and all(a not in cname for a in avoid):
return c
# fallback: a categorical with reasonable cardinality
obj_cols = [c for c in cols if df[c].dtype == "object"]
for c in obj_cols:
nuniq = df[c].nunique(dropna=True)
if 1 < nuniq < max(50, len(df)//10):
return c
return None
# -------------------- labels & cautions --------------------
def _label_vs_baseline(x: float, mu: float, band: float = 0.05) -> str:
if pd.isna(x) or pd.isna(mu) or mu == 0:
return "unknown"
rel = (x - mu) / mu
if rel > band:
return "higher than average"
if rel < -band:
return "lower than average"
return "about average"
def _small_sample_note(n: int, min_n: int = _DEF_MIN_SAMPLE) -> Optional[str]:
return f"Interpret averages cautiously (only {n} records)." if n < min_n else None
def _pluralize(word: str, n: int) -> str:
return f"{word}{'' if n == 1 else 's'}"
# -------------------- geo join (Top-5 only) --------------------
def _canon(s: str) -> str:
return re.sub(r"[^a-z0-9]+", "", (s or "").lower())
def _map_top_facilities_to_odhf(
top_facilities: pd.DataFrame,
odhf: pd.DataFrame,
fac_col: str = "Facility",
odhf_name_col: str = "facility_name"
) -> pd.DataFrame:
if odhf is None or odhf.empty or top_facilities is None or top_facilities.empty:
return pd.DataFrame()
out_rows: List[Dict[str, Any]] = []
try:
idx = { _canon(n): i for i, n in odhf[odhf_name_col].dropna().items() }
except Exception:
return pd.DataFrame()
for fac in top_facilities[fac_col].dropna().astype(str).unique():
key = _canon(fac)
row = None
if key in idx:
row = odhf.loc[idx[key]]
else:
# contains fallback (case-insensitive)
cand = odhf[odhf[odhf_name_col].astype(str).str.contains(fac, case=False, na=False)]
if not cand.empty:
row = cand.iloc[0]
if row is not None:
out_rows.append({
"Facility": fac,
"city": row.get("city"),
"latitude": row.get("latitude"),
"longitude": row.get("longitude")
})
return pd.DataFrame(out_rows)
# -------------------- main: narrative builder --------------------
def build_narrative(
scenario_text: str,
datasets: Dict[str, Any],
structured_tables: Optional[Dict[str, pd.DataFrame]] = None,
metric_hints: Optional[List[str]] = None,
group_hints: Optional[List[str]] = None,
min_sample: int = _DEF_MIN_SAMPLE,
baseline_band: float = 0.05 # Β±5% "about average"
) -> str:
"""
Scenario-agnostic narrative fallback:
1) Choose best (df, metric) dynamically using name hints + numeric sanity
2) Prefer structured tables (top facilities/specialties/zones) if provided
3) Compute overall baseline + label groups vs baseline
4) Geo notes via fuzzy Top-5 ↔ ODHF join (<= 3 bullets)
5) Recommendations grounded in the same metric/groups
"""
metric_hints = (metric_hints or _HINT_METRICS_DEFAULT)
group_hints = (group_hints or _HINT_GROUPS_DEFAULT)
# ---------- 1) Pick dataset + metric ----------
choice = _choose_df_and_metric(datasets, metric_hints)
if not choice:
return "No tabular data available. Unable to generate a narrative."
df_key, df, primary_metric = choice
# Ensure numeric
df = _nanlike_to_nan(df)
if primary_metric not in df.columns:
return "Chosen metric missing. Unable to generate a narrative."
df[primary_metric] = _to_numeric(df[primary_metric])
# Optional comparator metric (e.g., consult vs surgery)
comparator_metric = None
for c in df.columns:
if c == primary_metric:
continue
if _is_numeric_series(_to_numeric(df[c])):
name = c.lower()
if any(h in name for h in ["consult", "median", "wait", "p90", "90th"]):
comparator_metric = c
break
# ---------- 2) Prefer structured tables if present ----------
top_fac = None
top_spec = None
zone_tbl = None
odhf_df = None
if structured_tables:
top_fac = structured_tables.get("top_facilities")
top_spec = structured_tables.get("top_specialties")
zone_tbl = structured_tables.get("zone_summary")
# try to detect ODHF-like table by column fingerprint
for k, v in datasets.items():
if isinstance(v, pd.DataFrame) and {"facility_name", "city"}.issubset(set(map(str.lower, v.columns.str.lower()))):
odhf_df = v
break
# Compute baseline from the selected df/metric (not from ODHF)
baseline = df[primary_metric].mean(skipna=True)
# ---------- 3) Build sections ----------
sections: List[str] = []
# Methodology
meth: List[str] = []
meth.append(f"Primary metric: **{primary_metric}**; overall average: **{_fmt_num(baseline)}**.")
if comparator_metric:
meth.append(f"Comparator metric detected: **{comparator_metric}** (means shown when available).")
# Missing value note
if df.isna().sum().sum() > 0:
meth.append("Missing values (blank/dash) were treated as nulls and excluded from means.")
# Group hints (informative only)
g1 = _find_group_col(df, group_hints, avoid=[primary_metric])
if g1:
meth.append(f"Primary grouping inferred: **{g1}**.")
g2 = _find_group_col(df.drop(columns=[g1], errors="ignore") if g1 else df, group_hints, avoid=[primary_metric, g1 or ""])
if g2:
meth.append(f"Secondary grouping inferred: **{g2}**.")
sections.append("## Methodology (Auto-generated)")
for m in meth:
sections.append(f"- {m}")
sections.append("")
# Highest averages by primary grouping (prefer structured Top-5 if given)
top_lines: List[str] = []
if isinstance(top_fac, pd.DataFrame) and not top_fac.empty:
# Expect columns like: Facility, Zone, avg_Surgery_Median, count_*
# Keep dynamic: find a metric column in top_fac aligned to primary_metric by hint matching
metric_col = None
for c in top_fac.columns:
if primary_metric.lower() in c.lower() or any(h in c.lower() for h in ["avg_", "mean"]):
if _is_numeric_series(_to_numeric(top_fac[c])):
metric_col = c
break
if metric_col is None:
# fallback: first numeric col
for c in top_fac.columns:
if _is_numeric_series(_to_numeric(top_fac[c])):
metric_col = c; break
cnt_col = next((c for c in top_fac.columns if "count" in c.lower() or c.lower() in {"n", "records"}), None)
lab_col = next((c for c in top_fac.columns if "facility" in c.lower()), None)
if metric_col and lab_col:
# already sorted in your executor; if not, sort desc
tf = top_fac.copy()
tf[metric_col] = _to_numeric(tf[metric_col])
tf = tf.sort_values(by=metric_col, ascending=False)
for i, row in enumerate(tf.head(5).itertuples(index=False), 1):
label = getattr(row, lab_col)
met = getattr(row, metric_col)
cnt = getattr(row, cnt_col) if cnt_col and hasattr(row, cnt_col) else np.nan
dev = _label_vs_baseline(met, baseline, baseline_band)
caution = _small_sample_note(int(cnt)) if (isinstance(cnt, (int, float)) and not pd.isna(cnt)) else None
msg = f"{i}. **{label}** β€” {primary_metric}: {_fmt_num(met)}"
if cnt_col and hasattr(row, cnt_col):
msg += f"; {_pluralize('record', int(cnt))}: {int(cnt)}"
msg += f" β†’ {dev}"
if caution:
msg += f" ({caution})"
top_lines.append(msg)
else:
# No structured Top-5 provided: derive from g1
if g1:
tmp = df.copy()
tmp[primary_metric] = _to_numeric(tmp[primary_metric])
if comparator_metric in tmp.columns:
tmp[comparator_metric] = _to_numeric(tmp[comparator_metric])
agg = (
tmp.groupby(g1, dropna=False)
.agg(metric=(primary_metric, "mean"), count=(primary_metric, "count"))
.reset_index()
).sort_values(by="metric", ascending=False)
for i, row in enumerate(agg.head(5).itertuples(index=False), 1):
label = getattr(row, g1)
met = getattr(row, "metric")
cnt = getattr(row, "count")
dev = _label_vs_baseline(met, baseline, baseline_band)
caution = _small_sample_note(int(cnt), min_sample)
msg = f"{i}. **{label}** β€” {primary_metric}: {_fmt_num(met)}; {_pluralize('record', int(cnt))}: {cnt} β†’ {dev}"
if caution:
msg += f" ({caution})"
top_lines.append(msg)
if top_lines:
sections.append("## Highest average values by group")
sections.extend(top_lines)
sections.append("")
# Zone comparison (prefer structured zone table if present)
zone_lines: List[str] = []
if isinstance(zone_tbl, pd.DataFrame) and not zone_tbl.empty:
z = zone_tbl.copy()
# find zone label & metric columns dynamically
zone_col = next((c for c in z.columns if "zone" in c.lower()), None)
zmet_col = next((c for c in z.columns if primary_metric.lower() in c.lower() or "avg" in c.lower()), None)
zcnt_col = next((c for c in z.columns if "count" in c.lower() or c.lower() in {"n", "records"}), None)
if zone_col and zmet_col:
# Clean truly missing zones but keep literal "Total" if present
z[zone_col] = z[zone_col].astype("string")
keep = (z[zone_col].notna()) | (z[zone_col].str.upper() == "TOTAL")
z = z[keep]
z[zmet_col] = _to_numeric(z[zmet_col])
z = z.sort_values(by=zmet_col, ascending=False)
for row in z.itertuples(index=False):
zone = getattr(row, zone_col)
met = getattr(row, zmet_col)
cnt = getattr(row, zcnt_col) if zcnt_col and hasattr(row, zcnt_col) else np.nan
lab = _label_vs_baseline(met, baseline, baseline_band)
msg = f"- **{zone}**: {_fmt_num(met)} (vs overall {_fmt_num(baseline)} β†’ {lab})"
if zcnt_col and hasattr(row, zcnt_col) and not pd.isna(cnt):
msg += f"; n={int(cnt)}"
zone_lines.append(msg)
else:
# Derive zones dynamically if a zone-like column exists
zcol = _find_group_col(df, ["zone"])
if zcol:
z = df.copy()
z[zcol] = z[zcol].astype("string").str.strip()
# drop true NaN zones, but do NOT fabricate totals
z = z[z[zcol].notna()]
agg = (
z.groupby(zcol, dropna=False)[primary_metric]
.agg(["mean", "count"]).reset_index()
.rename(columns={"mean": "metric", "count": "count"})
.sort_values(by="metric", ascending=False)
)
for row in agg.itertuples(index=False):
zone = getattr(row, zcol)
met = getattr(row, "metric")
cnt = getattr(row, "count")
lab = _label_vs_baseline(met, baseline, baseline_band)
msg = f"- **{zone}**: {_fmt_num(met)} (vs overall {_fmt_num(baseline)} β†’ {lab}); n={cnt}"
zone_lines.append(msg)
if zone_lines:
sections.append(f"## {( 'Zone' if 'zone' in ''.join(df.columns).lower() else 'Category')} comparison vs overall")
sections.extend(zone_lines)
sections.append("")
# Geographic notes β€” map Top-5 facilities only (if we have both Top-5 and ODHF df)
geo_lines: List[str] = []
if isinstance(top_fac, pd.DataFrame) and not top_fac.empty and isinstance(odhf_df, pd.DataFrame) and not odhf_df.empty:
fac_col = next((c for c in top_fac.columns if "facility" in c.lower()), None)
if fac_col:
mapped = _map_top_facilities_to_odhf(top_fac.head(5), odhf_df, fac_col=fac_col, odhf_name_col=next(
(c for c in odhf_df.columns if c.lower() == "facility_name"), "facility_name"
))
if not mapped.empty:
for r in mapped.head(3).to_dict(orient="records"):
f = r.get("Facility")
city = r.get("city")
geo_lines.append(f"- **{f}** ({city}) is among the highest-average groups; consider capacity and referral patterns.")
if geo_lines:
sections.append("## Geographic notes")
sections.extend(geo_lines)
sections.append("")
# Recommendations β€” grounded in the above
recs: List[str] = []
if top_lines:
recs.append("Prioritize operating room time and staffing for the highest-average groups, especially those with substantial volume.")
if comparator_metric:
recs.append(f"Track **{comparator_metric}** alongside {primary_metric} to identify upstream bottlenecks (e.g., long consult waits driving surgical delays).")
if zone_lines:
recs.append("Address zones persistently above the provincial baseline; deploy targeted resources and load balancing across facilities.")
recs.append("Apply small-sample caution; pool or validate categories with very few records before acting on outliers.")
recs.append("Standardize specialty/facility naming to reduce coding-induced variance in aggregates.")
sections.append("## Recommendations (Auto-generated)")
for r in recs:
sections.append(f"- {r}")
return "\n".join(sections).strip()