File size: 3,488 Bytes
fcf8749 d04be17 fcf8749 d04be17 fcf8749 d04be17 fcf8749 d04be17 fcf8749 d04be17 fcf8749 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 | """
Database connection and session management.
Uses async SQLAlchemy with asyncpg driver for PostgreSQL,
or aiosqlite for local/free-tier deployment.
"""
import uuid
import os
from sqlalchemy import CHAR, text
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.types import TypeDecorator
from app.config import get_settings
class GUID(TypeDecorator):
"""
Platform-independent GUID type.
Uses PostgreSQL's UUID type when available, otherwise uses CHAR(32) for SQLite.
"""
impl = CHAR
cache_ok = True
def load_dialect_impl(self, dialect):
if dialect.name == 'postgresql':
from sqlalchemy.dialects.postgresql import UUID as PG_UUID
return dialect.type_descriptor(PG_UUID(as_uuid=True))
else:
return dialect.type_descriptor(CHAR(32))
def process_bind_param(self, value, dialect):
if value is None:
return value
elif dialect.name == 'postgresql':
return value
else:
if isinstance(value, uuid.UUID):
return value.hex
else:
return uuid.UUID(value).hex
def process_result_value(self, value, dialect):
if value is None:
return value
elif dialect.name == 'postgresql':
return value
else:
if isinstance(value, uuid.UUID):
return value
return uuid.UUID(value)
settings = get_settings()
# Determine database URL - fallback to SQLite if no PostgreSQL configured
_db_url = settings.database_url
if not _db_url or _db_url == "":
# Use SQLite as fallback (works without external DB)
_db_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "data")
os.makedirs(_db_dir, exist_ok=True)
_db_url = f"sqlite+aiosqlite:///{_db_dir}/fairrelay.db"
_is_sqlite = True
else:
_is_sqlite = False
# Create async engine with appropriate settings
if _is_sqlite:
engine = create_async_engine(
_db_url,
echo=settings.debug,
future=True,
connect_args={"check_same_thread": False},
)
else:
engine = create_async_engine(
_db_url,
echo=settings.debug,
future=True,
pool_size=20,
max_overflow=10,
pool_recycle=3600,
pool_pre_ping=True,
)
# Session factory
async_session_maker = async_sessionmaker(
engine,
class_=AsyncSession,
expire_on_commit=False,
)
class Base(DeclarativeBase):
"""Base class for all SQLAlchemy models."""
pass
async def get_db() -> AsyncSession:
"""
Dependency that provides a database session.
Yields a session and ensures it's closed after use.
"""
async with async_session_maker() as session:
try:
yield session
await session.commit()
except Exception:
await session.rollback()
raise
finally:
await session.close()
async def check_db_health() -> bool:
"""Check database connectivity by running a simple query."""
try:
async with engine.connect() as conn:
await conn.execute(text("SELECT 1"))
return True
except Exception:
return False
async def init_db() -> None:
"""Initialize database tables."""
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
|