SVashishta1 commited on
Commit
e3d98a2
·
1 Parent(s): a8c9793
Files changed (1) hide show
  1. app.py +113 -75
app.py CHANGED
@@ -54,73 +54,22 @@ current_context = {
54
  current_plot = None
55
 
56
  # Define the prompt with examples for SQL query generation
57
- query_prompt = ChatPromptTemplate.from_messages(
58
- [
59
- ("system", """
60
- You are an SQL and data analysis expert. Generate an appropriate SQL query using SQLite syntax for the question provided, without any explanations or code comments.
61
- Follow SQLite-specific conventions, as shown in the examples below:
62
-
63
- Example 1:
64
- Question: "What is the average fare for trips over 10 miles?"
65
- SQL Query: SELECT AVG(fare_amount) FROM data_tab WHERE trip_distance > 10;
66
 
67
- Example 2:
68
- Question: "How many trips were taken in each month?"
69
- SQL Query: SELECT strftime('%m', pickup_datetime) AS month, COUNT(*) AS trip_count FROM data_tab GROUP BY month;
 
 
 
 
 
70
 
71
- Example 3:
72
- Question: "What is the total fare amount for each driver (medallion) per day?"
73
- SQL Query: SELECT DATE(pickup_datetime) AS date, medallion, SUM(fare_amount) AS total_fare FROM data_tab GROUP BY date, medallion;
74
-
75
- Example 4:
76
- Question: "What is the highest tip amount in the dataset?"
77
- SQL Query: SELECT MAX(tip_amount) as highest_tip FROM data_tab;
78
-
79
- Example 5:
80
- Question: "Plot a bar graph for tip trends by month"
81
- SQL Query: SELECT strftime('%Y-%m', pickup_datetime) as month, AVG(tip_amount) as avg_tip, COUNT(*) as count FROM data_tab GROUP BY month ORDER BY month;
82
-
83
- SQLite-Specific Conventions:
84
-
85
- 1. Date and Time Extraction:
86
- - Instead of `EXTRACT(YEAR FROM column)`, use `strftime('%Y', column)` to extract the year.
87
- - Example: `SELECT strftime('%Y', pickup_datetime) FROM data_tab;`
88
-
89
- 2. String Length:
90
- - Instead of `CHAR_LENGTH(column)`, use `LENGTH(column)`.
91
- - Example: `SELECT LENGTH(passenger_name) FROM data_tab;`
92
-
93
- 3. Regular Expressions:
94
- - SQLite does not support `REGEXP`. Use `LIKE` for simple patterns or avoid regular expressions.
95
- - Example: `SELECT * FROM data_tab WHERE passenger_name LIKE 'A%';`
96
-
97
- 4. Window Functions:
98
- - For row numbering, use `ROW_NUMBER()` if supported, or simulate with joins.
99
- - Example: `SELECT id, ROW_NUMBER() OVER (ORDER BY pickup_datetime) AS row_num FROM data_tab;`
100
-
101
- 5. Data Type Casting:
102
- - Use `CAST(column AS TYPE)`, but note that SQLite supports limited types.
103
- - Example: `SELECT CAST(fare_amount AS INTEGER) FROM data_tab;`
104
-
105
- 6. Full Outer Join Workaround:
106
- - SQLite doesn't support `FULL OUTER JOIN`. Combine `LEFT JOIN` and `UNION` for a similar effect.
107
- - Example:
108
- ```
109
- SELECT a.*, b.*
110
- FROM table_a a
111
- LEFT JOIN table_b b ON a.id = b.id
112
- UNION
113
- SELECT a.*, b.*
114
- FROM table_a a
115
- RIGHT JOIN table_b b ON a.id = b.id;
116
- ```
117
-
118
- Use these examples and guidelines to generate an SQL query compatible with SQLite syntax for the question provided.
119
- Always use 'data_tab' as the table name.
120
- """),
121
- ("human", "{question}"),
122
- ]
123
- )
124
 
125
  # Define the prompt for interpreting the SQL query result
126
  interpret_prompt = ChatPromptTemplate.from_messages(
@@ -130,6 +79,37 @@ interpret_prompt = ChatPromptTemplate.from_messages(
130
  ]
131
  )
132
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  def process_text_query(query, history):
134
  """Process a text query and update chat history"""
135
  if not query:
@@ -205,28 +185,86 @@ def process_text_query(query, history):
205
  if is_visualization and not result_df.empty:
206
  try:
207
  print("Visualization requested, attempting to create plot...")
208
- # Determine the type of visualization based on the data
 
 
 
 
 
 
 
209
  if len(result_df.columns) >= 2:
210
  # Find numeric columns for y-axis
211
  numeric_cols = result_df.select_dtypes(include=['number']).columns.tolist()
212
 
213
  if len(numeric_cols) >= 1 and len(result_df) > 1:
214
- # Use the first column as x and first numeric column as y
215
- x_col = result_df.columns[0]
216
- y_cols = numeric_cols[:3] # Use up to 3 numeric columns
 
 
 
 
 
 
 
 
 
 
 
 
217
 
218
- print(f"Creating plot with x={x_col}, y={y_cols}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
 
220
- # Create appropriate plot based on data characteristics
221
- if 'month' in result_df.columns or 'date' in result_df.columns or 'year' in result_df.columns or any('date' in col.lower() for col in result_df.columns):
 
 
 
 
 
222
  # Time series data - use line chart
223
- fig = px.line(result_df, x=x_col, y=y_cols, title="Time Series Analysis")
 
 
 
 
 
224
  else:
225
  # Regular data - use bar chart
226
- fig = px.bar(result_df, x=x_col, y=y_cols[0], title="Data Visualization")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
 
228
  # Convert the figure to an image and encode it as base64
229
- img_bytes = fig.to_image(format="png", width=800, height=500)
230
  encoded = base64.b64encode(img_bytes).decode("ascii")
231
  img_src = f"data:image/png;base64,{encoded}"
232
 
 
54
  current_plot = None
55
 
56
  # Define the prompt with examples for SQL query generation
57
+ query_prompt = ChatPromptTemplate.from_template("""
58
+ You are a SQL expert. Given a question about data in a table, write a SQLite-compatible SQL query to answer the question.
 
 
 
 
 
 
 
59
 
60
+ Important guidelines:
61
+ 1. Use SQLite syntax (not PostgreSQL or MySQL)
62
+ 2. For date functions, use strftime() instead of EXTRACT
63
+ - Example: strftime('%Y', date_column) instead of EXTRACT(YEAR FROM date_column)
64
+ 3. SQLite doesn't have TRUNCATE function, use CAST((column / bin_size) AS INT) * bin_size instead
65
+ 4. For percentiles, use window functions or approximate methods
66
+ 5. Keep queries efficient and focused on answering the specific question
67
+ 6. Always use 'data_tab' as the table name
68
 
69
+ Question: {question}
70
+
71
+ SQL Query:
72
+ """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
  # Define the prompt for interpreting the SQL query result
75
  interpret_prompt = ChatPromptTemplate.from_messages(
 
79
  ]
80
  )
81
 
82
+ # Add this after the query_prompt definition
83
+ visualization_prompt = ChatPromptTemplate.from_template("""
84
+ You are a data visualization expert. Given a question about visualizing data, write a SQLite-compatible SQL query that will retrieve the appropriate data for the visualization.
85
+
86
+ Important guidelines for SQLite syntax:
87
+ 1. Use strftime() for date functions:
88
+ - Year: strftime('%Y', date_column)
89
+ - Month: strftime('%m', date_column)
90
+ - Day: strftime('%d', date_column)
91
+ - Hour: strftime('%H', date_column)
92
+
93
+ 2. For histograms and binning:
94
+ - Use: CAST((column / bin_size) AS INT) * bin_size
95
+ - Example: CAST((trip_distance / 0.5) AS INT) * 0.5 AS distance_bin
96
+
97
+ 3. For percentiles and statistics:
98
+ - SQLite doesn't have built-in percentile functions
99
+ - Use simple aggregations (MIN, MAX, AVG, COUNT) instead
100
+
101
+ 4. For time series:
102
+ - Group by date parts using strftime()
103
+ - Example: strftime('%Y-%m-%d', pickup_datetime) AS day
104
+
105
+ 5. Always use 'data_tab' as the table name
106
+
107
+ Question: {question}
108
+ Visualization type: {viz_type}
109
+
110
+ SQL Query:
111
+ """)
112
+
113
  def process_text_query(query, history):
114
  """Process a text query and update chat history"""
115
  if not query:
 
185
  if is_visualization and not result_df.empty:
186
  try:
187
  print("Visualization requested, attempting to create plot...")
188
+ # Determine the type of visualization based on the data and query
189
+
190
+ # Check for specific visualization types in the query
191
+ is_pie_chart = any(word in query.lower() for word in ['pie chart', 'pie graph', 'distribution'])
192
+ is_histogram = any(word in query.lower() for word in ['histogram', 'distribution of', 'frequency'])
193
+ is_heatmap = any(word in query.lower() for word in ['heatmap', 'heat map', 'correlation'])
194
+ is_scatter = any(word in query.lower() for word in ['scatter', 'relationship between', 'correlation'])
195
+
196
  if len(result_df.columns) >= 2:
197
  # Find numeric columns for y-axis
198
  numeric_cols = result_df.select_dtypes(include=['number']).columns.tolist()
199
 
200
  if len(numeric_cols) >= 1 and len(result_df) > 1:
201
+ # Create appropriate plot based on query and data characteristics
202
+ if is_pie_chart and len(result_df) <= 20: # Pie charts work best with limited categories
203
+ # For pie charts, we need a category column and a value column
204
+ category_col = result_df.columns[0]
205
+ value_col = numeric_cols[0] if len(numeric_cols) > 0 else result_df.columns[1]
206
+
207
+ fig = px.pie(result_df, names=category_col, values=value_col,
208
+ title="Distribution Analysis",
209
+ hole=0.3) # Use a donut chart for better readability
210
+
211
+ elif is_histogram and len(numeric_cols) > 0:
212
+ # For histograms, we need a numeric column
213
+ fig = px.histogram(result_df, x=numeric_cols[0],
214
+ title=f"Distribution of {numeric_cols[0]}",
215
+ nbins=20)
216
 
217
+ elif is_heatmap and len(numeric_cols) >= 2:
218
+ # For heatmaps, we need at least 2 numeric columns
219
+ # Convert to a correlation matrix if needed
220
+ if len(result_df.columns) == len(numeric_cols) and len(numeric_cols) > 2:
221
+ # This is likely already a correlation matrix or similar data
222
+ fig = px.imshow(result_df,
223
+ title="Correlation Heatmap",
224
+ color_continuous_scale='RdBu_r',
225
+ aspect="auto")
226
+ else:
227
+ # Create a correlation matrix from the numeric columns
228
+ corr_df = result_df[numeric_cols].corr()
229
+ fig = px.imshow(corr_df,
230
+ title="Correlation Heatmap",
231
+ color_continuous_scale='RdBu_r',
232
+ aspect="auto")
233
 
234
+ elif is_scatter and len(numeric_cols) >= 2:
235
+ # For scatter plots, we need at least 2 numeric columns
236
+ fig = px.scatter(result_df, x=numeric_cols[0], y=numeric_cols[1],
237
+ title=f"Relationship between {numeric_cols[0]} and {numeric_cols[1]}",
238
+ opacity=0.7)
239
+
240
+ elif 'month' in result_df.columns or 'date' in result_df.columns or 'year' in result_df.columns or any('date' in col.lower() for col in result_df.columns):
241
  # Time series data - use line chart
242
+ x_col = result_df.columns[0]
243
+ y_cols = numeric_cols[:3] # Use up to 3 numeric columns
244
+
245
+ fig = px.line(result_df, x=x_col, y=y_cols,
246
+ title="Time Series Analysis",
247
+ markers=True)
248
  else:
249
  # Regular data - use bar chart
250
+ x_col = result_df.columns[0]
251
+ y_cols = numeric_cols[0]
252
+
253
+ fig = px.bar(result_df, x=x_col, y=y_cols,
254
+ title="Data Visualization")
255
+
256
+ # Improve figure layout
257
+ fig.update_layout(
258
+ autosize=True,
259
+ width=900,
260
+ height=600,
261
+ margin=dict(l=50, r=50, b=100, t=100, pad=4),
262
+ template="plotly_white",
263
+ font=dict(size=14)
264
+ )
265
 
266
  # Convert the figure to an image and encode it as base64
267
+ img_bytes = fig.to_image(format="png", width=900, height=600, scale=2)
268
  encoded = base64.b64encode(img_bytes).decode("ascii")
269
  img_src = f"data:image/png;base64,{encoded}"
270