| """Database schema introspection (Postgres / MySQL / Supabase). |
| |
| Reads information_schema for tables/columns/types, samples ~100 rows per table |
| for `sample_values` and basic stats. Description fields are left empty — |
| the planner relies on names + samples + stats directly. |
| |
| Reuses Phase 1 utilities (`database_client_service`, `db_credential_encryption`, |
| `db_pipeline_service.engine_scope`, `extractor.get_schema/profile_column/get_row_count`) |
| to avoid reimplementation. The cleanup PR will move those into `security/` and |
| `pipeline/db_pipeline/` respectively. |
| """ |
|
|
| import asyncio |
| import hashlib |
| from datetime import UTC, datetime |
| from decimal import Decimal |
| from typing import Any |
|
|
| from src.database_client.database_client_service import database_client_service |
| from src.db.postgres.connection import AsyncSessionLocal |
| from src.middlewares.logging import get_logger |
| from src.pipeline.db_pipeline import db_pipeline_service |
| from src.pipeline.db_pipeline.extractor import ( |
| get_row_count, |
| get_schema, |
| profile_column, |
| ) |
| from src.utils.db_credential_encryption import decrypt_credentials_dict |
|
|
| from ..models import Column, ColumnStats, DataType, ForeignKey, Source, Table |
| from ..pii_detector import PIIDetector |
| from .base import BaseIntrospector |
|
|
| logger = get_logger("db_introspector") |
|
|
| _DBCLIENT_PREFIX = "dbclient://" |
|
|
|
|
| def _stable_id(prefix: str, *parts: str) -> str: |
| """Deterministic short ID from joined parts. Survives renames at the |
| `name` field while preserving identity for cached IRs. |
| |
| Hash is non-cryptographic (identifier only). |
| """ |
| h = hashlib.sha1( |
| "/".join(parts).encode("utf-8"), usedforsecurity=False |
| ).hexdigest()[:12] |
| return f"{prefix}{h}" |
|
|
|
|
| def _map_sql_type(sql_type: str) -> DataType: |
| """Map a stringified SQLAlchemy type to a Catalog DataType. |
| |
| Matches on substring of the SQLAlchemy type repr (e.g. 'INTEGER', |
| 'TIMESTAMP', 'BOOLEAN'). Conservative — unknowns fall back to "string" |
| so the column is at least addressable. |
| """ |
| s = sql_type.upper() |
| if "INT" in s: |
| return "int" |
| if "FLOAT" in s or "NUMERIC" in s or "DECIMAL" in s or "REAL" in s or "DOUBLE" in s: |
| return "decimal" |
| if "BOOL" in s: |
| return "bool" |
| if "TIMESTAMP" in s or "DATETIME" in s: |
| return "datetime" |
| if "DATE" in s: |
| return "date" |
| if "JSON" in s: |
| return "json" |
| return "string" |
|
|
|
|
| def _normalize(v: Any) -> Any: |
| """Coerce non-JSON-native scalars (Decimal, numpy, datetime) to types |
| that survive the jsonb round-trip when the catalog is persisted. |
| """ |
| if v is None: |
| return None |
| if isinstance(v, Decimal): |
| return float(v) |
| try: |
| import numpy as np |
|
|
| if isinstance(v, np.generic): |
| return v.item() |
| except ImportError: |
| pass |
| if isinstance(v, datetime): |
| return v.isoformat() |
| return v |
|
|
|
|
| class DatabaseIntrospector(BaseIntrospector): |
| """Connect to user DB → read information_schema → sample 100 rows/table.""" |
|
|
| def __init__(self) -> None: |
| self._pii = PIIDetector() |
|
|
| async def introspect(self, location_ref: str) -> Source: |
| if not location_ref.startswith(_DBCLIENT_PREFIX): |
| raise ValueError( |
| f"DatabaseIntrospector expects 'dbclient://...' location_ref, " |
| f"got {location_ref!r}" |
| ) |
| client_id = location_ref[len(_DBCLIENT_PREFIX):] |
| if not client_id: |
| raise ValueError("location_ref is missing client_id after 'dbclient://'") |
|
|
| async with AsyncSessionLocal() as session: |
| client = await database_client_service.get(session, client_id) |
| if client is None: |
| raise ValueError(f"DatabaseClient {client_id!r} not found") |
|
|
| creds = decrypt_credentials_dict(client.credentials) |
| logger.info( |
| "introspecting db source", |
| client_id=client_id, |
| db_type=client.db_type, |
| name=client.name, |
| ) |
|
|
| |
| |
| tables: list[Table] = await asyncio.to_thread( |
| self._introspect_sync, client.db_type, creds |
| ) |
|
|
| return Source( |
| source_id=client_id, |
| source_type="schema", |
| name=client.name, |
| location_ref=location_ref, |
| updated_at=datetime.now(UTC), |
| tables=tables, |
| ) |
|
|
| def _introspect_sync(self, db_type: str, creds: dict) -> list[Table]: |
| with db_pipeline_service.engine_scope(db_type, creds) as engine: |
| schema = get_schema(engine) |
| tables: list[Table] = [] |
| for table_name, cols in schema.items(): |
| try: |
| row_count = get_row_count(engine, table_name) |
| except Exception as e: |
| logger.error( |
| "row_count failed; skipping table", |
| table=table_name, |
| error=str(e), |
| ) |
| continue |
|
|
| columns: list[Column] = [] |
| for col in cols: |
| try: |
| profile = profile_column( |
| engine, |
| table_name, |
| col["name"], |
| col.get("is_numeric", False), |
| row_count, |
| is_temporal=col.get("is_temporal", False), |
| ) |
| except Exception as e: |
| logger.error( |
| "profile_column failed; skipping column", |
| table=table_name, |
| column=col["name"], |
| error=str(e), |
| ) |
| continue |
| columns.append(self._to_column(table_name, col, profile)) |
|
|
| foreign_keys = self._extract_foreign_keys(table_name, cols) |
|
|
| tables.append( |
| Table( |
| table_id=_stable_id("t_", table_name), |
| name=table_name, |
| row_count=row_count, |
| columns=columns, |
| foreign_keys=foreign_keys, |
| ) |
| ) |
| return tables |
|
|
| @staticmethod |
| def _extract_foreign_keys( |
| table_name: str, cols: list[dict[str, Any]] |
| ) -> list[ForeignKey]: |
| """Convert extractor's `foreign_key: 'target_table.target_col'` strings |
| into ForeignKey objects with stable IDs (derived deterministically from |
| names — same scheme used to generate table_id / column_id elsewhere). |
| """ |
| fks: list[ForeignKey] = [] |
| for col in cols: |
| fk_str = col.get("foreign_key") |
| if not fk_str: |
| continue |
| target_table, _, target_col = fk_str.partition(".") |
| if not target_table or not target_col: |
| continue |
| fks.append( |
| ForeignKey( |
| column_id=_stable_id("c_", table_name, col["name"]), |
| target_table_id=_stable_id("t_", target_table), |
| target_column_id=_stable_id("c_", target_table, target_col), |
| ) |
| ) |
| return fks |
|
|
| def _to_column( |
| self, table_name: str, col: dict[str, Any], profile: dict[str, Any] |
| ) -> Column: |
| name = col["name"] |
| sample_values: list[Any] | None = [ |
| _normalize(v) for v in (profile.get("sample_values") or []) |
| ] or None |
|
|
| top_raw = profile.get("top_values") or [] |
| top_values: list[Any] | None = [ |
| _normalize(v) for v, _cnt in top_raw |
| ] or None |
|
|
| column = Column( |
| column_id=_stable_id("c_", table_name, name), |
| name=name, |
| data_type=_map_sql_type(str(col["type"])), |
| nullable=True, |
| pii_flag=False, |
| sample_values=sample_values, |
| stats=ColumnStats( |
| min=_normalize(profile.get("min")), |
| max=_normalize(profile.get("max")), |
| mean=_normalize(profile.get("mean")), |
| median=_normalize(profile.get("median")), |
| distinct_count=profile.get("distinct_count"), |
| top_values=top_values, |
| ), |
| ) |
| if self._pii.detect(column): |
| return column.model_copy(update={"pii_flag": True, "sample_values": None}) |
| return column |
|
|
|
|
| database_introspector = DatabaseIntrospector() |
|
|