File size: 3,398 Bytes
8642c86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import os
import re
import json
from huggingface_hub import InferenceClient
from dotenv import load_dotenv

class SQLGenerator:
    def __init__(self, api_key=None):
        load_dotenv()
        self.api_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
        self.repo_id = "Qwen/Qwen2.5-Coder-32B-Instruct"
        self.client = InferenceClient(token=self.api_token, timeout=25.0)

    def generate_followup_questions(self, question, sql_query):
        return ["Visualize this result", "Export as CSV", "Compare with last year"]

    def generate_sql(self, question, context, history=None):
        if history is None: history = []

        forbidden = ["DROP", "DELETE", "UPDATE", "INSERT", "ALTER", "TRUNCATE", "GRANT"]
        if any(word in question.upper() for word in forbidden):
             return "SELECT 'Error: Blocked by Safety Layer' as status", "Safety Alert", "I cannot execute commands that modify data."

        history_text = ""
        if history:
            history_text = "PREVIOUS CONVERSATION:\n" + "\n".join([f"User: {h['user']}\nSQL: {h['sql']}" for h in history[-2:]])

        system_prompt = f"""You are an elite SQL Expert.

Schema:

{context}



{history_text}



Rules:

1. Output JSON: {{ "sql": "SELECT ...", "message": "Friendly text", "explanation": "Brief summary" }}

2. Query MUST be Read-Only (SELECT).

3. Do not include markdown formatting like ```json.

"""
        messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": question}]

        try:
            print(f"   ⚡ Generating SQL...")
            response = self.client.chat_completion(messages=messages, model=self.repo_id, max_tokens=1024, temperature=0.1)
            raw_text = response.choices[0].message.content
            
            sql_query = ""
            message = "Here is the data."
            explanation = "Query generated successfully."

            try:
                clean_json = re.sub(r"```json|```", "", raw_text).strip()
                data = json.loads(clean_json)
                sql_query = data.get("sql", "")
                message = data.get("message", message)
                explanation = data.get("explanation", explanation)
            except:
                match = re.search(r"(SELECT[\s\S]+?;)", raw_text, re.IGNORECASE)
                if match: sql_query = match.group(1)

            sql_query = sql_query.strip().replace("\n", " ")
            if sql_query and not sql_query.endswith(";"): sql_query += ";"
            
            # ✅ FIX: Strip comments and whitespace before validation
            clean_check = re.sub(r"/\*.*?\*/|--.*?\n", "", sql_query, flags=re.DOTALL).strip().upper()

            # ✅ FIX: Allow SELECT or WITH clauses
            if not clean_check.startswith("SELECT") and not clean_check.startswith("WITH"):
                print(f"   ⚠️ Invalid SQL Blocked: {sql_query}")
                return "SELECT 'Error: Invalid Query Type (Non-SELECT)' as status", "Safety Error", "I can only perform read-only operations."

            return sql_query, explanation, message

        except Exception as e:
            print(f"   ❌ Model Error: {e}")
            safe_e = str(e).replace("'", "").replace('"', "")
            return f"SELECT 'Error: {safe_e}' as status", "System Error", "An unexpected error occurred."