LalitChaudhari3 commited on
Commit
146146a
·
verified ·
1 Parent(s): ef0f8e7

Update src/sql_generator.py

Browse files
Files changed (1) hide show
  1. src/sql_generator.py +74 -22
src/sql_generator.py CHANGED
@@ -1,15 +1,25 @@
1
  import os
 
2
  import re
3
  import json
4
- from huggingface_hub import InferenceClient
5
  from dotenv import load_dotenv
6
 
7
  class SQLGenerator:
8
- def __init__(self, api_key=None):
9
  load_dotenv()
10
- self.api_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
11
- self.repo_id = "Qwen/Qwen2.5-Coder-32B-Instruct"
12
- self.client = InferenceClient(token=self.api_token, timeout=25.0)
 
 
 
 
 
 
 
 
 
 
13
 
14
  def generate_followup_questions(self, question, sql_query):
15
  return ["Visualize this result", "Export as CSV", "Compare with last year"]
@@ -17,53 +27,95 @@ class SQLGenerator:
17
  def generate_sql(self, question, context, history=None):
18
  if history is None: history = []
19
 
 
 
 
 
 
20
  forbidden = ["DROP", "DELETE", "UPDATE", "INSERT", "ALTER", "TRUNCATE", "GRANT"]
21
  if any(word in question.upper() for word in forbidden):
22
  return "SELECT 'Error: Blocked by Safety Layer' as status", "Safety Alert", "I cannot execute commands that modify data."
23
 
 
24
  history_text = ""
25
  if history:
26
- history_text = "PREVIOUS CONVERSATION:\n" + "\n".join([f"User: {h['user']}\nSQL: {h['sql']}" for h in history[-2:]])
 
 
 
 
 
27
 
28
- system_prompt = f"""You are an elite SQL Expert.
29
- Schema:
30
- {context}
31
 
32
- {history_text}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
- Rules:
35
- 1. Output JSON: {{ "sql": "SELECT ...", "message": "Friendly text", "explanation": "Brief summary" }}
36
- 2. Query MUST be Read-Only (SELECT).
37
- 3. Do not include markdown formatting like ```json.
38
- """
39
- messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": question}]
40
 
41
  try:
42
- print(f" ⚡ Generating SQL...")
43
- response = self.client.chat_completion(messages=messages, model=self.repo_id, max_tokens=1024, temperature=0.1)
44
- raw_text = response.choices[0].message.content
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  sql_query = ""
47
  message = "Here is the data."
48
  explanation = "Query generated successfully."
49
 
50
  try:
 
 
51
  clean_json = re.sub(r"```json|```", "", raw_text).strip()
 
52
  data = json.loads(clean_json)
53
  sql_query = data.get("sql", "")
54
  message = data.get("message", message)
55
  explanation = data.get("explanation", explanation)
56
  except:
 
57
  match = re.search(r"(SELECT[\s\S]+?;)", raw_text, re.IGNORECASE)
58
  if match: sql_query = match.group(1)
59
 
 
60
  sql_query = sql_query.strip().replace("\n", " ")
61
  if sql_query and not sql_query.endswith(";"): sql_query += ";"
62
 
63
- # FIX: Strip comments and whitespace before validation
64
  clean_check = re.sub(r"/\*.*?\*/|--.*?\n", "", sql_query, flags=re.DOTALL).strip().upper()
65
-
66
- # ✅ FIX: Allow SELECT or WITH clauses
67
  if not clean_check.startswith("SELECT") and not clean_check.startswith("WITH"):
68
  print(f" ⚠️ Invalid SQL Blocked: {sql_query}")
69
  return "SELECT 'Error: Invalid Query Type (Non-SELECT)' as status", "Safety Error", "I can only perform read-only operations."
 
1
  import os
2
+ import requests
3
  import re
4
  import json
 
5
  from dotenv import load_dotenv
6
 
7
  class SQLGenerator:
8
+ def __init__(self):
9
  load_dotenv()
10
+
11
+ # 1. ROBUSTLY FETCH & CLEAN THE KEY
12
+ # We check both names and use .strip() to remove invisible newlines
13
+ raw_key = os.getenv("HF_API_KEY") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
14
+ self.api_key = raw_key.strip() if raw_key else None
15
+
16
+ # 2. USE A RELIABLE FREE TIER MODEL
17
+ # We switch to Mistral-7B-Instruct-v0.3 (Supported on Free Tier)
18
+ # Qwen-32B is too big and causes 404 errors.
19
+ self.repo_id = "mistralai/Mistral-7B-Instruct-v0.3"
20
+
21
+ # 3. USE THE NEW ROUTER URL (Fixes 410 Gone errors)
22
+ self.api_url = f"https://router.huggingface.co/models/{self.repo_id}"
23
 
24
  def generate_followup_questions(self, question, sql_query):
25
  return ["Visualize this result", "Export as CSV", "Compare with last year"]
 
27
  def generate_sql(self, question, context, history=None):
28
  if history is None: history = []
29
 
30
+ # 🚨 ERROR CHECK
31
+ if not self.api_key:
32
+ return "SELECT 'Error: HF_API_KEY Missing' as status", "Configuration Error", "Please add HF_API_KEY to your Space Secrets."
33
+
34
+ # 🛡️ Safety Layer (Keyword Block)
35
  forbidden = ["DROP", "DELETE", "UPDATE", "INSERT", "ALTER", "TRUNCATE", "GRANT"]
36
  if any(word in question.upper() for word in forbidden):
37
  return "SELECT 'Error: Blocked by Safety Layer' as status", "Safety Alert", "I cannot execute commands that modify data."
38
 
39
+ # Format History
40
  history_text = ""
41
  if history:
42
+ history_text = "PREVIOUS CONVERSATION:\n" + "\n".join([f"User: {h.get('user')}\nSQL: {h.get('sql')}" for h in history[-2:]])
43
+
44
+ # System Prompt (Optimized for JSON output)
45
+ system_prompt = f"""<s>[INST] You are an elite SQL Expert.
46
+ DATABASE SCHEMA:
47
+ {context}
48
 
49
+ {history_text}
 
 
50
 
51
+ RULES:
52
+ 1. Output ONLY a valid JSON object. Format: {{ "sql": "SELECT ...", "message": "Friendly text", "explanation": "Brief summary" }}
53
+ 2. The SQL query MUST be Read-Only (SELECT).
54
+ 3. Do not use markdown formatting.
55
+
56
+ QUESTION: {question} [/INST]"""
57
+
58
+ # Payload
59
+ payload = {
60
+ "inputs": system_prompt,
61
+ "parameters": {
62
+ "max_new_tokens": 1024,
63
+ "temperature": 0.1,
64
+ "return_full_text": False
65
+ }
66
+ }
67
 
68
+ headers = {
69
+ "Authorization": f"Bearer {self.api_key}",
70
+ "Content-Type": "application/json"
71
+ }
 
 
72
 
73
  try:
74
+ print(f" ⚡ Generating SQL via Direct API ({self.repo_id})...")
 
 
75
 
76
+ # 🚀 DIRECT REQUEST (Bypasses library auth issues)
77
+ response = requests.post(self.api_url, headers=headers, json=payload, timeout=25)
78
+
79
+ # Handle Errors clearly
80
+ if response.status_code != 200:
81
+ print(f" ❌ API Status: {response.status_code} - {response.text}")
82
+ return f"SELECT 'Error: API returned {response.status_code}' as status", "API Error", f"Model unavailable ({response.status_code})"
83
+
84
+ # Parse Response
85
+ result = response.json()
86
+
87
+ if isinstance(result, list) and len(result) > 0:
88
+ raw_text = result[0].get('generated_text', '')
89
+ elif isinstance(result, dict):
90
+ raw_text = result.get('generated_text', '')
91
+ else:
92
+ raw_text = str(result)
93
+
94
+ # Clean JSON
95
  sql_query = ""
96
  message = "Here is the data."
97
  explanation = "Query generated successfully."
98
 
99
  try:
100
+ # Remove Markdown and extra text
101
+ if "[/INST]" in raw_text: raw_text = raw_text.split("[/INST]")[-1]
102
  clean_json = re.sub(r"```json|```", "", raw_text).strip()
103
+
104
  data = json.loads(clean_json)
105
  sql_query = data.get("sql", "")
106
  message = data.get("message", message)
107
  explanation = data.get("explanation", explanation)
108
  except:
109
+ # Fallback: Regex to find SQL if JSON parsing fails
110
  match = re.search(r"(SELECT[\s\S]+?;)", raw_text, re.IGNORECASE)
111
  if match: sql_query = match.group(1)
112
 
113
+ # Final Cleanup
114
  sql_query = sql_query.strip().replace("\n", " ")
115
  if sql_query and not sql_query.endswith(";"): sql_query += ";"
116
 
117
+ # 🛡️ Final Validation (Allows SELECT or WITH)
118
  clean_check = re.sub(r"/\*.*?\*/|--.*?\n", "", sql_query, flags=re.DOTALL).strip().upper()
 
 
119
  if not clean_check.startswith("SELECT") and not clean_check.startswith("WITH"):
120
  print(f" ⚠️ Invalid SQL Blocked: {sql_query}")
121
  return "SELECT 'Error: Invalid Query Type (Non-SELECT)' as status", "Safety Error", "I can only perform read-only operations."