Spaces:
Sleeping
Sleeping
| """Shared helpers for DataForge CLI commands.""" | |
| from __future__ import annotations | |
| from collections.abc import Iterable | |
| from pathlib import Path | |
| from typing import cast | |
| import pandas as pd | |
| import typer | |
| import yaml | |
| from dataforge.verifier.schema import ( | |
| AggregateDependency, | |
| AggregateLiteral, | |
| DomainBound, | |
| FunctionalDependency, | |
| Schema, | |
| ) | |
| def schema_from_mapping(raw_mapping: object) -> Schema: | |
| """Build a Schema from a raw YAML mapping-like payload. | |
| Args: | |
| raw_mapping: Untrusted YAML-decoded value. | |
| Returns: | |
| Parsed Schema object. | |
| Raises: | |
| typer.BadParameter: If the payload is not a mapping. | |
| """ | |
| if raw_mapping is None: | |
| mapping: dict[str, object] = {} | |
| elif isinstance(raw_mapping, dict): | |
| mapping = raw_mapping | |
| else: | |
| raise typer.BadParameter("Schema payload must be a YAML mapping.") | |
| columns: dict[str, str] = {} | |
| raw_columns = mapping.get("columns", {}) | |
| if isinstance(raw_columns, dict): | |
| columns = {str(key): str(value) for key, value in raw_columns.items()} | |
| fds: list[FunctionalDependency] = [] | |
| raw_fds = mapping.get("functional_dependencies", []) | |
| if isinstance(raw_fds, list): | |
| for raw_fd in raw_fds: | |
| if not isinstance(raw_fd, dict): | |
| continue | |
| raw_determinant = raw_fd.get("determinant", []) | |
| determinant_values = ( | |
| tuple(str(value) for value in raw_determinant) | |
| if isinstance(raw_determinant, Iterable) | |
| and not isinstance(raw_determinant, (str, bytes)) | |
| else () | |
| ) | |
| fds.append( | |
| FunctionalDependency( | |
| determinant=determinant_values, | |
| dependent=str(raw_fd.get("dependent", "")), | |
| ) | |
| ) | |
| raw_pii_columns = mapping.get("pii_columns", []) | |
| pii_columns = ( | |
| frozenset(str(value) for value in raw_pii_columns) | |
| if isinstance(raw_pii_columns, Iterable) and not isinstance(raw_pii_columns, (str, bytes)) | |
| else frozenset() | |
| ) | |
| bounds: list[DomainBound] = [] | |
| raw_bounds = mapping.get("domain_bounds", {}) | |
| if isinstance(raw_bounds, dict): | |
| for column, bound_payload in raw_bounds.items(): | |
| if not isinstance(bound_payload, dict): | |
| continue | |
| bounds.append( | |
| DomainBound( | |
| column=str(column), | |
| min_value=( | |
| float(bound_payload["min"]) | |
| if bound_payload.get("min") is not None | |
| else None | |
| ), | |
| max_value=( | |
| float(bound_payload["max"]) | |
| if bound_payload.get("max") is not None | |
| else None | |
| ), | |
| inclusive_min=bool(bound_payload.get("inclusive_min", True)), | |
| inclusive_max=bool(bound_payload.get("inclusive_max", True)), | |
| ) | |
| ) | |
| aggregate_dependencies: list[AggregateDependency] = [] | |
| raw_aggregates = mapping.get("aggregate_dependencies", []) | |
| if isinstance(raw_aggregates, list): | |
| for raw_dependency in raw_aggregates: | |
| if not isinstance(raw_dependency, dict): | |
| continue | |
| raw_aggregate = str(raw_dependency.get("aggregate", "")).lower() | |
| if raw_aggregate not in {"sum", "avg"}: | |
| continue | |
| raw_group_by = raw_dependency.get("group_by", []) | |
| group_by = ( | |
| tuple(str(value) for value in raw_group_by) | |
| if isinstance(raw_group_by, Iterable) and not isinstance(raw_group_by, (str, bytes)) | |
| else () | |
| ) | |
| aggregate_dependencies.append( | |
| AggregateDependency( | |
| source_column=str(raw_dependency.get("source_column", "")), | |
| aggregate=cast(AggregateLiteral, raw_aggregate), | |
| target_column=str(raw_dependency.get("target_column", "")), | |
| group_by=group_by, | |
| ) | |
| ) | |
| return Schema( | |
| columns=columns, | |
| functional_dependencies=tuple(fds), | |
| pii_columns=pii_columns, | |
| domain_bounds=tuple(bounds), | |
| aggregate_dependencies=tuple(aggregate_dependencies), | |
| ) | |
| def load_schema(schema_path: Path) -> Schema: | |
| """Load a Schema from a YAML file. | |
| Args: | |
| schema_path: Path to the YAML schema file. | |
| Returns: | |
| Parsed Schema object. | |
| Raises: | |
| typer.BadParameter: If the schema file is malformed or unreadable. | |
| """ | |
| try: | |
| raw = yaml.safe_load(schema_path.read_text(encoding="utf-8")) | |
| except OSError as exc: | |
| raise typer.BadParameter(f"Could not read schema file '{schema_path}': {exc}") from exc | |
| if raw is not None and not isinstance(raw, dict): | |
| raise typer.BadParameter(f"Schema file '{schema_path}' must be a YAML mapping.") | |
| return schema_from_mapping(raw) | |
| def read_csv(path: Path) -> pd.DataFrame: | |
| """Read a CSV using conservative string-preserving defaults. | |
| Args: | |
| path: CSV path. | |
| Returns: | |
| A DataFrame with string-preserved values. | |
| """ | |
| return pd.read_csv(path, dtype=str, keep_default_na=False, na_filter=False) | |