SVashishta1
Error Fix
e3d98a2
raw
history blame
28.5 kB
import os
import sys
import gradio as gr
from dotenv import load_dotenv
import tempfile
import pandas as pd
import sqlite3
from langchain_core.prompts import ChatPromptTemplate
from langchain_groq import ChatGroq
import plotly.express as px
import time
import plotly.io as pio
import traceback
import base64
from io import BytesIO
import speech_recognition as sr
from gtts import gTTS
import re
# Load environment variables
load_dotenv()
# Add parent directory to path to import backend modules
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from backend.main import DocumentAssistant
# Initialize the document assistant
document_assistant = DocumentAssistant()
# Initialize the LLM using the llama3-8b-8192 model from Groq
llm = ChatGroq(
model="llama3-8b-8192",
temperature=0,
max_tokens=None,
timeout=None,
max_retries=2,
verbose=True,
api_key=os.getenv("GROQ_API_KEY")
)
# Database path for CSV data
DB_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data", "csv_data.db")
os.makedirs(os.path.dirname(DB_PATH), exist_ok=True)
# Current context to track what we're working with
current_context = {
"file_type": None,
"file_name": None,
"table_name": None
}
# Add a global variable to store the current plot
current_plot = None
# Define the prompt with examples for SQL query generation
query_prompt = ChatPromptTemplate.from_template("""
You are a SQL expert. Given a question about data in a table, write a SQLite-compatible SQL query to answer the question.
Important guidelines:
1. Use SQLite syntax (not PostgreSQL or MySQL)
2. For date functions, use strftime() instead of EXTRACT
- Example: strftime('%Y', date_column) instead of EXTRACT(YEAR FROM date_column)
3. SQLite doesn't have TRUNCATE function, use CAST((column / bin_size) AS INT) * bin_size instead
4. For percentiles, use window functions or approximate methods
5. Keep queries efficient and focused on answering the specific question
6. Always use 'data_tab' as the table name
Question: {question}
SQL Query:
""")
# Define the prompt for interpreting the SQL query result
interpret_prompt = ChatPromptTemplate.from_messages(
[
("system", "You are an experienced data analyst. Provide a concise, natural language answer based on the given data summary. If relevant, give key statistics, trends, or patterns."),
("human", "Question: {question}\nSQL Query: {sql_query}\nData Summary:\n{data_summary}")
]
)
# Add this after the query_prompt definition
visualization_prompt = ChatPromptTemplate.from_template("""
You are a data visualization expert. Given a question about visualizing data, write a SQLite-compatible SQL query that will retrieve the appropriate data for the visualization.
Important guidelines for SQLite syntax:
1. Use strftime() for date functions:
- Year: strftime('%Y', date_column)
- Month: strftime('%m', date_column)
- Day: strftime('%d', date_column)
- Hour: strftime('%H', date_column)
2. For histograms and binning:
- Use: CAST((column / bin_size) AS INT) * bin_size
- Example: CAST((trip_distance / 0.5) AS INT) * 0.5 AS distance_bin
3. For percentiles and statistics:
- SQLite doesn't have built-in percentile functions
- Use simple aggregations (MIN, MAX, AVG, COUNT) instead
4. For time series:
- Group by date parts using strftime()
- Example: strftime('%Y-%m-%d', pickup_datetime) AS day
5. Always use 'data_tab' as the table name
Question: {question}
Visualization type: {viz_type}
SQL Query:
""")
def process_text_query(query, history):
"""Process a text query and update chat history"""
if not query:
return "", history
# Add the user's query to history
history.append({"role": "user", "content": query})
start_time = time.time()
# Check if we're in CSV context
if current_context["file_type"] == "csv" and current_context["table_name"]:
try:
# Connect to the database
conn = sqlite3.connect(DB_PATH)
# Get column information for context
cursor = conn.cursor()
cursor.execute(f"PRAGMA table_info({current_context['table_name']});")
columns = [info[1] for info in cursor.fetchall()]
columns_str = ", ".join(columns)
# Create question with context
question_with_context = f"The table 'data_tab' has columns: {columns_str}. {query}"
# Generate SQL query using LLM
ai_msg = query_prompt | llm
sql_query = ai_msg.invoke({"question": question_with_context}).content.strip()
print(f"Generated SQL Query: {sql_query}")
# Check if this is a visualization request
is_visualization = any(word in query.lower() for word in ['plot', 'graph', 'chart', 'visualize', 'visualization', 'trend'])
try:
# Execute the query
result_df = pd.read_sql_query(sql_query, conn)
# Generate data summary
if not result_df.empty:
data_summary = result_df.describe(include='all').to_string()
# For small result sets, include the actual data
if len(result_df) <= 10:
data_summary += f"\n\nFull Results:\n{result_df.to_string()}"
else:
data_summary += f"\n\nFirst 5 rows:\n{result_df.head(5).to_string()}"
else:
data_summary = "No relevant data found."
# Generate interpretation
answer_chain = interpret_prompt | llm
interpretation = answer_chain.invoke({
"question": query,
"sql_query": sql_query,
"data_summary": data_summary
}).content.strip()
# Create the response
response = f"**SQL Query:**\n```sql\n{sql_query}\n```\n\n"
if not result_df.empty:
if len(result_df) > 10:
response += f"**Results (first 5 of {len(result_df)} rows):**\n```\n{result_df.head(5).to_string()}\n```\n\n"
else:
response += f"**Results:**\n```\n{result_df.to_string()}\n```\n\n"
else:
response += "**No results found.**\n\n"
response += f"**Analysis:**\n{interpretation}"
# Add visualization if requested
if is_visualization and not result_df.empty:
try:
print("Visualization requested, attempting to create plot...")
# Determine the type of visualization based on the data and query
# Check for specific visualization types in the query
is_pie_chart = any(word in query.lower() for word in ['pie chart', 'pie graph', 'distribution'])
is_histogram = any(word in query.lower() for word in ['histogram', 'distribution of', 'frequency'])
is_heatmap = any(word in query.lower() for word in ['heatmap', 'heat map', 'correlation'])
is_scatter = any(word in query.lower() for word in ['scatter', 'relationship between', 'correlation'])
if len(result_df.columns) >= 2:
# Find numeric columns for y-axis
numeric_cols = result_df.select_dtypes(include=['number']).columns.tolist()
if len(numeric_cols) >= 1 and len(result_df) > 1:
# Create appropriate plot based on query and data characteristics
if is_pie_chart and len(result_df) <= 20: # Pie charts work best with limited categories
# For pie charts, we need a category column and a value column
category_col = result_df.columns[0]
value_col = numeric_cols[0] if len(numeric_cols) > 0 else result_df.columns[1]
fig = px.pie(result_df, names=category_col, values=value_col,
title="Distribution Analysis",
hole=0.3) # Use a donut chart for better readability
elif is_histogram and len(numeric_cols) > 0:
# For histograms, we need a numeric column
fig = px.histogram(result_df, x=numeric_cols[0],
title=f"Distribution of {numeric_cols[0]}",
nbins=20)
elif is_heatmap and len(numeric_cols) >= 2:
# For heatmaps, we need at least 2 numeric columns
# Convert to a correlation matrix if needed
if len(result_df.columns) == len(numeric_cols) and len(numeric_cols) > 2:
# This is likely already a correlation matrix or similar data
fig = px.imshow(result_df,
title="Correlation Heatmap",
color_continuous_scale='RdBu_r',
aspect="auto")
else:
# Create a correlation matrix from the numeric columns
corr_df = result_df[numeric_cols].corr()
fig = px.imshow(corr_df,
title="Correlation Heatmap",
color_continuous_scale='RdBu_r',
aspect="auto")
elif is_scatter and len(numeric_cols) >= 2:
# For scatter plots, we need at least 2 numeric columns
fig = px.scatter(result_df, x=numeric_cols[0], y=numeric_cols[1],
title=f"Relationship between {numeric_cols[0]} and {numeric_cols[1]}",
opacity=0.7)
elif 'month' in result_df.columns or 'date' in result_df.columns or 'year' in result_df.columns or any('date' in col.lower() for col in result_df.columns):
# Time series data - use line chart
x_col = result_df.columns[0]
y_cols = numeric_cols[:3] # Use up to 3 numeric columns
fig = px.line(result_df, x=x_col, y=y_cols,
title="Time Series Analysis",
markers=True)
else:
# Regular data - use bar chart
x_col = result_df.columns[0]
y_cols = numeric_cols[0]
fig = px.bar(result_df, x=x_col, y=y_cols,
title="Data Visualization")
# Improve figure layout
fig.update_layout(
autosize=True,
width=900,
height=600,
margin=dict(l=50, r=50, b=100, t=100, pad=4),
template="plotly_white",
font=dict(size=14)
)
# Convert the figure to an image and encode it as base64
img_bytes = fig.to_image(format="png", width=900, height=600, scale=2)
encoded = base64.b64encode(img_bytes).decode("ascii")
img_src = f"data:image/png;base64,{encoded}"
# Add the image directly to the response
response += f"\n\n<img src='{img_src}' width='100%' />"
# Add note about visualization
response += "\n\n**A visualization has been generated and is displayed above.**"
else:
print("Not enough numeric columns or data points for visualization")
else:
print("Not enough columns for visualization")
except Exception as viz_error:
print(f"Visualization error: {str(viz_error)}")
traceback.print_exc()
except Exception as e:
response = f"**SQL Query:**\n```sql\n{sql_query}\n```\n\n**Error executing query:** {str(e)}"
conn.close()
except Exception as e:
response = f"Error processing query: {str(e)}"
else:
# For non-CSV queries, use the document assistant
try:
response = document_assistant.process_query(query)
except Exception as e:
response = f"Error processing document query: {str(e)}"
# Calculate processing time
processing_time = time.time() - start_time
response += f"\n\n(Query processed in {processing_time:.2f} seconds)"
# Add the response to history
history.append({"role": "assistant", "content": response})
return "", history
def process_file_upload(files):
"""Process uploaded files and index them"""
if not files:
return "No files uploaded"
global current_context
# Clear existing context
current_context = {
"file_type": None,
"file_name": None,
"table_name": None
}
file_info = []
for file in files:
file_path = file.name
file_name = os.path.basename(file_path)
file_ext = os.path.splitext(file_name)[1].lower()
if file_ext == '.csv':
try:
# Create table name from filename
table_name = os.path.splitext(file_name)[0].replace(' ', '_').lower()
# Load CSV into SQLite
conn = sqlite3.connect(DB_PATH)
# Configure SQLite for faster imports
conn.execute("PRAGMA synchronous = OFF")
conn.execute("PRAGMA journal_mode = MEMORY")
# Read the CSV and load it into SQLite
df = pd.read_csv(file_path)
df.to_sql('data_tab', conn, if_exists='replace', index=False)
# Update current context
current_context = {
"file_type": "csv",
"file_name": file_name,
"table_name": "data_tab" # Always use data_tab as the table name
}
# Get column info
cursor = conn.cursor()
cursor.execute("PRAGMA table_info(data_tab);")
columns = [f"{col[1]} ({col[2]})" for col in cursor.fetchall()]
# Get row count
cursor.execute("SELECT COUNT(*) FROM data_tab;")
row_count = cursor.fetchone()[0]
conn.close()
file_info.append("βœ… CSV File Successfully Loaded")
file_info.append(f"πŸ“Š Table Name: data_tab")
file_info.append(f"πŸ“„ Source File: {file_name}")
file_info.append(f"πŸ“ˆ Total Rows: {row_count:,}")
file_info.append(f"πŸ“‹ Columns: {', '.join(columns)}")
except Exception as e:
file_info.append(f"❌ Error loading CSV {file_name}: {str(e)}")
else:
# Process PDF or other document types
try:
result = document_assistant.upload_document(file_path)
# Update current context
current_context = {
"file_type": "pdf",
"file_name": file_name,
"table_name": None
}
file_info.append("βœ… Document Successfully Processed")
file_info.append(f"πŸ“„ File: {file_name}")
file_info.append(f"πŸ“š Chunks: {result['chunks']}")
file_info.append(result['message'])
except Exception as e:
file_info.append(f"❌ Error processing document {file_name}: {str(e)}")
return "\n".join(file_info)
def list_documents():
"""List all indexed documents"""
info_list = []
# Check for CSV data
try:
conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
tables = cursor.fetchall()
if tables:
info_list.append("πŸ“Š CSV Data Tables:")
for table in tables:
# Get column info
cursor.execute(f"PRAGMA table_info({table[0]});")
columns = [col[1] for col in cursor.fetchall()]
# Get row count
cursor.execute(f"SELECT COUNT(*) FROM {table[0]};")
row_count = cursor.fetchone()[0]
info_list.append(f"- {table[0]} ({row_count:,} rows, {len(columns)} columns)")
conn.close()
except Exception as e:
info_list.append(f"Error accessing CSV data: {str(e)}")
# Check for indexed documents
docs = document_assistant.get_all_documents()
if docs:
info_list.append("\nπŸ“‘ Indexed Documents:")
for doc in docs:
info_list.append(f"- {doc['filename']} (ID: {doc['id']})")
if not info_list:
return "No data or documents loaded yet"
return "\n".join(info_list)
def clear_context():
"""Clear the current context and chat history"""
global current_context
current_context = {
"file_type": None,
"file_name": None,
"table_name": None
}
return None
def process_voice_input(audio_path):
"""Process voice input and return transcribed text"""
if audio_path is None:
return "No audio recorded"
try:
# Initialize recognizer
r = sr.Recognizer()
# Load the audio file
with sr.AudioFile(audio_path) as source:
# Read the audio data
audio_data = r.record(source)
# Recognize speech using Google Speech Recognition
text = r.recognize_google(audio_data)
return text
except sr.UnknownValueError:
return "Could not understand audio"
except sr.RequestError as e:
return f"Error with speech recognition service: {e}"
except Exception as e:
return f"Error processing audio: {str(e)}"
def text_to_speech_output(text):
"""Convert text to speech"""
if not text or len(text) == 0:
return None
# Extract the last assistant message
last_message = None
for msg in reversed(text):
if msg["role"] == "assistant":
last_message = msg["content"]
break
if not last_message:
return None
try:
# Clean the text (remove markdown and HTML)
clean_text = re.sub(r'<.*?>', '', last_message) # Remove HTML tags
clean_text = re.sub(r'\*\*(.*?)\*\*', r'\1', clean_text) # Remove bold markdown
clean_text = re.sub(r'\n\n', ' ', clean_text) # Replace double newlines with space
clean_text = re.sub(r'```.*?```', 'Code block removed for speech.', clean_text, flags=re.DOTALL) # Replace code blocks
# Create a temporary file
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3")
temp_file.close()
# Generate speech
tts = gTTS(text=clean_text, lang='en', slow=False)
tts.save(temp_file.name)
return temp_file.name
except Exception as e:
print(f"Error generating speech: {str(e)}")
return None
def create_test_visualization():
"""Create a test visualization to verify plotting works"""
# Create sample data
data = pd.DataFrame({
'Month': ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun'],
'Value': [10, 15, 13, 17, 20, 25]
})
# Create a simple bar chart
fig = px.bar(data, x='Month', y='Value', title='Test Visualization')
# Configure the figure
fig.update_layout(
autosize=True,
width=800,
height=500
)
return fig
def create_test_html_visualization():
"""Create a test HTML visualization to verify plotting works"""
# Create sample data
data = pd.DataFrame({
'Month': ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun'],
'Value': [10, 15, 13, 17, 20, 25]
})
# Create a simple bar chart
fig = px.bar(data, x='Month', y='Value', title='Test Visualization')
# Convert to HTML with CDN-hosted plotly.js
html = pio.to_html(fig, full_html=False, include_plotlyjs='cdn')
return html
def flush_databases():
"""Flush ChromaDB and SQLite databases"""
result = []
# Flush SQLite database
try:
conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor()
# Get all tables
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
tables = cursor.fetchall()
# Drop all tables
for table in tables:
cursor.execute(f"DROP TABLE IF EXISTS {table[0]};")
conn.commit()
conn.close()
result.append("βœ… SQLite database cleared successfully")
except Exception as e:
result.append(f"❌ Error clearing SQLite database: {str(e)}")
# Flush ChromaDB by resetting the document assistant
try:
document_assistant.reset_database()
result.append("βœ… ChromaDB cleared successfully")
except Exception as e:
result.append(f"❌ Error clearing ChromaDB: {str(e)}")
# Reset current context
global current_context
current_context = {
"file_type": None,
"file_name": None,
"table_name": None
}
return "\n".join(result)
# Create Gradio interface
with gr.Blocks(title="AI Document Analysis & Voice Assistant") as demo:
gr.Markdown("# πŸ€– AI Document Analysis & Voice Assistant")
gr.Markdown("Upload documents, ask questions, and get voice responses!")
with gr.Tab("Chat"):
# Use a custom CSS to ensure images are displayed properly
gr.HTML("""
<style>
.chatbot-container img {
max-width: 100%;
height: auto;
display: block;
margin: 10px 0;
}
</style>
""")
chatbot = gr.Chatbot(height=500, type="messages", elem_classes="chatbot-container")
with gr.Row():
with gr.Column(scale=8):
msg = gr.Textbox(
placeholder="Ask a question about your documents...",
show_label=False
)
with gr.Column(scale=1):
voice_btn = gr.Button("🎀")
with gr.Row():
submit_btn = gr.Button("Submit")
clear_btn = gr.Button("Clear")
clear_context_btn = gr.Button("Clear Context")
audio_output = gr.Audio(label="Voice Response", type="filepath")
# Voice input
voice_input = gr.Audio(
label="Voice Input",
type="filepath",
visible=False
)
# Event handlers
submit_btn.click(
process_text_query,
inputs=[msg, chatbot],
outputs=[msg, chatbot]
)
msg.submit(
process_text_query,
inputs=[msg, chatbot],
outputs=[msg, chatbot]
)
clear_btn.click(lambda: None, None, [chatbot], queue=False)
clear_context_btn.click(clear_context, inputs=[], outputs=[chatbot])
voice_btn.click(
lambda: gr.update(visible=True),
None,
voice_input
)
voice_input.change(
process_voice_input,
inputs=[voice_input],
outputs=[msg]
)
# Add TTS functionality
tts_btn = gr.Button("πŸ”Š Speak Response")
tts_btn.click(
text_to_speech_output,
inputs=[chatbot],
outputs=[audio_output]
)
with gr.Tab("Document Upload"):
with gr.Row():
file_upload = gr.File(
label="Upload Documents",
file_types=[".pdf", ".txt", ".docx", ".csv", ".xlsx"],
file_count="multiple"
)
flush_db_btn_doc = gr.Button("πŸ—‘οΈ Flush All Databases", variant="stop")
upload_button = gr.Button("Process & Index Documents")
upload_output = gr.Textbox(label="Upload Status")
upload_button.click(
process_file_upload,
inputs=[file_upload],
outputs=[upload_output]
)
flush_db_btn_doc.click(
flush_databases,
inputs=[],
outputs=[upload_output]
)
list_docs_button = gr.Button("List Indexed Documents")
docs_output = gr.Textbox(label="Indexed Documents")
list_docs_button.click(
list_documents,
inputs=[],
outputs=[docs_output]
)
with gr.Tab("Settings"):
with gr.Row():
gr.Markdown("## Database Management")
flush_db_btn = gr.Button("πŸ—‘οΈ Flush All Databases", variant="stop", scale=1)
flush_result = gr.Textbox(label="Flush Result")
flush_db_btn.click(
flush_databases,
inputs=[],
outputs=[flush_result]
)
gr.Markdown("## System Settings")
api_key = gr.Textbox(
label="Groq API Key",
placeholder="Enter your Groq API key",
type="password",
value=os.getenv("GROQ_API_KEY", "")
)
save_btn = gr.Button("Save Settings")
def save_settings(key):
os.environ["GROQ_API_KEY"] = key
return "Settings saved!"
save_btn.click(
save_settings,
inputs=[api_key],
outputs=[gr.Textbox(label="Status")]
)
gr.Markdown("## Debugging")
test_viz_btn = gr.Button("Test Visualization")
test_viz_output = gr.HTML(label="Test Visualization")
test_viz_btn.click(
create_test_html_visualization,
inputs=[],
outputs=[test_viz_output]
)
# Launch the app
if __name__ == "__main__":
demo.launch()