DB_Chatbot / database /schema_introspector.py
Vanshcc's picture
Upload 7 files
b404e8f verified
"""
Dynamic Schema Introspection Module - Multi-Database Support.
This module is the CORE of the schema-agnostic design.
It dynamically discovers:
- All tables in the database
- All columns with their data types
- Primary keys and foreign keys
- Text-like columns for RAG indexing
- Relationships between tables
Supports MySQL, PostgreSQL, and SQLite.
NEVER hardcodes any table or column names.
"""
import logging
from dataclasses import dataclass, field
from typing import List, Dict, Optional, Any
from sqlalchemy import text, inspect
from sqlalchemy.engine import Engine
from .connection import get_db
logger = logging.getLogger(__name__)
@dataclass
class ColumnInfo:
"""Information about a single database column."""
name: str
data_type: str
is_nullable: bool
is_primary_key: bool
max_length: Optional[int] = None
default_value: Optional[str] = None
comment: Optional[str] = None
@property
def is_text_type(self) -> bool:
"""Check if this column contains text data suitable for RAG."""
text_types = [
# MySQL
'text', 'mediumtext', 'longtext', 'tinytext', 'varchar', 'char', 'json',
# PostgreSQL
'character varying', 'character', 'text', 'json', 'jsonb'
]
data_type_lower = self.data_type.lower().split('(')[0].strip()
return data_type_lower in text_types
@property
def is_numeric(self) -> bool:
"""Check if this column contains numeric data."""
numeric_types = [
# Common across databases
'int', 'integer', 'bigint', 'smallint', 'tinyint',
'decimal', 'numeric', 'float', 'double', 'real',
# PostgreSQL specific
'double precision', 'serial', 'bigserial', 'smallserial'
]
data_type_lower = self.data_type.lower().split('(')[0].strip()
return data_type_lower in numeric_types
@dataclass
class TableInfo:
"""Complete information about a database table."""
name: str
columns: List[ColumnInfo] = field(default_factory=list)
primary_keys: List[str] = field(default_factory=list)
foreign_keys: Dict[str, str] = field(default_factory=dict) # column -> referenced_table.column
row_count: Optional[int] = None
comment: Optional[str] = None
@property
def text_columns(self) -> List[ColumnInfo]:
"""Get columns suitable for text/RAG indexing."""
return [col for col in self.columns if col.is_text_type]
@property
def column_names(self) -> List[str]:
"""Get list of all column names."""
return [col.name for col in self.columns]
def get_column(self, name: str) -> Optional[ColumnInfo]:
"""Get column info by name."""
for col in self.columns:
if col.name.lower() == name.lower():
return col
return None
@dataclass
class SchemaInfo:
"""Complete database schema information."""
database_name: str
tables: Dict[str, TableInfo] = field(default_factory=dict)
@property
def table_names(self) -> List[str]:
"""Get list of all table names."""
return list(self.tables.keys())
@property
def all_text_columns(self) -> List[tuple]:
"""Get all text columns across all tables as (table, column) tuples."""
result = []
for table_name, table_info in self.tables.items():
for col in table_info.text_columns:
result.append((table_name, col.name))
return result
def to_context_string(self, ignored_tables: Optional[List[str]] = None) -> str:
"""
Generate a natural language description of the schema.
This is used as context for the LLM.
"""
lines = [f"Database: {self.database_name}", ""]
lines.append("Available Tables:")
lines.append("-" * 40)
for table_name, table_info in self.tables.items():
if ignored_tables and table_name in ignored_tables:
continue
lines.append(f"\nTable: {table_name}")
if table_info.comment:
lines.append(f" Description: {table_info.comment}")
if table_info.row_count is not None:
lines.append(f" Approximate rows: {table_info.row_count}")
lines.append(" Columns:")
for col in table_info.columns:
pk_marker = " [PRIMARY KEY]" if col.is_primary_key else ""
nullable = " (nullable)" if col.is_nullable else " (required)"
lines.append(f" - {col.name}: {col.data_type}{pk_marker}{nullable}")
if col.comment:
lines.append(f" Comment: {col.comment}")
if table_info.foreign_keys:
lines.append(" Foreign Keys:")
for col, ref in table_info.foreign_keys.items():
lines.append(f" - {col} -> {ref}")
return "\n".join(lines)
def to_sql_ddl(self) -> str:
"""
Generate SQL-like DDL representation of the schema.
Useful for SQL generation context.
"""
ddl_lines = []
for table_name, table_info in self.tables.items():
ddl_lines.append(f"CREATE TABLE {table_name} (")
col_defs = []
for col in table_info.columns:
col_def = f" {col.name} {col.data_type}"
if col.is_primary_key:
col_def += " PRIMARY KEY"
if not col.is_nullable:
col_def += " NOT NULL"
col_defs.append(col_def)
ddl_lines.append(",\n".join(col_defs))
ddl_lines.append(");\n")
return "\n".join(ddl_lines)
class SchemaIntrospector:
"""
Dynamically introspects database schema.
This is the key component that enables schema-agnostic operation.
It queries database system catalogs to discover the complete schema.
Supports MySQL, PostgreSQL, and SQLite.
"""
# System tables to exclude from introspection
SYSTEM_TABLES = {
'_chatbot_memory', # Our own chat history table
'_chatbot_permanent_memory_v2',
'_chatbot_user_summaries',
'schema_migrations',
'flyway_schema_history',
# Vector store internal tables
'chunks',
'embeddings',
'vectors'
}
def __init__(self, engine: Optional[Engine] = None):
"""
Initialize the introspector.
Args:
engine: SQLAlchemy engine. Uses global connection if not provided.
"""
self.db = get_db()
self._cached_schema: Optional[SchemaInfo] = None
def introspect(self, force_refresh: bool = False) -> SchemaInfo:
"""
Perform complete schema introspection.
Args:
force_refresh: If True, bypass cache and re-introspect
Returns:
SchemaInfo object with complete schema details
"""
if self._cached_schema is not None and not force_refresh:
return self._cached_schema
logger.info("Starting schema introspection...")
# Get database name
db_name = self._get_database_name()
# Get all user tables
tables = self._get_tables()
schema = SchemaInfo(database_name=db_name)
for table_name in tables:
if table_name in self.SYSTEM_TABLES:
continue
# Also skip tables that start with underscore (internal tables)
if table_name.startswith('_chatbot'):
continue
table_info = self._introspect_table(table_name)
if table_info:
schema.tables[table_name] = table_info
self._cached_schema = schema
logger.info(f"Schema introspection complete. Found {len(schema.tables)} tables.")
return schema
def _get_database_name(self) -> str:
"""Get the current database name."""
db_type = self.db.db_type
try:
if db_type.value == "postgresql":
result = self.db.execute_query("SELECT current_database() as db_name")
return result[0]['db_name'] if result else "unknown"
elif db_type.value == "sqlite":
return "sqlite_main"
else: # MySQL
result = self.db.execute_query("SELECT DATABASE() as db_name")
return result[0]['db_name'] if result else "unknown"
except Exception as e:
logger.error(f"Error getting database name: {e}")
return "unknown"
def _get_tables(self) -> List[str]:
"""
Get all user tables from the database.
Uses database-specific queries for comprehensive discovery.
"""
db_type = self.db.db_type
try:
if db_type.value == "postgresql":
query = """
SELECT table_name
FROM information_schema.tables
WHERE table_schema = 'public'
AND table_type = 'BASE TABLE'
ORDER BY table_name
"""
result = self.db.execute_query(query)
return [row['table_name'] for row in result]
elif db_type.value == "sqlite":
query = "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%' ORDER BY name"
result = self.db.execute_query(query)
return [row['name'] for row in result]
else: # MySQL
query = """
SELECT TABLE_NAME
FROM INFORMATION_SCHEMA.TABLES
WHERE TABLE_SCHEMA = DATABASE()
AND TABLE_TYPE = 'BASE TABLE'
ORDER BY TABLE_NAME
"""
result = self.db.execute_query(query)
return [row['TABLE_NAME'] for row in result]
except Exception as e:
logger.error(f"Error getting tables: {e}")
return []
def _introspect_table(self, table_name: str) -> Optional[TableInfo]:
"""
Get complete information about a specific table.
Args:
table_name: Name of the table to introspect
Returns:
TableInfo object or None if table doesn't exist
"""
try:
# Get column information
columns = self._get_columns(table_name)
# Get primary keys
primary_keys = self._get_primary_keys(table_name)
# Get foreign keys
foreign_keys = self._get_foreign_keys(table_name)
# Get approximate row count (fast estimation)
row_count = self._get_row_count(table_name)
# Get table comment (not available in SQLite)
comment = self._get_table_comment(table_name)
# Mark primary key columns
for col in columns:
col.is_primary_key = col.name in primary_keys
return TableInfo(
name=table_name,
columns=columns,
primary_keys=primary_keys,
foreign_keys=foreign_keys,
row_count=row_count,
comment=comment
)
except Exception as e:
logger.error(f"Error introspecting table {table_name}: {e}")
return None
def _get_columns(self, table_name: str) -> List[ColumnInfo]:
"""Get all columns for a table."""
db_type = self.db.db_type
try:
if db_type.value == "postgresql":
query = """
SELECT
column_name,
data_type,
is_nullable,
column_default,
character_maximum_length,
col_description(
(SELECT oid FROM pg_class WHERE relname = :table_name),
ordinal_position
) as column_comment
FROM information_schema.columns
WHERE table_schema = 'public'
AND table_name = :table_name
ORDER BY ordinal_position
"""
result = self.db.execute_query(query, {"table_name": table_name})
columns = []
for row in result:
columns.append(ColumnInfo(
name=row['column_name'],
data_type=row['data_type'],
is_nullable=row['is_nullable'] == 'YES',
is_primary_key=False, # Will be set later
max_length=row['character_maximum_length'],
default_value=row['column_default'],
comment=row.get('column_comment')
))
return columns
elif db_type.value == "sqlite":
# For SQLite, we use PRAGMA table_info
query = f"PRAGMA table_info({table_name})"
result = self.db.execute_query(query)
columns = []
for row in result:
# SQLite PRAGMA result: cid, name, type, notnull, dflt_value, pk
columns.append(ColumnInfo(
name=row['name'],
data_type=row['type'],
is_nullable=row['notnull'] == 0,
is_primary_key=row['pk'] > 0,
max_length=None, # Extract from type string if needed
default_value=row['dflt_value'],
comment=None
))
return columns
else: # MySQL
query = """
SELECT
COLUMN_NAME,
COLUMN_TYPE,
IS_NULLABLE,
COLUMN_DEFAULT,
CHARACTER_MAXIMUM_LENGTH,
COLUMN_COMMENT
FROM INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_SCHEMA = DATABASE()
AND TABLE_NAME = :table_name
ORDER BY ORDINAL_POSITION
"""
result = self.db.execute_query(query, {"table_name": table_name})
columns = []
for row in result:
columns.append(ColumnInfo(
name=row['COLUMN_NAME'],
data_type=row['COLUMN_TYPE'],
is_nullable=row['IS_NULLABLE'] == 'YES',
is_primary_key=False, # Will be set later
max_length=row['CHARACTER_MAXIMUM_LENGTH'],
default_value=row['COLUMN_DEFAULT'],
comment=row['COLUMN_COMMENT'] if row['COLUMN_COMMENT'] else None
))
return columns
except Exception as e:
logger.error(f"Error getting columns for {table_name}: {e}")
return []
def _get_primary_keys(self, table_name: str) -> List[str]:
"""Get primary key columns for a table."""
db_type = self.db.db_type
try:
if db_type.value == "postgresql":
query = """
SELECT a.attname as column_name
FROM pg_index i
JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey)
WHERE i.indrelid = CAST(:table_name AS regclass)
AND i.indisprimary
"""
result = self.db.execute_query(query, {"table_name": table_name})
return [row['column_name'] for row in result]
elif db_type.value == "sqlite":
query = f"PRAGMA table_info({table_name})"
result = self.db.execute_query(query)
return [row['name'] for row in result if row['pk'] > 0]
else: # MySQL
query = """
SELECT COLUMN_NAME
FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE
WHERE TABLE_SCHEMA = DATABASE()
AND TABLE_NAME = :table_name
AND CONSTRAINT_NAME = 'PRIMARY'
ORDER BY ORDINAL_POSITION
"""
result = self.db.execute_query(query, {"table_name": table_name})
return [row['COLUMN_NAME'] for row in result]
except Exception as e:
logger.error(f"Error getting primary keys for {table_name}: {e}")
return []
def _get_foreign_keys(self, table_name: str) -> Dict[str, str]:
"""Get foreign key relationships for a table."""
db_type = self.db.db_type
try:
if db_type.value == "postgresql":
query = """
SELECT
kcu.column_name,
ccu.table_name AS foreign_table_name,
ccu.column_name AS foreign_column_name
FROM information_schema.table_constraints AS tc
JOIN information_schema.key_column_usage AS kcu
ON tc.constraint_name = kcu.constraint_name
AND tc.table_schema = kcu.table_schema
JOIN information_schema.constraint_column_usage AS ccu
ON ccu.constraint_name = tc.constraint_name
AND ccu.table_schema = tc.table_schema
WHERE tc.constraint_type = 'FOREIGN KEY'
AND tc.table_name = :table_name
"""
result = self.db.execute_query(query, {"table_name": table_name})
return {
row['column_name']: f"{row['foreign_table_name']}.{row['foreign_column_name']}"
for row in result
}
elif db_type.value == "sqlite":
query = f"PRAGMA foreign_key_list({table_name})"
result = self.db.execute_query(query)
# SQLite PRAGMA result: id, seq, table, from, to, on_update, on_delete, match
return {
row['from']: f"{row['table']}.{row['to']}"
for row in result
}
else: # MySQL
query = """
SELECT
COLUMN_NAME,
REFERENCED_TABLE_NAME,
REFERENCED_COLUMN_NAME
FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE
WHERE TABLE_SCHEMA = DATABASE()
AND TABLE_NAME = :table_name
AND REFERENCED_TABLE_NAME IS NOT NULL
"""
result = self.db.execute_query(query, {"table_name": table_name})
return {
row['COLUMN_NAME']: f"{row['REFERENCED_TABLE_NAME']}.{row['REFERENCED_COLUMN_NAME']}"
for row in result
}
except Exception as e:
logger.error(f"Error getting foreign keys for {table_name}: {e}")
return {}
def _get_row_count(self, table_name: str) -> Optional[int]:
"""
Get approximate row count for a table.
Uses different strategies per database.
"""
db_type = self.db.db_type
try:
if db_type.value == "postgresql":
# Use pg_stat_user_tables for fast estimation
query = """
SELECT n_live_tup as row_count
FROM pg_stat_user_tables
WHERE relname = :table_name
"""
result = self.db.execute_query(query, {"table_name": table_name})
return result[0]['row_count'] if result else None
elif db_type.value == "sqlite":
query = f"SELECT COUNT(*) as row_count FROM {table_name}"
result = self.db.execute_query(query)
return result[0]['row_count'] if result else 0
else: # MySQL
query = """
SELECT TABLE_ROWS
FROM INFORMATION_SCHEMA.TABLES
WHERE TABLE_SCHEMA = DATABASE()
AND TABLE_NAME = :table_name
"""
result = self.db.execute_query(query, {"table_name": table_name})
return result[0]['TABLE_ROWS'] if result else None
except Exception as e:
logger.error(f"Error getting row count for {table_name}: {e}")
return None
def _get_table_comment(self, table_name: str) -> Optional[str]:
"""Get table comment/description."""
db_type = self.db.db_type
try:
if db_type.value == "postgresql":
query = """
SELECT obj_description(CAST(:table_name AS regclass), 'pg_class') as table_comment
"""
result = self.db.execute_query(query, {"table_name": table_name})
comment = result[0]['table_comment'] if result else None
return comment if comment else None
elif db_type.value == "sqlite":
# SQLite doesn't conveniently support table comments
return None
else: # MySQL
query = """
SELECT TABLE_COMMENT
FROM INFORMATION_SCHEMA.TABLES
WHERE TABLE_SCHEMA = DATABASE()
AND TABLE_NAME = :table_name
"""
result = self.db.execute_query(query, {"table_name": table_name})
comment = result[0]['TABLE_COMMENT'] if result else None
return comment if comment else None
except Exception as e:
logger.error(f"Error getting table comment for {table_name}: {e}")
return None
def get_text_columns_for_rag(self, min_length: int = 50) -> List[Dict[str, Any]]:
"""
Get all text columns suitable for RAG indexing.
Args:
min_length: Minimum max_length for varchar columns to be considered
Returns:
List of dicts with table name, column name, and metadata
"""
schema = self.introspect()
text_columns = []
for table_name, table_info in schema.tables.items():
for col in table_info.columns:
if col.is_text_type:
# Skip very short varchar columns
if col.max_length and col.max_length < min_length:
continue
text_columns.append({
"table": table_name,
"column": col.name,
"data_type": col.data_type,
"primary_keys": table_info.primary_keys,
"max_length": col.max_length
})
return text_columns
def refresh_cache(self) -> SchemaInfo:
"""Force refresh the cached schema."""
return self.introspect(force_refresh=True)
# Global introspector instance
_introspector: Optional[SchemaIntrospector] = None
def get_introspector() -> SchemaIntrospector:
"""Get or create the global schema introspector."""
global _introspector
if _introspector is None:
_introspector = SchemaIntrospector()
return _introspector
def get_schema() -> SchemaInfo:
"""Convenience function to get the current schema."""
return get_introspector().introspect()