auth / adaptiveauth /core /database.py
Piyush1225's picture
fix: auto-fix postgresql:// to postgresql+psycopg2:// for SQLAlchemy 2.x
c92a083
"""
AdaptiveAuth Core - Database Module
Database engine, session management, and utilities.
"""
from typing import Generator, Optional
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, Session
from contextlib import contextmanager
from ..config import get_settings
from ..models import Base
# Global variables for database connection
_engine = None
_SessionLocal = None
def _fix_db_url(url: str) -> str:
"""SQLAlchemy 2.x requires postgresql+psycopg2:// not postgresql://"""
if url.startswith("postgres://") or url.startswith("postgresql://"):
url = url.replace("postgres://", "postgresql+psycopg2://", 1)
url = url.replace("postgresql://", "postgresql+psycopg2://", 1)
return url
def get_engine(database_url: Optional[str] = None, echo: bool = False):
"""Get or create database engine."""
global _engine
if _engine is None:
settings = get_settings()
url = database_url or settings.DATABASE_URL
url = _fix_db_url(url)
echo = echo or settings.DATABASE_ECHO
# Configure engine based on database type
connect_args = {}
if url.startswith("sqlite"):
connect_args["check_same_thread"] = False
_engine = create_engine(
url,
connect_args=connect_args,
echo=echo,
pool_pre_ping=True,
pool_recycle=3600,
)
return _engine
def get_session_local(database_url: Optional[str] = None):
"""Get or create session factory."""
global _SessionLocal
if _SessionLocal is None:
engine = get_engine(database_url)
_SessionLocal = sessionmaker(
autocommit=False,
autoflush=False,
bind=engine
)
return _SessionLocal
def init_database(database_url: Optional[str] = None, drop_all: bool = False):
"""Initialize database tables."""
engine = get_engine(database_url)
if drop_all:
Base.metadata.drop_all(bind=engine)
Base.metadata.create_all(bind=engine)
return engine
def get_db() -> Generator[Session, None, None]:
"""FastAPI dependency for database session."""
SessionLocal = get_session_local()
db = SessionLocal()
try:
yield db
finally:
db.close()
@contextmanager
def get_db_context():
"""Context manager for database session."""
SessionLocal = get_session_local()
db = SessionLocal()
try:
yield db
db.commit()
except Exception:
db.rollback()
raise
finally:
db.close()
def reset_database_connection():
"""Reset database connection (useful for testing)."""
global _engine, _SessionLocal
if _engine:
_engine.dispose()
_engine = None
_SessionLocal = None
class DatabaseManager:
"""Database manager for custom configurations."""
def __init__(self, database_url: str, echo: bool = False):
self.database_url = database_url
self.echo = echo
self._engine = None
self._SessionLocal = None
@property
def engine(self):
"""Get database engine."""
if self._engine is None:
connect_args = {}
url = _fix_db_url(self.database_url)
if url.startswith("sqlite"):
connect_args["check_same_thread"] = False
self._engine = create_engine(
url,
connect_args=connect_args,
echo=self.echo,
pool_pre_ping=True,
)
return self._engine
@property
def session_local(self):
"""Get session factory."""
if self._SessionLocal is None:
self._SessionLocal = sessionmaker(
autocommit=False,
autoflush=False,
bind=self.engine
)
return self._SessionLocal
def init_tables(self, drop_all: bool = False):
"""Initialize database tables."""
if drop_all:
Base.metadata.drop_all(bind=self.engine)
Base.metadata.create_all(bind=self.engine)
def get_session(self) -> Generator[Session, None, None]:
"""Get database session generator."""
db = self.session_local()
try:
yield db
finally:
db.close()
@contextmanager
def session_scope(self):
"""Context manager for database session."""
db = self.session_local()
try:
yield db
db.commit()
except Exception:
db.rollback()
raise
finally:
db.close()
def close(self):
"""Close database connection."""
if self._engine:
self._engine.dispose()
self._engine = None
self._SessionLocal = None