SVashishta1 commited on
Commit
77df513
·
1 Parent(s): e770679

Fix: Improve SQL query generation with better column checking and error handling

Browse files
Files changed (1) hide show
  1. app.py +124 -12
app.py CHANGED
@@ -69,14 +69,16 @@ query_prompt = ChatPromptTemplate.from_template("""
69
  You are a SQL expert. Given a question about data in a table, write a SQLite-compatible SQL query to answer the question.
70
 
71
  Important guidelines:
72
- 1. Use SQLite syntax (not PostgreSQL or MySQL)
73
- 2. For date functions, use strftime() instead of EXTRACT
 
74
  - Example: strftime('%Y', date_column) instead of EXTRACT(YEAR FROM date_column)
75
- 3. SQLite doesn't have TRUNCATE function, use CAST((column / bin_size) AS INT) * bin_size instead
76
- 4. For percentiles, use window functions or approximate methods
77
- 5. Keep queries efficient and focused on answering the specific question
78
- 6. Always use 'data_tab' as the table name
79
- 7. IMPORTANT: Return ONLY the SQL query without any markdown formatting, explanations, or code blocks
 
80
 
81
  Question: {question}
82
  """)
@@ -210,11 +212,29 @@ def process_text_query(query, history):
210
  # Get column information for context
211
  cursor = conn.cursor()
212
  cursor.execute(f"PRAGMA table_info({current_context['table_name']});")
213
- columns = [info[1] for info in cursor.fetchall()]
214
- columns_str = ", ".join(columns)
 
215
 
216
- # Create question with context
217
- question_with_context = f"The table 'data_tab' has columns: {columns_str}. {query}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
 
219
  # Special handling for visualization types that need raw data
220
  if is_visualization and viz_type in ['box', 'heatmap']:
@@ -270,8 +290,100 @@ def process_text_query(query, history):
270
  sql_query = llm.invoke(query_prompt.format(question=question_with_context)).content
271
  sql_query = clean_sql_query(sql_query)
272
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
  # Execute the query
274
- result_df = pd.read_sql_query(sql_query, conn)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
275
 
276
  # Close the connection
277
  conn.close()
 
69
  You are a SQL expert. Given a question about data in a table, write a SQLite-compatible SQL query to answer the question.
70
 
71
  Important guidelines:
72
+ 1. MOST IMPORTANT: Only use columns that are explicitly provided in the context. Do not assume or invent columns.
73
+ 2. Use SQLite syntax (not PostgreSQL or MySQL)
74
+ 3. For date functions, use strftime() instead of EXTRACT
75
  - Example: strftime('%Y', date_column) instead of EXTRACT(YEAR FROM date_column)
76
+ 4. SQLite doesn't have TRUNCATE function, use CAST((column / bin_size) AS INT) * bin_size instead
77
+ 5. For percentiles, use window functions or approximate methods
78
+ 6. Keep queries efficient and focused on answering the specific question
79
+ 7. Always use 'data_tab' as the table name
80
+ 8. IMPORTANT: Return ONLY the SQL query without any markdown formatting, explanations, or code blocks
81
+ 9. If the question seems to require a column that isn't provided, use the most relevant existing column instead
82
 
83
  Question: {question}
84
  """)
 
212
  # Get column information for context
213
  cursor = conn.cursor()
214
  cursor.execute(f"PRAGMA table_info({current_context['table_name']});")
215
+ columns_info = cursor.fetchall()
216
+ columns = [info[1] for info in columns_info]
217
+ column_types = [info[2] for info in columns_info]
218
 
219
+ # Create rich context with column types
220
+ columns_with_types = [f"{col} ({typ})" for col, typ in zip(columns, column_types)]
221
+ columns_str = ", ".join(columns_with_types)
222
+
223
+ # Create sample data context
224
+ sample_query = "SELECT * FROM data_tab LIMIT 3;"
225
+ sample_df = pd.read_sql_query(sample_query, conn)
226
+ sample_data = sample_df.to_string(index=False, max_rows=3)
227
+
228
+ # Create question with detailed context
229
+ question_with_context = f"""
230
+ The table 'data_tab' has the following columns with their types:
231
+ {columns_str}
232
+
233
+ Here's a sample of the data:
234
+ {sample_data}
235
+
236
+ User question: {query}
237
+ """
238
 
239
  # Special handling for visualization types that need raw data
240
  if is_visualization and viz_type in ['box', 'heatmap']:
 
290
  sql_query = llm.invoke(query_prompt.format(question=question_with_context)).content
291
  sql_query = clean_sql_query(sql_query)
292
 
293
+ # Check if all columns in the query exist before executing
294
+ try:
295
+ # Get all column names
296
+ cursor.execute("PRAGMA table_info(data_tab);")
297
+ available_columns = [info[1] for info in cursor.fetchall()]
298
+
299
+ # Extract column names from the SQL query (simple approach)
300
+ query_columns = []
301
+ from_pos = sql_query.lower().find("from")
302
+ if from_pos > 0:
303
+ select_part = sql_query[:from_pos].lower()
304
+ # Remove SELECT keyword
305
+ if select_part.startswith("select "):
306
+ select_part = select_part[7:]
307
+
308
+ # Split by commas and extract column names
309
+ for col_expr in select_part.split(","):
310
+ col_expr = col_expr.strip()
311
+ # Handle AS aliases and functions
312
+ if " as " in col_expr:
313
+ col_expr = col_expr.split(" as ")[0].strip()
314
+
315
+ # Extract column name from functions
316
+ for func in ["max(", "min(", "avg(", "sum(", "count("]:
317
+ if func in col_expr:
318
+ # Extract column inside function
319
+ start_idx = col_expr.find(func) + len(func)
320
+ end_idx = col_expr.find(")", start_idx)
321
+ if end_idx > start_idx:
322
+ col_name = col_expr[start_idx:end_idx].strip()
323
+ if col_name != "*" and "(" not in col_name: # Skip nested functions and *
324
+ query_columns.append(col_name)
325
+
326
+ # Handle direct column references
327
+ if "(" not in col_expr and col_expr != "*":
328
+ query_columns.append(col_expr)
329
+
330
+ # Check for missing columns
331
+ missing_columns = []
332
+ for col in query_columns:
333
+ if col not in available_columns and col.strip() != "*":
334
+ missing_columns.append(col)
335
+
336
+ if missing_columns:
337
+ # Generate a simpler query with available columns
338
+ if "tip" in query.lower() or "gratuity" in query.lower():
339
+ # Look for a tip column
340
+ tip_columns = [col for col in available_columns if "tip" in col.lower() or "gratuity" in col.lower()]
341
+ if tip_columns:
342
+ sql_query = f"SELECT MAX({tip_columns[0]}) AS highest_tip FROM data_tab"
343
+ else:
344
+ # No tip column, return info about available columns
345
+ return f"I couldn't find a column related to tips or gratuity. Available columns are: {', '.join(available_columns)}", history
346
+ else:
347
+ # For other queries, suggest a generic query
348
+ return f"Some columns in the query don't exist in the current dataset: {', '.join(missing_columns)}. Available columns are: {', '.join(available_columns)}", history
349
+ except Exception as e:
350
+ print(f"Error checking columns: {str(e)}")
351
+ # Continue with the original query
352
+
353
  # Execute the query
354
+ try:
355
+ result_df = pd.read_sql_query(sql_query, conn)
356
+ except Exception as e:
357
+ error_message = str(e)
358
+
359
+ # Try to provide a more helpful error message
360
+ if "no such column" in error_message.lower():
361
+ # Extract column name from error
362
+ column_name = error_message.split("no such column: ")[-1].strip("'").strip('"')
363
+
364
+ # Look for similar columns
365
+ cursor.execute("PRAGMA table_info(data_tab);")
366
+ available_columns = [info[1] for info in cursor.fetchall()]
367
+
368
+ # Simple fuzzy matching
369
+ similar_columns = []
370
+ for col in available_columns:
371
+ # Check if column name contains parts of the error column
372
+ if column_name.lower() in col.lower() or any(part.lower() in col.lower() for part in column_name.split('_') if len(part) > 2):
373
+ similar_columns.append(col)
374
+
375
+ if similar_columns:
376
+ message = f"Column '{column_name}' doesn't exist in the current dataset. Did you mean one of these? {', '.join(similar_columns)}\n\nAvailable columns are: {', '.join(available_columns)}"
377
+ else:
378
+ message = f"Column '{column_name}' doesn't exist in the current dataset. Available columns are: {', '.join(available_columns)}"
379
+
380
+ history[-1][1] = message
381
+ return message, history
382
+ else:
383
+ # Generic error message
384
+ error_msg = f"Error executing query: {error_message}"
385
+ history[-1][1] = error_msg
386
+ return error_msg, history
387
 
388
  # Close the connection
389
  conn.close()