File size: 2,333 Bytes
3f718dc
 
40da431
3f718dc
 
40da431
3f718dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40da431
3f718dc
 
 
 
 
 
 
 
 
40da431
3f718dc
 
 
40da431
3f718dc
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from typing import Dict, Any
from pydantic import ValidationError
import json
from schema import ScenarioPlan, TaskPlan
from llm_router import cohere_chat, open_fallback_chat

_PARSER_PROMPT = """You are a Canadian healthcare analysis planner.
Read the scenario and produce a JSON plan with tasks that exactly match the scenario requests.
Each task must specify: title, format, data_key (if you can infer), and any needed steps:
- filter: boolean expression using dataframe column names
- derive: list of 'col = expression' for calculated columns
- joins: list of objects {right_key, left_on, right_on, how}
- group_by: list of columns
- agg: list of aggregations like 'avg(x)', 'median(y)', 'p90(z)', 'count(*)'
- pivot: {"index":[...], "columns":"...", "values":"..."}
- sort_by, sort_dir, top
- fields: columns to output in order
- chart + encodings: when a chart is requested (x,y,color,column)
- number_format: per-column display formats (e.g., {"wait":"0", "rate":"0.0%"})
If you cannot infer a field/column name, leave it as-is; the executor will attempt to resolve.
Output STRICT JSON only with keys: tasks, notes. DO NOT include explanations outside JSON."""

def build_parser_prompt(scenario_text: str, dataset_catalog: Dict[str, list]) -> str:
    cat_lines = []
    for k, cols in dataset_catalog.items():
        cat_lines.append(f"- {k}: {', '.join(cols)}")
    catalog_str = "\n".join(cat_lines) if cat_lines else "- (no files uploaded yet)"
    return f"""{_PARSER_PROMPT}

# Uploaded datasets and their columns
{catalog_str}

# Scenario
{scenario_text}
"""

def parse_to_plan(scenario_text: str, dataset_catalog: Dict[str, list]) -> ScenarioPlan:
    prompt = build_parser_prompt(scenario_text, dataset_catalog)
    raw = cohere_chat(prompt) or open_fallback_chat(prompt)
    if not raw:
        return ScenarioPlan(tasks=[TaskPlan(title="Scenario Summary", format="narrative")], notes="LLM unavailable; minimal plan.")
    raw = raw.strip()
    start = raw.find("{"); end = raw.rfind("}")
    if start != -1 and end != -1:
        raw = raw[start:end+1]
    try:
        obj = json.loads(raw)
        return ScenarioPlan(**obj)
    except (json.JSONDecodeError, ValidationError):
        return ScenarioPlan(tasks=[TaskPlan(title="Scenario Summary", format="narrative")], notes="Plan validation failed, fallback.")