Rajan Sharma commited on
Commit
3f718dc
·
verified ·
1 Parent(s): b374ec8

Update scenario_planner.py

Browse files
Files changed (1) hide show
  1. scenario_planner.py +45 -38
scenario_planner.py CHANGED
@@ -1,43 +1,50 @@
 
 
1
  import json
2
- from schemas import ScenarioPlan
3
- from settings import HEALTHCARE_SYSTEM_PROMPT
4
- from llm_router import generate_with_fallback
5
 
6
- PLAN_INSTRUCTIONS = """
7
- Return ONLY valid JSON. Schema:
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
- {
10
- "tasks": [
11
- {
12
- "title": "string",
13
- "data_key": "string|null",
14
- "format": "table|list|comparison|map|narrative|chart",
15
- "filter": "expr|null",
16
- "derive": ["col=expr"]|null,
17
- "group_by": ["col"]|null,
18
- "agg": ["sum(col)","avg(col)",...]|null,
19
- "pivot": {"index":"a","columns":"b","values":"c"}|null,
20
- "join": [{"right_key":"ds","left_on":"x","right_on":"y","how":"left"}]|null,
21
- "sort_by": "col|null",
22
- "sort_dir": "asc|desc",
23
- "top": int|null,
24
- "fields": ["col"]|null,
25
- "chart": "bar|line|area|point|tick|rule"|null,
26
- "x": "col|null", "y": "col|null", "color":"col|null", "column":"col|null"
27
- }
28
- ],
29
- "narrative_required": true,
30
- "notes": "optional"
31
- }
32
- """
33
 
34
- def build_prompt(scenario: str, catalog: dict) -> str:
35
- catalog_str = "\n".join([f"- {k}: {', '.join(v)}" for k,v in catalog.items()])
36
- return f"{HEALTHCARE_SYSTEM_PROMPT}\n\nDATASETS:\n{catalog_str}\n\n{PLAN_INSTRUCTIONS}\n\nSCENARIO:\n{scenario}\n\nJSON:"
37
 
38
- def plan_from_llm(scenario: str, catalog: dict) -> ScenarioPlan:
39
- prompt = build_prompt(scenario, catalog)
40
- raw = generate_with_fallback(prompt)
41
- start, end = raw.find("{"), raw.rfind("}")
42
- data = json.loads(raw[start:end+1])
43
- return ScenarioPlan.model_validate(data)
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Any
2
+ from pydantic import ValidationError
3
  import json
4
+ from schema import ScenarioPlan, TaskPlan
5
+ from llm_router import cohere_chat, open_fallback_chat
 
6
 
7
+ _PARSER_PROMPT = """You are a Canadian healthcare analysis planner.
8
+ Read the scenario and produce a JSON plan with tasks that exactly match the scenario requests.
9
+ Each task must specify: title, format, data_key (if you can infer), and any needed steps:
10
+ - filter: boolean expression using dataframe column names
11
+ - derive: list of 'col = expression' for calculated columns
12
+ - joins: list of objects {right_key, left_on, right_on, how}
13
+ - group_by: list of columns
14
+ - agg: list of aggregations like 'avg(x)', 'median(y)', 'p90(z)', 'count(*)'
15
+ - pivot: {"index":[...], "columns":"...", "values":"..."}
16
+ - sort_by, sort_dir, top
17
+ - fields: columns to output in order
18
+ - chart + encodings: when a chart is requested (x,y,color,column)
19
+ - number_format: per-column display formats (e.g., {"wait":"0", "rate":"0.0%"})
20
+ If you cannot infer a field/column name, leave it as-is; the executor will attempt to resolve.
21
+ Output STRICT JSON only with keys: tasks, notes. DO NOT include explanations outside JSON."""
22
 
23
+ def build_parser_prompt(scenario_text: str, dataset_catalog: Dict[str, list]) -> str:
24
+ cat_lines = []
25
+ for k, cols in dataset_catalog.items():
26
+ cat_lines.append(f"- {k}: {', '.join(cols)}")
27
+ catalog_str = "\n".join(cat_lines) if cat_lines else "- (no files uploaded yet)"
28
+ return f"""{_PARSER_PROMPT}
29
+
30
+ # Uploaded datasets and their columns
31
+ {catalog_str}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
+ # Scenario
34
+ {scenario_text}
35
+ """
36
 
37
+ def parse_to_plan(scenario_text: str, dataset_catalog: Dict[str, list]) -> ScenarioPlan:
38
+ prompt = build_parser_prompt(scenario_text, dataset_catalog)
39
+ raw = cohere_chat(prompt) or open_fallback_chat(prompt)
40
+ if not raw:
41
+ return ScenarioPlan(tasks=[TaskPlan(title="Scenario Summary", format="narrative")], notes="LLM unavailable; minimal plan.")
42
+ raw = raw.strip()
43
+ start = raw.find("{"); end = raw.rfind("}")
44
+ if start != -1 and end != -1:
45
+ raw = raw[start:end+1]
46
+ try:
47
+ obj = json.loads(raw)
48
+ return ScenarioPlan(**obj)
49
+ except (json.JSONDecodeError, ValidationError):
50
+ return ScenarioPlan(tasks=[TaskPlan(title="Scenario Summary", format="narrative")], notes="Plan validation failed, fallback.")