MukeshKapoor25's picture
refactor(sql): reduce connection pool size and remove PostgreSQL optimizations
3715d0b
"""
PostgreSQL database connection and session management for POS microservice.
Following SCM pattern with SQLAlchemy async engine.
"""
import logging
import ssl
from contextlib import asynccontextmanager
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.orm import sessionmaker
from sqlalchemy import MetaData, text
from app.core.config import settings
logger = logging.getLogger(__name__)
# Database URL validation
DATABASE_URI = settings.POSTGRES_URI
if not DATABASE_URI:
logger.error("POSTGRES_URI is empty or missing from settings")
raise ValueError("POSTGRES_URI is not set. Check environment variables.")
# Parse and log connection details (with masked password)
def mask_connection_string(uri: str) -> str:
"""Mask password in connection string for safe logging"""
if not uri:
return "EMPTY"
try:
# Format: postgresql+asyncpg://user:password@host:port/database
if "@" in uri:
protocol_user_pass, host_db = uri.split("@", 1)
if "://" in protocol_user_pass:
protocol, user_pass = protocol_user_pass.split("://", 1)
if ":" in user_pass:
user, _ = user_pass.split(":", 1)
masked_uri = f"{protocol}://{user}:***@{host_db}"
else:
masked_uri = f"{protocol}://{user_pass}:***@{host_db}"
else:
masked_uri = f"{protocol_user_pass}:***@{host_db}"
else:
masked_uri = uri
return masked_uri
except Exception:
return "INVALID_FORMAT"
def parse_connection_details(uri: str) -> dict:
"""Parse connection string to extract host, port, database"""
try:
# Extract host:port/database from URI
if "@" in uri:
_, host_db = uri.split("@", 1)
if "/" in host_db:
host_port, database = host_db.split("/", 1)
if ":" in host_port:
host, port = host_port.split(":", 1)
else:
host = host_port
port = "5432"
return {
"host": host,
"port": port,
"database": database.split("?")[0] # Remove query params
}
except Exception:
pass
return {"host": "unknown", "port": "unknown", "database": "unknown"}
masked_uri = mask_connection_string(DATABASE_URI)
conn_details = parse_connection_details(DATABASE_URI)
logger.info(
"PostgreSQL connection configured",
extra={
"connection_string": masked_uri,
"host": conn_details["host"],
"port": conn_details["port"],
"database": conn_details["database"],
"ssl_mode": settings.POSTGRES_SSL_MODE
}
)
# Build connect args including optional SSL
CONNECT_ARGS = {
"server_settings": {
"application_name": "cuatrolabs-pos-ms",
"jit": "off"
},
"command_timeout": 60,
"statement_cache_size": 0
}
mode = (settings.POSTGRES_SSL_MODE or "disable").lower()
if mode != "disable":
ssl_context: ssl.SSLContext
if mode == "verify-full":
ssl_context = ssl.create_default_context(cafile=settings.POSTGRES_SSL_ROOT_CERT) if settings.POSTGRES_SSL_ROOT_CERT else ssl.create_default_context()
if settings.POSTGRES_SSL_CERT and settings.POSTGRES_SSL_KEY:
try:
ssl_context.load_cert_chain(certfile=settings.POSTGRES_SSL_CERT, keyfile=settings.POSTGRES_SSL_KEY)
except Exception as e:
logger.warning("Failed to load client SSL cert/key for PostgreSQL", exc_info=e)
ssl_context.check_hostname = True
ssl_context.verify_mode = ssl.CERT_REQUIRED
else:
# sslmode=require: encrypt but don't verify server cert
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
CONNECT_ARGS["ssl"] = ssl_context
logger.info("PostgreSQL SSL enabled", extra={"ssl_mode": settings.POSTGRES_SSL_MODE})
async_engine = create_async_engine(
DATABASE_URI,
echo=settings.DEBUG,
future=True,
pool_size=10,
max_overflow=20,
pool_timeout=30,
pool_recycle=3600,
pool_pre_ping=True,
connect_args=CONNECT_ARGS
)
# Create async session factory
async_session = sessionmaker(
async_engine,
expire_on_commit=False,
class_=AsyncSession
)
# Metadata for table creation
metadata = MetaData()
logger.info("PostgreSQL configuration loaded successfully")
# ────────────────────────────────────────────────────────────────────────────────
# Lifecycle helpers
# ────────────────────────────────────────────────────────────────────────────────
async def connect_to_database() -> None:
"""Initialize database connection when the application starts."""
import asyncio
# Log connection details at startup (once per worker)
print(f"\n{'='*70}")
print(f"[POSTGRES] Starting Database Connection")
print(f"{'='*70}")
print(f"[POSTGRES] Connection String: {masked_uri}")
print(f"[POSTGRES] Host: {conn_details['host']}")
print(f"[POSTGRES] Port: {conn_details['port']}")
print(f"[POSTGRES] Database: {conn_details['database']}")
print(f"[POSTGRES] SSL Mode: {settings.POSTGRES_SSL_MODE}")
print(f"{'='*70}\n")
# Log connection attempt start
logger.info(
"Starting PostgreSQL connection attempts",
extra={
"max_retries": settings.POSTGRES_CONNECT_MAX_RETRIES,
"initial_delay_ms": settings.POSTGRES_CONNECT_INITIAL_DELAY_MS,
"backoff_multiplier": settings.POSTGRES_CONNECT_BACKOFF_MULTIPLIER
}
)
print(f"[POSTGRES] Attempting to connect (max retries: {settings.POSTGRES_CONNECT_MAX_RETRIES})...")
attempts = 0
max_attempts = settings.POSTGRES_CONNECT_MAX_RETRIES
delay = settings.POSTGRES_CONNECT_INITIAL_DELAY_MS / 1000.0
last_error = None
while attempts < max_attempts:
try:
async with async_engine.begin() as conn:
await conn.execute(text("SELECT 1"))
logger.info("Successfully connected to PostgreSQL database")
print(f"[POSTGRES] βœ… Connection successful after {attempts + 1} attempt(s)")
return
except Exception as e:
last_error = e
attempts += 1
error_msg = str(e)
error_type = type(e).__name__
logger.warning(
"PostgreSQL connection attempt failed",
extra={
"attempt": attempts,
"max_attempts": max_attempts,
"retry_delay_ms": int(delay * 1000),
"error_type": error_type,
"error_message": error_msg[:200] # Truncate long errors
}
)
print(f"[POSTGRES] ❌ Connection attempt {attempts}/{max_attempts} failed")
print(f"[POSTGRES] Error: {error_type}: {error_msg[:150]}")
if attempts < max_attempts:
print(f"[POSTGRES] Retrying in {delay:.2f}s...")
await asyncio.sleep(delay)
delay = min(delay * settings.POSTGRES_CONNECT_BACKOFF_MULTIPLIER, 30.0)
# All attempts failed
logger.error(
"Failed to connect to PostgreSQL after all retries",
extra={
"total_attempts": attempts,
"final_error_type": type(last_error).__name__,
"final_error": str(last_error)[:500]
},
exc_info=last_error
)
print(f"[POSTGRES] ❌ FATAL: Failed to connect after {attempts} attempts")
print(f"[POSTGRES] Last error: {type(last_error).__name__}: {str(last_error)[:200]}")
print(f"[POSTGRES] Please check:")
print(f"[POSTGRES] 1. Database host is reachable: {conn_details['host']}")
print(f"[POSTGRES] 2. Database credentials are correct")
print(f"[POSTGRES] 3. SSL mode is appropriate: {settings.POSTGRES_SSL_MODE}")
print(f"[POSTGRES] 4. Firewall allows connections to port {conn_details['port']}")
raise last_error
async def disconnect_from_database() -> None:
"""Close database connection when the application shuts down."""
try:
await async_engine.dispose()
logger.info("Successfully disconnected from PostgreSQL database")
except Exception as e:
logger.exception("Error disconnecting from PostgreSQL database")
raise
async def enforce_trans_schema() -> None:
"""Enforce that all tables use the TRANS schema and validate schema compliance."""
try:
async with async_engine.begin() as conn:
# Ensure trans schema exists
await conn.execute(text("CREATE SCHEMA IF NOT EXISTS trans"))
logger.info("βœ… TRANS schema exists")
# Validate that all models use trans schema
from app.core.database import Base
# Import all POS models to ensure they're registered with Base
from app.sync.models import CustomerRef, StaffRef, CatalogueServiceRef
# Validate schema compliance
non_trans_tables = []
for table in Base.metadata.tables.values():
if table.schema != 'trans':
non_trans_tables.append(f"{table.name} (schema: {table.schema})")
if non_trans_tables:
error_msg = f"❌ SCHEMA VIOLATION: The following tables are not using 'trans' schema: {', '.join(non_trans_tables)}"
logger.error(error_msg)
print(f"\n{'='*80}")
print(f"[SCHEMA ERROR] {error_msg}")
print(f"{'='*80}\n")
raise ValueError(error_msg)
logger.info("βœ… All POS tables correctly use 'trans' schema")
print(f"[SCHEMA] βœ… All POS tables correctly use 'trans' schema")
except Exception as e:
logger.exception("Error enforcing TRANS schema")
raise
async def create_tables() -> None:
"""Create all tables defined in models after enforcing schema compliance."""
try:
# First enforce schema compliance
await enforce_trans_schema()
from app.core.database import Base
# Import all POS-specific model imports here as needed
from app.sync.models import CustomerRef, StaffRef, CatalogueServiceRef
async with async_engine.begin() as conn:
# Create all tables (schema already validated)
await conn.run_sync(Base.metadata.create_all)
logger.info("βœ… Database tables created successfully in TRANS schema")
print(f"[SCHEMA] βœ… Database tables created successfully in TRANS schema")
except Exception as e:
logger.exception("Error creating database tables")
raise
# ────────────────────────────────────────────────────────────────────────────────
# Legacy compatibility functions
# ────────────────────────────────────────────────────────────────────────────────
async def connect_to_postgres() -> None:
"""Legacy function name - calls connect_to_database"""
await connect_to_database()
async def close_postgres_connection() -> None:
"""Legacy function name - calls disconnect_from_database"""
await disconnect_from_database()
def get_postgres_engine():
"""Get PostgreSQL engine instance."""
return async_engine
@asynccontextmanager
async def get_postgres_session():
"""
Get PostgreSQL session context manager.
Usage:
async with get_postgres_session() as session:
await session.execute(...)
await session.commit()
"""
if not async_session:
logger.warning("PostgreSQL session maker not initialized")
yield None
return
session = async_session()
try:
yield session
except Exception as e:
await session.rollback()
logger.error(f"Session error, rolled back: {e}")
raise
finally:
await session.close()
async def execute_postgres_query(query: str, params: dict = None):
"""
Execute a raw SQL query.
Args:
query: SQL query string
params: Query parameters dict
"""
if not async_engine:
logger.warning("PostgreSQL engine not available, skipping query")
return None
async with get_postgres_session() as session:
if session is None:
return None
result = await session.execute(query, params or {})
await session.commit()
return result