PlainSQL-Agent / src /sql_generator.py
LalitChaudhari3's picture
Update src/sql_generator.py
627c842 verified
raw
history blame
4.65 kB
import os
import requests
import re
import json
from dotenv import load_dotenv
class SQLGenerator:
def __init__(self):
load_dotenv()
# 1. AUTHENTICATION
raw_key = os.getenv("HF_API_KEY") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
self.api_key = raw_key.strip() if raw_key else None
# 2. ACTIVE FREE TIER MODELS (2025)
# We prioritize "Showcase" models which are kept online by sponsors.
self.models = [
"Qwen/Qwen2.5-72B-Instruct", # Currently the #1 Free Showcase Model
"Qwen/Qwen2.5-7B-Instruct", # Reliable Backup
"microsoft/Phi-3.5-mini-instruct", # Newest Microsoft Model (Active)
"mistralai/Mistral-Nemo-Instruct-2407" # New Mistral Standard
]
# 3. ENDPOINTS
self.base_url = "https://router.huggingface.co/models/"
def generate_followup_questions(self, question, sql_query):
return ["Visualize this result", "Export as CSV", "Compare with last year"]
def generate_sql(self, question, context, history=None):
if history is None: history = []
if not self.api_key:
return "SELECT 'Error: HF_API_KEY Missing' as status", "Configuration Error", "Please add HF_API_KEY to your Space Secrets."
# 🛡️ Safety Layer
forbidden = ["DROP", "DELETE", "UPDATE", "INSERT", "ALTER", "TRUNCATE", "GRANT"]
if any(word in question.upper() for word in forbidden):
return "SELECT 'Error: Blocked by Safety Layer' as status", "Safety Alert", "I cannot execute commands that modify data."
# Prompt
system_prompt = f"""You are an SQL Expert.
Schema:
{context}
Rules:
1. Output valid JSON: {{ "sql": "SELECT ...", "message": "Short text", "explanation": "Brief summary" }}
2. Read-only SELECT queries only.
3. No markdown.
Question: {question}"""
payload = {
"inputs": f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{question}<|im_end|>\n<|im_start|>assistant\n",
"parameters": {"max_new_tokens": 512, "temperature": 0.1, "return_full_text": False}
}
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
# 🔄 RETRY LOOP
errors = []
for model in self.models:
api_url = f"{self.base_url}{model}"
try:
print(f" ⚡ Attempting: {model}...")
response = requests.post(api_url, headers=headers, json=payload, timeout=20)
if response.status_code == 200:
print(f" ✅ SUCCESS with {model}!")
return self._process_response(response.json())
print(f" ❌ Failed ({response.status_code})")
errors.append(f"{model}: {response.status_code}")
except Exception as e:
print(f" ⚠️ Connection Error: {e}")
errors.append(f"{model}: Error")
return f"SELECT 'Error: All models failed' as status", "System Error", f"Debug Info: {', '.join(errors)}"
def _process_response(self, result):
if isinstance(result, list) and len(result) > 0:
raw_text = result[0].get('generated_text', '')
elif isinstance(result, dict):
raw_text = result.get('generated_text', '')
else:
raw_text = str(result)
sql_query = ""
message = "Here is the data."
explanation = "Query generated successfully."
try:
clean_json = re.sub(r"```json|```", "", raw_text).strip()
json_match = re.search(r"\{.*\}", clean_json, re.DOTALL)
if json_match:
data = json.loads(json_match.group(0))
sql_query = data.get("sql", "")
message = data.get("message", message)
explanation = data.get("explanation", explanation)
else:
match = re.search(r"(SELECT[\s\S]+?;)", raw_text, re.IGNORECASE)
if match: sql_query = match.group(1)
except:
match = re.search(r"(SELECT[\s\S]+?;)", raw_text, re.IGNORECASE)
if match: sql_query = match.group(1)
sql_query = sql_query.strip().replace("\n", " ")
if sql_query and not sql_query.endswith(";"): sql_query += ";"
if not sql_query:
sql_query = "SELECT 'Error: Empty Query' as status"
return sql_query, explanation, message