Rajan Sharma commited on
Commit
525ff8a
·
verified ·
1 Parent(s): 76fca29

Create plan_extractor.py

Browse files
Files changed (1) hide show
  1. plan_extractor.py +88 -0
plan_extractor.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # plan_extractor.py
2
+ from __future__ import annotations
3
+ from typing import Dict, Any, List, Optional
4
+ import json, re
5
+
6
+ # We’ll reuse your Cohere→HF fallback pattern from app.py
7
+ from transformers import AutoTokenizer, AutoModelForCausalLM # only for local fallback
8
+
9
+ def _cohere_chat_fn():
10
+ try:
11
+ import cohere, os
12
+ key = os.getenv("COHERE_API_KEY")
13
+ if not key: return None
14
+ return cohere.Client(api_key=key)
15
+ except Exception:
16
+ return None
17
+
18
+ def draft_plan_from_scenario(
19
+ scenario_text: str,
20
+ column_bag: List[str],
21
+ cohere_client=None,
22
+ hf_tuple: Optional[tuple]=None,
23
+ max_tokens: int = 800
24
+ ) -> Dict[str, Any]:
25
+ """
26
+ Returns a JSON plan that is 100% scenario-derived.
27
+ The plan includes: goals, required_inputs (metrics/entities/time windows), and an output_format hint.
28
+ It may reference column candidates by *ideas* (not fixed labels).
29
+ """
30
+ scenario = (scenario_text or "").strip()
31
+ if not scenario:
32
+ return {
33
+ "goals": [], "requires": [], "output_format": "structured_analysis_v1",
34
+ "notes": "Empty scenario text; no plan."
35
+ }
36
+
37
+ # Build prompt that feeds real columns to steer the plan dynamically
38
+ col_hint = ", ".join(sorted(set(column_bag)))[:2000]
39
+ sys = (
40
+ "You design a data analysis plan from a scenario. "
41
+ "You do NOT assume undocumented numbers. "
42
+ "You only request metrics that plausibly map to the provided column headers.\n"
43
+ "Return STRICT JSON with keys: goals, requires, output_format.\n"
44
+ "Each goal has a type (rank_top_n | summary_table | delta_over_time | capacity_calc | cost_total | custom), "
45
+ "and parameters (e.g., metric names, groupings, n, filters, periods). "
46
+ "Each requires item lists an input name and a description of how it could map to columns.\n"
47
+ )
48
+ user = (
49
+ f"SCENARIO:\n{scenario}\n\n"
50
+ f"AVAILABLE COLUMN HEADERS (from uploaded files, deduped):\n{col_hint}\n\n"
51
+ "Produce the JSON plan now. Do not invent column names; propose inputs using phrases present in the scenario "
52
+ "or clearly mappable to the provided headers."
53
+ )
54
+
55
+ # Try Cohere first
56
+ client = cohere_client or _cohere_chat_fn()
57
+ if client is not None:
58
+ try:
59
+ resp = client.chat(model="command-r7b-12-2024", message=sys + "\n\n" + user, temperature=0.2, max_tokens=max_tokens)
60
+ txt = getattr(resp, "text", None) or getattr(resp, "reply", None)
61
+ if txt:
62
+ # cohere may wrap in markdown; extract JSON block
63
+ m = re.search(r"\{.*\}", txt, re.S)
64
+ if m:
65
+ return json.loads(m.group(0))
66
+ except Exception:
67
+ pass
68
+
69
+ # HF fallback (very lightweight)
70
+ if hf_tuple is not None:
71
+ model, tok = hf_tuple
72
+ prompt = sys + "\n\n" + user + "\n\nJSON:"
73
+ inpt = tok.apply_chat_template([{"role":"user","content":prompt}], tokenize=True, add_generation_prompt=True, return_tensors="pt")
74
+ out = model.generate(inpt.to(model.device), max_new_tokens=max_tokens, do_sample=False)
75
+ gen = tok.decode(out[0, inpt.shape[-1]:], skip_special_tokens=True)
76
+ m = re.search(r"\{.*\}", gen, re.S)
77
+ if m:
78
+ try:
79
+ return json.loads(m.group(0))
80
+ except Exception:
81
+ pass
82
+
83
+ # Ultra-conservative fallback
84
+ return {
85
+ "goals": [{"type":"summary_table","metrics":[],"by":[],"note":"fallback-empty"}],
86
+ "requires": [],
87
+ "output_format": "structured_analysis_v1"
88
+ }