| """Constitution parsing and compiled-rule registry for the safety layer.""" |
|
|
| from __future__ import annotations |
|
|
| from collections.abc import Callable |
| from dataclasses import dataclass |
| from functools import lru_cache |
| from pathlib import Path |
| from typing import TYPE_CHECKING, Any, Literal |
|
|
| import yaml |
|
|
| from dataforge.repairers.base import ProposedFix |
| from dataforge.verifier.schema import Schema |
|
|
| if TYPE_CHECKING: |
| from dataforge.safety.filter import SafetyContext |
| else: |
| SafetyContext = Any |
|
|
| RuleTier = Literal["hard_never", "soft_require_confirm", "soft_prefer"] |
| SinglePredicate = Callable[[ProposedFix, Schema | None, SafetyContext], bool] |
| BatchPredicate = Callable[[list[ProposedFix]], bool] |
| PreferenceScorer = Callable[[ProposedFix, Schema | None, SafetyContext], int] |
|
|
|
|
| class ConstitutionError(ValueError): |
| """Raised when a constitution file is malformed or references unknown rules.""" |
|
|
|
|
| def _levenshtein_distance(left: str, right: str) -> int: |
| """Return the Levenshtein edit distance between two strings.""" |
| if left == right: |
| return 0 |
| if not left: |
| return len(right) |
| if not right: |
| return len(left) |
|
|
| previous = list(range(len(right) + 1)) |
| for i, left_char in enumerate(left, start=1): |
| current = [i] |
| for j, right_char in enumerate(right, start=1): |
| insert_cost = current[j - 1] + 1 |
| delete_cost = previous[j] + 1 |
| replace_cost = previous[j - 1] + (left_char != right_char) |
| current.append(min(insert_cost, delete_cost, replace_cost)) |
| previous = current |
| return previous[-1] |
|
|
|
|
| def _pii_overwrite( |
| proposed_fix: ProposedFix, |
| schema: Schema | None, |
| context: SafetyContext, |
| ) -> bool: |
| """Return whether a fix touches a column marked as PII.""" |
| del context |
| return schema is not None and proposed_fix.fix.column in schema.pii_columns |
|
|
|
|
| def _row_delete( |
| proposed_fix: ProposedFix, |
| schema: Schema | None, |
| context: SafetyContext, |
| ) -> bool: |
| """Return whether a proposed fix is deleting a row.""" |
| del schema, context |
| return proposed_fix.fix.operation == "delete_row" |
|
|
|
|
| def _aggregate_sensitive( |
| proposed_fix: ProposedFix, |
| schema: Schema | None, |
| context: SafetyContext, |
| ) -> bool: |
| """Return whether a fix edits a column used as an aggregate source.""" |
| del context |
| return schema is not None and bool(schema.aggregate_dependencies_for(proposed_fix.fix.column)) |
|
|
|
|
| def _conflicting_cell_writes(fixes: list[ProposedFix]) -> bool: |
| """Return whether multiple proposed fixes target the same cell differently.""" |
| seen: dict[tuple[int, str], str] = {} |
| for fix in fixes: |
| key = (fix.fix.row, fix.fix.column) |
| existing = seen.get(key) |
| if existing is not None and existing != fix.fix.new_value: |
| return True |
| seen[key] = fix.fix.new_value |
| return False |
|
|
|
|
| def _minimal_edit_distance( |
| proposed_fix: ProposedFix, |
| schema: Schema | None, |
| context: SafetyContext, |
| ) -> int: |
| """Score a candidate by edit distance from the original value.""" |
| del schema, context |
| return _levenshtein_distance(proposed_fix.fix.old_value, proposed_fix.fix.new_value) |
|
|
|
|
| _SINGLE_PREDICATES: dict[str, SinglePredicate] = { |
| "pii_overwrite": _pii_overwrite, |
| "row_delete": _row_delete, |
| "aggregate_sensitive": _aggregate_sensitive, |
| } |
| _BATCH_PREDICATES: dict[str, BatchPredicate] = { |
| "conflicting_cell_writes": _conflicting_cell_writes, |
| } |
| _SCORERS: dict[str, PreferenceScorer] = { |
| "minimal_edit_distance": _minimal_edit_distance, |
| } |
|
|
|
|
| @dataclass(frozen=True) |
| class CompiledSingleRule: |
| """Compiled single-fix safety rule.""" |
|
|
| rule_id: str |
| description: str |
| tier: RuleTier |
| predicate: SinglePredicate |
| override_flag: str | None = None |
| confirm_flag: str | None = None |
|
|
|
|
| @dataclass(frozen=True) |
| class CompiledBatchRule: |
| """Compiled batch safety rule.""" |
|
|
| rule_id: str |
| description: str |
| tier: RuleTier |
| predicate: BatchPredicate |
|
|
|
|
| @dataclass(frozen=True) |
| class CompiledPreferenceRule: |
| """Compiled candidate-preference rule.""" |
|
|
| rule_id: str |
| description: str |
| tier: RuleTier |
| scorer: PreferenceScorer |
|
|
|
|
| @dataclass(frozen=True) |
| class Constitution: |
| """Compiled constitution with rule registries by scope.""" |
|
|
| single_rules: tuple[CompiledSingleRule, ...] |
| batch_rules: tuple[CompiledBatchRule, ...] |
| preference_rules: tuple[CompiledPreferenceRule, ...] |
|
|
|
|
| def default_constitution_path() -> Path: |
| """Return the shipped default constitution path.""" |
| return Path(__file__).resolve().parents[2] / "constitutions" / "default.yaml" |
|
|
|
|
| def _expect_mapping(payload: object, *, message: str) -> dict[str, object]: |
| if not isinstance(payload, dict): |
| raise ConstitutionError(message) |
| return payload |
|
|
|
|
| def _build_single_rule(payload: dict[str, object], tier: RuleTier) -> CompiledSingleRule: |
| rule_id = str(payload.get("id", "")).strip() |
| description = str(payload.get("description", "")).strip() |
| predicate_name = str(payload.get("predicate", "")).strip() |
| if not rule_id or not description: |
| raise ConstitutionError(f"Invalid rule entry for tier '{tier}'.") |
| predicate = _SINGLE_PREDICATES.get(predicate_name) |
| if predicate is None: |
| raise ConstitutionError(f"Unknown predicate '{predicate_name}' in rule '{rule_id}'.") |
| return CompiledSingleRule( |
| rule_id=rule_id, |
| description=description, |
| tier=tier, |
| predicate=predicate, |
| override_flag=str(payload["override_flag"]) if payload.get("override_flag") else None, |
| confirm_flag=str(payload["confirm_flag"]) if payload.get("confirm_flag") else None, |
| ) |
|
|
|
|
| def _build_batch_rule(payload: dict[str, object], tier: RuleTier) -> CompiledBatchRule: |
| rule_id = str(payload.get("id", "")).strip() |
| description = str(payload.get("description", "")).strip() |
| predicate_name = str(payload.get("predicate", "")).strip() |
| if not rule_id or not description: |
| raise ConstitutionError(f"Invalid batch rule entry for tier '{tier}'.") |
| predicate = _BATCH_PREDICATES.get(predicate_name) |
| if predicate is None: |
| raise ConstitutionError(f"Unknown predicate '{predicate_name}' in rule '{rule_id}'.") |
| return CompiledBatchRule( |
| rule_id=rule_id, |
| description=description, |
| tier=tier, |
| predicate=predicate, |
| ) |
|
|
|
|
| def _build_preference_rule(payload: dict[str, object], tier: RuleTier) -> CompiledPreferenceRule: |
| rule_id = str(payload.get("id", "")).strip() |
| description = str(payload.get("description", "")).strip() |
| scorer_name = str(payload.get("scorer", "")).strip() |
| if not rule_id or not description: |
| raise ConstitutionError(f"Invalid preference rule entry for tier '{tier}'.") |
| if not scorer_name: |
| raise ConstitutionError(f"Preference rule '{rule_id}' must declare a scorer.") |
| scorer = _SCORERS.get(scorer_name) |
| if scorer is None: |
| raise ConstitutionError(f"Unknown scorer '{scorer_name}' in rule '{rule_id}'.") |
| return CompiledPreferenceRule( |
| rule_id=rule_id, |
| description=description, |
| tier=tier, |
| scorer=scorer, |
| ) |
|
|
|
|
| @lru_cache(maxsize=8) |
| def load_constitution(path: Path) -> Constitution: |
| """Load and compile a constitution YAML file.""" |
| raw_payload = yaml.safe_load(path.read_text(encoding="utf-8")) |
| root = _expect_mapping(raw_payload or {}, message="Constitution must be a YAML mapping.") |
|
|
| single_rules: list[CompiledSingleRule] = [] |
| batch_rules: list[CompiledBatchRule] = [] |
| preference_rules: list[CompiledPreferenceRule] = [] |
|
|
| for tier in ("hard_never", "soft_require_confirm", "soft_prefer"): |
| raw_rules = root.get(tier, []) |
| if not isinstance(raw_rules, list): |
| raise ConstitutionError(f"Tier '{tier}' must be a YAML list.") |
| for raw_rule in raw_rules: |
| payload = _expect_mapping( |
| raw_rule, message=f"Rule entries in '{tier}' must be mappings." |
| ) |
| scope = str(payload.get("scope", "single")).strip().lower() |
| if tier == "soft_prefer": |
| preference_rules.append(_build_preference_rule(payload, tier)) |
| continue |
| if scope == "batch": |
| batch_rules.append(_build_batch_rule(payload, tier)) |
| else: |
| single_rules.append(_build_single_rule(payload, tier)) |
|
|
| return Constitution( |
| single_rules=tuple(single_rules), |
| batch_rules=tuple(batch_rules), |
| preference_rules=tuple(preference_rules), |
| ) |
|
|