File size: 1,486 Bytes
40da431
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import json
from schemas import ScenarioPlan
from settings import HEALTHCARE_SYSTEM_PROMPT
from llm_router import generate_with_fallback

PLAN_INSTRUCTIONS = """
Return ONLY valid JSON. Schema:

{
  "tasks": [
    {
      "title": "string",
      "data_key": "string|null",
      "format": "table|list|comparison|map|narrative|chart",
      "filter": "expr|null",
      "derive": ["col=expr"]|null,
      "group_by": ["col"]|null,
      "agg": ["sum(col)","avg(col)",...]|null,
      "pivot": {"index":"a","columns":"b","values":"c"}|null,
      "join": [{"right_key":"ds","left_on":"x","right_on":"y","how":"left"}]|null,
      "sort_by": "col|null",
      "sort_dir": "asc|desc",
      "top": int|null,
      "fields": ["col"]|null,
      "chart": "bar|line|area|point|tick|rule"|null,
      "x": "col|null", "y": "col|null", "color":"col|null", "column":"col|null"
    }
  ],
  "narrative_required": true,
  "notes": "optional"
}
"""

def build_prompt(scenario: str, catalog: dict) -> str:
    catalog_str = "\n".join([f"- {k}: {', '.join(v)}" for k,v in catalog.items()])
    return f"{HEALTHCARE_SYSTEM_PROMPT}\n\nDATASETS:\n{catalog_str}\n\n{PLAN_INSTRUCTIONS}\n\nSCENARIO:\n{scenario}\n\nJSON:"

def plan_from_llm(scenario: str, catalog: dict) -> ScenarioPlan:
    prompt = build_prompt(scenario, catalog)
    raw = generate_with_fallback(prompt)
    start, end = raw.find("{"), raw.rfind("}")
    data = json.loads(raw[start:end+1])
    return ScenarioPlan.model_validate(data)