""" PostgreSQL database connection and session management for Spa microservice. Following TMS pattern with SQLAlchemy async engine. """ import logging import ssl from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession from sqlalchemy.orm import sessionmaker from sqlalchemy import MetaData, text from contextlib import asynccontextmanager from app.core.config import settings logger = logging.getLogger(__name__) # Database URL validation DATABASE_URI = settings.POSTGRES_URI if not DATABASE_URI: logger.error("POSTGRES_URI is empty or missing from settings") raise ValueError("POSTGRES_URI is not set. Check environment variables.") # Parse and log connection details (with masked password) def mask_connection_string(uri: str) -> str: """Mask password in connection string for safe logging""" if not uri: return "EMPTY" try: if "@" in uri: protocol_user_pass, host_db = uri.split("@", 1) if "://" in protocol_user_pass: protocol, user_pass = protocol_user_pass.split("://", 1) if ":" in user_pass: user, _ = user_pass.split(":", 1) masked_uri = f"{protocol}://{user}:***@{host_db}" else: masked_uri = f"{protocol}://{user_pass}:***@{host_db}" else: masked_uri = f"{protocol_user_pass}:***@{host_db}" else: masked_uri = uri return masked_uri except Exception: return "INVALID_FORMAT" def parse_connection_details(uri: str) -> dict: """Parse connection string to extract host, port, database""" try: if "@" in uri: _, host_db = uri.split("@", 1) if "/" in host_db: host_port, database = host_db.split("/", 1) if ":" in host_port: host, port = host_port.split(":", 1) else: host = host_port port = "5432" return { "host": host, "port": port, "database": database.split("?")[0] } except Exception: pass return {"host": "unknown", "port": "unknown", "database": "unknown"} masked_uri = mask_connection_string(DATABASE_URI) conn_details = parse_connection_details(DATABASE_URI) logger.info( "PostgreSQL connection configured", extra={ "connection_string": masked_uri, "host": conn_details["host"], "port": conn_details["port"], "database": conn_details["database"], "ssl_mode": settings.POSTGRES_SSL_MODE } ) # Build connect args including optional SSL CONNECT_ARGS = { "server_settings": { "application_name": "cuatrolabs-spa-ms", "jit": "off" }, "command_timeout": 60, "statement_cache_size": 0 } mode = (settings.POSTGRES_SSL_MODE or "disable").lower() if mode != "disable": ssl_context: ssl.SSLContext if mode == "verify-full": ssl_context = ssl.create_default_context(cafile=settings.POSTGRES_SSL_ROOT_CERT) if settings.POSTGRES_SSL_ROOT_CERT else ssl.create_default_context() if settings.POSTGRES_SSL_CERT and settings.POSTGRES_SSL_KEY: try: ssl_context.load_cert_chain(certfile=settings.POSTGRES_SSL_CERT, keyfile=settings.POSTGRES_SSL_KEY) except Exception as e: logger.warning("Failed to load client SSL cert/key for PostgreSQL", exc_info=e) ssl_context.check_hostname = True ssl_context.verify_mode = ssl.CERT_REQUIRED else: ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ssl_context.check_hostname = False ssl_context.verify_mode = ssl.CERT_NONE CONNECT_ARGS["ssl"] = ssl_context logger.info("PostgreSQL SSL enabled", extra={"ssl_mode": settings.POSTGRES_SSL_MODE}) # Create async engine with connection pool settings async_engine = create_async_engine( DATABASE_URI, echo=settings.DEBUG, future=True, pool_size=10, max_overflow=20, pool_timeout=30, pool_recycle=3600, pool_pre_ping=True, connect_args=CONNECT_ARGS ) # Create async session factory async_session = sessionmaker( async_engine, expire_on_commit=False, class_=AsyncSession ) # Metadata for table creation metadata = MetaData() logger.info("PostgreSQL configuration loaded successfully") async def connect_to_database() -> None: """Initialize database connection when the application starts.""" import asyncio print(f"\n{'='*70}") print(f"[POSTGRES] Starting Database Connection") print(f"{'='*70}") print(f"[POSTGRES] Connection String: {masked_uri}") print(f"[POSTGRES] Host: {conn_details['host']}") print(f"[POSTGRES] Port: {conn_details['port']}") print(f"[POSTGRES] Database: {conn_details['database']}") print(f"[POSTGRES] SSL Mode: {settings.POSTGRES_SSL_MODE}") print(f"{'='*70}\n") logger.info( "Starting PostgreSQL connection attempts", extra={ "max_retries": settings.POSTGRES_CONNECT_MAX_RETRIES, "initial_delay_ms": settings.POSTGRES_CONNECT_INITIAL_DELAY_MS, "backoff_multiplier": settings.POSTGRES_CONNECT_BACKOFF_MULTIPLIER } ) attempts = 0 max_attempts = settings.POSTGRES_CONNECT_MAX_RETRIES delay = settings.POSTGRES_CONNECT_INITIAL_DELAY_MS / 1000.0 last_error = None while attempts < max_attempts: try: async with async_engine.begin() as conn: await conn.execute(text("SELECT 1")) logger.info("Successfully connected to PostgreSQL database") print(f"[POSTGRES] ✅ Connection successful after {attempts + 1} attempt(s)") return except Exception as e: last_error = e attempts += 1 error_msg = str(e) error_type = type(e).__name__ logger.warning( "PostgreSQL connection attempt failed", extra={ "attempt": attempts, "max_attempts": max_attempts, "retry_delay_ms": int(delay * 1000), "error_type": error_type, "error_message": error_msg[:200] } ) print(f"[POSTGRES] ❌ Connection attempt {attempts}/{max_attempts} failed") if attempts < max_attempts: print(f"[POSTGRES] Retrying in {delay:.2f}s...") await asyncio.sleep(delay) delay = min(delay * settings.POSTGRES_CONNECT_BACKOFF_MULTIPLIER, 30.0) logger.error( "Failed to connect to PostgreSQL after all retries", extra={ "total_attempts": attempts, "final_error_type": type(last_error).__name__, "final_error": str(last_error)[:500] }, exc_info=last_error ) print(f"[POSTGRES] ❌ FATAL: Failed to connect after {attempts} attempts") raise last_error async def disconnect_from_database() -> None: """Close database connection when the application shuts down.""" try: await async_engine.dispose() logger.info("Successfully disconnected from PostgreSQL database") except Exception as e: logger.exception("Error disconnecting from PostgreSQL database") raise async def enforce_trans_schema() -> None: """Enforce that all tables use the TRANS schema and validate schema compliance.""" try: async with async_engine.begin() as conn: # Ensure trans schema exists await conn.execute(text("CREATE SCHEMA IF NOT EXISTS trans")) logger.info("✅ TRANS schema exists") # Validate that all models use trans schema from app.core.database import Base # Import all models here when they are created # from app.module_name.models.model import ModelName # Validate schema compliance non_trans_tables = [] for table in Base.metadata.tables.values(): if table.schema != 'trans': non_trans_tables.append(f"{table.name} (schema: {table.schema})") if non_trans_tables: error_msg = f"❌ SCHEMA VIOLATION: The following tables are not using 'trans' schema: {', '.join(non_trans_tables)}" logger.error(error_msg) print(f"\n{'='*80}") print(f"[SCHEMA ERROR] {error_msg}") print(f"{'='*80}\n") raise ValueError(error_msg) logger.info("✅ All Spa tables correctly use 'trans' schema") print(f"[SCHEMA] ✅ All Spa tables correctly use 'trans' schema") except Exception as e: logger.exception("Error enforcing TRANS schema") raise async def create_tables() -> None: """Create all tables defined in models after enforcing schema compliance.""" try: # First enforce schema compliance await enforce_trans_schema() from app.core.database import Base # Import all models here when they are created # from app.module_name.models.model import ModelName async with async_engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) logger.info("✅ Database tables created successfully in TRANS schema") print(f"[SCHEMA] ✅ Database tables created successfully in TRANS schema") except Exception as e: logger.exception("Error creating database tables") raise try: from app.postgres import ( connect_to_postgres as _connect_to_postgres, close_postgres_connection as _close_postgres_connection, get_postgres_connection as _get_postgres_connection, release_postgres_connection as _release_postgres_connection, is_postgres_connected as _is_postgres_connected, ) async def connect_to_postgres() -> None: await _connect_to_postgres() async def close_postgres_connection() -> None: await _close_postgres_connection() async def get_postgres_connection(): return await _get_postgres_connection() async def release_postgres_connection(conn) -> None: await _release_postgres_connection(conn) def is_postgres_connected() -> bool: return _is_postgres_connected() except Exception: pass @asynccontextmanager async def get_postgres_session(): """ Get PostgreSQL session context manager. Usage: async with get_postgres_session() as session: await session.execute(...) await session.commit() """ if not async_session: logger.warning("PostgreSQL session maker not initialized") yield None return session = async_session() try: yield session except Exception as e: await session.rollback() logger.error(f"Session error, rolled back: {e}") raise finally: await session.close() async def get_db(): """ FastAPI dependency for getting database session. Usage: @app.get("/endpoint") async def endpoint(db: AsyncSession = Depends(get_db)): ... """ async with async_session() as session: try: yield session await session.commit() except Exception: await session.rollback() raise finally: await session.close()