Spaces:
Sleeping
Sleeping
| # app/tools/llm_sqlgen.py | |
| from __future__ import annotations | |
| from typing import Optional, Dict, Any | |
| import requests, json | |
| HF_CHAT_URL = "https://router.huggingface.co/v1/chat/completions" | |
| SCHEMA_SPEC = """ | |
| Tables and columns (SQLite): | |
| dim_region(code, name) | |
| dim_product(sku, category, name, price) | |
| dim_employee(emp_id, name, region_code, role, hire_date) | |
| fact_sales(day, region_code, sku, channel, units, revenue) | |
| fact_sales_detail(day, region_code, sku, channel, employee_id, units, revenue) | |
| inv_stock(day, region_code, sku, on_hand_qty) | |
| Rules: | |
| - Use only SELECT. Never modify data. | |
| - Prefer ISO date literals 'YYYY-MM-DD'. | |
| - Region codes are 3 letters: NCR, BLR, MUM, HYD, CHN, PUN. | |
| - For monthly rollups use strftime('%Y-%m', day). | |
| - Join to dim_product when you need category/name/price. | |
| - For per-employee metrics use fact_sales_detail (employee_id may be NULL for Online). | |
| - Always generate the SQL Queries in English | |
| for example. | |
| "q": रमेश का टोटल जेनरेटेड रेवेन्यू बताओ | |
| "sql": SELECT SUM(d.revenue) AS total_revenue FROM fact_sales_detail d JOIN dim_employee e ON e.emp_id = d.employee_id WHERE e.name LIKE 'Ramesh %' | |
| """ | |
| FEW_SHOTS = [ | |
| { | |
| "q": "What is monthly revenue for Electronics in BLR for 2025-09?", | |
| "sql": """SELECT strftime('%Y-%m', fs.day) AS month, SUM(fs.revenue) AS revenue | |
| FROM fact_sales fs | |
| JOIN dim_product p ON p.sku = fs.sku | |
| WHERE fs.region_code='BLR' AND p.category='Electronics' AND fs.day BETWEEN '2025-09-01' AND '2025-09-30' | |
| GROUP BY month | |
| ORDER BY month""" | |
| }, | |
| { | |
| "q": "Show Ramesh's sales (units and revenue) in NCR on 2025-09-06", | |
| "sql": """SELECT e.name, d.units, d.revenue | |
| FROM fact_sales_detail d | |
| JOIN dim_employee e ON e.emp_id = d.employee_id | |
| WHERE e.name LIKE 'Ramesh %' AND d.region_code='NCR' AND d.day='2025-09-06'""" | |
| }, | |
| { | |
| "q": "What's the on-hand stock for sku ELEC-002 in MUM on 2025-09-05?", | |
| "sql": """SELECT on_hand_qty | |
| FROM inv_stock | |
| WHERE region_code='MUM' AND sku='ELEC-002' AND day='2025-09-05'""" | |
| }, | |
| { | |
| "q": "Top 5 SKUs by revenue in HYD on 2025-09-06 (include category)", | |
| "sql": """SELECT fs.sku, p.category, SUM(fs.revenue) AS rev | |
| FROM fact_sales fs | |
| JOIN dim_product p ON p.sku=fs.sku | |
| WHERE fs.region_code='HYD' AND fs.day='2025-09-06' | |
| GROUP BY fs.sku, p.category | |
| ORDER BY rev DESC | |
| LIMIT 5""" | |
| } | |
| ] | |
| class SQLGenTool: | |
| def __init__(self, model_id: str, token: Optional[str], temperature: float = 0.0, max_tokens: int = 400, timeout: int = 60): | |
| self.model_id = model_id | |
| self.token = token | |
| self.temperature = temperature | |
| self.max_tokens = max_tokens | |
| self.timeout = timeout | |
| self.enabled = bool(token and model_id) | |
| def set_token(self, token: Optional[str]) -> None: | |
| self.token = token | |
| self.enabled = bool(token and self.model_id) | |
| def generate_sql(self, question: str) -> str: | |
| if not self.enabled: | |
| raise RuntimeError("SQLGenTool disabled: missing HF token or model_id.") | |
| fewshot_txt = "\n".join([f"Q: {ex['q']}\nSQL:\n{ex['sql']}\n" for ex in FEW_SHOTS]) | |
| sys = ( | |
| "You are a SQL generator. Output only a single JSON object: {\"sql\": \"...\"}.\n" | |
| "No prose. No explanations. Use the provided schema only.\n" + SCHEMA_SPEC | |
| ) | |
| user = f"Question:\n{question}\n\nReturn JSON with a single key 'sql'." | |
| payload = { | |
| "model": self.model_id, | |
| "stream": False, | |
| "messages": [ | |
| {"role":"system","content":[{"type":"text","text":sys}]}, | |
| {"role":"user","content":[{"type":"text","text":fewshot_txt + "\n\n" + user}]}, | |
| ], | |
| "temperature": self.temperature, | |
| "max_tokens": self.max_tokens, | |
| } | |
| headers = {"Authorization": f"Bearer {self.token}", | |
| "Accept": "application/json", | |
| "Accept-Encoding": "identity" | |
| } | |
| r = requests.post(HF_CHAT_URL, headers=headers, json=payload, timeout=self.timeout) | |
| r.raise_for_status() | |
| content = r.json()["choices"][0]["message"]["content"].strip() | |
| s, e = content.find("{"), content.rfind("}") | |
| obj = json.loads(content[s:e+1]) | |
| sql = obj.get("sql","").strip() | |
| return sql | |