import os
import sys
import gradio as gr
from dotenv import load_dotenv
import tempfile
import pandas as pd
import sqlite3
from langchain_core.prompts import ChatPromptTemplate
from langchain_groq import ChatGroq
import plotly.express as px
import time
import plotly.io as pio
import traceback
import base64
from io import BytesIO
import re
import importlib.util
# Load environment variables
load_dotenv()
# Add parent directory to path to import backend modules
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from backend.main import DocumentAssistant
# Initialize the document assistant
document_assistant = DocumentAssistant()
# Initialize the LLM using the llama3-8b-8192 model from Groq
llm = ChatGroq(
model="llama3-8b-8192",
temperature=0,
max_tokens=None,
timeout=None,
max_retries=2,
verbose=True,
api_key=os.getenv("GROQ_API_KEY")
)
# Database path for CSV data
DB_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data", "csv_data.db")
os.makedirs(os.path.dirname(DB_PATH), exist_ok=True)
# Create data directory if it doesn't exist
DATA_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data")
os.makedirs(DATA_DIR, exist_ok=True)
# Create chroma_db directory if it doesn't exist
CHROMA_DB_DIR = os.path.join(DATA_DIR, "chroma_db")
os.makedirs(CHROMA_DB_DIR, exist_ok=True)
# Set environment variables for ChromaDB
os.environ["CHROMA_DB_PATH"] = CHROMA_DB_DIR
# Current context to track what we're working with
current_context = {
"file_type": None,
"file_name": None,
"table_name": None
}
# Add a global variable to store the current plot
# current_plot = None
# Define the prompt with examples for SQL query generation
query_prompt = ChatPromptTemplate.from_template("""
You are a SQL expert. Given a question about data in a table, write a SQLite-compatible SQL query to answer the question.
Important guidelines:
1. Use SQLite syntax (not PostgreSQL or MySQL)
2. For date functions, use strftime() instead of EXTRACT
- Example: strftime('%Y', date_column) instead of EXTRACT(YEAR FROM date_column)
3. SQLite doesn't have TRUNCATE function, use CAST((column / bin_size) AS INT) * bin_size instead
4. For percentiles, use window functions or approximate methods
5. Keep queries efficient and focused on answering the specific question
6. Always use 'data_tab' as the table name
7. IMPORTANT: Return ONLY the SQL query without any markdown formatting, explanations, or code blocks
Question: {question}
""")
# Define the prompt for interpreting the SQL query result
interpret_prompt = ChatPromptTemplate.from_messages(
[
("system", "You are an experienced data analyst. Provide a concise, natural language answer based on the given data summary. If relevant, give key statistics, trends, or patterns."),
("human", "Question: {question}\nSQL Query: {sql_query}\nData Summary:\n{data_summary}")
]
)
# Add this after the query_prompt definition
# visualization_prompt = ChatPromptTemplate.from_template("""
# You are a data visualization expert. Given a question about visualizing data, write a SQLite-compatible SQL query that will retrieve the appropriate data for the visualization.
#
# Important guidelines for SQLite syntax:
# 1. Use strftime() for date functions:
# - Year: strftime('%Y', date_column)
# - Month: strftime('%m', date_column)
# - Day: strftime('%d', date_column)
# - Hour: strftime('%H', date_column)
#
# 2. For histograms and binning:
# - Use: CAST((column / bin_size) AS INT) * bin_size
# - Example: CAST((trip_distance / 0.5) AS INT) * 0.5 AS distance_bin
#
# 3. For box plots:
# - SQLite doesn't support PERCENTILE_CONT or window functions
# - Simply return the raw data column: SELECT column_name FROM data_tab
# - The application will calculate quartiles and outliers
#
# 4. For heatmaps:
# - Return raw data for correlation analysis
# - Example: SELECT numeric_col1, numeric_col2, numeric_col3 FROM data_tab
#
# 5. Always use 'data_tab' as the table name
#
# 6. IMPORTANT: Return ONLY the SQL query without any markdown formatting, explanations, or code blocks
#
# Question: {question}
# Visualization type: {viz_type}
# """)
# Add this helper function to clean SQL queries
def clean_sql_query(query_text):
"""Clean SQL query text by removing markdown formatting and comments"""
# Check if input is None or empty
if not query_text:
return "SELECT * FROM data_tab LIMIT 10;"
# Remove markdown code blocks
if "```" in query_text:
# Extract content between code blocks
pattern = r"```(?:sql)?(.*?)```"
matches = re.findall(pattern, query_text, re.DOTALL)
if matches:
query_text = matches[0].strip()
# Remove any "Here is the SQL query" text that might precede the query
prefixes = [
"here is the sql query",
"here is the sqlite query",
"here is a query",
"here's the sql query",
"the sql query is",
"sql query:"
]
for prefix in prefixes:
if query_text.lower().startswith(prefix):
# Find the first occurrence of "SELECT", "WITH", etc.
sql_keywords = ["select", "with", "create", "insert", "update", "delete"]
positions = [query_text.lower().find(keyword) for keyword in sql_keywords]
positions = [pos for pos in positions if pos != -1]
if positions:
start_pos = min(positions)
query_text = query_text[start_pos:]
# Remove SQL comments
query_text = re.sub(r'--.*?(\n|$)', ' ', query_text)
# Remove trailing semicolon if present
query_text = query_text.strip().rstrip(';')
# Ensure the query is not empty
if not query_text.strip():
return "SELECT * FROM data_tab LIMIT 10;"
return query_text
def process_text_query(query, history):
"""Process a text query and update chat history"""
if not query:
return "", history
# Add the user's query to history
history.append([query, None])
start_time = time.time()
# Define visualization keywords at the beginning
viz_keywords = {
'bar': ['bar chart', 'bar graph', 'bar plot', 'barchart', 'bargraph'],
'line': ['line chart', 'line graph', 'line plot', 'linechart', 'trend', 'trends', 'time series'],
'pie': ['pie chart', 'pie graph', 'pie plot', 'piechart', 'distribution', 'proportion'],
'histogram': ['histogram', 'distribution of', 'frequency distribution'],
'box': ['box plot', 'boxplot', 'box and whisker', 'outliers', 'quartiles'],
'heatmap': ['heatmap', 'heat map', 'correlation matrix', 'correlation heatmap'],
'scatter': ['scatter', 'scatter plot', 'relationship between', 'correlation between']
}
# Check if this is a visualization request
is_visualization = any(word in query.lower() for word in ['plot', 'graph', 'chart', 'visualize', 'visualization', 'trend', 'show me'])
# Determine visualization type from query
viz_type = None
if is_visualization:
for vtype, keywords in viz_keywords.items():
if any(keyword in query.lower() for keyword in keywords):
viz_type = vtype
break
# Check if we're in CSV context or have documents loaded
if current_context["file_type"] == "csv" and current_context["table_name"]:
try:
# Connect to the database
conn = sqlite3.connect(DB_PATH)
# Get column information for context
cursor = conn.cursor()
cursor.execute(f"PRAGMA table_info({current_context['table_name']});")
columns = [info[1] for info in cursor.fetchall()]
columns_str = ", ".join(columns)
# Create question with context
question_with_context = f"The table 'data_tab' has columns: {columns_str}. {query}"
# Special handling for visualization types that need raw data
if is_visualization and viz_type in ['box', 'heatmap']:
# For box plots and heatmaps, we need raw data
if viz_type == 'box':
# For box plots, we need a single numeric column
numeric_cols_query = "SELECT name FROM pragma_table_info('data_tab') WHERE type LIKE '%INT%' OR type LIKE '%REAL%' OR type LIKE '%FLOA%' OR type LIKE '%NUM%';"
cursor = conn.cursor()
cursor.execute(numeric_cols_query)
numeric_cols = [row[0] for row in cursor.fetchall()]
if numeric_cols:
# Find the relevant numeric column based on the query
target_col = None
for col in numeric_cols:
if col.lower() in query.lower():
target_col = col
break
# If no specific column is mentioned, use the first numeric column
if not target_col and numeric_cols:
target_col = numeric_cols[0]
# Generate a simple query to get the raw data
sql_query = f"SELECT {target_col} FROM data_tab WHERE {target_col} IS NOT NULL;"
else:
# No numeric columns found
sql_query = "SELECT * FROM data_tab LIMIT 10;"
elif viz_type == 'heatmap':
# For heatmaps, we need multiple numeric columns
numeric_cols_query = "SELECT name FROM pragma_table_info('data_tab') WHERE type LIKE '%INT%' OR type LIKE '%REAL%' OR type LIKE '%FLOA%' OR type LIKE '%NUM%';"
cursor = conn.cursor()
cursor.execute(numeric_cols_query)
numeric_cols = [row[0] for row in cursor.fetchall()]
if len(numeric_cols) >= 2:
# Use all numeric columns (up to a reasonable limit)
cols_to_use = numeric_cols[:10] # Limit to 10 columns for performance
cols_str = ", ".join(cols_to_use)
sql_query = f"SELECT {cols_str} FROM data_tab WHERE {numeric_cols[0]} IS NOT NULL LIMIT 1000;"
else:
sql_query = "SELECT * FROM data_tab LIMIT 10;"
else:
# For other queries, use the LLM to generate SQL
sql_query = llm.invoke(query_prompt.format(question=question_with_context)).content
sql_query = clean_sql_query(sql_query)
# Execute the query
result_df = pd.read_sql_query(sql_query, conn)
# Close the connection
conn.close()
# Format the dataframe as a string table for display
df_str = result_df.to_string()
# Generate text response
data_summary = result_df.to_string()
analysis = llm.invoke(interpret_prompt.format(
question=query,
sql_query=sql_query,
data_summary=data_summary
)).content
# Create a comprehensive response that includes:
# 1. SQL Query
# 2. Results as a table
# 3. Analysis of the results
comprehensive_response = f"""
### SQL Query:
```sql
{sql_query}
```
### Results:
```
{df_str}
```
### Analysis:
{analysis}
"""
# Generate visualization if requested
if is_visualization:
viz_html = generate_visualization(result_df, query)
if viz_html:
# Add the visualization to history
history[-1][1] = comprehensive_response
return viz_html, history
# If no visualization or visualization failed, return text response
history[-1][1] = comprehensive_response
return comprehensive_response, history
except Exception as e:
error_msg = f"Error processing query: {str(e)}"
history[-1][1] = error_msg
return error_msg, history
elif document_assistant.get_all_documents():
# Handle document queries
try:
response = document_assistant.process_query(query)
history[-1][1] = response
return response, history
except Exception as e:
error_msg = f"Error processing query: {str(e)}"
history[-1][1] = error_msg
return error_msg, history
else:
# Handle general queries with LLM when no documents are loaded
try:
# Create a general knowledge context prompt
general_prompt = ChatPromptTemplate.from_messages([
("system", "You are a helpful assistant that provides clear, informative responses. Use your knowledge to answer the user's question concisely."),
("human", "{question}")
])
# Get response from LLM
response = llm.invoke(general_prompt.format(question=query)).content
# Add the response to history
history[-1][1] = response
return response, history
except Exception as e:
error_msg = f"Error processing query: {str(e)}"
history[-1][1] = error_msg
return error_msg, history
def process_file_upload(files):
"""Process uploaded files and index them"""
if not files:
return "No files uploaded"
global current_context
# Clear existing context
current_context = {
"file_type": None,
"file_name": None,
"table_name": None
}
file_info = []
for file in files:
file_path = file.name
file_name = os.path.basename(file_path)
file_ext = os.path.splitext(file_name)[1].lower()
if file_ext == '.csv':
try:
# Create table name from filename
table_name = os.path.splitext(file_name)[0].replace(' ', '_').lower()
# Load CSV into SQLite
conn = sqlite3.connect(DB_PATH)
# Configure SQLite for faster imports
conn.execute("PRAGMA synchronous = OFF")
conn.execute("PRAGMA journal_mode = MEMORY")
# Read the CSV and load it into SQLite
df = pd.read_csv(file_path)
df.to_sql('data_tab', conn, if_exists='replace', index=False)
# Update current context
current_context = {
"file_type": "csv",
"file_name": file_name,
"table_name": "data_tab" # Always use data_tab as the table name
}
# Get column info
cursor = conn.cursor()
cursor.execute("PRAGMA table_info(data_tab);")
columns = [f"{col[1]} ({col[2]})" for col in cursor.fetchall()]
# Get row count
cursor.execute("SELECT COUNT(*) FROM data_tab;")
row_count = cursor.fetchone()[0]
conn.close()
file_info.append("✅ CSV File Successfully Loaded")
file_info.append(f"📊 Table Name: data_tab")
file_info.append(f"📄 Source File: {file_name}")
file_info.append(f"📈 Total Rows: {row_count:,}")
file_info.append(f"📋 Columns: {', '.join(columns)}")
except Exception as e:
file_info.append(f"❌ Error loading CSV {file_name}: {str(e)}")
else:
# Process PDF or other document types
try:
result = document_assistant.upload_document(file_path)
# Update current context
current_context = {
"file_type": "pdf",
"file_name": file_name,
"table_name": None
}
file_info.append("✅ Document Successfully Processed")
file_info.append(f"📄 File: {file_name}")
file_info.append(f"📚 Chunks: {result['chunks']}")
file_info.append(result['message'])
except Exception as e:
file_info.append(f"❌ Error processing document {file_name}: {str(e)}")
return "\n".join(file_info)
# Function commented out as it's no longer used
# def list_documents():
# """List all indexed documents"""
# try:
# docs = document_assistant.get_all_documents()
# if not docs:
# return "No documents indexed yet."
#
# result = "Indexed Documents:\n\n"
# for doc in docs:
# result += f"- {doc['filename']} ({doc['file_type']})\n"
#
# return result
# except Exception as e:
# return f"Error listing documents: {str(e)}"
def clear_context():
"""Clear the current context"""
global current_context
try:
# Reset the context
current_context = {
"file_type": None,
"file_name": None,
"table_name": None
}
return [["Context cleared. You can now upload new documents or CSV files.", None]]
except Exception as e:
return [[f"Error clearing context: {str(e)}", None]]
def flush_databases():
"""Flush ChromaDB and SQLite databases"""
global document_assistant
global current_context
result = []
# Flush SQLite database
try:
conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor()
# Get all tables
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
tables = cursor.fetchall()
# Drop all tables
for table in tables:
cursor.execute(f"DROP TABLE IF EXISTS {table[0]};")
conn.commit()
conn.close()
result.append("✅ SQLite database cleared successfully")
except Exception as e:
result.append(f"❌ Error clearing SQLite database: {str(e)}")
# Flush ChromaDB by resetting the document assistant
try:
success = document_assistant.reset_database()
if success:
result.append("✅ ChromaDB cleared successfully")
else:
# Even if reset fails, we can still reinitialize the document assistant
# This is a workaround that creates a fresh instance
document_assistant = DocumentAssistant()
result.append("⚠️ ChromaDB reset partially completed - created new instance")
except Exception as e:
result.append(f"❌ Error clearing ChromaDB: {str(e)}")
# Reset current context
current_context = {
"file_type": None,
"file_name": None,
"table_name": None
}
return "\n".join(result)
# At the beginning of app.py, after the imports
# Add this code to monkey patch the vector_db module
try:
from backend.vector_db import ChromaVectorDB
except NameError as e:
if "response" in str(e):
# If the error is about 'response' not being defined, fix the module
import backend.vector_db
# Remove the problematic code
if hasattr(backend.vector_db, 'response'):
delattr(backend.vector_db, 'response')
# Reload the module
importlib.reload(backend.vector_db)
from backend.vector_db import ChromaVectorDB
# Add this function to app.py
def generate_visualization(result_df, query):
"""Generate a visualization based on the query and data"""
try:
print("Visualization requested, attempting to create plot...")
# Set common figure parameters
fig_width = 1200 # Increased for better quality
fig_height = 800 # Maintain aspect ratio
# Determine visualization type from query
viz_type = 'bar' # Default
if any(word in query.lower() for word in ['pie', 'distribution', 'proportion']):
viz_type = 'pie'
elif any(word in query.lower() for word in ['line', 'trend', 'time series']):
viz_type = 'line'
elif any(word in query.lower() for word in ['scatter', 'relationship']):
viz_type = 'scatter'
elif any(word in query.lower() for word in ['histogram', 'distribution of']):
viz_type = 'histogram'
elif any(word in query.lower() for word in ['box', 'boxplot', 'outliers']):
viz_type = 'box'
elif any(word in query.lower() for word in ['heatmap', 'correlation']):
viz_type = 'heatmap'
print(f"Creating {viz_type} visualization...")
# Find numeric columns
numeric_cols = result_df.select_dtypes(include=['number']).columns.tolist()
# Create basic visualization based on type
if viz_type == 'pie' and len(result_df) <= 20:
# Simple pie chart
labels = result_df.iloc[:, 0].tolist()
values = result_df.iloc[:, 1].tolist() if len(result_df.columns) > 1 else [1] * len(result_df)
import plotly.graph_objects as go
fig = go.Figure(data=[go.Pie(labels=labels, values=values)])
fig.update_layout(title_text='Pie Chart')
elif viz_type == 'histogram' and len(numeric_cols) > 0:
# Simple histogram
import plotly.express as px
fig = px.histogram(result_df, x=numeric_cols[0])
fig.update_layout(title_text=f'Histogram of {numeric_cols[0]}')
elif viz_type == 'box' and len(numeric_cols) > 0:
# Simple box plot
import plotly.express as px
fig = px.box(result_df, y=numeric_cols[0])
fig.update_layout(title_text=f'Box Plot of {numeric_cols[0]}')
elif viz_type == 'heatmap' and len(numeric_cols) >= 2:
# Simple heatmap
import plotly.express as px
# Create correlation matrix
corr_df = result_df[numeric_cols].corr()
fig = px.imshow(corr_df, text_auto=True)
fig.update_layout(title_text='Correlation Heatmap')
elif viz_type == 'scatter' and len(numeric_cols) >= 2:
# Simple scatter plot
import plotly.express as px
fig = px.scatter(result_df, x=numeric_cols[0], y=numeric_cols[1])
fig.update_layout(title_text=f'Scatter Plot of {numeric_cols[0]} vs {numeric_cols[1]}')
elif viz_type == 'line':
# Simple line chart
import plotly.express as px
x_col = result_df.columns[0]
y_cols = numeric_cols if numeric_cols else [result_df.columns[1]] if len(result_df.columns) > 1 else None
if y_cols:
fig = px.line(result_df, x=x_col, y=y_cols[0])
fig.update_layout(
title_text=f'Line Chart of {y_cols[0]} over {x_col}',
xaxis=dict(
tickangle=-45,
tickmode='auto',
nticks=20
)
)
else:
# Fallback to bar chart
viz_type = 'bar'
if viz_type == 'bar' or 'fig' not in locals():
# Simple bar chart (default)
import plotly.express as px
x_col = result_df.columns[0]
y_col = numeric_cols[0] if numeric_cols else result_df.columns[1] if len(result_df.columns) > 1 else None
# Check if we have many categories (more than 10)
if len(result_df) > 10:
# Use horizontal bar chart for many categories
if y_col:
fig = px.bar(
result_df,
y=x_col, # Swap x and y for horizontal orientation
x=y_col,
orientation='h', # Horizontal orientation
title=f'Bar Chart of {y_col} by {x_col}'
)
else:
fig = px.bar(
result_df,
y=x_col, # Swap x and y for horizontal orientation
orientation='h', # Horizontal orientation
title=f'Bar Chart of {x_col}'
)
else:
# Use vertical bar chart for fewer categories
if y_col:
fig = px.bar(
result_df,
x=x_col,
y=y_col,
title=f'Bar Chart of {y_col} by {x_col}'
)
else:
fig = px.bar(
result_df,
x=x_col,
title=f'Bar Chart of {x_col}'
)
# Improve bar chart layout
fig.update_layout(
bargap=0.2, # Increase gap between bars
uniformtext_minsize=8, # Minimum text size
uniformtext_mode='hide' # Hide text if it doesn't fit
)
# Set common layout properties
fig.update_layout(
width=fig_width,
height=fig_height,
template="plotly_white",
margin=dict(l=40, r=40, t=80, b=80, pad=4), # Balanced margins
autosize=True, # Allow the plot to resize with the container
plot_bgcolor='rgba(240,240,240,0.2)', # Light gray background
paper_bgcolor='white',
font=dict(size=12) # Increase font size
)
# Add hover information
fig.update_traces(
hovertemplate="%{x}: %{y}