Medica_DecisionSupportAI / narrative_safetynet.py
Rajan Sharma
Update narrative_safetynet.py
87088be verified
raw
history blame
10.6 kB
# narrative_safetynet.py
from __future__ import annotations
from typing import Dict, Any, List, Optional
import math
import numpy as np
import pandas as pd
import re
_DEF_MIN_SAMPLE = 5 # threshold for "interpret with caution" (fully generic)
def _is_numeric(s: pd.Series) -> bool:
return pd.api.types.is_numeric_dtype(s)
def _fmt_num(x: Any, decimals: int = 1) -> str:
try:
if x is None or (isinstance(x, float) and math.isnan(x)):
return "n/a"
if isinstance(x, (int, np.integer)):
return f"{x:,}"
return f"{float(x):,.{decimals}f}"
except Exception:
return str(x)
def _pick_numeric(df: pd.DataFrame, hints: List[str]) -> Optional[str]:
# choose a numeric column; prefer hinted names
cols = list(df.columns)
for h in hints:
for c in cols:
if h.lower() in c.lower() and _is_numeric(df[c]):
return c
for c in cols:
if _is_numeric(df[c]):
return c
return None
def _find_group_col(df: pd.DataFrame, candidates: List[str]) -> Optional[str]:
cols = list(df.columns)
for cand in candidates:
for c in cols:
if cand.lower() in c.lower():
return c
# fallback: first reasonable categorical column
obj_cols = [c for c in cols if df[c].dtype == "object"]
for c in obj_cols:
nuniq = df[c].nunique(dropna=True)
if 1 < nuniq < max(50, len(df) // 10):
return c
return None
def _nanlike_to_nan(df: pd.DataFrame) -> pd.DataFrame:
dff = df.copy()
for c in dff.columns:
if dff[c].dtype == "object":
dff[c] = dff[c].replace({r"^\s*$": np.nan, r"^[-–—]$": np.nan}, regex=True)
return dff
def _small_sample_note(n: int, min_n: int = _DEF_MIN_SAMPLE) -> Optional[str]:
return f"Interpret averages cautiously (only {n} records)." if n < min_n else None
def _deviation_label(x: float, mu: float, tol: float = 0.01) -> str:
if np.isnan(x) or np.isnan(mu) or mu == 0:
return "unknown"
rel = (x - mu) / mu
if rel > 0.05:
return "higher than average"
if rel < -0.05:
return "lower than average"
if abs(rel) <= max(tol, 0.05):
return "about average"
return "about average"
def _pluralize(label: str, n: int) -> str:
return f"{label}{'' if n==1 else 's'}"
def build_narrative(
scenario_text: str,
datasets: Dict[str, Any],
structured_tables: Optional[Dict[str, pd.DataFrame]] = None,
metric_hints: Optional[List[str]] = None,
group_hints: Optional[List[str]] = None,
min_sample: int = _DEF_MIN_SAMPLE
) -> str:
"""
Scenario-agnostic narrative fallback:
- Picks numeric metric & groupings dynamically
- Computes overall baseline + deviations
- Warns on small samples
- Optional geographic notes if city/lat/lon exist
"""
metric_hints = metric_hints or ["surgery_median", "consult_median", "wait", "median", "p90", "90th"]
group_hints = group_hints or ["facility", "specialty", "zone", "hospital", "city", "region"]
# 1) choose first non-empty table-like dataset
df = None
df_key = None
for k, v in datasets.items():
if isinstance(v, pd.DataFrame) and not v.empty:
df = _nanlike_to_nan(v)
df_key = k
break
if df is None:
return "No tabular data available. Unable to generate a narrative."
# 2) metrics
primary_metric = _pick_numeric(df, metric_hints) # e.g., Surgery_Median
if not primary_metric:
return "No numeric metric found to summarize; please ensure at least one numeric wait-time column is present."
other_numeric = [c for c in df.columns if _is_numeric(df[c]) and c != primary_metric]
comparator_metric = next(
(c for c in other_numeric if any(h in c.lower() for h in ["consult", "wait", "median", "p90", "90th"])),
None
)
# 3) groups
group1 = _find_group_col(df, group_hints) # e.g., Facility
group2 = None
if group1:
alt_hints = [h for h in group_hints if h.lower() not in group1.lower()]
group2 = _find_group_col(df.drop(columns=[group1], errors="ignore"), alt_hints)
# 4) baseline + grouped
baseline = pd.to_numeric(df[primary_metric], errors="coerce").mean(skipna=True)
def _group_stats(col: str) -> Optional[pd.DataFrame]:
if not col:
return None
tmp = df.copy()
tmp[primary_metric] = pd.to_numeric(tmp[primary_metric], errors="coerce")
comp_col = comparator_metric or primary_metric
if comp_col in tmp.columns:
tmp[comp_col] = pd.to_numeric(tmp[comp_col], errors="coerce")
agg = (
tmp.groupby(col, dropna=False)
.agg(
metric=(primary_metric, "mean"),
count=(primary_metric, "count"),
comp=(comp_col, "mean") if comp_col in tmp.columns else (primary_metric, "mean"),
)
.reset_index()
)
return agg
g1 = _group_stats(group1)
g2 = _group_stats(group2)
# 5) Top groups (by primary metric) from group1
top_lines: List[str] = []
if isinstance(g1, pd.DataFrame) and not g1.empty:
g1 = g1.sort_values(by="metric", ascending=False)
k = min(5, len(g1))
for i, row in enumerate(g1.head(k).itertuples(index=False), 1):
label = getattr(row, group1)
metric = getattr(row, "metric")
comp = getattr(row, "comp")
cnt = getattr(row, "count")
devlab = _deviation_label(metric, baseline)
caution = _small_sample_note(int(cnt), min_sample)
msg = f"{i}. **{label}** β€” {primary_metric}: {_fmt_num(metric)}"
if comparator_metric:
msg += f"; {comparator_metric}: {_fmt_num(comp)}"
msg += f"; {_pluralize('record', int(cnt))}: {cnt}"
msg += f" β†’ {devlab}"
if caution:
msg += f" ({caution})"
top_lines.append(msg)
# 6) Group2 overview
region_lines: List[str] = []
if isinstance(g2, pd.DataFrame) and not g2.empty:
g2 = g2.sort_values(by="metric", ascending=False)
for row in g2.itertuples(index=False):
label = getattr(row, group2)
metric = getattr(row, "metric")
comp = getattr(row, "comp")
cnt = getattr(row, "count")
devlab = _deviation_label(metric, baseline)
caution = _small_sample_note(int(cnt), min_sample)
line = f"- **{label}**: {_fmt_num(metric)} (vs. overall {_fmt_num(baseline)} β†’ {devlab}); n={cnt}"
if comparator_metric:
line += f"; {comparator_metric}: {_fmt_num(comp)}"
if caution:
line += f" β€” {caution}"
region_lines.append(line)
# 7) Geographic notes (optional)
geo_notes: List[str] = []
city_col = next((c for c in df.columns if re.search(r"\bcity\b", c, re.I)), None)
lat_col = next((c for c in df.columns if re.search(r"\b(lat|latitude)\b", c, re.I)), None)
lon_col = next((c for c in df.columns if re.search(r"\b(lon|longitude)\b", c, re.I)), None)
if group1 and city_col and (lat_col and lon_col):
if isinstance(g1, pd.DataFrame) and not g1.empty and group1 in df.columns:
top_labels = g1[group1].astype(str).head(10).tolist()
sub = df[df[group1].astype(str).isin(top_labels)].copy()
if not sub.empty:
sub[primary_metric] = pd.to_numeric(sub[primary_metric], errors="coerce")
by_city = (
sub.groupby(city_col, dropna=False)[primary_metric]
.mean()
.reset_index()
.sort_values(by=primary_metric, ascending=False)
)
for r in by_city.head(3).to_dict(orient="records"):
cname = r.get(city_col)
val = r.get(primary_metric)
geo_notes.append(f"- **{cname}** shows higher average {primary_metric} among top groups ({_fmt_num(val)}).")
# 8) Methodology (auto)
methodology: List[str] = []
na_counts = df.isna().sum().sum()
if na_counts > 0:
methodology.append("Missing values (blank/dash) were treated as nulls and excluded from means.")
methodology.append(f"Primary metric: **{primary_metric}**; overall average: **{_fmt_num(baseline)}**.")
if comparator_metric:
methodology.append(f"Comparator metric detected: **{comparator_metric}** (means shown when available).")
if group1:
methodology.append(f"Primary grouping inferred: **{group1}**.")
if group2:
methodology.append(f"Secondary grouping inferred: **{group2}**.")
if min_sample != _DEF_MIN_SAMPLE:
methodology.append(f"Small-sample threshold set to {min_sample} records.")
# 9) Compose markdown
lines: List[str] = []
lines.append("## Methodology (Auto-generated)")
for m in methodology:
lines.append(f"- {m}")
lines.append("")
if top_lines:
lines.append("## Highest average values by group")
lines.extend(top_lines)
lines.append("")
if region_lines:
lines.append(f"## {group2 or 'Region/Category'} comparison vs overall")
lines.extend(region_lines)
lines.append("")
if geo_notes:
lines.append("## Geographic notes")
lines.extend(geo_notes)
lines.append("")
recs: List[str] = []
if top_lines:
recs.append("Prioritize resources to the highest-average groups (above overall baseline), especially those with sufficient volume.")
if comparator_metric:
recs.append(f"Cross-check {comparator_metric} trends to identify upstream bottlenecks (e.g., long consult waits pushing surgery waits).")
if isinstance(g2, pd.DataFrame) and not g2.empty:
high = g2[g2["metric"] > baseline]
if not high.empty:
recs.append(f"Address disparities where average **{primary_metric}** exceeds the overall baseline.")
recs.append("For very small groups, validate data quality and consider pooling across similar categories to stabilize estimates.")
recs.append("Validate coding differences (similar specialties or labels spelled differently) to ensure apples-to-apples comparison.")
lines.append("## Recommendations (Auto-generated)")
for r in recs:
lines.append(f"- {r}")
return "\n".join(lines).strip()