File size: 4,382 Bytes
bebb279
 
 
 
 
e4544a7
bebb279
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e4544a7
 
 
 
bebb279
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bd7254f
bebb279
 
 
 
 
 
 
 
e4544a7
 
 
 
bebb279
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
# app/tools/llm_sqlgen.py
from __future__ import annotations
from typing import Optional, Dict, Any
import requests, json

HF_CHAT_URL = "https://router.huggingface.co/v1/chat/completions"

SCHEMA_SPEC = """
Tables and columns (SQLite):

dim_region(code, name)
dim_product(sku, category, name, price)
dim_employee(emp_id, name, region_code, role, hire_date)

fact_sales(day, region_code, sku, channel, units, revenue)
fact_sales_detail(day, region_code, sku, channel, employee_id, units, revenue)

inv_stock(day, region_code, sku, on_hand_qty)

Rules:
- Use only SELECT. Never modify data.
- Prefer ISO date literals 'YYYY-MM-DD'.
- Region codes are 3 letters: NCR, BLR, MUM, HYD, CHN, PUN.
- For monthly rollups use strftime('%Y-%m', day).
- Join to dim_product when you need category/name/price.
- For per-employee metrics use fact_sales_detail (employee_id may be NULL for Online).
- Always generate the SQL Queries in English 
for example. 
    "q": रमेश का टोटल जेनरेटेड रेवेन्यू बताओ 
    "sql": SELECT SUM(d.revenue) AS total_revenue FROM fact_sales_detail d JOIN dim_employee e ON e.emp_id = d.employee_id WHERE e.name LIKE 'Ramesh %'
"""

FEW_SHOTS = [
    {
        "q": "What is monthly revenue for Electronics in BLR for 2025-09?",
        "sql": """SELECT strftime('%Y-%m', fs.day) AS month, SUM(fs.revenue) AS revenue
FROM fact_sales fs
JOIN dim_product p ON p.sku = fs.sku
WHERE fs.region_code='BLR' AND p.category='Electronics' AND fs.day BETWEEN '2025-09-01' AND '2025-09-30'
GROUP BY month
ORDER BY month"""
    },
    {
        "q": "Show Ramesh's sales (units and revenue) in NCR on 2025-09-06",
        "sql": """SELECT e.name, d.units, d.revenue
FROM fact_sales_detail d
JOIN dim_employee e ON e.emp_id = d.employee_id
WHERE e.name LIKE 'Ramesh %' AND d.region_code='NCR' AND d.day='2025-09-06'"""
    },
    {
        "q": "What's the on-hand stock for sku ELEC-002 in MUM on 2025-09-05?",
        "sql": """SELECT on_hand_qty
FROM inv_stock
WHERE region_code='MUM' AND sku='ELEC-002' AND day='2025-09-05'"""
    },
    {
        "q": "Top 5 SKUs by revenue in HYD on 2025-09-06 (include category)",
        "sql": """SELECT fs.sku, p.category, SUM(fs.revenue) AS rev
FROM fact_sales fs
JOIN dim_product p ON p.sku=fs.sku
WHERE fs.region_code='HYD' AND fs.day='2025-09-06'
GROUP BY fs.sku, p.category
ORDER BY rev DESC
LIMIT 5"""
    }
]

class SQLGenTool:
    def __init__(self, model_id: str, token: Optional[str], temperature: float = 0.0, max_tokens: int = 400, timeout: int = 60):
        self.model_id = model_id
        self.token = token
        self.temperature = temperature
        self.max_tokens = max_tokens
        self.timeout = timeout
        self.enabled = bool(token and model_id)

    def set_token(self, token: Optional[str]) -> None:
        self.token = token
        self.enabled = bool(token and self.model_id)

    def generate_sql(self, question: str) -> str:
        if not self.enabled:
            raise RuntimeError("SQLGenTool disabled: missing HF token or model_id.")
        fewshot_txt = "\n".join([f"Q: {ex['q']}\nSQL:\n{ex['sql']}\n" for ex in FEW_SHOTS])
        sys = (
            "You are a SQL generator. Output only a single JSON object: {\"sql\": \"...\"}.\n"
            "No prose. No explanations. Use the provided schema only.\n" + SCHEMA_SPEC
        )
        user = f"Question:\n{question}\n\nReturn JSON with a single key 'sql'."
        payload = {
            "model": self.model_id,
            "stream": False,
            "messages": [
                {"role":"system","content":[{"type":"text","text":sys}]},
                {"role":"user","content":[{"type":"text","text":fewshot_txt + "\n\n" + user}]},
            ],
            "temperature": self.temperature,
            "max_tokens": self.max_tokens,
        }
        headers = {"Authorization": f"Bearer {self.token}",
                    "Accept": "application/json",
                    "Accept-Encoding": "identity"
                }
        r = requests.post(HF_CHAT_URL, headers=headers, json=payload, timeout=self.timeout)
        r.raise_for_status()
        content = r.json()["choices"][0]["message"]["content"].strip()
        s, e = content.find("{"), content.rfind("}")
        obj = json.loads(content[s:e+1])
        sql = obj.get("sql","").strip()
        return sql