File size: 20,289 Bytes
9225010
 
 
 
 
 
 
 
e05b320
9225010
 
 
 
 
e05b320
 
 
 
18c53c6
9225010
18c53c6
 
 
 
9225010
 
de187e0
 
e05b320
 
 
 
 
 
de187e0
3874558
 
 
e05b320
3874558
 
e05b320
9225010
 
 
 
 
 
 
e05b320
 
 
 
 
 
 
de187e0
e05b320
 
 
 
 
 
 
 
9225010
 
 
 
e05b320
9225010
 
e05b320
de187e0
9225010
 
 
 
 
 
 
 
 
 
 
e05b320
9225010
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e05b320
9225010
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89feb1e
9225010
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e05b320
9225010
 
 
e05b320
 
 
9225010
 
 
 
 
 
e05b320
9225010
 
e05b320
 
 
 
9225010
 
 
 
 
e05b320
9225010
 
 
 
e05b320
 
 
 
 
 
 
 
 
 
9225010
 
 
 
 
 
e05b320
9225010
 
 
 
 
e05b320
 
 
 
 
 
 
 
9225010
 
 
 
 
 
 
 
 
 
e05b320
9225010
 
 
 
 
 
 
 
 
 
e05b320
9225010
e05b320
 
 
9225010
 
 
3874558
e05b320
3874558
 
9225010
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e05b320
9225010
 
e05b320
9225010
 
 
 
 
e05b320
9225010
 
 
e05b320
9225010
 
 
 
 
 
e05b320
9225010
 
 
 
e05b320
9225010
 
 
 
 
 
 
 
 
 
3874558
 
e05b320
3874558
9225010
e05b320
9225010
 
e05b320
 
 
9225010
 
 
 
e05b320
9225010
 
e05b320
9225010
 
e05b320
9225010
 
e05b320
9225010
 
e05b320
9225010
 
 
 
 
 
e05b320
9225010
 
e05b320
9225010
 
e05b320
9225010
e05b320
 
9225010
 
 
 
 
e05b320
9225010
 
 
18c53c6
9225010
 
e05b320
 
 
9225010
 
 
 
e05b320
 
 
9225010
 
 
 
 
e05b320
9225010
 
 
 
 
e05b320
 
 
 
de187e0
9225010
89feb1e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
from flask import Flask, request, render_template, jsonify, session
import mysql.connector
from mysql.connector import Error
import os
from groq import Groq
from dotenv import load_dotenv
import re
import uuid
import logging

app = Flask(__name__)
app.secret_key = os.urandom(24)  # Required for session management
load_dotenv()

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Default database configuration from .env
default_db_config = {
    'host': os.getenv('DB_HOST'),
    'user': os.getenv('DB_USER'),
    'password': os.getenv('DB_PASSWORD'),
    'port': int(os.getenv('DB_PORT', 4000))  # Default to TiDB port
}

# Validate default config at startup
if not all([default_db_config['host'], default_db_config['user'], default_db_config['password']]):
    logger.error("Incomplete default MySQL credentials in Secrets: host=%s, user=%s, password=%s, port=%s",
                 default_db_config['host'], default_db_config['user'], '***' if default_db_config['password'] else None,
                 default_db_config['port'])
else:
    logger.info("Default MySQL credentials loaded successfully: host=%s, user=%s, port=%s",
                default_db_config['host'], default_db_config['user'], default_db_config['port'])

# Groq API configuration with error handling
try:
    groq_client = Groq(api_key=os.getenv('GROQ_API_KEY'))
    logger.info("Groq client initialized successfully")
except Exception as e:
    groq_client = None
    logger.error("Failed to initialize Groq client: %s", str(e))

# Temporary storage for current database name and schema
current_db_name = None
current_schema = {}
current_summary = {}

def get_db_connection(db_name=None):
    """Establish a database connection using default config unless session config is explicitly set."""
    # Use session config only if explicitly set and complete
    if 'db_config' in session and all([session['db_config'].get('host'), session['db_config'].get('user'), session['db_config'].get('password')]):
        config = session['db_config'].copy()
        logger.info("Using session config for DB connection: host=%s, user=%s, port=%s",
                    config['host'], config['user'], config['port'])
    else:
        config = default_db_config.copy()
        logger.info("Using default config for DB connection: host=%s, user=%s, port=%s",
                    config['host'], config['user'], config['port'])

    if not all([config.get('host'), config.get('user'), config.get('password')]):
        logger.error("No valid MySQL credentials: host=%s, user=%s, password=%s",
                     config.get('host'), config.get('user'), '***' if config.get('password') else None)
        return None, "No valid MySQL credentials provided. Ensure Secrets (DB_HOST, DB_USER, DB_PASSWORD, DB_PORT) are set in Hugging Face Space settings or configure a custom connection."

    if db_name:
        config['database'] = db_name
    try:
        conn = mysql.connector.connect(**config)
        logger.info("Database connection successful%s", f" to {db_name}" if db_name else "")
        return conn, None
    except Error as e:
        logger.error("Database connection failed: %s", str(e))
        return None, f"Database connection failed: {str(e)}. Verify credentials, ensure the MySQL/TiDB server is running, and check network settings (e.g., IP whitelist, TLS)."

def parse_sql_file(file_content):
    """Parse SQL file to extract database name and clean statements."""
    file_content = file_content.decode('utf-8') if isinstance(file_content, bytes) else file_content
    statements = []
    current_statement = ""
    in_comment = False

    # Extract database name
    db_name_match = re.search(r"CREATE\s+DATABASE\s+[`']?(\w+)[`']?", file_content, re.IGNORECASE)
    db_name = db_name_match.group(1) if db_name_match else f"temp_db_{uuid.uuid4().hex[:8]}"
    logger.info("Parsed SQL file: database name=%s", db_name)

    # Split SQL into statements
    for line in file_content.splitlines():
        line = line.strip()
        if not line or line.startswith('--'):
            continue
        if line.startswith('/*'):
            in_comment = True
            continue
        if line.endswith('*/'):
            in_comment = False
            continue
        if not in_comment:
            current_statement += line + ' '
            if line.endswith(';'):
                statements.append(current_statement.strip())
                current_statement = ""

    return db_name, statements

def generate_schema_summary(schema, db_name):
    """Generate a dynamic summary of any MySQL database schema."""
    logger.info("Generating schema summary for database: %s", db_name)
    summary = {
        'description': '',
        'main_tables': {},
        'relationships': [],
        'suggestions': {
            'evaluation': 'Good',
            'note': '',
            'recommendations': []
        }
    }

    # Infer description based on table names
    table_names = list(schema.keys())
    if any(table in table_names for table in ['patient', 'doctor', 'admission', 'appointment']):
        summary['description'] = f"{db_name} appears to be a Hospital Management Database for tracking entities like patients, staff, and appointments."
    elif any(table in table_names for table in ['customer', 'order', 'product', 'employee']):
        summary['description'] = f"{db_name} appears to be a Retail or E-commerce Database for managing customers, orders, and products."
    elif any(table in table_names for table in ['book', 'author', 'loan', 'member']):
        summary['description'] = f"{db_name} appears to be a Library Management Database for tracking books, authors, and loans."
    else:
        summary['description'] = f"{db_name} is a database with {len(table_names)} tables for managing various entities."

    # Select main tables (up to 5, prioritized by column count or presence of 'id')
    sorted_tables = sorted(schema.items(), key=lambda x: len(x[1]), reverse=True)[:5]
    for table, columns in sorted_tables:
        key_columns = [col for col in columns if 'id' in col.lower() or col in ['name', 'first_name', 'last_name', 'title', 'amount', 'status', 'price']]
        summary['main_tables'][table] = key_columns[:3]  # Limit to 3 key columns

    # Connect to database to detect relationships and suggestions
    conn, error = get_db_connection(db_name)
    if conn:
        cursor = conn.cursor()
        try:
            # Detect foreign keys using INFORMATION_SCHEMA
            cursor.execute("""
                SELECT TABLE_NAME, COLUMN_NAME, REFERENCED_TABLE_NAME, REFERENCED_COLUMN_NAME
                FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE
                WHERE TABLE_SCHEMA = %s AND REFERENCED_TABLE_NAME IS NOT NULL
            """, (db_name,))
            relationships = cursor.fetchall()
            for rel in relationships[:5]:  # Limit to 5 relationships
                summary['relationships'].append(f"{rel[0]} links to {rel[2]} via {rel[1]}")

            # Fallback: Infer relationships from common column names
            if not relationships:
                for table1, columns1 in schema.items():
                    for col1 in columns1:
                        if '_id' in col1 and col1 != f"{table1}_id":
                            target_table = col1.replace('_id', '')
                            if target_table in schema:
                                summary['relationships'].append(f"{table1} likely links to {target_table} via {col1}")

            # Check for indexes and constraints
            cursor.execute("""
                SELECT TABLE_NAME, NON_UNIQUE, INDEX_NAME
                FROM INFORMATION_SCHEMA.STATISTICS
                WHERE TABLE_SCHEMA = %s AND INDEX_NAME != 'PRIMARY'
            """, (db_name,))
            indexes = cursor.fetchall()
            indexed_columns = set(row[0] + '.' + row[2] for row in indexes if row[1] == 0)

            # Evaluate schema
            has_foreign_keys = bool(relationships)
            has_indexes = bool(indexes)
            if has_foreign_keys and has_indexes:
                summary['suggestions']['evaluation'] = 'Excellent'
                summary['suggestions']['note'] = 'The schema is well-structured with defined foreign key constraints and indexes, supporting efficient queries.'
            elif has_foreign_keys:
                summary['suggestions']['evaluation'] = 'Good'
                summary['suggestions']['note'] = 'The schema has clear foreign key relationships but may lack sufficient indexes.'
            else:
                summary['suggestions']['evaluation'] = 'Needs Improvement'
                summary['suggestions']['note'] = 'The schema lacks explicit foreign key constraints, which may affect query reliability.'

            # Recommendations
            if not has_foreign_keys:
                summary['suggestions']['recommendations'].append('Add explicit foreign key constraints to ensure data integrity.')
            if not has_indexes:
                summary['suggestions']['recommendations'].append('Add indexes on frequently queried columns (e.g., foreign keys, date fields) to improve performance.')
            summary['suggestions']['recommendations'].append('Verify that date and numeric fields use appropriate data types for efficient querying.')

            cursor.close()
            conn.close()
        except Error as e:
            logger.error("Schema summary error: %s", str(e))
            summary['suggestions']['note'] = f'Analysis limited due to: {str(e)}'
            if not summary['relationships']:
                summary['relationships'] = ['Unable to detect relationships due to limited metadata access.']
    else:
        logger.error("Failed to connect for schema summary: %s", error)
        summary['suggestions']['note'] = 'Unable to analyze schema due to connection issues.'

    return summary

def load_sql_file(file):
    """Load SQL file into MySQL database and generate schema summary."""
    global current_db_name, current_schema, current_summary
    logger.info("Attempting to load SQL file")
    try:
        file_content = file.read()
        if not file_content:
            logger.error("Empty SQL file uploaded")
            return False, "Uploaded SQL file is empty.", None

        db_name, statements = parse_sql_file(file_content)

        # Connect without specifying a database
        conn, error = get_db_connection()
        if error:
            logger.error("Connection failed in load_sql_file: %s", error)
            return False, error, None
        cursor = conn.cursor()

        # Drop existing database if it exists
        try:
            cursor.execute(f"DROP DATABASE IF EXISTS `{db_name}`")
            cursor.execute(f"CREATE DATABASE `{db_name}`")
            conn.commit()
            logger.info("Created database: %s", db_name)
        except Error as e:
            logger.error("Failed to create database %s: %s", db_name, str(e))
            cursor.close()
            conn.close()
            return False, f"Failed to create database: {str(e)}", None
        cursor.close()
        conn.close()

        # Connect to the new database
        conn, error = get_db_connection(db_name)
        if error:
            logger.error("Connection to %s failed: %s", db_name, error)
            return False, error, None
        cursor = conn.cursor()

        # Execute SQL statements
        for statement in statements:
            try:
                cursor.execute(statement)
                logger.info("Executed statement: %s", statement[:50])
            except Error as e:
                logger.error("Failed to execute statement: %s, error: %s", statement[:50], str(e))
                cursor.close()
                conn.close()
                return False, f"Failed to execute SQL statement: {str(e)}", None
        conn.commit()

        # Extract schema
        cursor.execute("SHOW TABLES")
        tables = [row[0] for row in cursor.fetchall()]
        schema = {}
        for table in tables:
            cursor.execute(f"SHOW COLUMNS FROM `{table}`")
            columns = [row[0] for row in cursor.fetchall()]
            schema[table] = columns
            logger.info("Extracted schema for table: %s", table)

        # Generate summary
        summary = generate_schema_summary(schema, db_name)

        current_db_name = db_name
        current_schema = schema
        current_summary = summary

        cursor.close()
        conn.close()
        logger.info("SQL file loaded successfully: %s", db_name)
        return True, schema, summary
    except Exception as e:
        logger.error("Unexpected error in load_sql_file: %s", str(e))
        return False, f"Unexpected error while loading SQL file: {str(e)}", None

def generate_sql_query(question, schema):
    """Generate SQL query using Groq API with user-friendly aliases."""
    if not groq_client:
        logger.error("Groq client not initialized")
        return "ERROR: Groq client not initialized. Check API key and try again."

    schema_text = "\n".join([f"Table: {table}\nColumns: {', '.join(columns)}" for table, columns in schema.items()])
    prompt = f"""
You are a SQL expert. Based on the following database schema, generate a valid MySQL query for the user's question. Only use tables and columns that exist in the schema. Use user-friendly aliases for column names (e.g., 'cust_id' becomes 'Customer ID', 'admission_date' becomes 'Admission Date'). Return ONLY the SQL query, without explanations, markdown, or code block formatting (e.g., no ```). If the question references non-existent tables or columns, return an error message starting with 'ERROR:'. Do not use GROUP BY or aggregation functions (e.g., SUM, COUNT, AVG) unless the question explicitly requests aggregation (e.g., 'sum of all bills', 'average cost', 'count of patients'). Treat 'total bill amount' as the individual bill amount (e.g., bill.amount) unless aggregation is clearly specified. For names, concatenate first_name and last_name if applicable (e.g., CONCAT(first_name, ' ', last_name) AS 'Full Name'). Use direct JOINs with correct foreign key relationships. Avoid subqueries unless absolutely necessary. Place filtering conditions (e.g., department name, status) in the WHERE clause, not JOIN clauses. Handle case sensitivity in string comparisons by using LOWER() for status fields (e.g., LOWER(status) = 'unpaid'). Verify table relationships before joining.

Schema:
{schema_text}

User Question: {question}
"""
    try:
        response = groq_client.chat.completions.create(
            messages=[{"role": "user", "content": prompt}],
            model="llama3-70b-8192"
        )
        query = response.choices[0].message.content.strip()
        query = re.sub(r'```(?:sql)?\n?', '', query)  # Remove any markdown
        query = query.strip()
        logger.info("Generated SQL query: %s", query[:100])
        return query
    except Exception as e:
        logger.error("Failed to generate SQL query: %s", str(e))
        return f"ERROR: Failed to generate SQL query: {str(e)}"

def execute_sql_query(query):
    """Execute SQL query on the current database."""
    if not current_db_name:
        logger.error("No database loaded for query execution")
        return False, "No database loaded. Please upload an SQL file.", None
    conn, error = get_db_connection(current_db_name)
    if error:
        logger.error("Connection failed for query execution: %s", error)
        return False, error, None
    try:
        cursor = conn.cursor(dictionary=True)
        cursor.execute(query)
        results = cursor.fetchall()
        conn.commit()
        logger.info("Query executed successfully: %s", query[:50])
        cursor.close()
        conn.close()
        return True, results, None
    except Error as e:
        logger.error("SQL execution failed: %s", str(e))
        return False, f"SQL execution failed: {str(e)}", None

@app.route('/', methods=['GET', 'POST'])
def index():
    error = None
    schema = current_schema
    summary = current_summary
    results = None
    generated_query = None

    if not groq_client:
        error = "Groq client not initialized. Please check GROQ_API_KEY and restart the app."
        logger.error(error)

    if request.method == 'POST':
        logger.info("Received POST request")
        if 'sql_file' in request.files:
            file = request.files['sql_file']
            logger.info("SQL file upload detected: %s", file.filename if file else "No file")
            # Clear session config to ensure default TiDB backend is used for uploads
            session.pop('db_config', None)
            if file and file.filename.endswith('.sql'):
                success, result, summary = load_sql_file(file)
                if success:
                    schema = result
                    logger.info("SQL file loaded successfully")
                else:
                    error = result
                    logger.error("Failed to load SQL file: %s", error)
            else:
                error = "Please upload a valid .sql file."
                logger.error(error)
        elif 'question' in request.form:
            question = request.form['question']
            logger.info("Received question: %s", question)
            if not current_db_name or not current_schema:
                error = "No database loaded. Please upload an SQL file first."
                logger.error(error)
            else:
                generated_query = generate_sql_query(question, current_schema)
                if not generated_query.startswith('ERROR:'):
                    success, result, _ = execute_sql_query(generated_query)
                    if success:
                        results = result
                        logger.info("Query executed successfully, results: %d rows", len(results))
                    else:
                        error = result
                        logger.error("Query execution failed: %s", error)
                else:
                    error = generated_query
                    logger.error("Query generation failed: %s", error)

    logger.info("Rendering index.html: error=%s, schema=%s, summary=%s, results=%s",
                error, bool(schema), bool(summary), bool(results))
    return render_template('index.html', error=error, schema=schema, summary=summary, results=results, query=generated_query)

@app.route('/configure_db', methods=['POST'])
def configure_db():
    """Handle MySQL connection configuration."""
    logger.info("Received configure_db request")
    host = request.form.get('host', '').strip()
    user = request.form.get('user', '').strip()
    password = request.form.get('password', '')
    port = request.form.get('port', '4000').strip()

    if not host or not user:
        logger.error("Missing host or user in configure_db")
        return render_template('index.html', error="Host and user are required for custom MySQL configuration.",
                              schema=current_schema, summary=current_summary)

    try:
        port = int(port)
    except ValueError:
        logger.error("Invalid port number: %s", port)
        return render_template('index.html', error="Port must be a valid number.",
                              schema=current_schema, summary=current_summary)

    # Test connection
    test_config = {'host': host, 'user': user, 'password': password, 'port': port}
    conn, error = get_db_connection()
    if error:
        logger.error("Test connection failed in configure_db: %s", error)
        return render_template('index.html', error=error, schema=current_schema, summary=current_summary)
    
    # Store in session
    session['db_config'] = test_config
    conn.close()
    logger.info("Custom MySQL connection configured: host=%s, user=%s, port=%s",
                host, user, port)
    return render_template('index.html', error=None, schema=current_schema, summary=current_summary,
                          success="Custom MySQL connection configured successfully. You can now upload .sql files and query your database.")

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=int(os.getenv('PORT', 7860)), debug=False)