import json from schemas import ScenarioPlan from settings import HEALTHCARE_SYSTEM_PROMPT from llm_router import generate_with_fallback PLAN_INSTRUCTIONS = """ Return ONLY valid JSON. Schema: { "tasks": [ { "title": "string", "data_key": "string|null", "format": "table|list|comparison|map|narrative|chart", "filter": "expr|null", "derive": ["col=expr"]|null, "group_by": ["col"]|null, "agg": ["sum(col)","avg(col)",...]|null, "pivot": {"index":"a","columns":"b","values":"c"}|null, "join": [{"right_key":"ds","left_on":"x","right_on":"y","how":"left"}]|null, "sort_by": "col|null", "sort_dir": "asc|desc", "top": int|null, "fields": ["col"]|null, "chart": "bar|line|area|point|tick|rule"|null, "x": "col|null", "y": "col|null", "color":"col|null", "column":"col|null" } ], "narrative_required": true, "notes": "optional" } """ def build_prompt(scenario: str, catalog: dict) -> str: catalog_str = "\n".join([f"- {k}: {', '.join(v)}" for k,v in catalog.items()]) return f"{HEALTHCARE_SYSTEM_PROMPT}\n\nDATASETS:\n{catalog_str}\n\n{PLAN_INSTRUCTIONS}\n\nSCENARIO:\n{scenario}\n\nJSON:" def plan_from_llm(scenario: str, catalog: dict) -> ScenarioPlan: prompt = build_prompt(scenario, catalog) raw = generate_with_fallback(prompt) start, end = raw.find("{"), raw.rfind("}") data = json.loads(raw[start:end+1]) return ScenarioPlan.model_validate(data)