LalitChaudhari3 commited on
Commit
3656fbb
·
verified ·
1 Parent(s): c741092

Update src/sql_generator.py

Browse files
Files changed (1) hide show
  1. src/sql_generator.py +93 -87
src/sql_generator.py CHANGED
@@ -8,17 +8,30 @@ class SQLGenerator:
8
  def __init__(self):
9
  load_dotenv()
10
 
11
- # 1. GET KEY (Cleaned)
12
  raw_key = os.getenv("HF_API_KEY") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
13
  self.api_key = raw_key.strip() if raw_key else None
14
 
15
- # 2. USE QWEN 2.5 CODER (7B)
16
- # This is the Best "Ungated" Free Model right now.
17
- # It does not require a license agreement click, unlike Mistral/Llama.
18
- self.repo_id = "Qwen/Qwen2.5-Coder-7B-Instruct"
 
 
 
 
 
 
 
 
 
 
19
 
20
- # 3. ROUTER URL
21
- self.api_url = f"https://router.huggingface.co/models/{self.repo_id}"
 
 
 
22
 
23
  def generate_followup_questions(self, question, sql_query):
24
  return ["Visualize this result", "Export as CSV", "Compare with last year"]
@@ -29,96 +42,89 @@ class SQLGenerator:
29
  if not self.api_key:
30
  return "SELECT 'Error: HF_API_KEY Missing' as status", "Configuration Error", "Please add HF_API_KEY to your Space Secrets."
31
 
32
- # 🛡️ Safety Layer
33
  forbidden = ["DROP", "DELETE", "UPDATE", "INSERT", "ALTER", "TRUNCATE", "GRANT"]
34
  if any(word in question.upper() for word in forbidden):
35
  return "SELECT 'Error: Blocked by Safety Layer' as status", "Safety Alert", "I cannot execute commands that modify data."
36
 
37
- # Format History
38
- history_text = ""
39
- if history:
40
- history_text = "PREVIOUS CONVERSATION:\n" + "\n".join([f"User: {h.get('user')}\nSQL: {h.get('sql')}" for h in history[-2:]])
41
-
42
- # System Prompt (Qwen ChatML Format)
43
- system_prompt = f"""You are an elite SQL Expert.
44
- DATABASE SCHEMA:
45
  {context}
46
 
47
- {history_text}
48
-
49
- RULES:
50
- 1. Output ONLY a valid JSON object. Format: {{ "sql": "SELECT ...", "message": "Friendly text", "explanation": "Brief summary" }}
51
- 2. The SQL query MUST be Read-Only (SELECT).
52
- 3. Do not use markdown formatting.
53
- """
54
 
55
- # Payload (Qwen Specific Format)
56
  payload = {
57
- "inputs": f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{question}<|im_end|>\n<|im_start|>assistant\n",
58
- "parameters": {
59
- "max_new_tokens": 1024,
60
- "temperature": 0.1,
61
- "return_full_text": False
62
- }
63
  }
64
 
65
- headers = {
66
- "Authorization": f"Bearer {self.api_key}",
67
- "Content-Type": "application/json"
68
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
  try:
71
- print(f" ⚡ Generating SQL using {self.repo_id}...")
72
-
73
- # 🚀 DIRECT REQUEST
74
- response = requests.post(self.api_url, headers=headers, json=payload, timeout=25)
75
-
76
- if response.status_code != 200:
77
- print(f" ❌ API FAILURE: {response.status_code}")
78
- print(f" ❌ RESPONSE: {response.text}")
79
- return f"SELECT 'Error: API returned {response.status_code}' as status", "API Error", f"Model Error: {response.status_code}"
80
-
81
- result = response.json()
82
-
83
- # Helper to extract text
84
- if isinstance(result, list) and len(result) > 0:
85
- raw_text = result[0].get('generated_text', '')
86
- elif isinstance(result, dict):
87
- raw_text = result.get('generated_text', '')
88
  else:
89
- raw_text = str(result)
90
-
91
- # JSON Parsing
92
- sql_query = ""
93
- message = "Here is the data."
94
- explanation = "Query generated successfully."
95
-
96
- try:
97
- clean_json = re.sub(r"```json|```", "", raw_text).strip()
98
- json_match = re.search(r"\{.*\}", clean_json, re.DOTALL)
99
- if json_match:
100
- data = json.loads(json_match.group(0))
101
- sql_query = data.get("sql", "")
102
- message = data.get("message", message)
103
- explanation = data.get("explanation", explanation)
104
- else:
105
- # Regex Fallback
106
- match = re.search(r"(SELECT[\s\S]+?;)", raw_text, re.IGNORECASE)
107
- if match: sql_query = match.group(1)
108
- except:
109
- match = re.search(r"(SELECT[\s\S]+?;)", raw_text, re.IGNORECASE)
110
- if match: sql_query = match.group(1)
111
-
112
- sql_query = sql_query.strip().replace("\n", " ")
113
- if sql_query and not sql_query.endswith(";"): sql_query += ";"
114
-
115
- clean_check = re.sub(r"/\*.*?\*/|--.*?\n", "", sql_query, flags=re.DOTALL).strip().upper()
116
- if not clean_check.startswith("SELECT") and not clean_check.startswith("WITH"):
117
- return "SELECT 'Error: Invalid Query Type' as status", "Safety Error", "I can only perform read-only operations."
118
-
119
- return sql_query, explanation, message
120
-
121
- except Exception as e:
122
- print(f" ❌ SYSTEM EXCEPTION: {e}")
123
- safe_e = str(e).replace("'", "").replace('"', "")
124
- return f"SELECT 'Error: {safe_e}' as status", "System Error", "An unexpected error occurred."
 
8
  def __init__(self):
9
  load_dotenv()
10
 
11
+ # 1. AUTHENTICATION (With Debugging)
12
  raw_key = os.getenv("HF_API_KEY") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
13
  self.api_key = raw_key.strip() if raw_key else None
14
 
15
+ if not self.api_key:
16
+ print(" FATAL: API Key missing.")
17
+ else:
18
+ print(f" ✅ API Key loaded: {self.api_key[:5]}...")
19
+
20
+ # 2. THE "SHOTGUN" MODEL LIST
21
+ # We try these 5 models in order. One WILL work.
22
+ self.models = [
23
+ "microsoft/Phi-3-mini-4k-instruct", # High availability, very fast
24
+ "google/gemma-1.1-7b-it", # Google's open model (very stable)
25
+ "mistralai/Mistral-7B-Instruct-v0.3", # Standard free tier workhorse
26
+ "HuggingFaceH4/zephyr-7b-beta", # Reliable fallback
27
+ "Qwen/Qwen2.5-Coder-7B-Instruct" # Excellent coder (if online)
28
+ ]
29
 
30
+ # 3. ENDPOINTS (Router + Legacy)
31
+ self.endpoints = [
32
+ "https://router.huggingface.co/models/",
33
+ "https://api-inference.huggingface.co/models/"
34
+ ]
35
 
36
  def generate_followup_questions(self, question, sql_query):
37
  return ["Visualize this result", "Export as CSV", "Compare with last year"]
 
42
  if not self.api_key:
43
  return "SELECT 'Error: HF_API_KEY Missing' as status", "Configuration Error", "Please add HF_API_KEY to your Space Secrets."
44
 
45
+ # 🛡️ Safety
46
  forbidden = ["DROP", "DELETE", "UPDATE", "INSERT", "ALTER", "TRUNCATE", "GRANT"]
47
  if any(word in question.upper() for word in forbidden):
48
  return "SELECT 'Error: Blocked by Safety Layer' as status", "Safety Alert", "I cannot execute commands that modify data."
49
 
50
+ # Simple Prompt
51
+ system_prompt = f"""You are an SQL Expert.
52
+ Schema:
 
 
 
 
 
53
  {context}
54
 
55
+ Rules:
56
+ 1. Output valid JSON: {{ "sql": "SELECT ...", "message": "Short text", "explanation": "Brief summary" }}
57
+ 2. Read-only SELECT queries only.
58
+ 3. No markdown.
59
+
60
+ Question: {question}"""
 
61
 
 
62
  payload = {
63
+ "inputs": f"<|user|>\n{system_prompt}\n<|end|>\n<|assistant|>\n",
64
+ "parameters": {"max_new_tokens": 512, "temperature": 0.1, "return_full_text": False}
 
 
 
 
65
  }
66
 
67
+ headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
68
+
69
+ # 🔄 ULTRA-ROBUST RETRY LOOP
70
+ errors = []
71
+
72
+ for model in self.models:
73
+ for base_url in self.endpoints:
74
+ api_url = f"{base_url}{model}"
75
+ try:
76
+ print(f" ⚡ Attempting: {model}...")
77
+ response = requests.post(api_url, headers=headers, json=payload, timeout=15)
78
+
79
+ if response.status_code == 200:
80
+ print(f" ✅ SUCCESS with {model}!")
81
+ return self._process_response(response.json())
82
+
83
+ # Log failure and continue immediately
84
+ print(f" ❌ Failed ({response.status_code})")
85
+ errors.append(f"{model}: {response.status_code}")
86
+
87
+ except Exception as e:
88
+ print(f" ⚠️ Connection Error: {e}")
89
+ errors.append(f"{model}: Error")
90
+
91
+ # If we get here, literally everything failed (Rare)
92
+ return f"SELECT 'Error: All models failed' as status", "System Error", f"Debug Info: {', '.join(errors)}"
93
+
94
+ def _process_response(self, result):
95
+ if isinstance(result, list) and len(result) > 0:
96
+ raw_text = result[0].get('generated_text', '')
97
+ elif isinstance(result, dict):
98
+ raw_text = result.get('generated_text', '')
99
+ else:
100
+ raw_text = str(result)
101
+
102
+ sql_query = ""
103
+ message = "Here is the data."
104
+ explanation = "Query generated successfully."
105
 
106
  try:
107
+ # Clean and Extract
108
+ clean_json = re.sub(r"```json|```", "", raw_text).strip()
109
+ # Regex to find JSON
110
+ json_match = re.search(r"\{.*\}", clean_json, re.DOTALL)
111
+ if json_match:
112
+ data = json.loads(json_match.group(0))
113
+ sql_query = data.get("sql", "")
114
+ message = data.get("message", message)
115
+ explanation = data.get("explanation", explanation)
 
 
 
 
 
 
 
 
116
  else:
117
+ match = re.search(r"(SELECT[\s\S]+?;)", raw_text, re.IGNORECASE)
118
+ if match: sql_query = match.group(1)
119
+ except:
120
+ match = re.search(r"(SELECT[\s\S]+?;)", raw_text, re.IGNORECASE)
121
+ if match: sql_query = match.group(1)
122
+
123
+ sql_query = sql_query.strip().replace("\n", " ")
124
+ if sql_query and not sql_query.endswith(";"): sql_query += ";"
125
+
126
+ # Fallback for empty SQL
127
+ if not sql_query:
128
+ sql_query = "SELECT 'Error: AI generated empty query' as status"
129
+
130
+ return sql_query, explanation, message