Agentic-Service-Data-Eyond / src /database_client /database_client_service.py
Rifqi Hafizuddin
[NOTICKET] add duplicate check for storing database
d310770
raw
history blame
5.15 kB
"""Service for managing user-registered external database connections."""
import uuid
from typing import List, Optional
from sqlalchemy import delete, select
from sqlalchemy.ext.asyncio import AsyncSession
from src.db.postgres.models import DatabaseClient
from src.middlewares.logging import get_logger
from src.utils.db_credential_encryption import (
decrypt_credentials_dict,
encrypt_credentials_dict,
)
logger = get_logger("database_client_service")
# Fields that identify the same physical database per db_type.
_CONNECTION_IDENTITY_KEYS: dict[str, tuple[str, ...]] = {
"postgres": ("host", "port", "database"),
"supabase": ("host", "port", "database"),
"mysql": ("host", "port", "database"),
"sqlserver": ("host", "port", "database"),
"bigquery": ("project_id", "dataset_id"),
"snowflake": ("account", "warehouse", "database"),
}
class DatabaseClientService:
"""Service for managing user-registered external database connections."""
async def _find_duplicate(
self,
db: AsyncSession,
user_id: str,
db_type: str,
credentials: dict,
) -> Optional[DatabaseClient]:
"""Return an existing client if it points to the same physical database."""
identity_keys = _CONNECTION_IDENTITY_KEYS.get(db_type, ())
if not identity_keys:
return None
result = await db.execute(
select(DatabaseClient).where(
DatabaseClient.user_id == user_id,
DatabaseClient.db_type == db_type,
)
)
for existing in result.scalars().all():
decrypted = decrypt_credentials_dict(existing.credentials)
if all(
decrypted.get(k) == credentials.get(k) for k in identity_keys
):
return existing
return None
async def create(
self,
db: AsyncSession,
user_id: str,
name: str,
db_type: str,
credentials: dict,
) -> DatabaseClient:
"""Register a new database client connection.
If a connection to the same physical database already exists for this
user, the existing record is returned instead of creating a duplicate.
Credentials are encrypted before being stored.
"""
existing = await self._find_duplicate(db, user_id, db_type, credentials)
if existing:
logger.info(
f"Duplicate connection detected, returning existing client {existing.id}"
)
return existing
client = DatabaseClient(
id=str(uuid.uuid4()),
user_id=user_id,
name=name,
db_type=db_type,
credentials=encrypt_credentials_dict(credentials),
status="active",
)
db.add(client)
await db.commit()
await db.refresh(client)
logger.info(f"Created database client {client.id} for user {user_id}")
return client
async def get_user_clients(
self,
db: AsyncSession,
user_id: str,
) -> List[DatabaseClient]:
"""Return all active and inactive database clients for a user."""
result = await db.execute(
select(DatabaseClient)
.where(DatabaseClient.user_id == user_id)
.order_by(DatabaseClient.created_at.desc())
)
return result.scalars().all()
async def get(
self,
db: AsyncSession,
client_id: str,
) -> Optional[DatabaseClient]:
"""Return a single database client by its ID."""
result = await db.execute(
select(DatabaseClient).where(DatabaseClient.id == client_id)
)
return result.scalars().first()
async def update(
self,
db: AsyncSession,
client_id: str,
name: Optional[str] = None,
credentials: Optional[dict] = None,
status: Optional[str] = None,
) -> Optional[DatabaseClient]:
"""Update an existing database client connection.
Only non-None fields are updated.
Credentials are re-encrypted if provided.
"""
client = await self.get(db, client_id)
if not client:
return None
if name is not None:
client.name = name
if credentials is not None:
client.credentials = encrypt_credentials_dict(credentials)
if status is not None:
client.status = status
await db.commit()
await db.refresh(client)
logger.info(f"Updated database client {client_id}")
return client
async def delete(
self,
db: AsyncSession,
client_id: str,
) -> bool:
"""Permanently delete a database client connection."""
result = await db.execute(
delete(DatabaseClient).where(DatabaseClient.id == client_id)
)
await db.commit()
deleted = result.rowcount > 0
if deleted:
logger.info(f"Deleted database client {client_id}")
return deleted
database_client_service = DatabaseClientService()