""" PostgreSQL database connection and session management for SCM microservice. Following TMS pattern with SQLAlchemy async engine. """ import logging import ssl from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession from sqlalchemy.orm import sessionmaker from sqlalchemy import MetaData, text from app.core.config import settings logger = logging.getLogger(__name__) # Database URL validation DATABASE_URI = settings.POSTGRES_URI if not DATABASE_URI: logger.error("POSTGRES_URI is empty or missing from settings") raise ValueError("POSTGRES_URI is not set. Check environment variables.") # Parse and log connection details (with masked password) def mask_connection_string(uri: str) -> str: """Mask password in connection string for safe logging""" if not uri: return "EMPTY" try: # Format: postgresql+asyncpg://user:password@host:port/database if "@" in uri: protocol_user_pass, host_db = uri.split("@", 1) if "://" in protocol_user_pass: protocol, user_pass = protocol_user_pass.split("://", 1) if ":" in user_pass: user, _ = user_pass.split(":", 1) masked_uri = f"{protocol}://{user}:***@{host_db}" else: masked_uri = f"{protocol}://{user_pass}:***@{host_db}" else: masked_uri = f"{protocol_user_pass}:***@{host_db}" else: masked_uri = uri return masked_uri except Exception: return "INVALID_FORMAT" def parse_connection_details(uri: str) -> dict: """Parse connection string to extract host, port, database""" try: # Extract host:port/database from URI if "@" in uri: _, host_db = uri.split("@", 1) if "/" in host_db: host_port, database = host_db.split("/", 1) if ":" in host_port: host, port = host_port.split(":", 1) else: host = host_port port = "5432" return { "host": host, "port": port, "database": database.split("?")[0] # Remove query params } except Exception: pass return {"host": "unknown", "port": "unknown", "database": "unknown"} masked_uri = mask_connection_string(DATABASE_URI) conn_details = parse_connection_details(DATABASE_URI) logger.info( "PostgreSQL connection configured", extra={ "connection_string": masked_uri, "host": conn_details["host"], "port": conn_details["port"], "database": conn_details["database"], "ssl_mode": settings.POSTGRES_SSL_MODE } ) # Build connect args including optional SSL CONNECT_ARGS = { "server_settings": { "application_name": "cuatrolabs-scm-ms", "jit": "off" }, "command_timeout": 60, "statement_cache_size": 0 } mode = (settings.POSTGRES_SSL_MODE or "disable").lower() if mode != "disable": ssl_context: ssl.SSLContext if mode == "verify-full": ssl_context = ssl.create_default_context(cafile=settings.POSTGRES_SSL_ROOT_CERT) if settings.POSTGRES_SSL_ROOT_CERT else ssl.create_default_context() if settings.POSTGRES_SSL_CERT and settings.POSTGRES_SSL_KEY: try: ssl_context.load_cert_chain(certfile=settings.POSTGRES_SSL_CERT, keyfile=settings.POSTGRES_SSL_KEY) except Exception as e: logger.warning("Failed to load client SSL cert/key for PostgreSQL", exc_info=e) ssl_context.check_hostname = True ssl_context.verify_mode = ssl.CERT_REQUIRED else: # sslmode=require: encrypt but don't verify server cert ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ssl_context.check_hostname = False ssl_context.verify_mode = ssl.CERT_NONE CONNECT_ARGS["ssl"] = ssl_context logger.info("PostgreSQL SSL enabled", extra={"ssl_mode": settings.POSTGRES_SSL_MODE}) # Create async engine with connection pool settings async_engine = create_async_engine( DATABASE_URI, echo=settings.DEBUG, # Enable SQL logging in debug mode future=True, pool_size=10, max_overflow=20, pool_timeout=30, pool_recycle=3600, pool_pre_ping=True, connect_args=CONNECT_ARGS ) # Create async session factory async_session = sessionmaker( async_engine, expire_on_commit=False, class_=AsyncSession ) # Metadata for table creation metadata = MetaData() logger.info("PostgreSQL configuration loaded successfully") # ──────────────────────────────────────────────────────────────────────────────── # Lifecycle helpers # ──────────────────────────────────────────────────────────────────────────────── async def connect_to_database() -> None: """Initialize database connection when the application starts.""" import asyncio # Log connection details at startup (once per worker) print(f"\n{'='*70}") print(f"[POSTGRES] Starting Database Connection") print(f"{'='*70}") print(f"[POSTGRES] Connection String: {masked_uri}") print(f"[POSTGRES] Host: {conn_details['host']}") print(f"[POSTGRES] Port: {conn_details['port']}") print(f"[POSTGRES] Database: {conn_details['database']}") print(f"[POSTGRES] SSL Mode: {settings.POSTGRES_SSL_MODE}") print(f"{'='*70}\n") # Log connection attempt start logger.info( "Starting PostgreSQL connection attempts", extra={ "max_retries": settings.POSTGRES_CONNECT_MAX_RETRIES, "initial_delay_ms": settings.POSTGRES_CONNECT_INITIAL_DELAY_MS, "backoff_multiplier": settings.POSTGRES_CONNECT_BACKOFF_MULTIPLIER } ) print(f"[POSTGRES] Attempting to connect (max retries: {settings.POSTGRES_CONNECT_MAX_RETRIES})...") attempts = 0 max_attempts = settings.POSTGRES_CONNECT_MAX_RETRIES delay = settings.POSTGRES_CONNECT_INITIAL_DELAY_MS / 1000.0 last_error = None while attempts < max_attempts: try: async with async_engine.begin() as conn: await conn.execute(text("SELECT 1")) logger.info("Successfully connected to PostgreSQL database") print(f"[POSTGRES] ✅ Connection successful after {attempts + 1} attempt(s)") return except Exception as e: last_error = e attempts += 1 error_msg = str(e) error_type = type(e).__name__ logger.warning( "PostgreSQL connection attempt failed", extra={ "attempt": attempts, "max_attempts": max_attempts, "retry_delay_ms": int(delay * 1000), "error_type": error_type, "error_message": error_msg[:200] # Truncate long errors } ) print(f"[POSTGRES] ❌ Connection attempt {attempts}/{max_attempts} failed") print(f"[POSTGRES] Error: {error_type}: {error_msg[:150]}") if attempts < max_attempts: print(f"[POSTGRES] Retrying in {delay:.2f}s...") await asyncio.sleep(delay) delay = min(delay * settings.POSTGRES_CONNECT_BACKOFF_MULTIPLIER, 30.0) # All attempts failed logger.error( "Failed to connect to PostgreSQL after all retries", extra={ "total_attempts": attempts, "final_error_type": type(last_error).__name__, "final_error": str(last_error)[:500] }, exc_info=last_error ) print(f"[POSTGRES] ❌ FATAL: Failed to connect after {attempts} attempts") print(f"[POSTGRES] Last error: {type(last_error).__name__}: {str(last_error)[:200]}") print(f"[POSTGRES] Please check:") print(f"[POSTGRES] 1. Database host is reachable: {conn_details['host']}") print(f"[POSTGRES] 2. Database credentials are correct") print(f"[POSTGRES] 3. SSL mode is appropriate: {settings.POSTGRES_SSL_MODE}") print(f"[POSTGRES] 4. Firewall allows connections to port {conn_details['port']}") raise last_error async def disconnect_from_database() -> None: """Close database connection when the application shuts down.""" try: await async_engine.dispose() logger.info("Successfully disconnected from PostgreSQL database") except Exception as e: logger.exception("Error disconnecting from PostgreSQL database") raise async def enforce_trans_schema() -> None: """Enforce that all tables use the TRANS schema and validate schema compliance.""" try: async with async_engine.begin() as conn: # Ensure trans schema exists await conn.execute(text("CREATE SCHEMA IF NOT EXISTS trans")) logger.info("✅ TRANS schema exists") # Validate that all models use trans schema from app.core.database import Base # Import all models to ensure they're registered with Base from app.catalogues.models.model import CatalogueRef from app.inventory.stock.models.model import ScmStock, ScmStockLedger, ScmStockAdjustment from app.purchases.orders.models.model import ScmPo, ScmPoItem, ScmPoStatusLog from app.purchases.receipts.models.model import ScmGrn, ScmGrnItem, ScmGrnIssue from app.trade_sales.models.model import ScmTradeShipment, ScmTradeShipmentItem from app.trade_relationships.models.model import ScmTradeRelationship # Validate schema compliance non_trans_tables = [] for table in Base.metadata.tables.values(): if table.schema != 'trans': non_trans_tables.append(f"{table.name} (schema: {table.schema})") if non_trans_tables: error_msg = f"❌ SCHEMA VIOLATION: The following tables are not using 'trans' schema: {', '.join(non_trans_tables)}" logger.error(error_msg) print(f"\n{'='*80}") print(f"[SCHEMA ERROR] {error_msg}") print(f"{'='*80}\n") raise ValueError(error_msg) logger.info("✅ All SCM tables correctly use 'trans' schema") print(f"[SCHEMA] ✅ All SCM tables correctly use 'trans' schema") except Exception as e: logger.exception("Error enforcing TRANS schema") raise async def create_tables() -> None: """Create all tables defined in models after enforcing schema compliance.""" try: # First enforce schema compliance await enforce_trans_schema() from app.core.database import Base async with async_engine.begin() as conn: # Create all tables (schema already validated) await conn.run_sync(Base.metadata.create_all) logger.info("✅ Database tables created successfully in TRANS schema") print(f"[SCHEMA] ✅ Database tables created successfully in TRANS schema") except Exception as e: logger.exception("Error creating database tables") raise try: from app.postgres import ( connect_to_postgres as _connect_to_postgres, close_postgres_connection as _close_postgres_connection, get_postgres_connection as _get_postgres_connection, release_postgres_connection as _release_postgres_connection, is_postgres_connected as _is_postgres_connected, ) async def connect_to_postgres() -> None: await _connect_to_postgres() async def close_postgres_connection() -> None: await _close_postgres_connection() async def get_postgres_connection(): return await _get_postgres_connection() async def release_postgres_connection(conn) -> None: await _release_postgres_connection(conn) def is_postgres_connected() -> bool: return _is_postgres_connected() except Exception: pass