SVashishta1 commited on
Commit
e06f3ce
·
1 Parent(s): 6e54ca7
Files changed (1) hide show
  1. app.py +30 -11
app.py CHANGED
@@ -9,6 +9,7 @@ from langchain_core.prompts import ChatPromptTemplate
9
  from langchain_groq import ChatGroq
10
  import plotly.express as px
11
  import time
 
12
 
13
  # Load environment variables
14
  load_dotenv()
@@ -43,6 +44,9 @@ current_context = {
43
  "table_name": None
44
  }
45
 
 
 
 
46
  # Define the prompt with examples for SQL query generation
47
  query_prompt = ChatPromptTemplate.from_messages(
48
  [
@@ -122,8 +126,11 @@ interpret_prompt = ChatPromptTemplate.from_messages(
122
 
123
  def process_text_query(query, history):
124
  """Process a text query and update chat history"""
 
 
 
125
  if not query:
126
- return "", history
127
 
128
  # Add the user's query to history
129
  history.append({"role": "user", "content": query})
@@ -194,6 +201,7 @@ def process_text_query(query, history):
194
  # Add visualization if requested
195
  if is_visualization and not result_df.empty:
196
  try:
 
197
  # Determine the type of visualization based on the data
198
  if len(result_df.columns) >= 2:
199
  # Find numeric columns for y-axis
@@ -202,19 +210,25 @@ def process_text_query(query, history):
202
  if len(numeric_cols) >= 1 and len(result_df) > 1:
203
  # Use the first column as x and first numeric column as y
204
  x_col = result_df.columns[0]
205
- y_col = numeric_cols[0]
206
 
207
  # Create appropriate plot based on data characteristics
208
  if 'month' in result_df.columns or 'date' in result_df.columns or 'year' in result_df.columns:
209
  # Time series data - use line chart
210
- fig = px.line(result_df, x=x_col, y=numeric_cols, title="Time Series Analysis")
211
  else:
212
  # Regular data - use bar chart
213
- fig = px.bar(result_df, x=x_col, y=y_col, title="Data Visualization")
214
 
215
- # Convert to HTML and add to response
216
- plot_html = fig.to_html(full_html=False, include_plotlyjs='cdn')
217
- response += f"\n\n**Visualization:**\n<div>{plot_html}</div>"
 
 
 
 
 
 
218
  except Exception as viz_error:
219
  print(f"Visualization error: {str(viz_error)}")
220
  # Continue without visualization if there's an error
@@ -241,7 +255,7 @@ def process_text_query(query, history):
241
  # Add the response to history
242
  history.append({"role": "assistant", "content": response})
243
 
244
- return "", history
245
 
246
  def process_file_upload(files):
247
  """Process uploaded files and index them"""
@@ -411,6 +425,9 @@ with gr.Blocks(title="AI Document Analysis & Voice Assistant") as demo:
411
  with gr.Tab("Chat"):
412
  chatbot = gr.Chatbot(height=400, type="messages")
413
 
 
 
 
414
  with gr.Row():
415
  with gr.Column(scale=8):
416
  msg = gr.Textbox(
@@ -438,16 +455,18 @@ with gr.Blocks(title="AI Document Analysis & Voice Assistant") as demo:
438
  submit_btn.click(
439
  process_text_query,
440
  inputs=[msg, chatbot],
441
- outputs=[msg, chatbot]
 
442
  )
443
 
444
  msg.submit(
445
  process_text_query,
446
  inputs=[msg, chatbot],
447
- outputs=[msg, chatbot]
 
448
  )
449
 
450
- clear_btn.click(lambda: None, None, chatbot, queue=False)
451
  clear_context_btn.click(clear_context, inputs=[], outputs=[chatbot])
452
 
453
  voice_btn.click(
 
9
  from langchain_groq import ChatGroq
10
  import plotly.express as px
11
  import time
12
+ import plotly.io as pio
13
 
14
  # Load environment variables
15
  load_dotenv()
 
44
  "table_name": None
45
  }
46
 
47
+ # Add a global variable to store the current plot
48
+ current_plot = None
49
+
50
  # Define the prompt with examples for SQL query generation
51
  query_prompt = ChatPromptTemplate.from_messages(
52
  [
 
126
 
127
  def process_text_query(query, history):
128
  """Process a text query and update chat history"""
129
+ global current_plot
130
+ current_plot = None # Reset the plot
131
+
132
  if not query:
133
+ return "", history, None
134
 
135
  # Add the user's query to history
136
  history.append({"role": "user", "content": query})
 
201
  # Add visualization if requested
202
  if is_visualization and not result_df.empty:
203
  try:
204
+ print("Visualization requested, attempting to create plot...")
205
  # Determine the type of visualization based on the data
206
  if len(result_df.columns) >= 2:
207
  # Find numeric columns for y-axis
 
210
  if len(numeric_cols) >= 1 and len(result_df) > 1:
211
  # Use the first column as x and first numeric column as y
212
  x_col = result_df.columns[0]
213
+ y_cols = numeric_cols[:3] # Use up to 3 numeric columns
214
 
215
  # Create appropriate plot based on data characteristics
216
  if 'month' in result_df.columns or 'date' in result_df.columns or 'year' in result_df.columns:
217
  # Time series data - use line chart
218
+ fig = px.line(result_df, x=x_col, y=y_cols, title="Time Series Analysis")
219
  else:
220
  # Regular data - use bar chart
221
+ fig = px.bar(result_df, x=x_col, y=y_cols[0], title="Data Visualization")
222
 
223
+ # Store the figure for display
224
+ current_plot = fig
225
+
226
+ # Add note about visualization
227
+ response += "\n\n**A visualization has been generated and is displayed below.**"
228
+
229
+ # After creating the plot
230
+ print(f"Plot created: {current_plot is not None}")
231
+
232
  except Exception as viz_error:
233
  print(f"Visualization error: {str(viz_error)}")
234
  # Continue without visualization if there's an error
 
255
  # Add the response to history
256
  history.append({"role": "assistant", "content": response})
257
 
258
+ return "", history, current_plot
259
 
260
  def process_file_upload(files):
261
  """Process uploaded files and index them"""
 
425
  with gr.Tab("Chat"):
426
  chatbot = gr.Chatbot(height=400, type="messages")
427
 
428
+ # Add a plot component
429
+ plot_output = gr.Plot(label="Data Visualization", visible=False)
430
+
431
  with gr.Row():
432
  with gr.Column(scale=8):
433
  msg = gr.Textbox(
 
455
  submit_btn.click(
456
  process_text_query,
457
  inputs=[msg, chatbot],
458
+ outputs=[msg, chatbot, plot_output],
459
+ postprocess=lambda _, __, fig: gr.update(value=fig, visible=fig is not None)
460
  )
461
 
462
  msg.submit(
463
  process_text_query,
464
  inputs=[msg, chatbot],
465
+ outputs=[msg, chatbot, plot_output],
466
+ postprocess=lambda _, __, fig: gr.update(value=fig, visible=fig is not None)
467
  )
468
 
469
+ clear_btn.click(lambda: [None, None, gr.update(visible=False)], None, [chatbot, plot_output, plot_output], queue=False)
470
  clear_context_btn.click(clear_context, inputs=[], outputs=[chatbot])
471
 
472
  voice_btn.click(