Spaces:
Running
Running
| """Constitution parsing and compiled-rule registry for the safety layer.""" | |
| from __future__ import annotations | |
| from collections.abc import Callable | |
| from dataclasses import dataclass | |
| from functools import lru_cache | |
| from pathlib import Path | |
| from typing import TYPE_CHECKING, Any, Literal | |
| import yaml | |
| from dataforge.repairers.base import ProposedFix | |
| from dataforge.verifier.schema import Schema | |
| if TYPE_CHECKING: | |
| from dataforge.safety.filter import SafetyContext | |
| else: # pragma: no cover | |
| SafetyContext = Any | |
| RuleTier = Literal["hard_never", "soft_require_confirm", "soft_prefer"] | |
| SinglePredicate = Callable[[ProposedFix, Schema | None, SafetyContext], bool] | |
| BatchPredicate = Callable[[list[ProposedFix]], bool] | |
| PreferenceScorer = Callable[[ProposedFix, Schema | None, SafetyContext], int] | |
| class ConstitutionError(ValueError): | |
| """Raised when a constitution file is malformed or references unknown rules.""" | |
| def _levenshtein_distance(left: str, right: str) -> int: | |
| """Return the Levenshtein edit distance between two strings.""" | |
| if left == right: | |
| return 0 | |
| if not left: | |
| return len(right) | |
| if not right: | |
| return len(left) | |
| previous = list(range(len(right) + 1)) | |
| for i, left_char in enumerate(left, start=1): | |
| current = [i] | |
| for j, right_char in enumerate(right, start=1): | |
| insert_cost = current[j - 1] + 1 | |
| delete_cost = previous[j] + 1 | |
| replace_cost = previous[j - 1] + (left_char != right_char) | |
| current.append(min(insert_cost, delete_cost, replace_cost)) | |
| previous = current | |
| return previous[-1] | |
| def _pii_overwrite( | |
| proposed_fix: ProposedFix, | |
| schema: Schema | None, | |
| context: SafetyContext, | |
| ) -> bool: | |
| """Return whether a fix touches a column marked as PII.""" | |
| del context | |
| return schema is not None and proposed_fix.fix.column in schema.pii_columns | |
| def _row_delete( | |
| proposed_fix: ProposedFix, | |
| schema: Schema | None, | |
| context: SafetyContext, | |
| ) -> bool: | |
| """Return whether a proposed fix is deleting a row.""" | |
| del schema, context | |
| return proposed_fix.fix.operation == "delete_row" | |
| def _aggregate_sensitive( | |
| proposed_fix: ProposedFix, | |
| schema: Schema | None, | |
| context: SafetyContext, | |
| ) -> bool: | |
| """Return whether a fix edits a column used as an aggregate source.""" | |
| del context | |
| return schema is not None and bool(schema.aggregate_dependencies_for(proposed_fix.fix.column)) | |
| def _conflicting_cell_writes(fixes: list[ProposedFix]) -> bool: | |
| """Return whether multiple proposed fixes target the same cell differently.""" | |
| seen: dict[tuple[int, str], str] = {} | |
| for fix in fixes: | |
| key = (fix.fix.row, fix.fix.column) | |
| existing = seen.get(key) | |
| if existing is not None and existing != fix.fix.new_value: | |
| return True | |
| seen[key] = fix.fix.new_value | |
| return False | |
| def _minimal_edit_distance( | |
| proposed_fix: ProposedFix, | |
| schema: Schema | None, | |
| context: SafetyContext, | |
| ) -> int: | |
| """Score a candidate by edit distance from the original value.""" | |
| del schema, context | |
| return _levenshtein_distance(proposed_fix.fix.old_value, proposed_fix.fix.new_value) | |
| _SINGLE_PREDICATES: dict[str, SinglePredicate] = { | |
| "pii_overwrite": _pii_overwrite, | |
| "row_delete": _row_delete, | |
| "aggregate_sensitive": _aggregate_sensitive, | |
| } | |
| _BATCH_PREDICATES: dict[str, BatchPredicate] = { | |
| "conflicting_cell_writes": _conflicting_cell_writes, | |
| } | |
| _SCORERS: dict[str, PreferenceScorer] = { | |
| "minimal_edit_distance": _minimal_edit_distance, | |
| } | |
| class CompiledSingleRule: | |
| """Compiled single-fix safety rule.""" | |
| rule_id: str | |
| description: str | |
| tier: RuleTier | |
| predicate: SinglePredicate | |
| override_flag: str | None = None | |
| confirm_flag: str | None = None | |
| class CompiledBatchRule: | |
| """Compiled batch safety rule.""" | |
| rule_id: str | |
| description: str | |
| tier: RuleTier | |
| predicate: BatchPredicate | |
| class CompiledPreferenceRule: | |
| """Compiled candidate-preference rule.""" | |
| rule_id: str | |
| description: str | |
| tier: RuleTier | |
| scorer: PreferenceScorer | |
| class Constitution: | |
| """Compiled constitution with rule registries by scope.""" | |
| single_rules: tuple[CompiledSingleRule, ...] | |
| batch_rules: tuple[CompiledBatchRule, ...] | |
| preference_rules: tuple[CompiledPreferenceRule, ...] | |
| def default_constitution_path() -> Path: | |
| """Return the shipped default constitution path.""" | |
| return Path(__file__).resolve().parents[2] / "constitutions" / "default.yaml" | |
| def _expect_mapping(payload: object, *, message: str) -> dict[str, object]: | |
| if not isinstance(payload, dict): | |
| raise ConstitutionError(message) | |
| return payload | |
| def _build_single_rule(payload: dict[str, object], tier: RuleTier) -> CompiledSingleRule: | |
| rule_id = str(payload.get("id", "")).strip() | |
| description = str(payload.get("description", "")).strip() | |
| predicate_name = str(payload.get("predicate", "")).strip() | |
| if not rule_id or not description: | |
| raise ConstitutionError(f"Invalid rule entry for tier '{tier}'.") | |
| predicate = _SINGLE_PREDICATES.get(predicate_name) | |
| if predicate is None: | |
| raise ConstitutionError(f"Unknown predicate '{predicate_name}' in rule '{rule_id}'.") | |
| return CompiledSingleRule( | |
| rule_id=rule_id, | |
| description=description, | |
| tier=tier, | |
| predicate=predicate, | |
| override_flag=str(payload["override_flag"]) if payload.get("override_flag") else None, | |
| confirm_flag=str(payload["confirm_flag"]) if payload.get("confirm_flag") else None, | |
| ) | |
| def _build_batch_rule(payload: dict[str, object], tier: RuleTier) -> CompiledBatchRule: | |
| rule_id = str(payload.get("id", "")).strip() | |
| description = str(payload.get("description", "")).strip() | |
| predicate_name = str(payload.get("predicate", "")).strip() | |
| if not rule_id or not description: | |
| raise ConstitutionError(f"Invalid batch rule entry for tier '{tier}'.") | |
| predicate = _BATCH_PREDICATES.get(predicate_name) | |
| if predicate is None: | |
| raise ConstitutionError(f"Unknown predicate '{predicate_name}' in rule '{rule_id}'.") | |
| return CompiledBatchRule( | |
| rule_id=rule_id, | |
| description=description, | |
| tier=tier, | |
| predicate=predicate, | |
| ) | |
| def _build_preference_rule(payload: dict[str, object], tier: RuleTier) -> CompiledPreferenceRule: | |
| rule_id = str(payload.get("id", "")).strip() | |
| description = str(payload.get("description", "")).strip() | |
| scorer_name = str(payload.get("scorer", "")).strip() | |
| if not rule_id or not description: | |
| raise ConstitutionError(f"Invalid preference rule entry for tier '{tier}'.") | |
| if not scorer_name: | |
| raise ConstitutionError(f"Preference rule '{rule_id}' must declare a scorer.") | |
| scorer = _SCORERS.get(scorer_name) | |
| if scorer is None: | |
| raise ConstitutionError(f"Unknown scorer '{scorer_name}' in rule '{rule_id}'.") | |
| return CompiledPreferenceRule( | |
| rule_id=rule_id, | |
| description=description, | |
| tier=tier, | |
| scorer=scorer, | |
| ) | |
| def load_constitution(path: Path) -> Constitution: | |
| """Load and compile a constitution YAML file.""" | |
| raw_payload = yaml.safe_load(path.read_text(encoding="utf-8")) | |
| root = _expect_mapping(raw_payload or {}, message="Constitution must be a YAML mapping.") | |
| single_rules: list[CompiledSingleRule] = [] | |
| batch_rules: list[CompiledBatchRule] = [] | |
| preference_rules: list[CompiledPreferenceRule] = [] | |
| for tier in ("hard_never", "soft_require_confirm", "soft_prefer"): | |
| raw_rules = root.get(tier, []) | |
| if not isinstance(raw_rules, list): | |
| raise ConstitutionError(f"Tier '{tier}' must be a YAML list.") | |
| for raw_rule in raw_rules: | |
| payload = _expect_mapping( | |
| raw_rule, message=f"Rule entries in '{tier}' must be mappings." | |
| ) | |
| scope = str(payload.get("scope", "single")).strip().lower() | |
| if tier == "soft_prefer": | |
| preference_rules.append(_build_preference_rule(payload, tier)) | |
| continue | |
| if scope == "batch": | |
| batch_rules.append(_build_batch_rule(payload, tier)) | |
| else: | |
| single_rules.append(_build_single_rule(payload, tier)) | |
| return Constitution( | |
| single_rules=tuple(single_rules), | |
| batch_rules=tuple(batch_rules), | |
| preference_rules=tuple(preference_rules), | |
| ) | |