Spaces:
Running
Running
File size: 4,902 Bytes
7d369c8 c92a083 7d369c8 c92a083 7d369c8 c92a083 7d369c8 c92a083 7d369c8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 |
"""
AdaptiveAuth Core - Database Module
Database engine, session management, and utilities.
"""
from typing import Generator, Optional
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, Session
from contextlib import contextmanager
from ..config import get_settings
from ..models import Base
# Global variables for database connection
_engine = None
_SessionLocal = None
def _fix_db_url(url: str) -> str:
"""SQLAlchemy 2.x requires postgresql+psycopg2:// not postgresql://"""
if url.startswith("postgres://") or url.startswith("postgresql://"):
url = url.replace("postgres://", "postgresql+psycopg2://", 1)
url = url.replace("postgresql://", "postgresql+psycopg2://", 1)
return url
def get_engine(database_url: Optional[str] = None, echo: bool = False):
"""Get or create database engine."""
global _engine
if _engine is None:
settings = get_settings()
url = database_url or settings.DATABASE_URL
url = _fix_db_url(url)
echo = echo or settings.DATABASE_ECHO
# Configure engine based on database type
connect_args = {}
if url.startswith("sqlite"):
connect_args["check_same_thread"] = False
_engine = create_engine(
url,
connect_args=connect_args,
echo=echo,
pool_pre_ping=True,
pool_recycle=3600,
)
return _engine
def get_session_local(database_url: Optional[str] = None):
"""Get or create session factory."""
global _SessionLocal
if _SessionLocal is None:
engine = get_engine(database_url)
_SessionLocal = sessionmaker(
autocommit=False,
autoflush=False,
bind=engine
)
return _SessionLocal
def init_database(database_url: Optional[str] = None, drop_all: bool = False):
"""Initialize database tables."""
engine = get_engine(database_url)
if drop_all:
Base.metadata.drop_all(bind=engine)
Base.metadata.create_all(bind=engine)
return engine
def get_db() -> Generator[Session, None, None]:
"""FastAPI dependency for database session."""
SessionLocal = get_session_local()
db = SessionLocal()
try:
yield db
finally:
db.close()
@contextmanager
def get_db_context():
"""Context manager for database session."""
SessionLocal = get_session_local()
db = SessionLocal()
try:
yield db
db.commit()
except Exception:
db.rollback()
raise
finally:
db.close()
def reset_database_connection():
"""Reset database connection (useful for testing)."""
global _engine, _SessionLocal
if _engine:
_engine.dispose()
_engine = None
_SessionLocal = None
class DatabaseManager:
"""Database manager for custom configurations."""
def __init__(self, database_url: str, echo: bool = False):
self.database_url = database_url
self.echo = echo
self._engine = None
self._SessionLocal = None
@property
def engine(self):
"""Get database engine."""
if self._engine is None:
connect_args = {}
url = _fix_db_url(self.database_url)
if url.startswith("sqlite"):
connect_args["check_same_thread"] = False
self._engine = create_engine(
url,
connect_args=connect_args,
echo=self.echo,
pool_pre_ping=True,
)
return self._engine
@property
def session_local(self):
"""Get session factory."""
if self._SessionLocal is None:
self._SessionLocal = sessionmaker(
autocommit=False,
autoflush=False,
bind=self.engine
)
return self._SessionLocal
def init_tables(self, drop_all: bool = False):
"""Initialize database tables."""
if drop_all:
Base.metadata.drop_all(bind=self.engine)
Base.metadata.create_all(bind=self.engine)
def get_session(self) -> Generator[Session, None, None]:
"""Get database session generator."""
db = self.session_local()
try:
yield db
finally:
db.close()
@contextmanager
def session_scope(self):
"""Context manager for database session."""
db = self.session_local()
try:
yield db
db.commit()
except Exception:
db.rollback()
raise
finally:
db.close()
def close(self):
"""Close database connection."""
if self._engine:
self._engine.dispose()
self._engine = None
self._SessionLocal = None
|