ishaq101's picture
feat/Catalog Retrieval System (#1)
6bff5d9
"""Database schema introspection (Postgres / MySQL / Supabase).
Reads information_schema for tables/columns/types, samples ~100 rows per table
for `sample_values` and basic stats. Description fields are left empty —
the planner relies on names + samples + stats directly.
Reuses Phase 1 utilities (`database_client_service`, `db_credential_encryption`,
`db_pipeline_service.engine_scope`, `extractor.get_schema/profile_column/get_row_count`)
to avoid reimplementation. The cleanup PR will move those into `security/` and
`pipeline/db_pipeline/` respectively.
"""
import asyncio
import hashlib
from datetime import UTC, datetime
from decimal import Decimal
from typing import Any
from src.database_client.database_client_service import database_client_service
from src.db.postgres.connection import AsyncSessionLocal
from src.middlewares.logging import get_logger
from src.pipeline.db_pipeline import db_pipeline_service
from src.pipeline.db_pipeline.extractor import (
get_row_count,
get_schema,
profile_column,
)
from src.utils.db_credential_encryption import decrypt_credentials_dict
from ..models import Column, ColumnStats, DataType, ForeignKey, Source, Table
from ..pii_detector import PIIDetector
from .base import BaseIntrospector
logger = get_logger("db_introspector")
_DBCLIENT_PREFIX = "dbclient://"
def _stable_id(prefix: str, *parts: str) -> str:
"""Deterministic short ID from joined parts. Survives renames at the
`name` field while preserving identity for cached IRs.
Hash is non-cryptographic (identifier only).
"""
h = hashlib.sha1(
"/".join(parts).encode("utf-8"), usedforsecurity=False
).hexdigest()[:12]
return f"{prefix}{h}"
def _map_sql_type(sql_type: str) -> DataType:
"""Map a stringified SQLAlchemy type to a Catalog DataType.
Matches on substring of the SQLAlchemy type repr (e.g. 'INTEGER',
'TIMESTAMP', 'BOOLEAN'). Conservative — unknowns fall back to "string"
so the column is at least addressable.
"""
s = sql_type.upper()
if "INT" in s:
return "int"
if "FLOAT" in s or "NUMERIC" in s or "DECIMAL" in s or "REAL" in s or "DOUBLE" in s:
return "decimal"
if "BOOL" in s:
return "bool"
if "TIMESTAMP" in s or "DATETIME" in s:
return "datetime"
if "DATE" in s:
return "date"
if "JSON" in s:
return "json"
return "string"
def _normalize(v: Any) -> Any:
"""Coerce non-JSON-native scalars (Decimal, numpy, datetime) to types
that survive the jsonb round-trip when the catalog is persisted.
"""
if v is None:
return None
if isinstance(v, Decimal):
return float(v)
try:
import numpy as np
if isinstance(v, np.generic):
return v.item()
except ImportError:
pass
if isinstance(v, datetime):
return v.isoformat()
return v
class DatabaseIntrospector(BaseIntrospector):
"""Connect to user DB → read information_schema → sample 100 rows/table."""
def __init__(self) -> None:
self._pii = PIIDetector()
async def introspect(self, location_ref: str) -> Source:
if not location_ref.startswith(_DBCLIENT_PREFIX):
raise ValueError(
f"DatabaseIntrospector expects 'dbclient://...' location_ref, "
f"got {location_ref!r}"
)
client_id = location_ref[len(_DBCLIENT_PREFIX):]
if not client_id:
raise ValueError("location_ref is missing client_id after 'dbclient://'")
async with AsyncSessionLocal() as session:
client = await database_client_service.get(session, client_id)
if client is None:
raise ValueError(f"DatabaseClient {client_id!r} not found")
creds = decrypt_credentials_dict(client.credentials)
logger.info(
"introspecting db source",
client_id=client_id,
db_type=client.db_type,
name=client.name,
)
# SQLAlchemy inspect() + pandas read_sql are synchronous — run in a
# threadpool so the event loop stays free.
tables: list[Table] = await asyncio.to_thread(
self._introspect_sync, client.db_type, creds
)
return Source(
source_id=client_id,
source_type="schema",
name=client.name,
location_ref=location_ref,
updated_at=datetime.now(UTC),
tables=tables,
)
def _introspect_sync(self, db_type: str, creds: dict) -> list[Table]:
with db_pipeline_service.engine_scope(db_type, creds) as engine:
schema = get_schema(engine)
tables: list[Table] = []
for table_name, cols in schema.items():
try:
row_count = get_row_count(engine, table_name)
except Exception as e:
logger.error(
"row_count failed; skipping table",
table=table_name,
error=str(e),
)
continue
columns: list[Column] = []
for col in cols:
try:
profile = profile_column(
engine,
table_name,
col["name"],
col.get("is_numeric", False),
row_count,
is_temporal=col.get("is_temporal", False),
)
except Exception as e:
logger.error(
"profile_column failed; skipping column",
table=table_name,
column=col["name"],
error=str(e),
)
continue
columns.append(self._to_column(table_name, col, profile))
foreign_keys = self._extract_foreign_keys(table_name, cols)
tables.append(
Table(
table_id=_stable_id("t_", table_name),
name=table_name,
row_count=row_count,
columns=columns,
foreign_keys=foreign_keys,
)
)
return tables
@staticmethod
def _extract_foreign_keys(
table_name: str, cols: list[dict[str, Any]]
) -> list[ForeignKey]:
"""Convert extractor's `foreign_key: 'target_table.target_col'` strings
into ForeignKey objects with stable IDs (derived deterministically from
names — same scheme used to generate table_id / column_id elsewhere).
"""
fks: list[ForeignKey] = []
for col in cols:
fk_str = col.get("foreign_key")
if not fk_str:
continue
target_table, _, target_col = fk_str.partition(".")
if not target_table or not target_col:
continue
fks.append(
ForeignKey(
column_id=_stable_id("c_", table_name, col["name"]),
target_table_id=_stable_id("t_", target_table),
target_column_id=_stable_id("c_", target_table, target_col),
)
)
return fks
def _to_column(
self, table_name: str, col: dict[str, Any], profile: dict[str, Any]
) -> Column:
name = col["name"]
sample_values: list[Any] | None = [
_normalize(v) for v in (profile.get("sample_values") or [])
] or None
top_raw = profile.get("top_values") or []
top_values: list[Any] | None = [
_normalize(v) for v, _cnt in top_raw
] or None
column = Column(
column_id=_stable_id("c_", table_name, name),
name=name,
data_type=_map_sql_type(str(col["type"])),
nullable=True, # nullable not surfaced by extractor; default permissive
pii_flag=False,
sample_values=sample_values,
stats=ColumnStats(
min=_normalize(profile.get("min")),
max=_normalize(profile.get("max")),
mean=_normalize(profile.get("mean")),
median=_normalize(profile.get("median")),
distinct_count=profile.get("distinct_count"),
top_values=top_values,
),
)
if self._pii.detect(column):
return column.model_copy(update={"pii_flag": True, "sample_values": None})
return column
database_introspector = DatabaseIntrospector()