SVashishta1
Update supported file types in capabilities description
42d8891
raw
history blame
37.5 kB
import os
import sys
import gradio as gr
from dotenv import load_dotenv
import tempfile
import pandas as pd
import sqlite3
from langchain_core.prompts import ChatPromptTemplate
from langchain_groq import ChatGroq
import plotly.express as px
import time
import plotly.io as pio
import traceback
import base64
from io import BytesIO
# I am commenting out voice libraries because we are not using them right now
# import speech_recognition as sr
# from gtts import gTTS
import re
import importlib.util
# Load environment variables
load_dotenv()
# Add parent directory to path to import backend modules
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from backend.main import DocumentAssistant
# Initialize the document assistant
document_assistant = DocumentAssistant()
# Initialize the LLM using the llama3-8b-8192 model from Groq
llm = ChatGroq(
model="llama3-8b-8192",
temperature=0,
max_tokens=None,
timeout=None,
max_retries=2,
verbose=True,
api_key=os.getenv("GROQ_API_KEY")
)
# Database path for CSV data
DB_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data", "csv_data.db")
os.makedirs(os.path.dirname(DB_PATH), exist_ok=True)
# Current context to track what we're working with
current_context = {
"file_type": None,
"file_name": None,
"table_name": None
}
# Add a global variable to store the current plot
current_plot = None
# Define the prompt with examples for SQL query generation
query_prompt = ChatPromptTemplate.from_template("""
You are a SQL expert. Given a question about data in a table, write a SQLite-compatible SQL query to answer the question.
Important guidelines:
1. Use SQLite syntax (not PostgreSQL or MySQL)
2. For date functions, use strftime() instead of EXTRACT
- Example: strftime('%Y', date_column) instead of EXTRACT(YEAR FROM date_column)
3. SQLite doesn't have TRUNCATE function, use CAST((column / bin_size) AS INT) * bin_size instead
4. For percentiles, use window functions or approximate methods
5. Keep queries efficient and focused on answering the specific question
6. Always use 'data_tab' as the table name
7. IMPORTANT: Return ONLY the SQL query without any markdown formatting, explanations, or code blocks
Question: {question}
""")
# Define the prompt for interpreting the SQL query result
interpret_prompt = ChatPromptTemplate.from_messages(
[
("system", "You are an experienced data analyst. Provide a concise, natural language answer based on the given data summary. If relevant, give key statistics, trends, or patterns."),
("human", "Question: {question}\nSQL Query: {sql_query}\nData Summary:\n{data_summary}")
]
)
# Add this after the query_prompt definition
visualization_prompt = ChatPromptTemplate.from_template("""
You are a data visualization expert. Given a question about visualizing data, write a SQLite-compatible SQL query that will retrieve the appropriate data for the visualization.
Important guidelines for SQLite syntax:
1. Use strftime() for date functions:
- Year: strftime('%Y', date_column)
- Month: strftime('%m', date_column)
- Day: strftime('%d', date_column)
- Hour: strftime('%H', date_column)
2. For histograms and binning:
- Use: CAST((column / bin_size) AS INT) * bin_size
- Example: CAST((trip_distance / 0.5) AS INT) * 0.5 AS distance_bin
3. For box plots:
- SQLite doesn't support PERCENTILE_CONT or window functions
- Simply return the raw data column: SELECT column_name FROM data_tab
- The application will calculate quartiles and outliers
4. For heatmaps:
- Return raw data for correlation analysis
- Example: SELECT numeric_col1, numeric_col2, numeric_col3 FROM data_tab
5. Always use 'data_tab' as the table name
6. IMPORTANT: Return ONLY the SQL query without any markdown formatting, explanations, or code blocks
Question: {question}
Visualization type: {viz_type}
""")
# Add this helper function to clean SQL queries
def clean_sql_query(query_text):
"""Clean SQL query text by removing markdown formatting and comments"""
# Check if input is None or empty
if not query_text:
return "SELECT * FROM data_tab LIMIT 10;"
# Remove markdown code blocks
if "```" in query_text:
# Extract content between code blocks
pattern = r"```(?:sql)?(.*?)```"
matches = re.findall(pattern, query_text, re.DOTALL)
if matches:
query_text = matches[0].strip()
# Remove any "Here is the SQL query" text that might precede the query
prefixes = [
"here is the sql query",
"here is the sqlite query",
"here is a query",
"here's the sql query",
"the sql query is",
"sql query:"
]
for prefix in prefixes:
if query_text.lower().startswith(prefix):
# Find the first occurrence of "SELECT", "WITH", etc.
sql_keywords = ["select", "with", "create", "insert", "update", "delete"]
positions = [query_text.lower().find(keyword) for keyword in sql_keywords]
positions = [pos for pos in positions if pos != -1]
if positions:
start_pos = min(positions)
query_text = query_text[start_pos:]
# Remove SQL comments
query_text = re.sub(r'--.*?(\n|$)', ' ', query_text)
# Remove trailing semicolon if present
query_text = query_text.strip().rstrip(';')
# Ensure the query is not empty
if not query_text.strip():
return "SELECT * FROM data_tab LIMIT 10;"
return query_text
def process_text_query(query, history):
"""Process a text query and update chat history"""
if not query:
return "", history
# Add the user's query to history
history.append({"role": "user", "content": query})
start_time = time.time()
# Define visualization keywords at the beginning
viz_keywords = {
'bar': ['bar chart', 'bar graph', 'bar plot', 'barchart', 'bargraph'],
'line': ['line chart', 'line graph', 'line plot', 'linechart', 'trend', 'trends', 'time series'],
'pie': ['pie chart', 'pie graph', 'pie plot', 'piechart', 'distribution', 'proportion'],
'histogram': ['histogram', 'distribution of', 'frequency distribution'],
'box': ['box plot', 'boxplot', 'box and whisker', 'outliers', 'quartiles'],
'heatmap': ['heatmap', 'heat map', 'correlation matrix', 'correlation heatmap'],
'scatter': ['scatter', 'scatter plot', 'relationship between', 'correlation between']
}
# Check if this is a visualization request
is_visualization = any(word in query.lower() for word in ['plot', 'graph', 'chart', 'visualize', 'visualization', 'trend', 'show me'])
# Determine visualization type from query
viz_type = None
if is_visualization:
for vtype, keywords in viz_keywords.items():
if any(keyword in query.lower() for keyword in keywords):
viz_type = vtype
break
# Check if we're in CSV context
if current_context["file_type"] == "csv" and current_context["table_name"]:
try:
# Connect to the database
conn = sqlite3.connect(DB_PATH)
# Get column information for context
cursor = conn.cursor()
cursor.execute(f"PRAGMA table_info({current_context['table_name']});")
columns = [info[1] for info in cursor.fetchall()]
columns_str = ", ".join(columns)
# Create question with context
question_with_context = f"The table 'data_tab' has columns: {columns_str}. {query}"
# Special handling for visualization types that need raw data
if is_visualization and viz_type in ['box', 'heatmap']:
# For box plots and heatmaps, we need raw data
if viz_type == 'box':
# For box plots, we need a single numeric column
numeric_cols_query = "SELECT name FROM pragma_table_info('data_tab') WHERE type LIKE '%INT%' OR type LIKE '%REAL%' OR type LIKE '%FLOA%' OR type LIKE '%NUM%';"
cursor = conn.cursor()
cursor.execute(numeric_cols_query)
numeric_cols = [row[0] for row in cursor.fetchall()]
if numeric_cols:
# Find the relevant numeric column based on the query
target_col = None
for col in numeric_cols:
if col.lower() in query.lower():
target_col = col
break
# If no specific column is mentioned, use the first numeric column
if not target_col and numeric_cols:
target_col = numeric_cols[0]
# Generate a simple query to get the raw data
sql_query = f"SELECT {target_col} FROM data_tab WHERE {target_col} IS NOT NULL;"
else:
# No numeric columns found
sql_query = "SELECT * FROM data_tab LIMIT 10;"
elif viz_type == 'heatmap':
# For heatmaps, we need multiple numeric columns
numeric_cols_query = "SELECT name FROM pragma_table_info('data_tab') WHERE type LIKE '%INT%' OR type LIKE '%REAL%' OR type LIKE '%FLOA%' OR type LIKE '%NUM%';"
cursor = conn.cursor()
cursor.execute(numeric_cols_query)
numeric_cols = [row[0] for row in cursor.fetchall()]
if len(numeric_cols) >= 2:
# Use all numeric columns (up to a reasonable limit)
cols_to_use = numeric_cols[:10] # Limit to 10 columns for performance
cols_str = ", ".join(cols_to_use)
sql_query = f"SELECT {cols_str} FROM data_tab WHERE {numeric_cols[0]} IS NOT NULL LIMIT 1000;"
else:
# Not enough numeric columns
sql_query = "SELECT * FROM data_tab LIMIT 10;"
else:
# Generate SQL query using LLM
ai_msg = query_prompt | llm
raw_sql_query = ai_msg.invoke({"question": question_with_context}).content.strip()
# Clean the SQL query
sql_query = clean_sql_query(raw_sql_query)
print(f"Generated SQL Query: {sql_query}")
try:
# Execute the query
result_df = pd.read_sql_query(sql_query, conn)
# Generate data summary
if not result_df.empty:
data_summary = result_df.describe(include='all').to_string()
# For small result sets, include the actual data
if len(result_df) <= 10:
data_summary += f"\n\nFull Results:\n{result_df.to_string()}"
else:
data_summary += f"\n\nFirst 5 rows:\n{result_df.head(5).to_string()}"
else:
data_summary = "No relevant data found."
# Generate interpretation
answer_chain = interpret_prompt | llm
interpretation = answer_chain.invoke({
"question": query,
"sql_query": sql_query,
"data_summary": data_summary
}).content.strip()
# Create the response
response = f"**SQL Query:**\n```sql\n{sql_query}\n```\n\n"
if not result_df.empty:
if len(result_df) > 10:
response += f"**Results (first 5 of {len(result_df)} rows):**\n```\n{result_df.head(5).to_string()}\n```\n\n"
else:
response += f"**Results:**\n```\n{result_df.to_string()}\n```\n\n"
else:
response += "**No results found.**\n\n"
response += f"**Analysis:**\n{interpretation}"
# Add visualization if requested
if is_visualization and not result_df.empty:
try:
# Generate visualization
viz_html = generate_visualization(result_df, query)
if viz_html:
# Add the visualization to the response
response += f"\n\n{viz_html}"
# Add note about visualization
response += "\n\n**A visualization has been generated and is displayed above.**"
else:
response += "\n\n**Could not generate visualization due to an error.**"
except Exception as viz_error:
print(f"Visualization error: {str(viz_error)}")
import traceback
traceback.print_exc()
except Exception as e:
response = f"**SQL Query:**\n```sql\n{sql_query}\n```\n\n**Error executing query:** {str(e)}"
conn.close()
except Exception as e:
response = f"Error processing query: {str(e)}"
else:
# For non-CSV queries, use the document assistant
try:
response = document_assistant.process_query(query)
except Exception as e:
response = f"Error processing document query: {str(e)}"
# Calculate processing time
processing_time = time.time() - start_time
response += f"\n\n(Query processed in {processing_time:.2f} seconds)"
# Add the response to history
history.append({"role": "assistant", "content": response})
return "", history
def process_file_upload(files):
"""Process uploaded files and index them"""
if not files:
return "No files uploaded"
global current_context
# Clear existing context
current_context = {
"file_type": None,
"file_name": None,
"table_name": None
}
file_info = []
for file in files:
file_path = file.name
file_name = os.path.basename(file_path)
file_ext = os.path.splitext(file_name)[1].lower()
if file_ext == '.csv':
try:
# Create table name from filename
table_name = os.path.splitext(file_name)[0].replace(' ', '_').lower()
# Load CSV into SQLite
conn = sqlite3.connect(DB_PATH)
# Configure SQLite for faster imports
conn.execute("PRAGMA synchronous = OFF")
conn.execute("PRAGMA journal_mode = MEMORY")
# Read the CSV and load it into SQLite
df = pd.read_csv(file_path)
df.to_sql('data_tab', conn, if_exists='replace', index=False)
# Update current context
current_context = {
"file_type": "csv",
"file_name": file_name,
"table_name": "data_tab" # Always use data_tab as the table name
}
# Get column info
cursor = conn.cursor()
cursor.execute("PRAGMA table_info(data_tab);")
columns = [f"{col[1]} ({col[2]})" for col in cursor.fetchall()]
# Get row count
cursor.execute("SELECT COUNT(*) FROM data_tab;")
row_count = cursor.fetchone()[0]
conn.close()
file_info.append("βœ… CSV File Successfully Loaded")
file_info.append(f"πŸ“Š Table Name: data_tab")
file_info.append(f"πŸ“„ Source File: {file_name}")
file_info.append(f"πŸ“ˆ Total Rows: {row_count:,}")
file_info.append(f"πŸ“‹ Columns: {', '.join(columns)}")
except Exception as e:
file_info.append(f"❌ Error loading CSV {file_name}: {str(e)}")
else:
# Process PDF or other document types
try:
result = document_assistant.upload_document(file_path)
# Update current context
current_context = {
"file_type": "pdf",
"file_name": file_name,
"table_name": None
}
file_info.append("βœ… Document Successfully Processed")
file_info.append(f"πŸ“„ File: {file_name}")
file_info.append(f"πŸ“š Chunks: {result['chunks']}")
file_info.append(result['message'])
except Exception as e:
file_info.append(f"❌ Error processing document {file_name}: {str(e)}")
return "\n".join(file_info)
def list_documents():
"""List all indexed documents"""
try:
docs = document_assistant.get_all_documents()
if not docs:
return "No documents indexed yet."
result = "Indexed Documents:\n\n"
for doc in docs:
result += f"- {doc['filename']} ({doc['file_type']})\n"
return result
except Exception as e:
return f"Error listing documents: {str(e)}"
def clear_context():
"""Clear the current context"""
global current_context
try:
# Reset the context
current_context = {
"file_type": None,
"file_name": None,
"table_name": None
}
return [{"role": "assistant", "content": "Context cleared. You can now upload new documents or CSV files."}]
except Exception as e:
return [{"role": "assistant", "content": f"Error clearing context: {str(e)}"}]
# I am making a function for voice input but we are not using it in this version(still in development phase)
"""
def process_voice_input(audio_path):
# I am checking if there is audio
if audio_path is None:
return "No audio recorded"
try:
# I am making a recognizer for the voice
r = sr.Recognizer()
# I am loading the audio file
with sr.AudioFile(audio_path) as source:
# I am reading the audio data
audio_data = r.record(source)
# I am recognizing speech using Google
text = r.recognize_google(audio_data)
return text
except sr.UnknownValueError:
return "Could not understand audio"
except sr.RequestError as e:
return f"Error with speech recognition service: {e}"
except Exception as e:
return f"Error processing audio: {str(e)}"
"""
# a function for text to speech
"""
def text_to_speech_output(text):
# I am checking if there is text
if not text or len(text) == 0:
return None
# I am finding the last message from assistant
last_message = None
for msg in reversed(text):
if msg["role"] == "assistant":
last_message = msg["content"]
break
if not last_message:
return None
try:
# I am cleaning the text
clean_text = re.sub(r'<.*?>', '', last_message) # I am removing HTML tags
clean_text = re.sub(r'\*\*(.*?)\*\*', r'\1', clean_text) # I am removing bold markdown
clean_text = re.sub(r'\n\n', ' ', clean_text) # I am replacing double newlines with space
clean_text = re.sub(r'```.*?```', 'Code block removed for speech.', clean_text, flags=re.DOTALL) # I am replacing code blocks
# I am creating a temporary file
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3")
temp_file.close()
# I am generating speech
tts = gTTS(text=clean_text, lang='en', slow=False)
tts.save(temp_file.name)
return temp_file.name
except Exception as e:
print(f"Error generating speech: {str(e)}")
return None
"""
def create_test_visualization():
"""Create a test visualization to verify plotting works"""
try:
# Create sample data
data = pd.DataFrame({
'Month': ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun'],
'Value': [10, 15, 13, 17, 20, 25]
})
# Create a simple bar chart
fig = px.bar(data, x='Month', y='Value', title='Test Visualization')
# Configure the figure
fig.update_layout(
autosize=True,
width=800,
height=500
)
return fig
except Exception as e:
print(f"Error creating test visualization: {str(e)}")
return None
def create_test_html_visualization():
"""Create a test HTML visualization to verify plotting works"""
try:
# Create sample data
data = pd.DataFrame({
'Month': ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun'],
'Value': [10, 15, 13, 17, 20, 25]
})
# Create a simple bar chart
fig = px.bar(data, x='Month', y='Value', title='Test Visualization')
# Convert to HTML
html = pio.to_html(fig, full_html=False)
return html
except Exception as e:
print(f"Error creating test HTML visualization: {str(e)}")
return None
def flush_databases():
"""Flush ChromaDB and SQLite databases"""
global document_assistant
global current_context
result = []
# Flush SQLite database
try:
conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor()
# Get all tables
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
tables = cursor.fetchall()
# Drop all tables
for table in tables:
cursor.execute(f"DROP TABLE IF EXISTS {table[0]};")
conn.commit()
conn.close()
result.append("βœ… SQLite database cleared successfully")
except Exception as e:
result.append(f"❌ Error clearing SQLite database: {str(e)}")
# Flush ChromaDB by resetting the document assistant
try:
success = document_assistant.reset_database()
if success:
result.append("βœ… ChromaDB cleared successfully")
else:
# Even if reset fails, we can still reinitialize the document assistant
# This is a workaround that creates a fresh instance
document_assistant = DocumentAssistant()
result.append("⚠️ ChromaDB reset partially completed - created new instance")
except Exception as e:
result.append(f"❌ Error clearing ChromaDB: {str(e)}")
# Reset current context
current_context = {
"file_type": None,
"file_name": None,
"table_name": None
}
return "\n".join(result)
# At the beginning of app.py, after the imports
# Add this code to monkey patch the vector_db module
try:
from backend.vector_db import ChromaVectorDB
except NameError as e:
if "response" in str(e):
# If the error is about 'response' not being defined, fix the module
import backend.vector_db
# Remove the problematic code
if hasattr(backend.vector_db, 'response'):
delattr(backend.vector_db, 'response')
# Reload the module
importlib.reload(backend.vector_db)
from backend.vector_db import ChromaVectorDB
# Add this function to app.py
def generate_visualization(result_df, query):
"""Generate a visualization based on the query and data"""
try:
print("Visualization requested, attempting to create plot...")
# Set common figure parameters
fig_width = 900 # Adjusted for a more square shape
fig_height = 800 # Increased to make it more square
# Determine visualization type from query
viz_type = 'bar' # Default
if any(word in query.lower() for word in ['pie', 'distribution', 'proportion']):
viz_type = 'pie'
elif any(word in query.lower() for word in ['line', 'trend', 'time series']):
viz_type = 'line'
elif any(word in query.lower() for word in ['scatter', 'relationship']):
viz_type = 'scatter'
elif any(word in query.lower() for word in ['histogram', 'distribution of']):
viz_type = 'histogram'
elif any(word in query.lower() for word in ['box', 'boxplot', 'outliers']):
viz_type = 'box'
elif any(word in query.lower() for word in ['heatmap', 'correlation']):
viz_type = 'heatmap'
print(f"Creating {viz_type} visualization...")
# Find numeric columns
numeric_cols = result_df.select_dtypes(include=['number']).columns.tolist()
# Create basic visualization based on type
if viz_type == 'pie' and len(result_df) <= 20:
# Simple pie chart
labels = result_df.iloc[:, 0].tolist()
values = result_df.iloc[:, 1].tolist() if len(result_df.columns) > 1 else [1] * len(result_df)
import plotly.graph_objects as go
fig = go.Figure(data=[go.Pie(labels=labels, values=values)])
fig.update_layout(title_text='Pie Chart')
elif viz_type == 'histogram' and len(numeric_cols) > 0:
# Simple histogram
import plotly.express as px
fig = px.histogram(result_df, x=numeric_cols[0])
fig.update_layout(title_text=f'Histogram of {numeric_cols[0]}')
elif viz_type == 'box' and len(numeric_cols) > 0:
# Simple box plot
import plotly.express as px
fig = px.box(result_df, y=numeric_cols[0])
fig.update_layout(title_text=f'Box Plot of {numeric_cols[0]}')
elif viz_type == 'heatmap' and len(numeric_cols) >= 2:
# Simple heatmap
import plotly.express as px
# Create correlation matrix
corr_df = result_df[numeric_cols].corr()
fig = px.imshow(corr_df, text_auto=True)
fig.update_layout(title_text='Correlation Heatmap')
elif viz_type == 'scatter' and len(numeric_cols) >= 2:
# Simple scatter plot
import plotly.express as px
fig = px.scatter(result_df, x=numeric_cols[0], y=numeric_cols[1])
fig.update_layout(title_text=f'Scatter Plot of {numeric_cols[0]} vs {numeric_cols[1]}')
elif viz_type == 'line':
# Simple line chart
import plotly.express as px
x_col = result_df.columns[0]
y_cols = numeric_cols if numeric_cols else [result_df.columns[1]] if len(result_df.columns) > 1 else None
if y_cols:
fig = px.line(result_df, x=x_col, y=y_cols[0])
fig.update_layout(
title_text=f'Line Chart of {y_cols[0]} over {x_col}',
xaxis=dict(
tickangle=-45,
tickmode='auto',
nticks=20
)
)
else:
# Fallback to bar chart
viz_type = 'bar'
if viz_type == 'bar' or 'fig' not in locals():
# Simple bar chart (default)
import plotly.express as px
x_col = result_df.columns[0]
y_col = numeric_cols[0] if numeric_cols else result_df.columns[1] if len(result_df.columns) > 1 else None
# Check if we have many categories (more than 10)
if len(result_df) > 10:
# Use horizontal bar chart for many categories
if y_col:
fig = px.bar(
result_df,
y=x_col, # Swap x and y for horizontal orientation
x=y_col,
orientation='h', # Horizontal orientation
title=f'Bar Chart of {y_col} by {x_col}'
)
else:
fig = px.bar(
result_df,
y=x_col, # Swap x and y for horizontal orientation
orientation='h', # Horizontal orientation
title=f'Bar Chart of {x_col}'
)
else:
# Use vertical bar chart for fewer categories
if y_col:
fig = px.bar(
result_df,
x=x_col,
y=y_col,
title=f'Bar Chart of {y_col} by {x_col}',
width=900,
height=800
)
else:
fig = px.bar(
result_df,
x=x_col,
title=f'Bar Chart of {x_col}',
width=900,
height=800
)
# Improve bar chart layout
fig.update_layout(
bargap=0.2, # Increase gap between bars
uniformtext_minsize=8, # Minimum text size
uniformtext_mode='hide' # Hide text if it doesn't fit
)
# Set common layout properties
fig.update_layout(
width=fig_width,
height=fig_height,
template="plotly_white",
margin=dict(l=40, r=40, t=80, b=80, pad=4), # Balanced margins
autosize=True, # Allow the plot to resize with the container
plot_bgcolor='rgba(240,240,240,0.2)', # Light gray background
paper_bgcolor='white'
)
print(f"Created figure with width={fig_width}, height={fig_height}")
# Convert to image
print("Converting figure to image...")
img_bytes = pio.to_image(fig, format="png", width=fig_width, height=fig_height, scale=2)
print("Image conversion successful")
# Encode as base64
import base64
encoded = base64.b64encode(img_bytes).decode("ascii")
img_src = f"data:image/png;base64,{encoded}"
print("HTML conversion successful")
# Return the HTML img tag
return f"<img src='{img_src}' width='100%' style='max-width:900px; height:800px; object-fit:contain; display:block; margin:0 auto;' />"
except Exception as e:
import traceback
print(f"Error generating visualization: {str(e)}")
traceback.print_exc()
return None
# Create Gradio interface
with gr.Blocks(title="LLM Powered Database Chatbot") as demo:
gr.Markdown("# πŸ€– LLM Powered Database Chatbot")
gr.Markdown("Upload documents, ask questions, and get AI-powered responses!")
with gr.Tab("Chat"):
# Use a custom CSS to ensure images are displayed properly
gr.HTML("""
<style>
.chatbot-container img {
max-width: 100%;
height: auto;
display: block;
margin: 10px 0;
}
</style>
""")
chatbot = gr.Chatbot(height=500, type="messages", elem_classes="chatbot-container")
# Add information about capabilities
gr.Markdown("""
### Capabilities:
- **Data Analysis**: Ask questions about your data and get detailed responses
- **Visualization**: Request and view graphs and charts of your data
- **Multiple File Types**: Upload PDFs, TXT, DOCX, CSV, and XLSX files for analysis
- **Natural Language Queries**: Ask questions in plain English about your documents
""")
with gr.Row():
with gr.Column(scale=8):
msg = gr.Textbox(
placeholder="Ask a question about your documents...",
show_label=False
)
with gr.Column(scale=1):
# I am commenting out the voice button because we are not using it
# voice_btn = gr.Button("🎀")
pass # I am using pass so the code still works
with gr.Row():
submit_btn = gr.Button("Submit")
clear_btn = gr.Button("Clear")
clear_context_btn = gr.Button("Clear Context")
# I am commenting out audio output because we are not using it
# audio_output = gr.Audio(label="Voice Response", type="filepath")
# I am commenting out voice input because we are not using it
"""
voice_input = gr.Audio(
label="Voice Input",
type="filepath",
visible=False
)
"""
# Event handlers
submit_btn.click(
process_text_query,
inputs=[msg, chatbot],
outputs=[msg, chatbot]
)
msg.submit(
process_text_query,
inputs=[msg, chatbot],
outputs=[msg, chatbot]
)
clear_btn.click(lambda: None, None, [chatbot], queue=False)
clear_context_btn.click(clear_context, inputs=[], outputs=[chatbot])
# I am commenting out voice button click because it is still in development phase
"""
voice_btn.click(
lambda: gr.update(visible=True),
None,
voice_input
)
"""
# I am commenting out voice input change because it is still in development phase
"""
voice_input.change(
process_voice_input,
inputs=[voice_input],
outputs=[msg]
)
"""
# I am commenting out TTS button because it is still in development phase
"""
tts_btn = gr.Button("πŸ”Š Speak Response")
tts_btn.click(
text_to_speech_output,
inputs=[chatbot],
outputs=[audio_output]
)
"""
with gr.Tab("Document Upload"):
file_upload = gr.File(
label="Upload Documents",
file_types=[".pdf", ".txt", ".docx", ".csv", ".xlsx"],
file_count="multiple"
)
with gr.Row():
upload_button = gr.Button("Process & Index Documents", scale=2)
flush_db_btn_doc = gr.Button("πŸ—‘οΈ Flush All Databases", variant="stop", scale=1)
upload_output = gr.Textbox(label="Upload Status")
upload_button.click(
process_file_upload,
inputs=[file_upload],
outputs=[upload_output]
)
flush_db_btn_doc.click(
flush_databases,
inputs=[],
outputs=[upload_output]
)
list_docs_button = gr.Button("List Indexed Documents")
docs_output = gr.Textbox(label="Indexed Documents")
list_docs_button.click(
list_documents,
inputs=[],
outputs=[docs_output]
)
with gr.Tab("Settings"):
with gr.Row():
gr.Markdown("## Database Management")
flush_db_btn = gr.Button("πŸ—‘οΈ Flush All Databases", variant="stop", scale=1)
flush_result = gr.Textbox(label="Flush Result")
flush_db_btn.click(
flush_databases,
inputs=[],
outputs=[flush_result]
)
gr.Markdown("## System Settings")
api_key = gr.Textbox(
label="Groq API Key",
placeholder="Enter your Groq API key",
type="password",
value=os.getenv("GROQ_API_KEY", "")
)
save_btn = gr.Button("Save Settings")
def save_settings(key):
try:
os.environ["GROQ_API_KEY"] = key
return "Settings saved!"
except Exception as e:
return f"Error saving settings: {str(e)}"
save_btn.click(
save_settings,
inputs=[api_key],
outputs=[gr.Textbox(label="Status")]
)
gr.Markdown("## Debugging")
test_viz_btn = gr.Button("Test Visualization")
test_viz_output = gr.HTML(label="Test Visualization")
test_viz_btn.click(
create_test_html_visualization,
inputs=[],
outputs=[test_viz_output]
)
# Launch the app
if __name__ == "__main__":
demo.launch()