| """Service for managing user-registered external database connections.""" |
|
|
| import uuid |
| from typing import List, Optional |
|
|
| from sqlalchemy import delete, select |
| from sqlalchemy.ext.asyncio import AsyncSession |
|
|
| from src.db.postgres.models import DatabaseClient |
| from src.middlewares.logging import get_logger |
| from src.utils.db_credential_encryption import ( |
| decrypt_credentials_dict, |
| encrypt_credentials_dict, |
| ) |
|
|
| logger = get_logger("database_client_service") |
|
|
|
|
| |
| _CONNECTION_IDENTITY_KEYS: dict[str, tuple[str, ...]] = { |
| "postgres": ("host", "port", "database"), |
| "supabase": ("host", "port", "database"), |
| "mysql": ("host", "port", "database"), |
| "sqlserver": ("host", "port", "database"), |
| "bigquery": ("project_id", "dataset_id"), |
| "snowflake": ("account", "warehouse", "database"), |
| } |
|
|
|
|
| class DatabaseClientService: |
| """Service for managing user-registered external database connections.""" |
|
|
| async def _find_duplicate( |
| self, |
| db: AsyncSession, |
| user_id: str, |
| db_type: str, |
| credentials: dict, |
| ) -> Optional[DatabaseClient]: |
| """Return an existing client if it points to the same physical database.""" |
| identity_keys = _CONNECTION_IDENTITY_KEYS.get(db_type, ()) |
| if not identity_keys: |
| return None |
|
|
| result = await db.execute( |
| select(DatabaseClient).where( |
| DatabaseClient.user_id == user_id, |
| DatabaseClient.db_type == db_type, |
| ) |
| ) |
| for existing in result.scalars().all(): |
| decrypted = decrypt_credentials_dict(existing.credentials) |
| if all( |
| decrypted.get(k) == credentials.get(k) for k in identity_keys |
| ): |
| return existing |
| return None |
|
|
| async def create( |
| self, |
| db: AsyncSession, |
| user_id: str, |
| name: str, |
| db_type: str, |
| credentials: dict, |
| ) -> DatabaseClient: |
| """Register a new database client connection. |
| |
| If a connection to the same physical database already exists for this |
| user, the existing record is returned instead of creating a duplicate. |
| Credentials are encrypted before being stored. |
| """ |
| existing = await self._find_duplicate(db, user_id, db_type, credentials) |
| if existing: |
| logger.info( |
| f"Duplicate connection detected, returning existing client {existing.id}" |
| ) |
| return existing |
|
|
| client = DatabaseClient( |
| id=str(uuid.uuid4()), |
| user_id=user_id, |
| name=name, |
| db_type=db_type, |
| credentials=encrypt_credentials_dict(credentials), |
| status="active", |
| ) |
| db.add(client) |
| await db.commit() |
| await db.refresh(client) |
| logger.info(f"Created database client {client.id} for user {user_id}") |
| return client |
|
|
| async def get_user_clients( |
| self, |
| db: AsyncSession, |
| user_id: str, |
| ) -> List[DatabaseClient]: |
| """Return all active and inactive database clients for a user.""" |
| result = await db.execute( |
| select(DatabaseClient) |
| .where(DatabaseClient.user_id == user_id) |
| .order_by(DatabaseClient.created_at.desc()) |
| ) |
| return result.scalars().all() |
|
|
| async def get( |
| self, |
| db: AsyncSession, |
| client_id: str, |
| ) -> Optional[DatabaseClient]: |
| """Return a single database client by its ID.""" |
| result = await db.execute( |
| select(DatabaseClient).where(DatabaseClient.id == client_id) |
| ) |
| return result.scalars().first() |
|
|
| async def update( |
| self, |
| db: AsyncSession, |
| client_id: str, |
| name: Optional[str] = None, |
| credentials: Optional[dict] = None, |
| status: Optional[str] = None, |
| ) -> Optional[DatabaseClient]: |
| """Update an existing database client connection. |
| |
| Only non-None fields are updated. |
| Credentials are re-encrypted if provided. |
| """ |
| client = await self.get(db, client_id) |
| if not client: |
| return None |
|
|
| if name is not None: |
| client.name = name |
| if credentials is not None: |
| client.credentials = encrypt_credentials_dict(credentials) |
| if status is not None: |
| client.status = status |
|
|
| await db.commit() |
| await db.refresh(client) |
| logger.info(f"Updated database client {client_id}") |
| return client |
|
|
| async def delete( |
| self, |
| db: AsyncSession, |
| client_id: str, |
| ) -> bool: |
| """Permanently delete a database client connection.""" |
| result = await db.execute( |
| delete(DatabaseClient).where(DatabaseClient.id == client_id) |
| ) |
| await db.commit() |
| deleted = result.rowcount > 0 |
| if deleted: |
| logger.info(f"Deleted database client {client_id}") |
| return deleted |
|
|
|
|
| database_client_service = DatabaseClientService() |
| |