File size: 5,149 Bytes
2ba0613
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c7d59cf
 
 
 
 
 
 
 
 
 
 
2ba0613
 
 
c7d59cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2ba0613
 
 
 
 
 
 
 
 
 
c7d59cf
 
2ba0613
 
c7d59cf
 
 
 
 
 
 
2ba0613
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
"""Service for managing user-registered external database connections."""

import uuid
from typing import List, Optional

from sqlalchemy import delete, select
from sqlalchemy.ext.asyncio import AsyncSession

from src.db.postgres.models import DatabaseClient
from src.middlewares.logging import get_logger
from src.utils.db_credential_encryption import (
    decrypt_credentials_dict,
    encrypt_credentials_dict,
)

logger = get_logger("database_client_service")


    # Fields that identify the same physical database per db_type.
_CONNECTION_IDENTITY_KEYS: dict[str, tuple[str, ...]] = {
    "postgres": ("host", "port", "database"),
    "supabase": ("host", "port", "database"),
    "mysql": ("host", "port", "database"),
    "sqlserver": ("host", "port", "database"),
    "bigquery": ("project_id", "dataset_id"),
    "snowflake": ("account", "warehouse", "database"),
}


class DatabaseClientService:
    """Service for managing user-registered external database connections."""

    async def _find_duplicate(
        self,
        db: AsyncSession,
        user_id: str,
        db_type: str,
        credentials: dict,
    ) -> Optional[DatabaseClient]:
        """Return an existing client if it points to the same physical database."""
        identity_keys = _CONNECTION_IDENTITY_KEYS.get(db_type, ())
        if not identity_keys:
            return None

        result = await db.execute(
            select(DatabaseClient).where(
                DatabaseClient.user_id == user_id,
                DatabaseClient.db_type == db_type,
            )
        )
        for existing in result.scalars().all():
            decrypted = decrypt_credentials_dict(existing.credentials)
            if all(
                decrypted.get(k) == credentials.get(k) for k in identity_keys
            ):
                return existing
        return None

    async def create(
        self,
        db: AsyncSession,
        user_id: str,
        name: str,
        db_type: str,
        credentials: dict,
    ) -> DatabaseClient:
        """Register a new database client connection.

        If a connection to the same physical database already exists for this
        user, the existing record is returned instead of creating a duplicate.
        Credentials are encrypted before being stored.
        """
        existing = await self._find_duplicate(db, user_id, db_type, credentials)
        if existing:
            logger.info(
                f"Duplicate connection detected, returning existing client {existing.id}"
            )
            return existing

        client = DatabaseClient(
            id=str(uuid.uuid4()),
            user_id=user_id,
            name=name,
            db_type=db_type,
            credentials=encrypt_credentials_dict(credentials),
            status="active",
        )
        db.add(client)
        await db.commit()
        await db.refresh(client)
        logger.info(f"Created database client {client.id} for user {user_id}")
        return client

    async def get_user_clients(
        self,
        db: AsyncSession,
        user_id: str,
    ) -> List[DatabaseClient]:
        """Return all active and inactive database clients for a user."""
        result = await db.execute(
            select(DatabaseClient)
            .where(DatabaseClient.user_id == user_id)
            .order_by(DatabaseClient.created_at.desc())
        )
        return result.scalars().all()

    async def get(
        self,
        db: AsyncSession,
        client_id: str,
    ) -> Optional[DatabaseClient]:
        """Return a single database client by its ID."""
        result = await db.execute(
            select(DatabaseClient).where(DatabaseClient.id == client_id)
        )
        return result.scalars().first()

    async def update(
        self,
        db: AsyncSession,
        client_id: str,
        name: Optional[str] = None,
        credentials: Optional[dict] = None,
        status: Optional[str] = None,
    ) -> Optional[DatabaseClient]:
        """Update an existing database client connection.

        Only non-None fields are updated.
        Credentials are re-encrypted if provided.
        """
        client = await self.get(db, client_id)
        if not client:
            return None

        if name is not None:
            client.name = name
        if credentials is not None:
            client.credentials = encrypt_credentials_dict(credentials)
        if status is not None:
            client.status = status

        await db.commit()
        await db.refresh(client)
        logger.info(f"Updated database client {client_id}")
        return client

    async def delete(
        self,
        db: AsyncSession,
        client_id: str,
    ) -> bool:
        """Permanently delete a database client connection."""
        result = await db.execute(
            delete(DatabaseClient).where(DatabaseClient.id == client_id)
        )
        await db.commit()
        deleted = result.rowcount > 0
        if deleted:
            logger.info(f"Deleted database client {client_id}")
        return deleted


database_client_service = DatabaseClientService()