| """Z3-backed candidate verifier for Week 3 repairs.""" |
|
|
| from __future__ import annotations |
|
|
| import enum |
| from collections.abc import Callable |
| from dataclasses import dataclass |
| from typing import Any |
|
|
| from pydantic import BaseModel, Field |
| from z3 import ( |
| And, |
| Bool, |
| ForAll, |
| Function, |
| Implies, |
| Int, |
| IntSort, |
| IntVal, |
| RealSort, |
| RealVal, |
| Solver, |
| StringSort, |
| StringVal, |
| sat, |
| unknown, |
| unsat, |
| ) |
|
|
| from dataforge.repairers.base import ProposedFix |
| from dataforge.table import ( |
| TableLike, |
| cell_value, |
| column_names, |
| copy_table, |
| row_count, |
| set_cell_value, |
| ) |
| from dataforge.verifier.explain import explain_unsat_core |
| from dataforge.verifier.schema import DomainBound, FunctionalDependency, Schema |
|
|
| Z3ExprFactory = Callable[[Any], Any] |
| Z3ValueFactory = Callable[[str], Any] |
|
|
|
|
| class VerificationVerdict(enum.Enum): |
| """Possible outcomes of the verifier gate.""" |
|
|
| ACCEPT = "accept" |
| REJECT = "reject" |
| UNKNOWN = "unknown" |
|
|
|
|
| class VerificationResult(BaseModel): |
| """Typed result for the Week 3 verifier gate.""" |
|
|
| verdict: VerificationVerdict |
| reason: str = Field(min_length=1) |
| unsat_core: tuple[str, ...] = Field(default_factory=tuple) |
|
|
| model_config = {"frozen": True} |
|
|
|
|
| @dataclass(frozen=True) |
| class _ColumnEncoding: |
| """Z3 encoding helpers for one column.""" |
|
|
| name: str |
| column_type: str |
| function: Z3ExprFactory |
| value_factory: Z3ValueFactory |
|
|
|
|
| class SchemaToSMT: |
| """Compile candidate-local constraints from a schema and working dataframe.""" |
|
|
| def __init__(self, schema: Schema, df: TableLike, *, timeout_ms: int = 200) -> None: |
| self._schema = schema |
| self._df = df |
| self._timeout_ms = timeout_ms |
|
|
| def verify_fix(self, proposed_fix: ProposedFix) -> VerificationResult: |
| """Return whether a candidate fix satisfies schema constraints.""" |
| if proposed_fix.fix.operation != "update": |
| return VerificationResult( |
| verdict=VerificationVerdict.REJECT, |
| reason="Only cell updates are supported by the verifier.", |
| ) |
|
|
| row = proposed_fix.fix.row |
| column = proposed_fix.fix.column |
| if row < 0 or row >= row_count(self._df): |
| return VerificationResult( |
| verdict=VerificationVerdict.REJECT, |
| reason=f"Row {row} is out of bounds for the input file.", |
| ) |
| if column not in column_names(self._df): |
| return VerificationResult( |
| verdict=VerificationVerdict.REJECT, |
| reason=f"Column '{column}' does not exist in the input file.", |
| ) |
|
|
| relevant_columns = {column} |
| relevant_fds = tuple( |
| fd |
| for fd in self._schema.functional_dependencies |
| if column == fd.dependent or column in fd.determinant |
| ) |
| for fd in relevant_fds: |
| relevant_columns.update(fd.determinant) |
| relevant_columns.add(fd.dependent) |
|
|
| try: |
| encodings = { |
| name: self._build_column_encoding(name) for name in sorted(relevant_columns) |
| } |
| except ValueError as exc: |
| return VerificationResult( |
| verdict=VerificationVerdict.UNKNOWN, |
| reason=str(exc), |
| ) |
|
|
| solver = Solver() |
| solver.set(timeout=self._timeout_ms, unsat_core=True) |
|
|
| try: |
| self._add_value_assignments(solver, encodings, proposed_fix) |
| except ValueError as exc: |
| return VerificationResult( |
| verdict=VerificationVerdict.UNKNOWN, |
| reason=str(exc), |
| ) |
|
|
| for bound in self._schema.domain_bounds_for(column): |
| self._track_domain_bound(solver, encodings[column], proposed_fix, bound) |
|
|
| for fd in relevant_fds: |
| self._track_fd_constraint(solver, encodings, proposed_fix, fd) |
|
|
| result = solver.check() |
| if result == sat: |
| return VerificationResult( |
| verdict=VerificationVerdict.ACCEPT, |
| reason="The candidate fix satisfied all tracked verifier constraints.", |
| ) |
| if result == unsat: |
| unsat_core = tuple(str(label) for label in solver.unsat_core()) |
| return VerificationResult( |
| verdict=VerificationVerdict.REJECT, |
| reason=explain_unsat_core(unsat_core, self._schema), |
| unsat_core=unsat_core, |
| ) |
| if result == unknown: |
| return VerificationResult( |
| verdict=VerificationVerdict.UNKNOWN, |
| reason=f"Solver returned unknown: {solver.reason_unknown()}", |
| ) |
| return VerificationResult( |
| verdict=VerificationVerdict.UNKNOWN, |
| reason="Solver returned an unrecognized status.", |
| ) |
|
|
| def _build_column_encoding(self, column: str) -> _ColumnEncoding: |
| column_type = (self._schema.column_type(column) or "str").strip().lower() |
| function_name = f"col_{column.replace(' ', '_')}" |
| if column_type in {"int", "integer"}: |
| return _ColumnEncoding( |
| name=column, |
| column_type=column_type, |
| function=Function(function_name, IntSort(), IntSort()), |
| value_factory=lambda raw: IntVal(int(raw)), |
| ) |
| if column_type in {"float", "decimal", "real"}: |
| return _ColumnEncoding( |
| name=column, |
| column_type=column_type, |
| function=Function(function_name, IntSort(), RealSort()), |
| value_factory=lambda raw: RealVal(str(float(raw))), |
| ) |
| if column_type in {"str", "string"}: |
| return _ColumnEncoding( |
| name=column, |
| column_type=column_type, |
| function=Function(function_name, IntSort(), StringSort()), |
| value_factory=lambda raw: StringVal(str(raw)), |
| ) |
| raise ValueError(f"Unsupported schema type '{column_type}' for column '{column}'.") |
|
|
| def _add_value_assignments( |
| self, |
| solver: Solver, |
| encodings: dict[str, _ColumnEncoding], |
| proposed_fix: ProposedFix, |
| ) -> None: |
| for column, encoding in encodings.items(): |
| for index in range(row_count(self._df)): |
| raw_value = cell_value(self._df, index, column) |
| if index == proposed_fix.fix.row and column == proposed_fix.fix.column: |
| raw_value = proposed_fix.fix.new_value |
| try: |
| z3_value = encoding.value_factory(raw_value) |
| except (TypeError, ValueError) as exc: |
| raise ValueError( |
| f"Could not encode value '{raw_value}' for column '{column}' " |
| f"as type '{encoding.column_type}'." |
| ) from exc |
| solver.add(encoding.function(IntVal(index)) == z3_value) |
|
|
| def _track_domain_bound( |
| self, |
| solver: Solver, |
| encoding: _ColumnEncoding, |
| proposed_fix: ProposedFix, |
| bound: DomainBound, |
| ) -> None: |
| row_expr = encoding.function(IntVal(proposed_fix.fix.row)) |
| if bound.min_value is not None: |
| label = Bool(f"domain::{bound.column}::min::row::{proposed_fix.fix.row}") |
| threshold = ( |
| RealVal(str(bound.min_value)) |
| if encoding.column_type != "int" |
| else IntVal(int(bound.min_value)) |
| ) |
| formula = row_expr >= threshold if bound.inclusive_min else row_expr > threshold |
| solver.assert_and_track(formula, label) |
| if bound.max_value is not None: |
| label = Bool(f"domain::{bound.column}::max::row::{proposed_fix.fix.row}") |
| threshold = ( |
| RealVal(str(bound.max_value)) |
| if encoding.column_type != "int" |
| else IntVal(int(bound.max_value)) |
| ) |
| formula = row_expr <= threshold if bound.inclusive_max else row_expr < threshold |
| solver.assert_and_track(formula, label) |
|
|
| def _track_fd_constraint( |
| self, |
| solver: Solver, |
| encodings: dict[str, _ColumnEncoding], |
| proposed_fix: ProposedFix, |
| fd: FunctionalDependency, |
| ) -> None: |
| |
| other_row = Int("other_row") |
| bounds_guard = And(other_row >= 0, other_row < row_count(self._df)) |
| candidate_row = IntVal(proposed_fix.fix.row) |
| determinant_equal = And( |
| *[ |
| encodings[column].function(candidate_row) == encodings[column].function(other_row) |
| for column in fd.determinant |
| ] |
| ) |
| dependent_equal = encodings[fd.dependent].function(candidate_row) == encodings[ |
| fd.dependent |
| ].function(other_row) |
| determinant_label = "+".join(fd.determinant) |
| label = Bool(f"fd::{determinant_label}::{fd.dependent}::row::{proposed_fix.fix.row}") |
| solver.assert_and_track( |
| ForAll([other_row], Implies(bounds_guard, Implies(determinant_equal, dependent_equal))), |
| label, |
| ) |
|
|
|
|
| class SMTVerifier: |
| """Compatibility wrapper over the Week 3 `SchemaToSMT` verifier.""" |
|
|
| def verify( |
| self, |
| df: TableLike, |
| fixes: list[ProposedFix], |
| schema: Schema | None = None, |
| ) -> VerificationResult: |
| """Verify one or more candidate fixes against the working dataframe.""" |
| if schema is None: |
| total_rows = row_count(df) |
| for proposed in fixes: |
| if proposed.fix.row < 0 or proposed.fix.row >= total_rows: |
| return VerificationResult( |
| verdict=VerificationVerdict.REJECT, |
| reason=f"Row {proposed.fix.row} is out of bounds for the input file.", |
| ) |
| if proposed.fix.column not in column_names(df): |
| return VerificationResult( |
| verdict=VerificationVerdict.REJECT, |
| reason=f"Column '{proposed.fix.column}' does not exist in the input file.", |
| ) |
| return VerificationResult( |
| verdict=VerificationVerdict.ACCEPT, |
| reason="All proposed fixes passed structural verification.", |
| ) |
|
|
| working_df = copy_table(df) |
| verifier = SchemaToSMT(schema, working_df) |
| for proposed in fixes: |
| result = verifier.verify_fix(proposed) |
| if result.verdict != VerificationVerdict.ACCEPT: |
| return result |
| set_cell_value( |
| working_df, proposed.fix.row, proposed.fix.column, proposed.fix.new_value |
| ) |
| verifier = SchemaToSMT(schema, working_df) |
| return VerificationResult( |
| verdict=VerificationVerdict.ACCEPT, |
| reason="All proposed fixes passed the SMT verifier.", |
| ) |
|
|