LalitChaudhari3 commited on
Commit
3ee6432
·
verified ·
1 Parent(s): 00a888e

Update src/sql_generator.py

Browse files
Files changed (1) hide show
  1. src/sql_generator.py +59 -59
src/sql_generator.py CHANGED
@@ -1,36 +1,33 @@
1
  import os
2
- import requests
3
  import re
4
  import json
5
  from dotenv import load_dotenv
 
6
 
7
  class SQLGenerator:
8
  def __init__(self):
9
  load_dotenv()
10
 
11
- # 1. AUTHENTICATION
12
  raw_key = os.getenv("HF_API_KEY") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
13
  self.api_key = raw_key.strip() if raw_key else None
14
 
15
- # 2. ACTIVE FREE TIER MODELS (2025)
16
- # We prioritize "Showcase" models which are kept online by sponsors.
17
- self.models = [
18
- "Qwen/Qwen2.5-72B-Instruct", # Currently the #1 Free Showcase Model
19
- "Qwen/Qwen2.5-7B-Instruct", # Reliable Backup
20
- "microsoft/Phi-3.5-mini-instruct", # Newest Microsoft Model (Active)
21
- "mistralai/Mistral-Nemo-Instruct-2407" # New Mistral Standard
22
- ]
23
-
24
- # 3. ENDPOINTS
25
- self.base_url = "https://router.huggingface.co/models/"
26
 
27
  def generate_followup_questions(self, question, sql_query):
28
  return ["Visualize this result", "Export as CSV", "Compare with last year"]
29
 
30
  def generate_sql(self, question, context, history=None):
31
- if history is None: history = []
32
-
33
- if not self.api_key:
34
  return "SELECT 'Error: HF_API_KEY Missing' as status", "Configuration Error", "Please add HF_API_KEY to your Space Secrets."
35
 
36
  # 🛡️ Safety Layer
@@ -39,53 +36,56 @@ class SQLGenerator:
39
  return "SELECT 'Error: Blocked by Safety Layer' as status", "Safety Alert", "I cannot execute commands that modify data."
40
 
41
  # Prompt
42
- system_prompt = f"""You are an SQL Expert.
43
- Schema:
44
- {context}
45
-
46
- Rules:
47
- 1. Output valid JSON: {{ "sql": "SELECT ...", "message": "Short text", "explanation": "Brief summary" }}
48
- 2. Read-only SELECT queries only.
49
- 3. No markdown.
50
-
51
- Question: {question}"""
52
-
53
- payload = {
54
- "inputs": f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{question}<|im_end|>\n<|im_start|>assistant\n",
55
- "parameters": {"max_new_tokens": 512, "temperature": 0.1, "return_full_text": False}
56
- }
57
-
58
- headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
59
 
60
- # 🔄 RETRY LOOP
61
- errors = []
62
- for model in self.models:
63
- api_url = f"{self.base_url}{model}"
64
- try:
65
- print(f" ⚡ Attempting: {model}...")
66
- response = requests.post(api_url, headers=headers, json=payload, timeout=20)
67
-
68
- if response.status_code == 200:
69
- print(f" ✅ SUCCESS with {model}!")
70
- return self._process_response(response.json())
71
-
72
- print(f" ❌ Failed ({response.status_code})")
73
- errors.append(f"{model}: {response.status_code}")
74
-
75
- except Exception as e:
76
- print(f" ⚠️ Connection Error: {e}")
77
- errors.append(f"{model}: Error")
78
 
79
- return f"SELECT 'Error: All models failed' as status", "System Error", f"Debug Info: {', '.join(errors)}"
 
 
 
 
 
80
 
81
- def _process_response(self, result):
82
- if isinstance(result, list) and len(result) > 0:
83
- raw_text = result[0].get('generated_text', '')
84
- elif isinstance(result, dict):
85
- raw_text = result.get('generated_text', '')
86
- else:
87
- raw_text = str(result)
 
 
 
 
 
 
88
 
 
89
  sql_query = ""
90
  message = "Here is the data."
91
  explanation = "Query generated successfully."
 
1
  import os
 
2
  import re
3
  import json
4
  from dotenv import load_dotenv
5
+ from huggingface_hub import InferenceClient
6
 
7
  class SQLGenerator:
8
  def __init__(self):
9
  load_dotenv()
10
 
11
+ # 1. CLEAN THE KEY (Fixes "Invalid Header" error)
12
  raw_key = os.getenv("HF_API_KEY") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
13
  self.api_key = raw_key.strip() if raw_key else None
14
 
15
+ # 2. SETUP CLIENT (Fixes "404/410" errors)
16
+ # The client automatically handles the complex routing logic
17
+ if self.api_key:
18
+ self.client = InferenceClient(api_key=self.api_key)
19
+ else:
20
+ self.client = None
21
+ print(" FATAL: API Key missing.")
22
+
23
+ # 3. USE THE BEST FREE MODEL
24
+ self.model_id = "Qwen/Qwen2.5-Coder-32B-Instruct"
 
25
 
26
  def generate_followup_questions(self, question, sql_query):
27
  return ["Visualize this result", "Export as CSV", "Compare with last year"]
28
 
29
  def generate_sql(self, question, context, history=None):
30
+ if not self.client:
 
 
31
  return "SELECT 'Error: HF_API_KEY Missing' as status", "Configuration Error", "Please add HF_API_KEY to your Space Secrets."
32
 
33
  # 🛡️ Safety Layer
 
36
  return "SELECT 'Error: Blocked by Safety Layer' as status", "Safety Alert", "I cannot execute commands that modify data."
37
 
38
  # Prompt
39
+ messages = [
40
+ {"role": "system", "content": f"""You are an SQL Expert.
41
+ Database Schema:
42
+ {context}
43
+
44
+ Rules:
45
+ 1. Output valid JSON: {{ "sql": "SELECT ...", "message": "Short text", "explanation": "Brief summary" }}
46
+ 2. Read-only SELECT queries only.
47
+ 3. No markdown formatting.
48
+ """},
49
+ {"role": "user", "content": question}
50
+ ]
 
 
 
 
 
51
 
52
+ try:
53
+ print(f" ⚡ Generating SQL using {self.model_id}...")
54
+
55
+ # 🚀 OFFICIAL CLIENT CALL (The Robust Way)
56
+ response = self.client.chat.completions.create(
57
+ model=self.model_id,
58
+ messages=messages,
59
+ max_tokens=500,
60
+ temperature=0.1,
61
+ stream=False
62
+ )
63
+
64
+ raw_text = response.choices[0].message.content
65
+ return self._process_response(raw_text)
 
 
 
 
66
 
67
+ except Exception as e:
68
+ print(f" ❌ AI ERROR: {e}")
69
+ # Failover to backup model if Qwen is busy
70
+ if "404" in str(e) or "429" in str(e):
71
+ return self._fallback_generate(messages)
72
+ return f"SELECT 'Error: {str(e)[:50]}' as status", "System Error", "AI Model unavailable."
73
 
74
+ def _fallback_generate(self, messages):
75
+ """Backup using a smaller model if the main one fails"""
76
+ try:
77
+ backup_model = "meta-llama/Llama-3.2-3B-Instruct"
78
+ print(f" ⚠️ Switching to backup: {backup_model}...")
79
+ response = self.client.chat.completions.create(
80
+ model=backup_model,
81
+ messages=messages,
82
+ max_tokens=500
83
+ )
84
+ return self._process_response(response.choices[0].message.content)
85
+ except Exception as e:
86
+ return "SELECT 'Error: All models failed' as status", "System Error", "Please check your API Key permissions."
87
 
88
+ def _process_response(self, raw_text):
89
  sql_query = ""
90
  message = "Here is the data."
91
  explanation = "Query generated successfully."