Medica_DecisionSupportAI / narrative_safetynet.py
Rajan Sharma
Create narrative_safetynet.py
42a6bd6 verified
raw
history blame
12.2 kB
# narrative_safetynet.py
from __future__ import annotations
from typing import Dict, Any, List, Optional, Tuple
import math
import numpy as np
import pandas as pd
import re
# ---------- Generic helpers ----------
_DEF_MIN_SAMPLE = 5 # threshold for "interpret with caution" (fully generic)
def _is_numeric(s: pd.Series) -> bool:
return pd.api.types.is_numeric_dtype(s)
def _fmt_num(x: Any, decimals: int = 1) -> str:
try:
if x is None or (isinstance(x, float) and math.isnan(x)):
return "n/a"
if isinstance(x, (int, np.integer)):
return f"{x:,}"
return f"{float(x):,.{decimals}f}"
except Exception:
return str(x)
def _pick_numeric(df: pd.DataFrame, hints: List[str]) -> Optional[str]:
# choose a numeric column; use hints like "Surgery_Median", "Consult_Median" if present
cols = list(df.columns)
for h in hints:
for c in cols:
if h.lower() in c.lower() and _is_numeric(df[c]):
return c
for c in cols:
if _is_numeric(df[c]):
return c
return None
def _find_group_col(df: pd.DataFrame, candidates: List[str]) -> Optional[str]:
# choose a categorical/grouping column by fuzzy name
cols = list(df.columns)
for cand in candidates:
for c in cols:
if cand.lower() in c.lower():
return c
# fallback: first object/string column with reasonable cardinality
obj_cols = [c for c in cols if df[c].dtype == "object"]
for c in obj_cols:
nuniq = df[c].nunique(dropna=True)
if 1 < nuniq < max(50, len(df) // 10): # avoid IDs (too high cardinality) and constants
return c
return None
def _nanlike_to_nan(df: pd.DataFrame) -> pd.DataFrame:
# treat dashes and blank strings as NaN; do not hard-code schema
dff = df.copy()
for c in dff.columns:
if dff[c].dtype == "object":
dff[c] = dff[c].replace({r"^\s*$": np.nan, r"^[-–—]$": np.nan}, regex=True)
return dff
def _small_sample_note(n: int, min_n: int = _DEF_MIN_SAMPLE) -> Optional[str]:
return f"Interpret averages cautiously (only {n} records)." if n < min_n else None
def _deviation_label(x: float, mu: float, tol: float = 0.01) -> str:
# tol is a fraction of mu for β‰ˆ equal bucket (1% default)
if np.isnan(x) or np.isnan(mu):
return "unknown"
if mu == 0:
return "unknown"
rel = (x - mu) / mu
if rel > 0.05: # > +5% above average
return "higher than average"
if rel < -0.05: # < -5% below average
return "lower than average"
if abs(rel) <= max(tol, 0.05):
return "about average"
return "about average"
def _pluralize(label: str, n: int) -> str:
return f"{label}{'' if n==1 else 's'}"
# ---------- Core narrative generator ----------
def build_narrative(
scenario_text: str,
# A dict of dataframes your engine just produced / loaded (e.g., the same dict passed to ScenarioEngine)
datasets: Dict[str, Any],
# Optional structured outputs your engine already rendered (tables etc.) if you want to cross-reference
structured_tables: Optional[Dict[str, pd.DataFrame]] = None,
# Hints for metric selection (keeps it scenario-agnostic)
metric_hints: Optional[List[str]] = None, # e.g. ["Surgery_Median", "Consult_Median", "Wait"]
group_hints: Optional[List[str]] = None, # e.g. ["Facility","Specialty","Zone"]
min_sample: int = _DEF_MIN_SAMPLE
) -> str:
"""
Returns a markdown narrative that:
- Summarizes methodology (cleaning, numeric detection)
- Highlights top groups by the chosen metric
- Computes an overall baseline and compares groups vs baseline
- Flags small-sample groups
- Adds geographic notes if city/lat/lon are present (fully optional)
This function avoids any scenario-specific strings and infers columns dynamically.
"""
metric_hints = metric_hints or ["surgery_median", "consult_median", "wait", "median", "p50"]
group_hints = group_hints or ["facility", "specialty", "zone", "hospital", "city", "region"]
# 1) Pick a primary dataset (first table-like) and sanitize
df = None
df_key = None
for k, v in datasets.items():
if isinstance(v, pd.DataFrame) and not v.empty:
df = _nanlike_to_nan(v)
df_key = k
break
if df is None:
return "No tabular data available. Unable to generate a narrative."
# 2) Pick primary metric (numeric) and up to two comparators (e.g., consult vs surgery)
primary_metric = _pick_numeric(df, metric_hints) # e.g., Surgery_Median
if not primary_metric:
return "No numeric metric found to summarize; please ensure at least one numeric wait-time column is present."
other_numeric = [c for c in df.columns if _is_numeric(df[c]) and c != primary_metric]
comparator_metric = next((c for c in other_numeric if any(h in c.lower() for h in ["consult", "wait", "median", "p90", "90th"])), None)
# 3) Choose groupings dynamically
group1 = _find_group_col(df, group_hints) # e.g., Facility
group2 = None
if group1:
# try to find a second group that isn't identical (e.g., Zone if Facility selected)
alt_hints = [h for h in group_hints if h.lower() not in group1.lower()]
group2 = _find_group_col(df.drop(columns=[group1], errors="ignore"), alt_hints)
# 4) Baseline (overall) and grouped stats
baseline = df[primary_metric].astype(float).mean(skipna=True)
# grouped (group1)
g1 = None
if group1:
g1 = (
df.groupby(group1, dropna=False)
.agg(
metric=(primary_metric, "mean"),
count=(primary_metric, "count"),
_comp=(comparator_metric, "mean") if comparator_metric else (primary_metric, "mean"),
)
.reset_index()
)
# grouped (group2)
g2 = None
if group2:
g2 = (
df.groupby(group2, dropna=False)
.agg(
metric=(primary_metric, "mean"),
count=(primary_metric, "count"),
_comp=(comparator_metric, "mean") if comparator_metric else (primary_metric, "mean"),
)
.reset_index()
)
# 5) Identify top/bottom (by deviation) for group1
top_lines: List[str] = []
if isinstance(g1, pd.DataFrame) and not g1.empty:
g1 = g1.sort_values(by="metric", ascending=False)
k = min(5, len(g1))
for i, row in enumerate(g1.head(k).itertuples(index=False), 1):
label = getattr(row, group1)
metric = getattr(row, "metric")
comp = getattr(row, "_comp")
cnt = getattr(row, "count")
devlab = _deviation_label(metric, baseline)
caution = _small_sample_note(int(cnt), min_sample)
msg = f"{i}. **{label}** β€” {primary_metric}: {_fmt_num(metric)}"
if comparator_metric:
msg += f"; {comparator_metric}: {_fmt_num(comp)}"
msg += f"; {_pluralize('record', int(cnt))}: {cnt}"
msg += f" β†’ {devlab}"
if caution:
msg += f" ({caution})"
top_lines.append(msg)
# 6) Zone/region style overview (group2)
region_lines: List[str] = []
if isinstance(g2, pd.DataFrame) and not g2.empty:
# order by metric descending
g2 = g2.sort_values(by="metric", ascending=False)
for row in g2.itertuples(index=False):
label = getattr(row, group2)
metric = getattr(row, "metric")
comp = getattr(row, "_comp")
cnt = getattr(row, "count")
devlab = _deviation_label(metric, baseline)
caution = _small_sample_note(int(cnt), min_sample)
line = f"- **{label}**: {_fmt_num(metric)} (vs. overall {_fmt_num(baseline)} β†’ {devlab}); n={cnt}"
if comparator_metric:
line += f"; {comparator_metric}: {_fmt_num(comp)}"
if caution:
line += f" β€” {caution}"
region_lines.append(line)
# 7) Geographic notes (if present)
# We never hard-code field names; we look for city/lat/lon patterns
geo_notes: List[str] = []
city_col = next((c for c in df.columns if re.search(r"\bcity\b", c, re.I)), None)
lat_col = next((c for c in df.columns if re.search(r"\b(lat|latitude)\b", c, re.I)), None)
lon_col = next((c for c in df.columns if re.search(r"\b(lon|longitude)\b", c, re.I)), None)
if group1 and city_col and (lat_col and lon_col):
# summarize whether top groups cluster in specific cities
if isinstance(g1, pd.DataFrame) and not g1.empty and group1 in df.columns:
# join back to get city data for topK
top_labels = [re.sub(r"\s+", " ", re.sub(r"^\s+|\s+$", "", re.sub(r"\n", " ", l))) for l in g1[group1].astype(str).head(10).tolist()]
sub = df[df[group1].astype(str).isin(top_labels)]
if not sub.empty:
by_city = sub.groupby(city_col, dropna=False)[primary_metric].mean().reset_index().sort_values(by=primary_metric, ascending=False)
# Only a brief, dynamic note (no hard-coded cities)
top_city_rows = by_city.head(3).to_dict(orient="records")
for r in top_city_rows:
cname = r.get(city_col)
val = r.get(primary_metric)
geo_notes.append(f"- **{cname}** shows higher average {primary_metric} among top groups ({_fmt_num(val)}).")
# 8) Methodology (derived from actual data conditions)
methodology: List[str] = []
# missing values
na_counts = df.isna().sum().sum()
if na_counts > 0:
methodology.append("Missing values (blank/dash) were treated as nulls and excluded from means.")
# numeric coercion note
methodology.append(f"Primary metric: **{primary_metric}**; overall average: **{_fmt_num(baseline)}**.")
if comparator_metric:
methodology.append(f"Comparator metric detected: **{comparator_metric}** (means shown when available).")
if group1:
methodology.append(f"Primary grouping inferred: **{group1}**.")
if group2:
methodology.append(f"Secondary grouping inferred: **{group2}**.")
if min_sample != _DEF_MIN_SAMPLE:
methodology.append(f"Small-sample threshold set to {min_sample} records.")
# 9) Compose markdown
lines: List[str] = []
lines.append("## Methodology (Auto-generated)")
for m in methodology:
lines.append(f"- {m}")
lines.append("")
if top_lines:
lines.append("## Highest average values by group")
lines.extend(top_lines)
lines.append("")
if region_lines:
lines.append(f"## {group2 or 'Region/Category'} comparison vs overall")
lines.extend(region_lines)
lines.append("")
if geo_notes:
lines.append("## Geographic notes")
lines.extend(geo_notes)
lines.append("")
# Generic recommendations template (data-driven, not hard-coded)
recs: List[str] = []
if top_lines:
recs.append("Prioritize resources to the highest-average groups (above overall baseline), especially those with sufficient volume.")
if comparator_metric:
recs.append(f"Cross-check {comparator_metric} trends to identify upstream bottlenecks (e.g., long consult waits pushing surgery waits).")
if isinstance(g2, pd.DataFrame) and not g2.empty:
high = g2[g2["metric"] > baseline]
if not high.empty:
recs.append(f"Address regional disparities where average **{primary_metric}** exceeds the overall baseline.")
recs.append("For very small groups, validate data quality and consider pooling across similar categories to stabilize estimates.")
recs.append("Validate coding differences (similar specialties or labels spelled differently) to ensure apples-to-apples comparison.")
lines.append("## Recommendations (Auto-generated)")
for r in recs:
lines.append(f"- {r}")
return "\n".join(lines).strip()