diff --git a/.gitattributes b/.gitattributes index 5db4a09bb9785cb413e7c47dec5bdab958b01a30..45f721353c93ea457c3a59e85d51f5c414156235 100644 --- a/.gitattributes +++ b/.gitattributes @@ -69,3 +69,6 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/distl tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/rich/__pycache__/console.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Debugger/__pycache__/libpython.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/FlowControl.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text +tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cuda_cupti/lib/libcheckpoint.so filter=lfs diff=lfs merge=lfs -text +tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/Visitor.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text +tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/pkg_resources/__pycache__/__init__.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/Visitor.cpython-311-x86_64-linux-gnu.so b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/Visitor.cpython-311-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..86e4dcdf199df2edb2a6ea62ce98a3110f865f03 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/Visitor.cpython-311-x86_64-linux-gnu.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8597f1985804f6c0c55b84d29a8744f0e2bc6600aaa695402499fbbbcba1decc +size 374848 diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cuda_cupti/lib/libcheckpoint.so b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cuda_cupti/lib/libcheckpoint.so new file mode 100644 index 0000000000000000000000000000000000000000..700f62ffc626c49da58fd1468260c8e2318e49c8 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cuda_cupti/lib/libcheckpoint.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:38073c63ab8f022926f58f7cb39c565005f382bdfacd85822e7502a5256b6671 +size 1509528 diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/pkg_resources/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/pkg_resources/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7b5506b32c39f7d64545b4f5475475eb325114d2 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/pkg_resources/__pycache__/__init__.cpython-311.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f1e41fea31e2f114e2b8bb3065092e62588a33b909a8fa70bc578e734128e529 +size 176864 diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_custom_op/impl.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_custom_op/impl.py new file mode 100644 index 0000000000000000000000000000000000000000..a770a58d30b31ed44439db4ff843105f0565bdc7 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_custom_op/impl.py @@ -0,0 +1,976 @@ +import dataclasses +import functools +import inspect +import sys +import typing +import weakref + +from torchgen.model import FunctionSchema, OperatorName, SchemaKind, BaseType, ListType, BaseTy + +import torch +import torch._C as _C +import torch.library as library +from torch._library.abstract_impl import AbstractImplCtx +from torch.library import get_ctx + +from .autograd import autograd_kernel_indirection, construct_autograd_kernel + +""" +For a detailed guide on custom ops, please see +https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk + +This file includes pieces of the implementation of our custom operator API. +""" + +__all__ = ["custom_op", "CustomOp", "get_ctx", "AbstractImplCtx"] + + +SUPPORTED_DEVICE_TYPE_TO_KEY = { + "cpu": "CPU", + "cuda": "CUDA", +} + +# We will not let users register CustomOps with anything that could look like +# PyTorch internals to avoid confusion. +RESERVED_NS = { + "prim", + "prims", + "aten", + "at", + "torch", + "pytorch", +} + + +def custom_op( + qualname: str, manual_schema: typing.Optional[str] = None +) -> typing.Callable: + r"""Creates a new CustomOp object. + + WARNING: if you're a user, please do not use this directly + (instead use the torch._custom_ops APIs). + Also please see the following for a detailed guide on custom ops. + https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk + + In PyTorch, defining an op (short for "operator") is a two step-process: + - we need to define (create) the op + - we need to implement behavior for how the operator interacts with + various PyTorch subsystems, like CPU/CUDA Tensors, Autograd, etc. + + This entrypoint defines the CustomOp object (the first step); + you must then perform the second step by calling various methods on + the CustomOp object. + + This API is used as a decorator (see examples). + + Arguments: + qualname (str): Should be a string that looks like + "namespace::operator_name". Operators in PyTorch need a namespace to + avoid name collisions; a given operator may only be created once. + If you are writing a Python library, we recommend the namespace to + be the name of your top-level module. The operator_name must be + the same as the name of the function you pass to custom_op + (see examples). + manual_schema (Optional[str]): Each PyTorch operator needs a schema that + tells PyTorch the types of the inputs/outputs. If None (default), + we will infer the schema from the type annotations on the function + (see examples). Otherwise, if you don't want to use type annotations, + you may provide us the schema string. + + Example:: + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) + >>> import numpy as np + >>> from torch import Tensor + >>> + >>> # Step 1: define the CustomOp. + >>> # We need to provide the decorator a "prototype function" + >>> # (a function with Python ellipses as the body). + >>> @custom_op("my_library::numpy_sin") + >>> def numpy_sin(x: Tensor) -> Tensor: + >>> ... + >>> + >>> # numpy_sin is now an instance of class CustomOp + >>> print(type(numpy_sin)) + >>> + >>> # Step 2: Register an implementation for various PyTorch subsystems + >>> + >>> # Register an implementation for CPU tensors + >>> @numpy_sin.impl('cpu') + >>> def numpy_sin_impl_cpu(x): + >>> return torch.from_numpy(np.sin(x.numpy())) + >>> + >>> # Register an implementation for CUDA tensors + >>> @numpy_sin.impl('cuda') + >>> def numpy_sin_impl_cuda(x): + >>> return torch.from_numpy(np.sin(x.cpu().numpy())).to(x.device) + >>> + >>> x = torch.randn(3) + >>> numpy_sin(x) # calls numpy_sin_impl_cpu + >>> + >>> x_cuda = x.cuda() + >>> numpy_sin(x) # calls numpy_sin_impl_cuda + + """ + + def inner(func): + if not inspect.isfunction(func): + raise ValueError( + f"custom_op(...)(func): Expected `func` to be a Python " + f"function, got: {type(func)}" + ) + + ns, name = parse_qualname(qualname) + validate_namespace(ns) + if func.__name__ != name: + raise ValueError( + f"custom_op(qualname='{qualname}', ...)(func): expected `func` " + f"to have name '{name}' but got '{func.__name__}'. " + f"Please either change the name of `func` or the qualname that " + f"is passed to `custom_op`" + ) + + schema = infer_schema(func) if manual_schema is None else manual_schema + schema_str = f"{name}{schema}" + function_schema = FunctionSchema.parse(schema_str) + validate_schema(function_schema) + if manual_schema is not None: + validate_function_matches_schema(function_schema, func) + + lib = library.Library(ns, "FRAGMENT") + lib.define(schema_str) + ophandle = find_ophandle_or_throw(ns, function_schema.name) + result = CustomOp(lib, ns, function_schema, name, ophandle, _private_access=True) + + result.__name__ = func.__name__ + result.__module__ = func.__module__ + result.__doc__ = func.__doc__ + + library.impl(lib, result._opname, "Autograd")( + autograd_kernel_indirection(weakref.proxy(result)) + ) + + torch._C._dispatch_set_report_error_callback( + ophandle, functools.partial(report_error_callback, weakref.proxy(result)) + ) + + return result + + return inner + + +# Global dictionary holding references to all CustomOp objects +# Yes, it keeps all CustomOps alive (see NOTE [CustomOp lifetime]) +# Used to query the CustomOp associated with a specific C++ dispatcher operator. +# An example usage is FakeTensor: FakeTensor checks if a specific operator +# has an implementation registered via the CustomOp API. +# Indexed by qualname (e.g. aten::foo) +global_registry: typing.Dict[str, "CustomOp"] = {} + + +class CustomOp: + r"""Class for custom operators in PyTorch. + + Use the CustomOp API to create user-defined custom operators that behave + just like regular PyTorch operators (e.g. torch.sin, torch.mm) when it + comes to various PyTorch subsystems (like torch.compile). + + To construct a `CustomOp`, use `custom_op`. + """ + + def __init__(self, lib, cpp_ns, schema, operator_name, ophandle, *, _private_access=False): + super().__init__() + if not _private_access: + raise RuntimeError( + "The CustomOp constructor is private and we do not guarantee " + "BC for it. Please use custom_op(...) to create a CustomOp object" + ) + name = f"{cpp_ns}::{operator_name}" + self._schema = schema + self._cpp_ns = cpp_ns + self._lib: library.Library = lib + self._ophandle: _C._DispatchOperatorHandle = ophandle + # Has the name of the op, e.g. "foo". We cache here for convenience. + self._opname: str = operator_name + # this is _opname but with namespace. e.g. "custom::foo" + self._qualname: str = name + self.__name__ = None # mypy requires this + # NB: Some of these impls are registered as kernels to DispatchKeys. + # Modifying the _impls dict directly won't do anything in that case. + self._impls: typing.Dict[str, typing.Optional[FuncAndLocation]] = {} + # See NOTE [CustomOp autograd kernel indirection] + self._registered_autograd_kernel_indirection = False + + global_registry[self._qualname] = self + + def _register_autograd_kernel_indirection(self): + assert not self._registered_autograd_kernel_indirection + self._lib.impl(self._opname, autograd_kernel_indirection(weakref.proxy(self)), "Autograd") + self._registered_autograd_kernel_indirection = True + + # Records the impl and the source location in self._impls + # Note that this doesn't cause torch.library to use the impl, that + # needs to be done in a separate self._lib.impl call. + def _register_impl(self, kind, func, stacklevel=2): + if self._has_impl(kind): + func_and_location = self._impls[kind] + assert func_and_location is not None # Pacify mypy + location = func_and_location.location + raise RuntimeError( + f"Attempting to register a {kind} impl for operator {self._qualname} " + f"that already has a {kind} impl registered from Python at " + f"{location}. This is not supported." + ) + frame = inspect.getframeinfo(sys._getframe(stacklevel)) + location = f"{frame.filename}:{frame.lineno}" + self._impls[kind] = FuncAndLocation(func, location) + + def _get_impl(self, kind): + return self._impls[kind] + + def _has_impl(self, kind): + return kind in self._impls + + def _destroy(self): + # NOTE: [CustomOp lifetime] + # A CustomOp, once created, lives forever. The mechanism is that the + # global registry holds a reference to it. However, to make testing + # easier, we want to be able to destroy CustomOp objects. + # CustomOp._destroy does the job, though it leaves the CustomOp + # in a garbage state. + del self._lib + + opnamespace = getattr(torch.ops, self._cpp_ns) + if hasattr(opnamespace, self._opname): + delattr(opnamespace, self._opname) + + del global_registry[self._qualname] + + def __repr__(self): + return f'' + + def __call__(self, *args, **kwargs): + # Bypass torch.ops.* and directly do OperatorHandle::callBoxed. + # Using torch.ops.* is a bit of a pain (it can be slow and it has lifetime + # issues from caching operators that make testing CustomOp difficult). + result = _C._dispatch_call_boxed(self._ophandle, *args, **kwargs) + return result + + def impl( + self, device_types: typing.Union[str, typing.Iterable[str]], _stacklevel=2, + ) -> typing.Callable: + r"""Register an implementation for a device type for this CustomOp object. + + WARNING: if you're a user, please do not use this directly + (instead use the torch._custom_ops APIs). + Also please see the following for a detailed guide on custom ops. + https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk + + If the CustomOp is passed multiple Tensor inputs with different device + types, it will dispatch to the registered implementation for the highest + priority device type among those present. + The supported device types, in order of priority, are {'cuda', 'cpu'}. + + This API is used as a decorator (see examples). + + Arguments: + device_types (str or Iterable[str]): the device type(s) to register the function for. + + Examples:: + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) + >>> import numpy as np + >>> from torch import Tensor + >>> + >>> @custom_op("my_library::numpy_cos") + >>> def numpy_cos(x: Tensor) -> Tensor: + >>> ... + >>> + >>> # Register an implementation for CPU Tensors + >>> @numpy_cos.impl('cpu') + >>> def numpy_cos_impl_cpu(x): + >>> return torch.from_numpy(np.cos(x.numpy())) + >>> + >>> # Register an implementation for CUDA Tensors + >>> @numpy_cos.impl('cuda') + >>> def numpy_cos_impl_cuda(x): + >>> return torch.from_numpy(np.cos(x.cpu().numpy())).to(x.device) + >>> + >>> x = torch.randn(3) + >>> numpy_cos(x) # calls numpy_cos_impl_cpu + >>> + >>> x_cuda = x.cuda() + >>> numpy_cos(x) # calls numpy_cos_impl_cuda + + """ + if isinstance(device_types, str): + device_types = [device_types] + for device_type in device_types: + validate_device_type(device_type) + + def inner(f): + for device_type in set(device_types): + self._check_doesnt_have_library_impl(device_type) + self._register_impl(device_type, f, stacklevel=_stacklevel) + dispatch_key = SUPPORTED_DEVICE_TYPE_TO_KEY[device_type] + library.impl(self._lib, self._opname, dispatch_key)(f) + return f + + return inner + + def _check_doesnt_have_library_impl(self, device_type): + if self._has_impl(device_type): + return + key = SUPPORTED_DEVICE_TYPE_TO_KEY[device_type] + if _C._dispatch_has_computed_kernel_for_dispatch_key(self._qualname, key): + raise RuntimeError( + f"impl(..., device_types={device_type}): the operator {self._qualname} " + f"already has an implementation for this device type via a " + f"pre-existing torch.library or TORCH_LIBRARY registration.") + + def impl_factory(self) -> typing.Callable: + r"""Register an implementation for a factory function.""" + + def inner(f): + self._register_impl("factory", f) + library.impl(self._lib, self._opname, "BackendSelect")(f) + return f + + return inner + + def impl_abstract(self, _stacklevel=2) -> typing.Callable: + r"""Register an abstract implementation for this operator. + + WARNING: please do not use this directly (and instead use the torch._custom_ops + APIs). Also please see the following for a detailed guide on custom ops. + https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk + + An "abstract implementation" specifies the behavior of this operator on + Tensors that carry no data. Given some input Tensors with certain properties + (sizes/strides/storage_offset/device), it specifies what the properties of + the output Tensors are. + + The abstract implementation has the same signature as the operator. + It is run for both FakeTensors and meta tensors. To write an abstract + implementation, assume that all Tensor inputs to the operator are + regular CPU/CUDA/Meta tensors, but they do not have storage, and + you are trying to return regular CPU/CUDA/Meta tensor(s) as output. + The abstract implementation must consist of only PyTorch operations + (and may not directly access the storage or data of any input or + intermediate Tensors). + + This API is used as a decorator (see examples). + + Examples:: + >>> import numpy as np + >>> from torch import Tensor + >>> + >>> # Example 1: an operator without data-dependent output shape + >>> @custom_op('my_library::custom_linear') + >>> def custom_linear(x: Tensor, weight: Tensor, bias: Tensor) -> Tensor: + >>> ... + >>> + >>> @custom_linear.impl_abstract() + >>> def custom_linear_abstract(x, weight): + >>> assert x.dim() == 2 + >>> assert weight.dim() == 2 + >>> assert bias.dim() == 1 + >>> assert x.shape[1] == weight.shape[1] + >>> assert weight.shape[0] == bias.shape[0] + >>> assert x.device == weight.device + >>> + >>> return (x @ weight.t()) + bias + >>> + >>> # Example 2: an operator with data-dependent output shape + >>> @custom_op('my_library::custom_nonzero') + >>> def custom_nonzero(x: Tensor) -> Tensor: + >>> ... + >>> + >>> @custom_nonzero.impl_abstract() + >>> def custom_nonzero_abstract(x): + >>> # Number of nonzero-elements is data-dependent. + >>> # Since we cannot peek at the data in an abstract impl, + >>> # we use the ctx object to construct a new symint that + >>> # represents the data-dependent size. + >>> ctx = torch._custom_op.get_ctx() + >>> nnz = ctx.create_unbacked_symint() + >>> shape = [x.dim(), nnz] + >>> result = x.new_empty(shape, dtype=torch.long) + >>> return result + >>> + >>> @custom_nonzero.impl(['cpu', 'cuda']) + >>> def custom_nonzero_impl(x): + >>> x_np = to_numpy(x) + >>> res = np.stack(np.nonzero(x_np), axis=1) + >>> # unbacked symbolic ints in PyTorch must be >= 2, so we + >>> # constrain the range to at least 2 + >>> if res.shape[0] <= 1: + >>> raise RuntimeError("not supported") + >>> return torch.tensor(res, device=x.device) + + """ + + def inner(f): + self._check_doesnt_have_library_meta_impl() + self._register_impl("abstract", f, stacklevel=_stacklevel) + location = self._get_impl("abstract").location + + qualname = self._qualname + + # Handle DispatchKey.Meta registration + @functools.wraps(f) + def f_with_ctx(*args, **kwargs): + def error_on_ctx(): + raise RuntimeError( + f"Attempted to call get_ctx() for the meta implementation " + f"for {qualname}." + f"You have presumably called get_ctx() because the operator " + f"has a data-dependent output shape; if so, there is no " + f"such meta implementation and this error is the correct " + f"behavior. Otherwise, please remove the call to get_ctx() " + f"in the implementation registered with impl_abstract " + f"at {location}" + ) + + with torch._library.abstract_impl.set_ctx_getter(error_on_ctx): + return f(*args, **kwargs) + + self._lib.impl(self._opname, f_with_ctx, "Meta") + return f + + return inner + + def _check_can_register_backward(self): + def error(detail): + raise RuntimeError( + f"Cannot use torch._custom_ops APIs to register backward " + f"formula for {detail}. Got operator " + f"{self._qualname} with schema: {schema}" + ) + + schema = self._schema + if schema.kind() != SchemaKind.functional: + error("non-functional operator") + + rets = schema.returns + if not schema.returns: + error("operator with no returns") + + assert len(rets) > 0 + is_non_mutating_view = any( + r.annotation is not None and not r.annotation.is_write for r in rets + ) + if is_non_mutating_view: + error("operator that returns views") + + # We make assumptions about the schema's return types. + allowed_return_types = { + BaseType(BaseTy.int): "int", + BaseType(BaseTy.SymInt): "SymInt", + BaseType(BaseTy.bool): "bool", + BaseType(BaseTy.float): "float", + BaseType(BaseTy.Tensor): "Tensor", + ListType(BaseType(BaseTy.Tensor), None): "List[Tensor]", + } + for ret in schema.returns: + if ret.type in allowed_return_types: + continue + error(f"operator with return not in {list(allowed_return_types.values())} (got {ret.type})") + + def _check_doesnt_have_library_autograd_impl(self): + if self._registered_autograd_kernel_indirection: + return + + if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "CompositeImplicitAutograd"): + raise RuntimeError( + f"impl_backward/impl_save_for_backward: the operator {self._qualname} " + f"already has an implementation for this device type via a " + f"pre-existing registration to DispatchKey::CompositeImplicitAutograd." + f"CompositeImplicitAutograd operators do not need an autograd formula; " + f"instead, the operator will decompose into its constituents and those " + f"can have autograd formulas defined on them.") + + # We can improve this by adding "all Autograd keys", but + # realistically people will just be using this API for CPU/CUDA for now. + for key in ["Autograd", "AutogradCPU", "AutogradCUDA"]: + if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, key): + raise RuntimeError( + f"impl_backward/impl_save_for_backward: " + f"the operator {self._qualname} already has an Autograd kernel " + f"registered to DispatchKey::{key} vi a pre-existing " + f"torch.library or TORCH_LIBRARY registration. Please either " + f"remove those registrations or don't use the torch._custom_ops APIs") + + def _check_doesnt_have_library_meta_impl(self): + if self._has_impl("abstract"): + return + + # If the user's operator is CompositeExplicitAutograd, + # allow them to impl_abstract. This is being pragmatic + # (existing custom ops may have CompositeExplicitAutograd + # registration that don't work with Meta kernels, so this + # gives them an escape hatch). + if ( + _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "CompositeExplicitAutograd") + and not _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "Meta") + ): + return + + # Otherwise, if the user's already has a Meta kernel or their + # op is CompositeImplicitAutograd or some other alias dispatch key, + # raise. + + # Special case for CompositeImplicitAutograd + if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "CompositeImplicitAutograd"): + raise RuntimeError( + f"impl_abstract(...): the operator {self._qualname} " + f"already has an implementation for this device type via a " + f"pre-existing registration to DispatchKey::CompositeImplicitAutograd." + f"CompositeImplicitAutograd operators do not need an abstract impl; " + f"instead, the operator will decompose into its constituents and those " + f"can have abstract impls defined on them.") + + if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "Meta"): + raise RuntimeError( + f"impl_abstract(...): the operator {self._qualname} " + f"already has an DispatchKey::Meta implementation via a " + f"pre-existing torch.library or TORCH_LIBRARY registration. " + f"Please either remove that registration or don't call impl_abstract.") + + # NOTE ["backward", "save_for_backward", and "autograd"] + # As a part of the explicit autograd API, a user must provide us + # a "save_for_backward" function and a "backward" function. + # When both of these have been provided, then we automatically + # construct the "autograd" kernel. + def _register_autograd_kernel(self): + assert self._has_impl("backward") + assert self._has_impl("save_for_backward") + kernel = construct_autograd_kernel( + self._schema, + self._output_differentiability, + self, + get_op(self._qualname), + self._get_impl("save_for_backward").func, + self._get_impl("backward").func) + self._register_impl("autograd", kernel) + + def impl_save_for_backward(self, _stacklevel=2): + r"""Register a function that tells us what to save for backward. + + Please see impl_backward for more details. + """ + def inner(f): + self._check_can_register_backward() + self._check_doesnt_have_library_autograd_impl() + if not self._registered_autograd_kernel_indirection: + self._register_autograd_kernel_indirection() + self._register_impl("save_for_backward", f, stacklevel=_stacklevel) + if self._has_impl("backward"): + self._register_autograd_kernel() + return inner + + def impl_backward(self, output_differentiability=None, _stacklevel=2): + r"""Registers a backward formula. + + WARNING: if you're a user, please do not use this directly + (instead use the torch._custom_ops APIs). + Also please see the following for a detailed guide on custom ops. + https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk + + In order for the CustomOp to work with autograd, you need to register + a backward formula. There are two pieces to this: + 1. You must give us a function to specify what to save for backward. + Call this the "save for backward" function. + 2. You must give us a function that computes gradients. Call this the + "backward" function. + + Use `impl_save_for_backward` to define a "save for backward" function + that specifies what gets saved for backward. The function should accept + two arguments ``(inputs, output)`` and return the quantities to be saved + for backward. + + During runtime, when you call the CustomOp, PyTorch will invoke the + "save for backward" function with the inputs and output of the CustomOp. + + Use `impl_backward` to define the "backward" function. The backward + function must accept ``(ctx, saved, *grads)``: + - ``ctx`` is a context object where we may provide information + - ``saved`` is exactly what gets returned from the "save for backward" + function + - ``grads`` is one or more gradients. The number of gradients matches + the number of outputs of the CustomOp. + + The backward function must return a dict that maps the name of + an input to the CustomOp to its corresponding gradient. All inputs that + were declared to be Tensors in the CustomOp definition must be accounted + for in the dict. The gradient may be a Tensor or None. + + """ + if output_differentiability is not None: + def yell(): + raise RuntimeError( + f"impl_backward(output_differentiability): expected " + f"output_differentiability to be a list of bools with " + f"length equal to the number of outputs of this CustomOp " + f"got: {output_differentiability}") + + if not isinstance(output_differentiability, list): + yell() + for diff in output_differentiability: + if not isinstance(diff, bool): + yell() + if len(self._schema.returns) != len(output_differentiability): + yell() + + def inner(f): + self._check_can_register_backward() + self._check_doesnt_have_library_autograd_impl() + if not self._registered_autograd_kernel_indirection: + self._register_autograd_kernel_indirection() + self._register_impl("backward", f, stacklevel=_stacklevel) + self._output_differentiability = output_differentiability + if self._has_impl("save_for_backward"): + self._register_autograd_kernel() + return inner + + +@dataclasses.dataclass +class FuncAndLocation: + func: typing.Callable + location: str + + +def find_ophandle_or_throw(cpp_ns: str, operator_name: OperatorName): + overload_name = ( + "" if operator_name.overload_name is None else operator_name.overload_name + ) + return _C._dispatch_find_schema_or_throw( + f"{cpp_ns}::{str(operator_name.name)}", overload_name + ) + + +def validate_namespace(ns: str) -> None: + if "." in ns: + raise ValueError( + f'custom_op(..., ns="{ns}"): expected ns to not contain any . (and be a ' + f"valid variable name)" + ) + if ns in RESERVED_NS: + raise ValueError( + f"custom_op(..., ns='{ns}'): '{ns}' is a reserved namespace, " + f"please choose something else. " + ) + +def validate_schema(schema: FunctionSchema) -> None: + if not torch._library.utils.is_functional_schema(schema): + raise ValueError( + f"custom_op only supports functional operators " + f"(ops that do not mutate any inputs, do not return " + f"views of the inputs, and has at least one return). " + f"Got the following non-functional schema: {schema}" + ) + + # For simplicity: don't allow self arguments + if schema.arguments.self_arg is not None: + raise ValueError( + f"custom_op does not support arguments named 'self'. Please " + f"rename your argument. Got: {schema}" + ) + + +def parse_qualname(qualname: str) -> typing.Tuple[str, str]: + names = qualname.split("::", 1) + if len(names) != 2: + raise ValueError(f"Expected there to be a namespace in {qualname}, i.e. The " + f"operator name should look something like ns::foo") + if '.' in names[1]: + raise ValueError(f"The torch.custom_ops APIs do not handle overloads, " + f"i.e. operator names with '.' in them. " + f"Please name your operator something like ns::foo. " + f"Got: {qualname}") + return names[0], names[1] + + +def validate_device_type(device_type: str) -> None: + if device_type not in SUPPORTED_DEVICE_TYPE_TO_KEY: + raise ValueError( + f"CustomOp.impl(device_types=[{device_type}, ...]): we only support device_type " + f"in {SUPPORTED_DEVICE_TYPE_TO_KEY.keys()}." + ) + + +def supported_param(param: inspect.Parameter) -> bool: + return param.kind in ( + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.KEYWORD_ONLY, + ) + + +def validate_function_matches_schema( + schema: FunctionSchema, func: typing.Callable +) -> None: + sig = inspect.signature(func) + + if not all(supported_param(p) for _, p in sig.parameters.items()): + raise ValueError( + f"custom_op(..., manual_schema)(func): positional-only args, " + f"varargs, and kwargs are not supported. Please rewrite `func` " + f"to not have them. Got `func` with signature: {sig}" + ) + + if ( + any( + p.annotation is not inspect.Parameter.empty + for _, p in sig.parameters.items() + ) + or sig.return_annotation is not inspect.Signature.empty + ): + raise ValueError( + f"custom_op(..., manual_schema)(func): When passing in a manual " + f"schema, we expect `func` to have no type annotations to avoid " + f"ambiguity. Got `func` with signature: {sig}" + ) + + positional = [ + (name, param) + for name, param in sig.parameters.items() + if param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD + ] + kwargonly = [ + (name, param) + for name, param in sig.parameters.items() + if param.kind == inspect.Parameter.KEYWORD_ONLY + ] + + def error(): + raise ValueError( + f"custom_op(..., manual_schema)(func): When passing in a manual " + f"schema, we expect `func`'s signature to match `manual_schema` " + f"(aside from type annotations). " + f"func's signature: {sig}, manual_schema: {schema}" + ) + + def error_default_args(): + raise ValueError( + f"custom_op(..., manual_schema)(func): " + f"neither func nor manual_schema should have default " + f"arguments. Got " + f"func's signature: {sig}, manual_schema: {schema}" + ) + + def compare(sig_args, schema_args): + if len(sig_args) != len(schema_args): + error() + for (name, param), arg in zip(sig_args, schema_args): + if name != arg.name: + error() + if param.default is not inspect.Parameter.empty or arg.default is not None: + error_default_args() + + compare(positional, schema.arguments.flat_positional) + compare(kwargonly, schema.arguments.flat_kwarg_only) + + +def infer_schema(prototype_function: typing.Callable) -> str: + sig = inspect.signature(prototype_function) + + def error_fn(what): + raise ValueError( + f"custom_op(...)(func): {what} " f"Got func with signature {sig})" + ) + + params = [ + parse_param(name, param, error_fn) for name, param in sig.parameters.items() + ] + ret = parse_return(sig.return_annotation, error_fn) + return f"({', '.join(params)}) -> {ret}" + + +def parse_param(name, param, error_fn): + if not supported_param(param): + error_fn("We do not support positional-only args, varargs, or varkwargs.") + + if param.annotation is inspect.Parameter.empty: + error_fn(f"Parameter {name} must have a type annotation.") + + if param.annotation not in SUPPORTED_PARAM_TYPES.keys(): + error_fn( + f"Parameter {name} has unsupported type {param.annotation}. " + f"The valid types are: {SUPPORTED_PARAM_TYPES.keys()}." + ) + + if param.default is not inspect.Parameter.empty: + error_fn( + f"Parameter {name} has a default value; this is not supported. " + f"If you want to use default values then create a function with " + f"default values that calls the CustomOp" + ) + + return f"{SUPPORTED_PARAM_TYPES[param.annotation]} {name}" + + +def derived_types( + base_type, cpp_type, list_base, optional_base_list, optional_list_base +): + result = [ + (base_type, cpp_type), + (typing.Optional[base_type], f"{cpp_type}?"), + ] + if list_base: + result.append((typing.Sequence[base_type], f"{cpp_type}[]")) # type: ignore[valid-type] + if optional_base_list: + result.append((typing.Sequence[typing.Optional[base_type]], f"{cpp_type}?[]")) # type: ignore[valid-type] + if optional_list_base: + result.append((typing.Optional[typing.Sequence[base_type]], f"{cpp_type}[]?")) # type: ignore[valid-type] + return result + + +def get_supported_param_types(): + data = [ + # (python type, schema type, type[] variant, type?[] variant, type[]? variant + (torch.Tensor, "Tensor", True, True, False), + (int, "SymInt", True, False, True), + (float, "float", True, False, True), + (bool, "bool", True, False, True), + (str, "str", False, False, False), + (torch.types.Number, "Scalar", True, False, False), + (torch.dtype, "ScalarType", False, False, False), + (torch.device, "Device", False, False, False), + ] + result = [] + for line in data: + result.extend(derived_types(*line)) + return dict(result) + + +SUPPORTED_RETURN_TYPES = { + torch.Tensor: "Tensor", + typing.List[torch.Tensor]: "Tensor[]", + int: "SymInt", + float: "float", + bool: "bool", + torch.types.Number: "Scalar", +} + + +def parse_return(annotation, error_fn): + origin = typing.get_origin(annotation) + if origin is not tuple: + if annotation not in SUPPORTED_RETURN_TYPES.keys(): + error_fn( + f"Return has unsupported type {annotation}. " + f"The valid types are: {SUPPORTED_RETURN_TYPES}." + ) + return SUPPORTED_RETURN_TYPES[annotation] + + args = typing.get_args(annotation) + for arg in args: + if arg not in SUPPORTED_RETURN_TYPES: + error_fn( + f"Return has unsupported type {annotation}. " + f"The valid types are: {SUPPORTED_RETURN_TYPES}." + ) + + return "(" + ", ".join([SUPPORTED_RETURN_TYPES[arg] for arg in args]) + ")" + + +SUPPORTED_PARAM_TYPES = get_supported_param_types() + + +def report_error_callback(custom_op: typing.Any, key: str) -> None: + if key == "Undefined": + raise NotImplementedError( + f"{custom_op}: There were no Tensor inputs to this operator " + f"(e.g. you passed an empty list of Tensors). If your operator is a " + f"factory function (that is, it takes no Tensors and constructs " + f"a new one), then please use CustomOp.impl_factory to register " + f"an implementation for it" + ) + if key == "Meta": + raise NotImplementedError( + f"{custom_op}: when running with device='Meta' tensors: there is no " + f"abstract impl registered for this CustomOp. Please register one via " + f"CustomOp.impl_abstract to get this CustomOp to work with Meta tensors" + ) + if key in ("CPU", "CUDA"): + device = key.lower() + raise NotImplementedError( + f"{custom_op}: when running with device='{device}' tensors: there is no " + f"{device} impl registered for this CustomOp. Please register one via " + f"CustomOp.impl(device_type='{device}')" + ) + raise NotImplementedError( + f"{custom_op}: No implementation for dispatch key {key}. It is likely " + f"that we have not added this functionality yet, please either open an " + f"issue or if you're feeling adventurous, use the low-level " + f"torch.library API" + ) + + +def custom_op_from_existing(op): + ns = op.namespace + lib = torch.library.Library(ns, "FRAGMENT") + name = op.name().split("::")[-1] + schema_str = str(op._schema) + # CustomOp expects the schema string without the namespace + schema_str = schema_str.split("::")[-1] + schema = FunctionSchema.parse(schema_str) + return CustomOp(lib, ns, schema, name, op, _private_access=True) + + +def get_op(qualname): + def error_not_found(): + raise ValueError( + f"Could not find the operator {qualname}. Please make sure you have " + f"already registered the operator and (if registered from C++) " + f"loaded it via torch.ops.load_library.") + + ns, name = parse_qualname(qualname) + if not hasattr(torch.ops, ns): + error_not_found() + opnamespace = getattr(torch.ops, ns) + if not hasattr(opnamespace, name): + error_not_found() + packet = getattr(opnamespace, name) + if not hasattr(packet, 'default'): + error_not_found() + return packet.default + + +def _find_custom_op(qualname, also_check_torch_library=False): + if qualname in global_registry: + return global_registry[qualname] + if not also_check_torch_library: + raise RuntimeError( + f"Could not find custom op \"{qualname}\". Did you register it via " + f"the torch._custom_ops API?") + overload = get_op(qualname) + result = custom_op_from_existing(overload) + return result + + +def get_abstract_impl(qualname): + if qualname not in torch._custom_op.impl.global_registry: + return None + custom_op = torch._custom_op.impl.global_registry[qualname] + if custom_op is None: + return None + if not custom_op._has_impl("abstract"): + return None + return custom_op._get_impl("abstract").func + + +def _custom_op_with_schema(qualname, schema, needs_fixed_stride_order=True): + ns, name = qualname.split("::") + schema_str = f"{name}{schema}" + function_schema = FunctionSchema.parse(schema_str) + validate_schema(function_schema) + tags = [torch._C.Tag.needs_fixed_stride_order] if needs_fixed_stride_order else [] + lib = library.Library(ns, "FRAGMENT") + lib.define(schema_str, tags=tags) + ophandle = find_ophandle_or_throw(ns, function_schema.name) + result = CustomOp(lib, ns, function_schema, name, ophandle, _private_access=True) + result._register_autograd_kernel_indirection() + + torch._C._dispatch_set_report_error_callback( + ophandle, functools.partial(report_error_callback, weakref.proxy(result)) + ) + return get_op(qualname) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/fft.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/fft.py new file mode 100644 index 0000000000000000000000000000000000000000..06dda2d3accb766cec40ee3a2e84a6cafe2e37c4 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/fft.py @@ -0,0 +1,590 @@ +import math + +from typing import Iterable, List, Literal, NamedTuple, Optional, Sequence, Tuple, Union + +import torch +import torch._prims as prims +import torch._prims_common as utils +from torch._decomp import register_decomposition +from torch._prims_common import DimsType, ShapeType, TensorLikeType +from torch._prims_common.wrappers import _maybe_convert_to_dtype, out_wrapper + +__all__ = [ + # Transforms + "fft", + "fft2", + "fftn", + "hfft", + "hfft2", + "hfftn", + "rfft", + "rfft2", + "rfftn", + "ifft", + "ifft2", + "ifftn", + "ihfft", + "ihfft2", + "ihfftn", + "irfft", + "irfft2", + "irfftn", + # Helpers + "fftshift", + "ifftshift", +] + +NormType = Union[None, Literal["forward", "backward", "ortho"]] +_NORM_VALUES = {None, "forward", "backward", "ortho"} +aten = torch._ops.ops.aten + + +def _apply_norm( + x: TensorLikeType, norm: NormType, signal_numel: int, forward: bool +) -> TensorLikeType: + """Apply normalization to the un-normalized FFT result""" + torch._check(norm in _NORM_VALUES, lambda: f"Invalid normalization mode: {norm}") + + if norm == "ortho": + return x * (1 / math.sqrt(signal_numel)) + + normalize = (not forward and (norm is None or norm == "backward")) or ( + forward and norm == "forward" + ) + return x * (1 / signal_numel) if normalize else x + + +def _promote_type_fft( + dtype: torch.dtype, require_complex: bool, device: torch.device +) -> torch.dtype: + """Helper to promote a dtype to one supported by the FFT primitives""" + if dtype.is_complex: + return dtype + + # Promote integral to default float type + if not dtype.is_floating_point: + dtype = torch.get_default_dtype() + + allowed_types = [torch.float32, torch.float64] + maybe_support_half = device.type in ["cuda", "meta"] + + if maybe_support_half: + allowed_types.append(torch.float16) + torch._check(dtype in allowed_types, lambda: f"Unsupported dtype {dtype}") + + if require_complex: + dtype = utils.corresponding_complex_dtype(dtype) + + return dtype + + +def _maybe_promote_tensor_fft( + t: TensorLikeType, require_complex: bool = False +) -> TensorLikeType: + """Helper to promote a tensor to a dtype supported by the FFT primitives""" + cur_type = t.dtype + new_type = _promote_type_fft(cur_type, require_complex, t.device) + return _maybe_convert_to_dtype(t, new_type) # type: ignore[return-value] + + +def _resize_fft_input( + x: TensorLikeType, dims: Tuple[int, ...], sizes: Tuple[int, ...] +) -> TensorLikeType: + """ + Fixes the shape of x such that x.size(dims[i]) == sizes[i], + either by zero-padding, or by slicing x starting from 0. + """ + assert len(dims) == len(sizes) + must_copy = False + x_sizes = x.shape + pad_amount = [0] * len(x_sizes) * 2 + for i in range(len(dims)): + if sizes[i] == -1: + continue + + if x_sizes[dims[i]] < sizes[i]: + must_copy = True + pad_idx = len(pad_amount) - 2 * dims[i] - 1 + pad_amount[pad_idx] = sizes[i] - x_sizes[dims[i]] + + if x_sizes[dims[i]] > sizes[i]: + x = x.narrow(dims[i], 0, sizes[i]) + + return torch.constant_pad_nd(x, pad_amount) if must_copy else x + + +def _fft_c2r( + func_name: str, + input: TensorLikeType, + n: Optional[int], + dim: int, + norm: NormType, + forward: bool, +) -> TensorLikeType: + """Common code for performing any complex to real FFT (irfft or hfft)""" + input = _maybe_promote_tensor_fft(input, require_complex=True) + dims = (utils.canonicalize_dim(input.ndim, dim, wrap_scalar=False),) + last_dim_size = n if n is not None else 2 * (input.shape[dim] - 1) + torch._check( + last_dim_size >= 1, + lambda: f"Invalid number of data points ({last_dim_size}) specified", + ) + + if n is not None: + input = _resize_fft_input(input, dims=dims, sizes=(last_dim_size // 2 + 1,)) + + if forward: + input = torch.conj(input) + + output = prims.fft_c2r(input, dim=dims, last_dim_size=last_dim_size) + return _apply_norm(output, norm=norm, signal_numel=last_dim_size, forward=forward) + + +def _fft_r2c( + func_name: str, + input: TensorLikeType, + n: Optional[int], + dim: int, + norm: NormType, + forward: bool, + onesided: bool, +) -> TensorLikeType: + """Common code for performing any real to complex FFT (rfft or ihfft)""" + torch._check( + not input.dtype.is_complex, + lambda: f"{func_name} expects a floating point input tensor, but got {input.dtype}", + ) + input = _maybe_promote_tensor_fft(input) + dims = (utils.canonicalize_dim(input.ndim, dim, wrap_scalar=False),) + dim_size = n if n is not None else input.shape[dim] + torch._check( + dim_size >= 1, lambda: f"Invalid number of data points ({dim_size}) specified" + ) + + if n is not None: + input = _resize_fft_input(input, dims, (n,)) + + ret = prims.fft_r2c(input, dim=dims, onesided=onesided) + ret = _apply_norm(ret, norm, dim_size, forward) + return ret if forward else torch.conj(ret) + + +def _fft_c2c( + func_name: str, + input: TensorLikeType, + n: Optional[int], + dim: int, + norm: NormType, + forward: bool, +) -> TensorLikeType: + """Common code for performing any complex to complex FFT (fft or ifft)""" + torch._check( + input.dtype.is_complex, + lambda: f"{func_name} expects a complex input tensor, but got {input.dtype}", + ) + dims = (utils.canonicalize_dim(input.ndim, dim, wrap_scalar=False),) + dim_size = n if n is not None else input.shape[dim] + torch._check( + dim_size >= 1, lambda: f"Invalid number of data points ({dim_size}) specified" + ) + + if n is not None: + input = _resize_fft_input(input, dims, (n,)) + + ret = prims.fft_c2c(input, dim=dims, forward=forward) + return _apply_norm(ret, norm, dim_size, forward) + + +@register_decomposition(aten.fft_fft) +@out_wrapper() +def fft( + input: TensorLikeType, + n: Optional[int] = None, + dim: int = -1, + norm: NormType = None, +) -> TensorLikeType: + if input.dtype.is_complex: + return _fft_c2c("fft", input, n, dim, norm, forward=True) + else: + return _fft_r2c("fft", input, n, dim, norm, forward=True, onesided=False) + + +@register_decomposition(aten.fft_ifft) +@out_wrapper() +def ifft( + input: TensorLikeType, + n: Optional[int] = None, + dim: int = -1, + norm: NormType = None, +) -> TensorLikeType: + if input.dtype.is_complex: + return _fft_c2c("ifft", input, n, dim, norm, forward=False) + else: + return _fft_r2c("ifft", input, n, dim, norm, forward=False, onesided=False) + + +@register_decomposition(aten.fft_rfft) +@out_wrapper() +def rfft( + input: TensorLikeType, + n: Optional[int] = None, + dim: int = -1, + norm: NormType = None, +) -> TensorLikeType: + return _fft_r2c("rfft", input, n, dim, norm, forward=True, onesided=True) + + +@register_decomposition(aten.fft_irfft) +@out_wrapper() +def irfft( + input: TensorLikeType, + n: Optional[int] = None, + dim: int = -1, + norm: NormType = None, +) -> TensorLikeType: + return _fft_c2r("irfft", input, n, dim, norm, forward=False) + + +@register_decomposition(aten.fft_hfft) +@out_wrapper() +def hfft( + input: TensorLikeType, + n: Optional[int] = None, + dim: int = -1, + norm: NormType = None, +) -> TensorLikeType: + return _fft_c2r("hfft", input, n, dim, norm, forward=True) + + +@register_decomposition(aten.fft_ihfft) +@out_wrapper() +def ihfft( + input: TensorLikeType, + n: Optional[int] = None, + dim: int = -1, + norm: NormType = None, +) -> TensorLikeType: + return _fft_r2c("ihfft", input, n, dim, norm, forward=False, onesided=True) + + +class _ShapeAndDims(NamedTuple): + shape: Tuple[int, ...] + dims: Tuple[int, ...] + + +def _canonicalize_fft_shape_and_dim_args( + input: TensorLikeType, shape: Optional[ShapeType], dim: Optional[DimsType] +) -> _ShapeAndDims: + """Convert the shape and dim arguments into a canonical form where neither are optional""" + input_dim = input.ndim + input_sizes = input.shape + + if dim is not None: + if not isinstance(dim, Sequence): + dim = (dim,) + ret_dims = utils.canonicalize_dims(input_dim, dim, wrap_scalar=False) + + # Check dims are unique + torch._check( + len(set(ret_dims)) == len(ret_dims), lambda: "FFT dims must be unique" + ) + + if shape is not None: + if not isinstance(shape, Sequence): + shape = (shape,) + + # Has shape, might have dim + torch._check( + dim is None or len(dim) == len(shape), + lambda: "When given, dim and shape arguments must have the same length", + ) + transform_ndim = len(shape) + + torch._check( + transform_ndim <= input_dim, + lambda: f"Got shape with {transform_ndim} values but input tensor " + f"only has {input_dim} dimensions.", + ) + + # If shape is given, dims defaults to the last len(shape) dimensions + if dim is None: + ret_dims = tuple(range(input_dim - transform_ndim, input_dim)) + + # Translate any -1 values in shape to the default length + ret_shape = tuple( + s if s != -1 else input_sizes[d] for (s, d) in zip(shape, ret_dims) # type: ignore[possibly-undefined] + ) + elif dim is None: + # No shape, no dim + ret_dims = tuple(range(input_dim)) + ret_shape = tuple(input_sizes) + else: + # No shape, has dim + ret_shape = tuple(input_sizes[d] for d in ret_dims) # type: ignore[possibly-undefined] + + for n in ret_shape: + torch._check(n > 0, lambda: f"Invalid number of data points ({n}) specified") + + return _ShapeAndDims(shape=ret_shape, dims=ret_dims) # type: ignore[possibly-undefined] + + +def _prod(xs: Iterable[int]) -> int: + """Compute product of a list""" + prod = 1 + for x in xs: + prod *= x + return prod + + +def _fftn_c2c( + function_name: str, + input: TensorLikeType, + shape: Tuple[int, ...], + dim: Tuple[int, ...], + norm: NormType, + forward: bool, +) -> TensorLikeType: + """Common code for n-dimensional complex to complex FFTs (fftn or ifftn)""" + torch._check( + input.dtype.is_complex, + lambda: f"{function_name} expects a complex input tensor, " + f"but got {input.dtype}", + ) + x = _resize_fft_input(input, dim, shape) + output = prims.fft_c2c(x, dim=dim, forward=forward) + return _apply_norm(output, norm=norm, signal_numel=_prod(shape), forward=forward) + + +@register_decomposition(aten.fft_fftn) +@out_wrapper() +def fftn( + input: TensorLikeType, + s: Optional[ShapeType] = None, + dim: Optional[DimsType] = None, + norm: NormType = None, +) -> TensorLikeType: + (shape, dim) = _canonicalize_fft_shape_and_dim_args(input, s, dim) + x = _maybe_promote_tensor_fft(input, require_complex=True) + return _fftn_c2c("fftn", x, shape, dim, norm, forward=True) + + +@register_decomposition(aten.fft_ifftn) +@out_wrapper() +def ifftn( + input: TensorLikeType, + s: Optional[ShapeType] = None, + dim: Optional[DimsType] = None, + norm: NormType = None, +) -> TensorLikeType: + (shape, dim) = _canonicalize_fft_shape_and_dim_args(input, s, dim) + x = _maybe_promote_tensor_fft(input, require_complex=True) + return _fftn_c2c("ifftn", x, shape, dim, norm, forward=False) + + +@register_decomposition(aten.fft_rfftn) +@out_wrapper() +def rfftn( + input: TensorLikeType, + s: Optional[ShapeType] = None, + dim: Optional[DimsType] = None, + norm: NormType = None, +) -> TensorLikeType: + torch._check( + not input.dtype.is_complex, + lambda: f"rfftn expects a real-valued input tensor, but got {input.dtype}", + ) + shape, dim = _canonicalize_fft_shape_and_dim_args(input, s, dim) + input = _maybe_promote_tensor_fft(input, require_complex=False) + input = _resize_fft_input(input, dim, shape) + out = prims.fft_r2c(input, dim=dim, onesided=True) + return _apply_norm(out, norm=norm, signal_numel=_prod(shape), forward=True) + + +@register_decomposition(aten.fft_ihfftn) +@out_wrapper() +def ihfftn( + input: TensorLikeType, + s: Optional[ShapeType] = None, + dim: Optional[DimsType] = None, + norm: NormType = None, +) -> TensorLikeType: + torch._check( + not input.dtype.is_complex, + lambda: f"ihfftn expects a real-valued input tensor, but got {input.dtype}", + ) + shape, dim = _canonicalize_fft_shape_and_dim_args(input, s, dim) + torch._check(len(shape) > 0, lambda: "ihfftn must transform at least one axis") + input = _maybe_promote_tensor_fft(input, require_complex=False) + input = _resize_fft_input(input, dim, shape) + + tmp = prims.fft_r2c(input, dim=dim[-1:], onesided=True) + + if len(dim) == 1: + tmp = _apply_norm(tmp, norm=norm, signal_numel=shape[0], forward=False) + return prims.conj(tmp) + + tmp = prims.conj_physical(tmp) + tmp = prims.fft_c2c(tmp, dim=dim[:-1], forward=False) + return _apply_norm(tmp, norm=norm, signal_numel=_prod(shape), forward=False) + + +class _CanonicalizeC2rReturn(NamedTuple): + shape: Tuple[int, ...] + dim: Tuple[int, ...] + last_dim_size: int + + +def _canonicalize_fft_c2r_shape_and_dim_args( + fname: str, + input: TensorLikeType, + s: Optional[ShapeType], + dim: Optional[DimsType], +) -> _CanonicalizeC2rReturn: + """Canonicalize shape and dim arguments for n-dimensional c2r transforms, + as well as calculating the last_dim_size which is shape[dim[-1]] for the output""" + (shape, dim) = _canonicalize_fft_shape_and_dim_args(input, s, dim) + torch._check(len(shape) > 0, lambda: f"{fname} must transform at least one axis") + + if s is None or s[-1] == -1: + last_dim_size = 2 * (input.shape[dim[-1]] - 1) + else: + last_dim_size = shape[-1] + + torch._check( + last_dim_size >= 1, + lambda: f"Invalid number of data points ({last_dim_size}) specified", + ) + + shape_list = list(shape) + shape_list[-1] = last_dim_size // 2 + 1 + return _CanonicalizeC2rReturn( + shape=tuple(shape_list), dim=dim, last_dim_size=last_dim_size + ) + + +@register_decomposition(aten.fft_irfftn) +@out_wrapper() +def irfftn( + input: TensorLikeType, + s: Optional[ShapeType] = None, + dim: Optional[DimsType] = None, + norm: NormType = None, +) -> TensorLikeType: + shape, dim, last_dim_size = _canonicalize_fft_c2r_shape_and_dim_args( + "irfftn", input, s, dim + ) + input = _maybe_promote_tensor_fft(input, require_complex=True) + input = _resize_fft_input(input, dim, shape) + out = prims.fft_c2r(input, dim=dim, last_dim_size=last_dim_size) + return _apply_norm(out, norm, _prod(out.shape[d] for d in dim), forward=False) + + +@register_decomposition(aten.fft_hfftn) +@out_wrapper() +def hfftn( + input: TensorLikeType, + s: Optional[ShapeType] = None, + dim: Optional[DimsType] = None, + norm: NormType = None, +) -> TensorLikeType: + shape, dim, last_dim_size = _canonicalize_fft_c2r_shape_and_dim_args( + "hfftn", input, s, dim + ) + input = _maybe_promote_tensor_fft(input, require_complex=True) + input = _resize_fft_input(input, dim, shape) + + tmp = prims.fft_c2c(input, dim=dim[:-1], forward=True) if len(dim) > 1 else input + tmp = _apply_norm(tmp, norm, _prod(shape[:-1]), forward=True) + tmp = prims.conj_physical(tmp) + out = prims.fft_c2r(tmp, dim=dim[-1:], last_dim_size=last_dim_size) + return _apply_norm(out, norm, last_dim_size, forward=True) + + +@register_decomposition(aten.fft_fft2) +@out_wrapper() +def fft2( + input: TensorLikeType, + s: Optional[ShapeType] = None, + dim: Optional[DimsType] = (-2, -1), + norm: NormType = None, +) -> TensorLikeType: + return torch.fft.fftn(input, s=s, dim=dim, norm=norm) + + +@register_decomposition(aten.fft_ifft2) +@out_wrapper() +def ifft2( + input: TensorLikeType, + s: Optional[ShapeType] = None, + dim: Optional[DimsType] = (-2, -1), + norm: NormType = None, +) -> TensorLikeType: + return torch.fft.ifftn(input, s=s, dim=dim, norm=norm) + + +@register_decomposition(aten.fft_rfft2) +@out_wrapper() +def rfft2( + input: TensorLikeType, + s: Optional[ShapeType] = None, + dim: Optional[DimsType] = (-2, -1), + norm: NormType = None, +) -> TensorLikeType: + return torch.fft.rfftn(input, s=s, dim=dim, norm=norm) + + +@register_decomposition(aten.fft_irfft2) +@out_wrapper() +def irfft2( + input: TensorLikeType, + s: Optional[ShapeType] = None, + dim: Optional[DimsType] = (-2, -1), + norm: NormType = None, +) -> TensorLikeType: + return torch.fft.irfftn(input, s=s, dim=dim, norm=norm) + + +@register_decomposition(aten.fft_hfft2) +@out_wrapper() +def hfft2( + input: TensorLikeType, + s: Optional[ShapeType] = None, + dim: Optional[DimsType] = (-2, -1), + norm: NormType = None, +) -> TensorLikeType: + return torch.fft.hfftn(input, s=s, dim=dim, norm=norm) + + +@register_decomposition(aten.fft_ihfft2) +@out_wrapper() +def ihfft2( + input: TensorLikeType, + s: Optional[ShapeType] = None, + dim: Optional[DimsType] = (-2, -1), + norm: NormType = None, +) -> TensorLikeType: + return torch.fft.ihfftn(input, s=s, dim=dim, norm=norm) + + +def _default_alldims(dim: Optional[DimsType], x: TensorLikeType) -> List[int]: + """Convert Optional[DimsType] to a simple list, defaulting to all dimensions""" + if dim is None: + return list(range(x.ndim)) + elif not isinstance(dim, Sequence): + return [dim] + else: + return list(dim) + + +@register_decomposition(aten.fft_fftshift) +def fftshift(input: TensorLikeType, dim: Optional[DimsType] = None) -> TensorLikeType: + dims = _default_alldims(dim, input) + shift = [input.shape[d] // 2 for d in dims] + return torch.roll(input, shift, dims) + + +@register_decomposition(aten.fft_ifftshift) +def ifftshift(input: TensorLikeType, dim: Optional[DimsType] = None) -> TensorLikeType: + dims = _default_alldims(dim, input) + shift = [(input.shape[d] + 1) // 2 for d in dims] + return torch.roll(input, shift, dims) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..97cb3a55fb14b554787d02b71521ab2976891262 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/modules/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/modules/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..325043c11d3f7bfc3e53f072b51d237957c348e1 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/modules/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/qat/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/qat/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3d79bdbfe83209f18b17cc8c7b245f322871d6c0 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/qat/__init__.py @@ -0,0 +1 @@ +from .modules import * # noqa: F403 diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..239da3f04d472ee845dbb9379dd339d307466b9f Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/reference/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/reference/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..081cb17130b3765c6f6f57f0b8ec1ba901e21312 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/reference/__init__.py @@ -0,0 +1,18 @@ +from .modules import * # noqa: F403 + +__all__ = [ + 'Linear', + 'Conv1d', + 'Conv2d', + 'Conv3d', + 'ConvTranspose1d', + 'ConvTranspose2d', + 'ConvTranspose3d', + 'RNNCell', + 'LSTMCell', + 'GRUCell', + 'LSTM', + 'GRU', + 'Embedding', + 'EmbeddingBag', +] diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e445eb1d27260b10afc676f04a9546cda5b5a306 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/utils.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..bf35a7e531e1abc1aa60be4a3ec4a2430a01a21a --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/utils.py @@ -0,0 +1,533 @@ +import enum +import operator + +import torch +import torch.nn as nn +import torch.ao.nn.intrinsic.quantized as nniq +import torch.ao.nn.quantized as nnq + +toq = torch.ops.quantized +from typing import Tuple, Callable, Dict, Set, List, Optional, Union + +from torch.fx import GraphModule +from torch.fx.graph import Node +from torch.ao.quantization import ( + ObserverBase, + FakeQuantizeBase, +) +from torch.ao.quantization.utils import getattr_from_fqn +from torch.ao.quantization.observer import _is_activation_post_process + +from .ns_types import NSNodeTargetType, NSResultsType + +# TODO(future PR): consider deleting this enum and using the torch types +# directly. This might be tricky because it is not a one to one mapping. +class NodeInputOrOutputType(enum.Enum): + FP32 = enum.auto() # torch.float + INT8 = enum.auto() # torch.qint8 or torch.quint8 + FP16 = enum.auto() # torch.float16 + UNKNOWN = enum.auto() # we cannot determine input/output dtype + # TODO(future PR): while these functions can support multiple dtypes, + # for the purposes of numerical debugging we want to get the actual + # dtype used in the model. We will likely need some kind of dtype + # propagation to estimate this. + FP32_OR_INT8 = enum.auto() # either torch.float or torch.quint8 or torch.qint8 + # TODO(future PRs): dynamic quant, fake quant, etc + + +def get_node_first_input_and_output_type( + node: Node, + gm: GraphModule, + logger_cls: Callable, + node_type_to_io_type_map: Dict[str, Set[NSNodeTargetType]], +) -> Tuple[NodeInputOrOutputType, NodeInputOrOutputType]: + + # TODO(future PR): clean this up + FUNS_IO_TYPE_FP32 = node_type_to_io_type_map["funs_io_type_fp32"] + FUNS_IO_TYPE_FP16 = node_type_to_io_type_map["funs_io_type_fp16"] + FUNS_IO_TYPE_INT8 = node_type_to_io_type_map["funs_io_type_int8"] + FUNS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map["funs_io_type_fp32_or_int8"] + MODS_IO_TYPE_FP32 = node_type_to_io_type_map["mods_io_type_fp32"] + MODS_IO_TYPE_INT8 = node_type_to_io_type_map["mods_io_type_int8"] + MODS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map["mods_io_type_fp32_or_int8"] + METHS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map["meths_io_type_fp32_or_int8"] + + if node.op == "call_function": + if node.target in FUNS_IO_TYPE_FP32: + return (NodeInputOrOutputType.FP32, NodeInputOrOutputType.FP32) + if node.target in FUNS_IO_TYPE_FP16: + return (NodeInputOrOutputType.FP16, NodeInputOrOutputType.FP16) + elif node.target in FUNS_IO_TYPE_INT8: + return (NodeInputOrOutputType.INT8, NodeInputOrOutputType.INT8) + elif node.target in FUNS_IO_TYPE_FP32_OR_INT8: + first_arg = get_normalized_nth_input(node, gm, 0) + assert isinstance(first_arg, Node) + ( + _prev_node_input_type, + prev_node_output_type, + ) = get_node_first_input_and_output_type( + first_arg, gm, logger_cls, node_type_to_io_type_map + ) + return (prev_node_output_type, prev_node_output_type) + else: + return (NodeInputOrOutputType.UNKNOWN, NodeInputOrOutputType.UNKNOWN) + + elif node.op == "call_module": + assert node.op == "call_module" + assert isinstance(node.target, str) + mod = getattr_from_fqn(gm, node.target) + is_known_fp32_or_int8_input_module = any( + isinstance(mod, target_type) for target_type in MODS_IO_TYPE_FP32_OR_INT8 # type: ignore[arg-type] + ) + if ( + isinstance(mod, (logger_cls, ObserverBase, FakeQuantizeBase)) # type: ignore[arg-type] + or is_known_fp32_or_int8_input_module + ): + # A logger or observer's input and output type is the output + # type of the preceding node. + first_arg = get_normalized_nth_input(node, gm, 0) + assert isinstance(first_arg, Node) + ( + _prev_node_input_type, + prev_node_output_type, + ) = get_node_first_input_and_output_type( + first_arg, gm, logger_cls, node_type_to_io_type_map + ) + return (prev_node_output_type, prev_node_output_type) + is_known_fp32_input_module = any( + isinstance(mod, target_type) for target_type in MODS_IO_TYPE_FP32 # type: ignore[arg-type] + ) + is_known_int8_input_module = any( + isinstance(mod, target_type) for target_type in MODS_IO_TYPE_INT8 # type: ignore[arg-type] + ) + if is_known_fp32_input_module: + return (NodeInputOrOutputType.FP32, NodeInputOrOutputType.FP32) + elif is_known_int8_input_module: + return (NodeInputOrOutputType.INT8, NodeInputOrOutputType.INT8) + else: + return (NodeInputOrOutputType.UNKNOWN, NodeInputOrOutputType.UNKNOWN) + + elif node.op == "call_method": + if node.target == "dequantize": + # Dequantize is a special node because it allows multiple input types. + # So, we look up the output type of the previous node and return that + # as the input type of this node instance. + prev_node = get_normalized_nth_input(node, gm, 0) + assert isinstance(prev_node, Node) + ( + _prev_node_input_type, + prev_node_output_type, + ) = get_node_first_input_and_output_type( + prev_node, gm, logger_cls, node_type_to_io_type_map + ) + return (prev_node_output_type, NodeInputOrOutputType.FP32) + + elif node.target == "to": + # to is a special node because it allows multiple input types. + # So, we look up the output type of the previous node and return that + # as the input type of this node instance. We also look up the target + # of to and return the correct output type. + prev_node = get_normalized_nth_input(node, gm, 0) + assert isinstance(prev_node, Node) + ( + _prev_node_input_type, + prev_node_output_type, + ) = get_node_first_input_and_output_type( + prev_node, gm, logger_cls, node_type_to_io_type_map + ) + + cur_node_dtype_target = get_normalized_nth_input(node, gm, 1) + assert ( + cur_node_dtype_target is torch.float16 + ), f"{cur_node_dtype_target} handling needs to be added" + + return (prev_node_output_type, NodeInputOrOutputType.FP16) + + elif node.target in METHS_IO_TYPE_FP32_OR_INT8: + first_arg = get_normalized_nth_input(node, gm, 0) + assert isinstance(first_arg, Node) + ( + _prev_node_input_type, + prev_node_output_type, + ) = get_node_first_input_and_output_type( + first_arg, gm, logger_cls, node_type_to_io_type_map + ) + return (prev_node_output_type, prev_node_output_type) + + return (NodeInputOrOutputType.UNKNOWN, NodeInputOrOutputType.UNKNOWN) + else: + return (NodeInputOrOutputType.UNKNOWN, NodeInputOrOutputType.UNKNOWN) + + +def get_node_input_qparams( + node: Node, + gm: GraphModule, + node_type_to_io_type_map: Dict[str, Set[NSNodeTargetType]], +) -> Optional[Tuple[Union[torch.Tensor, float], Union[torch.Tensor, int]]]: + """ + Returns the qparams (scale, zero_point) of the first input to `node`, + if they can be inferred from the graph. + """ + prev_node = get_normalized_nth_input(node, gm, 0) + + if not isinstance(prev_node, Node): + return None + + MODS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map["mods_io_type_fp32_or_int8"] + + def _get_scale_zp_from_function_args(node, gm, scale_arg_idx, zp_arg_idx): + scale_node = get_normalized_nth_input(node, gm, scale_arg_idx) + zp_node = get_normalized_nth_input(node, gm, zp_arg_idx) + assert isinstance(scale_node, Node) and isinstance(scale_node.target, str) + assert isinstance(zp_node, Node) and isinstance(zp_node.target, str) + scale_obj = getattr_from_fqn(gm, scale_node.target) + zp_obj = getattr_from_fqn(gm, zp_node.target) + return (scale_obj, zp_obj) + + if prev_node.op == "call_function": + + # quantize - read the args directly + if prev_node.target == torch.quantize_per_tensor: + return _get_scale_zp_from_function_args(prev_node, gm, 1, 2) + elif prev_node.target in (toq.add, toq.add_relu, toq.mul, toq.mul_relu): + return _get_scale_zp_from_function_args(prev_node, gm, 2, 3) + + return None + # TODO(future PR): handle more functionals + # TODO(future PR): handle functional ops which inherit qparams from input + + elif prev_node.op == "call_module": + + # get type of the module + assert isinstance(prev_node.target, str) + module_obj = getattr_from_fqn(gm, prev_node.target) + if isinstance( + module_obj, + ( + nnq.Linear, + nnq.Conv1d, + nnq.Conv2d, + nniq.ConvReLU2d, + nnq.Conv3d, + nnq.BatchNorm2d, + nnq.BatchNorm3d, + nnq.ConvTranspose1d, + nnq.ConvTranspose2d, + nnq.ELU, + nnq.GroupNorm, + nnq.InstanceNorm1d, + nnq.InstanceNorm2d, + nnq.InstanceNorm3d, + nnq.LayerNorm, + nnq.Hardswish, + nnq.LeakyReLU, + nnq.ReLU6, + nniq.BNReLU2d, + nniq.BNReLU3d, + nniq.ConvReLU1d, + nniq.ConvReLU2d, + nniq.ConvReLU3d, + nniq.LinearReLU, + ), + ): + return (module_obj.scale, module_obj.zero_point) # type: ignore[return-value] + + is_known_fp32_or_int8_input_module = any( + isinstance(module_obj, target_type) for target_type in MODS_IO_TYPE_FP32_OR_INT8 # type: ignore[arg-type] + ) + if is_known_fp32_or_int8_input_module: + return get_node_input_qparams(prev_node, gm, node_type_to_io_type_map) + + return None + + +def return_first_non_observer_node( + node: Node, + gm: GraphModule, +) -> Node: + """ + If node is not an observer, returns it. If node is an observer, + navigates up the graph and returns the first parent which is not an + observer. For example, + + graph: (node_non_obs), node = node_non_obs : returns node_non_obs + graph: (node_non_obs -> obs0), node = obs0 : returns node_non_obs + graph: (node_non_obs -> obs0 -> fq0), node = fq0 : returns node_non_obs + """ + if node.op == "call_module": + node_obj = getattr_from_fqn(gm, node.target) # type: ignore[arg-type] + if _is_activation_post_process(node_obj): + assert len(node.args) == 1 + assert isinstance(node.args[0], Node) + node = node.args[0] + # code duplication intended, not worth refactoring + assert isinstance(node.target, str) + node_obj = getattr_from_fqn(gm, node.target) + if _is_activation_post_process(node_obj): + assert len(node.args) == 1 + assert isinstance(node.args[0], Node) + node = node.args[0] + return node + + +def get_number_of_non_param_args( + node: Node, + gm: GraphModule, +) -> int: + """ + Assumes that all non-param args occur first. Returns the number of + non-param args expected for a node. For example, for + + F.linear(x, weight, bias) + + Returns 1, because x is a non-param arg and weight and bias are params. + For + + lstm_mod(x, hid) + + Returns 2, because both x and hid are non-param args. + """ + if node.op == "call_module": + node_obj = getattr_from_fqn(gm, node.target) # type: ignore[arg-type] + if isinstance(node_obj, nn.LSTM): + return 2 + + # default is 1 + return 1 + + +def get_arg_indices_of_inputs_to_log(node: Node) -> List[int]: + """ + Returns the indices of args of the node which we should attach + loggers to, if input logging is enabled. + + For example, + * for (x + y), returns [0, 1] + * for (1 + y), returns [1] + * for (x + 1), returns [0] + * for (linear(x, w, b)) returns [0] + * by default, returns [0] + """ + if len(node.args) == 0: + return [] + if node.op == "call_function" and ( + # TODO(future PR): use relationship map instead of hardcoding + node.target in (torch.add, torch.ops.quantized.add, operator.add) + or node.target in (torch.mul, torch.ops.quantized.mul, operator.mul) + ): + result = [] + for i in range(2): + if type(node.args[i]) == Node: + result.append(i) + return result + return [0] + + +def get_target_type_str(node: Node, gm: GraphModule) -> str: + """ + Returns a string representation of the type of the function or module + pointed to by this node, or '' for other node types. + """ + target_type = "" + if node.op in ("call_function", "call_method"): + target_type = torch.typename(node.target) + elif node.op == "call_module": + assert isinstance(node.target, str) + target_mod = getattr_from_fqn(gm, node.target) + target_type = torch.typename(target_mod) + return target_type + + +def rekey_logger_info_on_node_name_of_model( + results: NSResultsType, + model_name: str, +) -> NSResultsType: + """ + Rekeys the layer name of a results dictionary to use node names + from `model_name`. + + For example, transforms + + {'base_op_1_0': {'node_output': {'model_a': + [{'ref_node_name': 'linear1', ...}]}}} + + into + + {'linear1': {'node_output': {'model_a': + [{'ref_node_name': 'linear1', ...}]}}} + + Note: we cannot use these node names directly because they are not + guaranteed to be consistent across models. This is why we extract + the results first and rekey afterwards. + """ + new_results = {} + for old_layer_name, result_type_to_results in results.items(): + new_layer_name = None + for model_name_to_results in result_type_to_results.values(): + for cur_model_name, list_of_results in model_name_to_results.items(): + if cur_model_name == model_name: + assert len(list_of_results) + new_layer_name = list_of_results[0]["ref_node_name"] + else: + continue + if new_layer_name is not None: + new_results[new_layer_name] = result_type_to_results + else: + new_results[old_layer_name] = result_type_to_results + return new_results + + +def maybe_add_missing_fqns(results: NSResultsType) -> None: + """ + If `fqn` entries are filled in for one of the models in `results`, copies + them over to any models which do not have them filled out. + + A common use case benefitting from this is comparing a model prepared by + quantization to a quantized model. In this case, the model prepared by + quantization would have `fqn` entries, and the quantized model would not. + """ + + # Check in the first result to find any model with fqn entries defined. + model_name_with_fqns = None + for result_type_to_results in results.values(): + for model_name_to_results in result_type_to_results.values(): + for model_name, model_results in model_name_to_results.items(): + if len(model_results) > 0: + if model_results[0]["fqn"] is not None: + model_name_with_fqns = model_name + break + break + break + + if model_name_with_fqns: + for result_type_to_results in results.values(): + for model_name_to_results in result_type_to_results.values(): + ref_model_results = model_name_to_results[model_name_with_fqns] + for model_name, model_results in model_name_to_results.items(): + if model_name == model_name_with_fqns: + continue + for i in range(len(model_results)): + fqn = ref_model_results[i]["fqn"] + model_results[i]["fqn"] = fqn + + +def maybe_dequantize_first_two_tensor_args_and_handle_tuples(f): + def inner(*args, **kwargs): + a0, a1, *a_other = args + + if (isinstance(a0, tuple) and isinstance(a1, tuple)) or ( + isinstance(a0, list) and isinstance(a1, list) + ): + results = [] + for el0, el1 in zip(a0, a1): + new_args = (el0, el1, *a_other) + results.append(inner(*new_args, **kwargs)) + return results + + elif isinstance(a0, torch.Tensor) and isinstance(a1, torch.Tensor): + if a0.is_quantized: + a0 = a0.dequantize() + if a1.is_quantized: + a1 = a1.dequantize() + + # for the purposes of this util, only handle floats + if a0.dtype != torch.float or a1.dtype != torch.float: + return None + + new_args = (a0, a1, *a_other) + return f(*new_args, **kwargs) + + return inner + + +@maybe_dequantize_first_two_tensor_args_and_handle_tuples +def compute_sqnr(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """ + Computes the SQNR between `x` and `y`. + + Args: + x: Tensor or tuple of tensors + y: Tensor or tuple of tensors + + Return: + float or tuple of floats + """ + Ps = torch.norm(x) + Pn = torch.norm(x - y) + return 20 * torch.log10(Ps / Pn) + + +@maybe_dequantize_first_two_tensor_args_and_handle_tuples +def compute_normalized_l2_error(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """ + Computes the normalized L2 error between `x` and `y`. + + Args: + x: Tensor or tuple of tensors + y: Tensor or tuple of tensors + + Return: + float or tuple of floats + """ + return torch.sqrt(((x - y) ** 2).sum() / (x ** 2).sum()) + + +@maybe_dequantize_first_two_tensor_args_and_handle_tuples +def compute_cosine_similarity(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """ + Computes the cosine similarity between `x` and `y`. + + Args: + x: Tensor or tuple of tensors + y: Tensor or tuple of tensors + + Return: + float or tuple of floats + """ + # For convolutions, the shape of the quantized weight has one additional + # dimension compared to the shape of the fp32 weight. Match the shapes + # to enable cosine similarity comparison. + x = x.reshape(1, -1) + y = y.reshape(1, -1) + return torch.nn.functional.cosine_similarity(x, y) + +def op_type_supports_shadowing(node: Node) -> bool: + if node.op == 'call_function': + if node.target in (torch.add, torch.mul, operator.add, operator.mul, torch.cat, torch.stack): + # shadowing for ops with multiple tensor inputs is not implemented yet + return False + return True + +def get_normalized_nth_input(node: Node, gm: GraphModule, idx: int) -> Node: + """ + Given a node, gets the n'th input to that node, normalizing + args and kwargs to the best of its ability. + """ + try: + norm_args_and_kwargs = node.normalized_arguments( + gm, normalize_to_only_use_kwargs=True) + if norm_args_and_kwargs is not None: + norm_args, norm_kwargs = norm_args_and_kwargs + assert len(norm_args) + len(norm_kwargs) > idx + if idx < len(norm_args): + return norm_args[idx] + else: + # note: in Python 3.7+ dicts are ordered + return list(norm_kwargs.values())[idx] + else: + assert len(node.args) + len(node.kwargs) > idx + if idx < len(node.args): + return node.args[idx] # type: ignore[return-value] + else: + kwargs_idx = idx + len(node.args) + return list(node.kwargs.values())[kwargs_idx] # type: ignore[return-value] + except RuntimeError: + # this RuntimeError happens when node argument normalization + # requires typehints to proceed, such as for torch.add where + # either the first, second or both arguments could be tensors + assert len(node.args) + len(node.kwargs) > idx + if idx < len(node.args): + return node.args[idx] # type: ignore[return-value] + else: + kwargs_idx = idx + len(node.args) + return list(node.kwargs.values())[kwargs_idx] # type: ignore[return-value] diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f459bd1aea78338d9629563fe5744b4d65a3fa01 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/_correct_bias.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/_correct_bias.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..76c2c8010ddde07a87591cdf1eebc02b182b9de2 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/_correct_bias.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/_equalize.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/_equalize.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f94f9e513e51721e9bd95067619852c25f1cd8ea Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/_equalize.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/fuser_method_mappings.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/fuser_method_mappings.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..99a58dd6fd8fde9335016edf46f7676d67184f9d Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/fuser_method_mappings.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/quant_type.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/quant_type.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f128aed67b1d8223aea00aa07b66d4d478938168 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/quant_type.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/quantize_pt2e.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/quantize_pt2e.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..16e1e1bac77f3cec1853f88901bf99b93defcc5c Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/quantize_pt2e.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/stubs.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/stubs.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ee7a69b35704c7c1569dd38d7fd11e2c451710d3 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/stubs.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/__pycache__/native.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/__pycache__/native.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d540f554f935d623f6321c99664ce0ac4d5a8643 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/__pycache__/native.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/_qnnpack_pt2e.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/_qnnpack_pt2e.py new file mode 100644 index 0000000000000000000000000000000000000000..01e112b688c0428cadb5e31502f18387bcd9282f --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/_qnnpack_pt2e.py @@ -0,0 +1,160 @@ +import operator +import torch +from torch.ao.quantization.backend_config import ( + BackendConfig, + DTypeConfig, + ObservationType, + BackendPatternConfig, +) + +weighted_op_quint8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, + weight_dtype=torch.qint8, + bias_dtype=torch.float, +) +from typing import List + +def get_linear_configs(): + linear_configs = [] + observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + dtype_configs = [weighted_op_quint8_dtype_config] + + # TODO: need to fix the way we insert observers for this pattern + # should be solved in the new fusion API + # reason that this doesn't work: the pattern is a bit complicated and we don't + # have a way to specify which input of the pattern we would like to observe + # pattern: + # bias input weight + # \ | / + # \ | t + # \ | / + # addmm + # we want to observe "weight" as weight, but there is not way to convey this + # information with current pattern language + # + # right now: + # original: + # weight - t \ + # input - addmm + # observed (no hack): + # weight - t - observer \ + # input - observer - addmm + # target: + # weight - observer - t \ + # input - observer - addmm + + # def root_node_getter(node_pattern): + # addmm, bias, act, weight = node_pattern + # return addmm + + # linear_configs.append( + # BackendPatternConfig((torch.ops.aten.addmm.default, MatchAllNode, MatchAllNode, torch.ops.aten.t.default)) + # .set_observation_type(observation_type) # noqa: E131 + # .set_dtype_configs(dtype_configs) + # ._set_root_node_getter(root_node_getter)) + + linear_configs.append( + BackendPatternConfig(torch.ops.aten.addmm.default) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + ._set_input_type_to_index({"weight": 2, "bias": 0}) + ) + # linear is decomposed to `t - mm` if bias is not present + linear_configs.append( + BackendPatternConfig(torch.ops.aten.mm.default) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + ._set_input_type_to_index({"weight": 1}) + ) + return linear_configs + +def get_conv_configs(): + conv_configs = [] + observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + dtype_configs = [weighted_op_quint8_dtype_config] + conv_configs.append( + BackendPatternConfig(torch.ops.aten.convolution.default) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + ._set_input_type_to_index({"weight": 1, "bias": 2}) + ) + conv_configs.append( + BackendPatternConfig((torch.ops.aten.convolution.default, torch.ops.aten.relu.default)) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + ._set_input_type_to_index({"weight": 1, "bias": 2}) + ) + # TODO: remove when functionalization is supported in PT2 mode + conv_configs.append( + BackendPatternConfig((torch.ops.aten.convolution.default, torch.ops.aten.relu_.default)) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + ._set_input_type_to_index({"weight": 1, "bias": 2}) + ) + return conv_configs + +def get_pooling_configs(): + backend_pattern_configs = [] + observation_type = ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT + dtype_configs = [weighted_op_quint8_dtype_config] + + def root_node_getter(node_pattern): + getitem, maxpool, index = node_pattern + return maxpool + + backend_pattern_configs.append( + BackendPatternConfig() + ._set_pattern_complex_format((operator.getitem, torch.ops.aten.max_pool2d_with_indices.default, 0)) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + ._set_root_node_getter(root_node_getter) + ) + + return backend_pattern_configs + +def get_relu_configs(): + backend_pattern_configs = [] + observation_type = ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT + dtype_configs = [weighted_op_quint8_dtype_config] + backend_pattern_configs.append( + BackendPatternConfig(torch.ops.aten.relu.default) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs)) + return backend_pattern_configs + +def get_binary_op_configs(): + binary_op_configs: List[BackendPatternConfig] = [] + dtype_configs = [weighted_op_quint8_dtype_config] + num_tensor_args_to_observation_type_mapping = { + # TODO: this is not used right now since we have extra check in prepare + # will need to change this to NO_OBSERVER later after we implemented + # Tensor dtype inference properly + 0: ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT, + 1: ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT, + 2: ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT, + } + for op_with_quantized_bop_scalar_variant in [torch.ops.aten.add.Tensor, torch.ops.aten.add_.Tensor]: + bop_patterns = [ + (op_with_quantized_bop_scalar_variant, torch.ops.aten.relu.default), + op_with_quantized_bop_scalar_variant, + # TODO: remove when functionalization is supported in pt2_mode + (op_with_quantized_bop_scalar_variant, torch.ops.aten.relu_.default), + ] + for bop_pattern in bop_patterns: + binary_op_configs.append( + BackendPatternConfig(bop_pattern) + .set_dtype_configs(dtype_configs) # noqa: E131 + ._set_num_tensor_args_to_observation_type(num_tensor_args_to_observation_type_mapping)) + + return binary_op_configs + +def get_qnnpack_pt2e_backend_config(): + return ( + BackendConfig("qnnpack_pytorch_2.0_export") + .set_backend_pattern_configs(get_linear_configs()) + .set_backend_pattern_configs(get_binary_op_configs()) + .set_backend_pattern_configs(get_conv_configs()) + .set_backend_pattern_configs(get_pooling_configs()) + .set_backend_pattern_configs(get_relu_configs()) + ) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/fbgemm.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/fbgemm.py new file mode 100644 index 0000000000000000000000000000000000000000..74759fa73580c2ab8abe9352887bf11c1f029f62 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/fbgemm.py @@ -0,0 +1,116 @@ +import torch +from ._common_operator_config_utils import ( + _get_binary_op_configs, + _get_bn_configs, + _get_cat_config, + _get_conv_configs, + _get_default_op_configs, + _get_embedding_op_configs, + _get_fixed_qparams_op_configs, + _get_linear_configs, + _get_rnn_op_configs, + _get_share_qparams_op_configs, + _get_tensor_info_op_configs, +) +from .backend_config import BackendConfig, DTypeConfig + +__all__ = [ + "get_fbgemm_backend_config", +] + +# =================== +# | DTYPE CONFIGS | +# =================== + +# TODO: For now, these DTypeConfigs are identical to the ones defined in native.py +# In the future, once we support specifying quant_min/quant_max and scale_min/scale_max, +# these will diverge. In particular, for FBGEMM, we will restrict the activation quantized +# values to within [0, 127]. + +fbgemm_weighted_op_quint8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, + weight_dtype=torch.qint8, + bias_dtype=torch.float, +) + +fbgemm_default_op_quint8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, +) + +fbgemm_default_op_fp16_dtype_config = DTypeConfig( + input_dtype=torch.float16, + output_dtype=torch.float16, + weight_dtype=torch.float16, + bias_dtype=torch.float16, +) + +fbgemm_default_dynamic_int8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.float, + weight_dtype=torch.qint8, + bias_dtype=torch.float, + is_dynamic=True, +) + +fbgemm_default_dynamic_float16_dtype_config = DTypeConfig( + input_dtype=torch.float16, + output_dtype=torch.float, + weight_dtype=torch.float16, + bias_dtype=torch.float, + is_dynamic=True, +) + +fbgemm_weight_only_quint8_dtype_config = DTypeConfig( + input_dtype=torch.float, + output_dtype=torch.float, + weight_dtype=torch.quint8, +) + +fbgemm_weight_only_quint4x2_dtype_config = DTypeConfig( + input_dtype=torch.float, + output_dtype=torch.float, + weight_dtype=torch.quint4x2, +) + + +# ===================== +# | BACKEND CONFIGS | +# ===================== + +def get_fbgemm_backend_config() -> BackendConfig: + """ + Return the `BackendConfig` for PyTorch's native FBGEMM backend. + """ + conv_dtype_configs = [fbgemm_weighted_op_quint8_dtype_config] + linear_dtype_configs = [ + fbgemm_weighted_op_quint8_dtype_config, + fbgemm_default_dynamic_int8_dtype_config, + fbgemm_default_dynamic_float16_dtype_config, + ] + binary_op_dtype_configs = [fbgemm_default_op_quint8_dtype_config] + default_op_dtype_configs = [fbgemm_default_op_quint8_dtype_config] + fixed_qparams_op_dtype_configs = [fbgemm_default_op_quint8_dtype_config] + share_qparams_op_dtype_configs = [fbgemm_default_op_quint8_dtype_config] + tensor_info_op_dtype_configs = [fbgemm_default_op_quint8_dtype_config] + rnn_op_dtype_configs = [ + fbgemm_default_dynamic_int8_dtype_config, + fbgemm_default_dynamic_float16_dtype_config, + ] + embedding_op_dtype_configs = [ + fbgemm_weight_only_quint8_dtype_config, + fbgemm_weight_only_quint4x2_dtype_config, + ] + return BackendConfig("fbgemm") \ + .set_backend_pattern_configs(_get_conv_configs(conv_dtype_configs)) \ + .set_backend_pattern_configs(_get_linear_configs(linear_dtype_configs)) \ + .set_backend_pattern_configs(_get_binary_op_configs(binary_op_dtype_configs)) \ + .set_backend_pattern_config(_get_cat_config(default_op_dtype_configs)) \ + .set_backend_pattern_configs(_get_default_op_configs(default_op_dtype_configs)) \ + .set_backend_pattern_configs(_get_fixed_qparams_op_configs(fixed_qparams_op_dtype_configs)) \ + .set_backend_pattern_configs(_get_share_qparams_op_configs(share_qparams_op_dtype_configs)) \ + .set_backend_pattern_configs(_get_tensor_info_op_configs(tensor_info_op_dtype_configs)) \ + .set_backend_pattern_configs(_get_bn_configs(default_op_dtype_configs)) \ + .set_backend_pattern_configs(_get_rnn_op_configs(rnn_op_dtype_configs)) \ + .set_backend_pattern_configs(_get_embedding_op_configs(embedding_op_dtype_configs)) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/native.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/native.py new file mode 100644 index 0000000000000000000000000000000000000000..81cfc928adb5b127a09691c5841b4cfd1d564800 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/native.py @@ -0,0 +1,204 @@ +import torch +from ._common_operator_config_utils import ( + _get_binary_op_configs, + _get_bn_configs, + _get_cat_config, + _get_conv_configs, + _get_default_op_configs, + _get_embedding_op_configs, + _get_fixed_qparams_op_configs, + _get_linear_configs, + _get_ln_configs, + _get_rnn_op_configs, + _get_share_qparams_op_configs, + _get_tensor_info_op_configs, +) +from .backend_config import BackendConfig, DTypeConfig + +__all__ = [ + "get_test_only_legacy_native_backend_config", + "default_op_quint8_dtype_config", + "default_op_fp16_dtype_config", + "default_dynamic_int8_dtype_config", + "default_dynamic_float16_dtype_config", + "input_output_only_quint8_dtype_config", + "weight_only_quint8_dtype_config", + "weight_only_quint4x2_dtype_config", + "get_native_backend_config", + "get_native_backend_config_dict", + "get_test_only_legacy_native_backend_config_dict", +] + +# =================== +# | DTYPE CONFIGS | +# =================== + +# weighted op int8 dtype config +# this is config for ops that has quantized weights, like linear, conv +weighted_op_quint8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, + weight_dtype=torch.qint8, + bias_dtype=torch.float, +) + +default_op_quint8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, +) + +default_op_fp16_dtype_config = DTypeConfig( + input_dtype=torch.float16, + output_dtype=torch.float16, + weight_dtype=torch.float16, + bias_dtype=torch.float16, +) + +default_dynamic_int8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.float, + weight_dtype=torch.qint8, + bias_dtype=torch.float, + # currently the dtype check is not yet enabled, so we provided the dtype_configs but + # it is not really used yet, + # we will enable it a bit later after we moved everything to backend_config_dict + is_dynamic=True, +) + +default_dynamic_float16_dtype_config = DTypeConfig( + input_dtype=torch.float16, + output_dtype=torch.float, + weight_dtype=torch.float16, + bias_dtype=torch.float, + # currently the dtype check is not yet enabled, so we provided the dtype_configs but + # it is not really used yet, + # we will enable it a bit later after we moved everything to backend_config_dict + is_dynamic=True, +) + +# Needed for LayerNorm and f.layer_norm, since currently the kernel only supports float weights +input_output_only_quint8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, + weight_dtype=torch.float, + bias_dtype=torch.float, +) + +weight_only_quint8_dtype_config = DTypeConfig( + input_dtype=torch.float, + output_dtype=torch.float, + weight_dtype=torch.quint8, +) + +weight_only_quint4x2_dtype_config = DTypeConfig( + input_dtype=torch.float, + output_dtype=torch.float, + weight_dtype=torch.quint4x2, +) + + +# ===================== +# | BACKEND CONFIGS | +# ===================== + +def get_test_only_legacy_native_backend_config() -> BackendConfig: + """ + Return the `BackendConfig` for PyTorch Native backend (fbgemm/qnnpack) with various additional fp16 ops. + """ + conv_dtype_configs = [weighted_op_quint8_dtype_config] + linear_dtype_configs = [ + weighted_op_quint8_dtype_config, + default_dynamic_int8_dtype_config, + default_dynamic_float16_dtype_config, + default_op_fp16_dtype_config, + ] + binary_op_dtype_configs = [ + default_op_quint8_dtype_config, + default_op_fp16_dtype_config, + ] + default_op_dtype_configs = [default_op_quint8_dtype_config] + fixed_qparams_op_dtype_configs = [ + default_op_quint8_dtype_config, + default_op_fp16_dtype_config, + ] + share_qparams_op_dtype_configs = [ + default_op_quint8_dtype_config, + default_op_fp16_dtype_config + ] + tensor_info_op_dtype_configs = [ + default_op_quint8_dtype_config, + ] + rnn_op_dtype_configs = [ + default_dynamic_int8_dtype_config, + default_dynamic_float16_dtype_config, + ] + embedding_op_dtype_configs = [ + weight_only_quint8_dtype_config, + weight_only_quint4x2_dtype_config, + ] + layer_norm_op_dtype_configs = [input_output_only_quint8_dtype_config] + return BackendConfig("_native_and_fp16") \ + .set_backend_pattern_configs(_get_conv_configs(conv_dtype_configs)) \ + .set_backend_pattern_configs(_get_linear_configs(linear_dtype_configs)) \ + .set_backend_pattern_configs(_get_binary_op_configs(binary_op_dtype_configs)) \ + .set_backend_pattern_config(_get_cat_config(default_op_dtype_configs)) \ + .set_backend_pattern_configs(_get_default_op_configs(default_op_dtype_configs)) \ + .set_backend_pattern_configs(_get_fixed_qparams_op_configs(fixed_qparams_op_dtype_configs)) \ + .set_backend_pattern_configs(_get_share_qparams_op_configs(share_qparams_op_dtype_configs)) \ + .set_backend_pattern_configs(_get_tensor_info_op_configs(tensor_info_op_dtype_configs)) \ + .set_backend_pattern_configs(_get_bn_configs(default_op_dtype_configs)) \ + .set_backend_pattern_configs(_get_ln_configs(layer_norm_op_dtype_configs)) \ + .set_backend_pattern_configs(_get_rnn_op_configs(rnn_op_dtype_configs)) \ + .set_backend_pattern_configs(_get_embedding_op_configs(embedding_op_dtype_configs)) + +def get_native_backend_config() -> BackendConfig: + """ + Return the `BackendConfig` for PyTorch Native backend (fbgemm/qnnpack). + """ + # TODO: express this BackendConfig as a union of the FBGEMM and QNNPACK BackendConfigs + conv_dtype_configs = [weighted_op_quint8_dtype_config] + linear_dtype_configs = [ + weighted_op_quint8_dtype_config, + default_dynamic_int8_dtype_config, + default_dynamic_float16_dtype_config, + ] + binary_op_dtype_configs = [default_op_quint8_dtype_config] + default_op_dtype_configs = [default_op_quint8_dtype_config] + fixed_qparams_op_dtype_configs = [default_op_quint8_dtype_config] + share_qparams_op_dtype_configs = [default_op_quint8_dtype_config] + tensor_info_op_dtype_configs = [default_op_quint8_dtype_config] + rnn_op_dtype_configs = [ + default_dynamic_int8_dtype_config, + default_dynamic_float16_dtype_config, + ] + embedding_op_dtype_configs = [ + weight_only_quint8_dtype_config, + weight_only_quint4x2_dtype_config, + ] + layer_norm_op_dtype_configs = [input_output_only_quint8_dtype_config] + return BackendConfig("native") \ + .set_backend_pattern_configs(_get_conv_configs(conv_dtype_configs)) \ + .set_backend_pattern_configs(_get_linear_configs(linear_dtype_configs)) \ + .set_backend_pattern_configs(_get_binary_op_configs(binary_op_dtype_configs)) \ + .set_backend_pattern_config(_get_cat_config(default_op_dtype_configs)) \ + .set_backend_pattern_configs(_get_default_op_configs(default_op_dtype_configs)) \ + .set_backend_pattern_configs(_get_fixed_qparams_op_configs(fixed_qparams_op_dtype_configs)) \ + .set_backend_pattern_configs(_get_share_qparams_op_configs(share_qparams_op_dtype_configs)) \ + .set_backend_pattern_configs(_get_tensor_info_op_configs(tensor_info_op_dtype_configs)) \ + .set_backend_pattern_configs(_get_bn_configs(default_op_dtype_configs)) \ + .set_backend_pattern_configs(_get_ln_configs(layer_norm_op_dtype_configs)) \ + .set_backend_pattern_configs(_get_rnn_op_configs(rnn_op_dtype_configs)) \ + .set_backend_pattern_configs(_get_embedding_op_configs(embedding_op_dtype_configs)) + +def get_native_backend_config_dict(): + """ + Return the `BackendConfig` for PyTorch Native backend (fbgemm/qnnpack) in dictionary form. + """ + return get_native_backend_config().to_dict() + +def get_test_only_legacy_native_backend_config_dict(): + """ + Return the `BackendConfig` for PyTorch Native backend (fbgemm/qnnpack) with various additional + fp16 ops in dictionary form. + """ + return get_test_only_legacy_native_backend_config().to_dict() diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fuser_method_mappings.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fuser_method_mappings.py new file mode 100644 index 0000000000000000000000000000000000000000..a160346d8a89dcb17ae9abf4f8a49f974f874ef6 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fuser_method_mappings.py @@ -0,0 +1,259 @@ +import torch.nn as nn +import torch.ao.nn.intrinsic as nni + +from typing import Any, Union, Callable, List, Tuple, Dict, Optional, Type +from torch.ao.quantization.utils import Pattern, get_combined_dict, MatchAllNode +import itertools + +__all__ = [ + "fuse_conv_bn", + "fuse_conv_bn_relu", + "fuse_linear_bn", + "fuse_convtranspose_bn", + "get_fuser_method", + "get_fuser_method_new", +] + +def fuse_conv_bn(is_qat, conv, bn): + r"""Return the fused the conv and bn modules. + Given the conv and bn modules, fuses them and returns the fused module + + Args: + is_qat: a flag for whether we are using quantization aware training fusion + or post training quantization fusion + conv: Module instance of type conv2d/conv3d + bn: Spatial BN instance that needs to be fused with the conv + + Examples:: + + >>> m1 = nn.Conv2d(10, 20, 3) + >>> b1 = nn.BatchNorm2d(20) + >>> # xdoctest: +SKIP + >>> m2 = fuse_conv_bn(m1, b1) + """ + assert conv.training == bn.training, \ + "Conv and BN both must be in the same mode (train or eval)." + + fused_module_class_map = { + nn.Conv1d: nni.ConvBn1d, + nn.Conv2d: nni.ConvBn2d, + nn.Conv3d: nni.ConvBn3d, + } + + if is_qat: + assert bn.num_features == conv.out_channels, 'Output channel of Conv2d must match num_features of BatchNorm2d' + assert bn.affine, 'Only support fusing BatchNorm2d with affine set to True' + assert bn.track_running_stats, 'Only support fusing BatchNorm2d with tracking_running_stats set to True' + fused_module_class = fused_module_class_map.get((type(conv)), None) + if fused_module_class is not None: + return fused_module_class(conv, bn) + else: + raise NotImplementedError(f"Cannot fuse train modules: {(conv, bn)}") + else: + return nn.utils.fuse_conv_bn_eval(conv, bn) + +def fuse_conv_bn_relu(is_qat, conv, bn, relu): + r"""Return the fused conv and bv modules. + + Given the conv and bn modules, fuses them and returns the fused module + + Args: + is_qat: a flag for whether we are using quantization aware training fusion + or post training quantization fusion + conv: Module instance of type conv2d/conv3d + bn: Spatial BN instance that needs to be fused with the conv + + Examples:: + + >>> m1 = nn.Conv2d(10, 20, 3) + >>> b1 = nn.BatchNorm2d(20) + >>> r1 = nn.ReLU(inplace=False) + >>> # xdoctest: +SKIP + >>> m2 = fuse_conv_bn_relu(m1, b1, r1) + """ + assert conv.training == bn.training == relu.training, \ + "Conv and BN both must be in the same mode (train or eval)." + fused_module : Optional[Type[nn.Sequential]] = None + if is_qat: + map_to_fused_module_train = { + nn.Conv1d: nni.ConvBnReLU1d, + nn.Conv2d: nni.ConvBnReLU2d, + nn.Conv3d: nni.ConvBnReLU3d, + } + assert bn.num_features == conv.out_channels, 'Output channel of Conv must match num_features of BatchNorm' + assert bn.affine, 'Only support fusing BatchNorm with affine set to True' + assert bn.track_running_stats, 'Only support fusing BatchNorm with tracking_running_stats set to True' + fused_module = map_to_fused_module_train.get(type(conv), None) + if fused_module is not None: + return fused_module(conv, bn, relu) + else: + raise NotImplementedError(f"Cannot fuse train modules: {(conv, bn, relu)}") + else: + map_to_fused_module_eval = { + nn.Conv1d: nni.ConvReLU1d, + nn.Conv2d: nni.ConvReLU2d, + nn.Conv3d: nni.ConvReLU3d, + } + fused_module = map_to_fused_module_eval.get(type(conv), None) + if fused_module is not None: + fused_conv = nn.utils.fusion.fuse_conv_bn_eval(conv, bn) + return fused_module(fused_conv, relu) + else: + raise NotImplementedError(f"Cannot fuse eval modules: {(conv, bn, relu)}") + +def fuse_linear_bn(is_qat, linear, bn): + r"""Return the fused linear and bn modules. + Given the linear and bn modules, fuses them and returns the fused module + + Args: + is_qat: a flag for whether we are using quantization aware training fusion + or post training quantization fusion + linear: Module instance of type Linear + bn: BatchNorm1d instance that needs to be fused with the linear layer + + Examples:: + + >>> m1 = nn.Linear(20, 10) + >>> b1 = nn.BatchNorm1d(10) + >>> # xdoctest: +SKIP + >>> m2 = fuse_linear_bn(m1, b1) + """ + assert linear.training == bn.training, \ + "Linear and BN both must be in the same mode (train or eval)." + + if is_qat: + assert bn.num_features == linear.out_features, \ + "Output features of Linear must match num_features of BatchNorm1d" + assert bn.affine, "Only support fusing BatchNorm1d with affine set to True" + assert bn.track_running_stats, \ + "Only support fusing BatchNorm1d with tracking_running_stats set to True" + return nni.LinearBn1d(linear, bn) + else: + return nn.utils.fusion.fuse_linear_bn_eval(linear, bn) + +def fuse_convtranspose_bn(is_qat, convt, bn): + r"""Return the fused ConvTranspose and bn modules. + Given ConvTranspose and bn modules, fuses them and returns the fused module + + Args: + convt: Module instance of type ConvTransposeNd + bn: BatchNormNd instance that needs to be fused with the linear layer. + batch norm N should match the ConvTranspose N + + Examples:: + + >>> m1 = nn.ConvTranspose2d(10, 20, 3) + >>> b1 = nn.BatchNorm2d(20) + >>> # xdoctest: +SKIP + >>> m2 = fuse_convtranspose_bn(m1, b1) + """ + assert convt.training == bn.training, \ + "ConvTranspose and BN both must be in the same mode (train or eval)." + + if is_qat: + raise Exception("Fusing ConvTranspose+BatchNorm not yet supported in QAT.") + else: + return nn.utils.fusion.fuse_conv_bn_eval(convt, bn, transpose=True) + +def _sequential_wrapper2(sequential): + """Return a sequential wrapped that for is_qat and two modules. + Given a sequential class for two modules, return a function that takes + is_qat, and then two modules as argument, that ignores the is_qat flag + and always returns the sequential that combines the two input modules + """ + def fuser_method(is_qat, m1, m2): + return sequential(m1, m2) + return fuser_method + +_DEFAULT_OP_LIST_TO_FUSER_METHOD: Dict[Tuple, Union[nn.Sequential, Callable]] = { + (nn.Conv1d, nn.BatchNorm1d): fuse_conv_bn, + (nn.Conv1d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu, + (nn.Conv2d, nn.BatchNorm2d): fuse_conv_bn, + (nn.Conv2d, nn.BatchNorm2d, nn.ReLU): fuse_conv_bn_relu, + (nn.Conv3d, nn.BatchNorm3d): fuse_conv_bn, + (nn.Conv3d, nn.BatchNorm3d, nn.ReLU): fuse_conv_bn_relu, + (nn.Conv1d, nn.ReLU): _sequential_wrapper2(nni.ConvReLU1d), + (nn.Conv2d, nn.ReLU): _sequential_wrapper2(nni.ConvReLU2d), + (nn.Conv3d, nn.ReLU): _sequential_wrapper2(nni.ConvReLU3d), + (nn.Linear, nn.BatchNorm1d): fuse_linear_bn, + (nn.Linear, nn.ReLU): _sequential_wrapper2(nni.LinearReLU), + (nn.BatchNorm2d, nn.ReLU): _sequential_wrapper2(nni.BNReLU2d), + (nn.BatchNorm3d, nn.ReLU): _sequential_wrapper2(nni.BNReLU3d), + (nn.ConvTranspose1d, nn.BatchNorm1d): fuse_convtranspose_bn, + (nn.ConvTranspose2d, nn.BatchNorm2d): fuse_convtranspose_bn, + (nn.ConvTranspose3d, nn.BatchNorm3d): fuse_convtranspose_bn, +} + +def get_fuser_method(op_list, additional_fuser_method_mapping=None): + """Get fuser method for the given list of module types. + + Get fuser method for the given list of module types, + return None if fuser method does not exist + """ + if additional_fuser_method_mapping is None: + additional_fuser_method_mapping = {} + all_mappings = get_combined_dict(_DEFAULT_OP_LIST_TO_FUSER_METHOD, + additional_fuser_method_mapping) + fuser_method = all_mappings.get(op_list, None) + assert fuser_method is not None, f"did not find fuser method for: {op_list} " + return fuser_method + +def _reverse2(f): + def reversed(is_qat, x, y): + return f(is_qat, y, x) + return reversed + +def _reverse3(f): + def reversed(is_qat, x, w): + y, z = w + return f(is_qat, z, y, x) + return reversed + +def _get_valid_patterns(op_pattern): + """Return a list of valid patterns generated from the op_pattern. + + Returns a list of valid patterns generated from the op_pattern, + since MatchAllNode can match all types of nodes, + e.g. pattern (torch.nn.Conv2d, torch.add) should also be able to match keys like + (MatchAllNode, torch.add) and (torch.nn.Conv2d, MatchAllNode) + + Example Input: + (torch.add, (torch.nn.ReLU, torch.nn.Conv2d)) + + Example Output: + [(torch.add, (torch.nn.ReLU, torch.nn.Conv2d)), + (torch.add, (torch.nn.ReLU, MatchAllNode)), + (torch.add, (MatchAllNode, torch.nn.Conv2d)), + (torch.add, (MatchAllNode, MatchAllNode)), + (MatchAllNode, (torch.nn.ReLU, torch.nn.Conv2d)), + (MatchAllNode, (torch.nn.ReLU, MatchAllNode)), + (MatchAllNode, (MatchAllNode, torch.nn.Conv2d)), + (MatchAllNode, (MatchAllNode, MatchAllNode)), + ] + """ + result: List[Any] + if isinstance(op_pattern, (tuple, list)): + sub_combs = [] + for sub_pattern in op_pattern: + sub_combs.append(_get_valid_patterns(sub_pattern)) + result = list(itertools.product(*sub_combs)) + else: + result = [op_pattern, MatchAllNode] + return result + +def get_fuser_method_new( + op_pattern: Pattern, + fuser_method_mapping: Dict[Pattern, Union[nn.Sequential, Callable]]): + """Get fuser method. + + This will be made default after we deprecate the get_fuser_method + Would like to implement this first and have a separate PR for deprecation + """ + op_patterns = _get_valid_patterns(op_pattern) + fuser_method = None + for op_pattern in op_patterns: + fuser_method = fuser_method_mapping.get(op_pattern, None) + if fuser_method is not None: + break + assert fuser_method is not None, f"did not find fuser method for: {op_pattern} " + return fuser_method diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0e37eaaded975381d6153b2a66c9d9550d07cd03 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__init__.py @@ -0,0 +1,3 @@ +from .prepare import prepare +from .convert import convert +from .fuse import fuse diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/_decomposed.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/_decomposed.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2d0bf594012654d149cba3656305ab1d08748559 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/_decomposed.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/_equalize.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/_equalize.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b9dfcf3d5439cf5243d79a6b9033e4e2f77a2755 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/_equalize.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/custom_config.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/custom_config.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e97c912eabae2da7785802b098191e9dda9a6cc3 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/custom_config.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/fuse.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/fuse.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..90dab2dc06fd818810699b69c3a73cb8c63b99a6 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/fuse.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/graph_module.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/graph_module.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3876a058512c3df2b96055516000a5ac240fe6eb Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/graph_module.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/lower_to_qnnpack.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/lower_to_qnnpack.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d930f730261fa349657cb59db77f15f0d07c3bb3 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/lower_to_qnnpack.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/match_utils.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/match_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6d2983ee98ec6545201464674202bd0241dc9823 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/match_utils.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/pattern_utils.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/pattern_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1736ff3c40c6149a6b44f63a5cd6e8dd529bec63 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/pattern_utils.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/prepare.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/prepare.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c956ad0e1d1444f806dbb754410d5fd913b38b2f Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/prepare.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/qconfig_mapping_utils.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/qconfig_mapping_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b1430220514cea35c1fcefe96f7162c048f9063d Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/qconfig_mapping_utils.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/quantize_handler.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/quantize_handler.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..716e9518ebb51dc773875a2def4a42daa6788c42 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/quantize_handler.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/_equalize.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/_equalize.py new file mode 100644 index 0000000000000000000000000000000000000000..55bcb52576b212b941a9ebdd9c76b8566d593d10 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/_equalize.py @@ -0,0 +1,820 @@ +import warnings + +from collections import namedtuple +from typing import Any, Dict, List, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.ao.nn.intrinsic as nni +from torch.fx import GraphModule +from torch.fx.graph import Node +from torch.ao.quantization.fx.graph_module import _get_observed_graph_module_attr + +from ..observer import _with_args, ObserverBase, PerChannelMinMaxObserver +from ..utils import _parent_name, check_min_max_valid + +from .utils import ( + get_new_attr_name_with_prefix, + maybe_get_next_module, + node_arg_is_weight, +) + +CUSTOM_MODULE_SUPP_LIST: List[Any] = [] + +def reshape_scale(scale: torch.Tensor, axis: int, input: torch.Tensor) -> torch.Tensor: + """Reshapes the scale so that we can multiply it to the input by the given axis. + """ + new_shape = [1] * input.ndim + new_shape[axis] = input.size(axis) + return scale.view(new_shape) + +qsheme_mapping_per_tensor_to_per_channel = { + torch.per_tensor_affine: torch.per_channel_affine, + torch.per_tensor_symmetric: torch.per_channel_symmetric, +} + + +class _InputEqualizationObserver(nn.Module): + r"""Observer for tracking the running min/max values of input columns, and + computing the quantization parameters for the overall min/max input values. + + Args: + dtype: Quantized data type + qscheme: Quantization scheme + quant_min: Minimum quantization value. If unspecified, it will + follow the 8-bit setup. + quant_max: Maximum quantization value. If unspecified, it will + follow the 8-bit setup. + + The running minimum/maximum :math:`x_\text{min/max}` are computed in the + same way as :class:`~torch.ao.quantization.observer.PerChannelMinMaxObserver`, + with the difference that the running min/max values are stored per column. + This observer is intended to be used along with a WeightEqualizationObserver + to calculate the equalization scale. + """ + + def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine, + quant_min=None, quant_max=None, factory_kwargs=None) -> None: + super().__init__() + + if qscheme not in {torch.per_tensor_affine, torch.per_tensor_symmetric}: + raise TypeError("Input qscheme must be per-tensor") + + self.dtype = dtype + self.qscheme = qscheme + + per_channel_qscheme = qsheme_mapping_per_tensor_to_per_channel[qscheme] + self.input_obs = PerChannelMinMaxObserver(ch_axis=1, dtype=dtype, + qscheme=per_channel_qscheme, + quant_min=quant_min, + quant_max=quant_max, + factory_kwargs=factory_kwargs) + + self.equalization_scale = torch.tensor(1) + self.equalization_shape: List[int] = [] + + def forward(self, x_orig): + if not (x_orig.ndim >= 2 and x_orig.ndim <= 5): + raise ValueError("InputEqualizationObserver only supports Linear and Conv layers") + + # Calculate the shape needed to reshape the equalization scale later (needed for Conv layers) + self.equalization_shape = [1] * x_orig.ndim + self.equalization_shape[1] = x_orig.size(1) + + return self.input_obs(x_orig) + + def get_input_minmax(self): + return (self.input_obs.min_val, self.input_obs.max_val) + + def set_equalization_scale(self, equalization_scale): + # Reshape the equalization scale along axis=1 so that it can be + # multiplied with the input along axis=1 + if equalization_scale.nelement() == 1 and equalization_scale == torch.tensor(1): + return + self.equalization_scale = torch.reshape(equalization_scale, self.equalization_shape) + + def calculate_scaled_minmax(self): + r""" Returns the scaled min/max inputs + """ + if self.equalization_scale.nelement() == 1 and self.equalization_scale == torch.tensor(1): + warnings.warn( + "Must call calculate_equalization_scale before calling calculate_scaled_minmax. " + + "Will not scale the next quantization observer." + ) + return None, None + + # Calculate qparams for the scaled min/max inputs + # Scale the input by the equalization scale located at the same column + # index + (min_inputs, max_inputs) = self.get_input_minmax() + equalization_scale_reshaped = reshape_scale(self.equalization_scale, 0, min_inputs) + min_input_scaled = torch.min(torch.mul(min_inputs, equalization_scale_reshaped)) + max_input_scaled = torch.max(torch.mul(max_inputs, equalization_scale_reshaped)) + + return min_input_scaled, max_input_scaled + + with_args = classmethod(_with_args) + + +class _WeightEqualizationObserver(nn.Module): + r"""Observer for tracking the running min/max values of weight columns and + rows, and computing the quantization parameters for the weight rows. + + Args: + dtype: Quantized data type + qscheme: Quantization scheme + quant_min: Minimum quantization value. If unspecified, it will + follow the 8-bit setup. + quant_max: Maximum quantization value. If unspecified, it will + follow the 8-bit setup. + + This observer is made up of 1 PerChannelMinMaxObserver `weight_col_obs` used + to record the running minimum and maximum of columns of incoming weight + tensors. This observer is intended to be used along with an + InputEqualizationObserver to calculate the equalization scale. + + The running minimum/maximum :math:`w_\text{min/max}` are computed in the + same way as :class:`~torch.ao.quantization.observer.PerChannelMinMaxObserver`. + """ + + def __init__(self, dtype=torch.qint8, qscheme=torch.per_tensor_affine, quant_min=None, + quant_max=None, factory_kwargs=None) -> None: + super().__init__() + + self.dtype = dtype + self.qscheme = qscheme + self.ch_axis = 1 + + per_channel_qscheme = qscheme + if qscheme in {torch.per_tensor_affine, torch.per_tensor_symmetric}: + per_channel_qscheme = qsheme_mapping_per_tensor_to_per_channel[qscheme] + self.weight_col_obs = PerChannelMinMaxObserver(ch_axis=1, dtype=dtype, + qscheme=per_channel_qscheme, + quant_min=quant_min, + quant_max=quant_max, + factory_kwargs=factory_kwargs) + + self.equalization_scale = torch.tensor(1) + + def forward(self, w_orig): + if not (w_orig.ndim >= 2 and w_orig.ndim <= 5): + raise ValueError("InputEqualizationObserver only supports Linear and Conv layers") + + return self.weight_col_obs(w_orig) + + def get_weight_col_minmax(self): + return (self.weight_col_obs.min_val, self.weight_col_obs.max_val) + + def set_equalization_scale(self, equalization_scale): + self.equalization_scale = equalization_scale + + with_args = classmethod(_with_args) + + +def calculate_equalization_scale(input_obs: _InputEqualizationObserver, + weight_obs: _WeightEqualizationObserver) -> torch.Tensor: + r""" Calculates the equalization scale and sets the equalization_scale value + in the observers. + + Args: + input_obs: Observer that tracks the ranges for the input columns + weight_obs: Observer that tracks the ranges for the weight columns + """ + + (min_inputs, max_inputs) = input_obs.get_input_minmax() + (min_weights, max_weights) = weight_obs.get_weight_col_minmax() + + if not (check_min_max_valid(min_inputs, max_inputs) and check_min_max_valid(min_weights, max_weights)): + warnings.warn( + "Must run observer before calling calculate_equalization_scale. " + + "Returning default equalization scale torch.tensor(1)." + ) + return torch.tensor(1) + + if not (min_inputs.shape == min_weights.shape): + raise ValueError( + "Input and Weight must have the same column dimension. " + + f"Found {min_inputs.shape} and {min_weights.shape} shapes instead." + ) + + equalization_scale = torch.sqrt((max_weights - min_weights) / (max_inputs - min_inputs)) + # Replace all 'inf', 'nan', 0's with 1s to prevent errors + equalization_scale[equalization_scale == 0.] = 1 + equalization_scale = torch.nan_to_num(equalization_scale, nan=1, posinf=1, neginf=1) + return equalization_scale + + +class EqualizationQConfig(namedtuple('EqualizationQConfig', ['input_activation', 'weight'])): + """ + Describes how to quantize a layer or a part of the network specifically for + input-weight equalization by providing settings (observer classes) for + inputs, outputs, and weights. + + Note that EqualizationQConfig needs to contain observer **classes** (like + MinMaxObserver) or a callable that returns instances on invocation, not the + concrete observer instances themselves. + Quantization function will instantiate observers multiple times for each of + the layers. + + Observer classes have usually reasonable default arguments, but they can be + overwritten with `with_args` method (that behaves like functools.partial): + + my_qconfig = EqualizationQConfig(input_activation=_InputEqualizationObserver.with_args(dtype=torch.qint8), + weight=_WeightEqualizationObserver.with_args(dtype=torch.qint8)) + """ + def __new__(cls, input_activation=torch.nn.Identity, weight=torch.nn.Identity): + if isinstance(input_activation, nn.Module) or isinstance(weight, nn.Module): + raise ValueError("EqualizationQConfig received observer instance, please pass observer class instead. " + + "Use MyObserver.with_args(x=1) to override arguments to constructor if needed") + self = super().__new__(cls, input_activation, weight) + return self + + +input_equalization_observer = _InputEqualizationObserver.with_args( + dtype=torch.quint8, qscheme=torch.per_tensor_symmetric) +weight_equalization_observer = _WeightEqualizationObserver.with_args( + dtype=torch.qint8, qscheme=torch.per_channel_symmetric) +default_equalization_qconfig = EqualizationQConfig(input_activation=input_equalization_observer, + weight=weight_equalization_observer) + + +def fused_module_supports_equalization(module) -> bool: + """ Checks if the fused node supports equalization. """ + return type(module) in [nni.LinearReLU, nni.ConvReLU1d, nni.ConvReLU2d, nni.ConvReLU3d] + +def nn_module_supports_equalization(module) -> bool: + """ Checks if the torch.nn node supports equalization. """ + return type(module) in [nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d] + +def custom_module_supports_equalization(module) -> bool: + """ Checks if the custom node supports equalization. """ + return type(module) in CUSTOM_MODULE_SUPP_LIST + + +def node_supports_equalization(node: Node, modules) -> bool: + """ Checks if the current node supports equalization + Currently we only support nn.Linear/F.Linear and nn.Conv/F.conv layers + """ + if node.op == 'call_module': + return nn_module_supports_equalization(modules[str(node.target)]) or \ + fused_module_supports_equalization(modules[str(node.target)]) or \ + custom_module_supports_equalization(modules[str(node.target)]) + elif node.op == 'call_function': + return node.target in [F.linear, F.conv1d, F.conv2d, F.conv3d] + return False + +def is_equalization_observer(observer: nn.Module) -> bool: + return (isinstance(observer, (_InputEqualizationObserver, _WeightEqualizationObserver))) + + +############################################################################### +# Functions for equalization during convert # +############################################################################### + +def get_op_node_and_weight_eq_obs( + input_eq_obs_node: Node, + model: GraphModule, + modules: Dict[str, nn.Module] +) -> Tuple[Optional[Node], Optional[_WeightEqualizationObserver]]: + """ Gets the following weight equalization observer. There should always + exist a weight equalization observer after an input equalization observer. + + Returns the operation node that follows the input equalization observer node + and the weight equalization observer + """ + + # Find the op node that comes directly after the input equalization observer + op_node = None + for user in input_eq_obs_node.users.keys(): + if node_supports_equalization(user, modules): + op_node = user + break + + assert op_node is not None + if op_node.op == 'call_module': + # If the op_node is a nn.Linear layer, then it must have a + # WeightEqualizationObserver configuration + maybe_equalization_node_name_to_config = _get_observed_graph_module_attr(model, "equalization_node_name_to_qconfig") + assert maybe_equalization_node_name_to_config is not None + equalization_node_name_to_qconfig: Dict[str, Any] = maybe_equalization_node_name_to_config # type: ignore[assignment] + assert equalization_node_name_to_qconfig.get(op_node.name, None) is not None + weight_eq_obs = equalization_node_name_to_qconfig.get(op_node.name, None).weight() + + assert isinstance(weight_eq_obs, _WeightEqualizationObserver) + return op_node, weight_eq_obs + + elif op_node.op == 'call_function': + weight_node = maybe_get_weight_eq_obs_node(op_node, modules) + if weight_node is not None: + weight_eq_obs = modules[str(weight_node.target)] + assert isinstance(weight_eq_obs, _WeightEqualizationObserver) + return op_node, weight_eq_obs + + return None, None + +def maybe_get_weight_eq_obs_node(op_node: Node, modules: Dict[str, nn.Module]) -> Optional[Node]: + """ Gets the weight equalization observer node if it exists. + """ + assert op_node.op == 'call_function' + for node_arg in op_node.args: + if node_arg_is_weight(op_node, node_arg): + assert (isinstance(node_arg, Node) and node_arg.op == 'call_module' and + isinstance(modules[str(node_arg.target)], _WeightEqualizationObserver)) + return node_arg + return None + +def maybe_get_next_input_eq_obs(node: Node, modules: Dict[str, nn.Module]) -> Optional[_InputEqualizationObserver]: + """ Gets the following input equalization observer if it exists. + + For example, in the case of connecting linear layers: + x -> inp_obs1 -> eq_obs1 -> linear1 -> out_obs1 -> eq_obs2 -> linear2 -> out_obs2 + If the node being passed in is the linear1 node, then we want to return eq_obs2, + the following equalization observer for linear2. + + However, if there are no connecting layers: + x -> inp_obs1 -> eq_obs1 -> linear1 -> out_obs1 -> add + Then we want to return None. + + In the case of an unfused linear-relu layer with a connecting linear layer: + linear1 -> relu -> out_obs1 -> eq_obs2 -> linear2 -> out_obs2 + Since it is unfused, we want to skip over the relu layer and return eq_obs2, + the following equalization observer for linear2. + """ + + assert node_supports_equalization(node, modules) + + # Locate the following nn.ReLU or F.relu node if it exists + maybe_relu_node = maybe_get_next_module(node, modules, nn.ReLU) + if maybe_relu_node is None: + maybe_relu_node = maybe_get_next_module(node, modules, target_functional_type=F.relu) + + # Locate the following output observer if it exists. + # We will skip the relu node if it exists. + maybe_obs_node = ( + maybe_get_next_module(node, modules, ObserverBase) + if maybe_relu_node is None + else maybe_get_next_module(maybe_relu_node, modules, ObserverBase) + ) + if maybe_obs_node is None: + return None + + maybe_eq_obs_node = maybe_get_next_module(maybe_obs_node, modules, _InputEqualizationObserver) + if maybe_eq_obs_node is None: + return None + + maybe_eq_obs = modules[str(maybe_eq_obs_node)] + assert isinstance(maybe_eq_obs, _InputEqualizationObserver) + return maybe_eq_obs + +def maybe_get_next_equalization_scale(node: Node, modules: Dict[str, nn.Module]) -> Optional[torch.Tensor]: + """ If the next next node is an InputEqualizationObserver then we want to + return its equalization scale, else we return 1 + + This is used in the case where there are two connecting linear layers: + linear1 -> LinearOutObs -> InputEqObs -> linear2 + In this case, the node given is linear1 and we want to locate the InputEqObs. + """ + next_inp_eq_obs = maybe_get_next_input_eq_obs(node, modules) + if next_inp_eq_obs: + if next_inp_eq_obs.equalization_scale.nelement() == 1 and \ + next_inp_eq_obs.equalization_scale == torch.tensor(1): + return None + return next_inp_eq_obs.equalization_scale + return None + +def scale_input_observer(node: Node, modules: Dict[str, nn.Module]) -> None: + """ Scales the following input quantization observer's min/max values by + updating the values with the scaled min/max values calculated by the input + equalization observer + """ + input_eq_obs = modules[str(node.target)] + assert isinstance(input_eq_obs, _InputEqualizationObserver) + + input_quant_obs_node = node.args[0] + assert isinstance(input_quant_obs_node, Node) + + input_quant_obs = modules[str(input_quant_obs_node.target)] + if not isinstance(input_quant_obs, ObserverBase): + return + + min_input_scaled, max_input_scaled = input_eq_obs.calculate_scaled_minmax() + if min_input_scaled is None and max_input_scaled is None: + return + input_quant_obs.min_val = min_input_scaled + input_quant_obs.max_val = max_input_scaled + +def scale_weight_node( + node: Node, + modules: Dict[str, nn.Module], + equalization_scale: torch.Tensor, + next_equalization_scale: Optional[torch.Tensor], +) -> None: + """ Scale the weights for input-weight equalization by multiplying the + weight by 1/equalization_scale and next_equalization_scale + + Args: + node: Current node whose weights we want to scale + equalization_scale: Current node's calculated equalization scale + next_equalization_scale: Next node's calculated equalization scale if + the following node needs to be equalized, 1 otherwise + """ + if equalization_scale is None: + return + + if fused_module_supports_equalization(modules[str(node.target)]): + op_module = modules[str(node.target)][0] # type: ignore[index] + else: + op_module = modules[str(node.target)] + assert nn_module_supports_equalization(op_module) or custom_module_supports_equalization(op_module) + + # Scale the weights for input-weight equalization + # If the following layer needs to be equalized then we will multiply its scale + weight = op_module.weight + assert isinstance(weight, torch.Tensor) + + # Scale the weights by the reciprocal of the equalization scale + # Reshape the equalization scale so that we can multiply it to the weight along axis=1 + equalization_scale_reshaped = reshape_scale(equalization_scale, 1, weight) + scaled_weight = torch.mul(weight, torch.reciprocal(equalization_scale_reshaped)) + + if next_equalization_scale is None: + op_module.weight = nn.Parameter(scaled_weight) + return + + # Multiply the weights row wise by the next equalization scale + # Reshape the equalization scale so that we can multiply it to the weight along axis=0 + next_equalization_scale_reshaped = reshape_scale(next_equalization_scale, 0, weight) + scaled_weight = torch.mul(scaled_weight, next_equalization_scale_reshaped) + + op_module.weight = nn.Parameter(scaled_weight) + + # Multiply the bias element wise by the next equalization scale + bias = op_module.bias + if bias is None: + return + assert isinstance(bias, torch.Tensor) + + # Reshape the equalization scale so that we can multiply it element-wise to the bias + next_equalization_scale_reshaped = reshape_scale(next_equalization_scale, 0, bias) + scaled_bias = torch.mul(bias, next_equalization_scale_reshaped) + op_module.bias = nn.Parameter(scaled_bias) + +def scale_weight_functional( + op_node: Node, + model: GraphModule, + modules: Dict[str, nn.Module], + equalization_scale: torch.Tensor, + next_equalization_scale: Optional[torch.Tensor], +) -> None: + """ Scales the weight value for functional layers + """ + if equalization_scale is None: + return + + # From the given op_node, the path looks like: + # get_attr(weight) -> weight_quant_obs -> weight_eq_obs -> op_node + # So we want to trace back from the op_node to get the equalization observer + # node, then the quantization observer node, and then finally the weight + # node which contains the weight values. + + # Get the equalization observer node + weight_eq_obs_node = maybe_get_weight_eq_obs_node(op_node, modules) + if weight_eq_obs_node is None: + return + + # Get the quantization observer node + weight_quant_obs_node = weight_eq_obs_node.args[0] + if weight_quant_obs_node is None: + return + assert (isinstance(weight_quant_obs_node, Node) and + isinstance(modules[str(weight_quant_obs_node.target)], ObserverBase)) + + # Get the get_attr(weight) node + weight_node = weight_quant_obs_node.args[0] + if weight_node is None: + return + assert isinstance(weight_node, Node) and weight_node.op == 'get_attr' + + weight_parent_name, weight_name = _parent_name(weight_node.target) + weight = getattr(modules[weight_parent_name], weight_name) + + # Scale the weights for input-weight equalization + # If the following layer needs to be equalized then we will multiply its scale + # Reshape the equalization scale so that we can multiply it to the weight along axis=1 + equalization_scale_reshaped = reshape_scale(equalization_scale, 1, weight) + scaled_weight = torch.mul(weight, torch.reciprocal(equalization_scale_reshaped)) + + if next_equalization_scale is None: + setattr(modules[weight_parent_name], weight_name, scaled_weight) + return + + # Multiply the weights row wise by the next equalization scale + # Reshape the equalization scale so that we can multiply it to the weight along axis=1 + next_equalization_scale_reshaped = reshape_scale(next_equalization_scale, 0, scaled_weight) + scaled_weight = torch.mul(scaled_weight, next_equalization_scale_reshaped) + + setattr(modules[weight_parent_name], weight_name, scaled_weight) + assert torch.allclose(model.get_buffer(str(weight_node.target)), scaled_weight) + + # Multiply the bias element wise by the next equalization scale + bias_node = None + for node in op_node.args: + # Find the node containing the weight values + if isinstance(node, Node) and node.op == 'get_attr' and 'bias' in node.name: + bias_node = node + break + if bias_node is None: + return + + bias_parent_name, bias_name = _parent_name(bias_node.target) + bias = getattr(modules[bias_parent_name], bias_name) + + # Reshape the equalization scale so that we can multiply it element-wise to the bias + next_equalization_scale_reshaped = reshape_scale(next_equalization_scale, 0, bias) + scaled_bias = torch.mul(bias, next_equalization_scale_reshaped) + setattr(modules[bias_parent_name], bias_name, scaled_bias) + +def clear_weight_quant_obs_node(op_node: Node, modules: Dict[str, nn.Module]) -> None: + """ Given the operation node, we want find the corresponding quantization + observer and reset its min/max values + """ + weight_eq_obs_node = maybe_get_weight_eq_obs_node(op_node, modules) + if weight_eq_obs_node is None: + return + + weight_quant_obs_node = weight_eq_obs_node.args[0] + if weight_quant_obs_node is None: + return + assert isinstance(weight_quant_obs_node, Node) + + weight_quant_obs = modules[str(weight_quant_obs_node.target)] + assert isinstance(modules[str(weight_quant_obs_node.target)], ObserverBase) + weight_quant_obs.reset_min_max_vals() # type: ignore[operator] + +def remove_node(model: GraphModule, node: Node, prev_node: Node): + """ Removes the given node from the model by replacing all of its users with + the given previous node + """ + # For all of the current node's users, replace the current node with + # the input quantization observer node + orig_users = list(node.users.keys()) + for user_node in orig_users: + user_node.replace_input_with(node, prev_node) + + # Erase the InputEqualizationObserver node + model.graph.erase_node(node) + +def update_obs_for_equalization(model: GraphModule, modules: Dict[str, nn.Module]) -> Dict[str, _WeightEqualizationObserver]: + """ Update all of the observer's equalization scale. For each + InputEqualizationObserver, we will find the location of the next + WeightEqualizationObserver, create it, and calculate the equalization scale + based on the two observers. + + We will then return a dictionary mapping operation node names to + the corresponding WeightEqualizationObservers for that operation. + """ + weight_eq_obs_dict = {} + for node in model.graph.nodes: + if node.op == 'call_module' and isinstance(modules[node.target], _InputEqualizationObserver): + input_eq_obs = modules[node.target] + assert isinstance(input_eq_obs, _InputEqualizationObserver) + op_node, weight_eq_obs = get_op_node_and_weight_eq_obs(node, model, modules) + + if op_node is None or weight_eq_obs is None: + continue + + if op_node.op == 'call_module': + # Calibrate the weight equalization observer since it has just + # been created + if fused_module_supports_equalization(modules[str(op_node.target)]): + module = modules[str(op_node.target)][0] # type: ignore[index] + assert nn_module_supports_equalization(module) + weight_eq_obs(module.weight) + else: + weight_eq_obs(modules[str(op_node.target)].weight) + + # Calculate and set the equalization scale values + equalization_scale = calculate_equalization_scale(input_eq_obs, weight_eq_obs) + input_eq_obs.set_equalization_scale(equalization_scale) + weight_eq_obs.set_equalization_scale(equalization_scale) + + weight_eq_obs_dict[op_node.name] = weight_eq_obs + + return weight_eq_obs_dict + +def convert_eq_obs( + model: GraphModule, + modules: Dict[str, nn.Module], + weight_eq_obs_dict: Dict[str, _WeightEqualizationObserver], +) -> None: + """ Converts the equalization operations and updates the other nodes in the + following way: + - Removes the input equalization observers and inserts a mul operator + along with an equalization scale node wherever applicable (we do not + want to insert a mul operator between connecting linear layers). + - Updates the input quantization observers with the scaled input min/max + values. + - Scales the weights by the current and next equalization scales. + - Removes the weight equalization observer node if it exists. + + Before (after prepare): + weight values + | + WeightQuantObs + | + WeightEqObs + | + x -> InpQuantObs -> InpEqObs -> linear -> OutQuantObs + + After this function: + scaled weight values + | + equalization scale WeightQuantObs + | | + x -> mul -> InpQuantObs (scaled min/max) -> linear -> OutQuantObs + + After convert: + equalization scale scaled weight values + | | + x -> mul -> quantize_per_tensor -> quantized::linear + + Note that although the equalization observer appeared after the quantization + observer after prepare_fx, the mul node appears before the quantization node + after convert_fx. This is because placing the equalization observer after + the quantization observer in prepare_fx would allow us to keep the invariant + that the graph before the current node inserts its observers is not + modified. + + Having the equalization observer before the quantization observer would also + cause some inconsistences between the ordering of the quantization and + equalization observers. + For example, a single linear layer would look like: + x -> InpEqObs1 -> InpQuantObs1 -> linear1 -> OutQuantObs1 + But between two connected linear layers, it would look like: + linear1 -> OutQuantObs1 -> InpEqObs2 -> linear2 -> OutQuantObs2 + """ + for node in model.graph.nodes: + if node.op == 'call_module' and isinstance(modules[node.target], _InputEqualizationObserver): + inp_quant_obs_node = node.args[0] + prev_node = inp_quant_obs_node.args[0] + + # If the previous node is a layer that needs to be equalized, then + # we will remove the current node because we do not need to add any + # equalization nodes between two layers that need to be equalized + + # Before: linear1/relu (prev_node) -> output_quant_obs1 (inp_quant_obs_node) -> input_eq_obs2 (node) -> linear2 + # After: linear1/relu (prev_node) -> output_quant_obs1 (inp_quant_obs_node) -> linear2 + if node_supports_equalization(prev_node, modules) or "relu" in prev_node.name: + remove_node(model, node, inp_quant_obs_node) + continue + + # Update the following input quantization observer's min/max values + scale_input_observer(node, modules) + + # Remove the InputEqualization node and add a mul operator before + # the quantization observer node that appears before the equalization node + # Before: x -> input_quant_obs -> input_eq_obs -> linear + # After: x -> mul -> input_quant_obs -> linear + + # Create a node containing the equalization scale + with model.graph.inserting_before(inp_quant_obs_node): + get_new_eq_scale_name = get_new_attr_name_with_prefix(prev_node.name + '_equalization_scale') + name = get_new_eq_scale_name(modules) + setattr(model, name, modules[node.target].equalization_scale) + eq_scale_node = model.graph.create_node('get_attr', name) + + # Create a node multiplying the input with the equalization scale + with model.graph.inserting_after(eq_scale_node): + inputs = (prev_node, eq_scale_node) + mul_node = model.graph.create_node("call_function", torch.mul, inputs) + + # Set the mul nod to be the input_quant_obs_node's input instead of + # the previous node + inp_quant_obs_node.replace_input_with(prev_node, mul_node) + remove_node(model, node, inp_quant_obs_node) + + elif weight_eq_obs_dict.get(node.name, None) is not None: + weight_eq_obs = weight_eq_obs_dict.get(node.name) + assert isinstance(weight_eq_obs, _WeightEqualizationObserver) + equalization_scale = weight_eq_obs.equalization_scale + + if equalization_scale.nelement() == 1 and equalization_scale == torch.tensor(1): + equalization_scale = None # type: ignore[assignment] + maybe_next_equalization_scale = maybe_get_next_equalization_scale(node, modules) + + # Scale the weight nodes + if node.op == 'call_module': + scale_weight_node(node, modules, equalization_scale, maybe_next_equalization_scale) + elif node.op == 'call_function': + scale_weight_functional(node, model, modules, equalization_scale, maybe_next_equalization_scale) + + weight_eq_obs_node = maybe_get_weight_eq_obs_node(node, modules) + if weight_eq_obs_node is None: + return + assert isinstance(modules[str(weight_eq_obs_node.target)], _WeightEqualizationObserver) + + # Clear the quantization observer's min/max values so that they + # can get updated later based on the new scale values + clear_weight_quant_obs_node(node, modules) + + # Erase the weight equalization observer node + prev_node = weight_eq_obs_node.args[0] + remove_node(model, weight_eq_obs_node, prev_node) + else: + raise ValueError("Expected operation node to be 'call_module' or 'call_function" + + f"Instead got node {node.name} as '{node.op}'.") + +def _convert_equalization_ref(model: GraphModule): + """ Reference function which applies changes needed for equalization, but + does not quantize the nodes + """ + modules = dict(model.named_modules(remove_duplicate=False)) + + # Calculate the equalization scale, update the observers with the scaled + # inputs, and scale the weight + weight_eq_obs_dict = update_obs_for_equalization(model, modules) + convert_eq_obs(model, modules, weight_eq_obs_dict) + + return GraphModule(model, model.graph) + + +############################################################################### +# Functions for running the equalized model on the Numeric Suite # +############################################################################### + +def get_layer_sqnr_dict(model_a: nn.Module, model_b: nn.Module, x: torch.Tensor) -> Dict[str, float]: + """ Runs the Numeric Suite on model_a and model_b and returns a dictionary + containing the SQNR between layers in model_a and model_b. + + Note: In order to support equalized models, this function has a hacky fix in + which we do not match any torch.mul operators. This is because equalized + models contain extra mul operators to scale the input by the equalization + scale, but this edge case has not been resolved yet within the numeric suite code. + + Args: + model_a: A float model + model_b: A quantized model + x: Inputs to use during calibration + """ + import torch.ao.ns._numeric_suite_fx as ns + from torch.ao.ns.fx.mappings import get_unmatchable_types_map + + unmatchable_types_map = get_unmatchable_types_map() + unmatchable_types_map["funs_unmatchable"].add(torch.mul) + + model_a_ns, model_b_ns = ns.add_loggers( + 'fp32', model_a, + 'int8', model_b, + ns.OutputLogger, + unmatchable_types_map=unmatchable_types_map + ) + + model_a_ns(x) + model_b_ns(x) + + activation_comparison_dict = ns.extract_logger_info( + model_a_ns, + model_b_ns, + ns.OutputLogger, + 'int8') + ns.extend_logger_results_with_comparison( + activation_comparison_dict, + 'fp32', 'int8', + torch.ao.ns.fx.utils.compute_sqnr, 'sqnr' + ) + + # Construct a dictionary mapping layer names to the SQNR values + layer_sqnr_dict = {} + for key in activation_comparison_dict: + layer = activation_comparison_dict[key]['node_output']['int8'][0]['fqn'] + sqnr = activation_comparison_dict[key]['node_output']['int8'][0]['sqnr'][0] + layer_sqnr_dict[layer] = sqnr + + return layer_sqnr_dict + +def get_equalization_qconfig_dict( + layer_sqnr_dict: Dict[str, float], + num_layers_to_equalize: int +) -> Any: + """ Given the layer to SQNR dictionary, find the layers with the highest + quantization errors, and return an equalization_qconfig_dict + specifying to only equalize those top layers. + + Args: + layer_sqnr_dict: Dictionary mapping layer names to SQNR values (found + when comparing an equalized model against a float model) + num_layers_to_equalize: Number of layers with the highest quantization + errors to equalize + """ + + # Sort the layer_sqnr_dictionary values and get the layers with the lowest + # SQNR values (aka highest quantization errors) + layer_sqnr_sorted = sorted(layer_sqnr_dict.items(), key=lambda item: item[1]) + layers_to_equalize = layer_sqnr_sorted[:num_layers_to_equalize] + + # Constructs an equalization_qconfig_dict that specifies to only equalize + # the layers with the highest quantization errors + module_to_qconfig_list = [(item[0], default_equalization_qconfig) for item in layers_to_equalize] + equalization_qconfig_dict = {"module_name": module_to_qconfig_list} + return equalization_qconfig_dict diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/_model_report/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/_model_report/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..271c9c6d423184df409dc39ae0e21dc2c514f383 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/detector.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/detector.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d9327b8874491b9e90b1b52c03e0eb5bfa014175 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/detector.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/model_report.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/model_report.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c241db249ef3e1a454b41fcbebfd78251baea5ef Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/model_report.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/model_report_observer.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/model_report_observer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9a5462095753d0fd21a85e96e43ecc963f478e40 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/model_report_observer.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/model_report_visualizer.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/model_report_visualizer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..53e8dd825f6e27e00641a0e846423597f4d4550b Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/model_report_visualizer.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/_model_report/detector.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/_model_report/detector.py new file mode 100644 index 0000000000000000000000000000000000000000..71986fd17fbe8e8c5061634dc69e4871ae61f4e7 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/_model_report/detector.py @@ -0,0 +1,1539 @@ +from typing import Any, Dict, Set, Tuple, Callable, List + +import torch +import torch.nn as nn +import torch.ao.nn.qat as nnqat +from abc import ABC, abstractmethod +from torch.ao.quantization.fake_quantize import FakeQuantize +from torch.ao.quantization.fx.graph_module import GraphModule +from torch.ao.quantization.fx._model_report.model_report_observer import ModelReportObserver +from torch.ao.quantization.qconfig import ( + QConfig, + default_qconfig, + _assert_valid_qconfig, +) +from torch.ao.quantization.observer import ( + ObserverBase, + default_dynamic_quant_observer, + default_per_channel_weight_observer, + default_observer, + default_weight_observer, +) +from torch.ao.quantization.fx._equalize import ( + default_equalization_qconfig, + EqualizationQConfig, +) +from torch.ao.quantization.observer import _is_activation_post_process + +# Names for observer insert keys +DETECTOR_TARGET_NODE_KEY = "target_node" +DETECTOR_OBS_TO_INSERT_KEY = "observer_to_insert" +DETECTOR_IS_POST_OBS_KEY = "is_post_observer" +DETECTOR_OBS_ARGS_KEY = "observer_args" + +# Mapping related code +class DetectorQConfigInfo: + r""" + This class contains the QConfig information for a single module. + The list of variables / values this contains can grow depending on the + extensibility of the qconfig mapping feature set but this currently includes: + - if activation observer is dynamic + - if weight observer is per channel + + + Args: + module_fqn (str): The fully qualified name (fqn) of the module that this + information contains info relevant to qconfig for + """ + + def __init__(self, module_fqn: str): + super().__init__() + self.module_fqn = module_fqn + + # populate this section with all the variables we might find important + # change from none if your detector is actually using this + self.is_activation_dynamic = False + self.is_weight_per_channel = False + + # equalization related options + self.is_equalization_recommended = False + + def generate_quantization_qconfig(self, module: torch.nn.Module) -> QConfig: + r""" + Args: + module (torch.nn.Module) The module we are generating + the qconfig for + + Returns the generated quantization QConfig according to what a valid configuration is + """ + # Apply suggestions to new qconfig + module_qconfig = default_qconfig + + # keep track of dynamic and per_channel recommendations + recommendations_list = [] + # append as if a list of combinations + recommendations_list.append((self.is_activation_dynamic, self.is_weight_per_channel)) + recommendations_list.append((self.is_activation_dynamic, False)) # only trying dynamic rec + recommendations_list.append((False, self.is_weight_per_channel)) # only trying dynamic + + # now we try each of the combinations + for rec in recommendations_list: + # rec[0] -> dynamic recommended + # rec[1] -> per channel recommended + activation = default_dynamic_quant_observer if rec[0] else default_observer + weight = default_per_channel_weight_observer if rec[1] else default_weight_observer + test_config = QConfig(activation, weight) + try: + _assert_valid_qconfig(test_config, module) + module_qconfig = test_config + break + except AssertionError: + # if not a valid configuration, we move on to the next one in priority + continue + + # return the QConfig chosen + return module_qconfig + + def generate_equalization_qconfig(self) -> EqualizationQConfig: + r""" + This returns the equalization configuration for a module. + + For now, it just returns the default, but as more equalization options become + possible, this method can get more fleshed out with more nuanced granularity. + + + Returns the generated equalization QConfig according to what a valid configuration is + """ + # in this case, we just return default equalization config + # we know this is valid because only valid modules would even + # have this option + return default_equalization_qconfig + +# Adding base class for detectors +class DetectorBase(ABC): + r""" Base Detector Module + Any detector class should derive from this class. + + Concrete detectors should follow the same general API, which includes: + - A method to calculate and return observer insertion points + - Should return both the fqns and the Observer class to insert + - A method to return a report based on the detector + - Should return a str-based report and dict info in Tuple[str,Dict] format + """ + + def __init__(self): + super().__init__() + self.detector_config_info = None + + @abstractmethod + def determine_observer_insert_points(self, model) -> Dict: + r""" + Args + model (nn.Module or subclass): model to find observer insertion points + + Returns a Dict mapping from unique observer fqns (where we want to insert them) to a Dict. + This dict maps string keys to detector specific information + """ + pass + + @abstractmethod + def get_detector_name(self) -> str: + r""" Returns the name of the current detector """ + pass + + + @abstractmethod + def get_qconfig_info(self, model) -> Dict[str, DetectorQConfigInfo]: + r""" Returns the DetectorQConfigInfo for each module_fqn relevant + Args + model (nn.Module or subclass): model to find observer insertion points + + Returns a Dict mapping from unique observer fqns (where we want to insert them) to: + A DetectorQConfigInfo with the information to generate a QConfig for a specific module + """ + pass + + def _get_targeting_node(self, prepared_fx_model: GraphModule, target_fqn: str) -> torch.fx.node.Node: + r""" + Takes in a GraphModule and the target_fqn and finds the node whose target is this fqn. + + If it's not found, it means it is most likely inside a fused layer + We just go one layer up in terms of the fqn we are searching for until we find parent node + If we get to empty string, then we know that it doesn't exist + + The reason for the recursion is that if the model that we are looking for got fused, + we will have module fqn as e.g. x.linear.0 but the graph will only have a node for the fused module, + which would have fqn as x.linear so they will not match. + To handle this, if we don't match, we then take off the last bit of the fqn e.g. x.linear.0 -> x.linear, + or more generally foo.bar.baz -> foo.bar and search again, this will allow us to locate the correct module + even in cases with fusion + + Args: + prepared_fx_model (GraphModule): The prepared Fx GraphModule + target_fqn (str): The fqn of the layer we are trying to target + + Returns the node object we are trying to add observers around + """ + for node in prepared_fx_model.graph.nodes: + # if the node's target is our target, return it + if node.target == target_fqn: + return node + + # getting here means node not found + # if no "." we are already at base and failed + parent_fqn_sep_index = target_fqn.rfind(".") + if parent_fqn_sep_index == -1: + raise ValueError("passed in target_fqn not found in graph's targets.") + else: + # recursively call it with parent fqn + return self._get_targeting_node(prepared_fx_model, target_fqn[:parent_fqn_sep_index]) + + @abstractmethod + def generate_detector_report(self, model) -> Tuple[str, Dict[str, Any]]: + r""" + Args + model (nn.Module or subclass): model to find observer insertion points + + Returns a Tuple of two elements: + Str: string report of the suggested improvements + Dict: contains useful data collected by the observer pertinent to this report + """ + pass + +class PerChannelDetector(DetectorBase): + r""" This class is used to detect if any Linear or Conv layers in a model utilize per_channel quantization. + Only Linear and Conv layers can use per_channel as of now so only these two are currently checked. + + per_channel quantization can lead to major benefits in the form of accuracy. + Therefore, if the backend used by the user supports it, it is recommended to use + + Args: + backend (str, optional): the backend the user wishes to use in production + Default value is current torch.backends.quantized.engine + """ + + # Keys for return dictionary + BACKEND_KEY = "backend" + PER_CHAN_SUPPORTED_KEY = "per_channel_quantization_supported" + PER_CHAN_USED_KEY = "per_channel_quantization_used" + + # Default map for representing supported per channel quantization modules for different backends + DEFAULT_BACKEND_PER_CHANNEL_SUPPORTED_MODULES: Dict[str, Set[Any]] = { + "fbgemm": {nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nnqat.Linear, nnqat.Conv1d, nnqat.Conv2d, nnqat.Conv3d}, + "qnnpack": {nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nnqat.Linear, nnqat.Conv1d, nnqat.Conv2d, nnqat.Conv3d}, + "onednn": {nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nnqat.Linear, nnqat.Conv1d, nnqat.Conv2d, nnqat.Conv3d}, + "x86": {nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nnqat.Linear, nnqat.Conv1d, nnqat.Conv2d, nnqat.Conv3d}, + } + + def __init__(self, backend: str = torch.backends.quantized.engine): + super().__init__() + + # store the backend information + self.backend_chosen = backend + self.supported_modules = set() + if self.backend_chosen in self.DEFAULT_BACKEND_PER_CHANNEL_SUPPORTED_MODULES: + self.supported_modules = self.DEFAULT_BACKEND_PER_CHANNEL_SUPPORTED_MODULES[self.backend_chosen] + else: + raise ValueError(f"Not configured to work with {self.backend_chosen}. Try a different default backend") + + def get_detector_name(self) -> str: + r""" returns the string name of this detector""" + return "per_channel_detector" + + def get_qconfig_info(self, model) -> Dict[str, DetectorQConfigInfo]: + r""" Returns the DetectorQConfigInfo for each module_fqn relevant + Args + model (nn.Module or subclass): model to find observer insertion points + + Returns a Dict mapping from unique observer fqns (where we want to insert them) to: + A DetectorQConfigInfo with the information to generate a QConfig for a specific module + """ + # run the helper function to populate the dictionary + per_channel_info = self._detect_per_channel_helper(model) + + # we actually have a qconfig info object we are populating + module_fqn_to_detector_qconfig_info = {} + + for module_fqn in per_channel_info: + # create a detector info instance + detector_qconfig_info = DetectorQConfigInfo(module_fqn) + + # see if per channel quantization is supported + per_chan_supported: bool = per_channel_info[module_fqn][self.PER_CHAN_SUPPORTED_KEY] + detector_qconfig_info.is_weight_per_channel = per_chan_supported + module_fqn_to_detector_qconfig_info[module_fqn] = detector_qconfig_info + + return module_fqn_to_detector_qconfig_info + + def determine_observer_insert_points(self, model: nn.Module) -> Dict: + r""" + There is no observers inserted for the PerChannelDetector. + + Returns an empty dictionary since no observers are added or needed + """ + return {} + + + def _detect_per_channel_helper(self, model: nn.Module): + r""" + determines if per_channel quantization is supported in modules and submodules. + + Returns a dictionary in the higher level _detect_per_channel function. + Each entry maps the fully-qualified-name to information on whether per_channel quantization. + + Args: + model: The current module that is being checked to see if it is per_channel quantizable + + Returns dictionary mapping fqns to if per_channel quantization is possible + """ + # create dict we will return + per_channel_info: Dict = {} + + # get the fully qualified name and check if in list of modules to include and list of modules to ignore + for fqn, module in model.named_modules(): + + is_in_include_list = sum([isinstance(module, x) for x in self.supported_modules]) > 0 + + # check if the module per_channel is supported + # based on backend + per_channel_supported = False + + if is_in_include_list: + per_channel_supported = True + + # assert statement for MyPy + q_config_file = module.qconfig + assert isinstance(q_config_file, QConfig) + + # this object should either be fake quant or observer + q_or_s_obj = module.qconfig.weight.p.func() + assert isinstance(q_or_s_obj, (FakeQuantize, ObserverBase)) + + per_channel_used = False # will be true if found in qconfig + + if hasattr(q_or_s_obj, "ch_axis"): # then we know that per_channel quantization used + + # all fake quants have channel axis so need to check is_per_channel + if isinstance(q_or_s_obj, FakeQuantize): + if hasattr(q_or_s_obj, "is_per_channel") and q_or_s_obj.is_per_channel: + per_channel_used = True + elif isinstance(q_or_s_obj, ObserverBase): + # should be an observer otherwise + per_channel_used = True + else: + raise ValueError("Should be either observer or fake quant") + + per_channel_info[fqn] = { + self.PER_CHAN_SUPPORTED_KEY: per_channel_supported, + self.PER_CHAN_USED_KEY: per_channel_used, + self.BACKEND_KEY: self.backend_chosen + } + + return per_channel_info + + def generate_detector_report(self, model: nn.Module) -> Tuple[str, Dict[str, Any]]: + r"""Checks if any Linear or Conv layers in the model utilize per_channel quantization. + Only Linear and Conv layers can use per_channel as of now so only these two are currently checked. + + Looks at q_config format and backend to determine if per_channel can be utilized. + Uses the DEFAULT_BACKEND_PER_CHANNEL_SUPPORTED_MODULES structure to determine support + + Args: + model: The prepared and calibrated model we want to check if using per_channel + + Returns a tuple with two elements: + String report of potential actions to improve model (if per_channel quantization is available in backend) + Dictionary mapping per_channel quantizable elements to: + whether per_channel quantization is supported by the backend + if it is being utilized in the current model + """ + + # run the helper function to populate the dictionary + per_channel_info = self._detect_per_channel_helper(model) + + # String to let the user know of further optimizations + further_optims_str = f"Further Optimizations for backend {self.backend_chosen}: \n" + + optimizations_possible = False + for fqn in per_channel_info: + fqn_dict = per_channel_info[fqn] + if fqn_dict[self.PER_CHAN_SUPPORTED_KEY] and not fqn_dict[self.PER_CHAN_USED_KEY]: + optimizations_possible = True + further_optims_str += f"Module {fqn} can be configured to use per_channel quantization.\n" + + if optimizations_possible: + further_optims_str += ( + "To use per_channel quantization, make sure the qconfig has a per_channel weight observer." + ) + else: + further_optims_str += "No further per_channel optimizations possible." + + # return the string and the dictionary form of same information + return (further_optims_str, per_channel_info) + + +class DynamicStaticDetector(DetectorBase): + r""" + Determines whether dynamic or static quantization is more appropriate for a given module. + + Takes advantage of the ModelReportObserver that records range information. + Stationary distribution of data are strictly above tolerance level for the comparison statistic: + + S = average_batch_activation_range/epoch_activation_range + + Nonstationary distributions are below or at the tolerance level for this metric. + + If the distribution of data right after the module is non-stationary, recommend dynamic quantization + Otherwise recommend static quantization + + Args: + tolerance (float, optional): The threshold where S metric is stationary above and non-stationary otherwise. Default: 0.5 + """ + # names for the pre and post observers that are inserted + DEFAULT_PRE_OBSERVER_NAME = "model_report_pre_observer" + DEFAULT_POST_OBSERVER_NAME = "model_report_post_observer" + + # naming conventions for stationary vs non-stationary data + STATIONARY_STR = "stationary" + NON_STATIONARY_STR = "non-stationary" + + # naming for activation + INPUT_ACTIVATION_PREFIX = "input_activation_" + OUTPUT_ACTIVATION_PREFIX = "output_activation_" + + # naming conventions for the keys of the return module info + TOLERANCE_KEY = "dynamic_static_tolerance" + DEFAULT_DYNAMIC_REC_KEY = "dynamic_recommended" + PRE_OBS_COMP_STAT_KEY = INPUT_ACTIVATION_PREFIX + "dynamic_static_comp_stat" + POST_OBS_COMP_STAT_KEY = OUTPUT_ACTIVATION_PREFIX + "dynamic_static_comp_stat" + PRE_OBS_DATA_DIST_KEY = INPUT_ACTIVATION_PREFIX + "dynamic_static_data_classification" + POST_OBS_DATA_DIST_KEY = OUTPUT_ACTIVATION_PREFIX + "dynamic_static_data_classification" + IS_CURRENTLY_SUPPORTED_KEY = "is_dynamic_supported" + + # modules that are supported both dynamic and static for this report function + DEFAULT_DYNAMIC_STATIC_CHECK_SUPPORTED = {nn.Linear} + + # modules that will be supported soon for both + DEFAULT_DYNAMIC_STATIC_FUTURE_SUPPORTED = {nn.Conv1d, nn.Conv2d, nn.Conv3d} + + def __init__(self, tolerance=0.5): + super().__init__() + + # set tolerance level and initialize a set to keep track of useful fqn locations + self.tolerance = tolerance + self.useful_observer_fqns: Set[str] = set() + + def determine_observer_insert_points(self, prepared_fx_model: GraphModule) -> Dict[str, Dict[str, Any]]: + r""" + Determines where observers need to be inserted for the Dynamic vs Static detector. + For this detector, we want to place observers on either side of linear layers in the model. + + Currently inserts observers for: + linear layers + + Args: + prepared_fx_model (GraphModule): The prepared Fx GraphModule + + Returns a Dict mapping from unique observer fqns (where we want to insert them) to a Dict with: + key "target_node" -> the node we are trying to observe with this observer (torch.fx.node.Node) + key "observer_to_insert" -> the observer we wish to insert (ObserverBase) + key "is_post_observer" -> True if this is meant to be a post-observer for target_node, False if pre-observer + key "observer_args" -> The arguments that are meant to be passed into the observer + """ + + # observer for this detector is ModelReportObserver + obs_ctr = ModelReportObserver + + # return dict + obs_fqn_to_info: Dict[str, Dict[str, Any]] = {} + + for fqn, module in prepared_fx_model.named_modules(): + # make sure module is supported + if self._is_supported(module, insert=True): + # if it's a supported type, we want to get node and add observer insert locations + targeted_node = self._get_targeting_node(prepared_fx_model, fqn) + + # add entry for pre-observer + pre_obs_fqn = fqn + "." + self.DEFAULT_PRE_OBSERVER_NAME + + obs_fqn_to_info[pre_obs_fqn] = { + DETECTOR_TARGET_NODE_KEY: targeted_node, + DETECTOR_OBS_TO_INSERT_KEY: obs_ctr(), + DETECTOR_IS_POST_OBS_KEY: False, + DETECTOR_OBS_ARGS_KEY: targeted_node.args + } + + # add entry for post-observer + post_obs_fqn = fqn + "." + self.DEFAULT_POST_OBSERVER_NAME + + obs_fqn_to_info[post_obs_fqn] = { + DETECTOR_TARGET_NODE_KEY: targeted_node, + DETECTOR_OBS_TO_INSERT_KEY: obs_ctr(), + DETECTOR_IS_POST_OBS_KEY: True, + DETECTOR_OBS_ARGS_KEY: (targeted_node,) + } + + return obs_fqn_to_info + + def get_detector_name(self) -> str: + r""" returns the string name of this detector""" + return "dynamic_vs_static_detector" + + + def get_qconfig_info(self, model) -> Dict[str, DetectorQConfigInfo]: + r""" Returns the DetectorQConfigInfo for each module_fqn relevant + Args + model (nn.Module or subclass): model to find observer insertion points + + Returns a Dict mapping from unique observer fqns (where we want to insert them) to: + A DetectorQConfigInfo with the information to generate a QConfig for a specific module + """ + # run the helper function to populate the dictionary + dynamic_static_info = self._generate_dict_info(model) + + # we actually have a qconfig info object we are populating + module_fqn_to_detector_qconfig_info = {} + + for module_fqn in dynamic_static_info: + # create a detector info instance + detector_qconfig_info = DetectorQConfigInfo(module_fqn) + + # see if per channel quantization is supported + dynamic_static_recommended: bool = dynamic_static_info[module_fqn][self.DEFAULT_DYNAMIC_REC_KEY] + detector_qconfig_info.is_activation_dynamic = dynamic_static_recommended + module_fqn_to_detector_qconfig_info[module_fqn] = detector_qconfig_info + + return module_fqn_to_detector_qconfig_info + + def _is_supported(self, module: nn.Module, insert: bool = False) -> bool: + r"""Returns whether the given module is supported for observers + + Args + module: The module to check and ensure is supported + insert: True if this is check for observer insertion, false if for report gen + + Returns True if the module is supported by observer, False otherwise + """ + # check to see if module is of a supported type + is_supported_type = sum([isinstance(module, x) for x in self.DEFAULT_DYNAMIC_STATIC_CHECK_SUPPORTED]) > 0 + + # check if it will be supported + future_supported_type = sum([isinstance(module, x) for x in self.DEFAULT_DYNAMIC_STATIC_FUTURE_SUPPORTED]) > 0 + + # supported + supported = is_supported_type or future_supported_type + + # this is check for observer insertion + if insert: + return supported + else: + # this is for report gen and we also need to check if it contains observers + has_obs = hasattr(module, self.DEFAULT_PRE_OBSERVER_NAME) and hasattr(module, self.DEFAULT_POST_OBSERVER_NAME) + return supported and has_obs + + def _generate_dict_info(self, model: GraphModule) -> Dict[str, Any]: + r""" + Helper function for generate_detector_report that does the generation of the dictionary. + This process is done as specified in generate_detector_report documentation + + Args: + model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers + + Returns a Dictionary mapping modules with ModelReportObservers around them to: + whether dynamic quantization is recommended + their S metric of input to module + whether input to module is stationary or non-stationary + their S metric of output of module + whether output of module is stationary or non-stationary + the tolerance level to decided whether input/output is stationary or non-stationary + whether it is currently supported or planned for the future + """ + # store modules dynamic vs static information + module_dynamic_static_info = {} + + # This for loop goes through the modules, and extracts all relevant information into module_dynamic_static_info + # This information primary includes whether the data distributions around a supported module is stationary or not + # Based on this, it is recorded whether dynamic or static quantization is recommended + + # loop through all submodules included nested ones + for fqn, module in model.named_modules(): + # if module is Linear has the ModelReportObserver attached to it + if self._is_supported(module): + # get pre and post observers for the module + pre_obs = getattr(module, self.DEFAULT_PRE_OBSERVER_NAME) + post_obs = getattr(module, self.DEFAULT_POST_OBSERVER_NAME) + + # get the statistics for each module + pre_stat = pre_obs.get_batch_to_epoch_ratio() + post_stat = post_obs.get_batch_to_epoch_ratio() + + # record module, pre and post stat, and whether to do dynamic or static based off it + # true if post observer data distribution is non-stationary, false if it's stationary + dynamic_recommended = post_stat <= self.tolerance + + # specify the classifications for whether data distributions considered stationary or non-stationary + pre_obs_dist_classif = self.STATIONARY_STR if pre_stat > self.tolerance else self.NON_STATIONARY_STR + post_obs_dist_classif = self.STATIONARY_STR if post_stat > self.tolerance else self.NON_STATIONARY_STR + + # check if current support or future support + is_supported_type = sum([isinstance(module, x) for x in self.DEFAULT_DYNAMIC_STATIC_CHECK_SUPPORTED]) > 0 + + # store the set of important information for this module + module_info = { + self.TOLERANCE_KEY: self.tolerance, + self.DEFAULT_DYNAMIC_REC_KEY: dynamic_recommended, + self.PRE_OBS_COMP_STAT_KEY: pre_stat, + self.PRE_OBS_DATA_DIST_KEY: pre_obs_dist_classif, + self.POST_OBS_COMP_STAT_KEY: post_stat, + self.POST_OBS_DATA_DIST_KEY: post_obs_dist_classif, + self.IS_CURRENTLY_SUPPORTED_KEY: is_supported_type, + } + + module_dynamic_static_info[fqn] = module_info + + return module_dynamic_static_info + + def generate_detector_report(self, model: GraphModule) -> Tuple[str, Dict[str, Any]]: + r""" + Determines whether dynamic or static quantization is more appropriate for a given module. + + Takes advantage of the ModelReportObserver that records range information. + Stationary distribution of data are strictly above tolerance level for the comparison statistic: + + S = average_batch_activation_range/epoch_activation_range + + Nonstationary distributions are below or at the tolerance level for this metric. + + If the distribution of data right after the module is non-stationary, recommend dynamic quantization + Otherwise recommend static quantization + + This will then generate suggestions for dynamic vs static quantization focused around Linear. + + Args: + model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers + + Returns a tuple with two elements: + String report of of whether dynamic or static quantization is recommended for certain modules + Dictionary mapping modules with ModelReportObservers around them to: + whether dynamic quantization is recommended + their S metric of input to module + whether input to module is stationary or non-stationary + their S metric of output of module + whether output of module is stationary or non-stationary + the tolerance level to decided whether input/output is stationary or non-stationary + whether it is currently supported or planned for the future + """ + + # get the dictionary of the information to format the string report + module_dynamic_static_info = self._generate_dict_info(model) + + dynamic_vs_static_string = "Dynamic vs. Static Quantization suggestions: \n" + + modules_added: bool = False # check to make sure at least 1 module added. + + dynamic_benefit = " You will get more accurate results if you use dynamic quantization" + static_benefit = " You can increase model efficiency if you use static quantization" + future_support_str = ". This layer is not yet supported for dynamic quantization" + # This for loop goes through the information collected in module_dynamic_static_info and: + # Populates the string based report with the information from module_dynamic_static_info + # Compiles the complete report by appending relevant formatted strings + + for module_fqn in module_dynamic_static_info.keys(): + + # there is at least 1 module for suggestion + modules_added = True + module_info = module_dynamic_static_info[module_fqn] + suggestion_string_template = "For module {} it is suggested to use {} quantization because {}.\n" + + # decide what string formatting values will be + quantization_type = "" + quantization_reasoning = "the distribution of data before {} is {} and the distribution after is {}." + + benefit_str = "" + + # strings for if dynamic quantized per tensor is needed + recommend_per_tensor = ". We recommend to add a {} before this module if it is static." + rec_lay_to_add = "dynamic quantize per tensor layer" + dynamic_per_tensor_string = recommend_per_tensor.format(rec_lay_to_add) + dynamic_per_tensor_reasoning_string = ( + " This is because the input to this module has a non-stationary distribution" + ) + + # start composing explanation + if module_info[self.DEFAULT_DYNAMIC_REC_KEY]: + quantization_type = "dynamic" + # check if currently supported or future supported + benefit_str = dynamic_benefit + if not module_info[self.IS_CURRENTLY_SUPPORTED_KEY]: + benefit_str += future_support_str + else: + quantization_type = "static" + benefit_str = static_benefit + + # now set the quantization explanation string + quantization_reasoning = ( + quantization_reasoning.format( + module_fqn, module_info[self.PRE_OBS_DATA_DIST_KEY], module_info[self.POST_OBS_DATA_DIST_KEY] + ) + + benefit_str + ) + + # if we have a non-stationary input -> linear -> stationary we suggested static + # however, we want to also recommend they add a dynamic quantize per tensor right if this change is made + if ( + module_info[self.PRE_OBS_DATA_DIST_KEY] == self.NON_STATIONARY_STR + and module_info[self.POST_OBS_DATA_DIST_KEY] == self.STATIONARY_STR + ): + quantization_reasoning = ( + quantization_reasoning + dynamic_per_tensor_string + dynamic_per_tensor_reasoning_string + ) + + # format the overall suggestion string with the specific inputs + module_suggestion_string = suggestion_string_template.format( + module_fqn, quantization_type, quantization_reasoning + ) + + # append to overall suggestion + dynamic_vs_static_string += module_suggestion_string + + if not modules_added: + dynamic_vs_static_string += "No applicable layers for suggestions. Only linear and conv are valid.\n" + + # return the string as well as the dictionary of information + return (dynamic_vs_static_string, module_dynamic_static_info) + + +class InputWeightEqualizationDetector(DetectorBase): + r""" + Determines whether input-weight equalization can help improve quantization for certain modules. + + Specifically, this list of modules includes: + linear + conv + + Determines whether input-weight equalization is recommended based on the comp stat: + s_c = sqrt(w_c/W)/sqrt(i_c/I) + where: + w_c is range of weight for channel c, W is range of weight over all channels + i_c is range of input for channel c, I is range of input over all channels + + if s_c >= threshold or <= 1 / threshold, recommends input-weight equalization + + Args: + ratio_threshold (float): The threshold for s_c to determine if input-weight equalization is suggested + Should be between 0 and 1 (both non-inclusive) + ch_axis (int, optional): The channel axis being observed to determine input weight equalization + Default: 1 + + * :attr:`ratio_threshold`: The threshold for s_c to determine if input-weight equalization is suggested + Should be between 0 and 1 + + * :attr:`ch_axis`: The channel axis being observed to determine input weight equalization + + * :attr:`SUPPORTED_MODULES`: This specifies the modules that are supported for input-weight equalization + + * :attr:`DEFAULT_PRE_OBSERVER_NAME`: The name of the pre-observer to be inserted for this detector + """ + + SUPPORTED_MODULES: Set[Callable] = {nn.Linear, + nn.Conv1d, + nn.Conv2d, + nn.Conv3d, + nnqat.Linear, + nnqat.Conv1d, + nnqat.Conv2d, + nnqat.Conv3d} + + # names for the pre and post observers that are inserted + DEFAULT_PRE_OBSERVER_NAME: str = "model_report_pre_observer" + + # weight / activation prefix for each of the below info + WEIGHT_PREFIX = "weight_" + ACTIVATION_PREFIX = "input_activation_" + + # string names for keys of info dictionaries + PER_CHANNEL_MAX_KEY = "per_channel_max" + PER_CHANNEL_MIN_KEY = "per_channel_min" + GLOBAL_MAX_KEY = "global_max" + GLOBAL_MIN_KEY = "global_min" + + # keys for return dict of recommendations + RECOMMENDED_KEY = "input_weight_equalization_recommended" + COMP_METRIC_KEY = "input_weight_channel_comparison_metrics" + THRESHOLD_KEY = "input_weight_threshold" + CHANNEL_KEY = "input_weight_channel_axis" + + # default weight and info strings + WEIGHT_STR = "weight" + INPUT_STR = "input" + + # default for what ratio we recommend input weight + DEFAULT_RECOMMEND_INPUT_WEIGHT_CHANNEL_RATIO = 0.4 + + def __init__(self, ratio_threshold: float, ch_axis: int = 1): + # ensure passed in inputs are valid + if ratio_threshold <= 0 or ratio_threshold >= 1: + raise ValueError("Make sure threshold is > 0 and < 1") + + # initialize attributes based on args + self.ratio_threshold: float = ratio_threshold + self.ch_axis: int = ch_axis + + def _is_supported(self, module: nn.Module, insert: bool = False) -> bool: + r"""Returns whether the given module is supported for observers + + Args + module: The module to check and ensure is supported + insert: True if this is check for observer insertion, false if for report gen + + Returns True if the module is supported by observer, False otherwise + """ + # check to see if module is of a supported type + is_supported_type = sum([type(module) is x for x in self.SUPPORTED_MODULES]) > 0 + + # this is check for observer insertion + if insert: + return is_supported_type + else: + # this is for report gen and we also need to check if it contains observers + has_obs = hasattr(module, self.DEFAULT_PRE_OBSERVER_NAME) + return is_supported_type and has_obs + + def get_qconfig_info(self, model) -> Dict[str, DetectorQConfigInfo]: + r""" Returns the DetectorQConfigInfo for each module_fqn relevant + Args + model (nn.Module or subclass): model to find observer insertion points + + Returns a Dict mapping from unique observer fqns (where we want to insert them) to: + A DetectorQConfigInfo with the information to generate a QConfig for a specific module + """ + # run the helper function to populate the dictionary + # find the range of inputs + input_values: Dict[str, Dict] = self._extract_input_info(model) + + # find the range of weights + weight_values: Dict[str, Dict] = self._extract_weight_info(model) + + # calculate per_channel comparison statistic s_c + comp_stats: Dict[str, torch.Tensor] = self._generate_comparison_values(input_values, weight_values) + + # generate the return dictionary + input_weight_equalization_info: Dict[str, Dict] = self._generate_dict_info(input_values, weight_values, comp_stats) + + # we actually have a qconfig info object we are populating + module_fqn_to_detector_qconfig_info = {} + + for module_fqn in input_weight_equalization_info: + # create a detector info instance + detector_qconfig_info = DetectorQConfigInfo(module_fqn) + + # see if per channel quantization is supported + input_weight_recommended: bool = input_weight_equalization_info[module_fqn][self.RECOMMENDED_KEY] + detector_qconfig_info.is_equalization_recommended = input_weight_recommended + module_fqn_to_detector_qconfig_info[module_fqn] = detector_qconfig_info + + return module_fqn_to_detector_qconfig_info + + def determine_observer_insert_points(self, prepared_fx_model: GraphModule) -> Dict[str, Dict[str, Any]]: + r"""Determines where observers need to be inserted for the Input Weight Equalization Detector. + For this detector, we want to place observers in front of supported layers. + + Currently inserts observers for: + linear layers + conv layers + + Args: + prepared_fx_model (GraphModule): The prepared Fx GraphModule + + Returns a Dict mapping from unique observer fqns (where we want to insert them) to a Dict with: + key "target_node" -> the node we are trying to observe with this observer (torch.fx.node.Node) + key "observer_to_insert" -> the observer we wish to insert (ObserverBase) + key "is_post_observer" -> True if this is meant to be a post-observer for target_node, False if pre-observer + key "observer_args" -> The arguments that are meant to be passed into the observer + """ + + # observer for this detector is ModelReportObserver + obs_ctr = ModelReportObserver + + # return dict + obs_fqn_to_info: Dict[str, Dict[str, Any]] = {} + + for fqn, module in prepared_fx_model.named_modules(): + # check to see if module is of a supported type + if self._is_supported(module, insert=True): + # if it's a supported type, we want to get node and add observer insert locations + targeted_node = self._get_targeting_node(prepared_fx_model, fqn) + + # add entry for pre-observer + pre_obs_fqn = fqn + "." + self.DEFAULT_PRE_OBSERVER_NAME + + obs_fqn_to_info[pre_obs_fqn] = { + DETECTOR_TARGET_NODE_KEY: targeted_node, + DETECTOR_OBS_TO_INSERT_KEY: obs_ctr(ch_axis=self.ch_axis), + DETECTOR_IS_POST_OBS_KEY: False, + DETECTOR_OBS_ARGS_KEY: targeted_node.args, + } + + return obs_fqn_to_info + + def get_detector_name(self) -> str: + r"""Returns the name of this detector""" + return "input_weight_equalization_detector" + + def _extract_input_info(self, model: GraphModule) -> Dict[str, Dict]: + r""" + Takes in a calibrated GraphModule and then finds the relevant observers. + It then extracts the input information for each observer returns it + + Args + model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers + + Returns a dict mapping relevant module fqns (str) to a dict with keys: + "input_activation_per_channel_max" : maps to the per_channel max values + "input_activation_per_channel_min" : maps to the per_channel min values + "input_activation_global_max" : maps to the global max recorded + "input_activation_global_min" : maps to the global min recorded + """ + + # return dictionary mapping observer fqns to desired info + input_info: Dict[str, Dict] = {} + + for fqn, module in model.named_modules(): + # if module is supported and it has a pre-observer + if self._is_supported(module): + # get pre observer for the module + pre_obs = getattr(module, self.DEFAULT_PRE_OBSERVER_NAME) + + input_info[fqn] = { + self.ACTIVATION_PREFIX + self.PER_CHANNEL_MAX_KEY: pre_obs.max_val, + self.ACTIVATION_PREFIX + self.PER_CHANNEL_MIN_KEY: pre_obs.min_val, + self.ACTIVATION_PREFIX + self.GLOBAL_MAX_KEY: max(pre_obs.max_val), + self.ACTIVATION_PREFIX + self.GLOBAL_MIN_KEY: min(pre_obs.min_val), + } + + return input_info + + def _extract_weight_info(self, model: GraphModule) -> Dict[str, Dict]: + r""" + Takes in a calibrated GraphModule and then finds the relevant observers. + It then extracts the weight information for each layer an observer is attached to. + + Args + model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers + + Returns a dict mapping module fqns (str) to a dict with keys: + "per_channel_max" : maps to the per_channel max values + "per_channel_min" : maps to the per_channel min values + "global_max" : maps to the global max recorded + "global_min" : maps to the global min recorded + """ + # return dictionary mapping observer fqns to desired info + weight_info: Dict[str, Dict] = {} + + for fqn, module in model.named_modules(): + # if module is supported and it has a pre-observer + if self._is_supported(module): + # we don't need actual observer, just the module weights + # calculate min and max vals + device = module.weight.device + min_val: torch.Tensor = torch.tensor([float('inf')], device=device) + max_val: torch.Tensor = torch.tensor([float('-inf')], device=device) + x_copy = module.weight + x_dim = x_copy.size() + + new_axis_list = [i for i in range(len(x_dim))] # noqa: C416 + new_axis_list[self.ch_axis] = 0 + new_axis_list[0] = self.ch_axis + y = x_copy.permute(new_axis_list) + + # Need to match dtype of min/max because the updates to buffers + # are done in place and types need to match for comparisons + y = y.to(min_val.dtype) + y = torch.flatten(y, start_dim=1) + if min_val.numel() == 0 or max_val.numel() == 0: + min_val, max_val = torch.aminmax(y, dim=1) + else: + min_val_cur, max_val_cur = torch.aminmax(y, dim=1) + min_val = torch.min(min_val_cur, min_val) + max_val = torch.max(max_val_cur, max_val) + + weight_info[fqn] = { + self.WEIGHT_PREFIX + self.PER_CHANNEL_MAX_KEY: max_val, + self.WEIGHT_PREFIX + self.PER_CHANNEL_MIN_KEY: min_val, + self.WEIGHT_PREFIX + self.GLOBAL_MAX_KEY: max(max_val), + self.WEIGHT_PREFIX + self.GLOBAL_MIN_KEY: min(min_val), + } + + return weight_info + + def _calculate_range_ratio(self, info_dict: Dict, info_str: str, module_fqn: str) -> torch.Tensor: + r""" + Takes in an info dict and calculates the s_c matrix. + + Args: + info_dict (dict): A dictionary of either input or weight range info + info_str (str): A str describing whether currently looking at weight or input info + Either "weight" or "input" + module_fqn (str): The fqn of the module we are looking at + + Returns a tensor of values, where each value is the s_c stat for a different channel + """ + # calculate the ratios of the info + # get the prefix str + prefix_str = self.ACTIVATION_PREFIX if info_str == self.INPUT_STR else self.WEIGHT_PREFIX + + per_channel_range = info_dict[prefix_str + self.PER_CHANNEL_MAX_KEY] - info_dict[prefix_str + self.PER_CHANNEL_MIN_KEY] + global_range = info_dict[prefix_str + self.GLOBAL_MAX_KEY] - info_dict[prefix_str + self.GLOBAL_MIN_KEY] + + if global_range == 0: + range_zero_explanation = "We recommend removing this channel as it doesn't provide any useful information." + raise ValueError( + "The range of the {} data for module {} is 0, which means you have a constant value channel. {}".format( + info_str, module_fqn, range_zero_explanation + ) + ) + + ratio = per_channel_range / global_range + + return ratio + + def _generate_comparison_values(self, input_info: Dict, weight_info: Dict) -> Dict[str, torch.Tensor]: + r""" + Takes in the information on the min and max values of the inputs and weights and: + Calculates the comp stat for each channel: s_c = sqrt(w_c/W)/sqrt(i_c/I) + + Args: + input_info (dict): A dict mapping each observer to input range information + weight_info (dict): A dict mapping each observer to weight range information + + Returns a dict mapping relevant observer fqns (str) to a 1-D tensor. + Each value is a different s_c value for a different channel + """ + # create return dictionary for each observer + module_fqn_to_channel: Dict[str, torch.Tensor] = {} + + # for each module (both passed in dicts should have same keys) + for module_fqn in input_info: + + # raise error if not in weight info + if module_fqn not in weight_info: + raise KeyError(f"Unable to find weight range stats for module {module_fqn}") + + # calculate the ratios of the weight info and input info + weight_ratio = self._calculate_range_ratio(weight_info[module_fqn], self.WEIGHT_STR, module_fqn) + input_ratio = self._calculate_range_ratio(input_info[module_fqn], self.INPUT_STR, module_fqn) + + # if mismatched size, because of grouping, we want to replicate weight enough times + weight_channels = len(weight_ratio) + input_channels = len(input_ratio) + if weight_channels != input_channels: + # we try to replicate + assert input_channels % weight_channels == 0, "input channels should be divisible by weight channels." + # get replication factor + rep_factor: int = input_channels // weight_channels + + # weight ratio is (n,), input ratio is (k,), we just repeat weight ratio k // n + weight_ratio = weight_ratio.repeat(rep_factor) + + # calculate the s metric per channel + s = torch.sqrt(weight_ratio) / torch.sqrt(input_ratio) + module_fqn_to_channel[module_fqn] = s + + # return compiled observer ratios + return module_fqn_to_channel + + def _generate_dict_info(self, input_info: Dict, weight_info: Dict, comp_stats: Dict) -> Dict[str, Dict]: + r""" + Helper function for generate_detector_report that does the generation of the dictionary. + This process is done as specified in generate_detector_report documentation + + Args: + input_info (dict): A dict mapping each module to input range information + weight_info (dict): A dict mapping each module to weight range information + comp_stats (dict): A dict mapping each module to its corresponding comp stat + + Returns a dictionary mapping each module with relevant ModelReportObservers around them to: + whether input weight equalization is recommended + their s_c metric compared to the threshold + the threshold used to make the recommendation + the channel used for recording data + the input channel range info + the weight channel range info + """ + # store modules input weight equalization info + input_weight_equalization_info: Dict[str, Dict] = {} + + # for each module we add separate set of suggestions + for module_fqn in input_info: + + # get relevant info for this module + mod_input_info: Dict = input_info[module_fqn] + mod_weight_info: Dict = weight_info[module_fqn] + mod_comp_stat: Dict = comp_stats[module_fqn] + + # decide if each channel should have input weight equalization or not + channel_rec_vals: list = [] + + for val in mod_comp_stat: + float_rep: float = val.item() + + # decide if recommending input weight equalization + recommended: bool = float_rep >= self.ratio_threshold and float_rep <= 1 / self.ratio_threshold + channel_rec_vals.append(recommended) + + # build the return dict input + # also unpack input and weight dicts into it + input_weight_equalization_info[module_fqn] = { + self.RECOMMENDED_KEY: channel_rec_vals, + self.COMP_METRIC_KEY: mod_comp_stat, + self.THRESHOLD_KEY: self.ratio_threshold, + self.CHANNEL_KEY: self.ch_axis, + **mod_input_info, + **mod_weight_info, + } + + # return our compiled info for each module + return input_weight_equalization_info + + def generate_detector_report(self, model: GraphModule) -> Tuple[str, Dict[str, Any]]: + r""" + Determines whether input weight equalization is appropriate for a given module. + + Takes advantage of the ModelReport Observer which records per channel information of input range + It then uses the passed in weight info inconjunction to compute the desired ratio + Finally, it gives suggestions based on this information for each module of interest + + Args: + model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers + + Returns a tuple with two elements: + String report of of whether input weight equalization is recommended for certain modules + Dictionary mapping modules of interest to: + whether input weight equalization is recommended + their s_c metric compared to the threshold + the threshold used to make the recommendation + the channel used for recording data + the input channel range info + the weight channel range info + """ + + # find the range of inputs + input_values: Dict[str, Dict] = self._extract_input_info(model) + + # find the range of weights + weight_values: Dict[str, Dict] = self._extract_weight_info(model) + + # calculate per_channel comparison statistic s_c + comp_stats: Dict[str, torch.Tensor] = self._generate_comparison_values(input_values, weight_values) + + # generate the return dictionary + input_weight_equalization_info: Dict[str, Dict] = self._generate_dict_info(input_values, weight_values, comp_stats) + + # now we can generate report based on this information + input_weight_string = "Input-Weight Equalization suggestions: \n" + + # some strings to be formatted depending on module we are adding + module_suggestion_str = "For Module {} looked at with axis {}: \n" + channel_suggestion_str = "\tWe suggest {} input weight equalization because {}\n" + use_str = "to use" + no_use_str = "to not use" + input_weight_benefit_str = "{}/{} channels would benefit and we expect significant reduction in quantization error." + input_weight_non_benefit_reasoning = "{}/{} channels benefitting from input-weight equalization being applied." + input_weight_non_benefit_str = "we don't expect much improvement from input-weight equalization based on {}" + + # added module check + added_module: bool = False + + # compile the suggestion string + for module_fqn in input_weight_equalization_info: + # we added at least 1 module + added_module = True + # add the module level description + input_weight_string += module_suggestion_str.format(module_fqn, self.ch_axis) + + mod_info: Dict[str, Any] = input_weight_equalization_info[module_fqn] + + # gather info on how many channels would benefit from input weight and + recommendation_per_channel: torch.Tensor = mod_info[self.RECOMMENDED_KEY] + num_recs = sum(recommendation_per_channel) + + if num_recs / len(recommendation_per_channel) >= self.DEFAULT_RECOMMEND_INPUT_WEIGHT_CHANNEL_RATIO: + input_benefit_formatted = input_weight_benefit_str.format(num_recs, len(recommendation_per_channel)) + channel_str = channel_suggestion_str.format(use_str, input_benefit_formatted) + input_weight_string += channel_str + else: + non_benefit_reason_formatted = input_weight_non_benefit_reasoning.format(num_recs, len(recommendation_per_channel)) + non_benefit_str = input_weight_non_benefit_str.format(non_benefit_reason_formatted) + channel_str = channel_suggestion_str.format(no_use_str, non_benefit_str) + input_weight_string += channel_str + + # if no modules looked at, amend return string + if not added_module: + input_weight_string += "No applicable layers for suggestions. Only linear and conv valid.\n" + + # return a tuple with the string explanation and the compiled dict info + return (input_weight_string, input_weight_equalization_info) + + +class OutlierDetector(DetectorBase): + r""" + Determines whether there are significant outliers in activation data around a certain layer. + + This is ideally used in conjunction with information on stationary vs. non-stationary distribution: + If the data is stationary, and there are significant outliers, then we want to flag them + We want to do this on a per channel basis for detecting outliers + + Determines whether activation data is flagged as outlier based on if data is stationary and: + p_r = avg(100th percentile / "reference_percentile"th percentile) + where: + p_r is average percentile ratio across all batches in the epoch + reference_percentile is a percentile values between 0 and 100 exclusive + + if p_r is above some threshold, then we consider the activations to have significant outliers + + Args: + ratio_threshold (float, optional): The threshold for p_r to determine if there are outliers in activations + Should be >= 1 + Default: 3.5 + reference_percentile (float, optional): The denominator to find the relative scale of the 100th percentile + Should be between 0 and 1 + Default: 0.975 + fraction_batches_used_threshold (float, optional): Threshold of fraction of batches per channel to determine outlier + If fraction is below this, we deem number of samples used to calculate outliers as insignificant and alert user + regardless of whether we detected outliers or not in channel to take a closer look at channel results + Should be between 0 and 1 + Default: 0.95 + ch_axis (int, optional): The channel axis being observed to determine input weight equalization + Default: 1 + + * :attr:`ratio_threshold`: The threshold for p_r to determine if there are outliers in activations + The p_r value (average ratio of 100th percentile/reference_percentile) is compared to ratio_threshold + If it is significantly greater, then we consider it an outlier + This threshold was calculated based on the ratio of the percentiles in a normal distribution + The calculations behind value choice: https://drive.google.com/file/d/1N2wdtXWI-kOH8S7HH4-PYB_NmqzZil4p/view?usp=sharing + + * :attr:`reference_percentile`: The denominator of the top fraction to find the relative scale of the 100th percentile + Should be between 0 and 1 + The calculations behind value choice: https://drive.google.com/file/d/1N2wdtXWI-kOH8S7HH4-PYB_NmqzZil4p/view?usp=sharing + + * :attr:`fraction_batches_used_threshold`: The fraction of batches to determine outliers for each channel should be above this + Some batches may not be used because of 0-based errors, so this is to ensure a good amount of the total batches are used + Should be between 0 and 1 + + * :attr:`ch_axis`: The channel axis being observed to determine outliers + + * :attr:`DEFAULT_PRE_OBSERVER_NAME`: The name of the pre-observer to be inserted for this detector + """ + + # names for the pre observers that are inserted + DEFAULT_PRE_OBSERVER_NAME: str = "model_report_pre_observer" + + # pre activation prefix + INPUT_ACTIVATION_PREFIX = "input_activation_" + + # names for dict keys + OUTLIER_KEY = "outliers_detected" + NUM_BATCHES_KEY = "outlier_detection_batches_used" + IS_SUFFICIENT_BATCHES_KEY = "outlier_detection_is_sufficient_batches" + COMP_METRIC_KEY = "outlier_detection_percentile_ratios" + RATIO_THRES_KEY = "outlier_detection_ratio_threshold" + REF_PERCENTILE_KEY = "outlier_detection_reference_percentile" + CHANNEL_AXIS_KEY = "outlier_detection_channel_axis" + MAX_VALS_KEY = INPUT_ACTIVATION_PREFIX + "per_channel_max" + CONSTANT_COUNTS_KEY = "constant_batch_counts" + + def __init__( + self, + ratio_threshold: float = 3.5, + reference_percentile: float = 0.975, + fraction_batches_used_threshold: float = 0.95, + ch_axis: int = 1, + ): + # initialize the variables of interest + self.ratio_threshold = ratio_threshold + + # make sure passed in percentile is valid + assert reference_percentile >= 0 and reference_percentile <= 1 + assert fraction_batches_used_threshold >= 0 and fraction_batches_used_threshold <= 1 + self.reference_percentile = reference_percentile + self.fraction_batches_used_threshold = fraction_batches_used_threshold + self.ch_axis = ch_axis + + def get_detector_name(self) -> str: + r"""Returns the name of this detector""" + return "outlier_detector" + + def _supports_insertion(self, module: nn.Module) -> bool: + r"""Returns whether the given module is supported for observers insertion + + Any module that doesn't have children and isn't an observer itself is supported + + Args + module: The module to check and ensure is supported + + Returns True if the module is supported by observer, False otherwise + """ + # case for insertion of module + # check if the module has any children and isn't observer + num_children = len(list(module.children())) + return num_children == 0 and not _is_activation_post_process(module) + + def get_qconfig_info(self, model) -> Dict[str, DetectorQConfigInfo]: + r""" Returns the DetectorQConfigInfo for each module_fqn relevant + Args + model (nn.Module or subclass): model to find observer insertion points + + Returns a Dict mapping from unique observer fqns (where we want to insert them) to: + A DetectorQConfigInfo with the information to generate a QConfig for a specific module + """ + # currently doesn't do anything for outlier detector + return {} + + def _supports_report_gen(self, module: nn.Module) -> bool: + r"""Returns whether the given module is supported for report generation + + Any module that has a model report pre-observer is supported + + Args + module: The module to check and ensure is supported + + Returns True if the module is supported by observer, False otherwise + """ + return hasattr(module, self.DEFAULT_PRE_OBSERVER_NAME) + + def determine_observer_insert_points(self, prepared_fx_model: GraphModule) -> Dict[str, Dict[str, Any]]: + r""" Determines where observers need to be inserted for the Outlier Detector. + + For this detector, we want to place observers in front of supported layers. + + Currently inserts observers for: + all layers that do not have children (leaf level layers) + + Args: + prepared_fx_model (GraphModule): The prepared Fx GraphModule + + Returns a Dict mapping from unique observer fqns (where we want to insert them) to a Dict with: + key "target_node" -> the node we are trying to observe with this observer (torch.fx.node.Node) + key "observer_to_insert" -> the observer we wish to insert (ObserverBase) + key "is_post_observer" -> True if this is meant to be a post-observer for target_node, False if pre-observer + key "observer_args" -> The arguments that are meant to be passed into the observer + """ + # observer for this detector is ModelReportObserver + obs_ctr = ModelReportObserver + + # return dict + obs_fqn_to_info: Dict[str, Dict[str, Any]] = {} + + for fqn, module in prepared_fx_model.named_modules(): + # check to see if module is of a supported type + if self._supports_insertion(module): + # if it's a supported type, we want to get node and add observer insert locations + targeted_node = self._get_targeting_node(prepared_fx_model, fqn) + + # add entry for pre-observer + pre_obs_fqn = fqn + "." + self.DEFAULT_PRE_OBSERVER_NAME + + obs_fqn_to_info[pre_obs_fqn] = { + DETECTOR_TARGET_NODE_KEY: targeted_node, + DETECTOR_OBS_TO_INSERT_KEY: obs_ctr(ch_axis=self.ch_axis, comp_percentile=self.reference_percentile), + DETECTOR_IS_POST_OBS_KEY: False, + DETECTOR_OBS_ARGS_KEY: targeted_node.args, + } + + return obs_fqn_to_info + + def _calculate_outlier_info( + self, + percentile_ratios: torch.Tensor, + counted_batches: torch.Tensor, + total_batches: int, + ) -> Dict[str, List[bool]]: + r""" + Gives info on whether the percentile ratios calculated would be considered outliers + Also gives information on whether the collected data is statistically significant to make this claim + + Args: + percentile_ratios (torch.Tensor): The average percentile_ratios per channel calculated by the observer + counted_batches (torch.Tensor): The number of batches used for average calculation per tensor + total_batches (int): The total number of batches that passed through observer in this epoch + + Returns a dictionary mapping: + "outliers_detected" : list of bools per channel that are true if it is considered an outlier + "is_sufficient_batches": if o_r was >= fraction_batches_used_threshold: + where o_r = counted_batches / total_batches + """ + outlier_dict: Dict[str, List[bool]] = {self.OUTLIER_KEY: [], self.IS_SUFFICIENT_BATCHES_KEY: []} + + # get both as flattened lists for easy mapping + ratios_list: List = percentile_ratios.tolist() + num_batches_list: List = counted_batches.tolist() + + # calculate whether channels were statistically significant + significant_size = [ + batch_size / total_batches >= self.fraction_batches_used_threshold for batch_size in num_batches_list + ] + outlier_dict[self.IS_SUFFICIENT_BATCHES_KEY] = significant_size + + # calculate for each channel whether it's an outlier or not based on ratio + outlier_detected = [ratio > self.ratio_threshold for ratio in ratios_list] + outlier_dict[self.OUTLIER_KEY] = outlier_detected + + # return the dictionary with the two lists + return outlier_dict + + def _generate_info_dict(self, model: GraphModule) -> Dict[str, Dict]: + r""" + Helper function for generate_detector_report that does the generation of the dictionary. + This process is done as specified in generate_detector_report documentation + + Args: + model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers + + Returns a dict mapping relevant module fqns to: + whether there were outliers found in activation before + the number of batches used for each channel + whether fraction of applicable batches used is above fraction_batches_used_threshold + their p_r metric compared to the threshold + the threshold used to make the recommendation + the reference_percentile used to make the recommendation + the channel axis used to determine individual channels + the constant batch counts per channel + the per channel max values + """ + # return dictionary mapping observer fqns to desired info + info_dict: Dict[str, Dict] = {} + + for fqn, module in model.named_modules(): + # if module is supported and it has a pre-observer + if self._supports_report_gen(module): + # get pre observer for the module + pre_obs: ModelReportObserver = getattr(module, self.DEFAULT_PRE_OBSERVER_NAME) + + # get the number of batches and calculated ratio thresholds + num_batches: torch.Tensor = pre_obs.percentile_batches_tracked + average_ratios: torch.Tensor = pre_obs.average_percentile_ratio + channel_batch_cnts: torch.Tensor = pre_obs.constant_channels + total_batches: int = pre_obs.num_batches_tracked + + # also get the max values + max_vals: torch.Tensor = pre_obs.max_val + + # we have to specifically modify how we are recording negative ratio for pre-relu layers + for index, ratio_val in enumerate(average_ratios): + # check if we have a negative ratio + # a ratio might be negative if we have a situation where the 100th percentile is + # > 0 while the nth percentile is < 0, in which case this would not be detected + # as an outlier. Since we care more about magnitude, we make it positive. + if ratio_val.item() < 0: + # first make it positive + average_ratios[index] = -ratio_val + + if ratio_val.item() < 1: + # if it's less than 1 we have the flip it as well + average_ratios[index] = 1 / ratio_val + + outlier_calcs = self._calculate_outlier_info(average_ratios, num_batches, total_batches) + + # calculate whether ratios were outliers + info_dict[fqn] = { + self.CHANNEL_AXIS_KEY: self.ch_axis, + self.REF_PERCENTILE_KEY: self.reference_percentile, + self.RATIO_THRES_KEY: self.ratio_threshold, + self.COMP_METRIC_KEY: average_ratios, + self.NUM_BATCHES_KEY: num_batches, + self.OUTLIER_KEY: outlier_calcs[self.OUTLIER_KEY], + self.IS_SUFFICIENT_BATCHES_KEY: outlier_calcs[self.IS_SUFFICIENT_BATCHES_KEY], + self.CONSTANT_COUNTS_KEY: channel_batch_cnts, + self.MAX_VALS_KEY: max_vals + } + + return info_dict + + def generate_detector_report(self, model: GraphModule) -> Tuple[str, Dict[str, Any]]: + r""" + Determines whether input weight equalization is appropriate for a given module. + + Takes advantage of the ModelReport Observer which records the relevant percentile information + + Args: + model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers + + Returns a tuple with two elements: + String report of of whether there are outliers in the activations around certain modules + Dictionary mapping modules of interest to: + whether there were outliers found in activation before + the number of batches used for each channel + whether fraction of applicable batches used is above fraction_batches_used_threshold + their p_r metric compared to the threshold + the threshold used to make the recommendation + the reference_percentile used to make the recommendation + the channel axis used to determine individual channels + the constant batch counts per channel + the per channel max values + """ + # generate the information dictionary of outlier information + info_dict = self._generate_info_dict(model) + + # now we can generate report based on this information + outlier_string = "Outlier detection report: \n" + + # added module check + added_module: bool = False + + # some strings to be formatted depending on module we are adding + module_suggestion_str = "For Module {} looked at with axis {}: \n" + channel_suggestion_str = "\tFor channel {}, we found outliers in the preceding activation data with {}.\n" + channel_max_value_str = "a max value across all batches of {}" + note_string = "Note: outlier detection is only reliable for {}. We recommend {} to ensure the most accurate results." + note_distribution = "stationary distributions" + note_rec = "running the static vs. dynamic detector to ensure activation data before modules above is stationary" + + # suggestion for constant batch check since that can make it no outliers + constant_str = "\tFor channel {}, we found {} constant value batches. {}\n" + constant_suggestion = "We recommend taking a look at the dict and data to see how frequent this occurred and why." + + # compile the suggestion string + for module_fqn in info_dict: + # get module specific info + mod_info: Dict[str, Any] = info_dict[module_fqn] + # check to see if we already added high level model desc + added_model_desc = False + # look at each individual channel and add a suggestion + for index, outlier_detected in enumerate(mod_info[self.OUTLIER_KEY]): + if outlier_detected: + # we found at least 1 outlier + if not added_model_desc: + # add the module level description + outlier_string += module_suggestion_str.format(module_fqn, self.ch_axis) + added_model_desc = True + + # we mark that we found at least one outlier + added_module = True + max_value_found_str = channel_max_value_str.format(mod_info[self.MAX_VALS_KEY][index]) + channel_str = channel_suggestion_str.format(index, max_value_found_str) + outlier_string += channel_str + + # also check if we found constant batch + if mod_info[self.CONSTANT_COUNTS_KEY][index] != 0: + # make sure we add a module level highlight. + if not added_model_desc: + # add the module level description + outlier_string += module_suggestion_str.format(module_fqn, self.ch_axis) + added_model_desc = True + + constant_values_for_channel = mod_info[self.CONSTANT_COUNTS_KEY][index] + formatted_str = constant_str.format(index, constant_values_for_channel, constant_suggestion) + outlier_string += formatted_str + # we also added at least one thing to description + added_module = True + + + # if found outlier, give suggestion, else give default response + if added_module: + # compose the note string + note_composed = note_string.format(note_distribution, note_rec) + outlier_string += note_composed + else: + outlier_string += "There were no outliers found in the activations.\n" + + return (outlier_string, info_dict) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/_model_report/model_report_visualizer.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/_model_report/model_report_visualizer.py new file mode 100644 index 0000000000000000000000000000000000000000..5463862aa1cd41dde5b2713ee391429824d13bb2 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/_model_report/model_report_visualizer.py @@ -0,0 +1,666 @@ +import torch +from typing import Any, Set, Dict, List, Tuple, OrderedDict +from collections import OrderedDict as OrdDict + +# try to import tablate +got_tabulate = True +try: + from tabulate import tabulate +except ImportError: + got_tabulate = False + + +# var to see if we could import matplotlib +got_matplotlib = True +try: + import matplotlib.pyplot as plt +except ImportError: + got_matplotlib = False + +class ModelReportVisualizer: + r""" + The ModelReportVisualizer class aims to provide users a way to visualize some of the statistics + that were generated by the ModelReport API. However, at a higher level, the class aims to provide + some level of visualization of statistics to PyTorch in order to make it easier to parse data and + diagnose any potential issues with data or a specific model. With respect to the visualizations, + the ModelReportVisualizer class currently supports several methods of visualizing data. + + Supported Visualization Methods Include: + - Table format + - Plot format (line graph) + - Histogram format + + For all of the existing visualization methods, there is the option to filter data based on: + - A module fqn prefix + - Feature [required for the plot and histogram] + + * :attr:`generated_reports` The reports generated by the ModelReport class in the structure below + Ensure sure that features that are the same across different report contain the same name + Ensure that objects representing the same features are the same type / dimension (where applicable) + + Note: + Currently, the ModelReportVisualizer class supports visualization of data generated by the + ModelReport class. However, this structure is extensible and should allow the visualization of + other information as long as the information is structured in the following general format: + + Report Structure + -- module_fqn [module with attached detectors] + | + -- feature keys [not every detector extracts same information] + [same collected info has same keys, unless can be specific to detector] + + + The goal behind the class is that the generated visualizations can be used in conjunction with the generated + report for people to get a better understanding of issues and what the fix might be. It is also just to provide + a good visualization platform, since it might be hard to parse through the ModelReport returned dictionary as + that grows in size. + + General Use Flow Expected + 1.) Initialize ModelReport object with reports of interest by passing in initialized detector objects + 2.) Prepare your model with prepare_fx + 3.) Call model_report.prepare_detailed_calibration on your model to add relevant observers + 4.) Callibrate your model with data + 5.) Call model_report.generate_report on your model to generate report and optionally remove added observers + 6.) Use output of model_report.generate_report to initialize ModelReportVisualizer instance + 7.) Use instance to view different views of data as desired, applying filters as needed + 8.) Either see the super detailed information or just the actual printed or shown table / plot / histogram + + """ + + # keys for table dict + TABLE_TENSOR_KEY = "tensor_level_info" + TABLE_CHANNEL_KEY = "channel_level_info" + + # Constants for header vals + NUM_NON_FEATURE_TENSOR_HEADERS = 2 + NUM_NON_FEATURE_CHANNEL_HEADERS = 3 + + # Constants for row index in header + CHANNEL_NUM_INDEX = 2 + + def __init__(self, generated_reports: OrderedDict[str, Any]): + r""" + Initializes the ModelReportVisualizer instance with the necessary reports. + + Args: + generated_reports (Dict[str, Any]): The reports generated by the ModelReport class + can also be a dictionary generated in another manner, as long as format is same + """ + self.generated_reports = generated_reports + + def get_all_unique_module_fqns(self) -> Set[str]: + r""" + The purpose of this method is to provide a user the set of all module_fqns so that if + they wish to use some of the filtering capabilities of the ModelReportVisualizer class, + they don't need to manually parse the generated_reports dictionary to get this information. + + Returns all the unique module fqns present in the reports the ModelReportVisualizer + instance was initialized with. + """ + # returns the keys of the ordered dict + return set(self.generated_reports.keys()) + + def get_all_unique_feature_names(self, plottable_features_only: bool = True) -> Set[str]: + r""" + The purpose of this method is to provide a user the set of all feature names so that if + they wish to use the filtering capabilities of the generate_table_view(), or use either of + the generate_plot_view() or generate_histogram_view(), they don't need to manually parse + the generated_reports dictionary to get this information. + + Args: + plottable_features_only (bool): True if the user is only looking for plottable features, + False otherwise + plottable features are those that are tensor values + Default: True (only return those feature names that are plottable) + + Returns all the unique module fqns present in the reports the ModelReportVisualizer + instance was initialized with. + """ + unique_feature_names = set() + for module_fqn in self.generated_reports: + # get dict of the features + feature_dict: Dict[str, Any] = self.generated_reports[module_fqn] + + # loop through features + for feature_name in feature_dict: + # if we need plottable, ensure type of val is tensor + if not plottable_features_only or type(feature_dict[feature_name]) == torch.Tensor: + unique_feature_names.add(feature_name) + + # return our compiled set of unique feature names + return unique_feature_names + + def _get_filtered_data(self, feature_filter: str, module_fqn_filter: str) -> OrderedDict[str, Any]: + r""" + Filters the data and returns it in the same ordered dictionary format so the relevant views can be displayed. + + Args: + feature_filter (str): The feature filter, if we want to filter the set of data to only include + a certain set of features that include feature_filter + If feature = "", then we do not filter based on any features + module_fqn_filter (str): The filter on prefix for the module fqn. All modules that have fqn with + this prefix will be included + If module_fqn_filter = "" we do not filter based on module fqn, and include all modules + + First, the data is filtered based on module_fqn, and then filtered based on feature + Returns an OrderedDict (sorted in order of model) mapping: + module_fqns -> feature_names -> values + """ + # create return dict + filtered_dict: OrderedDict[str, Any] = OrdDict() + + for module_fqn in self.generated_reports: + # first filter based on module + if module_fqn_filter == "" or module_fqn_filter in module_fqn: + # create entry for module and loop through features + filtered_dict[module_fqn] = {} + module_reports = self.generated_reports[module_fqn] + for feature_name in module_reports: + # check if filtering on features and do so if desired + if feature_filter == "" or feature_filter in feature_name: + filtered_dict[module_fqn][feature_name] = module_reports[feature_name] + + # we have populated the filtered dict, and must return it + + return filtered_dict + + def _generate_tensor_table( + self, + filtered_data: OrderedDict[str, Dict[str, Any]], + tensor_features: List[str] + ) -> Tuple[List, List]: + r""" + Takes in the filtered data and features list and generates the tensor headers and table + + Currently meant to generate the headers and table for both the tensor information. + + Args: + filtered_data (OrderedDict[str, Dict[str, Any]]): An OrderedDict (sorted in order of model) mapping: + module_fqns -> feature_names -> values + tensor_features (List[str]): A list of the tensor level features + + Returns a tuple with: + A list of the headers of the tensor table + A list of lists containing the table information row by row + The 0th index row will contain the headers of the columns + The rest of the rows will contain data + """ + # now we compose the tensor information table + tensor_table: List[List[Any]] = [] + tensor_headers: List[str] = [] + + # append the table row to the table only if we have features + if len(tensor_features) > 0: + # now we add all the data + for index, module_fqn in enumerate(filtered_data): + # we make a new row for the tensor table + tensor_table_row = [index, module_fqn] + for feature in tensor_features: + # we iterate in same order of added features + + if feature in filtered_data[module_fqn]: + # add value if applicable to module + feature_val = filtered_data[module_fqn][feature] + else: + # add that it is not applicable + feature_val = "Not Applicable" + + # if it's a tensor we want to extract val + if isinstance(feature_val, torch.Tensor): + feature_val = feature_val.item() + + # we add to our list of values + tensor_table_row.append(feature_val) + + tensor_table.append(tensor_table_row) + + # add row of headers of we actually have something, otherwise just empty + if len(tensor_table) != 0: + tensor_headers = ["idx", "layer_fqn"] + tensor_features + + return (tensor_headers, tensor_table) + + def _generate_channels_table( + self, + filtered_data: OrderedDict[str, Any], + channel_features: List[str], + num_channels: int + ) -> Tuple[List, List]: + r""" + Takes in the filtered data and features list and generates the channels headers and table + + Currently meant to generate the headers and table for both the channels information. + + Args: + filtered_data (OrderedDict[str, Any]): An OrderedDict (sorted in order of model) mapping: + module_fqns -> feature_names -> values + channel_features (List[str]): A list of the channel level features + num_channels (int): Number of channels in the channel data + + Returns a tuple with: + A list of the headers of the channel table + A list of lists containing the table information row by row + The 0th index row will contain the headers of the columns + The rest of the rows will contain data + """ + # now we compose the table for the channel information table + channel_table: List[List[Any]] = [] + channel_headers: List[str] = [] + + # counter to keep track of number of entries in + channel_table_entry_counter: int = 0 + + if len(channel_features) > 0: + # now we add all channel data + for module_fqn in filtered_data: + # we iterate over all channels + for channel in range(num_channels): + # we make a new row for the channel + new_channel_row = [channel_table_entry_counter, module_fqn, channel] + for feature in channel_features: + if feature in filtered_data[module_fqn]: + # add value if applicable to module + feature_val = filtered_data[module_fqn][feature][channel] + else: + # add that it is not applicable + feature_val = "Not Applicable" + + # if it's a tensor we want to extract val + if type(feature_val) is torch.Tensor: + feature_val = feature_val.item() + + # add value to channel specific row + new_channel_row.append(feature_val) + + # add to table and increment row index counter + channel_table.append(new_channel_row) + channel_table_entry_counter += 1 + + # add row of headers of we actually have something, otherwise just empty + if len(channel_table) != 0: + channel_headers = ["idx", "layer_fqn", "channel"] + channel_features + + return (channel_headers, channel_table) + + def generate_filtered_tables(self, feature_filter: str = "", module_fqn_filter: str = "") -> Dict[str, Tuple[List, List]]: + r""" + Takes in optional filter values and generates two tables with desired information. + + The generated tables are presented in both a list-of-lists format + + The reason for the two tables are that they handle different things: + 1.) the first table handles all tensor level information + 2.) the second table handles and displays all channel based information + + The reasoning for this is that having all the info in one table can make it ambiguous which collected + statistics are global, and which are actually per-channel, so it's better to split it up into two + tables. This also makes the information much easier to digest given the plethora of statistics collected + + Tensor table columns: + idx layer_fqn feature_1 feature_2 feature_3 .... feature_n + ---- --------- --------- --------- --------- --------- + + Per-Channel table columns: + idx layer_fqn channel feature_1 feature_2 feature_3 .... feature_n + ---- --------- ------- --------- --------- --------- --------- + + Args: + feature_filter (str, optional): Filters the features presented to only those that + contain this filter substring + Default = "", results in all the features being printed + module_fqn_filter (str, optional): Only includes modules that contains this string + Default = "", results in all the modules in the reports to be visible in the table + + Returns a dictionary with two keys: + (Dict[str, Tuple[List, List]]) A dict containing two keys: + "tensor_level_info", "channel_level_info" + Each key maps to a tuple with: + A list of the headers of each table + A list of lists containing the table information row by row + The 0th index row will contain the headers of the columns + The rest of the rows will contain data + + Example Use: + >>> # xdoctest: +SKIP("undefined variables") + >>> mod_report_visualizer.generate_filtered_tables( + ... feature_filter = "per_channel_min", + ... module_fqn_filter = "block1" + ... ) # generates table with per_channel_min info for all modules in block 1 of the model + """ + # first get the filtered data + filtered_data: OrderedDict[str, Any] = self._get_filtered_data(feature_filter, module_fqn_filter) + + # now we split into tensor and per-channel data + tensor_features: Set[str] = set() + channel_features: Set[str] = set() + + # keep track of the number of channels we have + num_channels: int = 0 + + for module_fqn in filtered_data: + for feature_name in filtered_data[module_fqn]: + # get the data for that specific feature + feature_data = filtered_data[module_fqn][feature_name] + + # check if not zero dim tensor + is_tensor: bool = isinstance(feature_data, torch.Tensor) + is_not_zero_dim: bool = is_tensor and len(feature_data.shape) != 0 + + if is_not_zero_dim or isinstance(feature_data, list): + # works means per channel + channel_features.add(feature_name) + num_channels = len(feature_data) + else: + # means is per-tensor + tensor_features.add(feature_name) + + # we make them lists for iteration purposes + tensor_features_list: List[str] = sorted(tensor_features) + channel_features_list: List[str] = sorted(channel_features) + + # get the tensor info + tensor_headers, tensor_table = self._generate_tensor_table(filtered_data, tensor_features_list) + + # get the channel info + channel_headers, channel_table = self._generate_channels_table( + filtered_data, channel_features_list, num_channels + ) + + # let's now create the dictionary to return + table_dict = { + self.TABLE_TENSOR_KEY : (tensor_headers, tensor_table), + self.TABLE_CHANNEL_KEY : (channel_headers, channel_table) + } + + # return the two tables + return table_dict + + def generate_table_visualization(self, feature_filter: str = "", module_fqn_filter: str = ""): + r""" + Takes in optional filter values and prints out formatted tables of the information. + + The reason for the two tables printed out instead of one large one are that they handle different things: + 1.) the first table handles all tensor level information + 2.) the second table handles and displays all channel based information + + The reasoning for this is that having all the info in one table can make it ambiguous which collected + statistics are global, and which are actually per-channel, so it's better to split it up into two + tables. This also makes the information much easier to digest given the plethora of statistics collected + + Tensor table columns: + idx layer_fqn feature_1 feature_2 feature_3 .... feature_n + ---- --------- --------- --------- --------- --------- + + Per-Channel table columns: + + idx layer_fqn channel feature_1 feature_2 feature_3 .... feature_n + ---- --------- ------- --------- --------- --------- --------- + + Args: + feature_filter (str, optional): Filters the features presented to only those that + contain this filter substring + Default = "", results in all the features being printed + module_fqn_filter (str, optional): Only includes modules that contains this string + Default = "", results in all the modules in the reports to be visible in the table + + Example Use: + >>> # xdoctest: +SKIP("undefined variables") + >>> mod_report_visualizer.generate_table_visualization( + ... feature_filter = "per_channel_min", + ... module_fqn_filter = "block1" + ... ) + >>> # prints out neatly formatted table with per_channel_min info + >>> # for all modules in block 1 of the model + """ + # see if we got tabulate + if not got_tabulate: + print("Make sure to install tabulate and try again.") + return None + + # get the table dict and the specific tables of interest + table_dict = self.generate_filtered_tables(feature_filter, module_fqn_filter) + tensor_headers, tensor_table = table_dict[self.TABLE_TENSOR_KEY] + channel_headers, channel_table = table_dict[self.TABLE_CHANNEL_KEY] + + # get the table string and print it out + # now we have populated the tables for each one + # let's create the strings to be returned + table_str = "" + # the tables will have some headers columns that are non-feature + # ex. table index, module name, channel index, etc. + # we want to look at header columns for features, that come after those headers + if len(tensor_headers) > self.NUM_NON_FEATURE_TENSOR_HEADERS: + # if we have at least one tensor level feature to be added we add tensor table + table_str += "Tensor Level Information \n" + table_str += tabulate(tensor_table, headers=tensor_headers) + if len(channel_headers) > self.NUM_NON_FEATURE_CHANNEL_HEADERS: + # if we have at least one channel level feature to be added we add tensor table + table_str += "\n\n Channel Level Information \n" + table_str += tabulate(channel_table, headers=channel_headers) + + # if no features at all, let user know + if table_str == "": + table_str = "No data points to generate table with." + + print(table_str) + + def _get_plottable_data(self, feature_filter: str, module_fqn_filter: str) -> Tuple[List, List[List], bool]: + r""" + Takes in the feature filters and module filters and outputs the x and y data for plotting + + Args: + feature_filter (str): Filters the features presented to only those that + contain this filter substring + module_fqn_filter (str): Only includes modules that contains this string + + Returns a tuple of three elements + The first is a list containing relevant x-axis data + The second is a list containing the corresponding y-axis data + If the data is per channel + """ + # get the table dict and the specific tables of interest + table_dict = self.generate_filtered_tables(feature_filter, module_fqn_filter) + tensor_headers, tensor_table = table_dict[self.TABLE_TENSOR_KEY] + channel_headers, channel_table = table_dict[self.TABLE_CHANNEL_KEY] + + # make sure it is only 1 feature that is being plotted + # get the number of features in each of these + tensor_info_features_count = len(tensor_headers) - ModelReportVisualizer.NUM_NON_FEATURE_TENSOR_HEADERS + channel_info_features_count = len(channel_headers) - ModelReportVisualizer.NUM_NON_FEATURE_CHANNEL_HEADERS + + # see if valid tensor or channel plot + is_valid_per_tensor_plot: bool = tensor_info_features_count == 1 + is_valid_per_channel_plot: bool = channel_info_features_count == 1 + + # offset should either be one of tensor or channel table or neither + feature_column_offset = ModelReportVisualizer.NUM_NON_FEATURE_TENSOR_HEADERS + table = tensor_table + + # if a per_channel plot, we have different offset and table + if is_valid_per_channel_plot: + feature_column_offset = ModelReportVisualizer.NUM_NON_FEATURE_CHANNEL_HEADERS + table = channel_table + + x_data: List = [] + y_data: List[List] = [] + # the feature will either be a tensor feature or channel feature + if is_valid_per_tensor_plot: + for table_row_num, row in enumerate(table): + # get x_value to append + x_val_to_append = table_row_num + # the index of the feature will the 0 + num non feature columns + tensor_feature_index = feature_column_offset + row_value = row[tensor_feature_index] + if not type(row_value) == str: + x_data.append(x_val_to_append) + y_data.append(row_value) + elif is_valid_per_channel_plot: + # gather the x_data and multiple y_data + # calculate the number of channels + num_channels: int = max(row[self.CHANNEL_NUM_INDEX] for row in table) + 1 + for channel in range(num_channels): + y_data.append([]) # separate data list per channel + + for table_row_num, row in enumerate(table): + # get x_value to append + x_val_to_append = table_row_num + current_channel = row[self.CHANNEL_NUM_INDEX] # initially chose current channel + new_module_index: int = table_row_num // num_channels + x_val_to_append = new_module_index + + # the index of the feature will the 0 + num non feature columns + tensor_feature_index = feature_column_offset + row_value = row[tensor_feature_index] + if not type(row_value) == str: + # only append if new index we are appending + if len(x_data) == 0 or x_data[-1] != x_val_to_append: + x_data.append(x_val_to_append) + + # append value for that channel + y_data[current_channel].append(row_value) + else: + # more than one feature was chosen + error_str = "Make sure to pick only a single feature with your filter to plot a graph." + error_str += " We recommend calling get_all_unique_feature_names() to find unique feature names." + error_str += " Pick one of those features to plot." + raise ValueError(error_str) + + # return x, y values, and if data is per-channel + return (x_data, y_data, is_valid_per_channel_plot) + + def generate_plot_visualization(self, feature_filter: str, module_fqn_filter: str = ""): + r""" + Takes in a feature and optional module_filter and plots of the desired data. + + For per channel features, it averages the value across the channels and plots a point + per module. The reason for this is that for models with hundreds of channels, it can + be hard to differentiate one channel line from another, and so the point of generating + a single average point per module is to give a sense of general trends that encourage + further deep dives. + + Note: + Only features in the report that have tensor value data are plottable by this class + When the tensor information is plotted, it will plot: + idx as the x val, feature value as the y_val + When the channel information is plotted, it will plot: + the first idx of each module as the x val, feature value as the y_val [for each channel] + The reason for this is that we want to be able to compare values across the + channels for same layer, and it will be hard if values are staggered by idx + This means each module is represented by only 1 x value + Args: + feature_filter (str): Filters the features presented to only those that + contain this filter substring + module_fqn_filter (str, optional): Only includes modules that contains this string + Default = "", results in all the modules in the reports to be visible in the table + + Example Use: + >>> # xdoctest: +SKIP("undefined variables") + >>> mod_report_visualizer.generate_plot_visualization( + ... feature_filter = "per_channel_min", + ... module_fqn_filter = "block1" + ... ) + >>> # outputs line plot of per_channel_min information for all + >>> # modules in block1 of model each channel gets it's own line, + >>> # and it's plotted across the in-order modules on the x-axis + """ + # checks if we have matplotlib and let's user know to install it if don't + if not got_matplotlib: + print("make sure to install matplotlib and try again.") + return None + + # get the x and y data and if per channel + x_data, y_data, data_per_channel = self._get_plottable_data(feature_filter, module_fqn_filter) + + # plot based on whether data is per channel or not + ax = plt.subplot() + ax.set_ylabel(feature_filter) + ax.set_title(feature_filter + " Plot") + plt.xticks(x_data) # only show ticks for actual points + + if data_per_channel: + ax.set_xlabel("First idx of module") + # set the legend as well + # plot a single line that is average of the channel values + num_modules = len(y_data[0]) # all y_data have same length, so get num modules + num_channels = len(y_data) # we want num channels to be able to calculate average later + + avg_vals = [sum(y_data[:][index]) / num_channels for index in range(num_modules)] + + # plot the three things we measured + ax.plot(x_data, avg_vals, label=f"Average Value Across {num_channels} Channels") + ax.legend(loc='upper right') + else: + ax.set_xlabel("idx") + ax.plot(x_data, y_data) + + # actually show the plot + plt.show() + + def generate_histogram_visualization(self, feature_filter: str, module_fqn_filter: str = "", num_bins: int = 10): + r""" + Takes in a feature and optional module_filter and plots the histogram of desired data. + + Note: + Only features in the report that have tensor value data can be viewed as a histogram + If you want to plot a histogram from all the channel values of a specific feature for + a specific model, make sure to specify both the model and the feature properly + in the filters and you should be able to see a distribution of the channel data + + Args: + feature_filter (str, optional): Filters the features presented to only those that + contain this filter substring + Default = "", results in all the features being printed + module_fqn_filter (str, optional): Only includes modules that contains this string + Default = "", results in all the modules in the reports to be visible in the table + num_bins (int, optional): The number of bins to create the histogram with + Default = 10, the values will be split into 10 equal sized bins + + Example Use: + >>> # xdoctest: +SKIP + >>> mod_report_visualizer.generategenerate_histogram_visualization_plot_visualization( + ... feature_filter = "per_channel_min", + ... module_fqn_filter = "block1" + ... ) + # outputs histogram of per_channel_min information for all modules in block1 of model + information is gathered across all channels for all modules in block 1 for the + per_channel_min and is displayed in a histogram of equally sized bins + """ + # checks if we have matplotlib and let's user know to install it if don't + if not got_matplotlib: + print("make sure to install matplotlib and try again.") + return None + + # get the x and y data and if per channel + x_data, y_data, data_per_channel = self._get_plottable_data(feature_filter, module_fqn_filter) + + # for histogram, we just care about plotting the y data + # plot based on whether data is per channel or not + ax = plt.subplot() + ax.set_xlabel(feature_filter) + ax.set_ylabel("Frequency") + ax.set_title(feature_filter + " Histogram") + + if data_per_channel: + # set the legend as well + # combine all the data + all_data = [] + for channel_info in y_data: + all_data.extend(channel_info) + + val, bins, _ = plt.hist( + all_data, + bins=num_bins, + stacked=True, + rwidth=0.8, + ) + plt.xticks(bins) + else: + val, bins, _ = plt.hist( + y_data, + bins=num_bins, + stacked=False, + rwidth=0.8, + ) + plt.xticks(bins) + + plt.show() diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/quantize_handler.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/quantize_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..e70040f7e6492fcf0d8abb29d3b07e48669adc91 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/quantize_handler.py @@ -0,0 +1,197 @@ +from abc import ABC +from typing import Callable, Dict, List, Optional, Type + +import torch + +from torch.ao.quantization.backend_config import ( + BackendConfig, + DTypeConfig, + ObservationType, +) +from torch.ao.quantization.utils import NodePattern, Pattern, QuantizerCls +from torch.fx.graph import Node + +from .utils import all_node_args_have_no_tensors + + +__all__ = [ + "QuantizeHandler", + "BinaryOpQuantizeHandler", + "CatQuantizeHandler", + "ConvReluQuantizeHandler", + "LinearReLUQuantizeHandler", + "BatchNormQuantizeHandler", + "EmbeddingQuantizeHandler", + "RNNDynamicQuantizeHandler", + "DefaultNodeQuantizeHandler", + "FixedQParamsOpQuantizeHandler", + "CopyNodeQuantizeHandler", + "GeneralTensorShapeOpQuantizeHandler", + "CustomModuleQuantizeHandler", + "StandaloneModuleQuantizeHandler", +] + +def _default_root_node_getter(node_pattern): + if node_pattern is None: + return node_pattern + while not isinstance(node_pattern, Node): + node_pattern = node_pattern[-1] + return node_pattern + +# Base Pattern Handler +class QuantizeHandler(ABC): # noqa: B024 + """ Base handler class for the quantizer patterns + """ + def __init__( + self, + node_pattern: NodePattern, + modules: Dict[str, torch.nn.Module], + root_node_getter: Optional[Callable] = None, + is_custom_module=False, + is_standalone_module=False): + """ Records pattern information in __init__, which will be used + in convert + """ + self.node_pattern = node_pattern + self.modules = modules + if root_node_getter is None: + root_node_getter = _default_root_node_getter + self.root_node = root_node_getter(node_pattern) + self.is_custom_module_ = is_custom_module + self.is_standalone_module_ = is_standalone_module + self.num_tensor_args = 0 + # determine how many of the first two args are Tensors (versus scalars) + # this distinguishes things like "x + y" from "x + 2" or "2 + x" + if isinstance(self.root_node, Node): + cache_for_no_tensor_check: Dict[Node, bool] = {} + for arg_idx in range(len(self.root_node.args)): + arg = self.root_node.args[arg_idx] + if isinstance(arg, Node) and ( + not all_node_args_have_no_tensors( + arg, self.modules, cache_for_no_tensor_check)): + self.num_tensor_args += 1 + + def is_general_tensor_value_op(self) -> bool: + """ + Returns True if the operator works for both floating point and + quantized input, and does some computation based on the input Tensor, + or the ops that only re-arranges the Tensor values or query some metadata + about the Tensor + so we need to insert observer/fake_quant for the output of the + operator (same observer instance as input) + since the distribution of values is different for input and output + Tensors (for HistogramObserver) while they share the same quantization + parameters + Example operator: avgpool2d, reshape, transpose, maxpool2d + Example observed operator: + observer_0 - avgpool2d - observer_0 (same observer instance as input) + """ + return False + + def is_custom_module(self): + return self.is_custom_module_ + + def is_standalone_module(self): + return self.is_standalone_module_ + +def _get_quantize_handler_cls( + observation_type: ObservationType, + dtype_configs: List[DTypeConfig], + num_tensor_args_to_observation_type: Dict[int, ObservationType]) -> Type[QuantizeHandler]: + """ + Return a configurable QuantizeHandler that matches the given specifications from the backend. + """ + + class ConfigurableQuantizeHandler(QuantizeHandler): + def __init__( + self, + node_pattern: NodePattern, + modules: Dict[str, torch.nn.Module], + root_node_getter: Optional[Callable] = None): + super().__init__(node_pattern, modules, root_node_getter) + if num_tensor_args_to_observation_type: + assert self.num_tensor_args in num_tensor_args_to_observation_type, \ + f"Must provide observation_type config for tensor number {self.num_tensor_args}" \ + f" in num_tensor_args_to_observation_type for {node_pattern}" + self.observation_type = num_tensor_args_to_observation_type[self.num_tensor_args] + else: + self.observation_type = observation_type + self.dtype_configs = dtype_configs + + def is_general_tensor_value_op(self) -> bool: + return self.observation_type == ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT + + return ConfigurableQuantizeHandler + +def _get_pattern_to_quantize_handlers(backend_config: BackendConfig) -> Dict[Pattern, QuantizerCls]: + """ + Note: Quantize handler is just a holder for some check methods like + (should_insert_observer_for_output), maybe this can be a enum as well, + we can refactor this after we convert the path for fbgemm/qnnpack fully to the + new path, this is not exposed to backend developers + """ + pattern_to_quantize_handlers = {} + for pattern, config in backend_config._pattern_complex_format_to_config.items(): + observation_type = config.observation_type + dtype_configs = config.dtype_configs + num_tensor_args_to_observation_type = config._num_tensor_args_to_observation_type + pattern_to_quantize_handlers[pattern] = \ + _get_quantize_handler_cls( + observation_type, + dtype_configs, + num_tensor_args_to_observation_type) + return pattern_to_quantize_handlers + +# TODO: remove this class, this is still exposed in torch.ao.quantization +# but we should be able to break bc +class BinaryOpQuantizeHandler(QuantizeHandler): + pass + +class CatQuantizeHandler(QuantizeHandler): + pass + +# TODO: remove this class +class ConvReluQuantizeHandler(QuantizeHandler): + pass + +# TODO: remove this class +class LinearReLUQuantizeHandler(QuantizeHandler): + pass + +# TODO: remove this class +class BatchNormQuantizeHandler(QuantizeHandler): + pass + +# TODO: remove this class +class EmbeddingQuantizeHandler(QuantizeHandler): + pass + +# TODO: remove this class +class RNNDynamicQuantizeHandler(QuantizeHandler): + pass + +# TODO: remove this class +class DefaultNodeQuantizeHandler(QuantizeHandler): + """ Common quantized op, first input and first output will be quantized + """ + pass + +# TODO: remove this class +class FixedQParamsOpQuantizeHandler(QuantizeHandler): + pass + +# TODO: remove +class CopyNodeQuantizeHandler(QuantizeHandler): + pass + +# TODO: remove +class GeneralTensorShapeOpQuantizeHandler(QuantizeHandler): + pass + +# TODO: not used, can be removed after torch.ao.quantization namespace is deprecated +class CustomModuleQuantizeHandler(QuantizeHandler): + pass + +# TODO: not used, can be removed after torch.ao.quantization namespace is deprecated +class StandaloneModuleQuantizeHandler(QuantizeHandler): + pass diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/utils.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..38176c5907d940071d06bd8c4fe56dea771d1a5b --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/utils.py @@ -0,0 +1,885 @@ +import copy +import torch +import torch.nn as nn +from torch.ao.quantization import ( + QConfigAny, + QuantType, +) +from torch.ao.quantization.backend_config import ( + DTypeWithConstraints, +) +from torch.ao.quantization.fake_quantize import ( + FakeQuantizeBase, + FixedQParamsFakeQuantize, +) +from torch.ao.quantization.observer import ( + FixedQParamsObserver, + ObserverBase, +) +from torch.ao.quantization.qconfig import ( + float16_static_qconfig, + float16_dynamic_qconfig, + qconfig_equals, +) +from torch.ao.quantization.stubs import DeQuantStub +from torch.ao.quantization.utils import ( + activation_is_statically_quantized, +) +from torch.ao.quantization.observer import _is_activation_post_process +from torch.ao.quantization.qconfig_mapping import QConfigMapping + +from torch.fx import GraphModule, map_arg + +from torch.fx.graph import ( + Graph, + Node, +) +from .custom_config import PrepareCustomConfig +# importing the lib so that the quantized_decomposed ops are registered +from ._decomposed import quantized_decomposed_lib # noqa: F401 + +from typing import Callable, Optional, List, Dict, Any, Set, Tuple, Union, Type +from dataclasses import dataclass +from collections import namedtuple +import operator +import warnings + +# TODO: revisit this list. Many helper methods shouldn't be public +__all__ = [ + "all_node_args_except_first", + "all_node_args_have_no_tensors", + "assert_and_get_unique_device", + "collect_producer_nodes", + "create_getattr_from_value", + "create_node_from_old_node_preserve_meta", + "EMPTY_ARG_DICT", + "get_custom_module_class_keys", + "get_linear_prepack_op_for_dtype", + "get_new_attr_name_with_prefix", + "get_non_observable_arg_indexes_and_types", + "get_qconv_prepack_op", + "get_skipped_module_name_and_classes", + "graph_module_from_producer_nodes", + "maybe_get_next_module", + "NodeInfo", + "node_arg_is_bias", + "node_arg_is_weight", + "NON_OBSERVABLE_ARG_DICT", + "NON_QUANTIZABLE_WEIGHT_OPS", + "return_arg_list", + "ObservedGraphModuleAttrs", +] + +NON_QUANTIZABLE_WEIGHT_OPS = {torch.nn.functional.layer_norm, torch.nn.functional.group_norm, torch.nn.functional.instance_norm} + +@dataclass +class ObservedGraphModuleAttrs: + node_name_to_qconfig: Dict[str, QConfigAny] + node_name_to_scope: Dict[str, Tuple[str, type]] + prepare_custom_config: PrepareCustomConfig + equalization_node_name_to_qconfig: Dict[str, Any] + qconfig_mapping: QConfigMapping + is_qat: bool + observed_node_names: Set[str] + is_observed_standalone_module: bool = False + standalone_module_input_quantized_idxs: Optional[List[int]] = None + standalone_module_output_quantized_idxs: Optional[List[int]] = None + +def node_arg_is_weight(node: Node, arg: Any) -> bool: + """Returns if node arg is weight""" + weight_index = None + if "target_dtype_info" in node.meta: + weight_index = node.meta["target_dtype_info"].get("weight_index", None) + if weight_index is not None and weight_index < len(node.args) and node.args[weight_index] is arg: + return True + return node.kwargs.get("weight") is arg + +def node_arg_is_bias(node: Node, arg: Any) -> bool: + """Returns if node arg is bias""" + bias_index = None + if "target_dtype_info" in node.meta: + bias_index = node.meta["target_dtype_info"].get("bias_index", None) + if bias_index is not None and bias_index < len(node.args) and node.args[bias_index] is arg: + return True + return node.kwargs.get("bias") is arg + +def get_custom_module_class_keys(custom_module_mapping: Dict[QuantType, Dict[Type, Type]]) -> List[Any]: + r""" Get all the unique custom module keys in the custom config dict + e.g. + Input: + { + QuantType.STATIC: { + CustomModule1: ObservedCustomModule + }, + QuantType.DYNAMIC: { + CustomModule2: DynamicObservedCustomModule + }, + QuantType.WEIGHT_ONLY: { + CustomModule3: WeightOnlyObservedCustomModule + }, + } + + Output: + # extract the keys across all inner STATIC, DYNAMIC, and WEIGHT_ONLY dicts + [CustomModule1, CustomModule2, CustomModule3] + """ + # using set to dedup + float_custom_module_classes : Set[Any] = set() + for quant_mode in [QuantType.STATIC, QuantType.DYNAMIC, QuantType.WEIGHT_ONLY]: + quant_mode_custom_module_config = custom_module_mapping.get(quant_mode, {}) + quant_mode_custom_module_classes = set(quant_mode_custom_module_config.keys()) + float_custom_module_classes |= quant_mode_custom_module_classes + return list(float_custom_module_classes) + +def get_linear_prepack_op_for_dtype(dtype): + if dtype == torch.float16: + return torch.ops.quantized.linear_prepack_fp16 + elif dtype == torch.qint8: + return torch.ops.quantized.linear_prepack + else: + raise Exception("can't get linear prepack op for dtype:", dtype) + +def get_qconv_prepack_op(conv_op: Callable) -> Callable: + prepack_ops = { + torch.nn.functional.conv1d: torch.ops.quantized.conv1d_prepack, + torch.nn.functional.conv2d: torch.ops.quantized.conv2d_prepack, + torch.nn.functional.conv3d: torch.ops.quantized.conv3d_prepack, + torch.nn.functional.conv_transpose1d: torch.ops.quantized.conv_transpose1d_prepack, + torch.nn.functional.conv_transpose2d: torch.ops.quantized.conv_transpose2d_prepack, + torch.nn.functional.conv_transpose3d: torch.ops.quantized.conv_transpose3d_prepack, + } + prepack_op = prepack_ops.get(conv_op, None) + assert prepack_op, f"Didn't find prepack op for {conv_op}" + return prepack_op + +# Returns a function that can get a new attribute name for module with given +# prefix, for example, +# >> get_new_observer_name = get_new_attr_name_with_prefix('_observer') +# >> new_name = get_new_observer_name(module) +# new_name will be an unused attribute name on module, e.g. `_observer_1` +def get_new_attr_name_with_prefix(prefix: str) -> Callable: + prefix = prefix.replace(".", "_") + + def get_new_attr_name(module: torch.nn.Module): + def get_attr_name(i: int): + return prefix + str(i) + i = 0 + attr_name = get_attr_name(i) + while hasattr(module, attr_name): + i += 1 + attr_name = get_attr_name(i) + return attr_name + return get_new_attr_name + +def collect_producer_nodes(node: Node) -> Optional[List[Node]]: + r''' Starting from a target node, trace back until we hit inpu or + getattr node. This is used to extract the chain of operators + starting from getattr to the target node, for example + def forward(self, x): + observed = self.observer(self.weight) + return F.linear(x, observed) + collect_producer_nodes(observed) will either return a list of nodes that + produces the observed node or None if we can't extract a self contained + graph without free variables(inputs of the forward function). + ''' + nodes = [node] + frontier = [node] + while frontier: + node = frontier.pop() + all_args = list(node.args) + list(node.kwargs.values()) + for arg in all_args: + if not isinstance(arg, Node): + continue + if arg.op == 'placeholder': + # hit input, can't fold in this case + return None + nodes.append(arg) + if not (arg.op == 'call_function' and arg.target == getattr): + frontier.append(arg) + return nodes + +def graph_module_from_producer_nodes( + root: GraphModule, producer_nodes: List[Node]) -> GraphModule: + r''' Construct a graph module from extracted producer nodes + from `collect_producer_nodes` function + Args: + root: the root module for the original graph + producer_nodes: a list of nodes we use to construct the graph + Return: + A graph module constructed from the producer nodes + ''' + assert len(producer_nodes) > 0, 'list of producer nodes can not be empty' + # since we traced back from node to getattr + producer_nodes.reverse() + graph = Graph() + env: Dict[Any, Any] = {} + + def load_arg(a): + return map_arg(a, lambda node: env[node]) + for producer_node in producer_nodes: + env[producer_node] = graph.node_copy(producer_node, load_arg) + graph.output(load_arg(producer_nodes[-1])) + graph_module = GraphModule(root, graph) + return graph_module + +def assert_and_get_unique_device(module: torch.nn.Module) -> Any: + """ + Returns the unique device for a module, or None if no device is found. + Throws an error if multiple devices are detected. + """ + devices = {p.device for p in module.parameters()} | \ + {p.device for p in module.buffers()} + """ + As a temp workaround for AIMP HHC publish we added CPU check.remove it later. T163614564 + """ + if {torch.device("cpu"), torch.device("meta")} == devices: + warnings.warn("Both 'meta' and 'cpu' are present in the list of devices. Module can have one device. We Select 'cpu'.") + devices = {torch.device("cpu")} + "" + assert len(devices) <= 1, ( + "prepare only works with cpu or single-device CUDA modules, " + f"but got devices {devices}" + ) + device = next(iter(devices)) if len(devices) > 0 else None + return device + +def create_getattr_from_value(module: torch.nn.Module, graph: Graph, prefix: str, value: Any) -> Node: + """ + Given a value of any type, creates a getattr node corresponding to the value and + registers the value as a buffer to the module. + """ + get_new_attr_name = get_new_attr_name_with_prefix(prefix) + attr_name = get_new_attr_name(module) + device = assert_and_get_unique_device(module) + new_value = value.clone().detach() if isinstance(value, torch.Tensor) \ + else torch.tensor(value, device=device) + module.register_buffer(attr_name, new_value) + # Create get_attr with value + attr_node = graph.create_node("get_attr", attr_name) + return attr_node + +def all_node_args_have_no_tensors(node: Node, modules: Dict[str, torch.nn.Module], cache: Dict[Node, bool]) -> bool: + """ + If we know for sure that all of this node's args have no + tensors (are primitives), return True. If we either + find a tensor or are not sure, return False. Note: this + function is not exact. + """ + if cache and node in cache: + return cache[node] + + result = False # will be overwritten + if not isinstance(node, Node): + result = True + elif node.op == 'placeholder': + result = False + elif node.op == 'call_module': + assert isinstance(node.target, str) + if _is_activation_post_process(modules[node.target]): + result = all_node_args_have_no_tensors(node.args[0], modules, cache) # type: ignore[arg-type] + elif node.op == 'call_module': + result = False + elif node.op == 'call_function' and node.target is operator.getitem: + result = all_node_args_have_no_tensors(node.args[0], modules, cache) # type: ignore[arg-type] + elif node.op == 'get_attr': + result = False + elif node.target is getattr and node.args[1] in ['ndim', 'shape']: + # x1 = x0.ndim + result = True + elif node.op == 'call_method' and node.target == 'size': + # x1 = x0.size(0) + result = True + else: + found_one_tensor = False + for arg in node.args: + if isinstance(arg, list): + for list_el in arg: + if isinstance(list_el, Node): + this_list_el_args_have_no_tensors = \ + all_node_args_have_no_tensors(list_el, modules, cache) + found_one_tensor = found_one_tensor or \ + (not this_list_el_args_have_no_tensors) + # If found_one_tensor is True, there is no point in + # recursing further as the end result will always + # be True. + # TODO(future PR): remove this entire function and + # change to dtype inference without recursion. + if found_one_tensor: + result = not found_one_tensor + if cache: + cache[node] = result + return result + elif isinstance(arg, int): + pass + else: + if isinstance(arg, Node): + this_arg_args_have_no_tensors = all_node_args_have_no_tensors(arg, modules, cache) + found_one_tensor = found_one_tensor or \ + (not this_arg_args_have_no_tensors) + # If found_one_tensor is True, there is no point in + # recursing further as the end result will always + # be True. + # TODO(future PR): remove this entire function and + # change to dtype inference without recursion. + if found_one_tensor: + result = not found_one_tensor + if cache: + cache[node] = result + return result + else: + found_one_tensor = True + result = not found_one_tensor + if cache: + cache[node] = result + return result + +def all_node_args_except_first(node: Node) -> List[int]: + """ + Returns all node arg indices after first + """ + return list(range(1, len(node.args))) + +def return_arg_list(arg_indices: List[int]) -> Callable[[Node], List[int]]: + """ + Constructs a function that takes a node as arg and returns the arg_indices + that are valid for node.args + """ + def arg_indices_func(node: Node) -> List[int]: + return [i for i in arg_indices if i < len(node.args)] + return arg_indices_func + +NodeInfo = namedtuple("NodeInfo", "op target") + +# this dict identifies which indices of a node are non tensors +# so that they can be propagated correctly since inserting observers +# for them would cause errors + +NON_OBSERVABLE_ARG_DICT: Dict[NodeInfo, Dict[Union[type, torch.dtype], Callable[[Node], List[int]]]] = { + NodeInfo("call_method", "masked_fill") : { + torch.bool: return_arg_list([1]), + float: return_arg_list([2]) + }, + NodeInfo("call_method", "permute") : { + int: all_node_args_except_first + }, + NodeInfo("call_method", "repeat") : { + int: all_node_args_except_first + }, + NodeInfo("call_method", "reshape") : { + int: all_node_args_except_first + }, + NodeInfo("call_method", "size") : { + int: return_arg_list([1]) + }, + NodeInfo("call_method", "transpose") : { + int: all_node_args_except_first + }, + NodeInfo("call_method", torch.transpose) : { + int: all_node_args_except_first + }, + NodeInfo("call_method", "unsqueeze") : { + int: return_arg_list([1]) + }, + NodeInfo("call_method", "unsqueeze_") : { + int: return_arg_list([1]) + }, + NodeInfo("call_method", torch.unsqueeze) : { + int: return_arg_list([1]) + }, + NodeInfo("call_method", "view") : { + int: all_node_args_except_first + }, +} + +EMPTY_ARG_DICT: Dict[Union[type, torch.dtype], Callable[[Node], List[int]]] = {} + +def get_non_observable_arg_indexes_and_types(node: Node) -> Dict[Union[type, torch.dtype], Callable[[Node], List[int]]]: + """ + Returns a dict with of non float tensor types as keys and values which correspond to a + function to retrieve the list (which takes the node as an argument) + """ + info = NodeInfo(node.op, node.target) + + return NON_OBSERVABLE_ARG_DICT.get(info, EMPTY_ARG_DICT) + +def maybe_get_next_module( + node: Node, + modules: Dict[str, nn.Module], + target_module_type: Optional[Type[nn.Module]] = None, + target_functional_type: Any = None, +) -> Optional[Node]: + """ Gets the next module that matches what is needed in + is_target_module_type if it exists + + Args: + node: The node whose users we want to look at + target_module_type: Module type that we want to check + target_functional_type: Functional type that we want to check + """ + + for user in node.users.keys(): + if user.op == 'call_module' and target_module_type is not None and \ + isinstance(modules[str(user.target)], target_module_type): + return user + elif (user.op == 'call_function' and target_functional_type is not None and + user.target == target_functional_type): + return user + + return None + +def create_node_from_old_node_preserve_meta( + quantized_graph: Graph, + create_node_args: Tuple[Any, ...], + old_node: Node, +) -> Node: + """ + Creates `new_node` and copies the necessary metadata to it from `old_node`. + """ + new_node = quantized_graph.create_node(*create_node_args) + new_node.stack_trace = old_node.stack_trace + return new_node + +def get_skipped_module_name_and_classes( + prepare_custom_config: PrepareCustomConfig, + is_standalone_module: bool) -> Tuple[List[str], List[Type[Any]]]: + skipped_module_names = copy.copy(prepare_custom_config.non_traceable_module_names) + skipped_module_classes = copy.copy(prepare_custom_config.non_traceable_module_classes) + if not is_standalone_module: + # standalone module and custom module config are applied in top level module + skipped_module_names += list(prepare_custom_config.standalone_module_names.keys()) + skipped_module_classes += list(prepare_custom_config.standalone_module_classes.keys()) + skipped_module_classes += get_custom_module_class_keys(prepare_custom_config.float_to_observed_mapping) + + return skipped_module_names, skipped_module_classes + +def _is_custom_module_lstm( + node: Node, + named_modules: Dict[str, torch.nn.Module], + qconfig: QConfigAny = None, + # QuantizeHandler, but we cannot include the type here due to circular imports + qhandler: Optional[Any] = None, +) -> bool: + """ + Return whether this refers to the custom module LSTM flow. + """ + mod = _get_module(node, named_modules) + if qconfig is not None and qhandler is not None: + assert isinstance(qhandler, torch.ao.quantization.fx.quantize_handler.QuantizeHandler) # type: ignore[attr-defined] + return isinstance(mod, torch.nn.LSTM) and \ + activation_is_statically_quantized(qconfig) and \ + qhandler.is_custom_module() + else: + return isinstance(mod, torch.ao.nn.quantizable.LSTM) + +def _is_custom_module_mha( + node: Node, + named_modules: Dict[str, torch.nn.Module], + qconfig: QConfigAny = None, + # QuantizeHandler, but we cannot include the type here due to circular imports + qhandler: Optional[Any] = None, +) -> bool: + """ + Return whether this refers to the custom module MultiheadAttention flow. + """ + mod = _get_module(node, named_modules) + if qconfig is not None and qhandler is not None: + assert isinstance(qhandler, torch.ao.quantization.fx.quantize_handler.QuantizeHandler) # type: ignore[attr-defined] + return isinstance(mod, torch.nn.MultiheadAttention) and \ + activation_is_statically_quantized(qconfig) and \ + qhandler.is_custom_module() + else: + return isinstance(mod, torch.ao.nn.quantizable.MultiheadAttention) + +def _get_module(node: Node, named_modules: Dict[str, torch.nn.Module]) -> Optional[torch.nn.Module]: + """ + If `node` refers to a call_module node, return the module, else None. + """ + if node.op == "call_module" and str(node.target) in named_modules: + return named_modules[str(node.target)] + else: + return None + +def _insert_dequant_stub( + node: Node, + model: torch.nn.Module, + named_modules: Dict[str, torch.nn.Module], + graph: Graph, +) -> Node: + """ + Attach a `DeQuantStub` to the model and create a node that calls this + `DeQuantStub` on the output of `node`, similar to how observers are inserted. + """ + prefix = "dequant_stub_" + get_new_dequant_stub_name = get_new_attr_name_with_prefix(prefix) + dequant_stub_name = get_new_dequant_stub_name(model) + dequant_stub = DeQuantStub() + setattr(model, dequant_stub_name, dequant_stub) + named_modules[dequant_stub_name] = dequant_stub + with graph.inserting_after(node): + return graph.call_module(dequant_stub_name, (node,)) + +def _insert_dequant_stubs_for_custom_module_lstm_output( + node: Node, + model: torch.nn.Module, + named_modules: Dict[str, torch.nn.Module], + graph: Graph, +) -> Node: + """ + Insert DeQuantStubs after each internal output node of custom module LSTM. + + Custom module LSTM outputs are nested tuples of the structure (output, (hidden0, hidden1)), + Since we cannot dequantize a tuple as a whole, we must first break down the tuple into its + components through `getitem`. This function transforms the graph as follows: + + (1) Split the LSTM node into (output, (hidden0, hidden1)) + (2) Insert a DeQuantStub after each internal node + (3) Recombine the DeQuantStubs into the same structure as before + (4) Reroute all consumers of the original LSTM node and its sub-nodes + (e.g. lstm[0]) + + Before: + lstm_output + | + v + original_user(s) + After: + lstm_output + / \\ + / (getitem) \\ + / \\ + v v + output hidden + | / \\ + (DeQuantStub) (getitem) + | / \\ + v v v + output_dq hidden0 hidden1 + | | | + | (DeQuantStub) (DeQuantStub) + | | | + | v v + | hidden0_dq hidden1_dq + | \\ / + | (tuple) + | \\ / + | v v + | hidden_dq + \\ / + \\ (tuple) / + v v + lstm_output_dq + | + v + original_user(s) + + For step (4), reroute all users of the original LSTM node(s) as follows: + lstm_output -> lstm_output_dq + lstm_output[0] -> output_dq + lstm_output[1] -> hidden_dq + lstm_output[1][0] -> hidden0_dq + lstm_output[1][1] -> hidden1_dq + + Return the node `lstm_output_dq`. + """ + # (1) Split the LSTM node into (output, (hidden0, hidden1)) + # (2) Insert a DeQuantStub after each internal node + with graph.inserting_after(node): + output = graph.call_function(operator.getitem, (node, 0)) + output_dq = _insert_dequant_stub(output, model, named_modules, graph) + with graph.inserting_after(output_dq): + hidden = graph.call_function(operator.getitem, (node, 1)) + with graph.inserting_after(hidden): + hidden0 = graph.call_function(operator.getitem, (hidden, 0)) + hidden0_dq = _insert_dequant_stub(hidden0, model, named_modules, graph) + with graph.inserting_after(hidden0_dq): + hidden1 = graph.call_function(operator.getitem, (hidden, 1)) + hidden1_dq = _insert_dequant_stub(hidden1, model, named_modules, graph) + + # (3) Recombine the DeQuantStubs into the same structure as before + with graph.inserting_after(hidden1_dq): + hidden_dq = graph.call_function(tuple, ([hidden0_dq, hidden1_dq],)) + with graph.inserting_after(hidden_dq): + lstm_output_dq = graph.call_function(tuple, ([output_dq, hidden_dq],)) + + # (4) Reroute all consumers of the original LSTM node and its sub-nodes + for user in list(node.users.keys()): + if user != output and user != hidden: + user.replace_input_with(node, lstm_output_dq) + # The getitem and tuple nodes we added here may interfere with reference quantized + # pattern matching, so we need to redirect the consumers of internal nodes to the + # corresponding nodes with DeQuantStubs (e.g. lstm_output_dq[0] -> output_dq) attached, + # in order to preserve reference patterns like "dequantize - consumer - quantize". + _reroute_tuple_getitem_pattern(graph) + return lstm_output_dq + +def _maybe_get_custom_module_lstm_from_node_arg( + arg: Node, + named_modules: Dict[str, torch.nn.Module], +) -> Optional[Node]: + """ + Given an argument of a node, if the argument refers to the path through which the node + is a consumer of custom module LSTM, return the custom module LSTM node, or None otherwise. + + This is used to determine whether a node is a consumer of custom module LSTM, and, if so, + skip inserting input observers for this node. This is because custom module LSTM produces + quantized outputs, so inserting an input observer for the consumer of custom module LSTM + would unnecessarily quantize the outputs again. + + lstm -> consumer + + In practice, however, custom module LSTM outputs a tuple (output, (hidden0, hidden1)) with + DeQuantStubs attached to each internal node (see `_insert_dequant_stubs_for_custom_module_lstm_output`). + This tuple can be consumed in one of four ways: + + lstm -> getitem -> DeQuantStub -> consumer # consume lstm[0] + lstm -> getitem -> getitem -> DeQuantStub -> tuple -> consumer # consume lstm[1] + lstm -> getitem -> getitem -> DeQuantStub -> consumer # consume lstm[1][0] or lstm[1][1] + lstm -> getitem -> DeQuantStub -> tuple -> consumer # consume lstm + + Thus, we must match against the above patterns instead of simply checking the parent node + to determine whether this node is a consumer of a custom module LSTM. + """ + def match_dq(a): + return isinstance(_get_module(a, named_modules), DeQuantStub) + + def match_lstm(a): + return _is_custom_module_lstm(a, named_modules) + + def match_getitem(a): + return a.op == "call_function" and a.target == operator.getitem + + def match_tuple(a): + return a.op == "call_function" and a.target == tuple + + def _match_pattern(match_pattern: List[Callable]) -> Optional[Node]: + """ + Traverse up the graph and match the args one by one. + If there is a match, return the last matched node, or None otherwise. + """ + a = arg + for i, match in enumerate(match_pattern): + if not match(a): + return None + # Match next arg, for tuple the arg is a tuple of a list, e.g. ([dq_1, other_node],) + if i < len(match_pattern) - 1: + if match == match_tuple: + a = a.args[0][0] # type: ignore[assignment,index] + else: + a = a.args[0] # type: ignore[assignment] + return a + + all_match_patterns = [ + [match_dq, match_getitem, match_lstm], + [match_tuple, match_dq, match_getitem, match_getitem, match_lstm], + [match_dq, match_getitem, match_getitem, match_lstm], + [match_tuple, match_dq, match_getitem, match_lstm], + ] + + for p in all_match_patterns: + matched_node = _match_pattern(p) + if matched_node is not None: + return matched_node + return None + +def _reroute_tuple_getitem_pattern(graph: Graph): + """ + Search for patterns where N consecutive `tuple` call_function nodes are followed by + N consecutive `getitem` call_function nodes that are "reverses" of the `tuple` nodes. + If we find this pattern, reroute the consumers of the last `getitem` to skip these + N `tuple` and `getitem` nodes. + + Before: + + a b c + | \\ / + \\ tuple + \\ / + tuple + | + getitem(1) + | + getitem(0) + | + d + + After: + + b + | + d + """ + def find_patterns( + node: Node, + index_stack: List[int], + current_pattern: List[Node], + matched_patterns: List[List[Node]], + seen: Set[Tuple[Node, Tuple[int, ...]]]): + """ + Traverse the graph recursively to match for the N-tuple - N-getitem patterns, + starting at the given node. + + We use a stack to keep track of the expected `getitem` indices, since these are + reversed from the `tuple` indices. In the above example, the stack after + (b -> tuple -> tuple) will be [0, 1], which will be popped by getitem(1) first + and then by getitem(0). + + TODO: traverse upwards from the output and handle the case when tuple is not a + separate node, e.g. graph.call_function(operator.getitem, args=(a, (b, c))) + """ + if len(index_stack) == 0 and len(current_pattern) > 0: + matched_patterns.append(copy.copy(current_pattern)) + current_pattern.clear() + + # Avoid duplicating work + state = (node, tuple(index_stack)) + if state in seen: + return + seen.add(state) + + # Iterate through users of this node to find tuple/getitem nodes to match + for user in node.users: + if user.op == "call_function" and user.target == tuple: + for i, user_arg in enumerate(user.args[0]): # type: ignore[arg-type] + if user_arg == node: + index_stack.append(i) + current_pattern.append(user) + find_patterns(user, index_stack, current_pattern, matched_patterns, seen) + elif user.op == "call_function" and user.target == operator.getitem: + if len(index_stack) > 0: + if user.args[1] == index_stack[-1]: + index_stack.pop() + current_pattern.append(user) + find_patterns(user, index_stack, current_pattern, matched_patterns, seen) + return matched_patterns + + # Collect all matched patterns + matched_patterns: List[List[Node]] = [] + seen: Set[Tuple[Node, Tuple[int, ...]]] = set() # (node, index_stack) + for node in graph.nodes: + find_patterns(node, [], [], matched_patterns, seen) + + # For each pattern, redirect all consumers of the last getitem node to the correct input + # of the first tuple node + for pattern in matched_patterns: + first_tuple = pattern[0] + last_getitem = pattern[-1] + assert first_tuple.op == "call_function" and first_tuple.target == tuple + assert last_getitem.op == "call_function" and last_getitem.target == operator.getitem + last_getitem_index = last_getitem.args[1] + new_input = first_tuple.args[0][last_getitem_index] # type: ignore[index] + for user in list(last_getitem.users.keys()): + user.replace_input_with(last_getitem, new_input) + +def _get_observer_from_activation_post_process( + activation_post_process: Union[ObserverBase, FakeQuantizeBase], +) -> ObserverBase: + """ + If `activation_post_process` is an observer, return the observer. + If `activation_post_process` is a fake quantize, return the internal observer. + """ + if isinstance(activation_post_process, ObserverBase): + return activation_post_process + else: + assert isinstance(activation_post_process, FakeQuantizeBase) + return activation_post_process.activation_post_process # type: ignore[return-value] + +def _qconfig_satisfies_dtype_config_constraints( + qconfig: QConfigAny, + dtype_with_constraints: DTypeWithConstraints, + is_activation: bool = True) -> bool: + """ + Return whether `qconfig` satisfies the following constraints from the backend, + specified through the activation and weight DTypeWithConstraints. + + 1. QConfig specified a quantization range that falls within the backend's, if any + 2. QConfig specified a min scale value that is >= the backend's, if any + 3. QConfig specified a FixedQParamsObserver or FixedQParamsFakeQuantize that has + scale and zero point that match the backend's, if any + + If `is_activation` is True, we check `qconfig.activation`, else we check `qconfig.weight`. + If `qconfig` or `dtype_with_constraints.dtype` is None, or the dtypes do not match, return True. + """ + # TODO: log warnings only when the user enabled a debug flag + def _activation_post_process_satisfies_dtype_config_constraints( + activation_post_process: Union[ObserverBase, FakeQuantizeBase], + dtype_with_constraints: DTypeWithConstraints, + debug_string: str) -> bool: + observer = _get_observer_from_activation_post_process(activation_post_process) + app_quant_min = getattr(observer, "quant_min", None) + app_quant_max = getattr(observer, "quant_max", None) + # TODO: for now, just use the existing eps value as scale_min. In the future, we should + # resolve the differences between the two, either by renaming eps or some other way + app_scale_min = getattr(observer, "eps", None) + backend_quant_min = dtype_with_constraints.quant_min_lower_bound + backend_quant_max = dtype_with_constraints.quant_max_upper_bound + backend_scale_min = dtype_with_constraints.scale_min_lower_bound + backend_scale_exact_match = dtype_with_constraints.scale_exact_match + backend_zero_point_exact_match = dtype_with_constraints.zero_point_exact_match + # check quantization ranges + if backend_quant_min is not None and backend_quant_max is not None: + if app_quant_min is None or app_quant_max is None: + warnings.warn(f"QConfig {debug_string} must specify 'quant_min' and 'quant_max', ignoring {qconfig}") + return False + elif app_quant_min < backend_quant_min or app_quant_max > backend_quant_max: + warnings.warn( + f"QConfig {debug_string} quantization range must fall within the backend's:\n" + f"QConfig range = ({app_quant_min}, {app_quant_max}), " + f"BackendConfig range = ({backend_quant_min}, {backend_quant_max}), " + f"ignoring {qconfig}" + ) + return False + # check scale min + if backend_scale_min is not None: + if app_scale_min is None: + warnings.warn(f"QConfig {debug_string} must specify 'eps', ignoring {qconfig}") + return False + if app_scale_min < backend_scale_min: + warnings.warn( + f"QConfig {debug_string} eps ({app_scale_min}) must be greater than or equal to " + f"the backend's min scale value ({backend_scale_min}), ignoring {qconfig}" + ) + return False + # check fixed scale and zero point + if backend_scale_exact_match is not None and backend_zero_point_exact_match is not None: + # For tests only, accept the following qconfigs for now + # TODO: handle fp16 qconfigs properly + for accepted_qconfig in [float16_static_qconfig, float16_dynamic_qconfig]: + if qconfig_equals(qconfig, accepted_qconfig): + return True + suggestion_str = ( + "Please use torch.ao.quantization.get_default_qconfig_mapping or " + "torch.ao.quantization.get_default_qat_qconfig_mapping. Example:\n" + " qconfig_mapping = get_default_qconfig_mapping(\"fbgemm\")\n" + " model = prepare_fx(model, qconfig_mapping, example_inputs)" + ) + if not isinstance(activation_post_process, FixedQParamsObserver) and \ + not isinstance(activation_post_process, FixedQParamsFakeQuantize): + warnings.warn( + f"QConfig must specify a FixedQParamsObserver or a FixedQParamsFakeQuantize " + f"for fixed qparams ops, ignoring {qconfig}.\n{suggestion_str}" + ) + return False + if observer.scale != backend_scale_exact_match or observer.zero_point != backend_zero_point_exact_match: + warnings.warn( + f"QConfig fixed scale ({observer.scale}) and zero point ({observer.zero_point}) " + f"do not match the backend's ({backend_scale_exact_match} and {backend_zero_point_exact_match}), " + f"ignoring {qconfig}.\n{suggestion_str}" + ) + return False + return True + + if qconfig is None or dtype_with_constraints.dtype is None: + return True + + activation_post_process_ctr = qconfig.activation if is_activation else qconfig.weight + debug_string = "activation" if is_activation else "weight" + satisfies_constraints = True + if activation_post_process_ctr is not None: + activation_post_process = activation_post_process_ctr() + assert _is_activation_post_process(activation_post_process) + # If dtypes don't match, don't check the activation_post_process and return True early + if activation_post_process.dtype != dtype_with_constraints.dtype: + return True + satisfies_constraints = _activation_post_process_satisfies_dtype_config_constraints( + activation_post_process, dtype_with_constraints, debug_string) + return satisfies_constraints diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/__pycache__/export_utils.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/__pycache__/export_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fa62de2c55d74c408170cf3329a5d1a4cbf51526 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/__pycache__/export_utils.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/__pycache__/graph_utils.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/__pycache__/graph_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd37abc415a4281798846d585377b4c07dfba1b5 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/__pycache__/graph_utils.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/generate_numeric_debug_handle.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/generate_numeric_debug_handle.py new file mode 100644 index 0000000000000000000000000000000000000000..a6ca1f71b7d20c3d230f0ccd20924d2c3ef02d7a --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/generate_numeric_debug_handle.py @@ -0,0 +1,17 @@ +from torch.fx import GraphModule, Node + +__all__ = ["generate_numeric_debug_handle"] + + +def generate_numeric_debug_handle(graph_module: GraphModule) -> None: + unique_id = 0 + for node in graph_module.graph.nodes: + if node.op == "call_function": + node.meta["numeric_debug_handle"] = {} + for arg in node.args: + if isinstance(arg, Node): + node.meta["numeric_debug_handle"][arg] = unique_id + unique_id += 1 + + node.meta["numeric_debug_handle"]["output"] = unique_id + unique_id += 1 diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/graph_utils.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/graph_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..bacb4d8a28f155bdbe5e636be12449f82fa1b381 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/graph_utils.py @@ -0,0 +1,109 @@ +import itertools +from typing import Any, List, OrderedDict, Set, Optional, Callable +import operator +from torch.fx import Node + +import torch + +from torch.fx.passes.utils.source_matcher_utils import ( + check_subgraphs_connected, + get_source_partitions, + SourcePartition, +) + +__all__ = [ + "find_sequential_partitions", + "get_equivalent_types", + "update_equivalent_types_dict", +] + +_EQUIVALENT_TYPES: List[Set] = [ + {torch.nn.Conv1d, torch.nn.functional.conv1d}, + {torch.nn.Conv2d, torch.nn.functional.conv2d}, + {torch.nn.AdaptiveAvgPool2d, torch.nn.functional.adaptive_avg_pool2d}, + {torch.nn.ReLU, torch.nn.functional.relu, torch.nn.functional.relu_}, + {torch.nn.BatchNorm2d, torch.nn.functional.batch_norm}, + {torch.nn.Hardtanh, torch.nn.functional.hardtanh, torch.nn.functional.hardtanh_}, + {torch.add, operator.add, operator.iadd, "add", "add_"}, + {torch.mul, operator.mul, operator.imul, "mul", "mul_"}, +] + + +def _create_equivalent_types_dict(): + _DICT = {} + for values in _EQUIVALENT_TYPES: + for v in values: + _DICT[v] = list(values) + return _DICT + + +_EQUIVALENT_TYPES_DICT = _create_equivalent_types_dict() + +def get_equivalent_types() -> List[Set]: + return _EQUIVALENT_TYPES + +def update_equivalent_types_dict(customized_equivalent_types=None): + """Help function for user who wants to customize the _EQUIVALENT_TYPES and _EQUIVALENT_TYPES_DICT. + When customized_equivalent_types passes in, + re-generate _EQUIVALENT_TYPES and _EQUIVALENT_TYPES_DICT. + """ + if customized_equivalent_types is None: + raise ValueError("customized_equivalent_types should not be None") + global _EQUIVALENT_TYPES + global _EQUIVALENT_TYPES_DICT + _EQUIVALENT_TYPES = customized_equivalent_types + _EQUIVALENT_TYPES_DICT = _create_equivalent_types_dict() + +def _partitions_sequential(partitions: List[SourcePartition]): + prev_partition = None + for partition in partitions: + if prev_partition is not None and not check_subgraphs_connected( + prev_partition, partition + ): + return False + prev_partition = partition + return True + + +def _get_matching_types(partition_type): + matching_types = [partition_type] + if partition_type in _EQUIVALENT_TYPES_DICT: + matching_types.extend(_EQUIVALENT_TYPES_DICT[partition_type]) + return matching_types + + +def _valid_type_sequence(partition_types: List[Any]): + partition_types_set = set() # type: ignore[var-annotated] + for partition_type in partition_types: + matching_types = _get_matching_types(partition_type) + matching_types_set = set(matching_types) + if len(partition_types_set & matching_types_set) > 0: + return False + partition_types_set |= matching_types_set + return True + + +def find_sequential_partitions( + gm: torch.fx.GraphModule, + partition_types: List[Any], + include_functional_equivalent=True, + filter_fn: Optional[Callable[[Node], bool]] = None, +): + if not _valid_type_sequence(partition_types): + raise ValueError( + f"Invalid partition types: {partition_types}. Each type in the sequence must be unique" + ) + + typed_partitions: OrderedDict[Any, List[SourcePartition]] = OrderedDict() + for partition_type in partition_types: + types_to_match = _get_matching_types(partition_type) + partitions = get_source_partitions(gm.graph, types_to_match, filter_fn) + typed_partitions[partition_type] = list(itertools.chain.from_iterable(partitions.values())) + + typed_partitions_list = list(typed_partitions.values()) + fusion_candidates = itertools.product(*typed_partitions_list) + fused_partitions = [] + for candidate in fusion_candidates: + if _partitions_sequential(candidate): # type: ignore[arg-type] + fused_partitions.append(candidate) + return fused_partitions diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/port_metadata_pass.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/port_metadata_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..3f02943146e927809cf6b63a8d0ccb3b183f5d68 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/port_metadata_pass.py @@ -0,0 +1,198 @@ +import logging +from typing import Optional + +import torch +from torch._export.error import InternalError + +from torch.ao.quantization.pt2e.utils import ( + _filter_sym_size_users, + _find_q_dq_node_for_user, + _is_valid_annotation, +) + +from torch.ao.quantization.quantizer import QuantizationSpecBase + +from torch.fx.passes.infra.pass_base import PassBase, PassResult + + +logger = logging.getLogger(__name__) +logger.setLevel(logging.ERROR) + +__all__ = ["PortNodeMetaForQDQ"] + +_METADATA_TO_PORT = [ + "stack_trace", + "quantization_tag", +] + +_QUANTIZE_OPS = [ + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.quantize_per_tensor.tensor, + torch.ops.quantized_decomposed.quantize_per_channel.default, +] + +_DEQUANTIZE_OPS = [ + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, + torch.ops.quantized_decomposed.dequantize_per_channel.default, +] + + +def _add_metadata(to_node: torch.fx.Node, from_node: torch.fx.Node) -> None: + from_meta = from_node.meta + for meta_name in _METADATA_TO_PORT: + if meta_name in from_meta: + to_node.meta[meta_name] = from_meta[meta_name] + + +def _has_quant_annotation(node: torch.fx.Node) -> bool: + return "quantization_annotation" in node.meta + + +def _find_choose_qparams_node(node: torch.fx.Node) -> Optional[torch.fx.Node]: + # BFS to look for choose qparams + from collections import deque + + queue = deque(list(node.users.keys())) + while len(queue): + n = queue.popleft() + if n.op == "output": + continue + if ( + n.op == "call_function" + and n.target == torch.ops.quantized_decomposed.choose_qparams.tensor + ): + return n + for k in n.users.keys(): + queue.append(k) + return None + + +def _port_metadata_for_input_quant_nodes( + input_node: torch.fx.Node, + node: torch.fx.Node, + qspec: Optional[QuantizationSpecBase], +): + if qspec is None: + return + + is_dynamic_quant = getattr(qspec, "is_dynamic", None) + if is_dynamic_quant is not None and is_dynamic_quant is True: + choose_qparams_node = _find_choose_qparams_node(input_node) + if choose_qparams_node is None: + raise ValueError(f"No chose qparams node found for {node}") + choose_qparam_users = _filter_sym_size_users(choose_qparams_node) + if len(choose_qparam_users) != 2: + raise InternalError(f"Expecting exactly two user for {choose_qparams_node}") + scale_node = choose_qparam_users.pop() + dynamic_q_node = next(iter(scale_node.users.keys())) + dynamic_q_node_users = _filter_sym_size_users(dynamic_q_node) + if len(dynamic_q_node_users) > 1: + raise InternalError(f"Expecting single user for {dynamic_q_node}") + dynamic_dq_node = dynamic_q_node_users.pop() + _add_metadata(choose_qparams_node, node) + _add_metadata(dynamic_q_node, node) + _add_metadata(dynamic_dq_node, node) + else: + q_node, dq_node = _find_q_dq_node_for_user(input_node, node) + if q_node is None or dq_node is None: + return + # add metadata for all the node between q_node and get_attr node + # if the q_node can be traced back to get_attr node + q_to_get_attr_nodes = [q_node] + q_node_input = q_node.args[0] + while isinstance(q_node_input, torch.fx.Node) and q_node_input.op not in [ + "placeholder", + "get_attr", + ]: + q_to_get_attr_nodes.append(q_node_input) + q_node_input = q_node_input.args[0] + if isinstance(q_node_input, torch.fx.Node) and q_node_input.op == "get_attr": + for n in q_to_get_attr_nodes: + _add_metadata(n, q_node_input) + _add_metadata(dq_node, node) + + +def _port_metadata_for_output_quant_nodes( + node: torch.fx.Node, qspec: Optional[QuantizationSpecBase] +): + if qspec is None: + return + + node_users = _filter_sym_size_users(node) + if len(node_users) != 1: + raise InternalError(f"Expecting {node} to have single user") + q_node = node_users.pop() + if q_node.op != "call_function" or q_node.target not in _QUANTIZE_OPS: + logger.warning( + f"Expecting {node} user to be a quantized op but got {q_node}" # noqa: G004 + ) # noqa: G004 + return + + _add_metadata(q_node, node) + + +class PortNodeMetaForQDQ(PassBase): + """ + Port metadata for nodes added by quantization flow. + For static quant these are: + - quantizer_per_tensor.default, dequantize_per_tensor.default + - quantizer_per_channel.default, dequantize_per_channel.default + For dynamic quant these are: + - choose_qparams.tensor + - quantizer_per_tensor.tensor, dequantize_per_tensor.tensor + - quantizer_per_channel.default, dequantize_per_channel.default + + Rules of porting metadata: + - Metadata to be ported: + - nn_module_stack + - stack_trace + - quantization_tag + - Metadata to NOT be ported: + - Everything else + - Rules: + - Statically quantized patterns: + - Dequantize nodes on the inputs to be quantized inherit metadata of the consumer node. + - Quantize nodes on the outputs inherit metadata of the producer node. + - Example 1: + - Original: [Conv -> AvgPool -> Linear] + - Quantized [Q-> DQ -> Conv -> Q -> DQ -> AvgPool -> Q -> DQ -> Linear -> Q -> DQ] + - Inner brackets specify which nodes Q/DQ inherit metdata from + - [Q-> [DQ -> Conv -> Q] -> [DQ -> AvgPool -> Q] -> [DQ -> Linear -> Q] -> DQ] + - Note first Q and last DQ do not inherit metadata from any nodes + - Example 2: + - Original: [Conv -> AvgPool -> Linear] + - AvgPool is not quantized + - Quantized [Q-> DQ -> Conv -> Q -> DQ -> AvgPool -> Q -> DQ -> Linear -> Q -> DQ] + - Inner brackets specify which nodes Q/DQ inherit metdata from + - [Q-> [DQ -> Conv -> Q] -> DQ -> [AvgPool] -> Q -> [DQ -> Linear -> Q] -> DQ] + - Note DQ and Q nodes around AvgPool do not inherit metadata from AvgPool because + AvgPool was not supposed to be quantized. Metadata porting relies on quantization_annotation + on the nodes (in this case AvgPool node) to conclude if the node or patter was + supposed to be quantized. And subsequntly decide if the preceding Q, if any, should + inherit metadata from AvgPool. + - Dynamically quantized patterns: + - Input that are dynamically quantized have choose_qparams, quantize and dequantize nodes + - For example, below linear is dynamically quantized while rest statically: + - Original: [Conv -> AvgPool -> Linear] + - Quantized [Q-> DQ -> Conv -> Q -> DQ -> AvgPool -> Q -> DQ -> choose_params -> Q -> DQ -> Linear] + - Quantized [Q-> [DQ -> Conv -> Q] -> [DQ -> AvgPool -> Q] -> DQ -> [choose_params -> Q -> DQ -> Linear]] + - Note first Q does not inherit metadata from any nodes + NB: + - The best place for porting metadata is during observer conversion to q/dq. This is because it precisely + knows which quantization spec is converted to q/dq and thus from where the metadata should be ported. + However, since FX and PT2E quant workflow are on a common code-base, this hurts readability quite a bit. + Doing it via a separate pass, helps readability of the code. Once we are able to refactor PT2E quant + code, this pass should like to be integrated in the refactored variant of "convert" step. + """ + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + for node in graph_module.graph.nodes: + annotation = node.meta.get("quantization_annotation", None) + if _is_valid_annotation(annotation): + input_qspec_map = node.meta["quantization_annotation"].input_qspec_map + output_qspec = node.meta["quantization_annotation"].output_qspec + for input_node, qspec in input_qspec_map.items(): + _port_metadata_for_input_quant_nodes(input_node, node, qspec) + _port_metadata_for_output_quant_nodes(node, output_qspec) + return PassResult(graph_module, True) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/qat_utils.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/qat_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d50a2f608e2719246ecca41269ca38d2fa156466 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/qat_utils.py @@ -0,0 +1,788 @@ +import dataclasses +import itertools +import operator +from typing import Any, Callable, Dict, List, Tuple, TYPE_CHECKING + +import torch +from torch.fx import Graph, GraphModule, Node +from torch.fx.subgraph_rewriter import ( + replace_pattern_with_filters, + ReplacedPatterns, +) +import torch.nn.functional as F +from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401 +from torch.ao.quantization.pt2e.export_utils import _WrapperModule +from torch.ao.quantization.quantizer import ( + DerivedQuantizationSpec, + EdgeOrNode, + SharedQuantizationSpec, + QuantizationSpecBase, +) +from .utils import ( + _conv1d_bn_example_inputs, + _conv2d_bn_example_inputs, + _is_conv, + _is_bn_node, + fold_bn_weights_into_conv_node, + get_aten_graph_module, +) + +if TYPE_CHECKING: + from torch.fx.passes.utils.matcher_with_name_node_map_utils import InternalMatch + +__all__ = [] # type: ignore[var-annotated] + + +# Example inputs for quantized and folded conv-bn1d patterns used in convert +_quantized_conv1d_bn_example_inputs = ( + torch.randn(1, 1, 3), # x + torch.randn(1, 1, 1), # conv_weight + torch.randn(1), # bn_weight + torch.randn(1), # bn_bias + torch.randn(1), # bn_running_mean + torch.randn(1), # bn_running_var +) + +# Example inputs for quantized and folded conv-bn2d patterns used in convert +_quantized_conv2d_bn_example_inputs = ( + torch.randn(1, 1, 3, 3), # x + torch.randn(1, 1, 1, 1), # conv_weight + torch.randn(1), # bn_weight + torch.randn(1), # bn_bias + torch.randn(1), # bn_running_mean + torch.randn(1), # bn_running_var +) + + +def _get_quantized_conv_bn_example_inputs_kwargs( + is_per_channel: bool, + has_bias: bool, + is_cuda: bool, +) -> Dict[str, Any]: + """ + Optional example inputs for quantized and folded conv-bn patterns + used in convert, expressed as kwargs. + """ + kwargs = {} + # Per tensor quantization uses literals to represent scale and zero + # point, so there is no need to include them here as kwargs + if is_per_channel: + kwargs["scale"] = torch.tensor([1], dtype=torch.float) + kwargs["zero_point"] = torch.tensor([0], dtype=torch.int) + if has_bias: + kwargs["conv_bias"] = torch.randn(1) + if is_cuda: + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + kwargs[k] = v.cuda() + return kwargs + +def _get_conv_bn_pattern(conv_fn: Callable) -> Callable: + def _conv_bn_pattern( + x: torch.Tensor, + conv_weight: torch.Tensor, + conv_bias: torch.Tensor, + bn_weight: torch.Tensor, + bn_bias: torch.Tensor, + bn_running_mean: torch.Tensor, + bn_running_var: torch.Tensor, + ) -> torch.Tensor: + x = conv_fn(x, conv_weight, conv_bias) + x = F.batch_norm(x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=True) + return x + return _WrapperModule(_conv_bn_pattern) + +# TODO: merge this with the `no_conv_bias` case +def _get_qat_conv_bn_pattern(conv_fn: Callable) -> Callable: + def _qat_conv_bn_pattern( + x: torch.Tensor, + conv_weight: torch.Tensor, + conv_bias: torch.Tensor, + bn_weight: torch.Tensor, + bn_bias: torch.Tensor, + bn_running_mean: torch.Tensor, + bn_running_var: torch.Tensor, + ) -> torch.Tensor: + """ + Approximated method to fuse conv and bn. It requires only one forward pass. + conv_orig = conv / scale_factor where scale_factor = bn.weight / running_std. + This is based on `nniqat.ConvBn2d._forward_approximate`. + """ + # TODO: allow setting eps + bn_eps = 1e-5 + running_std = torch.sqrt(bn_running_var + bn_eps) + scale_factor = bn_weight / running_std + weight_shape = [1] * len(conv_weight.shape) + weight_shape[0] = -1 + bias_shape = [1] * len(conv_weight.shape) + bias_shape[1] = -1 + scaled_weight = conv_weight * scale_factor.reshape(weight_shape) + zero_bias = torch.zeros_like(conv_bias, dtype=x.dtype) + x = conv_fn(x, scaled_weight, zero_bias) + x = x / scale_factor.reshape(bias_shape) + x = x + conv_bias.reshape(bias_shape) + x = F.batch_norm(x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=True, eps=bn_eps) + return x + return _WrapperModule(_qat_conv_bn_pattern) + +def _get_qat_conv_bn_pattern_no_conv_bias(conv_fn: Callable) -> Callable: + def _qat_conv_bn_pattern_no_conv_bias( + x: torch.Tensor, + conv_weight: torch.Tensor, + # Not used, only for matching convenience + conv_bias: torch.Tensor, + bn_weight: torch.Tensor, + bn_bias: torch.Tensor, + bn_running_mean: torch.Tensor, + bn_running_var: torch.Tensor, + ) -> torch.Tensor: + """ + Same as `_get_qat_conv_bn_pattern`, but handles the case with no conv bias. + """ + # TODO: allow setting eps + bn_eps = 1e-5 + running_std = torch.sqrt(bn_running_var + bn_eps) + scale_factor = bn_weight / running_std + weight_shape = [1] * len(conv_weight.shape) + weight_shape[0] = -1 + bias_shape = [1] * len(conv_weight.shape) + bias_shape[1] = -1 + scaled_weight = conv_weight * scale_factor.reshape(weight_shape) + x = conv_fn(x, scaled_weight, None) + x = x / scale_factor.reshape(bias_shape) + x = F.batch_norm(x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=True, eps=bn_eps) + return x + return _WrapperModule(_qat_conv_bn_pattern_no_conv_bias) + +def _append_qdq(x, is_per_channel, kwargs): + """ + Helper function to append q-dq ops after `x`, using dummy values for the qparams + and qmin/qmax. We use dummy values here because we match with `ignore_literals=True` + and will manually replace these values after subgraph rewriting. + + Return the dq node. + """ + # Dummy args to be passed into q-dq ops + per_channel_axis = 0 + scale = kwargs["scale"] if is_per_channel else 1.0 + zp = kwargs["zero_point"] if is_per_channel else 0 + qmin = -127 + qmax = 127 + dtype = torch.int8 + + qd = torch.ops.quantized_decomposed + if is_per_channel: + x = qd.quantize_per_channel(x, scale, zp, per_channel_axis, qmin, qmax, dtype) + x = qd.dequantize_per_channel(x, scale, zp, per_channel_axis, qmin, qmax, dtype) + else: + x = qd.quantize_per_tensor(x, scale, zp, qmin, qmax, dtype) + x = qd.dequantize_per_tensor(x, scale, zp, qmin, qmax, dtype) + return x + +def _get_quantized_qat_conv_bn_pattern( + is_per_channel: bool, + has_bias: bool, + bias_is_quantized: bool, + conv_fn: Callable, + bn_is_training: bool, +) -> Callable: + """ + Return the quantized version of QAT conv + BN pattern. + This is based on `nniqat.ConvBn2d._forward_approximate`, + used in QAT convert. We first match this pattern and replace + it with the normal [conv - bn] pattern, then fold the BN + weights into conv. + """ + # TODO: allow setting eps + bn_eps = 1e-5 + + def _quantized_qat_conv_bn_pattern( + x: torch.Tensor, + conv_weight: torch.Tensor, + bn_weight: torch.Tensor, + bn_bias: torch.Tensor, + bn_running_mean: torch.Tensor, + bn_running_var: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + running_std = torch.sqrt(bn_running_var + bn_eps) + scale_factor = bn_weight / running_std + weight_shape = [1] * len(conv_weight.shape) + weight_shape[0] = -1 + bias_shape = [1] * len(conv_weight.shape) + bias_shape[1] = -1 + scaled_weight = conv_weight * scale_factor.reshape(weight_shape) + scaled_weight = _append_qdq(scaled_weight, is_per_channel, kwargs) + if has_bias: + zero_bias = torch.zeros_like(kwargs["conv_bias"], dtype=x.dtype) + if bias_is_quantized: + zero_bias = _append_qdq(zero_bias, is_per_channel, kwargs) + x = conv_fn(x, scaled_weight, zero_bias) + else: + x = conv_fn(x, scaled_weight, None) + x = x / scale_factor.reshape(bias_shape) + if has_bias: + x = x + kwargs["conv_bias"].reshape(bias_shape) + x = F.batch_norm(x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=bn_is_training, eps=bn_eps) + return x + return _WrapperModule(_quantized_qat_conv_bn_pattern) + +def _get_folded_quantized_qat_conv_bn_pattern( + is_per_channel: bool, + has_bias: bool, + bias_is_quantized: bool, + conv_fn: Callable, + bn_is_training: bool, +) -> Callable: + """ + Quantized QAT conv - bn pattern with bn weights being folded into conv. + """ + # TODO: allow setting eps + bn_eps = 1e-5 + + def _folded_quantized_qat_conv_bn_pattern( + x: torch.Tensor, + conv_weight: torch.Tensor, + bn_weight: torch.Tensor, + bn_bias: torch.Tensor, + bn_running_mean: torch.Tensor, + bn_running_var: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + conv_weight = _append_qdq(conv_weight, is_per_channel, kwargs) + if has_bias: + bias = kwargs["conv_bias"] + if bias_is_quantized: + bias = _append_qdq(bias, is_per_channel, kwargs) + else: + bias = None + x = conv_fn(x, conv_weight, bias) + x = F.batch_norm(x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=bn_is_training, eps=bn_eps) + return x + return _WrapperModule(_folded_quantized_qat_conv_bn_pattern) + +def _has_conv_bias_filter( + match: "InternalMatch", + original_graph: Graph, + pattern_graph: Graph, +) -> bool: + """ + Match filter for the subgraph rewriter that returns True if the conv node in + the original graph has bias. + """ + for n in match.nodes_map.values(): + if _is_conv(n): + return len(n.args) > 2 and n.args[2] is not None + raise ValueError("Could not find conv node in matched conv + bn pattern") + +def _no_conv_bias_filter( + match: "InternalMatch", + original_graph: Graph, + pattern_graph: Graph, +) -> bool: + """ + Match filter for the subgraph rewriter that returns True if the conv node in + the original graph does NOT have bias. + """ + return not _has_conv_bias_filter(match, original_graph, pattern_graph) + +def _is_quantize(n: Node) -> bool: + return n.target in [ + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.quantize_per_tensor.tensor, + torch.ops.quantized_decomposed.quantize_per_channel.default, + ] + +def _is_dequantize(n: Node) -> bool: + return n.target in [ + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, + torch.ops.quantized_decomposed.dequantize_per_channel.default, + ] + +def _get_conv_bn_pattern_nodes(r: ReplacedPatterns) -> Dict[str, Tuple[Node, Node]]: + """ + Helper function to extract the nodes in the conv-bn fusion pattern after + subgraph rewriting, in the form of a map: + + {name: (original_node, replacement_node)} + + The following names must exist in the map: + + "conv", "conv_weight", "conv_input", "bn", "getitem" + + The following names may exist in the map: + + "conv_weight_q", "conv_weight_dq", "conv_bias", + "conv_bias_q", "conv_bias_dq" + """ + def _get_nodes(nodes: List[Node]) -> Tuple[Node, Node, Node]: + """ + Return a 3-tuple of (conv_node, bn_node, getitem_node). + This asserts that the match contains exactly one of each node. + """ + conv_node, bn_node, getitem_node = None, None, None + for n in nodes: + if n.op != "call_function": + continue + if _is_conv(n): + assert conv_node is None + conv_node = n + if _is_bn_node(n): + assert bn_node is None + bn_node = n + if n.target == operator.getitem: + assert getitem_node is None + getitem_node = n + assert conv_node is not None + assert bn_node is not None + assert getitem_node is not None + return (conv_node, bn_node, getitem_node) + + def _get_q_dq_nodes(n: Node) -> Tuple[Node, Node, Node]: + """ + Return a 3-tuple of (orig_node, q_node, dq_node). + """ + assert _is_dequantize(n) + q_node = n.args[0] + assert isinstance(q_node, Node) + assert _is_quantize(q_node) + orig_node = q_node.args[0] + assert isinstance(orig_node, Node) + return (orig_node, q_node, n) + + original_nodes = list(_filter_nodes_map(r.nodes_map).values()) + o_conv, o_bn, o_getitem = _get_nodes(original_nodes) + r_conv, r_bn, r_getitem = _get_nodes(r.replacements) + + # Create the mapping from original node to replacement node + mapping = { + "conv": (o_conv, r_conv), + "bn": (o_bn, r_bn), + "getitem": (o_getitem, r_getitem), + } + + # Extract conv input and weight + # Note: here we extract the original nodes indirectly through the pattern nodes + # because the args of the original nodes are no longer available after replacement + (p_conv, _, _) = _get_nodes(list(r.nodes_map.keys())) + (p_conv_input, p_conv_weight, *_) = p_conv.args + (r_conv_input, r_conv_weight, *_) = r_conv.args + assert isinstance(p_conv_input, Node) + assert isinstance(p_conv_weight, Node) + assert isinstance(r_conv_input, Node) + assert isinstance(r_conv_weight, Node) + o_conv_input = r.nodes_map[p_conv_input] + o_conv_weight = r.nodes_map[p_conv_weight] + + # If conv weight is quantized, extract the q - dq nodes + if _is_dequantize(p_conv_weight): + p_conv_weight, p_conv_weight_q, p_conv_weight_dq = _get_q_dq_nodes(p_conv_weight) + r_conv_weight, r_conv_weight_q, r_conv_weight_dq = _get_q_dq_nodes(r_conv_weight) + o_conv_weight = r.nodes_map[p_conv_weight] + o_conv_weight_q = r.nodes_map[p_conv_weight_q] + o_conv_weight_dq = r.nodes_map[p_conv_weight_dq] + mapping["conv_weight_q"] = (o_conv_weight_q, r_conv_weight_q) + mapping["conv_weight_dq"] = (o_conv_weight_dq, r_conv_weight_dq) + mapping["conv_input"] = (o_conv_input, r_conv_input) + mapping["conv_weight"] = (o_conv_weight, r_conv_weight) + + # Extract conv bias + if len(p_conv.args) > 2 and len(r_conv.args) > 2: + p_conv_bias = p_conv.args[2] + r_conv_bias = r_conv.args[2] + assert isinstance(p_conv_bias, Node) + assert isinstance(r_conv_bias, Node) + o_conv_bias = r.nodes_map[p_conv_bias] + + # If conv bias is quantized, extract the q - dq nodes + if _is_dequantize(p_conv_bias): + p_conv_bias, p_conv_bias_q, p_conv_bias_dq = _get_q_dq_nodes(p_conv_bias) + r_conv_bias, r_conv_bias_q, r_conv_bias_dq = _get_q_dq_nodes(r_conv_bias) + o_conv_bias = r.nodes_map[p_conv_bias] + o_conv_bias_q = r.nodes_map[p_conv_bias_q] + o_conv_bias_dq = r.nodes_map[p_conv_bias_dq] + mapping["conv_bias_q"] = (o_conv_bias_q, r_conv_bias_q) + mapping["conv_bias_dq"] = (o_conv_bias_dq, r_conv_bias_dq) + mapping["conv_bias"] = (o_conv_bias, r_conv_bias) + return mapping + +def _filter_nodes_map(nodes_map: Dict[Node, Node]) -> Dict[Node, Node]: + """ + Return a filtered `nodes_map` returned from the subgraph rewriter. + The filtered `nodes_map` will contain only nodes that are actually + matched in the pattern, excluding None or placeholder nodes. + """ + new_nodes_map: Dict[Node, Node] = {} + for pattern_node, graph_node in nodes_map.items(): + # bias can be None + if graph_node is None: + continue + # skip pattern placeholder nodes + if pattern_node.op == "placeholder": + continue + new_nodes_map[pattern_node] = graph_node + return new_nodes_map + +# TODO: this is error prone, use the replace_literals_with_placeholders hack instead +def _copy_over_literal_conv_args(original_node: Node, new_node: Node): + """ + Copy over literal args in conv, such as stride and padding, from the matched node + in the original graph to its replacement in the new graph. + + This is needed due to the following limitation in the subgraph rewriter when used + with dynamo export: literal (non-tensor) args are not supported in the match and + replacement patterns. This is because dynamo export automatically inlines these + literal args, making them dead placeholder nodes. In the future, we should check + if dynamo export can optionally disable this inlining, or if subgraph rewriter + can do the copying for us. See https://github.com/pytorch/pytorch/issues/100419. + + Note: Unlike other tensor args like conv weights and biases, literal args are + preserved in the original nodes after replacement, so we can access them here. + """ + assert _is_conv(original_node) + assert _is_conv(new_node) + # x, weight, bias, [stride, padding, dilation, transposed, output_padding, groups] + new_args = list(new_node.args) + if len(new_args) < 3: + # bias is optional, when it is not present, it means it is None + new_args.append(None) + new_node.args = tuple(new_args[:3]) + original_node.args[3:] + +def _update_conv_input_qspec_map_after_replacement(original_node: Node, replacement_node: Node): + """ + Update the `input_qspec_map` in the annotation after subgraph rewriting. + + The original annotation referred to the nodes in the original graph, + so the keys in the `input_qspec_map` will need to be updated to reflect + the corresponding nodes in the replacement graph. + """ + assert _is_conv(original_node) + assert _is_conv(replacement_node) + if "quantization_annotation" not in original_node.meta: + return + original_input_qspec_map = original_node.meta["quantization_annotation"].input_qspec_map + input_qspec_map = {} + # get the list of configs, it should be ordered as input, weight, bias + # note: this is really hacky, we need a better solution, hopefully + # in subgraph_rewriter, issue tracking the problem: https://github.com/pytorch/pytorch/issues/101820 + all_configs = list(original_input_qspec_map.items()) + # input activation + input_qspec_map[replacement_node.args[0]] = all_configs[0][1] + # weight + input_qspec_map[replacement_node.args[1]] = all_configs[1][1] + # bias + if len(replacement_node.args) > 2 and len(all_configs) > 2: + input_qspec_map[replacement_node.args[2]] = all_configs[2][1] + replacement_node.meta["quantization_annotation"].input_qspec_map = input_qspec_map + +def _update_special_qspecs_after_replacement( + node: Node, + original_to_replacement_node: Dict[Node, Node], +): + """ + Update the `SharedQuantizationSpec`s and `DerivedQuantizationSpec`s + used in `node`'s quantization annotation after subgraph rewriting. + + The original annotation referred to the nodes in the original graph, + so the nodes used in these special quantization specs will need to + be updated to the corresponding nodes in the replacement graph. + """ + def _get_new_edge_or_node(edge_or_node: EdgeOrNode): + if isinstance(edge_or_node, Node): + _node = edge_or_node + return original_to_replacement_node.get(_node, _node) + elif isinstance(edge_or_node, tuple) and len(edge_or_node) == 2 and all(isinstance(x, Node) for x in edge_or_node): + src, dest = edge_or_node + return ( + original_to_replacement_node.get(src, src), + original_to_replacement_node.get(dest, dest), + ) + else: + raise ValueError("unexpected type for edge_or_node: ", type(edge_or_node)) + + def _get_new_qspec(qspec: QuantizationSpecBase): + if isinstance(qspec, SharedQuantizationSpec): + new_edge_or_node = _get_new_edge_or_node(qspec.edge_or_node) + return SharedQuantizationSpec(new_edge_or_node) + elif isinstance(qspec, DerivedQuantizationSpec): + new_derived_from = [_get_new_edge_or_node(x) for x in qspec.derived_from] + return dataclasses.replace(qspec, derived_from=new_derived_from) + else: + return qspec + + if "quantization_annotation" not in node.meta: + return + annotation = node.meta["quantization_annotation"] + for input_node, qspec in annotation.input_qspec_map.items(): + annotation.input_qspec_map[input_node] = _get_new_qspec(qspec) + annotation.output_qspec = _get_new_qspec(annotation.output_qspec) + +def _fuse_conv_bn_qat(m: GraphModule) -> GraphModule: + has_bn = any(_is_bn_node(n) for n in m.graph.nodes) + if not has_bn: + return m + m = _fuse_conv_bn_qat_helper(m, F.conv1d, _conv1d_bn_example_inputs, is_cuda=False) + m = _fuse_conv_bn_qat_helper(m, F.conv2d, _conv2d_bn_example_inputs, is_cuda=False) + if torch.cuda.is_available(): + m = _fuse_conv_bn_qat_helper(m, F.conv1d, _conv1d_bn_example_inputs, is_cuda=True) + m = _fuse_conv_bn_qat_helper(m, F.conv2d, _conv2d_bn_example_inputs, is_cuda=True) + return m + +def _fuse_conv_bn_qat_helper( + m: GraphModule, + conv_fn: Callable, + example_inputs: Tuple[Any, ...], + is_cuda: bool, +) -> GraphModule: + """ + Given a graph of decomposed aten ops, replace the (conv + bn) pattern with + the fused QAT subgraph equivalent. The input graph should already be annotated. + The annotations in the original nodes will be preserved in the corresponding + nodes in the new subgraph. + + Note: This also handles the (conv + bn + relu) pattern. + """ + m.graph.eliminate_dead_code() + m.recompile() + conv_bn_pattern = _get_conv_bn_pattern(conv_fn) + match_pattern = get_aten_graph_module(conv_bn_pattern, example_inputs, is_cuda) + + # Step (1): Replace patterns with conv bias + # + # Here we do replacement separately for cases with and without conv bias, since + # the replacement patterns for these two cases are substantially different. + # TODO: use the public replace_pattern API once it also returns replacement nodes + + qat_conv_bn_pattern = _get_qat_conv_bn_pattern(conv_fn) + replacement_pattern_with_conv_bias = get_aten_graph_module( + qat_conv_bn_pattern, + example_inputs, + is_cuda, + ) + replacements_with_conv_bias = replace_pattern_with_filters( + m, + match_pattern, + replacement_pattern_with_conv_bias, + match_filters=[_has_conv_bias_filter], + ignore_literals=True, + ) + m.recompile() + + # Step (2): Replace patterns without conv bias + + qat_conv_bn_pattern_no_conv_bias = _get_qat_conv_bn_pattern_no_conv_bias(conv_fn) + replacement_pattern_no_conv_bias = get_aten_graph_module( + qat_conv_bn_pattern_no_conv_bias, + example_inputs, + is_cuda, + ) + replacements_no_conv_bias = replace_pattern_with_filters( + m, + match_pattern, + replacement_pattern_no_conv_bias, + match_filters=[_no_conv_bias_filter], + ignore_literals=True, + ) + m.recompile() + + # Step (3): Post processing + # + # Due to limited functionality in the subgraph rewriter, here we manually + # update the replacement graph as follows: + # + # (a) Copy over metadata from original subgraph. This ensures the stack traces + # and annotations are preserved in the new subgraph + # + # (b) Copy over literal args for conv from the original subgraph + # TODO: do this for literal args for batchnorm as well + # + # (c) Update all references of the old nodes in the original subgraph to refer + # to the corresponding nodes in the new subgraph in the annotations + # + # In the future, we should try to push as much of this functionality into the + # subgraph rewriter as possible, so we don't have to manually copy anything over. + # For more detail, see https://github.com/pytorch/pytorch/issues/100419. + + all_original_to_replacement_nodes = {} + for r in replacements_with_conv_bias + replacements_no_conv_bias: + for original_node, replacement_node in _get_conv_bn_pattern_nodes(r).values(): + # Step (3a): Copy over metadata for all nodes in [conv - bn - getitem] + replacement_node.meta = original_node.meta + if _is_conv(original_node): + # Step (3b): Copy over conv literal args + _copy_over_literal_conv_args(original_node, replacement_node) + # Step (3c): Update old references in the conv node's input_qspec_map + _update_conv_input_qspec_map_after_replacement(original_node, replacement_node) + all_original_to_replacement_nodes[original_node] = replacement_node + + # Step (3c): Update old references in the special qspecs for all nodes in the graph + for n in m.graph.nodes: + _update_special_qspecs_after_replacement(n, all_original_to_replacement_nodes) + + return m + +def _duplicate_dequantize_node(m: GraphModule): + """ + Helper function to duplicate all dequantize nodes in the graph if the + node has more than one user. For example: + + Before: + quantize -> dequantize -> a + \\--> b + \\--> c + + After: + quantize -> dequantize_1 -> a + \\--> dequantize_2 -> b + \\--> dequantize_3 -> c + + This is useful for subgraph rewriting. E.g. if we wish to match the + pattern [dequantize - a] above, subgraph matching would fail because + the dequantize node has users outside the matched portion of the graph. + Instead, we match [dequantize_1 - a], which is safe. + """ + dq_op = torch.ops.quantized_decomposed.dequantize_per_tensor + for n in m.graph.nodes: + if n.op != "call_function" or n.target != dq_op or len(n.users) == 1: + continue + for user in list(n.users): + with m.graph.inserting_before(n): + new_node = m.graph.create_node("call_function", dq_op, n.args, n.kwargs) + user.replace_input_with(n, new_node) + m.graph.erase_node(n) + m.recompile() + +def _remove_extra_dequantize(m: GraphModule): + """ + Removes duplicate dequant nodes in the graph, for an operator that has + multiple dequant nodes as a user, replace them with a single dequant node + that can be shared across all the uses. This should be seen as the "reverse" + of `_duplicate_dequantize_node`. + """ + dq_op = torch.ops.quantized_decomposed.dequantize_per_tensor + for n in m.graph.nodes: + dq_users = [user for user in n.users if user.op == "call_function" and user.target == dq_op] + if len(dq_users) > 1: + with m.graph.inserting_after(dq_users[0]): + new_node = m.graph.create_node("call_function", dq_op, dq_users[0].args, {}) + for dq_user in dq_users: + dq_user.replace_all_uses_with(new_node) + m.graph.erase_node(dq_user) + m.recompile() + +def _copy_over_q_dq_args(original_node: Node, replacement_node: Node): + """ + Given a pair of quantize or dequantize nodes, copy over all literal args + from the original node to the replacement node. + """ + # For quantize_per_tensor, scale and zp are literals and need to be copied + # For quantize_per_channel, scale and zp are get_attr nodes and should be skipped + assert original_node.target == replacement_node.target + if original_node.target in ( + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + ): + # Args: input, [scale, zp, qmin, qmax, dtype] + start_copy_arg_index = 1 + elif original_node.target in ( + torch.ops.quantized_decomposed.quantize_per_channel.default, + torch.ops.quantized_decomposed.dequantize_per_channel.default, + ): + # Args: input, scale, zp, [axis, qmin, qmax, dtype] + start_copy_arg_index = 3 + else: + raise ValueError("Expected quantize/dequantize nodes, got '%s'" % original_node.target) + replacement_node.args = ( + replacement_node.args[:start_copy_arg_index] + original_node.args[start_copy_arg_index:] + ) + +def _fold_conv_bn_qat(m: GraphModule) -> GraphModule: + has_bn = any(_is_bn_node(n) for n in m.graph.nodes) + if not has_bn: + return m + m = _fold_conv_bn_qat_helper(m, F.conv1d, _quantized_conv1d_bn_example_inputs, is_cuda=False) + m = _fold_conv_bn_qat_helper(m, F.conv2d, _quantized_conv2d_bn_example_inputs, is_cuda=False) + if torch.cuda.is_available(): + m = _fold_conv_bn_qat_helper(m, F.conv1d, _quantized_conv1d_bn_example_inputs, is_cuda=True) + m = _fold_conv_bn_qat_helper(m, F.conv2d, _quantized_conv2d_bn_example_inputs, is_cuda=True) + return m + +def _fold_conv_bn_qat_helper( + m: GraphModule, + conv_fn: Callable, + example_inputs: Tuple[Any, ...], + is_cuda: bool, +) -> GraphModule: + """ + Replace the quantized (conv + bn) pattern with conv with bn weights folded into the weights of conv. + """ + m.graph.eliminate_dead_code() + m.recompile() + _duplicate_dequantize_node(m) + + # Step (1): Replace QAT pattern with simple [conv - bn] pattern + replacements = [] + replacement_options = itertools.product( + [True, False], # is_per_channel + [True, False], # has_bias + [True, False], # bias_is_quantized + [True, False], # bn_is_training + ) + for is_per_channel, has_bias, bias_is_quantized, bn_is_training in replacement_options: + # For the cases without bias, `bias_is_quantized` is irrelevant, so here we arbitrarily + # filter out one of the values for this flag to avoid having duplicate patterns + if not has_bias and bias_is_quantized: + continue + kwargs = _get_quantized_conv_bn_example_inputs_kwargs(is_per_channel, has_bias, is_cuda) + match_pattern = _get_quantized_qat_conv_bn_pattern( + is_per_channel, has_bias, bias_is_quantized, conv_fn, bn_is_training + ) + match_pattern = get_aten_graph_module(match_pattern, example_inputs, is_cuda, **kwargs) + replacement_pattern = _get_folded_quantized_qat_conv_bn_pattern( + is_per_channel, has_bias, bias_is_quantized, conv_fn, bn_is_training + ) + replacement_pattern = get_aten_graph_module(replacement_pattern, example_inputs, is_cuda, **kwargs) + replacements.extend( + replace_pattern_with_filters( + m, + match_pattern, + replacement_pattern, + ignore_literals=True, + ) + ) + m.recompile() + _remove_extra_dequantize(m) + + for r in replacements: + node_map = _get_conv_bn_pattern_nodes(r) + + # Step (2): Copy over metadata from original subgraph + for original_node, replacement_node in node_map.values(): + replacement_node.meta = original_node.meta + + # Step (3): Copy over args for weight (and optionally bias) q - dq nodes + _copy_over_q_dq_args(*node_map["conv_weight_q"]) + _copy_over_q_dq_args(*node_map["conv_weight_dq"]) + if "conv_bias_q" in node_map: + assert "conv_bias_dq" in node_map + _copy_over_q_dq_args(*node_map["conv_bias_q"]) + _copy_over_q_dq_args(*node_map["conv_bias_dq"]) + + # Step (4): Fold BN weights into conv + conv_bias = None + (_, conv_node) = node_map["conv"] + (_, bn_node) = node_map["bn"] + (_, conv_weight) = node_map["conv_weight"] + if "conv_bias" in node_map: + (_, conv_bias) = node_map["conv_bias"] + fold_bn_weights_into_conv_node(conv_node, conv_weight, conv_bias, bn_node, m) + + # Copy over literal args for conv + for original_node in _filter_nodes_map(r.nodes_map).values(): + if _is_conv(original_node): + _copy_over_literal_conv_args(original_node, conv_node) + + m.graph.eliminate_dead_code() + m.recompile() + return m diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/representation/__pycache__/rewrite.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/representation/__pycache__/rewrite.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e62c276fecc8bd343a79b2d126244cf240138a9d Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/representation/__pycache__/rewrite.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/utils.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..70e4662be2c54aaaaaeb0781cbb48ad1e19f84ea --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/utils.py @@ -0,0 +1,540 @@ +import operator +import types + +import torch +from torch._export import capture_pre_autograd_graph +from torch.fx import ( + GraphModule, + Node, +) +from torch.nn.utils.fusion import fuse_conv_bn_weights +from typing import Any, Callable, Dict, Optional, Tuple, List, Union +from torch.utils._pytree import LeafSpec +from torch.export.unflatten import _AttrKind, _assign_attr + +# Makes sure that quantized_decomposed ops are registered +from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401 + +from torch.ao.quantization.quantizer import QuantizationAnnotation + + +__all__ = [ + "fold_bn_weights_into_conv_node", + "get_aten_graph_module", + "remove_tensor_overload_for_qdq_ops", +] + +_QUANTIZE_OPS = [ + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.quantize_per_tensor.tensor, + torch.ops.quantized_decomposed.quantize_per_channel.default, +] + + +_DEQUANTIZE_OPS = [ + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, + torch.ops.quantized_decomposed.dequantize_per_channel.default, +] + +# Example inputs for conv-bn1d patterns +_conv1d_bn_example_inputs = ( + torch.randn(1, 1, 3), # x + torch.randn(1, 1, 1), # conv_weight + torch.randn(1), # conv_bias + torch.randn(1), # bn_weight + torch.randn(1), # bn_bias + torch.randn(1), # bn_running_mean + torch.randn(1), # bn_running_var +) + +# Example inputs for conv-bn2d patterns +_conv2d_bn_example_inputs = ( + torch.randn(1, 1, 3, 3), # x + torch.randn(1, 1, 1, 1), # conv_weight + torch.randn(1), # conv_bias + torch.randn(1), # bn_weight + torch.randn(1), # bn_bias + torch.randn(1), # bn_running_mean + torch.randn(1), # bn_running_var +) + +def _is_connected(source: torch.fx.Node, dest: torch.fx.Node) -> bool: + """ + Assuming dest is one of the ops inserted by quant workflow, this function + finds if source and dest are connected. Assumption is that only quant workflow + inserted ops exist between source and dest + """ + quant_workflow_ops = _QUANTIZE_OPS + _DEQUANTIZE_OPS + quant_workflow_ops.append(torch.ops.quantized_decomposed.choose_qparams.tensor) + while dest.target in quant_workflow_ops: + if not isinstance(dest.args[0], torch.fx.Node): + raise ValueError(f"expected arg[0] of quant workflow ops to be a node but found {dest.args[0]}") + dest = dest.args[0] + return (dest == source) + + +def _find_q_dq_node_for_user( + produer: torch.fx.Node, user: torch.fx.Node +) -> Tuple[Any, Any]: + """ + Find q, dq pair corresponding to [producer -> q -> dq -> user] + Utils works by finding dq arg of user and ensuring it is connected to + producer + """ + dq_node = None + for n in user.args: + if isinstance(n, torch.fx.Node) and n.op == "call_function" and n.target in _DEQUANTIZE_OPS: + if _is_connected(produer, n): + dq_node = n + break + if dq_node is None: + for n in user.kwargs: + if isinstance(n, torch.fx.Node) and n.op == "call_function" and n.target in _DEQUANTIZE_OPS: + if _is_connected(produer, n): + dq_node = n + break + if dq_node is None: + return (None, None) + + q_node = None + if dq_node.args[0].op == "call_function" and dq_node.args[0].target in _QUANTIZE_OPS: + q_node = dq_node.args[0] + return (q_node, dq_node) + + + +def _is_sym_size_node(node: Node): + return ( + node.op == "call_function" + and node.target == torch.ops.aten.sym_size.default + or node.target == torch.ops.aten.sym_numel.default + or node.target == torch.ops.aten.sym_numel + or node.target == torch.ops.aten.sym_size + ) + + +def _filter_sym_size_users(node: torch.fx.Node) -> List[torch.fx.Node]: + node_users = list(filter((lambda x: (_is_sym_size_node(x) is False)), node.users)) + return node_users + + +def _is_valid_annotation(annotation: QuantizationAnnotation) -> bool: + if annotation is None: + return False + input_qspec_map = annotation.input_qspec_map + output_qspec = annotation.output_qspec + if len(input_qspec_map) == 0 and output_qspec is None: + return False + return True + + +def _get_tensor_constant_from_node(node, m): + if node is None: + return None + assert node.op == "get_attr" + target_atoms = node.target.split('.') + attr_itr = m + for i, atom in enumerate(target_atoms): + if not hasattr(attr_itr, atom): + raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}") + attr_itr = getattr(attr_itr, atom) + return attr_itr + +def _get_all_arguments(orig_args, orig_kwargs, args_schema): + all_args = [] + for i, schema in enumerate(args_schema): + if schema.name in orig_kwargs: + all_args.append(orig_kwargs[schema.name]) + elif not schema.kwarg_only and i < len(orig_args): + all_args.append(orig_args[i]) + else: + all_args.append(schema.default_value) + return all_args + +def _is_supported_batch_norm_for_training(node: Node): + """ + Return True if the given node refers to an aten batch norm op QAT supports. + """ + supported_ops = [ + torch.ops.aten._native_batch_norm_legit.default, + # Note: we won't need this op anymore after batch norm consolidation + # For now, we need to continue to support it because it gives better + # training numerics than `_native_batch_norm_legit` + torch.ops.aten.cudnn_batch_norm.default, + torch.ops.aten.miopen_batch_norm.default, + ] + return node.target in supported_ops + +# TODO: rename this to _is_conv_node +def _is_conv(n: Node): + """ + Return whether the node refers to an aten conv op. + """ + return n.op == "call_function" and n.target in [ + torch.ops.aten.conv1d.default, + torch.ops.aten.conv2d.default, + ] + +# TODO: rename this to _is_conv_transpose_node +def _is_conv_transpose(n: Node): + """ + Return whether the node refers to an aten conv_transpose op. + """ + return n.op == "call_function" and n.target in [ + torch.ops.aten.conv_transpose1d, + torch.ops.aten.conv_transpose2d, + ] + +def _is_bn_node(n: Node): + return _is_supported_batch_norm_for_training(n) or n.target == torch.ops.aten._native_batch_norm_legit_no_training.default + +def fold_bn_weights_into_conv_node( + conv_node: Node, + conv_weight_node: Node, + conv_bias_node: Optional[Node], + bn_node: Node, + m: GraphModule +) -> None: + # conv args: input, weight, bias, stride, padding, dilation, ... + conv_w = _get_tensor_constant_from_node(conv_weight_node, m) + conv_b = _get_tensor_constant_from_node(conv_bias_node, m) + transpose = _is_conv_transpose(conv_node) + + # eval bn args: input, weight, bias, running mean, running var, momentum, eps + # train bn args: input, weight, bias, running mean, running var, training, momentum, eps + bn_args_schema = bn_node.target._schema.arguments # type: ignore[union-attr] + bn_args = _get_all_arguments(bn_node.args, bn_node.kwargs, bn_args_schema) + bn_w = _get_tensor_constant_from_node(bn_args[1], m) + bn_b = _get_tensor_constant_from_node(bn_args[2], m) + bn_rm = _get_tensor_constant_from_node(bn_args[3], m) + bn_rv = _get_tensor_constant_from_node(bn_args[4], m) + if bn_node.target == torch.ops.aten._native_batch_norm_legit_no_training.default: + eps_arg_index = 6 + elif _is_supported_batch_norm_for_training(bn_node): + eps_arg_index = 7 + else: + raise ValueError("BN node target is unexpected ", bn_node.target) + bn_eps = bn_args[eps_arg_index] + + fused_weight, fused_bias = fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b, transpose=transpose) + + # update the weight and bias for conv + conv_args = list(conv_node.args) + # filling in the default bias argument + if len(conv_args) == 2: + conv_args.append(None) + + # calling data since the fused_weight and fused_bias are nn.Parameter + weight_attr_name = conv_weight_node.target + assert isinstance(weight_attr_name, str) + _assign_attr(fused_weight, m, weight_attr_name, _AttrKind.PARAMETER) + if conv_bias_node is not None: + bias_attr_name = conv_bias_node.target + _assign_attr(fused_bias, m, str(bias_attr_name), _AttrKind.PARAMETER) + else: + bias_attr_name = weight_attr_name + "_bias" + _assign_attr(fused_bias, m, bias_attr_name, _AttrKind.PARAMETER) + with m.graph.inserting_before(conv_node): + get_bias_node = m.graph.get_attr(bias_attr_name) + # NOTE: here we assume the bias of conv is not quantized! + conv_args[2] = get_bias_node + conv_node.args = tuple(conv_args) + + # native_batch_norm has 3 outputs, we expect getitem calls on the output + # and we want to replace the uses of getitem 0 with the output of conv + # + # Before: + # conv -> bn - (first output) -> users1 + # \ - (second output) -> users2 + # \ - (third output) -> users3 + # After: + # conv -> (first output) -> users1 + # bn - + # \ - (second output) -> users2 + # \ - (third output) -> users3 + # if users2 and users3 are empty then bn will be removed through dead code elimination + + for user in bn_node.users: + if user.op != "call_function" or user.target != operator.getitem or user.args[1] != 0: + continue + user.replace_all_uses_with(conv_node) + +# fuse conv bn weights, inplace modification of the graph_module and graph +def _fuse_conv_bn_(m: GraphModule) -> None: + has_bn = any(_is_bn_node(n) for n in m.graph.nodes) + if not has_bn: + return + for n in m.graph.nodes: + if n.op != "call_function" or n.target != torch.ops.aten._native_batch_norm_legit_no_training.default: + continue + bn_node = n + n = bn_node.args[0] + if not _is_conv(n): + continue + conv_node = n + conv_weight_node = conv_node.args[1] + conv_bias_node = conv_node.args[2] if len(conv_node.args) > 2 else None + fold_bn_weights_into_conv_node(conv_node, conv_weight_node, conv_bias_node, bn_node, m) + + m.graph.eliminate_dead_code() + m.recompile() + +def _get_node_name_to_scope(model: GraphModule) -> Dict[str, Tuple[str, type]]: + # TODO: move this information to fx node itself + node_name_to_scope: Dict[str, Tuple[str, type]] = {} + for n in model.graph.nodes: + nn_module_stack = n.meta.get("nn_module_stack", None) + current_scope = ("", type(None)) + if nn_module_stack: + bt = list(nn_module_stack.values())[-1] + current_scope = (bt[0].split(".")[-1], bt[1]) + node_name_to_scope[n.name] = current_scope + return node_name_to_scope + +def get_aten_graph_module( + pattern: Callable, + example_inputs: Tuple[Any, ...], + is_cuda: bool = False, + **kwargs, +) -> GraphModule: + """ + Convert the pattern to an FX graph with decomposed aten ops. + """ + if is_cuda: + example_inputs = tuple([x.cuda() if isinstance(x, torch.Tensor) else x for x in example_inputs]) + aten_pattern = capture_pre_autograd_graph( + pattern, + example_inputs, + kwargs, + ) + aten_pattern.graph.eliminate_dead_code() + aten_pattern.recompile() + return aten_pattern + +def remove_tensor_overload_for_qdq_ops(match_pattern: GraphModule) -> None: + """ Remove .tensor overload for quantize/dequantize ops so that we can + use the match_pattern that we get from torchdynamo export to match the output of convert_pt2e + """ + _MAP = { + torch.ops.quantized_decomposed.quantize_per_tensor.default: torch.ops.quantized_decomposed.quantize_per_tensor, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: torch.ops.quantized_decomposed.dequantize_per_tensor, + torch.ops.quantized_decomposed.quantize_per_tensor.tensor: torch.ops.quantized_decomposed.quantize_per_tensor, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: torch.ops.quantized_decomposed.dequantize_per_tensor, + torch.ops.quantized_decomposed.quantize_per_tensor.tensor2: torch.ops.quantized_decomposed.quantize_per_tensor, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor2: torch.ops.quantized_decomposed.dequantize_per_tensor, + torch.ops.quantized_decomposed.quantize_per_channel.default: torch.ops.quantized_decomposed.quantize_per_channel, + torch.ops.quantized_decomposed.dequantize_per_channel.default: torch.ops.quantized_decomposed.dequantize_per_channel, + torch.ops.aten.clamp.Tensor: torch.ops.aten.clamp, + } + for n in match_pattern.graph.nodes: + if n.op != "call_function": + continue + if n.target in _MAP: + n.target = _MAP[n.target] + +def _is_literal(arg): + if isinstance(arg, (int, float)): + return True + if isinstance(arg, (tuple, list)): + return all(map(_is_literal, arg)) + return False + +def _replace_literals_with_new_placeholders( + gm: torch.fx.GraphModule, + merge_dup: bool = False, + exclude_literals: Optional[List[Any]] = None +): + """Replace the literals in the graph with placeholder nodes that's created on the fly while we + traverse the graph, so that the literal arguments in the graph can be matched and replaced + + To use this, the pattern and replacement graph should have the exact same number of literal args + and they should be used in the exact same order in the pattern and replacement graph. + + If the literal arguments are not used in the same order in pattern and replacement graph, please + use `_replace_literals_with_existing_placeholders` instead + + Args: + `gm`: input GraphModule that we'll transform + `merge_dup`: boolean flag to indicate that if the same literal appears multiple times in + the graph, whether they should correspond to the same placeholder or not + `exclude_literals`: a list of literals that will not be replaced with placeholders + + Example: + + # 1. Original Graph + def pattern(self, x): + return x + 3 + + def replacement(self, x): + return x - 3 + + example_inputs = (torch.randn(1, 3, 3, 3),) + pattern_gm = get_aten_graph_module(pattern, example_inputs) + replacement_gm = get_aten_graph_module(pattern, example_inptus) + + # 2. Before calling replace literals we'll see the following graph: + def pattern(self, x): + return x + 3 + + def replacement(self, x): + return x - 3 + + pattern_gm = _replace_literals_with_new_placeholders(pattern_gm) + replacement_gm = _replace_literals_with_new_placeholders(replacement_gm) + + # 3. After replacing literals with new placeholder nodes + + def pattern(self, x, new_ph): + return x + new_ph + + def pattern(self, x, new_ph): + return x - new_ph + + """ + last_ph = None + cnt = 0 + literal_to_ph: Dict[Union[float, bool, int, torch.dtype], Node] = {} + if exclude_literals is None: + exclude_literals = [] + + in_spec = gm._in_spec + args_spec = in_spec.children_specs[0] + for node in gm.graph.nodes: + if node.op == "placeholder": + last_ph = node + cnt += 1 + continue + with gm.graph.inserting_after(last_ph): + new_args = [] + for arg in node.args: + if _is_literal(arg) and arg not in exclude_literals: + if merge_dup and arg in literal_to_ph: + new_args.append(literal_to_ph[arg]) + else: + ph_node = gm.graph.placeholder("arg" + str(cnt)) + new_args.append(ph_node) + args_spec.children_specs.append(LeafSpec()) + cnt += 1 + if merge_dup: + literal_to_ph[arg] = ph_node + else: + new_args.append(arg) + new_args = tuple(new_args) + + node.args = new_args + + # Update `num_nodes`, `num_leaves`, `num_children`. + args_spec.__post_init__() + in_spec.__post_init__() + return gm + + +def _replace_literals_with_existing_placeholders( + gm: torch.fx.GraphModule, + exclude_literals: Optional[List[Any]] = None, + literal_to_ph_idx: Optional[Dict[Union[float, int, bool, torch.dtype], int]] = None +): + """Replace the literals in the graph with **existing** placeholder nodes, so that the literal arguments + in the graph can be matched and replaced + + To use this, all literal args in the graph should be unique and each of them should correspond + to exactly one placeholder node + + # 1. Original Graph + def pattern(self, x_i8, scale, zero_point, quant_min, quant_max): + return torch.dequantize_per_tensor(x_i8, scale, zero_point, quant_min, quant_max) + + def replacement(x_i8, scale, zero_point, quant_min, quant_max): + x_i8 = torch.clamp(x_i8, quant_min, quant_max) + return ((x_i8.to(torch.float32) - zero_point) * scale).to(dtype=torch.float32) + + example_inputs = ( + torch.randn(1, 3, 3, 3), + 1.0, + 0, + -128, + 127, + ) + pattern_gm = get_aten_graph_module(pattern, example_inputs) + replacement_gm = get_aten_graph_module(pattern, example_inptus) + + # 2. Before calling replace literals we'll see the following graph: + def pattern(self, x_i8, scale, zero_point, quant_min, quant_max): + # scale/zero_point/quant_min/quant_max are burnt in since they are scalar values + return torch.dequantize_per_tensor(x_i8, 1.0, 0, -128, 127) + + def replacement(x_i8, scale, zero_point, quant_min, quant_max): + # scale/zero_point/quant_min/quant_max are burnt in since they are scalar values + x_i8 = torch.clamp(x_i8, -128, 127) + return ((x_i8.to(torch.float32) - 0) * 1.0).to(dtype=torch.float32) + + # Note that literal args appear in different order in pattern and replacement graph, so + # we can't use _replace_literals_with_new_placeholders + + literal_to_ph_idx = {1.0: 1, 0: 2, -128: 3, 127: 4} + pattern_gm = _replace_literals_with_existing_placeholders(pattern_gm, literal_to_ph_idx) + replacement_gm = _replace_literals_with_existing_placeholders(replacement_gm, literal_to_ph_idx) + + # 3. After replacing literals with existing placeholder nodes + + def pattern(self, x_i8, scale, zero_point, quant_min, quant_max): + # scale/zero_point/quant_min/quant_max are burnt in since they are scalar values + return torch.dequantize_per_tensor(x_i8, scale, zero_point, quant_min, quant_max) + + def replacement(x_i8, scale, zero_point, quant_min, quant_max): + # scale/zero_point/quant_min/quant_max are burnt in since they are scalar values + x_i8 = torch.clamp(x_i8, quant_min, quant_max) + return ((x_i8.to(torch.float32) - zero_point) * scale).to(dtype=torch.float32) + """ + if exclude_literals is None: + exclude_literals = [] + + if literal_to_ph_idx is None: + literal_to_ph_idx = {} + + phs = [node for node in gm.graph.nodes if node.op == "placeholder"] + + for node in gm.graph.nodes: + if node.op != "call_function": + continue + new_args = [] + for arg in node.args: + if _is_literal(arg) and arg not in exclude_literals and arg in literal_to_ph_idx: + ph_idx = literal_to_ph_idx[arg] + ph_node = phs[ph_idx] + new_args.append(ph_node) + else: + new_args.append(arg) + new_args = tuple(new_args) + node.args = new_args + return gm + +# TODO: Handle this in export itself and don't wrap the model in another GraphModule +# in prepare and convert +def _disallow_eval_train(model: GraphModule): + """ + Disallow calling `model.train()` or `model.eval()` on the given GraphModule. + This is useful for exported models, where these methods don't actually behave as expected. + """ + error_message = \ + """ + Calling train() or eval() is not supported for exported models. + Please call `torch.ao.quantization.move_exported_model_to_train(model)` (or eval) instead. + + If you cannot replace the calls to `model.train()` and `model.eval()`, you may override + the behavior for these methods by calling `torch.ao.quantization.allow_exported_model_train_eval(model)`, + which does the above automatically for you. Note that this has limited effect on switching + behavior between train and eval modes, and should be used only for special ops such as dropout + and batchnorm. + """ + + def _train(self, mode: bool = True): + raise NotImplementedError(error_message) + + def _eval(self, mode: bool = True): + raise NotImplementedError(error_message) + + model.train = types.MethodType(_train, model) # type: ignore[method-assign] + model.eval = types.MethodType(_eval, model) # type: ignore[method-assign] + return model diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/quantize_fx.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/quantize_fx.py new file mode 100644 index 0000000000000000000000000000000000000000..c9a3db87552aeccc5763c5a3cc6449547d5afb53 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/quantize_fx.py @@ -0,0 +1,726 @@ +from typing import Any, Dict, Optional, Tuple, Union +import warnings + +import torch +import copy +from torch.fx import GraphModule +from torch.fx.graph_module import _USER_PRESERVED_ATTRIBUTES_KEY +from .fx.tracer import QuantizationTracer +from .fx.tracer import ( # noqa: F401 + Scope, + ScopeContextManager +) +from .fx.fuse import fuse # noqa: F401 +from .fx.prepare import prepare # noqa: F401 +from .fx.convert import convert +from .backend_config import ( # noqa: F401 + BackendConfig, + get_tensorrt_backend_config, +) +from .fx.graph_module import ObservedGraphModule # noqa: F401 +from .fx.custom_config import ( + ConvertCustomConfig, + FuseCustomConfig, + PrepareCustomConfig, +) +from .fx.utils import get_custom_module_class_keys # noqa: F401 +from .fx.utils import get_skipped_module_name_and_classes +from .qconfig_mapping import QConfigMapping + +def attach_preserved_attrs_to_model( + model: Union[GraphModule, torch.nn.Module], + preserved_attrs: Dict[str, Any], +) -> None: + """ Store preserved attributes to the model.meta so that it can be preserved during deepcopy + """ + model.meta[_USER_PRESERVED_ATTRIBUTES_KEY] = copy.copy(preserved_attrs) # type: ignore[operator, index, assignment] + # set the preserved attributes in the model so that user can call + # model.attr as they do before calling fx graph mode quantization + for attr_name, attr in model.meta[_USER_PRESERVED_ATTRIBUTES_KEY].items(): # type: ignore[index, union-attr] + setattr(model, attr_name, attr) + +def _check_is_graph_module(model: torch.nn.Module) -> None: + if not isinstance(model, GraphModule): + raise ValueError( + "input model must be a GraphModule, " + + "Got type:" + + str(type(model)) + + " Please make " + + "sure to follow the tutorials." + ) + +def _attach_meta_to_node_if_not_exist(model: GraphModule) -> None: + """ Attach meta field to all nodes of the graph if it does not exist, + meta field is a field stores some meta information about the node, such + as dtype and shape information for output of the node, this only exists + if the program is captured by make_fx (used in quantize_pt2e flow), if + the program is captured by torch.fx symbolic tracing, this field may not exist, + so we add it here to avoid checking this all over the places + """ + for node in model.graph.nodes: + if not hasattr(node, "meta"): + node.meta = {} + +def _swap_ff_with_fxff(model: torch.nn.Module) -> None: + r""" Swap FloatFunctional with FXFloatFunctional + """ + modules_to_swap = [] + for name, module in model.named_children(): + if isinstance(module, torch.ao.nn.quantized.FloatFunctional): + modules_to_swap.append(name) + else: + _swap_ff_with_fxff(module) + + for name in modules_to_swap: + del model._modules[name] + model._modules[name] = torch.ao.nn.quantized.FXFloatFunctional() + + +def _fuse_fx( + model: GraphModule, + is_qat: bool, + fuse_custom_config: Union[FuseCustomConfig, Dict[str, Any], None] = None, + backend_config: Union[BackendConfig, Dict[str, Any], None] = None, +) -> GraphModule: + r""" Internal helper function to fuse modules in preparation for quantization + + Args: + model: GraphModule object from symbolic tracing (torch.fx.symbolic_trace) + """ + _check_is_graph_module(model) + return fuse( + model, is_qat, fuse_custom_config, backend_config) # type: ignore[operator] + +def _prepare_fx( + model: torch.nn.Module, + qconfig_mapping: Union[QConfigMapping, Dict[str, Any]], + is_qat: bool, + example_inputs: Tuple[Any, ...], + prepare_custom_config: Union[PrepareCustomConfig, Dict[str, Any], None] = None, + _equalization_config: Optional[Union[QConfigMapping, Dict[str, Any]]] = None, + backend_config: Union[BackendConfig, Dict[str, Any], None] = None, + is_standalone_module: bool = False, +) -> GraphModule: + r""" Internal helper function for prepare_fx + Args: + `model`, `qconfig_mapping`, `prepare_custom_config`, `_equalization_config`: + see docs for :func:`~torch.ao.quantization.prepare_fx` + `is_standalone_module`: a boolean flag indicates whether we are + quantizing a standalone module or not, a standalone module + is a submodule of the parent module that is not inlined in the +forward graph of the parent module, + the way we quantize standalone module is described in: + :func:`~torch.ao.quantization._prepare_standalone_module_fx` + """ + if prepare_custom_config is None: + prepare_custom_config = PrepareCustomConfig() + if _equalization_config is None: + _equalization_config = QConfigMapping() + + if isinstance(prepare_custom_config, Dict): + warnings.warn( + "Passing a prepare_custom_config_dict to prepare is deprecated and will not be supported " + "in a future version. Please pass in a PrepareCustomConfig instead.") + prepare_custom_config = PrepareCustomConfig.from_dict(prepare_custom_config) + + # swap FloatFunctional with FXFloatFunctional + _swap_ff_with_fxff(model) + + skipped_module_names, skipped_module_classes = \ + get_skipped_module_name_and_classes(prepare_custom_config, is_standalone_module) + preserved_attr_names = prepare_custom_config.preserved_attributes + preserved_attrs = {attr: getattr(model, attr) for attr in preserved_attr_names if hasattr(model, attr)} + # symbolically trace the model + tracer = QuantizationTracer(skipped_module_names, skipped_module_classes) # type: ignore[arg-type] + graph_module = GraphModule(model, tracer.trace(model)) + _attach_meta_to_node_if_not_exist(graph_module) + + fuse_custom_config = FuseCustomConfig().set_preserved_attributes(prepare_custom_config.preserved_attributes) + graph_module = _fuse_fx( + graph_module, + is_qat, + fuse_custom_config, + backend_config) + prepared = prepare( + graph_module, + qconfig_mapping, + is_qat, + tracer.node_name_to_scope, + example_inputs=example_inputs, + prepare_custom_config=prepare_custom_config, + _equalization_config=_equalization_config, + backend_config=backend_config, + is_standalone_module=is_standalone_module, + ) # type: ignore[operator] + + attach_preserved_attrs_to_model(prepared, preserved_attrs) + return prepared + + +def _prepare_standalone_module_fx( + model: torch.nn.Module, + qconfig_mapping: Union[QConfigMapping, Dict[str, Any]], + is_qat: bool, + example_inputs: Tuple[Any, ...], + prepare_custom_config: Union[PrepareCustomConfig, Dict[str, Any], None] = None, + backend_config: Union[BackendConfig, Dict[str, Any], None] = None, +) -> GraphModule: + r""" [Internal use only] Prepare a standalone module, so that it can be used when quantizing the + parent module. + standalone_module means it a submodule that is not inlined in parent module, + and will be quantized separately as one unit. + + How the standalone module is observed is specified by `input_quantized_idxs` and + `output_quantized_idxs` in the prepare_custom_config for the standalone module + + Returns: + + * model(GraphModule): prepared standalone module. It has these attributes in + model.meta: + + * `standalone_module_input_quantized_idxs(List[Int])`: a list of + indexes for the graph input that is expected to be quantized, + same as input_quantized_idxs configuration provided + for the standalone module + * `standalone_module_output_quantized_idxs(List[Int])`: a list of + indexs for the graph output that is quantized + same as input_quantized_idxs configuration provided + for the standalone module + + """ + return _prepare_fx( + model, + qconfig_mapping, + is_qat, + example_inputs, + prepare_custom_config, + backend_config=backend_config, + is_standalone_module=True, + ) + + +def fuse_fx( + model: torch.nn.Module, + fuse_custom_config: Union[FuseCustomConfig, Dict[str, Any], None] = None, + backend_config: Union[BackendConfig, Dict[str, Any], None] = None, +) -> GraphModule: + r""" Fuse modules like conv+bn, conv+bn+relu etc, model must be in eval mode. + Fusion rules are defined in torch.ao.quantization.fx.fusion_pattern.py + + Args: + + * `model` (torch.nn.Module): a torch.nn.Module model + * `fuse_custom_config` (FuseCustomConfig): custom configurations for fuse_fx. + See :class:`~torch.ao.quantization.fx.custom_config.FuseCustomConfig` for more details + Example:: + + from torch.ao.quantization import fuse_fx + m = Model().eval() + m = fuse_fx(m) + + """ + if fuse_custom_config is None: + fuse_custom_config = FuseCustomConfig() + + if isinstance(fuse_custom_config, Dict): + warnings.warn( + "Passing a fuse_custom_config_dict to fuse is deprecated and will not be supported " + "in a future version. Please pass in a FuseCustomConfig instead.") + fuse_custom_config = FuseCustomConfig.from_dict(fuse_custom_config) + + torch._C._log_api_usage_once("quantization_api.quantize_fx.fuse_fx") + preserved_attr_names = fuse_custom_config.preserved_attributes + preserved_attrs = {attr: getattr(model, attr) for attr in preserved_attr_names if hasattr(model, attr)} + + graph_module = torch.fx.symbolic_trace(model) + _attach_meta_to_node_if_not_exist(graph_module) + graph_module = _fuse_fx(graph_module, False, fuse_custom_config, backend_config) + + attach_preserved_attrs_to_model(graph_module, preserved_attrs) + return graph_module + +def prepare_fx( + model: torch.nn.Module, + qconfig_mapping: Union[QConfigMapping, Dict[str, Any]], + example_inputs: Tuple[Any, ...], + prepare_custom_config: Union[PrepareCustomConfig, Dict[str, Any], None] = None, + _equalization_config: Optional[Union[QConfigMapping, Dict[str, Any]]] = None, + backend_config: Union[BackendConfig, Dict[str, Any], None] = None, +) -> GraphModule: + r""" Prepare a model for post training quantization + + Args: + * `model` (torch.nn.Module): torch.nn.Module model + + * `qconfig_mapping` (QConfigMapping): QConfigMapping object to configure how a model is + quantized, see :class:`~torch.ao.quantization.qconfig_mapping.QConfigMapping` + for more details + + * `example_inputs` (Tuple[Any, ...]): Example inputs for forward function of the model, + Tuple of positional args (keyword args can be passed as positional args as well) + + * `prepare_custom_config` (PrepareCustomConfig): customization configuration for quantization tool. + See :class:`~torch.ao.quantization.fx.custom_config.PrepareCustomConfig` for more details + + * `_equalization_config`: config for specifying how to perform equalization on the model + + * `backend_config` (BackendConfig): config that specifies how operators are quantized + in a backend, this includes how the operators are observed, + supported fusion patterns, how quantize/dequantize ops are + inserted, supported dtypes etc. See :class:`~torch.ao.quantization.backend_config.BackendConfig` for more details + + Return: + A GraphModule with observer (configured by qconfig_mapping), ready for calibration + + Example:: + + import torch + from torch.ao.quantization import get_default_qconfig_mapping + from torch.ao.quantization.quantize_fx import prepare_fx + + class Submodule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(5, 5) + def forward(self, x): + x = self.linear(x) + return x + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(5, 5) + self.sub = Submodule() + + def forward(self, x): + x = self.linear(x) + x = self.sub(x) + x + return x + + # initialize a floating point model + float_model = M().eval() + + # define calibration function + def calibrate(model, data_loader): + model.eval() + with torch.no_grad(): + for image, target in data_loader: + model(image) + + # qconfig is the configuration for how we insert observers for a particular + # operator + # qconfig = get_default_qconfig("fbgemm") + # Example of customizing qconfig: + # qconfig = torch.ao.quantization.QConfig( + # activation=MinMaxObserver.with_args(dtype=torch.qint8), + # weight=MinMaxObserver.with_args(dtype=torch.qint8)) + # `activation` and `weight` are constructors of observer module + + # qconfig_mapping is a collection of quantization configurations, user can + # set the qconfig for each operator (torch op calls, functional calls, module calls) + # in the model through qconfig_mapping + # the following call will get the qconfig_mapping that works best for models + # that target "fbgemm" backend + qconfig_mapping = get_default_qconfig_mapping("fbgemm") + + # We can customize qconfig_mapping in different ways. + # e.g. set the global qconfig, which means we will use the same qconfig for + # all operators in the model, this can be overwritten by other settings + # qconfig_mapping = QConfigMapping().set_global(qconfig) + # e.g. quantize the linear submodule with a specific qconfig + # qconfig_mapping = QConfigMapping().set_module_name("linear", qconfig) + # e.g. quantize all nn.Linear modules with a specific qconfig + # qconfig_mapping = QConfigMapping().set_object_type(torch.nn.Linear, qconfig) + # for a more complete list, please see the docstring for :class:`torch.ao.quantization.QConfigMapping` + # argument + + # example_inputs is a tuple of inputs, that is used to infer the type of the + # outputs in the model + # currently it's not used, but please make sure model(*example_inputs) runs + example_inputs = (torch.randn(1, 3, 224, 224),) + + # TODO: add backend_config after we split the backend_config for fbgemm and qnnpack + # e.g. backend_config = get_default_backend_config("fbgemm") + # `prepare_fx` inserts observers in the model based on qconfig_mapping and + # backend_config. If the configuration for an operator in qconfig_mapping + # is supported in the backend_config (meaning it's supported by the target + # hardware), we'll insert observer modules according to the qconfig_mapping + # otherwise the configuration in qconfig_mapping will be ignored + # + # Example: + # in qconfig_mapping, user sets linear module to be quantized with quint8 for + # activation and qint8 for weight: + # qconfig = torch.ao.quantization.QConfig( + # observer=MinMaxObserver.with_args(dtype=torch.quint8), + # weight=MinMaxObserver.with-args(dtype=torch.qint8)) + # Note: current qconfig api does not support setting output observer, but + # we may extend this to support these more fine grained control in the + # future + # + # qconfig_mapping = QConfigMapping().set_object_type(torch.nn.Linear, qconfig) + # in backend config, linear module also supports in this configuration: + # weighted_int8_dtype_config = DTypeConfig( + # input_dtype=torch.quint8, + # output_dtype=torch.quint8, + # weight_dtype=torch.qint8, + # bias_type=torch.float) + + # linear_pattern_config = BackendPatternConfig(torch.nn.Linear) \ + # .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \ + # .add_dtype_config(weighted_int8_dtype_config) \ + # ... + + # backend_config = BackendConfig().set_backend_pattern_config(linear_pattern_config) + # `prepare_fx` will check that the setting requested by suer in qconfig_mapping + # is supported by the backend_config and insert observers and fake quant modules + # in the model + prepared_model = prepare_fx(float_model, qconfig_mapping, example_inputs) + # Run calibration + calibrate(prepared_model, sample_inference_data) + """ + torch._C._log_api_usage_once("quantization_api.quantize_fx.prepare_fx") + return _prepare_fx( + model, + qconfig_mapping, + False, # is_qat + example_inputs, + prepare_custom_config, + _equalization_config, + backend_config, + ) + + +def prepare_qat_fx( + model: torch.nn.Module, + qconfig_mapping: Union[QConfigMapping, Dict[str, Any]], + example_inputs: Tuple[Any, ...], + prepare_custom_config: Union[PrepareCustomConfig, Dict[str, Any], None] = None, + backend_config: Union[BackendConfig, Dict[str, Any], None] = None, +) -> GraphModule: + r""" Prepare a model for quantization aware training + + Args: + * `model` (torch.nn.Module): torch.nn.Module model + * `qconfig_mapping` (QConfigMapping): see :func:`~torch.ao.quantization.prepare_fx` + * `example_inputs` (Tuple[Any, ...]): see :func:`~torch.ao.quantization.prepare_fx` + * `prepare_custom_config` (PrepareCustomConfig): see :func:`~torch.ao.quantization.prepare_fx` + * `backend_config` (BackendConfig): see :func:`~torch.ao.quantization.prepare_fx` + + Return: + A GraphModule with fake quant modules (configured by qconfig_mapping and backend_config), ready for + quantization aware training + + Example:: + + import torch + from torch.ao.quantization import get_default_qat_qconfig_mapping + from torch.ao.quantization.quantize_fx import prepare_qat_fx + + class Submodule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(5, 5) + def forward(self, x): + x = self.linear(x) + return x + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(5, 5) + self.sub = Submodule() + + def forward(self, x): + x = self.linear(x) + x = self.sub(x) + x + return x + + # initialize a floating point model + float_model = M().train() + # (optional, but preferred) load the weights from pretrained model + # float_model.load_weights(...) + + # define the training loop for quantization aware training + def train_loop(model, train_data): + model.train() + for image, target in data_loader: + ... + + # qconfig is the configuration for how we insert observers for a particular + # operator + # qconfig = get_default_qconfig("fbgemm") + # Example of customizing qconfig: + # qconfig = torch.ao.quantization.QConfig( + # activation=FakeQuantize.with_args(observer=MinMaxObserver.with_args(dtype=torch.qint8)), + # weight=FakeQuantize.with_args(observer=MinMaxObserver.with_args(dtype=torch.qint8))) + # `activation` and `weight` are constructors of observer module + + # qconfig_mapping is a collection of quantization configurations, user can + # set the qconfig for each operator (torch op calls, functional calls, module calls) + # in the model through qconfig_mapping + # the following call will get the qconfig_mapping that works best for models + # that target "fbgemm" backend + qconfig_mapping = get_default_qat_qconfig("fbgemm") + + # We can customize qconfig_mapping in different ways, please take a look at + # the docstring for :func:`~torch.ao.quantization.prepare_fx` for different ways + # to configure this + + # example_inputs is a tuple of inputs, that is used to infer the type of the + # outputs in the model + # currently it's not used, but please make sure model(*example_inputs) runs + example_inputs = (torch.randn(1, 3, 224, 224),) + + # TODO: add backend_config after we split the backend_config for fbgemm and qnnpack + # e.g. backend_config = get_default_backend_config("fbgemm") + # `prepare_qat_fx` inserts observers in the model based on qconfig_mapping and + # backend_config, if the configuration for an operator in qconfig_mapping + # is supported in the backend_config (meaning it's supported by the target + # hardware), we'll insert fake_quantize modules according to the qconfig_mapping + # otherwise the configuration in qconfig_mapping will be ignored + # see :func:`~torch.ao.quantization.prepare_fx` for a detailed explanation of + # how qconfig_mapping interacts with backend_config + prepared_model = prepare_qat_fx(float_model, qconfig_mapping, example_inputs) + # Run training + train_loop(prepared_model, train_loop) + + """ + torch._C._log_api_usage_once("quantization_api.quantize_fx.prepare_qat_fx") + return _prepare_fx( + model, + qconfig_mapping, + True, # is_qat + example_inputs, + prepare_custom_config, + backend_config=backend_config, + ) + + +def _convert_fx( + graph_module: GraphModule, + is_reference: bool, + convert_custom_config: Union[ConvertCustomConfig, Dict[str, Any], None] = None, + is_standalone_module: bool = False, + _remove_qconfig: bool = True, + qconfig_mapping: Union[QConfigMapping, Dict[str, Any], None] = None, + backend_config: Union[BackendConfig, Dict[str, Any], None] = None, + is_decomposed: bool = False, +) -> GraphModule: + """ `is_standalone_module`: see docs in :func:`~torch.ao.quantization.prepare_standalone_module_fx` + """ + if convert_custom_config is None: + convert_custom_config = ConvertCustomConfig() + + if isinstance(convert_custom_config, Dict): + warnings.warn( + "Passing a convert_custom_config_dict to convert is deprecated and will not be supported " + "in a future version. Please pass in a ConvertCustomConfig instead.") + convert_custom_config = ConvertCustomConfig.from_dict(convert_custom_config) + + _check_is_graph_module(graph_module) + preserved_attr_names = convert_custom_config.preserved_attributes + preserved_attrs = {attr: getattr(graph_module, attr) for attr in preserved_attr_names if hasattr(graph_module, attr)} + + quantized = convert( + graph_module, + is_reference, + convert_custom_config, + is_standalone_module, + _remove_qconfig_flag=_remove_qconfig, + qconfig_mapping=qconfig_mapping, + backend_config=backend_config, + is_decomposed=is_decomposed, + ) + + attach_preserved_attrs_to_model(quantized, preserved_attrs) + return quantized + + +def convert_fx( + graph_module: GraphModule, + convert_custom_config: Union[ConvertCustomConfig, Dict[str, Any], None] = None, + _remove_qconfig: bool = True, + qconfig_mapping: Union[QConfigMapping, Dict[str, Any], None] = None, + backend_config: Union[BackendConfig, Dict[str, Any], None] = None, +) -> GraphModule: + r""" Convert a calibrated or trained model to a quantized model + + Args: + * `graph_module` (torch.fx.GraphModule): A prepared and calibrated/trained model (GraphModule) + + * `convert_custom_config` (ConvertCustomConfig): custom configurations for convert function. + See :class:`~torch.ao.quantization.fx.custom_config.ConvertCustomConfig` for more details + + * `_remove_qconfig` (bool): Option to remove the qconfig attributes in the model after convert. + + * `qconfig_mapping` (QConfigMapping): config for specifying how to convert a model for quantization. + + The keys must include the ones in the qconfig_mapping passed to `prepare_fx` or `prepare_qat_fx`, + with the same values or `None`. Additional keys can be specified with values set to `None`. + + For each entry whose value is set to None, we skip quantizing that entry in the model:: + + qconfig_mapping = QConfigMapping + .set_global(qconfig_from_prepare) + .set_object_type(torch.nn.functional.add, None) # skip quantizing torch.nn.functional.add + .set_object_type(torch.nn.functional.linear, qconfig_from_prepare) + .set_module_name("foo.bar", None) # skip quantizing module "foo.bar" + + * `backend_config` (BackendConfig): A configuration for the backend which describes how + operators should be quantized in the backend, this includes quantization + mode support (static/dynamic/weight_only), dtype support (quint8/qint8 etc.), + observer placement for each operators and fused operators. + See :class:`~torch.ao.quantization.backend_config.BackendConfig` for more details + + Return: + A quantized model (torch.nn.Module) + + Example:: + + # prepared_model: the model after prepare_fx/prepare_qat_fx and calibration/training + # convert_fx converts a calibrated/trained model to a quantized model for the + # target hardware, this includes converting the model first to a reference + # quantized model, and then lower the reference quantized model to a backend + # Currently, the supported backends are fbgemm (onednn), qnnpack (xnnpack) and + # they share the same set of quantized operators, so we are using the same + # lowering procedure + # + # backend_config defines the corresponding reference quantized module for + # the weighted modules in the model, e.g. nn.Linear + # TODO: add backend_config after we split the backend_config for fbgemm and qnnpack + # e.g. backend_config = get_default_backend_config("fbgemm") + quantized_model = convert_fx(prepared_model) + + """ + torch._C._log_api_usage_once("quantization_api.quantize_fx.convert_fx") + return _convert_fx( + graph_module, + is_reference=False, + convert_custom_config=convert_custom_config, + _remove_qconfig=_remove_qconfig, + qconfig_mapping=qconfig_mapping, + backend_config=backend_config, + ) + + +def convert_to_reference_fx( + graph_module: GraphModule, + convert_custom_config: Union[ConvertCustomConfig, Dict[str, Any], None] = None, + _remove_qconfig: bool = True, + qconfig_mapping: Union[QConfigMapping, Dict[str, Any], None] = None, + backend_config: Union[BackendConfig, Dict[str, Any], None] = None, +) -> GraphModule: + r""" Convert a calibrated or trained model to a reference quantized model, + see https://github.com/pytorch/rfcs/blob/master/RFC-0019-Extending-PyTorch-Quantization-to-Custom-Backends.md for more details, + reference quantized model is a standard representation of a quantized model provided + by FX Graph Mode Quantization, it can be further lowered to run on the target + hardware, like accelerators + + Args: + * `graph_module` (GraphModule): A prepared and calibrated/trained model (GraphModule) + + * `convert_custom_config` (ConvertCustomConfig): custom configurations for convert function. + See :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details. + + * `_remove_qconfig` (bool): Option to remove the qconfig attributes in the model after convert. + + * `qconfig_mapping` (QConfigMapping): config for specifying how to convert a model for quantization. + See :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details. + + * `backend_config` (BackendConfig): A configuration for the backend which describes how + operators should be quantized in the backend. See + :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details. + + Return: + A reference quantized model (GraphModule) + + Example:: + + # prepared_model: the model after prepare_fx/prepare_qat_fx and calibration/training + # TODO: add backend_config after we split the backend_config for fbgemm and qnnpack + # e.g. backend_config = get_default_backend_config("fbgemm") + reference_quantized_model = convert_to_reference_fx(prepared_model) + + """ + torch._C._log_api_usage_once("quantization_api.quantize_fx.convert_to_reference_fx") + return _convert_fx( + graph_module, + is_reference=True, + convert_custom_config=convert_custom_config, + _remove_qconfig=_remove_qconfig, + qconfig_mapping=qconfig_mapping, + backend_config=backend_config, + ) + +def _convert_to_reference_decomposed_fx( + graph_module: GraphModule, + convert_custom_config: Union[ConvertCustomConfig, Dict[str, Any], None] = None, + qconfig_mapping: Union[QConfigMapping, Dict[str, Any], None] = None, + backend_config: Union[BackendConfig, Dict[str, Any], None] = None, +) -> GraphModule: + r""" Convert a calibrated or trained model to a reference quantized model, with + decomposed representation for quantized Tensor + see https://github.com/pytorch/rfcs/blob/master/RFC-0019-Extending-PyTorch-Quantization-to-Custom-Backends.md for more details, + reference quantized model is a standard representation of a quantized model provided + by FX Graph Mode Quantization, it can be further lowered to run on the target + hardware, like accelerators + + Note: this is not public API + + Args: + * `graph_module` (GraphModule): A prepared and calibrated/trained model (GraphModule) + + * `convert_custom_config` (ConvertCustomConfig): custom configurations for convert function. + See :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details. + + * `_remove_qconfig` (bool): Option to remove the qconfig attributes in the model after convert. + + * `qconfig_mapping` (QConfigMapping): config for specifying how to convert a model for quantization. + See :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details. + + * `backend_config` (BackendConfig): A configuration for the backend which describes how + operators should be quantized in the backend. See + :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details. + + Return: + A reference quantized model (GraphModule) with operators working with decomposed quantized Tensor + + Example:: + + # prepared_model: the model after prepare_fx/prepare_qat_fx and calibration/training + # TODO: add backend_config after we split the backend_config for fbgemm and qnnpack + # e.g. backend_config = get_default_backend_config("fbgemm") + reference_quantized_model = _convert_to_reference_decomposed_fx(prepared_model) + + """ + torch._C._log_api_usage_once("quantization_api.quantize_fx._convert_to_reference_decomposed_fx") + return _convert_fx( + graph_module, + is_reference=True, + convert_custom_config=convert_custom_config, + _remove_qconfig=False, + qconfig_mapping=qconfig_mapping, + backend_config=backend_config, + is_decomposed=True, + ) + + +def _convert_standalone_module_fx( + graph_module: GraphModule, + is_reference: bool = False, + convert_custom_config: Union[ConvertCustomConfig, Dict[str, Any], None] = None, +) -> GraphModule: + r""" [Internal use only] Convert a model produced by :func:`~torch.ao.quantization.prepare_standalone_module_fx` + and convert it to a quantized model + + Returns a quantized standalone module, whether input/output is quantized is + specified by prepare_custom_config, with + input_quantized_idxs, output_quantized_idxs, please + see docs for prepare_fx for details + """ + return _convert_fx( + graph_module, + is_reference, + convert_custom_config, + is_standalone_module=True, + ) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/quantize_pt2e.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/quantize_pt2e.py new file mode 100644 index 0000000000000000000000000000000000000000..d9919aa2e9c5be03078c45103f162936137fe571 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/quantize_pt2e.py @@ -0,0 +1,250 @@ +import torch +from torch.fx import GraphModule +from torch.fx import Node + +from .pt2e.prepare import prepare +from .pt2e.qat_utils import ( + _fuse_conv_bn_qat, + _fold_conv_bn_qat, +) +from .pt2e.utils import ( + _get_node_name_to_scope, + _fuse_conv_bn_, + _disallow_eval_train, +) +from .pt2e.representation import reference_representation_rewrite +from .quantize_fx import _convert_to_reference_decomposed_fx +from torch.ao.quantization.quantizer import ( # noqa: F401 + Quantizer, + QuantizationSpecBase, + QuantizationSpec, + FixedQParamsQuantizationSpec, + SharedQuantizationSpec, + DerivedQuantizationSpec, + QuantizationAnnotation, +) +from torch.fx.passes.infra.pass_manager import PassManager +from torch.ao.quantization.pt2e.duplicate_dq_pass import DuplicateDQPass +from torch.ao.quantization.pt2e.port_metadata_pass import PortNodeMetaForQDQ +from torch._inductor.constant_folding import constant_fold + +__all__ = [ + "prepare_pt2e", + "prepare_qat_pt2e", + "convert_pt2e", +] + + +def prepare_pt2e( + model: GraphModule, + quantizer: Quantizer, +) -> GraphModule: + """Prepare a model for post training quantization + + Args: + * `model` (torch.fx.GraphModule): a model captured by `torch.export` API + in the short term we are using `torch._export.capture_pre_autograd_graph`, + in the long term we'll migrate to some `torch.export` API + * `quantizer`: A backend specific quantizer that conveys how user want the + model to be quantized. Tutorial for how to write a quantizer can be found here: + https://pytorch.org/tutorials/prototype/pt2e_quantizer.html + + Return: + A GraphModule with observer (based on quantizer annotation), ready for calibration + + Example:: + + import torch + from torch.ao.quantization.quantize_pt2e import prepare_pt2e + from torch._export import capture_pre_autograd_graph + from torch.ao.quantization.quantizer import ( + XNNPACKQuantizer, + get_symmetric_quantization_config, + ) + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(5, 10) + + def forward(self, x): + return self.linear(x) + + # initialize a floating point model + float_model = M().eval() + + # define calibration function + def calibrate(model, data_loader): + model.eval() + with torch.no_grad(): + for image, target in data_loader: + model(image) + + # Step 1. program capture + # NOTE: this API will be updated to torch.export API in the future, but the captured + # result shoud mostly stay the same + m = capture_pre_autograd_graph(m, *example_inputs) + # we get a model with aten ops + + # Step 2. quantization + # backend developer will write their own Quantizer and expose methods to allow + # users to express how they + # want the model to be quantized + quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config()) + m = prepare_pt2e(m, quantizer) + + # run calibration + # calibrate(m, sample_inference_data) + """ + torch._C._log_api_usage_once("quantization_api.quantize_pt2e.prepare_pt2e") + original_graph_meta = model.meta + node_name_to_scope = _get_node_name_to_scope(model) + # TODO: check qconfig_mapping to make sure conv and bn are both configured + # to be quantized before fusion + # TODO: (maybe) rewrite this with subgraph_rewriter + _fuse_conv_bn_(model) + quantizer.transform_for_annotation(model) + quantizer.annotate(model) + quantizer.validate(model) + model = prepare(model, node_name_to_scope, is_qat=False) + model.meta.update(original_graph_meta) + model = _disallow_eval_train(model) + return model + +def prepare_qat_pt2e( + model: GraphModule, + quantizer: Quantizer, +) -> GraphModule: + """Prepare a model for quantization aware training + + Args: + * `model` (torch.fx.GraphModule): see :func:`~torch.ao.quantization.quantize_pt2e.prepare_pt2e` + * `quantizer`: see :func:`~torch.ao.quantization.quantize_pt2e.prepare_pt2e` + + Return: + A GraphModule with fake quant modules (based on quantizer annotation), ready for + quantization aware training + + Example:: + import torch + from torch.ao.quantization.quantize_pt2e import prepare_qat_pt2e + from torch._export import capture_pre_autograd_graph + from torch.ao.quantization.quantizer import ( + XNNPACKQuantizer, + get_symmetric_quantization_config, + ) + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(5, 10) + + def forward(self, x): + return self.linear(x) + + # initialize a floating point model + float_model = M().eval() + + # define the training loop for quantization aware training + def train_loop(model, train_data): + model.train() + for image, target in data_loader: + ... + + # Step 1. program capture + # NOTE: this API will be updated to torch.export API in the future, but the captured + # result shoud mostly stay the same + m = capture_pre_autograd_graph(m, *example_inputs) + # we get a model with aten ops + + # Step 2. quantization + # backend developer will write their own Quantizer and expose methods to allow + # users to express how they + # want the model to be quantized + quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config()) + m = prepare_qat_pt2e(m, quantizer) + + # run quantization aware training + train_loop(prepared_model, train_loop) + + """ + torch._C._log_api_usage_once("quantization_api.quantize_pt2e.prepare_qat_pt2e") + original_graph_meta = model.meta + node_name_to_scope = _get_node_name_to_scope(model) + quantizer.transform_for_annotation(model) + quantizer.annotate(model) + quantizer.validate(model) + # Perform fusion after annotate to avoid quantizing ops in the new + # subgraph that don't need to be quantized + # TODO: only fuse if conv and bn are both configured to be quantized + _fuse_conv_bn_qat(model) + model = prepare(model, node_name_to_scope, is_qat=True) + model.meta.update(original_graph_meta) + model = _disallow_eval_train(model) + return model + +_QUANT_OPS = [ + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.quantize_per_tensor.tensor, + torch.ops.quantized_decomposed.quantize_per_channel.default, +] +def _quant_node_constraint(n: Node) -> bool: + """If there is any pure ops between get_attr and quantize op they will be const propagated + e.g. get_attr(weight) -> transpose -> quantize -> dequantize* + (Note: dequantize op is not going to be constant propagated) + + This filter is added because we don't want to constant fold the things that are not + related to quantization + """ + return n.op == "call_function" and n.target in _QUANT_OPS + +def convert_pt2e( + model: GraphModule, + use_reference_representation: bool = False, + fold_quantize: bool = True, +) -> GraphModule: + """Convert a calibrated/trained model to a quantized model + + Args: + * `model` (torch.fx.GraphModule): calibrated/trained model + * `use_reference_representation` (bool): boolean flag to indicate whether to produce referece representation or not + * `fold_quantize` (bool): boolean flag for whether fold the quantize op or not + + Returns: + quantized model, either in q/dq representation or reference representation + + Example:: + + # prepared_model: the model produced by `prepare_pt2e`/`prepare_qat_pt2e` and calibration/training + # `convert_pt2e` produces a quantized model that represents quantized computation with + # quantize dequantize ops and fp32 ops by default. + # Please refer to + # https://pytorch.org/tutorials/prototype/pt2e_quant_ptq_static.html#convert-the-calibrated-model-to-a-quantized-model + # for detailed explanation of output quantized model + quantized_model = convert_pt2e(prepared_model) + + """ # flake8: noqa + torch._C._log_api_usage_once("quantization_api.quantize_pt2e.convert_pt2e") + if not isinstance(use_reference_representation, bool): + raise ValueError( + "Unexpected argument type for `use_reference_representation`, " + f"please make sure you intend to pass argument {use_reference_representation} to convert_pt2e") + original_graph_meta = model.meta + model = _convert_to_reference_decomposed_fx(model) + model = _fold_conv_bn_qat(model) + + pm = PassManager([DuplicateDQPass()]) + model = pm(model).graph_module + + pm = PassManager([PortNodeMetaForQDQ()]) + model = pm(model).graph_module + + if fold_quantize: + constant_fold(model, _quant_node_constraint) + + if use_reference_representation: + model = reference_representation_rewrite(model) + + model.meta.update(original_graph_meta) + model = _disallow_eval_train(model) + return model diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/quantizer/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/quantizer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e65652573b1b30bc755ab1861d4f7de0359bcedc --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/quantizer/__init__.py @@ -0,0 +1,21 @@ +from .quantizer import ( + DerivedQuantizationSpec, + EdgeOrNode, + FixedQParamsQuantizationSpec, + QuantizationAnnotation, + QuantizationSpec, + QuantizationSpecBase, + Quantizer, + SharedQuantizationSpec, +) + +__all__ = [ + "EdgeOrNode", + "Quantizer", + "QuantizationSpecBase", + "QuantizationSpec", + "FixedQParamsQuantizationSpec", + "SharedQuantizationSpec", + "DerivedQuantizationSpec", + "QuantizationAnnotation", +] diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/quantizer/__pycache__/composable_quantizer.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/quantizer/__pycache__/composable_quantizer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0bb59828f9ecd449dcfff3d053f9efb03ee3b992 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/quantizer/__pycache__/composable_quantizer.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/quantizer/__pycache__/embedding_quantizer.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/quantizer/__pycache__/embedding_quantizer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..66a10a99e12762594d6e9ae2cf3c49384f93f700 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/quantizer/__pycache__/embedding_quantizer.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/quantizer/__pycache__/quantizer.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/quantizer/__pycache__/quantizer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..10ea2a49f561b46d74bc9b74a8ca4523787ddd09 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/quantizer/__pycache__/quantizer.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/quantizer/__pycache__/utils.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/quantizer/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6ba264fc9556790a034e658490e3182357d4878c Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/quantizer/__pycache__/utils.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/quantizer/__pycache__/xnnpack_quantizer_utils.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/quantizer/__pycache__/xnnpack_quantizer_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6530cc561f2b7b60b17fe21d46c44a5bae79539f Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/quantizer/__pycache__/xnnpack_quantizer_utils.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/quantizer/embedding_quantizer.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/quantizer/embedding_quantizer.py new file mode 100644 index 0000000000000000000000000000000000000000..8ffd2002e580db7cc6cae69161de1e88c788073c --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/quantizer/embedding_quantizer.py @@ -0,0 +1,96 @@ +from __future__ import annotations + +import copy +from typing import List, Set + +import torch +import torch.nn.functional as F +from torch.ao.quantization.observer import PerChannelMinMaxObserver +from torch.ao.quantization.quantizer.quantizer import ( + QuantizationAnnotation, + QuantizationSpec, + Quantizer, +) +from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import ( + OperatorConfig, + OperatorPatternType, + QuantizationConfig, +) + +__all__ = [ + "get_embedding_operators_config", + "EmbeddingQuantizer", +] + + +def get_embedding_operators_config() -> OperatorConfig: + weight_quantization_spec = QuantizationSpec( + dtype=torch.uint8, + qscheme=torch.per_channel_affine_float_qparams, + ch_axis=0, + observer_or_fake_quant_ctr=PerChannelMinMaxObserver.with_args(eps=2**-12), + ) + quantization_config = QuantizationConfig(None, None, weight_quantization_spec, None) + ops: List[OperatorPatternType] = [[torch.nn.Embedding]] + ops.append([F.embedding]) + supported_config_and_operators = OperatorConfig( + config=quantization_config, operators=ops + ) + return copy.deepcopy(supported_config_and_operators) + + +class EmbeddingQuantizer(Quantizer): + def __init__(self): + super().__init__() + + @classmethod + def get_supported_quantization_configs(cls) -> List[QuantizationConfig]: + op_configs: Set[QuantizationConfig] = set({}) + for spec, _ in cls.get_supported_operators(): + op_configs.add(spec) + return list(op_configs) + + @classmethod + def get_supported_operator_for_quantization_config( + cls, quantization_config: QuantizationConfig + ) -> List[OperatorPatternType]: + for config, ops in cls.get_supported_operators(): + # note: this assumes each entry in cls.supported_spec_and_operators + # corresponds to one spec, e.g. we don't have + # [(spec1, op_list1), (spec1, op_list2), (spec2, op_list3)] + # where the first and second entry have the same spec but did not + # merge the op list + if config == quantization_config: + return ops + return [] + + def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: + """just handling global spec for now""" + self._annotate_embedding_ops(model.graph) + return model + + def _annotate_embedding_ops(self, graph: torch.fx.Graph) -> None: + embedding_config: OperatorConfig = get_embedding_operators_config() + for node in graph.nodes: + # Keep node parsing based annotations instead of module partitioners + # just as an example of alternate ways of annotating + if ( + node.op == "call_function" + and node.target == torch.ops.aten.embedding.default + ): + if embedding_config.config.weight is None: + raise ValueError( + "Embedding config must have a valid weight quantization spec." + ) + node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map={ + node.args[0]: embedding_config.config.weight, + } + ) + + def validate(self, model: torch.fx.GraphModule) -> None: + pass + + @classmethod + def get_supported_operators(cls) -> List[OperatorConfig]: + return [get_embedding_operators_config()] diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/quantizer/quantizer.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/quantizer/quantizer.py new file mode 100644 index 0000000000000000000000000000000000000000..5fd1fd20b460310e270bc2a2ef973a4fcc584d66 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/quantizer/quantizer.py @@ -0,0 +1,158 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import Callable, Dict, List, Optional, Tuple, Union + +import torch +from torch import Tensor +from torch.ao.quantization import ObserverOrFakeQuantize +from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor +from torch.fx import Node + +__all__ = [ + "Quantizer", + "QuantizationSpecBase", + "QuantizationSpec", + "FixedQParamsQuantizationSpec", + "EdgeOrNode", + "SharedQuantizationSpec", + "DerivedQuantizationSpec", + "QuantizationAnnotation", +] + + +class QuantizationSpecBase(ABC): # noqa: B024 + """Base class for different types of quantization specs that allows users to + specify how to quantize a Tensor (input/output of a Node) in the model + """ + + pass + + +@dataclass(eq=True, frozen=True) +class QuantizationSpec(QuantizationSpecBase): + """Quantization spec for common operators that allows user to specify how to + quantize a Tensor, this includes dtype, quant_min, quant_max etc. + """ + + dtype: torch.dtype + # observer or fake_quantize constructor such as + # MinMaxObserver, PerChannelHistogramObserver etc. + # or we can attach some custom args to them + # e.g. MinMaxObserver.with_args(eps=eps) + observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor + quant_min: Optional[int] = None + quant_max: Optional[int] = None + qscheme: Optional[torch.qscheme] = None + ch_axis: Optional[int] = None + is_dynamic: bool = False + + def __post_init__(self): + # quant_min must be less than quant_max + if ( + self.quant_min is not None + and self.quant_max is not None + and self.quant_min > self.quant_max + ): + raise ValueError( + f"quant_min {self.quant_min} must be <= quant_max {self.quant_max}." + ) + + # ch_axis must be less than the number of channels + # but no way to check here. Just check that it is not < 0. + if self.ch_axis is not None and self.ch_axis < 0: + raise ValueError("Ch_axis is < 0.") + + +@dataclass(eq=True, frozen=True) +class FixedQParamsQuantizationSpec(QuantizationSpecBase): + dtype: torch.dtype + scale: float + zero_point: int + quant_min: Optional[int] = None + quant_max: Optional[int] = None + qscheme: Optional[torch.qscheme] = None + + +""" +The way we refer to other points of quantization in the graph will be either +an input edge or an output value +input edge is the connection between input node and the node consuming the input, so it's a Tuple[Node, Node] +output value is an fx Node +""" +EdgeOrNode = Union[Tuple[Node, Node], Node] +EdgeOrNode.__module__ = "torch.ao.quantization.quantizer.quantizer" + + +@dataclass(eq=True, frozen=True) +class SharedQuantizationSpec(QuantizationSpecBase): + """ + Quantization spec for the Tensors whose quantization parameters are shared with other Tensors + """ + + # the edge or node to share observer or fake quant instances with + edge_or_node: EdgeOrNode + + +@dataclass(eq=True, frozen=True) +class DerivedQuantizationSpec(QuantizationSpecBase): + """Quantization spec for the Tensors whose quantization parameters are derived from other Tensors""" + + derived_from: List[EdgeOrNode] + derive_qparams_fn: Callable[[List[ObserverOrFakeQuantize]], Tuple[Tensor, Tensor]] + dtype: torch.dtype + quant_min: Optional[int] = None + quant_max: Optional[int] = None + qscheme: Optional[torch.qscheme] = None + ch_axis: Optional[int] = None + + +@dataclass +class QuantizationAnnotation: + """How are input arguemnt or output should be quantized, + expressed as QuantizationSpec, this corresponds to how a Tensor in the + operator Graph is observed (PTQ) or fake quantized (QAT) + """ + + # a map from torch.fx.Node to a type of QuantizationSpecBase + input_qspec_map: Dict[Node, Optional[QuantizationSpecBase]] = field( + default_factory=dict + ) + + # How the output of this node is quantized, expressed as QuantizationSpec + # TODO: change the value to QuantizationSpec in a separate PR + output_qspec: Optional[QuantizationSpecBase] = None + + # For a Node: node1 and edge: (node1, node2), since they are observing the same + # Tensor, we may want to implicitly share observers, this flag allows people to + # turn off this behavior for the output of the node + allow_implicit_sharing: bool = True + + # whether the node is annotated or not + _annotated: bool = False + + +class Quantizer(ABC): + def transform_for_annotation( + self, model: torch.fx.GraphModule + ) -> torch.fx.GraphModule: + """Allows for user defined transforms to run before annotating the graph. + This allows quantizer to allow quantizing part of the model that are otherwise not quantizable. + For example quantizer can + a) decompose a compound operator like scaled dot product attention, + into bmm and softmax if quantizer knows how to quantize bmm/softmax but not sdpa + or b) transform scalars to tensor to allow quantizing scalares. + + Note: this is an optional method + """ + return model + + # annotate nodes in the graph with observer or fake quant constructors + # to convey the desired way of quantization + @abstractmethod + def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: + pass + + # validate the annotated graph is supported by the backend + @abstractmethod + def validate(self, model: torch.fx.GraphModule) -> None: + pass diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/quantizer/utils.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/quantizer/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f25d0916018bae597ee0ef4ab4f3ddfe24eb1ff0 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/quantizer/utils.py @@ -0,0 +1,49 @@ +from typing import List + +from torch.ao.quantization.pt2e.utils import _is_sym_size_node + +from torch.ao.quantization.quantizer.quantizer import QuantizationAnnotation +from torch.fx import Node + + +def _annotate_input_qspec_map(node: Node, input_node: Node, qspec): + quantization_annotation = node.meta.get( + "quantization_annotation", QuantizationAnnotation() + ) + if quantization_annotation.input_qspec_map is None: + quantization_annotation.input_qspec_map = {} + quantization_annotation.input_qspec_map[input_node] = qspec + node.meta["quantization_annotation"] = quantization_annotation + + +def _annotate_output_qspec(node: Node, qspec): + quantization_annotation = node.meta.get( + "quantization_annotation", QuantizationAnnotation() + ) + quantization_annotation.output_qspec = qspec + node.meta["quantization_annotation"] = quantization_annotation + + +def _node_only_used_for_sym_size(node: Node, partition_nodes: List[Node]): + """ + This utility is used to handle cases when dynami_shape=True tracing leads + to symint nodes in the pattern of linear module. In those cases, we need to + distinguish between the nodes that are in input for just extracting value of + some dimentions (and symint nodes) vs. the one that is activation. + For example: + graph(x, y, weight): + size_0 = torch.ops.aten.sym_size([x], [0]) + size_1 = torch.ops.aten.sym_size([y], [1]) + view_size = size_0 * size_1 + size_3 = torch.ops.aten.sym_size([x], [2]) + vie_out = torch.ops.aten.view(x, [view_size, size_3]) + return mm(view_out, weight) + In the example above y node is not actual input. It exist only to extract size_1 + """ + if _is_sym_size_node(node): + return True + + return all( + ((user not in partition_nodes) or _is_sym_size_node(user)) + for user in node.users + ) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/quantizer/x86_inductor_quantizer.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/quantizer/x86_inductor_quantizer.py new file mode 100644 index 0000000000000000000000000000000000000000..797753b7d940bc0358d976fd088993118da0452d --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/quantizer/x86_inductor_quantizer.py @@ -0,0 +1,1016 @@ +import copy +import functools +import itertools +import operator +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Sequence, Set, Tuple + +import torch +import torch.nn.functional as F +from torch.ao.quantization.fake_quantize import ( + FakeQuantize, + FusedMovingAvgObsFakeQuantize, +) +from torch.ao.quantization.observer import ( + HistogramObserver, + MovingAverageMinMaxObserver, + MovingAveragePerChannelMinMaxObserver, + PerChannelMinMaxObserver, + PlaceholderObserver, +) +from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions +from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor +from torch.ao.quantization.quantizer.quantizer import ( + QuantizationAnnotation, + QuantizationSpec, + Quantizer, + SharedQuantizationSpec, +) +from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import ( + _is_annotated, + get_bias_qspec, + get_input_act_qspec, + get_output_act_qspec, + get_weight_qspec, + OperatorConfig, + OperatorPatternType, + QuantizationConfig, +) +from torch.fx import Node +from torch.fx.passes.utils.source_matcher_utils import ( + get_source_partitions, + SourcePartition, +) + +__all__ = [ + "X86InductorQuantizer", + "get_default_x86_inductor_quantization_config", +] + + +@dataclass +class _X86InductorQuantizationAnnotation(QuantizationAnnotation): + # _is_output_of_quantized_pattern: + # * Node as output node of a fusion pattern. + # * The fusion pattern supports int8 data type. + # * The fusion pattern has inputs annotated to insert observer. + _is_output_of_quantized_pattern: bool = False + + +# Operations that: +# 1. Operations are optimized to run with int8 when int8 input provided. +# 2. Operations do not support int8 input and produce fp32 output. +int8_in_int8_out_ops_pt2e: Set = { + torch.ops.aten.max_pool2d.default, + torch.ops.aten.cat.default, + torch.ops.aten.avg_pool2d.default, + torch.ops.aten.adaptive_avg_pool2d.default, + torch.ops.aten.flatten.using_ints, +} + + +# Operations support the int8 data type and exclude operations such as conv and linear. +# A superset of int8_in_int8_out_ops_pt2e incorporating additional operators. +quantizable_ops_pt2e = copy.deepcopy(int8_in_int8_out_ops_pt2e) + +QUANT_ANNOTATION_KEY = "quantization_annotation" + + +def _mark_nodes_as_annotated(nodes: List[Node]): + for node in nodes: + if node is not None: + if QUANT_ANNOTATION_KEY not in node.meta: + node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation() + node.meta[QUANT_ANNOTATION_KEY]._annotated = True + + +def _is_node_annotated(_node): + """ + return True if the node is annotated, otherwise return False + """ + return ( + QUANT_ANNOTATION_KEY in _node.meta + and _node.meta[QUANT_ANNOTATION_KEY]._annotated + ) + + +def _is_any_annotated(nodes: List[Node]): + """ + Given a list of nodes (that represents an operator pattern), + check if any of the node is annotated, return True if any of the node + is annotated, otherwise return False. + """ + return any(_is_node_annotated(node) for node in nodes) + + +def _is_all_annotated(nodes: List[Node]): + """ + Given a list of nodes (that represents an operator pattern), + return True if all of the node is annotated, otherwise return False. + """ + return all(_is_node_annotated(node) for node in nodes) + + +def _is_quantized_op_pt2e(node: torch.fx.Node): + """ + Used for pt2e flow to check if the node is a quantized node: + Case1: the node has been annotated as output node of a fusion pattern. + Case2: the node has been annotated as single quantized node. + """ + if not _is_any_annotated([node]): + # The node has not been annotated, directly return False + return False + quantization_annotation = node.meta.get(QUANT_ANNOTATION_KEY, None) + assert isinstance(quantization_annotation, _X86InductorQuantizationAnnotation) + return quantization_annotation._is_output_of_quantized_pattern + + +def _supported_quantized_operators() -> Dict[str, List[OperatorPatternType]]: + # TODO: Add more supported operators here. + supported_operators: Dict[str, List[OperatorPatternType]] = { + "conv2d": [ + [torch.nn.Conv2d], + [F.conv2d], + ], + } + + # Append Conv Optional(Add) Optioinal(ReLU) + conv_add_relu_options = itertools.product( + [torch.nn.Conv2d, F.conv2d], + [torch.add, operator.add, None], # add + [torch.nn.ReLU, F.relu, None], # relu + ) + for conv_op, add_op, relu_op in conv_add_relu_options: + if add_op is None: + # Append Conv ReLU + supported_operators["conv2d"].append([conv_op, relu_op]) # type: ignore[list-item] + elif relu_op is None: + # Append Conv Add + supported_operators["conv2d"].append([conv_op, add_op]) # type: ignore[list-item] + else: + # Append Conv Add ReLU + supported_operators["conv2d"].append([conv_op, add_op, relu_op]) # type: ignore[list-item] + + return copy.deepcopy(supported_operators) + + +def _get_supported_x86_inductor_config_and_operators() -> List[OperatorConfig]: + supported_config_and_operators: List[OperatorConfig] = [] + for quantization_config in [ + get_default_x86_inductor_quantization_config(), + ]: + ops = _supported_quantized_operators() + for pattern_list in ops.values(): + supported_config_and_operators.append( + OperatorConfig(quantization_config, pattern_list) + ) + return copy.deepcopy(supported_config_and_operators) + + +@functools.lru_cache +def get_default_x86_inductor_quantization_config( + is_qat: bool = False, + is_dynamic: bool = False, +): + extra_args: Dict[str, Any] = {"eps": 2**-12} + if is_qat: + if is_dynamic: + act_observer_or_fake_quant_ctr = FakeQuantize + dynamic_quant_observer = MovingAverageMinMaxObserver.with_args( + averaging_constant=1 + ) + extra_args["observer"] = dynamic_quant_observer + else: + act_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize # type: ignore[assignment] + else: + if is_dynamic: + act_observer_or_fake_quant_ctr = PlaceholderObserver # type: ignore[assignment] + else: + act_observer_or_fake_quant_ctr = HistogramObserver # type: ignore[assignment] + + # Copy from x86 default qconfig from torch/ao/quantization/qconfig.py + act_quantization_spec = QuantizationSpec( + dtype=torch.uint8, + quant_min=0, + quant_max=255, # reduce_range=False + qscheme=torch.per_tensor_affine, + is_dynamic=is_dynamic, + observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args( + **extra_args + ), + ) + + weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = ( + FusedMovingAvgObsFakeQuantize if is_qat else PerChannelMinMaxObserver + ) + + if is_qat: + # Only support per channel quant for now + extra_args["observer"] = MovingAveragePerChannelMinMaxObserver # type: ignore[dict-item] + weight_quantization_spec = QuantizationSpec( + dtype=torch.int8, + quant_min=-128, + quant_max=127, + qscheme=torch.per_channel_symmetric, + ch_axis=0, # 0 corresponding to weight shape = (oc, ic, kh, kw) of conv + is_dynamic=False, + observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args( + **extra_args + ), + ) + bias_quantization_spec = None # will use placeholder observer by default + quantization_config = QuantizationConfig( + act_quantization_spec, + act_quantization_spec, + weight_quantization_spec, + bias_quantization_spec, + is_qat, + ) + return quantization_config + + +def _get_supported_config_and_operators() -> List[OperatorConfig]: + return _get_supported_x86_inductor_config_and_operators() + + +class X86InductorQuantizer(Quantizer): + supported_config_and_operators = _get_supported_config_and_operators() + + def __init__(self): + super().__init__() + self.global_config: QuantizationConfig = None # type: ignore[assignment] + self.operator_type_config: Dict[str, Optional[QuantizationConfig]] = {} + + @classmethod + def get_supported_quantization_configs(cls) -> List[QuantizationConfig]: + op_configs: Set[QuantizationConfig] = set({}) + for spec, _ in cls.supported_config_and_operators: + op_configs.add(spec) + return list(op_configs) + + @classmethod + def get_supported_operator_for_quantization_config( + cls, quantization_config: Optional[QuantizationConfig] + ) -> List[OperatorPatternType]: + if quantization_config is None: + all_ops = [] + for _, ops in cls.supported_config_and_operators: + all_ops.extend(ops) + return all_ops + + for config, ops in cls.supported_config_and_operators: + if config == quantization_config: + return ops + return [] + + def set_global(self, quantization_config: QuantizationConfig): + self.global_config = quantization_config + return self + + def set_config_for_operator_type( + self, operator_type: str, quantization_config: QuantizationConfig + ): + self.operator_type_config[operator_type] = quantization_config + return self + + def _annotate_conv_node_helper( + self, + conv_node: torch.fx.Node, + annotate_output: bool, + quantization_config: QuantizationConfig, + ) -> None: + """Helper function to annotate the conv node""" + input_qspec_map = {} + input_node = conv_node.args[0] + assert isinstance(input_node, Node) + input_qspec_map[input_node] = get_input_act_qspec(quantization_config) + weight_node = conv_node.args[1] + assert isinstance(weight_node, Node) + input_qspec_map[weight_node] = get_weight_qspec(quantization_config) + bias_node = None if len(conv_node.args) == 2 else conv_node.args[2] + if isinstance(bias_node, Node): + input_qspec_map[bias_node] = get_bias_qspec(quantization_config) + if annotate_output: + conv_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( + input_qspec_map=input_qspec_map, + _annotated=True, + _is_output_of_quantized_pattern=True, + ) + else: + conv_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( + input_qspec_map=input_qspec_map, + _annotated=True, + ) + + def _annotate_linear_node_helper( + self, + linear_node: torch.fx.Node, + annotate_output: bool, + quantization_config: QuantizationConfig, + ) -> None: + """Helper function to annotate the linear node""" + input_qspec_map = {} + assert linear_node.target in (torch.ops.aten.linear.default,) + has_bias = len(linear_node.args) == 3 + input_index = 0 + weight_index = 1 + bias_index = 2 + + input_node = linear_node.args[input_index] + assert isinstance(input_node, Node) + input_qspec_map[input_node] = get_input_act_qspec(quantization_config) + + weight_node = linear_node.args[weight_index] + assert isinstance(weight_node, Node) + input_qspec_map[weight_node] = get_weight_qspec(quantization_config) + + bias_node = linear_node.args[bias_index] if has_bias else None + if isinstance(bias_node, Node): + input_qspec_map[bias_node] = get_bias_qspec(quantization_config) + + if annotate_output: + linear_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( + input_qspec_map=input_qspec_map, + _annotated=True, + _is_output_of_quantized_pattern=True, + ) + else: + linear_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( + input_qspec_map=input_qspec_map, _annotated=True + ) + + def _get_output_nodes_of_partitions( + self, + partition_list: List[SourcePartition], + ) -> List[torch.fx.Node]: + """Helper function to get the output node list from partition list""" + output_node_list = [] + for partition in partition_list: + if len(partition.output_nodes) > 1: + raise ValueError("Input partition has more than one output node") + output_node = partition.output_nodes[0] + assert isinstance(output_node, Node) + output_node_list.append(output_node) + if len(output_node_list) != len(partition_list): + raise ValueError( + "length of output_node_list should equal to length of partition_list" + ) + return output_node_list + + def _get_input_idx_for_binary_node( + self, + conv_gemm_node: torch.fx.Node, + binary_node: torch.fx.Node, + ): + """Helper function to check conv_gemm and extra input node index + for binary node fused with conv_gemm. + """ + conv_gemm_node_idx = None + extra_input_node_idx = None + if (binary_node.args[0].op == "call_function") and ( # type: ignore[union-attr] + binary_node.args[0] == conv_gemm_node + ): + conv_gemm_node_idx = 0 + extra_input_node_idx = 1 + elif (binary_node.args[1].op == "call_function") and ( # type: ignore[union-attr] + binary_node.args[1] == conv_gemm_node + ): + conv_gemm_node_idx = 1 + extra_input_node_idx = 0 + extra_input_node = binary_node.args[extra_input_node_idx] # type: ignore[index] + assert isinstance(extra_input_node, Node) + return conv_gemm_node_idx, extra_input_node_idx + + def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: + """just handling global spec for now""" + if self.global_config and self.global_config.input_activation.is_dynamic: # type: ignore[union-attr] + model = self._annotate_for_dynamic_quantization_config(model) + else: + model = self._annotate_for_static_quantization_config(model) + return model + + def _annotate_for_static_quantization_config( + self, model: torch.fx.GraphModule + ) -> torch.fx.GraphModule: + r""" + High-level description of quantization recipe for X86 Inductor Backend: + Step 1: Apply quantization recipe for fusion patterns of conv/linear to enable int8 data type actively. + Step 2: Propagate quantization annotation for patterns besides conv/linear. Go through the pattern in model + from start to the end. If a pattern supports computation with int8 data type and inputs connected to + quantized patterns, annotate its inputs as quantized pattern. + Step 3: Since in step 2, we only annotate the inputs of quantized pattern. For some quantized patterns, + such as maxpool2d, which only supports output with int8 data type when the input is with int8 data type, + we need to annotate the output of this pattern. + """ + + config = self.global_config + + # Step1: Recipe of fusion patterns like conv/linear. + if config.is_qat: + # Annotate QAT specific pattern: mainly due to BN not folded in prepare_qat + self._annotate_qat_conv2d_fusion_pattern(model, config) + + self._annotate_conv2d_fusion_pattern(model, config) + + # Step2: Recipe to propagate annotation for patterns beside conv/linear. + # Go through all the nodes from start to end. + # Recipe refer to https://github.com/intel/intel-extension-for-pytorch/blob/ + # 90d19323d96afc53fcc22ba5a7bb3fb07fdd6c1c/intel_extension_for_pytorch/quantization/_recipe.py#L538 + for node in model.graph.nodes: + self._annotation_propagation_quantizable_pattern(node, config) + + # Step3: For quantizable ops, such as maxpool2d, we need to quantize its output if it is quantized + # in inputs. So, we can fuse dq-operator-q into a quantized op. + # Refer to https://github.com/intel/intel-extension-for-pytorch/blob/ + # 90d19323d96afc53fcc22ba5a7bb3fb07fdd6c1c/intel_extension_for_pytorch/quantization/_recipe.py#L487 + for node in model.graph.nodes: + self._annotate_output_for_int8_in_int8_out_pattern(node, config) + + return model + + def _annotate_for_dynamic_quantization_config( + self, model: torch.fx.GraphModule + ) -> torch.fx.GraphModule: + config = self.global_config + self._annotate_linear(model, config) + return model + + def _annotate_qat_conv2d_fusion_pattern( + self, model: torch.fx.GraphModule, config: QuantizationConfig + ): + # Annotate QAT Specific patterns + self._annotate_qat_conv2d_bn_binary_unary(model, config) + self._annotate_qat_conv2d_bn_binary(model, config) + self._annotate_qat_conv2d_bn_unary(model, config) + self._annotate_qat_conv2d_bn(model, config) + + def _annotate_qat_conv2d_bn_binary_unary( + self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig + ) -> None: + fused_partitions = find_sequential_partitions( + gm, [torch.nn.Conv2d, torch.nn.BatchNorm2d, operator.add, torch.nn.ReLU] + ) + for fused_partition in fused_partitions: + ( + conv_partition, + bn_partition, + binary_partition, + unary_partition, + ) = fused_partition + + ( + conv_node, + bn_output_node, + binary_node, + unary_node, + ) = self._get_output_nodes_of_partitions( + [conv_partition, bn_partition, binary_partition, unary_partition] + ) + if len(bn_output_node.users) != 1: + # Conv BN pattern should only has 1 user. + continue + ( + bn_output_node_idx, + extra_input_node_idx, + ) = self._get_input_idx_for_binary_node(bn_output_node, binary_node) + if (bn_output_node_idx is None) or (extra_input_node_idx is None): + continue + if bn_output_node != binary_node.args[bn_output_node_idx]: + raise ValueError(f"{bn_output_node} doesn't match input of binary node") + extra_input_node = binary_node.args[extra_input_node_idx] + + if ( + conv_node.op != "call_function" + or conv_node.target != torch.ops.aten.conv2d.default + ): + continue + + if _is_annotated([unary_node, binary_node, bn_output_node, conv_node]): + continue + + self._annotate_conv_node_helper(conv_node, False, quantization_config) + + binary_node_input_qspec_map = {} + binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec( + quantization_config + ) + binary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( + input_qspec_map=binary_node_input_qspec_map, + _annotated=True, + ) + unary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( + # TODO Remove the annotate of output in QAT when qat util support pattern matcher. + output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] + _annotated=True, + _is_output_of_quantized_pattern=True, + ) + nodes_to_mark_annotated = list(conv_partition.nodes) + nodes_to_mark_annotated.extend(list(bn_partition.nodes)) + nodes_to_mark_annotated.extend(list(binary_partition.nodes)) + nodes_to_mark_annotated.extend(list(unary_partition.nodes)) + _mark_nodes_as_annotated(nodes_to_mark_annotated) + + def _annotate_qat_conv2d_bn_binary( + self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig + ) -> None: + fused_partitions = find_sequential_partitions( + gm, [torch.nn.Conv2d, torch.nn.BatchNorm2d, operator.add] + ) + for fused_partition in fused_partitions: + conv_partition, bn_partition, binary_partition = fused_partition + ( + conv_node, + bn_output_node, + binary_node, + ) = self._get_output_nodes_of_partitions( + [conv_partition, bn_partition, binary_partition] + ) + if len(bn_output_node.users) != 1: + # Conv BN pattern should only has 1 user. + continue + ( + bn_output_node_idx, + extra_input_node_idx, + ) = self._get_input_idx_for_binary_node(bn_output_node, binary_node) + if (bn_output_node_idx is None) or (extra_input_node_idx is None): + continue + if bn_output_node != binary_node.args[bn_output_node_idx]: + raise ValueError(f"{bn_output_node} doesn't match input of binary node") + + extra_input_node = binary_node.args[extra_input_node_idx] + + if ( + conv_node.op != "call_function" + or conv_node.target != torch.ops.aten.conv2d.default + ): + continue + + if _is_annotated([binary_node, bn_output_node, conv_node]): + continue + + self._annotate_conv_node_helper(conv_node, False, quantization_config) + + binary_node_input_qspec_map = {} + binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec( + quantization_config + ) + binary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( + input_qspec_map=binary_node_input_qspec_map, + # TODO Remove the annotate of output in QAT when qat util support pattern matcher. + output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] + _annotated=True, + _is_output_of_quantized_pattern=True, + ) + nodes_to_mark_annotated = list(conv_partition.nodes) + nodes_to_mark_annotated.extend(list(bn_partition.nodes)) + nodes_to_mark_annotated.extend(list(binary_partition.nodes)) + _mark_nodes_as_annotated(nodes_to_mark_annotated) + + def _annotate_qat_conv2d_bn_unary( + self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig + ) -> None: + fused_partitions = [] + unary_patterns = [ + [torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.ReLU], + [torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.Hardtanh], + [torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.Hardswish], + [torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.ReLU6], + ] + for unary_pattern in unary_patterns: + partitions = find_sequential_partitions(gm, unary_pattern) + if partitions: + # Extend the fused_partitions if partitions is not empty + fused_partitions.extend(partitions) + + for fused_partition in fused_partitions: + conv_partition, bn_partition, unary_partition = fused_partition + ( + conv_node, + bn_output_node, + unary_node, + ) = self._get_output_nodes_of_partitions( + [conv_partition, bn_partition, unary_partition] + ) + + if ( + conv_node.op != "call_function" + or conv_node.target != torch.ops.aten.conv2d.default + ): + continue + + if _is_annotated([unary_node, bn_output_node, conv_node]): + continue + + self._annotate_conv_node_helper(conv_node, False, quantization_config) + unary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( + # TODO Remove the annotate of output in QAT when qat util support pattern matcher. + output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] + _annotated=True, + _is_output_of_quantized_pattern=True, + ) + nodes_to_mark_annotated = list(conv_partition.nodes) + nodes_to_mark_annotated.extend(list(bn_partition.nodes)) + nodes_to_mark_annotated.extend(list(unary_partition.nodes)) + _mark_nodes_as_annotated(nodes_to_mark_annotated) + + def _annotate_qat_conv2d_bn( + self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig + ) -> None: + fused_partitions = find_sequential_partitions( + gm, [torch.nn.Conv2d, torch.nn.BatchNorm2d] + ) + for fused_partition in fused_partitions: + conv_partition, bn_partition = fused_partition + conv_node, bn_output_node = self._get_output_nodes_of_partitions( + [conv_partition, bn_partition] + ) + + if ( + conv_node.op != "call_function" + or conv_node.target != torch.ops.aten.conv2d.default + ): + continue + + if _is_annotated([bn_output_node, conv_node]): + continue + + self._annotate_conv_node_helper(conv_node, False, quantization_config) + bn_output_node.meta[ + QUANT_ANNOTATION_KEY + ] = _X86InductorQuantizationAnnotation( + # TODO Remove the annotate of output in QAT when qat util support pattern matcher. + output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] + _annotated=True, + _is_output_of_quantized_pattern=True, + ) + nodes_to_mark_annotated = list(conv_partition.nodes) + nodes_to_mark_annotated.extend(list(bn_partition.nodes)) + _mark_nodes_as_annotated(nodes_to_mark_annotated) + + def _annotate_conv2d_fusion_pattern( + self, model: torch.fx.GraphModule, config: QuantizationConfig + ): + self._annotate_conv2d_binary_unary(model, config) + self._annotate_conv2d_binary(model, config) + self._annotate_conv2d_unary(model, config) + self._annotate_conv2d(model, config) + self._annotate_linear_unary(model, config) + self._annotate_linear(model, config) + + def _annotate_conv2d_binary_unary( + self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig + ) -> None: + # Conv2d + add + unary op + fused_partitions = find_sequential_partitions( + gm, [torch.nn.Conv2d, operator.add, torch.nn.ReLU] + ) + for fused_partition in fused_partitions: + conv_partition, binary_partition, unary_partition = fused_partition + conv_node, binary_node, unary_node = self._get_output_nodes_of_partitions( + [conv_partition, binary_partition, unary_partition] + ) + if len(conv_node.users) != 1: + # Conv Node should only has 1 user node + continue + conv_node_idx, extra_input_node_idx = self._get_input_idx_for_binary_node( + conv_node, binary_node + ) + if (conv_node_idx is None) or (extra_input_node_idx is None): + continue + if conv_node != binary_node.args[conv_node_idx]: + raise ValueError(f"{conv_node} doesn't match input of binary node") + extra_input_node = binary_node.args[extra_input_node_idx] + if ( + conv_node.op != "call_function" + or conv_node.target != torch.ops.aten.conv2d.default + ): + # No conv node found to be fused with add + continue + if _is_annotated([unary_node, binary_node, conv_node]): + continue + self._annotate_conv_node_helper(conv_node, False, quantization_config) + binary_node_input_qspec_map = {} + binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec( + quantization_config + ) + binary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( + input_qspec_map=binary_node_input_qspec_map, + _annotated=True, + ) + unary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( + _annotated=True, + _is_output_of_quantized_pattern=True, + ) + + def _annotate_conv2d_binary( + self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig + ) -> None: + # Conv2d + add + fused_partitions = find_sequential_partitions( + gm, [torch.nn.Conv2d, operator.add] + ) + for fused_partition in fused_partitions: + conv_partition, binary_partition = fused_partition + conv_node, binary_node = self._get_output_nodes_of_partitions( + [conv_partition, binary_partition] + ) + if len(conv_node.users) != 1: + # Conv Node should only has 1 user node + continue + conv_node_idx, extra_input_node_idx = self._get_input_idx_for_binary_node( + conv_node, binary_node + ) + if (conv_node_idx is None) or (extra_input_node_idx is None): + continue + if conv_node != binary_node.args[conv_node_idx]: + raise ValueError(f"{conv_node} doesn't match input of binary node") + extra_input_node = binary_node.args[extra_input_node_idx] + assert isinstance(conv_node, Node) + if ( + conv_node.op != "call_function" + or conv_node.target != torch.ops.aten.conv2d.default + ): + # No conv node found to be fused with add + continue + if _is_annotated([binary_node, conv_node]): + continue + self._annotate_conv_node_helper(conv_node, False, quantization_config) + binary_node_input_qspec_map = {} + binary_node_input_qspec_map[extra_input_node] = get_input_act_qspec( + quantization_config + ) + binary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( + input_qspec_map=binary_node_input_qspec_map, + _annotated=True, + _is_output_of_quantized_pattern=True, + ) + + def _annotate_conv2d_unary( + self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig + ) -> None: + fused_partitions = [] + unary_patterns = [ + [torch.nn.Conv2d, torch.nn.ReLU], + [torch.nn.Conv2d, torch.nn.Hardtanh], + [torch.nn.Conv2d, torch.nn.Hardswish], + [torch.nn.Conv2d, torch.nn.ReLU6], + ] + for unary_pattern in unary_patterns: + partitions = find_sequential_partitions(gm, unary_pattern) + if partitions: + # Extend the fused_partitions if partitions is not empty + fused_partitions.extend(partitions) + + for fused_partition in fused_partitions: + conv_partition, unary_partition = fused_partition + conv_node, unary_node = self._get_output_nodes_of_partitions( + [conv_partition, unary_partition] + ) + if ( + conv_node.op != "call_function" + or conv_node.target != torch.ops.aten.conv2d.default + ): + continue + if _is_annotated([unary_node, conv_node]): + continue + self._annotate_conv_node_helper(conv_node, False, quantization_config) + unary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( + _annotated=True, + _is_output_of_quantized_pattern=True, + ) + + def _annotate_conv2d( + self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig + ) -> None: + conv_partitions = get_source_partitions( + gm.graph, [torch.nn.Conv2d, torch.nn.functional.conv2d] + ) + conv_partitions = list(itertools.chain.from_iterable(conv_partitions.values())) + for conv_partition in conv_partitions: + if len(conv_partition.output_nodes) > 1: + raise ValueError("conv partition has more than one output node") + conv_node = conv_partition.output_nodes[0] + if ( + conv_node.op != "call_function" + or conv_node.target != torch.ops.aten.conv2d.default + ): + raise ValueError(f"{conv_node} is not an aten conv2d operator") + # skip annotation if it is already annotated + if _is_annotated([conv_node]): + continue + self._annotate_conv_node_helper(conv_node, True, quantization_config) + + def _annotate_maxpool2d( + self, node: Node, quantization_config: QuantizationConfig + ) -> None: + if node.target is not torch.ops.aten.max_pool2d.default: + return + maxpool_node = node + if _is_any_annotated( + [ + maxpool_node, + ] + ): + return + input_node = maxpool_node.args[0] + assert isinstance(input_node, Node) + input_qspec_map = {} + input_qspec_map[input_node] = get_input_act_qspec(quantization_config) + maxpool_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( + input_qspec_map=input_qspec_map, + _annotated=True, + _is_output_of_quantized_pattern=True, + ) + + def _annotate_cat( + self, node: Node, quantization_config: QuantizationConfig + ) -> None: + cat_node = node + input_nodes = cat_node.args[0] + assert isinstance(input_nodes, Sequence) + first_input_node = input_nodes[0] + input_qspec_map = {} + assert isinstance(first_input_node, Node) + assert isinstance(cat_node, Node) + input_qspec_map[first_input_node] = get_input_act_qspec(quantization_config) + share_qparams_with_input_act0_qspec = SharedQuantizationSpec( + (first_input_node, cat_node) + ) + + for input_node in input_nodes[1:]: + if input_node not in input_qspec_map: + # There has the case of cat same nodes: torch.cat([input0, input0], 1) + assert isinstance(input_node, Node) + input_qspec_map[input_node] = share_qparams_with_input_act0_qspec + + cat_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( + input_qspec_map=input_qspec_map, + _annotated=True, + _is_output_of_quantized_pattern=True, + ) + + def _annotation_propagation_quantizable_pattern( + self, node: Node, quantization_config: QuantizationConfig + ) -> None: + # Propagate annotation to quantizable patterns. + if ( + (node.target in quantizable_ops_pt2e) + and (not _is_any_annotated([node])) + and (node.op == "call_function") + ): + + def is_all_inputs_connected_to_quantized_op(input_nodes): + # Ensure all the inputs connect to fusion pattern or quantized node + for input_node in input_nodes: + if not _is_quantized_op_pt2e(input_node): + return False + return True + + if node.target is torch.ops.aten.max_pool2d.default: + # Recipe of maxpool2d: check input arg[0] of maxpool2d is quantized or not + input_nodes_to_check = [node.all_input_nodes[0]] + if not is_all_inputs_connected_to_quantized_op(input_nodes_to_check): + return + self._annotate_maxpool2d(node, quantization_config) + return + elif node.target is torch.ops.aten.cat.default: + input_nodes_to_check = node.all_input_nodes + if not is_all_inputs_connected_to_quantized_op(input_nodes_to_check): + return + self._annotate_cat(node, quantization_config) + else: + input_node = node.all_input_nodes[0] + if not is_all_inputs_connected_to_quantized_op( + [ + input_node, + ] + ): + return + input_qspec_map = {} + input_qspec_map[input_node] = get_input_act_qspec(quantization_config) + node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( + input_qspec_map=input_qspec_map, + _annotated=True, + _is_output_of_quantized_pattern=True, + ) + return + + def _annotate_output_share_observer_as_input( + self, input_node: Node, source_node: Node + ): + source_node_quantization_annotation = ( + source_node.meta[QUANT_ANNOTATION_KEY] + if QUANT_ANNOTATION_KEY in source_node.meta + else None + ) + if ( + source_node_quantization_annotation + and source_node_quantization_annotation._is_output_of_quantized_pattern + ): + edge_or_node = (input_node, source_node) + source_node_quantization_annotation.output_qspec = SharedQuantizationSpec( + edge_or_node + ) + return + + def _annotate_output_for_int8_in_int8_out_pattern( + self, node: Node, quantization_config: QuantizationConfig + ) -> None: + r""" + Check and insert observer at output of node in int8_in_int8_out_ops_pt2e if needed. + Recipe refers to https://github.com/intel/intel-extension-for-pytorch/blob/ + 90d19323d96afc53fcc22ba5a7bb3fb07fdd6c1c/intel_extension_for_pytorch/quantization/_utils.py#L495 + """ + edge_or_node: Tuple[Node, Node] + if (node.target in int8_in_int8_out_ops_pt2e) and (_is_any_annotated([node])): + if node.target == torch.ops.aten.max_pool2d.default: + maxpool_node = node + if not _is_all_annotated( + [ + maxpool_node, + ] + ): + return + # Get the quantization_annotation from getitem_node + maxpool_node_quantization_annotation = ( + maxpool_node.meta[QUANT_ANNOTATION_KEY] + if QUANT_ANNOTATION_KEY in maxpool_node.meta + else None + ) + if ( + maxpool_node_quantization_annotation + and maxpool_node_quantization_annotation._is_output_of_quantized_pattern + ): + # Annotate the output_qspec of getitem_node + input_act = maxpool_node.args[0] + assert isinstance(input_act, Node) + assert isinstance(maxpool_node, Node) + edge_or_node = (input_act, maxpool_node) + maxpool_node_quantization_annotation.output_qspec = ( + SharedQuantizationSpec(edge_or_node) + ) + else: + input_node = node.all_input_nodes[0] + self._annotate_output_share_observer_as_input(input_node, node) + return + + def _annotate_linear( + self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig + ) -> None: + linear_partitions = get_source_partitions( + gm.graph, [torch.nn.Linear, torch.nn.functional.linear] + ) + linear_partitions = list( + itertools.chain.from_iterable(linear_partitions.values()) + ) + for partition in linear_partitions: + if len(partition.output_nodes) > 1: + raise ValueError( + "Linear partition cannot have more than one output node" + ) + linear_node = partition.output_nodes[0] + if linear_node.op != "call_function" or linear_node.target not in ( + torch.ops.aten.linear.default, + ): + raise ValueError(f"{linear_node} is not an aten linear operator") + # skip annotation if it is already annotated + if _is_annotated([linear_node]): + continue + self._annotate_linear_node_helper(linear_node, True, quantization_config) + + def _annotate_linear_unary( + self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig + ) -> None: + postop_list = [ + torch.nn.ReLU, + torch.nn.LeakyReLU, + torch.nn.Tanh, + ] + fused_partitions: List[tuple] = [] + for postop in postop_list: + fused_partitions = fused_partitions + find_sequential_partitions( + gm, [torch.nn.Linear, postop] + ) + for fused_partition in fused_partitions: + linear_partition, unary_partition = fused_partition + linear_node, unary_node = self._get_output_nodes_of_partitions( + [linear_partition, unary_partition] + ) + if linear_node.op != "call_function" or linear_node.target not in ( + torch.ops.aten.linear.default, + ): + continue + if _is_annotated([unary_node, linear_node]): + continue + self._annotate_linear_node_helper(linear_node, False, quantization_config) + unary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( + _annotated=True, + _is_output_of_quantized_pattern=True, + ) + + def validate(self, model: torch.fx.GraphModule) -> None: + pass + + @classmethod + def get_supported_operators(cls) -> List[OperatorConfig]: + return cls.supported_config_and_operators diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/quantizer/xnnpack_quantizer.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/quantizer/xnnpack_quantizer.py new file mode 100644 index 0000000000000000000000000000000000000000..1062c62428975ac42c0c19c6fddf69904e0ebe19 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/quantizer/xnnpack_quantizer.py @@ -0,0 +1,453 @@ +from __future__ import annotations + +import copy +import functools + +from typing import Any, Callable, Dict, List, Optional, Set + +import torch +import torch._dynamo as torchdynamo +import torch.nn.functional as F +from torch.ao.quantization.fake_quantize import ( + FakeQuantize, + FusedMovingAvgObsFakeQuantize, +) +from torch.ao.quantization.observer import ( + HistogramObserver, + MinMaxObserver, + MovingAverageMinMaxObserver, + MovingAveragePerChannelMinMaxObserver, + PerChannelMinMaxObserver, + PlaceholderObserver, +) + +from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor + +from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer + +from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import ( + _convert_scalars_to_attrs, + OP_TO_ANNOTATOR, + OperatorConfig, + OperatorPatternType, + propagate_annotation, + QuantizationConfig, +) + +from torch.fx import Node + + +__all__ = [ + "XNNPACKQuantizer", + "get_symmetric_quantization_config", +] + + +def _get_dynamo_graph(function: Callable, inputs) -> torch.fx.Graph: + gm, _ = torchdynamo.export(function, aten_graph=True)(*inputs) + gm.graph.eliminate_dead_code() + return gm.graph + + +def _get_linear_patterns(input_size: List[int]): + in_channels = input_size[-1] + out_channels = 8 # hard coding but this should not matter + weight = torch.ones((out_channels, in_channels)) + bias = torch.ones((out_channels,)) + act = torch.ones(input_size) + + def linear_op(act, weight, bias=None): + return F.linear(act, weight, bias) + + pattern_w_bias = _get_dynamo_graph(linear_op, (act, weight, bias)) + pattern_wo_bias = _get_dynamo_graph(linear_op, (act, weight)) + return [pattern_w_bias, pattern_wo_bias] + + +def _supported_symmetric_quantized_operators() -> Dict[str, List[OperatorPatternType]]: + supported_operators: Dict[str, List[OperatorPatternType]] = { + # Both conv and linear should be able to handle relu + hardtanh fusion since + # those are clamp ops + "conv2d": [ + [torch.nn.Conv2d, torch.nn.ReLU], + [torch.nn.Conv2d, F.relu], + [F.conv2d, torch.nn.ReLU], + [F.conv2d, F.relu], + ], + "linear": [[torch.nn.Linear], [F.linear]], + "add": [[torch.add]], + "max_pool2d": [[torch.nn.MaxPool2d], [F.max_pool2d]], + "adaptive_avg_pool2d": [ + [torch.nn.AdaptiveAvgPool2d], + [F.adaptive_avg_pool2d], + ], + } + return copy.deepcopy(supported_operators) + + +def _get_supported_symmetric_config_and_operators() -> List[OperatorConfig]: + supported_config_and_operators: List[OperatorConfig] = [] + for quantization_config in [ + get_symmetric_quantization_config(), + get_symmetric_quantization_config(is_qat=True), + get_symmetric_quantization_config(is_per_channel=True), + get_symmetric_quantization_config(is_per_channel=True, is_qat=True), + ]: + ops = _supported_symmetric_quantized_operators() + for pattern_list in ops.values(): + supported_config_and_operators.append( + OperatorConfig(quantization_config, pattern_list) + ) + return copy.deepcopy(supported_config_and_operators) + + +@functools.lru_cache +def get_symmetric_quantization_config( + is_per_channel: bool = False, + is_qat: bool = False, + is_dynamic: bool = False, + act_qmin: int = -128, + act_qmax: int = 127, + weight_qmin: int = -127, + weight_qmax: int = 127, +): + extra_args: Dict[str, Any] = {"eps": 2**-12} + if is_qat: + if is_dynamic: + act_observer_or_fake_quant_ctr = FakeQuantize + dynamic_quant_observer = MovingAverageMinMaxObserver.with_args( + averaging_constant=1 + ) + extra_args["observer"] = dynamic_quant_observer + else: + act_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize # type: ignore[assignment] + else: + if is_dynamic: + act_observer_or_fake_quant_ctr = PlaceholderObserver # type: ignore[assignment] + else: + act_observer_or_fake_quant_ctr = HistogramObserver # type: ignore[assignment] + + act_quantization_spec = QuantizationSpec( + dtype=torch.int8, + quant_min=act_qmin, + quant_max=act_qmax, + qscheme=torch.per_tensor_affine, + is_dynamic=is_dynamic, + observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args( + **extra_args, + ), + ) + weight_qscheme = ( + torch.per_channel_symmetric if is_per_channel else torch.per_tensor_symmetric + ) + weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = ( + MinMaxObserver + ) + if is_qat: + # TODO: qat + per channel? + weight_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize + elif is_per_channel: + weight_observer_or_fake_quant_ctr = PerChannelMinMaxObserver + + extra_args: Dict[str, Any] = {"eps": 2**-12} + if is_qat: + if weight_qscheme == torch.per_tensor_symmetric: + extra_args["observer"] = MovingAverageMinMaxObserver + else: + extra_args["observer"] = MovingAveragePerChannelMinMaxObserver # type: ignore[dict-item] + weight_quantization_spec = QuantizationSpec( + dtype=torch.int8, + quant_min=weight_qmin, + quant_max=weight_qmax, + qscheme=weight_qscheme, + ch_axis=0, + is_dynamic=False, + observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args( + **extra_args + ), + ) + + bias_quantization_spec = None + if is_dynamic: + quantization_config = QuantizationConfig( + act_quantization_spec, + None, + weight_quantization_spec, + bias_quantization_spec, + is_qat, + ) + else: + quantization_config = QuantizationConfig( + act_quantization_spec, + act_quantization_spec, + weight_quantization_spec, + bias_quantization_spec, + is_qat, + ) + return quantization_config + + +def _get_supported_config_and_operators() -> List[OperatorConfig]: + return _get_supported_symmetric_config_and_operators() + + +def _get_module_name_filter(module_name: str): + """Get the module_name_filter function for a given module name, the filter accepts + a node and checks if the node comes from a module that has certain module name + + For example: + node: linear_op = call_function[...](...) # comes from a module with name blocks.sub.linear1 + + + >> module_name_filter = _get_module_name_filter("blocks.sub") + >> print(module_name_filter(node)) + True # the node is from "blocks.sub" based on the fully qualified name "blocks.sub.linear1" + """ + + def module_name_filter(n: Node) -> bool: + # example: { + # 'L__self___sub': ("L['self'].sub", ), + # 'L__self___sub_linear': ("L['self'].sub.linear", ) + # } + # get_attr nodes doesn't have nn_module_stack? + nn_module_stack = n.meta.get("nn_module_stack", {}) + names = [n[len("L['self'].") :] for n, klass in nn_module_stack.values()] + return module_name in names + + return module_name_filter + + +def _get_module_type_filter(tp: Callable): + """Get the module_type_filter function for a given module type, the filter accepts + a node and checks if the node comes from a module that has certain module type + + For example: + node: linear_op = call_function[...](...) # comes from a module with type Block -> Sub -> Linear + + + >> module_type_filter = _get_module_type_filter(Sub) # submodule with type `Sub`, under the `Block` submodule + >> print(module_type_filter(node)) + True # the node is from the submodule `Sub` (same for `Block` and `Linear` as well) + """ + + def module_type_filter(n: Node) -> bool: + # example: { + # 'L__self___sub': ("L['self'].sub", ), + # 'L__self___sub_linear': ("L['self'].sub.linear", ) + # } + nn_module_stack = n.meta.get("nn_module_stack", {}) + types = [t for _, t in nn_module_stack.values()] + return tp in types + + return module_type_filter + + +def _get_not_module_type_or_name_filter( + tp_list: List[Callable], module_name_list: List[str] +) -> Callable[[Node], bool]: + module_type_filters = [_get_module_type_filter(tp) for tp in tp_list] + module_name_list_filters = [_get_module_name_filter(m) for m in module_name_list] + + def not_module_type_or_name_filter(n: Node) -> bool: + return not any(f(n) for f in module_type_filters + module_name_list_filters) + + return not_module_type_or_name_filter + + +class XNNPACKQuantizer(Quantizer): + supported_config_and_operators = _get_supported_config_and_operators() + STATIC_QAT_ONLY_OPS = [ + "conv_bn_relu", + "conv_bn", + ] + + # static quantization ops (both PTQ and QAT) + # Preserve the order that fusions come before singular ops + STATIC_OPS = [ + "linear_relu", + "linear", + "conv_relu", + "conv", + "adaptive_avg_pool2d", + # TODO: move this to BoltNNQuantizer? + "gru_io_only", + "max_pool2d", + "add_relu", + "add", + "mul_relu", + "mul", + "cat", + ] + + DYNAMIC_OPS = [ + "linear", + ] + + def __init__(self): + super().__init__() + self.global_config: Optional[QuantizationConfig] = None + self.operator_type_config: Dict[ + torch._ops.OpOverloadPacket, Optional[QuantizationConfig] + ] = {} + self.module_type_config: Dict[Callable, Optional[QuantizationConfig]] = {} + self.module_name_config: Dict[str, Optional[QuantizationConfig]] = {} + + @classmethod + def get_supported_quantization_configs(cls) -> List[QuantizationConfig]: + op_configs: Set[QuantizationConfig] = set({}) + for spec, _ in cls.supported_config_and_operators: + op_configs.add(spec) + return list(op_configs) + + @classmethod + def get_supported_operator_for_quantization_config( + cls, quantization_config: Optional[QuantizationConfig] + ) -> List[OperatorPatternType]: + if quantization_config is None: + all_ops = [] + for _, ops in cls.supported_config_and_operators: + all_ops.extend(ops) + return all_ops + + for config, ops in cls.supported_config_and_operators: + # note: this assumes each entry in cls.supported_spec_and_operators + # corresponds to one spec, e.g. we don't have + # [(spec1, op_list1), (spec1, op_list2), (spec2, op_list3)] + # where the first and second entry have the same spec but did not + # merge the op list + if config == quantization_config: + return ops + return [] + + def set_global(self, quantization_config: QuantizationConfig) -> XNNPACKQuantizer: + self.global_config = quantization_config + return self + + def set_operator_type( + self, + operator_type: torch._ops.OpOverloadPacket, + quantization_config: QuantizationConfig, + ) -> XNNPACKQuantizer: + self.operator_type_config[operator_type] = quantization_config + return self + + def set_module_type( + self, module_type: Callable, quantization_config: QuantizationConfig + ): + """Set quantization_config for a submodule with type: `module_type`, for example: + quantizer.set_module_name(Sub) or quantizer.set_module_name(nn.Linear), it will quantize all supported operator/operator + patterns in the submodule with this module type with the given `quantization_config` + """ + self.module_type_config[module_type] = quantization_config + return self + + def set_module_name( + self, module_name: str, quantization_config: Optional[QuantizationConfig] + ): + """Set quantization_config for a submodule with name: `module_name`, for example: + quantizer.set_module_name("blocks.sub"), it will quantize all supported operator/operator + patterns in the submodule with this module name with the given `quantization_config` + """ + assert ( + quantization_config is not None + ), " quantization_config == None is not supported yet" + self.module_name_config[module_name] = quantization_config + return self + + def transform_for_annotation( + self, model: torch.fx.GraphModule + ) -> torch.fx.GraphModule: + """Transforms scalar values to tensor attributes""" + return _convert_scalars_to_attrs(model) + + def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: + """just handling global spec for now""" + # hacked for handling dynamic linear quant. will fix later. + if self.global_config and self.global_config.input_activation.is_dynamic: # type: ignore[union-attr] + model = self._annotate_for_dynamic_quantization_config(model) + else: + model = self._annotate_for_static_quantization_config(model) + propagate_annotation(model) + return model + + def _annotate_all_static_patterns( + self, + model: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, + ) -> torch.fx.GraphModule: + # TODO: implement the support for None to be canceling out previous annotations + if quantization_config is None: + return model + + if quantization_config.is_qat: + for op in self.STATIC_QAT_ONLY_OPS: + OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn) + for op in self.STATIC_OPS: + OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn) + return model + + def _annotate_all_dynamic_patterns( + self, + model: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, + ) -> torch.fx.GraphModule: + # TODO: implement the support for None to be canceling out previous annotations + if quantization_config is None: + return model + + for op in self.DYNAMIC_OPS: + OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn) + return model + + def _annotate_for_static_quantization_config( + self, model: torch.fx.GraphModule + ) -> torch.fx.GraphModule: + module_name_list = list(self.module_name_config.keys()) + for module_name, config in self.module_name_config.items(): + self._annotate_all_static_patterns( + model, config, _get_module_name_filter(module_name) + ) + + tp_list = list(self.module_type_config.keys()) + for module_type, config in self.module_type_config.items(): + self._annotate_all_static_patterns( + model, config, _get_module_type_filter(module_type) + ) + + self._annotate_all_static_patterns( + model, + self.global_config, + _get_not_module_type_or_name_filter(tp_list, module_name_list), + ) + return model + + def _annotate_for_dynamic_quantization_config( + self, model: torch.fx.GraphModule + ) -> torch.fx.GraphModule: + module_name_list = list(self.module_name_config.keys()) + for module_name, config in self.module_name_config.items(): + self._annotate_all_dynamic_patterns( + model, config, _get_module_name_filter(module_name) + ) + + tp_list = list(self.module_type_config.keys()) + for module_type, config in self.module_type_config.items(): + self._annotate_all_dynamic_patterns( + model, config, _get_module_type_filter(module_type) + ) + + self._annotate_all_dynamic_patterns( + model, + self.global_config, + _get_not_module_type_or_name_filter(tp_list, module_name_list), + ) + return model + + def validate(self, model: torch.fx.GraphModule) -> None: + pass + + @classmethod + def get_supported_operators(cls) -> List[OperatorConfig]: + return cls.supported_config_and_operators diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..042163705a0b93d5ffdef6562b55581182a8852a --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py @@ -0,0 +1,1032 @@ +import itertools +import operator +from dataclasses import dataclass +from typing import Callable, Dict, List, NamedTuple, Optional + +import torch +import torch.nn.functional as F +from torch._subclasses import FakeTensor +from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix +from torch.ao.quantization.pt2e.export_utils import _WrapperModule +from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions +from torch.ao.quantization.pt2e.utils import ( + _conv1d_bn_example_inputs, + _conv2d_bn_example_inputs, + get_aten_graph_module, +) +from torch.ao.quantization.quantizer import ( + QuantizationAnnotation, + QuantizationSpec, + QuantizationSpecBase, + SharedQuantizationSpec, +) + +from torch.ao.quantization.quantizer.utils import ( + _annotate_input_qspec_map, + _annotate_output_qspec, +) +from torch.fx import Node +from torch.fx.passes.utils.matcher_with_name_node_map_utils import ( + SubgraphMatcherWithNameNodeMap, +) +from torch.fx.passes.utils.source_matcher_utils import get_source_partitions + + +__all__ = [ + "OperatorConfig", + "OperatorPatternType", + "QuantizationConfig", + "get_input_act_qspec", + "get_output_act_qspec", + "get_weight_qspec", + "get_bias_qspec", + "OP_TO_ANNOTATOR", + "propagate_annotation", +] + + +# In the absence of better name, just winging it with QuantizationConfig +@dataclass(eq=True, frozen=True) +class QuantizationConfig: + input_activation: Optional[QuantizationSpec] + output_activation: Optional[QuantizationSpec] + weight: Optional[QuantizationSpec] + bias: Optional[QuantizationSpec] + # TODO: remove, since we can use observer_or_fake_quant_ctr to express this + is_qat: bool = False + + +OperatorPatternType = List[Callable] +OperatorPatternType.__module__ = ( + "torch.ao.quantization.quantizer.xnnpack_quantizer_utils" +) + +AnnotatorType = Callable[ + [ + torch.fx.GraphModule, + Optional[QuantizationConfig], + Optional[Callable[[Node], bool]], + ], + Optional[List[List[Node]]], +] +OP_TO_ANNOTATOR: Dict[str, AnnotatorType] = {} + + +def register_annotator(op: str): + def decorator(annotator: AnnotatorType): + OP_TO_ANNOTATOR[op] = annotator + + return decorator + + +class OperatorConfig(NamedTuple): + # fix List[str] with List[List[Union[nn.Module, FunctionType, BuiltinFunctionType]]] + # Basically we are mapping a quantization config to some list of patterns. + # a pattern is defined as a list of nn module, function or builtin function names + # e.g. [nn.Conv2d, torch.relu, torch.add] + # We have not resolved whether fusion can be considered internal details of the + # quantizer hence it does not need communication to user. + # Note this pattern is not really informative since it does not really + # tell us the graph structure resulting from the list of ops. + config: QuantizationConfig + operators: List[OperatorPatternType] + + +def _is_annotated(nodes: List[Node]): + """ + Given a list of nodes (that represents an operator pattern), + check if any of the node is annotated, return True if any of the node + is annotated, otherwise return False + """ + annotated = False + for node in nodes: + annotated = annotated or ( + "quantization_annotation" in node.meta + and node.meta["quantization_annotation"]._annotated + ) + return annotated + + +def _mark_nodes_as_annotated(nodes: List[Node]): + for node in nodes: + if node is not None: + if "quantization_annotation" not in node.meta: + node.meta["quantization_annotation"] = QuantizationAnnotation() + node.meta["quantization_annotation"]._annotated = True + + +def get_input_act_qspec(quantization_config: Optional[QuantizationConfig]): + if quantization_config is None: + return None + if quantization_config.input_activation is None: + return None + quantization_spec: QuantizationSpec = quantization_config.input_activation + assert quantization_spec.qscheme in [ + torch.per_tensor_affine, + torch.per_tensor_symmetric, + ] + return quantization_spec + + +def get_output_act_qspec(quantization_config: Optional[QuantizationConfig]): + if quantization_config is None: + return None + if quantization_config.output_activation is None: + return None + quantization_spec: QuantizationSpec = quantization_config.output_activation + assert quantization_spec.qscheme in [ + torch.per_tensor_affine, + torch.per_tensor_symmetric, + ] + return quantization_spec + + +def get_weight_qspec(quantization_config: Optional[QuantizationConfig]): + if quantization_config is None: + return None + assert quantization_config is not None + if quantization_config.weight is None: + return None + quantization_spec: QuantizationSpec = quantization_config.weight + if quantization_spec.qscheme not in [ + torch.per_tensor_symmetric, + torch.per_channel_symmetric, + ]: + raise ValueError( + f"Unsupported quantization_spec {quantization_spec} for weight" + ) + return quantization_spec + + +def get_bias_qspec(quantization_config: Optional[QuantizationConfig]): + if quantization_config is None: + return None + assert quantization_config is not None + if quantization_config.bias is None: + return None + quantization_spec: QuantizationSpec = quantization_config.bias + assert ( + quantization_spec.dtype == torch.float + ), "Only float dtype for bias is supported for bias right now" + return quantization_spec + + +@register_annotator("linear") +def _annotate_linear( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[List[List[Node]]]: + annotated_partitions = [] + input_act_qspec = get_input_act_qspec(quantization_config) + output_act_qspec = get_output_act_qspec(quantization_config) + weight_qspec = get_weight_qspec(quantization_config) + bias_qspec = get_bias_qspec(quantization_config) + for node in gm.graph.nodes: + if node.op != "call_function" or node.target != torch.ops.aten.linear.default: + continue + if filter_fn and not filter_fn(node): + continue + act_node = node.args[0] + weight_node = node.args[1] + bias_node = None + if len(node.args) > 2: + bias_node = node.args[2] + + if _is_annotated([node]) is False: # type: ignore[list-item] + _annotate_input_qspec_map( + node, + act_node, + input_act_qspec, + ) + _annotate_input_qspec_map( + node, + weight_node, + weight_qspec, + ) + nodes_to_mark_annotated = [node, weight_node] + if bias_node: + _annotate_input_qspec_map( + node, + bias_node, + bias_qspec, + ) + nodes_to_mark_annotated.append(bias_node) + _annotate_output_qspec(node, output_act_qspec) + _mark_nodes_as_annotated(nodes_to_mark_annotated) + annotated_partitions.append(nodes_to_mark_annotated) + + return annotated_partitions + + +@register_annotator("linear_relu") +def _annotate_linear_relu( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[List[List[Node]]]: + annotated_partitions = [] + input_act_qspec = get_input_act_qspec(quantization_config) + output_act_qspec = get_output_act_qspec(quantization_config) + weight_qspec = get_weight_qspec(quantization_config) + bias_qspec = get_bias_qspec(quantization_config) + for node in gm.graph.nodes: + if node.op != "call_function" or node.target not in [ + torch.ops.aten.relu.default, + torch.ops.aten.relu_.default, + ]: + continue + relu_node = node + maybe_linear_node = node.args[0] + if ( + not isinstance(maybe_linear_node, Node) + or maybe_linear_node.op != "call_function" + or maybe_linear_node.target != torch.ops.aten.linear.default + ): + continue + + linear_node = maybe_linear_node + input_qspec_map = {} + input_act = linear_node.args[0] + assert isinstance(input_act, Node) + input_qspec_map[input_act] = input_act_qspec + + weight = linear_node.args[1] + assert isinstance(weight, Node) + input_qspec_map[weight] = weight_qspec + + # adding weight node to the partition as well + partition = [relu_node, linear_node, weight] + bias = linear_node.args[2] if len(linear_node.args) > 2 else None + if isinstance(bias, Node): + input_qspec_map[bias] = bias_qspec + partition.append(bias) + + if _is_annotated(partition): + continue + + if filter_fn and any(not filter_fn(n) for n in partition): + continue + + linear_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + _annotated=True, + ) + relu_node.meta["quantization_annotation"] = QuantizationAnnotation( + output_qspec=output_act_qspec, + _annotated=True, + ) + _mark_nodes_as_annotated(partition) + annotated_partitions.append(partition) + return annotated_partitions + + +@register_annotator("conv") +def _annotate_conv( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[List[List[Node]]]: + annotated_partitions = [] + for n in gm.graph.nodes: + if n.op != "call_function" or n.target not in [ + torch.ops.aten.conv1d.default, + torch.ops.aten.conv2d.default, + ]: + continue + conv_node = n + + input_qspec_map = {} + input_act = conv_node.args[0] + assert isinstance(input_act, Node) + input_qspec_map[input_act] = get_input_act_qspec(quantization_config) + + weight = conv_node.args[1] + assert isinstance(weight, Node) + input_qspec_map[weight] = get_weight_qspec(quantization_config) + + # adding weight node to the partition as well + partition = [conv_node, conv_node.args[1]] + + bias = conv_node.args[2] if len(conv_node.args) > 2 else None + if isinstance(bias, Node): + input_qspec_map[bias] = get_bias_qspec(quantization_config) + partition.append(bias) + + if _is_annotated(partition): + continue + + if filter_fn and any(not filter_fn(n) for n in partition): + continue + + conv_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=get_output_act_qspec(quantization_config), + _annotated=True, + ) + _mark_nodes_as_annotated(partition) + annotated_partitions.append(partition) + return annotated_partitions + + +@register_annotator("conv_relu") +def _annotate_conv_relu( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[List[List[Node]]]: + annotated_partitions = [] + for n in gm.graph.nodes: + if n.op != "call_function" or n.target not in [ + torch.ops.aten.relu.default, + torch.ops.aten.relu_.default, + ]: + continue + relu_node = n + maybe_conv_node = n.args[0] + if ( + not isinstance(maybe_conv_node, Node) + or maybe_conv_node.op != "call_function" + or maybe_conv_node.target + not in [ + torch.ops.aten.conv1d.default, + torch.ops.aten.conv2d.default, + ] + ): + continue + conv_node = maybe_conv_node + + input_qspec_map = {} + input_act = conv_node.args[0] + assert isinstance(input_act, Node) + input_qspec_map[input_act] = get_input_act_qspec(quantization_config) + + weight = conv_node.args[1] + assert isinstance(weight, Node) + input_qspec_map[weight] = get_weight_qspec(quantization_config) + + # adding weight node to the partition as well + partition = [relu_node, conv_node, conv_node.args[1]] + bias = conv_node.args[2] if len(conv_node.args) > 2 else None + if isinstance(bias, Node): + input_qspec_map[bias] = get_bias_qspec(quantization_config) + partition.append(bias) + + if _is_annotated(partition): + continue + + if filter_fn and any(not filter_fn(n) for n in partition): + continue + + conv_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, _annotated=True + ) + relu_node.meta["quantization_annotation"] = QuantizationAnnotation( + output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] + _annotated=True, + ) + _mark_nodes_as_annotated(partition) + annotated_partitions.append(partition) + return annotated_partitions + + +@register_annotator("conv_bn") +def _annotate_conv_bn( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[List[List[Node]]]: + """ + Find conv + batchnorm parititions + Note: This is only used for QAT. In PTQ, batchnorm should already be fused into the conv. + """ + return _do_annotate_conv_bn(gm, quantization_config, filter_fn, has_relu=False) + + +@register_annotator("conv_bn_relu") +def _annotate_conv_bn_relu( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[List[List[Node]]]: + """ + Find conv + batchnorm + relu parititions + Note: This is only used for QAT. In PTQ, batchnorm should already be fused into the conv. + """ + return _do_annotate_conv_bn(gm, quantization_config, filter_fn, has_relu=True) + + +def _do_annotate_conv_bn( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]], + has_relu: bool, +) -> List[List[Node]]: + """ + Given a function that takes in a `conv_fn` and returns a conv-bn[-relu] pattern, + return a list of annotated partitions. + + The output of the pattern must include a dictionary from string name to node + for the following names: "input", "conv", "weight", "bias", and "output". + """ + + def get_pattern(conv_fn: Callable, relu_is_inplace: bool): + def _conv_bn(x, conv_weight, conv_bias, bn_weight, bn_bias, bn_rm, bn_rv): + conv = conv_fn(x, conv_weight, conv_bias) + bn = F.batch_norm(conv, bn_rm, bn_rv, bn_weight, bn_bias, training=True) + if has_relu: + output = F.relu_(bn) if relu_is_inplace else F.relu(bn) + else: + output = bn + return output, { + "input": x, + "conv": conv, + "weight": conv_weight, + "bias": conv_bias, + "output": output, + } + + return _WrapperModule(_conv_bn) + + # Needed for matching, otherwise the matches gets filtered out due to unused + # nodes returned by batch norm + gm.graph.eliminate_dead_code() + gm.recompile() + + matches = [] + combinations = [ + (F.conv1d, _conv1d_bn_example_inputs), + (F.conv2d, _conv2d_bn_example_inputs), + ] + + # Add `is_cuda` and `relu_is_inplace` dimensions + combinations = itertools.product( + combinations, + [True, False] if torch.cuda.is_available() else [False], # is_cuda + [True, False] if has_relu else [False], # relu_is_inplace + ) + + # Match against all conv dimensions and cuda variants + for (conv_fn, example_inputs), is_cuda, relu_is_inplace in combinations: + pattern = get_pattern(conv_fn, relu_is_inplace) + pattern = get_aten_graph_module(pattern, example_inputs, is_cuda) + pattern.graph.eliminate_dead_code() + pattern.recompile() + matcher = SubgraphMatcherWithNameNodeMap(pattern, ignore_literals=True) + matches.extend(matcher.match(gm.graph)) + + # Annotate nodes returned in the matches + annotated_partitions = [] + for match in matches: + name_node_map = match.name_node_map + input_node = name_node_map["input"] + conv_node = name_node_map["conv"] + weight_node = name_node_map["weight"] + bias_node = name_node_map["bias"] + output_node = name_node_map["output"] + + # TODO: annotate the uses of input, weight, and bias separately instead + # of assuming they come from a single conv node. This is not possible today + # because input may have multiple users, and we can't rely on the conv node + # always being the first user. This was the case in models with skip + # connections like resnet18 + + # Validate conv args + if conv_node.args[0] is not input_node: + raise ValueError("Conv arg did not contain input node ", input_node) + if conv_node.args[1] is not weight_node: + raise ValueError("Conv arg did not contain weight node ", weight_node) + if len(conv_node.args) > 2 and conv_node.args[2] is not bias_node: + raise ValueError("Conv arg did not contain bias node ", bias_node) + + # Skip if the partition is already annotated or is filtered out by the user + partition = [conv_node, weight_node] + if bias_node is not None: + partition.append(bias_node) + if _is_annotated(partition): + continue + if filter_fn and any(not filter_fn(n) for n in partition): + continue + + # Annotate conv inputs and pattern output + input_qspec_map = {} + input_qspec_map[input_node] = get_input_act_qspec(quantization_config) + input_qspec_map[weight_node] = get_weight_qspec(quantization_config) + if bias_node is not None: + input_qspec_map[bias_node] = get_bias_qspec(quantization_config) + conv_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + _annotated=True, + ) + output_node.meta["quantization_annotation"] = QuantizationAnnotation( + output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] + _annotated=True, + ) + _mark_nodes_as_annotated(partition) + annotated_partitions.append(partition) + return annotated_partitions + + +@register_annotator("gru_io_only") +def _annotate_gru_io_only( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[List[List[Node]]]: + gru_partitions = get_source_partitions(gm.graph, [torch.nn.GRU], filter_fn) + gru_partitions = list(itertools.chain.from_iterable(gru_partitions.values())) + annotated_partitions = [] + for gru_partition in gru_partitions: + annotated_partitions.append(gru_partition.nodes) + output_nodes = gru_partition.output_nodes + input_nodes = gru_partition.input_nodes + # skip annotation if it is already annotated + if _is_annotated(input_nodes + output_nodes): + continue + # inside each GRU partition, we should be able to annotate each linear + # subgraph + input_qspec_map: Dict[Node, QuantizationSpecBase] = {} + input_act = input_nodes[0] + input_act_user = next(iter(input_act.users.keys())) + assert isinstance(input_act, Node) + assert isinstance(input_act_user, Node) + input_act_user.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map={ + input_act: get_input_act_qspec(quantization_config), + }, + _annotated=True, + ) + + hidden_state = input_nodes[1] + hidden_state_user = next(iter(hidden_state.users.keys())) + assert isinstance(hidden_state, Node) + assert isinstance(hidden_state_user, Node) + hidden_state_user.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map={ + hidden_state: get_input_act_qspec(quantization_config), + }, + _annotated=True, + ) + + assert len(output_nodes) == 2, "expecting GRU to have two outputs" + for output in output_nodes: + output.meta["quantization_annotation"] = QuantizationAnnotation( + output_qspec=get_output_act_qspec(quantization_config), + _annotated=True, + ) + nodes_to_mark_annotated = list(gru_partition.nodes) + _mark_nodes_as_annotated(nodes_to_mark_annotated) + return annotated_partitions + + +@register_annotator("max_pool2d") +def _annotate_max_pool2d( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[List[List[Node]]]: + module_partitions = get_source_partitions( + gm.graph, [torch.nn.MaxPool2d, torch.nn.functional.max_pool2d], filter_fn + ) + maxpool_partitions = list(itertools.chain.from_iterable(module_partitions.values())) + annotated_partitions = [] + for maxpool_partition in maxpool_partitions: + annotated_partitions.append(maxpool_partition.nodes) + output_node = maxpool_partition.output_nodes[0] + maxpool_node = None + for n in maxpool_partition.nodes: + if n.target == torch.ops.aten.max_pool2d.default: + maxpool_node = n + assert ( + maxpool_node is not None + ), "XNNPACKQuantizer only works with torch.ops.aten.max_pool2d.default, " + "please make sure you are exporting the model correctly" + if _is_annotated([output_node, maxpool_node]): # type: ignore[list-item] + continue + + input_act = maxpool_node.args[0] # type: ignore[union-attr] + assert isinstance(input_act, Node) + + # only annotate maxpool when the output of the input node is annotated + if ( + "quantization_annotation" not in input_act.meta + or not input_act.meta["quantization_annotation"]._annotated + or input_act.meta["quantization_annotation"].output_qspec is None + ): + continue + # input and output of maxpool will share quantization parameter with input of maxpool + act_qspec = SharedQuantizationSpec(input_act) + # act_qspec = get_act_qspec(quantization_config) + maxpool_node.meta["quantization_annotation"] = QuantizationAnnotation( # type: ignore[union-attr] + input_qspec_map={ + input_act: act_qspec, + }, + _annotated=True, + ) + output_node.meta["quantization_annotation"] = QuantizationAnnotation( + output_qspec=act_qspec, + _annotated=True, + ) + return annotated_partitions + + +@register_annotator("adaptive_avg_pool2d") +def _annotate_adaptive_avg_pool2d( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[List[List[Node]]]: + """Always annotate adaptive_avg_pool2d op""" + module_partitions = get_source_partitions( + gm.graph, [torch.nn.AdaptiveAvgPool2d, F.adaptive_avg_pool2d], filter_fn + ) + partitions = list(itertools.chain.from_iterable(module_partitions.values())) + annotated_partitions = [] + for partition in partitions: + pool_node = partition.output_nodes[0] + if ( + pool_node.op != "call_function" + or pool_node.target != torch.ops.aten.adaptive_avg_pool2d.default + ): + raise ValueError(f"{pool_node} is not an aten adaptive_avg_pool2d operator") + + if _is_annotated([pool_node]): + continue + + annotated_partitions.append(partition.nodes) + input_act = pool_node.args[0] + assert isinstance(input_act, Node) + + # only annotate input output sharing operator + # when the output of the input node is annotated + if ( + "quantization_annotation" not in input_act.meta + or not input_act.meta["quantization_annotation"]._annotated + or input_act.meta["quantization_annotation"].output_qspec is None + ): + input_act_qspec = get_input_act_qspec(quantization_config) + else: + input_act_qspec = SharedQuantizationSpec(input_act) + + # output sharing with input + output_act_qspec = SharedQuantizationSpec((input_act, pool_node)) + pool_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map={ + input_act: input_act_qspec, + }, + output_qspec=output_act_qspec, + _annotated=True, + ) + return annotated_partitions + + +def _is_input_large_scalar(node: Node, gm: torch.fx.GraphModule): + """Check if input is a large scalar value. So that we can skip quantization for the node + since histc op (in HistogramObserver) only works for values up to certain upper bound + """ + if node.op == "get_attr": + tensor = getattr(gm, node.target) # type: ignore[arg-type] + # torch.histc works until this upper bound + HISTC_UPPER_BOUND = 3.4028235e15 + return tensor.numel() == 1 and abs(tensor.item()) > HISTC_UPPER_BOUND + return False + + +def _is_input_non_float_tensor(node: Node): + """Check if the input is not a float tensor, so that we can skip quantization for the node + since observers only works with float Tensors + """ + if "val" not in node.meta or not isinstance(node.meta["val"], FakeTensor): + return True + return node.meta["val"].dtype != torch.float32 + + +@register_annotator("add_relu") +def _annotate_add_relu( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[List[List[Node]]]: + fused_partitions = find_sequential_partitions( + gm, [torch.add, torch.nn.ReLU], filter_fn=filter_fn + ) + annotated_partitions = [] + for fused_partition in fused_partitions: + add_partition, relu_partition = fused_partition + annotated_partitions.append(add_partition.nodes + relu_partition.nodes) + if len(relu_partition.output_nodes) > 1: + raise ValueError("Relu partition has more than one output node") + relu_node = relu_partition.output_nodes[0] + if len(add_partition.output_nodes) > 1: + raise ValueError("add partition has more than one output node") + add_node = add_partition.output_nodes[0] + + if _is_annotated([relu_node, add_node]): + continue + + input_act_qspec = get_input_act_qspec(quantization_config) + output_act_qspec = get_output_act_qspec(quantization_config) + + input_qspec_map = {} + input_act0 = add_node.args[0] + if isinstance(input_act0, Node): + if _is_input_large_scalar(input_act0, gm): + continue + if _is_input_non_float_tensor(input_act0): + continue + input_qspec_map[input_act0] = input_act_qspec + + input_act1 = add_node.args[1] + if isinstance(input_act1, Node): + if _is_input_large_scalar(input_act1, gm): + continue + if _is_input_non_float_tensor(input_act1): + continue + input_qspec_map[input_act1] = input_act_qspec + + add_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + _annotated=True, + ) + relu_node.meta["quantization_annotation"] = QuantizationAnnotation( + output_qspec=output_act_qspec, + _annotated=True, + ) + return annotated_partitions + + +@register_annotator("add") +def _annotate_add( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[List[List[Node]]]: + add_partitions = get_source_partitions( + gm.graph, [operator.add, torch.add, operator.iadd], filter_fn + ) + add_partitions = list(itertools.chain.from_iterable(add_partitions.values())) + annotated_partitions = [] + for add_partition in add_partitions: + annotated_partitions.append(add_partition.nodes) + add_node = add_partition.output_nodes[0] + if _is_annotated([add_node]): + continue + + input_act_qspec = get_input_act_qspec(quantization_config) + output_act_qspec = get_output_act_qspec(quantization_config) + + input_qspec_map = {} + input_act0 = add_node.args[0] + if isinstance(input_act0, Node): + if _is_input_large_scalar(input_act0, gm): + continue + if _is_input_non_float_tensor(input_act0): + continue + input_qspec_map[input_act0] = input_act_qspec + + input_act1 = add_node.args[1] + if isinstance(input_act1, Node): + if _is_input_large_scalar(input_act1, gm): + continue + if _is_input_non_float_tensor(input_act1): + continue + input_qspec_map[input_act1] = input_act_qspec + + add_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=output_act_qspec, + _annotated=True, + ) + return annotated_partitions + + +@register_annotator("mul_relu") +def _annotate_mul_relu( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[List[List[Node]]]: + fused_partitions = find_sequential_partitions( + gm, [torch.mul, torch.nn.ReLU], filter_fn=filter_fn + ) + annotated_partitions = [] + for fused_partition in fused_partitions: + mul_partition, relu_partition = fused_partition + annotated_partitions.append(mul_partition.nodes + relu_partition.nodes) + if len(relu_partition.output_nodes) > 1: + raise ValueError("Relu partition has more than one output node") + relu_node = relu_partition.output_nodes[0] + if len(mul_partition.output_nodes) > 1: + raise ValueError("mul partition has more than one output node") + mul_node = mul_partition.output_nodes[0] + + if _is_annotated([relu_node, mul_node]): + continue + + input_act_qspec = get_input_act_qspec(quantization_config) + output_act_qspec = get_output_act_qspec(quantization_config) + + input_qspec_map = {} + input_act0 = mul_node.args[0] + if isinstance(input_act0, Node): + if _is_input_large_scalar(input_act0, gm): + continue + if _is_input_non_float_tensor(input_act0): + continue + input_qspec_map[input_act0] = input_act_qspec + + input_act1 = mul_node.args[1] + if isinstance(input_act1, Node): + if _is_input_large_scalar(input_act1, gm): + continue + if _is_input_non_float_tensor(input_act1): + continue + input_qspec_map[input_act1] = input_act_qspec + + mul_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + _annotated=True, + ) + relu_node.meta["quantization_annotation"] = QuantizationAnnotation( + output_qspec=output_act_qspec, + _annotated=True, + ) + return annotated_partitions + + +@register_annotator("mul") +def _annotate_mul( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[List[List[Node]]]: + mul_partitions = get_source_partitions( + gm.graph, ["mul", "mul_", operator.mul, torch.mul, operator.imul], filter_fn + ) + mul_partitions = list(itertools.chain.from_iterable(mul_partitions.values())) + annotated_partitions = [] + for mul_partition in mul_partitions: + annotated_partitions.append(mul_partition.nodes) + mul_node = mul_partition.output_nodes[0] + if _is_annotated([mul_node]): + continue + + input_act_qspec = get_input_act_qspec(quantization_config) + output_act_qspec = get_output_act_qspec(quantization_config) + + input_qspec_map = {} + input_act0 = mul_node.args[0] + if isinstance(input_act0, Node): + if _is_input_large_scalar(input_act0, gm): + continue + if _is_input_non_float_tensor(input_act0): + continue + input_qspec_map[input_act0] = input_act_qspec + + input_act1 = mul_node.args[1] + if isinstance(input_act1, Node): + if _is_input_large_scalar(input_act1, gm): + continue + if _is_input_non_float_tensor(input_act1): + continue + input_qspec_map[input_act1] = input_act_qspec + + mul_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=output_act_qspec, + _annotated=True, + ) + return annotated_partitions + + +# TODO: remove Optional in return type, fix annotated_partitions logic +@register_annotator("cat") +def _annotate_cat( + gm: torch.fx.GraphModule, + quantization_config: Optional[QuantizationConfig], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[List[List[Node]]]: + cat_partitions = get_source_partitions(gm.graph, [torch.cat], filter_fn) + cat_partitions = list(itertools.chain.from_iterable(cat_partitions.values())) + annotated_partitions = [] + for cat_partition in cat_partitions: + cat_node = cat_partition.output_nodes[0] + if _is_annotated([cat_node]): + continue + + if cat_node.target != torch.ops.aten.cat.default: + # TODO: change this to AnnotationException + raise Exception( + f"Expected cat node: torch.ops.aten.cat.default, but found {cat_node.target}" + " please check if you are calling the correct capture API" + ) + + annotated_partitions.append(cat_partition.nodes) + + input_act_qspec = get_input_act_qspec(quantization_config) + inputs = cat_node.args[0] + + input_qspec_map = {} + input_act0 = inputs[0] + if isinstance(input_act0, Node): + input_qspec_map[input_act0] = input_act_qspec + + shared_with_input0_qspec = SharedQuantizationSpec((input_act0, cat_node)) + for input_act in inputs[1:]: + input_qspec_map[input_act] = shared_with_input0_qspec + + output_act_qspec = shared_with_input0_qspec + + cat_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=output_act_qspec, + _annotated=True, + ) + return annotated_partitions + + +def _is_share_obs_or_fq_op(op: Callable) -> bool: + return op in [ + torch.ops.aten.hardtanh.default, + torch.ops.aten.hardtanh_.default, + torch.ops.aten.mean.default, + torch.ops.aten.mean.dim, + torch.ops.aten.permute.default, + torch.ops.aten.permute_copy.default, + torch.ops.aten.squeeze.dim, + torch.ops.aten.squeeze_copy.dim, + # TODO: remove? + torch.ops.aten.adaptive_avg_pool2d.default, + torch.ops.aten.view_copy.default, + torch.ops.aten.view.default, + torch.ops.aten.slice_copy.Tensor, + torch.ops.aten.flatten.using_ints, + ] + + +def propagate_annotation(model: torch.fx.GraphModule) -> None: + for n in model.graph.nodes: + if n.op != "call_function" or not _is_share_obs_or_fq_op(n.target): + continue + + prev_node = n.args[0] + if not isinstance(prev_node, Node): + continue + + quantization_annotation = prev_node.meta.get("quantization_annotation", None) + if not quantization_annotation: + continue + + output_qspec = quantization_annotation.output_qspec + if not output_qspec: + continue + + # make sure current node is not annotated + if ( + "quantization_annotation" in n.meta + and n.meta["quantization_annotation"]._annotated + ): + continue + + shared_qspec = SharedQuantizationSpec(prev_node) + # propagate the previous output_qspec to the current node + n.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map={ + prev_node: shared_qspec, + }, + output_qspec=shared_qspec, + _annotated=True, + ) + + +# TODO: make the list of ops customizable +def _convert_scalars_to_attrs(model: torch.fx.GraphModule) -> torch.fx.GraphModule: + for n in model.graph.nodes: + if n.op != "call_function" or n.target not in [ + torch.ops.aten.add.Tensor, + torch.ops.aten.mul.Tensor, + ]: + continue + args = list(n.args) + new_args = [] + for i in range(len(args)): + if isinstance(args[i], torch.fx.Node): + new_args.append(args[i]) + continue + prefix = "_tensor_constant_" + get_new_attr_name = get_new_attr_name_with_prefix(prefix) + tensor_constant_name = get_new_attr_name(model) + float_tensor = torch.tensor(float(args[i])) + model.register_buffer(tensor_constant_name, float_tensor) + fake_mode = n.meta["val"].fake_mode + with model.graph.inserting_before(n): + get_attr_node = model.graph.create_node( + "get_attr", tensor_constant_name, (), {} + ) + get_attr_node.meta["val"] = fake_mode.from_tensor( + float_tensor, static_shapes=True + ) + new_args.append(get_attr_node) + n.args = tuple(new_args) + model.recompile() + return model diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/utils.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c5ab78c3ccea63d137eb35828542a4a43cd9811f --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/utils.py @@ -0,0 +1,703 @@ +""" +Utils shared by different modes of quantization (eager/graph) +""" +import functools +import warnings +from collections import OrderedDict +from inspect import getfullargspec, signature +from typing import Any, Callable, Dict, Optional, Tuple, Union + +import torch +from torch.ao.quantization.quant_type import QuantType +from torch.fx import Node +from torch.nn.utils.parametrize import is_parametrized + +NodePattern = Union[Tuple[Node, Node], Tuple[Node, Tuple[Node, Node]], Any] +NodePattern.__module__ = "torch.ao.quantization.utils" + +# This is the Quantizer class instance from torch/quantization/fx/quantize.py. +# Define separately to prevent circular imports. +# TODO(future PR): improve this. +# make this public once fixed (can't be public as is because setting the module directly +# doesn't work) +QuantizerCls = Any + +# Type for fusion patterns, it can be more complicated than the following actually, +# see pattern.md for docs +# TODO: not sure if typing supports recursive data types +Pattern = Union[ + Callable, Tuple[Callable, Callable], Tuple[Callable, Tuple[Callable, Callable]], Any +] +Pattern.__module__ = "torch.ao.quantization.utils" + +# TODO: maybe rename this to MatchInputNode +class MatchAllNode: + """ A node pattern that matches all nodes, used in defining + fusion patterns in FX Graph Mode Quantization + """ + pass + +module_type_list = { + torch.nn.ReLU, + torch.nn.ReLU6, + torch.nn.AdaptiveAvgPool1d, + torch.nn.AdaptiveAvgPool2d, + torch.nn.AdaptiveAvgPool3d, + torch.nn.AvgPool1d, + torch.nn.AvgPool2d, + torch.nn.AvgPool3d, + torch.nn.MaxPool1d, + torch.nn.MaxPool2d, + torch.nn.MaxPool3d, + torch.nn.Identity, + torch.nn.Hardsigmoid, + torch.nn.Sigmoid, + torch.nn.Tanh, +} +func_list = { + torch.nn.functional.adaptive_avg_pool1d, + torch.nn.functional.adaptive_avg_pool2d, + torch.nn.functional.adaptive_avg_pool3d, + torch.nn.functional.elu, + torch.nn.functional.hardswish, + torch.nn.functional.instance_norm, + torch.nn.functional.layer_norm, + torch.nn.functional.leaky_relu, + torch.nn.functional.silu, + torch.nn.functional.mish, + torch.nn.functional.dropout, + torch.nn.functional.max_pool1d, + torch.nn.functional.max_pool2d, + torch.nn.functional.max_pool3d, + torch.nn.functional.relu, + torch.nn.functional.hardtanh, + torch.nn.functional.hardtanh_, + torch.nn.functional.hardsigmoid, + torch.nn.functional.sigmoid, + torch.transpose, + torch.repeat_interleave, + torch.sigmoid, + torch.squeeze, + torch.stack, + torch.sum, + torch.tanh, + torch.unsqueeze, + torch.cat, +} +method_list = { + torch.mean, + 'relu', + 'relu_', + 'contiguous', + 'detach', + 'detach_', + 'hardsigmoid', + 'hardsigmoid_', + 'permute', + 'repeat', + 'repeat_interleave', + 'reshape', + 'resize_', + 'shape', + 'sigmoid', + 'sigmoid_', + 'size', + 'squeeze', + 'squeeze_', + 'tanh', + 'tanh_', + 'transpose', + 'unsqueeze', + 'unsqueeze_', + 'view', +} + +# TODO: not used now, remove +def check_node(node, modules): + # TODO: reuse is_fixed_qparam_node after we move this function to _lower_to_native_backend.py + is_call_function = node.op == "call_function" and node.target in func_list + is_call_method = node.op == "call_method" and node.target in method_list + is_call_module = node.op == "call_module" and type(modules[str(node.target)]) in module_type_list + return is_call_function, is_call_method, is_call_module + +def get_combined_dict(default_dict, additional_dict): + d = default_dict.copy() + d.update(additional_dict) + return d + +def is_per_tensor(qscheme): + return qscheme == torch.per_tensor_affine or \ + qscheme == torch.per_tensor_symmetric + +def is_per_channel(qscheme): + return qscheme in [torch.per_channel_affine, + torch.per_channel_affine_float_qparams, + torch.per_channel_symmetric] + +def getattr_from_fqn(obj: Any, fqn: str) -> Any: + """ + Given an obj and a fqn such as "foo.bar.baz", returns gm.foo.bar.baz. + """ + return functools.reduce(getattr, fqn.split("."), obj) + +def to_underlying_dtype(qdtype): + DTYPE_MAPPING = { + torch.quint8: torch.uint8, + torch.qint8: torch.int8, + torch.qint32: torch.int32, + torch.quint4x2: torch.uint8, + torch.quint2x4: torch.uint8, + torch.uint8: torch.uint8, + torch.int8: torch.int8, + torch.int16: torch.int16, + torch.int32: torch.int32, + } + assert qdtype in DTYPE_MAPPING, "Unsupported dtype: " + str(qdtype) + return DTYPE_MAPPING[qdtype] + +def get_qparam_dict(observer_or_fake_quant): + from torch.ao.quantization.observer import PlaceholderObserver + + qscheme = getattr(observer_or_fake_quant, "qscheme", None) + dtype = observer_or_fake_quant.dtype + qparams = {"qscheme": qscheme, "dtype": dtype} + + if not qscheme or isinstance(observer_or_fake_quant, PlaceholderObserver): + return {"qscheme": None, "dtype": dtype} + + if is_per_tensor(qscheme): + qscheme = torch.per_tensor_affine + elif is_per_channel(qscheme): + # change symmetric to affine since we do not have symmetric + # quantized Tensor + if qscheme == torch.per_channel_symmetric: + qscheme = torch.per_channel_affine + qparams["axis"] = observer_or_fake_quant.ch_axis + else: + raise RuntimeError(f"Unrecognized qscheme: {qscheme}") + # update qscheme, since we don't have symmetric quant qscheme + # in quantized Tensor + qparams["qscheme"] = qscheme + + scale, zero_point = observer_or_fake_quant.calculate_qparams() + qparams["scale"] = scale + qparams["zero_point"] = zero_point + + if hasattr(observer_or_fake_quant, "quant_min"): + qparams["quant_min"] = observer_or_fake_quant.quant_min + if hasattr(observer_or_fake_quant, "quant_max"): + qparams["quant_max"] = observer_or_fake_quant.quant_max + + return qparams + + +def get_swapped_custom_module_class(custom_module, custom_module_class_mapping, qconfig): + """ Get the observed/quantized custom module class that we need + to swap `custom_module` to + Input: + custom_module: input, can be an instance of either a float or observed custom module + custom_module_class_mapping: the float to observed or observed to quantized custom module class mapping + qconfig: qconfig configured for the custom module + + Output: + corresponding observed/quantized custom module class for input custom module instance + """ + quant_type = get_quant_type(qconfig) + class_mapping = custom_module_class_mapping.get(quant_type, {}) + assert type(custom_module) in class_mapping, "did not find corresponding observed " \ + f"module class for {type(custom_module)} in mapping: {class_mapping}" + return class_mapping[type(custom_module)] + +def activation_dtype(qconfig): + assert qconfig is not None + activation = qconfig.activation() + return activation.dtype + +def weight_dtype(qconfig): + assert qconfig is not None + weight = qconfig.weight() + return weight.dtype + +def activation_is_statically_quantized(qconfig): + """ Given a qconfig, decide if the activation needs to be + quantized or not, this includes quantizing to quint8, qint8 and qint32 and float16 + """ + return ( + activation_dtype(qconfig) in [ + torch.quint8, + torch.qint8, + torch.qint32, + torch.float16, + torch.uint8, + torch.int8, + torch.int16, + torch.int32 + ] + and (not activation_is_dynamically_quantized(qconfig)) + ) + +def activation_is_dynamically_quantized(qconfig): + """ Given a qconfig, decide if the activation needs to be + dynamically quantized or not, this includes dynamically quantizing to + quint8, qint8 and float16 + """ + activation_dtype, _, activation_is_dynamic = \ + get_qconfig_dtypes(qconfig) + return activation_is_dynamic + +def activation_is_int8_quantized(qconfig): + """ Given a qconfig, decide if the activation needs to be + quantized to int8 or not, this includes quantizing to quint8, qint8 + """ + return activation_dtype(qconfig) in [torch.quint8, torch.qint8, torch.uint8, torch.int8] + +def activation_is_int32_quantized(qconfig): + """ Given a qconfig, decide if the activation needs to be + quantized to int32 or not + """ + return activation_dtype(qconfig) in [torch.qint32, torch.int32] + +def weight_is_quantized(qconfig): + """ Given a qconfig, decide if the weight needs to be + quantized or not + """ + return weight_dtype(qconfig) in [ + torch.quint8, + torch.qint8, + torch.float16, + torch.quint4x2, + torch.uint8, + torch.int8, + torch.int16, + torch.int32 + ] + +def weight_is_statically_quantized(qconfig): + """ Given a qconfig, decide if the weight needs to be statically + quantized or not + """ + return weight_dtype(qconfig) in [torch.quint8, torch.qint8, torch.uint8, torch.int8] + +def op_is_int8_dynamically_quantized(qconfig) -> bool: + """ Given a qconfig, returns True if this op is using int8 dynamic + quantization + """ + activation_dtype, weight_dtype, activation_is_dynamic = \ + get_qconfig_dtypes(qconfig) + return ( + activation_dtype in [torch.quint8, torch.uint8] and + # for now, the lines below assume fbgemm or qnnpack + weight_dtype in [torch.qint8, torch.int8] and + activation_is_dynamic + ) + +def get_qconfig_dtypes(qconfig): + r""" returns the qconfig tuple for qconfig: + (activation_dtype, weight_dtype, activation_is_dynamic) + """ + assert qconfig is not None + activation = qconfig.activation() + weight = qconfig.weight() + act_is_dynamic = getattr(activation, "is_dynamic", False) + return (activation.dtype, weight.dtype, act_is_dynamic) + +def get_quant_type(qconfig): + assert qconfig is not None + activation = qconfig.activation() + weight = qconfig.weight() + static_dtypes = [torch.quint8, torch.qint8, torch.quint4x2, torch.qint32, torch.uint8, torch.int8, torch.int16, torch.int32] + if weight.dtype in static_dtypes: + if hasattr(activation, 'is_dynamic') and activation.is_dynamic: + return QuantType.DYNAMIC + elif activation.dtype in static_dtypes: + return QuantType.STATIC + else: + return QuantType.WEIGHT_ONLY + + if weight.dtype == torch.float16: + if hasattr(activation, 'is_dynamic') and activation.is_dynamic: + return QuantType.DYNAMIC + elif activation.dtype == torch.float16: + return QuantType.STATIC + + raise Exception(f"Unrecognized dtype combination in get_quant_type: activation({activation.dtype})," + f"weight({weight.dtype})") + +def check_min_max_valid(min_val: torch.Tensor, max_val: torch.Tensor) -> bool: + """ Checks if the given minimum and maximum values are valid, meaning that + they exist and the min value is less than the max value. + """ + if min_val.numel() == 0 or max_val.numel() == 0: + warnings.warn( + "must run observer before calling calculate_qparams. " + + "Returning default values." + ) + return False + + if min_val.dim() == 0 or max_val.dim() == 0: + if min_val == float("inf") and max_val == float("-inf"): + warnings.warn( + "must run observer before calling calculate_qparams. " + + "Returning default values." + ) + + return False + + assert min_val <= max_val, f"min {min_val} should be less than max {max_val}" + else: + assert torch.all( + min_val <= max_val + ), f"min {min_val} should be less than max {max_val}" + + return True + + +def calculate_qmin_qmax(quant_min: int, quant_max: int, has_customized_qrange: bool, dtype: torch.dtype, + reduce_range: bool) -> Tuple[int, int]: + r"""Calculates actual qmin and qmax based on the quantization range, + observer datatype and if range is reduced. + """ + # TODO(jerryzh): Figure out why custom quant_min/quant_max are still adjusted. + if has_customized_qrange: + # This initialization here is to be resolve TorchScript compilation issues and allow + # using of refinement to decouple initial_qmin and initial_qmax from quantization range. + # The actual values of initial_qmin and initial_qmax will be reset below. + if dtype in [torch.qint32, torch.int32]: + initial_quant_min, initial_quant_max = 0, 2**32 - 1 + else: + initial_quant_min, initial_quant_max = 0, 255 + # The following assignment of self.qmin and self.qmax to the local variables and the if check refine the + # attribute from Optional valid integers for use, based on TorchScript's requirements. + custom_quant_min, custom_quant_max = quant_min, quant_max + if custom_quant_min is not None and custom_quant_max is not None: + initial_quant_min, initial_quant_max = ( + custom_quant_min, + custom_quant_max, + ) + + qrange_len = initial_quant_max - initial_quant_min + 1 + if dtype in [torch.qint8, torch.int8]: + assert ( + 0 < qrange_len <= 256 + ), "quantization range should be positive and not exceed the maximum bit range (=256)." + elif dtype in [torch.qint32, torch.int32]: + assert ( + 0 < qrange_len <= 2**32 + ), "quantization range should be positive and not exceed the maximum bit range (=4294967296)." + if reduce_range: + quant_min, quant_max = quant_min // 2, quant_max // 2 + else: + # Fallback onto default 8-bit qmin and qmax calculation if dynamic range is not used. + if dtype in [torch.qint8, torch.int8]: + if reduce_range: + quant_min, quant_max = -64, 63 + else: + quant_min, quant_max = -128, 127 + elif dtype in [torch.quint8, torch.uint8]: + if reduce_range: + quant_min, quant_max = 0, 127 + else: + quant_min, quant_max = 0, 255 + elif dtype in [torch.qint32, torch.int32]: + quant_min, quant_max = -1 * (2 ** 31), (2 ** 31) - 1 + else: + quant_min, quant_max = 0, 15 + return quant_min, quant_max + + +def _parent_name(target): + """ + Turn 'foo.bar' into ['foo', 'bar'] + """ + r = target.rsplit('.', 1) + if len(r) == 1: + return '', r[0] + else: + return r[0], r[1] + +def has_no_children_ignoring_parametrizations(module): + """ + Checks if module._modules is empty or + if module is a parametrization, checks that module._modules only has + the 'parametrizations' module + """ + if len(module._modules) == 0: + return True + elif is_parametrized(module): + return len(module._modules) == 1 and 'parametrizations' in module._modules + else: + return False + +def _get_path_of_module(root: torch.nn.Module, submodule: torch.nn.Module) -> Optional[str]: + """ Get the path (fully qualified name) of a submodule + + Example:: + + >> class M(torch.nn.Module): + def __init__(self): + self.linear = torch.nn.Linear(5, 5) + def forward(self, x): + return self.linear(x) + + >> m = M() + >> l = m.linear + >> _get_path_of_module(m, l) + "linear" + """ + for n, p in root.named_modules(): + if submodule is p: + return n + return None + +def _get_signature_locals(f: Callable, loc: Dict[str, Any]) -> Dict[str, Any]: + """ Get local keyword arguments + + Example:: + + >> def f(self, a, b=9): + pass + >> loc = {"a": 6, "c": 7} + >> _get_signature_locals(f, loc) + {"a": 6} + """ + return {k: v for k, v in loc.items() if k in signature(f).parameters} + +def _get_default_kwargs(f: Callable) -> "OrderedDict[str, Any]": + """ Get all default keyword arguments from function signature + + Example:: + + >> def f(self, a, b=9): + pass + >> _get_default_kwargs(f) + {"b": 9} + """ + kwargs = {} + for name, param in signature(f).parameters.items(): + if param.default is not param.empty: + kwargs[name] = param.default + elif param.kind is param.VAR_POSITIONAL: + kwargs[name] = () + elif param.kind is param.VAR_KEYWORD: + kwargs[name] = {} + return OrderedDict(kwargs) + +def _normalize_kwargs(func: Callable, loc: Dict[str, Any]) -> "OrderedDict[str, Any]": + """ Given a function and local function arguments, normalize the keyword + arguments by filling in default arguments from function signature + + Example:: + + >> def f(self, key1=3, key2=3): + pass + >> loc = {"key2": 6} + >> _normalize_kwargs(f, loc) + {"key1": 3, "key2": 6} + """ + default_kwargs = _get_default_kwargs(func) + local_kwargs = _get_signature_locals(func, loc) + normalized_kwargs = default_kwargs.copy() + for attr, val in local_kwargs.items(): + if attr in normalized_kwargs: + # override the default keyword arguments + normalized_kwargs[attr] = val + return normalized_kwargs + +def validate_qmin_qmax(quant_min: int, quant_max: int) -> None: + r"""Validates that the user-specified quantization range is properly initialized + and within the given bound supported by the observer dtype. + + To accommodate lower-bit quantization with respect to the existing torch.qint8 and + torch.quint8 datatypes, the user can choose to use dynamic quantization range by passing + in a tuple of initial qmin and qmax values. One use case is these customized qmin and qmax + values are used to calculate static estimates of the scale and zero point for aggressive lower-bit + fake quantization. These estimates are compared against parameters learned through backpropagation. + The related literatures for scale and zero point via backpropagation are as follows: + + Learned Step Size Quantization: https://openreview.net/pdf?id=rkgO66VKDS + Trained Quantization Thresholds: https://arxiv.org/pdf/1903.08066.pdf + """ + # The variable names are prefixed with "initial" because their values (qmin and qmax) might be adjusted + # based on whether quantization range is reduced and the datatype (signed/unsigned) used by the observer. + assert ( + quant_min <= 0 <= quant_max + ), "Used-specified quantization range must include 0." + assert ( + quant_min < quant_max + ), "qmin must be strictly less than qmax for user-specified quantization range." + + +# Functionally equivalent to '_calculate_qparams' in observer.py. Observers must be torchscriptable however and qscheme +# as far as I can tell is not allowed to passed as a parameter in torchscript functions. This makes refactoring observer +# to use this utility a massive pain and very gross. For now Im opting just to duplicate as this code seems unlikey to change +# (last update over 1 year ago) and when torchscript is fully deprecated we can refactor. TODO(jakeszwe, jerryzh168) +def determine_qparams( + min_val: torch.Tensor, max_val: torch.Tensor, quant_min: int, quant_max: int, + dtype: torch.dtype, eps: torch.Tensor, has_customized_qrange: bool, + qscheme: torch.qscheme = torch.per_tensor_affine) -> Tuple[torch.Tensor, torch.Tensor]: + r"""Calculates the quantization parameters, given min and max + value tensors. Works for both per tensor and per channel cases + + Args: + min_val: Minimum values per channel + max_val: Maximum values per channel + + Returns: + scales: Scales tensor of shape (#channels,) + zero_points: Zero points tensor of shape (#channels,) + """ + if not check_min_max_valid(min_val, max_val): + return torch.tensor([1.0], device=min_val.device.type), torch.tensor([0], device=min_val.device.type) + + min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) + max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) + + device = min_val_neg.device + scale = torch.ones(min_val_neg.size(), dtype=torch.double, device=device) + zero_point = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device) + + if ( + qscheme == torch.per_tensor_symmetric + or qscheme == torch.per_channel_symmetric + ): + max_val_pos = torch.max(-min_val_neg, max_val_pos) + scale = max_val_pos / (float(quant_max - quant_min) / 2) + scale = torch.max(scale, eps) + if dtype in [torch.uint8, torch.quint8]: + if has_customized_qrange: + # When customized quantization range is used, down-rounded midpoint of the range is chosen. + zero_point = zero_point.new_full( + zero_point.size(), (quant_min + quant_max) // 2 + ) + else: + zero_point = zero_point.new_full(zero_point.size(), 128) + elif qscheme == torch.per_channel_affine_float_qparams: + scale = (max_val - min_val) / float(quant_max - quant_min) + scale = torch.where(scale > eps, scale, torch.ones_like(scale)) + # We use the quantize function + # xq = Round(Xf * inv_scale + zero_point), + # setting zero_point to (-1 * min *inv_scale) we get + # Xq = Round((Xf - min) * inv_scale) + zero_point = -1 * min_val / scale + else: + scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min) + scale = torch.max(scale, eps) + zero_point = quant_min - torch.round(min_val_neg / scale).to(torch.int) + zero_point = torch.clamp(zero_point, quant_min, quant_max) + + # For scalar values, cast them to Tensors of size 1 to keep the shape + # consistent with default values in FakeQuantize. + if len(scale.shape) == 0: + # TODO: switch to scale.item() after adding JIT support + scale = torch.tensor([float(scale)], dtype=scale.dtype, device=device) + if len(zero_point.shape) == 0: + # TODO: switch to zero_point.item() after adding JIT support + zero_point = torch.tensor( + [int(zero_point)], dtype=zero_point.dtype, device=device + ) + if qscheme == torch.per_channel_affine_float_qparams: + zero_point = torch.tensor( + [float(zero_point)], dtype=zero_point.dtype, device=device + ) + + return scale.to(torch.double), zero_point.to(torch.int64) + +def _get_num_pos_args(f: Callable) -> int: + """ Get number of positional args for a function + + Example:: + + >> def f(self, key1=3, key2=3): + pass + >> _get_num_pos_args(f) + 3 + """ + return len(getfullargspec(f).args) + +def get_fqn_to_example_inputs( + model: torch.nn.Module, + example_inputs: Tuple[Any, ...] +) -> Dict[str, Tuple[Any, ...]]: + """ Given a model and its example inputs, return a dictionary from + fully qualified name of submodules to example_inputs for that submodule, + e.g. {"linear1": (tensor1,), "linear2": (tensor2,), "sub": (tensor3,), + "sub.linear1": (tensor4,), ...} + + Used to make quantizing submodules easier now that FX Graph Mode Quantization requires + example inputs. + + Also works for keyword arguments with default values, we would flatten keyword + arguments as positional arguments and fill in the missing keyword args with default + values, e.g. if we have a forward function: + def forward(self, x, key1=3, key2=3): + ... + + and we call it with self.submodule(x, key2=6) + we'll get example_inputs: (x, 3, 6) + + user can also override `key1` with positional arguments as well: + for self.submodule(x, 5, key2=6) + we'll get: (x, 5, 6) + + variable positional arguments and variable positional keyword arguments in forward + function are not supported currently, so please make sure no submodules is using + them. + """ + root = model + fqn_to_example_inputs = {} + + def _patched_module_call(self, *args, **kwargs): + submodule_example_inputs = list(args).copy() + normalized_kwargs = _normalize_kwargs(self.forward, kwargs) + # minus 1 to skipping counting `self` + num_args = _get_num_pos_args(self.forward) - 1 + num_to_pop = num_args - len(submodule_example_inputs) + while num_to_pop and normalized_kwargs: + normalized_kwargs.popitem(last=False) + num_to_pop -= 1 + submodule_example_inputs.extend(normalized_kwargs.values()) + submodule_example_inputs_tuple = tuple(submodule_example_inputs) + fqn = _get_path_of_module(root, self) + if fqn is not None: + fqn_to_example_inputs[fqn] = submodule_example_inputs_tuple + return orig_module_call(self, *args, **kwargs) + + orig_module_call = torch.nn.Module.__call__ + torch.nn.Module.__call__ = _patched_module_call # type: ignore[method-assign] + try: + model(*example_inputs) + finally: + # restore the module call even if there is an exception + torch.nn.Module.__call__ = orig_module_call # type: ignore[method-assign] + return fqn_to_example_inputs + +__all__ = [ + "NodePattern", + "Pattern", + "MatchAllNode", + "check_node", + "get_combined_dict", + "is_per_tensor", + "is_per_channel", + "getattr_from_fqn", + "get_qparam_dict", + "get_swapped_custom_module_class", + "activation_dtype", + "weight_dtype", + "activation_is_statically_quantized", + "activation_is_dynamically_quantized", + "activation_is_int8_quantized", + "activation_is_int32_quantized", + "weight_is_quantized", + "weight_is_statically_quantized", + "op_is_int8_dynamically_quantized", + "get_qconfig_dtypes", + "get_quant_type", + "check_min_max_valid", + "calculate_qmin_qmax", + "has_no_children_ignoring_parametrizations", + "get_fqn_to_example_inputs", + "to_underlying_dtype", + "determine_qparams", + "validate_qmin_qmax", +] diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..62b9ef4e4035a4d1716b0d90a1bfe8d68780c5b0 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/__init__.py @@ -0,0 +1,344 @@ +import builtins +import copy +import dataclasses +import inspect +import io +import os +import sys +import typing +import warnings +from enum import auto, Enum +from typing import ( + Any, + Callable, + Dict, + Iterator, + List, + Optional, + Tuple, + Type, + TYPE_CHECKING, + Union, +) + +import torch +import torch.utils._pytree as pytree +from torch.fx._compatibility import compatibility + +from torch.fx.passes.infra.pass_base import PassResult +from torch.fx.passes.infra.pass_manager import PassManager + +from torch.utils._pytree import ( + FlattenFunc, + FromDumpableContextFn, + ToDumpableContextFn, + UnflattenFunc, +) + +if TYPE_CHECKING: + # Import the following modules during type checking to enable code intelligence features, + # Do not import unconditionally, as they import sympy and importing sympy is very slow + from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint + + +__all__ = [ + "Constraint", + "Dim", + "ExportBackwardSignature", + "ExportGraphSignature", + "ExportedProgram", + "ModuleCallEntry", + "ModuleCallSignature", + "dims", + "dynamic_dim", + "export", + "load", + "register_dataclass", + "save", + "unflatten", + "FlatArgsAdapter", + "UnflattenedModule", +] + + +from .dynamic_shapes import Constraint, Dim, dims, dynamic_dim +from .exported_program import ExportedProgram, ModuleCallEntry, ModuleCallSignature +from .graph_signature import ExportBackwardSignature, ExportGraphSignature +from .unflatten import FlatArgsAdapter, unflatten, UnflattenedModule + + +PassType = Callable[[torch.fx.GraphModule], Optional[PassResult]] + + +def export( + mod: torch.nn.Module, + args: Tuple[Any, ...], + kwargs: Optional[Dict[str, Any]] = None, + *, + dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None, + strict: bool = True, + preserve_module_call_signature: Tuple[str, ...] = (), +) -> ExportedProgram: + """ + :func:`export` takes an arbitrary Python callable (an nn.Module, a function or + a method) along with example inputs, and produces a traced graph representing + only the Tensor computation of the function in an Ahead-of-Time (AOT) fashion, + which can subsequently be executed with different inputs or serialized. The + traced graph (1) produces normalized operators in the functional ATen operator set + (as well as any user-specified custom operators), (2) has eliminated all Python control + flow and data structures (with certain exceptions), and (3) records the set of + shape constraints needed to show that this normalization and control-flow elimination + is sound for future inputs. + + **Soundness Guarantee** + + While tracing, :func:`export()` takes note of shape-related assumptions + made by the user program and the underlying PyTorch operator kernels. + The output :class:`ExportedProgram` is considered valid only when these + assumptions hold true. + + Tracing makes assumptions on the shapes (not values) of input tensors. + Such assumptions must be validated at graph capture time for :func:`export` + to succeed. Specifically: + + - Assumptions on static shapes of input tensors are automatically validated without additional effort. + - Assumptions on dynamic shape of input tensors require explicit specification + by using the :func:`Dim` API to construct dynamic dimensions and by associating + them with example inputs through the ``dynamic_shapes`` argument. + + If any assumption can not be validated, a fatal error will be raised. When that happens, + the error message will include suggested fixes to the specification that are needed + to validate the assumptions. For example :func:`export` might suggest the + following fix to the definition of a dynamic dimension ``dim0_x``, say appearing in the + shape associated with input ``x``, that was previously defined as ``Dim("dim0_x")``:: + + dim = Dim("dim0_x", max=5) + + This example means the generated code requires dimension 0 of input ``x`` to be less + than or equal to 5 to be valid. You can inspect the suggested fixes to dynamic dimension + definitions and then copy them verbatim into your code without needing to change the + ``dynamic_shapes`` argument to your :func:`export` call. + + Args: + mod: We will trace the forward method of this module. + + args: Example positional inputs. + + kwargs: Optional example keyword inputs. + + dynamic_shapes: + An optional argument where the type should either be: + 1) a dict from argument names of ``f`` to their dynamic shape specifications, + 2) a tuple that specifies dynamic shape specifications for each input in original order. + If you are specifying dynamism on keyword args, you will need to pass them in the order that + is defined in the original function signature. + + The dynamic shape of a tensor argument can be specified as either + (1) a dict from dynamic dimension indices to :func:`Dim` types, where it is + not required to include static dimension indices in this dict, but when they are, + they should be mapped to None; or (2) a tuple / list of :func:`Dim` types or None, + where the :func:`Dim` types correspond to dynamic dimensions, and static dimensions + are denoted by None. Arguments that are dicts or tuples / lists of tensors are + recursively specified by using mappings or sequences of contained specifications. + + strict: When enabled (default), the export function will trace the program through + TorchDynamo which will ensure the soundness of the resulting graph. Otherwise, the + exported program will not validate the implicit assumptions baked into the graph and + may cause behavior divergence between the original model and the exported one. This is + useful when users need to workaround bugs in the tracer, or simply want incrementally + enable safety in their models. Note that this does not affect the resulting IR spec + to be different and the model will be serialized in the same way regardless of what value + is passed here. + WARNING: This option is experimental and use this at your own risk. + + Returns: + An :class:`ExportedProgram` containing the traced callable. + + **Acceptable input/output types** + + Acceptable types of inputs (for ``args`` and ``kwargs``) and outputs include: + + - Primitive types, i.e. ``torch.Tensor``, ``int``, ``float``, ``bool`` and ``str``. + - Dataclasses, but they must be registered by calling :func:`register_dataclass` first. + - (Nested) Data structures comprising of ``dict``, ``list``, ``tuple``, ``namedtuple`` and + ``OrderedDict`` containing all above types. + + """ + from ._trace import _export + + if not isinstance(mod, torch.nn.Module): + raise ValueError( + f"Expected `mod` to be an instance of `torch.nn.Module`, got {type(mod)}." + ) + + return _export( + mod, + args, + kwargs, + dynamic_shapes, + strict=strict, + preserve_module_call_signature=preserve_module_call_signature, + ) + + +def save( + ep: ExportedProgram, + f: Union[str, os.PathLike, io.BytesIO], + *, + extra_files: Optional[Dict[str, Any]] = None, + opset_version: Optional[Dict[str, int]] = None, +) -> None: + """ + + .. warning:: + Under active development, saved files may not be usable in newer versions + of PyTorch. + + Saves an :class:`ExportedProgram` to a file-like object. It can then be + loaded using the Python API :func:`torch.export.load `. + + Args: + ep (ExportedProgram): The exported program to save. + + f (Union[str, os.PathLike, io.BytesIO): A file-like object (has to + implement write and flush) or a string containing a file name. + + extra_files (Optional[Dict[str, Any]]): Map from filename to contents + which will be stored as part of f. + + opset_version (Optional[Dict[str, int]]): A map of opset names + to the version of this opset + + + Example:: + + import torch + import io + + class MyModule(torch.nn.Module): + def forward(self, x): + return x + 10 + + ep = torch.export.export(MyModule(), (torch.randn(5),)) + + # Save to file + torch.export.save(ep, 'exported_program.pt2') + + # Save to io.BytesIO buffer + buffer = io.BytesIO() + torch.export.save(ep, buffer) + + # Save with extra files + extra_files = {'foo.txt': b'bar'.decode('utf-8')} + torch.export.save(ep, 'exported_program.pt2', extra_files=extra_files) + + """ + from torch._export import save + + if not isinstance(ep, ExportedProgram): + raise TypeError( + f"The 'ep' parameter must be an instance of 'ExportedProgram', got '{type(ep).__name__}' instead." + ) + + save(ep, f, extra_files=extra_files, opset_version=opset_version) + + +def load( + f: Union[str, os.PathLike, io.BytesIO], + *, + extra_files: Optional[Dict[str, Any]] = None, + expected_opset_version: Optional[Dict[str, int]] = None, +) -> ExportedProgram: + """ + + .. warning:: + Under active development, saved files may not be usable in newer versions + of PyTorch. + + Loads an :class:`ExportedProgram` previously saved with + :func:`torch.export.save `. + + Args: + ep (ExportedProgram): The exported program to save. + + f (Union[str, os.PathLike, io.BytesIO): A file-like object (has to + implement write and flush) or a string containing a file name. + + extra_files (Optional[Dict[str, Any]]): The extra filenames given in + this map would be loaded and their content would be stored in the + provided map. + + expected_opset_version (Optional[Dict[str, int]]): A map of opset names + to expected opset versions + + Returns: + An :class:`ExportedProgram` object + + Example:: + + import torch + import io + + # Load ExportedProgram from file + ep = torch.export.load('exported_program.pt2') + + # Load ExportedProgram from io.BytesIO object + with open('exported_program.pt2', 'rb') as f: + buffer = io.BytesIO(f.read()) + buffer.seek(0) + ep = torch.export.load(buffer) + + # Load with extra files. + extra_files = {'foo.txt': ''} # values will be replaced with data + ep = torch.export.load('exported_program.pt2', extra_files=extra_files) + print(extra_files['foo.txt']) + print(ep(torch.randn(5))) + """ + from torch._export import load + + return load( + f, extra_files=extra_files, expected_opset_version=expected_opset_version + ) + + +def register_dataclass( + cls: Type[Any], + *, + serialized_type_name: Optional[str] = None, +) -> None: + """ + Registers a dataclass as a valid input/output type for :func:`torch.export.export`. + + Args: + cls: the dataclass type to register + serialized_type_name: The serialized name for the dataclass. This is + required if you want to serialize the pytree TreeSpec containing this + dataclass. + + Example:: + + @dataclass + class InputDataClass: + feature: torch.Tensor + bias: int + + class OutputDataClass: + res: torch.Tensor + + torch.export.register_dataclass(InputDataClass) + torch.export.register_dataclass(OutputDataClass) + + def fn(o: InputDataClass) -> torch.Tensor: + res = res=o.feature + o.bias + return OutputDataClass(res=res) + + ep = torch.export.export(fn, (InputDataClass(torch.ones(2, 2), 1), )) + print(ep) + + """ + + from torch._export.utils import register_dataclass_as_pytree_node + + return register_dataclass_as_pytree_node( + cls, serialized_type_name=serialized_type_name + ) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/__pycache__/_safeguard.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/__pycache__/_safeguard.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5aa4d273a1af1e3bddc3cc619c9f09edfdf18a97 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/__pycache__/_safeguard.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/__pycache__/_trace.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/__pycache__/_trace.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..957f88de5f4a90a3a3e0b168e5302c57f3404e56 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/__pycache__/_trace.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/__pycache__/_unlift.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/__pycache__/_unlift.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a3f2d6ab2ba1da7490ce493c6011d0cf9b9efcbf Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/__pycache__/_unlift.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/__pycache__/exported_program.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/__pycache__/exported_program.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b8a2f15f5c72a37044cbdb3051cbe1ce65f7e09e Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/__pycache__/exported_program.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/_trace.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/_trace.py new file mode 100644 index 0000000000000000000000000000000000000000..a7170ea2a17eb0fe03c890a1fb0f278a46ef9c38 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/_trace.py @@ -0,0 +1,1060 @@ +import dataclasses +import functools +import inspect +import logging +import re +import time +import warnings +from contextlib import contextmanager, nullcontext +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union + +import torch +import torch._dynamo +import torch.fx + +import torch.utils._pytree as pytree +from torch._dynamo.exc import UserError, UserErrorType +from torch._export.non_strict_utils import ( + make_constraints, + make_fake_inputs, + make_fake_params_buffers, +) +from torch._export.passes.add_runtime_assertions_for_constraints_pass import ( + _AddRuntimeAssertionsForInlineConstraintsPass, +) +from torch._export.passes.collect_tracepoints_pass import CollectTracepointsPass +from torch._export.passes.lift_constants_pass import ( + ConstantAttrMap, + lift_constants_pass, + rewrite_script_object_meta, +) +from torch._export.wrappers import _wrap_submodules +from torch._functorch.aot_autograd import aot_export_module +from torch._guards import detect_fake_mode +from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode +from torch._utils_internal import log_export_usage +from torch.export.exported_program import OutputKind +from torch.fx.experimental.symbolic_shapes import ( + ConstraintViolationError, + free_unbacked_symbols, + GuardOnDataDependentSymNode, + ShapeEnv, +) +from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo +from torch.utils._sympy.value_ranges import ValueRangeError + +from ._safeguard import AutogradStateOpsFailSafeguard + +from .dynamic_shapes import _process_constraints, Constraint +from .exported_program import ( + _disable_prexisiting_fake_mode, + ExportedProgram, + InputKind, + ModuleCallEntry, + ModuleCallSignature, +) +from .graph_signature import ( + _sig_to_specs, + ArgumentSpec, + ConstantArgument, + CustomObjArgument, + ExportGraphSignature, + SymIntArgument, + TensorArgument, +) + + +log = logging.getLogger(__name__) + + +@dataclasses.dataclass +class ExportDynamoConfig: + """ + Manage Export-specific configurations of Dynamo. + """ + + allow_rnn: bool = True + reorderable_logging_functions: Set[Callable] = dataclasses.field( + default_factory=set + ) + + +DEFAULT_EXPORT_DYNAMO_CONFIG = ExportDynamoConfig() +DEFAULT_EXPORT_DYNAMO_CONFIG.reorderable_logging_functions = { + logging.critical, + logging.debug, + logging.error, + logging.exception, + logging.info, + logging.log, + logging.warning, + print, + warnings.warn, +} + + +@contextmanager +def _ignore_backend_decomps(): + orig_mkldnn_flag = torch.backends.mkldnn.set_flags(False) + orig_nnpack_flag = torch.backends.nnpack.set_flags(False) + try: + yield + finally: + torch.backends.mkldnn.set_flags(*orig_mkldnn_flag) + torch.backends.nnpack.set_flags(*orig_nnpack_flag) + + +def _convert_input_to_fake(gm, args, kwargs): + params_buffers = _get_params_buffers(gm) + fake_inps: List[torch.Tensor] = [] + for node in gm.graph.nodes: + if node.op == "placeholder" and "val" in node.meta: + fake_val = node.meta["val"] + if fake_val is not None and isinstance(fake_val, torch.Tensor): + fake_inps.append(fake_val) + + if detected_fake_mode := detect_fake_mode(fake_inps): + fake_mode = detected_fake_mode + else: + fake_mode = FakeTensorMode(shape_env=ShapeEnv()) + + if len(args) == 0 and len(kwargs) == 0: + return (), {}, params_buffers, fake_mode + + count = 0 + + def convert_to_fake(x): + nonlocal count + val = fake_inps[count] + count += 1 + return val + + fake_args = pytree.tree_map_only(torch.Tensor, convert_to_fake, args) + # TODO properly use the cached fake tensor + fake_kwargs = pytree.tree_map_only(torch.Tensor, fake_mode.from_tensor, kwargs) + fake_params_buffers = pytree.tree_map_only( + torch.Tensor, + functools.partial(fake_mode.from_tensor, static_shapes=True), + params_buffers, + ) + return fake_args, fake_kwargs, fake_params_buffers, fake_mode + + +def _replace_param_buffer_names(param_buffer_table, sig): + for spec in sig.input_specs: + if spec.kind in ( + InputKind.PARAMETER, + InputKind.BUFFER, + ): + spec.target = param_buffer_table[spec.target] + for spec in sig.output_specs: + if spec.kind in ( + OutputKind.BUFFER_MUTATION, + OutputKind.GRADIENT_TO_PARAMETER, + ): + spec.target = param_buffer_table[spec.target] + + +def _convert_to_positional_args(orig_arg_names, args, kwargs): + assert len(orig_arg_names) == len(args) + len(kwargs), ( + f"Total number of arg names is expected to be {len(orig_arg_names)} " + f"but got {len(args)} positional args, {len(kwargs)} kwargs." + ) + reordered_kwargs = [kwargs[kw_name] for kw_name in orig_arg_names[len(args) :]] + return ( + *args, + *reordered_kwargs, + ) + + +def _normalize_nn_module_stack(gm_torch_level, root_cls): + # Append a root module to every nn_module_stack. + root = "L['self']" + root_key = re.sub(r"[^a-zA-Z0-9]", "_", root) + for gm in gm_torch_level.modules(): + if not isinstance(gm, torch.fx.GraphModule): + continue + for node in gm.graph.nodes: + if node.op in ["placeholder", "output"]: + continue + add_root = True + if nn_module_stack := node.meta.get("nn_module_stack", {}): + path, ty = next(iter(nn_module_stack.values())) + # After deserializing the class `ty` might not exist anymore so + # it could be a string + if inspect.isclass(ty) and issubclass(ty, torch.nn.Module): + # TODO Figure out why sometimes we have root sometimes we don't. + if path == root and ty is root_cls: + add_root = False + else: + assert isinstance(ty, str) + if add_root: + + def normalize_path(path): + try: + parts = [] + + class Path: + def __getattr__(self, name): + parts.append(name) + return self + + def __getitem__(self, idx): + parts.append(str(idx)) + return self + + eval(path, {"L": {"self": Path()}}) + return ".".join(parts) + except Exception: # TODO(zhxchen17) Remove this. + return path + + nn_module_stack = {root_key: (root, root_cls), **nn_module_stack} + node.meta["nn_module_stack"] = { + key: (normalize_path(path), ty) + for key, (path, ty) in nn_module_stack.items() + } + + +def _get_param_buffer_mapping( + original_module: torch.nn.Module, + traced_module: torch.nn.Module, +) -> Dict[str, str]: + """ + Returns a mapping of parameter/buffer names from the new module to the + original model. This is to help with restoring the FQN for parameter/buffers + of a traced module to what the original module contains. + """ + + param_lookup: Dict[int, List[str]] = {} + buffer_lookup: Dict[int, List[str]] = {} + for name, param in original_module.named_parameters(remove_duplicate=False): + param_lookup.setdefault(id(param), []).append(name) + for name, buffer in original_module.named_buffers(remove_duplicate=False): + buffer_lookup.setdefault(id(buffer), []).append(name) + + param_buffer_table: Dict[str, str] = {} + for dynamo_name, dynamo_param in traced_module.named_parameters( + remove_duplicate=False + ): + assert dynamo_name not in param_buffer_table + if id(dynamo_param) in param_lookup: + param_buffer_table[dynamo_name] = param_lookup[id(dynamo_param)].pop() + + for dynamo_name, dynamo_buffer in traced_module.named_buffers( + remove_duplicate=False + ): + assert dynamo_name not in param_buffer_table + if id(dynamo_buffer) in buffer_lookup: + param_buffer_table[dynamo_name] = buffer_lookup[id(dynamo_buffer)].pop() + + return param_buffer_table + + +def _remap_constants( + orig_constant_attrs: ConstantAttrMap, + graph_signature: ExportGraphSignature, + constants: Dict[str, Union[torch.Tensor, torch.ScriptObject]], +) -> None: + """Rewrite the graph signature and constants table to use the FQN from the original module.""" + remap_table: Dict[str, str] = {} + for name, value in constants.items(): + if value in orig_constant_attrs: + remap_table[name] = orig_constant_attrs[value] + + for spec in graph_signature.input_specs: + if spec.kind in ( + InputKind.CONSTANT_TENSOR, + InputKind.CUSTOM_OBJ, + ): + orig_target = spec.target + assert orig_target is not None + spec.target = remap_table.get(orig_target, orig_target) + + constant = constants[orig_target] + del constants[orig_target] + constants[spec.target] = constant + + +def _restore_state_dict( + original_module: torch.nn.Module, traced_module: torch.fx.GraphModule +) -> None: + """ + Restores the state dict of the traced module to that of the original module. + """ + param_buffer_table = _get_param_buffer_mapping(original_module, traced_module) + # Since the graph module is flattened (no module heirarchy), we + # need to noramlize the module by replacing "." with "_". If we + # don't, it will try to save the weight to a submodule which no + # longer exists. + for name, fqn in param_buffer_table.items(): + param_buffer_table[name] = fqn.replace(".", "_") + + # Replace state dict attr names with the fqn + for name, fqn in param_buffer_table.items(): + if not hasattr(traced_module, name): + continue + + attr = getattr(traced_module, name) + if isinstance(attr, torch.Tensor) and not isinstance(attr, torch.nn.Parameter): + traced_module.register_buffer(fqn, attr) + else: + setattr(traced_module, fqn, attr) + delattr(traced_module, name) + + # Replace graph getattr nodes with the correct name + for node in traced_module.graph.nodes: + if node.op == "get_attr": + attr_name = node.target + if attr_name in param_buffer_table: + node.target = param_buffer_table[attr_name] + + traced_module.recompile() + + +def _export_to_torch_ir( + f: Callable, + args: Tuple[Any, ...], + kwargs: Optional[Dict[str, Any]] = None, + constraints: Optional[List[Constraint]] = None, + *, + preserve_module_call_signature: Tuple[str, ...] = (), + disable_constraint_solver: bool = False, + restore_fqn: bool = True, + _log_export_usage: bool = True, +) -> torch.fx.GraphModule: + """ + Traces either an nn.Module's forward function or just a callable with PyTorch + operations inside and produce a torch.fx.GraphModule in torch IR. + """ + + if _log_export_usage: + log_export_usage(event="export.private_api", flags={"_export_to_torch_ir"}) + + kwargs = kwargs or {} + + if not isinstance(args, tuple): + raise UserError( + UserErrorType.INVALID_INPUT, + f"Expecting `args` to be a tuple of example positional inputs, got {type(args)}", + ) + + with torch._dynamo.config.patch(dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG)): + try: + module_call_specs: Dict[str, Dict[str, pytree.TreeSpec]] = {} + with _wrap_submodules( + f, preserve_module_call_signature, module_call_specs + ), _ignore_backend_decomps(): + gm_torch_level, _ = torch._dynamo.export( + f, + constraints=constraints, # type: ignore[arg-type] + assume_static_by_default=True, + tracing_mode="symbolic", + disable_constraint_solver=disable_constraint_solver, + _log_export_usage=_log_export_usage, + )( + *args, + **kwargs, + ) + except (ConstraintViolationError, ValueRangeError) as e: + raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e)) # noqa: TRY200 + except GuardOnDataDependentSymNode as e: + raise UserError( # noqa: TRY200 + UserErrorType.ANTI_PATTERN, + f"Consider annotating your code using torch._constrain_as_*(). {str(e)}", + case_name="constrain_as_size_example", + ) + + gm_torch_level.meta["module_call_specs"] = module_call_specs + + if isinstance(f, torch.nn.Module) and restore_fqn: + _restore_state_dict(f, gm_torch_level) + + return gm_torch_level + + +def _gather_constant_attrs(m: torch.nn.Module) -> ConstantAttrMap: + """Search the module hierarchy, gathering up all tensor and ScriptObject constants. + + Returns a dictionary mapping hash(value) to the name of the constant. We + have to abuse `hash` here unfortunately, see: [ScriptObject hash]. + """ + constants = ConstantAttrMap() + buffers_parameters = set(m.buffers()) + buffers_parameters.update(m.parameters()) + + def inner(m: torch.nn.Module, prefix_atoms: List[str], constants): + for k, v in m.__dict__.items(): + if isinstance(v, (torch.Tensor, torch.ScriptObject)): + if v in buffers_parameters: + # filter out buffers and parameters, leaving only constants + continue + + fqn = ".".join(prefix_atoms + [k]) + if v in constants: + raise ValueError( + f"Duplicate reference to constant attribute found: '{constants[v]}' and '{fqn}'." + ) + + constants[v] = fqn + for k, v in m.named_children(): + inner(v, prefix_atoms + [k], constants) + + inner(m, [], constants) + return constants + + +def _export_non_strict( + mod: torch.nn.Module, + fake_args, + fake_kwargs, + fake_params_buffers, + constant_attrs: ConstantAttrMap, + *, + transform=lambda x: x, # TODO(zhxchen17) Revisit if this is needed later. + pre_dispatch=False, +): + # [NOTE] If the user is exporting under training mode, we want to detect if there is any + # state change in the autograd global state and error. If the user is exporting under inference + # mode, we don't care. + is_grad_enabled = torch._C.is_grad_enabled() + grad_safe_guard = ( + AutogradStateOpsFailSafeguard() if is_grad_enabled else nullcontext() + ) + + @contextmanager + def _compiling_state_context(): + old_value = torch.compiler._is_compiling_flag + try: + torch.compiler._is_compiling_flag = True + yield + finally: + torch.compiler._is_compiling_flag = old_value + + # This _reparametrize_module makes sure inputs and module.params/buffers have the same fake_mode, + # otherwise aot_export_module will error out because it sees a mix of fake_modes. + # And we want aot_export_module to use the fake_tensor mode in dynamo to keep the pipeline easy to reason about. + with torch.nn.utils.stateless._reparametrize_module( + mod, fake_params_buffers + ), grad_safe_guard, _ignore_backend_decomps(), _compiling_state_context(): # type: ignore[attr-defined] + gm, graph_signature = transform(aot_export_module)( + mod, + fake_args, + trace_joint=False, + pre_dispatch=pre_dispatch, + kwargs=fake_kwargs, + ) + # TODO unfortunately preserving graph-level metadata is not + # working well with aot_export. So we manually copy it. + # (The node-level meta is addressed above.) + if isinstance(mod, torch.fx.GraphModule) and hasattr(mod, "meta"): + gm.meta.update(mod.meta) + + if pre_dispatch: + from torch._export.passes.replace_set_grad_with_hop_pass import ( + replace_set_grad_with_hop_pass, + ) + + gm = replace_set_grad_with_hop_pass(gm) + + # NOTE: aot_export adds symint metadata for placeholders with int values; + # since these become specialized, we replace such metadata with the original values + flat_args = pytree.tree_leaves((fake_args, fake_kwargs)) + index = 0 + total_non_user_inputs = ( + len(graph_signature.parameters) + + len(graph_signature.buffers) + + len(graph_signature.input_tokens) + ) + for node in gm.graph.nodes: + if node.op == "placeholder": + if index >= total_non_user_inputs: + user_arg = flat_args[index - total_non_user_inputs] + if not isinstance(user_arg, torch.Tensor): + node.meta["val"] = user_arg + index += 1 + + is_joint = graph_signature.backward_signature is not None + + def make_argument_spec(node) -> ArgumentSpec: + if isinstance(node, (int, bool, float, type(None))): + # For const outputs we just directly return this + return ConstantArgument(value=node) + + assert ( + "val" in node.meta + ), f"{node} is not a constant or a node with a 'val' metadata field" + val = node.meta["val"] + if isinstance(val, FakeTensor): + return TensorArgument(name=node.name) + elif isinstance(val, torch.SymInt): + return SymIntArgument(name=node.name) + elif isinstance(val, torch.ScriptObject): + return CustomObjArgument( + name=node.name, class_fqn=val._type().qualified_name() # type: ignore[attr-defined] + ) + else: + # TODO: this branch is likely wrong, all permissible ConstantArgument type + # should have been handled already + return ConstantArgument(value=val) + + input_specs, output_specs = _sig_to_specs( + user_inputs=set(graph_signature.user_inputs), + inputs_to_parameters=graph_signature.inputs_to_parameters, # type: ignore[arg-type] + inputs_to_buffers=graph_signature.inputs_to_buffers, # type: ignore[arg-type] + user_outputs=set(graph_signature.user_outputs), # type: ignore[arg-type] + buffer_mutations=graph_signature.buffers_to_mutate, # type: ignore[arg-type] + user_input_mutations=graph_signature.user_inputs_to_mutate, # type: ignore[arg-type] + grad_params=graph_signature.backward_signature.gradients_to_parameters if is_joint else {}, # type: ignore[arg-type, union-attr] + grad_user_inputs=graph_signature.backward_signature.gradients_to_user_inputs if is_joint else {}, # type: ignore[arg-type, union-attr] + loss_output=graph_signature.backward_signature.loss_output if is_joint else None, # type: ignore[arg-type, union-attr] + inputs=[ + make_argument_spec(node) + for node in gm.graph.nodes + if node.op == "placeholder" + ], + outputs=[ + make_argument_spec(node) + for node in pytree.tree_leaves(next(iter(reversed(gm.graph.nodes))).args) + ], + input_tokens=graph_signature.input_tokens, + output_tokens=graph_signature.output_tokens, + ) + export_graph_signature = ExportGraphSignature( + input_specs=input_specs, output_specs=output_specs + ) + + constants = rewrite_script_object_meta(gm) + constants.update(lift_constants_pass(gm, export_graph_signature, constant_attrs)) + + @dataclasses.dataclass + class _ExportedProgramNonStrict: + gm: torch.fx.GraphModule + sig: ExportGraphSignature + constants: Dict[str, Union[torch.Tensor, torch._C.ScriptObject]] + + return _ExportedProgramNonStrict( + gm, + export_graph_signature, + constants, + ) + + +def _get_params_buffers(mod: torch.nn.Module) -> Dict[str, torch.Tensor]: + params_buffers: Dict[str, torch.Tensor] = {} + for name, param in mod.named_parameters(remove_duplicate=False): + params_buffers[name] = param + + for name, buffer in mod.named_buffers(remove_duplicate=False): + params_buffers[name] = buffer + return params_buffers + + +def _rewrite_dynamo_tensor_constants( + orig_mod_buffers: Set[torch.Tensor], + traced_mod_buffers: Dict[str, torch.Tensor], + graph_signature: ExportGraphSignature, + constants: Dict[str, Union[torch.Tensor, torch.ScriptObject]], +): + """Dynamo erroneously marks tensor attributes on modules as a buffers. + + Rewrite them to be tensor constants. + """ + for spec in graph_signature.input_specs: + if spec.kind == InputKind.BUFFER: + assert spec.target is not None + value = traced_mod_buffers[spec.target] + if value not in orig_mod_buffers: + # This was a tensor constant erroneously marked as a buffer. + # Convert it int oa constant in the graph signature, and add its + # value to the constants table. + spec.kind = InputKind.CONSTANT_TENSOR + constants[spec.target] = value + + +def _rewrite_non_persistent_buffers( + orig_mod: torch.nn.Module, + graph_signature: ExportGraphSignature, + constants: Dict[str, Union[torch.Tensor, torch.ScriptObject]], +): + """Dynamo erroneously drops the persistent flag on buffers. + + Rewrite non-persistent buffers to reflect the original module. + """ + state_dict = orig_mod.state_dict() + for spec in graph_signature.input_specs: + if spec.kind == InputKind.BUFFER: + assert spec.target is not None + if spec.target not in state_dict: + assert spec.target not in constants + spec.persistent = False + constants[spec.target] = orig_mod.get_buffer(spec.target) + + +def get_ep_stats(ep: ExportedProgram) -> Dict[str, Any]: + op_count = 0 + op_set = set() + for m in ep.graph_module.modules(): + if not isinstance(m, torch.fx.GraphModule): + continue + for node in m.graph.nodes: + if node.op != "call_function": + continue + op_count += 1 + assert hasattr(node.target, "__module__") + assert hasattr(node.target, "__name__") + op_set.add(f"{node.target.__module__}.{node.target.__name__}") + return {"op_count": op_count, "op_set": op_set} + + +_EXPORT_FLAGS: Optional[Set[str]] = None + + +def _log_export_wrapper(fn): + @functools.wraps(fn) + def wrapper(*args, **kwargs): + global _EXPORT_FLAGS + try: + start = time.time() + ep = fn(*args, **kwargs) + end = time.time() + log_export_usage( + event="export.time", + metrics=end - start, + flags=_EXPORT_FLAGS, + **get_ep_stats(ep), + ) + except Exception as e: + t = type(e) + error_type = t.__module__ + "." + t.__qualname__ + log_export_usage( + event="export.error", + type=error_type, + message=str(e), + flags=_EXPORT_FLAGS, + ) + raise e + finally: + _EXPORT_FLAGS = None + + return ep + + return wrapper + + +@_log_export_wrapper +@_disable_prexisiting_fake_mode +def _export( + mod: torch.nn.Module, + args: Tuple[Any, ...], + kwargs: Optional[Dict[str, Any]] = None, + dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None, + *, + strict: bool = True, + preserve_module_call_signature: Tuple[str, ...] = (), + pre_dispatch: bool = False, +) -> ExportedProgram: + """ + Traces either an nn.Module's forward function or just a callable with PyTorch + operations inside and produce a ExportedProgram. + + Args: + f: the `nn.Module` to trace. + + args: example positional inputs. + + kwargs: optional example keyword inputs. + + dynamic_shapes: + An optional argument where the type should either be: + 1) a dict from argument names of ``f`` to their dynamic shape specifications, + 2) a tuple that specifies dynamic shape specifications for each input in original order. + If you are specifying dynamism on keyword args, you will need to pass them in the order that + is defined in the original function signature. + + The dynamic shape of a tensor argument can be specified as either + (1) a dict from dynamic dimension indices to :func:`Dim` types, where it is + not required to include static dimension indices in this dict, but when they are, + they should be mapped to None; or (2) a tuple / list of :func:`Dim` types or None, + where the :func:`Dim` types correspond to dynamic dimensions, and static dimensions + are denoted by None. Arguments that are dicts or tuples / lists of tensors are + recursively specified by using mappings or sequences of contained specifications. + + preserve_module_call_signature: A list of submodule paths for which the original + calling conventions are preserved as metadata. + + Returns: + An ExportedProgram containing the traced method. + """ + from .dynamic_shapes import _process_dynamic_shapes + + global _EXPORT_FLAGS + flags = set() + flags.add("strict" if strict else "non_strict") + flags.add("pre_dispatch" if pre_dispatch else "aot_dispatch") + log_export_usage(event="export.enter", flags=flags) + _EXPORT_FLAGS = flags + + constraints = _process_dynamic_shapes(mod, args, kwargs, dynamic_shapes) or [] + + kwargs = kwargs or {} + + constant_attrs = _gather_constant_attrs(mod) + + flat_args, orig_in_spec = pytree.tree_flatten((args, kwargs)) + + if not strict: + out_spec = None + + module_call_specs: Dict[str, Dict[str, pytree.TreeSpec]] = {} + + def strip_root(x): + if isinstance(x, str) and x.startswith("_export_root"): + stripped = x[len("_export_root") :] + return stripped[1:] if stripped.startswith(".") else stripped + return x + + def fixup_key(x): + return "L__self__" + strip_root(x) + + def _tuplify_outputs(aot_export): + def _aot_export_non_strict(mod, args, kwargs=None, **flags): + kwargs = kwargs or {} + + class Wrapper(torch.nn.Module): + def __init__(self, mod): + super().__init__() + self._export_root = mod + + def forward(self, *args, **kwargs): + nonlocal out_spec + if isinstance(self._export_root, torch.fx.GraphModule): + with torch.fx.traceback.preserve_node_meta(): + tree_out = torch.fx.Interpreter(self._export_root).run( + *args, **kwargs + ) + else: + tree_out = self._export_root(*args, **kwargs) + flat_outs, out_spec = pytree.tree_flatten(tree_out) + return tuple(flat_outs) + + wrapped_mod = Wrapper(mod) + # Patch export_root to the signatures so that wrapper module correctly populates the + # in/out spec + new_preserved_call_signatures = [ + "_export_root." + i for i in preserve_module_call_signature + ] + with _wrap_submodules( + wrapped_mod, new_preserved_call_signatures, module_call_specs + ): + gm, sig = aot_export(wrapped_mod, args, kwargs=kwargs, **flags) + + sig.parameters = pytree.tree_map(strip_root, sig.parameters) + sig.buffers = pytree.tree_map(strip_root, sig.buffers) + sig.inputs_to_buffers = pytree.tree_map( + strip_root, sig.inputs_to_buffers + ) + sig.inputs_to_parameters = pytree.tree_map( + strip_root, sig.inputs_to_parameters + ) + sig.buffers_to_mutate = pytree.tree_map( + strip_root, sig.buffers_to_mutate + ) + for node in gm.graph.nodes: + if "nn_module_stack" in node.meta: + nn_module_stack = node.meta["nn_module_stack"] + node.meta["nn_module_stack"] = { + fixup_key(key): val + for key, val in pytree.tree_map( + strip_root, nn_module_stack + ).items() + } + + return gm, sig + + return _aot_export_non_strict + + ( + fake_mode, + fake_args, + fake_kwargs, + equalities_inputs, + original_signature, + ) = make_fake_inputs(mod, args, kwargs, constraints) + + fake_params_buffers = make_fake_params_buffers( + fake_mode, _get_params_buffers(mod) + ) + with fake_mode: + ep_non_strict = _export_non_strict( + mod, + fake_args, + fake_kwargs, + fake_params_buffers, + constant_attrs, + pre_dispatch=pre_dispatch, + transform=_tuplify_outputs, + ) + try: + range_constraints = make_constraints( + fake_mode, + equalities_inputs, + original_signature, + ep_non_strict.gm, + ) + except (ConstraintViolationError, ValueRangeError) as e: + raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e)) # noqa: TRY200 + + assert out_spec is not None + + gm = ep_non_strict.gm + + module_call_signatures = { + strip_root(fqn): ModuleCallSignature(inputs=[], outputs=[], **specs) + for fqn, specs in module_call_specs.items() + } + + if len(preserve_module_call_signature) > 0: + for node in gm.graph.nodes: + if node.target == torch.ops.higher_order._export_tracepoint: + if "path" in node.kwargs: + path = strip_root(node.kwargs["path"]) + with gm.graph.inserting_before(node): + new_node = gm.graph.create_node( + "call_function", + torch.ops.higher_order._export_tracepoint, + args=node.args, + kwargs={ + "path": path, + "kind": node.kwargs["kind"], + }, + ) + node.replace_all_uses_with(new_node) + gm.graph.erase_node(node) + + res = CollectTracepointsPass(module_call_signatures, ep_non_strict.sig)(gm) + assert res is not None + gm = res.graph_module + + _rewrite_non_persistent_buffers(mod, ep_non_strict.sig, ep_non_strict.constants) + return ExportedProgram( + root=gm, + graph=gm.graph, + graph_signature=ep_non_strict.sig, + state_dict=mod.state_dict(keep_vars=True), + range_constraints=range_constraints, + module_call_graph=[ + ModuleCallEntry( + "", + ModuleCallSignature( + inputs=[], outputs=[], in_spec=orig_in_spec, out_spec=out_spec + ), + ) + ] + + [ + ModuleCallEntry(fqn, sig) for fqn, sig in module_call_signatures.items() + ], + example_inputs=(args, kwargs), + constants=ep_non_strict.constants, + ) + + gm_torch_level = _export_to_torch_ir( + mod, + args, + kwargs, + constraints, + preserve_module_call_signature=preserve_module_call_signature, + restore_fqn=False, # don't need to restore because we will do it later + _log_export_usage=False, + ) + + # We detect the fake_mode by looking at gm_torch_level's placeholders, this is the fake_mode created in dynamo. + ( + fake_args, + fake_kwargs, + fake_params_buffers, + dynamo_fake_mode, + ) = _convert_input_to_fake(gm_torch_level, args, kwargs) + + # First, we want to pass through the graph to try populating + # val field for getattr if there is anything missing. + # This can happen when quantization adds extra params and forgets + # to update "val" + for node in gm_torch_level.graph.nodes: + if node.op == "get_attr" and "val" not in node.meta: + attr = getattr(gm_torch_level, node.target) + # Checks if it is not a HigherOrderOp branch or a module + if not isinstance(attr, torch.nn.Module): + assert ( + dynamo_fake_mode is not None + ), "Cannot find dynamo_fake_mode. This could be due to the exported graph module have no placeholders." + node.meta["val"] = dynamo_fake_mode.from_tensor( + attr, static_shapes=True + ) + + # When aot_export lifts the params, we lose the nn_module_stack + # and source_fn from the param nodes as they are treated as fresh inputs + # Therefore, we manually extract them before calling into aot_export + params_buffers_to_node_meta = {} + for node in gm_torch_level.graph.nodes: + target = node.target + meta = node.meta + if node.op == "call_module": + submodule = getattr(gm_torch_level, target) + if isinstance(submodule, torch.nn.Module): + for name, _ in submodule.named_parameters( + recurse=True, remove_duplicate=False + ): + params_buffers_to_node_meta[target + "." + name] = meta + + for name, _ in submodule.named_buffers( + recurse=True, remove_duplicate=False + ): + params_buffers_to_node_meta[target + "." + name] = meta + + if node.op == "get_attr": + submodule = getattr(gm_torch_level, target) + if not isinstance(submodule, torch.fx.GraphModule): + params_buffers_to_node_meta[target] = meta + + # If the call_function uses param as input, we also need to update params' meta + # with this call_function node's meta. + # This is basically the same flow as torch.fx.traceback.preserve_meta() + if node.op == "call_function" and not isinstance( + node.target, torch._ops.HigherOrderOperator + ): + for arg in node._input_nodes: + if arg.op == "get_attr": + for entry in torch.fx.proxy._COPY_META_FIELDS: + if entry in meta: + params_buffers_to_node_meta[arg.target][entry] = meta[entry] + + # Fix the graph output signature to be tuple if scalar + out_spec = orig_out_spec = gm_torch_level._out_spec + assert out_spec is not None + # aot_export expect the return type to always be a tuple. + if out_spec.type not in (list, tuple): + out_spec = pytree.TreeSpec(tuple, None, [out_spec]) + + orig_arg_names = gm_torch_level.graph._codegen.pytree_info.orig_args # type: ignore[attr-defined] + + gm_torch_level.graph._codegen = _PyTreeCodeGen( + _PyTreeInfo( + orig_arg_names, + gm_torch_level._in_spec, + out_spec, + ) + ) + gm_torch_level.recompile() + + _normalize_nn_module_stack(gm_torch_level, type(mod)) + + # NOTE: graph module expects only positional args + ep_non_strict = _export_non_strict( + gm_torch_level, + _convert_to_positional_args(orig_arg_names, fake_args, fake_kwargs), + {}, + fake_params_buffers, + constant_attrs, + pre_dispatch=pre_dispatch, + ) + + gm = ep_non_strict.gm + export_graph_signature = ep_non_strict.sig + constants = ep_non_strict.constants + + # After aot_export, set the param/buffer metadata back into placeholders + # Technically, users can still construct this data from param names + # without relying on this metadata + for node in gm.graph.nodes: + if node.op == "placeholder": + if node.target in export_graph_signature.inputs_to_parameters: + param_name = export_graph_signature.inputs_to_parameters[node.target] + if param_name in params_buffers_to_node_meta: + for k, v in params_buffers_to_node_meta[param_name].items(): + node.meta[k] = v + if node.target in export_graph_signature.inputs_to_buffers: + buffer_name = export_graph_signature.inputs_to_buffers[node.target] + if buffer_name in params_buffers_to_node_meta: + for k, v in params_buffers_to_node_meta[buffer_name].items(): + node.meta[k] = v + + # The unbacked symint symbols are updated in aot_export + # so we serialize them here instead of inside dynamo + + gm.meta["inline_constraints"] = { + k: v + for k, v in dynamo_fake_mode.shape_env.var_to_range.items() + if free_unbacked_symbols(k) + } + + num_lifted = next( + ( + i + for i, s in enumerate(export_graph_signature.input_specs) + if s.kind == InputKind.USER_INPUT + ), + len(export_graph_signature.input_specs), + ) + range_constraints = _process_constraints( + dynamo_fake_mode, + gm, + num_lifted, + flat_args, + ) + + # Do some cleanups on the graph module to restore the state dict to the + # expected form. Each of these steps should probably get fixed upstream. + # 1. Remove tensor constants that were added as buffers. + _rewrite_dynamo_tensor_constants( + orig_mod_buffers=set(mod.buffers()), + traced_mod_buffers=dict(gm_torch_level.named_buffers()), + graph_signature=ep_non_strict.sig, + constants=ep_non_strict.constants, + ) + # 2. Restore FQN of param/buffers + param_buffer_table: Dict[str, str] = _get_param_buffer_mapping(mod, gm_torch_level) + _replace_param_buffer_names(param_buffer_table, export_graph_signature) + + # 3. Remove non-persistent buffers from the graph signature + _rewrite_non_persistent_buffers(mod, ep_non_strict.sig, ep_non_strict.constants) + + # 4. Rewrite constants to have the same FQN as the original module. + _remap_constants(constant_attrs, export_graph_signature, constants) + + module_call_signatures = { + fqn: ModuleCallSignature(inputs=[], outputs=[], **specs) + for fqn, specs in gm_torch_level.meta["module_call_specs"].items() + } + + if len(preserve_module_call_signature) > 0: + res = CollectTracepointsPass(module_call_signatures, export_graph_signature)(gm) + assert res is not None + gm = res.graph_module + + assert orig_out_spec is not None + exported_program = ExportedProgram( + root=gm, + graph=gm.graph, + graph_signature=export_graph_signature, + state_dict=mod.state_dict(keep_vars=True), + range_constraints=range_constraints, + module_call_graph=[ + ModuleCallEntry( + "", + ModuleCallSignature( + inputs=[], outputs=[], in_spec=orig_in_spec, out_spec=orig_out_spec + ), + ) + ] + + [ModuleCallEntry(fqn, sig) for fqn, sig in module_call_signatures.items()], + example_inputs=(args, kwargs), + constants=constants, + ) + log.debug("Exported program from AOTAutograd:\n%s", exported_program) + + if len(range_constraints) > 0: + exported_program = exported_program._transform_do_not_use( + _AddRuntimeAssertionsForInlineConstraintsPass(range_constraints) + ) + + return exported_program diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/dynamic_shapes.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/dynamic_shapes.py new file mode 100644 index 0000000000000000000000000000000000000000..73d64b8f5d07a481968bd90caaef2ac4b2e8382f --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/dynamic_shapes.py @@ -0,0 +1,876 @@ +import builtins +import dataclasses +import inspect +import math +import sys +import weakref +from collections import defaultdict +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TYPE_CHECKING, Union + +import torch +from torch._subclasses.fake_tensor import FakeTensor +from torch.utils._pytree import SUPPORTED_NODES + +from .exported_program import ExportedProgram + +if TYPE_CHECKING: + from sympy import Symbol + + from torch._guards import Source + + from ..fx.experimental.symbolic_shapes import ShapeEnv, StrictMinMaxConstraint + +__all__ = ["Constraint", "Dim", "dims", "dynamic_dim"] + + +class _Dim(type): + """ + Metaclass for :func:`Dim` types. + """ + + @staticmethod + def readable(name, min_, max_): + if min_ == 2: + min_ = None + if max_ == sys.maxsize - 1: + max_ = None + if min_ is None and max_ is None: + return f"Dim('{name}')" + if min_ is None: + return f"Dim('{name}', max={max_})" + if max_ is None: + return f"Dim('{name}', min={min_})" + return f"Dim('{name}', min={min_}, max={max_})" + + def __add__(cls, other): + # e.g., dim + 1 + if type(other) is not int: + raise NotImplementedError( + f"Attempted to add {other} to {cls.__name__}, where an integer was expected. " + "(Only increasing linear operations with integer coefficients are supported.)" + ) + return cls._derive(lambda x: x + other) + + def __radd__(cls, other): + return cls + other + + def __sub__(cls, other): + # e.g., dim - 1 + if type(other) is not int: + raise NotImplementedError( + f"Attempted to subtract {other} from {cls.__name__}, where an integer was expected. " + "(Only increasing linear operations with integer coefficients are supported.)" + ) + return cls._derive(lambda x: x - other) + + def __rsub__(cls, other): + raise NotImplementedError( + f"Attempted to negate {cls.__name__}. " + "(Only increasing linear operations with integer coefficients are supported.)" + ) + + def __mul__(cls, other): + # e.g., dim * 2 + if type(other) is not int or other <= 0: + raise NotImplementedError( + f"Attempted to multiply {other} with {cls.__name__}, where a positive integer was expected. " + "(Only increasing linear operations with integer coefficients are supported.)" + ) + return cls._derive(lambda x: x * other) + + def __rmul__(cls, other): + return cls * other + + def _derived_name(cls, fn): + from sympy import sympify + + return str(fn(sympify(cls.__name__))) + + def _derive(cls, fn): + return _DerivedDim(cls._derived_name(fn), (int,), {"root": cls, "fn": fn}) + + +class _DerivedDim(_Dim): + """ + Metaclass for derived :func:`Dim` types. + + Currently we only support increasing linear expressions with integer coefficients. + In other words, a derived Dim can always be written in the form Ax + B, where + x is a regular Dim (i.e., non-derived Dim), A and B are integers, and A is positive. + (In particular, the latter ensures that x < y => Ax + B < Ay + B.) + These restrictions on the form of derived Dims makes the metatheory simpler: e.g., + it simplifies computing ranges for derived Dims, solving for underlying regular Dims, + deciding equalities between derived Dims, and so on. + + The function lambda x: Ax + B is expressed by `fn`, where x is a normal Dim, `root`. + The range of a derived Dim is computed by mapping `fn` over the range of its `root`. + """ + + @property + def min(self): + # assume that self.fn is an increasing function + # TODO(avik): use sympy value range analysis instead? + from sympy import Integer + + _min_symint = self.fn(Integer(self.root.min)) # type: ignore[attr-defined] + assert _min_symint >= 2, ( + f"Expected derived min value of {self.__name__} to be >= 2. " + f"Please specify an appropriate min value for {self.root.__name__} " # type: ignore[attr-defined] + f"(currently {self.root.min})." # type: ignore[attr-defined] + ) + return int(_min_symint) + + @property + def max(self): + # assume that self.fn is an increasing function + # TODO(avik): use sympy value range analysis instead? + from sympy import Integer + + _max_symint = self.fn(Integer(self.root.max)) # type: ignore[attr-defined] + assert _max_symint <= sys.maxsize - 1, ( + f"Expected derived max value of {self.__name__} to be <= {sys.maxsize - 1}. " + f"Please specify an appropriate max value for {self.root.__name__} " # type: ignore[attr-defined] + f"(currently {self.root.max})." # type: ignore[attr-defined] + ) + return int(_max_symint) + + def _derive(self, fn): + # We support nesting, e.g., 2*dim + 1. + # This is implemented by composing operations on the same root. + # As a consequence, roots are always regular Dims (i.e., not derived Dims). + return _DerivedDim( + self._derived_name(fn), + (int,), + {"root": self.root, "fn": lambda x: fn(self.fn(x))}, # type: ignore[attr-defined] + ) + + +def Dim(name: str, *, min: Optional[int] = None, max: Optional[int] = None): + """ + :func:`Dim` constructs a type analogous to a named symbolic integer with a range. + It can be used to describe multiple possible values of a dynamic tensor dimension. + Note that different dynamic dimensions of the same tensor, or of different tensors, + can be described by the same type. + + Args: + name (str): Human-readable name for debugging. + min (Optional[int]): Minimum possible value of given symbol (inclusive) + max (Optional[int]): Maximum possible value of given symbol (inclusive) + + Returns: + A type that can be used in dynamic shape specifications for tensors. + """ + _min = 2 if min is None else builtins.max(min, 2) + _max = sys.maxsize - 1 if max is None else builtins.min(max, sys.maxsize - 1) + assert _max > _min, f"Cannot create Dim with inconsistent min={min}, max={max}" + dim = _Dim(name, (int,), {"min": _min, "max": _max}) + dim.__module__ = getattr( + inspect.getmodule(inspect.stack()[1][0]), "__name__", "__main__" + ) + return dim + + +def dims(*names: str, min: Optional[int] = None, max: Optional[int] = None): + """ + Util to create multiple :func:`Dim` types. + """ + return tuple(Dim(name, min=min, max=max) for name in names) + + +@dataclasses.dataclass +class _ConstraintTarget: + """ + This represents input tensor dimensions. Don't create this + class directly; instead, use :func:`dynamic_dim`. + """ + + w_tensor: Any # weakref to torch.Tensor + # TODO: We don't need t_id; we can get it off of w_tensor + t_id: int + dim: int + + +class _ConstraintFactory(type): + """ + Metaclass that ensures a private constructor for :class:`_Constraint` + """ + + def __call__(cls, *args, **kwargs): + raise TypeError( + f"{cls.__module__}.{cls.__qualname__} has no public constructor. " + f"Please use torch.export.dynamic_dim() to create one" + ) + + def _create( + cls, w_tensor, t_id, dim, constraint_range, shared=None, debug_name=None + ): + return super().__call__( + w_tensor, t_id, dim, constraint_range, shared, debug_name + ) + + +def _create_constraint( + w_tensor, t_id, dim, constraint_range, shared=None, debug_name=None +): + return _Constraint._create( + w_tensor, t_id, dim, constraint_range, shared, debug_name + ) + + +@dataclasses.dataclass +class _Constraint(_ConstraintTarget, metaclass=_ConstraintFactory): + """ + + .. warning:: + Do not construct :class:`_Constraint` directly, use :func:`dynamic_dim` instead. + + This represents constraints on input tensor dimensions, e.g., requiring + them to be fully polymorphic or within some range. + + """ + + # NOTE(avik): In the future, this could be Union[StrictMinMaxConstraint, ] + constraint_range: "StrictMinMaxConstraint" + # Represent that `constraint_range` is shared with another _ConstraintTarget, which + # typically arises because of a specified equality with another dynamic dimension. + shared: Optional[_ConstraintTarget] = None + debug_name: Optional[str] = None + + def _clone_with_range(self, lower=2, upper=math.inf): + # Import sympy locally + from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint + from torch.utils._sympy.value_ranges import ValueRanges + + constraint_range = StrictMinMaxConstraint( + vr=self.constraint_range.vr & ValueRanges(lower=lower, upper=upper), + warn_only=False, + ) + return _create_constraint( + self.w_tensor, + self.t_id, + self.dim, + constraint_range, + self.shared, + self.debug_name, + ) + + def __ge__(self, lower): + return self._clone_with_range(lower=lower) + + def __gt__(self, lower): + return self._clone_with_range(lower=lower + 1) + + def __le__(self, upper): + return self._clone_with_range(upper=upper) + + def __lt__(self, upper): + return self._clone_with_range(upper=upper - 1) + + def __bool__(self): + # NOTE(avik): We do not support compound expressions like a <= x <= b. + # This is because Python implicitly desugars them into bool(a <= x) and bool(x <= b), + # and moreover, enforces that any overload of __bool__ must return True or False. + # FWIW, sympy also raises TypeError in this case. + raise TypeError( + "Cannot determine truth value of _Constraint. " + "If you are trying to combine _Constraint's with logical connectives, " + "you can specify them separately instead." + ) + + @property + def serializable_spec(self): + # We need a serialization compatible format of the constraint so that it + # can be savedin the graph module w/o breaking the module serialization. + # The saved constraints will be used directly for the post-exporting pass + # that converts constraints to runtime assertion. The saved constraints + # will not be saved in the serialized module. + # TODO: A better way is needed. Currently we use 't_id' to map the constraint, + # which is not reliable + return { + "t_id": self.t_id, + "dim": self.dim, + "min": self.constraint_range.vr.lower, + "max": self.constraint_range.vr.upper, + } + + def __eq__(self, other): + if not isinstance(other, _Constraint): + raise TypeError( + "A dynamic dim can be specified equal only to another dynamic dim. " + f"Equality with {type(other)} is not supported." + ) + + # import sympy locally + from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint + + constraint_range = StrictMinMaxConstraint( + vr=self.constraint_range.vr & other.constraint_range.vr, + warn_only=False, + ) + if self.debug_name is None: + debug_name = other.debug_name + else: + assert other.debug_name is None or self.debug_name == other.debug_name + debug_name = self.debug_name + return _create_constraint( + self.w_tensor, + self.t_id, + self.dim, + constraint_range, + shared=_ConstraintTarget(other.w_tensor, other.t_id, other.dim), + debug_name=debug_name, + ) + + +@dataclasses.dataclass +class _PhantomRoot: + """ + This represents the root of a derived Dim where the root does not directly + specify the shape of any input dimension, but the derived Dim does. + + e.g., the input shapes 2*dim and dim + 1 are related via a "phantom" dim. + + The fields `name`, `constraint_range`, and `val` carried by a phantom root + help create a symbol for it. Any derived dims with this phantom root are + backed by expressions over this symbol. + """ + + name: str + constraint_range: "StrictMinMaxConstraint" + val: int + + +@dataclasses.dataclass +class _DerivedConstraint(_ConstraintTarget): + """ + This represents a derived Dim, whose root is either a regular constraint target + (which directly specifies the shape of some input dimension) or a phantom root + (which does so indirectly). + """ + + # NOTE: This is not currently a subclass of _Constraint because we do not support + # `shared` for derived `Dim`s. Indeed, sharing is a necessary concept only for + # legacy constraints based on `dynamic_dim`: equality can be expressed simply by + # reusing the same (derived or normal) `Dim`. + root: Union[_ConstraintTarget, _PhantomRoot] + fn: Callable + constraint_range: "StrictMinMaxConstraint" + debug_name: Optional[str] = None + + @property + def shared(self): + # Some code paths expect a union of _Constraint and _DerivedConstraint. + # Thus we expose a `shared` field that is always None. + # TODO(avik): clean this up + return None + + @property + def serializable_spec(self): + # same as _Constraint.serializable_spec + return { + "t_id": self.t_id, + "dim": self.dim, + "min": self.constraint_range.vr.lower, + "max": self.constraint_range.vr.upper, + } + + +Constraint = Union[_Constraint, _DerivedConstraint] + + +def dynamic_dim(t: torch.Tensor, index: int, debug_name: Optional[str] = None): + """ + .. warning:: + (This feature is DEPRECATED. See :func:`Dim` instead.) + + :func:`dynamic_dim` constructs a :class:`_Constraint` object that describes the dynamism of + a dimension ``index`` of tensor ``t``. :class:`_Constraint` objects should be passed to + ``constraints`` argument of :func:`export`. + + Args: + t (torch.Tensor): Example input tensor that have dynamic dimension size(s) + index (int): Index of dynamic dimension + + Returns: + A :class:`_Constraint` object that describes shape dynamism. It can be passed to :func:`export` so + that :func:`export` does not assume static size of specified tensor, i.e. keeping it dynamic + as a symbolic size rather than specializing according to size of example tracing input. + + Specifically :func:`dynamic_dim` can be used to express following types of dynamism. + + - Size of a dimension is dynamic and unbounded:: + + t0 = torch.rand(2, 3) + t1 = torch.rand(3, 4) + + # First dimension of t0 can be dynamic size rather than always being static size 2 + constraints = [dynamic_dim(t0, 0)] + ep = export(fn, (t0, t1), constraints=constraints) + + - Size of a dimension is dynamic with a lower bound:: + + t0 = torch.rand(10, 3) + t1 = torch.rand(3, 4) + + # First dimension of t0 can be dynamic size with a lower bound of 5 (inclusive) + # Second dimension of t1 can be dynamic size with a lower bound of 2 (exclusive) + constraints = [ + dynamic_dim(t0, 0) >= 5, + dynamic_dim(t1, 1) > 2, + ] + ep = export(fn, (t0, t1), constraints=constraints) + + - Size of a dimension is dynamic with an upper bound:: + + t0 = torch.rand(10, 3) + t1 = torch.rand(3, 4) + + # First dimension of t0 can be dynamic size with a upper bound of 16 (inclusive) + # Second dimension of t1 can be dynamic size with a upper bound of 8 (exclusive) + constraints = [ + dynamic_dim(t0, 0) <= 16, + dynamic_dim(t1, 1) < 8, + ] + ep = export(fn, (t0, t1), constraints=constraints) + + - Size of a dimension is dynamic and it is always equal to size of another dynamic dimension:: + + t0 = torch.rand(10, 3) + t1 = torch.rand(3, 4) + + # Sizes of second dimension of t0 and first dimension are always equal + constraints = [ + dynamic_dim(t0, 1) == dynamic_dim(t1, 0), + ] + ep = export(fn, (t0, t1), constraints=constraints) + + - Mix and match all types above as long as they do not express conflicting requirements + + """ + from torch._dynamo.exc import UserError, UserErrorType + + if not isinstance(t, torch.Tensor): + raise UserError( + UserErrorType.DYNAMIC_DIM, + f"Expected tensor as input to dynamic_dim but got {type(t)}", + ) + + if t.dim() < 1: + raise UserError( + UserErrorType.DYNAMIC_DIM, "Cannot mark 0-dimension tensors to be dynamic" + ) + + if index >= t.dim(): + raise UserError( + UserErrorType.DYNAMIC_DIM, + f"Expected the dimension passed to dynamic_dim to be in the range [0:{t.dim()-1}]" + f" but got {index}, which is out of bounds for the given tensor.", + ) + + # Import sympy locally + import sympy + + from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint + from torch.utils._sympy.value_ranges import ValueRanges + + return _create_constraint( + weakref.ref(t), + id(t), + index, + StrictMinMaxConstraint( + vr=ValueRanges(lower=2, upper=sympy.oo), warn_only=False + ), + debug_name=debug_name, + ) + + +def _process_equalities( + constraint: Constraint, + get_sources: Callable[[int, int], List["Source"]], + shape_env: "ShapeEnv", + source_pairs: List[Tuple["Source", "Source"]], + derived_equalities: List[Tuple["Source", Union["Source", "Symbol"], Callable]], + phantom_symbols: Dict[str, "Symbol"], +): + """ + Updates `source_pairs`, `derived_equalities`, and `phantom_symbols` (which become + fields of `EqualityConstraint`) based on a given input `constraint`. + """ + + source, *other_sources = get_sources(constraint.t_id, constraint.dim) + # When t.size()[dim] maps to src0, src1, ..., srcN, we add + # constraints that make src0 "equal" to src1, ..., srcN. + source_pairs.extend((source, other_source) for other_source in other_sources) + if not isinstance(constraint, _DerivedConstraint): + if constraint.shared is not None: + # Moreover, when t.size()[dim] is specified equal to t'.size()[dim'] + # and t'.size()[dim'] maps to src1', ..., srcN', we add + # constraints that also make src0 "equal" to src1', ..., srcN'. + other_sources = get_sources(constraint.shared.t_id, constraint.shared.dim) + source_pairs.extend( + (source, other_source) for other_source in other_sources + ) + else: + # branch based on the root of the _DerivedConstraint + if not isinstance(constraint.root, _PhantomRoot): + # either root points to an input source + root = get_sources(constraint.root.t_id, constraint.root.dim)[0] # type: ignore[assignment] + else: + # or root points to a phantom symbol + if constraint.root.name in phantom_symbols: + root = phantom_symbols[constraint.root.name] # type: ignore[assignment] + else: + # create a phantom symbol in the shape env based on the _PhantomRoot + root = shape_env.create_symbol( + val=constraint.root.val, + source=torch._dynamo.source.ConstantSource(constraint.root.name), + dynamic_dim=torch.fx.experimental.symbolic_shapes.DimDynamic.DYNAMIC, + constraint_dim=constraint.root.constraint_range, + ) + phantom_symbols[constraint.root.name] = root # type: ignore[assignment] + + fn = constraint.fn + # A derived equality (source, root, fn) informally corresponds to source = fn(root). + # Here source describes an input and root might describe another input or a phantom symbol. + derived_equalities.append((source, root, fn)) + + +def _process_dynamic_shapes( + f: Callable, + args: Tuple[Any, ...], + kwargs: Optional[Dict[str, Any]] = None, + dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None, +) -> Optional[List[Constraint]]: + from collections import defaultdict + from collections.abc import Mapping, Sequence + + from torch._dynamo.exc import UserError, UserErrorType + + if dynamic_shapes is None or len(dynamic_shapes) == 0: + return None + + kwargs = kwargs if kwargs is not None else {} + + def tree_zip(combined_args, dynamic_shapes): + if isinstance(combined_args, (tuple, list)): + if not isinstance(dynamic_shapes, Sequence): + raise UserError( + UserErrorType.INVALID_INPUT, + f"Expected dynamic_shapes of a {type(combined_args)} to be a Sequence, " + f"got {dynamic_shapes} instead", + ) + if len(combined_args) != len(dynamic_shapes): + raise UserError( + UserErrorType.INVALID_INPUT, + f"Expected {dynamic_shapes} to have {len(combined_args)} items", + ) + for i, shape in enumerate(dynamic_shapes): + yield from tree_zip(combined_args[i], shape) + elif isinstance(combined_args, dict): + if not isinstance(dynamic_shapes, Mapping): + raise UserError( + UserErrorType.INVALID_INPUT, + f"Expected dynamic_shapes of a {type(combined_args)} to be a Mapping, " + f"got {dynamic_shapes} instead", + ) + if len(combined_args) != len(dynamic_shapes): + raise UserError( + UserErrorType.INVALID_INPUT, + f"Expected {dynamic_shapes} to have {len(combined_args)} items", + ) + for k, shape in dynamic_shapes.items(): + yield from tree_zip(combined_args[k], shape) + elif type(combined_args) in SUPPORTED_NODES: + if not isinstance(dynamic_shapes, Sequence): + raise UserError( + UserErrorType.INVALID_INPUT, + f"Expected dynamic_shapes of a user-registered class (e.g., " + f"{type(combined_args)}) to be a Sequence that matches the " + f"flattened structure, but got {dynamic_shapes} instead", + ) + yield from tree_zip( + SUPPORTED_NODES[type(combined_args)].flatten_fn(combined_args)[0], + dynamic_shapes, + ) + elif isinstance(combined_args, torch.Tensor): + yield (combined_args, dynamic_shapes) + else: + if dynamic_shapes is not None: + raise UserError( + UserErrorType.INVALID_INPUT, + f"Expected dynamic_shapes of a {type(combined_args)} to be None, " + f"got {dynamic_shapes} instead", + ) + + # map of Dim names representing input shape dimensions to constraints on them + symbols: Dict[str, List[Constraint]] = defaultdict(list) + # track roots that do not directly represent input shape dimensions + phantom_roots: Dict[str, _PhantomRoot] = {} + derived_constraints_with_phantom_root: List[_DerivedConstraint] = [] + + def to_constraint(dim, tensor, i): + import sympy + + from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint + from torch.utils._sympy.solve import try_solve + from torch.utils._sympy.value_ranges import ValueRanges + + def root_value(): + # given tensor.shape[i] is the value of dim = fn(root), + # find the value of root + symbol = sympy.Symbol(dim.root.__name__, integer=True) + expr = dim.fn(symbol) + solution = try_solve(sympy.Eq(expr, tensor.shape[i]), symbol) + if solution is not None: + return int(solution[1]) # type: ignore[call-overload] + else: + raise UserError( # noqa: TRY200 + UserErrorType.CONSTRAINT_VIOLATION, + f"Expected shape[{i}] = {tensor.shape[i]} of input Tensor to be " + f"of the form {expr}, where {symbol} is an integer", + ) + + if isinstance(dim, _DerivedDim): + # generate a _DerivedConstraint where the root is: + # - either a _ConstraintTarget (if dim.root directly describes an input shape) + # - or a _PhantomRoot (otherwise) + dim_root = dim.root # type: ignore[attr-defined] + if dim_root.__name__ in symbols: + # root represents an input shape dimension + root_constraint = symbols[dim_root.__name__][0] + root = _ConstraintTarget( + root_constraint.w_tensor, + root_constraint.t_id, + root_constraint.dim, + ) + elif dim_root.__name__ not in phantom_roots: + # create a phantom root + root = _PhantomRoot( # type: ignore[assignment] + name=dim_root.__name__, + constraint_range=StrictMinMaxConstraint( + vr=ValueRanges(lower=dim_root.min, upper=dim_root.max), + warn_only=False, + ), + val=root_value(), + ) + phantom_roots[dim_root.__name__] = root # type: ignore[assignment] + else: + root = phantom_roots[dim_root.__name__] # type: ignore[assignment] + constraint = _DerivedConstraint( + weakref.ref(tensor), + id(tensor), + i, + root, + dim.fn, # type: ignore[attr-defined] + StrictMinMaxConstraint( + vr=ValueRanges(lower=dim.min, upper=dim.max), + warn_only=False, + ), + debug_name=dim.__name__, + ) + if isinstance(root, _PhantomRoot): + # NOTE(avik): since we have not processed all inputs yet, we may replace this + # with a root that does represent an input shape dimension later (see below) + derived_constraints_with_phantom_root.append(constraint) + else: + constraint = dynamic_dim(tensor, i, debug_name=dim.__name__) + if dim.min != 2: + constraint = constraint >= dim.min + if dim.max != sys.maxsize - 1: + constraint = constraint <= dim.max + return constraint + + bounds: Dict[str, Tuple[int, int]] = {} + + def check_same_bounds(dim): + if dim.__name__ in symbols: + min_, max_ = bounds[dim.__name__] + if dim.min != min_ or dim.max != max_: + this_ = _Dim.readable(dim.__name__, min_, max_) + that_ = _Dim.readable(dim.__name__, dim.min, dim.max) + raise UserError( + UserErrorType.INVALID_INPUT, + f"Found different definitions {this_} and {that_} " + f"for the same symbolic dimension {dim}!", + ) + + else: + bounds[dim.__name__] = (dim.min, dim.max) + + def update_symbols(tensor, shape): + if isinstance(shape, dict): + for i, dim in shape.items(): + if isinstance(dim, _Dim): + check_same_bounds(dim) + constraint = to_constraint(dim, tensor, i) + symbols[dim.__name__].append(constraint) + else: + if dim is not None: + raise UserError( + UserErrorType.INVALID_INPUT, + f"Unexpected item #{i} ({dim}) in dynamic_shape {shape} of Tensor, " + "try None instead", + ) + elif isinstance(shape, (tuple, list)): + for i, dim in enumerate(shape): + if isinstance(dim, _Dim): + check_same_bounds(dim) + constraint = to_constraint(dim, tensor, i) + symbols[dim.__name__].append(constraint) + else: + if dim is not None: + raise UserError( + UserErrorType.INVALID_INPUT, + f"Unexpected item #{i} ({dim}) in dynamic_shape {shape} of Tensor, " + "try None instead", + ) + else: + if shape is not None: + raise UserError( + UserErrorType.INVALID_INPUT, + f"Unexpected dynamic_shape {shape} of Tensor, " "try None instead", + ) + + import inspect + + if isinstance(f, ExportedProgram): + f = f.module() + signature = ( + inspect.signature(f.forward) + if isinstance(f, torch.nn.Module) + else inspect.signature(f) + ) + combined_args = signature.bind(*args, **kwargs).arguments + + # This means user didn't specify dynamic shapes with argument names. + combined_args = combined_args if isinstance(dynamic_shapes, Mapping) else list(combined_args.values()) # type: ignore[assignment] + for tensor, shape in tree_zip(combined_args, dynamic_shapes): + update_symbols(tensor, shape) + + constraints = [] + for derived_constraint_with_phantom_root in derived_constraints_with_phantom_root: + phantom_root_name = derived_constraint_with_phantom_root.root.name # type: ignore[union-attr] + if phantom_root_name in symbols: + # We found an input shape dimension corresponding to this name, so we + # do not need a phantom symbol for it after all. + # NOTE(avik): Overall we want to maintain the invariant that roots that + # are phantom symbols are really "phantom," i.e., they cannot be represented + # by any input source. This is important when we are deciding derived equalities, + # since we can focus our attention exclusively on input sources: deciding + # derived equalities involving phantom symbols are, in comparison, trivial. + derived_constraint_with_phantom_root.root = symbols[phantom_root_name][0] + + for dynamic_dims in symbols.values(): + if all( + isinstance(dynamic_dim, _DerivedConstraint) for dynamic_dim in dynamic_dims + ): + constraints.extend(dynamic_dims) + else: + primary, *others = dynamic_dims + if others: + for other in others: + constraints.append(primary == other) # type: ignore[arg-type] + else: + constraints.append(primary) + + return constraints # type: ignore[return-value] + + +def _process_constraints( + fake_mode, + graph_module: torch.fx.GraphModule, + num_lifted_params_buffers: int, + example_inputs: List[torch.Tensor], +) -> Dict: + """ + Process the constraints stored in the graph module to return something more readable. + + Args: + graph_module (torch.fx.GraphModule): GraphModule returned from + dynamo.export, which contains the "input_shape_constraints" and + "inline_constraints" metadata + + example_inputs: Flattened list of example inputs used to export the graph module + + Returns: + range_constraints (Dict[sympy.Symbol, ValueRanges]): Mapping of + symbols (from SymInts) appearing in the fake tensors in + node.meta["val"] to their range constraints, which are a tuple + containing (lower, upper) constraints. + """ + from torch._export.passes.add_runtime_assertions_for_constraints_pass import ( + InputDim, + ) + + # Import sympy locally + from torch.fx.experimental.symbolic_shapes import SymInt + from torch.utils._sympy.value_ranges import ValueRanges + + input_shape_constraints = graph_module.meta.get("input_shape_constraints", []) + inline_constraints = graph_module.meta.get("inline_constraints", []) + + # Create dict mapping tensor_id to node names + tensor_id_to_nodes: Dict[int, List[str]] = defaultdict(list) + # Create dict mapping placeholder node names to their nodes + placeholder_nodes: Dict[str, torch.fx.Node] = {} + for i, node in enumerate(graph_module.graph.nodes): + if node.op != "placeholder": + # All placeholder nodes should be together in the beginning of the + # graph + break + if i >= num_lifted_params_buffers: + example_input = example_inputs[i - num_lifted_params_buffers] + tensor_id_to_nodes[id(example_input)].append(node.name) + placeholder_nodes[node.name] = node + + # Create dict mapping (node name, dim) a list of range (lower, upper) + # constraints + multi_range_constraints: Dict[InputDim, List[ValueRanges]] = defaultdict(list) + for constraint in input_shape_constraints: + for node in tensor_id_to_nodes[constraint["t_id"]]: + node_dim = InputDim(node, constraint["dim"]) + + # Accumulate range constraints + multi_range_constraints[node_dim].append( + ValueRanges(constraint["min"], constraint["max"]) + ) + + # Create dict mapping symbol to a singular range (lower, upper) + range_constraints: Dict[Any, ValueRanges] = {} + + # Add inline constraints to range_constraints + range_constraints = { + symbol: inline_constraints[symbol] for symbol in inline_constraints + } + + free_symbols: Set["Symbol"] = set() + # Add input range constraints to range_constraints + for input_dim, multi_range_constraint in multi_range_constraints.items(): # type: ignore[assignment] + # Simplify the range constraints into a single range constraint + # Ex. ranges [2, 10] and [3, 11] would get merged to [3, 10] + min_vals = [rc.lower for rc in multi_range_constraint] + max_vals = [rc.upper for rc in multi_range_constraint] + min_val = max(min_vals) # type: ignore[type-var] + max_val = min(max_vals) # type: ignore[type-var] + assert min_val <= max_val # type: ignore[operator] + + # Add input node range constraints + val = placeholder_nodes[input_dim.input_name].meta["val"] + assert isinstance(val, FakeTensor) + symint = val.shape[input_dim.dim] + assert isinstance( + symint, SymInt + ), f"Expected SymInt but got {symint}: {type(symint)}" + symbol = symint.node.expr + range_constraints[symbol] = ValueRanges(min_val, max_val) + free_symbols.update(symbol.free_symbols) + + for symbol in free_symbols: + if symbol not in range_constraints: + # Placeholders can have symbolic shapes that are derived expressions. + # The above code will record direct range constraints for them + # so that we can do runtime assertions. In addition, for serde checks + # we want to record range constraints for their root symbols. + range_constraints[symbol] = fake_mode.shape_env.var_to_range[symbol] + + return range_constraints