LalitChaudhari3 commited on
Commit
ef0f8e7
·
verified ·
1 Parent(s): d3786b9

Update src/sql_generator.py

Browse files
Files changed (1) hide show
  1. src/sql_generator.py +22 -72
src/sql_generator.py CHANGED
@@ -1,24 +1,15 @@
1
  import os
2
- import requests
3
  import re
4
  import json
 
5
  from dotenv import load_dotenv
6
 
7
  class SQLGenerator:
8
- def __init__(self):
9
  load_dotenv()
10
-
11
- # 1. ROBUSTLY FETCH & CLEAN THE KEY
12
- raw_key = os.getenv("HF_API_KEY") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
13
- self.api_key = raw_key.strip() if raw_key else None
14
-
15
- # 2. USE A MODEL THAT IS DEFINITELY FREE & ONLINE
16
- # The 32B model caused the 404. We switch to Mistral-7B-Instruct-v0.3
17
- # It is excellent for SQL and 100% supported on the free tier.
18
- self.repo_id = "mistralai/Mistral-7B-Instruct-v0.3"
19
-
20
- # 3. USE THE NEW ROUTER URL
21
- self.api_url = f"https://router.huggingface.co/models/{self.repo_id}"
22
 
23
  def generate_followup_questions(self, question, sql_query):
24
  return ["Visualize this result", "Export as CSV", "Compare with last year"]
@@ -26,94 +17,53 @@ class SQLGenerator:
26
  def generate_sql(self, question, context, history=None):
27
  if history is None: history = []
28
 
29
- # 🚨 ERROR CHECK
30
- if not self.api_key:
31
- return "SELECT 'Error: HF_API_KEY Missing' as status", "Configuration Error", "Please add HF_API_KEY to your Space Secrets."
32
-
33
- # 🛡️ Safety Layer
34
  forbidden = ["DROP", "DELETE", "UPDATE", "INSERT", "ALTER", "TRUNCATE", "GRANT"]
35
  if any(word in question.upper() for word in forbidden):
36
  return "SELECT 'Error: Blocked by Safety Layer' as status", "Safety Alert", "I cannot execute commands that modify data."
37
 
38
- # Format History
39
  history_text = ""
40
  if history:
41
- history_text = "PREVIOUS CONVERSATION:\n" + "\n".join([f"User: {h.get('user')}\nSQL: {h.get('sql')}" for h in history[-2:]])
42
-
43
- # System Prompt (Optimized for Mistral)
44
- system_prompt = f"""<s>[INST] You are an expert SQL Assistant.
45
- DATABASE SCHEMA:
46
- {context}
47
 
48
- {history_text}
 
 
49
 
50
- RULES:
51
- 1. Output ONLY a valid JSON object. Format: {{ "sql": "SELECT ...", "message": "Friendly text", "explanation": "Brief summary" }}
52
- 2. The SQL query MUST be Read-Only (SELECT).
53
- 3. Do not use markdown formatting.
54
-
55
- QUESTION: {question} [/INST]"""
56
-
57
- # Payload
58
- payload = {
59
- "inputs": system_prompt,
60
- "parameters": {
61
- "max_new_tokens": 1024,
62
- "temperature": 0.1,
63
- "return_full_text": False
64
- }
65
- }
66
 
67
- headers = {
68
- "Authorization": f"Bearer {self.api_key}",
69
- "Content-Type": "application/json"
70
- }
 
 
71
 
72
  try:
73
- print(f" ⚡ Generating SQL via Direct API ({self.repo_id})...")
74
-
75
- # 🚀 DIRECT REQUEST
76
- response = requests.post(self.api_url, headers=headers, json=payload, timeout=25)
77
-
78
- if response.status_code != 200:
79
- print(f" ❌ API Status: {response.status_code} - {response.text}")
80
- return f"SELECT 'Error: API returned {response.status_code}' as status", "API Error", f"Model Error: {response.status_code}"
81
-
82
- # Parse Response
83
- result = response.json()
84
 
85
- if isinstance(result, list) and len(result) > 0:
86
- raw_text = result[0].get('generated_text', '')
87
- elif isinstance(result, dict):
88
- raw_text = result.get('generated_text', '')
89
- else:
90
- raw_text = str(result)
91
-
92
- # Clean JSON
93
  sql_query = ""
94
  message = "Here is the data."
95
  explanation = "Query generated successfully."
96
 
97
  try:
98
- # Mistral sometimes keeps the prompt. Split by [/INST] if present.
99
- if "[/INST]" in raw_text:
100
- raw_text = raw_text.split("[/INST]")[-1]
101
-
102
  clean_json = re.sub(r"```json|```", "", raw_text).strip()
103
  data = json.loads(clean_json)
104
  sql_query = data.get("sql", "")
105
  message = data.get("message", message)
106
  explanation = data.get("explanation", explanation)
107
  except:
108
- # Fallback: Regex to find SQL if JSON parsing fails
109
  match = re.search(r"(SELECT[\s\S]+?;)", raw_text, re.IGNORECASE)
110
  if match: sql_query = match.group(1)
111
 
112
- # Final Cleanup
113
  sql_query = sql_query.strip().replace("\n", " ")
114
  if sql_query and not sql_query.endswith(";"): sql_query += ";"
115
 
 
116
  clean_check = re.sub(r"/\*.*?\*/|--.*?\n", "", sql_query, flags=re.DOTALL).strip().upper()
 
 
117
  if not clean_check.startswith("SELECT") and not clean_check.startswith("WITH"):
118
  print(f" ⚠️ Invalid SQL Blocked: {sql_query}")
119
  return "SELECT 'Error: Invalid Query Type (Non-SELECT)' as status", "Safety Error", "I can only perform read-only operations."
 
1
  import os
 
2
  import re
3
  import json
4
+ from huggingface_hub import InferenceClient
5
  from dotenv import load_dotenv
6
 
7
  class SQLGenerator:
8
+ def __init__(self, api_key=None):
9
  load_dotenv()
10
+ self.api_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
11
+ self.repo_id = "Qwen/Qwen2.5-Coder-32B-Instruct"
12
+ self.client = InferenceClient(token=self.api_token, timeout=25.0)
 
 
 
 
 
 
 
 
 
13
 
14
  def generate_followup_questions(self, question, sql_query):
15
  return ["Visualize this result", "Export as CSV", "Compare with last year"]
 
17
  def generate_sql(self, question, context, history=None):
18
  if history is None: history = []
19
 
 
 
 
 
 
20
  forbidden = ["DROP", "DELETE", "UPDATE", "INSERT", "ALTER", "TRUNCATE", "GRANT"]
21
  if any(word in question.upper() for word in forbidden):
22
  return "SELECT 'Error: Blocked by Safety Layer' as status", "Safety Alert", "I cannot execute commands that modify data."
23
 
 
24
  history_text = ""
25
  if history:
26
+ history_text = "PREVIOUS CONVERSATION:\n" + "\n".join([f"User: {h['user']}\nSQL: {h['sql']}" for h in history[-2:]])
 
 
 
 
 
27
 
28
+ system_prompt = f"""You are an elite SQL Expert.
29
+ Schema:
30
+ {context}
31
 
32
+ {history_text}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
+ Rules:
35
+ 1. Output JSON: {{ "sql": "SELECT ...", "message": "Friendly text", "explanation": "Brief summary" }}
36
+ 2. Query MUST be Read-Only (SELECT).
37
+ 3. Do not include markdown formatting like ```json.
38
+ """
39
+ messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": question}]
40
 
41
  try:
42
+ print(f" ⚡ Generating SQL...")
43
+ response = self.client.chat_completion(messages=messages, model=self.repo_id, max_tokens=1024, temperature=0.1)
44
+ raw_text = response.choices[0].message.content
 
 
 
 
 
 
 
 
45
 
 
 
 
 
 
 
 
 
46
  sql_query = ""
47
  message = "Here is the data."
48
  explanation = "Query generated successfully."
49
 
50
  try:
 
 
 
 
51
  clean_json = re.sub(r"```json|```", "", raw_text).strip()
52
  data = json.loads(clean_json)
53
  sql_query = data.get("sql", "")
54
  message = data.get("message", message)
55
  explanation = data.get("explanation", explanation)
56
  except:
 
57
  match = re.search(r"(SELECT[\s\S]+?;)", raw_text, re.IGNORECASE)
58
  if match: sql_query = match.group(1)
59
 
 
60
  sql_query = sql_query.strip().replace("\n", " ")
61
  if sql_query and not sql_query.endswith(";"): sql_query += ";"
62
 
63
+ # ✅ FIX: Strip comments and whitespace before validation
64
  clean_check = re.sub(r"/\*.*?\*/|--.*?\n", "", sql_query, flags=re.DOTALL).strip().upper()
65
+
66
+ # ✅ FIX: Allow SELECT or WITH clauses
67
  if not clean_check.startswith("SELECT") and not clean_check.startswith("WITH"):
68
  print(f" ⚠️ Invalid SQL Blocked: {sql_query}")
69
  return "SELECT 'Error: Invalid Query Type (Non-SELECT)' as status", "Safety Error", "I can only perform read-only operations."