"""Shared helpers for DataForge CLI commands.""" from __future__ import annotations from collections.abc import Iterable from pathlib import Path from typing import cast import pandas as pd import typer import yaml from dataforge.verifier.schema import ( AggregateDependency, AggregateLiteral, DomainBound, FunctionalDependency, Schema, ) def schema_from_mapping(raw_mapping: object) -> Schema: """Build a Schema from a raw YAML mapping-like payload. Args: raw_mapping: Untrusted YAML-decoded value. Returns: Parsed Schema object. Raises: typer.BadParameter: If the payload is not a mapping. """ if raw_mapping is None: mapping: dict[str, object] = {} elif isinstance(raw_mapping, dict): mapping = raw_mapping else: raise typer.BadParameter("Schema payload must be a YAML mapping.") columns: dict[str, str] = {} raw_columns = mapping.get("columns", {}) if isinstance(raw_columns, dict): columns = {str(key): str(value) for key, value in raw_columns.items()} fds: list[FunctionalDependency] = [] raw_fds = mapping.get("functional_dependencies", []) if isinstance(raw_fds, list): for raw_fd in raw_fds: if not isinstance(raw_fd, dict): continue raw_determinant = raw_fd.get("determinant", []) determinant_values = ( tuple(str(value) for value in raw_determinant) if isinstance(raw_determinant, Iterable) and not isinstance(raw_determinant, (str, bytes)) else () ) fds.append( FunctionalDependency( determinant=determinant_values, dependent=str(raw_fd.get("dependent", "")), ) ) raw_pii_columns = mapping.get("pii_columns", []) pii_columns = ( frozenset(str(value) for value in raw_pii_columns) if isinstance(raw_pii_columns, Iterable) and not isinstance(raw_pii_columns, (str, bytes)) else frozenset() ) bounds: list[DomainBound] = [] raw_bounds = mapping.get("domain_bounds", {}) if isinstance(raw_bounds, dict): for column, bound_payload in raw_bounds.items(): if not isinstance(bound_payload, dict): continue bounds.append( DomainBound( column=str(column), min_value=( float(bound_payload["min"]) if bound_payload.get("min") is not None else None ), max_value=( float(bound_payload["max"]) if bound_payload.get("max") is not None else None ), inclusive_min=bool(bound_payload.get("inclusive_min", True)), inclusive_max=bool(bound_payload.get("inclusive_max", True)), ) ) aggregate_dependencies: list[AggregateDependency] = [] raw_aggregates = mapping.get("aggregate_dependencies", []) if isinstance(raw_aggregates, list): for raw_dependency in raw_aggregates: if not isinstance(raw_dependency, dict): continue raw_aggregate = str(raw_dependency.get("aggregate", "")).lower() if raw_aggregate not in {"sum", "avg"}: continue raw_group_by = raw_dependency.get("group_by", []) group_by = ( tuple(str(value) for value in raw_group_by) if isinstance(raw_group_by, Iterable) and not isinstance(raw_group_by, (str, bytes)) else () ) aggregate_dependencies.append( AggregateDependency( source_column=str(raw_dependency.get("source_column", "")), aggregate=cast(AggregateLiteral, raw_aggregate), target_column=str(raw_dependency.get("target_column", "")), group_by=group_by, ) ) return Schema( columns=columns, functional_dependencies=tuple(fds), pii_columns=pii_columns, domain_bounds=tuple(bounds), aggregate_dependencies=tuple(aggregate_dependencies), ) def load_schema(schema_path: Path) -> Schema: """Load a Schema from a YAML file. Args: schema_path: Path to the YAML schema file. Returns: Parsed Schema object. Raises: typer.BadParameter: If the schema file is malformed or unreadable. """ try: raw = yaml.safe_load(schema_path.read_text(encoding="utf-8")) except OSError as exc: raise typer.BadParameter(f"Could not read schema file '{schema_path}': {exc}") from exc if raw is not None and not isinstance(raw, dict): raise typer.BadParameter(f"Schema file '{schema_path}' must be a YAML mapping.") return schema_from_mapping(raw) def read_csv(path: Path) -> pd.DataFrame: """Read a CSV using conservative string-preserving defaults. Args: path: CSV path. Returns: A DataFrame with string-preserved values. """ return pd.read_csv(path, dtype=str, keep_default_na=False, na_filter=False)