Al1Abdullah commited on
Commit
4fe7172
·
verified ·
1 Parent(s): 5202f44

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +213 -119
app.py CHANGED
@@ -1,119 +1,213 @@
1
- import gradio as gr
2
- import pandas as pd
3
- import re
4
- import json
5
- from groq import Groq
6
- import os
7
-
8
- # --- Groq API Key Handling ---
9
- def get_groq_api_key():
10
- return os.environ.get('GROQ_API_KEY', '')
11
-
12
- groq_api_key = get_groq_api_key()
13
- groq_client = Groq(api_key=groq_api_key) if groq_api_key else None
14
-
15
- # --- Schema Parsing and Summarization ---
16
- def parse_sql_schema(sql_text):
17
- tables = {}
18
- current_table = None
19
- for line in sql_text.splitlines():
20
- line = line.strip()
21
- if line.lower().startswith('create table'):
22
- match = re.search(r'CREATE TABLE [`"]?(\w+)[`"]?', line, re.IGNORECASE)
23
- if match:
24
- current_table = match.group(1)
25
- tables[current_table] = []
26
- elif current_table and line and not line.startswith('--') and not line.startswith(')'):
27
- col_match = re.match(r'[`"]?(\w+)[`"]?\s', line)
28
- if col_match:
29
- col = col_match.group(1)
30
- tables[current_table].append(col)
31
- elif line.startswith(')'):
32
- current_table = None
33
- return tables
34
-
35
- def generate_schema_summary(schema):
36
- summary = ''
37
- for table, columns in schema.items():
38
- summary += f"Table: {table}\n Columns: {', '.join(columns)}\n"
39
- return summary or 'No tables found.'
40
-
41
- # --- SQL Generation ---
42
- def generate_sql_query(question, schema):
43
- if not groq_client:
44
- return "ERROR: Groq client not initialized. Please set your API key in groq_config.json."
45
- schema_text = "\n".join([f"Table: {table}\nColumns: {', '.join(columns)}" for table, columns in schema.items()])
46
- prompt = f"""
47
- You are a SQL expert. Based on the following database schema, generate a valid SQL query for the user's question. Only use tables and columns that exist in the schema. Return ONLY the SQL query, without explanations, markdown, or code block formatting.\n\nSchema:\n{schema_text}\n\nUser Question: {question}\n"""
48
- try:
49
- response = groq_client.chat.completions.create(
50
- messages=[{"role": "user", "content": prompt}],
51
- model="llama3-70b-8192"
52
- )
53
- query = response.choices[0].message.content.strip()
54
- query = re.sub(r'```(?:sql)?\n?', '', query)
55
- return query
56
- except Exception as e:
57
- return f"ERROR: Failed to generate SQL query: {str(e)}"
58
-
59
- # --- SQL Execution on CSV (optional demo) ---
60
- def run_sql_on_csv(sql_query, csv_file):
61
- try:
62
- df = pd.read_csv(csv_file)
63
- # Only support SELECT * FROM table [WHERE ...] for demo
64
- match = re.match(r'SELECT \* FROM (\w+)(?: WHERE (.*))?;', sql_query, re.IGNORECASE)
65
- if not match:
66
- return None, "Only simple SELECT * FROM ... queries are supported for CSV demo."
67
- # Ignore table name, just use df
68
- where_clause = match.group(2)
69
- if where_clause:
70
- # Very basic filter: col = value
71
- col_val = re.match(r'(\w+) *= *["\']?(.*?)["\']?$', where_clause)
72
- if col_val:
73
- col, val = col_val.groups()
74
- df = df[df[col] == val]
75
- return df, None
76
- except Exception as e:
77
- return None, f"Error running SQL on CSV: {str(e)}"
78
-
79
- # --- Gradio Interface Logic ---
80
- def process(schema_file, question, csv_file):
81
- if not schema_file:
82
- return "Please upload a .sql schema file.", None, None, None
83
- with open(schema_file.name, 'r', encoding='utf-8') as f:
84
- sql_text = f.read()
85
- schema = parse_sql_schema(sql_text)
86
- summary = generate_schema_summary(schema)
87
- if not question:
88
- return summary, None, None, None
89
- sql_query = generate_sql_query(question, schema)
90
- if sql_query.startswith('ERROR:'):
91
- return summary, sql_query, None, None
92
- if csv_file:
93
- df, err = run_sql_on_csv(sql_query, csv_file)
94
- if err:
95
- return summary, sql_query, None, err
96
- return summary, sql_query, df, None
97
- return summary, sql_query, None, None
98
-
99
- with gr.Blocks() as demo:
100
- gr.Markdown("""
101
- # AI SQL Assistant (Hugging Face Demo)
102
- - Upload a `.sql` schema file (no MySQL needed!)
103
- - See a summary of your schema
104
- - Ask a question in natural language
105
- - Get the generated SQL query
106
- - (Optional) Upload a CSV file to run the query and see results as a table
107
- """)
108
- with gr.Row():
109
- schema_file = gr.File(label="Upload .sql Schema File", file_types=[".sql"])
110
- csv_file = gr.File(label="(Optional) Upload CSV Data File", file_types=[".csv"])
111
- question = gr.Textbox(label="Ask a Question (e.g., Show all patients admitted in July 2025)")
112
- btn = gr.Button("Generate SQL and Query Table")
113
- summary = gr.Textbox(label="Schema Summary", interactive=False)
114
- sql_query = gr.Textbox(label="Generated SQL Query", interactive=False)
115
- table = gr.Dataframe(label="Query Results Table (if CSV provided)")
116
- error = gr.Textbox(label="Error", interactive=False)
117
- btn.click(process, inputs=[schema_file, question, csv_file], outputs=[summary, sql_query, table, error])
118
-
119
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, render_template
2
+ import os
3
+ import json
4
+ import re
5
+ import logging
6
+ import tkinter as tk
7
+ from tkinter import simpledialog, messagebox
8
+ import importlib.util
9
+ from groq import Groq
10
+
11
+ app = Flask(__name__)
12
+
13
+ # Configure logging
14
+ logging.basicConfig(filename='AI_SQL_Assistant.log', level=logging.INFO)
15
+ logger = logging.getLogger(__name__)
16
+
17
+ # Prompt for Groq API key using Tkinter
18
+ def get_groq_api_key():
19
+ groq_api_key = os.getenv('GROQ_API_KEY')
20
+ if not groq_api_key:
21
+ root = tk.Tk()
22
+ root.withdraw()
23
+ groq_api_key = simpledialog.askstring("Groq API Key", "Enter your Groq API key (get from https://console.groq.com):", show='*')
24
+ if groq_api_key:
25
+ os.environ['GROQ_API_KEY'] = groq_api_key
26
+ with open('groq_config.json', 'w') as f:
27
+ json.dump({'GROQ_API_KEY': groq_api_key}, f)
28
+ else:
29
+ messagebox.showerror("Error", "Groq API key is required to run the app.")
30
+ logger.error("No Groq API key provided")
31
+ exit(1)
32
+ root.destroy()
33
+ return groq_api_key
34
+
35
+ # Load Groq API key
36
+ if os.path.exists('groq_config.json'):
37
+ try:
38
+ with open('groq_config.json', 'r') as f:
39
+ config = json.load(f)
40
+ os.environ['GROQ_API_KEY'] = config.get('GROQ_API_KEY', '')
41
+ except Exception as e:
42
+ logger.error("Failed to load groq_config.json: %s", str(e))
43
+
44
+ try:
45
+ groq_client = Groq(api_key=get_groq_api_key())
46
+ logger.info("Groq client initialized successfully")
47
+ except Exception as e:
48
+ groq_client = None
49
+ logger.error("Failed to initialize Groq client: %s", str(e))
50
+ messagebox.showerror("Error", f"Failed to initialize Groq client: {str(e)}")
51
+ exit(1)
52
+
53
+ # Storage for current schema, summary, and mock data
54
+ current_schema = {}
55
+ current_summary = {}
56
+ current_db_name = None
57
+ results = None
58
+ generated_query = None
59
+ mock_data = {}
60
+
61
+ def parse_sql_file(file_content):
62
+ """Parse SQL file to extract schema and mock data."""
63
+ global current_db_name, current_schema, mock_data
64
+ file_content = file_content.decode('utf-8') if isinstance(file_content, bytes) else file_content
65
+ statements = []
66
+ current_statement = ""
67
+ in_comment = False
68
+
69
+ db_name_match = re.search(r"CREATE\s+DATABASE\s+[`']?(\w+)[`']?", file_content, re.IGNORECASE)
70
+ current_db_name = db_name_match.group(1) if db_name_match else f"temp_db_{uuid.uuid4().hex[:8]}"
71
+ logger.info("Parsed SQL file: database name=%s", current_db_name)
72
+
73
+ for line in file_content.splitlines():
74
+ line = line.strip()
75
+ if not line or line.startswith('--'):
76
+ continue
77
+ if line.startswith('/*'):
78
+ in_comment = True
79
+ continue
80
+ if line.endswith('*/'):
81
+ in_comment = False
82
+ continue
83
+ if not in_comment:
84
+ current_statement += line + ' '
85
+ if line.endswith(';'):
86
+ statements.append(current_statement.strip())
87
+ current_statement = ""
88
+
89
+ current_schema = {}
90
+ mock_data = {}
91
+ for statement in statements:
92
+ if statement.startswith("CREATE TABLE"):
93
+ table_match = re.search(r"CREATE TABLE\s+[`']?(\w+)[`']?\s*\(", statement, re.IGNORECASE)
94
+ if table_match:
95
+ table_name = table_match.group(1)
96
+ columns = re.findall(r"\b(\w+)\s+\w+(?:\(.*?\))?(?:,\s|\))", statement)
97
+ current_schema[table_name] = columns
98
+ # Mock some data (e.g., 3 rows with sample values)
99
+ mock_data[table_name] = [[f"val{i}_{col}" for col in columns] for i in range(3)]
100
+ logger.info("Extracted schema: %s", current_schema)
101
+ return True, current_schema, None
102
+
103
+ def generate_sql_query(question, schema):
104
+ """Generate SQL query using Groq API based on schema."""
105
+ if not groq_client:
106
+ logger.error("Groq client not initialized")
107
+ return "ERROR: Groq client not initialized. Check API key and try again."
108
+ schema_text = "\n".join([f"Table: {table}\nColumns: {', '.join(columns)}" for table, columns in schema.items()])
109
+ prompt = f"""
110
+ You are a SQL expert. Based on the following database schema from an uploaded .sql file, generate a valid MySQL query for the user's question. Only use tables and columns that exist in the schema. Use user-friendly aliases for column names (e.g., 'cust_id' becomes 'Customer ID', 'admission_date' becomes 'Admission Date'). Return ONLY the SQL query, without explanations, markdown, or code block formatting (e.g., no ```). If the question references non-existent tables or columns, return an error message starting with 'ERROR:'. Do not use GROUP BY or aggregation functions (e.g., SUM, COUNT, AVG) unless the question explicitly requests aggregation (e.g., 'sum of all bills', 'average cost', 'count of patients'). Treat 'total bill amount' as the individual bill amount (e.g., bill.amount) unless aggregation is clearly specified. For names, concatenate first_name and last_name if applicable (e.g., CONCAT(first_name, ' ', last_name) AS 'Full Name'). Use direct JOINs with correct foreign key relationships if implied (e.g., table_id columns). Avoid subqueries unless absolutely necessary. Place filtering conditions (e.g., department name, status) in the WHERE clause, not JOIN clauses. Handle case sensitivity in string comparisons by using LOWER() for status fields (e.g., LOWER(status) = 'unpaid'). Verify table relationships before joining.
111
+
112
+ Schema:
113
+ {schema_text}
114
+
115
+ User Question: {question}
116
+ """
117
+ try:
118
+ response = groq_client.chat.completions.create(
119
+ messages=[{"role": "user", "content": prompt}],
120
+ model="llama3-70b-8192"
121
+ )
122
+ query = response.choices[0].message.content.strip()
123
+ query = re.sub(r'```(?:sql)?\n?', '', query)
124
+ query = query.strip()
125
+ logger.info("Generated SQL query: %s", query[:100])
126
+ return query
127
+ except Exception as e:
128
+ logger.error("Failed to generate SQL query: %s", str(e))
129
+ return f"ERROR: Failed to generate SQL query: {str(e)}"
130
+
131
+ def execute_mock_query(query):
132
+ """Simulate query execution with mock data."""
133
+ if not current_schema or not mock_data:
134
+ logger.error("No schema or mock data available")
135
+ return False, "No data loaded. Please upload an .sql file first.", None
136
+ try:
137
+ # Simple mock execution: return data from the first table if query matches
138
+ table_match = re.search(r"FROM\s+[`']?(\w+)[`']?", query, re.IGNORECASE)
139
+ if table_match and table_match.group(1) in mock_data:
140
+ return True, mock_data[table_match.group(1)], None
141
+ return False, "Mock query execution failed: Table not found or query not supported.", None
142
+ except Exception as e:
143
+ logger.error("Mock query execution failed: %s", str(e))
144
+ return False, f"Mock query execution failed: {str(e)}", None
145
+
146
+ @app.route('/', methods=['GET', 'POST'])
147
+ def index():
148
+ global current_schema, current_summary, results, generated_query
149
+ error = None
150
+
151
+ if not groq_client:
152
+ error = "Groq client not initialized. Please restart the app and enter a valid Groq API key."
153
+ logger.error(error)
154
+
155
+ if request.method == 'POST':
156
+ logger.info("Received POST request")
157
+ if 'sql_file' in request.files:
158
+ file = request.files['sql_file']
159
+ logger.info("SQL file upload detected: %s", file.filename if file else "No file")
160
+ if file and file.filename.endswith('.sql'):
161
+ success, schema, _ = parse_sql_file(file.read())
162
+ if success:
163
+ current_schema = schema
164
+ logger.info("SQL file parsed successfully")
165
+ else:
166
+ error = "Failed to parse SQL file."
167
+ logger.error(error)
168
+ else:
169
+ error = "Please upload a valid .sql file."
170
+ logger.error(error)
171
+ elif 'question' in request.form:
172
+ question = request.form['question']
173
+ logger.info("Received question: %s", question)
174
+ if not current_schema:
175
+ error = "No schema loaded. Please upload an .sql file first."
176
+ logger.error(error)
177
+ else:
178
+ generated_query = generate_sql_query(question, current_schema)
179
+ if not generated_query.startswith('ERROR:'):
180
+ success, result, _ = execute_mock_query(generated_query)
181
+ if success:
182
+ results = result
183
+ logger.info("Mock query executed successfully, results: %d rows", len(result))
184
+ else:
185
+ error = result
186
+ logger.error(error)
187
+ else:
188
+ error = generated_query
189
+ logger.error(error)
190
+
191
+ logger.info("Rendering index.html: error=%s, schema=%s, summary=%s, results=%s",
192
+ error, bool(current_schema), bool(current_summary), bool(results))
193
+ return render_template('index.html', error=error, schema=current_schema, summary=current_summary, results=results, query=generated_query)
194
+
195
+ if __name__ == '__main__':
196
+ try:
197
+ spec = importlib.util.find_spec("webbrowser")
198
+ if spec is None:
199
+ logger.error("Standard library webbrowser module not found")
200
+ raise ImportError("Could not find webbrowser module")
201
+ webbrowser = importlib.util.module_from_spec(spec)
202
+ spec.loader.exec_module(webbrowser)
203
+ except ImportError as e:
204
+ logger.error("Failed to import webbrowser: %s", str(e))
205
+ raise
206
+
207
+ url = 'http://localhost:7860'
208
+ try:
209
+ webbrowser.open(url)
210
+ except AttributeError:
211
+ logger.error("webbrowser.open() failed, possibly due to environment issue")
212
+ print(f"Warning: Could not open browser automatically. Please navigate to {url} manually.")
213
+ app.run(host='0.0.0.0', port=7860, debug=False)