import os import sys import gradio as gr from dotenv import load_dotenv import tempfile import pandas as pd import sqlite3 from langchain_core.prompts import ChatPromptTemplate from langchain_groq import ChatGroq import plotly.express as px import time import plotly.io as pio import traceback import base64 from io import BytesIO import speech_recognition as sr from gtts import gTTS import re import importlib.util # Load environment variables load_dotenv() # Add parent directory to path to import backend modules sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from backend.main import DocumentAssistant # Initialize the document assistant document_assistant = DocumentAssistant() # Initialize the LLM using the llama3-8b-8192 model from Groq llm = ChatGroq( model="llama3-8b-8192", temperature=0, max_tokens=None, timeout=None, max_retries=2, verbose=True, api_key=os.getenv("GROQ_API_KEY") ) # Database path for CSV data DB_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data", "csv_data.db") os.makedirs(os.path.dirname(DB_PATH), exist_ok=True) # Current context to track what we're working with current_context = { "file_type": None, "file_name": None, "table_name": None } # Add a global variable to store the current plot current_plot = None # Define the prompt with examples for SQL query generation query_prompt = ChatPromptTemplate.from_template(""" You are a SQL expert. Given a question about data in a table, write a SQLite-compatible SQL query to answer the question. Important guidelines: 1. Use SQLite syntax (not PostgreSQL or MySQL) 2. For date functions, use strftime() instead of EXTRACT - Example: strftime('%Y', date_column) instead of EXTRACT(YEAR FROM date_column) 3. SQLite doesn't have TRUNCATE function, use CAST((column / bin_size) AS INT) * bin_size instead 4. For percentiles, use window functions or approximate methods 5. Keep queries efficient and focused on answering the specific question 6. Always use 'data_tab' as the table name 7. IMPORTANT: Return ONLY the SQL query without any markdown formatting, explanations, or code blocks Question: {question} """) # Define the prompt for interpreting the SQL query result interpret_prompt = ChatPromptTemplate.from_messages( [ ("system", "You are an experienced data analyst. Provide a concise, natural language answer based on the given data summary. If relevant, give key statistics, trends, or patterns."), ("human", "Question: {question}\nSQL Query: {sql_query}\nData Summary:\n{data_summary}") ] ) # Add this after the query_prompt definition visualization_prompt = ChatPromptTemplate.from_template(""" You are a data visualization expert. Given a question about visualizing data, write a SQLite-compatible SQL query that will retrieve the appropriate data for the visualization. Important guidelines for SQLite syntax: 1. Use strftime() for date functions: - Year: strftime('%Y', date_column) - Month: strftime('%m', date_column) - Day: strftime('%d', date_column) - Hour: strftime('%H', date_column) 2. For histograms and binning: - Use: CAST((column / bin_size) AS INT) * bin_size - Example: CAST((trip_distance / 0.5) AS INT) * 0.5 AS distance_bin 3. For box plots: - SQLite doesn't support PERCENTILE_CONT or window functions - Simply return the raw data column: SELECT column_name FROM data_tab - The application will calculate quartiles and outliers 4. For heatmaps: - Return raw data for correlation analysis - Example: SELECT numeric_col1, numeric_col2, numeric_col3 FROM data_tab 5. Always use 'data_tab' as the table name 6. IMPORTANT: Return ONLY the SQL query without any markdown formatting, explanations, or code blocks Question: {question} Visualization type: {viz_type} """) # Add this helper function to clean SQL queries def clean_sql_query(query_text): """Clean SQL query text by removing markdown formatting and comments""" # Check if input is None or empty if not query_text: return "SELECT * FROM data_tab LIMIT 10;" # Remove markdown code blocks if "```" in query_text: # Extract content between code blocks pattern = r"```(?:sql)?(.*?)```" matches = re.findall(pattern, query_text, re.DOTALL) if matches: query_text = matches[0].strip() # Remove any "Here is the SQL query" text that might precede the query prefixes = [ "here is the sql query", "here is the sqlite query", "here is a query", "here's the sql query", "the sql query is", "sql query:" ] for prefix in prefixes: if query_text.lower().startswith(prefix): # Find the first occurrence of "SELECT", "WITH", etc. sql_keywords = ["select", "with", "create", "insert", "update", "delete"] positions = [query_text.lower().find(keyword) for keyword in sql_keywords] positions = [pos for pos in positions if pos != -1] if positions: start_pos = min(positions) query_text = query_text[start_pos:] # Remove SQL comments query_text = re.sub(r'--.*?(\n|$)', ' ', query_text) # Remove trailing semicolon if present query_text = query_text.strip().rstrip(';') # Ensure the query is not empty if not query_text.strip(): return "SELECT * FROM data_tab LIMIT 10;" return query_text def process_text_query(query, history): """Process a text query and update chat history""" if not query: return "", history # Add the user's query to history history.append({"role": "user", "content": query}) start_time = time.time() # Define visualization keywords at the beginning viz_keywords = { 'bar': ['bar chart', 'bar graph', 'bar plot', 'barchart', 'bargraph'], 'line': ['line chart', 'line graph', 'line plot', 'linechart', 'trend', 'trends', 'time series'], 'pie': ['pie chart', 'pie graph', 'pie plot', 'piechart', 'distribution', 'proportion'], 'histogram': ['histogram', 'distribution of', 'frequency distribution'], 'box': ['box plot', 'boxplot', 'box and whisker', 'outliers', 'quartiles'], 'heatmap': ['heatmap', 'heat map', 'correlation matrix', 'correlation heatmap'], 'scatter': ['scatter', 'scatter plot', 'relationship between', 'correlation between'] } # Check if this is a visualization request is_visualization = any(word in query.lower() for word in ['plot', 'graph', 'chart', 'visualize', 'visualization', 'trend', 'show me']) # Determine visualization type from query viz_type = None if is_visualization: for vtype, keywords in viz_keywords.items(): if any(keyword in query.lower() for keyword in keywords): viz_type = vtype break # Check if we're in CSV context if current_context["file_type"] == "csv" and current_context["table_name"]: try: # Connect to the database conn = sqlite3.connect(DB_PATH) # Get column information for context cursor = conn.cursor() cursor.execute(f"PRAGMA table_info({current_context['table_name']});") columns = [info[1] for info in cursor.fetchall()] columns_str = ", ".join(columns) # Create question with context question_with_context = f"The table 'data_tab' has columns: {columns_str}. {query}" # Special handling for visualization types that need raw data if is_visualization and viz_type in ['box', 'heatmap']: # For box plots and heatmaps, we need raw data if viz_type == 'box': # For box plots, we need a single numeric column numeric_cols_query = "SELECT name FROM pragma_table_info('data_tab') WHERE type LIKE '%INT%' OR type LIKE '%REAL%' OR type LIKE '%FLOA%' OR type LIKE '%NUM%';" cursor = conn.cursor() cursor.execute(numeric_cols_query) numeric_cols = [row[0] for row in cursor.fetchall()] if numeric_cols: # Find the relevant numeric column based on the query target_col = None for col in numeric_cols: if col.lower() in query.lower(): target_col = col break # If no specific column is mentioned, use the first numeric column if not target_col and numeric_cols: target_col = numeric_cols[0] # Generate a simple query to get the raw data sql_query = f"SELECT {target_col} FROM data_tab WHERE {target_col} IS NOT NULL;" else: # No numeric columns found sql_query = "SELECT * FROM data_tab LIMIT 10;" elif viz_type == 'heatmap': # For heatmaps, we need multiple numeric columns numeric_cols_query = "SELECT name FROM pragma_table_info('data_tab') WHERE type LIKE '%INT%' OR type LIKE '%REAL%' OR type LIKE '%FLOA%' OR type LIKE '%NUM%';" cursor = conn.cursor() cursor.execute(numeric_cols_query) numeric_cols = [row[0] for row in cursor.fetchall()] if len(numeric_cols) >= 2: # Use all numeric columns (up to a reasonable limit) cols_to_use = numeric_cols[:10] # Limit to 10 columns for performance cols_str = ", ".join(cols_to_use) sql_query = f"SELECT {cols_str} FROM data_tab WHERE {numeric_cols[0]} IS NOT NULL LIMIT 1000;" else: # Not enough numeric columns sql_query = "SELECT * FROM data_tab LIMIT 10;" else: # Generate SQL query using LLM ai_msg = query_prompt | llm raw_sql_query = ai_msg.invoke({"question": question_with_context}).content.strip() # Clean the SQL query sql_query = clean_sql_query(raw_sql_query) print(f"Generated SQL Query: {sql_query}") try: # Execute the query result_df = pd.read_sql_query(sql_query, conn) # Generate data summary if not result_df.empty: data_summary = result_df.describe(include='all').to_string() # For small result sets, include the actual data if len(result_df) <= 10: data_summary += f"\n\nFull Results:\n{result_df.to_string()}" else: data_summary += f"\n\nFirst 5 rows:\n{result_df.head(5).to_string()}" else: data_summary = "No relevant data found." # Generate interpretation answer_chain = interpret_prompt | llm interpretation = answer_chain.invoke({ "question": query, "sql_query": sql_query, "data_summary": data_summary }).content.strip() # Create the response response = f"**SQL Query:**\n```sql\n{sql_query}\n```\n\n" if not result_df.empty: if len(result_df) > 10: response += f"**Results (first 5 of {len(result_df)} rows):**\n```\n{result_df.head(5).to_string()}\n```\n\n" else: response += f"**Results:**\n```\n{result_df.to_string()}\n```\n\n" else: response += "**No results found.**\n\n" response += f"**Analysis:**\n{interpretation}" # Add visualization if requested if is_visualization and not result_df.empty: try: print("Visualization requested, attempting to create plot...") # Set common figure parameters fig_width = 1000 fig_height = 700 # Create the appropriate visualization based on type if viz_type == 'pie' and len(result_df) <= 20: # For pie charts, we need a category column and a value column category_col = result_df.columns[0] value_col = numeric_cols[0] if numeric_cols else result_df.columns[1] # Handle case where all columns are numeric if len(numeric_cols) == len(result_df.columns): category_col = result_df.index.name or 'index' result_df = result_df.reset_index() fig = px.pie( result_df, names=category_col, values=value_col, title=f"Distribution of {value_col} by {category_col}", hole=0.3, # Donut chart for better readability color_discrete_sequence=px.colors.qualitative.Pastel ) elif viz_type == 'histogram' and len(result_df.columns) > 0: # For histograms, we need at least one column # Find the best column for histogram (prefer numeric) if numeric_cols: x_col = numeric_cols[0] else: x_col = result_df.columns[0] # Check if data is already binned if len(result_df) <= 30 and ('bin' in result_df.columns or 'range' in result_df.columns): # Data is pre-binned, use a bar chart bin_col = 'bin' if 'bin' in result_df.columns else 'range' count_col = 'count' if 'count' in result_df.columns else numeric_cols[0] if numeric_cols else result_df.columns[1] fig = px.bar( result_df, x=bin_col, y=count_col, title=f"Histogram of {x_col}", labels={bin_col: x_col, count_col: 'Frequency'}, color_discrete_sequence=['#636EFA'] ) else: # Create a proper histogram from raw data fig = px.histogram( result_df, x=x_col, title=f"Distribution of {x_col}", nbins=20, marginal="box", # Add a box plot on the margin color_discrete_sequence=['#636EFA'], opacity=0.8 ) # Improve histogram layout fig.update_layout( bargap=0.1, # Gap between bars xaxis_title=x_col, yaxis_title='Frequency', showlegend=True ) elif viz_type == 'box' and numeric_cols: # For box plots, we need to handle the data differently # SQLite doesn't support window functions for percentiles # So we'll calculate the box plot statistics in Python # Get the numeric column to plot x_col = numeric_cols[0] # Create a box plot using plotly express fig = px.box( result_df, y=x_col, title=f"Box Plot of {x_col}", points="outliers", # Only show outlier points color_discrete_sequence=['#636EFA'] ) # Add a strip plot (individual points) on the side for better visualization fig.add_trace( px.strip(result_df, y=x_col, color_discrete_sequence=['#FECB52']).data[0] ) elif viz_type == 'heatmap' and len(numeric_cols) >= 2: # For heatmaps, we need at least 2 numeric columns # If we have many numeric columns, create a correlation matrix if len(numeric_cols) >= 3: # Create a correlation matrix # First, drop any rows with NaN values in numeric columns clean_df = result_df[numeric_cols].dropna() if len(clean_df) > 1: # Need at least 2 rows for correlation corr_df = clean_df.corr() # Round to 2 decimal places for display corr_df = corr_df.round(2) fig = px.imshow( corr_df, title="Correlation Heatmap", color_continuous_scale='RdBu_r', text_auto=True, # Show correlation values aspect="auto", zmin=-1, zmax=1 # Set limits for correlation values ) # Improve heatmap layout fig.update_layout( xaxis_title="Features", yaxis_title="Features", coloraxis_colorbar=dict( title="Correlation", thicknessmode="pixels", thickness=20, lenmode="pixels", len=300, yanchor="top", y=1, ticks="outside" ) ) else: # Not enough data for correlation fig = px.bar( pd.DataFrame({'Message': ['Not enough data for heatmap']}), title="Cannot create heatmap - insufficient data" ) else: # If we only have 2 numeric columns, create a 2D histogram x_col = numeric_cols[0] y_col = numeric_cols[1] # Create a 2D histogram (heatmap) fig = px.density_heatmap( result_df, x=x_col, y=y_col, title=f"Density Heatmap of {x_col} vs {y_col}", color_continuous_scale='Viridis', nbinsx=20, nbinsy=20, marginal_x="histogram", # Add histograms on the margins marginal_y="histogram" ) # Improve heatmap layout fig.update_layout( xaxis_title=x_col, yaxis_title=y_col, coloraxis_colorbar=dict( title="Count", thicknessmode="pixels", thickness=20, lenmode="pixels", len=300, yanchor="top", y=1, ticks="outside" ) ) elif viz_type == 'scatter' and len(numeric_cols) >= 2: # For scatter plots, we need at least 2 numeric columns x_col = numeric_cols[0] y_col = numeric_cols[1] # Add a third dimension (size) if available size_col = numeric_cols[2] if len(numeric_cols) > 2 else None # Add a color dimension if available if len(result_df.columns) > len(numeric_cols): # Find a categorical column for color categorical_cols = [col for col in result_df.columns if col not in numeric_cols] color_col = categorical_cols[0] if categorical_cols else None else: color_col = None # Create scatter plot with enhanced features fig = px.scatter( result_df, x=x_col, y=y_col, size=size_col, color=color_col, # Add color dimension if available title=f"Relationship between {x_col} and {y_col}", opacity=0.7, size_max=15, # Maximum marker size color_discrete_sequence=px.colors.qualitative.Plotly ) # Add a trend line if pd.api.types.is_numeric_dtype(result_df[x_col]) and pd.api.types.is_numeric_dtype(result_df[y_col]): fig.update_layout( shapes=[ dict( type='line', xref='x', yref='y', x0=result_df[x_col].min(), y0=result_df[y_col].min(), x1=result_df[x_col].max(), y1=result_df[y_col].max(), line=dict(color='red', width=2, dash='dash') ) ] ) # Improve scatter plot layout fig.update_layout( xaxis_title=x_col, yaxis_title=y_col, showlegend=True, legend=dict( title=color_col if color_col else "", orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1 ) ) elif viz_type == 'line': # For line charts, determine the x-axis (preferably a date/time column) time_cols = [col for col in result_df.columns if any(time_word in col.lower() for time_word in ['date', 'time', 'month', 'year', 'day'])] if time_cols: x_col = time_cols[0] else: x_col = result_df.columns[0] # Determine y-axis columns (numeric columns) y_cols = numeric_cols[:3] # Use up to 3 numeric columns if not y_cols and len(result_df.columns) > 1: # If no numeric columns, use the second column y_cols = [result_df.columns[1]] fig = px.line( result_df, x=x_col, y=y_cols, title="Time Series Analysis", markers=True, # Add markers at each data point color_discrete_sequence=px.colors.qualitative.Plotly ) # Add range slider for time series fig.update_layout( xaxis=dict( rangeslider=dict(visible=True), type='category' if not pd.api.types.is_datetime64_any_dtype(result_df[x_col]) else '-' ) ) else: # Default to bar chart # For bar charts, use the first column as x and numeric columns as y x_col = result_df.columns[0] # Determine y-axis columns (numeric columns) if numeric_cols and x_col not in numeric_cols: y_cols = numeric_cols[:3] # Use up to 3 numeric columns elif len(result_df.columns) > 1: y_cols = [result_df.columns[1]] else: y_cols = ['value'] result_df['value'] = 1 # Default value if no suitable column fig = px.bar( result_df, x=x_col, y=y_cols[0], # Use only the first y column for bar charts title="Data Visualization", color_discrete_sequence=['#636EFA'] ) # Improve figure layout for all chart types fig.update_layout( autosize=True, width=fig_width, height=fig_height, margin=dict(l=50, r=50, b=100, t=100, pad=4), template="plotly_white", font=dict(size=14), legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1), plot_bgcolor='rgba(240,240,240,0.2)', # Light gray background paper_bgcolor='white' ) # Convert the figure to an image and encode it as base64 img_bytes = fig.to_image(format="png", width=fig_width, height=fig_height, scale=2) encoded = base64.b64encode(img_bytes).decode("ascii") img_src = f"data:image/png;base64,{encoded}" # Add the image directly to the response with increased size response += f"\n\n" # Add note about visualization response += f"\n\n**A {viz_type} visualization has been generated and is displayed above.**" except Exception as viz_error: print(f"Visualization error: {str(viz_error)}") traceback.print_exc() except Exception as e: response = f"**SQL Query:**\n```sql\n{sql_query}\n```\n\n**Error executing query:** {str(e)}" conn.close() except Exception as e: response = f"Error processing query: {str(e)}" else: # For non-CSV queries, use the document assistant try: response = document_assistant.process_query(query) except Exception as e: response = f"Error processing document query: {str(e)}" # Calculate processing time processing_time = time.time() - start_time response += f"\n\n(Query processed in {processing_time:.2f} seconds)" # Add the response to history history.append({"role": "assistant", "content": response}) return "", history def process_file_upload(files): """Process uploaded files and index them""" if not files: return "No files uploaded" global current_context # Clear existing context current_context = { "file_type": None, "file_name": None, "table_name": None } file_info = [] for file in files: file_path = file.name file_name = os.path.basename(file_path) file_ext = os.path.splitext(file_name)[1].lower() if file_ext == '.csv': try: # Create table name from filename table_name = os.path.splitext(file_name)[0].replace(' ', '_').lower() # Load CSV into SQLite conn = sqlite3.connect(DB_PATH) # Configure SQLite for faster imports conn.execute("PRAGMA synchronous = OFF") conn.execute("PRAGMA journal_mode = MEMORY") # Read the CSV and load it into SQLite df = pd.read_csv(file_path) df.to_sql('data_tab', conn, if_exists='replace', index=False) # Update current context current_context = { "file_type": "csv", "file_name": file_name, "table_name": "data_tab" # Always use data_tab as the table name } # Get column info cursor = conn.cursor() cursor.execute("PRAGMA table_info(data_tab);") columns = [f"{col[1]} ({col[2]})" for col in cursor.fetchall()] # Get row count cursor.execute("SELECT COUNT(*) FROM data_tab;") row_count = cursor.fetchone()[0] conn.close() file_info.append("✅ CSV File Successfully Loaded") file_info.append(f"📊 Table Name: data_tab") file_info.append(f"📄 Source File: {file_name}") file_info.append(f"📈 Total Rows: {row_count:,}") file_info.append(f"📋 Columns: {', '.join(columns)}") except Exception as e: file_info.append(f"❌ Error loading CSV {file_name}: {str(e)}") else: # Process PDF or other document types try: result = document_assistant.upload_document(file_path) # Update current context current_context = { "file_type": "pdf", "file_name": file_name, "table_name": None } file_info.append("✅ Document Successfully Processed") file_info.append(f"📄 File: {file_name}") file_info.append(f"📚 Chunks: {result['chunks']}") file_info.append(result['message']) except Exception as e: file_info.append(f"❌ Error processing document {file_name}: {str(e)}") return "\n".join(file_info) def list_documents(): """List all indexed documents""" try: docs = document_assistant.get_all_documents() if not docs: return "No documents indexed yet." result = "Indexed Documents:\n\n" for doc in docs: result += f"- {doc['filename']} ({doc['file_type']})\n" return result except Exception as e: return f"Error listing documents: {str(e)}" def clear_context(): """Clear the current context""" global current_context try: # Reset the context current_context = { "file_type": None, "file_name": None, "table_name": None } return [{"role": "assistant", "content": "Context cleared. You can now upload new documents or CSV files."}] except Exception as e: return [{"role": "assistant", "content": f"Error clearing context: {str(e)}"}] def process_voice_input(audio_path): """Process voice input and return transcribed text""" if audio_path is None: return "No audio recorded" try: # Initialize recognizer r = sr.Recognizer() # Load the audio file with sr.AudioFile(audio_path) as source: # Read the audio data audio_data = r.record(source) # Recognize speech using Google Speech Recognition text = r.recognize_google(audio_data) return text except sr.UnknownValueError: return "Could not understand audio" except sr.RequestError as e: return f"Error with speech recognition service: {e}" except Exception as e: return f"Error processing audio: {str(e)}" def text_to_speech_output(text): """Convert text to speech""" if not text or len(text) == 0: return None # Extract the last assistant message last_message = None for msg in reversed(text): if msg["role"] == "assistant": last_message = msg["content"] break if not last_message: return None try: # Clean the text (remove markdown and HTML) clean_text = re.sub(r'<.*?>', '', last_message) # Remove HTML tags clean_text = re.sub(r'\*\*(.*?)\*\*', r'\1', clean_text) # Remove bold markdown clean_text = re.sub(r'\n\n', ' ', clean_text) # Replace double newlines with space clean_text = re.sub(r'```.*?```', 'Code block removed for speech.', clean_text, flags=re.DOTALL) # Replace code blocks # Create a temporary file temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") temp_file.close() # Generate speech tts = gTTS(text=clean_text, lang='en', slow=False) tts.save(temp_file.name) return temp_file.name except Exception as e: print(f"Error generating speech: {str(e)}") return None def create_test_visualization(): """Create a test visualization to verify plotting works""" try: # Create sample data data = pd.DataFrame({ 'Month': ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun'], 'Value': [10, 15, 13, 17, 20, 25] }) # Create a simple bar chart fig = px.bar(data, x='Month', y='Value', title='Test Visualization') # Configure the figure fig.update_layout( autosize=True, width=800, height=500 ) return fig except Exception as e: print(f"Error creating test visualization: {str(e)}") return None def create_test_html_visualization(): """Create a test HTML visualization to verify plotting works""" try: # Create sample data data = pd.DataFrame({ 'Month': ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun'], 'Value': [10, 15, 13, 17, 20, 25] }) # Create a simple bar chart fig = px.bar(data, x='Month', y='Value', title='Test Visualization') # Convert to HTML html = pio.to_html(fig, full_html=False) return html except Exception as e: print(f"Error creating test HTML visualization: {str(e)}") return None def flush_databases(): """Flush ChromaDB and SQLite databases""" result = [] # Flush SQLite database try: conn = sqlite3.connect(DB_PATH) cursor = conn.cursor() # Get all tables cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") tables = cursor.fetchall() # Drop all tables for table in tables: cursor.execute(f"DROP TABLE IF EXISTS {table[0]};") conn.commit() conn.close() result.append("✅ SQLite database cleared successfully") except Exception as e: result.append(f"❌ Error clearing SQLite database: {str(e)}") # Flush ChromaDB by resetting the document assistant try: success = document_assistant.reset_database() if success: result.append("✅ ChromaDB cleared successfully") else: result.append("⚠️ ChromaDB reset may not have been complete") except Exception as e: result.append(f"❌ Error clearing ChromaDB: {str(e)}") # Reset current context global current_context current_context = { "file_type": None, "file_name": None, "table_name": None } return "\n".join(result) # At the beginning of app.py, after the imports # Add this code to monkey patch the vector_db module try: from backend.vector_db import ChromaVectorDB except NameError as e: if "response" in str(e): # If the error is about 'response' not being defined, fix the module import backend.vector_db # Remove the problematic code if hasattr(backend.vector_db, 'response'): delattr(backend.vector_db, 'response') # Reload the module importlib.reload(backend.vector_db) from backend.vector_db import ChromaVectorDB # Create Gradio interface with gr.Blocks(title="AI Document Analysis & Voice Assistant") as demo: gr.Markdown("# 🤖 AI Document Analysis & Voice Assistant") gr.Markdown("Upload documents, ask questions, and get voice responses!") with gr.Tab("Chat"): # Use a custom CSS to ensure images are displayed properly gr.HTML(""" """) chatbot = gr.Chatbot(height=500, type="messages", elem_classes="chatbot-container") with gr.Row(): with gr.Column(scale=8): msg = gr.Textbox( placeholder="Ask a question about your documents...", show_label=False ) with gr.Column(scale=1): voice_btn = gr.Button("🎤") with gr.Row(): submit_btn = gr.Button("Submit") clear_btn = gr.Button("Clear") clear_context_btn = gr.Button("Clear Context") audio_output = gr.Audio(label="Voice Response", type="filepath") # Voice input voice_input = gr.Audio( label="Voice Input", type="filepath", visible=False ) # Event handlers submit_btn.click( process_text_query, inputs=[msg, chatbot], outputs=[msg, chatbot] ) msg.submit( process_text_query, inputs=[msg, chatbot], outputs=[msg, chatbot] ) clear_btn.click(lambda: None, None, [chatbot], queue=False) clear_context_btn.click(clear_context, inputs=[], outputs=[chatbot]) voice_btn.click( lambda: gr.update(visible=True), None, voice_input ) voice_input.change( process_voice_input, inputs=[voice_input], outputs=[msg] ) # Add TTS functionality tts_btn = gr.Button("🔊 Speak Response") tts_btn.click( text_to_speech_output, inputs=[chatbot], outputs=[audio_output] ) with gr.Tab("Document Upload"): file_upload = gr.File( label="Upload Documents", file_types=[".pdf", ".txt", ".docx", ".csv", ".xlsx"], file_count="multiple" ) with gr.Row(): upload_button = gr.Button("Process & Index Documents", scale=2) flush_db_btn_doc = gr.Button("🗑️ Flush All Databases", variant="stop", scale=1) upload_output = gr.Textbox(label="Upload Status") upload_button.click( process_file_upload, inputs=[file_upload], outputs=[upload_output] ) flush_db_btn_doc.click( flush_databases, inputs=[], outputs=[upload_output] ) list_docs_button = gr.Button("List Indexed Documents") docs_output = gr.Textbox(label="Indexed Documents") list_docs_button.click( list_documents, inputs=[], outputs=[docs_output] ) with gr.Tab("Settings"): with gr.Row(): gr.Markdown("## Database Management") flush_db_btn = gr.Button("🗑️ Flush All Databases", variant="stop", scale=1) flush_result = gr.Textbox(label="Flush Result") flush_db_btn.click( flush_databases, inputs=[], outputs=[flush_result] ) gr.Markdown("## System Settings") api_key = gr.Textbox( label="Groq API Key", placeholder="Enter your Groq API key", type="password", value=os.getenv("GROQ_API_KEY", "") ) save_btn = gr.Button("Save Settings") def save_settings(key): try: os.environ["GROQ_API_KEY"] = key return "Settings saved!" except Exception as e: return f"Error saving settings: {str(e)}" save_btn.click( save_settings, inputs=[api_key], outputs=[gr.Textbox(label="Status")] ) gr.Markdown("## Debugging") test_viz_btn = gr.Button("Test Visualization") test_viz_output = gr.HTML(label="Test Visualization") test_viz_btn.click( create_test_html_visualization, inputs=[], outputs=[test_viz_output] ) # Launch the app if __name__ == "__main__": demo.launch()