| """Tabular file schema introspection (Parquet / CSV / XLSX). |
| |
| Reads file headers + samples ~100 rows. For XLSX, each sheet becomes a Table. |
| Files are expected to live in Azure Blob (location_ref like az_blob://{user_id}/{document_id}). |
| |
| Table.name convention (executor contract) |
| ----------------------------------------- |
| CSV / Parquet → Table.name = filename stem (e.g. "sales_data"). |
| Parquet blob was uploaded without a sheet suffix, so the |
| executor must call parquet_blob_name(uid, did, sheet_name=None). |
| XLSX → Table.name = sheet_name (e.g. "Sheet1"). |
| Executor calls parquet_blob_name(uid, did, table.name). |
| """ |
|
|
| import asyncio |
| import hashlib |
| from collections.abc import Callable, Coroutine |
| from datetime import UTC, datetime |
| from io import BytesIO |
| from pathlib import Path |
| from typing import Any |
|
|
| import pandas as pd |
|
|
| from src.middlewares.logging import get_logger |
|
|
| from ..models import Column, ColumnStats, DataType, Source, Table |
| from ..pii_detector import PIIDetector |
| from .base import BaseIntrospector |
|
|
| logger = get_logger("tabular_introspector") |
|
|
| _AZ_BLOB_PREFIX = "az_blob://" |
|
|
|
|
| def _stable_id(prefix: str, *parts: str) -> str: |
| h = hashlib.sha1( |
| "/".join(parts).encode("utf-8"), usedforsecurity=False |
| ).hexdigest()[:12] |
| return f"{prefix}{h}" |
|
|
|
|
| def _map_pandas_type(dtype: Any) -> DataType: |
| s = str(dtype).lower() |
| if "int" in s: |
| return "int" |
| if "float" in s or "decimal" in s: |
| return "decimal" |
| if "bool" in s: |
| return "bool" |
| if "datetime" in s: |
| return "datetime" |
| if "date" in s: |
| return "date" |
| return "string" |
|
|
|
|
| def _normalize(v: Any) -> Any: |
| """Coerce non-JSON-native scalars to types that survive the jsonb round-trip.""" |
| if v is None: |
| return None |
| try: |
| import numpy as np |
|
|
| if isinstance(v, np.generic): |
| return v.item() |
| except ImportError: |
| pass |
| if isinstance(v, datetime): |
| return v.isoformat() |
| return v |
|
|
|
|
| class TabularIntrospector(BaseIntrospector): |
| """Read column names, dtypes, and sample values from Parquet/CSV/XLSX. |
| |
| Heavy I/O dependencies (`fetch_doc`, `fetch_blob`) are injectable so unit |
| tests can pass mocks without triggering Settings or DB construction. |
| """ |
|
|
| def __init__( |
| self, |
| fetch_doc: Callable[[str], Coroutine[Any, Any, Any]] | None = None, |
| fetch_blob: Callable[[str], Coroutine[Any, Any, bytes]] | None = None, |
| ) -> None: |
| self._pii = PIIDetector() |
| self._fetch_doc = fetch_doc or self._default_fetch_doc |
| self._fetch_blob = fetch_blob or self._default_fetch_blob |
|
|
| @staticmethod |
| async def _default_fetch_doc(document_id: str) -> Any: |
| from sqlalchemy import select |
|
|
| from src.db.postgres.connection import AsyncSessionLocal |
| from src.db.postgres.models import Document as DBDocument |
|
|
| async with AsyncSessionLocal() as session: |
| result = await session.execute( |
| select(DBDocument).where(DBDocument.id == document_id) |
| ) |
| return result.scalar_one_or_none() |
|
|
| @staticmethod |
| async def _default_fetch_blob(blob_name: str) -> bytes: |
| from src.storage.az_blob.az_blob import blob_storage |
|
|
| return await blob_storage.download_file(blob_name) |
|
|
| async def introspect(self, location_ref: str) -> Source: |
| if not location_ref.startswith(_AZ_BLOB_PREFIX): |
| raise ValueError( |
| f"TabularIntrospector expects 'az_blob://...' location_ref, " |
| f"got {location_ref!r}" |
| ) |
| rest = location_ref[len(_AZ_BLOB_PREFIX):] |
| user_id, _, document_id = rest.partition("/") |
| if not user_id or not document_id: |
| raise ValueError( |
| f"location_ref must be 'az_blob://{{user_id}}/{{document_id}}', " |
| f"got {location_ref!r}" |
| ) |
|
|
| doc = await self._fetch_doc(document_id) |
| if doc is None: |
| raise ValueError(f"Document {document_id!r} not found") |
|
|
| logger.info( |
| "introspecting tabular source", |
| document_id=document_id, |
| file_type=doc.file_type, |
| filename=doc.filename, |
| ) |
|
|
| content = await self._fetch_blob(doc.blob_name) |
|
|
| tables: list[Table] = await asyncio.to_thread( |
| self._introspect_sync, content, doc.file_type, doc.filename, document_id |
| ) |
|
|
| return Source( |
| source_id=document_id, |
| source_type="tabular", |
| name=doc.filename, |
| location_ref=location_ref, |
| updated_at=datetime.now(UTC), |
| tables=tables, |
| ) |
|
|
| def _introspect_sync( |
| self, |
| content: bytes, |
| file_type: str, |
| filename: str, |
| document_id: str, |
| ) -> list[Table]: |
| if file_type == "csv": |
| df = pd.read_csv(BytesIO(content)) |
| return [self._build_table(df, document_id, Path(filename).stem, sheet_name=None)] |
| if file_type == "xlsx": |
| sheets: dict[str, pd.DataFrame] = pd.read_excel(BytesIO(content), sheet_name=None) |
| return [ |
| self._build_table(df, document_id, sheet_name, sheet_name=sheet_name) |
| for sheet_name, df in sheets.items() |
| ] |
| if file_type == "parquet": |
| df = pd.read_parquet(BytesIO(content)) |
| return [self._build_table(df, document_id, Path(filename).stem, sheet_name=None)] |
| raise ValueError(f"Unsupported file_type {file_type!r} for tabular introspection") |
|
|
| def _build_table( |
| self, |
| df: pd.DataFrame, |
| document_id: str, |
| table_name: str, |
| sheet_name: str | None, |
| ) -> Table: |
| id_parts = (document_id, sheet_name) if sheet_name else (document_id,) |
| columns = [ |
| self._to_column(df[col], document_id, sheet_name, col) |
| for col in df.columns |
| ] |
| return Table( |
| table_id=_stable_id("t_", *id_parts), |
| name=table_name, |
| row_count=len(df), |
| columns=columns, |
| foreign_keys=[], |
| ) |
|
|
| def _to_column( |
| self, |
| series: pd.Series, |
| document_id: str, |
| sheet_name: str | None, |
| col_name: str, |
| ) -> Column: |
| id_parts = ( |
| (document_id, sheet_name, col_name) if sheet_name else (document_id, col_name) |
| ) |
|
|
| sample_raw = series.dropna().head(3).tolist() |
| sample_values: list[Any] | None = [_normalize(v) for v in sample_raw] or None |
|
|
| is_numeric = pd.api.types.is_numeric_dtype(series) |
| is_dt = pd.api.types.is_datetime64_any_dtype(series) |
| non_null = series.dropna() |
| distinct_count = int(series.nunique()) |
| top_values = ( |
| [_normalize(v) for v in non_null.unique().tolist()] |
| if distinct_count <= 10 |
| else None |
| ) |
| has_values = len(non_null) > 0 |
| wants_range = (is_numeric or is_dt) and has_values |
| wants_mean = is_numeric and has_values |
| stats = ColumnStats( |
| min=_normalize(non_null.min()) if wants_range else None, |
| max=_normalize(non_null.max()) if wants_range else None, |
| mean=float(non_null.mean()) if wants_mean else None, |
| median=float(non_null.median()) if wants_mean else None, |
| distinct_count=distinct_count, |
| top_values=top_values, |
| ) |
|
|
| column = Column( |
| column_id=_stable_id("c_", *id_parts), |
| name=col_name, |
| data_type=_map_pandas_type(series.dtype), |
| nullable=bool(series.isnull().any()), |
| pii_flag=False, |
| sample_values=sample_values, |
| stats=stats, |
| ) |
| if self._pii.detect(column): |
| return column.model_copy(update={"pii_flag": True, "sample_values": None}) |
| return column |
|
|
|
|
| tabular_introspector = TabularIntrospector() |
|
|