Spaces:
Sleeping
Sleeping
File size: 3,606 Bytes
525ff8a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
# plan_extractor.py
from __future__ import annotations
from typing import Dict, Any, List, Optional
import json, re
# We’ll reuse your Cohere→HF fallback pattern from app.py
from transformers import AutoTokenizer, AutoModelForCausalLM # only for local fallback
def _cohere_chat_fn():
try:
import cohere, os
key = os.getenv("COHERE_API_KEY")
if not key: return None
return cohere.Client(api_key=key)
except Exception:
return None
def draft_plan_from_scenario(
scenario_text: str,
column_bag: List[str],
cohere_client=None,
hf_tuple: Optional[tuple]=None,
max_tokens: int = 800
) -> Dict[str, Any]:
"""
Returns a JSON plan that is 100% scenario-derived.
The plan includes: goals, required_inputs (metrics/entities/time windows), and an output_format hint.
It may reference column candidates by *ideas* (not fixed labels).
"""
scenario = (scenario_text or "").strip()
if not scenario:
return {
"goals": [], "requires": [], "output_format": "structured_analysis_v1",
"notes": "Empty scenario text; no plan."
}
# Build prompt that feeds real columns to steer the plan dynamically
col_hint = ", ".join(sorted(set(column_bag)))[:2000]
sys = (
"You design a data analysis plan from a scenario. "
"You do NOT assume undocumented numbers. "
"You only request metrics that plausibly map to the provided column headers.\n"
"Return STRICT JSON with keys: goals, requires, output_format.\n"
"Each goal has a type (rank_top_n | summary_table | delta_over_time | capacity_calc | cost_total | custom), "
"and parameters (e.g., metric names, groupings, n, filters, periods). "
"Each requires item lists an input name and a description of how it could map to columns.\n"
)
user = (
f"SCENARIO:\n{scenario}\n\n"
f"AVAILABLE COLUMN HEADERS (from uploaded files, deduped):\n{col_hint}\n\n"
"Produce the JSON plan now. Do not invent column names; propose inputs using phrases present in the scenario "
"or clearly mappable to the provided headers."
)
# Try Cohere first
client = cohere_client or _cohere_chat_fn()
if client is not None:
try:
resp = client.chat(model="command-r7b-12-2024", message=sys + "\n\n" + user, temperature=0.2, max_tokens=max_tokens)
txt = getattr(resp, "text", None) or getattr(resp, "reply", None)
if txt:
# cohere may wrap in markdown; extract JSON block
m = re.search(r"\{.*\}", txt, re.S)
if m:
return json.loads(m.group(0))
except Exception:
pass
# HF fallback (very lightweight)
if hf_tuple is not None:
model, tok = hf_tuple
prompt = sys + "\n\n" + user + "\n\nJSON:"
inpt = tok.apply_chat_template([{"role":"user","content":prompt}], tokenize=True, add_generation_prompt=True, return_tensors="pt")
out = model.generate(inpt.to(model.device), max_new_tokens=max_tokens, do_sample=False)
gen = tok.decode(out[0, inpt.shape[-1]:], skip_special_tokens=True)
m = re.search(r"\{.*\}", gen, re.S)
if m:
try:
return json.loads(m.group(0))
except Exception:
pass
# Ultra-conservative fallback
return {
"goals": [{"type":"summary_table","metrics":[],"by":[],"note":"fallback-empty"}],
"requires": [],
"output_format": "structured_analysis_v1"
}
|