diff --git a/.gitattributes b/.gitattributes index f5441aff30b541916513f3d17b0033aa85c4b75c..5db4a09bb9785cb413e7c47dec5bdab958b01a30 100644 --- a/.gitattributes +++ b/.gitattributes @@ -66,3 +66,6 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/torchgen/__pycach tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/tests/__pycache__/test_fp.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text tuning-competition-baseline/.venv/lib/python3.11/site-packages/mpmath/__pycache__/function_docs.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/distlib/w64-arm.exe filter=lfs diff=lfs merge=lfs -text +tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/rich/__pycache__/console.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text +tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Debugger/__pycache__/libpython.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text +tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/FlowControl.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/FlowControl.cpython-311-x86_64-linux-gnu.so b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/FlowControl.cpython-311-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..fec9da9646b5fffb19df7fa5d4f8948fea672512 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/FlowControl.cpython-311-x86_64-linux-gnu.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b2c7b91ab0731f5672d976d4408f3525891b8c4e1d4ed4d403f56d1c141c7f94 +size 688080 diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Debugger/__pycache__/libpython.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Debugger/__pycache__/libpython.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..af11ea15581dfbbbe7800700e9a6ad76f74c5d8e --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Debugger/__pycache__/libpython.cpython-311.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e828bf211daa379b740684868a31081921397805bfc7ef4b41a8572d794eaafb +size 137864 diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/rich/__pycache__/console.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/rich/__pycache__/console.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9cbab4966d9d2d18e180dce763d0c5057607e236 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/rich/__pycache__/console.cpython-311.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1a7da30d6865deaf94e2814884970e99b253843c23d4aa93b1107a23e61de6c1 +size 123664 diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_tensor_str.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_tensor_str.py new file mode 100644 index 0000000000000000000000000000000000000000..6903b49715ecda81bd55db9d8e85126f74c2eb46 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_tensor_str.py @@ -0,0 +1,697 @@ +import contextlib +import dataclasses +import math +import textwrap +from typing import Any, Dict, Optional + +import torch +from torch import inf + + +@dataclasses.dataclass +class __PrinterOptions: + precision: int = 4 + threshold: float = 1000 + edgeitems: int = 3 + linewidth: int = 80 + sci_mode: Optional[bool] = None + + +PRINT_OPTS = __PrinterOptions() + + +# We could use **kwargs, but this will give better docs +def set_printoptions( + precision=None, + threshold=None, + edgeitems=None, + linewidth=None, + profile=None, + sci_mode=None, +): + r"""Set options for printing. Items shamelessly taken from NumPy + + Args: + precision: Number of digits of precision for floating point output + (default = 4). + threshold: Total number of array elements which trigger summarization + rather than full `repr` (default = 1000). + edgeitems: Number of array items in summary at beginning and end of + each dimension (default = 3). + linewidth: The number of characters per line for the purpose of + inserting line breaks (default = 80). Thresholded matrices will + ignore this parameter. + profile: Sane defaults for pretty printing. Can override with any of + the above options. (any one of `default`, `short`, `full`) + sci_mode: Enable (True) or disable (False) scientific notation. If + None (default) is specified, the value is defined by + `torch._tensor_str._Formatter`. This value is automatically chosen + by the framework. + + Example:: + + >>> # Limit the precision of elements + >>> torch.set_printoptions(precision=2) + >>> torch.tensor([1.12345]) + tensor([1.12]) + >>> # Limit the number of elements shown + >>> torch.set_printoptions(threshold=5) + >>> torch.arange(10) + tensor([0, 1, 2, ..., 7, 8, 9]) + >>> # Restore defaults + >>> torch.set_printoptions(profile='default') + >>> torch.tensor([1.12345]) + tensor([1.1235]) + >>> torch.arange(10) + tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) + + """ + if profile is not None: + if profile == "default": + PRINT_OPTS.precision = 4 + PRINT_OPTS.threshold = 1000 + PRINT_OPTS.edgeitems = 3 + PRINT_OPTS.linewidth = 80 + elif profile == "short": + PRINT_OPTS.precision = 2 + PRINT_OPTS.threshold = 1000 + PRINT_OPTS.edgeitems = 2 + PRINT_OPTS.linewidth = 80 + elif profile == "full": + PRINT_OPTS.precision = 4 + PRINT_OPTS.threshold = inf + PRINT_OPTS.edgeitems = 3 + PRINT_OPTS.linewidth = 80 + + if precision is not None: + PRINT_OPTS.precision = precision + if threshold is not None: + PRINT_OPTS.threshold = threshold + if edgeitems is not None: + PRINT_OPTS.edgeitems = edgeitems + if linewidth is not None: + PRINT_OPTS.linewidth = linewidth + PRINT_OPTS.sci_mode = sci_mode + + +def get_printoptions() -> Dict[str, Any]: + r"""Gets the current options for printing, as a dictionary that + can be passed as ``**kwargs`` to set_printoptions(). + """ + return dataclasses.asdict(PRINT_OPTS) + + +@contextlib.contextmanager +def printoptions(**kwargs): + r"""Context manager that temporarily changes the print options. Accepted + arguments are same as :func:`set_printoptions`.""" + old_kwargs = get_printoptions() + set_printoptions(**kwargs) + try: + yield + finally: + set_printoptions(**old_kwargs) + + +def tensor_totype(t): + dtype = torch.float if t.is_mps else torch.double + return t.to(dtype=dtype) + + +class _Formatter: + def __init__(self, tensor): + self.floating_dtype = tensor.dtype.is_floating_point + self.int_mode = True + self.sci_mode = False + self.max_width = 1 + + with torch.no_grad(): + tensor_view = tensor.reshape(-1) + + if not self.floating_dtype: + for value in tensor_view: + value_str = f"{value}" + self.max_width = max(self.max_width, len(value_str)) + + else: + nonzero_finite_vals = torch.masked_select( + tensor_view, torch.isfinite(tensor_view) & tensor_view.ne(0) + ) + + if nonzero_finite_vals.numel() == 0: + # no valid number, do nothing + return + + # Convert to double for easy calculation. HalfTensor overflows with 1e8, and there's no div() on CPU. + nonzero_finite_abs = tensor_totype(nonzero_finite_vals.abs()) + nonzero_finite_min = tensor_totype(nonzero_finite_abs.min()) + nonzero_finite_max = tensor_totype(nonzero_finite_abs.max()) + + for value in nonzero_finite_vals: + if value != torch.ceil(value): + self.int_mode = False + break + + if self.int_mode: + # in int_mode for floats, all numbers are integers, and we append a decimal to nonfinites + # to indicate that the tensor is of floating type. add 1 to the len to account for this. + if ( + nonzero_finite_max / nonzero_finite_min > 1000.0 + or nonzero_finite_max > 1.0e8 + ): + self.sci_mode = True + for value in nonzero_finite_vals: + value_str = f"{{:.{PRINT_OPTS.precision}e}}".format(value) + self.max_width = max(self.max_width, len(value_str)) + else: + for value in nonzero_finite_vals: + value_str = f"{value:.0f}" + self.max_width = max(self.max_width, len(value_str) + 1) + else: + # Check if scientific representation should be used. + if ( + nonzero_finite_max / nonzero_finite_min > 1000.0 + or nonzero_finite_max > 1.0e8 + or nonzero_finite_min < 1.0e-4 + ): + self.sci_mode = True + for value in nonzero_finite_vals: + value_str = f"{{:.{PRINT_OPTS.precision}e}}".format(value) + self.max_width = max(self.max_width, len(value_str)) + else: + for value in nonzero_finite_vals: + value_str = f"{{:.{PRINT_OPTS.precision}f}}".format(value) + self.max_width = max(self.max_width, len(value_str)) + + if PRINT_OPTS.sci_mode is not None: + self.sci_mode = PRINT_OPTS.sci_mode + + def width(self): + return self.max_width + + def format(self, value): + if self.floating_dtype: + if self.sci_mode: + ret = f"{{:{self.max_width}.{PRINT_OPTS.precision}e}}".format(value) + elif self.int_mode: + ret = f"{value:.0f}" + if not (math.isinf(value) or math.isnan(value)): + ret += "." + else: + ret = f"{{:.{PRINT_OPTS.precision}f}}".format(value) + else: + ret = f"{value}" + return (self.max_width - len(ret)) * " " + ret + + +def _scalar_str(self, formatter1, formatter2=None): + if formatter2 is not None: + real_str = _scalar_str(self.real, formatter1) + imag_str = (_scalar_str(self.imag, formatter2) + "j").lstrip() + # handles negative numbers, +0.0, -0.0 + if imag_str[0] == "+" or imag_str[0] == "-": + return real_str + imag_str + else: + return real_str + "+" + imag_str + else: + return formatter1.format(self.item()) + + +def _vector_str(self, indent, summarize, formatter1, formatter2=None): + # length includes spaces and comma between elements + element_length = formatter1.width() + 2 + if formatter2 is not None: + # width for imag_formatter + an extra j for complex + element_length += formatter2.width() + 1 + + elements_per_line = max( + 1, int(math.floor((PRINT_OPTS.linewidth - indent) / (element_length))) + ) + + def _val_formatter(val, formatter1=formatter1, formatter2=formatter2): + if formatter2 is not None: + real_str = formatter1.format(val.real) + imag_str = (formatter2.format(val.imag) + "j").lstrip() + # handles negative numbers, +0.0, -0.0 + if imag_str[0] == "+" or imag_str[0] == "-": + return real_str + imag_str + else: + return real_str + "+" + imag_str + else: + return formatter1.format(val) + + if summarize and not PRINT_OPTS.edgeitems: + # Deal with edge case that negative zero is zero + data = ["..."] + elif summarize and self.size(0) > 2 * PRINT_OPTS.edgeitems: + data = ( + [_val_formatter(val) for val in self[: PRINT_OPTS.edgeitems].tolist()] + + [" ..."] + + [_val_formatter(val) for val in self[-PRINT_OPTS.edgeitems :].tolist()] + ) + else: + data = [_val_formatter(val) for val in self.tolist()] + + data_lines = [ + data[i : i + elements_per_line] for i in range(0, len(data), elements_per_line) + ] + lines = [", ".join(line) for line in data_lines] + return "[" + ("," + "\n" + " " * (indent + 1)).join(lines) + "]" + + +# formatter2 is only used for printing complex tensors. +# For complex tensors, formatter1 and formatter2 are the formatters for tensor.real +# and tensor.imag respesectively +def _tensor_str_with_formatter(self, indent, summarize, formatter1, formatter2=None): + dim = self.dim() + + if dim == 0: + return _scalar_str(self, formatter1, formatter2) + + if dim == 1: + return _vector_str(self, indent, summarize, formatter1, formatter2) + + if summarize and self.size(0) > 2 * PRINT_OPTS.edgeitems: + slices = ( + [ + _tensor_str_with_formatter( + self[i], indent + 1, summarize, formatter1, formatter2 + ) + for i in range(0, PRINT_OPTS.edgeitems) + ] + + ["..."] + + [ + _tensor_str_with_formatter( + self[i], indent + 1, summarize, formatter1, formatter2 + ) + for i in range(len(self) - PRINT_OPTS.edgeitems, len(self)) + ] + ) + else: + slices = [ + _tensor_str_with_formatter( + self[i], indent + 1, summarize, formatter1, formatter2 + ) + for i in range(0, self.size(0)) + ] + + tensor_str = ("," + "\n" * (dim - 1) + " " * (indent + 1)).join(slices) + return "[" + tensor_str + "]" + + +def _tensor_str(self, indent): + if self.numel() == 0: + return "[]" + + if self.has_names(): + # There are two main codepaths (possibly more) that tensor printing goes through: + # - tensor data can fit comfortably on screen + # - tensor data needs to be summarized + # Some of the codepaths don't fully support named tensors, so we send in + # an unnamed tensor to the formatting code as a workaround. + self = self.rename(None) + + summarize = self.numel() > PRINT_OPTS.threshold + + if self._is_zerotensor(): + self = self.clone() + + # handle the negative bit + if self.is_neg(): + self = self.resolve_neg() + + if self.dtype in [ + torch.float16, + torch.bfloat16, + torch.float8_e5m2, + torch.float8_e5m2fnuz, + torch.float8_e4m3fn, + torch.float8_e4m3fnuz, + ]: + self = self.float() + + if self.dtype is torch.complex32: + self = self.cfloat() + + if self.dtype.is_complex: + # handle the conjugate bit + self = self.resolve_conj() + real_formatter = _Formatter( + get_summarized_data(self.real) if summarize else self.real + ) + imag_formatter = _Formatter( + get_summarized_data(self.imag) if summarize else self.imag + ) + return _tensor_str_with_formatter( + self, indent, summarize, real_formatter, imag_formatter + ) + else: + formatter = _Formatter(get_summarized_data(self) if summarize else self) + return _tensor_str_with_formatter(self, indent, summarize, formatter) + + +def _add_suffixes(tensor_str, suffixes, indent, force_newline): + tensor_strs = [tensor_str] + last_line_len = len(tensor_str) - tensor_str.rfind("\n") + 1 + for suffix in suffixes: + suffix_len = len(suffix) + if force_newline or last_line_len + suffix_len + 2 > PRINT_OPTS.linewidth: + tensor_strs.append(",\n" + " " * indent + suffix) + last_line_len = indent + suffix_len + force_newline = False + else: + tensor_strs.append(", " + suffix) + last_line_len += suffix_len + 2 + tensor_strs.append(")") + return "".join(tensor_strs) + + +def get_summarized_data(self): + dim = self.dim() + if dim == 0: + return self + if dim == 1: + if self.size(0) > 2 * PRINT_OPTS.edgeitems: + return torch.cat( + (self[: PRINT_OPTS.edgeitems], self[-PRINT_OPTS.edgeitems :]) + ) + else: + return self + if not PRINT_OPTS.edgeitems: + return self.new_empty([0] * self.dim()) + elif self.size(0) > 2 * PRINT_OPTS.edgeitems: + start = [self[i] for i in range(0, PRINT_OPTS.edgeitems)] + end = [self[i] for i in range(len(self) - PRINT_OPTS.edgeitems, len(self))] + return torch.stack([get_summarized_data(x) for x in (start + end)]) + else: + return torch.stack([get_summarized_data(x) for x in self]) + + +def _str_intern(inp, *, tensor_contents=None): + if torch._C._functorch.is_functorch_wrapped_tensor(inp): + return _functorch_wrapper_str_intern(inp, tensor_contents=tensor_contents) + is_plain_tensor = type(inp) is torch.Tensor or type(inp) is torch.nn.Parameter + if inp.is_nested: + prefix = "nested_tensor(" + elif is_plain_tensor: + prefix = "tensor(" + else: + prefix = f"{type(inp).__name__}(" + indent = len(prefix) + suffixes = [] + custom_contents_provided = tensor_contents is not None + if custom_contents_provided: + tensor_str = tensor_contents + + # This is used to extract the primal value and thus disable the forward AD + # within this function. + # TODO(albanD) This needs to be updated when more than one level is supported + self, tangent = torch.autograd.forward_ad.unpack_dual(inp) + + # Note [Print tensor device]: + # A general logic here is we only print device when it doesn't match + # the device specified in default tensor type. + # Currently torch.set_default_tensor_type() only supports CPU/CUDA, thus + # torch._C._get_default_device() only returns either cpu or cuda. + # In other cases, we don't have a way to set them as default yet, + # and we should always print out device for them. + if ( + self.device.type != torch._C._get_default_device() + or ( + self.device.type == "cuda" + and torch.cuda.current_device() != self.device.index + ) + or (self.device.type == "mps") + ): + suffixes.append("device='" + str(self.device) + "'") + + # Tensor printing performs tensor operations like slice, indexing, etc to make it in a + # representable format. These operations on ipu/xla/lazy/mtia tensor results in compilations. Hence, + # to avoid compilations, copying the tensor to cpu before printing. + if self.device.type in ["xla", "lazy", "ipu", "mtia"]: + self = self.to("cpu") + + # TODO: add an API to map real -> complex dtypes + _default_complex_dtype = ( + torch.cdouble if torch.get_default_dtype() == torch.double else torch.cfloat + ) + has_default_dtype = self.dtype in ( + torch.get_default_dtype(), + _default_complex_dtype, + torch.int64, + torch.bool, + ) + if self.is_sparse: + suffixes.append("size=" + str(tuple(self.shape))) + from torch._subclasses.fake_tensor import FakeTensor + + is_meta = self.is_meta or isinstance(self, FakeTensor) + if not is_meta: + suffixes.append("nnz=" + str(self._nnz())) + if not has_default_dtype: + suffixes.append("dtype=" + str(self.dtype)) + if not custom_contents_provided: + indices_prefix = "indices=tensor(" + indices = self._indices().detach() + if is_meta: + indices_str = "..." + else: + indices_str = _tensor_str(indices, indent + len(indices_prefix)) + if indices.numel() == 0 or is_meta: + indices_str += ", size=" + str(tuple(indices.shape)) + values_prefix = "values=tensor(" + values = self._values().detach() + if is_meta: + values_str = "..." + else: + values_str = _tensor_str(values, indent + len(values_prefix)) + if values.numel() == 0 or is_meta: + values_str += ", size=" + str(tuple(values.shape)) + tensor_str = ( + indices_prefix + + indices_str + + "),\n" + + " " * indent + + values_prefix + + values_str + + ")" + ) + elif self.layout in { + torch.sparse_csr, + torch.sparse_csc, + torch.sparse_bsr, + torch.sparse_bsc, + }: + from torch._subclasses.fake_tensor import FakeTensor + + suffixes.append("size=" + str(tuple(self.shape))) + is_meta = self.is_meta or isinstance(self, FakeTensor) + if not is_meta: + suffixes.append("nnz=" + str(self._nnz())) + if not has_default_dtype: + suffixes.append("dtype=" + str(self.dtype)) + if not custom_contents_provided: + compressed_indices_method, plain_indices_method = { + torch.sparse_csr: (torch.Tensor.crow_indices, torch.Tensor.col_indices), + torch.sparse_csc: (torch.Tensor.ccol_indices, torch.Tensor.row_indices), + torch.sparse_bsr: (torch.Tensor.crow_indices, torch.Tensor.col_indices), + torch.sparse_bsc: (torch.Tensor.ccol_indices, torch.Tensor.row_indices), + }[self.layout] + if self.layout in {torch.sparse_csr, torch.sparse_bsr}: + cdimname, pdimname = "row", "column" + else: + cdimname, pdimname = "column", "row" + compressed_indices_prefix = f"c{cdimname[:3]}_indices=tensor(" + compressed_indices = compressed_indices_method(self).detach() + if is_meta: + compressed_indices_str = "..." + else: + compressed_indices_str = _tensor_str( + compressed_indices, indent + len(compressed_indices_prefix) + ) + if compressed_indices.numel() == 0 or is_meta: + compressed_indices_str += ", size=" + str( + tuple(compressed_indices.shape) + ) + plain_indices_prefix = f"{pdimname[:3]}_indices=tensor(" + plain_indices = plain_indices_method(self).detach() + if is_meta: + plain_indices_str = "..." + else: + plain_indices_str = _tensor_str( + plain_indices, indent + len(plain_indices_prefix) + ) + if plain_indices.numel() == 0 or is_meta: + plain_indices_str += ", size=" + str(tuple(plain_indices.shape)) + values_prefix = "values=tensor(" + values = self.values().detach() + if is_meta: + values_str = "..." + else: + values_str = _tensor_str(values, indent + len(values_prefix)) + if values.numel() == 0 or is_meta: + values_str += ", size=" + str(tuple(values.shape)) + tensor_str = ( + compressed_indices_prefix + + compressed_indices_str + + "),\n" + + " " * indent + + plain_indices_prefix + + plain_indices_str + + "),\n" + + " " * indent + + values_prefix + + values_str + + ")" + ) + elif self.is_quantized: + suffixes.append("size=" + str(tuple(self.shape))) + if not has_default_dtype: + suffixes.append("dtype=" + str(self.dtype)) + suffixes.append("quantization_scheme=" + str(self.qscheme())) + if ( + self.qscheme() == torch.per_tensor_affine + or self.qscheme() == torch.per_tensor_symmetric + ): + suffixes.append("scale=" + str(self.q_scale())) + suffixes.append("zero_point=" + str(self.q_zero_point())) + elif ( + self.qscheme() == torch.per_channel_affine + or self.qscheme() == torch.per_channel_symmetric + or self.qscheme() == torch.per_channel_affine_float_qparams + ): + suffixes.append("scale=" + str(self.q_per_channel_scales())) + suffixes.append("zero_point=" + str(self.q_per_channel_zero_points())) + suffixes.append("axis=" + str(self.q_per_channel_axis())) + if not custom_contents_provided: + tensor_str = _tensor_str(self.dequantize(), indent) + elif self.is_nested: + if not custom_contents_provided: + + def indented_str(s, indent): + return "\n".join(f" {line}" for line in s.split("\n")) + + strs = ",\n".join( + indented_str(str(t), indent + 1) + for t in torch.ops.aten.unbind.int(self, 0) + ) + tensor_str = f"[\n{strs}\n]" + elif torch._is_functional_tensor(self): + prefix = "_to_functional_tensor(" + tensor_str = repr(torch._from_functional_tensor(self)) + else: + # Circular import problem, so we import it here + from torch._subclasses.fake_tensor import FakeTensor + + if self.is_meta or isinstance(self, FakeTensor): + suffixes.append("size=" + str(tuple(self.shape))) + if self.dtype != torch.get_default_dtype(): + suffixes.append("dtype=" + str(self.dtype)) + # TODO: This implies that ellipses is valid syntax for allocating + # a meta tensor or FakeTensor, which it could be, but it isn't right now + if not custom_contents_provided: + tensor_str = "..." + else: + if self.numel() == 0 and not self.is_sparse: + # Explicitly print the shape if it is not (0,), to match NumPy behavior + if self.dim() != 1: + suffixes.append("size=" + str(tuple(self.shape))) + + # In an empty tensor, there are no elements to infer if the dtype + # should be int64, so it must be shown explicitly. + if self.dtype != torch.get_default_dtype(): + suffixes.append("dtype=" + str(self.dtype)) + if not custom_contents_provided: + tensor_str = "[]" + else: + if not PRINT_OPTS.edgeitems: + suffixes.append("size=" + str(tuple(self.shape))) + + if not has_default_dtype: + suffixes.append("dtype=" + str(self.dtype)) + + if not custom_contents_provided: + if self.layout != torch.strided: + tensor_str = _tensor_str(self.to_dense(), indent) + else: + tensor_str = _tensor_str(self, indent) + + if self.layout != torch.strided: + suffixes.append("layout=" + str(self.layout)) + + # Use inp here to get the original grad_fn and not the one generated by the forward grad + # unpacking. + grad_fn_name = None + try: + grad_fn = inp.grad_fn + except RuntimeError: + # Accessing the grad_fn calls rebasing logic which would cause an error + # if that tensor is a view created in no-grad mode modified in-place in + # no-grad mode. See: https://github.com/pytorch/pytorch/issues/99968 + grad_fn_name = "Invalid" + + if grad_fn_name is None and grad_fn is not None: # type: ignore[possibly-undefined] + grad_fn_name = type(grad_fn).__name__ + if grad_fn_name == "CppFunction": + grad_fn_name = grad_fn.name().rsplit("::", 1)[-1] + + if grad_fn_name is not None: + suffixes.append(f"grad_fn=<{grad_fn_name}>") + elif inp.requires_grad: + suffixes.append("requires_grad=True") + + if self.has_names(): + suffixes.append(f"names={self.names}") + + if tangent is not None: + suffixes.append(f"tangent={tangent}") + + string_repr = _add_suffixes( + prefix + tensor_str, suffixes, indent, force_newline=self.is_sparse # type: ignore[possibly-undefined] + ) + + # Check if this instance is flagged as a parameter and change the repr accordingly. + # Unfortunately, this function has to be aware of this detail. + # NB: This is currently skipped for plain tensor parameters to maintain BC. In the future, + # this should be done for those as well to produce a valid repr. + if isinstance(self, torch.nn.Parameter) and not is_plain_tensor: + string_repr = f"Parameter({string_repr})" + + return string_repr + + +def _functorch_wrapper_str_intern(tensor, *, tensor_contents=None): + level = torch._C._functorch.maybe_get_level(tensor) + assert level != -1 + + if torch._C._functorch.is_functionaltensor(tensor): + # Since we're unwrapping the FunctionalTensorWrapper, we need to make sure + # that it's up to date first + torch._sync(tensor) + + value = torch._C._functorch.get_unwrapped(tensor) + value_repr = repr(value) + + indented_value_repr = textwrap.indent(value_repr, " " * 4) + if torch._C._functorch.is_batchedtensor(tensor): + bdim = torch._C._functorch.maybe_get_bdim(tensor) + assert bdim != -1 + return ( + f"BatchedTensor(lvl={level}, bdim={bdim}, value=\n" + f"{indented_value_repr}\n" + f")" + ) + if torch._C._functorch.is_gradtrackingtensor(tensor): + return ( + f"GradTrackingTensor(lvl={level}, value=\n" f"{indented_value_repr}\n" f")" + ) + if torch._C._functorch.is_functionaltensor(tensor): + return f"FunctionalTensor(lvl={level}, value=\\\n{value_repr})" + + raise ValueError("We don't know how to print this, please file us an issue") + + +def _str(self, *, tensor_contents=None): + with torch.no_grad(), torch.utils._python_dispatch._disable_current_modes(): + guard = torch._C._DisableFuncTorch() + return _str_intern(self, tensor_contents=tensor_contents) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/intrinsic/modules/fused.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/intrinsic/modules/fused.py new file mode 100644 index 0000000000000000000000000000000000000000..2b4c6f489e99aa3cf1b31c340bea4a03f589d4ae --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/intrinsic/modules/fused.py @@ -0,0 +1,160 @@ +import torch +from torch.nn import Conv1d, Conv2d, Conv3d, ReLU, Linear, BatchNorm1d, BatchNorm2d, BatchNorm3d +from torch.nn.utils.parametrize import type_before_parametrizations + +__all__ = ['ConvReLU1d', 'ConvReLU2d', 'ConvReLU3d', 'LinearReLU', 'ConvBn1d', 'ConvBn2d', + 'ConvBnReLU1d', 'ConvBnReLU2d', 'ConvBn3d', 'ConvBnReLU3d', 'BNReLU2d', 'BNReLU3d', + 'LinearBn1d', 'LinearLeakyReLU', 'LinearTanh', 'ConvAdd2d', 'ConvAddReLU2d'] + +# Used for identifying intrinsic modules used in quantization +class _FusedModule(torch.nn.Sequential): + pass + +class ConvReLU1d(_FusedModule): + r"""This is a sequential container which calls the Conv1d and ReLU modules. + During quantization this will be replaced with the corresponding fused module.""" + def __init__(self, conv, relu): + assert type_before_parametrizations(conv) == Conv1d and type_before_parametrizations(relu) == ReLU, \ + f'Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(relu)}' + super().__init__(conv, relu) + +class ConvReLU2d(_FusedModule): + r"""This is a sequential container which calls the Conv2d and ReLU modules. + During quantization this will be replaced with the corresponding fused module.""" + def __init__(self, conv, relu): + assert type_before_parametrizations(conv) == Conv2d and type_before_parametrizations(relu) == ReLU, \ + f'Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(relu)}' + super().__init__(conv, relu) + +class ConvReLU3d(_FusedModule): + r"""This is a sequential container which calls the Conv3d and ReLU modules. + During quantization this will be replaced with the corresponding fused module.""" + def __init__(self, conv, relu): + assert type_before_parametrizations(conv) == Conv3d and type_before_parametrizations(relu) == ReLU, \ + f'Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(relu)}' + super().__init__(conv, relu) + +class LinearReLU(_FusedModule): + r"""This is a sequential container which calls the Linear and ReLU modules. + During quantization this will be replaced with the corresponding fused module.""" + def __init__(self, linear, relu): + assert type_before_parametrizations(linear) == Linear and type_before_parametrizations(relu) == ReLU, \ + 'Incorrect types for input modules{}{}'.format( + type_before_parametrizations(linear), type_before_parametrizations(relu)) + super().__init__(linear, relu) + +class ConvBn1d(_FusedModule): + r"""This is a sequential container which calls the Conv 1d and Batch Norm 1d modules. + During quantization this will be replaced with the corresponding fused module.""" + def __init__(self, conv, bn): + assert type_before_parametrizations(conv) == Conv1d and type_before_parametrizations(bn) == BatchNorm1d, \ + f'Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(bn)}' + super().__init__(conv, bn) + +class ConvBn2d(_FusedModule): + r"""This is a sequential container which calls the Conv 2d and Batch Norm 2d modules. + During quantization this will be replaced with the corresponding fused module.""" + def __init__(self, conv, bn): + assert type_before_parametrizations(conv) == Conv2d and type_before_parametrizations(bn) == BatchNorm2d, \ + f'Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(bn)}' + super().__init__(conv, bn) + +class ConvBnReLU1d(_FusedModule): + r"""This is a sequential container which calls the Conv 1d, Batch Norm 1d, and ReLU modules. + During quantization this will be replaced with the corresponding fused module.""" + def __init__(self, conv, bn, relu): + assert type_before_parametrizations(conv) == Conv1d and type_before_parametrizations(bn) == BatchNorm1d and \ + type_before_parametrizations(relu) == ReLU, 'Incorrect types for input modules{}{}{}' \ + .format(type_before_parametrizations(conv), type_before_parametrizations(bn), type_before_parametrizations(relu)) + super().__init__(conv, bn, relu) + +class ConvBnReLU2d(_FusedModule): + r"""This is a sequential container which calls the Conv 2d, Batch Norm 2d, and ReLU modules. + During quantization this will be replaced with the corresponding fused module.""" + def __init__(self, conv, bn, relu): + assert type_before_parametrizations(conv) == Conv2d and type_before_parametrizations(bn) == BatchNorm2d and \ + type_before_parametrizations(relu) == ReLU, 'Incorrect types for input modules{}{}{}' \ + .format(type_before_parametrizations(conv), type_before_parametrizations(bn), type_before_parametrizations(relu)) + super().__init__(conv, bn, relu) + +class ConvBn3d(_FusedModule): + r"""This is a sequential container which calls the Conv 3d and Batch Norm 3d modules. + During quantization this will be replaced with the corresponding fused module.""" + def __init__(self, conv, bn): + assert type_before_parametrizations(conv) == Conv3d and type_before_parametrizations(bn) == BatchNorm3d, \ + f'Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(bn)}' + super().__init__(conv, bn) + +class ConvBnReLU3d(_FusedModule): + r"""This is a sequential container which calls the Conv 3d, Batch Norm 3d, and ReLU modules. + During quantization this will be replaced with the corresponding fused module.""" + def __init__(self, conv, bn, relu): + assert type_before_parametrizations(conv) == Conv3d and type_before_parametrizations(bn) == BatchNorm3d and \ + type_before_parametrizations(relu) == ReLU, 'Incorrect types for input modules{}{}{}' \ + .format(type_before_parametrizations(conv), type_before_parametrizations(bn), type_before_parametrizations(relu)) + super().__init__(conv, bn, relu) + + +class BNReLU2d(_FusedModule): + r"""This is a sequential container which calls the BatchNorm 2d and ReLU modules. + During quantization this will be replaced with the corresponding fused module.""" + def __init__(self, batch_norm, relu): + assert type_before_parametrizations(batch_norm) == BatchNorm2d and type_before_parametrizations(relu) == ReLU, \ + 'Incorrect types for input modules{}{}'.format( + type_before_parametrizations(batch_norm), type_before_parametrizations(relu)) + super().__init__(batch_norm, relu) + +class BNReLU3d(_FusedModule): + r"""This is a sequential container which calls the BatchNorm 3d and ReLU modules. + During quantization this will be replaced with the corresponding fused module.""" + def __init__(self, batch_norm, relu): + assert type_before_parametrizations(batch_norm) == BatchNorm3d and type_before_parametrizations(relu) == ReLU, \ + 'Incorrect types for input modules{}{}'.format( + type_before_parametrizations(batch_norm), type_before_parametrizations(relu)) + super().__init__(batch_norm, relu) + + +class LinearBn1d(_FusedModule): + r"""This is a sequential container which calls the Linear and BatchNorm1d modules. + During quantization this will be replaced with the corresponding fused module.""" + def __init__(self, linear, bn): + assert type_before_parametrizations(linear) == Linear and type_before_parametrizations(bn) == BatchNorm1d, \ + f'Incorrect types for input modules{type_before_parametrizations(linear)}{type_before_parametrizations(bn)}' + super().__init__(linear, bn) + +class LinearLeakyReLU(_FusedModule): + r"""This is a sequential container which calls the Linear and LeakyReLU modules. + During quantization this will be replaced with the corresponding fused module.""" + def __init__(self, linear, leaky_relu): + assert type(linear) == Linear and type(leaky_relu) == torch.nn.LeakyReLU, \ + f'Incorrect types for input modules{type(linear)}{type(leaky_relu)}' + super().__init__(linear, leaky_relu) + +class LinearTanh(_FusedModule): + r"""This is a sequential container which calls the Linear and Tanh modules. + During quantization this will be replaced with the corresponding fused module.""" + def __init__(self, linear, tanh): + assert type(linear) == Linear and type(tanh) == torch.nn.Tanh, \ + f'Incorrect types for input modules{type(linear)}{type(tanh)}' + super().__init__(linear, tanh) + +class ConvAdd2d(_FusedModule): + r"""This is a sequential container which calls the Conv2d modules with extra Add. + During quantization this will be replaced with the corresponding fused module.""" + def __init__(self, conv, add): + super().__init__(conv) + self.add = add + + def forward(self, x1, x2): + return self.add(self[0](x1), x2) + +class ConvAddReLU2d(_FusedModule): + r"""This is a sequential container which calls the Conv2d, add, Relu. + During quantization this will be replaced with the corresponding fused module.""" + def __init__(self, conv, add, relu): + super().__init__(conv) + self.add = add + self.relu = relu + + def forward(self, x1, x2): + return self.relu(self.add(self[0](x1), x2)) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/intrinsic/quantized/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/intrinsic/quantized/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..78c75f0c82b5605575e3abceafd41b3036cb9431 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/intrinsic/quantized/__init__.py @@ -0,0 +1,14 @@ +from .modules import * # noqa: F403 + +__all__ = [ + 'BNReLU2d', + 'BNReLU3d', + 'ConvReLU1d', + 'ConvReLU2d', + 'ConvReLU3d', + 'LinearReLU', + 'LinearLeakyReLU', + 'LinearTanh', + 'ConvAdd2d', + 'ConvAddReLU2d', +] diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/modules/__pycache__/linear_relu.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/modules/__pycache__/linear_relu.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab8052fdd92044bfa8f8336993b7841658a5530c Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/modules/__pycache__/linear_relu.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/intrinsic/quantized/modules/conv_relu.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/intrinsic/quantized/modules/conv_relu.py new file mode 100644 index 0000000000000000000000000000000000000000..5cdc9004c99c600fbeaad5cf6d2196614ca36810 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/intrinsic/quantized/modules/conv_relu.py @@ -0,0 +1,175 @@ + +import torch +import torch.ao.nn.intrinsic +import torch.ao.nn.intrinsic.qat +import torch.nn.functional as F +import torch.ao.nn.quantized as nnq + +from torch.nn.utils import fuse_conv_bn_weights + +__all__ = [ + "ConvReLU1d", + "ConvReLU2d", + "ConvReLU3d", +] + +_reverse_repeat_padding = nnq.modules.conv._reverse_repeat_padding + +# TODO: factor out the common parts to ConvNd +class ConvReLU1d(nnq.Conv1d): + r""" + A ConvReLU1d module is a fused module of Conv1d and ReLU + + We adopt the same interface as :class:`torch.ao.nn.quantized.Conv1d`. + + Attributes: + Same as torch.ao.nn.quantized.Conv1d + + """ + _FLOAT_MODULE = torch.ao.nn.intrinsic.ConvReLU1d # type: ignore[assignment] + + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, dilation=1, groups=1, bias=True, + padding_mode='zeros', device=None, dtype=None): + super().__init__( + in_channels, out_channels, kernel_size, stride=stride, + padding=padding, dilation=dilation, groups=groups, bias=bias, + padding_mode=padding_mode, device=device, dtype=dtype) + + def forward(self, input): + # Temporarily using len(shape) instead of ndim due to JIT issue + # https://github.com/pytorch/pytorch/issues/23890 + if len(input.shape) != 3: + raise ValueError("Input shape must be `(N, C, L)`!") + if self.padding_mode != 'zeros': + # Padding in Conv1d is stored as (p, p), need to get (p,) + _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding[:1]) + input = F.pad(input, _reversed_padding_repeated_twice, + mode=self.padding_mode) + return torch.ops.quantized.conv1d_relu( + input, self._packed_params, self.scale, self.zero_point) + + def _get_name(self): + return 'QuantizedConvReLU1d' + + @classmethod + def from_float(cls, mod): + if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU1d: + assert mod.bn.running_var is not None and mod.bn.running_mean is not None + mod.weight, mod.bias = fuse_conv_bn_weights( + mod.weight, mod.bias, mod.bn.running_mean, mod.bn.running_var, + mod.bn.eps, mod.bn.weight, mod.bn.bias) + return super().from_float(mod) + + @classmethod + def from_reference(cls, ref_qconv, output_scale, output_zero_point): + assert type(ref_qconv) != torch.ao.nn.intrinsic.ConvBnReLU1d, \ + "BatchNorm1d should be fused into Conv1d before converting to reference module" + return super().from_reference(ref_qconv[0], output_scale, output_zero_point) + +class ConvReLU2d(nnq.Conv2d): + r""" + A ConvReLU2d module is a fused module of Conv2d and ReLU + + We adopt the same interface as :class:`torch.ao.nn.quantized.Conv2d`. + + Attributes: + Same as torch.ao.nn.quantized.Conv2d + + """ + _FLOAT_MODULE = torch.ao.nn.intrinsic.ConvReLU2d # type: ignore[assignment] + + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, dilation=1, groups=1, bias=True, + padding_mode='zeros', device=None, dtype=None): + super().__init__( + in_channels, out_channels, kernel_size, stride=stride, + padding=padding, dilation=dilation, groups=groups, bias=bias, + padding_mode=padding_mode, device=device, dtype=dtype) + + def forward(self, input): + # Temporarily using len(shape) instead of ndim due to JIT issue + # https://github.com/pytorch/pytorch/issues/23890 + if len(input.shape) != 4: + raise ValueError("Input shape must be `(N, C, H, W)`!") + if self.padding_mode != 'zeros': + _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding) + input = F.pad(input, _reversed_padding_repeated_twice, + mode=self.padding_mode) + return torch.ops.quantized.conv2d_relu( + input, self._packed_params, self.scale, self.zero_point) + + def _get_name(self): + return 'QuantizedConvReLU2d' + + @classmethod + def from_float(cls, mod): + if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU2d: + assert mod.bn.running_var is not None and mod.bn.running_mean is not None + mod.weight, mod.bias = fuse_conv_bn_weights( + mod.weight, mod.bias, mod.bn.running_mean, mod.bn.running_var, + mod.bn.eps, mod.bn.weight, mod.bn.bias) + return super().from_float(mod) + + @classmethod + def from_reference(cls, ref_qconv, output_scale, output_zero_point): + assert type(ref_qconv) != torch.ao.nn.intrinsic.ConvBnReLU2d, \ + "BatchNorm2d should be fused into Conv2d before converting to reference module" + return super().from_reference(ref_qconv[0], output_scale, output_zero_point) + + +class ConvReLU3d(nnq.Conv3d): + r""" + A ConvReLU3d module is a fused module of Conv3d and ReLU + + We adopt the same interface as :class:`torch.ao.nn.quantized.Conv3d`. + + Attributes: Same as torch.ao.nn.quantized.Conv3d + + """ + _FLOAT_MODULE = torch.ao.nn.intrinsic.ConvReLU3d # type: ignore[assignment] + + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, dilation=1, groups=1, bias=True, + padding_mode='zeros', device=None, dtype=None): + assert padding_mode != 'reflect', "Conv3d does not support reflection padding" + super().__init__( + in_channels, out_channels, kernel_size, stride=stride, + padding=padding, dilation=dilation, groups=groups, bias=bias, + padding_mode=padding_mode, device=device, dtype=dtype) + + def forward(self, input): + # Temporarily using len(shape) instead of ndim due to JIT issue + # https://github.com/pytorch/pytorch/issues/23890 + if len(input.shape) != 5: + raise ValueError("Input shape must be `(N, C, D, H, W)`!") + if self.padding_mode != 'zeros': + _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding) + input = F.pad(input, _reversed_padding_repeated_twice, + mode=self.padding_mode) + return torch.ops.quantized.conv3d_relu( + input, self._packed_params, self.scale, self.zero_point) + + def _get_name(self): + return 'QuantizedConvReLU3d' + + @classmethod + def from_float(cls, mod): + if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU3d: + assert mod.bn.running_var is not None and mod.bn.running_mean is not None + mod.weight, mod.bias = fuse_conv_bn_weights( + mod.weight, + mod.bias, + mod.bn.running_mean, + mod.bn.running_var, + mod.bn.eps, + mod.bn.weight, + mod.bn.bias, + ) + return super().from_float(mod) + + @classmethod + def from_reference(cls, ref_qconv, output_scale, output_zero_point): + assert type(ref_qconv) != torch.ao.nn.intrinsic.ConvBnReLU3d, \ + "BatchNorm3d should be fused into Conv3d before converting to reference module" + return super().from_reference(ref_qconv[0], output_scale, output_zero_point) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/intrinsic/quantized/modules/linear_relu.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/intrinsic/quantized/modules/linear_relu.py new file mode 100644 index 0000000000000000000000000000000000000000..e774a72dc8229194328ef2af3054599506f00d3e --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/intrinsic/quantized/modules/linear_relu.py @@ -0,0 +1,177 @@ +import torch +import torch.ao.nn.quantized as nnq +import torch.ao.nn.intrinsic as nni +from torch.ao.nn.quantized.modules.utils import _quantize_weight + +__all__ = [ + "LinearReLU", + "LinearLeakyReLU", + "LinearTanh", +] + +class LinearReLU(nnq.Linear): + r""" + A LinearReLU module fused from Linear and ReLU modules + + We adopt the same interface as :class:`torch.ao.nn.quantized.Linear`. + + Attributes: + Same as torch.ao.nn.quantized.Linear + + Examples:: + + >>> # xdoctest: +SKIP + >>> m = nn.intrinsic.LinearReLU(20, 30) + >>> input = torch.randn(128, 20) + >>> output = m(input) + >>> print(output.size()) + torch.Size([128, 30]) + """ + _FLOAT_MODULE = nni.LinearReLU # type: ignore[assignment] + + def __init__(self, in_features, out_features, bias=True, dtype=torch.qint8): + super().__init__(in_features, out_features, bias, dtype) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.ops.quantized.linear_relu( + x, self._packed_params._packed_params, self.scale, self.zero_point) + + def _get_name(self): + return 'QuantizedLinearReLU' + + @classmethod + def from_float(cls, mod): + return super().from_float(mod) + + @classmethod + def from_reference(cls, ref_linear_relu, output_scale, output_zero_point): + return super().from_reference(ref_linear_relu[0], output_scale, output_zero_point) + +class LinearLeakyReLU(nnq.Linear): + r""" + For onednn backend only + A LinearLeakyReLU module fused from Linear and LeakyReLU modules + We adopt the same interface as :class:`torch.ao.nn.quantized.Linear`. + Attributes: + Same as torch.ao.nn.quantized.Linear + + negative_slope + Examples:: + >>> # xdoctest: +SKIP + >>> m = nn.intrinsic.LinearLeakyReLU(20, 30, 0.01) + >>> input = torch.randn(128, 20) + >>> output = m(input) + >>> print(output.size()) + torch.Size([128, 30]) + """ + _FLOAT_MODULE = nni.LinearLeakyReLU # type: ignore[assignment] + + def __init__(self, in_features, out_features, negative_slope, bias=True, dtype=torch.qint8): + super().__init__(in_features, out_features, bias, dtype) + self.negative_slope = negative_slope + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.ops.quantized.linear_leaky_relu( + x, self._packed_params._packed_params, self.scale, self.zero_point, self.negative_slope) + + def _get_name(self): + return 'QuantizedLinearLeakyReLU' + + @classmethod + def from_float(cls, mod): + assert type(mod) == nni.LinearLeakyReLU, 'Input float module should be LinearLeakyReLU' + assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined' + activation_post_process = mod.activation_post_process + leaky_relu = mod[1] + mod = mod[0] + weight_post_process = mod.qconfig.weight() + weight_post_process(mod.weight) + dtype = weight_post_process.dtype + act_scale, act_zp = activation_post_process.calculate_qparams() # type: ignore[union-attr,operator] + assert dtype == torch.qint8, 'Weight observer must have dtype torch.qint8' + qweight = _quantize_weight(mod.weight.float(), weight_post_process) + qlinear_leaky_relu = cls( + mod.in_features, + mod.out_features, + leaky_relu.negative_slope, + dtype=dtype) + qlinear_leaky_relu.set_weight_bias(qweight, mod.bias) + qlinear_leaky_relu.scale = float(act_scale) + qlinear_leaky_relu.zero_point = int(act_zp) + return qlinear_leaky_relu + + @classmethod + def from_reference(cls, ref_mod, output_scale, output_zero_point): + linear = ref_mod[0] + leaky_relu = ref_mod[1] + qlinear_leaky_relu = cls( + linear.in_features, + linear.out_features, + leaky_relu.negative_slope) + qweight = linear.get_quantized_weight() + qlinear_leaky_relu.set_weight_bias(qweight, linear.bias) + qlinear_leaky_relu.scale = float(output_scale) + qlinear_leaky_relu.zero_point = int(output_zero_point) + return qlinear_leaky_relu + +class LinearTanh(nnq.Linear): + r""" + A LinearTanh module fused from Linear and Tanh modules + + We adopt the same interface as :class:`torch.ao.nn.quantized.Linear`. + + Attributes: + Same as torch.ao.nn.quantized.Linear + + Examples:: + + >>> # xdoctest: +SKIP + >>> m = nn.intrinsic.LinearTanh(20, 30) + >>> input = torch.randn(128, 20) + >>> output = m(input) + >>> print(output.size()) + torch.Size([128, 30]) + """ + _FLOAT_MODULE = nni.LinearTanh # type: ignore[assignment] + + def __init__(self, in_features, out_features, bias=True, dtype=torch.qint8): + super().__init__(in_features, out_features, bias, dtype) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.ops.quantized.linear_tanh( + x, self._packed_params._packed_params, self.scale, self.zero_point) + + def _get_name(self): + return 'QuantizedLinearTanh' + + @classmethod + def from_float(cls, mod): + assert type(mod) == nni.LinearTanh, 'Input float module should be LinearTanh' + assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined' + activation_post_process = mod.activation_post_process + mod = mod[0] + weight_post_process = mod.qconfig.weight() + weight_post_process(mod.weight) + dtype = weight_post_process.dtype + act_scale, act_zp = activation_post_process.calculate_qparams() # type: ignore[union-attr,operator] + assert dtype == torch.qint8, 'Weight observer must have dtype torch.qint8' + qweight = _quantize_weight(mod.weight.float(), weight_post_process) + qlinear_tanh = cls( + mod.in_features, + mod.out_features, + dtype=dtype) + qlinear_tanh.set_weight_bias(qweight, mod.bias) + qlinear_tanh.scale = float(act_scale) + qlinear_tanh.zero_point = int(act_zp) + return qlinear_tanh + + @classmethod + def from_reference(cls, ref_mod, output_scale, output_zero_point): + linear = ref_mod[0] + qlinear_tanh = cls( + linear.in_features, + linear.out_features) + qweight = linear.get_quantized_weight() + qlinear_tanh.set_weight_bias(qweight, linear.bias) + qlinear_tanh.scale = float(output_scale) + qlinear_tanh.zero_point = int(output_zero_point) + return qlinear_tanh diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/qat/dynamic/modules/linear.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/qat/dynamic/modules/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..c93dfab1f15b03dde346f85bac793e41c843c168 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/qat/dynamic/modules/linear.py @@ -0,0 +1,25 @@ +import torch + +__all__ = ["Linear"] + +class Linear(torch.ao.nn.qat.Linear): + r""" + A linear module attached with FakeQuantize modules for weight, + used for dynamic quantization aware training. + + We adopt the same interface as `torch.nn.Linear`, please see + https://pytorch.org/docs/stable/nn.html#torch.nn.Linear + for documentation. + + Similar to `torch.nn.Linear`, with FakeQuantize modules initialized to + default. + """ + + def __init__(self, in_features, out_features, bias=True, + qconfig=None, device=None, dtype=None) -> None: + super().__init__(in_features, out_features, bias, qconfig, device, dtype) + if not torch.ao.quantization.qconfig._activation_is_memoryless(qconfig): + raise ValueError( + "Dynamic QAT requires a memoryless observer." + + "This means a MovingAverage observer with averaging constant equal to 1" + ) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/qat/modules/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/qat/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..988a1dd5ed4b61b038bdd510f831b0109d684489 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/qat/modules/__init__.py @@ -0,0 +1,14 @@ +from .linear import Linear +from .conv import Conv1d +from .conv import Conv2d +from .conv import Conv3d +from .embedding_ops import EmbeddingBag, Embedding + +__all__ = [ + "Linear", + "Conv1d", + "Conv2d", + "Conv3d", + "Embedding", + "EmbeddingBag", +] diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantizable/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantizable/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9464dfccfcfa05fdabc822ead01c7a9ca7a66ec1 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantizable/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantizable/modules/activation.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantizable/modules/activation.py new file mode 100644 index 0000000000000000000000000000000000000000..56be29a09d6243dd0d288b71c5691bef011119c1 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantizable/modules/activation.py @@ -0,0 +1,465 @@ +import torch +import torch.jit # this is needed to avoid a circular import +from torch import nn +import torch.nn.functional as nnF + +from torch import Tensor +from typing import Optional, Tuple + +import warnings + +__all__ = [ + "MultiheadAttention" +] + +class MultiheadAttention(nn.MultiheadAttention): + _FLOAT_MODULE = nn.MultiheadAttention + + r"""Quantizable implementation of the MultiheadAttention. + + Note:: + Please, refer to :class:`~torch.nn.MultiheadAttention` for more + information + + Allows the model to jointly attend to information from different + representation subspaces. + See reference: Attention Is All You Need + + The original MHA module is not quantizable. + This reimplements it by explicitly instantiating the linear layers. + + .. math:: + \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O + \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) + + Args: + embed_dim: total dimension of the model. + num_heads: parallel attention heads. + dropout: a Dropout layer on attn_output_weights. Default: 0.0. + bias: add bias as module parameter. Default: True. + add_bias_kv: add bias to the key and value sequences at dim=0. + add_zero_attn: add a new batch of zeros to the key and + value sequences at dim=1. + kdim: total number of features in key. Default: None. + vdim: total number of features in value. Default: None. + batch_first: If ``True``, then the input and output tensors are provided + as (batch, seq, feature). Default: ``False`` (seq, batch, feature). + + Note that if :attr:`kdim` and :attr:`vdim` are None, they will be set + to :attr:`embed_dim` such that query, key, and value have the same + number of features. + + Examples:: + + >>> import torch.ao.nn.quantizable as nnqa + >>> multihead_attn = nnqa.MultiheadAttention(embed_dim, num_heads) + >>> attn_output, attn_output_weights = multihead_attn(query, key, value) + + Note:: + Please, follow the quantization flow to convert the quantizable MHA. + """ + __constants__ = ['batch_first'] + + def __init__(self, embed_dim: int, num_heads: int, + dropout: float = 0., bias: bool = True, + add_bias_kv: bool = False, add_zero_attn: bool = False, + kdim: Optional[int] = None, vdim: Optional[int] = None, batch_first: bool = False, + device=None, dtype=None) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__(embed_dim, num_heads, dropout, + bias, add_bias_kv, + add_zero_attn, kdim, vdim, batch_first, + **factory_kwargs) + self.linear_Q = nn.Linear(self.embed_dim, self.embed_dim, bias=bias, **factory_kwargs) + self.linear_K = nn.Linear(self.kdim, self.embed_dim, bias=bias, **factory_kwargs) + self.linear_V = nn.Linear(self.vdim, self.embed_dim, bias=bias, **factory_kwargs) + # for the type: ignore, see https://github.com/pytorch/pytorch/issues/58969 + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=bias, **factory_kwargs) # type: ignore[assignment] + + # Functionals + self.q_scaling_product = torch.ao.nn.quantized.FloatFunctional() + # note: importing torch.ao.nn.quantized at top creates a circular import + + # Quant/Dequant + self.quant_attn_output = torch.ao.quantization.QuantStub() + self.quant_attn_output_weights = torch.ao.quantization.QuantStub() + self.dequant_q = torch.ao.quantization.DeQuantStub() + self.dequant_k = torch.ao.quantization.DeQuantStub() + self.dequant_v = torch.ao.quantization.DeQuantStub() + + def _get_name(self): + return 'QuantizableMultiheadAttention' + + @classmethod + def from_float(cls, other): + assert type(other) == cls._FLOAT_MODULE + assert hasattr(other, 'qconfig'), "The float module must have 'qconfig'" + # Setting the dropout to 0.0! + observed = cls(other.embed_dim, other.num_heads, other.dropout, + (other.in_proj_bias is not None), + (other.bias_k is not None), + other.add_zero_attn, other.kdim, other.vdim, + other.batch_first) + observed.bias_k = other.bias_k + observed.bias_v = other.bias_v + observed.qconfig = other.qconfig + + # Set the linear weights + # for the type: ignores, see https://github.com/pytorch/pytorch/issues/58969 + observed.out_proj.weight = other.out_proj.weight # type: ignore[has-type] + observed.out_proj.bias = other.out_proj.bias # type: ignore[has-type] + if other._qkv_same_embed_dim: + # Use separate params + bias = other.in_proj_bias + _start = 0 + _end = _start + other.embed_dim + weight = other.in_proj_weight[_start:_end, :] + if bias is not None: + bias = torch.nn.Parameter(bias[_start:_end], bias.requires_grad) + observed.linear_Q.weight = torch.nn.Parameter(weight, + weight.requires_grad) + observed.linear_Q.bias = bias + + bias = other.in_proj_bias + _start = _end + _end = _start + other.embed_dim + weight = other.in_proj_weight[_start:_end, :] + if bias is not None: + bias = torch.nn.Parameter(bias[_start:_end], bias.requires_grad) + observed.linear_K.weight = torch.nn.Parameter(weight, + weight.requires_grad) + observed.linear_K.bias = bias + + bias = other.in_proj_bias + _start = _end + weight = other.in_proj_weight[_start:, :] + if bias is not None: + bias = torch.nn.Parameter(bias[_start:], bias.requires_grad) + observed.linear_V.weight = torch.nn.Parameter(weight, + weight.requires_grad) + observed.linear_V.bias = bias + else: + observed.linear_Q.weight = nn.Parameter(other.q_proj_weight) + observed.linear_K.weight = nn.Parameter(other.k_proj_weight) + observed.linear_V.weight = nn.Parameter(other.v_proj_weight) + if other.in_proj_bias is None: + observed.linear_Q.bias = None # type: ignore[assignment] + observed.linear_K.bias = None # type: ignore[assignment] + observed.linear_V.bias = None # type: ignore[assignment] + else: + observed.linear_Q.bias = nn.Parameter(other.in_proj_bias[0:other.embed_dim]) + observed.linear_K.bias = nn.Parameter(other.in_proj_bias[other.embed_dim:(other.embed_dim * 2)]) + observed.linear_V.bias = nn.Parameter(other.in_proj_bias[(other.embed_dim * 2):]) + observed.eval() + # Explicit prepare + observed = torch.ao.quantization.prepare(observed, inplace=True) + return observed + + @torch.jit.unused + def dequantize(self): + r"""Utility to convert the quantized MHA back to float. + + The motivation for this is that it is not trivial to conver the weights + from the format that is used in the quantized version back to the + float. + """ + fp = self._FLOAT_MODULE(self.embed_dim, self.num_heads, self.dropout, + (self.linear_Q._weight_bias()[1] is not None), + (self.bias_k is not None), + self.add_zero_attn, self.kdim, self.vdim, self.batch_first) + assert fp._qkv_same_embed_dim == self._qkv_same_embed_dim + if self.bias_k is not None: + fp.bias_k = nn.Parameter(self.bias_k.dequantize()) + if self.bias_v is not None: + fp.bias_v = nn.Parameter(self.bias_v.dequantize()) + + # Set the linear weights + # Note: Because the linear layers are quantized, mypy does not nkow how + # to deal with them -- might need to ignore the typing checks. + # for the type: ignore[has-type], see https://github.com/pytorch/pytorch/issues/58969 + w, b = self.out_proj._weight_bias() # type: ignore[operator, has-type] + fp.out_proj.weight = nn.Parameter(w.dequantize()) + if b is not None: + fp.out_proj.bias = nn.Parameter(b) + + wQ, bQ = self.linear_Q._weight_bias() # type: ignore[operator] + wQ = wQ.dequantize() + wK, bK = self.linear_K._weight_bias() # type: ignore[operator] + wK = wK.dequantize() + wV, bV = self.linear_V._weight_bias() # type: ignore[operator] + wV = wV.dequantize() + if fp._qkv_same_embed_dim: + # Use separate params + _start = 0 + _end = _start + fp.embed_dim + fp.in_proj_weight[_start:_end, :] = wQ + if fp.in_proj_bias is not None: + assert all(bQ == 0) + fp.in_proj_bias[_start:_end] = bQ + + _start = _end + _end = _start + fp.embed_dim + fp.in_proj_weight[_start:_end, :] = wK + if fp.in_proj_bias is not None: + assert all(bK == 0) + fp.in_proj_bias[_start:_end] = bK + + _start = _end + fp.in_proj_weight[_start:, :] = wV + if fp.in_proj_bias is not None: + assert all(bV == 0) + fp.in_proj_bias[_start:] = bV + else: + fp.q_proj_weight = nn.Parameter(wQ) + fp.k_proj_weight = nn.Parameter(wK) + fp.v_proj_weight = nn.Parameter(wV) + if fp.in_proj_bias is None: + self.linear_Q.bias = None + self.linear_K.bias = None + self.linear_V.bias = None + else: + fp.in_proj_bias[0:fp.embed_dim] = bQ + fp.in_proj_bias[fp.embed_dim:(fp.embed_dim * 2)] = bK + fp.in_proj_bias[(fp.embed_dim * 2):] = bV + + return fp + + + @classmethod + def from_observed(cls, other): + # The whole flow is float -> observed -> quantized + # This class does float -> observed only + # See nn.quantized.MultiheadAttention + raise NotImplementedError("It looks like you are trying to prepare an " + "MHA module. Please, see " + "the examples on quantizable MHAs.") + + def forward(self, + query: Tensor, + key: Tensor, + value: Tensor, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + average_attn_weights: bool = True, + is_causal: bool = False) -> Tuple[Tensor, Optional[Tensor]]: + r""" + Note:: + Please, refer to :func:`~torch.nn.MultiheadAttention.forward` for more + information + + Args: + query, key, value: map a query and a set of key-value pairs to an output. + See "Attention Is All You Need" for more details. + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. When given a binary mask and a value is True, + the corresponding value on the attention layer will be ignored. + need_weights: output attn_output_weights. + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + + Shape: + - Inputs: + - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. :math:`(N, L, E)` if ``batch_first`` is ``True``. + - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is + the embedding dimension. :math:`(N, S, E)` if ``batch_first`` is ``True``. + - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is + the embedding dimension. :math:`(N, S, E)` if ``batch_first`` is ``True``. + - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked + positions. If a BoolTensor is provided, positions with ``True`` + is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + - is_causal: If specified, applies a causal mask as attention mask. Mutually exclusive with providing attn_mask. + Default: ``False``. + - average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across + heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an + effect when ``need_weights=True.``. Default: True (i.e. average weights across heads) + + - Outputs: + - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. :math:`(N, L, E)` if ``batch_first`` is ``True``. + - attn_output_weights: If ``average_attn_weights=True``, returns attention weights averaged + across heads of shape :math:`(N, L, S)`, where N is the batch size, L is the target sequence length, + S is the source sequence length. If ``average_attn_weights=False``, returns attention weights per + head of shape :math:`(N, num_heads, L, S)`. + """ + return self._forward_impl(query, key, value, key_padding_mask, + need_weights, attn_mask, average_attn_weights, + is_causal) + + def _forward_impl(self, + query: Tensor, + key: Tensor, + value: Tensor, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + average_attn_weights: bool = True, + is_causal: bool = False) -> Tuple[Tensor, Optional[Tensor]]: + # This version will not deal with the static key/value pairs. + # Keeping it here for future changes. + # + # TODO: This method has some duplicate lines with the + # `torch.nn.functional.multi_head_attention`. Will need to refactor. + static_k = None + static_v = None + + if attn_mask is not None and is_causal: + raise AssertionError("Only allow causal mask or attn_mask") + + if is_causal: + raise AssertionError("causal mask not supported by AO MHA module") + + if self.batch_first: + query, key, value = (x.transpose(0, 1) for x in (query, key, value)) + + tgt_len, bsz, embed_dim_to_check = query.size() + assert self.embed_dim == embed_dim_to_check + # allow MHA to have different sizes for the feature dimension + assert key.size(0) == value.size(0) and key.size(1) == value.size(1) + + head_dim = self.embed_dim // self.num_heads + assert head_dim * self.num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" + scaling = float(head_dim) ** -0.5 + + q = self.linear_Q(query) + k = self.linear_K(key) + v = self.linear_V(value) + + q = self.q_scaling_product.mul_scalar(q, scaling) + + if attn_mask is not None: + if attn_mask.dtype == torch.uint8: + warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.") + attn_mask = attn_mask.to(torch.bool) + assert attn_mask.is_floating_point() or attn_mask.dtype == torch.bool, \ + f'Only float and bool types are supported for attn_mask, not {attn_mask.dtype}' + + if attn_mask.dim() == 2: + attn_mask = attn_mask.unsqueeze(0) + if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: + raise RuntimeError('The size of the 2D attn_mask is not correct.') + elif attn_mask.dim() == 3: + if list(attn_mask.size()) != [bsz * self.num_heads, query.size(0), key.size(0)]: + raise RuntimeError('The size of the 3D attn_mask is not correct.') + else: + raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported") + # attn_mask's dim is 3 now. + + # convert ByteTensor key_padding_mask to bool + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: + warnings.warn("Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.") + key_padding_mask = key_padding_mask.to(torch.bool) + if self.bias_k is not None and self.bias_v is not None: + if static_k is None and static_v is None: + + # Explicitly assert that bias_k and bias_v are not None + # in a way that TorchScript can understand. + bias_k = self.bias_k + assert bias_k is not None + bias_v = self.bias_v + assert bias_v is not None + + k = torch.cat([k, bias_k.repeat(1, bsz, 1)]) + v = torch.cat([v, bias_v.repeat(1, bsz, 1)]) + if attn_mask is not None: + attn_mask = nnF.pad(attn_mask, (0, 1)) + if key_padding_mask is not None: + key_padding_mask = nnF.pad(key_padding_mask, (0, 1)) + else: + assert static_k is None, "bias cannot be added to static key." + assert static_v is None, "bias cannot be added to static value." + else: + assert self.bias_k is None + assert self.bias_v is None + + q = q.contiguous().view(tgt_len, bsz * self.num_heads, head_dim).transpose(0, 1) + if k is not None: + k = k.contiguous().view(-1, bsz * self.num_heads, head_dim).transpose(0, 1) + if v is not None: + v = v.contiguous().view(-1, bsz * self.num_heads, head_dim).transpose(0, 1) + + if static_k is not None: + assert static_k.size(0) == bsz * self.num_heads + assert static_k.size(2) == head_dim + k = static_k + + if static_v is not None: + assert static_v.size(0) == bsz * self.num_heads + assert static_v.size(2) == head_dim + v = static_v + + src_len = k.size(1) + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz + assert key_padding_mask.size(1) == src_len + + if self.add_zero_attn: + src_len += 1 + k_zeros = torch.zeros((k.size(0), 1) + k.size()[2:]) + if k.is_quantized: + k_zeros = torch.quantize_per_tensor(k_zeros, k.q_scale(), k.q_zero_point(), k.dtype) + k = torch.cat([k, k_zeros], dim=1) + v_zeros = torch.zeros((v.size(0), 1) + k.size()[2:]) + if v.is_quantized: + v_zeros = torch.quantize_per_tensor(v_zeros, v.q_scale(), v.q_zero_point(), v.dtype) + v = torch.cat([v, v_zeros], dim=1) + + if attn_mask is not None: + attn_mask = nnF.pad(attn_mask, (0, 1)) + if key_padding_mask is not None: + key_padding_mask = nnF.pad(key_padding_mask, (0, 1)) + + # Leaving the quantized zone here + q = self.dequant_q(q) + k = self.dequant_k(k) + v = self.dequant_v(v) + attn_output_weights = torch.bmm(q, k.transpose(1, 2)) + assert list(attn_output_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_output_weights.masked_fill_(attn_mask, float('-inf')) + else: + attn_output_weights += attn_mask + + if key_padding_mask is not None: + attn_output_weights = attn_output_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_output_weights = attn_output_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2), + float('-inf'), + ) + attn_output_weights = attn_output_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_output_weights = nnF.softmax( + attn_output_weights, dim=-1) + attn_output_weights = nnF.dropout(attn_output_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_output_weights, v) + assert list(attn_output.size()) == [bsz * self.num_heads, tgt_len, head_dim] + if self.batch_first: + attn_output = attn_output.view(bsz, tgt_len, self.embed_dim) + else: + attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, self.embed_dim) + + # Reentering the quantized zone + attn_output = self.quant_attn_output(attn_output) + # for the type: ignore[has-type], see https://github.com/pytorch/pytorch/issues/58969 + attn_output = self.out_proj(attn_output) # type: ignore[has-type] + attn_output_weights = self.quant_attn_output_weights(attn_output_weights) + + if need_weights: + # average attention weights over heads + attn_output_weights = attn_output_weights.view(bsz, self.num_heads, tgt_len, src_len) + if average_attn_weights: + attn_output_weights = attn_output_weights.mean(dim=1) + return attn_output, attn_output_weights + else: + return attn_output, None diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/__pycache__/functional.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/__pycache__/functional.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9653cfbb675ef7b6e077b273c9774153148206bc Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/__pycache__/functional.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/dynamic/modules/__pycache__/rnn.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/dynamic/modules/__pycache__/rnn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7db4c5a88a8d80a24f2867966aeeb56d37e1c71d Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/dynamic/modules/__pycache__/rnn.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/__pycache__/batchnorm.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/__pycache__/batchnorm.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8b35df1c2f1ec0a94c0ac94c4bf32725487bbe74 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/__pycache__/batchnorm.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/__pycache__/conv.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/__pycache__/conv.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2c4e8defc7765341cf83ca980fed4fed065d6e01 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/__pycache__/conv.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/__pycache__/embedding_ops.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/__pycache__/embedding_ops.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c1d02ec5ae24f8098a14ce005e0564aa5703ed75 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/__pycache__/embedding_ops.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/__pycache__/linear.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/__pycache__/linear.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0ed4025fae9e8f3afe476784d9b80c2fc7d76309 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/__pycache__/linear.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/__pycache__/rnn.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/__pycache__/rnn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0fcbcff95b36ed43d8885f07ba4c236ae674c46e Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/__pycache__/rnn.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/__pycache__/utils.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a8817927fc549ba3ed01adc37f7a2a6c3acf149 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/__pycache__/utils.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/activation.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/activation.py new file mode 100644 index 0000000000000000000000000000000000000000..6fcd223e50499a917e20ed8f0fcdc95736ee00fb --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/activation.py @@ -0,0 +1,302 @@ +import torch +from warnings import warn +__all__ = [ + "ReLU6", + "Hardswish", + "ELU", + "LeakyReLU", + "Sigmoid", + "Softmax", + "MultiheadAttention", + "PReLU" +] + +class ReLU6(torch.nn.ReLU): + r"""Applies the element-wise function: + + :math:`\text{ReLU6}(x) = \min(\max(x_0, x), q(6))`, where :math:`x_0` is the + zero_point, and :math:`q(6)` is the quantized representation of number 6. + + Args: + inplace: can optionally do the operation in-place. Default: ``False`` + + Shape: + - Input: :math:`(N, *)` where `*` means, any number of additional + dimensions + - Output: :math:`(N, *)`, same shape as the input + + .. image:: ../scripts/activation_images/ReLU6.png + + Examples:: + + >>> m = nn.quantized.ReLU6() + >>> input = torch.randn(2) + >>> # xdoctest: +SKIP + >>> input = torch.quantize_per_tensor(input, 1.0, 0, dtype=torch.qint32) + >>> output = m(input) + """ + def __init__(self, inplace=False): + super().__init__(inplace) + self.inplace = inplace + + def forward(self, input): + return torch.ops.quantized.relu6(input, self.inplace) + + def _get_name(self): + return 'QuantizedReLU6' + + @staticmethod + def from_float(mod): + return ReLU6(mod.inplace) + +class Hardswish(torch.nn.Hardswish): + r"""This is the quantized version of :class:`~torch.nn.Hardswish`. + + Args: + scale: quantization scale of the output tensor + zero_point: quantization zero point of the output tensor + """ + def __init__(self, scale, zero_point, device=None, dtype=None): + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + self.register_buffer('scale', torch.tensor(scale, **factory_kwargs)) + self.register_buffer('zero_point', torch.tensor(zero_point, **factory_kwargs)) + + def forward(self, input): + return torch.ops.quantized.hardswish(input, self.scale, self.zero_point) + + def _get_name(self): + return 'QuantizedHardswish' + + @staticmethod + def from_float(mod): + scale, zero_point = mod.activation_post_process.calculate_qparams() + return Hardswish(float(scale), int(zero_point)) + + @classmethod + def from_reference(cls, mod, scale, zero_point): + return cls(float(scale), int(zero_point)) + +class ELU(torch.nn.ELU): + r"""This is the quantized equivalent of :class:`~torch.nn.ELU`. + + Args: + scale: quantization scale of the output tensor + zero_point: quantization zero point of the output tensor + alpha: the alpha constant + """ + def __init__(self, scale, zero_point, alpha=1.): + super().__init__(alpha) + self.scale = scale + self.zero_point = zero_point + + def forward(self, input): + return torch.ao.nn.quantized.functional.elu( + input, self.scale, self.zero_point, self.alpha) + + def _get_name(self): + return 'QuantizedELU' + + @staticmethod + def from_float(mod): + scale, zero_point = mod.activation_post_process.calculate_qparams() + return ELU(float(scale), int(zero_point), mod.alpha) + + @classmethod + def from_reference(cls, mod, scale, zero_point): + return cls(float(scale), int(zero_point), mod.alpha) + +class LeakyReLU(torch.nn.LeakyReLU): + r"""This is the quantized equivalent of :class:`~torch.nn.LeakyReLU`. + + Args: + scale: quantization scale of the output tensor + zero_point: quantization zero point of the output tensor + negative_slope: Controls the angle of the negative slope. Default: 1e-2 + """ + def __init__(self, scale: float, zero_point: int, negative_slope: float = 1e-2, + inplace: bool = False, device=None, dtype=None) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__(negative_slope, inplace) + self.register_buffer('scale', torch.tensor(scale, **factory_kwargs)) + self.register_buffer('zero_point', torch.tensor(zero_point, **factory_kwargs)) + + def forward(self, input): + return torch.ops.quantized.leaky_relu( + input, self.negative_slope, self.inplace, self.scale, self.zero_point) + + def _get_name(self): + return 'QuantizedLeakyReLU' + + @classmethod + def from_float(cls, mod): + scale, zero_point = mod.activation_post_process.calculate_qparams() + return cls(float(scale), int(zero_point), mod.negative_slope, mod.inplace) + + @classmethod + def from_reference(cls, mod, scale, zero_point): + return cls(float(scale), int(zero_point), mod.negative_slope, mod.inplace) + +class Sigmoid(torch.nn.Sigmoid): + r"""This is the quantized equivalent of :class:`~torch.nn.Sigmoid`. + + Args: + scale: quantization scale of the output tensor + zero_point: quantization zero point of the output tensor + """ + + def __init__(self, output_scale: float, output_zero_point: int): + super().__init__() + self.output_scale = output_scale + self.output_zero_point = output_zero_point + + def forward(self, input): + return torch.ops.quantized.sigmoid(input, self.output_scale, self.output_zero_point) + + @classmethod + def from_float(cls, mod): + output_scale, output_zero_point = mod.activation_post_process.calculate_qparams() + return cls(float(output_scale), int(output_zero_point)) + +class Softmax(torch.nn.Softmax): + r"""This is the quantized version of :class:`~torch.nn.Softmax`. + + Args: + dim: A dimension along which Softmax will be computed (so every slice along dim will sum to 1). + scale: quantization scale of the output tensor + zero_point: quantization zero point of the output tensor + """ + def __init__(self, dim=None, scale=1.0, zero_point=0): + super().__init__() + self.dim = dim + self.scale = scale + self.zero_point = zero_point + + def forward(self, input): + dim = self.dim + if dim is None: + stacklevel = 3 + # Note: adding the mypy ignore on _get_softmax_dim seems less bad + # than making `_get_softmax_dim` an official API. + dim = torch.nn.functional._get_softmax_dim( # type: ignore[attr-defined] + "softmax", input.dim(), stacklevel) + return torch.ops.quantized.softmax( + input, dim, self.scale, self.zero_point) + + def _get_name(self): + return 'QuantizedSoftmax' + + @staticmethod + def from_float(mod): + scale, zero_point = mod.activation_post_process.calculate_qparams() + return Softmax(mod.dim, float(scale), int(zero_point)) + + @classmethod + def from_reference(cls, mod, scale, zero_point): + return cls(mod.dim, float(scale), int(zero_point)) + + +class MultiheadAttention(torch.ao.nn.quantizable.MultiheadAttention): + _FLOAT_MODULE = torch.ao.nn.quantizable.MultiheadAttention + + def _get_name(self): + return "QuantizedMultiheadAttention" + + @classmethod + def from_float(cls, other): + # The whole flow is float -> observed -> quantized + # This class does observed -> quantized only + raise NotImplementedError("It looks like you are trying to convert a " + "non-observed MHA module. Please, see " + "the examples on quantizable MHAs.") + + @classmethod + def from_observed(cls, other): + converted = torch.ao.quantization.convert(other, mapping=None, + inplace=False, + remove_qconfig=True, + convert_custom_config_dict=None) + converted.__class__ = cls + # Remove the parameters for the bias_k and bias_v to quantize them + # TODO: This is a potential source of accuracy drop. + # quantized cat takes the scale and zp of the first + # element, which might lose the precision in the bias_k + # and the bias_v (which are cat'ed with k/v being first). + if converted.bias_k is not None: + bias_k = converted._parameters.pop('bias_k') + sc, zp = torch._choose_qparams_per_tensor(bias_k, + reduce_range=False) + bias_k = torch.quantize_per_tensor(bias_k, sc, zp, torch.quint8) + setattr(converted, 'bias_k', bias_k) # noqa: B010 + + if converted.bias_v is not None: + bias_v = converted._parameters.pop('bias_v') + sc, zp = torch._choose_qparams_per_tensor(bias_k, # type: ignore[possibly-undefined] + reduce_range=False) + bias_v = torch.quantize_per_tensor(bias_v, sc, zp, torch.quint8) + setattr(converted, 'bias_v', bias_v) # noqa: B010 + + del converted.in_proj_weight + del converted.in_proj_bias + + return converted + +class PReLU(torch.nn.Module): + r"""This is the quantized equivalent of :class:`~torch.nn.PReLU`. + + Args: + scale: quantization scale of the output tensor + zero_point: quantization zero point of the output tensor + num_parameters: number of parameters: 1, or the number of channels at input. Default: 1 + """ + def __init__(self, output_scale: float, output_zero_point: int, + num_parameters: int = 1) -> None: + super().__init__() + self.num_parameters = num_parameters + self.scale = output_scale + self.zero_point = output_zero_point + w = torch.randn(num_parameters, dtype=torch.float) + qw = torch.quantize_per_tensor(w, scale=1.0, zero_point=0, dtype=torch.quint8) + self.set_weight(qw) + + def set_weight(self, w: torch.Tensor) -> None: + self.weight = w + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return torch.ops.quantized.prelu(input, self.weight, self.scale, self.zero_point) + + def _get_name(self): + return 'QuantizedPReLU' + + @classmethod + def from_float(cls, mod): + scale, zero_point = mod.activation_post_process.calculate_qparams() + qprelu = cls(float(scale), int(zero_point), mod.num_parameters) + float_wt = mod.weight.float() + observer = mod.qconfig.weight() + observer(float_wt) + if observer.dtype != torch.quint8: + warn( + f"PReLU's weight observer should have dtype quint8 but got {observer.dtype}" + ) + wt_scale, wt_zp = observer.calculate_qparams() + qweight = torch.quantize_per_tensor( + float_wt, float(wt_scale), int(wt_zp), torch.quint8) + qprelu.set_weight(qweight) + return qprelu + + @classmethod + def from_reference(cls, mod, scale, zero_point): + qprelu = cls(float(scale), int(zero_point), mod.num_parameters) + float_wt = mod.weight.float() + observer = mod.qconfig.weight() + observer(float_wt) + if observer.dtype != torch.quint8: + warn( + f"PReLU's weight observer should have dtype quint8 but got {observer.dtype}" + ) + wt_scale, wt_zp = observer.calculate_qparams() + qweight = torch.quantize_per_tensor( + float_wt, float(wt_scale), int(wt_zp), torch.quint8) + qprelu.set_weight(qweight) + return qprelu diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/conv.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/conv.py new file mode 100644 index 0000000000000000000000000000000000000000..ad1a51ee9c3b18863c9f205f2f47a6e06a380419 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/conv.py @@ -0,0 +1,945 @@ +r"""Quantized convolution modules.""" + +from typing import Optional, List, TypeVar + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.ao.nn.intrinsic as nni +import torch.ao.nn.intrinsic.qat as nniqat + +from torch._ops import ops +from torch.nn.common_types import _size_1_t +from torch.nn.modules.utils import _single, _pair, _triple +from torch.nn.utils import fuse_conv_bn_weights + +from .utils import _quantize_weight, WeightedQuantizedModule + +__all__ = ['Conv1d', 'Conv2d', 'Conv3d', 'ConvTranspose1d', 'ConvTranspose2d', 'ConvTranspose3d'] + +_SUPPORTED_PADDING = { + 'zeros', + 'reflect' +} + + +def _reverse_repeat_padding(padding: List[int]) -> List[int]: + _reversed_padding_repeated_twice: List[int] = [] + N = len(padding) + for idx in range(N): + for _ in range(2): + _reversed_padding_repeated_twice.append(padding[N - idx - 1]) + return _reversed_padding_repeated_twice + + +class _ConvNd(WeightedQuantizedModule): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, dilation=1, groups=1, bias=True, + padding_mode='zeros', device=None, dtype=None): + # All subclasses have this signature - See PR #49702s + raise NotImplementedError + + def _init(self, in_channels, out_channels, kernel_size, stride, + padding, dilation, + transposed, output_padding, + groups, bias, + padding_mode='zeros', + device=None, + dtype=None) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + + if in_channels % groups != 0: + raise ValueError('in_channels must be divisible by groups') + if out_channels % groups != 0: + raise ValueError('out_channels must be divisible by groups') + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + self.transposed = transposed + self.output_padding = output_padding + self.groups = groups + if padding_mode not in _SUPPORTED_PADDING: + raise ValueError(f"'padding_mode' {padding_mode} is not supported by quantized convolution") + self.padding_mode = padding_mode + # Initialize as NCHW. set_weight will internally transpose to NHWC. + if self.transposed: + weight_shape = [in_channels, out_channels // self.groups] + else: + weight_shape = [out_channels, in_channels // self.groups] + qweight = torch._empty_affine_quantized( + weight_shape + list(kernel_size), + scale=1, zero_point=0, dtype=torch.qint8, + **{k: v for k, v in factory_kwargs.items() if k != 'dtype'}) + bias_float = ( + torch.zeros(out_channels, dtype=torch.float, + **{k: v for k, v in factory_kwargs.items() if k != 'dtype'}) if bias else None) + + self.set_weight_bias(qweight, bias_float) + self.scale = 1.0 + self.zero_point = 0 + + def set_weight_bias(self, qweight, bias_float): + raise NotImplementedError + + def bias(self): + raise NotImplementedError + + def _weight_bias(self): + raise NotImplementedError + + def extra_repr(self): + s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}' + ', stride={stride}, scale={scale}, zero_point={zero_point}') + if self.padding != (0,) * len(self.padding): + s += ', padding={padding}' + if self.dilation != (1,) * len(self.dilation): + s += ', dilation={dilation}' + if self.output_padding != (0,) * len(self.output_padding): + s += ', output_padding={output_padding}' + if self.groups != 1: + s += ', groups={groups}' + if self.bias() is None: + s += ', bias=False' + return s.format(**self.__dict__) + + # ===== Serialization methods ===== + # The special consideration here is that we have to unpack the weights into + # their regular QTensor form for serialization. Packed weights should not + # live outside the process in which they were created, rather they should be + # derived from the QTensor weight. + # self + # |--- weight : Tensor + # |--- bias : Tensor + # + # TODO: maybe change to this when https://github.com/pytorch/pytorch/pull/32958 is landed + # self + # |--- _packed_params : Conv2dPackedParamsBase or Conv3dPackedParamsBase + def _save_to_state_dict(self, destination, prefix, keep_vars): + super()._save_to_state_dict(destination, prefix, keep_vars) + (w, b) = self._weight_bias() + destination[prefix + 'weight'] = w + destination[prefix + 'bias'] = b + destination[prefix + 'scale'] = torch.tensor(self.scale) + destination[prefix + 'zero_point'] = torch.tensor(self.zero_point) + + @torch.jit.export + def __getstate__(self): + (w, b) = self._weight_bias() + return ( + self.in_channels, + self.out_channels, + self.kernel_size, + self.stride, + self.padding, + self.dilation, + self.transposed, + self.output_padding, + self.groups, + self.padding_mode, + w, + b, + self.scale, + self.zero_point, + self.training + ) + + # ===== Deserialization methods ===== + # Counterpart to the serialization methods, we must pack the serialized + # QTensor weight into its packed format for use by the FBGEMM ops. + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + self.set_weight_bias( + state_dict[prefix + 'weight'], state_dict[prefix + 'bias']) + state_dict.pop(prefix + 'weight') + state_dict.pop(prefix + 'bias') + self.scale = float(state_dict[prefix + 'scale']) + state_dict.pop(prefix + 'scale') + self.zero_point = int(state_dict[prefix + 'zero_point']) + state_dict.pop(prefix + 'zero_point') + super()._load_from_state_dict( + state_dict, prefix, local_metadata, False, missing_keys, + unexpected_keys, error_msgs) + + @torch.jit.export + def __setstate__(self, state): + self.in_channels = state[0] + self.out_channels = state[1] + self.kernel_size = state[2] + self.stride = state[3] + self.padding = state[4] + self.dilation = state[5] + self.transposed = state[6] + self.output_padding = state[7] + self.groups = state[8] + self.padding_mode = state[9] + self.set_weight_bias(state[10], state[11]) + self.scale = state[12] + self.zero_point = state[13] + self.training = state[14] + + def __deepcopy__(self, memo): + new_instance = type(self).__new__(type(self)) + torch.nn.Module.__init__(new_instance) + state = self.__getstate__() + new_instance.__setstate__(state) + return new_instance + + def __copy__(self): + return self.__deepcopy__({}) + + @classmethod + def get_qconv(cls, mod, activation_post_process, weight_post_process=None): + r"""Creates a qconv object and returns it. + """ + if weight_post_process is None: + weight_post_process = mod.qconfig.weight() + weight_post_process(mod.weight) + assert weight_post_process.dtype == torch.qint8, \ + 'Weight observer must have a dtype of qint8' + qweight = _quantize_weight(mod.weight.float(), weight_post_process) + # the __init__ call used is the one from derived classes and not the one from _ConvNd + qconv = cls(mod.in_channels, mod.out_channels, mod.kernel_size, + mod.stride, mod.padding, mod.dilation, mod.groups, + mod.bias is not None, mod.padding_mode) + qconv.set_weight_bias(qweight, mod.bias) + if activation_post_process is None or activation_post_process.dtype == torch.float: + return qconv # dynamic quantization doesn't need scale/zero_point + else: + act_scale, act_zp = activation_post_process.calculate_qparams() + qconv.scale = float(act_scale) + qconv.zero_point = int(act_zp) + return qconv + + @staticmethod + def from_float(cls, mod): + if hasattr(mod, "weight_fake_quant"): + # assert type(mod) == cls.__QAT_MODULE, " nnq." + cls.__name__ + \ + # ".from_float only works for " + cls.__QAT_MODULE.__name__ + if type(mod) == cls._NNIQAT_CONV_BN_MODULE: + mod.weight, mod.bias = fuse_conv_bn_weights( + mod.weight, mod.bias, mod.bn.running_mean, mod.bn.running_var, + mod.bn.eps, mod.bn.weight, mod.bn.bias) + assert hasattr(mod, "activation_post_process"), \ + "Input QAT module must have observer attached" + weight_post_process = mod.weight_fake_quant + activation_post_process = mod.activation_post_process + else: + assert type(mod) == cls._FLOAT_MODULE, \ + " nnq." + cls.__name__ + ".from_float only works for " + \ + cls._FLOAT_MODULE.__name__ + " but got:" + str(type(mod)) + assert hasattr(mod, "qconfig"), \ + "Input float module must have qconfig defined." + activation_post_process = None if not hasattr( + mod, "activation_post_process") else mod.activation_post_process + if type(mod) in [cls._NNI_CONV_RELU_MODULE, cls._NNI_CONV_ADD_MODULE, cls._NNI_CONV_ADD_RELU_MODULE]: + mod = mod[0] + weight_post_process = mod.qconfig.weight() + return cls.get_qconv(mod, activation_post_process, weight_post_process) + + @classmethod + def from_reference(cls, ref_qconv, output_scale, output_zero_point): + r"""Create a (fbgemm/qnnpack) quantized module from a reference quantized module + Args: + ref_qconv (Module): a reference quantized module, either produced by torch.ao.quantization + utilities or provided by the user + output_scale (float): scale for output Tensor + output_zero_point (int): zero point for output Tensor + """ + qconv = cls( + ref_qconv.in_channels, + ref_qconv.out_channels, + ref_qconv.kernel_size, # type: ignore[arg-type] + ref_qconv.stride, # type: ignore[arg-type] + ref_qconv.padding, # type: ignore[arg-type] + ref_qconv.dilation, # type: ignore[arg-type] + ref_qconv.groups, + ref_qconv.bias is not None, # type: ignore[arg-type] + ref_qconv.padding_mode, + device=ref_qconv.weight.device, + dtype=ref_qconv.weight.dtype) + qweight = ref_qconv.get_quantized_weight() + qconv.set_weight_bias(qweight, ref_qconv.bias) + qconv.scale = float(output_scale) + qconv.zero_point = int(output_zero_point) + return qconv + + +class Conv1d(_ConvNd): + r"""Applies a 1D convolution over a quantized input signal composed of + several quantized input planes. + + For details on input arguments, parameters, and implementation see + :class:`~torch.nn.Conv1d`. + + .. note:: + Only `zeros` is supported for the :attr:`padding_mode` argument. + + .. note:: + Only `torch.quint8` is supported for the input data type. + + + Attributes: + weight (Tensor): packed tensor derived from the learnable weight + parameter. + scale (Tensor): scalar for the output scale + zero_point (Tensor): scalar for the output zero point + + See :class:`~torch.nn.Conv1d` for other attributes. + + Examples:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE) + >>> m = nn.quantized.Conv1d(16, 33, 3, stride=2) + >>> input = torch.randn(20, 16, 100) + >>> # quantize input to quint8 + >>> # xdoctest: +SKIP + >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, + ... dtype=torch.quint8) + >>> output = m(q_input) + + """ + + _FLOAT_MODULE = nn.Conv1d + _NNIQAT_CONV_BN_MODULE = nniqat.ConvBn1d + _NNI_CONV_RELU_MODULE = nni.ConvReLU1d + _NNI_CONV_ADD_MODULE: None = None + _NNI_CONV_ADD_RELU_MODULE: None = None + + def __init__(self, + in_channels: int, + out_channels: int, + kernel_size: _size_1_t, + stride: _size_1_t = 1, + padding: _size_1_t = 0, + dilation: _size_1_t = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = 'zeros', + device=None, + dtype=None): + factory_kwargs = {'device': device, 'dtype': dtype} + kernel_size = _single(kernel_size) + stride = _single(stride) + padding = padding if isinstance(padding, str) else _single(padding) + dilation = _single(dilation) + + # Subclasses of _ConvNd needs to call _init rather than __init__. See + # discussion on PR #49702 + super()._init( + in_channels, out_channels, kernel_size, stride, padding, dilation, + False, _single(0), groups, bias, padding_mode, **factory_kwargs) + + def _get_name(self): + return 'QuantizedConv1d' + + def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None: + if self.padding_mode == 'zeros': + self._packed_params = torch.ops.quantized.conv1d_prepack( + w, b, self.stride, self.padding, self.dilation, self.groups) + else: + self._packed_params = torch.ops.quantized.conv1d_prepack( + w, b, self.stride, _pair(0), self.dilation, + self.groups) + + def _weight_bias(self): + w, b = torch.ops.quantized.conv1d_unpack(self._packed_params) + return w, b + + def weight(self): + return self._weight_bias()[0] + + def bias(self): + return self._weight_bias()[1] + + def forward(self, input): + # Temporarily using len(shape) instead of ndim due to JIT issue + # https://github.com/pytorch/pytorch/issues/23890 + if len(input.shape) != 3: + raise ValueError("Input shape must be `(N, C, L)`!") + if self.padding_mode != 'zeros': + # Padding in Conv1d is stored as (p, p), need to get (p,) + _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding[:1]) + input = F.pad(input, _reversed_padding_repeated_twice, + mode=self.padding_mode) + return ops.quantized.conv1d(input, self._packed_params, self.scale, self.zero_point) + + @classmethod + def from_float(cls, mod): + r"""Creates a quantized module from a float module or qparams_dict. + + Args: + mod (Module): a float module, either produced by torch.ao.quantization + utilities or provided by the user + """ + return _ConvNd.from_float(cls, mod) + + +class Conv2d(_ConvNd): + r"""Applies a 2D convolution over a quantized input signal composed of + several quantized input planes. + + For details on input arguments, parameters, and implementation see + :class:`~torch.nn.Conv2d`. + + .. note:: + Only `zeros` is supported for the :attr:`padding_mode` argument. + + .. note:: + Only `torch.quint8` is supported for the input data type. + + + Attributes: + weight (Tensor): packed tensor derived from the learnable weight + parameter. + scale (Tensor): scalar for the output scale + zero_point (Tensor): scalar for the output zero point + + See :class:`~torch.nn.Conv2d` for other attributes. + + Examples:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE) + >>> # With square kernels and equal stride + >>> m = nn.quantized.Conv2d(16, 33, 3, stride=2) + >>> # non-square kernels and unequal stride and with padding + >>> m = nn.quantized.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2)) + >>> # non-square kernels and unequal stride and with padding and dilation + >>> m = nn.quantized.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1)) + >>> input = torch.randn(20, 16, 50, 100) + >>> # quantize input to quint8 + >>> # xdoctest: +SKIP + >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8) + >>> output = m(q_input) + + """ + _FLOAT_MODULE = nn.Conv2d + _NNIQAT_CONV_BN_MODULE = nniqat.ConvBn2d + _NNI_CONV_RELU_MODULE = nni.ConvReLU2d + _NNI_CONV_ADD_MODULE = nni.ConvAdd2d + _NNI_CONV_ADD_RELU_MODULE = nni.ConvAddReLU2d + + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, dilation=1, groups=1, bias=True, + padding_mode='zeros', device=None, dtype=None): + factory_kwargs = {'device': device, 'dtype': dtype} + kernel_size = _pair(kernel_size) + stride = _pair(stride) + padding = _pair(padding) + dilation = _pair(dilation) + # Subclasses of _ConvNd need to call _init rather than __init__. See + # discussion on PR #49702 + super()._init( + in_channels, out_channels, kernel_size, stride, padding, dilation, + False, _pair(0), groups, bias, padding_mode, **factory_kwargs) + + def _get_name(self): + return 'QuantizedConv2d' + + def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None: + if self.padding_mode == 'zeros': + self._packed_params = torch.ops.quantized.conv2d_prepack( + w, b, self.stride, self.padding, self.dilation, self.groups) + else: + self._packed_params = torch.ops.quantized.conv2d_prepack( + w, b, self.stride, _pair(0), self.dilation, self.groups) + + def _weight_bias(self): + return self._packed_params.unpack() + + def weight(self): + return self._weight_bias()[0] + + def bias(self): + return self._weight_bias()[1] + + def forward(self, input): + # Temporarily using len(shape) instead of ndim due to JIT issue + # https://github.com/pytorch/pytorch/issues/23890 + if len(input.shape) != 4: + raise ValueError("Input shape must be `(N, C, H, W)`!") + if self.padding_mode != 'zeros': + _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding) + input = F.pad(input, _reversed_padding_repeated_twice, + mode=self.padding_mode) + return ops.quantized.conv2d( + input, self._packed_params, self.scale, self.zero_point) + + @classmethod + def from_float(cls, mod): + r"""Creates a quantized module from a float module or qparams_dict. + + Args: + mod (Module): a float module, either produced by torch.ao.quantization + utilities or provided by the user + """ + return _ConvNd.from_float(cls, mod) + + +class Conv3d(_ConvNd): + r"""Applies a 3D convolution over a quantized input signal composed of + several quantized input planes. + + For details on input arguments, parameters, and implementation see + :class:`~torch.nn.Conv3d`. + + .. note:: + Only `zeros` is supported for the :attr:`padding_mode` argument. + + .. note:: + Only `torch.quint8` is supported for the input data type. + + + Attributes: + weight (Tensor): packed tensor derived from the learnable weight + parameter. + scale (Tensor): scalar for the output scale + zero_point (Tensor): scalar for the output zero point + + See :class:`~torch.nn.Conv3d` for other attributes. + + Examples:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE) + >>> # With square kernels and equal stride + >>> m = nn.quantized.Conv3d(16, 33, 3, stride=2) + >>> # non-square kernels and unequal stride and with padding + >>> m = nn.quantized.Conv3d(16, 33, (3, 5, 5), stride=(1, 2, 2), padding=(1, 2, 2)) + >>> # non-square kernels and unequal stride and with padding and dilation + >>> m = nn.quantized.Conv3d(16, 33, (3, 5, 5), stride=(1, 2, 2), padding=(1, 2, 2), dilation=(1, 2, 2)) + >>> input = torch.randn(20, 16, 56, 56, 56) + >>> # quantize input to quint8 + >>> # xdoctest: +SKIP + >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8) + >>> output = m(q_input) + + """ + _FLOAT_MODULE = nn.Conv3d + _NNIQAT_CONV_BN_MODULE = nniqat.ConvBn3d + _NNI_CONV_RELU_MODULE = nni.ConvReLU3d + _NNI_CONV_ADD_MODULE: None = None + _NNI_CONV_ADD_RELU_MODULE: None = None + + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, dilation=1, groups=1, bias=True, + padding_mode='zeros', device=None, dtype=None): + assert padding_mode != 'reflect', "Conv3d does not support reflection padding" + factory_kwargs = {'device': device, 'dtype': dtype} + kernel_size = _triple(kernel_size) + stride = _triple(stride) + padding = _triple(padding) + dilation = _triple(dilation) + # Subclasses of _ConvNd need to call _init rather than __init__. See + # discussion on PR #49702 + super()._init( + in_channels, out_channels, kernel_size, stride, padding, dilation, + False, _triple(0), groups, bias, padding_mode, **factory_kwargs) + + def _get_name(self): + return 'QuantizedConv3d' + + def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None: + if self.padding_mode == 'zeros': + self._packed_params = torch.ops.quantized.conv3d_prepack( + w, b, self.stride, self.padding, self.dilation, self.groups) + else: + self._packed_params = torch.ops.quantized.conv3d_prepack( + w, b, self.stride, _triple(0), self.dilation, self.groups) + + def _weight_bias(self): + return self._packed_params.unpack() + + def weight(self): + return self._weight_bias()[0] + + def bias(self): + return self._weight_bias()[1] + + def forward(self, input): + # Temporarily using len(shape) instead of ndim due to JIT issue + # https://github.com/pytorch/pytorch/issues/23890 + if len(input.shape) != 5: + raise ValueError("Input shape must be `(N, C, D, H, W)`!") + if self.padding_mode != 'zeros': + _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding) + input = F.pad(input, _reversed_padding_repeated_twice, + mode=self.padding_mode) + return ops.quantized.conv3d( + input, self._packed_params, self.scale, self.zero_point) + + @classmethod + def from_float(cls, mod): + r"""Creates a quantized module from a float module or qparams_dict. + + Args: + mod (Module): a float module, either produced by torch.ao.quantization + utilities or provided by the user + """ + return _ConvNd.from_float(cls, mod) + +# === Transposed Convolutions === +MOD = TypeVar('MOD', bound=nn.modules.conv._ConvNd) + + +class _ConvTransposeNd(_ConvNd): + + _FLOAT_MODULE = MOD + + def __init__(self, in_channels, out_channels, kernel_size, stride, + padding, dilation, transposed, output_padding, + groups, bias, padding_mode, device=None, dtype=None): + if padding_mode != 'zeros': + raise ValueError(f'Only "zeros" padding mode is supported for {self.__class__.__name__}') + factory_kwargs = {'device': device, 'dtype': dtype} + # Subclasses of _ConvNd need to call _init rather than __init__. See + # discussion on PR #49702 + super()._init( + in_channels, out_channels, kernel_size, stride, + padding, dilation, transposed, output_padding, + groups, bias, padding_mode, **factory_kwargs) + + def _input_padding(self, kernel_size: List[int], dilation: List[int], padding: List[int]) -> List[int]: + res = torch.jit.annotate(List[int], []) + for kdx in range(len(kernel_size)): + pad = (dilation[kdx] * (kernel_size[kdx] - 1) - padding[kdx]) + res.append(pad) + return res + + @classmethod + def from_float(cls, mod): + r"""Creates a quantized module from a float module or qparams_dict. + Args: + mod (Module): a float module, either produced by torch.ao.quantization + utilities or provided by the user + """ + # derived classes override cls._FLOAT_MODULE attribute + msg = ' nnq.' + cls.__name__ + '.from_float only works for ' + \ + cls._FLOAT_MODULE.__name__ # type: ignore[attr-defined] + assert type(mod) == cls._FLOAT_MODULE, msg + assert hasattr(mod, 'qconfig'), \ + 'Input float module must have qconfig defined.' + weight_post_process = mod.qconfig.weight() + weight_post_process(mod.weight) + assert weight_post_process.dtype == torch.qint8, \ + 'Weight observer must have a dtype of qint8' + qweight = _quantize_weight(mod.weight.float(), weight_post_process) + # the __init__ call used is the one from derived classes and not the one from _ConvTransposeNd + qconv = cls(mod.in_channels, mod.out_channels, mod.kernel_size, # type: ignore[call-arg] + mod.stride, mod.padding, mod.output_padding, mod.groups, + mod.bias is not None, mod.dilation, mod.padding_mode) + qconv.set_weight_bias(qweight, mod.bias) + if not hasattr(mod, "activation_post_process") or mod.activation_post_process.dtype == torch.float: + return qconv # dynamic quantization doesn't need scale/zero_point + else: + act_scale, act_zp = mod.activation_post_process.calculate_qparams() + qconv.scale = float(act_scale) + qconv.zero_point = int(act_zp) + return qconv + + @staticmethod + def from_reference(cls, ref_qconvt, output_scale, output_zero_point): + r"""Create a (fbgemm/qnnpack) quantized module from a reference quantized module + Args: + ref_qconvt (Module): a reference quantized module, either produced by torch.ao.quantization + utilities or provided by the user + output_scale (float): scale for output Tensor + output_zero_point (int): zero point for output Tensor + """ + qconv = cls( + ref_qconvt.in_channels, + ref_qconvt.out_channels, + ref_qconvt.kernel_size, # type: ignore[arg-type] + ref_qconvt.stride, # type: ignore[arg-type] + ref_qconvt.padding, # type: ignore[arg-type] + ref_qconvt.output_padding, # type: ignore[arg-type] + ref_qconvt.groups, + ref_qconvt.bias is not None, # type: ignore[arg-type] + ref_qconvt.dilation, # type: ignore[arg-type] + ref_qconvt.padding_mode, + device=ref_qconvt.weight.device, + dtype=ref_qconvt.weight.dtype) + qweight = ref_qconvt.get_quantized_weight() + qconv.set_weight_bias(qweight, ref_qconvt.bias) + qconv.scale = float(output_scale) + qconv.zero_point = int(output_zero_point) + return qconv + + +class ConvTranspose1d(_ConvTransposeNd): + r"""Applies a 1D transposed convolution operator over an input image + composed of several input planes. + For details on input arguments, parameters, and implementation see + :class:`~torch.nn.ConvTranspose1d`. + + .. note:: Currently only the QNNPACK engine is implemented. + Please, set the `torch.backends.quantized.engine = 'qnnpack'` + + For special notes, please, see :class:`~torch.ao.nn.quantized.Conv1d` + + Attributes: + weight (Tensor): packed tensor derived from the learnable weight + parameter. + scale (Tensor): scalar for the output scale + zero_point (Tensor): scalar for the output zero point + See :class:`~torch.nn.ConvTranspose2d` for other attributes. + + Examples:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE) + >>> torch.backends.quantized.engine = 'qnnpack' + >>> from torch.ao.nn import quantized as nnq + >>> # With square kernels and equal stride + >>> m = nnq.ConvTranspose1d(16, 33, 3, stride=2) + >>> # non-square kernels and unequal stride and with padding + >>> m = nnq.ConvTranspose1d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2)) + >>> input = torch.randn(20, 16, 50) + >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8) + >>> output = m(q_input) + >>> # exact output size can be also specified as an argument + >>> input = torch.randn(1, 16, 12) + >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8) + >>> downsample = nnq.Conv1d(16, 16, 3, stride=2, padding=1) + >>> upsample = nnq.ConvTranspose1d(16, 16, 3, stride=2, padding=1) + >>> h = downsample(q_input) + >>> h.size() + torch.Size([1, 16, 6]) + >>> # xdoctest: +SKIP("FIXME: output_size is not a parameter) + >>> output = upsample(h, output_size=input.size()) + >>> output.size() + torch.Size([1, 16, 12]) + """ + + _FLOAT_MODULE = nn.ConvTranspose1d + + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, output_padding=0, groups=1, bias=True, + dilation=1, padding_mode='zeros', device=None, dtype=None): + factory_kwargs = {'device': device, 'dtype': dtype} + kernel_size = _single(kernel_size) + stride = _single(stride) + padding = _single(padding) + dilation = _single(dilation) + output_padding = _single(output_padding) + + super().__init__( + in_channels, out_channels, kernel_size, stride, padding, dilation, + True, output_padding, groups, bias, padding_mode, **factory_kwargs) + + def _get_name(self): + return 'QuantizedConvTranspose1d' + + def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None: + self._packed_params = torch.ops.quantized.conv_transpose1d_prepack( + w, b, self.stride, self.padding, self.output_padding, self.dilation, + self.groups) + + def _weight_bias(self): + w, b = torch.ops.quantized.conv_transpose1d_unpack(self._packed_params) + return w, b + + def weight(self): + (w, _) = self._weight_bias() + return w + + def bias(self): + (_, b) = self._weight_bias() + return b + + def forward(self, input): + # Temporarily using len(shape) instead of ndim due to JIT issue + # https://github.com/pytorch/pytorch/issues/23890 + if len(input.shape) != 3: + raise ValueError("Input shape must be `(N, C, L)`!") + return torch.ops.quantized.conv_transpose1d( + input, self._packed_params, self.scale, self.zero_point) + + @classmethod + def from_reference(cls, ref_qconvt, output_scale, output_zero_point): + return _ConvTransposeNd.from_reference(cls, ref_qconvt, output_scale, output_zero_point) + + +class ConvTranspose2d(_ConvTransposeNd): + r"""Applies a 2D transposed convolution operator over an input image + composed of several input planes. + For details on input arguments, parameters, and implementation see + :class:`~torch.nn.ConvTranspose2d`. + + For special notes, please, see :class:`~torch.ao.nn.quantized.Conv2d` + + Attributes: + weight (Tensor): packed tensor derived from the learnable weight + parameter. + scale (Tensor): scalar for the output scale + zero_point (Tensor): scalar for the output zero point + See :class:`~torch.nn.ConvTranspose2d` for other attributes. + + Examples:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE) + >>> # QNNPACK or FBGEMM as backend + >>> torch.backends.quantized.engine = 'qnnpack' + >>> # With square kernels and equal stride + >>> import torch.ao.nn.quantized as nnq + >>> m = nnq.ConvTranspose2d(16, 33, 3, stride=2) + >>> # non-square kernels and unequal stride and with padding + >>> m = nnq.ConvTranspose2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2)) + >>> input = torch.randn(20, 16, 50, 100) + >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8) + >>> output = m(q_input) + >>> # exact output size can be also specified as an argument + >>> input = torch.randn(1, 16, 12, 12) + >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8) + >>> downsample = nnq.Conv2d(16, 16, 3, stride=2, padding=1) + >>> upsample = nnq.ConvTranspose2d(16, 16, 3, stride=2, padding=1) + >>> h = downsample(q_input) + >>> h.size() + torch.Size([1, 16, 6, 6]) + >>> # xdoctest: +SKIP("FIXME: output_size is not a parameter) + >>> output = upsample(h, output_size=input.size()) + >>> output.size() + torch.Size([1, 16, 12, 12]) + """ + + _FLOAT_MODULE = nn.ConvTranspose2d + + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, output_padding=0, groups=1, bias=True, + dilation=1, padding_mode='zeros', device=None, dtype=None): + factory_kwargs = {'device': device, 'dtype': dtype} + kernel_size = _pair(kernel_size) + stride = _pair(stride) + padding = _pair(padding) + dilation = _pair(dilation) + output_padding = _pair(output_padding) + + super().__init__( + in_channels, out_channels, kernel_size, stride, padding, dilation, + True, output_padding, groups, bias, padding_mode, **factory_kwargs) + + def _get_name(self): + return 'QuantizedConvTranspose2d' + + def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None: + self._packed_params = torch.ops.quantized.conv_transpose2d_prepack( + w, b, self.stride, self.padding, self.output_padding, self.dilation, + self.groups) + + def _weight_bias(self): + w, b = torch.ops.quantized.conv2d_unpack(self._packed_params) + return w, b + + def weight(self): + (w, _) = self._weight_bias() + return w + + def bias(self): + (_, b) = self._weight_bias() + return b + + def forward(self, input): + # Temporarily using len(shape) instead of ndim due to JIT issue + # https://github.com/pytorch/pytorch/issues/23890 + if len(input.shape) != 4: + raise ValueError("Input shape must be `(N, C, H, W)`!") + return ops.quantized.conv_transpose2d( + input, self._packed_params, self.scale, self.zero_point) + + @classmethod + def from_reference(cls, ref_qconvt, output_scale, output_zero_point): + return _ConvTransposeNd.from_reference(cls, ref_qconvt, output_scale, output_zero_point) + + +class ConvTranspose3d(_ConvTransposeNd): + r"""Applies a 3D transposed convolution operator over an input image + composed of several input planes. + For details on input arguments, parameters, and implementation see + :class:`~torch.nn.ConvTranspose3d`. + + .. note:: Currently only the FBGEMM engine is implemented. + Please, set the `torch.backends.quantized.engine = 'fbgemm'` + + For special notes, please, see :class:`~torch.ao.nn.quantized.Conv3d` + + Attributes: + weight (Tensor): packed tensor derived from the learnable weight + parameter. + scale (Tensor): scalar for the output scale + zero_point (Tensor): scalar for the output zero point + See :class:`~torch.nn.ConvTranspose3d` for other attributes. + + Examples:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE) + >>> torch.backends.quantized.engine = 'fbgemm' + >>> from torch.ao.nn import quantized as nnq + >>> # With cubic kernels and equal stride + >>> m = nnq.ConvTranspose3d(16, 33, 3, stride=2) + >>> # non-cubic kernels and unequal stride and with padding + >>> m = nnq.ConvTranspose3d(16, 33, (3, 3, 5), stride=(2, 1, 1), padding=(4, 2, 2)) + >>> input = torch.randn(20, 16, 50, 100, 100) + >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8) + >>> output = m(q_input) + >>> # exact output size can be also specified as an argument + >>> input = torch.randn(1, 16, 12, 12, 12) + >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8) + >>> downsample = nnq.Conv3d(16, 16, 3, stride=2, padding=1) + >>> upsample = nnq.ConvTranspose3d(16, 16, 3, stride=2, padding=1) + >>> h = downsample(q_input) + >>> h.size() + torch.Size([1, 16, 6, 6, 6]) + >>> # xdoctest: +SKIP("FIXME: output_size is not a parameter) + >>> output = upsample(h, output_size=input.size()) + >>> output.size() + torch.Size([1, 16, 12, 12, 12]) + """ + + _FLOAT_MODULE = nn.ConvTranspose3d + + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, output_padding=0, groups=1, bias=True, + dilation=1, padding_mode='zeros', device=None, dtype=None): + factory_kwargs = {'device': device, 'dtype': dtype} + kernel_size = _triple(kernel_size) + stride = _triple(stride) + padding = _triple(padding) + dilation = _triple(dilation) + output_padding = _triple(output_padding) + + super().__init__( + in_channels, out_channels, kernel_size, stride, padding, dilation, + True, output_padding, groups, bias, padding_mode, **factory_kwargs) + + def _get_name(self): + return 'QuantizedConvTranspose3d' + + def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None: + self._packed_params = torch.ops.quantized.conv_transpose3d_prepack( + w, b, self.stride, self.padding, self.output_padding, self.dilation, + self.groups) + + def _weight_bias(self): + w, b = torch.ops.quantized.conv3d_unpack(self._packed_params) + return w, b + + def weight(self): + (w, _) = self._weight_bias() + return w + + def bias(self): + (_, b) = self._weight_bias() + return b + + def forward(self, input): + # Temporarily using len(shape) instead of ndim due to JIT issue + # https://github.com/pytorch/pytorch/issues/23890 + if len(input.shape) != 5: + raise ValueError("Input shape must be `(N, C, T, H, W)`!") + return ops.quantized.conv_transpose3d( + input, self._packed_params, self.scale, self.zero_point) + + @classmethod + def from_reference(cls, ref_qconvt, output_scale, output_zero_point): + return _ConvTransposeNd.from_reference(cls, ref_qconvt, output_scale, output_zero_point) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/dropout.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/dropout.py new file mode 100644 index 0000000000000000000000000000000000000000..64110ab53bed9a7620804a7732caebf36c84d2ca --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/dropout.py @@ -0,0 +1,27 @@ +import torch + +__all__ = ['Dropout'] + +class Dropout(torch.nn.Dropout): + r"""This is the quantized equivalent of :class:`~torch.nn.Dropout`. + And this is a placeholder to enable models where fp32 tensors + had dropout to work with quantized tensors in train and eval mode. + + Args: + p: probability of an element to be zeroed + inplace: can optionally do the operation in-place. Default: ``False`` + """ + + def forward(self, input): + return input + + def _get_name(self): + return 'QuantizedDropout' + + @classmethod + def from_float(cls, mod): + return cls(mod.p, mod.inplace) + + @classmethod + def from_reference(cls, mod, scale, zero_point): + return cls(mod.p, mod.inplace) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/functional_modules.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/functional_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..96408457a4497c72236986122c0cc3204fad4437 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/functional_modules.py @@ -0,0 +1,249 @@ +from typing import List + +import torch +from torch import Tensor +from torch._ops import ops + +__all__ = ['FloatFunctional', 'FXFloatFunctional', 'QFunctional'] + +class FloatFunctional(torch.nn.Module): + r"""State collector class for float operations. + + The instance of this class can be used instead of the ``torch.`` prefix for + some operations. See example usage below. + + .. note:: + + This class does not provide a ``forward`` hook. Instead, you must use + one of the underlying functions (e.g. ``add``). + + Examples:: + + >>> f_add = FloatFunctional() + >>> a = torch.tensor(3.0) + >>> b = torch.tensor(4.0) + >>> f_add.add(a, b) # Equivalent to ``torch.add(a, b)`` + + Valid operation names: + - add + - cat + - mul + - add_relu + - add_scalar + - mul_scalar + """ + def __init__(self): + super().__init__() + self.activation_post_process = torch.nn.Identity() + + def forward(self, x): + raise RuntimeError("FloatFunctional is not intended to use the " + + "'forward'. Please use the underlying operation") + + r"""Operation equivalent to ``torch.add(Tensor, Tensor)``""" + def add(self, x: Tensor, y: Tensor) -> Tensor: + r = torch.add(x, y) + r = self.activation_post_process(r) + return r + + r"""Operation equivalent to ``torch.add(Tensor, float)``""" + def add_scalar(self, x: Tensor, y: float) -> Tensor: + r = torch.add(x, y) + # Note: this operation is not observed because the observation is not + # needed for the quantized op. + return r + + r"""Operation equivalent to ``torch.mul(Tensor, Tensor)``""" + def mul(self, x: Tensor, y: Tensor) -> Tensor: + r = torch.mul(x, y) + r = self.activation_post_process(r) + return r + + r"""Operation equivalent to ``torch.mul(Tensor, float)``""" + def mul_scalar(self, x: Tensor, y: float) -> Tensor: + r = torch.mul(x, y) + # Note: this operation is not observed because the observation is not + # needed for the quantized op. + return r + + r"""Operation equivalent to ``torch.cat``""" + def cat(self, x: List[Tensor], dim: int = 0) -> Tensor: + r = torch.cat(x, dim=dim) + r = self.activation_post_process(r) + return r + + r"""Operation equivalent to ``relu(torch.add(x,y))``""" + def add_relu(self, x: Tensor, y: Tensor) -> Tensor: + r = torch.add(x, y) + r = torch.nn.functional.relu(r) + r = self.activation_post_process(r) + return r + + r"""Operation equivalent to ``torch.matmul(Tensor, Tensor)``""" + def matmul(self, x: Tensor, y: Tensor) -> Tensor: + r = torch.matmul(x, y) + r = self.activation_post_process(r) + return r + +class FXFloatFunctional(torch.nn.Module): + r""" module to replace FloatFunctional module before FX graph mode quantization, + since activation_post_process will be inserted in top level module directly + + Valid operation names: + - add + - cat + - mul + - add_relu + - add_scalar + - mul_scalar + """ + def forward(self, x): + raise RuntimeError("FloatFunctional is not intended to use the " + + "'forward'. Please use the underlying operation") + + r"""Operation equivalent to ``torch.add(Tensor, Tensor)``""" + def add(self, x: Tensor, y: Tensor) -> Tensor: + r = torch.add(x, y) + return r + + r"""Operation equivalent to ``torch.add(Tensor, float)``""" + def add_scalar(self, x: Tensor, y: float) -> Tensor: + r = torch.add(x, y) + return r + + r"""Operation equivalent to ``torch.mul(Tensor, Tensor)``""" + def mul(self, x: Tensor, y: Tensor) -> Tensor: + r = torch.mul(x, y) + return r + + r"""Operation equivalent to ``torch.mul(Tensor, float)``""" + def mul_scalar(self, x: Tensor, y: float) -> Tensor: + r = torch.mul(x, y) + return r + + r"""Operation equivalent to ``torch.cat``""" + def cat(self, x: List[Tensor], dim: int = 0) -> Tensor: + r = torch.cat(x, dim=dim) + return r + + r"""Operation equivalent to ``relu(torch.add(x,y))``""" + def add_relu(self, x: Tensor, y: Tensor) -> Tensor: + r = torch.add(x, y) + r = torch.nn.functional.relu(r) + return r + + r"""Operation equivalent to ``torch.matmul(Tensor, Tensor)``""" + def matmul(self, x: Tensor, y: Tensor) -> Tensor: + r = torch.matmul(x, y) + return r + +class QFunctional(torch.nn.Module): + r"""Wrapper class for quantized operations. + + The instance of this class can be used instead of the + ``torch.ops.quantized`` prefix. See example usage below. + + .. note:: + + This class does not provide a ``forward`` hook. Instead, you must use + one of the underlying functions (e.g. ``add``). + + Examples:: + + >>> q_add = QFunctional() + >>> # xdoctest: +SKIP + >>> a = torch.quantize_per_tensor(torch.tensor(3.0), 1.0, 0, torch.qint32) + >>> b = torch.quantize_per_tensor(torch.tensor(4.0), 1.0, 0, torch.qint32) + >>> q_add.add(a, b) # Equivalent to ``torch.ops.quantized.add(a, b, 1.0, 0)`` + + Valid operation names: + - add + - cat + - mul + - add_relu + - add_scalar + - mul_scalar + """ + def __init__(self): + super().__init__() + self.scale = 1.0 + self.zero_point = 0 + self.activation_post_process = torch.nn.Identity() + + def _save_to_state_dict(self, destination, prefix, keep_vars): + super()._save_to_state_dict(destination, prefix, keep_vars) + destination[prefix + 'scale'] = torch.tensor(self.scale) + destination[prefix + 'zero_point'] = torch.tensor(self.zero_point) + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + + self.scale = float(state_dict.pop(prefix + 'scale')) + self.zero_point = int(state_dict.pop(prefix + 'zero_point')) + super()._load_from_state_dict(state_dict, prefix, local_metadata, False, + missing_keys, unexpected_keys, error_msgs) + + def _get_name(self): + return 'QFunctional' + + def extra_repr(self): + return f'scale={self.scale}, zero_point={self.zero_point}' + + def forward(self, x): + raise RuntimeError("Functional is not intended to use the " + + "'forward'. Please use the underlying operation") + + r"""Operation equivalent to ``torch.ops.quantized.add``""" + def add(self, x: Tensor, y: Tensor) -> Tensor: + r = ops.quantized.add(x, y, scale=self.scale, zero_point=self.zero_point) + r = self.activation_post_process(r) + return r + + r"""Operation equivalent to ``torch.ops.quantized.add(Tensor, float)``""" + def add_scalar(self, x: Tensor, y: float) -> Tensor: + r = ops.quantized.add_scalar(x, y) + # Note: this operation is not observed because the observation is not + # needed for the quantized op. + return r + + r"""Operation equivalent to ``torch.ops.quantized.mul(Tensor, Tensor)``""" + def mul(self, x: Tensor, y: Tensor) -> Tensor: + r = ops.quantized.mul(x, y, scale=self.scale, zero_point=self.zero_point) + r = self.activation_post_process(r) + return r + + r"""Operation equivalent to ``torch.ops.quantized.mul(Tensor, float)``""" + def mul_scalar(self, x: Tensor, y: float) -> Tensor: + r = ops.quantized.mul_scalar(x, y) + # Note: this operation is not observed because the observation is not + # needed for the quantized op. + return r + + r"""Operation equivalent to ``torch.ops.quantized.cat``""" + def cat(self, x: List[Tensor], dim: int = 0) -> Tensor: + r = ops.quantized.cat(x, scale=self.scale, zero_point=self.zero_point, dim=dim) + r = self.activation_post_process(r) + return r + + r"""Operation equivalent to ``torch.ops.quantized.add_relu``""" + def add_relu(self, x: Tensor, y: Tensor) -> Tensor: + r = ops.quantized.add_relu(x, y, scale=self.scale, zero_point=self.zero_point) + r = self.activation_post_process(r) + return r + + r"""Operation equivalent to ``torch.ops.quantized.matmul(Tensor, Tensor)``""" + def matmul(self, x: Tensor, y: Tensor) -> Tensor: + r = ops.quantized.matmul(x, y, scale=self.scale, zero_point=self.zero_point) + # Note: this operation is not observed because the observation is not + # needed for the quantized op. + return r + + @classmethod + def from_float(cls, mod): + assert type(mod) == FloatFunctional, \ + "QFunctional.from_float expects an instance of FloatFunctional" + scale, zero_point = mod.activation_post_process.calculate_qparams() # type: ignore[operator] + new_mod = QFunctional() + new_mod.scale = float(scale) + new_mod.zero_point = int(zero_point) + return new_mod diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/reference/modules/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/reference/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..884720774c5f83c9122cf4de43a4265c9c6afb59 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/reference/modules/__init__.py @@ -0,0 +1,21 @@ +from .linear import Linear +from .conv import Conv1d, Conv2d, Conv3d, ConvTranspose1d, ConvTranspose2d, ConvTranspose3d +from .rnn import RNNCell, LSTMCell, GRUCell, LSTM, GRU +from .sparse import Embedding, EmbeddingBag + +__all__ = [ + 'Linear', + 'Conv1d', + 'Conv2d', + 'Conv3d', + 'ConvTranspose1d', + 'ConvTranspose2d', + 'ConvTranspose3d', + 'RNNCell', + 'LSTMCell', + 'GRUCell', + 'LSTM', + 'GRU', + 'Embedding', + 'EmbeddingBag', +] diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/linear.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/linear.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4f3aa38f724771dd5d094da77bcaafc7f713ae7e Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/linear.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/sparse.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/sparse.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..28c60342afb4c06fff32063fd949592d0aec90e0 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/sparse.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/reference/modules/rnn.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/reference/modules/rnn.py new file mode 100644 index 0000000000000000000000000000000000000000..4120338ce271af197a83ba2de8a767ed5ffe3716 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/reference/modules/rnn.py @@ -0,0 +1,614 @@ +import torch +import torch.nn as nn +from torch import Tensor +from .utils import _quantize_and_dequantize_weight +from .utils import _quantize_weight +from typing import Optional, Dict, Any, Tuple +from torch import _VF +from torch.nn.utils.rnn import PackedSequence + +__all__ = ['RNNCellBase', 'RNNCell', 'LSTMCell', 'GRUCell', 'RNNBase', 'LSTM', 'GRU', 'get_quantized_weight'] + +def _apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor: + return tensor.index_select(dim, permutation) + +def _get_weight_and_quantization_params(module, wn): + weight = getattr(module, wn) + params = [weight] + for param_name in [wn + n for n in ["_qscheme", "_dtype", "_scale", "_zero_point", "_axis_int"]]: + if hasattr(module, param_name): + param = getattr(module, param_name) + else: + param = None + params.append(param) + return params + +def get_quantized_weight(module, wn): + if not hasattr(module, wn): + return None + params = _get_weight_and_quantization_params(module, wn) + weight = _quantize_weight(*params) + return weight + +def _get_quantize_and_dequantized_weight(module, wn): + if not hasattr(module, wn): + return None + params = _get_weight_and_quantization_params(module, wn) + weight = _quantize_and_dequantize_weight(*params) + return weight + +class RNNCellBase(nn.RNNCellBase): + def __init__(self, input_size: int, hidden_size: int, bias: bool, num_chunks: int, + device=None, dtype=None, weight_qparams_dict=None) -> None: + super().__init__(input_size, hidden_size, bias, num_chunks, device=device, dtype=dtype) + # TODO(jerryzh168): maybe make this arg a required arg + if weight_qparams_dict is None: + weight_qparams = { + "qscheme": torch.per_tensor_affine, + "dtype": torch.quint8, + "scale": 1.0, + "zero_point": 0 + } + weight_qparams_dict = { + "weight_ih": weight_qparams, + "weight_hh": weight_qparams, + "is_decomposed": False, + } + assert len(weight_qparams_dict) == 3, "Expected length for weight_qparams_dict to be 3 for QuantizedRNNCellBase(Reference)" + self._init_weight_qparams_dict(weight_qparams_dict, device) + + def _init_weight_qparams_dict(self, weight_qparams_dict, device): + assert weight_qparams_dict is not None + self.is_decomposed = weight_qparams_dict["is_decomposed"] + for key, weight_qparams in weight_qparams_dict.items(): + if key == "is_decomposed": + continue + # TODO: refactor the duplicated code to utils.py + weight_qscheme = weight_qparams["qscheme"] + weight_dtype = weight_qparams["dtype"] + setattr(self, key + "_qscheme", weight_qscheme) + setattr(self, key + "_dtype", weight_dtype) + assert weight_qscheme in [None, torch.per_tensor_affine, torch.per_channel_affine], \ + Exception(f"qscheme: {weight_qscheme} is not support in {self._get_name()}") + if weight_qscheme is not None: + scale = weight_qparams["scale"] + scale_tensor = scale.clone().detach() \ + if isinstance(scale, torch.Tensor) else \ + torch.tensor(scale, dtype=torch.float, device=device) + self.register_buffer(key + "_scale", scale_tensor) + zp = weight_qparams["zero_point"] + zp_tensor = zp.clone().detach() \ + if isinstance(zp, torch.Tensor) else \ + torch.tensor(zp, dtype=torch.int, device=device) + self.register_buffer(key + "_zero_point", zp_tensor) + if weight_qscheme == torch.per_channel_affine: + axis = weight_qparams["axis"] + axis_tensor = axis.clone().detach() \ + if isinstance(axis, torch.Tensor) else \ + torch.tensor(axis, dtype=torch.int, device=device) + self.register_buffer(key + "_axis", axis_tensor) + else: + # added for TorchScriptability, not used + self.register_buffer( + key + "_axis", torch.tensor(0, dtype=torch.int, device=device)) + setattr(self, key + "_axis_int", getattr(self, key + "_axis").item()) + + def _get_name(self): + return "QuantizedRNNCellBase(Reference)" + + def get_quantized_weight_ih(self): + return get_quantized_weight(self, "weight_ih") + + def get_quantized_weight_hh(self): + return get_quantized_weight(self, "weight_hh") + + def get_weight_ih(self): + return _get_quantize_and_dequantized_weight(self, "weight_ih") + + def get_weight_hh(self): + return _get_quantize_and_dequantized_weight(self, "weight_hh") + +class RNNCell(RNNCellBase): + """ + We'll store weight_qparams for all the weights (weight_ih and weight_hh), + we need to pass in a `weight_qparams_dict` that maps from weight name, + e.g. weight_ih, to the weight_qparams for that weight + """ + def __init__(self, input_size: int, hidden_size: int, bias: bool = True, nonlinearity: str = "tanh", + device=None, dtype=None, weight_qparams_dict: Optional[Dict[str, Any]] = None) -> None: + factory_kwargs = {'device': device, 'dtype': dtype, 'weight_qparams_dict': weight_qparams_dict} + super().__init__(input_size, hidden_size, bias, num_chunks=1, **factory_kwargs) + self.nonlinearity = nonlinearity + + def _get_name(self): + return "QuantizedRNNCell(Reference)" + + # TODO: refactor nn.RNNCell to have a _forward that takes weight_ih and weight_hh as input + # and remove duplicated code, same for the other two Cell modules + def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor: + assert input.dim() in (1, 2), \ + f"RNNCell: Expected input to be 1-D or 2-D but received {input.dim()}-D tensor" + is_batched = input.dim() == 2 + if not is_batched: + input = input.unsqueeze(0) + + if hx is None: + hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device) + else: + hx = hx.unsqueeze(0) if not is_batched else hx + + if self.nonlinearity == "tanh": + ret = _VF.rnn_tanh_cell( + input, hx, + self.get_weight_ih(), self.get_weight_hh(), + self.bias_ih, self.bias_hh, + ) + elif self.nonlinearity == "relu": + ret = _VF.rnn_relu_cell( + input, hx, + self.get_weight_ih(), self.get_weight_hh(), + self.bias_ih, self.bias_hh, + ) + else: + ret = input # TODO: remove when jit supports exception flow + raise RuntimeError( + f"Unknown nonlinearity: {self.nonlinearity}") + + if not is_batched: + ret = ret.squeeze(0) + + return ret + + @classmethod + def from_float(cls, mod, weight_qparams_dict): + ref_mod = cls( + mod.input_size, + mod.hidden_size, + mod.bias, + mod.nonlinearity, + mod.weight_ih.device, + mod.weight_ih.dtype, + weight_qparams_dict) + ref_mod.weight_ih = mod.weight_ih + ref_mod.weight_hh = mod.weight_hh + ref_mod.bias_ih = mod.bias_ih + ref_mod.bias_hh = mod.bias_hh + return ref_mod + +class LSTMCell(RNNCellBase): + """ + We'll store weight_qparams for all the weights (weight_ih and weight_hh), + we need to pass in a `weight_qparams_dict` that maps from weight name, + e.g. weight_ih, to the weight_qparams for that weight + """ + def __init__(self, input_size: int, hidden_size: int, bias: bool = True, + device=None, dtype=None, weight_qparams_dict: Optional[Dict[str, Any]] = None) -> None: + factory_kwargs = {'device': device, 'dtype': dtype, 'weight_qparams_dict': weight_qparams_dict} + super().__init__(input_size, hidden_size, bias, num_chunks=4, **factory_kwargs) + + def _get_name(self): + return "QuantizedLSTMCell(Reference)" + + def forward(self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None) -> Tuple[Tensor, Tensor]: + assert input.dim() in (1, 2), \ + f"LSTMCell: Expected input to be 1-D or 2-D but received {input.dim()}-D tensor" + is_batched = input.dim() == 2 + if not is_batched: + input = input.unsqueeze(0) + + if hx is None: + zeros = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device) + hx = (zeros, zeros) + else: + hx = (hx[0].unsqueeze(0), hx[1].unsqueeze(0)) if not is_batched else hx + + ret = _VF.lstm_cell( + input, hx, + self.get_weight_ih(), self.get_weight_hh(), + self.bias_ih, self.bias_hh, + ) + + if not is_batched: + ret = (ret[0].squeeze(0), ret[1].squeeze(0)) + return ret + + @classmethod + def from_float(cls, mod, weight_qparams_dict): + ref_mod = cls( + mod.input_size, + mod.hidden_size, + mod.bias, + mod.weight_ih.device, + mod.weight_ih.dtype, + weight_qparams_dict) + ref_mod.weight_ih = mod.weight_ih + ref_mod.weight_hh = mod.weight_hh + ref_mod.bias_ih = mod.bias_ih + ref_mod.bias_hh = mod.bias_hh + return ref_mod + +class GRUCell(RNNCellBase): + """ + We'll store weight_qparams for all the weights (weight_ih and weight_hh), + we need to pass in a `weight_qparams_dict` that maps from weight name, + e.g. weight_ih, to the weight_qparams for that weight + """ + def __init__(self, input_size: int, hidden_size: int, bias: bool = True, + device=None, dtype=None, weight_qparams_dict: Optional[Dict[str, Any]] = None) -> None: + factory_kwargs = {'device': device, 'dtype': dtype, 'weight_qparams_dict': weight_qparams_dict} + super().__init__(input_size, hidden_size, bias, num_chunks=3, **factory_kwargs) + + def _get_name(self): + return "QuantizedGRUCell(Reference)" + + def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor: + assert input.dim() in (1, 2), \ + f"GRUCell: Expected input to be 1-D or 2-D but received {input.dim()}-D tensor" + is_batched = input.dim() == 2 + if not is_batched: + input = input.unsqueeze(0) + + if hx is None: + hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device) + else: + hx = hx.unsqueeze(0) if not is_batched else hx + + ret = _VF.gru_cell( + input, hx, + self.get_weight_ih(), self.get_weight_hh(), + self.bias_ih, self.bias_hh, + ) + + if not is_batched: + ret = ret.squeeze(0) + + return ret + + @classmethod + def from_float(cls, mod, weight_qparams_dict): + ref_mod = cls( + mod.input_size, + mod.hidden_size, + mod.bias, + mod.weight_ih.device, + mod.weight_ih.dtype, + weight_qparams_dict) + ref_mod.weight_ih = mod.weight_ih + ref_mod.weight_hh = mod.weight_hh + ref_mod.bias_ih = mod.bias_ih + ref_mod.bias_hh = mod.bias_hh + return ref_mod + +class RNNBase(nn.RNNBase): + def __init__(self, mode: str, input_size: int, hidden_size: int, + num_layers: int = 1, bias: bool = True, batch_first: bool = False, + dropout: float = 0., bidirectional: bool = False, proj_size: int = 0, + device=None, dtype=None, + weight_qparams_dict: Optional[Dict[str, Any]] = None) -> None: + super().__init__( + mode, input_size, hidden_size, num_layers, bias, batch_first, dropout, + bidirectional, proj_size, device, dtype + ) + # TODO(jerryzh168): maybe make this arg a required arg + if weight_qparams_dict is None: + weight_qparams = { + 'qscheme': torch.per_tensor_affine, + 'dtype': torch.quint8, + 'scale': 1.0, + 'zero_point': 0 + } + weight_qparams_dict = {"is_decomposed": False} # type: ignore[dict-item] + for wn in self._flat_weights_names: + if wn.startswith("weight"): + weight_qparams_dict[wn] = weight_qparams + self._init_weight_qparams_dict(weight_qparams_dict, device) + + def _init_weight_qparams_dict(self, weight_qparams_dict, device): + self.is_decomposed = weight_qparams_dict["is_decomposed"] + for key, weight_qparams in weight_qparams_dict.items(): + if key == "is_decomposed": + continue + weight_qscheme = weight_qparams["qscheme"] + weight_dtype = weight_qparams["dtype"] + setattr(self, key + "_qscheme", weight_qscheme) + setattr(self, key + "_dtype", weight_dtype) + assert weight_qscheme in [None, torch.per_tensor_affine, torch.per_channel_affine], \ + Exception(f"qscheme: {weight_qscheme} is not support in {self._get_name()}") + if weight_qscheme is not None: + self.register_buffer( + key + "_scale", + torch.tensor(weight_qparams["scale"], dtype=torch.float, device=device)) + self.register_buffer( + key + "_zero_point", + torch.tensor(weight_qparams["zero_point"], dtype=torch.int, device=device)) + if weight_qscheme == torch.per_channel_affine: + self.register_buffer( + key + "_axis", + torch.tensor(weight_qparams["axis"], dtype=torch.int, device=device)) + else: + # added for TorchScriptability, not used + self.register_buffer( + key + "_axis", torch.tensor(0, dtype=torch.int, device=device)) + setattr(self, key + "_axis_int", getattr(self, key + "_axis").item()) + +class LSTM(RNNBase): + """ Reference Quantized LSTM Module + We'll store weight_qparams for all the weights in _flat_weights, we need to pass in + a `weight_qparams_dict` that maps from weight name, e.g. weight_ih_l0, + to the weight_qparams for that weight + """ + def __init__(self, *args, **kwargs): + super().__init__('LSTM', *args, **kwargs) + + # Same as above, see torch/nn/modules/module.py::_forward_unimplemented + def permute_hidden(self, # type: ignore[override] + hx: Tuple[Tensor, Tensor], + permutation: Optional[Tensor] + ) -> Tuple[Tensor, Tensor]: + if permutation is None: + return hx + return _apply_permutation(hx[0], permutation), _apply_permutation(hx[1], permutation) + + def get_expected_cell_size(self, input: Tensor, batch_sizes: Optional[Tensor]) -> Tuple[int, int, int]: + if batch_sizes is not None: + mini_batch = int(batch_sizes[0]) + else: + mini_batch = input.size(0) if self.batch_first else input.size(1) + num_directions = 2 if self.bidirectional else 1 + expected_hidden_size = (self.num_layers * num_directions, + mini_batch, self.hidden_size) + return expected_hidden_size + + # In the future, we should prevent mypy from applying contravariance rules here. + # See torch/nn/modules/module.py::_forward_unimplemented + def check_forward_args(self, # type: ignore[override] + input: Tensor, + hidden: Tuple[Tensor, Tensor], + batch_sizes: Optional[Tensor], + ): + self.check_input(input, batch_sizes) + self.check_hidden_size(hidden[0], self.get_expected_hidden_size(input, batch_sizes), + 'Expected hidden[0] size {}, got {}') + self.check_hidden_size(hidden[1], self.get_expected_cell_size(input, batch_sizes), + 'Expected hidden[1] size {}, got {}') + + def get_quantized_weight_bias_dict(self): + """ dictionary from flat_weight_name to quantized weight or (unquantized) bias + e.g. + { + "weight_ih_l0": quantized_weight, + "bias_ih_l0": unquantized_bias, + ... + } + """ + quantized_weight_bias_dict = {} + for wn in self._flat_weights_names: + if hasattr(self, wn): + if wn.startswith("weight"): + weight_or_bias = get_quantized_weight(self, wn) + else: + weight_or_bias = getattr(self, wn) + else: + weight_or_bias = None + quantized_weight_bias_dict[wn] = weight_or_bias + return quantized_weight_bias_dict + + def get_flat_weights(self): + flat_weights = [] + for wn in self._flat_weights_names: + if hasattr(self, wn): + weight = getattr(self, wn) + if wn.startswith("weight"): + params = _get_weight_and_quantization_params(self, wn) + weight = _quantize_and_dequantize_weight(*params) + else: + weight = None + flat_weights.append(weight) + return flat_weights + + def forward(self, input, hx=None): # noqa: F811 + orig_input = input + # xxx: isinstance check needs to be in conditional for TorchScript to compile + batch_sizes = None + if isinstance(orig_input, PackedSequence): + input, batch_sizes, sorted_indices, unsorted_indices = input + max_batch_size = int(batch_sizes[0]) + else: + batch_sizes = None + is_batched = input.dim() == 3 + batch_dim = 0 if self.batch_first else 1 + if not is_batched: + input = input.unsqueeze(batch_dim) + max_batch_size = input.size(0) if self.batch_first else input.size(1) + sorted_indices = None + unsorted_indices = None + + if hx is None: + num_directions = 2 if self.bidirectional else 1 + real_hidden_size = self.proj_size if self.proj_size > 0 else self.hidden_size + h_zeros = torch.zeros(self.num_layers * num_directions, + max_batch_size, real_hidden_size, + dtype=input.dtype, device=input.device) + c_zeros = torch.zeros(self.num_layers * num_directions, + max_batch_size, self.hidden_size, + dtype=input.dtype, device=input.device) + hx = (h_zeros, c_zeros) + else: + if batch_sizes is None: # If not PackedSequence input. + if is_batched: # type: ignore[possibly-undefined] + if (hx[0].dim() != 3 or hx[1].dim() != 3): + msg = ("For batched 3-D input, hx and cx should " + f"also be 3-D but got ({hx[0].dim()}-D, {hx[1].dim()}-D) tensors") + raise RuntimeError(msg) + else: + if hx[0].dim() != 2 or hx[1].dim() != 2: + msg = ("For unbatched 2-D input, hx and cx should " + f"also be 2-D but got ({hx[0].dim()}-D, {hx[1].dim()}-D) tensors") + raise RuntimeError(msg) + hx = (hx[0].unsqueeze(1), hx[1].unsqueeze(1)) + + # Each batch of the hidden state should match the input sequence that + # the user believes he/she is passing in. + hx = self.permute_hidden(hx, sorted_indices) + + self.check_forward_args(input, hx, batch_sizes) + if batch_sizes is None: + result = _VF.lstm(input, hx, self.get_flat_weights(), self.bias, self.num_layers, + self.dropout, self.training, self.bidirectional, self.batch_first) + else: + result = _VF.lstm(input, batch_sizes, hx, self.get_flat_weights(), self.bias, + self.num_layers, self.dropout, self.training, self.bidirectional) + output = result[0] + hidden = result[1:] + # xxx: isinstance check needs to be in conditional for TorchScript to compile + if isinstance(orig_input, PackedSequence): + output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices) + return output_packed, self.permute_hidden(hidden, unsorted_indices) + else: + if not is_batched: # type: ignore[possibly-undefined] + output = output.squeeze(batch_dim) # type: ignore[possibly-undefined] + hidden = (hidden[0].squeeze(1), hidden[1].squeeze(1)) + return output, self.permute_hidden(hidden, unsorted_indices) + + def _get_name(self): + return "QuantizedLSTM(Reference)" + + @classmethod + def from_float(cls, mod, weight_qparams_dict): + ref_mod = cls( + mod.input_size, + mod.hidden_size, + mod.num_layers, + mod.bias, + mod.batch_first, + mod.dropout, + mod.bidirectional, + weight_qparams_dict=weight_qparams_dict) + for wn in mod._flat_weights_names: + setattr(ref_mod, wn, getattr(mod, wn)) + return ref_mod + +class GRU(RNNBase): + """ Reference Quantized GRU Module + We'll store weight_qparams for all the weights in _flat_weights, we need to pass in + a `weight_qparams_dict` that maps from weight name, e.g. weight_ih_l0, + to the weight_qparams for that weight + """ + def __init__(self, *args, **kwargs): + if 'proj_size' in kwargs: + raise ValueError("proj_size argument is only supported for LSTM, not RNN or GRU") + super().__init__('GRU', *args, **kwargs) + + def get_quantized_weight_bias_dict(self): + """ dictionary from flat_weight_name to quantized weight or (unquantized) bias + e.g. + { + "weight_ih_l0": quantized_weight, + "bias_ih_l0": unquantized_bias, + ... + } + """ + quantized_weight_bias_dict = {} + for wn in self._flat_weights_names: + if hasattr(self, wn): + if wn.startswith("weight"): + weight_or_bias = get_quantized_weight(self, wn) + else: + weight_or_bias = getattr(self, wn) + else: + weight_or_bias = None + quantized_weight_bias_dict[wn] = weight_or_bias + return quantized_weight_bias_dict + + def get_flat_weights(self): + flat_weights = [] + for wn in self._flat_weights_names: + if hasattr(self, wn): + weight = getattr(self, wn) + if wn.startswith("weight"): + params = _get_weight_and_quantization_params(self, wn) + weight = _quantize_and_dequantize_weight(*params) + else: + weight = None + flat_weights.append(weight) + return flat_weights + + def forward(self, input, hx=None): # noqa: F811 + # Note: this is copied from the forward of GRU in https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py + # only changed self._flat_weights to self.get_flat_weights() + # TODO: maybe we can try inheriting from that class and define get_flat_weights + # as a @property? this might interfere with TorchScript, if we remove that + # requirement in the future we should be able to do this + orig_input = input + # xxx: isinstance check needs to be in conditional for TorchScript to compile + if isinstance(orig_input, PackedSequence): + input, batch_sizes, sorted_indices, unsorted_indices = input + max_batch_size = int(batch_sizes[0]) + else: + batch_sizes = None + assert (input.dim() in (2, 3)), f"GRU: Expected input to be 2-D or 3-D but received {input.dim()}-D tensor" + is_batched = input.dim() == 3 + batch_dim = 0 if self.batch_first else 1 + if not is_batched: + input = input.unsqueeze(batch_dim) + if hx is not None: + if hx.dim() != 2: + raise RuntimeError( + f"For unbatched 2-D input, hx should also be 2-D but got {hx.dim()}-D tensor") + hx = hx.unsqueeze(1) + else: + if hx is not None and hx.dim() != 3: + raise RuntimeError( + f"For batched 3-D input, hx should also be 3-D but got {hx.dim()}-D tensor") + max_batch_size = input.size(0) if self.batch_first else input.size(1) + sorted_indices = None + unsorted_indices = None + + if hx is None: + num_directions = 2 if self.bidirectional else 1 + hx = torch.zeros(self.num_layers * num_directions, + max_batch_size, self.hidden_size, + dtype=input.dtype, device=input.device) + else: + # Each batch of the hidden state should match the input sequence that + # the user believes he/she is passing in. + hx = self.permute_hidden(hx, sorted_indices) + + self.check_forward_args(input, hx, batch_sizes) + if batch_sizes is None: + result = _VF.gru(input, hx, self.get_flat_weights(), self.bias, self.num_layers, + self.dropout, self.training, self.bidirectional, self.batch_first) + else: + result = _VF.gru(input, batch_sizes, hx, self.get_flat_weights(), self.bias, + self.num_layers, self.dropout, self.training, self.bidirectional) + output = result[0] + hidden = result[1] + + # xxx: isinstance check needs to be in conditional for TorchScript to compile + if isinstance(orig_input, PackedSequence): + output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices) + return output_packed, self.permute_hidden(hidden, unsorted_indices) + else: + if not is_batched: # type: ignore[possibly-undefined] + output = output.squeeze(batch_dim) # type: ignore[possibly-undefined] + hidden = hidden.squeeze(1) + + return output, self.permute_hidden(hidden, unsorted_indices) + + def _get_name(self): + return "QuantizedGRU(Reference)" + + @classmethod + def from_float(cls, mod, weight_qparams_dict): + ref_mod = cls( + mod.input_size, + mod.hidden_size, + mod.num_layers, + mod.bias, + mod.batch_first, + mod.dropout, + mod.bidirectional, + weight_qparams_dict=weight_qparams_dict) + for wn in mod._flat_weights_names: + setattr(ref_mod, wn, getattr(mod, wn)) + return ref_mod diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/reference/modules/utils.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/reference/modules/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2c1f52cdf884f6b2c032469c86e85d566e4b216f --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/reference/modules/utils.py @@ -0,0 +1,323 @@ +import torch +import typing + +__all__ = [ + "ReferenceQuantizedModule", +] + +class ReferenceQuantizedModule(torch.nn.Module): + def _init_weight_qparams(self, weight_qparams, device): + if weight_qparams is None: + weight_qparams = { + "qscheme": torch.per_tensor_affine, + "dtype": torch.quint8, + "scale": 1.0, + "zero_point": 0 + } + self.weight_qscheme: torch.qscheme = weight_qparams["qscheme"] + self.weight_dtype = weight_qparams["dtype"] + assert self.weight_qscheme in [ + None, torch.per_tensor_affine, torch.per_channel_affine, + torch.per_channel_affine_float_qparams], \ + Exception(f"qscheme: {self.weight_qscheme} is not support in reference quantized {self._get_name()}") + if self.weight_dtype in [torch.quint8, torch.qint8, torch.quint4x2, torch.qint32]: + zero_point_dtype = weight_qparams["zero_point"].dtype if \ + isinstance(weight_qparams["zero_point"], torch.Tensor) else \ + torch.int + w_scale = weight_qparams["scale"] + w_scale_tensor = w_scale.clone().detach() \ + if isinstance(w_scale, torch.Tensor) \ + else torch.tensor(w_scale, dtype=torch.float, device=device) + self.register_buffer("weight_scale", w_scale_tensor) + w_zp = weight_qparams["zero_point"] + w_zp_tensor = w_zp.clone().detach() \ + if isinstance(w_zp, torch.Tensor) \ + else torch.tensor(w_zp, dtype=zero_point_dtype, device=device) + self.register_buffer("weight_zero_point", w_zp_tensor) + if self.weight_qscheme in [torch.per_channel_affine, torch.per_channel_affine_float_qparams]: + w_axis = weight_qparams["axis"] + w_axis_tensor = w_axis.clone().detach() \ + if isinstance(w_axis, torch.Tensor) \ + else torch.tensor(w_axis, dtype=torch.int, device=device) + self.register_buffer("weight_axis", w_axis_tensor) + else: + # added for TorchScriptability, not used + self.register_buffer( + "weight_axis", torch.tensor(0, dtype=torch.int, device=device)) + else: + # added for TorchScriptability, and for torch.float + self.register_buffer("weight_scale", torch.tensor(1.0, dtype=torch.float, device=device)) + self.register_buffer("weight_zero_point", torch.tensor(0, dtype=torch.int, device=device)) + self.register_buffer( + "weight_axis", torch.tensor(0, dtype=torch.int, device=device)) + self.is_decomposed: bool = weight_qparams.get("is_decomposed", False) + # store weight_axis as weight_axis_int due to some constraints of torchdynamo.export + # for capturing `.item` operations + self.weight_axis_int: int = self.weight_axis.item() # type: ignore[operator, assignment] + self.weight_quant_min: typing.Optional[int] = weight_qparams.get("quant_min", None) + self.weight_quant_max: typing.Optional[int] = weight_qparams.get("quant_max", None) + + def get_weight(self): + """ + Fake quantize (quantize and dequantize) the weight with + the quantization parameters for weight, this is used to + simulate the numerics for the quantized weight in a quantized + model + """ + # suppress mypy warning + assert isinstance(self.weight_scale, torch.Tensor) + assert isinstance(self.weight_zero_point, torch.Tensor) + if self.is_decomposed: + return _quantize_and_dequantize_weight_decomposed( + self.weight, # type: ignore[arg-type] + self.weight_qscheme, + self.weight_dtype, + self.weight_scale, + self.weight_zero_point, + self.weight_axis_int, + self.weight_quant_min, + self.weight_quant_max) + else: + return _quantize_and_dequantize_weight( + self.weight, # type: ignore[arg-type] + self.weight_qscheme, + self.weight_dtype, + self.weight_scale, + self.weight_zero_point, + self.weight_axis_int) + + def get_quantized_weight(self): + # suppress mypy warning + assert isinstance(self.weight_scale, torch.Tensor) + assert isinstance(self.weight_zero_point, torch.Tensor) + # assert isinstance(self.weight_axis, torch.Tensor) + if self.is_decomposed: + return _quantize_weight_decomposed( + self.weight, # type: ignore[arg-type] + self.weight_qscheme, + self.weight_dtype, + self.weight_scale, + self.weight_zero_point, + self.weight_axis_int, + self.weight_quant_min, + self.weight_quant_max) + else: + return _quantize_weight( + self.weight, # type: ignore[arg-type] + self.weight_qscheme, + self.weight_dtype, + self.weight_scale, + self.weight_zero_point, + self.weight_axis_int) + + def _save_to_state_dict(self, destination, prefix, keep_vars): + super()._save_to_state_dict(destination, prefix, keep_vars) + _save_weight_qparams( + destination, prefix, self.weight_qscheme, self.weight_dtype, + self.weight_scale, self.weight_zero_point, self.weight_axis) + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + for key in _get_weight_qparam_keys(state_dict, prefix): + setattr(self, key, state_dict[prefix + key]) + state_dict.pop(prefix + key) + + super()._load_from_state_dict( + state_dict, prefix, local_metadata, False, + missing_keys, unexpected_keys, error_msgs) + +def _quantize_weight_decomposed( + weight: torch.Tensor, + weight_qscheme: torch.qscheme, + weight_dtype: torch.dtype, + weight_scale: torch.Tensor, + weight_zero_point: torch.Tensor, + weight_axis: int, + weight_quant_min: typing.Optional[int], + weight_quant_max: typing.Optional[int], +) -> torch.Tensor: + _DTYPE_TO_QVALUE_BOUNDS = { + torch.uint8: (0, 255), + torch.int8: (-128, 127), + torch.int32: (-(2**31), 2**31 - 1), + } + # TODO: add an util function for converting qdtype to dtype + _QDTYPE_TO_UNDERLYING_INT_REPR_DTYPE = { + torch.quint8: torch.uint8, + torch.qint8: torch.int8, + torch.qint32: torch.int32, + } + if weight_qscheme == torch.per_tensor_affine: + if weight_dtype in [torch.quint8, torch.qint8, torch.qint32]: + weight_dtype_ = _QDTYPE_TO_UNDERLYING_INT_REPR_DTYPE[weight_dtype] + if weight_quant_min is None or weight_quant_max is None: + weight_quant_min, weight_quant_max = _DTYPE_TO_QVALUE_BOUNDS[weight_dtype_] + weight = torch.ops.quantized_decomposed.quantize_per_tensor( + weight, + weight_scale, + weight_zero_point, + weight_quant_min, + weight_quant_max, + weight_dtype_ + ) + return weight + elif weight_qscheme in [torch.per_channel_affine, torch.per_channel_affine_float_qparams]: + # TODO: torch.quint4x2 is not supported + if weight_dtype in [torch.quint8, torch.qint8, torch.qint32]: + weight_dtype_ = _QDTYPE_TO_UNDERLYING_INT_REPR_DTYPE[weight_dtype] + if weight_quant_min is None or weight_quant_max is None: + weight_quant_min, weight_quant_max = _DTYPE_TO_QVALUE_BOUNDS[weight_dtype_] + weight = torch.ops.quantized_decomposed.quantize_per_channel( + weight, + weight_scale, + weight_zero_point, + weight_axis, + weight_quant_min, + weight_quant_max, + weight_dtype_) # type: ignore[arg-type] + return weight + raise Exception(f"Unsupported dtype and qscheme: {weight_dtype}, {weight_qscheme}") + +def _dequantize_weight_decomposed( + weight: torch.Tensor, + weight_qscheme: torch.qscheme, + weight_dtype: torch.dtype, + weight_scale: torch.Tensor, + weight_zero_point: torch.Tensor, + weight_axis: int, + weight_quant_min: typing.Optional[int], + weight_quant_max: typing.Optional[int], +) -> torch.Tensor: + # TODO: get the quant_min and quant_max from activation_post_process + _DTYPE_TO_QVALUE_BOUNDS = { + torch.uint8: (0, 255), + torch.int8: (-128, 127), + torch.int32: (-(2**31), 2**31 - 1), + } + # TODO: add an util function for converting qdtype to dtype + _QDTYPE_TO_UNDERLYING_INT_REPR_DTYPE = { + torch.quint8: torch.uint8, + torch.qint8: torch.int8, + torch.qint32: torch.int32, + } + weight_dtype_ = _QDTYPE_TO_UNDERLYING_INT_REPR_DTYPE[weight_dtype] + if weight_quant_min is None or weight_quant_max is None: + weight_quant_min, weight_quant_max = _DTYPE_TO_QVALUE_BOUNDS[weight_dtype_] + if weight_qscheme == torch.per_tensor_affine: + if weight_dtype in [torch.quint8, torch.qint8, torch.qint32]: + weight = torch.ops.quantized_decomposed.dequantize_per_tensor( + weight, + weight_scale, + weight_zero_point, + weight_quant_min, + weight_quant_max, + weight_dtype_ + ) + return weight + elif weight_qscheme in [torch.per_channel_affine, torch.per_channel_affine_float_qparams]: + # TODO: torch.quint4x2 is not supported + if weight_dtype in [torch.quint8, torch.qint8, torch.qint32]: + weight = torch.ops.quantized_decomposed.dequantize_per_channel( + weight, + weight_scale, + weight_zero_point, + weight_axis, + weight_quant_min, + weight_quant_max, + weight_dtype_) # type: ignore[arg-type] + return weight + raise Exception(f"Unsupported dtype and qscheme: {weight_dtype}, {weight_qscheme}") + +def _quantize_weight( + weight: torch.Tensor, + weight_qscheme: torch.qscheme, + weight_dtype: torch.dtype, + weight_scale: torch.Tensor, + weight_zero_point: torch.Tensor, + weight_axis_int: int +) -> torch.Tensor: + if weight_dtype == torch.float16: + weight = weight.to(weight_dtype) + return weight + + if weight_qscheme == torch.per_tensor_affine: + if weight_dtype in [torch.quint8, torch.qint8, torch.qint32]: + weight = torch.quantize_per_tensor(weight, weight_scale, weight_zero_point, weight_dtype) + return weight + elif weight_qscheme in [torch.per_channel_affine, torch.per_channel_affine_float_qparams]: + if weight_dtype in [torch.quint8, torch.qint8, torch.quint4x2, torch.qint32]: + weight = torch.quantize_per_channel( + weight, weight_scale, + weight_zero_point, weight_axis_int, weight_dtype) # type: ignore[arg-type] + return weight + raise Exception(f"Unsupported dtype and qscheme: {weight_dtype}, {weight_qscheme}") + +def _quantize_and_dequantize_weight_decomposed( + weight: torch.Tensor, + weight_qscheme: torch.qscheme, + weight_dtype: torch.dtype, + weight_scale: torch.Tensor, + weight_zero_point: torch.Tensor, + weight_axis_int: int, + weight_quant_min: typing.Optional[int], + weight_quant_max: typing.Optional[int], +) -> torch.Tensor: + """ Quantize and then dequantize the weight based on + the quantization parameters + """ + if weight_qscheme in [ + torch.per_tensor_affine, + torch.per_channel_affine, + torch.per_channel_affine_float_qparams]: + weight_quant = _quantize_weight_decomposed( + weight, weight_qscheme, weight_dtype, weight_scale, weight_zero_point, weight_axis_int, + weight_quant_min, weight_quant_max) + weight_dequant = _dequantize_weight_decomposed( + weight_quant, weight_qscheme, weight_dtype, weight_scale, weight_zero_point, + weight_axis_int, weight_quant_min, weight_quant_max) + else: + weight_dequant = weight + return weight_dequant + +def _quantize_and_dequantize_weight( + weight: torch.Tensor, + weight_qscheme: torch.qscheme, + weight_dtype: torch.dtype, + weight_scale: torch.Tensor, + weight_zero_point: torch.Tensor, + weight_axis_int: int +) -> torch.Tensor: + """ Quantize and then dequantize the weight based on + the quantization parameters + """ + if weight_qscheme in [ + torch.per_tensor_affine, + torch.per_channel_affine, + torch.per_channel_affine_float_qparams]: + weight_quant = _quantize_weight( + weight, weight_qscheme, weight_dtype, weight_scale, weight_zero_point, weight_axis_int) + weight_dequant = weight_quant.dequantize() + else: + weight_dequant = weight + return weight_dequant + +def _save_weight_qparams(destination, prefix, weight_qscheme, weight_dtype, weight_scale, weight_zero_point, weight_axis): + destination[prefix + "weight_qscheme"] = weight_qscheme + destination[prefix + "weight_dtype"] = weight_dtype + if weight_qscheme is not None: + destination[prefix + "weight_scale"] = weight_scale + destination[prefix + "weight_zero_point"] = weight_zero_point + if weight_qscheme == torch.per_channel_affine: + destination[prefix + "weight_axis"] = weight_axis + +def _get_weight_qparam_keys( + state_dict: typing.Dict[str, typing.Any], + prefix: str): + keys = ["weight_qscheme", "weight_dtype"] + weight_qscheme = state_dict[prefix + "weight_qscheme"] + if weight_qscheme is not None: + keys.append("weight_scale") + keys.append("weight_zero_point") + if weight_qscheme == torch.quantize_per_channel: + keys.append("weight_axis") + return keys diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/sparse/quantized/__pycache__/utils.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/sparse/quantized/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..df1f1433ce3b8b30ad22b6e115696792e1573062 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/sparse/quantized/__pycache__/utils.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/sparse/quantized/dynamic/linear.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/sparse/quantized/dynamic/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..5347b682fb5a2bb430ab0b9a947f9ca0b830fb91 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/sparse/quantized/dynamic/linear.py @@ -0,0 +1,139 @@ +from typing import Optional + +import torch +import torch.ao.nn.intrinsic as nni + +from torch.ao.nn.sparse.quantized import linear +from torch.ao.nn.sparse.quantized.utils import LinearBlockSparsePattern +from torch.ao.nn.quantized.modules.utils import _quantize_weight, _hide_packed_params_repr + +__all__ = ['Linear'] + +class Linear(torch.nn.Module): + r""" + A dynamically quantized sparse linear module with float tensor as inputs and outputs. + """ + _version = 1 + _op_type = "sparse_dynamic" + _FLOAT_MODULE = torch.nn.Linear + + def __init__(self, in_features, out_features, row_block_size, col_block_size, bias=True, dtype=torch.qint8): + super().__init__() + + if dtype != torch.qint8: + raise NotImplementedError("Only QINT8 is supported for Sparse Quantized Linear Dynamic") + + self.in_features = in_features + self.out_features = out_features + + if bias: + bias = torch.zeros(self.out_features, dtype=torch.float) + else: + bias = None + + qweight = torch._empty_affine_quantized([out_features, in_features], + scale=1, zero_point=0, dtype=torch.qint8) + self._packed_params = linear.LinearPackedParams(row_block_size=row_block_size, + col_block_size=col_block_size, + dtype=dtype) + self._packed_params.set_weight_bias(qweight, bias, row_block_size, col_block_size) + + def _get_name(self): + return 'SparseQuantizedDynamicLinear' + + def extra_repr(self): + return f'in_features={self.in_features}, out_features={self.out_features}, qscheme={self.weight().qscheme()}' + + def __repr__(self): + return _hide_packed_params_repr(self, linear.LinearPackedParams) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.ops.sparse.qlinear_dynamic(x, self._packed_params._packed_params) + + def _save_to_state_dict(self, destination, prefix, keep_vars): + super()._save_to_state_dict(destination, prefix, keep_vars) + destination[prefix + 'op_type'] = self._op_type + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + op_type = int(state_dict[prefix + 'op_type']) + assert op_type == 'sparse', \ + f"Cannot load from op_type [{op_type}], expecting [{self._op_type}]" + state_dict.pop(prefix + 'op_type') + + version = local_metadata.get('version', None) + assert version <= self._version + + # Is this code valid? In old quantization it seemed to be used to load + # older model + weight = state_dict.pop(prefix + 'weight') + bias = state_dict.pop(prefix + 'bias') + state_dict.update({prefix + '_packed_params.weight': weight, + prefix + '_packed_params.bias': bias}) + + super()._load_from_state_dict( + state_dict, prefix, local_metadata, False, + missing_keys, unexpected_keys, error_msgs) + + def _weight_bias(self): + return self._packed_params._weight_bias() + + def weight(self): + return self._weight_bias()[0] + + def bias(self): + return self._weight_bias()[1] + + def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor], + row_block_size: Optional[int], col_block_size: Optional[int]) -> None: + assert row_block_size is not None and col_block_size is not None + self.out_features = w.shape[0] + self.in_features = w.shape[1] + self._packed_params.set_weight_bias(w, b, row_block_size, col_block_size) + + @classmethod + def from_float(cls, mod): + r"""Create a quantized sparse dynamic module from a float module. + + We only care about the convert at this stage, no need for observers just yet. + """ + assert type(mod) == cls._FLOAT_MODULE, ' nnq.' + cls.__name__ + '.from_float only works for ' + \ + cls._FLOAT_MODULE.__name__ + # TODO: Need to add options to qconfig to avoid the calibration. + # TODO: Add calibration for the sparsity + assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined' + if type(mod) == nni.LinearReLU: + mod = mod[0] + if mod.qconfig is not None and mod.qconfig.weight is not None: + weight_observer = mod.qconfig.weight() + else: + # We have the circular import issues if we import the qconfig in the beginning of this file: + # https://github.com/pytorch/pytorch/pull/24231. The current workaround is to postpone the + # import until we need it. + from torch.ao.quantization.qconfig import default_dynamic_qconfig + weight_observer = default_dynamic_qconfig.weight() + + # It is important to multiply by the mask BEFORE calling the `weight_observer` + # TODO (zaf): Mask might not be part of the qconfig (T83295194) + weight = mod.weight + if getattr(mod.qconfig, 'mask', False): + weight = mod.qconfig.mask * mod.weight + + weight_observer(weight) + dtype = weight_observer.dtype + assert dtype == torch.qint8, 'Weight observer must have dtype torch.qint8' + w_sc, w_zp = weight_observer.calculate_qparams() + if isinstance(w_zp, torch.Tensor): + assert not torch.any(w_zp.bool()), "All weight zero points must map to 0" + else: + assert w_zp == 0, 'Weight zero point must map to 0' + qweight = _quantize_weight(weight.float(), weight_observer) + + row_block_size, col_block_size = LinearBlockSparsePattern.block_size() + qlinear = cls(mod.in_features, + mod.out_features, + row_block_size, + col_block_size, + dtype=dtype) + qlinear.set_weight_bias(qweight, mod.bias, row_block_size, col_block_size) + return qlinear diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/_experimental/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/_experimental/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a16e10bb0da39d1c5f131cadc4ca98716891a7eb Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/_experimental/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/_experimental/data_scheduler/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/_experimental/data_scheduler/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4f7a6f98e2d12628c6095c761d936b35f0ec891d --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/_experimental/data_scheduler/__init__.py @@ -0,0 +1,5 @@ +from .base_data_scheduler import BaseDataScheduler + +__all__ = [ + "BaseDataScheduler", +] diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/_experimental/data_sparsifier/__pycache__/quantization_utils.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/_experimental/data_sparsifier/__pycache__/quantization_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fc530df1b2a0f47c9031c2d17605c9f1e4b5ca95 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/_experimental/data_sparsifier/__pycache__/quantization_utils.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/_experimental/data_sparsifier/lightning/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/_experimental/data_sparsifier/quantization_utils.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/_experimental/data_sparsifier/quantization_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1e76cfc345ac5fde2861e6b09c85cc550bf2e6d4 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/_experimental/data_sparsifier/quantization_utils.py @@ -0,0 +1,130 @@ +import torch +import torch.nn as nn +from torch.ao.pruning.sparsifier.utils import module_to_fqn, fqn_to_module +from typing import Dict, List, Optional + +SUPPORTED_MODULES = { + nn.Embedding, + nn.EmbeddingBag +} + + +def _fetch_all_embeddings(model): + """Fetches Embedding and EmbeddingBag modules from the model + """ + embedding_modules = [] + stack = [model] + while stack: + module = stack.pop() + for _, child in module.named_children(): + fqn_name = module_to_fqn(model, child) + if type(child) in SUPPORTED_MODULES: + embedding_modules.append((fqn_name, child)) + else: + stack.append(child) + return embedding_modules + + +def post_training_sparse_quantize(model, + data_sparsifier_class, + sparsify_first=True, + select_embeddings: Optional[List[nn.Module]] = None, + **sparse_config): + """Takes in a model and applies sparsification and quantization to only embeddings & embeddingbags. + The quantization step can happen before or after sparsification depending on the `sparsify_first` argument. + + Args: + - model (nn.Module) + model whose embeddings needs to be sparsified + - data_sparsifier_class (type of data sparsifier) + Type of sparsification that needs to be applied to model + - sparsify_first (bool) + if true, sparsifies first and then quantizes + otherwise, quantizes first and then sparsifies. + - select_embeddings (List of Embedding modules) + List of embedding modules to in the model to be sparsified & quantized. + If None, all embedding modules with be sparsified + - sparse_config (Dict) + config that will be passed to the constructor of data sparsifier object. + + Note: + 1. When `sparsify_first=False`, quantization occurs first followed by sparsification. + - before sparsifying, the embedding layers are dequantized. + - scales and zero-points are saved + - embedding layers are sparsified and `squash_mask` is applied + - embedding weights are requantized using the saved scales and zero-points + 2. When `sparsify_first=True`, sparsification occurs first followed by quantization. + - embeddings are sparsified first + - quantization is applied on the sparsified embeddings + """ + data_sparsifier = data_sparsifier_class(**sparse_config) + + # if select_embeddings is None, perform it on all embeddings + if select_embeddings is None: + embedding_modules = _fetch_all_embeddings(model) + + else: + embedding_modules = [] + assert isinstance(select_embeddings, List), "the embedding_modules must be a list of embedding modules" + for emb in select_embeddings: + assert type(emb) in SUPPORTED_MODULES, "the embedding_modules list must be an embedding or embedding bags" + fqn_name = module_to_fqn(model, emb) + assert fqn_name is not None, "the embedding modules must be part of input model" + embedding_modules.append((fqn_name, emb)) + + if sparsify_first: + # sparsify + for name, emb_module in embedding_modules: + valid_name = name.replace('.', '_') + data_sparsifier.add_data(name=valid_name, data=emb_module) + + data_sparsifier.step() + data_sparsifier.squash_mask() + + # quantize + for _, emb_module in embedding_modules: + emb_module.qconfig = torch.ao.quantization.float_qparams_weight_only_qconfig + + torch.ao.quantization.prepare(model, inplace=True) + torch.ao.quantization.convert(model, inplace=True) + + else: + # quantize + for _, emb_module in embedding_modules: + emb_module.qconfig = torch.ao.quantization.float_qparams_weight_only_qconfig + + torch.ao.quantization.prepare(model, inplace=True) + torch.ao.quantization.convert(model, inplace=True) + + # retrieve scale & zero_points + quantize_params: Dict[str, Dict] = {'scales': {}, 'zero_points': {}, + 'dequant_weights': {}, 'axis': {}, + 'dtype': {}} + + for name, _ in embedding_modules: + quantized_emb = fqn_to_module(model, name) + assert quantized_emb is not None # satisfy mypy + + quantized_weight = quantized_emb.weight() # type: ignore[operator] + quantize_params['scales'][name] = quantized_weight.q_per_channel_scales() + quantize_params['zero_points'][name] = quantized_weight.q_per_channel_zero_points() + quantize_params['dequant_weights'][name] = torch.dequantize(quantized_weight) + quantize_params['axis'][name] = quantized_weight.q_per_channel_axis() + quantize_params['dtype'][name] = quantized_weight.dtype + + # attach data to sparsifier + data_sparsifier.add_data(name=name.replace('.', '_'), data=quantize_params['dequant_weights'][name]) + + data_sparsifier.step() + data_sparsifier.squash_mask() + + for name, _ in embedding_modules: + quantized_emb = fqn_to_module(model, name) + assert quantized_emb is not None # satisfy mypy + requantized_vector = torch.quantize_per_channel(quantize_params['dequant_weights'][name], + scales=quantize_params['scales'][name], + zero_points=quantize_params['zero_points'][name], + dtype=quantize_params['dtype'][name], + axis=quantize_params['axis'][name]) + + quantized_emb.set_weight(requantized_vector) # type: ignore[operator] diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/FPGM_pruner.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/FPGM_pruner.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ece53af75b16f205c0512c2993ad269ec33439c Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/FPGM_pruner.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..65cd16c93e242d67792cd1a8c01979a1e431f55f Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/_experimental/pruner/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/_experimental/pruner/saliency_pruner.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/_experimental/pruner/saliency_pruner.py new file mode 100644 index 0000000000000000000000000000000000000000..f965fa647de9e880cf82474527d1b01b49c80871 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/_experimental/pruner/saliency_pruner.py @@ -0,0 +1,29 @@ +from .base_structured_sparsifier import BaseStructuredSparsifier + + +class SaliencyPruner(BaseStructuredSparsifier): + """ + Prune rows based on the saliency (L1 norm) of each row. + + This pruner works on N-Dimensional weight tensors. + For each row, we will calculate the saliency, whic is the sum the L1 norm of all weights in that row. + We expect that the resulting saliency vector has the same shape as our mask. + We then pick elements to remove until we reach the target sparsity_level. + """ + + def update_mask(self, module, tensor_name, **kwargs): + # tensor_name will give you the FQN, all other entries in sparse config is present in kwargs + weights = getattr(module, tensor_name) + mask = getattr(module.parametrizations, tensor_name)[0].mask + + # use negative weights so we can use topk (we prune out the smallest) + if weights.dim() <= 1: + raise Exception("Structured pruning can only be applied to a 2+dim weight tensor!") + saliency = -weights.norm(dim=tuple(range(1, weights.dim())), p=1) + assert saliency.shape == mask.shape + + num_to_pick = int(len(mask) * kwargs["sparsity_level"]) + prune = saliency.topk(num_to_pick).indices + + # Set the mask to be false for the rows we want to prune + mask.data[prune] = False diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/scheduler/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/scheduler/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/scheduler/__pycache__/base_scheduler.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/scheduler/__pycache__/base_scheduler.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..95a35d3892dc0a487451c6afc0a4d0fc9b89d683 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/scheduler/__pycache__/base_scheduler.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/scheduler/lambda_scheduler.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/scheduler/lambda_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..a88d99a1f83b4b352025452b5863b1275c187b4d --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/scheduler/lambda_scheduler.py @@ -0,0 +1,47 @@ +import warnings + +from .base_scheduler import BaseScheduler + +__all__ = ["LambdaSL"] + +class LambdaSL(BaseScheduler): + """Sets the sparsity level of each parameter group to the final sl + times a given function. When last_epoch=-1, sets initial sl as zero. + Args: + sparsifier (BaseSparsifier): Wrapped sparsifier. + sl_lambda (function or list): A function which computes a multiplicative + factor given an integer parameter epoch, or a list of such + functions, one for each group in sparsifier.param_groups. + last_epoch (int): The index of last epoch. Default: -1. + verbose (bool): If ``True``, prints a message to stdout for + each update. Default: ``False``. + Example: + >>> # Assuming sparsifier has two groups. + >>> lambda1 = lambda epoch: epoch // 30 + >>> lambda2 = lambda epoch: 0.95 ** epoch + >>> # xdoctest: +SKIP + >>> scheduler = LambdaSL(sparsifier, sl_lambda=[lambda1, lambda2]) + >>> for epoch in range(100): + >>> train(...) + >>> validate(...) + >>> scheduler.step() + """ + + def __init__(self, sparsifier, sl_lambda, last_epoch=-1, verbose=False): + self.sparsifier = sparsifier + + if not isinstance(sl_lambda, list) and not isinstance(sl_lambda, tuple): + self.sl_lambdas = [sl_lambda] * len(sparsifier.groups) + else: + if len(sl_lambda) != len(sparsifier.groups): + raise ValueError(f"Expected {len(sparsifier.groups)} lr_lambdas, but got {len(sl_lambda)}") + self.sl_lambdas = list(sl_lambda) + super().__init__(sparsifier, last_epoch, verbose) + + def get_sl(self): + if not self._get_sl_called_within_step: + warnings.warn( + "To get the last sparsity level computed by the scheduler, " + "please use `get_last_sl()`.") + return [base_sl * lmbda(self.last_epoch) + for lmbda, base_sl in zip(self.sl_lambdas, self.base_sl)] diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/sparsifier/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/sparsifier/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..27b8046c38587dd243e80e0ecd675952c11be96b Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/sparsifier/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/sparsifier/__pycache__/base_sparsifier.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/sparsifier/__pycache__/base_sparsifier.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3b1a060be8e1a46fc2043f0136b6edd050c818e0 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/sparsifier/__pycache__/base_sparsifier.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/sparsifier/__pycache__/weight_norm_sparsifier.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/sparsifier/__pycache__/weight_norm_sparsifier.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7bd928675493cf6a30dea256e9810c0909ea33b6 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/pruning/sparsifier/__pycache__/weight_norm_sparsifier.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/_equalize.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/_equalize.py new file mode 100644 index 0000000000000000000000000000000000000000..7d39dbcf1ca861fc6ae8b19ed3684ec6ec72a0f2 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/_equalize.py @@ -0,0 +1,182 @@ +import torch +import copy +from typing import Dict, Any + +__all__ = [ + "set_module_weight", + "set_module_bias", + "get_module_weight", + "get_module_bias", + "max_over_ndim", + "min_over_ndim", + "channel_range", + "cross_layer_equalization", + "equalize", + "converged", +] + +_supported_types = {torch.nn.Conv2d, torch.nn.Linear} +_supported_intrinsic_types = {torch.ao.nn.intrinsic.ConvReLU2d, torch.ao.nn.intrinsic.LinearReLU} +_all_supported_types = _supported_types.union(_supported_intrinsic_types) + +def set_module_weight(module, weight) -> None: + if type(module) in _supported_types: + module.weight = torch.nn.Parameter(weight) + else: + module[0].weight = torch.nn.Parameter(weight) + +def set_module_bias(module, bias) -> None: + if type(module) in _supported_types: + module.bias = torch.nn.Parameter(bias) + else: + module[0].bias = torch.nn.Parameter(bias) + +def get_module_weight(module): + if type(module) in _supported_types: + return module.weight + else: + return module[0].weight + +def get_module_bias(module): + if type(module) in _supported_types: + return module.bias + else: + return module[0].bias + +def max_over_ndim(input, axis_list, keepdim=False): + """Apply 'torch.max' over the given axes.""" + axis_list.sort(reverse=True) + for axis in axis_list: + input, _ = input.max(axis, keepdim) + return input + +def min_over_ndim(input, axis_list, keepdim=False): + """Apply 'torch.min' over the given axes.""" + axis_list.sort(reverse=True) + for axis in axis_list: + input, _ = input.min(axis, keepdim) + return input + +def channel_range(input, axis=0): + """Find the range of weights associated with a specific channel.""" + size_of_tensor_dim = input.ndim + axis_list = list(range(size_of_tensor_dim)) + axis_list.remove(axis) + + mins = min_over_ndim(input, axis_list) + maxs = max_over_ndim(input, axis_list) + + assert mins.size(0) == input.size(axis), "Dimensions of resultant channel range does not match size of requested axis" + return maxs - mins + +def cross_layer_equalization(module1, module2, output_axis=0, input_axis=1): + """Scale the range of Tensor1.output to equal Tensor2.input. + + Given two adjacent tensors', the weights are scaled such that + the ranges of the first tensors' output channel are equal to the + ranges of the second tensors' input channel + """ + if type(module1) not in _all_supported_types or type(module2) not in _all_supported_types: + raise ValueError("module type not supported:", type(module1), " ", type(module2)) + + weight1 = get_module_weight(module1) + weight2 = get_module_weight(module2) + + if weight1.size(output_axis) != weight2.size(input_axis): + raise TypeError("Number of output channels of first arg do not match \ + number input channels of second arg") + + bias = get_module_bias(module1) + + weight1_range = channel_range(weight1, output_axis) + weight2_range = channel_range(weight2, input_axis) + + # producing scaling factors to applied + weight2_range += 1e-9 + scaling_factors = torch.sqrt(weight1_range / weight2_range) + inverse_scaling_factors = torch.reciprocal(scaling_factors) + + bias = bias * inverse_scaling_factors + + # formatting the scaling (1D) tensors to be applied on the given argument tensors + # pads axis to (1D) tensors to then be broadcasted + size1 = [1] * weight1.ndim + size1[output_axis] = weight1.size(output_axis) + size2 = [1] * weight2.ndim + size2[input_axis] = weight2.size(input_axis) + + scaling_factors = torch.reshape(scaling_factors, size2) + inverse_scaling_factors = torch.reshape(inverse_scaling_factors, size1) + + weight1 = weight1 * inverse_scaling_factors + weight2 = weight2 * scaling_factors + + set_module_weight(module1, weight1) + set_module_bias(module1, bias) + set_module_weight(module2, weight2) + +def equalize(model, paired_modules_list, threshold=1e-4, inplace=True): + """Equalize modules until convergence is achieved. + + Given a list of adjacent modules within a model, equalization will + be applied between each pair, this will repeated until convergence is achieved + + Keeps a copy of the changing modules from the previous iteration, if the copies + are not that different than the current modules (determined by converged_test), + then the modules have converged enough that further equalizing is not necessary + + Implementation of this referced section 4.1 of this paper https://arxiv.org/pdf/1906.04721.pdf + + Args: + model: a model (nn.module) that equalization is to be applied on + paired_modules_list: a list of lists where each sublist is a pair of two + submodules found in the model, for each pair the two submodules generally + have to be adjacent in the model to get expected/reasonable results + threshold: a number used by the converged function to determine what degree + similarity between models is necessary for them to be called equivalent + inplace: determines if function is inplace or not + """ + if not inplace: + model = copy.deepcopy(model) + + name_to_module : Dict[str, torch.nn.Module] = {} + previous_name_to_module: Dict[str, Any] = {} + name_set = {name for pair in paired_modules_list for name in pair} + + for name, module in model.named_modules(): + if name in name_set: + name_to_module[name] = module + previous_name_to_module[name] = None + while not converged(name_to_module, previous_name_to_module, threshold): + for pair in paired_modules_list: + previous_name_to_module[pair[0]] = copy.deepcopy(name_to_module[pair[0]]) + previous_name_to_module[pair[1]] = copy.deepcopy(name_to_module[pair[1]]) + + cross_layer_equalization(name_to_module[pair[0]], name_to_module[pair[1]]) + + return model + +def converged(curr_modules, prev_modules, threshold=1e-4): + """Test whether modules are converged to a specified threshold. + + Tests for the summed norm of the differences between each set of modules + being less than the given threshold + + Takes two dictionaries mapping names to modules, the set of names for each dictionary + should be the same, looping over the set of names, for each name take the difference + between the associated modules in each dictionary + + """ + if curr_modules.keys() != prev_modules.keys(): + raise ValueError("The keys to the given mappings must have the same set of names of modules") + + summed_norms = torch.tensor(0.) + if None in prev_modules.values(): + return False + for name in curr_modules.keys(): + curr_weight = get_module_weight(curr_modules[name]) + prev_weight = get_module_weight(prev_modules[name]) + + difference = curr_weight.sub(prev_weight) + summed_norms += torch.norm(difference) + return bool(summed_norms < threshold) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a0eab9c9521644825141f0684c7ad5bac6176d1e --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/__init__.py @@ -0,0 +1,23 @@ +from .backend_config import BackendConfig, BackendPatternConfig, DTypeConfig, DTypeWithConstraints, ObservationType +from .fbgemm import get_fbgemm_backend_config +from .native import get_native_backend_config, get_native_backend_config_dict +from .qnnpack import get_qnnpack_backend_config +from .tensorrt import get_tensorrt_backend_config, get_tensorrt_backend_config_dict +from .executorch import get_executorch_backend_config +from .onednn import get_onednn_backend_config + +__all__ = [ + "get_fbgemm_backend_config", + "get_native_backend_config", + "get_native_backend_config_dict", + "get_qnnpack_backend_config", + "get_tensorrt_backend_config", + "get_tensorrt_backend_config_dict", + "get_executorch_backend_config", + "BackendConfig", + "BackendPatternConfig", + "DTypeConfig", + "DTypeWithConstraints", + "ObservationType", + "get_onednn_backend_config", +] diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d4dd2677aeaf279d822ea5da87e8e1d8a53fafd1 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/__pycache__/executorch.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/__pycache__/executorch.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e563dbab895ce13525bcd12ee64239c5d5731677 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/__pycache__/executorch.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/__pycache__/fbgemm.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/__pycache__/fbgemm.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..13828e973ef81de3aac6d67070feb5de68ad2852 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/__pycache__/fbgemm.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/__pycache__/tensorrt.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/__pycache__/tensorrt.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2a2785ff58630e5eff969b7512117fe394dad34d Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/__pycache__/tensorrt.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/__pycache__/utils.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e7f67d0d7fe3c9dc25f598a0b6d5fc11fdcd06d3 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/__pycache__/utils.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/__pycache__/x86.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/__pycache__/x86.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4e3f2dd3385afb816629db68fc952194a6a9bba2 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/__pycache__/x86.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/_common_operator_config_utils.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/_common_operator_config_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4e946a25ffbbf003d39a020ea75fea185551ce46 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/_common_operator_config_utils.py @@ -0,0 +1,637 @@ +import copy +import operator +import torch +import torch.nn.functional as F +import torch.nn as nn +import torch.ao.nn.intrinsic as nni +import torch.ao.nn.intrinsic.qat as nniqat +import torch.ao.nn.qat as nnqat +import torch.ao.nn.quantized.reference as nnqr +from collections import namedtuple +from typing import Callable, Dict, List, Union +from .backend_config import ( + BackendPatternConfig, + DTypeConfig, + DTypeWithConstraints, + ObservationType, +) +from ..fuser_method_mappings import ( + _sequential_wrapper2, + fuse_conv_bn, + fuse_conv_bn_relu, + fuse_linear_bn, + fuse_convtranspose_bn, +) + +__all__: List[str] = [] + +# TODO: rename to be more explicit, e.g. qat_conv_relu +_ConvMetadata = namedtuple( + "_ConvMetadata", + ["root", "transpose", "bn", "reference", "transpose_reference", + "fused_conv_relu", "fused_conv_bn", "fused_conv_bn_relu", + "qat", "relu_qat", "bn_qat", "bn_relu_qat", + "func", "func_transpose"]) +_Conv1dMetadata = _ConvMetadata( + nn.Conv1d, nn.ConvTranspose1d, nn.BatchNorm1d, nnqr.Conv1d, nnqr.ConvTranspose1d, + nni.ConvReLU1d, nni.ConvBn1d, nni.ConvBnReLU1d, + nnqat.Conv1d, nniqat.ConvReLU1d, nniqat.ConvBn1d, nniqat.ConvBnReLU1d, + F.conv1d, F.conv_transpose1d) +_Conv2dMetadata = _ConvMetadata( + nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d, nnqr.Conv2d, nnqr.ConvTranspose2d, + nni.ConvReLU2d, nni.ConvBn2d, nni.ConvBnReLU2d, + nnqat.Conv2d, nniqat.ConvReLU2d, nniqat.ConvBn2d, nniqat.ConvBnReLU2d, + F.conv2d, F.conv_transpose2d) +_Conv3dMetadata = _ConvMetadata( + nn.Conv3d, nn.ConvTranspose3d, nn.BatchNorm3d, nnqr.Conv3d, nnqr.ConvTranspose3d, + nni.ConvReLU3d, nni.ConvBn3d, nni.ConvBnReLU3d, + nnqat.Conv3d, nniqat.ConvReLU3d, nniqat.ConvBn3d, nniqat.ConvBnReLU3d, + F.conv3d, F.conv_transpose3d) + +# Add constraints for fixed qparams ops like sigmoid and tanh to ensure values +# fall within the proper ranges, e.g. [0, 1] for sigmoid, [-1, 1] for tanh +_FIXED_QPARAM_OP_0TO1_CONSTRAINTS = DTypeWithConstraints( + dtype=torch.quint8, + quant_min_lower_bound=0, + quant_max_upper_bound=255, + scale_exact_match=1.0 / 256.0, + zero_point_exact_match=0, +) +_FIXED_QPARAM_OP_NEG1TO1_CONSTRAINTS = DTypeWithConstraints( + dtype=torch.quint8, + quant_min_lower_bound=0, + quant_max_upper_bound=255, + scale_exact_match=2.0 / 256.0, + zero_point_exact_match=128, +) +_FIXED_QPARAMS_OP_TO_CONSTRAINTS: Dict[Union[Callable, str], DTypeWithConstraints] = { + torch.nn.Hardsigmoid: _FIXED_QPARAM_OP_0TO1_CONSTRAINTS, + torch.nn.functional.hardsigmoid: _FIXED_QPARAM_OP_0TO1_CONSTRAINTS, + "hardsigmoid": _FIXED_QPARAM_OP_0TO1_CONSTRAINTS, + "hardsigmoid_": _FIXED_QPARAM_OP_0TO1_CONSTRAINTS, + torch.nn.Sigmoid: _FIXED_QPARAM_OP_0TO1_CONSTRAINTS, + torch.sigmoid: _FIXED_QPARAM_OP_0TO1_CONSTRAINTS, + "sigmoid": _FIXED_QPARAM_OP_0TO1_CONSTRAINTS, + "sigmoid_": _FIXED_QPARAM_OP_0TO1_CONSTRAINTS, + torch.nn.Softmax: _FIXED_QPARAM_OP_0TO1_CONSTRAINTS, + torch.nn.Tanh: _FIXED_QPARAM_OP_NEG1TO1_CONSTRAINTS, + torch.tanh: _FIXED_QPARAM_OP_NEG1TO1_CONSTRAINTS, + "tanh": _FIXED_QPARAM_OP_NEG1TO1_CONSTRAINTS, + "tanh_": _FIXED_QPARAM_OP_NEG1TO1_CONSTRAINTS, +} + +def _get_binary_op_configs(dtype_configs: List[DTypeConfig]) -> List[BackendPatternConfig]: + binary_op_configs: List[BackendPatternConfig] = [] + num_tensor_args_to_observation_type_mapping = { + # TODO: this is not used right now since we have extra check in prepare + # will need to change this to NO_OBSERVER later after we implemented + # Tensor dtype inference properly + 0: ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT, + 1: ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT, + 2: ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT, + } + for op_with_quantized_bop_scalar_variant in [operator.add, torch.add, operator.mul, torch.mul]: + bop_patterns = [ + (op_with_quantized_bop_scalar_variant, nn.ReLU), + (op_with_quantized_bop_scalar_variant, F.relu), + (op_with_quantized_bop_scalar_variant, torch.relu), + op_with_quantized_bop_scalar_variant + ] + for bop_pattern in bop_patterns: + binary_op_configs.append( + BackendPatternConfig(bop_pattern) + .set_dtype_configs(dtype_configs) # noqa: E131 + ._set_num_tensor_args_to_observation_type(num_tensor_args_to_observation_type_mapping)) + # matmul + binary_op_configs.append( + BackendPatternConfig(torch.matmul) + .set_dtype_configs(dtype_configs) # noqa: E131 + ) + return binary_op_configs + +def _get_linear_configs(dtype_configs: List[DTypeConfig]) -> List[BackendPatternConfig]: + """ + Return all configs related to linear modules and ops. + """ + observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + linear_configs: List[BackendPatternConfig] = [] + + # (1) Single linear modules/functions + # ------------------------------------- + # linear module + linear_configs.append( + BackendPatternConfig(torch.nn.Linear) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(torch.nn.Linear) + .set_reference_quantized_module(nnqr.Linear) + .set_qat_module(nnqat.Linear)) + # linear qat module + linear_configs.append( + BackendPatternConfig(nnqat.Linear) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(torch.nn.Linear) + .set_reference_quantized_module(nnqr.Linear)) + # functional linear + linear_configs.append( + BackendPatternConfig(torch.nn.functional.linear) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + ._set_input_type_to_index({"weight": 1, "bias": 2})) + + # (2) Linear + relu + # ------------------- + # 2.1 linear module + relu fusion config + # linear relu, linear module + relu module + linear_configs.append( + BackendPatternConfig((torch.nn.Linear, torch.nn.ReLU)) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method(_sequential_wrapper2(nni.LinearReLU)) + .set_fused_module(nni.LinearReLU)) + # linear relu, linear module + functional relu + linear_configs.append( + BackendPatternConfig((torch.nn.Linear, torch.nn.functional.relu)) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method(_sequential_wrapper2(nni.LinearReLU)) + .set_fused_module(nni.LinearReLU)) + + # 2.2 linear module + relu, fused module configs + # linear relu, fused module + linear_configs.append( + BackendPatternConfig(nni.LinearReLU) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(torch.nn.Linear) + .set_reference_quantized_module(nnqr.Linear) + .set_qat_module(nniqat.LinearReLU)) + # linear relu, qat fused module + linear_configs.append( + BackendPatternConfig(nniqat.LinearReLU) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(torch.nn.Linear) + .set_reference_quantized_module(nnqr.Linear)) + # 2.3 functional linear + relu configs + # linear relu, functional linear + relu module + linear_configs.append( + BackendPatternConfig((F.linear, torch.nn.ReLU)) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs)) + # linear relu, functional linear + functional relu + linear_configs.append( + BackendPatternConfig((F.linear, F.relu)) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs)) + + # (3) Linear + batchnorm + # ------------------------ + # 3.1 linear bn fusion + linear_configs.append( + BackendPatternConfig((nn.Linear, nn.BatchNorm1d)) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method(fuse_linear_bn) + .set_fused_module(nni.LinearBn1d)) + + # 3.2 linear bn fused + # linear bn, fused module + linear_configs.append( + BackendPatternConfig(nni.LinearBn1d) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(torch.nn.Linear) + .set_reference_quantized_module(nnqr.Linear) + .set_qat_module(nniqat.LinearBn1d)) + # linear bn, qat fused module + linear_configs.append( + BackendPatternConfig(nniqat.LinearBn1d) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(torch.nn.Linear) + .set_reference_quantized_module(nnqr.Linear)) + return linear_configs + +def _get_conv_configs(dtype_configs): + """ + Return all configs related to conv modules and ops. + """ + conv_configs = [] + observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + for convs in [_Conv1dMetadata, _Conv2dMetadata, _Conv3dMetadata]: + + # (1) Single conv modules/functions + # ----------------------------------- + # conv module + conv_configs.append( + BackendPatternConfig(convs.root) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(convs.root) + .set_reference_quantized_module(convs.reference) + .set_qat_module(convs.qat)) + # conv qat module + conv_configs.append( + BackendPatternConfig(convs.qat) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(convs.root) + .set_reference_quantized_module(convs.reference)) + # functional conv + conv_configs.append( + BackendPatternConfig(convs.func) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + ._set_input_type_to_index({"weight": 1, "bias": 2})) + + # (2) Conv + relu + # ----------------- + # 2.1 conv module + relu fusion configs + # conv relu fusion, conv module + relu module + conv_configs.append( + BackendPatternConfig((convs.root, torch.nn.ReLU)) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method(_sequential_wrapper2(convs.fused_conv_relu)) + .set_fused_module(convs.fused_conv_relu)) + # conv relu fusion, conv module + functional relu + conv_configs.append( + BackendPatternConfig((convs.root, F.relu)) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method(_sequential_wrapper2(convs.fused_conv_relu)) + .set_fused_module(convs.fused_conv_relu)) + # 2.2 conv module + relu fused module configs + # conv relu, fused module + conv_configs.append( + BackendPatternConfig(convs.fused_conv_relu) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(convs.root) + .set_reference_quantized_module(convs.reference) + .set_qat_module(convs.relu_qat)) + # conv relu, qat fused module + conv_configs.append( + BackendPatternConfig(convs.relu_qat) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(convs.root) + .set_reference_quantized_module(convs.reference)) + # 2.3 functional conv + relu configs + # conv relu, functional conv + relu module + conv_configs.append( + BackendPatternConfig((convs.func, torch.nn.ReLU)) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs)) + # conv relu, functional conv + functional relu + conv_configs.append( + BackendPatternConfig((convs.func, F.relu)) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs)) + + # fused conv relu + conv_configs.append( + BackendPatternConfig(convs.fused_conv_relu) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_qat_module(convs.relu_qat)) + + conv_configs.append( + BackendPatternConfig(convs.relu_qat) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_root_module(convs.root) + .set_reference_quantized_module(convs.reference)) + + # (3) Conv + batchnorm (+ relu) + # ------------------------------- + # 3.1 conv bn fusion configs + # conv + bn fusion + conv_configs.append( + BackendPatternConfig((convs.root, convs.bn)) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method(fuse_conv_bn) + .set_fused_module(convs.fused_conv_bn)) + # conv + bn + relu module fusion + conv_configs.append( + BackendPatternConfig((convs.root, convs.bn, nn.ReLU)) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method(fuse_conv_bn_relu) + .set_fused_module(convs.fused_conv_bn_relu)) + # conv + bn + relu functional fusion + conv_configs.append( + BackendPatternConfig((convs.root, convs.bn, F.relu)) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_root_module(convs.root) + .set_fuser_method(fuse_conv_bn_relu) + .set_fused_module(convs.fused_conv_bn_relu)) + # TODO: we can add fusion for torch.relu as well + + # 3.2 conv + bn (+ relu) fused module configs + # fused conv bn + conv_configs.append( + BackendPatternConfig(convs.fused_conv_bn) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_qat_module(convs.bn_qat)) + + # fused conv bn relu + conv_configs.append( + BackendPatternConfig(convs.fused_conv_bn_relu) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_qat_module(convs.bn_relu_qat)) + + # conv bn, qat fused module + conv_configs.append( + BackendPatternConfig(convs.bn_qat) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(convs.root) + .set_reference_quantized_module(convs.reference)) + # conv bn relu, qat fused module + conv_configs.append( + BackendPatternConfig(convs.bn_relu_qat) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(convs.root) + .set_reference_quantized_module(convs.reference)) + + # (4) conv transpose and its fusion + # 4.1 conv transpose config + conv_configs.append( + BackendPatternConfig(convs.transpose) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_root_module(convs.transpose) + .set_reference_quantized_module(convs.transpose_reference)) + + # 4.2 conv transpose + bn fusion + conv_configs.append( + BackendPatternConfig((convs.transpose, convs.bn)) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method(fuse_convtranspose_bn) + .set_root_module(convs.transpose) + .set_reference_quantized_module(convs.transpose_reference)) + + # 4.3 functional conv transpose + conv_configs.append( + BackendPatternConfig(convs.func_transpose) + .set_dtype_configs(dtype_configs) # noqa: E131 + ._set_input_type_to_index({"weight": 1, "bias": 2})) + + return conv_configs + +def _get_cat_config(dtype_configs: List[DTypeConfig]) -> BackendPatternConfig: + return BackendPatternConfig(torch.cat) \ + .set_observation_type(ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT) \ + .set_dtype_configs(dtype_configs) + +def _get_ln_configs(dtype_configs: List[DTypeConfig]) -> List[BackendPatternConfig]: + ln_configs = [] + ln_configs.append( + BackendPatternConfig(torch.nn.LayerNorm) + .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) # noqa: E131 + .set_dtype_configs(dtype_configs) + ) + ln_configs.append( + BackendPatternConfig(torch.nn.functional.layer_norm) + .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) # noqa: E131 + .set_dtype_configs(dtype_configs) + ._set_input_type_to_index({"weight": 2, "bias": 3}) + ) + return ln_configs + +def _get_default_op_configs(dtype_configs: List[DTypeConfig]) -> List[BackendPatternConfig]: + configs = [] + default_ops = [ + torch.nn.ELU, + torch.nn.LeakyReLU, + torch.nn.Hardswish, + torch.nn.InstanceNorm1d, + torch.nn.InstanceNorm2d, + torch.nn.InstanceNorm3d, + torch.nn.Dropout, + torch.nn.PReLU, + torch.nn.functional.elu, + torch.nn.functional.hardswish, + torch.nn.functional.leaky_relu, + torch.nn.functional.dropout, + ] + for op in default_ops: + configs.append( + BackendPatternConfig(op) + .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) # noqa: E131 + .set_dtype_configs(dtype_configs)) + + configs.append( + BackendPatternConfig(torch.nn.functional.group_norm) + .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) # noqa: E131 + .set_dtype_configs(dtype_configs) + ._set_input_type_to_index({"weight": 2, "bias": 3}) + ) + + configs.append( + BackendPatternConfig(torch.nn.functional.instance_norm) + .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) # noqa: E131 + .set_dtype_configs(dtype_configs) + ._set_input_type_to_index({"weight": 3, "bias": 4}) + ) + return configs + +def _add_fixed_qparams_to_dtype_configs( + dtype_configs: List[DTypeConfig], + constraints: DTypeWithConstraints, +) -> List[DTypeConfig]: + """ + Return a copy of the list of DTypeConfigs where activations are subject to the specified + constraints required for fixed qparams ops. + + If the data type doesn't match the one in the constraints, simply leave the corresponding + DTypeConfig unchanged. + + If `scale_min_lower_bound` or `scale_max_upper_bound` is specified in the activations, + throw an exception since these settings are incompatible with fixed qparams ops. + """ + new_dtype_configs = [] + for dtype_config in dtype_configs: + dc = copy.deepcopy(dtype_config) + for orig_constraints in [dc.input_dtype_with_constraints, dc.output_dtype_with_constraints]: + if orig_constraints.dtype != constraints.dtype: + continue + if orig_constraints.scale_min_lower_bound is not None: + raise ValueError(f"scale_min_lower_bound is invalid for fixed qparams ops: {dtype_config}") + if orig_constraints.scale_max_upper_bound is not None: + raise ValueError(f"scale_max_upper_bound is invalid for fixed qparams ops: {dtype_config}") + orig_constraints.quant_min_lower_bound = constraints.quant_min_lower_bound + orig_constraints.quant_max_upper_bound = constraints.quant_max_upper_bound + orig_constraints.scale_exact_match = constraints.scale_exact_match + orig_constraints.zero_point_exact_match = constraints.zero_point_exact_match + new_dtype_configs.append(dc) + return new_dtype_configs + +def _get_fixed_qparams_op_configs(dtype_configs: List[DTypeConfig]) -> List[BackendPatternConfig]: + fixed_qparams_op_configs = [] + for fixed_qparam_op, constraints in _FIXED_QPARAMS_OP_TO_CONSTRAINTS.items(): + new_dtype_configs = _add_fixed_qparams_to_dtype_configs(dtype_configs, constraints) + fixed_qparams_op_configs.append( + BackendPatternConfig(fixed_qparam_op) + .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) # noqa: E131 + .set_dtype_configs(new_dtype_configs)) + return fixed_qparams_op_configs + +def _get_share_qparams_op_configs(dtype_configs): + """ Get the operator config for the operators that works for both float and quantized input + if input is quantized, the output Tensor shares the same quantization parameter + with input. + Example operator: avgpool2d, reshape, transpose, maxpool2d + Example observed operator: + observer_0 - avgpool2d - observer_0 (same observer instance as input) + """ + + def _get_share_qprams_op_backend_config(op): + return BackendPatternConfig(op) \ + .set_observation_type(ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT) \ + .set_dtype_configs(dtype_configs) + + share_qparams_ops = [ + torch.nn.AdaptiveAvgPool1d, + torch.nn.AdaptiveAvgPool2d, + torch.nn.AdaptiveAvgPool3d, + torch.nn.AvgPool1d, + torch.nn.AvgPool2d, + torch.nn.AvgPool3d, + torch.nn.Hardtanh, + torch.nn.Identity, + torch.nn.MaxPool1d, + torch.nn.MaxPool2d, + torch.nn.MaxPool3d, + torch.nn.PixelShuffle, + torch.nn.PixelUnshuffle, + torch.nn.ReLU, + torch.nn.ReLU6, + torch.adaptive_avg_pool1d, + torch.nn.functional.adaptive_avg_pool2d, + torch.nn.functional.adaptive_avg_pool3d, + torch.nn.functional.hardtanh, + torch.nn.functional.hardtanh_, + torch.nn.functional.interpolate, + torch.nn.functional.max_pool1d, + torch.nn.functional.max_pool2d, + torch.nn.functional.max_pool3d, + torch.nn.functional.pixel_shuffle, + torch.nn.functional.pixel_unshuffle, + torch.nn.functional.relu, + torch.nn.functional.relu6, + torch.avg_pool1d, + torch._C._nn.avg_pool2d, + torch._C._nn.avg_pool3d, + torch.clamp, + torch.flatten, + torch.mean, + torch.narrow, + torch.repeat_interleave, + torch.transpose, + torch.squeeze, + torch.stack, + torch.unsqueeze, + operator.floordiv, + "contiguous", + "clamp", + "detach", + "detach_", + "mean", + "permute", + "repeat", + "repeat_interleave", + "reshape", + "resize_", + "relu", + "relu_", + "squeeze", + "squeeze_", + "transpose", + "unsqueeze", + "unsqueeze_", + "view" + ] + return [_get_share_qprams_op_backend_config(op) for op in share_qparams_ops] + +def _get_bn_configs(dtype_configs: List[DTypeConfig]) -> List[BackendPatternConfig]: + """ Get configs related to batchnorm. """ + bn_configs = [] + bn_to_fused_bn = { + torch.nn.BatchNorm2d: nni.BNReLU2d, + torch.nn.BatchNorm3d: nni.BNReLU3d, + } + for bn in bn_to_fused_bn.keys(): + fused_bn = bn_to_fused_bn[bn] + # bn module + relu module fusion config + bn_configs.append( + BackendPatternConfig((bn, nn.ReLU)) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method(_sequential_wrapper2(fused_bn)) + .set_fused_module(fused_bn)) + # bn module + F.relu fusion config + bn_configs.append( + BackendPatternConfig((bn, F.relu)) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method(_sequential_wrapper2(fused_bn)) + .set_fused_module(fused_bn)) + bn_configs.append( + BackendPatternConfig(bn) + .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) # noqa: E131 + .set_dtype_configs(dtype_configs)) + + # fused bn configs + for fused_bn in bn_to_fused_bn.values(): + bn_configs.append( + BackendPatternConfig(fused_bn) + .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) # noqa: E131 + .set_dtype_configs(dtype_configs)) + return bn_configs + +def _get_rnn_op_configs(dtype_configs: List[DTypeConfig]) -> List[BackendPatternConfig]: + rnn_op_configs = [] + for rnn_op, ref_rnn_op in [ + (nn.GRUCell, nnqr.GRUCell), + (nn.LSTMCell, nnqr.LSTMCell), + (nn.RNNCell, nnqr.RNNCell), + (nn.LSTM, nnqr.LSTM), + (nn.GRU, nnqr.GRU) + ]: + rnn_op_configs.append( + BackendPatternConfig(rnn_op) + .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(rnn_op) + .set_reference_quantized_module(ref_rnn_op)) + return rnn_op_configs + +def _get_embedding_op_configs(dtype_configs: List[DTypeConfig]) -> List[BackendPatternConfig]: + embedding_op_configs = [] + for embedding_op, qat_embedding_op, ref_embedding_op in [ + (nn.Embedding, nnqat.Embedding, nnqr.Embedding), + (nn.EmbeddingBag, nnqat.EmbeddingBag, nnqr.EmbeddingBag), + ]: + embedding_op_configs.append( + BackendPatternConfig(embedding_op) + .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_qat_module(qat_embedding_op) + .set_root_module(embedding_op) + .set_reference_quantized_module(ref_embedding_op)) + + # config for qat op + embedding_op_configs.append( + BackendPatternConfig(qat_embedding_op) + .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(embedding_op) + .set_reference_quantized_module(ref_embedding_op)) + return embedding_op_configs + +def _get_tensor_info_op_configs(dtype_configs): + """ + These ops work on tensors of different dtypes but return non-tensors + containing information about the input tensor. + """ + + def _get_config(op): + return BackendPatternConfig(op) \ + .set_observation_type(ObservationType.INPUT_OUTPUT_NOT_OBSERVED) \ + .set_dtype_configs(dtype_configs) + + return [_get_config(op) for op in ("shape", "size")] diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/backend_config.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/backend_config.py new file mode 100644 index 0000000000000000000000000000000000000000..e5a4d2f3afa349688365fe19e498cd3bedcb08e7 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/backend_config.py @@ -0,0 +1,659 @@ +from __future__ import annotations +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Type, Union + +import torch +from torch.ao.quantization.utils import Pattern +from enum import Enum + + +__all__ = [ + "BackendConfig", + "BackendPatternConfig", + "DTypeConfig", + "DTypeWithConstraints", + "ObservationType", +] + + +# DTypeConfig dict keys +INPUT_DTYPE_DICT_KEY = "input_dtype" +OUTPUT_DTYPE_DICT_KEY = "output_dtype" +WEIGHT_DTYPE_DICT_KEY = "weight_dtype" +BIAS_DTYPE_DICT_KEY = "bias_dtype" +IS_DYNAMIC_DICT_KEY = "is_dynamic" + +# BackendConfig dict keys +NAME_DICT_KEY = "name" +CONFIGS_DICT_KEY = "configs" + +# BackendPatternConfig dict keys +PATTERN_DICT_KEY = "pattern" +PATTERN_COMPLEX_FORMAT_DICT_KEY = "pattern_complex_format" +OBSERVATION_TYPE_DICT_KEY = "observation_type" +DTYPE_CONFIGS_DICT_KEY = "dtype_configs" +ROOT_MODULE_DICT_KEY = "root_module" +QAT_MODULE_DICT_KEY = "qat_module" +REFERENCE_QUANTIZED_MODULE_DICT_KEY = "reference_quantized_module_for_root" +FUSED_MODULE_DICT_KEY = "fused_module" +FUSER_METHOD_DICT_KEY = "fuser_method" +ROOT_NODE_GETTER_DICT_KEY = "root_node_getter" +EXTRA_INPUTS_GETTER_DICT_KEY = "extra_inputs_getter" +NUM_TENSOR_ARGS_TO_OBSERVATION_TYPE_DICT_KEY = "num_tensor_args_to_observation_type" +INPUT_TYPE_TO_INDEX_DICT_KEY = "input_type_to_index" + + +# TODO: maybe rename this to something that's not related to observer +# e.g. QParamsType +class ObservationType(Enum): + """ An enum that represents different ways of how an operator/operator pattern + should be observed + """ + + OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT = 0 + """this means input and output are observed with different observers, based + on qconfig.activation + example: conv, linear, softmax + """ + + OUTPUT_SHARE_OBSERVER_WITH_INPUT = 1 + """this means the output will use the same observer instance as input, based + on qconfig.activation + example: torch.cat, maxpool + """ + + INPUT_OUTPUT_NOT_OBSERVED = 2 + """this means the input and output are never observed + example: x.shape, x.size + """ + + +@dataclass +class DTypeWithConstraints: + """ + Config for specifying additional constraints for a given dtype, such as quantization + value ranges, scale value ranges, and fixed quantization params, to be used in + :class:`~torch.ao.quantization.backend_config.DTypeConfig`. + + The constraints currently supported are: + + * `quant_min_lower_bound` and `quant_max_upper_bound`: Lower and upper + bounds for the minimum and maximum quantized values respectively. If + the QConfig’s `quant_min` and `quant_max` fall outside this range, + then the QConfig will be ignored. + + * `scale_min_lower_bound` and `scale_max_upper_bound`: Lower and upper + bounds for the minimum and maximum scale values respectively. If the + QConfig’s minimum scale value (currently exposed as `eps`) falls below + the lower bound, then the QConfig will be ignored. Note that the upper + bound is currently not enforced. + + * `scale_exact_match` and `zero_point_exact_match`: Exact match requirements + for scale and zero point, to be used for operators with fixed quantization + parameters such as sigmoid and tanh. If the observer specified in the QConfig + is neither `FixedQParamsObserver` nor `FixedQParamsFakeQuantize`, or if + the quantization parameters don't match, then the QConfig will be ignored. + """ + dtype: Optional[torch.dtype] = None + quant_min_lower_bound: Union[int, float, None] = None + quant_max_upper_bound: Union[int, float, None] = None + scale_min_lower_bound: Union[int, float, None] = None + scale_max_upper_bound: Union[int, float, None] = None + scale_exact_match: Optional[float] = None + zero_point_exact_match: Optional[int] = None + + +@dataclass +class DTypeConfig: + """ + Config object that specifies the supported data types passed as arguments to + quantize ops in the reference model spec, for input and output activations, + weights, and biases. + + For example, consider the following reference model: + + quant1 - [dequant1 - fp32_linear - quant2] - dequant2 + + The pattern in the square brackets refers to the reference pattern of + statically quantized linear. Setting the input dtype as `torch.quint8` + in the DTypeConfig means we pass in `torch.quint8` as the dtype argument + to the first quantize op (quant1). Similarly, setting the output dtype as + `torch.quint8` means we pass in `torch.quint8` as the dtype argument to + the second quantize op (quant2). + + Note that the dtype here does not refer to the interface dtypes of the + op. For example, the "input dtype" here is not the dtype of the input + tensor passed to the quantized linear op. Though it can still be the + same as the interface dtype, this is not always the case, e.g. the + interface dtype is fp32 in dynamic quantization but the "input dtype" + specified in the DTypeConfig would still be quint8. The semantics of + dtypes here are the same as the semantics of the dtypes specified in + the observers. + + These dtypes are matched against the ones specified in the user’s + QConfig. If there is a match, and the QConfig satisfies the constraints + specified in the DTypeConfig (if any), then we will quantize the given + pattern using this DTypeConfig. Otherwise, the QConfig is ignored and + the pattern will not be quantized. + + Example usage:: + + >>> # xdoctest: +SKIP(failing) + >>> dtype_config1 = DTypeConfig( + ... input_dtype=torch.quint8, + ... output_dtype=torch.quint8, + ... weight_dtype=torch.qint8, + ... bias_dtype=torch.float) + + >>> dtype_config2 = DTypeConfig( + ... input_dtype=DTypeWithConstraints( + ... dtype=torch.quint8, + ... quant_min_lower_bound=0, + ... quant_max_upper_bound=255, + ... ), + ... output_dtype=DTypeWithConstraints( + ... dtype=torch.quint8, + ... quant_min_lower_bound=0, + ... quant_max_upper_bound=255, + ... ), + ... weight_dtype=DTypeWithConstraints( + ... dtype=torch.qint8, + ... quant_min_lower_bound=-128, + ... quant_max_upper_bound=127, + ... ), + ... bias_dtype=torch.float) + + >>> dtype_config1.input_dtype + torch.quint8 + + >>> dtype_config2.input_dtype + torch.quint8 + + >>> dtype_config2.input_dtype_with_constraints + DTypeWithConstraints(dtype=torch.quint8, quant_min_lower_bound=0, quant_max_upper_bound=255, \ +scale_min_lower_bound=None, scale_max_upper_bound=None) + """ + input_dtype_with_constraints: DTypeWithConstraints + output_dtype_with_constraints: DTypeWithConstraints + weight_dtype_with_constraints: DTypeWithConstraints + bias_dtype: Optional[torch.dtype] + is_dynamic: Optional[bool] + + def __init__( + self, + input_dtype: Union[torch.dtype, DTypeWithConstraints, None] = None, + output_dtype: Union[torch.dtype, DTypeWithConstraints, None] = None, + weight_dtype: Union[torch.dtype, DTypeWithConstraints, None] = None, + bias_dtype: Optional[torch.dtype] = None, + is_dynamic: Optional[bool] = None, + ): + if isinstance(input_dtype, DTypeWithConstraints): + self.input_dtype_with_constraints = input_dtype + else: + self.input_dtype_with_constraints = DTypeWithConstraints(dtype=input_dtype) + + if isinstance(output_dtype, DTypeWithConstraints): + self.output_dtype_with_constraints = output_dtype + else: + self.output_dtype_with_constraints = DTypeWithConstraints(dtype=output_dtype) + + if isinstance(weight_dtype, DTypeWithConstraints): + self.weight_dtype_with_constraints = weight_dtype + else: + self.weight_dtype_with_constraints = DTypeWithConstraints(dtype=weight_dtype) + + self.bias_dtype = bias_dtype + self.is_dynamic = is_dynamic + + @property + def input_dtype(self) -> Optional[torch.dtype]: + return self.input_dtype_with_constraints.dtype + + @property + def output_dtype(self) -> Optional[torch.dtype]: + return self.output_dtype_with_constraints.dtype + + @property + def weight_dtype(self) -> Optional[torch.dtype]: + return self.weight_dtype_with_constraints.dtype + + @classmethod + def from_dict(cls, dtype_config_dict: Dict[str, Any]) -> DTypeConfig: + """ + Create a ``DTypeConfig`` from a dictionary with the following items (all optional): + "input_dtype": torch.dtype or ``DTypeWithConstraints`` + "output_dtype": torch.dtype or ``DTypeWithConstraints`` + "weight_dtype": torch.dtype or ``DTypeWithConstraints`` + "bias_type": torch.dtype + "is_dynamic": bool + """ + input_dtype = dtype_config_dict.get(INPUT_DTYPE_DICT_KEY, None) + if input_dtype is not None and not isinstance(input_dtype, (torch.dtype, DTypeWithConstraints)): + raise ValueError("Expected input_dtype to be a torch.dtype or DTypeWithConstraints") + output_dtype = dtype_config_dict.get(OUTPUT_DTYPE_DICT_KEY, None) + if output_dtype is not None and not isinstance(output_dtype, (torch.dtype, DTypeWithConstraints)): + raise ValueError("Expected output_dtype to be a torch.dtype or DTypeWithConstraints") + weight_dtype = dtype_config_dict.get(WEIGHT_DTYPE_DICT_KEY, None) + if weight_dtype is not None and not isinstance(weight_dtype, (torch.dtype, DTypeWithConstraints)): + raise ValueError("Expected weight_dtype to be a torch.dtype or DTypeWithConstraints") + bias_dtype = dtype_config_dict.get(BIAS_DTYPE_DICT_KEY, None) + is_dynamic = dtype_config_dict.get(IS_DYNAMIC_DICT_KEY, None) + return cls(input_dtype, output_dtype, weight_dtype, bias_dtype, is_dynamic) + + def to_dict(self) -> Dict[str, Any]: + """ + Convert this ``DTypeConfig`` to a dictionary with the items described in + :func:`~torch.ao.quantization.backend_config.DTypeConfig.from_dict`. + """ + dtype_config_dict: Dict[str, Any] = {} + if self.input_dtype is not None: + dtype_config_dict[INPUT_DTYPE_DICT_KEY] = self.input_dtype_with_constraints + if self.output_dtype is not None: + dtype_config_dict[OUTPUT_DTYPE_DICT_KEY] = self.output_dtype_with_constraints + if self.weight_dtype is not None: + dtype_config_dict[WEIGHT_DTYPE_DICT_KEY] = self.weight_dtype_with_constraints + if self.bias_dtype is not None: + dtype_config_dict[BIAS_DTYPE_DICT_KEY] = self.bias_dtype + if self.is_dynamic is not None: + dtype_config_dict[IS_DYNAMIC_DICT_KEY] = self.is_dynamic + return dtype_config_dict + + +class BackendConfig: + # TODO: refer to NativeBackendConfig once that is implemented + """Config that defines the set of patterns that can be quantized on a given backend, and how reference + quantized models can be produced from these patterns. + + A pattern in this context refers to a module, a functional, an operator, or a directed acyclic graph + of the above. Each pattern supported on the target backend can be individually configured through + :class:`~torch.ao.quantization.backend_config.BackendPatternConfig` in terms of: + + (1) The supported input/output activation, weight, and bias data types + + (2) How observers and quant/dequant ops are inserted in order to construct the reference pattern, and + + (3) (Optionally) Fusion, QAT, and reference module mappings. + + The format of the patterns is described in: + https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/backend_config/README.md + + Example usage:: + + import torch + from torch.ao.quantization.backend_config import ( + BackendConfig, + BackendPatternConfig, + DTypeConfig, + ObservationType, + ) + + weighted_int8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, + weight_dtype=torch.qint8, + bias_dtype=torch.float) + + def fuse_conv2d_relu(is_qat, conv, relu): + return torch.ao.nn.intrinsic.ConvReLU2d(conv, relu) + + # For quantizing Linear + linear_config = BackendPatternConfig(torch.nn.Linear) \ + .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \ + .add_dtype_config(weighted_int8_dtype_config) \ + .set_root_module(torch.nn.Linear) \ + .set_qat_module(torch.ao.nn.qat.Linear) \ + .set_reference_quantized_module(torch.ao.nn.quantized.reference.Linear) + + # For fusing Conv2d + ReLU into ConvReLU2d + conv_relu_config = BackendPatternConfig((torch.nn.Conv2d, torch.nn.ReLU)) \ + .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \ + .add_dtype_config(weighted_int8_dtype_config) \ + .set_fused_module(torch.ao.nn.intrinsic.ConvReLU2d) \ + .set_fuser_method(fuse_conv2d_relu) + + # For quantizing ConvReLU2d + fused_conv_relu_config = BackendPatternConfig(torch.ao.nn.intrinsic.ConvReLU2d) \ + .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \ + .add_dtype_config(weighted_int8_dtype_config) \ + .set_root_module(torch.nn.Conv2d) \ + .set_qat_module(torch.ao.nn.intrinsic.qat.ConvReLU2d) \ + .set_reference_quantized_module(torch.ao.nn.quantized.reference.Conv2d) + + backend_config = BackendConfig("my_backend") \ + .set_backend_pattern_config(linear_config) \ + .set_backend_pattern_config(conv_relu_config) \ + .set_backend_pattern_config(fused_conv_relu_config) + + """ + def __init__(self, name: str = ""): + self.name = name + # Store all BackendPatternConfigs in a map to handle duplicates + # Note: the key in this map uses the complex reversed tuple format. + # This is intended only for internal use; users who wish to access + # the original patterns should go through `self.configs` instead. + self._pattern_complex_format_to_config: Dict[Pattern, BackendPatternConfig] = {} + + def __repr__(self): + return f"BackendConfig({self.__dict__})" + + def set_name(self, name: str) -> BackendConfig: + """ + Set the name of the target backend. + """ + self.name = name + return self + + def set_backend_pattern_config(self, config: BackendPatternConfig) -> BackendConfig: + """ + Set the config for an pattern that can be run on the target backend. + This overrides any existing config for the given pattern. + """ + # Avoid circular dependencies + pattern_complex_format = torch.ao.quantization.backend_config.utils \ + ._get_pattern_in_reversed_nested_tuple_format(config) # type: ignore[attr-defined] + self._pattern_complex_format_to_config[pattern_complex_format] = config + return self + + def set_backend_pattern_configs(self, configs: List[BackendPatternConfig]) -> BackendConfig: + """ + Set the configs for patterns that can be run on the target backend. + This overrides any existing config for a given pattern if it was previously registered already. + """ + for conf in configs: + self.set_backend_pattern_config(conf) + return self + + @property + def configs(self) -> List[BackendPatternConfig]: + """ + Return a copy of the list of configs set in this `BackendConfig`. + """ + return list(self._pattern_complex_format_to_config.values()) + + @classmethod + def from_dict(cls, backend_config_dict: Dict[str, Any]) -> BackendConfig: + """ + Create a ``BackendConfig`` from a dictionary with the following items: + + "name": the name of the target backend + + "configs": a list of dictionaries that each represents a `BackendPatternConfig` + + """ + conf = cls(backend_config_dict.get(NAME_DICT_KEY, "")) + for d in backend_config_dict.get(CONFIGS_DICT_KEY, []): + if isinstance(d, BackendPatternConfig): + conf.set_backend_pattern_config(d) + elif isinstance(d, Dict): + conf.set_backend_pattern_config(BackendPatternConfig.from_dict(d)) + else: + raise ValueError(f"Expected backend_config_dict['{CONFIGS_DICT_KEY}'] to be a dictionary") + return conf + + def to_dict(self) -> Dict[str, Any]: + """ + Convert this ``BackendConfig`` to a dictionary with the items described in + :func:`~torch.ao.quantization.backend_config.BackendConfig.from_dict`. + """ + return { + NAME_DICT_KEY: self.name, + CONFIGS_DICT_KEY: [c.to_dict() for c in self.configs], + } + + +class BackendPatternConfig: + """ + Config object that specifies quantization behavior for a given operator pattern. + For a detailed example usage, see :class:`~torch.ao.quantization.backend_config.BackendConfig`. + """ + def __init__(self, pattern: Optional[Pattern] = None): + self.pattern: Optional[Pattern] = pattern + self.observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + self.dtype_configs: List[DTypeConfig] = [] + self.root_module: Optional[Type[torch.nn.Module]] = None + self.qat_module: Optional[Type[torch.nn.Module]] = None + self.reference_quantized_module: Optional[Type[torch.nn.Module]] = None + self.fused_module: Optional[Type[torch.nn.Module]] = None + self.fuser_method: Optional[Callable] = None + + # Temporary/internal configs + self._root_node_getter: Optional[Callable] = None + self._extra_inputs_getter: Optional[Callable] = None + self._num_tensor_args_to_observation_type: Dict[int, ObservationType] = {} + self._input_type_to_index: Dict[str, int] = {} + self._pattern_complex_format: Optional[Pattern] = None + + def __repr__(self): + dict_nonempty = { + k: v for k, v in self.__dict__.items() + if ( + (not isinstance(v, (list, dict)) and v is not None) + or (isinstance(v, (list, dict)) and len(v) > 0) + ) + } + return f"BackendPatternConfig({dict_nonempty})" + + def set_pattern(self, pattern: Pattern) -> BackendPatternConfig: + """ + Set the pattern to configure. + + The pattern can be a float module, functional operator, pytorch operator, or a tuple + combination of the above. Tuple patterns are treated as sequential patterns, and + currently only tuples of 2 or 3 elements are supported. + """ + if self._pattern_complex_format is not None: + raise ValueError("Only one of 'pattern' or 'pattern_complex_format' can be set") + self.pattern = pattern + return self + + def set_observation_type(self, observation_type: ObservationType) -> BackendPatternConfig: + """ + Set how observers should be inserted in the graph for this pattern. + + Observation type here refers to how observers (or quant-dequant ops) will be placed + in the graph. This is used to produce the desired reference patterns understood by + the backend. Weighted ops such as linear and conv require different observers + (or quantization parameters passed to quantize ops in the reference model) for the + input and the output. + + There are two observation types: + + `OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT` (default): the output observer instance + will be different from the input. This is the most common observation type. + + `OUTPUT_SHARE_OBSERVER_WITH_INPUT`: the output observer instance will be the + same as the input. This is useful for operators like `cat`. + + Note: This will be renamed in the near future, since we will soon insert QuantDeQuantStubs + with observers (and fake quantizes) attached instead of observers themselves. + """ + self.observation_type = observation_type + return self + + def add_dtype_config(self, dtype_config: DTypeConfig) -> BackendPatternConfig: + """ + Add a set of supported data types passed as arguments to quantize ops in the + reference model spec. + """ + self.dtype_configs.append(dtype_config) + return self + + def set_dtype_configs(self, dtype_configs: List[DTypeConfig]) -> BackendPatternConfig: + """ + Set the supported data types passed as arguments to quantize ops in the + reference model spec, overriding all previously registered data types. + """ + self.dtype_configs = dtype_configs + return self + + def set_root_module(self, root_module: Type[torch.nn.Module]) -> BackendPatternConfig: + """ + Set the module that represents the root for this pattern. + + When we construct the reference quantized model during the convert phase, + the root modules (e.g. torch.nn.Linear for torch.ao.nn.intrinsic.LinearReLU) + will be swapped to the corresponding reference quantized modules (e.g. + torch.ao.nn.reference.quantized.Linear). This allows custom backends to + specify custom reference quantized module implementations to match the + numerics of their lowered operators. Since this is a one-to-one mapping, + both the root module and the reference quantized module must be specified + in the same BackendPatternConfig in order for the conversion to take place. + """ + self.root_module = root_module + return self + + def set_qat_module(self, qat_module: Type[torch.nn.Module]) -> BackendPatternConfig: + """ + Set the module that represents the QAT implementation for this pattern. + """ + self.qat_module = qat_module + return self + + def set_reference_quantized_module(self, reference_quantized_module: Type[torch.nn.Module]) -> BackendPatternConfig: + """ + Set the module that represents the reference quantized implementation for + this pattern's root module. + + For more detail, see :func:`~torch.ao.quantization.backend_config.BackendPatternConfig.set_root_module`. + """ + self.reference_quantized_module = reference_quantized_module + return self + + def set_fused_module(self, fused_module: Type[torch.nn.Module]) -> BackendPatternConfig: + """ + Set the module that represents the fused implementation for this pattern. + """ + self.fused_module = fused_module + return self + + def set_fuser_method(self, fuser_method: Callable) -> BackendPatternConfig: + """ + Set the function that specifies how to fuse this BackendPatternConfig's pattern. + + The first argument of this function should be `is_qat`, and the rest of the arguments + should be the items in the tuple pattern. The return value of this function should be + the resulting fused module. + + For example, the fuser method for the pattern `(torch.nn.Linear, torch.nn.ReLU)` can be: + + def fuse_linear_relu(is_qat, linear, relu): + return torch.ao.nn.intrinsic.LinearReLU(linear, relu) + + For a more complicated example, see https://gist.github.com/jerryzh168/8bea7180a8ba3c279f2c9b050f2a69a6. + """ + self.fuser_method = fuser_method + return self + + def _set_root_node_getter(self, root_node_getter: Callable) -> BackendPatternConfig: + self._root_node_getter = root_node_getter + return self + + def _set_extra_inputs_getter(self, extra_inputs_getter: Callable) -> BackendPatternConfig: + self._extra_inputs_getter = extra_inputs_getter + return self + + def _set_num_tensor_args_to_observation_type( + self, num_tensor_args_to_observation_type: Dict[int, ObservationType]) -> BackendPatternConfig: + self._num_tensor_args_to_observation_type = num_tensor_args_to_observation_type + return self + + def _set_input_type_to_index(self, input_type_to_index: Dict[str, int]) -> BackendPatternConfig: + self._input_type_to_index = input_type_to_index + return self + + def _set_pattern_complex_format(self, pattern: Pattern) -> BackendPatternConfig: + """ + Set the pattern to configure, using the reversed nested tuple format. + + See the BackendConfig README for more detail: + https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/backend_config/README.md#advanced-pattern-specification + """ + if self.pattern is not None: + raise ValueError("Only one of 'pattern' or 'pattern_complex_format' can be set") + self._pattern_complex_format = pattern + return self + + @classmethod + def from_dict(cls, backend_pattern_config_dict: Dict[str, Any]) -> BackendPatternConfig: + """ + Create a ``BackendPatternConfig`` from a dictionary with the following items: + + "pattern": the pattern being configured + "observation_type": the :class:`~torch.ao.quantization.backend_config.ObservationType` that specifies how + observers should be inserted for this pattern + "dtype_configs": a list of dictionaries that represents :class:`~torch.ao.quantization.backend_config.DTypeConfig` s + "root_module": a :class:`torch.nn.Module` that represents the root for this pattern + "qat_module": a :class:`torch.nn.Module` that represents the QAT implementation for this pattern + "reference_quantized_module": a :class:`torch.nn.Module` that represents the reference quantized + implementation for this pattern's root module. + "fused_module": a :class:`torch.nn.Module` that represents the fused implementation for this pattern + "fuser_method": a function that specifies how to fuse the pattern for this pattern + "pattern_complex_format": the pattern specified in the reversed nested tuple format (deprecated) + + """ + def _get_dtype_config(obj: Any) -> DTypeConfig: + """ + Convert the given object into a ``DTypeConfig`` if possible, else throw an exception. + """ + if isinstance(obj, DTypeConfig): + return obj + if isinstance(obj, Dict): + return DTypeConfig.from_dict(obj) + raise ValueError( + f"Expected a list of DTypeConfigs in " + f"backend_pattern_config_dict[\"{DTYPE_CONFIGS_DICT_KEY}\"], got '{type(obj)}'" + ) + + conf = cls() + if PATTERN_DICT_KEY in backend_pattern_config_dict: + conf.set_pattern(backend_pattern_config_dict[PATTERN_DICT_KEY]) + if OBSERVATION_TYPE_DICT_KEY in backend_pattern_config_dict: + conf.set_observation_type(backend_pattern_config_dict[OBSERVATION_TYPE_DICT_KEY]) + for d in backend_pattern_config_dict.get(DTYPE_CONFIGS_DICT_KEY, []): + conf.add_dtype_config(_get_dtype_config(d)) + conf.set_root_module(backend_pattern_config_dict.get(ROOT_MODULE_DICT_KEY, None)) + conf.set_qat_module(backend_pattern_config_dict.get(QAT_MODULE_DICT_KEY, None)) + conf.set_reference_quantized_module(backend_pattern_config_dict.get(REFERENCE_QUANTIZED_MODULE_DICT_KEY, None)) + conf.set_fused_module(backend_pattern_config_dict.get(FUSED_MODULE_DICT_KEY, None)) + conf.set_fuser_method(backend_pattern_config_dict.get(FUSER_METHOD_DICT_KEY, None)) + conf._set_root_node_getter(backend_pattern_config_dict.get(ROOT_NODE_GETTER_DICT_KEY, None)) + conf._set_extra_inputs_getter(backend_pattern_config_dict.get(EXTRA_INPUTS_GETTER_DICT_KEY, None)) + conf._set_num_tensor_args_to_observation_type( + backend_pattern_config_dict.get(NUM_TENSOR_ARGS_TO_OBSERVATION_TYPE_DICT_KEY, {})) + conf._set_input_type_to_index(backend_pattern_config_dict.get(INPUT_TYPE_TO_INDEX_DICT_KEY, {})) + if PATTERN_COMPLEX_FORMAT_DICT_KEY in backend_pattern_config_dict: + conf._set_pattern_complex_format(backend_pattern_config_dict[PATTERN_COMPLEX_FORMAT_DICT_KEY]) + return conf + + def to_dict(self) -> Dict[str, Any]: + """ + Convert this ``BackendPatternConfig`` to a dictionary with the items described in + :func:`~torch.ao.quantization.backend_config.BackendPatternConfig.from_dict`. + """ + backend_pattern_config_dict: Dict[str, Any] = { + OBSERVATION_TYPE_DICT_KEY: self.observation_type, + DTYPE_CONFIGS_DICT_KEY: [c.to_dict() for c in self.dtype_configs], + } + if self.pattern is not None: + backend_pattern_config_dict[PATTERN_DICT_KEY] = self.pattern + if self.root_module is not None: + backend_pattern_config_dict[ROOT_MODULE_DICT_KEY] = self.root_module + if self.qat_module is not None: + backend_pattern_config_dict[QAT_MODULE_DICT_KEY] = self.qat_module + if self.reference_quantized_module is not None: + backend_pattern_config_dict[REFERENCE_QUANTIZED_MODULE_DICT_KEY] = self.reference_quantized_module + if self.fused_module is not None: + backend_pattern_config_dict[FUSED_MODULE_DICT_KEY] = self.fused_module + if self.fuser_method is not None: + backend_pattern_config_dict[FUSER_METHOD_DICT_KEY] = self.fuser_method + if self._root_node_getter is not None: + backend_pattern_config_dict[ROOT_NODE_GETTER_DICT_KEY] = self._root_node_getter + if self._extra_inputs_getter is not None: + backend_pattern_config_dict[EXTRA_INPUTS_GETTER_DICT_KEY] = self._extra_inputs_getter + if len(self._num_tensor_args_to_observation_type) > 0: + backend_pattern_config_dict[NUM_TENSOR_ARGS_TO_OBSERVATION_TYPE_DICT_KEY] = self._num_tensor_args_to_observation_type + if len(self._input_type_to_index) > 0: + backend_pattern_config_dict[INPUT_TYPE_TO_INDEX_DICT_KEY] = self._input_type_to_index + if self._pattern_complex_format is not None: + backend_pattern_config_dict[PATTERN_COMPLEX_FORMAT_DICT_KEY] = self._pattern_complex_format + return backend_pattern_config_dict diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/executorch.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/executorch.py new file mode 100644 index 0000000000000000000000000000000000000000..86a2d13e19ff1a2dc2e9bdc5e5920bd1b207ab42 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/executorch.py @@ -0,0 +1,494 @@ +# TODO: rename executorch to qnnpack_executorch since executorch is a general runtime +# not a specific backend + +import operator +from typing import List + +import torch +import torch.ao.nn.qat as nnqat +import torch.ao.nn.quantized.reference as nnqr +import torch.nn as nn +import torch.nn.functional as F + +from ..fuser_method_mappings import ( + _sequential_wrapper2, + fuse_conv_bn, + fuse_conv_bn_relu, +) +from ._common_operator_config_utils import _Conv2dMetadata +from .backend_config import ( + BackendConfig, + BackendPatternConfig, + DTypeConfig, + DTypeWithConstraints, + ObservationType, +) +from .qnnpack import ( + qnnpack_default_op_qint8_symmetric_dtype_config, + qnnpack_weighted_op_qint8_symmetric_dtype_config, +) + + +__all__ = [ + "get_executorch_backend_config", +] + + +# =================== +# | DTYPE CONFIGS | +# =================== + +executorch_weighted_op_int8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, + weight_dtype=torch.qint8, + bias_dtype=torch.float, +) + +executorch_default_op_quint8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, +) + +executorch_default_dynamic_quint8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.float, + weight_dtype=torch.qint8, + bias_dtype=torch.float, + is_dynamic=True, +) + +executorch_act_qint8_scale_min_2_neg_12 = DTypeWithConstraints( + dtype=torch.qint8, + scale_min_lower_bound=2**-12, +) + +executorch_weight_qint8_neg_127_to_127_scale_min_2_neg_12 = DTypeWithConstraints( + dtype=torch.qint8, + quant_min_lower_bound=-127, + quant_max_upper_bound=127, + scale_min_lower_bound=2**-12, +) + +executorch_default_dynamic_qint8_dtype_config = DTypeConfig( + input_dtype=executorch_act_qint8_scale_min_2_neg_12, + output_dtype=torch.float, + weight_dtype=executorch_weight_qint8_neg_127_to_127_scale_min_2_neg_12, + bias_dtype=torch.float, + is_dynamic=True, +) + +executorch_default_dynamic_float16_dtype_config = DTypeConfig( + input_dtype=torch.float16, + output_dtype=torch.float, + weight_dtype=torch.float16, + bias_dtype=torch.float, + is_dynamic=True, +) + +executorch_weight_only_quint8_dtype_config = DTypeConfig( + input_dtype=torch.float, + output_dtype=torch.float, + weight_dtype=torch.quint8, +) + + +# ============================= +# | BACKEND PATTERN CONFIGS | +# ============================= + + +def _get_linear_configs() -> List[BackendPatternConfig]: + """ + Return all configs related to linear modules and ops. + """ + observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + dtype_configs = [ + qnnpack_weighted_op_qint8_symmetric_dtype_config, + executorch_weighted_op_int8_dtype_config, + executorch_default_dynamic_quint8_dtype_config, + executorch_default_dynamic_qint8_dtype_config, + executorch_default_dynamic_float16_dtype_config, + ] + linear_configs: List[BackendPatternConfig] = [] + # linear module + linear_configs.append( + BackendPatternConfig(torch.nn.Linear) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(torch.nn.Linear) + .set_reference_quantized_module(nnqr.Linear) + .set_qat_module(nnqat.Linear) + ) + # linear qat module + linear_configs.append( + BackendPatternConfig(nnqat.Linear) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(torch.nn.Linear) + .set_reference_quantized_module(nnqr.Linear) + ) + # functional linear + linear_configs.append( + BackendPatternConfig(torch.nn.functional.linear) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + ._set_input_type_to_index({"weight": 1, "bias": 2}) + ) + return linear_configs + + +def _get_conv_configs() -> List[BackendPatternConfig]: + """ + Return all configs related to conv modules and ops. + """ + observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + dtype_configs = [ + qnnpack_weighted_op_qint8_symmetric_dtype_config, + executorch_weighted_op_int8_dtype_config, + ] + conv_configs = [] + for convs in [_Conv2dMetadata]: + # (1) Single conv modules/functions + # ----------------------------------- + # conv module + conv_configs.append( + BackendPatternConfig(convs.root) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(convs.root) + .set_reference_quantized_module(convs.reference) + .set_qat_module(convs.qat) + ) + # conv qat module + conv_configs.append( + BackendPatternConfig(convs.qat) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(convs.root) + .set_reference_quantized_module(convs.reference) + ) + # functional conv + conv_configs.append( + BackendPatternConfig(convs.func) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + ._set_input_type_to_index({"weight": 1, "bias": 2}) + ) + + # (2) Conv + relu + # ----------------------------------- + # conv module + relu module + conv_configs.append( + BackendPatternConfig((convs.root, nn.ReLU)) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method(_sequential_wrapper2(convs.fused_conv_relu)) + .set_fused_module(convs.fused_conv_relu) + ) + # conv module + functional relu + conv_configs.append( + BackendPatternConfig((convs.root, F.relu)) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method(_sequential_wrapper2(convs.fused_conv_relu)) + .set_fused_module(convs.fused_conv_relu) + ) + # fused conv relu module + conv_configs.append( + BackendPatternConfig(convs.fused_conv_relu) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(convs.root) + .set_reference_quantized_module(convs.reference) + .set_qat_module(convs.relu_qat) + ) + # conv relu, qat fused module + conv_configs.append( + BackendPatternConfig(convs.relu_qat) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(convs.root) + .set_reference_quantized_module(convs.reference) + ) + # functional conv + relu module + conv_configs.append( + BackendPatternConfig((convs.func, nn.ReLU)) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + ) + # functional conv + functional relu + conv_configs.append( + BackendPatternConfig((convs.func, F.relu)) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + ) + # fused conv relu + conv_configs.append( + BackendPatternConfig(convs.fused_conv_relu) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_qat_module(convs.relu_qat) + ) + + conv_configs.append( + BackendPatternConfig(convs.relu_qat) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_root_module(convs.root) + .set_reference_quantized_module(convs.reference) + ) + + # (3) Conv + batchnorm (+ relu) + # ------------------------------- + # conv + batchnorm (+ relu) + conv_configs.append( + BackendPatternConfig((convs.root, convs.bn)) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method(fuse_conv_bn) + .set_fused_module(convs.fused_conv_bn) + ) + # conv + bn + relu module fusion + conv_configs.append( + BackendPatternConfig((convs.root, convs.bn, nn.ReLU)) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method(fuse_conv_bn_relu) + .set_fused_module(convs.fused_conv_bn_relu) + ) + # conv + bn + relu functional fusion + conv_configs.append( + BackendPatternConfig((convs.root, convs.bn, F.relu)) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_root_module(convs.root) + .set_fuser_method(fuse_conv_bn_relu) + .set_fused_module(convs.fused_conv_bn_relu) + ) + # TODO: we can add fusion for torch.relu as well + # 3.2 conv + bn (+ relu) fused module configs + # fused conv bn + conv_configs.append( + BackendPatternConfig(convs.fused_conv_bn) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_qat_module(convs.bn_qat) + ) + + # fused conv bn relu + conv_configs.append( + BackendPatternConfig(convs.fused_conv_bn_relu) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_qat_module(convs.bn_relu_qat) + ) + + # conv bn, qat fused module + conv_configs.append( + BackendPatternConfig(convs.bn_qat) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(convs.root) + .set_reference_quantized_module(convs.reference) + ) + # conv bn relu, qat fused module + conv_configs.append( + BackendPatternConfig(convs.bn_relu_qat) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(convs.root) + .set_reference_quantized_module(convs.reference) + ) + return conv_configs + + +def _get_binary_ops_configs() -> List[BackendPatternConfig]: + """ + Return all configs related to binary ops. + """ + dtype_configs = [ + qnnpack_default_op_qint8_symmetric_dtype_config, + executorch_weighted_op_int8_dtype_config, + ] + num_tensor_args_to_observation_type_mapping = { + # TODO: this is not used right now since we have extra check in prepare + # will need to change this to NO_OBSERVER later after we implemented + # Tensor dtype inference properly + 0: ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT, + 1: ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT, + 2: ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT, + } + binary_op_configs: List[BackendPatternConfig] = [] + for op in [operator.add, torch.add, operator.sub, torch.sub, operator.mul, torch.mul]: + bop_patterns = [ + (op, torch.nn.ReLU), + (op, torch.nn.functional.relu), + (op, torch.relu), + op + ] + for bop_pattern in bop_patterns: + binary_op_configs.append( + BackendPatternConfig(bop_pattern) + .set_dtype_configs(dtype_configs) # noqa: E131 + ._set_num_tensor_args_to_observation_type( + num_tensor_args_to_observation_type_mapping + ) + ) + return binary_op_configs + + +def _get_share_qparams_ops_configs() -> List[BackendPatternConfig]: + """ + Return the operator configs for the operators that works for both float and quantized + input if input is quantized, the output Tensor shares the same quantization parameter + with input. + + Example operator: avgpool2d, reshape, transpose, maxpool2d + Example observed operator: + observer_0 - avgpool2d - observer_0 (same observer instance as input) + """ + observation_type = ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT + dtype_configs = [ + qnnpack_default_op_qint8_symmetric_dtype_config, + executorch_default_op_quint8_dtype_config, + ] + share_qparams_ops = [ + torch.nn.Flatten, + F.adaptive_avg_pool2d, + F.elu, + F.hardtanh, + F.max_pool2d, + F.pad, + F.relu, + F.relu6, + F.leaky_relu, + F.leaky_relu_, + torch.nn.AdaptiveAvgPool2d, + torch.nn.ConstantPad2d, + torch.nn.ELU, + torch.nn.MaxPool2d, + torch.nn.ReLU6, + torch.nn.Hardtanh, + torch.nn.LeakyReLU, + torch.clamp, + torch.flatten, + torch.mean, + torch.permute, + torch.permute_copy, + torch.squeeze, + "clamp", + "mean", + "permute", + "reshape", + "relu", + "relu_", + "squeeze", + "squeeze_", + "leaky_relu", + ] + share_qparams_op_configs: List[BackendPatternConfig] = [] + for op in share_qparams_ops: + share_qparams_op_configs.append( + BackendPatternConfig(op) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + ) + return share_qparams_op_configs + + +def _get_bn_configs() -> List[BackendPatternConfig]: + """ + Return all configs related to batchnorm. + """ + observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + dtype_configs = [ + qnnpack_default_op_qint8_symmetric_dtype_config, + executorch_default_op_quint8_dtype_config, + ] + bn_configs = [] + bn_configs.append( + BackendPatternConfig(nn.BatchNorm2d) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + ) + return bn_configs + + +def _get_cat_configs() -> List[BackendPatternConfig]: + dtype_configs = [ + qnnpack_default_op_qint8_symmetric_dtype_config, + executorch_default_op_quint8_dtype_config, + ] + cat_configs = [] + cat_configs.append( + BackendPatternConfig(torch.cat) + .set_observation_type(ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT) + .set_dtype_configs(dtype_configs) + ) + cat_configs.append( + BackendPatternConfig(torch.concat) + .set_observation_type(ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT) + .set_dtype_configs(dtype_configs) + ) + cat_configs.append( + BackendPatternConfig(torch.concatenate) + .set_observation_type(ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT) + .set_dtype_configs(dtype_configs) + ) + return cat_configs + + +def _get_embedding_op_configs() -> List[BackendPatternConfig]: + dtype_configs = [ + executorch_weight_only_quint8_dtype_config, + ] + embedding_op_configs = [] + for embedding_op, qat_embedding_op, ref_embedding_op in [ + (nn.Embedding, nnqat.Embedding, nnqr.Embedding), + (nn.EmbeddingBag, nnqat.EmbeddingBag, nnqr.EmbeddingBag), + ]: + embedding_op_configs.append( + BackendPatternConfig(embedding_op) + .set_observation_type( + ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + ) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_qat_module(qat_embedding_op) + .set_root_module(embedding_op) + .set_reference_quantized_module(ref_embedding_op) + ) + # config for qat op + embedding_op_configs.append( + BackendPatternConfig(qat_embedding_op) + .set_observation_type( + ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + ) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(embedding_op) + .set_reference_quantized_module(ref_embedding_op) + ) + + # config for functional embedding + embedding_op_configs.append( + BackendPatternConfig(torch.nn.functional.embedding) + .set_observation_type( + ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + ) # noqa: E131 + .set_dtype_configs(dtype_configs) + ._set_input_type_to_index({"weight": 1}) + ) + return embedding_op_configs + + + +# ===================== +# | BACKEND CONFIGS | +# ===================== + + +def get_executorch_backend_config() -> BackendConfig: + """ + Return the `BackendConfig` for backends PyTorch lowers to through the Executorch stack. + """ + return ( + BackendConfig("executorch") + .set_backend_pattern_configs(_get_linear_configs()) + .set_backend_pattern_configs(_get_conv_configs()) + .set_backend_pattern_configs(_get_binary_ops_configs()) + .set_backend_pattern_configs(_get_share_qparams_ops_configs()) + .set_backend_pattern_configs(_get_bn_configs()) + .set_backend_pattern_configs(_get_cat_configs()) + .set_backend_pattern_configs(_get_embedding_op_configs()) + ) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/onednn.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/onednn.py new file mode 100644 index 0000000000000000000000000000000000000000..6eab945f7d743285160dd591bad59c0b1881dada --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/onednn.py @@ -0,0 +1,542 @@ +import torch +import torch.nn as nn +import torch.ao.nn.intrinsic as nni +import torch.nn.functional as F +import torch.ao.nn.quantized.reference as nnqr +from ._common_operator_config_utils import ( + _get_conv_configs, + _get_linear_configs, + _get_binary_op_configs, + _get_bn_configs, + _get_cat_config, + _get_default_op_configs, + _get_embedding_op_configs, + _get_fixed_qparams_op_configs, + _get_ln_configs, + _get_rnn_op_configs, + _get_share_qparams_op_configs, +) +from .backend_config import ( + BackendPatternConfig, + BackendConfig, + DTypeConfig, + ObservationType, +) +from ..fuser_method_mappings import ( + _sequential_wrapper2, +) +import operator +from torch.ao.quantization.utils import MatchAllNode +import itertools + +# =================== +# | DTYPE CONFIGS | +# =================== + +onednn_weighted_op_int8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, + weight_dtype=torch.qint8, + bias_dtype=torch.float, +) + +onednn_op_quint8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, +) + +onednn_dynamic_int8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.float, + weight_dtype=torch.qint8, + bias_dtype=torch.float, + is_dynamic=True, +) + +onednn_weight_only_qint8_dtype_config = DTypeConfig( + input_dtype=torch.float, + output_dtype=torch.float, + weight_dtype=torch.qint8, +) + +onednn_input_output_only_quint8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, + weight_dtype=torch.float, + bias_dtype=torch.float, +) + +# =================== +# | FUSER METHODS | +# =================== + +def _fuse_linear_bn_leaky_relu(is_qat, linear, bn, leaky_relu): + r"""Given the linear, bn and leaky_relu modules, fuses them and returns the fused module + Args: + is_qat: a flag for whether we are using quantization aware training fusion + or post training quantization fusion + linear: Module instance of type Linear + bn: BatchNorm1d instance that needs to be fused with the linear layer + leaky_relu: LeakyReLU instance that needs to be fused with the linear layer + Examples:: + >>> # xdoctest: +SKIP(failing) + >>> m1 = nn.Linear(20, 10) + >>> b1 = nn.BatchNorm1d(10) + >>> lr = nn.LeakyReLU(0.01) + >>> m2 = _fuse_linear_bn_leaky_relu(m1, b1, lr) + """ + assert linear.training == bn.training and bn.training == leaky_relu.training, \ + "Linear, BN and LeakyReLU all must be in the same mode (train or eval)." + + if is_qat: + raise NotImplementedError(f"Cannot fuse train modules: {(linear, bn, leaky_relu)}") + else: + map_to_fused_module_eval = { + nn.Linear: nni.LinearLeakyReLU, + } + fused_module = map_to_fused_module_eval.get(type(linear), None) + if fused_module is not None: + fused_linear = nn.utils.fusion.fuse_linear_bn_eval(linear, bn) + fm = fused_module(fused_linear, leaky_relu) + return fm + else: + raise NotImplementedError(f"Cannot fuse eval modules: {(linear, bn, leaky_relu)}") + +# ====================== +# | CONFIGS FOR CONV | +# ====================== +observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + +conv_dtype_configs = [onednn_weighted_op_int8_dtype_config] +conv_configs = _get_conv_configs(conv_dtype_configs) + +# (1) Conv2d + Add + +# conv2d Y +# \ / +# add + +# include: +# conv2d conv2d +# \ / +# add + +def _fuse_conv_add_left(is_qat, add, conv, _): + return nni.ConvAdd2d(conv, add) + +def _conv_add_root_node_getter_left(pattern): + _, conv, _ = pattern + return conv + +def _conv_add_extra_inputs_getter_left(pattern): + """ get inputs pattern for extra inputs, inputs for root node + are assumed to be copied over from root node to the fused node + """ + _, conv, extra_input = pattern + return [extra_input] + +# conv2d +# \ +# bn Y +# \ / +# add + +def _fuse_conv_bn_add_left(is_qat, add, bn_conv, _): + bn, conv = bn_conv + if is_qat: + raise NotImplementedError(f"Cannot fuse train modules: {(conv, bn, add)}") + else: + fused_conv = nn.utils.fusion.fuse_conv_bn_eval(conv, bn) + return nni.ConvAdd2d(fused_conv, add) + +def _conv_bn_add_root_node_getter_left(add_pattern): + _, bn_conv, _ = add_pattern + bn, conv = bn_conv + return conv + +def _conv_bn_add_extra_inputs_getter_left(add_pattern): + """ get inputs pattern for extra inputs, inputs for root node + are assumed to be copied over from root node to the fused node + """ + _, bn_conv, extra_input = add_pattern + bn, conv = bn_conv + return [extra_input] + +conv_add_left_optioins = itertools.product( + [True, False], # with_bn + [torch.add, operator.add], # add_op +) + +for with_bn, add_op in conv_add_left_optioins: + if with_bn: + conv_configs.append( + BackendPatternConfig() + ._set_pattern_complex_format((add_op, (nn.BatchNorm2d, nn.Conv2d), MatchAllNode)) # noqa: E131 + .set_observation_type(observation_type) + .set_dtype_configs(conv_dtype_configs) + .set_fuser_method(_fuse_conv_bn_add_left) + ._set_root_node_getter(_conv_bn_add_root_node_getter_left) + ._set_extra_inputs_getter(_conv_bn_add_extra_inputs_getter_left) + .set_fused_module(nni.ConvAdd2d)) + else: + conv_configs.append( + BackendPatternConfig() + ._set_pattern_complex_format((add_op, nn.Conv2d, MatchAllNode)) # noqa: E131 + .set_observation_type(observation_type) + .set_dtype_configs(conv_dtype_configs) + .set_fuser_method(_fuse_conv_add_left) + ._set_root_node_getter(_conv_add_root_node_getter_left) + ._set_extra_inputs_getter(_conv_add_extra_inputs_getter_left) + .set_fused_module(nni.ConvAdd2d)) + +# Y conv2d +# \ / +# add + +def _fuse_conv_add_right(is_qat, add, _, conv): + return nni.ConvAdd2d(conv, add) + +def _conv_add_root_node_getter_right(pattern): + add, _, conv = pattern + return conv + +def _conv_add_extra_inputs_getter_right(pattern): + """ get inputs pattern for extra inputs, inputs for root node + are assumed to be copied over from root node to the fused node + """ + _, extra_input, conv = pattern + return [extra_input] + +# conv2d +# / +# Y bn +# \ / +# add + +def _fuse_conv_bn_add_right(is_qat, add, _, bn_conv): + bn, conv = bn_conv + if is_qat: + raise NotImplementedError(f"Cannot fuse train modules: {(conv, bn, add)}") + else: + fused_conv = nn.utils.fusion.fuse_conv_bn_eval(conv, bn) + return nni.ConvAdd2d(fused_conv, add) + +def _conv_bn_add_root_node_getter_right(pattern): + add, _, bn_conv = pattern + bn, conv = bn_conv + return conv + +def _conv_bn_add_extra_inputs_getter_right(pattern): + """ get inputs pattern for extra inputs, inputs for root node + are assumed to be copied over from root node to the fused node + """ + _, extra_input, bn_conv = pattern + bn, conv = bn_conv + return [extra_input] + +conv_add_optioins = itertools.product( + [True, False], # with_bn + [torch.add, operator.add], # add_op +) + +for with_bn, add_op in conv_add_optioins: + if with_bn: + conv_configs.append( + BackendPatternConfig() + ._set_pattern_complex_format((add_op, MatchAllNode, (nn.BatchNorm2d, nn.Conv2d))) # noqa: E131 + .set_observation_type(observation_type) + .set_dtype_configs(conv_dtype_configs) + .set_fuser_method(_fuse_conv_bn_add_right) + ._set_root_node_getter(_conv_bn_add_root_node_getter_right) + ._set_extra_inputs_getter(_conv_bn_add_extra_inputs_getter_right) + .set_fused_module(nni.ConvAdd2d)) + else: + conv_configs.append( + BackendPatternConfig() + ._set_pattern_complex_format((add_op, MatchAllNode, nn.Conv2d)) # noqa: E131 + .set_observation_type(observation_type) + .set_dtype_configs(conv_dtype_configs) + .set_fuser_method(_fuse_conv_add_right) + ._set_root_node_getter(_conv_add_root_node_getter_right) + ._set_extra_inputs_getter(_conv_add_extra_inputs_getter_right) + .set_fused_module(nni.ConvAdd2d)) + +conv_configs.append( + BackendPatternConfig(nni.ConvAdd2d) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(conv_dtype_configs) + .set_root_module(nn.Conv2d) + .set_reference_quantized_module(nnqr.Conv2d)) + +# (2) Conv2d + Add + Relu + +# conv2d Y +# \ / +# add +# \ +# relu + +def _fuse_conv_add_relu_left(is_qat, relu, add_pattern): + add, conv, _ = add_pattern + return nni.ConvAddReLU2d(conv, add, relu) + +def _conv_add_relu_root_node_getter_left(pattern): + relu, add_pattern = pattern + _, conv, _ = add_pattern + return conv + +def _conv_add_relu_extra_inputs_getter_left(pattern): + """ get inputs pattern for extra inputs, inputs for root node + are assumed to be copied over from root node to the fused node + """ + relu, add_pattern = pattern + _, conv, extra_input = add_pattern + return [extra_input] + +# conv2d +# \ +# bn Y +# \ / +# add +# \ +# relu + +def _fuse_conv_bn_add_relu_left(is_qat, relu, add_pattern): + add, bn_conv, _ = add_pattern + bn, conv = bn_conv + if is_qat: + raise NotImplementedError(f"Cannot fuse train modules: {(conv, bn, add, relu)}") + else: + fused_conv = nn.utils.fusion.fuse_conv_bn_eval(conv, bn) + return nni.ConvAddReLU2d(fused_conv, add, relu) + +def _conv_bn_add_relu_root_node_getter_left(pattern): + relu, add_pattern = pattern + _, bn_conv, _ = add_pattern + bn, conv = bn_conv + return conv + +def _conv_bn_add_relu_extra_inputs_getter_left(pattern): + """ get inputs pattern for extra inputs, inputs for root node + are assumed to be copied over from root node to the fused node + """ + relu, add_pattern = pattern + _, bn_conv, extra_input = add_pattern + bn, conv = bn_conv + return [extra_input] + +conv_add_relu_left_optioins = itertools.product( + [True, False], # with_bn + [torch.add, operator.add], # add_op +) + +for with_bn, add_op in conv_add_relu_left_optioins: + if with_bn: + conv_configs.append( + BackendPatternConfig() + ._set_pattern_complex_format((nn.ReLU, (add_op, (nn.BatchNorm2d, nn.Conv2d), MatchAllNode))) # noqa: E131 + .set_observation_type(observation_type) + .set_dtype_configs(conv_dtype_configs) + .set_fuser_method(_fuse_conv_bn_add_relu_left) + ._set_root_node_getter(_conv_bn_add_relu_root_node_getter_left) + ._set_extra_inputs_getter(_conv_bn_add_relu_extra_inputs_getter_left) + .set_fused_module(nni.ConvAddReLU2d)) + else: + conv_configs.append( + BackendPatternConfig() + ._set_pattern_complex_format((nn.ReLU, (add_op, nn.Conv2d, MatchAllNode))) # noqa: E131 + .set_observation_type(observation_type) + .set_dtype_configs(conv_dtype_configs) + .set_fuser_method(_fuse_conv_add_relu_left) + ._set_root_node_getter(_conv_add_relu_root_node_getter_left) + ._set_extra_inputs_getter(_conv_add_relu_extra_inputs_getter_left) + .set_fused_module(nni.ConvAddReLU2d)) + +# Y conv2d +# \ / +# add +# \ +# relu + +def _fuse_conv_add_relu_right(is_qat, relu, add_pattern): + add, _, conv = add_pattern + return nni.ConvAddReLU2d(conv, add, relu) + +def _conv_add_relu_root_node_getter_right(pattern): + relu, add_pattern = pattern + _, _, conv = add_pattern + return conv + +def _conv_add_relu_extra_inputs_getter_right(pattern): + """ get inputs pattern for extra inputs, inputs for root node + are assumed to be copied over from root node to the fused node + """ + relu, add_pattern = pattern + _, extra_input, conv = add_pattern + return [extra_input] + +# conv2d +# / +# Y bn +# \ / +# add +# \ +# relu + +def _fuse_conv_bn_add_relu_right(is_qat, relu, add_pattern): + add, _, bn_conv = add_pattern + bn, conv = bn_conv + if is_qat: + raise NotImplementedError(f"Cannot fuse train modules: {(conv, bn, add, relu)}") + else: + fused_conv = nn.utils.fusion.fuse_conv_bn_eval(conv, bn) + return nni.ConvAddReLU2d(fused_conv, add, relu) + +def _conv_bn_add_relu_root_node_getter_right(pattern): + relu, add_pattern = pattern + _, _, bn_conv = add_pattern + bn, conv = bn_conv + return conv + +def _conv_bn_add_relu_extra_inputs_getter_right(pattern): + """ get inputs pattern for extra inputs, inputs for root node + are assumed to be copied over from root node to the fused node + """ + relu, add_pattern = pattern + _, extra_input, bn_conv = add_pattern + bn, conv = bn_conv + return [extra_input] + +conv_add_relu_optioins = itertools.product( + [True, False], # with_bn + [torch.add, operator.add], # add_op +) + +for with_bn, add_op in conv_add_relu_optioins: + if with_bn: + conv_configs.append( + BackendPatternConfig() + ._set_pattern_complex_format((nn.ReLU, (add_op, MatchAllNode, (nn.BatchNorm2d, nn.Conv2d)))) # noqa: E131 + .set_observation_type(observation_type) + .set_dtype_configs(conv_dtype_configs) + .set_fuser_method(_fuse_conv_bn_add_relu_right) + ._set_root_node_getter(_conv_bn_add_relu_root_node_getter_right) + ._set_extra_inputs_getter(_conv_bn_add_relu_extra_inputs_getter_right) + .set_fused_module(nni.ConvAddReLU2d)) + else: + conv_configs.append( + BackendPatternConfig() + ._set_pattern_complex_format((nn.ReLU, (add_op, MatchAllNode, nn.Conv2d))) # noqa: E131 + .set_observation_type(observation_type) + .set_dtype_configs(conv_dtype_configs) + .set_fuser_method(_fuse_conv_add_relu_right) + ._set_root_node_getter(_conv_add_relu_root_node_getter_right) + ._set_extra_inputs_getter(_conv_add_relu_extra_inputs_getter_right) + .set_fused_module(nni.ConvAddReLU2d)) + +conv_configs.append( + BackendPatternConfig(nni.ConvAddReLU2d) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(conv_dtype_configs) + .set_root_module(nn.Conv2d) + .set_reference_quantized_module(nnqr.Conv2d)) + +# ======================== +# | CONFIGS FOR LINEAR | +# ======================== + +linear_dtype_configs = [ + onednn_weighted_op_int8_dtype_config, + onednn_dynamic_int8_dtype_config, +] +linear_configs = _get_linear_configs(linear_dtype_configs) + +def _add_eltwise_fusion_configs(configs, root_module, root_op, post_module, post_op, + dtype_configs, fuser_method, fused_module, observation_type, + ref_quant_module): + # 1 base module + op module fusion config + configs.append( + BackendPatternConfig((root_module, post_module)) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method(fuser_method) + .set_fused_module(fused_module)) + # base module + functional post op + configs.append( + BackendPatternConfig((root_module, post_op)) + .set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method(fuser_method) + .set_fused_module(fused_module)) + + # 2 fused module configs + configs.append( + BackendPatternConfig(fused_module) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs) + .set_root_module(root_module) + .set_reference_quantized_module(ref_quant_module)) + + # 3 functional base op + post op configs + configs.append( + BackendPatternConfig((root_op, post_module)) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs)) + configs.append( + BackendPatternConfig((root_op, post_op)) + .set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs)) + +# Configs for linear + leaky_relu fusion +_add_eltwise_fusion_configs(linear_configs, nn.Linear, F.linear, + nn.LeakyReLU, F.leaky_relu, linear_dtype_configs, + _sequential_wrapper2(nni.LinearLeakyReLU), + nni.LinearLeakyReLU, observation_type, nnqr.Linear) + +# Configs for linear module + batchnorm + leaky_relu +linear_configs.append( + BackendPatternConfig((nn.Linear, nn.BatchNorm1d, nn.LeakyReLU)) + .set_dtype_configs(linear_dtype_configs) # noqa: E131 + .set_fuser_method(_fuse_linear_bn_leaky_relu) + .set_fused_module(nni.LinearLeakyReLU)) + +# Configs for linear + tanh fusion +_add_eltwise_fusion_configs(linear_configs, nn.Linear, F.linear, + nn.Tanh, torch.tanh, linear_dtype_configs, + _sequential_wrapper2(nni.LinearTanh), + nni.LinearTanh, observation_type, nnqr.Linear) + +# =========================== +# | CONFIGS FOR OTHER OPS | +# =========================== + +binary_op_dtype_configs = [onednn_op_quint8_dtype_config] +default_op_dtype_configs = [onednn_op_quint8_dtype_config] +fixed_qparams_op_dtype_configs = [onednn_op_quint8_dtype_config] +share_qparams_op_dtype_configs = [onednn_op_quint8_dtype_config] +rnn_op_dtype_configs = [onednn_dynamic_int8_dtype_config] +embedding_op_dtype_configs = [onednn_weight_only_qint8_dtype_config] +layer_norm_op_dtype_configs = [onednn_input_output_only_quint8_dtype_config] + +# ===================== +# | BACKEND CONFIGS | +# ===================== + +def get_onednn_backend_config() -> BackendConfig: + """ + Return the `BackendConfig` for PyTorch's native ONEDNN backend. + """ + return BackendConfig("onednn") \ + .set_backend_pattern_configs(conv_configs) \ + .set_backend_pattern_configs(linear_configs) \ + .set_backend_pattern_configs(_get_binary_op_configs(binary_op_dtype_configs)) \ + .set_backend_pattern_config(_get_cat_config(default_op_dtype_configs)) \ + .set_backend_pattern_configs(_get_default_op_configs(default_op_dtype_configs)) \ + .set_backend_pattern_configs(_get_fixed_qparams_op_configs(fixed_qparams_op_dtype_configs)) \ + .set_backend_pattern_configs(_get_share_qparams_op_configs(share_qparams_op_dtype_configs)) \ + .set_backend_pattern_configs(_get_bn_configs(default_op_dtype_configs)) \ + .set_backend_pattern_configs(_get_ln_configs(layer_norm_op_dtype_configs)) \ + .set_backend_pattern_configs(_get_rnn_op_configs(rnn_op_dtype_configs)) \ + .set_backend_pattern_configs(_get_embedding_op_configs(embedding_op_dtype_configs)) + +__all__ = [ + "get_onednn_backend_config", +] diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fuse_modules.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fuse_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..2caa0a2b7f2d484b0e948afd8af8ea6d854b9868 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fuse_modules.py @@ -0,0 +1,175 @@ +import copy + +import torch.nn as nn + +from torch.ao.quantization.fuser_method_mappings import get_fuser_method +# for backward compatibility +from torch.ao.quantization.fuser_method_mappings import fuse_conv_bn # noqa: F401 +from torch.ao.quantization.fuser_method_mappings import fuse_conv_bn_relu # noqa: F401 +from torch.nn.utils.parametrize import type_before_parametrizations + +from typing import List, Optional + +__all__ = [ + "fuse_known_modules", + "fuse_modules", + "fuse_modules_qat", +] + +# Generalization of getattr +def _get_module(model, submodule_key): + tokens = submodule_key.split('.') + cur_mod = model + for s in tokens: + cur_mod = getattr(cur_mod, s) + return cur_mod + +# Generalization of setattr +def _set_module(model, submodule_key, module): + tokens = submodule_key.split('.') + sub_tokens = tokens[:-1] + cur_mod = model + for s in sub_tokens: + cur_mod = getattr(cur_mod, s) + + setattr(cur_mod, tokens[-1], module) + +def fuse_known_modules(mod_list, is_qat, additional_fuser_method_mapping=None): + r"""Return a list of known fuse modules. + + Returns a list of modules that fuses the operations specified + in the input module list. + + Fuses only the following sequence of modules: + conv, bn + conv, bn, relu + conv, relu + linear, bn + linear, relu + For these sequences, the first element in the output module list performs + the fused operation. The rest of the elements are set to nn.Identity() + """ + types = tuple(type_before_parametrizations(m) for m in mod_list) + fuser_method = get_fuser_method(types, additional_fuser_method_mapping) + if fuser_method is None: + raise NotImplementedError(f"Cannot fuse modules: {types}") + new_mod : List[Optional[nn.Module]] = [None] * len(mod_list) + fused = fuser_method(is_qat, *mod_list) + # NOTE: forward hooks not processed in the two following for loops will be lost after the fusion + # Move pre forward hooks of the base module to resulting fused module + for pre_hook_fn in mod_list[0]._forward_pre_hooks.values(): + fused.register_forward_pre_hook(pre_hook_fn) + mod_list[0]._forward_pre_hooks.clear() + # Move post forward hooks of the last module to resulting fused module + for hook_fn in mod_list[-1]._forward_hooks.values(): + fused.register_forward_hook(hook_fn) + mod_list[-1]._forward_hooks.clear() + new_mod[0] = fused + + for i in range(1, len(mod_list)): + identity = nn.Identity() + identity.training = mod_list[0].training + new_mod[i] = identity + + return new_mod + +def _fuse_modules_helper(model, modules_to_fuse, is_qat, fuser_func=fuse_known_modules, fuse_custom_config_dict=None): + if fuse_custom_config_dict is None: + fuse_custom_config_dict = {} + additional_fuser_method_mapping = fuse_custom_config_dict.get("additional_fuser_method_mapping", {}) + mod_list = [] + for item in modules_to_fuse: + mod_list.append(_get_module(model, item)) + + # Fuse list of modules + new_mod_list = fuser_func(mod_list, is_qat, additional_fuser_method_mapping) + + # Replace original module list with fused module list + for i, item in enumerate(modules_to_fuse): + _set_module(model, item, new_mod_list[i]) + +def _fuse_modules(model, modules_to_fuse, is_qat, inplace=False, fuser_func=fuse_known_modules, fuse_custom_config_dict=None): + if not inplace: + model = copy.deepcopy(model) + + if all(isinstance(module_element, str) for module_element in modules_to_fuse): + # Handle case of modules_to_fuse being a list + _fuse_modules_helper(model, modules_to_fuse, is_qat, fuser_func, fuse_custom_config_dict) + else: + # Handle case of modules_to_fuse being a list of lists + for module_list in modules_to_fuse: + _fuse_modules_helper(model, module_list, is_qat, fuser_func, fuse_custom_config_dict) + return model + +def fuse_modules(model, modules_to_fuse, inplace=False, fuser_func=fuse_known_modules, fuse_custom_config_dict=None): + r"""Fuse a list of modules into a single module. + + Fuses only the following sequence of modules: + conv, bn + conv, bn, relu + conv, relu + linear, relu + bn, relu + All other sequences are left unchanged. + For these sequences, replaces the first item in the list + with the fused module, replacing the rest of the modules + with identity. + + Args: + model: Model containing the modules to be fused + modules_to_fuse: list of list of module names to fuse. Can also be a list + of strings if there is only a single list of modules to fuse. + inplace: bool specifying if fusion happens in place on the model, by default + a new model is returned + fuser_func: Function that takes in a list of modules and outputs a list of fused modules + of the same length. For example, + fuser_func([convModule, BNModule]) returns the list [ConvBNModule, nn.Identity()] + Defaults to torch.ao.quantization.fuse_known_modules + `fuse_custom_config_dict`: custom configuration for fusion + + .. code-block:: python + + # Example of fuse_custom_config_dict + fuse_custom_config_dict = { + # Additional fuser_method mapping + "additional_fuser_method_mapping": { + (torch.nn.Conv2d, torch.nn.BatchNorm2d): fuse_conv_bn + }, + } + + Returns: + model with fused modules. A new copy is created if inplace=True. + + Examples:: + + >>> # xdoctest: +SKIP + >>> m = M().eval() + >>> # m is a module containing the sub-modules below + >>> modules_to_fuse = [ ['conv1', 'bn1', 'relu1'], ['submodule.conv', 'submodule.relu']] + >>> fused_m = torch.ao.quantization.fuse_modules(m, modules_to_fuse) + >>> output = fused_m(input) + + >>> m = M().eval() + >>> # Alternately provide a single list of modules to fuse + >>> modules_to_fuse = ['conv1', 'bn1', 'relu1'] + >>> fused_m = torch.ao.quantization.fuse_modules(m, modules_to_fuse) + >>> output = fused_m(input) + + """ + return _fuse_modules( + model, + modules_to_fuse, + is_qat=False, + inplace=inplace, + fuser_func=fuser_func, + fuse_custom_config_dict=fuse_custom_config_dict) + +def fuse_modules_qat(model, modules_to_fuse, inplace=False, fuser_func=fuse_known_modules, fuse_custom_config_dict=None): + """QAT version for `fuse_modules`.""" + return _fuse_modules( + model, + modules_to_fuse, + is_qat=True, + inplace=inplace, + fuser_func=fuser_func, + fuse_custom_config_dict=fuse_custom_config_dict) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/_decomposed.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/_decomposed.py new file mode 100644 index 0000000000000000000000000000000000000000..ddc6d9c06ecade0d64fd47a92345ed545467ba79 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/_decomposed.py @@ -0,0 +1,925 @@ +import math +from typing import Optional, Tuple + +import torch +from torch.library import Library, impl +from torch.ao.quantization.utils import determine_qparams, validate_qmin_qmax +from torch._refs import _unsqueeze_multiple + + +# Note: decomposed means decomposed quantized tensor, using decomposed so that the +# name is not too long +quantized_decomposed_lib = Library("quantized_decomposed", "DEF") + +_DTYPE_TO_QVALUE_BOUNDS = { + torch.uint8: (0, 255), + torch.int8: (-128, 127), + torch.int16: (-(2**15), 2**15 - 1), + torch.int32: (-(2**31), 2**31 - 1) +} + +# Helper to check the passed in quant min and max are valid for the dtype +def _quant_min_max_bounds_check(quant_min, quant_max, dtype): + if dtype not in _DTYPE_TO_QVALUE_BOUNDS: + raise ValueError(f"Unsupported dtype: {dtype}") + quant_min_lower_bound, quant_max_upper_bound = _DTYPE_TO_QVALUE_BOUNDS[dtype] + + assert quant_min >= quant_min_lower_bound, \ + "quant_min out of bound for dtype, " \ + f"quant_min_lower_bound: {quant_min_lower_bound} quant_min: {quant_min}" + + assert quant_max <= quant_max_upper_bound, \ + "quant_max out of bound for dtype, " \ + f"quant_max_upper_bound: {quant_max_upper_bound} quant_max: {quant_max}" + +quantized_decomposed_lib.define( + "quantize_per_tensor(Tensor input, float scale, int zero_point, " + "int quant_min, int quant_max, ScalarType dtype) -> Tensor") + +@impl(quantized_decomposed_lib, "quantize_per_tensor", "CompositeExplicitAutograd") +def quantize_per_tensor( + input: torch.Tensor, + scale: float, + zero_point: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype +) -> torch.Tensor: + """ Affine quantization for the Tensor using the same quantization parameters to map + from floating point to quantized values + + Args: + input (torch.Tensor): original float32 or bfloat16 Tensor + scale (float): quantization parameter for affine quantization + zero_point (int): quantization parameter for affine quantization + quant_min (int): minimum quantized value for output Tensor + quant_max (int): maximum quantized value for output Tensor + dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor + + Returns: + Tensor with requested dtype (e.g. torch.uint8), note the quantization parameters + are not stored in the Tensor, we are storing them in function arguments instead + """ + if input.dtype == torch.bfloat16: + input = input.to(torch.float32) + + assert input.dtype == torch.float32, f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" + _quant_min_max_bounds_check(quant_min, quant_max, dtype) + + inv_scale = 1.0 / scale + return torch.clamp(torch.round(input * inv_scale) + zero_point, quant_min, quant_max).to(dtype) + +quantized_decomposed_lib.define( + "quantize_per_tensor.tensor(Tensor input, Tensor scale, Tensor zero_point, " + "int quant_min, int quant_max, ScalarType dtype) -> Tensor") + +@impl(quantized_decomposed_lib, "quantize_per_tensor.tensor", "CompositeExplicitAutograd") +def quantize_per_tensor_tensor( + input: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + quant_min: int, + quant_max: int, + dtype: torch.dtype +) -> torch.Tensor: + """ Affine quantization for the Tensor using the same quantization parameters to map + from floating point to quantized values + Same as `quantize_per_tensor` but scale and zero_point are Scalar Tensor instead of + scalar values + """ + assert zero_point.numel() == 1, f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}" + assert scale.numel() == 1, f"Expecting scale tensor to be one element, but received : {scale.numel()}" + return quantize_per_tensor(input, scale.item(), zero_point.item(), quant_min, quant_max, dtype) + +@impl(quantized_decomposed_lib, "quantize_per_tensor.tensor", "Meta") +def quantize_per_tensor_tensor_meta(input, scale, zero_point, quant_min, quant_max, dtype): + if input.dtype == torch.bfloat16: + input = input.to(torch.float32) + assert zero_point.numel() == 1, f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}" + assert scale.numel() == 1, f"Expecting scale tensor to be one element, but received : {scale.numel()}" + assert input.dtype == torch.float32, f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" + return torch.empty_like(input, dtype=dtype) + +# TODO: remove other variants and keep this one +quantized_decomposed_lib.define( + "quantize_per_tensor.tensor2(Tensor input, Tensor scale, Tensor zero_point, " + "Tensor quant_min, Tensor quant_max, ScalarType dtype) -> Tensor") + +@impl(quantized_decomposed_lib, "quantize_per_tensor.tensor2", "CompositeExplicitAutograd") +def quantize_per_tensor_tensor2( + input: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + quant_min: torch.Tensor, + quant_max: torch.Tensor, + dtype: torch.dtype +) -> torch.Tensor: + """ Affine quantization for the Tensor using the same quantization parameters to map + from floating point to quantized values + Same as `quantize_per_tensor` but scale and zero_point are Scalar Tensor instead of + scalar values + """ + assert zero_point.numel() == 1, f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}" + assert scale.numel() == 1, f"Expecting scale tensor to be one element, but received : {scale.numel()}" + return quantize_per_tensor(input, scale.item(), zero_point.item(), quant_min.item(), quant_max.item(), dtype) + +@impl(quantized_decomposed_lib, "quantize_per_tensor.tensor2", "Meta") +def quantize_per_tensor_tensor2_meta(input, scale, zero_point, quant_min, quant_max, dtype): + return quantize_per_tensor_tensor_meta(input, scale, zero_point, quant_min, quant_max, dtype) + +# Note: quant_min/quant_max/dtype are not used in the operator, but for now it's kept in +# the signature as metadata for the input Tensor, this might be useful for pattern +# matching in the future +# We will revisit this later if we found there are no use cases for it +quantized_decomposed_lib.define( + "dequantize_per_tensor(Tensor input, float scale, int zero_point, " + "int quant_min, int quant_max, ScalarType dtype) -> Tensor") + +@impl(quantized_decomposed_lib, "dequantize_per_tensor", "CompositeExplicitAutograd") +def dequantize_per_tensor( + input: torch.Tensor, + scale: float, + zero_point: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype +) -> torch.Tensor: + """ Affine dequantization for the Tensor using the same quantization parameters to map + from quantized values to floating point values + + Args: + input (torch.Tensor): Tensor with dtype matching `dtype` argument, + e.g. (`torch.uint8`), it is a per tensor quantized Tensor if combined with + quantization parameters in the argument of this function (scale/zero_point) + + scale (float): quantization parameter for affine quantization + + zero_point (int): quantization parameter for affine quantization + + quant_min (int): minimum quantized value for input Tensor (not used in computation, + reserved for pattern matching) + + quant_max (int): maximum quantized value for input Tensor (not used in computation, + reserved for pattern matching) + + dtype (torch.dtype): dtype for input Tensor (not used in computation, + reserved for pattern matching) + + Returns: + dequantized float32 Tensor + """ + assert input.dtype == dtype, f"Expecting input to have dtype: {dtype}, but got {input.dtype}" + if dtype in _DTYPE_TO_QVALUE_BOUNDS: + # TODO: investigate why + # (input - zero_point).to(torch.float32) * scale + # failed the test + return (input.to(torch.float32) - zero_point) * scale + else: + raise ValueError(f"Unsupported dtype in dequantize_per_tensor: {dtype}") + + +quantized_decomposed_lib.define( + "dequantize_per_tensor.tensor(Tensor input, Tensor scale, Tensor zero_point, " + "int quant_min, int quant_max, ScalarType dtype) -> Tensor") + +@impl(quantized_decomposed_lib, "dequantize_per_tensor.tensor", "CompositeExplicitAutograd") +def dequantize_per_tensor_tensor( + input: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + quant_min: int, + quant_max: int, + dtype: torch.dtype +) -> torch.Tensor: + """ Affine dequantization for the Tensor using the same quantization parameters to map + from quantized values to floating point values + Same as `dequantize_per_tensor` but scale and zero_point are Scalar Tensor instead of + scalar values + """ + assert zero_point.numel() == 1, f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}" + assert scale.numel() == 1, f"Expecting scale tensor to be one element, but received : {scale.numel()}" + return dequantize_per_tensor(input, scale.item(), zero_point.item(), quant_min, quant_max, dtype) + +@impl(quantized_decomposed_lib, "dequantize_per_tensor.tensor", "Meta") +def dequantize_per_tensor_tensor_meta(input, scale, zero_point, quant_min, quant_max, dtype): + assert zero_point.numel() == 1, f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}" + assert scale.numel() == 1, f"Expecting scale tensor to be one element, but received : {scale.numel()}" + assert input.dtype == dtype, f"Expecting input to have dtype: {dtype}" + if dtype in _DTYPE_TO_QVALUE_BOUNDS: + return torch.empty_like(input, dtype=torch.float32) + else: + raise ValueError(f"Unsupported dtype in dequantize_per_tensor: {dtype}") + +# TODO: remove other variants and keep this one +quantized_decomposed_lib.define( + "dequantize_per_tensor.tensor2(Tensor input, Tensor scale, Tensor zero_point, " + "Tensor quant_min, Tensor quant_max, ScalarType dtype) -> Tensor") + +@impl(quantized_decomposed_lib, "dequantize_per_tensor.tensor2", "CompositeExplicitAutograd") +def dequantize_per_tensor_tensor2( + input: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + quant_min: torch.Tensor, + quant_max: torch.Tensor, + dtype: torch.dtype +) -> torch.Tensor: + """ Affine dequantization for the Tensor using the same quantization parameters to map + from quantized values to floating point values + Same as `dequantize_per_tensor` but scale and zero_point are Scalar Tensor instead of + scalar values + """ + assert zero_point.numel() == 1, f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}" + assert scale.numel() == 1, f"Expecting scale tensor to be one element, but received : {scale.numel()}" + return dequantize_per_tensor(input, scale.item(), zero_point.item(), quant_min.item(), quant_max.item(), dtype) + +@impl(quantized_decomposed_lib, "dequantize_per_tensor.tensor2", "Meta") +def dequantize_per_tensor_tensor2_meta(input, scale, zero_point, quant_min, quant_max, dtype): + return dequantize_per_tensor_tensor_meta(input, scale, zero_point, quant_min, quant_max, dtype) + +quantized_decomposed_lib.define( + "choose_qparams.tensor(Tensor input, int quant_min, int quant_max, " + "float eps, ScalarType dtype) -> (Tensor, Tensor)") + +@impl(quantized_decomposed_lib, "choose_qparams.tensor", "CompositeExplicitAutograd") +def choose_qparams_tensor( + input: torch.Tensor, + qmin: int, + qmax: int, + eps: float, + dtype: torch.dtype +) -> Tuple[torch.Tensor, torch.Tensor]: + """ Given an input Tensor, derive the per tensor affine quantization parameter + (scale and zero_point) for target quantized Tensor from the Tensor + + Args: + input (torch.Tensor): floating point input Tensor + quant_min (int): minimum quantized value for target quantized Tensor + quant_max (int): maximum quantized value for target quantized Tensor + dtype (torch.dtype): dtype for target quantized Tensor + + Returns: + scale (float): quantization parameter for the target quantized Tensor + zero_point (int): quantization parameter for the target quantized Tensor + """ + assert input.dtype == torch.float32, f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" + assert dtype in _DTYPE_TO_QVALUE_BOUNDS, \ + f"Expecting target dtype to be one of {_DTYPE_TO_QVALUE_BOUNDS.keys()}, but got: {dtype}" + validate_qmin_qmax(qmin, qmax) + + min_val, max_val = torch.aminmax(input) + + return determine_qparams( + min_val, max_val, qmin, qmax, dtype, torch.Tensor([eps]), has_customized_qrange=False) + +quantized_decomposed_lib.define( + "choose_qparams_symmetric.tensor(Tensor input, int quant_min, int quant_max, " + "float eps, ScalarType dtype) -> (Tensor, Tensor)") + +@impl(quantized_decomposed_lib, "choose_qparams_symmetric.tensor", "CompositeExplicitAutograd") +def choose_qparams_symmetric_tensor( + input: torch.Tensor, + qmin: int, + qmax: int, + eps: float, + dtype: torch.dtype +) -> Tuple[torch.Tensor, torch.Tensor]: + """ Given an input Tensor, derive the per tensor affine quantization parameter + (scale and zero_point) for target quantized Tensor from the Tensor + + Args: + input (torch.Tensor): floating point input Tensor + quant_min (int): minimum quantized value for target quantized Tensor + quant_max (int): maximum quantized value for target quantized Tensor + dtype (torch.dtype): dtype for target quantized Tensor + + Returns: + scale (float): quantization parameter for the target quantized Tensor + zero_point (int): quantization parameter for the target quantized Tensor + """ + assert input.dtype == torch.float32, f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" + assert dtype in _DTYPE_TO_QVALUE_BOUNDS, \ + f"Expecting target dtype to be one of {_DTYPE_TO_QVALUE_BOUNDS.keys()}, but got: {dtype}" + validate_qmin_qmax(qmin, qmax) + + min_val, max_val = torch.aminmax(input) + return determine_qparams( + min_val, + max_val, + qmin, + qmax, + dtype, + torch.Tensor([eps]), + has_customized_qrange=False, + qscheme=torch.per_tensor_symmetric + ) + +@impl(quantized_decomposed_lib, "choose_qparams.tensor", "Meta") +def choose_qparams_tensor_meta( + input: torch.Tensor, + quant_min: int, + quant_max: int, + eps: float, + dtype: torch.dtype +) -> Tuple[torch.Tensor, torch.Tensor]: + assert input.dtype == torch.float32, f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" + assert quant_min < quant_max, f"Expecting quant_min to be smaller than quant_max but received min: \ + {quant_min} max: {quant_max}" + return torch.empty(1, dtype=torch.double, device=input.device), torch.empty(1, dtype=torch.int64, device=input.device) + +@impl(quantized_decomposed_lib, "choose_qparams_symmetric.tensor", "Meta") +def choose_qparams_symmetric_tensor_meta( + input: torch.Tensor, + quant_min: int, + quant_max: int, + eps: float, + dtype: torch.dtype +) -> Tuple[torch.Tensor, torch.Tensor]: + return torch.empty(1, dtype=torch.double, device=input.device), torch.empty(1, dtype=torch.int64, device=input.device) + +# Helper function used to implement per-channel quantization against any axis +def _permute_to_axis_zero(x, axis): + new_axis_list = list(range(x.dim())) + new_axis_list[axis] = 0 + new_axis_list[0] = axis + y = x.permute(tuple(new_axis_list)) + return y, new_axis_list + +quantized_decomposed_lib.define( + "quantize_per_channel(Tensor input, Tensor scales, Tensor zero_points, int axis, " + "int quant_min, int quant_max, ScalarType dtype) -> Tensor") + +@impl(quantized_decomposed_lib, "quantize_per_channel", "CompositeExplicitAutograd") +def quantize_per_channel( + input: torch.Tensor, + scales: torch.Tensor, + zero_points: torch.Tensor, + axis: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype +) -> torch.Tensor: + """ Affine per channel quantization for the Tensor using the same quantization + parameters for each channel/axis to map from floating point to quantized values + + Args: + input (torch.Tensor): original float32 or bfloat16 Tensor + scales (torch.Tensor): a list of scale quantization parameter for + affine quantization, one per channel + zero_point (torch.Tensor): a list of zero_point quantization parameter for + affine quantization, one per channel + quant_min (int): minimum quantized value for output Tensor + quant_max (int): maximum quantized value for output Tensor + dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor + + Returns: + Tensor with requested dtype (e.g. torch.uint8), note the quantization parameters + are not stored in the Tensor, we are storing them in function arguments instead + """ + if input.dtype == torch.bfloat16: + input = input.to(torch.float32) + assert input.dtype == torch.float32, f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" + assert axis < input.dim(), f"Expecting axis to be < {input.dim()}" + _quant_min_max_bounds_check(quant_min, quant_max, dtype) + input, permute_axis_list = _permute_to_axis_zero(input, axis) + res = torch.zeros_like(input) + + for i in range(input.size(0)): + res[i] = torch.clamp( + torch.round(input[i] * (1.0 / scales[i])) + zero_points[i], + quant_min, + quant_max + ) + + out = res.permute(tuple(permute_axis_list)) + return out.to(dtype) + +@impl(quantized_decomposed_lib, "quantize_per_channel", "Meta") +def quantize_per_channel_meta( + input: torch.Tensor, + scales: torch.Tensor, + zero_points: torch.Tensor, + axis: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype +) -> torch.Tensor: + if input.dtype == torch.bfloat16: + input = input.to(torch.float32) + assert input.dtype == torch.float32, f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" + assert axis < input.dim(), f"Expecting axis to be < {input.dim()}" + _quant_min_max_bounds_check(quant_min, quant_max, dtype) + return torch.empty_like(input, dtype=dtype) + +# Note: quant_min/quant_max/dtype are not used in the operator, but for now it's kept in +# the signature as metadata for the input Tensor, this might be useful for pattern +# matching in the future +# We will revisit this later if we found there are no use cases for it +quantized_decomposed_lib.define( + "dequantize_per_channel(Tensor input, Tensor scales, Tensor zero_points, int axis, " + "int quant_min, int quant_max, ScalarType dtype) -> Tensor") + +@impl(quantized_decomposed_lib, "dequantize_per_channel", "CompositeExplicitAutograd") +def dequantize_per_channel( + input: torch.Tensor, + scales: torch.Tensor, + zero_points: torch.Tensor, + axis: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype +) -> torch.Tensor: + """ Affine per channel dequantization for the Tensor using the same quantization + parameters for each channel/axis to map from quantized values to floating point values + + Args: + input (torch.Tensor): Tensor with dtype matching `dtype` argument, + e.g. (`torch.uint8`), it is a per channel quantized Tensor if combined with + quantization parameter in the argument of this function (scales/zero_points/axis) + + scales (torch.Tensor): a list of scale quantization parameter for + affine quantization, one per channel + + zero_points (torch.Tensor): a list of zero_point quantization parameter for + affine quantization, one per channel + + quant_min (int): minimum quantized value for output Tensor (not used in computation, + reserved for pattern matching) + + quant_max (int): maximum quantized value for output Tensor (not used in computation, + reserved for pattern matching) + + dtype (torch.dtype): requested dtype for output Tensor (not used in computation, + reserved for pattern matching) + + Returns: + dequantized float32 Tensor + """ + assert input.dtype == dtype, f"Expecting input to have dtype {dtype}, but got dtype: {input.dtype}" + assert axis < input.dim(), f"Expecting axis to be < {input.dim()}" + _quant_min_max_bounds_check(quant_min, quant_max, dtype) + input, permute_axis_list = _permute_to_axis_zero(input, axis) + res = torch.zeros_like(input, dtype=torch.float32) + + for i in range(input.size(0)): + # TODO: investigate why + # (input[i] - zero_points[i]).to(torch.float32) * scales[i] + # failed the test + res[i] = (input[i].to(torch.float32) - zero_points[i]) * scales[i] + + out = res.permute(tuple(permute_axis_list)) + return out + +@impl(quantized_decomposed_lib, "dequantize_per_channel", "Meta") +def dequantize_per_channel_meta( + input: torch.Tensor, + scales: torch.Tensor, + zero_points: torch.Tensor, + axis: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype +) -> torch.Tensor: + assert input.dtype == dtype, f"Expecting input to have dtype {dtype}, but got dtype: {input.dtype}" + assert axis < input.dim(), f"Expecting axis to be < {input.dim()}" + _quant_min_max_bounds_check(quant_min, quant_max, dtype) + return torch.empty_like(input, dtype=torch.float32) + + +quantized_decomposed_lib.define( + "choose_qparams_per_token(Tensor input, ScalarType dtype) -> (Tensor, Tensor)" +) + + +@impl( + quantized_decomposed_lib, + "choose_qparams_per_token", + "CompositeExplicitAutograd", +) +def choose_qparams_per_token( + input: torch.Tensor, + dtype: torch.dtype, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Choose quantization parameters for per token quantization. This means for a N dimension Tensor + (M1, M2, ...Mn, N), we calculate scales/zero_points for each N elements and quantize + every N elements with the same quantization parameter. The dimension for scales/zero_points + will be (M1 * M2 ... * Mn) + + Args: + input (torch.Tensor): original float32/float16 Tensor + dtype (torch.dtype): dtype (e.g. torch.uint8) for input Tensor + + Returns: + scales and zero_points, both float32 Tensors + """ + + scales = input.abs().amax(dim=-1, keepdim=True) + if scales.dtype == torch.float16: + scales = ( + scales.float() + ) # want float scales to avoid overflows for fp16, (bf16 has wide enough range) + if dtype == torch.int8: + n_bits = 8 + quant_max = 2 ** (n_bits - 1) - 1 + else: + raise Exception(f"unsupported dtype in choose_qparams_per_token: {dtype}") + + scales = scales.clamp(min=1e-5).div(quant_max) + zero_points = torch.zeros_like(scales) + return scales, zero_points + + +@impl( + quantized_decomposed_lib, + "choose_qparams_per_token", + "Meta", +) +def choose_qparams_per_token_meta( + input: torch.Tensor, + dtype: torch.dtype, +) -> Tuple[torch.Tensor, torch.Tensor]: + size = (1, input.size(-1)) + return torch.empty(size, dtype=torch.double, device=input.device), torch.empty( + size, dtype=torch.int64, device=input.device + ) + + +# TODO: move this to https://github.com/pytorch/pytorch/blob/main/torch/ao/quantization/fx/_decomposed.py +quantized_decomposed_lib.define( + "choose_qparams_per_token_asymmetric(Tensor input, ScalarType dtype) -> (Tensor, Tensor)" +) + + +@impl( + quantized_decomposed_lib, + "choose_qparams_per_token_asymmetric", + "CompositeExplicitAutograd", +) +def choose_qparams_per_token_asymmetric( + input: torch.Tensor, + dtype: torch.dtype, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Choose quantization parameters for per token quantization. This means for a N dimension Tensor + (M1, M2, ...Mn, N), we calculate scales/zero_points for each N elements and quantize + every N elements with the same quantization parameter. The dimension for scales/zero_points + will be (M1 * M2 ... * Mn) + + Args: + input (torch.Tensor): original float32/float16 Tensor + dtype (torch.dtype): dtype (e.g. torch.uint8) for input Tensor + + Returns: + scales and zero_points, both float32 Tensors + """ + # Based on https://github.com/google/XNNPACK/blob/df156f0cf3db5a4576cc711123eeb54915f82ffc/src/xnnpack/quantization.h#L18 + qmin, qmax = -128, 127 + min_val, max_val = torch.aminmax(input, dim=-1, keepdim=True) + min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) + max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) + eps = torch.finfo(torch.float32).eps # use xnnpack eps? + + # scale + scale = (max_val_pos - min_val_neg) / float(qmax - qmin) + scale = scale.clamp(min=eps) + + # zero point + descaled_min = min_val_neg / scale + descaled_max = max_val_pos / scale + zero_point_from_min_error = qmin + descaled_min + zero_point_from_max_error = qmax + descaled_max + zero_point = torch.where( + zero_point_from_min_error + zero_point_from_max_error > 0, + qmin - descaled_min, + qmax - descaled_max, + ) + zero_point = torch.clamp(zero_point, qmin, qmax).round() + + return scale.to(torch.float32), zero_point.to(torch.float32) + + +@impl( + quantized_decomposed_lib, + "choose_qparams_per_token_asymmetric", + "Meta", +) +def choose_qparams_per_token_asymmetric_meta( + input: torch.Tensor, + dtype: torch.dtype, +) -> Tuple[torch.Tensor, torch.Tensor]: + size = (1, input.size(-1)) + return torch.empty(size, dtype=torch.double, device=input.device), torch.empty( + size, dtype=torch.int64, device=input.device + ) + + +def _per_token_quant_qparam_dim_check(input, scales, zero_points): + num_tokens = math.prod(list(input.size())[:-1]) + assert ( + num_tokens == scales.numel() + ), f"num_tokens: {num_tokens} scales: {scales.size()}" + assert ( + num_tokens == zero_points.numel() + ), f"num_tokens: {num_tokens} zero_points: {zero_points.size()}" + + +quantized_decomposed_lib.define( + "quantize_per_token(Tensor input, Tensor scales, Tensor zero_points, " + "int quant_min, int quant_max, ScalarType dtype) -> Tensor" +) + + +@impl(quantized_decomposed_lib, "quantize_per_token", "CompositeExplicitAutograd") +def quantize_per_token( + input: torch.Tensor, + scales: torch.Tensor, + zero_points: torch.Tensor, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +): + """Per token quantization for the Tensor using the quantization parameters to map + from floating point to quantized values. This means for a N dimension Tensor + (M1, M2, ...Mn, N), we calculate scales/zero_points for each N elements and quantize + every N elements with the same quantization parameter. The dimension for scales/zero_points + will be (M1 * M2 ... * Mn) + + Args: + input (torch.Tensor): original float32 or bfloat16 Tensor + scales (float32 torch.Tensor): quantization parameter for per token affine quantization + zero_points (int32 torch.Tensor): quantization parameter for per token affine quantization + quant_min (int): minimum quantized value for output Tensor + quant_max (int): maximum quantized value for output Tensor + dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor + + Returns: + Tensor with requested dtype (e.g. torch.uint8), note the quantization parameters + are not stored in the Tensor, we are storing them in function arguments instead + """ + _quant_min_max_bounds_check(quant_min, quant_max, dtype) + _per_token_quant_qparam_dim_check(input, scales, zero_points) + input = ( + torch.round(input / scales + zero_points).clamp(quant_min, quant_max).to(dtype) + ) + return input + + +@impl(quantized_decomposed_lib, "quantize_per_token", "Meta") +def quantize_per_token_meta( + input: torch.Tensor, + scales: torch.Tensor, + zero_points: torch.Tensor, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +): + _quant_min_max_bounds_check(quant_min, quant_max, dtype) + return torch.empty_like(input, dtype=dtype) + + +quantized_decomposed_lib.define( + "dequantize_per_token(Tensor input, Tensor scales, Tensor zero_points, " + "int quant_min, int quant_max, ScalarType dtype, ScalarType output_dtype) -> Tensor" +) + + +@impl(quantized_decomposed_lib, "dequantize_per_token", "CompositeExplicitAutograd") +def dequantize_per_token( + input: torch.Tensor, + scales: torch.Tensor, + zero_points: torch.Tensor, + quant_min: int, + quant_max: int, + dtype: torch.dtype, + output_dtype: torch.dtype = torch.float32, +): + """Per token dequantization for the Tensor using the quantization parameters to map + from floating point to quantized values. This means for a N dimension Tensor + (M1, M2, ...Mn, N), we calculate scales/zero_points for each N elements and quantize + every N elements with the same quantization parameter. The dimension for scales/zero_points + will be (M1 * M2 ... * Mn) + + Args: + input (torch.Tensor): quantized Tensor (uint8, int8 etc.) + scales (float32 torch.Tensor): quantization parameter for per token affine quantization + zero_points (int32 torch.Tensor): quantization parameter for per token affine quantization + quant_min (int): minimum quantized value for input Tensor + quant_max (int): maximum quantized value for input Tensor + dtype (torch.dtype): dtype (e.g. torch.uint8) for input Tensor + output_dtype (torch.dtype): dtype (e.g. torch.float32) for output Tensor + + Returns: + dequantized Tensor with dtype `output_dtype` + """ + input = input - zero_points + input = input.to(output_dtype) * scales + return input + + +@impl(quantized_decomposed_lib, "dequantize_per_token", "Meta") +def dequantize_per_token_meta( + input: torch.Tensor, + scales: torch.Tensor, + zero_points: torch.Tensor, + quant_min: int, + quant_max: int, + dtype: torch.dtype, + output_dtype: torch.dtype = torch.float32, +): + _quant_min_max_bounds_check(quant_min, quant_max, dtype) + # TODO: support fp16 + return torch.empty_like(input, dtype=output_dtype) + + +quantized_decomposed_lib.define( + "quantize_per_channel_group(Tensor input, Tensor scales, Tensor zero_points, int quant_min, " + "int quant_max, ScalarType dtype, int group_size) -> Tensor" +) + + +# TODO: dtype is ignored for now +@impl( + quantized_decomposed_lib, "quantize_per_channel_group", "CompositeExplicitAutograd" +) +def quantize_per_channel_group( + input: torch.Tensor, + scales: torch.Tensor, + zero_points: torch.Tensor, + quant_min: int, + quant_max: int, + dtype: torch.dtype, + group_size=128, +): + assert group_size > 1 + # needed for GPTQ single column quantize + if group_size > input.shape[-1] and scales.shape[-1] == 1: + group_size = input.shape[-1] + + assert input.shape[-1] % group_size == 0 + assert input.dim() == 2 + + # TODO: check for dtype, currently we can't express torch.int4 so it's omitted + to_quant = input.reshape(-1, group_size) + assert torch.isnan(to_quant).sum() == 0 + + scales = scales.reshape(-1, 1) + zero_points = zero_points.reshape(-1, 1) + + input_int8 = ( + to_quant.div(scales) + .add(zero_points) + .round() + .clamp_(quant_min, quant_max) + .to(dtype) + .reshape_as(input) + ) + + return input_int8 + + +@impl(quantized_decomposed_lib, "quantize_per_channel_group", "Meta") +def quantize_per_channel_group_meta( + input: torch.Tensor, + scales: torch.Tensor, + zero_points: torch.Tensor, + quant_min: int, + quant_max: int, + dtype: torch.dtype, + group_size=128, +): + """Groupwise quantization within each channel for an 2-d Tensor using the quantization parameters + to map from floating point to quantized values. This means for each row of a 2-d Tensor + (M, N), we calculate scales/zero_points for each `group_size` elements + and quantize every `group_size` elements with the same quantization parameter. + The dimension for scales/zero_points will be (M * ceil(N, group_size),) + + Args: + input (torch.Tensor): original float32 or bfloat16 Tensor + scales (float32 torch.Tensor): quantization parameter for per channel group affine quantization + zero_points (int32 torch.Tensor): quantization parameter for per channel group affine quantization + quant_min (int): minimum quantized value for output Tensor + quant_max (int): maximum quantized value for output Tensor + dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor + + Returns: + Tensor with requested dtype (e.g. torch.uint8), note the quantization parameters + are not stored in the Tensor, we are storing them in function arguments instead + """ + assert group_size > 1 + # needed for GPTQ single column quantize + if group_size > input.shape[-1] and scales.shape[-1] == 1: + group_size = input.shape[-1] + + assert input.shape[-1] % group_size == 0 + assert input.dim() == 2 + return torch.empty_like(input, dtype=dtype) + + +quantized_decomposed_lib.define( + "dequantize_per_channel_group(Tensor input, Tensor scales, Tensor? zero_points, int quant_min, " + "int quant_max, ScalarType dtype, int group_size, ScalarType output_dtype) -> Tensor" +) + + +@impl( + quantized_decomposed_lib, + "dequantize_per_channel_group", + "CompositeExplicitAutograd", +) +def dequantize_per_channel_group( + w_int8: torch.Tensor, + scales: torch.Tensor, + zero_points: Optional[torch.Tensor], + quant_min: int, + quant_max: int, + dtype: torch.dtype, + group_size: int = 128, + output_dtype: torch.dtype = torch.float32, +): + """Groupwise dequantization within each channel for an 2-d Tensor using the quantization parameters + to map from floating point to quantized values. This means for each row of a 2-d Tensor + (M, N), we calculate scales/zero_points for each `group_size` elements + and quantize every `group_size` elements with the same quantization parameter. + The dimension for scales/zero_points will be (M * ceil(N, group_size),) + + Args: + input (torch.Tensor): quantized Tensor (uint8/int8 etc.) + scales (float32 torch.Tensor): quantization parameter for per channel group affine quantization + zero_points (int32 torch.Tensor): quantization parameter for per channel group affine quantization + quant_min (int): minimum quantized value for input Tensor + quant_max (int): maximum quantized value for input Tensor + dtype (torch.dtype): dtype (e.g. torch.uint8) for input Tensor + output_dtype (torch.dtype): dtype (e.g. torch.float32) for output Tensor + + Returns: + dequantized Tensor with dtype `output_dtype` + """ + + assert group_size > 1 + # needed for GPTQ single column dequantize + if group_size > w_int8.shape[-1] and scales.shape[-1] == 1: + group_size = w_int8.shape[-1] + assert w_int8.shape[-1] % group_size == 0 + assert w_int8.dim() == 2 + + w_int8_grouped = w_int8.reshape(-1, group_size) + scales = scales.reshape(-1, 1) + if zero_points is not None: + zp = zero_points.reshape(-1, 1) + else: + zp = torch.zeros([], dtype=torch.int32, device=scales.device) + w_dq = w_int8_grouped.sub(zp).mul(scales).reshape_as(w_int8).to(output_dtype) + return w_dq + + +quantized_decomposed_lib.define( + "fake_quant_per_channel(Tensor input, Tensor scales, Tensor zero_points, int axis, " + "int quant_min, int quant_max) -> Tensor") + +class FakeQuantPerChannel(torch.autograd.Function): + @staticmethod + def forward(ctx, input, scales, zero_points, axis, quant_min, quant_max): + with torch._C._AutoDispatchBelowAutograd(): + if input.dtype == torch.bfloat16: + input = input.to(torch.float32) + if scales.dtype != torch.float32: + scales = scales.to(torch.float32) + if zero_points.dtype != torch.int32: + zero_points = zero_points.to(torch.int32) + assert input.dtype == torch.float32, f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" + assert axis < input.dim(), f"Expecting axis to be < {input.dim()}" + broadcast_dims = list(range(0, axis)) + list(range(axis + 1, input.ndim)) + unsqueeze_scales = _unsqueeze_multiple(scales, broadcast_dims) + unsqueeze_zero_points = _unsqueeze_multiple(zero_points, broadcast_dims) + temp = torch.round(input * (1.0 / unsqueeze_scales)) + unsqueeze_zero_points + out = (torch.clamp(temp, quant_min, quant_max) - unsqueeze_zero_points) * unsqueeze_scales + mask = torch.logical_and((temp >= quant_min), (temp <= quant_max)) + + ctx.save_for_backward(mask) + return out + + @staticmethod + def backward(ctx, gy): + mask, = ctx.saved_tensors + return gy * mask, None, None, None, None, None + +@impl(quantized_decomposed_lib, "fake_quant_per_channel", "AutogradCPU") +def fake_quant_per_channel( + input: torch.Tensor, + scales: torch.Tensor, + zero_points: torch.Tensor, + axis: int, + quant_min: int, + quant_max: int, +) -> torch.Tensor: + return FakeQuantPerChannel.apply(input, scales, zero_points, axis, quant_min, quant_max) + +@impl(quantized_decomposed_lib, "fake_quant_per_channel", "Meta") +def fake_quant_per_channel_meta( + input: torch.Tensor, + scales: torch.Tensor, + zero_points: torch.Tensor, + axis: int, + quant_min: int, + quant_max: int, +) -> torch.Tensor: + return torch.empty_like(input) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/_lower_to_native_backend.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/_lower_to_native_backend.py new file mode 100644 index 0000000000000000000000000000000000000000..728506037b558c8798477a8d98b7191cb9fed3f0 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/_lower_to_native_backend.py @@ -0,0 +1,1170 @@ +import torch +from torch.fx import map_arg, Node +from torch.fx.graph import Graph +import torch.nn as nn +import torch.nn.functional as F +import torch.ao.nn.intrinsic as nni +import torch.ao.nn.intrinsic.quantized as nniq +import torch.ao.nn.intrinsic.quantized.dynamic as nniqd +import torch.ao.nn.quantized as nnq +import torch.ao.nn.quantized.dynamic as nnqd +import torch.ao.nn.quantized.reference as nnqr +from torch.ao.nn.quantized.modules.utils import WeightedQuantizedModule +from torch.fx import GraphModule +from .utils import ( + collect_producer_nodes, + get_linear_prepack_op_for_dtype, + get_new_attr_name_with_prefix, + get_qconv_prepack_op, + graph_module_from_producer_nodes, +) +from ..utils import _parent_name +from ..qconfig import QConfigAny +from ..quantization_mappings import get_quantized_operator +from .utils import create_node_from_old_node_preserve_meta +from typing import Dict, Tuple, Type, List, Callable, Any, Union, Set, Optional +import operator + +QOP_TO_ARG_NAMES_TO_SKIP = { + torch._ops.ops.quantized.hardswish: ['inplace'], + torch._ops.ops.quantized.elu: ['inplace'], + torch._ops.ops.quantized.dropout: ['inplace'], + torch._ops.ops.quantized.instance_norm: + ['running_mean', 'running_var', 'use_input_stats', 'momentum'], +} + +def _is_node_in_list(node, modules, func_list, method_list, module_type_list): + is_call_function = node.op == "call_function" and node.target in func_list + is_call_method = node.op == "call_method" and node.target in method_list + is_call_module = node.op == "call_module" and type(modules[str(node.target)]) in module_type_list + return is_call_function, is_call_method, is_call_module + +def is_fixed_qparams_node(node, modules): + func_list = [ + torch.nn.functional.hardsigmoid, + torch.nn.functional.sigmoid, + torch.sigmoid, + torch.tanh, + ] + method_list = [ + "hardsigmoid", + "hardsigmoid_", + "sigmoid", + "sigmoid_", + "tanh", + "tanh_", + ] + module_type_list = [ + torch.nn.Hardsigmoid, + torch.nn.Sigmoid, + torch.nn.Tanh, + torch.nn.Softmax, + ] + return _is_node_in_list(node, modules, func_list, method_list, module_type_list) + +def is_default_node(node, modules): + func_list = [ + torch.nn.functional.elu, + torch.nn.functional.hardswish, + torch.nn.functional.instance_norm, + torch.nn.functional.layer_norm, + torch.nn.functional.leaky_relu, + torch.nn.functional.dropout, + ] + method_list: List[Any] = [] + module_type_list = [ + nnqr.ConvTranspose1d, + nnqr.ConvTranspose2d, + nnqr.ConvTranspose3d, + torch.nn.ELU, + torch.nn.LeakyReLU, + torch.nn.Hardswish, + torch.nn.InstanceNorm1d, + torch.nn.InstanceNorm2d, + torch.nn.InstanceNorm3d, + torch.nn.LayerNorm, + torch.nn.Dropout, + torch.nn.PReLU, + torch.nn.BatchNorm2d, + torch.nn.BatchNorm3d, + torch.ao.nn.intrinsic.BNReLU2d, + torch.ao.nn.intrinsic.BNReLU3d, + ] + return _is_node_in_list(node, modules, func_list, method_list, module_type_list) + +def is_copy_node(node, modules): + func_list = [ + torch.adaptive_avg_pool1d, + torch.nn.functional.adaptive_avg_pool2d, + torch.nn.functional.adaptive_avg_pool3d, + torch.nn.functional.hardtanh, + torch.nn.functional.hardtanh_, + torch.nn.functional.interpolate, + torch.nn.functional.max_pool1d, + torch.nn.functional.max_pool2d, + torch.nn.functional.max_pool3d, + torch.nn.functional.relu, + torch.nn.functional.relu6, + torch.avg_pool1d, + torch._C._nn.avg_pool2d, + torch._C._nn.avg_pool3d, + torch.clamp, + torch.flatten, + torch.mean, + operator.floordiv, + # F.channel_shuffle and torch.channel_shuffle are essentially the same thing + # so we only need to put one of them here + torch.channel_shuffle, + ] + method_list = [ + "clamp", + "mean", + "relu", + "relu_", + ] + module_type_list = [ + torch.nn.AdaptiveAvgPool1d, + torch.nn.AdaptiveAvgPool2d, + torch.nn.AdaptiveAvgPool3d, + torch.nn.AvgPool1d, + torch.nn.AvgPool2d, + torch.nn.AvgPool3d, + torch.nn.Hardtanh, + torch.nn.MaxPool1d, + torch.nn.MaxPool2d, + torch.nn.MaxPool3d, + torch.nn.ReLU, + torch.nn.ReLU6, + torch.nn.ChannelShuffle, + ] + return _is_node_in_list(node, modules, func_list, method_list, module_type_list) + +def is_general_tensor_shape_node(node, modules): + func_list = [ + torch.narrow, + torch.transpose, + torch.repeat_interleave, + torch.squeeze, + torch.stack, + torch.unsqueeze, + torch.nn.functional.pixel_shuffle, + torch.nn.functional.pixel_unshuffle, + ] + method_list = [ + "contiguous", + "detach", + "detach_", + "permute", + "repeat", + "repeat_interleave", + "reshape", + "resize_", + "shape", + "size", + "squeeze", + "squeeze_", + "transpose", + "unsqueeze", + "unsqueeze_", + "view", + ] + module_type_list = [ + torch.nn.Identity, + torch.nn.PixelShuffle, + torch.nn.PixelUnshuffle, + ] + return _is_node_in_list(node, modules, func_list, method_list, module_type_list) + +def is_other_node(node, modules): + func_list = [ + torch.cat, + ] + method_list: List[Any] = [] + module_type_list: List[Any] = [] + return _is_node_in_list(node, modules, func_list, method_list, module_type_list) + +def is_special_pattern_node(node, modules): + res_function, res_method, res_module = False, False, False + for checker in [is_fixed_qparams_node, is_default_node, is_copy_node, is_general_tensor_shape_node, is_other_node]: + is_call_function, is_call_method, is_call_module = checker(node, modules) + res_function = res_function or is_call_function + res_method = res_method or is_call_method + res_module = res_module or is_call_module + return res_function, res_method, res_module + +def is_dequantize_node(node): + return isinstance(node, Node) and node.op == "call_method" and node.target == "dequantize" + +def is_getattr_tensor_metadata_node(node): + return node.op == "call_function" and \ + node.target == getattr and \ + node.args[1] in ["shape"] + +def is_get_tensor_info_node(node): + return node.op == "call_method" and \ + node.target in ["shape", "size"] + +def should_skip_lowering(op: torch.fx.node.Node, qconfig_map: Dict[str, QConfigAny]): + """ + Return True if the op is configured with a None qconfig, False otherwise. + Note: maybe need to generalize this to also check for the dtype, and we + only lower when dtype matches, but right now fbgemm/qnnpack only support + a single dtype, so it is OK for now. + """ + return op.name in qconfig_map and qconfig_map[op.name] is None + +# Mapping from reference module class to the replacement static quantized module class for lowering +STATIC_LOWER_MODULE_MAP: Dict[Type[nn.Module], Type[WeightedQuantizedModule]] = { + nnqr.Linear: nnq.Linear, + nnqr.Conv1d: nnq.Conv1d, + nnqr.Conv2d: nnq.Conv2d, + nnqr.Conv3d: nnq.Conv3d, +} + +# Mapping from reference module class to the replacement dynamic quantized module class for lowering +DYNAMIC_LOWER_MODULE_MAP: Dict[Type[nn.Module], Type[nn.Module]] = { + nnqr.Linear: nnqd.Linear, + nnqr.GRUCell: nnqd.GRUCell, + nnqr.LSTMCell: nnqd.LSTMCell, + nnqr.RNNCell: nnqd.RNNCell, + nnqr.LSTM: nnqd.LSTM, + nnqr.GRU: nnqd.GRU, +} + +# Mapping from reference module class to the replacement weight only quantized module class for lowering +# TODO: correct the namespace for these modules +WEIGHT_ONLY_LOWER_MODULE_MAP: Dict[Type[nn.Module], Type[nn.Module]] = { + nnqr.Embedding: nnq.Embedding, + nnqr.EmbeddingBag: nnq.EmbeddingBag, +} + +# TODO: merge with STATIC_LOWER_MODULE_MAP after we merge +# _lower_static_weighted_ref_module and special_pattern_replacement +SPECIAL_PATTERN_LOWER_MODULE_MAP = { + nn.BatchNorm2d: nnq.BatchNorm2d, + nn.BatchNorm3d: nnq.BatchNorm3d, + nnqr.ConvTranspose1d: nnq.ConvTranspose1d, + nnqr.ConvTranspose2d: nnq.ConvTranspose2d, + nnqr.ConvTranspose3d: nnq.ConvTranspose3d, + nn.ELU: nnq.ELU, + nn.LeakyReLU: nnq.LeakyReLU, + nn.Hardswish: nnq.Hardswish, + nn.InstanceNorm1d: nnq.InstanceNorm1d, + nn.InstanceNorm2d: nnq.InstanceNorm2d, + nn.InstanceNorm3d: nnq.InstanceNorm3d, + nn.LayerNorm: nnq.LayerNorm, + nn.Dropout: nnq.Dropout, + nn.Softmax: nnq.Softmax, + nn.PReLU: nnq.PReLU, + nni.BNReLU2d: nniq.BNReLU2d, + nni.BNReLU3d: nniq.BNReLU3d, +} + +# Mapping from fused module class to a 2-tuple of: +# 1) The inner reference module class +# 2) The replacement static quantized module class for lowering +STATIC_LOWER_FUSED_MODULE_MAP: Dict[Type[nn.Module], Tuple[Type[nn.Module], Type[WeightedQuantizedModule]]] = { + nni.LinearReLU: (nnqr.Linear, nniq.LinearReLU), + # TODO: LinearLeakyReLU is registered as global but it is only fused and + # lowered when ondnn's backend config is used. Maybe need to separate + # registration and lowering functions for different backends in the future. + nni.LinearLeakyReLU: (nnqr.Linear, nniq.LinearLeakyReLU), + nni.LinearTanh: (nnqr.Linear, nniq.LinearTanh), + nni.ConvReLU1d: (nnqr.Conv1d, nniq.ConvReLU1d), + nni.ConvReLU2d: (nnqr.Conv2d, nniq.ConvReLU2d), + nni.ConvReLU3d: (nnqr.Conv3d, nniq.ConvReLU3d), +} + +# The difference between STATIC_LOWER_FUSED_MODULE_TWO_INPUTS_MAP and STATIC_LOWER_FUSED_MODULE_MAP: +# The refer node inside STATIC_LOWER_FUSED_MODULE_TWO_INPUTS_MAP has 2 inputs. +# Mapping from fused module class to a 2-tuple of: +# 1) The inner reference module class +# 2) The replacement static quantized module class for lowering +STATIC_LOWER_FUSED_MODULE_TWO_INPUTS_MAP: Dict[Type[nn.Module], Tuple[Type[nn.Module], Type[WeightedQuantizedModule]]] = { + nni.ConvAdd2d: (nnqr.Conv2d, nniq.ConvAdd2d), + nni.ConvAddReLU2d: (nnqr.Conv2d, nniq.ConvAddReLU2d), +} + +# Mapping from fused module class to a 2-tuple of: +# 1) The inner reference module class +# 2) The replacement dynamic quantized module class for lowering +DYNAMIC_LOWER_FUSED_MODULE_MAP: Dict[Type[nn.Module], Tuple[Type[nn.Module], Type[nn.Module]]] = { + nni.LinearReLU: (nnqr.Linear, nniqd.LinearReLU), +} + +# Mapping from a functional to lower to a 2-tuple of +# 1) The quantized version of the op +# 2) The quantized version of the op fused with relu, if it exists, else None +STATIC_LOWER_FUNCTIONAL_MAP: Dict[Callable, Tuple[Callable, Optional[Callable]]] = { + F.linear: (torch.ops.quantized.linear, torch.ops.quantized.linear_relu), + F.conv1d: (torch.ops.quantized.conv1d, torch.ops.quantized.conv1d_relu), + F.conv2d: (torch.ops.quantized.conv2d, torch.ops.quantized.conv2d_relu), + F.conv3d: (torch.ops.quantized.conv3d, torch.ops.quantized.conv3d_relu), + F.conv_transpose1d: (torch.ops.quantized.conv_transpose1d, None), + F.conv_transpose2d: (torch.ops.quantized.conv_transpose2d, None), + F.conv_transpose3d: (torch.ops.quantized.conv_transpose3d, None), +} + +WEIGHT_PREPACK_OPS: Set[Callable] = { + torch._ops.ops.quantized.linear_prepack, + torch._ops.ops.quantized.linear_prepack_fp16, + torch._ops.ops.quantized.conv1d_prepack, + torch._ops.ops.quantized.conv2d_prepack, + torch._ops.ops.quantized.conv3d_prepack, + torch.ops.quantized.conv_transpose1d_prepack, + torch.ops.quantized.conv_transpose2d_prepack, + torch.ops.quantized.conv_transpose3d_prepack, +} + +# Mapping from a functional to a dictionary, where the key is a 2-tuple of +# (input_activation_dtype, weight_dtype) and the value is a 2-tuple of +# 1) The dynamically quantized version of the op +# 2) The dynamically quantized version of the op fused with relu, if it exists, else None +DYNAMIC_LOWER_FUNCTIONAL_MAP: Dict[Callable, Dict[Tuple[torch.dtype, torch.dtype], Tuple[Callable, Optional[Callable]]]] = { + F.linear: { + (torch.quint8, torch.qint8): (torch.ops.quantized.linear_dynamic, + torch.ops.quantized.linear_relu_dynamic), + (torch.float16, torch.float16): (torch.ops.quantized.linear_dynamic_fp16, + torch.ops.quantized.linear_relu_dynamic_fp16) + }, + # dynamic conv + relu is not available yet + F.conv1d: { + (torch.quint8, torch.qint8): (torch.ops.quantized.conv1d_dynamic, None), + }, + F.conv2d: { + (torch.quint8, torch.qint8): (torch.ops.quantized.conv2d_dynamic, None), + }, + F.conv3d: { + (torch.quint8, torch.qint8): (torch.ops.quantized.conv3d_dynamic, None), + }, +} + +CONV_FUNCTIONAL_OPS: Set[Callable] = { + F.conv1d, + F.conv2d, + F.conv3d, +} + +CONV_TRANSPOSE_FUNCTIONAL_OPS: Set[Callable] = { + F.conv_transpose1d, + F.conv_transpose2d, + F.conv_transpose3d, +} + +# TODO: add tests for lowering these ops +QBIN_OP_MAPPING: Dict[Union[Callable, str], Callable] = { + operator.add: torch.ops.quantized.add, + torch.add: torch.ops.quantized.add, + operator.mul: torch.ops.quantized.mul, + operator.matmul: torch.ops.quantized.matmul, + torch.mul: torch.ops.quantized.mul, + torch.matmul: torch.ops.quantized.matmul, +} +QBIN_RELU_OP_MAPPING: Dict[Union[Callable, str], Callable] = { + operator.add: torch.ops.quantized.add_relu, + torch.add: torch.ops.quantized.add_relu, + operator.mul: torch.ops.quantized.mul_relu, + torch.mul: torch.ops.quantized.mul_relu, +} + +def _save_packed_weight(self, destination, prefix, keep_vars): + for attr_name in dir(self): + if "_packed_weight" in attr_name and \ + isinstance(getattr(self, attr_name), torch._C.ScriptObject): # type: ignore[attr-defined] + packed_weight = getattr(self, attr_name) + destination[prefix + attr_name] = packed_weight + +def _load_packed_weight(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + attrs_to_pop = [] + for attr_name in state_dict: + if attr_name.startswith("_packed_weight") and isinstance(state_dict[attr_name], torch._C.ScriptObject): # type: ignore[attr-defined] # noqa: B950 + setattr(self, attr_name, state_dict[attr_name]) + attrs_to_pop.append(attr_name) + + # pop the packed param attributesn + for attr_name in attrs_to_pop: + state_dict.pop(attr_name) + +def fold_weight( + quantized_model: GraphModule, + node_name_to_scope: Dict[str, Tuple[str, type]] +) -> GraphModule: + """ + Trace back from the weight node util we hit getattr, reconstruct the + graph module with the traced nodes and run the graph module to pack the + weight. then replace the original chain of ops with the packed weight. + """ + packed_weights = {} + # map from folded node name to the prepacked weight name + folded_nodes = {} + # get packed weights + for node in quantized_model.graph.nodes: + if node.op == 'call_function' and node.target in WEIGHT_PREPACK_OPS: + nodes_to_fold = collect_producer_nodes(node) + if nodes_to_fold is not None: + for node_to_fold in nodes_to_fold: + folded_nodes[node_to_fold.name] = node + + prepacking_module = graph_module_from_producer_nodes( + quantized_model, nodes_to_fold) + packed_weight = prepacking_module() + packed_weights[node.name] = packed_weight + + # remove folded nodes and replace the prepacking node with getattr + folded_graph = Graph() + env: Dict[Any, Any] = {} + + def load_arg(a): + return map_arg(a, lambda node: env[node.name]) + + for node in quantized_model.graph.nodes: + prepack_node = folded_nodes.get(node.name, None) + if prepack_node is node: + packed_weight = packed_weights[node.name] + # add a prepacked attribute to root + op_node = next(iter(prepack_node.users)) + module_path, _ = node_name_to_scope[op_node.name] + get_new_packed_weight_name = \ + get_new_attr_name_with_prefix(module_path + '_packed_weight_') + packed_weight_name = get_new_packed_weight_name(quantized_model) + setattr(quantized_model, packed_weight_name, packed_weight) + # replace prepack node with a getattr node + env[node.name] = folded_graph.create_node( + 'get_attr', packed_weight_name, (), {}) + elif prepack_node is not None: + # remove the foled node + continue + else: + # copy other nodes + env[node.name] = folded_graph.node_copy(node, load_arg) + + quantized_model = GraphModule(quantized_model, folded_graph) + quantized_model._register_state_dict_hook(_save_packed_weight) + quantized_model._register_load_state_dict_pre_hook(_load_packed_weight, with_module=True) + return quantized_model + +def _get_module(node: Node, modules: Dict[str, nn.Module]) -> Optional[nn.Module]: + """ + Return the `torch.nn.Module` that corresponds to the specified node's target. + If no such node exists, return None. + """ + if node.op == "call_module" and str(node.target) in modules: + return modules[str(node.target)] + else: + return None + +def _match_static_pattern( + node: Node, + modules: Dict[str, nn.Module], + qconfig_map: Dict[str, QConfigAny], + matching_modules_or_ops: List[Callable], + dequantize_node_arg_indices: List[int] +) -> Union[Tuple[Node, Node, Node], Tuple[None, None, None]]: + """ + Match the pattern (dequantize - ref node - quantize) against the node provided. + + If there is a match, return a 3-tuple of: + 1) q_node: the quantize node, + 2) relu_node: a relu node wrapping the ref_node, and + 3) ref_node: a reference module or functional node to replace with its quantized counterpart + Otherwise, if there is no match, return a 3-tuple of (None, None, None). + + Parameters: + node: The `torch.fx.Node` to match against. + modules: A mapping from node names to modules in the model graph, used for module lookup. + qconfig_map: A mapping from node names to the qconfigs associated with the nodes. + If the corresponding qconfig for the reference node is None, then return no match. + matching_modules_or_ops: Either a list of functions or a list of `torch.nn.Module`s. + If the reference node is not in this list, then return no match. + dequantize_node_arg_indices: A list of indices in the reference node args where dequantize + nodes may be present. An empty list means skipping the check for dequantize nodes. + """ + SKIP_LOWERING_VALUE = (None, None, None) + + # Match quantize node + if node.op != "call_function" or node.target != torch.quantize_per_tensor: + return SKIP_LOWERING_VALUE + q_node = node + ref_node = q_node.args[0] + assert isinstance(ref_node, Node) + + # Handle cases where the node is wrapped in a ReLU + if (ref_node.op == "call_function" and ref_node.target in (F.relu, torch.relu)) or\ + (ref_node.op == "call_module" and type(_get_module(ref_node, modules)) == nn.ReLU): + relu_node = ref_node + ref_node = relu_node.args[0] + assert isinstance(ref_node, Node) + else: + relu_node = None + if should_skip_lowering(ref_node, qconfig_map): + return SKIP_LOWERING_VALUE + + # Match reference module or functional + if isinstance(matching_modules_or_ops[0], type) and issubclass(matching_modules_or_ops[0], nn.Module): + expected_op = "call_module" + match_key = type(_get_module(ref_node, modules)) + else: + expected_op = "call_function" + match_key = ref_node.target + if ref_node.op != expected_op or match_key not in matching_modules_or_ops: + return SKIP_LOWERING_VALUE + + # Match dequantize node(s). Both of the following conditions must pass: + # (1) All `torch.fx.Node`s at the matching indices must be a dequantize node + # (2) There must be at least one dequantize node + matched_dequantize = False + for i in dequantize_node_arg_indices: + assert i < len(ref_node.args), \ + f"Dequantize index {i} exceeded reference node's arg length {len(ref_node.args)}" + arg = ref_node.args[i] + if is_dequantize_node(arg): + matched_dequantize = True + elif isinstance(arg, Node): + return SKIP_LOWERING_VALUE + if not matched_dequantize: + return SKIP_LOWERING_VALUE + + return (q_node, relu_node, ref_node) + +def _match_static_pattern_with_two_inputs( + node: Node, + modules: Dict[str, nn.Module], + qconfig_map: Dict[str, QConfigAny], + matching_modules_or_ops: List[Callable] +) -> Union[Tuple[Node, Node], Tuple[None, None]]: + """ + (dequantize \ + Match the pattern (dequantize - ref node - quantize) against the node provided. + + If there is a match, return a 2-tuple of: + 1) q_node: the quantize node, + 2) ref_node: a reference module or functional node to replace with its quantized counterpart + Otherwise, if there is no match, return a 2-tuple of (None, None). + + Parameters: + node: The `torch.fx.Node` to match against. + modules: A mapping from node names to modules in the model graph, used for module lookup. + qconfig_map: A mapping from node names to the qconfigs associated with the nodes. + If the corresponding qconfig for the reference node is None, then return no match. + matching_modules_or_ops: Either a list of functions or a list of `torch.nn.Module`s. + If the reference node is not in this list, then return no match. + """ + SKIP_LOWERING_VALUE = (None, None) + + # Match quantize node + if node.op != "call_function" or node.target != torch.quantize_per_tensor: + return SKIP_LOWERING_VALUE + q_node = node + ref_node = q_node.args[0] + assert isinstance(ref_node, Node) + + if should_skip_lowering(ref_node, qconfig_map): + return SKIP_LOWERING_VALUE + + # Match reference module or functional + if isinstance(matching_modules_or_ops[0], type) and issubclass(matching_modules_or_ops[0], nn.Module): + expected_op = "call_module" + match_key = type(_get_module(ref_node, modules)) + else: + # This pass only support op of "call_module" + return SKIP_LOWERING_VALUE + + if ref_node.op != expected_op or match_key not in matching_modules_or_ops: + return SKIP_LOWERING_VALUE + + # Check ref_node has 2 input nodes, both are dq node. + if len(ref_node.args) != 2: + return SKIP_LOWERING_VALUE + for i in range(len(ref_node.args)): + arg = ref_node.args[i] + if not is_dequantize_node(arg): + return SKIP_LOWERING_VALUE + + return (q_node, ref_node) + +def _lower_static_weighted_ref_module( + model: GraphModule, + qconfig_map: Dict[str, QConfigAny]): + """ + Traverse the graph and find dequantize - ref module - quantize patterns + and replace them with the quantized version of the ref module. + """ + modules = dict(model.named_modules(remove_duplicate=False)) + nodes = list(model.graph.nodes) + for n in model.graph.nodes: + # Step 0: Find nodes that match this pattern (dequantize - ref module - quantize) + matching_modules = list(STATIC_LOWER_MODULE_MAP.keys()) + list(STATIC_LOWER_FUSED_MODULE_MAP.keys()) + (q_node, relu_node, ref_node) = _match_static_pattern( + n, modules, qconfig_map, matching_modules, dequantize_node_arg_indices=[0]) # type: ignore[arg-type] + if q_node is None: + continue + assert ref_node is not None + (_, scale_node, zero_point_node, _) = q_node.args + ref_module = _get_module(ref_node, modules) + ref_class = type(ref_module) + assert isinstance(scale_node, Node) + assert isinstance(zero_point_node, Node) + assert issubclass(ref_class, nn.Module) + + # Step 1: Change this pattern to use the corresponding quantized module + # For fused modules, we also check whether the inner module is a reference module + # If so, we replace the entire fused module with the corresponding quantized module + if ref_class in STATIC_LOWER_FUSED_MODULE_MAP: + inner_ref_class, q_class = STATIC_LOWER_FUSED_MODULE_MAP[ref_class] + if type(ref_module[0]) != inner_ref_class: # type: ignore[index] + continue + else: + q_class = STATIC_LOWER_MODULE_MAP[ref_class] + output_scale = getattr(model, scale_node.target) + output_zero_point = getattr(model, zero_point_node.target) + q_module = q_class.from_reference(ref_module, output_scale, output_zero_point) + # replace reference module with quantized module + parent_name, module_name = _parent_name(ref_node.target) + setattr(modules[parent_name], module_name, q_module) + + # Step 2: Reroute around dq_node, and remove q_node and its args + assert len(ref_node.args) == 1 + dq_node = ref_node.args[0] + assert isinstance(dq_node, Node) + ref_node.replace_input_with(dq_node, dq_node.args[0]) + q_node.replace_all_uses_with(ref_node) + model.graph.erase_node(q_node) + model.graph.erase_node(scale_node) + model.graph.erase_node(zero_point_node) + +def _lower_static_weighted_ref_module_with_two_inputs( + model: GraphModule, + qconfig_map: Dict[str, QConfigAny]): + """ + Traverse the graph and find patterns + dequantize dequantize + \\ // + ref module + \\ + quantize + and replace them with the quantized version of the ref module. + """ + modules = dict(model.named_modules(remove_duplicate=False)) + nodes = list(model.graph.nodes) + for n in model.graph.nodes: + # (dequantize \ + # Step 0: Find nodes that match this pattern (dequantize - ref module - quantize) + matching_modules = list(STATIC_LOWER_FUSED_MODULE_TWO_INPUTS_MAP.keys()) + (q_node, ref_node) = _match_static_pattern_with_two_inputs( + n, modules, qconfig_map, matching_modules) # type: ignore[arg-type] + if q_node is None: + continue + assert ref_node is not None + (_, scale_node, zero_point_node, _) = q_node.args + ref_module = _get_module(ref_node, modules) + ref_class = type(ref_module) + assert isinstance(scale_node, Node) + assert isinstance(zero_point_node, Node) + assert issubclass(ref_class, nn.Module) + + # Step 1: Change this pattern to use the corresponding quantized module + # For fused modules, we also check whether the inner module is a reference module + # If so, we replace the entire fused module with the corresponding quantized module + if ref_class in STATIC_LOWER_FUSED_MODULE_TWO_INPUTS_MAP: + inner_ref_class, q_class = STATIC_LOWER_FUSED_MODULE_TWO_INPUTS_MAP[ref_class] + if type(ref_module[0]) != inner_ref_class: # type: ignore[index] + continue + else: + continue + output_scale = getattr(model, scale_node.target) + output_zero_point = getattr(model, zero_point_node.target) + q_module = q_class.from_reference(ref_module, output_scale, output_zero_point) + # replace reference module with quantized module + parent_name, module_name = _parent_name(ref_node.target) + setattr(modules[parent_name], module_name, q_module) + + # Step 2: Reroute around dq_node, and remove q_node and its args + assert len(ref_node.args) == 2 + for arg in ref_node.args: + if not is_dequantize_node(arg): + continue + dq_node = arg + assert isinstance(dq_node, Node) + ref_node.replace_input_with(dq_node, dq_node.args[0]) + + q_node.replace_all_uses_with(ref_node) + model.graph.erase_node(q_node) + model.graph.erase_node(scale_node) + model.graph.erase_node(zero_point_node) + +def _lower_dynamic_weighted_ref_module(model: GraphModule): + """ + Traverse the graph and find quantize_per_tensor_dynamic - dequantize - ref_module patterns + and replace them with the dynamically quantized version of the ref module. + """ + named_modules = dict(model.named_modules(remove_duplicate=False)) + for n in model.graph.nodes: + if n.op != "call_module" or \ + type(named_modules[str(n.target)]) not in \ + set(DYNAMIC_LOWER_MODULE_MAP.keys()).union( + set(DYNAMIC_LOWER_FUSED_MODULE_MAP.keys())): + continue + ref_node = n + dq_node = ref_node.args[0] + if dq_node.op != "call_method" or dq_node.target != "dequantize": + continue + + input_dynamic_q_node = dq_node.args[0] + + if input_dynamic_q_node.op != "call_function" or \ + input_dynamic_q_node.target != torch.quantize_per_tensor_dynamic: + continue + + activation_dtype = input_dynamic_q_node.args[1] + is_fp16 = activation_dtype == torch.float16 + is_int8 = activation_dtype in [torch.quint8, torch.qint8] + if not is_int8 and not is_fp16: + continue + + ref_module = named_modules[str(ref_node.target)] + ref_class = type(ref_module) + if ref_class in DYNAMIC_LOWER_FUSED_MODULE_MAP: + inner_ref_class, q_class = DYNAMIC_LOWER_FUSED_MODULE_MAP[ref_class] + if type(ref_module[0]) != inner_ref_class: + continue + else: + q_class = DYNAMIC_LOWER_MODULE_MAP.get(ref_class) # type: ignore[assignment] + # TODO: maybe define a WeightedDynamicallyQuantizedModule + q_module = q_class.from_reference(ref_module) # type: ignore[attr-defined] + + # replace reference module with dynamically quantized module + parent_name, module_name = _parent_name(ref_node.target) + setattr(named_modules[parent_name], module_name, q_module) + ref_node.replace_input_with(dq_node, input_dynamic_q_node.args[0]) + +def _lower_weight_only_weighted_ref_module(model: GraphModule): + """ + Traverse the graph and find ref_module patterns + and replace them with the weight only quantized version of the ref module. + """ + named_modules = dict(model.named_modules(remove_duplicate=False)) + for n in model.graph.nodes: + if n.op != "call_module" or \ + type(named_modules[str(n.target)]) not in \ + set(WEIGHT_ONLY_LOWER_MODULE_MAP.keys()): + continue + ref_node = n + ref_module = named_modules[str(ref_node.target)] + ref_class = type(ref_module) + q_class = WEIGHT_ONLY_LOWER_MODULE_MAP.get(ref_class) + # TODO: WeightedQuantizedModule is currently assuming static quant apis + # with output_scale, output_zero_point in from_reference, we may want to + # relax that, or rename this + # TODO: maybe define a WeightedWeightOnlyQuantizedModule + q_module = q_class.from_reference(ref_module) # type: ignore[union-attr] + + # replace reference module with dynamically quantized module + parent_name, module_name = _parent_name(ref_node.target) + setattr(named_modules[parent_name], module_name, q_module) + +def _lower_static_weighted_ref_functional( + model: GraphModule, + qconfig_map: Dict[str, QConfigAny]): + """ + Traverse the graph and replace functional reference patterns with their quantized versions. + """ + modules = dict(model.named_modules(remove_duplicate=False)) + nodes = list(model.graph.nodes) + for n in model.graph.nodes: + # Step 0: Find nodes that match this pattern (dequantize - functional op - quantize) + matching_ops = list(STATIC_LOWER_FUNCTIONAL_MAP.keys()) + (q_node, relu_node, func_node) = _match_static_pattern( + n, modules, qconfig_map, matching_ops, dequantize_node_arg_indices=[0, 1]) + if q_node is None: + continue + assert func_node is not None + (_, output_scale_node, output_zp_node, _) = q_node.args + (input_dq_node, weight_dq_node, *remaining_func_args) = func_node.args + assert isinstance(output_zp_node, Node) + assert isinstance(input_dq_node, Node) + assert isinstance(weight_dq_node, Node) + quantized_weight = weight_dq_node.args[0] + assert isinstance(quantized_weight, Node) + if quantized_weight.op != "call_function" or\ + quantized_weight.target not in (torch.quantize_per_tensor, torch.quantize_per_channel): + continue + + # Step 1: Replace quantized weights with packed weights, which will be folded later + # Use the right prepack op and prepare the corresponding args + # Linear prepack args: (quantized weights[, bias]) + # Conv prepack args: (quantized weights[, bias, stride, padding, dilation, groups]) + prepack_args = [quantized_weight] + remaining_func_args + if func_node.target == F.linear: + weight_dtype = quantized_weight.args[-1] + prepack_op = get_linear_prepack_op_for_dtype(weight_dtype) + elif func_node.target in CONV_FUNCTIONAL_OPS: + prepack_op = get_qconv_prepack_op(func_node.target) # type: ignore[arg-type] + # For conv1d, the stride, padding, and dilation args may be ints, + # in which case we need to convert them to tuples + if func_node.target == F.conv1d: + for i in [2, 3, 4]: + if len(prepack_args) > i and isinstance(prepack_args[i], int): + prepack_args[i] = (prepack_args[i],) + elif func_node.target in CONV_TRANSPOSE_FUNCTIONAL_OPS: + prepack_op = get_qconv_prepack_op(func_node.target) # type: ignore[arg-type] + # For conv_transpose1d, the stride, padding, and dilation args may be ints, + # in which case we need to convert them to tuples + if func_node.target == F.conv_transpose1d: + # Note prepack_args[5] is groups. + for i in [2, 3, 4, 6]: + if len(prepack_args) > i and isinstance(prepack_args[i], int): + prepack_args[i] = (prepack_args[i],) + # swap dilation and groups + # prepack op has arguments: {w, b, stride, padding, output_padding, dilation, groups} + # transposed conv op has arguments: {x, w, b, stride, padding, output_padding, groups, dilation} + if (len(prepack_args) > 6): + prepack_args[5], prepack_args[6] = prepack_args[6], prepack_args[5] + else: + raise ValueError(f"Lowering is not supported for op '{func_node.target}'") + with model.graph.inserting_before(output_scale_node): + # kwargs of the func node are needed for prepack op (i.e., quantized::linear_prepack) + # They are not needed for compute op (i.e., quantized::linear) + kwargs = func_node.kwargs + # F.linear uses 'bias' key for bias while qlinear_prepack uses 'B' for bias + if func_node.target == F.linear and 'bias' in kwargs: + kwargs = kwargs.copy() + kwargs['B'] = kwargs['bias'] + del kwargs['bias'] + packed_weight = model.graph.create_node("call_function", prepack_op, tuple(prepack_args), kwargs) + + # Step 2: Replace reference pattern with the corresponding quantized op + (q_func, q_relu_func) = STATIC_LOWER_FUNCTIONAL_MAP[func_node.target] # type: ignore[index] + # conv_transpose does not support fusion with relu yet. q_relu_func is None in such cases + if q_relu_func is not None: + func_node.target = q_relu_func if relu_node is not None else q_func + else: + func_node.target = q_func + func_node.args = (input_dq_node.args[0], packed_weight, output_scale_node, output_zp_node) + # kwargs for func_node has been moved to kwargs for prepack op + func_node.kwargs = {} + q_node.replace_all_uses_with(func_node) + # Move func_node after output_zp_node in the graph + output_zp_node.append(func_node) + + # Clean up: Remove quantize node, and the relu node if it exists + model.graph.erase_node(q_node) + if relu_node is not None and q_relu_func is not None: + model.graph.erase_node(relu_node) + +def _lower_dynamic_weighted_ref_functional( + model: GraphModule, + qconfig_map: Dict[str, QConfigAny]): + """ + Traverse the graph and replace functional reference patterns with their dynamically + quantized versions. + Examples: + quantize_per_tensor_dynamic - dequantize - functional linear --> linear_dynamic + to(torch.float16) - dequantize - functional linear --> linear_dynamic_fp16 + """ + modules = dict(model.named_modules(remove_duplicate=False)) + nodes = list(model.graph.nodes) + # we want to search in reserved order so that we can match the larger patterns first + # e.g. we want to match linear - relu before linear. + for n in reversed(model.graph.nodes): + + # Step 0: Find nodes that match this pattern + # (quantize_per_tensor_dynamic - dequantize - dynamically quantized op) + # We search for the pattern backwards, starting with the quantize node + # Quantize node args: (func, scale, zp, dtype) + func_node = n + # Handle cases where the functional op is wrapped in a ReLU + if func_node.op == "call_function" and func_node.target == F.relu or \ + func_node.op == "call_module" and \ + type(modules[str(func_node.target)]) == torch.nn.ReLU: + relu_node = func_node + func_node = relu_node.args[0] + else: + relu_node = None + if should_skip_lowering(func_node, qconfig_map): + continue + # Linear args: (dequantized inputs, dequantized weights[, bias]) + # Conv args: (dequantized inputs, dequantized weights[, bias, stride, padding, dilation, groups]) + if func_node.op != "call_function" or func_node.target not in DYNAMIC_LOWER_FUNCTIONAL_MAP: + continue + (input_dq_node, weight_dq_node, *remaining_func_args) = func_node.args + if input_dq_node.op != "call_method" or input_dq_node.target != "dequantize" or \ + weight_dq_node.op != "call_method" or weight_dq_node.target != "dequantize": + continue + + input_dynamic_q_node = input_dq_node.args[0] + + if input_dynamic_q_node.op != "call_function" or \ + input_dynamic_q_node.target != torch.quantize_per_tensor_dynamic: + continue + + reduce_range_node = None + (pattern_input, activation_dtype, reduce_range_node) = input_dynamic_q_node.args + is_fp16 = activation_dtype == torch.float16 + is_int8 = activation_dtype in [torch.quint8, torch.qint8] + if not is_int8 and not is_fp16: + continue + + quantized_weight = weight_dq_node.args[0] + weight_dtype = quantized_weight.args[-1] + + # Step 1: Try to select reference pattern with the corresponding quantized op + dynamic_quant_dtype_key = (activation_dtype, weight_dtype) + if dynamic_quant_dtype_key not in DYNAMIC_LOWER_FUNCTIONAL_MAP[func_node.target]: + print(f"Didn't find dtype combination {dynamic_quant_dtype_key} during " + f"dynamic quantized op lowering for {func_node.target}") + continue + (q_func, q_relu_func) = DYNAMIC_LOWER_FUNCTIONAL_MAP[func_node.target][dynamic_quant_dtype_key] + + if q_func is None or q_relu_func is None: + print("Didn't find corresponding quantized function or quantized relu function " + f"for {func_node.target}, {dynamic_quant_dtype_key}") + continue + + # Step 2: Replace quantized weights with packed weights, which will be folded later + # Use the right prepack op and prepare the corresponding args + # Linear prepack args: (quantized weights[, bias]) + # Conv prepack args: (quantized weights[, bias, stride, padding, dilation, groups]) + prepack_args = [quantized_weight] + remaining_func_args + if func_node.target == F.linear: + prepack_op = get_linear_prepack_op_for_dtype(weight_dtype) + elif func_node.target in CONV_FUNCTIONAL_OPS: + prepack_op = get_qconv_prepack_op(func_node.target) + # For conv1d, the stride, padding, and dilation args may be ints, + # in which case we need to convert them to tuples + if func_node.target == F.conv1d: + for i in [2, 3, 4]: + if len(prepack_args) > i and isinstance(prepack_args[i], int): + prepack_args[i] = (prepack_args[i],) + else: + raise ValueError(f"Lowering is not supported for op '{func_node.target}'") + with model.graph.inserting_before(func_node): + packed_weight = model.graph.create_node("call_function", prepack_op, tuple(prepack_args), {}) + + # Step 3: Replace reference pattern with the corresponding quantized op + func_node.target = q_relu_func if relu_node is not None else q_func + if is_int8: + func_node.args = (pattern_input, packed_weight, reduce_range_node) + else: + func_node.args = (pattern_input, packed_weight) + + if relu_node is not None: + relu_node.replace_all_uses_with(func_node) + + # Step 4: Remove the relu node if it exists + if relu_node is not None: + model.graph.erase_node(relu_node) + +def _lower_quantized_binary_op( + model: GraphModule, + qconfig_map: Dict[str, QConfigAny]): + binary_ops_to_lower: List[Callable] = [operator.add, torch.add, operator.mul, torch.mul, torch.matmul] + modules = dict(model.named_modules(remove_duplicate=False)) + for n in model.graph.nodes: + # Step 0: Find nodes that match this pattern (dequantize - ref module - quantize) + (q_node, relu_node, bop_node) = _match_static_pattern( + n, modules, qconfig_map, binary_ops_to_lower, dequantize_node_arg_indices=[0, 1]) + if q_node is None: + continue + assert bop_node is not None + (_, scale_node, zero_point_node, _) = q_node.args + + # Step 1: Remove dequant nodes + num_dq_nodes = 0 + for arg in bop_node.args: + if not is_dequantize_node(arg): + continue + dq_node = arg + assert isinstance(dq_node, Node) + dn_input = dq_node.args[0] + bop_node.replace_input_with(dq_node, dn_input) + num_dq_nodes += 1 + assert num_dq_nodes > 0 + + # Step 2: Swap binary op to quantized binary op + assert bop_node.target in QBIN_OP_MAPPING + binop_to_qbinop = QBIN_OP_MAPPING if relu_node is None else QBIN_RELU_OP_MAPPING + qbin_op = binop_to_qbinop[bop_node.target] + # prepare the args for quantized binary op + # (x, y) + qop_node_args = list(bop_node.args) + # (x, y, scale, zero_point) + # add scale and zero_point arguments for Tensor - Tensor operation + if num_dq_nodes == 2: + qop_node_args.extend([scale_node, zero_point_node]) + # insert a call to quantized binary op and remove the original binary op + with model.graph.inserting_after(q_node): + qop_node = create_node_from_old_node_preserve_meta( + model.graph, + ("call_function", qbin_op, tuple(qop_node_args), {}), + bop_node) + q_node.replace_all_uses_with(qop_node) + + # Step 3: Remove quantize node, binary op node, and relu node if any + model.graph.erase_node(q_node) + if relu_node is not None: + model.graph.erase_node(relu_node) + model.graph.erase_node(bop_node) + +def special_pattern_replacement(model: GraphModule): + modules = dict(model.named_modules(remove_duplicate=False)) + for n in model.graph.nodes: + q_node = n + is_quantize = q_node.target == torch.quantize_per_tensor + is_to_fp16 = q_node.op == "call_method" and q_node.target == "to" and \ + len(q_node.args) == 2 and q_node.args[1] == torch.float16 + if not (is_quantize or is_to_fp16): + continue + ref_node = q_node.args[0] + # get output scale/zero_point/dtype from the quantize node + # ref_node, scale_node, zero_point_node, dtype = q_node.args + # TODO: add safety checks that users for the ref_node and dq_node needs to be one + is_call_function, is_call_method, is_call_module = is_fixed_qparams_node(ref_node, modules) + if is_to_fp16 and (is_call_function or is_call_method or is_call_module): + # TODO: add a warning or error out here? (bc-breaking if error out) + # warnings.warn( + # "Only reference patterns are currently supported for {dtype} dtype with {op} op" + # "".format(dtype=dtypes, op=ref_node)) + continue + + is_call_function, is_call_method, is_call_module = is_default_node(ref_node, modules) + if is_to_fp16 and (is_call_function or is_call_method or is_call_module): + # TODO: add a warning or error out here? (bc-breaking if error out) + continue + + # This check includes all supported ops + is_call_function, is_call_method, is_call_module = is_special_pattern_node(ref_node, modules) + if not (is_call_module or is_call_function or is_call_method): + continue + assert len(ref_node.args) > 0 or len(ref_node.kwargs) > 0 + dq_node_or_nodes = ref_node.args[0] if len(ref_node.args) > 0 else next(iter(ref_node.kwargs.values())) + assert isinstance(dq_node_or_nodes, (Node, tuple, list)) + is_dequantize = False + if isinstance(dq_node_or_nodes, Node): + is_dequantize = dq_node_or_nodes.op == 'call_method' and \ + dq_node_or_nodes.target == 'dequantize' + elif isinstance(dq_node_or_nodes, (tuple, list)): + is_dequantize = all( + x.op == 'call_method' and x.target == 'dequantize' + for x in dq_node_or_nodes) + + if not is_dequantize: + continue + + # TODO: enable we have patterns that needs to swap the modules + if is_call_module: + ref_module = modules[ref_node.target] + if type(ref_module) in SPECIAL_PATTERN_LOWER_MODULE_MAP and is_quantize: + qmodule_cls = SPECIAL_PATTERN_LOWER_MODULE_MAP.get(type(ref_module)) + scale_node = q_node.args[1] + zero_point_node = q_node.args[2] + output_scale = getattr(model, scale_node.target) + output_zero_point = getattr(model, zero_point_node.target) + + qmodule = qmodule_cls.from_reference(ref_module, output_scale, output_zero_point) # type:ignore[union-attr] + # replace reference module with quantized module + parent_name, module_name = _parent_name(ref_node.target) + setattr(modules[parent_name], module_name, qmodule) + + # reroute around dq node: + dq_nodes: List[Node] = [] + if isinstance(dq_node_or_nodes, Node): + dq_nodes = [dq_node_or_nodes] + elif isinstance(dq_node_or_nodes, (tuple, list)): + dq_nodes = list(dq_node_or_nodes) + + for dq_node in dq_nodes: + dn_input = dq_node.args[0] + ref_node.replace_input_with(dq_node, dn_input) + + # store q node args + qnode_qparams = list(q_node.args)[1:] + # replace uses of q node with input and remove q node + q_node_input = q_node.args[0] + q_node.replace_all_uses_with(q_node_input) + model.graph.erase_node(q_node) + + is_call_function, is_call_method, is_call_module = is_default_node(ref_node, modules) + if is_call_function: + # pass scale/zer_point arguments from quantize_per_tensor to the default node operator + # insert an op after the zero_point node so that the scale/zero_point + # nodes are is available + qop = get_quantized_operator(ref_node.target) + args = list(ref_node.args) + kwargs = dict(ref_node.kwargs) + if qop in QOP_TO_ARG_NAMES_TO_SKIP: + args_to_skip = QOP_TO_ARG_NAMES_TO_SKIP[qop] + for arg in args_to_skip: + if arg in kwargs: + kwargs.pop(arg) + kwargs["output_scale"] = qnode_qparams[0] + kwargs["output_zero_point"] = qnode_qparams[1] + with model.graph.inserting_after(qnode_qparams[1]): + qop_node = create_node_from_old_node_preserve_meta( + model.graph, + ("call_function", qop, tuple(args), kwargs), + ref_node) + ref_node.replace_all_uses_with(qop_node) + model.graph.erase_node(ref_node) + else: + # remove scale/zero_point node for quantize node + for n in qnode_qparams: + if isinstance(n, Node): + model.graph.erase_node(n) + + return model + +def _lower_getattr_tensor_metadta_op(model: GraphModule): + """ Modified the graph of the model inplace, to skip extra dequantize op before + the general tensor shape ops when possible + """ + for n in model.graph.nodes: + if is_getattr_tensor_metadata_node(n): + maybe_dq = n.args[0] + if maybe_dq.op != "call_method" or maybe_dq.target != "dequantize": + continue + # skip the dequantize node + args = list(n.args) + args[0] = n.args[0].args[0] + n.args = tuple(args) + +def _lower_get_tensor_info_op(model: GraphModule): + """ Modified the graph of the model inplace, to skip extra dequantize op before + the general tensor shape ops when possible + """ + for n in model.graph.nodes: + if not is_get_tensor_info_node(n): + continue + maybe_dq = n.args[0] + if maybe_dq.op != "call_method" or maybe_dq.target != "dequantize": + continue + # skip the dequantize node + args = list(n.args) + args[0] = n.args[0].args[0] + n.args = tuple(args) + +def _lower_to_native_backend( + model: GraphModule, + qconfig_map: Dict[str, QConfigAny], + node_name_to_scope: Dict[str, Tuple[str, type]] +) -> GraphModule: + """ Lower a quantized reference model (with reference quantized operator patterns) + to the native backend in PyTorch (fbgemm/qnnpack), both backends shares the same + operator signature so they can be lowered with the same function + """ + _lower_static_weighted_ref_module(model, qconfig_map) + _lower_static_weighted_ref_module_with_two_inputs(model, qconfig_map) + _lower_dynamic_weighted_ref_module(model) + _lower_weight_only_weighted_ref_module(model) + _lower_static_weighted_ref_functional(model, qconfig_map) + _lower_dynamic_weighted_ref_functional(model, qconfig_map) + _lower_quantized_binary_op(model, qconfig_map) + _lower_getattr_tensor_metadta_op(model) + _lower_get_tensor_info_op(model) + special_pattern_replacement(model) + model.graph.eliminate_dead_code() + model = fold_weight(model, node_name_to_scope) + model.graph.eliminate_dead_code() + model.recompile() + model.graph.lint() + return model diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/_model_report/model_report_observer.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/_model_report/model_report_observer.py new file mode 100644 index 0000000000000000000000000000000000000000..3ccf692dbe228a45f656b81cf190a3fd9e79ce93 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/_model_report/model_report_observer.py @@ -0,0 +1,265 @@ +import torch +from torch.ao.quantization.observer import ObserverBase + + +class ModelReportObserver(ObserverBase): + r"""This observer is used to record additional information regarding keeping track + of S = average_batch_activation_range/epoch_activation_range. + + The purpose of this information is to prepare a report to present to users on whether + Dynamic or Static Quantization is more appropriate for their model given the general + distributions of their data. + + Args: + ch_axis (int, optional): The channel axis for which the range and outlier stats are computed + Default: 1 + comp_percentile (float, optional): The percentile to compare against 100 percentile to find outliers + Should be between 0 and 1 exclusive + Default: 0.9 + + * :attr:`num_batches_tracked` specifies number of batches passed through the observer + + * :attr:`average_batch_activation_range` defines average across the ranges of each batch passed through + + * :attr:`epoch_activation_min` defines the minimum value passed through the observer + + * :attr:`epoch_activation_max` defines the maximum value passed through the observer + + * :attr:`ch_axis` defines the channel being used to compute per channel min max stats + + * :attr:`min_val` defines the per channel minimum values passed through + + * :attr:`max_val` defines the per channel maximum values passed through + + * :attr:`comp_percentile` defines comparison percentile to find outliers + + * :attr:`average_percentile_ratio` defines the per channel average percentile ratios + + * :attr:`percentile_batches_tracked` defines the number of percentile batches tracked for each channel + + * :attr:`constant_channels` defines the number of batches that aren't constant channels per channel + + Note: this tool is meant for FX Graph Mode Quantization + """ + + epoch_activation_min: torch.Tensor + epoch_activation_max: torch.Tensor + min_val: torch.Tensor + max_val: torch.Tensor + comp_percentile: torch.Tensor + average_percentile_ratio: torch.Tensor + percentile_batches_tracked: torch.Tensor + constant_channels: torch.Tensor + + def __init__(self, ch_axis: int = 1, comp_percentile: float = 0.9): + super().__init__(torch.qint8) + self.num_batches_tracked = 0 + + # keep track of the min and mix of the range for average batch and epoch as a whole + self.average_batch_activation_range: torch.Tensor = torch.tensor(float(0)) + self.register_buffer("epoch_activation_min", torch.tensor(float("inf"))) + self.register_buffer("epoch_activation_max", torch.tensor(float("-inf"))) + + # keep track of per channel min max information using the given channel + self.ch_axis: int = ch_axis + self.register_buffer("min_val", torch.tensor([])) + self.register_buffer("max_val", torch.tensor([])) + + # keep track of percentile ratio information per channel + self.register_buffer("comp_percentile", torch.tensor([comp_percentile])) + self.register_buffer("average_percentile_ratio", torch.tensor([])) + self.register_buffer("percentile_batches_tracked", torch.tensor([])) + self.register_buffer("constant_channels", torch.tensor([])) + + def forward(self, x): + x_copy = x.detach() # avoid keeping autograd tape + x_copy = x_copy.to(self.epoch_activation_min.dtype) + + x_copy = self._calculate_range_stats(x_copy) + x_copy = self._calculate_min_max_stats(x_copy) + x_copy = self._calculate_percentile_stats(x_copy) + + # return the passed in the value + return x + + def _calculate_range_stats(self, x_copy): + r"""Calculates and stores range stats with forward values. + + Args + x_copy: A copy of the forward data + + Returns the passed in x_copy + """ + # get the min, max values of the data + min_val_cur, max_val_cur = torch.aminmax(x_copy) + + # calculate new epoch range values + epoch_min_val = torch.min(self.epoch_activation_min, min_val_cur) + epoch_max_val = torch.max(self.epoch_activation_max, max_val_cur) + + self.epoch_activation_min.copy_(epoch_min_val) + self.epoch_activation_max.copy_(epoch_max_val) + + # calculate the average batch activation range + current_batch_range = max_val_cur - min_val_cur + new_range = ( + self.average_batch_activation_range * self.num_batches_tracked + + current_batch_range + ) / (self.num_batches_tracked + 1) + + self.average_batch_activation_range = new_range + self.num_batches_tracked += 1 # new batch was processed + + return x_copy + + def _calculate_min_max_stats(self, x_copy): + r"""Calculates and stores the per_channel min, max stats with forward values. + Does calculation based on channel axis: self.ch_axis + + Args + x_copy: A copy of the forward data + + Returns the passed in x_copy + """ + # get the current min and max vals + min_val = self.min_val + max_val = self.max_val + x_dim = x_copy.size() + + new_axis_list = [i for i in range(len(x_dim))] # noqa: C416 + new_axis_list[self.ch_axis] = 0 + new_axis_list[0] = self.ch_axis + y = x_copy.permute(new_axis_list) + # Need to match dtype of min/max because the updates to buffers + # are done in place and types need to match for comparisons + y = y.to(self.min_val.dtype) + y = torch.flatten(y, start_dim=1) + if min_val.numel() == 0 or max_val.numel() == 0: + min_val, max_val = torch.aminmax(y, dim=1) + else: + min_val_cur, max_val_cur = torch.aminmax(y, dim=1) + min_val = torch.min(min_val_cur, min_val) + max_val = torch.max(max_val_cur, max_val) + + self.min_val.resize_(min_val.shape) + self.max_val.resize_(max_val.shape) + self.min_val.copy_(min_val) + self.max_val.copy_(max_val) + + return x_copy + + def _calculate_percentile_stats(self, x_copy): + r"""Calculates and stores the per_channel percentile stats with forward values. + Does calculation based on channel axis: self.ch_axis + + Args + x_copy: A copy of the forward data + + Returns the passed in x_copy + """ + # get the dimension of the copy + x_dim = x_copy.size() + + new_axis_list = [i for i in range(len(x_dim))] # noqa: C416 + new_axis_list[self.ch_axis] = 0 + new_axis_list[0] = self.ch_axis + y = x_copy.permute(new_axis_list) + # Need to match dtype of min/max because the updates to buffers + # are done in place and types need to match for comparisons + y = y.to(self.min_val.dtype) + y = torch.flatten(y, start_dim=1) + y = y.to(dtype=self.min_val.dtype, device="cpu") + + # find the percentile values along the axis + # we want both 100th percentile and comp_percentile + # we also want to find 0th quartile to see if we have constant channel + quantiles_list = [0, self.comp_percentile, 1.00] + quantiles_to_find = torch.tensor(quantiles_list, dtype=self.min_val.dtype) + + # find the quantiles + desired_quantiles = torch.quantile(y, quantiles_to_find, dim=self.ch_axis, interpolation="lower") + zero_quantile = desired_quantiles[0] + comp_quantile = desired_quantiles[1] + hundreth_quartile = desired_quantiles[2] + + # if any of the channels have 0s, we ignore that channel for this calculation + any_non_zero_quantile_value: torch.Tensor = (comp_quantile != torch.tensor([0])) | (hundreth_quartile != torch.tensor([0])) + any_non_zero_quantile_value = any_non_zero_quantile_value.int() # transform boolean values to int values + + # we also check if we have a constant channel + any_constant_channels: torch.Tensor = (hundreth_quartile - zero_quantile) == torch.tensor([0]) + any_constant_channels = any_constant_channels.int() # transform boolean values to int values + + # possibilities to get nan as an answer + # will ignore any of these three cases with 0s and just not deal with them for now + # case (1) 0 in numerator: issue if 0 is largest, all negative, and rest are really negative + # case (2) 0 in denominator: is possible unless case 3, we just ignore + # case (3) 0 in both: not outlier, channel just kinda useless, ignore + + # get the ratio and get rid of nan values + quantile_ratios = hundreth_quartile / comp_quantile + quantile_ratios = torch.nan_to_num(quantile_ratios) + # update averages, remembering to only update if didn't have zeros + ratio_if_not_zero = any_non_zero_quantile_value * quantile_ratios + + # if num_batches and average_ratio are not initialized, we want to initialize them + if self.percentile_batches_tracked.shape[0] == 0 or self.average_percentile_ratio.shape[0] == 0: + self.percentile_batches_tracked = torch.zeros_like(any_non_zero_quantile_value) + self.average_percentile_ratio = torch.zeros_like(ratio_if_not_zero) + + # also initialize the constant channel var if that is not initialized separately + if self.constant_channels.shape[0] == 0: + self.constant_channels = torch.zeros_like(any_constant_channels) + + # get current num batches and average ratio + num_batches = self.percentile_batches_tracked + average_ratio = self.average_percentile_ratio + + # calculate new_number of batches, new_ratios, and get rid of nans because of 0 size batches + new_number_of_batches: torch.Tensor = num_batches + any_non_zero_quantile_value + new_ratios: torch.Tensor = ((average_ratio * num_batches) + ratio_if_not_zero) / new_number_of_batches + new_ratios = torch.nan_to_num(new_ratios) + + # update the number of non-constant channels + new_constant_count: torch.Tensor = self.constant_channels + any_constant_channels + + # update the values locally + self.percentile_batches_tracked.copy_(new_number_of_batches) + self.average_percentile_ratio.copy_(new_ratios) + self.constant_channels.copy_(new_constant_count) + + return x_copy + + @torch.jit.export + def get_batch_to_epoch_ratio(self): + epoch_activation_range = self.epoch_activation_max - self.epoch_activation_min + + if epoch_activation_range == torch.tensor(float(0)): + raise ValueError("Range for Epoch is 0") + elif epoch_activation_range == torch.tensor(float("inf")): + raise ValueError( + "No data has been run through observer or infinity value present" + ) + else: + return self.average_batch_activation_range / epoch_activation_range + + @torch.jit.export + def reset_batch_and_epoch_values(self): + # set all the values back to their original defaults for a new epoch + # keep device + device = self.max_val.device + self.num_batches_tracked = 0 + self.average_batch_activation_range = torch.tensor(float(0), device=device) + self.epoch_activation_min = torch.tensor(float("inf"), device=device) + self.epoch_activation_max = torch.tensor(float("-inf"), device=device) + self.min_val = torch.tensor([], device=device) + self.max_val = torch.tensor([], device=device) + self.average_percentile_ratio = torch.tensor([], device=device) + self.percentile_batches_tracked = torch.tensor([], device=device) + self.constant_channels = torch.tensor([], device=device) + + @torch.jit.export + def calculate_qparams(self): + raise Exception( + "calculate_qparams should not be called for ModelReportObserver" + ) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/custom_config.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/custom_config.py new file mode 100644 index 0000000000000000000000000000000000000000..4fb2c3a28cb0a589784f899d199b6eab7b2111a8 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/custom_config.py @@ -0,0 +1,419 @@ +from __future__ import annotations +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Type + +from torch.ao.quantization import QConfigMapping +from torch.ao.quantization.backend_config import BackendConfig +from torch.ao.quantization.quant_type import QuantType, _quant_type_from_str, _get_quant_type_to_str + + +__all__ = [ + "ConvertCustomConfig", + "FuseCustomConfig", + "PrepareCustomConfig", + "StandaloneModuleConfigEntry", +] + + +# TODO: replace all usages with these constants +STANDALONE_MODULE_NAME_DICT_KEY = "standalone_module_name" +STANDALONE_MODULE_CLASS_DICT_KEY = "standalone_module_class" +FLOAT_TO_OBSERVED_DICT_KEY = "float_to_observed_custom_module_class" +OBSERVED_TO_QUANTIZED_DICT_KEY = "observed_to_quantized_custom_module_class" +NON_TRACEABLE_MODULE_NAME_DICT_KEY = "non_traceable_module_name" +NON_TRACEABLE_MODULE_CLASS_DICT_KEY = "non_traceable_module_class" +INPUT_QUANTIZED_INDEXES_DICT_KEY = "input_quantized_idxs" +OUTPUT_QUANTIZED_INDEXES_DICT_KEY = "output_quantized_idxs" +PRESERVED_ATTRIBUTES_DICT_KEY = "preserved_attributes" + + +@dataclass +class StandaloneModuleConfigEntry: + # qconfig_mapping for the prepare function called in the submodule, + # None means use qconfig from parent qconfig_mapping + qconfig_mapping: Optional[QConfigMapping] + example_inputs: Tuple[Any, ...] + prepare_custom_config: Optional[PrepareCustomConfig] + backend_config: Optional[BackendConfig] + + +class PrepareCustomConfig: + """ + Custom configuration for :func:`~torch.ao.quantization.quantize_fx.prepare_fx` and + :func:`~torch.ao.quantization.quantize_fx.prepare_qat_fx`. + + Example usage:: + + prepare_custom_config = PrepareCustomConfig() \ + .set_standalone_module_name("module1", qconfig_mapping, example_inputs, \ + child_prepare_custom_config, backend_config) \ + .set_standalone_module_class(MyStandaloneModule, qconfig_mapping, example_inputs, \ + child_prepare_custom_config, backend_config) \ + .set_float_to_observed_mapping(FloatCustomModule, ObservedCustomModule) \ + .set_non_traceable_module_names(["module2", "module3"]) \ + .set_non_traceable_module_classes([NonTraceableModule1, NonTraceableModule2]) \ + .set_input_quantized_indexes([0]) \ + .set_output_quantized_indexes([0]) \ + .set_preserved_attributes(["attr1", "attr2"]) + """ + def __init__(self): + self.standalone_module_names: Dict[str, StandaloneModuleConfigEntry] = {} + self.standalone_module_classes: Dict[Type, StandaloneModuleConfigEntry] = {} + self.float_to_observed_mapping: Dict[QuantType, Dict[Type, Type]] = {} + self.non_traceable_module_names: List[str] = [] + self.non_traceable_module_classes: List[Type] = [] + self.input_quantized_indexes: List[int] = [] + self.output_quantized_indexes: List[int] = [] + self.preserved_attributes: List[str] = [] + + def __repr__(self): + dict_nonempty = { + k: v for k, v in self.__dict__.items() + if len(v) > 0 + } + return f"PrepareCustomConfig({dict_nonempty})" + + def set_standalone_module_name( + self, + module_name: str, + qconfig_mapping: Optional[QConfigMapping], + example_inputs: Tuple[Any, ...], + prepare_custom_config: Optional[PrepareCustomConfig], + backend_config: Optional[BackendConfig]) -> PrepareCustomConfig: + """ + Set the configuration for running a standalone module identified by ``module_name``. + + If ``qconfig_mapping`` is None, the parent ``qconfig_mapping`` will be used instead. + If ``prepare_custom_config`` is None, an empty ``PrepareCustomConfig`` will be used. + If ``backend_config`` is None, the parent ``backend_config`` will be used instead. + """ + self.standalone_module_names[module_name] = \ + StandaloneModuleConfigEntry(qconfig_mapping, example_inputs, prepare_custom_config, backend_config) + return self + + def set_standalone_module_class( + self, + module_class: Type, + qconfig_mapping: Optional[QConfigMapping], + example_inputs: Tuple[Any, ...], + prepare_custom_config: Optional[PrepareCustomConfig], + backend_config: Optional[BackendConfig]) -> PrepareCustomConfig: + """ + Set the configuration for running a standalone module identified by ``module_class``. + + If ``qconfig_mapping`` is None, the parent ``qconfig_mapping`` will be used instead. + If ``prepare_custom_config`` is None, an empty ``PrepareCustomConfig`` will be used. + If ``backend_config`` is None, the parent ``backend_config`` will be used instead. + """ + self.standalone_module_classes[module_class] = \ + StandaloneModuleConfigEntry(qconfig_mapping, example_inputs, prepare_custom_config, backend_config) + return self + + def set_float_to_observed_mapping( + self, + float_class: Type, + observed_class: Type, + quant_type: QuantType = QuantType.STATIC) -> PrepareCustomConfig: + """ + Set the mapping from a custom float module class to a custom observed module class. + + The observed module class must have a ``from_float`` class method that converts the float module class + to the observed module class. This is currently only supported for static quantization. + """ + if quant_type != QuantType.STATIC: + raise ValueError("set_float_to_observed_mapping is currently only supported for static quantization") + if quant_type not in self.float_to_observed_mapping: + self.float_to_observed_mapping[quant_type] = {} + self.float_to_observed_mapping[quant_type][float_class] = observed_class + return self + + def set_non_traceable_module_names(self, module_names: List[str]) -> PrepareCustomConfig: + """ + Set the modules that are not symbolically traceable, identified by name. + """ + self.non_traceable_module_names = module_names + return self + + def set_non_traceable_module_classes(self, module_classes: List[Type]) -> PrepareCustomConfig: + """ + Set the modules that are not symbolically traceable, identified by class. + """ + self.non_traceable_module_classes = module_classes + return self + + def set_input_quantized_indexes(self, indexes: List[int]) -> PrepareCustomConfig: + """ + Set the indexes of the inputs of the graph that should be quantized. + Inputs are otherwise assumed to be in fp32 by default instead. + """ + self.input_quantized_indexes = indexes + return self + + def set_output_quantized_indexes(self, indexes: List[int]) -> PrepareCustomConfig: + """ + Set the indexes of the outputs of the graph that should be quantized. + Outputs are otherwise assumed to be in fp32 by default instead. + """ + self.output_quantized_indexes = indexes + return self + + def set_preserved_attributes(self, attributes: List[str]) -> PrepareCustomConfig: + """ + Set the names of the attributes that will persist in the graph module even if they are not used in + the model's ``forward`` method. + """ + self.preserved_attributes = attributes + return self + + # TODO: remove this + @classmethod + def from_dict(cls, prepare_custom_config_dict: Dict[str, Any]) -> PrepareCustomConfig: + """ + Create a ``PrepareCustomConfig`` from a dictionary with the following items: + + "standalone_module_name": a list of (module_name, qconfig_mapping, example_inputs, + child_prepare_custom_config, backend_config) tuples + + "standalone_module_class" a list of (module_class, qconfig_mapping, example_inputs, + child_prepare_custom_config, backend_config) tuples + + "float_to_observed_custom_module_class": a nested dictionary mapping from quantization + mode to an inner mapping from float module classes to observed module classes, e.g. + {"static": {FloatCustomModule: ObservedCustomModule}} + + "non_traceable_module_name": a list of modules names that are not symbolically traceable + "non_traceable_module_class": a list of module classes that are not symbolically traceable + "input_quantized_idxs": a list of indexes of graph inputs that should be quantized + "output_quantized_idxs": a list of indexes of graph outputs that should be quantized + "preserved_attributes": a list of attributes that persist even if they are not used in ``forward`` + + This function is primarily for backward compatibility and may be removed in the future. + """ + def _get_qconfig_mapping(obj: Any, dict_key: str) -> Optional[QConfigMapping]: + """ + Convert the given object into a QConfigMapping if possible, else throw an exception. + """ + if isinstance(obj, QConfigMapping) or obj is None: + return obj + if isinstance(obj, Dict): + return QConfigMapping.from_dict(obj) + raise ValueError(f"Expected QConfigMapping in prepare_custom_config_dict[\"{dict_key}\"], got '{type(obj)}'") + + def _get_prepare_custom_config(obj: Any, dict_key: str) -> Optional[PrepareCustomConfig]: + """ + Convert the given object into a PrepareCustomConfig if possible, else throw an exception. + """ + if isinstance(obj, PrepareCustomConfig) or obj is None: + return obj + if isinstance(obj, Dict): + return PrepareCustomConfig.from_dict(obj) + raise ValueError(f"Expected PrepareCustomConfig in prepare_custom_config_dict[\"{dict_key}\"], got '{type(obj)}'") + + def _get_backend_config(obj: Any, dict_key: str) -> Optional[BackendConfig]: + """ + Convert the given object into a BackendConfig if possible, else throw an exception. + """ + if isinstance(obj, BackendConfig) or obj is None: + return obj + if isinstance(obj, Dict): + return BackendConfig.from_dict(obj) + raise ValueError(f"Expected BackendConfig in prepare_custom_config_dict[\"{dict_key}\"], got '{type(obj)}'") + + conf = cls() + for (module_name, qconfig_dict, example_inputs, _prepare_custom_config_dict, backend_config_dict) in\ + prepare_custom_config_dict.get(STANDALONE_MODULE_NAME_DICT_KEY, []): + qconfig_mapping = _get_qconfig_mapping(qconfig_dict, STANDALONE_MODULE_NAME_DICT_KEY) + prepare_custom_config = _get_prepare_custom_config(_prepare_custom_config_dict, STANDALONE_MODULE_NAME_DICT_KEY) + backend_config = _get_backend_config(backend_config_dict, STANDALONE_MODULE_NAME_DICT_KEY) + conf.set_standalone_module_name( + module_name, qconfig_mapping, example_inputs, prepare_custom_config, backend_config) + for (module_class, qconfig_dict, example_inputs, _prepare_custom_config_dict, backend_config_dict) in\ + prepare_custom_config_dict.get(STANDALONE_MODULE_CLASS_DICT_KEY, []): + qconfig_mapping = _get_qconfig_mapping(qconfig_dict, STANDALONE_MODULE_CLASS_DICT_KEY) + prepare_custom_config = _get_prepare_custom_config(_prepare_custom_config_dict, STANDALONE_MODULE_CLASS_DICT_KEY) + backend_config = _get_backend_config(backend_config_dict, STANDALONE_MODULE_CLASS_DICT_KEY) + conf.set_standalone_module_class( + module_class, qconfig_mapping, example_inputs, prepare_custom_config, backend_config) + for quant_type_name, custom_module_mapping in prepare_custom_config_dict.get(FLOAT_TO_OBSERVED_DICT_KEY, {}).items(): + quant_type = _quant_type_from_str(quant_type_name) + for float_class, observed_class in custom_module_mapping.items(): + conf.set_float_to_observed_mapping(float_class, observed_class, quant_type) + conf.set_non_traceable_module_names(prepare_custom_config_dict.get(NON_TRACEABLE_MODULE_NAME_DICT_KEY, [])) + conf.set_non_traceable_module_classes(prepare_custom_config_dict.get(NON_TRACEABLE_MODULE_CLASS_DICT_KEY, [])) + conf.set_input_quantized_indexes(prepare_custom_config_dict.get(INPUT_QUANTIZED_INDEXES_DICT_KEY, [])) + conf.set_output_quantized_indexes(prepare_custom_config_dict.get(OUTPUT_QUANTIZED_INDEXES_DICT_KEY, [])) + conf.set_preserved_attributes(prepare_custom_config_dict.get(PRESERVED_ATTRIBUTES_DICT_KEY, [])) + return conf + + def to_dict(self) -> Dict[str, Any]: + """ + Convert this ``PrepareCustomConfig`` to a dictionary with the items described in + :func:`~torch.ao.quantization.fx.custom_config.PrepareCustomConfig.from_dict`. + """ + def _make_tuple(key: Any, e: StandaloneModuleConfigEntry): + qconfig_dict = e.qconfig_mapping.to_dict() if e.qconfig_mapping else None + prepare_custom_config_dict = e.prepare_custom_config.to_dict() if e.prepare_custom_config else None + return (key, qconfig_dict, e.example_inputs, prepare_custom_config_dict, e.backend_config) + + d: Dict[str, Any] = {} + for module_name, sm_config_entry in self.standalone_module_names.items(): + if STANDALONE_MODULE_NAME_DICT_KEY not in d: + d[STANDALONE_MODULE_NAME_DICT_KEY] = [] + d[STANDALONE_MODULE_NAME_DICT_KEY].append(_make_tuple(module_name, sm_config_entry)) + for module_class, sm_config_entry in self.standalone_module_classes.items(): + if STANDALONE_MODULE_CLASS_DICT_KEY not in d: + d[STANDALONE_MODULE_CLASS_DICT_KEY] = [] + d[STANDALONE_MODULE_CLASS_DICT_KEY].append(_make_tuple(module_class, sm_config_entry)) + for quant_type, float_to_observed_mapping in self.float_to_observed_mapping.items(): + if FLOAT_TO_OBSERVED_DICT_KEY not in d: + d[FLOAT_TO_OBSERVED_DICT_KEY] = {} + d[FLOAT_TO_OBSERVED_DICT_KEY][_get_quant_type_to_str(quant_type)] = float_to_observed_mapping + if len(self.non_traceable_module_names) > 0: + d[NON_TRACEABLE_MODULE_NAME_DICT_KEY] = self.non_traceable_module_names + if len(self.non_traceable_module_classes) > 0: + d[NON_TRACEABLE_MODULE_CLASS_DICT_KEY] = self.non_traceable_module_classes + if len(self.input_quantized_indexes) > 0: + d[INPUT_QUANTIZED_INDEXES_DICT_KEY] = self.input_quantized_indexes + if len(self.output_quantized_indexes) > 0: + d[OUTPUT_QUANTIZED_INDEXES_DICT_KEY] = self.output_quantized_indexes + if len(self.preserved_attributes) > 0: + d[PRESERVED_ATTRIBUTES_DICT_KEY] = self.preserved_attributes + return d + + +class ConvertCustomConfig: + """ + Custom configuration for :func:`~torch.ao.quantization.quantize_fx.convert_fx`. + + Example usage:: + + convert_custom_config = ConvertCustomConfig() \ + .set_observed_to_quantized_mapping(ObservedCustomModule, QuantizedCustomModule) \ + .set_preserved_attributes(["attr1", "attr2"]) + """ + + def __init__(self): + self.observed_to_quantized_mapping: Dict[QuantType, Dict[Type, Type]] = {} + self.preserved_attributes: List[str] = [] + + def __repr__(self): + dict_nonempty = { + k: v for k, v in self.__dict__.items() + if len(v) > 0 + } + return f"ConvertCustomConfig({dict_nonempty})" + + def set_observed_to_quantized_mapping( + self, + observed_class: Type, + quantized_class: Type, + quant_type: QuantType = QuantType.STATIC) -> ConvertCustomConfig: + """ + Set the mapping from a custom observed module class to a custom quantized module class. + + The quantized module class must have a ``from_observed`` class method that converts the observed module class + to the quantized module class. + """ + if quant_type not in self.observed_to_quantized_mapping: + self.observed_to_quantized_mapping[quant_type] = {} + self.observed_to_quantized_mapping[quant_type][observed_class] = quantized_class + return self + + def set_preserved_attributes(self, attributes: List[str]) -> ConvertCustomConfig: + """ + Set the names of the attributes that will persist in the graph module even if they are not used in + the model's ``forward`` method. + """ + self.preserved_attributes = attributes + return self + + # TODO: remove this + @classmethod + def from_dict(cls, convert_custom_config_dict: Dict[str, Any]) -> ConvertCustomConfig: + """ + Create a ``ConvertCustomConfig`` from a dictionary with the following items: + + "observed_to_quantized_custom_module_class": a nested dictionary mapping from quantization + mode to an inner mapping from observed module classes to quantized module classes, e.g.:: + { + "static": {FloatCustomModule: ObservedCustomModule}, + "dynamic": {FloatCustomModule: ObservedCustomModule}, + "weight_only": {FloatCustomModule: ObservedCustomModule} + } + "preserved_attributes": a list of attributes that persist even if they are not used in ``forward`` + + This function is primarily for backward compatibility and may be removed in the future. + """ + conf = cls() + for quant_type_name, custom_module_mapping in convert_custom_config_dict.get(OBSERVED_TO_QUANTIZED_DICT_KEY, {}).items(): + quant_type = _quant_type_from_str(quant_type_name) + for observed_class, quantized_class in custom_module_mapping.items(): + conf.set_observed_to_quantized_mapping(observed_class, quantized_class, quant_type) + conf.set_preserved_attributes(convert_custom_config_dict.get(PRESERVED_ATTRIBUTES_DICT_KEY, [])) + return conf + + def to_dict(self) -> Dict[str, Any]: + """ + Convert this ``ConvertCustomConfig`` to a dictionary with the items described in + :func:`~torch.ao.quantization.fx.custom_config.ConvertCustomConfig.from_dict`. + """ + d: Dict[str, Any] = {} + for quant_type, observed_to_quantized_mapping in self.observed_to_quantized_mapping.items(): + if OBSERVED_TO_QUANTIZED_DICT_KEY not in d: + d[OBSERVED_TO_QUANTIZED_DICT_KEY] = {} + d[OBSERVED_TO_QUANTIZED_DICT_KEY][_get_quant_type_to_str(quant_type)] = observed_to_quantized_mapping + if len(self.preserved_attributes) > 0: + d[PRESERVED_ATTRIBUTES_DICT_KEY] = self.preserved_attributes + return d + + +class FuseCustomConfig: + """ + Custom configuration for :func:`~torch.ao.quantization.quantize_fx.fuse_fx`. + + Example usage:: + + fuse_custom_config = FuseCustomConfig().set_preserved_attributes(["attr1", "attr2"]) + """ + + def __init__(self): + self.preserved_attributes: List[str] = [] + + def __repr__(self): + dict_nonempty = { + k: v for k, v in self.__dict__.items() + if len(v) > 0 + } + return f"FuseCustomConfig({dict_nonempty})" + + def set_preserved_attributes(self, attributes: List[str]) -> FuseCustomConfig: + """ + Set the names of the attributes that will persist in the graph module even if they are not used in + the model's ``forward`` method. + """ + self.preserved_attributes = attributes + return self + + # TODO: remove this + @classmethod + def from_dict(cls, fuse_custom_config_dict: Dict[str, Any]) -> FuseCustomConfig: + """ + Create a ``ConvertCustomConfig`` from a dictionary with the following items: + + "preserved_attributes": a list of attributes that persist even if they are not used in ``forward`` + + This function is primarily for backward compatibility and may be removed in the future. + """ + conf = cls() + conf.set_preserved_attributes(fuse_custom_config_dict.get(PRESERVED_ATTRIBUTES_DICT_KEY, [])) + return conf + + def to_dict(self) -> Dict[str, Any]: + """ + Convert this ``FuseCustomConfig`` to a dictionary with the items described in + :func:`~torch.ao.quantization.fx.custom_config.ConvertCustomConfig.from_dict`. + """ + d: Dict[str, Any] = {} + if len(self.preserved_attributes) > 0: + d[PRESERVED_ATTRIBUTES_DICT_KEY] = self.preserved_attributes + return d diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/fuse.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/fuse.py new file mode 100644 index 0000000000000000000000000000000000000000..91b876997d10910e5b411225c2654857eab07f2b --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/fuse.py @@ -0,0 +1,161 @@ +from torch.fx import ( + GraphModule, + Node, + map_arg +) +from torch.fx.graph import Graph +from .match_utils import ( + _is_match, + MatchAllNode, +) +from .pattern_utils import ( + _sorted_patterns_dict, +) + +from ..backend_config import ( + BackendConfig, + get_native_backend_config, +) +from ..backend_config.utils import ( + get_fuser_method_mapping, + get_fusion_pattern_to_root_node_getter, + get_fusion_pattern_to_extra_inputs_getter, +) + +from .custom_config import FuseCustomConfig + +from .fuse_handler import ( + _get_fusion_pattern_to_fuse_handler_cls, + FuseHandler, +) + +from typing import Any, Callable, Dict, List, Tuple, Union +import warnings + +from torch.ao.quantization.utils import Pattern, NodePattern + + +__all__ = [ + "fuse", + # TODO: We should make this private in the future + # This is currently needed for test_public_bindings for some reason + "FuseHandler", +] + + +def fuse( + model: GraphModule, + is_qat: bool, + fuse_custom_config: Union[FuseCustomConfig, Dict[str, Any], None] = None, + backend_config: Union[BackendConfig, Dict[str, Any], None] = None, +) -> GraphModule: + if fuse_custom_config is None: + fuse_custom_config = FuseCustomConfig() + + if isinstance(fuse_custom_config, Dict): + warnings.warn( + "Passing a fuse_custom_config_dict to fuse is deprecated and will not be supported " + "in a future version. Please pass in a FuseCustomConfig instead.") + fuse_custom_config = FuseCustomConfig.from_dict(fuse_custom_config) + + if isinstance(backend_config, Dict): + warnings.warn( + "Passing a backend_config_dict to prepare is deprecated and will not be supported " + "in a future version. Please pass in a BackendConfig instead.") + backend_config = BackendConfig.from_dict(backend_config) + + named_modules = dict(model.named_modules()) + + if backend_config is None: + backend_config = get_native_backend_config() + + fusion_pattern_to_fuse_handler_cls = _sorted_patterns_dict(_get_fusion_pattern_to_fuse_handler_cls(backend_config)) + fuser_method_mapping = get_fuser_method_mapping(backend_config) + fusion_pattern_to_root_node_getter = get_fusion_pattern_to_root_node_getter(backend_config) + fusion_pattern_to_extra_inputs_getter = get_fusion_pattern_to_extra_inputs_getter(backend_config) + + # find fusion + fusion_pairs = _find_matches( + model, model.graph, fusion_pattern_to_fuse_handler_cls) + # TODO: change this to inplace changes to graph, since we no longer construct + # new GraphModule anymore + fused_graph = Graph() + env: Dict[Any, Any] = {} + + def load_arg(a): + return map_arg(a, lambda node: env[node.name]) + + def default_root_node_getter(node_pattern): + while not isinstance(node_pattern[-1], Node): + node_pattern = node_pattern[-1] + return node_pattern[-1] + + for node in model.graph.nodes: + maybe_last_node, pattern, matched_node_pattern, obj, node_to_subpattern = \ + fusion_pairs.get(node.name, (None, None, None, None, None)) + # get the corresponding subpattern for the current node + if node_to_subpattern is not None: + node_subpattern = node_to_subpattern.get(node, None) + else: + node_subpattern = None + if maybe_last_node is node: + assert obj is not None + root_node_getter = fusion_pattern_to_root_node_getter.get(pattern, default_root_node_getter) + root_node = root_node_getter(matched_node_pattern) # type: ignore[index] + extra_inputs_getter = fusion_pattern_to_extra_inputs_getter.get(pattern, None) + extra_inputs = [] + if extra_inputs_getter is not None: + extra_inputs = extra_inputs_getter(matched_node_pattern) + # TODO: add validation that root_node is a module and has the same type + # as the root_module in the configuration + env[node.name] = obj.fuse( + load_arg, named_modules, fused_graph, root_node, extra_inputs, matched_node_pattern, # type: ignore[arg-type] + fuse_custom_config, fuser_method_mapping, is_qat) + elif maybe_last_node is None or node_subpattern is MatchAllNode: + env[node.name] = fused_graph.node_copy(node, load_arg) + # node matched in patterns and is not root is removed here + + model = GraphModule(model, fused_graph) + return model + +def _find_matches( + root: GraphModule, + graph: Graph, + pattern_to_fuse_handler_cls: Dict[Pattern, Callable], +) -> Dict[str, Tuple[Node, Pattern, NodePattern, FuseHandler, Dict[Node, Any]]]: + modules = dict(root.named_modules()) + # node name -> (root_node, match_value) + match_map : Dict[ + str, Tuple[Node, Pattern, NodePattern, FuseHandler, Dict[Node, Any]]] = {} + # a map from node to the matched subpattern + node_to_subpattern: Dict[Node, Any] = {} + + # TODO: dedup with quantization matching function in match_utils.py + def apply_match(pattern, node, match, matched_node_pattern, node_to_subpattern): + if isinstance(pattern, tuple): + s, *args = pattern + current_node_pattern: List[Node] = [] + apply_match(s, node, match, current_node_pattern, node_to_subpattern) + for subpattern, arg in zip(args, node.args): + apply_match(subpattern, arg, match, current_node_pattern, node_to_subpattern) + matched_node_pattern.append(tuple(current_node_pattern)) + else: + # the first pattern matches will take precedence + if node.name not in match_map: + matched_node_pattern.append(node) + # MatchAllNode here is actually MatchAllInputNode which should not + # be added to match_map + if pattern is not MatchAllNode: + node_to_subpattern[node] = pattern + root_node, pattern, handler = match + match_map[node.name] = (root_node, pattern, matched_node_pattern, handler, node_to_subpattern) + + for node in reversed(graph.nodes): + if node.name not in match_map: + for pattern, fuse_handler_cls in pattern_to_fuse_handler_cls.items(): + matched_node_pattern: List[Node] = [] + if _is_match(modules, node, pattern): + apply_match(pattern, node, (node, pattern, fuse_handler_cls(node)), matched_node_pattern, node_to_subpattern) + break + + return match_map diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/fuse_handler.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/fuse_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..718cc561bfa0bb68935a899c7c1ba94b9f9820dc --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/fuse_handler.py @@ -0,0 +1,120 @@ +import torch +from torch.ao.quantization.backend_config import BackendConfig +from torch.fx.graph import Node, Graph +from ..utils import _parent_name, NodePattern, Pattern +from ..fuser_method_mappings import get_fuser_method_new +from abc import ABC, abstractmethod +from typing import Any, Callable, Dict, List, Union +from .custom_config import FuseCustomConfig +from .match_utils import MatchAllNode +from torch.nn.utils.parametrize import type_before_parametrizations + +__all__ = [ + "DefaultFuseHandler", + "FuseHandler", +] + + +# ---------------------------- +# Fusion Pattern Registrations +# ---------------------------- + +# Base Pattern Handler +class FuseHandler(ABC): + """ Base handler class for the fusion patterns + """ + @abstractmethod + def __init__(self, node: Node): + pass + + @abstractmethod + def fuse(self, + load_arg: Callable, + named_modules: Dict[str, torch.nn.Module], + fused_graph: Graph, + root_node: Node, + extra_inputs: List[Any], + matched_node_pattern: NodePattern, + fuse_custom_config: FuseCustomConfig, + fuser_method_mapping: Dict[Pattern, Union[torch.nn.Sequential, Callable]], + is_qat: bool) -> Node: + pass + +class DefaultFuseHandler(FuseHandler): + def __init__( + self, + node: Node): + super().__init__(node) + + def fuse(self, + load_arg: Callable, + named_modules: Dict[str, torch.nn.Module], + fused_graph: Graph, + root_node: Node, + extra_inputs: List[Any], + matched_node_pattern: NodePattern, + fuse_custom_config: FuseCustomConfig, + fuser_method_mapping: Dict[Pattern, Union[torch.nn.Sequential, Callable]], + is_qat: bool) -> Node: + assert root_node.op == "call_module", "Expecting module node to be a call_module Node" + root_module = named_modules[str(root_node.target)] + + def get_modules(pattern): + """ Given a node pattern, extract the corresponding modules + e.g. input: (relu_node, (bn_node, conv_node)) + output: (relu_module, (bn_module, conv_module)) + """ + if isinstance(pattern, (tuple, list)): + n, *args = pattern + modules: List[torch.nn.Module] = [] + modules.append(get_modules(n)) + for a in args: + modules.append(get_modules(a)) + return tuple(modules) + else: + n = pattern + if n.op == "call_module": + return named_modules[n.target] + elif n.op == "call_function" and n.target == torch.nn.functional.relu: + relu = torch.nn.ReLU() + relu.training = root_module.training + return relu + elif n.op == "call_function" or n.op == "call_method": + return n.target + else: + return MatchAllNode + + # since relu can be used multiple times, we'll need to create a relu module for each match + matched_modules = get_modules(matched_node_pattern) + + def get_matched_types(m): + if isinstance(m, tuple): + return tuple(map(get_matched_types, m)) + if isinstance(m, torch.nn.Module): + return type_before_parametrizations(m) + return m + + matched_module_types = get_matched_types(matched_modules) + module_parent_name, module_name = _parent_name(root_node.target) + fuser_method = get_fuser_method_new(matched_module_types, fuser_method_mapping) + # TODO: change the signature for fuser_method to take matched module patterns + # as input + fused_module = fuser_method(is_qat, *matched_modules) + setattr(named_modules[module_parent_name], module_name, fused_module) + extra_args = [] + for input in extra_inputs: + extra_args.append(load_arg(input)) + node = fused_graph.node_copy(root_node, load_arg) + args = list(node.args) + args.extend(extra_args) + node.args = tuple(args) + return node + +def _get_fusion_pattern_to_fuse_handler_cls( + backend_config: BackendConfig) -> Dict[Pattern, Callable]: + fusion_pattern_to_fuse_handlers: Dict[Pattern, Callable] = {} + for pattern, config in backend_config._pattern_complex_format_to_config.items(): + if config.fuser_method is not None: + # TODO: is this logic right? + fusion_pattern_to_fuse_handlers[pattern] = DefaultFuseHandler + return fusion_pattern_to_fuse_handlers diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/graph_module.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/graph_module.py new file mode 100644 index 0000000000000000000000000000000000000000..cc9187285ae6313b07e03fe47e0eaec8ca4a265b --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/graph_module.py @@ -0,0 +1,119 @@ +import torch +import copy +from torch.fx import GraphModule +from torch.fx.graph import Graph +from typing import Union, Dict, Any, Set + +__all__ = [ + "FusedGraphModule", + "ObservedGraphModule", + "ObservedStandaloneGraphModule", + "QuantizedGraphModule", +] + +class FusedGraphModule(GraphModule): + def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, preserved_attr_names: Set[str]): + self.preserved_attr_names = preserved_attr_names + preserved_attrs = {attr: getattr(root, attr) for attr in self.preserved_attr_names if hasattr(root, attr)} + super().__init__(root, graph) + for attr in preserved_attrs: + setattr(self, attr, preserved_attrs[attr]) + + # GraphModule does not copy attributes which are not in the __dict__ + # of vanilla nn.Module. So, we override __deepcopy__ in order + # to copy the quantization specific attributes correctly. + def __deepcopy__(self, memo): + fake_mod = torch.nn.Module() + fake_mod.__dict__ = copy.deepcopy(self.__dict__) + return FusedGraphModule(fake_mod, copy.deepcopy(self.graph), copy.deepcopy(self.preserved_attr_names)) + +class ObservedGraphModule(GraphModule): + + def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, preserved_attr_names: Set[str]): + self.preserved_attr_names = { + '_activation_post_process_map', + '_activation_post_process_indexes', + '_patterns', + '_node_name_to_qconfig', + '_prepare_custom_config', + '_equalization_node_name_to_qconfig', + '_node_name_to_scope', + '_qconfig_mapping', + '_is_qat', + '_observed_node_names'}.union(preserved_attr_names) + preserved_attrs = {attr: getattr(root, attr) for attr in self.preserved_attr_names if hasattr(root, attr)} + super().__init__(root, graph) + for attr in preserved_attrs: + setattr(self, attr, preserved_attrs[attr]) + + # GraphModule does not copy attributes which are not in the __dict__ + # of vanilla nn.Module. So, we override __deepcopy__ in order + # to copy the quantization specific attributes correctly. + def __deepcopy__(self, memo): + fake_mod = torch.nn.Module() + fake_mod.__dict__ = copy.deepcopy(self.__dict__) + return ObservedGraphModule(fake_mod, copy.deepcopy(self.graph), copy.deepcopy(self.preserved_attr_names)) + +def _is_observed_module(module: Any) -> bool: + return hasattr(module, "meta") and "_observed_graph_module_attrs" in module.meta + +def _get_observed_graph_module_attr(model: Union[torch.nn.Module, GraphModule], attr_name: str) -> Any: + if hasattr(model, "meta") and "_observed_graph_module_attrs" in model.meta: # type: ignore[operator, index] + return getattr(model.meta["_observed_graph_module_attrs"], attr_name) # type: ignore[index] + return None + +class ObservedStandaloneGraphModule(ObservedGraphModule): + def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, preserved_attr_names: Set[str]): + preserved_attr_names = preserved_attr_names.union({ + "_standalone_module_input_quantized_idxs", + "_standalone_module_output_quantized_idxs"}) + super().__init__(root, graph, preserved_attr_names) + + def __deepcopy__(self, memo): + fake_mod = torch.nn.Module() + fake_mod.__dict__ = copy.deepcopy(self.__dict__) + return ObservedStandaloneGraphModule(fake_mod, copy.deepcopy(self.graph), copy.deepcopy(self.preserved_attr_names)) + +def _is_observed_standalone_module(module: Any) -> bool: + return _is_observed_module(module) and module.meta["_observed_graph_module_attrs"].is_observed_standalone_module + +def _save_packed_weight(self, destination, prefix, keep_vars): + for attr_name in dir(self): + if "_packed_weight" in attr_name and \ + isinstance(getattr(self, attr_name), torch._C.ScriptObject): # type: ignore[attr-defined] + packed_weight = getattr(self, attr_name) + destination[prefix + attr_name] = packed_weight + +class QuantizedGraphModule(GraphModule): + """ This class is created to make sure PackedParams + (e.g. LinearPackedParams, Conv2dPackedParams) to appear in state_dict + so that we can serialize and deserialize quantized graph module with + torch.save(m.state_dict()) and m.load_state_dict(state_dict) + """ + def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, preserved_attr_names: Set[str]): + self.preserved_attr_names = preserved_attr_names + preserved_attrs = {attr: getattr(root, attr) for attr in self.preserved_attr_names if hasattr(root, attr)} + super().__init__(root, graph) + for attr in preserved_attrs: + setattr(self, attr, preserved_attrs[attr]) + self._register_state_dict_hook(_save_packed_weight) + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + attrs_to_pop = [] + for attr_name in state_dict: + if attr_name.startswith("_packed_weight") and isinstance(state_dict[attr_name], torch._C.ScriptObject): # type: ignore[attr-defined] # noqa: B950 + setattr(self, attr_name, state_dict[attr_name]) + attrs_to_pop.append(attr_name) + + # pop the packed param attributesn + for attr_name in attrs_to_pop: + state_dict.pop(attr_name) + + super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + + + def __deepcopy__(self, memo): + fake_mod = torch.nn.Module() + fake_mod.__dict__ = copy.deepcopy(self.__dict__) + return QuantizedGraphModule(fake_mod, copy.deepcopy(self.graph), copy.deepcopy(self.preserved_attr_names)) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/lower_to_fbgemm.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/lower_to_fbgemm.py new file mode 100644 index 0000000000000000000000000000000000000000..ef58652b1adda0dc135fbef21afe789d6f538eda --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/lower_to_fbgemm.py @@ -0,0 +1,16 @@ +from ._lower_to_native_backend import _lower_to_native_backend +from ..qconfig import QConfigAny +from torch.fx import GraphModule +from typing import Dict, Tuple + +__all__ = ['lower_to_fbgemm'] + +def lower_to_fbgemm( + model: GraphModule, + qconfig_map: Dict[str, QConfigAny], + node_name_to_scope: Dict[str, Tuple[str, type]] +) -> GraphModule: + """ Lower a quantized reference model (with reference quantized operator patterns) + to fbgemm + """ + return _lower_to_native_backend(model, qconfig_map, node_name_to_scope) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/lower_to_qnnpack.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/lower_to_qnnpack.py new file mode 100644 index 0000000000000000000000000000000000000000..a3a82179789dc392132a791632f0397a2dcf7595 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/lower_to_qnnpack.py @@ -0,0 +1,18 @@ +from ._lower_to_native_backend import _lower_to_native_backend +from ..qconfig import QConfigAny +from torch.fx import GraphModule +from typing import Dict, Tuple + +__all__ = [ + "lower_to_qnnpack" +] + +def lower_to_qnnpack( + model: GraphModule, + qconfig_map: Dict[str, QConfigAny], + node_name_to_scope: Dict[str, Tuple[str, type]] +) -> GraphModule: + """ Lower a quantized reference model (with reference quantized operator patterns) + to qnnpack + """ + return _lower_to_native_backend(model, qconfig_map, node_name_to_scope) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/lstm_utils.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/lstm_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9f163a1869ac1dc12ed2dca4a59a698482afc2f2 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/lstm_utils.py @@ -0,0 +1,183 @@ +import copy +import operator +import torch +from typing import Any, Callable, Optional, Tuple +from torch.ao.quantization import ( + default_weight_observer, + default_weight_fake_quant, + FakeQuantizeBase, + QConfig, + QConfigMapping, +) +from torch.ao.quantization.backend_config import BackendConfig +from torch.ao.quantization.observer import _PartialWrapper +from torch.ao.quantization.quantize_fx import ( + convert_to_reference_fx, + prepare_fx, +) + +# TODO: move all LSTM util functions from fx/utils.py to this file +def _get_lstm_with_individually_observed_parts( + float_lstm: torch.nn.LSTM, + example_inputs: Tuple[Any, ...], + backend_config: Optional[BackendConfig] = None, + linear_output_obs_ctr: Optional[_PartialWrapper] = None, + sigmoid_obs_ctr: Optional[_PartialWrapper] = None, + tanh_obs_ctr: Optional[_PartialWrapper] = None, + cell_state_obs_ctr: Optional[_PartialWrapper] = None, + hidden_state_obs_ctr: Optional[_PartialWrapper] = None, +) -> torch.ao.nn.quantizable.LSTM: + """ + Return an observed `torch.ao.nn.quantizable.LSTM` created from a `torch.nn.LSTM` + with specific observers or fake quantizes assigned to the inner ops or submodules. + + In both eager and FX graph mode quantization, `torch.ao.nn.quantizable.LSTM` is + used as an observed custom module, which is responsible for inserting its own + observers. By default, all inner ops inherit the parent custom module's QConfig. + Users who wish to override this behavior may extend `torch.ao.nn.quantizable.LSTM` + and use this helper function to customize the observer insertion logic. + + This is meant to be used to convert a float module to an observed module in the + custom module flow. + + Args: + `float_lstm`: The float LSTM module + `example_inputs`: example inputs for the forward function of the LSTM module + `backend_config`: BackendConfig to use to observe the LSTM module + `linear_output_obs_ctr`: observer or fake quantize for linear outputs Wx + b, + where W is the weight matrix, b is the bias, and x is either the inputs + or the hidden state from the previous layer (if any) + `sigmoid_obs_ctr`: observer or fake quantize for sigmoid activations + `tanh_obs_ctr`: observer or fake quantize for tanh activations + `cell_state_obs_ctr`: observer or fake quantize for the cell state + `hidden_state_obs_ctr`: observer or fake quantize for the hidden state and + the output + + Return: + A `torch.ao.nn.quantizable.LSTM` with the specified observers or fake quantizes + assigned to the inner ops. + """ + def make_qconfig(obs_ctr: _PartialWrapper) -> QConfig: + """ + Make a QConfig with fixed qparams observers or fake quantizes. + """ + if isinstance(obs_ctr(), FakeQuantizeBase): + weight = default_weight_fake_quant + else: + weight = default_weight_observer + return QConfig(activation=obs_ctr, weight=weight) + + quantizable_lstm = torch.ao.nn.quantizable.LSTM( + float_lstm.input_size, float_lstm.hidden_size, float_lstm.num_layers, float_lstm.bias, + float_lstm.batch_first, float_lstm.dropout, float_lstm.bidirectional) + quantizable_lstm.qconfig = float_lstm.qconfig + + for idx in range(float_lstm.num_layers): + quantizable_lstm.layers[idx] = torch.ao.nn.quantizable.modules.rnn._LSTMLayer.from_float(float_lstm, + idx, + float_lstm.qconfig, + batch_first=False) + + # Build QConfigMapping for the LSTM cell + # Note: FloatFunctional qconfigs will be configured separately below + cell_qm = QConfigMapping().set_global(float_lstm.qconfig) # type: ignore[arg-type] + if sigmoid_obs_ctr is not None: + cell_qm.set_module_name("input_gate", make_qconfig(sigmoid_obs_ctr)) + cell_qm.set_module_name("forget_gate", make_qconfig(sigmoid_obs_ctr)) + cell_qm.set_module_name("output_gate", make_qconfig(sigmoid_obs_ctr)) + if tanh_obs_ctr is not None: + cell_qm.set_module_name("cell_gate", make_qconfig(tanh_obs_ctr)) + + # Insert observers into each LSTM cell + # TODO: maybe make this work for layer_bw as well + for layer in quantizable_lstm.layers: + cell = layer.layer_fw.cell + cell = prepare_fx(cell, cell_qm, example_inputs, backend_config=backend_config) + # HACK: Manually replace the activation_post_process following these ops. + # This is needed for FloatFunctional ops because there is currently no way + # to configure these ops in FX graph mode quantization today. This is because + # the FloatFunctional modules simply disappear from the graph after tracing. + # In the future, we should rewrite quantizable LSTM without FloatFunctionals. + op_index_to_activation_post_process_ctr = { + (torch.add, 0): linear_output_obs_ctr, # gates.add + (torch.mul, 0): cell_state_obs_ctr, # fgate_cx.mul + (torch.mul, 1): cell_state_obs_ctr, # igate_cgate.mul + (torch.add, 1): cell_state_obs_ctr, # fgate_cx_igate_cgate.add + (torch.mul, 2): hidden_state_obs_ctr, # ogate_cy.mul + } + add_count = 0 + mul_count = 0 + for node in cell.graph.nodes: + op_index: Optional[Tuple[Callable, int]] = None # e.g. (torch.add, 1) + if node.target == torch.add: + op_index = (torch.add, add_count) + add_count += 1 + elif node.target == torch.mul: + op_index = (torch.mul, mul_count) + mul_count += 1 + else: + # Neither torch.add nor torch.mul + continue + if op_index not in op_index_to_activation_post_process_ctr: + continue + assert len(node.users) == 1 + activation_post_process_name = next(iter(node.users.keys())).name + activation_post_process_ctr = op_index_to_activation_post_process_ctr[op_index] + if activation_post_process_ctr is not None: + setattr(cell, activation_post_process_name, activation_post_process_ctr()) + layer.layer_fw.cell = cell + return quantizable_lstm + +def _get_reference_quantized_lstm_module( + observed_lstm: torch.ao.nn.quantizable.LSTM, + backend_config: Optional[BackendConfig] = None, +) -> torch.ao.nn.quantized.LSTM: + """ + Return a `torch.ao.nn.quantized.LSTM` created from a `torch.ao.nn.quantizable.LSTM` + with observers or fake quantizes inserted through `prepare_fx`, e.g. from + `_get_lstm_with_individually_observed_parts`. + + This is meant to be used to convert an observed module to a quantized module in the + custom module flow. + + Args: + `observed_lstm`: a `torch.ao.nn.quantizable.LSTM` observed through `prepare_fx` + `backend_config`: BackendConfig to use to produce the reference quantized model + + Return: + A reference `torch.ao.nn.quantized.LSTM` module. + """ + quantized_lstm = torch.ao.nn.quantized.LSTM( + observed_lstm.input_size, observed_lstm.hidden_size, observed_lstm.num_layers, + observed_lstm.bias, observed_lstm.batch_first, observed_lstm.dropout, + observed_lstm.bidirectional) + + for i, layer in enumerate(quantized_lstm.layers): + cell = copy.deepcopy(observed_lstm.layers.get_submodule(str(i)).layer_fw.cell) # type: ignore[union-attr] + cell = convert_to_reference_fx(cell, backend_config=backend_config) # type: ignore[arg-type] + assert isinstance(cell, torch.fx.GraphModule) + # HACK: Manually remove input quantize nodes and output dequantize nodes, + # since custom modules expect quint8 inputs and outputs for now. Note that + # this functionality is supposedly handled through PrepareCustomConfig's + # `set_input_quantized_indexes` and `set_output_quantized_indexes`, but that + # API doesn't currently handle tuple inputs and outputs, so we have to do + # this manually for now. In the future we should (1) relax the restriction + # on custom module input/output dtypes, and (2) expand support for complex + # input/output structures. + for node in cell.graph.nodes: + if node.target == torch.quantize_per_tensor: + arg = node.args[0] + # Remove quantize(x), quantize(hidden[0]), and quantize(hidden[1]) + if arg.target == "x" or (arg.target == operator.getitem and arg.args[0].target == "hidden"): + with cell.graph.inserting_before(node): + node.replace_all_uses_with(arg) + cell.graph.erase_node(node) + if node.target == "output": + # Remove all dequantize nodes in the output tuple + for arg in node.args[0]: + with cell.graph.inserting_before(node): + node.replace_input_with(arg, arg.args[0]) + cell.graph.eliminate_dead_code() + cell.recompile() + layer.layer_fw.cell = cell + return quantized_lstm diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/pattern_utils.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/pattern_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d8648a0aed5e701e26da22e218cab66bceab594b --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/pattern_utils.py @@ -0,0 +1,87 @@ +from collections import OrderedDict +from typing import Dict, Any +from torch.ao.quantization.utils import Pattern +from ..fake_quantize import FixedQParamsFakeQuantize +from ..observer import ObserverBase +import copy + +__all__ = [ + "get_default_fusion_patterns", + "get_default_quant_patterns", + "get_default_output_activation_post_process_map", +] + +# TODO(future PR): fix the typing on QuantizeHandler (currently a circular dependency) +QuantizeHandler = Any + +# pattern for conv bn fusion +_DEFAULT_FUSION_PATTERNS: Dict[Pattern, QuantizeHandler] = OrderedDict() +def _register_fusion_pattern(pattern): + def insert(fn): + _DEFAULT_FUSION_PATTERNS[pattern] = fn + return fn + return insert + +def get_default_fusion_patterns() -> Dict[Pattern, QuantizeHandler]: + return copy.copy(_DEFAULT_FUSION_PATTERNS) + +_DEFAULT_QUANTIZATION_PATTERNS: Dict[Pattern, QuantizeHandler] = OrderedDict() + +# Mapping from pattern to activation_post_process(observer/fake_quant) constructor for output activation +# e.g. pattern: torch.sigmoid, +# output_activation_post_process: default_fixed_qparams_range_0to1_fake_quant +_DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP: Dict[Pattern, QuantizeHandler] = {} +_DEFAULT_OUTPUT_OBSERVER_MAP: Dict[Pattern, QuantizeHandler] = {} + +# Register pattern for both static quantization and qat +def _register_quant_pattern(pattern, fixed_qparams_observer=None): + def insert(fn): + _DEFAULT_QUANTIZATION_PATTERNS[pattern] = fn + if fixed_qparams_observer is not None: + _DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP[pattern] = FixedQParamsFakeQuantize.with_args(observer=fixed_qparams_observer) + _DEFAULT_OUTPUT_OBSERVER_MAP[pattern] = fixed_qparams_observer + return fn + return insert + +# Get patterns for both static quantization and qat +def get_default_quant_patterns() -> Dict[Pattern, QuantizeHandler]: + return copy.copy(_DEFAULT_QUANTIZATION_PATTERNS) + +# a map from pattern to output activation post process constructor +# e.g. torch.sigmoid -> default_affine_fixed_qparam_fake_quant +def get_default_output_activation_post_process_map(is_training) -> Dict[Pattern, ObserverBase]: + if is_training: + return copy.copy(_DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP) + else: + return copy.copy(_DEFAULT_OUTPUT_OBSERVER_MAP) + +# Example use of register pattern function: +# @_register_fusion_pattern(torch.nn.ReLU, (torch.nn.BatchNorm2d, torch.nn.Conv2d))) +# class ConvOrLinearBNReLUFusion(): +# def __init__(...): +# ... +# + +def _sorted_patterns_dict(patterns_dict: Dict[Pattern, QuantizeHandler]) -> Dict[Pattern, QuantizeHandler]: + """ + Return a sorted version of the patterns dictionary such that longer patterns are matched first, + e.g. match (F.relu, F.linear) before F.relu. + This works for current use cases, but we may need to have a more clever way to sort + things to address more complex patterns + """ + + def get_len(pattern): + """ this will calculate the length of the pattern by counting all the entries + in the pattern. + this will make sure (nn.ReLU, (nn.BatchNorm, nn.Conv2d)) comes before + (nn.BatchNorm, nn.Conv2d) so that we can match the former first + """ + len = 0 + if isinstance(pattern, tuple): + for item in pattern: + len += get_len(item) + else: + len += 1 + return len + + return OrderedDict(sorted(patterns_dict.items(), key=lambda kv: -get_len(kv[0]) if isinstance(kv[0], tuple) else 1)) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/qconfig_mapping_utils.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/qconfig_mapping_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0b906a1777de0168511190a4d4d6ec4442a36b99 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/qconfig_mapping_utils.py @@ -0,0 +1,343 @@ +import torch +import re +from collections import defaultdict, OrderedDict +from typing import Callable, Any, Dict, Tuple, Set, List, Union +from torch.ao.quantization import QConfig +from torch.ao.quantization.qconfig import _add_module_to_qconfig_obs_ctr, QConfigAny, qconfig_equals +from torch.ao.quantization.observer import ( + _is_activation_post_process, +) +from torch.ao.quantization.backend_config import ( + BackendConfig, + DTypeConfig, +) +from torch.ao.quantization.backend_config.utils import ( + get_module_to_qat_module, +) + +from torch.fx import ( + GraphModule, +) +from torch.fx.graph import ( + Graph, +) +from torch.ao.nn.intrinsic import _FusedModule + +from ..utils import ( + _parent_name, + get_qconfig_dtypes, +) +from ..qconfig_mapping import ( + _OBJECT_TYPE_DICT_KEY, + _MODULE_NAME_DICT_KEY, + _MODULE_NAME_REGEX_DICT_KEY, + QConfigMapping, +) + +__all__: List[str] = [] + + + +def _maybe_adjust_qconfig_for_module_name_object_type_order( + qconfig_mapping: QConfigMapping, + cur_module_path: str, + cur_object_type: Callable, + cur_object_type_idx: int, + fallback_qconfig: QConfigAny, +) -> QConfigAny: + for (module_name, object_type, index), qconfig in qconfig_mapping.module_name_object_type_order_qconfigs.items(): + if ( + (module_name == cur_module_path) and + (object_type == cur_object_type) and + (index == cur_object_type_idx) + ): + return qconfig + return fallback_qconfig + + +def _update_qconfig_for_fusion(model: GraphModule, qconfig_mapping: QConfigMapping): + """ + Update the QConfigMapping to account for fused modules such as LinearReLU. + This assumes the QConfigMapping's attributes have already been converted to OrderedDicts. + """ + object_type_dict = qconfig_mapping.object_type_qconfigs + if len(object_type_dict) == 0: + return qconfig_mapping + + modules = dict(model.named_modules()) + + for node in model.graph.nodes: + if node.op == 'call_module' and node.target in modules: + maybe_fused_module = modules[str(node.target)] + if not isinstance(maybe_fused_module, _FusedModule): + continue + + ops = list(maybe_fused_module._modules.values()) + fused_qconfig = object_type_dict.get(type(ops[0]), None) + + # Raise an error if the modules in the fused module have + # different qconfigs specified in the qconfig_dict + # TODO: currently it only works for modules, + # need to make this work for torch.nn.functional.relu + # TODO: currently it only works for object_type configurations, + # ideally it should work for different types of configurations, + # maybe we want to redesign this part + for op in ops[1:]: + if not qconfig_equals(object_type_dict.get(type(op), None), fused_qconfig): + raise LookupError( + "During fusion, we need to specify the same " + + f"qconfigs for all module types in {type(maybe_fused_module)} " + + f"offending type: {type(op)}") + + if fused_qconfig is not None: + object_type_dict[type(maybe_fused_module)] = fused_qconfig + +def _generate_node_name_to_qconfig( + root: torch.nn.Module, + modules: Dict[str, torch.nn.Module], + input_graph: Graph, + qconfig_mapping: QConfigMapping, + node_name_to_scope: Dict[str, Tuple[str, type]]) -> Dict[str, QConfigAny]: + global_qconfig = qconfig_mapping.global_qconfig + node_name_to_qconfig = {} + + # example: + # + # {'foo.bar': {F.linear: 0, F.conv2d: 1, ...}, ...} + # + # meaning in submodule 'foo.bar', we have seen 0 F.linear and + # 1 F.conv2d invocations so far. + submodule_to_object_type_to_cur_idx: Dict[str, Dict[Callable, int]] = \ + defaultdict(lambda: defaultdict(int)) + for node in input_graph.nodes: + qconfig = None + if node.op == "get_attr": + module_name, _ = _parent_name(node.target) + qconfig = _maybe_adjust_qconfig_for_module_type_or_name( + qconfig_mapping, type(modules[module_name]), module_name, global_qconfig) + qconfig_with_device_check = _add_module_to_qconfig_obs_ctr(qconfig, modules.get(node.target, None)) + elif node.op == "call_function": + # precedence: module_name_qconfig + # > function_qconfig > global_qconfig + # module_name takes precedence over function qconfig + function_qconfig = _get_object_type_qconfig( + qconfig_mapping, node.target, global_qconfig) + module_path, module_type = node_name_to_scope[node.name] + qconfig = _maybe_adjust_qconfig_for_module_type_or_name( + qconfig_mapping, module_type, module_path, function_qconfig) + + cur_object_type_idx = \ + submodule_to_object_type_to_cur_idx[module_path][node.target] + submodule_to_object_type_to_cur_idx[module_path][node.target] += 1 + qconfig = _maybe_adjust_qconfig_for_module_name_object_type_order( + qconfig_mapping, module_path, node.target, cur_object_type_idx, qconfig) + qconfig_with_device_check = _add_module_to_qconfig_obs_ctr(qconfig, modules.get(node.target, None)) + + elif node.op == "call_method": + module_path, module_type = node_name_to_scope[node.name] + # first use node.target (string) to get the qconfig + # this is to support configs like + # "object_type": [("reshape", qconfig)] + qconfig = _maybe_adjust_qconfig_for_module_type_or_name( + qconfig_mapping, node.target, module_path, global_qconfig) + # if there is no special config for the method, we'll fall back to the + # config for the module that contains the call_method node + qconfig = _maybe_adjust_qconfig_for_module_type_or_name( + qconfig_mapping, module_type, module_path, qconfig) + # currently call_method does not support modifying qconfig + # by order, we can add this later if it is needed. + qconfig_with_device_check = _add_module_to_qconfig_obs_ctr(qconfig, modules.get(node.target, None)) + + elif node.op == 'call_module': + # if the node is an observer, just continue - don't add it to the qconfig_map + if _is_activation_post_process(modules[node.target]): + continue + qconfig = _maybe_adjust_qconfig_for_module_type_or_name( + qconfig_mapping, type(modules[node.target]), node.target, global_qconfig) + + module_path, module_type = node_name_to_scope[node.name] + # Note: for call_module, the module_path is the current module's name. + # to meaningfully count invocations, we need to count them in the parent + # module. + parent_name, _ = _parent_name(module_path) + cur_object_type_idx = \ + submodule_to_object_type_to_cur_idx[parent_name][module_type] + submodule_to_object_type_to_cur_idx[parent_name][module_type] += 1 + qconfig = _maybe_adjust_qconfig_for_module_name_object_type_order( + qconfig_mapping, parent_name, module_type, cur_object_type_idx, + qconfig) + qconfig_with_device_check = _add_module_to_qconfig_obs_ctr(qconfig, modules.get(node.target, None)) + + # regex is not supported eager mode propagate_qconfig_, we'll + # need to set the qconfig explicitly here in case regex + # is used + modules[node.target].qconfig = qconfig_with_device_check + else: + qconfig_with_device_check = None + + node_name_to_qconfig[node.name] = qconfig_with_device_check + return node_name_to_qconfig + + +def _check_is_valid_config_dict(config_dict: Any, allowed_keys: Set[str], dict_name: str) -> None: + r""" Checks if the given config_dict has the correct keys + + Args: + `config_dict`: dictionary whose keys we want to check + """ + + for k in config_dict.keys(): + if k not in allowed_keys: + raise ValueError( + 'Expected ' + dict_name + ' to have the following keys: ' + + str(allowed_keys) + '. But found \'' + k + + '\' instead.') + + +def _compare_prepare_convert_qconfig_mappings( + prepare_qconfig_mapping: QConfigMapping, + convert_qconfig_mapping: QConfigMapping): + r""" Compare the qconfig_mapping passed in convert to the one from prepare and check the values + + Args: + `prepare_qconfig_mapping`: configuration for prepare quantization step + `convert_qconfig_mapping`: configuration for convert quantization step + """ + assert qconfig_equals(prepare_qconfig_mapping.global_qconfig, convert_qconfig_mapping.global_qconfig), \ + "Expected global qconfigs to be the same in the prepare and convert quantization configs" + prepare_dicts: List[OrderedDict] = [ + prepare_qconfig_mapping.object_type_qconfigs, + prepare_qconfig_mapping.module_name_qconfigs, + prepare_qconfig_mapping.module_name_regex_qconfigs, + ] + convert_dicts: List[OrderedDict] = [ + convert_qconfig_mapping.object_type_qconfigs, + convert_qconfig_mapping.module_name_qconfigs, + convert_qconfig_mapping.module_name_regex_qconfigs, + ] + dict_names = [_OBJECT_TYPE_DICT_KEY, _MODULE_NAME_DICT_KEY, _MODULE_NAME_REGEX_DICT_KEY] + for i in range(len(prepare_dicts)): + for name in prepare_dicts[i].keys(): + assert name in convert_dicts[i], f"Missing key {dict_names[i]} {name} in convert QConfigMapping \ + when it was present in prepare" + assert convert_dicts[i][name] is None \ + or qconfig_equals(prepare_dicts[i][name], convert_dicts[i][name]), \ + f"Expected convert QConfigMapping to have the same qconfig as prepare for key {dict_names[i]} {name}; \ + prepare: {prepare_dicts[i][name]}; convert: {convert_dicts[i][name]}" + +def _is_qconfig_supported_by_dtype_configs(qconfig: QConfig, dtype_configs: List[DTypeConfig]): + for dtype_config in dtype_configs: + is_dynamic = dtype_config.is_dynamic + if is_dynamic is None: + is_dynamic = False + input_dtype = dtype_config.input_dtype or torch.float + weight_dtype = dtype_config.weight_dtype or torch.float + bias_dtype = dtype_config.bias_dtype or torch.float + output_dtype = dtype_config.output_dtype or torch.float + qconfig_activation_dtype, qconfig_weight_dtype, qconfig_input_act_is_dynamic = \ + get_qconfig_dtypes(qconfig) + qconfig_bias_dtype = torch.float16 \ + if ( + qconfig_activation_dtype == torch.float16 + and qconfig_weight_dtype == torch.float16 + and not is_dynamic + ) else torch.float + + if is_dynamic: + is_match = qconfig_input_act_is_dynamic and \ + input_dtype == qconfig_activation_dtype and \ + output_dtype == torch.float and \ + weight_dtype == qconfig_weight_dtype + else: + is_match = input_dtype == qconfig_activation_dtype and \ + output_dtype == qconfig_activation_dtype and \ + weight_dtype == qconfig_weight_dtype and \ + bias_dtype == qconfig_bias_dtype + if is_match: + return True + return False + +def _get_object_type_qconfig( + qconfig_mapping: QConfigMapping, + object_type: Union[Callable, str], + fallback_qconfig: QConfigAny) -> QConfigAny: + return qconfig_mapping.object_type_qconfigs.get(object_type, fallback_qconfig) + + +def _get_module_name_regex_qconfig(qconfig_mapping, module_name, fallback_qconfig): + for regex_pattern, qconfig in qconfig_mapping.module_name_regex_qconfigs.items(): + if re.match(regex_pattern, module_name): + # first match wins + return qconfig + return fallback_qconfig + + +def _get_module_name_qconfig(qconfig_mapping, module_name, fallback_qconfig): + if module_name == '': + # module name qconfig not found + return fallback_qconfig + if module_name in qconfig_mapping.module_name_qconfigs: + return qconfig_mapping.module_name_qconfigs[module_name] + else: + parent, _ = _parent_name(module_name) + return _get_module_name_qconfig(qconfig_mapping, parent, fallback_qconfig) + + +def _maybe_adjust_qconfig_for_module_type_or_name(qconfig_mapping, module_type, module_name, global_qconfig): + # get qconfig for module_name, + # fallback to module_name_regex_qconfig, module_type_qconfig, + # global_qconfig if necessary + module_type_qconfig = _get_object_type_qconfig( + qconfig_mapping, module_type, global_qconfig) + module_name_regex_qconfig = _get_module_name_regex_qconfig( + qconfig_mapping, module_name, module_type_qconfig) + module_name_qconfig = _get_module_name_qconfig( + qconfig_mapping, module_name, module_name_regex_qconfig) + return module_name_qconfig + + +def _get_flattened_qconfig_dict(qconfig_mapping: QConfigMapping) -> Dict[Union[Callable, str], QConfigAny]: + """ flatten the global, object_type and module_name qconfig + to the same qconfig_dict so that it can be used by + propagate_qconfig_ function. + "module_name_regex" is ignored for now since it's not supported + in propagate_qconfig_, but it can be fixed later. + + For example: + Input: { + "": qconfig, + "object_type": [ + (torch.add, qconfig) + ], + "module_name": [ + ("conv", qconfig) + ] + } + + Output: { + "": qconfig, + torch.add: qconfig, + "conv": qconfig + } + """ + flattened: Dict[Union[Callable, str], QConfigAny] = {"": qconfig_mapping.global_qconfig} + for obj, qconfig in qconfig_mapping.object_type_qconfigs.items(): + flattened[obj] = qconfig + for obj, qconfig in qconfig_mapping.module_name_qconfigs.items(): + flattened[obj] = qconfig + return flattened + + +def _update_qconfig_for_qat( + qconfig_mapping: QConfigMapping, + backend_config: BackendConfig): + """ + Update the qconfig_mapping to account for module swaps during QAT. + During QAT we perform a module swap on the nn.Module types to the corresponding nn.qat.modules types. + """ + module_to_qat_module_class = get_module_to_qat_module(backend_config) + object_type_dict = qconfig_mapping.object_type_qconfigs + new_object_type_dict = object_type_dict.copy() + for k, v in new_object_type_dict.items(): + if k in module_to_qat_module_class: + object_type_dict[module_to_qat_module_class[k]] = v diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/observer.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/observer.py new file mode 100644 index 0000000000000000000000000000000000000000..e19a73b3cc263e27f704bb4fc1b6b9ea6c85368f --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/observer.py @@ -0,0 +1,1688 @@ +""" +This module implements observers which are used to collect statistics about +the values observed during calibration (PTQ) or training (QAT). +""" + +import re +import warnings +from abc import ABCMeta, abstractmethod +from collections import OrderedDict +from functools import partial +from typing import Any, List, Tuple, Optional, Dict + +import torch +import torch.nn as nn +from torch.ao.quantization.utils import ( + check_min_max_valid, calculate_qmin_qmax, is_per_tensor, is_per_channel, validate_qmin_qmax) + +__all__ = [ + "default_affine_fixed_qparams_observer", + "default_debug_observer", + "default_dynamic_quant_observer", + "default_fixed_qparams_range_0to1_observer", + "default_fixed_qparams_range_neg1to1_observer", + "default_float_qparams_observer", + "default_float_qparams_observer_4bit", + "default_histogram_observer", + "default_observer", + "default_per_channel_weight_observer", + "default_placeholder_observer", + "default_reuse_input_observer", + "default_symmetric_fixed_qparams_observer", + "default_weight_observer", + "get_observer_state_dict", + "load_observer_state_dict", + "per_channel_weight_observer_range_neg_127_to_127", + "weight_observer_range_neg_127_to_127", + "FixedQParamsObserver", + "HistogramObserver", + "MinMaxObserver", + "MovingAverageMinMaxObserver", + "MovingAveragePerChannelMinMaxObserver", + "NoopObserver", + "ObserverBase", + "PerChannelMinMaxObserver", + "PlaceholderObserver", + "RecordingObserver", + "ReuseInputObserver", + "UniformQuantizationObserverBase", +] + + +class _PartialWrapper: + def __init__(self, p): + self.p = p + self.callable_args = {} + + def __call__(self, *args, **keywords): + # call each arg in callable_args and add them partial, then run with keywords + # skip if arg_name in keywords so its possible to overwrite + for arg_name in self.callable_args: + if arg_name not in keywords: + keywords = {**keywords, arg_name: self.callable_args[arg_name]()} + return self.p(*args, **keywords) + + def __repr__(self): + return self.p.__repr__() + self.callable_args.__repr__() + + def with_args(self, **kwargs): + return _with_args(self, **kwargs) + + def with_callable_args(self, **kwargs): + result = _PartialWrapper(p=self.p) + result.callable_args = {**self.callable_args, **kwargs} + return result + + +def _with_args(cls_or_self, **kwargs): + r"""Wrapper that allows creation of class factories. + + This can be useful when there is a need to create classes with the same + constructor arguments, but different instances. Can be used in conjunction with + _callable_args + + Example:: + + >>> # xdoctest: +SKIP("Undefined vars") + >>> Foo.with_args = classmethod(_with_args) + >>> foo_builder = Foo.with_args(a=3, b=4).with_args(answer=42) + >>> foo_instance1 = foo_builder() + >>> foo_instance2 = foo_builder() + >>> id(foo_instance1) == id(foo_instance2) + False + """ + r = _PartialWrapper(partial(cls_or_self, **kwargs)) + return r + +def _with_callable_args(cls_or_self, **kwargs): + r"""Wrapper that allows creation of class factories args that need to be + called at construction time. + + This can be useful when there is a need to create classes with the same + constructor arguments, but different instances and those arguments should only + be calculated at construction time. Can be used in conjunction with _with_args + + Example:: + + >>> # xdoctest: +SKIP("Undefined vars") + >>> Foo.with_callable_args = classmethod(_with_callable_args) + >>> Foo.with_args = classmethod(_with_args) + >>> foo_builder = Foo.with_callable_args(cur_time=get_time_func).with_args(name="dan") + >>> foo_instance1 = foo_builder() + >>> # wait 50 + >>> foo_instance2 = foo_builder() + >>> id(foo_instance1.creation_time) == id(foo_instance2.creation_time) + False + """ + r = _PartialWrapper(partial(cls_or_self)) + return r.with_callable_args(**kwargs) + + +ABC: Any = ABCMeta("ABC", (object,), {}) # compatible with Python 2 *and* 3: + + +class ObserverBase(ABC, nn.Module): + r"""Base observer Module. + Any observer implementation should derive from this class. + + Concrete observers should follow the same API. In forward, they will update + the statistics of the observed Tensor. And they should provide a + `calculate_qparams` function that computes the quantization parameters given + the collected statistics. + + Args: + dtype: dtype argument to the `quantize` node needed to implement the + reference model spec. + is_dynamic: indicator for whether the observer is a placeholder for dynamic quantization + or static quantization + """ + + def __init__(self, dtype, is_dynamic=False): + super().__init__() + self.dtype = dtype + self.is_dynamic = is_dynamic + + @abstractmethod + def forward(self, x): + pass + + @abstractmethod + def calculate_qparams(self, **kwargs): + pass + + with_args = classmethod(_with_args) + with_callable_args = classmethod(_with_callable_args) + + +class UniformQuantizationObserverBase(ObserverBase): + r"""Common base for all observers using uniform quantization to calculate + scale and zero_point. + + Args: + dtype: dtype argument to the `quantize` node needed to implement the + reference model spec. + qscheme: Quantization scheme to be used. + reduce_range: Reduces the range of the quantized data type by 1 bit. + This is sometimes required to avoid instruction overflow. + quant_min: Minimum quantization value. If unspecified, it will follow the 8-bit setup. + quant_max: Maximum quantization value. If unspecified, it will follow the 8-bit setup. + eps: Epsilon value for float32, Defaults to `torch.finfo(torch.float32).eps`. + + .. warning:: + + :attr:`dtype` can only take ``torch.qint8`` or ``torch.quint8``. + or `torch.int8` or `torch.uint8` + + .. warning:: + + :attr:`qscheme` can only take one of the following options: + + - ``torch.per_tensor_affine`` + - ``torch.per_tensor_symmetric`` + - ``torch.per_channel_affine`` + - ``torch.per_channel_symmetric`` + """ + + # Note: the version is shared by all observer types + # + # Version 1/None + # self + # + # Version 2 (base class only, does not include child class buffers) + # self + # |--- eps : Tensor + # + # Version 3 + # for HistogramObserver only, changed the shape of uninitialized + # min_val and max_val buffers from torch.Size([0]) to torch.Size([]) + # for PerChannelObservers, changed the name of the buffers from min_vals + # to min_val and from max_vals to max_val. + _version = 3 + + eps: torch.Tensor + + def __init__( + self, + dtype=torch.quint8, + qscheme=torch.per_tensor_affine, + reduce_range=False, + quant_min=None, + quant_max=None, + factory_kwargs=None, + eps=torch.finfo(torch.float32).eps, + is_dynamic=False, + **kwargs, + ) -> None: + factory_kwargs = torch.nn.factory_kwargs(factory_kwargs) + super().__init__(dtype=dtype, is_dynamic=is_dynamic, **kwargs) + self.qscheme = qscheme + if reduce_range: + warnings.warn( + "Please use quant_min and quant_max to specify the range for observers. \ + reduce_range will be deprecated in a future release of PyTorch." + ) + self.reduce_range = reduce_range + self.register_buffer( + "eps", torch.tensor([eps], **factory_kwargs) + ) + assert self.qscheme in ( + torch.per_tensor_affine, + torch.per_tensor_symmetric, + torch.per_channel_affine, + torch.per_channel_symmetric, + torch.per_channel_affine_float_qparams, + ), "Default Observer only works for per_tensor_affine, \ + per_tensor_symmetric, per_channel_affine, \ + per_channel_symmetric and per_channel_float_qparams quantization scheme" + + _ALLOWED_DTYPES = ( + torch.qint8, + torch.quint8, + torch.quint4x2, + torch.qint32, + torch.int8, + torch.uint8, + torch.int16, + torch.int32, + ) + + assert self.dtype in _ALLOWED_DTYPES, f"Default Observer only works for {_ALLOWED_DTYPES} data type" + self.has_customized_qrange = (quant_min is not None) and (quant_max is not None) + if self.has_customized_qrange: + validate_qmin_qmax(quant_min, quant_max) + self.quant_min, self.quant_max = \ + calculate_qmin_qmax(quant_min, quant_max, self.has_customized_qrange, self.dtype, self.reduce_range) + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + + version = local_metadata.get("version", None) + + if version is None or version == 1: + # eps was moved to a buffer in version 2 + eps = torch.tensor([torch.finfo(torch.float32).eps]) + state_dict[prefix + "eps"] = eps + + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) + + @torch.jit.export + def _validate_qmin_qmax(self, quant_min: int, quant_max: int) -> None: + r"""Validates that the user-specified quantization range is properly initialized + and within the given bound supported by the observer dtype. + + To accommodate lower-bit quantization with respect to the existing torch.qint8 and + torch.quint8 datatypes, the user can choose to use dynamic quantization range by passing + in a tuple of initial qmin and qmax values. One use case is these customized qmin and qmax + values are used to calculate static estimates of the scale and zero point for aggressive lower-bit + fake quantization. These estimates are compared against parameters learned through backpropagation. + The related literatures for scale and zero point via backpropagation are as follows: + + Learned Step Size Quantization: https://openreview.net/pdf?id=rkgO66VKDS + Trained Quantization Thresholds: https://arxiv.org/pdf/1903.08066.pdf + """ + # The variable names are prefixed with "initial" because their values (qmin and qmax) might be adjusted + # based on whether quantization range is reduced and the datatype (signed/unsigned) used by the observer. + assert ( + quant_min <= 0 <= quant_max + ), "Used-specified quantization range must include 0." + assert ( + quant_min < quant_max + ), "qmin must be strictly less than qmax for user-specified quantization range." + + @torch.jit.export + def _calculate_qparams( + self, min_val: torch.Tensor, max_val: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + r"""Calculates the quantization parameters, given min and max + value tensors. Works for both per tensor and per channel cases + + Args: + min_val: Minimum values per channel + max_val: Maximum values per channel + + Returns: + scales: Scales tensor of shape (#channels,) + zero_points: Zero points tensor of shape (#channels,) + """ + # Functionally equivalent to 'determine_qparams' in utils.py. Observers must be torchscriptable however and qscheme + # as far as I can tell is not allowed to passed as a parameter in torchscript functions. This makes refactoring observer + # to use this utility a massive pain and very gross. For now Im opting just to duplicate as this code + # seems unlikey to change (last update over 1 year ago) and when torchscript is fully deprecated we can refactor. + # TODO(jakeszwe, jerryzh168) + if not check_min_max_valid(min_val, max_val): + return torch.tensor([1.0], device=min_val.device.type), torch.tensor([0], device=min_val.device.type) + + quant_min, quant_max = self.quant_min, self.quant_max + min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) + max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) + + device = min_val_neg.device + scale = torch.ones(min_val_neg.size(), dtype=torch.float32, device=device) + zero_point = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device) + + if ( + self.qscheme == torch.per_tensor_symmetric + or self.qscheme == torch.per_channel_symmetric + ): + max_val_pos = torch.max(-min_val_neg, max_val_pos) + scale = max_val_pos / (float(quant_max - quant_min) / 2) + scale = torch.max(scale, self.eps) + if self.dtype in [torch.quint8, torch.uint8]: + if self.has_customized_qrange: + # When customized quantization range is used, down-rounded midpoint of the range is chosen. + zero_point = zero_point.new_full( + zero_point.size(), (quant_min + quant_max) // 2 + ) + else: + zero_point = zero_point.new_full(zero_point.size(), 128) + elif self.qscheme == torch.per_channel_affine_float_qparams: + scale = (max_val - min_val) / float(quant_max - quant_min) + scale = torch.where(scale > self.eps, scale, torch.ones_like(scale)) + # We use the quantize function + # xq = Round(Xf * inv_scale + zero_point), + # setting zero_point to (-1 * min *inv_scale) we get + # Xq = Round((Xf - min) * inv_scale) + zero_point = -1 * min_val / scale + else: + scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min) + scale = torch.max(scale, self.eps) + zero_point = quant_min - torch.round(min_val_neg / scale).to(torch.int) + zero_point = torch.clamp(zero_point, quant_min, quant_max) + + # For scalar values, cast them to Tensors of size 1 to keep the shape + # consistent with default values in FakeQuantize. + if len(scale.shape) == 0: + # TODO: switch to scale.item() after adding JIT support + scale = torch.tensor([float(scale)], dtype=scale.dtype, device=device) + if len(zero_point.shape) == 0: + # TODO: switch to zero_point.item() after adding JIT support + zero_point = torch.tensor( + [int(zero_point)], dtype=zero_point.dtype, device=device + ) + if self.qscheme == torch.per_channel_affine_float_qparams: + zero_point = torch.tensor( + [float(zero_point)], dtype=zero_point.dtype, device=device + ) + + return scale, zero_point + + @torch.jit.export + def reset_min_max_vals(self): + raise NotImplementedError("Cannot reset min/max values in the given observer.") + + +# Originally, this class was called `_ObserverBase`. Keeping the old name around +# for backwards compatibility. +# TODO(after v1.13): delete this +_ObserverBase = UniformQuantizationObserverBase + + +class MinMaxObserver(UniformQuantizationObserverBase): + r"""Observer module for computing the quantization parameters based on the + running min and max values. + + This observer uses the tensor min/max statistics to compute the quantization + parameters. The module records the running minimum and maximum of incoming + tensors, and uses this statistic to compute the quantization parameters. + + Args: + dtype: dtype argument to the `quantize` node needed to implement the + reference model spec. + qscheme: Quantization scheme to be used + reduce_range: Reduces the range of the quantized data type by 1 bit + quant_min: Minimum quantization value. If unspecified, it will follow the 8-bit setup. + quant_max: Maximum quantization value. If unspecified, it will follow the 8-bit setup. + eps: Epsilon value for float32, Defaults to `torch.finfo(torch.float32).eps`. + + Given running min/max as :math:`x_\text{min}` and :math:`x_\text{max}`, + scale :math:`s` and zero point :math:`z` are computed as: + + The running minimum/maximum :math:`x_\text{min/max}` is computed as: + + .. math:: + + \begin{array}{ll} + x_\text{min} &= \begin{cases} + \min(X) & \text{if~}x_\text{min} = \text{None} \\ + \min\left(x_\text{min}, \min(X)\right) & \text{otherwise} + \end{cases}\\ + x_\text{max} &= \begin{cases} + \max(X) & \text{if~}x_\text{max} = \text{None} \\ + \max\left(x_\text{max}, \max(X)\right) & \text{otherwise} + \end{cases}\\ + \end{array} + + where :math:`X` is the observed tensor. + + The scale :math:`s` and zero point :math:`z` are then computed as: + + .. math:: + + \begin{aligned} + \text{if Symmetric:}&\\ + &s = 2 \max(|x_\text{min}|, x_\text{max}) / + \left( Q_\text{max} - Q_\text{min} \right) \\ + &z = \begin{cases} + 0 & \text{if dtype is qint8} \\ + 128 & \text{otherwise} + \end{cases}\\ + \text{Otherwise:}&\\ + &s = \left( x_\text{max} - x_\text{min} \right ) / + \left( Q_\text{max} - Q_\text{min} \right ) \\ + &z = Q_\text{min} - \text{round}(x_\text{min} / s) + \end{aligned} + + where :math:`Q_\text{min}` and :math:`Q_\text{max}` are the minimum and + maximum of the quantized data type. + + .. warning:: :attr:`dtype` can only take ``torch.qint8`` or ``torch.quint8``. + + .. note:: If the running minimum equals to the running maximum, the scale + and zero_point are set to 1.0 and 0. + """ + min_val: torch.Tensor + max_val: torch.Tensor + + def __init__( + self, + dtype=torch.quint8, + qscheme=torch.per_tensor_affine, + reduce_range=False, + quant_min=None, + quant_max=None, + factory_kwargs=None, + eps=torch.finfo(torch.float32).eps, + is_dynamic=False, + **kwargs, + ) -> None: + if not is_per_tensor(qscheme): + raise NotImplementedError( + "MinMaxObserver's qscheme only support torch.per_tensor_symmetric \ + and torch.per_tensor_affine." + ) + # TODO: MinMaxObserver by itself doesn't support dynamic quantization, but + # if it's inherited by MovingAverageObserver, and averaging_constant is 1, it + # supports dynamic quantization, we may need to better error checking here + + # For x86 quantized kernels, we need to ensure that the vpmaddubsw + # instruction does not overflow. We allow for a reduce_range argument to + # observers that reduces the quantized range to (0,127) or (-64, 63). + # For more details see aten/src/ATen/native/quantized/cpu/qconv.cpp + # This is not an optimal choice for non x86 backends as it loses a bit + # of precision for activations. + super().__init__( + dtype=dtype, + qscheme=qscheme, + reduce_range=reduce_range, + quant_min=quant_min, + quant_max=quant_max, + factory_kwargs=factory_kwargs, + eps=eps, + is_dynamic=is_dynamic, + **kwargs, + ) + factory_kwargs = torch.nn.factory_kwargs(factory_kwargs) + self.register_buffer("min_val", torch.tensor(float("inf"), **factory_kwargs)) + self.register_buffer("max_val", torch.tensor(float("-inf"), **factory_kwargs)) + if ( + self.qscheme == torch.per_tensor_symmetric + and self.reduce_range + and self.dtype == torch.quint8 + ): + raise NotImplementedError( + "Cannot reduce range for symmetric \ + quantization for quint8" + ) + + def forward(self, x_orig): + r"""Records the running minimum and maximum of ``x``.""" + if x_orig.numel() == 0: + return x_orig + x = x_orig.detach() # avoid keeping autograd tape + x = x.to(self.min_val.dtype) + min_val_cur, max_val_cur = torch.aminmax(x) + min_val = torch.min(min_val_cur, self.min_val) + max_val = torch.max(max_val_cur, self.max_val) + self.min_val.copy_(min_val) + self.max_val.copy_(max_val) + return x_orig + + @torch.jit.export + def calculate_qparams(self): + r"""Calculates the quantization parameters.""" + return self._calculate_qparams(self.min_val, self.max_val) + + @torch.jit.export + def extra_repr(self): + return f"min_val={self.min_val}, max_val={self.max_val}" + + @torch.jit.export + def reset_min_max_vals(self): + """Resets the min/max values.""" + self.min_val.copy_(torch.tensor(float("inf"))) + self.max_val.copy_(torch.tensor(float("-inf"))) + +class MovingAverageMinMaxObserver(MinMaxObserver): + r"""Observer module for computing the quantization parameters based on the + moving average of the min and max values. + + This observer computes the quantization parameters based on the moving + averages of minimums and maximums of the incoming tensors. The module + records the average minimum and maximum of incoming tensors, and uses this + statistic to compute the quantization parameters. + + Args: + averaging_constant: Averaging constant for min/max. + dtype: dtype argument to the `quantize` node needed to implement the + reference model spec. + qscheme: Quantization scheme to be used + reduce_range: Reduces the range of the quantized data type by 1 bit + quant_min: Minimum quantization value. If unspecified, it will follow the 8-bit setup. + quant_max: Maximum quantization value. If unspecified, it will follow the 8-bit setup. + eps: Epsilon value for float32, Defaults to `torch.finfo(torch.float32).eps`. + + The moving average min/max is computed as follows + + .. math:: + + \begin{array}{ll} + x_\text{min} = \begin{cases} + \min(X) & \text{if~}x_\text{min} = \text{None} \\ + (1 - c) x_\text{min} + c \min(X) & \text{otherwise} + \end{cases}\\ + x_\text{max} = \begin{cases} + \max(X) & \text{if~}x_\text{max} = \text{None} \\ + (1 - c) x_\text{max} + c \max(X) & \text{otherwise} + \end{cases}\\ + \end{array} + + where :math:`x_\text{min/max}` is the running average min/max, :math:`X` is + is the incoming tensor, and :math:`c` is the ``averaging_constant``. + + The scale and zero point are then computed as in + :class:`~torch.ao.quantization.observer.MinMaxObserver`. + + .. note:: Only works with ``torch.per_tensor_affine`` quantization scheme. + + .. note:: If the running minimum equals to the running maximum, the scale + and zero_point are set to 1.0 and 0. + """ + + def __init__( + self, + averaging_constant=0.01, + dtype=torch.quint8, + qscheme=torch.per_tensor_affine, + reduce_range=False, + quant_min=None, + quant_max=None, + eps=torch.finfo(torch.float32).eps, + is_dynamic=False, + **kwargs + ) -> None: + if not is_per_tensor(qscheme): + raise NotImplementedError( + f"MovingAverageMinMaxObserver's qscheme only support \ + torch.per_tensor_symmetric and torch.per_tensor_affine. \ + but got: {qscheme}" + ) + self.averaging_constant = averaging_constant + if is_dynamic and self.averaging_constant != 1: + raise NotImplementedError( + "MovingAverageMinMaxObserver doesn't support dynamic quantization for " + f"averaging constant of {self.averaging_constant}" + ) + super().__init__( + dtype=dtype, + qscheme=qscheme, + reduce_range=reduce_range, + quant_min=quant_min, + quant_max=quant_max, + eps=eps, + is_dynamic=is_dynamic, + **kwargs + ) + + def forward(self, x_orig): + if x_orig.numel() == 0: + return x_orig + x = x_orig.detach() # avoid keeping autograd tape + x = x.to(self.min_val.dtype) + min_val = self.min_val + max_val = self.max_val + if min_val == float("inf") and max_val == float("-inf"): + min_val, max_val = torch.aminmax(x) + else: + min_val_cur, max_val_cur = torch.aminmax(x) + min_val = min_val + self.averaging_constant * (min_val_cur - min_val) + max_val = max_val + self.averaging_constant * (max_val_cur - max_val) + self.min_val.copy_(min_val) + self.max_val.copy_(max_val) + return x_orig + + +class PerChannelMinMaxObserver(UniformQuantizationObserverBase): + r"""Observer module for computing the quantization parameters based on the + running per channel min and max values. + + This observer uses the tensor min/max statistics to compute the per channel + quantization parameters. The module records the running minimum and maximum + of incoming tensors, and uses this statistic to compute the quantization + parameters. + + Args: + ch_axis: Channel axis + dtype: dtype argument to the `quantize` node needed to implement the + reference model spec. + qscheme: Quantization scheme to be used + reduce_range: Reduces the range of the quantized data type by 1 bit + quant_min: Minimum quantization value. If unspecified, it will follow the 8-bit setup. + quant_max: Maximum quantization value. If unspecified, it will follow the 8-bit setup. + eps: Epsilon value for float32, Defaults to `torch.finfo(torch.float32).eps`. + + The quantization parameters are computed the same way as in + :class:`~torch.ao.quantization.observer.MinMaxObserver`, with the difference + that the running min/max values are stored per channel. + Scales and zero points are thus computed per channel as well. + + .. note:: If the running minimum equals to the running maximum, the scales + and zero_points are set to 1.0 and 0. + """ + min_val: torch.Tensor + max_val: torch.Tensor + + def __init__( + self, + ch_axis=0, + dtype=torch.quint8, + qscheme=torch.per_channel_affine, + reduce_range=False, + quant_min=None, + quant_max=None, + factory_kwargs=None, + eps=torch.finfo(torch.float32).eps, + is_dynamic=False, + **kwargs, + ) -> None: + if not is_per_channel(qscheme): + raise NotImplementedError( + "PerChannelMinMaxObserver's qscheme only support \ + torch.per_channel_symmetric, torch.per_channel_affine and torch.per_channel_affine_float_qparams." + ) + if is_dynamic: + raise NotImplementedError( + "PerChannelMinMaxObserver doesn't support dynamic quantization" + ) + super().__init__( + dtype=dtype, + qscheme=qscheme, + reduce_range=reduce_range, + quant_min=quant_min, + quant_max=quant_max, + factory_kwargs=factory_kwargs, + eps=eps, + is_dynamic=is_dynamic, + **kwargs, + ) + factory_kwargs = torch.nn.factory_kwargs(factory_kwargs) + self.ch_axis = ch_axis + self.register_buffer("min_val", torch.tensor([], **factory_kwargs)) + self.register_buffer("max_val", torch.tensor([], **factory_kwargs)) + if ( + self.qscheme == torch.per_channel_symmetric + and self.reduce_range + and self.dtype == torch.quint8 + ): + raise NotImplementedError( + "Cannot reduce range for symmetric quantization for quint8" + ) + + def forward(self, x_orig): + return self._forward(x_orig) + + def _forward(self, x_orig): + if x_orig.numel() == 0: + return x_orig + x = x_orig.detach() # avoid keeping autograd tape + min_val = self.min_val + max_val = self.max_val + x_dim = x.size() + + new_axis_list = [i for i in range(len(x_dim))] # noqa: C416 + new_axis_list[self.ch_axis] = 0 + new_axis_list[0] = self.ch_axis + y = x.permute(new_axis_list) + # Need to match dtype of min/max because the updates to buffers + # are done in place and types need to match for comparisons + y = y.to(self.min_val.dtype) + y = torch.flatten(y, start_dim=1) + if min_val.numel() == 0 or max_val.numel() == 0: + min_val, max_val = torch.aminmax(y, dim=1) + else: + min_val_cur, max_val_cur = torch.aminmax(y, dim=1) + min_val = torch.min(min_val_cur, min_val) + max_val = torch.max(max_val_cur, max_val) + self.min_val.resize_(min_val.shape) + self.max_val.resize_(max_val.shape) + self.min_val.copy_(min_val) + self.max_val.copy_(max_val) + return x_orig + + @torch.jit.export + def calculate_qparams(self): + return self._calculate_qparams(self.min_val, self.max_val) + + def extra_repr(self): + return f"min_val={self.min_val}, max_val={self.max_val}" + + def _load_from_state_dict( + self, + state_dict: Dict[str, Any], + prefix: str, + local_metadata: Dict[str, torch.Tensor], + strict: bool, + missing_keys: List[str], + unexpected_keys: List[str], + error_msgs: List[str], + ): + version = local_metadata.get("version", None) + if version is not None and version < 3: + local_state = ["min_vals", "max_vals"] + expected_min_name = "min_vals" + expected_max_name = "max_vals" + else: + local_state = ["min_val", "max_val"] + expected_min_name = "min_val" + expected_max_name = "max_val" + for name in local_state: + key = prefix + name + if key in state_dict: + val = state_dict[key] + # Custom handling to allow loading min_val or max_val + # of size N into uninitialized buffers of size 0. The + # buffers are resized here, and the values are copied in + # the default state_dict loading code of the parent. + if name == expected_min_name: + self.min_val.resize_(val.shape) + elif name == expected_max_name: + self.max_val.resize_(val.shape) + else: + warnings.warn(f"Observer load_from_state_dict got unexpected name {name}") + # For torchscript module we need to update the attributes here since we do not + # call the `_load_from_state_dict` function defined module.py + if torch.jit.is_scripting(): + if name == expected_min_name: + self.min_val.copy_(val) + elif name == expected_max_name: + self.max_val.copy_(val) + else: + warnings.warn(f"Observer load_from_state_dict got unexpected name {name}") + elif strict: + missing_keys.append(key) + + if not torch.jit.is_scripting(): + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + False, + missing_keys, + unexpected_keys, + error_msgs, + ) + + def _load_from_state_dict_script( + self, + state_dict: Dict[str, Any], + prefix: str, + local_metadata: Dict[str, torch.Tensor], + strict: bool, + missing_keys: List[str], + unexpected_keys: List[str], + error_msgs: List[str], + ): + + self._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) + + @torch.jit.export + def reset_min_max_vals(self): + """Resets the min/max values.""" + # This used to be torch.ones but that does not work because + # JIT compiler can optimize it via common subexpression elimination + # in which case both min_val and max_val point to the same tensor. + self.min_val = torch.rand(0, ) + self.max_val = torch.rand(0, ) + + +class MovingAveragePerChannelMinMaxObserver(PerChannelMinMaxObserver): + r"""Observer module for computing the quantization parameters based on the + running per channel min and max values. + + This observer uses the tensor min/max statistics to compute the per channel + quantization parameters. The module records the running minimum and maximum + of incoming tensors, and uses this statistic to compute the quantization + parameters. + + Args: + averaging_constant: Averaging constant for min/max. + ch_axis: Channel axis + dtype: Quantized data type + qscheme: Quantization scheme to be used + reduce_range: Reduces the range of the quantized data type by 1 bit + quant_min: Minimum quantization value. If unspecified, it will follow the 8-bit setup. + quant_max: Maximum quantization value. If unspecified, it will follow the 8-bit setup. + eps: Epsilon value for float32, Defaults to `torch.finfo(torch.float32).eps`. + + The quantization parameters are computed the same way as in + :class:`~torch.ao.quantization.observer.MovingAverageMinMaxObserver`, with the + difference that the running min/max values are stored per channel. + Scales and zero points are thus computed per channel as well. + + .. note:: If the running minimum equals to the running maximum, the scales + and zero_points are set to 1.0 and 0. + """ + + def __init__( + self, + averaging_constant=0.01, + ch_axis=0, + dtype=torch.quint8, + qscheme=torch.per_channel_affine, + reduce_range=False, + quant_min=None, + quant_max=None, + eps=torch.finfo(torch.float32).eps, + is_dynamic=False, + **kwargs + ) -> None: + if not is_per_channel(qscheme): + raise NotImplementedError( + "MovingAveragePerChannelMinMaxObserver's qscheme only support \ + torch.per_channel_symmetric, torch.per_channel_affine and torch.per_channel_affine_float_qparams." + ) + if is_dynamic: + raise NotImplementedError( + "MovingAveragePerChannelMinMaxObserver doesn't support dynamic quantization" + ) + super().__init__( + ch_axis=ch_axis, + dtype=dtype, + qscheme=qscheme, + reduce_range=reduce_range, + quant_min=quant_min, + quant_max=quant_max, + eps=eps, + is_dynamic=is_dynamic, + **kwargs + ) + self.averaging_constant = averaging_constant + + def forward(self, x_orig): + if x_orig.numel() == 0: + return x_orig + x = x_orig.detach() # avoid keeping autograd tape + x = x.to(self.min_val.dtype) + min_val = self.min_val + max_val = self.max_val + x_dim = x.size() + + new_axis_list = [i for i in range(len(x_dim))] # noqa: C416 + new_axis_list[self.ch_axis] = 0 + new_axis_list[0] = self.ch_axis + y = x.permute(new_axis_list) + y = torch.flatten(y, start_dim=1) + if min_val.numel() == 0 or max_val.numel() == 0: + min_val, max_val = torch.aminmax(y, dim=1) + else: + min_val_cur, max_val_cur = torch.aminmax(y, dim=1) + min_val = min_val + self.averaging_constant * (min_val_cur - min_val) + max_val = max_val + self.averaging_constant * (max_val_cur - max_val) + self.min_val.resize_(min_val.shape) + self.max_val.resize_(max_val.shape) + self.min_val.copy_(min_val) + self.max_val.copy_(max_val) + return x_orig + + +class HistogramObserver(UniformQuantizationObserverBase): + r""" + The module records the running histogram of tensor values along with + min/max values. ``calculate_qparams`` will calculate scale and zero_point. + + Args: + bins: Number of bins to use for the histogram + upsample_rate: Factor by which the histograms are upsampled, this is + used to interpolate histograms with varying ranges across observations + dtype: dtype argument to the `quantize` node needed to implement the + reference model spec + qscheme: Quantization scheme to be used + reduce_range: Reduces the range of the quantized data type by 1 bit + eps: Epsilon value for float32, Defaults to `torch.finfo(torch.float32).eps`. + + The scale and zero point are computed as follows: + + 1. Create the histogram of the incoming inputs. + The histogram is computed continuously, and the ranges per bin change + with every new tensor observed. + 2. Search the distribution in the histogram for optimal min/max values. + The search for the min/max values ensures the minimization of the + quantization error with respect to the floating point model. + 3. Compute the scale and zero point the same way as in the + :class:`~torch.ao.quantization.MinMaxObserver` + """ + histogram: torch.Tensor + min_val: torch.Tensor + max_val: torch.Tensor + + def __init__( + self, + bins: int = 2048, + upsample_rate: int = 128, + dtype: torch.dtype = torch.quint8, + qscheme=torch.per_tensor_affine, + reduce_range=False, + quant_min=None, + quant_max=None, + factory_kwargs=None, + eps=torch.finfo(torch.float32).eps, + is_dynamic=False, + **kwargs, + ) -> None: + if not is_per_tensor(qscheme): + raise NotImplementedError( + "HistogramObserver's qscheme only support torch.per_tensor_symmetric \ + and torch.per_tensor_affine." + ) + if is_dynamic: + raise NotImplementedError( + "HistogramObserver doesn't support dynamic quantization" + ) + # bins: The number of bins used for histogram calculation. + super().__init__( + dtype=dtype, + qscheme=qscheme, + reduce_range=reduce_range, + quant_min=quant_min, + quant_max=quant_max, + factory_kwargs=factory_kwargs, + eps=eps, + is_dynamic=is_dynamic, + **kwargs + ) + factory_kwargs = torch.nn.factory_kwargs(factory_kwargs) + self.bins = bins + self.register_buffer("histogram", torch.zeros(self.bins, **factory_kwargs)) + self.register_buffer("min_val", torch.tensor(float("inf"), **factory_kwargs)) + self.register_buffer("max_val", torch.tensor(float("-inf"), **factory_kwargs)) + self.dst_nbins = 2 ** torch.iinfo(self.dtype).bits + self.upsample_rate = upsample_rate + + def _get_norm( + self, delta_begin: torch.Tensor, delta_end: torch.Tensor, density: torch.Tensor + ) -> torch.Tensor: + r""" + Compute the norm of the values uniformaly distributed between + delta_begin and delta_end. + Currently only L2 norm is supported. + + norm = density * (integral_{begin, end} x^2) + = density * (end^3 - begin^3) / 3 + """ + norm = ( + delta_end * delta_end * delta_end - delta_begin * delta_begin * delta_begin + ) / 3 + return density * norm + + def _compute_quantization_error(self, next_start_bin: int, next_end_bin: int): + r""" + Compute the quantization error if we use start_bin to end_bin as the + min and max to do the quantization. + """ + bin_width = (self.max_val.item() - self.min_val.item()) / self.bins + + dst_bin_width = bin_width * (next_end_bin - next_start_bin + 1) / self.dst_nbins + if dst_bin_width == 0.0: + return 0.0 + + src_bin = torch.arange(self.bins, device=self.histogram.device) + # distances from the beginning of first dst_bin to the beginning and + # end of src_bin + src_bin_begin = (src_bin - next_start_bin) * bin_width + src_bin_end = src_bin_begin + bin_width + + # which dst_bins the beginning and end of src_bin belong to? + dst_bin_of_begin = torch.clamp( + torch.div(src_bin_begin, dst_bin_width, rounding_mode='floor'), 0, self.dst_nbins - 1 + ) + dst_bin_of_begin_center = (dst_bin_of_begin + 0.5) * dst_bin_width + + dst_bin_of_end = torch.clamp( + torch.div(src_bin_end, dst_bin_width, rounding_mode='floor'), 0, self.dst_nbins - 1 + ) + density = self.histogram / bin_width + + norm = torch.zeros(self.bins, device=self.histogram.device) + + delta_begin = src_bin_begin - dst_bin_of_begin_center + delta_end = dst_bin_width / 2 + norm += self._get_norm(delta_begin, + torch.ones(self.bins, device=self.histogram.device) * delta_end, + density) + + norm += (dst_bin_of_end - dst_bin_of_begin - 1) * self._get_norm( + torch.tensor(-dst_bin_width / 2), torch.tensor(dst_bin_width / 2), density + ) + + dst_bin_of_end_center = dst_bin_of_end * dst_bin_width + dst_bin_width / 2 + + delta_begin = -dst_bin_width / 2 + delta_end = src_bin_end - dst_bin_of_end_center + norm += self._get_norm(torch.tensor(delta_begin), delta_end, density) + + return norm.sum().item() + + def _non_linear_param_search(self) -> Tuple[torch.Tensor, torch.Tensor]: + r"""Non-linear parameter search. + + An approximation for L2 error minimization for selecting min/max. + By selecting new min/max, we filter out outliers in input distribution. + This follows the implementation of NormMinimization::NonlinearQuantizationParamsSearch in + caffe2/quantization/server/norm_minimization.cc + """ + assert self.histogram.size()[0] == self.bins, "bins mismatch" + bin_width = (self.max_val - self.min_val) / self.bins + + # cumulative sum + total = torch.sum(self.histogram).item() + cSum = torch.cumsum(self.histogram, dim=0) + + stepsize = 1e-5 # granularity + alpha = 0.0 # lower bound + beta = 1.0 # upper bound + start_bin = 0 + end_bin = self.bins - 1 + norm_min = float("inf") + + while alpha < beta: + # Find the next step + next_alpha = alpha + stepsize + next_beta = beta - stepsize + + # find the left and right bins between the quantile bounds + l = start_bin + r = end_bin + while l < end_bin and cSum[l] < next_alpha * total: + l = l + 1 + while r > start_bin and cSum[r] > next_beta * total: + r = r - 1 + + # decide the next move + next_start_bin = start_bin + next_end_bin = end_bin + if (l - start_bin) > (end_bin - r): + # move the start bin + next_start_bin = l + alpha = next_alpha + else: + # move the end bin + next_end_bin = r + beta = next_beta + + if next_start_bin == start_bin and next_end_bin == end_bin: + continue + + # calculate the quantization error using next_start_bin and next_end_bin + norm = self._compute_quantization_error(next_start_bin, next_end_bin) + + if norm > norm_min: + break + norm_min = norm + start_bin = next_start_bin + end_bin = next_end_bin + + new_min = self.min_val + bin_width * start_bin + new_max = self.min_val + bin_width * (end_bin + 1) + return new_min, new_max + + def _adjust_min_max( + self, combined_min: torch.Tensor, combined_max: torch.Tensor, upsample_rate: int + ) -> Tuple[torch.Tensor, torch.Tensor, int, int]: + # We ensure that: + # (combined_max - combined_min)/(downsample_rate*Nbins) = (max - min)/(upsample_rate*Nbins) + # This allows us to have a common grid of resolution s, where we can align + # the input histogram + # start_idx maps min_val to the histogram bin index. + + # Compute the width of histogram bins is a straightforward solution, where + # hist_bin_width = (self.max_val - self.min_val) / (self.bins * upsample_rate) + # Underflow happens if the numerator is close to the smallest positive subnormal number of FP32 + # Therefore, we avoid such division operation. + downsample_rate = int( + torch.ceil( + ((combined_max - combined_min) / (self.max_val - self.min_val)) * upsample_rate + ).item() + ) + e = downsample_rate / upsample_rate * (self.max_val - self.min_val) - (combined_max - combined_min) + start_idx = int( + torch.round((self.min_val - combined_min) / (self.max_val - self.min_val) * self.bins * upsample_rate).item() + ) + combined_max = combined_max + e + return combined_min, combined_max, downsample_rate, start_idx + + def _combine_histograms( + self, + orig_hist: torch.Tensor, + new_hist: torch.Tensor, + upsample_rate: int, + downsample_rate: int, + start_idx: int, + Nbins: int, + ) -> torch.Tensor: + # First up-sample the histogram with new data by a factor of L + # This creates an approximate probability density thats piecewise constant + upsampled_histogram = new_hist.repeat_interleave(upsample_rate) + # Now insert the upsampled histogram into the output + # histogram, which is initialized with zeros. + # The offset at which the histogram is introduced is determined + # by the start index as the output histogram can cover a wider range + histogram_with_output_range = torch.zeros( + (Nbins * downsample_rate), device=orig_hist.device + ) + histogram_with_output_range[ + start_idx : Nbins * upsample_rate + start_idx + ] = upsampled_histogram + # Compute integral histogram, double precision is needed to ensure + # that there are no overflows + integral_histogram = torch.cumsum( + histogram_with_output_range, 0, dtype=torch.double + )[downsample_rate - 1 :: downsample_rate] + # Finally perform interpolation + shifted_integral_histogram = torch.zeros((Nbins), device=orig_hist.device) + shifted_integral_histogram[1:Nbins] = integral_histogram[0:-1] + interpolated_histogram = ( + integral_histogram - shifted_integral_histogram + ) / upsample_rate + orig_hist = orig_hist + interpolated_histogram.to(torch.float) + return orig_hist + + def forward(self, x_orig: torch.Tensor) -> torch.Tensor: + if x_orig.numel() == 0: + return x_orig + x = x_orig.detach() + x_min, x_max = torch.aminmax(x) + # want to ignore torch.inf since we don't actually + # want to make our quantization range infinite + # and in practice those values will be clamped + if x_min == -torch.inf or x_max == torch.inf: + warnings.warn("torch.inf detected in input tensor, ignoring input") + x = x[x.abs() != torch.inf] + if x.numel() == 0: + return x_orig + x_min, x_max = torch.aminmax(x) + min_val = self.min_val + max_val = self.max_val + same_values = min_val.item() == max_val.item() + is_uninitialized = min_val == float("inf") and max_val == float("-inf") + if is_uninitialized or same_values: + min_val, max_val = x_min, x_max + self.min_val.resize_(min_val.shape) + self.min_val.copy_(min_val) + self.max_val.resize_(max_val.shape) + self.max_val.copy_(max_val) + assert ( + min_val.numel() == 1 and max_val.numel() == 1 + ), "histogram min/max values must be scalar." + torch.histc( + x, self.bins, min=min_val, max=max_val, out=self.histogram # type: ignore[arg-type] + ) + else: + new_min, new_max = x_min, x_max + combined_min = torch.min(new_min, min_val) + combined_max = torch.max(new_max, max_val) + # combine the existing histogram and new histogram into 1 histogram + # We do this by first upsampling the histogram to a dense grid + # and then downsampling the histogram efficiently + ( + combined_min, + combined_max, + downsample_rate, + start_idx, + ) = self._adjust_min_max(combined_min, combined_max, self.upsample_rate) + assert ( + combined_min.numel() == 1 and combined_max.numel() == 1 + ), "histogram min/max values must be scalar." + + # TODO: For some reason, this is required for it to pass torchscript test + # combined_min and combined_max should already have requires_grad set to False + combined_min, combined_max = combined_min.detach(), combined_max.detach() + + combined_histogram = torch.histc( + x, self.bins, min=combined_min, max=combined_max # type: ignore[arg-type] + ) + if combined_min == min_val and combined_max == max_val: + combined_histogram += self.histogram + else: + combined_histogram = self._combine_histograms( + combined_histogram, + self.histogram, + self.upsample_rate, + downsample_rate, + start_idx, + self.bins, + ) + + self.histogram.detach_().resize_(combined_histogram.shape) + self.histogram.copy_(combined_histogram) + self.min_val.detach_().resize_(combined_min.shape) + self.min_val.copy_(combined_min) + self.max_val.detach_().resize_(combined_max.shape) + self.max_val.copy_(combined_max) + return x_orig + + @torch.jit.export + def calculate_qparams(self): + is_uninitialized = self.min_val == float("inf") and self.max_val == float( + "-inf" + ) + if is_uninitialized: + warnings.warn( + "must run observer before calling calculate_qparams.\ + Returning default scale and zero point " + ) + return torch.tensor([1.0], device=self.min_val.device.type), torch.tensor([0], device=self.min_val.device.type) + assert self.bins == len(self.histogram), ( + "The number of bins in histogram should be equal to the number of bins " + "supplied while making this observer" + ) + + new_min, new_max = self._non_linear_param_search() + + return self._calculate_qparams(new_min, new_max) + + def _save_to_state_dict(self, destination, prefix, keep_vars): + super()._save_to_state_dict(destination, prefix, keep_vars) + destination[prefix + "min_val"] = self.min_val + destination[prefix + "max_val"] = self.max_val + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + version = local_metadata.get("version", None) + + if version is None or version < 3: + # if min_val and max_val are not initialized, update their shape + # to account for the differences between v2 and v3 + min_val_name, max_val_name = prefix + "min_val", prefix + "max_val" + if min_val_name in state_dict: + if state_dict[min_val_name].shape == torch.Size([0]): + state_dict[min_val_name] = torch.tensor(float("inf")) + if max_val_name in state_dict: + if state_dict[max_val_name].shape == torch.Size([0]): + state_dict[max_val_name] = torch.tensor(float("-inf")) + + local_state = ["min_val", "max_val"] + for name in local_state: + key = prefix + name + if key in state_dict: + val = state_dict[key] + setattr(self, name, val) + elif strict: + missing_keys.append(key) + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) + + def extra_repr(self): + return f"min_val={self.min_val}, max_val={self.max_val}" + + +class FixedQParamsObserver(ObserverBase): + r""" + Observer that simulates quantize and dequantize with fixed + quantization parameters in training time. Only per tensor + quantization is supported. + + Args: + `scale` (float): fixed scale for the observer + `zero_point` (int): fixed zero point for the observer + `dtype`, `qscheme`, `quant_min`, `quant_max` + """ + + scale: torch.Tensor + zero_point: torch.Tensor + + def __init__( + self, + scale, + zero_point, + dtype=torch.quint8, + qscheme=torch.per_tensor_affine, + quant_min=0, + quant_max=255, + is_dynamic=False, + **kwargs, + ): + if is_dynamic: + raise NotImplementedError( + "FixedQParamsObserver doesn't support dynamic quantization" + ) + super().__init__(dtype=dtype, is_dynamic=is_dynamic, **kwargs) + self.quant_min = quant_min + self.quant_max = quant_max + self.register_buffer('scale', torch.tensor([scale], dtype=torch.float)) + self.register_buffer('zero_point', torch.tensor([zero_point], dtype=torch.int)) + self.dtype = dtype + self.qscheme = qscheme + + def forward(self, X): + return X + + @torch.jit.export + def calculate_qparams(self): + return self.scale, self.zero_point + + +class PlaceholderObserver(ObserverBase): + r""" + Observer that doesn't do anything and just passes its configuration to the + quantized module's ``.from_float()``. + + Can be used for quantization to float16 which doesn't require determining + ranges. + + Args: + dtype: dtype argument to the `quantize` node needed to implement the + reference model spec. + quant_min: minimum value in quantized domain (TODO: align behavior with other observers) + quant_max: maximum value in quantized domain + custom_op_name: (temporary) specify this observer for an operator that doesn't require any observation + (Can be used in Graph Mode Passes for special case ops). + compute_dtype (deprecated): if set, marks the future quantize function to use + dynamic quantization instead of static quantization. + This field is deprecated, use `is_dynamic=True` instead. + is_dynamic: if True, the `quantize` function in the reference model + representation taking stats from this observer instance will + use dynamic quantization. + """ + + def __init__( + self, dtype=torch.float32, custom_op_name="", compute_dtype=None, + quant_min=None, quant_max=None, qscheme=None, eps=None, + is_dynamic=False, + ) -> None: + super().__init__(dtype=dtype, is_dynamic=is_dynamic) + if qscheme is None: + qscheme = torch.per_tensor_affine + if eps is None: + eps = torch.finfo(torch.float32).eps + + # dtype of input of the target operator, e.g. for dynamic quantization + # ops, the dtype will be float32 + self.dtype = dtype + self.qscheme = qscheme + self.quant_min = quant_min + self.quant_max = quant_max + self.eps = eps + self.custom_op = custom_op_name + # used for configuration of computation type for dynamic quantization + if compute_dtype: + is_dynamic = True + warnings.warn( + "Please use `is_dynamic` instead of `compute_dtype`. \ + `compute_dtype` will be deprecated in a future release \ + of PyTorch." + ) + + def forward(self, x): + return x + + @torch.jit.export + def extra_repr(self): + return f"dtype={self.dtype}, is_dynamic={self.is_dynamic}" + + @torch.jit.export + def calculate_qparams(self): + raise Exception( + "calculate_qparams should not be called for PlaceholderObserver" + ) + + +class RecordingObserver(ObserverBase): + r""" + The module is mainly for debug and records the tensor values during runtime. + + Args: + dtype: Quantized data type + qscheme: Quantization scheme to be used + reduce_range: Reduces the range of the quantized data type by 1 bit + """ + __annotations__ = {"tensor_val": List[Optional[torch.Tensor]]} + + def __init__(self, dtype=torch.quint8): + super().__init__(dtype=dtype, is_dynamic=False) # type: ignore[call-arg] + self.tensor_val = [] + + def forward(self, x): + self.tensor_val.append(x.clone()) + return x + + @torch.jit.export + def calculate_qparams(self): + raise Exception("calculate_qparams should not be called for RecordingObserver") + + @torch.jit.export + def get_tensor_value(self): + return self.tensor_val + + +class NoopObserver(ObserverBase): + r""" + Observer that doesn't do anything and just passes its configuration to the + quantized module's ``.from_float()``. + + Primarily used for quantization to float16 which doesn't require determining + ranges. + + Args: + dtype: Quantized data type + custom_op_name: (temporary) specify this observer for an operator that doesn't require any observation + (Can be used in Graph Mode Passes for special case ops). + """ + + def __init__(self, dtype=torch.float16, custom_op_name="") -> None: + super().__init__(dtype=dtype, is_dynamic=False) + self.dtype = dtype + self.custom_op = custom_op_name + + def forward(self, x): + return x + + @torch.jit.export + def calculate_qparams(self): + raise Exception("calculate_qparams should not be called for NoopObserver") + +class ReuseInputObserver(ObserverBase): + r""" This observer is used when we want to reuse the observer from the operator + that produces the input Tensor, typically used for operators like reshape, e.g. + ``` + x0 = ... + x1 = x0.reshape() + ``` + if we configure x0 to be observed by some observer, let's say MinMaxObserver, + and reshape is configured with ReuseInputObserver, we'll reuse the observer instance + for x0 for x1 (output of reshape). If x0 is not observed, we also won't observe x1. + + Note: this is only enabled in FX Graph Mode Quantization + """ + def __init__(self): + super().__init__(torch.quint8, is_dynamic=False) + + def forward(self, x): + return x + + @torch.jit.export + def calculate_qparams(self): + raise Exception("calculate_qparams should not be called for ReuseInputObserver") + +def _is_observer_script_module(mod, obs_type_name): + """Returns true if given mod is an instance of Observer script module.""" + if isinstance(mod, torch.jit.RecursiveScriptModule): + # qualified name looks like '__torch__.torch.ao.quantization.observer.___torch_mangle_2.MinMaxObserver' + suffix = mod._c.qualified_name.split(".", 1)[1] + name = re.sub(r"\.___torch_mangle_\d+", "", suffix) + return obs_type_name in name + return False + + +def _is_activation_post_process(module): + return ( + isinstance(module, (torch.ao.quantization.ObserverBase, + torch.ao.quantization.FakeQuantizeBase)) or _is_observer_script_module(module, "quantization.observer") + ) + + +def _is_per_channel_script_obs_instance(module): + if isinstance(module, torch.jit.RecursiveScriptModule): + return _is_observer_script_module( + module, "quantization.observer.PerChannelMinMaxObserver" + ) or _is_observer_script_module( + module, "quantization.observer.MovingAveragePerChannelMinMaxObserver" + ) + return False + + +def get_observer_state_dict(mod): + r""" + Returns the state dict corresponding to the observer stats. + Traverse the model state_dict and extract out the stats. + """ + od = OrderedDict() + if isinstance(mod, torch.jit.RecursiveScriptModule): + for k, v in mod.state_dict().items(): + if "observer" in k: + od[k] = v + else: + # path for GraphModule and nn.Module (eager mode) + for k, v in mod.state_dict().items(): + if "activation_post_process" in k: + od[k] = v + od._metadata = mod.state_dict()._metadata # type: ignore[attr-defined] + return od + + +def load_observer_state_dict(mod, obs_dict): + r""" + Given input model and a state_dict containing model observer stats, + load the stats back into the model. The observer state_dict can be saved + using torch.ao.quantization.get_observer_state_dict + """ + missing_keys: List[str] = [] + unexpected_keys: List[str] = [] + for name, module in mod.named_modules(): + prefix = name + "." + if _is_activation_post_process(module): + if _is_per_channel_script_obs_instance(module): + # For per-channel observers we need to call a custom load_from_state_dict to resize the tensor. + # However this is not called when the module is scripted and we end up calling the default one in module.py + module._load_from_state_dict_script( + obs_dict, prefix, {}, True, missing_keys, unexpected_keys, [] + ) + else: + module._load_from_state_dict( + obs_dict, prefix, {}, False, missing_keys, unexpected_keys, [] + ) + for k in missing_keys: + if "observer" in k or "activation_post_process" in k: + raise Exception(f"Missing keys for observer {k} in state_dict") + for k in unexpected_keys: + if "observer" in k or "activation_post_process" in k: + raise Exception(f"Unexpected keys for observer {k} in state_dict") + + +# Restrict activations to be in the range (0,127) +default_observer = MinMaxObserver.with_args(quant_min=0, quant_max=127) +""" +Default observer for static quantization, usually used for debugging. +""" + +default_placeholder_observer = PlaceholderObserver +""" +Default placeholder observer, usually used for quantization to torch.float16. +""" + +default_debug_observer = RecordingObserver +""" +Default debug-only observer. +""" + +default_weight_observer = MinMaxObserver.with_args( + dtype=torch.qint8, qscheme=torch.per_tensor_symmetric +) +""" +Default weight observer. +""" + +weight_observer_range_neg_127_to_127 = MinMaxObserver.with_args( + dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, + quant_min=-127, quant_max=127, eps=2 ** -12) +""" +Symmetric weight observer with the 8-bit values restricted to [-127, +127], excluding -128. +""" + +default_histogram_observer = HistogramObserver.with_args(quant_min=0, quant_max=127) +""" +Default histogram observer, usually used for PTQ. +""" + +default_per_channel_weight_observer = PerChannelMinMaxObserver.with_args( + dtype=torch.qint8, qscheme=torch.per_channel_symmetric +) +""" +Default per-channel weight observer, usually used on backends where per-channel +weight quantization is supported, such as `fbgemm`. +""" + +per_channel_weight_observer_range_neg_127_to_127 = PerChannelMinMaxObserver.with_args( + dtype=torch.qint8, qscheme=torch.per_channel_symmetric, + quant_min=-127, quant_max=127, eps=2 ** -12) +""" +Per-channel, symmetric weight observer with the 8-bit values restricted to [-127, +127], excluding -128. +""" + +default_dynamic_quant_observer = PlaceholderObserver.with_args( + dtype=torch.quint8, quant_min=0, quant_max=255, is_dynamic=True, +) +""" +Default observer for dynamic quantization. +""" + +default_float_qparams_observer = PerChannelMinMaxObserver.with_args( + dtype=torch.quint8, qscheme=torch.per_channel_affine_float_qparams, ch_axis=0 +) +""" +Default observer for a floating point zero-point. +""" + +default_float_qparams_observer_4bit = PerChannelMinMaxObserver.with_args( + dtype=torch.quint4x2, qscheme=torch.per_channel_affine_float_qparams, ch_axis=0 +) +""" +Default observer for a floating point zero-point and 4 bit activations. +""" + +# TODO(future PR): remove these defaults and enforce activation functions +# to explicitly specify their output range +default_fixed_qparams_range_neg1to1_observer = FixedQParamsObserver.with_args( + scale=2.0 / 256.0, zero_point=128, dtype=torch.quint8, quant_min=0, quant_max=255) +default_fixed_qparams_range_0to1_observer = FixedQParamsObserver.with_args( + scale=1.0 / 256.0, zero_point=0, dtype=torch.quint8, quant_min=0, quant_max=255) +# TODO: the following 2 variables are kept for backwards compatibility; remove after a few releases +default_symmetric_fixed_qparams_observer = default_fixed_qparams_range_neg1to1_observer +default_affine_fixed_qparams_observer = default_fixed_qparams_range_0to1_observer + +""" +Default observers for fixed qparams operations. +""" + +default_reuse_input_observer = ReuseInputObserver +""" +Default observer for operators like reshape that reuses the observer of input to +the operator +""" diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/qconfig.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/qconfig.py new file mode 100644 index 0000000000000000000000000000000000000000..dc8353d6172990b5be63f27dfb798605095f425e --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/qconfig.py @@ -0,0 +1,560 @@ +from collections import namedtuple +from typing import Optional, Any, Union, Type + +import torch +import torch.nn as nn +from torch.ao.quantization.fake_quantize import ( + FakeQuantize, + FakeQuantizeBase, + default_fake_quant, + default_dynamic_fake_quant, + default_per_channel_weight_fake_quant, + default_weight_fake_quant, + default_fused_act_fake_quant, + default_fused_wt_fake_quant, + FusedMovingAvgObsFakeQuantize, + default_fused_per_channel_wt_fake_quant, + default_embedding_fake_quant, + default_embedding_fake_quant_4bit, + fused_wt_fake_quant_range_neg_127_to_127, + fused_per_channel_wt_fake_quant_range_neg_127_to_127, +) + +from .observer import ( + _PartialWrapper, + MinMaxObserver, + HistogramObserver, + MovingAverageMinMaxObserver, + NoopObserver, + PlaceholderObserver, + ReuseInputObserver, + default_debug_observer, + default_dynamic_quant_observer, + default_float_qparams_observer, + default_float_qparams_observer_4bit, + default_observer, + default_per_channel_weight_observer, + default_placeholder_observer, + default_weight_observer, + weight_observer_range_neg_127_to_127, + per_channel_weight_observer_range_neg_127_to_127, + default_reuse_input_observer, + ObserverBase, +) +import warnings +import copy + +__all__ = [ + "QConfig", + # TODO: deprecated, remove + "QConfigDynamic", + "default_qconfig", + "default_debug_qconfig", + "default_per_channel_qconfig", + "default_dynamic_qconfig", + "float16_dynamic_qconfig", + "float16_static_qconfig", + "per_channel_dynamic_qconfig", + "float_qparams_weight_only_qconfig", + "float_qparams_weight_only_qconfig_4bit", + "default_quint8_weight_qconfig", + "default_qat_qconfig", + "default_dynamic_qat_qconfig", + "default_weight_only_qconfig", + "default_activation_only_qconfig", + "default_qat_qconfig_v2", + "default_reuse_input_qconfig", + "default_symmetric_qnnpack_qconfig", + "default_per_channel_symmetric_qnnpack_qconfig", + "default_symmetric_qnnpack_qat_qconfig", + "default_per_channel_symmetric_qnnpack_qat_qconfig", + "default_embedding_qat_qconfig", + "default_embedding_qat_qconfig_4bit", + "get_default_qconfig", + "get_default_qat_qconfig", + "get_default_qconfig_dict", + "get_default_qat_qconfig_dict", + "QConfigAny", + "qconfig_equals", + +] + +class QConfig(namedtuple('QConfig', ['activation', 'weight'])): + """ + Describes how to quantize a layer or a part of the network by providing + settings (observer classes) for activations and weights respectively. + + + Note that QConfig needs to contain observer **classes** (like MinMaxObserver) or a callable that returns + instances on invocation, not the concrete observer instances themselves. + Quantization preparation function will instantiate observers multiple times for each of the layers. + + + Observer classes have usually reasonable default arguments, but they can be overwritten with `with_args` + method (that behaves like functools.partial):: + + my_qconfig = QConfig( + activation=MinMaxObserver.with_args(dtype=torch.qint8), + weight=default_observer.with_args(dtype=torch.qint8)) + + """ + def __new__(cls, activation, weight): + # catch common mistakes + if isinstance(activation, nn.Module) or isinstance(weight, nn.Module): + raise ValueError("QConfig received observer instance, please pass observer class instead. " + + "Use MyObserver.with_args(x=1) to override arguments to constructor if needed") + return super().__new__(cls, activation, weight) + + +class QConfigDynamic(namedtuple('QConfigDynamic', ['activation', 'weight'])): + """ + Describes how to dynamically quantize a layer or a part of the network by providing + settings (observer classes) for weights. + + It's like QConfig, but for dynamic quantization. + + Note that QConfigDynamic needs to contain observer **classes** (like MinMaxObserver) or a callable that returns + instances on invocation, not the concrete observer instances themselves. + Quantization function will instantiate observers multiple times for each of the layers. + + Observer classes have usually reasonable default arguments, but they can be overwritten with `with_args` + method (that behaves like functools.partial):: + + my_qconfig = QConfigDynamic(weight=default_observer.with_args(dtype=torch.qint8)) + """ + def __new__(cls, activation=torch.nn.Identity, weight=torch.nn.Identity): + # catch common mistakes + if isinstance(weight, nn.Module): + raise ValueError("QConfigDynamic received observer instance, please pass observer class instead. " + + "Use MyObserver.with_args(x=1) to override arguments to constructor if needed") + warnings.warn("QConfigDynamic is going to be deprecated in PyTorch 1.12, please use QConfig instead") + return super().__new__(cls, activation, weight) + + +default_qconfig = QConfig(activation=default_observer, + weight=default_weight_observer) +""" +Default qconfig configuration. +""" + +default_debug_qconfig = QConfig(weight=default_weight_observer, + activation=default_debug_observer) +""" +Default qconfig configuration for debugging. +""" + +default_per_channel_qconfig = QConfig(activation=default_observer, + weight=default_per_channel_weight_observer) +""" +Default qconfig configuration for per channel weight quantization. +""" + +default_dynamic_qconfig = QConfig(activation=default_dynamic_quant_observer, + weight=default_weight_observer) +""" +Default dynamic qconfig. +""" + +float16_dynamic_qconfig = QConfig(activation=PlaceholderObserver.with_args(dtype=torch.float16, is_dynamic=True), + weight=PlaceholderObserver.with_args(dtype=torch.float16)) +""" +Dynamic qconfig with weights quantized to `torch.float16`. +""" + +float16_static_qconfig = QConfig(activation=PlaceholderObserver.with_args(dtype=torch.float16), + weight=PlaceholderObserver.with_args(dtype=torch.float16)) +""" +Dynamic qconfig with both activations and weights quantized to `torch.float16`. +""" + +per_channel_dynamic_qconfig = QConfig(activation=default_dynamic_quant_observer, + weight=default_per_channel_weight_observer) +""" +Dynamic qconfig with weights quantized per channel. +""" + +float_qparams_weight_only_qconfig = QConfig( + activation=default_placeholder_observer, + weight=default_float_qparams_observer) +""" +Dynamic qconfig with weights quantized with a floating point zero_point. +""" + +float_qparams_weight_only_qconfig_4bit = QConfig( + activation=default_placeholder_observer, + weight=default_float_qparams_observer_4bit) + +default_qat_qconfig = QConfig(activation=default_fake_quant, + weight=default_weight_fake_quant) +""" +Default qconfig for QAT. +""" + +default_dynamic_qat_qconfig = QConfig(activation=default_dynamic_fake_quant, + weight=default_weight_fake_quant) +""" +Default qconfig for dynamic QAT. +""" + +default_weight_only_qconfig = QConfig(activation=torch.nn.Identity, + weight=default_weight_fake_quant) +""" +Default qconfig for quantizing weights only. +""" + +default_activation_only_qconfig = QConfig(activation=default_fake_quant, + weight=torch.nn.Identity) +""" +Default qconfig for quantizing activations only. +""" + +# QAT config that uses a fused observer + fake quant modules for optimized training performance. +# to modify the activation/weight observers, the default entries in fake_quantize.py can be modified. +default_qat_qconfig_v2 = QConfig(activation=default_fused_act_fake_quant, weight=default_fused_wt_fake_quant) +""" +Fused version of `default_qat_config`, has performance benefits. +""" + +default_reuse_input_qconfig = QConfig(activation=default_reuse_input_observer, + weight=NoopObserver) +""" +Default qconfig for operators that reuse the observers from input Tensor, e.g. reshape +""" + +def get_default_qconfig(backend='x86', version=0): + """ + Returns the default PTQ qconfig for the specified backend. + + Args: + * `backend` (str): a string representing the target backend. Currently supports + `x86` (default), `fbgemm`, `qnnpack` and `onednn`. + + Return: + qconfig + """ + supported_backends = ["fbgemm", "x86", "qnnpack", "onednn"] + if backend not in supported_backends: + raise AssertionError( + "backend: " + str(backend) + + f" not supported. backend must be one of {supported_backends}" + ) + + if version == 0: + if backend == 'fbgemm': + qconfig = QConfig(activation=HistogramObserver.with_args(reduce_range=True), + weight=default_per_channel_weight_observer) + elif backend == 'qnnpack': + # TODO: make this compatible with xnnpack constraints + qconfig = QConfig(activation=HistogramObserver.with_args(reduce_range=False), + weight=default_weight_observer) + elif backend == 'onednn': + if not torch.cpu._is_cpu_support_vnni(): + warnings.warn( + "Default qconfig of oneDNN backend with reduce_range of false may have accuracy issues " + "on CPU without Vector Neural Network Instruction support.") + qconfig = QConfig(activation=HistogramObserver.with_args(reduce_range=False), + weight=default_per_channel_weight_observer) + elif backend == 'x86': + qconfig = QConfig(activation=HistogramObserver.with_args(reduce_range=True), + weight=default_per_channel_weight_observer) + else: + # won't reach + qconfig = default_qconfig + else: + raise AssertionError("Version number: " + str(version) + + " in get_default_qconfig is not supported. Version number must be 0") + + return qconfig + +""" +Default, symmetric PTQ qconfig for the specified backend. And a per_channel +variant of the same. + +Symmetric here applies to signed weights with zero point = 0, and additional +value restrictions. The activations are also signed 8-bit integers with this +qconfig. + + * Once this change is merged [as of 3/17/22], with backend or qengine = + 'qnnpack', some quantized operators with this symmetric qconfig may use + operators from xnnpack library. + + ** Support to use xnnpack ops with `qnnpack` backed for asymmetric + qconfig (returned by get_default_qconfig()) is not available yet. + + * This qconfig uses signed activations and weights. Weights have added + restrictions such as zero point is forced to be 0, making the weights + symmetric, hence the name. And the 8-bit quantized values are + restricting to to [-127, +127], excluding -128. + + * xnnpack has a requantization scale value restriction, 0x1p-32 <= + requantization_scale < 256.0 where, `requantization_scale = (input_scale + * kernel_scale) / (output_scale)`. Using this eps (w/ assumed max value + of 256) is to prevent requantization_scale to go below xnnpack lower + threshold. +""" +default_symmetric_qnnpack_qconfig = QConfig(activation=HistogramObserver.with_args(dtype=torch.qint8, + reduce_range=False, + eps=2 ** -12), + weight=weight_observer_range_neg_127_to_127) + +default_per_channel_symmetric_qnnpack_qconfig = QConfig(activation=HistogramObserver.with_args(dtype=torch.qint8, + reduce_range=False, + eps=2 ** -12), + weight=per_channel_weight_observer_range_neg_127_to_127) + +default_embedding_qat_qconfig = QConfig(activation=NoopObserver.with_args(dtype=torch.float32), + weight=default_embedding_fake_quant) + +default_embedding_qat_qconfig_4bit = QConfig(activation=NoopObserver.with_args(dtype=torch.float32), + weight=default_embedding_fake_quant_4bit) + +default_quint8_weight_qconfig = QConfig(activation=HistogramObserver, weight=MinMaxObserver) + +def get_default_qat_qconfig(backend='x86', version=1): + """ + Returns the default QAT qconfig for the specified backend. + + Args: + * `backend` (str): a string representing the target backend. Currently supports + `x86` (default), `fbgemm`, `qnnpack` and `onednn`. + * `version`: version, for backwards compatibility. Can be `None` or `1`. + + Return: + qconfig + """ + supported_backends = ["fbgemm", "x86", "qnnpack", "onednn"] + if backend not in supported_backends: + raise AssertionError( + "backend: " + str(backend) + + f" not supported. backend must be one of {supported_backends}" + ) + + # Histogram observer is too slow for quantization aware training + if version == 0: + if backend == 'fbgemm': + qconfig = QConfig(activation=FakeQuantize.with_args(observer=MovingAverageMinMaxObserver, + quant_min=0, + quant_max=255, + reduce_range=True), + weight=default_per_channel_weight_fake_quant) + elif backend == 'qnnpack': + qconfig = QConfig(activation=FakeQuantize.with_args(observer=MovingAverageMinMaxObserver, + quant_min=0, + quant_max=255, + reduce_range=False), + weight=default_weight_fake_quant) + elif backend == 'onednn': + qconfig = QConfig(activation=FakeQuantize.with_args(observer=MovingAverageMinMaxObserver, + quant_min=0, + quant_max=255), + weight=default_per_channel_weight_fake_quant) + elif backend == 'x86': + qconfig = QConfig(activation=FakeQuantize.with_args(observer=MovingAverageMinMaxObserver, + quant_min=0, + quant_max=255, + reduce_range=True), + weight=default_per_channel_weight_fake_quant) + else: + qconfig = default_qat_qconfig + # Use the fused observe + fake_quant modules for doing QAT. + elif version == 1: + if backend == 'fbgemm': + qconfig = QConfig(activation=FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver, + quant_min=0, + quant_max=255, + reduce_range=True), + weight=default_fused_per_channel_wt_fake_quant) + elif backend == 'qnnpack': + # TODO: make this compatible with xnnpack constraints + qconfig = QConfig(activation=FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver, + quant_min=0, + quant_max=255, + reduce_range=False), + weight=default_fused_wt_fake_quant) + elif backend == 'onednn': + qconfig = QConfig(activation=FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver, + quant_min=0, + quant_max=255), + weight=default_fused_per_channel_wt_fake_quant) + elif backend == 'x86': + qconfig = QConfig(activation=FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver, + quant_min=0, + quant_max=255, + reduce_range=True), + weight=default_fused_per_channel_wt_fake_quant) + else: + qconfig = default_qat_qconfig_v2 + else: + raise AssertionError("Version number: " + str(version) + + "in get_default_qat_qconfig is not supported. Version number must be 0 or 1") + + return qconfig + +""" +Default symmetric QAT qconfig for qnnpack. And its per channel weight variant. +""" +default_symmetric_qnnpack_qat_qconfig = QConfig( + activation=FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver, + quant_min=-128, + quant_max=127, + dtype=torch.qint8, + reduce_range=False, + eps=2 ** -12), + weight=fused_wt_fake_quant_range_neg_127_to_127) + +default_per_channel_symmetric_qnnpack_qat_qconfig = QConfig( + activation=FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver, + quant_min=-128, + quant_max=127, + dtype=torch.qint8, + reduce_range=False, + eps=2 ** -12), + weight=fused_per_channel_wt_fake_quant_range_neg_127_to_127) + +_default_fp32_placeholder_qconfig = QConfig( + activation=PlaceholderObserver.with_args(dtype=torch.float32), + weight=PlaceholderObserver.with_args(dtype=torch.float32) +) + +_default_quint8_placeholder_qconfig = QConfig( + activation=PlaceholderObserver.with_args(dtype=torch.quint8), + # operators using this qconfig doesn't have weights + weight=None, +) + +def get_default_qconfig_dict(backend='x86', version=0): + warnings.warn( + "torch.ao.quantization.get_default_qconfig_dict is deprecated and will be removed in " + "a future version. Please use torch.ao.quantization.get_default_qconfig_mapping instead.") + return torch.ao.quantization.get_default_qconfig_mapping(backend, version).to_dict() + +def get_default_qat_qconfig_dict(backend='x86', version=1): + warnings.warn( + "torch.ao.quantization.get_default_qat_qconfig_dict is deprecated and will be removed in " + "a future version. Please use torch.ao.quantization.get_default_qat_qconfig_mapping instead.") + return torch.ao.quantization.get_default_qat_qconfig_mapping(backend, version).to_dict() + +def _assert_valid_qconfig(qconfig: Optional[QConfig], + mod: torch.nn.Module) -> None: + """ + Verifies that this `qconfig` is valid. + """ + if qconfig is None: + return + is_conv_transpose_mod = ( + isinstance(mod, (torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d))) + if is_conv_transpose_mod: + if qconfig.weight is None: + # for now, we assume that any qconfig for ConvTranspose without a weight is valid + return + example_observer = qconfig.weight() + is_per_channel = ( + isinstance(example_observer, (torch.ao.quantization.PerChannelMinMaxObserver, + torch.ao.quantization.MovingAveragePerChannelMinMaxObserver)) + ) + assert not is_per_channel, \ + 'Per channel weight observer is not supported yet for ConvTranspose{n}d.' + +QConfigAny = Optional[QConfig] +QConfigAny.__module__ = "torch.ao.quantization.qconfig" + +def _add_module_to_qconfig_obs_ctr( + qconfig: QConfigAny, + module: Optional[nn.Module]) -> Any: + r"""This is a helper function for use in quantization prepare that updates a qconfig so that + the constructors stored in the qconfig will create observers on the same device that + 'module' is on. This is intended to be used when the qconfigs are propagated to each + module in order to avoid potential device alignment issues. + + Args: + qconfig: QConfig with obs constructors stored in activation and weight + module: module which the qconfig is related to + + Return: + qconfig: configured so that obs constructors set to construct on the same device as module + """ + + if module is None or qconfig is None or qconfig._fields != ('activation', 'weight'): + return qconfig + + def get_factory_kwargs_based_on_module_device(): + assert isinstance(module, torch.nn.Module) + devices = {p.device for p in module.parameters()} | \ + {p.device for p in module.buffers()} + device = next(iter(devices)) if len(devices) > 0 else None + return None if device is None else {'device': device} + + def configure_constructor_to_put_obs_on_module_device(original_constructor): + try: + # check if constructor can accept factory_kwargs + check = original_constructor.with_args(factory_kwargs=None) + check() + return original_constructor.with_callable_args(factory_kwargs=get_factory_kwargs_based_on_module_device) + except AttributeError: # qconfig doesn't have activation or weight + return original_constructor + except TypeError: # the class doesn't accept factory_kwargs argument + return original_constructor + + activation = configure_constructor_to_put_obs_on_module_device(qconfig.activation) + weight = configure_constructor_to_put_obs_on_module_device(qconfig.weight) + + return QConfig(activation, weight) + +_ObserverOrFakeQuantizeConstructor = Union[_PartialWrapper, Type[ObserverBase], Type[FakeQuantizeBase]] + +def _obs_or_fq_ctr_equals(obs_or_fq1: _ObserverOrFakeQuantizeConstructor, obs_or_fq2: _ObserverOrFakeQuantizeConstructor): + if isinstance(obs_or_fq1, _PartialWrapper) and isinstance(obs_or_fq2, _PartialWrapper): + return _partial_wrapper_equals(obs_or_fq1, obs_or_fq2) + return obs_or_fq1 == obs_or_fq2 + +def _partial_wrapper_equals(obs_or_fq1: _PartialWrapper, obs_or_fq2: _PartialWrapper): + """ + Return whether the two partial wrappers are equal, + """ + # functools.partial has no __eq__ operator defined so '==' defaults to 'is' + obs_or_fq1_keywords = copy.copy(obs_or_fq1.p.keywords) + obs_or_fq2_keywords = copy.copy(obs_or_fq2.p.keywords) + keywords_equal = True + # compare observer constructor with _obs_or_fq_ctr_equals since direct compare would fail + if "observer" in obs_or_fq1_keywords and "observer" in obs_or_fq2_keywords: + keywords_equal = keywords_equal and _obs_or_fq_ctr_equals(obs_or_fq1_keywords["observer"], obs_or_fq2_keywords["observer"]) + obs_or_fq1_keywords.pop("observer") + obs_or_fq2_keywords.pop("observer") + keywords_equal = keywords_equal and obs_or_fq1_keywords == obs_or_fq2_keywords + return obs_or_fq1.p.func == obs_or_fq2.p.func and obs_or_fq1.p.args == obs_or_fq2.p.args and keywords_equal + +def qconfig_equals(q1: QConfigAny, q2: QConfigAny): + """ + Returns `True` if `q1` equals `q2`, and `False` otherwise. + """ + if q1 is None or q2 is None: + return q1 == q2 + else: + assert q1 is not None and q2 is not None + try: + # Qconfig weight and activation can be either a partial wrapper, + # or an observer class. Special handling is required (above) for + # comparing partial wrappers. + activation_same = _obs_or_fq_ctr_equals(q1.activation, q2.activation) + weight_same = _obs_or_fq_ctr_equals(q1.weight, q2.weight) + return activation_same and weight_same + except AttributeError: + return q1 == q2 + +def _activation_is_memoryless(qconfig: QConfig): + """ + Return whether the observer for activations defined in the given QConfig is memoryless. + This means a MovingAverage observer with averaging constant equal to 1. + """ + def _is_memoryless(observer): + return hasattr(observer, "averaging_constant") and observer.averaging_constant == 1 + act = qconfig.activation() + if isinstance(act, FakeQuantizeBase) and hasattr(act, "activation_post_process"): + return _is_memoryless(act.activation_post_process) + else: + return _is_memoryless(act) + +def _is_reuse_input_qconfig(qconfig: Optional[QConfig]): + return qconfig is not None and \ + isinstance(qconfig.activation(), ReuseInputObserver) and \ + isinstance(qconfig.weight(), NoopObserver) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/quantization_mappings.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/quantization_mappings.py new file mode 100644 index 0000000000000000000000000000000000000000..179cddca27427bae08139b5b777207cdd31650e8 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/quantization_mappings.py @@ -0,0 +1,348 @@ +import copy + +import torch +from torch import nn + +import torch.nn.functional as F +import torch.ao.nn.intrinsic as nni +import torch.ao.nn.intrinsic.quantized as nniq +import torch.ao.nn.intrinsic.quantized.dynamic as nniqd +import torch.ao.nn.intrinsic.qat as nniqat +import torch.ao.nn.quantized as nnq +import torch.ao.nn.quantized.reference as nnqr +import torch.ao.nn.quantized.dynamic as nnqd +import torch.ao.nn.qat as nnqat +import torch.ao.nn.qat.dynamic as nnqatd + +from typing import Optional, Union, Dict, Set, Callable, Any + +# Because `torch.ao.nn` uses lazy imports, we need to make +# sure we import the contents explicitly here. +import torch.ao.nn.sparse +import torch.ao.nn as ao_nn +from torch.ao.quantization.stubs import QuantStub, DeQuantStub +from torch.ao.quantization.fake_quantize import ( + default_fixed_qparams_range_0to1_fake_quant, + default_fixed_qparams_range_neg1to1_fake_quant, +) +from torch.ao.quantization.utils import get_combined_dict +from torch.nn.utils.parametrize import type_before_parametrizations + +__all__ = [ + "DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS", + "DEFAULT_STATIC_QUANT_MODULE_MAPPINGS", + "DEFAULT_QAT_MODULE_MAPPINGS", + "DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS", + "DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS", + "DEFAULT_MODULE_TO_ACT_POST_PROCESS", + "DEFAULT_STATIC_SPARSE_QUANT_MODULE_MAPPINGS", + "DEFAULT_DYNAMIC_SPARSE_QUANT_MODULE_MAPPINGS", + "no_observer_set", + "get_default_static_quant_module_mappings", + "get_default_static_quant_reference_module_mappings", + "get_embedding_static_quant_module_mappings", + "get_default_static_sparse_quant_module_mappings", + "get_static_quant_module_class", + "get_dynamic_quant_module_class", + "get_default_qat_module_mappings", + "get_embedding_qat_module_mappings", + "get_default_dynamic_quant_module_mappings", + "get_default_dynamic_sparse_quant_module_mappings", + "get_default_qconfig_propagation_list", + "get_default_compare_output_module_list", + "get_default_float_to_quantized_operator_mappings", + "get_quantized_operator", +] + +# Default map for swapping float module to reference quantized modules +DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS : Dict[Callable, Any] = { + QuantStub: nnq.Quantize, + DeQuantStub: nnq.DeQuantize, + nn.Linear: nnqr.Linear, + nn.Conv1d: nnqr.Conv1d, + nn.Conv2d: nnqr.Conv2d, + nn.Conv3d: nnqr.Conv3d, + nn.ConvTranspose1d: nnqr.ConvTranspose1d, + nn.ConvTranspose2d: nnqr.ConvTranspose2d, + nn.ConvTranspose3d: nnqr.ConvTranspose3d, + nn.Embedding: nnqr.Embedding, + nn.EmbeddingBag: nnqr.EmbeddingBag, + nn.GRUCell: nnqr.GRUCell, + nn.LSTMCell: nnqr.LSTMCell, + nn.RNNCell: nnqr.RNNCell, + nn.LSTM: nnqr.LSTM, +} + +# Default map for swapping float module to quantized ones +DEFAULT_STATIC_QUANT_MODULE_MAPPINGS : Dict[Callable, Any] = { + QuantStub: nnq.Quantize, + DeQuantStub: nnq.DeQuantize, + nn.BatchNorm2d: nnq.BatchNorm2d, + nn.BatchNorm3d: nnq.BatchNorm3d, + nn.Dropout: nnq.Dropout, + nn.Conv1d: nnq.Conv1d, + nn.Conv2d: nnq.Conv2d, + nn.Conv3d: nnq.Conv3d, + nn.ConvTranspose1d: nnq.ConvTranspose1d, + nn.ConvTranspose2d: nnq.ConvTranspose2d, + nn.ConvTranspose3d: nnq.ConvTranspose3d, + nn.ELU: nnq.ELU, + nn.Embedding: nnq.Embedding, + nn.EmbeddingBag: nnq.EmbeddingBag, + nn.GroupNorm: nnq.GroupNorm, + nn.Hardswish: nnq.Hardswish, + nn.InstanceNorm1d: nnq.InstanceNorm1d, + nn.InstanceNorm2d: nnq.InstanceNorm2d, + nn.InstanceNorm3d: nnq.InstanceNorm3d, + nn.LayerNorm: nnq.LayerNorm, + nn.LeakyReLU: nnq.LeakyReLU, + nn.modules.linear.NonDynamicallyQuantizableLinear: nnq.Linear, + nn.Linear: nnq.Linear, + nn.ReLU6: nnq.ReLU6, + nn.Dropout: nnq.Dropout, + nn.PReLU: nnq.PReLU, + # Wrapper Modules: + nnq.FloatFunctional: nnq.QFunctional, + # Intrinsic modules: + nni.BNReLU2d: nniq.BNReLU2d, + nni.BNReLU3d: nniq.BNReLU3d, + nni.ConvReLU1d: nniq.ConvReLU1d, + nni.ConvReLU2d: nniq.ConvReLU2d, + nni.ConvReLU3d: nniq.ConvReLU3d, + nni.ConvAdd2d: nniq.ConvAdd2d, + nni.ConvAddReLU2d: nniq.ConvAddReLU2d, + nni.LinearReLU: nniq.LinearReLU, + nni.LinearLeakyReLU: nniq.LinearLeakyReLU, + nni.LinearTanh: nniq.LinearTanh, + nniqat.ConvBn1d: nnq.Conv1d, + nniqat.ConvBn2d: nnq.Conv2d, + nniqat.ConvBn3d: nnq.Conv3d, + nniqat.ConvBnReLU1d: nniq.ConvReLU1d, + nniqat.ConvBnReLU2d: nniq.ConvReLU2d, + nniqat.ConvBnReLU3d: nniq.ConvReLU3d, + nniqat.ConvReLU2d: nniq.ConvReLU2d, + nniqat.ConvReLU3d: nniq.ConvReLU3d, + nniqat.LinearReLU: nniq.LinearReLU, + nniqat.LinearBn1d: nnq.Linear, + # QAT modules: + nnqat.Linear: nnq.Linear, + nnqat.Conv2d: nnq.Conv2d, + nnqat.Conv3d: nnq.Conv3d, +} + +# Default map for swapping float module to qat modules +DEFAULT_QAT_MODULE_MAPPINGS : Dict[Callable, Any] = { + nn.Conv2d: nnqat.Conv2d, + nn.Conv3d: nnqat.Conv3d, + nn.Linear: nnqat.Linear, + nn.modules.linear.NonDynamicallyQuantizableLinear: nnqat.Linear, + # Intrinsic modules: + nni.ConvBn1d: nniqat.ConvBn1d, + nni.ConvBn2d: nniqat.ConvBn2d, + nni.ConvBn3d: nniqat.ConvBn3d, + nni.ConvBnReLU1d: nniqat.ConvBnReLU1d, + nni.ConvBnReLU2d: nniqat.ConvBnReLU2d, + nni.ConvBnReLU3d: nniqat.ConvBnReLU3d, + nni.ConvReLU2d: nniqat.ConvReLU2d, + nni.ConvReLU3d: nniqat.ConvReLU3d, + nni.LinearReLU: nniqat.LinearReLU, + nni.LinearBn1d: nniqat.LinearBn1d, +} + +# Default map for swapping dynamic modules +DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS : Dict[Callable, Any] = { + nn.GRUCell: nnqd.GRUCell, + nn.Linear: nnqd.Linear, + nnqatd.Linear: nnqd.Linear, + nn.modules.linear.NonDynamicallyQuantizableLinear: nnqd.Linear, + nn.LSTM: nnqd.LSTM, + nn.GRU: nnqd.GRU, + nn.LSTMCell: nnqd.LSTMCell, + nn.RNNCell: nnqd.RNNCell, + nni.LinearReLU: nniqd.LinearReLU, + nn.EmbeddingBag: nnq.EmbeddingBag, + nn.Embedding: nnq.Embedding, + # Don't want to enable these by default because the numerical + # accuracy is poor compared to other dynamic ops + # nn.Conv1d: nnqd.Conv1d, + # nn.Conv2d: nnqd.Conv2d, + # nn.Conv3d: nnqd.Conv3d, + # nn.ConvTranspose1d: nnqd.ConvTranspose1d, + # nn.ConvTranspose2d: nnqd.ConvTranspose2d, + # nn.ConvTranspose3d: nnqd.ConvTranspose3d, +} + +# Allowlist for propagating the qconfig +_INCLUDE_QCONFIG_PROPAGATE_LIST : Set[Callable] = { + nn.Sequential, +} + +# Default mapping from floating point function or torch ops to quantized ops +# TODO: merge with default static mapping +DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS : Dict[Union[Callable, str], Callable] = { + F.elu: torch.ops.quantized.elu, + F.hardswish: torch.ops.quantized.hardswish, + F.instance_norm: torch.ops.quantized.instance_norm, + F.layer_norm: torch.ops.quantized.layer_norm, + F.leaky_relu: torch.ops.quantized.leaky_relu, + F.dropout: torch.ops.quantized.dropout, +} + +# mapping from module to output activation post process class +DEFAULT_MODULE_TO_ACT_POST_PROCESS : Dict[Callable, Callable] = { + nn.Hardsigmoid: default_fixed_qparams_range_0to1_fake_quant, + nn.Sigmoid: default_fixed_qparams_range_0to1_fake_quant, + nn.Softmax: default_fixed_qparams_range_0to1_fake_quant, + nn.Tanh: default_fixed_qparams_range_neg1to1_fake_quant, +} + +# Default map for swapping float module to static sparse quantized ones +DEFAULT_STATIC_SPARSE_QUANT_MODULE_MAPPINGS : Dict[Callable, Any] = { + nn.Linear: ao_nn.sparse.quantized.Linear +} + +# Default map for swapping float module to dynamic sparse quantized ones +DEFAULT_DYNAMIC_SPARSE_QUANT_MODULE_MAPPINGS : Dict[Callable, Any] = { + nn.Linear: ao_nn.sparse.quantized.dynamic.Linear +} + +def no_observer_set() -> Set[Any]: + r"""These modules cannot have observers inserted by default.""" + no_observers = { + nn.quantizable.LSTM, + nn.quantizable.MultiheadAttention + } + return no_observers + +def get_default_static_quant_module_mappings() -> Dict[Callable, Any]: + ''' Get module mapping for post training static quantization + ''' + return copy.deepcopy(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS) + +def get_default_static_quant_reference_module_mappings() -> Dict[Callable, Any]: + ''' Get reference module mapping for post training static quantization + ''' + return copy.deepcopy(DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS) + +def get_embedding_static_quant_module_mappings() -> Dict[Callable, Any]: + ''' Get module mapping, including mapping for embedding QAT + ''' + mapping = copy.deepcopy(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS) + mapping[nnqat.EmbeddingBag] = nnq.EmbeddingBag + mapping[nnqat.Embedding] = nnq.Embedding + return mapping + +def get_default_static_sparse_quant_module_mappings() -> Dict[Callable, Any]: + ''' Get module mapping for post training static sparse quantization + ''' + return copy.deepcopy(DEFAULT_STATIC_SPARSE_QUANT_MODULE_MAPPINGS) + +def get_static_quant_module_class( + float_module_class: Callable, + additional_static_quant_mapping: Optional[Dict[Callable, Any]] = None, + is_reference: bool = False) -> Any: + r"""n Get the statically quantized module class corresponding to + the floating point module class + """ + if additional_static_quant_mapping is None: + additional_static_quant_mapping = {} + all_mappings = get_combined_dict( + DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS if is_reference + else DEFAULT_STATIC_QUANT_MODULE_MAPPINGS, additional_static_quant_mapping) + static_quant_module_class = all_mappings.get(float_module_class, None) + assert static_quant_module_class is not None, \ + f"Floating point module class {str(float_module_class)}" + \ + " does not have a corresponding quantized module class" + return copy.deepcopy(static_quant_module_class) + +def get_dynamic_quant_module_class( + float_module_class: Callable, + additional_dynamic_quant_mapping: Optional[Dict[Callable, Any]] = None) -> Any: + r"""n Get the dynamically quantized module class corresponding to + the floating point module class + """ + if additional_dynamic_quant_mapping is None: + additional_dynamic_quant_mapping = {} + all_mappings = get_combined_dict(DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS, additional_dynamic_quant_mapping) + dynamic_quant_module_class = all_mappings.get(float_module_class, None) + assert dynamic_quant_module_class is not None, \ + f"Floating point module class {str(float_module_class)}" + \ + " does not have a corresponding quantized module class" + return copy.deepcopy(dynamic_quant_module_class) + +def get_default_qat_module_mappings() -> Dict[Callable, Any]: + ''' Get default module mapping for quantization aware training + ''' + return copy.deepcopy(DEFAULT_QAT_MODULE_MAPPINGS) + +def get_embedding_qat_module_mappings() -> Dict[Callable, Any]: + ''' Get module mapping for quantization aware training + This is includes default values in addition to + enabling qat for embeddings. + ''' + mapping = copy.deepcopy(DEFAULT_QAT_MODULE_MAPPINGS) + mapping[nn.EmbeddingBag] = nnqat.EmbeddingBag + mapping[nn.Embedding] = nnqat.Embedding + return mapping + +def get_default_dynamic_quant_module_mappings() -> Dict[Callable, Any]: + ''' Get module mapping for post training dynamic quantization + ''' + return DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS + +def get_default_dynamic_sparse_quant_module_mappings() -> Dict[Callable, Any]: + ''' Get module mapping for post training dynamic sparse quantization + ''' + return DEFAULT_DYNAMIC_SPARSE_QUANT_MODULE_MAPPINGS + +def get_default_qconfig_propagation_list() -> Set[Callable]: + ''' Get the default list of module types that we'll attach qconfig + attribute to in prepare + ''' + QCONFIG_PROPAGATE_MODULE_CLASS_LIST = ( + set(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS.keys()) | + set(DEFAULT_QAT_MODULE_MAPPINGS.keys()) | + set(DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS.keys()) | + _INCLUDE_QCONFIG_PROPAGATE_LIST + ) + return copy.deepcopy(QCONFIG_PROPAGATE_MODULE_CLASS_LIST) + +def get_default_compare_output_module_list() -> Set[Callable]: + ''' Get list of module class types that we will record output + in numeric suite + ''' + NUMERIC_SUITE_COMPARE_MODEL_OUTPUT_MODULE_LIST = ( + set(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS.values()) + | set(DEFAULT_QAT_MODULE_MAPPINGS.values()) + | set(DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS.values()) + | set(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS.keys()) + | set(DEFAULT_QAT_MODULE_MAPPINGS.keys()) + | set(DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS.keys()) + | _INCLUDE_QCONFIG_PROPAGATE_LIST + ) + return copy.deepcopy(NUMERIC_SUITE_COMPARE_MODEL_OUTPUT_MODULE_LIST) + +def get_default_float_to_quantized_operator_mappings( +) -> Dict[Union[Callable, str], Callable]: + return copy.deepcopy(DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS) + +# TODO: merge with get_static_quant_module_class +def get_quantized_operator(float_op: Union[Callable, str]) -> Callable: + ''' Get the quantized operator corresponding to the float operator + ''' + quantized_op = DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS.get(float_op, None) + assert quantized_op is not None, \ + f'Operator {str(float_op)} does not have corresponding quantized op' + return quantized_op + +def _get_special_act_post_process(module: torch.nn.Module) -> Optional[Callable]: + r""" Get the special activation post process for `module`, this has + higher priority than the activation post process in `qconfig` + e.g. + input: torch.nn.Sigmoid + output: default_affine_fixed_qparam_fake_quant + """ + return DEFAULT_MODULE_TO_ACT_POST_PROCESS.get(type_before_parametrizations(module), None) + +def _has_special_act_post_process(module: torch.nn.Module) -> bool: + return module.training and type(module) in DEFAULT_MODULE_TO_ACT_POST_PROCESS diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/quantize.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/quantize.py new file mode 100644 index 0000000000000000000000000000000000000000..794cb142220dfb424cb58b06b736ed02cd14dad2 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/quantize.py @@ -0,0 +1,664 @@ +import copy +import itertools +import warnings + +import torch +import torch.nn as nn +import torch.ao.nn.quantized as nnq +from torch.ao.nn.intrinsic import _FusedModule + +from torch.ao.quantization.quantization_mappings import ( + get_default_dynamic_quant_module_mappings, + get_default_static_quant_module_mappings, + get_default_static_quant_reference_module_mappings, + get_default_qat_module_mappings, + get_default_qconfig_propagation_list, + no_observer_set, + _has_special_act_post_process, + _get_special_act_post_process, +) +from .utils import get_qparam_dict, has_no_children_ignoring_parametrizations +from torch.ao.quantization.stubs import DeQuantStub, QuantWrapper +from torch.ao.quantization.qconfig import ( + _add_module_to_qconfig_obs_ctr, + default_dynamic_qconfig, + float16_dynamic_qconfig, + float_qparams_weight_only_qconfig, + float_qparams_weight_only_qconfig_4bit, + _activation_is_memoryless) +from torch.nn.utils.parametrize import type_before_parametrizations +from torch.ao.quantization.observer import _is_activation_post_process + +# TODO remove this once BC is no longer required to avoid a SEV +from torch.ao.quantization.observer import ( # noqa: F401 + _is_activation_post_process as is_activation_post_process +) + +__all__ = [ + "get_default_custom_config_dict", + "propagate_qconfig_", + "add_quant_dequant", + "prepare", + "quantize", + "quantize_dynamic", + "prepare_qat", + "quantize_qat", + "convert", + "swap_module", +] + +_DEFAULT_CUSTOM_CONFIG_DICT = { + 'float_to_observed_custom_module_class': { + nn.LSTM: nn.quantizable.LSTM, + nn.MultiheadAttention: nn.quantizable.MultiheadAttention, + }, + 'observed_to_quantized_custom_module_class': { + nn.quantizable.LSTM: nn.quantized.LSTM, + nn.quantizable.MultiheadAttention: nn.quantized.MultiheadAttention, + } +} + +def get_default_custom_config_dict(): + r"""Defines the default custom config dict. + """ + return _DEFAULT_CUSTOM_CONFIG_DICT + +def _propagate_qconfig_helper(module, qconfig_dict, + qconfig_parent=None, prefix='', prepare_custom_config_dict=None): + r"""This is a helper function for `propagate_qconfig_` + + Args: + module: input module + qconfig_dict: dictionary that maps from name of submodule to quantization + configuration + qconfig_parent: quantization config of parent module, we will fallback to + this config when there is no specified config for current + module + prefix: corresponding prefix of the current module, used as key in + qconfig_dict + prepare_custom_config_dict: dictionary for custom handling of modules + see docs for :func:`~torch.ao.quantization.prepare_fx` + + Return: + None, module is modified inplace with qconfig attached + """ + + module_qconfig = qconfig_dict.get(type_before_parametrizations(module), qconfig_parent) + module_qconfig = qconfig_dict.get(prefix, module_qconfig) + module_qconfig = getattr(module, 'qconfig', module_qconfig) + + torch.ao.quantization.qconfig._assert_valid_qconfig(module_qconfig, module) + + qconfig_with_device_check = _add_module_to_qconfig_obs_ctr(module_qconfig, module) + module.qconfig = qconfig_with_device_check + + for name, child in module.named_children(): + module_prefix = prefix + '.' + name if prefix else name + # do no not propagate qconfig to child if child is non traceable + if prepare_custom_config_dict is None or not ( + name in prepare_custom_config_dict.get("non_traceable_module_name", []) + or type(child) in prepare_custom_config_dict.get("non_traceable_module_class", []) + ): + _propagate_qconfig_helper( + child, qconfig_dict, qconfig_with_device_check, module_prefix + ) + +def propagate_qconfig_(module, qconfig_dict=None, prepare_custom_config_dict=None): + r"""Propagate qconfig through the module hierarchy and assign `qconfig` + attribute on each leaf module + + Args: + module: input module + qconfig_dict: dictionary that maps from name or type of submodule to + quantization configuration, qconfig applies to all submodules of a + given module unless qconfig for the submodules are specified (when + the submodule already has qconfig attribute) + prepare_custom_config_dict: dictionary for custom handling of modules + see docs for :func:`~torch.ao.quantization.prepare_fx` + + Return: + None, module is modified inplace with qconfig attached + """ + if qconfig_dict is None: + qconfig_dict = {} + if prepare_custom_config_dict is None: + prepare_custom_config_dict = {} + _propagate_qconfig_helper(module, qconfig_dict, prepare_custom_config_dict=prepare_custom_config_dict) + +def _observer_forward_hook(self, input, output): + r"""Forward hook that calls observer on the output + """ + return self.activation_post_process(output) + +def _observer_forward_pre_hook(self, input): + r"""Forward pre hook that calls observer on the output + """ + return self.activation_post_process(input[0]) + +def _register_activation_post_process_hook(module, pre_hook=False): + assert hasattr(module, 'activation_post_process'), \ + 'Expect activation_post_process attribute already attached to the module' + if pre_hook: + handle = module.register_forward_pre_hook( + _observer_forward_pre_hook, prepend=True + ) + else: + handle = module.register_forward_hook( + _observer_forward_hook, prepend=True + ) + + +def _add_observer_(module, qconfig_propagation_list=None, non_leaf_module_list=None, device=None, custom_module_class_mapping=None): + r"""Add observer for the leaf child of the module. + + This function insert observer module to all leaf child module that + has a valid qconfig attribute. + + Args: + module: input module with qconfig attributes for all the leaf modules that we want to quantize + qconfig_propagation_list: a list of quantizable modules that will have observers added to them + if they are leaf nodes + device: parent device, if any + non_leaf_module_list: list of non-leaf modules we want to add observer + + Return: + None, module is modified inplace with added observer modules and forward_hooks + """ + if qconfig_propagation_list is None: + qconfig_propagation_list = get_default_qconfig_propagation_list() + + if custom_module_class_mapping is None: + custom_module_class_mapping = {} + + # respect device affinity when adding observers + if device is None: + devices = _get_unique_devices_(module) + assert len(devices) <= 1, ( + f"_add_observer_ only works with cpu or single-device CUDA modules, but got devices {devices}" + ) + device = next(iter(devices)) if len(devices) > 0 else None + + def get_activation_post_process(qconfig, device, special_act_post_process=None): + activation = qconfig.activation() if special_act_post_process is None else special_act_post_process() + if device is not None: + activation.to(device) + return activation + + def needs_observation(m): + return hasattr(m, 'qconfig') and m.qconfig is not None + + def insert_activation_post_process(m, special_act_post_process=None): + """ Adds an activation post process module and register + a pre or post hook that calls the module + """ + # We don't insert observer/fake_quantize for DeQuantStub + if needs_observation(m) and not isinstance(m, DeQuantStub): + # observer and hook will be gone after we swap the module + m.add_module('activation_post_process', get_activation_post_process( + m.qconfig, device, special_act_post_process)) + # Register observer as the first entry in the hook list + # All post forward hooks are preserved and will be executed after the observer before convert + _register_activation_post_process_hook(m, pre_hook=_activation_is_memoryless(m.qconfig)) + + for name, child in module.named_children(): + # TODO remove Dropout special after codebase stable + if type_before_parametrizations(child) in [nn.Dropout]: + continue + elif issubclass(type_before_parametrizations(child), (nnq.FloatFunctional, nnq.QFunctional)): + if needs_observation(child): + assert hasattr(child, "activation_post_process"), ( + f"functional class {type_before_parametrizations(child)} has no pre-defined `activation_post_process`" + ) + child.activation_post_process = get_activation_post_process(child.qconfig, device) + elif isinstance(child, _FusedModule): + # activation_post_process are now added directly to nn.Sequential/_FusedModule + if needs_observation(child): + insert_activation_post_process(child) + elif non_leaf_module_list is not None and type_before_parametrizations(child) in non_leaf_module_list: + if needs_observation(child): + insert_activation_post_process(child) + elif _has_special_act_post_process(child): + special_act_post_process = _get_special_act_post_process(child) + insert_activation_post_process(child, special_act_post_process) + elif needs_observation(child) and type_before_parametrizations(child) in custom_module_class_mapping: + observed_child = custom_module_class_mapping[type_before_parametrizations(child)].from_float(child) + setattr(module, name, observed_child) + # TODO: These are the modules that cannot be observed + # Once there are more, we should move them to a separate list + if custom_module_class_mapping[type_before_parametrizations(child)] not in no_observer_set(): + insert_activation_post_process(observed_child) + else: + _add_observer_(child, qconfig_propagation_list, non_leaf_module_list, device, custom_module_class_mapping) + + # Insert observers only for leaf nodes, note that this observer is for + # the output of the module, for input QuantStub will observe them + if has_no_children_ignoring_parametrizations(module) and not isinstance(module, torch.nn.Sequential) \ + and type_before_parametrizations(module) in qconfig_propagation_list: + insert_activation_post_process(module) + +def _get_unique_devices_(module): + return {p.device for p in module.parameters()} | \ + {p.device for p in module.buffers()} + +def add_quant_dequant(module): + r"""Wrap the leaf child module in QuantWrapper if it has a valid qconfig + Note that this function will modify the children of module inplace and it + can return a new module which wraps the input module as well. + + Args: + module: input module with qconfig attributes for all the leaf modules + that we want to quantize + + Return: + Either the inplace modified module with submodules wrapped in + `QuantWrapper` based on qconfig or a new `QuantWrapper` module which + wraps the input module, the latter case only happens when the input + module is a leaf module and we want to quantize it. + """ + if has_no_children_ignoring_parametrizations(module) and hasattr(module, 'qconfig') and module.qconfig: + return QuantWrapper(module) + + for name, child in module.named_children(): + module._modules[name] = add_quant_dequant(child) + return module + +def prepare(model, inplace=False, allow_list=None, + observer_non_leaf_module_list=None, + prepare_custom_config_dict=None): + r"""Prepares a copy of the model for quantization calibration or quantization-aware training. + + Quantization configuration should be assigned preemptively + to individual submodules in `.qconfig` attribute. + + The model will be attached with observer or fake quant modules, and qconfig + will be propagated. + + Args: + `model`: input model to be modified in-place + `inplace`: carry out model transformations in-place, the original module is mutated + `allow_list`: list of quantizable modules + `observer_non_leaf_module_list`: list of non-leaf modules we want to add observer + `prepare_custom_config_dict`: customization configuration dictionary for prepare function + + .. code-block:: python + + # Example of prepare_custom_config_dict: + prepare_custom_config_dict = { + # user will manually define the corresponding observed + # module class which has a from_float class method that converts + # float custom module to observed custom module + "float_to_observed_custom_module_class": { + CustomModule: ObservedCustomModule + } + } + + """ + torch._C._log_api_usage_once("quantization_api.quantize.prepare") + if prepare_custom_config_dict is None: + prepare_custom_config_dict = get_default_custom_config_dict() + custom_module_class_mapping = prepare_custom_config_dict.get("float_to_observed_custom_module_class", {}) + + if not inplace: + model = copy.deepcopy(model) + + # TODO: remove allow_list + qconfig_propagation_list = allow_list + if allow_list is None: + qconfig_propagation_list = get_default_qconfig_propagation_list() + propagate_qconfig_(model, qconfig_dict=None) + + # sanity check common API misusage + if not any(hasattr(m, 'qconfig') and m.qconfig for m in model.modules()): + warnings.warn("None of the submodule got qconfig applied. Make sure you " + "passed correct configuration through `qconfig_dict` or " + "by assigning the `.qconfig` attribute directly on submodules") + + _add_observer_( + model, qconfig_propagation_list, observer_non_leaf_module_list, + custom_module_class_mapping=custom_module_class_mapping) + return model + +def _remove_activation_post_process(module): + # TODO: maybe we should change activation_post_process to _activation_post_process + # to prevent it from being used by user + if hasattr(module, 'activation_post_process') and \ + _is_activation_post_process(module.activation_post_process): + delattr(module, 'activation_post_process') + + # remove activation_post_process pre and post hooks + def remove_hooks(pre_hook=False): + hook_map = module._forward_pre_hooks if pre_hook else module._forward_hooks + observer_hook = _observer_forward_pre_hook if pre_hook else _observer_forward_hook + handle_ids_to_remove = set() + for handle_id, hook_fn in hook_map.items(): + if hook_fn is observer_hook: + handle_ids_to_remove.add(handle_id) + for handle_id in handle_ids_to_remove: + hook_map.pop(handle_id) + + remove_hooks(pre_hook=True) + remove_hooks(pre_hook=False) + +# TODO: rename to something more general +def _remove_qconfig(module): + r"""Clean up the qconfig left in the module so that new qconfig can be + propagated. + + Args: + module: module to be cleaned up + """ + for child in module.children(): + _remove_qconfig(child) + + if hasattr(module, "qconfig"): + del module.qconfig + + _remove_activation_post_process(module) + +def quantize(model, run_fn, run_args, mapping=None, inplace=False): + r"""Quantize the input float model with post training static quantization. + + First it will prepare the model for calibration, then it calls + `run_fn` which will run the calibration step, after that we will + convert the model to a quantized model. + + Args: + model: input float model + run_fn: a calibration function for calibrating the prepared model + run_args: positional arguments for `run_fn` + inplace: carry out model transformations in-place, the original module is mutated + mapping: correspondence between original module types and quantized counterparts + + Return: + Quantized model. + """ + torch._C._log_api_usage_once("quantization_api.quantize.quantize") + if mapping is None: + mapping = get_default_static_quant_module_mappings() + if not inplace: + model = copy.deepcopy(model) + model.eval() + prepare(model, inplace=True) + run_fn(model, *run_args) + convert(model, mapping, inplace=True) + return model + +def quantize_dynamic(model, qconfig_spec=None, dtype=torch.qint8, + mapping=None, inplace=False): + r"""Converts a float model to dynamic (i.e. weights-only) quantized model. + + Replaces specified modules with dynamic weight-only quantized versions and output the quantized model. + + For simplest usage provide `dtype` argument that can be float16 or qint8. Weight-only quantization + by default is performed for layers with large weights size - i.e. Linear and RNN variants. + + Fine grained control is possible with `qconfig` and `mapping` that act similarly to `quantize()`. + If `qconfig` is provided, the `dtype` argument is ignored. + + Args: + model: input model + qconfig_spec: Either: + + - A dictionary that maps from name or type of submodule to quantization + configuration, qconfig applies to all submodules of a given + module unless qconfig for the submodules are specified (when the + submodule already has qconfig attribute). Entries in the dictionary + need to be QConfig instances. + + - A set of types and/or submodule names to apply dynamic quantization to, + in which case the `dtype` argument is used to specify the bit-width + + inplace: carry out model transformations in-place, the original module is mutated + mapping: maps type of a submodule to a type of corresponding dynamically quantized version + with which the submodule needs to be replaced + + """ + torch._C._log_api_usage_once("quantization_api.quantize.quantize_dynamic") + if qconfig_spec is None: + if dtype == torch.qint8: + qconfig_spec = { + nn.Linear : default_dynamic_qconfig, + nn.LSTM : default_dynamic_qconfig, + nn.GRU : default_dynamic_qconfig, + nn.LSTMCell : default_dynamic_qconfig, + nn.RNNCell : default_dynamic_qconfig, + nn.GRUCell : default_dynamic_qconfig, + } + elif dtype == torch.float16: + qconfig_spec = { + nn.Linear : float16_dynamic_qconfig, + nn.LSTM : float16_dynamic_qconfig, + nn.GRU : float16_dynamic_qconfig, + nn.LSTMCell : float16_dynamic_qconfig, + nn.RNNCell : float16_dynamic_qconfig, + nn.GRUCell : float16_dynamic_qconfig, + } + elif dtype == torch.quint8: + qconfig_spec = { + nn.EmbeddingBag : float_qparams_weight_only_qconfig, + nn.Embedding : float_qparams_weight_only_qconfig, + } + elif dtype == torch.quint4x2: + qconfig_spec = { + nn.EmbeddingBag : float_qparams_weight_only_qconfig_4bit, + } + else: + raise ValueError( + f"Don't know how to quantize with default settings for {dtype}. Provide full qconfig please") + elif isinstance(qconfig_spec, set): + if dtype is torch.qint8: + default_qconfig = default_dynamic_qconfig + elif dtype is torch.float16: + default_qconfig = float16_dynamic_qconfig + elif dtype is torch.quint8: + default_qconfig = float_qparams_weight_only_qconfig + elif dtype is torch.quint4x2: + default_qconfig = float_qparams_weight_only_qconfig_4bit + else: + raise RuntimeError('Unknown dtype specified for quantize_dynamic: ', str(dtype)) + qconfig_spec = dict(zip(qconfig_spec, itertools.repeat(default_qconfig))) + + if mapping is None: + mapping = get_default_dynamic_quant_module_mappings() + + if not inplace: + model = copy.deepcopy(model) + model.eval() + propagate_qconfig_(model, qconfig_spec) + convert(model, mapping, inplace=True) + return model + +def prepare_qat(model, mapping=None, inplace=False): + r""" + Prepares a copy of the model for quantization calibration or + quantization-aware training and converts it to quantized version. + + Quantization configuration should be assigned preemptively + to individual submodules in `.qconfig` attribute. + + Args: + model: input model to be modified in-place + mapping: dictionary that maps float modules to quantized modules to be + replaced. + inplace: carry out model transformations in-place, the original module + is mutated + """ + torch._C._log_api_usage_once("quantization_api.quantize.prepare_qat") + assert model.training, "prepare_qat only works on models in training mode" + if mapping is None: + mapping = get_default_qat_module_mappings() + + if not inplace: + model = copy.deepcopy(model) + + propagate_qconfig_(model, qconfig_dict=None) + convert(model, mapping=mapping, inplace=True, remove_qconfig=False) + prepare(model, observer_non_leaf_module_list=set(mapping.values()), inplace=True) + return model + +def quantize_qat(model, run_fn, run_args, inplace=False): + r"""Do quantization aware training and output a quantized model + + Args: + model: input model + run_fn: a function for evaluating the prepared model, can be a + function that simply runs the prepared model or a training + loop + run_args: positional arguments for `run_fn` + + Return: + Quantized model. + """ + torch._C._log_api_usage_once("quantization_api.quantize.quantize_qat") + if not inplace: + model = copy.deepcopy(model) + model.train() + prepare_qat(model, inplace=True) + run_fn(model, *run_args) + convert(model, inplace=True) + return model + +def convert( + module, mapping=None, inplace=False, remove_qconfig=True, + is_reference=False, convert_custom_config_dict=None): + r"""Converts submodules in input module to a different module according to `mapping` + by calling `from_float` method on the target module class. And remove qconfig at the + end if remove_qconfig is set to True. + + Args: + `module`: prepared and calibrated module + `mapping`: a dictionary that maps from source module type to target + module type, can be overwritten to allow swapping user defined + Modules + `inplace`: carry out model transformations in-place, the original module + is mutated + `convert_custom_config_dict`: custom configuration dictionary for convert function + + .. code-block:: python + + # Example of convert_custom_config_dict: + convert_custom_config_dict = { + # user will manually define the corresponding quantized + # module class which has a from_observed class method that converts + # observed custom module to quantized custom module + "observed_to_quantized_custom_module_class": { + ObservedCustomModule: QuantizedCustomModule + } + } + + """ + torch._C._log_api_usage_once("quantization_api.quantize.convert") + if not inplace: + module = copy.deepcopy(module) + _convert( + module, mapping, inplace=True, is_reference=is_reference, + convert_custom_config_dict=convert_custom_config_dict) + if remove_qconfig: + _remove_qconfig(module) + return module + +def _convert( + module, mapping=None, inplace=False, + is_reference=False, convert_custom_config_dict=None): + r"""Converts submodules in input module to a different module according to `mapping` + by calling `from_float` method on the target module class + + Args: + module: input module + mapping: a dictionary that maps from source module type to target + module type, can be overwritten to allow swapping user defined + Modules + inplace: carry out model transformations in-place, the original module + is mutated + is_reference: a flag to enable quantized reference module + + """ + if mapping is None: + mapping = get_default_static_quant_reference_module_mappings() if is_reference \ + else get_default_static_quant_module_mappings() + if convert_custom_config_dict is None: + convert_custom_config_dict = get_default_custom_config_dict() + custom_module_class_mapping = convert_custom_config_dict.get("observed_to_quantized_custom_module_class", {}) + + if not inplace: + module = copy.deepcopy(module) + reassign = {} + for name, mod in module.named_children(): + # both fused modules and observed custom modules are + # swapped as one unit + if not isinstance(mod, _FusedModule) and \ + type_before_parametrizations(mod) not in custom_module_class_mapping: + _convert(mod, mapping, True, # inplace + is_reference, convert_custom_config_dict) + reassign[name] = swap_module(mod, mapping, custom_module_class_mapping) + + for key, value in reassign.items(): + module._modules[key] = value + + return module + +def swap_module(mod, mapping, custom_module_class_mapping): + r"""Swaps the module if it has a quantized counterpart and it has an + `observer` attached. + + Args: + mod: input module + mapping: a dictionary that maps from nn module to nnq module + + Return: + The corresponding quantized module of `mod` + """ + new_mod = mod + if hasattr(mod, 'qconfig') and mod.qconfig is not None: + swapped = False + if type_before_parametrizations(mod) in custom_module_class_mapping: + new_mod = custom_module_class_mapping[type_before_parametrizations(mod)].from_observed(mod) + swapped = True + elif type_before_parametrizations(mod) in mapping: + qmod = mapping[type_before_parametrizations(mod)] + if hasattr(qmod, '_IS_REFERENCE') and qmod._IS_REFERENCE: + assert mod.qconfig is not None + weight_post_process = mod.qconfig.weight() + weight_post_process(mod.weight) + weight_qparams = get_qparam_dict(weight_post_process) + new_mod = qmod.from_float(mod, weight_qparams) + else: + new_mod = qmod.from_float(mod) + swapped = True + + if swapped: + # Preserve module's pre forward hooks. They'll be called on quantized input + for pre_hook_fn in mod._forward_pre_hooks.values(): + new_mod.register_forward_pre_hook(pre_hook_fn) + # Preserve module's post forward hooks except _observer_forward_hook + # After convert they'll work with quantized output + for hook_fn in mod._forward_hooks.values(): + if hook_fn is not _observer_forward_hook: + new_mod.register_forward_hook(hook_fn) + + # respect device affinity when swapping modules + devices = _get_unique_devices_(mod) + assert len(devices) <= 1, ( + f"swap_module only works with cpu or single-device CUDA modules, but got devices {devices}" + ) + device = next(iter(devices)) if len(devices) > 0 else None + if device: + new_mod.to(device) + return new_mod + +def _get_observer_dict(mod, target_dict, prefix=""): + r"""Traverse the modules and save all observers into dict. + This is mainly used for quantization accuracy debug + Args: + mod: the top module we want to save all observers + prefix: the prefix for the current module + target_dict: the dictionary used to save all the observers + """ + def get_prefix(prefix): + return prefix if prefix == "" else prefix + '.' + + if hasattr(mod, 'activation_post_process'): + target_dict[get_prefix(prefix) + 'activation_post_process'] = mod.activation_post_process + for name, child in mod.named_children(): + module_prefix = get_prefix(prefix) + name if prefix else name + _get_observer_dict(child, target_dict, module_prefix) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/quantize_jit.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/quantize_jit.py new file mode 100644 index 0000000000000000000000000000000000000000..632fc1db2327235fc66a2dc1d49b76c0316b5be1 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/quantize_jit.py @@ -0,0 +1,335 @@ + +import torch +from torch.ao.quantization.qconfig import QConfig +from torch.ao.quantization.quant_type import QuantType +from torch.jit._recursive import wrap_cpp_module + +__all__ = [ + "script_qconfig", + "script_qconfig_dict", + "fuse_conv_bn_jit", + "prepare_jit", + "prepare_dynamic_jit", + "convert_jit", + "convert_dynamic_jit", + "quantize_jit", + "quantize_dynamic_jit", +] + +def _check_is_script_module(model): + if not isinstance(model, torch.jit.ScriptModule): + raise ValueError('input must be a script module, got: ' + str(type(model))) + +def _check_forward_method(model): + if not model._c._has_method('forward'): + raise ValueError('input script module does not have forward method') + +def script_qconfig(qconfig): + r"""Instantiate the activation and weight observer modules and script + them, these observer module instances will be deepcopied during + prepare_jit step. + """ + return QConfig( + activation=torch.jit.script(qconfig.activation())._c, + weight=torch.jit.script(qconfig.weight())._c) + +def script_qconfig_dict(qconfig_dict): + r"""Helper function used by `prepare_jit`. + Apply `script_qconfig` for all entries in `qconfig_dict` that is + not None. + """ + return {k: script_qconfig(v) if v else None for k, v in qconfig_dict.items()} + +def fuse_conv_bn_jit(model, inplace=False): + r""" Fuse conv - bn module + Works for eval model only. + + Args: + model: TorchScript model from scripting or tracing + """ + torch._C._log_api_usage_once("quantization_api.quantize_jit.fuse_conv_bn_jit") + model_c = model._c + model_c = torch._C._jit_pass_fold_convbn(model_c) + if inplace: + model._reconstruct(model_c) + else: + model = wrap_cpp_module(model_c) + return model + +def _prepare_jit(model, qconfig_dict, inplace=False, quant_type=QuantType.STATIC): + _check_is_script_module(model) + _check_forward_method(model) + if not all(isinstance(x, str) for x in qconfig_dict.keys()): + raise ValueError('qconfig_dict should only contain names(str) as keys.') + scripted_qconfig_dict = script_qconfig_dict(qconfig_dict) + model = fuse_conv_bn_jit(model, inplace) + model_c = torch._C._jit_pass_insert_observers(model._c, + 'forward', + scripted_qconfig_dict, + inplace, + quant_type) + if inplace: + model._reconstruct(model_c) + else: + model = wrap_cpp_module(model_c) + return model + +def _prepare_ondevice_jit(model, qconfig_dict, method_name='forward', inplace=False, quant_type=QuantType.STATIC): + _check_is_script_module(model) + if not all(isinstance(x, str) for x in qconfig_dict.keys()): + raise ValueError('qconfig_dict should only contain names(str) as keys.') + scripted_qconfig_dict = script_qconfig_dict(qconfig_dict) + method_graph = model._c._get_method(method_name).graph + torch._C._jit_pass_inline(method_graph) + model = fuse_conv_bn_jit(model, inplace) + model_c = torch._C._jit_pass_insert_observer_method_for_ondevice_ptq(model._c, + method_name, + scripted_qconfig_dict, + inplace, + quant_type) + if inplace: + model._reconstruct(model_c) + else: + model = wrap_cpp_module(model_c) + return model + +def prepare_jit(model, qconfig_dict, inplace=False): + torch._C._log_api_usage_once("quantization_api.quantize_jit.prepare_jit") + return _prepare_jit(model, qconfig_dict, inplace, quant_type=QuantType.STATIC) + +def prepare_dynamic_jit(model, qconfig_dict, inplace=False): + torch._C._log_api_usage_once("quantization_api.quantize_jit.prepare_dynamic_jit") + return _prepare_jit(model, qconfig_dict, inplace, quant_type=QuantType.DYNAMIC) + + +def _prepare_ondevice_dynamic_jit(model, qconfig_dict, method_name='forward', inplace=False): + return _prepare_ondevice_jit(model, qconfig_dict, method_name, inplace, quant_type=QuantType.DYNAMIC) + +def _convert_jit(model, inplace=False, debug=False, quant_type=QuantType.STATIC, + preserved_attrs=None): + _check_is_script_module(model) + model.eval() + model_c = model._c + model_c = torch._C._jit_pass_insert_quant_dequant(model_c, 'forward', inplace, debug, quant_type) + if not debug: + is_xpu = all(p.device.type == 'xpu' for p in model.parameters()) + if not is_xpu: + # Moving model parameters to CPU since quantized operators + # are only supported on CPU and XPU right now + model.cpu() + if preserved_attrs is None: + preserved_attrs = [] + model_c = torch._C._jit_pass_quant_finalize(model_c, quant_type, preserved_attrs) + if inplace: + model._reconstruct(model_c) + else: + model = wrap_cpp_module(model_c) + torch._C._jit_pass_constant_propagation(model.graph) + torch._C._jit_pass_dce(model.graph) + return model + + +def _convert_ondevice_jit(model, method_name, inplace=False, debug=False, quant_type=QuantType.STATIC): + _check_is_script_module(model) + assert quant_type == QuantType.DYNAMIC, "This API, while should work for static quant, is only tested for dynamic quant." + assert not method_name.startswith("observe_"), "Pass in valid method to be quantized, e.g. forward" + observe_method_name = "observe_" + method_name + quantize_method_name = "quantize_" + method_name + model_c = model._c + model_c = torch._C._jit_pass_insert_quant_dequant_for_ondevice_ptq( + model._c, observe_method_name, inplace, debug, QuantType.DYNAMIC) + model_c = torch._C._jit_pass_quant_finalize_for_ondevice_ptq(model_c, QuantType.DYNAMIC, quantize_method_name) + if inplace: + model._reconstruct(model_c) + else: + model = wrap_cpp_module(model_c) + return model + +def convert_jit(model, inplace=False, debug=False, preserved_attrs=None): + torch._C._log_api_usage_once("quantization_api.quantize_jit.convert_jit") + return _convert_jit(model, inplace, debug, quant_type=QuantType.STATIC, preserved_attrs=preserved_attrs) + +def convert_dynamic_jit(model, inplace=False, debug=False, preserved_attrs=None): + torch._C._log_api_usage_once("quantization_api.quantize_jit.convert_dynamic_jit") + return _convert_jit(model, inplace, debug, quant_type=QuantType.DYNAMIC, preserved_attrs=preserved_attrs) + + +def _convert_ondevice_dynamic_jit(model, method_name, inplace=False, debug=False): + return _convert_ondevice_jit(model, method_name, inplace, debug, quant_type=QuantType.DYNAMIC) + + +def _quantize_ondevice_dynamic_jit_impl(model, qconfig_dict, method_name, inplace=False): + model = _prepare_ondevice_dynamic_jit(model, qconfig_dict, method_name, inplace) + model = _convert_ondevice_dynamic_jit(model, method_name, inplace) + return model + +def _quantize_jit(model, qconfig_dict, run_fn=None, run_args=None, inplace=False, debug=False, quant_type=QuantType.STATIC): + # Always do inplace convert because the Tensor is already + # copied in prepare_jit when inplace is False + if quant_type == QuantType.DYNAMIC: + model = prepare_dynamic_jit(model, qconfig_dict, inplace) + model = convert_dynamic_jit(model, True, debug) + else: + assert run_fn, "Must provide calibration function for post training static quantization" + assert run_args, "Must provide calibration dataset for post training static quantization" + model = prepare_jit(model, qconfig_dict, inplace) + run_fn(model, *run_args) + model = convert_jit(model, True, debug) + + torch._C._jit_pass_constant_propagation(model.graph) + torch._C._jit_pass_dce(model.graph) + return model + +def quantize_jit(model, qconfig_dict, run_fn, run_args, inplace=False, debug=False): + r"""Quantize the input float TorchScript model with + post training static quantization. + + First it will prepare the model for calibration, then it calls + `run_fn` which will run the calibration step, after that we will + convert the model to a quantized model. + + Args: + `model`: input float TorchScript model + `qconfig_dict`: qconfig_dict is a dictionary with names of sub modules as key and + qconfig for that module as value, empty key means the qconfig will be applied + to whole model unless it's overwritten by more specific configurations, the + qconfig for each module is either found in the dictionary or fallback to + the qconfig of parent module. + + Right now qconfig_dict is the only way to configure how the model is quantized, + and it is done in the granularity of module, that is, we only support one type + of qconfig for each torch.nn.Module, and the qconfig for sub module will + override the qconfig for parent module, empty string means global configuration. + `run_fn`: a calibration function for calibrating the prepared model + `run_args`: positional arguments for `run_fn` + `inplace`: carry out model transformations in-place, the original module is + mutated + `debug`: flag for producing a debug friendly model (preserve weight attribute) + + Return: + Quantized TorchSciprt model. + + Example: + ```python + import torch + from torch.ao.quantization import get_default_qconfig + from torch.ao.quantization import quantize_jit + + ts_model = torch.jit.script(float_model.eval()) # or torch.jit.trace(float_model, input) + qconfig = get_default_qconfig('fbgemm') + def calibrate(model, data_loader): + model.eval() + with torch.no_grad(): + for image, target in data_loader: + model(image) + + quantized_model = quantize_jit( + ts_model, + {'': qconfig}, + calibrate, + [data_loader_test]) + ``` + """ + torch._C._log_api_usage_once("quantization_api.quantize_jit.quantize_jit") + return _quantize_jit(model, qconfig_dict, run_fn, run_args, inplace, debug, quant_type=QuantType.STATIC) + +def quantize_dynamic_jit(model, qconfig_dict, inplace=False, debug=False): + r"""Quantize the input float TorchScript model with + post training dynamic quantization. + Currently only qint8 quantization of torch.nn.Linear is supported. + + Args: + `model`: input float TorchScript model + `qconfig_dict`: qconfig_dict is a dictionary with names of sub modules as key and + qconfig for that module as value, please see detailed + descriptions in :func:`~torch.ao.quantization.quantize_jit` + `inplace`: carry out model transformations in-place, the original module is + mutated + `debug`: flag for producing a debug friendly model (preserve weight attribute) + + Return: + Quantized TorchSciprt model. + + Example: + ```python + import torch + from torch.ao.quantization import per_channel_dynamic_qconfig + from torch.ao.quantization import quantize_dynamic_jit + + ts_model = torch.jit.script(float_model.eval()) # or torch.jit.trace(float_model, input) + qconfig = get_default_qconfig('fbgemm') + def calibrate(model, data_loader): + model.eval() + with torch.no_grad(): + for image, target in data_loader: + model(image) + + quantized_model = quantize_dynamic_jit( + ts_model, + {'': qconfig}, + calibrate, + [data_loader_test]) + ``` + """ + torch._C._log_api_usage_once("quantization_api.quantize_jit.quantize_dynamic_jit") + return _quantize_jit(model, qconfig_dict, inplace=inplace, debug=debug, quant_type=QuantType.DYNAMIC) + + +def _quantize_ondevice_dynamic_jit(model, qconfig_dict, method_name='forward', inplace=False): + r"""Prepares the input float TorchScript model with + *on-device* post training dynamic quantization. + Currently only qint8 quantization of torch.nn.Linear is supported. + + Args: + `model`: input float TorchScript model + `qconfig_dict`: qconfig_dict is a dictionary with names of sub modules as key and + qconfig for that module as value, please see detailed + `method_name`: Name of the method within the model, to be prepared for quantization + descriptions in :func:`~torch.ao.quantization.quantize_jit` + `inplace`: carry out model transformations in-place, the original module is + mutated + + Return: + TorchScript model that is ready for on device quantization. + This means that the returned + model has: + - Method is inlined. + - Model has observer modules inserted in the model. + - Model has packed params inserted in the model. However they are empty as in they dont + contain valid quantized weights. + - observe_ is added that observe the values to be quantized. + - reset_observers_ to reset observers. + - quantize_ is added to the model. + - This method extract scale, zero points. + - Quantizes observed weights. + - Creates packed params from it and update the attribute of the model with the new values + for the packed params. + - Reset the original fp32 weights with empty tensor using SetAttr. + - quantized_ is added to the model. + - This method uses quantized weights and quantized linear ops instead of fp32 op. + - This method should be used for inference post PTQ. + - Note that all method's signatures should be the same as method_name. + + Later on device: + - Run reset_observers_ + - Run observe_ + - Run quantize_ + - Now model can be saved and loaded later. + - Run model with quantized_ + + Example: + ```python + import torch + from torch.ao.quantization import per_channel_dynamic_qconfig + from torch.ao.quantization.quantize_jit import _quantize_ondevice_dynamic_jit + + ts_model = torch.jit.script(float_model.eval()) # or torch.jit.trace(float_model, input) + qconfig = get_default_qconfig('fbgemm') + quant_ready_model = _quantize_ondevice_dynamic_jit( + ts_model, + {'': qconfig}, + 'forward', + True) + ``` + """ + return _quantize_ondevice_dynamic_jit_impl(model, qconfig_dict, method_name, inplace=inplace) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/stubs.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/stubs.py new file mode 100644 index 0000000000000000000000000000000000000000..10a63fb8f0ee43dc3f2367a2f6b0f56164ff4715 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/stubs.py @@ -0,0 +1,64 @@ + +from torch import nn + +class QuantStub(nn.Module): + r"""Quantize stub module, before calibration, this is same as an observer, + it will be swapped as `nnq.Quantize` in `convert`. + + Args: + qconfig: quantization configuration for the tensor, + if qconfig is not provided, we will get qconfig from parent modules + """ + def __init__(self, qconfig=None): + super().__init__() + if qconfig: + self.qconfig = qconfig + + def forward(self, x): + return x + + +class DeQuantStub(nn.Module): + r"""Dequantize stub module, before calibration, this is same as identity, + this will be swapped as `nnq.DeQuantize` in `convert`. + + Args: + qconfig: quantization configuration for the tensor, + if qconfig is not provided, we will get qconfig from parent modules + """ + def __init__(self, qconfig=None): + super().__init__() + if qconfig: + self.qconfig = qconfig + + def forward(self, x): + return x + + +class QuantWrapper(nn.Module): + r"""A wrapper class that wraps the input module, adds QuantStub and + DeQuantStub and surround the call to module with call to quant and dequant + modules. + + This is used by the `quantization` utility functions to add the quant and + dequant modules, before `convert` function `QuantStub` will just be observer, + it observes the input tensor, after `convert`, `QuantStub` + will be swapped to `nnq.Quantize` which does actual quantization. Similarly + for `DeQuantStub`. + """ + quant: QuantStub + dequant: DeQuantStub + module: nn.Module + + def __init__(self, module): + super().__init__() + qconfig = getattr(module, "qconfig", None) + self.add_module('quant', QuantStub(qconfig)) + self.add_module('dequant', DeQuantStub(qconfig)) + self.add_module('module', module) + self.train(module.training) + + def forward(self, X): + X = self.quant(X) + X = self.module(X) + return self.dequant(X) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/serialization.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/serialization.py new file mode 100644 index 0000000000000000000000000000000000000000..b4bcf2977da56684b1631e788473df368d015d56 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/serialization.py @@ -0,0 +1,1456 @@ +import difflib +import os +import io +import shutil +import struct +import sys +import torch +import tarfile +import tempfile +import warnings +from contextlib import closing, contextmanager +from enum import Enum +from ._utils import _import_dotted_name +from torch._sources import get_source_lines_and_file +from torch.types import Storage +from torch.storage import _get_dtype_from_pickle_storage_type +from typing import Any, BinaryIO, Callable, cast, Dict, Optional, Type, Tuple, Union, IO, List +from typing_extensions import TypeAlias, TypeGuard # Python 3.10+ +import copyreg +import pickle +import torch._weights_only_unpickler as _weights_only_unpickler + +DEFAULT_PROTOCOL = 2 + +LONG_SIZE = struct.Struct('=l').size +INT_SIZE = struct.Struct('=i').size +SHORT_SIZE = struct.Struct('=h').size + +MAGIC_NUMBER = 0x1950a86a20f9469cfc6c +PROTOCOL_VERSION = 1001 +STORAGE_KEY_SEPARATOR = ',' + +FILE_LIKE: TypeAlias = Union[str, os.PathLike, BinaryIO, IO[bytes]] +MAP_LOCATION: TypeAlias = Optional[Union[Callable[[torch.Tensor, str], torch.Tensor], torch.device, str, Dict[str, str]]] +STORAGE: TypeAlias = Union[Storage, torch.storage.TypedStorage, torch.UntypedStorage] + +__all__ = [ + 'SourceChangeWarning', + 'mkdtemp', + 'register_package', + 'check_module_version_greater_or_equal', + 'validate_cuda_device', + 'validate_hpu_device', + 'location_tag', + 'default_restore_location', + 'normalize_storage_type', + 'storage_to_tensor_type', + 'save', + 'load', + 'StorageType', + 'LoadEndianness', + 'get_default_load_endianness', + 'set_default_load_endianness', +] + + +class SourceChangeWarning(Warning): + pass + + +@contextmanager +def mkdtemp(): + path = tempfile.mkdtemp() + try: + yield path + finally: + shutil.rmtree(path) + + +_package_registry: List[Tuple[int, Callable[[STORAGE], Optional[str]], Callable[[STORAGE, str], Optional[STORAGE]]]] = [] + +class LoadEndianness(Enum): + NATIVE = 1 + LITTLE = 2 + BIG = 3 + +_default_load_endian: Optional[LoadEndianness] = None + +def get_default_load_endianness() -> Optional[LoadEndianness]: + ''' + Get fallback byte order for loading files + + If byteorder mark is not present in saved checkpoint, + this byte order is used as fallback. + By default, it's "native" byte order. + + Returns: + default_load_endian: Optional[LoadEndianness] + ''' + return _default_load_endian + +def set_default_load_endianness(endianness): + ''' + Set fallback byte order for loading files + + If byteorder mark is not present in saved checkpoint, + this byte order is used as fallback. + By default, it's "native" byte order. + + Args: + endianness: the new fallback byte order + ''' + global _default_load_endian + if not isinstance(endianness, LoadEndianness) and endianness is not None: + raise TypeError("Invalid argument type in function set_default_load_endianness") + _default_load_endian = endianness + +def _is_zipfile(f) -> bool: + # This is a stricter implementation than zipfile.is_zipfile(). + # zipfile.is_zipfile() is True if the magic number appears anywhere in the + # binary. Since we expect the files here to be generated by torch.save or + # torch.jit.save, it's safe to only check the start bytes and avoid + # collisions and assume the zip has only 1 file. + # See bugs.python.org/issue28494. + + start = f.tell() + # Read the first few bytes and match against the ZIP file signature + local_header_magic_number = b'PK\x03\x04' + read_bytes = f.read(len(local_header_magic_number)) + f.seek(start) + return read_bytes == local_header_magic_number + + +def register_package( + priority: int, + tagger: Callable[[STORAGE], Optional[str]], + deserializer: Callable[[STORAGE, str], Optional[STORAGE]] +): + ''' + Registers callables for tagging and deserializing storage objects with an associated priority. + Tagging associates a device with a storage object at save time while deserializing moves a + storage object to an appropriate device at load time. :attr:`tagger` and :attr:`deserializer` + are run in the order given by their :attr:`priority` until a tagger/deserializer returns a + value that is not `None`. + + To override the deserialization behavior for a device in the global registry, one can register a + tagger with a higher priority than the existing tagger. + + This function can also be used to register a tagger and deserializer for new devices. + + Args: + priority: Indicates the priority associated with the tagger and deserializer, where a lower + value indicates higher priority. + tagger: Callable that takes in a storage object and returns its tagged device as a string + or None. + deserializer: Callable that takes in storage object and a device string and returns a storage + object on the appropriate device or None. + + Returns: + `None` + + Example: + >>> def ipu_tag(obj): + >>> if obj.device.type == 'ipu': + >>> return 'ipu' + >>> def ipu_deserialize(obj, location): + >>> if location.startswith('ipu'): + >>> ipu = getattr(torch, "ipu", None) + >>> assert ipu is not None, "IPU device module is not loaded" + >>> assert torch.ipu.is_available(), "ipu is not available" + >>> return obj.ipu(location) + >>> torch.serialization.register_package(11, ipu_tag, ipu_deserialize) + ''' + queue_elem = (priority, tagger, deserializer) + _package_registry.append(queue_elem) + _package_registry.sort() + + +def check_module_version_greater_or_equal(module, req_version_tuple, error_if_malformed=True): + ''' + Check if a module's version satisfies requirements + + Usually, a module's version string will be like 'x.y.z', which would be represented + as a tuple (x, y, z), but sometimes it could be an unexpected format. If the version + string does not match the given tuple's format up to the length of the tuple, then + error and exit or emit a warning. + + Args: + module: the module to check the version of + req_version_tuple: tuple (usually of ints) representing the required version + error_if_malformed: whether we should exit if module version string is malformed + + Returns: + requirement_is_met: bool + ''' + try: + version_strs = module.__version__.split('.') + # Cast module version fields to match the types of the required version + module_version = tuple( + type(req_field)(version_strs[idx]) for idx, req_field in enumerate(req_version_tuple) + ) + requirement_is_met = module_version >= req_version_tuple + + except Exception as e: + message = ( + f"'{module.__name__}' module version string is malformed '{module.__version__}' and cannot be compared" + f" with tuple {str(req_version_tuple)}" + ) + if error_if_malformed: + raise RuntimeError(message) from e + else: + warnings.warn(message + ', but continuing assuming that requirement is met') + requirement_is_met = True + + return requirement_is_met + + +def _cpu_tag(obj): + if obj.device.type == 'cpu': + return 'cpu' + + +def _cuda_tag(obj): + if obj.device.type == 'cuda': + return 'cuda:' + str(obj.device.index) + +def _hpu_tag(obj): + if obj.device.type == 'hpu': + return 'hpu:' + str(obj.device.index) + +def _mps_tag(obj): + if obj.device.type == 'mps': + return 'mps' + + +def _meta_tag(obj): + if obj.device.type == 'meta': + return 'meta' + + +def _privateuse1_tag(obj): + backend_name = torch._C._get_privateuse1_backend_name() + if obj.device.type == backend_name: + if obj.device.index is None: + return backend_name + else: + return backend_name + ':' + str(obj.device.index) + + +def _cpu_deserialize(obj, location): + if location == 'cpu': + return obj + + +def validate_cuda_device(location): + device = torch.cuda._utils._get_device_index(location, True) + + if not torch.cuda.is_available(): + raise RuntimeError('Attempting to deserialize object on a CUDA ' + 'device but torch.cuda.is_available() is False. ' + 'If you are running on a CPU-only machine, ' + 'please use torch.load with map_location=torch.device(\'cpu\') ' + 'to map your storages to the CPU.') + device_count = torch.cuda.device_count() + if device >= device_count: + raise RuntimeError('Attempting to deserialize object on CUDA device ' + f'{device} but torch.cuda.device_count() is {device_count}. Please use ' + 'torch.load with map_location to map your storages ' + 'to an existing device.') + return device + + +def _cuda_deserialize(obj, location): + if location.startswith('cuda'): + device = validate_cuda_device(location) + if getattr(obj, "_torch_load_uninitialized", False): + with torch.cuda.device(device): + return torch.UntypedStorage(obj.nbytes(), device=torch.device(location)) + else: + return obj.cuda(device) + + +def validate_hpu_device(location): + hpu = getattr(torch, "hpu", None) + assert hpu is not None, "HPU device module is not loaded" + device = hpu._utils._get_device_index(location, optional=True) + + if not hpu.is_available(): + raise RuntimeError('Attempting to deserialize object on a HPU ' + 'device but torch.hpu.is_available() is False. ' + 'If you are running on a CPU-only machine, ' + 'please use torch.load with map_location=torch.device(\'cpu\') ' + 'to map your storages to the CPU.') + device_count = hpu.device_count() + if device >= device_count: + raise RuntimeError('Attempting to deserialize object on HPU device ' + f'{device} but torch.hpu.device_count() is {device_count}. Please use ' + 'torch.load with map_location to map your storages ' + 'to an existing device.') + return device + + +def _hpu_deserialize(obj, location): + if location.startswith('hpu'): + hpu = getattr(torch, "hpu", None) + assert hpu is not None, "HPU device module is not loaded" + device = validate_hpu_device(location) + if getattr(obj, "_torch_load_uninitialized", False): + with hpu.device(device): + return torch.UntypedStorage(obj.nbytes(), device=torch.device(location)) + else: + return obj.hpu(device) + + +def _mps_deserialize(obj, location): + if location.startswith('mps'): + return obj.mps() + + +def _meta_deserialize(obj, location): + if location == 'meta': + return torch.UntypedStorage(obj.nbytes(), device='meta') + + +def _validate_privateuse1_device(location, backend_name): + ''' + Check whether the device index of privateuse1 is valid + + Register a device_module of privateuse1 by torch._register_device_module. + Implement the following methods in device_module like cuda: + device_module._utils._get_device_index(location, True), + device_module.device_count(). + + Args: + location: string of device + backend_name: the name of privateuse1, which can be renamed + + Returns: + device_index: int + ''' + if not hasattr(torch, backend_name): + raise RuntimeError(f'The {backend_name.upper()} device module is not registered. ' + 'If you are running on a CPU-only machine, ' + 'please use torch.load with map_location=torch.device(\'cpu\') ' + 'to map your storages to the CPU.') + device_module = getattr(torch, backend_name) + if hasattr(device_module, '_utils') and hasattr(device_module._utils, '_get_device_index'): + device_index = device_module._utils._get_device_index(location, True) + else: + device = torch.device(location) + device_index = device.index if device.index else 0 + if hasattr(device_module, 'is_available') and not device_module.is_available(): + raise RuntimeError(f'Attempting to deserialize object on a {backend_name.upper()} ' + f'device but torch.{backend_name}.is_available() is False. ' + 'If you are running on a CPU-only machine, ' + 'please use torch.load with map_location=torch.device(\'cpu\') ' + 'to map your storages to the CPU.') + if hasattr(device_module, 'device_count'): + device_count = device_module.device_count() + if device_index >= device_count: + raise RuntimeError(f'Attempting to deserialize object on {backend_name.upper()} device ' + f'{device_index} but torch.{backend_name}.device_count() is {device_count}. ' + 'Please use torch.load with map_location to map your storages ' + 'to an existing device.') + return device_index + + +def _privateuse1_deserialize(obj, location): + backend_name = torch._C._get_privateuse1_backend_name() + if location.startswith(backend_name): + if not hasattr(obj, backend_name): + raise RuntimeError(f'Attempting to load the storages to the {backend_name.upper()} device ' + f'but torch.storage._StorageBase.{backend_name}() or ' + f'torch.storage.TypedStorage.{backend_name}() is not generated. ' + 'Please use torch.utils.generate_methods_for_privateuse1_backend ' + f'to generate storage.{backend_name}() method first.') + device_index = _validate_privateuse1_device(location, backend_name) + return getattr(obj, backend_name)(device_index) + + +register_package(10, _cpu_tag, _cpu_deserialize) +register_package(20, _cuda_tag, _cuda_deserialize) +register_package(21, _mps_tag, _mps_deserialize) +register_package(22, _meta_tag, _meta_deserialize) +register_package(23, _privateuse1_tag, _privateuse1_deserialize) +register_package(24, _hpu_tag, _hpu_deserialize) + + +def location_tag(storage: Union[Storage, torch.storage.TypedStorage, torch.UntypedStorage]): + for _, tagger, _ in _package_registry: + location = tagger(storage) + if location: + return location + raise RuntimeError("don't know how to determine data location of " + + torch.typename(storage)) + + +def default_restore_location(storage, location): + for _, _, fn in _package_registry: + result = fn(storage, location) + if result is not None: + return result + raise RuntimeError("don't know how to restore data location of " + + torch.typename(storage) + " (tagged with " + + location + ")") + + +def normalize_storage_type(storage_type): + return getattr(torch, storage_type.__name__) + + +def storage_to_tensor_type(storage): + storage_type = type(storage) + module = _import_dotted_name(storage_type.__module__) + return getattr(module, storage_type.__name__.replace('Storage', 'Tensor')) + + +def _is_path(name_or_buffer) -> TypeGuard[Union[str, os.PathLike]]: + return isinstance(name_or_buffer, (str, os.PathLike)) + + +class _opener: + def __init__(self, file_like): + self.file_like = file_like + + def __enter__(self): + return self.file_like + + def __exit__(self, *args): + pass + + +class _open_file(_opener): + def __init__(self, name, mode): + super().__init__(open(name, mode)) + + def __exit__(self, *args): + self.file_like.close() + + +class _open_buffer_reader(_opener): + def __init__(self, buffer): + super().__init__(buffer) + _check_seekable(buffer) + + +class _open_buffer_writer(_opener): + def __exit__(self, *args): + self.file_like.flush() + + +def _open_file_like(name_or_buffer, mode): + if _is_path(name_or_buffer): + return _open_file(name_or_buffer, mode) + else: + if 'w' in mode: + return _open_buffer_writer(name_or_buffer) + elif 'r' in mode: + return _open_buffer_reader(name_or_buffer) + else: + raise RuntimeError(f"Expected 'r' or 'w' in mode but got {mode}") + + +class _open_zipfile_reader(_opener): + def __init__(self, name_or_buffer) -> None: + super().__init__(torch._C.PyTorchFileReader(name_or_buffer)) + + +class _open_zipfile_writer_file(_opener): + def __init__(self, name) -> None: + self.file_stream = None + self.name = str(name) + try: + self.name.encode('ascii') + except UnicodeEncodeError: + # PyTorchFileWriter only supports ascii filename. + # For filenames with non-ascii characters, we rely on Python + # for writing out the file. + self.file_stream = io.FileIO(self.name, mode='w') + super().__init__(torch._C.PyTorchFileWriter(self.file_stream)) + else: + super().__init__(torch._C.PyTorchFileWriter(self.name)) + + def __exit__(self, *args) -> None: + self.file_like.write_end_of_file() + if self.file_stream is not None: + self.file_stream.close() + + +class _open_zipfile_writer_buffer(_opener): + def __init__(self, buffer) -> None: + if not callable(getattr(buffer, "write", None)): + msg = f"Buffer of {str(type(buffer)).strip('<>')} has no callable attribute 'write'" + if not hasattr(buffer, "write"): + raise AttributeError(msg) + raise TypeError(msg) + self.buffer = buffer + super().__init__(torch._C.PyTorchFileWriter(buffer)) + + def __exit__(self, *args) -> None: + self.file_like.write_end_of_file() + self.buffer.flush() + + +def _open_zipfile_writer(name_or_buffer): + container: Type[_opener] + if _is_path(name_or_buffer): + container = _open_zipfile_writer_file + else: + container = _open_zipfile_writer_buffer + return container(name_or_buffer) + + +def _is_compressed_file(f) -> bool: + compress_modules = ['gzip'] + try: + return f.__module__ in compress_modules + except AttributeError: + return False + + +def _should_read_directly(f): + """ + Checks if f is a file that should be read directly. It should be read + directly if it is backed by a real file (has a fileno) and is not a + a compressed file (e.g. gzip) + """ + if _is_compressed_file(f): + return False + try: + return f.fileno() >= 0 + except io.UnsupportedOperation: + return False + except AttributeError: + return False + + +def _check_seekable(f) -> bool: + + def raise_err_msg(patterns, e): + for p in patterns: + if p in str(e): + msg = (str(e) + ". You can only torch.load from a file that is seekable." + + " Please pre-load the data into a buffer like io.BytesIO and" + + " try to load from it instead.") + raise type(e)(msg) + raise e + + try: + f.seek(f.tell()) + return True + except (io.UnsupportedOperation, AttributeError) as e: + raise_err_msg(["seek", "tell"], e) + return False + + +def _check_dill_version(pickle_module) -> None: + '''Checks if using dill as the pickle module, and if so, checks if it is the correct version. + If dill version is lower than 0.3.1, a ValueError is raised. + + Args: + pickle_module: module used for pickling metadata and objects + + ''' + if pickle_module is not None and pickle_module.__name__ == 'dill': + required_dill_version = (0, 3, 1) + if not check_module_version_greater_or_equal(pickle_module, required_dill_version, False): + raise ValueError(( + "'torch' supports dill >= {}, but you have dill {}." + " Please upgrade dill or switch to 'pickle'" + ).format( + '.'.join([str(num) for num in required_dill_version]), + pickle_module.__version__ + )) + + +def _check_save_filelike(f): + if not _is_path(f) and not hasattr(f, 'write'): + raise AttributeError( + "expected 'f' to be string, path, or a file-like object with " + "a 'write' attribute") + + +def save( + obj: object, + f: FILE_LIKE, + pickle_module: Any = pickle, + pickle_protocol: int = DEFAULT_PROTOCOL, + _use_new_zipfile_serialization: bool = True, + _disable_byteorder_record: bool = False +) -> None: + # Reference: https://github.com/pytorch/pytorch/issues/54354 + # The first line of this docstring overrides the one Sphinx generates for the + # documentation. We need it so that Sphinx doesn't leak `pickle`s path from + # the build environment (e.g. `>> # xdoctest: +SKIP("makes cwd dirty") + >>> # Save to file + >>> x = torch.tensor([0, 1, 2, 3, 4]) + >>> torch.save(x, 'tensor.pt') + >>> # Save to io.BytesIO buffer + >>> buffer = io.BytesIO() + >>> torch.save(x, buffer) + """ + torch._C._log_api_usage_once("torch.save") + _check_dill_version(pickle_module) + _check_save_filelike(f) + + if _use_new_zipfile_serialization: + with _open_zipfile_writer(f) as opened_zipfile: + _save(obj, opened_zipfile, pickle_module, pickle_protocol, _disable_byteorder_record) + return + else: + with _open_file_like(f, 'wb') as opened_file: + _legacy_save(obj, opened_file, pickle_module, pickle_protocol) + + +def _legacy_save(obj, f, pickle_module, pickle_protocol) -> None: + import torch.nn as nn + serialized_container_types = {} + serialized_storages = {} + + # Since loading storages that view the same data with different dtypes is + # not supported, we need to keep track of the dtype associated with each + # storage data_ptr and throw an error if the dtype is ever different. + # TODO: This feature could be added in the future + storage_dtypes: Dict[int, torch.dtype] = {} + + def persistent_id(obj: Any) -> Optional[Tuple]: + # FIXME: the docs say that persistent_id should only return a string + # but torch store returns tuples. This works only in the binary protocol + # see + # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects + # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537 + if isinstance(obj, type) and issubclass(obj, nn.Module): + if obj in serialized_container_types: + return None + serialized_container_types[obj] = True + source_file = source = None + try: + source_lines, _, source_file = get_source_lines_and_file(obj) + source = ''.join(source_lines) + except Exception: # saving the source is optional, so we can ignore any errors + warnings.warn("Couldn't retrieve source code for container of " + "type " + obj.__name__ + ". It won't be checked " + "for correctness upon loading.") + return ('module', obj, source_file, source) + + if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj): + storage: torch.UntypedStorage + + if isinstance(obj, torch.storage.TypedStorage): + # TODO: Once we decide to break serialization FC, this case + # can be deleted + storage = obj._untyped_storage + storage_dtype = obj.dtype + storage_type_str = obj._pickle_storage_type() + storage_type = getattr(torch, storage_type_str) + dtype = obj.dtype + storage_numel = obj._size() + + elif isinstance(obj, torch.UntypedStorage): + storage = obj + storage_dtype = torch.uint8 + storage_type = normalize_storage_type(type(obj)) + dtype = torch.uint8 + storage_numel = storage.nbytes() + else: + raise TypeError(f'type not recognized: {type(obj)}') + + # If storage is allocated, ensure that any other saved storages + # pointing to the same data all have the same dtype. If storage is + # not allocated, don't perform this check + if storage.data_ptr() != 0: + if storage.data_ptr() in storage_dtypes: + if storage_dtype != storage_dtypes[storage.data_ptr()]: + raise RuntimeError( + 'Cannot save multiple tensors or storages that ' + 'view the same data as different types') + else: + storage_dtypes[storage.data_ptr()] = storage_dtype + + view_metadata: Optional[Tuple[str, int, int]] + + # Offset is always 0, but we keep it for backwards compatibility + # with the old serialization format (which supported storage views) + offset = 0 + storage_key = str(storage._cdata) + location = location_tag(storage) + + # TODO: There's an issue here with FC. It might be impossible to + # solve, but it's worth noting. Imagine we save a list `[storage, + # tensor]`, where `tensor.storage()` is the same as `storage`, and + # `tensor.element_size() > 1`. Let's say that `tensor.dtype == + # torch.float`. The storage will be serialized with element size + # of 1, since we're choosing to serialize the first occurance of + # a duplicate storage. Since this legacy serialization format saves + # the numel of the storage, rather than nbytes directly, we'll be + # effectively saving nbytes in this case. We'll be able to load it + # and the tensor back up with no problems in _this_ and future + # versions of pytorch, but in older versions, here's the problem: + # the storage will be loaded up as a UntypedStorage, and then the + # FloatTensor will loaded and the UntypedStorage will be assigned to + # it. Since the storage dtype does not match the tensor dtype, this + # will cause an error. If we reverse the list, like `[tensor, + # storage]`, then we will save the `tensor.storage()` as a faked + # `FloatStorage`, and the saved size will be the correct + # dtype-specific numel count that old versions expect. `tensor` + # will be able to load up properly in old versions, pointing to + # a FloatStorage. However, `storage` is still being translated to + # a UntypedStorage, and it will try to resolve to the same + # FloatStorage that `tensor` contains. This will also cause an + # error. It doesn't seem like there's any way around this. + # Probably, we just cannot maintain FC for the legacy format if the + # saved list contains both a tensor and a storage that point to the + # same data. We should still be able to maintain FC for lists of + # just tensors, as long as all views share the same dtype as the + # tensor they are viewing. + + if storage_key not in serialized_storages: + serialized_storages[storage_key] = (storage, dtype) + is_view = storage._cdata != storage._cdata + if is_view: + view_metadata = (str(storage._cdata), offset, storage.nbytes()) + else: + view_metadata = None + + res = ('storage', + storage_type, + storage_key, + location, + storage_numel, + view_metadata) + return res + return None + + sys_info = dict( + protocol_version=PROTOCOL_VERSION, + little_endian=sys.byteorder == 'little', + type_sizes=dict( + short=SHORT_SIZE, + int=INT_SIZE, + long=LONG_SIZE, + ), + ) + + pickle_module.dump(MAGIC_NUMBER, f, protocol=pickle_protocol) + pickle_module.dump(PROTOCOL_VERSION, f, protocol=pickle_protocol) + pickle_module.dump(sys_info, f, protocol=pickle_protocol) + pickler = pickle_module.Pickler(f, protocol=pickle_protocol) + pickler.persistent_id = persistent_id + pickler.dump(obj) + + serialized_storage_keys = sorted(serialized_storages.keys()) + pickle_module.dump(serialized_storage_keys, f, protocol=pickle_protocol) + f.flush() + for key in serialized_storage_keys: + storage, dtype = serialized_storages[key] + storage._write_file(f, _should_read_directly(f), True, torch._utils._element_size(dtype)) + + +def _save(obj, zip_file, pickle_module, pickle_protocol, _disable_byteorder_record): + serialized_storages = {} + id_map: Dict[int, str] = {} + + # Since loading storages that view the same data with different dtypes is + # not supported, we need to keep track of the dtype associated with each + # storage data_ptr and throw an error if the dtype is ever different. + # TODO: This feature could be added in the future + storage_dtypes: Dict[int, torch.dtype] = {} + + def persistent_id(obj): + # FIXME: the docs say that persistent_id should only return a string + # but torch store returns tuples. This works only in the binary protocol + # see + # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects + # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537 + if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj): + + if isinstance(obj, torch.storage.TypedStorage): + # TODO: Once we decide to break serialization FC, this case + # can be deleted + storage = obj._untyped_storage + storage_dtype = obj.dtype + storage_type_str = obj._pickle_storage_type() + storage_type = getattr(torch, storage_type_str) + storage_numel = obj._size() + + else: + storage = obj + storage_dtype = torch.uint8 + storage_type = normalize_storage_type(type(obj)) + storage_numel = storage.nbytes() + + # If storage is allocated, ensure that any other saved storages + # pointing to the same data all have the same dtype. If storage is + # not allocated, don't perform this check + if storage.data_ptr() != 0: + if storage.data_ptr() in storage_dtypes: + if storage_dtype != storage_dtypes[storage.data_ptr()]: + raise RuntimeError( + 'Cannot save multiple tensors or storages that ' + 'view the same data as different types') + else: + storage_dtypes[storage.data_ptr()] = storage_dtype + + storage_key = id_map.setdefault(storage._cdata, str(len(id_map))) + location = location_tag(storage) + serialized_storages[storage_key] = storage + + return ('storage', + storage_type, + storage_key, + location, + storage_numel) + + return None + + # Write the pickle data for `obj` + data_buf = io.BytesIO() + pickler = pickle_module.Pickler(data_buf, protocol=pickle_protocol) + pickler.persistent_id = persistent_id + pickler.dump(obj) + data_value = data_buf.getvalue() + zip_file.write_record('data.pkl', data_value, len(data_value)) + + # Write byte order marker + if not _disable_byteorder_record: + if sys.byteorder not in ['little', 'big']: + raise ValueError('Unknown endianness type: ' + sys.byteorder) + + zip_file.write_record('byteorder', sys.byteorder, len(sys.byteorder)) + + # Write each tensor to a file named tensor/the_tensor_key in the zip archive + for key in sorted(serialized_storages.keys()): + name = f'data/{key}' + storage = serialized_storages[key] + # given that we copy things around anyway, we might use storage.cpu() + # this means to that to get tensors serialized, you need to implement + # .cpu() on the underlying Storage + if storage.device.type != 'cpu': + storage = storage.cpu() + # Now that it is on the CPU we can directly copy it into the zip file + num_bytes = storage.nbytes() + zip_file.write_record(name, storage, num_bytes) + + +def load( + f: FILE_LIKE, + map_location: MAP_LOCATION = None, + pickle_module: Any = None, + *, + weights_only: bool = False, + mmap: Optional[bool] = None, + **pickle_load_args: Any +) -> Any: + # Reference: https://github.com/pytorch/pytorch/issues/54354 + # The first line of this docstring overrides the one Sphinx generates for the + # documentation. We need it so that Sphinx doesn't leak `pickle`s path from + # the build environment (e.g. `>> # xdoctest: +SKIP("undefined filepaths") + >>> torch.load('tensors.pt', weights_only=True) + # Load all tensors onto the CPU + >>> torch.load('tensors.pt', map_location=torch.device('cpu'), weights_only=True) + # Load all tensors onto the CPU, using a function + >>> torch.load('tensors.pt', map_location=lambda storage, loc: storage, weights_only=True) + # Load all tensors onto GPU 1 + >>> torch.load('tensors.pt', map_location=lambda storage, loc: storage.cuda(1), weights_only=True) + # Map tensors from GPU 1 to GPU 0 + >>> torch.load('tensors.pt', map_location={'cuda:1': 'cuda:0'}, weights_only=True) + # Load tensor from io.BytesIO object + # Loading from a buffer setting weights_only=False, warning this can be unsafe + >>> with open('tensor.pt', 'rb') as f: + ... buffer = io.BytesIO(f.read()) + >>> torch.load(buffer, weights_only=False) + # Load a module with 'ascii' encoding for unpickling + # Loading from a module setting weights_only=False, warning this can be unsafe + >>> torch.load('module.pt', encoding='ascii', weights_only=False) + """ + torch._C._log_api_usage_once("torch.load") + UNSAFE_MESSAGE = ( + "Weights only load failed. Re-running `torch.load` with `weights_only` set to `False`" + " will likely succeed, but it can result in arbitrary code execution." + "Do it only if you get the file from a trusted source. WeightsUnpickler error: " + ) + # Add ability to force safe only weight loads via environment variable + if os.getenv("TORCH_FORCE_WEIGHTS_ONLY_LOAD", "0").lower() in ['1', 'y', 'yes', 'true']: + weights_only = True + + if weights_only: + if pickle_module is not None: + raise RuntimeError("Can not safely load weights when explicit pickle_module is specified") + else: + if pickle_module is None: + pickle_module = pickle + + # make flipping default BC-compatible + if mmap is None: + mmap = False + + _check_dill_version(pickle_module) + + if 'encoding' not in pickle_load_args.keys(): + pickle_load_args['encoding'] = 'utf-8' + + with _open_file_like(f, 'rb') as opened_file: + if _is_zipfile(opened_file): + # The zipfile reader is going to advance the current file position. + # If we want to actually tail call to torch.jit.load, we need to + # reset back to the original position. + orig_position = opened_file.tell() + overall_storage = None + with _open_zipfile_reader(opened_file) as opened_zipfile: + if _is_torchscript_zip(opened_zipfile): + warnings.warn("'torch.load' received a zip file that looks like a TorchScript archive" + " dispatching to 'torch.jit.load' (call 'torch.jit.load' directly to" + " silence this warning)", UserWarning) + opened_file.seek(orig_position) + return torch.jit.load(opened_file, map_location=map_location) + if mmap: + if not _is_path(f): + raise ValueError("f must be a file path in order to use the mmap argument") + size = os.path.getsize(f) + overall_storage = torch.UntypedStorage.from_file(os.fspath(f), False, size) + if weights_only: + try: + return _load(opened_zipfile, + map_location, + _weights_only_unpickler, + overall_storage=overall_storage, + **pickle_load_args) + except RuntimeError as e: + raise pickle.UnpicklingError(UNSAFE_MESSAGE + str(e)) from None + return _load(opened_zipfile, + map_location, + pickle_module, + overall_storage=overall_storage, + **pickle_load_args) + if mmap: + f_name = "" if not isinstance(f, str) else f"{f}, " + raise RuntimeError("mmap can only be used with files saved with " + f"`torch.save({f_name}_use_new_zipfile_serialization=True), " + "please torch.save your checkpoint with this option in order to use mmap.") + if weights_only: + try: + return _legacy_load(opened_file, map_location, _weights_only_unpickler, **pickle_load_args) + except RuntimeError as e: + raise pickle.UnpicklingError(UNSAFE_MESSAGE + str(e)) from None + return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args) + + +# Register pickling support for layout instances such as +# torch.sparse_coo, etc +def _get_layout(name): + """Get layout extension object from its string representation. + """ + cache = _get_layout.cache # type: ignore[attr-defined] + if not cache: + for v in torch.__dict__.values(): + if isinstance(v, torch.layout): + cache[str(v)] = v + return cache[name] + +# There are yet not good way to type annotate function attributes https://github.com/python/mypy/issues/2087 +_get_layout.cache = {} # type: ignore[attr-defined] +copyreg.pickle(torch.layout, lambda obj: (_get_layout, (str(obj),))) + + +def _legacy_load(f, map_location, pickle_module, **pickle_load_args): + deserialized_objects: Dict[int, Any] = {} + + restore_location = _get_restore_location(map_location) + + class UnpicklerWrapper(pickle_module.Unpickler): # type: ignore[name-defined] + + def find_class(self, mod_name, name): + if type(name) is str and 'Storage' in name: + try: + return StorageType(name) + except KeyError: + pass + return super().find_class(mod_name, name) + + def _check_container_source(container_type, source_file, original_source): + try: + current_source = ''.join(get_source_lines_and_file(container_type)[0]) + except Exception: # saving the source is optional, so we can ignore any errors + warnings.warn("Couldn't retrieve source code for container of " + "type " + container_type.__name__ + ". It won't be checked " + "for correctness upon loading.") + return + if original_source != current_source: + if container_type.dump_patches: + file_name = container_type.__name__ + '.patch' + diff = difflib.unified_diff(current_source.split('\n'), + original_source.split('\n'), + source_file, + source_file, lineterm="") + lines = '\n'.join(diff) + try: + with open(file_name, 'a+') as f: + file_size = f.seek(0, 2) + f.seek(0) + if file_size == 0: + f.write(lines) + elif file_size != len(lines) or f.read() != lines: + raise OSError + msg = ("Saved a reverse patch to " + file_name + ". " + "Run `patch -p0 < " + file_name + "` to revert your " + "changes.") + except OSError: + msg = ("Tried to save a patch, but couldn't create a " + "writable file " + file_name + ". Make sure it " + "doesn't exist and your working directory is " + "writable.") + else: + msg = ("you can retrieve the original source code by " + "accessing the object's source attribute or set " + "`torch.nn.Module.dump_patches = True` and use the " + "patch tool to revert the changes.") + msg = f"source code of class '{torch.typename(container_type)}' has changed. {msg}" + warnings.warn(msg, SourceChangeWarning) + + def legacy_load(f): + deserialized_objects: Dict[int, Any] = {} + + def persistent_load(saved_id): + if isinstance(saved_id, tuple): + # Ignore containers that don't have any sources saved + if all(saved_id[1:]): + _check_container_source(*saved_id) + return saved_id[0] + return deserialized_objects[int(saved_id)] + + with closing(tarfile.open(fileobj=f, mode='r:', format=tarfile.PAX_FORMAT)) as tar, \ + mkdtemp() as tmpdir: + + tar.extract('storages', path=tmpdir) + with open(os.path.join(tmpdir, 'storages'), 'rb', 0) as f: + num_storages = pickle_module.load(f, **pickle_load_args) + for i in range(num_storages): + args = pickle_module.load(f, **pickle_load_args) + key, location, storage_type = args + dtype = storage_type._dtype + obj = cast(Storage, torch.UntypedStorage)._new_with_file(f, torch._utils._element_size(dtype)) + obj = restore_location(obj, location) + # TODO: Once we decide to break serialization FC, we can + # stop wrapping with TypedStorage + deserialized_objects[key] = torch.storage.TypedStorage( + wrap_storage=obj, + dtype=dtype, + _internal=True) + + storage_views = pickle_module.load(f, **pickle_load_args) + for target_cdata, root_cdata, offset, numel in storage_views: + root = deserialized_objects[root_cdata] + element_size = torch._utils._element_size(root.dtype) + offset_bytes = offset * element_size + # TODO: Once we decide to break serialization FC, we can + # stop wrapping with TypedStorage + deserialized_objects[target_cdata] = torch.storage.TypedStorage( + wrap_storage=root._untyped_storage[offset_bytes:offset_bytes + numel * element_size], + dtype=root.dtype, + _internal=True) + + tar.extract('tensors', path=tmpdir) + with open(os.path.join(tmpdir, 'tensors'), 'rb', 0) as f: + num_tensors = pickle_module.load(f, **pickle_load_args) + for _ in range(num_tensors): + args = pickle_module.load(f, **pickle_load_args) + key, storage_id, original_tensor_type = args + storage = deserialized_objects[storage_id] + ndim, = struct.unpack(' str: + # When using encoding='bytes' in Py3, some **internal** keys stored as + # strings in Py2 are loaded as bytes. This function decodes them with + # ascii encoding, one that Py3 uses by default. + # + # NOTE: This should only be used on internal keys (e.g., `typename` and + # `location` in `persistent_load` below! + if isinstance(bytes_str, bytes): + return bytes_str.decode('ascii') + return bytes_str + + +def _get_restore_location(map_location): + if map_location is None: + restore_location = default_restore_location + elif isinstance(map_location, dict): + def restore_location(storage, location): + location = map_location.get(location, location) + return default_restore_location(storage, location) + elif isinstance(map_location, (str, bytes)): + def restore_location(storage, location): + return default_restore_location(storage, map_location) + elif isinstance(map_location, torch.device): + def restore_location(storage, location): + return default_restore_location(storage, str(map_location)) + else: + def restore_location(storage, location): + result = map_location(storage, location) + if result is None: + result = default_restore_location(storage, location) + return result + return restore_location + + +class StorageType: + def __init__(self, name): + self._dtype = _get_dtype_from_pickle_storage_type(name) + + @property + def dtype(self): + return self._dtype + + def __str__(self): + return f'StorageType(dtype={self.dtype})' + + +def _load(zip_file, map_location, pickle_module, pickle_file='data.pkl', overall_storage=None, **pickle_load_args): + restore_location = _get_restore_location(map_location) + + loaded_storages = {} + + # check if byteswapping is needed + byteordername = 'byteorder' + byteorderdata = None + if zip_file.has_record(byteordername): + byteorderdata = zip_file.get_record(byteordername) + if byteorderdata not in [b'little', b'big']: + raise ValueError('Unknown endianness type: ' + byteorderdata.decode()) + elif get_default_load_endianness() == LoadEndianness.LITTLE or \ + get_default_load_endianness() is None: + byteorderdata = b'little' + elif get_default_load_endianness() == LoadEndianness.BIG: + byteorderdata = b'big' + elif get_default_load_endianness() == LoadEndianness.NATIVE: + pass + else: + raise ValueError('Invalid load endianness type') + + if not zip_file.has_record(byteordername) and \ + get_default_load_endianness() is None and \ + sys.byteorder == 'big': + # Default behaviour was changed + # See https://github.com/pytorch/pytorch/issues/101688 + warnings.warn("The default load endianness for checkpoints without a byteorder mark " + "on big endian machines was changed from 'native' to 'little' endian, " + "to avoid this behavior please use " + "torch.serialization.set_default_load_endianness to set " + "the desired default load endianness", + UserWarning) + + def load_tensor(dtype, numel, key, location): + name = f'data/{key}' + if torch._guards.detect_fake_mode(None) is not None: + nbytes = numel * torch._utils._element_size(dtype) + storage = torch.UntypedStorage(nbytes, device='meta') + elif overall_storage is not None: + storage_offset = zip_file.get_record_offset(name) + storage = overall_storage[storage_offset:storage_offset + numel] + else: + storage = zip_file.get_storage_from_record(name, numel, torch.UntypedStorage)._typed_storage()._untyped_storage + # swap here if byteswapping is needed + if byteorderdata is not None: + if byteorderdata.decode() != sys.byteorder: + storage.byteswap(dtype) + + # TODO: Once we decide to break serialization FC, we can + # stop wrapping with TypedStorage + typed_storage = torch.storage.TypedStorage( + wrap_storage=restore_location(storage, location), + dtype=dtype, + _internal=True) + + if typed_storage._data_ptr() != 0: + loaded_storages[key] = typed_storage + + return typed_storage + + def persistent_load(saved_id): + assert isinstance(saved_id, tuple) + typename = _maybe_decode_ascii(saved_id[0]) + data = saved_id[1:] + + assert typename == 'storage', \ + f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'" + storage_type, key, location, numel = data + if storage_type is torch.UntypedStorage: + dtype = torch.uint8 + else: + dtype = storage_type.dtype + + if key in loaded_storages: + typed_storage = loaded_storages[key] + else: + nbytes = numel * torch._utils._element_size(dtype) + typed_storage = load_tensor(dtype, nbytes, key, _maybe_decode_ascii(location)) + + return typed_storage + + load_module_mapping: Dict[str, str] = { + # See https://github.com/pytorch/pytorch/pull/51633 + 'torch.tensor': 'torch._tensor' + } + + # Need to subclass Unpickler instead of directly monkey-patching the find_class method + # because it's marked readonly in pickle. + # The type: ignore is because mypy can't statically determine the type of this class. + class UnpicklerWrapper(pickle_module.Unpickler): # type: ignore[name-defined] + # from https://stackoverflow.com/questions/13398462/unpickling-python-objects-with-a-changed-module-path/13405732 + # Lets us override the imports that pickle uses when unpickling an object. + # This is useful for maintaining BC if we change a module path that tensor instantiation relies on. + def find_class(self, mod_name, name): + if type(name) is str and 'Storage' in name: + try: + return StorageType(name) + except KeyError: + pass + mod_name = load_module_mapping.get(mod_name, mod_name) + return super().find_class(mod_name, name) + + # Load the data (which may in turn use `persistent_load` to load tensors) + data_file = io.BytesIO(zip_file.get_record(pickle_file)) + + unpickler = UnpicklerWrapper(data_file, **pickle_load_args) + unpickler.persistent_load = persistent_load + result = unpickler.load() + + torch._utils._validate_loaded_sparse_tensors() + torch._C._log_api_usage_metadata( + "torch.load.metadata", {"serialization_id": zip_file.serialization_id()} + ) + return result + + +def _is_torchscript_zip(zip_file): + return 'constants.pkl' in zip_file.get_all_records()