SVashishta1
commited on
Commit
·
77df513
1
Parent(s):
e770679
Fix: Improve SQL query generation with better column checking and error handling
Browse files
app.py
CHANGED
|
@@ -69,14 +69,16 @@ query_prompt = ChatPromptTemplate.from_template("""
|
|
| 69 |
You are a SQL expert. Given a question about data in a table, write a SQLite-compatible SQL query to answer the question.
|
| 70 |
|
| 71 |
Important guidelines:
|
| 72 |
-
1.
|
| 73 |
-
2.
|
|
|
|
| 74 |
- Example: strftime('%Y', date_column) instead of EXTRACT(YEAR FROM date_column)
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
|
|
|
| 80 |
|
| 81 |
Question: {question}
|
| 82 |
""")
|
|
@@ -210,11 +212,29 @@ def process_text_query(query, history):
|
|
| 210 |
# Get column information for context
|
| 211 |
cursor = conn.cursor()
|
| 212 |
cursor.execute(f"PRAGMA table_info({current_context['table_name']});")
|
| 213 |
-
|
| 214 |
-
|
|
|
|
| 215 |
|
| 216 |
-
# Create
|
| 217 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
|
| 219 |
# Special handling for visualization types that need raw data
|
| 220 |
if is_visualization and viz_type in ['box', 'heatmap']:
|
|
@@ -270,8 +290,100 @@ def process_text_query(query, history):
|
|
| 270 |
sql_query = llm.invoke(query_prompt.format(question=question_with_context)).content
|
| 271 |
sql_query = clean_sql_query(sql_query)
|
| 272 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 273 |
# Execute the query
|
| 274 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 275 |
|
| 276 |
# Close the connection
|
| 277 |
conn.close()
|
|
|
|
| 69 |
You are a SQL expert. Given a question about data in a table, write a SQLite-compatible SQL query to answer the question.
|
| 70 |
|
| 71 |
Important guidelines:
|
| 72 |
+
1. MOST IMPORTANT: Only use columns that are explicitly provided in the context. Do not assume or invent columns.
|
| 73 |
+
2. Use SQLite syntax (not PostgreSQL or MySQL)
|
| 74 |
+
3. For date functions, use strftime() instead of EXTRACT
|
| 75 |
- Example: strftime('%Y', date_column) instead of EXTRACT(YEAR FROM date_column)
|
| 76 |
+
4. SQLite doesn't have TRUNCATE function, use CAST((column / bin_size) AS INT) * bin_size instead
|
| 77 |
+
5. For percentiles, use window functions or approximate methods
|
| 78 |
+
6. Keep queries efficient and focused on answering the specific question
|
| 79 |
+
7. Always use 'data_tab' as the table name
|
| 80 |
+
8. IMPORTANT: Return ONLY the SQL query without any markdown formatting, explanations, or code blocks
|
| 81 |
+
9. If the question seems to require a column that isn't provided, use the most relevant existing column instead
|
| 82 |
|
| 83 |
Question: {question}
|
| 84 |
""")
|
|
|
|
| 212 |
# Get column information for context
|
| 213 |
cursor = conn.cursor()
|
| 214 |
cursor.execute(f"PRAGMA table_info({current_context['table_name']});")
|
| 215 |
+
columns_info = cursor.fetchall()
|
| 216 |
+
columns = [info[1] for info in columns_info]
|
| 217 |
+
column_types = [info[2] for info in columns_info]
|
| 218 |
|
| 219 |
+
# Create rich context with column types
|
| 220 |
+
columns_with_types = [f"{col} ({typ})" for col, typ in zip(columns, column_types)]
|
| 221 |
+
columns_str = ", ".join(columns_with_types)
|
| 222 |
+
|
| 223 |
+
# Create sample data context
|
| 224 |
+
sample_query = "SELECT * FROM data_tab LIMIT 3;"
|
| 225 |
+
sample_df = pd.read_sql_query(sample_query, conn)
|
| 226 |
+
sample_data = sample_df.to_string(index=False, max_rows=3)
|
| 227 |
+
|
| 228 |
+
# Create question with detailed context
|
| 229 |
+
question_with_context = f"""
|
| 230 |
+
The table 'data_tab' has the following columns with their types:
|
| 231 |
+
{columns_str}
|
| 232 |
+
|
| 233 |
+
Here's a sample of the data:
|
| 234 |
+
{sample_data}
|
| 235 |
+
|
| 236 |
+
User question: {query}
|
| 237 |
+
"""
|
| 238 |
|
| 239 |
# Special handling for visualization types that need raw data
|
| 240 |
if is_visualization and viz_type in ['box', 'heatmap']:
|
|
|
|
| 290 |
sql_query = llm.invoke(query_prompt.format(question=question_with_context)).content
|
| 291 |
sql_query = clean_sql_query(sql_query)
|
| 292 |
|
| 293 |
+
# Check if all columns in the query exist before executing
|
| 294 |
+
try:
|
| 295 |
+
# Get all column names
|
| 296 |
+
cursor.execute("PRAGMA table_info(data_tab);")
|
| 297 |
+
available_columns = [info[1] for info in cursor.fetchall()]
|
| 298 |
+
|
| 299 |
+
# Extract column names from the SQL query (simple approach)
|
| 300 |
+
query_columns = []
|
| 301 |
+
from_pos = sql_query.lower().find("from")
|
| 302 |
+
if from_pos > 0:
|
| 303 |
+
select_part = sql_query[:from_pos].lower()
|
| 304 |
+
# Remove SELECT keyword
|
| 305 |
+
if select_part.startswith("select "):
|
| 306 |
+
select_part = select_part[7:]
|
| 307 |
+
|
| 308 |
+
# Split by commas and extract column names
|
| 309 |
+
for col_expr in select_part.split(","):
|
| 310 |
+
col_expr = col_expr.strip()
|
| 311 |
+
# Handle AS aliases and functions
|
| 312 |
+
if " as " in col_expr:
|
| 313 |
+
col_expr = col_expr.split(" as ")[0].strip()
|
| 314 |
+
|
| 315 |
+
# Extract column name from functions
|
| 316 |
+
for func in ["max(", "min(", "avg(", "sum(", "count("]:
|
| 317 |
+
if func in col_expr:
|
| 318 |
+
# Extract column inside function
|
| 319 |
+
start_idx = col_expr.find(func) + len(func)
|
| 320 |
+
end_idx = col_expr.find(")", start_idx)
|
| 321 |
+
if end_idx > start_idx:
|
| 322 |
+
col_name = col_expr[start_idx:end_idx].strip()
|
| 323 |
+
if col_name != "*" and "(" not in col_name: # Skip nested functions and *
|
| 324 |
+
query_columns.append(col_name)
|
| 325 |
+
|
| 326 |
+
# Handle direct column references
|
| 327 |
+
if "(" not in col_expr and col_expr != "*":
|
| 328 |
+
query_columns.append(col_expr)
|
| 329 |
+
|
| 330 |
+
# Check for missing columns
|
| 331 |
+
missing_columns = []
|
| 332 |
+
for col in query_columns:
|
| 333 |
+
if col not in available_columns and col.strip() != "*":
|
| 334 |
+
missing_columns.append(col)
|
| 335 |
+
|
| 336 |
+
if missing_columns:
|
| 337 |
+
# Generate a simpler query with available columns
|
| 338 |
+
if "tip" in query.lower() or "gratuity" in query.lower():
|
| 339 |
+
# Look for a tip column
|
| 340 |
+
tip_columns = [col for col in available_columns if "tip" in col.lower() or "gratuity" in col.lower()]
|
| 341 |
+
if tip_columns:
|
| 342 |
+
sql_query = f"SELECT MAX({tip_columns[0]}) AS highest_tip FROM data_tab"
|
| 343 |
+
else:
|
| 344 |
+
# No tip column, return info about available columns
|
| 345 |
+
return f"I couldn't find a column related to tips or gratuity. Available columns are: {', '.join(available_columns)}", history
|
| 346 |
+
else:
|
| 347 |
+
# For other queries, suggest a generic query
|
| 348 |
+
return f"Some columns in the query don't exist in the current dataset: {', '.join(missing_columns)}. Available columns are: {', '.join(available_columns)}", history
|
| 349 |
+
except Exception as e:
|
| 350 |
+
print(f"Error checking columns: {str(e)}")
|
| 351 |
+
# Continue with the original query
|
| 352 |
+
|
| 353 |
# Execute the query
|
| 354 |
+
try:
|
| 355 |
+
result_df = pd.read_sql_query(sql_query, conn)
|
| 356 |
+
except Exception as e:
|
| 357 |
+
error_message = str(e)
|
| 358 |
+
|
| 359 |
+
# Try to provide a more helpful error message
|
| 360 |
+
if "no such column" in error_message.lower():
|
| 361 |
+
# Extract column name from error
|
| 362 |
+
column_name = error_message.split("no such column: ")[-1].strip("'").strip('"')
|
| 363 |
+
|
| 364 |
+
# Look for similar columns
|
| 365 |
+
cursor.execute("PRAGMA table_info(data_tab);")
|
| 366 |
+
available_columns = [info[1] for info in cursor.fetchall()]
|
| 367 |
+
|
| 368 |
+
# Simple fuzzy matching
|
| 369 |
+
similar_columns = []
|
| 370 |
+
for col in available_columns:
|
| 371 |
+
# Check if column name contains parts of the error column
|
| 372 |
+
if column_name.lower() in col.lower() or any(part.lower() in col.lower() for part in column_name.split('_') if len(part) > 2):
|
| 373 |
+
similar_columns.append(col)
|
| 374 |
+
|
| 375 |
+
if similar_columns:
|
| 376 |
+
message = f"Column '{column_name}' doesn't exist in the current dataset. Did you mean one of these? {', '.join(similar_columns)}\n\nAvailable columns are: {', '.join(available_columns)}"
|
| 377 |
+
else:
|
| 378 |
+
message = f"Column '{column_name}' doesn't exist in the current dataset. Available columns are: {', '.join(available_columns)}"
|
| 379 |
+
|
| 380 |
+
history[-1][1] = message
|
| 381 |
+
return message, history
|
| 382 |
+
else:
|
| 383 |
+
# Generic error message
|
| 384 |
+
error_msg = f"Error executing query: {error_message}"
|
| 385 |
+
history[-1][1] = error_msg
|
| 386 |
+
return error_msg, history
|
| 387 |
|
| 388 |
# Close the connection
|
| 389 |
conn.close()
|