Spaces:
Running
Running
| """ | |
| PostgreSQL database connection and session management for SCM microservice. | |
| Following TMS pattern with SQLAlchemy async engine. | |
| """ | |
| import logging | |
| import ssl | |
| from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession | |
| from sqlalchemy.orm import sessionmaker | |
| from sqlalchemy import MetaData, text | |
| from app.core.config import settings | |
| logger = logging.getLogger(__name__) | |
| # Database URL validation | |
| DATABASE_URI = settings.POSTGRES_URI | |
| if not DATABASE_URI: | |
| logger.error("POSTGRES_URI is empty or missing from settings") | |
| raise ValueError("POSTGRES_URI is not set. Check environment variables.") | |
| # Parse and log connection details (with masked password) | |
| def mask_connection_string(uri: str) -> str: | |
| """Mask password in connection string for safe logging""" | |
| if not uri: | |
| return "EMPTY" | |
| try: | |
| # Format: postgresql+asyncpg://user:password@host:port/database | |
| if "@" in uri: | |
| protocol_user_pass, host_db = uri.split("@", 1) | |
| if "://" in protocol_user_pass: | |
| protocol, user_pass = protocol_user_pass.split("://", 1) | |
| if ":" in user_pass: | |
| user, _ = user_pass.split(":", 1) | |
| masked_uri = f"{protocol}://{user}:***@{host_db}" | |
| else: | |
| masked_uri = f"{protocol}://{user_pass}:***@{host_db}" | |
| else: | |
| masked_uri = f"{protocol_user_pass}:***@{host_db}" | |
| else: | |
| masked_uri = uri | |
| return masked_uri | |
| except Exception: | |
| return "INVALID_FORMAT" | |
| def parse_connection_details(uri: str) -> dict: | |
| """Parse connection string to extract host, port, database""" | |
| try: | |
| # Extract host:port/database from URI | |
| if "@" in uri: | |
| _, host_db = uri.split("@", 1) | |
| if "/" in host_db: | |
| host_port, database = host_db.split("/", 1) | |
| if ":" in host_port: | |
| host, port = host_port.split(":", 1) | |
| else: | |
| host = host_port | |
| port = "5432" | |
| return { | |
| "host": host, | |
| "port": port, | |
| "database": database.split("?")[0] # Remove query params | |
| } | |
| except Exception: | |
| pass | |
| return {"host": "unknown", "port": "unknown", "database": "unknown"} | |
| masked_uri = mask_connection_string(DATABASE_URI) | |
| conn_details = parse_connection_details(DATABASE_URI) | |
| logger.info( | |
| "PostgreSQL connection configured", | |
| extra={ | |
| "connection_string": masked_uri, | |
| "host": conn_details["host"], | |
| "port": conn_details["port"], | |
| "database": conn_details["database"], | |
| "ssl_mode": settings.POSTGRES_SSL_MODE | |
| } | |
| ) | |
| # Build connect args including optional SSL | |
| CONNECT_ARGS = { | |
| "server_settings": { | |
| "application_name": "cuatrolabs-scm-ms", | |
| "jit": "off" | |
| }, | |
| "command_timeout": 60, | |
| "statement_cache_size": 0 | |
| } | |
| mode = (settings.POSTGRES_SSL_MODE or "disable").lower() | |
| if mode != "disable": | |
| ssl_context: ssl.SSLContext | |
| if mode == "verify-full": | |
| ssl_context = ssl.create_default_context(cafile=settings.POSTGRES_SSL_ROOT_CERT) if settings.POSTGRES_SSL_ROOT_CERT else ssl.create_default_context() | |
| if settings.POSTGRES_SSL_CERT and settings.POSTGRES_SSL_KEY: | |
| try: | |
| ssl_context.load_cert_chain(certfile=settings.POSTGRES_SSL_CERT, keyfile=settings.POSTGRES_SSL_KEY) | |
| except Exception as e: | |
| logger.warning("Failed to load client SSL cert/key for PostgreSQL", exc_info=e) | |
| ssl_context.check_hostname = True | |
| ssl_context.verify_mode = ssl.CERT_REQUIRED | |
| else: | |
| # sslmode=require: encrypt but don't verify server cert | |
| ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) | |
| ssl_context.check_hostname = False | |
| ssl_context.verify_mode = ssl.CERT_NONE | |
| CONNECT_ARGS["ssl"] = ssl_context | |
| logger.info("PostgreSQL SSL enabled", extra={"ssl_mode": settings.POSTGRES_SSL_MODE}) | |
| # Create async engine with connection pool settings | |
| async_engine = create_async_engine( | |
| DATABASE_URI, | |
| echo=settings.DEBUG, # Enable SQL logging in debug mode | |
| future=True, | |
| pool_size=10, | |
| max_overflow=20, | |
| pool_timeout=30, | |
| pool_recycle=3600, | |
| pool_pre_ping=True, | |
| connect_args=CONNECT_ARGS | |
| ) | |
| # Create async session factory | |
| async_session = sessionmaker( | |
| async_engine, | |
| expire_on_commit=False, | |
| class_=AsyncSession | |
| ) | |
| # Metadata for table creation | |
| metadata = MetaData() | |
| logger.info("PostgreSQL configuration loaded successfully") | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Lifecycle helpers | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def connect_to_database() -> None: | |
| """Initialize database connection when the application starts.""" | |
| import asyncio | |
| # Log connection details at startup (once per worker) | |
| print(f"\n{'='*70}") | |
| print(f"[POSTGRES] Starting Database Connection") | |
| print(f"{'='*70}") | |
| print(f"[POSTGRES] Connection String: {masked_uri}") | |
| print(f"[POSTGRES] Host: {conn_details['host']}") | |
| print(f"[POSTGRES] Port: {conn_details['port']}") | |
| print(f"[POSTGRES] Database: {conn_details['database']}") | |
| print(f"[POSTGRES] SSL Mode: {settings.POSTGRES_SSL_MODE}") | |
| print(f"{'='*70}\n") | |
| # Log connection attempt start | |
| logger.info( | |
| "Starting PostgreSQL connection attempts", | |
| extra={ | |
| "max_retries": settings.POSTGRES_CONNECT_MAX_RETRIES, | |
| "initial_delay_ms": settings.POSTGRES_CONNECT_INITIAL_DELAY_MS, | |
| "backoff_multiplier": settings.POSTGRES_CONNECT_BACKOFF_MULTIPLIER | |
| } | |
| ) | |
| print(f"[POSTGRES] Attempting to connect (max retries: {settings.POSTGRES_CONNECT_MAX_RETRIES})...") | |
| attempts = 0 | |
| max_attempts = settings.POSTGRES_CONNECT_MAX_RETRIES | |
| delay = settings.POSTGRES_CONNECT_INITIAL_DELAY_MS / 1000.0 | |
| last_error = None | |
| while attempts < max_attempts: | |
| try: | |
| async with async_engine.begin() as conn: | |
| await conn.execute(text("SELECT 1")) | |
| logger.info("Successfully connected to PostgreSQL database") | |
| print(f"[POSTGRES] β Connection successful after {attempts + 1} attempt(s)") | |
| return | |
| except Exception as e: | |
| last_error = e | |
| attempts += 1 | |
| error_msg = str(e) | |
| error_type = type(e).__name__ | |
| logger.warning( | |
| "PostgreSQL connection attempt failed", | |
| extra={ | |
| "attempt": attempts, | |
| "max_attempts": max_attempts, | |
| "retry_delay_ms": int(delay * 1000), | |
| "error_type": error_type, | |
| "error_message": error_msg[:200] # Truncate long errors | |
| } | |
| ) | |
| print(f"[POSTGRES] β Connection attempt {attempts}/{max_attempts} failed") | |
| print(f"[POSTGRES] Error: {error_type}: {error_msg[:150]}") | |
| if attempts < max_attempts: | |
| print(f"[POSTGRES] Retrying in {delay:.2f}s...") | |
| await asyncio.sleep(delay) | |
| delay = min(delay * settings.POSTGRES_CONNECT_BACKOFF_MULTIPLIER, 30.0) | |
| # All attempts failed | |
| logger.error( | |
| "Failed to connect to PostgreSQL after all retries", | |
| extra={ | |
| "total_attempts": attempts, | |
| "final_error_type": type(last_error).__name__, | |
| "final_error": str(last_error)[:500] | |
| }, | |
| exc_info=last_error | |
| ) | |
| print(f"[POSTGRES] β FATAL: Failed to connect after {attempts} attempts") | |
| print(f"[POSTGRES] Last error: {type(last_error).__name__}: {str(last_error)[:200]}") | |
| print(f"[POSTGRES] Please check:") | |
| print(f"[POSTGRES] 1. Database host is reachable: {conn_details['host']}") | |
| print(f"[POSTGRES] 2. Database credentials are correct") | |
| print(f"[POSTGRES] 3. SSL mode is appropriate: {settings.POSTGRES_SSL_MODE}") | |
| print(f"[POSTGRES] 4. Firewall allows connections to port {conn_details['port']}") | |
| raise last_error | |
| async def disconnect_from_database() -> None: | |
| """Close database connection when the application shuts down.""" | |
| try: | |
| await async_engine.dispose() | |
| logger.info("Successfully disconnected from PostgreSQL database") | |
| except Exception as e: | |
| logger.exception("Error disconnecting from PostgreSQL database") | |
| raise | |
| async def enforce_trans_schema() -> None: | |
| """Enforce that all tables use the TRANS schema and validate schema compliance.""" | |
| try: | |
| async with async_engine.begin() as conn: | |
| # Ensure trans schema exists | |
| await conn.execute(text("CREATE SCHEMA IF NOT EXISTS trans")) | |
| logger.info("β TRANS schema exists") | |
| # Validate that all models use trans schema | |
| from app.core.database import Base | |
| # Import all models to ensure they're registered with Base | |
| from app.catalogues.models.model import CatalogueRef | |
| from app.inventory.stock.models.model import ScmStock, ScmStockLedger, ScmStockAdjustment | |
| from app.purchases.orders.models.model import ScmPo, ScmPoItem, ScmPoStatusLog | |
| from app.purchases.receipts.models.model import ScmGrn, ScmGrnItem, ScmGrnIssue | |
| from app.trade_sales.models.model import ScmTradeShipment, ScmTradeShipmentItem | |
| from app.trade_relationships.models.model import ScmTradeRelationship | |
| # Validate schema compliance | |
| non_trans_tables = [] | |
| for table in Base.metadata.tables.values(): | |
| if table.schema != 'trans': | |
| non_trans_tables.append(f"{table.name} (schema: {table.schema})") | |
| if non_trans_tables: | |
| error_msg = f"β SCHEMA VIOLATION: The following tables are not using 'trans' schema: {', '.join(non_trans_tables)}" | |
| logger.error(error_msg) | |
| print(f"\n{'='*80}") | |
| print(f"[SCHEMA ERROR] {error_msg}") | |
| print(f"{'='*80}\n") | |
| raise ValueError(error_msg) | |
| logger.info("β All SCM tables correctly use 'trans' schema") | |
| print(f"[SCHEMA] β All SCM tables correctly use 'trans' schema") | |
| except Exception as e: | |
| logger.exception("Error enforcing TRANS schema") | |
| raise | |
| async def create_tables() -> None: | |
| """Create all tables defined in models after enforcing schema compliance.""" | |
| try: | |
| # First enforce schema compliance | |
| await enforce_trans_schema() | |
| from app.core.database import Base | |
| async with async_engine.begin() as conn: | |
| # Create all tables (schema already validated) | |
| await conn.run_sync(Base.metadata.create_all) | |
| logger.info("β Database tables created successfully in TRANS schema") | |
| print(f"[SCHEMA] β Database tables created successfully in TRANS schema") | |
| except Exception as e: | |
| logger.exception("Error creating database tables") | |
| raise | |
| try: | |
| from app.postgres import ( | |
| connect_to_postgres as _connect_to_postgres, | |
| close_postgres_connection as _close_postgres_connection, | |
| get_postgres_connection as _get_postgres_connection, | |
| release_postgres_connection as _release_postgres_connection, | |
| is_postgres_connected as _is_postgres_connected, | |
| ) | |
| async def connect_to_postgres() -> None: | |
| await _connect_to_postgres() | |
| async def close_postgres_connection() -> None: | |
| await _close_postgres_connection() | |
| async def get_postgres_connection(): | |
| return await _get_postgres_connection() | |
| async def release_postgres_connection(conn) -> None: | |
| await _release_postgres_connection(conn) | |
| def is_postgres_connected() -> bool: | |
| return _is_postgres_connected() | |
| except Exception: | |
| pass | |