copper-mind / app /db.py
ifieryarrows's picture
Sync from GitHub (tests passed)
b5ed530 verified
"""
Database connection and session management.
SQLite with WAL mode for concurrent read/write support.
"""
import logging
from contextlib import contextmanager
from typing import Generator
from sqlalchemy import create_engine, event, text
from sqlalchemy.orm import sessionmaker, Session, declarative_base
from app.settings import get_settings
logger = logging.getLogger(__name__)
# SQLAlchemy declarative base
Base = declarative_base()
# Global engine and session factory (lazy initialized)
_engine = None
_SessionLocal = None
def get_engine():
"""Get or create the database engine."""
global _engine
if _engine is None:
settings = get_settings()
database_url = settings.database_url
# Determine if SQLite
is_sqlite = database_url.startswith("sqlite")
# Engine configuration
engine_kwargs = {
"echo": settings.log_level == "DEBUG",
"pool_pre_ping": True,
}
if is_sqlite:
# SQLite-specific settings
engine_kwargs["connect_args"] = {
"check_same_thread": False,
"timeout": 30,
}
else:
# PostgreSQL (Supabase) - connection pooling
engine_kwargs["pool_size"] = 5
engine_kwargs["max_overflow"] = 10
engine_kwargs["pool_timeout"] = 30
_engine = create_engine(database_url, **engine_kwargs)
# SQLite WAL mode and pragmas
if is_sqlite:
@event.listens_for(_engine, "connect")
def set_sqlite_pragma(dbapi_connection, connection_record):
cursor = dbapi_connection.cursor()
# WAL mode for concurrent reads
cursor.execute("PRAGMA journal_mode=WAL")
# Busy timeout (ms) - wait for locks instead of immediate failure
cursor.execute("PRAGMA busy_timeout=5000")
# Synchronous mode - balance between safety and speed
cursor.execute("PRAGMA synchronous=NORMAL")
# Foreign keys enforcement
cursor.execute("PRAGMA foreign_keys=ON")
# Memory-mapped I/O (faster reads)
cursor.execute("PRAGMA mmap_size=268435456") # 256MB
cursor.close()
logger.info("SQLite configured with WAL mode and performance pragmas")
# Log sanitized URL (hide password)
import re
safe_url = re.sub(r'://[^:]+:[^@]+@', '://***:***@', database_url.split('?')[0])
logger.info(f"Database engine created: {safe_url}")
return _engine
def get_session_factory() -> sessionmaker:
"""Get or create the session factory."""
global _SessionLocal
if _SessionLocal is None:
engine = get_engine()
_SessionLocal = sessionmaker(
autocommit=False,
autoflush=False,
bind=engine
)
return _SessionLocal
def SessionLocal() -> Session:
"""Create a new database session."""
factory = get_session_factory()
return factory()
@contextmanager
def get_db() -> Generator[Session, None, None]:
"""
Context manager for database sessions.
Automatically commits on success, rolls back on exception.
"""
session = SessionLocal()
try:
yield session
session.commit()
except Exception:
session.rollback()
raise
finally:
session.close()
def _run_migrations(engine):
"""
Run necessary database migrations for schema changes.
These are safe to run multiple times (idempotent).
"""
settings = get_settings()
is_sqlite = settings.database_url.startswith("sqlite")
with engine.connect() as conn:
# Migration: Add ai_stance column to ai_commentaries if it doesn't exist
try:
if is_sqlite:
# SQLite doesn't support IF NOT EXISTS for columns, check manually
result = conn.execute(text("PRAGMA table_info(ai_commentaries)"))
columns = [row[1] for row in result.fetchall()]
if 'ai_stance' not in columns:
conn.execute(text("ALTER TABLE ai_commentaries ADD COLUMN ai_stance VARCHAR(20) DEFAULT 'NEUTRAL'"))
conn.commit()
logger.info("Migration: Added ai_stance column to ai_commentaries")
else:
# PostgreSQL supports IF NOT EXISTS via DO block
conn.execute(text("""
ALTER TABLE ai_commentaries
ADD COLUMN IF NOT EXISTS ai_stance VARCHAR(20) DEFAULT 'NEUTRAL'
"""))
conn.commit()
logger.info("Migration: Ensured ai_stance column exists in ai_commentaries")
except Exception as e:
logger.debug(f"Migration check for ai_stance: {e}")
# Migration: V2 stage-2 metrics columns
v2_metric_columns = [
("articles_scored_v2", "INTEGER"),
("llm_parse_fail_count", "INTEGER"),
("escalation_count", "INTEGER"),
("fallback_count", "INTEGER"),
]
try:
if is_sqlite:
result = conn.execute(text("PRAGMA table_info(pipeline_run_metrics)"))
columns = [row[1] for row in result.fetchall()]
for column_name, column_type in v2_metric_columns:
if column_name not in columns:
conn.execute(
text(
f"ALTER TABLE pipeline_run_metrics ADD COLUMN {column_name} {column_type}"
)
)
conn.commit()
else:
for column_name, column_type in v2_metric_columns:
conn.execute(
text(
f"ALTER TABLE pipeline_run_metrics "
f"ADD COLUMN IF NOT EXISTS {column_name} {column_type}"
)
)
conn.commit()
logger.info("Migration: Ensured V2 Stage-2 metric columns exist")
except Exception as e:
logger.debug(f"Migration check for V2 metric columns: {e}")
# Migration: TFT-ASRO metric columns on pipeline_run_metrics
tft_metric_columns = [
("tft_embeddings_computed", "INTEGER"),
("tft_trained", "BOOLEAN DEFAULT FALSE"),
("tft_val_loss", "FLOAT"),
("tft_sharpe", "FLOAT"),
("tft_directional_accuracy", "FLOAT"),
("tft_snapshot_generated", "BOOLEAN DEFAULT FALSE"),
]
try:
if is_sqlite:
result = conn.execute(text("PRAGMA table_info(pipeline_run_metrics)"))
columns = [row[1] for row in result.fetchall()]
for column_name, column_type in tft_metric_columns:
col_name_only = column_name
if col_name_only not in columns:
conn.execute(
text(
f"ALTER TABLE pipeline_run_metrics ADD COLUMN {column_name} {column_type}"
)
)
conn.commit()
else:
for column_name, column_type in tft_metric_columns:
conn.execute(
text(
f"ALTER TABLE pipeline_run_metrics "
f"ADD COLUMN IF NOT EXISTS {column_name} {column_type}"
)
)
conn.commit()
logger.info("Migration: Ensured TFT-ASRO metric columns exist")
except Exception as e:
logger.debug(f"Migration check for TFT metric columns: {e}")
def init_db():
"""
Initialize the database - create all tables.
Safe to call multiple times (uses CREATE IF NOT EXISTS).
Also runs any necessary migrations for schema changes.
"""
# Import models to register them with Base
from app import models # noqa: F401
engine = get_engine()
Base.metadata.create_all(bind=engine)
logger.info("Database tables created/verified")
# Run migrations for existing tables
_run_migrations(engine)
def get_db_type() -> str:
"""Return the database type (sqlite, postgresql, etc.)."""
settings = get_settings()
url = settings.database_url
if url.startswith("sqlite"):
return "sqlite"
elif url.startswith("postgresql"):
return "postgresql"
elif url.startswith("mysql"):
return "mysql"
else:
return "unknown"
def check_db_connection() -> bool:
"""Test database connectivity."""
try:
engine = get_engine()
with engine.connect() as conn:
conn.execute(text("SELECT 1"))
return True
except Exception as e:
logger.error(f"Database connection check failed: {e}")
return False