Spaces:
Sleeping
Sleeping
Rajan Sharma
commited on
Update scenario_planner.py
Browse files- scenario_planner.py +45 -38
scenario_planner.py
CHANGED
|
@@ -1,43 +1,50 @@
|
|
|
|
|
|
|
|
| 1 |
import json
|
| 2 |
-
from
|
| 3 |
-
from
|
| 4 |
-
from llm_router import generate_with_fallback
|
| 5 |
|
| 6 |
-
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
"agg": ["sum(col)","avg(col)",...]|null,
|
| 19 |
-
"pivot": {"index":"a","columns":"b","values":"c"}|null,
|
| 20 |
-
"join": [{"right_key":"ds","left_on":"x","right_on":"y","how":"left"}]|null,
|
| 21 |
-
"sort_by": "col|null",
|
| 22 |
-
"sort_dir": "asc|desc",
|
| 23 |
-
"top": int|null,
|
| 24 |
-
"fields": ["col"]|null,
|
| 25 |
-
"chart": "bar|line|area|point|tick|rule"|null,
|
| 26 |
-
"x": "col|null", "y": "col|null", "color":"col|null", "column":"col|null"
|
| 27 |
-
}
|
| 28 |
-
],
|
| 29 |
-
"narrative_required": true,
|
| 30 |
-
"notes": "optional"
|
| 31 |
-
}
|
| 32 |
-
"""
|
| 33 |
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
|
| 38 |
-
def
|
| 39 |
-
prompt =
|
| 40 |
-
raw =
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, Any
|
| 2 |
+
from pydantic import ValidationError
|
| 3 |
import json
|
| 4 |
+
from schema import ScenarioPlan, TaskPlan
|
| 5 |
+
from llm_router import cohere_chat, open_fallback_chat
|
|
|
|
| 6 |
|
| 7 |
+
_PARSER_PROMPT = """You are a Canadian healthcare analysis planner.
|
| 8 |
+
Read the scenario and produce a JSON plan with tasks that exactly match the scenario requests.
|
| 9 |
+
Each task must specify: title, format, data_key (if you can infer), and any needed steps:
|
| 10 |
+
- filter: boolean expression using dataframe column names
|
| 11 |
+
- derive: list of 'col = expression' for calculated columns
|
| 12 |
+
- joins: list of objects {right_key, left_on, right_on, how}
|
| 13 |
+
- group_by: list of columns
|
| 14 |
+
- agg: list of aggregations like 'avg(x)', 'median(y)', 'p90(z)', 'count(*)'
|
| 15 |
+
- pivot: {"index":[...], "columns":"...", "values":"..."}
|
| 16 |
+
- sort_by, sort_dir, top
|
| 17 |
+
- fields: columns to output in order
|
| 18 |
+
- chart + encodings: when a chart is requested (x,y,color,column)
|
| 19 |
+
- number_format: per-column display formats (e.g., {"wait":"0", "rate":"0.0%"})
|
| 20 |
+
If you cannot infer a field/column name, leave it as-is; the executor will attempt to resolve.
|
| 21 |
+
Output STRICT JSON only with keys: tasks, notes. DO NOT include explanations outside JSON."""
|
| 22 |
|
| 23 |
+
def build_parser_prompt(scenario_text: str, dataset_catalog: Dict[str, list]) -> str:
|
| 24 |
+
cat_lines = []
|
| 25 |
+
for k, cols in dataset_catalog.items():
|
| 26 |
+
cat_lines.append(f"- {k}: {', '.join(cols)}")
|
| 27 |
+
catalog_str = "\n".join(cat_lines) if cat_lines else "- (no files uploaded yet)"
|
| 28 |
+
return f"""{_PARSER_PROMPT}
|
| 29 |
+
|
| 30 |
+
# Uploaded datasets and their columns
|
| 31 |
+
{catalog_str}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
+
# Scenario
|
| 34 |
+
{scenario_text}
|
| 35 |
+
"""
|
| 36 |
|
| 37 |
+
def parse_to_plan(scenario_text: str, dataset_catalog: Dict[str, list]) -> ScenarioPlan:
|
| 38 |
+
prompt = build_parser_prompt(scenario_text, dataset_catalog)
|
| 39 |
+
raw = cohere_chat(prompt) or open_fallback_chat(prompt)
|
| 40 |
+
if not raw:
|
| 41 |
+
return ScenarioPlan(tasks=[TaskPlan(title="Scenario Summary", format="narrative")], notes="LLM unavailable; minimal plan.")
|
| 42 |
+
raw = raw.strip()
|
| 43 |
+
start = raw.find("{"); end = raw.rfind("}")
|
| 44 |
+
if start != -1 and end != -1:
|
| 45 |
+
raw = raw[start:end+1]
|
| 46 |
+
try:
|
| 47 |
+
obj = json.loads(raw)
|
| 48 |
+
return ScenarioPlan(**obj)
|
| 49 |
+
except (json.JSONDecodeError, ValidationError):
|
| 50 |
+
return ScenarioPlan(tasks=[TaskPlan(title="Scenario Summary", format="narrative")], notes="Plan validation failed, fallback.")
|