SVashishta1
commited on
Commit
·
61a8a66
1
Parent(s):
2f13356
Cleanup: Remove unused voice library comments
Browse files
app.py
CHANGED
|
@@ -62,76 +62,65 @@ current_context = {
|
|
| 62 |
}
|
| 63 |
|
| 64 |
# Add a global variable to store the current plot
|
| 65 |
-
current_plot = None
|
| 66 |
|
| 67 |
# Define the prompt with examples for SQL query generation
|
| 68 |
query_prompt = ChatPromptTemplate.from_template("""
|
| 69 |
You are a SQL expert. Given a question about data in a table, write a SQLite-compatible SQL query to answer the question.
|
| 70 |
|
| 71 |
-
|
| 72 |
-
1.
|
| 73 |
-
2.
|
| 74 |
-
3. ALWAYS double-check that every column in your query is in the list of available columns.
|
| 75 |
-
|
| 76 |
-
Technical guidelines:
|
| 77 |
-
4. Use SQLite syntax (not PostgreSQL or MySQL)
|
| 78 |
-
5. For date functions, use strftime() instead of EXTRACT
|
| 79 |
- Example: strftime('%Y', date_column) instead of EXTRACT(YEAR FROM date_column)
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
|
| 86 |
Question: {question}
|
| 87 |
""")
|
| 88 |
|
| 89 |
-
# Add this after the query_prompt definition
|
| 90 |
-
visualization_prompt = ChatPromptTemplate.from_template("""
|
| 91 |
-
You are a data visualization expert. Given a question about visualizing data, write a SQLite-compatible SQL query that will retrieve the appropriate data for the visualization.
|
| 92 |
-
|
| 93 |
-
Important guidelines for SQLite syntax:
|
| 94 |
-
1. Use strftime() for date functions:
|
| 95 |
-
- Year: strftime('%Y', date_column)
|
| 96 |
-
- Month: strftime('%m', date_column)
|
| 97 |
-
- Day: strftime('%d', date_column)
|
| 98 |
-
- Hour: strftime('%H', date_column)
|
| 99 |
-
|
| 100 |
-
2. For histograms and binning:
|
| 101 |
-
- Use: CAST((column / bin_size) AS INT) * bin_size
|
| 102 |
-
- Example: CAST((trip_distance / 0.5) AS INT) * 0.5 AS distance_bin
|
| 103 |
-
|
| 104 |
-
3. For box plots:
|
| 105 |
-
- SQLite doesn't support PERCENTILE_CONT or window functions
|
| 106 |
-
- Simply return the raw data column: SELECT column_name FROM data_tab
|
| 107 |
-
- The application will calculate quartiles and outliers
|
| 108 |
-
|
| 109 |
-
4. For heatmaps:
|
| 110 |
-
- Return raw data for correlation analysis
|
| 111 |
-
- Example: SELECT numeric_col1, numeric_col2, numeric_col3 FROM data_tab
|
| 112 |
-
|
| 113 |
-
5. Always use 'data_tab' as the table name
|
| 114 |
-
|
| 115 |
-
6. IMPORTANT: Return ONLY the SQL query without any markdown formatting, explanations, or code blocks
|
| 116 |
-
|
| 117 |
-
Question: {question}
|
| 118 |
-
Visualization type: {viz_type}
|
| 119 |
-
""")
|
| 120 |
-
|
| 121 |
# Define the prompt for interpreting the SQL query result
|
| 122 |
interpret_prompt = ChatPromptTemplate.from_messages(
|
| 123 |
[
|
| 124 |
-
("system", "
|
| 125 |
-
|
| 126 |
-
If relevant, give key statistics, trends, or patterns. Be clear about what the data shows and doesn't show.
|
| 127 |
-
|
| 128 |
-
If the SQL query had to use alternative columns because the exact ones requested weren't available, explain this clearly to the user.
|
| 129 |
-
|
| 130 |
-
For example, if they asked about 'fare_amount' but the dataset has 'fare' or 'total_fare' instead, mention this substitution."""),
|
| 131 |
("human", "Question: {question}\nSQL Query: {sql_query}\nData Summary:\n{data_summary}")
|
| 132 |
]
|
| 133 |
)
|
| 134 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
# Add this helper function to clean SQL queries
|
| 136 |
def clean_sql_query(query_text):
|
| 137 |
"""Clean SQL query text by removing markdown formatting and comments"""
|
|
@@ -218,59 +207,14 @@ def process_text_query(query, history):
|
|
| 218 |
# Connect to the database
|
| 219 |
conn = sqlite3.connect(DB_PATH)
|
| 220 |
|
| 221 |
-
# Get
|
| 222 |
cursor = conn.cursor()
|
| 223 |
cursor.execute(f"PRAGMA table_info({current_context['table_name']});")
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
column_types = [info[2] for info in columns_info]
|
| 227 |
-
|
| 228 |
-
# Create rich context with column types
|
| 229 |
-
columns_with_types = [f"{col} ({typ})" for col, typ in zip(columns, column_types)]
|
| 230 |
-
columns_str = ", ".join(columns_with_types)
|
| 231 |
-
|
| 232 |
-
# Handle specific queries directly based on schema
|
| 233 |
-
if "highest tip" in query.lower() or "largest tip" in query.lower() or "maximum tip" in query.lower():
|
| 234 |
-
# Look for tip-related columns
|
| 235 |
-
tip_columns = [col for col in columns if "tip" in col.lower() or "gratuity" in col.lower()]
|
| 236 |
-
if tip_columns:
|
| 237 |
-
print(f"Found tip-related columns: {tip_columns}")
|
| 238 |
-
sql_query = f"SELECT MAX({tip_columns[0]}) AS highest_tip FROM data_tab"
|
| 239 |
-
|
| 240 |
-
# Execute the query directly
|
| 241 |
-
result_df = pd.read_sql_query(sql_query, conn)
|
| 242 |
-
|
| 243 |
-
# Generate response
|
| 244 |
-
highest_tip = result_df.iloc[0, 0]
|
| 245 |
-
response = f"The highest tip in the dataset is {highest_tip}."
|
| 246 |
-
history[-1][1] = response
|
| 247 |
-
return response, history
|
| 248 |
-
else:
|
| 249 |
-
response = f"I couldn't find any columns related to tips in the dataset. Available columns are: {', '.join(columns)}"
|
| 250 |
-
history[-1][1] = response
|
| 251 |
-
return response, history
|
| 252 |
|
| 253 |
-
# Create
|
| 254 |
-
|
| 255 |
-
sample_df = pd.read_sql_query(sample_query, conn)
|
| 256 |
-
sample_data = sample_df.to_string(index=False, max_rows=3)
|
| 257 |
-
|
| 258 |
-
# Create question with detailed context
|
| 259 |
-
question_with_context = f"""
|
| 260 |
-
IMPORTANT: ONLY use the exact columns listed below. DO NOT use any columns not explicitly listed here.
|
| 261 |
-
|
| 262 |
-
The table 'data_tab' has these columns with their types:
|
| 263 |
-
{columns_str}
|
| 264 |
-
|
| 265 |
-
Available columns (exact names): {', '.join(columns)}
|
| 266 |
-
|
| 267 |
-
Here's a sample of the data:
|
| 268 |
-
{sample_data}
|
| 269 |
-
|
| 270 |
-
User question: {query}
|
| 271 |
-
|
| 272 |
-
Remember to ONLY use the columns listed above. If the question seems to require a column that doesn't exist, use the most relevant existing column instead or explain that the data doesn't contain that information.
|
| 273 |
-
"""
|
| 274 |
|
| 275 |
# Special handling for visualization types that need raw data
|
| 276 |
if is_visualization and viz_type in ['box', 'heatmap']:
|
|
@@ -314,112 +258,13 @@ Remember to ONLY use the columns listed above. If the question seems to require
|
|
| 314 |
sql_query = f"SELECT {cols_str} FROM data_tab WHERE {numeric_cols[0]} IS NOT NULL LIMIT 1000;"
|
| 315 |
else:
|
| 316 |
sql_query = "SELECT * FROM data_tab LIMIT 10;"
|
| 317 |
-
elif is_visualization:
|
| 318 |
-
# For visualization queries, use the specialized visualization prompt
|
| 319 |
-
sql_query = llm.invoke(visualization_prompt.format(
|
| 320 |
-
question=question_with_context,
|
| 321 |
-
viz_type=viz_type or "bar"
|
| 322 |
-
)).content
|
| 323 |
-
sql_query = clean_sql_query(sql_query)
|
| 324 |
else:
|
| 325 |
# For other queries, use the LLM to generate SQL
|
| 326 |
sql_query = llm.invoke(query_prompt.format(question=question_with_context)).content
|
| 327 |
sql_query = clean_sql_query(sql_query)
|
| 328 |
|
| 329 |
-
# Check if all columns in the query exist before executing
|
| 330 |
-
try:
|
| 331 |
-
# Get all column names
|
| 332 |
-
cursor.execute("PRAGMA table_info(data_tab);")
|
| 333 |
-
available_columns = [info[1] for info in cursor.fetchall()]
|
| 334 |
-
|
| 335 |
-
# Extract column names from the SQL query (simple approach)
|
| 336 |
-
query_columns = []
|
| 337 |
-
from_pos = sql_query.lower().find("from")
|
| 338 |
-
if from_pos > 0:
|
| 339 |
-
select_part = sql_query[:from_pos].lower()
|
| 340 |
-
# Remove SELECT keyword
|
| 341 |
-
if select_part.startswith("select "):
|
| 342 |
-
select_part = select_part[7:]
|
| 343 |
-
|
| 344 |
-
# Split by commas and extract column names
|
| 345 |
-
for col_expr in select_part.split(","):
|
| 346 |
-
col_expr = col_expr.strip()
|
| 347 |
-
# Handle AS aliases and functions
|
| 348 |
-
if " as " in col_expr:
|
| 349 |
-
col_expr = col_expr.split(" as ")[0].strip()
|
| 350 |
-
|
| 351 |
-
# Extract column name from functions
|
| 352 |
-
for func in ["max(", "min(", "avg(", "sum(", "count("]:
|
| 353 |
-
if func in col_expr:
|
| 354 |
-
# Extract column inside function
|
| 355 |
-
start_idx = col_expr.find(func) + len(func)
|
| 356 |
-
end_idx = col_expr.find(")", start_idx)
|
| 357 |
-
if end_idx > start_idx:
|
| 358 |
-
col_name = col_expr[start_idx:end_idx].strip()
|
| 359 |
-
if col_name != "*" and "(" not in col_name: # Skip nested functions and *
|
| 360 |
-
query_columns.append(col_name)
|
| 361 |
-
|
| 362 |
-
# Handle direct column references
|
| 363 |
-
if "(" not in col_expr and col_expr != "*":
|
| 364 |
-
query_columns.append(col_expr)
|
| 365 |
-
|
| 366 |
-
# Check for missing columns
|
| 367 |
-
missing_columns = []
|
| 368 |
-
for col in query_columns:
|
| 369 |
-
if col not in available_columns and col.strip() != "*":
|
| 370 |
-
missing_columns.append(col)
|
| 371 |
-
|
| 372 |
-
if missing_columns:
|
| 373 |
-
# Generate a simpler query with available columns
|
| 374 |
-
if "tip" in query.lower() or "gratuity" in query.lower():
|
| 375 |
-
# Look for a tip column
|
| 376 |
-
tip_columns = [col for col in available_columns if "tip" in col.lower() or "gratuity" in col.lower()]
|
| 377 |
-
if tip_columns:
|
| 378 |
-
sql_query = f"SELECT MAX({tip_columns[0]}) AS highest_tip FROM data_tab"
|
| 379 |
-
else:
|
| 380 |
-
# No tip column, return info about available columns
|
| 381 |
-
return f"I couldn't find a column related to tips or gratuity. Available columns are: {', '.join(available_columns)}", history
|
| 382 |
-
else:
|
| 383 |
-
# For other queries, suggest a generic query
|
| 384 |
-
return f"Some columns in the query don't exist in the current dataset: {', '.join(missing_columns)}. Available columns are: {', '.join(available_columns)}", history
|
| 385 |
-
except Exception as e:
|
| 386 |
-
print(f"Error checking columns: {str(e)}")
|
| 387 |
-
# Continue with the original query
|
| 388 |
-
|
| 389 |
# Execute the query
|
| 390 |
-
|
| 391 |
-
result_df = pd.read_sql_query(sql_query, conn)
|
| 392 |
-
except Exception as e:
|
| 393 |
-
error_message = str(e)
|
| 394 |
-
|
| 395 |
-
# Try to provide a more helpful error message
|
| 396 |
-
if "no such column" in error_message.lower():
|
| 397 |
-
# Extract column name from error
|
| 398 |
-
column_name = error_message.split("no such column: ")[-1].strip("'").strip('"')
|
| 399 |
-
|
| 400 |
-
# Look for similar columns
|
| 401 |
-
cursor.execute("PRAGMA table_info(data_tab);")
|
| 402 |
-
available_columns = [info[1] for info in cursor.fetchall()]
|
| 403 |
-
|
| 404 |
-
# Simple fuzzy matching
|
| 405 |
-
similar_columns = []
|
| 406 |
-
for col in available_columns:
|
| 407 |
-
# Check if column name contains parts of the error column
|
| 408 |
-
if column_name.lower() in col.lower() or any(part.lower() in col.lower() for part in column_name.split('_') if len(part) > 2):
|
| 409 |
-
similar_columns.append(col)
|
| 410 |
-
|
| 411 |
-
if similar_columns:
|
| 412 |
-
message = f"Column '{column_name}' doesn't exist in the current dataset. Did you mean one of these? {', '.join(similar_columns)}\n\nAvailable columns are: {', '.join(available_columns)}"
|
| 413 |
-
else:
|
| 414 |
-
message = f"Column '{column_name}' doesn't exist in the current dataset. Available columns are: {', '.join(available_columns)}"
|
| 415 |
-
|
| 416 |
-
history[-1][1] = message
|
| 417 |
-
return message, history
|
| 418 |
-
else:
|
| 419 |
-
# Generic error message
|
| 420 |
-
error_msg = f"Error executing query: {error_message}"
|
| 421 |
-
history[-1][1] = error_msg
|
| 422 |
-
return error_msg, history
|
| 423 |
|
| 424 |
# Close the connection
|
| 425 |
conn.close()
|
|
|
|
| 62 |
}
|
| 63 |
|
| 64 |
# Add a global variable to store the current plot
|
| 65 |
+
# current_plot = None
|
| 66 |
|
| 67 |
# Define the prompt with examples for SQL query generation
|
| 68 |
query_prompt = ChatPromptTemplate.from_template("""
|
| 69 |
You are a SQL expert. Given a question about data in a table, write a SQLite-compatible SQL query to answer the question.
|
| 70 |
|
| 71 |
+
Important guidelines:
|
| 72 |
+
1. Use SQLite syntax (not PostgreSQL or MySQL)
|
| 73 |
+
2. For date functions, use strftime() instead of EXTRACT
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
- Example: strftime('%Y', date_column) instead of EXTRACT(YEAR FROM date_column)
|
| 75 |
+
3. SQLite doesn't have TRUNCATE function, use CAST((column / bin_size) AS INT) * bin_size instead
|
| 76 |
+
4. For percentiles, use window functions or approximate methods
|
| 77 |
+
5. Keep queries efficient and focused on answering the specific question
|
| 78 |
+
6. Always use 'data_tab' as the table name
|
| 79 |
+
7. IMPORTANT: Return ONLY the SQL query without any markdown formatting, explanations, or code blocks
|
| 80 |
|
| 81 |
Question: {question}
|
| 82 |
""")
|
| 83 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
# Define the prompt for interpreting the SQL query result
|
| 85 |
interpret_prompt = ChatPromptTemplate.from_messages(
|
| 86 |
[
|
| 87 |
+
("system", "You are an experienced data analyst. Provide a concise, natural language answer based on the given data summary. If relevant, give key statistics, trends, or patterns."),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
("human", "Question: {question}\nSQL Query: {sql_query}\nData Summary:\n{data_summary}")
|
| 89 |
]
|
| 90 |
)
|
| 91 |
|
| 92 |
+
# Add this after the query_prompt definition
|
| 93 |
+
# visualization_prompt = ChatPromptTemplate.from_template("""
|
| 94 |
+
# You are a data visualization expert. Given a question about visualizing data, write a SQLite-compatible SQL query that will retrieve the appropriate data for the visualization.
|
| 95 |
+
#
|
| 96 |
+
# Important guidelines for SQLite syntax:
|
| 97 |
+
# 1. Use strftime() for date functions:
|
| 98 |
+
# - Year: strftime('%Y', date_column)
|
| 99 |
+
# - Month: strftime('%m', date_column)
|
| 100 |
+
# - Day: strftime('%d', date_column)
|
| 101 |
+
# - Hour: strftime('%H', date_column)
|
| 102 |
+
#
|
| 103 |
+
# 2. For histograms and binning:
|
| 104 |
+
# - Use: CAST((column / bin_size) AS INT) * bin_size
|
| 105 |
+
# - Example: CAST((trip_distance / 0.5) AS INT) * 0.5 AS distance_bin
|
| 106 |
+
#
|
| 107 |
+
# 3. For box plots:
|
| 108 |
+
# - SQLite doesn't support PERCENTILE_CONT or window functions
|
| 109 |
+
# - Simply return the raw data column: SELECT column_name FROM data_tab
|
| 110 |
+
# - The application will calculate quartiles and outliers
|
| 111 |
+
#
|
| 112 |
+
# 4. For heatmaps:
|
| 113 |
+
# - Return raw data for correlation analysis
|
| 114 |
+
# - Example: SELECT numeric_col1, numeric_col2, numeric_col3 FROM data_tab
|
| 115 |
+
#
|
| 116 |
+
# 5. Always use 'data_tab' as the table name
|
| 117 |
+
#
|
| 118 |
+
# 6. IMPORTANT: Return ONLY the SQL query without any markdown formatting, explanations, or code blocks
|
| 119 |
+
#
|
| 120 |
+
# Question: {question}
|
| 121 |
+
# Visualization type: {viz_type}
|
| 122 |
+
# """)
|
| 123 |
+
|
| 124 |
# Add this helper function to clean SQL queries
|
| 125 |
def clean_sql_query(query_text):
|
| 126 |
"""Clean SQL query text by removing markdown formatting and comments"""
|
|
|
|
| 207 |
# Connect to the database
|
| 208 |
conn = sqlite3.connect(DB_PATH)
|
| 209 |
|
| 210 |
+
# Get column information for context
|
| 211 |
cursor = conn.cursor()
|
| 212 |
cursor.execute(f"PRAGMA table_info({current_context['table_name']});")
|
| 213 |
+
columns = [info[1] for info in cursor.fetchall()]
|
| 214 |
+
columns_str = ", ".join(columns)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 215 |
|
| 216 |
+
# Create question with context
|
| 217 |
+
question_with_context = f"The table 'data_tab' has columns: {columns_str}. {query}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
|
| 219 |
# Special handling for visualization types that need raw data
|
| 220 |
if is_visualization and viz_type in ['box', 'heatmap']:
|
|
|
|
| 258 |
sql_query = f"SELECT {cols_str} FROM data_tab WHERE {numeric_cols[0]} IS NOT NULL LIMIT 1000;"
|
| 259 |
else:
|
| 260 |
sql_query = "SELECT * FROM data_tab LIMIT 10;"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 261 |
else:
|
| 262 |
# For other queries, use the LLM to generate SQL
|
| 263 |
sql_query = llm.invoke(query_prompt.format(question=question_with_context)).content
|
| 264 |
sql_query = clean_sql_query(sql_query)
|
| 265 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 266 |
# Execute the query
|
| 267 |
+
result_df = pd.read_sql_query(sql_query, conn)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 268 |
|
| 269 |
# Close the connection
|
| 270 |
conn.close()
|