SVashishta1
commited on
Commit
·
e770679
1
Parent(s):
b8b94fc
Restore: Visualization prompt and current_plot for better LLM guidance and SQL generation
Browse files
app.py
CHANGED
|
@@ -13,9 +13,7 @@ import plotly.io as pio
|
|
| 13 |
import traceback
|
| 14 |
import base64
|
| 15 |
from io import BytesIO
|
| 16 |
-
|
| 17 |
-
# import speech_recognition as sr
|
| 18 |
-
# from gtts import gTTS
|
| 19 |
import re
|
| 20 |
import importlib.util
|
| 21 |
|
|
@@ -64,7 +62,7 @@ current_context = {
|
|
| 64 |
}
|
| 65 |
|
| 66 |
# Add a global variable to store the current plot
|
| 67 |
-
|
| 68 |
|
| 69 |
# Define the prompt with examples for SQL query generation
|
| 70 |
query_prompt = ChatPromptTemplate.from_template("""
|
|
@@ -83,6 +81,38 @@ Important guidelines:
|
|
| 83 |
Question: {question}
|
| 84 |
""")
|
| 85 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
# Define the prompt for interpreting the SQL query result
|
| 87 |
interpret_prompt = ChatPromptTemplate.from_messages(
|
| 88 |
[
|
|
@@ -91,38 +121,6 @@ interpret_prompt = ChatPromptTemplate.from_messages(
|
|
| 91 |
]
|
| 92 |
)
|
| 93 |
|
| 94 |
-
# Add this after the query_prompt definition
|
| 95 |
-
# visualization_prompt = ChatPromptTemplate.from_template("""
|
| 96 |
-
# You are a data visualization expert. Given a question about visualizing data, write a SQLite-compatible SQL query that will retrieve the appropriate data for the visualization.
|
| 97 |
-
#
|
| 98 |
-
# Important guidelines for SQLite syntax:
|
| 99 |
-
# 1. Use strftime() for date functions:
|
| 100 |
-
# - Year: strftime('%Y', date_column)
|
| 101 |
-
# - Month: strftime('%m', date_column)
|
| 102 |
-
# - Day: strftime('%d', date_column)
|
| 103 |
-
# - Hour: strftime('%H', date_column)
|
| 104 |
-
#
|
| 105 |
-
# 2. For histograms and binning:
|
| 106 |
-
# - Use: CAST((column / bin_size) AS INT) * bin_size
|
| 107 |
-
# - Example: CAST((trip_distance / 0.5) AS INT) * 0.5 AS distance_bin
|
| 108 |
-
#
|
| 109 |
-
# 3. For box plots:
|
| 110 |
-
# - SQLite doesn't support PERCENTILE_CONT or window functions
|
| 111 |
-
# - Simply return the raw data column: SELECT column_name FROM data_tab
|
| 112 |
-
# - The application will calculate quartiles and outliers
|
| 113 |
-
#
|
| 114 |
-
# 4. For heatmaps:
|
| 115 |
-
# - Return raw data for correlation analysis
|
| 116 |
-
# - Example: SELECT numeric_col1, numeric_col2, numeric_col3 FROM data_tab
|
| 117 |
-
#
|
| 118 |
-
# 5. Always use 'data_tab' as the table name
|
| 119 |
-
#
|
| 120 |
-
# 6. IMPORTANT: Return ONLY the SQL query without any markdown formatting, explanations, or code blocks
|
| 121 |
-
#
|
| 122 |
-
# Question: {question}
|
| 123 |
-
# Visualization type: {viz_type}
|
| 124 |
-
# """)
|
| 125 |
-
|
| 126 |
# Add this helper function to clean SQL queries
|
| 127 |
def clean_sql_query(query_text):
|
| 128 |
"""Clean SQL query text by removing markdown formatting and comments"""
|
|
@@ -260,6 +258,13 @@ def process_text_query(query, history):
|
|
| 260 |
sql_query = f"SELECT {cols_str} FROM data_tab WHERE {numeric_cols[0]} IS NOT NULL LIMIT 1000;"
|
| 261 |
else:
|
| 262 |
sql_query = "SELECT * FROM data_tab LIMIT 10;"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 263 |
else:
|
| 264 |
# For other queries, use the LLM to generate SQL
|
| 265 |
sql_query = llm.invoke(query_prompt.format(question=question_with_context)).content
|
|
|
|
| 13 |
import traceback
|
| 14 |
import base64
|
| 15 |
from io import BytesIO
|
| 16 |
+
|
|
|
|
|
|
|
| 17 |
import re
|
| 18 |
import importlib.util
|
| 19 |
|
|
|
|
| 62 |
}
|
| 63 |
|
| 64 |
# Add a global variable to store the current plot
|
| 65 |
+
current_plot = None
|
| 66 |
|
| 67 |
# Define the prompt with examples for SQL query generation
|
| 68 |
query_prompt = ChatPromptTemplate.from_template("""
|
|
|
|
| 81 |
Question: {question}
|
| 82 |
""")
|
| 83 |
|
| 84 |
+
# Add this after the query_prompt definition
|
| 85 |
+
visualization_prompt = ChatPromptTemplate.from_template("""
|
| 86 |
+
You are a data visualization expert. Given a question about visualizing data, write a SQLite-compatible SQL query that will retrieve the appropriate data for the visualization.
|
| 87 |
+
|
| 88 |
+
Important guidelines for SQLite syntax:
|
| 89 |
+
1. Use strftime() for date functions:
|
| 90 |
+
- Year: strftime('%Y', date_column)
|
| 91 |
+
- Month: strftime('%m', date_column)
|
| 92 |
+
- Day: strftime('%d', date_column)
|
| 93 |
+
- Hour: strftime('%H', date_column)
|
| 94 |
+
|
| 95 |
+
2. For histograms and binning:
|
| 96 |
+
- Use: CAST((column / bin_size) AS INT) * bin_size
|
| 97 |
+
- Example: CAST((trip_distance / 0.5) AS INT) * 0.5 AS distance_bin
|
| 98 |
+
|
| 99 |
+
3. For box plots:
|
| 100 |
+
- SQLite doesn't support PERCENTILE_CONT or window functions
|
| 101 |
+
- Simply return the raw data column: SELECT column_name FROM data_tab
|
| 102 |
+
- The application will calculate quartiles and outliers
|
| 103 |
+
|
| 104 |
+
4. For heatmaps:
|
| 105 |
+
- Return raw data for correlation analysis
|
| 106 |
+
- Example: SELECT numeric_col1, numeric_col2, numeric_col3 FROM data_tab
|
| 107 |
+
|
| 108 |
+
5. Always use 'data_tab' as the table name
|
| 109 |
+
|
| 110 |
+
6. IMPORTANT: Return ONLY the SQL query without any markdown formatting, explanations, or code blocks
|
| 111 |
+
|
| 112 |
+
Question: {question}
|
| 113 |
+
Visualization type: {viz_type}
|
| 114 |
+
""")
|
| 115 |
+
|
| 116 |
# Define the prompt for interpreting the SQL query result
|
| 117 |
interpret_prompt = ChatPromptTemplate.from_messages(
|
| 118 |
[
|
|
|
|
| 121 |
]
|
| 122 |
)
|
| 123 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
# Add this helper function to clean SQL queries
|
| 125 |
def clean_sql_query(query_text):
|
| 126 |
"""Clean SQL query text by removing markdown formatting and comments"""
|
|
|
|
| 258 |
sql_query = f"SELECT {cols_str} FROM data_tab WHERE {numeric_cols[0]} IS NOT NULL LIMIT 1000;"
|
| 259 |
else:
|
| 260 |
sql_query = "SELECT * FROM data_tab LIMIT 10;"
|
| 261 |
+
elif is_visualization:
|
| 262 |
+
# For visualization queries, use the specialized visualization prompt
|
| 263 |
+
sql_query = llm.invoke(visualization_prompt.format(
|
| 264 |
+
question=question_with_context,
|
| 265 |
+
viz_type=viz_type or "bar"
|
| 266 |
+
)).content
|
| 267 |
+
sql_query = clean_sql_query(sql_query)
|
| 268 |
else:
|
| 269 |
# For other queries, use the LLM to generate SQL
|
| 270 |
sql_query = llm.invoke(query_prompt.format(question=question_with_context)).content
|