Praneshrajan15's picture
Deploy DataForge playground API
791c076 verified
"""Z3-backed candidate verifier for Week 3 repairs."""
from __future__ import annotations
import enum
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any
from pydantic import BaseModel, Field
from z3 import ( # type: ignore[import-untyped]
And,
Bool,
ForAll,
Function,
Implies,
Int,
IntSort,
IntVal,
RealSort,
RealVal,
Solver,
StringSort,
StringVal,
sat,
unknown,
unsat,
)
from dataforge.repairers.base import ProposedFix
from dataforge.table import (
TableLike,
cell_value,
column_names,
copy_table,
row_count,
set_cell_value,
)
from dataforge.verifier.explain import explain_unsat_core
from dataforge.verifier.schema import DomainBound, FunctionalDependency, Schema
Z3ExprFactory = Callable[[Any], Any]
Z3ValueFactory = Callable[[str], Any]
class VerificationVerdict(enum.Enum):
"""Possible outcomes of the verifier gate."""
ACCEPT = "accept"
REJECT = "reject"
UNKNOWN = "unknown"
class VerificationResult(BaseModel):
"""Typed result for the Week 3 verifier gate."""
verdict: VerificationVerdict
reason: str = Field(min_length=1)
unsat_core: tuple[str, ...] = Field(default_factory=tuple)
model_config = {"frozen": True}
@dataclass(frozen=True)
class _ColumnEncoding:
"""Z3 encoding helpers for one column."""
name: str
column_type: str
function: Z3ExprFactory
value_factory: Z3ValueFactory
class SchemaToSMT:
"""Compile candidate-local constraints from a schema and working dataframe."""
def __init__(self, schema: Schema, df: TableLike, *, timeout_ms: int = 200) -> None:
self._schema = schema
self._df = df
self._timeout_ms = timeout_ms
def verify_fix(self, proposed_fix: ProposedFix) -> VerificationResult:
"""Return whether a candidate fix satisfies schema constraints."""
if proposed_fix.fix.operation != "update":
return VerificationResult(
verdict=VerificationVerdict.REJECT,
reason="Only cell updates are supported by the verifier.",
)
row = proposed_fix.fix.row
column = proposed_fix.fix.column
if row < 0 or row >= row_count(self._df):
return VerificationResult(
verdict=VerificationVerdict.REJECT,
reason=f"Row {row} is out of bounds for the input file.",
)
if column not in column_names(self._df):
return VerificationResult(
verdict=VerificationVerdict.REJECT,
reason=f"Column '{column}' does not exist in the input file.",
)
relevant_columns = {column}
relevant_fds = tuple(
fd
for fd in self._schema.functional_dependencies
if column == fd.dependent or column in fd.determinant
)
for fd in relevant_fds:
relevant_columns.update(fd.determinant)
relevant_columns.add(fd.dependent)
try:
encodings = {
name: self._build_column_encoding(name) for name in sorted(relevant_columns)
}
except ValueError as exc:
return VerificationResult(
verdict=VerificationVerdict.UNKNOWN,
reason=str(exc),
)
solver = Solver()
solver.set(timeout=self._timeout_ms, unsat_core=True)
try:
self._add_value_assignments(solver, encodings, proposed_fix)
except ValueError as exc:
return VerificationResult(
verdict=VerificationVerdict.UNKNOWN,
reason=str(exc),
)
for bound in self._schema.domain_bounds_for(column):
self._track_domain_bound(solver, encodings[column], proposed_fix, bound)
for fd in relevant_fds:
self._track_fd_constraint(solver, encodings, proposed_fix, fd)
result = solver.check()
if result == sat:
return VerificationResult(
verdict=VerificationVerdict.ACCEPT,
reason="The candidate fix satisfied all tracked verifier constraints.",
)
if result == unsat:
unsat_core = tuple(str(label) for label in solver.unsat_core())
return VerificationResult(
verdict=VerificationVerdict.REJECT,
reason=explain_unsat_core(unsat_core, self._schema),
unsat_core=unsat_core,
)
if result == unknown:
return VerificationResult(
verdict=VerificationVerdict.UNKNOWN,
reason=f"Solver returned unknown: {solver.reason_unknown()}",
)
return VerificationResult(
verdict=VerificationVerdict.UNKNOWN,
reason="Solver returned an unrecognized status.",
)
def _build_column_encoding(self, column: str) -> _ColumnEncoding:
column_type = (self._schema.column_type(column) or "str").strip().lower()
function_name = f"col_{column.replace(' ', '_')}"
if column_type in {"int", "integer"}:
return _ColumnEncoding(
name=column,
column_type=column_type,
function=Function(function_name, IntSort(), IntSort()),
value_factory=lambda raw: IntVal(int(raw)),
)
if column_type in {"float", "decimal", "real"}:
return _ColumnEncoding(
name=column,
column_type=column_type,
function=Function(function_name, IntSort(), RealSort()),
value_factory=lambda raw: RealVal(str(float(raw))),
)
if column_type in {"str", "string"}:
return _ColumnEncoding(
name=column,
column_type=column_type,
function=Function(function_name, IntSort(), StringSort()),
value_factory=lambda raw: StringVal(str(raw)),
)
raise ValueError(f"Unsupported schema type '{column_type}' for column '{column}'.")
def _add_value_assignments(
self,
solver: Solver,
encodings: dict[str, _ColumnEncoding],
proposed_fix: ProposedFix,
) -> None:
for column, encoding in encodings.items():
for index in range(row_count(self._df)):
raw_value = cell_value(self._df, index, column)
if index == proposed_fix.fix.row and column == proposed_fix.fix.column:
raw_value = proposed_fix.fix.new_value
try:
z3_value = encoding.value_factory(raw_value)
except (TypeError, ValueError) as exc:
raise ValueError(
f"Could not encode value '{raw_value}' for column '{column}' "
f"as type '{encoding.column_type}'."
) from exc
solver.add(encoding.function(IntVal(index)) == z3_value)
def _track_domain_bound(
self,
solver: Solver,
encoding: _ColumnEncoding,
proposed_fix: ProposedFix,
bound: DomainBound,
) -> None:
row_expr = encoding.function(IntVal(proposed_fix.fix.row))
if bound.min_value is not None:
label = Bool(f"domain::{bound.column}::min::row::{proposed_fix.fix.row}")
threshold = (
RealVal(str(bound.min_value))
if encoding.column_type != "int"
else IntVal(int(bound.min_value))
)
formula = row_expr >= threshold if bound.inclusive_min else row_expr > threshold
solver.assert_and_track(formula, label)
if bound.max_value is not None:
label = Bool(f"domain::{bound.column}::max::row::{proposed_fix.fix.row}")
threshold = (
RealVal(str(bound.max_value))
if encoding.column_type != "int"
else IntVal(int(bound.max_value))
)
formula = row_expr <= threshold if bound.inclusive_max else row_expr < threshold
solver.assert_and_track(formula, label)
def _track_fd_constraint(
self,
solver: Solver,
encodings: dict[str, _ColumnEncoding],
proposed_fix: ProposedFix,
fd: FunctionalDependency,
) -> None:
# Use a universally-quantified implication over all valid other rows.
other_row = Int("other_row")
bounds_guard = And(other_row >= 0, other_row < row_count(self._df))
candidate_row = IntVal(proposed_fix.fix.row)
determinant_equal = And(
*[
encodings[column].function(candidate_row) == encodings[column].function(other_row)
for column in fd.determinant
]
)
dependent_equal = encodings[fd.dependent].function(candidate_row) == encodings[
fd.dependent
].function(other_row)
determinant_label = "+".join(fd.determinant)
label = Bool(f"fd::{determinant_label}::{fd.dependent}::row::{proposed_fix.fix.row}")
solver.assert_and_track(
ForAll([other_row], Implies(bounds_guard, Implies(determinant_equal, dependent_equal))),
label,
)
class SMTVerifier:
"""Compatibility wrapper over the Week 3 `SchemaToSMT` verifier."""
def verify(
self,
df: TableLike,
fixes: list[ProposedFix],
schema: Schema | None = None,
) -> VerificationResult:
"""Verify one or more candidate fixes against the working dataframe."""
if schema is None:
total_rows = row_count(df)
for proposed in fixes:
if proposed.fix.row < 0 or proposed.fix.row >= total_rows:
return VerificationResult(
verdict=VerificationVerdict.REJECT,
reason=f"Row {proposed.fix.row} is out of bounds for the input file.",
)
if proposed.fix.column not in column_names(df):
return VerificationResult(
verdict=VerificationVerdict.REJECT,
reason=f"Column '{proposed.fix.column}' does not exist in the input file.",
)
return VerificationResult(
verdict=VerificationVerdict.ACCEPT,
reason="All proposed fixes passed structural verification.",
)
working_df = copy_table(df)
verifier = SchemaToSMT(schema, working_df)
for proposed in fixes:
result = verifier.verify_fix(proposed)
if result.verdict != VerificationVerdict.ACCEPT:
return result
set_cell_value(
working_df, proposed.fix.row, proposed.fix.column, proposed.fix.new_value
)
verifier = SchemaToSMT(schema, working_df)
return VerificationResult(
verdict=VerificationVerdict.ACCEPT,
reason="All proposed fixes passed the SMT verifier.",
)