File size: 5,149 Bytes
2ba0613 c7d59cf 2ba0613 c7d59cf 2ba0613 c7d59cf 2ba0613 c7d59cf 2ba0613 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 | """Service for managing user-registered external database connections."""
import uuid
from typing import List, Optional
from sqlalchemy import delete, select
from sqlalchemy.ext.asyncio import AsyncSession
from src.db.postgres.models import DatabaseClient
from src.middlewares.logging import get_logger
from src.utils.db_credential_encryption import (
decrypt_credentials_dict,
encrypt_credentials_dict,
)
logger = get_logger("database_client_service")
# Fields that identify the same physical database per db_type.
_CONNECTION_IDENTITY_KEYS: dict[str, tuple[str, ...]] = {
"postgres": ("host", "port", "database"),
"supabase": ("host", "port", "database"),
"mysql": ("host", "port", "database"),
"sqlserver": ("host", "port", "database"),
"bigquery": ("project_id", "dataset_id"),
"snowflake": ("account", "warehouse", "database"),
}
class DatabaseClientService:
"""Service for managing user-registered external database connections."""
async def _find_duplicate(
self,
db: AsyncSession,
user_id: str,
db_type: str,
credentials: dict,
) -> Optional[DatabaseClient]:
"""Return an existing client if it points to the same physical database."""
identity_keys = _CONNECTION_IDENTITY_KEYS.get(db_type, ())
if not identity_keys:
return None
result = await db.execute(
select(DatabaseClient).where(
DatabaseClient.user_id == user_id,
DatabaseClient.db_type == db_type,
)
)
for existing in result.scalars().all():
decrypted = decrypt_credentials_dict(existing.credentials)
if all(
decrypted.get(k) == credentials.get(k) for k in identity_keys
):
return existing
return None
async def create(
self,
db: AsyncSession,
user_id: str,
name: str,
db_type: str,
credentials: dict,
) -> DatabaseClient:
"""Register a new database client connection.
If a connection to the same physical database already exists for this
user, the existing record is returned instead of creating a duplicate.
Credentials are encrypted before being stored.
"""
existing = await self._find_duplicate(db, user_id, db_type, credentials)
if existing:
logger.info(
f"Duplicate connection detected, returning existing client {existing.id}"
)
return existing
client = DatabaseClient(
id=str(uuid.uuid4()),
user_id=user_id,
name=name,
db_type=db_type,
credentials=encrypt_credentials_dict(credentials),
status="active",
)
db.add(client)
await db.commit()
await db.refresh(client)
logger.info(f"Created database client {client.id} for user {user_id}")
return client
async def get_user_clients(
self,
db: AsyncSession,
user_id: str,
) -> List[DatabaseClient]:
"""Return all active and inactive database clients for a user."""
result = await db.execute(
select(DatabaseClient)
.where(DatabaseClient.user_id == user_id)
.order_by(DatabaseClient.created_at.desc())
)
return result.scalars().all()
async def get(
self,
db: AsyncSession,
client_id: str,
) -> Optional[DatabaseClient]:
"""Return a single database client by its ID."""
result = await db.execute(
select(DatabaseClient).where(DatabaseClient.id == client_id)
)
return result.scalars().first()
async def update(
self,
db: AsyncSession,
client_id: str,
name: Optional[str] = None,
credentials: Optional[dict] = None,
status: Optional[str] = None,
) -> Optional[DatabaseClient]:
"""Update an existing database client connection.
Only non-None fields are updated.
Credentials are re-encrypted if provided.
"""
client = await self.get(db, client_id)
if not client:
return None
if name is not None:
client.name = name
if credentials is not None:
client.credentials = encrypt_credentials_dict(credentials)
if status is not None:
client.status = status
await db.commit()
await db.refresh(client)
logger.info(f"Updated database client {client_id}")
return client
async def delete(
self,
db: AsyncSession,
client_id: str,
) -> bool:
"""Permanently delete a database client connection."""
result = await db.execute(
delete(DatabaseClient).where(DatabaseClient.id == client_id)
)
await db.commit()
deleted = result.rowcount > 0
if deleted:
logger.info(f"Deleted database client {client_id}")
return deleted
database_client_service = DatabaseClientService()
|