File size: 3,606 Bytes
525ff8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
# plan_extractor.py
from __future__ import annotations
from typing import Dict, Any, List, Optional
import json, re

# We’ll reuse your Cohere→HF fallback pattern from app.py
from transformers import AutoTokenizer, AutoModelForCausalLM  # only for local fallback

def _cohere_chat_fn():
    try:
        import cohere, os
        key = os.getenv("COHERE_API_KEY")
        if not key: return None
        return cohere.Client(api_key=key)
    except Exception:
        return None

def draft_plan_from_scenario(
    scenario_text: str,
    column_bag: List[str],
    cohere_client=None,
    hf_tuple: Optional[tuple]=None,
    max_tokens: int = 800
) -> Dict[str, Any]:
    """
    Returns a JSON plan that is 100% scenario-derived.
    The plan includes: goals, required_inputs (metrics/entities/time windows), and an output_format hint.
    It may reference column candidates by *ideas* (not fixed labels).
    """
    scenario = (scenario_text or "").strip()
    if not scenario:
        return {
            "goals": [], "requires": [], "output_format": "structured_analysis_v1",
            "notes": "Empty scenario text; no plan."
        }

    # Build prompt that feeds real columns to steer the plan dynamically
    col_hint = ", ".join(sorted(set(column_bag)))[:2000]
    sys = (
        "You design a data analysis plan from a scenario. "
        "You do NOT assume undocumented numbers. "
        "You only request metrics that plausibly map to the provided column headers.\n"
        "Return STRICT JSON with keys: goals, requires, output_format.\n"
        "Each goal has a type (rank_top_n | summary_table | delta_over_time | capacity_calc | cost_total | custom), "
        "and parameters (e.g., metric names, groupings, n, filters, periods). "
        "Each requires item lists an input name and a description of how it could map to columns.\n"
    )
    user = (
        f"SCENARIO:\n{scenario}\n\n"
        f"AVAILABLE COLUMN HEADERS (from uploaded files, deduped):\n{col_hint}\n\n"
        "Produce the JSON plan now. Do not invent column names; propose inputs using phrases present in the scenario "
        "or clearly mappable to the provided headers."
    )

    # Try Cohere first
    client = cohere_client or _cohere_chat_fn()
    if client is not None:
        try:
            resp = client.chat(model="command-r7b-12-2024", message=sys + "\n\n" + user, temperature=0.2, max_tokens=max_tokens)
            txt = getattr(resp, "text", None) or getattr(resp, "reply", None)
            if txt:
                # cohere may wrap in markdown; extract JSON block
                m = re.search(r"\{.*\}", txt, re.S)
                if m:
                    return json.loads(m.group(0))
        except Exception:
            pass

    # HF fallback (very lightweight)
    if hf_tuple is not None:
        model, tok = hf_tuple
        prompt = sys + "\n\n" + user + "\n\nJSON:"
        inpt = tok.apply_chat_template([{"role":"user","content":prompt}], tokenize=True, add_generation_prompt=True, return_tensors="pt")
        out = model.generate(inpt.to(model.device), max_new_tokens=max_tokens, do_sample=False)
        gen = tok.decode(out[0, inpt.shape[-1]:], skip_special_tokens=True)
        m = re.search(r"\{.*\}", gen, re.S)
        if m:
            try:
                return json.loads(m.group(0))
            except Exception:
                pass

    # Ultra-conservative fallback
    return {
        "goals": [{"type":"summary_table","metrics":[],"by":[],"note":"fallback-empty"}],
        "requires": [],
        "output_format": "structured_analysis_v1"
    }