File size: 8,777 Bytes
6bff5d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
"""Database schema introspection (Postgres / MySQL / Supabase).

Reads information_schema for tables/columns/types, samples ~100 rows per table
for `sample_values` and basic stats. Description fields are left empty —
the planner relies on names + samples + stats directly.

Reuses Phase 1 utilities (`database_client_service`, `db_credential_encryption`,
`db_pipeline_service.engine_scope`, `extractor.get_schema/profile_column/get_row_count`)
to avoid reimplementation. The cleanup PR will move those into `security/` and
`pipeline/db_pipeline/` respectively.
"""

import asyncio
import hashlib
from datetime import UTC, datetime
from decimal import Decimal
from typing import Any

from src.database_client.database_client_service import database_client_service
from src.db.postgres.connection import AsyncSessionLocal
from src.middlewares.logging import get_logger
from src.pipeline.db_pipeline import db_pipeline_service
from src.pipeline.db_pipeline.extractor import (
    get_row_count,
    get_schema,
    profile_column,
)
from src.utils.db_credential_encryption import decrypt_credentials_dict

from ..models import Column, ColumnStats, DataType, ForeignKey, Source, Table
from ..pii_detector import PIIDetector
from .base import BaseIntrospector

logger = get_logger("db_introspector")

_DBCLIENT_PREFIX = "dbclient://"


def _stable_id(prefix: str, *parts: str) -> str:
    """Deterministic short ID from joined parts. Survives renames at the
    `name` field while preserving identity for cached IRs.

    Hash is non-cryptographic (identifier only).
    """
    h = hashlib.sha1(
        "/".join(parts).encode("utf-8"), usedforsecurity=False
    ).hexdigest()[:12]
    return f"{prefix}{h}"


def _map_sql_type(sql_type: str) -> DataType:
    """Map a stringified SQLAlchemy type to a Catalog DataType.

    Matches on substring of the SQLAlchemy type repr (e.g. 'INTEGER',
    'TIMESTAMP', 'BOOLEAN'). Conservative — unknowns fall back to "string"
    so the column is at least addressable.
    """
    s = sql_type.upper()
    if "INT" in s:
        return "int"
    if "FLOAT" in s or "NUMERIC" in s or "DECIMAL" in s or "REAL" in s or "DOUBLE" in s:
        return "decimal"
    if "BOOL" in s:
        return "bool"
    if "TIMESTAMP" in s or "DATETIME" in s:
        return "datetime"
    if "DATE" in s:
        return "date"
    if "JSON" in s:
        return "json"
    return "string"


def _normalize(v: Any) -> Any:
    """Coerce non-JSON-native scalars (Decimal, numpy, datetime) to types
    that survive the jsonb round-trip when the catalog is persisted.
    """
    if v is None:
        return None
    if isinstance(v, Decimal):
        return float(v)
    try:
        import numpy as np

        if isinstance(v, np.generic):
            return v.item()
    except ImportError:
        pass
    if isinstance(v, datetime):
        return v.isoformat()
    return v


class DatabaseIntrospector(BaseIntrospector):
    """Connect to user DB → read information_schema → sample 100 rows/table."""

    def __init__(self) -> None:
        self._pii = PIIDetector()

    async def introspect(self, location_ref: str) -> Source:
        if not location_ref.startswith(_DBCLIENT_PREFIX):
            raise ValueError(
                f"DatabaseIntrospector expects 'dbclient://...' location_ref, "
                f"got {location_ref!r}"
            )
        client_id = location_ref[len(_DBCLIENT_PREFIX):]
        if not client_id:
            raise ValueError("location_ref is missing client_id after 'dbclient://'")

        async with AsyncSessionLocal() as session:
            client = await database_client_service.get(session, client_id)
        if client is None:
            raise ValueError(f"DatabaseClient {client_id!r} not found")

        creds = decrypt_credentials_dict(client.credentials)
        logger.info(
            "introspecting db source",
            client_id=client_id,
            db_type=client.db_type,
            name=client.name,
        )

        # SQLAlchemy inspect() + pandas read_sql are synchronous — run in a
        # threadpool so the event loop stays free.
        tables: list[Table] = await asyncio.to_thread(
            self._introspect_sync, client.db_type, creds
        )

        return Source(
            source_id=client_id,
            source_type="schema",
            name=client.name,
            location_ref=location_ref,
            updated_at=datetime.now(UTC),
            tables=tables,
        )

    def _introspect_sync(self, db_type: str, creds: dict) -> list[Table]:
        with db_pipeline_service.engine_scope(db_type, creds) as engine:
            schema = get_schema(engine)
            tables: list[Table] = []
            for table_name, cols in schema.items():
                try:
                    row_count = get_row_count(engine, table_name)
                except Exception as e:
                    logger.error(
                        "row_count failed; skipping table",
                        table=table_name,
                        error=str(e),
                    )
                    continue

                columns: list[Column] = []
                for col in cols:
                    try:
                        profile = profile_column(
                            engine,
                            table_name,
                            col["name"],
                            col.get("is_numeric", False),
                            row_count,
                            is_temporal=col.get("is_temporal", False),
                        )
                    except Exception as e:
                        logger.error(
                            "profile_column failed; skipping column",
                            table=table_name,
                            column=col["name"],
                            error=str(e),
                        )
                        continue
                    columns.append(self._to_column(table_name, col, profile))

                foreign_keys = self._extract_foreign_keys(table_name, cols)

                tables.append(
                    Table(
                        table_id=_stable_id("t_", table_name),
                        name=table_name,
                        row_count=row_count,
                        columns=columns,
                        foreign_keys=foreign_keys,
                    )
                )
        return tables

    @staticmethod
    def _extract_foreign_keys(
        table_name: str, cols: list[dict[str, Any]]
    ) -> list[ForeignKey]:
        """Convert extractor's `foreign_key: 'target_table.target_col'` strings
        into ForeignKey objects with stable IDs (derived deterministically from
        names — same scheme used to generate table_id / column_id elsewhere).
        """
        fks: list[ForeignKey] = []
        for col in cols:
            fk_str = col.get("foreign_key")
            if not fk_str:
                continue
            target_table, _, target_col = fk_str.partition(".")
            if not target_table or not target_col:
                continue
            fks.append(
                ForeignKey(
                    column_id=_stable_id("c_", table_name, col["name"]),
                    target_table_id=_stable_id("t_", target_table),
                    target_column_id=_stable_id("c_", target_table, target_col),
                )
            )
        return fks

    def _to_column(
        self, table_name: str, col: dict[str, Any], profile: dict[str, Any]
    ) -> Column:
        name = col["name"]
        sample_values: list[Any] | None = [
            _normalize(v) for v in (profile.get("sample_values") or [])
        ] or None

        top_raw = profile.get("top_values") or []
        top_values: list[Any] | None = [
            _normalize(v) for v, _cnt in top_raw
        ] or None

        column = Column(
            column_id=_stable_id("c_", table_name, name),
            name=name,
            data_type=_map_sql_type(str(col["type"])),
            nullable=True,  # nullable not surfaced by extractor; default permissive
            pii_flag=False,
            sample_values=sample_values,
            stats=ColumnStats(
                min=_normalize(profile.get("min")),
                max=_normalize(profile.get("max")),
                mean=_normalize(profile.get("mean")),
                median=_normalize(profile.get("median")),
                distinct_count=profile.get("distinct_count"),
                top_values=top_values,
            ),
        )
        if self._pii.detect(column):
            return column.model_copy(update={"pii_flag": True, "sample_values": None})
        return column


database_introspector = DatabaseIntrospector()