SVashishta1 commited on
Commit
e770679
·
1 Parent(s): b8b94fc

Restore: Visualization prompt and current_plot for better LLM guidance and SQL generation

Browse files
Files changed (1) hide show
  1. app.py +41 -36
app.py CHANGED
@@ -13,9 +13,7 @@ import plotly.io as pio
13
  import traceback
14
  import base64
15
  from io import BytesIO
16
- # I am commenting out voice libraries because we are not using them right now
17
- # import speech_recognition as sr
18
- # from gtts import gTTS
19
  import re
20
  import importlib.util
21
 
@@ -64,7 +62,7 @@ current_context = {
64
  }
65
 
66
  # Add a global variable to store the current plot
67
- # current_plot = None
68
 
69
  # Define the prompt with examples for SQL query generation
70
  query_prompt = ChatPromptTemplate.from_template("""
@@ -83,6 +81,38 @@ Important guidelines:
83
  Question: {question}
84
  """)
85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  # Define the prompt for interpreting the SQL query result
87
  interpret_prompt = ChatPromptTemplate.from_messages(
88
  [
@@ -91,38 +121,6 @@ interpret_prompt = ChatPromptTemplate.from_messages(
91
  ]
92
  )
93
 
94
- # Add this after the query_prompt definition
95
- # visualization_prompt = ChatPromptTemplate.from_template("""
96
- # You are a data visualization expert. Given a question about visualizing data, write a SQLite-compatible SQL query that will retrieve the appropriate data for the visualization.
97
- #
98
- # Important guidelines for SQLite syntax:
99
- # 1. Use strftime() for date functions:
100
- # - Year: strftime('%Y', date_column)
101
- # - Month: strftime('%m', date_column)
102
- # - Day: strftime('%d', date_column)
103
- # - Hour: strftime('%H', date_column)
104
- #
105
- # 2. For histograms and binning:
106
- # - Use: CAST((column / bin_size) AS INT) * bin_size
107
- # - Example: CAST((trip_distance / 0.5) AS INT) * 0.5 AS distance_bin
108
- #
109
- # 3. For box plots:
110
- # - SQLite doesn't support PERCENTILE_CONT or window functions
111
- # - Simply return the raw data column: SELECT column_name FROM data_tab
112
- # - The application will calculate quartiles and outliers
113
- #
114
- # 4. For heatmaps:
115
- # - Return raw data for correlation analysis
116
- # - Example: SELECT numeric_col1, numeric_col2, numeric_col3 FROM data_tab
117
- #
118
- # 5. Always use 'data_tab' as the table name
119
- #
120
- # 6. IMPORTANT: Return ONLY the SQL query without any markdown formatting, explanations, or code blocks
121
- #
122
- # Question: {question}
123
- # Visualization type: {viz_type}
124
- # """)
125
-
126
  # Add this helper function to clean SQL queries
127
  def clean_sql_query(query_text):
128
  """Clean SQL query text by removing markdown formatting and comments"""
@@ -260,6 +258,13 @@ def process_text_query(query, history):
260
  sql_query = f"SELECT {cols_str} FROM data_tab WHERE {numeric_cols[0]} IS NOT NULL LIMIT 1000;"
261
  else:
262
  sql_query = "SELECT * FROM data_tab LIMIT 10;"
 
 
 
 
 
 
 
263
  else:
264
  # For other queries, use the LLM to generate SQL
265
  sql_query = llm.invoke(query_prompt.format(question=question_with_context)).content
 
13
  import traceback
14
  import base64
15
  from io import BytesIO
16
+
 
 
17
  import re
18
  import importlib.util
19
 
 
62
  }
63
 
64
  # Add a global variable to store the current plot
65
+ current_plot = None
66
 
67
  # Define the prompt with examples for SQL query generation
68
  query_prompt = ChatPromptTemplate.from_template("""
 
81
  Question: {question}
82
  """)
83
 
84
+ # Add this after the query_prompt definition
85
+ visualization_prompt = ChatPromptTemplate.from_template("""
86
+ You are a data visualization expert. Given a question about visualizing data, write a SQLite-compatible SQL query that will retrieve the appropriate data for the visualization.
87
+
88
+ Important guidelines for SQLite syntax:
89
+ 1. Use strftime() for date functions:
90
+ - Year: strftime('%Y', date_column)
91
+ - Month: strftime('%m', date_column)
92
+ - Day: strftime('%d', date_column)
93
+ - Hour: strftime('%H', date_column)
94
+
95
+ 2. For histograms and binning:
96
+ - Use: CAST((column / bin_size) AS INT) * bin_size
97
+ - Example: CAST((trip_distance / 0.5) AS INT) * 0.5 AS distance_bin
98
+
99
+ 3. For box plots:
100
+ - SQLite doesn't support PERCENTILE_CONT or window functions
101
+ - Simply return the raw data column: SELECT column_name FROM data_tab
102
+ - The application will calculate quartiles and outliers
103
+
104
+ 4. For heatmaps:
105
+ - Return raw data for correlation analysis
106
+ - Example: SELECT numeric_col1, numeric_col2, numeric_col3 FROM data_tab
107
+
108
+ 5. Always use 'data_tab' as the table name
109
+
110
+ 6. IMPORTANT: Return ONLY the SQL query without any markdown formatting, explanations, or code blocks
111
+
112
+ Question: {question}
113
+ Visualization type: {viz_type}
114
+ """)
115
+
116
  # Define the prompt for interpreting the SQL query result
117
  interpret_prompt = ChatPromptTemplate.from_messages(
118
  [
 
121
  ]
122
  )
123
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  # Add this helper function to clean SQL queries
125
  def clean_sql_query(query_text):
126
  """Clean SQL query text by removing markdown formatting and comments"""
 
258
  sql_query = f"SELECT {cols_str} FROM data_tab WHERE {numeric_cols[0]} IS NOT NULL LIMIT 1000;"
259
  else:
260
  sql_query = "SELECT * FROM data_tab LIMIT 10;"
261
+ elif is_visualization:
262
+ # For visualization queries, use the specialized visualization prompt
263
+ sql_query = llm.invoke(visualization_prompt.format(
264
+ question=question_with_context,
265
+ viz_type=viz_type or "bar"
266
+ )).content
267
+ sql_query = clean_sql_query(sql_query)
268
  else:
269
  # For other queries, use the LLM to generate SQL
270
  sql_query = llm.invoke(query_prompt.format(question=question_with_context)).content