import os import sys import gradio as gr from dotenv import load_dotenv import tempfile import pandas as pd import sqlite3 from langchain_core.prompts import ChatPromptTemplate from langchain_groq import ChatGroq import plotly.express as px import time import plotly.io as pio import traceback import base64 from io import BytesIO import re import importlib.util # Load environment variables load_dotenv() # Add parent directory to path to import backend modules sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from backend.main import DocumentAssistant # Initialize the document assistant document_assistant = DocumentAssistant() # Initialize the LLM using the llama3-8b-8192 model from Groq llm = ChatGroq( model="llama3-8b-8192", temperature=0, max_tokens=None, timeout=None, max_retries=2, verbose=True, api_key=os.getenv("GROQ_API_KEY") ) # Database path for CSV data DB_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data", "csv_data.db") os.makedirs(os.path.dirname(DB_PATH), exist_ok=True) # Create data directory if it doesn't exist DATA_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data") os.makedirs(DATA_DIR, exist_ok=True) # Create chroma_db directory if it doesn't exist CHROMA_DB_DIR = os.path.join(DATA_DIR, "chroma_db") os.makedirs(CHROMA_DB_DIR, exist_ok=True) # Set environment variables for ChromaDB os.environ["CHROMA_DB_PATH"] = CHROMA_DB_DIR # Current context to track what we're working with current_context = { "file_type": None, "file_name": None, "table_name": None } # Add a global variable to store the current plot # current_plot = None # Define the prompt with examples for SQL query generation query_prompt = ChatPromptTemplate.from_template(""" You are a SQL expert. Given a question about data in a table, write a SQLite-compatible SQL query to answer the question. Important guidelines: 1. Use SQLite syntax (not PostgreSQL or MySQL) 2. For date functions, use strftime() instead of EXTRACT - Example: strftime('%Y', date_column) instead of EXTRACT(YEAR FROM date_column) 3. SQLite doesn't have TRUNCATE function, use CAST((column / bin_size) AS INT) * bin_size instead 4. For percentiles, use window functions or approximate methods 5. Keep queries efficient and focused on answering the specific question 6. Always use 'data_tab' as the table name 7. IMPORTANT: Return ONLY the SQL query without any markdown formatting, explanations, or code blocks Question: {question} """) # Define the prompt for interpreting the SQL query result interpret_prompt = ChatPromptTemplate.from_messages( [ ("system", "You are an experienced data analyst. Provide a concise, natural language answer based on the given data summary. If relevant, give key statistics, trends, or patterns."), ("human", "Question: {question}\nSQL Query: {sql_query}\nData Summary:\n{data_summary}") ] ) # Add this after the query_prompt definition # visualization_prompt = ChatPromptTemplate.from_template(""" # You are a data visualization expert. Given a question about visualizing data, write a SQLite-compatible SQL query that will retrieve the appropriate data for the visualization. # # Important guidelines for SQLite syntax: # 1. Use strftime() for date functions: # - Year: strftime('%Y', date_column) # - Month: strftime('%m', date_column) # - Day: strftime('%d', date_column) # - Hour: strftime('%H', date_column) # # 2. For histograms and binning: # - Use: CAST((column / bin_size) AS INT) * bin_size # - Example: CAST((trip_distance / 0.5) AS INT) * 0.5 AS distance_bin # # 3. For box plots: # - SQLite doesn't support PERCENTILE_CONT or window functions # - Simply return the raw data column: SELECT column_name FROM data_tab # - The application will calculate quartiles and outliers # # 4. For heatmaps: # - Return raw data for correlation analysis # - Example: SELECT numeric_col1, numeric_col2, numeric_col3 FROM data_tab # # 5. Always use 'data_tab' as the table name # # 6. IMPORTANT: Return ONLY the SQL query without any markdown formatting, explanations, or code blocks # # Question: {question} # Visualization type: {viz_type} # """) # Add this helper function to clean SQL queries def clean_sql_query(query_text): """Clean SQL query text by removing markdown formatting and comments""" # Check if input is None or empty if not query_text: return "SELECT * FROM data_tab LIMIT 10;" # Remove markdown code blocks if "```" in query_text: # Extract content between code blocks pattern = r"```(?:sql)?(.*?)```" matches = re.findall(pattern, query_text, re.DOTALL) if matches: query_text = matches[0].strip() # Remove any "Here is the SQL query" text that might precede the query prefixes = [ "here is the sql query", "here is the sqlite query", "here is a query", "here's the sql query", "the sql query is", "sql query:" ] for prefix in prefixes: if query_text.lower().startswith(prefix): # Find the first occurrence of "SELECT", "WITH", etc. sql_keywords = ["select", "with", "create", "insert", "update", "delete"] positions = [query_text.lower().find(keyword) for keyword in sql_keywords] positions = [pos for pos in positions if pos != -1] if positions: start_pos = min(positions) query_text = query_text[start_pos:] # Remove SQL comments query_text = re.sub(r'--.*?(\n|$)', ' ', query_text) # Remove trailing semicolon if present query_text = query_text.strip().rstrip(';') # Ensure the query is not empty if not query_text.strip(): return "SELECT * FROM data_tab LIMIT 10;" return query_text def process_text_query(query, history): """Process a text query and update chat history""" if not query: return "", history # Add the user's query to history history.append([query, None]) start_time = time.time() # Define visualization keywords at the beginning viz_keywords = { 'bar': ['bar chart', 'bar graph', 'bar plot', 'barchart', 'bargraph'], 'line': ['line chart', 'line graph', 'line plot', 'linechart', 'trend', 'trends', 'time series'], 'pie': ['pie chart', 'pie graph', 'pie plot', 'piechart', 'distribution', 'proportion'], 'histogram': ['histogram', 'distribution of', 'frequency distribution'], 'box': ['box plot', 'boxplot', 'box and whisker', 'outliers', 'quartiles'], 'heatmap': ['heatmap', 'heat map', 'correlation matrix', 'correlation heatmap'], 'scatter': ['scatter', 'scatter plot', 'relationship between', 'correlation between'] } # Check if this is a visualization request is_visualization = any(word in query.lower() for word in ['plot', 'graph', 'chart', 'visualize', 'visualization', 'trend', 'show me']) # Determine visualization type from query viz_type = None if is_visualization: for vtype, keywords in viz_keywords.items(): if any(keyword in query.lower() for keyword in keywords): viz_type = vtype break # Check if we're in CSV context or have documents loaded if current_context["file_type"] == "csv" and current_context["table_name"]: try: # Connect to the database conn = sqlite3.connect(DB_PATH) # Get column information for context cursor = conn.cursor() cursor.execute(f"PRAGMA table_info({current_context['table_name']});") columns = [info[1] for info in cursor.fetchall()] columns_str = ", ".join(columns) # Create question with context question_with_context = f"The table 'data_tab' has columns: {columns_str}. {query}" # Special handling for visualization types that need raw data if is_visualization and viz_type in ['box', 'heatmap']: # For box plots and heatmaps, we need raw data if viz_type == 'box': # For box plots, we need a single numeric column numeric_cols_query = "SELECT name FROM pragma_table_info('data_tab') WHERE type LIKE '%INT%' OR type LIKE '%REAL%' OR type LIKE '%FLOA%' OR type LIKE '%NUM%';" cursor = conn.cursor() cursor.execute(numeric_cols_query) numeric_cols = [row[0] for row in cursor.fetchall()] if numeric_cols: # Find the relevant numeric column based on the query target_col = None for col in numeric_cols: if col.lower() in query.lower(): target_col = col break # If no specific column is mentioned, use the first numeric column if not target_col and numeric_cols: target_col = numeric_cols[0] # Generate a simple query to get the raw data sql_query = f"SELECT {target_col} FROM data_tab WHERE {target_col} IS NOT NULL;" else: # No numeric columns found sql_query = "SELECT * FROM data_tab LIMIT 10;" elif viz_type == 'heatmap': # For heatmaps, we need multiple numeric columns numeric_cols_query = "SELECT name FROM pragma_table_info('data_tab') WHERE type LIKE '%INT%' OR type LIKE '%REAL%' OR type LIKE '%FLOA%' OR type LIKE '%NUM%';" cursor = conn.cursor() cursor.execute(numeric_cols_query) numeric_cols = [row[0] for row in cursor.fetchall()] if len(numeric_cols) >= 2: # Use all numeric columns (up to a reasonable limit) cols_to_use = numeric_cols[:10] # Limit to 10 columns for performance cols_str = ", ".join(cols_to_use) sql_query = f"SELECT {cols_str} FROM data_tab WHERE {numeric_cols[0]} IS NOT NULL LIMIT 1000;" else: sql_query = "SELECT * FROM data_tab LIMIT 10;" else: # For other queries, use the LLM to generate SQL sql_query = llm.invoke(query_prompt.format(question=question_with_context)).content sql_query = clean_sql_query(sql_query) # Execute the query result_df = pd.read_sql_query(sql_query, conn) # Close the connection conn.close() # Format the dataframe as a string table for display df_str = result_df.to_string() # Generate text response data_summary = result_df.to_string() analysis = llm.invoke(interpret_prompt.format( question=query, sql_query=sql_query, data_summary=data_summary )).content # Create a comprehensive response that includes: # 1. SQL Query # 2. Results as a table # 3. Analysis of the results comprehensive_response = f""" ### SQL Query: ```sql {sql_query} ``` ### Results: ``` {df_str} ``` ### Analysis: {analysis} """ # Generate visualization if requested if is_visualization: viz_html = generate_visualization(result_df, query) if viz_html: # Add the visualization to history history[-1][1] = comprehensive_response return viz_html, history # If no visualization or visualization failed, return text response history[-1][1] = comprehensive_response return comprehensive_response, history except Exception as e: error_msg = f"Error processing query: {str(e)}" history[-1][1] = error_msg return error_msg, history elif document_assistant.get_all_documents(): # Handle document queries try: response = document_assistant.process_query(query) history[-1][1] = response return response, history except Exception as e: error_msg = f"Error processing query: {str(e)}" history[-1][1] = error_msg return error_msg, history else: # Handle general queries with LLM when no documents are loaded try: # Create a general knowledge context prompt general_prompt = ChatPromptTemplate.from_messages([ ("system", "You are a helpful assistant that provides clear, informative responses. Use your knowledge to answer the user's question concisely."), ("human", "{question}") ]) # Get response from LLM response = llm.invoke(general_prompt.format(question=query)).content # Add the response to history history[-1][1] = response return response, history except Exception as e: error_msg = f"Error processing query: {str(e)}" history[-1][1] = error_msg return error_msg, history def process_file_upload(files): """Process uploaded files and index them""" if not files: return "No files uploaded" global current_context # Clear existing context current_context = { "file_type": None, "file_name": None, "table_name": None } file_info = [] for file in files: file_path = file.name file_name = os.path.basename(file_path) file_ext = os.path.splitext(file_name)[1].lower() if file_ext == '.csv': try: # Create table name from filename table_name = os.path.splitext(file_name)[0].replace(' ', '_').lower() # Load CSV into SQLite conn = sqlite3.connect(DB_PATH) # Configure SQLite for faster imports conn.execute("PRAGMA synchronous = OFF") conn.execute("PRAGMA journal_mode = MEMORY") # Read the CSV and load it into SQLite df = pd.read_csv(file_path) df.to_sql('data_tab', conn, if_exists='replace', index=False) # Update current context current_context = { "file_type": "csv", "file_name": file_name, "table_name": "data_tab" # Always use data_tab as the table name } # Get column info cursor = conn.cursor() cursor.execute("PRAGMA table_info(data_tab);") columns = [f"{col[1]} ({col[2]})" for col in cursor.fetchall()] # Get row count cursor.execute("SELECT COUNT(*) FROM data_tab;") row_count = cursor.fetchone()[0] conn.close() file_info.append("✅ CSV File Successfully Loaded") file_info.append(f"📊 Table Name: data_tab") file_info.append(f"📄 Source File: {file_name}") file_info.append(f"📈 Total Rows: {row_count:,}") file_info.append(f"📋 Columns: {', '.join(columns)}") except Exception as e: file_info.append(f"❌ Error loading CSV {file_name}: {str(e)}") else: # Process PDF or other document types try: result = document_assistant.upload_document(file_path) # Update current context current_context = { "file_type": "pdf", "file_name": file_name, "table_name": None } file_info.append("✅ Document Successfully Processed") file_info.append(f"📄 File: {file_name}") file_info.append(f"📚 Chunks: {result['chunks']}") file_info.append(result['message']) except Exception as e: file_info.append(f"❌ Error processing document {file_name}: {str(e)}") return "\n".join(file_info) # Function commented out as it's no longer used # def list_documents(): # """List all indexed documents""" # try: # docs = document_assistant.get_all_documents() # if not docs: # return "No documents indexed yet." # # result = "Indexed Documents:\n\n" # for doc in docs: # result += f"- {doc['filename']} ({doc['file_type']})\n" # # return result # except Exception as e: # return f"Error listing documents: {str(e)}" def clear_context(): """Clear the current context""" global current_context try: # Reset the context current_context = { "file_type": None, "file_name": None, "table_name": None } return [["Context cleared. You can now upload new documents or CSV files.", None]] except Exception as e: return [[f"Error clearing context: {str(e)}", None]] def flush_databases(): """Flush ChromaDB and SQLite databases""" global document_assistant global current_context result = [] # Flush SQLite database try: conn = sqlite3.connect(DB_PATH) cursor = conn.cursor() # Get all tables cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") tables = cursor.fetchall() # Drop all tables for table in tables: cursor.execute(f"DROP TABLE IF EXISTS {table[0]};") conn.commit() conn.close() result.append("✅ SQLite database cleared successfully") except Exception as e: result.append(f"❌ Error clearing SQLite database: {str(e)}") # Flush ChromaDB by resetting the document assistant try: success = document_assistant.reset_database() if success: result.append("✅ ChromaDB cleared successfully") else: # Even if reset fails, we can still reinitialize the document assistant # This is a workaround that creates a fresh instance document_assistant = DocumentAssistant() result.append("⚠️ ChromaDB reset partially completed - created new instance") except Exception as e: result.append(f"❌ Error clearing ChromaDB: {str(e)}") # Reset current context current_context = { "file_type": None, "file_name": None, "table_name": None } return "\n".join(result) # At the beginning of app.py, after the imports # Add this code to monkey patch the vector_db module try: from backend.vector_db import ChromaVectorDB except NameError as e: if "response" in str(e): # If the error is about 'response' not being defined, fix the module import backend.vector_db # Remove the problematic code if hasattr(backend.vector_db, 'response'): delattr(backend.vector_db, 'response') # Reload the module importlib.reload(backend.vector_db) from backend.vector_db import ChromaVectorDB # Add this function to app.py def generate_visualization(result_df, query): """Generate a visualization based on the query and data""" try: print("Visualization requested, attempting to create plot...") # Set common figure parameters fig_width = 1200 # Increased for better quality fig_height = 800 # Maintain aspect ratio # Determine visualization type from query viz_type = 'bar' # Default if any(word in query.lower() for word in ['pie', 'distribution', 'proportion']): viz_type = 'pie' elif any(word in query.lower() for word in ['line', 'trend', 'time series']): viz_type = 'line' elif any(word in query.lower() for word in ['scatter', 'relationship']): viz_type = 'scatter' elif any(word in query.lower() for word in ['histogram', 'distribution of']): viz_type = 'histogram' elif any(word in query.lower() for word in ['box', 'boxplot', 'outliers']): viz_type = 'box' elif any(word in query.lower() for word in ['heatmap', 'correlation']): viz_type = 'heatmap' print(f"Creating {viz_type} visualization...") # Find numeric columns numeric_cols = result_df.select_dtypes(include=['number']).columns.tolist() # Create basic visualization based on type if viz_type == 'pie' and len(result_df) <= 20: # Simple pie chart labels = result_df.iloc[:, 0].tolist() values = result_df.iloc[:, 1].tolist() if len(result_df.columns) > 1 else [1] * len(result_df) import plotly.graph_objects as go fig = go.Figure(data=[go.Pie(labels=labels, values=values)]) fig.update_layout(title_text='Pie Chart') elif viz_type == 'histogram' and len(numeric_cols) > 0: # Simple histogram import plotly.express as px fig = px.histogram(result_df, x=numeric_cols[0]) fig.update_layout(title_text=f'Histogram of {numeric_cols[0]}') elif viz_type == 'box' and len(numeric_cols) > 0: # Simple box plot import plotly.express as px fig = px.box(result_df, y=numeric_cols[0]) fig.update_layout(title_text=f'Box Plot of {numeric_cols[0]}') elif viz_type == 'heatmap' and len(numeric_cols) >= 2: # Simple heatmap import plotly.express as px # Create correlation matrix corr_df = result_df[numeric_cols].corr() fig = px.imshow(corr_df, text_auto=True) fig.update_layout(title_text='Correlation Heatmap') elif viz_type == 'scatter' and len(numeric_cols) >= 2: # Simple scatter plot import plotly.express as px fig = px.scatter(result_df, x=numeric_cols[0], y=numeric_cols[1]) fig.update_layout(title_text=f'Scatter Plot of {numeric_cols[0]} vs {numeric_cols[1]}') elif viz_type == 'line': # Simple line chart import plotly.express as px x_col = result_df.columns[0] y_cols = numeric_cols if numeric_cols else [result_df.columns[1]] if len(result_df.columns) > 1 else None if y_cols: fig = px.line(result_df, x=x_col, y=y_cols[0]) fig.update_layout( title_text=f'Line Chart of {y_cols[0]} over {x_col}', xaxis=dict( tickangle=-45, tickmode='auto', nticks=20 ) ) else: # Fallback to bar chart viz_type = 'bar' if viz_type == 'bar' or 'fig' not in locals(): # Simple bar chart (default) import plotly.express as px x_col = result_df.columns[0] y_col = numeric_cols[0] if numeric_cols else result_df.columns[1] if len(result_df.columns) > 1 else None # Check if we have many categories (more than 10) if len(result_df) > 10: # Use horizontal bar chart for many categories if y_col: fig = px.bar( result_df, y=x_col, # Swap x and y for horizontal orientation x=y_col, orientation='h', # Horizontal orientation title=f'Bar Chart of {y_col} by {x_col}' ) else: fig = px.bar( result_df, y=x_col, # Swap x and y for horizontal orientation orientation='h', # Horizontal orientation title=f'Bar Chart of {x_col}' ) else: # Use vertical bar chart for fewer categories if y_col: fig = px.bar( result_df, x=x_col, y=y_col, title=f'Bar Chart of {y_col} by {x_col}' ) else: fig = px.bar( result_df, x=x_col, title=f'Bar Chart of {x_col}' ) # Improve bar chart layout fig.update_layout( bargap=0.2, # Increase gap between bars uniformtext_minsize=8, # Minimum text size uniformtext_mode='hide' # Hide text if it doesn't fit ) # Set common layout properties fig.update_layout( width=fig_width, height=fig_height, template="plotly_white", margin=dict(l=40, r=40, t=80, b=80, pad=4), # Balanced margins autosize=True, # Allow the plot to resize with the container plot_bgcolor='rgba(240,240,240,0.2)', # Light gray background paper_bgcolor='white', font=dict(size=12) # Increase font size ) # Add hover information fig.update_traces( hovertemplate="%{x}: %{y}", hoverlabel=dict( bgcolor="white", font_size=12, font_family="Arial" ) ) print(f"Created figure with width={fig_width}, height={fig_height}") # Convert to image with higher quality print("Converting figure to image...") img_bytes = pio.to_image(fig, format="png", width=fig_width, height=fig_height, scale=3) # Increased scale for better quality print("Image conversion successful") # Encode as base64 import base64 encoded = base64.b64encode(img_bytes).decode("ascii") img_src = f"data:image/png;base64,{encoded}" print("HTML conversion successful") # Return the HTML img tag with responsive sizing return f"""
Data Visualization
""" except Exception as e: import traceback print(f"Error generating visualization: {str(e)}") traceback.print_exc() return None # Create Gradio interface with gr.Blocks(title="LLM Powered Database Chatbot") as demo: gr.Markdown("# 🤖 LLM Powered Database Chatbot") gr.Markdown("Upload documents, ask questions, and get AI-powered responses!") # Add a global variable to store the current visualization current_visualization = gr.State(None) with gr.Tab("Chat & Visualizations"): # Use a custom CSS to ensure images are displayed properly gr.HTML(""" """) with gr.Row(): with gr.Column(scale=1): chatbot = gr.Chatbot(height=500, elem_classes="chatbot-container") with gr.Row(): with gr.Column(scale=8): msg = gr.Textbox( placeholder="Ask a question about your documents...", show_label=False ) with gr.Column(scale=1): pass with gr.Row(): submit_btn = gr.Button("Submit") clear_btn = gr.Button("Clear") clear_context_btn = gr.Button("Clear Context") with gr.Column(scale=1): visualization_output = gr.HTML( label="Visualization", elem_classes="visualization-container" ) with gr.Row(): clear_viz_btn = gr.Button("🗑️ Clear Visualization") download_btn = gr.Button("📥 Download Visualization") save_status = gr.Textbox(label="Save Status", visible=False) download_img = gr.Image(visible=False, type="pil", label="Download Image") # Add information about capabilities gr.Markdown(""" ### Capabilities: - **Data Analysis**: Ask questions about your data and get detailed responses - **Visualization**: Request and view graphs and charts of your data - **Multiple File Types**: Upload PDFs, TXT, DOCX, CSV, and XLSX files for analysis - **Natural Language Queries**: Ask questions in plain English about your documents """) def clear_visualization(): return "", "" def download_visualization(viz_html): if not viz_html: return None try: # Extract the base64 image data from the HTML img_data_match = re.search(r'src=\'data:image/png;base64,([^\']+)\'', viz_html) if img_data_match: # Get the base64 data img_data = img_data_match.group(1) # Convert base64 to image import base64 from io import BytesIO from PIL import Image image_data = base64.b64decode(img_data) image = Image.open(BytesIO(image_data)) return image, gr.update(visible=True) else: return None, gr.update(visible=False) except Exception as e: print(f"Error downloading visualization: {str(e)}") return None, gr.update(visible=False) clear_viz_btn.click( clear_visualization, outputs=[visualization_output, current_visualization] ) download_btn.click( download_visualization, inputs=[current_visualization], outputs=[download_img, download_img] ) # Update the process_text_query function to handle visualizations def process_text_query_with_visualization(query, history, current_viz): """Process a text query and update chat history and visualization""" if not query: return "", history, current_viz # Process the query and get the response response, new_history = process_text_query(query, history) # Check if the response contains a visualization if "