SVashishta1 commited on
Commit
80b8363
·
1 Parent(s): 5421c65
Files changed (1) hide show
  1. app.py +63 -49
app.py CHANGED
@@ -168,6 +168,28 @@ def process_text_query(query, history):
168
 
169
  start_time = time.time()
170
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  # Check if we're in CSV context
172
  if current_context["file_type"] == "csv" and current_context["table_name"]:
173
  try:
@@ -184,53 +206,48 @@ def process_text_query(query, history):
184
  question_with_context = f"The table 'data_tab' has columns: {columns_str}. {query}"
185
 
186
  # Special handling for visualization types that need raw data
187
- if is_visualization:
188
- viz_type = None
189
- for vtype, keywords in viz_keywords.items():
190
- if any(keyword in query.lower() for keyword in keywords):
191
- viz_type = vtype
192
- break
193
-
194
- if viz_type in ['box', 'heatmap']:
195
- # For box plots and heatmaps, we need raw data
196
- if viz_type == 'box':
197
- # For box plots, we need a single numeric column
198
- numeric_cols_query = "SELECT name FROM pragma_table_info('data_tab') WHERE type LIKE '%INT%' OR type LIKE '%REAL%' OR type LIKE '%FLOA%' OR type LIKE '%NUM%';"
199
- cursor.execute(numeric_cols_query)
200
- numeric_cols = [row[0] for row in cursor.fetchall()]
201
-
202
- if numeric_cols:
203
- # Find the relevant numeric column based on the query
204
- target_col = None
205
- for col in numeric_cols:
206
- if col.lower() in query.lower():
207
- target_col = col
208
- break
209
-
210
- # If no specific column is mentioned, use the first numeric column
211
- if not target_col and numeric_cols:
212
- target_col = numeric_cols[0]
213
-
214
- # Generate a simple query to get the raw data
215
- sql_query = f"SELECT {target_col} FROM data_tab WHERE {target_col} IS NOT NULL;"
216
- else:
217
- # No numeric columns found
218
- sql_query = "SELECT * FROM data_tab LIMIT 10;"
219
 
220
- elif viz_type == 'heatmap':
221
- # For heatmaps, we need multiple numeric columns
222
- numeric_cols_query = "SELECT name FROM pragma_table_info('data_tab') WHERE type LIKE '%INT%' OR type LIKE '%REAL%' OR type LIKE '%FLOA%' OR type LIKE '%NUM%';"
223
- cursor.execute(numeric_cols_query)
224
- numeric_cols = [row[0] for row in cursor.fetchall()]
 
 
225
 
226
- if len(numeric_cols) >= 2:
227
- # Use all numeric columns (up to a reasonable limit)
228
- cols_to_use = numeric_cols[:10] # Limit to 10 columns for performance
229
- cols_str = ", ".join(cols_to_use)
230
- sql_query = f"SELECT {cols_str} FROM data_tab WHERE {numeric_cols[0]} IS NOT NULL LIMIT 1000;"
231
- else:
232
- # Not enough numeric columns
233
- sql_query = "SELECT * FROM data_tab LIMIT 10;"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
  else:
235
  # Generate SQL query using LLM
236
  ai_msg = query_prompt | llm
@@ -238,11 +255,8 @@ def process_text_query(query, history):
238
 
239
  # Clean the SQL query
240
  sql_query = clean_sql_query(raw_sql_query)
241
-
242
- print(f"Generated SQL Query: {sql_query}")
243
 
244
- # Check if this is a visualization request
245
- is_visualization = any(word in query.lower() for word in ['plot', 'graph', 'chart', 'visualize', 'visualization', 'trend'])
246
 
247
  try:
248
  # Execute the query
 
168
 
169
  start_time = time.time()
170
 
171
+ # Define visualization keywords at the beginning
172
+ viz_keywords = {
173
+ 'bar': ['bar chart', 'bar graph', 'bar plot', 'barchart', 'bargraph'],
174
+ 'line': ['line chart', 'line graph', 'line plot', 'linechart', 'trend', 'trends', 'time series'],
175
+ 'pie': ['pie chart', 'pie graph', 'pie plot', 'piechart', 'distribution', 'proportion'],
176
+ 'histogram': ['histogram', 'distribution of', 'frequency distribution'],
177
+ 'box': ['box plot', 'boxplot', 'box and whisker', 'outliers', 'quartiles'],
178
+ 'heatmap': ['heatmap', 'heat map', 'correlation matrix', 'correlation heatmap'],
179
+ 'scatter': ['scatter', 'scatter plot', 'relationship between', 'correlation between']
180
+ }
181
+
182
+ # Check if this is a visualization request
183
+ is_visualization = any(word in query.lower() for word in ['plot', 'graph', 'chart', 'visualize', 'visualization', 'trend', 'show me'])
184
+
185
+ # Determine visualization type from query
186
+ viz_type = None
187
+ if is_visualization:
188
+ for vtype, keywords in viz_keywords.items():
189
+ if any(keyword in query.lower() for keyword in keywords):
190
+ viz_type = vtype
191
+ break
192
+
193
  # Check if we're in CSV context
194
  if current_context["file_type"] == "csv" and current_context["table_name"]:
195
  try:
 
206
  question_with_context = f"The table 'data_tab' has columns: {columns_str}. {query}"
207
 
208
  # Special handling for visualization types that need raw data
209
+ if is_visualization and viz_type in ['box', 'heatmap']:
210
+ # For box plots and heatmaps, we need raw data
211
+ if viz_type == 'box':
212
+ # For box plots, we need a single numeric column
213
+ numeric_cols_query = "SELECT name FROM pragma_table_info('data_tab') WHERE type LIKE '%INT%' OR type LIKE '%REAL%' OR type LIKE '%FLOA%' OR type LIKE '%NUM%';"
214
+ cursor = conn.cursor()
215
+ cursor.execute(numeric_cols_query)
216
+ numeric_cols = [row[0] for row in cursor.fetchall()]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
 
218
+ if numeric_cols:
219
+ # Find the relevant numeric column based on the query
220
+ target_col = None
221
+ for col in numeric_cols:
222
+ if col.lower() in query.lower():
223
+ target_col = col
224
+ break
225
 
226
+ # If no specific column is mentioned, use the first numeric column
227
+ if not target_col and numeric_cols:
228
+ target_col = numeric_cols[0]
229
+
230
+ # Generate a simple query to get the raw data
231
+ sql_query = f"SELECT {target_col} FROM data_tab WHERE {target_col} IS NOT NULL;"
232
+ else:
233
+ # No numeric columns found
234
+ sql_query = "SELECT * FROM data_tab LIMIT 10;"
235
+
236
+ elif viz_type == 'heatmap':
237
+ # For heatmaps, we need multiple numeric columns
238
+ numeric_cols_query = "SELECT name FROM pragma_table_info('data_tab') WHERE type LIKE '%INT%' OR type LIKE '%REAL%' OR type LIKE '%FLOA%' OR type LIKE '%NUM%';"
239
+ cursor = conn.cursor()
240
+ cursor.execute(numeric_cols_query)
241
+ numeric_cols = [row[0] for row in cursor.fetchall()]
242
+
243
+ if len(numeric_cols) >= 2:
244
+ # Use all numeric columns (up to a reasonable limit)
245
+ cols_to_use = numeric_cols[:10] # Limit to 10 columns for performance
246
+ cols_str = ", ".join(cols_to_use)
247
+ sql_query = f"SELECT {cols_str} FROM data_tab WHERE {numeric_cols[0]} IS NOT NULL LIMIT 1000;"
248
+ else:
249
+ # Not enough numeric columns
250
+ sql_query = "SELECT * FROM data_tab LIMIT 10;"
251
  else:
252
  # Generate SQL query using LLM
253
  ai_msg = query_prompt | llm
 
255
 
256
  # Clean the SQL query
257
  sql_query = clean_sql_query(raw_sql_query)
 
 
258
 
259
+ print(f"Generated SQL Query: {sql_query}")
 
260
 
261
  try:
262
  # Execute the query