Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_custom_op/__init__.py +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_custom_op/__pycache__/autograd.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_custom_op/__pycache__/functional.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_custom_op/__pycache__/impl.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_custom_op/functional.py +187 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/_conversions.py +118 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/linalg/__pycache__/__init__.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/nn/__init__.py +3 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/nn/__pycache__/__init__.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/nn/functional/__pycache__/__init__.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/special/__pycache__/__init__.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantizable/modules/__init__.py +9 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/__pycache__/functional_modules.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/__pycache__/normalization.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/batchnorm.py +106 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/utils.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/reference/modules/linear.py +57 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/sparse/quantized/dynamic/__pycache__/__init__.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/__pycache__/__init__.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/__pycache__/_numeric_suite.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/__pycache__/__init__.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__init__.py +189 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/quantization_mappings.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/quantize_fx.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/_correct_bias.py +144 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/__pycache__/_common_operator_config_utils.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/__pycache__/backend_config.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/__pycache__/observation_type.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/__pycache__/onednn.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/observation_type.py +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/qnnpack.py +160 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/tensorrt.py +81 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/utils.py +279 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/x86.py +113 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fake_quantize.py +546 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/_model_report/model_report.py +606 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/prepare.py +1880 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/__pycache__/duplicate_dq_pass.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/__pycache__/prepare.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/__pycache__/qat_utils.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/__pycache__/utils.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/duplicate_dq_pass.py +83 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/export_utils.py +211 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/prepare.py +489 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/representation/__init__.py +5 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/qconfig_mapping.py +350 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/quantizer/__pycache__/__init__.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/quantizer/__pycache__/xnnpack_quantizer.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/__pycache__/__init__.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/__pycache__/_remove_auto_functionalized_pass.cpython-311.pyc +0 -0
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_custom_op/__init__.py
ADDED
|
File without changes
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_custom_op/__pycache__/autograd.cpython-311.pyc
ADDED
|
Binary file (16 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_custom_op/__pycache__/functional.cpython-311.pyc
ADDED
|
Binary file (10.4 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_custom_op/__pycache__/impl.cpython-311.pyc
ADDED
|
Binary file (50 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_custom_op/functional.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import weakref
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.utils._pytree as pytree
|
| 5 |
+
from torch._C import _ExcludeDispatchKeyGuard, DispatchKey, DispatchKeySet
|
| 6 |
+
from torch._ops import OpOverload
|
| 7 |
+
from torch.library import Library
|
| 8 |
+
from torchgen.model import (
|
| 9 |
+
BaseTy,
|
| 10 |
+
BaseType,
|
| 11 |
+
FunctionSchema,
|
| 12 |
+
OperatorName,
|
| 13 |
+
OptionalType,
|
| 14 |
+
SchemaKind,
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
from .autograd import autograd_not_implemented
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def register_functional_op(
|
| 21 |
+
lib: Library,
|
| 22 |
+
new_op_name: str,
|
| 23 |
+
mutable_op: OpOverload,
|
| 24 |
+
) -> None:
|
| 25 |
+
"""Given a mutable operator, registers the functional variant.
|
| 26 |
+
|
| 27 |
+
This API also correctly links the functional variant with the mutable
|
| 28 |
+
operator for the purposes of functionalization.
|
| 29 |
+
|
| 30 |
+
All of the new registrations are performed on the ``lib`` passed in.
|
| 31 |
+
|
| 32 |
+
Arguments:
|
| 33 |
+
lib (Library): Should be a torch.library.Library object that has
|
| 34 |
+
the same namespace as ``mutable_op``'s namespace.
|
| 35 |
+
lib will be used to register the new functional op as well
|
| 36 |
+
as a functionalization kernel for the ``mutable_op``
|
| 37 |
+
If you don't have a library handy, use
|
| 38 |
+
``torch.library.Library(ns, 'FRAGMENT')`` to construct one.
|
| 39 |
+
new_op_name (str): The name of the functional operator (without the
|
| 40 |
+
namespace). If no namespace, the new functional variant will be
|
| 41 |
+
accessible under ``torch.ops.{lib.ns}.new_op_name``.
|
| 42 |
+
mutable_op (OpOverload): The mutable custom operator. Note
|
| 43 |
+
that you may need to add a `.default` to it, like
|
| 44 |
+
`torch.ops.aten.abs_.default`.
|
| 45 |
+
|
| 46 |
+
"""
|
| 47 |
+
validate(mutable_op)
|
| 48 |
+
schema = functional_schema(new_op_name, mutable_op)
|
| 49 |
+
lib.define(schema)
|
| 50 |
+
|
| 51 |
+
functional_impl = construct_functional_impl(mutable_op)
|
| 52 |
+
lib.impl(new_op_name, functional_impl, 'CompositeExplicitAutograd')
|
| 53 |
+
|
| 54 |
+
functional_op = getattr(getattr(torch.ops, lib.ns), new_op_name).default
|
| 55 |
+
|
| 56 |
+
# There's no easy way for us to generate the autograd kernel, so we
|
| 57 |
+
# use autograd_not_implemented. Also, this makes it so that the user
|
| 58 |
+
# is unable to register an autograd formula themselves. This shouldn't
|
| 59 |
+
# be a problem if the user doesn't use the functional op direclty
|
| 60 |
+
# in their program, but we may need to revist this in the future.
|
| 61 |
+
lib.impl(new_op_name, autograd_not_implemented(functional_op), 'Autograd')
|
| 62 |
+
|
| 63 |
+
f_kernel = construct_functionalization_kernel(weakref.proxy(mutable_op), functional_op)
|
| 64 |
+
|
| 65 |
+
lib.impl(mutable_op, f_kernel, 'Functionalize')
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def construct_functional_impl(mutable_op):
|
| 69 |
+
def functional_impl(*args):
|
| 70 |
+
# Strategy:
|
| 71 |
+
# - clone args that would have been mutated
|
| 72 |
+
# - run mutable_op
|
| 73 |
+
# - return the cloned args as additional outputs
|
| 74 |
+
new_args = []
|
| 75 |
+
extra_rets = []
|
| 76 |
+
for is_write, arg in zip(mutable_args(mutable_op), args):
|
| 77 |
+
if is_write:
|
| 78 |
+
cloned = arg.clone() if arg is not None else None
|
| 79 |
+
new_args.append(cloned)
|
| 80 |
+
extra_rets.append(cloned)
|
| 81 |
+
else:
|
| 82 |
+
new_args.append(arg)
|
| 83 |
+
result = mutable_op(*new_args)
|
| 84 |
+
if result is None:
|
| 85 |
+
return tuple(extra_rets)
|
| 86 |
+
if isinstance(result, tuple):
|
| 87 |
+
return (*result, *extra_rets)
|
| 88 |
+
return (result, *extra_rets)
|
| 89 |
+
return functional_impl
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def construct_functionalization_kernel(mutable_op, functional_op):
|
| 93 |
+
def kernel(*args):
|
| 94 |
+
# There's nothing to be functionalized!
|
| 95 |
+
# We can still end up here because DispatchKey::Functionalize is a mode key
|
| 96 |
+
if pytree.tree_all_only(torch.Tensor, lambda x: not torch._is_functional_tensor(x), args):
|
| 97 |
+
with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Functionalize)):
|
| 98 |
+
return mutable_op(*args)
|
| 99 |
+
|
| 100 |
+
# NB: This differs from the codegen -- codegen handles cases where there
|
| 101 |
+
# are mixed FunctionalTensorWrapper and non-FunctionalTensorWrapper.
|
| 102 |
+
# This only really matters for XLA (mixed CPU-XLA tensors) and
|
| 103 |
+
# running functionalization without the PT2 stack (which guarantees to us that
|
| 104 |
+
# all tensors are FunctionalTensorWrapper).
|
| 105 |
+
if not pytree.tree_all_only(torch.Tensor, torch._is_functional_tensor, args):
|
| 106 |
+
raise RuntimeError("{mutable_op}: expected all args to be FunctionalTensorWrapper")
|
| 107 |
+
|
| 108 |
+
unwrapped_args = []
|
| 109 |
+
for arg in args:
|
| 110 |
+
if isinstance(arg, torch.Tensor) and torch._is_functional_tensor(arg):
|
| 111 |
+
torch._sync(arg)
|
| 112 |
+
unwrapped = torch._from_functional_tensor(arg)
|
| 113 |
+
unwrapped_args.append(unwrapped)
|
| 114 |
+
else:
|
| 115 |
+
unwrapped_args.append(arg)
|
| 116 |
+
|
| 117 |
+
with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Functionalize)):
|
| 118 |
+
output = functional_op(*unwrapped_args)
|
| 119 |
+
|
| 120 |
+
num_actual_output = len(mutable_op._schema.returns)
|
| 121 |
+
actual_output = pytree.tree_map(
|
| 122 |
+
torch._to_functional_tensor, output[:num_actual_output])
|
| 123 |
+
|
| 124 |
+
new_values_to_propagate = output[num_actual_output:]
|
| 125 |
+
inputs_to_replace = [arg for is_write, arg in zip(mutable_args(mutable_op), args)
|
| 126 |
+
if is_write]
|
| 127 |
+
assert len(new_values_to_propagate) == len(inputs_to_replace)
|
| 128 |
+
for new_value, arg in zip(new_values_to_propagate, inputs_to_replace):
|
| 129 |
+
if (arg is None and new_value is None) or (arg is not None and new_value is not None):
|
| 130 |
+
continue
|
| 131 |
+
torch._C._propagate_xla_data(arg, new_value)
|
| 132 |
+
torch._C._replace_(arg, new_value)
|
| 133 |
+
torch._C._commit_update(arg)
|
| 134 |
+
torch._sync(arg)
|
| 135 |
+
|
| 136 |
+
if len(actual_output) == 1:
|
| 137 |
+
return actual_output[0]
|
| 138 |
+
elif len(actual_output) == 0:
|
| 139 |
+
return None
|
| 140 |
+
return actual_output
|
| 141 |
+
|
| 142 |
+
return kernel
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def validate(mutable_op: OpOverload):
|
| 146 |
+
if not isinstance(mutable_op, OpOverload):
|
| 147 |
+
raise TypeError(
|
| 148 |
+
f"register_functional_op(mutable_op): expected mutable_op to be instance of "
|
| 149 |
+
f"OpOverload but got {type(mutable_op)}")
|
| 150 |
+
|
| 151 |
+
# There are generally three types of "in-place" or "mutable" ops.
|
| 152 |
+
# Each of them have their own conventions:
|
| 153 |
+
# - inplace (first input modified in-place and returned as only output)
|
| 154 |
+
# - out= (some args modified in-place and returned as outputs)
|
| 155 |
+
# - mutable (some args modified in-place but none of those returned as outputs)
|
| 156 |
+
# In theory we can support all three, but we'll just support the last
|
| 157 |
+
# option right now for simplicity.
|
| 158 |
+
schema = FunctionSchema.parse(str(mutable_op._schema))
|
| 159 |
+
if not schema.kind() == SchemaKind.mutable:
|
| 160 |
+
raise RuntimeError("Expected op to be mutable (as opposed to functional, inplace or out)")
|
| 161 |
+
for ret in schema.returns:
|
| 162 |
+
# construct_functionalization_kernel assumes this for simplicity
|
| 163 |
+
if ret.annotation is not None:
|
| 164 |
+
raise NotImplementedError(
|
| 165 |
+
"NYI: register_functional_op(op) where op returns a mutated or aliased value. "
|
| 166 |
+
"Please file an issue (and as a workaround, modify your operator to "
|
| 167 |
+
"not return the mutated value or aliases)")
|
| 168 |
+
for arg in schema.arguments.flat_all:
|
| 169 |
+
# construct_functionalization_kernel assumes this for simplicity
|
| 170 |
+
if arg.type.is_tensor_like() and (
|
| 171 |
+
arg.type != BaseType(BaseTy.Tensor)
|
| 172 |
+
and arg.type != OptionalType(BaseType(BaseTy.Tensor))
|
| 173 |
+
):
|
| 174 |
+
raise NotImplementedError(
|
| 175 |
+
"NYI: register_functional_op(op) where op has a List[Tensor] input."
|
| 176 |
+
"Please file an issue.")
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def functional_schema(new_op_name, op: OpOverload):
|
| 180 |
+
schema = FunctionSchema.parse(str(op._schema))
|
| 181 |
+
schema = schema.signature().with_name(OperatorName.parse(new_op_name))
|
| 182 |
+
return str(schema)
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def mutable_args(op: OpOverload):
|
| 186 |
+
return tuple(False if arg.alias_info is None else arg.alias_info.is_write
|
| 187 |
+
for arg in op._schema.arguments)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/_conversions.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch._prims_common as utils
|
| 3 |
+
|
| 4 |
+
# Utilities should come BEFORE this import
|
| 5 |
+
from torch._decomp import register_decomposition
|
| 6 |
+
|
| 7 |
+
from torch._prims_common import TensorLikeType
|
| 8 |
+
from torch._prims_common.wrappers import out_wrapper
|
| 9 |
+
from torch._refs import _broadcast_shapes
|
| 10 |
+
|
| 11 |
+
# Data conversion references.
|
| 12 |
+
#
|
| 13 |
+
# Note: this module breaks the usual _refs to torch naming scheme where
|
| 14 |
+
# _refs.foo.bar is a ref for torch.foo.bar. The following definitions are not
|
| 15 |
+
# part of _refs/__init__.py to avoid name clashes with Python builtin types
|
| 16 |
+
# (like int).
|
| 17 |
+
|
| 18 |
+
__all__ = [
|
| 19 |
+
# dtypes
|
| 20 |
+
"bfloat16",
|
| 21 |
+
"bool",
|
| 22 |
+
"byte",
|
| 23 |
+
"cdouble",
|
| 24 |
+
"cfloat",
|
| 25 |
+
"chalf",
|
| 26 |
+
"char",
|
| 27 |
+
"double",
|
| 28 |
+
"float",
|
| 29 |
+
"half",
|
| 30 |
+
"int",
|
| 31 |
+
"long",
|
| 32 |
+
"short",
|
| 33 |
+
# misc
|
| 34 |
+
"complex",
|
| 35 |
+
"polar",
|
| 36 |
+
]
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def _make_conversion_method(name: str, dtype: torch.dtype):
|
| 40 |
+
def fn(
|
| 41 |
+
self: TensorLikeType, memory_format: torch.memory_format = torch.preserve_format
|
| 42 |
+
) -> TensorLikeType:
|
| 43 |
+
return self.to(dtype, memory_format=memory_format) # type: ignore[call-overload]
|
| 44 |
+
|
| 45 |
+
fn.__name__ = name
|
| 46 |
+
return fn
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
bfloat16 = _make_conversion_method("bfloat16", torch.bfloat16)
|
| 50 |
+
|
| 51 |
+
bool = _make_conversion_method("bool", torch.bool)
|
| 52 |
+
|
| 53 |
+
byte = _make_conversion_method("byte", torch.uint8)
|
| 54 |
+
|
| 55 |
+
cdouble = _make_conversion_method("cdouble", torch.cdouble)
|
| 56 |
+
|
| 57 |
+
cfloat = _make_conversion_method("cfloat", torch.cfloat)
|
| 58 |
+
|
| 59 |
+
chalf = _make_conversion_method("chalf", torch.complex32)
|
| 60 |
+
|
| 61 |
+
char = _make_conversion_method("char", torch.int8)
|
| 62 |
+
|
| 63 |
+
double = _make_conversion_method("double", torch.double)
|
| 64 |
+
|
| 65 |
+
float = _make_conversion_method("float", torch.float)
|
| 66 |
+
|
| 67 |
+
half = _make_conversion_method("half", torch.half)
|
| 68 |
+
|
| 69 |
+
int = _make_conversion_method("int", torch.int)
|
| 70 |
+
|
| 71 |
+
long = _make_conversion_method("long", torch.long)
|
| 72 |
+
|
| 73 |
+
short = _make_conversion_method("short", torch.short)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
@register_decomposition(torch._ops.ops.aten.complex)
|
| 77 |
+
# Note: complex has type promotion tests disabled due to different semantics.
|
| 78 |
+
# exact_dtype is for compat with complex_check_dtype from core.
|
| 79 |
+
@out_wrapper(exact_dtype=True)
|
| 80 |
+
def complex(real: TensorLikeType, imag: TensorLikeType) -> TensorLikeType:
|
| 81 |
+
allowed_dtypes = (torch.float32, torch.float64, torch.float16)
|
| 82 |
+
torch._check(
|
| 83 |
+
real.dtype in allowed_dtypes and imag.dtype in allowed_dtypes,
|
| 84 |
+
lambda: (
|
| 85 |
+
f"Expected both inputs to be Half, Float or Double tensors but got "
|
| 86 |
+
f"{real.dtype} and {imag.dtype}"
|
| 87 |
+
),
|
| 88 |
+
)
|
| 89 |
+
torch._check(
|
| 90 |
+
real.dtype == imag.dtype,
|
| 91 |
+
lambda: (
|
| 92 |
+
f"Expected object of scalar type {real.dtype} but got "
|
| 93 |
+
f"scalar type {imag.dtype} for second argument"
|
| 94 |
+
),
|
| 95 |
+
)
|
| 96 |
+
result_dtype = utils.corresponding_complex_dtype(real.dtype) # type: ignore[arg-type]
|
| 97 |
+
common_shape = _broadcast_shapes(real.shape, imag.shape)
|
| 98 |
+
result = real.new_empty(
|
| 99 |
+
common_shape,
|
| 100 |
+
dtype=result_dtype,
|
| 101 |
+
layout=real.layout,
|
| 102 |
+
device=real.device,
|
| 103 |
+
# pin_memory=real.is_pinned(), # NYI
|
| 104 |
+
)
|
| 105 |
+
result.real = real
|
| 106 |
+
result.imag = imag
|
| 107 |
+
return result
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
@register_decomposition(torch._ops.ops.aten.polar)
|
| 111 |
+
# Note: polar has type promotion tests disabled due to different semantics.
|
| 112 |
+
# exact_dtype is for compat with complex_check_dtype from core.
|
| 113 |
+
@out_wrapper(exact_dtype=True)
|
| 114 |
+
def polar(abs: TensorLikeType, angle: TensorLikeType) -> TensorLikeType:
|
| 115 |
+
result = torch.complex(abs, angle)
|
| 116 |
+
result.real = abs * torch.cos(angle)
|
| 117 |
+
result.imag = abs * torch.sin(angle)
|
| 118 |
+
return result
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/linalg/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (16.9 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/nn/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List
|
| 2 |
+
|
| 3 |
+
__all__: List[str] = []
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/nn/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (351 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/nn/functional/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (49 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/special/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (10.6 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantizable/modules/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .activation import MultiheadAttention
|
| 2 |
+
from .rnn import LSTM
|
| 3 |
+
from .rnn import LSTMCell
|
| 4 |
+
|
| 5 |
+
__all__ = [
|
| 6 |
+
'LSTM',
|
| 7 |
+
'LSTMCell',
|
| 8 |
+
'MultiheadAttention',
|
| 9 |
+
]
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/__pycache__/functional_modules.cpython-311.pyc
ADDED
|
Binary file (13.5 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/__pycache__/normalization.cpython-311.pyc
ADDED
|
Binary file (11.9 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/batchnorm.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.ao.nn.intrinsic as nni
|
| 3 |
+
|
| 4 |
+
__all__ = [
|
| 5 |
+
"BatchNorm2d",
|
| 6 |
+
"BatchNorm3d"
|
| 7 |
+
]
|
| 8 |
+
|
| 9 |
+
class _BatchNorm(torch.nn.modules.batchnorm._BatchNorm):
|
| 10 |
+
def __init__(self, num_features, eps=1e-5, momentum=0.1, device=None, dtype=None) -> None:
|
| 11 |
+
factory_kwargs = {'device': device, 'dtype': dtype}
|
| 12 |
+
super().__init__(num_features, eps, momentum, True, True, **factory_kwargs)
|
| 13 |
+
self.register_buffer('scale', torch.tensor(1.0, **factory_kwargs))
|
| 14 |
+
self.register_buffer('zero_point', torch.tensor(0, **factory_kwargs))
|
| 15 |
+
|
| 16 |
+
@staticmethod
|
| 17 |
+
def from_float(cls, mod):
|
| 18 |
+
activation_post_process = mod.activation_post_process
|
| 19 |
+
if type(mod) == cls._NNI_BN_RELU_MODULE:
|
| 20 |
+
mod = mod[0]
|
| 21 |
+
scale, zero_point = activation_post_process.calculate_qparams()
|
| 22 |
+
new_mod = cls(mod.num_features, mod.eps)
|
| 23 |
+
new_mod.weight = mod.weight
|
| 24 |
+
new_mod.bias = mod.bias
|
| 25 |
+
new_mod.running_mean = mod.running_mean
|
| 26 |
+
new_mod.running_var = mod.running_var
|
| 27 |
+
new_mod.scale = scale
|
| 28 |
+
new_mod.zero_point = zero_point
|
| 29 |
+
return new_mod
|
| 30 |
+
|
| 31 |
+
@classmethod
|
| 32 |
+
def from_reference(cls, bn, output_scale, output_zero_point):
|
| 33 |
+
qbn = cls(
|
| 34 |
+
bn.num_features,
|
| 35 |
+
bn.eps,
|
| 36 |
+
bn.momentum,
|
| 37 |
+
device=bn.weight.device,
|
| 38 |
+
dtype=bn.weight.dtype
|
| 39 |
+
)
|
| 40 |
+
qbn.weight = bn.weight
|
| 41 |
+
qbn.bias = bn.bias
|
| 42 |
+
qbn.running_mean = bn.running_mean
|
| 43 |
+
qbn.running_var = bn.running_var
|
| 44 |
+
qbn.scale = output_scale
|
| 45 |
+
qbn.zero_point = output_zero_point
|
| 46 |
+
return qbn
|
| 47 |
+
|
| 48 |
+
class BatchNorm2d(_BatchNorm):
|
| 49 |
+
r"""This is the quantized version of :class:`~torch.nn.BatchNorm2d`.
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
_NNI_BN_RELU_MODULE = nni.BNReLU2d
|
| 53 |
+
|
| 54 |
+
def __init__(self, num_features, eps=1e-5, momentum=0.1, device=None, dtype=None) -> None:
|
| 55 |
+
factory_kwargs = {'device': device, 'dtype': dtype}
|
| 56 |
+
super().__init__(num_features, eps, momentum, **factory_kwargs)
|
| 57 |
+
|
| 58 |
+
def _get_name(self):
|
| 59 |
+
return 'QuantizedBatchNorm2d'
|
| 60 |
+
|
| 61 |
+
def _check_input_dim(self, input):
|
| 62 |
+
# Temporarily using len(shape) instead of ndim due to JIT issue
|
| 63 |
+
# https://github.com/pytorch/pytorch/issues/23890
|
| 64 |
+
if len(input.shape) != 4:
|
| 65 |
+
raise ValueError("Input shape must be `(N, C, H, W)`!")
|
| 66 |
+
|
| 67 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
| 68 |
+
# disabling this since this is not symbolically traceable
|
| 69 |
+
# self._check_input_dim(input)
|
| 70 |
+
return torch.ops.quantized.batch_norm2d(
|
| 71 |
+
input, self.weight, self.bias, self.running_mean,
|
| 72 |
+
self.running_var, self.eps, self.scale, self.zero_point)
|
| 73 |
+
|
| 74 |
+
@classmethod
|
| 75 |
+
def from_float(cls, mod):
|
| 76 |
+
return _BatchNorm.from_float(cls, mod)
|
| 77 |
+
|
| 78 |
+
class BatchNorm3d(_BatchNorm):
|
| 79 |
+
r"""This is the quantized version of :class:`~torch.nn.BatchNorm3d`.
|
| 80 |
+
"""
|
| 81 |
+
|
| 82 |
+
_NNI_BN_RELU_MODULE = nni.BNReLU3d
|
| 83 |
+
|
| 84 |
+
def __init__(self, num_features, eps=1e-5, momentum=0.1, device=None, dtype=None):
|
| 85 |
+
factory_kwargs = {'device': device, 'dtype': dtype}
|
| 86 |
+
super().__init__(num_features, eps, momentum, **factory_kwargs)
|
| 87 |
+
|
| 88 |
+
def _get_name(self):
|
| 89 |
+
return 'QuantizedBatchNorm3d'
|
| 90 |
+
|
| 91 |
+
def _check_input_dim(self, input):
|
| 92 |
+
# Temporarily using len(shape) instead of ndim due to JIT issue
|
| 93 |
+
# https://github.com/pytorch/pytorch/issues/23890
|
| 94 |
+
if len(input.shape) != 5:
|
| 95 |
+
raise ValueError("Input shape must be `(N, C, H, W)`!")
|
| 96 |
+
|
| 97 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
| 98 |
+
# disabling this since this is not symbolically traceable
|
| 99 |
+
# self._check_input_dim(input)
|
| 100 |
+
return torch.ops.quantized.batch_norm3d(
|
| 101 |
+
input, self.weight, self.bias, self.running_mean,
|
| 102 |
+
self.running_var, self.eps, self.scale, self.zero_point)
|
| 103 |
+
|
| 104 |
+
@classmethod
|
| 105 |
+
def from_float(cls, mod):
|
| 106 |
+
return _BatchNorm.from_float(cls, mod)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/utils.cpython-311.pyc
ADDED
|
Binary file (14.5 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/reference/modules/linear.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from typing import Optional, Dict, Any
|
| 5 |
+
from .utils import ReferenceQuantizedModule
|
| 6 |
+
|
| 7 |
+
__all__ = ['Linear']
|
| 8 |
+
|
| 9 |
+
class Linear(nn.Linear, ReferenceQuantizedModule):
|
| 10 |
+
""" A reference quantized linear module that fits into the FX
|
| 11 |
+
Graph Mode Quantization workflow
|
| 12 |
+
activation will be floating point Tensor, we will store floating
|
| 13 |
+
point weight as well in the module, but in forward we'll quantize
|
| 14 |
+
and dequantize the weight before running the floating point functional
|
| 15 |
+
linear operator.
|
| 16 |
+
"""
|
| 17 |
+
_IS_REFERENCE = True
|
| 18 |
+
|
| 19 |
+
def __init__(
|
| 20 |
+
self,
|
| 21 |
+
in_features: int,
|
| 22 |
+
out_features: int,
|
| 23 |
+
bias_: bool = True,
|
| 24 |
+
device: Optional[torch.device] = None,
|
| 25 |
+
dtype: Optional[torch.dtype] = None,
|
| 26 |
+
weight_qparams: Optional[Dict[str, Any]] = None):
|
| 27 |
+
super().__init__(in_features, out_features, bias_, device, dtype)
|
| 28 |
+
self._init_weight_qparams(weight_qparams, device)
|
| 29 |
+
|
| 30 |
+
def _get_name(self):
|
| 31 |
+
return "QuantizedLinear(Reference)"
|
| 32 |
+
|
| 33 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 34 |
+
"""
|
| 35 |
+
we have:
|
| 36 |
+
w(float) -- quant - dequant \
|
| 37 |
+
x(float) ------------- F.linear ---
|
| 38 |
+
|
| 39 |
+
In the full model, we will see
|
| 40 |
+
w(float) -- quant - *dequant \
|
| 41 |
+
x -- quant --- *dequant -- *F.linear --- *quant - dequant
|
| 42 |
+
and the backend should be able to fuse the ops with `*` into a quantized linear
|
| 43 |
+
"""
|
| 44 |
+
weight_quant_dequant = self.get_weight()
|
| 45 |
+
result = F.linear(x, weight_quant_dequant, self.bias)
|
| 46 |
+
return result
|
| 47 |
+
|
| 48 |
+
@classmethod
|
| 49 |
+
def from_float(cls, float_linear, weight_qparams):
|
| 50 |
+
qref_linear = Linear(
|
| 51 |
+
float_linear.in_features, float_linear.out_features,
|
| 52 |
+
float_linear.bias is not None, device=float_linear.weight.device,
|
| 53 |
+
dtype=float_linear.weight.dtype, weight_qparams=weight_qparams)
|
| 54 |
+
qref_linear.weight = torch.nn.Parameter(float_linear.weight.detach())
|
| 55 |
+
if float_linear.bias is not None:
|
| 56 |
+
qref_linear.bias = torch.nn.Parameter(float_linear.bias.detach())
|
| 57 |
+
return qref_linear
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/sparse/quantized/dynamic/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (318 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (212 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/__pycache__/_numeric_suite.cpython-311.pyc
ADDED
|
Binary file (26.9 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (215 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__init__.py
ADDED
|
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flake8: noqa: F403
|
| 2 |
+
|
| 3 |
+
from .fake_quantize import * # noqa: F403
|
| 4 |
+
from .fuse_modules import fuse_modules # noqa: F403
|
| 5 |
+
from .fuse_modules import fuse_modules_qat # noqa: F403
|
| 6 |
+
from .fuser_method_mappings import * # noqa: F403
|
| 7 |
+
from .observer import * # noqa: F403
|
| 8 |
+
from .qconfig import * # noqa: F403
|
| 9 |
+
from .qconfig_mapping import * # noqa: F403
|
| 10 |
+
from .quant_type import * # noqa: F403
|
| 11 |
+
from .quantization_mappings import * # type: ignore[no-redef]
|
| 12 |
+
from .quantize import * # noqa: F403
|
| 13 |
+
from .quantize_jit import * # noqa: F403
|
| 14 |
+
from .stubs import * # noqa: F403
|
| 15 |
+
from .pt2e.export_utils import _move_exported_model_to_eval as move_exported_model_to_eval
|
| 16 |
+
from .pt2e.export_utils import _move_exported_model_to_train as move_exported_model_to_train
|
| 17 |
+
from .pt2e.export_utils import _allow_exported_model_train_eval as allow_exported_model_train_eval
|
| 18 |
+
from .pt2e.generate_numeric_debug_handle import generate_numeric_debug_handle # noqa: F401
|
| 19 |
+
from typing import Union, List, Callable, Tuple, Optional
|
| 20 |
+
from torch import Tensor
|
| 21 |
+
import torch
|
| 22 |
+
|
| 23 |
+
ObserverOrFakeQuantize = Union[ObserverBase, FakeQuantizeBase]
|
| 24 |
+
ObserverOrFakeQuantize.__module__ = "torch.ao.quantization"
|
| 25 |
+
|
| 26 |
+
__all__ = [
|
| 27 |
+
"DeQuantStub",
|
| 28 |
+
"FakeQuantize",
|
| 29 |
+
"FakeQuantizeBase",
|
| 30 |
+
"FixedQParamsFakeQuantize",
|
| 31 |
+
"FixedQParamsObserver",
|
| 32 |
+
"FusedMovingAvgObsFakeQuantize",
|
| 33 |
+
"HistogramObserver",
|
| 34 |
+
"MatchAllNode",
|
| 35 |
+
"MinMaxObserver",
|
| 36 |
+
"MovingAverageMinMaxObserver",
|
| 37 |
+
"MovingAveragePerChannelMinMaxObserver",
|
| 38 |
+
"NoopObserver",
|
| 39 |
+
"ObserverBase",
|
| 40 |
+
"ObserverOrFakeQuantize",
|
| 41 |
+
"Pattern",
|
| 42 |
+
"PerChannelMinMaxObserver",
|
| 43 |
+
"PlaceholderObserver",
|
| 44 |
+
"QConfig",
|
| 45 |
+
"QConfigAny",
|
| 46 |
+
"QConfigDynamic",
|
| 47 |
+
"QConfigMapping",
|
| 48 |
+
"QuantStub",
|
| 49 |
+
"QuantType",
|
| 50 |
+
"QuantWrapper",
|
| 51 |
+
"RecordingObserver",
|
| 52 |
+
"ReuseInputObserver",
|
| 53 |
+
"UniformQuantizationObserverBase",
|
| 54 |
+
"add_quant_dequant",
|
| 55 |
+
"convert",
|
| 56 |
+
"convert_dynamic_jit",
|
| 57 |
+
"convert_jit",
|
| 58 |
+
"default_affine_fixed_qparams_fake_quant",
|
| 59 |
+
"default_affine_fixed_qparams_observer",
|
| 60 |
+
"default_debug_observer",
|
| 61 |
+
"default_dynamic_fake_quant",
|
| 62 |
+
"default_dynamic_quant_observer",
|
| 63 |
+
"default_embedding_fake_quant",
|
| 64 |
+
"default_embedding_fake_quant_4bit",
|
| 65 |
+
"default_eval_fn",
|
| 66 |
+
"default_fake_quant",
|
| 67 |
+
"default_fixed_qparams_range_0to1_fake_quant",
|
| 68 |
+
"default_fixed_qparams_range_0to1_observer",
|
| 69 |
+
"default_fixed_qparams_range_neg1to1_fake_quant",
|
| 70 |
+
"default_fixed_qparams_range_neg1to1_observer",
|
| 71 |
+
"default_float_qparams_observer",
|
| 72 |
+
"default_float_qparams_observer_4bit",
|
| 73 |
+
"default_fused_act_fake_quant",
|
| 74 |
+
"default_fused_per_channel_wt_fake_quant",
|
| 75 |
+
"default_fused_wt_fake_quant",
|
| 76 |
+
"default_histogram_fake_quant",
|
| 77 |
+
"default_histogram_observer",
|
| 78 |
+
"default_observer",
|
| 79 |
+
"default_per_channel_weight_fake_quant",
|
| 80 |
+
"default_per_channel_weight_observer",
|
| 81 |
+
"default_placeholder_observer",
|
| 82 |
+
"default_reuse_input_observer",
|
| 83 |
+
"default_symmetric_fixed_qparams_fake_quant",
|
| 84 |
+
"default_symmetric_fixed_qparams_observer",
|
| 85 |
+
"default_weight_fake_quant",
|
| 86 |
+
"default_weight_observer",
|
| 87 |
+
"disable_fake_quant",
|
| 88 |
+
"disable_observer",
|
| 89 |
+
"enable_fake_quant",
|
| 90 |
+
"enable_observer",
|
| 91 |
+
"fuse_conv_bn",
|
| 92 |
+
"fuse_conv_bn_jit",
|
| 93 |
+
"fuse_conv_bn_relu",
|
| 94 |
+
"fuse_convtranspose_bn",
|
| 95 |
+
"fuse_linear_bn",
|
| 96 |
+
"fuse_modules",
|
| 97 |
+
"fuse_modules_qat",
|
| 98 |
+
"fused_per_channel_wt_fake_quant_range_neg_127_to_127",
|
| 99 |
+
"fused_wt_fake_quant_range_neg_127_to_127",
|
| 100 |
+
"get_combined_dict",
|
| 101 |
+
"get_default_compare_output_module_list",
|
| 102 |
+
"get_default_custom_config_dict",
|
| 103 |
+
"get_default_dynamic_quant_module_mappings",
|
| 104 |
+
"get_default_dynamic_sparse_quant_module_mappings",
|
| 105 |
+
"get_default_float_to_quantized_operator_mappings",
|
| 106 |
+
"get_default_qat_module_mappings",
|
| 107 |
+
"get_default_qat_qconfig",
|
| 108 |
+
"get_default_qat_qconfig_dict",
|
| 109 |
+
"get_default_qat_qconfig_mapping",
|
| 110 |
+
"get_default_qconfig",
|
| 111 |
+
"get_default_qconfig_dict",
|
| 112 |
+
"get_default_qconfig_mapping",
|
| 113 |
+
"get_default_qconfig_propagation_list",
|
| 114 |
+
"get_default_static_quant_module_mappings",
|
| 115 |
+
"get_default_static_quant_reference_module_mappings",
|
| 116 |
+
"get_default_static_sparse_quant_module_mappings",
|
| 117 |
+
"get_dynamic_quant_module_class",
|
| 118 |
+
"get_embedding_qat_module_mappings",
|
| 119 |
+
"get_embedding_static_quant_module_mappings",
|
| 120 |
+
"get_fuser_method",
|
| 121 |
+
"get_fuser_method_new",
|
| 122 |
+
"get_observer_state_dict",
|
| 123 |
+
"get_quantized_operator",
|
| 124 |
+
"get_static_quant_module_class",
|
| 125 |
+
"load_observer_state_dict",
|
| 126 |
+
"move_exported_model_to_eval",
|
| 127 |
+
"move_exported_model_to_train",
|
| 128 |
+
"allow_exported_model_train_eval",
|
| 129 |
+
"no_observer_set",
|
| 130 |
+
"per_channel_weight_observer_range_neg_127_to_127",
|
| 131 |
+
"prepare",
|
| 132 |
+
"prepare_dynamic_jit",
|
| 133 |
+
"prepare_jit",
|
| 134 |
+
"prepare_qat",
|
| 135 |
+
"propagate_qconfig_",
|
| 136 |
+
"qconfig_equals",
|
| 137 |
+
"quantize",
|
| 138 |
+
"quantize_dynamic",
|
| 139 |
+
"quantize_dynamic_jit",
|
| 140 |
+
"quantize_jit",
|
| 141 |
+
"quantize_qat",
|
| 142 |
+
"script_qconfig",
|
| 143 |
+
"script_qconfig_dict",
|
| 144 |
+
"swap_module",
|
| 145 |
+
"weight_observer_range_neg_127_to_127",
|
| 146 |
+
"generate_numeric_debug_handle",
|
| 147 |
+
]
|
| 148 |
+
|
| 149 |
+
def default_eval_fn(model, calib_data):
|
| 150 |
+
r"""Define the default evaluation function.
|
| 151 |
+
|
| 152 |
+
Default evaluation function takes a torch.utils.data.Dataset or a list of
|
| 153 |
+
input Tensors and run the model on the dataset
|
| 154 |
+
"""
|
| 155 |
+
for data, target in calib_data:
|
| 156 |
+
model(data)
|
| 157 |
+
|
| 158 |
+
class _DerivedObserverOrFakeQuantize(ObserverBase):
|
| 159 |
+
r"""This observer is used to describe an observer whose quantization parameters
|
| 160 |
+
are derived from other observers
|
| 161 |
+
"""
|
| 162 |
+
|
| 163 |
+
def __init__(
|
| 164 |
+
self,
|
| 165 |
+
dtype: torch.dtype,
|
| 166 |
+
obs_or_fqs: List[ObserverOrFakeQuantize],
|
| 167 |
+
derive_qparams_fn: Callable[[List[ObserverOrFakeQuantize]], Tuple[Tensor, Tensor]],
|
| 168 |
+
quant_min: Optional[int]=None,
|
| 169 |
+
quant_max: Optional[int]=None,
|
| 170 |
+
qscheme: Optional[torch.qscheme]=None,
|
| 171 |
+
ch_axis: Optional[int] = None
|
| 172 |
+
):
|
| 173 |
+
super().__init__(dtype)
|
| 174 |
+
self.obs_or_fqs = obs_or_fqs
|
| 175 |
+
self.derive_qparams_fn = derive_qparams_fn
|
| 176 |
+
self.quant_min = quant_min
|
| 177 |
+
self.quant_max = quant_max
|
| 178 |
+
self.qscheme = qscheme
|
| 179 |
+
self.ch_axis = ch_axis
|
| 180 |
+
|
| 181 |
+
from .utils import is_per_channel
|
| 182 |
+
if is_per_channel(self.qscheme):
|
| 183 |
+
assert self.ch_axis is not None, "Must provide a valid ch_axis if qscheme is per channel"
|
| 184 |
+
|
| 185 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 186 |
+
return x
|
| 187 |
+
|
| 188 |
+
def calculate_qparams(self):
|
| 189 |
+
return self.derive_qparams_fn(self.obs_or_fqs)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/quantization_mappings.cpython-311.pyc
ADDED
|
Binary file (16.8 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/quantize_fx.cpython-311.pyc
ADDED
|
Binary file (34.4 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/_correct_bias.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.ao.nn.quantized as nnq
|
| 4 |
+
|
| 5 |
+
import torch.ao.quantization
|
| 6 |
+
import torch.ao.ns._numeric_suite as ns
|
| 7 |
+
|
| 8 |
+
__all__ = [
|
| 9 |
+
"get_module",
|
| 10 |
+
"parent_child_names",
|
| 11 |
+
"get_param",
|
| 12 |
+
"MeanShadowLogger",
|
| 13 |
+
"bias_correction",
|
| 14 |
+
]
|
| 15 |
+
|
| 16 |
+
_supported_modules = {nn.Linear, nn.Conv2d}
|
| 17 |
+
_supported_modules_quantized = {nnq.Linear, nnq.Conv2d}
|
| 18 |
+
|
| 19 |
+
def get_module(model, name):
|
| 20 |
+
"""Given name of submodule, this function grabs the submodule from given model."""
|
| 21 |
+
return dict(model.named_modules())[name]
|
| 22 |
+
|
| 23 |
+
def parent_child_names(name):
|
| 24 |
+
"""Split full name of submodule into parent submodule's full name and submodule's name."""
|
| 25 |
+
split_name = name.rsplit('.', 1)
|
| 26 |
+
if len(split_name) == 1:
|
| 27 |
+
return '', split_name[0]
|
| 28 |
+
else:
|
| 29 |
+
return split_name[0], split_name[1]
|
| 30 |
+
|
| 31 |
+
def get_param(module, attr):
|
| 32 |
+
"""Get the parameter given a module and attribute.
|
| 33 |
+
|
| 34 |
+
Sometimes the weights/bias attribute gives you the raw tensor, but sometimes
|
| 35 |
+
gives a function that will give you the raw tensor, this function takes care of that logic
|
| 36 |
+
"""
|
| 37 |
+
param = getattr(module, attr, None)
|
| 38 |
+
if callable(param):
|
| 39 |
+
return param()
|
| 40 |
+
else:
|
| 41 |
+
return param
|
| 42 |
+
|
| 43 |
+
class MeanShadowLogger(ns.Logger):
|
| 44 |
+
"""Mean Logger for a Shadow module.
|
| 45 |
+
|
| 46 |
+
A logger for a Shadow module whose purpose is to record the rolling mean
|
| 47 |
+
of the data passed to the floating point and quantized models
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
def __init__(self):
|
| 51 |
+
"""Set up initial values for float and quantized stats, count, float sum, and quant sum."""
|
| 52 |
+
super().__init__()
|
| 53 |
+
self.stats["float"] = None
|
| 54 |
+
self.stats["quantized"] = None
|
| 55 |
+
self.count = 0
|
| 56 |
+
self.float_sum = None
|
| 57 |
+
self.quant_sum = None
|
| 58 |
+
|
| 59 |
+
def forward(self, x, y):
|
| 60 |
+
"""Compute the average of quantized and floating-point data from modules.
|
| 61 |
+
|
| 62 |
+
The inputs x,y are output data from the quantized and floating-point modules.
|
| 63 |
+
x is for the quantized module, y is for the floating point module
|
| 64 |
+
"""
|
| 65 |
+
if x.is_quantized:
|
| 66 |
+
x = x.dequantize()
|
| 67 |
+
|
| 68 |
+
self.count += 1
|
| 69 |
+
if self.stats["quantized"] is None:
|
| 70 |
+
self.stats["quantized"] = x
|
| 71 |
+
self.quant_sum = x
|
| 72 |
+
else:
|
| 73 |
+
self.quant_sum += x
|
| 74 |
+
self.stats["quantized"] = self.quant_sum / self.count
|
| 75 |
+
|
| 76 |
+
if self.stats["float"] is None:
|
| 77 |
+
self.stats["float"] = y
|
| 78 |
+
self.float_sum = y
|
| 79 |
+
else:
|
| 80 |
+
self.float_sum += y
|
| 81 |
+
self.stats["float"] = self.float_sum / self.count
|
| 82 |
+
|
| 83 |
+
def clear(self):
|
| 84 |
+
self.stats["float"] = None
|
| 85 |
+
self.stats["quantized"] = None
|
| 86 |
+
self.count = 0
|
| 87 |
+
self.float_sum = None
|
| 88 |
+
self.quant_sum = None
|
| 89 |
+
|
| 90 |
+
def bias_correction(float_model, quantized_model, img_data, target_modules=_supported_modules_quantized, neval_batches=None):
|
| 91 |
+
"""Perform bias correction on a module.
|
| 92 |
+
|
| 93 |
+
Using numeric suite shadow module, the expected output of the floating point and quantized modules
|
| 94 |
+
is recorded. Using that data the bias of supported modules is shifted to compensate for the drift caused
|
| 95 |
+
by quantization
|
| 96 |
+
Paper reference: https://arxiv.org/pdf/1906.04721.pdf (Section 4.2)
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
float_model: a trained model that serves as a reference to what bias correction should aim for
|
| 100 |
+
quantized_model: quantized form of float_model that bias correction is to applied to
|
| 101 |
+
img_data: calibration data to estimate the expected output (used to find quantization error)
|
| 102 |
+
target_modules: specifies what submodules in quantized_model need bias correction (can be extended to
|
| 103 |
+
unquantized submodules)
|
| 104 |
+
neval_batches: a cap to the number of batches you want to be used for estimating the expected output
|
| 105 |
+
"""
|
| 106 |
+
ns.prepare_model_with_stubs(float_model, quantized_model, _supported_modules, MeanShadowLogger)
|
| 107 |
+
|
| 108 |
+
uncorrected_modules = {}
|
| 109 |
+
for name, submodule in quantized_model.named_modules():
|
| 110 |
+
if type(submodule) in target_modules:
|
| 111 |
+
uncorrected_modules[name] = submodule
|
| 112 |
+
|
| 113 |
+
for uncorrected_module in uncorrected_modules:
|
| 114 |
+
quantized_submodule = get_module(quantized_model, uncorrected_module)
|
| 115 |
+
bias = get_param(quantized_submodule, 'bias')
|
| 116 |
+
if bias is not None:
|
| 117 |
+
|
| 118 |
+
count = 0
|
| 119 |
+
for data in img_data:
|
| 120 |
+
quantized_model(data[0])
|
| 121 |
+
count += 1
|
| 122 |
+
if count == neval_batches:
|
| 123 |
+
break
|
| 124 |
+
ob_dict = ns.get_logger_dict(quantized_model)
|
| 125 |
+
parent_name, _ = parent_child_names(uncorrected_module)
|
| 126 |
+
|
| 127 |
+
float_data = ob_dict[parent_name + '.stats']['float']
|
| 128 |
+
quant_data = ob_dict[parent_name + '.stats']['quantized']
|
| 129 |
+
|
| 130 |
+
# math for expected_error
|
| 131 |
+
quantization_error = quant_data - float_data
|
| 132 |
+
dims = list(range(quantization_error.dim()))
|
| 133 |
+
# Note: we don't want to take the mean over the output channel dimension
|
| 134 |
+
dims.remove(1)
|
| 135 |
+
expected_error = torch.mean(quantization_error, dims)
|
| 136 |
+
|
| 137 |
+
updated_bias = bias.data - expected_error
|
| 138 |
+
|
| 139 |
+
bias.data = updated_bias
|
| 140 |
+
|
| 141 |
+
# Resets the data contained in the loggers
|
| 142 |
+
for name, submodule in quantized_model.named_modules():
|
| 143 |
+
if isinstance(submodule, MeanShadowLogger):
|
| 144 |
+
submodule.clear()
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/__pycache__/_common_operator_config_utils.cpython-311.pyc
ADDED
|
Binary file (32.9 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/__pycache__/backend_config.cpython-311.pyc
ADDED
|
Binary file (34.2 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/__pycache__/observation_type.cpython-311.pyc
ADDED
|
Binary file (245 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/__pycache__/onednn.cpython-311.pyc
ADDED
|
Binary file (20.9 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/observation_type.py
ADDED
|
File without changes
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/qnnpack.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from ._common_operator_config_utils import (
|
| 3 |
+
_get_binary_op_configs,
|
| 4 |
+
_get_bn_configs,
|
| 5 |
+
_get_cat_config,
|
| 6 |
+
_get_conv_configs,
|
| 7 |
+
_get_default_op_configs,
|
| 8 |
+
_get_embedding_op_configs,
|
| 9 |
+
_get_fixed_qparams_op_configs,
|
| 10 |
+
_get_linear_configs,
|
| 11 |
+
_get_rnn_op_configs,
|
| 12 |
+
_get_share_qparams_op_configs,
|
| 13 |
+
)
|
| 14 |
+
from .backend_config import BackendConfig, DTypeConfig, DTypeWithConstraints
|
| 15 |
+
|
| 16 |
+
__all__ = [
|
| 17 |
+
"get_qnnpack_backend_config",
|
| 18 |
+
]
|
| 19 |
+
|
| 20 |
+
# ===================
|
| 21 |
+
# | DTYPE CONFIGS |
|
| 22 |
+
# ===================
|
| 23 |
+
|
| 24 |
+
qnnpack_weighted_op_quint8_dtype_config = DTypeConfig(
|
| 25 |
+
input_dtype=torch.quint8,
|
| 26 |
+
output_dtype=torch.quint8,
|
| 27 |
+
weight_dtype=torch.qint8,
|
| 28 |
+
bias_dtype=torch.float,
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
qnnpack_default_op_quint8_dtype_config = DTypeConfig(
|
| 32 |
+
input_dtype=torch.quint8,
|
| 33 |
+
output_dtype=torch.quint8,
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
qnnpack_default_op_fp16_dtype_config = DTypeConfig(
|
| 37 |
+
input_dtype=torch.float16,
|
| 38 |
+
output_dtype=torch.float16,
|
| 39 |
+
weight_dtype=torch.float16,
|
| 40 |
+
bias_dtype=torch.float16,
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
qnnpack_default_dynamic_int8_dtype_config = DTypeConfig(
|
| 44 |
+
input_dtype=torch.quint8,
|
| 45 |
+
output_dtype=torch.float,
|
| 46 |
+
weight_dtype=torch.qint8,
|
| 47 |
+
bias_dtype=torch.float,
|
| 48 |
+
is_dynamic=True,
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
qnnpack_default_dynamic_float16_dtype_config = DTypeConfig(
|
| 52 |
+
input_dtype=torch.float16,
|
| 53 |
+
output_dtype=torch.float,
|
| 54 |
+
weight_dtype=torch.float16,
|
| 55 |
+
bias_dtype=torch.float,
|
| 56 |
+
is_dynamic=True,
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
qnnpack_weight_only_quint8_dtype_config = DTypeConfig(
|
| 60 |
+
input_dtype=torch.float,
|
| 61 |
+
output_dtype=torch.float,
|
| 62 |
+
weight_dtype=torch.quint8,
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
qnnpack_weight_only_quint4x2_dtype_config = DTypeConfig(
|
| 66 |
+
input_dtype=torch.float,
|
| 67 |
+
output_dtype=torch.float,
|
| 68 |
+
weight_dtype=torch.quint4x2,
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
# xnnpack compatible dtype configs
|
| 72 |
+
|
| 73 |
+
# We restrict scale values to be 2 ** -12 to ensure the
|
| 74 |
+
# requantization scale never falls below the xnnpack lower
|
| 75 |
+
# threshold. Additionally, for qint8 weight, we restrict
|
| 76 |
+
# the quantization values to [-127, +127], excluding -128.
|
| 77 |
+
# For more detail, refer to the description of
|
| 78 |
+
# `default_symmetric_qnnpack_qconfig`.
|
| 79 |
+
|
| 80 |
+
# TODO: add additional restriction on qscheme to ensure it
|
| 81 |
+
# is either per_tensor_symmetric or per_channel_symmetric
|
| 82 |
+
|
| 83 |
+
qnnpack_act_qint8_scale_min_2_neg_12 = DTypeWithConstraints(
|
| 84 |
+
dtype=torch.qint8,
|
| 85 |
+
scale_min_lower_bound=2 ** -12,
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
qnnpack_weight_qint8_neg_127_to_127_scale_min_2_neg_12 = DTypeWithConstraints(
|
| 89 |
+
dtype=torch.qint8,
|
| 90 |
+
quant_min_lower_bound=-127,
|
| 91 |
+
quant_max_upper_bound=127,
|
| 92 |
+
scale_min_lower_bound=2 ** -12,
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
qnnpack_weighted_op_qint8_symmetric_dtype_config = DTypeConfig(
|
| 96 |
+
input_dtype=qnnpack_act_qint8_scale_min_2_neg_12,
|
| 97 |
+
output_dtype=qnnpack_act_qint8_scale_min_2_neg_12,
|
| 98 |
+
weight_dtype=qnnpack_weight_qint8_neg_127_to_127_scale_min_2_neg_12,
|
| 99 |
+
bias_dtype=torch.float,
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
qnnpack_default_op_qint8_symmetric_dtype_config = DTypeConfig(
|
| 103 |
+
input_dtype=qnnpack_act_qint8_scale_min_2_neg_12,
|
| 104 |
+
output_dtype=qnnpack_act_qint8_scale_min_2_neg_12,
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
# =====================
|
| 109 |
+
# | BACKEND CONFIGS |
|
| 110 |
+
# =====================
|
| 111 |
+
|
| 112 |
+
def get_qnnpack_backend_config() -> BackendConfig:
|
| 113 |
+
"""
|
| 114 |
+
Return the `BackendConfig` for PyTorch's native QNNPACK backend.
|
| 115 |
+
"""
|
| 116 |
+
conv_dtype_configs = [
|
| 117 |
+
qnnpack_weighted_op_qint8_symmetric_dtype_config,
|
| 118 |
+
qnnpack_weighted_op_quint8_dtype_config,
|
| 119 |
+
]
|
| 120 |
+
linear_dtype_configs = [
|
| 121 |
+
qnnpack_weighted_op_qint8_symmetric_dtype_config,
|
| 122 |
+
qnnpack_weighted_op_quint8_dtype_config,
|
| 123 |
+
qnnpack_default_dynamic_int8_dtype_config,
|
| 124 |
+
qnnpack_default_dynamic_float16_dtype_config,
|
| 125 |
+
]
|
| 126 |
+
binary_op_dtype_configs = [
|
| 127 |
+
qnnpack_default_op_qint8_symmetric_dtype_config,
|
| 128 |
+
qnnpack_default_op_quint8_dtype_config,
|
| 129 |
+
]
|
| 130 |
+
default_op_dtype_configs = [
|
| 131 |
+
qnnpack_default_op_qint8_symmetric_dtype_config,
|
| 132 |
+
qnnpack_default_op_quint8_dtype_config,
|
| 133 |
+
]
|
| 134 |
+
fixed_qparams_op_dtype_configs = [
|
| 135 |
+
qnnpack_default_op_qint8_symmetric_dtype_config,
|
| 136 |
+
qnnpack_default_op_quint8_dtype_config,
|
| 137 |
+
]
|
| 138 |
+
share_qparams_op_dtype_configs = [
|
| 139 |
+
qnnpack_default_op_qint8_symmetric_dtype_config,
|
| 140 |
+
qnnpack_default_op_quint8_dtype_config,
|
| 141 |
+
]
|
| 142 |
+
rnn_op_dtype_configs = [
|
| 143 |
+
qnnpack_default_dynamic_int8_dtype_config,
|
| 144 |
+
qnnpack_default_dynamic_float16_dtype_config,
|
| 145 |
+
]
|
| 146 |
+
embedding_op_dtype_configs = [
|
| 147 |
+
qnnpack_weight_only_quint8_dtype_config,
|
| 148 |
+
qnnpack_weight_only_quint4x2_dtype_config,
|
| 149 |
+
]
|
| 150 |
+
return BackendConfig("qnnpack") \
|
| 151 |
+
.set_backend_pattern_configs(_get_conv_configs(conv_dtype_configs)) \
|
| 152 |
+
.set_backend_pattern_configs(_get_linear_configs(linear_dtype_configs)) \
|
| 153 |
+
.set_backend_pattern_configs(_get_binary_op_configs(binary_op_dtype_configs)) \
|
| 154 |
+
.set_backend_pattern_config(_get_cat_config(default_op_dtype_configs)) \
|
| 155 |
+
.set_backend_pattern_configs(_get_default_op_configs(default_op_dtype_configs)) \
|
| 156 |
+
.set_backend_pattern_configs(_get_fixed_qparams_op_configs(fixed_qparams_op_dtype_configs)) \
|
| 157 |
+
.set_backend_pattern_configs(_get_share_qparams_op_configs(share_qparams_op_dtype_configs)) \
|
| 158 |
+
.set_backend_pattern_configs(_get_bn_configs(default_op_dtype_configs)) \
|
| 159 |
+
.set_backend_pattern_configs(_get_rnn_op_configs(rnn_op_dtype_configs)) \
|
| 160 |
+
.set_backend_pattern_configs(_get_embedding_op_configs(embedding_op_dtype_configs))
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/tensorrt.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from .backend_config import (
|
| 3 |
+
BackendConfig,
|
| 4 |
+
BackendPatternConfig,
|
| 5 |
+
DTypeConfig,
|
| 6 |
+
ObservationType
|
| 7 |
+
)
|
| 8 |
+
from ._common_operator_config_utils import (
|
| 9 |
+
_get_binary_op_configs,
|
| 10 |
+
_get_linear_configs,
|
| 11 |
+
_get_conv_configs,
|
| 12 |
+
_get_share_qparams_op_configs,
|
| 13 |
+
_get_tensor_info_op_configs,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
__all__ = [
|
| 17 |
+
"get_tensorrt_backend_config",
|
| 18 |
+
"get_tensorrt_backend_config_dict",
|
| 19 |
+
]
|
| 20 |
+
|
| 21 |
+
def get_tensorrt_backend_config() -> BackendConfig:
|
| 22 |
+
"""
|
| 23 |
+
Return the `BackendConfig` for the TensorRT backend.
|
| 24 |
+
NOTE: Current api will change in the future, it's just to unblock experimentation for
|
| 25 |
+
new backends, please don't use it right now.
|
| 26 |
+
TODO: add a README when it's more stable
|
| 27 |
+
"""
|
| 28 |
+
# dtype configs
|
| 29 |
+
weighted_op_qint8_dtype_config = DTypeConfig(
|
| 30 |
+
input_dtype=torch.qint8,
|
| 31 |
+
output_dtype=torch.qint8,
|
| 32 |
+
weight_dtype=torch.qint8,
|
| 33 |
+
bias_dtype=torch.float,
|
| 34 |
+
)
|
| 35 |
+
non_weighted_op_qint8_dtype_config = DTypeConfig(
|
| 36 |
+
input_dtype=torch.qint8,
|
| 37 |
+
output_dtype=torch.qint8,
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
addmm_config = BackendPatternConfig(torch.addmm) \
|
| 41 |
+
.set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \
|
| 42 |
+
.add_dtype_config(weighted_op_qint8_dtype_config) \
|
| 43 |
+
._set_input_type_to_index({
|
| 44 |
+
"bias": 0,
|
| 45 |
+
"input": 1,
|
| 46 |
+
"weight": 2,
|
| 47 |
+
})
|
| 48 |
+
cat_config = BackendPatternConfig(torch.cat) \
|
| 49 |
+
.set_observation_type(ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT) \
|
| 50 |
+
.add_dtype_config(non_weighted_op_qint8_dtype_config)
|
| 51 |
+
conv_dtype_configs = [
|
| 52 |
+
weighted_op_qint8_dtype_config,
|
| 53 |
+
]
|
| 54 |
+
linear_dtype_configs = [
|
| 55 |
+
weighted_op_qint8_dtype_config,
|
| 56 |
+
]
|
| 57 |
+
binary_op_dtype_configs = [
|
| 58 |
+
weighted_op_qint8_dtype_config,
|
| 59 |
+
]
|
| 60 |
+
share_qparams_op_dtype_configs = [
|
| 61 |
+
non_weighted_op_qint8_dtype_config,
|
| 62 |
+
]
|
| 63 |
+
tensor_info_op_dtype_configs = [
|
| 64 |
+
non_weighted_op_qint8_dtype_config,
|
| 65 |
+
]
|
| 66 |
+
# there might be things not supported in fx2trt, but it will error out
|
| 67 |
+
# during fx2trt conversion and can support them after that
|
| 68 |
+
return BackendConfig("tensorrt") \
|
| 69 |
+
.set_backend_pattern_configs(_get_conv_configs(conv_dtype_configs)) \
|
| 70 |
+
.set_backend_pattern_config(addmm_config) \
|
| 71 |
+
.set_backend_pattern_config(cat_config) \
|
| 72 |
+
.set_backend_pattern_configs(_get_linear_configs(linear_dtype_configs)) \
|
| 73 |
+
.set_backend_pattern_configs(_get_binary_op_configs(binary_op_dtype_configs)) \
|
| 74 |
+
.set_backend_pattern_configs(_get_share_qparams_op_configs(share_qparams_op_dtype_configs)) \
|
| 75 |
+
.set_backend_pattern_configs(_get_tensor_info_op_configs(tensor_info_op_dtype_configs))
|
| 76 |
+
|
| 77 |
+
def get_tensorrt_backend_config_dict():
|
| 78 |
+
"""
|
| 79 |
+
Return the `BackendConfig` for the TensorRT backend in dictionary form.
|
| 80 |
+
"""
|
| 81 |
+
return get_tensorrt_backend_config().to_dict()
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/utils.py
ADDED
|
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, Any, List, Callable, Union, Tuple, Type
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from .backend_config import (
|
| 7 |
+
BackendConfig,
|
| 8 |
+
BackendPatternConfig,
|
| 9 |
+
DTypeConfig,
|
| 10 |
+
)
|
| 11 |
+
from ..utils import Pattern
|
| 12 |
+
from ..fuser_method_mappings import (
|
| 13 |
+
_reverse2,
|
| 14 |
+
_reverse3,
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
__all__ = [
|
| 18 |
+
"get_pattern_to_dtype_configs",
|
| 19 |
+
"get_qat_module_classes",
|
| 20 |
+
"get_fused_module_classes",
|
| 21 |
+
"get_pattern_to_input_type_to_index",
|
| 22 |
+
"get_root_module_to_quantized_reference_module",
|
| 23 |
+
"get_fuser_method_mapping",
|
| 24 |
+
"get_module_to_qat_module",
|
| 25 |
+
"get_fusion_pattern_to_root_node_getter",
|
| 26 |
+
"get_fusion_pattern_to_extra_inputs_getter",
|
| 27 |
+
"remove_boolean_dispatch_from_name",
|
| 28 |
+
"pattern_to_human_readable",
|
| 29 |
+
"entry_to_pretty_str",
|
| 30 |
+
]
|
| 31 |
+
|
| 32 |
+
def get_pattern_to_dtype_configs(backend_config: BackendConfig) -> Dict[Pattern, List[DTypeConfig]]:
|
| 33 |
+
pattern_to_dtype_configs: Dict[Pattern, List[DTypeConfig]] = {}
|
| 34 |
+
for pattern, config in backend_config._pattern_complex_format_to_config.items():
|
| 35 |
+
pattern_to_dtype_configs[pattern] = config.dtype_configs
|
| 36 |
+
return pattern_to_dtype_configs
|
| 37 |
+
|
| 38 |
+
def get_qat_module_classes(backend_config: BackendConfig) -> Tuple[type, ...]:
|
| 39 |
+
qat_module_classes = []
|
| 40 |
+
for config in backend_config.configs:
|
| 41 |
+
if config.qat_module is not None:
|
| 42 |
+
qat_module_classes.append(config.qat_module)
|
| 43 |
+
return tuple(set(qat_module_classes))
|
| 44 |
+
|
| 45 |
+
def get_fused_module_classes(backend_config: BackendConfig) -> Tuple[type, ...]:
|
| 46 |
+
fused_module_classes = []
|
| 47 |
+
for config in backend_config.configs:
|
| 48 |
+
if config.fused_module is not None:
|
| 49 |
+
fused_module_classes.append(config.fused_module)
|
| 50 |
+
return tuple(set(fused_module_classes))
|
| 51 |
+
|
| 52 |
+
def get_pattern_to_input_type_to_index(backend_config: BackendConfig) -> Dict[Pattern, Dict[str, int]]:
|
| 53 |
+
pattern_to_input_type_to_index: Dict[Pattern, Dict[str, int]] = {}
|
| 54 |
+
for pattern, config in backend_config._pattern_complex_format_to_config.items():
|
| 55 |
+
pattern_to_input_type_to_index[pattern] = config._input_type_to_index
|
| 56 |
+
return pattern_to_input_type_to_index
|
| 57 |
+
|
| 58 |
+
def get_root_module_to_quantized_reference_module(
|
| 59 |
+
backend_config: BackendConfig) -> Dict[Type[torch.nn.Module], Type[torch.nn.Module]]:
|
| 60 |
+
mapping: Dict[Type[torch.nn.Module], Type[torch.nn.Module]] = {}
|
| 61 |
+
for config in backend_config.configs:
|
| 62 |
+
if config.root_module is not None and config.reference_quantized_module is not None:
|
| 63 |
+
mapping[config.root_module] = config.reference_quantized_module
|
| 64 |
+
return mapping
|
| 65 |
+
|
| 66 |
+
def get_fuser_method_mapping(backend_config: BackendConfig) -> Dict[Pattern, Union[nn.Sequential, Callable]]:
|
| 67 |
+
fuser_method_mapping : Dict[Pattern, Union[nn.Sequential, Callable]] = {}
|
| 68 |
+
for pattern, config in backend_config._pattern_complex_format_to_config.items():
|
| 69 |
+
if config.fuser_method is not None:
|
| 70 |
+
# Note: both the fuser method and the pattern are specified in forward order in the
|
| 71 |
+
# BackendConfig, but the internal pattern matching code uses the reversed nested tuple
|
| 72 |
+
# format, so we need to convert both to the internal format
|
| 73 |
+
fuser_method = _get_fuser_method_in_reversed_nested_tuple_format(config)
|
| 74 |
+
fuser_method_mapping[pattern] = fuser_method
|
| 75 |
+
return fuser_method_mapping
|
| 76 |
+
|
| 77 |
+
def get_module_to_qat_module(backend_config: BackendConfig) -> Dict[Pattern, Type[torch.nn.Module]]:
|
| 78 |
+
module_to_qat_module: Dict[Pattern, Type[torch.nn.Module]] = {}
|
| 79 |
+
for pattern, config in backend_config._pattern_complex_format_to_config.items():
|
| 80 |
+
if config.qat_module is not None:
|
| 81 |
+
module_to_qat_module[pattern] = config.qat_module
|
| 82 |
+
return module_to_qat_module
|
| 83 |
+
|
| 84 |
+
def get_fusion_pattern_to_root_node_getter(backend_config: BackendConfig) -> Dict[Pattern, Callable]:
|
| 85 |
+
""" Get a map from fusion pattern to a function that returns the root node
|
| 86 |
+
from the fusion pattern, e.g. the most common one is:
|
| 87 |
+
def get_root_node(node_pattern):
|
| 88 |
+
while not isinstance(node_pattern[-1], Node):
|
| 89 |
+
node_pattern = node_pattern[-1]
|
| 90 |
+
return node_pattern[-1]
|
| 91 |
+
This can work for all patterns whose root node is the "last node" in the pattern,
|
| 92 |
+
e.g. (torch.add, MatchAllNode, (torch.ReLU, torch.Conv2d))
|
| 93 |
+
"""
|
| 94 |
+
root_node_getter_mapping: Dict[Pattern, Callable] = {}
|
| 95 |
+
for pattern, config in backend_config._pattern_complex_format_to_config.items():
|
| 96 |
+
if config._root_node_getter is not None:
|
| 97 |
+
root_node_getter_mapping[pattern] = config._root_node_getter
|
| 98 |
+
return root_node_getter_mapping
|
| 99 |
+
|
| 100 |
+
def get_fusion_pattern_to_extra_inputs_getter(backend_config: BackendConfig) -> Dict[Pattern, Callable]:
|
| 101 |
+
""" Get a map from fusion pattern to a function that returns extra input nodes
|
| 102 |
+
from the fusion pattern, in the order required by the root node. This is optional,
|
| 103 |
+
if not specified, we will not copy over any extra inputs for the root node.
|
| 104 |
+
Example:
|
| 105 |
+
# Let's say we have the pattern (torch.add, MatchAllNode, (torch.nn.BatchNorm2d, torch.nn.Conv2d))
|
| 106 |
+
# and root node is torch.nn.Conv2d, and the node in MatchAllNode would be an extra
|
| 107 |
+
# argument to the fused module, we can unpack the pattern and return the node at
|
| 108 |
+
# MatchAllNode here
|
| 109 |
+
# we can implement extra_inputs_getter as follows:
|
| 110 |
+
def extra_inputs_getter(pattern) -> List[Any]:
|
| 111 |
+
add, extra_input, conv_pattern = pattern
|
| 112 |
+
return [extra_input]
|
| 113 |
+
"""
|
| 114 |
+
extra_inputs_getter_mapping: Dict[Pattern, Callable] = {}
|
| 115 |
+
for pattern, config in backend_config._pattern_complex_format_to_config.items():
|
| 116 |
+
if config._extra_inputs_getter is not None:
|
| 117 |
+
extra_inputs_getter_mapping[pattern] = config._extra_inputs_getter
|
| 118 |
+
return extra_inputs_getter_mapping
|
| 119 |
+
|
| 120 |
+
def remove_boolean_dispatch_from_name(p) -> Any:
|
| 121 |
+
"""
|
| 122 |
+
Some ops have a default string representation such as
|
| 123 |
+
'<function boolean_dispatch.<locals>.fn at 0x7ff1106bf280>',
|
| 124 |
+
this function replaces them with the hardcoded function names.
|
| 125 |
+
"""
|
| 126 |
+
if p is F.fractional_max_pool2d:
|
| 127 |
+
return "torch.nn.functional.fractional_max_pool2d"
|
| 128 |
+
elif p is F.fractional_max_pool3d:
|
| 129 |
+
return "torch.nn.functional.fractional_max_pool3d"
|
| 130 |
+
elif p is F.max_pool1d:
|
| 131 |
+
return "torch.nn.functional.max_pool1d"
|
| 132 |
+
elif p is F.max_pool2d:
|
| 133 |
+
return "torch.nn.functional.max_pool2d"
|
| 134 |
+
elif p is F.max_pool3d:
|
| 135 |
+
return "torch.nn.functional.max_pool3d"
|
| 136 |
+
elif p is F.adaptive_max_pool1d:
|
| 137 |
+
return "torch.nn.functional.adaptive_max_pool1d"
|
| 138 |
+
elif p is F.adaptive_max_pool2d:
|
| 139 |
+
return "torch.nn.functional.adaptive_max_pool2d"
|
| 140 |
+
elif p is F.adaptive_max_pool3d:
|
| 141 |
+
return "torch.nn.functional.adaptive_max_pool3d"
|
| 142 |
+
assert "boolean_dispatch" not in str(p), \
|
| 143 |
+
f"{p} does not have a human readable representation in " + \
|
| 144 |
+
"quantization documentation"
|
| 145 |
+
return p
|
| 146 |
+
|
| 147 |
+
def pattern_to_human_readable(p) -> Any:
|
| 148 |
+
if isinstance(p, tuple):
|
| 149 |
+
# nested patterns, recurse
|
| 150 |
+
return tuple(pattern_to_human_readable(inner_p) for inner_p in p)
|
| 151 |
+
elif isinstance(p, str):
|
| 152 |
+
# method names are already human readable
|
| 153 |
+
return p
|
| 154 |
+
else:
|
| 155 |
+
p = remove_boolean_dispatch_from_name(p)
|
| 156 |
+
return p
|
| 157 |
+
|
| 158 |
+
# TODO(future PR): move backend_config_dict to use dataclass and move this logic to
|
| 159 |
+
# the corresponding __str__ function
|
| 160 |
+
def entry_to_pretty_str(entry) -> str:
|
| 161 |
+
"""
|
| 162 |
+
Given a backend_config_dict entry, returns a string with the human readable
|
| 163 |
+
representation of it.
|
| 164 |
+
"""
|
| 165 |
+
s = "{\n"
|
| 166 |
+
|
| 167 |
+
# always output the pattern first
|
| 168 |
+
if "pattern" in entry:
|
| 169 |
+
pattern_str = pattern_to_human_readable(entry["pattern"])
|
| 170 |
+
|
| 171 |
+
s += f" 'pattern': {pattern_str},\n"
|
| 172 |
+
|
| 173 |
+
# custom output for dtype_configs to make it look nice
|
| 174 |
+
if "dtype_configs" in entry:
|
| 175 |
+
s += " 'dtype_configs': [\n"
|
| 176 |
+
for dtype_config in entry["dtype_configs"]:
|
| 177 |
+
s += " {\n"
|
| 178 |
+
for k, v in dtype_config.items():
|
| 179 |
+
s += f" '{k}': {v},\n"
|
| 180 |
+
s += " },\n"
|
| 181 |
+
s += " ],\n"
|
| 182 |
+
|
| 183 |
+
# custom output for num_tensor_args_to_observation_type to make it look nice
|
| 184 |
+
if "num_tensor_args_to_observation_type" in entry:
|
| 185 |
+
s += " 'num_tensor_args_to_observation_type': {\n"
|
| 186 |
+
for k, v in entry["num_tensor_args_to_observation_type"].items():
|
| 187 |
+
s += f" {k}: {v},\n"
|
| 188 |
+
s += " },\n"
|
| 189 |
+
|
| 190 |
+
# output all the other fields
|
| 191 |
+
custom_handled_fields = [
|
| 192 |
+
"pattern",
|
| 193 |
+
"dtype_configs",
|
| 194 |
+
"num_tensor_args_to_observation_type",
|
| 195 |
+
]
|
| 196 |
+
for field_name in entry:
|
| 197 |
+
if field_name in custom_handled_fields:
|
| 198 |
+
continue
|
| 199 |
+
s += f" '{field_name}': {entry[field_name]},\n"
|
| 200 |
+
|
| 201 |
+
s += "}"
|
| 202 |
+
return s
|
| 203 |
+
|
| 204 |
+
def _get_pattern_in_reversed_nested_tuple_format(config: BackendPatternConfig) -> Pattern:
|
| 205 |
+
"""
|
| 206 |
+
Return the pattern specified in the given config in the reversed nested tuple format
|
| 207 |
+
used internally in the quantization pattern matching code.
|
| 208 |
+
|
| 209 |
+
If the pattern is not a tuple, or the pattern is already specified in the reversed
|
| 210 |
+
nested tuple format, return the pattern as is. Otherwise:
|
| 211 |
+
|
| 212 |
+
For 2-tuples (a, b), return (b, a).
|
| 213 |
+
For 3-tuples (a, b, c), return (c, (b, a)).
|
| 214 |
+
|
| 215 |
+
For example:
|
| 216 |
+
* Given nn.Linear, return nn.Linear
|
| 217 |
+
* Given (nn.Linear, nn.ReLU), return (nn.ReLU, nn.Linear)
|
| 218 |
+
* Given (nn.Conv2d, nn.BatchNorm2d, nn.ReLU), return
|
| 219 |
+
(nn.ReLU, (nn.BatchNorm2d, nn.Conv2d))
|
| 220 |
+
|
| 221 |
+
For context, the reason why this is needed is the user-facing BackendConfig
|
| 222 |
+
API accepts the flat 2-or-3-tuple format in forward order. While this simple
|
| 223 |
+
format handles the vast majority of use cases, it does not handle the more
|
| 224 |
+
complex ones, and so the internal pattern matching code for quantization uses
|
| 225 |
+
the following, more general reversed nested tuple format instead:
|
| 226 |
+
|
| 227 |
+
operator = module_type | functional | torch op | native op | MatchAllNode
|
| 228 |
+
Pattern = (operator, Pattern, Pattern, ...) | operator
|
| 229 |
+
|
| 230 |
+
In the future, we expect to replace the above complex format with the one used
|
| 231 |
+
by the subgraph rewriter in torch.fx, so we don't have to maintain our own
|
| 232 |
+
complex pattern matching code. Then we won't need this helper function anymore.
|
| 233 |
+
"""
|
| 234 |
+
if config._pattern_complex_format is not None:
|
| 235 |
+
return config._pattern_complex_format
|
| 236 |
+
if config.pattern is None:
|
| 237 |
+
raise ValueError("Either 'pattern' or 'pattern_complex_format' must be specified")
|
| 238 |
+
if not isinstance(config.pattern, tuple):
|
| 239 |
+
return config.pattern
|
| 240 |
+
|
| 241 |
+
# Pattern is specified in the simple tuple format, need to convert
|
| 242 |
+
if len(config.pattern) == 2:
|
| 243 |
+
(a, b) = config.pattern
|
| 244 |
+
return (b, a)
|
| 245 |
+
elif len(config.pattern) == 3:
|
| 246 |
+
(a, b, c) = config.pattern
|
| 247 |
+
return (c, (b, a))
|
| 248 |
+
else:
|
| 249 |
+
raise ValueError("Expected a tuple with 2 or 3 elements, got: ", config.pattern)
|
| 250 |
+
|
| 251 |
+
def _get_fuser_method_in_reversed_nested_tuple_format(config: BackendPatternConfig) -> Callable:
|
| 252 |
+
"""
|
| 253 |
+
Return the fuser method specified in the given config in the reversed nested
|
| 254 |
+
tuple format used internally in the quantization pattern matching code.
|
| 255 |
+
|
| 256 |
+
If pattern is specified in the reversed nested tuple format, we assume the
|
| 257 |
+
fuser method is also specified in this format and simply return it as is.
|
| 258 |
+
Otherwise, we convert the fuser method as follows:
|
| 259 |
+
|
| 260 |
+
* Given f(is_qat, conv, relu), return f'(is_qat, relu, conv)
|
| 261 |
+
* Given f(is_qat, conv, bn, relu), return f'(is_qat, relu, bn_conv),
|
| 262 |
+
where bn_conv is a 2-tuple (bn, conv)
|
| 263 |
+
|
| 264 |
+
The first argument of a fuser method is always `is_qat` and is not affected
|
| 265 |
+
in the conversion. We currently only support functions with 3 or 4 arguments.
|
| 266 |
+
"""
|
| 267 |
+
assert config.fuser_method is not None
|
| 268 |
+
if config._pattern_complex_format is not None:
|
| 269 |
+
return config.fuser_method
|
| 270 |
+
if not isinstance(config.pattern, tuple):
|
| 271 |
+
raise ValueError("Expected pattern to be a tuple, got: ", config.pattern)
|
| 272 |
+
|
| 273 |
+
# Pattern is specified in the simple tuple format, need to convert
|
| 274 |
+
if len(config.pattern) == 2:
|
| 275 |
+
return _reverse2(config.fuser_method)
|
| 276 |
+
elif len(config.pattern) == 3:
|
| 277 |
+
return _reverse3(config.fuser_method)
|
| 278 |
+
else:
|
| 279 |
+
raise ValueError("Expected a tuple with 2 or 3 elements, got: ", config.pattern)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/x86.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from ._common_operator_config_utils import (
|
| 3 |
+
_get_binary_op_configs,
|
| 4 |
+
_get_bn_configs,
|
| 5 |
+
_get_cat_config,
|
| 6 |
+
_get_conv_configs,
|
| 7 |
+
_get_default_op_configs,
|
| 8 |
+
_get_embedding_op_configs,
|
| 9 |
+
_get_fixed_qparams_op_configs,
|
| 10 |
+
_get_linear_configs,
|
| 11 |
+
_get_rnn_op_configs,
|
| 12 |
+
_get_share_qparams_op_configs,
|
| 13 |
+
_get_tensor_info_op_configs,
|
| 14 |
+
)
|
| 15 |
+
from .backend_config import BackendConfig, DTypeConfig
|
| 16 |
+
|
| 17 |
+
__all__ = [
|
| 18 |
+
"get_x86_backend_config",
|
| 19 |
+
]
|
| 20 |
+
|
| 21 |
+
# ===================
|
| 22 |
+
# | DTYPE CONFIGS |
|
| 23 |
+
# ===================
|
| 24 |
+
|
| 25 |
+
# X86 aligns with FBGEMM for now
|
| 26 |
+
|
| 27 |
+
x86_weighted_op_int8_dtype_config = DTypeConfig(
|
| 28 |
+
input_dtype=torch.quint8,
|
| 29 |
+
output_dtype=torch.quint8,
|
| 30 |
+
weight_dtype=torch.qint8,
|
| 31 |
+
bias_dtype=torch.float,
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
x86_default_op_quint8_dtype_config = DTypeConfig(
|
| 35 |
+
input_dtype=torch.quint8,
|
| 36 |
+
output_dtype=torch.quint8,
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
x86_default_op_fp16_dtype_config = DTypeConfig(
|
| 40 |
+
input_dtype=torch.float16,
|
| 41 |
+
output_dtype=torch.float16,
|
| 42 |
+
weight_dtype=torch.float16,
|
| 43 |
+
bias_dtype=torch.float16,
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
x86_default_dynamic_int8_dtype_config = DTypeConfig(
|
| 47 |
+
input_dtype=torch.quint8,
|
| 48 |
+
output_dtype=torch.float,
|
| 49 |
+
weight_dtype=torch.qint8,
|
| 50 |
+
bias_dtype=torch.float,
|
| 51 |
+
is_dynamic=True,
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
x86_default_dynamic_float16_dtype_config = DTypeConfig(
|
| 55 |
+
input_dtype=torch.float16,
|
| 56 |
+
output_dtype=torch.float,
|
| 57 |
+
weight_dtype=torch.float16,
|
| 58 |
+
bias_dtype=torch.float,
|
| 59 |
+
is_dynamic=True,
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
x86_weight_only_quint8_dtype_config = DTypeConfig(
|
| 63 |
+
input_dtype=torch.float,
|
| 64 |
+
output_dtype=torch.float,
|
| 65 |
+
weight_dtype=torch.quint8,
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
x86_weight_only_quint4x2_dtype_config = DTypeConfig(
|
| 69 |
+
input_dtype=torch.float,
|
| 70 |
+
output_dtype=torch.float,
|
| 71 |
+
weight_dtype=torch.quint4x2,
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
# =====================
|
| 76 |
+
# | BACKEND CONFIGS |
|
| 77 |
+
# =====================
|
| 78 |
+
|
| 79 |
+
def get_x86_backend_config() -> BackendConfig:
|
| 80 |
+
"""
|
| 81 |
+
Return the `BackendConfig` for PyTorch's native x86 backend.
|
| 82 |
+
"""
|
| 83 |
+
conv_dtype_configs = [x86_weighted_op_int8_dtype_config]
|
| 84 |
+
linear_dtype_configs = [
|
| 85 |
+
x86_weighted_op_int8_dtype_config,
|
| 86 |
+
x86_default_dynamic_int8_dtype_config,
|
| 87 |
+
x86_default_dynamic_float16_dtype_config,
|
| 88 |
+
]
|
| 89 |
+
binary_op_dtype_configs = [x86_weighted_op_int8_dtype_config]
|
| 90 |
+
default_op_dtype_configs = [x86_default_op_quint8_dtype_config]
|
| 91 |
+
fixed_qparams_op_dtype_configs = [x86_weighted_op_int8_dtype_config]
|
| 92 |
+
share_qparams_op_dtype_configs = [x86_default_op_quint8_dtype_config]
|
| 93 |
+
tensor_info_op_dtype_configs = [x86_default_op_quint8_dtype_config]
|
| 94 |
+
rnn_op_dtype_configs = [
|
| 95 |
+
x86_default_dynamic_int8_dtype_config,
|
| 96 |
+
x86_default_dynamic_float16_dtype_config,
|
| 97 |
+
]
|
| 98 |
+
embedding_op_dtype_configs = [
|
| 99 |
+
x86_weight_only_quint8_dtype_config,
|
| 100 |
+
x86_weight_only_quint4x2_dtype_config,
|
| 101 |
+
]
|
| 102 |
+
return BackendConfig("x86") \
|
| 103 |
+
.set_backend_pattern_configs(_get_conv_configs(conv_dtype_configs)) \
|
| 104 |
+
.set_backend_pattern_configs(_get_linear_configs(linear_dtype_configs)) \
|
| 105 |
+
.set_backend_pattern_configs(_get_binary_op_configs(binary_op_dtype_configs)) \
|
| 106 |
+
.set_backend_pattern_config(_get_cat_config(default_op_dtype_configs)) \
|
| 107 |
+
.set_backend_pattern_configs(_get_default_op_configs(default_op_dtype_configs)) \
|
| 108 |
+
.set_backend_pattern_configs(_get_fixed_qparams_op_configs(fixed_qparams_op_dtype_configs)) \
|
| 109 |
+
.set_backend_pattern_configs(_get_share_qparams_op_configs(share_qparams_op_dtype_configs)) \
|
| 110 |
+
.set_backend_pattern_configs(_get_tensor_info_op_configs(tensor_info_op_dtype_configs)) \
|
| 111 |
+
.set_backend_pattern_configs(_get_bn_configs(default_op_dtype_configs)) \
|
| 112 |
+
.set_backend_pattern_configs(_get_rnn_op_configs(rnn_op_dtype_configs)) \
|
| 113 |
+
.set_backend_pattern_configs(_get_embedding_op_configs(embedding_op_dtype_configs))
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fake_quantize.py
ADDED
|
@@ -0,0 +1,546 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Implements modules used to perform fake quantization."""
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch.nn import Module
|
| 5 |
+
from torch.ao.quantization.observer import (
|
| 6 |
+
MovingAverageMinMaxObserver,
|
| 7 |
+
HistogramObserver,
|
| 8 |
+
MovingAveragePerChannelMinMaxObserver,
|
| 9 |
+
FixedQParamsObserver,
|
| 10 |
+
default_fixed_qparams_range_0to1_observer,
|
| 11 |
+
default_fixed_qparams_range_neg1to1_observer,
|
| 12 |
+
_with_args,
|
| 13 |
+
)
|
| 14 |
+
import re
|
| 15 |
+
from abc import ABC, abstractmethod
|
| 16 |
+
from typing import Any, Tuple
|
| 17 |
+
|
| 18 |
+
__all__ = [
|
| 19 |
+
"FakeQuantizeBase",
|
| 20 |
+
"FakeQuantize",
|
| 21 |
+
"FixedQParamsFakeQuantize",
|
| 22 |
+
"FusedMovingAvgObsFakeQuantize",
|
| 23 |
+
"disable_fake_quant",
|
| 24 |
+
"disable_observer",
|
| 25 |
+
"enable_fake_quant",
|
| 26 |
+
"enable_observer",
|
| 27 |
+
"default_fake_quant",
|
| 28 |
+
"default_weight_fake_quant",
|
| 29 |
+
"default_dynamic_fake_quant",
|
| 30 |
+
"default_fixed_qparams_range_neg1to1_fake_quant",
|
| 31 |
+
"default_fixed_qparams_range_0to1_fake_quant",
|
| 32 |
+
"default_symmetric_fixed_qparams_fake_quant",
|
| 33 |
+
"default_affine_fixed_qparams_fake_quant",
|
| 34 |
+
"default_per_channel_weight_fake_quant",
|
| 35 |
+
"default_embedding_fake_quant",
|
| 36 |
+
"default_embedding_fake_quant_4bit",
|
| 37 |
+
"default_histogram_fake_quant",
|
| 38 |
+
"default_fused_act_fake_quant",
|
| 39 |
+
"default_fused_wt_fake_quant",
|
| 40 |
+
"default_fused_per_channel_wt_fake_quant",
|
| 41 |
+
"fused_wt_fake_quant_range_neg_127_to_127",
|
| 42 |
+
"fused_per_channel_wt_fake_quant_range_neg_127_to_127",
|
| 43 |
+
]
|
| 44 |
+
|
| 45 |
+
def _is_per_channel(qscheme: 'torch.qscheme') -> bool:
|
| 46 |
+
return qscheme in [torch.per_channel_symmetric, torch.per_channel_affine, torch.per_channel_affine_float_qparams]
|
| 47 |
+
|
| 48 |
+
def _is_per_tensor(qscheme: 'torch.qscheme') -> bool:
|
| 49 |
+
return qscheme in [torch.per_tensor_symmetric, torch.per_tensor_affine]
|
| 50 |
+
|
| 51 |
+
def _is_symmetric_quant(qscheme: 'torch.qscheme') -> bool:
|
| 52 |
+
return qscheme in [torch.per_tensor_symmetric, torch.per_channel_symmetric]
|
| 53 |
+
|
| 54 |
+
def _is_float_qparams(qscheme: 'torch.qscheme') -> bool:
|
| 55 |
+
return qscheme in [torch.per_channel_affine_float_qparams, ]
|
| 56 |
+
|
| 57 |
+
class FakeQuantizeBase(ABC, Module):
|
| 58 |
+
r"""Base fake quantize module.
|
| 59 |
+
|
| 60 |
+
Base fake quantize module
|
| 61 |
+
Any fake quantize implementation should derive from this class.
|
| 62 |
+
|
| 63 |
+
Concrete fake quantize module should follow the same API. In forward, they will update
|
| 64 |
+
the statistics of the observed Tensor and fake quantize the input. They should also provide a
|
| 65 |
+
`calculate_qparams` function that computes the quantization parameters given
|
| 66 |
+
the collected statistics.
|
| 67 |
+
|
| 68 |
+
"""
|
| 69 |
+
|
| 70 |
+
fake_quant_enabled: torch.Tensor
|
| 71 |
+
observer_enabled: torch.Tensor
|
| 72 |
+
|
| 73 |
+
def __init__(self):
|
| 74 |
+
"""Set fake_quant_enabled and observer_enabled."""
|
| 75 |
+
super().__init__()
|
| 76 |
+
# fake_quant_enabled and observer_enabled are buffers to support their
|
| 77 |
+
# replication in DDP. Data type is uint8 because NCCL does not support
|
| 78 |
+
# bool tensors.
|
| 79 |
+
self.register_buffer('fake_quant_enabled', torch.tensor([1], dtype=torch.uint8))
|
| 80 |
+
self.register_buffer('observer_enabled', torch.tensor([1], dtype=torch.uint8))
|
| 81 |
+
|
| 82 |
+
@abstractmethod
|
| 83 |
+
def forward(self, x):
|
| 84 |
+
pass
|
| 85 |
+
|
| 86 |
+
@abstractmethod
|
| 87 |
+
def calculate_qparams(self, **kwargs):
|
| 88 |
+
pass
|
| 89 |
+
|
| 90 |
+
@torch.jit.export
|
| 91 |
+
def enable_fake_quant(self, enabled: bool = True) -> None:
|
| 92 |
+
self.fake_quant_enabled[0] = 1 if enabled else 0
|
| 93 |
+
|
| 94 |
+
@torch.jit.export
|
| 95 |
+
def disable_fake_quant(self):
|
| 96 |
+
self.enable_fake_quant(False)
|
| 97 |
+
|
| 98 |
+
@torch.jit.export
|
| 99 |
+
def enable_observer(self, enabled: bool = True) -> None:
|
| 100 |
+
self.observer_enabled[0] = 1 if enabled else 0
|
| 101 |
+
|
| 102 |
+
@torch.jit.export
|
| 103 |
+
def disable_observer(self):
|
| 104 |
+
self.enable_observer(False)
|
| 105 |
+
|
| 106 |
+
@classmethod
|
| 107 |
+
def with_args(cls, **kwargs):
|
| 108 |
+
fake_quant_constructor = _with_args(cls, **kwargs)
|
| 109 |
+
# need to assign the correct module to fake_quantize
|
| 110 |
+
# constructors to satisfy public v private requirements
|
| 111 |
+
fake_quant_constructor.__module__ = "torch.ao.quantization.fake_quantize"
|
| 112 |
+
return fake_quant_constructor
|
| 113 |
+
|
| 114 |
+
class FakeQuantize(FakeQuantizeBase):
|
| 115 |
+
r"""Simulate the quantize and dequantize operations in training time.
|
| 116 |
+
|
| 117 |
+
The output of this module is given by::
|
| 118 |
+
|
| 119 |
+
x_out = (
|
| 120 |
+
clamp(round(x/scale + zero_point), quant_min, quant_max) - zero_point
|
| 121 |
+
) * scale
|
| 122 |
+
|
| 123 |
+
* :attr:`is_dynamic` indicates whether the fake quantie is a placeholder for dynamic quantization
|
| 124 |
+
operators (choose_qparams -> q -> dq) or static quantization operators (q -> dq)
|
| 125 |
+
|
| 126 |
+
* :attr:`scale` defines the scale factor used for quantization.
|
| 127 |
+
|
| 128 |
+
* :attr:`zero_point` specifies the quantized value to which 0 in floating point maps to
|
| 129 |
+
|
| 130 |
+
* :attr:`fake_quant_enabled` controls the application of fake quantization on tensors, note that
|
| 131 |
+
statistics can still be updated.
|
| 132 |
+
|
| 133 |
+
* :attr:`observer_enabled` controls statistics collection on tensors
|
| 134 |
+
|
| 135 |
+
* :attr:`dtype` specifies the quantized dtype that is being emulated with fake-quantization,
|
| 136 |
+
allowable values are torch.qint8 and torch.quint8.
|
| 137 |
+
|
| 138 |
+
Args:
|
| 139 |
+
|
| 140 |
+
observer (module): Module for observing statistics on input tensors and calculating scale
|
| 141 |
+
and zero-point.
|
| 142 |
+
observer_kwargs (optional): Arguments for the observer module
|
| 143 |
+
|
| 144 |
+
Attributes:
|
| 145 |
+
activation_post_process (Module): User provided module that collects statistics on the input tensor and
|
| 146 |
+
provides a method to calculate scale and zero-point.
|
| 147 |
+
|
| 148 |
+
"""
|
| 149 |
+
|
| 150 |
+
scale: torch.Tensor
|
| 151 |
+
zero_point: torch.Tensor
|
| 152 |
+
|
| 153 |
+
def __init__(self, observer=MovingAverageMinMaxObserver, quant_min=None, quant_max=None, is_dynamic=False, **observer_kwargs):
|
| 154 |
+
super().__init__()
|
| 155 |
+
# Populate quant_min/quant_max to observer_kwargs if valid
|
| 156 |
+
if quant_min is not None and quant_max is not None:
|
| 157 |
+
assert quant_min <= quant_max, \
|
| 158 |
+
'quant_min must be less than or equal to quant_max'
|
| 159 |
+
dtype = observer_kwargs.get("dtype", torch.quint8)
|
| 160 |
+
if hasattr(observer, "p"):
|
| 161 |
+
# In case observer is _PartialWrapper, dtype can be stored in
|
| 162 |
+
# observer.p.keywords["dtype"]
|
| 163 |
+
dtype = getattr(getattr(observer, "p", {}), "keywords", {}).get(
|
| 164 |
+
"dtype", dtype
|
| 165 |
+
)
|
| 166 |
+
assert torch.iinfo(dtype).min <= quant_min, 'quant_min out of bound'
|
| 167 |
+
assert quant_max <= torch.iinfo(dtype).max, 'quant_max out of bound'
|
| 168 |
+
observer_kwargs.update({"quant_min": quant_min, "quant_max": quant_max})
|
| 169 |
+
observer_kwargs["is_dynamic"] = is_dynamic
|
| 170 |
+
self.activation_post_process = observer(**observer_kwargs)
|
| 171 |
+
# TODO: keeping self.quant_min/max for BC; remove after a couple releases
|
| 172 |
+
# Users should use self.activation_post_process.quant_min
|
| 173 |
+
self.quant_min = self.activation_post_process.quant_min
|
| 174 |
+
self.quant_max = self.activation_post_process.quant_max
|
| 175 |
+
self.is_dynamic = self.activation_post_process.is_dynamic
|
| 176 |
+
if _is_float_qparams(self.activation_post_process.qscheme):
|
| 177 |
+
zero_point_dtype = torch.float
|
| 178 |
+
else:
|
| 179 |
+
zero_point_dtype = torch.int
|
| 180 |
+
self.register_buffer('scale', torch.tensor([1.0], dtype=torch.float))
|
| 181 |
+
self.register_buffer('zero_point', torch.tensor([0], dtype=zero_point_dtype))
|
| 182 |
+
self.dtype = self.activation_post_process.dtype
|
| 183 |
+
self.qscheme = self.activation_post_process.qscheme
|
| 184 |
+
self.ch_axis = self.activation_post_process.ch_axis \
|
| 185 |
+
if hasattr(self.activation_post_process, 'ch_axis') else -1
|
| 186 |
+
assert _is_per_channel(self.qscheme) or \
|
| 187 |
+
_is_per_tensor(self.qscheme), \
|
| 188 |
+
'Only per channel and per tensor quantization are supported in fake quantize' + \
|
| 189 |
+
' got qscheme: ' + str(self.qscheme)
|
| 190 |
+
self.is_per_channel = _is_per_channel(self.qscheme)
|
| 191 |
+
|
| 192 |
+
@torch.jit.export
|
| 193 |
+
def calculate_qparams(self):
|
| 194 |
+
return self.activation_post_process.calculate_qparams()
|
| 195 |
+
|
| 196 |
+
def forward(self, X):
|
| 197 |
+
if self.observer_enabled[0] == 1:
|
| 198 |
+
self.activation_post_process(X.detach())
|
| 199 |
+
_scale, _zero_point = self.calculate_qparams()
|
| 200 |
+
_scale, _zero_point = _scale.to(self.scale.device), _zero_point.to(self.zero_point.device)
|
| 201 |
+
if self.scale.shape != _scale.shape:
|
| 202 |
+
self.scale.resize_(_scale.shape)
|
| 203 |
+
self.zero_point.resize_(_zero_point.shape)
|
| 204 |
+
self.scale.copy_(_scale)
|
| 205 |
+
self.zero_point.copy_(_zero_point)
|
| 206 |
+
|
| 207 |
+
if self.fake_quant_enabled[0] == 1:
|
| 208 |
+
if self.is_per_channel:
|
| 209 |
+
X = torch.fake_quantize_per_channel_affine(
|
| 210 |
+
X, self.scale, self.zero_point,
|
| 211 |
+
self.ch_axis, self.activation_post_process.quant_min, self.activation_post_process.quant_max)
|
| 212 |
+
else:
|
| 213 |
+
X = torch.fake_quantize_per_tensor_affine(
|
| 214 |
+
X, self.scale, self.zero_point,
|
| 215 |
+
self.activation_post_process.quant_min, self.activation_post_process.quant_max)
|
| 216 |
+
return X
|
| 217 |
+
|
| 218 |
+
@torch.jit.export
|
| 219 |
+
def extra_repr(self):
|
| 220 |
+
return 'fake_quant_enabled={}, observer_enabled={}, ' \
|
| 221 |
+
'quant_min={}, quant_max={}, dtype={}, qscheme={}, ch_axis={}, ' \
|
| 222 |
+
'scale={}, zero_point={}'.format(
|
| 223 |
+
self.fake_quant_enabled, self.observer_enabled,
|
| 224 |
+
self.activation_post_process.quant_min, self.activation_post_process.quant_max,
|
| 225 |
+
self.dtype, self.qscheme, self.ch_axis, self.scale, self.zero_point)
|
| 226 |
+
|
| 227 |
+
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
| 228 |
+
# We cannot currently register scalar values as buffers, so need to manually
|
| 229 |
+
# specify serialization here.
|
| 230 |
+
super()._save_to_state_dict(destination, prefix, keep_vars)
|
| 231 |
+
destination[prefix + 'scale'] = self.scale
|
| 232 |
+
destination[prefix + 'zero_point'] = self.zero_point
|
| 233 |
+
|
| 234 |
+
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
|
| 235 |
+
missing_keys, unexpected_keys, error_msgs):
|
| 236 |
+
# Removing this function throws an error that the size of the loaded tensor does not match the original size
|
| 237 |
+
# i.e., These buffers start out with numel 0 and become numel 1 once they have their first forward pass.
|
| 238 |
+
local_state = ['scale', 'zero_point']
|
| 239 |
+
for name in local_state:
|
| 240 |
+
key = prefix + name
|
| 241 |
+
if key in state_dict:
|
| 242 |
+
val = state_dict[key]
|
| 243 |
+
# Custom handling to allow loading scale and zero_point
|
| 244 |
+
# of size N into uninitialized buffers of size 0. The
|
| 245 |
+
# buffers are resized here, and the values are copied in
|
| 246 |
+
# the default state_dict loading code of the parent.
|
| 247 |
+
if name == 'scale':
|
| 248 |
+
self.scale.resize_(val.shape)
|
| 249 |
+
else:
|
| 250 |
+
assert name == 'zero_point'
|
| 251 |
+
self.zero_point.resize_(val.shape)
|
| 252 |
+
# For torchscript module we need to update the attributes here since we do not
|
| 253 |
+
# call the `_load_from_state_dict` function defined module.py
|
| 254 |
+
if torch.jit.is_scripting():
|
| 255 |
+
if name == 'scale':
|
| 256 |
+
self.scale.copy_(val)
|
| 257 |
+
else:
|
| 258 |
+
assert name == 'zero_point'
|
| 259 |
+
self.zero_point.copy_(val)
|
| 260 |
+
elif strict:
|
| 261 |
+
missing_keys.append(key)
|
| 262 |
+
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
|
| 263 |
+
missing_keys, unexpected_keys, error_msgs)
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
class FixedQParamsFakeQuantize(FakeQuantize):
|
| 267 |
+
"""Simulate quantize and dequantize in training time.
|
| 268 |
+
|
| 269 |
+
Simulate quantize and dequantize with fixed quantization
|
| 270 |
+
parameters in training time. Only per tensor quantization
|
| 271 |
+
is supported.
|
| 272 |
+
"""
|
| 273 |
+
|
| 274 |
+
# TODO: rename observer to observer_ctr
|
| 275 |
+
def __init__(self, observer):
|
| 276 |
+
super().__init__(observer=observer)
|
| 277 |
+
assert type(self.activation_post_process) == FixedQParamsObserver, \
|
| 278 |
+
f"{self.__class__.__name__}'s observer must be a {FixedQParamsObserver.__name__}"
|
| 279 |
+
self._observer_ctr = observer
|
| 280 |
+
self.scale = self.activation_post_process.scale
|
| 281 |
+
self.zero_point = self.activation_post_process.zero_point
|
| 282 |
+
assert _is_per_tensor(self.qscheme), 'Only per tensor quantization is supported' + \
|
| 283 |
+
' FixedQParamsFakeQuantize module, got qscheme:' + str(self.qscheme)
|
| 284 |
+
|
| 285 |
+
@torch.jit.export
|
| 286 |
+
def calculate_qparams(self):
|
| 287 |
+
return self.scale, self.zero_point
|
| 288 |
+
|
| 289 |
+
@torch.jit.export
|
| 290 |
+
def extra_repr(self):
|
| 291 |
+
"""Define a string representation of the object's attributes."""
|
| 292 |
+
return 'fake_quant_enabled={}, observer_enabled={}, scale={}, zero_point={}, ' \
|
| 293 |
+
'dtype={}, quant_min={}, quant_max={}, qscheme={}'.format(
|
| 294 |
+
self.fake_quant_enabled, self.observer_enabled,
|
| 295 |
+
self.scale, self.zero_point, self.dtype,
|
| 296 |
+
self.activation_post_process.quant_min, self.activation_post_process.quant_max, self.qscheme)
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
class FusedMovingAvgObsFakeQuantize(FakeQuantize):
|
| 300 |
+
r"""Define a fused module to observe the tensor.
|
| 301 |
+
|
| 302 |
+
Fused module that is used to observe the input tensor (compute min/max), compute
|
| 303 |
+
scale/zero_point and fake_quantize the tensor.
|
| 304 |
+
This module uses calculation similar MovingAverageMinMaxObserver for the inputs,
|
| 305 |
+
to compute the min/max values in order to compute the scale/zero_point.
|
| 306 |
+
The qscheme input in the observer is used to differentiate between symmetric/affine
|
| 307 |
+
quantization scheme.
|
| 308 |
+
|
| 309 |
+
The output of this module is given by
|
| 310 |
+
x_out = (clamp(round(x/scale + zero_point), quant_min, quant_max)-zero_point)*scale
|
| 311 |
+
|
| 312 |
+
Similar to :class:`~torch.ao.quantization.FakeQuantize`, and accepts the same attributes as the
|
| 313 |
+
base class.
|
| 314 |
+
|
| 315 |
+
"""
|
| 316 |
+
|
| 317 |
+
def __init__(
|
| 318 |
+
self,
|
| 319 |
+
observer: Any = MovingAverageMinMaxObserver,
|
| 320 |
+
quant_min: int = 0,
|
| 321 |
+
quant_max: int = 255,
|
| 322 |
+
**observer_kwargs: Any
|
| 323 |
+
) -> None:
|
| 324 |
+
super().__init__(observer, quant_min, quant_max, **observer_kwargs)
|
| 325 |
+
assert isinstance(self.activation_post_process, (MovingAverageMinMaxObserver, MovingAveragePerChannelMinMaxObserver)), \
|
| 326 |
+
"Fused observer+fake_quant module only works with MovingAverageMinMaxObserver"
|
| 327 |
+
self.register_buffer("fake_quant_enabled", torch.tensor([1], dtype=torch.long))
|
| 328 |
+
self.register_buffer("observer_enabled", torch.tensor([1], dtype=torch.long))
|
| 329 |
+
self.is_symmetric_quant = _is_symmetric_quant(self.activation_post_process.qscheme)
|
| 330 |
+
|
| 331 |
+
@torch.jit.export
|
| 332 |
+
def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 333 |
+
return self.activation_post_process.calculate_qparams()
|
| 334 |
+
|
| 335 |
+
@torch.jit.export
|
| 336 |
+
def extra_repr(self) -> str:
|
| 337 |
+
return (
|
| 338 |
+
"fake_quant_enabled={}, observer_enabled={}, scale={}, zero_point={}, "
|
| 339 |
+
"dtype={}, quant_min={}, quant_max={}, qscheme={}, reduce_range={}".format(
|
| 340 |
+
self.fake_quant_enabled,
|
| 341 |
+
self.observer_enabled,
|
| 342 |
+
self.scale,
|
| 343 |
+
self.zero_point,
|
| 344 |
+
self.dtype,
|
| 345 |
+
self.activation_post_process.quant_min,
|
| 346 |
+
self.activation_post_process.quant_max,
|
| 347 |
+
self.qscheme,
|
| 348 |
+
self.activation_post_process.reduce_range,
|
| 349 |
+
)
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
def forward(self, X: torch.Tensor) -> torch.Tensor:
|
| 353 |
+
return torch.fused_moving_avg_obs_fake_quant(
|
| 354 |
+
X,
|
| 355 |
+
self.observer_enabled,
|
| 356 |
+
self.fake_quant_enabled,
|
| 357 |
+
self.activation_post_process.min_val,
|
| 358 |
+
self.activation_post_process.max_val,
|
| 359 |
+
self.scale,
|
| 360 |
+
self.zero_point,
|
| 361 |
+
self.activation_post_process.averaging_constant,
|
| 362 |
+
self.activation_post_process.quant_min,
|
| 363 |
+
self.activation_post_process.quant_max,
|
| 364 |
+
self.ch_axis,
|
| 365 |
+
self.is_per_channel,
|
| 366 |
+
self.is_symmetric_quant,
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
default_fake_quant = FakeQuantize.with_args(observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255,
|
| 370 |
+
dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=True)
|
| 371 |
+
"""
|
| 372 |
+
Default fake_quant for activations.
|
| 373 |
+
"""
|
| 374 |
+
|
| 375 |
+
default_weight_fake_quant = FakeQuantize.with_args(observer=MovingAverageMinMaxObserver, quant_min=-128, quant_max=127,
|
| 376 |
+
dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, reduce_range=False)
|
| 377 |
+
"""
|
| 378 |
+
Default fake_quant for weights.
|
| 379 |
+
Observer is memoryless since averaging_constant is 1.
|
| 380 |
+
"""
|
| 381 |
+
|
| 382 |
+
default_dynamic_fake_quant = FakeQuantize.with_args(
|
| 383 |
+
observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255, is_dynamic=True,
|
| 384 |
+
dtype=torch.quint8, averaging_constant=1)
|
| 385 |
+
"""
|
| 386 |
+
Default dynamic fake_quant for activations.
|
| 387 |
+
"""
|
| 388 |
+
|
| 389 |
+
default_fixed_qparams_range_neg1to1_fake_quant = (
|
| 390 |
+
FixedQParamsFakeQuantize.with_args(observer=default_fixed_qparams_range_neg1to1_observer)
|
| 391 |
+
)
|
| 392 |
+
default_fixed_qparams_range_0to1_fake_quant = (
|
| 393 |
+
FixedQParamsFakeQuantize.with_args(observer=default_fixed_qparams_range_0to1_observer)
|
| 394 |
+
)
|
| 395 |
+
# TODO: the following 2 variables are kept for backwards compatibility; remove after a few releases
|
| 396 |
+
default_symmetric_fixed_qparams_fake_quant = default_fixed_qparams_range_neg1to1_fake_quant
|
| 397 |
+
default_affine_fixed_qparams_fake_quant = default_fixed_qparams_range_0to1_fake_quant
|
| 398 |
+
|
| 399 |
+
default_per_channel_weight_fake_quant = FakeQuantize.with_args(observer=MovingAveragePerChannelMinMaxObserver,
|
| 400 |
+
quant_min=-128,
|
| 401 |
+
quant_max=127,
|
| 402 |
+
dtype=torch.qint8,
|
| 403 |
+
qscheme=torch.per_channel_symmetric,
|
| 404 |
+
reduce_range=False,
|
| 405 |
+
ch_axis=0)
|
| 406 |
+
"""
|
| 407 |
+
Default fake_quant for per-channel weights.
|
| 408 |
+
Observer is memoryless since averaging_constant is 1.
|
| 409 |
+
"""
|
| 410 |
+
default_embedding_fake_quant = FakeQuantize.with_args(observer=MovingAveragePerChannelMinMaxObserver,
|
| 411 |
+
qscheme=torch.per_channel_affine_float_qparams,
|
| 412 |
+
dtype=torch.quint8,
|
| 413 |
+
quant_min=0,
|
| 414 |
+
quant_max=255,
|
| 415 |
+
ch_axis=0,
|
| 416 |
+
averaging_constant=1)
|
| 417 |
+
"""
|
| 418 |
+
Default fake_quant for embeddings.
|
| 419 |
+
Observer is memoryless since averaging_constant is 1.
|
| 420 |
+
"""
|
| 421 |
+
|
| 422 |
+
default_embedding_fake_quant_4bit = FakeQuantize.with_args(observer=MovingAveragePerChannelMinMaxObserver,
|
| 423 |
+
qscheme=torch.per_channel_affine_float_qparams,
|
| 424 |
+
ch_axis=0,
|
| 425 |
+
dtype=torch.quint4x2,
|
| 426 |
+
averaging_constant=1)
|
| 427 |
+
|
| 428 |
+
default_histogram_fake_quant = FakeQuantize.with_args(observer=HistogramObserver,
|
| 429 |
+
quant_min=0,
|
| 430 |
+
quant_max=255,
|
| 431 |
+
dtype=torch.quint8,
|
| 432 |
+
qscheme=torch.per_tensor_affine,
|
| 433 |
+
reduce_range=True)
|
| 434 |
+
"""
|
| 435 |
+
Fake_quant for activations using a histogram..
|
| 436 |
+
"""
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
default_fused_act_fake_quant = FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
|
| 440 |
+
quant_min=0,
|
| 441 |
+
quant_max=255,
|
| 442 |
+
dtype=torch.quint8,)
|
| 443 |
+
|
| 444 |
+
"""
|
| 445 |
+
Fused version of `default_fake_quant`, with improved performance.
|
| 446 |
+
"""
|
| 447 |
+
|
| 448 |
+
|
| 449 |
+
default_fused_wt_fake_quant = FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
|
| 450 |
+
quant_min=-128,
|
| 451 |
+
quant_max=127,
|
| 452 |
+
dtype=torch.qint8,
|
| 453 |
+
qscheme=torch.per_tensor_symmetric)
|
| 454 |
+
"""
|
| 455 |
+
Fused version of `default_weight_fake_quant`, with improved performance.
|
| 456 |
+
"""
|
| 457 |
+
|
| 458 |
+
default_fused_per_channel_wt_fake_quant = FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAveragePerChannelMinMaxObserver,
|
| 459 |
+
quant_min=-128,
|
| 460 |
+
quant_max=127,
|
| 461 |
+
dtype=torch.qint8,
|
| 462 |
+
qscheme=torch.per_channel_symmetric)
|
| 463 |
+
"""
|
| 464 |
+
Fused version of `default_per_channel_weight_fake_quant`, with improved performance.
|
| 465 |
+
"""
|
| 466 |
+
|
| 467 |
+
fused_wt_fake_quant_range_neg_127_to_127 = FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
|
| 468 |
+
quant_min=-127,
|
| 469 |
+
quant_max=127,
|
| 470 |
+
dtype=torch.qint8,
|
| 471 |
+
qscheme=torch.per_tensor_symmetric,
|
| 472 |
+
eps=2 ** -12)
|
| 473 |
+
"""
|
| 474 |
+
Fused version of `default_weight_fake_quant`, with the 8-bit values restricted to [-127, +127], excluding -128.
|
| 475 |
+
"""
|
| 476 |
+
|
| 477 |
+
fused_per_channel_wt_fake_quant_range_neg_127_to_127 = \
|
| 478 |
+
FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAveragePerChannelMinMaxObserver,
|
| 479 |
+
quant_min=-127,
|
| 480 |
+
quant_max=127,
|
| 481 |
+
dtype=torch.qint8,
|
| 482 |
+
qscheme=torch.per_channel_symmetric,
|
| 483 |
+
eps=2 ** -12)
|
| 484 |
+
|
| 485 |
+
"""
|
| 486 |
+
Fused version of `default_per_channel_weight_fake_quant`, with the 8-bit values restricted to [-127, +127], excluding -128.
|
| 487 |
+
"""
|
| 488 |
+
|
| 489 |
+
|
| 490 |
+
def _is_fake_quant_script_module(mod):
|
| 491 |
+
"""Return true if given mod is an instance of FakeQuantize script module."""
|
| 492 |
+
if isinstance(mod, torch.jit.RecursiveScriptModule):
|
| 493 |
+
# qualified name looks like '__torch__.torch.ao.quantization.fake_quantize.___torch_mangle_2.FakeQuantize'
|
| 494 |
+
suffix = mod._c.qualified_name.split('.', 1)[1]
|
| 495 |
+
name = re.sub(r'\.___torch_mangle_\d+', '', suffix)
|
| 496 |
+
return name == 'torch.ao.quantization.fake_quantize.FakeQuantize' or \
|
| 497 |
+
name == 'torch.ao.quantization.fake_quantize.FusedMovingAvgObsFakeQuantize'
|
| 498 |
+
return False
|
| 499 |
+
|
| 500 |
+
def disable_fake_quant(mod):
|
| 501 |
+
"""Disable fake quantization for the module.
|
| 502 |
+
|
| 503 |
+
Disable fake quantization for this module, if applicable. Example usage::
|
| 504 |
+
|
| 505 |
+
# model is any PyTorch model
|
| 506 |
+
model.apply(torch.ao.quantization.disable_fake_quant)
|
| 507 |
+
|
| 508 |
+
"""
|
| 509 |
+
if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod):
|
| 510 |
+
mod.disable_fake_quant()
|
| 511 |
+
|
| 512 |
+
def enable_fake_quant(mod):
|
| 513 |
+
"""Enable fake quantization for the module.
|
| 514 |
+
|
| 515 |
+
Enable fake quantization for this module, if applicable. Example usage::
|
| 516 |
+
|
| 517 |
+
# model is any PyTorch model
|
| 518 |
+
model.apply(torch.ao.quantization.enable_fake_quant)
|
| 519 |
+
|
| 520 |
+
"""
|
| 521 |
+
if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod):
|
| 522 |
+
mod.enable_fake_quant()
|
| 523 |
+
|
| 524 |
+
def disable_observer(mod):
|
| 525 |
+
"""Disable observation for this module.
|
| 526 |
+
|
| 527 |
+
Disable observation for this module, if applicable. Example usage::
|
| 528 |
+
|
| 529 |
+
# model is any PyTorch model
|
| 530 |
+
model.apply(torch.ao.quantization.disable_observer)
|
| 531 |
+
|
| 532 |
+
"""
|
| 533 |
+
if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod):
|
| 534 |
+
mod.disable_observer()
|
| 535 |
+
|
| 536 |
+
def enable_observer(mod):
|
| 537 |
+
"""Enable observation for this module.
|
| 538 |
+
|
| 539 |
+
Enable observation for this module, if applicable. Example usage::
|
| 540 |
+
|
| 541 |
+
# model is any PyTorch model
|
| 542 |
+
model.apply(torch.ao.quantization.enable_observer)
|
| 543 |
+
|
| 544 |
+
"""
|
| 545 |
+
if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod):
|
| 546 |
+
mod.enable_observer()
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/_model_report/model_report.py
ADDED
|
@@ -0,0 +1,606 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Dict, Set, Tuple, Callable
|
| 2 |
+
from collections import OrderedDict
|
| 3 |
+
import torch
|
| 4 |
+
from torch.ao.quantization.fx._model_report.detector import (
|
| 5 |
+
DetectorBase,
|
| 6 |
+
DETECTOR_OBS_ARGS_KEY,
|
| 7 |
+
DETECTOR_OBS_TO_INSERT_KEY,
|
| 8 |
+
DETECTOR_IS_POST_OBS_KEY,
|
| 9 |
+
DETECTOR_TARGET_NODE_KEY,
|
| 10 |
+
DetectorQConfigInfo
|
| 11 |
+
)
|
| 12 |
+
from torch.ao.quantization.fx._model_report.model_report_visualizer import ModelReportVisualizer
|
| 13 |
+
from torch.ao.quantization.fx.graph_module import GraphModule
|
| 14 |
+
from torch.ao.quantization.observer import ObserverBase
|
| 15 |
+
from torch.ao.quantization.qconfig_mapping import QConfigMapping, QConfig
|
| 16 |
+
from torch.ao.quantization.fx._equalize import EqualizationQConfig
|
| 17 |
+
|
| 18 |
+
class ModelReport:
|
| 19 |
+
r"""
|
| 20 |
+
The ModelReport class aims to provide users an easy way to diagnose issues that they run into
|
| 21 |
+
with their models. The class works with all traceable GraphModules to help diagnose issues,
|
| 22 |
+
though the requirements on the type of model more-so depends on the specific report the user
|
| 23 |
+
is trying to generate. With respect to the reports, the ModelReport class is initialized with
|
| 24 |
+
a set of Detector classes, each of which generate reports on quantization configuration
|
| 25 |
+
issues a use might have.
|
| 26 |
+
|
| 27 |
+
Currently supports generating reports on:
|
| 28 |
+
- Suggestions for per-channel vs. per-tensor quantization (nn.Module)
|
| 29 |
+
- Suggestions for dynamic vs static quantization for linear layers (Graph Modules)
|
| 30 |
+
- Suggestions for input-weight equalization for linear and conv layers (Graph Modules)
|
| 31 |
+
- Suggestions for outlier detection for all layers (Graph Modules)
|
| 32 |
+
|
| 33 |
+
The ModelReport class has the primary functionality of inserting observers (primarily the ModelReportObserver)
|
| 34 |
+
where needed for each detector to gather the information it needs, and then after callibration, the ModelReport
|
| 35 |
+
class compiles the report generated by each Detector class into a single report to return to the user. It also
|
| 36 |
+
has the capability to remove all the observers it inserted as well.
|
| 37 |
+
|
| 38 |
+
* :attr:`_model` The model we wish to generate the report for. Must be a traceable GraphModule
|
| 39 |
+
|
| 40 |
+
* :attr:`_desired_report_detectors` The set of Detectors representing desired reports from the ModelReport class
|
| 41 |
+
Make sure that these are all unique types of detectors [do not have more than 1 of the same class]
|
| 42 |
+
|
| 43 |
+
* :attr:`_desired_detector_names` The set of detector names of the _desired_report_detectors.
|
| 44 |
+
This set is generated by calling the get_detector_name() of each detector
|
| 45 |
+
|
| 46 |
+
* :attr:`_detector_name_to_observer_fqns` The mapping from each detector to fqns of observers of interest
|
| 47 |
+
The purpose of this is to keep track of what observers were inserted for each detector, so that they
|
| 48 |
+
can be removed at the end if desired
|
| 49 |
+
|
| 50 |
+
* :attr:`_prepared_flag` A boolean flag that keeps track of whether we have prepared the model or not
|
| 51 |
+
This is to ensure we only insert observers once with the ModelReport instance
|
| 52 |
+
|
| 53 |
+
* :attr:`_removed_observers` A boolean to track if we have removed observers already
|
| 54 |
+
The purpose is to ensure we don't attempt to remove observers twice with the same ModelReport
|
| 55 |
+
instance. This also allows the functionality where we can generate the report multiple times
|
| 56 |
+
as long as we haven't removed the observers yet.
|
| 57 |
+
|
| 58 |
+
Note:
|
| 59 |
+
This class was initially designed to work with the Fx Graph Mode workflow in mind. However,
|
| 60 |
+
full functionality is available as long as there is a traceable GraphModule that is being used.
|
| 61 |
+
One method to get a traceable GraphModule without going through the Fx workflow is to use
|
| 62 |
+
the QuantizationTracer class.
|
| 63 |
+
|
| 64 |
+
General Flow for Fx workflow:
|
| 65 |
+
1.) Initialize ModelReport object with reports of interest by passing in initialized detector objects and model
|
| 66 |
+
2.) Prepare your model with prepare_fx
|
| 67 |
+
3.) Call model_report.prepare_detailed_calibration to add relevant observers
|
| 68 |
+
4.) Callibrate your model with data
|
| 69 |
+
5.) Call model_report.generate_report on your model to generate report and optionally remove added observers
|
| 70 |
+
Optional
|
| 71 |
+
6.) Call model_report.generate_visualizer to get a ModelReportVisualizer instance
|
| 72 |
+
7.) To help in parsing report information and debugging, view report info as a:
|
| 73 |
+
- Table
|
| 74 |
+
- Histogram
|
| 75 |
+
- Line plot
|
| 76 |
+
8.) Call model_report.generate_qconfigs to generate the qconfigs based on the report suggestions
|
| 77 |
+
|
| 78 |
+
Example (with QuantizationTracer):
|
| 79 |
+
>>> # xdoctest: +SKIP
|
| 80 |
+
>>> # get the necessary qconfig
|
| 81 |
+
>>> config = PrepareCustomConfig()
|
| 82 |
+
>>> skipped_module_names, skipped_module_classes = get_skipped_module_name_and_classes(config, False)
|
| 83 |
+
|
| 84 |
+
>>> # initialize our model and get GraphModule
|
| 85 |
+
>>> model = SomeModel()
|
| 86 |
+
>>> tracer = QuantizationTracer(skipped_module_names, skipped_module_classes)
|
| 87 |
+
>>> graph_module = GraphModule(model, tracer.trace(model))
|
| 88 |
+
|
| 89 |
+
>>> # get our set of detectors and ModelReport instance
|
| 90 |
+
>>> detector_set = set([DynamicStaticDetector(tolerance=0.5), InputWeightEqualizationDetector(ratio_threshold=0.7)])
|
| 91 |
+
>>> tracer_reporter = ModelReport(graph_module, tracer_detector_set)
|
| 92 |
+
|
| 93 |
+
>>> # now we insert the observers and callibrate the model
|
| 94 |
+
>>> tracer_model_with_observers = tracer_reporter.prepare_detailed_calibration()
|
| 95 |
+
>>> for i in range(num_callibration_batches):
|
| 96 |
+
>>> example_input = get_callibration_input()
|
| 97 |
+
>>> tracer_model_with_observers(example_input)
|
| 98 |
+
|
| 99 |
+
>>> # finally we generate the reports and optionally remove the observers we inserted
|
| 100 |
+
>>> reports = tracer_reporter.generate_model_report(remove_inserted_observers=True)
|
| 101 |
+
|
| 102 |
+
>>> # Optional: we can generate the qconfig mapping based on the suggestions
|
| 103 |
+
>>> qconfigs = model_report.generate_qconfig_mapping()
|
| 104 |
+
|
| 105 |
+
>>> # Optional: we can generate the equalization mapping based on the suggestions
|
| 106 |
+
>>> qconfigs = model_report.generate_equalization_mapping()
|
| 107 |
+
|
| 108 |
+
>>> # Optional: we get a ModelReportVisualizer instance to do any visualizations desired
|
| 109 |
+
>>> model_report_visualizer = tracer_reporter.generate_visualizer()
|
| 110 |
+
|
| 111 |
+
"""
|
| 112 |
+
|
| 113 |
+
def __init__(self, model: GraphModule, desired_report_detectors: Set[DetectorBase]):
|
| 114 |
+
|
| 115 |
+
if len(desired_report_detectors) == 0:
|
| 116 |
+
raise ValueError("Should include at least 1 desired report")
|
| 117 |
+
|
| 118 |
+
# keep track of the model we wish to generate report for
|
| 119 |
+
self._model: GraphModule = model
|
| 120 |
+
|
| 121 |
+
# keep the reports private so they can't be modified
|
| 122 |
+
self._desired_report_detectors = desired_report_detectors
|
| 123 |
+
self._desired_detector_names = {detector.get_detector_name() for detector in desired_report_detectors}
|
| 124 |
+
|
| 125 |
+
# keep a mapping of desired reports to observers of interest
|
| 126 |
+
# this is to get the readings, and to remove them, can create a large set
|
| 127 |
+
# this set can then be used to traverse the graph and remove added observers
|
| 128 |
+
self._detector_name_to_observer_fqns: Dict[str, Set[str]] = {}
|
| 129 |
+
|
| 130 |
+
# initialize each report to have empty set of observers of interest
|
| 131 |
+
for desired_report in self._desired_detector_names:
|
| 132 |
+
self._detector_name_to_observer_fqns[desired_report] = set()
|
| 133 |
+
|
| 134 |
+
# flags to ensure that we can only prepare and remove observers once
|
| 135 |
+
self._prepared_flag = False
|
| 136 |
+
self._removed_observers = False
|
| 137 |
+
|
| 138 |
+
# store the reports that we generated for visualization purposes
|
| 139 |
+
# initially empty since no reports generated
|
| 140 |
+
self._generated_reports: Dict[str, Dict] = {}
|
| 141 |
+
|
| 142 |
+
def get_desired_reports_names(self) -> Set[str]:
|
| 143 |
+
""" Returns a copy of the desired reports for viewing """
|
| 144 |
+
return self._desired_detector_names.copy()
|
| 145 |
+
|
| 146 |
+
def get_observers_of_interest(self) -> Dict[str, Set[str]]:
|
| 147 |
+
""" Returns a copy of the observers of interest for viewing """
|
| 148 |
+
return self._detector_name_to_observer_fqns.copy()
|
| 149 |
+
|
| 150 |
+
def prepare_detailed_calibration(self) -> GraphModule:
|
| 151 |
+
r"""
|
| 152 |
+
Takes in a graph model and inserts the following observers:
|
| 153 |
+
- ModelReportObserver
|
| 154 |
+
|
| 155 |
+
Each observer is inserted based on the desired_reports into the relevant locations
|
| 156 |
+
|
| 157 |
+
Right now, each report in self._desired_detector_names has independent insertions
|
| 158 |
+
However, if a module already has a Observer of the same type, the insertion will not occur
|
| 159 |
+
This is because all of the same type of Observer collect same information, so redundant
|
| 160 |
+
|
| 161 |
+
Returns the same GraphModule with the observers inserted
|
| 162 |
+
"""
|
| 163 |
+
|
| 164 |
+
# if already prepared once, cannot prepare again
|
| 165 |
+
if self._prepared_flag:
|
| 166 |
+
raise ValueError("Already ran preparing detailed callibration. Run the report generation next after callibration.")
|
| 167 |
+
|
| 168 |
+
# loop through each detector, find where placements should be, and keep track
|
| 169 |
+
insert_observers_fqns: Dict[str, Any] = {}
|
| 170 |
+
|
| 171 |
+
for detector in self._desired_report_detectors:
|
| 172 |
+
# determine observer points for each detector
|
| 173 |
+
obs_fqn_to_info = detector.determine_observer_insert_points(self._model)
|
| 174 |
+
# map each insert point to the observer to use
|
| 175 |
+
insert_observers_fqns.update(obs_fqn_to_info)
|
| 176 |
+
# update the set of observers this report cares about
|
| 177 |
+
self._detector_name_to_observer_fqns[detector.get_detector_name()] = set(obs_fqn_to_info.keys())
|
| 178 |
+
|
| 179 |
+
# now insert all the observers at their desired locations
|
| 180 |
+
for observer_fqn in insert_observers_fqns:
|
| 181 |
+
target_node = insert_observers_fqns[observer_fqn][DETECTOR_TARGET_NODE_KEY]
|
| 182 |
+
insert_obs = insert_observers_fqns[observer_fqn][DETECTOR_OBS_TO_INSERT_KEY]
|
| 183 |
+
insert_post = insert_observers_fqns[observer_fqn][DETECTOR_IS_POST_OBS_KEY]
|
| 184 |
+
observer_args = insert_observers_fqns[observer_fqn][DETECTOR_OBS_ARGS_KEY]
|
| 185 |
+
self._insert_observer_around_module(
|
| 186 |
+
observer_fqn, target_node, insert_obs, observer_args, insert_post
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
self._prepared_flag = True
|
| 190 |
+
|
| 191 |
+
return self._model
|
| 192 |
+
|
| 193 |
+
def _insert_observer_around_module(
|
| 194 |
+
self,
|
| 195 |
+
obs_fqn: str,
|
| 196 |
+
target_node: torch.fx.node.Node,
|
| 197 |
+
obs_to_insert: ObserverBase,
|
| 198 |
+
observer_args: Tuple,
|
| 199 |
+
insert_post: bool
|
| 200 |
+
):
|
| 201 |
+
r"""
|
| 202 |
+
Helper function that inserts the observer into both the graph structure and the module of the model
|
| 203 |
+
|
| 204 |
+
Args
|
| 205 |
+
node_fqn (str): The fully qualified name of the observer we want to insert
|
| 206 |
+
target_node (torch.fx.node.Node): The node in model we are inserting observers around
|
| 207 |
+
obs_to_insert (ObserverBase): The observer we are inserting around target_node
|
| 208 |
+
observer_args (Tuple): The arguments we want to pass into the observer
|
| 209 |
+
insert_post (bool): whether this is meant to be a post observer for this node
|
| 210 |
+
"""
|
| 211 |
+
# if we are inserting post, then our target node is the next node
|
| 212 |
+
if insert_post:
|
| 213 |
+
target_node = target_node.next
|
| 214 |
+
|
| 215 |
+
with self._model.graph.inserting_before(target_node):
|
| 216 |
+
self._model.add_submodule(obs_fqn, obs_to_insert)
|
| 217 |
+
self._model.graph.create_node(op="call_module", target=obs_fqn, args=observer_args)
|
| 218 |
+
|
| 219 |
+
# recompile model after inserts are made
|
| 220 |
+
self._model.recompile()
|
| 221 |
+
|
| 222 |
+
def _get_node_from_fqn(self, node_fqn: str) -> torch.fx.node.Node:
|
| 223 |
+
r"""
|
| 224 |
+
Takes in a node fqn and returns the node based on the fqn
|
| 225 |
+
|
| 226 |
+
Args
|
| 227 |
+
node_fqn (str): The fully qualified name of the node we want to find in model
|
| 228 |
+
|
| 229 |
+
Returns the Node object of the given node_fqn otherwise returns None
|
| 230 |
+
"""
|
| 231 |
+
node_to_return = None
|
| 232 |
+
for node in self._model.graph.nodes:
|
| 233 |
+
# if the target matches the fqn, it's the node we are looking for
|
| 234 |
+
if node.target == node_fqn:
|
| 235 |
+
node_to_return = node
|
| 236 |
+
break
|
| 237 |
+
|
| 238 |
+
if node_to_return is None:
|
| 239 |
+
raise ValueError("The node_fqn is was not found within the module.")
|
| 240 |
+
|
| 241 |
+
# assert for MyPy
|
| 242 |
+
assert isinstance(node_to_return, torch.fx.node.Node)
|
| 243 |
+
|
| 244 |
+
return node_to_return
|
| 245 |
+
|
| 246 |
+
def generate_model_report(
|
| 247 |
+
self, remove_inserted_observers: bool
|
| 248 |
+
) -> Dict[str, Tuple[str, Dict]]:
|
| 249 |
+
r"""
|
| 250 |
+
Generates all the requested reports.
|
| 251 |
+
|
| 252 |
+
Note:
|
| 253 |
+
You should have callibrated the model with relevant data before calling this
|
| 254 |
+
|
| 255 |
+
The reports generated are specified by the desired_reports specified in desired_reports
|
| 256 |
+
|
| 257 |
+
Can optionally remove all the observers inserted by the ModelReport instance
|
| 258 |
+
|
| 259 |
+
Args:
|
| 260 |
+
remove_inserted_observers (bool): True to remove the observers inserted by this ModelReport instance
|
| 261 |
+
|
| 262 |
+
Returns a mapping of each desired report name to a tuple with:
|
| 263 |
+
The textual summary of that report information
|
| 264 |
+
A dictionary containing relevant statistics or information for that report
|
| 265 |
+
|
| 266 |
+
Note:
|
| 267 |
+
Throws exception if we try to generate report on model we already removed observers from
|
| 268 |
+
Throws exception if we try to generate report without preparing for callibration
|
| 269 |
+
"""
|
| 270 |
+
# if we haven't prepped model for callibration, then we shouldn't generate report yet
|
| 271 |
+
if not self._prepared_flag:
|
| 272 |
+
raise Exception("Cannot generate report without preparing model for callibration")
|
| 273 |
+
|
| 274 |
+
# if we already removed the observers, we cannot generate report
|
| 275 |
+
if self._removed_observers:
|
| 276 |
+
raise Exception("Cannot generate report on model you already removed observers from")
|
| 277 |
+
|
| 278 |
+
# keep track of all the reports of interest and their outputs
|
| 279 |
+
reports_of_interest = {}
|
| 280 |
+
|
| 281 |
+
for detector in self._desired_report_detectors:
|
| 282 |
+
# generate the individual report for the detector
|
| 283 |
+
report_output = detector.generate_detector_report(self._model)
|
| 284 |
+
reports_of_interest[detector.get_detector_name()] = report_output
|
| 285 |
+
|
| 286 |
+
# if user wishes to remove inserted observers, go ahead and remove
|
| 287 |
+
if remove_inserted_observers:
|
| 288 |
+
self._removed_observers = True
|
| 289 |
+
# get the set of all Observers inserted by this instance of ModelReport
|
| 290 |
+
all_observers_of_interest: Set[str] = set()
|
| 291 |
+
for desired_report in self._detector_name_to_observer_fqns:
|
| 292 |
+
observers_of_interest = self._detector_name_to_observer_fqns[desired_report]
|
| 293 |
+
all_observers_of_interest.update(observers_of_interest)
|
| 294 |
+
|
| 295 |
+
# go through all_observers_of_interest and remove them from the graph and model
|
| 296 |
+
for observer_fqn in all_observers_of_interest:
|
| 297 |
+
# remove the observer from the model
|
| 298 |
+
self._model.delete_submodule(observer_fqn)
|
| 299 |
+
|
| 300 |
+
# remove the observer from the graph structure
|
| 301 |
+
node_obj = self._get_node_from_fqn(observer_fqn)
|
| 302 |
+
|
| 303 |
+
if node_obj:
|
| 304 |
+
self._model.graph.erase_node(node_obj)
|
| 305 |
+
else:
|
| 306 |
+
raise ValueError("Node no longer exists in GraphModule structure")
|
| 307 |
+
|
| 308 |
+
# remember to recompile the model
|
| 309 |
+
self._model.recompile()
|
| 310 |
+
|
| 311 |
+
# save the generated reports for visualization purposes
|
| 312 |
+
saved_reports: Dict[str, Dict] = {
|
| 313 |
+
report_name : report_tuple[1] for report_name, report_tuple in reports_of_interest.items()
|
| 314 |
+
}
|
| 315 |
+
|
| 316 |
+
self._generated_reports = saved_reports
|
| 317 |
+
|
| 318 |
+
# return the reports of interest
|
| 319 |
+
return reports_of_interest
|
| 320 |
+
|
| 321 |
+
def _is_same_info_for_same_key(self, info_dict_a: Dict, info_dict_b: Dict) -> bool:
|
| 322 |
+
r"""
|
| 323 |
+
Takes in two dictionaries and ensures that any common keys between the two have the same
|
| 324 |
+
values.
|
| 325 |
+
|
| 326 |
+
Args:
|
| 327 |
+
info_dict_a (Dict): First dictionary we wish to compare
|
| 328 |
+
info_dict_b (Dict): Second dictionary we wish to compare
|
| 329 |
+
|
| 330 |
+
Returns True if all shared keys have same values, false otherwise
|
| 331 |
+
"""
|
| 332 |
+
# get the set of keys for both
|
| 333 |
+
dict_a_keys: Set = set(info_dict_a.keys())
|
| 334 |
+
dict_b_keys: Set = set(info_dict_b.keys())
|
| 335 |
+
|
| 336 |
+
# get the insersection keys and check if same value for both dicts
|
| 337 |
+
intersecting_keys: Set = dict_a_keys.intersection(dict_b_keys)
|
| 338 |
+
|
| 339 |
+
for key in intersecting_keys:
|
| 340 |
+
dict_a_val = info_dict_a[key]
|
| 341 |
+
dict_b_val = info_dict_b[key]
|
| 342 |
+
|
| 343 |
+
# if it's a tensor we have to handle separately
|
| 344 |
+
if type(dict_a_val) == torch.Tensor:
|
| 345 |
+
# if dict_b_val not tensor, automatically false
|
| 346 |
+
if type(dict_b_val) != torch.Tensor or sum(dict_a_val != dict_b_val) != 0:
|
| 347 |
+
return False
|
| 348 |
+
else:
|
| 349 |
+
# for non-tensor vals
|
| 350 |
+
if dict_a_val != dict_b_val:
|
| 351 |
+
return False
|
| 352 |
+
|
| 353 |
+
# if no non matching shared keys found, return true
|
| 354 |
+
return True
|
| 355 |
+
|
| 356 |
+
def _reformat_reports_for_visualizer(self) -> OrderedDict:
|
| 357 |
+
r"""
|
| 358 |
+
Takes the generated reports and reformats them into the format that is desired by the
|
| 359 |
+
ModelReportVisualizer
|
| 360 |
+
|
| 361 |
+
Returns an OrderedDict mapping module_fqns to their features
|
| 362 |
+
"""
|
| 363 |
+
# we want to reorder and reformat the information so it is ordered in terms of order
|
| 364 |
+
# found in the model
|
| 365 |
+
|
| 366 |
+
# first create new dict with all modules as keys and features under respective module
|
| 367 |
+
module_fqns_to_features: Dict[str, Dict] = {}
|
| 368 |
+
|
| 369 |
+
for report_name in self._generated_reports:
|
| 370 |
+
# get mod -> feature dict and go through
|
| 371 |
+
module_info = self._generated_reports[report_name]
|
| 372 |
+
|
| 373 |
+
for module_fqn in module_info:
|
| 374 |
+
# check if already in our accumulation dict
|
| 375 |
+
if module_fqn in module_fqns_to_features:
|
| 376 |
+
# we merge all the features together
|
| 377 |
+
new_info: Dict = module_info[module_fqn]
|
| 378 |
+
present_info: Dict = module_fqns_to_features[module_fqn]
|
| 379 |
+
|
| 380 |
+
# merge them together into the new unioned dict
|
| 381 |
+
# same features keys -> same info, so okay if override
|
| 382 |
+
|
| 383 |
+
# do safety check to make sure shared keys have same info
|
| 384 |
+
if self._is_same_info_for_same_key(new_info, present_info):
|
| 385 |
+
module_fqns_to_features[module_fqn] = {**new_info, **present_info}
|
| 386 |
+
else:
|
| 387 |
+
error_str = "You have the same key with different values across detectors. "
|
| 388 |
+
error_str += "Someone incorrectly implemented a detector with conflicting keys to existing detectors."
|
| 389 |
+
raise ValueError(error_str)
|
| 390 |
+
else:
|
| 391 |
+
# we just set it
|
| 392 |
+
module_fqns_to_features[module_fqn] = module_info[module_fqn]
|
| 393 |
+
|
| 394 |
+
# our ordered dict so that modules can be ordered in order of how they appear in model
|
| 395 |
+
features_by_module: OrderedDict[str, Dict] = OrderedDict()
|
| 396 |
+
|
| 397 |
+
# we loop through modules in graph in order
|
| 398 |
+
for fqn, module in self._model.named_modules():
|
| 399 |
+
# find that fqn in fqns_to_features
|
| 400 |
+
if fqn in module_fqns_to_features:
|
| 401 |
+
# add it to our ordered dict
|
| 402 |
+
features_by_module[fqn] = module_fqns_to_features[fqn]
|
| 403 |
+
|
| 404 |
+
# return the ordered dict of info we created
|
| 405 |
+
return features_by_module
|
| 406 |
+
|
| 407 |
+
def generate_visualizer(self) -> ModelReportVisualizer:
|
| 408 |
+
r"""
|
| 409 |
+
Generates a ModelReportVisualizer instance using the reports generated
|
| 410 |
+
by the generate_model_report() method.
|
| 411 |
+
|
| 412 |
+
Returns the generated ModelReportVisualizer instance initialized
|
| 413 |
+
|
| 414 |
+
Note:
|
| 415 |
+
Throws exception if attempt to get visualizers without generating report
|
| 416 |
+
"""
|
| 417 |
+
# check if user has generated reports at least once
|
| 418 |
+
if len(self._generated_reports) == 0:
|
| 419 |
+
raise Exception("Unable to generate visualizers without first generating reports")
|
| 420 |
+
|
| 421 |
+
# get the ordered dict mapping modules to their full set of collected features / stats
|
| 422 |
+
module_fqns_to_features: OrderedDict = self._reformat_reports_for_visualizer()
|
| 423 |
+
|
| 424 |
+
# create and return ModelReportVisualizer instance
|
| 425 |
+
visualizer: ModelReportVisualizer = ModelReportVisualizer(module_fqns_to_features)
|
| 426 |
+
|
| 427 |
+
return visualizer
|
| 428 |
+
|
| 429 |
+
def _generate_qconfig_mapping_helper(
|
| 430 |
+
self,
|
| 431 |
+
detector_qconfig_info_combined: Dict[str, DetectorQConfigInfo],
|
| 432 |
+
generation_function: Callable
|
| 433 |
+
) -> QConfigMapping:
|
| 434 |
+
r"""
|
| 435 |
+
This helper takes in the compiled detector qconfig info that
|
| 436 |
+
has been compiled together and merges it into a QConfigMapping
|
| 437 |
+
"""
|
| 438 |
+
# keep track of the qconfigmapping
|
| 439 |
+
qconfig_mapping = QConfigMapping()
|
| 440 |
+
|
| 441 |
+
# loop through each module / fqn and attempt to create QConfigMapping
|
| 442 |
+
for fqn, module in self._model.named_modules():
|
| 443 |
+
# if we have a qconfig info for this module
|
| 444 |
+
if fqn in detector_qconfig_info_combined:
|
| 445 |
+
qconfig_info_compiled = detector_qconfig_info_combined[fqn]
|
| 446 |
+
|
| 447 |
+
# now generate the qconfig and add it to the mapping
|
| 448 |
+
generated_qconfig = generation_function(qconfig_info_compiled, module)
|
| 449 |
+
|
| 450 |
+
# add to our config
|
| 451 |
+
qconfig_mapping.set_module_name(fqn, generated_qconfig)
|
| 452 |
+
|
| 453 |
+
# return compiled mapping
|
| 454 |
+
return qconfig_mapping
|
| 455 |
+
|
| 456 |
+
def _update_detector_quantizaiton_qconfig_info(self, combined_info: DetectorQConfigInfo, new_info: DetectorQConfigInfo):
|
| 457 |
+
r"""
|
| 458 |
+
Takes in the old and new information and updates the combined information.
|
| 459 |
+
|
| 460 |
+
Args:
|
| 461 |
+
combined_info (DetectorQConfigInfo): The DetectorQConfigInfo we are compiling all of the information in
|
| 462 |
+
new_info (DetectorQConfigInfo): The DetectorQConfigInfo with the information we are trying to merge the new info
|
| 463 |
+
into it
|
| 464 |
+
"""
|
| 465 |
+
combined_info.is_activation_dynamic = combined_info.is_activation_dynamic or new_info.is_activation_dynamic
|
| 466 |
+
combined_info.is_weight_per_channel = combined_info.is_weight_per_channel or new_info.is_weight_per_channel
|
| 467 |
+
|
| 468 |
+
def _update_detector_equalization_qconfig_info(self, combined_info: DetectorQConfigInfo, new_info: DetectorQConfigInfo):
|
| 469 |
+
r"""
|
| 470 |
+
Takes in the old and new information and updates the combined information.
|
| 471 |
+
|
| 472 |
+
Args:
|
| 473 |
+
combined_info (DetectorQConfigInfo): The DetectorQConfigInfo we are compiling all of the information in
|
| 474 |
+
new_info (DetectorQConfigInfo): The DetectorQConfigInfo with the information we are trying to merge the new info
|
| 475 |
+
into it
|
| 476 |
+
"""
|
| 477 |
+
is_equalization_recommended = combined_info.is_equalization_recommended or new_info.is_equalization_recommended
|
| 478 |
+
combined_info.is_equalization_recommended = is_equalization_recommended
|
| 479 |
+
|
| 480 |
+
def _generate_module_fqn_to_detector_info_mapping(
|
| 481 |
+
self,
|
| 482 |
+
update_qconfig_info_function: Callable
|
| 483 |
+
) -> Dict[str, DetectorQConfigInfo]:
|
| 484 |
+
r"""
|
| 485 |
+
Generates a QConfigMapping based on the suggestions of the
|
| 486 |
+
ModelReport API. The generated mapping encompasses all the
|
| 487 |
+
different types of feedback from the different detectors
|
| 488 |
+
all into one place.
|
| 489 |
+
|
| 490 |
+
These configs are based on the suggestions provided by the ModelReport API
|
| 491 |
+
and can only be generated once the reports have been generated.
|
| 492 |
+
|
| 493 |
+
Args:
|
| 494 |
+
update_qconfig_info_function (Callable) takes in a function that takes in two DetectorQConfigInfo
|
| 495 |
+
and updates the one that is being compiled
|
| 496 |
+
|
| 497 |
+
Returns a Dict mapping module_fqns to DetectorQConfigInfo objects
|
| 498 |
+
|
| 499 |
+
Note:
|
| 500 |
+
Throws exception if we try to generate mapping on model we already removed observers from
|
| 501 |
+
Throws exception if we try to generate mapping without preparing for callibration
|
| 502 |
+
"""
|
| 503 |
+
# if we haven't prepped model for callibration, then we shouldn't generate mapping yet
|
| 504 |
+
if not self._prepared_flag:
|
| 505 |
+
raise Exception("Cannot generate report without preparing model for callibration")
|
| 506 |
+
|
| 507 |
+
# if we already removed the observers, we cannot mapping
|
| 508 |
+
if self._removed_observers:
|
| 509 |
+
raise Exception("Cannot generate report on model you already removed observers from")
|
| 510 |
+
|
| 511 |
+
# keep track of qconfig info for each module across detectors
|
| 512 |
+
detector_qconfig_info_combined: Dict[str, DetectorQConfigInfo] = {}
|
| 513 |
+
|
| 514 |
+
for detector in self._desired_report_detectors:
|
| 515 |
+
# get the info from the detector
|
| 516 |
+
detector_info: Dict[str, DetectorQConfigInfo] = detector.get_qconfig_info(self._model)
|
| 517 |
+
|
| 518 |
+
# we go through the modules
|
| 519 |
+
for module_fqn in detector_info:
|
| 520 |
+
# see if we already have info on it
|
| 521 |
+
if module_fqn in detector_qconfig_info_combined:
|
| 522 |
+
# we combine the current options with what is there
|
| 523 |
+
current_options = detector_qconfig_info_combined[module_fqn]
|
| 524 |
+
detector_options = detector_info[module_fqn]
|
| 525 |
+
|
| 526 |
+
update_qconfig_info_function(current_options, detector_options)
|
| 527 |
+
else:
|
| 528 |
+
# we just use this for now
|
| 529 |
+
detector_qconfig_info_combined[module_fqn] = detector_info[module_fqn]
|
| 530 |
+
|
| 531 |
+
return detector_qconfig_info_combined
|
| 532 |
+
|
| 533 |
+
def generate_qconfig_mapping(self) -> QConfigMapping:
|
| 534 |
+
r"""
|
| 535 |
+
Generates a QConfigMapping based on the suggestions of the
|
| 536 |
+
ModelReport API. The generated mapping encompasses all the
|
| 537 |
+
different types of feedback from the different detectors
|
| 538 |
+
all into one place.
|
| 539 |
+
|
| 540 |
+
These configs are based on the suggestions provided by the ModelReport API
|
| 541 |
+
and can only be generated once the reports have been generated.
|
| 542 |
+
|
| 543 |
+
Returns a QConfigMapping for the quantization configuration
|
| 544 |
+
|
| 545 |
+
Note:
|
| 546 |
+
Throws exception if we try to generate mapping on model we already removed observers from
|
| 547 |
+
Throws exception if we try to generate mapping without preparing for callibration
|
| 548 |
+
"""
|
| 549 |
+
# get the mapping info
|
| 550 |
+
detector_qconfig_info_combined = self._generate_module_fqn_to_detector_info_mapping(
|
| 551 |
+
self._update_detector_quantizaiton_qconfig_info
|
| 552 |
+
)
|
| 553 |
+
|
| 554 |
+
# we will do a bit of processing and remove fqns that don't have input weight recommended
|
| 555 |
+
|
| 556 |
+
# now we generate the QConfig for each of the options
|
| 557 |
+
mapping: QConfigMapping = self._generate_qconfig_mapping_helper(
|
| 558 |
+
detector_qconfig_info_combined,
|
| 559 |
+
self._quantization_config_generator
|
| 560 |
+
)
|
| 561 |
+
|
| 562 |
+
# return the generated mapping
|
| 563 |
+
return mapping
|
| 564 |
+
|
| 565 |
+
def _quantization_config_generator(self, detector_qconfig_info: DetectorQConfigInfo, module: torch.nn.Module) -> QConfig:
|
| 566 |
+
r"""
|
| 567 |
+
Returns the quantization configuration generated by the DetectorQConfigInfo object
|
| 568 |
+
"""
|
| 569 |
+
return detector_qconfig_info.generate_quantization_qconfig(module)
|
| 570 |
+
|
| 571 |
+
def _equalization_config_generator(
|
| 572 |
+
self,
|
| 573 |
+
detector_qconfig_info: DetectorQConfigInfo,
|
| 574 |
+
module: torch.nn.Module
|
| 575 |
+
) -> EqualizationQConfig:
|
| 576 |
+
r"""
|
| 577 |
+
We ignore the module argument here, and only focus on thedetector_qconfig_info
|
| 578 |
+
|
| 579 |
+
Returns the equalization configuration generated by the DetectorQConfigInfo object
|
| 580 |
+
"""
|
| 581 |
+
return detector_qconfig_info.generate_equalization_qconfig()
|
| 582 |
+
|
| 583 |
+
def generate_equalization_mapping(self) -> QConfigMapping:
|
| 584 |
+
r"""
|
| 585 |
+
Generates a QConfigMapping based on the suggestions of the
|
| 586 |
+
ModelReport API for equalization. The generated mapping encompasses all the
|
| 587 |
+
different types of feedback from the input-weight equalization detector.
|
| 588 |
+
|
| 589 |
+
These configs are based on the suggestions provided by the ModelReport API
|
| 590 |
+
and can only be generated once the reports have been generated.
|
| 591 |
+
|
| 592 |
+
Returns a QConfigMapping for the equalization configuration
|
| 593 |
+
"""
|
| 594 |
+
# get the mapping info
|
| 595 |
+
detector_qconfig_info_combined = self._generate_module_fqn_to_detector_info_mapping(
|
| 596 |
+
self._update_detector_equalization_qconfig_info
|
| 597 |
+
)
|
| 598 |
+
|
| 599 |
+
# now we generate the QConfig for each of the options
|
| 600 |
+
mapping: QConfigMapping = self._generate_qconfig_mapping_helper(
|
| 601 |
+
detector_qconfig_info_combined,
|
| 602 |
+
self._equalization_config_generator
|
| 603 |
+
)
|
| 604 |
+
|
| 605 |
+
# return the generated mapping
|
| 606 |
+
return mapping
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/prepare.py
ADDED
|
@@ -0,0 +1,1880 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import torch
|
| 3 |
+
import warnings
|
| 4 |
+
from torch.fx import (
|
| 5 |
+
GraphModule,
|
| 6 |
+
)
|
| 7 |
+
from torch.fx.graph import (
|
| 8 |
+
Graph,
|
| 9 |
+
Node,
|
| 10 |
+
)
|
| 11 |
+
from torch.fx.node import Argument
|
| 12 |
+
|
| 13 |
+
from ..quantize import (
|
| 14 |
+
propagate_qconfig_,
|
| 15 |
+
)
|
| 16 |
+
from ..observer import (
|
| 17 |
+
_is_activation_post_process,
|
| 18 |
+
_PartialWrapper,
|
| 19 |
+
)
|
| 20 |
+
from ..qconfig import (
|
| 21 |
+
_is_reuse_input_qconfig,
|
| 22 |
+
QConfigAny,
|
| 23 |
+
)
|
| 24 |
+
from ..qconfig_mapping import (
|
| 25 |
+
QConfigMapping,
|
| 26 |
+
)
|
| 27 |
+
from .qconfig_mapping_utils import (
|
| 28 |
+
_generate_node_name_to_qconfig,
|
| 29 |
+
_update_qconfig_for_fusion,
|
| 30 |
+
_get_flattened_qconfig_dict,
|
| 31 |
+
_update_qconfig_for_qat,
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
from .quantize_handler import (
|
| 35 |
+
_default_root_node_getter,
|
| 36 |
+
_get_pattern_to_quantize_handlers,
|
| 37 |
+
QuantizeHandler,
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
from torch.ao.quantization import (
|
| 41 |
+
ObserverBase,
|
| 42 |
+
FixedQParamsObserver,
|
| 43 |
+
FixedQParamsFakeQuantize,
|
| 44 |
+
_DerivedObserverOrFakeQuantize,
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
from torch.ao.quantization.utils import (
|
| 48 |
+
Pattern,
|
| 49 |
+
NodePattern,
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
from ._equalize import (
|
| 53 |
+
is_equalization_observer,
|
| 54 |
+
node_supports_equalization,
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
from .pattern_utils import (
|
| 58 |
+
_sorted_patterns_dict,
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
from .match_utils import (
|
| 62 |
+
_MatchResultWithQConfig,
|
| 63 |
+
_find_matches,
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
from .utils import (
|
| 67 |
+
_insert_dequant_stubs_for_custom_module_lstm_output,
|
| 68 |
+
_is_custom_module_lstm,
|
| 69 |
+
_maybe_get_custom_module_lstm_from_node_arg,
|
| 70 |
+
_qconfig_satisfies_dtype_config_constraints,
|
| 71 |
+
get_custom_module_class_keys,
|
| 72 |
+
all_node_args_have_no_tensors,
|
| 73 |
+
assert_and_get_unique_device,
|
| 74 |
+
get_non_observable_arg_indexes_and_types,
|
| 75 |
+
get_new_attr_name_with_prefix,
|
| 76 |
+
node_arg_is_weight,
|
| 77 |
+
node_arg_is_bias,
|
| 78 |
+
NON_QUANTIZABLE_WEIGHT_OPS,
|
| 79 |
+
ObservedGraphModuleAttrs,
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
from torch.ao.quantization import (
|
| 83 |
+
PlaceholderObserver
|
| 84 |
+
)
|
| 85 |
+
from torch.ao.quantization.quantize import (
|
| 86 |
+
convert
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
from ..utils import (
|
| 90 |
+
_parent_name,
|
| 91 |
+
get_qconfig_dtypes,
|
| 92 |
+
get_swapped_custom_module_class,
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
from ..backend_config.utils import (
|
| 96 |
+
get_pattern_to_dtype_configs,
|
| 97 |
+
get_module_to_qat_module,
|
| 98 |
+
get_fusion_pattern_to_root_node_getter,
|
| 99 |
+
)
|
| 100 |
+
from ..backend_config import (
|
| 101 |
+
BackendConfig,
|
| 102 |
+
DTypeConfig,
|
| 103 |
+
get_native_backend_config,
|
| 104 |
+
)
|
| 105 |
+
from .custom_config import (
|
| 106 |
+
PrepareCustomConfig,
|
| 107 |
+
StandaloneModuleConfigEntry,
|
| 108 |
+
)
|
| 109 |
+
from torch.ao.quantization.quantizer import (
|
| 110 |
+
EdgeOrNode,
|
| 111 |
+
QuantizationSpec,
|
| 112 |
+
QuantizationSpecBase,
|
| 113 |
+
FixedQParamsQuantizationSpec,
|
| 114 |
+
SharedQuantizationSpec,
|
| 115 |
+
DerivedQuantizationSpec,
|
| 116 |
+
)
|
| 117 |
+
from torch.ao.quantization import ObserverOrFakeQuantize
|
| 118 |
+
|
| 119 |
+
from torch._subclasses import FakeTensor
|
| 120 |
+
|
| 121 |
+
from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
|
| 122 |
+
from dataclasses import asdict
|
| 123 |
+
|
| 124 |
+
__all__ = [
|
| 125 |
+
"insert_observers_for_model",
|
| 126 |
+
"prepare",
|
| 127 |
+
"propagate_dtypes_for_known_nodes",
|
| 128 |
+
]
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
# list of dtypes to not add observers to
|
| 132 |
+
_DO_NOT_OBS_DTYPE_LIST = [int, float, torch.bool, None]
|
| 133 |
+
_OBS_DTYPE_LIST = [
|
| 134 |
+
torch.quint8,
|
| 135 |
+
torch.qint8,
|
| 136 |
+
torch.qint32,
|
| 137 |
+
torch.float16,
|
| 138 |
+
torch.uint8,
|
| 139 |
+
torch.int8,
|
| 140 |
+
torch.int16,
|
| 141 |
+
torch.int32
|
| 142 |
+
]
|
| 143 |
+
|
| 144 |
+
_DEFAULT_FP32_OBS_OR_FQ_CTR = PlaceholderObserver.with_args(dtype=torch.float)
|
| 145 |
+
|
| 146 |
+
# note: the following default target dtype info dicts are temporary,
|
| 147 |
+
# should be moved to the new programmable API class soon
|
| 148 |
+
_DEFAULT_FP32_QCONFIG_FOR_TARGET_DTYPE_INFO = {
|
| 149 |
+
"input_act_obs_or_fq_ctr": torch.ao.quantization.qconfig._default_fp32_placeholder_qconfig.activation,
|
| 150 |
+
"output_act_obs_or_fq_ctr": torch.ao.quantization.qconfig._default_fp32_placeholder_qconfig.activation
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
_DEFAULT_QUINT8_QCONFIG_FOR_TARGET_DTYPE_INFO = {
|
| 154 |
+
"input_act_obs_or_fq_ctr": torch.ao.quantization.qconfig._default_quint8_placeholder_qconfig.activation,
|
| 155 |
+
"output_act_obs_or_fq_ctr": torch.ao.quantization.qconfig._default_quint8_placeholder_qconfig.activation
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def _get_observer_kwargs(quant_spec: Union[QuantizationSpec, FixedQParamsQuantizationSpec]):
|
| 160 |
+
kwargs_dict = asdict(quant_spec)
|
| 161 |
+
return copy.deepcopy(kwargs_dict)
|
| 162 |
+
|
| 163 |
+
def _get_qspec_for_arg(
|
| 164 |
+
arg: Node,
|
| 165 |
+
input_qspec_map: Dict[Node, QuantizationSpecBase],
|
| 166 |
+
named_modules: Dict[str, torch.nn.Module]
|
| 167 |
+
) -> Optional[QuantizationSpecBase]:
|
| 168 |
+
while _is_activation_post_process_node(arg, named_modules):
|
| 169 |
+
arg = arg.args[0] # type: ignore[assignment]
|
| 170 |
+
return input_qspec_map.get(arg, None)
|
| 171 |
+
|
| 172 |
+
def _create_obs_or_fq_from_qspec(
|
| 173 |
+
quantization_spec: Optional[QuantizationSpecBase],
|
| 174 |
+
obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize],
|
| 175 |
+
is_qat: bool,
|
| 176 |
+
):
|
| 177 |
+
""" Create observer or fake quantize objects based on quantization spec
|
| 178 |
+
|
| 179 |
+
Args:
|
| 180 |
+
quantization_spec: used to store parameters to create the observer or fake quantizer
|
| 181 |
+
obs_or_fq_map: this is a map from edge/output to the corresponding observer/fake_quant
|
| 182 |
+
instance, it may be reused for different edge/output depending on configuration
|
| 183 |
+
"""
|
| 184 |
+
if quantization_spec is None:
|
| 185 |
+
return None
|
| 186 |
+
if isinstance(quantization_spec, SharedQuantizationSpec):
|
| 187 |
+
edge_or_node = quantization_spec.edge_or_node
|
| 188 |
+
assert edge_or_node in obs_or_fq_map, \
|
| 189 |
+
"please make sure only refer to edge or node that has " \
|
| 190 |
+
f"observer/fake_quant inserted: '{edge_or_node}' not in\n{obs_or_fq_map.keys()}"
|
| 191 |
+
return obs_or_fq_map[edge_or_node]
|
| 192 |
+
elif isinstance(quantization_spec, DerivedQuantizationSpec):
|
| 193 |
+
# can't use asdict, so not calling get_observer_kwargs here
|
| 194 |
+
kwargs = {
|
| 195 |
+
"dtype": quantization_spec.dtype,
|
| 196 |
+
"derive_qparams_fn": quantization_spec.derive_qparams_fn,
|
| 197 |
+
"quant_min": quantization_spec.quant_min,
|
| 198 |
+
"quant_max": quantization_spec.quant_max,
|
| 199 |
+
"qscheme": quantization_spec.qscheme,
|
| 200 |
+
"ch_axis": quantization_spec.ch_axis,
|
| 201 |
+
}
|
| 202 |
+
edge_or_nodes = quantization_spec.derived_from
|
| 203 |
+
obs_or_fqs = [obs_or_fq_map[k] for k in edge_or_nodes]
|
| 204 |
+
kwargs["obs_or_fqs"] = obs_or_fqs
|
| 205 |
+
return _DerivedObserverOrFakeQuantize.with_args(**kwargs)()
|
| 206 |
+
elif isinstance(quantization_spec, FixedQParamsQuantizationSpec):
|
| 207 |
+
kwargs = _get_observer_kwargs(quantization_spec)
|
| 208 |
+
observer_ctr = FixedQParamsObserver.with_args(**kwargs)
|
| 209 |
+
if is_qat:
|
| 210 |
+
return FixedQParamsFakeQuantize.with_args(observer=observer_ctr)
|
| 211 |
+
else:
|
| 212 |
+
return observer_ctr()
|
| 213 |
+
|
| 214 |
+
assert isinstance(quantization_spec, QuantizationSpec)
|
| 215 |
+
observer_or_fake_quant_ctr = quantization_spec.observer_or_fake_quant_ctr
|
| 216 |
+
kwargs = _get_observer_kwargs(quantization_spec)
|
| 217 |
+
kwargs.pop("observer_or_fake_quant_ctr")
|
| 218 |
+
# we will remove is_dynamic from QuantizationSpec because
|
| 219 |
+
# it seems that dynamic range quantization
|
| 220 |
+
obs_or_fq_class = observer_or_fake_quant_ctr
|
| 221 |
+
if isinstance(observer_or_fake_quant_ctr, _PartialWrapper):
|
| 222 |
+
obs_or_fq_class = observer_or_fake_quant_ctr.p.func # type: ignore[union-attr, assignment]
|
| 223 |
+
if "PerChannel" not in obs_or_fq_class.__name__: # type: ignore[operator, union-attr]
|
| 224 |
+
kwargs.pop("ch_axis")
|
| 225 |
+
return observer_or_fake_quant_ctr.with_args(**kwargs)()
|
| 226 |
+
|
| 227 |
+
def _needs_obs_or_fq(
|
| 228 |
+
prev_output_dtype: Any,
|
| 229 |
+
prev_output_is_dynamic: bool,
|
| 230 |
+
cur_target_dtype: Any,
|
| 231 |
+
cur_target_is_dynamic: bool,
|
| 232 |
+
reuse_input_obs_or_fq: bool,
|
| 233 |
+
is_zeroth_arg: bool = False) -> bool:
|
| 234 |
+
"""
|
| 235 |
+
note: we will treat "not specified" as torch.float for now
|
| 236 |
+
utility function that checks if we should insert an observer or fake quant node
|
| 237 |
+
base on the requested dtype for the nodes from user
|
| 238 |
+
|
| 239 |
+
is_zeroth_arg: we only dynamically quantize the first arg of the node right now
|
| 240 |
+
this should be removed when we enable configuring dynamic quantization
|
| 241 |
+
for a specific argument, this can be removed if we deprecate fx graph mode
|
| 242 |
+
quantization
|
| 243 |
+
|
| 244 |
+
"""
|
| 245 |
+
|
| 246 |
+
# need to insert placeholder observer for dynamic quantization so that it can
|
| 247 |
+
# be converted to choose_qparams -> q -> dq in convert step
|
| 248 |
+
if cur_target_is_dynamic:
|
| 249 |
+
assert cur_target_dtype in _OBS_DTYPE_LIST, \
|
| 250 |
+
f"Expected cur_target_dtype to be torch.float, but got: {cur_target_dtype}"
|
| 251 |
+
assert prev_output_dtype not in _DO_NOT_OBS_DTYPE_LIST
|
| 252 |
+
return is_zeroth_arg
|
| 253 |
+
if reuse_input_obs_or_fq:
|
| 254 |
+
return False
|
| 255 |
+
# non dynamic quantization
|
| 256 |
+
if cur_target_dtype in _OBS_DTYPE_LIST:
|
| 257 |
+
return prev_output_dtype in _OBS_DTYPE_LIST + [torch.float] and cur_target_dtype != prev_output_dtype
|
| 258 |
+
|
| 259 |
+
# lots of error checking are skipped here for now
|
| 260 |
+
return False
|
| 261 |
+
|
| 262 |
+
def _is_activation_post_process_node(node: Node, named_modules: Dict[str, torch.nn.Module]) -> bool:
|
| 263 |
+
return isinstance(node, torch.fx.Node) and node.op == "call_module" and \
|
| 264 |
+
_is_activation_post_process(named_modules[str(node.target)])
|
| 265 |
+
|
| 266 |
+
def _get_dtype_and_is_dynamic(obs_or_fq: Optional[ObserverOrFakeQuantize]) -> Tuple[Optional[torch.dtype], bool]:
|
| 267 |
+
""" Given a constructor for observer or fake quant module, returns
|
| 268 |
+
a Tuple of dtype and is_dynamic
|
| 269 |
+
"""
|
| 270 |
+
# TODO: instead of instantiating the instance, we can use inspect to get the default args
|
| 271 |
+
if obs_or_fq is None:
|
| 272 |
+
return None, False
|
| 273 |
+
else:
|
| 274 |
+
return obs_or_fq.dtype, getattr(obs_or_fq, "is_dynamic", False) # type: ignore[return-value]
|
| 275 |
+
|
| 276 |
+
def _is_input_arg_dtype_supported_by_backend(
|
| 277 |
+
arg: Argument,
|
| 278 |
+
node: Node,
|
| 279 |
+
qconfig: QConfigAny,
|
| 280 |
+
dtype_config: DTypeConfig,
|
| 281 |
+
backend_config: BackendConfig,
|
| 282 |
+
) -> bool:
|
| 283 |
+
""" Check if the configured qconfig for the argument
|
| 284 |
+
is supported by the backend or not
|
| 285 |
+
"""
|
| 286 |
+
if isinstance(arg, (list, tuple)):
|
| 287 |
+
return all(_is_input_arg_dtype_supported_by_backend(
|
| 288 |
+
a, node, qconfig,
|
| 289 |
+
dtype_config, backend_config) for a in arg)
|
| 290 |
+
if not isinstance(arg, Node):
|
| 291 |
+
return True
|
| 292 |
+
# TODO: support check for standalone module
|
| 293 |
+
is_weight = node_arg_is_weight(node, arg)
|
| 294 |
+
is_bias = node_arg_is_bias(node, arg)
|
| 295 |
+
is_activation = not is_weight and not is_bias
|
| 296 |
+
if is_activation:
|
| 297 |
+
input_act_obs_or_fq_ctr = node.meta["target_dtype_info"].get("input_act_obs_or_fq_ctr")
|
| 298 |
+
input_act_obs_or_fq = input_act_obs_or_fq_ctr() if input_act_obs_or_fq_ctr else None
|
| 299 |
+
qconfig_dtype, qconfig_is_dynamic = _get_dtype_and_is_dynamic(input_act_obs_or_fq)
|
| 300 |
+
# TODO(future PR): remove the cast to bool below after figuring
|
| 301 |
+
# out why backend_config has is_dynamic set to None in some cases.
|
| 302 |
+
return (dtype_config.input_dtype is None) or (
|
| 303 |
+
dtype_config.input_dtype == qconfig_dtype and
|
| 304 |
+
bool(dtype_config.is_dynamic) == bool(qconfig_is_dynamic) and
|
| 305 |
+
_qconfig_satisfies_dtype_config_constraints(qconfig, dtype_config.input_dtype_with_constraints)
|
| 306 |
+
)
|
| 307 |
+
elif is_weight:
|
| 308 |
+
# TODO: move dtype check into `_qconfig_satisfies_dtype_config_constraints` as well
|
| 309 |
+
weight_obs_or_fq_ctr = node.meta["target_dtype_info"].get("weight_obs_or_fq_ctr", None)
|
| 310 |
+
weight_obs_or_fq = weight_obs_or_fq_ctr() if weight_obs_or_fq_ctr else None
|
| 311 |
+
qconfig_weight_dtype, _ = _get_dtype_and_is_dynamic(weight_obs_or_fq)
|
| 312 |
+
backend_config_weight_dtype = dtype_config.weight_dtype
|
| 313 |
+
dtype_matches = qconfig_weight_dtype == backend_config_weight_dtype
|
| 314 |
+
qconfig_satisfies_constraints = _qconfig_satisfies_dtype_config_constraints(
|
| 315 |
+
qconfig, dtype_config.weight_dtype_with_constraints, is_activation=False)
|
| 316 |
+
return backend_config_weight_dtype is None or (dtype_matches and qconfig_satisfies_constraints)
|
| 317 |
+
else: # bias
|
| 318 |
+
# TODO: move dtype check into `_qconfig_satisfies_dtype_config_constraints` as well
|
| 319 |
+
bias_obs_or_fq_ctr = node.meta["target_dtype_info"].get("bias_obs_or_fq_ctr", None)
|
| 320 |
+
bias_obs_or_fq = bias_obs_or_fq_ctr() if bias_obs_or_fq_ctr else None
|
| 321 |
+
qconfig_bias_dtype, _ = _get_dtype_and_is_dynamic(bias_obs_or_fq)
|
| 322 |
+
backend_config_bias_dtype = dtype_config.bias_dtype
|
| 323 |
+
return backend_config_bias_dtype is None or qconfig_bias_dtype == backend_config_bias_dtype
|
| 324 |
+
|
| 325 |
+
def _is_output_dtype_supported_by_backend(
|
| 326 |
+
node: Node,
|
| 327 |
+
qconfig: QConfigAny,
|
| 328 |
+
dtype_config: DTypeConfig,
|
| 329 |
+
) -> bool:
|
| 330 |
+
""" Check if the configured qconfig for the output
|
| 331 |
+
is supported by the backend or not
|
| 332 |
+
"""
|
| 333 |
+
# TODO: move dtype check into `_qconfig_satisfies_dtype_config_constraints` as well
|
| 334 |
+
backend_config_output_dtype = dtype_config.output_dtype
|
| 335 |
+
# TODO: we should check is_dynamic here as well, the code from _is_input_arg_dtype_supported_by_backend
|
| 336 |
+
# from input activation check can be reused here
|
| 337 |
+
qconfig_output_dtype = None
|
| 338 |
+
output_act_obs_or_fq_ctr = node.meta["target_dtype_info"].get("output_act_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR)
|
| 339 |
+
output_act_obs_or_fq = output_act_obs_or_fq_ctr() if output_act_obs_or_fq_ctr else None
|
| 340 |
+
qconfig_output_dtype, qconfig_output_is_dynamic = _get_dtype_and_is_dynamic(output_act_obs_or_fq)
|
| 341 |
+
# TODO: this is a hack because we can only specify one activation_obs_or_fq for
|
| 342 |
+
# qconfig (qconfig.activation), and we are only supporting dynamically quantized
|
| 343 |
+
# linear op which has fp32 output dtype, this should be removed if we generalize
|
| 344 |
+
# the structure of qconfig in the future
|
| 345 |
+
if qconfig_output_is_dynamic:
|
| 346 |
+
qconfig_output_dtype = torch.float32
|
| 347 |
+
dtype_matches = qconfig_output_dtype == backend_config_output_dtype
|
| 348 |
+
qconfig_satisfies_constraints = _qconfig_satisfies_dtype_config_constraints(
|
| 349 |
+
qconfig, dtype_config.output_dtype_with_constraints)
|
| 350 |
+
return backend_config_output_dtype is None or (dtype_matches and qconfig_satisfies_constraints)
|
| 351 |
+
|
| 352 |
+
def _is_observer_in_same_graph(
|
| 353 |
+
node: Node,
|
| 354 |
+
named_modules: Dict[str, torch.nn.Module],
|
| 355 |
+
obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize],
|
| 356 |
+
is_qat,
|
| 357 |
+
):
|
| 358 |
+
""" Check if observer in same graph
|
| 359 |
+
when the node output is not fp32 and input is 'placeholder'
|
| 360 |
+
the input is assumed to be quantized, so it is observed
|
| 361 |
+
in a different place rather than not observed.
|
| 362 |
+
"""
|
| 363 |
+
node_output_dtype = _get_arg_target_dtype_as_output(node, named_modules, obs_or_fq_map, is_qat)
|
| 364 |
+
if len(node.args) > 0 and isinstance(node.args[0], Node):
|
| 365 |
+
if node_output_dtype in [torch.quint8, torch.uint8] and node.args[0].op == 'placeholder':
|
| 366 |
+
return False
|
| 367 |
+
return True
|
| 368 |
+
|
| 369 |
+
def _is_pattern_dtype_config_and_qconfig_supported_by_backend(
|
| 370 |
+
pattern: Optional[Pattern],
|
| 371 |
+
matched_node_pattern: Optional[List[Node]],
|
| 372 |
+
qconfig: QConfigAny,
|
| 373 |
+
backend_config: BackendConfig,
|
| 374 |
+
) -> bool:
|
| 375 |
+
""" Check if the dtype configuration of a pattern is supported by
|
| 376 |
+
the backend or not, and whether the qconfig satisfies constraints
|
| 377 |
+
specified in the corresponding dtype config.
|
| 378 |
+
"""
|
| 379 |
+
if backend_config is None or pattern is None:
|
| 380 |
+
return True
|
| 381 |
+
assert matched_node_pattern is not None and len(matched_node_pattern) >= 1
|
| 382 |
+
pattern_to_dtype_configs = get_pattern_to_dtype_configs(backend_config)
|
| 383 |
+
dtype_configs: List[DTypeConfig] = pattern_to_dtype_configs.get(pattern, [])
|
| 384 |
+
pattern_to_root_node_getter = get_fusion_pattern_to_root_node_getter(backend_config)
|
| 385 |
+
|
| 386 |
+
root_node_getter = pattern_to_root_node_getter.get(pattern, _default_root_node_getter)
|
| 387 |
+
root_node = root_node_getter(matched_node_pattern)
|
| 388 |
+
input_node = root_node
|
| 389 |
+
output_node = matched_node_pattern[0]
|
| 390 |
+
for dtype_config in dtype_configs:
|
| 391 |
+
# check if arg dtype are supported
|
| 392 |
+
supported = True
|
| 393 |
+
for arg in list(input_node.args) + list(input_node.kwargs.values()):
|
| 394 |
+
supported = supported and _is_input_arg_dtype_supported_by_backend(
|
| 395 |
+
arg, input_node, qconfig, dtype_config, backend_config)
|
| 396 |
+
# check if output dtype is supported
|
| 397 |
+
supported = supported and _is_output_dtype_supported_by_backend(
|
| 398 |
+
output_node, qconfig, dtype_config)
|
| 399 |
+
if supported:
|
| 400 |
+
return True
|
| 401 |
+
return False
|
| 402 |
+
|
| 403 |
+
def _get_standalone_module_configs(
|
| 404 |
+
node: Node,
|
| 405 |
+
named_modules: Dict[str, torch.nn.Module],
|
| 406 |
+
prepare_custom_config: PrepareCustomConfig,
|
| 407 |
+
parent_qconfig: QConfigAny,
|
| 408 |
+
parent_backend_config: Optional[BackendConfig],
|
| 409 |
+
) -> Tuple[QConfigMapping, Tuple[Any, ...], PrepareCustomConfig, Optional[BackendConfig]]:
|
| 410 |
+
"""
|
| 411 |
+
Returns the standalone module QConfigMapping and PrepareCustomConfig
|
| 412 |
+
for `node`, assuming that the module pointed to by `node` is
|
| 413 |
+
a standalone modules.
|
| 414 |
+
"""
|
| 415 |
+
module_name = str(node.target)
|
| 416 |
+
module_type = type(named_modules[module_name]) # type: ignore[index]
|
| 417 |
+
# name config has precedence over type config
|
| 418 |
+
config_entry = StandaloneModuleConfigEntry(None, (), None, None)
|
| 419 |
+
config_entry = prepare_custom_config.standalone_module_classes.get(module_type, config_entry)
|
| 420 |
+
config_entry = prepare_custom_config.standalone_module_names.get(module_name, config_entry)
|
| 421 |
+
# fallback to use parent module's qconfig if user didn't specify qconfig dict
|
| 422 |
+
qconfig_mapping = config_entry.qconfig_mapping or QConfigMapping().set_global(parent_qconfig)
|
| 423 |
+
example_inputs = config_entry.example_inputs
|
| 424 |
+
prepare_custom_config = config_entry.prepare_custom_config or PrepareCustomConfig()
|
| 425 |
+
backend_config = config_entry.backend_config or parent_backend_config
|
| 426 |
+
return (qconfig_mapping, example_inputs, prepare_custom_config, backend_config)
|
| 427 |
+
|
| 428 |
+
def _qat_swap_modules(
|
| 429 |
+
root: torch.nn.Module,
|
| 430 |
+
module_to_qat_module: Dict[Pattern, Type[torch.nn.Module]]) -> None:
|
| 431 |
+
convert(root, mapping=module_to_qat_module, inplace=True, remove_qconfig=False)
|
| 432 |
+
|
| 433 |
+
def _add_matched_node_name_to_set(matched_node_pattern: NodePattern, s: Set[str]):
|
| 434 |
+
if isinstance(matched_node_pattern, Node):
|
| 435 |
+
s.add(matched_node_pattern.name)
|
| 436 |
+
elif isinstance(matched_node_pattern, (list, tuple)):
|
| 437 |
+
for maybe_node in matched_node_pattern:
|
| 438 |
+
_add_matched_node_name_to_set(maybe_node, s)
|
| 439 |
+
|
| 440 |
+
def _insert_obs_or_fq(
|
| 441 |
+
node: Node,
|
| 442 |
+
obs_or_fq: ObserverOrFakeQuantize,
|
| 443 |
+
model: torch.nn.Module,
|
| 444 |
+
named_modules: Dict[str, torch.nn.Module],
|
| 445 |
+
graph: Graph,
|
| 446 |
+
) -> Node:
|
| 447 |
+
"""
|
| 448 |
+
Attaches `obs_or_fq` to `model`, and creates a node which calls
|
| 449 |
+
`obs_or_fq` on the output of `node`.
|
| 450 |
+
|
| 451 |
+
obs_or_fq: an instance of Observer or FakeQuantize module
|
| 452 |
+
"""
|
| 453 |
+
model_device = assert_and_get_unique_device(model)
|
| 454 |
+
if model_device:
|
| 455 |
+
obs_or_fq.to(model_device)
|
| 456 |
+
# add obs_or_fq module as attribute
|
| 457 |
+
if is_equalization_observer(obs_or_fq):
|
| 458 |
+
prefix = node.name + '_equalization_process_'
|
| 459 |
+
else:
|
| 460 |
+
prefix = 'activation_post_process_'
|
| 461 |
+
get_new_obs_or_fq_name = get_new_attr_name_with_prefix(prefix)
|
| 462 |
+
obs_or_fq_name = get_new_obs_or_fq_name(model)
|
| 463 |
+
setattr(model, obs_or_fq_name, obs_or_fq)
|
| 464 |
+
named_modules[obs_or_fq_name] = obs_or_fq
|
| 465 |
+
with graph.inserting_after(node):
|
| 466 |
+
new_obs = graph.create_node(
|
| 467 |
+
'call_module', obs_or_fq_name, (node,), {})
|
| 468 |
+
return new_obs
|
| 469 |
+
|
| 470 |
+
def _set_target_dtype_info_for_matched_node_pattern(
|
| 471 |
+
matched_node_pattern: NodePattern,
|
| 472 |
+
last_node: Node,
|
| 473 |
+
qconfig: QConfigAny,
|
| 474 |
+
qhandler: Optional[QuantizeHandler],
|
| 475 |
+
backend_config: BackendConfig,
|
| 476 |
+
named_modules: Dict[str, torch.nn.Module],
|
| 477 |
+
cache_for_no_tensor_check: Dict[Node, bool],
|
| 478 |
+
processed_nodes: Set[Node],
|
| 479 |
+
) -> None:
|
| 480 |
+
""" Sets the target_dtype_info for each node in matched_node_pattern
|
| 481 |
+
Note: processed_nodes is used to ensure we only process each node once
|
| 482 |
+
"""
|
| 483 |
+
if isinstance(matched_node_pattern, (list, tuple)):
|
| 484 |
+
for node_pattern in matched_node_pattern:
|
| 485 |
+
_set_target_dtype_info_for_matched_node_pattern(
|
| 486 |
+
node_pattern,
|
| 487 |
+
last_node,
|
| 488 |
+
qconfig,
|
| 489 |
+
qhandler,
|
| 490 |
+
backend_config,
|
| 491 |
+
named_modules,
|
| 492 |
+
cache_for_no_tensor_check,
|
| 493 |
+
processed_nodes
|
| 494 |
+
)
|
| 495 |
+
|
| 496 |
+
# set target_dtype_info if matched_node_pattern is a Node
|
| 497 |
+
# other types of matched object, e.g. int, float literals, are ignored
|
| 498 |
+
elif isinstance(matched_node_pattern, Node):
|
| 499 |
+
# for pyre
|
| 500 |
+
assert isinstance(matched_node_pattern, Node)
|
| 501 |
+
node = matched_node_pattern
|
| 502 |
+
if node in processed_nodes:
|
| 503 |
+
return
|
| 504 |
+
processed_nodes.add(node)
|
| 505 |
+
|
| 506 |
+
if qconfig is None:
|
| 507 |
+
return
|
| 508 |
+
# TODO: refactor the following code in terms of apply a qconfig to a pattern
|
| 509 |
+
# e.g. for a pattern with op1 -> op2 -> op3, and qconfig = QConfig(input_act=obs0, output_act=obs1)
|
| 510 |
+
# we set the input_obs_or_fq_ctr for the arguments of op1 to based on qconfig.input_act,
|
| 511 |
+
# and set output_obs_or_fq_ctr based on qconfig.output_act
|
| 512 |
+
# this also requires we extend the structure of QConfig to support more fine
|
| 513 |
+
# grained configurations
|
| 514 |
+
target_dtype_info: Dict[str, Any] = (
|
| 515 |
+
_get_target_activation_dtype_for_node(
|
| 516 |
+
node,
|
| 517 |
+
qconfig,
|
| 518 |
+
qhandler,
|
| 519 |
+
named_modules,
|
| 520 |
+
backend_config,
|
| 521 |
+
cache_for_no_tensor_check,
|
| 522 |
+
)
|
| 523 |
+
)
|
| 524 |
+
node.meta["target_dtype_info"] = target_dtype_info
|
| 525 |
+
|
| 526 |
+
def _get_target_activation_dtype_for_node(
|
| 527 |
+
node: Node,
|
| 528 |
+
qconfig: QConfigAny,
|
| 529 |
+
qhandler: Optional[QuantizeHandler],
|
| 530 |
+
named_modules: Dict[str, torch.nn.Module],
|
| 531 |
+
backend_config: BackendConfig,
|
| 532 |
+
cache_for_no_tensor_check: Dict[Node, bool],
|
| 533 |
+
) -> Dict[str, Any]:
|
| 534 |
+
"""
|
| 535 |
+
For each op attribute in the op's input activation, output activation,
|
| 536 |
+
weight, bias - returns the settings of dtype and is_dynamic we expect
|
| 537 |
+
for the `quantize` call in the reference model representation, or None
|
| 538 |
+
if there is no `quantize` call needed.
|
| 539 |
+
|
| 540 |
+
For example, if we have a node corresponding to `op0` in
|
| 541 |
+
|
| 542 |
+
x0 -> op0 -> x1
|
| 543 |
+
|
| 544 |
+
And we want a reference quantized representation to be
|
| 545 |
+
|
| 546 |
+
x0 -> quant_static -> dequant -> op0 -> quant_dynamic -> dequant -> x1
|
| 547 |
+
|
| 548 |
+
Then this function will return
|
| 549 |
+
|
| 550 |
+
{
|
| 551 |
+
"input_act_obs_or_fq_ctr": MinMaxObserver.with_args(dtype=torch.quint8, is_dynamic=False),
|
| 552 |
+
"output_act_obs_or_fq_ctr": MinMaxObserver.with_args(dtype=torch.quint8, is_dynamic=False),
|
| 553 |
+
}
|
| 554 |
+
|
| 555 |
+
TODO(future PR, if needed): explicitly spell out the non-Tensor
|
| 556 |
+
dtypes.
|
| 557 |
+
"""
|
| 558 |
+
args_have_no_tensors = \
|
| 559 |
+
all_node_args_have_no_tensors(
|
| 560 |
+
node, named_modules, cache_for_no_tensor_check)
|
| 561 |
+
if args_have_no_tensors:
|
| 562 |
+
return {
|
| 563 |
+
"input_act_obs_or_fq_ctr": None,
|
| 564 |
+
"output_act_obs_or_fq_ctr": None,
|
| 565 |
+
}
|
| 566 |
+
# get qconfig to determine the eventual dtype of this node
|
| 567 |
+
if qconfig is not None:
|
| 568 |
+
act_dtype, weight_dtype, input_act_is_dynamic = \
|
| 569 |
+
get_qconfig_dtypes(qconfig)
|
| 570 |
+
|
| 571 |
+
# Currently `QConfig` only has one `activation` field.
|
| 572 |
+
# For static quantization, it is reused for both input
|
| 573 |
+
# and output activation. For dynamic quantization, this
|
| 574 |
+
# field is currently only used for the input activation,
|
| 575 |
+
# with the output activation being in fp32.
|
| 576 |
+
# In the future this may change as we add more fields
|
| 577 |
+
# to the `QConfig` object.
|
| 578 |
+
output_act_dtype = act_dtype \
|
| 579 |
+
if (not input_act_is_dynamic) else torch.float
|
| 580 |
+
|
| 581 |
+
bias_dtype = torch.float16 \
|
| 582 |
+
if (
|
| 583 |
+
act_dtype == torch.float16
|
| 584 |
+
and weight_dtype == torch.float16
|
| 585 |
+
and (not input_act_is_dynamic)
|
| 586 |
+
) else torch.float
|
| 587 |
+
|
| 588 |
+
is_general_tensor_value_op = \
|
| 589 |
+
(qhandler is not None and qhandler.is_general_tensor_value_op())
|
| 590 |
+
|
| 591 |
+
_is_standalone_module = (
|
| 592 |
+
qhandler is not None and qhandler.is_standalone_module()
|
| 593 |
+
)
|
| 594 |
+
|
| 595 |
+
weight_index = None
|
| 596 |
+
if isinstance(node, Node) and node.op == "call_function" and \
|
| 597 |
+
node.target in backend_config._pattern_complex_format_to_config:
|
| 598 |
+
weight_index = backend_config._pattern_complex_format_to_config[node.target]._input_type_to_index.get("weight")
|
| 599 |
+
|
| 600 |
+
bias_index = None
|
| 601 |
+
if isinstance(node, Node) and node.op == "call_function" and \
|
| 602 |
+
node.target in backend_config._pattern_complex_format_to_config:
|
| 603 |
+
bias_index = backend_config._pattern_complex_format_to_config[node.target]._input_type_to_index.get("bias")
|
| 604 |
+
|
| 605 |
+
return {
|
| 606 |
+
"input_act_obs_or_fq_ctr": qconfig.activation,
|
| 607 |
+
"weight_obs_or_fq_ctr": qconfig.weight,
|
| 608 |
+
"bias_obs_or_fq_ctr": PlaceholderObserver.with_args(dtype=bias_dtype),
|
| 609 |
+
"weight_index": weight_index,
|
| 610 |
+
"bias_index": bias_index,
|
| 611 |
+
"output_act_obs_or_fq_ctr": qconfig.activation,
|
| 612 |
+
"reuse_input_obs_or_fq": _is_reuse_input_qconfig(qconfig),
|
| 613 |
+
"input_output_share_observers": is_general_tensor_value_op,
|
| 614 |
+
"_is_standalone_module": _is_standalone_module,
|
| 615 |
+
}
|
| 616 |
+
return copy.copy(_DEFAULT_FP32_QCONFIG_FOR_TARGET_DTYPE_INFO)
|
| 617 |
+
|
| 618 |
+
def _get_output_act_obs_or_fq(
|
| 619 |
+
arg: Node,
|
| 620 |
+
named_modules: Dict[str, torch.nn.Module],
|
| 621 |
+
obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize],
|
| 622 |
+
is_qat: bool,
|
| 623 |
+
) -> ObserverOrFakeQuantize:
|
| 624 |
+
""" Get the constructor for observer or fake quant object for
|
| 625 |
+
the argument in the original graph as the output of previous node,
|
| 626 |
+
skipping inserted observers
|
| 627 |
+
|
| 628 |
+
We are assuming that the observers are inserted correctly, and the dtype for
|
| 629 |
+
argument in quantized graph will match what is specified by the qconfig
|
| 630 |
+
"""
|
| 631 |
+
assert isinstance(arg, Node)
|
| 632 |
+
if "quantization_annotation" in arg.meta:
|
| 633 |
+
return _create_obs_or_fq_from_qspec(arg.meta["quantization_annotation"].output_qspec, obs_or_fq_map, is_qat)
|
| 634 |
+
|
| 635 |
+
# Custom module LSTM output is a tuple that we broke down into the internal nodes in order
|
| 636 |
+
# to insert DeQuantStubs (see `_insert_dequant_stubs_for_custom_module_lstm_output`).
|
| 637 |
+
# Since we modified the graph in this case, we must trace back from the args through
|
| 638 |
+
# the specific nodes we added in order to reach the original LSTM node. Otherwise, we would
|
| 639 |
+
# not be able to accurately detect whether this node is a consumer of custom module LSTM.
|
| 640 |
+
custom_module_lstm_node = _maybe_get_custom_module_lstm_from_node_arg(arg, named_modules)
|
| 641 |
+
output_act_obs_or_fq_ctr = None
|
| 642 |
+
if custom_module_lstm_node is not None:
|
| 643 |
+
output_act_obs_or_fq_ctr = custom_module_lstm_node.meta["target_dtype_info"]["output_act_obs_or_fq_ctr"]
|
| 644 |
+
output_act_obs_or_fq = output_act_obs_or_fq_ctr() if output_act_obs_or_fq_ctr else None
|
| 645 |
+
elif _is_activation_post_process_node(arg, named_modules):
|
| 646 |
+
observed_arg = arg.args[0]
|
| 647 |
+
assert isinstance(observed_arg, Node), "Currently we only support observing Node"
|
| 648 |
+
if "quantization_annotation" in observed_arg.meta:
|
| 649 |
+
output_act_obs_or_fq = \
|
| 650 |
+
_create_obs_or_fq_from_qspec(
|
| 651 |
+
observed_arg.meta["quantization_annotation"].output_qspec, obs_or_fq_map, is_qat)
|
| 652 |
+
else:
|
| 653 |
+
assert "target_dtype_info" in observed_arg.meta
|
| 654 |
+
output_act_obs_or_fq_ctr = observed_arg.meta["target_dtype_info"]["output_act_obs_or_fq_ctr"]
|
| 655 |
+
output_act_obs_or_fq = output_act_obs_or_fq_ctr() if output_act_obs_or_fq_ctr else None
|
| 656 |
+
else:
|
| 657 |
+
if "target_dtype_info" in arg.meta:
|
| 658 |
+
output_act_obs_or_fq_ctr = \
|
| 659 |
+
arg.meta["target_dtype_info"].get("output_act_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR)
|
| 660 |
+
else:
|
| 661 |
+
output_act_obs_or_fq_ctr = _DEFAULT_FP32_OBS_OR_FQ_CTR
|
| 662 |
+
output_act_obs_or_fq = output_act_obs_or_fq_ctr() if output_act_obs_or_fq_ctr else None
|
| 663 |
+
|
| 664 |
+
return output_act_obs_or_fq
|
| 665 |
+
|
| 666 |
+
def _get_arg_target_dtype_as_output(
|
| 667 |
+
arg: Node,
|
| 668 |
+
named_modules: Dict[str, torch.nn.Module],
|
| 669 |
+
obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize],
|
| 670 |
+
is_qat: bool,
|
| 671 |
+
) -> Optional[torch.dtype]:
|
| 672 |
+
arg_as_output_act_obs_or_fq = _get_output_act_obs_or_fq(arg, named_modules, obs_or_fq_map, is_qat)
|
| 673 |
+
arg_as_output_target_dtype, _ = _get_dtype_and_is_dynamic(arg_as_output_act_obs_or_fq)
|
| 674 |
+
return arg_as_output_target_dtype
|
| 675 |
+
|
| 676 |
+
def _get_arg_as_input_act_obs_or_fq(
|
| 677 |
+
arg: Node,
|
| 678 |
+
node: Node,
|
| 679 |
+
named_modules: Dict[str, torch.nn.Module],
|
| 680 |
+
obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize],
|
| 681 |
+
is_qat: bool,
|
| 682 |
+
) -> Optional[ObserverOrFakeQuantize]:
|
| 683 |
+
""" Get the observer or fake quant constructor for the Argument `arg`, as input
|
| 684 |
+
to Node `node`
|
| 685 |
+
"""
|
| 686 |
+
assert isinstance(arg, Node)
|
| 687 |
+
# "input_qspec_map" is the more general design we'll use for pt2e path
|
| 688 |
+
# it is a map from input argument node to observer or fake quant constructor, for example
|
| 689 |
+
# for the following graph:
|
| 690 |
+
# x -> conv -> output
|
| 691 |
+
#
|
| 692 |
+
# we may annotate conv node like the following:
|
| 693 |
+
# conv.meta[...] = QuantizationAnnotation("input_qspec_map": {x: MinMaxObserver.with_args(dtype=torch.qint8)}, ...)
|
| 694 |
+
#
|
| 695 |
+
if "quantization_annotation" in node.meta:
|
| 696 |
+
input_qspec_map = node.meta["quantization_annotation"].input_qspec_map
|
| 697 |
+
input_arg_qspec = _get_qspec_for_arg(arg, input_qspec_map, named_modules)
|
| 698 |
+
if input_arg_qspec is None:
|
| 699 |
+
input_arg_obs_or_fq = _DEFAULT_FP32_OBS_OR_FQ_CTR()
|
| 700 |
+
else:
|
| 701 |
+
input_arg_obs_or_fq = _create_obs_or_fq_from_qspec(input_arg_qspec, obs_or_fq_map, is_qat)
|
| 702 |
+
return input_arg_obs_or_fq
|
| 703 |
+
|
| 704 |
+
# we can remove the following path in the future if fx graph mode quantization is
|
| 705 |
+
# no longer used
|
| 706 |
+
is_weight = node_arg_is_weight(node, arg)
|
| 707 |
+
is_bias = node_arg_is_bias(node, arg)
|
| 708 |
+
is_activation = not is_weight and not is_bias
|
| 709 |
+
obs_or_fq_ctr = None
|
| 710 |
+
if is_activation:
|
| 711 |
+
obs_or_fq_ctr = node.meta["target_dtype_info"].get("input_act_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR)
|
| 712 |
+
elif is_weight:
|
| 713 |
+
if node.target not in NON_QUANTIZABLE_WEIGHT_OPS:
|
| 714 |
+
obs_or_fq_ctr = node.meta["target_dtype_info"].get("weight_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR)
|
| 715 |
+
else:
|
| 716 |
+
obs_or_fq_ctr = node.meta["target_dtype_info"].get("bias_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR)
|
| 717 |
+
return obs_or_fq_ctr() if obs_or_fq_ctr else None
|
| 718 |
+
|
| 719 |
+
def _maybe_insert_input_observer_for_arg_or_kwarg(
|
| 720 |
+
node: Union[Node, Any],
|
| 721 |
+
arg: Argument,
|
| 722 |
+
qconfig: QConfigAny,
|
| 723 |
+
model: torch.nn.Module,
|
| 724 |
+
named_modules: Dict[str, torch.nn.Module],
|
| 725 |
+
graph: Graph,
|
| 726 |
+
qhandler: Optional[QuantizeHandler],
|
| 727 |
+
prepare_custom_config: PrepareCustomConfig,
|
| 728 |
+
obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize],
|
| 729 |
+
is_qat: bool,
|
| 730 |
+
backend_config: Optional[BackendConfig] = None,
|
| 731 |
+
) -> Argument:
|
| 732 |
+
"""
|
| 733 |
+
Given a `node` and an `arg`, inserts an input observer between
|
| 734 |
+
`node` and `arg` if necessary.
|
| 735 |
+
"""
|
| 736 |
+
# for ops such as torch.cat([x0, x1]),
|
| 737 |
+
# traverse through the list
|
| 738 |
+
if isinstance(arg, (list, tuple)):
|
| 739 |
+
new_arg_to_return = []
|
| 740 |
+
for inner_arg in arg:
|
| 741 |
+
new_inner_arg = _maybe_insert_input_observer_for_arg_or_kwarg(
|
| 742 |
+
node, inner_arg, qconfig, model, named_modules,
|
| 743 |
+
graph,
|
| 744 |
+
qhandler,
|
| 745 |
+
prepare_custom_config,
|
| 746 |
+
obs_or_fq_map,
|
| 747 |
+
is_qat,
|
| 748 |
+
backend_config)
|
| 749 |
+
new_arg_to_return.append(new_inner_arg)
|
| 750 |
+
return type(arg)(new_arg_to_return)
|
| 751 |
+
|
| 752 |
+
if not isinstance(arg, Node):
|
| 753 |
+
return arg
|
| 754 |
+
assert isinstance(arg, Node)
|
| 755 |
+
# default (no observer)
|
| 756 |
+
new_arg = arg
|
| 757 |
+
|
| 758 |
+
is_standalone_module = qhandler is not None and qhandler.is_standalone_module()
|
| 759 |
+
# TODO: move this to a separate function
|
| 760 |
+
if not is_standalone_module:
|
| 761 |
+
# Note: qconfig can be None in this branch this we are getting act/fq from
|
| 762 |
+
# node.meta now
|
| 763 |
+
# regular flow for most nodes, except standalone modules
|
| 764 |
+
|
| 765 |
+
if "quantization_annotation" in node.meta:
|
| 766 |
+
reuse_input_obs_or_fq = node.meta["quantization_annotation"]._reuse_input_obs_or_fq
|
| 767 |
+
else:
|
| 768 |
+
assert "target_dtype_info" in node.meta
|
| 769 |
+
# TODO: we are assuming "target_dtype_info" exists here, maybe
|
| 770 |
+
# a default value also need to be provided here
|
| 771 |
+
target_dtype_info = node.meta["target_dtype_info"]
|
| 772 |
+
# for nodes that doesn't have `reuse_input_obs_or_fq` configured,
|
| 773 |
+
# we'll default to False, this makes configuring this field optional for users
|
| 774 |
+
reuse_input_obs_or_fq = target_dtype_info.get("reuse_input_obs_or_fq", False)
|
| 775 |
+
arg_as_input_act_obs_or_fq = _get_arg_as_input_act_obs_or_fq(arg, node, named_modules, obs_or_fq_map, is_qat)
|
| 776 |
+
arg_as_input_target_dtype, arg_as_input_target_is_dynamic = _get_dtype_and_is_dynamic(arg_as_input_act_obs_or_fq)
|
| 777 |
+
|
| 778 |
+
arg_as_output_act_obs_or_fq = _get_output_act_obs_or_fq(arg, named_modules, obs_or_fq_map, is_qat)
|
| 779 |
+
arg_as_output_target_dtype, arg_as_output_target_is_dynamic = _get_dtype_and_is_dynamic(arg_as_output_act_obs_or_fq)
|
| 780 |
+
|
| 781 |
+
|
| 782 |
+
needs_obs_or_fq = _needs_obs_or_fq(
|
| 783 |
+
arg_as_output_target_dtype,
|
| 784 |
+
arg_as_output_target_is_dynamic,
|
| 785 |
+
arg_as_input_target_dtype,
|
| 786 |
+
arg_as_input_target_is_dynamic,
|
| 787 |
+
reuse_input_obs_or_fq,
|
| 788 |
+
is_zeroth_arg=len(node.args) > 0 and arg is node.args[0],
|
| 789 |
+
)
|
| 790 |
+
|
| 791 |
+
else:
|
| 792 |
+
assert qconfig is not None
|
| 793 |
+
# custom flow for standalone modules
|
| 794 |
+
_, _, sm_prepare_custom_config, _ = \
|
| 795 |
+
_get_standalone_module_configs(
|
| 796 |
+
node, named_modules, prepare_custom_config, qconfig, backend_config)
|
| 797 |
+
sm_input_quantized_idxs = sm_prepare_custom_config.input_quantized_indexes
|
| 798 |
+
|
| 799 |
+
# for args, this is set to the index of the current arg
|
| 800 |
+
# for kwargs, this is left at None
|
| 801 |
+
cur_input_idx = None
|
| 802 |
+
for arg_idx, arg_to_check in enumerate(node.args):
|
| 803 |
+
if arg_to_check is arg:
|
| 804 |
+
cur_input_idx = arg_idx
|
| 805 |
+
break
|
| 806 |
+
|
| 807 |
+
if cur_input_idx is None:
|
| 808 |
+
needs_obs_or_fq = False
|
| 809 |
+
else:
|
| 810 |
+
arg_as_output_target_dtype = _get_arg_target_dtype_as_output(arg, named_modules, obs_or_fq_map, is_qat)
|
| 811 |
+
arg_as_input_target_dtype = torch.quint8 if cur_input_idx in sm_input_quantized_idxs \
|
| 812 |
+
else torch.float
|
| 813 |
+
needs_obs_or_fq = (
|
| 814 |
+
(arg_as_output_target_dtype != arg_as_input_target_dtype) and
|
| 815 |
+
(arg_as_input_target_dtype != torch.float)
|
| 816 |
+
)
|
| 817 |
+
|
| 818 |
+
act_post_process_ctr = qconfig.activation
|
| 819 |
+
arg_as_input_act_obs_or_fq = act_post_process_ctr() if act_post_process_ctr else None
|
| 820 |
+
|
| 821 |
+
if needs_obs_or_fq:
|
| 822 |
+
|
| 823 |
+
existing_obs_node = None
|
| 824 |
+
|
| 825 |
+
# Before using the new observer, check if an observer
|
| 826 |
+
# of the correct type already exists. If it does, use it.
|
| 827 |
+
# This prevents duplicate observer insertions if a node is
|
| 828 |
+
# used by multiple nodes.
|
| 829 |
+
# TODO: this is looking into how the value is used in the future
|
| 830 |
+
# we should remove this
|
| 831 |
+
# removing this means we insert one observer for each use, even if they
|
| 832 |
+
# have the same dtype, we can have an extra pass that removes the extra observers
|
| 833 |
+
for maybe_obs_node in arg.users.keys():
|
| 834 |
+
if maybe_obs_node.op == 'call_module':
|
| 835 |
+
maybe_obs_mod = named_modules[maybe_obs_node.target] # type: ignore[index]
|
| 836 |
+
if (
|
| 837 |
+
type(maybe_obs_mod) == type(arg_as_input_act_obs_or_fq) and
|
| 838 |
+
maybe_obs_mod.dtype == arg_as_input_target_dtype # type: ignore[possibly-undefined]
|
| 839 |
+
):
|
| 840 |
+
arg_as_input_act_obs_or_fq = maybe_obs_mod # type: ignore[assignment]
|
| 841 |
+
existing_obs_node = maybe_obs_node
|
| 842 |
+
break
|
| 843 |
+
|
| 844 |
+
assert arg_as_input_act_obs_or_fq is not None
|
| 845 |
+
obs_or_fq_map[(arg, node)] = arg_as_input_act_obs_or_fq
|
| 846 |
+
if existing_obs_node is None:
|
| 847 |
+
new_obs_node = _insert_obs_or_fq(
|
| 848 |
+
arg, arg_as_input_act_obs_or_fq, model, named_modules, graph)
|
| 849 |
+
# override this arg to be the observed arg
|
| 850 |
+
new_arg = new_obs_node
|
| 851 |
+
else:
|
| 852 |
+
new_arg = existing_obs_node
|
| 853 |
+
|
| 854 |
+
return new_arg
|
| 855 |
+
|
| 856 |
+
|
| 857 |
+
def _maybe_insert_input_observers_for_node(
|
| 858 |
+
node: Node,
|
| 859 |
+
qconfig: QConfigAny,
|
| 860 |
+
model: torch.nn.Module,
|
| 861 |
+
named_modules: Dict[str, torch.nn.Module],
|
| 862 |
+
graph: Graph,
|
| 863 |
+
qhandler: Optional[QuantizeHandler],
|
| 864 |
+
prepare_custom_config: PrepareCustomConfig,
|
| 865 |
+
obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize],
|
| 866 |
+
is_qat: bool,
|
| 867 |
+
backend_config: Optional[BackendConfig] = None
|
| 868 |
+
) -> None:
|
| 869 |
+
"""
|
| 870 |
+
If needed, inserts observers to the input args and kwargs of `node`.
|
| 871 |
+
Note: modifies `node` inplace.
|
| 872 |
+
|
| 873 |
+
For example, if cur_node needs an observer after prev_node, we change from
|
| 874 |
+
|
| 875 |
+
prev_node -> cur_node
|
| 876 |
+
|
| 877 |
+
To
|
| 878 |
+
|
| 879 |
+
prev_node -> obs -> cur_node
|
| 880 |
+
|
| 881 |
+
Note: backend_config only needed for standalone_module node
|
| 882 |
+
"""
|
| 883 |
+
# Look through every input arg. If that arg's target dtype does not
|
| 884 |
+
# match the current node's target dtype, insert an observer.
|
| 885 |
+
new_args = []
|
| 886 |
+
for arg in node.args:
|
| 887 |
+
new_arg = _maybe_insert_input_observer_for_arg_or_kwarg(
|
| 888 |
+
node, arg, qconfig, model, named_modules, graph,
|
| 889 |
+
qhandler,
|
| 890 |
+
prepare_custom_config,
|
| 891 |
+
obs_or_fq_map,
|
| 892 |
+
is_qat,
|
| 893 |
+
backend_config)
|
| 894 |
+
new_args.append(new_arg)
|
| 895 |
+
|
| 896 |
+
new_kwargs = {}
|
| 897 |
+
for k, kwarg in node.kwargs.items():
|
| 898 |
+
new_kwarg = _maybe_insert_input_observer_for_arg_or_kwarg(
|
| 899 |
+
node, kwarg, qconfig, model, named_modules, graph,
|
| 900 |
+
qhandler,
|
| 901 |
+
prepare_custom_config,
|
| 902 |
+
obs_or_fq_map,
|
| 903 |
+
is_qat,
|
| 904 |
+
backend_config)
|
| 905 |
+
new_kwargs[k] = new_kwarg
|
| 906 |
+
|
| 907 |
+
# assign the new args and kwargs to the node, inplace
|
| 908 |
+
node.args = tuple(new_args)
|
| 909 |
+
node.kwargs = new_kwargs
|
| 910 |
+
|
| 911 |
+
def _maybe_insert_input_equalization_observers_for_node(
|
| 912 |
+
node: Node,
|
| 913 |
+
equalization_qconfig: Any,
|
| 914 |
+
model: torch.nn.Module,
|
| 915 |
+
named_modules: Dict[str, torch.nn.Module],
|
| 916 |
+
graph: Graph,
|
| 917 |
+
is_branch: bool,
|
| 918 |
+
) -> None:
|
| 919 |
+
"""
|
| 920 |
+
If `node` needs to be equalized, find the input/weight observers it needs in
|
| 921 |
+
`equalization_qconfig`, creates them, and inserts it into `graph`.
|
| 922 |
+
|
| 923 |
+
If `node` does not need an equalization observer, returns None.
|
| 924 |
+
"""
|
| 925 |
+
if equalization_qconfig is None or not node_supports_equalization(node, named_modules):
|
| 926 |
+
return
|
| 927 |
+
|
| 928 |
+
if is_branch:
|
| 929 |
+
warnings.warn(
|
| 930 |
+
f"Cannot equalize {node} because it is part of a branch."
|
| 931 |
+
)
|
| 932 |
+
return
|
| 933 |
+
|
| 934 |
+
new_args = []
|
| 935 |
+
for arg in node.args:
|
| 936 |
+
if not isinstance(arg, Node) or node_arg_is_bias(node, arg):
|
| 937 |
+
new_args.append(arg)
|
| 938 |
+
continue
|
| 939 |
+
|
| 940 |
+
is_weight = node_arg_is_weight(node, arg)
|
| 941 |
+
|
| 942 |
+
act_eq_process_ctr = equalization_qconfig.weight if is_weight else \
|
| 943 |
+
equalization_qconfig.input_activation
|
| 944 |
+
|
| 945 |
+
new_eq_obs_mod = act_eq_process_ctr()
|
| 946 |
+
new_eq_obs_node = _insert_obs_or_fq(
|
| 947 |
+
arg, new_eq_obs_mod, model, named_modules, graph)
|
| 948 |
+
|
| 949 |
+
new_args.append(new_eq_obs_node)
|
| 950 |
+
|
| 951 |
+
# assign the new args and kwargs to the node, inplace
|
| 952 |
+
node.args = tuple(new_args)
|
| 953 |
+
|
| 954 |
+
def _maybe_insert_output_observer_for_node(
|
| 955 |
+
node: Node,
|
| 956 |
+
model: torch.nn.Module,
|
| 957 |
+
named_modules: Dict[str, torch.nn.Module],
|
| 958 |
+
graph: Graph,
|
| 959 |
+
obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize],
|
| 960 |
+
is_qat: bool,
|
| 961 |
+
) -> Optional[Node]:
|
| 962 |
+
"""
|
| 963 |
+
If `node` needs an output observer, creates it, inserts it into `graph`
|
| 964 |
+
and returns it.
|
| 965 |
+
|
| 966 |
+
If `node` does not need an output observer, returns None.
|
| 967 |
+
|
| 968 |
+
Note: inserting dynamic quantization ops for output is not supported in fx graph mode
|
| 969 |
+
quantization code path right now
|
| 970 |
+
"""
|
| 971 |
+
assert node.op != 'output', 'observer insertion for outputs is handled elsewhere'
|
| 972 |
+
|
| 973 |
+
is_standalone_module = False
|
| 974 |
+
if "quantization_annotation" in node.meta:
|
| 975 |
+
output_act_obs_or_fq = _create_obs_or_fq_from_qspec(
|
| 976 |
+
node.meta["quantization_annotation"].output_qspec, obs_or_fq_map, is_qat
|
| 977 |
+
)
|
| 978 |
+
else:
|
| 979 |
+
assert "target_dtype_info" in node.meta
|
| 980 |
+
is_standalone_module = node.meta["target_dtype_info"].get("_is_standalone_module", False)
|
| 981 |
+
output_act_obs_or_fq_ctr = node.meta["target_dtype_info"].get("output_act_obs_or_fq_ctr")
|
| 982 |
+
output_act_obs_or_fq = output_act_obs_or_fq_ctr() if output_act_obs_or_fq_ctr else None
|
| 983 |
+
target_dtype, target_is_dynamic = _get_dtype_and_is_dynamic(output_act_obs_or_fq)
|
| 984 |
+
# uncomment after we support reuse_input_obs_or_fq properly by having separate
|
| 985 |
+
# implemntations for this key instead of reusing the input_output_share_observers
|
| 986 |
+
# code
|
| 987 |
+
# reuse_input_obs_or_fq = node.meta["target_dtype_info"].get("reuse_input_obs_or_fq", False)
|
| 988 |
+
# for now we set this to False since reuse_input_obs_or_fq for
|
| 989 |
+
# the output of a node is implementation in the same code path as observer sharing,
|
| 990 |
+
# we should refactor this part to make it clearer in the future
|
| 991 |
+
# and we would be able to read this from config directly
|
| 992 |
+
reuse_input_obs_or_fq = False
|
| 993 |
+
|
| 994 |
+
# Note: prev_output_dtype = torch.float and prev_output_is_dynamic=False
|
| 995 |
+
# because the prev_output is the output of an fp32 op, althought technically
|
| 996 |
+
# we should get the dtype of the output from node.meta["val"] in the future
|
| 997 |
+
# if we deprecate fx graph mode quantization
|
| 998 |
+
needs_obs_or_fq = _needs_obs_or_fq(torch.float, False, target_dtype, target_is_dynamic, reuse_input_obs_or_fq)
|
| 999 |
+
# currently the activation in QConfig(activation=...,) is for both input
|
| 1000 |
+
# and output, and when the activation is configured to be dynamic quantization
|
| 1001 |
+
# e.g. PlaceholderObserver(dtype=torch.quint8, is_dynamic=True, ...), it means
|
| 1002 |
+
# the input should by dynamically quantized, but output should not be quantized
|
| 1003 |
+
#
|
| 1004 |
+
# there is no way we can specify different observer/fq for input and output
|
| 1005 |
+
# activation through QConfig today, this limitation is lifted in the
|
| 1006 |
+
# quantizer/annotation API in pytorch 2.0 export quantization code path,
|
| 1007 |
+
# but since this code is reused, annotating output to be dynamically quantized
|
| 1008 |
+
# would not work either for that.
|
| 1009 |
+
# we can change QConfig to support input/output activation if we want
|
| 1010 |
+
# to remove the following check, or if we can deprecate fx graph mode quantization
|
| 1011 |
+
if target_is_dynamic:
|
| 1012 |
+
needs_obs_or_fq = False
|
| 1013 |
+
|
| 1014 |
+
# we never insert observers to output of standalone module, we assume
|
| 1015 |
+
# if needed, they are inserted inside the standalone module
|
| 1016 |
+
needs_obs_or_fq = needs_obs_or_fq and \
|
| 1017 |
+
(not is_standalone_module)
|
| 1018 |
+
|
| 1019 |
+
if needs_obs_or_fq:
|
| 1020 |
+
obs_or_fq_map[node] = output_act_obs_or_fq
|
| 1021 |
+
return _insert_obs_or_fq(node, output_act_obs_or_fq, model, named_modules, graph)
|
| 1022 |
+
else:
|
| 1023 |
+
return None
|
| 1024 |
+
|
| 1025 |
+
def _maybe_insert_observers_before_graph_output(
|
| 1026 |
+
graph_output_node: Node,
|
| 1027 |
+
model: torch.nn.Module,
|
| 1028 |
+
named_modules: Dict[str, torch.nn.Module],
|
| 1029 |
+
graph: Graph,
|
| 1030 |
+
obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize],
|
| 1031 |
+
is_qat: bool,
|
| 1032 |
+
) -> None:
|
| 1033 |
+
"""
|
| 1034 |
+
If the output needs to be quantized and there are any nodes
|
| 1035 |
+
in the output which are not already observed, inserts observers
|
| 1036 |
+
for those nodes.
|
| 1037 |
+
"""
|
| 1038 |
+
|
| 1039 |
+
def _recursive_maybe_replace_node_with_obs(
|
| 1040 |
+
maybe_node: Argument,
|
| 1041 |
+
model: torch.nn.Module,
|
| 1042 |
+
named_modules: Dict[str, torch.nn.Module],
|
| 1043 |
+
graph: Graph,
|
| 1044 |
+
) -> Argument:
|
| 1045 |
+
"""
|
| 1046 |
+
Navigate an arbitrary data structure of lists, tuples, dicts.
|
| 1047 |
+
For each container type, recurse on all inputs. Once any Node
|
| 1048 |
+
is found, insert an observer if needed and do not recurse further.
|
| 1049 |
+
|
| 1050 |
+
For example, given a structure of
|
| 1051 |
+
|
| 1052 |
+
{'foo1': [[bar1]], 'foo2': {'foo3': [[[bar3]]]}}
|
| 1053 |
+
|
| 1054 |
+
we recurse down to bar1 and bar3, observe them if necessary,
|
| 1055 |
+
and if we inserted an observer then replace the original node
|
| 1056 |
+
with its observer.
|
| 1057 |
+
|
| 1058 |
+
Returns the data structure with all nodes needing observation being
|
| 1059 |
+
replaced by their observers.
|
| 1060 |
+
"""
|
| 1061 |
+
if isinstance(maybe_node, Node):
|
| 1062 |
+
# check dtype of this node
|
| 1063 |
+
arg_as_output_target_dtype = _get_arg_target_dtype_as_output(maybe_node, named_modules, obs_or_fq_map, is_qat)
|
| 1064 |
+
observer_mod = None
|
| 1065 |
+
arg_as_input_target_dtype = torch.float
|
| 1066 |
+
if "target_dtype_info" in maybe_node.meta:
|
| 1067 |
+
observer_cls = maybe_node.meta["target_dtype_info"].get("input_act_obs_or_fq_ctr", None)
|
| 1068 |
+
if observer_cls is not None:
|
| 1069 |
+
observer_mod = observer_cls()
|
| 1070 |
+
arg_as_input_target_dtype = observer_mod.dtype
|
| 1071 |
+
# TODO: this does not handle dynamic quantization yet
|
| 1072 |
+
need_obs = (
|
| 1073 |
+
arg_as_output_target_dtype != arg_as_input_target_dtype and
|
| 1074 |
+
arg_as_input_target_dtype != torch.float
|
| 1075 |
+
)
|
| 1076 |
+
if need_obs:
|
| 1077 |
+
assert observer_mod is not None
|
| 1078 |
+
# insert observer
|
| 1079 |
+
observer_node = _insert_obs_or_fq(
|
| 1080 |
+
maybe_node, observer_mod, model, named_modules, graph)
|
| 1081 |
+
return observer_node
|
| 1082 |
+
else:
|
| 1083 |
+
return maybe_node
|
| 1084 |
+
elif isinstance(maybe_node, (list, tuple)):
|
| 1085 |
+
results = []
|
| 1086 |
+
for inner_node in maybe_node:
|
| 1087 |
+
results.append(_recursive_maybe_replace_node_with_obs(
|
| 1088 |
+
inner_node, model, named_modules, graph))
|
| 1089 |
+
if isinstance(maybe_node, list):
|
| 1090 |
+
return results
|
| 1091 |
+
else:
|
| 1092 |
+
return tuple(results)
|
| 1093 |
+
elif isinstance(maybe_node, dict):
|
| 1094 |
+
results_dict = {}
|
| 1095 |
+
for k, inner_v in maybe_node.items():
|
| 1096 |
+
results_dict[k] = _recursive_maybe_replace_node_with_obs(
|
| 1097 |
+
inner_v, model, named_modules, graph)
|
| 1098 |
+
return results_dict
|
| 1099 |
+
elif maybe_node is None:
|
| 1100 |
+
return None
|
| 1101 |
+
else:
|
| 1102 |
+
raise Exception("Unhandled type for returned node:", maybe_node)
|
| 1103 |
+
|
| 1104 |
+
new_args = []
|
| 1105 |
+
for old_arg in graph_output_node.args:
|
| 1106 |
+
new_args.append(
|
| 1107 |
+
_recursive_maybe_replace_node_with_obs(
|
| 1108 |
+
old_arg, model, named_modules, graph))
|
| 1109 |
+
|
| 1110 |
+
graph_output_node.args = tuple(new_args) # type: ignore[assignment]
|
| 1111 |
+
|
| 1112 |
+
|
| 1113 |
+
def _maybe_propagate_dtype_for_node(
|
| 1114 |
+
node: Node,
|
| 1115 |
+
target_dtype: Union[torch.dtype, type],
|
| 1116 |
+
node_name_to_match_result_with_qconfig: Dict[str, _MatchResultWithQConfig],
|
| 1117 |
+
) -> None:
|
| 1118 |
+
"""
|
| 1119 |
+
Assigns `target_dtype` to `node`, setting `is_dynamic` to False. If `node`
|
| 1120 |
+
is a general tensor shape op, also call this function recursively on
|
| 1121 |
+
the first argument, to propagate the dtype to the caller.
|
| 1122 |
+
"""
|
| 1123 |
+
node.meta["target_dtype_info"]["input_act_obs_or_fq_ctr"] = None
|
| 1124 |
+
node.meta["target_dtype_info"]["output_act_obs_or_fq_ctr"] = None
|
| 1125 |
+
# if this is a copy node, propagate to first arg
|
| 1126 |
+
root_node, _, pattern, qhandler, qconfig = node_name_to_match_result_with_qconfig.get(
|
| 1127 |
+
node.name, (None, None, None, None, None))
|
| 1128 |
+
# TODO: probably need to remove `is_general_tensor_value_op`
|
| 1129 |
+
if qhandler is not None and qhandler.is_general_tensor_value_op():
|
| 1130 |
+
prev_node = node.args[0]
|
| 1131 |
+
if isinstance(prev_node, Node):
|
| 1132 |
+
_maybe_propagate_dtype_for_node(
|
| 1133 |
+
prev_node, target_dtype, node_name_to_match_result_with_qconfig)
|
| 1134 |
+
|
| 1135 |
+
def propagate_dtypes_for_known_nodes(
|
| 1136 |
+
graph: Graph,
|
| 1137 |
+
node_name_to_match_result_with_qconfig: Dict[str, _MatchResultWithQConfig],
|
| 1138 |
+
) -> None:
|
| 1139 |
+
"""
|
| 1140 |
+
Currently we assume that inputs to the graph are either `torch.float` or
|
| 1141 |
+
`torch.quint8`, which is not always correct. For ops such as
|
| 1142 |
+
`x.masked_fill(mask, value)`, we know that the dtype of `mask` is a
|
| 1143 |
+
`BoolTensor`. Propagate this information throughout the graph.
|
| 1144 |
+
|
| 1145 |
+
Note: not all dtypes in the graph will be correct after this pass, but a
|
| 1146 |
+
higher percentage of them will be correct. Hopefully in the future we can
|
| 1147 |
+
replace this with a better way to reason about dtypes of tensors.
|
| 1148 |
+
"""
|
| 1149 |
+
for node in graph.nodes:
|
| 1150 |
+
non_observable_arg_dict = get_non_observable_arg_indexes_and_types(node)
|
| 1151 |
+
|
| 1152 |
+
for arg_type in non_observable_arg_dict:
|
| 1153 |
+
non_observable_indices = non_observable_arg_dict[arg_type](node)
|
| 1154 |
+
|
| 1155 |
+
for index in non_observable_indices:
|
| 1156 |
+
arg = node.args[index]
|
| 1157 |
+
|
| 1158 |
+
# when an argument is a tuple, it does not show up as another node so we need to go through
|
| 1159 |
+
# all elements of the tuple manually
|
| 1160 |
+
if isinstance(arg, (tuple, list)):
|
| 1161 |
+
arg_list = list(arg)
|
| 1162 |
+
else:
|
| 1163 |
+
arg_list = [arg]
|
| 1164 |
+
|
| 1165 |
+
for cur_arg in arg_list:
|
| 1166 |
+
# hard coded arguments show up but aren't `Node` typed and do not need dtype propagated
|
| 1167 |
+
if isinstance(cur_arg, torch.fx.node.Node):
|
| 1168 |
+
_maybe_propagate_dtype_for_node(
|
| 1169 |
+
cur_arg, arg_type, node_name_to_match_result_with_qconfig)
|
| 1170 |
+
|
| 1171 |
+
def _maybe_make_input_output_share_observers(
|
| 1172 |
+
node: Node,
|
| 1173 |
+
model: torch.nn.Module,
|
| 1174 |
+
named_modules: Dict[str, torch.nn.Module],
|
| 1175 |
+
) -> bool:
|
| 1176 |
+
"""
|
| 1177 |
+
Ensures that we share an observer
|
| 1178 |
+
for all input arguments as well as the output argument. In detail, given
|
| 1179 |
+
a graph of
|
| 1180 |
+
|
| 1181 |
+
x0 -> obs0 -> op -> x2
|
| 1182 |
+
/
|
| 1183 |
+
x1 -> obs1 /
|
| 1184 |
+
|
| 1185 |
+
where node obs0 points to observer instance observer0,
|
| 1186 |
+
obs1 points to observer1 and obs2 points to observer2, we make nodes obs1
|
| 1187 |
+
and ob2 point to observer0.
|
| 1188 |
+
Returns: whether the operation succeeded or not
|
| 1189 |
+
"""
|
| 1190 |
+
first_arg = None
|
| 1191 |
+
# find the first non-Tensor arg
|
| 1192 |
+
for i in range(len(node.args)):
|
| 1193 |
+
if isinstance(node.args[i], (Node, list, tuple)):
|
| 1194 |
+
first_arg = node.args[i]
|
| 1195 |
+
break
|
| 1196 |
+
|
| 1197 |
+
# if there is no non-Tensor arg, return directly
|
| 1198 |
+
if first_arg is None:
|
| 1199 |
+
return False
|
| 1200 |
+
|
| 1201 |
+
if isinstance(first_arg, (list, tuple)):
|
| 1202 |
+
first_arg_arg = first_arg[0]
|
| 1203 |
+
elif isinstance(first_arg, Node):
|
| 1204 |
+
first_arg_arg = first_arg
|
| 1205 |
+
else:
|
| 1206 |
+
return False
|
| 1207 |
+
|
| 1208 |
+
# if we have a graph such as
|
| 1209 |
+
# observed_node -> non_observed_node -> cat
|
| 1210 |
+
# we need to navigate up to the first observer
|
| 1211 |
+
iteration_guard = 0
|
| 1212 |
+
while not _is_activation_post_process_node(first_arg_arg, named_modules):
|
| 1213 |
+
if not isinstance(first_arg_arg, Node):
|
| 1214 |
+
return False
|
| 1215 |
+
# did not find an activation_post_process for the op
|
| 1216 |
+
if first_arg_arg.op == "placeholder":
|
| 1217 |
+
return False
|
| 1218 |
+
# trace back the args until we found the first Tensor/Node
|
| 1219 |
+
trace_back_node = None
|
| 1220 |
+
for i in range(len(first_arg_arg.args)):
|
| 1221 |
+
trace_back_node = first_arg_arg.args[i]
|
| 1222 |
+
if isinstance(trace_back_node, Node):
|
| 1223 |
+
break
|
| 1224 |
+
if trace_back_node is None:
|
| 1225 |
+
return False
|
| 1226 |
+
first_arg_arg = trace_back_node
|
| 1227 |
+
|
| 1228 |
+
iteration_guard += 1
|
| 1229 |
+
if iteration_guard > 10000:
|
| 1230 |
+
raise AssertionError('Unable to find observer of previous node')
|
| 1231 |
+
|
| 1232 |
+
assert isinstance(first_arg_arg, Node)
|
| 1233 |
+
target_to_use = first_arg_arg.target
|
| 1234 |
+
assert isinstance(target_to_use, str)
|
| 1235 |
+
obs_mod_to_use = named_modules[target_to_use]
|
| 1236 |
+
|
| 1237 |
+
if isinstance(first_arg, (list, tuple)):
|
| 1238 |
+
# set all other input observer nodes to use that module
|
| 1239 |
+
for input_idx, input_arg in enumerate(first_arg):
|
| 1240 |
+
if input_idx == 0:
|
| 1241 |
+
continue
|
| 1242 |
+
iteration_guard = 0
|
| 1243 |
+
while not _is_activation_post_process_node(input_arg, named_modules):
|
| 1244 |
+
# failed to trace back since no input arg for the current node
|
| 1245 |
+
if len(input_arg.args) < 1:
|
| 1246 |
+
return False
|
| 1247 |
+
input_arg = input_arg.args[0]
|
| 1248 |
+
iteration_guard += 1
|
| 1249 |
+
if iteration_guard > 10000:
|
| 1250 |
+
raise AssertionError('Unable to find observer of previous node')
|
| 1251 |
+
|
| 1252 |
+
parent_name, name = _parent_name(input_arg.target)
|
| 1253 |
+
setattr(named_modules[parent_name], name, obs_mod_to_use)
|
| 1254 |
+
|
| 1255 |
+
# set the output observer node to use that module
|
| 1256 |
+
for output_obs_node in node.users.keys():
|
| 1257 |
+
assert _is_activation_post_process_node(output_obs_node, named_modules)
|
| 1258 |
+
parent_name, name = _parent_name(output_obs_node.target)
|
| 1259 |
+
setattr(named_modules[parent_name], name, obs_mod_to_use)
|
| 1260 |
+
|
| 1261 |
+
# TODO(future PR): delete the orphaned observer modules
|
| 1262 |
+
return True
|
| 1263 |
+
|
| 1264 |
+
def _remove_output_observer(
|
| 1265 |
+
node: Node,
|
| 1266 |
+
model: torch.nn.Module,
|
| 1267 |
+
named_modules: Dict[str, torch.nn.Module]):
|
| 1268 |
+
items = list(node.users.items())
|
| 1269 |
+
for output_obs_node, _ in items:
|
| 1270 |
+
assert _is_activation_post_process_node(output_obs_node, named_modules)
|
| 1271 |
+
output_obs_node.replace_all_uses_with(node)
|
| 1272 |
+
model.graph.erase_node(output_obs_node) # type: ignore[union-attr, operator]
|
| 1273 |
+
|
| 1274 |
+
def _swap_custom_module_to_observed(
|
| 1275 |
+
node: Node,
|
| 1276 |
+
qconfig: QConfigAny,
|
| 1277 |
+
named_modules: Dict[str, torch.nn.Module],
|
| 1278 |
+
prepare_custom_config: PrepareCustomConfig):
|
| 1279 |
+
custom_module = named_modules[node.target] # type: ignore[index]
|
| 1280 |
+
custom_module_class_mapping = prepare_custom_config.float_to_observed_mapping
|
| 1281 |
+
observed_custom_module_class = \
|
| 1282 |
+
get_swapped_custom_module_class(
|
| 1283 |
+
custom_module, custom_module_class_mapping, qconfig)
|
| 1284 |
+
observed_custom_module = \
|
| 1285 |
+
observed_custom_module_class.from_float(custom_module)
|
| 1286 |
+
parent_name, name = _parent_name(node.target)
|
| 1287 |
+
setattr(named_modules[parent_name], name, observed_custom_module)
|
| 1288 |
+
|
| 1289 |
+
def insert_observers_for_model(
|
| 1290 |
+
model: GraphModule,
|
| 1291 |
+
node_name_to_match_result_with_qconfig: Dict[str, _MatchResultWithQConfig],
|
| 1292 |
+
node_name_to_qconfig: Dict[str, QConfigAny],
|
| 1293 |
+
prepare_custom_config: PrepareCustomConfig,
|
| 1294 |
+
equalization_config_map: Dict[str, Any],
|
| 1295 |
+
backend_config: BackendConfig,
|
| 1296 |
+
observed_node_names: Set[str],
|
| 1297 |
+
is_qat: bool,
|
| 1298 |
+
) -> Optional[Node]:
|
| 1299 |
+
"""
|
| 1300 |
+
Inserts observers, using the following high level algorithm:
|
| 1301 |
+
|
| 1302 |
+
For each node in the graph:
|
| 1303 |
+
1. determine the target dtype of this node in the quantized graph, and save
|
| 1304 |
+
it for future steps
|
| 1305 |
+
2. determine the target dtype or all args and kwargs of this node
|
| 1306 |
+
3. if any arg or kwarg's target dtype does not match the current node's
|
| 1307 |
+
dtype, insert an observer
|
| 1308 |
+
4. if the current node needs an output observer, insert it
|
| 1309 |
+
|
| 1310 |
+
For example:
|
| 1311 |
+
|
| 1312 |
+
- starting graph:
|
| 1313 |
+
x0 -> linear -> x1
|
| 1314 |
+
|
| 1315 |
+
- observed graph after processing x0:
|
| 1316 |
+
x0(fp32)
|
| 1317 |
+
|
| 1318 |
+
- observed graph after processing linear:
|
| 1319 |
+
x0(fp32) -> x0_obs0(int8) -> linear(int8) -> linear_obs0(int8)
|
| 1320 |
+
|
| 1321 |
+
- observed graph after processing x1:
|
| 1322 |
+
x0(fp32) -> x0_obs0(int8) -> linear(int8) -> linear_obs0(int8) -> x1
|
| 1323 |
+
|
| 1324 |
+
After a node is processed, the naive observer placement is guaranteed to be
|
| 1325 |
+
complete for that node and all of its predecessors. There can be future
|
| 1326 |
+
passes which optimize the graph by deduplicating observers, etc.
|
| 1327 |
+
"""
|
| 1328 |
+
|
| 1329 |
+
# node.meta["target_dtype_info"] stores the target dtype information
|
| 1330 |
+
# that's derived from qconfig for the Node, for example, if we have
|
| 1331 |
+
# a conv2d node that has a qconfig
|
| 1332 |
+
# qconfig = QConfig(activation=..., weight=...)
|
| 1333 |
+
# # information for input and bias node omitted
|
| 1334 |
+
# # for getattr node
|
| 1335 |
+
# # weight = getattr(self, 'weight')
|
| 1336 |
+
# weight.meta["target_dtype_info"] = {
|
| 1337 |
+
# 'output_act_obs_or_fq_ctr': qconfig.weight,
|
| 1338 |
+
# }
|
| 1339 |
+
# # for conv2d node
|
| 1340 |
+
# # conv2d = call_function[target=torch.nn.functional.conv2d](
|
| 1341 |
+
# # args=(input, weight, bias))
|
| 1342 |
+
# conv2d.meta["target_dtype_info"] = {
|
| 1343 |
+
# 'input_act_obs_or_fq_ctr': qconfig.activation
|
| 1344 |
+
# 'weight_obs_or_fq_ctr': qconfig.weight,
|
| 1345 |
+
# 'bias_obs_or_fq_ctr': PlaceholderObserver.with_args(dtype=torch.float32),
|
| 1346 |
+
# 'output_act_obs_or_fq_ctr': qconfig.activation,
|
| 1347 |
+
# }
|
| 1348 |
+
#
|
| 1349 |
+
cache_for_no_tensor_check: Dict[Node, bool] = {}
|
| 1350 |
+
|
| 1351 |
+
# first, populate the dtype map based only on qconfig and qhandler
|
| 1352 |
+
# this assumes:
|
| 1353 |
+
# graph inputs are fp32 by default, and int8 where overriden
|
| 1354 |
+
# other nodes output dtype is specified by the qconfig
|
| 1355 |
+
named_modules = dict(model.named_modules(remove_duplicate=False))
|
| 1356 |
+
|
| 1357 |
+
input_quantized_idxs: List[int] = prepare_custom_config.input_quantized_indexes
|
| 1358 |
+
output_quantized_idxs: List[int] = prepare_custom_config.output_quantized_indexes
|
| 1359 |
+
processed_nodes: Set[Node] = set()
|
| 1360 |
+
# initialize target_dtype_info
|
| 1361 |
+
for node in model.graph.nodes:
|
| 1362 |
+
node.meta["target_dtype_info"] = copy.copy(_DEFAULT_FP32_QCONFIG_FOR_TARGET_DTYPE_INFO)
|
| 1363 |
+
|
| 1364 |
+
inputs_seen_counter = 0
|
| 1365 |
+
outputs_seen_counter = 0
|
| 1366 |
+
placeholder_node_to_input_index: Dict[Node, int] = {}
|
| 1367 |
+
# TODO: we probably don't need this counter since each graph will only have
|
| 1368 |
+
# one output node?
|
| 1369 |
+
output_node_to_output_index: Dict[Node, int] = {}
|
| 1370 |
+
for node in model.graph.nodes:
|
| 1371 |
+
if node.op == "placeholder":
|
| 1372 |
+
placeholder_node_to_input_index[node] = inputs_seen_counter
|
| 1373 |
+
inputs_seen_counter += 1
|
| 1374 |
+
if node.op == "output":
|
| 1375 |
+
output_node_to_output_index[node] = outputs_seen_counter
|
| 1376 |
+
outputs_seen_counter += 1
|
| 1377 |
+
|
| 1378 |
+
# Step 1, set the observer or fake quantize module constructor for each node in the
|
| 1379 |
+
# matched_node_pattern
|
| 1380 |
+
|
| 1381 |
+
for match_res_with_qconfig in node_name_to_match_result_with_qconfig.values():
|
| 1382 |
+
last_node, matched_node_pattern, pattern, qhandler, qconfig = match_res_with_qconfig
|
| 1383 |
+
assert qhandler is not None
|
| 1384 |
+
_set_target_dtype_info_for_matched_node_pattern(
|
| 1385 |
+
matched_node_pattern,
|
| 1386 |
+
last_node,
|
| 1387 |
+
qconfig,
|
| 1388 |
+
qhandler,
|
| 1389 |
+
backend_config,
|
| 1390 |
+
named_modules,
|
| 1391 |
+
cache_for_no_tensor_check,
|
| 1392 |
+
processed_nodes
|
| 1393 |
+
)
|
| 1394 |
+
|
| 1395 |
+
# Step 2. Special cases for some operators, we might be able to remove them
|
| 1396 |
+
# in the future if we know dtype information of each node better
|
| 1397 |
+
|
| 1398 |
+
# Step 2.1. some settings are not based on patterns, we need to process each node
|
| 1399 |
+
# instead
|
| 1400 |
+
for node in model.graph.nodes:
|
| 1401 |
+
if node.op == "placeholder" and placeholder_node_to_input_index[node] in input_quantized_idxs:
|
| 1402 |
+
# users are not supposed to call calculate_qparams on PlaceholderObserver, and
|
| 1403 |
+
# this is OK because we are using this as a way to encode the dtypes of input
|
| 1404 |
+
# tensor, we won't actually insert these observers in the graph and won't
|
| 1405 |
+
# actually call calculate_qparams
|
| 1406 |
+
node.meta["target_dtype_info"] = copy.copy(_DEFAULT_QUINT8_QCONFIG_FOR_TARGET_DTYPE_INFO)
|
| 1407 |
+
elif node.op in ("call_module", "call_method", "call_function"):
|
| 1408 |
+
args_have_no_tensors = \
|
| 1409 |
+
all_node_args_have_no_tensors(
|
| 1410 |
+
node, named_modules, cache_for_no_tensor_check)
|
| 1411 |
+
if args_have_no_tensors:
|
| 1412 |
+
node.meta["target_dtype_info"] = {
|
| 1413 |
+
"input_act_obs_or_fq_ctr": None,
|
| 1414 |
+
"output_act_obs_or_fq_ctr": None,
|
| 1415 |
+
}
|
| 1416 |
+
elif node.op == "output" and output_node_to_output_index[node] in output_quantized_idxs:
|
| 1417 |
+
# TODO(future PR): update the output_quantized_idxs API to match
|
| 1418 |
+
# arbitrary data structures. There is always a single output, and
|
| 1419 |
+
# that output can have arbitrary nesting of values. List[int] is
|
| 1420 |
+
# not the right data type for this.
|
| 1421 |
+
|
| 1422 |
+
# TODO(future PR): support more dtypes in model outputs, if necessary
|
| 1423 |
+
node.meta["target_dtype_info"] = copy.copy(_DEFAULT_QUINT8_QCONFIG_FOR_TARGET_DTYPE_INFO)
|
| 1424 |
+
|
| 1425 |
+
# Step 2.2, for nodes with known input dtypes, propagate them throughout the
|
| 1426 |
+
# graph. For example, if there is a call such as
|
| 1427 |
+
# x1 = x0.masked_fill(mask, 1)
|
| 1428 |
+
# we propagate the type of mask to be torch.bool
|
| 1429 |
+
propagate_dtypes_for_known_nodes(model.graph, node_name_to_match_result_with_qconfig)
|
| 1430 |
+
|
| 1431 |
+
# Step 3, check if the requested target_dtype_info is supported by backend or not
|
| 1432 |
+
# if not, we'll reset the target_dtye_info to use the default (float Tensor)
|
| 1433 |
+
|
| 1434 |
+
# reset the counters and set of processed_nodes
|
| 1435 |
+
processed_nodes: Set[Node] = set()
|
| 1436 |
+
for match_res_with_qconfig in node_name_to_match_result_with_qconfig.values():
|
| 1437 |
+
last_node, matched_node_pattern, pattern, qhandler, qconfig = match_res_with_qconfig
|
| 1438 |
+
is_supported_by_backend = _is_pattern_dtype_config_and_qconfig_supported_by_backend(
|
| 1439 |
+
pattern, matched_node_pattern, qconfig, backend_config)
|
| 1440 |
+
assert qhandler is not None
|
| 1441 |
+
|
| 1442 |
+
# get output_act_dtype so that we don't also reset the special typed nodes
|
| 1443 |
+
# TODO: we might want to handle these more uniformly with the default path
|
| 1444 |
+
# this can be improved if we can use node.meta["val"]
|
| 1445 |
+
output_act_or_fq_ctr = node.meta["target_dtype_info"]["output_act_obs_or_fq_ctr"]
|
| 1446 |
+
output_act_or_fq = output_act_or_fq_ctr() if output_act_or_fq_ctr else None
|
| 1447 |
+
output_act_dtype, _ = _get_dtype_and_is_dynamic(output_act_or_fq)
|
| 1448 |
+
if not is_supported_by_backend and output_act_dtype not in [None, int, float, torch.bool]:
|
| 1449 |
+
# restore target_dtype_info to default if it is not supported by backend
|
| 1450 |
+
_set_target_dtype_info_for_matched_node_pattern(
|
| 1451 |
+
matched_node_pattern,
|
| 1452 |
+
last_node,
|
| 1453 |
+
torch.ao.quantization.qconfig._default_fp32_placeholder_qconfig,
|
| 1454 |
+
None,
|
| 1455 |
+
backend_config,
|
| 1456 |
+
named_modules,
|
| 1457 |
+
cache_for_no_tensor_check,
|
| 1458 |
+
processed_nodes
|
| 1459 |
+
)
|
| 1460 |
+
|
| 1461 |
+
# After this point, the current node and all of its arguments
|
| 1462 |
+
# have a target_dtype_info assigned. Now, we insert observers for inputs
|
| 1463 |
+
# of this node (if needed for this node), and the output of this node
|
| 1464 |
+
# (if needed for this node).
|
| 1465 |
+
|
| 1466 |
+
# Since we are mutating the graph as we go, we iterate over the original
|
| 1467 |
+
# nodes before observer insertion, instead of model.graph.nodes.
|
| 1468 |
+
nodes_before_observation = list(model.graph.nodes)
|
| 1469 |
+
|
| 1470 |
+
# Avoid duplicates custom module swaps for multiple nodes with same target.
|
| 1471 |
+
custom_module_names_already_swapped: Set[str] = set()
|
| 1472 |
+
|
| 1473 |
+
# TODO: reuse placeholder_node_to_input_index and output_node_to_output_index
|
| 1474 |
+
# reset inputs/outputs counters
|
| 1475 |
+
inputs_seen_counter = 0
|
| 1476 |
+
outputs_seen_counter = 0
|
| 1477 |
+
results_node = None
|
| 1478 |
+
obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize] = {}
|
| 1479 |
+
|
| 1480 |
+
# TODO: change this to insert obs/fq by pattern instead of by node
|
| 1481 |
+
for node in nodes_before_observation:
|
| 1482 |
+
|
| 1483 |
+
if node.op == 'placeholder':
|
| 1484 |
+
# if a graph input is in fp32, it does not need observation
|
| 1485 |
+
# if a graph input is in int8, we assume the observation happens
|
| 1486 |
+
# outside of the graph, and no additional observation is needed
|
| 1487 |
+
pass
|
| 1488 |
+
|
| 1489 |
+
elif node.op in ('call_module', 'call_method', 'call_function', 'output'):
|
| 1490 |
+
# check for matches
|
| 1491 |
+
last_node, matched_node_pattern, pattern, qhandler, qconfig = (
|
| 1492 |
+
node_name_to_match_result_with_qconfig.get(node.name, (None, None, None, None, None)) # type: ignore[assignment]
|
| 1493 |
+
)
|
| 1494 |
+
equalization_qconfig = equalization_config_map.get(node.name, None)
|
| 1495 |
+
|
| 1496 |
+
this_node_dtype_info = node.meta["target_dtype_info"]
|
| 1497 |
+
if "val" in node.meta:
|
| 1498 |
+
output_is_a_tensor = (
|
| 1499 |
+
this_node_dtype_info is not None and
|
| 1500 |
+
isinstance(node.meta["val"], FakeTensor)
|
| 1501 |
+
)
|
| 1502 |
+
else:
|
| 1503 |
+
output_is_a_tensor = this_node_dtype_info is not None
|
| 1504 |
+
|
| 1505 |
+
skip_inserting_observers = (
|
| 1506 |
+
(qconfig is None) or
|
| 1507 |
+
not output_is_a_tensor
|
| 1508 |
+
) and (
|
| 1509 |
+
not node.op == 'output'
|
| 1510 |
+
)
|
| 1511 |
+
|
| 1512 |
+
# TODO: take a closer look to see if we can remove this check
|
| 1513 |
+
# right now it is here because of `observed_node_names`, we are using
|
| 1514 |
+
# it as an indicator for swapping the modules to reference modules in
|
| 1515 |
+
# convert
|
| 1516 |
+
is_supported_by_backend = _is_pattern_dtype_config_and_qconfig_supported_by_backend(
|
| 1517 |
+
pattern, matched_node_pattern, qconfig, backend_config)
|
| 1518 |
+
|
| 1519 |
+
if not skip_inserting_observers and is_supported_by_backend:
|
| 1520 |
+
named_modules = dict(model.named_modules(remove_duplicate=False))
|
| 1521 |
+
if node.op != 'output':
|
| 1522 |
+
assert matched_node_pattern is not None
|
| 1523 |
+
# add matched nodes to the observed node name set
|
| 1524 |
+
_add_matched_node_name_to_set(matched_node_pattern, observed_node_names)
|
| 1525 |
+
|
| 1526 |
+
# This is currently only used for equalization.
|
| 1527 |
+
# Checks if the current node is in a branch in which the two
|
| 1528 |
+
# first layers are both being quantized.
|
| 1529 |
+
#
|
| 1530 |
+
# ex. conv2
|
| 1531 |
+
# /
|
| 1532 |
+
# x -> conv1
|
| 1533 |
+
#
|
| 1534 |
+
# If this is the case, we will not apply equalization to the
|
| 1535 |
+
# initial two layers.
|
| 1536 |
+
is_quantized_branch = False
|
| 1537 |
+
if (
|
| 1538 |
+
len(node.args) > 0 and
|
| 1539 |
+
isinstance(node.args[0], Node) and
|
| 1540 |
+
len(node.args[0].users) > 1
|
| 1541 |
+
):
|
| 1542 |
+
for user in node.args[0].users:
|
| 1543 |
+
# Checks if there exists another user being quantized
|
| 1544 |
+
is_user_quantized = (
|
| 1545 |
+
node_name_to_qconfig.get(user.name, None) is not None or
|
| 1546 |
+
(user.op == 'call_module' and isinstance(named_modules[str(user.target)], ObserverBase))
|
| 1547 |
+
)
|
| 1548 |
+
if user != node and is_user_quantized:
|
| 1549 |
+
is_quantized_branch = True
|
| 1550 |
+
|
| 1551 |
+
pattern_to_root_node_getter = get_fusion_pattern_to_root_node_getter(backend_config)
|
| 1552 |
+
root_node_getter = pattern_to_root_node_getter.get(pattern, _default_root_node_getter)
|
| 1553 |
+
root_node = root_node_getter(matched_node_pattern)
|
| 1554 |
+
is_input_node_of_the_pattern = node is root_node
|
| 1555 |
+
if is_input_node_of_the_pattern:
|
| 1556 |
+
# this modifies node inplace
|
| 1557 |
+
_maybe_insert_input_observers_for_node(
|
| 1558 |
+
node, qconfig, model, named_modules, model.graph,
|
| 1559 |
+
qhandler,
|
| 1560 |
+
prepare_custom_config,
|
| 1561 |
+
obs_or_fq_map,
|
| 1562 |
+
is_qat,
|
| 1563 |
+
backend_config)
|
| 1564 |
+
|
| 1565 |
+
# insert equalization input observers if needed
|
| 1566 |
+
_maybe_insert_input_equalization_observers_for_node(
|
| 1567 |
+
node, equalization_qconfig, model, named_modules, model.graph,
|
| 1568 |
+
is_quantized_branch)
|
| 1569 |
+
|
| 1570 |
+
is_last_node_of_pattern = node is last_node
|
| 1571 |
+
input_output_share_observers = node.meta["target_dtype_info"].get("input_output_share_observers", False)
|
| 1572 |
+
reuse_input_obs_or_fq = node.meta["target_dtype_info"].get("reuse_input_obs_or_fq", False)
|
| 1573 |
+
|
| 1574 |
+
if is_last_node_of_pattern:
|
| 1575 |
+
if _is_custom_module_lstm(node, named_modules, qconfig, qhandler):
|
| 1576 |
+
# Currently custom module outputs are assumed to be already quantized,
|
| 1577 |
+
# so we need to insert a DeQuantStub after the output. For custom module
|
| 1578 |
+
# LSTM specifically, the outputs are also a nested tuple, so we must first
|
| 1579 |
+
# break down the tuple to insert DeQuantStubs after the internal nodes.
|
| 1580 |
+
|
| 1581 |
+
# TODO: This currently diverges from how custom modules are handled today,
|
| 1582 |
+
# where we insert observers after the output instead of DeQuantStubs, and
|
| 1583 |
+
# replace these observers with "dequantize" nodes during convert. Conceptually,
|
| 1584 |
+
# these output observers are the same as DeQuantStubs. In the future, we
|
| 1585 |
+
# should resolve this inconsistency by inserting DeQuantStubs for all custom
|
| 1586 |
+
# modules, not just for LSTM.
|
| 1587 |
+
_insert_dequant_stubs_for_custom_module_lstm_output(node, model, named_modules, model.graph)
|
| 1588 |
+
if node.target not in custom_module_names_already_swapped:
|
| 1589 |
+
custom_module_names_already_swapped.add(node.target)
|
| 1590 |
+
_swap_custom_module_to_observed(node, qconfig, named_modules, prepare_custom_config)
|
| 1591 |
+
else:
|
| 1592 |
+
# this returns the new observer node if it was needed
|
| 1593 |
+
maybe_output_obs_node = _maybe_insert_output_observer_for_node(
|
| 1594 |
+
node, model, named_modules, model.graph, obs_or_fq_map, is_qat)
|
| 1595 |
+
|
| 1596 |
+
if maybe_output_obs_node is not None:
|
| 1597 |
+
# Update users of original node to use the output observer
|
| 1598 |
+
# instead. For example, change
|
| 1599 |
+
#
|
| 1600 |
+
# next_node
|
| 1601 |
+
# /
|
| 1602 |
+
# cur_node -> obs
|
| 1603 |
+
#
|
| 1604 |
+
# to
|
| 1605 |
+
#
|
| 1606 |
+
# next_node
|
| 1607 |
+
# /
|
| 1608 |
+
# cur_node -> obs
|
| 1609 |
+
#
|
| 1610 |
+
# We need to save orig users before updating uses because
|
| 1611 |
+
# the list of users will change as we update uses
|
| 1612 |
+
orig_users = list(node.users.keys())
|
| 1613 |
+
for user_node in orig_users:
|
| 1614 |
+
if user_node is maybe_output_obs_node:
|
| 1615 |
+
continue
|
| 1616 |
+
user_node.replace_input_with(node, maybe_output_obs_node)
|
| 1617 |
+
|
| 1618 |
+
_is_observer_in_same_graph_ = _is_observer_in_same_graph(
|
| 1619 |
+
node, named_modules, obs_or_fq_map, is_qat)
|
| 1620 |
+
|
| 1621 |
+
# for ops whose inputs and outputs share observer/fqs, we modify the graph
|
| 1622 |
+
# to make all inputs and outputs use the first input's
|
| 1623 |
+
# observer/fq
|
| 1624 |
+
if (input_output_share_observers and _is_observer_in_same_graph_) or \
|
| 1625 |
+
reuse_input_obs_or_fq:
|
| 1626 |
+
if not _maybe_make_input_output_share_observers(node, model, named_modules):
|
| 1627 |
+
_remove_output_observer(node, model, named_modules)
|
| 1628 |
+
|
| 1629 |
+
if qhandler is not None and qhandler.is_custom_module():
|
| 1630 |
+
if node.target not in custom_module_names_already_swapped:
|
| 1631 |
+
custom_module_names_already_swapped.add(node.target)
|
| 1632 |
+
_swap_custom_module_to_observed(node, qconfig, named_modules, prepare_custom_config)
|
| 1633 |
+
|
| 1634 |
+
else: # output
|
| 1635 |
+
_maybe_insert_observers_before_graph_output(node, model, named_modules, model.graph, obs_or_fq_map, is_qat)
|
| 1636 |
+
|
| 1637 |
+
#
|
| 1638 |
+
# After this point, the current node has input and output observers
|
| 1639 |
+
# that it needs for itself inserted.
|
| 1640 |
+
#
|
| 1641 |
+
|
| 1642 |
+
# increment the counters, so future inputs and outputs are assigned
|
| 1643 |
+
# correct dtypes
|
| 1644 |
+
if node.op == 'placeholder':
|
| 1645 |
+
inputs_seen_counter += 1
|
| 1646 |
+
elif node.op == 'output':
|
| 1647 |
+
outputs_seen_counter += 1
|
| 1648 |
+
results_node = node
|
| 1649 |
+
|
| 1650 |
+
return results_node
|
| 1651 |
+
|
| 1652 |
+
def _run_prepare_fx_on_standalone_modules(
|
| 1653 |
+
model: torch.nn.Module,
|
| 1654 |
+
is_qat: bool,
|
| 1655 |
+
named_modules: Dict[str, torch.nn.Module],
|
| 1656 |
+
node_name_to_match_result_with_qconfig: Any,
|
| 1657 |
+
prepare_custom_config: PrepareCustomConfig,
|
| 1658 |
+
backend_config: BackendConfig,
|
| 1659 |
+
) -> None:
|
| 1660 |
+
"""
|
| 1661 |
+
Runs prepare_fx on each standalone module. Note: this does
|
| 1662 |
+
not modify the graph, it just replaces the unobserved modules with
|
| 1663 |
+
their observed versions.
|
| 1664 |
+
"""
|
| 1665 |
+
for (root_node, _, pattern, qhandler, qconfig) in node_name_to_match_result_with_qconfig.values():
|
| 1666 |
+
if qhandler is None:
|
| 1667 |
+
continue
|
| 1668 |
+
elif not qhandler.is_standalone_module():
|
| 1669 |
+
continue
|
| 1670 |
+
|
| 1671 |
+
sm_qconfig_mapping, sm_example_inputs, sm_prepare_custom_config, \
|
| 1672 |
+
sm_backend_config = _get_standalone_module_configs(
|
| 1673 |
+
root_node, named_modules, prepare_custom_config, qconfig, backend_config)
|
| 1674 |
+
|
| 1675 |
+
standalone_module = named_modules[root_node.target]
|
| 1676 |
+
prepare = \
|
| 1677 |
+
torch.ao.quantization.quantize_fx._prepare_standalone_module_fx # type: ignore[attr-defined]
|
| 1678 |
+
observed_standalone_module = \
|
| 1679 |
+
prepare(
|
| 1680 |
+
standalone_module,
|
| 1681 |
+
sm_qconfig_mapping,
|
| 1682 |
+
is_qat,
|
| 1683 |
+
example_inputs=sm_example_inputs,
|
| 1684 |
+
prepare_custom_config=sm_prepare_custom_config,
|
| 1685 |
+
backend_config=sm_backend_config)
|
| 1686 |
+
parent_name, name = _parent_name(root_node.target)
|
| 1687 |
+
setattr(named_modules[parent_name], name, observed_standalone_module)
|
| 1688 |
+
named_modules[root_node.target] = observed_standalone_module
|
| 1689 |
+
|
| 1690 |
+
def _save_state(
|
| 1691 |
+
observed: GraphModule,
|
| 1692 |
+
node_name_to_qconfig: Dict[str, QConfigAny],
|
| 1693 |
+
node_name_to_scope: Dict[str, Tuple[str, type]],
|
| 1694 |
+
prepare_custom_config: PrepareCustomConfig,
|
| 1695 |
+
equalization_node_name_to_qconfig: Dict[str, Any],
|
| 1696 |
+
qconfig_mapping: QConfigMapping,
|
| 1697 |
+
is_qat: bool,
|
| 1698 |
+
observed_node_names: Set[str],
|
| 1699 |
+
) -> None:
|
| 1700 |
+
observed.meta["_observed_graph_module_attrs"] = (
|
| 1701 |
+
ObservedGraphModuleAttrs(
|
| 1702 |
+
node_name_to_qconfig=node_name_to_qconfig,
|
| 1703 |
+
node_name_to_scope=node_name_to_scope,
|
| 1704 |
+
prepare_custom_config=prepare_custom_config,
|
| 1705 |
+
equalization_node_name_to_qconfig=equalization_node_name_to_qconfig,
|
| 1706 |
+
qconfig_mapping=qconfig_mapping,
|
| 1707 |
+
is_qat=is_qat,
|
| 1708 |
+
observed_node_names=observed_node_names,
|
| 1709 |
+
)
|
| 1710 |
+
)
|
| 1711 |
+
|
| 1712 |
+
def prepare(
|
| 1713 |
+
model: GraphModule,
|
| 1714 |
+
qconfig_mapping: Union[QConfigMapping, Dict[str, Any]],
|
| 1715 |
+
is_qat: bool,
|
| 1716 |
+
node_name_to_scope: Dict[str, Tuple[str, type]],
|
| 1717 |
+
example_inputs: Tuple[Any, ...],
|
| 1718 |
+
prepare_custom_config: Union[PrepareCustomConfig, Dict[str, Any], None] = None,
|
| 1719 |
+
_equalization_config: Union[QConfigMapping, Dict[str, Any], None] = None,
|
| 1720 |
+
backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
|
| 1721 |
+
is_standalone_module: bool = False) -> GraphModule:
|
| 1722 |
+
""" standalone_module means it a submodule that is not inlined in
|
| 1723 |
+
parent module, and will be quantized separately as one unit.
|
| 1724 |
+
|
| 1725 |
+
How the standalone module is observed is specified by `input_quantized_idxs` and
|
| 1726 |
+
`output_quantized_idxs` in the prepare_custom_config for the standalone module
|
| 1727 |
+
Args:
|
| 1728 |
+
node_name_to_scope: mapping from node name to the scope of the module which contains the node.
|
| 1729 |
+
The scope is a tuple of fully qualified path of the module and the type of the module
|
| 1730 |
+
Returns:
|
| 1731 |
+
model(GraphModule): prepared standalone module
|
| 1732 |
+
attributes related to standalone module
|
| 1733 |
+
in model.meta["_observed_graph_module_attrs"]:
|
| 1734 |
+
is_observed_standalone_module (bool): boolean value that shows whether the
|
| 1735 |
+
current model is a observed standalone module or not
|
| 1736 |
+
standalone_module_input_quantized_idxs(List[Int]): a list of
|
| 1737 |
+
indexes for the graph input that is expected to be quantized,
|
| 1738 |
+
same as input_quantized_idxs configuration provided
|
| 1739 |
+
for the standalone module
|
| 1740 |
+
standalone_module_output_quantized_idxs(List[Int]): a list of
|
| 1741 |
+
indexs for the graph output that is quantized
|
| 1742 |
+
same as input_quantized_idxs configuration provided
|
| 1743 |
+
for the standalone module
|
| 1744 |
+
"""
|
| 1745 |
+
if prepare_custom_config is None:
|
| 1746 |
+
prepare_custom_config = PrepareCustomConfig()
|
| 1747 |
+
if _equalization_config is None:
|
| 1748 |
+
_equalization_config = QConfigMapping()
|
| 1749 |
+
|
| 1750 |
+
if isinstance(qconfig_mapping, Dict):
|
| 1751 |
+
warnings.warn(
|
| 1752 |
+
"Passing a QConfig dictionary to prepare is deprecated and will not be supported "
|
| 1753 |
+
"in a future version. Please pass in a QConfigMapping instead.")
|
| 1754 |
+
qconfig_mapping = QConfigMapping.from_dict(qconfig_mapping)
|
| 1755 |
+
|
| 1756 |
+
if isinstance(_equalization_config, Dict):
|
| 1757 |
+
warnings.warn(
|
| 1758 |
+
"Passing a QConfig dictionary to prepare for equalization is deprecated and will not "
|
| 1759 |
+
"be supported in a future version. Please pass in a QConfigMapping instead.")
|
| 1760 |
+
_equalization_config = QConfigMapping.from_dict(_equalization_config)
|
| 1761 |
+
|
| 1762 |
+
if isinstance(prepare_custom_config, Dict):
|
| 1763 |
+
warnings.warn(
|
| 1764 |
+
"Passing a prepare_custom_config_dict to prepare is deprecated and will not be supported "
|
| 1765 |
+
"in a future version. Please pass in a PrepareCustomConfig instead.")
|
| 1766 |
+
prepare_custom_config = PrepareCustomConfig.from_dict(prepare_custom_config)
|
| 1767 |
+
|
| 1768 |
+
if isinstance(backend_config, Dict):
|
| 1769 |
+
warnings.warn(
|
| 1770 |
+
"Passing a backend_config_dict to prepare is deprecated and will not be supported "
|
| 1771 |
+
"in a future version. Please pass in a BackendConfig instead.")
|
| 1772 |
+
backend_config = BackendConfig.from_dict(backend_config)
|
| 1773 |
+
|
| 1774 |
+
assert isinstance(qconfig_mapping, QConfigMapping)
|
| 1775 |
+
assert isinstance(_equalization_config, QConfigMapping)
|
| 1776 |
+
qconfig_mapping = copy.deepcopy(qconfig_mapping)
|
| 1777 |
+
_equalization_config = copy.deepcopy(_equalization_config)
|
| 1778 |
+
|
| 1779 |
+
# mapping from a tuple of nodes in reverse order to uninitialized
|
| 1780 |
+
# QuantizeHandler subclass. For example,
|
| 1781 |
+
# {
|
| 1782 |
+
# # match a single node
|
| 1783 |
+
# (<class 'torch.nn.modules.conv.Conv3d'>:
|
| 1784 |
+
# <class 'torch.ao.quantization.fx.quantize.ConvRelu'>),
|
| 1785 |
+
# # match multiple nodes in reverse order
|
| 1786 |
+
# ((<function relu at 0x7f766a7360d0>, <built-in function add>):
|
| 1787 |
+
# <class 'torch.ao.quantization.fx.quantize.Add'>),
|
| 1788 |
+
# }
|
| 1789 |
+
|
| 1790 |
+
pattern_to_quantize_handler: Dict[Pattern, QuantizeHandler] = {}
|
| 1791 |
+
if backend_config is None:
|
| 1792 |
+
backend_config = get_native_backend_config()
|
| 1793 |
+
pattern_to_quantize_handler = _get_pattern_to_quantize_handlers(backend_config)
|
| 1794 |
+
pattern_to_quantize_handler = _sorted_patterns_dict(pattern_to_quantize_handler)
|
| 1795 |
+
|
| 1796 |
+
root_node_getter_mapping = \
|
| 1797 |
+
get_fusion_pattern_to_root_node_getter(backend_config)
|
| 1798 |
+
|
| 1799 |
+
_update_qconfig_for_fusion(model, qconfig_mapping)
|
| 1800 |
+
_update_qconfig_for_fusion(model, _equalization_config)
|
| 1801 |
+
flattened_qconfig_dict = _get_flattened_qconfig_dict(qconfig_mapping)
|
| 1802 |
+
# TODO: support regex as well
|
| 1803 |
+
propagate_qconfig_(model, flattened_qconfig_dict, prepare_custom_config.to_dict())
|
| 1804 |
+
|
| 1805 |
+
if is_qat:
|
| 1806 |
+
module_to_qat_module = get_module_to_qat_module(backend_config)
|
| 1807 |
+
_qat_swap_modules(model, module_to_qat_module)
|
| 1808 |
+
_update_qconfig_for_qat(qconfig_mapping, backend_config)
|
| 1809 |
+
|
| 1810 |
+
# mapping from fully qualified module name to module instance
|
| 1811 |
+
# for example,
|
| 1812 |
+
# {
|
| 1813 |
+
# '': Model(...),
|
| 1814 |
+
# 'linear': Linear(...),
|
| 1815 |
+
# 'linear.weight_fake_quant': PerChannelMinMaxObserver(...),
|
| 1816 |
+
# }
|
| 1817 |
+
named_modules = dict(model.named_modules(remove_duplicate=False))
|
| 1818 |
+
|
| 1819 |
+
# fill node_name_to_qconfig, a map from node name to qconfig, used in _find_matches
|
| 1820 |
+
equalization_node_name_to_qconfig = _generate_node_name_to_qconfig(
|
| 1821 |
+
model, named_modules, model.graph, _equalization_config, node_name_to_scope)
|
| 1822 |
+
node_name_to_qconfig = _generate_node_name_to_qconfig(model, named_modules, model.graph, qconfig_mapping, node_name_to_scope)
|
| 1823 |
+
|
| 1824 |
+
# match the patterns that will get quantized
|
| 1825 |
+
standalone_module_names = list(prepare_custom_config.standalone_module_names.keys())
|
| 1826 |
+
standalone_module_classes = list(prepare_custom_config.standalone_module_classes.keys())
|
| 1827 |
+
|
| 1828 |
+
custom_module_classes = get_custom_module_class_keys(prepare_custom_config.float_to_observed_mapping)
|
| 1829 |
+
matches_without_qconfig = _find_matches(
|
| 1830 |
+
model.graph, named_modules, pattern_to_quantize_handler, root_node_getter_mapping,
|
| 1831 |
+
standalone_module_names, standalone_module_classes, custom_module_classes)
|
| 1832 |
+
|
| 1833 |
+
# map qconfig instances to matches
|
| 1834 |
+
node_name_to_match_result_with_qconfig = {}
|
| 1835 |
+
for node_name, match_without_qconfig in matches_without_qconfig.items():
|
| 1836 |
+
match_with_qconfig = (*match_without_qconfig, node_name_to_qconfig[node_name])
|
| 1837 |
+
node_name_to_match_result_with_qconfig[node_name] = match_with_qconfig
|
| 1838 |
+
|
| 1839 |
+
_run_prepare_fx_on_standalone_modules(
|
| 1840 |
+
model, is_qat, named_modules, node_name_to_match_result_with_qconfig, prepare_custom_config, backend_config)
|
| 1841 |
+
|
| 1842 |
+
# record names for the set of observed node, so that in convert step
|
| 1843 |
+
# we know whether we need to convert a floating point module to reference
|
| 1844 |
+
# quantized module or not
|
| 1845 |
+
observed_node_names: Set[str] = set()
|
| 1846 |
+
|
| 1847 |
+
result_node = insert_observers_for_model(
|
| 1848 |
+
model,
|
| 1849 |
+
node_name_to_match_result_with_qconfig,
|
| 1850 |
+
node_name_to_qconfig,
|
| 1851 |
+
prepare_custom_config,
|
| 1852 |
+
equalization_node_name_to_qconfig,
|
| 1853 |
+
backend_config,
|
| 1854 |
+
observed_node_names,
|
| 1855 |
+
is_qat,
|
| 1856 |
+
)
|
| 1857 |
+
model = GraphModule(model, model.graph)
|
| 1858 |
+
|
| 1859 |
+
_save_state(model, node_name_to_qconfig, node_name_to_scope,
|
| 1860 |
+
prepare_custom_config, equalization_node_name_to_qconfig,
|
| 1861 |
+
qconfig_mapping, is_qat, observed_node_names)
|
| 1862 |
+
|
| 1863 |
+
if is_standalone_module:
|
| 1864 |
+
assert result_node is not None
|
| 1865 |
+
assert isinstance(result_node.args[0], Node), \
|
| 1866 |
+
"standalone module only supports returning simple value currently"\
|
| 1867 |
+
"(not tuple, dict etc.)"
|
| 1868 |
+
# these inputs are observed in parent
|
| 1869 |
+
# converting List[int] to Tensor since module attribute is
|
| 1870 |
+
# Union[Tensor, Module]
|
| 1871 |
+
input_quantized_idxs: List[int] = prepare_custom_config.input_quantized_indexes
|
| 1872 |
+
output_quantized_idxs: List[int] = prepare_custom_config.output_quantized_indexes
|
| 1873 |
+
observed_graph_module_attrs = model.meta["_observed_graph_module_attrs"]
|
| 1874 |
+
# inplace modification
|
| 1875 |
+
observed_graph_module_attrs.is_observed_standalone_module = True
|
| 1876 |
+
observed_graph_module_attrs.standalone_module_input_quantized_idxs = \
|
| 1877 |
+
input_quantized_idxs
|
| 1878 |
+
observed_graph_module_attrs.standalone_module_output_quantized_idxs = \
|
| 1879 |
+
output_quantized_idxs
|
| 1880 |
+
return model
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/__pycache__/duplicate_dq_pass.cpython-311.pyc
ADDED
|
Binary file (4.62 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/__pycache__/prepare.cpython-311.pyc
ADDED
|
Binary file (18.4 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/__pycache__/qat_utils.cpython-311.pyc
ADDED
|
Binary file (38 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/__pycache__/utils.cpython-311.pyc
ADDED
|
Binary file (26.8 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/duplicate_dq_pass.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import operator
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from torch.ao.quantization.pt2e.utils import (
|
| 7 |
+
_filter_sym_size_users,
|
| 8 |
+
_is_valid_annotation,
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
from torch.fx.node import map_arg
|
| 12 |
+
from torch.fx.passes.infra.pass_base import PassBase, PassResult
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
logger.setLevel(logging.WARNING)
|
| 17 |
+
|
| 18 |
+
__all__ = ["DuplicateDQPass"]
|
| 19 |
+
|
| 20 |
+
_QUANTIZE_OPS = [
|
| 21 |
+
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
| 22 |
+
torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
|
| 23 |
+
torch.ops.quantized_decomposed.quantize_per_channel.default,
|
| 24 |
+
]
|
| 25 |
+
|
| 26 |
+
_DEQUANTIZE_OPS = [
|
| 27 |
+
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
| 28 |
+
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
|
| 29 |
+
torch.ops.quantized_decomposed.dequantize_per_channel.default,
|
| 30 |
+
]
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _maybe_duplicate_dq(
|
| 34 |
+
gm: torch.fx.GraphModule, dq_node: torch.fx.Node, user: torch.fx.Node
|
| 35 |
+
):
|
| 36 |
+
annotation = user.meta.get("quantization_annotation", None)
|
| 37 |
+
if not _is_valid_annotation(annotation):
|
| 38 |
+
return
|
| 39 |
+
with gm.graph.inserting_after(dq_node):
|
| 40 |
+
new_node = gm.graph.node_copy(dq_node)
|
| 41 |
+
|
| 42 |
+
def maybe_replace_node(n: torch.fx.Node) -> torch.fx.Node:
|
| 43 |
+
if n == dq_node:
|
| 44 |
+
return new_node
|
| 45 |
+
else:
|
| 46 |
+
return n
|
| 47 |
+
|
| 48 |
+
new_args = map_arg(user.args, maybe_replace_node)
|
| 49 |
+
new_kwargs = map_arg(user.kwargs, maybe_replace_node)
|
| 50 |
+
user.args = new_args
|
| 51 |
+
user.kwargs = new_kwargs
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class DuplicateDQPass(PassBase):
|
| 55 |
+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
|
| 56 |
+
for node in graph_module.graph.nodes:
|
| 57 |
+
if node.op == "call_function" and node.target in _DEQUANTIZE_OPS:
|
| 58 |
+
dq_users = _filter_sym_size_users(node)
|
| 59 |
+
if len(dq_users) <= 1:
|
| 60 |
+
continue
|
| 61 |
+
# Do not duplicate dq for dynamic quantization
|
| 62 |
+
# Pattern: choose_qparam - getitem - q - dq
|
| 63 |
+
q_node = node.args[0]
|
| 64 |
+
if q_node.op == "call_function" and q_node.target in _QUANTIZE_OPS:
|
| 65 |
+
getitem_node = q_node.args[1]
|
| 66 |
+
if (
|
| 67 |
+
isinstance(getitem_node, torch.fx.node.Node)
|
| 68 |
+
and getitem_node.op == "call_function"
|
| 69 |
+
and getitem_node.target == operator.getitem
|
| 70 |
+
):
|
| 71 |
+
choose_qparam_node = getitem_node.args[0]
|
| 72 |
+
if (
|
| 73 |
+
isinstance(choose_qparam_node, torch.fx.node.Node)
|
| 74 |
+
and choose_qparam_node.op == "call_function"
|
| 75 |
+
and choose_qparam_node.target
|
| 76 |
+
== torch.ops.quantized_decomposed.choose_qparams.tensor
|
| 77 |
+
):
|
| 78 |
+
continue
|
| 79 |
+
for user in dq_users:
|
| 80 |
+
_maybe_duplicate_dq(graph_module, node, user)
|
| 81 |
+
graph_module.graph.eliminate_dead_code()
|
| 82 |
+
graph_module.recompile()
|
| 83 |
+
return PassResult(graph_module, True)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/export_utils.py
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import types
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
__all__ = [
|
| 8 |
+
"model_is_exported",
|
| 9 |
+
"_WrapperModule",
|
| 10 |
+
]
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class _WrapperModule(torch.nn.Module):
|
| 14 |
+
"""Class to wrap a callable in an :class:`torch.nn.Module`. Use this if you
|
| 15 |
+
are trying to export a callable.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def __init__(self, fn):
|
| 19 |
+
super().__init__()
|
| 20 |
+
self.fn = fn
|
| 21 |
+
|
| 22 |
+
def forward(self, *args, **kwargs):
|
| 23 |
+
"""Simple forward that just calls the ``fn`` provided to :meth:`WrapperModule.__init__`."""
|
| 24 |
+
return self.fn(*args, **kwargs)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def model_is_exported(m: torch.nn.Module) -> bool:
|
| 28 |
+
"""
|
| 29 |
+
Return True if the `torch.nn.Module` was exported, False otherwise
|
| 30 |
+
(e.g. if the model was FX symbolically traced or not traced at all).
|
| 31 |
+
"""
|
| 32 |
+
return isinstance(m, torch.fx.GraphModule) and any(
|
| 33 |
+
"val" in n.meta for n in m.graph.nodes
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _replace_dropout(m: torch.fx.GraphModule, train_to_eval: bool):
|
| 38 |
+
"""
|
| 39 |
+
Switch dropout patterns in the model between train and eval modes.
|
| 40 |
+
|
| 41 |
+
Dropout has different behavior in train vs eval mode. For exported models,
|
| 42 |
+
however, calling `model.train()` or `model.eval()` does not automatically switch
|
| 43 |
+
the dropout behavior between the two modes, so here we need to rewrite the aten
|
| 44 |
+
dropout patterns manually to achieve the same effect.
|
| 45 |
+
|
| 46 |
+
See https://github.com/pytorch/pytorch/issues/103681.
|
| 47 |
+
"""
|
| 48 |
+
# Avoid circular dependencies
|
| 49 |
+
from .utils import get_aten_graph_module
|
| 50 |
+
|
| 51 |
+
# Needed to ensure subgraph matches are self-contained
|
| 52 |
+
m.graph.eliminate_dead_code()
|
| 53 |
+
m.recompile()
|
| 54 |
+
|
| 55 |
+
for inplace in [False, True]:
|
| 56 |
+
|
| 57 |
+
def dropout_train(x):
|
| 58 |
+
return F.dropout(x, p=0.5, training=True, inplace=inplace)
|
| 59 |
+
|
| 60 |
+
def dropout_eval(x):
|
| 61 |
+
return F.dropout(x, p=0.5, training=False, inplace=inplace)
|
| 62 |
+
|
| 63 |
+
example_inputs = (torch.randn(1),)
|
| 64 |
+
if train_to_eval:
|
| 65 |
+
match_pattern = get_aten_graph_module(
|
| 66 |
+
_WrapperModule(dropout_train), example_inputs
|
| 67 |
+
)
|
| 68 |
+
replacement_pattern = get_aten_graph_module(
|
| 69 |
+
_WrapperModule(dropout_eval), example_inputs
|
| 70 |
+
)
|
| 71 |
+
else:
|
| 72 |
+
match_pattern = get_aten_graph_module(
|
| 73 |
+
_WrapperModule(dropout_eval), example_inputs
|
| 74 |
+
)
|
| 75 |
+
replacement_pattern = get_aten_graph_module(
|
| 76 |
+
_WrapperModule(dropout_train), example_inputs
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
from torch.fx.subgraph_rewriter import replace_pattern_with_filters
|
| 80 |
+
|
| 81 |
+
replace_pattern_with_filters(
|
| 82 |
+
m,
|
| 83 |
+
match_pattern,
|
| 84 |
+
replacement_pattern,
|
| 85 |
+
match_filters=[],
|
| 86 |
+
ignore_literals=True,
|
| 87 |
+
)
|
| 88 |
+
m.recompile()
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def _replace_batchnorm(m: torch.fx.GraphModule, train_to_eval: bool):
|
| 92 |
+
"""
|
| 93 |
+
Switch batchnorm patterns in the model between train and eval modes.
|
| 94 |
+
|
| 95 |
+
Batchnorm has different behavior in train vs eval mode. For exported models,
|
| 96 |
+
however, calling `model.train()` or `model.eval()` does not automatically switch
|
| 97 |
+
the batchnorm behavior between the two modes, so here we need to rewrite the aten
|
| 98 |
+
batchnorm patterns manually to achieve the same effect.
|
| 99 |
+
"""
|
| 100 |
+
# TODO(Leslie): This function still fails to support custom momentum and eps value.
|
| 101 |
+
# Enable this support in future updates.
|
| 102 |
+
|
| 103 |
+
# Avoid circular dependencies
|
| 104 |
+
from .utils import get_aten_graph_module
|
| 105 |
+
|
| 106 |
+
# Needed to ensure subgraph matches are self-contained
|
| 107 |
+
m.graph.eliminate_dead_code()
|
| 108 |
+
m.recompile()
|
| 109 |
+
|
| 110 |
+
def bn_train(
|
| 111 |
+
x: torch.Tensor,
|
| 112 |
+
bn_weight: torch.Tensor,
|
| 113 |
+
bn_bias: torch.Tensor,
|
| 114 |
+
bn_running_mean: torch.Tensor,
|
| 115 |
+
bn_running_var: torch.Tensor,
|
| 116 |
+
):
|
| 117 |
+
return F.batch_norm(
|
| 118 |
+
x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=True
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
def bn_eval(
|
| 122 |
+
x: torch.Tensor,
|
| 123 |
+
bn_weight: torch.Tensor,
|
| 124 |
+
bn_bias: torch.Tensor,
|
| 125 |
+
bn_running_mean: torch.Tensor,
|
| 126 |
+
bn_running_var: torch.Tensor,
|
| 127 |
+
):
|
| 128 |
+
return F.batch_norm(
|
| 129 |
+
x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=False
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
example_inputs = (
|
| 133 |
+
torch.randn(1, 1, 3, 3), # x
|
| 134 |
+
torch.randn(1), # bn_weight
|
| 135 |
+
torch.randn(1), # bn_bias
|
| 136 |
+
torch.randn(1), # bn_running_mean
|
| 137 |
+
torch.randn(1), # bn_running_var
|
| 138 |
+
)
|
| 139 |
+
if train_to_eval:
|
| 140 |
+
match_pattern = get_aten_graph_module(_WrapperModule(bn_train), example_inputs)
|
| 141 |
+
replacement_pattern = get_aten_graph_module(
|
| 142 |
+
_WrapperModule(bn_eval), example_inputs
|
| 143 |
+
)
|
| 144 |
+
else:
|
| 145 |
+
match_pattern = get_aten_graph_module(_WrapperModule(bn_eval), example_inputs)
|
| 146 |
+
replacement_pattern = get_aten_graph_module(
|
| 147 |
+
_WrapperModule(bn_train), example_inputs
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
from torch.fx.subgraph_rewriter import replace_pattern_with_filters
|
| 151 |
+
|
| 152 |
+
replace_pattern_with_filters(
|
| 153 |
+
m,
|
| 154 |
+
match_pattern,
|
| 155 |
+
replacement_pattern,
|
| 156 |
+
match_filters=[],
|
| 157 |
+
ignore_literals=True,
|
| 158 |
+
)
|
| 159 |
+
m.recompile()
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
# TODO: expose these under this namespace?
|
| 163 |
+
def _move_exported_model_to_eval(model: torch.fx.GraphModule):
|
| 164 |
+
"""
|
| 165 |
+
Move an exported GraphModule to eval mode.
|
| 166 |
+
|
| 167 |
+
This is equivalent to model.eval() but only for certain special ops like dropout, batchnorm.
|
| 168 |
+
QAT users should call this before performing inference on the model.
|
| 169 |
+
"""
|
| 170 |
+
_replace_dropout(model, train_to_eval=True)
|
| 171 |
+
_replace_batchnorm(model, train_to_eval=True)
|
| 172 |
+
return model
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def _move_exported_model_to_train(model: torch.fx.GraphModule):
|
| 176 |
+
"""
|
| 177 |
+
Move an exported GraphModule to train mode.
|
| 178 |
+
|
| 179 |
+
This is equivalent to model.train() but only for certain special ops like dropout, batchnorm.
|
| 180 |
+
QAT users should call this before performing training on the model.
|
| 181 |
+
"""
|
| 182 |
+
_replace_dropout(model, train_to_eval=False)
|
| 183 |
+
_replace_batchnorm(model, train_to_eval=False)
|
| 184 |
+
return model
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def _allow_exported_model_train_eval(model: torch.fx.GraphModule):
|
| 188 |
+
"""
|
| 189 |
+
Allow users to call `model.train()` and `model.eval()` on an exported model,
|
| 190 |
+
but with the effect of changing behavior between the two modes limited to special
|
| 191 |
+
ops only, which are currently dropout and batchnorm.
|
| 192 |
+
|
| 193 |
+
Note: This does not achieve the same effect as what `model.train()` and `model.eval()`
|
| 194 |
+
does in eager models, but only provides an approximation. In particular, user code
|
| 195 |
+
branching on `training` flag will not function correctly in general because the branch
|
| 196 |
+
is already specialized at export time. Additionally, other ops beyond dropout and batchnorm
|
| 197 |
+
that have different train/eval behavior will also not be converted properly.
|
| 198 |
+
"""
|
| 199 |
+
|
| 200 |
+
def _train(self, mode: bool = True):
|
| 201 |
+
if mode:
|
| 202 |
+
_move_exported_model_to_train(self)
|
| 203 |
+
else:
|
| 204 |
+
_move_exported_model_to_eval(self)
|
| 205 |
+
|
| 206 |
+
def _eval(self):
|
| 207 |
+
_move_exported_model_to_eval(self)
|
| 208 |
+
|
| 209 |
+
model.train = types.MethodType(_train, model) # type: ignore[method-assign]
|
| 210 |
+
model.eval = types.MethodType(_eval, model) # type: ignore[method-assign]
|
| 211 |
+
return model
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/prepare.py
ADDED
|
@@ -0,0 +1,489 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch._subclasses import FakeTensor
|
| 3 |
+
from torch.ao.quantization.fx.prepare import (
|
| 4 |
+
_insert_obs_or_fq,
|
| 5 |
+
_save_state,
|
| 6 |
+
_is_activation_post_process_node,
|
| 7 |
+
_create_obs_or_fq_from_qspec,
|
| 8 |
+
)
|
| 9 |
+
from torch.fx import (
|
| 10 |
+
GraphModule,
|
| 11 |
+
Graph,
|
| 12 |
+
Node,
|
| 13 |
+
)
|
| 14 |
+
from torch.fx.node import Argument
|
| 15 |
+
|
| 16 |
+
from torch.ao.quantization import QConfigMapping
|
| 17 |
+
from torch.ao.quantization.qconfig import QConfigAny
|
| 18 |
+
from torch.ao.quantization.fx.custom_config import PrepareCustomConfig
|
| 19 |
+
from typing import Dict, Tuple, Union, Any, Optional
|
| 20 |
+
from torch.ao.quantization.quantizer import (
|
| 21 |
+
EdgeOrNode,
|
| 22 |
+
SharedQuantizationSpec,
|
| 23 |
+
QuantizationSpecBase,
|
| 24 |
+
)
|
| 25 |
+
from torch.ao.quantization import ObserverOrFakeQuantize
|
| 26 |
+
|
| 27 |
+
# TODO: make pt2e folder private?
|
| 28 |
+
__all__ = [
|
| 29 |
+
"prepare",
|
| 30 |
+
]
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _find_root_edge_or_node(edge_or_node: EdgeOrNode, shared_with_map: Dict[EdgeOrNode, EdgeOrNode]) -> EdgeOrNode:
|
| 34 |
+
"""Find the root node for the sharing tree
|
| 35 |
+
Args:
|
| 36 |
+
edge_or_node: edge/node that we want to find the root
|
| 37 |
+
shared_with_map: each edge/node points to the parent, the root node will points to itself
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
root edge/node
|
| 41 |
+
"""
|
| 42 |
+
parent = shared_with_map[edge_or_node]
|
| 43 |
+
if parent == edge_or_node:
|
| 44 |
+
return edge_or_node
|
| 45 |
+
root = _find_root_edge_or_node(parent, shared_with_map)
|
| 46 |
+
# path compression
|
| 47 |
+
shared_with_map[edge_or_node] = root
|
| 48 |
+
return root
|
| 49 |
+
|
| 50 |
+
def _union(parent: EdgeOrNode, child: EdgeOrNode, shared_with_map: Dict[EdgeOrNode, EdgeOrNode]) -> None:
|
| 51 |
+
"""Merge the subtree for `child` with `parent`, the order is important here
|
| 52 |
+
"""
|
| 53 |
+
root_parent = _find_root_edge_or_node(parent, shared_with_map)
|
| 54 |
+
root_child = _find_root_edge_or_node(child, shared_with_map)
|
| 55 |
+
# union the two trees by pointing the root of child to root of parent
|
| 56 |
+
shared_with_map[root_child] = root_parent
|
| 57 |
+
|
| 58 |
+
def _update_shared_with(child: EdgeOrNode, qspec: QuantizationSpecBase, shared_with_map: Dict[EdgeOrNode, EdgeOrNode]):
|
| 59 |
+
"""Update the `shared_with_map` based on the qspec, this applies the `SharedQuantizationSpec`
|
| 60 |
+
configuration and established the relationship between `edge_or_node` with the edge/node that it
|
| 61 |
+
is pointing to, we'll use this information in the end to get the group id
|
| 62 |
+
"""
|
| 63 |
+
if isinstance(qspec, SharedQuantizationSpec):
|
| 64 |
+
parent = qspec.edge_or_node
|
| 65 |
+
# we point from edge_or_node to the node that it is sharing_with, e.g.
|
| 66 |
+
# qspec for a = SharedQuantizationSpec(b) means `a` points to `b`
|
| 67 |
+
_union(parent, child, shared_with_map)
|
| 68 |
+
|
| 69 |
+
def _unwrap_shared_qspec(
|
| 70 |
+
qspec: QuantizationSpecBase,
|
| 71 |
+
edge_or_node_to_qspec: Dict[EdgeOrNode, QuantizationSpecBase],
|
| 72 |
+
shared_with_map: Dict[EdgeOrNode, EdgeOrNode]
|
| 73 |
+
) -> QuantizationSpecBase:
|
| 74 |
+
"""Unwraps qspec to get the final root qspec (non SharedQuantizationSpec)
|
| 75 |
+
if qspec is SharedQuantizationSpec
|
| 76 |
+
(1). tries to find the root edge or node for the node that the qspec points to
|
| 77 |
+
(2). recursively find the root qspec based on the qspec for the root node
|
| 78 |
+
"""
|
| 79 |
+
if isinstance(qspec, SharedQuantizationSpec):
|
| 80 |
+
sharing_with = qspec.edge_or_node
|
| 81 |
+
root = _find_root_edge_or_node(sharing_with, shared_with_map)
|
| 82 |
+
qspec = edge_or_node_to_qspec[root]
|
| 83 |
+
return _unwrap_shared_qspec(qspec, edge_or_node_to_qspec, shared_with_map)
|
| 84 |
+
return qspec
|
| 85 |
+
|
| 86 |
+
def _has_same_dtype(qspec_a: QuantizationSpecBase, qspec_b: QuantizationSpecBase):
|
| 87 |
+
return (
|
| 88 |
+
hasattr(qspec_a, "dtype") and
|
| 89 |
+
hasattr(qspec_b, "dtype") and
|
| 90 |
+
qspec_a.dtype == qspec_b.dtype
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
def _has_same_is_dynamic(qspec_a: QuantizationSpecBase, qspec_b: QuantizationSpecBase):
|
| 94 |
+
return (
|
| 95 |
+
hasattr(qspec_a, "is_dynamic") and
|
| 96 |
+
hasattr(qspec_b, "is_dynamic") and
|
| 97 |
+
qspec_a.is_dynamic == qspec_b.is_dynamic
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
def _get_edge_or_node_to_qspec(model: torch.fx.GraphModule) -> Dict[EdgeOrNode, QuantizationSpecBase]:
|
| 101 |
+
"""Get a map from EdgeOrNode to quantization spec based on annotations on the nodes
|
| 102 |
+
"""
|
| 103 |
+
edge_or_node_to_qspec: Dict[EdgeOrNode, QuantizationSpecBase] = {}
|
| 104 |
+
for n in model.graph.nodes:
|
| 105 |
+
if hasattr(n, "meta") and "quantization_annotation" in n.meta:
|
| 106 |
+
qa = n.meta["quantization_annotation"]
|
| 107 |
+
for input_to_n, qspec in qa.input_qspec_map.items():
|
| 108 |
+
input_edge = (input_to_n, n)
|
| 109 |
+
edge_or_node_to_qspec[input_edge] = qspec
|
| 110 |
+
if qa.output_qspec is not None:
|
| 111 |
+
output_node = n
|
| 112 |
+
qspec = qa.output_qspec
|
| 113 |
+
edge_or_node_to_qspec[output_node] = qspec
|
| 114 |
+
return edge_or_node_to_qspec
|
| 115 |
+
|
| 116 |
+
def _union_input_edge_with(input_edge, input_edge_root_qspec, edge_or_node, edge_or_node_to_qspec, shared_with_map):
|
| 117 |
+
"""Union input edge with another edge or node, used in implicit sharing to point the current input
|
| 118 |
+
edge to other user edges of the producer node, or the output of producer node since these are
|
| 119 |
+
referring to the same Tensor
|
| 120 |
+
"""
|
| 121 |
+
root_qspec = None
|
| 122 |
+
if edge_or_node in edge_or_node_to_qspec:
|
| 123 |
+
qspec = edge_or_node_to_qspec[edge_or_node]
|
| 124 |
+
root_qspec = _unwrap_shared_qspec(qspec, edge_or_node_to_qspec, shared_with_map)
|
| 125 |
+
# TODO: add assertions for types of root qspecs
|
| 126 |
+
if (
|
| 127 |
+
root_qspec is not None and
|
| 128 |
+
_has_same_dtype(root_qspec, input_edge_root_qspec) and
|
| 129 |
+
_has_same_is_dynamic(root_qspec, input_edge_root_qspec)
|
| 130 |
+
):
|
| 131 |
+
# the input arg to the node should reuse the existing output observer for arg
|
| 132 |
+
# since dtype is the same (we may want to extend this to be a more strict check
|
| 133 |
+
# in the future)
|
| 134 |
+
# so we point from `input_edge` to `arg` (output of the argument)
|
| 135 |
+
_union(edge_or_node, input_edge, shared_with_map)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def _get_edge_or_node_to_group_id(edge_or_node_to_qspec: Dict[EdgeOrNode, QuantizationSpecBase]) -> Dict[EdgeOrNode, int]:
|
| 139 |
+
"""Map from edge/node to the group ID, generated from quantization annotations,
|
| 140 |
+
edge/node with the same group ID should use the same observer/fake_quant instance
|
| 141 |
+
|
| 142 |
+
This is applying SharedQuantizationSpec configuration and map each edge/node to a group
|
| 143 |
+
There is another implicit sharing that's built in the quantization, when we have the following:
|
| 144 |
+
* op1 -> op2
|
| 145 |
+
* output of op1: int8_qspec
|
| 146 |
+
* (op1 -> op2) input edge: int8_qspec
|
| 147 |
+
we'll assume sharing between the output of op1 and input of (op1 -> op2) since these are the same Tensor.
|
| 148 |
+
|
| 149 |
+
Figuring out the correct group ID for all edge/node is a standard union find problem:
|
| 150 |
+
https://www.geeksforgeeks.org/introduction-to-disjoint-set-data-structure-or-union-find-algorithm/
|
| 151 |
+
|
| 152 |
+
Args:
|
| 153 |
+
edge_or_node_to_qspec: Dictionary from edge_or_node to the qspec, derived from annotations
|
| 154 |
+
Returns:
|
| 155 |
+
edge_or_node_to_group_id: Dictionary from edge_or_node to group_id (int), all edge or node that
|
| 156 |
+
belongs to the same group should have the same id
|
| 157 |
+
|
| 158 |
+
Example:
|
| 159 |
+
op2 -> cat1 -> cat2
|
| 160 |
+
op1 / /
|
| 161 |
+
op3
|
| 162 |
+
edge_or_node_to_qspec: {
|
| 163 |
+
op1: int8_qspec,
|
| 164 |
+
op2: int8_qspec,
|
| 165 |
+
(op1, cat1): int8_qspc,
|
| 166 |
+
(op2, cat1): SharedQuantizationSpec((op1, cat1)),
|
| 167 |
+
cat1: SharedQuantizationSpec((op1, cat1)),
|
| 168 |
+
(op3, cat2): int8_qspec,
|
| 169 |
+
(cat1, cat2): SharedQuantizationSpec((op3, cat2)),
|
| 170 |
+
cat2: SharedQuantizationSpec((op3, cat2)),
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
edge_or_node_to_group_id = _get_edge_or_node_to_group_id(edge_or_node_to_qspec)
|
| 174 |
+
edge_or_node_to_group_id: {
|
| 175 |
+
op1: 1,
|
| 176 |
+
op2: 1,
|
| 177 |
+
(op1, cat1): 1,
|
| 178 |
+
(op2, cat1): 1,
|
| 179 |
+
cat1: 1,
|
| 180 |
+
(op3, cat2): 1,
|
| 181 |
+
(cat1, cat2): 1,
|
| 182 |
+
cat2: 1,
|
| 183 |
+
}
|
| 184 |
+
# everything are in the same group because (cat1) and (cat1, cat2) are implicitly shared, which
|
| 185 |
+
# connects the two sharing group around cat1 and cat2 op due to transitive sharing
|
| 186 |
+
"""
|
| 187 |
+
# means the observer of key should be shared with observer with value, by default it will
|
| 188 |
+
# be shared with itself
|
| 189 |
+
shared_with_map: Dict[EdgeOrNode, EdgeOrNode] = {k: k for k in edge_or_node_to_qspec.keys()}
|
| 190 |
+
for edge_or_node, qspec in edge_or_node_to_qspec.items():
|
| 191 |
+
if isinstance(edge_or_node, torch.fx.Node):
|
| 192 |
+
output_node = edge_or_node
|
| 193 |
+
_update_shared_with(output_node, qspec, shared_with_map)
|
| 194 |
+
else:
|
| 195 |
+
input_edge = edge_or_node
|
| 196 |
+
input_edge_root_qspec = _unwrap_shared_qspec(qspec, edge_or_node_to_qspec, shared_with_map)
|
| 197 |
+
|
| 198 |
+
assert isinstance(input_edge, tuple)
|
| 199 |
+
arg, n = input_edge
|
| 200 |
+
if n.meta["quantization_annotation"].allow_implicit_sharing:
|
| 201 |
+
# NOTE: the order is important here, we first share with other users and then share with previous
|
| 202 |
+
# output because the reverse order could cause circular dependency
|
| 203 |
+
# e.g node1 -> node2
|
| 204 |
+
# \ -> node3
|
| 205 |
+
# when processing (node1, node2), if we first point (node1, node2) to node1
|
| 206 |
+
# Step 1. shared_map = {(node1, node2): node1}
|
| 207 |
+
# Step 2. after that, we point the (node1, node2) to its other user (node1, node3) ,
|
| 208 |
+
# which means shared_map = {(node1, node2): node1, node1: (node1, node3)}
|
| 209 |
+
# because we will point the root of (node1, node2) (in this case node1) to the root of (node1, node3)
|
| 210 |
+
# Step 3. and when we process (node1, node3), it can try to point to node1 as well, then we'll
|
| 211 |
+
# have a circular dependency
|
| 212 |
+
# the following order works around this issue, but this does not allow arbitrary configuration
|
| 213 |
+
# of sharing so it might break in a different case in the future, when it breaks
|
| 214 |
+
# quantizer writer can check the notes here to debug the issue
|
| 215 |
+
|
| 216 |
+
# sharing with other users of the producer node
|
| 217 |
+
# (arg, user)
|
| 218 |
+
if not isinstance(arg, Node) or not isinstance(n, Node):
|
| 219 |
+
raise Exception(f"Expected input_edge to have type Tuple[Node, Node], but got: {arg, n}")
|
| 220 |
+
for user in arg.users:
|
| 221 |
+
if user is n:
|
| 222 |
+
continue
|
| 223 |
+
arg_to_user_edge = (arg, user)
|
| 224 |
+
_union_input_edge_with(
|
| 225 |
+
input_edge,
|
| 226 |
+
input_edge_root_qspec,
|
| 227 |
+
arg_to_user_edge,
|
| 228 |
+
edge_or_node_to_qspec,
|
| 229 |
+
shared_with_map
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
# sharing with output of producer node
|
| 233 |
+
_union_input_edge_with(input_edge, input_edge_root_qspec, arg, edge_or_node_to_qspec, shared_with_map)
|
| 234 |
+
|
| 235 |
+
_update_shared_with(input_edge, qspec, shared_with_map)
|
| 236 |
+
|
| 237 |
+
# now that we get the sharing relations between all edges and nodes, we can assingn group ids
|
| 238 |
+
cur_group_id = 0
|
| 239 |
+
edge_or_node_to_group_id: Dict[EdgeOrNode, int] = {}
|
| 240 |
+
for edge_or_node in shared_with_map.keys():
|
| 241 |
+
root = _find_root_edge_or_node(edge_or_node, shared_with_map)
|
| 242 |
+
if root not in edge_or_node_to_group_id:
|
| 243 |
+
edge_or_node_to_group_id[root] = cur_group_id
|
| 244 |
+
cur_group_id += 1
|
| 245 |
+
edge_or_node_to_group_id[edge_or_node] = edge_or_node_to_group_id[root]
|
| 246 |
+
|
| 247 |
+
return edge_or_node_to_group_id
|
| 248 |
+
|
| 249 |
+
def _get_obs_or_fq_map(
|
| 250 |
+
edge_or_node_to_group_id: Dict[EdgeOrNode, int],
|
| 251 |
+
edge_or_node_to_qspec: Dict[EdgeOrNode, QuantizationSpecBase],
|
| 252 |
+
is_qat: bool
|
| 253 |
+
) -> Dict[EdgeOrNode, ObserverOrFakeQuantize]:
|
| 254 |
+
"""Generates the EdgeOrNode to observer/fake_quant instances
|
| 255 |
+
Makes sure that for EdgeOrNode that has the same group_id should have the same observer or fake quant
|
| 256 |
+
instances
|
| 257 |
+
"""
|
| 258 |
+
obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize] = {}
|
| 259 |
+
group_id_to_obs_or_fq: Dict[int, ObserverOrFakeQuantize] = {}
|
| 260 |
+
for edge_or_node, qspec in edge_or_node_to_qspec.items():
|
| 261 |
+
group_id = edge_or_node_to_group_id[edge_or_node]
|
| 262 |
+
if group_id not in group_id_to_obs_or_fq:
|
| 263 |
+
# TODO: maybe edge_or_node_to_qspec should be edge_or_node_to_root_qspec, this will simplify
|
| 264 |
+
# the implementation for _create_obs_or_fq_from_qspec
|
| 265 |
+
group_id_to_obs_or_fq[group_id] = _create_obs_or_fq_from_qspec(qspec, obs_or_fq_map, is_qat)
|
| 266 |
+
obs_or_fq_map[edge_or_node] = group_id_to_obs_or_fq[group_id]
|
| 267 |
+
return obs_or_fq_map
|
| 268 |
+
|
| 269 |
+
def _maybe_insert_input_observer_for_arg_or_kwarg(
|
| 270 |
+
node: Union[Node, Any],
|
| 271 |
+
arg: Argument,
|
| 272 |
+
qconfig: QConfigAny,
|
| 273 |
+
model: torch.nn.Module,
|
| 274 |
+
named_modules: Dict[str, torch.nn.Module],
|
| 275 |
+
obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize],
|
| 276 |
+
is_qat: bool,
|
| 277 |
+
) -> Argument:
|
| 278 |
+
"""
|
| 279 |
+
Given a `node` and an `arg`, inserts an input observer between
|
| 280 |
+
`node` and `arg` if necessary.
|
| 281 |
+
"""
|
| 282 |
+
# for ops such as torch.cat([x0, x1]),
|
| 283 |
+
# traverse through the list
|
| 284 |
+
if isinstance(arg, (list, tuple)):
|
| 285 |
+
new_arg_to_return = []
|
| 286 |
+
for inner_arg in arg:
|
| 287 |
+
new_inner_arg = _maybe_insert_input_observer_for_arg_or_kwarg(
|
| 288 |
+
node, inner_arg, qconfig, model, named_modules, obs_or_fq_map, is_qat,
|
| 289 |
+
)
|
| 290 |
+
new_arg_to_return.append(new_inner_arg)
|
| 291 |
+
return type(arg)(new_arg_to_return)
|
| 292 |
+
|
| 293 |
+
if not isinstance(arg, Node):
|
| 294 |
+
return arg
|
| 295 |
+
assert isinstance(arg, Node)
|
| 296 |
+
# default (no observer)
|
| 297 |
+
new_arg = arg
|
| 298 |
+
|
| 299 |
+
# find the original `arg` node to the current node, skipping inserted observer/fake_quant nodes
|
| 300 |
+
original_arg = arg
|
| 301 |
+
while _is_activation_post_process_node(original_arg, named_modules):
|
| 302 |
+
original_arg = original_arg.args[0] # type: ignore[assignment]
|
| 303 |
+
assert isinstance(original_arg, Node), f"expect original argument to be a Node, but got: {type(original_arg)}"
|
| 304 |
+
|
| 305 |
+
input_edge = (original_arg, node)
|
| 306 |
+
if input_edge not in obs_or_fq_map:
|
| 307 |
+
return new_arg
|
| 308 |
+
# input_edge needs to be observed
|
| 309 |
+
input_edge_obs_or_fq = obs_or_fq_map[input_edge]
|
| 310 |
+
if input_edge_obs_or_fq is None:
|
| 311 |
+
return new_arg
|
| 312 |
+
|
| 313 |
+
arg_as_output_obs_or_fq = obs_or_fq_map.get(original_arg, None)
|
| 314 |
+
# the arg is observed as the output and is using the same instance as the input_edge
|
| 315 |
+
# we'll reuse the inserted observer/fake_quant
|
| 316 |
+
if arg_as_output_obs_or_fq is not None and id(arg_as_output_obs_or_fq) == id(input_edge_obs_or_fq):
|
| 317 |
+
return new_arg
|
| 318 |
+
|
| 319 |
+
# otherwise, we'll insert a new observer/fake_quant node
|
| 320 |
+
|
| 321 |
+
existing_obs_node = None
|
| 322 |
+
# skip inserting new observers if the same observer instance is inserted before for another user
|
| 323 |
+
# Example:
|
| 324 |
+
# conv1 -> obs1 -> existing_obs -> conv2
|
| 325 |
+
# \ -> conv3
|
| 326 |
+
#
|
| 327 |
+
# instead of inserting new observers we will have:
|
| 328 |
+
# conv1 -> obs1 -> existing_obs -> conv2
|
| 329 |
+
# \ -> conv3
|
| 330 |
+
for maybe_obs_node in arg.users.keys():
|
| 331 |
+
if not _is_activation_post_process_node(maybe_obs_node, named_modules):
|
| 332 |
+
continue
|
| 333 |
+
maybe_obs_mod = named_modules[maybe_obs_node.target] # type: ignore[index]
|
| 334 |
+
if id(maybe_obs_mod) == id(input_edge_obs_or_fq):
|
| 335 |
+
return maybe_obs_node
|
| 336 |
+
|
| 337 |
+
new_arg = _insert_obs_or_fq(arg, input_edge_obs_or_fq, model, named_modules, model.graph)
|
| 338 |
+
return new_arg
|
| 339 |
+
|
| 340 |
+
def _maybe_insert_input_observers_for_node(
|
| 341 |
+
node: Node,
|
| 342 |
+
qconfig: QConfigAny,
|
| 343 |
+
model: torch.nn.Module,
|
| 344 |
+
named_modules: Dict[str, torch.nn.Module],
|
| 345 |
+
obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize],
|
| 346 |
+
is_qat: bool,
|
| 347 |
+
) -> None:
|
| 348 |
+
"""
|
| 349 |
+
If needed, inserts observers to the input args and kwargs of `node`.
|
| 350 |
+
Note: modifies `node` inplace.
|
| 351 |
+
|
| 352 |
+
For example, if cur_node needs an observer after prev_node, we change from
|
| 353 |
+
|
| 354 |
+
prev_node -> cur_node
|
| 355 |
+
|
| 356 |
+
To
|
| 357 |
+
|
| 358 |
+
prev_node -> obs -> cur_node
|
| 359 |
+
|
| 360 |
+
"""
|
| 361 |
+
# Look through every input arg. If that arg's target dtype does not
|
| 362 |
+
# match the current node's target dtype, insert an observer.
|
| 363 |
+
new_args = []
|
| 364 |
+
# map from old arg to new arg, used for updating the numeric debug handle map
|
| 365 |
+
remap = {}
|
| 366 |
+
for arg in node.args:
|
| 367 |
+
new_arg = _maybe_insert_input_observer_for_arg_or_kwarg(
|
| 368 |
+
node, arg, qconfig, model, named_modules, obs_or_fq_map, is_qat,
|
| 369 |
+
)
|
| 370 |
+
new_args.append(new_arg)
|
| 371 |
+
remap[arg] = new_arg
|
| 372 |
+
|
| 373 |
+
if "numeric_debug_handle" in node.meta:
|
| 374 |
+
|
| 375 |
+
def remap_fn(x):
|
| 376 |
+
return remap.get(x, x)
|
| 377 |
+
|
| 378 |
+
numeric_debug_handle = node.meta["numeric_debug_handle"]
|
| 379 |
+
node.meta["numeric_debug_handle"] = {remap_fn(k): v for k, v in numeric_debug_handle.items()}
|
| 380 |
+
|
| 381 |
+
# Clone has a memory_format kwarg and zeros_like has a pin_memory kwarg
|
| 382 |
+
# that persist in exported graph. This is just a work around for these.
|
| 383 |
+
assert (
|
| 384 |
+
node.target == torch.ops.aten.clone.default or
|
| 385 |
+
node.target == torch.ops.aten.zeros_like.default or
|
| 386 |
+
len(node.kwargs) == 0
|
| 387 |
+
), " expecting kwargs for aten op IR to be empty"
|
| 388 |
+
|
| 389 |
+
# assign the new args to the node, inplace
|
| 390 |
+
node.args = tuple(new_args)
|
| 391 |
+
|
| 392 |
+
def _maybe_insert_output_observer_for_node(
|
| 393 |
+
node: Node,
|
| 394 |
+
model: torch.nn.Module,
|
| 395 |
+
named_modules: Dict[str, torch.nn.Module],
|
| 396 |
+
graph: Graph,
|
| 397 |
+
obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize],
|
| 398 |
+
is_qat: bool,
|
| 399 |
+
) -> Optional[Node]:
|
| 400 |
+
if node in obs_or_fq_map:
|
| 401 |
+
output_act_obs_or_fq = obs_or_fq_map[node]
|
| 402 |
+
return _insert_obs_or_fq(node, output_act_obs_or_fq, model, named_modules, graph)
|
| 403 |
+
return None
|
| 404 |
+
|
| 405 |
+
def _maybe_insert_input_and_output_observers_for_node(
|
| 406 |
+
node: Node,
|
| 407 |
+
model: torch.fx.GraphModule,
|
| 408 |
+
obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize],
|
| 409 |
+
is_qat: bool,
|
| 410 |
+
):
|
| 411 |
+
this_node_quantization_annotation = node.meta["quantization_annotation"] if "quantization_annotation" in node.meta else None
|
| 412 |
+
if this_node_quantization_annotation is None:
|
| 413 |
+
return
|
| 414 |
+
|
| 415 |
+
named_modules = dict(model.named_modules(remove_duplicate=False))
|
| 416 |
+
_maybe_insert_input_observers_for_node(
|
| 417 |
+
node,
|
| 418 |
+
None, # qconfig
|
| 419 |
+
model,
|
| 420 |
+
named_modules,
|
| 421 |
+
obs_or_fq_map,
|
| 422 |
+
is_qat,
|
| 423 |
+
)
|
| 424 |
+
|
| 425 |
+
output_is_a_tensor = "val" in node.meta and isinstance(node.meta["val"], FakeTensor)
|
| 426 |
+
if not output_is_a_tensor:
|
| 427 |
+
return
|
| 428 |
+
|
| 429 |
+
# this returns the new observer node if it was needed
|
| 430 |
+
maybe_output_obs_node = _maybe_insert_output_observer_for_node(
|
| 431 |
+
node, model, named_modules, model.graph, obs_or_fq_map, is_qat)
|
| 432 |
+
|
| 433 |
+
if maybe_output_obs_node is None:
|
| 434 |
+
return
|
| 435 |
+
# Update users of original node to use the output observer
|
| 436 |
+
# instead. For example, change
|
| 437 |
+
#
|
| 438 |
+
# next_node
|
| 439 |
+
# /
|
| 440 |
+
# cur_node -> obs
|
| 441 |
+
#
|
| 442 |
+
# to
|
| 443 |
+
#
|
| 444 |
+
# next_node
|
| 445 |
+
# /
|
| 446 |
+
# cur_node -> obs
|
| 447 |
+
#
|
| 448 |
+
# We need to save orig users before updating uses because
|
| 449 |
+
# the list of users will change as we update uses
|
| 450 |
+
orig_users = list(node.users.keys())
|
| 451 |
+
for user_node in orig_users:
|
| 452 |
+
if user_node is maybe_output_obs_node:
|
| 453 |
+
continue
|
| 454 |
+
user_node.replace_input_with(node, maybe_output_obs_node)
|
| 455 |
+
|
| 456 |
+
def prepare(
|
| 457 |
+
model: GraphModule,
|
| 458 |
+
node_name_to_scope: Dict[str, Tuple[str, type]],
|
| 459 |
+
is_qat: bool,
|
| 460 |
+
) -> GraphModule:
|
| 461 |
+
# Since we are mutating the graph as we go, we iterate over the original
|
| 462 |
+
# nodes before observer insertion, instead of model.graph.nodes.
|
| 463 |
+
nodes_before_observation = list(model.graph.nodes)
|
| 464 |
+
|
| 465 |
+
# At the high level we construct a map from EdgeOrNode to a observer_or_fake_quant instance
|
| 466 |
+
# all edge/nodes that belongs to the same group will use the same instance
|
| 467 |
+
# and when we insert observers we'll just query this map to get the correct observer_or_fake_quant
|
| 468 |
+
# instance
|
| 469 |
+
edge_or_node_to_qspec = _get_edge_or_node_to_qspec(model)
|
| 470 |
+
edge_or_node_to_group_id = _get_edge_or_node_to_group_id(edge_or_node_to_qspec)
|
| 471 |
+
obs_or_fq_map = _get_obs_or_fq_map(edge_or_node_to_group_id, edge_or_node_to_qspec, is_qat)
|
| 472 |
+
|
| 473 |
+
for node in nodes_before_observation:
|
| 474 |
+
# TODO: simplify logic for inserting observers
|
| 475 |
+
_maybe_insert_input_and_output_observers_for_node(node, model, obs_or_fq_map, is_qat)
|
| 476 |
+
|
| 477 |
+
model = GraphModule(model, model.graph)
|
| 478 |
+
|
| 479 |
+
_save_state(
|
| 480 |
+
model,
|
| 481 |
+
{}, # node_name_to_qconfig
|
| 482 |
+
node_name_to_scope,
|
| 483 |
+
PrepareCustomConfig(),
|
| 484 |
+
{}, # equalization_node_name_to_qconfig
|
| 485 |
+
QConfigMapping(),
|
| 486 |
+
is_qat,
|
| 487 |
+
set() # observed_node_names
|
| 488 |
+
)
|
| 489 |
+
return model
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/representation/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .rewrite import reference_representation_rewrite
|
| 2 |
+
|
| 3 |
+
__all__ = [
|
| 4 |
+
"reference_representation_rewrite",
|
| 5 |
+
]
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/qconfig_mapping.py
ADDED
|
@@ -0,0 +1,350 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
from collections import OrderedDict
|
| 3 |
+
from typing import Any, Callable, Dict, Tuple, Union, List
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
from .fake_quantize import (
|
| 8 |
+
default_weight_fake_quant,
|
| 9 |
+
FixedQParamsFakeQuantize,
|
| 10 |
+
)
|
| 11 |
+
from .observer import (
|
| 12 |
+
_PartialWrapper,
|
| 13 |
+
default_fixed_qparams_range_0to1_observer,
|
| 14 |
+
default_fixed_qparams_range_neg1to1_observer,
|
| 15 |
+
default_placeholder_observer,
|
| 16 |
+
default_weight_observer,
|
| 17 |
+
)
|
| 18 |
+
from .qconfig import (
|
| 19 |
+
default_reuse_input_qconfig,
|
| 20 |
+
default_symmetric_qnnpack_qconfig,
|
| 21 |
+
default_symmetric_qnnpack_qat_qconfig,
|
| 22 |
+
get_default_qconfig,
|
| 23 |
+
get_default_qat_qconfig,
|
| 24 |
+
QConfig,
|
| 25 |
+
QConfigAny,
|
| 26 |
+
default_quint8_weight_qconfig
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
__all__ = [
|
| 31 |
+
"get_default_qconfig_mapping",
|
| 32 |
+
"get_default_qat_qconfig_mapping",
|
| 33 |
+
"QConfigMapping",
|
| 34 |
+
]
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# TODO: replace all usages with these constants
|
| 38 |
+
_GLOBAL_DICT_KEY = ""
|
| 39 |
+
_OBJECT_TYPE_DICT_KEY = "object_type"
|
| 40 |
+
_MODULE_NAME_REGEX_DICT_KEY = "module_name_regex"
|
| 41 |
+
_MODULE_NAME_DICT_KEY = "module_name"
|
| 42 |
+
_MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY = "module_name_object_type_order"
|
| 43 |
+
|
| 44 |
+
# TODO: derive this map from the BackendConfig
|
| 45 |
+
_FIXED_QPARAMS_OP_TO_OBSERVER: Dict[Union[Callable, str], _PartialWrapper] = {
|
| 46 |
+
torch.nn.Hardsigmoid: default_fixed_qparams_range_0to1_observer,
|
| 47 |
+
torch.nn.functional.hardsigmoid: default_fixed_qparams_range_0to1_observer,
|
| 48 |
+
"hardsigmoid": default_fixed_qparams_range_0to1_observer,
|
| 49 |
+
"hardsigmoid_": default_fixed_qparams_range_0to1_observer,
|
| 50 |
+
torch.nn.Sigmoid: default_fixed_qparams_range_0to1_observer,
|
| 51 |
+
torch.sigmoid: default_fixed_qparams_range_0to1_observer,
|
| 52 |
+
"sigmoid": default_fixed_qparams_range_0to1_observer,
|
| 53 |
+
"sigmoid_": default_fixed_qparams_range_0to1_observer,
|
| 54 |
+
torch.nn.Softmax: default_fixed_qparams_range_0to1_observer,
|
| 55 |
+
torch.nn.Tanh: default_fixed_qparams_range_neg1to1_observer,
|
| 56 |
+
torch.tanh: default_fixed_qparams_range_neg1to1_observer,
|
| 57 |
+
"tanh": default_fixed_qparams_range_neg1to1_observer,
|
| 58 |
+
"tanh_": default_fixed_qparams_range_neg1to1_observer,
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def _get_default_qconfig_mapping(is_qat: bool, backend: str, version: int) -> QConfigMapping:
|
| 63 |
+
"""
|
| 64 |
+
Return the default QConfigMapping for the given quantization type and backend.
|
| 65 |
+
"""
|
| 66 |
+
if is_qat:
|
| 67 |
+
qconfig = get_default_qat_qconfig(backend, version)
|
| 68 |
+
else:
|
| 69 |
+
qconfig = get_default_qconfig(backend, version)
|
| 70 |
+
default_weight = default_weight_fake_quant if is_qat else default_weight_observer
|
| 71 |
+
|
| 72 |
+
# default_per_channel_weight_observer is not currently compatible with fbgemm backend
|
| 73 |
+
# so we have to modify the weight observer to default_weight_observer or another
|
| 74 |
+
# per tensor supported observer.
|
| 75 |
+
# see https://github.com/pytorch/pytorch/issues/47535
|
| 76 |
+
if backend in ("fbgemm", "x86"):
|
| 77 |
+
qconfig_transpose = QConfig(activation=qconfig.activation, weight=default_weight)
|
| 78 |
+
else:
|
| 79 |
+
qconfig_transpose = qconfig
|
| 80 |
+
|
| 81 |
+
# currently layernorm only supports float weights
|
| 82 |
+
# we have to add this because otherwise there will be a extra quantize-dequantize pair
|
| 83 |
+
qconfig_layernorm = QConfig(activation=qconfig.activation, weight=default_placeholder_observer)
|
| 84 |
+
|
| 85 |
+
qconfig_mapping = QConfigMapping() \
|
| 86 |
+
.set_global(qconfig) \
|
| 87 |
+
.set_object_type("reshape", default_reuse_input_qconfig) \
|
| 88 |
+
.set_object_type(torch.nn.ConvTranspose1d, qconfig_transpose) \
|
| 89 |
+
.set_object_type(torch.nn.ConvTranspose2d, qconfig_transpose) \
|
| 90 |
+
.set_object_type(torch.nn.ConvTranspose3d, qconfig_transpose) \
|
| 91 |
+
.set_object_type(torch.nn.functional.conv_transpose1d, qconfig_transpose) \
|
| 92 |
+
.set_object_type(torch.nn.functional.conv_transpose2d, qconfig_transpose) \
|
| 93 |
+
.set_object_type(torch.nn.functional.conv_transpose3d, qconfig_transpose) \
|
| 94 |
+
.set_object_type(torch.nn.functional.layer_norm, qconfig_layernorm) \
|
| 95 |
+
.set_object_type(torch.nn.LayerNorm, qconfig_layernorm) \
|
| 96 |
+
.set_object_type(torch.nn.PReLU, default_quint8_weight_qconfig) \
|
| 97 |
+
|
| 98 |
+
# Use special observers for ops with fixed qparams
|
| 99 |
+
fixed_qparams_observer_to_qconfig: Dict[Any, QConfigAny] = {}
|
| 100 |
+
for fixed_qparams_op, observer in _FIXED_QPARAMS_OP_TO_OBSERVER.items():
|
| 101 |
+
if observer in fixed_qparams_observer_to_qconfig:
|
| 102 |
+
fixed_qparams_qconfig = fixed_qparams_observer_to_qconfig[observer]
|
| 103 |
+
else:
|
| 104 |
+
if is_qat:
|
| 105 |
+
activation = FixedQParamsFakeQuantize.with_args(observer=observer)
|
| 106 |
+
else:
|
| 107 |
+
activation = observer
|
| 108 |
+
fixed_qparams_qconfig = QConfig(activation=activation, weight=default_weight)
|
| 109 |
+
fixed_qparams_observer_to_qconfig[observer] = fixed_qparams_qconfig
|
| 110 |
+
qconfig_mapping.set_object_type(fixed_qparams_op, fixed_qparams_qconfig)
|
| 111 |
+
|
| 112 |
+
# TODO Currently it's required that separate ops in a fused op/module have the same qconfig.
|
| 113 |
+
# Need to be able to support fusion of ops with different qconfigs
|
| 114 |
+
|
| 115 |
+
return qconfig_mapping
|
| 116 |
+
|
| 117 |
+
def get_default_qconfig_mapping(backend="x86", version=0) -> QConfigMapping:
|
| 118 |
+
"""
|
| 119 |
+
Return the default QConfigMapping for post training quantization.
|
| 120 |
+
|
| 121 |
+
Args:
|
| 122 |
+
* ``backend`` (str) : the quantization backend for the default qconfig mapping, should be
|
| 123 |
+
one of ["x86" (default), "fbgemm", "qnnpack", "onednn"]
|
| 124 |
+
* ``version`` (int) : the version for the default qconfig mapping
|
| 125 |
+
"""
|
| 126 |
+
# TODO: add assert for backend choices
|
| 127 |
+
return _get_default_qconfig_mapping(False, backend, version)
|
| 128 |
+
|
| 129 |
+
def get_default_qat_qconfig_mapping(backend="x86", version=1) -> QConfigMapping:
|
| 130 |
+
"""
|
| 131 |
+
Return the default QConfigMapping for quantization aware training.
|
| 132 |
+
|
| 133 |
+
Args:
|
| 134 |
+
* ``backend`` (str) : the quantization backend for the default qconfig mapping, should be
|
| 135 |
+
one of ["x86" (default), "fbgemm", "qnnpack", "onednn"]
|
| 136 |
+
* ``version`` (int) : the version for the default qconfig mapping
|
| 137 |
+
"""
|
| 138 |
+
return _get_default_qconfig_mapping(True, backend, version)
|
| 139 |
+
|
| 140 |
+
def _get_symmetric_qnnpack_qconfig_mapping() -> QConfigMapping:
|
| 141 |
+
"""
|
| 142 |
+
Return a QConfigMapping that uses `torch.ao.quantization.default_symmetric_qnnpack_qconfig`
|
| 143 |
+
as the default QConfig.
|
| 144 |
+
"""
|
| 145 |
+
default_qconfig = default_symmetric_qnnpack_qconfig
|
| 146 |
+
return _get_default_qconfig_mapping_with_default_qconfig(False, "qnnpack", default_qconfig)
|
| 147 |
+
|
| 148 |
+
def _get_symmetric_qnnpack_qat_qconfig_mapping() -> QConfigMapping:
|
| 149 |
+
"""
|
| 150 |
+
Return a QConfigMapping that uses `torch.ao.quantization.default_symmetric_qnnpack_qat_qconfig`
|
| 151 |
+
as the default QConfig.
|
| 152 |
+
"""
|
| 153 |
+
default_qconfig = default_symmetric_qnnpack_qat_qconfig
|
| 154 |
+
return _get_default_qconfig_mapping_with_default_qconfig(True, "qnnpack", default_qconfig)
|
| 155 |
+
|
| 156 |
+
def _get_default_qconfig_mapping_with_default_qconfig(
|
| 157 |
+
is_qat: bool,
|
| 158 |
+
backend: str,
|
| 159 |
+
default_qconfig: QConfig,
|
| 160 |
+
) -> QConfigMapping:
|
| 161 |
+
"""
|
| 162 |
+
Return a QConfigMapping that uses the provided qconfig as the default QConfig.
|
| 163 |
+
"""
|
| 164 |
+
if is_qat:
|
| 165 |
+
qconfig_mapping = get_default_qat_qconfig_mapping(backend)
|
| 166 |
+
else:
|
| 167 |
+
qconfig_mapping = get_default_qconfig_mapping(backend)
|
| 168 |
+
qconfig_mapping.set_global(default_qconfig)
|
| 169 |
+
for pattern in qconfig_mapping.object_type_qconfigs.keys():
|
| 170 |
+
if pattern not in _FIXED_QPARAMS_OP_TO_OBSERVER:
|
| 171 |
+
qconfig_mapping.set_object_type(pattern, default_qconfig)
|
| 172 |
+
return qconfig_mapping
|
| 173 |
+
|
| 174 |
+
_QCONFIG_STYLE_ORDER: List[str] = [
|
| 175 |
+
"global_qconfig",
|
| 176 |
+
"object_type_qconfigs",
|
| 177 |
+
"module_name_regex_qconfigs",
|
| 178 |
+
"module_name_qconfigs",
|
| 179 |
+
"module_name_object_type_order_qconfigs",
|
| 180 |
+
]
|
| 181 |
+
|
| 182 |
+
class QConfigMapping:
|
| 183 |
+
"""
|
| 184 |
+
Mapping from model ops to :class:`torch.ao.quantization.QConfig` s.
|
| 185 |
+
|
| 186 |
+
The user can specify QConfigs using the following methods (in increasing match priority):
|
| 187 |
+
|
| 188 |
+
``set_global`` : sets the global (default) QConfig
|
| 189 |
+
|
| 190 |
+
``set_object_type`` : sets the QConfig for a given module type, function, or method name
|
| 191 |
+
|
| 192 |
+
``set_module_name_regex`` : sets the QConfig for modules matching the given regex string
|
| 193 |
+
|
| 194 |
+
``set_module_name`` : sets the QConfig for modules matching the given module name
|
| 195 |
+
|
| 196 |
+
``set_module_name_object_type_order`` : sets the QConfig for modules matching a combination
|
| 197 |
+
of the given module name, object type, and the index at which the module appears
|
| 198 |
+
|
| 199 |
+
Example usage::
|
| 200 |
+
|
| 201 |
+
qconfig_mapping = QConfigMapping()
|
| 202 |
+
.set_global(global_qconfig)
|
| 203 |
+
.set_object_type(torch.nn.Linear, qconfig1)
|
| 204 |
+
.set_object_type(torch.nn.ReLU, qconfig1)
|
| 205 |
+
.set_module_name_regex("foo.*bar.*conv[0-9]+", qconfig1)
|
| 206 |
+
.set_module_name_regex("foo.*", qconfig2)
|
| 207 |
+
.set_module_name("module1", qconfig1)
|
| 208 |
+
.set_module_name("module2", qconfig2)
|
| 209 |
+
.set_module_name_object_type_order("foo.bar", torch.nn.functional.linear, 0, qconfig3)
|
| 210 |
+
|
| 211 |
+
"""
|
| 212 |
+
|
| 213 |
+
def __init__(self):
|
| 214 |
+
# In increasing match priority:
|
| 215 |
+
self.global_qconfig: QConfigAny = None
|
| 216 |
+
self.object_type_qconfigs: OrderedDict[Union[Callable, str], QConfigAny] = OrderedDict()
|
| 217 |
+
self.module_name_regex_qconfigs: OrderedDict[str, QConfigAny] = OrderedDict()
|
| 218 |
+
self.module_name_qconfigs: OrderedDict[str, QConfigAny] = OrderedDict()
|
| 219 |
+
self.module_name_object_type_order_qconfigs: OrderedDict[Tuple[str, Callable, int], QConfigAny] =\
|
| 220 |
+
OrderedDict()
|
| 221 |
+
|
| 222 |
+
def set_global(self, global_qconfig: QConfigAny) -> QConfigMapping:
|
| 223 |
+
"""
|
| 224 |
+
Set the global (default) QConfig.
|
| 225 |
+
"""
|
| 226 |
+
self.global_qconfig = global_qconfig
|
| 227 |
+
return self
|
| 228 |
+
|
| 229 |
+
def set_object_type(self, object_type: Union[Callable, str], qconfig: QConfigAny) -> QConfigMapping:
|
| 230 |
+
"""
|
| 231 |
+
Set the QConfig for a given module type, function, or method name.
|
| 232 |
+
If the QConfig for an existing object type was already set, the new QConfig will override the old one.
|
| 233 |
+
"""
|
| 234 |
+
self.object_type_qconfigs[object_type] = qconfig
|
| 235 |
+
return self
|
| 236 |
+
|
| 237 |
+
def set_module_name_regex(self, module_name_regex: str, qconfig: QConfigAny) -> QConfigMapping:
|
| 238 |
+
"""
|
| 239 |
+
Set the QConfig for modules matching the given regex string.
|
| 240 |
+
|
| 241 |
+
Regexes will be matched in the order in which they are registered through this method.
|
| 242 |
+
Thus, the caller should register more specific patterns first, e.g.::
|
| 243 |
+
|
| 244 |
+
qconfig_mapping = QConfigMapping()
|
| 245 |
+
.set_module_name_regex("foo.*bar.*conv[0-9]+", qconfig1)
|
| 246 |
+
.set_module_name_regex("foo.*bar.*", qconfig2)
|
| 247 |
+
.set_module_name_regex("foo.*", qconfig3)
|
| 248 |
+
|
| 249 |
+
In this example, "foo.bar.conv0" would match qconfig1, "foo.bar.linear" would match qconfig2,
|
| 250 |
+
and "foo.baz.relu" would match qconfig3.
|
| 251 |
+
|
| 252 |
+
If the QConfig for an existing module name regex was already set, the new QConfig will override the
|
| 253 |
+
old one while preserving the order in which the regexes were originally registered.
|
| 254 |
+
"""
|
| 255 |
+
self.module_name_regex_qconfigs[module_name_regex] = qconfig
|
| 256 |
+
return self
|
| 257 |
+
|
| 258 |
+
def set_module_name(self, module_name: str, qconfig: QConfigAny) -> QConfigMapping:
|
| 259 |
+
"""
|
| 260 |
+
Set the QConfig for modules matching the given module name.
|
| 261 |
+
If the QConfig for an existing module name was already set, the new QConfig will override the old one.
|
| 262 |
+
"""
|
| 263 |
+
self.module_name_qconfigs[module_name] = qconfig
|
| 264 |
+
return self
|
| 265 |
+
|
| 266 |
+
def set_module_name_object_type_order(
|
| 267 |
+
self,
|
| 268 |
+
module_name: str,
|
| 269 |
+
object_type: Callable,
|
| 270 |
+
index: int,
|
| 271 |
+
qconfig: QConfigAny) -> QConfigMapping:
|
| 272 |
+
"""
|
| 273 |
+
Set the QConfig for modules matching a combination of the given module name, object type,
|
| 274 |
+
and the index at which the module appears.
|
| 275 |
+
|
| 276 |
+
If the QConfig for an existing (module name, object type, index) was already set, the new QConfig
|
| 277 |
+
will override the old one.
|
| 278 |
+
"""
|
| 279 |
+
self.module_name_object_type_order_qconfigs[(module_name, object_type, index)] = qconfig
|
| 280 |
+
return self
|
| 281 |
+
|
| 282 |
+
def __repr__(self) -> str:
|
| 283 |
+
output = self.__class__.__name__ + " ("
|
| 284 |
+
for style_name in _QCONFIG_STYLE_ORDER:
|
| 285 |
+
output += f"\n {style_name}"
|
| 286 |
+
qconfigs = getattr(self, style_name)
|
| 287 |
+
if isinstance(qconfigs, OrderedDict) and len(qconfigs) > 0:
|
| 288 |
+
for key, qconfig in qconfigs.items():
|
| 289 |
+
output += f"\n {key}: {qconfig}"
|
| 290 |
+
else:
|
| 291 |
+
output += f"\n {qconfigs}"
|
| 292 |
+
return output + "\n)"
|
| 293 |
+
|
| 294 |
+
# TODO: remove this
|
| 295 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 296 |
+
"""
|
| 297 |
+
Convert this ``QConfigMapping`` to a dictionary with the following keys:
|
| 298 |
+
|
| 299 |
+
"" (for global QConfig)
|
| 300 |
+
|
| 301 |
+
"object_type"
|
| 302 |
+
|
| 303 |
+
"module_name_regex"
|
| 304 |
+
|
| 305 |
+
"module_name"
|
| 306 |
+
|
| 307 |
+
"module_name_object_type_order"
|
| 308 |
+
|
| 309 |
+
The values of this dictionary are lists of tuples.
|
| 310 |
+
"""
|
| 311 |
+
return {
|
| 312 |
+
_GLOBAL_DICT_KEY: self.global_qconfig,
|
| 313 |
+
_OBJECT_TYPE_DICT_KEY: list(self.object_type_qconfigs.items()),
|
| 314 |
+
_MODULE_NAME_REGEX_DICT_KEY: list(self.module_name_regex_qconfigs.items()),
|
| 315 |
+
_MODULE_NAME_DICT_KEY: list(self.module_name_qconfigs.items()),
|
| 316 |
+
_MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY: [
|
| 317 |
+
(*k, v) for k, v in self.module_name_object_type_order_qconfigs.items()
|
| 318 |
+
],
|
| 319 |
+
}
|
| 320 |
+
|
| 321 |
+
# TODO: remove this
|
| 322 |
+
@classmethod
|
| 323 |
+
def from_dict(cls, qconfig_dict: Dict[str, Any]) -> QConfigMapping:
|
| 324 |
+
"""
|
| 325 |
+
Create a ``QConfigMapping`` from a dictionary with the following keys (all optional):
|
| 326 |
+
|
| 327 |
+
"" (for global QConfig)
|
| 328 |
+
|
| 329 |
+
"object_type"
|
| 330 |
+
|
| 331 |
+
"module_name_regex"
|
| 332 |
+
|
| 333 |
+
"module_name"
|
| 334 |
+
|
| 335 |
+
"module_name_object_type_order"
|
| 336 |
+
|
| 337 |
+
The values of this dictionary are expected to be lists of tuples.
|
| 338 |
+
"""
|
| 339 |
+
conf = cls()
|
| 340 |
+
if _GLOBAL_DICT_KEY in qconfig_dict:
|
| 341 |
+
conf.set_global(qconfig_dict[_GLOBAL_DICT_KEY])
|
| 342 |
+
for object_type, qconfig in qconfig_dict.get(_OBJECT_TYPE_DICT_KEY, []):
|
| 343 |
+
conf.set_object_type(object_type, qconfig)
|
| 344 |
+
for module_name_regex, qconfig in qconfig_dict.get(_MODULE_NAME_REGEX_DICT_KEY, []):
|
| 345 |
+
conf.set_module_name_regex(module_name_regex, qconfig)
|
| 346 |
+
for module_name, qconfig in qconfig_dict.get(_MODULE_NAME_DICT_KEY, []):
|
| 347 |
+
conf.set_module_name(module_name, qconfig)
|
| 348 |
+
for module_name, object_type, index, qconfig in qconfig_dict.get(_MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY, []):
|
| 349 |
+
conf.set_module_name_object_type_order(module_name, object_type, index, qconfig)
|
| 350 |
+
return conf
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/quantizer/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (663 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/quantizer/__pycache__/xnnpack_quantizer.cpython-311.pyc
ADDED
|
Binary file (19.7 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (13.6 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/__pycache__/_remove_auto_functionalized_pass.cpython-311.pyc
ADDED
|
Binary file (4.41 kB). View file
|
|
|