File size: 8,044 Bytes
6bff5d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
"""Tabular file schema introspection (Parquet / CSV / XLSX).

Reads file headers + samples ~100 rows. For XLSX, each sheet becomes a Table.
Files are expected to live in Azure Blob (location_ref like az_blob://{user_id}/{document_id}).

Table.name convention (executor contract)
-----------------------------------------
  CSV / Parquet  → Table.name = filename stem (e.g. "sales_data").
                   Parquet blob was uploaded without a sheet suffix, so the
                   executor must call parquet_blob_name(uid, did, sheet_name=None).
  XLSX           → Table.name = sheet_name (e.g. "Sheet1").
                   Executor calls parquet_blob_name(uid, did, table.name).
"""

import asyncio
import hashlib
from collections.abc import Callable, Coroutine
from datetime import UTC, datetime
from io import BytesIO
from pathlib import Path
from typing import Any

import pandas as pd

from src.middlewares.logging import get_logger

from ..models import Column, ColumnStats, DataType, Source, Table
from ..pii_detector import PIIDetector
from .base import BaseIntrospector

logger = get_logger("tabular_introspector")

_AZ_BLOB_PREFIX = "az_blob://"


def _stable_id(prefix: str, *parts: str) -> str:
    h = hashlib.sha1(
        "/".join(parts).encode("utf-8"), usedforsecurity=False
    ).hexdigest()[:12]
    return f"{prefix}{h}"


def _map_pandas_type(dtype: Any) -> DataType:
    s = str(dtype).lower()
    if "int" in s:
        return "int"
    if "float" in s or "decimal" in s:
        return "decimal"
    if "bool" in s:
        return "bool"
    if "datetime" in s:
        return "datetime"
    if "date" in s:
        return "date"
    return "string"


def _normalize(v: Any) -> Any:
    """Coerce non-JSON-native scalars to types that survive the jsonb round-trip."""
    if v is None:
        return None
    try:
        import numpy as np

        if isinstance(v, np.generic):
            return v.item()
    except ImportError:
        pass
    if isinstance(v, datetime):
        return v.isoformat()
    return v


class TabularIntrospector(BaseIntrospector):
    """Read column names, dtypes, and sample values from Parquet/CSV/XLSX.

    Heavy I/O dependencies (`fetch_doc`, `fetch_blob`) are injectable so unit
    tests can pass mocks without triggering Settings or DB construction.
    """

    def __init__(
        self,
        fetch_doc: Callable[[str], Coroutine[Any, Any, Any]] | None = None,
        fetch_blob: Callable[[str], Coroutine[Any, Any, bytes]] | None = None,
    ) -> None:
        self._pii = PIIDetector()
        self._fetch_doc = fetch_doc or self._default_fetch_doc
        self._fetch_blob = fetch_blob or self._default_fetch_blob

    @staticmethod
    async def _default_fetch_doc(document_id: str) -> Any:
        from sqlalchemy import select

        from src.db.postgres.connection import AsyncSessionLocal
        from src.db.postgres.models import Document as DBDocument

        async with AsyncSessionLocal() as session:
            result = await session.execute(
                select(DBDocument).where(DBDocument.id == document_id)
            )
            return result.scalar_one_or_none()

    @staticmethod
    async def _default_fetch_blob(blob_name: str) -> bytes:
        from src.storage.az_blob.az_blob import blob_storage

        return await blob_storage.download_file(blob_name)

    async def introspect(self, location_ref: str) -> Source:
        if not location_ref.startswith(_AZ_BLOB_PREFIX):
            raise ValueError(
                f"TabularIntrospector expects 'az_blob://...' location_ref, "
                f"got {location_ref!r}"
            )
        rest = location_ref[len(_AZ_BLOB_PREFIX):]
        user_id, _, document_id = rest.partition("/")
        if not user_id or not document_id:
            raise ValueError(
                f"location_ref must be 'az_blob://{{user_id}}/{{document_id}}', "
                f"got {location_ref!r}"
            )

        doc = await self._fetch_doc(document_id)
        if doc is None:
            raise ValueError(f"Document {document_id!r} not found")

        logger.info(
            "introspecting tabular source",
            document_id=document_id,
            file_type=doc.file_type,
            filename=doc.filename,
        )

        content = await self._fetch_blob(doc.blob_name)

        tables: list[Table] = await asyncio.to_thread(
            self._introspect_sync, content, doc.file_type, doc.filename, document_id
        )

        return Source(
            source_id=document_id,
            source_type="tabular",
            name=doc.filename,
            location_ref=location_ref,
            updated_at=datetime.now(UTC),
            tables=tables,
        )

    def _introspect_sync(
        self,
        content: bytes,
        file_type: str,
        filename: str,
        document_id: str,
    ) -> list[Table]:
        if file_type == "csv":
            df = pd.read_csv(BytesIO(content))
            return [self._build_table(df, document_id, Path(filename).stem, sheet_name=None)]
        if file_type == "xlsx":
            sheets: dict[str, pd.DataFrame] = pd.read_excel(BytesIO(content), sheet_name=None)
            return [
                self._build_table(df, document_id, sheet_name, sheet_name=sheet_name)
                for sheet_name, df in sheets.items()
            ]
        if file_type == "parquet":
            df = pd.read_parquet(BytesIO(content))
            return [self._build_table(df, document_id, Path(filename).stem, sheet_name=None)]
        raise ValueError(f"Unsupported file_type {file_type!r} for tabular introspection")

    def _build_table(
        self,
        df: pd.DataFrame,
        document_id: str,
        table_name: str,
        sheet_name: str | None,
    ) -> Table:
        id_parts = (document_id, sheet_name) if sheet_name else (document_id,)
        columns = [
            self._to_column(df[col], document_id, sheet_name, col)
            for col in df.columns
        ]
        return Table(
            table_id=_stable_id("t_", *id_parts),
            name=table_name,
            row_count=len(df),
            columns=columns,
            foreign_keys=[],
        )

    def _to_column(
        self,
        series: pd.Series,
        document_id: str,
        sheet_name: str | None,
        col_name: str,
    ) -> Column:
        id_parts = (
            (document_id, sheet_name, col_name) if sheet_name else (document_id, col_name)
        )

        sample_raw = series.dropna().head(3).tolist()
        sample_values: list[Any] | None = [_normalize(v) for v in sample_raw] or None

        is_numeric = pd.api.types.is_numeric_dtype(series)
        is_dt = pd.api.types.is_datetime64_any_dtype(series)
        non_null = series.dropna()
        distinct_count = int(series.nunique())
        top_values = (
            [_normalize(v) for v in non_null.unique().tolist()]
            if distinct_count <= 10
            else None
        )
        has_values = len(non_null) > 0
        wants_range = (is_numeric or is_dt) and has_values
        wants_mean = is_numeric and has_values
        stats = ColumnStats(
            min=_normalize(non_null.min()) if wants_range else None,
            max=_normalize(non_null.max()) if wants_range else None,
            mean=float(non_null.mean()) if wants_mean else None,
            median=float(non_null.median()) if wants_mean else None,
            distinct_count=distinct_count,
            top_values=top_values,
        )

        column = Column(
            column_id=_stable_id("c_", *id_parts),
            name=col_name,
            data_type=_map_pandas_type(series.dtype),
            nullable=bool(series.isnull().any()),
            pii_flag=False,
            sample_values=sample_values,
            stats=stats,
        )
        if self._pii.detect(column):
            return column.model_copy(update={"pii_flag": True, "sample_values": None})
        return column


tabular_introspector = TabularIntrospector()