from typing import Dict, Any, List from pydantic import ValidationError import json from schema import ScenarioPlan, TaskPlan from llm_router import cohere_chat, open_fallback_chat _PARSER_PROMPT = """You are a Canadian healthcare analysis planner. Read the scenario and produce a JSON plan with tasks that exactly match the scenario requests. Each task must specify: title, format, data_key (if you can infer), and any needed steps: - filter: boolean expression using dataframe column names - derive: list of 'col = expression' for calculated columns - joins: list of objects {right_key, left_on, right_on, how} - group_by: list of columns - agg: list of aggregations like 'avg(x)', 'median(y)', 'p90(z)', 'count(*)' - pivot: {"index":[...], "columns":"...", "values":"..."} - sort_by, sort_dir, top - fields: columns to output in order - chart + encodings: when a chart is requested (x,y,color,column) - number_format: per-column display formats (e.g., {"wait":"0", "rate":"0.0%"}) If you cannot infer a field/column name, leave it as-is; the executor will attempt to resolve. Output STRICT JSON only with keys: tasks, notes. DO NOT include explanations outside JSON.""" def build_parser_prompt(scenario_text: str, dataset_catalog: Dict[str, list]) -> str: cat_lines = [f"- {k}: {', '.join(cols)}" for k, cols in dataset_catalog.items()] catalog_str = "\n".join(cat_lines) if cat_lines else "- (no files uploaded yet)" return f"""{_PARSER_PROMPT} # Uploaded datasets and their columns {catalog_str} # Scenario {scenario_text} """ def _llm_parse(prompt: str) -> str | None: return cohere_chat(prompt) or open_fallback_chat(prompt) def _safe_json_slice(raw: str) -> str | None: raw = (raw or "").strip() i, j = raw.find("{"), raw.rfind("}") if i == -1 or j == -1: return None return raw[i:j+1] def _make_safety_net_plan(scenario_text: str, dataset_catalog: Dict[str, list]) -> ScenarioPlan: csvs = list(dataset_catalog.keys()) wt = next((k for k in csvs if "wait" in k.lower()), (csvs[0] if csvs else None)) fac = next((k for k in csvs if "facility" in k.lower() or "hospital" in k.lower()), wt) tasks: List[TaskPlan] = [] text = scenario_text.lower() wants_map = "map" in text or "geographic distribution" in text or "geographic" in text wants_reco = "recommend" in text or "allocation" in text or "plan" in text tasks.append(TaskPlan( title="Top facilities by average surgery wait", data_key=wt, format="table", group_by=["Facility","Zone"], agg=["avg(Surgery_Median)","avg(Surgery_90th)","count(*)"], sort_by="avg_Surgery_Median", sort_dir="desc", top=5, fields=["Facility","Zone","avg_Surgery_Median","avg_Surgery_90th","count"], number_format={"avg_Surgery_Median":"0", "avg_Surgery_90th":"0"} )) tasks.append(TaskPlan( title="Top specialties by average surgery wait", data_key=wt, format="table", group_by=["Specialty"], agg=["avg(Surgery_Median)","avg(Consult_Median)","count(*)"], sort_by="avg_Surgery_Median", sort_dir="desc", top=5, fields=["Specialty","avg_Surgery_Median","avg_Consult_Median","count"], number_format={"avg_Surgery_Median":"0", "avg_Consult_Median":"0"} )) tasks.append(TaskPlan( title="Zone-level surgery wait comparison", data_key=wt, format="table", group_by=["Zone"], agg=["avg(Surgery_Median)","count(*)"], sort_by="avg_Surgery_Median", sort_dir="desc", fields=["Zone","avg_Surgery_Median","count"], number_format={"avg_Surgery_Median":"0"} )) if wants_map and fac: tasks.append(TaskPlan( title="Geographic distribution of high-wait facilities", data_key=wt, format="map", group_by=["Facility","Zone"], agg=["avg(Surgery_Median)"], sort_by="avg_Surgery_Median", sort_dir="desc", top=5, joins=[{"right_key": fac, "left_on": "Facility", "right_on": "facility_name", "how": "left"}], fields=["Facility","Zone","city","latitude","longitude","avg_Surgery_Median"], number_format={"avg_Surgery_Median":"0"} )) if wants_reco: tasks.append(TaskPlan( title="Recommendations (inputs for narrative)", data_key=wt, format="narrative", group_by=["Facility","Zone"], agg=["avg(Surgery_Median)","count(*)"], sort_by="avg_Surgery_Median", sort_dir="desc", top=10, fields=["Facility","Zone","avg_Surgery_Median","count"], number_format={"avg_Surgery_Median":"0"} )) return ScenarioPlan(tasks=tasks, notes="Safety-net plan (LLM planner failed).") def parse_to_plan(scenario_text: str, dataset_catalog: Dict[str, list]) -> ScenarioPlan: prompt = build_parser_prompt(scenario_text, dataset_catalog) raw = _llm_parse(prompt) if not raw: return _make_safety_net_plan(scenario_text, dataset_catalog) js = _safe_json_slice(raw) if not js: return _make_safety_net_plan(scenario_text, dataset_catalog) try: obj = json.loads(js) plan = ScenarioPlan(**obj) if not plan.tasks or (len(plan.tasks) == 1 and plan.tasks[0].format == "narrative"): return _make_safety_net_plan(scenario_text, dataset_catalog) return plan except (json.JSONDecodeError, ValidationError): return _make_safety_net_plan(scenario_text, dataset_catalog)