DB_Chatbot / chatbot.py
Vanshcc's picture
Upload chatbot.py
7c3bad4 verified
"""
Chatbot Core - Main orchestrator for the schema-agnostic database chatbot.
Combines all components:
- Schema introspection
- Query routing
- RAG retrieval
- SQL generation & execution
- Response generation
"""
import logging
from typing import Dict, Any, List, Optional, Tuple
from dataclasses import dataclass
from database import get_db, get_schema, get_introspector
from rag import get_rag_engine
from sql import get_sql_generator, get_sql_validator
from llm import create_llm_client, LLMClient
from router import get_query_router, QueryType
from memory import ChatMemory, EnhancedChatMemory, create_memory
logger = logging.getLogger(__name__)
@dataclass
class ChatResponse:
"""Response from the chatbot."""
answer: str
query_type: str
sources: List[Dict[str, Any]] = None
sql_query: Optional[str] = None
sql_results: Optional[List[Dict]] = None
error: Optional[str] = None
token_usage: Optional[Dict[str, int]] = None
def __post_init__(self):
if self.sources is None:
self.sources = []
if self.token_usage is None:
self.token_usage = {"input": 0, "output": 0, "total": 0}
class DatabaseChatbot:
"""Main chatbot class orchestrating all components."""
RESPONSE_PROMPT = """You are a helpful database assistant. Answer the user's question based on the provided context.
IMPORTANT: Use the conversation history to understand follow-up questions. If the user refers to "it", "that", "the product", etc., look at the previous messages to understand what they're referring to.
{context}
USER QUESTION: {question}
INSTRUCTIONS:
- Answer ONLY based on the provided context AND conversation history
- Do NOT use outside knowledge, general assumptions, or hallucinate facts
- If the context doesn't contain the answer, explicitly state that the information is not available in the database
- Resolve pronouns using previous messages
- Be concise but complete
- Format data nicely
{language_instruction}
INTERACTION GUIDELINES:
- If the SQL results show a list (e.g., top products) and hit the limit (5, 10, or 50), MENTION this and ASK the user if they want to see more or a specific number.
Example: "Here are the top 5 products... Would you like to see the top 10?"
- If the user's question was broad (e.g., "Show me products") and you're showing a limited set, ASK if they want to filter by a specific attribute (e.g., "Would you like to filter by category or price?").
- If the answer is "0 results" for a "top/best" query, suggest looking at the data generally.
- IF SUBJECTIVE INFERENCE WAS USED (e.g., inferred "summer" = sandals), EXPLAIN THIS to the user.
Example: "I found these products that match 'summer' (based on being Sandals or breathability)..."
YOUR RESPONSE:"""
def __init__(self, llm_client: Optional[LLMClient] = None):
self.db = get_db()
self.introspector = get_introspector()
self.rag_engine = get_rag_engine()
# Pass database type to SQL generator for dialect-specific SQL
db_type = self.db.db_type.value
self.sql_generator = get_sql_generator(db_type)
self.sql_validator = get_sql_validator()
self.router = get_query_router()
self.llm_client = llm_client
self._schema_initialized = False
self._rag_initialized = False
def set_llm_client(self, llm_client: LLMClient):
"""Configure the LLM client."""
self.llm_client = llm_client
self.sql_generator.set_llm_client(llm_client)
self.router.set_llm_client(llm_client)
def _get_language_instruction(self, language: str) -> str:
"""Generate language instruction for the response prompt.
Args:
language: The target language name (e.g., 'Hindi', 'Spanish')
Returns:
A formatted instruction string for the LLM
"""
if language == "English":
return "" # No special instruction needed for English
# Extract the base language name from display name
# e.g., "हिन्दी (Hindi)" -> "Hindi"
base_language = language
if "(" in language and ")" in language:
base_language = language.split("(")[1].rstrip(")")
return f"\n- **IMPORTANT: Respond ENTIRELY in {base_language}**. Translate your response to {base_language}. Keep technical terms (like table names, column names, SQL) as-is, but explain everything else in {base_language}."
def initialize(self) -> Tuple[bool, str]:
"""Initialize the chatbot by introspecting the database."""
try:
# Test connection
success, msg = self.db.test_connection()
if not success:
return False, f"Database connection failed: {msg}"
# Introspect schema
schema = self.introspector.introspect(force_refresh=True)
# Configure SQL validator with discovered tables
self.sql_validator.set_allowed_tables(schema.table_names)
self._schema_initialized = True
return True, f"Initialized with {len(schema.tables)} tables"
except Exception as e:
logger.error(f"Initialization failed: {e}")
return False, str(e)
def index_text_data(self, progress_callback=None) -> int:
"""Index all text data for RAG."""
if not self._schema_initialized:
raise RuntimeError("Chatbot not initialized. Call initialize() first.")
# Use the instance's introspector which might be patched for custom DB
schema = self.introspector.introspect()
total_docs = 0
for table_name, table_info in schema.tables.items():
text_cols = [c.name for c in table_info.text_columns]
if not text_cols:
continue
pk = table_info.primary_keys[0] if table_info.primary_keys else None
cols_to_select = text_cols + ([pk] if pk else [])
# Quote table name based on DB specific rules to handle case sensitivity and special chars
if self.db.db_type.value == "mysql":
quoted_table = f"`{table_name}`"
else:
quoted_table = f'"{table_name}"'
query = f"SELECT {', '.join(cols_to_select)} FROM {quoted_table} LIMIT 1000"
try:
# Try the primary query
query = f"SELECT {', '.join(cols_to_select)} FROM {quoted_table} LIMIT 1000"
rows = self.db.execute_query(query)
docs = self.rag_engine.index_table(table_name, rows, text_cols, pk)
total_docs += docs
if progress_callback:
progress_callback(table_name, docs)
except Exception as e:
# Fallback mechanism for PostgreSQL if table not found (often due to schema issues)
if self.db.db_type.value == "postgresql" and "UndefinedTable" in str(e):
try:
logger.warning(f"Initial query failed for {table_name}, trying 'public' schema prefix...")
fallback_query = f"SELECT {', '.join(cols_to_select)} FROM public.\"{table_name}\" LIMIT 1000"
rows = self.db.execute_query(fallback_query)
docs = self.rag_engine.index_table(table_name, rows, text_cols, pk)
total_docs += docs
if progress_callback:
progress_callback(table_name, docs)
continue # Success with fallback
except Exception as e2:
logger.error(f"Fallback query also failed for {table_name}: {e2}")
logger.warning(f"Failed to index {table_name}: {e}")
self.rag_engine.save()
self._rag_initialized = True
return total_docs
def chat(self, query: str, memory: Optional[ChatMemory] = None, ignored_tables: Optional[List[str]] = None, language: str = "English") -> ChatResponse:
"""Process a user query and return a response.
Args:
query: The user's question
memory: Optional chat memory for context
ignored_tables: Tables to exclude from queries
language: Preferred response language (default: English)
"""
if not self._schema_initialized:
return ChatResponse(answer="Chatbot not initialized.", query_type="error",
error="Call initialize() first")
if not self.llm_client:
return ChatResponse(answer="LLM not configured.", query_type="error",
error="Configure LLM client first")
try:
# Use instance introspector
schema = self.introspector.introspect()
schema_context = schema.to_context_string(ignored_tables=ignored_tables)
# Calculate allowed tables for RAG and Validator
allowed_tables = None
if ignored_tables:
allowed_tables = [t for t in schema.table_names if t not in ignored_tables]
# Update validator to only allow these tables
self.sql_validator.set_allowed_tables(allowed_tables)
else:
self.sql_validator.set_allowed_tables(schema.table_names)
# Check for memory commands using regex for flexibility
import re
# This regex captures patterns like "save this", "remember that my size is 7", "please memorize my name"
save_pattern = re.compile(r"(?:please\s+)?(?:save|remember|memorize|record|store)\s+(?:this|that|to\s+(?:main\s+)?memory)?\s*(?:that)?\s*:?\s*(.*)", re.IGNORECASE)
match = save_pattern.search(query.strip())
# Additional check for colloquial "save to memory" or "memory: X" phrasings
is_memory_phrase = any(phrase in query.lower() for phrase in ["save to memory", "remember this", "memorize this", "save my", "remember my"])
is_command = bool(match) or is_memory_phrase
if is_command and memory:
# Prioritize explicit content from the regex match
content_to_save = match.group(1).strip() if (match and match.group(1)) else ""
# Special case enhancement
if not content_to_save:
# Try to extract content if regex was too strict but is_memory_phrase matched
# e.g. "my shoe size is 7, save to memory"
if "save" in query.lower():
content_to_save = query.lower().split("save")[0].strip().strip(",").strip()
elif "remember" in query.lower():
content_to_save = query.lower().split("remember")[1].strip()
# If we have content, save it
if content_to_save:
is_ok, msg = memory.save_permanent_context(content_to_save)
if is_ok:
return ChatResponse(answer=f"💾 I've saved to your permanent memory: '{content_to_save}'", query_type="memory")
else:
return ChatResponse(answer=f"❌ Failed to save to permanent memory: {msg}", query_type="memory")
# If no content (e.g. "Save this"), save the previous conversation turn
elif len(memory.messages) >= 2:
# We try to grab the last Assistant Response
last_ai_msg = next((m for m in reversed(memory.messages[:-1]) if m.role == "assistant"), None)
last_user_msg = next((m for m in reversed(memory.messages[:-1]) if m.role == "user"), None)
if last_ai_msg and last_user_msg:
context_str = f"User: {last_user_msg.content} | AI: {last_ai_msg.content}"
is_ok, msg = memory.save_permanent_context(context_str)
if is_ok:
return ChatResponse(answer="💾 I've saved our last exchange to your permanent memory.", query_type="memory")
else:
return ChatResponse(answer=f"❌ Failed to save to permanent memory: {msg}", query_type="memory")
else:
return ChatResponse(answer="⚠️ I couldn't find a clear previous exchange to save. Try saying 'Remember that [fact]'.", query_type="memory")
else:
return ChatResponse(answer="⚠️ Nothing previous to save. Tell me something to remember first!", query_type="memory")
# Get chat history for context
history = memory.get_context_messages(5) if memory else []
# Route the query
routing = self.router.route(query, schema_context, history)
# Initial usage from routing
routing_usage = routing.token_usage or {"input": 0, "output": 0, "total": 0}
# Process based on route
response = None
if routing.query_type == QueryType.RAG:
response = self._handle_rag(query, history, allowed_tables, language)
elif routing.query_type == QueryType.SQL:
response = self._handle_sql(query, schema_context, history, allowed_tables, language)
elif routing.query_type == QueryType.HYBRID:
response = self._handle_hybrid(query, schema_context, history, allowed_tables, language)
else:
response = self._handle_general(query, history, language)
# Add routing tokens to total
if response.token_usage:
response.token_usage["input"] += routing_usage.get("input", 0)
response.token_usage["output"] += routing_usage.get("output", 0)
response.token_usage["total"] += routing_usage.get("total", 0)
else:
response.token_usage = routing_usage
return response
except Exception as e:
logger.error(f"Chat error: {e}")
return ChatResponse(answer=f"Error: {str(e)}", query_type="error", error=str(e))
def _handle_rag(self, query: str, history: List[Dict], allowed_tables: Optional[List[str]] = None, language: str = "English") -> ChatResponse:
"""Handle RAG-based query."""
# Check if we have any indexed data
if self.rag_engine.document_count == 0:
# Even for this error, we consumed tokens up to the routing decision, but since
# routing happens before this function, we can't easily track that here.
# However, we can return empty usage.
usage = {"input": 0, "output": 0, "total": 0}
return ChatResponse(
answer="⚠️ **I can't answer this yet.**\n\nThis looks like a semantic question (searching for meaning/concepts), but you haven't **indexed the text data** yet.\n\nPlease click the **'📚 Index Text Data'** button in the sidebar to enable this functionality.",
query_type="error",
error="RAG index is empty",
token_usage=usage
)
context = self.rag_engine.get_context(query, top_k=5, table_filter=allowed_tables)
# Get language instruction
language_instruction = self._get_language_instruction(language)
prompt = self.RESPONSE_PROMPT.format(
context=f"RELEVANT DATA:\n{context}",
question=query,
language_instruction=language_instruction
)
messages = self._construct_messages(
"You are a helpful database assistant.",
history,
prompt
)
response = self.llm_client.chat(messages)
usage = {
"input": response.input_tokens,
"output": response.output_tokens,
"total": response.total_tokens
}
return ChatResponse(answer=response.content, query_type="rag",
sources=[{"type": "semantic_search", "context": context[:500]}],
token_usage=usage)
def _handle_sql(self, query: str, schema_context: str, history: List[Dict], allowed_tables: Optional[List[str]] = None, language: str = "English") -> ChatResponse:
"""Handle SQL-based query."""
sql, gen_response = self.sql_generator.generate(query, schema_context, history)
# Initial usage from SQL generation
total_usage = {
"input": gen_response.input_tokens,
"output": gen_response.output_tokens,
"total": gen_response.total_tokens
}
# Validate SQL
is_valid, msg, sanitized_sql = self.sql_validator.validate(sql)
if not is_valid:
return ChatResponse(answer=f"Could not generate safe query: {msg}",
query_type="sql", error=msg, token_usage=total_usage)
# Execute query
try:
results = self.db.execute_query(sanitized_sql)
except Exception as e:
return ChatResponse(answer=f"Query execution failed: {e}",
query_type="sql", sql_query=sanitized_sql, error=str(e),
token_usage=total_usage)
# SMART FALLBACK: If SQL returns nothing, it might be a semantic issue (e.g. wrong column)
# We try RAG as a fallback if SQL found nothing
if not results:
logger.info(f"SQL returned no results for query: '{query}'. Falling back to RAG.")
rag_response = self._handle_rag(query, history, allowed_tables, language)
# Combine the info: "I couldn't find an exact match in the rows, but here is what I found semantically:"
rag_response.answer = f"I couldn't find a direct match using a database query, but here is what I found in the product descriptions:\n\n{rag_response.answer}"
rag_response.query_type = "hybrid_fallback"
rag_response.sql_query = sanitized_sql
# Add usage from SQL gen to RAG usage
if rag_response.token_usage:
rag_response.token_usage["input"] += total_usage["input"]
rag_response.token_usage["output"] += total_usage["output"]
rag_response.token_usage["total"] += total_usage["total"]
else:
rag_response.token_usage = total_usage
return rag_response
# Generate response with language instruction
language_instruction = self._get_language_instruction(language)
context = f"SQL QUERY:\n{sanitized_sql}\n\nRESULTS:\n{self._format_results(results)}"
prompt = self.RESPONSE_PROMPT.format(
context=context,
question=query,
language_instruction=language_instruction
)
messages = self._construct_messages(
"You are a helpful database assistant.",
history,
prompt
)
final_response = self.llm_client.chat(messages)
# Add usage from final response
total_usage["input"] += final_response.input_tokens
total_usage["output"] += final_response.output_tokens
total_usage["total"] += final_response.total_tokens
return ChatResponse(answer=final_response.content, query_type="sql",
sql_query=sanitized_sql, sql_results=results[:10],
token_usage=total_usage)
def _handle_hybrid(self, query: str, schema_context: str, history: List[Dict], allowed_tables: Optional[List[str]] = None, language: str = "English") -> ChatResponse:
"""Handle hybrid RAG + SQL query."""
# Get RAG context
rag_context = self.rag_engine.get_context(query, top_k=3, table_filter=allowed_tables)
# Try SQL as well
sql_context = ""
sql_query = None
total_usage = {"input": 0, "output": 0, "total": 0}
try:
sql, gen_response = self.sql_generator.generate(query, schema_context, history)
# Accumulate usage
total_usage["input"] += gen_response.input_tokens
total_usage["output"] += gen_response.output_tokens
total_usage["total"] += gen_response.total_tokens
is_valid, _, sanitized_sql = self.sql_validator.validate(sql)
if is_valid:
results = self.db.execute_query(sanitized_sql)
sql_context = f"\nSQL RESULTS:\n{self._format_results(results)}"
sql_query = sanitized_sql
except Exception as e:
logger.debug(f"SQL part of hybrid failed: {e}")
# Get language instruction
language_instruction = self._get_language_instruction(language)
context = f"SEMANTIC SEARCH RESULTS:\n{rag_context}{sql_context}"
prompt = self.RESPONSE_PROMPT.format(
context=context,
question=query,
language_instruction=language_instruction
)
messages = self._construct_messages(
"You are a helpful database assistant.",
history,
prompt
)
final_response = self.llm_client.chat(messages)
# Add final usage
total_usage["input"] += final_response.input_tokens
total_usage["output"] += final_response.output_tokens
total_usage["total"] += final_response.total_tokens
return ChatResponse(answer=final_response.content, query_type="hybrid", sql_query=sql_query, token_usage=total_usage)
def _construct_messages(self, system_instruction: str, history: List[Dict], user_content: str) -> List[Dict]:
"""Construct message list, merging system messages from history."""
# Check if first history item is a system message (from memory)
additional_context = ""
filtered_history = []
for msg in history:
if msg.get("role") == "system":
additional_context += f"\n\n{msg.get('content')}"
else:
filtered_history.append(msg)
full_system_prompt = f"{system_instruction}{additional_context}"
messages = [{"role": "system", "content": full_system_prompt}]
messages.extend(filtered_history)
messages.append({"role": "user", "content": user_content})
return messages
def _handle_general(self, query: str, history: List[Dict], language: str = "English") -> ChatResponse:
"""Handle conversation."""
# Get language instruction
language_instruction = self._get_language_instruction(language)
# Build language suffix for system prompt
language_suffix = ""
if language != "English":
base_language = language
if "(" in language and ")" in language:
base_language = language.split("(")[1].rstrip(")")
language_suffix = f"\n- Respond entirely in {base_language}."
# Use a strict prompt for general conversation as well to prevent hallucinations
strict_system_prompt = (
"You are a helpful database assistant.\n"
"INSTRUCTIONS:\n"
"- Answer ONLY based on the conversation history and any context provided within it.\n"
"- Do NOT use outside knowledge, general assumptions, or hallucinate facts.\n"
"- If the answer is not in the history or context, state that you don't have that information.\n"
f"- Be concise.{language_suffix}"
)
messages = self._construct_messages(
strict_system_prompt,
history,
query
)
response = self.llm_client.chat(messages)
usage = {
"input": response.input_tokens,
"output": response.output_tokens,
"total": response.total_tokens
}
return ChatResponse(answer=response.content, query_type="general", token_usage=usage)
def _format_results(self, results: List[Dict], max_rows: int = 10) -> str:
"""Format SQL results for display."""
if not results:
return "No results found."
rows = results[:max_rows]
lines = []
# Header
headers = list(rows[0].keys())
lines.append(" | ".join(headers))
lines.append("-" * len(lines[0]))
# Rows
for row in rows:
values = [str(v)[:50] for v in row.values()]
lines.append(" | ".join(values))
if len(results) > max_rows:
lines.append(f"... and {len(results) - max_rows} more rows")
return "\n".join(lines)
def get_schema_summary(self) -> str:
"""Get a summary of the database schema."""
if not self._schema_initialized:
return "Schema not loaded."
return self.introspector.introspect().to_context_string()
def create_chatbot(llm_client: Optional[LLMClient] = None) -> DatabaseChatbot:
return DatabaseChatbot(llm_client)