LalitChaudhari3 commited on
Commit
d3786b9
Β·
verified Β·
1 Parent(s): 10a32e9

Update src/sql_generator.py

Browse files
Files changed (1) hide show
  1. src/sql_generator.py +22 -14
src/sql_generator.py CHANGED
@@ -12,10 +12,12 @@ class SQLGenerator:
12
  raw_key = os.getenv("HF_API_KEY") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
13
  self.api_key = raw_key.strip() if raw_key else None
14
 
15
- # 2. Use Qwen 2.5 Coder
16
- self.repo_id = "Qwen/Qwen2.5-Coder-32B-Instruct"
 
 
17
 
18
- # βœ… FIX: Updated to the new Hugging Face Router URL
19
  self.api_url = f"https://router.huggingface.co/models/{self.repo_id}"
20
 
21
  def generate_followup_questions(self, question, sql_query):
@@ -38,22 +40,23 @@ class SQLGenerator:
38
  if history:
39
  history_text = "PREVIOUS CONVERSATION:\n" + "\n".join([f"User: {h.get('user')}\nSQL: {h.get('sql')}" for h in history[-2:]])
40
 
41
- # System Prompt
42
- system_prompt = f"""You are an elite SQL Expert.
43
- Schema:
44
  {context}
45
 
46
  {history_text}
47
 
48
- Rules:
49
- 1. Output JSON: {{ "sql": "SELECT ...", "message": "Friendly text", "explanation": "Brief summary" }}
50
- 2. Query MUST be Read-Only (SELECT).
51
- 3. Do not include markdown formatting like ```json.
52
- """
 
53
 
54
  # Payload
55
  payload = {
56
- "inputs": f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{question}<|im_end|>\n<|im_start|>assistant\n",
57
  "parameters": {
58
  "max_new_tokens": 1024,
59
  "temperature": 0.1,
@@ -67,14 +70,14 @@ class SQLGenerator:
67
  }
68
 
69
  try:
70
- print(f" ⚑ Generating SQL via Direct API...")
71
 
72
  # πŸš€ DIRECT REQUEST
73
  response = requests.post(self.api_url, headers=headers, json=payload, timeout=25)
74
 
75
  if response.status_code != 200:
76
  print(f" ❌ API Status: {response.status_code} - {response.text}")
77
- return f"SELECT 'Error: API returned {response.status_code}' as status", "API Error", "The AI model is currently unavailable."
78
 
79
  # Parse Response
80
  result = response.json()
@@ -92,12 +95,17 @@ class SQLGenerator:
92
  explanation = "Query generated successfully."
93
 
94
  try:
 
 
 
 
95
  clean_json = re.sub(r"```json|```", "", raw_text).strip()
96
  data = json.loads(clean_json)
97
  sql_query = data.get("sql", "")
98
  message = data.get("message", message)
99
  explanation = data.get("explanation", explanation)
100
  except:
 
101
  match = re.search(r"(SELECT[\s\S]+?;)", raw_text, re.IGNORECASE)
102
  if match: sql_query = match.group(1)
103
 
 
12
  raw_key = os.getenv("HF_API_KEY") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
13
  self.api_key = raw_key.strip() if raw_key else None
14
 
15
+ # 2. USE A MODEL THAT IS DEFINITELY FREE & ONLINE
16
+ # The 32B model caused the 404. We switch to Mistral-7B-Instruct-v0.3
17
+ # It is excellent for SQL and 100% supported on the free tier.
18
+ self.repo_id = "mistralai/Mistral-7B-Instruct-v0.3"
19
 
20
+ # 3. USE THE NEW ROUTER URL
21
  self.api_url = f"https://router.huggingface.co/models/{self.repo_id}"
22
 
23
  def generate_followup_questions(self, question, sql_query):
 
40
  if history:
41
  history_text = "PREVIOUS CONVERSATION:\n" + "\n".join([f"User: {h.get('user')}\nSQL: {h.get('sql')}" for h in history[-2:]])
42
 
43
+ # System Prompt (Optimized for Mistral)
44
+ system_prompt = f"""<s>[INST] You are an expert SQL Assistant.
45
+ DATABASE SCHEMA:
46
  {context}
47
 
48
  {history_text}
49
 
50
+ RULES:
51
+ 1. Output ONLY a valid JSON object. Format: {{ "sql": "SELECT ...", "message": "Friendly text", "explanation": "Brief summary" }}
52
+ 2. The SQL query MUST be Read-Only (SELECT).
53
+ 3. Do not use markdown formatting.
54
+
55
+ QUESTION: {question} [/INST]"""
56
 
57
  # Payload
58
  payload = {
59
+ "inputs": system_prompt,
60
  "parameters": {
61
  "max_new_tokens": 1024,
62
  "temperature": 0.1,
 
70
  }
71
 
72
  try:
73
+ print(f" ⚑ Generating SQL via Direct API ({self.repo_id})...")
74
 
75
  # πŸš€ DIRECT REQUEST
76
  response = requests.post(self.api_url, headers=headers, json=payload, timeout=25)
77
 
78
  if response.status_code != 200:
79
  print(f" ❌ API Status: {response.status_code} - {response.text}")
80
+ return f"SELECT 'Error: API returned {response.status_code}' as status", "API Error", f"Model Error: {response.status_code}"
81
 
82
  # Parse Response
83
  result = response.json()
 
95
  explanation = "Query generated successfully."
96
 
97
  try:
98
+ # Mistral sometimes keeps the prompt. Split by [/INST] if present.
99
+ if "[/INST]" in raw_text:
100
+ raw_text = raw_text.split("[/INST]")[-1]
101
+
102
  clean_json = re.sub(r"```json|```", "", raw_text).strip()
103
  data = json.loads(clean_json)
104
  sql_query = data.get("sql", "")
105
  message = data.get("message", message)
106
  explanation = data.get("explanation", explanation)
107
  except:
108
+ # Fallback: Regex to find SQL if JSON parsing fails
109
  match = re.search(r"(SELECT[\s\S]+?;)", raw_text, re.IGNORECASE)
110
  if match: sql_query = match.group(1)
111