SVashishta1
commited on
Commit
·
e3d98a2
1
Parent(s):
a8c9793
Error Fix
Browse files
app.py
CHANGED
|
@@ -54,73 +54,22 @@ current_context = {
|
|
| 54 |
current_plot = None
|
| 55 |
|
| 56 |
# Define the prompt with examples for SQL query generation
|
| 57 |
-
query_prompt = ChatPromptTemplate.
|
| 58 |
-
|
| 59 |
-
("system", """
|
| 60 |
-
You are an SQL and data analysis expert. Generate an appropriate SQL query using SQLite syntax for the question provided, without any explanations or code comments.
|
| 61 |
-
Follow SQLite-specific conventions, as shown in the examples below:
|
| 62 |
-
|
| 63 |
-
Example 1:
|
| 64 |
-
Question: "What is the average fare for trips over 10 miles?"
|
| 65 |
-
SQL Query: SELECT AVG(fare_amount) FROM data_tab WHERE trip_distance > 10;
|
| 66 |
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
Example 4:
|
| 76 |
-
Question: "What is the highest tip amount in the dataset?"
|
| 77 |
-
SQL Query: SELECT MAX(tip_amount) as highest_tip FROM data_tab;
|
| 78 |
-
|
| 79 |
-
Example 5:
|
| 80 |
-
Question: "Plot a bar graph for tip trends by month"
|
| 81 |
-
SQL Query: SELECT strftime('%Y-%m', pickup_datetime) as month, AVG(tip_amount) as avg_tip, COUNT(*) as count FROM data_tab GROUP BY month ORDER BY month;
|
| 82 |
-
|
| 83 |
-
SQLite-Specific Conventions:
|
| 84 |
-
|
| 85 |
-
1. Date and Time Extraction:
|
| 86 |
-
- Instead of `EXTRACT(YEAR FROM column)`, use `strftime('%Y', column)` to extract the year.
|
| 87 |
-
- Example: `SELECT strftime('%Y', pickup_datetime) FROM data_tab;`
|
| 88 |
-
|
| 89 |
-
2. String Length:
|
| 90 |
-
- Instead of `CHAR_LENGTH(column)`, use `LENGTH(column)`.
|
| 91 |
-
- Example: `SELECT LENGTH(passenger_name) FROM data_tab;`
|
| 92 |
-
|
| 93 |
-
3. Regular Expressions:
|
| 94 |
-
- SQLite does not support `REGEXP`. Use `LIKE` for simple patterns or avoid regular expressions.
|
| 95 |
-
- Example: `SELECT * FROM data_tab WHERE passenger_name LIKE 'A%';`
|
| 96 |
-
|
| 97 |
-
4. Window Functions:
|
| 98 |
-
- For row numbering, use `ROW_NUMBER()` if supported, or simulate with joins.
|
| 99 |
-
- Example: `SELECT id, ROW_NUMBER() OVER (ORDER BY pickup_datetime) AS row_num FROM data_tab;`
|
| 100 |
-
|
| 101 |
-
5. Data Type Casting:
|
| 102 |
-
- Use `CAST(column AS TYPE)`, but note that SQLite supports limited types.
|
| 103 |
-
- Example: `SELECT CAST(fare_amount AS INTEGER) FROM data_tab;`
|
| 104 |
-
|
| 105 |
-
6. Full Outer Join Workaround:
|
| 106 |
-
- SQLite doesn't support `FULL OUTER JOIN`. Combine `LEFT JOIN` and `UNION` for a similar effect.
|
| 107 |
-
- Example:
|
| 108 |
-
```
|
| 109 |
-
SELECT a.*, b.*
|
| 110 |
-
FROM table_a a
|
| 111 |
-
LEFT JOIN table_b b ON a.id = b.id
|
| 112 |
-
UNION
|
| 113 |
-
SELECT a.*, b.*
|
| 114 |
-
FROM table_a a
|
| 115 |
-
RIGHT JOIN table_b b ON a.id = b.id;
|
| 116 |
-
```
|
| 117 |
-
|
| 118 |
-
Use these examples and guidelines to generate an SQL query compatible with SQLite syntax for the question provided.
|
| 119 |
-
Always use 'data_tab' as the table name.
|
| 120 |
-
"""),
|
| 121 |
-
("human", "{question}"),
|
| 122 |
-
]
|
| 123 |
-
)
|
| 124 |
|
| 125 |
# Define the prompt for interpreting the SQL query result
|
| 126 |
interpret_prompt = ChatPromptTemplate.from_messages(
|
|
@@ -130,6 +79,37 @@ interpret_prompt = ChatPromptTemplate.from_messages(
|
|
| 130 |
]
|
| 131 |
)
|
| 132 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
def process_text_query(query, history):
|
| 134 |
"""Process a text query and update chat history"""
|
| 135 |
if not query:
|
|
@@ -205,28 +185,86 @@ def process_text_query(query, history):
|
|
| 205 |
if is_visualization and not result_df.empty:
|
| 206 |
try:
|
| 207 |
print("Visualization requested, attempting to create plot...")
|
| 208 |
-
# Determine the type of visualization based on the data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
if len(result_df.columns) >= 2:
|
| 210 |
# Find numeric columns for y-axis
|
| 211 |
numeric_cols = result_df.select_dtypes(include=['number']).columns.tolist()
|
| 212 |
|
| 213 |
if len(numeric_cols) >= 1 and len(result_df) > 1:
|
| 214 |
-
#
|
| 215 |
-
|
| 216 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
|
| 218 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 219 |
|
| 220 |
-
|
| 221 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 222 |
# Time series data - use line chart
|
| 223 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 224 |
else:
|
| 225 |
# Regular data - use bar chart
|
| 226 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 227 |
|
| 228 |
# Convert the figure to an image and encode it as base64
|
| 229 |
-
img_bytes = fig.to_image(format="png", width=
|
| 230 |
encoded = base64.b64encode(img_bytes).decode("ascii")
|
| 231 |
img_src = f"data:image/png;base64,{encoded}"
|
| 232 |
|
|
|
|
| 54 |
current_plot = None
|
| 55 |
|
| 56 |
# Define the prompt with examples for SQL query generation
|
| 57 |
+
query_prompt = ChatPromptTemplate.from_template("""
|
| 58 |
+
You are a SQL expert. Given a question about data in a table, write a SQLite-compatible SQL query to answer the question.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
|
| 60 |
+
Important guidelines:
|
| 61 |
+
1. Use SQLite syntax (not PostgreSQL or MySQL)
|
| 62 |
+
2. For date functions, use strftime() instead of EXTRACT
|
| 63 |
+
- Example: strftime('%Y', date_column) instead of EXTRACT(YEAR FROM date_column)
|
| 64 |
+
3. SQLite doesn't have TRUNCATE function, use CAST((column / bin_size) AS INT) * bin_size instead
|
| 65 |
+
4. For percentiles, use window functions or approximate methods
|
| 66 |
+
5. Keep queries efficient and focused on answering the specific question
|
| 67 |
+
6. Always use 'data_tab' as the table name
|
| 68 |
|
| 69 |
+
Question: {question}
|
| 70 |
+
|
| 71 |
+
SQL Query:
|
| 72 |
+
""")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
|
| 74 |
# Define the prompt for interpreting the SQL query result
|
| 75 |
interpret_prompt = ChatPromptTemplate.from_messages(
|
|
|
|
| 79 |
]
|
| 80 |
)
|
| 81 |
|
| 82 |
+
# Add this after the query_prompt definition
|
| 83 |
+
visualization_prompt = ChatPromptTemplate.from_template("""
|
| 84 |
+
You are a data visualization expert. Given a question about visualizing data, write a SQLite-compatible SQL query that will retrieve the appropriate data for the visualization.
|
| 85 |
+
|
| 86 |
+
Important guidelines for SQLite syntax:
|
| 87 |
+
1. Use strftime() for date functions:
|
| 88 |
+
- Year: strftime('%Y', date_column)
|
| 89 |
+
- Month: strftime('%m', date_column)
|
| 90 |
+
- Day: strftime('%d', date_column)
|
| 91 |
+
- Hour: strftime('%H', date_column)
|
| 92 |
+
|
| 93 |
+
2. For histograms and binning:
|
| 94 |
+
- Use: CAST((column / bin_size) AS INT) * bin_size
|
| 95 |
+
- Example: CAST((trip_distance / 0.5) AS INT) * 0.5 AS distance_bin
|
| 96 |
+
|
| 97 |
+
3. For percentiles and statistics:
|
| 98 |
+
- SQLite doesn't have built-in percentile functions
|
| 99 |
+
- Use simple aggregations (MIN, MAX, AVG, COUNT) instead
|
| 100 |
+
|
| 101 |
+
4. For time series:
|
| 102 |
+
- Group by date parts using strftime()
|
| 103 |
+
- Example: strftime('%Y-%m-%d', pickup_datetime) AS day
|
| 104 |
+
|
| 105 |
+
5. Always use 'data_tab' as the table name
|
| 106 |
+
|
| 107 |
+
Question: {question}
|
| 108 |
+
Visualization type: {viz_type}
|
| 109 |
+
|
| 110 |
+
SQL Query:
|
| 111 |
+
""")
|
| 112 |
+
|
| 113 |
def process_text_query(query, history):
|
| 114 |
"""Process a text query and update chat history"""
|
| 115 |
if not query:
|
|
|
|
| 185 |
if is_visualization and not result_df.empty:
|
| 186 |
try:
|
| 187 |
print("Visualization requested, attempting to create plot...")
|
| 188 |
+
# Determine the type of visualization based on the data and query
|
| 189 |
+
|
| 190 |
+
# Check for specific visualization types in the query
|
| 191 |
+
is_pie_chart = any(word in query.lower() for word in ['pie chart', 'pie graph', 'distribution'])
|
| 192 |
+
is_histogram = any(word in query.lower() for word in ['histogram', 'distribution of', 'frequency'])
|
| 193 |
+
is_heatmap = any(word in query.lower() for word in ['heatmap', 'heat map', 'correlation'])
|
| 194 |
+
is_scatter = any(word in query.lower() for word in ['scatter', 'relationship between', 'correlation'])
|
| 195 |
+
|
| 196 |
if len(result_df.columns) >= 2:
|
| 197 |
# Find numeric columns for y-axis
|
| 198 |
numeric_cols = result_df.select_dtypes(include=['number']).columns.tolist()
|
| 199 |
|
| 200 |
if len(numeric_cols) >= 1 and len(result_df) > 1:
|
| 201 |
+
# Create appropriate plot based on query and data characteristics
|
| 202 |
+
if is_pie_chart and len(result_df) <= 20: # Pie charts work best with limited categories
|
| 203 |
+
# For pie charts, we need a category column and a value column
|
| 204 |
+
category_col = result_df.columns[0]
|
| 205 |
+
value_col = numeric_cols[0] if len(numeric_cols) > 0 else result_df.columns[1]
|
| 206 |
+
|
| 207 |
+
fig = px.pie(result_df, names=category_col, values=value_col,
|
| 208 |
+
title="Distribution Analysis",
|
| 209 |
+
hole=0.3) # Use a donut chart for better readability
|
| 210 |
+
|
| 211 |
+
elif is_histogram and len(numeric_cols) > 0:
|
| 212 |
+
# For histograms, we need a numeric column
|
| 213 |
+
fig = px.histogram(result_df, x=numeric_cols[0],
|
| 214 |
+
title=f"Distribution of {numeric_cols[0]}",
|
| 215 |
+
nbins=20)
|
| 216 |
|
| 217 |
+
elif is_heatmap and len(numeric_cols) >= 2:
|
| 218 |
+
# For heatmaps, we need at least 2 numeric columns
|
| 219 |
+
# Convert to a correlation matrix if needed
|
| 220 |
+
if len(result_df.columns) == len(numeric_cols) and len(numeric_cols) > 2:
|
| 221 |
+
# This is likely already a correlation matrix or similar data
|
| 222 |
+
fig = px.imshow(result_df,
|
| 223 |
+
title="Correlation Heatmap",
|
| 224 |
+
color_continuous_scale='RdBu_r',
|
| 225 |
+
aspect="auto")
|
| 226 |
+
else:
|
| 227 |
+
# Create a correlation matrix from the numeric columns
|
| 228 |
+
corr_df = result_df[numeric_cols].corr()
|
| 229 |
+
fig = px.imshow(corr_df,
|
| 230 |
+
title="Correlation Heatmap",
|
| 231 |
+
color_continuous_scale='RdBu_r',
|
| 232 |
+
aspect="auto")
|
| 233 |
|
| 234 |
+
elif is_scatter and len(numeric_cols) >= 2:
|
| 235 |
+
# For scatter plots, we need at least 2 numeric columns
|
| 236 |
+
fig = px.scatter(result_df, x=numeric_cols[0], y=numeric_cols[1],
|
| 237 |
+
title=f"Relationship between {numeric_cols[0]} and {numeric_cols[1]}",
|
| 238 |
+
opacity=0.7)
|
| 239 |
+
|
| 240 |
+
elif 'month' in result_df.columns or 'date' in result_df.columns or 'year' in result_df.columns or any('date' in col.lower() for col in result_df.columns):
|
| 241 |
# Time series data - use line chart
|
| 242 |
+
x_col = result_df.columns[0]
|
| 243 |
+
y_cols = numeric_cols[:3] # Use up to 3 numeric columns
|
| 244 |
+
|
| 245 |
+
fig = px.line(result_df, x=x_col, y=y_cols,
|
| 246 |
+
title="Time Series Analysis",
|
| 247 |
+
markers=True)
|
| 248 |
else:
|
| 249 |
# Regular data - use bar chart
|
| 250 |
+
x_col = result_df.columns[0]
|
| 251 |
+
y_cols = numeric_cols[0]
|
| 252 |
+
|
| 253 |
+
fig = px.bar(result_df, x=x_col, y=y_cols,
|
| 254 |
+
title="Data Visualization")
|
| 255 |
+
|
| 256 |
+
# Improve figure layout
|
| 257 |
+
fig.update_layout(
|
| 258 |
+
autosize=True,
|
| 259 |
+
width=900,
|
| 260 |
+
height=600,
|
| 261 |
+
margin=dict(l=50, r=50, b=100, t=100, pad=4),
|
| 262 |
+
template="plotly_white",
|
| 263 |
+
font=dict(size=14)
|
| 264 |
+
)
|
| 265 |
|
| 266 |
# Convert the figure to an image and encode it as base64
|
| 267 |
+
img_bytes = fig.to_image(format="png", width=900, height=600, scale=2)
|
| 268 |
encoded = base64.b64encode(img_bytes).decode("ascii")
|
| 269 |
img_src = f"data:image/png;base64,{encoded}"
|
| 270 |
|