|
|
|
|
|
import builtins
|
|
|
import functools
|
|
|
import logging
|
|
|
import math
|
|
|
import operator
|
|
|
from dataclasses import dataclass
|
|
|
from typing import Any, Callable, Optional, Union
|
|
|
|
|
|
import sympy
|
|
|
|
|
|
import torch
|
|
|
import torch.fx
|
|
|
import torch.fx.traceback as fx_traceback
|
|
|
from torch._dynamo.exc import TorchDynamoException
|
|
|
from torch._dynamo.utils import dynamo_timed
|
|
|
from torch.fx.node import Argument, Target
|
|
|
from torch.utils._sympy.interp import sympy_interp
|
|
|
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
try:
|
|
|
import z3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def z3str(e: z3.ExprRef) -> str:
|
|
|
assert z3.is_expr(e), f"unsupported expression type: {e}"
|
|
|
|
|
|
def get_args_str(e: z3.ExprRef) -> list[str]:
|
|
|
return [z3str(e.arg(i)) for i in range(e.num_args())]
|
|
|
|
|
|
|
|
|
|
|
|
e = z3.simplify(e)
|
|
|
|
|
|
|
|
|
|
|
|
if not z3.is_app(e):
|
|
|
raise ValueError(f"can't print Z3 expression: {e}")
|
|
|
|
|
|
if z3.is_int_value(e) or z3.is_rational_value(e):
|
|
|
return e.as_string()
|
|
|
|
|
|
decl = e.decl()
|
|
|
kind = decl.kind()
|
|
|
op = str(decl)
|
|
|
args = get_args_str(e)
|
|
|
|
|
|
if kind == z3.Z3_OP_POWER:
|
|
|
op = "pow"
|
|
|
|
|
|
elif kind in (z3.Z3_OP_ADD, z3.Z3_OP_MUL):
|
|
|
|
|
|
|
|
|
|
|
|
def collect_str_args(e):
|
|
|
if not (z3.is_app(e) and e.decl().kind() == kind):
|
|
|
return [z3str(e)]
|
|
|
else:
|
|
|
return [
|
|
|
x
|
|
|
for i in range(e.num_args())
|
|
|
for x in collect_str_args(e.arg(i))
|
|
|
]
|
|
|
|
|
|
args = collect_str_args(e)
|
|
|
|
|
|
elif kind == z3.Z3_OP_NOT:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert e.num_args() == 1
|
|
|
arg = e.arg(0)
|
|
|
|
|
|
assert z3.is_app(arg)
|
|
|
argkind = arg.decl().kind()
|
|
|
|
|
|
logic_inverse = {
|
|
|
z3.Z3_OP_EQ: "!=",
|
|
|
z3.Z3_OP_LE: ">",
|
|
|
z3.Z3_OP_GE: "<",
|
|
|
}
|
|
|
|
|
|
if argkind in logic_inverse:
|
|
|
op = logic_inverse[argkind]
|
|
|
args = get_args_str(arg)
|
|
|
|
|
|
elif kind in (z3.Z3_OP_TO_INT, z3.Z3_OP_TO_REAL):
|
|
|
assert e.num_args() == 1
|
|
|
argstr = z3str(e.arg(0))
|
|
|
|
|
|
|
|
|
if argstr.startswith("(/"):
|
|
|
return "(idiv" + argstr[2:]
|
|
|
|
|
|
|
|
|
return argstr
|
|
|
|
|
|
elif kind == z3.Z3_OP_UNINTERPRETED:
|
|
|
assert e.num_args() == 0
|
|
|
return str(decl)
|
|
|
|
|
|
string = op + " " + " ".join(args)
|
|
|
return f"({string.rstrip()})"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _bitwise_op(bitwise_func, bool_func):
|
|
|
@functools.wraps(bitwise_func)
|
|
|
def wrapper(self, *args):
|
|
|
if bool_func is not None and all(
|
|
|
isinstance(arg, z3.BoolRef) for arg in args
|
|
|
):
|
|
|
return bool_func(*args)
|
|
|
|
|
|
wrapped_args = tuple(z3.Int2BV(a, 64) for a in args)
|
|
|
return z3.BV2Int(bitwise_func(*wrapped_args))
|
|
|
|
|
|
return wrapper
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
class _Z3Ops:
|
|
|
|
|
|
|
|
|
validator: "TranslationValidator"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
def to_real(x: z3.ArithRef) -> z3.ArithRef:
|
|
|
return x if x.is_real() else z3.ToReal(x)
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
def to_int(x: z3.ArithRef) -> z3.ArithRef:
|
|
|
return x if x.is_int() else z3.ToInt(x)
|
|
|
|
|
|
def sym_sum(self, args: z3.ArithRef) -> z3.ArithRef:
|
|
|
return sum(args)
|
|
|
|
|
|
|
|
|
def div(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef:
|
|
|
self.validator.add_assertion(denominator != 0)
|
|
|
return _Z3Ops.to_real(numerator) / _Z3Ops.to_real(denominator)
|
|
|
|
|
|
def floor(self, number: z3.ArithRef) -> z3.ArithRef:
|
|
|
|
|
|
return _Z3Ops.to_int(number)
|
|
|
|
|
|
|
|
|
|
|
|
def floordiv(
|
|
|
self, numerator: z3.ArithRef, denominator: z3.ArithRef
|
|
|
) -> z3.ArithRef:
|
|
|
cast_result_to_real = numerator.is_real() or denominator.is_real()
|
|
|
result = _Z3Ops.to_int(self.div(numerator, denominator))
|
|
|
|
|
|
|
|
|
return _Z3Ops.to_real(result) if cast_result_to_real else result
|
|
|
|
|
|
def ceil(self, number: z3.ArithRef) -> z3.ArithRef:
|
|
|
return z3.If(self.floor(number) < number, self.floor(number + 1), number)
|
|
|
|
|
|
def trunc(self, number: z3.ArithRef) -> z3.ArithRef:
|
|
|
return z3.If(number >= 0, self.floor(number), self.ceil(number))
|
|
|
|
|
|
def max(self, a: z3.ArithRef, b: z3.ArithRef) -> z3.ArithRef:
|
|
|
return z3.If(a > b, a, b)
|
|
|
|
|
|
def min(self, a: z3.ArithRef, b: z3.ArithRef) -> z3.ArithRef:
|
|
|
return z3.If(a < b, a, b)
|
|
|
|
|
|
|
|
|
|
|
|
def mod(self, p: z3.ArithRef, q: z3.ArithRef) -> z3.ArithRef:
|
|
|
return p - self.floordiv(p, q) * q
|
|
|
|
|
|
def pow(self, base: z3.ArithRef, exp: z3.ArithRef) -> z3.ArithRef:
|
|
|
|
|
|
self.validator.add_assertion(z3.Or(base != 0, exp > 0))
|
|
|
return base**exp
|
|
|
|
|
|
def sqrt(self, number: z3.ArithRef) -> z3.ArithRef:
|
|
|
|
|
|
|
|
|
number = _Z3Ops.to_real(number)
|
|
|
|
|
|
|
|
|
self.validator.add_assertion(number >= 0)
|
|
|
return number**0.5
|
|
|
|
|
|
def abs(self, number: z3.ArithRef) -> z3.ArithRef:
|
|
|
return z3.Abs(number)
|
|
|
|
|
|
def round_to_int(self, number: z3.ArithRef) -> z3.ArithRef:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return z3.If(
|
|
|
self.mod(number, z3.IntVal(2)) == 0.5,
|
|
|
self.ceil(number - 0.5),
|
|
|
self.floor(number + 0.5),
|
|
|
)
|
|
|
|
|
|
bitwise_and = _bitwise_op(operator.and_, z3.And)
|
|
|
bitwise_or = _bitwise_op(operator.or_, z3.Or)
|
|
|
lshift = _bitwise_op(operator.lshift, None)
|
|
|
rshift = _bitwise_op(operator.rshift, None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def z3op(op: Callable, validator: "TranslationValidator") -> Callable:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
boolean_ops = {operator.not_}
|
|
|
as_bool = op in boolean_ops
|
|
|
|
|
|
|
|
|
def lift(func):
|
|
|
def wrap(a) -> z3.ExprRef:
|
|
|
if isinstance(a, (z3.ArithRef, z3.BoolRef)):
|
|
|
return a
|
|
|
|
|
|
|
|
|
if isinstance(a, bool) or (as_bool and isinstance(a, int)):
|
|
|
return z3.BoolVal(bool(a))
|
|
|
if isinstance(a, (int, sympy.Integer)):
|
|
|
return z3.IntVal(int(a))
|
|
|
if isinstance(a, (float, sympy.Float)):
|
|
|
return z3.RealVal(float(a))
|
|
|
raise ValueError(f"can't lift type: {type(a)}")
|
|
|
|
|
|
@functools.wraps(func)
|
|
|
def wrapper(*args):
|
|
|
|
|
|
if len(args) == 1 and isinstance(args[0], (list, tuple)):
|
|
|
wrapped_args = (tuple(wrap(a) for a in args[0]),)
|
|
|
else:
|
|
|
wrapped_args = tuple(wrap(a) for a in args)
|
|
|
|
|
|
return func(*wrapped_args)
|
|
|
|
|
|
return wrapper
|
|
|
|
|
|
ops = _Z3Ops(validator)
|
|
|
replacement_map = {
|
|
|
|
|
|
operator.not_: lift(z3.Not),
|
|
|
operator.and_: lift(ops.bitwise_and),
|
|
|
operator.or_: lift(ops.bitwise_or),
|
|
|
operator.lshift: lift(ops.lshift),
|
|
|
operator.rshift: lift(ops.rshift),
|
|
|
operator.floordiv: lift(ops.floordiv),
|
|
|
operator.truediv: lift(ops.div),
|
|
|
operator.mod: lift(ops.mod),
|
|
|
operator.abs: lift(ops.abs),
|
|
|
builtins.round: lift(ops.round_to_int),
|
|
|
|
|
|
math.ceil: lift(ops.ceil),
|
|
|
math.floor: lift(ops.floor),
|
|
|
math.trunc: lift(ops.trunc),
|
|
|
|
|
|
torch.sym_float: lift(ops.to_real),
|
|
|
torch.sym_max: lift(ops.max),
|
|
|
torch.sym_min: lift(ops.min),
|
|
|
torch.sym_sum: lift(ops.sym_sum),
|
|
|
torch.sym_ite: lift(lambda b, t, f: t if b else f),
|
|
|
torch._sym_sqrt: lift(ops.sqrt),
|
|
|
|
|
|
|
|
|
torch._assert: torch._assert,
|
|
|
}
|
|
|
return replacement_map[op] if op in replacement_map else lift(op)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PopulateValidator(torch.fx.Interpreter):
|
|
|
def __init__(self, graph: torch.fx.Graph, validator: "TranslationValidator"):
|
|
|
|
|
|
self.validator = validator
|
|
|
|
|
|
|
|
|
module = torch.fx.GraphModule(root={}, graph=graph)
|
|
|
super().__init__(module, garbage_collect_values=True)
|
|
|
|
|
|
def placeholder(
|
|
|
self, target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any]
|
|
|
) -> Any:
|
|
|
symbol = fx_traceback.get_current_meta()["symbol"]
|
|
|
return self.validator.z3var(symbol)
|
|
|
|
|
|
def call_function(
|
|
|
self, target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any]
|
|
|
) -> Any:
|
|
|
if target != torch._assert:
|
|
|
|
|
|
return super().call_function(z3op(target, self.validator), args, kwargs)
|
|
|
|
|
|
|
|
|
assert len(args) == 1, (
|
|
|
f"expected 1 argument on assertion. Got: {len(args)} "
|
|
|
)
|
|
|
self.validator.add_source_expr(args[0])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SympyToZ3:
|
|
|
OPERATOR_HANDLES = {"add", "mul", "eq", "ne", "lt", "gt", "le", "ge"}
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
validator: "TranslationValidator",
|
|
|
) -> None:
|
|
|
self._validator = validator
|
|
|
self._ops = _Z3Ops(self._validator)
|
|
|
|
|
|
def constant(self, value: Any, dtype: torch.dtype) -> z3.ExprRef:
|
|
|
|
|
|
if dtype is torch.int64:
|
|
|
return z3.IntVal(int(value))
|
|
|
if dtype is torch.double:
|
|
|
return z3.RealVal(float(value))
|
|
|
if dtype is torch.bool:
|
|
|
return z3.BoolVal(bool(value))
|
|
|
raise ValueError(f"unsupported dtype (SympyToZ3): {dtype}")
|
|
|
|
|
|
def to_dtype(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef:
|
|
|
if dtype == torch.float64:
|
|
|
return z3.ToReal(x)
|
|
|
raise NotImplementedError(f"to_dtype {dtype} NYI")
|
|
|
|
|
|
def trunc_to_int(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef:
|
|
|
return z3.ToInt(x)
|
|
|
|
|
|
def round_to_int(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef:
|
|
|
return self._ops.round_to_int(x)
|
|
|
|
|
|
def int_truediv(
|
|
|
self, numerator: z3.ArithRef, denominator: z3.ArithRef
|
|
|
) -> z3.ArithRef:
|
|
|
return self._ops.div(numerator, denominator)
|
|
|
|
|
|
def truediv(
|
|
|
self, numerator: z3.ArithRef, denominator: z3.ArithRef
|
|
|
) -> z3.ArithRef:
|
|
|
return self._ops.div(numerator, denominator)
|
|
|
|
|
|
def floordiv(
|
|
|
self, numerator: z3.ArithRef, denominator: z3.ArithRef
|
|
|
) -> z3.ArithRef:
|
|
|
return self._ops.floordiv(numerator, denominator)
|
|
|
|
|
|
def div(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef:
|
|
|
return self._ops.floordiv(numerator, denominator)
|
|
|
|
|
|
def pow(self, base: z3.ArithRef, exp: z3.ArithRef) -> z3.ArithRef:
|
|
|
return self._ops.pow(base, exp)
|
|
|
|
|
|
def pow_by_natural(self, base: z3.ArithRef, exp: z3.ArithRef) -> z3.ArithRef:
|
|
|
return self._ops.pow(base, exp)
|
|
|
|
|
|
def mod(self, p: z3.ArithRef, q: z3.ArithRef) -> z3.ArithRef:
|
|
|
return self._ops.mod(p, q)
|
|
|
|
|
|
def ceil_to_int(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef:
|
|
|
return self._ops.ceil(x)
|
|
|
|
|
|
def floor_to_int(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef:
|
|
|
return self._ops.floor(x)
|
|
|
|
|
|
def __getattr__(self, name: str) -> Any:
|
|
|
REPLACEMENT = {
|
|
|
"and_": z3.And,
|
|
|
"or_": z3.Or,
|
|
|
"not_": z3.Not,
|
|
|
"bitwise_and": self._ops.bitwise_and,
|
|
|
"bitwise_or": self._ops.bitwise_or,
|
|
|
"lshift": self._ops.lshift,
|
|
|
"rshift": self._ops.rshift,
|
|
|
"floor": self._ops.floor,
|
|
|
"ceil": self._ops.ceil,
|
|
|
"minimum": self._ops.min,
|
|
|
"maximum": self._ops.max,
|
|
|
}
|
|
|
|
|
|
if name in REPLACEMENT:
|
|
|
return REPLACEMENT[name]
|
|
|
if name in self.OPERATOR_HANDLES:
|
|
|
return getattr(operator, name)
|
|
|
raise AttributeError(f"unhandled operator: {name}")
|
|
|
|
|
|
def run(self, expr: sympy.Basic) -> z3.ExprRef:
|
|
|
return sympy_interp(self, self._validator.symbols, expr)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TranslationValidator:
|
|
|
def __init__(self) -> None:
|
|
|
log.debug("new instance")
|
|
|
|
|
|
|
|
|
self.symbols: dict[sympy.Symbol, z3.ExprRef] = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self._source_exprs: set[z3.BoolRef] = set()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self._target_exprs: set[z3.BoolRef] = set()
|
|
|
|
|
|
|
|
|
|
|
|
self._assertions: set[z3.BoolRef] = set()
|
|
|
|
|
|
|
|
|
def z3var(self, symbol: sympy.Symbol) -> z3.ExprRef:
|
|
|
assert symbol in self.symbols, f"Z3 variable not found for: {symbol}"
|
|
|
return self.symbols[symbol]
|
|
|
|
|
|
|
|
|
def add_var(self, symbol: sympy.Symbol, type: type) -> z3.ExprRef:
|
|
|
if symbol in self.symbols:
|
|
|
return self.symbols[symbol]
|
|
|
|
|
|
log.debug("new variable: %s (%s)", symbol.name, type.__name__)
|
|
|
|
|
|
if type is int:
|
|
|
var = z3.Int(symbol.name)
|
|
|
|
|
|
|
|
|
|
|
|
if symbol.is_positive:
|
|
|
self._target_exprs.add(var > 0)
|
|
|
elif type is float:
|
|
|
var = z3.Real(symbol.name)
|
|
|
elif type is bool:
|
|
|
var = z3.Bool(symbol.name)
|
|
|
else:
|
|
|
raise RuntimeError(f"unsupported type for Z3 variable: {type}")
|
|
|
|
|
|
self.symbols[symbol] = var
|
|
|
return var
|
|
|
|
|
|
|
|
|
def _check_freesymbols(self, e: sympy.Basic) -> None:
|
|
|
for s in e.free_symbols:
|
|
|
assert isinstance(s, sympy.Symbol)
|
|
|
|
|
|
|
|
|
self.z3var(s)
|
|
|
|
|
|
def to_z3_boolean_expr(self, e: sympy.Basic) -> z3.BoolRef:
|
|
|
z3expr = SympyToZ3(self).run(e)
|
|
|
assert isinstance(z3expr, z3.BoolRef), (
|
|
|
f"expected boolean expression. Got: {z3expr}"
|
|
|
)
|
|
|
return z3expr
|
|
|
|
|
|
def add_source_expr(self, e: z3.BoolRef) -> None:
|
|
|
if e not in self._source_exprs:
|
|
|
log.debug("add source guard: %s", z3str(e))
|
|
|
self._source_exprs.add(e)
|
|
|
|
|
|
def add_target_expr(self, e: "sympy.logic.boolalg.Boolean") -> None:
|
|
|
self._check_freesymbols(e)
|
|
|
z3expr = self.to_z3_boolean_expr(e)
|
|
|
if e not in self._target_exprs:
|
|
|
log.debug("add target guard: %s", z3str(z3expr))
|
|
|
self._target_exprs.add(z3expr)
|
|
|
|
|
|
def add_assertion(self, e: Union[z3.BoolRef, sympy.Basic]) -> None:
|
|
|
if isinstance(e, sympy.Basic):
|
|
|
self._check_freesymbols(e)
|
|
|
ref = self.to_z3_boolean_expr(e)
|
|
|
else:
|
|
|
ref = e
|
|
|
assert isinstance(ref, z3.BoolRef)
|
|
|
if ref not in self._assertions:
|
|
|
log.debug("add assertion: %s", z3str(ref))
|
|
|
self._assertions.add(ref)
|
|
|
|
|
|
def validate(self) -> None:
|
|
|
with dynamo_timed("TranslationValidator.validate"):
|
|
|
return self._validate()
|
|
|
|
|
|
def _validate(self) -> None:
|
|
|
if len(self._source_exprs) == 0 or len(self._target_exprs) == 0:
|
|
|
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
solver = z3.SolverFor("QF_NRA")
|
|
|
|
|
|
solver.set(timeout=translation_validation_timeout())
|
|
|
|
|
|
|
|
|
for assertion in self._assertions:
|
|
|
solver.add(assertion)
|
|
|
|
|
|
|
|
|
|
|
|
solver.add(z3.Not(z3.And(*self._source_exprs)))
|
|
|
solver.add(*self._target_exprs)
|
|
|
|
|
|
log.debug("translation validation: start")
|
|
|
r = solver.check()
|
|
|
if r == z3.sat:
|
|
|
|
|
|
|
|
|
model = solver.model()
|
|
|
raise ValidationException(
|
|
|
model,
|
|
|
self._assertions,
|
|
|
self._target_exprs,
|
|
|
failed_source_exprs=[
|
|
|
inp for inp in self._source_exprs if not model.evaluate(inp)
|
|
|
],
|
|
|
)
|
|
|
else:
|
|
|
if r == z3.unknown:
|
|
|
|
|
|
|
|
|
|
|
|
log.warning(
|
|
|
"translation validation: could not validate: got z3.unknown"
|
|
|
)
|
|
|
else:
|
|
|
|
|
|
assert r == z3.unsat
|
|
|
log.debug("translation validation: success")
|
|
|
|
|
|
except ImportError:
|
|
|
_HAS_Z3 = False
|
|
|
|
|
|
__all__ = [
|
|
|
"translation_validation_enabled",
|
|
|
"translation_validation_timeout",
|
|
|
"ValidationException",
|
|
|
"BisectValidationException",
|
|
|
]
|
|
|
|
|
|
else:
|
|
|
_HAS_Z3 = True
|
|
|
|
|
|
__all__ = [
|
|
|
"z3str",
|
|
|
"z3op",
|
|
|
"PopulateValidator",
|
|
|
"SympyToZ3",
|
|
|
"TranslationValidator",
|
|
|
"translation_validation_enabled",
|
|
|
"translation_validation_timeout",
|
|
|
"ValidationException",
|
|
|
"BisectValidationException",
|
|
|
]
|
|
|
|
|
|
from torch.fx.experimental import _config as config
|
|
|
|
|
|
|
|
|
def translation_validation_enabled() -> bool:
|
|
|
|
|
|
|
|
|
_assert_z3_installed_if_tv_set()
|
|
|
return _HAS_Z3 and config.translation_validation
|
|
|
|
|
|
|
|
|
def translation_validation_timeout() -> int:
|
|
|
return config.translation_validation_timeout
|
|
|
|
|
|
|
|
|
def _assert_z3_installed_if_tv_set():
|
|
|
assert _HAS_Z3 or not config.translation_validation, (
|
|
|
"translation validation requires Z3 package. Please, either install "
|
|
|
"z3-solver or disable translation validation."
|
|
|
)
|
|
|
|
|
|
|
|
|
class ValidationException(TorchDynamoException):
|
|
|
def __init__(self, model, assertions, target_exprs, failed_source_exprs):
|
|
|
assert _HAS_Z3
|
|
|
|
|
|
def symbolstr(sym) -> str:
|
|
|
return f"{sym}: {model[sym]}"
|
|
|
|
|
|
def joinlines(xs) -> str:
|
|
|
return "\n".join(f" ==> {x}" for x in xs)
|
|
|
|
|
|
model_str = joinlines(sorted(map(symbolstr, model)))
|
|
|
assertions_str = joinlines(sorted(map(z3str, assertions)))
|
|
|
target_exprs_str = joinlines(sorted(map(z3str, target_exprs)))
|
|
|
failed_source_exprs_str = joinlines(sorted(map(z3str, failed_source_exprs)))
|
|
|
|
|
|
self.msg = "translation validation failed."
|
|
|
self.details = f"""\
|
|
|
Model:
|
|
|
{model_str}
|
|
|
|
|
|
Assertions:
|
|
|
{assertions_str}
|
|
|
|
|
|
Target Expressions:
|
|
|
{target_exprs_str}
|
|
|
|
|
|
Failed Source Expressions:
|
|
|
{failed_source_exprs_str}"""
|
|
|
|
|
|
def __str__(self):
|
|
|
return f"{self.msg}\n\n{self.details}"
|
|
|
|
|
|
|
|
|
class BisectValidationException(TorchDynamoException):
|
|
|
def __init__(self, validation_exc, expr, failed_action, traced_node):
|
|
|
self.msg = f"translation validation failed when {failed_action}: {expr}"
|
|
|
self.details = f"""\
|
|
|
Failure occurred while running node:
|
|
|
{traced_node.format_node()}
|
|
|
|
|
|
{validation_exc.details}"""
|
|
|
|
|
|
def __str__(self):
|
|
|
return f"{self.msg}\n\n{self.details}"
|
|
|
|
|
|
|
|
|
|
|
|
_assert_z3_installed_if_tv_set()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def bisect(shape_env):
|
|
|
from torch.fx.experimental.recording import (
|
|
|
FakeTensorMeta,
|
|
|
replay_shape_env_events,
|
|
|
ShapeEnvEvent,
|
|
|
)
|
|
|
from torch.fx.experimental.symbolic_shapes import (
|
|
|
CURRENT_NODE_KEY,
|
|
|
ShapeEnv,
|
|
|
SHAPEENV_EVENT_KEY,
|
|
|
)
|
|
|
|
|
|
events = shape_env.events
|
|
|
|
|
|
|
|
|
def get_node_event(node: torch.fx.Node) -> ShapeEnvEvent:
|
|
|
assert SHAPEENV_EVENT_KEY in node.meta
|
|
|
return events[node.meta[SHAPEENV_EVENT_KEY]]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def new_with_shape_env(shape_env: ShapeEnv, fake) -> Any:
|
|
|
if isinstance(fake, int):
|
|
|
return fake
|
|
|
if isinstance(fake, torch.SymInt):
|
|
|
return torch.SymInt(fake.node.with_shape_env(shape_env))
|
|
|
if isinstance(fake, torch.SymFloat):
|
|
|
return torch.SymFloat(fake.node.with_shape_env(shape_env))
|
|
|
assert isinstance(fake, FakeTensorMeta)
|
|
|
return FakeTensorMeta(
|
|
|
tuple(new_with_shape_env(shape_env, s) for s in fake.size()),
|
|
|
tuple(new_with_shape_env(shape_env, s) for s in fake.stride()),
|
|
|
new_with_shape_env(shape_env, fake.storage_offset()),
|
|
|
fake.is_nested,
|
|
|
)
|
|
|
|
|
|
|
|
|
def check_shapeenv_fails(
|
|
|
shape_env: ShapeEnv, tracked_fakes: Optional[list[Any]]
|
|
|
) -> Optional[ValidationException]:
|
|
|
assert tracked_fakes is not None
|
|
|
try:
|
|
|
|
|
|
|
|
|
|
|
|
shape_env.produce_guards(
|
|
|
[new_with_shape_env(shape_env, a.fake) for a in tracked_fakes],
|
|
|
[a.source for a in tracked_fakes],
|
|
|
input_contexts=[a.symbolic_context for a in tracked_fakes],
|
|
|
)
|
|
|
return None
|
|
|
except ValidationException as e:
|
|
|
return e
|
|
|
|
|
|
|
|
|
|
|
|
def check_node_fails(node: torch.fx.Node) -> Optional[ValidationException]:
|
|
|
number = node.meta[SHAPEENV_EVENT_KEY]
|
|
|
|
|
|
shape_env = replay_shape_env_events(events[: number + 1])
|
|
|
shape_env.graph.lint()
|
|
|
return check_shapeenv_fails(shape_env, events[number].tracked_fakes)
|
|
|
|
|
|
last_exception = check_shapeenv_fails(
|
|
|
shape_env, shape_env._snapshot_tracked_fakes()
|
|
|
)
|
|
|
|
|
|
if not last_exception:
|
|
|
|
|
|
|
|
|
log.info("translation validation succeeded: no errors found.")
|
|
|
return
|
|
|
|
|
|
if not shape_env.should_record_events or config.translation_validation_no_bisect:
|
|
|
|
|
|
|
|
|
raise last_exception
|
|
|
|
|
|
|
|
|
exception = {}
|
|
|
|
|
|
|
|
|
|
|
|
assert_nodes = [
|
|
|
node for node in shape_env.graph.nodes if node.target == torch._assert
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
left, mid, right = 0, 0, len(assert_nodes) - 1
|
|
|
exception[right] = check_node_fails(assert_nodes[right])
|
|
|
|
|
|
while left < right:
|
|
|
mid = (left + right) // 2
|
|
|
|
|
|
node = assert_nodes[mid]
|
|
|
log.debug("bisecting at %s: %s", mid, get_node_event(node))
|
|
|
|
|
|
|
|
|
exception[mid] = check_node_fails(node)
|
|
|
|
|
|
if exception[mid]:
|
|
|
right = mid
|
|
|
else:
|
|
|
left = mid + 1
|
|
|
|
|
|
assert left in exception and isinstance(exception[left], ValidationException)
|
|
|
|
|
|
node = assert_nodes[left]
|
|
|
event = get_node_event(node)
|
|
|
|
|
|
if event.is_evaluate_expr():
|
|
|
failed_action = "evaluating"
|
|
|
else:
|
|
|
assert event.is_defer_runtime_assert(), f"unexpected event type: {event}"
|
|
|
failed_action = "adding runtime assert"
|
|
|
|
|
|
args = event.args
|
|
|
assert args is not None
|
|
|
assert len(args) >= 2, (
|
|
|
f"bisecting expects {event.name} to have at least 2 positional arguments. "
|
|
|
f"Got: {len(args)}"
|
|
|
)
|
|
|
assert isinstance(args[1], sympy.Basic), (
|
|
|
f"bisecting expects {event.name} to have a SymPy expression as its second argument. "
|
|
|
f"Got: {type(args[1])}"
|
|
|
)
|
|
|
|
|
|
raise BisectValidationException(
|
|
|
exception[left],
|
|
|
expr=args[1],
|
|
|
failed_action=failed_action,
|
|
|
traced_node=node.meta[CURRENT_NODE_KEY],
|
|
|
)
|
|
|
|