LalitChaudhari3 commited on
Commit
304a74a
·
verified ·
1 Parent(s): 5d48e70

Update src/sql_generator.py

Browse files
Files changed (1) hide show
  1. src/sql_generator.py +64 -84
src/sql_generator.py CHANGED
@@ -12,18 +12,13 @@ class SQLGenerator:
12
  raw_key = os.getenv("HF_API_KEY") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
13
  self.api_key = raw_key.strip() if raw_key else None
14
 
15
- # 2. MODEL LIST (The "Shotgun" Strategy)
16
- # We try 4 different models. If one fails (404/429), it instantly tries the next.
17
- self.models = [
18
- "mistralai/Mistral-7B-Instruct-v0.3", # Most popular free model
19
- "google/gemma-1.1-7b-it", # Highly reliable backup
20
- "microsoft/Phi-3-mini-4k-instruct", # Very fast, almost always online
21
- "HuggingFaceH4/zephyr-7b-beta" # Old reliable (last resort)
22
- ]
23
 
24
- # 3. BASE URLs (Try Router first, then Legacy)
25
- self.router_url = "https://router.huggingface.co/models/"
26
- self.legacy_url = "https://api-inference.huggingface.co/models/"
27
 
28
  def generate_followup_questions(self, question, sql_query):
29
  return ["Visualize this result", "Export as CSV", "Compare with last year"]
@@ -39,13 +34,13 @@ class SQLGenerator:
39
  if any(word in question.upper() for word in forbidden):
40
  return "SELECT 'Error: Blocked by Safety Layer' as status", "Safety Alert", "I cannot execute commands that modify data."
41
 
 
42
  history_text = ""
43
  if history:
44
  history_text = "PREVIOUS CONVERSATION:\n" + "\n".join([f"User: {h.get('user')}\nSQL: {h.get('sql')}" for h in history[-2:]])
45
 
46
- # System Prompt
47
- system_prompt = f"""<|system|>
48
- You are an elite SQL Expert.
49
  DATABASE SCHEMA:
50
  {context}
51
 
@@ -55,14 +50,13 @@ class SQLGenerator:
55
  1. Output ONLY a valid JSON object. Format: {{ "sql": "SELECT ...", "message": "Friendly text", "explanation": "Brief summary" }}
56
  2. The SQL query MUST be Read-Only (SELECT).
57
  3. Do not use markdown formatting.
 
58
 
59
- QUESTION: {question} </s>
60
- <|assistant|>"""
61
-
62
  payload = {
63
- "inputs": system_prompt,
64
  "parameters": {
65
- "max_new_tokens": 512,
66
  "temperature": 0.1,
67
  "return_full_text": False
68
  }
@@ -73,72 +67,58 @@ class SQLGenerator:
73
  "Content-Type": "application/json"
74
  }
75
 
76
- # 🔄 ROBUST RETRY LOOP
77
- last_error = ""
78
-
79
- for model in self.models:
80
- # Try Router URL first, then Legacy URL
81
- urls_to_try = [self.router_url + model, self.legacy_url + model]
82
 
83
- for api_url in urls_to_try:
84
- try:
85
- print(f" ⚡ Trying {model} at {api_url}...")
86
-
87
- response = requests.post(api_url, headers=headers, json=payload, timeout=20)
88
-
89
- if response.status_code == 200:
90
- print(f" ✅ SUCCESS with {model}!")
91
- return self._process_response(response.json())
92
-
93
- # If 404/410/500, we log and continue to next
94
- print(f" ❌ Failed ({response.status_code}). Trying next...")
95
- last_error = f"{response.status_code}: {response.text}"
96
-
97
- except Exception as e:
98
- print(f" ⚠️ Connection Error: {e}")
99
- last_error = str(e)
100
-
101
- # If ALL models fail
102
- return f"SELECT 'Error: {last_error}' as status", "System Error", "All AI models are currently unavailable. Check your HF_API_KEY."
103
-
104
- def _process_response(self, result):
105
- """Helper to parse the AI response cleanly"""
106
- if isinstance(result, list) and len(result) > 0:
107
- raw_text = result[0].get('generated_text', '')
108
- elif isinstance(result, dict):
109
- raw_text = result.get('generated_text', '')
110
- else:
111
- raw_text = str(result)
112
-
113
- sql_query = ""
114
- message = "Here is the data."
115
- explanation = "Query generated successfully."
116
 
117
- try:
118
- # Clean Markdown
119
- clean_json = re.sub(r"```json|```", "", raw_text).strip()
120
- # Attempt to find JSON object
121
- json_match = re.search(r"\{.*\}", clean_json, re.DOTALL)
122
- if json_match:
123
- data = json.loads(json_match.group(0))
124
- sql_query = data.get("sql", "")
125
- message = data.get("message", message)
126
- explanation = data.get("explanation", explanation)
127
  else:
128
- # Fallback regex
129
- match = re.search(r"(SELECT[\s\S]+?;)", raw_text, re.IGNORECASE)
130
- if match: sql_query = match.group(1)
131
- except:
132
- match = re.search(r"(SELECT[\s\S]+?;)", raw_text, re.IGNORECASE)
133
- if match: sql_query = match.group(1)
134
-
135
- # Final Cleanup
136
- sql_query = sql_query.strip().replace("\n", " ")
137
- if sql_query and not sql_query.endswith(";"): sql_query += ";"
138
-
139
- clean_check = re.sub(r"/\*.*?\*/|--.*?\n", "", sql_query, flags=re.DOTALL).strip().upper()
140
- if not clean_check.startswith("SELECT") and not clean_check.startswith("WITH"):
141
- # Last resort fallback for clean "SELECT"
142
- return "SELECT 'Error: Invalid Query Type' as status", "Safety Error", "I can only perform read-only operations."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
- return sql_query, explanation, message
 
 
 
 
12
  raw_key = os.getenv("HF_API_KEY") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
13
  self.api_key = raw_key.strip() if raw_key else None
14
 
15
+ # 2. USE QWEN 2.5 CODER (7B)
16
+ # This is the Best "Ungated" Free Model right now.
17
+ # It does not require a license agreement click, unlike Mistral/Llama.
18
+ self.repo_id = "Qwen/Qwen2.5-Coder-7B-Instruct"
 
 
 
 
19
 
20
+ # 3. ROUTER URL
21
+ self.api_url = f"https://router.huggingface.co/models/{self.repo_id}"
 
22
 
23
  def generate_followup_questions(self, question, sql_query):
24
  return ["Visualize this result", "Export as CSV", "Compare with last year"]
 
34
  if any(word in question.upper() for word in forbidden):
35
  return "SELECT 'Error: Blocked by Safety Layer' as status", "Safety Alert", "I cannot execute commands that modify data."
36
 
37
+ # Format History
38
  history_text = ""
39
  if history:
40
  history_text = "PREVIOUS CONVERSATION:\n" + "\n".join([f"User: {h.get('user')}\nSQL: {h.get('sql')}" for h in history[-2:]])
41
 
42
+ # System Prompt (Qwen ChatML Format)
43
+ system_prompt = f"""You are an elite SQL Expert.
 
44
  DATABASE SCHEMA:
45
  {context}
46
 
 
50
  1. Output ONLY a valid JSON object. Format: {{ "sql": "SELECT ...", "message": "Friendly text", "explanation": "Brief summary" }}
51
  2. The SQL query MUST be Read-Only (SELECT).
52
  3. Do not use markdown formatting.
53
+ """
54
 
55
+ # Payload (Qwen Specific Format)
 
 
56
  payload = {
57
+ "inputs": f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{question}<|im_end|>\n<|im_start|>assistant\n",
58
  "parameters": {
59
+ "max_new_tokens": 1024,
60
  "temperature": 0.1,
61
  "return_full_text": False
62
  }
 
67
  "Content-Type": "application/json"
68
  }
69
 
70
+ try:
71
+ print(f" ⚡ Generating SQL using {self.repo_id}...")
 
 
 
 
72
 
73
+ # 🚀 DIRECT REQUEST
74
+ response = requests.post(self.api_url, headers=headers, json=payload, timeout=25)
75
+
76
+ if response.status_code != 200:
77
+ print(f" ❌ API FAILURE: {response.status_code}")
78
+ print(f" ❌ RESPONSE: {response.text}")
79
+ return f"SELECT 'Error: API returned {response.status_code}' as status", "API Error", f"Model Error: {response.status_code}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
+ result = response.json()
82
+
83
+ # Helper to extract text
84
+ if isinstance(result, list) and len(result) > 0:
85
+ raw_text = result[0].get('generated_text', '')
86
+ elif isinstance(result, dict):
87
+ raw_text = result.get('generated_text', '')
 
 
 
88
  else:
89
+ raw_text = str(result)
90
+
91
+ # JSON Parsing
92
+ sql_query = ""
93
+ message = "Here is the data."
94
+ explanation = "Query generated successfully."
95
+
96
+ try:
97
+ clean_json = re.sub(r"```json|```", "", raw_text).strip()
98
+ json_match = re.search(r"\{.*\}", clean_json, re.DOTALL)
99
+ if json_match:
100
+ data = json.loads(json_match.group(0))
101
+ sql_query = data.get("sql", "")
102
+ message = data.get("message", message)
103
+ explanation = data.get("explanation", explanation)
104
+ else:
105
+ # Regex Fallback
106
+ match = re.search(r"(SELECT[\s\S]+?;)", raw_text, re.IGNORECASE)
107
+ if match: sql_query = match.group(1)
108
+ except:
109
+ match = re.search(r"(SELECT[\s\S]+?;)", raw_text, re.IGNORECASE)
110
+ if match: sql_query = match.group(1)
111
+
112
+ sql_query = sql_query.strip().replace("\n", " ")
113
+ if sql_query and not sql_query.endswith(";"): sql_query += ";"
114
+
115
+ clean_check = re.sub(r"/\*.*?\*/|--.*?\n", "", sql_query, flags=re.DOTALL).strip().upper()
116
+ if not clean_check.startswith("SELECT") and not clean_check.startswith("WITH"):
117
+ return "SELECT 'Error: Invalid Query Type' as status", "Safety Error", "I can only perform read-only operations."
118
+
119
+ return sql_query, explanation, message
120
 
121
+ except Exception as e:
122
+ print(f" ❌ SYSTEM EXCEPTION: {e}")
123
+ safe_e = str(e).replace("'", "").replace('"', "")
124
+ return f"SELECT 'Error: {safe_e}' as status", "System Error", "An unexpected error occurred."