LalitChaudhari3 commited on
Commit
77ad74c
·
verified ·
1 Parent(s): 94f744f

Update src/sql_generator.py

Browse files
Files changed (1) hide show
  1. src/sql_generator.py +20 -21
src/sql_generator.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
  import re
3
  import json
4
  from dotenv import load_dotenv
@@ -8,21 +9,19 @@ class SQLGenerator:
8
  def __init__(self):
9
  load_dotenv()
10
 
11
- # 1. CLEAN THE KEY (Fixes "Invalid Header" error)
12
  raw_key = os.getenv("HF_API_KEY") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
13
  self.api_key = raw_key.strip() if raw_key else None
14
 
15
- # 2. SETUP CLIENT (Fixes "404/410" errors)
16
- # The client automatically handles the complex routing logic
17
  if self.api_key:
18
  self.client = InferenceClient(api_key=self.api_key)
19
  else:
20
  self.client = None
21
  print(" ❌ FATAL: API Key missing.")
22
 
23
- # 3. USE THE BEST FREE MODEL
24
- # Change this line in src/sql_generator.py if 32B gets slow:
25
- self.model_id = "Qwen/Qwen2.5-Coder-7B-Instruct"
26
 
27
  def generate_followup_questions(self, question, sql_query):
28
  return ["Visualize this result", "Export as CSV", "Compare with last year"]
@@ -36,29 +35,29 @@ class SQLGenerator:
36
  if any(word in question.upper() for word in forbidden):
37
  return "SELECT 'Error: Blocked by Safety Layer' as status", "Safety Alert", "I cannot execute commands that modify data."
38
 
39
- # Prompt
40
  messages = [
41
- {"role": "system", "content": f"""You are an SQL Expert.
42
- Database Schema:
43
- {context}
 
 
 
44
 
45
- Rules:
46
- 1. Output valid JSON: {{ "sql": "SELECT ...", "message": "Short text", "explanation": "Brief summary" }}
47
- 2. Read-only SELECT queries only.
48
- 3. No markdown formatting.
49
  """},
50
- {"role": "user", "content": question}
51
  ]
52
 
53
  try:
54
  print(f" ⚡ Generating SQL using {self.model_id}...")
55
 
56
- # 🚀 OFFICIAL CLIENT CALL (The Robust Way)
57
  response = self.client.chat.completions.create(
58
  model=self.model_id,
59
  messages=messages,
60
  max_tokens=500,
61
- temperature=0.1,
62
  stream=False
63
  )
64
 
@@ -67,15 +66,15 @@ class SQLGenerator:
67
 
68
  except Exception as e:
69
  print(f" ❌ AI ERROR: {e}")
70
- # Failover to backup model if Qwen is busy
71
- if "404" in str(e) or "429" in str(e):
72
  return self._fallback_generate(messages)
73
  return f"SELECT 'Error: {str(e)[:50]}' as status", "System Error", "AI Model unavailable."
74
 
75
  def _fallback_generate(self, messages):
76
- """Backup using a smaller model if the main one fails"""
77
  try:
78
- backup_model = "meta-llama/Llama-3.2-3B-Instruct"
 
79
  print(f" ⚠️ Switching to backup: {backup_model}...")
80
  response = self.client.chat.completions.create(
81
  model=backup_model,
 
1
  import os
2
+ import requests
3
  import re
4
  import json
5
  from dotenv import load_dotenv
 
9
  def __init__(self):
10
  load_dotenv()
11
 
12
+ # 1. CLEAN THE KEY
13
  raw_key = os.getenv("HF_API_KEY") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
14
  self.api_key = raw_key.strip() if raw_key else None
15
 
16
+ # 2. SETUP CLIENT
 
17
  if self.api_key:
18
  self.client = InferenceClient(api_key=self.api_key)
19
  else:
20
  self.client = None
21
  print(" ❌ FATAL: API Key missing.")
22
 
23
+ # 3. USE QWEN 2.5 (Best Free Model)
24
+ self.model_id = "Qwen/Qwen2.5-Coder-32B-Instruct"
 
25
 
26
  def generate_followup_questions(self, question, sql_query):
27
  return ["Visualize this result", "Export as CSV", "Compare with last year"]
 
35
  if any(word in question.upper() for word in forbidden):
36
  return "SELECT 'Error: Blocked by Safety Layer' as status", "Safety Alert", "I cannot execute commands that modify data."
37
 
38
+ # 🧠 SMART PROMPT (Fixes the "No Such Table" error)
39
  messages = [
40
+ {"role": "system", "content": f"""You are a precise SQL Expert.
41
+
42
+ CRITICAL RULES:
43
+ 1. You MUST use the EXACT table names and column names from the SCHEMA below.
44
+ 2. Do NOT hallucinate table names (e.g., if schema says 'Employee', do NOT use 'employees').
45
+ 3. Output valid JSON only.
46
 
47
+ SCHEMA:
48
+ {context}
 
 
49
  """},
50
+ {"role": "user", "content": f"Question: {question}\nReturn JSON format: {{ 'sql': 'SELECT ...', 'message': '...', 'explanation': '...' }}"}
51
  ]
52
 
53
  try:
54
  print(f" ⚡ Generating SQL using {self.model_id}...")
55
 
 
56
  response = self.client.chat.completions.create(
57
  model=self.model_id,
58
  messages=messages,
59
  max_tokens=500,
60
+ temperature=0.1, # Low temp = More strict
61
  stream=False
62
  )
63
 
 
66
 
67
  except Exception as e:
68
  print(f" ❌ AI ERROR: {e}")
69
+ # Failover to 7B if 32B is busy
70
+ if "404" in str(e) or "429" in str(e) or "503" in str(e):
71
  return self._fallback_generate(messages)
72
  return f"SELECT 'Error: {str(e)[:50]}' as status", "System Error", "AI Model unavailable."
73
 
74
  def _fallback_generate(self, messages):
 
75
  try:
76
+ # Fallback to the smaller, faster model
77
+ backup_model = "Qwen/Qwen2.5-Coder-7B-Instruct"
78
  print(f" ⚠️ Switching to backup: {backup_model}...")
79
  response = self.client.chat.completions.create(
80
  model=backup_model,