File size: 5,484 Bytes
00c16d0
3f718dc
40da431
3f718dc
 
40da431
3f718dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40da431
3f718dc
00c16d0
3f718dc
 
 
 
 
40da431
3f718dc
 
 
40da431
00c16d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3f718dc
 
00c16d0
3f718dc
00c16d0
 
 
 
3f718dc
00c16d0
 
 
 
 
3f718dc
00c16d0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from typing import Dict, Any, List
from pydantic import ValidationError
import json
from schema import ScenarioPlan, TaskPlan
from llm_router import cohere_chat, open_fallback_chat

_PARSER_PROMPT = """You are a Canadian healthcare analysis planner.
Read the scenario and produce a JSON plan with tasks that exactly match the scenario requests.
Each task must specify: title, format, data_key (if you can infer), and any needed steps:
- filter: boolean expression using dataframe column names
- derive: list of 'col = expression' for calculated columns
- joins: list of objects {right_key, left_on, right_on, how}
- group_by: list of columns
- agg: list of aggregations like 'avg(x)', 'median(y)', 'p90(z)', 'count(*)'
- pivot: {"index":[...], "columns":"...", "values":"..."}
- sort_by, sort_dir, top
- fields: columns to output in order
- chart + encodings: when a chart is requested (x,y,color,column)
- number_format: per-column display formats (e.g., {"wait":"0", "rate":"0.0%"})
If you cannot infer a field/column name, leave it as-is; the executor will attempt to resolve.
Output STRICT JSON only with keys: tasks, notes. DO NOT include explanations outside JSON."""

def build_parser_prompt(scenario_text: str, dataset_catalog: Dict[str, list]) -> str:
    cat_lines = [f"- {k}: {', '.join(cols)}" for k, cols in dataset_catalog.items()]
    catalog_str = "\n".join(cat_lines) if cat_lines else "- (no files uploaded yet)"
    return f"""{_PARSER_PROMPT}

# Uploaded datasets and their columns
{catalog_str}

# Scenario
{scenario_text}
"""

def _llm_parse(prompt: str) -> str | None:
    return cohere_chat(prompt) or open_fallback_chat(prompt)

def _safe_json_slice(raw: str) -> str | None:
    raw = (raw or "").strip()
    i, j = raw.find("{"), raw.rfind("}")
    if i == -1 or j == -1: return None
    return raw[i:j+1]

def _make_safety_net_plan(scenario_text: str, dataset_catalog: Dict[str, list]) -> ScenarioPlan:
    csvs = list(dataset_catalog.keys())
    wt = next((k for k in csvs if "wait" in k.lower()), (csvs[0] if csvs else None))
    fac = next((k for k in csvs if "facility" in k.lower() or "hospital" in k.lower()), wt)

    tasks: List[TaskPlan] = []
    text = scenario_text.lower()
    wants_map = "map" in text or "geographic distribution" in text or "geographic" in text
    wants_reco = "recommend" in text or "allocation" in text or "plan" in text

    tasks.append(TaskPlan(
        title="Top facilities by average surgery wait",
        data_key=wt, format="table",
        group_by=["Facility","Zone"],
        agg=["avg(Surgery_Median)","avg(Surgery_90th)","count(*)"],
        sort_by="avg_Surgery_Median", sort_dir="desc", top=5,
        fields=["Facility","Zone","avg_Surgery_Median","avg_Surgery_90th","count"],
        number_format={"avg_Surgery_Median":"0", "avg_Surgery_90th":"0"}
    ))

    tasks.append(TaskPlan(
        title="Top specialties by average surgery wait",
        data_key=wt, format="table",
        group_by=["Specialty"],
        agg=["avg(Surgery_Median)","avg(Consult_Median)","count(*)"],
        sort_by="avg_Surgery_Median", sort_dir="desc", top=5,
        fields=["Specialty","avg_Surgery_Median","avg_Consult_Median","count"],
        number_format={"avg_Surgery_Median":"0", "avg_Consult_Median":"0"}
    ))

    tasks.append(TaskPlan(
        title="Zone-level surgery wait comparison",
        data_key=wt, format="table",
        group_by=["Zone"], agg=["avg(Surgery_Median)","count(*)"],
        sort_by="avg_Surgery_Median", sort_dir="desc",
        fields=["Zone","avg_Surgery_Median","count"],
        number_format={"avg_Surgery_Median":"0"}
    ))

    if wants_map and fac:
        tasks.append(TaskPlan(
            title="Geographic distribution of high-wait facilities",
            data_key=wt, format="map",
            group_by=["Facility","Zone"], agg=["avg(Surgery_Median)"],
            sort_by="avg_Surgery_Median", sort_dir="desc", top=5,
            joins=[{"right_key": fac, "left_on": "Facility", "right_on": "facility_name", "how": "left"}],
            fields=["Facility","Zone","city","latitude","longitude","avg_Surgery_Median"],
            number_format={"avg_Surgery_Median":"0"}
        ))

    if wants_reco:
        tasks.append(TaskPlan(
            title="Recommendations (inputs for narrative)",
            data_key=wt, format="narrative",
            group_by=["Facility","Zone"], agg=["avg(Surgery_Median)","count(*)"],
            sort_by="avg_Surgery_Median", sort_dir="desc", top=10,
            fields=["Facility","Zone","avg_Surgery_Median","count"],
            number_format={"avg_Surgery_Median":"0"}
        ))

    return ScenarioPlan(tasks=tasks, notes="Safety-net plan (LLM planner failed).")

def parse_to_plan(scenario_text: str, dataset_catalog: Dict[str, list]) -> ScenarioPlan:
    prompt = build_parser_prompt(scenario_text, dataset_catalog)
    raw = _llm_parse(prompt)
    if not raw:
        return _make_safety_net_plan(scenario_text, dataset_catalog)
    js = _safe_json_slice(raw)
    if not js:
        return _make_safety_net_plan(scenario_text, dataset_catalog)
    try:
        obj = json.loads(js)
        plan = ScenarioPlan(**obj)
        if not plan.tasks or (len(plan.tasks) == 1 and plan.tasks[0].format == "narrative"):
            return _make_safety_net_plan(scenario_text, dataset_catalog)
        return plan
    except (json.JSONDecodeError, ValidationError):
        return _make_safety_net_plan(scenario_text, dataset_catalog)