"""Database schema introspection (Postgres / MySQL / Supabase). Reads information_schema for tables/columns/types, samples ~100 rows per table for `sample_values` and basic stats. Description fields are left empty — the planner relies on names + samples + stats directly. Reuses Phase 1 utilities (`database_client_service`, `db_credential_encryption`, `db_pipeline_service.engine_scope`, `extractor.get_schema/profile_column/get_row_count`) to avoid reimplementation. The cleanup PR will move those into `security/` and `pipeline/db_pipeline/` respectively. """ import asyncio import hashlib from datetime import UTC, datetime from decimal import Decimal from typing import Any from src.database_client.database_client_service import database_client_service from src.db.postgres.connection import AsyncSessionLocal from src.middlewares.logging import get_logger from src.pipeline.db_pipeline import db_pipeline_service from src.pipeline.db_pipeline.extractor import ( get_row_count, get_schema, profile_column, ) from src.utils.db_credential_encryption import decrypt_credentials_dict from ..models import Column, ColumnStats, DataType, ForeignKey, Source, Table from ..pii_detector import PIIDetector from .base import BaseIntrospector logger = get_logger("db_introspector") _DBCLIENT_PREFIX = "dbclient://" def _stable_id(prefix: str, *parts: str) -> str: """Deterministic short ID from joined parts. Survives renames at the `name` field while preserving identity for cached IRs. Hash is non-cryptographic (identifier only). """ h = hashlib.sha1( "/".join(parts).encode("utf-8"), usedforsecurity=False ).hexdigest()[:12] return f"{prefix}{h}" def _map_sql_type(sql_type: str) -> DataType: """Map a stringified SQLAlchemy type to a Catalog DataType. Matches on substring of the SQLAlchemy type repr (e.g. 'INTEGER', 'TIMESTAMP', 'BOOLEAN'). Conservative — unknowns fall back to "string" so the column is at least addressable. """ s = sql_type.upper() if "INT" in s: return "int" if "FLOAT" in s or "NUMERIC" in s or "DECIMAL" in s or "REAL" in s or "DOUBLE" in s: return "decimal" if "BOOL" in s: return "bool" if "TIMESTAMP" in s or "DATETIME" in s: return "datetime" if "DATE" in s: return "date" if "JSON" in s: return "json" return "string" def _normalize(v: Any) -> Any: """Coerce non-JSON-native scalars (Decimal, numpy, datetime) to types that survive the jsonb round-trip when the catalog is persisted. """ if v is None: return None if isinstance(v, Decimal): return float(v) try: import numpy as np if isinstance(v, np.generic): return v.item() except ImportError: pass if isinstance(v, datetime): return v.isoformat() return v class DatabaseIntrospector(BaseIntrospector): """Connect to user DB → read information_schema → sample 100 rows/table.""" def __init__(self) -> None: self._pii = PIIDetector() async def introspect(self, location_ref: str) -> Source: if not location_ref.startswith(_DBCLIENT_PREFIX): raise ValueError( f"DatabaseIntrospector expects 'dbclient://...' location_ref, " f"got {location_ref!r}" ) client_id = location_ref[len(_DBCLIENT_PREFIX):] if not client_id: raise ValueError("location_ref is missing client_id after 'dbclient://'") async with AsyncSessionLocal() as session: client = await database_client_service.get(session, client_id) if client is None: raise ValueError(f"DatabaseClient {client_id!r} not found") creds = decrypt_credentials_dict(client.credentials) logger.info( "introspecting db source", client_id=client_id, db_type=client.db_type, name=client.name, ) # SQLAlchemy inspect() + pandas read_sql are synchronous — run in a # threadpool so the event loop stays free. tables: list[Table] = await asyncio.to_thread( self._introspect_sync, client.db_type, creds ) return Source( source_id=client_id, source_type="schema", name=client.name, location_ref=location_ref, updated_at=datetime.now(UTC), tables=tables, ) def _introspect_sync(self, db_type: str, creds: dict) -> list[Table]: with db_pipeline_service.engine_scope(db_type, creds) as engine: schema = get_schema(engine) tables: list[Table] = [] for table_name, cols in schema.items(): try: row_count = get_row_count(engine, table_name) except Exception as e: logger.error( "row_count failed; skipping table", table=table_name, error=str(e), ) continue columns: list[Column] = [] for col in cols: try: profile = profile_column( engine, table_name, col["name"], col.get("is_numeric", False), row_count, is_temporal=col.get("is_temporal", False), ) except Exception as e: logger.error( "profile_column failed; skipping column", table=table_name, column=col["name"], error=str(e), ) continue columns.append(self._to_column(table_name, col, profile)) foreign_keys = self._extract_foreign_keys(table_name, cols) tables.append( Table( table_id=_stable_id("t_", table_name), name=table_name, row_count=row_count, columns=columns, foreign_keys=foreign_keys, ) ) return tables @staticmethod def _extract_foreign_keys( table_name: str, cols: list[dict[str, Any]] ) -> list[ForeignKey]: """Convert extractor's `foreign_key: 'target_table.target_col'` strings into ForeignKey objects with stable IDs (derived deterministically from names — same scheme used to generate table_id / column_id elsewhere). """ fks: list[ForeignKey] = [] for col in cols: fk_str = col.get("foreign_key") if not fk_str: continue target_table, _, target_col = fk_str.partition(".") if not target_table or not target_col: continue fks.append( ForeignKey( column_id=_stable_id("c_", table_name, col["name"]), target_table_id=_stable_id("t_", target_table), target_column_id=_stable_id("c_", target_table, target_col), ) ) return fks def _to_column( self, table_name: str, col: dict[str, Any], profile: dict[str, Any] ) -> Column: name = col["name"] sample_values: list[Any] | None = [ _normalize(v) for v in (profile.get("sample_values") or []) ] or None top_raw = profile.get("top_values") or [] top_values: list[Any] | None = [ _normalize(v) for v, _cnt in top_raw ] or None column = Column( column_id=_stable_id("c_", table_name, name), name=name, data_type=_map_sql_type(str(col["type"])), nullable=True, # nullable not surfaced by extractor; default permissive pii_flag=False, sample_values=sample_values, stats=ColumnStats( min=_normalize(profile.get("min")), max=_normalize(profile.get("max")), mean=_normalize(profile.get("mean")), median=_normalize(profile.get("median")), distinct_count=profile.get("distinct_count"), top_values=top_values, ), ) if self._pii.detect(column): return column.model_copy(update={"pii_flag": True, "sample_values": None}) return column database_introspector = DatabaseIntrospector()