Medica_DecisionSupportAI / scenario_planner.py
Rajan Sharma
Update scenario_planner.py
3f718dc verified
raw
history blame
2.33 kB
from typing import Dict, Any
from pydantic import ValidationError
import json
from schema import ScenarioPlan, TaskPlan
from llm_router import cohere_chat, open_fallback_chat
_PARSER_PROMPT = """You are a Canadian healthcare analysis planner.
Read the scenario and produce a JSON plan with tasks that exactly match the scenario requests.
Each task must specify: title, format, data_key (if you can infer), and any needed steps:
- filter: boolean expression using dataframe column names
- derive: list of 'col = expression' for calculated columns
- joins: list of objects {right_key, left_on, right_on, how}
- group_by: list of columns
- agg: list of aggregations like 'avg(x)', 'median(y)', 'p90(z)', 'count(*)'
- pivot: {"index":[...], "columns":"...", "values":"..."}
- sort_by, sort_dir, top
- fields: columns to output in order
- chart + encodings: when a chart is requested (x,y,color,column)
- number_format: per-column display formats (e.g., {"wait":"0", "rate":"0.0%"})
If you cannot infer a field/column name, leave it as-is; the executor will attempt to resolve.
Output STRICT JSON only with keys: tasks, notes. DO NOT include explanations outside JSON."""
def build_parser_prompt(scenario_text: str, dataset_catalog: Dict[str, list]) -> str:
cat_lines = []
for k, cols in dataset_catalog.items():
cat_lines.append(f"- {k}: {', '.join(cols)}")
catalog_str = "\n".join(cat_lines) if cat_lines else "- (no files uploaded yet)"
return f"""{_PARSER_PROMPT}
# Uploaded datasets and their columns
{catalog_str}
# Scenario
{scenario_text}
"""
def parse_to_plan(scenario_text: str, dataset_catalog: Dict[str, list]) -> ScenarioPlan:
prompt = build_parser_prompt(scenario_text, dataset_catalog)
raw = cohere_chat(prompt) or open_fallback_chat(prompt)
if not raw:
return ScenarioPlan(tasks=[TaskPlan(title="Scenario Summary", format="narrative")], notes="LLM unavailable; minimal plan.")
raw = raw.strip()
start = raw.find("{"); end = raw.rfind("}")
if start != -1 and end != -1:
raw = raw[start:end+1]
try:
obj = json.loads(raw)
return ScenarioPlan(**obj)
except (json.JSONDecodeError, ValidationError):
return ScenarioPlan(tasks=[TaskPlan(title="Scenario Summary", format="narrative")], notes="Plan validation failed, fallback.")