File size: 9,954 Bytes
4fe7172
0ccb622
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4fe7172
 
 
 
 
 
 
0ccb622
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
from flask import Flask, request, render_template
import os
import json
import re
import logging
import importlib.util
import tkinter as tk
from tkinter import simpledialog, messagebox
import uuid
import webbrowser
from groq import Groq

app = Flask(__name__)

# Configure logging
logging.basicConfig(filename='AI_SQL_Assistant.log', level=logging.INFO)
logger = logging.getLogger(__name__)

# Prompt for Groq API key using Tkinter
def get_groq_api_key():
    groq_api_key = os.getenv('GROQ_API_KEY')
    if not groq_api_key:
        root = tk.Tk()
        root.withdraw()
        groq_api_key = simpledialog.askstring("Groq API Key", "Enter your Groq API key (get from https://console.groq.com):", show='*')
        root.destroy()
    if groq_api_key:
        os.environ['GROQ_API_KEY'] = groq_api_key
        with open('groq_config.json', 'w') as f:
            json.dump({'GROQ_API_KEY': groq_api_key}, f)
    else:
        messagebox.showerror("Error", "Groq API key is required to run the app.")
        logger.error("No Groq API key provided")
        exit(1)
    return groq_api_key

# Load Groq API key
if os.path.exists('groq_config.json'):
    try:
        with open('groq_config.json', 'r') as f:
            config = json.load(f)
            os.environ['GROQ_API_KEY'] = config.get('GROQ_API_KEY', '')
    except Exception as e:
        logger.error("Failed to load groq_config.json: %s", str(e))

try:
    groq_client = Groq(api_key=get_groq_api_key())
    logger.info("Groq client initialized successfully")
except Exception as e:
    groq_client = None
    logger.error("Failed to initialize Groq client: %s", str(e))
    messagebox.showerror("Error", f"Failed to initialize Groq client: {str(e)}")
    exit(1)

# Storage for current schema, summary, and mock data
current_schema = {}
current_summary = {}
current_db_name = None
results = None
generated_query = None
mock_data = {}

def parse_sql_file(file_content):
    """Parse SQL file to extract schema and mock data."""
    global current_db_name, current_schema, mock_data
    file_content = file_content.decode('utf-8') if isinstance(file_content, bytes) else file_content
    statements = []
    current_statement = ""
    in_comment = False

    db_name_match = re.search(r"CREATE\s+DATABASE\s+[`']?(\w+)[`']?", file_content, re.IGNORECASE)
    current_db_name = db_name_match.group(1) if db_name_match else f"temp_db_{uuid.uuid4().hex[:8]}"
    logger.info("Parsed SQL file: database name=%s", current_db_name)

    for line in file_content.splitlines():
        line = line.strip()
        if not line or line.startswith('--'):
            continue
        if line.startswith('/*'):
            in_comment = True
            continue
        if line.endswith('*/'):
            in_comment = False
            continue
        if not in_comment:
            current_statement += line + ' '
            if line.endswith(';'):
                statements.append(current_statement.strip())
                current_statement = ""

    current_schema = {}
    mock_data = {}
    for statement in statements:
        if statement.startswith("CREATE TABLE"):
            table_match = re.search(r"CREATE TABLE\s+[`']?(\w+)[`']?\s*\(", statement, re.IGNORECASE)
            if table_match:
                table_name = table_match.group(1)
                columns = re.findall(r"\b(\w+)\s+\w+(?:\(.*?\))?(?:,\s|\))", statement)
                current_schema[table_name] = columns
                # Mock some data (e.g., 3 rows with sample values)
                mock_data[table_name] = [[f"val{i}_{col}" for col in columns] for i in range(3)]
    logger.info("Extracted schema: %s", current_schema)
    return True, current_schema, None

def generate_sql_query(question, schema):
    """Generate SQL query using Groq API based on schema."""
    if not groq_client:
        logger.error("Groq client not initialized")
        return "ERROR: Groq client not initialized. Check API key and try again."
    schema_text = "\n".join([f"Table: {table}\nColumns: {', '.join(columns)}" for table, columns in schema.items()])
    prompt = f"""
You are a SQL expert. Based on the following database schema from an uploaded .sql file, generate a valid MySQL query for the user's question. Only use tables and columns that exist in the schema. Use user-friendly aliases for column names (e.g., 'cust_id' becomes 'Customer ID', 'admission_date' becomes 'Admission Date'). Return ONLY the SQL query, without explanations, markdown, or code block formatting (e.g., no ```). If the question references non-existent tables or columns, return an error message starting with 'ERROR:'. Do not use GROUP BY or aggregation functions (e.g., SUM, COUNT, AVG) unless the question explicitly requests aggregation (e.g., 'sum of all bills', 'average cost', 'count of patients'). Treat 'total bill amount' as the individual bill amount (e.g., bill.amount) unless aggregation is clearly specified. For names, concatenate first_name and last_name if applicable (e.g., CONCAT(first_name, ' ', last_name) AS 'Full Name'). Use direct JOINs with correct foreign key relationships if implied (e.g., table_id columns). Avoid subqueries unless absolutely necessary. Place filtering conditions (e.g., department name, status) in the WHERE clause, not JOIN clauses. Handle case sensitivity in string comparisons by using LOWER() for status fields (e.g., LOWER(status) = 'unpaid'). Verify table relationships before joining.

Schema:
{schema_text}

User Question: {question}
"""
    try:
        response = groq_client.chat.completions.create(
            messages=[{"role": "user", "content": prompt}],
            model="llama3-70b-8192"
        )
        query = response.choices[0].message.content.strip()
        query = re.sub(r'```(?:sql)?\n?', '', query)
        query = query.strip()
        logger.info("Generated SQL query: %s", query[:100])
        return query
    except Exception as e:
        logger.error("Failed to generate SQL query: %s", str(e))
        return f"ERROR: Failed to generate SQL query: {str(e)}"

def execute_mock_query(query):
    """Simulate query execution with mock data."""
    if not current_schema or not mock_data:
        logger.error("No schema or mock data available")
        return False, "No data loaded. Please upload an .sql file first.", None
    try:
        # Simple mock execution: return data from the first table if query matches
        table_match = re.search(r"FROM\s+[`']?(\w+)[`']?", query, re.IGNORECASE)
        if table_match and table_match.group(1) in mock_data:
            return True, mock_data[table_match.group(1)], None
        return False, "Mock query execution failed: Table not found or query not supported.", None
    except Exception as e:
        logger.error("Mock query execution failed: %s", str(e))
        return False, f"Mock query execution failed: {str(e)}", None

@app.route('/', methods=['GET', 'POST'])
def index():
    global current_schema, current_summary, results, generated_query
    error = None

    if not groq_client:
        error = "Groq client not initialized. Please restart the app and enter a valid Groq API key."
        logger.error(error)

    if request.method == 'POST':
        logger.info("Received POST request")
        if 'sql_file' in request.files:
            file = request.files['sql_file']
            logger.info("SQL file upload detected: %s", file.filename if file else "No file")
            if file and file.filename.endswith('.sql'):
                success, schema, _ = parse_sql_file(file.read())
                if success:
                    current_schema = schema
                    logger.info("SQL file parsed successfully")
                else:
                    error = "Failed to parse SQL file."
                    logger.error(error)
            else:
                error = "Please upload a valid .sql file."
                logger.error(error)
        elif 'question' in request.form:
            question = request.form['question']
            logger.info("Received question: %s", question)
            if not current_schema:
                error = "No schema loaded. Please upload an .sql file first."
                logger.error(error)
            else:
                generated_query = generate_sql_query(question, current_schema)
                if not generated_query.startswith('ERROR:'):
                    success, result, _ = execute_mock_query(generated_query)
                    if success:
                        results = result
                        logger.info("Mock query executed successfully, results: %d rows", len(result))
                    else:
                        error = result
                        logger.error(error)
                else:
                    error = generated_query
                    logger.error(error)

    logger.info("Rendering index.html: error=%s, schema=%s, summary=%s, results=%s",
                error, bool(current_schema), bool(current_summary), bool(results))
    return render_template('index.html', error=error, schema=current_schema, summary=current_summary, results=results, query=generated_query)

if __name__ == '__main__':
    try:
        spec = importlib.util.find_spec("webbrowser")
        if spec is None:
            logger.error("Standard library webbrowser module not found")
            raise ImportError("Could not find webbrowser module")
        webbrowser = importlib.util.module_from_spec(spec)
        spec.loader.exec_module(webbrowser)
    except ImportError as e:
        logger.error("Failed to import webbrowser: %s", str(e))
        raise

    url = 'http://localhost:7860'
    try:
        webbrowser.open(url)
    except AttributeError:
        logger.error("webbrowser.open() failed, possibly due to environment issue")
        print(f"Warning: Could not open browser automatically. Please navigate to {url} manually.")
    app.run(host='0.0.0.0', port=7860, debug=False)