SVashishta1
commited on
Commit
·
80b8363
1
Parent(s):
5421c65
Error Fix
Browse files
app.py
CHANGED
|
@@ -168,6 +168,28 @@ def process_text_query(query, history):
|
|
| 168 |
|
| 169 |
start_time = time.time()
|
| 170 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
# Check if we're in CSV context
|
| 172 |
if current_context["file_type"] == "csv" and current_context["table_name"]:
|
| 173 |
try:
|
|
@@ -184,53 +206,48 @@ def process_text_query(query, history):
|
|
| 184 |
question_with_context = f"The table 'data_tab' has columns: {columns_str}. {query}"
|
| 185 |
|
| 186 |
# Special handling for visualization types that need raw data
|
| 187 |
-
if is_visualization:
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
# For box plots and heatmaps, we need raw data
|
| 196 |
-
if viz_type == 'box':
|
| 197 |
-
# For box plots, we need a single numeric column
|
| 198 |
-
numeric_cols_query = "SELECT name FROM pragma_table_info('data_tab') WHERE type LIKE '%INT%' OR type LIKE '%REAL%' OR type LIKE '%FLOA%' OR type LIKE '%NUM%';"
|
| 199 |
-
cursor.execute(numeric_cols_query)
|
| 200 |
-
numeric_cols = [row[0] for row in cursor.fetchall()]
|
| 201 |
-
|
| 202 |
-
if numeric_cols:
|
| 203 |
-
# Find the relevant numeric column based on the query
|
| 204 |
-
target_col = None
|
| 205 |
-
for col in numeric_cols:
|
| 206 |
-
if col.lower() in query.lower():
|
| 207 |
-
target_col = col
|
| 208 |
-
break
|
| 209 |
-
|
| 210 |
-
# If no specific column is mentioned, use the first numeric column
|
| 211 |
-
if not target_col and numeric_cols:
|
| 212 |
-
target_col = numeric_cols[0]
|
| 213 |
-
|
| 214 |
-
# Generate a simple query to get the raw data
|
| 215 |
-
sql_query = f"SELECT {target_col} FROM data_tab WHERE {target_col} IS NOT NULL;"
|
| 216 |
-
else:
|
| 217 |
-
# No numeric columns found
|
| 218 |
-
sql_query = "SELECT * FROM data_tab LIMIT 10;"
|
| 219 |
|
| 220 |
-
|
| 221 |
-
#
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
|
|
|
|
|
|
| 225 |
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 234 |
else:
|
| 235 |
# Generate SQL query using LLM
|
| 236 |
ai_msg = query_prompt | llm
|
|
@@ -238,11 +255,8 @@ def process_text_query(query, history):
|
|
| 238 |
|
| 239 |
# Clean the SQL query
|
| 240 |
sql_query = clean_sql_query(raw_sql_query)
|
| 241 |
-
|
| 242 |
-
print(f"Generated SQL Query: {sql_query}")
|
| 243 |
|
| 244 |
-
|
| 245 |
-
is_visualization = any(word in query.lower() for word in ['plot', 'graph', 'chart', 'visualize', 'visualization', 'trend'])
|
| 246 |
|
| 247 |
try:
|
| 248 |
# Execute the query
|
|
|
|
| 168 |
|
| 169 |
start_time = time.time()
|
| 170 |
|
| 171 |
+
# Define visualization keywords at the beginning
|
| 172 |
+
viz_keywords = {
|
| 173 |
+
'bar': ['bar chart', 'bar graph', 'bar plot', 'barchart', 'bargraph'],
|
| 174 |
+
'line': ['line chart', 'line graph', 'line plot', 'linechart', 'trend', 'trends', 'time series'],
|
| 175 |
+
'pie': ['pie chart', 'pie graph', 'pie plot', 'piechart', 'distribution', 'proportion'],
|
| 176 |
+
'histogram': ['histogram', 'distribution of', 'frequency distribution'],
|
| 177 |
+
'box': ['box plot', 'boxplot', 'box and whisker', 'outliers', 'quartiles'],
|
| 178 |
+
'heatmap': ['heatmap', 'heat map', 'correlation matrix', 'correlation heatmap'],
|
| 179 |
+
'scatter': ['scatter', 'scatter plot', 'relationship between', 'correlation between']
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
# Check if this is a visualization request
|
| 183 |
+
is_visualization = any(word in query.lower() for word in ['plot', 'graph', 'chart', 'visualize', 'visualization', 'trend', 'show me'])
|
| 184 |
+
|
| 185 |
+
# Determine visualization type from query
|
| 186 |
+
viz_type = None
|
| 187 |
+
if is_visualization:
|
| 188 |
+
for vtype, keywords in viz_keywords.items():
|
| 189 |
+
if any(keyword in query.lower() for keyword in keywords):
|
| 190 |
+
viz_type = vtype
|
| 191 |
+
break
|
| 192 |
+
|
| 193 |
# Check if we're in CSV context
|
| 194 |
if current_context["file_type"] == "csv" and current_context["table_name"]:
|
| 195 |
try:
|
|
|
|
| 206 |
question_with_context = f"The table 'data_tab' has columns: {columns_str}. {query}"
|
| 207 |
|
| 208 |
# Special handling for visualization types that need raw data
|
| 209 |
+
if is_visualization and viz_type in ['box', 'heatmap']:
|
| 210 |
+
# For box plots and heatmaps, we need raw data
|
| 211 |
+
if viz_type == 'box':
|
| 212 |
+
# For box plots, we need a single numeric column
|
| 213 |
+
numeric_cols_query = "SELECT name FROM pragma_table_info('data_tab') WHERE type LIKE '%INT%' OR type LIKE '%REAL%' OR type LIKE '%FLOA%' OR type LIKE '%NUM%';"
|
| 214 |
+
cursor = conn.cursor()
|
| 215 |
+
cursor.execute(numeric_cols_query)
|
| 216 |
+
numeric_cols = [row[0] for row in cursor.fetchall()]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
|
| 218 |
+
if numeric_cols:
|
| 219 |
+
# Find the relevant numeric column based on the query
|
| 220 |
+
target_col = None
|
| 221 |
+
for col in numeric_cols:
|
| 222 |
+
if col.lower() in query.lower():
|
| 223 |
+
target_col = col
|
| 224 |
+
break
|
| 225 |
|
| 226 |
+
# If no specific column is mentioned, use the first numeric column
|
| 227 |
+
if not target_col and numeric_cols:
|
| 228 |
+
target_col = numeric_cols[0]
|
| 229 |
+
|
| 230 |
+
# Generate a simple query to get the raw data
|
| 231 |
+
sql_query = f"SELECT {target_col} FROM data_tab WHERE {target_col} IS NOT NULL;"
|
| 232 |
+
else:
|
| 233 |
+
# No numeric columns found
|
| 234 |
+
sql_query = "SELECT * FROM data_tab LIMIT 10;"
|
| 235 |
+
|
| 236 |
+
elif viz_type == 'heatmap':
|
| 237 |
+
# For heatmaps, we need multiple numeric columns
|
| 238 |
+
numeric_cols_query = "SELECT name FROM pragma_table_info('data_tab') WHERE type LIKE '%INT%' OR type LIKE '%REAL%' OR type LIKE '%FLOA%' OR type LIKE '%NUM%';"
|
| 239 |
+
cursor = conn.cursor()
|
| 240 |
+
cursor.execute(numeric_cols_query)
|
| 241 |
+
numeric_cols = [row[0] for row in cursor.fetchall()]
|
| 242 |
+
|
| 243 |
+
if len(numeric_cols) >= 2:
|
| 244 |
+
# Use all numeric columns (up to a reasonable limit)
|
| 245 |
+
cols_to_use = numeric_cols[:10] # Limit to 10 columns for performance
|
| 246 |
+
cols_str = ", ".join(cols_to_use)
|
| 247 |
+
sql_query = f"SELECT {cols_str} FROM data_tab WHERE {numeric_cols[0]} IS NOT NULL LIMIT 1000;"
|
| 248 |
+
else:
|
| 249 |
+
# Not enough numeric columns
|
| 250 |
+
sql_query = "SELECT * FROM data_tab LIMIT 10;"
|
| 251 |
else:
|
| 252 |
# Generate SQL query using LLM
|
| 253 |
ai_msg = query_prompt | llm
|
|
|
|
| 255 |
|
| 256 |
# Clean the SQL query
|
| 257 |
sql_query = clean_sql_query(raw_sql_query)
|
|
|
|
|
|
|
| 258 |
|
| 259 |
+
print(f"Generated SQL Query: {sql_query}")
|
|
|
|
| 260 |
|
| 261 |
try:
|
| 262 |
# Execute the query
|