Spaces:
Running
Running
| """ | |
| Chatbot Core - Main orchestrator for the schema-agnostic database chatbot. | |
| Combines all components: | |
| - Schema introspection | |
| - Query routing | |
| - RAG retrieval | |
| - SQL generation & execution | |
| - Response generation | |
| """ | |
| import logging | |
| from typing import Dict, Any, List, Optional, Tuple | |
| from dataclasses import dataclass | |
| from database import get_db, get_schema, get_introspector | |
| from rag import get_rag_engine | |
| from sql import get_sql_generator, get_sql_validator | |
| from llm import create_llm_client, LLMClient | |
| from router import get_query_router, QueryType | |
| from memory import ChatMemory, EnhancedChatMemory, create_memory | |
| logger = logging.getLogger(__name__) | |
| class ChatResponse: | |
| """Response from the chatbot.""" | |
| answer: str | |
| query_type: str | |
| sources: List[Dict[str, Any]] = None | |
| sql_query: Optional[str] = None | |
| sql_results: Optional[List[Dict]] = None | |
| error: Optional[str] = None | |
| token_usage: Optional[Dict[str, int]] = None | |
| def __post_init__(self): | |
| if self.sources is None: | |
| self.sources = [] | |
| if self.token_usage is None: | |
| self.token_usage = {"input": 0, "output": 0, "total": 0} | |
| class DatabaseChatbot: | |
| """Main chatbot class orchestrating all components.""" | |
| RESPONSE_PROMPT = """You are a helpful database assistant. Answer the user's question based on the provided context. | |
| IMPORTANT: Use the conversation history to understand follow-up questions. If the user refers to "it", "that", "the product", etc., look at the previous messages to understand what they're referring to. | |
| {context} | |
| USER QUESTION: {question} | |
| INSTRUCTIONS: | |
| - Answer ONLY based on the provided context AND conversation history | |
| - Do NOT use outside knowledge, general assumptions, or hallucinate facts | |
| - If the context doesn't contain the answer, explicitly state that the information is not available in the database | |
| - Resolve pronouns using previous messages | |
| - Be concise but complete | |
| - Format data nicely | |
| {language_instruction} | |
| INTERACTION GUIDELINES: | |
| - If the SQL results show a list (e.g., top products) and hit the limit (5, 10, or 50), MENTION this and ASK the user if they want to see more or a specific number. | |
| Example: "Here are the top 5 products... Would you like to see the top 10?" | |
| - If the user's question was broad (e.g., "Show me products") and you're showing a limited set, ASK if they want to filter by a specific attribute (e.g., "Would you like to filter by category or price?"). | |
| - If the answer is "0 results" for a "top/best" query, suggest looking at the data generally. | |
| - IF SUBJECTIVE INFERENCE WAS USED (e.g., inferred "summer" = sandals), EXPLAIN THIS to the user. | |
| Example: "I found these products that match 'summer' (based on being Sandals or breathability)..." | |
| YOUR RESPONSE:""" | |
| def __init__(self, llm_client: Optional[LLMClient] = None): | |
| self.db = get_db() | |
| self.introspector = get_introspector() | |
| self.rag_engine = get_rag_engine() | |
| # Pass database type to SQL generator for dialect-specific SQL | |
| db_type = self.db.db_type.value | |
| self.sql_generator = get_sql_generator(db_type) | |
| self.sql_validator = get_sql_validator() | |
| self.router = get_query_router() | |
| self.llm_client = llm_client | |
| self._schema_initialized = False | |
| self._rag_initialized = False | |
| def set_llm_client(self, llm_client: LLMClient): | |
| """Configure the LLM client.""" | |
| self.llm_client = llm_client | |
| self.sql_generator.set_llm_client(llm_client) | |
| self.router.set_llm_client(llm_client) | |
| def _get_language_instruction(self, language: str) -> str: | |
| """Generate language instruction for the response prompt. | |
| Args: | |
| language: The target language name (e.g., 'Hindi', 'Spanish') | |
| Returns: | |
| A formatted instruction string for the LLM | |
| """ | |
| if language == "English": | |
| return "" # No special instruction needed for English | |
| # Extract the base language name from display name | |
| # e.g., "हिन्दी (Hindi)" -> "Hindi" | |
| base_language = language | |
| if "(" in language and ")" in language: | |
| base_language = language.split("(")[1].rstrip(")") | |
| return f"\n- **IMPORTANT: Respond ENTIRELY in {base_language}**. Translate your response to {base_language}. Keep technical terms (like table names, column names, SQL) as-is, but explain everything else in {base_language}." | |
| def initialize(self) -> Tuple[bool, str]: | |
| """Initialize the chatbot by introspecting the database.""" | |
| try: | |
| # Test connection | |
| success, msg = self.db.test_connection() | |
| if not success: | |
| return False, f"Database connection failed: {msg}" | |
| # Introspect schema | |
| schema = self.introspector.introspect(force_refresh=True) | |
| # Configure SQL validator with discovered tables | |
| self.sql_validator.set_allowed_tables(schema.table_names) | |
| self._schema_initialized = True | |
| return True, f"Initialized with {len(schema.tables)} tables" | |
| except Exception as e: | |
| logger.error(f"Initialization failed: {e}") | |
| return False, str(e) | |
| def index_text_data(self, progress_callback=None) -> int: | |
| """Index all text data for RAG.""" | |
| if not self._schema_initialized: | |
| raise RuntimeError("Chatbot not initialized. Call initialize() first.") | |
| # Use the instance's introspector which might be patched for custom DB | |
| schema = self.introspector.introspect() | |
| total_docs = 0 | |
| for table_name, table_info in schema.tables.items(): | |
| text_cols = [c.name for c in table_info.text_columns] | |
| if not text_cols: | |
| continue | |
| pk = table_info.primary_keys[0] if table_info.primary_keys else None | |
| cols_to_select = text_cols + ([pk] if pk else []) | |
| # Quote table name based on DB specific rules to handle case sensitivity and special chars | |
| if self.db.db_type.value == "mysql": | |
| quoted_table = f"`{table_name}`" | |
| else: | |
| quoted_table = f'"{table_name}"' | |
| query = f"SELECT {', '.join(cols_to_select)} FROM {quoted_table} LIMIT 1000" | |
| try: | |
| # Try the primary query | |
| query = f"SELECT {', '.join(cols_to_select)} FROM {quoted_table} LIMIT 1000" | |
| rows = self.db.execute_query(query) | |
| docs = self.rag_engine.index_table(table_name, rows, text_cols, pk) | |
| total_docs += docs | |
| if progress_callback: | |
| progress_callback(table_name, docs) | |
| except Exception as e: | |
| # Fallback mechanism for PostgreSQL if table not found (often due to schema issues) | |
| if self.db.db_type.value == "postgresql" and "UndefinedTable" in str(e): | |
| try: | |
| logger.warning(f"Initial query failed for {table_name}, trying 'public' schema prefix...") | |
| fallback_query = f"SELECT {', '.join(cols_to_select)} FROM public.\"{table_name}\" LIMIT 1000" | |
| rows = self.db.execute_query(fallback_query) | |
| docs = self.rag_engine.index_table(table_name, rows, text_cols, pk) | |
| total_docs += docs | |
| if progress_callback: | |
| progress_callback(table_name, docs) | |
| continue # Success with fallback | |
| except Exception as e2: | |
| logger.error(f"Fallback query also failed for {table_name}: {e2}") | |
| logger.warning(f"Failed to index {table_name}: {e}") | |
| self.rag_engine.save() | |
| self._rag_initialized = True | |
| return total_docs | |
| def chat(self, query: str, memory: Optional[ChatMemory] = None, ignored_tables: Optional[List[str]] = None, language: str = "English") -> ChatResponse: | |
| """Process a user query and return a response. | |
| Args: | |
| query: The user's question | |
| memory: Optional chat memory for context | |
| ignored_tables: Tables to exclude from queries | |
| language: Preferred response language (default: English) | |
| """ | |
| if not self._schema_initialized: | |
| return ChatResponse(answer="Chatbot not initialized.", query_type="error", | |
| error="Call initialize() first") | |
| if not self.llm_client: | |
| return ChatResponse(answer="LLM not configured.", query_type="error", | |
| error="Configure LLM client first") | |
| try: | |
| # Use instance introspector | |
| schema = self.introspector.introspect() | |
| schema_context = schema.to_context_string(ignored_tables=ignored_tables) | |
| # Calculate allowed tables for RAG and Validator | |
| allowed_tables = None | |
| if ignored_tables: | |
| allowed_tables = [t for t in schema.table_names if t not in ignored_tables] | |
| # Update validator to only allow these tables | |
| self.sql_validator.set_allowed_tables(allowed_tables) | |
| else: | |
| self.sql_validator.set_allowed_tables(schema.table_names) | |
| # Check for memory commands using regex for flexibility | |
| import re | |
| # This regex captures patterns like "save this", "remember that my size is 7", "please memorize my name" | |
| save_pattern = re.compile(r"(?:please\s+)?(?:save|remember|memorize|record|store)\s+(?:this|that|to\s+(?:main\s+)?memory)?\s*(?:that)?\s*:?\s*(.*)", re.IGNORECASE) | |
| match = save_pattern.search(query.strip()) | |
| # Additional check for colloquial "save to memory" or "memory: X" phrasings | |
| is_memory_phrase = any(phrase in query.lower() for phrase in ["save to memory", "remember this", "memorize this", "save my", "remember my"]) | |
| is_command = bool(match) or is_memory_phrase | |
| if is_command and memory: | |
| # Prioritize explicit content from the regex match | |
| content_to_save = match.group(1).strip() if (match and match.group(1)) else "" | |
| # Special case enhancement | |
| if not content_to_save: | |
| # Try to extract content if regex was too strict but is_memory_phrase matched | |
| # e.g. "my shoe size is 7, save to memory" | |
| if "save" in query.lower(): | |
| content_to_save = query.lower().split("save")[0].strip().strip(",").strip() | |
| elif "remember" in query.lower(): | |
| content_to_save = query.lower().split("remember")[1].strip() | |
| # If we have content, save it | |
| if content_to_save: | |
| is_ok, msg = memory.save_permanent_context(content_to_save) | |
| if is_ok: | |
| return ChatResponse(answer=f"💾 I've saved to your permanent memory: '{content_to_save}'", query_type="memory") | |
| else: | |
| return ChatResponse(answer=f"❌ Failed to save to permanent memory: {msg}", query_type="memory") | |
| # If no content (e.g. "Save this"), save the previous conversation turn | |
| elif len(memory.messages) >= 2: | |
| # We try to grab the last Assistant Response | |
| last_ai_msg = next((m for m in reversed(memory.messages[:-1]) if m.role == "assistant"), None) | |
| last_user_msg = next((m for m in reversed(memory.messages[:-1]) if m.role == "user"), None) | |
| if last_ai_msg and last_user_msg: | |
| context_str = f"User: {last_user_msg.content} | AI: {last_ai_msg.content}" | |
| is_ok, msg = memory.save_permanent_context(context_str) | |
| if is_ok: | |
| return ChatResponse(answer="💾 I've saved our last exchange to your permanent memory.", query_type="memory") | |
| else: | |
| return ChatResponse(answer=f"❌ Failed to save to permanent memory: {msg}", query_type="memory") | |
| else: | |
| return ChatResponse(answer="⚠️ I couldn't find a clear previous exchange to save. Try saying 'Remember that [fact]'.", query_type="memory") | |
| else: | |
| return ChatResponse(answer="⚠️ Nothing previous to save. Tell me something to remember first!", query_type="memory") | |
| # Get chat history for context | |
| history = memory.get_context_messages(5) if memory else [] | |
| # Route the query | |
| routing = self.router.route(query, schema_context, history) | |
| # Initial usage from routing | |
| routing_usage = routing.token_usage or {"input": 0, "output": 0, "total": 0} | |
| # Process based on route | |
| response = None | |
| if routing.query_type == QueryType.RAG: | |
| response = self._handle_rag(query, history, allowed_tables, language) | |
| elif routing.query_type == QueryType.SQL: | |
| response = self._handle_sql(query, schema_context, history, allowed_tables, language) | |
| elif routing.query_type == QueryType.HYBRID: | |
| response = self._handle_hybrid(query, schema_context, history, allowed_tables, language) | |
| else: | |
| response = self._handle_general(query, history, language) | |
| # Add routing tokens to total | |
| if response.token_usage: | |
| response.token_usage["input"] += routing_usage.get("input", 0) | |
| response.token_usage["output"] += routing_usage.get("output", 0) | |
| response.token_usage["total"] += routing_usage.get("total", 0) | |
| else: | |
| response.token_usage = routing_usage | |
| return response | |
| except Exception as e: | |
| logger.error(f"Chat error: {e}") | |
| return ChatResponse(answer=f"Error: {str(e)}", query_type="error", error=str(e)) | |
| def _handle_rag(self, query: str, history: List[Dict], allowed_tables: Optional[List[str]] = None, language: str = "English") -> ChatResponse: | |
| """Handle RAG-based query.""" | |
| # Check if we have any indexed data | |
| if self.rag_engine.document_count == 0: | |
| # Even for this error, we consumed tokens up to the routing decision, but since | |
| # routing happens before this function, we can't easily track that here. | |
| # However, we can return empty usage. | |
| usage = {"input": 0, "output": 0, "total": 0} | |
| return ChatResponse( | |
| answer="⚠️ **I can't answer this yet.**\n\nThis looks like a semantic question (searching for meaning/concepts), but you haven't **indexed the text data** yet.\n\nPlease click the **'📚 Index Text Data'** button in the sidebar to enable this functionality.", | |
| query_type="error", | |
| error="RAG index is empty", | |
| token_usage=usage | |
| ) | |
| context = self.rag_engine.get_context(query, top_k=5, table_filter=allowed_tables) | |
| # Get language instruction | |
| language_instruction = self._get_language_instruction(language) | |
| prompt = self.RESPONSE_PROMPT.format( | |
| context=f"RELEVANT DATA:\n{context}", | |
| question=query, | |
| language_instruction=language_instruction | |
| ) | |
| messages = self._construct_messages( | |
| "You are a helpful database assistant.", | |
| history, | |
| prompt | |
| ) | |
| response = self.llm_client.chat(messages) | |
| usage = { | |
| "input": response.input_tokens, | |
| "output": response.output_tokens, | |
| "total": response.total_tokens | |
| } | |
| return ChatResponse(answer=response.content, query_type="rag", | |
| sources=[{"type": "semantic_search", "context": context[:500]}], | |
| token_usage=usage) | |
| def _handle_sql(self, query: str, schema_context: str, history: List[Dict], allowed_tables: Optional[List[str]] = None, language: str = "English") -> ChatResponse: | |
| """Handle SQL-based query.""" | |
| sql, gen_response = self.sql_generator.generate(query, schema_context, history) | |
| # Initial usage from SQL generation | |
| total_usage = { | |
| "input": gen_response.input_tokens, | |
| "output": gen_response.output_tokens, | |
| "total": gen_response.total_tokens | |
| } | |
| # Validate SQL | |
| is_valid, msg, sanitized_sql = self.sql_validator.validate(sql) | |
| if not is_valid: | |
| return ChatResponse(answer=f"Could not generate safe query: {msg}", | |
| query_type="sql", error=msg, token_usage=total_usage) | |
| # Execute query | |
| try: | |
| results = self.db.execute_query(sanitized_sql) | |
| except Exception as e: | |
| return ChatResponse(answer=f"Query execution failed: {e}", | |
| query_type="sql", sql_query=sanitized_sql, error=str(e), | |
| token_usage=total_usage) | |
| # SMART FALLBACK: If SQL returns nothing, it might be a semantic issue (e.g. wrong column) | |
| # We try RAG as a fallback if SQL found nothing | |
| if not results: | |
| logger.info(f"SQL returned no results for query: '{query}'. Falling back to RAG.") | |
| rag_response = self._handle_rag(query, history, allowed_tables, language) | |
| # Combine the info: "I couldn't find an exact match in the rows, but here is what I found semantically:" | |
| rag_response.answer = f"I couldn't find a direct match using a database query, but here is what I found in the product descriptions:\n\n{rag_response.answer}" | |
| rag_response.query_type = "hybrid_fallback" | |
| rag_response.sql_query = sanitized_sql | |
| # Add usage from SQL gen to RAG usage | |
| if rag_response.token_usage: | |
| rag_response.token_usage["input"] += total_usage["input"] | |
| rag_response.token_usage["output"] += total_usage["output"] | |
| rag_response.token_usage["total"] += total_usage["total"] | |
| else: | |
| rag_response.token_usage = total_usage | |
| return rag_response | |
| # Generate response with language instruction | |
| language_instruction = self._get_language_instruction(language) | |
| context = f"SQL QUERY:\n{sanitized_sql}\n\nRESULTS:\n{self._format_results(results)}" | |
| prompt = self.RESPONSE_PROMPT.format( | |
| context=context, | |
| question=query, | |
| language_instruction=language_instruction | |
| ) | |
| messages = self._construct_messages( | |
| "You are a helpful database assistant.", | |
| history, | |
| prompt | |
| ) | |
| final_response = self.llm_client.chat(messages) | |
| # Add usage from final response | |
| total_usage["input"] += final_response.input_tokens | |
| total_usage["output"] += final_response.output_tokens | |
| total_usage["total"] += final_response.total_tokens | |
| return ChatResponse(answer=final_response.content, query_type="sql", | |
| sql_query=sanitized_sql, sql_results=results[:10], | |
| token_usage=total_usage) | |
| def _handle_hybrid(self, query: str, schema_context: str, history: List[Dict], allowed_tables: Optional[List[str]] = None, language: str = "English") -> ChatResponse: | |
| """Handle hybrid RAG + SQL query.""" | |
| # Get RAG context | |
| rag_context = self.rag_engine.get_context(query, top_k=3, table_filter=allowed_tables) | |
| # Try SQL as well | |
| sql_context = "" | |
| sql_query = None | |
| total_usage = {"input": 0, "output": 0, "total": 0} | |
| try: | |
| sql, gen_response = self.sql_generator.generate(query, schema_context, history) | |
| # Accumulate usage | |
| total_usage["input"] += gen_response.input_tokens | |
| total_usage["output"] += gen_response.output_tokens | |
| total_usage["total"] += gen_response.total_tokens | |
| is_valid, _, sanitized_sql = self.sql_validator.validate(sql) | |
| if is_valid: | |
| results = self.db.execute_query(sanitized_sql) | |
| sql_context = f"\nSQL RESULTS:\n{self._format_results(results)}" | |
| sql_query = sanitized_sql | |
| except Exception as e: | |
| logger.debug(f"SQL part of hybrid failed: {e}") | |
| # Get language instruction | |
| language_instruction = self._get_language_instruction(language) | |
| context = f"SEMANTIC SEARCH RESULTS:\n{rag_context}{sql_context}" | |
| prompt = self.RESPONSE_PROMPT.format( | |
| context=context, | |
| question=query, | |
| language_instruction=language_instruction | |
| ) | |
| messages = self._construct_messages( | |
| "You are a helpful database assistant.", | |
| history, | |
| prompt | |
| ) | |
| final_response = self.llm_client.chat(messages) | |
| # Add final usage | |
| total_usage["input"] += final_response.input_tokens | |
| total_usage["output"] += final_response.output_tokens | |
| total_usage["total"] += final_response.total_tokens | |
| return ChatResponse(answer=final_response.content, query_type="hybrid", sql_query=sql_query, token_usage=total_usage) | |
| def _construct_messages(self, system_instruction: str, history: List[Dict], user_content: str) -> List[Dict]: | |
| """Construct message list, merging system messages from history.""" | |
| # Check if first history item is a system message (from memory) | |
| additional_context = "" | |
| filtered_history = [] | |
| for msg in history: | |
| if msg.get("role") == "system": | |
| additional_context += f"\n\n{msg.get('content')}" | |
| else: | |
| filtered_history.append(msg) | |
| full_system_prompt = f"{system_instruction}{additional_context}" | |
| messages = [{"role": "system", "content": full_system_prompt}] | |
| messages.extend(filtered_history) | |
| messages.append({"role": "user", "content": user_content}) | |
| return messages | |
| def _handle_general(self, query: str, history: List[Dict], language: str = "English") -> ChatResponse: | |
| """Handle conversation.""" | |
| # Get language instruction | |
| language_instruction = self._get_language_instruction(language) | |
| # Build language suffix for system prompt | |
| language_suffix = "" | |
| if language != "English": | |
| base_language = language | |
| if "(" in language and ")" in language: | |
| base_language = language.split("(")[1].rstrip(")") | |
| language_suffix = f"\n- Respond entirely in {base_language}." | |
| # Use a strict prompt for general conversation as well to prevent hallucinations | |
| strict_system_prompt = ( | |
| "You are a helpful database assistant.\n" | |
| "INSTRUCTIONS:\n" | |
| "- Answer ONLY based on the conversation history and any context provided within it.\n" | |
| "- Do NOT use outside knowledge, general assumptions, or hallucinate facts.\n" | |
| "- If the answer is not in the history or context, state that you don't have that information.\n" | |
| f"- Be concise.{language_suffix}" | |
| ) | |
| messages = self._construct_messages( | |
| strict_system_prompt, | |
| history, | |
| query | |
| ) | |
| response = self.llm_client.chat(messages) | |
| usage = { | |
| "input": response.input_tokens, | |
| "output": response.output_tokens, | |
| "total": response.total_tokens | |
| } | |
| return ChatResponse(answer=response.content, query_type="general", token_usage=usage) | |
| def _format_results(self, results: List[Dict], max_rows: int = 10) -> str: | |
| """Format SQL results for display.""" | |
| if not results: | |
| return "No results found." | |
| rows = results[:max_rows] | |
| lines = [] | |
| # Header | |
| headers = list(rows[0].keys()) | |
| lines.append(" | ".join(headers)) | |
| lines.append("-" * len(lines[0])) | |
| # Rows | |
| for row in rows: | |
| values = [str(v)[:50] for v in row.values()] | |
| lines.append(" | ".join(values)) | |
| if len(results) > max_rows: | |
| lines.append(f"... and {len(results) - max_rows} more rows") | |
| return "\n".join(lines) | |
| def get_schema_summary(self) -> str: | |
| """Get a summary of the database schema.""" | |
| if not self._schema_initialized: | |
| return "Schema not loaded." | |
| return self.introspector.introspect().to_context_string() | |
| def create_chatbot(llm_client: Optional[LLMClient] = None) -> DatabaseChatbot: | |
| return DatabaseChatbot(llm_client) | |