SVashishta1
commited on
Commit
·
e06f3ce
1
Parent(s):
6e54ca7
Error Fix
Browse files
app.py
CHANGED
|
@@ -9,6 +9,7 @@ from langchain_core.prompts import ChatPromptTemplate
|
|
| 9 |
from langchain_groq import ChatGroq
|
| 10 |
import plotly.express as px
|
| 11 |
import time
|
|
|
|
| 12 |
|
| 13 |
# Load environment variables
|
| 14 |
load_dotenv()
|
|
@@ -43,6 +44,9 @@ current_context = {
|
|
| 43 |
"table_name": None
|
| 44 |
}
|
| 45 |
|
|
|
|
|
|
|
|
|
|
| 46 |
# Define the prompt with examples for SQL query generation
|
| 47 |
query_prompt = ChatPromptTemplate.from_messages(
|
| 48 |
[
|
|
@@ -122,8 +126,11 @@ interpret_prompt = ChatPromptTemplate.from_messages(
|
|
| 122 |
|
| 123 |
def process_text_query(query, history):
|
| 124 |
"""Process a text query and update chat history"""
|
|
|
|
|
|
|
|
|
|
| 125 |
if not query:
|
| 126 |
-
return "", history
|
| 127 |
|
| 128 |
# Add the user's query to history
|
| 129 |
history.append({"role": "user", "content": query})
|
|
@@ -194,6 +201,7 @@ def process_text_query(query, history):
|
|
| 194 |
# Add visualization if requested
|
| 195 |
if is_visualization and not result_df.empty:
|
| 196 |
try:
|
|
|
|
| 197 |
# Determine the type of visualization based on the data
|
| 198 |
if len(result_df.columns) >= 2:
|
| 199 |
# Find numeric columns for y-axis
|
|
@@ -202,19 +210,25 @@ def process_text_query(query, history):
|
|
| 202 |
if len(numeric_cols) >= 1 and len(result_df) > 1:
|
| 203 |
# Use the first column as x and first numeric column as y
|
| 204 |
x_col = result_df.columns[0]
|
| 205 |
-
|
| 206 |
|
| 207 |
# Create appropriate plot based on data characteristics
|
| 208 |
if 'month' in result_df.columns or 'date' in result_df.columns or 'year' in result_df.columns:
|
| 209 |
# Time series data - use line chart
|
| 210 |
-
fig = px.line(result_df, x=x_col, y=
|
| 211 |
else:
|
| 212 |
# Regular data - use bar chart
|
| 213 |
-
fig = px.bar(result_df, x=x_col, y=
|
| 214 |
|
| 215 |
-
#
|
| 216 |
-
|
| 217 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
except Exception as viz_error:
|
| 219 |
print(f"Visualization error: {str(viz_error)}")
|
| 220 |
# Continue without visualization if there's an error
|
|
@@ -241,7 +255,7 @@ def process_text_query(query, history):
|
|
| 241 |
# Add the response to history
|
| 242 |
history.append({"role": "assistant", "content": response})
|
| 243 |
|
| 244 |
-
return "", history
|
| 245 |
|
| 246 |
def process_file_upload(files):
|
| 247 |
"""Process uploaded files and index them"""
|
|
@@ -411,6 +425,9 @@ with gr.Blocks(title="AI Document Analysis & Voice Assistant") as demo:
|
|
| 411 |
with gr.Tab("Chat"):
|
| 412 |
chatbot = gr.Chatbot(height=400, type="messages")
|
| 413 |
|
|
|
|
|
|
|
|
|
|
| 414 |
with gr.Row():
|
| 415 |
with gr.Column(scale=8):
|
| 416 |
msg = gr.Textbox(
|
|
@@ -438,16 +455,18 @@ with gr.Blocks(title="AI Document Analysis & Voice Assistant") as demo:
|
|
| 438 |
submit_btn.click(
|
| 439 |
process_text_query,
|
| 440 |
inputs=[msg, chatbot],
|
| 441 |
-
outputs=[msg, chatbot]
|
|
|
|
| 442 |
)
|
| 443 |
|
| 444 |
msg.submit(
|
| 445 |
process_text_query,
|
| 446 |
inputs=[msg, chatbot],
|
| 447 |
-
outputs=[msg, chatbot]
|
|
|
|
| 448 |
)
|
| 449 |
|
| 450 |
-
clear_btn.click(lambda: None, None, chatbot, queue=False)
|
| 451 |
clear_context_btn.click(clear_context, inputs=[], outputs=[chatbot])
|
| 452 |
|
| 453 |
voice_btn.click(
|
|
|
|
| 9 |
from langchain_groq import ChatGroq
|
| 10 |
import plotly.express as px
|
| 11 |
import time
|
| 12 |
+
import plotly.io as pio
|
| 13 |
|
| 14 |
# Load environment variables
|
| 15 |
load_dotenv()
|
|
|
|
| 44 |
"table_name": None
|
| 45 |
}
|
| 46 |
|
| 47 |
+
# Add a global variable to store the current plot
|
| 48 |
+
current_plot = None
|
| 49 |
+
|
| 50 |
# Define the prompt with examples for SQL query generation
|
| 51 |
query_prompt = ChatPromptTemplate.from_messages(
|
| 52 |
[
|
|
|
|
| 126 |
|
| 127 |
def process_text_query(query, history):
|
| 128 |
"""Process a text query and update chat history"""
|
| 129 |
+
global current_plot
|
| 130 |
+
current_plot = None # Reset the plot
|
| 131 |
+
|
| 132 |
if not query:
|
| 133 |
+
return "", history, None
|
| 134 |
|
| 135 |
# Add the user's query to history
|
| 136 |
history.append({"role": "user", "content": query})
|
|
|
|
| 201 |
# Add visualization if requested
|
| 202 |
if is_visualization and not result_df.empty:
|
| 203 |
try:
|
| 204 |
+
print("Visualization requested, attempting to create plot...")
|
| 205 |
# Determine the type of visualization based on the data
|
| 206 |
if len(result_df.columns) >= 2:
|
| 207 |
# Find numeric columns for y-axis
|
|
|
|
| 210 |
if len(numeric_cols) >= 1 and len(result_df) > 1:
|
| 211 |
# Use the first column as x and first numeric column as y
|
| 212 |
x_col = result_df.columns[0]
|
| 213 |
+
y_cols = numeric_cols[:3] # Use up to 3 numeric columns
|
| 214 |
|
| 215 |
# Create appropriate plot based on data characteristics
|
| 216 |
if 'month' in result_df.columns or 'date' in result_df.columns or 'year' in result_df.columns:
|
| 217 |
# Time series data - use line chart
|
| 218 |
+
fig = px.line(result_df, x=x_col, y=y_cols, title="Time Series Analysis")
|
| 219 |
else:
|
| 220 |
# Regular data - use bar chart
|
| 221 |
+
fig = px.bar(result_df, x=x_col, y=y_cols[0], title="Data Visualization")
|
| 222 |
|
| 223 |
+
# Store the figure for display
|
| 224 |
+
current_plot = fig
|
| 225 |
+
|
| 226 |
+
# Add note about visualization
|
| 227 |
+
response += "\n\n**A visualization has been generated and is displayed below.**"
|
| 228 |
+
|
| 229 |
+
# After creating the plot
|
| 230 |
+
print(f"Plot created: {current_plot is not None}")
|
| 231 |
+
|
| 232 |
except Exception as viz_error:
|
| 233 |
print(f"Visualization error: {str(viz_error)}")
|
| 234 |
# Continue without visualization if there's an error
|
|
|
|
| 255 |
# Add the response to history
|
| 256 |
history.append({"role": "assistant", "content": response})
|
| 257 |
|
| 258 |
+
return "", history, current_plot
|
| 259 |
|
| 260 |
def process_file_upload(files):
|
| 261 |
"""Process uploaded files and index them"""
|
|
|
|
| 425 |
with gr.Tab("Chat"):
|
| 426 |
chatbot = gr.Chatbot(height=400, type="messages")
|
| 427 |
|
| 428 |
+
# Add a plot component
|
| 429 |
+
plot_output = gr.Plot(label="Data Visualization", visible=False)
|
| 430 |
+
|
| 431 |
with gr.Row():
|
| 432 |
with gr.Column(scale=8):
|
| 433 |
msg = gr.Textbox(
|
|
|
|
| 455 |
submit_btn.click(
|
| 456 |
process_text_query,
|
| 457 |
inputs=[msg, chatbot],
|
| 458 |
+
outputs=[msg, chatbot, plot_output],
|
| 459 |
+
postprocess=lambda _, __, fig: gr.update(value=fig, visible=fig is not None)
|
| 460 |
)
|
| 461 |
|
| 462 |
msg.submit(
|
| 463 |
process_text_query,
|
| 464 |
inputs=[msg, chatbot],
|
| 465 |
+
outputs=[msg, chatbot, plot_output],
|
| 466 |
+
postprocess=lambda _, __, fig: gr.update(value=fig, visible=fig is not None)
|
| 467 |
)
|
| 468 |
|
| 469 |
+
clear_btn.click(lambda: [None, None, gr.update(visible=False)], None, [chatbot, plot_output, plot_output], queue=False)
|
| 470 |
clear_context_btn.click(clear_context, inputs=[], outputs=[chatbot])
|
| 471 |
|
| 472 |
voice_btn.click(
|