"""Service for managing user-registered external database connections.""" import uuid from typing import List, Optional from sqlalchemy import delete, select from sqlalchemy.ext.asyncio import AsyncSession from src.db.postgres.models import DatabaseClient from src.middlewares.logging import get_logger from src.utils.db_credential_encryption import ( decrypt_credentials_dict, encrypt_credentials_dict, ) logger = get_logger("database_client_service") # Fields that identify the same physical database per db_type. _CONNECTION_IDENTITY_KEYS: dict[str, tuple[str, ...]] = { "postgres": ("host", "port", "database"), "supabase": ("host", "port", "database"), "mysql": ("host", "port", "database"), "sqlserver": ("host", "port", "database"), "bigquery": ("project_id", "dataset_id"), "snowflake": ("account", "warehouse", "database"), } class DatabaseClientService: """Service for managing user-registered external database connections.""" async def _find_duplicate( self, db: AsyncSession, user_id: str, db_type: str, credentials: dict, ) -> Optional[DatabaseClient]: """Return an existing client if it points to the same physical database.""" identity_keys = _CONNECTION_IDENTITY_KEYS.get(db_type, ()) if not identity_keys: return None result = await db.execute( select(DatabaseClient).where( DatabaseClient.user_id == user_id, DatabaseClient.db_type == db_type, ) ) for existing in result.scalars().all(): decrypted = decrypt_credentials_dict(existing.credentials) if all( decrypted.get(k) == credentials.get(k) for k in identity_keys ): return existing return None async def create( self, db: AsyncSession, user_id: str, name: str, db_type: str, credentials: dict, ) -> DatabaseClient: """Register a new database client connection. If a connection to the same physical database already exists for this user, the existing record is returned instead of creating a duplicate. Credentials are encrypted before being stored. """ existing = await self._find_duplicate(db, user_id, db_type, credentials) if existing: logger.info( f"Duplicate connection detected, returning existing client {existing.id}" ) return existing client = DatabaseClient( id=str(uuid.uuid4()), user_id=user_id, name=name, db_type=db_type, credentials=encrypt_credentials_dict(credentials), status="active", ) db.add(client) await db.commit() await db.refresh(client) logger.info(f"Created database client {client.id} for user {user_id}") return client async def get_user_clients( self, db: AsyncSession, user_id: str, ) -> List[DatabaseClient]: """Return all active and inactive database clients for a user.""" result = await db.execute( select(DatabaseClient) .where(DatabaseClient.user_id == user_id) .order_by(DatabaseClient.created_at.desc()) ) return result.scalars().all() async def get( self, db: AsyncSession, client_id: str, ) -> Optional[DatabaseClient]: """Return a single database client by its ID.""" result = await db.execute( select(DatabaseClient).where(DatabaseClient.id == client_id) ) return result.scalars().first() async def update( self, db: AsyncSession, client_id: str, name: Optional[str] = None, credentials: Optional[dict] = None, status: Optional[str] = None, ) -> Optional[DatabaseClient]: """Update an existing database client connection. Only non-None fields are updated. Credentials are re-encrypted if provided. """ client = await self.get(db, client_id) if not client: return None if name is not None: client.name = name if credentials is not None: client.credentials = encrypt_credentials_dict(credentials) if status is not None: client.status = status await db.commit() await db.refresh(client) logger.info(f"Updated database client {client_id}") return client async def delete( self, db: AsyncSession, client_id: str, ) -> bool: """Permanently delete a database client connection.""" result = await db.execute( delete(DatabaseClient).where(DatabaseClient.id == client_id) ) await db.commit() deleted = result.rowcount > 0 if deleted: logger.info(f"Deleted database client {client_id}") return deleted database_client_service = DatabaseClientService()