LalitChaudhari3 commited on
Commit
8fbaa76
·
verified ·
1 Parent(s): 91ed273

Update src/sql_generator.py

Browse files
Files changed (1) hide show
  1. src/sql_generator.py +13 -17
src/sql_generator.py CHANGED
@@ -8,31 +8,31 @@ class SQLGenerator:
8
  def __init__(self):
9
  load_dotenv()
10
 
11
- # 1. Robustly fetch the API Key
12
- # Try HF_API_KEY first (our standard), then fall back to others if user renamed it
13
- self.api_key = os.getenv("HF_API_KEY") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
 
14
 
15
- # 2. Use the powerful Qwen 2.5 Coder model
16
  self.repo_id = "Qwen/Qwen2.5-Coder-32B-Instruct"
17
  self.api_url = f"https://api-inference.huggingface.co/models/{self.repo_id}"
18
 
19
  def generate_followup_questions(self, question, sql_query):
20
- # Basic heuristics to suggest next steps
21
  return ["Visualize this result", "Export as CSV", "Compare with last year"]
22
 
23
  def generate_sql(self, question, context, history=None):
24
  if history is None: history = []
25
 
26
- # 🚨 ERROR CHECK: Stop early if key is missing
27
  if not self.api_key:
28
  return "SELECT 'Error: HF_API_KEY Missing' as status", "Configuration Error", "Please add HF_API_KEY to your Space Secrets."
29
 
30
- # 🛡️ Safety Layer: Block dangerous keywords
31
  forbidden = ["DROP", "DELETE", "UPDATE", "INSERT", "ALTER", "TRUNCATE", "GRANT"]
32
  if any(word in question.upper() for word in forbidden):
33
  return "SELECT 'Error: Blocked by Safety Layer' as status", "Safety Alert", "I cannot execute commands that modify data."
34
 
35
- # Format History for the AI
36
  history_text = ""
37
  if history:
38
  history_text = "PREVIOUS CONVERSATION:\n" + "\n".join([f"User: {h.get('user')}\nSQL: {h.get('sql')}" for h in history[-2:]])
@@ -50,12 +50,12 @@ class SQLGenerator:
50
  3. Do not include markdown formatting like ```json.
51
  """
52
 
53
- # Prepare the payload (Direct HTTP Request)
54
  payload = {
55
  "inputs": f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{question}<|im_end|>\n<|im_start|>assistant\n",
56
  "parameters": {
57
  "max_new_tokens": 1024,
58
- "temperature": 0.1, # Low temp = precise SQL
59
  "return_full_text": False
60
  }
61
  }
@@ -68,7 +68,7 @@ class SQLGenerator:
68
  try:
69
  print(f" ⚡ Generating SQL via Direct API...")
70
 
71
- # 🚀 DIRECT REQUEST (Bypasses library auth issues)
72
  response = requests.post(self.api_url, headers=headers, json=payload, timeout=25)
73
 
74
  if response.status_code != 200:
@@ -78,7 +78,6 @@ class SQLGenerator:
78
  # Parse Response
79
  result = response.json()
80
 
81
- # Handle different response formats (sometimes list, sometimes dict)
82
  if isinstance(result, list) and len(result) > 0:
83
  raw_text = result[0].get('generated_text', '')
84
  elif isinstance(result, dict):
@@ -86,28 +85,25 @@ class SQLGenerator:
86
  else:
87
  raw_text = str(result)
88
 
89
- # JSON Parsing & Cleanup
90
  sql_query = ""
91
  message = "Here is the data."
92
  explanation = "Query generated successfully."
93
 
94
  try:
95
- # Remove markdown code blocks if AI added them
96
  clean_json = re.sub(r"```json|```", "", raw_text).strip()
97
  data = json.loads(clean_json)
98
  sql_query = data.get("sql", "")
99
  message = data.get("message", message)
100
  explanation = data.get("explanation", explanation)
101
  except:
102
- # Fallback: Regex to find SQL if JSON parsing fails
103
  match = re.search(r"(SELECT[\s\S]+?;)", raw_text, re.IGNORECASE)
104
  if match: sql_query = match.group(1)
105
 
106
- # Final SQL Cleanup
107
  sql_query = sql_query.strip().replace("\n", " ")
108
  if sql_query and not sql_query.endswith(";"): sql_query += ";"
109
 
110
- # 🛡️ Final Validation (Allow SELECT or WITH)
111
  clean_check = re.sub(r"/\*.*?\*/|--.*?\n", "", sql_query, flags=re.DOTALL).strip().upper()
112
  if not clean_check.startswith("SELECT") and not clean_check.startswith("WITH"):
113
  print(f" ⚠️ Invalid SQL Blocked: {sql_query}")
 
8
  def __init__(self):
9
  load_dotenv()
10
 
11
+ # 1. ROBUSTLY FETCH & CLEAN THE KEY
12
+ # We use .strip() to remove the accidental '\n' (newline) causing your error
13
+ raw_key = os.getenv("HF_API_KEY") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
14
+ self.api_key = raw_key.strip() if raw_key else None
15
 
16
+ # 2. Use Qwen 2.5 Coder
17
  self.repo_id = "Qwen/Qwen2.5-Coder-32B-Instruct"
18
  self.api_url = f"https://api-inference.huggingface.co/models/{self.repo_id}"
19
 
20
  def generate_followup_questions(self, question, sql_query):
 
21
  return ["Visualize this result", "Export as CSV", "Compare with last year"]
22
 
23
  def generate_sql(self, question, context, history=None):
24
  if history is None: history = []
25
 
26
+ # 🚨 ERROR CHECK
27
  if not self.api_key:
28
  return "SELECT 'Error: HF_API_KEY Missing' as status", "Configuration Error", "Please add HF_API_KEY to your Space Secrets."
29
 
30
+ # 🛡️ Safety Layer
31
  forbidden = ["DROP", "DELETE", "UPDATE", "INSERT", "ALTER", "TRUNCATE", "GRANT"]
32
  if any(word in question.upper() for word in forbidden):
33
  return "SELECT 'Error: Blocked by Safety Layer' as status", "Safety Alert", "I cannot execute commands that modify data."
34
 
35
+ # Format History
36
  history_text = ""
37
  if history:
38
  history_text = "PREVIOUS CONVERSATION:\n" + "\n".join([f"User: {h.get('user')}\nSQL: {h.get('sql')}" for h in history[-2:]])
 
50
  3. Do not include markdown formatting like ```json.
51
  """
52
 
53
+ # Payload
54
  payload = {
55
  "inputs": f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{question}<|im_end|>\n<|im_start|>assistant\n",
56
  "parameters": {
57
  "max_new_tokens": 1024,
58
+ "temperature": 0.1,
59
  "return_full_text": False
60
  }
61
  }
 
68
  try:
69
  print(f" ⚡ Generating SQL via Direct API...")
70
 
71
+ # 🚀 DIRECT REQUEST
72
  response = requests.post(self.api_url, headers=headers, json=payload, timeout=25)
73
 
74
  if response.status_code != 200:
 
78
  # Parse Response
79
  result = response.json()
80
 
 
81
  if isinstance(result, list) and len(result) > 0:
82
  raw_text = result[0].get('generated_text', '')
83
  elif isinstance(result, dict):
 
85
  else:
86
  raw_text = str(result)
87
 
88
+ # Clean JSON
89
  sql_query = ""
90
  message = "Here is the data."
91
  explanation = "Query generated successfully."
92
 
93
  try:
 
94
  clean_json = re.sub(r"```json|```", "", raw_text).strip()
95
  data = json.loads(clean_json)
96
  sql_query = data.get("sql", "")
97
  message = data.get("message", message)
98
  explanation = data.get("explanation", explanation)
99
  except:
 
100
  match = re.search(r"(SELECT[\s\S]+?;)", raw_text, re.IGNORECASE)
101
  if match: sql_query = match.group(1)
102
 
103
+ # Final Cleanup
104
  sql_query = sql_query.strip().replace("\n", " ")
105
  if sql_query and not sql_query.endswith(";"): sql_query += ";"
106
 
 
107
  clean_check = re.sub(r"/\*.*?\*/|--.*?\n", "", sql_query, flags=re.DOTALL).strip().upper()
108
  if not clean_check.startswith("SELECT") and not clean_check.startswith("WITH"):
109
  print(f" ⚠️ Invalid SQL Blocked: {sql_query}")