""" PostgreSQL database connection and session management for POS microservice. Following SCM pattern with SQLAlchemy async engine. """ import logging import ssl from contextlib import asynccontextmanager from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession from sqlalchemy.orm import sessionmaker from sqlalchemy import MetaData, text from app.core.config import settings logger = logging.getLogger(__name__) # Database URL validation DATABASE_URI = settings.POSTGRES_URI if not DATABASE_URI: logger.error("POSTGRES_URI is empty or missing from settings") raise ValueError("POSTGRES_URI is not set. Check environment variables.") # Parse and log connection details (with masked password) def mask_connection_string(uri: str) -> str: """Mask password in connection string for safe logging""" if not uri: return "EMPTY" try: # Format: postgresql+asyncpg://user:password@host:port/database if "@" in uri: protocol_user_pass, host_db = uri.split("@", 1) if "://" in protocol_user_pass: protocol, user_pass = protocol_user_pass.split("://", 1) if ":" in user_pass: user, _ = user_pass.split(":", 1) masked_uri = f"{protocol}://{user}:***@{host_db}" else: masked_uri = f"{protocol}://{user_pass}:***@{host_db}" else: masked_uri = f"{protocol_user_pass}:***@{host_db}" else: masked_uri = uri return masked_uri except Exception: return "INVALID_FORMAT" def parse_connection_details(uri: str) -> dict: """Parse connection string to extract host, port, database""" try: # Extract host:port/database from URI if "@" in uri: _, host_db = uri.split("@", 1) if "/" in host_db: host_port, database = host_db.split("/", 1) if ":" in host_port: host, port = host_port.split(":", 1) else: host = host_port port = "5432" return { "host": host, "port": port, "database": database.split("?")[0] # Remove query params } except Exception: pass return {"host": "unknown", "port": "unknown", "database": "unknown"} masked_uri = mask_connection_string(DATABASE_URI) conn_details = parse_connection_details(DATABASE_URI) logger.info( "PostgreSQL connection configured", extra={ "connection_string": masked_uri, "host": conn_details["host"], "port": conn_details["port"], "database": conn_details["database"], "ssl_mode": settings.POSTGRES_SSL_MODE } ) # Build connect args including optional SSL CONNECT_ARGS = { "server_settings": { "application_name": "cuatrolabs-pos-ms", "jit": "off" }, "command_timeout": 60, "statement_cache_size": 0 } mode = (settings.POSTGRES_SSL_MODE or "disable").lower() if mode != "disable": ssl_context: ssl.SSLContext if mode == "verify-full": ssl_context = ssl.create_default_context(cafile=settings.POSTGRES_SSL_ROOT_CERT) if settings.POSTGRES_SSL_ROOT_CERT else ssl.create_default_context() if settings.POSTGRES_SSL_CERT and settings.POSTGRES_SSL_KEY: try: ssl_context.load_cert_chain(certfile=settings.POSTGRES_SSL_CERT, keyfile=settings.POSTGRES_SSL_KEY) except Exception as e: logger.warning("Failed to load client SSL cert/key for PostgreSQL", exc_info=e) ssl_context.check_hostname = True ssl_context.verify_mode = ssl.CERT_REQUIRED else: # sslmode=require: encrypt but don't verify server cert ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ssl_context.check_hostname = False ssl_context.verify_mode = ssl.CERT_NONE CONNECT_ARGS["ssl"] = ssl_context logger.info("PostgreSQL SSL enabled", extra={"ssl_mode": settings.POSTGRES_SSL_MODE}) async_engine = create_async_engine( DATABASE_URI, echo=settings.DEBUG, future=True, pool_size=10, max_overflow=20, pool_timeout=30, pool_recycle=3600, pool_pre_ping=True, connect_args=CONNECT_ARGS ) # Create async session factory async_session = sessionmaker( async_engine, expire_on_commit=False, class_=AsyncSession ) # Metadata for table creation metadata = MetaData() logger.info("PostgreSQL configuration loaded successfully") # ──────────────────────────────────────────────────────────────────────────────── # Lifecycle helpers # ──────────────────────────────────────────────────────────────────────────────── async def connect_to_database() -> None: """Initialize database connection when the application starts.""" import asyncio # Log connection details at startup (once per worker) print(f"\n{'='*70}") print(f"[POSTGRES] Starting Database Connection") print(f"{'='*70}") print(f"[POSTGRES] Connection String: {masked_uri}") print(f"[POSTGRES] Host: {conn_details['host']}") print(f"[POSTGRES] Port: {conn_details['port']}") print(f"[POSTGRES] Database: {conn_details['database']}") print(f"[POSTGRES] SSL Mode: {settings.POSTGRES_SSL_MODE}") print(f"{'='*70}\n") # Log connection attempt start logger.info( "Starting PostgreSQL connection attempts", extra={ "max_retries": settings.POSTGRES_CONNECT_MAX_RETRIES, "initial_delay_ms": settings.POSTGRES_CONNECT_INITIAL_DELAY_MS, "backoff_multiplier": settings.POSTGRES_CONNECT_BACKOFF_MULTIPLIER } ) print(f"[POSTGRES] Attempting to connect (max retries: {settings.POSTGRES_CONNECT_MAX_RETRIES})...") attempts = 0 max_attempts = settings.POSTGRES_CONNECT_MAX_RETRIES delay = settings.POSTGRES_CONNECT_INITIAL_DELAY_MS / 1000.0 last_error = None while attempts < max_attempts: try: async with async_engine.begin() as conn: await conn.execute(text("SELECT 1")) logger.info("Successfully connected to PostgreSQL database") print(f"[POSTGRES] ✅ Connection successful after {attempts + 1} attempt(s)") return except Exception as e: last_error = e attempts += 1 error_msg = str(e) error_type = type(e).__name__ logger.warning( "PostgreSQL connection attempt failed", extra={ "attempt": attempts, "max_attempts": max_attempts, "retry_delay_ms": int(delay * 1000), "error_type": error_type, "error_message": error_msg[:200] # Truncate long errors } ) print(f"[POSTGRES] ❌ Connection attempt {attempts}/{max_attempts} failed") print(f"[POSTGRES] Error: {error_type}: {error_msg[:150]}") if attempts < max_attempts: print(f"[POSTGRES] Retrying in {delay:.2f}s...") await asyncio.sleep(delay) delay = min(delay * settings.POSTGRES_CONNECT_BACKOFF_MULTIPLIER, 30.0) # All attempts failed logger.error( "Failed to connect to PostgreSQL after all retries", extra={ "total_attempts": attempts, "final_error_type": type(last_error).__name__, "final_error": str(last_error)[:500] }, exc_info=last_error ) print(f"[POSTGRES] ❌ FATAL: Failed to connect after {attempts} attempts") print(f"[POSTGRES] Last error: {type(last_error).__name__}: {str(last_error)[:200]}") print(f"[POSTGRES] Please check:") print(f"[POSTGRES] 1. Database host is reachable: {conn_details['host']}") print(f"[POSTGRES] 2. Database credentials are correct") print(f"[POSTGRES] 3. SSL mode is appropriate: {settings.POSTGRES_SSL_MODE}") print(f"[POSTGRES] 4. Firewall allows connections to port {conn_details['port']}") raise last_error async def disconnect_from_database() -> None: """Close database connection when the application shuts down.""" try: await async_engine.dispose() logger.info("Successfully disconnected from PostgreSQL database") except Exception as e: logger.exception("Error disconnecting from PostgreSQL database") raise async def enforce_trans_schema() -> None: """Enforce that all tables use the TRANS schema and validate schema compliance.""" try: async with async_engine.begin() as conn: # Ensure trans schema exists await conn.execute(text("CREATE SCHEMA IF NOT EXISTS trans")) logger.info("✅ TRANS schema exists") # Validate that all models use trans schema from app.core.database import Base # Import all POS models to ensure they're registered with Base from app.sync.models import CustomerRef, StaffRef, CatalogueServiceRef # Validate schema compliance non_trans_tables = [] for table in Base.metadata.tables.values(): if table.schema != 'trans': non_trans_tables.append(f"{table.name} (schema: {table.schema})") if non_trans_tables: error_msg = f"❌ SCHEMA VIOLATION: The following tables are not using 'trans' schema: {', '.join(non_trans_tables)}" logger.error(error_msg) print(f"\n{'='*80}") print(f"[SCHEMA ERROR] {error_msg}") print(f"{'='*80}\n") raise ValueError(error_msg) logger.info("✅ All POS tables correctly use 'trans' schema") print(f"[SCHEMA] ✅ All POS tables correctly use 'trans' schema") except Exception as e: logger.exception("Error enforcing TRANS schema") raise async def create_tables() -> None: """Create all tables defined in models after enforcing schema compliance.""" try: # First enforce schema compliance await enforce_trans_schema() from app.core.database import Base # Import all POS-specific model imports here as needed from app.sync.models import CustomerRef, StaffRef, CatalogueServiceRef async with async_engine.begin() as conn: # Create all tables (schema already validated) await conn.run_sync(Base.metadata.create_all) logger.info("✅ Database tables created successfully in TRANS schema") print(f"[SCHEMA] ✅ Database tables created successfully in TRANS schema") except Exception as e: logger.exception("Error creating database tables") raise # ──────────────────────────────────────────────────────────────────────────────── # Legacy compatibility functions # ──────────────────────────────────────────────────────────────────────────────── async def connect_to_postgres() -> None: """Legacy function name - calls connect_to_database""" await connect_to_database() async def close_postgres_connection() -> None: """Legacy function name - calls disconnect_from_database""" await disconnect_from_database() def get_postgres_engine(): """Get PostgreSQL engine instance.""" return async_engine @asynccontextmanager async def get_postgres_session(): """ Get PostgreSQL session context manager. Usage: async with get_postgres_session() as session: await session.execute(...) await session.commit() """ if not async_session: logger.warning("PostgreSQL session maker not initialized") yield None return session = async_session() try: yield session except Exception as e: await session.rollback() logger.error(f"Session error, rolled back: {e}") raise finally: await session.close() async def execute_postgres_query(query: str, params: dict = None): """ Execute a raw SQL query. Args: query: SQL query string params: Query parameters dict """ if not async_engine: logger.warning("PostgreSQL engine not available, skipping query") return None async with get_postgres_session() as session: if session is None: return None result = await session.execute(query, params or {}) await session.commit() return result