SVashishta1
Fix: Add direct handling for tip queries and make schema instructions more explicit
2f13356
| import os | |
| import sys | |
| import gradio as gr | |
| from dotenv import load_dotenv | |
| import tempfile | |
| import pandas as pd | |
| import sqlite3 | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langchain_groq import ChatGroq | |
| import plotly.express as px | |
| import time | |
| import plotly.io as pio | |
| import traceback | |
| import base64 | |
| from io import BytesIO | |
| import re | |
| import importlib.util | |
| # Load environment variables | |
| load_dotenv() | |
| # Add parent directory to path to import backend modules | |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| from backend.main import DocumentAssistant | |
| # Initialize the document assistant | |
| document_assistant = DocumentAssistant() | |
| # Initialize the LLM using the llama3-8b-8192 model from Groq | |
| llm = ChatGroq( | |
| model="llama3-8b-8192", | |
| temperature=0, | |
| max_tokens=None, | |
| timeout=None, | |
| max_retries=2, | |
| verbose=True, | |
| api_key=os.getenv("GROQ_API_KEY") | |
| ) | |
| # Database path for CSV data | |
| DB_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data", "csv_data.db") | |
| os.makedirs(os.path.dirname(DB_PATH), exist_ok=True) | |
| # Create data directory if it doesn't exist | |
| DATA_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data") | |
| os.makedirs(DATA_DIR, exist_ok=True) | |
| # Create chroma_db directory if it doesn't exist | |
| CHROMA_DB_DIR = os.path.join(DATA_DIR, "chroma_db") | |
| os.makedirs(CHROMA_DB_DIR, exist_ok=True) | |
| # Set environment variables for ChromaDB | |
| os.environ["CHROMA_DB_PATH"] = CHROMA_DB_DIR | |
| # Current context to track what we're working with | |
| current_context = { | |
| "file_type": None, | |
| "file_name": None, | |
| "table_name": None | |
| } | |
| # Add a global variable to store the current plot | |
| current_plot = None | |
| # Define the prompt with examples for SQL query generation | |
| query_prompt = ChatPromptTemplate.from_template(""" | |
| You are a SQL expert. Given a question about data in a table, write a SQLite-compatible SQL query to answer the question. | |
| CRITICAL RULES: | |
| 1. ONLY use columns that are EXPLICITLY provided in the context. DO NOT invent or assume columns exist if they are not listed. | |
| 2. If the user asks about a column that doesn't exist, use a similar column from the available ones or explain that the data doesn't contain that information. | |
| 3. ALWAYS double-check that every column in your query is in the list of available columns. | |
| Technical guidelines: | |
| 4. Use SQLite syntax (not PostgreSQL or MySQL) | |
| 5. For date functions, use strftime() instead of EXTRACT | |
| - Example: strftime('%Y', date_column) instead of EXTRACT(YEAR FROM date_column) | |
| 6. SQLite doesn't have TRUNCATE function, use CAST((column / bin_size) AS INT) * bin_size instead | |
| 7. For percentiles, use window functions or approximate methods | |
| 8. Keep queries efficient and focused on answering the specific question | |
| 9. Always use 'data_tab' as the table name | |
| 10. Return ONLY the SQL query without any markdown formatting, explanations, or code blocks | |
| Question: {question} | |
| """) | |
| # Add this after the query_prompt definition | |
| visualization_prompt = ChatPromptTemplate.from_template(""" | |
| You are a data visualization expert. Given a question about visualizing data, write a SQLite-compatible SQL query that will retrieve the appropriate data for the visualization. | |
| Important guidelines for SQLite syntax: | |
| 1. Use strftime() for date functions: | |
| - Year: strftime('%Y', date_column) | |
| - Month: strftime('%m', date_column) | |
| - Day: strftime('%d', date_column) | |
| - Hour: strftime('%H', date_column) | |
| 2. For histograms and binning: | |
| - Use: CAST((column / bin_size) AS INT) * bin_size | |
| - Example: CAST((trip_distance / 0.5) AS INT) * 0.5 AS distance_bin | |
| 3. For box plots: | |
| - SQLite doesn't support PERCENTILE_CONT or window functions | |
| - Simply return the raw data column: SELECT column_name FROM data_tab | |
| - The application will calculate quartiles and outliers | |
| 4. For heatmaps: | |
| - Return raw data for correlation analysis | |
| - Example: SELECT numeric_col1, numeric_col2, numeric_col3 FROM data_tab | |
| 5. Always use 'data_tab' as the table name | |
| 6. IMPORTANT: Return ONLY the SQL query without any markdown formatting, explanations, or code blocks | |
| Question: {question} | |
| Visualization type: {viz_type} | |
| """) | |
| # Define the prompt for interpreting the SQL query result | |
| interpret_prompt = ChatPromptTemplate.from_messages( | |
| [ | |
| ("system", """You are an experienced data analyst. Provide a concise, natural language answer based on the given data summary. | |
| If relevant, give key statistics, trends, or patterns. Be clear about what the data shows and doesn't show. | |
| If the SQL query had to use alternative columns because the exact ones requested weren't available, explain this clearly to the user. | |
| For example, if they asked about 'fare_amount' but the dataset has 'fare' or 'total_fare' instead, mention this substitution."""), | |
| ("human", "Question: {question}\nSQL Query: {sql_query}\nData Summary:\n{data_summary}") | |
| ] | |
| ) | |
| # Add this helper function to clean SQL queries | |
| def clean_sql_query(query_text): | |
| """Clean SQL query text by removing markdown formatting and comments""" | |
| # Check if input is None or empty | |
| if not query_text: | |
| return "SELECT * FROM data_tab LIMIT 10;" | |
| # Remove markdown code blocks | |
| if "```" in query_text: | |
| # Extract content between code blocks | |
| pattern = r"```(?:sql)?(.*?)```" | |
| matches = re.findall(pattern, query_text, re.DOTALL) | |
| if matches: | |
| query_text = matches[0].strip() | |
| # Remove any "Here is the SQL query" text that might precede the query | |
| prefixes = [ | |
| "here is the sql query", | |
| "here is the sqlite query", | |
| "here is a query", | |
| "here's the sql query", | |
| "the sql query is", | |
| "sql query:" | |
| ] | |
| for prefix in prefixes: | |
| if query_text.lower().startswith(prefix): | |
| # Find the first occurrence of "SELECT", "WITH", etc. | |
| sql_keywords = ["select", "with", "create", "insert", "update", "delete"] | |
| positions = [query_text.lower().find(keyword) for keyword in sql_keywords] | |
| positions = [pos for pos in positions if pos != -1] | |
| if positions: | |
| start_pos = min(positions) | |
| query_text = query_text[start_pos:] | |
| # Remove SQL comments | |
| query_text = re.sub(r'--.*?(\n|$)', ' ', query_text) | |
| # Remove trailing semicolon if present | |
| query_text = query_text.strip().rstrip(';') | |
| # Ensure the query is not empty | |
| if not query_text.strip(): | |
| return "SELECT * FROM data_tab LIMIT 10;" | |
| return query_text | |
| def process_text_query(query, history): | |
| """Process a text query and update chat history""" | |
| if not query: | |
| return "", history | |
| # Add the user's query to history | |
| history.append([query, None]) | |
| start_time = time.time() | |
| # Define visualization keywords at the beginning | |
| viz_keywords = { | |
| 'bar': ['bar chart', 'bar graph', 'bar plot', 'barchart', 'bargraph'], | |
| 'line': ['line chart', 'line graph', 'line plot', 'linechart', 'trend', 'trends', 'time series'], | |
| 'pie': ['pie chart', 'pie graph', 'pie plot', 'piechart', 'distribution', 'proportion'], | |
| 'histogram': ['histogram', 'distribution of', 'frequency distribution'], | |
| 'box': ['box plot', 'boxplot', 'box and whisker', 'outliers', 'quartiles'], | |
| 'heatmap': ['heatmap', 'heat map', 'correlation matrix', 'correlation heatmap'], | |
| 'scatter': ['scatter', 'scatter plot', 'relationship between', 'correlation between'] | |
| } | |
| # Check if this is a visualization request | |
| is_visualization = any(word in query.lower() for word in ['plot', 'graph', 'chart', 'visualize', 'visualization', 'trend', 'show me']) | |
| # Determine visualization type from query | |
| viz_type = None | |
| if is_visualization: | |
| for vtype, keywords in viz_keywords.items(): | |
| if any(keyword in query.lower() for keyword in keywords): | |
| viz_type = vtype | |
| break | |
| # Check if we're in CSV context or have documents loaded | |
| if current_context["file_type"] == "csv" and current_context["table_name"]: | |
| try: | |
| # Connect to the database | |
| conn = sqlite3.connect(DB_PATH) | |
| # Get schema information FIRST before doing anything else | |
| cursor = conn.cursor() | |
| cursor.execute(f"PRAGMA table_info({current_context['table_name']});") | |
| columns_info = cursor.fetchall() | |
| columns = [info[1] for info in columns_info] | |
| column_types = [info[2] for info in columns_info] | |
| # Create rich context with column types | |
| columns_with_types = [f"{col} ({typ})" for col, typ in zip(columns, column_types)] | |
| columns_str = ", ".join(columns_with_types) | |
| # Handle specific queries directly based on schema | |
| if "highest tip" in query.lower() or "largest tip" in query.lower() or "maximum tip" in query.lower(): | |
| # Look for tip-related columns | |
| tip_columns = [col for col in columns if "tip" in col.lower() or "gratuity" in col.lower()] | |
| if tip_columns: | |
| print(f"Found tip-related columns: {tip_columns}") | |
| sql_query = f"SELECT MAX({tip_columns[0]}) AS highest_tip FROM data_tab" | |
| # Execute the query directly | |
| result_df = pd.read_sql_query(sql_query, conn) | |
| # Generate response | |
| highest_tip = result_df.iloc[0, 0] | |
| response = f"The highest tip in the dataset is {highest_tip}." | |
| history[-1][1] = response | |
| return response, history | |
| else: | |
| response = f"I couldn't find any columns related to tips in the dataset. Available columns are: {', '.join(columns)}" | |
| history[-1][1] = response | |
| return response, history | |
| # Create sample data context | |
| sample_query = "SELECT * FROM data_tab LIMIT 3;" | |
| sample_df = pd.read_sql_query(sample_query, conn) | |
| sample_data = sample_df.to_string(index=False, max_rows=3) | |
| # Create question with detailed context | |
| question_with_context = f""" | |
| IMPORTANT: ONLY use the exact columns listed below. DO NOT use any columns not explicitly listed here. | |
| The table 'data_tab' has these columns with their types: | |
| {columns_str} | |
| Available columns (exact names): {', '.join(columns)} | |
| Here's a sample of the data: | |
| {sample_data} | |
| User question: {query} | |
| Remember to ONLY use the columns listed above. If the question seems to require a column that doesn't exist, use the most relevant existing column instead or explain that the data doesn't contain that information. | |
| """ | |
| # Special handling for visualization types that need raw data | |
| if is_visualization and viz_type in ['box', 'heatmap']: | |
| # For box plots and heatmaps, we need raw data | |
| if viz_type == 'box': | |
| # For box plots, we need a single numeric column | |
| numeric_cols_query = "SELECT name FROM pragma_table_info('data_tab') WHERE type LIKE '%INT%' OR type LIKE '%REAL%' OR type LIKE '%FLOA%' OR type LIKE '%NUM%';" | |
| cursor = conn.cursor() | |
| cursor.execute(numeric_cols_query) | |
| numeric_cols = [row[0] for row in cursor.fetchall()] | |
| if numeric_cols: | |
| # Find the relevant numeric column based on the query | |
| target_col = None | |
| for col in numeric_cols: | |
| if col.lower() in query.lower(): | |
| target_col = col | |
| break | |
| # If no specific column is mentioned, use the first numeric column | |
| if not target_col and numeric_cols: | |
| target_col = numeric_cols[0] | |
| # Generate a simple query to get the raw data | |
| sql_query = f"SELECT {target_col} FROM data_tab WHERE {target_col} IS NOT NULL;" | |
| else: | |
| # No numeric columns found | |
| sql_query = "SELECT * FROM data_tab LIMIT 10;" | |
| elif viz_type == 'heatmap': | |
| # For heatmaps, we need multiple numeric columns | |
| numeric_cols_query = "SELECT name FROM pragma_table_info('data_tab') WHERE type LIKE '%INT%' OR type LIKE '%REAL%' OR type LIKE '%FLOA%' OR type LIKE '%NUM%';" | |
| cursor = conn.cursor() | |
| cursor.execute(numeric_cols_query) | |
| numeric_cols = [row[0] for row in cursor.fetchall()] | |
| if len(numeric_cols) >= 2: | |
| # Use all numeric columns (up to a reasonable limit) | |
| cols_to_use = numeric_cols[:10] # Limit to 10 columns for performance | |
| cols_str = ", ".join(cols_to_use) | |
| sql_query = f"SELECT {cols_str} FROM data_tab WHERE {numeric_cols[0]} IS NOT NULL LIMIT 1000;" | |
| else: | |
| sql_query = "SELECT * FROM data_tab LIMIT 10;" | |
| elif is_visualization: | |
| # For visualization queries, use the specialized visualization prompt | |
| sql_query = llm.invoke(visualization_prompt.format( | |
| question=question_with_context, | |
| viz_type=viz_type or "bar" | |
| )).content | |
| sql_query = clean_sql_query(sql_query) | |
| else: | |
| # For other queries, use the LLM to generate SQL | |
| sql_query = llm.invoke(query_prompt.format(question=question_with_context)).content | |
| sql_query = clean_sql_query(sql_query) | |
| # Check if all columns in the query exist before executing | |
| try: | |
| # Get all column names | |
| cursor.execute("PRAGMA table_info(data_tab);") | |
| available_columns = [info[1] for info in cursor.fetchall()] | |
| # Extract column names from the SQL query (simple approach) | |
| query_columns = [] | |
| from_pos = sql_query.lower().find("from") | |
| if from_pos > 0: | |
| select_part = sql_query[:from_pos].lower() | |
| # Remove SELECT keyword | |
| if select_part.startswith("select "): | |
| select_part = select_part[7:] | |
| # Split by commas and extract column names | |
| for col_expr in select_part.split(","): | |
| col_expr = col_expr.strip() | |
| # Handle AS aliases and functions | |
| if " as " in col_expr: | |
| col_expr = col_expr.split(" as ")[0].strip() | |
| # Extract column name from functions | |
| for func in ["max(", "min(", "avg(", "sum(", "count("]: | |
| if func in col_expr: | |
| # Extract column inside function | |
| start_idx = col_expr.find(func) + len(func) | |
| end_idx = col_expr.find(")", start_idx) | |
| if end_idx > start_idx: | |
| col_name = col_expr[start_idx:end_idx].strip() | |
| if col_name != "*" and "(" not in col_name: # Skip nested functions and * | |
| query_columns.append(col_name) | |
| # Handle direct column references | |
| if "(" not in col_expr and col_expr != "*": | |
| query_columns.append(col_expr) | |
| # Check for missing columns | |
| missing_columns = [] | |
| for col in query_columns: | |
| if col not in available_columns and col.strip() != "*": | |
| missing_columns.append(col) | |
| if missing_columns: | |
| # Generate a simpler query with available columns | |
| if "tip" in query.lower() or "gratuity" in query.lower(): | |
| # Look for a tip column | |
| tip_columns = [col for col in available_columns if "tip" in col.lower() or "gratuity" in col.lower()] | |
| if tip_columns: | |
| sql_query = f"SELECT MAX({tip_columns[0]}) AS highest_tip FROM data_tab" | |
| else: | |
| # No tip column, return info about available columns | |
| return f"I couldn't find a column related to tips or gratuity. Available columns are: {', '.join(available_columns)}", history | |
| else: | |
| # For other queries, suggest a generic query | |
| return f"Some columns in the query don't exist in the current dataset: {', '.join(missing_columns)}. Available columns are: {', '.join(available_columns)}", history | |
| except Exception as e: | |
| print(f"Error checking columns: {str(e)}") | |
| # Continue with the original query | |
| # Execute the query | |
| try: | |
| result_df = pd.read_sql_query(sql_query, conn) | |
| except Exception as e: | |
| error_message = str(e) | |
| # Try to provide a more helpful error message | |
| if "no such column" in error_message.lower(): | |
| # Extract column name from error | |
| column_name = error_message.split("no such column: ")[-1].strip("'").strip('"') | |
| # Look for similar columns | |
| cursor.execute("PRAGMA table_info(data_tab);") | |
| available_columns = [info[1] for info in cursor.fetchall()] | |
| # Simple fuzzy matching | |
| similar_columns = [] | |
| for col in available_columns: | |
| # Check if column name contains parts of the error column | |
| if column_name.lower() in col.lower() or any(part.lower() in col.lower() for part in column_name.split('_') if len(part) > 2): | |
| similar_columns.append(col) | |
| if similar_columns: | |
| message = f"Column '{column_name}' doesn't exist in the current dataset. Did you mean one of these? {', '.join(similar_columns)}\n\nAvailable columns are: {', '.join(available_columns)}" | |
| else: | |
| message = f"Column '{column_name}' doesn't exist in the current dataset. Available columns are: {', '.join(available_columns)}" | |
| history[-1][1] = message | |
| return message, history | |
| else: | |
| # Generic error message | |
| error_msg = f"Error executing query: {error_message}" | |
| history[-1][1] = error_msg | |
| return error_msg, history | |
| # Close the connection | |
| conn.close() | |
| # Format the dataframe as a string table for display | |
| df_str = result_df.to_string() | |
| # Generate text response | |
| data_summary = result_df.to_string() | |
| analysis = llm.invoke(interpret_prompt.format( | |
| question=query, | |
| sql_query=sql_query, | |
| data_summary=data_summary | |
| )).content | |
| # Create a comprehensive response that includes: | |
| # 1. SQL Query | |
| # 2. Results as a table | |
| # 3. Analysis of the results | |
| comprehensive_response = f""" | |
| ### SQL Query: | |
| ```sql | |
| {sql_query} | |
| ``` | |
| ### Results: | |
| ``` | |
| {df_str} | |
| ``` | |
| ### Analysis: | |
| {analysis} | |
| """ | |
| # Generate visualization if requested | |
| if is_visualization: | |
| viz_html = generate_visualization(result_df, query) | |
| if viz_html: | |
| # Add the visualization to history | |
| history[-1][1] = comprehensive_response | |
| return viz_html, history | |
| # If no visualization or visualization failed, return text response | |
| history[-1][1] = comprehensive_response | |
| return comprehensive_response, history | |
| except Exception as e: | |
| error_msg = f"Error processing query: {str(e)}" | |
| history[-1][1] = error_msg | |
| return error_msg, history | |
| elif document_assistant.get_all_documents(): | |
| # Handle document queries | |
| try: | |
| response = document_assistant.process_query(query) | |
| history[-1][1] = response | |
| return response, history | |
| except Exception as e: | |
| error_msg = f"Error processing query: {str(e)}" | |
| history[-1][1] = error_msg | |
| return error_msg, history | |
| else: | |
| # Handle general queries with LLM when no documents are loaded | |
| try: | |
| # Create a general knowledge context prompt | |
| general_prompt = ChatPromptTemplate.from_messages([ | |
| ("system", "You are a helpful assistant that provides clear, informative responses. Use your knowledge to answer the user's question concisely."), | |
| ("human", "{question}") | |
| ]) | |
| # Get response from LLM | |
| response = llm.invoke(general_prompt.format(question=query)).content | |
| # Add the response to history | |
| history[-1][1] = response | |
| return response, history | |
| except Exception as e: | |
| error_msg = f"Error processing query: {str(e)}" | |
| history[-1][1] = error_msg | |
| return error_msg, history | |
| def process_file_upload(files): | |
| """Process uploaded files and index them""" | |
| if not files: | |
| return "No files uploaded" | |
| global current_context | |
| # Clear existing context | |
| current_context = { | |
| "file_type": None, | |
| "file_name": None, | |
| "table_name": None | |
| } | |
| file_info = [] | |
| for file in files: | |
| file_path = file.name | |
| file_name = os.path.basename(file_path) | |
| file_ext = os.path.splitext(file_name)[1].lower() | |
| if file_ext == '.csv': | |
| try: | |
| # Create table name from filename | |
| table_name = os.path.splitext(file_name)[0].replace(' ', '_').lower() | |
| # Load CSV into SQLite | |
| conn = sqlite3.connect(DB_PATH) | |
| # Configure SQLite for faster imports | |
| conn.execute("PRAGMA synchronous = OFF") | |
| conn.execute("PRAGMA journal_mode = MEMORY") | |
| # Read the CSV and load it into SQLite | |
| df = pd.read_csv(file_path) | |
| df.to_sql('data_tab', conn, if_exists='replace', index=False) | |
| # Update current context | |
| current_context = { | |
| "file_type": "csv", | |
| "file_name": file_name, | |
| "table_name": "data_tab" # Always use data_tab as the table name | |
| } | |
| # Get column info | |
| cursor = conn.cursor() | |
| cursor.execute("PRAGMA table_info(data_tab);") | |
| columns = [f"{col[1]} ({col[2]})" for col in cursor.fetchall()] | |
| # Get row count | |
| cursor.execute("SELECT COUNT(*) FROM data_tab;") | |
| row_count = cursor.fetchone()[0] | |
| conn.close() | |
| file_info.append("✅ CSV File Successfully Loaded") | |
| file_info.append(f"📊 Table Name: data_tab") | |
| file_info.append(f"📄 Source File: {file_name}") | |
| file_info.append(f"📈 Total Rows: {row_count:,}") | |
| file_info.append(f"📋 Columns: {', '.join(columns)}") | |
| except Exception as e: | |
| file_info.append(f"❌ Error loading CSV {file_name}: {str(e)}") | |
| else: | |
| # Process PDF or other document types | |
| try: | |
| result = document_assistant.upload_document(file_path) | |
| # Update current context | |
| current_context = { | |
| "file_type": "pdf", | |
| "file_name": file_name, | |
| "table_name": None | |
| } | |
| file_info.append("✅ Document Successfully Processed") | |
| file_info.append(f"📄 File: {file_name}") | |
| file_info.append(f"📚 Chunks: {result['chunks']}") | |
| file_info.append(result['message']) | |
| except Exception as e: | |
| file_info.append(f"❌ Error processing document {file_name}: {str(e)}") | |
| return "\n".join(file_info) | |
| # Function commented out as it's no longer used | |
| # def list_documents(): | |
| # """List all indexed documents""" | |
| # try: | |
| # docs = document_assistant.get_all_documents() | |
| # if not docs: | |
| # return "No documents indexed yet." | |
| # | |
| # result = "Indexed Documents:\n\n" | |
| # for doc in docs: | |
| # result += f"- {doc['filename']} ({doc['file_type']})\n" | |
| # | |
| # return result | |
| # except Exception as e: | |
| # return f"Error listing documents: {str(e)}" | |
| def clear_context(): | |
| """Clear the current context""" | |
| global current_context | |
| try: | |
| # Reset the context | |
| current_context = { | |
| "file_type": None, | |
| "file_name": None, | |
| "table_name": None | |
| } | |
| return [["Context cleared. You can now upload new documents or CSV files.", None]] | |
| except Exception as e: | |
| return [[f"Error clearing context: {str(e)}", None]] | |
| def flush_databases(): | |
| """Flush ChromaDB and SQLite databases""" | |
| global document_assistant | |
| global current_context | |
| result = [] | |
| # Flush SQLite database | |
| try: | |
| conn = sqlite3.connect(DB_PATH) | |
| cursor = conn.cursor() | |
| # Get all tables | |
| cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") | |
| tables = cursor.fetchall() | |
| # Drop all tables | |
| for table in tables: | |
| cursor.execute(f"DROP TABLE IF EXISTS {table[0]};") | |
| conn.commit() | |
| conn.close() | |
| result.append("✅ SQLite database cleared successfully") | |
| except Exception as e: | |
| result.append(f"❌ Error clearing SQLite database: {str(e)}") | |
| # Flush ChromaDB by resetting the document assistant | |
| try: | |
| success = document_assistant.reset_database() | |
| if success: | |
| result.append("✅ ChromaDB cleared successfully") | |
| else: | |
| # Even if reset fails, we can still reinitialize the document assistant | |
| # This is a workaround that creates a fresh instance | |
| document_assistant = DocumentAssistant() | |
| result.append("⚠️ ChromaDB reset partially completed - created new instance") | |
| except Exception as e: | |
| result.append(f"❌ Error clearing ChromaDB: {str(e)}") | |
| # Reset current context | |
| current_context = { | |
| "file_type": None, | |
| "file_name": None, | |
| "table_name": None | |
| } | |
| return "\n".join(result) | |
| # At the beginning of app.py, after the imports | |
| # Add this code to monkey patch the vector_db module | |
| try: | |
| from backend.vector_db import ChromaVectorDB | |
| except NameError as e: | |
| if "response" in str(e): | |
| # If the error is about 'response' not being defined, fix the module | |
| import backend.vector_db | |
| # Remove the problematic code | |
| if hasattr(backend.vector_db, 'response'): | |
| delattr(backend.vector_db, 'response') | |
| # Reload the module | |
| importlib.reload(backend.vector_db) | |
| from backend.vector_db import ChromaVectorDB | |
| # Add this function to app.py | |
| def generate_visualization(result_df, query): | |
| """Generate a visualization based on the query and data""" | |
| try: | |
| print("Visualization requested, attempting to create plot...") | |
| # Set common figure parameters | |
| fig_width = 1200 # Increased for better quality | |
| fig_height = 800 # Maintain aspect ratio | |
| # Determine visualization type from query | |
| viz_type = 'bar' # Default | |
| if any(word in query.lower() for word in ['pie', 'distribution', 'proportion']): | |
| viz_type = 'pie' | |
| elif any(word in query.lower() for word in ['line', 'trend', 'time series']): | |
| viz_type = 'line' | |
| elif any(word in query.lower() for word in ['scatter', 'relationship']): | |
| viz_type = 'scatter' | |
| elif any(word in query.lower() for word in ['histogram', 'distribution of']): | |
| viz_type = 'histogram' | |
| elif any(word in query.lower() for word in ['box', 'boxplot', 'outliers']): | |
| viz_type = 'box' | |
| elif any(word in query.lower() for word in ['heatmap', 'correlation']): | |
| viz_type = 'heatmap' | |
| print(f"Creating {viz_type} visualization...") | |
| # Find numeric columns | |
| numeric_cols = result_df.select_dtypes(include=['number']).columns.tolist() | |
| # Create basic visualization based on type | |
| if viz_type == 'pie' and len(result_df) <= 20: | |
| # Simple pie chart | |
| labels = result_df.iloc[:, 0].tolist() | |
| values = result_df.iloc[:, 1].tolist() if len(result_df.columns) > 1 else [1] * len(result_df) | |
| import plotly.graph_objects as go | |
| fig = go.Figure(data=[go.Pie(labels=labels, values=values)]) | |
| fig.update_layout(title_text='Pie Chart') | |
| elif viz_type == 'histogram' and len(numeric_cols) > 0: | |
| # Simple histogram | |
| import plotly.express as px | |
| fig = px.histogram(result_df, x=numeric_cols[0]) | |
| fig.update_layout(title_text=f'Histogram of {numeric_cols[0]}') | |
| elif viz_type == 'box' and len(numeric_cols) > 0: | |
| # Simple box plot | |
| import plotly.express as px | |
| fig = px.box(result_df, y=numeric_cols[0]) | |
| fig.update_layout(title_text=f'Box Plot of {numeric_cols[0]}') | |
| elif viz_type == 'heatmap' and len(numeric_cols) >= 2: | |
| # Simple heatmap | |
| import plotly.express as px | |
| # Create correlation matrix | |
| corr_df = result_df[numeric_cols].corr() | |
| fig = px.imshow(corr_df, text_auto=True) | |
| fig.update_layout(title_text='Correlation Heatmap') | |
| elif viz_type == 'scatter' and len(numeric_cols) >= 2: | |
| # Simple scatter plot | |
| import plotly.express as px | |
| fig = px.scatter(result_df, x=numeric_cols[0], y=numeric_cols[1]) | |
| fig.update_layout(title_text=f'Scatter Plot of {numeric_cols[0]} vs {numeric_cols[1]}') | |
| elif viz_type == 'line': | |
| # Simple line chart | |
| import plotly.express as px | |
| x_col = result_df.columns[0] | |
| y_cols = numeric_cols if numeric_cols else [result_df.columns[1]] if len(result_df.columns) > 1 else None | |
| if y_cols: | |
| fig = px.line(result_df, x=x_col, y=y_cols[0]) | |
| fig.update_layout( | |
| title_text=f'Line Chart of {y_cols[0]} over {x_col}', | |
| xaxis=dict( | |
| tickangle=-45, | |
| tickmode='auto', | |
| nticks=20 | |
| ) | |
| ) | |
| else: | |
| # Fallback to bar chart | |
| viz_type = 'bar' | |
| if viz_type == 'bar' or 'fig' not in locals(): | |
| # Simple bar chart (default) | |
| import plotly.express as px | |
| x_col = result_df.columns[0] | |
| y_col = numeric_cols[0] if numeric_cols else result_df.columns[1] if len(result_df.columns) > 1 else None | |
| # Check if we have many categories (more than 10) | |
| if len(result_df) > 10: | |
| # Use horizontal bar chart for many categories | |
| if y_col: | |
| fig = px.bar( | |
| result_df, | |
| y=x_col, # Swap x and y for horizontal orientation | |
| x=y_col, | |
| orientation='h', # Horizontal orientation | |
| title=f'Bar Chart of {y_col} by {x_col}' | |
| ) | |
| else: | |
| fig = px.bar( | |
| result_df, | |
| y=x_col, # Swap x and y for horizontal orientation | |
| orientation='h', # Horizontal orientation | |
| title=f'Bar Chart of {x_col}' | |
| ) | |
| else: | |
| # Use vertical bar chart for fewer categories | |
| if y_col: | |
| fig = px.bar( | |
| result_df, | |
| x=x_col, | |
| y=y_col, | |
| title=f'Bar Chart of {y_col} by {x_col}' | |
| ) | |
| else: | |
| fig = px.bar( | |
| result_df, | |
| x=x_col, | |
| title=f'Bar Chart of {x_col}' | |
| ) | |
| # Improve bar chart layout | |
| fig.update_layout( | |
| bargap=0.2, # Increase gap between bars | |
| uniformtext_minsize=8, # Minimum text size | |
| uniformtext_mode='hide' # Hide text if it doesn't fit | |
| ) | |
| # Set common layout properties | |
| fig.update_layout( | |
| width=fig_width, | |
| height=fig_height, | |
| template="plotly_white", | |
| margin=dict(l=40, r=40, t=80, b=80, pad=4), # Balanced margins | |
| autosize=True, # Allow the plot to resize with the container | |
| plot_bgcolor='rgba(240,240,240,0.2)', # Light gray background | |
| paper_bgcolor='white', | |
| font=dict(size=12) # Increase font size | |
| ) | |
| # Add hover information | |
| fig.update_traces( | |
| hovertemplate="%{x}: %{y}<extra></extra>", | |
| hoverlabel=dict( | |
| bgcolor="white", | |
| font_size=12, | |
| font_family="Arial" | |
| ) | |
| ) | |
| print(f"Created figure with width={fig_width}, height={fig_height}") | |
| # Convert to image with higher quality | |
| print("Converting figure to image...") | |
| img_bytes = pio.to_image(fig, format="png", width=fig_width, height=fig_height, scale=3) # Increased scale for better quality | |
| print("Image conversion successful") | |
| # Encode as base64 | |
| import base64 | |
| encoded = base64.b64encode(img_bytes).decode("ascii") | |
| img_src = f"data:image/png;base64,{encoded}" | |
| print("HTML conversion successful") | |
| # Return the HTML img tag with responsive sizing | |
| return f""" | |
| <div class="visualization-wrapper"> | |
| <img src='{img_src}' | |
| style='max-width:100%; height:auto; display:block; margin:0 auto;' | |
| alt='Data Visualization' /> | |
| </div> | |
| """ | |
| except Exception as e: | |
| import traceback | |
| print(f"Error generating visualization: {str(e)}") | |
| traceback.print_exc() | |
| return None | |
| # Create Gradio interface | |
| with gr.Blocks(title="LLM Powered Database Chatbot") as demo: | |
| gr.Markdown("# 🤖 LLM Powered Database Chatbot") | |
| gr.Markdown("Upload documents, ask questions, and get AI-powered responses!") | |
| # Add a global variable to store the current visualization | |
| current_visualization = gr.State(None) | |
| with gr.Tab("Chat & Visualizations"): | |
| # Use a custom CSS to ensure images are displayed properly | |
| gr.HTML(""" | |
| <style> | |
| .chatbot-container img { | |
| max-width: 100%; | |
| height: auto; | |
| display: block; | |
| margin: 10px 0; | |
| } | |
| .visualization-container { | |
| min-height: 500px; | |
| max-height: 800px; | |
| overflow: auto; | |
| padding: 20px; | |
| background-color: #f8f9fa; | |
| border-radius: 8px; | |
| } | |
| .visualization-container img { | |
| max-width: 100%; | |
| height: auto; | |
| display: block; | |
| margin: 0 auto; | |
| } | |
| </style> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| chatbot = gr.Chatbot(height=500, elem_classes="chatbot-container") | |
| with gr.Row(): | |
| with gr.Column(scale=8): | |
| msg = gr.Textbox( | |
| placeholder="Ask a question about your documents...", | |
| show_label=False | |
| ) | |
| with gr.Column(scale=1): | |
| pass | |
| with gr.Row(): | |
| submit_btn = gr.Button("Submit") | |
| clear_btn = gr.Button("Clear") | |
| clear_context_btn = gr.Button("Clear Context") | |
| with gr.Column(scale=1): | |
| visualization_output = gr.HTML( | |
| label="Visualization", | |
| elem_classes="visualization-container" | |
| ) | |
| with gr.Row(): | |
| clear_viz_btn = gr.Button("🗑️ Clear Visualization") | |
| download_btn = gr.Button("📥 Download Visualization") | |
| save_status = gr.Textbox(label="Save Status", visible=False) | |
| download_img = gr.Image(visible=False, type="pil", label="Download Image") | |
| # Add information about capabilities | |
| gr.Markdown(""" | |
| ### Capabilities: | |
| - **Data Analysis**: Ask questions about your data and get detailed responses | |
| - **Visualization**: Request and view graphs and charts of your data | |
| - **Multiple File Types**: Upload PDFs, TXT, DOCX, CSV, and XLSX files for analysis | |
| - **Natural Language Queries**: Ask questions in plain English about your documents | |
| """) | |
| def clear_visualization(): | |
| return "", "" | |
| def download_visualization(viz_html): | |
| if not viz_html: | |
| return None | |
| try: | |
| # Extract the base64 image data from the HTML | |
| img_data_match = re.search(r'src=\'data:image/png;base64,([^\']+)\'', viz_html) | |
| if img_data_match: | |
| # Get the base64 data | |
| img_data = img_data_match.group(1) | |
| # Convert base64 to image | |
| import base64 | |
| from io import BytesIO | |
| from PIL import Image | |
| image_data = base64.b64decode(img_data) | |
| image = Image.open(BytesIO(image_data)) | |
| return image, gr.update(visible=True) | |
| else: | |
| return None, gr.update(visible=False) | |
| except Exception as e: | |
| print(f"Error downloading visualization: {str(e)}") | |
| return None, gr.update(visible=False) | |
| clear_viz_btn.click( | |
| clear_visualization, | |
| outputs=[visualization_output, current_visualization] | |
| ) | |
| download_btn.click( | |
| download_visualization, | |
| inputs=[current_visualization], | |
| outputs=[download_img, download_img] | |
| ) | |
| # Update the process_text_query function to handle visualizations | |
| def process_text_query_with_visualization(query, history, current_viz): | |
| """Process a text query and update chat history and visualization""" | |
| if not query: | |
| return "", history, current_viz | |
| # Process the query and get the response | |
| response, new_history = process_text_query(query, history) | |
| # Check if the response contains a visualization | |
| if "<img src=" in response: | |
| # Extract the visualization HTML | |
| viz_html = response | |
| # Update the visualization state | |
| current_viz = viz_html | |
| # Return the updated state | |
| return "", new_history, current_viz | |
| # Update the button click handlers | |
| submit_btn.click( | |
| process_text_query_with_visualization, | |
| inputs=[msg, chatbot, current_visualization], | |
| outputs=[msg, chatbot, current_visualization] | |
| ).then( | |
| lambda viz: viz if viz else "", # Update visualization tab | |
| inputs=[current_visualization], | |
| outputs=[visualization_output] | |
| ) | |
| clear_btn.click(lambda: None, None, chatbot, queue=False) | |
| clear_context_btn.click(clear_context, None, chatbot, queue=False) | |
| with gr.Tab("Document Upload"): | |
| file_upload = gr.File( | |
| label="Upload Documents", | |
| file_types=[".pdf", ".txt", ".docx", ".csv", ".xlsx"], | |
| file_count="multiple" | |
| ) | |
| with gr.Row(): | |
| upload_button = gr.Button("Process & Index Documents", scale=2) | |
| flush_db_btn_doc = gr.Button("🗑️ Flush All Databases", variant="stop", scale=1) | |
| upload_output = gr.Textbox(label="Upload Status") | |
| upload_button.click( | |
| process_file_upload, | |
| inputs=[file_upload], | |
| outputs=[upload_output] | |
| ) | |
| flush_db_btn_doc.click( | |
| flush_databases, | |
| inputs=[], | |
| outputs=[upload_output] | |
| ) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| demo.launch( | |
| share=True, | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| show_error=True, | |
| debug=True | |
| ) |