| """Schema introspection and per-column profiling for a user's database. |
| |
| Identifiers (table/column names) are quoted via the engine's dialect preparer, |
| which handles reserved words, mixed case, and embedded quotes correctly across |
| dialects. Values used in SQL come from SQLAlchemy inspection of the DB itself, |
| not user input. |
| """ |
|
|
| from typing import Optional |
|
|
| import pandas as pd |
| from sqlalchemy import Float, Integer, Numeric, inspect |
| from sqlalchemy.engine import Engine |
|
|
| from src.middlewares.logging import get_logger |
|
|
| logger = get_logger("db_extractor") |
|
|
| TOP_VALUES_THRESHOLD = 0.05 |
|
|
| |
| |
| |
| _MEDIAN_DIALECTS = frozenset({"postgresql", "mssql", "snowflake"}) |
|
|
|
|
| def _supports_median(engine: Engine) -> bool: |
| return engine.dialect.name in _MEDIAN_DIALECTS |
|
|
|
|
| def _head_query( |
| engine: Engine, |
| select_clause: str, |
| from_clause: str, |
| n: int, |
| order_by: str = "", |
| ) -> str: |
| """LIMIT/TOP-equivalent head query for the engine's dialect.""" |
| if engine.dialect.name == "mssql": |
| return f"SELECT TOP {n} {select_clause} FROM {from_clause} {order_by}".strip() |
| return f"SELECT {select_clause} FROM {from_clause} {order_by} LIMIT {n}".strip() |
|
|
|
|
| def _qi(engine: Engine, name: str) -> str: |
| """Dialect-correct identifier quoting (schema.table also handled if dotted).""" |
| preparer = engine.dialect.identifier_preparer |
| if "." in name: |
| schema, _, table = name.partition(".") |
| return f"{preparer.quote(schema)}.{preparer.quote(table)}" |
| return preparer.quote(name) |
|
|
|
|
| def get_schema( |
| engine: Engine, exclude_tables: Optional[frozenset[str]] = None |
| ) -> dict[str, list[dict]]: |
| """Returns {table_name: [{name, type, is_numeric, is_primary_key, foreign_key}, ...]}.""" |
| exclude = exclude_tables or frozenset() |
| inspector = inspect(engine) |
| schema = {} |
| for table_name in inspector.get_table_names(): |
| if table_name in exclude: |
| continue |
|
|
| pk = inspector.get_pk_constraint(table_name) |
| pk_cols = set(pk["constrained_columns"]) if pk else set() |
|
|
| fk_map = {} |
| for fk in inspector.get_foreign_keys(table_name): |
| for col, ref_col in zip(fk["constrained_columns"], fk["referred_columns"]): |
| fk_map[col] = f"{fk['referred_table']}.{ref_col}" |
|
|
| cols = inspector.get_columns(table_name) |
| schema[table_name] = [ |
| { |
| "name": c["name"], |
| "type": str(c["type"]), |
| "is_numeric": isinstance(c["type"], (Integer, Numeric, Float)), |
| "is_primary_key": c["name"] in pk_cols, |
| "foreign_key": fk_map.get(c["name"]), |
| } |
| for c in cols |
| ] |
| logger.info("extracted schema", table_count=len(schema)) |
| return schema |
|
|
|
|
| def get_row_count(engine: Engine, table_name: str) -> int: |
| |
| |
| return int(pd.read_sql(f"SELECT COUNT(*) FROM {_qi(engine, table_name)}", engine).iloc[0, 0]) |
|
|
|
|
| def profile_column( |
| engine: Engine, |
| table_name: str, |
| col_name: str, |
| is_numeric: bool, |
| row_count: int, |
| ) -> dict: |
| """Returns null_count, distinct_count, min/max, top values, and sample values.""" |
| if row_count == 0: |
| return { |
| "null_count": 0, |
| "distinct_count": 0, |
| "distinct_ratio": 0.0, |
| "sample_values": [], |
| } |
|
|
| qt = _qi(engine, table_name) |
| qc = _qi(engine, col_name) |
|
|
| |
| |
| select_cols = [ |
| f"COUNT(*) - COUNT({qc}) AS nulls", |
| f"COUNT(DISTINCT {qc}) AS distincts", |
| ] |
| if is_numeric: |
| select_cols.append(f"MIN({qc}) AS min_val") |
| select_cols.append(f"MAX({qc}) AS max_val") |
| select_cols.append(f"AVG({qc}) AS mean_val") |
| if _supports_median(engine): |
| select_cols.append( |
| f"PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY {qc}) AS median_val" |
| ) |
| stats = pd.read_sql(f"SELECT {', '.join(select_cols)} FROM {qt}", engine) |
|
|
| null_count = int(stats.iloc[0]["nulls"]) |
| distinct_count = int(stats.iloc[0]["distincts"]) |
| distinct_ratio = distinct_count / row_count if row_count > 0 else 0 |
|
|
| profile = { |
| "null_count": null_count, |
| "distinct_count": distinct_count, |
| "distinct_ratio": round(distinct_ratio, 4), |
| } |
|
|
| if is_numeric: |
| profile["min"] = stats.iloc[0]["min_val"] |
| profile["max"] = stats.iloc[0]["max_val"] |
| profile["mean"] = stats.iloc[0]["mean_val"] |
| if _supports_median(engine): |
| profile["median"] = stats.iloc[0]["median_val"] |
|
|
| if 0 < distinct_ratio <= TOP_VALUES_THRESHOLD: |
| top_sql = _head_query( |
| engine, |
| select_clause=f"{qc}, COUNT(*) AS cnt", |
| from_clause=f"{qt} GROUP BY {qc}", |
| n=10, |
| order_by="ORDER BY cnt DESC", |
| ) |
| top = pd.read_sql(top_sql, engine) |
| profile["top_values"] = list(zip(top.iloc[:, 0].tolist(), top["cnt"].tolist())) |
|
|
| sample = pd.read_sql(_head_query(engine, qc, qt, 5), engine) |
| profile["sample_values"] = sample.iloc[:, 0].tolist() |
|
|
| return profile |
|
|
|
|
| def profile_table(engine: Engine, table_name: str, columns: list[dict]) -> list[dict]: |
| """Profile every column in a table. Returns [{col, profile, text}, ...]. |
| |
| Per-column errors are logged and skipped so one bad column doesn't abort |
| the whole table. |
| """ |
| row_count = get_row_count(engine, table_name) |
| if row_count == 0: |
| logger.info("skipping empty table", table=table_name) |
| return [] |
|
|
| results = [] |
| for col in columns: |
| try: |
| profile = profile_column( |
| engine, table_name, col["name"], col.get("is_numeric", False), row_count |
| ) |
| text = build_text(table_name, row_count, col, profile) |
| results.append({"col": col, "profile": profile, "text": text}) |
| except Exception as e: |
| logger.error( |
| "column profiling failed", |
| table=table_name, |
| column=col["name"], |
| error=str(e), |
| ) |
| continue |
| return results |
|
|
|
|
| def fetch_sample_row(engine: Engine, table_name: str) -> Optional[dict]: |
| """First row of the table as a dict, or None if the table is empty. |
| |
| Reuses _qi for dialect-correct quoting and _head_query for TOP/LIMIT. |
| """ |
| qt = _qi(engine, table_name) |
| sql = _head_query(engine, "*", qt, 1) |
| df = pd.read_sql(sql, engine) |
| if df.empty: |
| return None |
| return df.iloc[0].to_dict() |
|
|
|
|
| def build_table_chunk( |
| table_name: str, |
| row_count: int, |
| columns: list[dict], |
| column_profiles: list[dict], |
| sample_row: Optional[dict], |
| ) -> str: |
| """Build the table-level chunk text. |
| |
| Format (lines omitted when not applicable): |
| Table: {name} ({row_count} rows) |
| Primary key: {pk_cols} |
| Foreign keys: {col} -> {target_table}.{target_col}, ... |
| Columns ({n}): {col1}, {col2}, ... |
| Numeric ranges: {col} [{min}-{max}], ... |
| Sample row: {dict} |
| |
| Pure formatter — no DB I/O. column_profiles is the output of profile_table |
| and is reused so we don't re-introspect. |
| """ |
| lines = [f"Table: {table_name} ({row_count} rows)"] |
|
|
| pk_cols = [c["name"] for c in columns if c.get("is_primary_key")] |
| if pk_cols: |
| lines.append(f"Primary key: {', '.join(pk_cols)}") |
|
|
| fk_parts = [ |
| f"{c['name']} -> {c['foreign_key']}" for c in columns if c.get("foreign_key") |
| ] |
| if fk_parts: |
| lines.append(f"Foreign keys: {', '.join(fk_parts)}") |
|
|
| col_names = [c["name"] for c in columns] |
| lines.append(f"Columns ({len(col_names)}): {', '.join(col_names)}") |
|
|
| range_parts = [] |
| for entry in column_profiles: |
| col = entry["col"] |
| profile = entry["profile"] |
| if not col.get("is_numeric"): |
| continue |
| mn = profile.get("min") |
| mx = profile.get("max") |
| if mn is None or mx is None: |
| continue |
| range_parts.append(f"{col['name']} [{mn}-{mx}]") |
| if range_parts: |
| lines.append(f"Numeric ranges: {', '.join(range_parts)}") |
|
|
| if sample_row is not None: |
| lines.append(f"Sample row: {sample_row}") |
|
|
| return "\n".join(lines) |
|
|
|
|
| def build_text(table_name: str, row_count: int, col: dict, profile: dict) -> str: |
| col_name = col["name"] |
| col_type = col["type"] |
|
|
| key_label = "" |
| if col.get("is_primary_key"): |
| key_label = " [PRIMARY KEY]" |
| elif col.get("foreign_key"): |
| key_label = f" [FK -> {col['foreign_key']}]" |
|
|
| text = f"Table: {table_name} ({row_count} rows)\n" |
| text += f"Column: {col_name} ({col_type}){key_label}\n" |
| text += f"Null count: {profile['null_count']}\n" |
| text += f"Distinct count: {profile['distinct_count']} ({profile['distinct_ratio']:.1%})\n" |
| if "min" in profile: |
| text += f"Min: {profile['min']}, Max: {profile['max']}\n" |
| text += f"Mean: {profile['mean']}\n" |
| if profile.get("median") is not None: |
| text += f"Median: {profile['median']}\n" |
| if "top_values" in profile: |
| top_str = ", ".join(f"{v} ({c})" for v, c in profile["top_values"]) |
| text += f"Top values: {top_str}\n" |
| text += f"Sample values: {profile['sample_values']}" |
| return text |
|
|