Spaces:
Running
Running
| """ | |
| Database connection and session management. | |
| SQLite with WAL mode for concurrent read/write support. | |
| """ | |
| import logging | |
| from contextlib import contextmanager | |
| from typing import Generator | |
| from sqlalchemy import create_engine, event, text | |
| from sqlalchemy.orm import sessionmaker, Session, declarative_base | |
| from app.settings import get_settings | |
| logger = logging.getLogger(__name__) | |
| # SQLAlchemy declarative base | |
| Base = declarative_base() | |
| # Global engine and session factory (lazy initialized) | |
| _engine = None | |
| _SessionLocal = None | |
| def get_engine(): | |
| """Get or create the database engine.""" | |
| global _engine | |
| if _engine is None: | |
| settings = get_settings() | |
| database_url = settings.database_url | |
| # Determine if SQLite | |
| is_sqlite = database_url.startswith("sqlite") | |
| # Engine configuration | |
| engine_kwargs = { | |
| "echo": settings.log_level == "DEBUG", | |
| "pool_pre_ping": True, | |
| } | |
| if is_sqlite: | |
| # SQLite-specific settings | |
| engine_kwargs["connect_args"] = { | |
| "check_same_thread": False, | |
| "timeout": 30, | |
| } | |
| else: | |
| # PostgreSQL (Supabase) - connection pooling | |
| engine_kwargs["pool_size"] = 5 | |
| engine_kwargs["max_overflow"] = 10 | |
| engine_kwargs["pool_timeout"] = 30 | |
| _engine = create_engine(database_url, **engine_kwargs) | |
| # SQLite WAL mode and pragmas | |
| if is_sqlite: | |
| def set_sqlite_pragma(dbapi_connection, connection_record): | |
| cursor = dbapi_connection.cursor() | |
| # WAL mode for concurrent reads | |
| cursor.execute("PRAGMA journal_mode=WAL") | |
| # Busy timeout (ms) - wait for locks instead of immediate failure | |
| cursor.execute("PRAGMA busy_timeout=5000") | |
| # Synchronous mode - balance between safety and speed | |
| cursor.execute("PRAGMA synchronous=NORMAL") | |
| # Foreign keys enforcement | |
| cursor.execute("PRAGMA foreign_keys=ON") | |
| # Memory-mapped I/O (faster reads) | |
| cursor.execute("PRAGMA mmap_size=268435456") # 256MB | |
| cursor.close() | |
| logger.info("SQLite configured with WAL mode and performance pragmas") | |
| # Log sanitized URL (hide password) | |
| import re | |
| safe_url = re.sub(r'://[^:]+:[^@]+@', '://***:***@', database_url.split('?')[0]) | |
| logger.info(f"Database engine created: {safe_url}") | |
| return _engine | |
| def get_session_factory() -> sessionmaker: | |
| """Get or create the session factory.""" | |
| global _SessionLocal | |
| if _SessionLocal is None: | |
| engine = get_engine() | |
| _SessionLocal = sessionmaker( | |
| autocommit=False, | |
| autoflush=False, | |
| bind=engine | |
| ) | |
| return _SessionLocal | |
| def SessionLocal() -> Session: | |
| """Create a new database session.""" | |
| factory = get_session_factory() | |
| return factory() | |
| def get_db() -> Generator[Session, None, None]: | |
| """ | |
| Context manager for database sessions. | |
| Automatically commits on success, rolls back on exception. | |
| """ | |
| session = SessionLocal() | |
| try: | |
| yield session | |
| session.commit() | |
| except Exception: | |
| session.rollback() | |
| raise | |
| finally: | |
| session.close() | |
| def _run_migrations(engine): | |
| """ | |
| Run necessary database migrations for schema changes. | |
| These are safe to run multiple times (idempotent). | |
| """ | |
| settings = get_settings() | |
| is_sqlite = settings.database_url.startswith("sqlite") | |
| with engine.connect() as conn: | |
| # Migration: Add ai_stance column to ai_commentaries if it doesn't exist | |
| try: | |
| if is_sqlite: | |
| # SQLite doesn't support IF NOT EXISTS for columns, check manually | |
| result = conn.execute(text("PRAGMA table_info(ai_commentaries)")) | |
| columns = [row[1] for row in result.fetchall()] | |
| if 'ai_stance' not in columns: | |
| conn.execute(text("ALTER TABLE ai_commentaries ADD COLUMN ai_stance VARCHAR(20) DEFAULT 'NEUTRAL'")) | |
| conn.commit() | |
| logger.info("Migration: Added ai_stance column to ai_commentaries") | |
| else: | |
| # PostgreSQL supports IF NOT EXISTS via DO block | |
| conn.execute(text(""" | |
| ALTER TABLE ai_commentaries | |
| ADD COLUMN IF NOT EXISTS ai_stance VARCHAR(20) DEFAULT 'NEUTRAL' | |
| """)) | |
| conn.commit() | |
| logger.info("Migration: Ensured ai_stance column exists in ai_commentaries") | |
| except Exception as e: | |
| logger.debug(f"Migration check for ai_stance: {e}") | |
| # Migration: V2 stage-2 metrics columns | |
| v2_metric_columns = [ | |
| ("articles_scored_v2", "INTEGER"), | |
| ("llm_parse_fail_count", "INTEGER"), | |
| ("escalation_count", "INTEGER"), | |
| ("fallback_count", "INTEGER"), | |
| ] | |
| try: | |
| if is_sqlite: | |
| result = conn.execute(text("PRAGMA table_info(pipeline_run_metrics)")) | |
| columns = [row[1] for row in result.fetchall()] | |
| for column_name, column_type in v2_metric_columns: | |
| if column_name not in columns: | |
| conn.execute( | |
| text( | |
| f"ALTER TABLE pipeline_run_metrics ADD COLUMN {column_name} {column_type}" | |
| ) | |
| ) | |
| conn.commit() | |
| else: | |
| for column_name, column_type in v2_metric_columns: | |
| conn.execute( | |
| text( | |
| f"ALTER TABLE pipeline_run_metrics " | |
| f"ADD COLUMN IF NOT EXISTS {column_name} {column_type}" | |
| ) | |
| ) | |
| conn.commit() | |
| logger.info("Migration: Ensured V2 Stage-2 metric columns exist") | |
| except Exception as e: | |
| logger.debug(f"Migration check for V2 metric columns: {e}") | |
| # Migration: TFT-ASRO metric columns on pipeline_run_metrics | |
| tft_metric_columns = [ | |
| ("tft_embeddings_computed", "INTEGER"), | |
| ("tft_trained", "BOOLEAN DEFAULT FALSE"), | |
| ("tft_val_loss", "FLOAT"), | |
| ("tft_sharpe", "FLOAT"), | |
| ("tft_directional_accuracy", "FLOAT"), | |
| ("tft_snapshot_generated", "BOOLEAN DEFAULT FALSE"), | |
| ] | |
| try: | |
| if is_sqlite: | |
| result = conn.execute(text("PRAGMA table_info(pipeline_run_metrics)")) | |
| columns = [row[1] for row in result.fetchall()] | |
| for column_name, column_type in tft_metric_columns: | |
| col_name_only = column_name | |
| if col_name_only not in columns: | |
| conn.execute( | |
| text( | |
| f"ALTER TABLE pipeline_run_metrics ADD COLUMN {column_name} {column_type}" | |
| ) | |
| ) | |
| conn.commit() | |
| else: | |
| for column_name, column_type in tft_metric_columns: | |
| conn.execute( | |
| text( | |
| f"ALTER TABLE pipeline_run_metrics " | |
| f"ADD COLUMN IF NOT EXISTS {column_name} {column_type}" | |
| ) | |
| ) | |
| conn.commit() | |
| logger.info("Migration: Ensured TFT-ASRO metric columns exist") | |
| except Exception as e: | |
| logger.debug(f"Migration check for TFT metric columns: {e}") | |
| def init_db(): | |
| """ | |
| Initialize the database - create all tables. | |
| Safe to call multiple times (uses CREATE IF NOT EXISTS). | |
| Also runs any necessary migrations for schema changes. | |
| """ | |
| # Import models to register them with Base | |
| from app import models # noqa: F401 | |
| engine = get_engine() | |
| Base.metadata.create_all(bind=engine) | |
| logger.info("Database tables created/verified") | |
| # Run migrations for existing tables | |
| _run_migrations(engine) | |
| def get_db_type() -> str: | |
| """Return the database type (sqlite, postgresql, etc.).""" | |
| settings = get_settings() | |
| url = settings.database_url | |
| if url.startswith("sqlite"): | |
| return "sqlite" | |
| elif url.startswith("postgresql"): | |
| return "postgresql" | |
| elif url.startswith("mysql"): | |
| return "mysql" | |
| else: | |
| return "unknown" | |
| def check_db_connection() -> bool: | |
| """Test database connectivity.""" | |
| try: | |
| engine = get_engine() | |
| with engine.connect() as conn: | |
| conn.execute(text("SELECT 1")) | |
| return True | |
| except Exception as e: | |
| logger.error(f"Database connection check failed: {e}") | |
| return False | |