| """Canonical schema models shared by detectors, safety, and the verifier.""" |
|
|
| from __future__ import annotations |
|
|
| from typing import Literal |
|
|
| from pydantic import ConfigDict, Field |
| from pydantic.dataclasses import dataclass |
|
|
| AggregateLiteral = Literal["sum", "avg"] |
|
|
| _CONFIG = ConfigDict(frozen=True) |
|
|
|
|
| @dataclass(config=_CONFIG, kw_only=True) |
| class FunctionalDependency: |
| """Declared functional dependency: determinant columns -> dependent column.""" |
|
|
| determinant: tuple[str, ...] = Field(min_length=1) |
| dependent: str = Field(min_length=1) |
|
|
|
|
| @dataclass(config=_CONFIG, kw_only=True) |
| class DomainBound: |
| """Numeric min/max bounds for a column.""" |
|
|
| column: str = Field(min_length=1) |
| min_value: float | None = None |
| max_value: float | None = None |
| inclusive_min: bool = True |
| inclusive_max: bool = True |
|
|
|
|
| @dataclass(config=_CONFIG, kw_only=True) |
| class AggregateDependency: |
| """Metadata describing a source column used in an aggregate elsewhere.""" |
|
|
| source_column: str = Field(min_length=1) |
| target_column: str = Field(min_length=1) |
| aggregate: AggregateLiteral |
| group_by: tuple[str, ...] = Field(default_factory=tuple) |
|
|
|
|
| @dataclass(config=_CONFIG, kw_only=True) |
| class Schema: |
| """Optional declared schema for a dataset.""" |
|
|
| columns: dict[str, str] = Field(default_factory=dict) |
| functional_dependencies: tuple[FunctionalDependency, ...] = Field(default_factory=tuple) |
| pii_columns: frozenset[str] = Field(default_factory=frozenset) |
| domain_bounds: tuple[DomainBound, ...] = Field(default_factory=tuple) |
| aggregate_dependencies: tuple[AggregateDependency, ...] = Field(default_factory=tuple) |
|
|
| def column_type(self, column: str) -> str | None: |
| """Return the declared type for a column, if any.""" |
| return self.columns.get(column) |
|
|
| def domain_bounds_for(self, column: str) -> tuple[DomainBound, ...]: |
| """Return all domain bounds declared for the given column.""" |
| return tuple(bound for bound in self.domain_bounds if bound.column == column) |
|
|
| def aggregate_dependencies_for(self, column: str) -> tuple[AggregateDependency, ...]: |
| """Return aggregate dependencies where the column is the source input.""" |
| return tuple( |
| dependency |
| for dependency in self.aggregate_dependencies |
| if dependency.source_column == column |
| ) |
|
|