File size: 4,647 Bytes
91ed273
146146a
91ed273
 
 
 
 
146146a
91ed273
146146a
627c842
146146a
 
 
627c842
 
3656fbb
627c842
 
 
 
3656fbb
146146a
627c842
 
91ed273
 
 
 
 
 
 
146146a
 
 
627c842
91ed273
 
 
 
627c842
3656fbb
 
146146a
91ed273
3656fbb
 
 
 
 
 
146146a
 
627c842
3656fbb
146146a
91ed273
3656fbb
 
627c842
3656fbb
 
627c842
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3656fbb
 
 
 
 
 
 
 
 
 
 
 
 
 
91ed273
304a74a
3656fbb
 
 
 
 
 
 
5d48e70
3656fbb
 
 
 
 
 
 
 
 
 
627c842
3656fbb
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import os
import requests
import re
import json
from dotenv import load_dotenv

class SQLGenerator:
    def __init__(self):
        load_dotenv()
        
        # 1. AUTHENTICATION
        raw_key = os.getenv("HF_API_KEY") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
        self.api_key = raw_key.strip() if raw_key else None
        
        # 2. ACTIVE FREE TIER MODELS (2025)
        # We prioritize "Showcase" models which are kept online by sponsors.
        self.models = [
            "Qwen/Qwen2.5-72B-Instruct",        # Currently the #1 Free Showcase Model
            "Qwen/Qwen2.5-7B-Instruct",         # Reliable Backup
            "microsoft/Phi-3.5-mini-instruct",  # Newest Microsoft Model (Active)
            "mistralai/Mistral-Nemo-Instruct-2407" # New Mistral Standard
        ]
        
        # 3. ENDPOINTS
        self.base_url = "https://router.huggingface.co/models/"

    def generate_followup_questions(self, question, sql_query):
        return ["Visualize this result", "Export as CSV", "Compare with last year"]

    def generate_sql(self, question, context, history=None):
        if history is None: history = []

        if not self.api_key:
            return "SELECT 'Error: HF_API_KEY Missing' as status", "Configuration Error", "Please add HF_API_KEY to your Space Secrets."

        # 🛡️ Safety Layer
        forbidden = ["DROP", "DELETE", "UPDATE", "INSERT", "ALTER", "TRUNCATE", "GRANT"]
        if any(word in question.upper() for word in forbidden):
             return "SELECT 'Error: Blocked by Safety Layer' as status", "Safety Alert", "I cannot execute commands that modify data."

        # Prompt
        system_prompt = f"""You are an SQL Expert.
        Schema:
        {context}

        Rules:
        1. Output valid JSON: {{ "sql": "SELECT ...", "message": "Short text", "explanation": "Brief summary" }}
        2. Read-only SELECT queries only.
        3. No markdown.
        
        Question: {question}"""
        
        payload = {
            "inputs": f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{question}<|im_end|>\n<|im_start|>assistant\n",
            "parameters": {"max_new_tokens": 512, "temperature": 0.1, "return_full_text": False}
        }

        headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}

        # 🔄 RETRY LOOP
        errors = []
        for model in self.models:
            api_url = f"{self.base_url}{model}"
            try:
                print(f"   ⚡ Attempting: {model}...")
                response = requests.post(api_url, headers=headers, json=payload, timeout=20)
                
                if response.status_code == 200:
                    print(f"   ✅ SUCCESS with {model}!")
                    return self._process_response(response.json())
                
                print(f"   ❌ Failed ({response.status_code})")
                errors.append(f"{model}: {response.status_code}")
                
            except Exception as e:
                print(f"   ⚠️ Connection Error: {e}")
                errors.append(f"{model}: Error")

        return f"SELECT 'Error: All models failed' as status", "System Error", f"Debug Info: {', '.join(errors)}"

    def _process_response(self, result):
        if isinstance(result, list) and len(result) > 0:
            raw_text = result[0].get('generated_text', '')
        elif isinstance(result, dict):
            raw_text = result.get('generated_text', '')
        else:
            raw_text = str(result)

        sql_query = ""
        message = "Here is the data."
        explanation = "Query generated successfully."

        try:
            clean_json = re.sub(r"```json|```", "", raw_text).strip()
            json_match = re.search(r"\{.*\}", clean_json, re.DOTALL)
            if json_match:
                data = json.loads(json_match.group(0))
                sql_query = data.get("sql", "")
                message = data.get("message", message)
                explanation = data.get("explanation", explanation)
            else:
                 match = re.search(r"(SELECT[\s\S]+?;)", raw_text, re.IGNORECASE)
                 if match: sql_query = match.group(1)
        except:
            match = re.search(r"(SELECT[\s\S]+?;)", raw_text, re.IGNORECASE)
            if match: sql_query = match.group(1)

        sql_query = sql_query.strip().replace("\n", " ")
        if sql_query and not sql_query.endswith(";"): sql_query += ";"
        
        if not sql_query:
            sql_query = "SELECT 'Error: Empty Query' as status"

        return sql_query, explanation, message