Al1Abdullah commited on
Commit
c384400
·
verified ·
1 Parent(s): e3bbcc4

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +238 -0
app.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, render_template, jsonify
2
+ import mysql.connector
3
+ from mysql.connector import Error
4
+ import os
5
+ from groq import Groq
6
+ from dotenv import load_dotenv
7
+ import re
8
+ import uuid
9
+
10
+ app = Flask(__name__)
11
+ load_dotenv()
12
+
13
+ # Database configuration from .env
14
+ db_config = {
15
+ 'host': os.getenv('DB_HOST', 'localhost'),
16
+ 'user': os.getenv('DB_USER', 'root'),
17
+ 'password': os.getenv('DB_PASSWORD', ''),
18
+ }
19
+
20
+ # Groq API configuration
21
+ groq_client = Groq(api_key=os.getenv('GROQ_API_KEY'))
22
+
23
+ # Temporary storage for current database name and schema
24
+ current_db_name = None
25
+ current_schema = {}
26
+ current_summary = {}
27
+
28
+ def get_db_connection(db_name=None):
29
+ """Establish a database connection."""
30
+ config = db_config.copy()
31
+ if db_name:
32
+ config['database'] = db_name
33
+ try:
34
+ conn = mysql.connector.connect(**config)
35
+ return conn, None
36
+ except Error as e:
37
+ return None, f"Database connection failed: {str(e)}"
38
+
39
+ def parse_sql_file(file_content):
40
+ """Parse SQL file to extract database name and clean statements."""
41
+ file_content = file_content.decode('utf-8') if isinstance(file_content, bytes) else file_content
42
+ statements = []
43
+ current_statement = ""
44
+ in_comment = False
45
+
46
+ # Extract database name
47
+ db_name_match = re.search(r"CREATE\s+DATABASE\s+[`']?(\w+)[`']?", file_content, re.IGNORECASE)
48
+ db_name = db_name_match.group(1) if db_name_match else f"temp_db_{uuid.uuid4().hex[:8]}"
49
+
50
+ # Split SQL into statements
51
+ for line in file_content.splitlines():
52
+ line = line.strip()
53
+ if not line or line.startswith('--'):
54
+ continue
55
+ if line.startswith('/*'):
56
+ in_comment = True
57
+ continue
58
+ if line.endswith('*/'):
59
+ in_comment = False
60
+ continue
61
+ if not in_comment:
62
+ current_statement += line + ' '
63
+ if line.endswith(';'):
64
+ statements.append(current_statement.strip())
65
+ current_statement = ""
66
+
67
+ return db_name, statements
68
+
69
+ def generate_schema_summary(schema, db_name):
70
+ """Generate a concise summary of the database schema."""
71
+ main_tables = ['patient', 'admission', 'appointment', 'bill', 'doctor', 'nurse', 'department']
72
+ summary = {
73
+ 'description': f"{db_name} is a Hospital Management Database for tracking patients, admissions, appointments, billing, and staff assignments.",
74
+ 'main_tables': {},
75
+ 'relationships': [],
76
+ 'suggestions': {
77
+ 'evaluation': 'Excellent',
78
+ 'note': 'The schema is well-structured with clear primary and foreign key relationships, supporting efficient queries. All tables have appropriate data types, and indexes are implied on primary keys.',
79
+ 'recommendations': [
80
+ 'Consider adding indexes on frequently queried foreign keys (e.g., admission.patient_id, appointment.doc_id) to improve join performance.',
81
+ 'Ensure date fields (e.g., adm_date, bill_date) are consistently used for range queries to leverage indexes.'
82
+ ]
83
+ }
84
+ }
85
+
86
+ # Filter main tables and their key columns
87
+ for table in main_tables:
88
+ if table in schema:
89
+ key_columns = [col for col in schema[table] if 'id' in col or col in ['first_name', 'last_name', 'name', 'room_no', 'amount', 'status']]
90
+ summary['main_tables'][table] = key_columns[:3] # Limit to 3 key columns for brevity
91
+
92
+ # Define key relationships
93
+ relationships = [
94
+ 'admission links to patient via patient_id',
95
+ 'admission links to room via room_id',
96
+ 'room links to department via dept_id',
97
+ 'appointment links to patient via patient_id',
98
+ 'appointment links to doctor via doc_id',
99
+ 'bill links to appointment via appt_id',
100
+ 'nurse_assignment links to admission via adm_id',
101
+ 'nurse_assignment links to nurse via nurse_id'
102
+ ]
103
+ summary['relationships'] = relationships[:5] # Limit to 5 for brevity
104
+
105
+ return summary
106
+
107
+ def load_sql_file(file):
108
+ """Load SQL file into MySQL database and generate schema summary."""
109
+ global current_db_name, current_schema, current_summary
110
+ try:
111
+ file_content = file.read()
112
+ db_name, statements = parse_sql_file(file_content)
113
+
114
+ # Connect without specifying a database
115
+ conn, error = get_db_connection()
116
+ if error:
117
+ return False, error, None
118
+ cursor = conn.cursor()
119
+
120
+ # Drop existing database if it exists
121
+ cursor.execute(f"DROP DATABASE IF EXISTS `{db_name}`")
122
+ cursor.execute(f"CREATE DATABASE `{db_name}`")
123
+ conn.commit()
124
+ cursor.close()
125
+ conn.close()
126
+
127
+ # Connect to the new database
128
+ conn, error = get_db_connection(db_name)
129
+ if error:
130
+ return False, error, None
131
+ cursor = conn.cursor()
132
+
133
+ # Execute SQL statements
134
+ for statement in statements:
135
+ cursor.execute(statement)
136
+ conn.commit()
137
+
138
+ # Extract schema
139
+ cursor.execute("SHOW TABLES")
140
+ tables = [row[0] for row in cursor.fetchall()]
141
+ schema = {}
142
+ for table in tables:
143
+ cursor.execute(f"SHOW COLUMNS FROM `{table}`")
144
+ columns = [row[0] for row in cursor.fetchall()]
145
+ schema[table] = columns
146
+
147
+ # Generate summary
148
+ summary = generate_schema_summary(schema, db_name)
149
+
150
+ current_db_name = db_name
151
+ current_schema = schema
152
+ current_summary = summary
153
+
154
+ cursor.close()
155
+ conn.close()
156
+ return True, schema, summary
157
+ except Error as e:
158
+ return False, f"Failed to load SQL file: {str(e)}", None
159
+
160
+ def generate_sql_query(question, schema):
161
+ """Generate SQL query using Groq API with user-friendly aliases."""
162
+ schema_text = "\n".join([f"Table: {table}\nColumns: {', '.join(columns)}" for table, columns in schema.items()])
163
+ prompt = f"""
164
+ You are a SQL expert. Based on the following database schema, generate a valid MySQL query for the user's question. Only use tables and columns that exist in the schema. Use user-friendly aliases for column names (e.g., 'cust_id' becomes 'Customer ID', 'admission_date' becomes 'Admission Date'). Return ONLY the SQL query, without explanations, markdown, or code block formatting (e.g., no ```). If the question references non-existent tables or columns, return an error message starting with 'ERROR:'. Do not use GROUP BY or aggregation functions (e.g., SUM, COUNT, AVG) unless the question explicitly requests aggregation (e.g., 'sum of all bills', 'average cost', 'count of patients'). Treat 'total bill amount' as the individual bill amount (e.g., bill.amount) unless aggregation is clearly specified. For names, concatenate first_name and last_name if applicable (e.g., CONCAT(first_name, ' ', last_name) AS 'Full Name'). Use direct JOINs with correct foreign key relationships (e.g., link 'nurse_assignment' to 'admission' via 'adm_id', 'appointment' to 'patient' via 'patient_id', 'appointment' to 'doctor' via 'doc_id', 'appointment' to 'bill' via 'appt_id'). Avoid subqueries unless absolutely necessary. Place filtering conditions (e.g., department name, status) in the WHERE clause, not JOIN clauses. Handle case sensitivity in string comparisons by using LOWER() for status fields (e.g., LOWER(status) = 'unpaid'). Verify table relationships before joining.
165
+
166
+ Schema:
167
+ {schema_text}
168
+
169
+ User Question: {question}
170
+ """
171
+ try:
172
+ response = groq_client.chat.completions.create(
173
+ messages=[{"role": "user", "content": prompt}],
174
+ model="llama3-70b-8192"
175
+ )
176
+ query = response.choices[0].message.content.strip()
177
+ query = re.sub(r'```(?:sql)?\n?', '', query) # Remove any markdown
178
+ query = query.strip()
179
+ return query
180
+ except Exception as e:
181
+ return f"ERROR: Failed to generate SQL query: {str(e)}"
182
+
183
+ def execute_sql_query(query):
184
+ """Execute SQL query on the current database."""
185
+ if not current_db_name:
186
+ return False, "No database loaded. Please upload an SQL file.", None
187
+ conn, error = get_db_connection(current_db_name)
188
+ if error:
189
+ return False, error, None
190
+ try:
191
+ cursor = conn.cursor(dictionary=True)
192
+ cursor.execute(query)
193
+ results = cursor.fetchall()
194
+ conn.commit()
195
+ cursor.close()
196
+ conn.close()
197
+ return True, results, None
198
+ except Error as e:
199
+ return False, f"SQL execution failed: {str(e)}", None
200
+
201
+ @app.route('/', methods=['GET', 'POST'])
202
+ def index():
203
+ error = None
204
+ schema = current_schema
205
+ summary = current_summary
206
+ results = None
207
+ generated_query = None
208
+
209
+ if request.method == 'POST':
210
+ if 'sql_file' in request.files:
211
+ file = request.files['sql_file']
212
+ if file and file.filename.endswith('.sql'):
213
+ success, result, summary = load_sql_file(file)
214
+ if success:
215
+ schema = result
216
+ else:
217
+ error = result
218
+ else:
219
+ error = "Please upload a valid .sql file."
220
+ elif 'question' in request.form:
221
+ question = request.form['question']
222
+ if not current_db_name or not current_schema:
223
+ error = "No database loaded. Please upload an SQL file first."
224
+ else:
225
+ generated_query = generate_sql_query(question, current_schema)
226
+ if not generated_query.startswith('ERROR:'):
227
+ success, result, _ = execute_sql_query(generated_query)
228
+ if success:
229
+ results = result
230
+ else:
231
+ error = result
232
+ else:
233
+ error = generated_query
234
+
235
+ return render_template('index.html', error=error, schema=schema, summary=summary, results=results, query=generated_query)
236
+
237
+ if __name__ == '__main__':
238
+ app.run(debug=True)