Spaces:
Sleeping
Sleeping
File size: 5,484 Bytes
00c16d0 3f718dc 40da431 3f718dc 40da431 3f718dc 40da431 3f718dc 00c16d0 3f718dc 40da431 3f718dc 40da431 00c16d0 3f718dc 00c16d0 3f718dc 00c16d0 3f718dc 00c16d0 3f718dc 00c16d0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
from typing import Dict, Any, List
from pydantic import ValidationError
import json
from schema import ScenarioPlan, TaskPlan
from llm_router import cohere_chat, open_fallback_chat
_PARSER_PROMPT = """You are a Canadian healthcare analysis planner.
Read the scenario and produce a JSON plan with tasks that exactly match the scenario requests.
Each task must specify: title, format, data_key (if you can infer), and any needed steps:
- filter: boolean expression using dataframe column names
- derive: list of 'col = expression' for calculated columns
- joins: list of objects {right_key, left_on, right_on, how}
- group_by: list of columns
- agg: list of aggregations like 'avg(x)', 'median(y)', 'p90(z)', 'count(*)'
- pivot: {"index":[...], "columns":"...", "values":"..."}
- sort_by, sort_dir, top
- fields: columns to output in order
- chart + encodings: when a chart is requested (x,y,color,column)
- number_format: per-column display formats (e.g., {"wait":"0", "rate":"0.0%"})
If you cannot infer a field/column name, leave it as-is; the executor will attempt to resolve.
Output STRICT JSON only with keys: tasks, notes. DO NOT include explanations outside JSON."""
def build_parser_prompt(scenario_text: str, dataset_catalog: Dict[str, list]) -> str:
cat_lines = [f"- {k}: {', '.join(cols)}" for k, cols in dataset_catalog.items()]
catalog_str = "\n".join(cat_lines) if cat_lines else "- (no files uploaded yet)"
return f"""{_PARSER_PROMPT}
# Uploaded datasets and their columns
{catalog_str}
# Scenario
{scenario_text}
"""
def _llm_parse(prompt: str) -> str | None:
return cohere_chat(prompt) or open_fallback_chat(prompt)
def _safe_json_slice(raw: str) -> str | None:
raw = (raw or "").strip()
i, j = raw.find("{"), raw.rfind("}")
if i == -1 or j == -1: return None
return raw[i:j+1]
def _make_safety_net_plan(scenario_text: str, dataset_catalog: Dict[str, list]) -> ScenarioPlan:
csvs = list(dataset_catalog.keys())
wt = next((k for k in csvs if "wait" in k.lower()), (csvs[0] if csvs else None))
fac = next((k for k in csvs if "facility" in k.lower() or "hospital" in k.lower()), wt)
tasks: List[TaskPlan] = []
text = scenario_text.lower()
wants_map = "map" in text or "geographic distribution" in text or "geographic" in text
wants_reco = "recommend" in text or "allocation" in text or "plan" in text
tasks.append(TaskPlan(
title="Top facilities by average surgery wait",
data_key=wt, format="table",
group_by=["Facility","Zone"],
agg=["avg(Surgery_Median)","avg(Surgery_90th)","count(*)"],
sort_by="avg_Surgery_Median", sort_dir="desc", top=5,
fields=["Facility","Zone","avg_Surgery_Median","avg_Surgery_90th","count"],
number_format={"avg_Surgery_Median":"0", "avg_Surgery_90th":"0"}
))
tasks.append(TaskPlan(
title="Top specialties by average surgery wait",
data_key=wt, format="table",
group_by=["Specialty"],
agg=["avg(Surgery_Median)","avg(Consult_Median)","count(*)"],
sort_by="avg_Surgery_Median", sort_dir="desc", top=5,
fields=["Specialty","avg_Surgery_Median","avg_Consult_Median","count"],
number_format={"avg_Surgery_Median":"0", "avg_Consult_Median":"0"}
))
tasks.append(TaskPlan(
title="Zone-level surgery wait comparison",
data_key=wt, format="table",
group_by=["Zone"], agg=["avg(Surgery_Median)","count(*)"],
sort_by="avg_Surgery_Median", sort_dir="desc",
fields=["Zone","avg_Surgery_Median","count"],
number_format={"avg_Surgery_Median":"0"}
))
if wants_map and fac:
tasks.append(TaskPlan(
title="Geographic distribution of high-wait facilities",
data_key=wt, format="map",
group_by=["Facility","Zone"], agg=["avg(Surgery_Median)"],
sort_by="avg_Surgery_Median", sort_dir="desc", top=5,
joins=[{"right_key": fac, "left_on": "Facility", "right_on": "facility_name", "how": "left"}],
fields=["Facility","Zone","city","latitude","longitude","avg_Surgery_Median"],
number_format={"avg_Surgery_Median":"0"}
))
if wants_reco:
tasks.append(TaskPlan(
title="Recommendations (inputs for narrative)",
data_key=wt, format="narrative",
group_by=["Facility","Zone"], agg=["avg(Surgery_Median)","count(*)"],
sort_by="avg_Surgery_Median", sort_dir="desc", top=10,
fields=["Facility","Zone","avg_Surgery_Median","count"],
number_format={"avg_Surgery_Median":"0"}
))
return ScenarioPlan(tasks=tasks, notes="Safety-net plan (LLM planner failed).")
def parse_to_plan(scenario_text: str, dataset_catalog: Dict[str, list]) -> ScenarioPlan:
prompt = build_parser_prompt(scenario_text, dataset_catalog)
raw = _llm_parse(prompt)
if not raw:
return _make_safety_net_plan(scenario_text, dataset_catalog)
js = _safe_json_slice(raw)
if not js:
return _make_safety_net_plan(scenario_text, dataset_catalog)
try:
obj = json.loads(js)
plan = ScenarioPlan(**obj)
if not plan.tasks or (len(plan.tasks) == 1 and plan.tasks[0].format == "narrative"):
return _make_safety_net_plan(scenario_text, dataset_catalog)
return plan
except (json.JSONDecodeError, ValidationError):
return _make_safety_net_plan(scenario_text, dataset_catalog)
|