Rifqi Hafizuddin
[KM-564] Edit column description in catalog to reduce token & ingestion time
430c361 | """Database schema introspection (Postgres / MySQL / Supabase). | |
| Reads information_schema for tables/columns/types, samples ~100 rows per table | |
| for `sample_values` and basic stats. Description fields are left empty — | |
| the planner relies on names + samples + stats directly. | |
| Reuses Phase 1 utilities (`database_client_service`, `db_credential_encryption`, | |
| `db_pipeline_service.engine_scope`, `extractor.get_schema/profile_column/get_row_count`) | |
| to avoid reimplementation. The cleanup PR will move those into `security/` and | |
| `pipeline/db_pipeline/` respectively. | |
| """ | |
| import asyncio | |
| import hashlib | |
| from datetime import UTC, datetime | |
| from decimal import Decimal | |
| from typing import Any | |
| from src.database_client.database_client_service import database_client_service | |
| from src.db.postgres.connection import AsyncSessionLocal | |
| from src.middlewares.logging import get_logger | |
| from src.pipeline.db_pipeline import db_pipeline_service | |
| from src.pipeline.db_pipeline.extractor import ( | |
| get_row_count, | |
| get_schema, | |
| profile_column, | |
| ) | |
| from src.utils.db_credential_encryption import decrypt_credentials_dict | |
| from ..models import Column, ColumnStats, DataType, ForeignKey, Source, Table | |
| from ..pii_detector import PIIDetector | |
| from .base import BaseIntrospector | |
| logger = get_logger("db_introspector") | |
| _DBCLIENT_PREFIX = "dbclient://" | |
| def _stable_id(prefix: str, *parts: str) -> str: | |
| """Deterministic short ID from joined parts. Survives renames at the | |
| `name` field while preserving identity for cached IRs. | |
| Hash is non-cryptographic (identifier only). | |
| """ | |
| h = hashlib.sha1( | |
| "/".join(parts).encode("utf-8"), usedforsecurity=False | |
| ).hexdigest()[:12] | |
| return f"{prefix}{h}" | |
| def _map_sql_type(sql_type: str) -> DataType: | |
| """Map a stringified SQLAlchemy type to a Catalog DataType. | |
| Matches on substring of the SQLAlchemy type repr (e.g. 'INTEGER', | |
| 'TIMESTAMP', 'BOOLEAN'). Conservative — unknowns fall back to "string" | |
| so the column is at least addressable. | |
| """ | |
| s = sql_type.upper() | |
| if "INT" in s: | |
| return "int" | |
| if "FLOAT" in s or "NUMERIC" in s or "DECIMAL" in s or "REAL" in s or "DOUBLE" in s: | |
| return "decimal" | |
| if "BOOL" in s: | |
| return "bool" | |
| if "TIMESTAMP" in s or "DATETIME" in s: | |
| return "datetime" | |
| if "DATE" in s: | |
| return "date" | |
| if "JSON" in s: | |
| return "json" | |
| return "string" | |
| def _normalize(v: Any) -> Any: | |
| """Coerce non-JSON-native scalars (Decimal, numpy, datetime) to types | |
| that survive the jsonb round-trip when the catalog is persisted. | |
| """ | |
| if v is None: | |
| return None | |
| if isinstance(v, Decimal): | |
| return float(v) | |
| try: | |
| import numpy as np | |
| if isinstance(v, np.generic): | |
| return v.item() | |
| except ImportError: | |
| pass | |
| if isinstance(v, datetime): | |
| return v.isoformat() | |
| return v | |
| class DatabaseIntrospector(BaseIntrospector): | |
| """Connect to user DB → read information_schema → sample 100 rows/table.""" | |
| def __init__(self) -> None: | |
| self._pii = PIIDetector() | |
| async def introspect(self, location_ref: str) -> Source: | |
| if not location_ref.startswith(_DBCLIENT_PREFIX): | |
| raise ValueError( | |
| f"DatabaseIntrospector expects 'dbclient://...' location_ref, " | |
| f"got {location_ref!r}" | |
| ) | |
| client_id = location_ref[len(_DBCLIENT_PREFIX):] | |
| if not client_id: | |
| raise ValueError("location_ref is missing client_id after 'dbclient://'") | |
| async with AsyncSessionLocal() as session: | |
| client = await database_client_service.get(session, client_id) | |
| if client is None: | |
| raise ValueError(f"DatabaseClient {client_id!r} not found") | |
| creds = decrypt_credentials_dict(client.credentials) | |
| logger.info( | |
| "introspecting db source", | |
| client_id=client_id, | |
| db_type=client.db_type, | |
| name=client.name, | |
| ) | |
| # SQLAlchemy inspect() + pandas read_sql are synchronous — run in a | |
| # threadpool so the event loop stays free. | |
| tables: list[Table] = await asyncio.to_thread( | |
| self._introspect_sync, client.db_type, creds | |
| ) | |
| return Source( | |
| source_id=client_id, | |
| source_type="schema", | |
| name=client.name, | |
| location_ref=location_ref, | |
| updated_at=datetime.now(UTC), | |
| tables=tables, | |
| ) | |
| def _introspect_sync(self, db_type: str, creds: dict) -> list[Table]: | |
| with db_pipeline_service.engine_scope(db_type, creds) as engine: | |
| schema = get_schema(engine) | |
| tables: list[Table] = [] | |
| for table_name, cols in schema.items(): | |
| try: | |
| row_count = get_row_count(engine, table_name) | |
| except Exception as e: | |
| logger.error( | |
| "row_count failed; skipping table", | |
| table=table_name, | |
| error=str(e), | |
| ) | |
| continue | |
| columns: list[Column] = [] | |
| for col in cols: | |
| try: | |
| profile = profile_column( | |
| engine, | |
| table_name, | |
| col["name"], | |
| col.get("is_numeric", False), | |
| row_count, | |
| is_temporal=col.get("is_temporal", False), | |
| ) | |
| except Exception as e: | |
| logger.error( | |
| "profile_column failed; skipping column", | |
| table=table_name, | |
| column=col["name"], | |
| error=str(e), | |
| ) | |
| continue | |
| columns.append(self._to_column(table_name, col, profile)) | |
| foreign_keys = self._extract_foreign_keys(table_name, cols) | |
| tables.append( | |
| Table( | |
| table_id=_stable_id("t_", table_name), | |
| name=table_name, | |
| row_count=row_count, | |
| columns=columns, | |
| foreign_keys=foreign_keys, | |
| ) | |
| ) | |
| return tables | |
| def _extract_foreign_keys( | |
| table_name: str, cols: list[dict[str, Any]] | |
| ) -> list[ForeignKey]: | |
| """Convert extractor's `foreign_key: 'target_table.target_col'` strings | |
| into ForeignKey objects with stable IDs (derived deterministically from | |
| names — same scheme used to generate table_id / column_id elsewhere). | |
| """ | |
| fks: list[ForeignKey] = [] | |
| for col in cols: | |
| fk_str = col.get("foreign_key") | |
| if not fk_str: | |
| continue | |
| target_table, _, target_col = fk_str.partition(".") | |
| if not target_table or not target_col: | |
| continue | |
| fks.append( | |
| ForeignKey( | |
| column_id=_stable_id("c_", table_name, col["name"]), | |
| target_table_id=_stable_id("t_", target_table), | |
| target_column_id=_stable_id("c_", target_table, target_col), | |
| ) | |
| ) | |
| return fks | |
| def _to_column( | |
| self, table_name: str, col: dict[str, Any], profile: dict[str, Any] | |
| ) -> Column: | |
| name = col["name"] | |
| sample_values: list[Any] | None = [ | |
| _normalize(v) for v in (profile.get("sample_values") or []) | |
| ] or None | |
| top_raw = profile.get("top_values") or [] | |
| top_values: list[Any] | None = [ | |
| _normalize(v) for v, _cnt in top_raw | |
| ] or None | |
| column = Column( | |
| column_id=_stable_id("c_", table_name, name), | |
| name=name, | |
| data_type=_map_sql_type(str(col["type"])), | |
| nullable=True, # nullable not surfaced by extractor; default permissive | |
| pii_flag=False, | |
| sample_values=sample_values, | |
| stats=ColumnStats( | |
| min=_normalize(profile.get("min")), | |
| max=_normalize(profile.get("max")), | |
| mean=_normalize(profile.get("mean")), | |
| median=_normalize(profile.get("median")), | |
| distinct_count=profile.get("distinct_count"), | |
| top_values=top_values, | |
| ), | |
| ) | |
| if self._pii.detect(column): | |
| return column.model_copy(update={"pii_flag": True, "sample_values": None}) | |
| return column | |
| database_introspector = DatabaseIntrospector() | |