LalitChaudhari3 commited on
Commit
627c842
Β·
verified Β·
1 Parent(s): 3656fbb

Update src/sql_generator.py

Browse files
Files changed (1) hide show
  1. src/sql_generator.py +29 -45
src/sql_generator.py CHANGED
@@ -8,30 +8,21 @@ class SQLGenerator:
8
  def __init__(self):
9
  load_dotenv()
10
 
11
- # 1. AUTHENTICATION (With Debugging)
12
  raw_key = os.getenv("HF_API_KEY") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
13
  self.api_key = raw_key.strip() if raw_key else None
14
 
15
- if not self.api_key:
16
- print(" ❌ FATAL: API Key missing.")
17
- else:
18
- print(f" βœ… API Key loaded: {self.api_key[:5]}...")
19
-
20
- # 2. THE "SHOTGUN" MODEL LIST
21
- # We try these 5 models in order. One WILL work.
22
  self.models = [
23
- "microsoft/Phi-3-mini-4k-instruct", # High availability, very fast
24
- "google/gemma-1.1-7b-it", # Google's open model (very stable)
25
- "mistralai/Mistral-7B-Instruct-v0.3", # Standard free tier workhorse
26
- "HuggingFaceH4/zephyr-7b-beta", # Reliable fallback
27
- "Qwen/Qwen2.5-Coder-7B-Instruct" # Excellent coder (if online)
28
  ]
29
 
30
- # 3. ENDPOINTS (Router + Legacy)
31
- self.endpoints = [
32
- "https://router.huggingface.co/models/",
33
- "https://api-inference.huggingface.co/models/"
34
- ]
35
 
36
  def generate_followup_questions(self, question, sql_query):
37
  return ["Visualize this result", "Export as CSV", "Compare with last year"]
@@ -42,12 +33,12 @@ class SQLGenerator:
42
  if not self.api_key:
43
  return "SELECT 'Error: HF_API_KEY Missing' as status", "Configuration Error", "Please add HF_API_KEY to your Space Secrets."
44
 
45
- # πŸ›‘οΈ Safety
46
  forbidden = ["DROP", "DELETE", "UPDATE", "INSERT", "ALTER", "TRUNCATE", "GRANT"]
47
  if any(word in question.upper() for word in forbidden):
48
  return "SELECT 'Error: Blocked by Safety Layer' as status", "Safety Alert", "I cannot execute commands that modify data."
49
 
50
- # Simple Prompt
51
  system_prompt = f"""You are an SQL Expert.
52
  Schema:
53
  {context}
@@ -60,35 +51,31 @@ class SQLGenerator:
60
  Question: {question}"""
61
 
62
  payload = {
63
- "inputs": f"<|user|>\n{system_prompt}\n<|end|>\n<|assistant|>\n",
64
  "parameters": {"max_new_tokens": 512, "temperature": 0.1, "return_full_text": False}
65
  }
66
 
67
  headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
68
 
69
- # πŸ”„ ULTRA-ROBUST RETRY LOOP
70
  errors = []
71
-
72
  for model in self.models:
73
- for base_url in self.endpoints:
74
- api_url = f"{base_url}{model}"
75
- try:
76
- print(f" ⚑ Attempting: {model}...")
77
- response = requests.post(api_url, headers=headers, json=payload, timeout=15)
78
-
79
- if response.status_code == 200:
80
- print(f" βœ… SUCCESS with {model}!")
81
- return self._process_response(response.json())
82
-
83
- # Log failure and continue immediately
84
- print(f" ❌ Failed ({response.status_code})")
85
- errors.append(f"{model}: {response.status_code}")
86
-
87
- except Exception as e:
88
- print(f" ⚠️ Connection Error: {e}")
89
- errors.append(f"{model}: Error")
90
 
91
- # If we get here, literally everything failed (Rare)
92
  return f"SELECT 'Error: All models failed' as status", "System Error", f"Debug Info: {', '.join(errors)}"
93
 
94
  def _process_response(self, result):
@@ -104,9 +91,7 @@ class SQLGenerator:
104
  explanation = "Query generated successfully."
105
 
106
  try:
107
- # Clean and Extract
108
  clean_json = re.sub(r"```json|```", "", raw_text).strip()
109
- # Regex to find JSON
110
  json_match = re.search(r"\{.*\}", clean_json, re.DOTALL)
111
  if json_match:
112
  data = json.loads(json_match.group(0))
@@ -123,8 +108,7 @@ class SQLGenerator:
123
  sql_query = sql_query.strip().replace("\n", " ")
124
  if sql_query and not sql_query.endswith(";"): sql_query += ";"
125
 
126
- # Fallback for empty SQL
127
  if not sql_query:
128
- sql_query = "SELECT 'Error: AI generated empty query' as status"
129
 
130
  return sql_query, explanation, message
 
8
  def __init__(self):
9
  load_dotenv()
10
 
11
+ # 1. AUTHENTICATION
12
  raw_key = os.getenv("HF_API_KEY") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
13
  self.api_key = raw_key.strip() if raw_key else None
14
 
15
+ # 2. ACTIVE FREE TIER MODELS (2025)
16
+ # We prioritize "Showcase" models which are kept online by sponsors.
 
 
 
 
 
17
  self.models = [
18
+ "Qwen/Qwen2.5-72B-Instruct", # Currently the #1 Free Showcase Model
19
+ "Qwen/Qwen2.5-7B-Instruct", # Reliable Backup
20
+ "microsoft/Phi-3.5-mini-instruct", # Newest Microsoft Model (Active)
21
+ "mistralai/Mistral-Nemo-Instruct-2407" # New Mistral Standard
 
22
  ]
23
 
24
+ # 3. ENDPOINTS
25
+ self.base_url = "https://router.huggingface.co/models/"
 
 
 
26
 
27
  def generate_followup_questions(self, question, sql_query):
28
  return ["Visualize this result", "Export as CSV", "Compare with last year"]
 
33
  if not self.api_key:
34
  return "SELECT 'Error: HF_API_KEY Missing' as status", "Configuration Error", "Please add HF_API_KEY to your Space Secrets."
35
 
36
+ # πŸ›‘οΈ Safety Layer
37
  forbidden = ["DROP", "DELETE", "UPDATE", "INSERT", "ALTER", "TRUNCATE", "GRANT"]
38
  if any(word in question.upper() for word in forbidden):
39
  return "SELECT 'Error: Blocked by Safety Layer' as status", "Safety Alert", "I cannot execute commands that modify data."
40
 
41
+ # Prompt
42
  system_prompt = f"""You are an SQL Expert.
43
  Schema:
44
  {context}
 
51
  Question: {question}"""
52
 
53
  payload = {
54
+ "inputs": f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{question}<|im_end|>\n<|im_start|>assistant\n",
55
  "parameters": {"max_new_tokens": 512, "temperature": 0.1, "return_full_text": False}
56
  }
57
 
58
  headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
59
 
60
+ # πŸ”„ RETRY LOOP
61
  errors = []
 
62
  for model in self.models:
63
+ api_url = f"{self.base_url}{model}"
64
+ try:
65
+ print(f" ⚑ Attempting: {model}...")
66
+ response = requests.post(api_url, headers=headers, json=payload, timeout=20)
67
+
68
+ if response.status_code == 200:
69
+ print(f" βœ… SUCCESS with {model}!")
70
+ return self._process_response(response.json())
71
+
72
+ print(f" ❌ Failed ({response.status_code})")
73
+ errors.append(f"{model}: {response.status_code}")
74
+
75
+ except Exception as e:
76
+ print(f" ⚠️ Connection Error: {e}")
77
+ errors.append(f"{model}: Error")
 
 
78
 
 
79
  return f"SELECT 'Error: All models failed' as status", "System Error", f"Debug Info: {', '.join(errors)}"
80
 
81
  def _process_response(self, result):
 
91
  explanation = "Query generated successfully."
92
 
93
  try:
 
94
  clean_json = re.sub(r"```json|```", "", raw_text).strip()
 
95
  json_match = re.search(r"\{.*\}", clean_json, re.DOTALL)
96
  if json_match:
97
  data = json.loads(json_match.group(0))
 
108
  sql_query = sql_query.strip().replace("\n", " ")
109
  if sql_query and not sql_query.endswith(";"): sql_query += ";"
110
 
 
111
  if not sql_query:
112
+ sql_query = "SELECT 'Error: Empty Query' as status"
113
 
114
  return sql_query, explanation, message