Spaces:
Running
Running
| """ | |
| Schema-Agnostic Database Chatbot - Streamlit Application | |
| A production-grade chatbot that connects to ANY database | |
| (MySQL, PostgreSQL, SQLite) and provides intelligent querying | |
| through RAG and Text-to-SQL. | |
| Uses Groq for FREE LLM inference! | |
| """ | |
| import os | |
| from pathlib import Path | |
| # Load .env FIRST before any other imports | |
| from dotenv import load_dotenv | |
| load_dotenv(Path(__file__).parent / ".env") | |
| import streamlit as st | |
| import uuid | |
| import time | |
| import io | |
| import csv | |
| import base64 | |
| import pandas as pd | |
| from datetime import datetime | |
| # Page config must be first | |
| st.set_page_config( | |
| page_title="OnceDataBot", | |
| page_icon="🤖", | |
| layout="wide", | |
| initial_sidebar_state="expanded" | |
| ) | |
| # Imports | |
| from config import config, DatabaseConfig, DatabaseType | |
| from database import get_db, get_schema, get_introspector | |
| from database.connection import DatabaseConnection | |
| from llm import create_llm_client | |
| from chatbot import create_chatbot, DatabaseChatbot | |
| from memory import ChatMemory, EnhancedChatMemory | |
| from viz_utils import render_visualization | |
| # Groq models (all FREE!) | |
| GROQ_MODELS = [ | |
| "llama-3.3-70b-versatile", | |
| "llama-3.1-8b-instant", | |
| "mixtral-8x7b-32768", | |
| "gemma2-9b-it" | |
| ] | |
| # Database types | |
| DB_TYPES = { | |
| "MySQL": "mysql", | |
| "PostgreSQL": "postgresql", | |
| "SQLite": "sqlite" | |
| } | |
| # Supported languages for multi-language responses | |
| SUPPORTED_LANGUAGES = { | |
| "English": "en", | |
| "हिन्दी (Hindi)": "hi", | |
| "Español (Spanish)": "es", | |
| "Français (French)": "fr", | |
| "Deutsch (German)": "de", | |
| "中文 (Chinese)": "zh", | |
| "日本語 (Japanese)": "ja", | |
| "한국어 (Korean)": "ko", | |
| "Português (Portuguese)": "pt", | |
| "العربية (Arabic)": "ar", | |
| "Русский (Russian)": "ru", | |
| "Italiano (Italian)": "it", | |
| "Nederlands (Dutch)": "nl", | |
| "தமிழ் (Tamil)": "ta", | |
| "తెలుగు (Telugu)": "te", | |
| "मराठी (Marathi)": "mr", | |
| "বাংলা (Bengali)": "bn", | |
| "ગુજરાতી (Gujarati)": "gu" | |
| } | |
| def create_custom_db_config(db_type: str, **kwargs) -> DatabaseConfig: | |
| """Create a custom database configuration from user input.""" | |
| return DatabaseConfig( | |
| db_type=DatabaseType(db_type), | |
| host=kwargs.get("host", ""), | |
| port=kwargs.get("port", 3306 if db_type == "mysql" else 5432), | |
| database=kwargs.get("database", ""), | |
| username=kwargs.get("username", ""), | |
| password=kwargs.get("password", ""), | |
| ssl_ca=kwargs.get("ssl_ca", None) | |
| ) | |
| def create_custom_memory(session_id: str, user_id: str, db_connection, llm_client=None, | |
| enable_summarization=True, summary_threshold=10) -> EnhancedChatMemory: | |
| """Create enhanced memory with a custom database connection.""" | |
| return EnhancedChatMemory( | |
| session_id=session_id, | |
| user_id=user_id, | |
| max_messages=20, | |
| db_connection=db_connection, | |
| llm_client=llm_client, | |
| enable_summarization=enable_summarization, | |
| summary_threshold=summary_threshold | |
| ) | |
| def init_session_state(): | |
| """Initialize Streamlit session state.""" | |
| if "session_id" not in st.session_state: | |
| st.session_state.session_id = str(uuid.uuid4()) | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [] | |
| if "chatbot" not in st.session_state: | |
| st.session_state.chatbot = None | |
| if "initialized" not in st.session_state: | |
| st.session_state.initialized = False | |
| if "user_id" not in st.session_state: | |
| st.session_state.user_id = "default" | |
| if "enable_summarization" not in st.session_state: | |
| st.session_state.enable_summarization = True | |
| if "summary_threshold" not in st.session_state: | |
| st.session_state.summary_threshold = 10 | |
| if "memory" not in st.session_state: | |
| st.session_state.memory = None | |
| if "indexed" not in st.session_state: | |
| st.session_state.indexed = False | |
| if "db_source" not in st.session_state: | |
| st.session_state.db_source = "environment" # "environment" or "custom" | |
| if "custom_db_config" not in st.session_state: | |
| st.session_state.custom_db_config = None | |
| if "custom_db_connection" not in st.session_state: | |
| st.session_state.custom_db_connection = None | |
| if "ignored_tables" not in st.session_state: | |
| st.session_state.ignored_tables = set() | |
| if "response_language" not in st.session_state: | |
| st.session_state.response_language = "English" | |
| if "favorites" not in st.session_state: | |
| st.session_state.favorites = [] # List of message indices that are favorited | |
| def export_results_to_csv(results: list) -> str: | |
| """Convert SQL results to CSV format and return as downloadable string.""" | |
| if not results: | |
| return "" | |
| output = io.StringIO() | |
| writer = csv.DictWriter(output, fieldnames=results[0].keys()) | |
| writer.writeheader() | |
| writer.writerows(results) | |
| return output.getvalue() | |
| def export_chat_to_text() -> str: | |
| """Export chat messages to text format.""" | |
| if not st.session_state.messages: | |
| return "No messages to export." | |
| lines = [] | |
| lines.append("=" * 50) | |
| lines.append(f"OnceDataBot Chat Export") | |
| lines.append(f"Exported: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") | |
| lines.append(f"User: {st.session_state.user_id}") | |
| lines.append("=" * 50) | |
| lines.append("") | |
| for i, msg in enumerate(st.session_state.messages): | |
| role = "🧑 User" if msg["role"] == "user" else "🤖 Assistant" | |
| is_favorited = "⭐ " if i in st.session_state.favorites else "" | |
| lines.append(f"{is_favorited}{role}:") | |
| lines.append(msg["content"]) | |
| if msg["role"] == "assistant" and "metadata" in msg: | |
| meta = msg["metadata"] | |
| if meta.get("sql_query"): | |
| lines.append(f"\n📝 SQL Query: {meta['sql_query']}") | |
| if meta.get("query_type"): | |
| lines.append(f"📌 Query Type: {meta['query_type']}") | |
| if meta.get("execution_time"): | |
| lines.append(f"⏱️ Execution Time: {meta['execution_time']:.2f}s") | |
| lines.append("-" * 40) | |
| lines.append("") | |
| return "\n".join(lines) | |
| def render_copy_button(text: str, key: str): | |
| """Render a copy to clipboard button using Streamlit.""" | |
| # Using a workaround with st.code which has built-in copy | |
| st.code(text, language="sql") | |
| def render_database_config(): | |
| """Render database configuration section in sidebar.""" | |
| st.subheader("🗄️ Database Configuration") | |
| # Database source selection | |
| db_source = st.radio( | |
| "Database Source", | |
| options=["Use Environment Variables", "Custom Database"], | |
| index=0 if st.session_state.db_source == "environment" else 1, | |
| key="db_source_radio", | |
| help="Choose to use .env settings or enter custom credentials" | |
| ) | |
| st.session_state.db_source = "environment" if db_source == "Use Environment Variables" else "custom" | |
| if st.session_state.db_source == "environment": | |
| # Show current environment config | |
| current_db_type = config.database.db_type.value.upper() | |
| st.info(f"📌 Using {current_db_type} from environment") | |
| st.caption(f"Host: {config.database.host}") | |
| return None | |
| else: | |
| # Custom database configuration | |
| st.markdown("##### Enter Database Credentials") | |
| # Database type selector | |
| db_type_label = st.selectbox( | |
| "Database Type", | |
| options=list(DB_TYPES.keys()), | |
| index=0, | |
| key="custom_db_type" | |
| ) | |
| db_type = DB_TYPES[db_type_label] | |
| if db_type == "sqlite": | |
| # SQLite only needs a file path | |
| database = st.text_input( | |
| "SQLite Database File", | |
| value="ingested_data.db", | |
| key="db_sqlite_path", | |
| help="Path to the .db file (will be created if it doesn't exist)" | |
| ) | |
| return { | |
| "db_type": db_type, | |
| "database": database | |
| } | |
| else: # MySQL or PostgreSQL | |
| # MySQL or PostgreSQL | |
| col1, col2 = st.columns([3, 1]) | |
| with col1: | |
| host = st.text_input( | |
| "Host", | |
| value="", | |
| key="db_host_input", | |
| placeholder="your-database-host.com" | |
| ) | |
| with col2: | |
| default_port = 3306 if db_type == "mysql" else 5432 | |
| port = st.number_input( | |
| "Port", | |
| value=default_port, | |
| min_value=1, | |
| max_value=65535, | |
| key="db_port_input" | |
| ) | |
| database = st.text_input( | |
| "Database Name", | |
| value="", | |
| key="db_name_input", | |
| placeholder="your_database" | |
| ) | |
| username = st.text_input( | |
| "Username", | |
| value="", | |
| key="db_user_input", | |
| placeholder="your_username" | |
| ) | |
| password = st.text_input( | |
| "Password", | |
| value="", | |
| type="password", | |
| key="db_pass_input" | |
| ) | |
| # Optional SSL | |
| with st.expander("🔒 SSL Settings (Optional)"): | |
| ssl_ca = st.text_input( | |
| "SSL CA Certificate Path", | |
| value="", | |
| key="ssl_ca_input", | |
| help="Path to SSL CA certificate file (for cloud databases like Aiven)" | |
| ) | |
| return { | |
| "db_type": db_type, | |
| "host": host, | |
| "port": int(port), | |
| "database": database, | |
| "username": username, | |
| "password": password, | |
| "ssl_ca": ssl_ca if ssl_ca else None | |
| } | |
| def render_sidebar(): | |
| """Render the configuration sidebar.""" | |
| with st.sidebar: | |
| st.title("⚙️ Settings") | |
| # Session Dashboard | |
| if st.session_state.messages: | |
| st.markdown("### 📊 Session Stats") | |
| # Calculate stats | |
| total_msgs = len(st.session_state.messages) | |
| assistant_msgs = [m for m in st.session_state.messages if m.get("role") == "assistant"] | |
| sql_queries = sum(1 for m in assistant_msgs if m.get("metadata", {}).get("sql_query")) | |
| total_tokens = 0 | |
| exec_times = [] | |
| for m in assistant_msgs: | |
| meta = m.get("metadata", {}) | |
| total_tokens += meta.get("token_usage", {}).get("total", 0) | |
| if meta.get("execution_time"): | |
| exec_times.append(meta["execution_time"]) | |
| avg_time = sum(exec_times) / len(exec_times) if exec_times else 0 | |
| col_s1, col_s2 = st.columns(2) | |
| col_s1.metric("Queries", sql_queries) | |
| col_s2.metric("Tokens", f"{total_tokens:,}") | |
| st.caption(f"⏱️ Avg Time: {avg_time:.2f}s | 💬 Msgs: {total_msgs}") | |
| st.divider() | |
| # User Profile | |
| st.subheader("👤 User Profile") | |
| user_id = st.text_input( | |
| "User ID / Name", | |
| value=st.session_state.get("user_id", "default"), | |
| key="user_id_input", | |
| help="Your unique ID for private memory storage" | |
| ) | |
| if user_id != st.session_state.get("user_id"): | |
| st.session_state.user_id = user_id | |
| st.session_state.session_id = str(uuid.uuid4()) | |
| st.session_state.messages = [] | |
| # Recreate memory for new user | |
| if st.session_state.custom_db_connection: | |
| st.session_state.memory = create_custom_memory( | |
| st.session_state.session_id, | |
| user_id, | |
| st.session_state.custom_db_connection, | |
| st.session_state.get("llm"), | |
| st.session_state.enable_summarization, | |
| st.session_state.summary_threshold | |
| ) | |
| elif st.session_state.initialized: | |
| from memory import create_enhanced_memory | |
| st.session_state.memory = create_enhanced_memory( | |
| st.session_state.session_id, | |
| user_id=user_id, | |
| enable_summarization=st.session_state.enable_summarization, | |
| summary_threshold=st.session_state.summary_threshold | |
| ) | |
| if st.session_state.memory: | |
| st.session_state.memory.clear_user_history() | |
| st.rerun() | |
| st.divider() | |
| # Language Selection | |
| st.subheader("🌐 Response Language") | |
| selected_language = st.selectbox( | |
| "Select Language", | |
| options=list(SUPPORTED_LANGUAGES.keys()), | |
| index=list(SUPPORTED_LANGUAGES.keys()).index(st.session_state.response_language), | |
| key="language_selector", | |
| help="Choose the language for chatbot responses" | |
| ) | |
| if selected_language != st.session_state.response_language: | |
| st.session_state.response_language = selected_language | |
| st.toast(f"🌐 Language changed to {selected_language}") | |
| st.divider() | |
| if st.session_state.messages: | |
| st.download_button( | |
| label="📄 Export Chat", | |
| data=export_chat_to_text(), | |
| file_name=f"chat_export_{datetime.now().strftime('%Y%m%d_%H%M%S')}.txt", | |
| mime="text/plain", | |
| use_container_width=True, | |
| help="Download your chat conversation as a text file" | |
| ) | |
| st.divider() | |
| # CSV Ingestion Section | |
| st.subheader("📥 Ingest CSV Data") | |
| uploaded_files = st.file_uploader( | |
| "Upload CSV(s) to create database", | |
| type=["csv"], | |
| accept_multiple_files=True, | |
| help="Your CSVs will be converted to tables in a local SQLite database" | |
| ) | |
| if uploaded_files: | |
| if st.button("🚀 Upload & Initialize", use_container_width=True): | |
| with st.spinner("Processing CSVs..."): | |
| success_count = 0 | |
| table_names = [] | |
| for uploaded_file in uploaded_files: | |
| success, name, rows = ingest_csv(uploaded_file) | |
| if success: | |
| success_count += 1 | |
| table_names.append(name) | |
| else: | |
| st.error(f"Failed to ingest {uploaded_file.name}: {name}") | |
| if success_count > 0: | |
| st.success(f"Successfully ingested {success_count} file(s) as tables: {', '.join(table_names)}") | |
| # Now initialize chatbot with this SQLite DB | |
| sqlite_params = { | |
| "db_type": "sqlite", | |
| "database": "ingested_data.db" | |
| } | |
| # Temporarily set db_source to custom for initialization | |
| old_source = st.session_state.db_source | |
| st.session_state.db_source = "custom" | |
| init_success = initialize_chatbot(sqlite_params, None, None) | |
| if not init_success: | |
| st.session_state.db_source = old_source | |
| else: | |
| st.rerun() | |
| st.divider() | |
| # Database Configuration | |
| custom_db_params = render_database_config() | |
| st.divider() | |
| # LLM Configuration | |
| st.subheader("🤖 LLM Configuration") | |
| # Show status of API key | |
| if os.getenv("GROQ_API_KEY"): | |
| st.success("✓ API Key configured") | |
| else: | |
| st.warning("⚠️ GROQ_API_KEY not set in environment") | |
| st.divider() | |
| # Initialize Button | |
| if st.button("🚀 Connect & Initialize", use_container_width=True, type="primary"): | |
| with st.spinner("Connecting to database..."): | |
| success = initialize_chatbot(custom_db_params, None, None) | |
| if success: | |
| st.success("✅ Connected!") | |
| st.rerun() | |
| # Index Button (after initialization) | |
| if st.session_state.initialized: | |
| if st.button("📚 Index Text Data", use_container_width=True): | |
| with st.spinner("Indexing text data..."): | |
| index_data() | |
| st.success("✅ Indexed!") | |
| st.rerun() | |
| st.divider() | |
| # Status | |
| st.subheader("📊 Status") | |
| if st.session_state.initialized: | |
| # Show database type | |
| if st.session_state.custom_db_connection: | |
| db_type = st.session_state.custom_db_connection.db_type.value.upper() | |
| else: | |
| db_type = get_db().db_type.value.upper() | |
| st.success(f"Database: {db_type} ✓") | |
| try: | |
| schema = get_schema() | |
| st.info(f"Tables: {len(schema.tables)}") | |
| except: | |
| st.warning("Schema not loaded") | |
| if st.session_state.indexed: | |
| from rag import get_rag_engine | |
| engine = get_rag_engine() | |
| st.info(f"Indexed Docs: {engine.document_count}") | |
| else: | |
| st.warning("Not connected") | |
| # New Chat | |
| if st.button("➕ New Chat", use_container_width=True, type="secondary"): | |
| if st.session_state.memory: | |
| st.session_state.memory.clear() | |
| st.session_state.messages = [] | |
| st.session_state.session_id = str(uuid.uuid4()) | |
| current_user = st.session_state.get("user_id", "default") | |
| if st.session_state.custom_db_connection: | |
| st.session_state.memory = create_custom_memory( | |
| st.session_state.session_id, | |
| current_user, | |
| st.session_state.custom_db_connection, | |
| st.session_state.get("llm"), | |
| st.session_state.enable_summarization, | |
| st.session_state.summary_threshold | |
| ) | |
| elif st.session_state.initialized: | |
| from memory import create_enhanced_memory | |
| st.session_state.memory = create_enhanced_memory( | |
| st.session_state.session_id, | |
| user_id=current_user, | |
| enable_summarization=st.session_state.enable_summarization, | |
| summary_threshold=st.session_state.summary_threshold | |
| ) | |
| if st.session_state.get("llm"): | |
| st.session_state.memory.set_llm_client(st.session_state.llm) | |
| st.rerun() | |
| # Disconnect button (when using custom DB) | |
| if st.session_state.initialized and st.session_state.db_source == "custom": | |
| if st.button("🔌 Disconnect", use_container_width=True): | |
| if st.session_state.custom_db_connection: | |
| st.session_state.custom_db_connection.close() | |
| st.session_state.custom_db_connection = None | |
| st.session_state.chatbot = None | |
| st.session_state.initialized = False | |
| st.session_state.indexed = False | |
| st.session_state.memory = None | |
| st.success("Disconnected!") | |
| st.rerun() | |
| st.divider() | |
| # Chat History Section | |
| if st.session_state.memory: | |
| st.subheader("🕰️ Chat History") | |
| sessions = st.session_state.memory.get_user_sessions() | |
| if not sessions: | |
| st.caption("No previous chats found.") | |
| else: | |
| for session in sessions: | |
| # Highlight current session | |
| is_current = session["id"] == st.session_state.session_id | |
| icon = "🟢" if is_current else "💬" | |
| if st.button( | |
| f"{icon} {session['title']}", | |
| key=f"hist_{session['id']}", | |
| use_container_width=True, | |
| type="secondary" if not is_current else "primary" | |
| ): | |
| if not is_current: | |
| # Load selected session | |
| st.session_state.session_id = session["id"] | |
| st.session_state.memory.session_id = session["id"] | |
| st.session_state.memory.messages = [] # Clear current state local cache | |
| # Load from DB | |
| msgs = st.session_state.memory.load_session(session["id"]) | |
| st.session_state.messages = msgs | |
| # Re-populate memory object messages list for context | |
| # (We need to convert dicts back to ChatMessage objects implicitly or just rely on reload) | |
| # Actually, we should probably re-init the memory to be safe or manually populate | |
| # Let's manually populate to keep the connection valid | |
| from memory import ChatMessage | |
| st.session_state.memory.messages = [ | |
| ChatMessage( | |
| role=m['role'], | |
| content=m['content'], | |
| metadata=m.get('metadata') | |
| ) for m in msgs | |
| ] | |
| st.rerun() | |
| def initialize_chatbot(custom_db_params=None, api_key=None, model=None) -> bool: | |
| """Initialize the chatbot with either environment or custom database.""" | |
| try: | |
| # Get API key | |
| groq_api_key = api_key or os.getenv("GROQ_API_KEY", "") | |
| groq_model = model or os.getenv("GROQ_MODEL", "llama-3.3-70b-versatile") | |
| if not groq_api_key: | |
| st.error("GROQ_API_KEY not configured. Please enter your API key.") | |
| return False | |
| # Create LLM client | |
| llm = create_llm_client("groq", api_key=groq_api_key, model=groq_model) | |
| # Create database connection | |
| if custom_db_params and st.session_state.db_source == "custom": | |
| # Validate custom params | |
| db_type = custom_db_params.get("db_type", "mysql") | |
| if db_type != "sqlite": | |
| if not all([custom_db_params.get("host"), | |
| custom_db_params.get("database"), | |
| custom_db_params.get("username")]): | |
| st.error("Please fill in all required database fields.") | |
| return False | |
| else: | |
| if not custom_db_params.get("database"): | |
| st.error("Please specify a SQLite database file path.") | |
| return False | |
| # Create custom config | |
| db_config = create_custom_db_config(**custom_db_params) | |
| # Create custom connection | |
| custom_connection = DatabaseConnection(db_config) | |
| # Test connection | |
| success, msg = custom_connection.test_connection() | |
| if not success: | |
| st.error(f"Connection failed: {msg}") | |
| return False | |
| st.session_state.custom_db_connection = custom_connection | |
| st.session_state.custom_db_config = db_config | |
| # Override the global db connection for the chatbot | |
| # We need to create a chatbot with this custom connection | |
| from chatbot import DatabaseChatbot | |
| from database.schema_introspector import SchemaIntrospector | |
| from rag import get_rag_engine | |
| from sql import get_sql_generator, get_sql_validator | |
| from router import get_query_router | |
| chatbot = DatabaseChatbot.__new__(DatabaseChatbot) | |
| chatbot.db = custom_connection | |
| chatbot.introspector = SchemaIntrospector() | |
| chatbot.introspector.db = custom_connection | |
| chatbot.rag_engine = get_rag_engine() | |
| chatbot.sql_generator = get_sql_generator(db_type) | |
| chatbot.sql_validator = get_sql_validator() | |
| chatbot.router = get_query_router() | |
| chatbot.llm_client = llm | |
| chatbot._schema_initialized = False | |
| chatbot._rag_initialized = False | |
| # Set LLM client | |
| chatbot.set_llm_client(llm) | |
| # Initialize (introspect schema) | |
| schema = chatbot.introspector.introspect(force_refresh=True) | |
| chatbot.sql_validator.set_allowed_tables(schema.table_names) | |
| chatbot._schema_initialized = True | |
| st.session_state.chatbot = chatbot | |
| else: | |
| # Use environment-based connection (existing flow) | |
| chatbot = create_chatbot(llm) | |
| chatbot.set_llm_client(llm) | |
| success, msg = chatbot.initialize() | |
| if not success: | |
| st.error(f"Initialization failed: {msg}") | |
| return False | |
| st.session_state.chatbot = chatbot | |
| st.session_state.custom_db_connection = None | |
| st.session_state.llm = llm | |
| st.session_state.initialized = True | |
| st.session_state.indexed = False # Reset index status on new connection | |
| # Clear RAG index to ensure no data from previous DB connection persists | |
| if hasattr(chatbot, 'rag_engine') and hasattr(chatbot.rag_engine, 'clear_index'): | |
| chatbot.rag_engine.clear_index() | |
| # Create memory with appropriate connection | |
| db_conn = st.session_state.custom_db_connection or get_db() | |
| st.session_state.memory = create_custom_memory( | |
| st.session_state.session_id, | |
| st.session_state.user_id, | |
| db_conn, | |
| llm, | |
| st.session_state.enable_summarization, | |
| st.session_state.summary_threshold | |
| ) | |
| return True | |
| except Exception as e: | |
| st.error(f"Error: {str(e)}") | |
| import traceback | |
| st.error(traceback.format_exc()) | |
| return False | |
| def ingest_csv(uploaded_file): | |
| """Ingest a CSV file into a SQLite database.""" | |
| from sqlalchemy import create_engine | |
| try: | |
| # 1. Read CSV | |
| # Reset file pointer to beginning in case it was read before | |
| uploaded_file.seek(0) | |
| df = pd.read_csv(uploaded_file) | |
| # 2. Clean table name from filename | |
| table_name = Path(uploaded_file.name).stem.replace(" ", "_").replace("-", "_").lower() | |
| # Ensure it starts with a letter and only contains alphanumeric/underscore | |
| table_name = "".join([c for c in table_name if c.isalnum() or c == "_"]) | |
| if not table_name[0].isalpha(): | |
| table_name = "t_" + table_name | |
| # 3. Create/Connect to SQLite DB | |
| db_path = "ingested_data.db" | |
| engine = create_engine(f"sqlite:///{db_path}") | |
| # 4. Write to DB | |
| df.to_sql(table_name, engine, if_exists='replace', index=False) | |
| return True, table_name, len(df) | |
| except Exception as e: | |
| return False, str(e), 0 | |
| def index_data(): | |
| """Index text data from the database.""" | |
| if st.session_state.chatbot: | |
| progress = st.progress(0) | |
| status = st.empty() | |
| # Get schema from the correct introspector | |
| schema = st.session_state.chatbot.introspector.introspect() | |
| total_tables = len(schema.tables) | |
| indexed = 0 | |
| def progress_callback(table_name, docs): | |
| nonlocal indexed | |
| indexed += 1 | |
| progress.progress(indexed / total_tables) | |
| status.text(f"Indexed {table_name}: {docs} documents") | |
| total_docs = st.session_state.chatbot.index_text_data(progress_callback) | |
| st.session_state.indexed = True | |
| status.text(f"Total: {total_docs} documents indexed") | |
| def render_schema_explorer(): | |
| """Render schema explorer in an expander.""" | |
| if not st.session_state.initialized: | |
| return | |
| with st.expander("📋 Database Schema", expanded=False): | |
| try: | |
| schema = st.session_state.chatbot.introspector.introspect() | |
| tab_list, tab_erd = st.tabs(["📋 Table List", "🕸️ Schema Diagram"]) | |
| with tab_list: | |
| st.markdown("Uncheck tables to exclude them from the chat context.") | |
| for table_name, table_info in schema.tables.items(): | |
| col1, col2 = st.columns([0.05, 0.95]) | |
| with col1: | |
| is_active = table_name not in st.session_state.ignored_tables | |
| active = st.checkbox( | |
| "Use", | |
| value=is_active, | |
| key=f"use_{table_name}", | |
| label_visibility="collapsed", | |
| help=f"Include {table_name} in chat analysis" | |
| ) | |
| if not active: | |
| st.session_state.ignored_tables.add(table_name) | |
| else: | |
| st.session_state.ignored_tables.discard(table_name) | |
| with col2: | |
| with st.container(): | |
| st.markdown(f"**{table_name}** ({table_info.row_count or '?'} rows)") | |
| cols = [] | |
| for col in table_info.columns: | |
| pk = "🔑" if col.is_primary_key else "" | |
| txt = "📝" if col.is_text_type else "" | |
| cols.append(f"`{col.name}` {col.data_type} {pk}{txt}") | |
| st.caption(" | ".join(cols)) | |
| st.divider() | |
| with tab_erd: | |
| if len(schema.tables) > 50: | |
| st.warning("⚠️ Too many tables to visualize effectively (limit: 50).") | |
| else: | |
| try: | |
| # Build Graphviz DOT string | |
| dot = ['digraph Database {'] | |
| dot.append(' rankdir=LR;') | |
| dot.append(' node [shape=box, style="filled,rounded", fillcolor="#f0f2f6", fontname="Arial", fontsize=10];') | |
| dot.append(' edge [fontname="Arial", fontsize=9, color="#666666"];') | |
| # Add nodes (tables) | |
| for table_name in schema.tables: | |
| if table_name not in st.session_state.ignored_tables: | |
| dot.append(f' "{table_name}" [label="{table_name}", fillcolor="#e1effe", color="#1e40af"];') | |
| else: | |
| dot.append(f' "{table_name}" [label="{table_name} (ignored)", fillcolor="#f3f4f6", color="#9ca3af", fontcolor="#9ca3af"];') | |
| # Add edges (relationships) | |
| has_edges = False | |
| for table_name, table_info in schema.tables.items(): | |
| for col_name, ref_str in table_info.foreign_keys.items(): | |
| # ref_str format: "referenced_table.referenced_column" | |
| if "." in ref_str: | |
| ref_table = ref_str.split(".")[0] | |
| # specific_col = ref_str.split(".")[1] | |
| # Only draw if both tables exist in our schema list | |
| if ref_table in schema.tables: | |
| dot.append(f' "{table_name}" -> "{ref_table}" [label="{col_name}"];') | |
| has_edges = True | |
| dot.append('}') | |
| graph_code = "\n".join(dot) | |
| st.graphviz_chart(graph_code, width="stretch") | |
| if not has_edges: | |
| st.info("No foreign key relationships detected in the schema metadata.") | |
| except Exception as e: | |
| st.error(f"Could not render diagram: {e}") | |
| except Exception as e: | |
| st.error(f"Error loading schema: {e}") | |
| def render_chat_interface(): | |
| """Render the main chat interface.""" | |
| st.title("🤖 OnceDataBot") | |
| st.caption("Schema-agnostic chatbot • MySQL | PostgreSQL • Powered by Groq (FREE!)") | |
| # Schema explorer | |
| render_schema_explorer() | |
| # Chat container | |
| chat_container = st.container() | |
| with chat_container: | |
| # Display messages | |
| for i, msg in enumerate(st.session_state.messages): | |
| with st.chat_message(msg["role"]): | |
| # Create columns for message and favorite button | |
| msg_col, fav_col = st.columns([0.95, 0.05]) | |
| with msg_col: | |
| st.markdown(msg["content"]) | |
| with fav_col: | |
| # Favorite button for assistant messages | |
| if msg["role"] == "assistant": | |
| is_favorited = i in st.session_state.favorites | |
| if st.button( | |
| "⭐" if is_favorited else "☆", | |
| key=f"fav_{i}", | |
| help="Click to favorite/unfavorite this response" | |
| ): | |
| if is_favorited: | |
| st.session_state.favorites.remove(i) | |
| else: | |
| st.session_state.favorites.append(i) | |
| st.rerun() | |
| # Show metadata for assistant messages | |
| if msg["role"] == "assistant" and "metadata" in msg: | |
| meta = msg["metadata"] | |
| # Show token usage in a dropdown expander | |
| if "token_usage" in meta: | |
| usage = meta["token_usage"] | |
| total = usage.get('total', 0) | |
| with st.expander(f"📊 Token Usage ({total:,} total)", expanded=False): | |
| # Create styled token usage boxes using columns | |
| st.markdown(""" | |
| <style> | |
| .token-box { | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| border-radius: 12px; | |
| padding: 12px 16px; | |
| color: white; | |
| text-align: center; | |
| box-shadow: 0 4px 15px rgba(102, 126, 234, 0.3); | |
| margin: 4px 0; | |
| } | |
| .token-box-input { | |
| background: linear-gradient(135deg, #11998e 0%, #38ef7d 100%); | |
| box-shadow: 0 4px 15px rgba(17, 153, 142, 0.3); | |
| } | |
| .token-box-output { | |
| background: linear-gradient(135deg, #ee0979 0%, #ff6a00 100%); | |
| box-shadow: 0 4px 15px rgba(238, 9, 121, 0.3); | |
| } | |
| .token-box-total { | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| box-shadow: 0 4px 15px rgba(102, 126, 234, 0.3); | |
| } | |
| .token-label { | |
| font-size: 11px; | |
| text-transform: uppercase; | |
| letter-spacing: 1px; | |
| opacity: 0.9; | |
| margin-bottom: 4px; | |
| } | |
| .token-value { | |
| font-size: 20px; | |
| font-weight: 700; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| col1, col2, col3 = st.columns(3) | |
| with col1: | |
| st.markdown(f""" | |
| <div class="token-box token-box-input"> | |
| <div class="token-label">📥 Input Tokens</div> | |
| <div class="token-value">{usage.get('input', 0):,}</div> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| with col2: | |
| st.markdown(f""" | |
| <div class="token-box token-box-output"> | |
| <div class="token-label">📤 Output Tokens</div> | |
| <div class="token-value">{usage.get('output', 0):,}</div> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| with col3: | |
| st.markdown(f""" | |
| <div class="token-box token-box-total"> | |
| <div class="token-label">📊 Total Tokens</div> | |
| <div class="token-value">{usage.get('total', 0):,}</div> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| if meta.get("query_type"): | |
| # Show query type and execution time on same line | |
| info_text = f"Query type: {meta['query_type']}" | |
| if meta.get("execution_time"): | |
| info_text += f" • ⏱️ {meta['execution_time']:.2f}s" | |
| st.caption(info_text) | |
| # SQL Query expander | |
| if meta.get("sql_query"): | |
| with st.expander("🛠️ SQL Query & Details"): | |
| st.code(meta["sql_query"], language="sql") | |
| # Visualizations and CSV export | |
| if meta.get("sql_results"): | |
| # Only render viz if we have results | |
| render_visualization(meta["sql_results"], f"viz_{i}") | |
| # CSV Export button | |
| csv_data = export_results_to_csv(meta["sql_results"]) | |
| if csv_data: | |
| st.download_button( | |
| label="📊 Export to CSV", | |
| data=csv_data, | |
| file_name=f"query_results_{i}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv", | |
| mime="text/csv", | |
| key=f"csv_export_{i}", | |
| help="Download query results as CSV file" | |
| ) | |
| # Chat input | |
| if prompt := st.chat_input("Ask about your data..."): | |
| if not st.session_state.initialized: | |
| st.error("Please connect to a database first!") | |
| return | |
| # Add user message | |
| st.session_state.messages.append({"role": "user", "content": prompt}) | |
| # Calculate memory context for display? No, just render user msg | |
| with st.chat_message("user"): | |
| st.markdown(prompt) | |
| # Get response | |
| with st.spinner("Thinking..."): | |
| try: | |
| # Add memory interaction | |
| if st.session_state.memory: | |
| st.session_state.memory.add_message("user", prompt) | |
| # Track execution time | |
| start_time = time.time() | |
| response = st.session_state.chatbot.chat( | |
| prompt, | |
| st.session_state.memory, | |
| ignored_tables=list(st.session_state.ignored_tables), | |
| language=st.session_state.response_language | |
| ) | |
| execution_time = time.time() - start_time | |
| # Create metadata dict | |
| metadata = { | |
| "query_type": response.query_type, | |
| "sql_query": response.sql_query, | |
| "sql_results": response.sql_results, | |
| "token_usage": response.token_usage, | |
| "execution_time": execution_time | |
| } | |
| # Save to session state | |
| st.session_state.messages.append({ | |
| "role": "assistant", | |
| "content": response.answer, | |
| "metadata": metadata | |
| }) | |
| # Set flag to auto-read the latest response | |
| st.session_state.auto_read_latest = True | |
| # Save to active memory | |
| if st.session_state.memory: | |
| st.session_state.memory.add_message("assistant", response.answer) | |
| st.rerun() | |
| except Exception as e: | |
| st.error(f"An error occurred: {e}") | |
| import traceback | |
| st.error(traceback.format_exc()) | |
| def main(): | |
| """Main application entry point.""" | |
| init_session_state() | |
| # Auto-connect to environment database on first load | |
| if "auto_connect_attempted" not in st.session_state: | |
| st.session_state.auto_connect_attempted = True | |
| if st.session_state.db_source == "environment": | |
| success = initialize_chatbot() | |
| if success: | |
| st.toast("✅ Auto-connected to database!") | |
| render_sidebar() | |
| render_chat_interface() | |
| if __name__ == "__main__": | |
| main() | |