Spaces:
Running
Running
| """ | |
| AdaptiveAuth Core - Database Module | |
| Database engine, session management, and utilities. | |
| """ | |
| from typing import Generator, Optional | |
| from sqlalchemy import create_engine | |
| from sqlalchemy.orm import sessionmaker, Session | |
| from contextlib import contextmanager | |
| from ..config import get_settings | |
| from ..models import Base | |
| # Global variables for database connection | |
| _engine = None | |
| _SessionLocal = None | |
| def _fix_db_url(url: str) -> str: | |
| """SQLAlchemy 2.x requires postgresql+psycopg2:// not postgresql://""" | |
| if url.startswith("postgres://") or url.startswith("postgresql://"): | |
| url = url.replace("postgres://", "postgresql+psycopg2://", 1) | |
| url = url.replace("postgresql://", "postgresql+psycopg2://", 1) | |
| return url | |
| def get_engine(database_url: Optional[str] = None, echo: bool = False): | |
| """Get or create database engine.""" | |
| global _engine | |
| if _engine is None: | |
| settings = get_settings() | |
| url = database_url or settings.DATABASE_URL | |
| url = _fix_db_url(url) | |
| echo = echo or settings.DATABASE_ECHO | |
| # Configure engine based on database type | |
| connect_args = {} | |
| if url.startswith("sqlite"): | |
| connect_args["check_same_thread"] = False | |
| _engine = create_engine( | |
| url, | |
| connect_args=connect_args, | |
| echo=echo, | |
| pool_pre_ping=True, | |
| pool_recycle=3600, | |
| ) | |
| return _engine | |
| def get_session_local(database_url: Optional[str] = None): | |
| """Get or create session factory.""" | |
| global _SessionLocal | |
| if _SessionLocal is None: | |
| engine = get_engine(database_url) | |
| _SessionLocal = sessionmaker( | |
| autocommit=False, | |
| autoflush=False, | |
| bind=engine | |
| ) | |
| return _SessionLocal | |
| def init_database(database_url: Optional[str] = None, drop_all: bool = False): | |
| """Initialize database tables.""" | |
| engine = get_engine(database_url) | |
| if drop_all: | |
| Base.metadata.drop_all(bind=engine) | |
| Base.metadata.create_all(bind=engine) | |
| return engine | |
| def get_db() -> Generator[Session, None, None]: | |
| """FastAPI dependency for database session.""" | |
| SessionLocal = get_session_local() | |
| db = SessionLocal() | |
| try: | |
| yield db | |
| finally: | |
| db.close() | |
| def get_db_context(): | |
| """Context manager for database session.""" | |
| SessionLocal = get_session_local() | |
| db = SessionLocal() | |
| try: | |
| yield db | |
| db.commit() | |
| except Exception: | |
| db.rollback() | |
| raise | |
| finally: | |
| db.close() | |
| def reset_database_connection(): | |
| """Reset database connection (useful for testing).""" | |
| global _engine, _SessionLocal | |
| if _engine: | |
| _engine.dispose() | |
| _engine = None | |
| _SessionLocal = None | |
| class DatabaseManager: | |
| """Database manager for custom configurations.""" | |
| def __init__(self, database_url: str, echo: bool = False): | |
| self.database_url = database_url | |
| self.echo = echo | |
| self._engine = None | |
| self._SessionLocal = None | |
| def engine(self): | |
| """Get database engine.""" | |
| if self._engine is None: | |
| connect_args = {} | |
| url = _fix_db_url(self.database_url) | |
| if url.startswith("sqlite"): | |
| connect_args["check_same_thread"] = False | |
| self._engine = create_engine( | |
| url, | |
| connect_args=connect_args, | |
| echo=self.echo, | |
| pool_pre_ping=True, | |
| ) | |
| return self._engine | |
| def session_local(self): | |
| """Get session factory.""" | |
| if self._SessionLocal is None: | |
| self._SessionLocal = sessionmaker( | |
| autocommit=False, | |
| autoflush=False, | |
| bind=self.engine | |
| ) | |
| return self._SessionLocal | |
| def init_tables(self, drop_all: bool = False): | |
| """Initialize database tables.""" | |
| if drop_all: | |
| Base.metadata.drop_all(bind=self.engine) | |
| Base.metadata.create_all(bind=self.engine) | |
| def get_session(self) -> Generator[Session, None, None]: | |
| """Get database session generator.""" | |
| db = self.session_local() | |
| try: | |
| yield db | |
| finally: | |
| db.close() | |
| def session_scope(self): | |
| """Context manager for database session.""" | |
| db = self.session_local() | |
| try: | |
| yield db | |
| db.commit() | |
| except Exception: | |
| db.rollback() | |
| raise | |
| finally: | |
| db.close() | |
| def close(self): | |
| """Close database connection.""" | |
| if self._engine: | |
| self._engine.dispose() | |
| self._engine = None | |
| self._SessionLocal = None | |