SVashishta1
Error Fix
028022d
raw
history blame
45.5 kB
import os
import sys
import gradio as gr
from dotenv import load_dotenv
import tempfile
import pandas as pd
import sqlite3
from langchain_core.prompts import ChatPromptTemplate
from langchain_groq import ChatGroq
import plotly.express as px
import time
import plotly.io as pio
import traceback
import base64
from io import BytesIO
import speech_recognition as sr
from gtts import gTTS
import re
import importlib.util
# Load environment variables
load_dotenv()
# Add parent directory to path to import backend modules
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from backend.main import DocumentAssistant
# Initialize the document assistant
document_assistant = DocumentAssistant()
# Initialize the LLM using the llama3-8b-8192 model from Groq
llm = ChatGroq(
model="llama3-8b-8192",
temperature=0,
max_tokens=None,
timeout=None,
max_retries=2,
verbose=True,
api_key=os.getenv("GROQ_API_KEY")
)
# Database path for CSV data
DB_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data", "csv_data.db")
os.makedirs(os.path.dirname(DB_PATH), exist_ok=True)
# Current context to track what we're working with
current_context = {
"file_type": None,
"file_name": None,
"table_name": None
}
# Add a global variable to store the current plot
current_plot = None
# Define the prompt with examples for SQL query generation
query_prompt = ChatPromptTemplate.from_template("""
You are a SQL expert. Given a question about data in a table, write a SQLite-compatible SQL query to answer the question.
Important guidelines:
1. Use SQLite syntax (not PostgreSQL or MySQL)
2. For date functions, use strftime() instead of EXTRACT
- Example: strftime('%Y', date_column) instead of EXTRACT(YEAR FROM date_column)
3. SQLite doesn't have TRUNCATE function, use CAST((column / bin_size) AS INT) * bin_size instead
4. For percentiles, use window functions or approximate methods
5. Keep queries efficient and focused on answering the specific question
6. Always use 'data_tab' as the table name
7. IMPORTANT: Return ONLY the SQL query without any markdown formatting, explanations, or code blocks
Question: {question}
""")
# Define the prompt for interpreting the SQL query result
interpret_prompt = ChatPromptTemplate.from_messages(
[
("system", "You are an experienced data analyst. Provide a concise, natural language answer based on the given data summary. If relevant, give key statistics, trends, or patterns."),
("human", "Question: {question}\nSQL Query: {sql_query}\nData Summary:\n{data_summary}")
]
)
# Add this after the query_prompt definition
visualization_prompt = ChatPromptTemplate.from_template("""
You are a data visualization expert. Given a question about visualizing data, write a SQLite-compatible SQL query that will retrieve the appropriate data for the visualization.
Important guidelines for SQLite syntax:
1. Use strftime() for date functions:
- Year: strftime('%Y', date_column)
- Month: strftime('%m', date_column)
- Day: strftime('%d', date_column)
- Hour: strftime('%H', date_column)
2. For histograms and binning:
- Use: CAST((column / bin_size) AS INT) * bin_size
- Example: CAST((trip_distance / 0.5) AS INT) * 0.5 AS distance_bin
3. For box plots:
- SQLite doesn't support PERCENTILE_CONT or window functions
- Simply return the raw data column: SELECT column_name FROM data_tab
- The application will calculate quartiles and outliers
4. For heatmaps:
- Return raw data for correlation analysis
- Example: SELECT numeric_col1, numeric_col2, numeric_col3 FROM data_tab
5. Always use 'data_tab' as the table name
6. IMPORTANT: Return ONLY the SQL query without any markdown formatting, explanations, or code blocks
Question: {question}
Visualization type: {viz_type}
""")
# Add this helper function to clean SQL queries
def clean_sql_query(query_text):
"""Clean SQL query text by removing markdown formatting and comments"""
# Check if input is None or empty
if not query_text:
return "SELECT * FROM data_tab LIMIT 10;"
# Remove markdown code blocks
if "```" in query_text:
# Extract content between code blocks
pattern = r"```(?:sql)?(.*?)```"
matches = re.findall(pattern, query_text, re.DOTALL)
if matches:
query_text = matches[0].strip()
# Remove any "Here is the SQL query" text that might precede the query
prefixes = [
"here is the sql query",
"here is the sqlite query",
"here is a query",
"here's the sql query",
"the sql query is",
"sql query:"
]
for prefix in prefixes:
if query_text.lower().startswith(prefix):
# Find the first occurrence of "SELECT", "WITH", etc.
sql_keywords = ["select", "with", "create", "insert", "update", "delete"]
positions = [query_text.lower().find(keyword) for keyword in sql_keywords]
positions = [pos for pos in positions if pos != -1]
if positions:
start_pos = min(positions)
query_text = query_text[start_pos:]
# Remove SQL comments
query_text = re.sub(r'--.*?(\n|$)', ' ', query_text)
# Remove trailing semicolon if present
query_text = query_text.strip().rstrip(';')
# Ensure the query is not empty
if not query_text.strip():
return "SELECT * FROM data_tab LIMIT 10;"
return query_text
def process_text_query(query, history):
"""Process a text query and update chat history"""
if not query:
return "", history
# Add the user's query to history
history.append({"role": "user", "content": query})
start_time = time.time()
# Define visualization keywords at the beginning
viz_keywords = {
'bar': ['bar chart', 'bar graph', 'bar plot', 'barchart', 'bargraph'],
'line': ['line chart', 'line graph', 'line plot', 'linechart', 'trend', 'trends', 'time series'],
'pie': ['pie chart', 'pie graph', 'pie plot', 'piechart', 'distribution', 'proportion'],
'histogram': ['histogram', 'distribution of', 'frequency distribution'],
'box': ['box plot', 'boxplot', 'box and whisker', 'outliers', 'quartiles'],
'heatmap': ['heatmap', 'heat map', 'correlation matrix', 'correlation heatmap'],
'scatter': ['scatter', 'scatter plot', 'relationship between', 'correlation between']
}
# Check if this is a visualization request
is_visualization = any(word in query.lower() for word in ['plot', 'graph', 'chart', 'visualize', 'visualization', 'trend', 'show me'])
# Determine visualization type from query
viz_type = None
if is_visualization:
for vtype, keywords in viz_keywords.items():
if any(keyword in query.lower() for keyword in keywords):
viz_type = vtype
break
# Check if we're in CSV context
if current_context["file_type"] == "csv" and current_context["table_name"]:
try:
# Connect to the database
conn = sqlite3.connect(DB_PATH)
# Get column information for context
cursor = conn.cursor()
cursor.execute(f"PRAGMA table_info({current_context['table_name']});")
columns = [info[1] for info in cursor.fetchall()]
columns_str = ", ".join(columns)
# Create question with context
question_with_context = f"The table 'data_tab' has columns: {columns_str}. {query}"
# Special handling for visualization types that need raw data
if is_visualization and viz_type in ['box', 'heatmap']:
# For box plots and heatmaps, we need raw data
if viz_type == 'box':
# For box plots, we need a single numeric column
numeric_cols_query = "SELECT name FROM pragma_table_info('data_tab') WHERE type LIKE '%INT%' OR type LIKE '%REAL%' OR type LIKE '%FLOA%' OR type LIKE '%NUM%';"
cursor = conn.cursor()
cursor.execute(numeric_cols_query)
numeric_cols = [row[0] for row in cursor.fetchall()]
if numeric_cols:
# Find the relevant numeric column based on the query
target_col = None
for col in numeric_cols:
if col.lower() in query.lower():
target_col = col
break
# If no specific column is mentioned, use the first numeric column
if not target_col and numeric_cols:
target_col = numeric_cols[0]
# Generate a simple query to get the raw data
sql_query = f"SELECT {target_col} FROM data_tab WHERE {target_col} IS NOT NULL;"
else:
# No numeric columns found
sql_query = "SELECT * FROM data_tab LIMIT 10;"
elif viz_type == 'heatmap':
# For heatmaps, we need multiple numeric columns
numeric_cols_query = "SELECT name FROM pragma_table_info('data_tab') WHERE type LIKE '%INT%' OR type LIKE '%REAL%' OR type LIKE '%FLOA%' OR type LIKE '%NUM%';"
cursor = conn.cursor()
cursor.execute(numeric_cols_query)
numeric_cols = [row[0] for row in cursor.fetchall()]
if len(numeric_cols) >= 2:
# Use all numeric columns (up to a reasonable limit)
cols_to_use = numeric_cols[:10] # Limit to 10 columns for performance
cols_str = ", ".join(cols_to_use)
sql_query = f"SELECT {cols_str} FROM data_tab WHERE {numeric_cols[0]} IS NOT NULL LIMIT 1000;"
else:
# Not enough numeric columns
sql_query = "SELECT * FROM data_tab LIMIT 10;"
else:
# Generate SQL query using LLM
ai_msg = query_prompt | llm
raw_sql_query = ai_msg.invoke({"question": question_with_context}).content.strip()
# Clean the SQL query
sql_query = clean_sql_query(raw_sql_query)
print(f"Generated SQL Query: {sql_query}")
try:
# Execute the query
result_df = pd.read_sql_query(sql_query, conn)
# Generate data summary
if not result_df.empty:
data_summary = result_df.describe(include='all').to_string()
# For small result sets, include the actual data
if len(result_df) <= 10:
data_summary += f"\n\nFull Results:\n{result_df.to_string()}"
else:
data_summary += f"\n\nFirst 5 rows:\n{result_df.head(5).to_string()}"
else:
data_summary = "No relevant data found."
# Generate interpretation
answer_chain = interpret_prompt | llm
interpretation = answer_chain.invoke({
"question": query,
"sql_query": sql_query,
"data_summary": data_summary
}).content.strip()
# Create the response
response = f"**SQL Query:**\n```sql\n{sql_query}\n```\n\n"
if not result_df.empty:
if len(result_df) > 10:
response += f"**Results (first 5 of {len(result_df)} rows):**\n```\n{result_df.head(5).to_string()}\n```\n\n"
else:
response += f"**Results:**\n```\n{result_df.to_string()}\n```\n\n"
else:
response += "**No results found.**\n\n"
response += f"**Analysis:**\n{interpretation}"
# Add visualization if requested
if is_visualization and not result_df.empty:
try:
print("Visualization requested, attempting to create plot...")
# Set common figure parameters
fig_width = 1000
fig_height = 700
# Create the appropriate visualization based on type
if viz_type == 'pie' and len(result_df) <= 20:
# For pie charts, we need a category column and a value column
category_col = result_df.columns[0]
value_col = numeric_cols[0] if numeric_cols else result_df.columns[1]
# Handle case where all columns are numeric
if len(numeric_cols) == len(result_df.columns):
category_col = result_df.index.name or 'index'
result_df = result_df.reset_index()
fig = px.pie(
result_df,
names=category_col,
values=value_col,
title=f"Distribution of {value_col} by {category_col}",
hole=0.3, # Donut chart for better readability
color_discrete_sequence=px.colors.qualitative.Pastel
)
elif viz_type == 'histogram' and len(result_df.columns) > 0:
# For histograms, we need at least one column
# Find the best column for histogram (prefer numeric)
if numeric_cols:
x_col = numeric_cols[0]
else:
x_col = result_df.columns[0]
# Check if data is already binned
if len(result_df) <= 30 and ('bin' in result_df.columns or 'range' in result_df.columns):
# Data is pre-binned, use a bar chart
bin_col = 'bin' if 'bin' in result_df.columns else 'range'
count_col = 'count' if 'count' in result_df.columns else numeric_cols[0] if numeric_cols else result_df.columns[1]
fig = px.bar(
result_df,
x=bin_col,
y=count_col,
title=f"Histogram of {x_col}",
labels={bin_col: x_col, count_col: 'Frequency'},
color_discrete_sequence=['#636EFA']
)
else:
# Create a proper histogram from raw data
fig = px.histogram(
result_df,
x=x_col,
title=f"Distribution of {x_col}",
nbins=20,
marginal="box", # Add a box plot on the margin
color_discrete_sequence=['#636EFA'],
opacity=0.8
)
# Improve histogram layout
fig.update_layout(
bargap=0.1, # Gap between bars
xaxis_title=x_col,
yaxis_title='Frequency',
showlegend=True
)
elif viz_type == 'box' and numeric_cols:
# For box plots, we need to handle the data differently
# SQLite doesn't support window functions for percentiles
# So we'll calculate the box plot statistics in Python
# Get the numeric column to plot
x_col = numeric_cols[0]
# Create a box plot using plotly express
fig = px.box(
result_df,
y=x_col,
title=f"Box Plot of {x_col}",
points="outliers", # Only show outlier points
color_discrete_sequence=['#636EFA']
)
# Add a strip plot (individual points) on the side for better visualization
fig.add_trace(
px.strip(result_df, y=x_col, color_discrete_sequence=['#FECB52']).data[0]
)
elif viz_type == 'heatmap' and len(numeric_cols) >= 2:
# For heatmaps, we need at least 2 numeric columns
# If we have many numeric columns, create a correlation matrix
if len(numeric_cols) >= 3:
# Create a correlation matrix
# First, drop any rows with NaN values in numeric columns
clean_df = result_df[numeric_cols].dropna()
if len(clean_df) > 1: # Need at least 2 rows for correlation
corr_df = clean_df.corr()
# Round to 2 decimal places for display
corr_df = corr_df.round(2)
fig = px.imshow(
corr_df,
title="Correlation Heatmap",
color_continuous_scale='RdBu_r',
text_auto=True, # Show correlation values
aspect="auto",
zmin=-1, zmax=1 # Set limits for correlation values
)
# Improve heatmap layout
fig.update_layout(
xaxis_title="Features",
yaxis_title="Features",
coloraxis_colorbar=dict(
title="Correlation",
thicknessmode="pixels", thickness=20,
lenmode="pixels", len=300,
yanchor="top", y=1,
ticks="outside"
)
)
else:
# Not enough data for correlation
fig = px.bar(
pd.DataFrame({'Message': ['Not enough data for heatmap']}),
title="Cannot create heatmap - insufficient data"
)
else:
# If we only have 2 numeric columns, create a 2D histogram
x_col = numeric_cols[0]
y_col = numeric_cols[1]
# Create a 2D histogram (heatmap)
fig = px.density_heatmap(
result_df,
x=x_col,
y=y_col,
title=f"Density Heatmap of {x_col} vs {y_col}",
color_continuous_scale='Viridis',
nbinsx=20,
nbinsy=20,
marginal_x="histogram", # Add histograms on the margins
marginal_y="histogram"
)
# Improve heatmap layout
fig.update_layout(
xaxis_title=x_col,
yaxis_title=y_col,
coloraxis_colorbar=dict(
title="Count",
thicknessmode="pixels", thickness=20,
lenmode="pixels", len=300,
yanchor="top", y=1,
ticks="outside"
)
)
elif viz_type == 'scatter' and len(numeric_cols) >= 2:
# For scatter plots, we need at least 2 numeric columns
x_col = numeric_cols[0]
y_col = numeric_cols[1]
# Add a third dimension (size) if available
size_col = numeric_cols[2] if len(numeric_cols) > 2 else None
# Add a color dimension if available
if len(result_df.columns) > len(numeric_cols):
# Find a categorical column for color
categorical_cols = [col for col in result_df.columns if col not in numeric_cols]
color_col = categorical_cols[0] if categorical_cols else None
else:
color_col = None
# Create scatter plot with enhanced features
fig = px.scatter(
result_df,
x=x_col,
y=y_col,
size=size_col,
color=color_col, # Add color dimension if available
title=f"Relationship between {x_col} and {y_col}",
opacity=0.7,
size_max=15, # Maximum marker size
color_discrete_sequence=px.colors.qualitative.Plotly
)
# Add a trend line
if pd.api.types.is_numeric_dtype(result_df[x_col]) and pd.api.types.is_numeric_dtype(result_df[y_col]):
fig.update_layout(
shapes=[
dict(
type='line',
xref='x', yref='y',
x0=result_df[x_col].min(),
y0=result_df[y_col].min(),
x1=result_df[x_col].max(),
y1=result_df[y_col].max(),
line=dict(color='red', width=2, dash='dash')
)
]
)
# Improve scatter plot layout
fig.update_layout(
xaxis_title=x_col,
yaxis_title=y_col,
showlegend=True,
legend=dict(
title=color_col if color_col else "",
orientation="h",
yanchor="bottom",
y=1.02,
xanchor="right",
x=1
)
)
elif viz_type == 'line':
# For line charts, determine the x-axis (preferably a date/time column)
time_cols = [col for col in result_df.columns if any(time_word in col.lower()
for time_word in ['date', 'time', 'month', 'year', 'day'])]
if time_cols:
x_col = time_cols[0]
else:
x_col = result_df.columns[0]
# Determine y-axis columns (numeric columns)
y_cols = numeric_cols[:3] # Use up to 3 numeric columns
if not y_cols and len(result_df.columns) > 1:
# If no numeric columns, use the second column
y_cols = [result_df.columns[1]]
fig = px.line(
result_df,
x=x_col,
y=y_cols,
title="Time Series Analysis",
markers=True, # Add markers at each data point
color_discrete_sequence=px.colors.qualitative.Plotly
)
# Add range slider for time series
fig.update_layout(
xaxis=dict(
rangeslider=dict(visible=True),
type='category' if not pd.api.types.is_datetime64_any_dtype(result_df[x_col]) else '-'
)
)
else: # Default to bar chart
# For bar charts, use the first column as x and numeric columns as y
x_col = result_df.columns[0]
# Determine y-axis columns (numeric columns)
if numeric_cols and x_col not in numeric_cols:
y_cols = numeric_cols[:3] # Use up to 3 numeric columns
elif len(result_df.columns) > 1:
y_cols = [result_df.columns[1]]
else:
y_cols = ['value']
result_df['value'] = 1 # Default value if no suitable column
fig = px.bar(
result_df,
x=x_col,
y=y_cols[0], # Use only the first y column for bar charts
title="Data Visualization",
color_discrete_sequence=['#636EFA']
)
# Improve figure layout for all chart types
fig.update_layout(
autosize=True,
width=fig_width,
height=fig_height,
margin=dict(l=50, r=50, b=100, t=100, pad=4),
template="plotly_white",
font=dict(size=14),
legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1),
plot_bgcolor='rgba(240,240,240,0.2)', # Light gray background
paper_bgcolor='white'
)
# Convert the figure to an image and encode it as base64
img_bytes = fig.to_image(format="png", width=fig_width, height=fig_height, scale=2)
encoded = base64.b64encode(img_bytes).decode("ascii")
img_src = f"data:image/png;base64,{encoded}"
# Add the image directly to the response with increased size
response += f"\n\n<img src='{img_src}' width='100%' style='min-height:700px;' />"
# Add note about visualization
response += f"\n\n**A {viz_type} visualization has been generated and is displayed above.**"
except Exception as viz_error:
print(f"Visualization error: {str(viz_error)}")
traceback.print_exc()
except Exception as e:
response = f"**SQL Query:**\n```sql\n{sql_query}\n```\n\n**Error executing query:** {str(e)}"
conn.close()
except Exception as e:
response = f"Error processing query: {str(e)}"
else:
# For non-CSV queries, use the document assistant
try:
response = document_assistant.process_query(query)
except Exception as e:
response = f"Error processing document query: {str(e)}"
# Calculate processing time
processing_time = time.time() - start_time
response += f"\n\n(Query processed in {processing_time:.2f} seconds)"
# Add the response to history
history.append({"role": "assistant", "content": response})
return "", history
def process_file_upload(files):
"""Process uploaded files and index them"""
if not files:
return "No files uploaded"
global current_context
# Clear existing context
current_context = {
"file_type": None,
"file_name": None,
"table_name": None
}
file_info = []
for file in files:
file_path = file.name
file_name = os.path.basename(file_path)
file_ext = os.path.splitext(file_name)[1].lower()
if file_ext == '.csv':
try:
# Create table name from filename
table_name = os.path.splitext(file_name)[0].replace(' ', '_').lower()
# Load CSV into SQLite
conn = sqlite3.connect(DB_PATH)
# Configure SQLite for faster imports
conn.execute("PRAGMA synchronous = OFF")
conn.execute("PRAGMA journal_mode = MEMORY")
# Read the CSV and load it into SQLite
df = pd.read_csv(file_path)
df.to_sql('data_tab', conn, if_exists='replace', index=False)
# Update current context
current_context = {
"file_type": "csv",
"file_name": file_name,
"table_name": "data_tab" # Always use data_tab as the table name
}
# Get column info
cursor = conn.cursor()
cursor.execute("PRAGMA table_info(data_tab);")
columns = [f"{col[1]} ({col[2]})" for col in cursor.fetchall()]
# Get row count
cursor.execute("SELECT COUNT(*) FROM data_tab;")
row_count = cursor.fetchone()[0]
conn.close()
file_info.append("βœ… CSV File Successfully Loaded")
file_info.append(f"πŸ“Š Table Name: data_tab")
file_info.append(f"πŸ“„ Source File: {file_name}")
file_info.append(f"πŸ“ˆ Total Rows: {row_count:,}")
file_info.append(f"πŸ“‹ Columns: {', '.join(columns)}")
except Exception as e:
file_info.append(f"❌ Error loading CSV {file_name}: {str(e)}")
else:
# Process PDF or other document types
try:
result = document_assistant.upload_document(file_path)
# Update current context
current_context = {
"file_type": "pdf",
"file_name": file_name,
"table_name": None
}
file_info.append("βœ… Document Successfully Processed")
file_info.append(f"πŸ“„ File: {file_name}")
file_info.append(f"πŸ“š Chunks: {result['chunks']}")
file_info.append(result['message'])
except Exception as e:
file_info.append(f"❌ Error processing document {file_name}: {str(e)}")
return "\n".join(file_info)
def list_documents():
"""List all indexed documents"""
try:
docs = document_assistant.get_all_documents()
if not docs:
return "No documents indexed yet."
result = "Indexed Documents:\n\n"
for doc in docs:
result += f"- {doc['filename']} ({doc['file_type']})\n"
return result
except Exception as e:
return f"Error listing documents: {str(e)}"
def clear_context():
"""Clear the current context"""
global current_context
try:
# Reset the context
current_context = {
"file_type": None,
"file_name": None,
"table_name": None
}
return [{"role": "assistant", "content": "Context cleared. You can now upload new documents or CSV files."}]
except Exception as e:
return [{"role": "assistant", "content": f"Error clearing context: {str(e)}"}]
def process_voice_input(audio_path):
"""Process voice input and return transcribed text"""
if audio_path is None:
return "No audio recorded"
try:
# Initialize recognizer
r = sr.Recognizer()
# Load the audio file
with sr.AudioFile(audio_path) as source:
# Read the audio data
audio_data = r.record(source)
# Recognize speech using Google Speech Recognition
text = r.recognize_google(audio_data)
return text
except sr.UnknownValueError:
return "Could not understand audio"
except sr.RequestError as e:
return f"Error with speech recognition service: {e}"
except Exception as e:
return f"Error processing audio: {str(e)}"
def text_to_speech_output(text):
"""Convert text to speech"""
if not text or len(text) == 0:
return None
# Extract the last assistant message
last_message = None
for msg in reversed(text):
if msg["role"] == "assistant":
last_message = msg["content"]
break
if not last_message:
return None
try:
# Clean the text (remove markdown and HTML)
clean_text = re.sub(r'<.*?>', '', last_message) # Remove HTML tags
clean_text = re.sub(r'\*\*(.*?)\*\*', r'\1', clean_text) # Remove bold markdown
clean_text = re.sub(r'\n\n', ' ', clean_text) # Replace double newlines with space
clean_text = re.sub(r'```.*?```', 'Code block removed for speech.', clean_text, flags=re.DOTALL) # Replace code blocks
# Create a temporary file
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3")
temp_file.close()
# Generate speech
tts = gTTS(text=clean_text, lang='en', slow=False)
tts.save(temp_file.name)
return temp_file.name
except Exception as e:
print(f"Error generating speech: {str(e)}")
return None
def create_test_visualization():
"""Create a test visualization to verify plotting works"""
try:
# Create sample data
data = pd.DataFrame({
'Month': ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun'],
'Value': [10, 15, 13, 17, 20, 25]
})
# Create a simple bar chart
fig = px.bar(data, x='Month', y='Value', title='Test Visualization')
# Configure the figure
fig.update_layout(
autosize=True,
width=800,
height=500
)
return fig
except Exception as e:
print(f"Error creating test visualization: {str(e)}")
return None
def create_test_html_visualization():
"""Create a test HTML visualization to verify plotting works"""
try:
# Create sample data
data = pd.DataFrame({
'Month': ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun'],
'Value': [10, 15, 13, 17, 20, 25]
})
# Create a simple bar chart
fig = px.bar(data, x='Month', y='Value', title='Test Visualization')
# Convert to HTML
html = pio.to_html(fig, full_html=False)
return html
except Exception as e:
print(f"Error creating test HTML visualization: {str(e)}")
return None
def flush_databases():
"""Flush ChromaDB and SQLite databases"""
result = []
# Flush SQLite database
try:
conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor()
# Get all tables
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
tables = cursor.fetchall()
# Drop all tables
for table in tables:
cursor.execute(f"DROP TABLE IF EXISTS {table[0]};")
conn.commit()
conn.close()
result.append("βœ… SQLite database cleared successfully")
except Exception as e:
result.append(f"❌ Error clearing SQLite database: {str(e)}")
# Flush ChromaDB by resetting the document assistant
try:
success = document_assistant.reset_database()
if success:
result.append("βœ… ChromaDB cleared successfully")
else:
result.append("⚠️ ChromaDB reset may not have been complete")
except Exception as e:
result.append(f"❌ Error clearing ChromaDB: {str(e)}")
# Reset current context
global current_context
current_context = {
"file_type": None,
"file_name": None,
"table_name": None
}
return "\n".join(result)
# At the beginning of app.py, after the imports
# Add this code to monkey patch the vector_db module
try:
from backend.vector_db import ChromaVectorDB
except NameError as e:
if "response" in str(e):
# If the error is about 'response' not being defined, fix the module
import backend.vector_db
# Remove the problematic code
if hasattr(backend.vector_db, 'response'):
delattr(backend.vector_db, 'response')
# Reload the module
importlib.reload(backend.vector_db)
from backend.vector_db import ChromaVectorDB
# Create Gradio interface
with gr.Blocks(title="AI Document Analysis & Voice Assistant") as demo:
gr.Markdown("# πŸ€– AI Document Analysis & Voice Assistant")
gr.Markdown("Upload documents, ask questions, and get voice responses!")
with gr.Tab("Chat"):
# Use a custom CSS to ensure images are displayed properly
gr.HTML("""
<style>
.chatbot-container img {
max-width: 100%;
height: auto;
display: block;
margin: 10px 0;
}
</style>
""")
chatbot = gr.Chatbot(height=500, type="messages", elem_classes="chatbot-container")
with gr.Row():
with gr.Column(scale=8):
msg = gr.Textbox(
placeholder="Ask a question about your documents...",
show_label=False
)
with gr.Column(scale=1):
voice_btn = gr.Button("🎀")
with gr.Row():
submit_btn = gr.Button("Submit")
clear_btn = gr.Button("Clear")
clear_context_btn = gr.Button("Clear Context")
audio_output = gr.Audio(label="Voice Response", type="filepath")
# Voice input
voice_input = gr.Audio(
label="Voice Input",
type="filepath",
visible=False
)
# Event handlers
submit_btn.click(
process_text_query,
inputs=[msg, chatbot],
outputs=[msg, chatbot]
)
msg.submit(
process_text_query,
inputs=[msg, chatbot],
outputs=[msg, chatbot]
)
clear_btn.click(lambda: None, None, [chatbot], queue=False)
clear_context_btn.click(clear_context, inputs=[], outputs=[chatbot])
voice_btn.click(
lambda: gr.update(visible=True),
None,
voice_input
)
voice_input.change(
process_voice_input,
inputs=[voice_input],
outputs=[msg]
)
# Add TTS functionality
tts_btn = gr.Button("πŸ”Š Speak Response")
tts_btn.click(
text_to_speech_output,
inputs=[chatbot],
outputs=[audio_output]
)
with gr.Tab("Document Upload"):
file_upload = gr.File(
label="Upload Documents",
file_types=[".pdf", ".txt", ".docx", ".csv", ".xlsx"],
file_count="multiple"
)
with gr.Row():
upload_button = gr.Button("Process & Index Documents", scale=2)
flush_db_btn_doc = gr.Button("πŸ—‘οΈ Flush All Databases", variant="stop", scale=1)
upload_output = gr.Textbox(label="Upload Status")
upload_button.click(
process_file_upload,
inputs=[file_upload],
outputs=[upload_output]
)
flush_db_btn_doc.click(
flush_databases,
inputs=[],
outputs=[upload_output]
)
list_docs_button = gr.Button("List Indexed Documents")
docs_output = gr.Textbox(label="Indexed Documents")
list_docs_button.click(
list_documents,
inputs=[],
outputs=[docs_output]
)
with gr.Tab("Settings"):
with gr.Row():
gr.Markdown("## Database Management")
flush_db_btn = gr.Button("πŸ—‘οΈ Flush All Databases", variant="stop", scale=1)
flush_result = gr.Textbox(label="Flush Result")
flush_db_btn.click(
flush_databases,
inputs=[],
outputs=[flush_result]
)
gr.Markdown("## System Settings")
api_key = gr.Textbox(
label="Groq API Key",
placeholder="Enter your Groq API key",
type="password",
value=os.getenv("GROQ_API_KEY", "")
)
save_btn = gr.Button("Save Settings")
def save_settings(key):
try:
os.environ["GROQ_API_KEY"] = key
return "Settings saved!"
except Exception as e:
return f"Error saving settings: {str(e)}"
save_btn.click(
save_settings,
inputs=[api_key],
outputs=[gr.Textbox(label="Status")]
)
gr.Markdown("## Debugging")
test_viz_btn = gr.Button("Test Visualization")
test_viz_output = gr.HTML(label="Test Visualization")
test_viz_btn.click(
create_test_html_visualization,
inputs=[],
outputs=[test_viz_output]
)
# Launch the app
if __name__ == "__main__":
demo.launch()