hytch / db.py
LeonceNsh's picture
Update db.py
6a73e96 verified
"""
Database connector with support for ODBC and TDS drivers.
Secure handling of credentials via environment variables.
"""
import os
import logging
from typing import Optional, Dict, Any
from contextlib import contextmanager
from urllib.parse import quote_plus
import pandas as pd
from sqlalchemy import create_engine, text, pool
from sqlalchemy.engine import Engine
from tenacity import retry, stop_after_attempt, wait_exponential
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class DatabaseConnector:
"""Handles database connections and query execution."""
def __init__(self):
self.engine: Optional[Engine] = None
self.connection_string: Optional[str] = None
self._init_engine()
def _get_env_var(self, key: str, default: str = "", mask_log: bool = False) -> str:
"""Safely retrieve environment variable."""
value = os.getenv(key, default)
if mask_log and value:
logger.info(f"{key}: {'*' * 8}")
else:
logger.info(f"{key}: {value if value else '(not set)'}")
return value
def _build_connection_string(self) -> Optional[str]:
"""Build SQLAlchemy connection string from environment variables."""
host = self._get_env_var("DB_HOST")
port = self._get_env_var("DB_PORT", "1433")
database = self._get_env_var("DB_NAME")
user = self._get_env_var("DB_USER", mask_log=True)
password = self._get_env_var("DB_PASSWORD", mask_log=True)
driver = self._get_env_var("DB_DRIVER", "tds") # tds or odbc
encrypt = self._get_env_var("DB_ENCRYPT", "false")
if not all([host, database, user, password]):
logger.warning("Database credentials incomplete. Demo mode will be used.")
return None
try:
if driver == "odbc":
# ODBC connection string
driver_name = "{ODBC Driver 18 for SQL Server}"
params = {
"DRIVER": driver_name,
"SERVER": f"{host},{port}",
"DATABASE": database,
"UID": user,
"PWD": password,
"Encrypt": "yes" if encrypt.lower() == "true" else "no",
"TrustServerCertificate": "yes"
}
conn_str = "mssql+pyodbc://?" + "&".join(
f"{k}={quote_plus(str(v))}" for k, v in params.items()
)
else:
# python-tds connection string
conn_str = (
f"mssql+pytds://{quote_plus(user)}:{quote_plus(password)}"
f"@{host}:{port}/{database}"
)
if encrypt.lower() == "true":
conn_str += "?encryption=required"
logger.info(f"Connection string built using {driver} driver")
return conn_str
except Exception as e:
logger.error(f"Error building connection string: {str(e)}")
return None
def _init_engine(self):
"""Initialize SQLAlchemy engine."""
self.connection_string = self._build_connection_string()
if not self.connection_string:
logger.warning("No valid connection string. Database features disabled.")
return
try:
self.engine = create_engine(
self.connection_string,
poolclass=pool.QueuePool,
pool_size=5,
max_overflow=10,
pool_pre_ping=True, # Verify connections before using
echo=False
)
logger.info("Database engine initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize database engine: {str(e)}")
self.engine = None
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=2, max=10))
def test_connection(self) -> bool:
"""Test database connectivity."""
if not self.engine:
return False
try:
with self.engine.connect() as conn:
result = conn.execute(text("SELECT 1 AS test"))
row = result.fetchone()
if row and row[0] == 1:
logger.info("Database connection test successful")
return True
return False
except Exception as e:
logger.error(f"Database connection test failed: {str(e)}")
return False
def execute_query(
self,
query: str,
params: Optional[Dict[str, Any]] = None
) -> Optional[pd.DataFrame]:
"""
Execute a SQL query and return results as pandas DataFrame.
Args:
query: SQL query string with :param placeholders
params: Dictionary of parameter values
Returns:
DataFrame with results or None on error
"""
if not self.engine:
logger.error("No database engine available")
return None
try:
with self.engine.connect() as conn:
result = pd.read_sql_query(
text(query),
conn,
params=params or {}
)
logger.info(f"Query executed successfully, returned {len(result)} rows")
return result
except Exception as e:
# Mask any credential info in error messages
error_msg = str(e)
for key in ["PASSWORD", "PWD", "UID", "password", "user"]:
if key.lower() in error_msg.lower():
error_msg = "Database query error (credentials masked)"
break
logger.error(f"Query execution failed: {error_msg}")
return None
def execute_scalar(
self,
query: str,
params: Optional[Dict[str, Any]] = None
) -> Optional[Any]:
"""Execute query and return single scalar value."""
df = self.execute_query(query, params)
if df is not None and not df.empty:
return df.iloc[0, 0]
return None
@contextmanager
def get_connection(self):
"""Context manager for raw database connections."""
if not self.engine:
raise RuntimeError("No database engine available")
conn = self.engine.connect()
try:
yield conn
finally:
conn.close()
def is_available(self) -> bool:
"""Check if database is available."""
if self.engine is None:
return False
try:
return self.test_connection()
except Exception:
return False
# Global database connector instance
db_connector = DatabaseConnector()