DB_Chatbot / config.py
Vanshcc's picture
Upload 7 files
b404e8f verified
"""
Configuration module for the Schema-Agnostic Database Chatbot.
This module handles all configuration including:
- Database connection settings (MySQL, PostgreSQL, SQLite)
- LLM provider settings (Groq / OpenAI / Local LLaMA)
- Embedding model configuration
- Security settings
"""
import os
from pathlib import Path
from dataclasses import dataclass, field
from typing import Optional, List
from enum import Enum
# Load .env file BEFORE any os.getenv calls
from dotenv import load_dotenv
env_path = Path(__file__).parent / ".env"
load_dotenv(env_path)
class DatabaseType(Enum):
"""Supported database types."""
MYSQL = "mysql"
POSTGRESQL = "postgresql"
SQLITE = "sqlite"
class LLMProvider(Enum):
"""Supported LLM providers."""
GROQ = "groq" # FREE!
OPENAI = "openai"
LOCAL_LLAMA = "local_llama"
class EmbeddingProvider(Enum):
"""Supported embedding providers."""
OPENAI = "openai"
SENTENCE_TRANSFORMERS = "sentence_transformers"
@dataclass
class DatabaseConfig:
"""
Database configuration supporting MySQL and PostgreSQL.
All sensitive values are loaded from environment variables.
"""
# Database type (mysql, postgresql)
db_type: DatabaseType = field(
default_factory=lambda: DatabaseType(os.getenv("DB_TYPE", "mysql").lower())
)
# Common connection settings (for MySQL/PostgreSQL)
host: str = field(default_factory=lambda: os.getenv("DB_HOST", os.getenv("MYSQL_HOST", "")))
port: int = field(default_factory=lambda: int(os.getenv("DB_PORT", os.getenv("MYSQL_PORT", "3306"))))
database: str = field(default_factory=lambda: os.getenv("DB_DATABASE", os.getenv("MYSQL_DATABASE", "")))
username: str = field(default_factory=lambda: os.getenv("DB_USERNAME", os.getenv("MYSQL_USERNAME", "")))
password: str = field(default_factory=lambda: os.getenv("DB_PASSWORD", os.getenv("MYSQL_PASSWORD", "")))
# SSL configuration
ssl_ca: Optional[str] = field(default_factory=lambda: os.getenv("DB_SSL_CA", os.getenv("MYSQL_SSL_CA", None)))
@property
def connection_string(self) -> str:
"""Generate SQLAlchemy connection string based on database type."""
if self.db_type == DatabaseType.POSTGRESQL:
# PostgreSQL connection string
base_url = f"postgresql+psycopg2://{self.username}:{self.password}@{self.host}:{self.port}/{self.database}"
if self.ssl_ca:
return f"{base_url}?sslmode=verify-full&sslrootcert={self.ssl_ca}"
return base_url
elif self.db_type == DatabaseType.SQLITE:
# SQLite connection string (e.g. sqlite:///database.db)
# If database is just a name, it will be in the current directory
# If it starts with / or \, it's an absolute path
return f"sqlite:///{self.database}"
else: # MySQL (default)
# MySQL connection string
base_url = f"mysql+pymysql://{self.username}:{self.password}@{self.host}:{self.port}/{self.database}"
if self.ssl_ca:
return f"{base_url}?ssl_ca={self.ssl_ca}"
return base_url
def is_configured(self) -> bool:
"""Check if all required database settings are configured."""
if self.db_type == DatabaseType.SQLITE:
return bool(self.database)
# MySQL/PostgreSQL need host, database, username, password
return all([self.host, self.database, self.username, self.password])
@property
def is_mysql(self) -> bool:
"""Check if using MySQL."""
return self.db_type == DatabaseType.MYSQL
@property
def is_postgresql(self) -> bool:
"""Check if using PostgreSQL."""
return self.db_type == DatabaseType.POSTGRESQL
@property
def is_sqlite(self) -> bool:
"""Check if using SQLite."""
return self.db_type == DatabaseType.SQLITE
@dataclass
class LLMConfig:
"""LLM configuration for query routing and response generation."""
provider: LLMProvider = field(
default_factory=lambda: LLMProvider(os.getenv("LLM_PROVIDER", "openai"))
)
openai_api_key: str = field(default_factory=lambda: os.getenv("OPENAI_API_KEY", ""))
openai_model: str = field(default_factory=lambda: os.getenv("OPENAI_MODEL", "gpt-4o-mini"))
# Local LLaMA settings
local_model_path: str = field(
default_factory=lambda: os.getenv("LOCAL_MODEL_PATH", "")
)
local_model_name: str = field(
default_factory=lambda: os.getenv("LOCAL_MODEL_NAME", "llama-2-7b-chat")
)
# Generation parameters
temperature: float = 0.1 # Low temperature for more deterministic outputs
max_tokens: int = 1024
def is_configured(self) -> bool:
"""Check if LLM is properly configured."""
if self.provider == LLMProvider.OPENAI:
return bool(self.openai_api_key)
return bool(self.local_model_path)
@dataclass
class EmbeddingConfig:
"""Embedding model configuration for RAG."""
provider: EmbeddingProvider = field(
default_factory=lambda: EmbeddingProvider(
os.getenv("EMBEDDING_PROVIDER", "sentence_transformers")
)
)
# OpenAI embedding settings
openai_embedding_model: str = "text-embedding-3-small"
# Sentence Transformers settings
st_model_name: str = field(
default_factory=lambda: os.getenv(
"EMBEDDING_MODEL",
"sentence-transformers/all-MiniLM-L6-v2"
)
)
# Embedding dimensions (varies by model)
embedding_dim: int = 384 # Default for all-MiniLM-L6-v2
@dataclass
class SecurityConfig:
"""Security settings for SQL validation and execution."""
# SQL operations whitelist - ONLY SELECT allowed
allowed_operations: List[str] = field(default_factory=lambda: ["SELECT"])
# Dangerous keywords that should never appear in queries
forbidden_keywords: List[str] = field(default_factory=lambda: [
"INSERT", "UPDATE", "DELETE", "DROP", "CREATE", "ALTER",
"TRUNCATE", "GRANT", "REVOKE", "EXECUTE", "EXEC",
"INTO OUTFILE", "INTO DUMPFILE", "LOAD_FILE",
"INFORMATION_SCHEMA.USER_PRIVILEGES"
])
# Maximum number of rows to return
max_result_rows: int = 100
# Default LIMIT clause if not specified
default_limit: int = 50
@dataclass
class RAGConfig:
"""RAG (Retrieval-Augmented Generation) configuration."""
# FAISS index settings
faiss_index_path: str = "./faiss_index"
# Number of top results to retrieve
top_k: int = 5
# Minimum similarity score for relevance
similarity_threshold: float = 0.3
# Text columns to consider for RAG (common across database types)
text_column_types: List[str] = field(default_factory=lambda: [
# MySQL types
"TEXT", "MEDIUMTEXT", "LONGTEXT", "TINYTEXT", "VARCHAR", "CHAR",
# PostgreSQL types
"CHARACTER VARYING", "CHARACTER"
])
# Minimum character length to consider a column for RAG
min_text_length: int = 50
# Chunk size for long text documents
chunk_size: int = 500
chunk_overlap: int = 50
@dataclass
class ChatConfig:
"""Chat and memory configuration."""
# Short-term memory (in session)
max_session_messages: int = 20
# Long-term memory table name (will be created if not exists)
memory_table_name: str = "_chatbot_memory"
# Number of recent messages to include in context
context_messages: int = 5
class AppConfig:
"""
Main application configuration aggregator.
Combines all configuration sections and provides
validation methods.
"""
def __init__(self):
self.database = DatabaseConfig()
self.llm = LLMConfig()
self.embedding = EmbeddingConfig()
self.security = SecurityConfig()
self.rag = RAGConfig()
self.chat = ChatConfig()
def validate(self) -> tuple[bool, List[str]]:
"""
Validate all configuration settings.
Returns:
tuple: (is_valid, list of error messages)
"""
errors = []
if not self.database.is_configured():
db_type = self.database.db_type.value.upper()
errors.append(f"{db_type} configuration incomplete. Check DB_* environment variables.")
if not self.llm.is_configured():
errors.append(
f"LLM configuration incomplete for provider: {self.llm.provider.value}. "
"Check API keys or model paths."
)
return len(errors) == 0, errors
@classmethod
def from_env(cls) -> "AppConfig":
"""Create configuration from environment variables."""
return cls()
# Global configuration instance
config = AppConfig.from_env()