Al1Abdullah commited on
Commit
55c6e2b
·
verified ·
1 Parent(s): eaf5245

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +105 -405
app.py CHANGED
@@ -1,423 +1,123 @@
1
- from flask import Flask, request, render_template, jsonify, session
2
- import mysql.connector
3
- from mysql.connector import Error
4
- import os
5
- from groq import Groq
6
- from dotenv import load_dotenv
7
  import re
8
- import uuid
9
- import logging
10
-
11
- app = Flask(__name__)
12
- app.secret_key = os.urandom(24) # Required for session management
13
- load_dotenv()
14
-
15
- # Configure logging
16
- logging.basicConfig(level=logging.INFO)
17
- logger = logging.getLogger(__name__)
18
-
19
- # Default database configuration from .env
20
- default_db_config = {
21
- 'host': os.getenv('DB_HOST'),
22
- 'user': os.getenv('DB_USER'),
23
- 'password': os.getenv('DB_PASSWORD'),
24
- 'port': int(os.getenv('DB_PORT', 4000)) # Default to TiDB port
25
- }
26
-
27
- # Validate default config at startup
28
- if not all([default_db_config['host'], default_db_config['user'], default_db_config['password']]):
29
- logger.error("Incomplete default MySQL credentials in Secrets: host=%s, user=%s, password=%s, port=%s",
30
- default_db_config['host'], default_db_config['user'], '***' if default_db_config['password'] else None,
31
- default_db_config['port'])
32
- else:
33
- logger.info("Default MySQL credentials loaded successfully: host=%s, user=%s, port=%s",
34
- default_db_config['host'], default_db_config['user'], default_db_config['port'])
35
-
36
- # Groq API configuration with error handling
37
- try:
38
- groq_client = Groq(api_key=os.getenv('GROQ_API_KEY'))
39
- logger.info("Groq client initialized successfully")
40
- except Exception as e:
41
- groq_client = None
42
- logger.error("Failed to initialize Groq client: %s", str(e))
43
-
44
- # Temporary storage for current database name and schema
45
- current_db_name = None
46
- current_schema = {}
47
- current_summary = {}
48
-
49
- def get_db_connection(db_name=None):
50
- """Establish a database connection using default config unless session config is explicitly set."""
51
- # Use session config only if explicitly set and complete
52
- if 'db_config' in session and all([session['db_config'].get('host'), session['db_config'].get('user'), session['db_config'].get('password')]):
53
- config = session['db_config'].copy()
54
- logger.info("Using session config for DB connection: host=%s, user=%s, port=%s",
55
- config['host'], config['user'], config['port'])
56
- else:
57
- config = default_db_config.copy()
58
- logger.info("Using default config for DB connection: host=%s, user=%s, port=%s",
59
- config['host'], config['user'], config['port'])
60
-
61
- if not all([config.get('host'), config.get('user'), config.get('password')]):
62
- logger.error("No valid MySQL credentials: host=%s, user=%s, password=%s",
63
- config.get('host'), config.get('user'), '***' if config.get('password') else None)
64
- return None, "No valid MySQL credentials provided. Ensure Secrets (DB_HOST, DB_USER, DB_PASSWORD, DB_PORT) are set in Hugging Face Space settings or configure a custom connection."
65
-
66
- if db_name:
67
- config['database'] = db_name
68
- try:
69
- conn = mysql.connector.connect(**config)
70
- logger.info("Database connection successful%s", f" to {db_name}" if db_name else "")
71
- return conn, None
72
- except Error as e:
73
- logger.error("Database connection failed: %s", str(e))
74
- return None, f"Database connection failed: {str(e)}. Verify credentials, ensure the MySQL/TiDB server is running, and check network settings (e.g., IP whitelist, TLS)."
75
-
76
- def parse_sql_file(file_content):
77
- """Parse SQL file to extract database name and clean statements."""
78
- file_content = file_content.decode('utf-8') if isinstance(file_content, bytes) else file_content
79
- statements = []
80
- current_statement = ""
81
- in_comment = False
82
-
83
- # Extract database name
84
- db_name_match = re.search(r"CREATE\s+DATABASE\s+[`']?(\w+)[`']?", file_content, re.IGNORECASE)
85
- db_name = db_name_match.group(1) if db_name_match else f"temp_db_{uuid.uuid4().hex[:8]}"
86
- logger.info("Parsed SQL file: database name=%s", db_name)
87
 
88
- # Split SQL into statements
89
- for line in file_content.splitlines():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  line = line.strip()
91
- if not line or line.startswith('--'):
92
- continue
93
- if line.startswith('/*'):
94
- in_comment = True
95
- continue
96
- if line.endswith('*/'):
97
- in_comment = False
98
- continue
99
- if not in_comment:
100
- current_statement += line + ' '
101
- if line.endswith(';'):
102
- statements.append(current_statement.strip())
103
- current_statement = ""
104
-
105
- return db_name, statements
106
-
107
- def generate_schema_summary(schema, db_name):
108
- """Generate a dynamic summary of any MySQL database schema."""
109
- logger.info("Generating schema summary for database: %s", db_name)
110
- summary = {
111
- 'description': '',
112
- 'main_tables': {},
113
- 'relationships': [],
114
- 'suggestions': {
115
- 'evaluation': 'Good',
116
- 'note': '',
117
- 'recommendations': []
118
- }
119
- }
120
-
121
- # Infer description based on table names
122
- table_names = list(schema.keys())
123
- if any(table in table_names for table in ['patient', 'doctor', 'admission', 'appointment']):
124
- summary['description'] = f"{db_name} appears to be a Hospital Management Database for tracking entities like patients, staff, and appointments."
125
- elif any(table in table_names for table in ['customer', 'order', 'product', 'employee']):
126
- summary['description'] = f"{db_name} appears to be a Retail or E-commerce Database for managing customers, orders, and products."
127
- elif any(table in table_names for table in ['book', 'author', 'loan', 'member']):
128
- summary['description'] = f"{db_name} appears to be a Library Management Database for tracking books, authors, and loans."
129
- else:
130
- summary['description'] = f"{db_name} is a database with {len(table_names)} tables for managing various entities."
131
-
132
- # Select main tables (up to 5, prioritized by column count or presence of 'id')
133
- sorted_tables = sorted(schema.items(), key=lambda x: len(x[1]), reverse=True)[:5]
134
- for table, columns in sorted_tables:
135
- key_columns = [col for col in columns if 'id' in col.lower() or col in ['name', 'first_name', 'last_name', 'title', 'amount', 'status', 'price']]
136
- summary['main_tables'][table] = key_columns[:3] # Limit to 3 key columns
137
-
138
- # Connect to database to detect relationships and suggestions
139
- conn, error = get_db_connection(db_name)
140
- if conn:
141
- cursor = conn.cursor()
142
- try:
143
- # Detect foreign keys using INFORMATION_SCHEMA
144
- cursor.execute("""
145
- SELECT TABLE_NAME, COLUMN_NAME, REFERENCED_TABLE_NAME, REFERENCED_COLUMN_NAME
146
- FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE
147
- WHERE TABLE_SCHEMA = %s AND REFERENCED_TABLE_NAME IS NOT NULL
148
- """, (db_name,))
149
- relationships = cursor.fetchall()
150
- for rel in relationships[:5]: # Limit to 5 relationships
151
- summary['relationships'].append(f"{rel[0]} links to {rel[2]} via {rel[1]}")
152
-
153
- # Fallback: Infer relationships from common column names
154
- if not relationships:
155
- for table1, columns1 in schema.items():
156
- for col1 in columns1:
157
- if '_id' in col1 and col1 != f"{table1}_id":
158
- target_table = col1.replace('_id', '')
159
- if target_table in schema:
160
- summary['relationships'].append(f"{table1} likely links to {target_table} via {col1}")
161
-
162
- # Check for indexes and constraints
163
- cursor.execute("""
164
- SELECT TABLE_NAME, NON_UNIQUE, INDEX_NAME
165
- FROM INFORMATION_SCHEMA.STATISTICS
166
- WHERE TABLE_SCHEMA = %s AND INDEX_NAME != 'PRIMARY'
167
- """, (db_name,))
168
- indexes = cursor.fetchall()
169
- indexed_columns = set(row[0] + '.' + row[2] for row in indexes if row[1] == 0)
170
-
171
- # Evaluate schema
172
- has_foreign_keys = bool(relationships)
173
- has_indexes = bool(indexes)
174
- if has_foreign_keys and has_indexes:
175
- summary['suggestions']['evaluation'] = 'Excellent'
176
- summary['suggestions']['note'] = 'The schema is well-structured with defined foreign key constraints and indexes, supporting efficient queries.'
177
- elif has_foreign_keys:
178
- summary['suggestions']['evaluation'] = 'Good'
179
- summary['suggestions']['note'] = 'The schema has clear foreign key relationships but may lack sufficient indexes.'
180
- else:
181
- summary['suggestions']['evaluation'] = 'Needs Improvement'
182
- summary['suggestions']['note'] = 'The schema lacks explicit foreign key constraints, which may affect query reliability.'
183
-
184
- # Recommendations
185
- if not has_foreign_keys:
186
- summary['suggestions']['recommendations'].append('Add explicit foreign key constraints to ensure data integrity.')
187
- if not has_indexes:
188
- summary['suggestions']['recommendations'].append('Add indexes on frequently queried columns (e.g., foreign keys, date fields) to improve performance.')
189
- summary['suggestions']['recommendations'].append('Verify that date and numeric fields use appropriate data types for efficient querying.')
190
-
191
- cursor.close()
192
- conn.close()
193
- except Error as e:
194
- logger.error("Schema summary error: %s", str(e))
195
- summary['suggestions']['note'] = f'Analysis limited due to: {str(e)}'
196
- if not summary['relationships']:
197
- summary['relationships'] = ['Unable to detect relationships due to limited metadata access.']
198
- else:
199
- logger.error("Failed to connect for schema summary: %s", error)
200
- summary['suggestions']['note'] = 'Unable to analyze schema due to connection issues.'
201
-
202
- return summary
203
-
204
- def load_sql_file(file):
205
- """Load SQL file into MySQL database and generate schema summary."""
206
- global current_db_name, current_schema, current_summary
207
- logger.info("Attempting to load SQL file")
208
- try:
209
- file_content = file.read()
210
- if not file_content:
211
- logger.error("Empty SQL file uploaded")
212
- return False, "Uploaded SQL file is empty.", None
213
-
214
- db_name, statements = parse_sql_file(file_content)
215
-
216
- # Connect without specifying a database
217
- conn, error = get_db_connection()
218
- if error:
219
- logger.error("Connection failed in load_sql_file: %s", error)
220
- return False, error, None
221
- cursor = conn.cursor()
222
-
223
- # Drop existing database if it exists
224
- try:
225
- cursor.execute(f"DROP DATABASE IF EXISTS `{db_name}`")
226
- cursor.execute(f"CREATE DATABASE `{db_name}`")
227
- conn.commit()
228
- logger.info("Created database: %s", db_name)
229
- except Error as e:
230
- logger.error("Failed to create database %s: %s", db_name, str(e))
231
- cursor.close()
232
- conn.close()
233
- return False, f"Failed to create database: {str(e)}", None
234
- cursor.close()
235
- conn.close()
236
-
237
- # Connect to the new database
238
- conn, error = get_db_connection(db_name)
239
- if error:
240
- logger.error("Connection to %s failed: %s", db_name, error)
241
- return False, error, None
242
- cursor = conn.cursor()
243
-
244
- # Execute SQL statements
245
- for statement in statements:
246
- try:
247
- cursor.execute(statement)
248
- logger.info("Executed statement: %s", statement[:50])
249
- except Error as e:
250
- logger.error("Failed to execute statement: %s, error: %s", statement[:50], str(e))
251
- cursor.close()
252
- conn.close()
253
- return False, f"Failed to execute SQL statement: {str(e)}", None
254
- conn.commit()
255
-
256
- # Extract schema
257
- cursor.execute("SHOW TABLES")
258
- tables = [row[0] for row in cursor.fetchall()]
259
- schema = {}
260
- for table in tables:
261
- cursor.execute(f"SHOW COLUMNS FROM `{table}`")
262
- columns = [row[0] for row in cursor.fetchall()]
263
- schema[table] = columns
264
- logger.info("Extracted schema for table: %s", table)
265
-
266
- # Generate summary
267
- summary = generate_schema_summary(schema, db_name)
268
-
269
- current_db_name = db_name
270
- current_schema = schema
271
- current_summary = summary
272
-
273
- cursor.close()
274
- conn.close()
275
- logger.info("SQL file loaded successfully: %s", db_name)
276
- return True, schema, summary
277
- except Exception as e:
278
- logger.error("Unexpected error in load_sql_file: %s", str(e))
279
- return False, f"Unexpected error while loading SQL file: {str(e)}", None
280
-
281
  def generate_sql_query(question, schema):
282
- """Generate SQL query using Groq API with user-friendly aliases."""
283
  if not groq_client:
284
- logger.error("Groq client not initialized")
285
- return "ERROR: Groq client not initialized. Check API key and try again."
286
-
287
  schema_text = "\n".join([f"Table: {table}\nColumns: {', '.join(columns)}" for table, columns in schema.items()])
288
  prompt = f"""
289
- You are a SQL expert. Based on the following database schema, generate a valid MySQL query for the user's question. Only use tables and columns that exist in the schema. Use user-friendly aliases for column names (e.g., 'cust_id' becomes 'Customer ID', 'admission_date' becomes 'Admission Date'). Return ONLY the SQL query, without explanations, markdown, or code block formatting (e.g., no ```). If the question references non-existent tables or columns, return an error message starting with 'ERROR:'. Do not use GROUP BY or aggregation functions (e.g., SUM, COUNT, AVG) unless the question explicitly requests aggregation (e.g., 'sum of all bills', 'average cost', 'count of patients'). Treat 'total bill amount' as the individual bill amount (e.g., bill.amount) unless aggregation is clearly specified. For names, concatenate first_name and last_name if applicable (e.g., CONCAT(first_name, ' ', last_name) AS 'Full Name'). Use direct JOINs with correct foreign key relationships. Avoid subqueries unless absolutely necessary. Place filtering conditions (e.g., department name, status) in the WHERE clause, not JOIN clauses. Handle case sensitivity in string comparisons by using LOWER() for status fields (e.g., LOWER(status) = 'unpaid'). Verify table relationships before joining.
290
-
291
- Schema:
292
- {schema_text}
293
-
294
- User Question: {question}
295
- """
296
  try:
297
  response = groq_client.chat.completions.create(
298
  messages=[{"role": "user", "content": prompt}],
299
  model="llama3-70b-8192"
300
  )
301
  query = response.choices[0].message.content.strip()
302
- query = re.sub(r'```(?:sql)?\n?', '', query) # Remove any markdown
303
- query = query.strip()
304
- logger.info("Generated SQL query: %s", query[:100])
305
  return query
306
  except Exception as e:
307
- logger.error("Failed to generate SQL query: %s", str(e))
308
  return f"ERROR: Failed to generate SQL query: {str(e)}"
309
 
310
- def execute_sql_query(query):
311
- """Execute SQL query on the current database."""
312
- if not current_db_name:
313
- logger.error("No database loaded for query execution")
314
- return False, "No database loaded. Please upload an SQL file.", None
315
- conn, error = get_db_connection(current_db_name)
316
- if error:
317
- logger.error("Connection failed for query execution: %s", error)
318
- return False, error, None
319
  try:
320
- cursor = conn.cursor(dictionary=True)
321
- cursor.execute(query)
322
- results = cursor.fetchall()
323
- conn.commit()
324
- logger.info("Query executed successfully: %s", query[:50])
325
- cursor.close()
326
- conn.close()
327
- return True, results, None
328
- except Error as e:
329
- logger.error("SQL execution failed: %s", str(e))
330
- return False, f"SQL execution failed: {str(e)}", None
331
-
332
- @app.route('/', methods=['GET', 'POST'])
333
- def index():
334
- error = None
335
- schema = current_schema
336
- summary = current_summary
337
- results = None
338
- generated_query = None
339
-
340
- if not groq_client:
341
- error = "Groq client not initialized. Please check GROQ_API_KEY and restart the app."
342
- logger.error(error)
343
-
344
- if request.method == 'POST':
345
- logger.info("Received POST request")
346
- if 'sql_file' in request.files:
347
- file = request.files['sql_file']
348
- logger.info("SQL file upload detected: %s", file.filename if file else "No file")
349
- # Clear session config to ensure default TiDB backend is used for uploads
350
- session.pop('db_config', None)
351
- if file and file.filename.endswith('.sql'):
352
- success, result, summary = load_sql_file(file)
353
- if success:
354
- schema = result
355
- logger.info("SQL file loaded successfully")
356
- else:
357
- error = result
358
- logger.error("Failed to load SQL file: %s", error)
359
- else:
360
- error = "Please upload a valid .sql file."
361
- logger.error(error)
362
- elif 'question' in request.form:
363
- question = request.form['question']
364
- logger.info("Received question: %s", question)
365
- if not current_db_name or not current_schema:
366
- error = "No database loaded. Please upload an SQL file first."
367
- logger.error(error)
368
- else:
369
- generated_query = generate_sql_query(question, current_schema)
370
- if not generated_query.startswith('ERROR:'):
371
- success, result, _ = execute_sql_query(generated_query)
372
- if success:
373
- results = result
374
- logger.info("Query executed successfully, results: %d rows", len(results))
375
- else:
376
- error = result
377
- logger.error("Query execution failed: %s", error)
378
- else:
379
- error = generated_query
380
- logger.error("Query generation failed: %s", error)
381
-
382
- logger.info("Rendering index.html: error=%s, schema=%s, summary=%s, results=%s",
383
- error, bool(schema), bool(summary), bool(results))
384
- return render_template('index.html', error=error, schema=schema, summary=summary, results=results, query=generated_query)
385
-
386
- @app.route('/configure_db', methods=['POST'])
387
- def configure_db():
388
- """Handle MySQL connection configuration."""
389
- logger.info("Received configure_db request")
390
- host = request.form.get('host', '').strip()
391
- user = request.form.get('user', '').strip()
392
- password = request.form.get('password', '')
393
- port = request.form.get('port', '4000').strip()
394
-
395
- if not host or not user:
396
- logger.error("Missing host or user in configure_db")
397
- return render_template('index.html', error="Host and user are required for custom MySQL configuration.",
398
- schema=current_schema, summary=current_summary)
399
-
400
- try:
401
- port = int(port)
402
- except ValueError:
403
- logger.error("Invalid port number: %s", port)
404
- return render_template('index.html', error="Port must be a valid number.",
405
- schema=current_schema, summary=current_summary)
406
-
407
- # Test connection
408
- test_config = {'host': host, 'user': user, 'password': password, 'port': port}
409
- conn, error = get_db_connection()
410
- if error:
411
- logger.error("Test connection failed in configure_db: %s", error)
412
- return render_template('index.html', error=error, schema=current_schema, summary=current_summary)
413
-
414
- # Store in session
415
- session['db_config'] = test_config
416
- conn.close()
417
- logger.info("Custom MySQL connection configured: host=%s, user=%s, port=%s",
418
- host, user, port)
419
- return render_template('index.html', error=None, schema=current_schema, summary=current_summary,
420
- success="Custom MySQL connection configured successfully. You can now upload .sql files and query your database.")
421
-
422
- if __name__ == '__main__':
423
- app.run(host='0.0.0.0', port=int(os.getenv('PORT', 7860)), debug=False)
 
1
+ import gradio as gr
2
+ import pandas as pd
 
 
 
 
3
  import re
4
+ import json
5
+ from groq import Groq
6
+ import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
+ # --- Groq API Key Handling ---
9
+ GROQ_CONFIG = 'groq_config.json'
10
+ def get_groq_api_key():
11
+ if os.path.exists(GROQ_CONFIG):
12
+ with open(GROQ_CONFIG, 'r') as f:
13
+ config = json.load(f)
14
+ return config.get('GROQ_API_KEY', '')
15
+ return ''
16
+
17
+ groq_api_key = get_groq_api_key()
18
+ groq_client = Groq(api_key=groq_api_key) if groq_api_key else None
19
+
20
+ # --- Schema Parsing and Summarization ---
21
+ def parse_sql_schema(sql_text):
22
+ tables = {}
23
+ current_table = None
24
+ for line in sql_text.splitlines():
25
  line = line.strip()
26
+ if line.lower().startswith('create table'):
27
+ match = re.search(r'CREATE TABLE [`"]?(\w+)[`"]?', line, re.IGNORECASE)
28
+ if match:
29
+ current_table = match.group(1)
30
+ tables[current_table] = []
31
+ elif current_table and line and not line.startswith('--') and not line.startswith(')'):
32
+ col_match = re.match(r'[`"]?(\w+)[`"]?\s', line)
33
+ if col_match:
34
+ col = col_match.group(1)
35
+ tables[current_table].append(col)
36
+ elif line.startswith(')'):
37
+ current_table = None
38
+ return tables
39
+
40
+ def generate_schema_summary(schema):
41
+ summary = ''
42
+ for table, columns in schema.items():
43
+ summary += f"Table: {table}\n Columns: {', '.join(columns)}\n"
44
+ return summary or 'No tables found.'
45
+
46
+ # --- SQL Generation ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  def generate_sql_query(question, schema):
 
48
  if not groq_client:
49
+ return "ERROR: Groq client not initialized. Please set your API key in groq_config.json."
 
 
50
  schema_text = "\n".join([f"Table: {table}\nColumns: {', '.join(columns)}" for table, columns in schema.items()])
51
  prompt = f"""
52
+ You are a SQL expert. Based on the following database schema, generate a valid SQL query for the user's question. Only use tables and columns that exist in the schema. Return ONLY the SQL query, without explanations, markdown, or code block formatting.\n\nSchema:\n{schema_text}\n\nUser Question: {question}\n"""
 
 
 
 
 
 
53
  try:
54
  response = groq_client.chat.completions.create(
55
  messages=[{"role": "user", "content": prompt}],
56
  model="llama3-70b-8192"
57
  )
58
  query = response.choices[0].message.content.strip()
59
+ query = re.sub(r'```(?:sql)?\n?', '', query)
 
 
60
  return query
61
  except Exception as e:
 
62
  return f"ERROR: Failed to generate SQL query: {str(e)}"
63
 
64
+ # --- SQL Execution on CSV (optional demo) ---
65
+ def run_sql_on_csv(sql_query, csv_file):
 
 
 
 
 
 
 
66
  try:
67
+ df = pd.read_csv(csv_file)
68
+ # Only support SELECT * FROM table [WHERE ...] for demo
69
+ match = re.match(r'SELECT \* FROM (\w+)(?: WHERE (.*))?;', sql_query, re.IGNORECASE)
70
+ if not match:
71
+ return None, "Only simple SELECT * FROM ... queries are supported for CSV demo."
72
+ # Ignore table name, just use df
73
+ where_clause = match.group(2)
74
+ if where_clause:
75
+ # Very basic filter: col = value
76
+ col_val = re.match(r'(\w+) *= *["\']?(.*?)["\']?$', where_clause)
77
+ if col_val:
78
+ col, val = col_val.groups()
79
+ df = df[df[col] == val]
80
+ return df, None
81
+ except Exception as e:
82
+ return None, f"Error running SQL on CSV: {str(e)}"
83
+
84
+ # --- Gradio Interface Logic ---
85
+ def process(schema_file, question, csv_file):
86
+ if not schema_file:
87
+ return "Please upload a .sql schema file.", None, None, None
88
+ sql_text = schema_file.read().decode('utf-8')
89
+ schema = parse_sql_schema(sql_text)
90
+ summary = generate_schema_summary(schema)
91
+ if not question:
92
+ return summary, None, None, None
93
+ sql_query = generate_sql_query(question, schema)
94
+ if sql_query.startswith('ERROR:'):
95
+ return summary, sql_query, None, None
96
+ if csv_file:
97
+ df, err = run_sql_on_csv(sql_query, csv_file)
98
+ if err:
99
+ return summary, sql_query, None, err
100
+ return summary, sql_query, df, None
101
+ return summary, sql_query, None, None
102
+
103
+ with gr.Blocks() as demo:
104
+ gr.Markdown("""
105
+ # AI SQL Assistant (Hugging Face Demo)
106
+ - Upload a `.sql` schema file (no MySQL needed!)
107
+ - See a summary of your schema
108
+ - Ask a question in natural language
109
+ - Get the generated SQL query
110
+ - (Optional) Upload a CSV file to run the query and see results as a table
111
+ """)
112
+ with gr.Row():
113
+ schema_file = gr.File(label="Upload .sql Schema File", file_types=[".sql"])
114
+ csv_file = gr.File(label="(Optional) Upload CSV Data File", file_types=[".csv"])
115
+ question = gr.Textbox(label="Ask a Question (e.g., Show all patients admitted in July 2025)")
116
+ btn = gr.Button("Generate SQL and Query Table")
117
+ summary = gr.Textbox(label="Schema Summary", interactive=False)
118
+ sql_query = gr.Textbox(label="Generated SQL Query", interactive=False)
119
+ table = gr.Dataframe(label="Query Results Table (if CSV provided)")
120
+ error = gr.Textbox(label="Error", interactive=False)
121
+ btn.click(process, inputs=[schema_file, question, csv_file], outputs=[summary, sql_query, table, error])
122
+
123
+ demo.launch()