Spaces:
Sleeping
Sleeping
| from typing import Dict, Any, List | |
| from pydantic import ValidationError | |
| import json | |
| from schema import ScenarioPlan, TaskPlan | |
| from llm_router import cohere_chat, open_fallback_chat | |
| _PARSER_PROMPT = """You are a Canadian healthcare analysis planner. | |
| Read the scenario and produce a JSON plan with tasks that exactly match the scenario requests. | |
| Each task must specify: title, format, data_key (if you can infer), and any needed steps: | |
| - filter: boolean expression using dataframe column names | |
| - derive: list of 'col = expression' for calculated columns | |
| - joins: list of objects {right_key, left_on, right_on, how} | |
| - group_by: list of columns | |
| - agg: list of aggregations like 'avg(x)', 'median(y)', 'p90(z)', 'count(*)' | |
| - pivot: {"index":[...], "columns":"...", "values":"..."} | |
| - sort_by, sort_dir, top | |
| - fields: columns to output in order | |
| - chart + encodings: when a chart is requested (x,y,color,column) | |
| - number_format: per-column display formats (e.g., {"wait":"0", "rate":"0.0%"}) | |
| If you cannot infer a field/column name, leave it as-is; the executor will attempt to resolve. | |
| Output STRICT JSON only with keys: tasks, notes. DO NOT include explanations outside JSON.""" | |
| def build_parser_prompt(scenario_text: str, dataset_catalog: Dict[str, list]) -> str: | |
| cat_lines = [f"- {k}: {', '.join(cols)}" for k, cols in dataset_catalog.items()] | |
| catalog_str = "\n".join(cat_lines) if cat_lines else "- (no files uploaded yet)" | |
| return f"""{_PARSER_PROMPT} | |
| # Uploaded datasets and their columns | |
| {catalog_str} | |
| # Scenario | |
| {scenario_text} | |
| """ | |
| def _llm_parse(prompt: str) -> str | None: | |
| return cohere_chat(prompt) or open_fallback_chat(prompt) | |
| def _safe_json_slice(raw: str) -> str | None: | |
| raw = (raw or "").strip() | |
| i, j = raw.find("{"), raw.rfind("}") | |
| if i == -1 or j == -1: return None | |
| return raw[i:j+1] | |
| def _make_safety_net_plan(scenario_text: str, dataset_catalog: Dict[str, list]) -> ScenarioPlan: | |
| csvs = list(dataset_catalog.keys()) | |
| wt = next((k for k in csvs if "wait" in k.lower()), (csvs[0] if csvs else None)) | |
| fac = next((k for k in csvs if "facility" in k.lower() or "hospital" in k.lower()), wt) | |
| tasks: List[TaskPlan] = [] | |
| text = scenario_text.lower() | |
| wants_map = "map" in text or "geographic distribution" in text or "geographic" in text | |
| wants_reco = "recommend" in text or "allocation" in text or "plan" in text | |
| tasks.append(TaskPlan( | |
| title="Top facilities by average surgery wait", | |
| data_key=wt, format="table", | |
| group_by=["Facility","Zone"], | |
| agg=["avg(Surgery_Median)","avg(Surgery_90th)","count(*)"], | |
| sort_by="avg_Surgery_Median", sort_dir="desc", top=5, | |
| fields=["Facility","Zone","avg_Surgery_Median","avg_Surgery_90th","count"], | |
| number_format={"avg_Surgery_Median":"0", "avg_Surgery_90th":"0"} | |
| )) | |
| tasks.append(TaskPlan( | |
| title="Top specialties by average surgery wait", | |
| data_key=wt, format="table", | |
| group_by=["Specialty"], | |
| agg=["avg(Surgery_Median)","avg(Consult_Median)","count(*)"], | |
| sort_by="avg_Surgery_Median", sort_dir="desc", top=5, | |
| fields=["Specialty","avg_Surgery_Median","avg_Consult_Median","count"], | |
| number_format={"avg_Surgery_Median":"0", "avg_Consult_Median":"0"} | |
| )) | |
| tasks.append(TaskPlan( | |
| title="Zone-level surgery wait comparison", | |
| data_key=wt, format="table", | |
| group_by=["Zone"], agg=["avg(Surgery_Median)","count(*)"], | |
| sort_by="avg_Surgery_Median", sort_dir="desc", | |
| fields=["Zone","avg_Surgery_Median","count"], | |
| number_format={"avg_Surgery_Median":"0"} | |
| )) | |
| if wants_map and fac: | |
| tasks.append(TaskPlan( | |
| title="Geographic distribution of high-wait facilities", | |
| data_key=wt, format="map", | |
| group_by=["Facility","Zone"], agg=["avg(Surgery_Median)"], | |
| sort_by="avg_Surgery_Median", sort_dir="desc", top=5, | |
| joins=[{"right_key": fac, "left_on": "Facility", "right_on": "facility_name", "how": "left"}], | |
| fields=["Facility","Zone","city","latitude","longitude","avg_Surgery_Median"], | |
| number_format={"avg_Surgery_Median":"0"} | |
| )) | |
| if wants_reco: | |
| tasks.append(TaskPlan( | |
| title="Recommendations (inputs for narrative)", | |
| data_key=wt, format="narrative", | |
| group_by=["Facility","Zone"], agg=["avg(Surgery_Median)","count(*)"], | |
| sort_by="avg_Surgery_Median", sort_dir="desc", top=10, | |
| fields=["Facility","Zone","avg_Surgery_Median","count"], | |
| number_format={"avg_Surgery_Median":"0"} | |
| )) | |
| return ScenarioPlan(tasks=tasks, notes="Safety-net plan (LLM planner failed).") | |
| def parse_to_plan(scenario_text: str, dataset_catalog: Dict[str, list]) -> ScenarioPlan: | |
| prompt = build_parser_prompt(scenario_text, dataset_catalog) | |
| raw = _llm_parse(prompt) | |
| if not raw: | |
| return _make_safety_net_plan(scenario_text, dataset_catalog) | |
| js = _safe_json_slice(raw) | |
| if not js: | |
| return _make_safety_net_plan(scenario_text, dataset_catalog) | |
| try: | |
| obj = json.loads(js) | |
| plan = ScenarioPlan(**obj) | |
| if not plan.tasks or (len(plan.tasks) == 1 and plan.tasks[0].format == "narrative"): | |
| return _make_safety_net_plan(scenario_text, dataset_catalog) | |
| return plan | |
| except (json.JSONDecodeError, ValidationError): | |
| return _make_safety_net_plan(scenario_text, dataset_catalog) | |