Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -19,8 +19,12 @@ default_db_config = {
|
|
| 19 |
'port': int(os.getenv('DB_PORT', 3306))
|
| 20 |
}
|
| 21 |
|
| 22 |
-
# Groq API configuration
|
| 23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
# Temporary storage for current database name and schema
|
| 26 |
current_db_name = None
|
|
@@ -215,6 +219,9 @@ def load_sql_file(file):
|
|
| 215 |
|
| 216 |
def generate_sql_query(question, schema):
|
| 217 |
"""Generate SQL query using Groq API with user-friendly aliases."""
|
|
|
|
|
|
|
|
|
|
| 218 |
schema_text = "\n".join([f"Table: {table}\nColumns: {', '.join(columns)}" for table, columns in schema.items()])
|
| 219 |
prompt = f"""
|
| 220 |
You are a SQL expert. Based on the following database schema, generate a valid MySQL query for the user's question. Only use tables and columns that exist in the schema. Use user-friendly aliases for column names (e.g., 'cust_id' becomes 'Customer ID', 'admission_date' becomes 'Admission Date'). Return ONLY the SQL query, without explanations, markdown, or code block formatting (e.g., no ```). If the question references non-existent tables or columns, return an error message starting with 'ERROR:'. Do not use GROUP BY or aggregation functions (e.g., SUM, COUNT, AVG) unless the question explicitly requests aggregation (e.g., 'sum of all bills', 'average cost', 'count of patients'). Treat 'total bill amount' as the individual bill amount (e.g., bill.amount) unless aggregation is clearly specified. For names, concatenate first_name and last_name if applicable (e.g., CONCAT(first_name, ' ', last_name) AS 'Full Name'). Use direct JOINs with correct foreign key relationships. Avoid subqueries unless absolutely necessary. Place filtering conditions (e.g., department name, status) in the WHERE clause, not JOIN clauses. Handle case sensitivity in string comparisons by using LOWER() for status fields (e.g., LOWER(status) = 'unpaid'). Verify table relationships before joining.
|
|
@@ -262,6 +269,9 @@ def index():
|
|
| 262 |
results = None
|
| 263 |
generated_query = None
|
| 264 |
|
|
|
|
|
|
|
|
|
|
| 265 |
if request.method == 'POST':
|
| 266 |
if 'sql_file' in request.files:
|
| 267 |
file = request.files['sql_file']
|
|
|
|
| 19 |
'port': int(os.getenv('DB_PORT', 3306))
|
| 20 |
}
|
| 21 |
|
| 22 |
+
# Groq API configuration with error handling
|
| 23 |
+
try:
|
| 24 |
+
groq_client = Groq(api_key=os.getenv('GROQ_API_KEY'))
|
| 25 |
+
except Exception as e:
|
| 26 |
+
groq_client = None
|
| 27 |
+
print(f"Failed to initialize Groq client: {str(e)}")
|
| 28 |
|
| 29 |
# Temporary storage for current database name and schema
|
| 30 |
current_db_name = None
|
|
|
|
| 219 |
|
| 220 |
def generate_sql_query(question, schema):
|
| 221 |
"""Generate SQL query using Groq API with user-friendly aliases."""
|
| 222 |
+
if not groq_client:
|
| 223 |
+
return "ERROR: Groq client not initialized. Check API key and try again."
|
| 224 |
+
|
| 225 |
schema_text = "\n".join([f"Table: {table}\nColumns: {', '.join(columns)}" for table, columns in schema.items()])
|
| 226 |
prompt = f"""
|
| 227 |
You are a SQL expert. Based on the following database schema, generate a valid MySQL query for the user's question. Only use tables and columns that exist in the schema. Use user-friendly aliases for column names (e.g., 'cust_id' becomes 'Customer ID', 'admission_date' becomes 'Admission Date'). Return ONLY the SQL query, without explanations, markdown, or code block formatting (e.g., no ```). If the question references non-existent tables or columns, return an error message starting with 'ERROR:'. Do not use GROUP BY or aggregation functions (e.g., SUM, COUNT, AVG) unless the question explicitly requests aggregation (e.g., 'sum of all bills', 'average cost', 'count of patients'). Treat 'total bill amount' as the individual bill amount (e.g., bill.amount) unless aggregation is clearly specified. For names, concatenate first_name and last_name if applicable (e.g., CONCAT(first_name, ' ', last_name) AS 'Full Name'). Use direct JOINs with correct foreign key relationships. Avoid subqueries unless absolutely necessary. Place filtering conditions (e.g., department name, status) in the WHERE clause, not JOIN clauses. Handle case sensitivity in string comparisons by using LOWER() for status fields (e.g., LOWER(status) = 'unpaid'). Verify table relationships before joining.
|
|
|
|
| 269 |
results = None
|
| 270 |
generated_query = None
|
| 271 |
|
| 272 |
+
if not groq_client:
|
| 273 |
+
error = "Groq client not initialized. Please check GROQ_API_KEY and restart the app."
|
| 274 |
+
|
| 275 |
if request.method == 'POST':
|
| 276 |
if 'sql_file' in request.files:
|
| 277 |
file = request.files['sql_file']
|