Rajan Sharma commited on
Commit
40da431
·
verified ·
1 Parent(s): 21827d6

Create scenario_planner.py

Browse files
Files changed (1) hide show
  1. scenario_planner.py +43 -0
scenario_planner.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from schemas import ScenarioPlan
3
+ from settings import HEALTHCARE_SYSTEM_PROMPT
4
+ from llm_router import generate_with_fallback
5
+
6
+ PLAN_INSTRUCTIONS = """
7
+ Return ONLY valid JSON. Schema:
8
+
9
+ {
10
+ "tasks": [
11
+ {
12
+ "title": "string",
13
+ "data_key": "string|null",
14
+ "format": "table|list|comparison|map|narrative|chart",
15
+ "filter": "expr|null",
16
+ "derive": ["col=expr"]|null,
17
+ "group_by": ["col"]|null,
18
+ "agg": ["sum(col)","avg(col)",...]|null,
19
+ "pivot": {"index":"a","columns":"b","values":"c"}|null,
20
+ "join": [{"right_key":"ds","left_on":"x","right_on":"y","how":"left"}]|null,
21
+ "sort_by": "col|null",
22
+ "sort_dir": "asc|desc",
23
+ "top": int|null,
24
+ "fields": ["col"]|null,
25
+ "chart": "bar|line|area|point|tick|rule"|null,
26
+ "x": "col|null", "y": "col|null", "color":"col|null", "column":"col|null"
27
+ }
28
+ ],
29
+ "narrative_required": true,
30
+ "notes": "optional"
31
+ }
32
+ """
33
+
34
+ def build_prompt(scenario: str, catalog: dict) -> str:
35
+ catalog_str = "\n".join([f"- {k}: {', '.join(v)}" for k,v in catalog.items()])
36
+ return f"{HEALTHCARE_SYSTEM_PROMPT}\n\nDATASETS:\n{catalog_str}\n\n{PLAN_INSTRUCTIONS}\n\nSCENARIO:\n{scenario}\n\nJSON:"
37
+
38
+ def plan_from_llm(scenario: str, catalog: dict) -> ScenarioPlan:
39
+ prompt = build_prompt(scenario, catalog)
40
+ raw = generate_with_fallback(prompt)
41
+ start, end = raw.find("{"), raw.rfind("}")
42
+ data = json.loads(raw[start:end+1])
43
+ return ScenarioPlan.model_validate(data)