Spaces:
Sleeping
Sleeping
| # narrative_safetynet.py | |
| from __future__ import annotations | |
| from typing import Dict, Any, List, Optional, Tuple | |
| import re | |
| import math | |
| import numpy as np | |
| import pandas as pd | |
| # -------------------- helpers: dtype / formatting -------------------- | |
| _DEF_MIN_SAMPLE = 5 # generic caution threshold for group sizes | |
| _HINT_METRICS_DEFAULT = [ | |
| "surgery_median", "consult_median", | |
| "surgery_90th", "consult_90th", | |
| "surgery", "consult", | |
| "wait", "median", "p90", "90th" | |
| ] | |
| _HINT_GROUPS_DEFAULT = [ | |
| "facility", "specialty", "zone", | |
| "hospital", "city", "region" | |
| ] | |
| _BAD_METRIC_NAMES = ["index", "id", "row", "unnamed"] | |
| def _nanlike_to_nan(df: pd.DataFrame) -> pd.DataFrame: | |
| dff = df.copy() | |
| for c in dff.columns: | |
| if dff[c].dtype == "object": | |
| dff[c] = dff[c].replace({r"^\s*$": np.nan, r"^[-ββ]$": np.nan}, regex=True) | |
| return dff | |
| def _is_numeric_series(s: pd.Series) -> bool: | |
| try: | |
| return pd.api.types.is_numeric_dtype(s) | |
| except Exception: | |
| return False | |
| def _to_numeric(s: pd.Series) -> pd.Series: | |
| return pd.to_numeric(s, errors="coerce") | |
| def _fmt_num(x: Any, decimals: int = 1) -> str: | |
| try: | |
| if x is None or (isinstance(x, float) and math.isnan(x)): | |
| return "n/a" | |
| if isinstance(x, (int, np.integer)) or (isinstance(x, float) and float(x).is_integer()): | |
| return f"{int(round(float(x))):,}" | |
| return f"{float(x):,.{decimals}f}" | |
| except Exception: | |
| return str(x) | |
| # -------------------- metric & dataset selection (dynamic) -------------------- | |
| def _score_metric_name(col: str, hints: List[str]) -> int: | |
| name = (col or "").lower() | |
| if any(bad in name for bad in _BAD_METRIC_NAMES): | |
| return -10**6 # disqualify obvious counters/ids | |
| score = 0 | |
| for h in hints: | |
| if h in name: | |
| score += 3 | |
| return score | |
| def _choose_df_and_metric( | |
| datasets: Dict[str, Any], | |
| metric_hints: List[str] | |
| ) -> Optional[Tuple[str, pd.DataFrame, str]]: | |
| """ | |
| Sweep all dataframes & numeric columns. Pick the (df, metric) with best score: | |
| +3 per hint match; +1 if non-constant numeric. Disqualify id-like names. | |
| """ | |
| best: Optional[Tuple[int, str, pd.DataFrame, str]] = None | |
| for key, v in datasets.items(): | |
| if not isinstance(v, pd.DataFrame) or v.empty: | |
| continue | |
| df = _nanlike_to_nan(v) | |
| for col in df.columns: | |
| col_num = _to_numeric(df[col]) | |
| if not _is_numeric_series(col_num): | |
| continue | |
| s = _score_metric_name(col, metric_hints) | |
| if col_num.nunique(dropna=True) > 1: | |
| s += 1 | |
| if best is None or s > best[0]: | |
| best = (s, key, df, col) | |
| if best is None: | |
| return None | |
| _, key, df, metric = best | |
| return key, df, metric | |
| # -------------------- grouping detection (dynamic) -------------------- | |
| def _find_group_col(df: pd.DataFrame, candidates: List[str], avoid: Optional[List[str]] = None) -> Optional[str]: | |
| avoid = [a.lower() for a in (avoid or [])] | |
| cols = list(df.columns) | |
| # prefer name matches | |
| for cand in candidates: | |
| for c in cols: | |
| cname = c.lower() | |
| if cand.lower() in cname and all(a not in cname for a in avoid): | |
| return c | |
| # fallback: a categorical with reasonable cardinality | |
| obj_cols = [c for c in cols if df[c].dtype == "object"] | |
| for c in obj_cols: | |
| nuniq = df[c].nunique(dropna=True) | |
| if 1 < nuniq < max(50, len(df)//10): | |
| return c | |
| return None | |
| # -------------------- labels & cautions -------------------- | |
| def _label_vs_baseline(x: float, mu: float, band: float = 0.05) -> str: | |
| if pd.isna(x) or pd.isna(mu) or mu == 0: | |
| return "unknown" | |
| rel = (x - mu) / mu | |
| if rel > band: | |
| return "higher than average" | |
| if rel < -band: | |
| return "lower than average" | |
| return "about average" | |
| def _small_sample_note(n: int, min_n: int = _DEF_MIN_SAMPLE) -> Optional[str]: | |
| return f"Interpret averages cautiously (only {n} records)." if n < min_n else None | |
| def _pluralize(word: str, n: int) -> str: | |
| return f"{word}{'' if n == 1 else 's'}" | |
| # -------------------- geo join (Top-5 only) -------------------- | |
| def _canon(s: str) -> str: | |
| return re.sub(r"[^a-z0-9]+", "", (s or "").lower()) | |
| def _map_top_facilities_to_odhf( | |
| top_facilities: pd.DataFrame, | |
| odhf: pd.DataFrame, | |
| fac_col: str = "Facility", | |
| odhf_name_col: str = "facility_name" | |
| ) -> pd.DataFrame: | |
| if odhf is None or odhf.empty or top_facilities is None or top_facilities.empty: | |
| return pd.DataFrame() | |
| out_rows: List[Dict[str, Any]] = [] | |
| try: | |
| idx = { _canon(n): i for i, n in odhf[odhf_name_col].dropna().items() } | |
| except Exception: | |
| return pd.DataFrame() | |
| for fac in top_facilities[fac_col].dropna().astype(str).unique(): | |
| key = _canon(fac) | |
| row = None | |
| if key in idx: | |
| row = odhf.loc[idx[key]] | |
| else: | |
| # contains fallback (case-insensitive) | |
| cand = odhf[odhf[odhf_name_col].astype(str).str.contains(fac, case=False, na=False)] | |
| if not cand.empty: | |
| row = cand.iloc[0] | |
| if row is not None: | |
| out_rows.append({ | |
| "Facility": fac, | |
| "city": row.get("city"), | |
| "latitude": row.get("latitude"), | |
| "longitude": row.get("longitude") | |
| }) | |
| return pd.DataFrame(out_rows) | |
| # -------------------- main: narrative builder -------------------- | |
| def build_narrative( | |
| scenario_text: str, | |
| datasets: Dict[str, Any], | |
| structured_tables: Optional[Dict[str, pd.DataFrame]] = None, | |
| metric_hints: Optional[List[str]] = None, | |
| group_hints: Optional[List[str]] = None, | |
| min_sample: int = _DEF_MIN_SAMPLE, | |
| baseline_band: float = 0.05 # Β±5% "about average" | |
| ) -> str: | |
| """ | |
| Scenario-agnostic narrative fallback: | |
| 1) Choose best (df, metric) dynamically using name hints + numeric sanity | |
| 2) Prefer structured tables (top facilities/specialties/zones) if provided | |
| 3) Compute overall baseline + label groups vs baseline | |
| 4) Geo notes via fuzzy Top-5 β ODHF join (<= 3 bullets) | |
| 5) Recommendations grounded in the same metric/groups | |
| """ | |
| metric_hints = (metric_hints or _HINT_METRICS_DEFAULT) | |
| group_hints = (group_hints or _HINT_GROUPS_DEFAULT) | |
| # ---------- 1) Pick dataset + metric ---------- | |
| choice = _choose_df_and_metric(datasets, metric_hints) | |
| if not choice: | |
| return "No tabular data available. Unable to generate a narrative." | |
| df_key, df, primary_metric = choice | |
| # Ensure numeric | |
| df = _nanlike_to_nan(df) | |
| if primary_metric not in df.columns: | |
| return "Chosen metric missing. Unable to generate a narrative." | |
| df[primary_metric] = _to_numeric(df[primary_metric]) | |
| # Optional comparator metric (e.g., consult vs surgery) | |
| comparator_metric = None | |
| for c in df.columns: | |
| if c == primary_metric: | |
| continue | |
| if _is_numeric_series(_to_numeric(df[c])): | |
| name = c.lower() | |
| if any(h in name for h in ["consult", "median", "wait", "p90", "90th"]): | |
| comparator_metric = c | |
| break | |
| # ---------- 2) Prefer structured tables if present ---------- | |
| top_fac = None | |
| top_spec = None | |
| zone_tbl = None | |
| odhf_df = None | |
| if structured_tables: | |
| top_fac = structured_tables.get("top_facilities") | |
| top_spec = structured_tables.get("top_specialties") | |
| zone_tbl = structured_tables.get("zone_summary") | |
| # try to detect ODHF-like table by column fingerprint | |
| for k, v in datasets.items(): | |
| if isinstance(v, pd.DataFrame) and {"facility_name", "city"}.issubset(set(map(str.lower, v.columns.str.lower()))): | |
| odhf_df = v | |
| break | |
| # Compute baseline from the selected df/metric (not from ODHF) | |
| baseline = df[primary_metric].mean(skipna=True) | |
| # ---------- 3) Build sections ---------- | |
| sections: List[str] = [] | |
| # Methodology | |
| meth: List[str] = [] | |
| meth.append(f"Primary metric: **{primary_metric}**; overall average: **{_fmt_num(baseline)}**.") | |
| if comparator_metric: | |
| meth.append(f"Comparator metric detected: **{comparator_metric}** (means shown when available).") | |
| # Missing value note | |
| if df.isna().sum().sum() > 0: | |
| meth.append("Missing values (blank/dash) were treated as nulls and excluded from means.") | |
| # Group hints (informative only) | |
| g1 = _find_group_col(df, group_hints, avoid=[primary_metric]) | |
| if g1: | |
| meth.append(f"Primary grouping inferred: **{g1}**.") | |
| g2 = _find_group_col(df.drop(columns=[g1], errors="ignore") if g1 else df, group_hints, avoid=[primary_metric, g1 or ""]) | |
| if g2: | |
| meth.append(f"Secondary grouping inferred: **{g2}**.") | |
| sections.append("## Methodology (Auto-generated)") | |
| for m in meth: | |
| sections.append(f"- {m}") | |
| sections.append("") | |
| # Highest averages by primary grouping (prefer structured Top-5 if given) | |
| top_lines: List[str] = [] | |
| if isinstance(top_fac, pd.DataFrame) and not top_fac.empty: | |
| # Expect columns like: Facility, Zone, avg_Surgery_Median, count_* | |
| # Keep dynamic: find a metric column in top_fac aligned to primary_metric by hint matching | |
| metric_col = None | |
| for c in top_fac.columns: | |
| if primary_metric.lower() in c.lower() or any(h in c.lower() for h in ["avg_", "mean"]): | |
| if _is_numeric_series(_to_numeric(top_fac[c])): | |
| metric_col = c | |
| break | |
| if metric_col is None: | |
| # fallback: first numeric col | |
| for c in top_fac.columns: | |
| if _is_numeric_series(_to_numeric(top_fac[c])): | |
| metric_col = c; break | |
| cnt_col = next((c for c in top_fac.columns if "count" in c.lower() or c.lower() in {"n", "records"}), None) | |
| lab_col = next((c for c in top_fac.columns if "facility" in c.lower()), None) | |
| if metric_col and lab_col: | |
| # already sorted in your executor; if not, sort desc | |
| tf = top_fac.copy() | |
| tf[metric_col] = _to_numeric(tf[metric_col]) | |
| tf = tf.sort_values(by=metric_col, ascending=False) | |
| for i, row in enumerate(tf.head(5).itertuples(index=False), 1): | |
| label = getattr(row, lab_col) | |
| met = getattr(row, metric_col) | |
| cnt = getattr(row, cnt_col) if cnt_col and hasattr(row, cnt_col) else np.nan | |
| dev = _label_vs_baseline(met, baseline, baseline_band) | |
| caution = _small_sample_note(int(cnt)) if (isinstance(cnt, (int, float)) and not pd.isna(cnt)) else None | |
| msg = f"{i}. **{label}** β {primary_metric}: {_fmt_num(met)}" | |
| if cnt_col and hasattr(row, cnt_col): | |
| msg += f"; {_pluralize('record', int(cnt))}: {int(cnt)}" | |
| msg += f" β {dev}" | |
| if caution: | |
| msg += f" ({caution})" | |
| top_lines.append(msg) | |
| else: | |
| # No structured Top-5 provided: derive from g1 | |
| if g1: | |
| tmp = df.copy() | |
| tmp[primary_metric] = _to_numeric(tmp[primary_metric]) | |
| if comparator_metric in tmp.columns: | |
| tmp[comparator_metric] = _to_numeric(tmp[comparator_metric]) | |
| agg = ( | |
| tmp.groupby(g1, dropna=False) | |
| .agg(metric=(primary_metric, "mean"), count=(primary_metric, "count")) | |
| .reset_index() | |
| ).sort_values(by="metric", ascending=False) | |
| for i, row in enumerate(agg.head(5).itertuples(index=False), 1): | |
| label = getattr(row, g1) | |
| met = getattr(row, "metric") | |
| cnt = getattr(row, "count") | |
| dev = _label_vs_baseline(met, baseline, baseline_band) | |
| caution = _small_sample_note(int(cnt), min_sample) | |
| msg = f"{i}. **{label}** β {primary_metric}: {_fmt_num(met)}; {_pluralize('record', int(cnt))}: {cnt} β {dev}" | |
| if caution: | |
| msg += f" ({caution})" | |
| top_lines.append(msg) | |
| if top_lines: | |
| sections.append("## Highest average values by group") | |
| sections.extend(top_lines) | |
| sections.append("") | |
| # Zone comparison (prefer structured zone table if present) | |
| zone_lines: List[str] = [] | |
| if isinstance(zone_tbl, pd.DataFrame) and not zone_tbl.empty: | |
| z = zone_tbl.copy() | |
| # find zone label & metric columns dynamically | |
| zone_col = next((c for c in z.columns if "zone" in c.lower()), None) | |
| zmet_col = next((c for c in z.columns if primary_metric.lower() in c.lower() or "avg" in c.lower()), None) | |
| zcnt_col = next((c for c in z.columns if "count" in c.lower() or c.lower() in {"n", "records"}), None) | |
| if zone_col and zmet_col: | |
| # Clean truly missing zones but keep literal "Total" if present | |
| z[zone_col] = z[zone_col].astype("string") | |
| keep = (z[zone_col].notna()) | (z[zone_col].str.upper() == "TOTAL") | |
| z = z[keep] | |
| z[zmet_col] = _to_numeric(z[zmet_col]) | |
| z = z.sort_values(by=zmet_col, ascending=False) | |
| for row in z.itertuples(index=False): | |
| zone = getattr(row, zone_col) | |
| met = getattr(row, zmet_col) | |
| cnt = getattr(row, zcnt_col) if zcnt_col and hasattr(row, zcnt_col) else np.nan | |
| lab = _label_vs_baseline(met, baseline, baseline_band) | |
| msg = f"- **{zone}**: {_fmt_num(met)} (vs overall {_fmt_num(baseline)} β {lab})" | |
| if zcnt_col and hasattr(row, zcnt_col) and not pd.isna(cnt): | |
| msg += f"; n={int(cnt)}" | |
| zone_lines.append(msg) | |
| else: | |
| # Derive zones dynamically if a zone-like column exists | |
| zcol = _find_group_col(df, ["zone"]) | |
| if zcol: | |
| z = df.copy() | |
| z[zcol] = z[zcol].astype("string").str.strip() | |
| # drop true NaN zones, but do NOT fabricate totals | |
| z = z[z[zcol].notna()] | |
| agg = ( | |
| z.groupby(zcol, dropna=False)[primary_metric] | |
| .agg(["mean", "count"]).reset_index() | |
| .rename(columns={"mean": "metric", "count": "count"}) | |
| .sort_values(by="metric", ascending=False) | |
| ) | |
| for row in agg.itertuples(index=False): | |
| zone = getattr(row, zcol) | |
| met = getattr(row, "metric") | |
| cnt = getattr(row, "count") | |
| lab = _label_vs_baseline(met, baseline, baseline_band) | |
| msg = f"- **{zone}**: {_fmt_num(met)} (vs overall {_fmt_num(baseline)} β {lab}); n={cnt}" | |
| zone_lines.append(msg) | |
| if zone_lines: | |
| sections.append(f"## {( 'Zone' if 'zone' in ''.join(df.columns).lower() else 'Category')} comparison vs overall") | |
| sections.extend(zone_lines) | |
| sections.append("") | |
| # Geographic notes β map Top-5 facilities only (if we have both Top-5 and ODHF df) | |
| geo_lines: List[str] = [] | |
| if isinstance(top_fac, pd.DataFrame) and not top_fac.empty and isinstance(odhf_df, pd.DataFrame) and not odhf_df.empty: | |
| fac_col = next((c for c in top_fac.columns if "facility" in c.lower()), None) | |
| if fac_col: | |
| mapped = _map_top_facilities_to_odhf(top_fac.head(5), odhf_df, fac_col=fac_col, odhf_name_col=next( | |
| (c for c in odhf_df.columns if c.lower() == "facility_name"), "facility_name" | |
| )) | |
| if not mapped.empty: | |
| for r in mapped.head(3).to_dict(orient="records"): | |
| f = r.get("Facility") | |
| city = r.get("city") | |
| geo_lines.append(f"- **{f}** ({city}) is among the highest-average groups; consider capacity and referral patterns.") | |
| if geo_lines: | |
| sections.append("## Geographic notes") | |
| sections.extend(geo_lines) | |
| sections.append("") | |
| # Recommendations β grounded in the above | |
| recs: List[str] = [] | |
| if top_lines: | |
| recs.append("Prioritize operating room time and staffing for the highest-average groups, especially those with substantial volume.") | |
| if comparator_metric: | |
| recs.append(f"Track **{comparator_metric}** alongside {primary_metric} to identify upstream bottlenecks (e.g., long consult waits driving surgical delays).") | |
| if zone_lines: | |
| recs.append("Address zones persistently above the provincial baseline; deploy targeted resources and load balancing across facilities.") | |
| recs.append("Apply small-sample caution; pool or validate categories with very few records before acting on outliers.") | |
| recs.append("Standardize specialty/facility naming to reduce coding-induced variance in aggregates.") | |
| sections.append("## Recommendations (Auto-generated)") | |
| for r in recs: | |
| sections.append(f"- {r}") | |
| return "\n".join(sections).strip() | |