LalitChaudhari3 commited on
Commit
91ed273
·
verified ·
1 Parent(s): 9b37324

Update src/sql_generator.py

Browse files
Files changed (1) hide show
  1. src/sql_generator.py +120 -75
src/sql_generator.py CHANGED
@@ -1,76 +1,121 @@
1
- import os
2
- import re
3
- import json
4
- from huggingface_hub import InferenceClient
5
- from dotenv import load_dotenv
6
-
7
- class SQLGenerator:
8
- def __init__(self, api_key=None):
9
- load_dotenv()
10
- self.api_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
11
- self.repo_id = "Qwen/Qwen2.5-Coder-32B-Instruct"
12
- self.client = InferenceClient(token=self.api_token, timeout=25.0)
13
-
14
- def generate_followup_questions(self, question, sql_query):
15
- return ["Visualize this result", "Export as CSV", "Compare with last year"]
16
-
17
- def generate_sql(self, question, context, history=None):
18
- if history is None: history = []
19
-
20
- forbidden = ["DROP", "DELETE", "UPDATE", "INSERT", "ALTER", "TRUNCATE", "GRANT"]
21
- if any(word in question.upper() for word in forbidden):
22
- return "SELECT 'Error: Blocked by Safety Layer' as status", "Safety Alert", "I cannot execute commands that modify data."
23
-
24
- history_text = ""
25
- if history:
26
- history_text = "PREVIOUS CONVERSATION:\n" + "\n".join([f"User: {h['user']}\nSQL: {h['sql']}" for h in history[-2:]])
27
-
28
- system_prompt = f"""You are an elite SQL Expert.
29
- Schema:
30
- {context}
31
-
32
- {history_text}
33
-
34
- Rules:
35
- 1. Output JSON: {{ "sql": "SELECT ...", "message": "Friendly text", "explanation": "Brief summary" }}
36
- 2. Query MUST be Read-Only (SELECT).
37
- 3. Do not include markdown formatting like ```json.
38
- """
39
- messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": question}]
40
-
41
- try:
42
- print(f" ⚡ Generating SQL...")
43
- response = self.client.chat_completion(messages=messages, model=self.repo_id, max_tokens=1024, temperature=0.1)
44
- raw_text = response.choices[0].message.content
45
-
46
- sql_query = ""
47
- message = "Here is the data."
48
- explanation = "Query generated successfully."
49
-
50
- try:
51
- clean_json = re.sub(r"```json|```", "", raw_text).strip()
52
- data = json.loads(clean_json)
53
- sql_query = data.get("sql", "")
54
- message = data.get("message", message)
55
- explanation = data.get("explanation", explanation)
56
- except:
57
- match = re.search(r"(SELECT[\s\S]+?;)", raw_text, re.IGNORECASE)
58
- if match: sql_query = match.group(1)
59
-
60
- sql_query = sql_query.strip().replace("\n", " ")
61
- if sql_query and not sql_query.endswith(";"): sql_query += ";"
62
-
63
- # FIX: Strip comments and whitespace before validation
64
- clean_check = re.sub(r"/\*.*?\*/|--.*?\n", "", sql_query, flags=re.DOTALL).strip().upper()
65
-
66
- # ✅ FIX: Allow SELECT or WITH clauses
67
- if not clean_check.startswith("SELECT") and not clean_check.startswith("WITH"):
68
- print(f" ⚠️ Invalid SQL Blocked: {sql_query}")
69
- return "SELECT 'Error: Invalid Query Type (Non-SELECT)' as status", "Safety Error", "I can only perform read-only operations."
70
-
71
- return sql_query, explanation, message
72
-
73
- except Exception as e:
74
- print(f" ❌ Model Error: {e}")
75
- safe_e = str(e).replace("'", "").replace('"', "")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  return f"SELECT 'Error: {safe_e}' as status", "System Error", "An unexpected error occurred."
 
1
+ import os
2
+ import requests
3
+ import re
4
+ import json
5
+ from dotenv import load_dotenv
6
+
7
+ class SQLGenerator:
8
+ def __init__(self):
9
+ load_dotenv()
10
+
11
+ # 1. Robustly fetch the API Key
12
+ # Try HF_API_KEY first (our standard), then fall back to others if user renamed it
13
+ self.api_key = os.getenv("HF_API_KEY") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
14
+
15
+ # 2. Use the powerful Qwen 2.5 Coder model
16
+ self.repo_id = "Qwen/Qwen2.5-Coder-32B-Instruct"
17
+ self.api_url = f"https://api-inference.huggingface.co/models/{self.repo_id}"
18
+
19
+ def generate_followup_questions(self, question, sql_query):
20
+ # Basic heuristics to suggest next steps
21
+ return ["Visualize this result", "Export as CSV", "Compare with last year"]
22
+
23
+ def generate_sql(self, question, context, history=None):
24
+ if history is None: history = []
25
+
26
+ # 🚨 ERROR CHECK: Stop early if key is missing
27
+ if not self.api_key:
28
+ return "SELECT 'Error: HF_API_KEY Missing' as status", "Configuration Error", "Please add HF_API_KEY to your Space Secrets."
29
+
30
+ # 🛡️ Safety Layer: Block dangerous keywords
31
+ forbidden = ["DROP", "DELETE", "UPDATE", "INSERT", "ALTER", "TRUNCATE", "GRANT"]
32
+ if any(word in question.upper() for word in forbidden):
33
+ return "SELECT 'Error: Blocked by Safety Layer' as status", "Safety Alert", "I cannot execute commands that modify data."
34
+
35
+ # Format History for the AI
36
+ history_text = ""
37
+ if history:
38
+ history_text = "PREVIOUS CONVERSATION:\n" + "\n".join([f"User: {h.get('user')}\nSQL: {h.get('sql')}" for h in history[-2:]])
39
+
40
+ # System Prompt
41
+ system_prompt = f"""You are an elite SQL Expert.
42
+ Schema:
43
+ {context}
44
+
45
+ {history_text}
46
+
47
+ Rules:
48
+ 1. Output JSON: {{ "sql": "SELECT ...", "message": "Friendly text", "explanation": "Brief summary" }}
49
+ 2. Query MUST be Read-Only (SELECT).
50
+ 3. Do not include markdown formatting like ```json.
51
+ """
52
+
53
+ # Prepare the payload (Direct HTTP Request)
54
+ payload = {
55
+ "inputs": f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{question}<|im_end|>\n<|im_start|>assistant\n",
56
+ "parameters": {
57
+ "max_new_tokens": 1024,
58
+ "temperature": 0.1, # Low temp = precise SQL
59
+ "return_full_text": False
60
+ }
61
+ }
62
+
63
+ headers = {
64
+ "Authorization": f"Bearer {self.api_key}",
65
+ "Content-Type": "application/json"
66
+ }
67
+
68
+ try:
69
+ print(f" Generating SQL via Direct API...")
70
+
71
+ # 🚀 DIRECT REQUEST (Bypasses library auth issues)
72
+ response = requests.post(self.api_url, headers=headers, json=payload, timeout=25)
73
+
74
+ if response.status_code != 200:
75
+ print(f" ❌ API Status: {response.status_code} - {response.text}")
76
+ return f"SELECT 'Error: API returned {response.status_code}' as status", "API Error", "The AI model is currently unavailable."
77
+
78
+ # Parse Response
79
+ result = response.json()
80
+
81
+ # Handle different response formats (sometimes list, sometimes dict)
82
+ if isinstance(result, list) and len(result) > 0:
83
+ raw_text = result[0].get('generated_text', '')
84
+ elif isinstance(result, dict):
85
+ raw_text = result.get('generated_text', '')
86
+ else:
87
+ raw_text = str(result)
88
+
89
+ # JSON Parsing & Cleanup
90
+ sql_query = ""
91
+ message = "Here is the data."
92
+ explanation = "Query generated successfully."
93
+
94
+ try:
95
+ # Remove markdown code blocks if AI added them
96
+ clean_json = re.sub(r"```json|```", "", raw_text).strip()
97
+ data = json.loads(clean_json)
98
+ sql_query = data.get("sql", "")
99
+ message = data.get("message", message)
100
+ explanation = data.get("explanation", explanation)
101
+ except:
102
+ # Fallback: Regex to find SQL if JSON parsing fails
103
+ match = re.search(r"(SELECT[\s\S]+?;)", raw_text, re.IGNORECASE)
104
+ if match: sql_query = match.group(1)
105
+
106
+ # Final SQL Cleanup
107
+ sql_query = sql_query.strip().replace("\n", " ")
108
+ if sql_query and not sql_query.endswith(";"): sql_query += ";"
109
+
110
+ # 🛡️ Final Validation (Allow SELECT or WITH)
111
+ clean_check = re.sub(r"/\*.*?\*/|--.*?\n", "", sql_query, flags=re.DOTALL).strip().upper()
112
+ if not clean_check.startswith("SELECT") and not clean_check.startswith("WITH"):
113
+ print(f" ⚠️ Invalid SQL Blocked: {sql_query}")
114
+ return "SELECT 'Error: Invalid Query Type (Non-SELECT)' as status", "Safety Error", "I can only perform read-only operations."
115
+
116
+ return sql_query, explanation, message
117
+
118
+ except Exception as e:
119
+ print(f" ❌ Model Error: {e}")
120
+ safe_e = str(e).replace("'", "").replace('"', "")
121
  return f"SELECT 'Error: {safe_e}' as status", "System Error", "An unexpected error occurred."