Al1Abdullah commited on
Commit
3874558
·
verified ·
1 Parent(s): 23ad945

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -2
app.py CHANGED
@@ -19,8 +19,12 @@ default_db_config = {
19
  'port': int(os.getenv('DB_PORT', 3306))
20
  }
21
 
22
- # Groq API configuration
23
- groq_client = Groq(api_key=os.getenv('GROQ_API_KEY'))
 
 
 
 
24
 
25
  # Temporary storage for current database name and schema
26
  current_db_name = None
@@ -215,6 +219,9 @@ def load_sql_file(file):
215
 
216
  def generate_sql_query(question, schema):
217
  """Generate SQL query using Groq API with user-friendly aliases."""
 
 
 
218
  schema_text = "\n".join([f"Table: {table}\nColumns: {', '.join(columns)}" for table, columns in schema.items()])
219
  prompt = f"""
220
  You are a SQL expert. Based on the following database schema, generate a valid MySQL query for the user's question. Only use tables and columns that exist in the schema. Use user-friendly aliases for column names (e.g., 'cust_id' becomes 'Customer ID', 'admission_date' becomes 'Admission Date'). Return ONLY the SQL query, without explanations, markdown, or code block formatting (e.g., no ```). If the question references non-existent tables or columns, return an error message starting with 'ERROR:'. Do not use GROUP BY or aggregation functions (e.g., SUM, COUNT, AVG) unless the question explicitly requests aggregation (e.g., 'sum of all bills', 'average cost', 'count of patients'). Treat 'total bill amount' as the individual bill amount (e.g., bill.amount) unless aggregation is clearly specified. For names, concatenate first_name and last_name if applicable (e.g., CONCAT(first_name, ' ', last_name) AS 'Full Name'). Use direct JOINs with correct foreign key relationships. Avoid subqueries unless absolutely necessary. Place filtering conditions (e.g., department name, status) in the WHERE clause, not JOIN clauses. Handle case sensitivity in string comparisons by using LOWER() for status fields (e.g., LOWER(status) = 'unpaid'). Verify table relationships before joining.
@@ -262,6 +269,9 @@ def index():
262
  results = None
263
  generated_query = None
264
 
 
 
 
265
  if request.method == 'POST':
266
  if 'sql_file' in request.files:
267
  file = request.files['sql_file']
 
19
  'port': int(os.getenv('DB_PORT', 3306))
20
  }
21
 
22
+ # Groq API configuration with error handling
23
+ try:
24
+ groq_client = Groq(api_key=os.getenv('GROQ_API_KEY'))
25
+ except Exception as e:
26
+ groq_client = None
27
+ print(f"Failed to initialize Groq client: {str(e)}")
28
 
29
  # Temporary storage for current database name and schema
30
  current_db_name = None
 
219
 
220
  def generate_sql_query(question, schema):
221
  """Generate SQL query using Groq API with user-friendly aliases."""
222
+ if not groq_client:
223
+ return "ERROR: Groq client not initialized. Check API key and try again."
224
+
225
  schema_text = "\n".join([f"Table: {table}\nColumns: {', '.join(columns)}" for table, columns in schema.items()])
226
  prompt = f"""
227
  You are a SQL expert. Based on the following database schema, generate a valid MySQL query for the user's question. Only use tables and columns that exist in the schema. Use user-friendly aliases for column names (e.g., 'cust_id' becomes 'Customer ID', 'admission_date' becomes 'Admission Date'). Return ONLY the SQL query, without explanations, markdown, or code block formatting (e.g., no ```). If the question references non-existent tables or columns, return an error message starting with 'ERROR:'. Do not use GROUP BY or aggregation functions (e.g., SUM, COUNT, AVG) unless the question explicitly requests aggregation (e.g., 'sum of all bills', 'average cost', 'count of patients'). Treat 'total bill amount' as the individual bill amount (e.g., bill.amount) unless aggregation is clearly specified. For names, concatenate first_name and last_name if applicable (e.g., CONCAT(first_name, ' ', last_name) AS 'Full Name'). Use direct JOINs with correct foreign key relationships. Avoid subqueries unless absolutely necessary. Place filtering conditions (e.g., department name, status) in the WHERE clause, not JOIN clauses. Handle case sensitivity in string comparisons by using LOWER() for status fields (e.g., LOWER(status) = 'unpaid'). Verify table relationships before joining.
 
269
  results = None
270
  generated_query = None
271
 
272
+ if not groq_client:
273
+ error = "Groq client not initialized. Please check GROQ_API_KEY and restart the app."
274
+
275
  if request.method == 'POST':
276
  if 'sql_file' in request.files:
277
  file = request.files['sql_file']