Spaces:
Running
Running
File size: 9,179 Bytes
86c0ef4 0956794 86c0ef4 255af0f 86c0ef4 f4e8f50 aa191f1 b5ed530 f4e8f50 86c0ef4 f4e8f50 86c0ef4 f4e8f50 86c0ef4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 | """
Database connection and session management.
SQLite with WAL mode for concurrent read/write support.
"""
import logging
from contextlib import contextmanager
from typing import Generator
from sqlalchemy import create_engine, event, text
from sqlalchemy.orm import sessionmaker, Session, declarative_base
from app.settings import get_settings
logger = logging.getLogger(__name__)
# SQLAlchemy declarative base
Base = declarative_base()
# Global engine and session factory (lazy initialized)
_engine = None
_SessionLocal = None
def get_engine():
"""Get or create the database engine."""
global _engine
if _engine is None:
settings = get_settings()
database_url = settings.database_url
# Determine if SQLite
is_sqlite = database_url.startswith("sqlite")
# Engine configuration
engine_kwargs = {
"echo": settings.log_level == "DEBUG",
"pool_pre_ping": True,
}
if is_sqlite:
# SQLite-specific settings
engine_kwargs["connect_args"] = {
"check_same_thread": False,
"timeout": 30,
}
else:
# PostgreSQL (Supabase) - connection pooling
engine_kwargs["pool_size"] = 5
engine_kwargs["max_overflow"] = 10
engine_kwargs["pool_timeout"] = 30
engine_kwargs["pool_recycle"] = 300
_engine = create_engine(database_url, **engine_kwargs)
# SQLite WAL mode and pragmas
if is_sqlite:
@event.listens_for(_engine, "connect")
def set_sqlite_pragma(dbapi_connection, connection_record):
cursor = dbapi_connection.cursor()
# WAL mode for concurrent reads
cursor.execute("PRAGMA journal_mode=WAL")
# Busy timeout (ms) - wait for locks instead of immediate failure
cursor.execute("PRAGMA busy_timeout=5000")
# Synchronous mode - balance between safety and speed
cursor.execute("PRAGMA synchronous=NORMAL")
# Foreign keys enforcement
cursor.execute("PRAGMA foreign_keys=ON")
# Memory-mapped I/O (faster reads)
cursor.execute("PRAGMA mmap_size=268435456") # 256MB
cursor.close()
logger.info("SQLite configured with WAL mode and performance pragmas")
# Log sanitized URL (hide password)
import re
safe_url = re.sub(r'://[^:]+:[^@]+@', '://***:***@', database_url.split('?')[0])
logger.info(f"Database engine created: {safe_url}")
return _engine
def get_session_factory() -> sessionmaker:
"""Get or create the session factory."""
global _SessionLocal
if _SessionLocal is None:
engine = get_engine()
_SessionLocal = sessionmaker(
autocommit=False,
autoflush=False,
bind=engine
)
return _SessionLocal
def SessionLocal() -> Session:
"""Create a new database session."""
factory = get_session_factory()
return factory()
@contextmanager
def get_db() -> Generator[Session, None, None]:
"""
Context manager for database sessions.
Automatically commits on success, rolls back on exception.
"""
session = SessionLocal()
try:
yield session
session.commit()
except Exception:
session.rollback()
raise
finally:
session.close()
def _run_migrations(engine):
"""
Run necessary database migrations for schema changes.
These are safe to run multiple times (idempotent).
"""
settings = get_settings()
is_sqlite = settings.database_url.startswith("sqlite")
with engine.connect() as conn:
# Migration: Add ai_stance column to ai_commentaries if it doesn't exist
try:
if is_sqlite:
# SQLite doesn't support IF NOT EXISTS for columns, check manually
result = conn.execute(text("PRAGMA table_info(ai_commentaries)"))
columns = [row[1] for row in result.fetchall()]
if 'ai_stance' not in columns:
conn.execute(text("ALTER TABLE ai_commentaries ADD COLUMN ai_stance VARCHAR(20) DEFAULT 'NEUTRAL'"))
conn.commit()
logger.info("Migration: Added ai_stance column to ai_commentaries")
else:
# PostgreSQL supports IF NOT EXISTS via DO block
conn.execute(text("""
ALTER TABLE ai_commentaries
ADD COLUMN IF NOT EXISTS ai_stance VARCHAR(20) DEFAULT 'NEUTRAL'
"""))
conn.commit()
logger.info("Migration: Ensured ai_stance column exists in ai_commentaries")
except Exception as e:
logger.debug(f"Migration check for ai_stance: {e}")
# Migration: V2 stage-2 metrics columns
v2_metric_columns = [
("articles_scored_v2", "INTEGER"),
("llm_parse_fail_count", "INTEGER"),
("escalation_count", "INTEGER"),
("fallback_count", "INTEGER"),
]
try:
if is_sqlite:
result = conn.execute(text("PRAGMA table_info(pipeline_run_metrics)"))
columns = [row[1] for row in result.fetchall()]
for column_name, column_type in v2_metric_columns:
if column_name not in columns:
conn.execute(
text(
f"ALTER TABLE pipeline_run_metrics ADD COLUMN {column_name} {column_type}"
)
)
conn.commit()
else:
for column_name, column_type in v2_metric_columns:
conn.execute(
text(
f"ALTER TABLE pipeline_run_metrics "
f"ADD COLUMN IF NOT EXISTS {column_name} {column_type}"
)
)
conn.commit()
logger.info("Migration: Ensured V2 Stage-2 metric columns exist")
except Exception as e:
logger.debug(f"Migration check for V2 metric columns: {e}")
# Migration: TFT-ASRO metric columns on pipeline_run_metrics
tft_metric_columns = [
("tft_embeddings_computed", "INTEGER"),
("tft_trained", "BOOLEAN DEFAULT FALSE"),
("tft_val_loss", "FLOAT"),
("tft_sharpe", "FLOAT"),
("tft_directional_accuracy", "FLOAT"),
("tft_snapshot_generated", "BOOLEAN DEFAULT FALSE"),
]
try:
if is_sqlite:
result = conn.execute(text("PRAGMA table_info(pipeline_run_metrics)"))
columns = [row[1] for row in result.fetchall()]
for column_name, column_type in tft_metric_columns:
col_name_only = column_name
if col_name_only not in columns:
conn.execute(
text(
f"ALTER TABLE pipeline_run_metrics ADD COLUMN {column_name} {column_type}"
)
)
conn.commit()
else:
for column_name, column_type in tft_metric_columns:
conn.execute(
text(
f"ALTER TABLE pipeline_run_metrics "
f"ADD COLUMN IF NOT EXISTS {column_name} {column_type}"
)
)
conn.commit()
logger.info("Migration: Ensured TFT-ASRO metric columns exist")
except Exception as e:
logger.debug(f"Migration check for TFT metric columns: {e}")
def init_db():
"""
Initialize the database - create all tables.
Safe to call multiple times (uses CREATE IF NOT EXISTS).
Also runs any necessary migrations for schema changes.
"""
# Import models to register them with Base
from app import models # noqa: F401
engine = get_engine()
Base.metadata.create_all(bind=engine)
logger.info("Database tables created/verified")
# Run migrations for existing tables
_run_migrations(engine)
def get_db_type() -> str:
"""Return the database type (sqlite, postgresql, etc.)."""
settings = get_settings()
url = settings.database_url
if url.startswith("sqlite"):
return "sqlite"
elif url.startswith("postgresql"):
return "postgresql"
elif url.startswith("mysql"):
return "mysql"
else:
return "unknown"
def check_db_connection() -> bool:
"""Test database connectivity."""
try:
engine = get_engine()
with engine.connect() as conn:
conn.execute(text("SELECT 1"))
return True
except Exception as e:
logger.error(f"Database connection check failed: {e}")
return False
|