diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/exported_program.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/exported_program.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..16c78c804235c93faa920e293e20d539373c7da3 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/exported_program.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/pass_base.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/pass_base.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..20844ff5d57c2c0e700b03b2ff9bd28594fedee0 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/pass_base.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/utils.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8de6055e3924a4713003767b24def76a43bf5480 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/utils.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/verifier.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/verifier.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bf37517a9e3cd55ad92b4f56bfbace302df96153 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/verifier.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/wrappers.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/wrappers.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2762105611f0b5194f9a96f693347a813c568f47 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/__pycache__/wrappers.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/__pycache__/case.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/__pycache__/case.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd6db640e16e38e436116a3bbad1ed171a2ba14d Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/__pycache__/case.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/static_for_loop.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/static_for_loop.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..69d063ceb1aa5086c501644b68ef397cfc862eb3 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/static_for_loop.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..55934b0b5c2bc5db73752a7e42b74217437f6e82 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/add_runtime_assertions_for_constraints_pass.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/add_runtime_assertions_for_constraints_pass.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3e991d6bd07091c7c7304659a24075186c27893e Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/add_runtime_assertions_for_constraints_pass.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/lift_constants_pass.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/lift_constants_pass.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..44b4d1ccb35de5e6ad3dc5850f0b17c5bc9f2c80 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/lift_constants_pass.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/remove_runtime_assertions.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/remove_runtime_assertions.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6247e5afaecb1699c7f3a1b8992387cf31c6a986 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/remove_runtime_assertions.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/replace_set_grad_with_hop_pass.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/replace_set_grad_with_hop_pass.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5cc8918ab2d1ecfcb77e608a1cd55c3f9cfa94ca Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/replace_set_grad_with_hop_pass.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/add_runtime_assertions_for_constraints_pass.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/add_runtime_assertions_for_constraints_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..7eedd1498c94be9d09922cd33c0191ae5bb354d0 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/add_runtime_assertions_for_constraints_pass.py @@ -0,0 +1,231 @@ +import math +import operator +import traceback +from functools import partial +from typing import Callable, Dict, List, NamedTuple, Set + +import sympy + +import torch +import torch.fx +from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse, ProxyValue, PassResult +from torch.utils._sympy.value_ranges import ValueRanges +from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols + + +__all__ = ["InputDim"] + + +class InputDim(NamedTuple): + input_name: str + dim: int + + +def _convert_to_int(val): + # Convert simple sympy Integers into concrete int + if val == sympy.oo: + return math.inf + if val == -sympy.oo: + return -math.inf + if isinstance(val, sympy.Integer): + return int(val) + raise RuntimeError( + "Export constraints cannot be non-integer expressions" + ) + + +def _convert_range_to_int(range: ValueRanges): + assert isinstance(range, ValueRanges) + min_val = _convert_to_int(range.lower) + max_val = _convert_to_int(range.upper) + return min_val, max_val + + +class _AddRuntimeAssertionsForInlineConstraintsPass(_ExportPassBaseDeprecatedDoNotUse): + def __init__( + self, + range_constraints: Dict[sympy.Symbol, ValueRanges], + ): + super().__init__() + self.range_constraints: Dict[sympy.Symbol, ValueRanges] = range_constraints + self._asserts_generated_unbacked_symbols: Set[sympy.Symbol] = set() + self.counter = 0 + + def _assert_range_constraint(self, proxy, lower, upper, assert_msg): + if lower > -math.inf: + self._insert_assert_async(operator.ge, proxy, lower, assert_msg) + + if upper < math.inf: + self._insert_assert_async(operator.le, proxy, upper, assert_msg) + + def _insert_assert_async(self, operator, lower, upper, assert_msg): + """ + Inserts assert_async call_function nodes in the graph. This function is + called **during** the interpreter-based pass. + """ + self.counter += 1 + cmp = super().call_operator(operator, (lower, upper), {}, self._create_dummy_node_metadata()) + cmp_tensor = super().call_operator(torch.ops.aten.scalar_tensor.default, (cmp,), {}, self._create_dummy_node_metadata()) + super().call_operator( + torch.ops.aten._assert_async.msg, + (cmp_tensor, assert_msg), + {}, + self._create_dummy_node_metadata(), + ) + + def call_operator(self, op, args, kwargs, meta) -> ProxyValue: + ret = super().call_operator(op, args, kwargs, meta) + if "val" not in meta: + return ret + + val = meta["val"] + + # In general, we may have to deal the case such as: ret[1].shape[0]. + # We need first find out what symbols require assertion, then we need to follow the path + # from ret to the symbol, construct the proxies along the way and construct the messages + # piece-wise at the same time. + # + # We use post-order traversal to collect all the proxies callbacks needed, construct + # the error message callbacks, and at the top-level traversal tree we execute all the callbacks. + # We need the callbacks because, in order to call the function to create a proxy for shape[0], we + # need the proxy for shape, which further requires the proxy for ret[1], etc. + def add_assertions(val): + call_backs: List[Callable] = [] + messages: List[str] = [] + if isinstance(val, (torch.SymInt, torch.SymFloat, torch.SymBool)): + symbol = val.node.expr + if symbol in self.existing_inline_assertions: + return call_backs, messages + if isinstance(symbol, sympy.Symbol) and free_unbacked_symbols(symbol): + if symbol in self._asserts_generated_unbacked_symbols: + return call_backs, messages + # We only care about unbacked symints for these inline + # constraints, which are prefixed with 'u' + constraint = self.range_constraints[symbol] + min_val, max_val = _convert_range_to_int(constraint) + assert_msg = f" is outside of inline constraint [{min_val}, {max_val}]." + call_backs.append( + partial(self._assert_range_constraint, lower=min_val, upper=max_val) + ) + messages.append(assert_msg) + self._asserts_generated_unbacked_symbols.add(symbol) + + elif isinstance(val, torch.Tensor): + for i, sym in enumerate(val.shape): + cbs, msgs = add_assertions(sym) + for cb, msg in zip(cbs, msgs): + def sym_size_cb(proxy, assert_msg, dim): + dim_proxy = super( + _AddRuntimeAssertionsForInlineConstraintsPass, + self + ).call_operator( + torch.ops.aten.sym_size.int, + (proxy, dim), + {}, + self._create_dummy_node_metadata(), + ) + cb(proxy=dim_proxy, assert_msg=assert_msg) + call_backs.append(partial(sym_size_cb, dim=i)) + messages.append(f".shape[{i}]" + msg) + return call_backs, messages + + callbacks, messages = add_assertions(val) + for cb, msg in zip(callbacks, messages): + cb(proxy=ret, assert_msg=f"{ret.node}" + msg) + return ret + + def call(self, graph_module): + self.existing_inline_assertions = _get_existing_inline_assertions( + graph_module, self.range_constraints + ) + + # Add runtime asserts for inline constraints + val = super().call(graph_module) + + # Sometimes this pass would return a wrong graph where we have mismatched + # node names in signature. Before we fix it, let's just skip it. + if self.counter == 0 and type(self) is _AddRuntimeAssertionsForInlineConstraintsPass: + return PassResult(graph_module, False) + + # Populate the stack trace with dummy vals to respect IR + for node in val.graph_module.graph.nodes: + if not node.meta.get("stack_trace", None): + node.meta["stack_trace"] = "".join(traceback.format_stack(limit=1)) + + return PassResult(val.graph_module, val.modified) + + +def _get_existing_inline_assertions( + graph_module: torch.fx.GraphModule, + range_constraints: Dict[sympy.Symbol, ValueRanges], +) -> Dict[sympy.Symbol, ValueRanges]: + existing_inline_assertions: Dict[sympy.Symbol, ValueRanges] = {} + + for module in graph_module.modules(): + if not isinstance(module, torch.fx.GraphModule): + continue + + # Find all the existing inline assertions. They will look something like: + # %_local_scalar_dense = call_function[target=torch.ops.aten._local_scalar_dense.default](args = (%arg1_1,), kwargs = {}) + # %ge = call_function[target=operator.ge](args = (%_local_scalar_dense, 0), kwargs = {}) + # %scalar_tensor = call_function[target=torch.ops.aten.scalar_tensor.default](args = (%ge,), kwargs = {}) + # %_assert_async = call_function[target=torch.ops.aten._assert_async.msg](args = (%scalar_tensor, "..."), kwargs = {}) + for node in module.graph.nodes: + if node.target != torch.ops.aten._assert_async.msg: + continue + + scalar_tensor_arg = node.args[0] + if not ( + scalar_tensor_arg.op == "call_function" and + scalar_tensor_arg.target == torch.ops.aten.scalar_tensor.default + ): + continue + + compare_arg = scalar_tensor_arg.args[0] + if not ( + compare_arg.op == "call_function" and + compare_arg.target in (operator.le, operator.ge) and + len(compare_arg.args) == 2 + ): + continue + + compare_op = compare_arg.target + maybe_symint_arg, compare_int = compare_arg.args + + # x >= 0 will sometimes be canonicalized to -x <= 0, so in some + # cases the operation before the comparison is to multiply by -1. We + # can undo the canonicalization here + if ( + maybe_symint_arg.op == "call_function" and + maybe_symint_arg.target == operator.mul and + maybe_symint_arg.args[0] == -1 + ): + maybe_symint_arg = maybe_symint_arg.args[1] + compare_op = operator.ge + compare_int = -1 * compare_int + + if not ( + "val" in maybe_symint_arg.meta and + isinstance(maybe_symint_arg.meta["val"], torch.SymInt) + ): + continue + + symint = maybe_symint_arg.meta["val"].node.expr + if not isinstance(symint, sympy.Symbol): + continue + + if symint not in range_constraints: + raise RuntimeError(f"Unable to find symint {symint} in {range_constraints}") + + found_range = existing_inline_assertions.get(symint, ValueRanges(-math.inf, math.inf)) + + if compare_arg.target == operator.le: + existing_inline_assertions[symint] = ValueRanges( + lower=found_range.lower, upper=compare_int + ) + elif compare_arg.target == operator.ge: + existing_inline_assertions[symint] = ValueRanges( + lower=compare_int, upper=found_range.upper + ) + + return existing_inline_assertions diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/functionalize_side_effectful_ops_pass.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/functionalize_side_effectful_ops_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..5fcf5adaca5b0b478db87e71633f5136b54969b2 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/functionalize_side_effectful_ops_pass.py @@ -0,0 +1,94 @@ +import copy +from typing import Dict, Optional, Tuple, List + +import torch +from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse, PassResult, Argument +from torch._export.pass_infra.node_metadata import NodeMetadata +from torch._export.pass_infra.proxy_value import ProxyValue +from torch._ops import OpOverload + +aten = torch.ops.aten + +_NON_FUNCTIONAL_TO_FUNCTIONAL_SIDE_EFFECTFUL_FUNCS: Dict[OpOverload, OpOverload] = { + aten.sym_constrain_range.default: aten._functional_sym_constrain_range, + aten._assert_async.msg: aten._functional_assert_async.msg, +} + + +class _FunctionalizeSideEffectfulOpsPass(_ExportPassBaseDeprecatedDoNotUse): + """ + Functionalize ops with side effect in graph module by replacing the op with + functional version of it. A new dependency token (`dep_token`) will be + created and propagated through functional ops to output. + For example: + ``` + def f(x): + sym_constrain_range(x.shape[0], min=1, max=3) + return x.add(3) + ``` + Will be transformed to: + ``` + def f(x): + dep_token0 = _make_dep_token() + dep_token1 = _functional_sym_constrain_range( + x.shape[0], min=1, max=3, dep_token=dep_token0 + ) + + return x.add(3), dep_token1 + ``` + """ + + def __init__(self) -> None: + super().__init__() + self._dep_token: Optional[ProxyValue] = None + self._next_dep_token_index: Optional[int] = None + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + # Early return if no non-functional assertions. + if not any( + n.target in _NON_FUNCTIONAL_TO_FUNCTIONAL_SIDE_EFFECTFUL_FUNCS + for n in graph_module.graph.nodes + ): + return PassResult(graph_module=graph_module, modified=False) + + gm = copy.deepcopy(graph_module) + self._dep_token = None + self._next_dep_token_index = None + return super().call(gm) + + def call_operator( + self, + op: OpOverload, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + meta: NodeMetadata, + ) -> ProxyValue: + if op not in _NON_FUNCTIONAL_TO_FUNCTIONAL_SIDE_EFFECTFUL_FUNCS: + return super().call_operator(op, args, kwargs, meta) + + if self._dep_token is None: + self._dep_token = super().call_operator( + aten._make_dep_token, + args=(), + kwargs={}, + meta=self._create_dummy_node_metadata(), + ) + self._dep_token.node.name = "dep_token0" + self._next_dep_token_index = 1 + + self._dep_token = super().call_operator( + _NON_FUNCTIONAL_TO_FUNCTIONAL_SIDE_EFFECTFUL_FUNCS[op], + args=args, + kwargs={**kwargs, "dep_token": self._dep_token}, + meta=meta, + ) + assert self._next_dep_token_index is not None + self._dep_token.node.name = f"dep_token{self._next_dep_token_index}" + self._next_dep_token_index += 1 + + return self._dep_token + + def output(self, results: List[Argument], meta: NodeMetadata) -> ProxyValue: + assert self._dep_token is not None + + return super().output(results=(*results, self._dep_token), meta=meta) # type: ignore[arg-type] diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/replace_sym_size_ops_pass.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/replace_sym_size_ops_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..109a96d7b4bd3672660b1271b4d72e7fbb6b982f --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/replace_sym_size_ops_pass.py @@ -0,0 +1,18 @@ +from typing import Dict + +import torch + +replacements: Dict[torch._ops.OpOverloadPacket, torch._ops.OpOverload] = { + torch.ops.aten.sym_size: torch.ops.aten.sym_size.int, + torch.ops.aten.sym_stride: torch.ops.aten.sym_stride.int, + torch.ops.aten.sym_numel: torch.ops.aten.sym_numel.default, +} + + +def _replace_sym_size_ops_pass(gm: torch.fx.GraphModule): + for module in gm.modules(): + if not isinstance(module, torch.fx.GraphModule): + continue + for node in module.graph.nodes: + if node.target in replacements: + node.target = replacements[node.target] diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/replace_view_ops_with_view_copy_ops_pass.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/replace_view_ops_with_view_copy_ops_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..f32b442733eb98d49aea4d766ef0e727243ebeeb --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/passes/replace_view_ops_with_view_copy_ops_pass.py @@ -0,0 +1,71 @@ +from typing import Dict, Optional, Set + +import torch +from torch._ops import OpOverload, OpOverloadPacket, HigherOrderOperator +from torch._export.error import InternalError +from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse + + +__all__ = ["ReplaceViewOpsWithViewCopyOpsPass"] + + +_NON_FUNCTIONAL_OPS_TO_FUNCTIONAL_OPS: Dict[OpOverload, OpOverload] = { + torch.ops.aten._unsafe_view.default: torch.ops.aten.view_copy.default, +} + +# TODO (tmanlaibaatar) remove this after https://github.com/pytorch/pytorch/pull/100749 +_BLACK_LISTED_OPS: Set[OpOverloadPacket] = { + torch.ops.aten.sym_size, + torch.ops.aten.sym_stride, + torch.ops.aten.sym_numel, +} + +def is_view_op(schema: torch._C.FunctionSchema) -> bool: + if len(schema.arguments) == 0: + return False + alias_info = schema.arguments[0].alias_info + return (alias_info is not None) and (not alias_info.is_write) + + +def get_view_copy_of_view_op(schema: torch._C.FunctionSchema) -> Optional[OpOverload]: + if is_view_op(schema) and schema.name.startswith("aten::"): + view_op_name = schema.name.split("::")[1] + view_op_overload = ( + schema.overload_name + if schema.overload_name != "" + else "default" + ) + view_copy_op_name = view_op_name + "_copy" + if not hasattr(torch.ops.aten, view_copy_op_name): + raise InternalError(f"{schema.name} is missing a view_copy variant") + + view_copy_op_overload_packet = getattr(torch.ops.aten, view_copy_op_name) + + if not hasattr(view_copy_op_overload_packet, view_op_overload): + raise InternalError(f"{schema.name} is missing a view_copy variant") + + return getattr(view_copy_op_overload_packet, view_op_overload) + + return None + + +class ReplaceViewOpsWithViewCopyOpsPass(_ExportPassBaseDeprecatedDoNotUse): + """ + Our backend expects pure functional operators. For efficiency + purposes, we keep view ops around while functionalizing the exported + program. This pass replaces view ops with view copy ops for backends that + need AOT memory planning. + """ + def call_operator(self, op, args, kwargs, meta): + if op in _NON_FUNCTIONAL_OPS_TO_FUNCTIONAL_OPS: + return super().call_operator( + (_NON_FUNCTIONAL_OPS_TO_FUNCTIONAL_OPS[op]), args, kwargs, meta + ) + + if op in _BLACK_LISTED_OPS or isinstance(op, HigherOrderOperator): + return super().call_operator(op, args, kwargs, meta) + + if view_copy_op := get_view_copy_of_view_op(op._schema): + return super().call_operator(view_copy_op, args, kwargs, meta) + + return super().call_operator(op, args, kwargs, meta) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/serde/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/serde/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/schema.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/schema.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..059dd5049d700d0e12d98b480440dc68d28e4656 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/schema.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/union.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/union.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..63249a8efa9608d6f66674f229584a6c2343da0e Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/union.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/upgrade.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/upgrade.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8a469c8df37770125f21b18cacefda3c4498f583 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/upgrade.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/serde/schema.yaml b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/serde/schema.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d6c4c0ceae1663f09fbfe1901f3cee64230809c9 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/serde/schema.yaml @@ -0,0 +1,389 @@ +# @generated by update_schema.py +# checksum<<4c9986f3aba283b1746995fff8fe7005b370c7e288adec65c03030349a4bab60>> +Argument: + kind: union + fields: + as_none: + type: Tuple[()] + as_tensor: + type: TensorArgument + as_tensors: + type: List[TensorArgument] + as_int: + type: int + as_ints: + type: List[int] + as_float: + type: float + as_floats: + type: List[float] + as_string: + type: str + as_strings: + type: List[str] + as_sym_int: + type: SymIntArgument + as_sym_ints: + type: List[SymIntArgument] + as_scalar_type: + type: ScalarType + as_memory_format: + type: MemoryFormat + as_layout: + type: Layout + as_device: + type: Device + as_bool: + type: bool + as_bools: + type: List[bool] + as_sym_bool: + type: SymBoolArgument + as_sym_bools: + type: List[SymBoolArgument] + as_graph: + type: GraphArgument + as_optional_tensors: + type: List[OptionalTensorArgument] + as_custom_obj: + type: CustomObjArgument + as_operator: + type: str +BufferMutationSpec: + kind: struct + fields: + arg: + type: TensorArgument + buffer_name: + type: str +CustomObjArgument: + kind: struct + fields: + name: + type: str + class_fqn: + type: str +Device: + kind: struct + fields: + type: + type: str + index: + type: Optional[int] + default: None +ExportedProgram: + kind: struct + fields: + graph_module: + type: GraphModule + opset_version: + type: Dict[str, int] + range_constraints: + type: Dict[str, RangeConstraint] + schema_version: + type: SchemaVersion + dialect: + type: str +GradientToParameterSpec: + kind: struct + fields: + arg: + type: TensorArgument + parameter_name: + type: str +GradientToUserInputSpec: + kind: struct + fields: + arg: + type: TensorArgument + user_input_name: + type: str +Graph: + kind: struct + fields: + inputs: + type: List[Argument] + outputs: + type: List[Argument] + nodes: + type: List[Node] + tensor_values: + type: Dict[str, TensorMeta] + sym_int_values: + type: Dict[str, SymInt] + sym_bool_values: + type: Dict[str, SymBool] + is_single_tensor_return: + type: bool + default: 'False' + custom_obj_values: + type: Dict[str, CustomObjArgument] + default: '{}' +GraphArgument: + kind: struct + fields: + name: + type: str + graph: + type: Graph +GraphModule: + kind: struct + fields: + graph: + type: Graph + signature: + type: GraphSignature + module_call_graph: + type: List[ModuleCallEntry] +GraphSignature: + kind: struct + fields: + input_specs: + type: List[InputSpec] + output_specs: + type: List[OutputSpec] +InputSpec: + kind: union + fields: + user_input: + type: UserInputSpec + parameter: + type: InputToParameterSpec + buffer: + type: InputToBufferSpec + tensor_constant: + type: InputToTensorConstantSpec + custom_obj: + type: InputToCustomObjSpec +InputToBufferSpec: + kind: struct + fields: + arg: + type: TensorArgument + buffer_name: + type: str + persistent: + type: bool +InputToCustomObjSpec: + kind: struct + fields: + arg: + type: CustomObjArgument + custom_obj_name: + type: str +InputToParameterSpec: + kind: struct + fields: + arg: + type: TensorArgument + parameter_name: + type: str +InputToTensorConstantSpec: + kind: struct + fields: + arg: + type: TensorArgument + tensor_constant_name: + type: str +Layout: + kind: enum + fields: + Unknown: 0 + SparseCoo: 1 + SparseCsr: 2 + SparseCsc: 3 + SparseBsr: 4 + SparseBsc: 5 + _mkldnn: 6 + Strided: 7 +LossOutputSpec: + kind: struct + fields: + arg: + type: TensorArgument +MemoryFormat: + kind: enum + fields: + Unknown: 0 + ContiguousFormat: 1 + ChannelsLast: 2 + ChannelsLast3d: 3 + PreserveFormat: 4 +ModuleCallEntry: + kind: struct + fields: + fqn: + type: str + signature: + type: Optional[ModuleCallSignature] + default: None +ModuleCallSignature: + kind: struct + fields: + inputs: + type: List[Argument] + outputs: + type: List[Argument] + in_spec: + type: str + out_spec: + type: str +NamedArgument: + kind: struct + fields: + name: + type: str + arg: + type: Argument +Node: + kind: struct + fields: + target: + type: str + inputs: + type: List[NamedArgument] + outputs: + type: List[Argument] + metadata: + type: Dict[str, str] +OptionalTensorArgument: + kind: union + fields: + as_tensor: + type: str + as_none: + type: Tuple[()] +OutputSpec: + kind: union + fields: + user_output: + type: UserOutputSpec + loss_output: + type: LossOutputSpec + buffer_mutation: + type: BufferMutationSpec + gradient_to_parameter: + type: GradientToParameterSpec + gradient_to_user_input: + type: GradientToUserInputSpec + user_input_mutation: + type: UserInputMutationSpec +RangeConstraint: + kind: struct + fields: + min_val: + type: int + max_val: + type: int +ScalarType: + kind: enum + fields: + UNKNOWN: 0 + BYTE: 1 + CHAR: 2 + SHORT: 3 + INT: 4 + LONG: 5 + HALF: 6 + FLOAT: 7 + DOUBLE: 8 + COMPLEXHALF: 9 + COMPLEXFLOAT: 10 + COMPLEXDOUBLE: 11 + BOOL: 12 + BFLOAT16: 13 +SchemaVersion: + kind: struct + fields: + major: + type: int + minor: + type: int +SymBool: + kind: union + fields: + as_expr: + type: SymExpr + as_bool: + type: bool +SymBoolArgument: + kind: union + fields: + as_name: + type: str + as_bool: + type: bool +SymExpr: + kind: struct + fields: + expr_str: + type: str + hint: + type: Optional[SymExprHint] + default: None +SymExprHint: + kind: union + fields: + as_int: + type: int + as_float: + type: float + as_bool: + type: bool +SymInt: + kind: union + fields: + as_expr: + type: SymExpr + as_int: + type: int +SymIntArgument: + kind: union + fields: + as_name: + type: str + as_int: + type: int +TensorArgument: + kind: struct + fields: + name: + type: str +TensorMeta: + kind: struct + fields: + dtype: + type: ScalarType + sizes: + type: List[SymInt] + requires_grad: + type: bool + device: + type: Device + strides: + type: List[SymInt] + storage_offset: + type: SymInt + layout: + type: Layout +UserInputMutationSpec: + kind: struct + fields: + arg: + type: TensorArgument + user_input_name: + type: str +UserInputSpec: + kind: struct + fields: + arg: + type: Argument +UserOutputSpec: + kind: struct + fields: + arg: + type: Argument +SCHEMA_VERSION: +- 5 +- 1 +TREESPEC_VERSION: 1 diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/serde/upgrade.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/serde/upgrade.py new file mode 100644 index 0000000000000000000000000000000000000000..c34917f3dd074cf50e3ab2e030f9730c3d4333a9 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/serde/upgrade.py @@ -0,0 +1,201 @@ +import logging +from collections import defaultdict +from typing import Tuple, Dict, Optional, List + +import torch +from torch.export import export +from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse +from torch._export.pass_infra.node_metadata import NodeMetadata +from torch._export.pass_infra.proxy_value import ProxyValue +from torch._subclasses import FakeTensor +from torch.fx.node import Target, Argument +from torch.library import Library +from torch.utils._pytree import tree_unflatten +import torch._export.exported_program as ep +import re + +lib = Library("aten", "FRAGMENT") +impl_lib = Library("aten", "IMPL") + +log = logging.getLogger(__name__) + + +def get_target_version(versioned_upgrader_name: str) -> int: + """div_Scalar_0_3 is the name of the upgrader, meaning it applies to div.Scalar of version 0 to 3 and is + upgrading to version 4.""" + if not re.match("^.*_[0-9]+_[0-9]+$", versioned_upgrader_name): + raise RuntimeError(f"Upgrader name {versioned_upgrader_name} is invalid") + + return int(versioned_upgrader_name.split('_')[-1]) + 1 + + +def get_upgraders() -> Dict[str, Tuple[str, str]]: + """Getting upgraders entry map and operator version map and merge them into one dict.""" + upgraders = torch._C._get_upgraders_entry_map() + op_version_map = torch._C._get_operator_version_map() + output: Dict[str, Tuple[str, str]] = defaultdict(tuple) # type: ignore[arg-type] + for opname, entry_list in op_version_map.items(): + if not entry_list: + raise RuntimeError(f"Op version map has an empty entry for opname {opname}") + entry = entry_list[0] + old_schema = entry.old_schema + upgrader_name = entry.upgrader_name + upgrader_str = upgraders.get(upgrader_name, None) + if not upgrader_str: + raise RuntimeError(f"Can't find upgrader for op {opname} and upgrader name {upgrader_name}") + output[upgrader_name] = (old_schema, upgrader_str) + return output + + +class GraphModuleOpUpgrader: + """This upgrader is able to upgrade the old version of ops in a given GraphModule, if all upgraders are available. + To use it, retrieve upgraders from somewhere (TorchScript API or new API) and pass it into this upgrader. In + __init__() it does the following: + 1. parse the upgrader list and reorder for upgrading purpose. + 2. register old versions of operators as custom ops. + 3. prepare upgrader passes. + + In `upgrade()` API run these upgrader passes. + + An example of op_upgraders input: + { + "aten::div__Scalar_0_3": ( # versioned op name + "div._Scalar(self: Tensor, other: Scalar)", # old schema + ''' + def div__Scalar_0_3(self: torch.Tensor, other) -> torch.Tensor: # upgrader in literal string + if (self.is_floating_point() or isinstance(other, float)): + return self.true_divide_(other) + return self.divide_(other, rounding_mode='trunc') + ''', + ), + }, + + Note that we require the upgrader function to be runnable in Python (which is a stricter requirement than the + original TorchScript upgrader). + """ + + class UpgraderPass(_ExportPassBaseDeprecatedDoNotUse): + def __init__(self, old_target: Target, new_target: Target): + super().__init__() + self.old_target = old_target + self.new_target = new_target + + def call_operator( + self, + op, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + meta: NodeMetadata, + ) -> ProxyValue: + if op == self.old_target: + return super().call_operator(self.new_target, args, kwargs, meta) + return super().call_operator(op, args, kwargs, meta) + + def __init__( + self, + compiler_opset_version: Optional[Dict[str, int]] = None, + model_opset_version: Optional[Dict[str, int]] = None, + op_upgraders: Optional[Dict[str, Tuple[str, str]]] = None, + ): + self.op_upgraders: Dict[str, Tuple[str, str]] = get_upgraders() if not op_upgraders else op_upgraders + self.compiler_opset_version = compiler_opset_version if compiler_opset_version else {} + self.model_opset_version = model_opset_version if model_opset_version else {} + self.upgrader_passes: List[GraphModuleOpUpgrader.UpgraderPass] = GraphModuleOpUpgrader._populate_passes( + self._parse_upgraders(self.op_upgraders)) + + def _parse_upgraders(self, op_upgraders: Optional[Dict[str, Tuple[str, str]]] = None) -> List[Tuple[str, str]]: + """Reorder op_upgraders by version number, return an ordered list of tuples, containing old op schema as well + as the upgrader function string literal.""" + # TODO(larryliu0820): Add support for custom ops + op_namespace = "aten" + if not op_upgraders or op_namespace not in self.model_opset_version or op_namespace not in self.compiler_opset_version: + return [] + model_ver = self.model_opset_version[op_namespace] + curr_ver = self.compiler_opset_version[op_namespace] + + # key is the target version. div__Scalar_0_3 should have a key of 4. + versioned_upgraders: Dict[int, Tuple[str, str]] = {get_target_version(name): v for name, v in + op_upgraders.items()} + target_upgraders: List[Tuple[str, str]] = [] + # we need all upgraders from model_ver + 1 to curr_ver, inclusively + for ver in range(model_ver + 1, curr_ver + 1): + if ver in versioned_upgraders: + target_upgraders.append(versioned_upgraders[ver]) + else: + # we may be able to get away with missing upgraders, if that operator is missing from given graph + # module. + log.warning("Missing an upgrader to upgrade to version {ver}.", extra={"ver": ver}) + + return target_upgraders + + @staticmethod + def _populate_passes(upgraders: List[Tuple[str, str]]) -> List[UpgraderPass]: + """Given a list of upgraders, loop through it from lower version to higher version and create passes for all + upgraders. se torch.Library API to register old ops. Op name will be + __. Register upgraders as CompositeImplicitAutograd kernels. For example: + + lib = Library("aten", "FRAGMENT") + lib.define(old_schema) + + impl_lib = Library("aten", "IMPL") + impl_lib.impl("div__Scalar_0_3", div__Scalar_0_3, "CompositeImplicitAutograd") + + @:var upgraders: a list of tuples. The first element of the tuple is the old schema and the second is the + upgrader function literal text. + @:return upgrader passes, order matters + """ + + upgrader_passes = [] + + def register_old_op(name: str, schema: str, impl_str: str): + """Registers an old version operator using impl_name as old op name.""" + lib.define(schema) + try: + exec(impl_str) + except Exception as e: + raise RuntimeError(f"Invalid upgrader string: {impl_str}") from e + impl_lib.impl(name, locals()[name], "CompositeImplicitAutograd") + + for (schema, upgrader_str) in upgraders: + upgrader_name = upgrader_str.split('(')[0].split(' ')[-1] + op_name = schema.split('(')[0].split("::")[-1] + schema = schema.replace(op_name, upgrader_name) + try: + register_old_op(name=upgrader_name, schema=schema, impl_str=upgrader_str) + except RuntimeError as e: + if "with the same name and overload name multiple times" in str(e): + print(f"Registering {upgrader_name} multiple times") + else: + raise RuntimeError from e + old_op_target = getattr(torch.ops.aten, upgrader_name).default + # for example, the operator instance of "aten::div" is torch.op.aten.div.default. We need to append the + # "default" at the end. + op_name, overload_name = (op_name, "default") if "." not in op_name else tuple(op_name.split(".")[:2]) + new_op_target = getattr(getattr(torch.ops.aten, op_name), overload_name) + # Note that the graph will have op names in the graph, but actually they are of old versions. + upgrader_passes.append( + GraphModuleOpUpgrader.UpgraderPass(old_target=new_op_target, new_target=old_op_target)) + + return upgrader_passes + + def upgrade(self, exported_program: ep.ExportedProgram) -> ep.ExportedProgram: + """Run each upgrader pass and then retrace to decompose it. Each upgrader pass replaces the old version of + operators with a custom operator. The custom operator contains a CompositeImplicitAutograd kernel (the + upgrading function itself). After retrace, this custom operator will be decomposed into the ops used in the + upgrader. After all passes are applied, the exported program will be upgraded to the target version.""" + if not self.upgrader_passes: + return exported_program + + args = [n.meta.get("val", None) for n in exported_program.graph.nodes if n.op == "placeholder"] + args_real_tensors = [torch.ones(tuple(arg.size()), dtype=arg.dtype) if isinstance(arg, FakeTensor) else arg for + arg in args] + assert exported_program.call_spec.in_spec is not None + args, kwargs = tree_unflatten(args_real_tensors, exported_program.call_spec.in_spec) + assert kwargs == {} + + for _pass in self.upgrader_passes: + upgraded_program = exported_program._transform_do_not_use(_pass) + # NB: we have to retrace the graph_module instead of ep because of some failure. + exported_program = export(upgraded_program.module(), args, kwargs) + + return exported_program diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_prims_common/wrappers.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_prims_common/wrappers.py new file mode 100644 index 0000000000000000000000000000000000000000..1fa575dc78c065326ab70673cac6baf4237cb0c0 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_prims_common/wrappers.py @@ -0,0 +1,401 @@ +import inspect +import warnings +from functools import wraps +from itertools import chain + +from typing import Callable, NamedTuple, Optional, overload, Sequence, Tuple + +import torch +import torch._prims_common as utils +from torch._prims_common import ( + CustomOutParamAnnotation, + ELEMENTWISE_TYPE_PROMOTION_KIND, + Number, + NumberType, + ShapeType, + TensorLike, + TensorLikeType, +) +from torch.utils import _pytree as pytree +from torch.utils._pytree import tree_flatten, tree_unflatten + + +@overload +def _maybe_convert_to_dtype(a: TensorLikeType, dtype: torch.dtype) -> TensorLikeType: + pass + + +@overload +def _maybe_convert_to_dtype(a: NumberType, dtype: torch.dtype) -> NumberType: + pass + + +@overload +def _maybe_convert_to_dtype(a: Sequence, dtype: torch.dtype) -> Sequence: + pass + + +@overload +def _maybe_convert_to_dtype(a: None, dtype: torch.dtype) -> None: + pass + + +# TODO: implement ref.cast with an option to enforce safe casting +def _maybe_convert_to_dtype(a, dtype): + if isinstance(a, TensorLike): + if a.dtype != dtype: + return a.to(dtype) + return a + if isinstance(a, Number): + return utils.dtype_to_type_ctor(dtype)(a) # type: ignore[arg-type] + if isinstance(a, Sequence): + return tuple(_maybe_convert_to_dtype(x, dtype) for x in a) + # Passthrough None because some functions wrapped with type promotion + # wrapper might have optional args + if a is None: + return None + + raise ValueError(f"Received type {type(a)} that is neither a tensor or a number!") + + +def _maybe_convert_to_type(a: NumberType, typ: type) -> NumberType: + if not isinstance(a, Number): + msg = f"Found unknown type {type(a)} when trying to convert scalars!" + raise ValueError(msg) + if not utils.is_weakly_lesser_type(type(a), typ): + msg = f"Scalar {a} of type {type(a)} cannot be safely cast to type {typ}!" + raise ValueError(msg) + + return typ(a) + + +def _annotation_has_type(*, typ, annotation): + if hasattr(annotation, "__args__"): + for a in annotation.__args__: + if _annotation_has_type(typ=typ, annotation=a): + return True + return False + + return typ is annotation + + +class elementwise_type_promotion_wrapper: + """ + Adds elementwise type promotion to a Python reference implementation. + + Takes two kwargs, type_promoting_args and type_promotion_kind. + + type_promoting_args must be a string Sequence specifiying the argument names of all + arguments that participate in type promotion (and should be type promoted). If the + arg specifies a Sequence-type then every element of the Sequence will participate in + type promotion. + + type_promotion_kind must be one of the kinds specified by ELEMENTWISE_TYPE_PROMOTION_KIND. + See its documentation for details. + + The return_dtype will be coerced to the wrapped function's dtype arg if it is available and + not None. + + Other type promotion behavior, like validating the Python type of scalar arguments, must + be handled separately. + """ + + def __init__( + self, + *, + type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND, + type_promoting_args: Optional[Sequence[str]] = None, + ): + self.type_promoting_arg_names = type_promoting_args + self.type_promotion_kind = type_promotion_kind + + def __call__(self, fn: Callable) -> Callable: + sig = inspect.signature(fn) + + @wraps(fn) + def _fn(*args, **kwargs): + bound = sig.bind(*args, **kwargs) + type_promoting_args = tuple( + bound.arguments[x] + for x in self.type_promoting_arg_names # type: ignore[union-attr] + if x in bound.arguments.keys() + ) + + flattened_type_promoting_args = pytree.arg_tree_leaves(*type_promoting_args) + compute_dtype, result_dtype = utils.elementwise_dtypes( + *flattened_type_promoting_args, + type_promotion_kind=self.type_promotion_kind, + ) + + promoted_args = { + x: _maybe_convert_to_dtype(bound.arguments[x], compute_dtype) + for x in self.type_promoting_arg_names # type: ignore[union-attr] + if x in bound.arguments.keys() + } + bound.arguments.update(promoted_args) + + result = fn(**bound.arguments) + + # Override the return_dtype if a dtype arg is present and not None + if "dtype" in bound.arguments: + maybe_dtype = bound.arguments["dtype"] + if maybe_dtype: # dtype cannot be None + result_dtype = maybe_dtype + + if isinstance(result, TensorLike): + return _maybe_convert_to_dtype(result, result_dtype) + if isinstance(result, Sequence): + return tuple(_maybe_convert_to_dtype(x, result_dtype) for x in result) + raise AssertionError(f"Unhandled result type: {type(result)}") + + _fn.__signature__ = sig # type: ignore[attr-defined] + return _fn + + +# Returns True if resize is necessary +def _resize_output_check(out: TensorLikeType, shape: ShapeType): + # If the shapes are correct there's nothing to do + if utils.same_shape(out.shape, shape): + return False + if out.numel() != 0: + msg = ( + f"An output with one or more elements was resized since it had shape {str(out.shape)} " + "which does not match the required output shape {str(shape)}. " + "This behavior is deprecated, and in a future PyTorch release outputs will not " + "be resized unless they have zero elements. " + "You can explicitly reuse an out tensor t by resizing it, inplace, to zero elements with t.resize_(0)." + ) + warnings.warn(msg) + return True + + +# TODO: handle tuples of tensors +def _maybe_resize_out(out: TensorLikeType, shape: ShapeType): + if _resize_output_check(out, shape): + return out.resize_(shape) + else: + return out + + +def _safe_copy_out( + *, copy_from: TensorLikeType, copy_to: TensorLikeType, exact_dtype: bool = False +): + # Checks same device + if copy_from.device != copy_to.device: + msg = "Attempting to copy from device {} to device {}, but cross-device copies are not allowed!".format( + copy_from.device, copy_to.device + ) + raise RuntimeError(msg) + + # Checks safe cast + if exact_dtype: + torch._check( + copy_from.dtype == copy_to.dtype, + lambda: f"Expected out tensor to have dtype {copy_from.dtype} " + f"but got {copy_to.dtype} instead", + ) + else: + torch._check( + utils.can_safe_cast_to(cast_from=copy_from.dtype, cast_to=copy_to.dtype), + lambda: f"Attempting to cast from {copy_from.dtype} to out tensor with dtype {copy_to.dtype}, " + "but this can't be cast because it is not safe!", + ) + + return copy_to.copy_(copy_from) + + +def out_wrapper(*out_names: str, exact_dtype: bool = False, pass_is_out: bool = False): + # The wrapped function needs to convert the output parameters to ensure + # compatibility between the Python API (which always uses "out" as the + # parameter name and may be a tuple) and the Aten API (which may have + # multiple output parameters and use different parameter names such as + # "grad_input", "indices" or "values".) + + default_out_names = ("out",) + if len(out_names) == 0: + # Use default in out name + out_names = default_out_names + + is_tensor = len(out_names) == 1 + + def _out_wrapper(fn: Callable) -> Callable: + """ + Adds the out parameter to a Python reference. + """ + out_type = ( + TensorLikeType + if is_tensor + else Tuple[tuple(TensorLikeType for _ in range(len(out_names)))] + ) + return_type = ( + TensorLikeType + if is_tensor + else NamedTuple( + f"return_types_{fn.__name__}", [(o, TensorLikeType) for o in out_names] + ) + ) + + sig = inspect.signature(fn) + factory_kwargs = ("device", "dtype") + is_factory_fn = all(p in sig.parameters for p in factory_kwargs) + + @wraps(fn) + def _fn(*args, out=None, **kwargs): + if is_factory_fn and out is not None: + for k in factory_kwargs: + out_attr = getattr(out, k) + if k not in kwargs: + kwargs[k] = out_attr + if pass_is_out: + result = fn(*args, is_out=(out is not None), **kwargs) + else: + result = fn(*args, **kwargs) + assert ( + isinstance(result, TensorLike) + and is_tensor + or isinstance(result, Tuple) # type: ignore[arg-type] + and len(result) == len(out_names) + ) + if out is not None: + # Naively you might expect this assert to be true, but + # it's not: + # + # assert type(out) == type(result) + # + # The reason is that functions under this wrapper can + # get registered to the Meta dispatch key, and that + # means they can be executed in a context where tensor + # subclasses are disabled (with no_dispatch), which is a + # handy way for an is-a tensor subclass (e.g., + # FakeTensor) to have the normal meta backend create a + # meta tensor, to be wrapped once it gets returned. + # In this situation, you will get a FakeTensor as + # the output tensor, but not the result--which will + # be a normal meta tensor, but this is perfectly + # harmless. + if is_tensor: + assert isinstance(out, TensorLike) + # These two operations are done in-place + _maybe_resize_out(out, result.shape) + _safe_copy_out(copy_from=result, copy_to=out, exact_dtype=exact_dtype) # type: ignore[arg-type] + else: + assert isinstance(out, Tuple) # type: ignore[arg-type] + torch._check_type( + len(out) == len(result), + lambda: f"expected tuple of {len(result)} elements but got {len(out)}", + ) + for r, o in zip(result, out): + # These two operations are done in-place + _maybe_resize_out(o, r.shape) + _safe_copy_out(copy_from=r, copy_to=o, exact_dtype=exact_dtype) # type: ignore[arg-type] + else: + out = result + # mypy does not see through the definition of out_type given that it's in a different scope + return out if is_tensor else return_type(*out) # type: ignore[operator] + + out_param = inspect.Parameter( + "out", + kind=inspect.Parameter.KEYWORD_ONLY, + default=None, + annotation=out_type, + ) + # Mark that the function now returns a tuple + assert isinstance(sig.return_annotation, str) or sig.return_annotation in ( + sig.empty, + out_type, + ) + params = chain(sig.parameters.values(), (out_param,)) + _fn.__signature__ = inspect.Signature( # type: ignore[attr-defined] + parameters=params, return_annotation=return_type # type: ignore[arg-type] + ) + + _fn.__annotations__ = fn.__annotations__ + _fn.__annotations__["out"] = out_type + _fn.__annotations__["return"] = return_type + + # In the special case of having a single tensor out parameter with a + # name other than out, add a special annotation to name the parameter + if is_tensor and out_names != default_out_names: + _fn.__annotations__[CustomOutParamAnnotation] = out_names[0] + + # Add an indicator attribute that can be used in special cases + # where having a function wrapped by `out_wrapper` is not desirable e.g. + # jit + _fn._torch_decompositions_out_wrapper = f"This function is wrapped by {out_wrapper.__module__}.out_wrapper" # type: ignore[attr-defined] + + return _fn + + return _out_wrapper + + +def _maybe_remove_out_wrapper(fn: Callable): + return inspect.unwrap( + fn, + stop=lambda f: not hasattr(f, "_torch_decompositions_out_wrapper"), + ) + + +def backwards_not_supported(prim): + def redispatch_prim(args, kwargs): + with torch._C._AutoDispatchBelowAutograd(): + old = torch._C._dispatch_tls_is_dispatch_key_excluded( + torch._C.DispatchKey.ADInplaceOrView + ) + return prim(*args, **kwargs) + + class BackwardsNotSupported(torch.autograd.Function): + @staticmethod + def forward(ctx, args_spec, *flat_args): + args, kwargs = tree_unflatten(flat_args, args_spec) # type: ignore[arg-type] + return redispatch_prim(args, kwargs) + + @staticmethod + def backward(ctx, *args): + raise RuntimeError("backwards not supported on prim") + + @wraps(prim) + def _autograd_impl(*args, **kwargs): + flat_args, args_spec = tree_flatten((args, kwargs)) + if torch.is_grad_enabled() and any( + a.requires_grad for a in flat_args if isinstance(a, torch.Tensor) + ): + # TODO: There is a subtle bug here: prims like copy_to + # return their input argument after mutating it; and custom + # autograd function will incorrectly turn the result into + # a view which will fail test_python_ref_executor tests. + # At the moment, we sidestep this by observing that the + # unit tests don't ever try to run the executor with + # autograd, so we don't exercise the buggy case, but if + # you ever want to feed autograd through this, be aware + # of it! We need a way of properly implementing autograd + # for mutating operations in Python to do this. + return BackwardsNotSupported.apply(args_spec, *flat_args) + else: + return redispatch_prim(args, kwargs) + + return _autograd_impl + + +# TODO: when tracing this will add torch tensors and not TensorMeta objects +# to the trace -- we should fix this by adding a tracing context and NumberMeta classes +# TODO: this wrapper is currently untested +def elementwise_unary_scalar_wrapper(fn: Callable) -> Callable: + """ + Allows unary operators that accept tensors to work with Python numbers. + """ + sig = inspect.signature(fn) + + @wraps(fn) + def _fn(*args, **kwargs): + if len(args) > 0 and isinstance(args[0], Number): + dtype = utils.type_to_dtype(type(args[0])) + args_ = list(args) + args_[0] = torch.tensor(args[0], dtype=dtype) + result = fn(*args_, **kwargs) + assert isinstance(result, torch.Tensor) + return result.item() + + return fn(*args, **kwargs) + + _fn.__signature__ = sig # type: ignore[attr-defined] + return _fn diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/_numeric_suite_fx.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/_numeric_suite_fx.py new file mode 100644 index 0000000000000000000000000000000000000000..b26dbadc006823cdbb2fb9f9cfce537336abf842 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/_numeric_suite_fx.py @@ -0,0 +1,1025 @@ +""" +This module contains tooling to compare weights and activations +across models. Example usage:: + + import copy + import torch + import torch.ao.quantization.quantize_fx as quantize_fx + import torch.ao.ns._numeric_suite_fx as ns + + m = torch.nn.Sequential(torch.nn.Conv2d(1, 1, 1)).eval() + mp = quantize_fx.prepare_fx(m, {'': torch.ao.quantization.default_qconfig}) + # We convert a copy because we need the original prepared model + # to be available for comparisons, and `quantize_fx.convert_fx` is inplace. + mq = quantize_fx.convert_fx(copy.deepcopy(mp)) + + # + # Comparing weights + # + + # extract weight pairs + weight_comparison = ns.extract_weights('a', mp, 'b', mq) + + # add SQNR for each comparison, inplace + ns.extend_logger_results_with_comparison( + weight_comparison, 'a', 'b', torch.ao.ns.fx.utils.compute_sqnr, + 'sqnr') + + # weight_comparison contains the weights from `mp` and `mq` stored + # in pairs, and can be used for further analysis. + + + # + # Comparing activations, with error propagation + # + + # add loggers + mp_ns, mq_ns = ns.add_loggers( + 'a', copy.deepcopy(mp), + 'b', copy.deepcopy(mq), + ns.OutputLogger) + + # send an example datum to capture intermediate activations + datum = torch.randn(1, 1, 1, 1) + mp_ns(datum) + mq_ns(datum) + + # extract intermediate activations + act_comparison = ns.extract_logger_info( + mp_ns, mq_ns, ns.OutputLogger, 'b') + + # add SQNR for each comparison, inplace + ns.extend_logger_results_with_comparison( + act_comparison, 'a', 'b', torch.ao.ns.fx.utils.compute_sqnr, + 'sqnr') + + # act_comparison contains the activations from `mp_ns` and `mq_ns` stored + # in pairs, and can be used for further analysis. + + # + # Comparing activations, without error propagation + # + + # create shadow model + mp_shadows_mq = ns.add_shadow_loggers( + 'a', copy.deepcopy(mp), + 'b', copy.deepcopy(mq), + ns.OutputLogger) + + # send an example datum to capture intermediate activations + datum = torch.randn(1, 1, 1, 1) + mp_shadows_mq(datum) + + # extract intermediate activations + shadow_act_comparison = ns.extract_shadow_logger_info( + mp_shadows_mq, ns.OutputLogger, 'b') + + # add SQNR for each comparison, inplace + ns.extend_logger_results_with_comparison( + shadow_act_comparison, 'a', 'b', torch.ao.ns.fx.utils.compute_sqnr, + 'sqnr') + + # shadow_act_comparison contains the activations from `mp_ns` and `mq_ns` stored + # in pairs, and can be used for further analysis. + +""" + +import collections + +import torch +import torch.nn as nn +import torch.ao.quantization.quantize_fx as quantize_fx +from torch.fx import GraphModule +from torch.fx.graph import Node +from torch.ao.ns.fx.mappings import ( + get_base_name_to_sets_of_related_ops, +) +from torch.ao.ns.fx.graph_matcher import ( + get_matching_subgraph_pairs, + get_type_a_related_to_b, +) + +from .fx.weight_utils import ( + extract_weight_from_node, +) + +from .fx.graph_passes import ( + add_loggers_to_model, + create_a_shadows_b, +) + +from .fx.utils import ( + rekey_logger_info_on_node_name_of_model, + maybe_add_missing_fqns, + get_target_type_str, +) + +from .fx.ns_types import ( + NSSingleResultValuesType, + NSResultsType, + NSNodeTargetType, +) +from torch.ao.quantization.backend_config.utils import get_fusion_pattern_to_root_node_getter +from torch.ao.quantization.backend_config import BackendConfig +from torch.ao.quantization.fx.match_utils import _find_matches +from torch.ao.quantization.fx.graph_module import _get_observed_graph_module_attr +from torch.ao.quantization.fx.qconfig_mapping_utils import _generate_node_name_to_qconfig +from torch.ao.quantization.fx.quantize_handler import _get_pattern_to_quantize_handlers +from torch.ao.quantization.qconfig import QConfigAny +from torch.ao.quantization import QConfigMapping +from torch.ao.ns.fx.n_shadows_utils import ( + OutputProp, + _get_dedup_subgraphs, + SHADOW_WRAPPER_NODE_NAME_PREFIX, + group_results_by_subgraph, + create_results_comparison, + print_n_shadows_summary, + create_n_transformed_and_logged_copies_of_subgraph, + create_add_loggers_graph, + extract_weight_comparison, +) +from torch.ao.ns.fx.qconfig_multi_mapping import QConfigMultiMapping + +from typing import Dict, Tuple, Callable, List, Optional, Set, Any, Type + +RNNReturnType = Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] + +class OutputLogger(nn.Module): + """ + Base class for capturing intermediate values. + """ + stats: List[torch.Tensor] + stats_rnn: List[RNNReturnType] + + # Mark as impure so that calls to it will not be removed during DCE. + _is_impure = True + + def __init__( + self, + ref_node_name: str, + prev_node_name: str, + model_name: str, + ref_name: str, + prev_node_target_type: str, + ref_node_target_type: str, + results_type: str, + index_within_arg: int, + index_of_arg: int, + fqn: Optional[str], + qconfig_str: Optional[str] = '', + ): + super().__init__() + self.stats: List[torch.Tensor] = [] + self.stats_rnn: List[RNNReturnType] = [] + + # name of the node which was responsible for adding this logger + # Note: + # - if we are logging node outputs, this is the same as prev_node_name + # - if we are logging node inputs, this is the name of the node + # whose input this logger is logging. + # + # example, where logger1 is logging input of op1 and logger2 is logging + # the output of op1: + # + # x1 -> logger1 -> op1 -> logger2 -> x2 + # + # in this example, + # - logger1's prev_node_name is x1 and ref_node_name is op1 + # - logger2's prev_node_name is op1 and ref_node_name is op1 + self.ref_node_name = ref_node_name + # name of the node whose output this Logger is capturing + self.prev_node_name = prev_node_name + + # name of the model from which the node originated from + self.model_name = model_name + # reference name, used to match loggers from separate models + # to each other + self.ref_name = ref_name + # type of the target of the node whose output this logger is logging + self.prev_node_target_type = prev_node_target_type + # type of the target of the node which was responsible for adding this + # logger + self.ref_node_target_type = ref_node_target_type + # what kind of values are inside of stats + self.results_type = results_type + # index of this node within the arg of the input/output node + # for example, in cat([x1, x2, x3], dim=0), x2 would have index_within_arg == 1 + self.index_within_arg = index_within_arg + # index of this node within the args of the input/output node + # for example, in add(x1, x2), x2 would have index_of_arg == 1 + self.index_of_arg = index_of_arg + # fully qualified name + self.fqn = fqn + # if loggers are added before prepare_fx, but we do not want + # collect results of calibration, only results after convert_fx + # so, we add a flag to control whether this logger collects data + self.enabled = True + # string representation of qconfig + self.qconfig_str = qconfig_str + # this can be turned off to reduce memory usage during calibration + self.save_activations = True + + # Note: cannot annotate the type of x because TorchScript does not support + # the Union type. + def forward(self, x): + """ + """ # blank docblock to make autodoc happy + # TODO(future PR): consider designing this better, as the difference + # between these two flags is subtle and not obvious. + if not self.enabled: + return x + if not self.save_activations: + return x + # TODO(future PR): consider refactoring this to better reuse the parent + # class + if isinstance(x, torch.Tensor): + self.stats.append(x.detach()) + elif isinstance(x, tuple) and len(x) == 2 and len(x[1]) == 2: + new_res = (x[0].detach(), (x[1][0].detach(), x[1][1].detach())) + self.stats_rnn.append(new_res) + return x + + def __repr__(self): + clean_dict = { + k: v + for k, v in self.__dict__.items() + # skip nn.Module keys + if (k != 'training') and not k.startswith('_') + } + return f"OutputLogger({clean_dict})" + + +class OutputComparisonLogger(OutputLogger): + """ + Same as OutputLogger, but also requires the original activation + in order to calculate the comparison at calibration time + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # TODO(future PR): make the comparison function configurable + self.comparison_fn = torch.ao.ns.fx.utils.compute_sqnr + self.comparison_fn_name = 'sqnr' + # precalculated comparisons of logger output versus reference + self.comparisons = [] + # precalculated comparisons function + + def forward(self, x, x_ref): + """ + """ # blank docblock to make autodoc happy + if not self.enabled: + return x + assert isinstance(x, torch.Tensor), 'non-tensor inputs not yet supported' + if self.save_activations: + # save the activation, for debugging + self.stats.append(x.detach()) + # save the comparison + self.comparisons.append(self.comparison_fn(x, x_ref)) + return x + + def __repr__(self): + clean_dict = { + k: v + for k, v in self.__dict__.items() + # skip nn.Module keys + if (k != 'training') and not k.startswith('_') + } + return f"OutputComparisonLogger({clean_dict})" + + +class NSTracer(quantize_fx.QuantizationTracer): + """ + Just like a regular FX quantization tracer, but treats observers and fake_quantize + modules as leaf modules. + """ + def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool: + """ + """ # blank docblock to make autodoc happy + if isinstance(m, torch.ao.quantization.ObserverBase): + return True + elif isinstance(m, torch.ao.quantization.FakeQuantizeBase): + return True + return super().is_leaf_module(m, module_qualified_name) + + +def _extract_weights_one_model( + model_name: str, + model: GraphModule, + nodes_and_names_to_instrument: List[Tuple[Node, str]], + results: NSResultsType, + op_to_type_to_weight_extraction_fn: Optional[Dict[str, Dict[Callable, Callable]]] = None, +) -> None: + torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._extract_weights_one_model") + for node, ref_name in nodes_and_names_to_instrument: + res_type = NSSingleResultValuesType.WEIGHT.value + extracted_weight = extract_weight_from_node( + node, model, op_to_type_to_weight_extraction_fn) + if extracted_weight: + if ref_name not in results: + results[ref_name] = {res_type: {}} + results[ref_name][res_type][model_name] = [extracted_weight] + + +def _extract_weights_impl( + model_name_a: str, + gm_a: GraphModule, + model_name_b: str, + gm_b: GraphModule, + base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None, + unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None, + op_to_type_to_weight_extraction_fn: Optional[Dict[str, Dict[Callable, Callable]]] = None, +) -> NSResultsType: + torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._extract_weights_impl") + matched_subgraph_pairs = get_matching_subgraph_pairs( + gm_a, gm_b, base_name_to_sets_of_related_ops, + unmatchable_types_map) + + # split the subgraph pairs into one data structure for each model + nodes_and_names_to_instrument_a: List[Tuple[Node, str]] = [] + nodes_and_names_to_instrument_b: List[Tuple[Node, str]] = [] + for match_name, match in matched_subgraph_pairs.items(): + subgraph_a, subgraph_b = match + nodes_and_names_to_instrument_a.append((subgraph_a.base_op_node, match_name)) + nodes_and_names_to_instrument_b.append((subgraph_b.base_op_node, match_name)) + + # populate the results, one model at a time + results: NSResultsType = {} + _extract_weights_one_model( + model_name_a, gm_a, nodes_and_names_to_instrument_a, results, + op_to_type_to_weight_extraction_fn) + _extract_weights_one_model( + model_name_b, gm_b, nodes_and_names_to_instrument_b, results, + op_to_type_to_weight_extraction_fn) + + # fill in missing fqn entries + maybe_add_missing_fqns(results) + + # rekey on names of nodes in gm_b + results = rekey_logger_info_on_node_name_of_model(results, model_name_b) + + return results + + +def extract_weights( + model_name_a: str, + model_a: nn.Module, + model_name_b: str, + model_b: nn.Module, + base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None, + unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None, + op_to_type_to_weight_extraction_fn: Optional[Dict[str, Dict[Callable, Callable]]] = None, +) -> NSResultsType: + """ + Extract weights from model A and model B, and return a comparison. + + Args: + model_name_a: string name of model A to use in results + model_a: model A + model_name_b: string name of model B to use in results + model_b: model B + base_name_to_sets_of_related_ops: optional override of subgraph base nodes, subject to change + unmatchable_types_map: optional override of unmatchable types, subject to change + op_to_type_to_weight_extraction_fn: optional override of function which extracts weight + from a type, subject to change + + Return: + NSResultsType, containing the weight comparisons + """ + + torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.extract_weights") + if base_name_to_sets_of_related_ops is None: + base_name_to_sets_of_related_ops = \ + get_base_name_to_sets_of_related_ops() + type_a_related_to_b = \ + get_type_a_related_to_b(base_name_to_sets_of_related_ops) + + # TODO(future PR): expose these + skipped_module_names: List[str] = [] + skipped_module_classes: List[Callable] = [] + tracer_a = NSTracer(skipped_module_names, skipped_module_classes) + tracer_b = NSTracer(skipped_module_names, skipped_module_classes) + gm_a = GraphModule(model_a, tracer_a.trace(model_a)) + maybe_model_a_node_name_to_scope = _get_observed_graph_module_attr(model_a, 'node_name_to_scope') + if maybe_model_a_node_name_to_scope is not None: + gm_a._node_name_to_scope = maybe_model_a_node_name_to_scope + gm_b = GraphModule(model_b, tracer_b.trace(model_b)) + maybe_model_b_node_name_to_scope = _get_observed_graph_module_attr(model_b, 'node_name_to_scope') + if maybe_model_b_node_name_to_scope is not None: + gm_b._node_name_to_scope = maybe_model_b_node_name_to_scope + return _extract_weights_impl( + model_name_a, gm_a, model_name_b, gm_b, base_name_to_sets_of_related_ops, + unmatchable_types_map, op_to_type_to_weight_extraction_fn) + + +def _add_loggers_one_model( + model_name: str, + model: GraphModule, + nodes_and_names_to_instrument_inputs: List[Tuple[Node, str, str]], + nodes_and_names_to_instrument_outputs: List[Tuple[Node, str, str]], + logger_cls: Callable, +) -> nn.Module: + torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._add_loggers_one_model") + + # TODO(future PR): do not observe nodes we do not care + # about (both fp32, denylist, etc) + node_to_instrument_inputs_to_ref_name: Dict[Node, Tuple[str, str]] = {} + node_to_instrument_outputs_to_ref_name: Dict[Node, Tuple[str, str]] = {} + for node, ref_name, ref_node_type in nodes_and_names_to_instrument_inputs: + node_to_instrument_inputs_to_ref_name[node] = (ref_name, ref_node_type) + for node, ref_name, ref_node_type in nodes_and_names_to_instrument_outputs: + node_to_instrument_outputs_to_ref_name[node] = (ref_name, ref_node_type) + + model = add_loggers_to_model( + model, node_to_instrument_inputs_to_ref_name, + node_to_instrument_outputs_to_ref_name, logger_cls, model_name) + return model + + +def _add_loggers_impl( + name_a: str, + gm_a: GraphModule, + name_b: str, + gm_b: GraphModule, + logger_cls: Callable, + should_log_inputs: bool, + base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None, + unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None, +) -> Tuple[nn.Module, nn.Module]: + torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._add_loggers_impl") + matched_subgraph_pairs = get_matching_subgraph_pairs( + gm_a, gm_b, + base_name_to_sets_of_related_ops, unmatchable_types_map) + nodes_and_names_to_instrument_inputs_a = [] + nodes_and_names_to_instrument_inputs_b = [] + nodes_and_names_to_instrument_outputs_a = [] + nodes_and_names_to_instrument_outputs_b = [] + for match_name, (subgraph_a, subgraph_b) in matched_subgraph_pairs.items(): + ref_node_type_a = get_target_type_str(subgraph_a.base_op_node, gm_a) + ref_node_type_b = get_target_type_str(subgraph_b.base_op_node, gm_b) + # Note: for matching inputs we use start_node, such as observing + # the input of linear in linear-relu + if should_log_inputs: + nodes_and_names_to_instrument_inputs_a.append( + (subgraph_a.start_node, match_name, ref_node_type_a)) + nodes_and_names_to_instrument_inputs_b.append( + (subgraph_b.start_node, match_name, ref_node_type_b)) + # Note: for matching activations we always use end_node, + # such as observing the output of relu in linear-relu + nodes_and_names_to_instrument_outputs_a.append( + (subgraph_a.end_node, match_name, ref_node_type_a)) + nodes_and_names_to_instrument_outputs_b.append( + (subgraph_b.end_node, match_name, ref_node_type_b)) + + new_model_a = _add_loggers_one_model( + name_a, gm_a, nodes_and_names_to_instrument_inputs_a, + nodes_and_names_to_instrument_outputs_a, logger_cls) + new_model_b = _add_loggers_one_model( + name_b, gm_b, nodes_and_names_to_instrument_inputs_b, + nodes_and_names_to_instrument_outputs_b, logger_cls) + return (new_model_a, new_model_b) + + +def add_loggers( + name_a: str, + model_a: nn.Module, + name_b: str, + model_b: nn.Module, + logger_cls: Callable, + should_log_inputs : bool = False, + base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None, + unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None, +) -> Tuple[nn.Module, nn.Module]: + """ + Instrument model A and model B with loggers. + + Args: + name_a: string name of model A to use in results + model_a: model A + name_b: string name of model B to use in results + model_b: model B + logger_cls: class of Logger to use + base_name_to_sets_of_related_ops: optional override of subgraph base nodes, subject to change + unmatchable_types_map: optional override of unmatchable types, subject to change + + Return: + Returns a tuple of (model_a_with_loggers, model_b_with_loggers). Modifies both models inplace. + """ + + torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.add_loggers") + # TODO(future PR): expose these + skipped_module_names: List[str] = [] + skipped_module_classes: List[Callable] = [] + tracer_a = NSTracer(skipped_module_names, skipped_module_classes) + tracer_b = NSTracer(skipped_module_names, skipped_module_classes) + gm_a = GraphModule(model_a, tracer_a.trace(model_a)) + maybe_model_a_node_name_to_scope = _get_observed_graph_module_attr(model_a, 'node_name_to_scope') + if maybe_model_a_node_name_to_scope is not None: + gm_a._node_name_to_scope = maybe_model_a_node_name_to_scope + gm_b = GraphModule(model_b, tracer_b.trace(model_b)) + maybe_model_b_node_name_to_scope = _get_observed_graph_module_attr(model_b, 'node_name_to_scope') + if maybe_model_b_node_name_to_scope is not None: + gm_b._node_name_to_scope = maybe_model_b_node_name_to_scope + return _add_loggers_impl( + name_a, gm_a, name_b, gm_b, logger_cls, + should_log_inputs=should_log_inputs, + base_name_to_sets_of_related_ops=base_name_to_sets_of_related_ops, + unmatchable_types_map=unmatchable_types_map) + + +def _extract_logger_info_one_model( + model: nn.Module, + results: NSResultsType, + logger_cls: Callable, +) -> None: + torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._extract_logger_info_one_model") + for gm_name, mod in model.named_modules(): + # TODO(future PR): better check when scripted + is_logger = ( + isinstance(mod, logger_cls) # type: ignore[arg-type] + or ( + isinstance(mod, torch.jit.RecursiveScriptModule) + and mod.original_name == 'OutputLogger' + ) + ) + if is_logger: + key = mod.ref_name + if key not in results: + results[key] = {} + assert mod.model_name not in results[key], \ + f"{mod.model_name} is already present in results" + if mod.results_type not in results[key]: + results[key][mod.results_type] = {} + if mod.model_name not in results[key][mod.results_type]: + results[key][mod.results_type][mod.model_name] = [] + stats_to_use = mod.stats + if len(mod.stats_rnn) > 0: + stats_to_use = mod.stats_rnn + data = { + 'type': mod.results_type, + 'values': stats_to_use, + 'ref_node_name': mod.ref_node_name, + 'ref_node_target_type': mod.ref_node_target_type, + 'prev_node_name': mod.prev_node_name, + 'prev_node_target_type': mod.prev_node_target_type, + 'index_within_arg': mod.index_within_arg, + 'index_of_arg': mod.index_of_arg, + 'fqn': mod.fqn, + 'qconfig_str': mod.qconfig_str, + } + if hasattr(mod, 'comparisons'): + data['comparisons'] = mod.comparisons + data['comparison_fn_name'] = mod.comparison_fn_name + else: + data['comparisons'] = [] + data['comparison_fn_name'] = '' + results[key][mod.results_type][mod.model_name].append(data) + # ensure the list stays sorted + results[key][mod.results_type][mod.model_name].sort( + key=lambda res: + f"{res['index_of_arg']}:{res['index_within_arg']}" + ) + + +# TODO(future PR): align on naming +# this is equivalent of just the comparison extraction part of `ns.compare_model_outputs` +def extract_logger_info( + model_a: nn.Module, + model_b: nn.Module, + logger_cls: Callable, + model_name_to_use_for_layer_names: str, +) -> NSResultsType: + """ + Traverse all loggers in `model_a` and `model_b`, and extract the logged + information. + + Args: + model_a: model A + model_b: model B + logger_cls: class of Logger to use + model_name_to_use_for_layer_names: string name of model to use for + layer names in the output + + Return: + NSResultsType, containing the logged comparisons + """ + torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.extract_logger_info") + results: NSResultsType = {} + for model in (model_a, model_b): + _extract_logger_info_one_model(model, results, logger_cls) + # fill in missing fqn entries + maybe_add_missing_fqns(results) + # rekey on the name of model b + results = rekey_logger_info_on_node_name_of_model( + results, model_name_to_use_for_layer_names) + return results + + +def _add_shadow_loggers_impl( + name_a: str, + gm_a: GraphModule, + name_b: str, + gm_b: GraphModule, + logger_cls: Callable, + should_log_inputs: bool, + base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None, + node_type_to_io_type_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None, + unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None, +) -> nn.Module: + torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._add_shadow_loggers_impl") + matched_subgraph_pairs = get_matching_subgraph_pairs( + gm_a, gm_b, base_name_to_sets_of_related_ops, + unmatchable_types_map) + gm_a_shadows_b = create_a_shadows_b( + name_a, gm_a, name_b, gm_b, matched_subgraph_pairs, logger_cls, + should_log_inputs=should_log_inputs, + node_type_to_io_type_map=node_type_to_io_type_map) + return gm_a_shadows_b + + +def add_shadow_loggers( + name_a: str, + model_a: nn.Module, + name_b: str, + model_b: nn.Module, + logger_cls: Callable, + should_log_inputs: bool = False, + base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None, + node_type_to_io_type_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None, + unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None, +) -> nn.Module: + """ + Instrument model A and model B with shadow loggers. + + Args: + name_a: string name of model A to use in results + model_a: model A + name_b: string name of model B to use in results + model_b: model B + logger_cls: class of Logger to use + should_log_inputs: whether to log inputs + base_name_to_sets_of_related_ops: optional override of subgraph base nodes, subject to change + unmatchable_types_map: optional override of unmatchable types, subject to change + """ + torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.add_shadow_loggers") + # TODO(future PR): expose these + skipped_module_names: List[str] = [] + skipped_module_classes: List[Callable] = [] + tracer_a = NSTracer(skipped_module_names, skipped_module_classes) + tracer_b = NSTracer(skipped_module_names, skipped_module_classes) + gm_a = GraphModule(model_a, tracer_a.trace(model_a)) + maybe_model_a_node_name_to_scope = _get_observed_graph_module_attr(model_a, 'node_name_to_scope') + if maybe_model_a_node_name_to_scope is not None: + gm_a._node_name_to_scope = maybe_model_a_node_name_to_scope + gm_b = GraphModule(model_b, tracer_b.trace(model_b)) + maybe_model_b_node_name_to_scope = _get_observed_graph_module_attr(model_b, 'node_name_to_scope') + if maybe_model_b_node_name_to_scope is not None: + gm_b._node_name_to_scope = maybe_model_b_node_name_to_scope + return _add_shadow_loggers_impl( + name_a, gm_a, name_b, gm_b, logger_cls, + should_log_inputs=should_log_inputs, + base_name_to_sets_of_related_ops=base_name_to_sets_of_related_ops, + node_type_to_io_type_map=node_type_to_io_type_map, + unmatchable_types_map=unmatchable_types_map) + + +def extract_shadow_logger_info( + model_a_shadows_b: nn.Module, + logger_cls: Callable, + model_name_to_use_for_layer_names: str, +) -> NSResultsType: + """ + Traverse all loggers in a shadow model, and extract the logged + information. + + Args: + model_a_shadows_b: shadow model + logger_cls: class of Logger to use + model_name_to_use_for_layer_names: string name of model to use for + layer names in the output + + Return: + NSResultsType, containing the logged comparisons + """ + torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.extract_shadow_logger_info") + results: NSResultsType = collections.defaultdict(dict) + _extract_logger_info_one_model(model_a_shadows_b, results, logger_cls) + # fill in missing fqn entries + maybe_add_missing_fqns(results) + # rekey on the name of model b + results = rekey_logger_info_on_node_name_of_model( + results, model_name_to_use_for_layer_names) + return dict(results) + + +def extend_logger_results_with_comparison( + results: NSResultsType, + model_name_1: str, + model_name_2: str, + comparison_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], + comparison_name: str, +) -> None: + """ + Compares the logged values from `model_name_2` against the corresponding + values in `model_name_1`, using `comparison_fn`. Records the result + in `model_name_2`'s results under `comparison_name`. Modifies `results` inplace. + + Args: + results: the result data structure from `extract_logger_info` or + `extract_shadow_logger_info`. + model_name_1: string name of model 1 + model_name_2: string name of model 2 + comparison_fn: function to compare two Tensors + comparison_name: string name of model to use for + layer names in the output + """ + for results_type_to_results in results.values(): + for model_name_to_results in results_type_to_results.values(): + assert model_name_1 in model_name_to_results, \ + f"{model_name_1} not found in results" + assert model_name_2 in model_name_to_results, \ + f"{model_name_2} not found in results" + + results_1 = model_name_to_results[model_name_1] + results_2 = model_name_to_results[model_name_2] + + for result_2 in results_2: + index_within_arg_2 = result_2['index_within_arg'] + index_of_arg_2 = result_2['index_of_arg'] + # find corresponding result_1 + result_1 = None + for cur_result_1 in results_1: + index_within_arg_1 = cur_result_1['index_within_arg'] + index_of_arg_1 = cur_result_1['index_of_arg'] + if ( + (index_within_arg_1 == index_within_arg_2) and + (index_of_arg_1 == index_of_arg_2) + ): + result_1 = cur_result_1 + break + assert result_1 is not None + + values_1 = result_1['values'] + values_2 = result_2['values'] + result_2[comparison_name] = [] + for value_1, value_2 in zip(values_1, values_2): + comparison_result = comparison_fn(value_1, value_2) + result_2[comparison_name].append(comparison_result) + +def prepare_n_shadows_model( + model: torch.nn.Module, + example_inputs: Any, + qconfig_multi_mapping: QConfigMultiMapping, + backend_config: BackendConfig, + custom_prepare_fn: Optional[Callable] = None, + custom_prepare_kwargs: Optional[Dict[str, Any]] = None, + custom_tracer: Any = None, +) -> GraphModule: + """ + Given a model with a graph with M ops such as + + + args_kwargs_m -> op_m -> output_m + + + And a set of N qconfigs for each op, creates a new model, with + each of the subgraph of `op_m` transformed into + + .. code:: + + |---------> op_m_n -> log_m_n + | / + args_kwargs_m ---------> op_m -> log_m_0 + + Where op_m_n is op_m wrapped in a submodule and transformed with + qconfig_n, and its inner graph looks like + + .. code:: + + args_m -------- op_m_prepared_with_qconfig_n -> out_m_n + / + kwargs_m --- + + This is useful for testing different quantization of multiple layers in + a single pass through the model. + + High level TODOs for future PRs: + * figure out a better way to name the output structure + * return a results data structure instead of printing it out + * add examples to docblocks + """ + + if custom_tracer is None: + tracer = quantize_fx.QuantizationTracer([], []) + else: + tracer = custom_tracer + mt = torch.fx.GraphModule(model, tracer.trace(model)) + # this is necessary to ensure logger FQNs get populated + mt._node_name_to_scope = tracer.node_name_to_scope + + # run example input propagation, we need this to call prepare_fx on + # individual subgraphs + output_prop = OutputProp(mt) + output_prop.propagate(*example_inputs) + + # Find the set of subgraphs in the original graph which we need to + # consider. + modules = dict(mt.named_modules(remove_duplicate=False)) + patterns = _get_pattern_to_quantize_handlers(backend_config) + root_node_getter_mapping = \ + get_fusion_pattern_to_root_node_getter(backend_config) + standalone_module_names: List[str] = [] + standalone_module_classes: List[Type] = [] + custom_module_classes: List[Type] = [] + matches = _find_matches( + mt.graph, modules, patterns, root_node_getter_mapping, + standalone_module_names, standalone_module_classes, custom_module_classes) + subgraphs_dedup: Dict[str, List[Node]] = \ + _get_dedup_subgraphs(matches) + + # generate node to qconfig for each subgraph + # TODO(future PR): deduplicate repeating entries + list_of_node_name_to_qconfig: List[Dict[str, QConfigAny]] = [] + for qconfig_mapping in qconfig_multi_mapping.qconfig_mappings_list: + node_name_to_qconfig = _generate_node_name_to_qconfig( + mt, modules, mt.graph, qconfig_mapping, tracer.node_name_to_scope) + list_of_node_name_to_qconfig.append(node_name_to_qconfig) + + # For each region in the model, do the following: + # For each qconfig for that region, do the following: + # 1. create a copy of the region wrapped in a module + # 2. pass original args, original kwargs, and expected output to module + # 3. add an output comparison logger and hook it up to compare + # actual output to expected output + # 4. run `prepare_fx` on the module + for (subgraph_idx, (match_name, nodes_in_this_subgraph)) in \ + enumerate(subgraphs_dedup.items()): + create_n_transformed_and_logged_copies_of_subgraph( + mt, subgraph_idx, match_name, nodes_in_this_subgraph, + qconfig_multi_mapping.qconfig_mappings_list, list_of_node_name_to_qconfig, + custom_prepare_fn, custom_prepare_kwargs # type: ignore[arg-type] + ) + + return mt + +# TODO(future PR): we should rethink the names of all the PNP APIs +def _prepare_n_shadows_add_loggers_model( + model: torch.nn.Module, + example_inputs: Any, + qconfig_mapping: QConfigMapping, + backend_config: BackendConfig, +) -> torch.nn.Module: + r""" + Note: this API is not recommended for wide usage, it is only + provided for customers who need to migrate from the `add_loggers` + API. + + This creates a model which provides logging for the following + problem: if we quantize `model` with `qconfig_mapping` and feed + the same input through both models, log the comparisons of + corresponding intermediate layers. + + The problem is solved with a single model. Specifically, we + partition `model` into N subgraphs, create a copy of each relevant + subgraph, wrap it in a module, apply the quantization API to that + module, and hook up loggers to measure the comparisons. + + Example starting graph: + + x0 -> op0 -> x1 -> op1 -> x2 + + Example config: quantize op0 to int8, do nothing to op1. + The following graph will be created: + + .. code:: + + x0_0 -> op0_0 -> x1_0 -> log -----> op1_0 -> x2_0 -> log + \ \ \ # noqa: W605 + ---> op0_1 -> x1_1 ----> clog -> op1_0 -> x2_1 ----> clog + + Where op0_0 is op0, op0_1 is op0 wrapped in a submodule and quantized + to int8, op1_0 is op1 (appearing in the graph twice), log is a logger, + and clog is a comparison logger. + """ + + tracer = quantize_fx.QuantizationTracer([], []) + mt = torch.fx.GraphModule(model, tracer.trace(model)) + # this is necessary to ensure logger FQNs get populated + mt._node_name_to_scope = tracer.node_name_to_scope + + # run example input propagation, we need this to call prepare_fx on + # individual subgraphs + output_prop = OutputProp(mt) + output_prop.propagate(*example_inputs) + + # Find the set of subgraphs in the original graph which we need to + # consider. + modules = dict(mt.named_modules(remove_duplicate=False)) + patterns = _get_pattern_to_quantize_handlers(backend_config) + root_node_getter_mapping = \ + get_fusion_pattern_to_root_node_getter(backend_config) + standalone_module_names: List[str] = [] + standalone_module_classes: List[Type] = [] + custom_module_classes: List[Type] = [] + matches = _find_matches( + mt.graph, modules, patterns, root_node_getter_mapping, + standalone_module_names, standalone_module_classes, custom_module_classes) + subgraphs_dedup: Dict[str, List[Node]] = \ + _get_dedup_subgraphs(matches) + + # generate node to qconfig for each subgraph + node_name_to_qconfig = _generate_node_name_to_qconfig( + mt, modules, mt.graph, qconfig_mapping, tracer.node_name_to_scope) + + # Now, mutate the graph to be the add_loggers graph with propagation + # error. + create_add_loggers_graph( + mt, subgraphs_dedup, qconfig_mapping, node_name_to_qconfig) + + return mt + +# TODO(future PR): we should rethink the names of all the PNP APIs +def _n_shadows_compare_weights( + model: torch.nn.Module, + example_inputs: Any, + qconfig_mapping: QConfigMapping, + backend_config: BackendConfig, +) -> NSResultsType: + """ + Note: this API is not recommended for wide usage, it is only + provided for customers who need to migrate from the `add_loggers` + API. + """ + qconfig_multi_mapping = \ + QConfigMultiMapping.from_list_qconfig_mapping([qconfig_mapping]) + mp = prepare_n_shadows_model( + model, example_inputs, qconfig_multi_mapping, backend_config) + # passing inputs through the model is necessary to populate + # observers which observe weights with real values + mp(*example_inputs) + mq = convert_n_shadows_model(mp) + weight_comparison = extract_weight_comparison(mq) + return weight_comparison + +# TODO(future PR): consider aligning API signature with other similar quantization +# functions (enable_fake_quant, etc) +def loggers_set_enabled(model: torch.nn.Module, enabled: bool) -> None: + """ + Sets the `enabled` setting on a `model`'s loggers + """ + for name, child in model.named_modules(): + if isinstance(child, OutputLogger): + child.enabled = enabled + +# TODO(future PR): consider aligning API signature with other similar quantization +# functions (enable_fake_quant, etc) +def loggers_set_save_activations( + model: torch.nn.Module, + save_activations: bool, +) -> None: + """ + Sets the `save_activations` setting on a `model`'s loggers + """ + for name, child in model.named_modules(): + if isinstance(child, OutputLogger): + child.save_activations = save_activations + +def convert_n_shadows_model( + model: GraphModule, + custom_convert_fn: Optional[Callable] = None, + custom_convert_kwargs: Optional[Dict[str, Any]] = None +) -> GraphModule: + """ + Given a model from `prepare_n_shadows_model`, runs `convert_fx` + on each shadow submodule. + """ + for node in model.graph.nodes: + # TODO(future PR): consider matching in a safer way than + # node name string match + if node.name.startswith(SHADOW_WRAPPER_NODE_NAME_PREFIX): + orig_mod = getattr(model, node.name) + if custom_convert_fn is None: + converted_mod = torch.ao.quantization.quantize_fx.convert_fx( + orig_mod) + else: + if custom_convert_kwargs is None: + custom_convert_kwargs = {} + converted_mod = custom_convert_fn(orig_mod, **custom_convert_kwargs) + setattr(model, node.name, converted_mod) + + return model + +def extract_results_n_shadows_model(model: torch.nn.Module) -> NSResultsType: + """ + Extracts logger results from `model`. + """ + results: NSResultsType = {} + _extract_logger_info_one_model(model, results, OutputLogger) + return results + +def print_comparisons_n_shadows_model(results: NSResultsType) -> None: + """ + Prints a summary of extracted `results`. + """ + results_grouped = group_results_by_subgraph(results) + results_comparison = create_results_comparison(results_grouped) + print_n_shadows_summary(results_comparison) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/mappings.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/mappings.py new file mode 100644 index 0000000000000000000000000000000000000000..b1a872056d16cf502e0e0c6943dec56a0c6bc4df --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/mappings.py @@ -0,0 +1,761 @@ +import operator + +import torch +import torch.nn as nn +import torch.nn.functional as F +toq = torch.ops.quantized + +import torch.ao.nn.quantized as nnq +import torch.ao.nn.quantized.dynamic as nnqd +import torch.ao.nn.intrinsic.quantized as nniq +import torch.ao.nn.intrinsic.quantized.dynamic as nniqd +import torch.ao.nn.intrinsic.qat as nniqat +import torch.ao.nn.intrinsic as nni +import torch.ao.nn.qat as nnqat +import torch.ao.nn.qat.dynamic as nnqatd +from torch.ao.quantization.backend_config import get_native_backend_config +import torch.ao.quantization.fx._lower_to_native_backend as \ + _lower_to_native_backend +import torch.ao.quantization.quantization_mappings as quantization_mappings + +from .ns_types import NSNodeTargetType + +from typing import Callable, Dict, List, Optional, Set, Tuple + + +def get_base_name_to_sets_of_related_ops() -> Dict[str, Set[NSNodeTargetType]]: + # note: this set is modified below by items from backend_config + sets_of_related_ops: List[Set[NSNodeTargetType]] = [ + # conv modules + { + nn.Conv1d, + }, + { + nn.Conv2d, + }, + { + nn.Conv3d, + }, + # conv functionals + { + F.conv1d, + }, + { + F.conv2d, + }, + { + F.conv3d, + }, + # linear modules + { + nn.Linear, + }, + # linear functionals + { + F.linear, + }, + # average pool + { + nn.AvgPool1d, + torch.avg_pool1d, + }, + { + nn.AvgPool2d, + torch._C._nn.avg_pool2d, + }, + { + nn.AvgPool3d, + torch._C._nn.avg_pool3d, + }, + # adaptive average pool + { + nn.AdaptiveAvgPool1d, + F.adaptive_avg_pool1d, + }, + { + nn.AdaptiveAvgPool2d, + F.adaptive_avg_pool2d, + }, + { + nn.AdaptiveAvgPool3d, + F.adaptive_avg_pool3d, + }, + # LSTM + { + nn.LSTM, + }, + # add + { + torch.add, + operator.add, # x + y + }, + # cat + { + torch.cat, + }, + # mul + { + torch.mul, + operator.mul, + }, + # relu + { + F.relu, + nn.ReLU, + 'relu', + 'relu_', + torch.relu, + }, + # maxpool + { + nn.MaxPool1d, + F.max_pool1d, + }, + { + nn.MaxPool2d, + F.max_pool2d, + }, + { + nn.MaxPool3d, + F.max_pool3d, + }, + # sigmoid + { + torch.sigmoid, + 'sigmoid', + 'sigmoid_', + nn.Sigmoid, + F.sigmoid, + }, + # BatchNorm + { + nn.BatchNorm2d, + }, + { + nn.BatchNorm3d, + }, + # ConvTranspose + { + nn.ConvTranspose1d, + }, + { + nn.ConvTranspose2d, + }, + { + nn.ConvTranspose3d, + }, + # functional transposed conv + { + F.conv_transpose1d, + }, + { + F.conv_transpose2d, + }, + { + F.conv_transpose3d, + }, + # ELU + { + nn.ELU, + }, + # Embedding + { + nn.Embedding, + }, + # EmbeddingBag + { + nn.EmbeddingBag, + }, + # GroupNorm + { + nn.GroupNorm, + }, + # Hardswish + { + nn.Hardswish, + }, + # InstanceNorm + { + nn.InstanceNorm1d, + }, + { + nn.InstanceNorm2d, + }, + { + nn.InstanceNorm3d, + }, + # LayerNorm + { + nn.LayerNorm, + }, + # LeakyReLU + { + nn.LeakyReLU, + }, + # ReLU6 + { + nn.ReLU6, + F.relu6, + }, + # F.elu + { + F.elu, + }, + # F.hardswish + { + F.hardswish, + }, + # F.group_norm + { + F.group_norm, + }, + # F.instance_norm + { + F.instance_norm, + }, + # F.layer_norm + { + F.layer_norm, + }, + # F.leaky_relu + { + F.leaky_relu, + }, + # F.silu + { + nn.SiLU, + F.silu, + }, + # F.mish + { + nn.Mish, + F.mish, + }, + # F.tanh + { + nn.Tanh, + F.tanh, + torch.tanh, + 'tanh_', + 'tanh', + }, + # F.hardsigmoid + { + 'hardsigmoid_', + 'hardsigmoid', + F.hardsigmoid, + nn.Hardsigmoid, + }, + # F.hardtanh + { + nn.Hardtanh, + F.hardtanh, + F.hardtanh_, + }, + # floordiv + { + operator.floordiv, + }, + # unsqueeze + { + torch.unsqueeze, + }, + # stack + { + torch.stack, + }, + # squeeze + { + torch.squeeze, + }, + # sort + { + torch.sort, + }, + # repeat_interleave + { + torch.repeat_interleave, + }, + # min + { + torch.min, + }, + # mean + { + torch.mean, + }, + # max + { + torch.max, + }, + # transpose + { + torch.transpose, + }, + # flatten + { + torch.flatten, + }, + # clamp + { + torch.clamp, + }, + # chunk + { + torch.chunk, + }, + # interpolate + { + torch.nn.functional.interpolate, + }, + # dropout + { + nn.Dropout, + }, + # F.dropout + { + F.dropout, + }, + # matmul + { + torch.matmul, + }, + # Softmax + { + nn.Softmax, + }, + # PReLU + { + nn.PReLU, + nnq.PReLU, + }, + # F.prelu + { + F.prelu, + toq.prelu, + }, + # pixel shuffle + { + nn.PixelShuffle, + }, + { + F.pixel_shuffle, + }, + # pixel unshuffle + { + nn.PixelUnshuffle, + }, + { + F.pixel_unshuffle, + }, + # narrow + { + torch.narrow, + }, + ] + + # for each floating point op, add versions of the op added by + # backend_config + backend_config = get_native_backend_config() + + new_connections: List[Tuple[Callable, Callable]] = [ + # technical debt edge case + (nn.Linear, nn.modules.linear.NonDynamicallyQuantizableLinear), + ] + + for pattern, config in backend_config._pattern_complex_format_to_config.items(): + + # pattern format: (c, (b, a)) + first_element = pattern + # look from the end, because pattern is in reverse order + while isinstance(first_element, (list, tuple)): + first_element = first_element[-1] + + if config.fused_module is not None: + # case 1: pattern fuses a pattern of ops into an op + # example: nn.Conv1d, nn.ReLU fused into nni.ConvReLU1d + new_connections.append((first_element, config.fused_module)) + + if config.qat_module is not None: + # case 2: pattern swaps a module into a QAT module + # example: nni.ConvReLU1d swapped into nniqat.ConvReLU1d + new_connections.append((first_element, config.qat_module)) + + if config.reference_quantized_module is not None: + # case 3: reference version of floating point module, such as + # nn.Conv2d and nnqr.Conv2d + new_connections.append((first_element, config.reference_quantized_module)) + + # + # Add reference module swaps from default lowering path + # + + for source_to_target in ( + _lower_to_native_backend.STATIC_LOWER_MODULE_MAP, + _lower_to_native_backend.DYNAMIC_LOWER_MODULE_MAP, + _lower_to_native_backend.WEIGHT_ONLY_LOWER_MODULE_MAP, + _lower_to_native_backend.SPECIAL_PATTERN_LOWER_MODULE_MAP, + ): + for source, target in source_to_target.items(): # type: ignore[attr-defined] + new_connections.append((source, target)) + + for source_to_double_target in ( + _lower_to_native_backend.STATIC_LOWER_FUSED_MODULE_MAP, + _lower_to_native_backend.STATIC_LOWER_FUSED_MODULE_TWO_INPUTS_MAP, + _lower_to_native_backend.DYNAMIC_LOWER_FUSED_MODULE_MAP, + ): + for source, (target1, target2) in source_to_double_target.items(): # type: ignore[attr-defined] + new_connections.append((source, target1)) + new_connections.append((source, target2)) + + # + # Add function swaps from default lowering path + # + + for source, (target1, target2) in \ + _lower_to_native_backend.STATIC_LOWER_FUNCTIONAL_MAP.items(): + new_connections.append((source, target1)) + new_connections.append((source, target2)) + + for source_to_target in ( + _lower_to_native_backend.QBIN_OP_MAPPING, + _lower_to_native_backend.QBIN_RELU_OP_MAPPING, + quantization_mappings.DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS, + ): + for source, target in source_to_target.items(): + new_connections.append((source, target)) + + # + # Add other swaps, ideally in the future this could be removed + # after the lowering code stops using these. + # + for source_to_target in ( + quantization_mappings.DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS, + ): + for source, target in source_to_target.items(): + new_connections.append((source, target)) + + + # add the new connections from backend_config + for item1, item2 in new_connections: + for set_of_related_ops in sets_of_related_ops: + if item1 in set_of_related_ops or item2 in set_of_related_ops: + set_of_related_ops.add(item1) + set_of_related_ops.add(item2) + break + + base_name_to_sets_of_related_ops: Dict[str, Set[NSNodeTargetType]] = {} + + counter = 0 + for set_of_related_ops in sets_of_related_ops: + base_name = str(counter) + counter += 1 + base_name_to_sets_of_related_ops[base_name] = set_of_related_ops + + return base_name_to_sets_of_related_ops + + +def get_base_name_for_op( + base_name_to_sets_of_related_ops: Dict[str, Set[NSNodeTargetType]], + op: NSNodeTargetType, +) -> Optional[str]: + for base_name, set_of_related_ops in base_name_to_sets_of_related_ops.items(): + if op in set_of_related_ops: + return base_name + return None + + +def add_op_to_sets_of_related_ops( + base_name_to_sets_of_related_ops: Dict[str, Set[NSNodeTargetType]], + op: NSNodeTargetType, + related_op: Optional[NSNodeTargetType], +) -> None: + if related_op is not None: + for set_of_related_ops in base_name_to_sets_of_related_ops.values(): + if related_op in set_of_related_ops: + set_of_related_ops.add(op) + return + # if we got here, related_op was not found + raise AssertionError(f"{related_op} was not found") + else: + counter = 0 + while str(counter) in base_name_to_sets_of_related_ops: + counter += 1 + base_name_to_sets_of_related_ops[str(counter)] = {op} + + +# TODO(future PR): clean this up +def get_node_type_to_io_type_map() -> Dict[str, Set[NSNodeTargetType]]: + FUNS_IO_TYPE_FP32: Set[NSNodeTargetType] = { + F.linear, + F.conv1d, + F.conv2d, + F.conv3d, + torch.cat, + F.elu, + F.hardswish, + F.instance_norm, + F.layer_norm, + F.leaky_relu, + F.dropout, + F.silu, + F.mish, + operator.add, + torch.add, + operator.mul, + torch.mul, + torch.sum, + F.prelu, + } + + FUNS_IO_TYPE_FP16: Set[NSNodeTargetType] = set() + + FUNS_IO_TYPE_INT8: Set[NSNodeTargetType] = { + toq.linear, + toq.linear_relu, + toq.conv1d, + toq.conv1d_relu, + toq.conv2d, + toq.conv2d_relu, + toq.conv3d, + toq.conv3d_relu, + toq.cat, + toq.elu, + toq.hardswish, + toq.instance_norm, + toq.layer_norm, + toq.leaky_relu, + toq.dropout, + toq.prelu, + # TODO(future PR): implement shadowing for binary ops and + # uncomment below + # toq.add, + # toq.mul, + } + + FUNS_IO_TYPE_FP32_OR_INT8: Set[NSNodeTargetType] = { + F.relu, + F.tanh, + torch.tanh, + F.sigmoid, + torch.sigmoid, + F.hardsigmoid, + operator.floordiv, + torch.adaptive_avg_pool1d, + F.adaptive_avg_pool2d, + F.adaptive_avg_pool3d, + F.dropout, + F.hardtanh, + F.hardtanh_, + F.interpolate, + F.max_pool1d, + F.max_pool2d, + F.max_pool3d, + F.relu6, + F.pixel_shuffle, + F.pixel_unshuffle, + torch.avg_pool1d, + torch._C._nn.avg_pool2d, + torch._C._nn.avg_pool3d, + torch.cat, + torch.chunk, + torch.clamp, + torch.flatten, + torch.transpose, + torch.max, + torch.mean, + torch.min, + torch.narrow, + torch.repeat_interleave, + torch.sort, + torch.squeeze, + torch.stack, + torch.unsqueeze, + operator.add, + } + + MODS_IO_TYPE_FP32: Set[NSNodeTargetType] = { + nn.Linear, + nnqat.Linear, + nnqatd.Linear, + nnqd.Linear, + torch.nn.modules.linear.NonDynamicallyQuantizableLinear, + nn.Conv1d, + nn.Conv2d, + nn.Conv3d, + nnqat.Conv1d, + nnqat.Conv2d, + nnqat.Conv3d, + nnqat.Embedding, + nnqat.EmbeddingBag, + nn.LSTM, + # note: nnqd.Linear is an instance of nnq.Linear, so this + # check has to happen before the int8 module check + nnqd.LSTM, + nn.BatchNorm2d, + nn.BatchNorm3d, + nn.Dropout, + nn.ConvTranspose1d, + nn.ConvTranspose2d, + nn.ConvTranspose3d, + nn.ELU, + nn.GroupNorm, + nn.InstanceNorm1d, + nn.InstanceNorm2d, + nn.InstanceNorm3d, + nn.LayerNorm, + nn.Hardswish, + nn.LeakyReLU, + nn.ReLU6, + nn.SiLU, + nn.Mish, + nn.Softmax, + nn.PReLU, + nni.BNReLU2d, + nni.BNReLU3d, + nni.ConvReLU1d, + nni.ConvReLU2d, + nni.ConvReLU3d, + nni.LinearReLU, + nni.LinearBn1d, + nni.ConvBn1d, + nni.ConvBn2d, + nni.ConvBn3d, + nniqat.ConvBn1d, + nniqat.ConvBn2d, + nniqat.ConvBn3d, + nniqat.ConvBnReLU1d, + nniqat.ConvBnReLU2d, + nniqat.ConvBnReLU3d, + nniqat.ConvReLU1d, + nniqat.ConvReLU2d, + nniqat.ConvReLU3d, + nniqat.LinearReLU, + nniqat.LinearBn1d, + nniqd.LinearReLU, + nni.LinearLeakyReLU, + nni.LinearTanh, + nni.ConvAdd2d, + nni.ConvAddReLU2d, + } + + MODS_IO_TYPE_INT8: Set[NSNodeTargetType] = { + nnq.Linear, + nnq.Conv1d, + nnq.Conv2d, + nnq.Conv3d, + nnq.BatchNorm2d, + nnq.BatchNorm3d, + nnq.Dropout, + nnq.ConvTranspose1d, + nnq.ConvTranspose2d, + nnq.ELU, + nnq.InstanceNorm1d, + nnq.InstanceNorm2d, + nnq.InstanceNorm3d, + nnq.LayerNorm, + nnq.Hardswish, + nnq.LeakyReLU, + nnq.Embedding, + nnq.EmbeddingBag, + nnq.Dropout, + nnq.Softmax, + nnq.PReLU, + nniq.BNReLU2d, + nniq.BNReLU3d, + nniq.ConvReLU1d, + nniq.ConvReLU2d, + nniq.ConvReLU3d, + nniq.LinearReLU, + nniq.LinearLeakyReLU, + nniq.LinearTanh, + nniq.ConvAdd2d, + nniq.ConvAddReLU2d, + } + + MODS_IO_TYPE_FP32_OR_INT8: Set[NSNodeTargetType] = { + nn.ReLU, + nn.Tanh, + nn.Sigmoid, + nn.Hardsigmoid, + nn.AdaptiveAvgPool1d, + nn.AdaptiveAvgPool2d, + nn.AdaptiveAvgPool3d, + nn.AvgPool1d, + nn.AvgPool2d, + nn.AvgPool3d, + nn.Dropout, + nn.Hardtanh, + nn.Identity, + nn.MaxPool1d, + nn.MaxPool2d, + nn.MaxPool3d, + nn.PixelShuffle, + nn.PixelUnshuffle, + nn.ReLU6, + } + + METHS_IO_TYPE_FP32_OR_INT8: Set[NSNodeTargetType] = { + 'sigmoid_', + 'sigmoid', + 'tanh_', + 'tanh', + 'hardsigmoid_', + 'hardsigmoid', + 'relu_', + 'relu', + } + + return { + 'funs_io_type_fp32': FUNS_IO_TYPE_FP32, + 'funs_io_type_fp16': FUNS_IO_TYPE_FP16, + 'funs_io_type_int8': FUNS_IO_TYPE_INT8, + 'funs_io_type_fp32_or_int8': FUNS_IO_TYPE_FP32_OR_INT8, + 'mods_io_type_fp32': MODS_IO_TYPE_FP32, + 'mods_io_type_int8': MODS_IO_TYPE_INT8, + 'mods_io_type_fp32_or_int8': MODS_IO_TYPE_FP32_OR_INT8, + 'meths_io_type_fp32_or_int8': METHS_IO_TYPE_FP32_OR_INT8, + } + + +def get_unmatchable_types_map() -> Dict[str, Set[NSNodeTargetType]]: + + FUNS_UNMATCHABLE: Set[NSNodeTargetType] = { + torch.quantize_per_tensor, + operator.getitem, + } + + MODS_UNMATCHABLE: Set[NSNodeTargetType] = { + nn.Identity, + } + + METHS_UNMATCHABLE: Set[NSNodeTargetType] = { + 'to', + 'dequantize', + 'reshape', + 'view', + 'unsqueeze_', + 'unsqueeze', + 'transpose', + 'squeeze_', + 'squeeze', + 'size', + 'shape', + 'resize_', + 'repeat_interleave', + 'repeat', + 'permute', + 'numel', + 'mean', + 'detach_', + 'detach', + 'contiguous', + 'clamp', + 'chunk', + } + + return { + 'funs_unmatchable': FUNS_UNMATCHABLE, + 'mods_unmatchable': MODS_UNMATCHABLE, + 'meths_unmatchable': METHS_UNMATCHABLE, + } diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/n_shadows_utils.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/n_shadows_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b7eddf93e2ae49e7e7a4bd12787212b3502e5254 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/n_shadows_utils.py @@ -0,0 +1,1311 @@ +import torch +import torch.fx +from torch.fx import ( + Node, + GraphModule, + Graph, +) + +from torch.ao.ns.fx.utils import ( + # TODO(future PR): make this work correctly for methods + get_target_type_str, + get_normalized_nth_input, +) +from torch.ao.ns.fx.ns_types import ( + NSSingleResultValuesType, + NSResultsType, +) +from torch.ao.ns.fx.graph_passes import _maybe_get_fqn +from torch.ao.quantization import QConfigMapping +from torch.ao.quantization.qconfig import QConfigAny +from torch.ao.quantization.utils import getattr_from_fqn +from torch.ao.quantization.fx.match_utils import _MatchResult +from torch.utils._pytree import tree_map + +import collections +import copy +from typing import List, Dict, Set, Tuple, Callable, Any, Optional +import operator + +SHADOW_NODE_NAME_PREFIX = 'shadow' +SHADOW_WRAPPER_NODE_NAME_PREFIX = 'shadow_wrapper' + +# TODO(future PR): reuse existing mapping instead of creating a new one +BINARY_FUNCTIONS = { + torch.add, + torch.Tensor.add, + operator.add, + torch.mul, + torch.Tensor.mul, + operator.mul, +} + +def _get_attr_name(subgraph_idx, subgraph_candidate_idx): + return f"{SHADOW_NODE_NAME_PREFIX}_{subgraph_idx}_{subgraph_candidate_idx}" + +def _get_attr_wrapper_name(subgraph_idx, subgraph_candidate_idx): + return f"{SHADOW_WRAPPER_NODE_NAME_PREFIX}_{subgraph_idx}_{subgraph_candidate_idx}" + + +class OutputProp: + """ + Output propagation (modeled from shape propagation). + + Given a GraphModule and an example input, saves the output flowing + through each node on `node.traced_result`. + + Code based on the example from + https://pytorch.org/docs/stable/fx.html#the-interpreter-pattern + """ + def __init__(self, mod): + self.mod = mod + self.graph = mod.graph + self.modules = dict(self.mod.named_modules()) + + def propagate(self, *args): + args_iter = iter(args) + env : Dict[str, Node] = {} + + def load_arg(a): + return torch.fx.graph.map_arg(a, lambda n: env[n.name]) + + def fetch_attr(target : str): + target_atoms = target.split('.') + attr_itr = self.mod + for i, atom in enumerate(target_atoms): + if not hasattr(attr_itr, atom): + raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}") + attr_itr = getattr(attr_itr, atom) + return attr_itr + + for node in self.graph.nodes: + if node.op == 'placeholder': + result = next(args_iter) + elif node.op == 'get_attr': + result = fetch_attr(node.target) + elif node.op == 'call_function': + result = node.target(*load_arg(node.args), **load_arg(node.kwargs)) + elif node.op == 'call_method': + self_obj, *args = load_arg(node.args) + kwargs = load_arg(node.kwargs) + result = getattr(self_obj, node.target)(*args, **kwargs) + elif node.op == 'call_module': + result = self.modules[node.target](*load_arg(node.args), **load_arg(node.kwargs)) + + if isinstance(result, torch.Tensor): # type: ignore[possibly-undefined] + node.traced_result = result + + env[node.name] = result + + return None + +def _get_dedup_subgraphs( + matches: Dict[str, _MatchResult] +) -> Dict[str, List[Node]]: + # the original matches variable is unique by node, make it unique by subgraph + # instead + seen_nodes = set() + subgraphs_dedup = {} + + # Dict items are not reversible until Python 3.8, so we hack it + # to be compatible with previous Python versions + # TODO(future PR): try reversed(list(matches.items())) + matches_items_reversed: List[Tuple[str, _MatchResult]] = [] + for name, cur_match in matches.items(): + matches_items_reversed.insert(0, (name, cur_match)) + + # Note: the order is important. `matches` currently provides the matches + # in reverse order. We would like to process the matches in non-reverse + # order, so that we can create an intuitive naming scheme, such as + # naming the first op's submodules `shadow_0_0` through `shadow_0_(n-1)` + for name, cur_match in matches_items_reversed: # type: ignore[call-overload] + was_seen = False + for node_or_tuple in cur_match[1]: + + # Cur_match[1] has an unusual type. It says that it's a `List[Node]`, + # but it is really not. Furthermore, the contents of this field + # can change from match results of multiple nodes of the same pattern + # + # For example, for conv -> bn -> relu, we see + # match_results = { + # 'conv': (relu, [(bn, conv), relu], ...), + # 'bn': (relu, [(bn, conv), relu], ...), + # 'relu': (relu, [(bn, conv), relu], ...), + # } + # + # Ideally we should clean up the `find_matches` function to make + # this more intuitive. For the purposes of this prototype, we hack + # around it. + + if isinstance(node_or_tuple, Node): + if node_or_tuple in seen_nodes: + was_seen = True + seen_nodes.add(node_or_tuple) + + else: + assert isinstance(node_or_tuple, tuple) + for node in node_or_tuple: + assert isinstance(node, Node) + if node in seen_nodes: + was_seen = True + seen_nodes.add(node) + + if was_seen: + continue + + # Start with the unusual type, convert it to [op_0, ..., op_n] + list_of_nodes = [] + + if len(cur_match[1]) == 1: + list_of_nodes = cur_match[1] + else: + assert len(cur_match[1]) == 2 + # either (a, b), or ((a, b), c) or (c, (a, b)) + # cannot make any assumptions on order, not clear what the + # _find_matches function is doing to populate this + # TODO(future PR): make this code less confusing, see discussion + # in https://github.com/pytorch/pytorch/pull/80521/files#r975918836 + + def _order_nodes(node_a, node_b, node_c) -> List[Node]: + nodes = [node_a, node_b, node_c] + first_node = None + mid_node = None + last_node = None + for n in nodes: + prev_n = n.args[0] + next_n = next(iter(n.users)) + if prev_n not in nodes: + first_node = n + elif next_n not in nodes: + last_node = n + else: + mid_node = n + assert first_node is not None and mid_node is not None and \ + last_node is not None + assert mid_node.args[0] is first_node + assert last_node.args[0] is mid_node + return [last_node, mid_node, first_node] + + if isinstance(cur_match[1][0], Node) and isinstance(cur_match[1][1], Node): + # (a, b) + list_of_nodes = cur_match[1] + elif isinstance(cur_match[1][0], tuple): + # ((a, b), c) + node_a, node_b = cur_match[1][0] + node_c = cur_match[1][1] + list_of_nodes = _order_nodes(node_a, node_b, node_c) + elif isinstance(cur_match[1][1], tuple): + # (a, (b, c)) + node_a, node_b = cur_match[1][1] + node_c = cur_match[1][0] + list_of_nodes = _order_nodes(node_a, node_b, node_c) + + # [node_n, ..., node_0], note that the order is reversed + # to make it chronological for simple subgraphs + list_of_nodes.reverse() + subgraphs_dedup[name] = list_of_nodes + + return subgraphs_dedup + +def _get_logger_for_subgraph( + model: GraphModule, + first_node: Node, + last_node: Node, + subgraph_idx: int, + subgraph_candidate_idx: int, + qconfig_str: str, + logger_cls: Callable, + fqn: Optional[str], +) -> torch.nn.Module: + """ + Given a model and a linear subgraph starting from `first_node` and + ending with `last_node`, creates a logger for the end of this + subgraph. + """ + if fqn is None: + fqn = '' + logger_mod_orig = logger_cls( + first_node.name, # ref_node_name + last_node.name, # prev_node_name + f'subgraph_{subgraph_idx}_{subgraph_candidate_idx}', # model_name + 'model', # ref_name + get_target_type_str(last_node, model), # prev_node_target_type + get_target_type_str(first_node, model), # ref_node_target_type + NSSingleResultValuesType.NODE_OUTPUT.value, # results_type + 0, # index_within_arg + 0, # index_of_arg + fqn, # fqn + qconfig_str, + ) + # Usually we expect the user to add loggers, then calibrate, then convert, + # and then populate loggers. This is why the loggers start disabled. + # TODO(future PR): reconsider the design to make this more intuitive. + logger_mod_orig.enabled = False + return logger_mod_orig + +def create_submodule_from_subgraph( + model: torch.nn.Module, + first_node: Node, + last_node: Node, +) -> GraphModule: + """ + Input: a model, and a linear subgraph within the model from first_node to + last_node. + + Output: a new submodule containing a copy of the subgraph, with the inputs + to the first node becoming the inputs to the submodule, and all other + nodes in the subgraph being copied. + + Example inputs: + + `model`: a module with graph + + x0 -> op1 -> x1 -> op2 -> x2 + | + arg1 + + `first_node`: op1 + `last_node`: op2 + + Example output: a new module with graph + + input1 -> op1_copy -> x1 -> op2_copy -> output1 + | + arg1 + """ + + # + # create a blank GraphModule with an empty graph + # + + class M(torch.nn.Module): + def forward(self, x): + pass + + m = M() + gm = torch.fx.symbolic_trace(m) + g = gm.graph + for node in reversed(gm.graph.nodes): + g.erase_node(node) + + # + # modify the graph to have a copy of our subgraph + # + + cur_node_orig = first_node + cur_args_orig = cur_node_orig.args + cur_kwargs_orig = cur_node_orig.kwargs + + cur_name_idx = 0 + + iteration_limit = 100 + cur_iteration = 0 + + while True: + if cur_node_orig is first_node: + # we are at the first node, we need to set up graph inputs + # TODO(future): some graphs could have placeholders which are unrelated + # to the first node, need to handle this + cur_args_copy = [] + cur_kwargs_copy = {} + seen_names: Set[str] = set() + old_name_to_new_node: Dict[str, Node] = {} + + def _add_placeholder( + g: Graph, node: Node, seen_names, old_name_to_new_node + ): + # note: for graphs starting with patterns such as `y = x + x`, we + # need to ensure we do not add multiple placeholders with the + # same name + counter = 0 + while node.name + '_' + str(counter) in seen_names: + counter += 1 + cur_name = node.name + '_' + str(counter) + seen_names.add(cur_name) + placeholder = g.placeholder(cur_name) + old_name_to_new_node[node.name] = placeholder + return placeholder + + for arg in cur_node_orig.args: + if isinstance(arg, Node): + p = _add_placeholder( + g, arg, seen_names, old_name_to_new_node) + cur_args_copy.append(p) + elif isinstance(arg, (list, tuple)): + new_arg = [] + for inner_arg in arg: + if isinstance(inner_arg, Node): + new_arg.append(_add_placeholder( + g, inner_arg, seen_names, old_name_to_new_node)) + else: + new_arg.append(inner_arg) + cur_args_copy.append(new_arg) + else: + cur_args_copy.append(arg) + + # TODO(future PR): handle non-normalized kwargs + for kwarg_name, kwarg in cur_node_orig.kwargs.items(): + if isinstance(kwarg, Node): + cur_kwargs_copy[kwarg_name] = _add_placeholder( + g, kwarg, seen_names, old_name_to_new_node) + elif isinstance(kwarg, (list, tuple)): + new_kwarg = [] + for inner_kwarg in kwarg: + p = _add_placeholder( + g, inner_kwarg, seen_names, old_name_to_new_node) + new_kwarg.append(p) + cur_kwargs_copy[kwarg_name] = new_kwarg + else: + cur_kwargs_copy[kwarg_name] = kwarg + + cur_args_copy = tuple(cur_args_copy) # type: ignore[assignment] + else: + # we are not at first node, first arg is from the previous node, + # and all other args are copied + + # the current implementation is simplistic and cannot handle + # ops with two or more arguments which need to be passed from + # the previous op, so we assert them out + assert cur_node_orig.target not in BINARY_FUNCTIONS + + # at this point in the code, cur_node_copy is pointing to the copy + # of the previous node + # TODO(future PR): this is not handling complicated graphs correctly, need to + # look at actual relationships instead of assuming sequential graph + # TODO(future PR): this is ignoring kwargs, will need to support kwargs + # for any fusion pattern which has them for a node that is not the + # first node. + cur_args_copy = [cur_node_copy] # type: ignore[has-type, possibly-undefined] # noqa: F821 + + if len(cur_node_orig.args) > 1: + for arg in cur_node_orig.args[1:]: + if isinstance(arg, torch.nn.Parameter): + new_arg = arg.clone().detach() # type: ignore[assignment] + mod_name = f"mod_{cur_name_idx}" + cur_name_idx += 1 + setattr(gm, mod_name, new_arg) + new_arg_placeholder = gm.placeholder(mod_name) + cur_args_copy.append(new_arg_placeholder) + elif isinstance(arg, (float, int, torch.dtype)): + cur_args_copy.append(arg) + else: + raise AssertionError(f'arg of type {type(arg)} not handled yet') + cur_args_copy = tuple(cur_args_copy) # type: ignore[assignment] + + # copy the node + if cur_node_orig.op == 'call_module': + orig_mod = getattr_from_fqn(model, cur_node_orig.target) # type: ignore[arg-type] + orig_mod_copy = copy.deepcopy(orig_mod) + mod_name = f"mod_{cur_name_idx}" + setattr(gm, mod_name, orig_mod_copy) + cur_name_idx += 1 + cur_node_copy = g.call_module(mod_name, cur_args_copy, cur_kwargs_copy) # type: ignore[possibly-undefined] + + elif cur_node_orig.op == 'call_function': + cur_node_copy = g.call_function( + cur_node_orig.target, cur_args_copy, cur_kwargs_copy) # type: ignore[possibly-undefined] + + elif cur_node_orig.op == 'call_method': + cur_node_copy = g.call_method( + cur_node_orig.target, cur_args_copy, cur_kwargs_copy) # type: ignore[possibly-undefined] + + else: + raise AssertionError(f'{cur_node_orig.op} not supported yet') + + if cur_node_orig is last_node: + break + + # go to next node + assert len(cur_node_orig.users.keys()) == 1, \ + f'{cur_node_orig} has more than 1 users, not supported yet' + cur_node_orig = next(iter(cur_node_orig.users.keys())) + cur_args_orig = cur_node_orig.args + cur_kwargs_orig = cur_node_orig.kwargs + + cur_iteration += 1 + if cur_iteration > iteration_limit: + raise AssertionError('iteration limit exceeded') + + # set up outputs + g.output(cur_node_copy) + + gm.recompile() + return gm + +def create_one_transformed_and_logged_copy_of_subgraph( + mt: GraphModule, + subgraph_idx: int, + subgraph_candidate_idx: int, + first_node: Node, + last_node: Node, + fqn: Optional[str], + list_of_node_name_to_qconfig: List[Dict[str, QConfigAny]], + example_inputs: Any, + last_added_shadow_node_list: List[Optional[Node]], + custom_prepare_fn: Optional[Callable] = None, + custom_prepare_kwargs: Optional[Dict[str, Any]] = None, +) -> None: + """ + Given a subgraph in `mt` and a subgraph candidate idx, inserts the + subgraph candidate copy and instruments it with loggers. + + If subgraph_candidate_idx is 0, this is the baseline fp32 subgraph and we just + add a logger to the end. + + If subgraph_candidate_idx is not 0, we create a copy of the subgraph and + prepare it with `prepare_fx`. + """ + + # TODO(future PR): move logger classes to utils to remove circular dependency + from torch.ao.ns._numeric_suite_fx import OutputLogger, OutputComparisonLogger + + if subgraph_candidate_idx == 0: + # idx = 0 is the floating point (original) version of the subgraph + # We keep the subgraph as is, and add a logger at the end + + qconfig_str = '' + logger_mod_orig = _get_logger_for_subgraph( + mt, first_node, last_node, subgraph_idx, subgraph_candidate_idx, + qconfig_str, OutputLogger, fqn) + + attr_name = _get_attr_name(subgraph_idx, subgraph_candidate_idx) + assert not hasattr(mt, attr_name) + setattr(mt, attr_name, logger_mod_orig) + with mt.graph.inserting_after(last_node): + new_node = mt.graph.call_module(attr_name, args=(last_node,), kwargs={}) + last_added_shadow_node_list[0] = new_node + + else: + # idx > 0 means we have a candidate qconfig to try, so we need + # to make a copy of the subgraph, feed it with the right inputs, + # and add a logger at the end + + # get the qconfig + # subtract one because the first candidate is the floating point + # version of the subgraph + node_name_to_qconfig = \ + list_of_node_name_to_qconfig[subgraph_candidate_idx - 1] + qconfig = node_name_to_qconfig[first_node.name] + + # if no quantization is requested, skip + # TODO(future PR): deduplicate equivalent qconfigs that come from + # different qconfig mapping objects + if qconfig is None: + return + + qconfig_mapping = QConfigMapping().set_global(qconfig) + + # create a copy of the submodule, wrapped in a separate module + orig_mod_copy_wrapped = create_submodule_from_subgraph( + mt, first_node, last_node) + + # add a call to prepare_fx on the wrapper module + if custom_prepare_fn is None: + orig_mod_copy_wrapped = torch.ao.quantization.quantize_fx.prepare_fx( + orig_mod_copy_wrapped, qconfig_mapping, example_inputs=example_inputs) + else: + if custom_prepare_kwargs is None: + custom_prepare_kwargs = {} + for kwarg_name in ["example_inputs", "prepare_custom_config", "qconfig_mapping"]: + assert kwarg_name not in custom_prepare_kwargs, f"cannot specify {kwarg_name} in custom_prepare_kwargs" + prepare_kwargs: Dict[str, Any] = { + "example_inputs": example_inputs, + "qconfig_mapping": qconfig_mapping + } + prepare_kwargs.update(custom_prepare_kwargs) + orig_mod_copy_wrapped = custom_prepare_fn( + orig_mod_copy_wrapped, + **prepare_kwargs) + + # attach the wrapper to the model + attr_name = _get_attr_wrapper_name(subgraph_idx, subgraph_candidate_idx) + assert not hasattr(mt, attr_name) + setattr(mt, attr_name, orig_mod_copy_wrapped) + + # add a call to the wrapper module from the parent graph + insert_after_node = last_added_shadow_node_list[0] + with mt.graph.inserting_after(insert_after_node): + # TODO(future PR): handle fusion patterns where non-first nodes + # need inputs + + # pass in all node args and kwargs + + new_args = [] + for arg in first_node.args: + if isinstance(arg, Node): + new_args.append(arg) + elif isinstance(arg, (list, tuple)) and len(arg) and isinstance(arg[0], Node): + for inner_arg in arg: + if isinstance(inner_arg, Node): + new_args.append(inner_arg) + + new_kwargs = {} + for name, old_kwarg in first_node.kwargs.items(): + if isinstance(old_kwarg, Node): + new_kwargs[name] = old_kwarg + elif isinstance(old_kwarg, (list, tuple)) and len(old_kwarg): + # TODO(future PR): clarify why we are adding kwargs to args + new_args.extend(old_kwarg) + + new_args = tuple(new_args) # type: ignore[assignment] + + new_node = mt.graph.call_module( + attr_name, args=new_args, kwargs=new_kwargs) + + # add a logger to parent graph to observe the shadow wrapper + logger_mod_orig = _get_logger_for_subgraph( + mt, first_node, last_node, subgraph_idx, subgraph_candidate_idx, + str(qconfig), OutputComparisonLogger, fqn) + + attr_name = _get_attr_name(subgraph_idx, subgraph_candidate_idx) + assert not hasattr(mt, attr_name) + setattr(mt, attr_name, logger_mod_orig) + with mt.graph.inserting_after(new_node): + logger = mt.graph.call_module(attr_name, args=(new_node, last_node), kwargs={}) + last_added_shadow_node_list[0] = logger + + mt.recompile() + +def create_n_transformed_and_logged_copies_of_subgraph( + mt: GraphModule, + subgraph_idx: int, + match_name: str, + nodes_in_this_subgraph: List[Any], + qconfig_mappings: List[QConfigMapping], + list_of_node_name_to_qconfig: List[Dict[str, QConfigAny]], + custom_prepare_fn: Optional[Callable] = None, + custom_prepare_kwargs: Optional[Dict[str, Any]] = None, +) -> None: + """ + Given a model `mt` and a subgraph_idx, creates the needed copies + of the subgraph for all qconfigs, and instruments them with loggers. + """ + # for now, assume that + # 1. the first node has one input + # 2. the last node has one output + + # for now, ignore all subgraphs that contain non-nodes (tuples, etc) + # TODO(future PR): implement this + if any( + not isinstance(node, Node) + for node in nodes_in_this_subgraph + ): + return + + first_node = nodes_in_this_subgraph[0] + last_node = nodes_in_this_subgraph[-1] + # We used output propagation to populate example values on each + # node. Use the example values from the previous node as the input + # to the current node. + prev_node = get_normalized_nth_input(first_node, mt, 0) + if isinstance(prev_node, list): + example_inputs = [x.traced_result for x in prev_node] + elif isinstance(prev_node, tuple): + example_inputs = (x.traced_result for x in prev_node) # type: ignore[assignment] + else: + # currently some customer models do not have a traced_result in + # every node, so we have to guard for this case since we cannot + # quantize without an example input + # TODO(future PR): add a test case for this once we have an easy + # repro, see https://github.com/pytorch/pytorch/pull/80521/files#r975940489 + # for additional context + if hasattr(prev_node, 'traced_result'): + example_inputs = (prev_node.traced_result,) # type: ignore[attr-defined, assignment] + else: + print( + 'unable to get example input for node ' + + f'{first_node.format_node()}, skipping') + return + + # If there are no quantization configs for this subgraph, skip adding + # loggers. This reduces memory usage for models where not all layers are + # quantized. + # TODO(future): consider making this configurable + found_at_least_one_qconfig = False + for subgraph_candidate_idx in range(len(qconfig_mappings) + 1): + + if subgraph_candidate_idx == 0: + # fp32 baseline does not need a qconfig + continue + + # a. we have N shadows, so len(qconfig_mappings) is N + # b. we will have the fp32 layer + N shadows, so overall number of + # (original_op) + (*shadows) will be N+1 + # c. since `subgraph_candidate_idx` represents (b), we need + # to subtract 1 to query from (a) + node_name_to_qconfig = \ + list_of_node_name_to_qconfig[subgraph_candidate_idx - 1] + qconfig = node_name_to_qconfig[first_node.name] + if qconfig is not None: + found_at_least_one_qconfig = True + break + if not found_at_least_one_qconfig: + print('unable to find at least one qconfig for node ' + + f'{first_node.format_node()}, skipping') + return + + fqn = _maybe_get_fqn(first_node, mt) + + # We want the results to contain the subgraphs in natural order, + # and the graph to also contain shadow wrappers and shadow loggers + # in natural order. + # If we just iterate in reverse, the graph will be in natural + # order but the eventual results will be in reverse order. + # So, we keep track of the last shadow logger we added and + # always insert after it. + last_added_shadow_node_list: List[Optional[Node]] = [None] + for subgraph_candidate_idx in range(len(qconfig_mappings) + 1): + + create_one_transformed_and_logged_copy_of_subgraph( + mt, subgraph_idx, subgraph_candidate_idx, first_node, + last_node, fqn, list_of_node_name_to_qconfig, + example_inputs, last_added_shadow_node_list, custom_prepare_fn, + custom_prepare_kwargs) + +def create_add_loggers_graph( + model: GraphModule, + subgraphs_dedup: Dict[str, List[Node]], + qconfig_mapping: QConfigMapping, + node_name_to_qconfig: Dict[str, QConfigAny], +) -> None: + r""" + Given a model, a model graph partition (currently a set of matched + subgraphs) and instructions how to transform each subgraph + (currently quantizing it according to qconfig_mapping), modifies + the model graph to create an alternate path through the original graph, + with each of the subgraphs quantized. This is useful to compare + propagation error of a transformation such as quantization. + + For example, given layer op0 and op1, there are four cases when handling op1: + 1. op0 and op1 quantized + 2. op0 and op1 unquantized + 3. op0 quantized, op1 unquantized + 4. op0 unquantized, op1 quantized + + Example input, case 1: + + .. code:: + + x0_0 -> op0_0 -> x1_0 -> log -----> op1_0 -> x2_0 -> log + \ \ \ \ # noqa: W605 + ---> op0_1 -> x1_1 ----> clog op1_1 -> x2_1 ----> clog + + Example output, case 1: + + .. code:: + + x0_0 -> op0_0 -> x1_0 -> log -----> op1_0 -> x2_0 -> log + \ \ \ # noqa: W605 + ---> op0_1 -> x1_1 ----> clog -> op1_1 -> x2_1 ----> clog + + """ + # TODO(future PR): move logger classes to utils to remove circular dependency + from torch.ao.ns._numeric_suite_fx import OutputLogger, OutputComparisonLogger + + def _get_subgraph_containing_node(node, subgraphs_dedup): + for subgraph in subgraphs_dedup.values(): + if node in subgraph: + return subgraph + return None + + # First, we need to create shadow branches, going from + # + # x0 -> op0 -> x1 -> ... + # + # + # to + # + # x0 -> op0_0 -> x1_0 -> log -> ... + # \ \ + # -> op0_1 -> x1_1 -> clog + # + # Later, the outputs of each shadow will be rerouted to calculate + # propagation error. + + # Note: we cannot iterate over matched subgraphs because some nodes + # may not be matched. So, we iterate over nodes in the graph, and + # associate them to matched subgraphs if possible. + + nodes_to_skip = set() + # for each subgraph, save a mapping from first node of subgraph + # to first and last node of the shadow of this subgraph + orig_first_node_to_shadow_in_node = {} + orig_first_node_to_shadow_out_node = {} + # need to record original list because we will mutate the graph as we go + orig_nodes = list(model.graph.nodes) # type: ignore[union-attr, arg-type] + cur_subgraph_idx = 0 + for n in orig_nodes: + if n.op in ('placeholder', 'get_attr', 'output') or n in nodes_to_skip: + continue + + maybe_subgraph = _get_subgraph_containing_node(n, subgraphs_dedup) + insert_submodule_copy = False + if maybe_subgraph is not None: + first_node, last_node = maybe_subgraph[0], maybe_subgraph[-1] + for node_to_skip in maybe_subgraph: + nodes_to_skip.add(node_to_skip) + qconfig = node_name_to_qconfig[first_node.name] + if qconfig is not None: + insert_submodule_copy = True + else: + first_node, last_node = n, n + + if insert_submodule_copy: + match_name = first_node.name + create_n_transformed_and_logged_copies_of_subgraph( + model, cur_subgraph_idx, match_name, maybe_subgraph, + [qconfig_mapping], [node_name_to_qconfig], + None, None # type: ignore[arg-type] + ) + # find the created shadow module and record it so we + # can find it easily in step 2 + expected_shadow_target = f"shadow_wrapper_{cur_subgraph_idx}_1" + new_shadow_mod = None + for maybe_shadow_mod in model.graph.nodes: + if maybe_shadow_mod.op == 'call_module' and \ + maybe_shadow_mod.target == expected_shadow_target: + new_shadow_mod = maybe_shadow_mod + break + assert new_shadow_mod is not None + orig_first_node_to_shadow_in_node[first_node] = new_shadow_mod + orig_first_node_to_shadow_out_node[first_node] = new_shadow_mod + + else: + # create a copy of the subgraph by only copying FX nodes + # but not copying any parameters, to minimize memory usage + subgraph_to_use = maybe_subgraph if maybe_subgraph is not None \ + else [first_node] + + # add a regular logger after last_node + qconfig_str = '' + subgraph_candidate_idx = 0 + fqn = _maybe_get_fqn(first_node, model) + logger_mod_orig = _get_logger_for_subgraph( + model, first_node, last_node, cur_subgraph_idx, subgraph_candidate_idx, + qconfig_str, OutputLogger, fqn) + attr_name = _get_attr_name(cur_subgraph_idx, subgraph_candidate_idx) + assert not hasattr(model, attr_name) + setattr(model, attr_name, logger_mod_orig) + insertion_point = last_node + with model.graph.inserting_after(insertion_point): + logger = model.graph.call_module( + attr_name, args=(last_node,), kwargs={}) + insertion_point = logger + + # create a copy of the subgraph + cur_node_orig = first_node + cur_node_copy = None + first_node_copy = None + while cur_node_orig in subgraph_to_use: + # TODO(future PR): make this support all possible args/kwargs + if cur_node_orig is first_node: + new_args = cur_node_orig.args + new_kwargs = cur_node_orig.kwargs + else: + first_arg_for_copy = cur_node_copy + new_args = tuple([first_arg_for_copy, *cur_node_orig.args[1:]]) # noqa: C409 + new_kwargs = cur_node_orig.kwargs + # make a copy of cur_node_orig + with model.graph.inserting_after(insertion_point): + cur_node_copy = model.graph.create_node( + cur_node_orig.op, + cur_node_orig.target, + new_args, + new_kwargs, + # cur_node_orig.name, # TODO(future PR): set name explicitly + ) + if first_node_copy is None: + first_node_copy = cur_node_copy + # since now only linear subgraphs are supported, all nodes + # except the last one must have only one user + if cur_node_orig != last_node: + assert len(cur_node_orig.users.keys()) == 1 + cur_node_orig = next(iter(cur_node_orig.users.keys())) + assert not cur_node_orig.name.startswith(SHADOW_NODE_NAME_PREFIX) + insertion_point = cur_node_copy + + # add a comparison logger after last_node's copy + subgraph_candidate_idx = 1 + logger_mod_orig = _get_logger_for_subgraph( + model, first_node, last_node, cur_subgraph_idx, subgraph_candidate_idx, + qconfig_str, OutputComparisonLogger, fqn) + attr_name = _get_attr_name(cur_subgraph_idx, subgraph_candidate_idx) + assert not hasattr(model, attr_name) + setattr(model, attr_name, logger_mod_orig) + with model.graph.inserting_after(insertion_point): + logger = model.graph.call_module( + attr_name, args=(cur_node_copy, last_node), kwargs={}) + + # save the final node so we can use it in step 2 + orig_first_node_to_shadow_in_node[first_node] = first_node_copy + orig_first_node_to_shadow_out_node[first_node] = cur_node_copy + + cur_subgraph_idx += 1 + + model.recompile() + + # Now, we go from + # + # x0 -> op0_0 -> x1_0 -> log -> x1 -> op1_0 -> ... + # \ \ \ + # -> op0_1 -> x1_1 -> clog -> op1_1 -> ... + # + # to + # + # x0 -> op0_0 -> x1_0 -> log --> x1_0 -> op1_0 -> ... + # \ \ + # -> op0_1 -> x1_1 -> clog -> x1_1 -> op1_1 -> ... + # + # sample values of key internal variables for the example above: + # + # orig_first_node_to_shadow_in_node = {op0_0: op0_1, op1_0: op1_1} + # orig_first_node_to_shadow_out_node = {op0_0: op0_1, op1_0: op1_1} + # + # note: for subgraphs with more than one node, in_node will be different + # compared to out_node + + + nodes_to_skip = set() + for n in orig_nodes: + if n.op in ('placeholder', 'get_attr', 'output') or n in nodes_to_skip: + continue + + maybe_subgraph = _get_subgraph_containing_node(n, subgraphs_dedup) + if maybe_subgraph is not None: + first_node, last_node = maybe_subgraph[0], maybe_subgraph[-1] + for node_to_skip in maybe_subgraph: + nodes_to_skip.add(node_to_skip) + else: + first_node, last_node = n, n + + def maybe_remap_node_to_shadow(node): + """ + If unshadowed `node` has a shadow version, return that. If not, + return `node`. + """ + if not isinstance(node, Node): + # handle scalars + return node + + if node.op in ('placeholder', 'get_attr'): + return node + + # Find the shadowed version of this arg from the previous + # subgraph. For this, we need to: + # 1. navigate to the first node of the previous subgraph + # 2. get the output of the shadow wrapper which has (1) as an input + + # For now, assume the arg is in matched subgraphs. In the + # future we may have to handle the case where this is not true. + prev_subgraph = _get_subgraph_containing_node( + node, subgraphs_dedup) + if prev_subgraph is None: + prev_subgraph = [node] + prev_first_node = prev_subgraph[0] + prev_shadow_output = \ + orig_first_node_to_shadow_out_node[prev_first_node] + return prev_shadow_output + + cur_shadow_input = \ + orig_first_node_to_shadow_in_node[first_node] + assert cur_shadow_input is not None + cur_shadow_input.args = tree_map( + maybe_remap_node_to_shadow, cur_shadow_input.args) + cur_shadow_input.kwargs = tree_map( + maybe_remap_node_to_shadow, cur_shadow_input.kwargs) + + model.recompile() + +def _get_weight_info_from_shadow_wrapper(shadow_wrapper: torch.nn.Module): + # input: shadow wrapper module + # output if shadow wrapper module has a weighted op: + # (quantize_fn, (quantize_fn_args)) + # output if shadow wrapper module doesn't have a weighted op: + # None + + # For now, assume that the weight is the second input + # to the shadow module. If that changes, we can fix it later. + placeholders_seen = 0 + for shadow_n in shadow_wrapper.graph.nodes: # type: ignore[union-attr] + if shadow_n.op != 'placeholder': + continue + + placeholders_seen += 1 + if placeholders_seen != 2: + continue + + # the subgraph looks like + # + # _input_scale_1 = self._input_scale_1 + # _input_zero_point_1 = self._input_zero_point_1 + # quantize_per_channel = torch.quantize_per_channel( + # w2_0, _input_scale_1, _input_zero_point_1, + # 0, torch.qint8) + # + # we have `w2_0`, and are navigating this subgraph + # to get `_input_scale_1` and `_input_zero_point_1` + + assert len(shadow_n.users) == 1 + quant_node = next(iter(shadow_n.users.keys())) + new_args: Any = None + if quant_node.target == torch.quantize_per_channel: + _weight, scale_node, zp_node, axis, dtype = quant_node.args + scale_val = getattr_from_fqn( + shadow_wrapper, scale_node.target) + zp_val = getattr_from_fqn( + shadow_wrapper, zp_node.target) + new_args = (scale_val, zp_val, axis, dtype) + else: + assert quant_node.target == torch.quantize_per_tensor + _weight, scale_node, zp_node, dtype = quant_node.args + scale_val = getattr_from_fqn( + shadow_wrapper, scale_node.target) + zp_val = getattr_from_fqn( + shadow_wrapper, zp_node.target) + new_args = (scale_val, zp_val, dtype) + return (quant_node.target, new_args) + + return None + + +def extract_weight_comparison(m: GraphModule) -> NSResultsType: + + # example graph: + # + # w1 = self.w1 + # b1 = self.b1 + # linear = torch._C._nn.linear(x, w1, b1) + # shadow_0_0 = self.shadow_0_0(linear) + # shadow_wrapper_0_1 = self.shadow_wrapper_0_1(x, w1, b1) + # shadow_0_1 = self.shadow_0_1(shadow_wrapper_0_1, linear) + # + # algorithm: + # 1. for each call_function node matching our allowlist: + # 2. if corresponding shadow wrapper exists, extract the weight pair + # + # Note: this is not super robust, but that's ok because this is + # just for legacy customers who depend on the previous two-model version + # of this API. TBD if we need to make this robust. + # Note: modules are not supported, since existing customers only + # use functions. + + # TODO(future PR): move this to config + weighted_ops = { + torch.nn.functional.linear, + } + + results: NSResultsType = { + 'model': {NSSingleResultValuesType.WEIGHT.value: {}} + } + + for n in m.graph.nodes: # type: ignore[union-attr] + if not (n.op == 'call_function' and n.target in weighted_ops): + continue + + # Check if we have a corresponding shadow wrapper + # TODO(future PR, if needed): support kwargs + # TODO(future PR, if needed): support multiple shadow users + first_arg = n.args[0] + shadow_wrapper_node = None + for user in first_arg.users: + # TODO(before land): fix string match + if user.op == 'call_module' and \ + user.target.startswith('shadow_wrapper'): + shadow_wrapper_node = user + break + + if shadow_wrapper_node is None: + continue + + shadow_wrapper = getattr_from_fqn( + m, shadow_wrapper_node.target) # type: ignore[arg-type] + weight_info = _get_weight_info_from_shadow_wrapper( + shadow_wrapper) + if weight_info is None: + continue + + # get weight + w_node = n.args[1] + w_obj = getattr_from_fqn(m, w_node.target).detach() + + # get a quantized version of weight + quant_fn, quant_fn_args_except_first = weight_info + new_args = (w_obj, *quant_fn_args_except_first) + w_obj_q = quant_fn(*new_args) + + # add a comparison + ref_node_name = n.name + prev_node_name = n.name + ref_node_type = get_target_type_str(n, m) + prev_node_type = ref_node_type + fqn = None + if hasattr(m, '_node_name_to_scope'): + fqn = m._node_name_to_scope[n.name][0] # type: ignore[index] + comparison = torch.ao.ns.fx.utils.compute_sqnr(w_obj, w_obj_q) + result_fp32 = { + 'res_type': NSSingleResultValuesType.WEIGHT.value, + 'values': [w_obj], + 'prev_node_name': prev_node_name, + 'prev_node_target_type': prev_node_type, + 'ref_node_name': ref_node_name, + 'ref_node_target_type': ref_node_type, + 'index_within_arg': 0, + 'index_of_arg': 0, + 'fqn': fqn, + 'qconfig_str': '', + 'comparisons': [comparison], + 'comparison_fn_name': 'sqnr', + } + result_q = { + 'res_type': NSSingleResultValuesType.WEIGHT.value, + 'values': [w_obj_q], + 'prev_node_name': prev_node_name, + 'prev_node_target_type': prev_node_type, + 'ref_node_name': ref_node_name, + 'ref_node_target_type': ref_node_type, + 'index_within_arg': 0, + 'index_of_arg': 0, + 'fqn': fqn, + 'qconfig_str': '', + 'comparisons': [comparison], + 'comparison_fn_name': 'sqnr', + } + + # go from subgraph_n_1 to subgraph_n_0 + _1, _2, node_idx, _3 = shadow_wrapper_node.target.split('_') + name_fp32 = f"subgraph_{node_idx}_0" + name_q = f"subgraph_{node_idx}_1" + + results['model'][NSSingleResultValuesType.WEIGHT.value][name_fp32] = \ + [result_fp32] + results['model'][NSSingleResultValuesType.WEIGHT.value][name_q] = \ + [result_q] + + return results + +# TODO(future PR): redesign this to make it easier to consume outputs +def group_results_by_subgraph(results: NSResultsType) -> Any: + """ + Creates a comparison of results + + Input: + + { + 'model': { + 'node_output': { + 'subgraph_0_0': [ + 'values': [torch.tensor(...), ...], ... + 'ref_node_name': ..., + 'ref_node_target_type': ..., + 'qconfig_str': ..., + 'comparisons': [], ... + 'comparison_fn_name': '', + 'fqn': '...', + ], + 'subgraph_0_1': [ + 'values': [torch.tensor(...), ...], ... + 'ref_node_name': ..., + 'ref_node_target_type': ..., + 'qconfig_str': ..., + 'comparisons': [torch.tensor(...), ...], ... + 'comparison_fn_name': '...', + 'fqn': '...', + ], + ... + }, + }, + } + + Output: + { + 'subgraph_0': { + '0': { + 'ref_node_name': '...', + 'ref_node_target_type': ..., + 'values': [torch.tensor(...), ...], + 'qconfig_str': None, + 'comparisons': [torch.tensor(...), ...], ... + 'comparison_fn_name': '...', + 'fqn': '...', + }, + '1': { + 'ref_node_name': '...', + 'ref_node_target_type': ..., + 'values': [torch.tensor(...), ...], + 'qconfig_str': '...', + 'comparisons': [torch.tensor(...), ...], ... + 'comparison_fn_name': '...', + 'fqn': '...', + }, + }, + } + + """ + subgraph_name_to_subgraph_results: Any = collections.defaultdict(dict) + + # node_output or weight + key_to_use = next(iter(results['model'].keys())) + + for subgraph_name_with_idx, subgraph_candidate_results in \ + results['model'][key_to_use].items(): + + # convert from `subgraph_m_n` to `subgraph_m` and `n` + subgraph_str, subgraph_idx, subgraph_candidate_idx = \ + subgraph_name_with_idx.split('_') + subgraph_name = f'{subgraph_str}_{subgraph_idx}' + + subgraph_results = { + 'ref_node_name': subgraph_candidate_results[0]['ref_node_name'], + 'ref_node_target_type': subgraph_candidate_results[0]['ref_node_target_type'], + 'fqn': subgraph_candidate_results[0]['fqn'], + 'values': subgraph_candidate_results[0]['values'], + 'qconfig_str': subgraph_candidate_results[0]['qconfig_str'], + 'comparisons': subgraph_candidate_results[0]['comparisons'], + 'comparison_fn_name': subgraph_candidate_results[0]['comparison_fn_name'], + } + + subgraph_name_to_subgraph_results[subgraph_name][subgraph_candidate_idx] = \ + subgraph_results + + return dict(subgraph_name_to_subgraph_results) + +# TODO(future PR): redesign this to make it easier to consume outputs +def create_results_comparison( + results_grouped, +) -> Any: + """ + Input: + + { + 'subgraph_0': { + '0': { + 'ref_node_name': '...', + 'ref_node_target_type': ..., + 'values': [torch.tensor(...), ...], + 'qconfig_str': '', + 'comparisons': [], + 'comparison_fn_name': '', + 'fqn': '...', + }, + '1': { + 'ref_node_name': '...', + 'ref_node_target_type': ..., + 'values': [torch.tensor(...), ...], + 'qconfig_str': '...', + 'comparisons': [torch.tensor(...), ...], + 'comparison_fn_name': 'sqnr', + 'fqn': '...', + }, + }, + } + + Output: + { + 'subgraph_0': { + 'ref_node_name': '...', + 'ref_node_target_type': '...', + 'fqn': '...', + 'candidates': { + '1': { + 'qconfig_str': ..., + 'comparison_fn_name': 'sqnr', + 'cmp_raw': [..., ...], + 'cmp_mean': ..., + }, + ..., + }, + }, + } + """ + + results_comparison = {} + + for subgraph_name, subgraph_results in results_grouped.items(): + + candidates = {} + for subgraph_inner_name, subgraph_inner_result in subgraph_results.items(): + # skip comparing baseline to baseline + if subgraph_inner_name == '0': + continue + + # we expect the comparisons to be precalculated from + # calibration, so we just fetch them here + cmp_raw = subgraph_inner_result['comparisons'] + cmp_raw_tensor = torch.stack(cmp_raw) + + candidates[subgraph_inner_name] = { + 'qconfig_str': subgraph_inner_result['qconfig_str'], + 'comparison_fn_name': subgraph_inner_result['comparison_fn_name'], + 'cmp_raw': cmp_raw_tensor, + 'cmp_mean': torch.mean(cmp_raw_tensor), + } + + results_comparison[subgraph_name] = { + 'ref_node_name': subgraph_results['0']['ref_node_name'], + 'ref_node_target_type': subgraph_results['0']['ref_node_target_type'], + 'fqn': subgraph_results['0']['fqn'], + 'candidates': candidates, + } + + return results_comparison + +# TODO(future PR): redesign this to make it easier to consume outputs +def print_n_shadows_summary( + results_comparison, +) -> None: + """ + Input: + + { + 'subgraph_0': { + 'ref_node_name': 'linear1', + 'ref_node_target_type': '...', + 'fqn': '...', + 'candidates': { + '1': { + 'qconfig_str': ..., + 'comparison_fn_name': ..., + 'cmp_raw': [45.0, 55.0], + 'cmp_mean': 50.0, + }, + ..., + }, + }, + } + + Prints: + + node_name | node_type | fqn | 0 | 1 | ... + linear1 | ... | ... | 45.0 | 50.0 | ... + """ + + try: + from tabulate import tabulate + except ImportError: + print("`print_tabular` relies on the library `tabulate`, " + "which could not be found on this machine. Run `pip " + "install tabulate` to install the library.") + return + + results = [] + for subgraph_data in results_comparison.values(): + mean_all_candidates = [ + candidate['cmp_mean'] + for candidate_name, candidate in subgraph_data['candidates'].items() + ] + + data_row = [ + subgraph_data['ref_node_name'], + subgraph_data['ref_node_target_type'], + subgraph_data['fqn'], + *mean_all_candidates, + ] + results.append(data_row) + + max_candidate_idx_len = -1 + for data_row in results: + max_candidate_idx_len = max(max_candidate_idx_len, len(data_row[1])) + candidate_idx_headers = [str(x) for x in range(max_candidate_idx_len)] + + headers = ['node_name', 'node_type', 'fqn', *candidate_idx_headers] + print(tabulate(results, headers=headers)) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/ns_types.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/ns_types.py new file mode 100644 index 0000000000000000000000000000000000000000..5c3c422dd4ae9d698645615f285efdae24dc278c --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/ns_types.py @@ -0,0 +1,64 @@ +import enum +from typing import NamedTuple + +from torch.fx.graph import Node + +from typing import Dict, Any, List, Union, Callable + +class NSSingleResultValuesType(str, enum.Enum): + WEIGHT = 'weight' + NODE_OUTPUT = 'node_output' + NODE_INPUT = 'node_input' + +class NSSubgraph(NamedTuple): + start_node: Node + end_node: Node + base_op_node: Node + +# TODO(future PR): see if we can use typing_extensions's TypedDict instead +# to properly type the various keys +# { +# # one of NSSingleResultValuesType +# 'type': 'weight', +# # the values of type specified above +# 'values': [torch.tensor(...), ...], +# # name of the node directly before the logger +# 'prev_node_name': 'linear1', +# # type of the underlying function or module +# 'prev_node_target_type': torch.nn.functional.linear # or torch.nn.Linear, etc +# # name of the node responsible for adding this logger +# # Note: this may differ from prev_node_name if we are logging inputs +# 'ref_node_name': 'linear1', +# # index of this node within the arg of the input/output node +# # for example, in cat([x1, x2, x3], dim=0), x2 would have index_within_arg == 1 +# 'index_within_arg': 0, +# # index of this node within the args of the input/output node +# # for example, in add(x1, x2), x2 would have index_of_arg == 1 +# 'index_of_arg': 0, +# # precomputed comparisons of logger values to reference values +# 'comparisons': [torch.tensor(...), ...] +# # name of function used for precomputed comparisons +# 'comparison_fn_name': 'sqnr', +# # string representation of qconfig responsible for creating this logger +# 'qconfig_str': 'QConfig(...)', +# } +NSSingleResultType = Dict[str, Any] + +# { +# 'layer_name_1': { # subgraph name +# 'node_output': { # results type (node_output, node_input, weight) +# 'model_name_a': # model name +# [NSSingleResultType, ...], # results, ordered by index_within_arg +# 'model_name_b': +# [NSSingleResultType, ...], +# }, +# }, +# } +# +NSResultsType = Dict[str, Dict[str, Dict[str, List[NSSingleResultType]]]] + +# Defines the underlying target type of a node, for example: +# `F.conv1d` for a `call_function` conv node +# `nn.Conv1d` for a `call_module` node calling the forward of a `nn.Conv1d` module +# `'sigmoid'` for a `call_method` node calling `x.sigmoid()` +NSNodeTargetType = Union[Callable, str] diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/pattern_utils.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/pattern_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2925dfe012125f3428d156602199e1e9d840e926 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/pattern_utils.py @@ -0,0 +1,200 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +toq = torch.ops.quantized + +from torch.fx import GraphModule +from torch.fx.graph import Node + +from torch.ao.quantization.backend_config import get_native_backend_config +from torch.ao.quantization.fx.quantize_handler import _get_pattern_to_quantize_handlers +from torch.ao.quantization.utils import getattr_from_fqn +from .ns_types import NSNodeTargetType +from torch.ao.quantization import ( + ObserverBase, + FakeQuantizeBase, +) + +from typing import Dict, Tuple, Set, Callable, Any, Union, List + + +def get_type_a_related_to_b( + base_name_to_sets_of_related_ops: Dict[str, Set[NSNodeTargetType]], +) -> Set[Tuple[NSNodeTargetType, NSNodeTargetType]]: + # TODO(future PR): allow customizations + # TODO(future PR): reuse existing quantization mappings + # TODO(future PR): add the rest of modules and ops here + type_a_related_to_b: Set[Tuple[NSNodeTargetType, NSNodeTargetType]] = set() + + for s in base_name_to_sets_of_related_ops.values(): + s_list = list(s) + # add every bidirectional pair + for idx_0 in range(0, len(s_list)): + for idx_1 in range(idx_0, len(s_list)): + type_a_related_to_b.add((s_list[idx_0], s_list[idx_1])) + type_a_related_to_b.add((s_list[idx_1], s_list[idx_0])) + + return type_a_related_to_b + + +NSFusionElType = Union[ + Callable, # call_function or call_module type, example: F.linear or nn.Conv2d + str, # call_method name, example: "dequantize" + Tuple[str, Any], # call_method name and first argument, example: ("to", torch.float16) +] +NSFusionType = Union[ + Tuple[NSFusionElType, NSFusionElType], + Tuple[NSFusionElType, NSFusionElType, NSFusionElType, NSFusionElType], +] + +def get_reversed_fusions() -> List[Tuple[NSFusionType, int]]: + """ + Set of potential fusions, in reverse order. The order is reversed + to match how fusion patterns are defined in quantization code. + + Fusion format: + ((fusion_op_0, fusion_op_1), base_op_idx) + + Where base_op_idx is the idx of the op we should use to match other related + ops. Note: base_op_idx is specified in non-reverse order, i.e. a base_op_idx + of 0 represents the first op in regular (non-reverse) order, 1 represents the + second op, etc. + """ + results: List[Tuple[NSFusionType, int]] = [] + + # Possible syntaxes: + # * single op: torch.nn.Conv2d + # * multiple ops: (torch.nn.ReLU, torch.nn.Conv2d) + # For fusions, we only care about patterns composed of multiple ops. + # TODO(future PR): allow customizations from default patterns. + all_quant_patterns = _get_pattern_to_quantize_handlers(get_native_backend_config()) + + default_base_op_idx = 0 + for quant_pattern in all_quant_patterns.keys(): + # TODO: this is a temporary hack to flatten the patterns from quantization so + # that it works with the ns matcher function, maybe we should use `_is_match` + # in torch.ao.quantization.fx.match_utils to match the patterns + if isinstance(quant_pattern, tuple) and len(quant_pattern) == 2 and \ + isinstance(quant_pattern[1], tuple) and len(quant_pattern[1]) == 2: + # flatten the pattern with form (nn.ReLU, (nn.BatchNorm2d, nn.Conv2d)) + quant_pattern = (quant_pattern[0], quant_pattern[1][0], quant_pattern[1][1]) + + # Only patterns of multiple ops are fusions, ignore + # patterns which contain a single ops (they get matched + # without caring about fusions). + if isinstance(quant_pattern, tuple): + results.append((quant_pattern, default_base_op_idx)) # type: ignore[arg-type] + + # For each pattern, add additional patterns with observers and + # fake quants at the end. + # TODO(future PR): if needed, implement matching for a node + # having multiple output observers. + for cls in (ObserverBase, FakeQuantizeBase): + if isinstance(quant_pattern, tuple): + new_pattern = (cls, *quant_pattern) + else: + new_pattern = (cls, quant_pattern) + results.append((new_pattern, default_base_op_idx)) # type: ignore[arg-type] + + + # After this point, results contains values such as + # [..., ((torch.nn.Relu, torch.nn.Conv2d), 0), ...] + + # Patterns for matching fp16 emulation are not specified in the quantization + # fusion mappings. For now, define them here. + fp16_em_base_op_idx = 1 + patterns_to_add = [ + # linear-relu fp16 emulation: + # fp16_to_fp32 -> linear -> relu -> fp32_to_fp16 + ((("to", torch.float16), F.relu, F.linear, "dequantize"), fp16_em_base_op_idx,), + # Conv-BN fusion (this happens outside of quantization patterns, + # which is why it is defined separately here). + ((nn.BatchNorm1d, nn.Conv1d), default_base_op_idx), + ((nn.BatchNorm2d, nn.Conv2d), default_base_op_idx), + ((nn.BatchNorm3d, nn.Conv3d), default_base_op_idx), + ((nn.ReLU, nn.BatchNorm1d, nn.Conv1d), default_base_op_idx), + ((nn.ReLU, nn.BatchNorm2d, nn.Conv2d), default_base_op_idx), + ((nn.ReLU, nn.BatchNorm3d, nn.Conv3d), default_base_op_idx), + ] + for p in patterns_to_add: + results.append(p) # type: ignore[arg-type] + results.append(((ObserverBase, *p[0]), p[1])) # type: ignore[arg-type] + results.append(((FakeQuantizeBase, *p[0]), p[1])) # type: ignore[arg-type] + + return results + + +def end_node_matches_reversed_fusion( + end_node: Node, + reversed_fusion: NSFusionType, + gm: GraphModule, + seen_nodes: Set[Node], +) -> bool: + """ + Returns true if a pattern ending with `end_node` matches + the fusion pattern. + """ + cur_node = end_node + for fusion_idx in range(len(reversed_fusion)): + # each node can only belong to one matched pattern + if cur_node in seen_nodes: + return False + + cur_fusion_el = reversed_fusion[fusion_idx] + + if cur_node.op == 'call_function': + fusion_el_is_fun = (not isinstance(cur_fusion_el, str)) and \ + (not isinstance(cur_fusion_el, type)) + if fusion_el_is_fun: + if cur_node.target != cur_fusion_el: + return False + if len(cur_node.args) > 0 and isinstance(cur_node.args[0], Node): + cur_node = cur_node.args[0] + else: + return False + else: + return False + + elif cur_node.op == 'call_module': + fusion_el_is_mod = isinstance(cur_fusion_el, type) + if fusion_el_is_mod: + assert isinstance(cur_node.target, str) + target_mod = getattr_from_fqn(gm, cur_node.target) + if not isinstance(cur_fusion_el, type): + return False + if not isinstance(target_mod, cur_fusion_el): + return False + if len(cur_node.args) > 0 and isinstance(cur_node.args[0], Node): + cur_node = cur_node.args[0] + else: + return False + else: + return False + + elif cur_node.op == 'call_method': + fusion_el_is_meth_with_second_arg = \ + isinstance(cur_fusion_el, tuple) and len(cur_fusion_el) == 2 + fusion_el_is_meth_without_args = isinstance(cur_fusion_el, str) + if fusion_el_is_meth_without_args or fusion_el_is_meth_with_second_arg: + if fusion_el_is_meth_without_args: + if cur_node.target != cur_fusion_el: + return False + else: + assert isinstance(cur_fusion_el, tuple) + if cur_node.target != cur_fusion_el[0]: + return False + elif len(cur_node.args) < 2: + return False + elif cur_node.args[1] != cur_fusion_el[1]: + return False + + if len(cur_node.args) > 0 and isinstance(cur_node.args[0], Node): + cur_node = cur_node.args[0] + else: + return False + else: + return False + else: + return False + + return True diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/qconfig_multi_mapping.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/qconfig_multi_mapping.py new file mode 100644 index 0000000000000000000000000000000000000000..20a005d0c8bf9441554113e9a1bb49754b415ee1 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/qconfig_multi_mapping.py @@ -0,0 +1,243 @@ +from __future__ import annotations + +import copy +from typing import Any, Callable, Dict, List, Union + +import torch +from torch.ao.quantization import QConfigMapping +from torch.ao.quantization.qconfig_mapping import _QCONFIG_STYLE_ORDER +from torch.ao.quantization.qconfig import QConfigAny + +__all__ = ["QConfigMultiMapping"] + +_QCONFIG_STYLE_TO_METHOD: Dict[str, str] = { + "global_qconfig": "set_global", + "object_type_qconfigs": "set_object_type", + "module_name_regex_qconfigs": "set_module_name_regex", + "module_name_qconfigs": "set_module_name", + "module_name_object_type_order_qconfigs": "set_module_name_object_type_order", +} + +def _remove_duplicates_and_none(qconfig_list: List[QConfigAny]) -> None: + to_remove = [] + for index, cur_qconfig in enumerate(qconfig_list): + if cur_qconfig is None: + to_remove.append(index) + break + for checked_qconfig in qconfig_list[:index]: + if torch.ao.quantization.qconfig_equals(cur_qconfig, checked_qconfig): + to_remove.append(index) + break + for index in to_remove[::-1]: + qconfig_list.pop(index) + +class QConfigMultiMapping: + """ + This class, used with the prepare_n_shadows_model API, stores a list of :class:`torch.ao.quantization.QConfigMapping`s + so that multiple QConfigs can be specified for each QConfig matching style. + + The user can specify QConfigs using the following methods (in increasing match priority): + + ``set_global`` : sets the global (default) QConfigs + + ``set_object_type`` : sets the QConfigs for a given module type, function, or method name + + ``set_module_name_regex`` : sets the QConfigs for modules matching the given regex string + + ``set_module_name`` : sets the QConfigs for modules matching the given module name + + ``set_module_name_object_type_order`` : sets the QConfigs for modules matching a combination + of the given module name, object type, and the index at which the module appears + + Note: Usage of set methods is the same as in QConfigMapping except with a passed in list of QConfigs rather than a + single QConfig. + + Example usage:: + + qconfig_mapping = QConfigMultiMapping() + .set_global([qconfig1, qconfig2]) + .set_object_type(torch.nn.Linear, [qconfig2, qconfig3]) + .set_object_type(torch.nn.ReLU, [qconfig1]) + .set_module_name_regex("foo.*bar.*conv[0-9]+", [qconfig2]) + .set_module_name_regex("foo.*", [qconfig1, qconfig2, qconfig3]) + .set_module_name("module1", [None]) + .set_module_name("module2", [qconfig2]) + .set_module_name_object_type_order("foo.bar", torch.nn.functional.linear, 0, [qconfig3]) + + """ + + def __init__(self): + # initialize this with 1 QConfigMapping to avoid corner cases + self.qconfig_mappings_list: List[QConfigMapping] = [QConfigMapping()] + + def _handle_list_size_mismatch( + self, qconfig_list: List[QConfigAny], style: str + ) -> None: + # this method handles cases where the size of qconfig_list does not match + # the size of qconfig_mappings_list. + # Issue: Consider a user inserting global_qconfig A and B first, then inserting + # qconfig C as an object_type_qconfig for conv ops. If we internally store + # 1 QConfigMapping with A and C and another with just B, then the + # second QConfigMapping will match B to conv ops (which is not wanted), since B is global. + + # we avoid this by maintaining the invariant that if any QConfigMapping + # has a qconfig style+key with a qconfig in it, all QConfigMappings must + # have either a qconfig or None for that same style+key. In the above + # example, a None qconfig would prevent the unwanted match in the + # second QConfigMapping + + if len(qconfig_list) > len(self.qconfig_mappings_list): + # Case: we have more qconfigs (in qconfig_list) than QConfigMappings + + # Add new QConfigMappings (initialized so we maintain the `invariant`) + + new_qconfig_mapping = QConfigMapping() + # searches other QConfigMappings for qconfig style+keys + # that need to be inserted as `None` into the new QConfigMapping + for qconfig_mapping in self.qconfig_mappings_list: + + # global_qconfig has None by default + for check_style in _QCONFIG_STYLE_ORDER[1:]: + qconfigs_dict = getattr(qconfig_mapping, check_style) + target_qconfigs_dict = getattr(new_qconfig_mapping, check_style) + for key in qconfigs_dict: + target_qconfigs_dict[key] = None + break + + # insert copies of this new QConfigMapping until all entires + # in qconfig_list can fit among the QConfigMappings + while len(qconfig_list) > len(self.qconfig_mappings_list): + self.qconfig_mappings_list.append(copy.deepcopy(new_qconfig_mapping)) + else: + # Case: we have fewer qconfigs in qconfig_list than QConfigMappings + + # pad qconfig_list with `None` until length is same + while len(qconfig_list) < len(self.qconfig_mappings_list): + qconfig_list.append(None) + + # this function applies the insertion method across each QConfigMapping + def _insert_qconfig_list( + self, + style: str, + args: List[Union[str, int, Callable]], + qconfig_list: List[QConfigAny], + ) -> None: + + # we remove duplicates and None to make the ordering of qconfigs + # deterministic upon insertion. + _remove_duplicates_and_none(qconfig_list) + + self._handle_list_size_mismatch(qconfig_list, style) + method_name = _QCONFIG_STYLE_TO_METHOD[style] + for qconfig_mapping, qconfig in zip(self.qconfig_mappings_list, qconfig_list): + # uses QConfigMapping set method to insert qconfig + set_method = getattr(qconfig_mapping, method_name) + set_method(*args, qconfig) + + def set_global(self, global_qconfig_list: List[QConfigAny]) -> QConfigMultiMapping: + """ + Set global QConfigs + see :func:`~torch.ao.quantization.QConfigMapping.set_global()` for more info + """ + self._insert_qconfig_list("global_qconfig", [], global_qconfig_list) + return self + + def set_object_type( + self, object_type: Union[Callable, str], qconfig_list: List[QConfigAny] + ) -> QConfigMultiMapping: + """ + Set object type QConfigs + see :func:`~torch.ao.quantization.QConfigMapping.set_object_type()` for more info + """ + self._insert_qconfig_list("object_type_qconfigs", [object_type], qconfig_list) + return self + + def set_module_name_regex( + self, module_name_regex: str, qconfig_list: List[QConfigAny] + ) -> QConfigMultiMapping: + """ + Set module_name_regex QConfigs + see :func:`~torch.ao.quantization.QConfigMapping.set_module_name_regex()` for more info + """ + self._insert_qconfig_list( + "module_name_regex_qconfigs", [module_name_regex], qconfig_list + ) + return self + + def set_module_name( + self, module_name: str, qconfig_list: List[QConfigAny] + ) -> QConfigMultiMapping: + """ + Set module_name QConfigs + see :func:`~torch.ao.quantization.QConfigMapping.set_module_name()` for more info + """ + self._insert_qconfig_list("module_name_qconfigs", [module_name], qconfig_list) + return self + + def set_module_name_object_type_order( + self, + module_name: str, + object_type: Callable, + index: int, + qconfig_list: List[QConfigAny], + ) -> QConfigMultiMapping: + """ + Set module_name QConfigs + see :func:`~torch.ao.quantization.QConfigMapping.set_module_name_object_type_order()` for more info + """ + self._insert_qconfig_list( + "module_name_object_type_order_qconfigs", + [module_name, object_type, index], + qconfig_list, + ) + return self + + def __repr__(self): + return ( + self.__class__.__name__ + + " [" + + "".join(f"\n{qconfig_mapping.__repr__()}," for qconfig_mapping in self.qconfig_mappings_list) + + "\n]" + ) + + @classmethod + def from_list_qconfig_mapping( + cls, qconfig_mapping_list: List[QConfigMapping] + ) -> QConfigMultiMapping: + """ + Creates a QConfigMultiMapping from a list of QConfigMappings + """ + new_qconfig_multi_mapping = cls() + + new_qconfig_multi_mapping.qconfig_mappings_list = copy.deepcopy( + qconfig_mapping_list + ) + + # we need to avoid the issue described in _handle_list_size_mismatch, + # so we reinsert all the qconfigs using the QConfigMultiMapping + # set methods + + # go through all qconfig styles + # note: global can be ignored since it is None by default + for style in _QCONFIG_STYLE_ORDER[1:]: + + # gather all key+qconfigs for current style + # into qconfig_dict_list + qconfig_dict_list: Dict[Any, List[QConfigAny]] = {} + for qconfig_mapping in qconfig_mapping_list: + qconfig_dict = getattr(qconfig_mapping, style) + for key, qconfig in qconfig_dict.items(): + if key not in qconfig_dict_list: + qconfig_dict_list[key] = [] + qconfig_dict_list[key].append(qconfig) + + # reinsert all gathered key+qconfigs + set_method_name = _QCONFIG_STYLE_TO_METHOD[style] + set_method = getattr(new_qconfig_multi_mapping, set_method_name) + for key, qconfig_list in qconfig_dict_list.items(): + if isinstance(key, tuple): + set_method(*key, qconfig_list) + else: + set_method(key, qconfig_list) + + return new_qconfig_multi_mapping diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/_learnable_fake_quantize.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/_learnable_fake_quantize.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d0979b8b9275983bf4ddb12c70a81803259fa05f Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/_learnable_fake_quantize.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/fuse_modules.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/fuse_modules.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd02e22ad4f43a20bf3b7045294cdca75884e3eb Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/fuse_modules.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/observer.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/observer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2287ab5413e906f47a78d4f6c24337c2bb6727e4 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/observer.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/qconfig_mapping.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/qconfig_mapping.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..10a2cdf51821eed99023a3386f8bce39343ff4de Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/qconfig_mapping.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/quantize.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/quantize.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c8c420f8bbe0738cd46e06852f0af797e8db0781 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/quantize.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/quantize_jit.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/quantize_jit.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6f752e6a7b37503bd17cd56ede0e820ef4292b09 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/quantize_jit.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2efcc7188301f7123894386f4111022ccea1dc88 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/representation/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/representation/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7453beff6cd25fb6217f1ec9f3e4911c7129d62f Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/representation/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/representation/rewrite.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/representation/rewrite.py new file mode 100644 index 0000000000000000000000000000000000000000..36ef2ecbdcdc129db094dfdf54d876d8b9faee37 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/representation/rewrite.py @@ -0,0 +1,600 @@ +import torch +from torch.fx import GraphModule +from ..export_utils import _WrapperModule +from ..utils import ( + get_aten_graph_module, + remove_tensor_overload_for_qdq_ops, + _replace_literals_with_new_placeholders, + _replace_literals_with_existing_placeholders, +) +from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401 +from torch.fx.subgraph_rewriter import replace_pattern +from torch._higher_order_ops.out_dtype import out_dtype +from typing import Optional, Callable, Tuple, Any +from dataclasses import dataclass + +from functools import partial + +__all__ = [ + "reference_representation_rewrite", +] + + +_QUANTIZED_LINEAR_EXAMPLE_INPUTS = ( + torch.randint(-128, 127, (2, 5), dtype=torch.int8), + torch.randn(1, dtype=torch.float), + torch.zeros(1, dtype=torch.int), + torch.tensor([-128], dtype=torch.int), + torch.tensor([127], dtype=torch.int), + torch.randint(-128, 127, (5, 5), dtype=torch.int8), + torch.randn(1, dtype=torch.float), + torch.zeros(1, dtype=torch.int), + torch.tensor([-127], dtype=torch.int), + torch.tensor([127], dtype=torch.int), + torch.randn(1, dtype=torch.float), + torch.randn(1, dtype=torch.float), + torch.zeros(1, dtype=torch.int), + torch.tensor([-128], dtype=torch.int), + torch.tensor([127], dtype=torch.int), +) + +def _qdq_quantized_linear( + x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, + weight_i8, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max, + bias_fp32, + out_scale, out_zero_point, out_quant_min, out_quant_max +): + x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor( + x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8) + weight_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor( + weight_i8, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max, torch.int8) + out_fp32 = torch.ops.aten.linear.default(x_fp32, weight_fp32, bias_fp32) + out_i8 = torch.ops.quantized_decomposed.quantize_per_tensor( + out_fp32, out_scale, out_zero_point, out_quant_min, out_quant_max, torch.int8) + return out_i8 + +def _reference_quantized_linear( + x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, + weight_i8, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max, + bias_fp32, + out_scale, out_zero_point, out_quant_min, out_quant_max +): + # without using quant_min/max in clamp, the traced graph will not have quant_mi/max args. + # This results in failure to match the pattern. + # Therefore, we call a torch.ops.aten.clamp here + x_i8 = torch.ops.aten.clamp(x_i8, x_quant_min, x_quant_max) + weight_i8 = torch.ops.aten.clamp(weight_i8, weight_quant_min, weight_quant_max) + + x_i16 = x_i8.to(torch.int16) + weight_i16 = weight_i8.to(torch.int16) + # always set bias to None so that the same representation can work for the case + # no matter if bias_scale == x_scale * weight_scale or not + acc_i32 = out_dtype( + torch.ops.aten.linear.default, + torch.int32, + x_i16 - x_zero_point, + weight_i16 - weight_zero_point, + None) + # TODO: change to mul.Scalar + # Note: we are quantizing bias with these scales without signal from user, but it might be OK + bias_scale = x_scale * weight_scale + bias_i32 = out_dtype(torch.ops.aten.div.Tensor, torch.int32, bias_fp32, bias_scale) + acc_i32 = acc_i32 + bias_i32 + # TODO: change to mul.Scalar when we make x_scale/weight_scale etc. Scalar values + acc_i32 = out_dtype(torch.ops.aten.mul.Tensor, torch.int32, acc_i32, x_scale * weight_scale / out_scale) + out_zero_point + out_i8 = torch.ops.aten.clamp(acc_i32, out_quant_min, out_quant_max).to(torch.int8) + return out_i8 + + +_DYNAMIC_QUANTIZED_LINEAR_EXAMPLE_INPUTS = ( + torch.randn((2, 5), dtype=torch.float), + -128, + 127, + torch.finfo(torch.float32).eps, + torch.randint(-128, 127, (5, 5), dtype=torch.int8), + torch.randn(1, dtype=torch.float), + torch.zeros(1, dtype=torch.int), + torch.tensor([-127], dtype=torch.int), + torch.tensor([127], dtype=torch.int), + torch.randn(1, dtype=torch.float), +) + + +def _qdq_dynamic_quantized_linear( + x_fp32, x_quant_min, x_quant_max, x_eps, + weight_i8, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max, + bias_fp32, +): + x_scale, x_zero_point = torch.ops.quantized_decomposed.choose_qparams(x_fp32, x_quant_min, x_quant_max, x_eps, torch.int8) + x_i8 = torch.ops.quantized_decomposed.quantize_per_tensor( + x_fp32, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8) + x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor( + x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8) + weight_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor( + weight_i8, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max, torch.int8) + out_fp32 = torch.ops.aten.linear.default(x_fp32, weight_fp32, bias_fp32) + return out_fp32 + +def _reference_dynamic_quantized_linear( + x_fp32, x_quant_min, x_quant_max, x_eps, + weight_i8, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max, + bias_fp32, +): + x_scale, x_zero_point = torch.ops.quantized_decomposed.choose_qparams(x_fp32, x_quant_min, x_quant_max, x_eps, torch.int8) + # decomposed representation for quantize_per_tensor + # TODO: use out_dtype(mul, ...) here when the op is ready + x_fp32 = x_fp32 / x_scale # fp32 + # round modes might be different here + # pytorch is rounding to even, which is also common for most of the backends + x_fp32 = torch.round(x_fp32) # fp32 + x_i32 = x_fp32.to(dtype=torch.int32) # int32 + x_i32 = x_i32 + x_zero_point # int32 + # clamp works for fp32, int32 and int8 dtypes + x_i32 = torch.clamp(x_i32, x_quant_min, x_quant_max) # int32 + x_i8 = x_i32.to(dtype=torch.int8) + + weight_i8 = torch.ops.aten.clamp(weight_i8, weight_quant_min, weight_quant_max) + + x_i16 = x_i8.to(torch.int16) + weight_i16 = weight_i8.to(torch.int16) + # always set bias to None so that the same representation can work for the case + # no matter if bias_scale == x_scale * weight_scale or not + acc_i32 = out_dtype( + torch.ops.aten.linear.default, + torch.int32, + x_i16 - x_zero_point, + weight_i16 - weight_zero_point, + None) + bias_scale = x_scale * weight_scale + bias_i32 = out_dtype(torch.ops.aten.div.Tensor, torch.int32, bias_fp32, bias_scale) + acc_i32 = acc_i32 + bias_i32 + out_fp32 = acc_i32 * (x_scale * weight_scale) + return out_fp32 + + +_QUANTIZED_CONV2d_EXAMPLE_INPUTS = ( + torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8), + torch.randn(1, dtype=torch.float), + torch.zeros(1, dtype=torch.int), + torch.tensor([-128], dtype=torch.int), + torch.tensor([127], dtype=torch.int), + torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8), + torch.randn(1, dtype=torch.float), + torch.zeros(1, dtype=torch.int), + torch.tensor([-127], dtype=torch.int), + torch.tensor([127], dtype=torch.int), + torch.randn(1, dtype=torch.float), + torch.randn(1, dtype=torch.float), + torch.zeros(1, dtype=torch.int), + torch.tensor([-128], dtype=torch.int), + torch.tensor([127], dtype=torch.int), +) + +def _qdq_quantized_conv2d( + x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, + weight_i8, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max, + bias_fp32, + out_scale, out_zero_point, out_quant_min, out_quant_max +): + stride = [1, 1] + padding = [0, 0] + dilation = [1, 1] + transposed = False + output_padding = [0, 0] + groups = 1 + x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor( + x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8) + weight_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor( + weight_i8, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max, torch.int8) + out_fp32 = torch.ops.aten.convolution.default( + x_fp32, weight_fp32, bias_fp32, stride, padding, dilation, transposed, output_padding, groups) + out_i8 = torch.ops.quantized_decomposed.quantize_per_tensor( + out_fp32, out_scale, out_zero_point, out_quant_min, out_quant_max, torch.int8) + return out_i8 + +def _reference_quantized_conv2d( + x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, + weight_i8, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max, + bias_fp32, + out_scale, out_zero_point, out_quant_min, out_quant_max +): + stride = [1, 1] + padding = [0, 0] + dilation = [1, 1] + transposed = False + output_padding = [0, 0] + groups = 1 + # without using quant_min/max in clamp, the traced graph will not have quant_mi/max args. + # This results in failure to match the pattern. + # Therefore, we call a torch.ops.aten.clamp here + x_i8 = torch.ops.aten.clamp(x_i8, x_quant_min, x_quant_max) + weight_i8 = torch.ops.aten.clamp(weight_i8, weight_quant_min, weight_quant_max) + + x_i16 = x_i8.to(torch.int16) + weight_i16 = weight_i8.to(torch.int16) + # always set bias to None so that the same representation can work for the case + # no matter if bias_scale == x_scale * weight_scale or not + acc_i32 = out_dtype( + torch.ops.aten.convolution.default, + torch.int32, + x_i16 - x_zero_point, + weight_i16 - weight_zero_point, + None, stride, padding, dilation, transposed, output_padding, groups) + # Note: we are quantizing bias with these scales without signal from user, but it might be OK + bias_scale = x_scale * weight_scale + # bias quantization to int32 uses bias_scale = x_scale * weight_scale due to: + # Take linear calculation for example + # Out_(i, j)_fp32 = Sum_(over k)[X_(i, k)_fp32 * W_(i, k)_fp32] + bias_(i)_fp32 + # Represent X, W fp32 as their dequant transforms + # A_fp32 = (A_q - A_zero_point)/A_scale + # Out_(i, j)_fp32 = Sum_(over k)[(X_(i, k)_fp32 - X_zp) * X_scale * (W_(i, k)_fp32 - W_zp) * W_scale] + bias_(i)_fp32 + # Factor out X_scale and W_scale + # Out_(i, j)_fp32 = ((X_scale * W_scale) * Sum_(over k)[(X_(i, k)_fp32 - X_zp) * (W_(i, k)_fp32 - W_zp)]) + bias_(i)_fp32 + # In order to addition of bias_(i)_fp32 inside, we must do + # Out_(i, j)_fp32 = (X_scale * W_scale) * (Sum_(over k)[(X_(i, k)_fp32 - X_zp) * (W_(i, k)_fp32 - W_zp)] + (1 / (X_scale * W_scale)) * bias_(i)_fp32)W_scale # noqa: B950 + # Note we had to multiply bias_fp32 qith X_scale * W_scale = bias_scale + # Thus bias quantization to int32 must be with X_scale * W_scale + + bias_i32 = out_dtype(torch.ops.aten.div.Tensor, torch.int32, bias_fp32, bias_scale) + # Unsqueeze to match broadcast dims + # Unfortnuately I cannot do bias_i32.unsqueeze(0) due to literal matching nightmare + # in graph pattern replacement + bias_i32 = bias_i32.unsqueeze(-1) + bias_i32 = bias_i32.unsqueeze(-1) + acc_i32 = acc_i32 + bias_i32 + # TODO: change to mul.Scalar when we make x_scale/weight_scale etc. Scalar values + acc_i32 = out_dtype( + torch.ops.aten.mul.Tensor, torch.int32, acc_i32, x_scale * weight_scale / out_scale) + out_zero_point + out_i8 = torch.ops.aten.clamp(acc_i32, out_quant_min, out_quant_max).to(torch.int8) + return out_i8 + + +_QUANTIZED_ADD_OR_ADD_RELU_EXAMPLE_INPUTS = ( + torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8), + torch.randn(1, dtype=torch.float), + torch.zeros(1, dtype=torch.int), + torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8), + torch.randn(1, dtype=torch.float), + torch.zeros(1, dtype=torch.int), + torch.randn(1, dtype=torch.float), + torch.zeros(1, dtype=torch.int), + torch.tensor([-128], dtype=torch.int), + torch.tensor([127], dtype=torch.int), +) + +def _qdq_quantized_add_relu( + x_i8, x_scale, x_zero_point, y_i8, y_scale, y_zero_point, + out_scale, out_zero_point, quant_min, quant_max +): + x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(x_i8, x_scale, x_zero_point, quant_min, quant_max, torch.int8) + y_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(y_i8, y_scale, y_zero_point, quant_min, quant_max, torch.int8) + out_fp32 = x_fp32 + y_fp32 + out_fp32 = torch.ops.aten.relu(out_fp32) + out_i8 = torch.ops.quantized_decomposed.quantize_per_tensor( + out_fp32, out_scale, out_zero_point, quant_min, quant_max, torch.int8 + ) + return out_i8 + +def _reference_quantized_add_relu( + x_i8, x_scale, x_zero_point, y_i8, y_scale, y_zero_point, + out_scale, out_zero_point, quant_min, quant_max +): + """ + See comments for `_reference_quantized_add` for more information on + how to derive the formula for out_i8 based on x_i8 and y_i8 + """ + x_i32 = x_i8.to(torch.int32) + y_i32 = y_i8.to(torch.int32) + # TODO: change this to mul.Scalar? + x_i32 = out_dtype(torch.ops.aten.mul.Tensor, torch.int32, (x_i32 - x_zero_point), (x_scale / out_scale)) + y_i32 = out_dtype(torch.ops.aten.mul.Tensor, torch.int32, (y_i32 - y_zero_point), (y_scale / out_scale)) + out_i32 = x_i32 + y_i32 + out_zero_point + # out_i32 = torch.ops.aten.clamp(out_i32, out_zero_point) + out_i8 = torch.ops.aten.clamp(out_i32, out_zero_point, quant_max).to(torch.int8) + return out_i8 + +def _qdq_quantized_add(x_i8, x_scale, x_zero_point, y_i8, y_scale, y_zero_point, out_scale, out_zero_point, quant_min, quant_max): + x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(x_i8, x_scale, x_zero_point, quant_min, quant_max, torch.int8) + y_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(y_i8, y_scale, y_zero_point, quant_min, quant_max, torch.int8) + out_fp32 = x_fp32 + y_fp32 + out_i8 = torch.ops.quantized_decomposed.quantize_per_tensor( + out_fp32, out_scale, out_zero_point, quant_min, quant_max, torch.int8 + ) + return out_i8 + +def _reference_quantized_add( + x_i8, x_scale, x_zero_point, y_i8, y_scale, y_zero_point, + out_scale, out_zero_point, quant_min, quant_max +): + """ + # How to Derive the formula for out_i8 based on x_i8 and y_i8 + # (since quantized add takes x_i8, y_i8 and their quantization parameters, and produce an out_i8) + + # out_i8 is quantized output, we can write down the formula for it first: +out_i8 = out_f32 / out_scale + out_zero_point (1) + + # then out_fp32 is computed from x_f32 + y_f32, and the x_fp32 and y_fp32 are the dequantized x_i8 and y_i8 + out_f32 = x_f32 + y_f32 (2) + x_fp32 = (x_i8 - x_zero_point) * x_scale (3) + y_fp32 = (y_i8 - y_zero_point) * y_scale (4) + + # applying the above fomula to the out_i8 equation we can get the following: + out_i8 = out_fp32 / out_scale + out_zero_point # (1) + = (x_f32 + y_f32) / out_scale + out_zero_point # applying (2) to substitute out_fp32 with x_fp32 + y_fp32 + = ((x_i8 - x_zero_point) * x_scale + (y_i8 - y_zero_point) * y_scale) / out_scale + out_zero_point # apply (3) and (4) + """ + x_i32 = x_i8.to(torch.int32) + y_i32 = y_i8.to(torch.int32) + # TODO: use out_dtype op + x_i32 = torch.round((x_scale / out_scale) * (x_i32 - x_zero_point)).to(torch.int32) + y_i32 = torch.round((y_scale / out_scale) * (y_i32 - y_zero_point)).to(torch.int32) + out_i32 = x_i32 + y_i32 + out_zero_point + quant_min = -128 + quant_max = 127 + out_i8 = torch.ops.aten.clamp(out_i32, quant_min, quant_max).to(torch.int8) + return out_i8 + +_QUANTIZED_MAX_POOL2D_EXAMPLE_INPUTS = ( + torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8), + torch.randn(1, dtype=torch.float), + torch.zeros(1, dtype=torch.int), + torch.tensor([-128], dtype=torch.int), + torch.tensor([127], dtype=torch.int), + torch.randn(1, dtype=torch.float), + torch.zeros(1, dtype=torch.int), + torch.tensor([-128], dtype=torch.int), + torch.tensor([127], dtype=torch.int), +) + +def _qdq_quantized_max_pool2d( + x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, out_scale, out_zero_point, out_quant_min, out_quant_max): + kernel_size = 1 + stride = 1 + padding = 0 + dilation = 1 + ceil_mode = False + x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8) + out_fp32, _ = torch.ops.aten.max_pool2d_with_indices.default(x_fp32, kernel_size, stride, padding, dilation, ceil_mode) + out_i8 = torch.ops.quantized_decomposed.quantize_per_tensor( + out_fp32, out_scale, out_zero_point, out_quant_min, out_quant_max, torch.int8) + return out_i8 + +def _reference_quantized_max_pool2d( + x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, out_scale, out_zero_point, out_quant_min, out_quant_max): + kernel_size = 1 + stride = 1 + padding = 0 + dilation = 1 + ceil_mode = False + # to preserve x_quant_min, x_quant_max in the graph for pattern matching + x_i8 = torch.clamp(x_i8, x_quant_min, x_quant_max) + x_i32 = x_i8.to(torch.int32) + out_i32, _ = torch.ops.aten.max_pool2d_with_indices.default( + x_i32 - x_zero_point, + kernel_size, + stride, + padding, + dilation, + ceil_mode + ) + out_fp32 = out_i32 * (x_scale / out_scale) + out_zero_point + out_fp32 = torch.clamp(out_fp32, out_quant_min, out_quant_max) + out_i8 = out_fp32.to(torch.int8) + return out_i8 + +_QUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS = ( + torch.randn(1, 3, 3, 3, dtype=torch.float), + torch.randn(1, dtype=torch.float), + torch.zeros(1, dtype=torch.int), + torch.tensor([-128], dtype=torch.int), + torch.tensor([127], dtype=torch.int), +) + +def _quantize_per_tensor_int8(x_fp32, scale, zero_point, quant_min, quant_max): + x = torch.ops.quantized_decomposed.quantize_per_tensor(x_fp32, scale, zero_point, quant_min, quant_max, torch.int8) + return x + +def _reference_quantize_per_tensor_int8(x_fp32, scale, zero_point, quant_min, quant_max): + # TODO: use out_dtype(mul, ...) here when the op is ready + x = x_fp32 / scale # fp32 + # round modes might be different here + # pytorch is rounding to even, which is also common for most of the backends + x = torch.round(x) # fp32 + x = x.to(dtype=torch.int32) # int32 + x = x + zero_point # int32 + # clamp works for fp32, int32 and int8 dtypes + x = torch.clamp(x, quant_min, quant_max) # int32 + x = x.to(dtype=torch.int8) + return x + +_DEQUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS = ( + torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8), + torch.randn(1, dtype=torch.float), + torch.zeros(1, dtype=torch.int), + torch.tensor([-128], dtype=torch.int), + torch.tensor([127], dtype=torch.int), +) + +def _dequantize_per_tensor_int8(x_i8, scale, zero_point, quant_min, quant_max): + x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(x_i8, scale, zero_point, quant_min, quant_max, torch.int8) + return x_fp32 + +def _reference_dequantize_per_tensor_int8(x_i8, scale, zero_point, quant_min, quant_max): + # without using quant_min/max in clamp, the traced graph will not have quant_mi/max args. + # This results in failure to match the pattern. + # Therefore, we call a torch.ops.aten.clamp here + x_i8 = torch.ops.aten.clamp(x_i8, quant_min, quant_max) + # TODO: use out_dtype op + # note: x_i8.to(torch.int32) does not work here + # TODO: debug the implementation later when torchdynamo time out issue is resolved + return ((x_i8.to(torch.float32) - zero_point) * scale).to(dtype=torch.float32) + +_QUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS = ( + torch.randn(1, 3, 3, 3, dtype=torch.float), + torch.randn(3, dtype=torch.float), + torch.zeros(3, dtype=torch.int), + 1, + -128, + 127, +) + +def _quantize_per_channel_int8(x_fp32, scales, zero_points, ch_axis, quant_min, quant_max): + out_i8 = torch.ops.quantized_decomposed.quantize_per_channel( + x_fp32, scales, zero_points, ch_axis, quant_min, quant_max, torch.int8 + ) + return out_i8 + +def _reference_quantize_per_channel_int8(x_fp32, scales, zero_points, ch_axis, quant_min, quant_max): + x_fp32 = torch.transpose(x_fp32, ch_axis, -1) + out_i32 = torch.ops.aten.clamp(torch.round(x_fp32 / scales).to(torch.int32) + zero_points, quant_min, quant_max) + out_i32 = torch.transpose(out_i32, ch_axis, -1) + return out_i32.to(torch.int8) + +_DEQUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS = ( + torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8), + torch.randn(3, dtype=torch.float), + torch.zeros(3, dtype=torch.int), + 1, + -128, + 127, +) + +def _dequantize_per_channel_int8(x_i8, scales, zero_points, ch_axis, quant_min, quant_max): + # the following will be replaced as placeholders + out_fp32 = torch.ops.quantized_decomposed.dequantize_per_channel( + x_i8, scales, zero_points, ch_axis, quant_min, quant_max, torch.int8 + ) + return out_fp32 + +def _reference_dequantize_per_channel_int8(x_i8, scales, zero_points, ch_axis, quant_min, quant_max): + # the following will be replaced as placeholders + # in order to preserve the quant_min/quant_max args for pattern matching (e.g. matching for int4 quantized ops) + # we call a torch.ops.aten.clamp here + x_i8 = torch.ops.aten.clamp(x_i8, quant_min, quant_max) + x_i8 = torch.transpose(x_i8, ch_axis, -1) + x_i32 = x_i8.to(torch.int32) + out_fp32 = (x_i32 - zero_points).to(torch.float) * scales + out_fp32 = torch.transpose(out_fp32, ch_axis, -1) + return out_fp32 + +def _replace_ph_qdq_per_channel_replacement(gm: torch.fx.GraphModule): + return _replace_literals_with_existing_placeholders( + gm, + exclude_literals=[-1], + literal_to_ph_idx={1: 3, -128: 4, 127: 5} + ) + + +@dataclass +class _RewriteInfo: + """Data needed for rewrite, this includes example inputs, pattern and replacement functions + and post transformation functions for the exported pattern and replacement GraphModule + """ + + # example inputs used for exporting the pattern into GraphModule + example_inputs: Tuple[Any, ...] + pattern: Callable + replacement: Callable + # post transformation on the exported pattern and replacement GraphModule + pattern_post_trans: Optional[Callable[[GraphModule], GraphModule]] = None + replacement_post_trans: Optional[Callable[[GraphModule], GraphModule]] = None + +_REWRITE_INFO_LIST = [ + _RewriteInfo( + _DYNAMIC_QUANTIZED_LINEAR_EXAMPLE_INPUTS, + _WrapperModule(_qdq_dynamic_quantized_linear), + _WrapperModule(_reference_dynamic_quantized_linear), + partial( + _replace_literals_with_existing_placeholders, + literal_to_ph_idx={ + -128: 1, + 127: 2, + torch.finfo(torch.float32).eps: 3 + } + ), + partial( + _replace_literals_with_existing_placeholders, + literal_to_ph_idx={ + -128: 1, + 127: 2, + torch.finfo(torch.float32).eps: 3 + } + ), + ), + _RewriteInfo( + _QUANTIZED_LINEAR_EXAMPLE_INPUTS, + _WrapperModule(_qdq_quantized_linear), + _WrapperModule(_reference_quantized_linear), + _replace_literals_with_new_placeholders, + _replace_literals_with_new_placeholders, + ), + _RewriteInfo( + _QUANTIZED_CONV2d_EXAMPLE_INPUTS, + _WrapperModule(_qdq_quantized_conv2d), + _WrapperModule(_reference_quantized_conv2d), + partial(_replace_literals_with_new_placeholders, exclude_literals=[-1]), + partial(_replace_literals_with_new_placeholders, exclude_literals=[-1]), + ), + _RewriteInfo( + _QUANTIZED_ADD_OR_ADD_RELU_EXAMPLE_INPUTS, + _WrapperModule(_qdq_quantized_add_relu), + _WrapperModule(_reference_quantized_add_relu), + ), + _RewriteInfo( + _QUANTIZED_ADD_OR_ADD_RELU_EXAMPLE_INPUTS, + _WrapperModule(_qdq_quantized_add), + _WrapperModule(_reference_quantized_add), + ), + _RewriteInfo( + _QUANTIZED_MAX_POOL2D_EXAMPLE_INPUTS, + _WrapperModule(_qdq_quantized_max_pool2d), + _WrapperModule(_reference_quantized_max_pool2d), + _replace_literals_with_new_placeholders, + _replace_literals_with_new_placeholders + ), + _RewriteInfo( + _QUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS, + _WrapperModule(_quantize_per_tensor_int8), + _WrapperModule(_reference_quantize_per_tensor_int8), + ), + _RewriteInfo( + _DEQUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS, + _WrapperModule(_dequantize_per_tensor_int8), + _WrapperModule(_reference_dequantize_per_tensor_int8), + ), + _RewriteInfo( + _QUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS, + _WrapperModule(_quantize_per_channel_int8), + _WrapperModule(_reference_quantize_per_channel_int8), + _replace_ph_qdq_per_channel_replacement, + _replace_ph_qdq_per_channel_replacement + ), + _RewriteInfo( + _DEQUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS, + _WrapperModule(_dequantize_per_channel_int8), + _WrapperModule(_reference_dequantize_per_channel_int8), + _replace_ph_qdq_per_channel_replacement, + _replace_ph_qdq_per_channel_replacement + ), +] + +def reference_representation_rewrite(model: GraphModule) -> GraphModule: + remove_tensor_overload_for_qdq_ops(model) + for rewrite_info in _REWRITE_INFO_LIST: + example_inputs = rewrite_info.example_inputs + pattern = rewrite_info.pattern + replacement = rewrite_info.replacement + pattern_post_trans = rewrite_info.pattern_post_trans + replacement_post_trans = rewrite_info.replacement_post_trans + pattern = get_aten_graph_module(pattern, example_inputs) # type: ignore[arg-type, assignment] + remove_tensor_overload_for_qdq_ops(pattern) # type: ignore[arg-type] + replacement = get_aten_graph_module(replacement, example_inputs) # type: ignore[arg-type, assignment] + remove_tensor_overload_for_qdq_ops(replacement) # type: ignore[arg-type] + if pattern_post_trans: + pattern = pattern_post_trans(pattern) + if replacement_post_trans: + replacement = replacement_post_trans(replacement) + pattern.recompile() # type: ignore[attr-defined] + replacement.recompile() # type: ignore[attr-defined] + matches = replace_pattern(model, pattern, replacement) + return model diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/contrib/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/contrib/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b3f8e42eae6d40b37c1591aa21f58a24a74a1366 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/contrib/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/contrib/__pycache__/_tensorboard_vis.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/contrib/__pycache__/_tensorboard_vis.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..46d7857ef27eba2f333d8b002ca7733268d9a689 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/contrib/__pycache__/_tensorboard_vis.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/config.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/config.py new file mode 100644 index 0000000000000000000000000000000000000000..da5120d6edf180f7fbbe88ac342b4d0e4b383e50 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/config.py @@ -0,0 +1,6 @@ +# Whether to disable showing progress on compilation passes +# Need to add a new config otherwise wil get a circular import if dynamo config is imported here +disable_progress = True + +# If True this also shows the node names in each pass, for small models this is great but larger models it's quite noisy +verbose_progress = False diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/_backward_state.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/_backward_state.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d92ffecd1b8cebcde41f5ea4e99ecdd589ba347a Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/_backward_state.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/partitioner_utils.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/partitioner_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..67df4fbe11f017c9e477af621d8dba9dbd6b1d46 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/partitioner_utils.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/proxy_tensor.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/proxy_tensor.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a4ae17fad00b377d8ddd2f78466e1e1e8db4aa5d Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/proxy_tensor.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/recording.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/recording.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f0dbffa659462f0c147421eded796a68c81ed9d6 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/recording.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/validator.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/validator.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..887dea2990a614519ca6d0c1e81b7a0fa846b765 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/validator.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/recording.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/recording.py new file mode 100644 index 0000000000000000000000000000000000000000..551cab26f0a5aa695bbbea710910dec5bdd46cf6 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/recording.py @@ -0,0 +1,458 @@ +import functools +import itertools +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +import torch.utils._pytree as pytree + + +__all__ = [ + "ShapeEnvEvent", + "record_shapeenv_event", + "replay_shape_env_events", + "FakeTensorMeta", + "shape_env_check_state_equal", + "NotEqualError", +] + +# [Note: Recording ShapeEnv Events] +# ================================= +# +# What is a ShapeEnv event? +# ------------------------- +# We consider a ShapeEnv event every function call (ShapeEnv method or +# independent function) that modifies the state of the ShapeEnv instance. +# Such calls are recorded alongside their positional and keyword arguments, +# so that it may be replayed over a different ShapeEnv instance. +# +# See [Note: ShapeEnv State Equality] for what is considered the state +# of a ShapeEnv instance. +# +# What is it for? +# --------------- +# ShapeEnv events recording is used for reconstructing the ShapeEnv in an +# arbitrary state in time. +# +# Being able to arbitrarily replay events like so is useful, mainly for +# translation validation bisection. i.e. if a ValidationException has been +# raised, find the earliest point in time where the translation validation +# fails. +# +# Besides that, it also allows us to inspect the given instance and, +# for example, check the guards that would actually be issued at that point. +# +# What kind of arguments can be stored in an event? +# ------------------------------------------------- +# There's no specific rule for what cannot be used as an argument. +# That said, pay special attention to the following cases: +# +# 1. Tensor inputs: there are some tests that check whether the inputs +# were garbage collected after execution. These will fail if there's +# an event that is holding a reference to those inputs. +# +# 2. ShapeEnv arguments: if there is an argument of ShapeEnv type, that +# will be automatically replaced by the new given ShapeEnv instance. +# +# 3. SymTypes arguments: they also hold references to ShapeEnv. So, +# whenever we see them, we create a new instance, replacing the +# ShapeEnv reference. +# +# 4. FX nodes: specifically, FX nodes from the FX graph for symbolic +# shapes. That argument must be replaced when replaying the event at +# ShapeEnvEvent.run, since it has to reference a node from the given +# instance, and not from the recorded instance. + + +# Event class for reconstructing ShapeEnv at arbitrary time. +# +# Represents a method call that mutates ShapeEnv in a way that affects the +# issued guards, when ShapeEnv.produce_guards is called. +@dataclass +class ShapeEnvEvent: + # ShapeEnv method. + f: Callable + + # Arguments and keyword arguments called with. + args: Optional[List[Any]] = None + kwargs: Optional[Dict[str, Any]] = None + + # List of tracked_fakes at the time the method was called. + tracked_fakes: Optional[List[Any]] = None + + # Name of the captured event. + # Used for special handling of particular methods. + name: Optional[str] = None + + # Replay itself, but using shape_env as self. + def run(self, shape_env=None) -> Any: + from torch.fx.experimental.symbolic_shapes import ( + is_symbolic, + ShapeEnv, + SymTypes, + ) + + # Special handling for the constructor event. + if self.f is ShapeEnv: + assert shape_env is None and self.args is None and self.kwargs is not None + return ShapeEnv(**self.kwargs) + + assert shape_env is not None + args = list(self.args or list()) + kwargs = dict(self.kwargs or dict()) + + # Replace any argument of type ShapeEnv by the given one. + args, kwargs = pytree.tree_map_only( + ShapeEnv, lambda _: shape_env, (args, kwargs) + ) + + # Replace any argument of type SymTypes by a new instance, + # replacing its ShapeEnv reference. + args, kwargs = pytree.tree_map_only( + lambda x: isinstance(x, SymTypes) and is_symbolic(x), + lambda a: type(a)(a.node.with_shape_env(shape_env)), + (args, kwargs), + ) + + # Converts FX nodes using the mapping argument. + def maybe_convert_node(x: Any) -> Any: + if not isinstance(x, torch.fx.Node): + # Don't do anything to x if it's not an FX node. + return x + + # If, at some point, we created an FX node, it means that translation validation is on. + # It also means we are building an FX graph for symbolic shapes at shape_env.graph, and + # we are tracking node names at shape_env.name_to_node. + assert hasattr(shape_env, "name_to_node") + name_to_node = shape_env.name_to_node # type: ignore[attr-defined] + assert x.name in name_to_node + return name_to_node[x.name] + + # Replaces the value of an specific argument by the result of fn. + def replacearg(index: int, key: str, fn: Callable): + if index < len(args): + args[index] = fn(args[index]) + if key in kwargs: + kwargs[key] = fn(kwargs[key]) + + if self.is_create_fx_call_function(): + # ShapeEnv.create_fx_call_function: + # "args" parameter is a tuple of FX nodes from the FX graph of the old ShapeEnv. + # They must be replaced, since a "call_function" FX node with this tuple as argument + # will be added to the FX graph of the new shape_env. + replacearg( + index=2, + key="args", + fn=lambda args: tuple(maybe_convert_node(a) for a in args), + ) + if self.is_evaluate_expr() or self.is_defer_runtime_assert(): + # ShapeEnv.evaluate_expr and ShapeEnv.defer_runtime_assert: + # "fx_node" parameter is an (optional) FX node that represents the evaluate expression. + # They must be replaced, since it will be part of a "call_function" FX node for + # torch._assert, which will be added to the FX graph of the new shape_env. + replacearg(index=3, key="fx_node", fn=maybe_convert_node) + + # Actually call the method with the converted arguments. + return self.f(*args, **kwargs) + + def __str__(self) -> str: + name = self.name if self.name is not None else self.f.__name__ + return f"event: {name} ({self.args}, {self.kwargs})" + + def is_create_fx_call_function(self) -> bool: + return self.name == "_create_fx_call_function" + + def is_evaluate_expr(self) -> bool: + return self.name == "evaluate_expr" + + def is_defer_runtime_assert(self) -> bool: + return self.name == "defer_runtime_assert" + + +# Extracts a ShapeEnv instance inside args and kwargs. +# Specifically, it looks for: +# 1. ShapeEnv arguments +# 2. SymInt, SymFloat, or SymBool arguments +# If we find more than one object of any of the above types, we +# also check that the ShapeEnv instance is the same for all of them. +def _extract_shape_env_and_assert_equal(args, kwargs): + from torch.fx.experimental.symbolic_shapes import is_symbolic, ShapeEnv, SymTypes + + def assert_equal(old: Optional[ShapeEnv], new: ShapeEnv) -> ShapeEnv: + if old is not None: + assert old is new, "call with different ShapeEnv" + return new + + shape_env = None + for val in itertools.chain(args, kwargs.values()): + if isinstance(val, ShapeEnv): + shape_env = assert_equal(shape_env, val) + if isinstance(val, SymTypes) and is_symbolic(val): + shape_env = assert_equal(shape_env, val.node.shape_env) + + return shape_env + + +# Decorator for recording the given function as a replayable event. +# +# This decorator should be used at every function that mutates the state of +# ShapeEnv in some way that affects the resulting issued guards (i.e. when +# ShapeEnv.produce_guards is called). +# +# save_tracked_fakes: saves a snapshot of the TrackedFake list. +# This is used when calling ShapeEnv.produce_guards at arbitrary points in time. +# +# When to save the list of TrackedFake? +# ===================================== +# We should save the list of TrackedFake whenever the translation validation +# bisection may actually stop and call the produce_guards method at the moment +# right after the recorded function was played. In other words, since the +# bisection bisects through torch._assert calls, we should save in all methods +# that adds a torch._assert call to the symbolic shapes FX graph. +# +# At the moment, there are 2 methods that save the list: +# - ShapeEnv.evaluate_expr +# - ShapeEnv.defer_runtime_assert +def record_shapeenv_event(*, save_tracked_fakes: bool = False) -> Callable: + def decorator(fn: Callable) -> Callable: + assert callable(fn) + name = fn.__name__ + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + from torch.fx.experimental.symbolic_shapes import ShapeEnv + + if isinstance(args[0], ShapeEnv) and args[0].is_recording: # type: ignore[has-type] + # If ShapeEnv is already recording an event, call the wrapped + # function directly. + # + # NB: here, we skip the check of whether all ShapeEnv instances + # are equal, in favor of a faster dispatch. + return fn(*args, **kwargs) + + # Retrieve an instance of ShapeEnv. + # Assumption: the collection of args and kwargs may not reference + # different ShapeEnv instances. + self = _extract_shape_env_and_assert_equal(args, kwargs) + + # If we are calling this function without any ShapeEnv instance + # alive in its arguments, we don't record and call the original. + if self is None: + return fn(*args, **kwargs) + + # Otherwise, start recording and call the function. + with self._recording(): + # Take a snapshot of the current tracked_fakes. + tracked_fakes = ( + self._snapshot_tracked_fakes() if save_tracked_fakes else None + ) + # Record the event for 'fn'. + event = ShapeEnvEvent( + fn, list(args), kwargs, tracked_fakes, name=fn.__name__ + ) + self.events.append(event) + # Play the event on this ShapeEnv. + return event.run(self) + + return wrapper + + return decorator + + +# Replays the ShapeEnvEvents list. +# It assumes the first event is the constructor call. +# +# fn: transforms an old FX node into one corresponding to the newly created ShapeEnv. +def replay_shape_env_events(events): + from torch.fx.experimental.symbolic_shapes import ShapeEnv + + constructor_event = events[0] + assert constructor_event.f == ShapeEnv + + # Constructs the new ShapeEnv. + shape_env = constructor_event.run() + + for event in events[1:]: + try: + # Actually replays each event. + # We need to call create_mapping_fn every time, since the node list might + # change after each event is replayed. + event.run(shape_env) + except Exception as e: + raise RuntimeError(f"failed when running event: {event}") from e + + return shape_env + + +# FakeTensor metadata. +# This is to be used in place of FakeTensor placeholders when calling +# ShapeEnv.produce_guards. +@dataclass +class FakeTensorMeta: + tensor_size: Tuple[Union[int, torch.SymInt], ...] + tensor_stride: Tuple[Union[int, torch.SymInt], ...] + tensor_storage_offset: Union[int, torch.SymInt] + is_nested: bool + + def size(self) -> Tuple[Union[int, torch.SymInt], ...]: + return self.tensor_size + + def stride(self) -> Tuple[Union[int, torch.SymInt], ...]: + return self.tensor_stride + + def storage_offset(self) -> Union[int, torch.SymInt]: + return self.tensor_storage_offset + + def dim(self) -> int: + return len(self.tensor_size) + + @staticmethod + def from_fake(fake) -> "FakeTensorMeta": + return FakeTensorMeta( + fake.size(), fake.stride(), fake.storage_offset(), fake.is_nested + ) + + +# [Note: ShapeEnv State Equality] +# =============================== +# +# What is considered ShapeEnv state? +# ---------------------------------- +# We consider to be the state of a ShapeEnv instance everything that +# is not in the inline tuple inside remove_nonstate_variables function. +# That is: the fields within ShapeEnv that modify the flow of execution +# of the program. +# +# So, for example: the replacements field might influence on how an +# expression is simplified. That, in turn, may result in a guard being +# statically known (i.e. not added). +# +# On the other hand, var_to_stack serves only changes what is printed +# in the screen, i.e. used only for debugging purposes. Therefore, we +# should not consider it when comparing states. +# +# What to do on NotEqualError? +# ---------------------------- +# Here are a few possible causes for getting a NotEqualError raised: +# +# 1. New field that does not belong in the ShapeEnv state. +# For example: log field of type ShapeEnvLoggerAdapter. Different +# ShapeEnv instances will always have different ShapeEnvLoggerAdapter +# instances, i.e. equality comparison would fail. +# Solution: add it to the inlined tuple inside remove_nonstate_variables +# function inside check_equal method. +# +# 2. New field that is not directly comparable across instances. +# For example: guards field of type List[ShapeGuard]. More specifically, +# the ShapeGuard type holds an expression and a stack information +# for debugging purposes. When replaying the even on a new ShapeEnv +# instance, the stack would be different, which would trigger this error. +# Solution: add a special case to the map_value function inside +# check_equal function. +# +# 3. Mutation of ShapeEnv on some not recorded function. +# If a mutation of the state of ShapeEnv happens inside a function +# that is not recorded (or that no caller in the stack is recorded), +# then, the replayed ShapeEnv won't catch that. +# Solution: decorate the function with record_shape_env_event. + + +# Checks whether the state of two ShapeEnv are equal w.r.t. the guards +# returned by ShapeEnv.produce_guards. +def shape_env_check_state_equal(env1, env2, non_state_variable_names, map_value): + # Collect and remove variables that don't necessarily represent the state + # of a ShapeEnv. Note: we copy the dictionary so that we don't modify the + # instance itself. + env1_vars = vars(env1).copy() + env2_vars = vars(env2).copy() + + for v in non_state_variable_names: + if v in env1_vars: + env1_vars.pop(v) + if v in env2_vars: + env2_vars.pop(v) + + # Function for transforming the mismatched values into string. + # Needed, since dict and set entries order might not be the same every time. + def value_to_str(value: Any) -> str: + if isinstance(value, dict): + return ( + "{" + + ", ".join(f"{k}: {value[k]}" for k in sorted(value.keys(), key=str)) + + "}" + ) + if isinstance(value, set): + return "{" + ", ".join(f"{v}" for v in sorted(value)) + "}" + return str(value) + + # Compares env1_vars with env2_vars. + # Here, we allow the value of each field to be mapped, so that we appropriately + # compare the two values. + def compare_vars( + map_value: Callable[[str, Any], Any] + ) -> List[Tuple[str, str, str]]: + env1_set, env2_set = set(env1_vars), set(env2_vars) + + # First, compare the set of keys in each vars dictionary. + if env1_set != env2_set: + raise NotEqualError( + "field set mismatch:", + [ + ( + "found unique fields:", + str(sorted(env1_set - env2_set)), + str(sorted(env2_set - env1_set)), + ), + ], + ) + + # Then, sort the keys, and compare the mapped values of each key. + sorted_keys = list(env1_set) + sorted_keys.sort() + + mapped_dict = [ + (k, map_value(k, env1_vars[k]), map_value(k, env2_vars[k])) + for k in sorted_keys + ] + + # Return a list of tuples representing the fields that did not match + # alongside their respective mapped values. + return [ + (f"{k}: values don't match.", value_to_str(val1), value_to_str(val2)) + for k, val1, val2 in mapped_dict + if val1 != val2 + ] + + # Accumulate the mismatching fields. + errors = compare_vars(map_value) + + if len(errors) > 0: + raise NotEqualError("field values don't match:", errors) + + +class NotEqualError(Exception): + def __init__( + self, + msg: str, + mismatched: List[Tuple[str, str, str]], + ) -> None: + details = "\n".join( + [ + "\n".join( + [ + f"==> {inner_msg}", + f" > Left: {str1}", + f" > Right: {str2}", + ] + ) + for inner_msg, str1, str2 in mismatched + ] + ) + + super().__init__( + f"""\ +ShapeEnv not equal: {msg} + +{details} +""" + ) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/refinement_types.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/refinement_types.py new file mode 100644 index 0000000000000000000000000000000000000000..762e4340f12b49d1a9f2628ce1e011e38b8d23a1 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/refinement_types.py @@ -0,0 +1,16 @@ +class Equality: + def __init__(self, lhs, rhs): + self.lhs = lhs + self.rhs = rhs + + def __str__(self): + return f'{self.lhs} = {self.rhs}' + + def __repr__(self): + return f'{self.lhs} = {self.rhs}' + + def __eq__(self, other): + if isinstance(other, Equality): + return self.lhs == other.lhs and self.rhs == other.rhs + else: + return False diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/rewriter.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/rewriter.py new file mode 100644 index 0000000000000000000000000000000000000000..c4abe52c8c279bf93cd05659423eaceddf023b55 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/rewriter.py @@ -0,0 +1,121 @@ +import ast +import inspect +import textwrap +import copy +import functools +from types import FunctionType +from typing import cast, Union, Callable, Dict, Optional, Any +from torch.fx._symbolic_trace import Tracer +from torch.fx.graph import Graph +from torch._sources import normalize_source_lines +import torch + +class AST_Rewriter(ast.NodeTransformer): + """ + Take a FunctionType object representing a `forward` method, then + perform an AST rewrite to swap out nodes that are not symbolically + traceable with a callsite to the FX alternative. + + To support swapping out an AST node, define a new `visit` method on + that node. For more details, see: + https://docs.python.org/3/library/ast.html#ast.NodeTransformer + """ + + def rewrite(self, fn: FunctionType): + + # Normalize the source lines + sourcelines, _ = inspect.getsourcelines(fn) + sourcelines = normalize_source_lines(sourcelines) + source = ''.join(sourcelines) + normalized_str = textwrap.dedent(source) + + # Rewrite the original AST + source_ast = ast.parse(normalized_str) + dest_ast = ast.fix_missing_locations(self.visit(source_ast)) + + # Pull out the compiled function from the newly-created Module + code = compile(dest_ast, "", "exec") + globals_dict = copy.copy(fn.__globals__) + keys_before = set(globals_dict.keys()) + exec(code, globals_dict) + new_keys = list(set(globals_dict.keys()) - keys_before) + assert len(new_keys) == 1 + fn_compiled = globals_dict[new_keys[0]] + + # return the compiled function with the original globals + def change_func_globals(f, globals): + """Based on https://stackoverflow.com/a/13503277/2988730 (@unutbu)""" + # __globals__ is a private member of the function class + # so we have to copy the function, f, all of its member, except f.__globals__ + g = FunctionType( + f.__code__, + globals, + name=f.__name__, + argdefs=f.__defaults__, + closure=f.__closure__, + ) + g = functools.update_wrapper(g, f) + g.__kwdefaults__ = copy.copy(f.__kwdefaults__) + return g + # Return the correct FunctionType object + return change_func_globals(fn_compiled, globals=fn.__globals__) + + def visit_Assert(self, node): + """ + Swap out the Assert node (Python's `assert`) with a callsite to the + symbolically-traceable torch._assert function + """ + # Create the Call node + n = ast.parse('torch._assert()', mode='eval') + assert isinstance(n, ast.Expression) + call_node = n.body + assert isinstance(call_node, ast.Call) + msg = node.msg if node.msg else ast.Constant(value="", kind=None) + call_node.args = [node.test, msg] + + # Ensure that the new node conforms to the Python AST grammar + expr_wrapper = ast.Expr(value=call_node) + + # Return the new Call node to signify that we want to use it as + # a replacement for the original _assert node + return ast.copy_location(expr_wrapper, node) + + def visit_AnnAssign(self, node): + """ + Swap out Python's AnnAssign with an Assign node where the annotation function is called. + Example: + Original: + y: Tensor_Type(1,2,3, Dyn) = f2(x) + Output: + y = annotate(f2(x),Tensor_Type((1,2,3,Dyn))) + """ + return ast.Assign(targets=[node.target], value=ast.Call( + func=ast.Name(id='annotate', ctx=ast.Load()), + args=[node.value, node.annotation], keywords=[])) + + +class RewritingTracer(Tracer): + def trace(self, root: Union[torch.nn.Module, Callable], concrete_args: Optional[Dict[str, Any]] = None) -> Graph: + return super().trace(_rewrite(root), concrete_args) + + +def _rewrite(fn: Union[torch.nn.Module, Callable]) -> Union[torch.nn.Module, Callable]: + if isinstance(fn, torch.nn.Module): + # Rewrite this module's `forward` as well as the `forward`s of + # all of this module's recursive descendents. Return the new, + # rewritten module hierarchy. + def rewrite_module(m : torch.nn.Module): + class RewrittenModule(torch.nn.Module): + def __init__(self, orig): + super().__init__() + for k, v in orig.__dict__.items(): + if isinstance(v, torch.nn.Module): + self.__dict__[k] = copy.copy(rewrite_module(v)) + else: + self.__dict__[k] = copy.copy(v) + RewrittenModule.forward = AST_Rewriter().rewrite(cast(FunctionType, m.forward)) + return RewrittenModule(m) + return rewrite_module(fn) + else: + # Rewrite this single free function + return AST_Rewriter().rewrite(cast(FunctionType, fn)) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/symbolic_shapes.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/symbolic_shapes.py new file mode 100644 index 0000000000000000000000000000000000000000..e9b7f56f97488cc299b78178e4fb5fb0ae230af2 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/symbolic_shapes.py @@ -0,0 +1,4362 @@ +# mypy: ignore-errors + +""" +``torch.fx.experimental.symbolic_shapes`` provides interfaces for interacting with +our symbolic shapes reasoning system that is used heavily in torch.compile. Although +this is not generally considered public API, when writing framework code in PyTorch +as well as extensions to PyTorch (e.g., in custom operator implementations), you may +need to make use of these APIs to setup dynamic shapes support appropriately. +""" + +import builtins +import collections +import functools +import inspect +import itertools +import logging +import math +import operator +import re +import sys +import threading +import traceback +from collections import defaultdict +from contextlib import contextmanager +from dataclasses import dataclass, field +from enum import Enum +from functools import lru_cache +from typing import ( + Any, + cast, + Callable, + Dict, + Iterable, + List, + Optional, + Sequence, + Set, + Tuple, + Type, + Union, + TYPE_CHECKING +) +from typing_extensions import TypeAlias + +import torch +import torch.fx +import torch.fx.traceback as fx_traceback +from torch.fx.experimental import _config as config + +from torch.fx.experimental.recording import ( + FakeTensorMeta, + ShapeEnvEvent, + record_shapeenv_event, + replay_shape_env_events, + shape_env_check_state_equal +) +from torch.fx.experimental.sym_node import SymNode, SymTypes + +# NB: The sym_* functions are used via getattr() and must be imported here. +from torch import SymBool, SymFloat, SymInt +from torch._guards import ShapeGuard, Source, TracingContext +from torch.utils._python_dispatch import is_traceable_wrapper_subclass +from torch.utils._sympy.functions import FloorDiv, Mod, IsNonOverlappingAndDenseIndicator +from torch.utils._sympy.solve import try_solve +from torch.utils._sympy.value_ranges import bound_sympy, SymPyValueRangeAnalysis, ValueRanges, ValueRangeError +from torch.utils._sympy.singleton_int import SingletonInt +from torch.utils._traceback import format_frame, CapturedTraceback +from torch._utils_internal import signpost_event +from torch._subclasses.meta_utils import is_sparse_any + +from torch._logging import LazyString + +if TYPE_CHECKING: + from torch._dynamo.source import TensorPropertySource + +InputList = List +DimList = List + +log = logging.getLogger(__name__) + +class GuardOnDataDependentSymNode(RuntimeError): + pass + +import sympy +from sympy.printing.str import StrPrinter +from sympy.printing.precedence import precedence, PRECEDENCE + +aten = torch._ops.ops.aten # type: ignore[has-type] + +__all__ = [ + "has_symbolic_sizes_strides", "create_contiguous", "ShapeEnv", "is_concrete_int", + "guard_int", "guard_float", "guard_scalar", "canonicalize_bool_expr", + "hint_int", "SYMPY_INTERP", "free_symbols", "is_symbol_binding_fx_node", + "is_concrete_bool", "is_nested_int", "SHAPEENV_EVENT_KEY", "CURRENT_NODE_KEY", + "has_free_symbols", "sym_eq", "SymbolicContext", "StatelessSymbolicContext", + "StatefulSymbolicContext", "SubclassSymbolicContext", "statically_known_true", + "guard_size_oblivious", +] + +# FX node metadata keys for symbolic shape FX graph. +SHAPEENV_EVENT_KEY = "shapeenv_event" +CURRENT_NODE_KEY = "current_node" + +# These are modules that contain generic code for interacting with ShapeEnv +# which are unlikely to identify a particular interesting guard statement +@lru_cache(None) +def uninteresting_files() -> Set[str]: + import torch._inductor.sizevars + import torch._library.abstract_impl + import torch._subclasses.meta_utils + import torch._subclasses.fake_tensor + mods = [ + sys.modules[__name__], + torch.fx.experimental.recording, + torch.fx.experimental.sym_node, + torch.fx.interpreter, + torch, + torch._inductor.sizevars, + torch._library.abstract_impl, + torch._subclasses.meta_utils, + torch._subclasses.fake_tensor, + ] + return {inspect.getfile(m) for m in mods} + +# We don't bother with the metaclass as all of the dispatching logic happens +# entirely from Python +# +# Didn't bother with ancestors for now, unlikely to have multiple modes for +# symints right now + +class ConstraintViolationError(RuntimeError): + pass + +def has_symbolic_sizes_strides(elem) -> bool: + return elem._has_symbolic_sizes_strides + +Int = Union[torch.SymInt, int] + +def create_contiguous(shape: Sequence[Int]) -> List[Int]: + strides: List[Int] = [1] + for dim in reversed(shape[:-1]): + strides.append(dim * strides[-1]) + return list(reversed(strides)) + +def hint_int(a: Union[torch.SymInt, int], fallback: Optional[int] = None) -> int: + """ + Retrieve the hint for an int (based on the underlying real values as observed + at runtime). If no hint is available (e.g., because data dependent shapes), + if fallback is not None, use that instead (otherwise raise an error). + """ + if isinstance(a, torch.SymInt): + return a.node.require_hint(fallback) + assert type(a) is int, a + return a + +Scalar = Union[torch.SymInt, torch.SymFloat, torch.SymBool, int, float, bool] + +def has_hint(a: Scalar) -> bool: + if isinstance(a, SymTypes): + return a.node.has_hint() + return True + +def is_concrete_int(a: Union[int, SymInt]) -> bool: + r""" Utility to check if underlying object + in SymInt is concrete value. Also returns + true if integer is passed in. + + Args: + a (SymInt or int): Object to test if it int + """ + assert isinstance(a, (SymInt, int)) + + if isinstance(a, int): + return True + + if isinstance(a.node.expr, sympy.core.numbers.Integer): + return True + + return False + +# In obscure Meta only situations, sympy.logic.boolalg doesn't exist at runtime. +# So make sure only type checker evaluates this alias. +# Xref: https://www.internalfb.com/diff/D53324783 +SympyBoolean: TypeAlias = "sympy.logic.boolalg.Boolean" + +def guard_size_oblivious(expr: Union[torch.SymBool, bool]) -> bool: + """ + Perform a guard on a symbolic boolean expression in a size oblivious way. + This is typically used when a non-oblivious test would result in a guard + on a data dependent value of which we don't know the value of at compile time. + When a guard is tested this way, we may diverge in behavior from how regular + PyTorch semantics would treat it. For more information, see + https://github.com/pytorch/pytorch/pull/118579 + """ + if isinstance(expr, torch.SymBool): + return expr.node.guard_size_oblivious("", 0) + else: + assert isinstance(expr, bool) + return expr + +def canonicalize_bool_expr(expr: SympyBoolean) -> SympyBoolean: + r""" Canonicalize a boolean expression by transforming it into a lt / le + inequality and moving all the non-constant terms to the rhs. + We canonicalize And / Ors / Not via cnf and then canonicalize their subexpr + recursively + nb. sympy.Rel.canonical is not good enough https://github.com/sympy/sympy/issues/25924 + + Args: + expr (sympy.Expr): Expression to canonicalize + """ + # Canonicalise an inequality by transforming it into a lt / le + # inequality and moving all the non-constant terms to the rhs + # We canonicalise And / Ors / Not via cnf + # nb. Relational.canonical in sympy is broken + # https://github.com/sympy/sympy/issues/25924 + + if not isinstance(expr, (sympy.Rel, sympy.And, sympy.Or, sympy.Not, sympy.Eq, sympy.Ne)): + return expr + + if isinstance(expr, (sympy.And, sympy.Or, sympy.Not)): + expr = sympy.logic.boolalg.to_cnf(expr) + return _canonicalize_bool_expr_impl(expr) + +def _canonicalize_bool_expr_impl(expr: SympyBoolean) -> SympyBoolean: + """ + After canonicalization, we are guaranteed to have eliminated Ge/Gt relations + (rewriting them to Le/Lt, respectively). + """ + if isinstance(expr, (sympy.And, sympy.Or)): + return type(expr)(*map(canonicalize_bool_expr, expr.args)) + + opposite = {sympy.Gt: sympy.Lt, sympy.Ge: sympy.Le} + if isinstance(expr, tuple(opposite.keys())): + lhs = expr.rhs - expr.lhs + t = opposite[type(expr)] + else: + assert isinstance(expr, (sympy.Lt, sympy.Le, sympy.Eq, sympy.Ne)) + lhs = expr.lhs - expr.rhs + t = type(expr) + rhs = 0 + if isinstance(lhs, sympy.Add): + cts = [] + variables = [] + for term in lhs.args: + if term.is_number: + cts.append(term) + else: + variables.append(term) + lhs = sympy.Add(*variables) + rhs = -sympy.Add(*cts) + return t(lhs, rhs) + +def is_concrete_bool(a: Union[bool, SymBool]) -> bool: + r""" Utility to check if underlying object + in SymBool is concrete value. Also returns + true if integer is passed in. + Args: + a (SymBool or bool): Object to test if it bool + """ + assert isinstance(a, (SymBool, bool)) + + if isinstance(a, bool): + return True + + if isinstance(a.node.expr, (sympy.logic.boolalg.BooleanTrue, sympy.logic.boolalg.BooleanFalse)): + return True + + return False + +def is_nested_int(s): + return isinstance(s, torch.SymInt) and s.node.is_nested_int() + +def _iterate_exprs(val: Union[SymInt, torch.Tensor]) -> Iterable[sympy.Basic]: + if isinstance(val, SymTypes): + # This allow applies to the jagged layout NestedTensor case as + # nested ints are not symbolic + if is_symbolic(val): + yield val.node.expr + elif isinstance(val, sympy.Basic): + yield val + elif isinstance(val, (int, float, bool)): + pass + elif is_sparse_any(val): + yield from _iterate_exprs(val.size()) + elif isinstance(val, torch.Tensor): + yield from _iterate_exprs(val.size()) + yield from _iterate_exprs(val.stride()) + yield from _iterate_exprs(val.storage_offset()) + elif isinstance(val, (tuple, list)): + for s in val: + yield from _iterate_exprs(s) + elif val is None: + pass + else: + raise AssertionError(f"cannot extract sympy expressions from {val} {type(val)}") + +def free_symbols(val: Union[SymInt, torch.Tensor]) -> Set[sympy.Symbol]: + if val is None: + return set() + itr = _iterate_exprs(val) + # we need at least 1 to call union, so we hand code the identity + try: + first_expr = next(itr) + except StopIteration: + return set() + + return first_expr.free_symbols.union(*(e.free_symbols for e in itr)) + +def has_free_symbols(val: Union[SymInt, torch.Tensor]) -> bool: + """Faster version of bool(free_symbols(val))""" + return not all(e.is_number for e in _iterate_exprs(val)) + +# Like free_symbols, but filtered to only report unbacked symbols +def free_unbacked_symbols(x): + # NB: keep synced with is_unbacked_symint + return {s for s in free_symbols(x) if s.name.startswith(("u", "f"))} + +# WARNING: Don't use this on Dynamo produced graphs, they don't have meta +# setup! +def is_symbol_binding_fx_node(node) -> Optional[sympy.Symbol]: + if ( + node.op == "placeholder" and + "val" in node.meta and + isinstance(node.meta["val"], torch.SymInt) and + isinstance(node.meta["val"].node.expr, sympy.Symbol) + ): + return node.meta["val"].node.expr + return None + +def find_symbol_binding_fx_nodes(graph): + return { + node.meta["val"].node.expr: node + for node in graph.nodes + if is_symbol_binding_fx_node(node) + } + +def definitely_true(a): + """ + Returns True only if we can tell that a is True, possibly introducing + a guard in the process. If a depends on some unbacked SymInt, we may + return False even though there may exist a possible value of the SymInt + that would cause the expression to return True. + + When is it appropriate to use definitely_true? First, if you can use + a higher level combinator like parallel_or/parallel_and, prefer using + those instead, they are definitely safe (modulo short-circuiting). + Second, it can be used if the program would behave equivalently if + definitely_true always returned False (parallel_or/parallel_and are + examples of this pattern, modulo short-circuiting). Finally, it even + be OK if the program wouldn't behave equivalently, so long as the + change is semantics preserving. It can be semantics preserving if + the program errors in more cases than it did previously (but otherwise + behaves identically), or if it changes some quantity in a way that + doesn't matter (e.g., strides often fall in this bucket.) + """ + if isinstance(a, SymBool): + if a.node.has_hint(): + return guard_bool(a) + else: + return False + return bool(a) + +def definitely_false(a): + """ + Returns True only if we can tell that a is False, possibly introducing + a guard in the process. If a depends on some unbacked SymInt, we may + return False even though there may exist a possible value of the SymInt + that would cause the expression a to be False. See definitely_true + for more usage guidance. + """ + if isinstance(a, SymBool): + if a.node.has_hint(): + return not guard_bool(a) + else: + return False + return not bool(a) + +def statically_known_true(x: Union[bool, SymBool]) -> bool: + """Returns True if x can be simplified to a constant and is true. + + .. note:: + This function doesn't introduce new guards, so the expression may end + up evaluating to true at runtime even if this function returns False. + + Args: + x (bool, SymBool): The expression to try statically evaluating + + """ + if isinstance(x, SymBool): + expr = x.node.expr + shape_env = x.node.shape_env + try: + simplified = shape_env._maybe_evaluate_static(expr) + if simplified is not None: + return bool(simplified) + except Exception: + log.debug("Could not simplify %s", expr) + return False + assert isinstance(x, bool) + return x + + +def parallel_or(*args): + """ + Evaluate the logical OR of several arguments, avoiding guarding on + unbacked SymInts if another argument is definitely True. + """ + if any(statically_known_true(a) for a in args): + return True + if any(definitely_true(a) for a in args): + return True + return any(args) + +def parallel_and(*args): + """ + Evaluate the logical FALSE of several arguments, avoiding guarding on + unbacked SymInts if another argument is definitely False. + """ + if any(statically_known_true(torch.sym_not(a)) for a in args): + return False + if any(definitely_false(a) for a in args): + return False + return all(args) + +def sym_eq(x, y): + """ + Like ==, but when run on list/tuple, it will recursively test equality + and use sym_and to join the results together, without guarding. + """ + if (isinstance(x, tuple) and isinstance(y, tuple)) or (isinstance(x, list) and isinstance(y, list)): + if len(x) != len(y): + return False + return functools.reduce(operator.and_, map(sym_eq, x, y), True) + elif isinstance(x, (int, torch.SymInt)) and isinstance(y, (int, torch.SymInt)): + return x == y + else: + raise AssertionError(f"unexpected sym_eq between {type(x)} {type(y)}") + +def guard_scalar(a): + if isinstance(a, (SymBool, bool)): + return guard_bool(a) + elif isinstance(a, (SymInt, int)): + return guard_int(a) + elif isinstance(a, (SymFloat, float)): + return guard_float(a) + else: + raise AssertionError(f"unrecognized scalar {a}") + + +@record_shapeenv_event() +def _constrain_symbol_range(shape_env, s: sympy.Symbol, compiler_min: int, compiler_max: int): + upd_vr = ValueRanges(compiler_min, compiler_max) + old_vr = shape_env.var_to_range.get(s, ValueRanges.unknown()) + new_vr = shape_env.var_to_range[s] = old_vr & upd_vr + if new_vr != old_vr: + log.info("_constrain_symbol_range %s [%s, %s]", s, new_vr.lower, new_vr.upper) + + +def _advise_is_size(a): + """ + Don't use this directly; use torch._check_is_size instead. + + This is a softer version of _constrain_range_for_size (with min=0, + max=Inf). Instead of forcibly constraining a variable (and erroring if we + failed to constrain it), it will simply advise us that a size is + constrained in some way. We will always defer a runtime assert for this + constraint if we cannot prove it at compile-time, but we we only + *sometimes* learn useful extra information at compile-time with this + information. This is in contrast to constrain_range_for_size, where if + you don't call that on a fresh unbacked symint, chances are we will choke. + + TODO: Make Dynamo handle this appropriately if this is seen in Dynamo-ed + code. Right now this is only really used in code with AOTAutograd trace + through, so it is not a big problem that this isn't supported, but in + principle all of this code should be Dynamo'able too. + + TODO: I didn't support min/max because I didn't have a use case where this + actually helped. In principle we can support it, it just makes the + implementation below more complicated. + """ + + # This must always succeed, because the sole allowed caller _check_is_size + # was responsible for expect_true'ing this + assert a >= 0 + + # NB: it's important not to constrain range for size for *hinted* SymInts, + # because it is not only unsound, it will immediately trip our asserts + # that hints have to be consistent with static analysis! If you somehow + # have an unbounded SymInt that later constrains to 1, this will be + # inconsistent with the range + if ( + isinstance(a, SymInt) + and isinstance(a.node, SymNode) + and not a.node.has_hint() + and isinstance(a.node.expr, sympy.Symbol) + ): + _constrain_range_for_size(a) + +@record_shapeenv_event() +def _constrain_range_for_size(a, min: Optional[int] = None, max: Optional[int] = None): + """ + This function is NOT INTENDED to be used by itself. + """ + + if isinstance(a, (SymFloat, SymBool)): + raise ValueError("Constraining SymFloat/SymBool is nyi") + + assert isinstance(a, SymInt), "can only constrain range for SymInt" + assert isinstance(a.node.expr, sympy.Symbol), "constraining non-Symbols NYI" + + if min is None: + min = 0 + if max is None: + max = sympy.oo + + if max < min: + raise ValueError( + "Maximum value to constrain_as_size can't be less than the specified min value, " + "received min={min} and max={max}" + ) + + _constrain_symbol_range( + a.node.shape_env, + a.node.expr, + compiler_min=min, + compiler_max=max, + ) + a.node.shape_env.size_like.add(a.node.expr) + + +# inclusive both ways +@record_shapeenv_event() +def constrain_range(a, *, min: Optional[int], max: Optional[int] = None): + """ + Applies a constraint that the passed in SymInt must lie between min-max + inclusive-inclusive, WITHOUT introducing a guard on the SymInt (meaning + that it can be used on unbacked SymInts). If min/max are None, we assume + that the dimension is unbounded in that direction. Repeated application + of constrain_range intersects the ranges. This is a fairly low level API + that doesn't have a lot of safety guarantees (TODO: provide higher level + APIs). + + Currently, we use this API in the following circumstance: when we allocate + an unbacked SymInt, denoting an integer quantity which is data dependent, + we ordinarily do not know anything about what values it may take. This + means that any sort of guard on it will immediately fail. However, in + many cases, we know something about the unbacked SymInt: for example, we + know that nonzero(x).size(0) must be >= 0. We use constrain_range to + narrow the possible range, declaring that negative symbols are impossible. + This permits to definitely answer True to queries like 'nnz >= 0', even if + we don't know what the actual (hinted) value of 'nnz' is. In fact, we + actually use constrain_range to unsoundly discharge common guards: for an + unbacked SymInt produced by nonzero, we will also assume that it is not + equal to 0/1 (even though these are perfectly possible values at runtime), + because we generally expect graphs that are valid for N=2 to also be valid + for N=1. + """ + if min is None: + min = -sympy.oo + if max is None: + max = sympy.oo + + if max < min: + raise ValueError( + "Maximum value to constrain_as_size can't be less than the specified min value, " + "received min={min} and max={max}" + ) + + if isinstance(a, int): + if not (min <= a <= max): + raise ValueError(f"Invalid value {a} for range [{min}:{max}]") + return + + if isinstance(a.node.expr, sympy.Integer): + if not (min <= int(a.node.expr) <= max): + raise ValueRangeError(f"Invalid value {int(a.node.expr)} for range [{min}:{max}]") + return + assert isinstance(a.node.expr, sympy.Symbol), "constraining non-Symbols NYI" + + # TODO: Shouldn't we install a guard if the symbol is backed? Or is the + # semantics that this is an "unchecked" assert (but it this actually + # something useful? Might be better to restrict only for unbacked + # SymInt). + _constrain_symbol_range( + a.node.shape_env, + a.node.expr, + compiler_min=min, + compiler_max=max, + ) + + +@record_shapeenv_event() +def constrain_unify(a, b): + """ + Given two SymInts, constrain them so that they must be equal. NB: + this will not work with SymInts that represent nontrivial expressions + (yet!) + """ + # TODO: this does not install a deferred runtime assert yet + + # TODO: Maybe dedupe this with _maybe_guard_rel? + if not isinstance(a, SymInt): + if not isinstance(b, SymInt): + assert a == b + else: + assert isinstance(b.node.expr, sympy.Symbol), "constraining non-Symbols NYI" + shape_env = b.node.shape_env + shape_env.replacements[b.node.expr] = sympy.Integer(a) + else: + # TODO: Actually, we can support this as long as one of them is a symbol. + # NB: We can't actually do "unification" as our operators are not + # injective + assert isinstance(a.node.expr, sympy.Symbol), "constraining non-Symbols NYI" + shape_env = a.node.shape_env + if not isinstance(b, SymInt): + shape_env.replacements[a.node.expr] = sympy.Integer(b) + else: + assert a.node.shape_env is b.node.shape_env + assert isinstance(b.node.expr, sympy.Symbol), "constraining non-Symbols NYI" + new_var = shape_env._find(a.node.expr) + shape_env.replacements[b.node.expr] = new_var + +# Assume that a boolean is true for the purposes of subsequent symbolic +# reasoning. This will keep track of corresponding runtime checks to verify +# that the result is upheld: either as a regular guard, or as a special set +# of asserts which are triggered when an unbacked SymInt is allocated. +# +# DO NOT use this function for these cases: +# +# - This is inappropriate for "branching" conditions (where both +# true and false result in valid programs). We will always assume +# the condition evaluates true, and so it will never be possible +# to trace the false condition when you use it. For true branching +# on unbacked SymInts, you must use torch.cond; if you incorrectly +# use expect_true in this case, you will make the false branch +# unreachable (as we will simply assume that only the true branch +# is ever exercised). +# +# - This is inappropriate for situations where you know some other system +# invariant guarantees that this property holds, since you don't +# really need to insert a runtime check in that case. Use something +# like constrain_range in that case. +# +# This API has a hitch. To avoid having to reimplement error reporting +# capabilities, this function CAN return False. The invariant is that +# the surrounding code must raise an error when this function returns +# False. This is quite low level, so we recommend using other functions +# like check() which enforce this in a more intuitive way. +# +# By the way, this name is a nod to the __builtin_expect macro, +# which is used similarly (but unlike __builtin_expect, you MUST fail +# in the unlikely branch.) (I think expect is a good name; in recent +# versions of C++, this is replaced with [[likely]], which is weaker +# and not accurate for this function!) +def expect_true(a, skip: int = 0): + if isinstance(a, SymBool): + # TODO: check perf implications of this + frame = inspect.currentframe() + for _ in range(skip + 1): # always run this loop at least once + frame = frame.f_back + return a.node.expect_true(frame.f_code.co_filename, frame.f_lineno) + assert type(a) is bool, a + return a + +def guard_bool(a): + if isinstance(a, SymBool): + return a.node.guard_bool("", 0) # NB: uses Python backtrace + assert type(a) is bool, a + return a + +def guard_int(a): + if isinstance(a, SymInt): + return a.node.guard_int("", 0) # NB: uses Python backtrace + assert type(a) is int, a + return a + +def guard_float(a): + if isinstance(a, SymFloat): + return a.node.guard_float("", 0) # NB: uses Python backtrace + assert isinstance(a, float), a + return a + +# Given a GraphModule, return all the FakeTensors for all the placeholders +def fx_placeholder_vals(gm): + return [n.meta['val'] for n in gm.graph.nodes if n.op == "placeholder"] + +def fx_placeholder_targets(gm): + return [n.target for n in gm.graph.nodes if n.op == "placeholder"] + +# Given a GraphModule and arguments to run it with, evaluate that the guards +# for its associated ShapeEnv are satisfied by the passed arguments. This +# WILL check for duck sizing. +def eval_guards(gm, *args, ignore_static=True): + return gm.shape_env.evaluate_guards_for_args(fx_placeholder_vals(gm), args, ignore_static=ignore_static) + +def bind_symbols(gm, *args): + return gm.shape_env.bind_symbols(fx_placeholder_vals(gm), args) + +def _assert_bound_is_rational(expr: sympy.Expr, bound: ValueRanges): + """ + We assert that the bounds are either Boolean, or not finite, or can be computed + in exact prevision via rational arithmetic. + The only exception to this is the rare case when the user calls `sqrt(s0)` + sqrt is turned into sympy.Pow so we just match for that (it matches more things, but still) + """ + assert bound.lower.is_rational or bound.lower.is_Boolean or not bound.lower.is_finite or expr.has(sympy.Pow), (bound, expr) + assert bound.upper.is_rational or bound.upper.is_Boolean or not bound.upper.is_finite or expr.has(sympy.Pow), (bound, expr) + +class DimDynamic(Enum): + """ + Controls how to perform symbol allocation for a dimension. It is always + sound to default this to DYNAMIC, but the policies DUCK and STATIC can + result in better trace-time and compile-time performance, as they reduce + the number of allocated symbols and generally make your graph more static. + + NB: If we notice you've applied a constraint to the dimension, we will + force it to DYNAMIC for simplicity. + + DimDynamic is controlled by a variety of higher level UX features. + Currently: + + - In eager mode, the default policy is DUCK. + - The default is changed to STATIC with assume_static_by_default. + - An individual dim is marked DYNAMIC if you mark_dynamic_dim. + - In export mode, the default policy is STATIC. + - An individual dim is marked DYNAMIC if you mention it as dynamic_dim + in the constraints kwarg. + """ + # Treat the dimension symbolically + DYNAMIC = 0 + # Treat the dimension symbolically, but if its hint matches another + # dynamic dimension, unify the two symbols ("duck sizing") + DUCK = 1 + # Treat the dimension statically based on its hint + STATIC = 2 + + +# NB: These constraints affect both clients and backends: given some +# constraint C, the client must pass inputs that satisfy the constraint, +# while a backend must not introduce guards BEYOND this constraint. +# For clarity, we document the implications on both sides for both the client +# and the backend. +# +# NB: These constraints are on a *single* dimension. In principle, we could +# also have multi-dimension constraints, but our guess is that this is not +# actually useful and so we are not supporting it right now. +# +# NB: Strict constraints are typically only suitable for export, as in eager +# a backend like inductor may validly introduce extra, discretionary guards +# to improve performance of code. A StrictMinMaxConstraint would be brittle +# under future optimizations performed by inductor; we don't guarantee +# eager code with StrictMinMaxConstraint will keep working in the future! + +@dataclass(frozen=True) +class Constraint: + warn_only: bool + +@dataclass(frozen=True) +class StrictMinMaxConstraint(Constraint): + """ + For clients: the size at this dimension must be within 'vr' (which + specifies a lower and upper bound, inclusive-inclusive) AND it + must be non-negative and should not be 0 or 1 (but see NB below). + + For backends: there must not be any guards on this dimension which + are not implied by the given lower and upper bound. Regardless of + the lower bound, the backend can assume the size is non-negative + and that it is not 0 or 1. + + An unbounded StrictMinMaxConstraint can be thought of as a strict version + of "RelaxedUnspecConstraint". + + NB: Export will often unsoundly assume that a graph works for 0/1, even + though at trace time we assumed size is not 0 or 1. The idea is that + if we produce a graph that works for a range of values, it will be OK + for N=0/1 too. + """ + vr: ValueRanges + + def render(self, source: Source): + """Format the constrain equation""" + # TODO: better printing for -oo and oo + return f"{self.vr.lower} <= {source.name()} <= {self.vr.upper}" + +@dataclass(frozen=True) +class RelaxedUnspecConstraint(Constraint): + """ + For clients: no explicit constraint; constraint is whatever is implicitly + inferred by guards from tracing. + + For backends: there must exist at least TWO possible values for the + size at this dimension which satisfy the guards for this dimension. + + In other words, this constraint helps us distinguish between "we don't + care if this dimension specializes or not" versus "this dimension must be + unspecialized." However, this constraint doesn't say very much about what + specialization is permitted; for example, if we guard on a size being + even, this would still be acceptable under an unspec constraint. This + makes RelaxedUnspecConstraint useful for eager mode, where your backend compiler + may add constraints to otherwise dynamic dimensions; we can't assert that + there are NO guards as this is brittle because compilers should be able to + add extra constraints. If you want to assert that there are no guards, + use StrictMinMaxConstraint with an unbounded ValueRanges. + """ + def render(self, source: Source): + return f"RelaxedUnspecConstraint({source.name()})" + +# NB: None here indicates the client constraint is whatever is implicitly +# inferred by guards from tracing, and that a backend can add whatever guards +# it wants (including fully specializing the value). +DimConstraint = Union[StrictMinMaxConstraint, RelaxedUnspecConstraint, None] + +@dataclass(frozen=True) +class EqualityConstraint(Constraint): + """ + Represent and decide various kinds of equality constraints between input sources. + + A "source pair" is a pair of input sources for dynamic dimensions that + are specified equal. We represent `source_pairs` in a union-find forest + so that we can efficiently check whether two such sources are transitively equal. + + A "derived equality" relates an input source to an expression over a root. + The root can be another input source, corresponding to some dynamic dimension, + or a phantom symbol that does not directly represent any dynamic dimension. We + represent `derived_equalities` involving input sources in a transitively-closed map + so that we can efficiently check whether an input source is transitively equal to + a given expression over another input source. + (NOTE: In contrast, it is easy to decide whether an input source is transitively equal + to a given expression over a phantom symbol; such expressions are already in canonical + form and so the problem reduces to symbolic expression equality.) + """ + source_pairs: List[Tuple[Source, Source]] + derived_equalities: List[Tuple[Source, Union[Source, sympy.Symbol], Callable[[sympy.Expr], sympy.Expr]]] + phantom_symbols: List[sympy.Symbol] + + def __post_init__(self): + """Pre-processing to answer queries `is_equal` and `is_derived` below. + + Example: Suppose we are given: + source_pairs [a = b, b = c] + derived_equalities [d = c + 1, e = d - 1] + We first construct a union find with source_pairs: + _parents = {a: a, b: a, c: a} + Then we compute canonical symbolic expressions, recursively applying derived_equalities + until we bottom out: + _defs = {d: c + 1, e: (c + 1) - 1 aka c} + """ + + # self._parents is a map from input sources to input sources where, conceptually, + # these are directed edges in a union-find forest + _parents: Dict[Source, Source] = {} + object.__setattr__(self, "_parents", _parents) + # self._defs is a map from input sources to "canonical" symbolic expressions, + # i.e., unary expressions with symbols that corresponds to regular Dims (i.e., + # not derived Dims) + _defs: Dict[Source, sympy.Expr] = {} + object.__setattr__(self, "_defs", _defs) + + for source1, source2 in self.source_pairs: + # preprocess into a union-find forest + self._union(self._find(source1), self._find(source2)) + for source, root, fn in self.derived_equalities: + # preprocess into a transitively-closed map + # NOTE(avik): we reuse the union-find forest for canonicalizing input sources + if isinstance(root, sympy.Symbol): + self._defs[self._find(source)] = fn(root) + else: + self._defs[self._find(source)] = fn(self._rewrite(root)) + + def _find(self, source): + # chase edges to find the root of this equivalence class + if source in self._parents: + return self._find(self._parents[source]) + else: + return source + + def _union(self, root1, root2): + # merge two equivalence classes by adding an edge from one root to the other + if root1 != root2: + self._parents[root1] = root2 + + def _rewrite(self, src): + # always represent the given source by the root of its equivalence class + src = self._find(src) + if src in self._defs: + # simply look up the definition if it exists + # NOTE(avik): This works because definitions are always transitively-closed; + # otherwise we would have to do recursive rewriting. + return self._defs[src] + else: + # otherwise, create a symbol representing the source + return sympy.Symbol(src.name()) + + def is_equal(self, source1, source2): + return ( + # check whether source1 and source2 have the same root + self._find(source1) == self._find(source2) or + # check whether source1 is derived equal to source2 + self.is_derived(source1, source2, lambda x: x) + ) + + def is_derived(self, src, symbol_src, fn): + # check whether both src and symbol_src have the same definition + return self._rewrite(src) == fn(self._rewrite(symbol_src)) + + +def _assert_symbol_context(symbolic_context): + assert isinstance(symbolic_context, SymbolicContext), "Invalid symbolic_context object" + assert type(symbolic_context) is not SymbolicContext, "Illegal usage of symbolic_context ABC" + + +@dataclass(frozen=True) +class SymbolicContext: + """ + Data structure specifying how we should create symbols in + ``create_symbolic_sizes_strides_storage_offset``; e.g., should + they be static or dynamic. + + This is an abstract base class because we are probably going to add + another version of this that says "use exactly these SymInts, don't + allocate fresh symbols." + """ + pass + + +@dataclass(frozen=True) +class StatelessSymbolicContext(SymbolicContext): + """ + Create symbols in ``create_symbolic_sizes_strides_storage_offset`` via + a symbolic_context determination as given by ``DimDynamic`` and ``DimConstraint``. + This will cause fresh symbols to be allocated + """ + dynamic_sizes: DimList[DimDynamic] + constraint_sizes: DimList[DimConstraint] = None + # If the tensor is a view, this should be populated for the base. It contains + # information on how to allocate symbols when recursively fakeifying the base + # during view fake-ification. + view_base_context: Optional[SymbolicContext] = None + # TODO: add storage offset and stride symbolic_context + + def __post_init__(self): + if self.constraint_sizes is None: + object.__setattr__(self, 'constraint_sizes', [None] * len(self.dynamic_sizes)) + + +# note [Tensor Fakification and Symbol Caching] +# +# As of the time of this note, dynamo creates a fresh fake tensor mode for backends. +# The reason we do this is because there are certain classes of operations, namely, +# metadata mutations, that change tensor size, stride, etc. This means that the fake tensor +# state at the end of a dynamo trace is different than the fake tensor state at the beginning +# of a trace. Backends like aot_autograd need a fresh fake tensor to correctly track metadata mutation, +# view relationships, etc. +# +# As we create a new fake mode, we also lose the memoization that comes with it. Rather than +# transfer the memoization cache, we instead transfer the shape env. However, with this +# comes nuance - as dynamo is selective in how it makes symbolic shapes. Due to strategies in +# automatic dynamic and constraints, the policy for which dims are dynamic is nuanced and varies across +# recompilations. +# +# In order to preserve the symbolic decisions made during dynamo tensor fakification, we pass +# a StatefulSymbolicContext at creation time. This object is tracked, per tensor, on the TracingContext. +# The lifecycle of this object should match the lifecycle of the original dynamo tracked tensor, and it is +# safe to reuse this object as many times as necessary to create a fake tensor. Fake tensors +# created with new fake modes should produce the same exact symbols as the original, providing the same shape_env +# is used. +# TODO(voz): Shape env validation +@dataclass(frozen=True) +class StatefulSymbolicContext(StatelessSymbolicContext): + """ + Create symbols in ``create_symbolic_sizes_strides_storage_offset`` via + a symbolic_context determination as given by a cache of Source:Symbol. A cache hit + will reuse a stored symbol, and a cache miss will write to this cache. + + This behaves like StatelessSymbolicContext, except the cache supersedes the + other values - dynamic_sizes and constraint_sizes will not be read if we cache + hit. + + It is the cache owners responsibility to maintain the lifecycle of the cache + w/r/t different shape_envs, clearing, etc. + """ + tensor_source: Source = None + # Why is this keyd on int first? + # That integer is actually the id of the shape_env. This cache short-circuits symbol + # creation, and we must store it per shape env. Now, while tracing invariants are a single + # shape env per tracing context, and every new frame gets a new shape_env. So where would we have + # multiple shape envs? The answer lies in recording. When we are replaying, replay_shape_env_events + # is invoked, and creates a new shape_env. Replaying events against this new shape_env will + # cause it to fail with unknown symbols, as the symbols cached here will skip creation, and never + # get recorded in var_to_val, etc. + # TODO(voz): consider a weakref to the shape_env here + shape_env_to_source_to_symbol_cache : Dict[int, Dict["TensorPropertySource", "sympy.Expr"]] = None + + def __post_init__(self): + # The None default is annoying, but required because of dataclass limitations + assert self.tensor_source is not None + if not self.shape_env_to_source_to_symbol_cache: + object.__setattr__(self, 'shape_env_to_source_to_symbol_cache', {}) + + +@dataclass(frozen=True) +class SubclassSymbolicContext(StatefulSymbolicContext): + """ + The correct symbolic context for a given inner tensor of a traceable tensor subclass + may differ from that of the outer symbolic context. This structure allows for this + flexibility, with inner symbolic contexts mapped via attr -> symbolic context. + """ + inner_contexts: Dict[str, SymbolicContext] = None + + def __post_init__(self): + super().__post_init__() + if self.inner_contexts is None: + self.inner_contexts = {} + + +def is_symbolic(val: Union[int, SymInt, float, SymFloat, bool, SymBool]) -> bool: + if isinstance(val, (int, float, bool)): + return False + return val.node.is_symbolic() + +IndicatorTypes = (IsNonOverlappingAndDenseIndicator,) + +@lru_cache(256) +def safe_expand(r): + if hasattr(r, 'expand'): + try: + return sympy.expand(r) + except RecursionError: + log.warning("RecursionError in sympy.expand(%s)", r) + return r + else: + return r + +def error(): + raise AssertionError("shouldn't be hit") + + +# TODO: Deduplicate this with torch/_prims_common/__init__.py +def eval_is_non_overlapping_and_dense(sizes, strides): + return int(guard_bool(_eval_is_non_overlapping_and_dense(sizes, strides))) + +def _eval_is_non_overlapping_and_dense(sizes, strides): + dim = len(sizes) + + # Short-circuits for tensors of rank one, which are + # non-overlapping and "dense" if their stride is one + # or it is a 0/1 element tensor + if dim == 1: + return strides[0] == 1 or sizes[0] < 2 + + # Checks that there exists a permutation of the strides s.t. the tensor would be contiguous + # Sorts (length, stride) pairs by stride + lengths_and_strides = sorted( + zip(sizes, strides), key=operator.itemgetter(1) + ) + + # Unlike the C++ code, we don't move the 0/1 size dimensions to the + # end. So we have to keep going for this code. + expected_stride = 1 + for length, stride in lengths_and_strides: + + if length == 1: + continue + + if stride != expected_stride: + return False + + expected_stride *= length + + return True + + +def cast_symbool_to_symint_guardless(symbool: torch.SymBool) -> torch.SymInt: + int_sym = sympy.Piecewise((1, symbool.node.expr), (0, True)) + return symbool.node.shape_env.create_symintnode(int_sym, hint=int(symbool.node.require_hint())) + +SYMPY_INTERP = { + 'Abs': operator.abs, + 'Eq': operator.eq, + 'Ne': operator.ne, + 'Gt': operator.gt, + 'Lt': operator.lt, + 'Le': operator.le, + 'Ge': operator.ge, + 'Min': min, + 'Max': max, + 'Mod': operator.mod, + 'FloorDiv': operator.floordiv, + 'TrueDiv': operator.truediv, + 'IsNonOverlappingAndDenseIndicator': eval_is_non_overlapping_and_dense, + 'floor': math.floor, + 'ceiling': math.ceil, + 'cast_symbool_to_symint_guardless': cast_symbool_to_symint_guardless, + 'Round': builtins.round, + 'RoundDecimal': builtins.round, +} + + +def _lru_cache(fn, maxsize=None): + """ + Wrapper around lru_cache that clears when new info about shapes has been + updated. + + Use lru_cache if the output is always the same, regardless of the + constraints we know now (i.e. evaluate_expr) + + Use _lru_cache otherwise. + + Also note that this depends on _update_version_counter being called on the + shape environment whenever the constraints are updated, otherwise the cache + will not be cleared. + """ + fn_cache = lru_cache(maxsize)(fn) + prior_version = 0 + + if config.validate_shape_env_version_key: + prior_key = None + + @functools.wraps(fn) + def wrapper(self, *args, **kwargs): + nonlocal prior_version, prior_key + if prior_key is None: + prior_key = self._get_key() + + if prior_version != self._version_counter: + fn_cache.cache_clear() + prior_version = self._version_counter + prior_key = self._get_key() + else: + assert prior_key == self._get_key(), \ + "ShapeEnv cache key changed without version being updated!" + + return fn_cache(self, *args, **kwargs) + + else: + + @functools.wraps(fn) + def wrapper(self, *args, **kwargs): + nonlocal prior_version + if prior_version != self._version_counter: + fn_cache.cache_clear() + prior_version = self._version_counter + + return fn_cache(self, *args, **kwargs) + + wrapper.cache_clear = fn_cache.cache_clear + wrapper.cache_info = fn_cache.cache_info # type: ignore[attr-defined] + return wrapper + + +# This is pretty similar to ShapeGuard but it also comes with a message, +# and is exclusively used for things that MUST be true (unlike guards, +# which can evaluate False, in which case you just choose not to use +# a particular specialization) +@dataclass(frozen=True) +class RuntimeAssert: + expr: sympy.Expr + msg: str = field(repr=False) + stack: str = field(repr=False) + + +class ShapeGuardPrinter(StrPrinter): + def __init__( + self, + symbol_to_source, + source_ref, + var_to_sources, + ): + super().__init__() + self.symbol_to_source = symbol_to_source + self.source_ref = source_ref + self.var_to_sources = var_to_sources + + def _print_Not(self, expr): + return 'not %s' % (self.parenthesize(expr.args[0], PRECEDENCE["Not"])) + + def _print_And(self, expr): + return self.stringify(expr.args, " and ", PRECEDENCE["And"]) + + def _print_Or(self, expr): + return self.stringify(expr.args, " or ", PRECEDENCE["Or"]) + + def _print_Symbol(self, expr) -> str: + assert isinstance(expr, sympy.Symbol), str(type(expr)) + + def repr_symbol_to_source(): + return repr({ + symbol: [s.name() for s in sources] + for symbol, sources in self.symbol_to_source.items() + }) + + assert self.symbol_to_source.get(expr), ( + f"{expr} (could be from {[s.name() for s in self.var_to_sources[expr]]}) " + f"not in {repr_symbol_to_source()}. If this assert is failing, it could be " + "due to the issue described in https://github.com/pytorch/pytorch/pull/90665" + ) + return self.source_ref(self.symbol_to_source[expr][0]) + + +class LoggingShapeGuardPrinter(ShapeGuardPrinter): + def __init__(self, var_to_sources): + super().__init__(var_to_sources, lambda n: n.name(), var_to_sources) + + +class DynamicDimConstraintPrinter(StrPrinter): + """ + Printer for dynamic dim constraints. + - Instead of t.size()[d] it prints dynamic_dim(t, d) + - Instead of Eq(_, _), Mod(_, _), etc. it prints _ == _, _ % _, etc. + + We use this to suggest code for specifying dynamic dim constraints. + """ + def __init__(self, symbol_to_source, source_name_to_debug_name): + super().__init__() + self.symbol_to_source = symbol_to_source + self.source_name_to_debug_name = source_name_to_debug_name + + def print_source(self, source) -> str: + if self.source_name_to_debug_name: + return source.name() + return f"dynamic_dim({source.base.name()}, {source.idx})" + + def _print_Symbol(self, expr) -> str: + assert isinstance(expr, sympy.Symbol), str(type(expr)) + assert self.symbol_to_source.get(expr), ( + f"Unknown symbol {expr} created by constraints solver" + ) + return self.print_source(self.symbol_to_source[expr][0]) + + def _print_Relational(self, expr): + return '{} {} {}'.format( + self.parenthesize(expr.lhs, precedence(expr)), + expr.rel_op, + self.parenthesize(expr.rhs, precedence(expr)) + ) + + +class DimConstraints: + """ + Custom solver for a system of constraints on symbolic dimensions. + Solutions are "static" values or simplified "dynamic" constraints. + """ + + def __init__(self, symbol_to_source, var_to_val, marked_dynamic, source_name_to_debug_name): + # We try to solve systems of inequalities with 1 free variable. + self._univariate_inequalities: Dict[sympy.Symbol, Set[sympy.Expr]] = defaultdict(set) + # Among them, we prioritize solving for a free variable that has equalities. + # NOTE: _symbols_with_equalities is always a subset of _univariate_inequalities.keys() + # and removing a symbol from the former => removing it from the latter. + self._symbols_with_equalities: Set[sympy.Symbol] = set() + # A solution of a free variable with equalities becomes a substitution. + # We use these substitutions to simplify other constraints. + # NOTE: removing a symbol from _symbols_with_equalities => adding it to _substitutions. + self._substitutions: Dict[sympy.Symbol, sympy.Integer] = {} + + # In general, constraints may have // and % operations. + # Of course, // can be expressed in terms of / and %. + # Our inequality solver can handle / but not %. So we need to transform them away. + # We do so by using the values of variables as hints to evaluate %. + # For soundness we record additional congruence guards and solve them separately. + self._var_to_val: Dict[sympy.Symbol, sympy.Integer] = var_to_val + self._congruences: Set[sympy.Expr] = defaultdict(set) + + # We do not try to (directly) solve inequalities with > 1 free variables. + # NOTE: free variables in these inequalities cannot also be in _substitutions. + self._multivariate_inequalities: Set[sympy.Expr] = set() + + # We park external equalities between free variables here. + self._symbolic_equivalences: List[Tuple[Source, sympy.Expr]] = [] + + # Solutions come in two forms: + # - (static) specializations + # - (dynamic) inequalities / congruences + self._static_results: Set[str] = set() + self._dynamic_results: Set[str] = set() + + # printer for solutions + self._dcp = DynamicDimConstraintPrinter(symbol_to_source, source_name_to_debug_name) + + # inconsistencies found on substituting with concrete values / static solutions + self._inconsistencies: List[str] = [] + + # symbols that are marked dynamic + self._marked_dynamic = marked_dynamic + + def rewrite_with_congruences(self, s, expr): + """ + Eliminate expressions of the form b // d and b % d while adding congruences of the form b % d == k. + This leaves rational operators (in particular of the form b / d) that our inequality solver can handle. + We solve the added congruences separately (using our congruence solver, see below). + """ + def mod_handler(*args): + # Suppose that we have an expression of the form b % d with free variable s. + # Using the value of s as a "hint," we can evaluate b % d to a value k. + # Then we can rewrite b % d to k while adding the guard b % d == k. + + # NOTE(avik): This abstraction is provably sound but, in general, incomplete. It is complete IFF + # the original expression always evaluates to a constant value (i.e., it does not vary with s). + # In other words, + # - solutions of s with the rewritten expression are guaranteed to also be solutions of s with + # the original expression; + # - while it may be possible to find solutions of s with the original expression that are not + # solutions with the rewritten expression, in that case the original expression cannot evaluate + # to the same value for all solutions of s. + # + # Should we be worried about this incompleteness? No, because of the following reasons: + # 1. It unblocks dramatic simplification that would not be otherwise possible with current tech + # (i.e., "don't let perfect be the enemy of the good"). + # 2. We already have a tradition of using hints to add guards in the compiler for making progress. + # 3. We have not yet seen a counterexample arise in practice! In particular, any congruence guards + # we generate (or simplify to) seem to be of the form b % d == k where k is a constant. + # + # Here's a theoretical counterexample: 3*s % (s + 1) == s - 2, that is satisfied by all s >= 2. + # With any hint (say) s = k, we'd rewrite this to: 3*s % (s + 1) == k - 2. But, substituting, we + # would then get k - 2 == s - 2, and thus s = k as the (only, constant) solution! + base, divisor = args + base, divisor = self.rewrite_with_congruences(s, base), self.rewrite_with_congruences(s, divisor) + mod_reduced = base.subs(self._var_to_val) % divisor.subs(self._var_to_val) + congruence = (base - mod_reduced) % divisor + if congruence != 0: + self._congruences[s].add(congruence) + return mod_reduced + + def floor_div_handler(*args): + # Suppose that we have an expression of the form b // d with free variable s. + # Using the value of s, we can evaluate b % d to a value k. + # Then we can rewrite b // d to (b - k) / d, while adding the guard b % d == k. + + # NOTE(avik): This is exactly equivalent to rewriting b // d as (b - (b % d)) / d + # and eliminating b % d as above. + base, divisor = args + base, divisor = self.rewrite_with_congruences(s, base), self.rewrite_with_congruences(s, divisor) + mod_reduced = base.subs(self._var_to_val) % divisor.subs(self._var_to_val) + congruence = (base - mod_reduced) % divisor + if congruence != 0: + self._congruences[s].add(congruence) + return (base - mod_reduced) / divisor + + if expr.has(Mod): + expr = expr.replace(Mod, mod_handler) + if expr.has(FloorDiv): + expr = expr.replace(FloorDiv, floor_div_handler) + return expr + + def add(self, expr) -> bool: + """Add an expression to the set of constraints. + + Return whether the expression is a trivial constraint (i.e., an obvious tautology). + """ + if expr == sympy.true: + return True + orig_expr = expr + orig_reduced = orig_expr.subs(self._var_to_val) + # TODO(avik): https://github.com/pytorch/pytorch/issues/101093 + # It is possible that `expr` will fail the consistency check because of + # precision errors. Specifically, on substituting its free symbols with + # their concrete values, we might end up comparing floats. Until we have + # a fix for this issue, we delay raising such failures. See solve(). + if orig_reduced == sympy.false: + self._inconsistencies.append(f"{orig_expr} is inconsistent!") + if isinstance(expr, sympy.Ne): + # we're not going to do anything useful with these, so drop them + return False + free_symbols = expr.free_symbols + assert free_symbols, f"Did not expect constraint with no free variables: {expr}" + if len(free_symbols) > 1: + # multivariate: record and move on + self._multivariate_inequalities.add(expr) + else: + # univariate: can solve these immediately + s = next(iter(free_symbols)) + # eliminate // and % (see documentation of `rewrite_with_congruences` above) + old_n_congruences = len(self._congruences[s]) + expr = self.rewrite_with_congruences(s, expr) + new_n_congruences = len(self._congruences[s]) + if expr == sympy.true: + return old_n_congruences == new_n_congruences + reduced = expr.subs(self._var_to_val) + if reduced == sympy.false: + self._inconsistencies.append( + f"{expr}, obtained by rewriting {orig_expr} with congruences, " + "is inconsistent!" + ) + if isinstance(expr, sympy.Eq): + # special status for symbols that have equalities (see `solve` below) + self._symbols_with_equalities.add(s) + self._univariate_inequalities[s].add(expr) + return False + + def add_equality(self, source, expr): + """Add an equality constraint""" + if expr.is_number: + # specialization, right here + self._static_results.add(f"{source.name()} == {expr}") + else: + # these will resolve to either specializations or dynamic equality constraints + self._symbolic_equivalences.append((source, expr)) + + def _reduce_congruences(self): + reduced_congruences = {} + for s, congruences in self._congruences.items(): + remainder_modulus_pairs = [] + congruences_to_check = set() + for congruence in congruences: + base, divisor = congruence.args + # We are given a congruence of the form base % divisor == 0 with a free variable s. So: + # - we transform this into an equation of the form base = divisor * tmp; + # - we solve this equation for s to get a linear solution with free variable tmp. + tmp = sympy.Symbol("tmp", integer=True) + symbol, solution = sympy.solve_linear(base - divisor * tmp, symbols=[s]) + # See https://docs.sympy.org/latest/modules/solvers/solvers.html#sympy.solvers.solvers.solve_linear + # for how to interpret the results. + if s == symbol: + # This means the solution is of the form s = modulus*tmp + remainder. + modulus, remainder = sympy.polys.polytools.div(solution, tmp) + if isinstance(modulus, sympy.Integer) and isinstance(remainder, sympy.Integer): + # Make sure 0 <= remainder <= modulus. + remainder = remainder % modulus + remainder_modulus_pairs.append((remainder, modulus)) + continue + # This means that we did not get a unique solution to the equation. + # No problem, we will check it. + congruences_to_check.add(congruence) + # Finally we solve for a congruence s such that s = r_i mod m_i for each (r_i, m_i). + # The solution will be a congruence of the form s = r mod m. + # NOTE(avik): Since the given m_i may not be pairwise coprime, we can't just use CRT. + if remainder_modulus_pairs: + remainder, modulus = sympy.ntheory.modular.solve_congruence(*remainder_modulus_pairs) + reduced_congruences[s] = {(s - remainder) % modulus} + substitution = {s: modulus * sympy.Symbol("tmp", integer=True) + remainder} + reduced_congruences[s].update( + congruence for congruence in congruences_to_check + if not sympy.checksol(congruence, substitution) + ) + else: + reduced_congruences[s] = congruences_to_check + + return reduced_congruences + + def _raise_inconsistencies(self): + if self._inconsistencies: + msg = "\n".join(self._inconsistencies) + self._inconsistencies.clear() + raise ValueError(f"The following inconsistencies were found:\n{msg}") + + def _force_specialization(self, s): + val = self._var_to_val[s] + self._static_results.add(f"{self._dcp.symbol_to_source[s][0].name()} == {val}") + self._substitutions[s] = val + + def _specialize_divisor_symbols(self): + for expr in self._multivariate_inequalities: + for atom in expr.atoms(FloorDiv, Mod): + _, divisor = atom.args + for s in divisor.free_symbols: + self._force_specialization(s) + + multivariate_inequalities = self._multivariate_inequalities + self._multivariate_inequalities = set() + for expr in multivariate_inequalities: + self.add(expr.subs(self._substitutions)) + self._raise_inconsistencies() + self._univariate_inequalities = { + s: exprs + for s, exprs in self._univariate_inequalities.items() + if s not in self._substitutions + } + self._congruences = { + s: congruences + for s, congruences in self._congruences.items() + if s not in self._substitutions + } + + def solve(self, disable_congruences=True, disable_equivalences=True): + """Solve the system of constraint equations to find simplified constraints + """ + self._raise_inconsistencies() + # as long as there are symbols with equalities, solve for them + # NOTE(avik): this is guaranteed to terminate (#iterations <= #symbols) + while self._symbols_with_equalities: + s = self._symbols_with_equalities.pop() + exprs = self._univariate_inequalities.pop(s) + solution = sympy.solvers.inequalities.reduce_inequalities(exprs, s) + if isinstance(solution, sympy.And): + solution = next((arg for arg in solution.args if isinstance(arg, sympy.Eq)), solution) + assert isinstance(solution, sympy.Eq), f"Expected an equality constraint for {s}, got {solution}" + symbol, val = solution.args + assert symbol == s, f"Expected a constraint on {s} instead of on {symbol}" + # because this is univariate, the solution is a specialization + self._static_results.add(f"{self._dcp.symbol_to_source[s][0].name()} == {val}") + # add this as a substitution to simplify other constraints + self._substitutions[s] = val + + # simplify multivariate inequalities: some of them will now become univariate! + multivariate_inequalities = self._multivariate_inequalities + self._multivariate_inequalities = set() + for expr in multivariate_inequalities: + self.add(expr.subs(s, self._substitutions[s])) + self._raise_inconsistencies() + + self._specialize_divisor_symbols() + + # solve linear congruences + # NOTE(avik): We do not need to solve them for symbols that have already been specialized. + reduced_congruences = self._reduce_congruences() + for s, congruences in reduced_congruences.items(): + for congruence in congruences: + # any congruence that cannot be checked becomes a dynamic constraint as well + if s not in self._substitutions or not sympy.checksol(congruence, {s: self._substitutions[s]}): + if self._is_supported_congruence(congruence): + base, divisor = congruence.args + tmp_name = f"_{self._dcp.source_name_to_debug_name[self._dcp.symbol_to_source[s][0].name()]}" + tmp = sympy.Symbol(tmp_name, integer=True) + from torch._dynamo.source import ConstantSource + self._dcp.symbol_to_source[tmp] = [ConstantSource(tmp_name)] + r = try_solve(sympy.Eq(base, divisor * tmp), s) + self._dynamic_results.add(self._dcp.doprint(sympy.Eq(s, r[1]))) + elif disable_congruences: + self._force_specialization(s) + self._univariate_inequalities.pop(s, None) + + # remaining symbols have only pure inequalities (no equalities) + for s, exprs in self._univariate_inequalities.items(): + try: + solution = sympy.solvers.inequalities.reduce_inequalities(exprs, s) + # because this is univariate, the solution is a dynamic (range) constraint + if isinstance(solution, sympy.Or): + solution = next(iter(arg for arg in solution.args if arg.subs(self._var_to_val))) + if isinstance(solution, sympy.And): + for arg in solution.args: + self._dynamic_results.add(self._dcp.doprint(arg)) + else: + self._dynamic_results.add(self._dcp.doprint(solution)) + except (NotImplementedError, AssertionError) as e: + log.warning("Failed to reduce inequalities: %s", e) + for expr in exprs: + self._dynamic_results.add(self._dcp.doprint(expr)) + + # simplify symbolic equivalences: some of them will now become specializations! + symbolic_equivalences = self._symbolic_equivalences + self._symbolic_equivalences = [] + for source, expr in symbolic_equivalences: + if disable_equivalences and not self._is_supported_equivalence(expr): + for s in expr.free_symbols: + self._force_specialization(s) + sexpr = self._dcp._print_Symbol(s) + self._dynamic_results = {r for r in self._dynamic_results if sexpr not in r} + self.add_equality(source, expr.subs(self._substitutions)) + + # remaining symbolic equivalences become dynamic equality constraints + for source, expr in self._symbolic_equivalences: + self._dynamic_results.add(f"{self._dcp.print_source(source)} == {self._dcp.doprint(expr)}") + + @classmethod + def _is_supported_equivalence(cls, expr): + # Currently supported Dim ops are linear expressions with integer coefficients. + # So check that expr only contains +, *, ints, and a single occurrence of a symbol. + # (See also documentation of dynamic_shapes._DerivedDim.) + if isinstance(expr, (sympy.Add, sympy.Mul)): + lhs, rhs = expr.args + return ( + (cls._is_supported_equivalence(lhs) and isinstance(rhs, sympy.Integer)) or + (isinstance(lhs, sympy.Integer) and cls._is_supported_equivalence(rhs)) + ) + return isinstance(expr, sympy.Symbol) + + @classmethod + def _is_supported_congruence(cls, congruence): + base, divisor = congruence.args + # Congruences that can be currently expressed with supported Dim ops are + # of the form (x + a) % b == 0, where x is a Dim and a and b are constants. + # This allows us to derive x as b*y - a for some Dim y. + # (See also documentation of dynamic_shapes._DerivedDim.) + if isinstance(base, sympy.Add): + lhs, rhs = base.args + cond = ( + (isinstance(lhs, sympy.Symbol) and isinstance(rhs, sympy.Integer)) or + (isinstance(lhs, sympy.Integer) and isinstance(rhs, sympy.Symbol)) + ) + else: + cond = isinstance(base, sympy.Symbol) + cond = cond and isinstance(divisor, sympy.Integer) + return cond + + def forced_specializations(self): + """Returns a dictionary of the names of symbols to their specialized value + """ + def debug_name(src): + name = src.name() + if self._dcp.source_name_to_debug_name: + return f"{self._dcp.source_name_to_debug_name[name]} = {name}" + else: + return name + + return { + debug_name(self._dcp.symbol_to_source[s][0]): val + for s, val in self._substitutions.items() + if s in self._marked_dynamic + } + + def remove_redundant_dynamic_results(self): + """Remove constraints of the form 2 <= dynamic_dim(...) as 2 is the default + lower bound. + """ + candidates_for_removal = [] + dynamic_results = set() + for dc in self._dynamic_results: + # Instead of 2 <= dynamic_dim(...) simply suggest dynamic_dim(...). + # There is no change in behavior since 2 is the default lower bound. + dc_ = re.sub(r"2 <= dynamic_dim(.+)", r"dynamic_dim\1", dc) + if dc != dc_: + candidates_for_removal.append(dc_) + else: + dynamic_results.add(dc_) + for dc in candidates_for_removal: + # remove dynamic_dim(t, 0) as a constraint when dynamic_dim(t, 0) also + # appears as part of another constraint + found = False + for other_dc in dynamic_results: + if dc in other_dc: + found = True + if not found: + dynamic_results.add(dc) + self._dynamic_results = dynamic_results + + def prettify_results( + self, + original_signature: inspect.Signature, + constraint_violation_error=None, + forced_specializations=None, + ): + """Format a message for constraint violation erros""" + if self._dcp.source_name_to_debug_name: + def transform(s): + for k, v in self._dcp.source_name_to_debug_name.items(): + s = s.replace(k, v) + return s + + results = defaultdict(dict) + + def flip(op): + if op == "<=": + return ">=" + if op == ">=": + return "<=" + if op == "<": + return ">" + if op == ">": + return "<" + assert op == "==" + return op + + def relation_with_digit(expr, op, digit): + if op == "<=": + results[expr]["max"] = digit + elif op == "<": + results[expr]["max"] = digit - 1 + elif op == ">=": + results[expr]["min"] = digit + elif op == ">": + results[expr]["min"] = digit + 1 + else: + assert op == "==" + results[expr]["eq"] = digit + + for s in self._static_results.union(self._dynamic_results): + t = transform(s) + if t == s: + continue + left, op, right = re.split(r"( == | <= | >= | < | > )", t) + op = op.strip() + if op == "==" and left == right: + continue + if right.isdigit(): + relation_with_digit(left, op, int(right)) + elif left.isdigit(): + relation_with_digit(right, flip(op), int(left)) + else: + assert op == "==" + results[left]["eq"] = sympy.sympify(right) + + buf = "" + debug_names = set() + if forced_specializations: + debug_names.update(k.split(" = ")[0] for k in forced_specializations.keys()) + buf += ( + f"Specializations unexpectedly required ({', '.join(debug_names)})! " + "For more information, run with TORCH_LOGS=\"+dynamic\".\n" + ) + for s, val in forced_specializations.items(): + buf += f" - {s} must be specialized to {val} because the guards generated for it are too complex.\n" + + dims = [] + others = [] + match = None + if constraint_violation_error: + match = re.search(r"Constraints violated \((.*)\)", constraint_violation_error.args[0]) + if match is not None: + debug_names.update(match.expand(r'\1').split(', ')) + + for k, c in sorted(results.items()): + # if k not in debug_names: + # continue + if "eq" in c: + other = c["eq"] + if isinstance(other, int): + others.append(f"{k} = None # {other}") + elif self._is_supported_equivalence(other): + s = next(iter(other.free_symbols)) + if s not in results: + modulus, remainder = sympy.polys.polytools.div(other, s) + c_min = c.get("min", 2) + min_ = math.ceil((c_min - remainder) / modulus) + c_max = c.get("max", sys.maxsize - 1) + max_ = math.floor((c_max - remainder) / modulus) + dims.append(f"{s} = Dim('{s}', min={min_}, max={max_}) # {c_min} <= {other} <= {c_max}") + others.append(f"{k} = {other}") + else: + min_ = c.get("min", None) + if min_ == 2: + min_ = None + max_ = c.get("max", None) + if min_ is not None and max_ is not None: + dims.append(f"{k} = Dim('{k}', min={min_}, max={max_})") + elif min_ is not None: + dims.append(f"{k} = Dim('{k}', min={min_})") + elif max_ is not None: + dims.append(f"{k} = Dim('{k}', max={max_})") + else: + dims.append(f"{k} = Dim('{k}')") + + buf += "\nSuggested fixes:\n " + buf += "\n ".join(dims + others) + + return buf + + # Note: Model inputs are wrapped as LocalSource in dynamo. + # LocalSource.name() wraps the name with L[""]. We use regular + # expression to do the replacement to avoid traversing up + # the source hierarchy manually. + def extract_and_rewrite_local(dc): + match = re.search(r"L\['(.+?)'\]", dc) + if match is None: + return + arg = match.expand(r'\1') + dc = re.sub(r"L\['(.+?)'\]", r'\1', dc) + return arg, dc + + def group(results, args_index): + groups = defaultdict(list) + for dc in results: + local = extract_and_rewrite_local(dc) + if local is None: + # This can happen, e.g., with `assume_constant_result`. + # In that case, we drop the constraint. + # TODO(avik) Maybe we should generate an assertion here? + continue + arg, dc = local + if arg in args_index: + groups[args_index[arg]].append(dc) + else: + # This can happen, e.g., with decorators that change the signature. + # In that case, we drop the constraint. Seems hard to do better. :/ + # TODO(avik) Maybe warn that `arg` in not in `signature`? + continue + sorted_groups = [] + for idx, dcs in sorted(groups.items()): + _, arg = idx + sorted_groups.append((arg, sorted(dcs))) + return sorted_groups + + signature = original_signature.replace(return_annotation=inspect.Signature.empty) + args_index = {} + for i, arg in enumerate(signature.parameters.keys()): + args_index[arg] = (i, arg) + + def print_results(grouped, indent, result_fn): + nonlocal buf + + space = False + for arg, results in grouped: + if space: + buf += "\n" + else: + space = True + buf += f"\n{indent}# {arg}:" + for result in results: + buf += f"\n{indent}{result_fn(result)}" + + buf = "" + if forced_specializations: + buf += ( + "Some dynamic dimensions need to be specialized because " + "the constraints inferred for them are too complex to specify.\n" + ) + for s, val in forced_specializations.items(): + buf += f" - {s}, which was marked dynamic, must be specialized to {val}.\n" + indent = 4 * " " + if self._static_results: + grouped_static_results = group(self._static_results, args_index) + buf += "\nThe following dimensions have been specialized and CANNOT be dynamic." + buf += f"\n```\ndef specializations{str(signature)}:" + print_results( + grouped_static_results, + indent, + lambda result: f"assert {result}", + ) + buf += "\n```\n" + if self._dynamic_results: + grouped_dynamic_results = group(self._dynamic_results, args_index) + buf += "\nThe following dimensions CAN be dynamic." + buf += "\nPlease use the following code to specify the constraints they must satisfy:" + buf += f"\n```\ndef specify_constraints{str(signature)}:" + buf += f"\n{indent}return [" + print_results( + grouped_dynamic_results, + indent * 2, + lambda result: f"{result},", + ) + buf += f"\n{indent}]\n```\n" + return buf + + +TLS = threading.local() + + +class ShapeEnv: + # This is a wrapper over the actual __init__ function. + # + # Where to add a new constructor parameter to ShapeEnv? + # ===================================================== + # This __init__ function should be used only for parameters related to event recording. + # These are parameters that we don't wish to pass down the road to new ShapeEnv instances + # created from replaying events. + # + # If you wish to add a parameter to the constructor of ShapeEnv, unrelated to event + # recording, do so in the _init function. + def __init__( + self, *, + should_record_events: Optional[bool] = None, + tracked_fakes: Optional[List[Any]] = None, + **kwargs + ) -> None: + self._init(**kwargs) + + # Disable event recording when replaying. + kwargs["should_record_events"] = False + + from torch.fx.experimental.validator import translation_validation_enabled + self._translation_validation_enabled = translation_validation_enabled() + + # If not specified, enable event recording if both: + # - Translation validation is on + # - Translation validation bisection is not disabled + self.should_record_events = ( + should_record_events + if should_record_events is not None + else ( + self._translation_validation_enabled + and not config.translation_validation_no_bisect + ) + ) + + # Enable event recording check if both: + # - It should record events + # - The recording check is enabled + self.check_recorded_events = ( + self.should_record_events and config.check_shape_env_recorded_events + ) + + # This will make sure we only record the top-level function call. + self.is_recording = not self.should_record_events + # Keep track of the list of tracked fakes. + self.tracked_fakes = tracked_fakes + # List of events for reconstructing ShapeEnv at arbitrary points in time. + self.events: List[ShapeEnvEvent] = ( + [ShapeEnvEvent(ShapeEnv, kwargs=kwargs)] if self.should_record_events else [] + ) + + # Pro-tip: if you add new field to ShapeEnv, this affects some accept + # tests. Accept their output with: + # + # EXPECTTEST_ACCEPT=1 python test/dynamo/test_dynamic_shapes.py -k test_shape_env_equal + # + def _init( + self, *, + allow_scalar_outputs=True, + allow_dynamic_output_shape_ops=True, + # NB: These are legacy configuration that help us make good choices + # when the constraint/dynamic dims are not explicitly passed to us. + # Ideally we will fix all call sites to be explicit and not have + # implicit choices, but this apparently was pretty involved. + assume_static_by_default=False, + # Note - On 0/1 specialization + # + # The following options affect decisions we make about eager + # specialization. Disabling them will increase trace time (as we do + # more symbolic reasoning) and can also harm the quality of generated + # code (because inductor may not be able to specialize for bounds + # being equal--although if we later respecialize because of a guard, + # your code may be just as good as it was before.) + # + # When True, eagerly specialize input sizes which have 0/1. + specialize_zero_one=True, + # When True, assume input sizes which have the same size are + # symbolically equal. + duck_shape=True, + # For debugging + co_fields=None, + # XXX Add any new settings that could affect FakeTensor evaluation + # to: torch._subclasses.fake_tensor._ShapeEnvSettings + ): + # Not directly used by ShapeEnv; indirectly used by FakeTensor + self.allow_scalar_outputs = allow_scalar_outputs + self.allow_dynamic_output_shape_ops = allow_dynamic_output_shape_ops + self.guards: List[ShapeGuard] = [] + # Maps symbolic ints to their original concrete values + # Currently populated from tensors + self.var_to_val: Dict[sympy.Symbol, sympy.Integer] = {} + # Maps symbolic ints to their min/max range. These ranges + # are conservative: the int MUST fall in the range, but the + # range may contain ints which may not actually appear in + # practice + self.var_to_range: Dict[sympy.Symbol, ValueRanges] = {} + self.source_name_to_debug_name: Dict[str, str] = {} + self.var_to_sources: Dict[sympy.Symbol, List[Source]] = {} + self.var_to_stack: Dict[sympy.Symbol, CapturedTraceback] = {} + # Maps from sympy ints to expressions representing them + # Populated from equality guards (i.e. a.shape[0] == b.shape[0]) + self.replacements: Dict[sympy.Symbol, sympy.Expr] = {} + # Set holds a % b expressions that evaluate to 0. + self.divisible: Set[sympy.Expr] = set() + # Set that holds "size-like" symbols. When we perform + # "size-oblivious" tests, these can be assumed to be >= 2. + self.size_like: Set[sympy.Symbol] = set() + # Duck-shaping says that if two input tensors have the same size, + # they get assigned the same symbolic variable + self.val_to_var: Dict[int, sympy.Expr] = {} + if specialize_zero_one: + self.val_to_var = {0: sympy.Integer(0), 1: sympy.Integer(1)} + self.unbacked_symfloat_counter = itertools.count() + self.unbacked_symint_counter = itertools.count() + # Similar to guards, but these MUST evaluate to true and can + # only be evaluated at runtime midway through (i.e., they always + # involve unbacked symints) + # + # For efficiency reasons, we index in the following way. Suppose you have + # a runtime assert i0 + i1 <= s1. We pick the most recently allocated + # symbol in the source expression and add the assert to the list for + # that symbol e.g., {i1: [i0 + i1 <= s1]}. + # + # We access the runtime asserts in two situations: + # + # - When we are guarding on an expression, we will attempt to + # statically evaluate it, in case the unbacked SymInts can + # simplify away. If we have a runtime assert, we may be able + # to discharge the guard entirely. We only need to attempt + # runtime asserts that mention freevars of the expression in + # question. + # + # - When we are performing codegen (in Inductor for eager, or + # when finalizing the export FX graph), we need to know what + # extra runtime asserts to insert. Whenever an unbacked + # SymInt comes into scope, all runtime asserts involving it + # become eligible for insertion (so long as all of their other + # free unbacked symbols are also in scope). We technically + # can handle any choice of key by kicking inexpressible asserts + # to the next unbacked symbol to wait on, but if we choose the + # latest key, an assert will only show up at the moment when + # we can actually codegen it. + self.deferred_runtime_asserts: Dict[sympy.Symbol, List[RuntimeAssert]] = {} + # This exists so we can efficiently invalidate the cache (it's used as + # part of the cache key); otherwise we'd have to iterate through + # deferred_runtime_asserts to compute its length + self.num_deferred_runtime_asserts = 0 + self.assume_static_by_default = assume_static_by_default + self.specialize_zero_one = specialize_zero_one + self.duck_shape = duck_shape + self.log = log + self.log.debug("create_env") + self.frozen = False + self.dim_constraints: Optional[DimConstraints] = None + self.counter = collections.Counter() + # Mapping from sympy.Symbol to the number of guards which mention this + # symbol + self.symbol_guard_counter = collections.Counter() + # A selection of important fields on co_field; solely used for + # signpost_event + self.co_fields = co_fields if co_fields else {} + + # Version counter used to invalidate cached values + self._prev_cache_key = self._get_key() + self._version_counter = 0 + + # Cache for FX nodes. + # Maps an already built node a tuple of: + # 1. node's target + # 2. list of arguments + # This drastically reduces the size of the FX graph, avoiding + # duplicated nodes. + self.fx_node_cache: Dict[Tuple[Callable, Tuple[Any, ...]], torch.fx.Node] = {} + self.source_to_symbol: Dict[str, sympy.Symbol] = {} + + from torch.fx.experimental.validator import translation_validation_enabled + self._translation_validation_enabled = translation_validation_enabled() + + if self._translation_validation_enabled: + from torch.fx.experimental.validator import TranslationValidator + + self.validator = TranslationValidator() + self.graph = torch.fx.Graph() + # Create an output graph and start inserting before that. + # This is needed when 'deepcopy'-ing this object. + self.graph.inserting_before(self.graph.output(None)) + + # Mapping of each node name to the node itself. + # + # This is useful for matching an FX node from a recorded ShapeEnv.graph + # to the FX node of the ShapeEnv we are running the event on. + # + # Whenever you add a node to self.graph, you must add a mapping to this + # variable. Otherwise, the built FX graph on the replayed ShapeEnv will + # not be valid. + self.name_to_node: Dict[str, torch.fx.Node] = {} + + def check_equal(self, other: "ShapeEnv") -> None: + """Compare another ShapeEnv for equivalence + """ + # ShapeEnv fields that are not relevant for the outcome of + # ShapeEnv.produce_guards call: + # - Debugging variables + # - Translation validation related variables + # - Events recording related variables + non_state_variable_names = ( + "counter", + "log", + "var_to_stack", + "fx_node_cache", + "graph", + "validator", + "check_recorded_events", + "should_record_events", + "is_recording", + "tracked_fakes", + "events", + "source_name_to_debug_name", + "_prev_cache_key", + "_version_counter", + ) + + # Mapping of the value of each to-be-compared field into the values that + # should actually be compared. + # + # You should modify this if, for example, the field that holds state and + # debugging information. e.g. ShapeGuard holds the actual guard (sympy.Expr) + # and the stack when it was added to the set of guards. In order to compare + # it, we throw away the stack information. + def map_value(key: str, value: Any) -> Any: + if key in ("unbacked_symfloat_counter", "unbacked_symint_counter"): + from copy import copy + + # For itertools.count(), we compare the next integer returned + # by the count iterators. Not that we need to copy the iterator + # first. Otherwise we are mutating the object. + return next(copy(value)) + elif key == "guards": + # Transform the list of ShapeGuard into a list of expressions. + return [g.expr for g in value] + elif key == "deferred_runtime_asserts": + # Transform the list of RuntimeAsserts into a list of expressions. + return {s: [ra.expr for ra in ras] for s, ras in value.items()} + elif key == "name_to_node": + # Compare just the set of keys is the same. + return set(value.keys()) + elif key == "symbol_guard_counter": + # Skip this for comparisons + return None + return value + + shape_env_check_state_equal(self, other, non_state_variable_names, map_value) + + def _snapshot_tracked_fakes(self) -> Optional[List[Any]]: + if self.tracked_fakes is None: + return None + + from torch._dynamo.variables.builder import TrackedFake + + def maybe_transform_fake(fake: TrackedFake): + inner_fake = fake.fake \ + if isinstance(fake.fake, torch.SymInt) \ + else FakeTensorMeta.from_fake(fake.fake) + # Even though TrackedFake accepts either a Union[SymInt, FakeTensor], here we give it a + # FakeTensorMeta for two reasons: + # 1. this is all the information we need when recording ShapeEnvEvents. + # 2. it works even if each TrackedFake changes its metadata. + return TrackedFake(inner_fake, fake.source, fake.symbolic_context) # type: ignore[arg-type] + + return [maybe_transform_fake(fake) for fake in self.tracked_fakes] + + def _last_event_index(self) -> int: + return len(self.events) - 1 + + @contextmanager + def _recording(self): + self.is_recording = True + try: + yield + finally: + self.is_recording = False + + @record_shapeenv_event() + def freeze(self): + """Freeze this ShapeEnv to stop accumulating guards + + A frozen ShapeEnv will ignore any further guards generated on it and + only emit a warning which may lead to accuracy problems. + """ + self.frozen = True + + def _create_symbol_for_source(self, source: Source) -> Optional[sympy.Symbol]: + if not self._translation_validation_enabled: + return None + srcname = source.name() + if source not in self.source_to_symbol: + self.source_to_symbol[srcname] = sympy.Symbol(srcname, integer=True) + return self.source_to_symbol[srcname] + + def _add_z3var(self, symbol: sympy.Symbol, type: Type) -> None: + if self._translation_validation_enabled: + self.validator.add_var(symbol, type) + + def _add_target_expr(self, expr) -> None: + if self._translation_validation_enabled: + self.validator.add_target_expr(expr) + + def _add_assertion(self, expr) -> None: + if self._translation_validation_enabled: + self.validator.add_assertion(expr) + + def _check_translation_validate(self) -> None: + if self._translation_validation_enabled: + self.validator.validate() + + @record_shapeenv_event() + def _create_fx_call_function( + self, + op: Callable, + args: Tuple, + ) -> Tuple[Optional[torch.fx.Node], bool]: + # Cache this tuple in order to avoid duplicated nodes. + node_key = (op, args) + # Flags whether the returned node was cached or not. + fresh = False + + if self._translation_validation_enabled and node_key not in self.fx_node_cache: + from torch.fx.experimental.validator import z3op + + # Presence of None in the arguments implies that we should ignore this operation. + if any(a is None for a in args): + # We check if we are not mixing SymNode that should not be ignored + # (fx_node is not None) with those that should (fx_node is None). + assert all(not isinstance(a, torch.fx.Node) for a in args) + return None, fresh + + fresh = True + lifted_op = z3op(op, self.validator) + + # If translation validation is enabled, all arguments must have its + # own FX node. + assert all(a is not None for a in args), f"missing arg in FX graph ({op.__name__}): {args}" + node = self.fx_node_cache[node_key] = self.graph.call_function(lifted_op, args) + self.name_to_node[node.name] = node + + return self.fx_node_cache.get(node_key, None), fresh + + def _create_fx_placeholder_and_z3var( + self, + symbol: sympy.Symbol, + type: Type, + ) -> Optional[torch.fx.Node]: + if not self._translation_validation_enabled: + return None + + node_key = (self.graph.placeholder, (symbol,)) + + # Check if we haven't added this symbol already. + # If so, skip the placeholder creation, as it + # generates invalid Python code. + if node_key not in self.fx_node_cache: + # Add a Z3 variable according to 'type'. + self._add_z3var(symbol, type) + # Create the FX placeholder out of a mangled name. + mangled_name = re.sub(r'[^a-zA-Z0-9]', '_', re.sub(r'[()]', '', symbol.name)) + node = self.fx_node_cache[node_key] = self.graph.placeholder(mangled_name) + self.name_to_node[node.name] = node + # Attach the 'symbol' to the placeholder so that we can retrieve + # the Z3 variable later. + node.meta["symbol"] = symbol + + return self.fx_node_cache[node_key] + + def _remove_fx_node(self, node: Optional[torch.fx.Node]) -> None: + if self._translation_validation_enabled and node is not None: + self.name_to_node.pop(node.name) + self.graph.erase_node(node) + + def _add_fx_node_metadata(self, node: torch.fx.Node) -> None: + from torch._dynamo.utils import get_current_node + + if self.should_record_events: + node.meta[SHAPEENV_EVENT_KEY] = self._last_event_index() + node.meta[CURRENT_NODE_KEY] = get_current_node() + + def _suppress_guards_tls(self): + return getattr(TLS, "suppress_guards", False) + + @record_shapeenv_event() + def _suppress_guards_enter(self): + TLS.suppress_guards = True + + @record_shapeenv_event() + def _suppress_guards_exit(self): + TLS.suppress_guards = False + + @contextmanager + def suppress_guards(self): + """Context manager to ignore all guards generated inside""" + self._suppress_guards_enter() + try: + yield + finally: + self._suppress_guards_exit() + + def _get_key(self): + """ + Defines the current "state" of the guards we've accumulated in this ShapeEnv. + Determines when we need to invalidate our cache + """ + return (len(self.replacements), len(self.divisible), self.num_deferred_runtime_asserts) + + def _update_version_counter(self): + # The shape environment is queried orders of magnitude more often than + # it is changed, so we summarise the cache key into a linearly + # increasing version counter which is cheaper to check in _lru_cache + + # Only update version counter if the state actually changed + cur_key = self._get_key() + if self._prev_cache_key != cur_key: + self._prev_cache_key = cur_key + self._version_counter += 1 + + def _produce_dyn_sizes(self, + ex_size: Sequence[int], + source: Source, + symbolic_context: SymbolicContext + ) -> List[sympy.Expr]: + return self._produce_dyn_sizes_from_int_tuple(tuple(ex_size), source, symbolic_context) + + def _produce_dyn_sizes_from_int_tuple(self, + tensor_size: Tuple[int], + source: Source, + symbolic_context: SymbolicContext, + ) -> List[sympy.Expr]: + assert all(not is_symbolic(val) for val in tensor_size), f"Expect size to be a plain tuple of ints but got {tensor_size}" + from torch._dynamo.source import TensorPropertySource, TensorProperty + _assert_symbol_context(symbolic_context) + dynamic_dims = symbolic_context.dynamic_sizes + constraint_dims = symbolic_context.constraint_sizes + size = [] + for i, val in enumerate(tensor_size): + size.append(self.create_symbol( + val, + TensorPropertySource(source, TensorProperty.SIZE, i), + dynamic_dims[i], + constraint_dims[i], + symbolic_context=symbolic_context + )) + return size + + def create_symbolic_sizes_strides_storage_offset( + self, + ex: torch.Tensor, + source: Source, + *, + symbolic_context: Optional[SymbolicContext] = None, + ): + """ + Returns a list of symbolic sizes and strides for the given tensor. + We try our best to express stride in terms of the sizes, so as to not + introduce new symbolic variables. + """ + + # Dynamo may want to wrap FakeTensors with SymInt sizes up e.g. make_fx(opt_f(), tracing_mode="symbolic"). + # We create symbols in shape_env using the backed hints behind SymInt. + + # Case 1: when SymInt is backed, dynamo can proceed with FakeTensors that have concrete shape. + # produce_guards will trigger specializations on the outer stuff + + # Case 2: when the SymInt is unbacked, we will throw an data dependent error in require_hint(). + # + # It's probably good for now but it's important to note that this approach has implications for + # the original shape_env when checking guards in different order. + + # Example: + # --------- + # Consider a function "opt_f" as shown below: + + # @torch.compile() + # def opt_f(x: bool, y: Tensor): + # if x == True: + # return y + torch.randn([4]) + # else: + # return y + # Depending on the sequence of calls, we might install two different sets of guards: + + # 1. opt_f(False, y): + # - "x == False" (always works for any size y) + + # 2. opt_f(True, y): + # - Triggers recompilation and results in guards like: + # - "x == True and y.size(0) == 4" + # - (or "y.size(0) == 4 and x == True") + + # The order of checking the guards matters. In this specific example: + # If True branch guard check precedes False branch and for True branch, y.size(0) check precedes x == True, + # we may have an unnessary shape speciliazation for y. + def maybe_specialize_sym_int_with_hint(maybe_sym) -> int: + assert isinstance(maybe_sym, (int, torch.SymInt)) + if is_symbolic(maybe_sym): + assert maybe_sym.node.shape_env is not self, \ + "expect the symbol is created from an shape env other than current one." + return maybe_sym.node.require_hint() + return maybe_sym + + ex_size = tuple(maybe_specialize_sym_int_with_hint(sz) for sz in ex.size()) + ex_stride = tuple(maybe_specialize_sym_int_with_hint(sd) for sd in ex.stride()) + ex_storage_offset = maybe_specialize_sym_int_with_hint(ex.storage_offset()) + + return self._create_symbolic_sizes_strides_storage_offset( + ex_size, + ex_stride, + ex_storage_offset, + [_is_dim_dynamic(ex, i) for i in range(ex.dim())], + source, + symbolic_context=symbolic_context, + ) + + @record_shapeenv_event() + def _create_symbolic_sizes_strides_storage_offset( + self, + ex_size: Sequence[int], + ex_stride: Sequence[int], + ex_storage_offset: int, + is_dim_dynamic: Sequence[bool], + source: Source, + *, + symbolic_context: Optional[SymbolicContext] = None, + ): + dim = len(ex_size) + + # Reimplement the legacy behavior + if symbolic_context is None: + constraint_dims = [None] * dim + dynamic_dims = [] + for i in range(dim): + # NB: This is encapsulation breaking! Legacy behavior was + # bad. + if is_dim_dynamic[i]: + r = DimDynamic.DYNAMIC + elif self.assume_static_by_default: + r = DimDynamic.STATIC + else: + r = DimDynamic.DUCK + dynamic_dims.append(r) + dynamic_dims = [DimDynamic.DUCK] * dim + # symbolic_context is None - set one + symbolic_context = StatelessSymbolicContext(dynamic_sizes=dynamic_dims, constraint_sizes=constraint_dims) + # We got a StatelessSymbolicContext + _assert_symbol_context(symbolic_context) + constraint_dims = symbolic_context.constraint_sizes + dynamic_dims = symbolic_context.dynamic_sizes + + # TODO: make this configurable from outside symbolic_context; we made a symbolic_context + # decision here where if all sizes are static, we are going to + # specialize all of the inner strides/offset too. We don't have to + # do this, and arguably we should ALWAYS allow for dynamic offset, + # this is cheap. + # TODO: This should be DYNAMIC, using DUCK for BC + dynamic_strides_offset = DimDynamic.STATIC if all(r == DimDynamic.STATIC for r in dynamic_dims) else DimDynamic.DUCK + + assert len(dynamic_dims) == dim, f"{len(dynamic_dims)} != {dim}" + assert len(constraint_dims) == dim + + from torch._dynamo.source import TensorPropertySource, TensorProperty + size: List[sympy.Expr] = self._produce_dyn_sizes_from_int_tuple(ex_size, source, symbolic_context) + stride: List[Optional[sympy.Expr]] = [None] * len(size) + for i, val in enumerate(ex_stride): + if val in (0, 1): + stride[i] = sympy.Integer(val) + while any(x is None for x in stride): + candidates = { + ex_size[i] * ex_stride[i]: size[i] * stride[i] + for i in range(len(size)) + if stride[i] is not None and ex_stride[i] >= 0 + } + + # iterate over unbound strides in sorted order + def _nested_int_aware_sort(tup): + return ( + # Order nested ints by their coefficients. + # 1 here to order nested ints after non-nested-ints. + (1, tup[0].node.nested_int_coeff(), tup[1]) if is_nested_int(tup[0]) + else (0, *tup) + ) + val_list = sorted( + [(ex_stride[i], i) for i in range(len(stride)) if stride[i] is None], + key=_nested_int_aware_sort, + ) + for _, i in val_list: + if stride[i] is None and ex_stride[i] in candidates: + stride[i] = candidates[ex_stride[i]] + candidates[ex_size[i] * ex_stride[i]] = size[i] * stride[i] + + if any(x is None for x in stride): + # bind the smallest unbound stride to a new variable + val, i = min( + [ + (ex_stride[i], i) + for i in range(len(stride)) + if stride[i] is None + ], key=_nested_int_aware_sort + ) + stride[i] = self.create_symbol( + val, + TensorPropertySource(source, TensorProperty.STRIDE, i), + dynamic_dim=dynamic_strides_offset, + constraint_dim=None, + symbolic_context=symbolic_context, + ) + assert all(x is not None for x in stride) + + sym_sizes = [ + self.create_symintnode( + sym, + hint=hint, + source=TensorPropertySource(source, TensorProperty.SIZE, i), + ) + for i, (sym, hint) in enumerate(zip(size, ex_size)) + ] + sym_stride = [] + for i, stride_expr in enumerate(stride): + # NB: Don't duck size the stride; instead use the expression + # we computed + assert stride_expr is not None + sym_stride.append(self.create_symintnode( + stride_expr, hint=ex_stride[i], source=TensorPropertySource(source, TensorProperty.STRIDE, i))) + sym_storage_offset = self.create_symintnode( + self.create_symbol( + ex_storage_offset, + TensorPropertySource(source, TensorProperty.STORAGE_OFFSET), + dynamic_dim=dynamic_strides_offset, + constraint_dim=None, + symbolic_context=symbolic_context + ), + hint=ex_storage_offset, + source=TensorPropertySource(source, TensorProperty.STORAGE_OFFSET)) + return tuple(sym_sizes), tuple(sym_stride), sym_storage_offset + + @record_shapeenv_event() + def create_symintnode( + self, + sym: "sympy.Expr", + *, + hint: Optional[int], + source: Optional[Source] = None, + ): + """Create a SymInt value from a symbolic expression + + If you know what the current hint value of the SymInt to be created + is, pass it into hint. Otherwise, pass None and we will make our best + guess + + """ + source_name = source.name() if source else None + + if self._translation_validation_enabled and source is not None: + # Create a new symbol for this source. + symbol = self._create_symbol_for_source(source) + assert symbol is not None + + # Create a new FX placeholder and Z3 variable for 'symbol'. + fx_node = self._create_fx_placeholder_and_z3var(symbol, int) + + # Add an equality assertion for the newly created symbol and 'sym'. + self._add_assertion(sympy.Eq(symbol, sym)) + else: + fx_node = None + + if isinstance(sym, sympy.Integer): + if hint is not None: + assert int(sym) == hint + out = int(sym) + else: + out = SymInt(SymNode(sym, self, int, hint, fx_node=fx_node)) + return out + + @record_shapeenv_event() + def create_unspecified_symint_and_symbol(self, value, source, dynamic_dim): + """Create a SymInt wrapping a new unspecified symbol""" + return self.create_symintnode( + self.create_unspecified_symbol( + value, + source=source, + dynamic_dim=dynamic_dim, + ), + hint=value, + source=source, + ) + + def create_symboolnode(self, sym: "sympy.Expr"): + """Create a SymBool object from a sympy boolean expression""" + # This function is only being used in serialization, so we do not track it + # for validation. + return SymBool(SymNode(sym, self, bool, None)) + + def _log_create_unbacked_symbol(self, prefix: str, symbol, vr: ValueRanges): + is_debug = config.extended_debug_create_symbol is not None and str(symbol) in config.extended_debug_create_symbol.split(',') + fsummary, maybe_user_loc, maybe_extra_debug = self._get_stack_summary(is_debug) + log.info( + "%s %s [%s, %s]%s (%s)%s", + prefix, symbol, vr.lower, vr.upper, maybe_user_loc, format_frame(fsummary), maybe_extra_debug, stack_info=is_debug + ) + + @record_shapeenv_event() + def create_unbacked_symfloat(self): + """Create a symbolic float without a hint value + """ + symbol: sympy.Symbol = sympy.Symbol(f"f{next(self.unbacked_symfloat_counter)}") + self.counter["create_unbacked_symbol"] += 1 + self.var_to_stack[symbol] = CapturedTraceback.extract(skip=1) + vr = self.var_to_range[symbol] = ValueRanges.unknown() + + # Create a new FX placeholder and Z3 variable for 'symbol'. + fx_node = self._create_fx_placeholder_and_z3var(symbol, float) + + self._log_create_unbacked_symbol("create_unbacked_symfloat", symbol, vr) + + return SymFloat(SymNode(symbol, self, float, None, fx_node=fx_node)) + + @record_shapeenv_event() + def create_unbacked_symint(self): + """Create a symbolic integer without a hint value + """ + symbol: sympy.Symbol = sympy.Symbol(f"u{next(self.unbacked_symint_counter)}", integer=True) + self.counter["create_unbacked_symbol"] += 1 + self.var_to_stack[symbol] = CapturedTraceback.extract(skip=1) + vr = self.var_to_range[symbol] = self._default_unspecified_value_range() + + # Create a new FX placeholder and Z3 variable for 'symbol'. + fx_node = self._create_fx_placeholder_and_z3var(symbol, int) + + self._log_create_unbacked_symbol("create_unbacked_symint", symbol, vr) + + return SymInt(SymNode(symbol, self, int, None, fx_node=fx_node)) + + def is_unbacked_symint(self, symbol: sympy.Symbol) -> bool: + """Check if a sympy symbol matches the naming convention for unbacked symbols + """ + # NB: keep synced with free_unbacked_symbols + return str(symbol).startswith("u") + + @record_shapeenv_event() + def create_unbacked_symbool(self): + """Create a symbolic boolean without a hint value + """ + symbol: sympy.Symbol = sympy.Symbol(f"u{next(self.unbacked_symint_counter)}", integer=True) + self.counter["create_unbacked_symbol"] += 1 + self.var_to_stack[symbol] = CapturedTraceback.extract(skip=1) + vr = self.var_to_range[symbol] = ValueRanges(0, 1) + + # Create a new FX placeholder and Z3 variable for 'symbol'. + fx_node = self._create_fx_placeholder_and_z3var(symbol, bool) + + self._log_create_unbacked_symbol("create_unbacked_symbool", symbol, vr) + + return SymBool(SymNode(sympy.Eq(symbol, 1), self, bool, None, fx_node=fx_node)) + + @record_shapeenv_event() + def create_unspecified_symbol( + self, + val: Union[int, SymInt], + source: Source, + dynamic_dim: DimDynamic = DimDynamic.DUCK, + constraint_dim: DimConstraint = None, # NB: includes None + ) -> "sympy.Expr": + """Create a symbol with an unspecified value + + Compared to standard symbols we do not assume the value is positive, + nor do we specialze on zero or one values. + """ + # 'positive' is None for unspecified symbols, since we can't + # assume that it will be neither positive nor negative. + + # We don't want to specialize zero one val for unspecified symbol + # so that we can always get a new symbol despite val. + return self.create_symbol( + val, + source, + dynamic_dim, + constraint_dim, + positive=None, + do_not_specialize_zero_one=True, + symbolic_context=None) + + @record_shapeenv_event() + def create_symbol( + self, + val: int, + source: Source, + dynamic_dim: DimDynamic = DimDynamic.DUCK, + constraint_dim: DimConstraint = None, # NB: includes None + positive: Optional[bool] = True, + do_not_specialize_zero_one: bool = False, + symbolic_context=None, + ) -> "sympy.Expr": + """Create a new symbol which is tracked by this ShapeEnv + """ + # see note [Tensor Fakification and Symbol Caching] + source_name = source.name() + if (isinstance(symbolic_context, StatefulSymbolicContext) + and id(self) not in symbolic_context.shape_env_to_source_to_symbol_cache): + symbolic_context.shape_env_to_source_to_symbol_cache[id(self)] = {} + + if (isinstance(symbolic_context, StatefulSymbolicContext) + and source_name + and (source_name in symbolic_context.shape_env_to_source_to_symbol_cache[id(self)])): + return symbolic_context.shape_env_to_source_to_symbol_cache[id(self)][source_name] + + if do_not_specialize_zero_one: + specialize_zero_one = False + else: + specialize_zero_one = self.specialize_zero_one + + assert isinstance(source, Source), f"{type(source)} {source}" + assert not (positive and val < 0), f"positive set for negative value: {val}" + # It's always sound to allocate a symbol as DYNAMIC. If the user + # constrained the symbol, force the symbolic_context to DYNAMIC, because our + # constraint code will do weird stuff if, e.g., it's duck shaped + if constraint_dim is not None: + dynamic_dim = DimDynamic.DYNAMIC + + if dynamic_dim is DimDynamic.STATIC: + out = sympy.Integer(val) + if isinstance(symbolic_context, StatefulSymbolicContext) and source_name: + symbolic_context.shape_env_to_source_to_symbol_cache[id(self)][source_name] = out + return out + + elif dynamic_dim is DimDynamic.DUCK: + # duck_shape can be used to globally turn off duck shaping, even + # if it was requested + duck = self.duck_shape + elif dynamic_dim is DimDynamic.DYNAMIC: + duck = False + else: + raise AssertionError(f"unhandled dynamic_dim {dynamic_dim}") + + if val in (0, 1) and specialize_zero_one: + r = self.val_to_var[val] + elif not duck or val not in self.val_to_var: + # If we're not duck shaping, we always create a new symbol + # Even if we're duck shaping, if we haven't seen this particular + # value before, we also create a new symbol + sympy_expr = sympy.Symbol(f"s{len(self.var_to_val)}", positive=positive, integer=True) + # We always associate vars to vals + if isinstance(val, int): + self.var_to_val[sympy_expr] = sympy.Integer(val) + else: + # Only used for jagged layout nested tensors + self.var_to_val[sympy_expr] = SingletonInt(val.node.nested_int(), coeff=val.node.nested_int_coeff()) + + # Do the appending later, because we always want to populate this + self.var_to_sources[sympy_expr] = [] + # Create a Z3 variable for the new symbol. + self._add_z3var(sympy_expr, int) + + if duck: + # Make sure to reuse this symbol for subsequent duck shaping + self.val_to_var[val] = sympy_expr + + if isinstance(val, int): + if positive: + # Add assertions for the newly created symbols + self._add_assertion(sympy_expr > 1) + + # Apply default range, which assumes not zero-one + self.var_to_range[sympy_expr] = self._default_value_range() + else: + self.var_to_range[sympy_expr] = self._default_unspecified_value_range() + + # Small performance optimization: if we have a min-max constraint, + # we can proactively narrow to that range + if isinstance(constraint_dim, StrictMinMaxConstraint): + assert not duck + self.var_to_range[sympy_expr] &= constraint_dim.vr + + vr = self.var_to_range[sympy_expr] + + if val not in vr: + raise ConstraintViolationError(f"{val} not in range [{vr.lower}, {vr.upper}]") + + range_str = f"[{vr.lower}, {vr.upper}]" + else: + # Skip var_range logic for SingletonInt + # Only used for jagged layout nested tensors + range_str = "" + + r = sympy_expr + + is_debug = ( + config.extended_debug_create_symbol is not None and + str(sympy_expr) in config.extended_debug_create_symbol.split(',') + ) + fsummary, maybe_user_loc, maybe_extra_debug = self._get_stack_summary(is_debug) + self.log.info( + "create_symbol %s = %s for %s %s%s (%s)%s", + sympy_expr, val, source.name(), range_str, + maybe_user_loc, format_frame(fsummary), maybe_extra_debug, stack_info=is_debug + ) + + self.counter["create_symbol"] += 1 + else: + # This implements duck-shaping: input sizes that match are assigned + # the same symint + r = self.val_to_var[val] + self.log.debug("create_symbol %s duck sized %s", r, source.name()) + + if isinstance(r, sympy.Symbol): + r_sources = self.var_to_sources[r] + r_sources.append(source) + if not source.is_ephemeral() and r_sources[0].is_ephemeral(): + # prefer non-ephemeral source first since it may be guarded on later + r_sources[0], r_sources[-1] = r_sources[-1], r_sources[0] + + # This ensures we get zeros in symbol_guard_counts, which makes + # some queries simpler (since we will accumulate mass on 0 this + # way) + self.symbol_guard_counter[r] = 0 + + if isinstance(symbolic_context, StatefulSymbolicContext) and source_name: + symbolic_context.shape_env_to_source_to_symbol_cache[id(self)][source_name] = r + return r + + def _debug_name(self, source): + src_name = source.name() + return self.source_name_to_debug_name.get(src_name, src_name) + + def _render_range_for_constraint_violation(self, source, c): + if isinstance(c, StrictMinMaxConstraint): + lower, upper = c.vr.lower, c.vr.upper + default = self._default_value_range() + if lower <= default.lower: + lower = None + if upper >= default.upper: + upper = None + c_render = f"{self._debug_name(source)} = {source.name()} in the specified range" + if lower is not None and upper is not None: + c_render += f" {lower} <= {self._debug_name(source)} <= {upper}" + elif lower is None and upper is not None: + c_render += f" {self._debug_name(source)} <= {upper}" + elif lower is not None and upper is None: + c_render += f" {lower} <= {self._debug_name(source)}" + return c_render + return c.render(source) + + def produce_guards( + self, + placeholders, + sources, + source_ref=lambda n: n.name(), + *, + input_contexts: Optional[DimList[SymbolicContext]] = None, + # Encodes user-specified input shape equations of the form s = s' and s = fn(s'). + # (See docs on EqualityConstraint for details of the encoding.) + equalities_inputs: Optional[EqualityConstraint] = None, + _simplified=False, + # Indicates if we should produce guards for known static values. + ignore_static=True, + ) -> List[str]: + """ + Generates a list of guards strings which, when evaluated in a context that + defines tensors for all the sources, returns True or False depending + on if the guards in the list evaluated to True or not. Primarily used by Dynamo, + but this is also helpful for manual testing of guards (see + evaluate_guards_for_args) + + For convenience in testing, a source is allowed to be a str, + in which case we will assume it is a LocalSource + + simplified lets you omit duck sizing, equality and 0/1 guards. + This is useful for testing when you don't care about the boilerplate + guards, and it may be helpful for user output too (be careful though; + some equality guards are nontrivial! It would be nice to get simplified + output to print them too). It's private because it's not + intended for normal use + """ + self.log.info("produce_guards") + + # Check if we get to the same ShapeEnv state by replaying the recorded events. + # This will create a new ShapeEnv instance, and call all recorded function + # calls on this new instance. Finally, it will check whether this new instance + # has equal state. + # + # It's important that we do it in the begining of this function, since it modifies + # self.dim_constraints through its execution. Changes that happen in this method + # aren't interesting, since this is the function call we wish to reproduce at the + # end. If we wish to simply reproduce ShapeEnv instances even after this call, + # this method should also be recorded. + if self.check_recorded_events: + shape_env = replay_shape_env_events(self.events) + self.check_equal(shape_env) + + assert len(placeholders) == len(sources), f"len({placeholders}) != len({sources})" + Tensorlike = (torch.Tensor, FakeTensorMeta) + + def _create_no_constraints_context(t): + return StatelessSymbolicContext( + # Ignored; only the constraints part is relevant below. + dynamic_sizes=[DimDynamic.DYNAMIC] * t.dim(), + constraint_sizes=[None] * t.dim() + ) + + # Expand optional inputs, or verify invariants are upheld + if input_contexts is None: + input_contexts = [ + _create_no_constraints_context(t) if isinstance(t, Tensorlike) + else None for t in placeholders + ] + else: + assert len(input_contexts) == len(placeholders) + for i, (t, context) in enumerate(zip(placeholders, input_contexts)): + if isinstance(t, Tensorlike): + if context is None: + input_contexts[i] = _create_no_constraints_context(t) + else: + assert isinstance(t, (SymInt, int)) + assert not isinstance(context, list) + + # It took a lot of sweat to figure out the algorithm here. Let's + # explain how it works. + # + # The ShapeEnv lifecycle looks something like this: + # + # - For each input, you either generate a fresh Sympy symbol (s0) to + # represent its value (a binding site), or you reuse some + # preexisting symbol or expression, skipping the symbol allocation + # (e.g., duck sizing to a preexisting symbol, or expressing a + # stride as a multiplication of a separate stride and size.) + # Naively, you might expect to bind a fresh Sympy symbol for + # every input, but this is fairly wasteful as most of these + # symbols immediately simplify away, and if you don't eagerly + # specialize, e.g., 0/1 symbols, you end up with very complicated + # expressions that are not optimizable in practice. + # + # - You perform some compute on these symbols, occasionally + # introducing guards on boolean expressions on these symbols. + # In particular, whenever we guard on equality (_maybe_guard_rel), + # we can simplify shapes; e.g., when s0 == s1 * 2, we can now + # replace all occurrences of s0 with s1 * 2. Sometimes, a + # boolean expression evaluation doesn't introduce a guard, as + # the guard is already entailed by the simplifications we have + # applied. + # + # - In the end, you have a bunch of replacements (saying how to + # simplify shapes) and a bunch of guards (all the equality guards + # are trivial, because they're covered by the replacements). + # + # From the ShapeEnv, we must generate a Python expression that, when + # evaluated on a set of inputs, tells us whether or not these boolean + # expressions would have evaluated in the same way. However, + # we cannot easily compute this, as we elide recording boolean + # expressions when we think they are vacuously true. Thus, we seek + # an approximation: we must generate an expression, if true, would have + # produced an "equivalent" ShapeEnv, which would answer guard + # expressions in the same way. + # + # Our notion of equivalence is a bit subtle. For example, consider + # the ShapeEnv created from an input of size (5, 4) versus (4, 4) + # (no other guards.) Duck sizing would generate (s0, s1) in the first + # case but (s0, s0) in the second. We do NOT assume that size + # variables are disjoint; so in fact a graph that assumes the input + # could be (s0, s1) subsumes (s0, s0) (setting s0 == s1), but not + # vice versa. However, consider an analogous case (1,) versus (2,). + # Duck sizing generates (1,) and (s0,); the (s0,) graph does NOT + # subsume the (1,) graph because we assume that any size variables + # is NOT 0/1 (and make simplifications according to this; e.g., if + # we queried s0 == 0, we would immediately return False without + # returning a guard.) + # + # So, it is perhaps easier to flip things on their head: the guard + # expressions we generate here say what simplifications are valid, + # and what are not. Below, we explain each of the guard expressions + # we generate + + # TODO: Make this more efficient by binding all the size/stride/offsets + # to locals before performing tests on them. + + from torch._dynamo.source import TensorPropertySource, TensorProperty, NegateSource + + # Actual codegen must be delayed as we don't necessarily know what + # the symbol mapping is + input_guards = [] + + symbol_to_source = collections.defaultdict(list) + symbol_to_constraints = collections.defaultdict(set) + constraint_violations : List[Tuple[bool, Callable[[], str]]] = [] + + def record_constraint_violation(warn_only, debug_name, msg, hint=None): + constraint_violations.append( + (warn_only, debug_name, lambda: f"{msg}{hint()}" if hint else msg) + ) + + def is_dim(src): + return isinstance(src, TensorPropertySource) and src.prop is TensorProperty.SIZE + + if equalities_inputs: + source_index = {} + for i, src in enumerate(sources): + source_index[src.name()] = i + + def get_expression(tensor_dim_src): + fake = placeholders[source_index[tensor_dim_src.base.name()]] + symint = fake.shape[tensor_dim_src.idx] + if isinstance(symint, torch.SymInt): + return symint.node.expr + else: + assert type(symint) is int, f"Expected int, got {type(symint)}" + return symint + + for src1, src2 in equalities_inputs.source_pairs: + expr1, expr2 = get_expression(src1), get_expression(src2) + # Check whether given input shape values satisfy a specified equation s = s'. + # - Raise when the equation was violated by the given input shape values. + # - Otherwise issue a guard to constrain them. + concrete_val = self.evaluate_expr(sympy.Eq(expr1, expr2)) + if not concrete_val: + raise ConstraintViolationError( + f"{src1.name()} = {expr1.subs(self.var_to_val)}" + " is not equal to " + f"{src2.name()} = {expr2.subs(self.var_to_val)}" + ) + + for src, root, fn in equalities_inputs.derived_equalities: + expr1 = get_expression(src) + # recall that root is either a phantom symbol or an input source + expr2, debug_name = ( + (root, self.var_to_sources[root][0].name()) if isinstance(root, sympy.Symbol) + else (get_expression(root), self._debug_name(root)) + ) + expr2_ = fn(expr2) + # Check whether given input shape values satisfy a specified equation s = fn(s'). + # - Raise when the equation was violated by the given input shape values. + # - Otherwise issue a guard to constrain them. + concrete_val = self.evaluate_expr(sympy.Eq(expr1, expr2_)) + if not concrete_val: + raise ConstraintViolationError( + f"Expected input {src.name()} to be equal to " + f"{fn(sympy.Symbol(debug_name))}, " + f"where {debug_name} = {expr2.subs(self.var_to_val)}, " + f"but got {expr1.subs(self.var_to_val)}" + ) + + for phantom_symbol in equalities_inputs.phantom_symbols: + # we created additional phantom symbols that are not input shape dimensions + symbol_to_source[phantom_symbol].extend(self.var_to_sources[phantom_symbol]) + + # How do we know what the value of s0 is? Fresh variables can only be + # bound by inputs, so there MUST be some other input which binds the + # variable. If there is no such input, this is an error in our + # system. We record where all symbols come from, to help you diagnose + # why those symbols didn't occur. + # + # In fact, generally speaking it is only possible for the "outermost" + # user of a ShapeEnv to evaluate the guards, because some inputs may + # not be available to inner levels. For example, Dynamo can guard on + # tensors that never actually become graph arguments (they are + # pruned). In this case, only Dynamo knows about these arguments. + def track_symint(source, val, constraint=None): + log.debug("track_symint %s %s %s", LazyString(source.name), val, constraint) + assert not isinstance(val, SymInt) or is_symbolic(val) + + if isinstance(val, SymInt) and val.node.maybe_as_int() is not None: + val = val.node.maybe_as_int() + + if isinstance(val, SymInt): + s = val.node.expr + if isinstance(s, sympy.Symbol): + symbol_to_source[s].append(source) + if constraint is not None: + symbol_to_constraints[s].add(constraint) + elif isinstance(-s, sympy.Symbol): + symbol_to_source[-s].append(NegateSource(source)) + else: + constraint_violated = False + if isinstance(constraint, StrictMinMaxConstraint): + # try inferring the ranges of the expr s + sym_vrs = {x: self.var_to_range.get(x, None) for x in s.free_symbols} + if all(vr is not None for vr in sym_vrs.values()): + expr_vr = bound_sympy(s, sym_vrs) + if expr_vr != constraint.vr: + # the expr and constrain ranges don't match + constraint_violated = True + else: + # some of the free symbols in s don't have ranges + constraint_violated = True + elif isinstance(constraint, RelaxedUnspecConstraint): + if s.is_number: + i = int(s) + # Don't complain about 0/1 specialization, we + # expect to have to compile in this case anyway + if i not in (0, 1): + constraint_violated = True + if constraint_violated: + def hint(s): + sexpr = ShapeGuardPrinter(symbol_to_source, source_ref, self.var_to_sources).doprint(s) + return f"{sexpr}." + + var_with_range = self._render_range_for_constraint_violation(source, constraint) + msg = ( + f"Not all values of {var_with_range} are valid because " + f"{self._debug_name(source)} was inferred to be equal to " + ) + record_constraint_violation( + constraint.warn_only, + self._debug_name(source), + msg, + hint=functools.partial(hint, s), + ) + + input_guards.append((source, s)) + else: + s = sympy.Integer(val) + input_guards.append((source, s)) + constraint_violated = False + if isinstance(constraint, StrictMinMaxConstraint): + constraint_violated = True + elif isinstance(constraint, RelaxedUnspecConstraint): + # Don't complain about 0/1 specialization, we + # expect to have to compile in this case anyway + if val not in (0, 1): + constraint_violated = True + if constraint_violated: + var_with_range = self._render_range_for_constraint_violation(source, constraint) + msg = ( + f"Not all values of {var_with_range} are valid because " + f"{self._debug_name(source)} was inferred to be a constant ({val})." + ) + record_constraint_violation(constraint.warn_only, self._debug_name(source), msg) + + for t, source, context in zip(placeholders, sources, input_contexts): + if isinstance(source, str): + from torch._dynamo.source import LocalSource + source = LocalSource(source) + assert isinstance(source, Source) + if t is None: + continue + if isinstance(t, (SymInt, int)): + track_symint(source, t) + continue + assert isinstance(t, Tensorlike) + if is_traceable_wrapper_subclass(t): + from torch._dynamo.source import AttrSource + + assert isinstance(context, SubclassSymbolicContext) + + # For subclasses, we need to track symints on BOTH the outer + # and inner tensors. + sources_tensors_constraints = [ + (source, t, context.constraint_sizes) + ] + attrs, _ = t.__tensor_flatten__() + for attr in attrs: + inner_t = getattr(t, attr) + inner_context = context.inner_contexts[attr] + sources_tensors_constraints.append(( + AttrSource(source, attr), + inner_t, + inner_context.constraint_sizes + )) + else: + sources_tensors_constraints = [(source, t, context.constraint_sizes)] + + for src, curr_t, constraint in sources_tensors_constraints: + if is_sparse_any(curr_t): + for i, ss in enumerate(curr_t.size()): + property_source = TensorPropertySource(src, TensorProperty.SIZE, i) + track_symint(property_source, ss, constraint[i]) + else: + for i, ss in enumerate(curr_t.size()): + property_source = TensorPropertySource(src, TensorProperty.SIZE, i) + track_symint(property_source, ss, constraint[i]) + for i, ss in enumerate(curr_t.stride()): + track_symint(TensorPropertySource(src, TensorProperty.STRIDE, i), ss) + track_symint(TensorPropertySource(src, TensorProperty.STORAGE_OFFSET), curr_t.storage_offset()) + + # 1. Every input must equal the final simplified symbolic expression + # stored on the placeholder. Given a placeholder (s0*2, s1), + # if we have an input (2, 3), we must show s0*2 == 2 and s1 == 3. + # This does a lot of work: it covers duck sizing and equality guards. + exprs = [] + self.dim_constraints = DimConstraints( + symbol_to_source, + self.var_to_val, + set(symbol_to_constraints.keys()), + self.source_name_to_debug_name, + ) + + if not _simplified: + for source, expr in input_guards: + if self._translation_validation_enabled: + # Ignore sources that were not turned into SymInts. + srcname = source.name() + if srcname in self.source_to_symbol: + self._add_target_expr(sympy.Eq(self.source_to_symbol[srcname], expr)) + + # Small optimization + if ( + isinstance(expr, sympy.Symbol) and + symbol_to_source.get(expr) and + source == symbol_to_source[expr][0] + ): + continue + + # This logic excludes static values found on tensors from guarding, because + # dynamo's check_tensor_fn does that (see guards.cpp). + # However, for non tensor sources, we still need to guard here. + if ignore_static and isinstance(source, TensorPropertySource): + if expr.is_number: + self.log.debug("Skipping guard %s", f"{source_ref(source)} == {expr}") + continue + + if is_dim(source): + self.dim_constraints.add_equality(source, expr) + + sexpr = ShapeGuardPrinter(symbol_to_source, source_ref, self.var_to_sources).doprint(expr) + exprs.append(f"{source_ref(source)} == {sexpr}") + if ( + isinstance(source, TensorPropertySource) + and source.prop is TensorProperty.SIZE + and equalities_inputs + and len(expr.free_symbols) == 1 + ): + symbol = next(iter(expr.free_symbols)) + if ( + isinstance(expr, sympy.Symbol) and + expr in symbol_to_constraints and + not equalities_inputs.is_equal(source, symbol_to_source[expr][0]) + ): + msg = ( + f"The values of {self._debug_name(source)} = {source.name()} and " + f"{self._debug_name(symbol_to_source[expr][0])} = {symbol_to_source[expr][0].name()} " + "must always be equal." + ) + record_constraint_violation(equalities_inputs.warn_only, self._debug_name(source), msg) + + if ( + not isinstance(expr, sympy.Symbol) and + symbol in symbol_to_constraints and + not equalities_inputs.is_derived(source, symbol_to_source[symbol][0], lambda x: expr.subs(symbol, x)) + ): + src = symbol_to_source[symbol][0] + msg = ( + f"The values of {self._debug_name(source)} = {source.name()} must always be related to " + f"the values of {self._debug_name(src)} = {src.name()} by " + f"{self._debug_name(source)} = {expr.subs(symbol, sympy.sympify(self._debug_name(src)))}." + ) + record_constraint_violation(equalities_inputs.warn_only, self._debug_name(source), msg) + + # NB: Not necessary to report constraint violations here: + # constraints are guaranteed to be on symbols (we've already + # caught constants and non-atomic expressions), so we only + # have relational constraints, but we don't support those + # at the moment + + # 2. Every guard must evaluate to True (but remember many guards + # like s0 == s1*2 because trivial due to simplification) + issued = set() + + def issue_guard(guard: ShapeGuard) -> None: + expr = self.simplify(guard.expr) + + # Avoid re-issueing the same guard. + if expr in issued: + return + + issued.add(expr) + + try: + is_trivial = False + if any(is_dim(source) for s in expr.free_symbols for source in symbol_to_source[s]): + is_trivial = self.dim_constraints.add(expr) + guard_expr = ShapeGuardPrinter(symbol_to_source, source_ref, self.var_to_sources).doprint(expr) + exprs.append(guard_expr) + self._add_target_expr(expr) + # A non-relational constraint on a single sizevar can violate + # a constraint + if not is_trivial and len(expr.free_symbols) == 1: + symbol = next(iter(expr.free_symbols)) + source = symbol_to_source[symbol][0] + constraints = symbol_to_constraints[symbol] + for c in constraints: + if isinstance(c, StrictMinMaxConstraint): + var_with_range = self._render_range_for_constraint_violation(source, c) + msg = ( + f"Not all values of {var_with_range} " + f"satisfy the generated guard {guard_expr}." + ) + record_constraint_violation(c.warn_only, self._debug_name(source), msg) + elif isinstance(c, RelaxedUnspecConstraint): + # This is fine, we allow guards here as long as it + # didn't constrain it to one value (we don't + # actually know this; this depends on our + # ValueRanges reasoning capability) + pass + else: + raise AssertionError(f"unrecognized constraint {c}") + except Exception: + self.log.warning("Failing guard allocated at: \n%s", ''.join(guard.stack.format())) + raise + + # First, issue all the non-trivial guards. + for guard in self.guards: + if self._maybe_evaluate_static(guard.expr) is not None: + continue + issue_guard(guard) + + # 3. Every symbol must be within its value range (this handles 0/1 + # specialization too). + for symbol, sources in symbol_to_source.items(): + r = self.var_to_range.get(symbol) + if r is None: + if symbol not in self.var_to_range: + continue + r = self.var_to_range[symbol] + + assert sources + assert symbol.is_integer + bounds = [] + if r.lower != -sympy.oo: + if any(is_dim(source) for source in sources): + self.dim_constraints.add(sympy.Ge(symbol, r.lower)) + # Only print lower bound in simplified mode if it is not the + # default + if not _simplified or r.lower != self._default_value_range().lower: + bounds.append(str(r.lower)) + bounds.append(source_ref(sources[0])) + # NB: This looks like an off-by-one error but it's not: the + # upper bound may be sys.maxsize - 1 because we intentionally + # exclude sys.maxsize from our bounds to deal with direct + # == INT_MAX guards, but it's still dumb to actually test it. + # Note that you can be off by a pretty large constant and it + # won't matter because sizes in practice will be no where near + # the 64-bit limit. + if r.upper != sympy.oo and r.upper < sys.maxsize - 1: + if any(is_dim(source) for source in sources): + self.dim_constraints.add(sympy.Le(symbol, r.upper)) + # nontrivial upper bound is always interesting + bounds.append(str(r.upper)) + if len(bounds) > 1: + exprs.append(" <= ".join(bounds)) + + # Check constraints + constraints = symbol_to_constraints[symbol] + for c in constraints: + if isinstance(c, StrictMinMaxConstraint): + # NB: By default, we have a restrictive range + # 2 <= s0 <= sys.maxsize - 1. But export users generally + # expect to be able to specify nice ranges like [0, oo] + if not (c.vr & self._default_value_range()).issubset(r): + source = sources[0] + + expr = sympy.And(sympy.Le(r.lower, symbol), sympy.Le(symbol, r.upper)) + guard_expr = ShapeGuardPrinter(symbol_to_source, source_ref, self.var_to_sources).doprint(expr) + var_with_range = self._render_range_for_constraint_violation(source, c) + msg = ( + f"Not all values of {var_with_range} satisfy the generated guard {guard_expr}" + ) + record_constraint_violation( + c.warn_only, + self._debug_name(source), + msg, + ) + + if constraint_violations: + warn_msgs = [] + error_msgs = [] + debug_names = set() + for warn_only, debug_name, msg in constraint_violations: + if warn_only: + msg = f" {len(warn_msgs) + 1}. {msg()}" + warn_msgs.append(msg) + else: + msg = f" - {msg()}" + error_msgs.append(msg) + debug_names.add(debug_name) + if len(error_msgs) > 0: + debug_names = ', '.join(debug_names) + err = '\n'.join(error_msgs) + raise ConstraintViolationError( + f"Constraints violated ({debug_names})! " + "For more information, run with TORCH_LOGS=\"+dynamic\".\n" + f"{err}" + ) + elif len(warn_msgs) > 0: + log.debug("%s Warning only constraints violated", len(warn_msgs)) + + signpost_event( + "dynamic", + "produce_guards", + { + **self.co_fields, + **self.counter, + "num_guards": len(exprs), + "free_symbols": sum(1 for v in symbol_to_source.values() if v), + # The keys are meaningless from an aggregate perspective, so + # don't include them. Biggest first. + "symbol_guard_counts": sorted(self.symbol_guard_counter.values(), reverse=True), + }, + ) + + if self._translation_validation_enabled: + from torch.fx.experimental.validator import PopulateValidator + + # Add all deferred runtime assertions; these are not technically + # handled by produce_guards but we need to put them in the target + # set + for ras in self.deferred_runtime_asserts.values(): + for ra in ras: + self._add_target_expr(ra.expr) + + # Add value range bound guards for all symbols with no trivial bounds. + # Reason: '_maybe_evaluate_static' may eliminate guards based on the + # refined value ranges. + for sym, vr in self.var_to_range.items(): + if vr.lower != -sympy.oo: + self._add_target_expr(sympy.Le(vr.lower, sym)) + if vr.upper != sympy.oo: + self._add_target_expr(sympy.Le(sym, vr.upper)) + + # Before validating, populate the input of the validator with the + # built FX graph. + with fx_traceback.preserve_node_meta(): + PopulateValidator(self.graph, self.validator).run() + + self._check_translation_validate() + return exprs + + def produce_guards_expression(self, placeholders, ignore_static=True): + """ + Expected to be used with evaluate_guards_expression(). Produces the guards + for the given placeholders and returns a string expression to be evaluated + by evaluate_guards_expression given concrete values for the placeholders. + """ + from torch._dynamo.source import LocalSource + arg_names = [f"t{i}" for i in range(len(placeholders))] + guards = self.produce_guards(placeholders, [LocalSource(a) for a in arg_names], ignore_static=ignore_static) + if guards: + return " and ".join(guards) + return None + + def evaluate_guards_expression(self, code, args): + """ + Expected to be used with produce_guards_expression(). Evaluates an expression + generated by produce_guards_expression for the given concrete args. + """ + arg_names = [f"t{i}" for i in range(len(args))] + return eval(code, SYMPY_INTERP, {"L": dict(zip(arg_names, args))}) + + def evaluate_guards_for_args(self, placeholders, args, *, ignore_static=True): + """Generate guards for a graph's placeholder values and evaluate the guards with args + """ + code = self.produce_guards_expression(placeholders, ignore_static=ignore_static) + if code: + return self.evaluate_guards_expression(code, args) + return True + + def bind_symbols(self, placeholders, args): + """ + Given a paired list of placeholders (fake tensors with + symbolic sizes) and concrete arguments (regular tensors + with real sizes), returns a dictionary mapping each + symbol to its real value. So for example, if you + have a placeholder with size (s0, s1), binding + (2, 4) to it will give you {s0: 2, s1: 4}. This is + not guaranteed to bind ALL symbols in the ShapeEnv; + we can't bind a symbol if it doesn't occur in any placeholder, + and symbols that already have replacements won't get bindings. + + This is a little duplicative with evaluate_guards but + it's different enough that it seemed cleanest to make + another copy. This assumes the guards are already checked, + though if it's cheap we'll check for shenanigans + """ + bindings: Dict[sympy.Symbol, int] = {} + + def bind_symint(arg, val): + if isinstance(val, SymInt): + s = val.node.expr + + if isinstance(s, sympy.Symbol): + if s in bindings: + assert bindings[s] == arg, f"{bindings[s]} != {arg}" + else: + bindings[s] = arg + elif isinstance(-s, sympy.Symbol): + if -s in bindings: + assert bindings[-s] == -arg, f"{bindings[-s]} != {-arg}" + else: + bindings[-s] = -arg + + for t, arg in zip(placeholders, args): + if t is None: + continue + if isinstance(t, SymInt): + bind_symint(arg, t) + continue + assert isinstance(t, torch.Tensor) + for i, s in enumerate(t.size()): + bind_symint(arg.size(i), s) + for i, s in enumerate(t.stride()): + bind_symint(arg.stride(i), s) + bind_symint(arg.storage_offset(), t.storage_offset()) + + return bindings + + def get_nontrivial_guards(self): + """Returns a list of guard expressions that aren't statically known (i.e. not trivial)""" + return [self.simplify(guard.expr) for guard in self.guards if self._maybe_evaluate_static(guard.expr) is None] + + def format_guards(self, verbose=False): + """Format this shape env's guard expressions with optional traceback info if verbose""" + def format_tb(tb): + if not verbose: + return "" + return f"\n Guarded at:\n{''.join(' ' + l for l in tb.format())}" + + return '\n'.join(f" - {guard.expr}{format_tb(guard.stack)}" for guard in self.guards) + + def bound_sympy(self, expr: sympy.Expr, size_oblivious: bool = False) -> ValueRanges: + """Given a sympy expression, computes a ValueRanges bound for what values it can be""" + var_to_range = {x: self.var_to_range.get(x, None) for x in expr.free_symbols} + if size_oblivious: + # Clamp values of size-like variables + for x in self.size_like & var_to_range.keys(): + if var_to_range[x] is not None: + var_to_range[x] &= ValueRanges(2, sympy.oo) + return bound_sympy(expr, var_to_range) + + @_lru_cache + def _maybe_evaluate_static( + self, expr: "sympy.Expr", *, unbacked_only: bool = False, compute_hint: bool = False, + expect_rational=True, size_oblivious: bool = False + ) -> "Optional[sympy.Expr]": + """ + Tries to evaluate expr without introducing guards + + If unbacked_only == True, then we only do substitutions on + unbacked SymInts (leaving regular hinted integers alone). This could + result in an expression that still contains backed SymInts, which you + could then potentially guard on. + + Use compute_hint == True if you are trying to compute a non-binding + hint for the particular hint values of backed SymInts, e.g., if + s0 happens to be 3 this run, compute_hint will subsitute s0 with 3. + """ + expr = self.simplify(expr) + + if compute_hint: + expr = expr.xreplace(self.var_to_val) + + expr = canonicalize_bool_expr(expr) + + symbols = list(expr.free_symbols) + + # Apply known runtime asserts + for s in symbols: + # Unbacked symints only + if s in self.var_to_val: + continue + + subst = {} + + def add_expr(expr): + # Expr and negation + subst[canonicalize_bool_expr(expr)] = sympy.true + subst[canonicalize_bool_expr(sympy.Not(expr))] = sympy.false + if isinstance(expr, sympy.Rel): + # multiplying by -1 changes the direction of the inequality + dual = type(expr)(-expr.rhs, -expr.lhs) + subst[canonicalize_bool_expr(dual)] = sympy.true + subst[canonicalize_bool_expr(sympy.Not(dual))] = sympy.false + + for e in itertools.chain(self.guards, self.deferred_runtime_asserts.get(s, ())): + e = e.expr + if compute_hint: + e = canonicalize_bool_expr(e.xreplace(self.var_to_val)) + add_expr(e) + # Other relational expressions this expression implies + if isinstance(e, sympy.Eq): + add_expr(sympy.Le(e.lhs, e.rhs)) + add_expr(sympy.Ge(e.lhs, e.rhs)) + elif isinstance(e, sympy.Lt): + add_expr(sympy.Le(e.lhs, e.rhs)) + add_expr(sympy.Ne(e.lhs, e.rhs)) + + # NB: this helps us deal with And/Or connectives + expr = expr.subs(subst) + + # Simplify making use of value range lower bound + new_shape_env = {} + new_range_env = {} + for idx, k in enumerate(symbols): + if isinstance(self.var_to_val.get(k, None), SingletonInt): + # Skip var_to_range logic for SingletonInt which is only used + # for jagged layout NestedTensors today + continue + vr = self.var_to_range[k] + if size_oblivious and k in self.size_like: + lower = max(2, vr.lower) + else: + lower = vr.lower + # Don't do anything if we don't have a nontrivial lower bound + # Also don't do anything if we asked only to simplify unbacked + # SymInt + if ( + lower < (-sys.maxsize - 1) // 2 or + (unbacked_only and k in self.var_to_val) + ): + new_range_env[k] = vr + continue + # Positive means >= 1 + # Positive - 1 means >= 0 + # Positive + lower - 1 means >= lower + # The new symbol 's' is "too low", so when we substitute it in + # we have to increase it by offset (and conversely, the new + # variables have to have their value range bounds adjusted as + # well) + s = sympy.Symbol(f"shape_{idx}", positive=True, integer=True) + offset = lower - 1 + new_shape_env[k] = s + offset + new_range_env[s] = SymPyValueRangeAnalysis.add(vr, -offset) + + def replace(expr, repl): + return expr.xreplace(repl) + + try: + new_expr = replace(expr, new_shape_env) + except RecursionError: + log.warning("RecursionError in sympy.xreplace(%s, %s)", expr, new_shape_env) + self.counter["sympy_recursion_error"] += 1 + return None + + floor_div_replace = {} + for atom in new_expr.atoms(FloorDiv): + floor_div_replace[atom] = sympy.floor(atom.args[0] / atom.args[1]) + new_expr = safe_expand(new_expr.xreplace(floor_div_replace)) + # TODO: when unbacked_only, can sometimes early return even when there + # are still free symbols + if new_expr.is_number: + return new_expr + + # Check if the range can solve it statically + out = bound_sympy(new_expr, new_range_env) + if expect_rational: + _assert_bound_is_rational(new_expr, out) + if out.is_singleton(): + return out.lower + + return new_expr if unbacked_only else None + + @_lru_cache + def replace(self, expr: "sympy.Expr") -> "sympy.Expr": + """Apply symbol replacements to any symbols in the given expression + """ + replacements = {s: self._find(cast(sympy.Symbol, s)) for s in expr.free_symbols} + return safe_expand(expr.xreplace(replacements)) + + @_lru_cache + def _update_divisible(self): + new_divisible = set() + for k in self.divisible: + res = self.replace(k) + if not res.is_number: + new_divisible.add(k) + + self.divisible = new_divisible + self._update_version_counter() + + @_lru_cache + def simplify(self, expr: "sympy.Expr") -> "sympy.Expr": + """Use known constraints and replacements to simplify the given expr + """ + expr = self.replace(expr) + # TODO it would seem that this pass is not necessary given the + # below replacement of // with /, but for nested FloorDivs + # the non-recursive replacement doesn't work, and + # recursive makes it hard to look up divisibility, + # because existing divisibility info has FloorDiv in it, not / + # for now just do a separate pass to catch common nested case + if expr.has(FloorDiv): + self._update_divisible() + div_replacements = {} + for atom in expr.atoms(FloorDiv): + base, divisor = atom.args + if isinstance(divisor, FloorDiv): + base1, divisor1 = divisor.args + if self.replace(Mod(base, divisor)) in self.divisible and \ + base == base1 and self.replace(Mod(base1, divisor1)) in self.divisible: + div_replacements[atom] = divisor1 + expr = expr.xreplace(div_replacements) + expr = safe_expand(expr) + if expr.has(FloorDiv): + div_replacements = {} + pows = expr.atoms(sympy.Pow) + rationals = expr.atoms(sympy.Rational).difference(expr.atoms(sympy.Integer)) + for fd in expr.atoms(FloorDiv): + base, divisor = fd.args + if self.replace(Mod(base, divisor)) in self.divisible: + div_replacements[fd] = base / divisor + new_expr = expr.xreplace(div_replacements) + new_expr = safe_expand(new_expr) + new_pows = new_expr.atoms(sympy.Pow) + new_rationals = new_expr.atoms(sympy.Rational).difference(new_expr.atoms(sympy.Integer)) + # divisions simplified away + if new_pows.issubset(pows) and new_rationals.issubset(rationals): + expr = new_expr + return expr + + @lru_cache(256) + def size_hint(self, expr: "sympy.Expr", *, allow_none=False): + """ + Gets a size hint for a given expression from the underlying shapes we had. + Does not introduce a guard, so only use this when you can guarantee that + your code is still valid for arbitrary shapes (such as optimization decisions) + """ + result_expr = safe_expand(expr).xreplace(self.var_to_val) + if not result_expr.is_number: + + from torch.utils._sympy.singleton_int import SingletonInt + + if isinstance(result_expr, SingletonInt): + return None + r = self._maybe_evaluate_static(result_expr, compute_hint=True) + if r is not None: + return r + if allow_none: + return None + raise self._make_data_dependent_error(result_expr, expr) + return result_expr + + # NB: keep in sync with size_hint + @lru_cache(256) + def has_hint(self, expr: "sympy.Expr"): + result_expr = safe_expand(expr).xreplace(self.var_to_val) + return result_expr.is_number or self._maybe_evaluate_static(result_expr) is not None + + def _make_data_dependent_error(self, expr, unhinted_expr, *, size_oblivious_result: Optional[bool] = None): + # TODO: in a Dynamo context, having user code, and having the + # name of the local, will be much better + size_like_symbols = [] + for s in expr.free_symbols: + stacktrace = ''.join(self.var_to_stack[s].format()) + self.log.debug("Data dependent variable '%s' allocated at:\n%s", s, stacktrace) + if s in self.size_like: + size_like_symbols.append(s) + size_oblivious_result_msg = "" + if size_oblivious_result is not None: + size_oblivious_result_msg = ( + f"ATTENTION: guard_size_oblivious would fix the error, evaluating expression to {size_oblivious_result}.\n" + "Maybe you need to add guard_size_oblivious to framework code, see doc below for more guidance.\n\n" + ) + fsummary, maybe_user_loc, maybe_extra_debug = self._get_stack_summary(True) + return GuardOnDataDependentSymNode( + f"Could not guard on data-dependent expression {expr} (unhinted: {unhinted_expr}). " + f"(Size-like symbols: {', '.join(map(str, size_like_symbols)) or 'none'})\n\n" + f"{size_oblivious_result_msg}" + "Potential framework code culprit (scroll up for full backtrace):\n" + f"{''.join(traceback.StackSummary.from_list([fsummary]).format())}\n" + "For more information, run with TORCH_LOGS=\"dynamic\"\n" + "For extended logs when we create symbols, also add " + f"TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL=\"{','.join(map(str, expr.free_symbols))}\"\n" + "If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1\n" + "For more debugging help, see " + "https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing\n" + + maybe_extra_debug + # TODO: Help text about how to use our runtime tests to fix this + # problem + ) + + def _set_replacement(self, a: "sympy.Symbol", tgt: "sympy.Expr", msg: str) -> None: + """ + Adds or updates a replacement for a symbol. + Use this instead of `self.replacements[a] = tgt`. + """ + + # Precondition: a == tgt + assert isinstance(a, sympy.Symbol) + + # Handles nested tensor symbolic variables which don't have + # var_to_range bounds + tgt_bound = None + if a in self.var_to_range: + src_bound = self.var_to_range[a] + + # If you have x in [2, maxint], then 2*x in [4, 2*maxint]. + # But we don't really care that the max bound says we can + # go beyond the maximum integer size, because we aren't + # using bigints anyway. Arguably, ValueRanges should know + # to do this truncation automaticaly (to avoid doing + # bigint compute in range analysis), but right now it doesn't + # so we need to get rid of some unnecessary precision. + int_range = ValueRanges(-sys.maxsize - 1, sys.maxsize - 1) + + def issubset(x, y): + return (x & int_range).issubset(y & int_range) + + # First, refine the value range of a based on the computed value range + # of tgt. This is always OK to do, even if we decide not to do the + # substitution in the end. This might be a no-op, if a already has + # a tighter bound + tgt_bound = self.bound_sympy(tgt) + self.var_to_range[a] = src_bound & tgt_bound + + # Next, check if we can update the range of free symbols in tgt + # based on the range in a. But only do it if: + # - the source bound non-trivially improves over what we get out of + # the existing bounds. + # - the replacement is univariate and we can invert the tgt expression + if not issubset(tgt_bound, src_bound) and len(tgt.free_symbols) == 1: + b = next(iter(tgt.free_symbols)) + # Try to invert the equality + r = try_solve(sympy.Eq(a, tgt), b, floordiv_inequality=False) + if r is not None: + b_bound = self.bound_sympy(r[1]) + self.var_to_range[b] = b_bound & self.var_to_range[b] + tgt_bound = self.bound_sympy(tgt) + assert issubset(tgt_bound, src_bound) + + # TODO: Should we propagate size-like-ness? + # + # Pros: if u0 is size-like, intuitively u0 == u1 should cause u1 + # to become size-like. + # + # Cons: if u0 is size-like, what about u0 - 1 == u1? You CAN'T + # propagate in this case, because what if u0 == 0, then u1 is negative + # and clearly isn't a size. So, at minimum, any f(x) whose value + # range isn't [0, inf] given x in [0, inf] cannot propagate + # size-like-ness. But there are many situations where you could + # imagine u1 is going to be size-like and actually you just didn't + # have a refined enough value range on u0. Since even innocuous + # looking arithmetic operations can destroy size-like-ness, it's + # best to not propagate it at all and force the user to annotate it + # as necessary. + # + # Compromise: we preserve size-like-ness only for exact equality + # and nothing else. + if a in self.size_like and isinstance(tgt, sympy.Symbol): + self.size_like.add(tgt) + elif isinstance(tgt, sympy.Symbol) and tgt in self.size_like: + self.size_like.add(a) + + # Now, decide if we will do the substitution. + # + # - If the source has a non-trivial range, only substitute if + # we preserve this range. Note that we may have propagated + # the src_range to free variables in tgt when tgt is univariate + # and we could find an inverse, which helps us achieve this. + # This ensures we never "forget" about user defined ranges, + # even if they end up being defined on composite formulas + # like s0 + s1. + # + # - If the variable is unbacked, only substitute if the substitution + # would preserve the bounds also under size-like-ness conditions. + + if not issubset(tgt_bound, src_bound): + self.log.debug("skipped set_replacement %s = %s (%s) [%s not subset of %s]", a, tgt, msg, tgt_bound, src_bound) + return + elif a in self.size_like: + tgt_bound_so = self.bound_sympy(tgt, size_oblivious=True) + # This is morally equivalent to self.bound_sympy(a, size_oblivious=True) + # but handles substitutions like u0 == 0 + src_bound_so = self.var_to_range[a] + if src_bound_so.upper >= 2: + src_bound_so &= ValueRanges(2, sympy.oo) + if not issubset(tgt_bound_so, src_bound_so): + self.log.debug("skipped set_replacement %s = %s (%s) " + "[%s not subset of %s (size-oblivious conditions)]", a, tgt, msg, tgt_bound_so, src_bound_so) + return + + if config.print_specializations and isinstance(tgt, (sympy.Integer, sympy.Float)): + # specializing to a constant, which is likely unexpected + + # NOTE(avik): It is possible that we try logging the same specialization multiple times, e.g., + # when adding a to self.replacements, and again when simplifying an expression containing a. + # Thus to avoid duplication, checking whether a is in self.replacements isn't enough; if it is, + # it must not already map to `tgt`. Fortunately this check is cheap because `tgt` is a constant. + if a not in self.replacements or tgt != self.replacements[a]: + self.log.warning("Specializing %s to %s", self.var_to_sources[a][0].name(), tgt) + self.log.debug("SPECIALIZATION", stack_info=True) + log.info("set_replacement %s = %s (%s) %s", a, tgt, msg, tgt_bound) + self.replacements[a] = tgt + self._update_version_counter() + + # When specializing 'a == tgt', the equality should be also conveyed to + # Z3, in case an expression uses 'a'. + self._add_target_expr(sympy.Eq(a, tgt)) + + def _add_divisible(self, expr: "sympy.Expr"): + self.divisible.add(expr) + self._update_version_counter() + + @_lru_cache + @record_shapeenv_event() + def _find(self, a: "sympy.Symbol") -> "sympy.Expr": + """ + Implements a DSU-like algorithm to find the variable that represents a + Also handles transitive non-identity replacements. + + a: b + c + c: d + """ + if a not in self.replacements: + return a + res = self.replacements[a] + cur_replace = {s: self._find(s) for s in res.free_symbols} + self._set_replacement(a, self.replacements[a].xreplace(cur_replace), "find") + return self.replacements[a] + + @lru_cache(256) + def _maybe_guard_rel(self, expr: "sympy.Rel") -> None: + """ + The relational guard is guarded to be true. Use this information to + simplify shapes (i.e. a == b or a % 5 == 0) + """ + assert isinstance(expr, sympy.Rel) + + # A good example of what goes wrong if you don't do this is + # python test/functorch/test_aotdispatch.py -k + # test_aot_autograd_symbolic_module_exhaustive_nn_LazyConv3d_cpu_float32 + if isinstance(expr, sympy.Ne): + return + + free = list(expr.free_symbols) + + assert len(free) > 0, f"The expression should not be static by this point: {expr}" + # In case of really gnarly expression, we don't blow up + if len(free) > 5: + return + + # Prioritize unbacked symints for solving by ordering them last. + # Prefer to simplify out lexicographically higher symbols (i.e. simplify out s4 over s3). + # (NB: this unfortunately isn't strictly equivalent to simplifying out newer symbols) + # Prefer to simplify out symbols with ephemeral sources. + def _smart_symbol_sort(x): + has_only_ephemeral_sources = ( + x in self.var_to_sources and all(s.is_ephemeral() for s in self.var_to_sources[x]) + ) + size = self.size_hint(x, allow_none=True) or sys.maxsize + name = x.name + # 1 puts ephemeral sourced symbols first when sorting in reverse + return (1 if has_only_ephemeral_sources else 0, size, name) + + free = sorted(free, key=_smart_symbol_sort, reverse=True) # type: ignore[attr-defined] + lhs = expr.lhs + rhs = expr.rhs + + self._refine_ranges(expr) + + # The rest of this stuff is for equality only + if not isinstance(expr, sympy.Eq): + return + + if not expr.has(Mod): + try: + floor_div_atoms = lhs.atoms(FloorDiv).union(rhs.atoms(FloorDiv)) + if len(floor_div_atoms) > 0 and any(a.divisor != 1 for a in floor_div_atoms): + raise NotImplementedError + # short-circuit when no solving is needed + + if isinstance(lhs, sympy.Symbol) and free_unbacked_symbols(lhs): + self._set_replacement(lhs, self._find(rhs), "trivial_lhs") + elif isinstance(rhs, sympy.Symbol) and free_unbacked_symbols(rhs): + self._set_replacement(rhs, self._find(lhs), "trivial_rhs") + else: + r = try_solve(expr, free[0], floordiv_inequality=False) + if r is not None and all(t.is_integer for t in sympy.preorder_traversal(r[1])): + new_var = self._find(r[1]) + ok = False + if self.is_unbacked_symint(free[0]): + # If you have i0 + i1 + i2 = s0, don't substitute i2 = + # s0 - i0 - i1. Arguably this should be OK but the + # runtime assert machinery is very delicate right now + # so this causes things to fail e.g., + # test_split_unbacked_sizes + ok = len(free_unbacked_symbols(new_var)) <= 1 + msg = "solve_unbacked" + else: + # Never substitute backed with unbacked + ok = len(free_unbacked_symbols(new_var)) == 0 + msg = "solve_backed" + if ok: + self._set_replacement(cast(sympy.Symbol, free[0]), new_var, msg) + except NotImplementedError: + pass + if expr.has(Mod): + mod_expr = next(iter(expr.atoms(Mod))) + try: + r = try_solve(expr, mod_expr, floordiv_inequality=False) + if r is not None and r[1] == 0: + self._add_divisible(mod_expr) + # This is a little bit of extra logic to make things like + # torch.empty(i0, q).view(c, -1, q) work out + p, q = mod_expr.args + if isinstance(q, sympy.Number) and isinstance(p, sympy.Mul) and len(p.args) == 2: + c, i0 = p.args + # Given Mod(c * i0, q) == 0 + if ( + isinstance(c, sympy.Number) and + isinstance(i0, sympy.Symbol) and + self.is_unbacked_symint(i0) + ): + # We have Mod(i0, q / c) == 0, which means we can + # rewrite i0 as (q / gcd(q, c)) * i1 + d = q / sympy.gcd(q, c) + i1 = self.create_unbacked_symint().node.expr + # Propagate the value ranges. It doesn't really + # matter if we use truediv or floordiv, because we + # have established divisibility. + self.var_to_range[i1] = SymPyValueRangeAnalysis.truediv( + self.var_to_range[i0], ValueRanges.wrap(d) + ) + # Propagate size-like-ness + if i0 in self.size_like: + self.size_like.add(i1) + self._set_replacement(i0, d * i1, "divisibility") + + except NotImplementedError: + pass + return + + # See: Note - On 0/1 specialization + # NB: sys.maxsize is NOT allowed for sizes, because we use MAX_INT + # as a sentinel sometimes. Your sizevar isn't going to be + # anywhere near the max 64-bit integer anyway. + def _default_value_range(self) -> ValueRanges: + lower = 2 if self.specialize_zero_one else 0 + return ValueRanges(lower, sys.maxsize - 1) + + def _default_unspecified_value_range(self) -> ValueRanges: + return ValueRanges(-sys.maxsize - 1, sys.maxsize) + + @_lru_cache + def _simplify_floor_div(self, expr): + floor_divs = tuple(expr.atoms(FloorDiv)) + # we expect floor_divs to be exact, + # and thus add the guards for the exact floordivs, + # even if tracing doesn't require them otherwise + for fd in reversed(floor_divs): + base, divisor = fd.args + mod_expr = Mod(base, divisor) + eq_expr = sympy.Eq(mod_expr, 0) + # add necessary mod guards + self.evaluate_expr(eq_expr) + return self.simplify(expr) + + # We're about to add a guard/runtime assert, check if the ShapeEnv is frozen + # and if so issue a warning + def _check_frozen(self, expr, concrete_val): + if self.frozen: + self.counter["ignored_backward_guard"] += 1 + signpost_event( + "dynamic", + "evaluate_expr_frozen", + { + **self.co_fields, + "ignored_guard": f"{expr} == {concrete_val}", + # no version = original state (this signpost is expected) + # version 2 = dynamic backwards is eagerly compiled + "version": 2, + }, + ) + log.warning("Ignored guard %s == %s, this could result in accuracy problems", expr, concrete_val) + + + def _get_stack_summary(self, is_debug: bool = False): + fsummary = None + frame = inspect.currentframe() + try: + while frame is not None: + if frame.f_code.co_filename not in uninteresting_files(): + fsummary = traceback.FrameSummary( + frame.f_code.co_filename, + frame.f_lineno, + frame.f_code.co_name, + ) + break + frame = frame.f_back + finally: + del frame + + # NB: this stack is truncated, but it's fine because the main + # stack_info will give you the rest of the info you need + maybe_user_loc = "" + user_tb = TracingContext.extract_stack() + if user_tb: + maybe_user_loc = " at " + format_frame(user_tb[-1]) + + maybe_extra_debug = "" + if is_debug and user_tb: + maybe_extra_debug = ( + '\nUser Stack (most recent call last):\n' + + ' (snipped, see stack below for prefix)\n' + + ''.join(traceback.format_list(user_tb)) + ) + if is_debug and config.extended_debug_cpp: + cpp_stack = CapturedTraceback.extract(cpp=True) + maybe_extra_debug += "\nC++ stack trace:\n" + ''.join(cpp_stack.format()) + + return fsummary, maybe_user_loc, maybe_extra_debug + + def _log_guard(self, prefix: str, g, forcing_spec: bool): + if self.log.isEnabledFor(logging.INFO): + str_g = str(g) + is_debug = config.extended_debug_guard_added is not None and str_g == config.extended_debug_guard_added + fsummary, maybe_user_loc, maybe_extra_debug = self._get_stack_summary(is_debug) + self.log.info( + "%s %s [guard added]%s (%s)%s", + prefix if not forcing_spec else f"{prefix} (forcing_spec)", + str_g, + maybe_user_loc, + format_frame(fsummary), + maybe_extra_debug, + stack_info=is_debug, + ) + + @lru_cache(256) + @record_shapeenv_event(save_tracked_fakes=True) + def evaluate_expr(self, orig_expr: "sympy.Expr", hint=None, fx_node=None, + expect_rational=True, size_oblivious: bool = False, *, forcing_spec: bool = False): + """ + Given an expression, evaluates it, adding guards if necessary + """ + + # TODO: split conjunctions and evaluate them separately + + @lru_cache(None) + def compute_concrete_val(): + if hint is None: + return self.size_hint(orig_expr) + else: + return sympy.sympify(hint) + + # Check if: + # 1. 'translation_validation' is set + # 2. the corresponding 'fx_node' is not 'None' + # 3. the guard should not be suppressed + # + # If all of the above check, we create an FX node representing the + # actual expression to be guarded. + node = None + fresh = False + if ( + self._translation_validation_enabled + and fx_node is not None + and not self._suppress_guards_tls() + and not size_oblivious + ): + concrete_val = compute_concrete_val() + if concrete_val is sympy.true: + node, fresh = self._create_fx_call_function(torch._assert, (fx_node,)) + elif concrete_val is sympy.false: + neg, _ = self._create_fx_call_function(operator.not_, (fx_node,)) + node, fresh = self._create_fx_call_function(torch._assert, (neg,)) + else: + eql, _ = self._create_fx_call_function(operator.eq, (fx_node, concrete_val)) + node, fresh = self._create_fx_call_function(torch._assert, (eql,)) + + assert node is not None + # If this is a fresh node, we have to remember the event index that + # corresponds to this assertion node. + # Reason: so that, given an assertion node, we can replay the ShapeEnv + # events until the point where this assertion node was freshly created. + if fresh: + self._add_fx_node_metadata(node) + + # After creating the FX node corresponding to orig_expr, we must make sure that + # no error will be raised until the end of this function. + # + # Reason: the translation validation may become invalid otherwise. + # + # If an error is raised before the end of this function, we remove the FX node + # inserted, and re-raise the error. + guard = None + tb = None + + try: + if orig_expr.is_number: + self.log.debug("eval %s [trivial]", orig_expr) + # NB: don't test float as there may be precision issues + if isinstance(hint, (int, bool)): + assert orig_expr == hint, f"{orig_expr} != {hint}" + return orig_expr + + expr = orig_expr + + static_expr = self._maybe_evaluate_static(expr, + expect_rational=expect_rational, + size_oblivious=size_oblivious) + if static_expr is not None: + self.log.debug("eval %s == %s [statically known]", orig_expr, static_expr) + # NB: don't test float as there may be precision issues + if isinstance(hint, (int, bool)): + assert static_expr == hint, f"{static_expr} != {hint}" + return static_expr + + if not (expr.free_symbols <= self.var_to_val.keys()): + # TODO: dedupe this with _maybe_evaluate_static + # Attempt to eliminate the unbacked SymInt + new_expr = self._maybe_evaluate_static(expr, unbacked_only=True) + if not (new_expr.free_symbols <= self.var_to_val.keys()): + size_oblivious_result = None + if not size_oblivious: + size_oblivious_result = self._maybe_evaluate_static( + expr, + expect_rational=expect_rational, + size_oblivious=True + ) + + raise self._make_data_dependent_error( + expr.xreplace(self.var_to_val), + expr, + size_oblivious_result=size_oblivious_result + ) + expr = new_expr + + concrete_val = compute_concrete_val() + self._check_frozen(expr, concrete_val) + + if ( + config.inject_EVALUATE_EXPR_flip_equality_TESTING_ONLY + and isinstance(hint, bool) + and isinstance(expr, (sympy.Eq, sympy.Ne)) + ): + expr = sympy.Not(expr) + + # Turn this into a boolean expression, no longer need to consult + # concrete_val + suppress_maybe_guard_rel = False + if concrete_val is sympy.true: + g = expr + elif concrete_val is sympy.false: + g = sympy.Not(expr) + else: + # WARNING: we cannot actually do simplifications on guards + # on floating point values, because Sympy generally does not + # think expressions on integers can ever be equal to floating + # point (e.g., sympy.Eq(s0/6, 0.5) evaluates to False). Without + # very clear algebraic laws that hold for floating point, such + # simplifications are error prone anyway, so be sure not to + # maybe_guard_rel in those cases. + if not isinstance(concrete_val, sympy.Integer): + suppress_maybe_guard_rel = True + g = sympy.Eq(expr, concrete_val) # type: ignore[arg-type] + + if isinstance(g, sympy.Rel): + # TODO: If we successfully eliminate a symbol via equality, it + # is not actually necessary to save a guard for the equality, + # as we will implicitly generate a guard when we match that + # input against the symbol. Probably the easiest way to + # implement this is to have maybe_guard_rel return a bool + # saying if it "subsumed" the guard (and therefore the guard + # is no longer necessary) + self._maybe_guard_rel(g) + + if not self._suppress_guards_tls(): + stack = CapturedTraceback.extract(skip=1) + guard = ShapeGuard(g, stack) + # TODO: deal with duplicate guards somehow + self.guards.append(guard) + except Exception: + if fresh: + self._remove_fx_node(node) + raise + else: + if not self._suppress_guards_tls(): + assert guard is not None + + self._log_guard("eval", g, forcing_spec=forcing_spec) + + for s in g.free_symbols: + self.symbol_guard_counter[s] += 1 + # Forcing_spec to avoid infinite recursion + if ( + not forcing_spec and + config.symbol_guard_limit_before_specialize is not None and + self.symbol_guard_counter[s] > config.symbol_guard_limit_before_specialize + ): + # Force specialization + self.log.info( + "symbol_guard_limit_before_specialize=%s exceeded on %s", + config.symbol_guard_limit_before_specialize, + s + ) + self.evaluate_expr(s, forcing_spec=True) + else: + self.log.debug("eval %s [guard suppressed]", g) + + return concrete_val + + def cleanup(self): + """ + Break reference cycles. + + This destroys the stacks. If you really want to keep them, we + just need some way to break references on code objects. + """ + for g in self.guards: + g.stack.cleanup() + for s in self.var_to_stack.values(): + s.cleanup() + for ras in self.deferred_runtime_asserts.values(): + for ra in ras: + ra.stack.cleanup() + + @record_shapeenv_event(save_tracked_fakes=True) + def defer_runtime_assert(self, orig_expr: "sympy.Expr", msg, fx_node=None): + """Create an assert that is checked at runtime + + Args: + orig_expr (sympy.Expr): Boolean expression to assert is true + msg (str): Message to display on assertion failure + fx_node (Optional, torch.fx.Node): node in ``self.graph`` corresponding + to the expression, if applicable + + """ + expr = orig_expr + + # TODO: split conjunctions and evaluate them separately + + static_expr = self._maybe_evaluate_static(expr) + if static_expr is not None: + self.log.debug("runtime_assert %s == %s [statically known]", orig_expr, static_expr) + return static_expr + + # Attempt to eliminate the unbacked SymInt + new_expr = self._maybe_evaluate_static(expr, unbacked_only=True) + if new_expr.free_symbols <= self.var_to_val.keys(): + # Do a normal guard + return self.evaluate_expr(new_expr, fx_node=fx_node) + # NB: Don't use new_expr as expr; it could contain gunk like shape0 + # which we don't want to guard on + + # OK, we're definitely doing a runtime assert now + if ( + self._translation_validation_enabled + and fx_node is not None + and not self._suppress_guards_tls() + ): + node, fresh = self._create_fx_call_function(torch._assert, (fx_node,)) + assert node is not None + if fresh: + self._add_fx_node_metadata(node) + + self._check_frozen(expr, sympy.true) + + # eliminate symbols on equality tests / refine ranges + if isinstance(expr, sympy.Rel): + self._maybe_guard_rel(expr) + + if not self._suppress_guards_tls(): + # canonicalise to remove equations that are trivially equal + orig_expr = expr + expr = canonicalize_bool_expr(expr) + stack = CapturedTraceback.extract(skip=1) + ra = RuntimeAssert(expr, msg, stack) + # TODO: Do this in a way that is less janky than int(s.name[1:]) + cands = sorted([s for s in expr.free_symbols if s.name.startswith("u")], key=lambda s: int(s.name[1:])) + self.deferred_runtime_asserts.setdefault(cands[-1], []).append(ra) + self.num_deferred_runtime_asserts += 1 + self._update_version_counter() + self._log_guard("runtime_assert", orig_expr, forcing_spec=False) + else: + self.log.debug("runtime_assert %s [guard suppressed]", expr) + + return True + + # Refines the ranges of the variables present in 'guard'. + # + # This function tries to refine the range of the variables inside + # 'guard' by reasoning about it. Specifically, when 'guard' is a + # 'sympy.Relational' operation. + # + # It does mainly 3 things: + # 1. Tries to isolate a variable in the left-hand side + # 2. Compute the value range of the right-hand side + # 3. Update the value range of the variable, if better + def _refine_ranges(self, expr: sympy.Expr) -> None: + expr = self.simplify(expr) + + for symbol in expr.free_symbols: + assert isinstance(symbol, sympy.Symbol) + + if isinstance(self.var_to_val.get(symbol, None), SingletonInt): + # Skip var_to_range logic for SingletonInt which is only used + # for jagged layout NestedTensors today + continue + + r = try_solve(expr, symbol) + + if r is None or not (symbol.is_integer and r[1].is_integer): + # Range refinement only supports integer symbols for now. + # There are lots of SymPy bugs when it comes to comparing + # reals and integers, so we skip that for now. + continue + + r_expr, rhs = r + vr = self.var_to_range[symbol] + lower, upper = vr.lower, vr.upper + + rhs_vr = bound_sympy(rhs, self.var_to_range) + _assert_bound_is_rational(rhs, rhs_vr) + + # Let's suppose that we have a preexisting range for x [0, 100]. + # Now, we issue a guard x > y, where the range for y is [50, 150]. + # Then, lower = 0, rhs_vr.lower = 50 and therefore refinement can happen, + # refining x to [51, 100], since x must be greater than y, but the lowest + # y could be is 50. + # + # sympy.Eq may update both lower and upper bounds. + # sympy.G{t,e} may update the lower bound, only. + # sympy.L{t,e} may update the upper bound, only. + if lower < rhs_vr.lower and isinstance(r_expr, (sympy.Eq, sympy.Ge, sympy.Gt)): + # Strictly greater relations allow us to refine a bit more, since + # x < y implies that the lower bound for x is: y + 1. + lower = rhs_vr.lower + int(isinstance(r_expr, sympy.Gt)) + if upper > rhs_vr.upper and isinstance(r_expr, (sympy.Eq, sympy.Le, sympy.Lt)): + upper = rhs_vr.upper - int(isinstance(r_expr, sympy.Lt)) + + # Do nothing if the new value range is no better than what we already have. + if vr == ValueRanges(lower, upper): + continue + + # Updates the range and the guards corresponding to each bound of the symbol. + self.var_to_range[symbol] = ValueRanges(lower, upper) + # Clears the cache, since this update can change the result. + self._maybe_evaluate_static.cache_clear() + +def _is_int(expr): + return isinstance(expr, SymInt) and expr.node.expr.is_number + +# WARNING: This is legacy, DO NOT USE +def _is_dim_dynamic(t, d): + return hasattr(t, "_dynamo_dynamic_indices") and d in t._dynamo_dynamic_indices diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/validator.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/validator.py new file mode 100644 index 0000000000000000000000000000000000000000..1dfed971a60ae49370d1b04f484509325ba57d1e --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/validator.py @@ -0,0 +1,766 @@ +import functools +import logging +import math +import operator +import sympy +import builtins + +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union + +import torch +import torch.fx +import torch.fx.traceback as fx_traceback + +from torch._dynamo.exc import TorchDynamoException +from torch.fx.node import Argument, Target +from torch.utils._sympy.interp import sympy_interp + +log = logging.getLogger(__name__) + +try: + import z3 # type: ignore[import] + + # Translation Validation for Dynamo guards + # ======================================== + # + # Checks whether optimizations applied to the collected guards are + # valid. In other words, whether the guard function we actually run + # does not have false positives (unsound). + # + # In order to do so, we build the guards using 2 different information + # attached to each 'SymNode': + # 1. SymPy expressions + # 2. FX nodes + # + # SymPy expressions have implicit optimizations baked within itself, + # which may have a few bugs. On the other hand, we build the FX graph + # manually, with no optimizations enabled. This gives us access to + # the "ground truth". + # + # We then convert into Z3 expressions both the SymPy expressions + # (see [Note: SympyToZ3]) that reach 'ShapeEnv.produce_guards' function + # and the FX nodes (see [Note: PopulateValidator]) that go through + # 'ShapeEnv.evaluate_expr' function. Finally, we run the validation. + # (see [Note: TranslationValidator]) + + # Better Z3 to string implementation (for a small fraction of Z3). + # + # Here are the things we clean before showing the Z3 expression: + # - Rename a few ops (e.g. "Distinct" ==> "!=") + # + # - Ignore ToInt and ToReal operations: + # usually they don't really matter + # + # - Transform (ToInt (/ ...)) into (idiv ...): + # this is the pattern for floor division + # + # - Collect a chain of the same operations into one + def z3str(e: z3.ExprRef) -> str: + assert z3.is_expr(e), f"unsupported expression type: {e}" + + def get_args_str(e: z3.ExprRef) -> List[str]: + return [z3str(e.arg(i)) for i in range(e.num_args())] + + # First, we simplify the given expression. + # This is done using rewriting rules, so shouldn't take long. + e = z3.simplify(e) + + + # Only support function applications. + # Even Z3 "variables" are, in fact, function applications. + if not z3.is_app(e): + raise ValueError(f"can't print Z3 expression: {e}") + + if z3.is_int_value(e) or z3.is_rational_value(e): + return e.as_string() # type: ignore[attr-defined] + + decl = e.decl() + kind = decl.kind() + op = str(decl) + args = get_args_str(e) + + if kind == z3.Z3_OP_POWER: + op = "pow" + + elif kind in (z3.Z3_OP_ADD, z3.Z3_OP_MUL): + # Collect the arguments of chains of ADD and MUL. + # This is safe, since they are associative. + + def collect_str_args(e): + if not (z3.is_app(e) and e.decl().kind() == kind): + return [z3str(e)] + else: + return [ + x + for i in range(e.num_args()) + for x in collect_str_args(e.arg(i)) + ] + + args = collect_str_args(e) + + elif kind == z3.Z3_OP_NOT: + # Revert some conversions that z3.simplify applies: + # - a != b ==> (Not (== a b)) ==> (!= a b) + # - a < b ==> (Not (<= b a)) ==> (> b a) + # - a > b ==> (Not (<= a b)) ==> (> a b) + + assert e.num_args() == 1 + arg = e.arg(0) + + assert z3.is_app(arg) + argkind = arg.decl().kind() + + logic_inverse = { + z3.Z3_OP_EQ: "!=", + z3.Z3_OP_LE: ">", + z3.Z3_OP_GE: "<", + } + + if argkind in logic_inverse: + op = logic_inverse[argkind] + args = get_args_str(arg) + + elif kind in (z3.Z3_OP_TO_INT, z3.Z3_OP_TO_REAL): + assert e.num_args() == 1 + argstr = z3str(e.arg(0)) + + # Check if it's the floor division pattern. + if argstr.startswith("(/"): + return "(idiv" + argstr[2:] + + # Otherwise, just ignore it. + return argstr + + elif kind == z3.Z3_OP_UNINTERPRETED: + assert e.num_args() == 0 + return str(decl) + + string = op + " " + " ".join(args) + return f"({string.rstrip()})" + + # Implementation of Python semantics as Z3 expressions. + # + # Z3 Real-Int theory has operators with semantics that differ that of + # Python. Therefore, in order to get it right, we need to implement + # the (Python) semantics we are relying on in Z3. + @dataclass + class _Z3Ops: + # Validator used for adding assertions as needed. + # e.g. div(a, b) requires b != 0. + validator: "TranslationValidator" + + # The 2 functions below are used for conditionally casting between + # integer and reals. + # + # Returns a real expression from 'x'. + @staticmethod + def to_real(x: z3.ArithRef) -> z3.ArithRef: + return x if x.is_real() else z3.ToReal(x) + + # Returns an integer expression from 'x'. + @staticmethod + def to_int(x: z3.ArithRef) -> z3.ArithRef: + return x if x.is_int() else z3.ToInt(x) + + # Implements Python division semantics. + def div(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef: + self.validator.add_assertion(denominator != 0) # type: ignore[arg-type] + return _Z3Ops.to_real(numerator) / _Z3Ops.to_real(denominator) + + def floor(self, number: z3.ArithRef) -> z3.ArithRef: + # Z3 ToInt function rounds a real number towards negative infinity. + return _Z3Ops.to_int(number) + + # Python semantics for 'FloorDiv' states that before applying the floor + # function, the operands are converted to their common type. + def floordiv(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef: + cast_result_to_real = numerator.is_real() or denominator.is_real() + result = _Z3Ops.to_int(self.div(numerator, denominator)) + # Since the 'result' is already an integer, we just have to check + # whether we should cast it to real. + return _Z3Ops.to_real(result) if cast_result_to_real else result + + def ceil(self, number: z3.ArithRef) -> z3.ArithRef: + return z3.If( + self.floor(number) < number, + self.floor(number + 1), + number + ) # type: ignore[return-value] + + def max(self, a: z3.ArithRef, b: z3.ArithRef) -> z3.ArithRef: + return z3.If(a > b, a, b) # type: ignore[return-value] + + def min(self, a: z3.ArithRef, b: z3.ArithRef) -> z3.ArithRef: + return z3.If(a < b, a, b) # type: ignore[return-value] + + # Python semantics for 'Mod' is defined as: p % q = p - floordiv(p, q) * q + # It should work with both integer and reals. + def mod(self, p: z3.ArithRef, q: z3.ArithRef) -> z3.ArithRef: + return p - self.floordiv(p, q) * q + + def pow(self, base: z3.ArithRef, exp: z3.ArithRef) -> z3.ArithRef: + # Z3 can't handle complex numbers very well. + self.validator.add_assertion(z3.Or(base != 0, exp > 0)) # type: ignore[arg-type] + return base ** exp + + def sqrt(self, number: z3.ArithRef) -> z3.ArithRef: + # Square-root: + # 1. Only work with reals + number = _Z3Ops.to_real(number) + # 2. The number should be positive or zero. + # Otherwise, Z3 returns 'unknown'. + self.validator.add_assertion(number >= 0) + return number ** 0.5 + + def abs(self, number: z3.ArithRef) -> z3.ArithRef: + return z3.Abs(number) + + def round(self, number: z3.ArithRef, ndigits: Optional[z3.ArithRef] = None) -> z3.ArithRef: + if ndigits is not None: + raise ValueError("round(..., ndigits=) is currently not supported by shape validations.") + + # Pythons builtin 'round' implements the 'round half to even' strategy + # See https://en.wikipedia.org/wiki/Rounding#Rounding_half_to_even + # z3 has an equivalent z3.fpRoundToIntegral(z3.RoundNearestTiesToEven(), ...), but this only applies to + # floating point numbers, which is different from real numbers that we are dealing with here. + # Instead, we implement 'round half to even' in terms of 'round half up' (floor(x + 0.5)) and + # 'round half down' (ceil(x - 0.5)). + # Assuming 'round half up' is the default case, we need to correct ..., -3.5, -1.5, 0.5, 2.5, 4.5, ... + # to round down, i.e. use the 'round half down' strategy + return z3.If( + self.mod(number, z3.IntVal(2)) == 0.5, + self.ceil(number - 0.5), + self.floor(number + 0.5), + ) + + # Lifts a callable to be used in Z3. + # + # This function replaces the given 'op' by a function that: + # + # 1. Lifts the arguments into Z3 (i.e. make them inhabitants of Z3) + # + # 2. Calls an operation that corresponds to 'op', but works with Z3 + # inhabitants (left as is if it works as is) + def z3op(op: Callable, validator: "TranslationValidator") -> Callable: + # Operations that have booleans as their argument. + # This is needed because the argument of some FX nodes were + # literal integers, instead of booleans. So, whenever this flag + # is set, we also convert ints to booleans. + boolean_ops = {operator.not_, operator.and_, operator.or_} + as_bool = op in boolean_ops + + # Lifts the function into 'z3.ExprRef' domain. + def lift(func): + def wrap(a) -> z3.ExprRef: + if isinstance(a, (z3.ArithRef, z3.BoolRef)): + return a + # Convert it into a Z3 value, if it is some of the supported + # types below. + if isinstance(a, bool) or (as_bool and isinstance(a, int)): + return z3.BoolVal(bool(a)) + if isinstance(a, (int, sympy.Integer)): + return z3.IntVal(int(a)) + if isinstance(a, (float, sympy.Float)): + return z3.RealVal(float(a)) + raise ValueError(f"can't lift type: {type(a)}") + + @functools.wraps(func) + def wrapper(*args): + # Lifts the arguments into a list of Z3 inhabitants. + wrapped_args = (wrap(a) for a in args) + # Run the function on the Z3 expressions. + return func(*wrapped_args) + + return wrapper + + ops = _Z3Ops(validator) + replacement_map = { + # Operator module. + operator.not_: lift(z3.Not), + operator.and_: lift(z3.And), + operator.or_: lift(z3.Or), + operator.floordiv: lift(ops.floordiv), + operator.truediv: lift(ops.div), + operator.mod: lift(ops.mod), + operator.abs: lift(ops.abs), + builtins.round: lift(ops.round), + + # Math module. + math.ceil: lift(ops.ceil), + math.floor: lift(ops.floor), + + # Torch module. + torch.sym_float: lift(ops.to_real), + torch.sym_max: lift(ops.max), + torch.sym_min: lift(ops.min), + torch.sym_ite: lift(lambda b, t, f: t if b else f), + torch._sym_sqrt: lift(ops.sqrt), # type: ignore[attr-defined] + # Not lifted because we only use this function as a + # marker for adding the expression as validator input. + torch._assert: torch._assert, + } + return replacement_map[op] if op in replacement_map else lift(op) + + # Processes an FX graph, populating the given validator. + # + # [Note: PopulateValidator] + # This class walks through each node in the FX graph, translating + # them into the Z3 world. + # + # Then, whenever it finds an 'torch._assert' call_function operation, + # it adds the Z3 expression corresponding to the argument as validator + # input. + class PopulateValidator(torch.fx.Interpreter): + def __init__(self, graph: torch.fx.Graph, validator: "TranslationValidator"): + # Reference to the translation validator. + self.validator = validator + + # Build the graph module and call `Interpreter` constructor. + module = torch.fx.GraphModule(root={}, graph=graph) + super().__init__(module, garbage_collect_values=True) + + def placeholder(self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + symbol = fx_traceback.get_current_meta()["symbol"] + return self.validator.z3var(symbol) + + def call_function(self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + if target != torch._assert: + # Actually runs the node target function (which is already + # lifted) with its arguments. + return super().call_function(target, args, kwargs) + # Adds the Z3 expression corresponding to the first argument + # as a validator input. + assert len(args) == 1, f"expected 1 argument on assertion. Got: {len(args)} " + self.validator.add_source_expr(args[0]) # type: ignore[arg-type] + + # Translates SymPy expressions into Z3 expressions. + # + # [Note: SympyToZ3] + # At the time of the translation, all free variables present in the + # SymPy expression being translated must be already mapped to a Z3 + # integer variable. + class SympyToZ3: + OPERATOR_HANDLES = {"add", "mul", "eq", "ne", "lt", "gt", "le", "ge"} + + def __init__( + self, + validator: "TranslationValidator", + ) -> None: + self._validator = validator + self._ops = _Z3Ops(self._validator) + + def constant(self, value: Any, dtype: torch.dtype) -> z3.ExprRef: + if dtype is torch.int64: + return z3.IntVal(int(value)) + if dtype is torch.double: + return z3.RealVal(float(value)) + if dtype is torch.bool: + return z3.BoolVal(bool(value)) + raise ValueError(f"unsupported dtype (SympyToZ3): {dtype}") + + def truediv(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef: + return self._ops.div(numerator, denominator) + + def floordiv(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef: + return self._ops.floordiv(numerator, denominator) + + def div(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef: + return self._ops.floordiv(numerator, denominator) + + def pow(self, base: z3.ArithRef, exp: z3.ArithRef) -> z3.ArithRef: + return self._ops.pow(base, exp) + + def mod(self, p: z3.ArithRef, q: z3.ArithRef) -> z3.ArithRef: + return self._ops.mod(p, q) + + def round(self, number: z3.ArithRef, ndigits: Optional[z3.ArithRef] = None) -> z3.ArithRef: + return self._ops.round(number, ndigits) + + def __getattr__(self, name: str) -> Any: + REPLACEMENT = { + "and_": z3.And, + "or_": z3.Or, + "not_": z3.Not, + "floor": self._ops.floor, + "ceil": self._ops.ceil, + "minimum": self._ops.min, + "maximum": self._ops.max, + } + + if name in REPLACEMENT: + return REPLACEMENT[name] + if name in self.OPERATOR_HANDLES: + return getattr(operator, name) + raise AttributeError(f"unhandled operator: {name}") + + def run(self, expr: sympy.Basic) -> z3.ExprRef: + return sympy_interp(self, self._validator.symbols, expr) # type: ignore[arg-type] + + # Dynamo guards translation validator. + # + # [Note: TranslationValidator] + # Verifies whether the guards issued by 'ShapeEnv.produce_guards' are sound. + # That is: whether those (target) guards only yield TRUE whenever the original, + # unoptimized, (source) guards yield TRUE. + # + # More concretely, given 'source' and 'target' guard expressions, we wish to + # check whether the following expression holds: + # + # Not(And(source)) AND And(target) + # + # i.e. whether there is an assignment of the free variables where the opposite + # happens: target is TRUE, but source is FALSE. + class TranslationValidator: + def __init__(self) -> None: + log.debug("new instance") + + # Mapping of SymPy symbols to Z3 variables. + self.symbols: Dict[sympy.Symbol, z3.ExprRef] = {} + + # Set of source Z3 expressions. + # They represent the generated guards without any kind of + # simplification or transformation. + self._source_exprs: Set[z3.BoolRef] = set() + + # Set of target Z3 expressions. + # They represent the actual checked guards at runtime. They might + # be simplified or transformed versions of the source guards. + self._target_exprs: Set[z3.BoolRef] = set() + + # Set of Z3 expressions representing assertions over both the + # source and target expressions. + self._assertions: Set[z3.BoolRef] = set() + + # Retrieves the corresponding Z3 variable. + def z3var(self, symbol: sympy.Symbol) -> z3.ExprRef: + assert symbol in self.symbols, f"Z3 variable not found for: {symbol}" + return self.symbols[symbol] + + # Create a variable in Z3 of 'type' for 'symbol', if it doesn't already exists. + def add_var(self, symbol: sympy.Symbol, type: Type) -> z3.ExprRef: + if symbol in self.symbols: + return self.symbols[symbol] + + log.debug("new variable: %s (%s)", symbol.name, type.__name__) + + if type is int: + var = z3.Int(symbol.name) + + # If 'symbol' is positive (SymPy assumption), we have to + # convey it to Z3 as well. + if symbol.is_positive: # type: ignore[attr-defined] + self._target_exprs.add(var > 0) + elif type is float: + var = z3.Real(symbol.name) + elif type is bool: + var = z3.Bool(symbol.name) + else: + raise RuntimeError(f"unsupported type for Z3 variable: {type}") + + self.symbols[symbol] = var + return var + + # Checks whether all symbols were already added. + def _check_freesymbols(self, e: sympy.Basic) -> None: + for s in e.free_symbols: + assert isinstance(s, sympy.Symbol) + # Call 'z3var' just to check whether there's already a + # Z3 variable corresponding to 's'. + self.z3var(s) + + + def to_z3_boolean_expr(self, e: sympy.Basic) -> z3.BoolRef: + z3expr = SympyToZ3(self).run(e) + assert isinstance(z3expr, z3.BoolRef), f"expected boolean expression. Got: {z3expr}" + return z3expr + + def add_source_expr(self, e: z3.BoolRef) -> None: + if e not in self._source_exprs: + log.debug("add source guard: %s", z3str(e)) + self._source_exprs.add(e) + + def add_target_expr(self, e: sympy.Expr) -> None: + self._check_freesymbols(e) + z3expr = self.to_z3_boolean_expr(e) + if e not in self._target_exprs: + log.debug("add target guard: %s", z3str(z3expr)) + self._target_exprs.add(z3expr) + + def add_assertion(self, e: Union[z3.BoolRef, sympy.Basic]) -> None: + if isinstance(e, sympy.Basic): + self._check_freesymbols(e) + ref = self.to_z3_boolean_expr(e) + else: + ref = e + assert isinstance(ref, z3.BoolRef) + if ref not in self._assertions: + log.debug("add assertion: %s", z3str(ref)) + self._assertions.add(ref) + + def validate(self) -> None: + from torch._dynamo.utils import dynamo_timed + + if len(self._source_exprs) == 0 or len(self._target_exprs) == 0: + # If there are no source/target expressions, there's nothing we really + # wish to prove. So, we just return. + return None + + # Here, we use "QF_NRA" logic for the solver: + # "Quantifier-free Non-linear Real Arithmetic". + # + # Most of the guards expressions have: + # 1. arithmetic between integer and reals + # 2. no quantifiers + # 3. potentially non-linear. + # + # Although there's also "QF_NIRA" (mixed integer-real arithmetic), + # "QF_NRA" seems to work better on 'dynamo/test_dynamic_shapes.py'. + solver = z3.SolverFor("QF_NRA") + # Set a timeout for finding a solution. + solver.set(timeout=translation_validation_timeout()) + + # Add all the assertions to the solver. + for assertion in self._assertions: + solver.add(assertion) + + # "Is there any case where it's TRUE for the target expressions, + # but FALSE for the source expressions?" + solver.add(z3.Not(z3.And(*self._source_exprs))) + solver.add(*self._target_exprs) + + log.debug("translation validation: start") + r = dynamo_timed()(solver.check)() + if r == z3.sat: + # Target expressions are unsound. + # Log the found model and the source expressions that failed. + model = solver.model() + raise ValidationException( + model, self._assertions, self._target_exprs, + failed_source_exprs=[ + inp for inp in self._source_exprs if not model.evaluate(inp) + ] + ) + else: + if r == z3.unknown: + # Could not find a solution. It didn't fail, but it also + # didn't succeed. Canceling the validation execution (keyboard + # interrupt) also gets to this branch. + log.warning("translation validation: could not validate: got z3.unknown") + else: + # Target expressions are sound. + assert r == z3.unsat + log.debug("translation validation: success") + +except ImportError: + _HAS_Z3 = False + + __all__ = [ + "translation_validation_enabled", "translation_validation_timeout", + "ValidationException", "BisectValidationException", + ] + +else: + _HAS_Z3 = True + + __all__ = [ + "z3str", "z3op", "PopulateValidator", "SympyToZ3", "TranslationValidator", + "translation_validation_enabled", "translation_validation_timeout", + "ValidationException", "BisectValidationException", + ] + +from torch.fx.experimental import _config as config + +def translation_validation_enabled() -> bool: + # Checks everytime this function is called, in case the Dynamo + # option is set, but Z3 is not installed. + _assert_z3_installed_if_tv_set() + return _HAS_Z3 and config.translation_validation + + +def translation_validation_timeout() -> int: + return config.translation_validation_timeout + + +def _assert_z3_installed_if_tv_set(): + assert _HAS_Z3 or not config.translation_validation, ( + "translation validation requires Z3 package. Please, either install " + "z3-solver or disable translation validation." + ) + + +class ValidationException(TorchDynamoException): + def __init__(self, model, assertions, target_exprs, failed_source_exprs): + assert _HAS_Z3 + + def symbolstr(sym) -> str: + return f"{sym}: {model[sym]}" + + def joinlines(xs) -> str: + return "\n".join(f" ==> {x}" for x in xs) + + model_str = joinlines(sorted(map(symbolstr, model))) + assertions_str = joinlines(sorted(map(z3str, assertions))) + target_exprs_str = joinlines(sorted(map(z3str, target_exprs))) + failed_source_exprs_str = joinlines(sorted(map(z3str, failed_source_exprs))) + + self.msg = "translation validation failed." + self.details = f"""\ +Model: +{model_str} + +Assertions: +{assertions_str} + +Target Expressions: +{target_exprs_str} + +Failed Source Expressions: +{failed_source_exprs_str}""" + + def __str__(self): + return f"{self.msg}\n\n{self.details}" + + +class BisectValidationException(TorchDynamoException): + def __init__(self, validation_exc, expr, failed_action, traced_node): + self.msg = f"translation validation failed when {failed_action}: {expr}" + self.details = f"""\ +Failure occurred while running node: + {traced_node.format_node()} + +{validation_exc.details}""" + + def __str__(self): + return f"{self.msg}\n\n{self.details}" + +# Checks when this module is loaded. +_assert_z3_installed_if_tv_set() + +# Translation validation bisection. +# +# Bisect into the torch._assert nodes recorded in the shape_env FX graph, and raise +# the earliest ValidationException. +# +# As guards are added by ShapeEnv.evaluate_expr calls, some simplification errors +# might be silently happening. This function tries to nail down exactly at which +# point things went wrong from a validation perspective. +def bisect(shape_env): + from torch.fx.experimental.symbolic_shapes import ShapeEnv, SHAPEENV_EVENT_KEY, CURRENT_NODE_KEY + from torch.fx.experimental.recording import FakeTensorMeta, ShapeEnvEvent, replay_shape_env_events + + events = shape_env.events + + # Retrieves the ShapeEnvEvent associated with node. + def get_node_event(node: torch.fx.Node) -> ShapeEnvEvent: + assert SHAPEENV_EVENT_KEY in node.meta + return events[node.meta[SHAPEENV_EVENT_KEY]] + + # Creates a new instance of fake, but updating every symbolic value's ShapeEnv + # reference to the one given as argument. + # + # This is needed so as not to simplify a symbolic expression using a ShapeEnv + # "from the future", where it may have a different set of replacements. + def new_with_shape_env(shape_env: ShapeEnv, fake) -> Any: + if isinstance(fake, int): + return fake + if isinstance(fake, torch.SymInt): + return torch.SymInt(fake.node.with_shape_env(shape_env)) + assert isinstance(fake, FakeTensorMeta) + return FakeTensorMeta( + tuple(new_with_shape_env(shape_env, s) for s in fake.size()), + tuple(new_with_shape_env(shape_env, s) for s in fake.stride()), + new_with_shape_env(shape_env, fake.storage_offset()), + fake.is_nested, + ) + + # Checks whether the given shape_env fails when produce_guards is called. + def check_shapeenv_fails(shape_env: ShapeEnv, tracked_fakes: Optional[List[Any]]) -> Optional[ValidationException]: + assert tracked_fakes is not None + try: + # This produce_guards call is a best-effort replication, since we + # don't populate EqualityConstraint list. Reason: we would also have + # to save OutputGraph.tracked_fakes_id_to_source. + shape_env.produce_guards( + [new_with_shape_env(shape_env, a.fake) for a in tracked_fakes], + [a.source for a in tracked_fakes], + input_contexts=[a.symbolic_context for a in tracked_fakes], + ) + return None + except ValidationException as e: + return e + + # Checks whether the ShapeEnv reconstructed by replaying the events until + # node is created fails when produce_guards is called. + def check_node_fails(node: torch.fx.Node) -> Optional[ValidationException]: + number = node.meta[SHAPEENV_EVENT_KEY] + # Reconstruct shape_env until the event at event_number. + shape_env = replay_shape_env_events(events[:number + 1]) + shape_env.graph.lint() + return check_shapeenv_fails(shape_env, events[number].tracked_fakes) + + last_exception = check_shapeenv_fails(shape_env, shape_env._snapshot_tracked_fakes()) + + if not last_exception: + # We don't actually fail due to a produce_guards call. + # Stop and don't bisect. + log.info("translation validation succeeded: no errors found.") + return + + if not shape_env.should_record_events or config.translation_validation_no_bisect: + # Bisection is off. + # Return the last ValidationException we got. + raise last_exception + + # Cache the raised exception (if any) at each bisection point. + exception = {} + + # Bisection happens on the assertion nodes of the recorded FX graph for + # dynamic shapes. + assert_nodes = [node for node in shape_env.graph.nodes if node.target == torch._assert] + + # Preparing the indices for binary search. + left, mid, right = 0, 0, len(assert_nodes) - 1 + + while left < right: + mid = (left + right) // 2 + + node = assert_nodes[mid] + log.debug("bisecting at %s: %s", mid, get_node_event(node)) + + # Check whether the new shape_env raises a ValidationException or not. + exception[mid] = check_node_fails(node) + + if exception[mid]: + right = mid + else: + left = mid + 1 + + assert left in exception and isinstance(exception[left], ValidationException) + + node = assert_nodes[left] + event = get_node_event(node) + + if event.is_evaluate_expr(): + failed_action = "evaluating" + else: + assert event.is_defer_runtime_assert(), f"unexpected event type: {event}" + failed_action = "adding runtime assert" + + args = event.args + assert args is not None + assert len(args) >= 2, ( + f"bisecting expects {event.name} to have at least 2 positional arguments. " + f"Got: {len(args)}" + ) + assert isinstance(args[1], sympy.Basic), ( + f"bisecting expects {event.name} to have a SymPy expression as its second argument. " + f"Got: {type(args[1])}" + ) + + raise BisectValidationException( + exception[left], + expr=args[1], + failed_action=failed_action, + traced_node=node.meta[CURRENT_NODE_KEY], + ) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/operator_schemas.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/operator_schemas.py new file mode 100644 index 0000000000000000000000000000000000000000..142740a322bceadc7df3798a0cdebe90661fac14 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/operator_schemas.py @@ -0,0 +1,441 @@ +import torch +import inspect +import numbers +import types +import typing +import enum +import warnings +from typing import Any, Callable, Dict, List, Optional, Tuple, NamedTuple, cast, TYPE_CHECKING +from torch._jit_internal import boolean_dispatched +from ._compatibility import compatibility +from torch._ops import OpOverloadPacket, OpOverload + +if TYPE_CHECKING: + from .node import Argument + +__all__ = ["ArgsKwargsPair", "check_for_mutable_operation", "get_signature_for_torch_op", "create_type_hint", + "type_matches", "normalize_function", "normalize_module"] + +@compatibility(is_backward_compatible=False) +class ArgsKwargsPair(NamedTuple): + """ + Simple named tuple for wrapping args/kwargs pairs. + """ + args: Tuple[Any, ...] + kwargs: Dict[str, Any] + +_manual_overrides : Dict[Callable, List[inspect.Signature]] = {} + +def _nonzero_schemas(): + signatures = [] + + def nonzero(self): + pass + signatures.append(inspect.signature(nonzero)) + + def nonzero(self, *, as_tuple : bool): # type: ignore[no-redef] + pass + signatures.append(inspect.signature(nonzero)) + + return signatures + +_manual_overrides[torch.nonzero] = _nonzero_schemas() + +class _FakeGlobalNamespace: + def __getattr__(self, name): + if name == 'torch': + return torch + raise RuntimeError('Expected a torch namespace lookup') + +_type_eval_globals = {'Tensor' : torch.Tensor, 'Device' : torch.device, 'Layout' : torch.layout, + 'number' : numbers.Number, 'Future' : torch.jit.Future, + 'AnyEnumType' : enum.Enum, 'QScheme' : torch.qscheme, + '__torch__': _FakeGlobalNamespace(), 'NoneType': type(None), + 'Storage': torch.UntypedStorage, + 't': typing.TypeVar('t')} +for k in dir(typing): + _type_eval_globals[k] = getattr(typing, k) + +def _torchscript_type_to_python_type(ts_type : 'torch._C.JitType') -> Any: + """ + Convert a TorchScript type to a Python type (including subtypes) via + eval'ing the annotation_str. _type_eval_globals sets up expressions + like "List" and "Future" to map to actual types (typing.List and jit.Future) + """ + return eval(ts_type.annotation_str, _type_eval_globals) + +def _torchscript_schema_to_signature_impl(ts_schema : torch._C.FunctionSchema) -> inspect.Signature: + from inspect import Parameter + parameters : List[Parameter] = [] + for arg in ts_schema.arguments: + arg_type = _torchscript_type_to_python_type(arg.type) + default = arg.default_value if arg.has_default_value() else Parameter.empty + # TODO: Figure out if this is safe. It seems like when generating the type signatures for + # PythonArgParser, we emit signatures with `input` instead of `self` as the first tensor + # argument name. Downstream, if someone converts that positional argument to a keyword + # argument, the name mismatch will break things, so here we're going to normalize the + # name to "input" + name = arg.name if arg.name != 'self' else 'input' + kind = Parameter.KEYWORD_ONLY if arg.kwarg_only else Parameter.POSITIONAL_OR_KEYWORD + # "from" is a keyword therefore it must be a POSITIONAL_ONLY argument + if name == "from": + assert kind == Parameter.POSITIONAL_OR_KEYWORD + # ParameterKind type is internal implementation detail to inspec package + # which makes it hard to do type annotation + kind = Parameter.POSITIONAL_ONLY # type: ignore[assignment] + # This renders all previous arguments to positional only + for idx, p in enumerate(parameters): + assert p.kind == Parameter.POSITIONAL_OR_KEYWORD + parameters[idx] = Parameter(name=p.name, kind=Parameter.POSITIONAL_ONLY, default=p.default, annotation=p.annotation) + parameters.append(Parameter(name=name, kind=kind, default=default, annotation=arg_type)) + return_types = [_torchscript_type_to_python_type(ret.type) for ret in ts_schema.returns] + if len(return_types) == 0: + return_type = None + elif len(return_types) == 1: + return_type = return_types[0] + else: + return_type = tuple(return_types) + + return inspect.Signature(parameters, return_annotation=return_type) + +_SCHEMA_TO_SIGNATURE_CACHE : Dict[Tuple[str, str], inspect.Signature] = {} + +def _torchscript_schema_to_signature(ts_schema : torch._C.FunctionSchema) -> inspect.Signature: + # Cached as it's called in the hot path of FakeTensor dispatch + cache_key = ts_schema.name, ts_schema.overload_name + cache_val = _SCHEMA_TO_SIGNATURE_CACHE.get(cache_key) + if cache_val is not None: + return cache_val + + res = _torchscript_schema_to_signature_impl(ts_schema) + _SCHEMA_TO_SIGNATURE_CACHE[cache_key] = res + return res + +@compatibility(is_backward_compatible=False) +def check_for_mutable_operation(target : Callable, args : Tuple['Argument', ...], kwargs : Dict[str, 'Argument']): + signatures, schemas = get_signature_for_torch_op(target, return_schemas=True) + + if signatures and schemas: + matched_schemas = [] + + # Iterate through all of the schema until we find one that matches + # If one matches, populate `new_args_and_kwargs` with the new args/kwargs + # values. If none matches, `new_args_and_kwargs` will be None + for candidate_signature, schema in zip(signatures, schemas): + try: + candidate_signature.bind(*args, **kwargs) + matched_schemas.append((candidate_signature, schema)) + except TypeError as e: + continue + + def throw_if_mutable(schema): + if schema.is_mutable: + raise RuntimeError(f'Tried to trace mutable operation {schema}. FX only supports functional ' + f'code, so operations that mutate operands in-place (e.g. via `out` arguments) ' + f'are not supported') + + if len(matched_schemas) == 0: + # Did not match any schema. Cannot check for mutation + pass + elif len(matched_schemas) == 1: + # Matched exactly one schema, unambiguous + _, schema_to_check = matched_schemas[0] + throw_if_mutable(schema_to_check) + pass + else: + # Ambiguous schema match. Since mutability checking is best effort, + # do nothing. + pass + +@compatibility(is_backward_compatible=False) +def get_signature_for_torch_op(op : Callable, return_schemas : bool = False): + """ + Given an operator on the `torch` namespace, return a list of `inspect.Signature` + objects corresponding to the overloads of that op.. May return `None` if a signature + could not be retrieved. + + Args: + op (Callable): An operator on the `torch` namespace to look up a signature for + + Returns: + Optional[List[inspect.Signature]]: A list of signatures for the overloads of this + operator, or None if the operator signatures could not be retrieved. If + return_schemas=True, returns a tuple containing the optional Python signatures + and the optional TorchScript Function signature + """ + if isinstance(op, OpOverload): + schemas = [op._schema] + elif isinstance(op, OpOverloadPacket): + schemas = [getattr(op, overload)._schema for overload in op.overloads()] + else: + override = _manual_overrides.get(op) + if override: + return (override, None) if return_schemas else None + + aten_fn = torch.jit._builtins._find_builtin(op) + + if aten_fn is None: + return (None, None) if return_schemas else None + schemas = torch._C._jit_get_schemas_for_operator(aten_fn) + + signatures = [_torchscript_schema_to_signature(schema) for schema in schemas] + return (signatures, schemas) if return_schemas else signatures + +@compatibility(is_backward_compatible=False) +def create_type_hint(x): + try: + if isinstance(x, (list, tuple)): + # todo(chilli): Figure out the right way for mypy to handle this + if isinstance(x, list): + def ret_type(x): + return List[x] # type: ignore[valid-type] + else: + def ret_type(x): + return Tuple[x, ...] + if len(x) == 0: + return ret_type(Any) + base_type = x[0] + for t in x: + if issubclass(t, base_type): + continue + elif issubclass(base_type, t): + base_type = t + else: + return ret_type(Any) + return ret_type(base_type) + except Exception as e: + # We tried to create a type hint for list but failed. + warnings.warn(f"We were not able to successfully create type hint from the type {x}") + pass + return x + +@compatibility(is_backward_compatible=False) +def type_matches(signature_type : Any, argument_type : Any): + sig_origin_type = getattr(signature_type, '__origin__', signature_type) + + if signature_type is argument_type: + return True + + # Union types in signature. Given type needs to match one of the + # contained types in the Union + if sig_origin_type is typing.Union and signature_type != argument_type: + sig_contained = signature_type.__args__ + return any(type_matches(c, argument_type) for c in sig_contained) + + if signature_type is List[int] and argument_type is int: + # int can be promoted to List[int] + return True + + if getattr(signature_type, '__origin__', None) in {list, List}: + sig_el_type = signature_type.__args__[0] + if not inspect.isclass(sig_el_type): + warnings.warn( + f"Does not support nested parametric types, got {signature_type}. Please file a bug.") + return False + if getattr(argument_type, '__origin__', None) in {list, List}: + return issubclass(argument_type.__args__[0], sig_el_type) + + def is_homogeneous_tuple(t): + if getattr(t, "__origin__", None) not in {tuple, Tuple}: + return False + contained = t.__args__ + if t.__args__ == ((),): # Tuple[()].__args__ == ((),) for some reason + return True + return all((c is Ellipsis) or issubclass(c, sig_el_type) for c in contained) + + # Tuple[T] is accepted for List[T] parameters + return is_homogeneous_tuple(argument_type) + + # Dtype is an int in schemas + if signature_type is int and argument_type is torch.dtype: + return True + + if signature_type is numbers.Number and argument_type in {int, float}: + return True + if inspect.isclass(argument_type) and inspect.isclass(signature_type): + return issubclass(argument_type, signature_type) + + return False + +@compatibility(is_backward_compatible=False) +def normalize_function( + target: Callable, args: Tuple[Any], kwargs : Optional[Dict[str, Any]] = None, arg_types : Optional[Tuple[Any]] = None, + kwarg_types : Optional[Dict[str, Any]] = None, + normalize_to_only_use_kwargs : bool = False) -> Optional[ArgsKwargsPair]: + """ + Returns normalized arguments to PyTorch functions. This means that + `args/kwargs` will be matched up to the functional's + signature and return exclusively kwargs in positional order if + `normalize_to_only_use_kwargs` is True. + Also populates default values. Does not support positional-only + parameters or varargs parameters (*args, **kwargs). Does not support modules. + + May require `arg_types` and `kwarg_types` in order to disambiguate overloads. + + Args: + target (Callable): Function that we are normalizing + args (Tuple[Any]): Tuple of args to the function + kwargs (Optional[Dict[str, Any]]): Dict of kwargs to the function + arg_types (Optional[Tuple[Any]]): Tuple of arg types for the args + kwarg_types (Optional[Dict[str, Any]]): Dict of arg types for the kwargs + normalize_to_only_use_kwargs (bool): Whether to normalize to only use kwargs. + + Returns: + + Returns normalized_args_and_kwargs, or `None` if not successful. + """ + if kwargs is None: + kwargs = {} + new_args_and_kwargs = None + if not isinstance(target, types.BuiltinFunctionType) and not ( + isinstance(target, (OpOverloadPacket, OpOverload)) + ): + target_for_analysis = target + if target in boolean_dispatched: + # HACK: `boolean_dispatch` as used in `torch.nn.functional` makes it so that we have + # a 2-way dispatch based on a boolean value. Here we check that the `true` and `false` + # branches of the dispatch have exactly the same signature. If they do, use the `true` + # branch signature for analysis. Otherwise, leave this un-normalized + assert not isinstance(target, str) + dispatched = boolean_dispatched[target] + if_true, if_false = dispatched['if_true'], dispatched['if_false'] + if inspect.signature(if_true).parameters != inspect.signature(if_false).parameters: + return None + target_for_analysis = if_true + + assert callable(target_for_analysis) + sig = inspect.signature(inspect.unwrap(target_for_analysis)) + new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(sig, args, kwargs, normalize_to_only_use_kwargs) + else: + assert callable(target) + torch_op_schemas = get_signature_for_torch_op(target) + matched_schemas = [] + if torch_op_schemas: + # Iterate through all of the schema until we find one that matches + # If one matches, populate `new_args_and_kwargs` with the new args/kwargs + # values. If none matches, `new_args_and_kwargs` will be None + for candidate_signature in torch_op_schemas: + try: + candidate_signature.bind(*args, **kwargs) + matched_schemas.append(candidate_signature) + except TypeError as e: + continue + + if len(matched_schemas) == 0: + # Did not match any schema. Cannot normalize + pass + elif len(matched_schemas) == 1: + # Matched exactly one schema, unambiguous + new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(matched_schemas[0], args, kwargs, + normalize_to_only_use_kwargs) + else: + if arg_types is not None or kwarg_types is not None: + arg_types = arg_types if arg_types else cast(Tuple[Any], ()) + kwarg_types = kwarg_types if kwarg_types else {} + for candidate_signature in torch_op_schemas: + sig_matches = True + try: + bound_types = candidate_signature.bind(*arg_types, **kwarg_types) + for arg_name, arg_type in bound_types.arguments.items(): + param = candidate_signature.parameters[arg_name] + sig_matches = sig_matches and type_matches(param.annotation, arg_type) + except TypeError as e: + sig_matches = False + if sig_matches: + new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(candidate_signature, args, kwargs, + normalize_to_only_use_kwargs) + break + else: + # Matched more than one schema. In this situation, the caller must provide the types of + # the arguments of the overload they expect. + schema_printouts = '\n'.join(str(schema) for schema in matched_schemas) + raise RuntimeError(f'Tried to normalize arguments to {torch.typename(target)} but ' + f'the schema match was ambiguous! Please provide argument types to ' + f'the normalize_arguments() call. Available schemas:\n{schema_printouts}') + + return new_args_and_kwargs + +@compatibility(is_backward_compatible=False) +def normalize_module( + root: torch.nn.Module, target: str, args: Tuple[Any], kwargs : Optional[Dict[str, Any]] = None, + normalize_to_only_use_kwargs : bool = False) -> Optional[ArgsKwargsPair]: + """ + Returns normalized arguments to PyTorch modules. This means that + `args/kwargs` will be matched up to the functional's + signature and return exclusively kwargs in positional order if + `normalize_to_only_use_kwargs` is True. + Also populates default values. Does not support positional-only + parameters or varargs parameters (*args, **kwargs). + + Args: + root (nn.Module): root module upon which we query modules + target (Callable): Function that we are normalizing + args (Tuple[Any]): Tuple of args to the function + kwargs (Optional[Dict[str, Any]]): Dict of kwargs to the function + normalize_to_only_use_kwargs (bool): Whether to normalize to only use kwargs. + + Returns: + + Returns normalized_args_and_kwargs, or `None` if not successful. + """ + try: + submod = root.get_submodule(target) + except AttributeError as e: + raise RuntimeError(f"Tried to normalize node with target {target} but root did not " + f"have that target!") from e + if hasattr(submod.__class__, '__name__'): + classname = submod.__class__.__name__ + if getattr(torch.nn, classname, None) == submod.__class__: + sig = inspect.signature(inspect.unwrap(submod.forward)) + if kwargs is None: + kwargs = {} + new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(sig, args, kwargs, + normalize_to_only_use_kwargs) + return new_args_and_kwargs + return None + +def _args_kwargs_to_normalized_args_kwargs(sig : inspect.Signature, args : Tuple[Any, ...], + kwargs : Dict[str, Any], + normalize_to_only_use_kwargs : bool) -> Optional[ArgsKwargsPair]: + """ + Given a call target, args, and kwargs, return the arguments normalized into + an ArgsKwargsPair, or None if the type signature is not supported by + this normalization. + + Args: + + sig (inspect.Signature): Signature object for the target + args (Tuple): Arguments that appear at the callsite for `target` + kwargs (Dict): Keyword arguments that appear at the callsite for `target` + normalize_to_only_use_kwargs (bool): Whether to normalize to only use kwargs. + + Returns: + + Optional[ArgsKwargsPair]: Normalized args and kwargs for `target`, or `None` if + this target is not supported. + """ + + # Don't currently support positional-only + # or varargs (*args, **kwargs) signatures + supported_parameter_types = { + inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY} + if any(p.kind not in supported_parameter_types for p in sig.parameters.values()): + # Add an exception for one signature, which is common for random/uniform, i.e.: + # Tensor(a!) self, float from=0, float to=1, *, Generator? generator=None + # `from` is Python keyword and as such functions with that signature should have + # positional-only args, but at the same time they could be dispatched as kwargs + if list(sig.parameters.keys()) != ['input', 'from', 'to', 'generator']: + return None + + bound_args = sig.bind(*args, **kwargs) + bound_args.apply_defaults() + + new_kwargs : Dict[str, Any] = {} + new_args : List[Any] = [] + for i, param in enumerate(sig.parameters): + if not normalize_to_only_use_kwargs and i < len(args): + new_args.append(bound_args.arguments[param]) + else: + new_kwargs[param] = bound_args.arguments[param] + + return ArgsKwargsPair(tuple(new_args), new_kwargs) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/proxy.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/proxy.py new file mode 100644 index 0000000000000000000000000000000000000000..a7ec45a6722e97d55b883ca35ca7da497ea6addb --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/proxy.py @@ -0,0 +1,565 @@ +# mypy: ignore-errors + +import enum +import dis +import copy +import sys +import torch +import inspect +import operator +import traceback +import collections + +from dataclasses import is_dataclass, fields + + +from .graph import magic_methods, reflectable_magic_methods, Graph +from typing import Tuple, Dict, OrderedDict, Optional, Any, Iterator, Callable +from .node import Target, Node, Argument, base_types, map_aggregate +from ._compatibility import compatibility +from .operator_schemas import check_for_mutable_operation +import torch.fx.traceback as fx_traceback + +__all__ = ['TracerBase', 'GraphAppendingTracer', 'TraceError', + 'Proxy', 'Attribute', 'ParameterProxy', 'Scope', + 'ScopeContextManager'] + + +@compatibility(is_backward_compatible=False) +class Scope: + """ Scope object that records the module path and the module type + of a module. Scope is used to track the information of the module + that contains a Node in a Graph of GraphModule. For example:: + + class Sub(torch.nn.Module): + def forward(self, x): + # This will be a call_method Node in GraphModule, + # scope for this would be (module_path="sub", module_type=Sub) + return x.transpose(1, 2) + + class M(torch.nn.Module): + def __init__(self): + self.sub = Sub() + + def forward(self, x): + # This will be a call_method Node as well, + # scope for this would be (module_path="", None) + x = x.transpose(1, 2) + x = self.sub(x) + return x + + """ + + def __init__(self, module_path: str, module_type: Any): + super().__init__() + self.module_path = module_path + self.module_type = module_type + + +@compatibility(is_backward_compatible=False) +class ScopeContextManager: + """ A context manager to track the Scope of Node during symbolic tracing. + When entering a forward function of a Module, we'll update the scope information of + the current module, and when we exit, we'll restore the previous scope information. + """ + + def __init__( + self, + scope: Scope, + current_scope: Scope, + ): + super().__init__() + # Keep a copy of prev scope to restore on exit + self._prev_scope = copy.copy(scope) + # Update scope to current scope + scope.module_path = current_scope.module_path + scope.module_type = current_scope.module_type + # Save a reference so we can restore it + self._scope = scope + + def __enter__(self): + return self._scope + + def __exit__(self, *args): + self._scope.module_path = self._prev_scope.module_path + self._scope.module_type = self._prev_scope.module_type + return + + +_COPY_META_FIELDS = ["nn_module_stack", "source_fn_stack", "original_aten", "recompute", "from_node", "quantization_tag"] + + +@compatibility(is_backward_compatible=True) +class TracerBase: + graph: Graph + record_stack_traces : bool = False + # Feature flag for mutable schema checking + # Enableby default in 1.12 + check_mutable_operations : bool = False + # Feature flag for assert tracing + trace_asserts : bool = False + # Feature flag for proxying accesses to buffer values + proxy_buffer_attributes : bool = False + + # Name of the function to be traced. It will only be used when + # ``root`` is an instance of ``nn.Module`` + traced_func_name: str = "forward" + + # Maps the containing module's name to the operator name + scope : Scope + + # Records the module call stack + module_stack: OrderedDict[str, Tuple[str, Any]] + + # Mapping of node name to module scope + node_name_to_scope: Dict[str, Tuple[str, type]] + + @compatibility(is_backward_compatible=True) + def create_node(self, kind : str, target : Target, + args : Tuple[Argument, ...], kwargs : Dict[str, Argument], name : Optional[str] = None, + type_expr : Optional[Any] = None) -> Node: + """ + Inserts a graph node given target, args, kwargs, and name. + + This method can be overridden to do extra checking, validation, or + modification of values used in node creation. For example, one might + want to disallow in-place operations from being recorded. + """ + if kind == 'call_function' and self.check_mutable_operations: + check_for_mutable_operation(target, args, kwargs) + + node = self.graph.create_node(kind, target, args, kwargs, name, type_expr) + # TODO node_name_to_scope will be depreciated in favor of + # node.meta['nn_module_stack'] + self.node_name_to_scope[node.name] = ( + self.scope.module_path, + self.scope.module_type, + ) + # Optionally set stack trace on the created Node for debugging purposes + if fx_traceback.has_preserved_node_meta(): + current_meta: Dict[str, Any] = fx_traceback.get_current_meta() + + stack_trace = current_meta.get("stack_trace") + if stack_trace: + node.stack_trace = stack_trace + # Explicitly set the stack_trace, nn_module_stack and source_fn on the node.meta + # If other meta fields are needed, they can be added here + for field in _COPY_META_FIELDS: + if field in current_meta: + node.meta[field] = copy.copy(current_meta[field]) + + # Here we decrement to account for the sequence_nr having + # just been incremented while tracing this lowered aten op. + new_seq_nr = torch.autograd._get_sequence_nr() - 1 + # The sequence_nr increments every time a new autograd Node + # is created. During the FWD pass we store the sequence_nr + # corresponding to the last autograd Node created on this fx + # node's meta. A single aten op can create multiple autograd + # nodes as is the case with in-place foreach ops. During the + # BWD pass we retrieve the sequence_nr stored on the current + # executing autograd Node. See NOTE [ Sequence Number ]. + if current_meta.get("in_grad_fn", 0) > 0: + new_seq_nr = current_meta["grad_fn_seq_nr"][-1] + node.meta["seq_nr"] = new_seq_nr + + elif self.module_stack: + node.meta['nn_module_stack'] = copy.copy(self.module_stack) + return node + + @compatibility(is_backward_compatible=True) + def proxy(self, node: Node) -> 'Proxy': + return Proxy(node, self) + + @compatibility(is_backward_compatible=True) + def create_proxy(self, kind: str, target: Target, args: Tuple[Any, ...], kwargs: Dict[str, Any], + name: Optional[str] = None, type_expr : Optional[Any] = None, + proxy_factory_fn: Callable[[Node], 'Proxy'] = None): + ''' + Create a Node from the given arguments, then return the Node + wrapped in a Proxy object. + + If kind = 'placeholder', then we're creating a Node that + represents the parameter of a function. If we need to encode + a default parameter, we use the ``args`` tuple. ``args`` is + otherwise empty for ``placeholder`` Nodes. + ''' + + args_ = self.create_arg(args) + kwargs_ = self.create_arg(kwargs) + assert isinstance(args_, tuple) + assert isinstance(kwargs_, dict) + + node = self.create_node(kind, target, args_, kwargs_, name, type_expr) + + if not proxy_factory_fn: + proxy = self.proxy(node) + else: + proxy = proxy_factory_fn(node) + + if self.record_stack_traces and not proxy.node.stack_trace: + user_frame = self._find_user_frame() + if user_frame: + summary = traceback.extract_stack(user_frame) + tb_lines = summary.format() + # stack_trace would have innermost frame at the bottom + proxy.node.stack_trace = ''.join(tb_lines) + + return proxy + + def _find_user_frame(self): + """ + Find the Python stack frame executing the user code during + symbolic tracing. + """ + # We have to do a little dance here. Basically, walk up the callstack and + # record the first frame not in the pytorch source. This is the frame executing + # the user code during tracing. + frame = inspect.currentframe() + + pt_files = ['torch/fx/proxy.py', + 'torch/fx/_symbolic_trace.py', + 'torch/fx/experimental/proxy_tensor.py', + 'torch/_ops.py', + 'torch/_tensor.py', + 'torch/utils/_python_dispatch.py', + 'torch/_prims_common/wrappers.py', + 'torch/_refs/__init__.py', + 'torch/_refs/nn/functional/__init__.py', + 'torch/utils/_stats.py', + ] + while frame: + frame = frame.f_back + if frame and all(not frame.f_code.co_filename.endswith(file) for file in pt_files): + break + + if not frame: + return None + + return frame + + @compatibility(is_backward_compatible=True) + def create_arg(self, a: Any) -> Argument: + """ + A method that lowers the objects seen as arguments during symbolic evaluation + into Argument types that can be stored in IR. + + Can be override to support more trace-specific types. + """ + if not isinstance(a, Proxy) and hasattr(a, '__fx_create_arg__'): + return a.__fx_create_arg__(self) + # aggregates + elif isinstance(a, tuple) and hasattr(a, '_fields'): + # NamedTuple constructors don't seem to like getting a generator + # expression as an argument to their constructor, so build this + # intermediate tuple and unpack it into the NamedTuple constructor + args = tuple(self.create_arg(elem) for elem in a) + return type(a)(*args) # type: ignore[arg-type] + elif isinstance(a, (tuple, list)): + return type(a)(self.create_arg(elem) for elem in a) + elif isinstance(a, dict): + r = {} + for k, v in a.items(): + # Check for invalid dict keys. We do not want a Proxy to appear + # anywhere within the key. Since keys can be collection types, + # we iterate through the key with map_aggregate + k = self.create_arg(k) + + def no_node(arg): + if isinstance(arg, Node): + raise RuntimeError("Keys for dictionaries used as an argument cannot contain a " + f"Node. Got key: {k}") + map_aggregate(k, no_node) + + r[k] = self.create_arg(v) + return r + elif isinstance(a, slice): + return slice(self.create_arg(a.start), self.create_arg(a.stop), self.create_arg(a.step)) + + elif isinstance(a, range): + return range(self.create_arg(a.start), self.create_arg(a.stop), self.create_arg(a.step)) + + elif isinstance(a, torch._ops.OpOverload): + return a + + if isinstance(a, Proxy): + # base case: we unwrap the Proxy object + return a.node + + if is_dataclass(a): + kwargs = {field.name: self.create_arg(getattr(a, field.name)) for field in fields(a)} + return self.create_node("call_function", a.__class__, (), kwargs) + + elif isinstance(a, (*base_types, enum.Enum)) or a is None or a is ...: + return a + raise NotImplementedError(f"argument of type: {type(a)}") + + @compatibility(is_backward_compatible=True) + def to_bool(self, obj: 'Proxy') -> bool: + """Called when a proxy object is being converted to a boolean, such as + when used in control flow. Normally we don't know what to do because + we don't know the value of the proxy, but a custom tracer can attach more + information to the graph node using create_node and can choose to return a value. + """ + raise TraceError('symbolically traced variables cannot be used as inputs to control flow') + + @compatibility(is_backward_compatible=True) + def iter(self, obj: 'Proxy') -> Iterator: + """Called when a proxy object is being iterated over, such as + when used in control flow. Normally we don't know what to do because + we don't know the value of the proxy, but a custom tracer can attach more + information to the graph node using create_node and can choose to return an iterator. + """ + raise TraceError('Proxy object cannot be iterated. This can be ' + 'attempted when the Proxy is used in a loop or' + ' as a *args or **kwargs function argument. ' + 'See the torch.fx docs on pytorch.org for a ' + 'more detailed explanation of what types of ' + 'control flow can be traced, and check out the' + ' Proxy docstring for help troubleshooting ' + 'Proxy iteration errors') + + @compatibility(is_backward_compatible=True) + def keys(self, obj: 'Proxy') -> Any: + """Called when a proxy object is has the keys() method called. + This is what happens when ** is called on a proxy. This should return an + iterator it ** is suppose to work in your custom tracer. + """ + return Attribute(obj, 'keys')() + + +# used in Proxy object when just appending to the graph while not tracing. +@compatibility(is_backward_compatible=True) +class GraphAppendingTracer(TracerBase): + def __init__(self, graph: Graph): + super().__init__() + self.graph = graph + self.scope = Scope("", None) + self.module_stack = collections.OrderedDict() + self.node_name_to_scope = {} + +@compatibility(is_backward_compatible=False) +def assert_fn(x): + assert x + +@compatibility(is_backward_compatible=True) +class TraceError(ValueError): + pass + +@compatibility(is_backward_compatible=True) +class Proxy: + """ + ``Proxy`` objects are ``Node`` wrappers that flow through the + program during symbolic tracing and record all the operations + (``torch`` function calls, method calls, operators) that they touch + into the growing FX Graph. + + If you're doing graph transforms, you can wrap your own ``Proxy`` + method around a raw ``Node`` so that you can use the overloaded + operators to add additional things to a ``Graph``. + + ``Proxy`` objects cannot be iterated. In other words, the symbolic + tracer will throw an error if a ``Proxy`` is used in a loop or as + an ``*args``/``**kwargs`` function argument. + + There are two main ways around this: + 1. Factor out the untraceable logic into a top-level function and + use ``fx.wrap`` on it. + 2. If the control flow is static (i.e. the loop trip count is + based on some hyperparameter), the code can be kept in its original + position and refactored into something like:: + + for i in range(self.some_hyperparameter): + indexed_item = proxied_value[i] + + For a more detailed description into the Proxy internals, check out + the "Proxy" section in `torch/fx/OVERVIEW.md` + """ + + @compatibility(is_backward_compatible=True) + def __init__(self, node: Node, tracer: 'Optional[TracerBase]' = None): + if tracer is None: + # This allows you to create a Proxy object around a raw Node + tracer = GraphAppendingTracer(node.graph) + self.tracer = tracer + self.node = node + + def __repr__(self) -> str: + return f'Proxy({self.node.name})' + + def __getattr__(self, k) -> 'Attribute': + # note: not added to the graph yet, if this is a method call + # we peephole optimize to the method invocation + return Attribute(self, k) + + def __call__(self, *args, **kwargs) -> 'Proxy': + return self.tracer.create_proxy('call_method', '__call__', (self,) + args, kwargs) + + def __iter__(self) -> Iterator['Proxy']: + frame = inspect.currentframe() + assert frame is not None + calling_frame = frame.f_back + assert calling_frame is not None + inst_list = list(dis.get_instructions(calling_frame.f_code)) + if sys.version_info >= (3, 11): + from bisect import bisect_left + inst_idx = bisect_left(inst_list, calling_frame.f_lasti, key=lambda x: x.offset) + else: + inst_idx = calling_frame.f_lasti // 2 + inst = inst_list[inst_idx] + if inst.opname == 'UNPACK_SEQUENCE': + return (self[i] for i in range(inst.argval)) # type: ignore[index] + + return self.tracer.iter(self) + + def __abs__(self): + return self.tracer.create_proxy('call_function', operator.abs, (self,), {}) + + def __bool__(self) -> bool: + if self.tracer.trace_asserts: + # check if this boolean is used in an assertion, bytecode pattern for assertions + # is pretty stable for Python 3.7--3.9 + frame = inspect.currentframe() + assert frame is not None + calling_frame = frame.f_back + assert calling_frame is not None + insts = list(dis.get_instructions(calling_frame.f_code)) + if sys.version_info >= (3, 11): + from bisect import bisect_left + cur = bisect_left(insts, calling_frame.f_lasti, key=lambda x: x.offset) + else: + cur = calling_frame.f_lasti // 2 + inst = insts[cur] + + if inst.opname == 'POP_JUMP_IF_TRUE': + first = insts[cur + 1] + assert inst.arg is not None + last = insts[inst.arg // 2 - 1] + starts_with_assert = (first.opname == 'LOAD_GLOBAL' and first.argval == 'AssertionError' + or first.opname == 'LOAD_ASSERTION_ERROR') + if starts_with_assert and last.opname == 'RAISE_VARARGS': + self.tracer.create_proxy('call_function', assert_fn, (self,), {}) + return True + + return self.tracer.to_bool(self) + + @compatibility(is_backward_compatible=True) + def keys(self): + return self.tracer.keys(self) + + def __len__(self): + raise RuntimeError("'len' is not supported in symbolic tracing by default. If you want " + "this call to be recorded, please call torch.fx.wrap('len') at " + "module scope") + + @classmethod + def __torch_function__(cls, orig_method, types, args=None, kwargs=None): + args = args if args else () + kwargs = kwargs if kwargs else {} + + tracers : Dict[Any, None] = {} + + def find_tracer(a): + if isinstance(a, cls): + tracers[a.tracer] = None + torch.fx.node.map_aggregate(args, find_tracer) + torch.fx.node.map_aggregate(kwargs, find_tracer) + + if len(tracers) > 1: + raise RuntimeError(f'Found multiple different tracers {list(tracers.keys())} while ' + f'trying to trace operations {orig_method}') + tracer = next(iter(tracers.keys())) + + if isinstance(orig_method, torch._C.ScriptMethod): + args = (orig_method.owner,) + args + return tracer.create_proxy('call_method', orig_method.name, args, kwargs) + if torch.overrides.is_tensor_method_or_property(orig_method): + return tracer.create_proxy('call_method', orig_method.__name__, args, kwargs) + else: + if isinstance(orig_method, torch._ops.HigherOrderOperator): + # TODO: Define how to symbolically trace HigherOrderOperators + raise RuntimeError("Unable to symbolically trace HigherOrderOperators") + return tracer.create_proxy('call_function', orig_method, args, kwargs, + name=tracer.graph._target_to_str(orig_method.__name__)) + + +@compatibility(is_backward_compatible=True) +class Attribute(Proxy): + @compatibility(is_backward_compatible=True) + def __init__(self, root: Proxy, attr: str): + self.root = root + self.attr = attr + self.tracer = root.tracer + self._node: Optional[Node] = None + + @property + def node(self): + # the node for attributes is added lazily, since most will just be method calls + # which do not rely on the getitem call + if self._node is None: + self._node = self.tracer.create_proxy('call_function', getattr, (self.root, self.attr), {}).node + return self._node + + def __call__(self, *args, **kwargs): + return self.tracer.create_proxy('call_method', self.attr, (self.root,) + args, kwargs) + + +@compatibility(is_backward_compatible=False) +class ParameterProxy(Proxy): + """ + A special proxy which lets "shape", "size", "dim", and a few other + attribute accesses pass through to the underlying module parameter object, + so that conditional tests on these attributes will not throw exception during tracing + """ + def __init__(self, tracer: TracerBase, node: Node, name, param): + super().__init__(node, tracer) + assert isinstance(param, torch.nn.Parameter) + self.param = param + self.name = name + + def __repr__(self) -> str: + return f'ParameterProxy({self.name})' + + @property + def shape(self): + return self.param.shape + + def size(self): + return self.param.size() + + def dim(self): + return self.param.dim() + + @property + def ndim(self): + return self.param.ndim + + def numel(self): + return self.param.numel() + + def nelement(self): + return self.param.nelement() + + +for method in magic_methods: + def _scope(method): + def impl(*args, **kwargs): + tracer = args[0].tracer + target = getattr(operator, method) + return tracer.create_proxy('call_function', target, args, kwargs) + impl.__name__ = method + as_magic = f'__{method.strip("_")}__' + setattr(Proxy, as_magic, impl) + _scope(method) + +def _define_reflectable(orig_method_name): + method_name = f'__r{orig_method_name.strip("_")}__' + + def impl(self, rhs): + target = getattr(operator, orig_method_name) + return self.tracer.create_proxy('call_function', target, (rhs, self), {}) + impl.__name__ = method_name + impl.__qualname__ = method_name + setattr(Proxy, method_name, impl) + +for orig_method_name in reflectable_magic_methods: + _define_reflectable(orig_method_name) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/tensor_type.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/tensor_type.py new file mode 100644 index 0000000000000000000000000000000000000000..c822a38ec78e44ecf3835aa7ef18cc682d8df522 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/tensor_type.py @@ -0,0 +1,104 @@ +from torch.fx.experimental.unification import Var # type: ignore[attr-defined] + +from ._compatibility import compatibility + + +@compatibility(is_backward_compatible=False) +class TensorType: + """ + TensorType defines a type for tensors, which consists of a list of dimensions. + Example: + class M(torch.nn.Module): + def forward(self, x:TensorType((1,2,3, Dyn)), y:TensorType((1,2,3, Dyn))): + return torch.add(x, y) + """ + + def __init__(self, dim): + self.__origin__ = TensorType + self.__args__ = dim + + def __repr__(self): + return f'TensorType[{self.__args__}]' + + def __eq__(self, other): + if isinstance(other, self.__class__): + return list(self.__args__) == list(other.__args__) + else: + return False + + @staticmethod + def __class_getitem__(*args): + if len(args) == 1 and isinstance(args[0], tuple): + args = args[0] + return TensorType(tuple(args)) + + +class _DynType: + """ + _DynType defines a type which stands for the absence of type information. + """ + def __init__(self): + self.__name__ = '_DynType' + + def __eq__(self, other): + return isinstance(other, self.__class__) + + def __str__(self): + return "Dyn" + + def __repr__(self): + return "Dyn" + + +Dyn = _DynType() + +@compatibility(is_backward_compatible=False) +def is_consistent(t1, t2): + """ + A binary relation denoted by ~ that determines if t1 is consistent with t2. + The relation is reflexive, symmetric but not transitive. + returns True if t1 and t2 are consistent and False otherwise. + Example: + Dyn ~ TensorType((1,2,3)) + int ~ Dyn + int ~ int + TensorType((1,Dyn,3)) ~ TensorType((1,2,3)) + """ + + if t1 == t2: + return True + + if t1 == Dyn or t2 == Dyn or isinstance(t1, Var) or isinstance(t2, Var): + return True + + if isinstance(t1, TensorType) and isinstance(t2, TensorType): + return len(t1.__args__) == len(t2.__args__) and \ + all(is_consistent(elem1, elem2) for elem1, elem2 in zip(t1.__args__, t2.__args__)) + else: + return False + + +@compatibility(is_backward_compatible=False) +def is_more_precise(t1, t2): + """ + A binary relation denoted by <= that determines if t1 is more precise than t2. + The relation is reflexive and transitive. + returns True if t1 is more precise than t2 and False otherwise. + Example: + Dyn >= TensorType((1,2,3)) + int >= Dyn + int >= int + TensorType((1,Dyn,3)) <= TensorType((1,2,3)) + """ + if t1 == t2: + return True + + if isinstance(t2, _DynType): + return True + + if isinstance(t1, TensorType) and isinstance(t2, TensorType): + return len(t1.__args__) == len(t2.__args__) and \ + all(is_more_precise(elem1, elem2) for elem1, elem2 in zip(t1.__args__, t2.__args__)) + + else: + return False diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/multiprocessing/queue.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/multiprocessing/queue.py new file mode 100644 index 0000000000000000000000000000000000000000..99da145e75f1a9f6fb2467251948bc74361cbc02 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/multiprocessing/queue.py @@ -0,0 +1,42 @@ +import io +import multiprocessing.queues +import pickle +from multiprocessing.reduction import ForkingPickler + + +class ConnectionWrapper: + """Proxy class for _multiprocessing.Connection which uses ForkingPickler for object serialization.""" + + def __init__(self, conn): + self.conn = conn + + def send(self, obj): + buf = io.BytesIO() + ForkingPickler(buf, pickle.HIGHEST_PROTOCOL).dump(obj) + self.send_bytes(buf.getvalue()) + + def recv(self): + buf = self.recv_bytes() + return pickle.loads(buf) + + def __getattr__(self, name): + if "conn" in self.__dict__: + return getattr(self.conn, name) + raise AttributeError(f"'{type(self).__name__}' object has no attribute 'conn'") + + +class Queue(multiprocessing.queues.Queue): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._reader: ConnectionWrapper = ConnectionWrapper(self._reader) + self._writer: ConnectionWrapper = ConnectionWrapper(self._writer) + self._send = self._writer.send + self._recv = self._reader.recv + + +class SimpleQueue(multiprocessing.queues.SimpleQueue): + def _make_methods(self): + if not isinstance(self._reader, ConnectionWrapper): + self._reader: ConnectionWrapper = ConnectionWrapper(self._reader) + self._writer: ConnectionWrapper = ConnectionWrapper(self._writer) + super()._make_methods() # type: ignore[misc] diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/batchnorm.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/batchnorm.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1c93ae5528a39835f7dcbfea87687fe9e4b3e5a7 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/batchnorm.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/transformer.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/transformer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..769bc243ffb73113fdb1c082eeab567e9aff47a2 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/transformer.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/fold.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/fold.py new file mode 100644 index 0000000000000000000000000000000000000000..8ae911252f996fb9d1001eb73ab0f195e20f5ffe --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/fold.py @@ -0,0 +1,303 @@ +from .module import Module +from .. import functional as F + +from torch import Tensor +from ..common_types import _size_any_t + +__all__ = ['Fold', 'Unfold'] + +class Fold(Module): + r"""Combines an array of sliding local blocks into a large containing tensor. + + Consider a batched :attr:`input` tensor containing sliding local blocks, + e.g., patches of images, of shape :math:`(N, C \times \prod(\text{kernel\_size}), L)`, + where :math:`N` is batch dimension, :math:`C \times \prod(\text{kernel\_size})` + is the number of values within a block (a block has :math:`\prod(\text{kernel\_size})` + spatial locations each containing a :math:`C`-channeled vector), and + :math:`L` is the total number of blocks. (This is exactly the + same specification as the output shape of :class:`~torch.nn.Unfold`.) This + operation combines these local blocks into the large :attr:`output` tensor + of shape :math:`(N, C, \text{output\_size}[0], \text{output\_size}[1], \dots)` + by summing the overlapping values. Similar to :class:`~torch.nn.Unfold`, the + arguments must satisfy + + .. math:: + L = \prod_d \left\lfloor\frac{\text{output\_size}[d] + 2 \times \text{padding}[d] % + - \text{dilation}[d] \times (\text{kernel\_size}[d] - 1) - 1}{\text{stride}[d]} + 1\right\rfloor, + + where :math:`d` is over all spatial dimensions. + + * :attr:`output_size` describes the spatial shape of the large containing + tensor of the sliding local blocks. It is useful to resolve the ambiguity + when multiple input shapes map to same number of sliding blocks, e.g., + with ``stride > 0``. + + The :attr:`padding`, :attr:`stride` and :attr:`dilation` arguments specify + how the sliding blocks are retrieved. + + * :attr:`stride` controls the stride for the sliding blocks. + + * :attr:`padding` controls the amount of implicit zero-paddings on both + sides for :attr:`padding` number of points for each dimension before + reshaping. + + * :attr:`dilation` controls the spacing between the kernel points; also known as the à trous algorithm. + It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does. + + Args: + output_size (int or tuple): the shape of the spatial dimensions of the + output (i.e., ``output.sizes()[2:]``) + kernel_size (int or tuple): the size of the sliding blocks + dilation (int or tuple, optional): a parameter that controls the + stride of elements within the + neighborhood. Default: 1 + padding (int or tuple, optional): implicit zero padding to be added on + both sides of input. Default: 0 + stride (int or tuple): the stride of the sliding blocks in the input + spatial dimensions. Default: 1 + + * If :attr:`output_size`, :attr:`kernel_size`, :attr:`dilation`, + :attr:`padding` or :attr:`stride` is an int or a tuple of length 1 then + their values will be replicated across all spatial dimensions. + + * For the case of two output spatial dimensions this operation is sometimes + called ``col2im``. + + .. note:: + :class:`~torch.nn.Fold` calculates each combined value in the resulting + large tensor by summing all values from all containing blocks. + :class:`~torch.nn.Unfold` extracts the values in the local blocks by + copying from the large tensor. So, if the blocks overlap, they are not + inverses of each other. + + In general, folding and unfolding operations are related as + follows. Consider :class:`~torch.nn.Fold` and + :class:`~torch.nn.Unfold` instances created with the same + parameters: + + >>> fold_params = dict(kernel_size=..., dilation=..., padding=..., stride=...) + >>> fold = nn.Fold(output_size=..., **fold_params) + >>> unfold = nn.Unfold(**fold_params) + + Then for any (supported) ``input`` tensor the following + equality holds: + + :: + + fold(unfold(input)) == divisor * input + + where ``divisor`` is a tensor that depends only on the shape + and dtype of the ``input``: + + >>> # xdoctest: +SKIP + >>> input_ones = torch.ones(input.shape, dtype=input.dtype) + >>> divisor = fold(unfold(input_ones)) + + When the ``divisor`` tensor contains no zero elements, then + ``fold`` and ``unfold`` operations are inverses of each + other (up to constant divisor). + + .. warning:: + Currently, only unbatched (3D) or batched (4D) image-like output tensors are supported. + + Shape: + - Input: :math:`(N, C \times \prod(\text{kernel\_size}), L)` or :math:`(C \times \prod(\text{kernel\_size}), L)` + - Output: :math:`(N, C, \text{output\_size}[0], \text{output\_size}[1], \dots)` + or :math:`(C, \text{output\_size}[0], \text{output\_size}[1], \dots)` as described above + + Examples:: + + >>> fold = nn.Fold(output_size=(4, 5), kernel_size=(2, 2)) + >>> input = torch.randn(1, 3 * 2 * 2, 12) + >>> output = fold(input) + >>> output.size() + torch.Size([1, 3, 4, 5]) + + .. _link: + https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md + + """ + + __constants__ = ['output_size', 'kernel_size', 'dilation', 'padding', + 'stride'] + output_size: _size_any_t + kernel_size: _size_any_t + dilation: _size_any_t + padding: _size_any_t + stride: _size_any_t + + def __init__( + self, + output_size: _size_any_t, + kernel_size: _size_any_t, + dilation: _size_any_t = 1, + padding: _size_any_t = 0, + stride: _size_any_t = 1 + ) -> None: + super().__init__() + self.output_size = output_size + self.kernel_size = kernel_size + self.dilation = dilation + self.padding = padding + self.stride = stride + + def forward(self, input: Tensor) -> Tensor: + return F.fold(input, self.output_size, self.kernel_size, self.dilation, + self.padding, self.stride) + + def extra_repr(self) -> str: + return 'output_size={output_size}, kernel_size={kernel_size}, ' \ + 'dilation={dilation}, padding={padding}, stride={stride}'.format( + **self.__dict__ + ) + + +class Unfold(Module): + r"""Extracts sliding local blocks from a batched input tensor. + + Consider a batched :attr:`input` tensor of shape :math:`(N, C, *)`, + where :math:`N` is the batch dimension, :math:`C` is the channel dimension, + and :math:`*` represent arbitrary spatial dimensions. This operation flattens + each sliding :attr:`kernel_size`-sized block within the spatial dimensions + of :attr:`input` into a column (i.e., last dimension) of a 3-D :attr:`output` + tensor of shape :math:`(N, C \times \prod(\text{kernel\_size}), L)`, where + :math:`C \times \prod(\text{kernel\_size})` is the total number of values + within each block (a block has :math:`\prod(\text{kernel\_size})` spatial + locations each containing a :math:`C`-channeled vector), and :math:`L` is + the total number of such blocks: + + .. math:: + L = \prod_d \left\lfloor\frac{\text{spatial\_size}[d] + 2 \times \text{padding}[d] % + - \text{dilation}[d] \times (\text{kernel\_size}[d] - 1) - 1}{\text{stride}[d]} + 1\right\rfloor, + + where :math:`\text{spatial\_size}` is formed by the spatial dimensions + of :attr:`input` (:math:`*` above), and :math:`d` is over all spatial + dimensions. + + Therefore, indexing :attr:`output` at the last dimension (column dimension) + gives all values within a certain block. + + The :attr:`padding`, :attr:`stride` and :attr:`dilation` arguments specify + how the sliding blocks are retrieved. + + * :attr:`stride` controls the stride for the sliding blocks. + + * :attr:`padding` controls the amount of implicit zero-paddings on both + sides for :attr:`padding` number of points for each dimension before + reshaping. + + * :attr:`dilation` controls the spacing between the kernel points; also known as the à trous algorithm. + It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does. + + Args: + kernel_size (int or tuple): the size of the sliding blocks + dilation (int or tuple, optional): a parameter that controls the + stride of elements within the + neighborhood. Default: 1 + padding (int or tuple, optional): implicit zero padding to be added on + both sides of input. Default: 0 + stride (int or tuple, optional): the stride of the sliding blocks in the input + spatial dimensions. Default: 1 + + * If :attr:`kernel_size`, :attr:`dilation`, :attr:`padding` or + :attr:`stride` is an int or a tuple of length 1, their values will be + replicated across all spatial dimensions. + + * For the case of two input spatial dimensions this operation is sometimes + called ``im2col``. + + .. note:: + :class:`~torch.nn.Fold` calculates each combined value in the resulting + large tensor by summing all values from all containing blocks. + :class:`~torch.nn.Unfold` extracts the values in the local blocks by + copying from the large tensor. So, if the blocks overlap, they are not + inverses of each other. + + In general, folding and unfolding operations are related as + follows. Consider :class:`~torch.nn.Fold` and + :class:`~torch.nn.Unfold` instances created with the same + parameters: + + >>> fold_params = dict(kernel_size=..., dilation=..., padding=..., stride=...) + >>> fold = nn.Fold(output_size=..., **fold_params) + >>> unfold = nn.Unfold(**fold_params) + + Then for any (supported) ``input`` tensor the following + equality holds: + + :: + + fold(unfold(input)) == divisor * input + + where ``divisor`` is a tensor that depends only on the shape + and dtype of the ``input``: + + >>> # xdoctest: +SKIP + >>> input_ones = torch.ones(input.shape, dtype=input.dtype) + >>> divisor = fold(unfold(input_ones)) + + When the ``divisor`` tensor contains no zero elements, then + ``fold`` and ``unfold`` operations are inverses of each + other (up to constant divisor). + + .. warning:: + Currently, only 4-D input tensors (batched image-like tensors) are + supported. + + Shape: + - Input: :math:`(N, C, *)` + - Output: :math:`(N, C \times \prod(\text{kernel\_size}), L)` as described above + + Examples:: + + >>> unfold = nn.Unfold(kernel_size=(2, 3)) + >>> input = torch.randn(2, 5, 3, 4) + >>> output = unfold(input) + >>> # each patch contains 30 values (2x3=6 vectors, each of 5 channels) + >>> # 4 blocks (2x3 kernels) in total in the 3x4 input + >>> output.size() + torch.Size([2, 30, 4]) + + >>> # xdoctest: +IGNORE_WANT + >>> # Convolution is equivalent with Unfold + Matrix Multiplication + Fold (or view to output shape) + >>> inp = torch.randn(1, 3, 10, 12) + >>> w = torch.randn(2, 3, 4, 5) + >>> inp_unf = torch.nn.functional.unfold(inp, (4, 5)) + >>> out_unf = inp_unf.transpose(1, 2).matmul(w.view(w.size(0), -1).t()).transpose(1, 2) + >>> out = torch.nn.functional.fold(out_unf, (7, 8), (1, 1)) + >>> # or equivalently (and avoiding a copy), + >>> # out = out_unf.view(1, 2, 7, 8) + >>> (torch.nn.functional.conv2d(inp, w) - out).abs().max() + tensor(1.9073e-06) + + .. _link: + https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md + + """ + + __constants__ = ['kernel_size', 'dilation', 'padding', 'stride'] + kernel_size: _size_any_t + dilation: _size_any_t + padding: _size_any_t + stride: _size_any_t + + def __init__( + self, + kernel_size: _size_any_t, + dilation: _size_any_t = 1, + padding: _size_any_t = 0, + stride: _size_any_t = 1 + ) -> None: + super().__init__() + self.kernel_size = kernel_size + self.dilation = dilation + self.padding = padding + self.stride = stride + + def forward(self, input: Tensor) -> Tensor: + return F.unfold(input, self.kernel_size, self.dilation, + self.padding, self.stride) + + def extra_repr(self) -> str: + return 'kernel_size={kernel_size}, dilation={dilation}, padding={padding},' \ + ' stride={stride}'.format(**self.__dict__) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/instancenorm.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/instancenorm.py new file mode 100644 index 0000000000000000000000000000000000000000..d0c37b72448c3270857ac80303a844dc4ba38a36 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/instancenorm.py @@ -0,0 +1,434 @@ + +import warnings +from torch import Tensor + +from .batchnorm import _LazyNormBase, _NormBase +from .. import functional as F + +__all__ = ['InstanceNorm1d', 'InstanceNorm2d', 'InstanceNorm3d', 'LazyInstanceNorm1d', + 'LazyInstanceNorm2d', 'LazyInstanceNorm3d'] + +class _InstanceNorm(_NormBase): + def __init__( + self, + num_features: int, + eps: float = 1e-5, + momentum: float = 0.1, + affine: bool = False, + track_running_stats: bool = False, + device=None, + dtype=None + ) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__( + num_features, eps, momentum, affine, track_running_stats, **factory_kwargs) + + def _check_input_dim(self, input): + raise NotImplementedError + + def _get_no_batch_dim(self): + raise NotImplementedError + + def _handle_no_batch_input(self, input): + return self._apply_instance_norm(input.unsqueeze(0)).squeeze(0) + + def _apply_instance_norm(self, input): + return F.instance_norm( + input, self.running_mean, self.running_var, self.weight, self.bias, + self.training or not self.track_running_stats, self.momentum, self.eps) + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + version = local_metadata.get('version', None) + # at version 1: removed running_mean and running_var when + # track_running_stats=False (default) + if version is None and not self.track_running_stats: + running_stats_keys = [] + for name in ('running_mean', 'running_var'): + key = prefix + name + if key in state_dict: + running_stats_keys.append(key) + if len(running_stats_keys) > 0: + error_msgs.append( + 'Unexpected running stats buffer(s) {names} for {klass} ' + 'with track_running_stats=False. If state_dict is a ' + 'checkpoint saved before 0.4.0, this may be expected ' + 'because {klass} does not track running stats by default ' + 'since 0.4.0. Please remove these keys from state_dict. If ' + 'the running stats are actually needed, instead set ' + 'track_running_stats=True in {klass} to enable them. See ' + 'the documentation of {klass} for details.' + .format(names=" and ".join(f'"{k}"' for k in running_stats_keys), + klass=self.__class__.__name__)) + for key in running_stats_keys: + state_dict.pop(key) + + super()._load_from_state_dict( + state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs) + + def forward(self, input: Tensor) -> Tensor: + self._check_input_dim(input) + + feature_dim = input.dim() - self._get_no_batch_dim() + if input.size(feature_dim) != self.num_features: + if self.affine: + raise ValueError( + f"expected input's size at dim={feature_dim} to match num_features" + f" ({self.num_features}), but got: {input.size(feature_dim)}.") + else: + warnings.warn(f"input's size at dim={feature_dim} does not match num_features. " + "You can silence this warning by not passing in num_features, " + "which is not used because affine=False") + + if input.dim() == self._get_no_batch_dim(): + return self._handle_no_batch_input(input) + + return self._apply_instance_norm(input) + + +class InstanceNorm1d(_InstanceNorm): + r"""Applies Instance Normalization. + + This operation applies Instance Normalization + over a 2D (unbatched) or 3D (batched) input as described in the paper + `Instance Normalization: The Missing Ingredient for Fast Stylization + `__. + + .. math:: + + y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta + + The mean and standard-deviation are calculated per-dimension separately + for each object in a mini-batch. :math:`\gamma` and :math:`\beta` are learnable parameter vectors + of size `C` (where `C` is the number of features or channels of the input) if :attr:`affine` is ``True``. + The standard-deviation is calculated via the biased estimator, equivalent to + `torch.var(input, unbiased=False)`. + + By default, this layer uses instance statistics computed from input data in + both training and evaluation modes. + + If :attr:`track_running_stats` is set to ``True``, during training this + layer keeps running estimates of its computed mean and variance, which are + then used for normalization during evaluation. The running estimates are + kept with a default :attr:`momentum` of 0.1. + + .. note:: + This :attr:`momentum` argument is different from one used in optimizer + classes and the conventional notion of momentum. Mathematically, the + update rule for running statistics here is + :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`, + where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the + new observed value. + + .. note:: + :class:`InstanceNorm1d` and :class:`LayerNorm` are very similar, but + have some subtle differences. :class:`InstanceNorm1d` is applied + on each channel of channeled data like multidimensional time series, but + :class:`LayerNorm` is usually applied on entire sample and often in NLP + tasks. Additionally, :class:`LayerNorm` applies elementwise affine + transform, while :class:`InstanceNorm1d` usually don't apply affine + transform. + + Args: + num_features: number of features or channels :math:`C` of the input + eps: a value added to the denominator for numerical stability. Default: 1e-5 + momentum: the value used for the running_mean and running_var computation. Default: 0.1 + affine: a boolean value that when set to ``True``, this module has + learnable affine parameters, initialized the same way as done for batch normalization. + Default: ``False``. + track_running_stats: a boolean value that when set to ``True``, this + module tracks the running mean and variance, and when set to ``False``, + this module does not track such statistics and always uses batch + statistics in both training and eval modes. Default: ``False`` + + Shape: + - Input: :math:`(N, C, L)` or :math:`(C, L)` + - Output: :math:`(N, C, L)` or :math:`(C, L)` (same shape as input) + + Examples:: + + >>> # Without Learnable Parameters + >>> m = nn.InstanceNorm1d(100) + >>> # With Learnable Parameters + >>> m = nn.InstanceNorm1d(100, affine=True) + >>> input = torch.randn(20, 100, 40) + >>> output = m(input) + """ + + def _get_no_batch_dim(self): + return 2 + + def _check_input_dim(self, input): + if input.dim() not in (2, 3): + raise ValueError(f'expected 2D or 3D input (got {input.dim()}D input)') + + +class LazyInstanceNorm1d(_LazyNormBase, _InstanceNorm): + r"""A :class:`torch.nn.InstanceNorm1d` module with lazy initialization of the ``num_features`` argument. + + The ``num_features`` argument of the :class:`InstanceNorm1d` is inferred from the ``input.size(1)``. + The attributes that will be lazily initialized are `weight`, `bias`, `running_mean` and `running_var`. + + Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation + on lazy modules and their limitations. + + Args: + num_features: :math:`C` from an expected input of size + :math:`(N, C, L)` or :math:`(C, L)` + eps: a value added to the denominator for numerical stability. Default: 1e-5 + momentum: the value used for the running_mean and running_var computation. Default: 0.1 + affine: a boolean value that when set to ``True``, this module has + learnable affine parameters, initialized the same way as done for batch normalization. + Default: ``False``. + track_running_stats: a boolean value that when set to ``True``, this + module tracks the running mean and variance, and when set to ``False``, + this module does not track such statistics and always uses batch + statistics in both training and eval modes. Default: ``False`` + + Shape: + - Input: :math:`(N, C, L)` or :math:`(C, L)` + - Output: :math:`(N, C, L)` or :math:`(C, L)` (same shape as input) + """ + + cls_to_become = InstanceNorm1d # type: ignore[assignment] + + def _get_no_batch_dim(self): + return 2 + + def _check_input_dim(self, input): + if input.dim() not in (2, 3): + raise ValueError(f'expected 2D or 3D input (got {input.dim()}D input)') + + +class InstanceNorm2d(_InstanceNorm): + r"""Applies Instance Normalization. + + This operation applies Instance Normalization + over a 4D input (a mini-batch of 2D inputs + with additional channel dimension) as described in the paper + `Instance Normalization: The Missing Ingredient for Fast Stylization + `__. + + .. math:: + + y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta + + The mean and standard-deviation are calculated per-dimension separately + for each object in a mini-batch. :math:`\gamma` and :math:`\beta` are learnable parameter vectors + of size `C` (where `C` is the input size) if :attr:`affine` is ``True``. + The standard-deviation is calculated via the biased estimator, equivalent to + `torch.var(input, unbiased=False)`. + + By default, this layer uses instance statistics computed from input data in + both training and evaluation modes. + + If :attr:`track_running_stats` is set to ``True``, during training this + layer keeps running estimates of its computed mean and variance, which are + then used for normalization during evaluation. The running estimates are + kept with a default :attr:`momentum` of 0.1. + + .. note:: + This :attr:`momentum` argument is different from one used in optimizer + classes and the conventional notion of momentum. Mathematically, the + update rule for running statistics here is + :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`, + where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the + new observed value. + + .. note:: + :class:`InstanceNorm2d` and :class:`LayerNorm` are very similar, but + have some subtle differences. :class:`InstanceNorm2d` is applied + on each channel of channeled data like RGB images, but + :class:`LayerNorm` is usually applied on entire sample and often in NLP + tasks. Additionally, :class:`LayerNorm` applies elementwise affine + transform, while :class:`InstanceNorm2d` usually don't apply affine + transform. + + Args: + num_features: :math:`C` from an expected input of size + :math:`(N, C, H, W)` or :math:`(C, H, W)` + eps: a value added to the denominator for numerical stability. Default: 1e-5 + momentum: the value used for the running_mean and running_var computation. Default: 0.1 + affine: a boolean value that when set to ``True``, this module has + learnable affine parameters, initialized the same way as done for batch normalization. + Default: ``False``. + track_running_stats: a boolean value that when set to ``True``, this + module tracks the running mean and variance, and when set to ``False``, + this module does not track such statistics and always uses batch + statistics in both training and eval modes. Default: ``False`` + + Shape: + - Input: :math:`(N, C, H, W)` or :math:`(C, H, W)` + - Output: :math:`(N, C, H, W)` or :math:`(C, H, W)` (same shape as input) + + Examples:: + + >>> # Without Learnable Parameters + >>> m = nn.InstanceNorm2d(100) + >>> # With Learnable Parameters + >>> m = nn.InstanceNorm2d(100, affine=True) + >>> input = torch.randn(20, 100, 35, 45) + >>> output = m(input) + """ + + def _get_no_batch_dim(self): + return 3 + + def _check_input_dim(self, input): + if input.dim() not in (3, 4): + raise ValueError(f'expected 3D or 4D input (got {input.dim()}D input)') + + +class LazyInstanceNorm2d(_LazyNormBase, _InstanceNorm): + r"""A :class:`torch.nn.InstanceNorm2d` module with lazy initialization of the ``num_features`` argument. + + The ``num_features`` argument of the :class:`InstanceNorm2d` is inferred from the ``input.size(1)``. + The attributes that will be lazily initialized are `weight`, `bias`, + `running_mean` and `running_var`. + + Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation + on lazy modules and their limitations. + + Args: + num_features: :math:`C` from an expected input of size + :math:`(N, C, H, W)` or :math:`(C, H, W)` + eps: a value added to the denominator for numerical stability. Default: 1e-5 + momentum: the value used for the running_mean and running_var computation. Default: 0.1 + affine: a boolean value that when set to ``True``, this module has + learnable affine parameters, initialized the same way as done for batch normalization. + Default: ``False``. + track_running_stats: a boolean value that when set to ``True``, this + module tracks the running mean and variance, and when set to ``False``, + this module does not track such statistics and always uses batch + statistics in both training and eval modes. Default: ``False`` + + Shape: + - Input: :math:`(N, C, H, W)` or :math:`(C, H, W)` + - Output: :math:`(N, C, H, W)` or :math:`(C, H, W)` (same shape as input) + """ + + cls_to_become = InstanceNorm2d # type: ignore[assignment] + + def _get_no_batch_dim(self): + return 3 + + def _check_input_dim(self, input): + if input.dim() not in (3, 4): + raise ValueError(f'expected 3D or 4D input (got {input.dim()}D input)') + + +class InstanceNorm3d(_InstanceNorm): + r"""Applies Instance Normalization. + + This operation applies Instance Normalization + over a 5D input (a mini-batch of 3D inputs with additional channel dimension) as described in the paper + `Instance Normalization: The Missing Ingredient for Fast Stylization + `__. + + .. math:: + + y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta + + The mean and standard-deviation are calculated per-dimension separately + for each object in a mini-batch. :math:`\gamma` and :math:`\beta` are learnable parameter vectors + of size C (where C is the input size) if :attr:`affine` is ``True``. + The standard-deviation is calculated via the biased estimator, equivalent to + `torch.var(input, unbiased=False)`. + + By default, this layer uses instance statistics computed from input data in + both training and evaluation modes. + + If :attr:`track_running_stats` is set to ``True``, during training this + layer keeps running estimates of its computed mean and variance, which are + then used for normalization during evaluation. The running estimates are + kept with a default :attr:`momentum` of 0.1. + + .. note:: + This :attr:`momentum` argument is different from one used in optimizer + classes and the conventional notion of momentum. Mathematically, the + update rule for running statistics here is + :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`, + where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the + new observed value. + + .. note:: + :class:`InstanceNorm3d` and :class:`LayerNorm` are very similar, but + have some subtle differences. :class:`InstanceNorm3d` is applied + on each channel of channeled data like 3D models with RGB color, but + :class:`LayerNorm` is usually applied on entire sample and often in NLP + tasks. Additionally, :class:`LayerNorm` applies elementwise affine + transform, while :class:`InstanceNorm3d` usually don't apply affine + transform. + + Args: + num_features: :math:`C` from an expected input of size + :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` + eps: a value added to the denominator for numerical stability. Default: 1e-5 + momentum: the value used for the running_mean and running_var computation. Default: 0.1 + affine: a boolean value that when set to ``True``, this module has + learnable affine parameters, initialized the same way as done for batch normalization. + Default: ``False``. + track_running_stats: a boolean value that when set to ``True``, this + module tracks the running mean and variance, and when set to ``False``, + this module does not track such statistics and always uses batch + statistics in both training and eval modes. Default: ``False`` + + Shape: + - Input: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` + - Output: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` (same shape as input) + + Examples:: + + >>> # Without Learnable Parameters + >>> m = nn.InstanceNorm3d(100) + >>> # With Learnable Parameters + >>> m = nn.InstanceNorm3d(100, affine=True) + >>> input = torch.randn(20, 100, 35, 45, 10) + >>> output = m(input) + """ + + def _get_no_batch_dim(self): + return 4 + + def _check_input_dim(self, input): + if input.dim() not in (4, 5): + raise ValueError(f'expected 4D or 5D input (got {input.dim()}D input)') + + +class LazyInstanceNorm3d(_LazyNormBase, _InstanceNorm): + r"""A :class:`torch.nn.InstanceNorm3d` module with lazy initialization of the ``num_features`` argument. + + The ``num_features`` argument of the :class:`InstanceNorm3d` is inferred from the ``input.size(1)``. + The attributes that will be lazily initialized are `weight`, `bias`, + `running_mean` and `running_var`. + + Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation + on lazy modules and their limitations. + + Args: + num_features: :math:`C` from an expected input of size + :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` + eps: a value added to the denominator for numerical stability. Default: 1e-5 + momentum: the value used for the running_mean and running_var computation. Default: 0.1 + affine: a boolean value that when set to ``True``, this module has + learnable affine parameters, initialized the same way as done for batch normalization. + Default: ``False``. + track_running_stats: a boolean value that when set to ``True``, this + module tracks the running mean and variance, and when set to ``False``, + this module does not track such statistics and always uses batch + statistics in both training and eval modes. Default: ``False`` + + Shape: + - Input: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` + - Output: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` (same shape as input) + """ + + cls_to_become = InstanceNorm3d # type: ignore[assignment] + + def _get_no_batch_dim(self): + return 4 + + def _check_input_dim(self, input): + if input.dim() not in (4, 5): + raise ValueError(f'expected 4D or 5D input (got {input.dim()}D input)') diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/lazy.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/lazy.py new file mode 100644 index 0000000000000000000000000000000000000000..52784ae5110a81ae62a0f5ab02ddc06113675d32 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/lazy.py @@ -0,0 +1,265 @@ +import itertools +import warnings +from typing import Protocol, Optional, Type, Any + +import torch +from ..parameter import is_lazy + +__all__ = ['LazyModuleMixin'] + +class _LazyProtocol(Protocol): + """This class is used to avoid errors with mypy checks for the attributes in a mixin. + + https://mypy.readthedocs.io/en/latest/more_types.html#mixin-classes + """ + + def _register_load_state_dict_pre_hook(self, hook): + ... + + def register_forward_pre_hook(self, hook, *, prepend=False, with_kwargs=False): + ... + + def _lazy_load_hook( + self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + ... + + def _get_name(self): + ... + + def _infer_parameters(self, module, input): + ... + + @property + def _parameters(self): + ... + + @property + def _buffers(self): + ... + + @property + def _non_persistent_buffers_set(self): + ... + + @property + def _load_hook(self): + ... + + @property + def _initialize_hook(self): + ... + + +class LazyModuleMixin: + r"""A mixin for modules that lazily initialize parameters, also known as "lazy modules". + + .. warning: + Lazy modules are an experimental new feature under active development, + and their API is likely to change. + + Modules that lazily initialize parameters, or "lazy modules", + derive the shapes of their parameters from the first input(s) + to their forward method. Until that first forward they contain + :class:`torch.nn.UninitializedParameter` s that should not be accessed + or used, and afterward they contain regular :class:`torch.nn.Parameter` s. + Lazy modules are convenient since they don't require computing some + module arguments, like the :attr:`in_features` argument of a + typical :class:`torch.nn.Linear`. + + After construction, networks with lazy modules should first + be converted to the desired dtype and placed on the expected device. + This is because lazy modules only perform shape inference so the usual dtype + and device placement behavior applies. + The lazy modules should then perform "dry runs" to initialize all the components in the module. + These "dry runs" send inputs of the correct size, dtype, and device through + the network and to each one of its lazy modules. After this the network can be used as usual. + + >>> # xdoctest: +SKIP + >>> class LazyMLP(torch.nn.Module): + ... def __init__(self): + ... super().__init__() + ... self.fc1 = torch.nn.LazyLinear(10) + ... self.relu1 = torch.nn.ReLU() + ... self.fc2 = torch.nn.LazyLinear(1) + ... self.relu2 = torch.nn.ReLU() + ... + ... def forward(self, input): + ... x = self.relu1(self.fc1(input)) + ... y = self.relu2(self.fc2(x)) + ... return y + >>> # constructs a network with lazy modules + >>> lazy_mlp = LazyMLP() + >>> # transforms the network's device and dtype + >>> # NOTE: these transforms can and should be applied after construction and before any 'dry runs' + >>> lazy_mlp = lazy_mlp.cuda().double() + >>> lazy_mlp + LazyMLP( (fc1): LazyLinear(in_features=0, out_features=10, bias=True) + (relu1): ReLU() + (fc2): LazyLinear(in_features=0, out_features=1, bias=True) + (relu2): ReLU() + ) + >>> # performs a dry run to initialize the network's lazy modules + >>> lazy_mlp(torch.ones(10,10).cuda()) + >>> # after initialization, LazyLinear modules become regular Linear modules + >>> lazy_mlp + LazyMLP( + (fc1): Linear(in_features=10, out_features=10, bias=True) + (relu1): ReLU() + (fc2): Linear(in_features=10, out_features=1, bias=True) + (relu2): ReLU() + ) + >>> # attaches an optimizer, since parameters can now be used as usual + >>> optim = torch.optim.SGD(mlp.parameters(), lr=0.01) + + A final caveat when using lazy modules is that the order of initialization of a network's + parameters may change, since the lazy modules are always initialized after other modules. + For example, if the LazyMLP class defined above had a :class:`torch.nn.LazyLinear` module + first and then a regular :class:`torch.nn.Linear` second, the second module would be + initialized on construction and the first module would be initialized during the first dry run. + This can cause the parameters of a network using lazy modules to be initialized differently + than the parameters of a network without lazy modules as the order of parameter initializations, + which often depends on a stateful random number generator, is different. + Check :doc:`/notes/randomness` for more details. + + Lazy modules can be serialized with a state dict like other modules. For example: + + >>> lazy_mlp = LazyMLP() + >>> # The state dict shows the uninitialized parameters + >>> lazy_mlp.state_dict() + OrderedDict([('fc1.weight', Uninitialized parameter), + ('fc1.bias', + tensor([-1.8832e+25, 4.5636e-41, -1.8832e+25, 4.5636e-41, -6.1598e-30, + 4.5637e-41, -1.8788e+22, 4.5636e-41, -2.0042e-31, 4.5637e-41])), + ('fc2.weight', Uninitialized parameter), + ('fc2.bias', tensor([0.0019]))]) + + + Lazy modules can load regular :class:`torch.nn.Parameter` s (i.e. you can serialize/deserialize + initialized LazyModules and they will remain initialized) + + + >>> full_mlp = LazyMLP() + >>> # Dry run to initialize another module + >>> full_mlp.forward(torch.ones(10, 1)) + >>> # Load an initialized state into a lazy module + >>> lazy_mlp.load_state_dict(full_mlp.state_dict()) + >>> # The state dict now holds valid values + >>> lazy_mlp.state_dict() + OrderedDict([('fc1.weight', + tensor([[-0.3837], + [ 0.0907], + [ 0.6708], + [-0.5223], + [-0.9028], + [ 0.2851], + [-0.4537], + [ 0.6813], + [ 0.5766], + [-0.8678]])), + ('fc1.bias', + tensor([-1.8832e+25, 4.5636e-41, -1.8832e+25, 4.5636e-41, -6.1598e-30, + 4.5637e-41, -1.8788e+22, 4.5636e-41, -2.0042e-31, 4.5637e-41])), + ('fc2.weight', + tensor([[ 0.1320, 0.2938, 0.0679, 0.2793, 0.1088, -0.1795, -0.2301, 0.2807, + 0.2479, 0.1091]])), + ('fc2.bias', tensor([0.0019]))]) + + Note, however, that the loaded parameters will not be replaced when doing a "dry run" if they are initialized + when the state is loaded. This prevents using initialized modules in different contexts. + """ + + # modules inheriting from this will change their __class__ to the specified + # one after they are fully initialized + cls_to_become: Optional[Type[Any]] = None + + def __init__(self: _LazyProtocol, *args, **kwargs): + # Mypy doesnt like this super call in a mixin + super().__init__(*args, **kwargs) # type: ignore[misc] + self._load_hook = self._register_load_state_dict_pre_hook(self._lazy_load_hook) + self._initialize_hook = self.register_forward_pre_hook(self._infer_parameters, with_kwargs=True) + warnings.warn('Lazy modules are a new feature under heavy development ' + 'so changes to the API or functionality can happen at any moment.') + + def _save_to_state_dict(self: _LazyProtocol, destination, prefix, keep_vars): + # This should be ideally implemented as a hook, + # but we should override `detach` in the UninitializedParameter to return itself + # which is not clean + for name, param in self._parameters.items(): + if param is not None: + if not (is_lazy(param) or keep_vars): + param = param.detach() + destination[prefix + name] = param + for name, buf in self._buffers.items(): + if buf is not None and name not in self._non_persistent_buffers_set: + if not (is_lazy(buf) or keep_vars): + buf = buf.detach() + destination[prefix + name] = buf + + def _lazy_load_hook( + self: _LazyProtocol, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + """load_state_dict pre-hook function for lazy buffers and parameters. + + The purpose of this hook is to adjust the current state and/or + ``state_dict`` being loaded so that a module instance serialized in + both un/initialized state can be deserialized onto both un/initialized + module instance. + See comment in ``torch.nn.Module._register_load_state_dict_pre_hook`` + for the details of the hook specification. + """ + for name, param in itertools.chain(self._parameters.items(), self._buffers.items()): + key = prefix + name + if key in state_dict and param is not None: + input_param = state_dict[key] + if is_lazy(param): + # The current parameter is not initialized but the one being loaded one is + # create a new parameter based on the uninitialized one + if not is_lazy(input_param): + with torch.no_grad(): + param.materialize(input_param.shape) + + def initialize_parameters(self: _LazyProtocol, *args, **kwargs): + r"""Initialize parameters according to the input batch properties. + + This adds an interface to isolate parameter initialization from the + forward pass when doing parameter shape inference. + """ + raise NotImplementedError(f'initialize_parameters is not implemented for {self.__class__.__name__}') + + def has_uninitialized_params(self: _LazyProtocol): + r"""Check if a module has parameters that are not initialized.""" + # This is to avoid the JIT to track this parameter and force + # custom modules __setstate__ to add it + params = self._parameters.values() + buffers = self._buffers.values() + for param in itertools.chain(params, buffers): + if is_lazy(param): + return True + return False + + def _infer_parameters(self: _LazyProtocol, module, args, kwargs=None): + r"""Infers the size and initializes the parameters according to the provided input batch. + + Given a module that contains parameters that were declared inferrable + using :class:`torch.nn.parameter.ParameterMode.Infer`, runs a forward pass + in the complete module using the provided input to initialize all the parameters + as needed. + The module is set into evaluation mode before running the forward pass in order + to avoid saving statistics or calculating gradients + """ + kwargs = kwargs if kwargs else {} + module.initialize_parameters(*args, **kwargs) + if module.has_uninitialized_params(): + raise RuntimeError(f'module {self._get_name()} has not been fully initialized') + module._initialize_hook.remove() + module._load_hook.remove() + delattr(module, '_initialize_hook') + delattr(module, '_load_hook') + if module.cls_to_become is not None: + module.__class__ = module.cls_to_become + + + def _replicate_for_data_parallel(self: _LazyProtocol): + raise RuntimeError('Modules with uninitialized parameters can\'t be used with `DataParallel`. ' + 'Run a dummy forward pass to correctly initialize the modules') diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/linear.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..83e1b8a368a5f934aed84361e7bc54b60089dc28 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/linear.py @@ -0,0 +1,264 @@ +import math +from typing import Any + +import torch +from torch import Tensor +from torch.nn.parameter import Parameter, UninitializedParameter +from .. import functional as F +from .. import init +from .module import Module +from .lazy import LazyModuleMixin + + +__all__ = [ + 'Bilinear', + 'Identity', + 'LazyLinear', + 'Linear', +] + + +class Identity(Module): + r"""A placeholder identity operator that is argument-insensitive. + + Args: + args: any argument (unused) + kwargs: any keyword argument (unused) + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. + + Examples:: + + >>> m = nn.Identity(54, unused_argument1=0.1, unused_argument2=False) + >>> input = torch.randn(128, 20) + >>> output = m(input) + >>> print(output.size()) + torch.Size([128, 20]) + + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__() + + def forward(self, input: Tensor) -> Tensor: + return input + + +class Linear(Module): + r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b`. + + This module supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + Args: + in_features: size of each input sample + out_features: size of each output sample + bias: If set to ``False``, the layer will not learn an additive bias. + Default: ``True`` + + Shape: + - Input: :math:`(*, H_{in})` where :math:`*` means any number of + dimensions including none and :math:`H_{in} = \text{in\_features}`. + - Output: :math:`(*, H_{out})` where all but the last dimension + are the same shape as the input and :math:`H_{out} = \text{out\_features}`. + + Attributes: + weight: the learnable weights of the module of shape + :math:`(\text{out\_features}, \text{in\_features})`. The values are + initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where + :math:`k = \frac{1}{\text{in\_features}}` + bias: the learnable bias of the module of shape :math:`(\text{out\_features})`. + If :attr:`bias` is ``True``, the values are initialized from + :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where + :math:`k = \frac{1}{\text{in\_features}}` + + Examples:: + + >>> m = nn.Linear(20, 30) + >>> input = torch.randn(128, 20) + >>> output = m(input) + >>> print(output.size()) + torch.Size([128, 30]) + """ + + __constants__ = ['in_features', 'out_features'] + in_features: int + out_features: int + weight: Tensor + + def __init__(self, in_features: int, out_features: int, bias: bool = True, + device=None, dtype=None) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.weight = Parameter(torch.empty((out_features, in_features), **factory_kwargs)) + if bias: + self.bias = Parameter(torch.empty(out_features, **factory_kwargs)) + else: + self.register_parameter('bias', None) + self.reset_parameters() + + def reset_parameters(self) -> None: + # Setting a=sqrt(5) in kaiming_uniform is the same as initializing with + # uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see + # https://github.com/pytorch/pytorch/issues/57109 + init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + if self.bias is not None: + fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + init.uniform_(self.bias, -bound, bound) + + def forward(self, input: Tensor) -> Tensor: + return F.linear(input, self.weight, self.bias) + + def extra_repr(self) -> str: + return f'in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}' + + +# This class exists solely to avoid triggering an obscure error when scripting +# an improperly quantized attention layer. See this issue for details: +# https://github.com/pytorch/pytorch/issues/58969 +# TODO: fail fast on quantization API usage error, then remove this class +# and replace uses of it with plain Linear +class NonDynamicallyQuantizableLinear(Linear): + def __init__(self, in_features: int, out_features: int, bias: bool = True, + device=None, dtype=None) -> None: + super().__init__(in_features, out_features, bias=bias, + device=device, dtype=dtype) + + +class Bilinear(Module): + r"""Applies a bilinear transformation to the incoming data: :math:`y = x_1^T A x_2 + b`. + + Args: + in1_features: size of each first input sample + in2_features: size of each second input sample + out_features: size of each output sample + bias: If set to False, the layer will not learn an additive bias. + Default: ``True`` + + Shape: + - Input1: :math:`(*, H_{in1})` where :math:`H_{in1}=\text{in1\_features}` and + :math:`*` means any number of additional dimensions including none. All but the last dimension + of the inputs should be the same. + - Input2: :math:`(*, H_{in2})` where :math:`H_{in2}=\text{in2\_features}`. + - Output: :math:`(*, H_{out})` where :math:`H_{out}=\text{out\_features}` + and all but the last dimension are the same shape as the input. + + Attributes: + weight: the learnable weights of the module of shape + :math:`(\text{out\_features}, \text{in1\_features}, \text{in2\_features})`. + The values are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where + :math:`k = \frac{1}{\text{in1\_features}}` + bias: the learnable bias of the module of shape :math:`(\text{out\_features})`. + If :attr:`bias` is ``True``, the values are initialized from + :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where + :math:`k = \frac{1}{\text{in1\_features}}` + + Examples:: + + >>> m = nn.Bilinear(20, 30, 40) + >>> input1 = torch.randn(128, 20) + >>> input2 = torch.randn(128, 30) + >>> output = m(input1, input2) + >>> print(output.size()) + torch.Size([128, 40]) + """ + + __constants__ = ['in1_features', 'in2_features', 'out_features'] + in1_features: int + in2_features: int + out_features: int + weight: Tensor + + def __init__(self, in1_features: int, in2_features: int, out_features: int, bias: bool = True, + device=None, dtype=None) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + self.in1_features = in1_features + self.in2_features = in2_features + self.out_features = out_features + self.weight = Parameter(torch.empty((out_features, in1_features, in2_features), **factory_kwargs)) + + if bias: + self.bias = Parameter(torch.empty(out_features, **factory_kwargs)) + else: + self.register_parameter('bias', None) + self.reset_parameters() + + def reset_parameters(self) -> None: + bound = 1 / math.sqrt(self.weight.size(1)) + init.uniform_(self.weight, -bound, bound) + if self.bias is not None: + init.uniform_(self.bias, -bound, bound) + + def forward(self, input1: Tensor, input2: Tensor) -> Tensor: + return F.bilinear(input1, input2, self.weight, self.bias) + + def extra_repr(self) -> str: + return 'in1_features={}, in2_features={}, out_features={}, bias={}'.format( + self.in1_features, self.in2_features, self.out_features, self.bias is not None + ) + + +class LazyLinear(LazyModuleMixin, Linear): + r"""A :class:`torch.nn.Linear` module where `in_features` is inferred. + + In this module, the `weight` and `bias` are of :class:`torch.nn.UninitializedParameter` + class. They will be initialized after the first call to ``forward`` is done and the + module will become a regular :class:`torch.nn.Linear` module. The ``in_features`` argument + of the :class:`Linear` is inferred from the ``input.shape[-1]``. + + Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation + on lazy modules and their limitations. + + Args: + out_features: size of each output sample + bias: If set to ``False``, the layer will not learn an additive bias. + Default: ``True`` + + Attributes: + weight: the learnable weights of the module of shape + :math:`(\text{out\_features}, \text{in\_features})`. The values are + initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where + :math:`k = \frac{1}{\text{in\_features}}` + bias: the learnable bias of the module of shape :math:`(\text{out\_features})`. + If :attr:`bias` is ``True``, the values are initialized from + :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where + :math:`k = \frac{1}{\text{in\_features}}` + + + """ + + cls_to_become = Linear # type: ignore[assignment] + weight: UninitializedParameter + bias: UninitializedParameter # type: ignore[assignment] + + def __init__(self, out_features: int, bias: bool = True, + device=None, dtype=None) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + # bias is hardcoded to False to avoid creating tensor + # that will soon be overwritten. + super().__init__(0, 0, False) + self.weight = UninitializedParameter(**factory_kwargs) + self.out_features = out_features + if bias: + self.bias = UninitializedParameter(**factory_kwargs) + + def reset_parameters(self) -> None: + if not self.has_uninitialized_params() and self.in_features != 0: + super().reset_parameters() + + def initialize_parameters(self, input) -> None: # type: ignore[override] + if self.has_uninitialized_params(): + with torch.no_grad(): + self.in_features = input.shape[-1] + self.weight.materialize((self.out_features, self.in_features)) + if self.bias is not None: + self.bias.materialize((self.out_features,)) + self.reset_parameters() +# TODO: PartialLinear - maybe in sparse? diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/sparse.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/sparse.py new file mode 100644 index 0000000000000000000000000000000000000000..f053a0c8f3c2d8f0ae0a572b638e7c417b18ebdd --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/sparse.py @@ -0,0 +1,455 @@ +from typing import Optional + +import torch +from torch import Tensor +from torch.nn.parameter import Parameter + +from .module import Module +from .. import functional as F +from .. import init + +__all__ = ['Embedding', 'EmbeddingBag'] + +class Embedding(Module): + r"""A simple lookup table that stores embeddings of a fixed dictionary and size. + + This module is often used to store word embeddings and retrieve them using indices. + The input to the module is a list of indices, and the output is the corresponding + word embeddings. + + Args: + num_embeddings (int): size of the dictionary of embeddings + embedding_dim (int): the size of each embedding vector + padding_idx (int, optional): If specified, the entries at :attr:`padding_idx` do not contribute to the gradient; + therefore, the embedding vector at :attr:`padding_idx` is not updated during training, + i.e. it remains as a fixed "pad". For a newly constructed Embedding, + the embedding vector at :attr:`padding_idx` will default to all zeros, + but can be updated to another value to be used as the padding vector. + max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm` + is renormalized to have norm :attr:`max_norm`. + norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``. + scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse of frequency of + the words in the mini-batch. Default ``False``. + sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor. + See Notes for more details regarding sparse gradients. + + Attributes: + weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim) + initialized from :math:`\mathcal{N}(0, 1)` + + Shape: + - Input: :math:`(*)`, IntTensor or LongTensor of arbitrary shape containing the indices to extract + - Output: :math:`(*, H)`, where `*` is the input shape and :math:`H=\text{embedding\_dim}` + + .. note:: + Keep in mind that only a limited number of optimizers support + sparse gradients: currently it's :class:`optim.SGD` (`CUDA` and `CPU`), + :class:`optim.SparseAdam` (`CUDA` and `CPU`) and :class:`optim.Adagrad` (`CPU`) + + .. note:: + When :attr:`max_norm` is not ``None``, :class:`Embedding`'s forward method will modify the + :attr:`weight` tensor in-place. Since tensors needed for gradient computations cannot be + modified in-place, performing a differentiable operation on ``Embedding.weight`` before + calling :class:`Embedding`'s forward method requires cloning ``Embedding.weight`` when + :attr:`max_norm` is not ``None``. For example:: + + n, d, m = 3, 5, 7 + embedding = nn.Embedding(n, d, max_norm=True) + W = torch.randn((m, d), requires_grad=True) + idx = torch.tensor([1, 2]) + a = embedding.weight.clone() @ W.t() # weight must be cloned for this to be differentiable + b = embedding(idx) @ W.t() # modifies weight in-place + out = (a.unsqueeze(0) + b.unsqueeze(1)) + loss = out.sigmoid().prod() + loss.backward() + + Examples:: + + >>> # an Embedding module containing 10 tensors of size 3 + >>> embedding = nn.Embedding(10, 3) + >>> # a batch of 2 samples of 4 indices each + >>> input = torch.LongTensor([[1, 2, 4, 5], [4, 3, 2, 9]]) + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> embedding(input) + tensor([[[-0.0251, -1.6902, 0.7172], + [-0.6431, 0.0748, 0.6969], + [ 1.4970, 1.3448, -0.9685], + [-0.3677, -2.7265, -0.1685]], + + [[ 1.4970, 1.3448, -0.9685], + [ 0.4362, -0.4004, 0.9400], + [-0.6431, 0.0748, 0.6969], + [ 0.9124, -2.3616, 1.1151]]]) + + + >>> # example with padding_idx + >>> embedding = nn.Embedding(10, 3, padding_idx=0) + >>> input = torch.LongTensor([[0, 2, 0, 5]]) + >>> embedding(input) + tensor([[[ 0.0000, 0.0000, 0.0000], + [ 0.1535, -2.0309, 0.9315], + [ 0.0000, 0.0000, 0.0000], + [-0.1655, 0.9897, 0.0635]]]) + + >>> # example of changing `pad` vector + >>> padding_idx = 0 + >>> embedding = nn.Embedding(3, 3, padding_idx=padding_idx) + >>> embedding.weight + Parameter containing: + tensor([[ 0.0000, 0.0000, 0.0000], + [-0.7895, -0.7089, -0.0364], + [ 0.6778, 0.5803, 0.2678]], requires_grad=True) + >>> with torch.no_grad(): + ... embedding.weight[padding_idx] = torch.ones(3) + >>> embedding.weight + Parameter containing: + tensor([[ 1.0000, 1.0000, 1.0000], + [-0.7895, -0.7089, -0.0364], + [ 0.6778, 0.5803, 0.2678]], requires_grad=True) + """ + + __constants__ = ['num_embeddings', 'embedding_dim', 'padding_idx', 'max_norm', + 'norm_type', 'scale_grad_by_freq', 'sparse'] + + num_embeddings: int + embedding_dim: int + padding_idx: Optional[int] + max_norm: Optional[float] + norm_type: float + scale_grad_by_freq: bool + weight: Tensor + freeze: bool + sparse: bool + + def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, + max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False, + sparse: bool = False, _weight: Optional[Tensor] = None, _freeze: bool = False, + device=None, dtype=None) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + if padding_idx is not None: + if padding_idx > 0: + assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings' + elif padding_idx < 0: + assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings' + padding_idx = self.num_embeddings + padding_idx + self.padding_idx = padding_idx + self.max_norm = max_norm + self.norm_type = norm_type + self.scale_grad_by_freq = scale_grad_by_freq + if _weight is None: + self.weight = Parameter(torch.empty((num_embeddings, embedding_dim), **factory_kwargs), + requires_grad=not _freeze) + self.reset_parameters() + else: + assert list(_weight.shape) == [num_embeddings, embedding_dim], \ + 'Shape of weight does not match num_embeddings and embedding_dim' + self.weight = Parameter(_weight, requires_grad=not _freeze) + + self.sparse = sparse + + def reset_parameters(self) -> None: + init.normal_(self.weight) + self._fill_padding_idx_with_zero() + + def _fill_padding_idx_with_zero(self) -> None: + if self.padding_idx is not None: + with torch.no_grad(): + self.weight[self.padding_idx].fill_(0) + + def forward(self, input: Tensor) -> Tensor: + return F.embedding( + input, self.weight, self.padding_idx, self.max_norm, + self.norm_type, self.scale_grad_by_freq, self.sparse) + + def extra_repr(self) -> str: + s = '{num_embeddings}, {embedding_dim}' + if self.padding_idx is not None: + s += ', padding_idx={padding_idx}' + if self.max_norm is not None: + s += ', max_norm={max_norm}' + if self.norm_type != 2: + s += ', norm_type={norm_type}' + if self.scale_grad_by_freq is not False: + s += ', scale_grad_by_freq={scale_grad_by_freq}' + if self.sparse is not False: + s += ', sparse=True' + return s.format(**self.__dict__) + + @classmethod + def from_pretrained(cls, embeddings, freeze=True, padding_idx=None, + max_norm=None, norm_type=2., scale_grad_by_freq=False, + sparse=False): + r"""Create Embedding instance from given 2-dimensional FloatTensor. + + Args: + embeddings (Tensor): FloatTensor containing weights for the Embedding. + First dimension is being passed to Embedding as ``num_embeddings``, second as ``embedding_dim``. + freeze (bool, optional): If ``True``, the tensor does not get updated in the learning process. + Equivalent to ``embedding.weight.requires_grad = False``. Default: ``True`` + padding_idx (int, optional): If specified, the entries at :attr:`padding_idx` do not contribute to the gradient; + therefore, the embedding vector at :attr:`padding_idx` is not updated during training, + i.e. it remains as a fixed "pad". + max_norm (float, optional): See module initialization documentation. + norm_type (float, optional): See module initialization documentation. Default ``2``. + scale_grad_by_freq (bool, optional): See module initialization documentation. Default ``False``. + sparse (bool, optional): See module initialization documentation. + + Examples:: + + >>> # FloatTensor containing pretrained weights + >>> weight = torch.FloatTensor([[1, 2.3, 3], [4, 5.1, 6.3]]) + >>> embedding = nn.Embedding.from_pretrained(weight) + >>> # Get embeddings for index 1 + >>> input = torch.LongTensor([1]) + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> embedding(input) + tensor([[ 4.0000, 5.1000, 6.3000]]) + """ + assert embeddings.dim() == 2, \ + 'Embeddings parameter is expected to be 2-dimensional' + rows, cols = embeddings.shape + embedding = cls( + num_embeddings=rows, + embedding_dim=cols, + _weight=embeddings, + _freeze=freeze, + padding_idx=padding_idx, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + sparse=sparse) + return embedding + + +class EmbeddingBag(Module): + r"""Compute sums or means of 'bags' of embeddings, without instantiating the intermediate embeddings. + + For bags of constant length, no :attr:`per_sample_weights`, no indices equal to :attr:`padding_idx`, + and with 2D inputs, this class + + * with ``mode="sum"`` is equivalent to :class:`~torch.nn.Embedding` followed by ``torch.sum(dim=1)``, + * with ``mode="mean"`` is equivalent to :class:`~torch.nn.Embedding` followed by ``torch.mean(dim=1)``, + * with ``mode="max"`` is equivalent to :class:`~torch.nn.Embedding` followed by ``torch.max(dim=1)``. + + However, :class:`~torch.nn.EmbeddingBag` is much more time and memory efficient than using a chain of these + operations. + + EmbeddingBag also supports per-sample weights as an argument to the forward + pass. This scales the output of the Embedding before performing a weighted + reduction as specified by ``mode``. If :attr:`per_sample_weights` is passed, the + only supported ``mode`` is ``"sum"``, which computes a weighted sum according to + :attr:`per_sample_weights`. + + Args: + num_embeddings (int): size of the dictionary of embeddings + embedding_dim (int): the size of each embedding vector + max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm` + is renormalized to have norm :attr:`max_norm`. + norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``. + scale_grad_by_freq (bool, optional): if given, this will scale gradients by the inverse of frequency of + the words in the mini-batch. Default ``False``. + Note: this option is not supported when ``mode="max"``. + mode (str, optional): ``"sum"``, ``"mean"`` or ``"max"``. Specifies the way to reduce the bag. + ``"sum"`` computes the weighted sum, taking :attr:`per_sample_weights` + into consideration. ``"mean"`` computes the average of the values + in the bag, ``"max"`` computes the max value over each bag. + Default: ``"mean"`` + sparse (bool, optional): if ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor. See + Notes for more details regarding sparse gradients. Note: this option is not + supported when ``mode="max"``. + include_last_offset (bool, optional): if ``True``, :attr:`offsets` has one additional element, where the last element + is equivalent to the size of `indices`. This matches the CSR format. + padding_idx (int, optional): If specified, the entries at :attr:`padding_idx` do not contribute to the + gradient; therefore, the embedding vector at :attr:`padding_idx` is not updated + during training, i.e. it remains as a fixed "pad". For a newly constructed + EmbeddingBag, the embedding vector at :attr:`padding_idx` will default to all + zeros, but can be updated to another value to be used as the padding vector. + Note that the embedding vector at :attr:`padding_idx` is excluded from the + reduction. + + Attributes: + weight (Tensor): the learnable weights of the module of shape `(num_embeddings, embedding_dim)` + initialized from :math:`\mathcal{N}(0, 1)`. + + Examples:: + + >>> # an EmbeddingBag module containing 10 tensors of size 3 + >>> embedding_sum = nn.EmbeddingBag(10, 3, mode='sum') + >>> # a batch of 2 samples of 4 indices each + >>> input = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.long) + >>> offsets = torch.tensor([0, 4], dtype=torch.long) + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> embedding_sum(input, offsets) + tensor([[-0.8861, -5.4350, -0.0523], + [ 1.1306, -2.5798, -1.0044]]) + + >>> # Example with padding_idx + >>> embedding_sum = nn.EmbeddingBag(10, 3, mode='sum', padding_idx=2) + >>> input = torch.tensor([2, 2, 2, 2, 4, 3, 2, 9], dtype=torch.long) + >>> offsets = torch.tensor([0, 4], dtype=torch.long) + >>> embedding_sum(input, offsets) + tensor([[ 0.0000, 0.0000, 0.0000], + [-0.7082, 3.2145, -2.6251]]) + + >>> # An EmbeddingBag can be loaded from an Embedding like so + >>> embedding = nn.Embedding(10, 3, padding_idx=2) + >>> embedding_sum = nn.EmbeddingBag.from_pretrained( + embedding.weight, + padding_idx=embedding.padding_idx, + mode='sum') + """ + + __constants__ = ['num_embeddings', 'embedding_dim', 'max_norm', 'norm_type', + 'scale_grad_by_freq', 'mode', 'sparse', 'include_last_offset', + 'padding_idx'] + + num_embeddings: int + embedding_dim: int + max_norm: Optional[float] + norm_type: float + scale_grad_by_freq: bool + weight: Tensor + mode: str + sparse: bool + include_last_offset: bool + padding_idx: Optional[int] + + def __init__(self, num_embeddings: int, embedding_dim: int, + max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False, + mode: str = 'mean', sparse: bool = False, _weight: Optional[Tensor] = None, + include_last_offset: bool = False, padding_idx: Optional[int] = None, + device=None, dtype=None) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.max_norm = max_norm + self.norm_type = norm_type + self.scale_grad_by_freq = scale_grad_by_freq + if padding_idx is not None: + if padding_idx > 0: + assert padding_idx < self.num_embeddings, 'padding_idx must be within num_embeddings' + elif padding_idx < 0: + assert padding_idx >= -self.num_embeddings, 'padding_idx must be within num_embeddings' + padding_idx = self.num_embeddings + padding_idx + self.padding_idx = padding_idx + if _weight is None: + self.weight = Parameter(torch.empty((num_embeddings, embedding_dim), **factory_kwargs)) + self.reset_parameters() + else: + assert list(_weight.shape) == [num_embeddings, embedding_dim], \ + 'Shape of weight does not match num_embeddings and embedding_dim' + self.weight = Parameter(_weight) + self.mode = mode + self.sparse = sparse + self.include_last_offset = include_last_offset + + def reset_parameters(self) -> None: + init.normal_(self.weight) + self._fill_padding_idx_with_zero() + + def _fill_padding_idx_with_zero(self) -> None: + if self.padding_idx is not None: + with torch.no_grad(): + self.weight[self.padding_idx].fill_(0) + + def forward(self, input: Tensor, offsets: Optional[Tensor] = None, per_sample_weights: Optional[Tensor] = None) -> Tensor: + """Forward pass of EmbeddingBag. + + Args: + input (Tensor): Tensor containing bags of indices into the embedding matrix. + offsets (Tensor, optional): Only used when :attr:`input` is 1D. :attr:`offsets` determines + the starting index position of each bag (sequence) in :attr:`input`. + per_sample_weights (Tensor, optional): a tensor of float / double weights, or None + to indicate all weights should be taken to be ``1``. If specified, :attr:`per_sample_weights` + must have exactly the same shape as input and is treated as having the same + :attr:`offsets`, if those are not ``None``. Only supported for ``mode='sum'``. + + Returns: + Tensor output shape of `(B, embedding_dim)`. + + .. note:: + + A few notes about ``input`` and ``offsets``: + + - :attr:`input` and :attr:`offsets` have to be of the same type, either int or long + + - If :attr:`input` is 2D of shape `(B, N)`, it will be treated as ``B`` bags (sequences) + each of fixed length ``N``, and this will return ``B`` values aggregated in a way + depending on the :attr:`mode`. :attr:`offsets` is ignored and required to be ``None`` in this case. + + - If :attr:`input` is 1D of shape `(N)`, it will be treated as a concatenation of + multiple bags (sequences). :attr:`offsets` is required to be a 1D tensor containing the + starting index positions of each bag in :attr:`input`. Therefore, for :attr:`offsets` of shape `(B)`, + :attr:`input` will be viewed as having ``B`` bags. Empty bags (i.e., having 0-length) will have + returned vectors filled by zeros. + """ + return F.embedding_bag(input, self.weight, offsets, + self.max_norm, self.norm_type, + self.scale_grad_by_freq, self.mode, self.sparse, + per_sample_weights, self.include_last_offset, + self.padding_idx) + + def extra_repr(self) -> str: + s = '{num_embeddings}, {embedding_dim}' + if self.max_norm is not None: + s += ', max_norm={max_norm}' + if self.norm_type != 2: + s += ', norm_type={norm_type}' + if self.scale_grad_by_freq is not False: + s += ', scale_grad_by_freq={scale_grad_by_freq}' + s += ', mode={mode}' + if self.padding_idx is not None: + s += ', padding_idx={padding_idx}' + return s.format(**{k: repr(v) for k, v in self.__dict__.items()}) + + @classmethod + def from_pretrained(cls, embeddings: Tensor, freeze: bool = True, max_norm: Optional[float] = None, + norm_type: float = 2., scale_grad_by_freq: bool = False, + mode: str = 'mean', sparse: bool = False, include_last_offset: bool = False, + padding_idx: Optional[int] = None) -> 'EmbeddingBag': + r"""Create EmbeddingBag instance from given 2-dimensional FloatTensor. + + Args: + embeddings (Tensor): FloatTensor containing weights for the EmbeddingBag. + First dimension is being passed to EmbeddingBag as 'num_embeddings', second as 'embedding_dim'. + freeze (bool, optional): If ``True``, the tensor does not get updated in the learning process. + Equivalent to ``embeddingbag.weight.requires_grad = False``. Default: ``True`` + max_norm (float, optional): See module initialization documentation. Default: ``None`` + norm_type (float, optional): See module initialization documentation. Default ``2``. + scale_grad_by_freq (bool, optional): See module initialization documentation. Default ``False``. + mode (str, optional): See module initialization documentation. Default: ``"mean"`` + sparse (bool, optional): See module initialization documentation. Default: ``False``. + include_last_offset (bool, optional): See module initialization documentation. Default: ``False``. + padding_idx (int, optional): See module initialization documentation. Default: ``None``. + + Examples:: + + >>> # FloatTensor containing pretrained weights + >>> weight = torch.FloatTensor([[1, 2.3, 3], [4, 5.1, 6.3]]) + >>> embeddingbag = nn.EmbeddingBag.from_pretrained(weight) + >>> # Get embeddings for index 1 + >>> input = torch.LongTensor([[1, 0]]) + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> embeddingbag(input) + tensor([[ 2.5000, 3.7000, 4.6500]]) + """ + assert embeddings.dim() == 2, \ + 'Embeddings parameter is expected to be 2-dimensional' + rows, cols = embeddings.shape + embeddingbag = cls( + num_embeddings=rows, + embedding_dim=cols, + _weight=embeddings, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + mode=mode, + sparse=sparse, + include_last_offset=include_last_offset, + padding_idx=padding_idx) + embeddingbag.weight.requires_grad = not freeze + return embeddingbag diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/utils.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..019dabe3e533f6d31ba41241f65d527fab659a25 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/utils.py @@ -0,0 +1,79 @@ +import collections +from itertools import repeat +from typing import List, Dict, Any + +__all__ = ['consume_prefix_in_state_dict_if_present'] + + +def _ntuple(n, name="parse"): + def parse(x): + if isinstance(x, collections.abc.Iterable): + return tuple(x) + return tuple(repeat(x, n)) + + parse.__name__ = name + return parse + + +_single = _ntuple(1, "_single") +_pair = _ntuple(2, "_pair") +_triple = _ntuple(3, "_triple") +_quadruple = _ntuple(4, "_quadruple") + + +def _reverse_repeat_tuple(t, n): + r"""Reverse the order of `t` and repeat each element for `n` times. + + This can be used to translate padding arg used by Conv and Pooling modules + to the ones used by `F.pad`. + """ + return tuple(x for x in reversed(t) for _ in range(n)) + + +def _list_with_default(out_size: List[int], defaults: List[int]) -> List[int]: + import torch + if isinstance(out_size, (int, torch.SymInt)): + return out_size + if len(defaults) <= len(out_size): + raise ValueError( + f"Input dimension should be at least {len(out_size) + 1}" + ) + return [ + v if v is not None else d for v, d in zip(out_size, defaults[-len(out_size) :]) + ] + + +def consume_prefix_in_state_dict_if_present( + state_dict: Dict[str, Any], prefix: str +) -> None: + r"""Strip the prefix in state_dict in place, if any. + + ..note:: + Given a `state_dict` from a DP/DDP model, a local model can load it by applying + `consume_prefix_in_state_dict_if_present(state_dict, "module.")` before calling + :meth:`torch.nn.Module.load_state_dict`. + + Args: + state_dict (OrderedDict): a state-dict to be loaded to the model. + prefix (str): prefix. + """ + keys = list(state_dict.keys()) + for key in keys: + if key.startswith(prefix): + newkey = key[len(prefix) :] + state_dict[newkey] = state_dict.pop(key) + + # also strip the prefix in metadata if any. + if hasattr(state_dict, "_metadata"): + keys = list(state_dict._metadata.keys()) + for key in keys: + # for the metadata dict, the key can be: + # '': for the DDP module, which we want to remove. + # 'module': for the actual model. + # 'module.xx.xx': for the rest. + if len(key) == 0: + continue + # handling both, 'module' case and 'module.' cases + if key == prefix.replace('.', '') or key.startswith(prefix): + newkey = key[len(prefix) :] + state_dict._metadata[newkey] = state_dict._metadata.pop(key) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/parallel/scatter_gather.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/parallel/scatter_gather.py new file mode 100644 index 0000000000000000000000000000000000000000..8daa1117bfaf98246f83acf9e2b79666ccdf6ef8 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/parallel/scatter_gather.py @@ -0,0 +1,107 @@ +import torch +from typing import Any, Dict, List, Optional, Sequence, Tuple, TypeVar, Union, overload +from ._functions import Scatter, Gather +import warnings + +__all__ = ['scatter', 'scatter_kwargs', 'gather'] + +def is_namedtuple(obj: Any) -> bool: + # Check if type was created from collections.namedtuple or a typing.NamedTuple. + warnings.warn("is_namedtuple is deprecated, please use the python checks instead") + return _is_namedtuple(obj) + +def _is_namedtuple(obj: Any) -> bool: + # Check if type was created from collections.namedtuple or a typing.NamedTuple. + return ( + isinstance(obj, tuple) and hasattr(obj, "_asdict") and hasattr(obj, "_fields") + ) + + +T = TypeVar("T", dict, list, tuple) + +# For some reason, 'scatter' returns a tuple when given a single Tensor input but a list otherwise. +@overload +def scatter( + inputs: torch.Tensor, + target_gpus: Sequence[Union[int, torch.device]], + dim: int = ..., +) -> Tuple[torch.Tensor, ...]: + ... + +@overload +def scatter(inputs: T, target_gpus: Sequence[Union[int, torch.device]], dim: int = ...) -> List[T]: + ... + +def scatter(inputs, target_gpus, dim=0): + r"""Slice tensors into approximately equal chunks and distributes them across given GPUs. + + Duplicates references to objects that are not tensors. + """ + def scatter_map(obj): + if isinstance(obj, torch.Tensor): + return Scatter.apply(target_gpus, None, dim, obj) + if _is_namedtuple(obj): + return [type(obj)(*args) for args in zip(*map(scatter_map, obj))] + if isinstance(obj, tuple) and len(obj) > 0: + return list(zip(*map(scatter_map, obj))) + if isinstance(obj, list) and len(obj) > 0: + return [list(i) for i in zip(*map(scatter_map, obj))] + if isinstance(obj, dict) and len(obj) > 0: + return [type(obj)(i) for i in zip(*map(scatter_map, obj.items()))] + return [obj for _ in target_gpus] + + # After scatter_map is called, a scatter_map cell will exist. This cell + # has a reference to the actual function scatter_map, which has references + # to a closure that has a reference to the scatter_map cell (because the + # fn is recursive). To avoid this reference cycle, we set the function to + # None, clearing the cell + try: + res = scatter_map(inputs) + finally: + scatter_map = None # type: ignore[assignment] + return res + + +def scatter_kwargs( + inputs: Tuple[Any, ...], + kwargs: Optional[Dict[str, Any]], + target_gpus: Sequence[Union[int, torch.device]], + dim: int = 0, +) -> Tuple[Tuple[Any, ...], Tuple[Dict[str, Any], ...]]: + r"""Scatter with support for kwargs dictionary.""" + scattered_inputs = scatter(inputs, target_gpus, dim) if inputs else [] + scattered_kwargs = scatter(kwargs, target_gpus, dim) if kwargs else [] + if len(scattered_inputs) < len(scattered_kwargs): + scattered_inputs.extend(() for _ in range(len(scattered_kwargs) - len(scattered_inputs))) + elif len(scattered_kwargs) < len(inputs): + scattered_kwargs.extend({} for _ in range(len(scattered_inputs) - len(scattered_kwargs))) + return tuple(scattered_inputs), tuple(scattered_kwargs) + + +def gather(outputs: Any, target_device: Union[int, torch.device], dim: int = 0) -> Any: + r"""Gather tensors from different GPUs on a specified device. + + Use 'cpu' for CPU to avoid a deprecation warning. + """ + def gather_map(outputs): + out = outputs[0] + if isinstance(out, torch.Tensor): + return Gather.apply(target_device, dim, *outputs) + if out is None: + return None + if isinstance(out, dict): + if not all(len(out) == len(d) for d in outputs): + raise ValueError('All dicts must have the same number of keys') + return type(out)((k, gather_map([d[k] for d in outputs])) + for k in out) + if _is_namedtuple(out): + return type(out)._make(map(gather_map, zip(*outputs))) + return type(out)(map(gather_map, zip(*outputs))) + + # Recursive function calls like this create reference cycles. + # Setting the function to None clears the refcycle. + try: + res = gather_map(outputs) + finally: + gather_map = None # type: ignore[assignment] + return res diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantizable/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantizable/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e3fe611769a1dd3a9134c710fb948f5390cd6bd5 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantizable/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantizable/modules/activation.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantizable/modules/activation.py new file mode 100644 index 0000000000000000000000000000000000000000..e4f7a5ca3b540edc9f9b1fc15899b63240b7ac79 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantizable/modules/activation.py @@ -0,0 +1,10 @@ +# flake8: noqa: F401 +r"""Quantizable Modules. + +This file is in the process of migration to `torch/ao/nn/quantizable`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate file under the `torch/ao/nn/quantizable/modules`, +while adding an import statement here. +""" +from torch.ao.nn.quantizable.modules.activation import MultiheadAttention diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/rnn.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/rnn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd2ecbed8d646b9410f2440e35f0ab4c5d086afa Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/rnn.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/__pycache__/conv_utils.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/__pycache__/conv_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..90efc0646c5d751c1bf7379ea1e06b7471beefbc Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/__pycache__/conv_utils.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/__pycache__/expanded_weights_utils.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/__pycache__/expanded_weights_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..24ba594159679d215fe5241db1caefafe9f99ef8 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/__pycache__/expanded_weights_utils.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/__pycache__/layer_norm_expanded_weights.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/__pycache__/layer_norm_expanded_weights.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cac7866b428a6fde43b1267b6e439e9cb82b2043 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/__pycache__/layer_norm_expanded_weights.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/conv_utils.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/conv_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b675e3b892bdb848f2599d566e6079427684e8e4 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/conv_utils.py @@ -0,0 +1,240 @@ +import torch +import torch.nn.functional as F + +import numpy as np +from typing import List, Optional + +from .expanded_weights_utils import \ + set_grad_sample_if_exists, unpack_expanded_weight_or_tensor + +THRESHOLD = 32 + + +def conv_picker(func, conv1dOpt, conv2dOpt, conv3dOpt): + if func == F.conv1d: + return conv1dOpt + if func == F.conv2d: + return conv2dOpt + else: + assert func == F.conv3d + return conv3dOpt + + +def conv_args_and_kwargs(kwarg_names, expanded_args_and_kwargs): + args = expanded_args_and_kwargs[:len(expanded_args_and_kwargs) - len(kwarg_names)] + kwargs = expanded_args_and_kwargs[len(expanded_args_and_kwargs) - len(kwarg_names):] + kwargs = dict(zip(kwarg_names, kwargs)) + + return conv_normalizer(*args, **kwargs) + + +def conv_normalizer(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): + return (input, weight), {'bias': bias, 'stride': stride, 'padding': padding, 'dilation': dilation, 'groups': groups} + + +def conv_input_for_string_padding(func, padding_style, input, dilation, kernel_size): + if padding_style == "valid": + return input + else: + padding = int_padding_for_string_padding(func, padding_style, dilation, kernel_size) + return F.pad(input, padding) + + +def int_padding_for_string_padding(func, padding_style, dilation, kernel_size): + def get_dilation(i): + return dilation[i] if isinstance(dilation, tuple) else dilation + + if padding_style == "same": + padding: List[int] = [] + # F.pad needs the padding in reverse order from what conv expects + for i in range(conv_picker(func, 0, 1, 2), -1, -1): + padding += conv_padding_for_same(get_dilation(i), kernel_size[i]) + return padding + elif padding_style == "valid": + return conv_picker(func, 2, 4, 6) * (0,) + else: + raise RuntimeError(f"got padding type of {padding_style}, only accept 'same' or 'valid'") + + +def conv_padding_for_same(dilation, kernel_size): + total_pad = dilation * (kernel_size - 1) + left_pad = total_pad // 2 + right_pad = total_pad - left_pad + return left_pad, right_pad + + +def conv_backward(func, ctx, grad_output): + + def weight_grad_sample(weight): + if (batch_size < THRESHOLD and groups == 1): + return conv_group_weight_grad_sample(ctx.input, grad_output, weight_shape, stride, padding, dilation, batch_size, func) + else: + return conv_unfold_weight_grad_sample(ctx.input, grad_output, weight_shape, kernel_size, + stride, padding, dilation, groups, func) + + def expand(param): + if isinstance(param, int): + return conv_picker(func, (param,), (param, param), (param, param, param)) + else: + return param + + def calc_total_padding(func, was_same, padding, dilation, kernel_size): + if was_same: + all_padding = int_padding_for_string_padding(func, "same", dilation, kernel_size) + # F.pad needs the padding in reverse order from what conv expects + total_padding = tuple(all_padding[i] + all_padding[i - 1] for i in range(len(all_padding) - 1, -1, -2)) + return total_padding + else: + return tuple(2 * pad for pad in padding) + + weight_shape = ctx.weight.shape + stride, padding, dilation, groups = expand(ctx.stride), expand(ctx.padding), expand(ctx.dilation), ctx.groups + + kernel_size = [] + for i in range(2, conv_picker(func, 3, 4, 5)): + kernel_size.append(weight_shape[i]) + + batch_size = ctx.batch_size + results: List[Optional[torch.Tensor]] = [] + results.append(None) # for kwarg names + results.append(None) # for op reference + + # "same" padding may give uneven padding on either side so we need to separate the "padding" attr and total padding + total_padding = calc_total_padding(func, ctx.was_same_padding, padding, dilation, kernel_size) + + if ctx.input_required_grad: + output_padding = [] + input_dims = conv_picker(func, 1, 2, 3) + for i in range(input_dims): + input_dim = ctx.orig_input_shape[2 + i] + output_padding.append((total_padding[i] + input_dim - (kernel_size[i] * dilation[i] - dilation[i] + 1)) % stride[i]) + weight_ = unpack_expanded_weight_or_tensor(ctx.weight) + transpose_func = conv_picker(func, F.conv_transpose1d, F.conv_transpose2d, F.conv_transpose3d) + out = transpose_func(grad_output, weight_, None, stride, padding, tuple(output_padding), groups, dilation) + + if ctx.was_same_padding: + for i in range(len(total_padding)): + out = torch.narrow(out, 2 + i, total_padding[i] // 2, ctx.orig_input_shape[2 + i]) + + results.append(out) + else: + results.append(None) + # weight and bias don't compute batched gradients; no other arguments are differentiable + results = results + [None] * 6 + + # set grad_sample field for weight and bias with per sample gradients + set_grad_sample_if_exists(ctx.weight, weight_grad_sample) + set_grad_sample_if_exists(ctx.bias, lambda _: grad_output.reshape(*grad_output.shape[:2], -1).sum(dim=2)) + return tuple(results) + + +def conv_unfold_weight_grad_sample(input, grad_output, weight_shape, kernel_size, stride, padding, dilation, groups, func): + n = input.shape[0] + in_channels = input.shape[1] + + unfold_func = conv_picker( + func, + lambda: F.unfold(input.unsqueeze(-2), + kernel_size=(1, kernel_size[0]), + dilation=(1, dilation[0]), + padding=(0, padding[0]), + stride=(1, stride[0])), + lambda: F.unfold(input, kernel_size, dilation=dilation, padding=padding, stride=stride), + lambda: unfold3d(input, kernel_size, padding, stride, dilation) + ) + + input = unfold_func() + grad_output = grad_output.reshape(n, -1, input.shape[-1]) + + # n=batch_sz; o=num_out_channels; p=(num_in_channels/groups)*kernel_sz + weight_grad_sample = torch.einsum("noq,npq->nop", grad_output, input) + # rearrange the above tensor and extract diagonals. + weight_grad_sample = weight_grad_sample.view( + n, + groups, + -1, + groups, + int(in_channels / groups), + np.prod(kernel_size), + ) + weight_grad_sample = torch.einsum("ngrg...->ngr...", weight_grad_sample).contiguous() + shape = [n] + list(weight_shape) + weight_grad_sample = weight_grad_sample.view(shape) + return weight_grad_sample + + +def conv_group_weight_grad_sample(input, grad_output, weight_shape, stride, padding, dilation, batch_size, func): + I = input.shape[1] + O = grad_output.shape[1] + + input_ = input.transpose(0, 1) + grad_output_ = grad_output.view(grad_output.shape[0] * grad_output.shape[1], 1, *grad_output.shape[2:]) + + weight_grad_sample = func(input_, grad_output_, None, stride=dilation, padding=padding, dilation=stride, groups=batch_size) + input_dims = conv_picker(func, 3, 4, 5) + for i in range(2, input_dims): + weight_grad_sample = weight_grad_sample.narrow(i, 0, weight_shape[i]) + weight_grad_sample = weight_grad_sample.view(I, batch_size, O, *weight_grad_sample.shape[2:]) + weight_grad_sample = weight_grad_sample.movedim(0, 2) + return weight_grad_sample + + +def unfold3d( + tensor, + kernel_size, + padding, + stride, + dilation, +): + r""" + Extract sliding local blocks from an batched input tensor. + + :class:`torch.nn.Unfold` only supports 4D inputs (batched image-like tensors). + This method implements the same action for 5D inputs + Args: + tensor: An input tensor of shape ``(B, C, D, H, W)``. + kernel_size: the size of the sliding blocks + padding: implicit zero padding to be added on both sides of input + stride: the stride of the sliding blocks in the input spatial dimensions + dilation: the spacing between the kernel points. + Returns: + A tensor of shape ``(B, C * np.prod(kernel_size), L)``, where L - output spatial dimensions. + See :class:`torch.nn.Unfold` for more details + Example: + >>> # xdoctest: +SKIP + >>> B, C, D, H, W = 3, 4, 5, 6, 7 + >>> tensor = torch.arange(1, B * C * D * H * W + 1.).view(B, C, D, H, W) + >>> unfold3d(tensor, kernel_size=2, padding=0, stride=1).shape + torch.Size([3, 32, 120]) + """ + if len(tensor.shape) != 5: + raise ValueError( + f"Input tensor must be of the shape [B, C, D, H, W]. Got{tensor.shape}" + ) + + if dilation != (1, 1, 1): + raise NotImplementedError(f"dilation={dilation} not supported.") + + batch_size, channels, _, _, _ = tensor.shape + + # Input shape: (B, C, D, H, W) + tensor = F.pad( + tensor, (padding[2], padding[2], padding[1], padding[1], padding[0], padding[0]) + ) + # Output shape: (B, C, D+2*padding[2], H+2*padding[1], W+2*padding[0]) + + tensor = tensor.unfold(dimension=2, size=kernel_size[0], step=stride[0]) + tensor = tensor.unfold(dimension=3, size=kernel_size[1], step=stride[1]) + tensor = tensor.unfold(dimension=4, size=kernel_size[2], step=stride[2]) + # Output shape: (B, C, D_out, H_out, W_out, kernel_size[0], kernel_size[1], kernel_size[2]) + # For D_out, H_out, W_out definitions see :class:`torch.nn.Unfold` + + tensor = tensor.permute(0, 2, 3, 4, 1, 5, 6, 7) + # Output shape: (B, D_out, H_out, W_out, C, kernel_size[0], kernel_size[1], kernel_size[2]) + + tensor = tensor.reshape(batch_size, -1, channels * np.prod(kernel_size)).transpose( + 1, 2 + ) + # Output shape: (B, D_out * H_out * W_out, C * kernel_size[0] * kernel_size[1] * kernel_size[2] + + return tensor diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/instance_norm_expanded_weights.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/instance_norm_expanded_weights.py new file mode 100644 index 0000000000000000000000000000000000000000..f3e68b940660263f8a9ad13fe109f82c6338de1c --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/instance_norm_expanded_weights.py @@ -0,0 +1,60 @@ +from functools import partial +import torch +import torch.nn.functional as F +from .expanded_weights_impl import implements_per_sample_grads +from .expanded_weights_utils import \ + forward_helper, set_grad_sample_if_exists, standard_kwargs, unpack_expanded_weight_or_tensor +from typing import List, Optional + +@implements_per_sample_grads(F.instance_norm) +class InstanceNormPerSampleGrad(torch.autograd.Function): + @staticmethod + def forward(ctx, kwarg_names, _, *expanded_args_and_kwargs): + instance_norm = partial(torch.instance_norm, cudnn_enabled=True) + expanded_args, expanded_kwargs = standard_kwargs(kwarg_names, expanded_args_and_kwargs) + output = forward_helper(instance_norm, expanded_args, expanded_kwargs) + ctx.input = expanded_args[0] + ctx.running_mean, ctx.running_var = expanded_kwargs['running_mean'], expanded_kwargs['running_var'] + ctx.weight, ctx.bias, ctx.eps = expanded_kwargs['weight'], expanded_kwargs['bias'], expanded_kwargs['eps'] + return output + + + @staticmethod + def backward(ctx, grad_output): + input, running_mean, running_var = ctx.input, ctx.running_mean, ctx.running_var + weight, bias, eps = ctx.weight, ctx.bias, ctx.eps + + results: List[Optional[torch.Tensor]] = [] + results.append(None) # for kwarg names + results.append(None) # for op reference + if input.requires_grad: + b = input.shape[0] + c = input.shape[1] + new_shape = (1, b * c, *input.shape[2:]) + + weight_ = unpack_expanded_weight_or_tensor(weight, lambda orig_weight: orig_weight.repeat(b)) + running_mean_ = running_mean.repeat(b) if running_mean is not None else None + running_var_ = running_var.repeat(b) if running_var is not None else None + input_reshaped = input.contiguous().view(new_shape) + grad_output_reshaped = grad_output.contiguous().view(new_shape) + mean = torch.mean(input_reshaped, (0,) + tuple(range(2, input.dim())), False) + var = torch.var(input_reshaped, (0,) + tuple(range(2, input.dim())), keepdim=False, unbiased=False) + rstd = 1 / torch.sqrt(var + eps) + + # must use native batch norm since it supports all inputs. This may have used cuda or openmi during the forward but + # it didn't save the metadata, so we don't know during the backward + res = torch.ops.aten.native_batch_norm_backward( + grad_output_reshaped, input_reshaped, weight_, running_mean_, running_var_, + mean, rstd, True, eps, (True, False, False)) + results.append(res[0].reshape(input.shape)) + else: + results.append(None) + + # weight and bias don't compute batched gradients; no other arguments are differentiable (2 are not saved from the forward) + results = results + [None] * 7 + + # set grad_sample field for weight and bias with per sample gradients + set_grad_sample_if_exists(weight, + lambda _: torch.einsum("ni...->ni", F.instance_norm(input, eps=eps) * grad_output)) + set_grad_sample_if_exists(bias, lambda _: torch.einsum("ni...->ni", grad_output)) + return tuple(results) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/parametrize.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/parametrize.py new file mode 100644 index 0000000000000000000000000000000000000000..e73aada232abf7e0754319428abe7b8f88289bd9 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/parametrize.py @@ -0,0 +1,758 @@ +import torch +from torch.nn.modules.container import ModuleList, ModuleDict, Module +from torch.nn.parameter import Parameter +from torch import Tensor + +import collections +import copyreg +from copy import deepcopy +from contextlib import contextmanager +from typing import Union, Optional, Dict, Tuple, Sequence + +__all__ = ['cached', 'ParametrizationList', 'register_parametrization', 'is_parametrized', 'remove_parametrizations', + 'type_before_parametrizations', 'transfer_parametrizations_and_params'] + +_cache_enabled = 0 +_cache: Dict[Tuple[int, str], Optional[Tensor]] = {} + + +@contextmanager +def cached(): + r"""Context manager that enables the caching system within parametrizations registered with :func:`register_parametrization`. + + The value of the parametrized objects is computed and cached the first time + they are required when this context manager is active. The cached values are + discarded when leaving the context manager. + + This is useful when using a parametrized parameter more than once in the forward pass. + An example of this is when parametrizing the recurrent kernel of an RNN or when + sharing weights. + + The simplest way to activate the cache is by wrapping the forward pass of the neural network + + .. code-block:: python + + import torch.nn.utils.parametrize as P + ... + with P.cached(): + output = model(inputs) + + in training and evaluation. One may also wrap the parts of the modules that use + several times the parametrized tensors. For example, the loop of an RNN with a + parametrized recurrent kernel: + + .. code-block:: python + + with P.cached(): + for x in xs: + out_rnn = self.rnn_cell(x, out_rnn) + """ + global _cache + global _cache_enabled + _cache_enabled += 1 + try: + yield + finally: + _cache_enabled -= 1 + if not _cache_enabled: + _cache = {} + + +def _register_parameter_or_buffer(module, name, X): + if isinstance(X, Parameter): + module.register_parameter(name, X) + else: + module.register_buffer(name, X) + + +class ParametrizationList(ModuleList): + r"""A sequential container that holds and manages the original parameters or buffers of a parametrized :class:`torch.nn.Module`. + + It is the type of ``module.parametrizations[tensor_name]`` when ``module[tensor_name]`` + has been parametrized with :func:`register_parametrization`. + + If the first registered parametrization has a ``right_inverse`` that returns one tensor or + does not have a ``right_inverse`` (in which case we assume that ``right_inverse`` is the identity), + it will hold the tensor under the name ``original``. + If it has a ``right_inverse`` that returns more than one tensor, these will be registered as + ``original0``, ``original1``, ... + + .. warning:: + This class is used internally by :func:`register_parametrization`. It is documented + here for completeness. It shall not be instantiated by the user. + + Args: + modules (sequence): sequence of modules representing the parametrizations + original (Parameter or Tensor): parameter or buffer that is parametrized + unsafe (bool): a boolean flag that denotes whether the parametrization + may change the dtype and shape of the tensor. Default: `False` + Warning: the parametrization is not checked for consistency upon registration. + Enable this flag at your own risk. + """ + + original: Tensor + unsafe: bool + + def __init__( + self, modules: Sequence[Module], original: Union[Tensor, Parameter], unsafe: bool = False + ) -> None: + # We require this because we need to treat differently the first parametrization + # This should never throw, unless this class is used from the outside + if len(modules) == 0: + raise ValueError("ParametrizationList requires one or more modules.") + + super().__init__(modules) + self.unsafe = unsafe + + # In plain words: + # module.weight must keep its dtype and shape. + # Furthermore, if there is no right_inverse or the right_inverse returns a tensor, + # this should be of the same dtype as the original tensor + # + # We check that the following invariants hold: + # X = module.weight + # Y = param.right_inverse(X) + # assert isinstance(Y, Tensor) or + # (isinstance(Y, collections.abc.Sequence) and all(isinstance(t, Tensor) for t in Y)) + # Z = param(Y) if isinstance(Y, Tensor) else param(*Y) + # # Consistency checks + # assert X.dtype == Z.dtype and X.shape == Z.shape + # # If it has one input, this allows to be able to use set_ to be able to + # # move data to/from the original tensor without changing its id (which is what the + # # optimizer uses to track parameters) + # if isinstance(Y, Tensor) + # assert X.dtype == Y.dtype + # Below we use original = X, new = Y + + original_shape = original.shape + original_dtype = original.dtype + + # Compute new + with torch.no_grad(): + new = original + for module in reversed(self): # type: ignore[call-overload] + if hasattr(module, "right_inverse"): + try: + new = module.right_inverse(new) + except NotImplementedError: + pass + # else, or if it throws, we assume that right_inverse is the identity + + if not isinstance(new, Tensor) and not isinstance(new, collections.abc.Sequence): + raise ValueError("'right_inverse' must return a Tensor or a Sequence of tensors (list, tuple...). " + f"Got {type(new).__name__}") + + # Set the number of original tensors + self.is_tensor = isinstance(new, Tensor) + self.ntensors = 1 if self.is_tensor else len(new) + + # Register the tensor(s) + if self.is_tensor: + if original.dtype != new.dtype: + raise ValueError( + "When `right_inverse` outputs one tensor, it may not change the dtype.\n" + f"original.dtype: {original.dtype}\n" + f"right_inverse(original).dtype: {new.dtype}" + ) + # Set the original to original so that the user does not need to re-register the parameter + # manually in the optimiser + with torch.no_grad(): + original.set_(new) # type: ignore[call-overload] + _register_parameter_or_buffer(self, "original", original) + else: + for i, originali in enumerate(new): + if not isinstance(originali, Tensor): + raise ValueError("'right_inverse' must return a Tensor or a Sequence of tensors " + "(list, tuple...). " + f"Got element {i} of the sequence with type {type(originali).__name__}.") + + # If the original tensor was a Parameter that required grad, we expect the user to + # add the new parameters to the optimizer after registering the parametrization + # (this is documented) + if isinstance(original, Parameter): + originali = Parameter(originali) + originali.requires_grad_(original.requires_grad) + _register_parameter_or_buffer(self, f"original{i}", originali) + + if not self.unsafe: + # Consistency checks: + # Since f : A -> B, right_inverse : B -> A, Z and original should live in B + # Z = forward(right_inverse(original)) + Z = self() + if not isinstance(Z, Tensor): + raise ValueError( + f"A parametrization must return a tensor. Got {type(Z).__name__}." + ) + if Z.dtype != original_dtype: + raise ValueError( + "Registering a parametrization may not change the dtype of the tensor, unless `unsafe` flag is enabled.\n" + f"unparametrized dtype: {original_dtype}\n" + f"parametrized dtype: {Z.dtype}" + ) + if Z.shape != original_shape: + raise ValueError( + "Registering a parametrization may not change the shape of the tensor, unless `unsafe` flag is enabled.\n" + f"unparametrized shape: {original_shape}\n" + f"parametrized shape: {Z.shape}" + ) + + def right_inverse(self, value: Tensor) -> None: + r"""Call the ``right_inverse`` methods of the parametrizations in the inverse registration order. + + Then, it stores the result in ``self.original`` if ``right_inverse`` outputs one tensor + or in ``self.original0``, ``self.original1``, ... if it outputs several. + + Args: + value (Tensor): Value to which initialize the module + """ + # All the exceptions in this function should almost never throw. + # They could throw if, for example, right_inverse function returns a different + # dtype when given a different input, which should most likely be caused by a + # bug in the user's code + + with torch.no_grad(): + # See https://github.com/pytorch/pytorch/issues/53103 + for module in reversed(self): # type: ignore[call-overload] + if hasattr(module, "right_inverse"): + value = module.right_inverse(value) + else: + raise RuntimeError(f"parametrization {type(module).__name__} does not implement " + "right_inverse.") + if self.is_tensor: + # These exceptions should only throw when a right_inverse function does not + # return the same dtype for every input, which should most likely be caused by a bug + if not isinstance(value, Tensor): + raise ValueError( + f"`right_inverse` should return a tensor. Got {type(value).__name__}" + ) + if value.dtype != self.original.dtype: + raise ValueError( + f"The tensor returned by `right_inverse` has dtype {value.dtype} " + f"while `original` has dtype {self.original.dtype}" + ) + # We know that the result is going to have the same dtype + self.original.set_(value) # type: ignore[call-overload] + else: + if not isinstance(value, collections.abc.Sequence): + raise ValueError( + "'right_inverse' must return a sequence of tensors. " + f"Got {type(value).__name__}." + ) + if len(value) != self.ntensors: + raise ValueError( + "'right_inverse' must return a sequence of tensors of length " + f"{self.ntensors}. Got a sequence of length {len(value)}." + ) + for i, tensor in enumerate(value): + original_i = getattr(self, f"original{i}") + if not isinstance(tensor, Tensor): + raise ValueError( + f"`right_inverse` must return a sequence of tensors. " + f"Got element {i} of type {type(tensor).__name__}" + ) + if original_i.dtype != tensor.dtype: + raise ValueError( + f"Tensor {i} returned by `right_inverse` has dtype {tensor.dtype} " + f"while `original{i}` has dtype {original_i.dtype}" + ) + original_i.set_(tensor) + + def forward(self) -> Tensor: + if torch.jit.is_scripting(): + raise RuntimeError('Parametrization is not working with scripting.') + # Unpack the originals for the first parametrization + if self.is_tensor: + x = self[0](self.original) + else: + originals = (getattr(self, f"original{i}") for i in range(self.ntensors)) + x = self[0](*originals) + # It's not possible to call self[1:] here, so we have to be a bit more cryptic + # Also we want to skip all non-integer keys + curr_idx = 1 + while hasattr(self, str(curr_idx)): + x = self[curr_idx](x) + curr_idx += 1 + return x + + +def _inject_new_class(module: Module) -> None: + r"""Set up a module to be parametrized. + + This works by substituting the class of the module by a class + that extends it to be able to inject a property + + Args: + module (nn.Module): module into which to inject the property + """ + cls = module.__class__ + + def default_deepcopy(self, memo): + # Just emulate a standard deepcopy procedure when __deepcopy__ doesn't exist in the current class. + obj = memo.get(id(self), None) + if obj is not None: + return obj + replica = self.__new__(self.__class__) + memo[id(self)] = replica + replica.__dict__ = deepcopy(self.__dict__, memo) + # Also save all slots if they exist. + slots_to_save = copyreg._slotnames(self.__class__) # type: ignore[attr-defined] + for slot in slots_to_save: + if hasattr(self, slot): + setattr(replica, slot, deepcopy(getattr(self, slot), memo)) + return replica + + def getstate(self): + raise RuntimeError( + "Serialization of parametrized modules is only " + "supported through state_dict(). See:\n" + "https://pytorch.org/tutorials/beginner/saving_loading_models.html" + "#saving-loading-a-general-checkpoint-for-inference-and-or-resuming-training" + ) + + dct = {"__getstate__": getstate} + # We don't allow serialization of parametrized modules but should still allow deepcopying. + # Default 'deepcopy' function invokes __deepcopy__ method instead of __getstate__ when it exists. + if not hasattr(cls, "__deepcopy__"): + dct["__deepcopy__"] = default_deepcopy # type: ignore[assignment] + + param_cls = type( + f"Parametrized{cls.__name__}", + (cls,), + dct, + ) + + module.__class__ = param_cls + + +def _inject_property(module: Module, tensor_name: str) -> None: + r"""Injects a property into module[tensor_name]. + + It assumes that the class in the module has already been modified from its + original one using _inject_new_class and that the tensor under :attr:`tensor_name` + has already been moved out + + Args: + module (nn.Module): module into which to inject the property + tensor_name (str): name of the name of the property to create + """ + # We check the precondition. + # This should never fire if register_parametrization is correctly implemented + assert not hasattr(module, tensor_name) + + @torch.jit.unused + def get_cached_parametrization(parametrization) -> Tensor: + global _cache + key = (id(module), tensor_name) + tensor = _cache.get(key) + if tensor is None: + tensor = parametrization() + _cache[key] = tensor + return tensor + + def get_parametrized(self) -> Tensor: + if torch.jit.is_scripting(): + raise RuntimeError('Parametrization is not working with scripting.') + parametrization = self.parametrizations[tensor_name] + if _cache_enabled: + if torch.jit.is_scripting(): + # Scripting + raise RuntimeError('Caching is not implemented for scripting. ' + 'Either disable caching or avoid scripting.') + elif torch._C._get_tracing_state() is not None: + # Tracing + raise RuntimeError('Cannot trace a model while caching parametrizations.') + else: + return get_cached_parametrization(parametrization) + else: + # If caching is not active, this function just evaluates the parametrization + return parametrization() + + def set_original(self, value: Tensor) -> None: + if torch.jit.is_scripting(): + raise RuntimeError('Parametrization is not working with scripting.') + self.parametrizations[tensor_name].right_inverse(value) + + setattr(module.__class__, tensor_name, property(get_parametrized, set_original)) + +def register_parametrization( + module: Module, tensor_name: str, parametrization: Module, *, unsafe: bool = False, +) -> Module: + r"""Register a parametrization to a tensor in a module. + + Assume that ``tensor_name="weight"`` for simplicity. When accessing ``module.weight``, + the module will return the parametrized version ``parametrization(module.weight)``. + If the original tensor requires a gradient, the backward pass will differentiate + through :attr:`parametrization`, and the optimizer will update the tensor accordingly. + + The first time that a module registers a parametrization, this function will add an attribute + ``parametrizations`` to the module of type :class:`~ParametrizationList`. + + The list of parametrizations on the tensor ``weight`` will be accessible under + ``module.parametrizations.weight``. + + The original tensor will be accessible under + ``module.parametrizations.weight.original``. + + Parametrizations may be concatenated by registering several parametrizations + on the same attribute. + + The training mode of a registered parametrization is updated on registration + to match the training mode of the host module + + Parametrized parameters and buffers have an inbuilt caching system that can be activated + using the context manager :func:`cached`. + + A :attr:`parametrization` may optionally implement a method with signature + + .. code-block:: python + + def right_inverse(self, X: Tensor) -> Union[Tensor, Sequence[Tensor]] + + This method is called on the unparametrized tensor when the first parametrization + is registered to compute the initial value of the original tensor. + If this method is not implemented, the original tensor will be just the unparametrized tensor. + + If all the parametrizations registered on a tensor implement `right_inverse` it is possible + to initialize a parametrized tensor by assigning to it, as shown in the example below. + + It is possible for the first parametrization to depend on several inputs. + This may be implemented returning a tuple of tensors from ``right_inverse`` + (see the example implementation of a ``RankOne`` parametrization below). + + In this case, the unconstrained tensors are also located under ``module.parametrizations.weight`` + with names ``original0``, ``original1``,... + + .. note:: + + If unsafe=False (default) both the forward and right_inverse methods will be called + once to perform a number of consistency checks. + If unsafe=True, then right_inverse will be called if the tensor is not parametrized, + and nothing will be called otherwise. + + .. note:: + + In most situations, ``right_inverse`` will be a function such that + ``forward(right_inverse(X)) == X`` (see + `right inverse `_). + Sometimes, when the parametrization is not surjective, it may be reasonable + to relax this. + + .. warning:: + + If a parametrization depends on several inputs, :func:`~register_parametrization` + will register a number of new parameters. If such parametrization is registered + after the optimizer is created, these new parameters will need to be added manually + to the optimizer. See :meth:`torch.Optimizer.add_param_group`. + + Args: + module (nn.Module): module on which to register the parametrization + tensor_name (str): name of the parameter or buffer on which to register + the parametrization + parametrization (nn.Module): the parametrization to register + Keyword args: + unsafe (bool): a boolean flag that denotes whether the parametrization + may change the dtype and shape of the tensor. Default: `False` + Warning: the parametrization is not checked for consistency upon registration. + Enable this flag at your own risk. + + Raises: + ValueError: if the module does not have a parameter or a buffer named :attr:`tensor_name` + + Examples: + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK) + >>> import torch + >>> import torch.nn as nn + >>> import torch.nn.utils.parametrize as P + >>> + >>> class Symmetric(nn.Module): + >>> def forward(self, X): + >>> return X.triu() + X.triu(1).T # Return a symmetric matrix + >>> + >>> def right_inverse(self, A): + >>> return A.triu() + >>> + >>> m = nn.Linear(5, 5) + >>> P.register_parametrization(m, "weight", Symmetric()) + >>> print(torch.allclose(m.weight, m.weight.T)) # m.weight is now symmetric + True + >>> A = torch.rand(5, 5) + >>> A = A + A.T # A is now symmetric + >>> m.weight = A # Initialize the weight to be the symmetric matrix A + >>> print(torch.allclose(m.weight, A)) + True + + >>> class RankOne(nn.Module): + >>> def forward(self, x, y): + >>> # Form a rank 1 matrix multiplying two vectors + >>> return x.unsqueeze(-1) @ y.unsqueeze(-2) + >>> + >>> def right_inverse(self, Z): + >>> # Project Z onto the rank 1 matrices + >>> U, S, Vh = torch.linalg.svd(Z, full_matrices=False) + >>> # Return rescaled singular vectors + >>> s0_sqrt = S[0].sqrt().unsqueeze(-1) + >>> return U[..., :, 0] * s0_sqrt, Vh[..., 0, :] * s0_sqrt + >>> + >>> linear_rank_one = P.register_parametrization(nn.Linear(4, 4), "weight", RankOne()) + >>> print(torch.linalg.matrix_rank(linear_rank_one.weight).item()) + 1 + + """ + parametrization.train(module.training) + if is_parametrized(module, tensor_name): + # Correctness checks. + # If A is the space of tensors with shape and dtype equal to module.weight + # we check that parametrization.forward and parametrization.right_inverse are + # functions from A to A + if not unsafe: + Y = getattr(module, tensor_name) + X = parametrization(Y) + if not isinstance(X, Tensor): + raise ValueError( + f"A parametrization must return a tensor. Got {type(X).__name__}." + ) + if X.dtype != Y.dtype: + raise ValueError( + "Registering a parametrization may not change the dtype of the tensor, unless the `unsafe` flag is enabled.\n" + f"module.{tensor_name}.dtype: {Y.dtype}\n" + f"parametrization(module.{tensor_name}).dtype: {X.dtype}" + ) + if X.shape != Y.shape: + raise ValueError( + "Registering a parametrization may not change the shape of the tensor, unless the `unsafe` flag is enabled.\n" + f"module.{tensor_name}.shape: {Y.shape}\n" + f"parametrization(module.{tensor_name}).shape: {X.shape}" + ) + if hasattr(parametrization, "right_inverse"): + try: + Z = parametrization.right_inverse(X) # type: ignore[operator] + except NotImplementedError: + pass + else: + if not isinstance(Z, Tensor): + raise ValueError( + f"parametrization.right_inverse must return a tensor. Got: {type(Z).__name__}" + ) + if Z.dtype != Y.dtype: + raise ValueError( + "The tensor returned by parametrization.right_inverse must have the same dtype " + f"as module.{tensor_name}, unless the `unsafe` flag is enabled.\n" + f"module.{tensor_name}.dtype: {Y.dtype}\n" + f"returned dtype: {Z.dtype}" + ) + if Z.shape != Y.shape: + raise ValueError( + "The tensor returned by parametrization.right_inverse must have the same shape " + f"as module.{tensor_name}, unless the `unsafe` flag is enabled.\n" + f"module.{tensor_name}.shape: {Y.shape}\n" + f"returned shape: {Z.shape}" + ) + # else right_inverse is assumed to be the identity + + # add the new parametrization to the parametrization list + assert isinstance(module.parametrizations, ModuleDict) # Make mypy happy + module.parametrizations[tensor_name].append(parametrization) + # If unsafe was True in previous parametrization, keep it enabled + module.parametrizations[tensor_name].unsafe |= unsafe # type: ignore[index, union-attr] + elif tensor_name in module._buffers or tensor_name in module._parameters: + # Set the parametrization mechanism + # Fetch the original buffer or parameter + original = getattr(module, tensor_name) + # We create this early to check for possible errors + parametrizations = ParametrizationList([parametrization], original, unsafe=unsafe) + # Delete the previous parameter or buffer + delattr(module, tensor_name) + # If this is the first parametrization registered on the module, + # we prepare the module to inject the property + if not is_parametrized(module): + # Change the class + _inject_new_class(module) + # Inject a ``ModuleDict`` into the instance under module.parametrizations + module.parametrizations = ModuleDict() + # Add a property into the class + _inject_property(module, tensor_name) + # Add a ParametrizationList + assert isinstance(module.parametrizations, ModuleDict) # Make mypy happy + module.parametrizations[tensor_name] = parametrizations + else: + raise ValueError( + f"Module '{module}' does not have a parameter, a buffer, or a " + f"parametrized element with name '{tensor_name}'" + ) + return module + + +def is_parametrized(module: Module, tensor_name: Optional[str] = None) -> bool: + r"""Determine if a module has a parametrization. + + Args: + module (nn.Module): module to query + tensor_name (str, optional): name of the parameter in the module + Default: ``None`` + Returns: + ``True`` if :attr:`module` has a parametrization for the parameter named :attr:`tensor_name`, + or if it has any parametrization when :attr:`tensor_name` is ``None``; + otherwise ``False`` + """ + parametrizations = getattr(module, "parametrizations", None) + if parametrizations is None or not isinstance(parametrizations, ModuleDict): + return False + if tensor_name is None: + # Check that there is at least one parametrized buffer or Parameter + return len(parametrizations) > 0 + else: + return tensor_name in parametrizations + +def remove_parametrizations( + module: Module, tensor_name: str, leave_parametrized: bool = True +) -> Module: + r"""Remove the parametrizations on a tensor in a module. + + - If ``leave_parametrized=True``, ``module[tensor_name]`` will be set to + its current output. In this case, the parametrization shall not change the ``dtype`` + of the tensor. + - If ``leave_parametrized=False``, ``module[tensor_name]`` will be set to + the unparametrised tensor in ``module.parametrizations[tensor_name].original``. + This is only possible when the parametrization depends on just one tensor. + + Args: + module (nn.Module): module from which remove the parametrization + tensor_name (str): name of the parametrization to be removed + leave_parametrized (bool, optional): leave the attribute :attr:`tensor_name` parametrized. + Default: ``True`` + + Returns: + Module: module + + Raises: + ValueError: if ``module[tensor_name]`` is not parametrized + ValueError: if ``leave_parametrized=False`` and the parametrization depends on several tensors + """ + if not is_parametrized(module, tensor_name): + raise ValueError(f"Module {module} does not have a parametrization on {tensor_name}") + + # Fetch the original tensor + assert isinstance(module.parametrizations, ModuleDict) # Make mypy happy + parametrizations = module.parametrizations[tensor_name] + if parametrizations.is_tensor: + original = parametrizations.original + if leave_parametrized: + with torch.no_grad(): + t = getattr(module, tensor_name) + # We know they have the same dtype because we have checked this when registering the + # parametrizations. As such, we can use set_ + # We do this so that the parameter does not to change the id() + # This way the user does not need to update the optimizer + with torch.no_grad(): + if type(original) is torch.Tensor: + original.set_(t) + else: + try: + original.set_(t) + except RuntimeError as e: + # TODO: Fix this for tensor subclasses that are parameters: + # RuntimeError: set_storage is not allowed on a Tensor created from .data or .detach(). + raise RuntimeError("Calling remove_parametrizations() with leave_parametrized=True " + "for a parameter that is an instance of a tensor subclass requires " + "set_() to be implemented correctly for the tensor subclass. Either " + "set leave_parametrized=False or provide a working implementation for " + "set_() in the tensor subclass.") from e + else: + if leave_parametrized: + # We cannot use no_grad because we need to know whether one or more + # original tensors required grad + t = getattr(module, tensor_name) + # We'll have to trust the user to add it to the optimizer + original = Parameter(t) if t.requires_grad else t + else: + raise ValueError("Cannot leave unparametrized (`leave_parametrized=False`) a tensor " + "that is parametrized in terms of a sequence of tensors.") + + # Delete the property that manages the parametrization + delattr(module.__class__, tensor_name) + # Delete the ParametrizationList + del module.parametrizations[tensor_name] + + # Restore the parameter / buffer into the main class + _register_parameter_or_buffer(module, tensor_name, original) + + # Roll back the parametrized class if no other buffer or parameter + # is currently parametrized in this class + if not is_parametrized(module): + delattr(module, "parametrizations") + # Restore class + orig_cls = module.__class__.__bases__[0] + module.__class__ = orig_cls + return module + +def type_before_parametrizations(module: Module) -> type: + r"""Return the module type before parametrizations were applied and if not, then it returns the module type. + + Args: + module (nn.Module): module to get type of + """ + if is_parametrized(module): + return module.__class__.__bases__[0] + else: + return type(module) + +def transfer_parametrizations_and_params( + from_module: Module, to_module: Module, tensor_name: Optional[str] = None +) -> Module: + r"""Transfer parametrizations and the parameters they parametrize from :attr:`from_module` to :attr:`to_module`. + + If :attr:`tensor_name` is specified, only transfers the specified parameter, otherwise + transfers all parametrized parameters. If those parameters do not exist in to_module, it will create them. + Does nothing if from_module is not parametrized. + + Args: + from_module (nn.Module): module to transfer from + to_module (nn.Module): module to transfer to + tensor_name (str, optional): parameter to transfer + + Returns: + Module: to_module + """ + if is_parametrized(from_module): + assert isinstance(from_module.parametrizations, ModuleDict) # for mypy + + # get list of all params or the single param to transfer + parameters_to_transfer: Union[list, ModuleDict] = ( + from_module.parametrizations if tensor_name is None else [tensor_name] + ) + + assert hasattr(parameters_to_transfer, "__iter__") # for mypy + for parameter_name in parameters_to_transfer: + + # initialize the to-be-transferred param in to_module if it doesn't exist already + if not hasattr(to_module, parameter_name): + setattr( + to_module, + parameter_name, + Parameter(getattr(from_module, parameter_name)), + ) + + # apply the params's parametrizations to to_module + for param_func in from_module.parametrizations[parameter_name]: + register_parametrization(to_module, parameter_name, param_func) + assert isinstance(to_module.parametrizations, ModuleDict) # for mypy + + # make values match, original values can be stored in either original or + # original0, original1..., need to check both cases + if hasattr(from_module.parametrizations[parameter_name], "original"): + to_module.parametrizations[parameter_name].original = \ + from_module.parametrizations[parameter_name].original + else: + num = 0 + orig_num = "original" + str(num) + # loop through each original# until all values have been set + while hasattr(from_module.parametrizations[parameter_name], orig_num): + setattr( + to_module.parametrizations[parameter_name], + orig_num, + getattr(from_module.parametrizations[parameter_name], orig_num), + ) + num = num + 1 + orig_num = "original" + str(num) + + return to_module