koichi12 commited on
Commit
0c36bb3
·
verified ·
1 Parent(s): 8caf96a

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_custom_op/__init__.py +0 -0
  2. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_custom_op/__pycache__/autograd.cpython-311.pyc +0 -0
  3. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_custom_op/__pycache__/functional.cpython-311.pyc +0 -0
  4. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_custom_op/__pycache__/impl.cpython-311.pyc +0 -0
  5. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_custom_op/functional.py +187 -0
  6. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/_conversions.py +118 -0
  7. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/linalg/__pycache__/__init__.cpython-311.pyc +0 -0
  8. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/nn/__init__.py +3 -0
  9. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/nn/__pycache__/__init__.cpython-311.pyc +0 -0
  10. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/nn/functional/__pycache__/__init__.cpython-311.pyc +0 -0
  11. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/special/__pycache__/__init__.cpython-311.pyc +0 -0
  12. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantizable/modules/__init__.py +9 -0
  13. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/__pycache__/functional_modules.cpython-311.pyc +0 -0
  14. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/__pycache__/normalization.cpython-311.pyc +0 -0
  15. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/batchnorm.py +106 -0
  16. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/utils.cpython-311.pyc +0 -0
  17. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/reference/modules/linear.py +57 -0
  18. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/sparse/quantized/dynamic/__pycache__/__init__.cpython-311.pyc +0 -0
  19. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/__pycache__/__init__.cpython-311.pyc +0 -0
  20. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/__pycache__/_numeric_suite.cpython-311.pyc +0 -0
  21. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/__pycache__/__init__.cpython-311.pyc +0 -0
  22. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__init__.py +189 -0
  23. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/quantization_mappings.cpython-311.pyc +0 -0
  24. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/quantize_fx.cpython-311.pyc +0 -0
  25. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/_correct_bias.py +144 -0
  26. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/__pycache__/_common_operator_config_utils.cpython-311.pyc +0 -0
  27. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/__pycache__/backend_config.cpython-311.pyc +0 -0
  28. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/__pycache__/observation_type.cpython-311.pyc +0 -0
  29. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/__pycache__/onednn.cpython-311.pyc +0 -0
  30. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/observation_type.py +0 -0
  31. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/qnnpack.py +160 -0
  32. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/tensorrt.py +81 -0
  33. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/utils.py +279 -0
  34. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/x86.py +113 -0
  35. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fake_quantize.py +546 -0
  36. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/_model_report/model_report.py +606 -0
  37. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/prepare.py +1880 -0
  38. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/__pycache__/duplicate_dq_pass.cpython-311.pyc +0 -0
  39. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/__pycache__/prepare.cpython-311.pyc +0 -0
  40. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/__pycache__/qat_utils.cpython-311.pyc +0 -0
  41. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/__pycache__/utils.cpython-311.pyc +0 -0
  42. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/duplicate_dq_pass.py +83 -0
  43. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/export_utils.py +211 -0
  44. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/prepare.py +489 -0
  45. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/representation/__init__.py +5 -0
  46. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/qconfig_mapping.py +350 -0
  47. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/quantizer/__pycache__/__init__.cpython-311.pyc +0 -0
  48. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/quantizer/__pycache__/xnnpack_quantizer.cpython-311.pyc +0 -0
  49. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/__pycache__/__init__.cpython-311.pyc +0 -0
  50. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/__pycache__/_remove_auto_functionalized_pass.cpython-311.pyc +0 -0
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_custom_op/__init__.py ADDED
File without changes
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_custom_op/__pycache__/autograd.cpython-311.pyc ADDED
Binary file (16 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_custom_op/__pycache__/functional.cpython-311.pyc ADDED
Binary file (10.4 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_custom_op/__pycache__/impl.cpython-311.pyc ADDED
Binary file (50 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_custom_op/functional.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import weakref
2
+
3
+ import torch
4
+ import torch.utils._pytree as pytree
5
+ from torch._C import _ExcludeDispatchKeyGuard, DispatchKey, DispatchKeySet
6
+ from torch._ops import OpOverload
7
+ from torch.library import Library
8
+ from torchgen.model import (
9
+ BaseTy,
10
+ BaseType,
11
+ FunctionSchema,
12
+ OperatorName,
13
+ OptionalType,
14
+ SchemaKind,
15
+ )
16
+
17
+ from .autograd import autograd_not_implemented
18
+
19
+
20
+ def register_functional_op(
21
+ lib: Library,
22
+ new_op_name: str,
23
+ mutable_op: OpOverload,
24
+ ) -> None:
25
+ """Given a mutable operator, registers the functional variant.
26
+
27
+ This API also correctly links the functional variant with the mutable
28
+ operator for the purposes of functionalization.
29
+
30
+ All of the new registrations are performed on the ``lib`` passed in.
31
+
32
+ Arguments:
33
+ lib (Library): Should be a torch.library.Library object that has
34
+ the same namespace as ``mutable_op``'s namespace.
35
+ lib will be used to register the new functional op as well
36
+ as a functionalization kernel for the ``mutable_op``
37
+ If you don't have a library handy, use
38
+ ``torch.library.Library(ns, 'FRAGMENT')`` to construct one.
39
+ new_op_name (str): The name of the functional operator (without the
40
+ namespace). If no namespace, the new functional variant will be
41
+ accessible under ``torch.ops.{lib.ns}.new_op_name``.
42
+ mutable_op (OpOverload): The mutable custom operator. Note
43
+ that you may need to add a `.default` to it, like
44
+ `torch.ops.aten.abs_.default`.
45
+
46
+ """
47
+ validate(mutable_op)
48
+ schema = functional_schema(new_op_name, mutable_op)
49
+ lib.define(schema)
50
+
51
+ functional_impl = construct_functional_impl(mutable_op)
52
+ lib.impl(new_op_name, functional_impl, 'CompositeExplicitAutograd')
53
+
54
+ functional_op = getattr(getattr(torch.ops, lib.ns), new_op_name).default
55
+
56
+ # There's no easy way for us to generate the autograd kernel, so we
57
+ # use autograd_not_implemented. Also, this makes it so that the user
58
+ # is unable to register an autograd formula themselves. This shouldn't
59
+ # be a problem if the user doesn't use the functional op direclty
60
+ # in their program, but we may need to revist this in the future.
61
+ lib.impl(new_op_name, autograd_not_implemented(functional_op), 'Autograd')
62
+
63
+ f_kernel = construct_functionalization_kernel(weakref.proxy(mutable_op), functional_op)
64
+
65
+ lib.impl(mutable_op, f_kernel, 'Functionalize')
66
+
67
+
68
+ def construct_functional_impl(mutable_op):
69
+ def functional_impl(*args):
70
+ # Strategy:
71
+ # - clone args that would have been mutated
72
+ # - run mutable_op
73
+ # - return the cloned args as additional outputs
74
+ new_args = []
75
+ extra_rets = []
76
+ for is_write, arg in zip(mutable_args(mutable_op), args):
77
+ if is_write:
78
+ cloned = arg.clone() if arg is not None else None
79
+ new_args.append(cloned)
80
+ extra_rets.append(cloned)
81
+ else:
82
+ new_args.append(arg)
83
+ result = mutable_op(*new_args)
84
+ if result is None:
85
+ return tuple(extra_rets)
86
+ if isinstance(result, tuple):
87
+ return (*result, *extra_rets)
88
+ return (result, *extra_rets)
89
+ return functional_impl
90
+
91
+
92
+ def construct_functionalization_kernel(mutable_op, functional_op):
93
+ def kernel(*args):
94
+ # There's nothing to be functionalized!
95
+ # We can still end up here because DispatchKey::Functionalize is a mode key
96
+ if pytree.tree_all_only(torch.Tensor, lambda x: not torch._is_functional_tensor(x), args):
97
+ with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Functionalize)):
98
+ return mutable_op(*args)
99
+
100
+ # NB: This differs from the codegen -- codegen handles cases where there
101
+ # are mixed FunctionalTensorWrapper and non-FunctionalTensorWrapper.
102
+ # This only really matters for XLA (mixed CPU-XLA tensors) and
103
+ # running functionalization without the PT2 stack (which guarantees to us that
104
+ # all tensors are FunctionalTensorWrapper).
105
+ if not pytree.tree_all_only(torch.Tensor, torch._is_functional_tensor, args):
106
+ raise RuntimeError("{mutable_op}: expected all args to be FunctionalTensorWrapper")
107
+
108
+ unwrapped_args = []
109
+ for arg in args:
110
+ if isinstance(arg, torch.Tensor) and torch._is_functional_tensor(arg):
111
+ torch._sync(arg)
112
+ unwrapped = torch._from_functional_tensor(arg)
113
+ unwrapped_args.append(unwrapped)
114
+ else:
115
+ unwrapped_args.append(arg)
116
+
117
+ with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Functionalize)):
118
+ output = functional_op(*unwrapped_args)
119
+
120
+ num_actual_output = len(mutable_op._schema.returns)
121
+ actual_output = pytree.tree_map(
122
+ torch._to_functional_tensor, output[:num_actual_output])
123
+
124
+ new_values_to_propagate = output[num_actual_output:]
125
+ inputs_to_replace = [arg for is_write, arg in zip(mutable_args(mutable_op), args)
126
+ if is_write]
127
+ assert len(new_values_to_propagate) == len(inputs_to_replace)
128
+ for new_value, arg in zip(new_values_to_propagate, inputs_to_replace):
129
+ if (arg is None and new_value is None) or (arg is not None and new_value is not None):
130
+ continue
131
+ torch._C._propagate_xla_data(arg, new_value)
132
+ torch._C._replace_(arg, new_value)
133
+ torch._C._commit_update(arg)
134
+ torch._sync(arg)
135
+
136
+ if len(actual_output) == 1:
137
+ return actual_output[0]
138
+ elif len(actual_output) == 0:
139
+ return None
140
+ return actual_output
141
+
142
+ return kernel
143
+
144
+
145
+ def validate(mutable_op: OpOverload):
146
+ if not isinstance(mutable_op, OpOverload):
147
+ raise TypeError(
148
+ f"register_functional_op(mutable_op): expected mutable_op to be instance of "
149
+ f"OpOverload but got {type(mutable_op)}")
150
+
151
+ # There are generally three types of "in-place" or "mutable" ops.
152
+ # Each of them have their own conventions:
153
+ # - inplace (first input modified in-place and returned as only output)
154
+ # - out= (some args modified in-place and returned as outputs)
155
+ # - mutable (some args modified in-place but none of those returned as outputs)
156
+ # In theory we can support all three, but we'll just support the last
157
+ # option right now for simplicity.
158
+ schema = FunctionSchema.parse(str(mutable_op._schema))
159
+ if not schema.kind() == SchemaKind.mutable:
160
+ raise RuntimeError("Expected op to be mutable (as opposed to functional, inplace or out)")
161
+ for ret in schema.returns:
162
+ # construct_functionalization_kernel assumes this for simplicity
163
+ if ret.annotation is not None:
164
+ raise NotImplementedError(
165
+ "NYI: register_functional_op(op) where op returns a mutated or aliased value. "
166
+ "Please file an issue (and as a workaround, modify your operator to "
167
+ "not return the mutated value or aliases)")
168
+ for arg in schema.arguments.flat_all:
169
+ # construct_functionalization_kernel assumes this for simplicity
170
+ if arg.type.is_tensor_like() and (
171
+ arg.type != BaseType(BaseTy.Tensor)
172
+ and arg.type != OptionalType(BaseType(BaseTy.Tensor))
173
+ ):
174
+ raise NotImplementedError(
175
+ "NYI: register_functional_op(op) where op has a List[Tensor] input."
176
+ "Please file an issue.")
177
+
178
+
179
+ def functional_schema(new_op_name, op: OpOverload):
180
+ schema = FunctionSchema.parse(str(op._schema))
181
+ schema = schema.signature().with_name(OperatorName.parse(new_op_name))
182
+ return str(schema)
183
+
184
+
185
+ def mutable_args(op: OpOverload):
186
+ return tuple(False if arg.alias_info is None else arg.alias_info.is_write
187
+ for arg in op._schema.arguments)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/_conversions.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch._prims_common as utils
3
+
4
+ # Utilities should come BEFORE this import
5
+ from torch._decomp import register_decomposition
6
+
7
+ from torch._prims_common import TensorLikeType
8
+ from torch._prims_common.wrappers import out_wrapper
9
+ from torch._refs import _broadcast_shapes
10
+
11
+ # Data conversion references.
12
+ #
13
+ # Note: this module breaks the usual _refs to torch naming scheme where
14
+ # _refs.foo.bar is a ref for torch.foo.bar. The following definitions are not
15
+ # part of _refs/__init__.py to avoid name clashes with Python builtin types
16
+ # (like int).
17
+
18
+ __all__ = [
19
+ # dtypes
20
+ "bfloat16",
21
+ "bool",
22
+ "byte",
23
+ "cdouble",
24
+ "cfloat",
25
+ "chalf",
26
+ "char",
27
+ "double",
28
+ "float",
29
+ "half",
30
+ "int",
31
+ "long",
32
+ "short",
33
+ # misc
34
+ "complex",
35
+ "polar",
36
+ ]
37
+
38
+
39
+ def _make_conversion_method(name: str, dtype: torch.dtype):
40
+ def fn(
41
+ self: TensorLikeType, memory_format: torch.memory_format = torch.preserve_format
42
+ ) -> TensorLikeType:
43
+ return self.to(dtype, memory_format=memory_format) # type: ignore[call-overload]
44
+
45
+ fn.__name__ = name
46
+ return fn
47
+
48
+
49
+ bfloat16 = _make_conversion_method("bfloat16", torch.bfloat16)
50
+
51
+ bool = _make_conversion_method("bool", torch.bool)
52
+
53
+ byte = _make_conversion_method("byte", torch.uint8)
54
+
55
+ cdouble = _make_conversion_method("cdouble", torch.cdouble)
56
+
57
+ cfloat = _make_conversion_method("cfloat", torch.cfloat)
58
+
59
+ chalf = _make_conversion_method("chalf", torch.complex32)
60
+
61
+ char = _make_conversion_method("char", torch.int8)
62
+
63
+ double = _make_conversion_method("double", torch.double)
64
+
65
+ float = _make_conversion_method("float", torch.float)
66
+
67
+ half = _make_conversion_method("half", torch.half)
68
+
69
+ int = _make_conversion_method("int", torch.int)
70
+
71
+ long = _make_conversion_method("long", torch.long)
72
+
73
+ short = _make_conversion_method("short", torch.short)
74
+
75
+
76
+ @register_decomposition(torch._ops.ops.aten.complex)
77
+ # Note: complex has type promotion tests disabled due to different semantics.
78
+ # exact_dtype is for compat with complex_check_dtype from core.
79
+ @out_wrapper(exact_dtype=True)
80
+ def complex(real: TensorLikeType, imag: TensorLikeType) -> TensorLikeType:
81
+ allowed_dtypes = (torch.float32, torch.float64, torch.float16)
82
+ torch._check(
83
+ real.dtype in allowed_dtypes and imag.dtype in allowed_dtypes,
84
+ lambda: (
85
+ f"Expected both inputs to be Half, Float or Double tensors but got "
86
+ f"{real.dtype} and {imag.dtype}"
87
+ ),
88
+ )
89
+ torch._check(
90
+ real.dtype == imag.dtype,
91
+ lambda: (
92
+ f"Expected object of scalar type {real.dtype} but got "
93
+ f"scalar type {imag.dtype} for second argument"
94
+ ),
95
+ )
96
+ result_dtype = utils.corresponding_complex_dtype(real.dtype) # type: ignore[arg-type]
97
+ common_shape = _broadcast_shapes(real.shape, imag.shape)
98
+ result = real.new_empty(
99
+ common_shape,
100
+ dtype=result_dtype,
101
+ layout=real.layout,
102
+ device=real.device,
103
+ # pin_memory=real.is_pinned(), # NYI
104
+ )
105
+ result.real = real
106
+ result.imag = imag
107
+ return result
108
+
109
+
110
+ @register_decomposition(torch._ops.ops.aten.polar)
111
+ # Note: polar has type promotion tests disabled due to different semantics.
112
+ # exact_dtype is for compat with complex_check_dtype from core.
113
+ @out_wrapper(exact_dtype=True)
114
+ def polar(abs: TensorLikeType, angle: TensorLikeType) -> TensorLikeType:
115
+ result = torch.complex(abs, angle)
116
+ result.real = abs * torch.cos(angle)
117
+ result.imag = abs * torch.sin(angle)
118
+ return result
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/linalg/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (16.9 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/nn/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from typing import List
2
+
3
+ __all__: List[str] = []
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/nn/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (351 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/nn/functional/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (49 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/special/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (10.6 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantizable/modules/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from .activation import MultiheadAttention
2
+ from .rnn import LSTM
3
+ from .rnn import LSTMCell
4
+
5
+ __all__ = [
6
+ 'LSTM',
7
+ 'LSTMCell',
8
+ 'MultiheadAttention',
9
+ ]
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/__pycache__/functional_modules.cpython-311.pyc ADDED
Binary file (13.5 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/__pycache__/normalization.cpython-311.pyc ADDED
Binary file (11.9 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/batchnorm.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.ao.nn.intrinsic as nni
3
+
4
+ __all__ = [
5
+ "BatchNorm2d",
6
+ "BatchNorm3d"
7
+ ]
8
+
9
+ class _BatchNorm(torch.nn.modules.batchnorm._BatchNorm):
10
+ def __init__(self, num_features, eps=1e-5, momentum=0.1, device=None, dtype=None) -> None:
11
+ factory_kwargs = {'device': device, 'dtype': dtype}
12
+ super().__init__(num_features, eps, momentum, True, True, **factory_kwargs)
13
+ self.register_buffer('scale', torch.tensor(1.0, **factory_kwargs))
14
+ self.register_buffer('zero_point', torch.tensor(0, **factory_kwargs))
15
+
16
+ @staticmethod
17
+ def from_float(cls, mod):
18
+ activation_post_process = mod.activation_post_process
19
+ if type(mod) == cls._NNI_BN_RELU_MODULE:
20
+ mod = mod[0]
21
+ scale, zero_point = activation_post_process.calculate_qparams()
22
+ new_mod = cls(mod.num_features, mod.eps)
23
+ new_mod.weight = mod.weight
24
+ new_mod.bias = mod.bias
25
+ new_mod.running_mean = mod.running_mean
26
+ new_mod.running_var = mod.running_var
27
+ new_mod.scale = scale
28
+ new_mod.zero_point = zero_point
29
+ return new_mod
30
+
31
+ @classmethod
32
+ def from_reference(cls, bn, output_scale, output_zero_point):
33
+ qbn = cls(
34
+ bn.num_features,
35
+ bn.eps,
36
+ bn.momentum,
37
+ device=bn.weight.device,
38
+ dtype=bn.weight.dtype
39
+ )
40
+ qbn.weight = bn.weight
41
+ qbn.bias = bn.bias
42
+ qbn.running_mean = bn.running_mean
43
+ qbn.running_var = bn.running_var
44
+ qbn.scale = output_scale
45
+ qbn.zero_point = output_zero_point
46
+ return qbn
47
+
48
+ class BatchNorm2d(_BatchNorm):
49
+ r"""This is the quantized version of :class:`~torch.nn.BatchNorm2d`.
50
+ """
51
+
52
+ _NNI_BN_RELU_MODULE = nni.BNReLU2d
53
+
54
+ def __init__(self, num_features, eps=1e-5, momentum=0.1, device=None, dtype=None) -> None:
55
+ factory_kwargs = {'device': device, 'dtype': dtype}
56
+ super().__init__(num_features, eps, momentum, **factory_kwargs)
57
+
58
+ def _get_name(self):
59
+ return 'QuantizedBatchNorm2d'
60
+
61
+ def _check_input_dim(self, input):
62
+ # Temporarily using len(shape) instead of ndim due to JIT issue
63
+ # https://github.com/pytorch/pytorch/issues/23890
64
+ if len(input.shape) != 4:
65
+ raise ValueError("Input shape must be `(N, C, H, W)`!")
66
+
67
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
68
+ # disabling this since this is not symbolically traceable
69
+ # self._check_input_dim(input)
70
+ return torch.ops.quantized.batch_norm2d(
71
+ input, self.weight, self.bias, self.running_mean,
72
+ self.running_var, self.eps, self.scale, self.zero_point)
73
+
74
+ @classmethod
75
+ def from_float(cls, mod):
76
+ return _BatchNorm.from_float(cls, mod)
77
+
78
+ class BatchNorm3d(_BatchNorm):
79
+ r"""This is the quantized version of :class:`~torch.nn.BatchNorm3d`.
80
+ """
81
+
82
+ _NNI_BN_RELU_MODULE = nni.BNReLU3d
83
+
84
+ def __init__(self, num_features, eps=1e-5, momentum=0.1, device=None, dtype=None):
85
+ factory_kwargs = {'device': device, 'dtype': dtype}
86
+ super().__init__(num_features, eps, momentum, **factory_kwargs)
87
+
88
+ def _get_name(self):
89
+ return 'QuantizedBatchNorm3d'
90
+
91
+ def _check_input_dim(self, input):
92
+ # Temporarily using len(shape) instead of ndim due to JIT issue
93
+ # https://github.com/pytorch/pytorch/issues/23890
94
+ if len(input.shape) != 5:
95
+ raise ValueError("Input shape must be `(N, C, H, W)`!")
96
+
97
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
98
+ # disabling this since this is not symbolically traceable
99
+ # self._check_input_dim(input)
100
+ return torch.ops.quantized.batch_norm3d(
101
+ input, self.weight, self.bias, self.running_mean,
102
+ self.running_var, self.eps, self.scale, self.zero_point)
103
+
104
+ @classmethod
105
+ def from_float(cls, mod):
106
+ return _BatchNorm.from_float(cls, mod)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/utils.cpython-311.pyc ADDED
Binary file (14.5 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/reference/modules/linear.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from typing import Optional, Dict, Any
5
+ from .utils import ReferenceQuantizedModule
6
+
7
+ __all__ = ['Linear']
8
+
9
+ class Linear(nn.Linear, ReferenceQuantizedModule):
10
+ """ A reference quantized linear module that fits into the FX
11
+ Graph Mode Quantization workflow
12
+ activation will be floating point Tensor, we will store floating
13
+ point weight as well in the module, but in forward we'll quantize
14
+ and dequantize the weight before running the floating point functional
15
+ linear operator.
16
+ """
17
+ _IS_REFERENCE = True
18
+
19
+ def __init__(
20
+ self,
21
+ in_features: int,
22
+ out_features: int,
23
+ bias_: bool = True,
24
+ device: Optional[torch.device] = None,
25
+ dtype: Optional[torch.dtype] = None,
26
+ weight_qparams: Optional[Dict[str, Any]] = None):
27
+ super().__init__(in_features, out_features, bias_, device, dtype)
28
+ self._init_weight_qparams(weight_qparams, device)
29
+
30
+ def _get_name(self):
31
+ return "QuantizedLinear(Reference)"
32
+
33
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
34
+ """
35
+ we have:
36
+ w(float) -- quant - dequant \
37
+ x(float) ------------- F.linear ---
38
+
39
+ In the full model, we will see
40
+ w(float) -- quant - *dequant \
41
+ x -- quant --- *dequant -- *F.linear --- *quant - dequant
42
+ and the backend should be able to fuse the ops with `*` into a quantized linear
43
+ """
44
+ weight_quant_dequant = self.get_weight()
45
+ result = F.linear(x, weight_quant_dequant, self.bias)
46
+ return result
47
+
48
+ @classmethod
49
+ def from_float(cls, float_linear, weight_qparams):
50
+ qref_linear = Linear(
51
+ float_linear.in_features, float_linear.out_features,
52
+ float_linear.bias is not None, device=float_linear.weight.device,
53
+ dtype=float_linear.weight.dtype, weight_qparams=weight_qparams)
54
+ qref_linear.weight = torch.nn.Parameter(float_linear.weight.detach())
55
+ if float_linear.bias is not None:
56
+ qref_linear.bias = torch.nn.Parameter(float_linear.bias.detach())
57
+ return qref_linear
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/sparse/quantized/dynamic/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (318 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (212 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/__pycache__/_numeric_suite.cpython-311.pyc ADDED
Binary file (26.9 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (215 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__init__.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa: F403
2
+
3
+ from .fake_quantize import * # noqa: F403
4
+ from .fuse_modules import fuse_modules # noqa: F403
5
+ from .fuse_modules import fuse_modules_qat # noqa: F403
6
+ from .fuser_method_mappings import * # noqa: F403
7
+ from .observer import * # noqa: F403
8
+ from .qconfig import * # noqa: F403
9
+ from .qconfig_mapping import * # noqa: F403
10
+ from .quant_type import * # noqa: F403
11
+ from .quantization_mappings import * # type: ignore[no-redef]
12
+ from .quantize import * # noqa: F403
13
+ from .quantize_jit import * # noqa: F403
14
+ from .stubs import * # noqa: F403
15
+ from .pt2e.export_utils import _move_exported_model_to_eval as move_exported_model_to_eval
16
+ from .pt2e.export_utils import _move_exported_model_to_train as move_exported_model_to_train
17
+ from .pt2e.export_utils import _allow_exported_model_train_eval as allow_exported_model_train_eval
18
+ from .pt2e.generate_numeric_debug_handle import generate_numeric_debug_handle # noqa: F401
19
+ from typing import Union, List, Callable, Tuple, Optional
20
+ from torch import Tensor
21
+ import torch
22
+
23
+ ObserverOrFakeQuantize = Union[ObserverBase, FakeQuantizeBase]
24
+ ObserverOrFakeQuantize.__module__ = "torch.ao.quantization"
25
+
26
+ __all__ = [
27
+ "DeQuantStub",
28
+ "FakeQuantize",
29
+ "FakeQuantizeBase",
30
+ "FixedQParamsFakeQuantize",
31
+ "FixedQParamsObserver",
32
+ "FusedMovingAvgObsFakeQuantize",
33
+ "HistogramObserver",
34
+ "MatchAllNode",
35
+ "MinMaxObserver",
36
+ "MovingAverageMinMaxObserver",
37
+ "MovingAveragePerChannelMinMaxObserver",
38
+ "NoopObserver",
39
+ "ObserverBase",
40
+ "ObserverOrFakeQuantize",
41
+ "Pattern",
42
+ "PerChannelMinMaxObserver",
43
+ "PlaceholderObserver",
44
+ "QConfig",
45
+ "QConfigAny",
46
+ "QConfigDynamic",
47
+ "QConfigMapping",
48
+ "QuantStub",
49
+ "QuantType",
50
+ "QuantWrapper",
51
+ "RecordingObserver",
52
+ "ReuseInputObserver",
53
+ "UniformQuantizationObserverBase",
54
+ "add_quant_dequant",
55
+ "convert",
56
+ "convert_dynamic_jit",
57
+ "convert_jit",
58
+ "default_affine_fixed_qparams_fake_quant",
59
+ "default_affine_fixed_qparams_observer",
60
+ "default_debug_observer",
61
+ "default_dynamic_fake_quant",
62
+ "default_dynamic_quant_observer",
63
+ "default_embedding_fake_quant",
64
+ "default_embedding_fake_quant_4bit",
65
+ "default_eval_fn",
66
+ "default_fake_quant",
67
+ "default_fixed_qparams_range_0to1_fake_quant",
68
+ "default_fixed_qparams_range_0to1_observer",
69
+ "default_fixed_qparams_range_neg1to1_fake_quant",
70
+ "default_fixed_qparams_range_neg1to1_observer",
71
+ "default_float_qparams_observer",
72
+ "default_float_qparams_observer_4bit",
73
+ "default_fused_act_fake_quant",
74
+ "default_fused_per_channel_wt_fake_quant",
75
+ "default_fused_wt_fake_quant",
76
+ "default_histogram_fake_quant",
77
+ "default_histogram_observer",
78
+ "default_observer",
79
+ "default_per_channel_weight_fake_quant",
80
+ "default_per_channel_weight_observer",
81
+ "default_placeholder_observer",
82
+ "default_reuse_input_observer",
83
+ "default_symmetric_fixed_qparams_fake_quant",
84
+ "default_symmetric_fixed_qparams_observer",
85
+ "default_weight_fake_quant",
86
+ "default_weight_observer",
87
+ "disable_fake_quant",
88
+ "disable_observer",
89
+ "enable_fake_quant",
90
+ "enable_observer",
91
+ "fuse_conv_bn",
92
+ "fuse_conv_bn_jit",
93
+ "fuse_conv_bn_relu",
94
+ "fuse_convtranspose_bn",
95
+ "fuse_linear_bn",
96
+ "fuse_modules",
97
+ "fuse_modules_qat",
98
+ "fused_per_channel_wt_fake_quant_range_neg_127_to_127",
99
+ "fused_wt_fake_quant_range_neg_127_to_127",
100
+ "get_combined_dict",
101
+ "get_default_compare_output_module_list",
102
+ "get_default_custom_config_dict",
103
+ "get_default_dynamic_quant_module_mappings",
104
+ "get_default_dynamic_sparse_quant_module_mappings",
105
+ "get_default_float_to_quantized_operator_mappings",
106
+ "get_default_qat_module_mappings",
107
+ "get_default_qat_qconfig",
108
+ "get_default_qat_qconfig_dict",
109
+ "get_default_qat_qconfig_mapping",
110
+ "get_default_qconfig",
111
+ "get_default_qconfig_dict",
112
+ "get_default_qconfig_mapping",
113
+ "get_default_qconfig_propagation_list",
114
+ "get_default_static_quant_module_mappings",
115
+ "get_default_static_quant_reference_module_mappings",
116
+ "get_default_static_sparse_quant_module_mappings",
117
+ "get_dynamic_quant_module_class",
118
+ "get_embedding_qat_module_mappings",
119
+ "get_embedding_static_quant_module_mappings",
120
+ "get_fuser_method",
121
+ "get_fuser_method_new",
122
+ "get_observer_state_dict",
123
+ "get_quantized_operator",
124
+ "get_static_quant_module_class",
125
+ "load_observer_state_dict",
126
+ "move_exported_model_to_eval",
127
+ "move_exported_model_to_train",
128
+ "allow_exported_model_train_eval",
129
+ "no_observer_set",
130
+ "per_channel_weight_observer_range_neg_127_to_127",
131
+ "prepare",
132
+ "prepare_dynamic_jit",
133
+ "prepare_jit",
134
+ "prepare_qat",
135
+ "propagate_qconfig_",
136
+ "qconfig_equals",
137
+ "quantize",
138
+ "quantize_dynamic",
139
+ "quantize_dynamic_jit",
140
+ "quantize_jit",
141
+ "quantize_qat",
142
+ "script_qconfig",
143
+ "script_qconfig_dict",
144
+ "swap_module",
145
+ "weight_observer_range_neg_127_to_127",
146
+ "generate_numeric_debug_handle",
147
+ ]
148
+
149
+ def default_eval_fn(model, calib_data):
150
+ r"""Define the default evaluation function.
151
+
152
+ Default evaluation function takes a torch.utils.data.Dataset or a list of
153
+ input Tensors and run the model on the dataset
154
+ """
155
+ for data, target in calib_data:
156
+ model(data)
157
+
158
+ class _DerivedObserverOrFakeQuantize(ObserverBase):
159
+ r"""This observer is used to describe an observer whose quantization parameters
160
+ are derived from other observers
161
+ """
162
+
163
+ def __init__(
164
+ self,
165
+ dtype: torch.dtype,
166
+ obs_or_fqs: List[ObserverOrFakeQuantize],
167
+ derive_qparams_fn: Callable[[List[ObserverOrFakeQuantize]], Tuple[Tensor, Tensor]],
168
+ quant_min: Optional[int]=None,
169
+ quant_max: Optional[int]=None,
170
+ qscheme: Optional[torch.qscheme]=None,
171
+ ch_axis: Optional[int] = None
172
+ ):
173
+ super().__init__(dtype)
174
+ self.obs_or_fqs = obs_or_fqs
175
+ self.derive_qparams_fn = derive_qparams_fn
176
+ self.quant_min = quant_min
177
+ self.quant_max = quant_max
178
+ self.qscheme = qscheme
179
+ self.ch_axis = ch_axis
180
+
181
+ from .utils import is_per_channel
182
+ if is_per_channel(self.qscheme):
183
+ assert self.ch_axis is not None, "Must provide a valid ch_axis if qscheme is per channel"
184
+
185
+ def forward(self, x: Tensor) -> Tensor:
186
+ return x
187
+
188
+ def calculate_qparams(self):
189
+ return self.derive_qparams_fn(self.obs_or_fqs)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/quantization_mappings.cpython-311.pyc ADDED
Binary file (16.8 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/quantize_fx.cpython-311.pyc ADDED
Binary file (34.4 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/_correct_bias.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.ao.nn.quantized as nnq
4
+
5
+ import torch.ao.quantization
6
+ import torch.ao.ns._numeric_suite as ns
7
+
8
+ __all__ = [
9
+ "get_module",
10
+ "parent_child_names",
11
+ "get_param",
12
+ "MeanShadowLogger",
13
+ "bias_correction",
14
+ ]
15
+
16
+ _supported_modules = {nn.Linear, nn.Conv2d}
17
+ _supported_modules_quantized = {nnq.Linear, nnq.Conv2d}
18
+
19
+ def get_module(model, name):
20
+ """Given name of submodule, this function grabs the submodule from given model."""
21
+ return dict(model.named_modules())[name]
22
+
23
+ def parent_child_names(name):
24
+ """Split full name of submodule into parent submodule's full name and submodule's name."""
25
+ split_name = name.rsplit('.', 1)
26
+ if len(split_name) == 1:
27
+ return '', split_name[0]
28
+ else:
29
+ return split_name[0], split_name[1]
30
+
31
+ def get_param(module, attr):
32
+ """Get the parameter given a module and attribute.
33
+
34
+ Sometimes the weights/bias attribute gives you the raw tensor, but sometimes
35
+ gives a function that will give you the raw tensor, this function takes care of that logic
36
+ """
37
+ param = getattr(module, attr, None)
38
+ if callable(param):
39
+ return param()
40
+ else:
41
+ return param
42
+
43
+ class MeanShadowLogger(ns.Logger):
44
+ """Mean Logger for a Shadow module.
45
+
46
+ A logger for a Shadow module whose purpose is to record the rolling mean
47
+ of the data passed to the floating point and quantized models
48
+ """
49
+
50
+ def __init__(self):
51
+ """Set up initial values for float and quantized stats, count, float sum, and quant sum."""
52
+ super().__init__()
53
+ self.stats["float"] = None
54
+ self.stats["quantized"] = None
55
+ self.count = 0
56
+ self.float_sum = None
57
+ self.quant_sum = None
58
+
59
+ def forward(self, x, y):
60
+ """Compute the average of quantized and floating-point data from modules.
61
+
62
+ The inputs x,y are output data from the quantized and floating-point modules.
63
+ x is for the quantized module, y is for the floating point module
64
+ """
65
+ if x.is_quantized:
66
+ x = x.dequantize()
67
+
68
+ self.count += 1
69
+ if self.stats["quantized"] is None:
70
+ self.stats["quantized"] = x
71
+ self.quant_sum = x
72
+ else:
73
+ self.quant_sum += x
74
+ self.stats["quantized"] = self.quant_sum / self.count
75
+
76
+ if self.stats["float"] is None:
77
+ self.stats["float"] = y
78
+ self.float_sum = y
79
+ else:
80
+ self.float_sum += y
81
+ self.stats["float"] = self.float_sum / self.count
82
+
83
+ def clear(self):
84
+ self.stats["float"] = None
85
+ self.stats["quantized"] = None
86
+ self.count = 0
87
+ self.float_sum = None
88
+ self.quant_sum = None
89
+
90
+ def bias_correction(float_model, quantized_model, img_data, target_modules=_supported_modules_quantized, neval_batches=None):
91
+ """Perform bias correction on a module.
92
+
93
+ Using numeric suite shadow module, the expected output of the floating point and quantized modules
94
+ is recorded. Using that data the bias of supported modules is shifted to compensate for the drift caused
95
+ by quantization
96
+ Paper reference: https://arxiv.org/pdf/1906.04721.pdf (Section 4.2)
97
+
98
+ Args:
99
+ float_model: a trained model that serves as a reference to what bias correction should aim for
100
+ quantized_model: quantized form of float_model that bias correction is to applied to
101
+ img_data: calibration data to estimate the expected output (used to find quantization error)
102
+ target_modules: specifies what submodules in quantized_model need bias correction (can be extended to
103
+ unquantized submodules)
104
+ neval_batches: a cap to the number of batches you want to be used for estimating the expected output
105
+ """
106
+ ns.prepare_model_with_stubs(float_model, quantized_model, _supported_modules, MeanShadowLogger)
107
+
108
+ uncorrected_modules = {}
109
+ for name, submodule in quantized_model.named_modules():
110
+ if type(submodule) in target_modules:
111
+ uncorrected_modules[name] = submodule
112
+
113
+ for uncorrected_module in uncorrected_modules:
114
+ quantized_submodule = get_module(quantized_model, uncorrected_module)
115
+ bias = get_param(quantized_submodule, 'bias')
116
+ if bias is not None:
117
+
118
+ count = 0
119
+ for data in img_data:
120
+ quantized_model(data[0])
121
+ count += 1
122
+ if count == neval_batches:
123
+ break
124
+ ob_dict = ns.get_logger_dict(quantized_model)
125
+ parent_name, _ = parent_child_names(uncorrected_module)
126
+
127
+ float_data = ob_dict[parent_name + '.stats']['float']
128
+ quant_data = ob_dict[parent_name + '.stats']['quantized']
129
+
130
+ # math for expected_error
131
+ quantization_error = quant_data - float_data
132
+ dims = list(range(quantization_error.dim()))
133
+ # Note: we don't want to take the mean over the output channel dimension
134
+ dims.remove(1)
135
+ expected_error = torch.mean(quantization_error, dims)
136
+
137
+ updated_bias = bias.data - expected_error
138
+
139
+ bias.data = updated_bias
140
+
141
+ # Resets the data contained in the loggers
142
+ for name, submodule in quantized_model.named_modules():
143
+ if isinstance(submodule, MeanShadowLogger):
144
+ submodule.clear()
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/__pycache__/_common_operator_config_utils.cpython-311.pyc ADDED
Binary file (32.9 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/__pycache__/backend_config.cpython-311.pyc ADDED
Binary file (34.2 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/__pycache__/observation_type.cpython-311.pyc ADDED
Binary file (245 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/__pycache__/onednn.cpython-311.pyc ADDED
Binary file (20.9 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/observation_type.py ADDED
File without changes
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/qnnpack.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from ._common_operator_config_utils import (
3
+ _get_binary_op_configs,
4
+ _get_bn_configs,
5
+ _get_cat_config,
6
+ _get_conv_configs,
7
+ _get_default_op_configs,
8
+ _get_embedding_op_configs,
9
+ _get_fixed_qparams_op_configs,
10
+ _get_linear_configs,
11
+ _get_rnn_op_configs,
12
+ _get_share_qparams_op_configs,
13
+ )
14
+ from .backend_config import BackendConfig, DTypeConfig, DTypeWithConstraints
15
+
16
+ __all__ = [
17
+ "get_qnnpack_backend_config",
18
+ ]
19
+
20
+ # ===================
21
+ # | DTYPE CONFIGS |
22
+ # ===================
23
+
24
+ qnnpack_weighted_op_quint8_dtype_config = DTypeConfig(
25
+ input_dtype=torch.quint8,
26
+ output_dtype=torch.quint8,
27
+ weight_dtype=torch.qint8,
28
+ bias_dtype=torch.float,
29
+ )
30
+
31
+ qnnpack_default_op_quint8_dtype_config = DTypeConfig(
32
+ input_dtype=torch.quint8,
33
+ output_dtype=torch.quint8,
34
+ )
35
+
36
+ qnnpack_default_op_fp16_dtype_config = DTypeConfig(
37
+ input_dtype=torch.float16,
38
+ output_dtype=torch.float16,
39
+ weight_dtype=torch.float16,
40
+ bias_dtype=torch.float16,
41
+ )
42
+
43
+ qnnpack_default_dynamic_int8_dtype_config = DTypeConfig(
44
+ input_dtype=torch.quint8,
45
+ output_dtype=torch.float,
46
+ weight_dtype=torch.qint8,
47
+ bias_dtype=torch.float,
48
+ is_dynamic=True,
49
+ )
50
+
51
+ qnnpack_default_dynamic_float16_dtype_config = DTypeConfig(
52
+ input_dtype=torch.float16,
53
+ output_dtype=torch.float,
54
+ weight_dtype=torch.float16,
55
+ bias_dtype=torch.float,
56
+ is_dynamic=True,
57
+ )
58
+
59
+ qnnpack_weight_only_quint8_dtype_config = DTypeConfig(
60
+ input_dtype=torch.float,
61
+ output_dtype=torch.float,
62
+ weight_dtype=torch.quint8,
63
+ )
64
+
65
+ qnnpack_weight_only_quint4x2_dtype_config = DTypeConfig(
66
+ input_dtype=torch.float,
67
+ output_dtype=torch.float,
68
+ weight_dtype=torch.quint4x2,
69
+ )
70
+
71
+ # xnnpack compatible dtype configs
72
+
73
+ # We restrict scale values to be 2 ** -12 to ensure the
74
+ # requantization scale never falls below the xnnpack lower
75
+ # threshold. Additionally, for qint8 weight, we restrict
76
+ # the quantization values to [-127, +127], excluding -128.
77
+ # For more detail, refer to the description of
78
+ # `default_symmetric_qnnpack_qconfig`.
79
+
80
+ # TODO: add additional restriction on qscheme to ensure it
81
+ # is either per_tensor_symmetric or per_channel_symmetric
82
+
83
+ qnnpack_act_qint8_scale_min_2_neg_12 = DTypeWithConstraints(
84
+ dtype=torch.qint8,
85
+ scale_min_lower_bound=2 ** -12,
86
+ )
87
+
88
+ qnnpack_weight_qint8_neg_127_to_127_scale_min_2_neg_12 = DTypeWithConstraints(
89
+ dtype=torch.qint8,
90
+ quant_min_lower_bound=-127,
91
+ quant_max_upper_bound=127,
92
+ scale_min_lower_bound=2 ** -12,
93
+ )
94
+
95
+ qnnpack_weighted_op_qint8_symmetric_dtype_config = DTypeConfig(
96
+ input_dtype=qnnpack_act_qint8_scale_min_2_neg_12,
97
+ output_dtype=qnnpack_act_qint8_scale_min_2_neg_12,
98
+ weight_dtype=qnnpack_weight_qint8_neg_127_to_127_scale_min_2_neg_12,
99
+ bias_dtype=torch.float,
100
+ )
101
+
102
+ qnnpack_default_op_qint8_symmetric_dtype_config = DTypeConfig(
103
+ input_dtype=qnnpack_act_qint8_scale_min_2_neg_12,
104
+ output_dtype=qnnpack_act_qint8_scale_min_2_neg_12,
105
+ )
106
+
107
+
108
+ # =====================
109
+ # | BACKEND CONFIGS |
110
+ # =====================
111
+
112
+ def get_qnnpack_backend_config() -> BackendConfig:
113
+ """
114
+ Return the `BackendConfig` for PyTorch's native QNNPACK backend.
115
+ """
116
+ conv_dtype_configs = [
117
+ qnnpack_weighted_op_qint8_symmetric_dtype_config,
118
+ qnnpack_weighted_op_quint8_dtype_config,
119
+ ]
120
+ linear_dtype_configs = [
121
+ qnnpack_weighted_op_qint8_symmetric_dtype_config,
122
+ qnnpack_weighted_op_quint8_dtype_config,
123
+ qnnpack_default_dynamic_int8_dtype_config,
124
+ qnnpack_default_dynamic_float16_dtype_config,
125
+ ]
126
+ binary_op_dtype_configs = [
127
+ qnnpack_default_op_qint8_symmetric_dtype_config,
128
+ qnnpack_default_op_quint8_dtype_config,
129
+ ]
130
+ default_op_dtype_configs = [
131
+ qnnpack_default_op_qint8_symmetric_dtype_config,
132
+ qnnpack_default_op_quint8_dtype_config,
133
+ ]
134
+ fixed_qparams_op_dtype_configs = [
135
+ qnnpack_default_op_qint8_symmetric_dtype_config,
136
+ qnnpack_default_op_quint8_dtype_config,
137
+ ]
138
+ share_qparams_op_dtype_configs = [
139
+ qnnpack_default_op_qint8_symmetric_dtype_config,
140
+ qnnpack_default_op_quint8_dtype_config,
141
+ ]
142
+ rnn_op_dtype_configs = [
143
+ qnnpack_default_dynamic_int8_dtype_config,
144
+ qnnpack_default_dynamic_float16_dtype_config,
145
+ ]
146
+ embedding_op_dtype_configs = [
147
+ qnnpack_weight_only_quint8_dtype_config,
148
+ qnnpack_weight_only_quint4x2_dtype_config,
149
+ ]
150
+ return BackendConfig("qnnpack") \
151
+ .set_backend_pattern_configs(_get_conv_configs(conv_dtype_configs)) \
152
+ .set_backend_pattern_configs(_get_linear_configs(linear_dtype_configs)) \
153
+ .set_backend_pattern_configs(_get_binary_op_configs(binary_op_dtype_configs)) \
154
+ .set_backend_pattern_config(_get_cat_config(default_op_dtype_configs)) \
155
+ .set_backend_pattern_configs(_get_default_op_configs(default_op_dtype_configs)) \
156
+ .set_backend_pattern_configs(_get_fixed_qparams_op_configs(fixed_qparams_op_dtype_configs)) \
157
+ .set_backend_pattern_configs(_get_share_qparams_op_configs(share_qparams_op_dtype_configs)) \
158
+ .set_backend_pattern_configs(_get_bn_configs(default_op_dtype_configs)) \
159
+ .set_backend_pattern_configs(_get_rnn_op_configs(rnn_op_dtype_configs)) \
160
+ .set_backend_pattern_configs(_get_embedding_op_configs(embedding_op_dtype_configs))
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/tensorrt.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .backend_config import (
3
+ BackendConfig,
4
+ BackendPatternConfig,
5
+ DTypeConfig,
6
+ ObservationType
7
+ )
8
+ from ._common_operator_config_utils import (
9
+ _get_binary_op_configs,
10
+ _get_linear_configs,
11
+ _get_conv_configs,
12
+ _get_share_qparams_op_configs,
13
+ _get_tensor_info_op_configs,
14
+ )
15
+
16
+ __all__ = [
17
+ "get_tensorrt_backend_config",
18
+ "get_tensorrt_backend_config_dict",
19
+ ]
20
+
21
+ def get_tensorrt_backend_config() -> BackendConfig:
22
+ """
23
+ Return the `BackendConfig` for the TensorRT backend.
24
+ NOTE: Current api will change in the future, it's just to unblock experimentation for
25
+ new backends, please don't use it right now.
26
+ TODO: add a README when it's more stable
27
+ """
28
+ # dtype configs
29
+ weighted_op_qint8_dtype_config = DTypeConfig(
30
+ input_dtype=torch.qint8,
31
+ output_dtype=torch.qint8,
32
+ weight_dtype=torch.qint8,
33
+ bias_dtype=torch.float,
34
+ )
35
+ non_weighted_op_qint8_dtype_config = DTypeConfig(
36
+ input_dtype=torch.qint8,
37
+ output_dtype=torch.qint8,
38
+ )
39
+
40
+ addmm_config = BackendPatternConfig(torch.addmm) \
41
+ .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \
42
+ .add_dtype_config(weighted_op_qint8_dtype_config) \
43
+ ._set_input_type_to_index({
44
+ "bias": 0,
45
+ "input": 1,
46
+ "weight": 2,
47
+ })
48
+ cat_config = BackendPatternConfig(torch.cat) \
49
+ .set_observation_type(ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT) \
50
+ .add_dtype_config(non_weighted_op_qint8_dtype_config)
51
+ conv_dtype_configs = [
52
+ weighted_op_qint8_dtype_config,
53
+ ]
54
+ linear_dtype_configs = [
55
+ weighted_op_qint8_dtype_config,
56
+ ]
57
+ binary_op_dtype_configs = [
58
+ weighted_op_qint8_dtype_config,
59
+ ]
60
+ share_qparams_op_dtype_configs = [
61
+ non_weighted_op_qint8_dtype_config,
62
+ ]
63
+ tensor_info_op_dtype_configs = [
64
+ non_weighted_op_qint8_dtype_config,
65
+ ]
66
+ # there might be things not supported in fx2trt, but it will error out
67
+ # during fx2trt conversion and can support them after that
68
+ return BackendConfig("tensorrt") \
69
+ .set_backend_pattern_configs(_get_conv_configs(conv_dtype_configs)) \
70
+ .set_backend_pattern_config(addmm_config) \
71
+ .set_backend_pattern_config(cat_config) \
72
+ .set_backend_pattern_configs(_get_linear_configs(linear_dtype_configs)) \
73
+ .set_backend_pattern_configs(_get_binary_op_configs(binary_op_dtype_configs)) \
74
+ .set_backend_pattern_configs(_get_share_qparams_op_configs(share_qparams_op_dtype_configs)) \
75
+ .set_backend_pattern_configs(_get_tensor_info_op_configs(tensor_info_op_dtype_configs))
76
+
77
+ def get_tensorrt_backend_config_dict():
78
+ """
79
+ Return the `BackendConfig` for the TensorRT backend in dictionary form.
80
+ """
81
+ return get_tensorrt_backend_config().to_dict()
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/utils.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Any, List, Callable, Union, Tuple, Type
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from .backend_config import (
7
+ BackendConfig,
8
+ BackendPatternConfig,
9
+ DTypeConfig,
10
+ )
11
+ from ..utils import Pattern
12
+ from ..fuser_method_mappings import (
13
+ _reverse2,
14
+ _reverse3,
15
+ )
16
+
17
+ __all__ = [
18
+ "get_pattern_to_dtype_configs",
19
+ "get_qat_module_classes",
20
+ "get_fused_module_classes",
21
+ "get_pattern_to_input_type_to_index",
22
+ "get_root_module_to_quantized_reference_module",
23
+ "get_fuser_method_mapping",
24
+ "get_module_to_qat_module",
25
+ "get_fusion_pattern_to_root_node_getter",
26
+ "get_fusion_pattern_to_extra_inputs_getter",
27
+ "remove_boolean_dispatch_from_name",
28
+ "pattern_to_human_readable",
29
+ "entry_to_pretty_str",
30
+ ]
31
+
32
+ def get_pattern_to_dtype_configs(backend_config: BackendConfig) -> Dict[Pattern, List[DTypeConfig]]:
33
+ pattern_to_dtype_configs: Dict[Pattern, List[DTypeConfig]] = {}
34
+ for pattern, config in backend_config._pattern_complex_format_to_config.items():
35
+ pattern_to_dtype_configs[pattern] = config.dtype_configs
36
+ return pattern_to_dtype_configs
37
+
38
+ def get_qat_module_classes(backend_config: BackendConfig) -> Tuple[type, ...]:
39
+ qat_module_classes = []
40
+ for config in backend_config.configs:
41
+ if config.qat_module is not None:
42
+ qat_module_classes.append(config.qat_module)
43
+ return tuple(set(qat_module_classes))
44
+
45
+ def get_fused_module_classes(backend_config: BackendConfig) -> Tuple[type, ...]:
46
+ fused_module_classes = []
47
+ for config in backend_config.configs:
48
+ if config.fused_module is not None:
49
+ fused_module_classes.append(config.fused_module)
50
+ return tuple(set(fused_module_classes))
51
+
52
+ def get_pattern_to_input_type_to_index(backend_config: BackendConfig) -> Dict[Pattern, Dict[str, int]]:
53
+ pattern_to_input_type_to_index: Dict[Pattern, Dict[str, int]] = {}
54
+ for pattern, config in backend_config._pattern_complex_format_to_config.items():
55
+ pattern_to_input_type_to_index[pattern] = config._input_type_to_index
56
+ return pattern_to_input_type_to_index
57
+
58
+ def get_root_module_to_quantized_reference_module(
59
+ backend_config: BackendConfig) -> Dict[Type[torch.nn.Module], Type[torch.nn.Module]]:
60
+ mapping: Dict[Type[torch.nn.Module], Type[torch.nn.Module]] = {}
61
+ for config in backend_config.configs:
62
+ if config.root_module is not None and config.reference_quantized_module is not None:
63
+ mapping[config.root_module] = config.reference_quantized_module
64
+ return mapping
65
+
66
+ def get_fuser_method_mapping(backend_config: BackendConfig) -> Dict[Pattern, Union[nn.Sequential, Callable]]:
67
+ fuser_method_mapping : Dict[Pattern, Union[nn.Sequential, Callable]] = {}
68
+ for pattern, config in backend_config._pattern_complex_format_to_config.items():
69
+ if config.fuser_method is not None:
70
+ # Note: both the fuser method and the pattern are specified in forward order in the
71
+ # BackendConfig, but the internal pattern matching code uses the reversed nested tuple
72
+ # format, so we need to convert both to the internal format
73
+ fuser_method = _get_fuser_method_in_reversed_nested_tuple_format(config)
74
+ fuser_method_mapping[pattern] = fuser_method
75
+ return fuser_method_mapping
76
+
77
+ def get_module_to_qat_module(backend_config: BackendConfig) -> Dict[Pattern, Type[torch.nn.Module]]:
78
+ module_to_qat_module: Dict[Pattern, Type[torch.nn.Module]] = {}
79
+ for pattern, config in backend_config._pattern_complex_format_to_config.items():
80
+ if config.qat_module is not None:
81
+ module_to_qat_module[pattern] = config.qat_module
82
+ return module_to_qat_module
83
+
84
+ def get_fusion_pattern_to_root_node_getter(backend_config: BackendConfig) -> Dict[Pattern, Callable]:
85
+ """ Get a map from fusion pattern to a function that returns the root node
86
+ from the fusion pattern, e.g. the most common one is:
87
+ def get_root_node(node_pattern):
88
+ while not isinstance(node_pattern[-1], Node):
89
+ node_pattern = node_pattern[-1]
90
+ return node_pattern[-1]
91
+ This can work for all patterns whose root node is the "last node" in the pattern,
92
+ e.g. (torch.add, MatchAllNode, (torch.ReLU, torch.Conv2d))
93
+ """
94
+ root_node_getter_mapping: Dict[Pattern, Callable] = {}
95
+ for pattern, config in backend_config._pattern_complex_format_to_config.items():
96
+ if config._root_node_getter is not None:
97
+ root_node_getter_mapping[pattern] = config._root_node_getter
98
+ return root_node_getter_mapping
99
+
100
+ def get_fusion_pattern_to_extra_inputs_getter(backend_config: BackendConfig) -> Dict[Pattern, Callable]:
101
+ """ Get a map from fusion pattern to a function that returns extra input nodes
102
+ from the fusion pattern, in the order required by the root node. This is optional,
103
+ if not specified, we will not copy over any extra inputs for the root node.
104
+ Example:
105
+ # Let's say we have the pattern (torch.add, MatchAllNode, (torch.nn.BatchNorm2d, torch.nn.Conv2d))
106
+ # and root node is torch.nn.Conv2d, and the node in MatchAllNode would be an extra
107
+ # argument to the fused module, we can unpack the pattern and return the node at
108
+ # MatchAllNode here
109
+ # we can implement extra_inputs_getter as follows:
110
+ def extra_inputs_getter(pattern) -> List[Any]:
111
+ add, extra_input, conv_pattern = pattern
112
+ return [extra_input]
113
+ """
114
+ extra_inputs_getter_mapping: Dict[Pattern, Callable] = {}
115
+ for pattern, config in backend_config._pattern_complex_format_to_config.items():
116
+ if config._extra_inputs_getter is not None:
117
+ extra_inputs_getter_mapping[pattern] = config._extra_inputs_getter
118
+ return extra_inputs_getter_mapping
119
+
120
+ def remove_boolean_dispatch_from_name(p) -> Any:
121
+ """
122
+ Some ops have a default string representation such as
123
+ '<function boolean_dispatch.<locals>.fn at 0x7ff1106bf280>',
124
+ this function replaces them with the hardcoded function names.
125
+ """
126
+ if p is F.fractional_max_pool2d:
127
+ return "torch.nn.functional.fractional_max_pool2d"
128
+ elif p is F.fractional_max_pool3d:
129
+ return "torch.nn.functional.fractional_max_pool3d"
130
+ elif p is F.max_pool1d:
131
+ return "torch.nn.functional.max_pool1d"
132
+ elif p is F.max_pool2d:
133
+ return "torch.nn.functional.max_pool2d"
134
+ elif p is F.max_pool3d:
135
+ return "torch.nn.functional.max_pool3d"
136
+ elif p is F.adaptive_max_pool1d:
137
+ return "torch.nn.functional.adaptive_max_pool1d"
138
+ elif p is F.adaptive_max_pool2d:
139
+ return "torch.nn.functional.adaptive_max_pool2d"
140
+ elif p is F.adaptive_max_pool3d:
141
+ return "torch.nn.functional.adaptive_max_pool3d"
142
+ assert "boolean_dispatch" not in str(p), \
143
+ f"{p} does not have a human readable representation in " + \
144
+ "quantization documentation"
145
+ return p
146
+
147
+ def pattern_to_human_readable(p) -> Any:
148
+ if isinstance(p, tuple):
149
+ # nested patterns, recurse
150
+ return tuple(pattern_to_human_readable(inner_p) for inner_p in p)
151
+ elif isinstance(p, str):
152
+ # method names are already human readable
153
+ return p
154
+ else:
155
+ p = remove_boolean_dispatch_from_name(p)
156
+ return p
157
+
158
+ # TODO(future PR): move backend_config_dict to use dataclass and move this logic to
159
+ # the corresponding __str__ function
160
+ def entry_to_pretty_str(entry) -> str:
161
+ """
162
+ Given a backend_config_dict entry, returns a string with the human readable
163
+ representation of it.
164
+ """
165
+ s = "{\n"
166
+
167
+ # always output the pattern first
168
+ if "pattern" in entry:
169
+ pattern_str = pattern_to_human_readable(entry["pattern"])
170
+
171
+ s += f" 'pattern': {pattern_str},\n"
172
+
173
+ # custom output for dtype_configs to make it look nice
174
+ if "dtype_configs" in entry:
175
+ s += " 'dtype_configs': [\n"
176
+ for dtype_config in entry["dtype_configs"]:
177
+ s += " {\n"
178
+ for k, v in dtype_config.items():
179
+ s += f" '{k}': {v},\n"
180
+ s += " },\n"
181
+ s += " ],\n"
182
+
183
+ # custom output for num_tensor_args_to_observation_type to make it look nice
184
+ if "num_tensor_args_to_observation_type" in entry:
185
+ s += " 'num_tensor_args_to_observation_type': {\n"
186
+ for k, v in entry["num_tensor_args_to_observation_type"].items():
187
+ s += f" {k}: {v},\n"
188
+ s += " },\n"
189
+
190
+ # output all the other fields
191
+ custom_handled_fields = [
192
+ "pattern",
193
+ "dtype_configs",
194
+ "num_tensor_args_to_observation_type",
195
+ ]
196
+ for field_name in entry:
197
+ if field_name in custom_handled_fields:
198
+ continue
199
+ s += f" '{field_name}': {entry[field_name]},\n"
200
+
201
+ s += "}"
202
+ return s
203
+
204
+ def _get_pattern_in_reversed_nested_tuple_format(config: BackendPatternConfig) -> Pattern:
205
+ """
206
+ Return the pattern specified in the given config in the reversed nested tuple format
207
+ used internally in the quantization pattern matching code.
208
+
209
+ If the pattern is not a tuple, or the pattern is already specified in the reversed
210
+ nested tuple format, return the pattern as is. Otherwise:
211
+
212
+ For 2-tuples (a, b), return (b, a).
213
+ For 3-tuples (a, b, c), return (c, (b, a)).
214
+
215
+ For example:
216
+ * Given nn.Linear, return nn.Linear
217
+ * Given (nn.Linear, nn.ReLU), return (nn.ReLU, nn.Linear)
218
+ * Given (nn.Conv2d, nn.BatchNorm2d, nn.ReLU), return
219
+ (nn.ReLU, (nn.BatchNorm2d, nn.Conv2d))
220
+
221
+ For context, the reason why this is needed is the user-facing BackendConfig
222
+ API accepts the flat 2-or-3-tuple format in forward order. While this simple
223
+ format handles the vast majority of use cases, it does not handle the more
224
+ complex ones, and so the internal pattern matching code for quantization uses
225
+ the following, more general reversed nested tuple format instead:
226
+
227
+ operator = module_type | functional | torch op | native op | MatchAllNode
228
+ Pattern = (operator, Pattern, Pattern, ...) | operator
229
+
230
+ In the future, we expect to replace the above complex format with the one used
231
+ by the subgraph rewriter in torch.fx, so we don't have to maintain our own
232
+ complex pattern matching code. Then we won't need this helper function anymore.
233
+ """
234
+ if config._pattern_complex_format is not None:
235
+ return config._pattern_complex_format
236
+ if config.pattern is None:
237
+ raise ValueError("Either 'pattern' or 'pattern_complex_format' must be specified")
238
+ if not isinstance(config.pattern, tuple):
239
+ return config.pattern
240
+
241
+ # Pattern is specified in the simple tuple format, need to convert
242
+ if len(config.pattern) == 2:
243
+ (a, b) = config.pattern
244
+ return (b, a)
245
+ elif len(config.pattern) == 3:
246
+ (a, b, c) = config.pattern
247
+ return (c, (b, a))
248
+ else:
249
+ raise ValueError("Expected a tuple with 2 or 3 elements, got: ", config.pattern)
250
+
251
+ def _get_fuser_method_in_reversed_nested_tuple_format(config: BackendPatternConfig) -> Callable:
252
+ """
253
+ Return the fuser method specified in the given config in the reversed nested
254
+ tuple format used internally in the quantization pattern matching code.
255
+
256
+ If pattern is specified in the reversed nested tuple format, we assume the
257
+ fuser method is also specified in this format and simply return it as is.
258
+ Otherwise, we convert the fuser method as follows:
259
+
260
+ * Given f(is_qat, conv, relu), return f'(is_qat, relu, conv)
261
+ * Given f(is_qat, conv, bn, relu), return f'(is_qat, relu, bn_conv),
262
+ where bn_conv is a 2-tuple (bn, conv)
263
+
264
+ The first argument of a fuser method is always `is_qat` and is not affected
265
+ in the conversion. We currently only support functions with 3 or 4 arguments.
266
+ """
267
+ assert config.fuser_method is not None
268
+ if config._pattern_complex_format is not None:
269
+ return config.fuser_method
270
+ if not isinstance(config.pattern, tuple):
271
+ raise ValueError("Expected pattern to be a tuple, got: ", config.pattern)
272
+
273
+ # Pattern is specified in the simple tuple format, need to convert
274
+ if len(config.pattern) == 2:
275
+ return _reverse2(config.fuser_method)
276
+ elif len(config.pattern) == 3:
277
+ return _reverse3(config.fuser_method)
278
+ else:
279
+ raise ValueError("Expected a tuple with 2 or 3 elements, got: ", config.pattern)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/x86.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from ._common_operator_config_utils import (
3
+ _get_binary_op_configs,
4
+ _get_bn_configs,
5
+ _get_cat_config,
6
+ _get_conv_configs,
7
+ _get_default_op_configs,
8
+ _get_embedding_op_configs,
9
+ _get_fixed_qparams_op_configs,
10
+ _get_linear_configs,
11
+ _get_rnn_op_configs,
12
+ _get_share_qparams_op_configs,
13
+ _get_tensor_info_op_configs,
14
+ )
15
+ from .backend_config import BackendConfig, DTypeConfig
16
+
17
+ __all__ = [
18
+ "get_x86_backend_config",
19
+ ]
20
+
21
+ # ===================
22
+ # | DTYPE CONFIGS |
23
+ # ===================
24
+
25
+ # X86 aligns with FBGEMM for now
26
+
27
+ x86_weighted_op_int8_dtype_config = DTypeConfig(
28
+ input_dtype=torch.quint8,
29
+ output_dtype=torch.quint8,
30
+ weight_dtype=torch.qint8,
31
+ bias_dtype=torch.float,
32
+ )
33
+
34
+ x86_default_op_quint8_dtype_config = DTypeConfig(
35
+ input_dtype=torch.quint8,
36
+ output_dtype=torch.quint8,
37
+ )
38
+
39
+ x86_default_op_fp16_dtype_config = DTypeConfig(
40
+ input_dtype=torch.float16,
41
+ output_dtype=torch.float16,
42
+ weight_dtype=torch.float16,
43
+ bias_dtype=torch.float16,
44
+ )
45
+
46
+ x86_default_dynamic_int8_dtype_config = DTypeConfig(
47
+ input_dtype=torch.quint8,
48
+ output_dtype=torch.float,
49
+ weight_dtype=torch.qint8,
50
+ bias_dtype=torch.float,
51
+ is_dynamic=True,
52
+ )
53
+
54
+ x86_default_dynamic_float16_dtype_config = DTypeConfig(
55
+ input_dtype=torch.float16,
56
+ output_dtype=torch.float,
57
+ weight_dtype=torch.float16,
58
+ bias_dtype=torch.float,
59
+ is_dynamic=True,
60
+ )
61
+
62
+ x86_weight_only_quint8_dtype_config = DTypeConfig(
63
+ input_dtype=torch.float,
64
+ output_dtype=torch.float,
65
+ weight_dtype=torch.quint8,
66
+ )
67
+
68
+ x86_weight_only_quint4x2_dtype_config = DTypeConfig(
69
+ input_dtype=torch.float,
70
+ output_dtype=torch.float,
71
+ weight_dtype=torch.quint4x2,
72
+ )
73
+
74
+
75
+ # =====================
76
+ # | BACKEND CONFIGS |
77
+ # =====================
78
+
79
+ def get_x86_backend_config() -> BackendConfig:
80
+ """
81
+ Return the `BackendConfig` for PyTorch's native x86 backend.
82
+ """
83
+ conv_dtype_configs = [x86_weighted_op_int8_dtype_config]
84
+ linear_dtype_configs = [
85
+ x86_weighted_op_int8_dtype_config,
86
+ x86_default_dynamic_int8_dtype_config,
87
+ x86_default_dynamic_float16_dtype_config,
88
+ ]
89
+ binary_op_dtype_configs = [x86_weighted_op_int8_dtype_config]
90
+ default_op_dtype_configs = [x86_default_op_quint8_dtype_config]
91
+ fixed_qparams_op_dtype_configs = [x86_weighted_op_int8_dtype_config]
92
+ share_qparams_op_dtype_configs = [x86_default_op_quint8_dtype_config]
93
+ tensor_info_op_dtype_configs = [x86_default_op_quint8_dtype_config]
94
+ rnn_op_dtype_configs = [
95
+ x86_default_dynamic_int8_dtype_config,
96
+ x86_default_dynamic_float16_dtype_config,
97
+ ]
98
+ embedding_op_dtype_configs = [
99
+ x86_weight_only_quint8_dtype_config,
100
+ x86_weight_only_quint4x2_dtype_config,
101
+ ]
102
+ return BackendConfig("x86") \
103
+ .set_backend_pattern_configs(_get_conv_configs(conv_dtype_configs)) \
104
+ .set_backend_pattern_configs(_get_linear_configs(linear_dtype_configs)) \
105
+ .set_backend_pattern_configs(_get_binary_op_configs(binary_op_dtype_configs)) \
106
+ .set_backend_pattern_config(_get_cat_config(default_op_dtype_configs)) \
107
+ .set_backend_pattern_configs(_get_default_op_configs(default_op_dtype_configs)) \
108
+ .set_backend_pattern_configs(_get_fixed_qparams_op_configs(fixed_qparams_op_dtype_configs)) \
109
+ .set_backend_pattern_configs(_get_share_qparams_op_configs(share_qparams_op_dtype_configs)) \
110
+ .set_backend_pattern_configs(_get_tensor_info_op_configs(tensor_info_op_dtype_configs)) \
111
+ .set_backend_pattern_configs(_get_bn_configs(default_op_dtype_configs)) \
112
+ .set_backend_pattern_configs(_get_rnn_op_configs(rnn_op_dtype_configs)) \
113
+ .set_backend_pattern_configs(_get_embedding_op_configs(embedding_op_dtype_configs))
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fake_quantize.py ADDED
@@ -0,0 +1,546 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Implements modules used to perform fake quantization."""
2
+
3
+ import torch
4
+ from torch.nn import Module
5
+ from torch.ao.quantization.observer import (
6
+ MovingAverageMinMaxObserver,
7
+ HistogramObserver,
8
+ MovingAveragePerChannelMinMaxObserver,
9
+ FixedQParamsObserver,
10
+ default_fixed_qparams_range_0to1_observer,
11
+ default_fixed_qparams_range_neg1to1_observer,
12
+ _with_args,
13
+ )
14
+ import re
15
+ from abc import ABC, abstractmethod
16
+ from typing import Any, Tuple
17
+
18
+ __all__ = [
19
+ "FakeQuantizeBase",
20
+ "FakeQuantize",
21
+ "FixedQParamsFakeQuantize",
22
+ "FusedMovingAvgObsFakeQuantize",
23
+ "disable_fake_quant",
24
+ "disable_observer",
25
+ "enable_fake_quant",
26
+ "enable_observer",
27
+ "default_fake_quant",
28
+ "default_weight_fake_quant",
29
+ "default_dynamic_fake_quant",
30
+ "default_fixed_qparams_range_neg1to1_fake_quant",
31
+ "default_fixed_qparams_range_0to1_fake_quant",
32
+ "default_symmetric_fixed_qparams_fake_quant",
33
+ "default_affine_fixed_qparams_fake_quant",
34
+ "default_per_channel_weight_fake_quant",
35
+ "default_embedding_fake_quant",
36
+ "default_embedding_fake_quant_4bit",
37
+ "default_histogram_fake_quant",
38
+ "default_fused_act_fake_quant",
39
+ "default_fused_wt_fake_quant",
40
+ "default_fused_per_channel_wt_fake_quant",
41
+ "fused_wt_fake_quant_range_neg_127_to_127",
42
+ "fused_per_channel_wt_fake_quant_range_neg_127_to_127",
43
+ ]
44
+
45
+ def _is_per_channel(qscheme: 'torch.qscheme') -> bool:
46
+ return qscheme in [torch.per_channel_symmetric, torch.per_channel_affine, torch.per_channel_affine_float_qparams]
47
+
48
+ def _is_per_tensor(qscheme: 'torch.qscheme') -> bool:
49
+ return qscheme in [torch.per_tensor_symmetric, torch.per_tensor_affine]
50
+
51
+ def _is_symmetric_quant(qscheme: 'torch.qscheme') -> bool:
52
+ return qscheme in [torch.per_tensor_symmetric, torch.per_channel_symmetric]
53
+
54
+ def _is_float_qparams(qscheme: 'torch.qscheme') -> bool:
55
+ return qscheme in [torch.per_channel_affine_float_qparams, ]
56
+
57
+ class FakeQuantizeBase(ABC, Module):
58
+ r"""Base fake quantize module.
59
+
60
+ Base fake quantize module
61
+ Any fake quantize implementation should derive from this class.
62
+
63
+ Concrete fake quantize module should follow the same API. In forward, they will update
64
+ the statistics of the observed Tensor and fake quantize the input. They should also provide a
65
+ `calculate_qparams` function that computes the quantization parameters given
66
+ the collected statistics.
67
+
68
+ """
69
+
70
+ fake_quant_enabled: torch.Tensor
71
+ observer_enabled: torch.Tensor
72
+
73
+ def __init__(self):
74
+ """Set fake_quant_enabled and observer_enabled."""
75
+ super().__init__()
76
+ # fake_quant_enabled and observer_enabled are buffers to support their
77
+ # replication in DDP. Data type is uint8 because NCCL does not support
78
+ # bool tensors.
79
+ self.register_buffer('fake_quant_enabled', torch.tensor([1], dtype=torch.uint8))
80
+ self.register_buffer('observer_enabled', torch.tensor([1], dtype=torch.uint8))
81
+
82
+ @abstractmethod
83
+ def forward(self, x):
84
+ pass
85
+
86
+ @abstractmethod
87
+ def calculate_qparams(self, **kwargs):
88
+ pass
89
+
90
+ @torch.jit.export
91
+ def enable_fake_quant(self, enabled: bool = True) -> None:
92
+ self.fake_quant_enabled[0] = 1 if enabled else 0
93
+
94
+ @torch.jit.export
95
+ def disable_fake_quant(self):
96
+ self.enable_fake_quant(False)
97
+
98
+ @torch.jit.export
99
+ def enable_observer(self, enabled: bool = True) -> None:
100
+ self.observer_enabled[0] = 1 if enabled else 0
101
+
102
+ @torch.jit.export
103
+ def disable_observer(self):
104
+ self.enable_observer(False)
105
+
106
+ @classmethod
107
+ def with_args(cls, **kwargs):
108
+ fake_quant_constructor = _with_args(cls, **kwargs)
109
+ # need to assign the correct module to fake_quantize
110
+ # constructors to satisfy public v private requirements
111
+ fake_quant_constructor.__module__ = "torch.ao.quantization.fake_quantize"
112
+ return fake_quant_constructor
113
+
114
+ class FakeQuantize(FakeQuantizeBase):
115
+ r"""Simulate the quantize and dequantize operations in training time.
116
+
117
+ The output of this module is given by::
118
+
119
+ x_out = (
120
+ clamp(round(x/scale + zero_point), quant_min, quant_max) - zero_point
121
+ ) * scale
122
+
123
+ * :attr:`is_dynamic` indicates whether the fake quantie is a placeholder for dynamic quantization
124
+ operators (choose_qparams -> q -> dq) or static quantization operators (q -> dq)
125
+
126
+ * :attr:`scale` defines the scale factor used for quantization.
127
+
128
+ * :attr:`zero_point` specifies the quantized value to which 0 in floating point maps to
129
+
130
+ * :attr:`fake_quant_enabled` controls the application of fake quantization on tensors, note that
131
+ statistics can still be updated.
132
+
133
+ * :attr:`observer_enabled` controls statistics collection on tensors
134
+
135
+ * :attr:`dtype` specifies the quantized dtype that is being emulated with fake-quantization,
136
+ allowable values are torch.qint8 and torch.quint8.
137
+
138
+ Args:
139
+
140
+ observer (module): Module for observing statistics on input tensors and calculating scale
141
+ and zero-point.
142
+ observer_kwargs (optional): Arguments for the observer module
143
+
144
+ Attributes:
145
+ activation_post_process (Module): User provided module that collects statistics on the input tensor and
146
+ provides a method to calculate scale and zero-point.
147
+
148
+ """
149
+
150
+ scale: torch.Tensor
151
+ zero_point: torch.Tensor
152
+
153
+ def __init__(self, observer=MovingAverageMinMaxObserver, quant_min=None, quant_max=None, is_dynamic=False, **observer_kwargs):
154
+ super().__init__()
155
+ # Populate quant_min/quant_max to observer_kwargs if valid
156
+ if quant_min is not None and quant_max is not None:
157
+ assert quant_min <= quant_max, \
158
+ 'quant_min must be less than or equal to quant_max'
159
+ dtype = observer_kwargs.get("dtype", torch.quint8)
160
+ if hasattr(observer, "p"):
161
+ # In case observer is _PartialWrapper, dtype can be stored in
162
+ # observer.p.keywords["dtype"]
163
+ dtype = getattr(getattr(observer, "p", {}), "keywords", {}).get(
164
+ "dtype", dtype
165
+ )
166
+ assert torch.iinfo(dtype).min <= quant_min, 'quant_min out of bound'
167
+ assert quant_max <= torch.iinfo(dtype).max, 'quant_max out of bound'
168
+ observer_kwargs.update({"quant_min": quant_min, "quant_max": quant_max})
169
+ observer_kwargs["is_dynamic"] = is_dynamic
170
+ self.activation_post_process = observer(**observer_kwargs)
171
+ # TODO: keeping self.quant_min/max for BC; remove after a couple releases
172
+ # Users should use self.activation_post_process.quant_min
173
+ self.quant_min = self.activation_post_process.quant_min
174
+ self.quant_max = self.activation_post_process.quant_max
175
+ self.is_dynamic = self.activation_post_process.is_dynamic
176
+ if _is_float_qparams(self.activation_post_process.qscheme):
177
+ zero_point_dtype = torch.float
178
+ else:
179
+ zero_point_dtype = torch.int
180
+ self.register_buffer('scale', torch.tensor([1.0], dtype=torch.float))
181
+ self.register_buffer('zero_point', torch.tensor([0], dtype=zero_point_dtype))
182
+ self.dtype = self.activation_post_process.dtype
183
+ self.qscheme = self.activation_post_process.qscheme
184
+ self.ch_axis = self.activation_post_process.ch_axis \
185
+ if hasattr(self.activation_post_process, 'ch_axis') else -1
186
+ assert _is_per_channel(self.qscheme) or \
187
+ _is_per_tensor(self.qscheme), \
188
+ 'Only per channel and per tensor quantization are supported in fake quantize' + \
189
+ ' got qscheme: ' + str(self.qscheme)
190
+ self.is_per_channel = _is_per_channel(self.qscheme)
191
+
192
+ @torch.jit.export
193
+ def calculate_qparams(self):
194
+ return self.activation_post_process.calculate_qparams()
195
+
196
+ def forward(self, X):
197
+ if self.observer_enabled[0] == 1:
198
+ self.activation_post_process(X.detach())
199
+ _scale, _zero_point = self.calculate_qparams()
200
+ _scale, _zero_point = _scale.to(self.scale.device), _zero_point.to(self.zero_point.device)
201
+ if self.scale.shape != _scale.shape:
202
+ self.scale.resize_(_scale.shape)
203
+ self.zero_point.resize_(_zero_point.shape)
204
+ self.scale.copy_(_scale)
205
+ self.zero_point.copy_(_zero_point)
206
+
207
+ if self.fake_quant_enabled[0] == 1:
208
+ if self.is_per_channel:
209
+ X = torch.fake_quantize_per_channel_affine(
210
+ X, self.scale, self.zero_point,
211
+ self.ch_axis, self.activation_post_process.quant_min, self.activation_post_process.quant_max)
212
+ else:
213
+ X = torch.fake_quantize_per_tensor_affine(
214
+ X, self.scale, self.zero_point,
215
+ self.activation_post_process.quant_min, self.activation_post_process.quant_max)
216
+ return X
217
+
218
+ @torch.jit.export
219
+ def extra_repr(self):
220
+ return 'fake_quant_enabled={}, observer_enabled={}, ' \
221
+ 'quant_min={}, quant_max={}, dtype={}, qscheme={}, ch_axis={}, ' \
222
+ 'scale={}, zero_point={}'.format(
223
+ self.fake_quant_enabled, self.observer_enabled,
224
+ self.activation_post_process.quant_min, self.activation_post_process.quant_max,
225
+ self.dtype, self.qscheme, self.ch_axis, self.scale, self.zero_point)
226
+
227
+ def _save_to_state_dict(self, destination, prefix, keep_vars):
228
+ # We cannot currently register scalar values as buffers, so need to manually
229
+ # specify serialization here.
230
+ super()._save_to_state_dict(destination, prefix, keep_vars)
231
+ destination[prefix + 'scale'] = self.scale
232
+ destination[prefix + 'zero_point'] = self.zero_point
233
+
234
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
235
+ missing_keys, unexpected_keys, error_msgs):
236
+ # Removing this function throws an error that the size of the loaded tensor does not match the original size
237
+ # i.e., These buffers start out with numel 0 and become numel 1 once they have their first forward pass.
238
+ local_state = ['scale', 'zero_point']
239
+ for name in local_state:
240
+ key = prefix + name
241
+ if key in state_dict:
242
+ val = state_dict[key]
243
+ # Custom handling to allow loading scale and zero_point
244
+ # of size N into uninitialized buffers of size 0. The
245
+ # buffers are resized here, and the values are copied in
246
+ # the default state_dict loading code of the parent.
247
+ if name == 'scale':
248
+ self.scale.resize_(val.shape)
249
+ else:
250
+ assert name == 'zero_point'
251
+ self.zero_point.resize_(val.shape)
252
+ # For torchscript module we need to update the attributes here since we do not
253
+ # call the `_load_from_state_dict` function defined module.py
254
+ if torch.jit.is_scripting():
255
+ if name == 'scale':
256
+ self.scale.copy_(val)
257
+ else:
258
+ assert name == 'zero_point'
259
+ self.zero_point.copy_(val)
260
+ elif strict:
261
+ missing_keys.append(key)
262
+ super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
263
+ missing_keys, unexpected_keys, error_msgs)
264
+
265
+
266
+ class FixedQParamsFakeQuantize(FakeQuantize):
267
+ """Simulate quantize and dequantize in training time.
268
+
269
+ Simulate quantize and dequantize with fixed quantization
270
+ parameters in training time. Only per tensor quantization
271
+ is supported.
272
+ """
273
+
274
+ # TODO: rename observer to observer_ctr
275
+ def __init__(self, observer):
276
+ super().__init__(observer=observer)
277
+ assert type(self.activation_post_process) == FixedQParamsObserver, \
278
+ f"{self.__class__.__name__}'s observer must be a {FixedQParamsObserver.__name__}"
279
+ self._observer_ctr = observer
280
+ self.scale = self.activation_post_process.scale
281
+ self.zero_point = self.activation_post_process.zero_point
282
+ assert _is_per_tensor(self.qscheme), 'Only per tensor quantization is supported' + \
283
+ ' FixedQParamsFakeQuantize module, got qscheme:' + str(self.qscheme)
284
+
285
+ @torch.jit.export
286
+ def calculate_qparams(self):
287
+ return self.scale, self.zero_point
288
+
289
+ @torch.jit.export
290
+ def extra_repr(self):
291
+ """Define a string representation of the object's attributes."""
292
+ return 'fake_quant_enabled={}, observer_enabled={}, scale={}, zero_point={}, ' \
293
+ 'dtype={}, quant_min={}, quant_max={}, qscheme={}'.format(
294
+ self.fake_quant_enabled, self.observer_enabled,
295
+ self.scale, self.zero_point, self.dtype,
296
+ self.activation_post_process.quant_min, self.activation_post_process.quant_max, self.qscheme)
297
+
298
+
299
+ class FusedMovingAvgObsFakeQuantize(FakeQuantize):
300
+ r"""Define a fused module to observe the tensor.
301
+
302
+ Fused module that is used to observe the input tensor (compute min/max), compute
303
+ scale/zero_point and fake_quantize the tensor.
304
+ This module uses calculation similar MovingAverageMinMaxObserver for the inputs,
305
+ to compute the min/max values in order to compute the scale/zero_point.
306
+ The qscheme input in the observer is used to differentiate between symmetric/affine
307
+ quantization scheme.
308
+
309
+ The output of this module is given by
310
+ x_out = (clamp(round(x/scale + zero_point), quant_min, quant_max)-zero_point)*scale
311
+
312
+ Similar to :class:`~torch.ao.quantization.FakeQuantize`, and accepts the same attributes as the
313
+ base class.
314
+
315
+ """
316
+
317
+ def __init__(
318
+ self,
319
+ observer: Any = MovingAverageMinMaxObserver,
320
+ quant_min: int = 0,
321
+ quant_max: int = 255,
322
+ **observer_kwargs: Any
323
+ ) -> None:
324
+ super().__init__(observer, quant_min, quant_max, **observer_kwargs)
325
+ assert isinstance(self.activation_post_process, (MovingAverageMinMaxObserver, MovingAveragePerChannelMinMaxObserver)), \
326
+ "Fused observer+fake_quant module only works with MovingAverageMinMaxObserver"
327
+ self.register_buffer("fake_quant_enabled", torch.tensor([1], dtype=torch.long))
328
+ self.register_buffer("observer_enabled", torch.tensor([1], dtype=torch.long))
329
+ self.is_symmetric_quant = _is_symmetric_quant(self.activation_post_process.qscheme)
330
+
331
+ @torch.jit.export
332
+ def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]:
333
+ return self.activation_post_process.calculate_qparams()
334
+
335
+ @torch.jit.export
336
+ def extra_repr(self) -> str:
337
+ return (
338
+ "fake_quant_enabled={}, observer_enabled={}, scale={}, zero_point={}, "
339
+ "dtype={}, quant_min={}, quant_max={}, qscheme={}, reduce_range={}".format(
340
+ self.fake_quant_enabled,
341
+ self.observer_enabled,
342
+ self.scale,
343
+ self.zero_point,
344
+ self.dtype,
345
+ self.activation_post_process.quant_min,
346
+ self.activation_post_process.quant_max,
347
+ self.qscheme,
348
+ self.activation_post_process.reduce_range,
349
+ )
350
+ )
351
+
352
+ def forward(self, X: torch.Tensor) -> torch.Tensor:
353
+ return torch.fused_moving_avg_obs_fake_quant(
354
+ X,
355
+ self.observer_enabled,
356
+ self.fake_quant_enabled,
357
+ self.activation_post_process.min_val,
358
+ self.activation_post_process.max_val,
359
+ self.scale,
360
+ self.zero_point,
361
+ self.activation_post_process.averaging_constant,
362
+ self.activation_post_process.quant_min,
363
+ self.activation_post_process.quant_max,
364
+ self.ch_axis,
365
+ self.is_per_channel,
366
+ self.is_symmetric_quant,
367
+ )
368
+
369
+ default_fake_quant = FakeQuantize.with_args(observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255,
370
+ dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=True)
371
+ """
372
+ Default fake_quant for activations.
373
+ """
374
+
375
+ default_weight_fake_quant = FakeQuantize.with_args(observer=MovingAverageMinMaxObserver, quant_min=-128, quant_max=127,
376
+ dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, reduce_range=False)
377
+ """
378
+ Default fake_quant for weights.
379
+ Observer is memoryless since averaging_constant is 1.
380
+ """
381
+
382
+ default_dynamic_fake_quant = FakeQuantize.with_args(
383
+ observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255, is_dynamic=True,
384
+ dtype=torch.quint8, averaging_constant=1)
385
+ """
386
+ Default dynamic fake_quant for activations.
387
+ """
388
+
389
+ default_fixed_qparams_range_neg1to1_fake_quant = (
390
+ FixedQParamsFakeQuantize.with_args(observer=default_fixed_qparams_range_neg1to1_observer)
391
+ )
392
+ default_fixed_qparams_range_0to1_fake_quant = (
393
+ FixedQParamsFakeQuantize.with_args(observer=default_fixed_qparams_range_0to1_observer)
394
+ )
395
+ # TODO: the following 2 variables are kept for backwards compatibility; remove after a few releases
396
+ default_symmetric_fixed_qparams_fake_quant = default_fixed_qparams_range_neg1to1_fake_quant
397
+ default_affine_fixed_qparams_fake_quant = default_fixed_qparams_range_0to1_fake_quant
398
+
399
+ default_per_channel_weight_fake_quant = FakeQuantize.with_args(observer=MovingAveragePerChannelMinMaxObserver,
400
+ quant_min=-128,
401
+ quant_max=127,
402
+ dtype=torch.qint8,
403
+ qscheme=torch.per_channel_symmetric,
404
+ reduce_range=False,
405
+ ch_axis=0)
406
+ """
407
+ Default fake_quant for per-channel weights.
408
+ Observer is memoryless since averaging_constant is 1.
409
+ """
410
+ default_embedding_fake_quant = FakeQuantize.with_args(observer=MovingAveragePerChannelMinMaxObserver,
411
+ qscheme=torch.per_channel_affine_float_qparams,
412
+ dtype=torch.quint8,
413
+ quant_min=0,
414
+ quant_max=255,
415
+ ch_axis=0,
416
+ averaging_constant=1)
417
+ """
418
+ Default fake_quant for embeddings.
419
+ Observer is memoryless since averaging_constant is 1.
420
+ """
421
+
422
+ default_embedding_fake_quant_4bit = FakeQuantize.with_args(observer=MovingAveragePerChannelMinMaxObserver,
423
+ qscheme=torch.per_channel_affine_float_qparams,
424
+ ch_axis=0,
425
+ dtype=torch.quint4x2,
426
+ averaging_constant=1)
427
+
428
+ default_histogram_fake_quant = FakeQuantize.with_args(observer=HistogramObserver,
429
+ quant_min=0,
430
+ quant_max=255,
431
+ dtype=torch.quint8,
432
+ qscheme=torch.per_tensor_affine,
433
+ reduce_range=True)
434
+ """
435
+ Fake_quant for activations using a histogram..
436
+ """
437
+
438
+
439
+ default_fused_act_fake_quant = FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
440
+ quant_min=0,
441
+ quant_max=255,
442
+ dtype=torch.quint8,)
443
+
444
+ """
445
+ Fused version of `default_fake_quant`, with improved performance.
446
+ """
447
+
448
+
449
+ default_fused_wt_fake_quant = FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
450
+ quant_min=-128,
451
+ quant_max=127,
452
+ dtype=torch.qint8,
453
+ qscheme=torch.per_tensor_symmetric)
454
+ """
455
+ Fused version of `default_weight_fake_quant`, with improved performance.
456
+ """
457
+
458
+ default_fused_per_channel_wt_fake_quant = FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAveragePerChannelMinMaxObserver,
459
+ quant_min=-128,
460
+ quant_max=127,
461
+ dtype=torch.qint8,
462
+ qscheme=torch.per_channel_symmetric)
463
+ """
464
+ Fused version of `default_per_channel_weight_fake_quant`, with improved performance.
465
+ """
466
+
467
+ fused_wt_fake_quant_range_neg_127_to_127 = FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
468
+ quant_min=-127,
469
+ quant_max=127,
470
+ dtype=torch.qint8,
471
+ qscheme=torch.per_tensor_symmetric,
472
+ eps=2 ** -12)
473
+ """
474
+ Fused version of `default_weight_fake_quant`, with the 8-bit values restricted to [-127, +127], excluding -128.
475
+ """
476
+
477
+ fused_per_channel_wt_fake_quant_range_neg_127_to_127 = \
478
+ FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAveragePerChannelMinMaxObserver,
479
+ quant_min=-127,
480
+ quant_max=127,
481
+ dtype=torch.qint8,
482
+ qscheme=torch.per_channel_symmetric,
483
+ eps=2 ** -12)
484
+
485
+ """
486
+ Fused version of `default_per_channel_weight_fake_quant`, with the 8-bit values restricted to [-127, +127], excluding -128.
487
+ """
488
+
489
+
490
+ def _is_fake_quant_script_module(mod):
491
+ """Return true if given mod is an instance of FakeQuantize script module."""
492
+ if isinstance(mod, torch.jit.RecursiveScriptModule):
493
+ # qualified name looks like '__torch__.torch.ao.quantization.fake_quantize.___torch_mangle_2.FakeQuantize'
494
+ suffix = mod._c.qualified_name.split('.', 1)[1]
495
+ name = re.sub(r'\.___torch_mangle_\d+', '', suffix)
496
+ return name == 'torch.ao.quantization.fake_quantize.FakeQuantize' or \
497
+ name == 'torch.ao.quantization.fake_quantize.FusedMovingAvgObsFakeQuantize'
498
+ return False
499
+
500
+ def disable_fake_quant(mod):
501
+ """Disable fake quantization for the module.
502
+
503
+ Disable fake quantization for this module, if applicable. Example usage::
504
+
505
+ # model is any PyTorch model
506
+ model.apply(torch.ao.quantization.disable_fake_quant)
507
+
508
+ """
509
+ if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod):
510
+ mod.disable_fake_quant()
511
+
512
+ def enable_fake_quant(mod):
513
+ """Enable fake quantization for the module.
514
+
515
+ Enable fake quantization for this module, if applicable. Example usage::
516
+
517
+ # model is any PyTorch model
518
+ model.apply(torch.ao.quantization.enable_fake_quant)
519
+
520
+ """
521
+ if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod):
522
+ mod.enable_fake_quant()
523
+
524
+ def disable_observer(mod):
525
+ """Disable observation for this module.
526
+
527
+ Disable observation for this module, if applicable. Example usage::
528
+
529
+ # model is any PyTorch model
530
+ model.apply(torch.ao.quantization.disable_observer)
531
+
532
+ """
533
+ if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod):
534
+ mod.disable_observer()
535
+
536
+ def enable_observer(mod):
537
+ """Enable observation for this module.
538
+
539
+ Enable observation for this module, if applicable. Example usage::
540
+
541
+ # model is any PyTorch model
542
+ model.apply(torch.ao.quantization.enable_observer)
543
+
544
+ """
545
+ if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod):
546
+ mod.enable_observer()
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/_model_report/model_report.py ADDED
@@ -0,0 +1,606 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Set, Tuple, Callable
2
+ from collections import OrderedDict
3
+ import torch
4
+ from torch.ao.quantization.fx._model_report.detector import (
5
+ DetectorBase,
6
+ DETECTOR_OBS_ARGS_KEY,
7
+ DETECTOR_OBS_TO_INSERT_KEY,
8
+ DETECTOR_IS_POST_OBS_KEY,
9
+ DETECTOR_TARGET_NODE_KEY,
10
+ DetectorQConfigInfo
11
+ )
12
+ from torch.ao.quantization.fx._model_report.model_report_visualizer import ModelReportVisualizer
13
+ from torch.ao.quantization.fx.graph_module import GraphModule
14
+ from torch.ao.quantization.observer import ObserverBase
15
+ from torch.ao.quantization.qconfig_mapping import QConfigMapping, QConfig
16
+ from torch.ao.quantization.fx._equalize import EqualizationQConfig
17
+
18
+ class ModelReport:
19
+ r"""
20
+ The ModelReport class aims to provide users an easy way to diagnose issues that they run into
21
+ with their models. The class works with all traceable GraphModules to help diagnose issues,
22
+ though the requirements on the type of model more-so depends on the specific report the user
23
+ is trying to generate. With respect to the reports, the ModelReport class is initialized with
24
+ a set of Detector classes, each of which generate reports on quantization configuration
25
+ issues a use might have.
26
+
27
+ Currently supports generating reports on:
28
+ - Suggestions for per-channel vs. per-tensor quantization (nn.Module)
29
+ - Suggestions for dynamic vs static quantization for linear layers (Graph Modules)
30
+ - Suggestions for input-weight equalization for linear and conv layers (Graph Modules)
31
+ - Suggestions for outlier detection for all layers (Graph Modules)
32
+
33
+ The ModelReport class has the primary functionality of inserting observers (primarily the ModelReportObserver)
34
+ where needed for each detector to gather the information it needs, and then after callibration, the ModelReport
35
+ class compiles the report generated by each Detector class into a single report to return to the user. It also
36
+ has the capability to remove all the observers it inserted as well.
37
+
38
+ * :attr:`_model` The model we wish to generate the report for. Must be a traceable GraphModule
39
+
40
+ * :attr:`_desired_report_detectors` The set of Detectors representing desired reports from the ModelReport class
41
+ Make sure that these are all unique types of detectors [do not have more than 1 of the same class]
42
+
43
+ * :attr:`_desired_detector_names` The set of detector names of the _desired_report_detectors.
44
+ This set is generated by calling the get_detector_name() of each detector
45
+
46
+ * :attr:`_detector_name_to_observer_fqns` The mapping from each detector to fqns of observers of interest
47
+ The purpose of this is to keep track of what observers were inserted for each detector, so that they
48
+ can be removed at the end if desired
49
+
50
+ * :attr:`_prepared_flag` A boolean flag that keeps track of whether we have prepared the model or not
51
+ This is to ensure we only insert observers once with the ModelReport instance
52
+
53
+ * :attr:`_removed_observers` A boolean to track if we have removed observers already
54
+ The purpose is to ensure we don't attempt to remove observers twice with the same ModelReport
55
+ instance. This also allows the functionality where we can generate the report multiple times
56
+ as long as we haven't removed the observers yet.
57
+
58
+ Note:
59
+ This class was initially designed to work with the Fx Graph Mode workflow in mind. However,
60
+ full functionality is available as long as there is a traceable GraphModule that is being used.
61
+ One method to get a traceable GraphModule without going through the Fx workflow is to use
62
+ the QuantizationTracer class.
63
+
64
+ General Flow for Fx workflow:
65
+ 1.) Initialize ModelReport object with reports of interest by passing in initialized detector objects and model
66
+ 2.) Prepare your model with prepare_fx
67
+ 3.) Call model_report.prepare_detailed_calibration to add relevant observers
68
+ 4.) Callibrate your model with data
69
+ 5.) Call model_report.generate_report on your model to generate report and optionally remove added observers
70
+ Optional
71
+ 6.) Call model_report.generate_visualizer to get a ModelReportVisualizer instance
72
+ 7.) To help in parsing report information and debugging, view report info as a:
73
+ - Table
74
+ - Histogram
75
+ - Line plot
76
+ 8.) Call model_report.generate_qconfigs to generate the qconfigs based on the report suggestions
77
+
78
+ Example (with QuantizationTracer):
79
+ >>> # xdoctest: +SKIP
80
+ >>> # get the necessary qconfig
81
+ >>> config = PrepareCustomConfig()
82
+ >>> skipped_module_names, skipped_module_classes = get_skipped_module_name_and_classes(config, False)
83
+
84
+ >>> # initialize our model and get GraphModule
85
+ >>> model = SomeModel()
86
+ >>> tracer = QuantizationTracer(skipped_module_names, skipped_module_classes)
87
+ >>> graph_module = GraphModule(model, tracer.trace(model))
88
+
89
+ >>> # get our set of detectors and ModelReport instance
90
+ >>> detector_set = set([DynamicStaticDetector(tolerance=0.5), InputWeightEqualizationDetector(ratio_threshold=0.7)])
91
+ >>> tracer_reporter = ModelReport(graph_module, tracer_detector_set)
92
+
93
+ >>> # now we insert the observers and callibrate the model
94
+ >>> tracer_model_with_observers = tracer_reporter.prepare_detailed_calibration()
95
+ >>> for i in range(num_callibration_batches):
96
+ >>> example_input = get_callibration_input()
97
+ >>> tracer_model_with_observers(example_input)
98
+
99
+ >>> # finally we generate the reports and optionally remove the observers we inserted
100
+ >>> reports = tracer_reporter.generate_model_report(remove_inserted_observers=True)
101
+
102
+ >>> # Optional: we can generate the qconfig mapping based on the suggestions
103
+ >>> qconfigs = model_report.generate_qconfig_mapping()
104
+
105
+ >>> # Optional: we can generate the equalization mapping based on the suggestions
106
+ >>> qconfigs = model_report.generate_equalization_mapping()
107
+
108
+ >>> # Optional: we get a ModelReportVisualizer instance to do any visualizations desired
109
+ >>> model_report_visualizer = tracer_reporter.generate_visualizer()
110
+
111
+ """
112
+
113
+ def __init__(self, model: GraphModule, desired_report_detectors: Set[DetectorBase]):
114
+
115
+ if len(desired_report_detectors) == 0:
116
+ raise ValueError("Should include at least 1 desired report")
117
+
118
+ # keep track of the model we wish to generate report for
119
+ self._model: GraphModule = model
120
+
121
+ # keep the reports private so they can't be modified
122
+ self._desired_report_detectors = desired_report_detectors
123
+ self._desired_detector_names = {detector.get_detector_name() for detector in desired_report_detectors}
124
+
125
+ # keep a mapping of desired reports to observers of interest
126
+ # this is to get the readings, and to remove them, can create a large set
127
+ # this set can then be used to traverse the graph and remove added observers
128
+ self._detector_name_to_observer_fqns: Dict[str, Set[str]] = {}
129
+
130
+ # initialize each report to have empty set of observers of interest
131
+ for desired_report in self._desired_detector_names:
132
+ self._detector_name_to_observer_fqns[desired_report] = set()
133
+
134
+ # flags to ensure that we can only prepare and remove observers once
135
+ self._prepared_flag = False
136
+ self._removed_observers = False
137
+
138
+ # store the reports that we generated for visualization purposes
139
+ # initially empty since no reports generated
140
+ self._generated_reports: Dict[str, Dict] = {}
141
+
142
+ def get_desired_reports_names(self) -> Set[str]:
143
+ """ Returns a copy of the desired reports for viewing """
144
+ return self._desired_detector_names.copy()
145
+
146
+ def get_observers_of_interest(self) -> Dict[str, Set[str]]:
147
+ """ Returns a copy of the observers of interest for viewing """
148
+ return self._detector_name_to_observer_fqns.copy()
149
+
150
+ def prepare_detailed_calibration(self) -> GraphModule:
151
+ r"""
152
+ Takes in a graph model and inserts the following observers:
153
+ - ModelReportObserver
154
+
155
+ Each observer is inserted based on the desired_reports into the relevant locations
156
+
157
+ Right now, each report in self._desired_detector_names has independent insertions
158
+ However, if a module already has a Observer of the same type, the insertion will not occur
159
+ This is because all of the same type of Observer collect same information, so redundant
160
+
161
+ Returns the same GraphModule with the observers inserted
162
+ """
163
+
164
+ # if already prepared once, cannot prepare again
165
+ if self._prepared_flag:
166
+ raise ValueError("Already ran preparing detailed callibration. Run the report generation next after callibration.")
167
+
168
+ # loop through each detector, find where placements should be, and keep track
169
+ insert_observers_fqns: Dict[str, Any] = {}
170
+
171
+ for detector in self._desired_report_detectors:
172
+ # determine observer points for each detector
173
+ obs_fqn_to_info = detector.determine_observer_insert_points(self._model)
174
+ # map each insert point to the observer to use
175
+ insert_observers_fqns.update(obs_fqn_to_info)
176
+ # update the set of observers this report cares about
177
+ self._detector_name_to_observer_fqns[detector.get_detector_name()] = set(obs_fqn_to_info.keys())
178
+
179
+ # now insert all the observers at their desired locations
180
+ for observer_fqn in insert_observers_fqns:
181
+ target_node = insert_observers_fqns[observer_fqn][DETECTOR_TARGET_NODE_KEY]
182
+ insert_obs = insert_observers_fqns[observer_fqn][DETECTOR_OBS_TO_INSERT_KEY]
183
+ insert_post = insert_observers_fqns[observer_fqn][DETECTOR_IS_POST_OBS_KEY]
184
+ observer_args = insert_observers_fqns[observer_fqn][DETECTOR_OBS_ARGS_KEY]
185
+ self._insert_observer_around_module(
186
+ observer_fqn, target_node, insert_obs, observer_args, insert_post
187
+ )
188
+
189
+ self._prepared_flag = True
190
+
191
+ return self._model
192
+
193
+ def _insert_observer_around_module(
194
+ self,
195
+ obs_fqn: str,
196
+ target_node: torch.fx.node.Node,
197
+ obs_to_insert: ObserverBase,
198
+ observer_args: Tuple,
199
+ insert_post: bool
200
+ ):
201
+ r"""
202
+ Helper function that inserts the observer into both the graph structure and the module of the model
203
+
204
+ Args
205
+ node_fqn (str): The fully qualified name of the observer we want to insert
206
+ target_node (torch.fx.node.Node): The node in model we are inserting observers around
207
+ obs_to_insert (ObserverBase): The observer we are inserting around target_node
208
+ observer_args (Tuple): The arguments we want to pass into the observer
209
+ insert_post (bool): whether this is meant to be a post observer for this node
210
+ """
211
+ # if we are inserting post, then our target node is the next node
212
+ if insert_post:
213
+ target_node = target_node.next
214
+
215
+ with self._model.graph.inserting_before(target_node):
216
+ self._model.add_submodule(obs_fqn, obs_to_insert)
217
+ self._model.graph.create_node(op="call_module", target=obs_fqn, args=observer_args)
218
+
219
+ # recompile model after inserts are made
220
+ self._model.recompile()
221
+
222
+ def _get_node_from_fqn(self, node_fqn: str) -> torch.fx.node.Node:
223
+ r"""
224
+ Takes in a node fqn and returns the node based on the fqn
225
+
226
+ Args
227
+ node_fqn (str): The fully qualified name of the node we want to find in model
228
+
229
+ Returns the Node object of the given node_fqn otherwise returns None
230
+ """
231
+ node_to_return = None
232
+ for node in self._model.graph.nodes:
233
+ # if the target matches the fqn, it's the node we are looking for
234
+ if node.target == node_fqn:
235
+ node_to_return = node
236
+ break
237
+
238
+ if node_to_return is None:
239
+ raise ValueError("The node_fqn is was not found within the module.")
240
+
241
+ # assert for MyPy
242
+ assert isinstance(node_to_return, torch.fx.node.Node)
243
+
244
+ return node_to_return
245
+
246
+ def generate_model_report(
247
+ self, remove_inserted_observers: bool
248
+ ) -> Dict[str, Tuple[str, Dict]]:
249
+ r"""
250
+ Generates all the requested reports.
251
+
252
+ Note:
253
+ You should have callibrated the model with relevant data before calling this
254
+
255
+ The reports generated are specified by the desired_reports specified in desired_reports
256
+
257
+ Can optionally remove all the observers inserted by the ModelReport instance
258
+
259
+ Args:
260
+ remove_inserted_observers (bool): True to remove the observers inserted by this ModelReport instance
261
+
262
+ Returns a mapping of each desired report name to a tuple with:
263
+ The textual summary of that report information
264
+ A dictionary containing relevant statistics or information for that report
265
+
266
+ Note:
267
+ Throws exception if we try to generate report on model we already removed observers from
268
+ Throws exception if we try to generate report without preparing for callibration
269
+ """
270
+ # if we haven't prepped model for callibration, then we shouldn't generate report yet
271
+ if not self._prepared_flag:
272
+ raise Exception("Cannot generate report without preparing model for callibration")
273
+
274
+ # if we already removed the observers, we cannot generate report
275
+ if self._removed_observers:
276
+ raise Exception("Cannot generate report on model you already removed observers from")
277
+
278
+ # keep track of all the reports of interest and their outputs
279
+ reports_of_interest = {}
280
+
281
+ for detector in self._desired_report_detectors:
282
+ # generate the individual report for the detector
283
+ report_output = detector.generate_detector_report(self._model)
284
+ reports_of_interest[detector.get_detector_name()] = report_output
285
+
286
+ # if user wishes to remove inserted observers, go ahead and remove
287
+ if remove_inserted_observers:
288
+ self._removed_observers = True
289
+ # get the set of all Observers inserted by this instance of ModelReport
290
+ all_observers_of_interest: Set[str] = set()
291
+ for desired_report in self._detector_name_to_observer_fqns:
292
+ observers_of_interest = self._detector_name_to_observer_fqns[desired_report]
293
+ all_observers_of_interest.update(observers_of_interest)
294
+
295
+ # go through all_observers_of_interest and remove them from the graph and model
296
+ for observer_fqn in all_observers_of_interest:
297
+ # remove the observer from the model
298
+ self._model.delete_submodule(observer_fqn)
299
+
300
+ # remove the observer from the graph structure
301
+ node_obj = self._get_node_from_fqn(observer_fqn)
302
+
303
+ if node_obj:
304
+ self._model.graph.erase_node(node_obj)
305
+ else:
306
+ raise ValueError("Node no longer exists in GraphModule structure")
307
+
308
+ # remember to recompile the model
309
+ self._model.recompile()
310
+
311
+ # save the generated reports for visualization purposes
312
+ saved_reports: Dict[str, Dict] = {
313
+ report_name : report_tuple[1] for report_name, report_tuple in reports_of_interest.items()
314
+ }
315
+
316
+ self._generated_reports = saved_reports
317
+
318
+ # return the reports of interest
319
+ return reports_of_interest
320
+
321
+ def _is_same_info_for_same_key(self, info_dict_a: Dict, info_dict_b: Dict) -> bool:
322
+ r"""
323
+ Takes in two dictionaries and ensures that any common keys between the two have the same
324
+ values.
325
+
326
+ Args:
327
+ info_dict_a (Dict): First dictionary we wish to compare
328
+ info_dict_b (Dict): Second dictionary we wish to compare
329
+
330
+ Returns True if all shared keys have same values, false otherwise
331
+ """
332
+ # get the set of keys for both
333
+ dict_a_keys: Set = set(info_dict_a.keys())
334
+ dict_b_keys: Set = set(info_dict_b.keys())
335
+
336
+ # get the insersection keys and check if same value for both dicts
337
+ intersecting_keys: Set = dict_a_keys.intersection(dict_b_keys)
338
+
339
+ for key in intersecting_keys:
340
+ dict_a_val = info_dict_a[key]
341
+ dict_b_val = info_dict_b[key]
342
+
343
+ # if it's a tensor we have to handle separately
344
+ if type(dict_a_val) == torch.Tensor:
345
+ # if dict_b_val not tensor, automatically false
346
+ if type(dict_b_val) != torch.Tensor or sum(dict_a_val != dict_b_val) != 0:
347
+ return False
348
+ else:
349
+ # for non-tensor vals
350
+ if dict_a_val != dict_b_val:
351
+ return False
352
+
353
+ # if no non matching shared keys found, return true
354
+ return True
355
+
356
+ def _reformat_reports_for_visualizer(self) -> OrderedDict:
357
+ r"""
358
+ Takes the generated reports and reformats them into the format that is desired by the
359
+ ModelReportVisualizer
360
+
361
+ Returns an OrderedDict mapping module_fqns to their features
362
+ """
363
+ # we want to reorder and reformat the information so it is ordered in terms of order
364
+ # found in the model
365
+
366
+ # first create new dict with all modules as keys and features under respective module
367
+ module_fqns_to_features: Dict[str, Dict] = {}
368
+
369
+ for report_name in self._generated_reports:
370
+ # get mod -> feature dict and go through
371
+ module_info = self._generated_reports[report_name]
372
+
373
+ for module_fqn in module_info:
374
+ # check if already in our accumulation dict
375
+ if module_fqn in module_fqns_to_features:
376
+ # we merge all the features together
377
+ new_info: Dict = module_info[module_fqn]
378
+ present_info: Dict = module_fqns_to_features[module_fqn]
379
+
380
+ # merge them together into the new unioned dict
381
+ # same features keys -> same info, so okay if override
382
+
383
+ # do safety check to make sure shared keys have same info
384
+ if self._is_same_info_for_same_key(new_info, present_info):
385
+ module_fqns_to_features[module_fqn] = {**new_info, **present_info}
386
+ else:
387
+ error_str = "You have the same key with different values across detectors. "
388
+ error_str += "Someone incorrectly implemented a detector with conflicting keys to existing detectors."
389
+ raise ValueError(error_str)
390
+ else:
391
+ # we just set it
392
+ module_fqns_to_features[module_fqn] = module_info[module_fqn]
393
+
394
+ # our ordered dict so that modules can be ordered in order of how they appear in model
395
+ features_by_module: OrderedDict[str, Dict] = OrderedDict()
396
+
397
+ # we loop through modules in graph in order
398
+ for fqn, module in self._model.named_modules():
399
+ # find that fqn in fqns_to_features
400
+ if fqn in module_fqns_to_features:
401
+ # add it to our ordered dict
402
+ features_by_module[fqn] = module_fqns_to_features[fqn]
403
+
404
+ # return the ordered dict of info we created
405
+ return features_by_module
406
+
407
+ def generate_visualizer(self) -> ModelReportVisualizer:
408
+ r"""
409
+ Generates a ModelReportVisualizer instance using the reports generated
410
+ by the generate_model_report() method.
411
+
412
+ Returns the generated ModelReportVisualizer instance initialized
413
+
414
+ Note:
415
+ Throws exception if attempt to get visualizers without generating report
416
+ """
417
+ # check if user has generated reports at least once
418
+ if len(self._generated_reports) == 0:
419
+ raise Exception("Unable to generate visualizers without first generating reports")
420
+
421
+ # get the ordered dict mapping modules to their full set of collected features / stats
422
+ module_fqns_to_features: OrderedDict = self._reformat_reports_for_visualizer()
423
+
424
+ # create and return ModelReportVisualizer instance
425
+ visualizer: ModelReportVisualizer = ModelReportVisualizer(module_fqns_to_features)
426
+
427
+ return visualizer
428
+
429
+ def _generate_qconfig_mapping_helper(
430
+ self,
431
+ detector_qconfig_info_combined: Dict[str, DetectorQConfigInfo],
432
+ generation_function: Callable
433
+ ) -> QConfigMapping:
434
+ r"""
435
+ This helper takes in the compiled detector qconfig info that
436
+ has been compiled together and merges it into a QConfigMapping
437
+ """
438
+ # keep track of the qconfigmapping
439
+ qconfig_mapping = QConfigMapping()
440
+
441
+ # loop through each module / fqn and attempt to create QConfigMapping
442
+ for fqn, module in self._model.named_modules():
443
+ # if we have a qconfig info for this module
444
+ if fqn in detector_qconfig_info_combined:
445
+ qconfig_info_compiled = detector_qconfig_info_combined[fqn]
446
+
447
+ # now generate the qconfig and add it to the mapping
448
+ generated_qconfig = generation_function(qconfig_info_compiled, module)
449
+
450
+ # add to our config
451
+ qconfig_mapping.set_module_name(fqn, generated_qconfig)
452
+
453
+ # return compiled mapping
454
+ return qconfig_mapping
455
+
456
+ def _update_detector_quantizaiton_qconfig_info(self, combined_info: DetectorQConfigInfo, new_info: DetectorQConfigInfo):
457
+ r"""
458
+ Takes in the old and new information and updates the combined information.
459
+
460
+ Args:
461
+ combined_info (DetectorQConfigInfo): The DetectorQConfigInfo we are compiling all of the information in
462
+ new_info (DetectorQConfigInfo): The DetectorQConfigInfo with the information we are trying to merge the new info
463
+ into it
464
+ """
465
+ combined_info.is_activation_dynamic = combined_info.is_activation_dynamic or new_info.is_activation_dynamic
466
+ combined_info.is_weight_per_channel = combined_info.is_weight_per_channel or new_info.is_weight_per_channel
467
+
468
+ def _update_detector_equalization_qconfig_info(self, combined_info: DetectorQConfigInfo, new_info: DetectorQConfigInfo):
469
+ r"""
470
+ Takes in the old and new information and updates the combined information.
471
+
472
+ Args:
473
+ combined_info (DetectorQConfigInfo): The DetectorQConfigInfo we are compiling all of the information in
474
+ new_info (DetectorQConfigInfo): The DetectorQConfigInfo with the information we are trying to merge the new info
475
+ into it
476
+ """
477
+ is_equalization_recommended = combined_info.is_equalization_recommended or new_info.is_equalization_recommended
478
+ combined_info.is_equalization_recommended = is_equalization_recommended
479
+
480
+ def _generate_module_fqn_to_detector_info_mapping(
481
+ self,
482
+ update_qconfig_info_function: Callable
483
+ ) -> Dict[str, DetectorQConfigInfo]:
484
+ r"""
485
+ Generates a QConfigMapping based on the suggestions of the
486
+ ModelReport API. The generated mapping encompasses all the
487
+ different types of feedback from the different detectors
488
+ all into one place.
489
+
490
+ These configs are based on the suggestions provided by the ModelReport API
491
+ and can only be generated once the reports have been generated.
492
+
493
+ Args:
494
+ update_qconfig_info_function (Callable) takes in a function that takes in two DetectorQConfigInfo
495
+ and updates the one that is being compiled
496
+
497
+ Returns a Dict mapping module_fqns to DetectorQConfigInfo objects
498
+
499
+ Note:
500
+ Throws exception if we try to generate mapping on model we already removed observers from
501
+ Throws exception if we try to generate mapping without preparing for callibration
502
+ """
503
+ # if we haven't prepped model for callibration, then we shouldn't generate mapping yet
504
+ if not self._prepared_flag:
505
+ raise Exception("Cannot generate report without preparing model for callibration")
506
+
507
+ # if we already removed the observers, we cannot mapping
508
+ if self._removed_observers:
509
+ raise Exception("Cannot generate report on model you already removed observers from")
510
+
511
+ # keep track of qconfig info for each module across detectors
512
+ detector_qconfig_info_combined: Dict[str, DetectorQConfigInfo] = {}
513
+
514
+ for detector in self._desired_report_detectors:
515
+ # get the info from the detector
516
+ detector_info: Dict[str, DetectorQConfigInfo] = detector.get_qconfig_info(self._model)
517
+
518
+ # we go through the modules
519
+ for module_fqn in detector_info:
520
+ # see if we already have info on it
521
+ if module_fqn in detector_qconfig_info_combined:
522
+ # we combine the current options with what is there
523
+ current_options = detector_qconfig_info_combined[module_fqn]
524
+ detector_options = detector_info[module_fqn]
525
+
526
+ update_qconfig_info_function(current_options, detector_options)
527
+ else:
528
+ # we just use this for now
529
+ detector_qconfig_info_combined[module_fqn] = detector_info[module_fqn]
530
+
531
+ return detector_qconfig_info_combined
532
+
533
+ def generate_qconfig_mapping(self) -> QConfigMapping:
534
+ r"""
535
+ Generates a QConfigMapping based on the suggestions of the
536
+ ModelReport API. The generated mapping encompasses all the
537
+ different types of feedback from the different detectors
538
+ all into one place.
539
+
540
+ These configs are based on the suggestions provided by the ModelReport API
541
+ and can only be generated once the reports have been generated.
542
+
543
+ Returns a QConfigMapping for the quantization configuration
544
+
545
+ Note:
546
+ Throws exception if we try to generate mapping on model we already removed observers from
547
+ Throws exception if we try to generate mapping without preparing for callibration
548
+ """
549
+ # get the mapping info
550
+ detector_qconfig_info_combined = self._generate_module_fqn_to_detector_info_mapping(
551
+ self._update_detector_quantizaiton_qconfig_info
552
+ )
553
+
554
+ # we will do a bit of processing and remove fqns that don't have input weight recommended
555
+
556
+ # now we generate the QConfig for each of the options
557
+ mapping: QConfigMapping = self._generate_qconfig_mapping_helper(
558
+ detector_qconfig_info_combined,
559
+ self._quantization_config_generator
560
+ )
561
+
562
+ # return the generated mapping
563
+ return mapping
564
+
565
+ def _quantization_config_generator(self, detector_qconfig_info: DetectorQConfigInfo, module: torch.nn.Module) -> QConfig:
566
+ r"""
567
+ Returns the quantization configuration generated by the DetectorQConfigInfo object
568
+ """
569
+ return detector_qconfig_info.generate_quantization_qconfig(module)
570
+
571
+ def _equalization_config_generator(
572
+ self,
573
+ detector_qconfig_info: DetectorQConfigInfo,
574
+ module: torch.nn.Module
575
+ ) -> EqualizationQConfig:
576
+ r"""
577
+ We ignore the module argument here, and only focus on thedetector_qconfig_info
578
+
579
+ Returns the equalization configuration generated by the DetectorQConfigInfo object
580
+ """
581
+ return detector_qconfig_info.generate_equalization_qconfig()
582
+
583
+ def generate_equalization_mapping(self) -> QConfigMapping:
584
+ r"""
585
+ Generates a QConfigMapping based on the suggestions of the
586
+ ModelReport API for equalization. The generated mapping encompasses all the
587
+ different types of feedback from the input-weight equalization detector.
588
+
589
+ These configs are based on the suggestions provided by the ModelReport API
590
+ and can only be generated once the reports have been generated.
591
+
592
+ Returns a QConfigMapping for the equalization configuration
593
+ """
594
+ # get the mapping info
595
+ detector_qconfig_info_combined = self._generate_module_fqn_to_detector_info_mapping(
596
+ self._update_detector_equalization_qconfig_info
597
+ )
598
+
599
+ # now we generate the QConfig for each of the options
600
+ mapping: QConfigMapping = self._generate_qconfig_mapping_helper(
601
+ detector_qconfig_info_combined,
602
+ self._equalization_config_generator
603
+ )
604
+
605
+ # return the generated mapping
606
+ return mapping
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/prepare.py ADDED
@@ -0,0 +1,1880 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import torch
3
+ import warnings
4
+ from torch.fx import (
5
+ GraphModule,
6
+ )
7
+ from torch.fx.graph import (
8
+ Graph,
9
+ Node,
10
+ )
11
+ from torch.fx.node import Argument
12
+
13
+ from ..quantize import (
14
+ propagate_qconfig_,
15
+ )
16
+ from ..observer import (
17
+ _is_activation_post_process,
18
+ _PartialWrapper,
19
+ )
20
+ from ..qconfig import (
21
+ _is_reuse_input_qconfig,
22
+ QConfigAny,
23
+ )
24
+ from ..qconfig_mapping import (
25
+ QConfigMapping,
26
+ )
27
+ from .qconfig_mapping_utils import (
28
+ _generate_node_name_to_qconfig,
29
+ _update_qconfig_for_fusion,
30
+ _get_flattened_qconfig_dict,
31
+ _update_qconfig_for_qat,
32
+ )
33
+
34
+ from .quantize_handler import (
35
+ _default_root_node_getter,
36
+ _get_pattern_to_quantize_handlers,
37
+ QuantizeHandler,
38
+ )
39
+
40
+ from torch.ao.quantization import (
41
+ ObserverBase,
42
+ FixedQParamsObserver,
43
+ FixedQParamsFakeQuantize,
44
+ _DerivedObserverOrFakeQuantize,
45
+ )
46
+
47
+ from torch.ao.quantization.utils import (
48
+ Pattern,
49
+ NodePattern,
50
+ )
51
+
52
+ from ._equalize import (
53
+ is_equalization_observer,
54
+ node_supports_equalization,
55
+ )
56
+
57
+ from .pattern_utils import (
58
+ _sorted_patterns_dict,
59
+ )
60
+
61
+ from .match_utils import (
62
+ _MatchResultWithQConfig,
63
+ _find_matches,
64
+ )
65
+
66
+ from .utils import (
67
+ _insert_dequant_stubs_for_custom_module_lstm_output,
68
+ _is_custom_module_lstm,
69
+ _maybe_get_custom_module_lstm_from_node_arg,
70
+ _qconfig_satisfies_dtype_config_constraints,
71
+ get_custom_module_class_keys,
72
+ all_node_args_have_no_tensors,
73
+ assert_and_get_unique_device,
74
+ get_non_observable_arg_indexes_and_types,
75
+ get_new_attr_name_with_prefix,
76
+ node_arg_is_weight,
77
+ node_arg_is_bias,
78
+ NON_QUANTIZABLE_WEIGHT_OPS,
79
+ ObservedGraphModuleAttrs,
80
+ )
81
+
82
+ from torch.ao.quantization import (
83
+ PlaceholderObserver
84
+ )
85
+ from torch.ao.quantization.quantize import (
86
+ convert
87
+ )
88
+
89
+ from ..utils import (
90
+ _parent_name,
91
+ get_qconfig_dtypes,
92
+ get_swapped_custom_module_class,
93
+ )
94
+
95
+ from ..backend_config.utils import (
96
+ get_pattern_to_dtype_configs,
97
+ get_module_to_qat_module,
98
+ get_fusion_pattern_to_root_node_getter,
99
+ )
100
+ from ..backend_config import (
101
+ BackendConfig,
102
+ DTypeConfig,
103
+ get_native_backend_config,
104
+ )
105
+ from .custom_config import (
106
+ PrepareCustomConfig,
107
+ StandaloneModuleConfigEntry,
108
+ )
109
+ from torch.ao.quantization.quantizer import (
110
+ EdgeOrNode,
111
+ QuantizationSpec,
112
+ QuantizationSpecBase,
113
+ FixedQParamsQuantizationSpec,
114
+ SharedQuantizationSpec,
115
+ DerivedQuantizationSpec,
116
+ )
117
+ from torch.ao.quantization import ObserverOrFakeQuantize
118
+
119
+ from torch._subclasses import FakeTensor
120
+
121
+ from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
122
+ from dataclasses import asdict
123
+
124
+ __all__ = [
125
+ "insert_observers_for_model",
126
+ "prepare",
127
+ "propagate_dtypes_for_known_nodes",
128
+ ]
129
+
130
+
131
+ # list of dtypes to not add observers to
132
+ _DO_NOT_OBS_DTYPE_LIST = [int, float, torch.bool, None]
133
+ _OBS_DTYPE_LIST = [
134
+ torch.quint8,
135
+ torch.qint8,
136
+ torch.qint32,
137
+ torch.float16,
138
+ torch.uint8,
139
+ torch.int8,
140
+ torch.int16,
141
+ torch.int32
142
+ ]
143
+
144
+ _DEFAULT_FP32_OBS_OR_FQ_CTR = PlaceholderObserver.with_args(dtype=torch.float)
145
+
146
+ # note: the following default target dtype info dicts are temporary,
147
+ # should be moved to the new programmable API class soon
148
+ _DEFAULT_FP32_QCONFIG_FOR_TARGET_DTYPE_INFO = {
149
+ "input_act_obs_or_fq_ctr": torch.ao.quantization.qconfig._default_fp32_placeholder_qconfig.activation,
150
+ "output_act_obs_or_fq_ctr": torch.ao.quantization.qconfig._default_fp32_placeholder_qconfig.activation
151
+ }
152
+
153
+ _DEFAULT_QUINT8_QCONFIG_FOR_TARGET_DTYPE_INFO = {
154
+ "input_act_obs_or_fq_ctr": torch.ao.quantization.qconfig._default_quint8_placeholder_qconfig.activation,
155
+ "output_act_obs_or_fq_ctr": torch.ao.quantization.qconfig._default_quint8_placeholder_qconfig.activation
156
+ }
157
+
158
+
159
+ def _get_observer_kwargs(quant_spec: Union[QuantizationSpec, FixedQParamsQuantizationSpec]):
160
+ kwargs_dict = asdict(quant_spec)
161
+ return copy.deepcopy(kwargs_dict)
162
+
163
+ def _get_qspec_for_arg(
164
+ arg: Node,
165
+ input_qspec_map: Dict[Node, QuantizationSpecBase],
166
+ named_modules: Dict[str, torch.nn.Module]
167
+ ) -> Optional[QuantizationSpecBase]:
168
+ while _is_activation_post_process_node(arg, named_modules):
169
+ arg = arg.args[0] # type: ignore[assignment]
170
+ return input_qspec_map.get(arg, None)
171
+
172
+ def _create_obs_or_fq_from_qspec(
173
+ quantization_spec: Optional[QuantizationSpecBase],
174
+ obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize],
175
+ is_qat: bool,
176
+ ):
177
+ """ Create observer or fake quantize objects based on quantization spec
178
+
179
+ Args:
180
+ quantization_spec: used to store parameters to create the observer or fake quantizer
181
+ obs_or_fq_map: this is a map from edge/output to the corresponding observer/fake_quant
182
+ instance, it may be reused for different edge/output depending on configuration
183
+ """
184
+ if quantization_spec is None:
185
+ return None
186
+ if isinstance(quantization_spec, SharedQuantizationSpec):
187
+ edge_or_node = quantization_spec.edge_or_node
188
+ assert edge_or_node in obs_or_fq_map, \
189
+ "please make sure only refer to edge or node that has " \
190
+ f"observer/fake_quant inserted: '{edge_or_node}' not in\n{obs_or_fq_map.keys()}"
191
+ return obs_or_fq_map[edge_or_node]
192
+ elif isinstance(quantization_spec, DerivedQuantizationSpec):
193
+ # can't use asdict, so not calling get_observer_kwargs here
194
+ kwargs = {
195
+ "dtype": quantization_spec.dtype,
196
+ "derive_qparams_fn": quantization_spec.derive_qparams_fn,
197
+ "quant_min": quantization_spec.quant_min,
198
+ "quant_max": quantization_spec.quant_max,
199
+ "qscheme": quantization_spec.qscheme,
200
+ "ch_axis": quantization_spec.ch_axis,
201
+ }
202
+ edge_or_nodes = quantization_spec.derived_from
203
+ obs_or_fqs = [obs_or_fq_map[k] for k in edge_or_nodes]
204
+ kwargs["obs_or_fqs"] = obs_or_fqs
205
+ return _DerivedObserverOrFakeQuantize.with_args(**kwargs)()
206
+ elif isinstance(quantization_spec, FixedQParamsQuantizationSpec):
207
+ kwargs = _get_observer_kwargs(quantization_spec)
208
+ observer_ctr = FixedQParamsObserver.with_args(**kwargs)
209
+ if is_qat:
210
+ return FixedQParamsFakeQuantize.with_args(observer=observer_ctr)
211
+ else:
212
+ return observer_ctr()
213
+
214
+ assert isinstance(quantization_spec, QuantizationSpec)
215
+ observer_or_fake_quant_ctr = quantization_spec.observer_or_fake_quant_ctr
216
+ kwargs = _get_observer_kwargs(quantization_spec)
217
+ kwargs.pop("observer_or_fake_quant_ctr")
218
+ # we will remove is_dynamic from QuantizationSpec because
219
+ # it seems that dynamic range quantization
220
+ obs_or_fq_class = observer_or_fake_quant_ctr
221
+ if isinstance(observer_or_fake_quant_ctr, _PartialWrapper):
222
+ obs_or_fq_class = observer_or_fake_quant_ctr.p.func # type: ignore[union-attr, assignment]
223
+ if "PerChannel" not in obs_or_fq_class.__name__: # type: ignore[operator, union-attr]
224
+ kwargs.pop("ch_axis")
225
+ return observer_or_fake_quant_ctr.with_args(**kwargs)()
226
+
227
+ def _needs_obs_or_fq(
228
+ prev_output_dtype: Any,
229
+ prev_output_is_dynamic: bool,
230
+ cur_target_dtype: Any,
231
+ cur_target_is_dynamic: bool,
232
+ reuse_input_obs_or_fq: bool,
233
+ is_zeroth_arg: bool = False) -> bool:
234
+ """
235
+ note: we will treat "not specified" as torch.float for now
236
+ utility function that checks if we should insert an observer or fake quant node
237
+ base on the requested dtype for the nodes from user
238
+
239
+ is_zeroth_arg: we only dynamically quantize the first arg of the node right now
240
+ this should be removed when we enable configuring dynamic quantization
241
+ for a specific argument, this can be removed if we deprecate fx graph mode
242
+ quantization
243
+
244
+ """
245
+
246
+ # need to insert placeholder observer for dynamic quantization so that it can
247
+ # be converted to choose_qparams -> q -> dq in convert step
248
+ if cur_target_is_dynamic:
249
+ assert cur_target_dtype in _OBS_DTYPE_LIST, \
250
+ f"Expected cur_target_dtype to be torch.float, but got: {cur_target_dtype}"
251
+ assert prev_output_dtype not in _DO_NOT_OBS_DTYPE_LIST
252
+ return is_zeroth_arg
253
+ if reuse_input_obs_or_fq:
254
+ return False
255
+ # non dynamic quantization
256
+ if cur_target_dtype in _OBS_DTYPE_LIST:
257
+ return prev_output_dtype in _OBS_DTYPE_LIST + [torch.float] and cur_target_dtype != prev_output_dtype
258
+
259
+ # lots of error checking are skipped here for now
260
+ return False
261
+
262
+ def _is_activation_post_process_node(node: Node, named_modules: Dict[str, torch.nn.Module]) -> bool:
263
+ return isinstance(node, torch.fx.Node) and node.op == "call_module" and \
264
+ _is_activation_post_process(named_modules[str(node.target)])
265
+
266
+ def _get_dtype_and_is_dynamic(obs_or_fq: Optional[ObserverOrFakeQuantize]) -> Tuple[Optional[torch.dtype], bool]:
267
+ """ Given a constructor for observer or fake quant module, returns
268
+ a Tuple of dtype and is_dynamic
269
+ """
270
+ # TODO: instead of instantiating the instance, we can use inspect to get the default args
271
+ if obs_or_fq is None:
272
+ return None, False
273
+ else:
274
+ return obs_or_fq.dtype, getattr(obs_or_fq, "is_dynamic", False) # type: ignore[return-value]
275
+
276
+ def _is_input_arg_dtype_supported_by_backend(
277
+ arg: Argument,
278
+ node: Node,
279
+ qconfig: QConfigAny,
280
+ dtype_config: DTypeConfig,
281
+ backend_config: BackendConfig,
282
+ ) -> bool:
283
+ """ Check if the configured qconfig for the argument
284
+ is supported by the backend or not
285
+ """
286
+ if isinstance(arg, (list, tuple)):
287
+ return all(_is_input_arg_dtype_supported_by_backend(
288
+ a, node, qconfig,
289
+ dtype_config, backend_config) for a in arg)
290
+ if not isinstance(arg, Node):
291
+ return True
292
+ # TODO: support check for standalone module
293
+ is_weight = node_arg_is_weight(node, arg)
294
+ is_bias = node_arg_is_bias(node, arg)
295
+ is_activation = not is_weight and not is_bias
296
+ if is_activation:
297
+ input_act_obs_or_fq_ctr = node.meta["target_dtype_info"].get("input_act_obs_or_fq_ctr")
298
+ input_act_obs_or_fq = input_act_obs_or_fq_ctr() if input_act_obs_or_fq_ctr else None
299
+ qconfig_dtype, qconfig_is_dynamic = _get_dtype_and_is_dynamic(input_act_obs_or_fq)
300
+ # TODO(future PR): remove the cast to bool below after figuring
301
+ # out why backend_config has is_dynamic set to None in some cases.
302
+ return (dtype_config.input_dtype is None) or (
303
+ dtype_config.input_dtype == qconfig_dtype and
304
+ bool(dtype_config.is_dynamic) == bool(qconfig_is_dynamic) and
305
+ _qconfig_satisfies_dtype_config_constraints(qconfig, dtype_config.input_dtype_with_constraints)
306
+ )
307
+ elif is_weight:
308
+ # TODO: move dtype check into `_qconfig_satisfies_dtype_config_constraints` as well
309
+ weight_obs_or_fq_ctr = node.meta["target_dtype_info"].get("weight_obs_or_fq_ctr", None)
310
+ weight_obs_or_fq = weight_obs_or_fq_ctr() if weight_obs_or_fq_ctr else None
311
+ qconfig_weight_dtype, _ = _get_dtype_and_is_dynamic(weight_obs_or_fq)
312
+ backend_config_weight_dtype = dtype_config.weight_dtype
313
+ dtype_matches = qconfig_weight_dtype == backend_config_weight_dtype
314
+ qconfig_satisfies_constraints = _qconfig_satisfies_dtype_config_constraints(
315
+ qconfig, dtype_config.weight_dtype_with_constraints, is_activation=False)
316
+ return backend_config_weight_dtype is None or (dtype_matches and qconfig_satisfies_constraints)
317
+ else: # bias
318
+ # TODO: move dtype check into `_qconfig_satisfies_dtype_config_constraints` as well
319
+ bias_obs_or_fq_ctr = node.meta["target_dtype_info"].get("bias_obs_or_fq_ctr", None)
320
+ bias_obs_or_fq = bias_obs_or_fq_ctr() if bias_obs_or_fq_ctr else None
321
+ qconfig_bias_dtype, _ = _get_dtype_and_is_dynamic(bias_obs_or_fq)
322
+ backend_config_bias_dtype = dtype_config.bias_dtype
323
+ return backend_config_bias_dtype is None or qconfig_bias_dtype == backend_config_bias_dtype
324
+
325
+ def _is_output_dtype_supported_by_backend(
326
+ node: Node,
327
+ qconfig: QConfigAny,
328
+ dtype_config: DTypeConfig,
329
+ ) -> bool:
330
+ """ Check if the configured qconfig for the output
331
+ is supported by the backend or not
332
+ """
333
+ # TODO: move dtype check into `_qconfig_satisfies_dtype_config_constraints` as well
334
+ backend_config_output_dtype = dtype_config.output_dtype
335
+ # TODO: we should check is_dynamic here as well, the code from _is_input_arg_dtype_supported_by_backend
336
+ # from input activation check can be reused here
337
+ qconfig_output_dtype = None
338
+ output_act_obs_or_fq_ctr = node.meta["target_dtype_info"].get("output_act_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR)
339
+ output_act_obs_or_fq = output_act_obs_or_fq_ctr() if output_act_obs_or_fq_ctr else None
340
+ qconfig_output_dtype, qconfig_output_is_dynamic = _get_dtype_and_is_dynamic(output_act_obs_or_fq)
341
+ # TODO: this is a hack because we can only specify one activation_obs_or_fq for
342
+ # qconfig (qconfig.activation), and we are only supporting dynamically quantized
343
+ # linear op which has fp32 output dtype, this should be removed if we generalize
344
+ # the structure of qconfig in the future
345
+ if qconfig_output_is_dynamic:
346
+ qconfig_output_dtype = torch.float32
347
+ dtype_matches = qconfig_output_dtype == backend_config_output_dtype
348
+ qconfig_satisfies_constraints = _qconfig_satisfies_dtype_config_constraints(
349
+ qconfig, dtype_config.output_dtype_with_constraints)
350
+ return backend_config_output_dtype is None or (dtype_matches and qconfig_satisfies_constraints)
351
+
352
+ def _is_observer_in_same_graph(
353
+ node: Node,
354
+ named_modules: Dict[str, torch.nn.Module],
355
+ obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize],
356
+ is_qat,
357
+ ):
358
+ """ Check if observer in same graph
359
+ when the node output is not fp32 and input is 'placeholder'
360
+ the input is assumed to be quantized, so it is observed
361
+ in a different place rather than not observed.
362
+ """
363
+ node_output_dtype = _get_arg_target_dtype_as_output(node, named_modules, obs_or_fq_map, is_qat)
364
+ if len(node.args) > 0 and isinstance(node.args[0], Node):
365
+ if node_output_dtype in [torch.quint8, torch.uint8] and node.args[0].op == 'placeholder':
366
+ return False
367
+ return True
368
+
369
+ def _is_pattern_dtype_config_and_qconfig_supported_by_backend(
370
+ pattern: Optional[Pattern],
371
+ matched_node_pattern: Optional[List[Node]],
372
+ qconfig: QConfigAny,
373
+ backend_config: BackendConfig,
374
+ ) -> bool:
375
+ """ Check if the dtype configuration of a pattern is supported by
376
+ the backend or not, and whether the qconfig satisfies constraints
377
+ specified in the corresponding dtype config.
378
+ """
379
+ if backend_config is None or pattern is None:
380
+ return True
381
+ assert matched_node_pattern is not None and len(matched_node_pattern) >= 1
382
+ pattern_to_dtype_configs = get_pattern_to_dtype_configs(backend_config)
383
+ dtype_configs: List[DTypeConfig] = pattern_to_dtype_configs.get(pattern, [])
384
+ pattern_to_root_node_getter = get_fusion_pattern_to_root_node_getter(backend_config)
385
+
386
+ root_node_getter = pattern_to_root_node_getter.get(pattern, _default_root_node_getter)
387
+ root_node = root_node_getter(matched_node_pattern)
388
+ input_node = root_node
389
+ output_node = matched_node_pattern[0]
390
+ for dtype_config in dtype_configs:
391
+ # check if arg dtype are supported
392
+ supported = True
393
+ for arg in list(input_node.args) + list(input_node.kwargs.values()):
394
+ supported = supported and _is_input_arg_dtype_supported_by_backend(
395
+ arg, input_node, qconfig, dtype_config, backend_config)
396
+ # check if output dtype is supported
397
+ supported = supported and _is_output_dtype_supported_by_backend(
398
+ output_node, qconfig, dtype_config)
399
+ if supported:
400
+ return True
401
+ return False
402
+
403
+ def _get_standalone_module_configs(
404
+ node: Node,
405
+ named_modules: Dict[str, torch.nn.Module],
406
+ prepare_custom_config: PrepareCustomConfig,
407
+ parent_qconfig: QConfigAny,
408
+ parent_backend_config: Optional[BackendConfig],
409
+ ) -> Tuple[QConfigMapping, Tuple[Any, ...], PrepareCustomConfig, Optional[BackendConfig]]:
410
+ """
411
+ Returns the standalone module QConfigMapping and PrepareCustomConfig
412
+ for `node`, assuming that the module pointed to by `node` is
413
+ a standalone modules.
414
+ """
415
+ module_name = str(node.target)
416
+ module_type = type(named_modules[module_name]) # type: ignore[index]
417
+ # name config has precedence over type config
418
+ config_entry = StandaloneModuleConfigEntry(None, (), None, None)
419
+ config_entry = prepare_custom_config.standalone_module_classes.get(module_type, config_entry)
420
+ config_entry = prepare_custom_config.standalone_module_names.get(module_name, config_entry)
421
+ # fallback to use parent module's qconfig if user didn't specify qconfig dict
422
+ qconfig_mapping = config_entry.qconfig_mapping or QConfigMapping().set_global(parent_qconfig)
423
+ example_inputs = config_entry.example_inputs
424
+ prepare_custom_config = config_entry.prepare_custom_config or PrepareCustomConfig()
425
+ backend_config = config_entry.backend_config or parent_backend_config
426
+ return (qconfig_mapping, example_inputs, prepare_custom_config, backend_config)
427
+
428
+ def _qat_swap_modules(
429
+ root: torch.nn.Module,
430
+ module_to_qat_module: Dict[Pattern, Type[torch.nn.Module]]) -> None:
431
+ convert(root, mapping=module_to_qat_module, inplace=True, remove_qconfig=False)
432
+
433
+ def _add_matched_node_name_to_set(matched_node_pattern: NodePattern, s: Set[str]):
434
+ if isinstance(matched_node_pattern, Node):
435
+ s.add(matched_node_pattern.name)
436
+ elif isinstance(matched_node_pattern, (list, tuple)):
437
+ for maybe_node in matched_node_pattern:
438
+ _add_matched_node_name_to_set(maybe_node, s)
439
+
440
+ def _insert_obs_or_fq(
441
+ node: Node,
442
+ obs_or_fq: ObserverOrFakeQuantize,
443
+ model: torch.nn.Module,
444
+ named_modules: Dict[str, torch.nn.Module],
445
+ graph: Graph,
446
+ ) -> Node:
447
+ """
448
+ Attaches `obs_or_fq` to `model`, and creates a node which calls
449
+ `obs_or_fq` on the output of `node`.
450
+
451
+ obs_or_fq: an instance of Observer or FakeQuantize module
452
+ """
453
+ model_device = assert_and_get_unique_device(model)
454
+ if model_device:
455
+ obs_or_fq.to(model_device)
456
+ # add obs_or_fq module as attribute
457
+ if is_equalization_observer(obs_or_fq):
458
+ prefix = node.name + '_equalization_process_'
459
+ else:
460
+ prefix = 'activation_post_process_'
461
+ get_new_obs_or_fq_name = get_new_attr_name_with_prefix(prefix)
462
+ obs_or_fq_name = get_new_obs_or_fq_name(model)
463
+ setattr(model, obs_or_fq_name, obs_or_fq)
464
+ named_modules[obs_or_fq_name] = obs_or_fq
465
+ with graph.inserting_after(node):
466
+ new_obs = graph.create_node(
467
+ 'call_module', obs_or_fq_name, (node,), {})
468
+ return new_obs
469
+
470
+ def _set_target_dtype_info_for_matched_node_pattern(
471
+ matched_node_pattern: NodePattern,
472
+ last_node: Node,
473
+ qconfig: QConfigAny,
474
+ qhandler: Optional[QuantizeHandler],
475
+ backend_config: BackendConfig,
476
+ named_modules: Dict[str, torch.nn.Module],
477
+ cache_for_no_tensor_check: Dict[Node, bool],
478
+ processed_nodes: Set[Node],
479
+ ) -> None:
480
+ """ Sets the target_dtype_info for each node in matched_node_pattern
481
+ Note: processed_nodes is used to ensure we only process each node once
482
+ """
483
+ if isinstance(matched_node_pattern, (list, tuple)):
484
+ for node_pattern in matched_node_pattern:
485
+ _set_target_dtype_info_for_matched_node_pattern(
486
+ node_pattern,
487
+ last_node,
488
+ qconfig,
489
+ qhandler,
490
+ backend_config,
491
+ named_modules,
492
+ cache_for_no_tensor_check,
493
+ processed_nodes
494
+ )
495
+
496
+ # set target_dtype_info if matched_node_pattern is a Node
497
+ # other types of matched object, e.g. int, float literals, are ignored
498
+ elif isinstance(matched_node_pattern, Node):
499
+ # for pyre
500
+ assert isinstance(matched_node_pattern, Node)
501
+ node = matched_node_pattern
502
+ if node in processed_nodes:
503
+ return
504
+ processed_nodes.add(node)
505
+
506
+ if qconfig is None:
507
+ return
508
+ # TODO: refactor the following code in terms of apply a qconfig to a pattern
509
+ # e.g. for a pattern with op1 -> op2 -> op3, and qconfig = QConfig(input_act=obs0, output_act=obs1)
510
+ # we set the input_obs_or_fq_ctr for the arguments of op1 to based on qconfig.input_act,
511
+ # and set output_obs_or_fq_ctr based on qconfig.output_act
512
+ # this also requires we extend the structure of QConfig to support more fine
513
+ # grained configurations
514
+ target_dtype_info: Dict[str, Any] = (
515
+ _get_target_activation_dtype_for_node(
516
+ node,
517
+ qconfig,
518
+ qhandler,
519
+ named_modules,
520
+ backend_config,
521
+ cache_for_no_tensor_check,
522
+ )
523
+ )
524
+ node.meta["target_dtype_info"] = target_dtype_info
525
+
526
+ def _get_target_activation_dtype_for_node(
527
+ node: Node,
528
+ qconfig: QConfigAny,
529
+ qhandler: Optional[QuantizeHandler],
530
+ named_modules: Dict[str, torch.nn.Module],
531
+ backend_config: BackendConfig,
532
+ cache_for_no_tensor_check: Dict[Node, bool],
533
+ ) -> Dict[str, Any]:
534
+ """
535
+ For each op attribute in the op's input activation, output activation,
536
+ weight, bias - returns the settings of dtype and is_dynamic we expect
537
+ for the `quantize` call in the reference model representation, or None
538
+ if there is no `quantize` call needed.
539
+
540
+ For example, if we have a node corresponding to `op0` in
541
+
542
+ x0 -> op0 -> x1
543
+
544
+ And we want a reference quantized representation to be
545
+
546
+ x0 -> quant_static -> dequant -> op0 -> quant_dynamic -> dequant -> x1
547
+
548
+ Then this function will return
549
+
550
+ {
551
+ "input_act_obs_or_fq_ctr": MinMaxObserver.with_args(dtype=torch.quint8, is_dynamic=False),
552
+ "output_act_obs_or_fq_ctr": MinMaxObserver.with_args(dtype=torch.quint8, is_dynamic=False),
553
+ }
554
+
555
+ TODO(future PR, if needed): explicitly spell out the non-Tensor
556
+ dtypes.
557
+ """
558
+ args_have_no_tensors = \
559
+ all_node_args_have_no_tensors(
560
+ node, named_modules, cache_for_no_tensor_check)
561
+ if args_have_no_tensors:
562
+ return {
563
+ "input_act_obs_or_fq_ctr": None,
564
+ "output_act_obs_or_fq_ctr": None,
565
+ }
566
+ # get qconfig to determine the eventual dtype of this node
567
+ if qconfig is not None:
568
+ act_dtype, weight_dtype, input_act_is_dynamic = \
569
+ get_qconfig_dtypes(qconfig)
570
+
571
+ # Currently `QConfig` only has one `activation` field.
572
+ # For static quantization, it is reused for both input
573
+ # and output activation. For dynamic quantization, this
574
+ # field is currently only used for the input activation,
575
+ # with the output activation being in fp32.
576
+ # In the future this may change as we add more fields
577
+ # to the `QConfig` object.
578
+ output_act_dtype = act_dtype \
579
+ if (not input_act_is_dynamic) else torch.float
580
+
581
+ bias_dtype = torch.float16 \
582
+ if (
583
+ act_dtype == torch.float16
584
+ and weight_dtype == torch.float16
585
+ and (not input_act_is_dynamic)
586
+ ) else torch.float
587
+
588
+ is_general_tensor_value_op = \
589
+ (qhandler is not None and qhandler.is_general_tensor_value_op())
590
+
591
+ _is_standalone_module = (
592
+ qhandler is not None and qhandler.is_standalone_module()
593
+ )
594
+
595
+ weight_index = None
596
+ if isinstance(node, Node) and node.op == "call_function" and \
597
+ node.target in backend_config._pattern_complex_format_to_config:
598
+ weight_index = backend_config._pattern_complex_format_to_config[node.target]._input_type_to_index.get("weight")
599
+
600
+ bias_index = None
601
+ if isinstance(node, Node) and node.op == "call_function" and \
602
+ node.target in backend_config._pattern_complex_format_to_config:
603
+ bias_index = backend_config._pattern_complex_format_to_config[node.target]._input_type_to_index.get("bias")
604
+
605
+ return {
606
+ "input_act_obs_or_fq_ctr": qconfig.activation,
607
+ "weight_obs_or_fq_ctr": qconfig.weight,
608
+ "bias_obs_or_fq_ctr": PlaceholderObserver.with_args(dtype=bias_dtype),
609
+ "weight_index": weight_index,
610
+ "bias_index": bias_index,
611
+ "output_act_obs_or_fq_ctr": qconfig.activation,
612
+ "reuse_input_obs_or_fq": _is_reuse_input_qconfig(qconfig),
613
+ "input_output_share_observers": is_general_tensor_value_op,
614
+ "_is_standalone_module": _is_standalone_module,
615
+ }
616
+ return copy.copy(_DEFAULT_FP32_QCONFIG_FOR_TARGET_DTYPE_INFO)
617
+
618
+ def _get_output_act_obs_or_fq(
619
+ arg: Node,
620
+ named_modules: Dict[str, torch.nn.Module],
621
+ obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize],
622
+ is_qat: bool,
623
+ ) -> ObserverOrFakeQuantize:
624
+ """ Get the constructor for observer or fake quant object for
625
+ the argument in the original graph as the output of previous node,
626
+ skipping inserted observers
627
+
628
+ We are assuming that the observers are inserted correctly, and the dtype for
629
+ argument in quantized graph will match what is specified by the qconfig
630
+ """
631
+ assert isinstance(arg, Node)
632
+ if "quantization_annotation" in arg.meta:
633
+ return _create_obs_or_fq_from_qspec(arg.meta["quantization_annotation"].output_qspec, obs_or_fq_map, is_qat)
634
+
635
+ # Custom module LSTM output is a tuple that we broke down into the internal nodes in order
636
+ # to insert DeQuantStubs (see `_insert_dequant_stubs_for_custom_module_lstm_output`).
637
+ # Since we modified the graph in this case, we must trace back from the args through
638
+ # the specific nodes we added in order to reach the original LSTM node. Otherwise, we would
639
+ # not be able to accurately detect whether this node is a consumer of custom module LSTM.
640
+ custom_module_lstm_node = _maybe_get_custom_module_lstm_from_node_arg(arg, named_modules)
641
+ output_act_obs_or_fq_ctr = None
642
+ if custom_module_lstm_node is not None:
643
+ output_act_obs_or_fq_ctr = custom_module_lstm_node.meta["target_dtype_info"]["output_act_obs_or_fq_ctr"]
644
+ output_act_obs_or_fq = output_act_obs_or_fq_ctr() if output_act_obs_or_fq_ctr else None
645
+ elif _is_activation_post_process_node(arg, named_modules):
646
+ observed_arg = arg.args[0]
647
+ assert isinstance(observed_arg, Node), "Currently we only support observing Node"
648
+ if "quantization_annotation" in observed_arg.meta:
649
+ output_act_obs_or_fq = \
650
+ _create_obs_or_fq_from_qspec(
651
+ observed_arg.meta["quantization_annotation"].output_qspec, obs_or_fq_map, is_qat)
652
+ else:
653
+ assert "target_dtype_info" in observed_arg.meta
654
+ output_act_obs_or_fq_ctr = observed_arg.meta["target_dtype_info"]["output_act_obs_or_fq_ctr"]
655
+ output_act_obs_or_fq = output_act_obs_or_fq_ctr() if output_act_obs_or_fq_ctr else None
656
+ else:
657
+ if "target_dtype_info" in arg.meta:
658
+ output_act_obs_or_fq_ctr = \
659
+ arg.meta["target_dtype_info"].get("output_act_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR)
660
+ else:
661
+ output_act_obs_or_fq_ctr = _DEFAULT_FP32_OBS_OR_FQ_CTR
662
+ output_act_obs_or_fq = output_act_obs_or_fq_ctr() if output_act_obs_or_fq_ctr else None
663
+
664
+ return output_act_obs_or_fq
665
+
666
+ def _get_arg_target_dtype_as_output(
667
+ arg: Node,
668
+ named_modules: Dict[str, torch.nn.Module],
669
+ obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize],
670
+ is_qat: bool,
671
+ ) -> Optional[torch.dtype]:
672
+ arg_as_output_act_obs_or_fq = _get_output_act_obs_or_fq(arg, named_modules, obs_or_fq_map, is_qat)
673
+ arg_as_output_target_dtype, _ = _get_dtype_and_is_dynamic(arg_as_output_act_obs_or_fq)
674
+ return arg_as_output_target_dtype
675
+
676
+ def _get_arg_as_input_act_obs_or_fq(
677
+ arg: Node,
678
+ node: Node,
679
+ named_modules: Dict[str, torch.nn.Module],
680
+ obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize],
681
+ is_qat: bool,
682
+ ) -> Optional[ObserverOrFakeQuantize]:
683
+ """ Get the observer or fake quant constructor for the Argument `arg`, as input
684
+ to Node `node`
685
+ """
686
+ assert isinstance(arg, Node)
687
+ # "input_qspec_map" is the more general design we'll use for pt2e path
688
+ # it is a map from input argument node to observer or fake quant constructor, for example
689
+ # for the following graph:
690
+ # x -> conv -> output
691
+ #
692
+ # we may annotate conv node like the following:
693
+ # conv.meta[...] = QuantizationAnnotation("input_qspec_map": {x: MinMaxObserver.with_args(dtype=torch.qint8)}, ...)
694
+ #
695
+ if "quantization_annotation" in node.meta:
696
+ input_qspec_map = node.meta["quantization_annotation"].input_qspec_map
697
+ input_arg_qspec = _get_qspec_for_arg(arg, input_qspec_map, named_modules)
698
+ if input_arg_qspec is None:
699
+ input_arg_obs_or_fq = _DEFAULT_FP32_OBS_OR_FQ_CTR()
700
+ else:
701
+ input_arg_obs_or_fq = _create_obs_or_fq_from_qspec(input_arg_qspec, obs_or_fq_map, is_qat)
702
+ return input_arg_obs_or_fq
703
+
704
+ # we can remove the following path in the future if fx graph mode quantization is
705
+ # no longer used
706
+ is_weight = node_arg_is_weight(node, arg)
707
+ is_bias = node_arg_is_bias(node, arg)
708
+ is_activation = not is_weight and not is_bias
709
+ obs_or_fq_ctr = None
710
+ if is_activation:
711
+ obs_or_fq_ctr = node.meta["target_dtype_info"].get("input_act_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR)
712
+ elif is_weight:
713
+ if node.target not in NON_QUANTIZABLE_WEIGHT_OPS:
714
+ obs_or_fq_ctr = node.meta["target_dtype_info"].get("weight_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR)
715
+ else:
716
+ obs_or_fq_ctr = node.meta["target_dtype_info"].get("bias_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR)
717
+ return obs_or_fq_ctr() if obs_or_fq_ctr else None
718
+
719
+ def _maybe_insert_input_observer_for_arg_or_kwarg(
720
+ node: Union[Node, Any],
721
+ arg: Argument,
722
+ qconfig: QConfigAny,
723
+ model: torch.nn.Module,
724
+ named_modules: Dict[str, torch.nn.Module],
725
+ graph: Graph,
726
+ qhandler: Optional[QuantizeHandler],
727
+ prepare_custom_config: PrepareCustomConfig,
728
+ obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize],
729
+ is_qat: bool,
730
+ backend_config: Optional[BackendConfig] = None,
731
+ ) -> Argument:
732
+ """
733
+ Given a `node` and an `arg`, inserts an input observer between
734
+ `node` and `arg` if necessary.
735
+ """
736
+ # for ops such as torch.cat([x0, x1]),
737
+ # traverse through the list
738
+ if isinstance(arg, (list, tuple)):
739
+ new_arg_to_return = []
740
+ for inner_arg in arg:
741
+ new_inner_arg = _maybe_insert_input_observer_for_arg_or_kwarg(
742
+ node, inner_arg, qconfig, model, named_modules,
743
+ graph,
744
+ qhandler,
745
+ prepare_custom_config,
746
+ obs_or_fq_map,
747
+ is_qat,
748
+ backend_config)
749
+ new_arg_to_return.append(new_inner_arg)
750
+ return type(arg)(new_arg_to_return)
751
+
752
+ if not isinstance(arg, Node):
753
+ return arg
754
+ assert isinstance(arg, Node)
755
+ # default (no observer)
756
+ new_arg = arg
757
+
758
+ is_standalone_module = qhandler is not None and qhandler.is_standalone_module()
759
+ # TODO: move this to a separate function
760
+ if not is_standalone_module:
761
+ # Note: qconfig can be None in this branch this we are getting act/fq from
762
+ # node.meta now
763
+ # regular flow for most nodes, except standalone modules
764
+
765
+ if "quantization_annotation" in node.meta:
766
+ reuse_input_obs_or_fq = node.meta["quantization_annotation"]._reuse_input_obs_or_fq
767
+ else:
768
+ assert "target_dtype_info" in node.meta
769
+ # TODO: we are assuming "target_dtype_info" exists here, maybe
770
+ # a default value also need to be provided here
771
+ target_dtype_info = node.meta["target_dtype_info"]
772
+ # for nodes that doesn't have `reuse_input_obs_or_fq` configured,
773
+ # we'll default to False, this makes configuring this field optional for users
774
+ reuse_input_obs_or_fq = target_dtype_info.get("reuse_input_obs_or_fq", False)
775
+ arg_as_input_act_obs_or_fq = _get_arg_as_input_act_obs_or_fq(arg, node, named_modules, obs_or_fq_map, is_qat)
776
+ arg_as_input_target_dtype, arg_as_input_target_is_dynamic = _get_dtype_and_is_dynamic(arg_as_input_act_obs_or_fq)
777
+
778
+ arg_as_output_act_obs_or_fq = _get_output_act_obs_or_fq(arg, named_modules, obs_or_fq_map, is_qat)
779
+ arg_as_output_target_dtype, arg_as_output_target_is_dynamic = _get_dtype_and_is_dynamic(arg_as_output_act_obs_or_fq)
780
+
781
+
782
+ needs_obs_or_fq = _needs_obs_or_fq(
783
+ arg_as_output_target_dtype,
784
+ arg_as_output_target_is_dynamic,
785
+ arg_as_input_target_dtype,
786
+ arg_as_input_target_is_dynamic,
787
+ reuse_input_obs_or_fq,
788
+ is_zeroth_arg=len(node.args) > 0 and arg is node.args[0],
789
+ )
790
+
791
+ else:
792
+ assert qconfig is not None
793
+ # custom flow for standalone modules
794
+ _, _, sm_prepare_custom_config, _ = \
795
+ _get_standalone_module_configs(
796
+ node, named_modules, prepare_custom_config, qconfig, backend_config)
797
+ sm_input_quantized_idxs = sm_prepare_custom_config.input_quantized_indexes
798
+
799
+ # for args, this is set to the index of the current arg
800
+ # for kwargs, this is left at None
801
+ cur_input_idx = None
802
+ for arg_idx, arg_to_check in enumerate(node.args):
803
+ if arg_to_check is arg:
804
+ cur_input_idx = arg_idx
805
+ break
806
+
807
+ if cur_input_idx is None:
808
+ needs_obs_or_fq = False
809
+ else:
810
+ arg_as_output_target_dtype = _get_arg_target_dtype_as_output(arg, named_modules, obs_or_fq_map, is_qat)
811
+ arg_as_input_target_dtype = torch.quint8 if cur_input_idx in sm_input_quantized_idxs \
812
+ else torch.float
813
+ needs_obs_or_fq = (
814
+ (arg_as_output_target_dtype != arg_as_input_target_dtype) and
815
+ (arg_as_input_target_dtype != torch.float)
816
+ )
817
+
818
+ act_post_process_ctr = qconfig.activation
819
+ arg_as_input_act_obs_or_fq = act_post_process_ctr() if act_post_process_ctr else None
820
+
821
+ if needs_obs_or_fq:
822
+
823
+ existing_obs_node = None
824
+
825
+ # Before using the new observer, check if an observer
826
+ # of the correct type already exists. If it does, use it.
827
+ # This prevents duplicate observer insertions if a node is
828
+ # used by multiple nodes.
829
+ # TODO: this is looking into how the value is used in the future
830
+ # we should remove this
831
+ # removing this means we insert one observer for each use, even if they
832
+ # have the same dtype, we can have an extra pass that removes the extra observers
833
+ for maybe_obs_node in arg.users.keys():
834
+ if maybe_obs_node.op == 'call_module':
835
+ maybe_obs_mod = named_modules[maybe_obs_node.target] # type: ignore[index]
836
+ if (
837
+ type(maybe_obs_mod) == type(arg_as_input_act_obs_or_fq) and
838
+ maybe_obs_mod.dtype == arg_as_input_target_dtype # type: ignore[possibly-undefined]
839
+ ):
840
+ arg_as_input_act_obs_or_fq = maybe_obs_mod # type: ignore[assignment]
841
+ existing_obs_node = maybe_obs_node
842
+ break
843
+
844
+ assert arg_as_input_act_obs_or_fq is not None
845
+ obs_or_fq_map[(arg, node)] = arg_as_input_act_obs_or_fq
846
+ if existing_obs_node is None:
847
+ new_obs_node = _insert_obs_or_fq(
848
+ arg, arg_as_input_act_obs_or_fq, model, named_modules, graph)
849
+ # override this arg to be the observed arg
850
+ new_arg = new_obs_node
851
+ else:
852
+ new_arg = existing_obs_node
853
+
854
+ return new_arg
855
+
856
+
857
+ def _maybe_insert_input_observers_for_node(
858
+ node: Node,
859
+ qconfig: QConfigAny,
860
+ model: torch.nn.Module,
861
+ named_modules: Dict[str, torch.nn.Module],
862
+ graph: Graph,
863
+ qhandler: Optional[QuantizeHandler],
864
+ prepare_custom_config: PrepareCustomConfig,
865
+ obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize],
866
+ is_qat: bool,
867
+ backend_config: Optional[BackendConfig] = None
868
+ ) -> None:
869
+ """
870
+ If needed, inserts observers to the input args and kwargs of `node`.
871
+ Note: modifies `node` inplace.
872
+
873
+ For example, if cur_node needs an observer after prev_node, we change from
874
+
875
+ prev_node -> cur_node
876
+
877
+ To
878
+
879
+ prev_node -> obs -> cur_node
880
+
881
+ Note: backend_config only needed for standalone_module node
882
+ """
883
+ # Look through every input arg. If that arg's target dtype does not
884
+ # match the current node's target dtype, insert an observer.
885
+ new_args = []
886
+ for arg in node.args:
887
+ new_arg = _maybe_insert_input_observer_for_arg_or_kwarg(
888
+ node, arg, qconfig, model, named_modules, graph,
889
+ qhandler,
890
+ prepare_custom_config,
891
+ obs_or_fq_map,
892
+ is_qat,
893
+ backend_config)
894
+ new_args.append(new_arg)
895
+
896
+ new_kwargs = {}
897
+ for k, kwarg in node.kwargs.items():
898
+ new_kwarg = _maybe_insert_input_observer_for_arg_or_kwarg(
899
+ node, kwarg, qconfig, model, named_modules, graph,
900
+ qhandler,
901
+ prepare_custom_config,
902
+ obs_or_fq_map,
903
+ is_qat,
904
+ backend_config)
905
+ new_kwargs[k] = new_kwarg
906
+
907
+ # assign the new args and kwargs to the node, inplace
908
+ node.args = tuple(new_args)
909
+ node.kwargs = new_kwargs
910
+
911
+ def _maybe_insert_input_equalization_observers_for_node(
912
+ node: Node,
913
+ equalization_qconfig: Any,
914
+ model: torch.nn.Module,
915
+ named_modules: Dict[str, torch.nn.Module],
916
+ graph: Graph,
917
+ is_branch: bool,
918
+ ) -> None:
919
+ """
920
+ If `node` needs to be equalized, find the input/weight observers it needs in
921
+ `equalization_qconfig`, creates them, and inserts it into `graph`.
922
+
923
+ If `node` does not need an equalization observer, returns None.
924
+ """
925
+ if equalization_qconfig is None or not node_supports_equalization(node, named_modules):
926
+ return
927
+
928
+ if is_branch:
929
+ warnings.warn(
930
+ f"Cannot equalize {node} because it is part of a branch."
931
+ )
932
+ return
933
+
934
+ new_args = []
935
+ for arg in node.args:
936
+ if not isinstance(arg, Node) or node_arg_is_bias(node, arg):
937
+ new_args.append(arg)
938
+ continue
939
+
940
+ is_weight = node_arg_is_weight(node, arg)
941
+
942
+ act_eq_process_ctr = equalization_qconfig.weight if is_weight else \
943
+ equalization_qconfig.input_activation
944
+
945
+ new_eq_obs_mod = act_eq_process_ctr()
946
+ new_eq_obs_node = _insert_obs_or_fq(
947
+ arg, new_eq_obs_mod, model, named_modules, graph)
948
+
949
+ new_args.append(new_eq_obs_node)
950
+
951
+ # assign the new args and kwargs to the node, inplace
952
+ node.args = tuple(new_args)
953
+
954
+ def _maybe_insert_output_observer_for_node(
955
+ node: Node,
956
+ model: torch.nn.Module,
957
+ named_modules: Dict[str, torch.nn.Module],
958
+ graph: Graph,
959
+ obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize],
960
+ is_qat: bool,
961
+ ) -> Optional[Node]:
962
+ """
963
+ If `node` needs an output observer, creates it, inserts it into `graph`
964
+ and returns it.
965
+
966
+ If `node` does not need an output observer, returns None.
967
+
968
+ Note: inserting dynamic quantization ops for output is not supported in fx graph mode
969
+ quantization code path right now
970
+ """
971
+ assert node.op != 'output', 'observer insertion for outputs is handled elsewhere'
972
+
973
+ is_standalone_module = False
974
+ if "quantization_annotation" in node.meta:
975
+ output_act_obs_or_fq = _create_obs_or_fq_from_qspec(
976
+ node.meta["quantization_annotation"].output_qspec, obs_or_fq_map, is_qat
977
+ )
978
+ else:
979
+ assert "target_dtype_info" in node.meta
980
+ is_standalone_module = node.meta["target_dtype_info"].get("_is_standalone_module", False)
981
+ output_act_obs_or_fq_ctr = node.meta["target_dtype_info"].get("output_act_obs_or_fq_ctr")
982
+ output_act_obs_or_fq = output_act_obs_or_fq_ctr() if output_act_obs_or_fq_ctr else None
983
+ target_dtype, target_is_dynamic = _get_dtype_and_is_dynamic(output_act_obs_or_fq)
984
+ # uncomment after we support reuse_input_obs_or_fq properly by having separate
985
+ # implemntations for this key instead of reusing the input_output_share_observers
986
+ # code
987
+ # reuse_input_obs_or_fq = node.meta["target_dtype_info"].get("reuse_input_obs_or_fq", False)
988
+ # for now we set this to False since reuse_input_obs_or_fq for
989
+ # the output of a node is implementation in the same code path as observer sharing,
990
+ # we should refactor this part to make it clearer in the future
991
+ # and we would be able to read this from config directly
992
+ reuse_input_obs_or_fq = False
993
+
994
+ # Note: prev_output_dtype = torch.float and prev_output_is_dynamic=False
995
+ # because the prev_output is the output of an fp32 op, althought technically
996
+ # we should get the dtype of the output from node.meta["val"] in the future
997
+ # if we deprecate fx graph mode quantization
998
+ needs_obs_or_fq = _needs_obs_or_fq(torch.float, False, target_dtype, target_is_dynamic, reuse_input_obs_or_fq)
999
+ # currently the activation in QConfig(activation=...,) is for both input
1000
+ # and output, and when the activation is configured to be dynamic quantization
1001
+ # e.g. PlaceholderObserver(dtype=torch.quint8, is_dynamic=True, ...), it means
1002
+ # the input should by dynamically quantized, but output should not be quantized
1003
+ #
1004
+ # there is no way we can specify different observer/fq for input and output
1005
+ # activation through QConfig today, this limitation is lifted in the
1006
+ # quantizer/annotation API in pytorch 2.0 export quantization code path,
1007
+ # but since this code is reused, annotating output to be dynamically quantized
1008
+ # would not work either for that.
1009
+ # we can change QConfig to support input/output activation if we want
1010
+ # to remove the following check, or if we can deprecate fx graph mode quantization
1011
+ if target_is_dynamic:
1012
+ needs_obs_or_fq = False
1013
+
1014
+ # we never insert observers to output of standalone module, we assume
1015
+ # if needed, they are inserted inside the standalone module
1016
+ needs_obs_or_fq = needs_obs_or_fq and \
1017
+ (not is_standalone_module)
1018
+
1019
+ if needs_obs_or_fq:
1020
+ obs_or_fq_map[node] = output_act_obs_or_fq
1021
+ return _insert_obs_or_fq(node, output_act_obs_or_fq, model, named_modules, graph)
1022
+ else:
1023
+ return None
1024
+
1025
+ def _maybe_insert_observers_before_graph_output(
1026
+ graph_output_node: Node,
1027
+ model: torch.nn.Module,
1028
+ named_modules: Dict[str, torch.nn.Module],
1029
+ graph: Graph,
1030
+ obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize],
1031
+ is_qat: bool,
1032
+ ) -> None:
1033
+ """
1034
+ If the output needs to be quantized and there are any nodes
1035
+ in the output which are not already observed, inserts observers
1036
+ for those nodes.
1037
+ """
1038
+
1039
+ def _recursive_maybe_replace_node_with_obs(
1040
+ maybe_node: Argument,
1041
+ model: torch.nn.Module,
1042
+ named_modules: Dict[str, torch.nn.Module],
1043
+ graph: Graph,
1044
+ ) -> Argument:
1045
+ """
1046
+ Navigate an arbitrary data structure of lists, tuples, dicts.
1047
+ For each container type, recurse on all inputs. Once any Node
1048
+ is found, insert an observer if needed and do not recurse further.
1049
+
1050
+ For example, given a structure of
1051
+
1052
+ {'foo1': [[bar1]], 'foo2': {'foo3': [[[bar3]]]}}
1053
+
1054
+ we recurse down to bar1 and bar3, observe them if necessary,
1055
+ and if we inserted an observer then replace the original node
1056
+ with its observer.
1057
+
1058
+ Returns the data structure with all nodes needing observation being
1059
+ replaced by their observers.
1060
+ """
1061
+ if isinstance(maybe_node, Node):
1062
+ # check dtype of this node
1063
+ arg_as_output_target_dtype = _get_arg_target_dtype_as_output(maybe_node, named_modules, obs_or_fq_map, is_qat)
1064
+ observer_mod = None
1065
+ arg_as_input_target_dtype = torch.float
1066
+ if "target_dtype_info" in maybe_node.meta:
1067
+ observer_cls = maybe_node.meta["target_dtype_info"].get("input_act_obs_or_fq_ctr", None)
1068
+ if observer_cls is not None:
1069
+ observer_mod = observer_cls()
1070
+ arg_as_input_target_dtype = observer_mod.dtype
1071
+ # TODO: this does not handle dynamic quantization yet
1072
+ need_obs = (
1073
+ arg_as_output_target_dtype != arg_as_input_target_dtype and
1074
+ arg_as_input_target_dtype != torch.float
1075
+ )
1076
+ if need_obs:
1077
+ assert observer_mod is not None
1078
+ # insert observer
1079
+ observer_node = _insert_obs_or_fq(
1080
+ maybe_node, observer_mod, model, named_modules, graph)
1081
+ return observer_node
1082
+ else:
1083
+ return maybe_node
1084
+ elif isinstance(maybe_node, (list, tuple)):
1085
+ results = []
1086
+ for inner_node in maybe_node:
1087
+ results.append(_recursive_maybe_replace_node_with_obs(
1088
+ inner_node, model, named_modules, graph))
1089
+ if isinstance(maybe_node, list):
1090
+ return results
1091
+ else:
1092
+ return tuple(results)
1093
+ elif isinstance(maybe_node, dict):
1094
+ results_dict = {}
1095
+ for k, inner_v in maybe_node.items():
1096
+ results_dict[k] = _recursive_maybe_replace_node_with_obs(
1097
+ inner_v, model, named_modules, graph)
1098
+ return results_dict
1099
+ elif maybe_node is None:
1100
+ return None
1101
+ else:
1102
+ raise Exception("Unhandled type for returned node:", maybe_node)
1103
+
1104
+ new_args = []
1105
+ for old_arg in graph_output_node.args:
1106
+ new_args.append(
1107
+ _recursive_maybe_replace_node_with_obs(
1108
+ old_arg, model, named_modules, graph))
1109
+
1110
+ graph_output_node.args = tuple(new_args) # type: ignore[assignment]
1111
+
1112
+
1113
+ def _maybe_propagate_dtype_for_node(
1114
+ node: Node,
1115
+ target_dtype: Union[torch.dtype, type],
1116
+ node_name_to_match_result_with_qconfig: Dict[str, _MatchResultWithQConfig],
1117
+ ) -> None:
1118
+ """
1119
+ Assigns `target_dtype` to `node`, setting `is_dynamic` to False. If `node`
1120
+ is a general tensor shape op, also call this function recursively on
1121
+ the first argument, to propagate the dtype to the caller.
1122
+ """
1123
+ node.meta["target_dtype_info"]["input_act_obs_or_fq_ctr"] = None
1124
+ node.meta["target_dtype_info"]["output_act_obs_or_fq_ctr"] = None
1125
+ # if this is a copy node, propagate to first arg
1126
+ root_node, _, pattern, qhandler, qconfig = node_name_to_match_result_with_qconfig.get(
1127
+ node.name, (None, None, None, None, None))
1128
+ # TODO: probably need to remove `is_general_tensor_value_op`
1129
+ if qhandler is not None and qhandler.is_general_tensor_value_op():
1130
+ prev_node = node.args[0]
1131
+ if isinstance(prev_node, Node):
1132
+ _maybe_propagate_dtype_for_node(
1133
+ prev_node, target_dtype, node_name_to_match_result_with_qconfig)
1134
+
1135
+ def propagate_dtypes_for_known_nodes(
1136
+ graph: Graph,
1137
+ node_name_to_match_result_with_qconfig: Dict[str, _MatchResultWithQConfig],
1138
+ ) -> None:
1139
+ """
1140
+ Currently we assume that inputs to the graph are either `torch.float` or
1141
+ `torch.quint8`, which is not always correct. For ops such as
1142
+ `x.masked_fill(mask, value)`, we know that the dtype of `mask` is a
1143
+ `BoolTensor`. Propagate this information throughout the graph.
1144
+
1145
+ Note: not all dtypes in the graph will be correct after this pass, but a
1146
+ higher percentage of them will be correct. Hopefully in the future we can
1147
+ replace this with a better way to reason about dtypes of tensors.
1148
+ """
1149
+ for node in graph.nodes:
1150
+ non_observable_arg_dict = get_non_observable_arg_indexes_and_types(node)
1151
+
1152
+ for arg_type in non_observable_arg_dict:
1153
+ non_observable_indices = non_observable_arg_dict[arg_type](node)
1154
+
1155
+ for index in non_observable_indices:
1156
+ arg = node.args[index]
1157
+
1158
+ # when an argument is a tuple, it does not show up as another node so we need to go through
1159
+ # all elements of the tuple manually
1160
+ if isinstance(arg, (tuple, list)):
1161
+ arg_list = list(arg)
1162
+ else:
1163
+ arg_list = [arg]
1164
+
1165
+ for cur_arg in arg_list:
1166
+ # hard coded arguments show up but aren't `Node` typed and do not need dtype propagated
1167
+ if isinstance(cur_arg, torch.fx.node.Node):
1168
+ _maybe_propagate_dtype_for_node(
1169
+ cur_arg, arg_type, node_name_to_match_result_with_qconfig)
1170
+
1171
+ def _maybe_make_input_output_share_observers(
1172
+ node: Node,
1173
+ model: torch.nn.Module,
1174
+ named_modules: Dict[str, torch.nn.Module],
1175
+ ) -> bool:
1176
+ """
1177
+ Ensures that we share an observer
1178
+ for all input arguments as well as the output argument. In detail, given
1179
+ a graph of
1180
+
1181
+ x0 -> obs0 -> op -> x2
1182
+ /
1183
+ x1 -> obs1 /
1184
+
1185
+ where node obs0 points to observer instance observer0,
1186
+ obs1 points to observer1 and obs2 points to observer2, we make nodes obs1
1187
+ and ob2 point to observer0.
1188
+ Returns: whether the operation succeeded or not
1189
+ """
1190
+ first_arg = None
1191
+ # find the first non-Tensor arg
1192
+ for i in range(len(node.args)):
1193
+ if isinstance(node.args[i], (Node, list, tuple)):
1194
+ first_arg = node.args[i]
1195
+ break
1196
+
1197
+ # if there is no non-Tensor arg, return directly
1198
+ if first_arg is None:
1199
+ return False
1200
+
1201
+ if isinstance(first_arg, (list, tuple)):
1202
+ first_arg_arg = first_arg[0]
1203
+ elif isinstance(first_arg, Node):
1204
+ first_arg_arg = first_arg
1205
+ else:
1206
+ return False
1207
+
1208
+ # if we have a graph such as
1209
+ # observed_node -> non_observed_node -> cat
1210
+ # we need to navigate up to the first observer
1211
+ iteration_guard = 0
1212
+ while not _is_activation_post_process_node(first_arg_arg, named_modules):
1213
+ if not isinstance(first_arg_arg, Node):
1214
+ return False
1215
+ # did not find an activation_post_process for the op
1216
+ if first_arg_arg.op == "placeholder":
1217
+ return False
1218
+ # trace back the args until we found the first Tensor/Node
1219
+ trace_back_node = None
1220
+ for i in range(len(first_arg_arg.args)):
1221
+ trace_back_node = first_arg_arg.args[i]
1222
+ if isinstance(trace_back_node, Node):
1223
+ break
1224
+ if trace_back_node is None:
1225
+ return False
1226
+ first_arg_arg = trace_back_node
1227
+
1228
+ iteration_guard += 1
1229
+ if iteration_guard > 10000:
1230
+ raise AssertionError('Unable to find observer of previous node')
1231
+
1232
+ assert isinstance(first_arg_arg, Node)
1233
+ target_to_use = first_arg_arg.target
1234
+ assert isinstance(target_to_use, str)
1235
+ obs_mod_to_use = named_modules[target_to_use]
1236
+
1237
+ if isinstance(first_arg, (list, tuple)):
1238
+ # set all other input observer nodes to use that module
1239
+ for input_idx, input_arg in enumerate(first_arg):
1240
+ if input_idx == 0:
1241
+ continue
1242
+ iteration_guard = 0
1243
+ while not _is_activation_post_process_node(input_arg, named_modules):
1244
+ # failed to trace back since no input arg for the current node
1245
+ if len(input_arg.args) < 1:
1246
+ return False
1247
+ input_arg = input_arg.args[0]
1248
+ iteration_guard += 1
1249
+ if iteration_guard > 10000:
1250
+ raise AssertionError('Unable to find observer of previous node')
1251
+
1252
+ parent_name, name = _parent_name(input_arg.target)
1253
+ setattr(named_modules[parent_name], name, obs_mod_to_use)
1254
+
1255
+ # set the output observer node to use that module
1256
+ for output_obs_node in node.users.keys():
1257
+ assert _is_activation_post_process_node(output_obs_node, named_modules)
1258
+ parent_name, name = _parent_name(output_obs_node.target)
1259
+ setattr(named_modules[parent_name], name, obs_mod_to_use)
1260
+
1261
+ # TODO(future PR): delete the orphaned observer modules
1262
+ return True
1263
+
1264
+ def _remove_output_observer(
1265
+ node: Node,
1266
+ model: torch.nn.Module,
1267
+ named_modules: Dict[str, torch.nn.Module]):
1268
+ items = list(node.users.items())
1269
+ for output_obs_node, _ in items:
1270
+ assert _is_activation_post_process_node(output_obs_node, named_modules)
1271
+ output_obs_node.replace_all_uses_with(node)
1272
+ model.graph.erase_node(output_obs_node) # type: ignore[union-attr, operator]
1273
+
1274
+ def _swap_custom_module_to_observed(
1275
+ node: Node,
1276
+ qconfig: QConfigAny,
1277
+ named_modules: Dict[str, torch.nn.Module],
1278
+ prepare_custom_config: PrepareCustomConfig):
1279
+ custom_module = named_modules[node.target] # type: ignore[index]
1280
+ custom_module_class_mapping = prepare_custom_config.float_to_observed_mapping
1281
+ observed_custom_module_class = \
1282
+ get_swapped_custom_module_class(
1283
+ custom_module, custom_module_class_mapping, qconfig)
1284
+ observed_custom_module = \
1285
+ observed_custom_module_class.from_float(custom_module)
1286
+ parent_name, name = _parent_name(node.target)
1287
+ setattr(named_modules[parent_name], name, observed_custom_module)
1288
+
1289
+ def insert_observers_for_model(
1290
+ model: GraphModule,
1291
+ node_name_to_match_result_with_qconfig: Dict[str, _MatchResultWithQConfig],
1292
+ node_name_to_qconfig: Dict[str, QConfigAny],
1293
+ prepare_custom_config: PrepareCustomConfig,
1294
+ equalization_config_map: Dict[str, Any],
1295
+ backend_config: BackendConfig,
1296
+ observed_node_names: Set[str],
1297
+ is_qat: bool,
1298
+ ) -> Optional[Node]:
1299
+ """
1300
+ Inserts observers, using the following high level algorithm:
1301
+
1302
+ For each node in the graph:
1303
+ 1. determine the target dtype of this node in the quantized graph, and save
1304
+ it for future steps
1305
+ 2. determine the target dtype or all args and kwargs of this node
1306
+ 3. if any arg or kwarg's target dtype does not match the current node's
1307
+ dtype, insert an observer
1308
+ 4. if the current node needs an output observer, insert it
1309
+
1310
+ For example:
1311
+
1312
+ - starting graph:
1313
+ x0 -> linear -> x1
1314
+
1315
+ - observed graph after processing x0:
1316
+ x0(fp32)
1317
+
1318
+ - observed graph after processing linear:
1319
+ x0(fp32) -> x0_obs0(int8) -> linear(int8) -> linear_obs0(int8)
1320
+
1321
+ - observed graph after processing x1:
1322
+ x0(fp32) -> x0_obs0(int8) -> linear(int8) -> linear_obs0(int8) -> x1
1323
+
1324
+ After a node is processed, the naive observer placement is guaranteed to be
1325
+ complete for that node and all of its predecessors. There can be future
1326
+ passes which optimize the graph by deduplicating observers, etc.
1327
+ """
1328
+
1329
+ # node.meta["target_dtype_info"] stores the target dtype information
1330
+ # that's derived from qconfig for the Node, for example, if we have
1331
+ # a conv2d node that has a qconfig
1332
+ # qconfig = QConfig(activation=..., weight=...)
1333
+ # # information for input and bias node omitted
1334
+ # # for getattr node
1335
+ # # weight = getattr(self, 'weight')
1336
+ # weight.meta["target_dtype_info"] = {
1337
+ # 'output_act_obs_or_fq_ctr': qconfig.weight,
1338
+ # }
1339
+ # # for conv2d node
1340
+ # # conv2d = call_function[target=torch.nn.functional.conv2d](
1341
+ # # args=(input, weight, bias))
1342
+ # conv2d.meta["target_dtype_info"] = {
1343
+ # 'input_act_obs_or_fq_ctr': qconfig.activation
1344
+ # 'weight_obs_or_fq_ctr': qconfig.weight,
1345
+ # 'bias_obs_or_fq_ctr': PlaceholderObserver.with_args(dtype=torch.float32),
1346
+ # 'output_act_obs_or_fq_ctr': qconfig.activation,
1347
+ # }
1348
+ #
1349
+ cache_for_no_tensor_check: Dict[Node, bool] = {}
1350
+
1351
+ # first, populate the dtype map based only on qconfig and qhandler
1352
+ # this assumes:
1353
+ # graph inputs are fp32 by default, and int8 where overriden
1354
+ # other nodes output dtype is specified by the qconfig
1355
+ named_modules = dict(model.named_modules(remove_duplicate=False))
1356
+
1357
+ input_quantized_idxs: List[int] = prepare_custom_config.input_quantized_indexes
1358
+ output_quantized_idxs: List[int] = prepare_custom_config.output_quantized_indexes
1359
+ processed_nodes: Set[Node] = set()
1360
+ # initialize target_dtype_info
1361
+ for node in model.graph.nodes:
1362
+ node.meta["target_dtype_info"] = copy.copy(_DEFAULT_FP32_QCONFIG_FOR_TARGET_DTYPE_INFO)
1363
+
1364
+ inputs_seen_counter = 0
1365
+ outputs_seen_counter = 0
1366
+ placeholder_node_to_input_index: Dict[Node, int] = {}
1367
+ # TODO: we probably don't need this counter since each graph will only have
1368
+ # one output node?
1369
+ output_node_to_output_index: Dict[Node, int] = {}
1370
+ for node in model.graph.nodes:
1371
+ if node.op == "placeholder":
1372
+ placeholder_node_to_input_index[node] = inputs_seen_counter
1373
+ inputs_seen_counter += 1
1374
+ if node.op == "output":
1375
+ output_node_to_output_index[node] = outputs_seen_counter
1376
+ outputs_seen_counter += 1
1377
+
1378
+ # Step 1, set the observer or fake quantize module constructor for each node in the
1379
+ # matched_node_pattern
1380
+
1381
+ for match_res_with_qconfig in node_name_to_match_result_with_qconfig.values():
1382
+ last_node, matched_node_pattern, pattern, qhandler, qconfig = match_res_with_qconfig
1383
+ assert qhandler is not None
1384
+ _set_target_dtype_info_for_matched_node_pattern(
1385
+ matched_node_pattern,
1386
+ last_node,
1387
+ qconfig,
1388
+ qhandler,
1389
+ backend_config,
1390
+ named_modules,
1391
+ cache_for_no_tensor_check,
1392
+ processed_nodes
1393
+ )
1394
+
1395
+ # Step 2. Special cases for some operators, we might be able to remove them
1396
+ # in the future if we know dtype information of each node better
1397
+
1398
+ # Step 2.1. some settings are not based on patterns, we need to process each node
1399
+ # instead
1400
+ for node in model.graph.nodes:
1401
+ if node.op == "placeholder" and placeholder_node_to_input_index[node] in input_quantized_idxs:
1402
+ # users are not supposed to call calculate_qparams on PlaceholderObserver, and
1403
+ # this is OK because we are using this as a way to encode the dtypes of input
1404
+ # tensor, we won't actually insert these observers in the graph and won't
1405
+ # actually call calculate_qparams
1406
+ node.meta["target_dtype_info"] = copy.copy(_DEFAULT_QUINT8_QCONFIG_FOR_TARGET_DTYPE_INFO)
1407
+ elif node.op in ("call_module", "call_method", "call_function"):
1408
+ args_have_no_tensors = \
1409
+ all_node_args_have_no_tensors(
1410
+ node, named_modules, cache_for_no_tensor_check)
1411
+ if args_have_no_tensors:
1412
+ node.meta["target_dtype_info"] = {
1413
+ "input_act_obs_or_fq_ctr": None,
1414
+ "output_act_obs_or_fq_ctr": None,
1415
+ }
1416
+ elif node.op == "output" and output_node_to_output_index[node] in output_quantized_idxs:
1417
+ # TODO(future PR): update the output_quantized_idxs API to match
1418
+ # arbitrary data structures. There is always a single output, and
1419
+ # that output can have arbitrary nesting of values. List[int] is
1420
+ # not the right data type for this.
1421
+
1422
+ # TODO(future PR): support more dtypes in model outputs, if necessary
1423
+ node.meta["target_dtype_info"] = copy.copy(_DEFAULT_QUINT8_QCONFIG_FOR_TARGET_DTYPE_INFO)
1424
+
1425
+ # Step 2.2, for nodes with known input dtypes, propagate them throughout the
1426
+ # graph. For example, if there is a call such as
1427
+ # x1 = x0.masked_fill(mask, 1)
1428
+ # we propagate the type of mask to be torch.bool
1429
+ propagate_dtypes_for_known_nodes(model.graph, node_name_to_match_result_with_qconfig)
1430
+
1431
+ # Step 3, check if the requested target_dtype_info is supported by backend or not
1432
+ # if not, we'll reset the target_dtye_info to use the default (float Tensor)
1433
+
1434
+ # reset the counters and set of processed_nodes
1435
+ processed_nodes: Set[Node] = set()
1436
+ for match_res_with_qconfig in node_name_to_match_result_with_qconfig.values():
1437
+ last_node, matched_node_pattern, pattern, qhandler, qconfig = match_res_with_qconfig
1438
+ is_supported_by_backend = _is_pattern_dtype_config_and_qconfig_supported_by_backend(
1439
+ pattern, matched_node_pattern, qconfig, backend_config)
1440
+ assert qhandler is not None
1441
+
1442
+ # get output_act_dtype so that we don't also reset the special typed nodes
1443
+ # TODO: we might want to handle these more uniformly with the default path
1444
+ # this can be improved if we can use node.meta["val"]
1445
+ output_act_or_fq_ctr = node.meta["target_dtype_info"]["output_act_obs_or_fq_ctr"]
1446
+ output_act_or_fq = output_act_or_fq_ctr() if output_act_or_fq_ctr else None
1447
+ output_act_dtype, _ = _get_dtype_and_is_dynamic(output_act_or_fq)
1448
+ if not is_supported_by_backend and output_act_dtype not in [None, int, float, torch.bool]:
1449
+ # restore target_dtype_info to default if it is not supported by backend
1450
+ _set_target_dtype_info_for_matched_node_pattern(
1451
+ matched_node_pattern,
1452
+ last_node,
1453
+ torch.ao.quantization.qconfig._default_fp32_placeholder_qconfig,
1454
+ None,
1455
+ backend_config,
1456
+ named_modules,
1457
+ cache_for_no_tensor_check,
1458
+ processed_nodes
1459
+ )
1460
+
1461
+ # After this point, the current node and all of its arguments
1462
+ # have a target_dtype_info assigned. Now, we insert observers for inputs
1463
+ # of this node (if needed for this node), and the output of this node
1464
+ # (if needed for this node).
1465
+
1466
+ # Since we are mutating the graph as we go, we iterate over the original
1467
+ # nodes before observer insertion, instead of model.graph.nodes.
1468
+ nodes_before_observation = list(model.graph.nodes)
1469
+
1470
+ # Avoid duplicates custom module swaps for multiple nodes with same target.
1471
+ custom_module_names_already_swapped: Set[str] = set()
1472
+
1473
+ # TODO: reuse placeholder_node_to_input_index and output_node_to_output_index
1474
+ # reset inputs/outputs counters
1475
+ inputs_seen_counter = 0
1476
+ outputs_seen_counter = 0
1477
+ results_node = None
1478
+ obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize] = {}
1479
+
1480
+ # TODO: change this to insert obs/fq by pattern instead of by node
1481
+ for node in nodes_before_observation:
1482
+
1483
+ if node.op == 'placeholder':
1484
+ # if a graph input is in fp32, it does not need observation
1485
+ # if a graph input is in int8, we assume the observation happens
1486
+ # outside of the graph, and no additional observation is needed
1487
+ pass
1488
+
1489
+ elif node.op in ('call_module', 'call_method', 'call_function', 'output'):
1490
+ # check for matches
1491
+ last_node, matched_node_pattern, pattern, qhandler, qconfig = (
1492
+ node_name_to_match_result_with_qconfig.get(node.name, (None, None, None, None, None)) # type: ignore[assignment]
1493
+ )
1494
+ equalization_qconfig = equalization_config_map.get(node.name, None)
1495
+
1496
+ this_node_dtype_info = node.meta["target_dtype_info"]
1497
+ if "val" in node.meta:
1498
+ output_is_a_tensor = (
1499
+ this_node_dtype_info is not None and
1500
+ isinstance(node.meta["val"], FakeTensor)
1501
+ )
1502
+ else:
1503
+ output_is_a_tensor = this_node_dtype_info is not None
1504
+
1505
+ skip_inserting_observers = (
1506
+ (qconfig is None) or
1507
+ not output_is_a_tensor
1508
+ ) and (
1509
+ not node.op == 'output'
1510
+ )
1511
+
1512
+ # TODO: take a closer look to see if we can remove this check
1513
+ # right now it is here because of `observed_node_names`, we are using
1514
+ # it as an indicator for swapping the modules to reference modules in
1515
+ # convert
1516
+ is_supported_by_backend = _is_pattern_dtype_config_and_qconfig_supported_by_backend(
1517
+ pattern, matched_node_pattern, qconfig, backend_config)
1518
+
1519
+ if not skip_inserting_observers and is_supported_by_backend:
1520
+ named_modules = dict(model.named_modules(remove_duplicate=False))
1521
+ if node.op != 'output':
1522
+ assert matched_node_pattern is not None
1523
+ # add matched nodes to the observed node name set
1524
+ _add_matched_node_name_to_set(matched_node_pattern, observed_node_names)
1525
+
1526
+ # This is currently only used for equalization.
1527
+ # Checks if the current node is in a branch in which the two
1528
+ # first layers are both being quantized.
1529
+ #
1530
+ # ex. conv2
1531
+ # /
1532
+ # x -> conv1
1533
+ #
1534
+ # If this is the case, we will not apply equalization to the
1535
+ # initial two layers.
1536
+ is_quantized_branch = False
1537
+ if (
1538
+ len(node.args) > 0 and
1539
+ isinstance(node.args[0], Node) and
1540
+ len(node.args[0].users) > 1
1541
+ ):
1542
+ for user in node.args[0].users:
1543
+ # Checks if there exists another user being quantized
1544
+ is_user_quantized = (
1545
+ node_name_to_qconfig.get(user.name, None) is not None or
1546
+ (user.op == 'call_module' and isinstance(named_modules[str(user.target)], ObserverBase))
1547
+ )
1548
+ if user != node and is_user_quantized:
1549
+ is_quantized_branch = True
1550
+
1551
+ pattern_to_root_node_getter = get_fusion_pattern_to_root_node_getter(backend_config)
1552
+ root_node_getter = pattern_to_root_node_getter.get(pattern, _default_root_node_getter)
1553
+ root_node = root_node_getter(matched_node_pattern)
1554
+ is_input_node_of_the_pattern = node is root_node
1555
+ if is_input_node_of_the_pattern:
1556
+ # this modifies node inplace
1557
+ _maybe_insert_input_observers_for_node(
1558
+ node, qconfig, model, named_modules, model.graph,
1559
+ qhandler,
1560
+ prepare_custom_config,
1561
+ obs_or_fq_map,
1562
+ is_qat,
1563
+ backend_config)
1564
+
1565
+ # insert equalization input observers if needed
1566
+ _maybe_insert_input_equalization_observers_for_node(
1567
+ node, equalization_qconfig, model, named_modules, model.graph,
1568
+ is_quantized_branch)
1569
+
1570
+ is_last_node_of_pattern = node is last_node
1571
+ input_output_share_observers = node.meta["target_dtype_info"].get("input_output_share_observers", False)
1572
+ reuse_input_obs_or_fq = node.meta["target_dtype_info"].get("reuse_input_obs_or_fq", False)
1573
+
1574
+ if is_last_node_of_pattern:
1575
+ if _is_custom_module_lstm(node, named_modules, qconfig, qhandler):
1576
+ # Currently custom module outputs are assumed to be already quantized,
1577
+ # so we need to insert a DeQuantStub after the output. For custom module
1578
+ # LSTM specifically, the outputs are also a nested tuple, so we must first
1579
+ # break down the tuple to insert DeQuantStubs after the internal nodes.
1580
+
1581
+ # TODO: This currently diverges from how custom modules are handled today,
1582
+ # where we insert observers after the output instead of DeQuantStubs, and
1583
+ # replace these observers with "dequantize" nodes during convert. Conceptually,
1584
+ # these output observers are the same as DeQuantStubs. In the future, we
1585
+ # should resolve this inconsistency by inserting DeQuantStubs for all custom
1586
+ # modules, not just for LSTM.
1587
+ _insert_dequant_stubs_for_custom_module_lstm_output(node, model, named_modules, model.graph)
1588
+ if node.target not in custom_module_names_already_swapped:
1589
+ custom_module_names_already_swapped.add(node.target)
1590
+ _swap_custom_module_to_observed(node, qconfig, named_modules, prepare_custom_config)
1591
+ else:
1592
+ # this returns the new observer node if it was needed
1593
+ maybe_output_obs_node = _maybe_insert_output_observer_for_node(
1594
+ node, model, named_modules, model.graph, obs_or_fq_map, is_qat)
1595
+
1596
+ if maybe_output_obs_node is not None:
1597
+ # Update users of original node to use the output observer
1598
+ # instead. For example, change
1599
+ #
1600
+ # next_node
1601
+ # /
1602
+ # cur_node -> obs
1603
+ #
1604
+ # to
1605
+ #
1606
+ # next_node
1607
+ # /
1608
+ # cur_node -> obs
1609
+ #
1610
+ # We need to save orig users before updating uses because
1611
+ # the list of users will change as we update uses
1612
+ orig_users = list(node.users.keys())
1613
+ for user_node in orig_users:
1614
+ if user_node is maybe_output_obs_node:
1615
+ continue
1616
+ user_node.replace_input_with(node, maybe_output_obs_node)
1617
+
1618
+ _is_observer_in_same_graph_ = _is_observer_in_same_graph(
1619
+ node, named_modules, obs_or_fq_map, is_qat)
1620
+
1621
+ # for ops whose inputs and outputs share observer/fqs, we modify the graph
1622
+ # to make all inputs and outputs use the first input's
1623
+ # observer/fq
1624
+ if (input_output_share_observers and _is_observer_in_same_graph_) or \
1625
+ reuse_input_obs_or_fq:
1626
+ if not _maybe_make_input_output_share_observers(node, model, named_modules):
1627
+ _remove_output_observer(node, model, named_modules)
1628
+
1629
+ if qhandler is not None and qhandler.is_custom_module():
1630
+ if node.target not in custom_module_names_already_swapped:
1631
+ custom_module_names_already_swapped.add(node.target)
1632
+ _swap_custom_module_to_observed(node, qconfig, named_modules, prepare_custom_config)
1633
+
1634
+ else: # output
1635
+ _maybe_insert_observers_before_graph_output(node, model, named_modules, model.graph, obs_or_fq_map, is_qat)
1636
+
1637
+ #
1638
+ # After this point, the current node has input and output observers
1639
+ # that it needs for itself inserted.
1640
+ #
1641
+
1642
+ # increment the counters, so future inputs and outputs are assigned
1643
+ # correct dtypes
1644
+ if node.op == 'placeholder':
1645
+ inputs_seen_counter += 1
1646
+ elif node.op == 'output':
1647
+ outputs_seen_counter += 1
1648
+ results_node = node
1649
+
1650
+ return results_node
1651
+
1652
+ def _run_prepare_fx_on_standalone_modules(
1653
+ model: torch.nn.Module,
1654
+ is_qat: bool,
1655
+ named_modules: Dict[str, torch.nn.Module],
1656
+ node_name_to_match_result_with_qconfig: Any,
1657
+ prepare_custom_config: PrepareCustomConfig,
1658
+ backend_config: BackendConfig,
1659
+ ) -> None:
1660
+ """
1661
+ Runs prepare_fx on each standalone module. Note: this does
1662
+ not modify the graph, it just replaces the unobserved modules with
1663
+ their observed versions.
1664
+ """
1665
+ for (root_node, _, pattern, qhandler, qconfig) in node_name_to_match_result_with_qconfig.values():
1666
+ if qhandler is None:
1667
+ continue
1668
+ elif not qhandler.is_standalone_module():
1669
+ continue
1670
+
1671
+ sm_qconfig_mapping, sm_example_inputs, sm_prepare_custom_config, \
1672
+ sm_backend_config = _get_standalone_module_configs(
1673
+ root_node, named_modules, prepare_custom_config, qconfig, backend_config)
1674
+
1675
+ standalone_module = named_modules[root_node.target]
1676
+ prepare = \
1677
+ torch.ao.quantization.quantize_fx._prepare_standalone_module_fx # type: ignore[attr-defined]
1678
+ observed_standalone_module = \
1679
+ prepare(
1680
+ standalone_module,
1681
+ sm_qconfig_mapping,
1682
+ is_qat,
1683
+ example_inputs=sm_example_inputs,
1684
+ prepare_custom_config=sm_prepare_custom_config,
1685
+ backend_config=sm_backend_config)
1686
+ parent_name, name = _parent_name(root_node.target)
1687
+ setattr(named_modules[parent_name], name, observed_standalone_module)
1688
+ named_modules[root_node.target] = observed_standalone_module
1689
+
1690
+ def _save_state(
1691
+ observed: GraphModule,
1692
+ node_name_to_qconfig: Dict[str, QConfigAny],
1693
+ node_name_to_scope: Dict[str, Tuple[str, type]],
1694
+ prepare_custom_config: PrepareCustomConfig,
1695
+ equalization_node_name_to_qconfig: Dict[str, Any],
1696
+ qconfig_mapping: QConfigMapping,
1697
+ is_qat: bool,
1698
+ observed_node_names: Set[str],
1699
+ ) -> None:
1700
+ observed.meta["_observed_graph_module_attrs"] = (
1701
+ ObservedGraphModuleAttrs(
1702
+ node_name_to_qconfig=node_name_to_qconfig,
1703
+ node_name_to_scope=node_name_to_scope,
1704
+ prepare_custom_config=prepare_custom_config,
1705
+ equalization_node_name_to_qconfig=equalization_node_name_to_qconfig,
1706
+ qconfig_mapping=qconfig_mapping,
1707
+ is_qat=is_qat,
1708
+ observed_node_names=observed_node_names,
1709
+ )
1710
+ )
1711
+
1712
+ def prepare(
1713
+ model: GraphModule,
1714
+ qconfig_mapping: Union[QConfigMapping, Dict[str, Any]],
1715
+ is_qat: bool,
1716
+ node_name_to_scope: Dict[str, Tuple[str, type]],
1717
+ example_inputs: Tuple[Any, ...],
1718
+ prepare_custom_config: Union[PrepareCustomConfig, Dict[str, Any], None] = None,
1719
+ _equalization_config: Union[QConfigMapping, Dict[str, Any], None] = None,
1720
+ backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
1721
+ is_standalone_module: bool = False) -> GraphModule:
1722
+ """ standalone_module means it a submodule that is not inlined in
1723
+ parent module, and will be quantized separately as one unit.
1724
+
1725
+ How the standalone module is observed is specified by `input_quantized_idxs` and
1726
+ `output_quantized_idxs` in the prepare_custom_config for the standalone module
1727
+ Args:
1728
+ node_name_to_scope: mapping from node name to the scope of the module which contains the node.
1729
+ The scope is a tuple of fully qualified path of the module and the type of the module
1730
+ Returns:
1731
+ model(GraphModule): prepared standalone module
1732
+ attributes related to standalone module
1733
+ in model.meta["_observed_graph_module_attrs"]:
1734
+ is_observed_standalone_module (bool): boolean value that shows whether the
1735
+ current model is a observed standalone module or not
1736
+ standalone_module_input_quantized_idxs(List[Int]): a list of
1737
+ indexes for the graph input that is expected to be quantized,
1738
+ same as input_quantized_idxs configuration provided
1739
+ for the standalone module
1740
+ standalone_module_output_quantized_idxs(List[Int]): a list of
1741
+ indexs for the graph output that is quantized
1742
+ same as input_quantized_idxs configuration provided
1743
+ for the standalone module
1744
+ """
1745
+ if prepare_custom_config is None:
1746
+ prepare_custom_config = PrepareCustomConfig()
1747
+ if _equalization_config is None:
1748
+ _equalization_config = QConfigMapping()
1749
+
1750
+ if isinstance(qconfig_mapping, Dict):
1751
+ warnings.warn(
1752
+ "Passing a QConfig dictionary to prepare is deprecated and will not be supported "
1753
+ "in a future version. Please pass in a QConfigMapping instead.")
1754
+ qconfig_mapping = QConfigMapping.from_dict(qconfig_mapping)
1755
+
1756
+ if isinstance(_equalization_config, Dict):
1757
+ warnings.warn(
1758
+ "Passing a QConfig dictionary to prepare for equalization is deprecated and will not "
1759
+ "be supported in a future version. Please pass in a QConfigMapping instead.")
1760
+ _equalization_config = QConfigMapping.from_dict(_equalization_config)
1761
+
1762
+ if isinstance(prepare_custom_config, Dict):
1763
+ warnings.warn(
1764
+ "Passing a prepare_custom_config_dict to prepare is deprecated and will not be supported "
1765
+ "in a future version. Please pass in a PrepareCustomConfig instead.")
1766
+ prepare_custom_config = PrepareCustomConfig.from_dict(prepare_custom_config)
1767
+
1768
+ if isinstance(backend_config, Dict):
1769
+ warnings.warn(
1770
+ "Passing a backend_config_dict to prepare is deprecated and will not be supported "
1771
+ "in a future version. Please pass in a BackendConfig instead.")
1772
+ backend_config = BackendConfig.from_dict(backend_config)
1773
+
1774
+ assert isinstance(qconfig_mapping, QConfigMapping)
1775
+ assert isinstance(_equalization_config, QConfigMapping)
1776
+ qconfig_mapping = copy.deepcopy(qconfig_mapping)
1777
+ _equalization_config = copy.deepcopy(_equalization_config)
1778
+
1779
+ # mapping from a tuple of nodes in reverse order to uninitialized
1780
+ # QuantizeHandler subclass. For example,
1781
+ # {
1782
+ # # match a single node
1783
+ # (<class 'torch.nn.modules.conv.Conv3d'>:
1784
+ # <class 'torch.ao.quantization.fx.quantize.ConvRelu'>),
1785
+ # # match multiple nodes in reverse order
1786
+ # ((<function relu at 0x7f766a7360d0>, <built-in function add>):
1787
+ # <class 'torch.ao.quantization.fx.quantize.Add'>),
1788
+ # }
1789
+
1790
+ pattern_to_quantize_handler: Dict[Pattern, QuantizeHandler] = {}
1791
+ if backend_config is None:
1792
+ backend_config = get_native_backend_config()
1793
+ pattern_to_quantize_handler = _get_pattern_to_quantize_handlers(backend_config)
1794
+ pattern_to_quantize_handler = _sorted_patterns_dict(pattern_to_quantize_handler)
1795
+
1796
+ root_node_getter_mapping = \
1797
+ get_fusion_pattern_to_root_node_getter(backend_config)
1798
+
1799
+ _update_qconfig_for_fusion(model, qconfig_mapping)
1800
+ _update_qconfig_for_fusion(model, _equalization_config)
1801
+ flattened_qconfig_dict = _get_flattened_qconfig_dict(qconfig_mapping)
1802
+ # TODO: support regex as well
1803
+ propagate_qconfig_(model, flattened_qconfig_dict, prepare_custom_config.to_dict())
1804
+
1805
+ if is_qat:
1806
+ module_to_qat_module = get_module_to_qat_module(backend_config)
1807
+ _qat_swap_modules(model, module_to_qat_module)
1808
+ _update_qconfig_for_qat(qconfig_mapping, backend_config)
1809
+
1810
+ # mapping from fully qualified module name to module instance
1811
+ # for example,
1812
+ # {
1813
+ # '': Model(...),
1814
+ # 'linear': Linear(...),
1815
+ # 'linear.weight_fake_quant': PerChannelMinMaxObserver(...),
1816
+ # }
1817
+ named_modules = dict(model.named_modules(remove_duplicate=False))
1818
+
1819
+ # fill node_name_to_qconfig, a map from node name to qconfig, used in _find_matches
1820
+ equalization_node_name_to_qconfig = _generate_node_name_to_qconfig(
1821
+ model, named_modules, model.graph, _equalization_config, node_name_to_scope)
1822
+ node_name_to_qconfig = _generate_node_name_to_qconfig(model, named_modules, model.graph, qconfig_mapping, node_name_to_scope)
1823
+
1824
+ # match the patterns that will get quantized
1825
+ standalone_module_names = list(prepare_custom_config.standalone_module_names.keys())
1826
+ standalone_module_classes = list(prepare_custom_config.standalone_module_classes.keys())
1827
+
1828
+ custom_module_classes = get_custom_module_class_keys(prepare_custom_config.float_to_observed_mapping)
1829
+ matches_without_qconfig = _find_matches(
1830
+ model.graph, named_modules, pattern_to_quantize_handler, root_node_getter_mapping,
1831
+ standalone_module_names, standalone_module_classes, custom_module_classes)
1832
+
1833
+ # map qconfig instances to matches
1834
+ node_name_to_match_result_with_qconfig = {}
1835
+ for node_name, match_without_qconfig in matches_without_qconfig.items():
1836
+ match_with_qconfig = (*match_without_qconfig, node_name_to_qconfig[node_name])
1837
+ node_name_to_match_result_with_qconfig[node_name] = match_with_qconfig
1838
+
1839
+ _run_prepare_fx_on_standalone_modules(
1840
+ model, is_qat, named_modules, node_name_to_match_result_with_qconfig, prepare_custom_config, backend_config)
1841
+
1842
+ # record names for the set of observed node, so that in convert step
1843
+ # we know whether we need to convert a floating point module to reference
1844
+ # quantized module or not
1845
+ observed_node_names: Set[str] = set()
1846
+
1847
+ result_node = insert_observers_for_model(
1848
+ model,
1849
+ node_name_to_match_result_with_qconfig,
1850
+ node_name_to_qconfig,
1851
+ prepare_custom_config,
1852
+ equalization_node_name_to_qconfig,
1853
+ backend_config,
1854
+ observed_node_names,
1855
+ is_qat,
1856
+ )
1857
+ model = GraphModule(model, model.graph)
1858
+
1859
+ _save_state(model, node_name_to_qconfig, node_name_to_scope,
1860
+ prepare_custom_config, equalization_node_name_to_qconfig,
1861
+ qconfig_mapping, is_qat, observed_node_names)
1862
+
1863
+ if is_standalone_module:
1864
+ assert result_node is not None
1865
+ assert isinstance(result_node.args[0], Node), \
1866
+ "standalone module only supports returning simple value currently"\
1867
+ "(not tuple, dict etc.)"
1868
+ # these inputs are observed in parent
1869
+ # converting List[int] to Tensor since module attribute is
1870
+ # Union[Tensor, Module]
1871
+ input_quantized_idxs: List[int] = prepare_custom_config.input_quantized_indexes
1872
+ output_quantized_idxs: List[int] = prepare_custom_config.output_quantized_indexes
1873
+ observed_graph_module_attrs = model.meta["_observed_graph_module_attrs"]
1874
+ # inplace modification
1875
+ observed_graph_module_attrs.is_observed_standalone_module = True
1876
+ observed_graph_module_attrs.standalone_module_input_quantized_idxs = \
1877
+ input_quantized_idxs
1878
+ observed_graph_module_attrs.standalone_module_output_quantized_idxs = \
1879
+ output_quantized_idxs
1880
+ return model
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/__pycache__/duplicate_dq_pass.cpython-311.pyc ADDED
Binary file (4.62 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/__pycache__/prepare.cpython-311.pyc ADDED
Binary file (18.4 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/__pycache__/qat_utils.cpython-311.pyc ADDED
Binary file (38 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/__pycache__/utils.cpython-311.pyc ADDED
Binary file (26.8 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/duplicate_dq_pass.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import operator
3
+
4
+ import torch
5
+
6
+ from torch.ao.quantization.pt2e.utils import (
7
+ _filter_sym_size_users,
8
+ _is_valid_annotation,
9
+ )
10
+
11
+ from torch.fx.node import map_arg
12
+ from torch.fx.passes.infra.pass_base import PassBase, PassResult
13
+
14
+
15
+ logger = logging.getLogger(__name__)
16
+ logger.setLevel(logging.WARNING)
17
+
18
+ __all__ = ["DuplicateDQPass"]
19
+
20
+ _QUANTIZE_OPS = [
21
+ torch.ops.quantized_decomposed.quantize_per_tensor.default,
22
+ torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
23
+ torch.ops.quantized_decomposed.quantize_per_channel.default,
24
+ ]
25
+
26
+ _DEQUANTIZE_OPS = [
27
+ torch.ops.quantized_decomposed.dequantize_per_tensor.default,
28
+ torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
29
+ torch.ops.quantized_decomposed.dequantize_per_channel.default,
30
+ ]
31
+
32
+
33
+ def _maybe_duplicate_dq(
34
+ gm: torch.fx.GraphModule, dq_node: torch.fx.Node, user: torch.fx.Node
35
+ ):
36
+ annotation = user.meta.get("quantization_annotation", None)
37
+ if not _is_valid_annotation(annotation):
38
+ return
39
+ with gm.graph.inserting_after(dq_node):
40
+ new_node = gm.graph.node_copy(dq_node)
41
+
42
+ def maybe_replace_node(n: torch.fx.Node) -> torch.fx.Node:
43
+ if n == dq_node:
44
+ return new_node
45
+ else:
46
+ return n
47
+
48
+ new_args = map_arg(user.args, maybe_replace_node)
49
+ new_kwargs = map_arg(user.kwargs, maybe_replace_node)
50
+ user.args = new_args
51
+ user.kwargs = new_kwargs
52
+
53
+
54
+ class DuplicateDQPass(PassBase):
55
+ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
56
+ for node in graph_module.graph.nodes:
57
+ if node.op == "call_function" and node.target in _DEQUANTIZE_OPS:
58
+ dq_users = _filter_sym_size_users(node)
59
+ if len(dq_users) <= 1:
60
+ continue
61
+ # Do not duplicate dq for dynamic quantization
62
+ # Pattern: choose_qparam - getitem - q - dq
63
+ q_node = node.args[0]
64
+ if q_node.op == "call_function" and q_node.target in _QUANTIZE_OPS:
65
+ getitem_node = q_node.args[1]
66
+ if (
67
+ isinstance(getitem_node, torch.fx.node.Node)
68
+ and getitem_node.op == "call_function"
69
+ and getitem_node.target == operator.getitem
70
+ ):
71
+ choose_qparam_node = getitem_node.args[0]
72
+ if (
73
+ isinstance(choose_qparam_node, torch.fx.node.Node)
74
+ and choose_qparam_node.op == "call_function"
75
+ and choose_qparam_node.target
76
+ == torch.ops.quantized_decomposed.choose_qparams.tensor
77
+ ):
78
+ continue
79
+ for user in dq_users:
80
+ _maybe_duplicate_dq(graph_module, node, user)
81
+ graph_module.graph.eliminate_dead_code()
82
+ graph_module.recompile()
83
+ return PassResult(graph_module, True)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/export_utils.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import types
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+
7
+ __all__ = [
8
+ "model_is_exported",
9
+ "_WrapperModule",
10
+ ]
11
+
12
+
13
+ class _WrapperModule(torch.nn.Module):
14
+ """Class to wrap a callable in an :class:`torch.nn.Module`. Use this if you
15
+ are trying to export a callable.
16
+ """
17
+
18
+ def __init__(self, fn):
19
+ super().__init__()
20
+ self.fn = fn
21
+
22
+ def forward(self, *args, **kwargs):
23
+ """Simple forward that just calls the ``fn`` provided to :meth:`WrapperModule.__init__`."""
24
+ return self.fn(*args, **kwargs)
25
+
26
+
27
+ def model_is_exported(m: torch.nn.Module) -> bool:
28
+ """
29
+ Return True if the `torch.nn.Module` was exported, False otherwise
30
+ (e.g. if the model was FX symbolically traced or not traced at all).
31
+ """
32
+ return isinstance(m, torch.fx.GraphModule) and any(
33
+ "val" in n.meta for n in m.graph.nodes
34
+ )
35
+
36
+
37
+ def _replace_dropout(m: torch.fx.GraphModule, train_to_eval: bool):
38
+ """
39
+ Switch dropout patterns in the model between train and eval modes.
40
+
41
+ Dropout has different behavior in train vs eval mode. For exported models,
42
+ however, calling `model.train()` or `model.eval()` does not automatically switch
43
+ the dropout behavior between the two modes, so here we need to rewrite the aten
44
+ dropout patterns manually to achieve the same effect.
45
+
46
+ See https://github.com/pytorch/pytorch/issues/103681.
47
+ """
48
+ # Avoid circular dependencies
49
+ from .utils import get_aten_graph_module
50
+
51
+ # Needed to ensure subgraph matches are self-contained
52
+ m.graph.eliminate_dead_code()
53
+ m.recompile()
54
+
55
+ for inplace in [False, True]:
56
+
57
+ def dropout_train(x):
58
+ return F.dropout(x, p=0.5, training=True, inplace=inplace)
59
+
60
+ def dropout_eval(x):
61
+ return F.dropout(x, p=0.5, training=False, inplace=inplace)
62
+
63
+ example_inputs = (torch.randn(1),)
64
+ if train_to_eval:
65
+ match_pattern = get_aten_graph_module(
66
+ _WrapperModule(dropout_train), example_inputs
67
+ )
68
+ replacement_pattern = get_aten_graph_module(
69
+ _WrapperModule(dropout_eval), example_inputs
70
+ )
71
+ else:
72
+ match_pattern = get_aten_graph_module(
73
+ _WrapperModule(dropout_eval), example_inputs
74
+ )
75
+ replacement_pattern = get_aten_graph_module(
76
+ _WrapperModule(dropout_train), example_inputs
77
+ )
78
+
79
+ from torch.fx.subgraph_rewriter import replace_pattern_with_filters
80
+
81
+ replace_pattern_with_filters(
82
+ m,
83
+ match_pattern,
84
+ replacement_pattern,
85
+ match_filters=[],
86
+ ignore_literals=True,
87
+ )
88
+ m.recompile()
89
+
90
+
91
+ def _replace_batchnorm(m: torch.fx.GraphModule, train_to_eval: bool):
92
+ """
93
+ Switch batchnorm patterns in the model between train and eval modes.
94
+
95
+ Batchnorm has different behavior in train vs eval mode. For exported models,
96
+ however, calling `model.train()` or `model.eval()` does not automatically switch
97
+ the batchnorm behavior between the two modes, so here we need to rewrite the aten
98
+ batchnorm patterns manually to achieve the same effect.
99
+ """
100
+ # TODO(Leslie): This function still fails to support custom momentum and eps value.
101
+ # Enable this support in future updates.
102
+
103
+ # Avoid circular dependencies
104
+ from .utils import get_aten_graph_module
105
+
106
+ # Needed to ensure subgraph matches are self-contained
107
+ m.graph.eliminate_dead_code()
108
+ m.recompile()
109
+
110
+ def bn_train(
111
+ x: torch.Tensor,
112
+ bn_weight: torch.Tensor,
113
+ bn_bias: torch.Tensor,
114
+ bn_running_mean: torch.Tensor,
115
+ bn_running_var: torch.Tensor,
116
+ ):
117
+ return F.batch_norm(
118
+ x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=True
119
+ )
120
+
121
+ def bn_eval(
122
+ x: torch.Tensor,
123
+ bn_weight: torch.Tensor,
124
+ bn_bias: torch.Tensor,
125
+ bn_running_mean: torch.Tensor,
126
+ bn_running_var: torch.Tensor,
127
+ ):
128
+ return F.batch_norm(
129
+ x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=False
130
+ )
131
+
132
+ example_inputs = (
133
+ torch.randn(1, 1, 3, 3), # x
134
+ torch.randn(1), # bn_weight
135
+ torch.randn(1), # bn_bias
136
+ torch.randn(1), # bn_running_mean
137
+ torch.randn(1), # bn_running_var
138
+ )
139
+ if train_to_eval:
140
+ match_pattern = get_aten_graph_module(_WrapperModule(bn_train), example_inputs)
141
+ replacement_pattern = get_aten_graph_module(
142
+ _WrapperModule(bn_eval), example_inputs
143
+ )
144
+ else:
145
+ match_pattern = get_aten_graph_module(_WrapperModule(bn_eval), example_inputs)
146
+ replacement_pattern = get_aten_graph_module(
147
+ _WrapperModule(bn_train), example_inputs
148
+ )
149
+
150
+ from torch.fx.subgraph_rewriter import replace_pattern_with_filters
151
+
152
+ replace_pattern_with_filters(
153
+ m,
154
+ match_pattern,
155
+ replacement_pattern,
156
+ match_filters=[],
157
+ ignore_literals=True,
158
+ )
159
+ m.recompile()
160
+
161
+
162
+ # TODO: expose these under this namespace?
163
+ def _move_exported_model_to_eval(model: torch.fx.GraphModule):
164
+ """
165
+ Move an exported GraphModule to eval mode.
166
+
167
+ This is equivalent to model.eval() but only for certain special ops like dropout, batchnorm.
168
+ QAT users should call this before performing inference on the model.
169
+ """
170
+ _replace_dropout(model, train_to_eval=True)
171
+ _replace_batchnorm(model, train_to_eval=True)
172
+ return model
173
+
174
+
175
+ def _move_exported_model_to_train(model: torch.fx.GraphModule):
176
+ """
177
+ Move an exported GraphModule to train mode.
178
+
179
+ This is equivalent to model.train() but only for certain special ops like dropout, batchnorm.
180
+ QAT users should call this before performing training on the model.
181
+ """
182
+ _replace_dropout(model, train_to_eval=False)
183
+ _replace_batchnorm(model, train_to_eval=False)
184
+ return model
185
+
186
+
187
+ def _allow_exported_model_train_eval(model: torch.fx.GraphModule):
188
+ """
189
+ Allow users to call `model.train()` and `model.eval()` on an exported model,
190
+ but with the effect of changing behavior between the two modes limited to special
191
+ ops only, which are currently dropout and batchnorm.
192
+
193
+ Note: This does not achieve the same effect as what `model.train()` and `model.eval()`
194
+ does in eager models, but only provides an approximation. In particular, user code
195
+ branching on `training` flag will not function correctly in general because the branch
196
+ is already specialized at export time. Additionally, other ops beyond dropout and batchnorm
197
+ that have different train/eval behavior will also not be converted properly.
198
+ """
199
+
200
+ def _train(self, mode: bool = True):
201
+ if mode:
202
+ _move_exported_model_to_train(self)
203
+ else:
204
+ _move_exported_model_to_eval(self)
205
+
206
+ def _eval(self):
207
+ _move_exported_model_to_eval(self)
208
+
209
+ model.train = types.MethodType(_train, model) # type: ignore[method-assign]
210
+ model.eval = types.MethodType(_eval, model) # type: ignore[method-assign]
211
+ return model
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/prepare.py ADDED
@@ -0,0 +1,489 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch._subclasses import FakeTensor
3
+ from torch.ao.quantization.fx.prepare import (
4
+ _insert_obs_or_fq,
5
+ _save_state,
6
+ _is_activation_post_process_node,
7
+ _create_obs_or_fq_from_qspec,
8
+ )
9
+ from torch.fx import (
10
+ GraphModule,
11
+ Graph,
12
+ Node,
13
+ )
14
+ from torch.fx.node import Argument
15
+
16
+ from torch.ao.quantization import QConfigMapping
17
+ from torch.ao.quantization.qconfig import QConfigAny
18
+ from torch.ao.quantization.fx.custom_config import PrepareCustomConfig
19
+ from typing import Dict, Tuple, Union, Any, Optional
20
+ from torch.ao.quantization.quantizer import (
21
+ EdgeOrNode,
22
+ SharedQuantizationSpec,
23
+ QuantizationSpecBase,
24
+ )
25
+ from torch.ao.quantization import ObserverOrFakeQuantize
26
+
27
+ # TODO: make pt2e folder private?
28
+ __all__ = [
29
+ "prepare",
30
+ ]
31
+
32
+
33
+ def _find_root_edge_or_node(edge_or_node: EdgeOrNode, shared_with_map: Dict[EdgeOrNode, EdgeOrNode]) -> EdgeOrNode:
34
+ """Find the root node for the sharing tree
35
+ Args:
36
+ edge_or_node: edge/node that we want to find the root
37
+ shared_with_map: each edge/node points to the parent, the root node will points to itself
38
+
39
+ Returns:
40
+ root edge/node
41
+ """
42
+ parent = shared_with_map[edge_or_node]
43
+ if parent == edge_or_node:
44
+ return edge_or_node
45
+ root = _find_root_edge_or_node(parent, shared_with_map)
46
+ # path compression
47
+ shared_with_map[edge_or_node] = root
48
+ return root
49
+
50
+ def _union(parent: EdgeOrNode, child: EdgeOrNode, shared_with_map: Dict[EdgeOrNode, EdgeOrNode]) -> None:
51
+ """Merge the subtree for `child` with `parent`, the order is important here
52
+ """
53
+ root_parent = _find_root_edge_or_node(parent, shared_with_map)
54
+ root_child = _find_root_edge_or_node(child, shared_with_map)
55
+ # union the two trees by pointing the root of child to root of parent
56
+ shared_with_map[root_child] = root_parent
57
+
58
+ def _update_shared_with(child: EdgeOrNode, qspec: QuantizationSpecBase, shared_with_map: Dict[EdgeOrNode, EdgeOrNode]):
59
+ """Update the `shared_with_map` based on the qspec, this applies the `SharedQuantizationSpec`
60
+ configuration and established the relationship between `edge_or_node` with the edge/node that it
61
+ is pointing to, we'll use this information in the end to get the group id
62
+ """
63
+ if isinstance(qspec, SharedQuantizationSpec):
64
+ parent = qspec.edge_or_node
65
+ # we point from edge_or_node to the node that it is sharing_with, e.g.
66
+ # qspec for a = SharedQuantizationSpec(b) means `a` points to `b`
67
+ _union(parent, child, shared_with_map)
68
+
69
+ def _unwrap_shared_qspec(
70
+ qspec: QuantizationSpecBase,
71
+ edge_or_node_to_qspec: Dict[EdgeOrNode, QuantizationSpecBase],
72
+ shared_with_map: Dict[EdgeOrNode, EdgeOrNode]
73
+ ) -> QuantizationSpecBase:
74
+ """Unwraps qspec to get the final root qspec (non SharedQuantizationSpec)
75
+ if qspec is SharedQuantizationSpec
76
+ (1). tries to find the root edge or node for the node that the qspec points to
77
+ (2). recursively find the root qspec based on the qspec for the root node
78
+ """
79
+ if isinstance(qspec, SharedQuantizationSpec):
80
+ sharing_with = qspec.edge_or_node
81
+ root = _find_root_edge_or_node(sharing_with, shared_with_map)
82
+ qspec = edge_or_node_to_qspec[root]
83
+ return _unwrap_shared_qspec(qspec, edge_or_node_to_qspec, shared_with_map)
84
+ return qspec
85
+
86
+ def _has_same_dtype(qspec_a: QuantizationSpecBase, qspec_b: QuantizationSpecBase):
87
+ return (
88
+ hasattr(qspec_a, "dtype") and
89
+ hasattr(qspec_b, "dtype") and
90
+ qspec_a.dtype == qspec_b.dtype
91
+ )
92
+
93
+ def _has_same_is_dynamic(qspec_a: QuantizationSpecBase, qspec_b: QuantizationSpecBase):
94
+ return (
95
+ hasattr(qspec_a, "is_dynamic") and
96
+ hasattr(qspec_b, "is_dynamic") and
97
+ qspec_a.is_dynamic == qspec_b.is_dynamic
98
+ )
99
+
100
+ def _get_edge_or_node_to_qspec(model: torch.fx.GraphModule) -> Dict[EdgeOrNode, QuantizationSpecBase]:
101
+ """Get a map from EdgeOrNode to quantization spec based on annotations on the nodes
102
+ """
103
+ edge_or_node_to_qspec: Dict[EdgeOrNode, QuantizationSpecBase] = {}
104
+ for n in model.graph.nodes:
105
+ if hasattr(n, "meta") and "quantization_annotation" in n.meta:
106
+ qa = n.meta["quantization_annotation"]
107
+ for input_to_n, qspec in qa.input_qspec_map.items():
108
+ input_edge = (input_to_n, n)
109
+ edge_or_node_to_qspec[input_edge] = qspec
110
+ if qa.output_qspec is not None:
111
+ output_node = n
112
+ qspec = qa.output_qspec
113
+ edge_or_node_to_qspec[output_node] = qspec
114
+ return edge_or_node_to_qspec
115
+
116
+ def _union_input_edge_with(input_edge, input_edge_root_qspec, edge_or_node, edge_or_node_to_qspec, shared_with_map):
117
+ """Union input edge with another edge or node, used in implicit sharing to point the current input
118
+ edge to other user edges of the producer node, or the output of producer node since these are
119
+ referring to the same Tensor
120
+ """
121
+ root_qspec = None
122
+ if edge_or_node in edge_or_node_to_qspec:
123
+ qspec = edge_or_node_to_qspec[edge_or_node]
124
+ root_qspec = _unwrap_shared_qspec(qspec, edge_or_node_to_qspec, shared_with_map)
125
+ # TODO: add assertions for types of root qspecs
126
+ if (
127
+ root_qspec is not None and
128
+ _has_same_dtype(root_qspec, input_edge_root_qspec) and
129
+ _has_same_is_dynamic(root_qspec, input_edge_root_qspec)
130
+ ):
131
+ # the input arg to the node should reuse the existing output observer for arg
132
+ # since dtype is the same (we may want to extend this to be a more strict check
133
+ # in the future)
134
+ # so we point from `input_edge` to `arg` (output of the argument)
135
+ _union(edge_or_node, input_edge, shared_with_map)
136
+
137
+
138
+ def _get_edge_or_node_to_group_id(edge_or_node_to_qspec: Dict[EdgeOrNode, QuantizationSpecBase]) -> Dict[EdgeOrNode, int]:
139
+ """Map from edge/node to the group ID, generated from quantization annotations,
140
+ edge/node with the same group ID should use the same observer/fake_quant instance
141
+
142
+ This is applying SharedQuantizationSpec configuration and map each edge/node to a group
143
+ There is another implicit sharing that's built in the quantization, when we have the following:
144
+ * op1 -> op2
145
+ * output of op1: int8_qspec
146
+ * (op1 -> op2) input edge: int8_qspec
147
+ we'll assume sharing between the output of op1 and input of (op1 -> op2) since these are the same Tensor.
148
+
149
+ Figuring out the correct group ID for all edge/node is a standard union find problem:
150
+ https://www.geeksforgeeks.org/introduction-to-disjoint-set-data-structure-or-union-find-algorithm/
151
+
152
+ Args:
153
+ edge_or_node_to_qspec: Dictionary from edge_or_node to the qspec, derived from annotations
154
+ Returns:
155
+ edge_or_node_to_group_id: Dictionary from edge_or_node to group_id (int), all edge or node that
156
+ belongs to the same group should have the same id
157
+
158
+ Example:
159
+ op2 -> cat1 -> cat2
160
+ op1 / /
161
+ op3
162
+ edge_or_node_to_qspec: {
163
+ op1: int8_qspec,
164
+ op2: int8_qspec,
165
+ (op1, cat1): int8_qspc,
166
+ (op2, cat1): SharedQuantizationSpec((op1, cat1)),
167
+ cat1: SharedQuantizationSpec((op1, cat1)),
168
+ (op3, cat2): int8_qspec,
169
+ (cat1, cat2): SharedQuantizationSpec((op3, cat2)),
170
+ cat2: SharedQuantizationSpec((op3, cat2)),
171
+ }
172
+
173
+ edge_or_node_to_group_id = _get_edge_or_node_to_group_id(edge_or_node_to_qspec)
174
+ edge_or_node_to_group_id: {
175
+ op1: 1,
176
+ op2: 1,
177
+ (op1, cat1): 1,
178
+ (op2, cat1): 1,
179
+ cat1: 1,
180
+ (op3, cat2): 1,
181
+ (cat1, cat2): 1,
182
+ cat2: 1,
183
+ }
184
+ # everything are in the same group because (cat1) and (cat1, cat2) are implicitly shared, which
185
+ # connects the two sharing group around cat1 and cat2 op due to transitive sharing
186
+ """
187
+ # means the observer of key should be shared with observer with value, by default it will
188
+ # be shared with itself
189
+ shared_with_map: Dict[EdgeOrNode, EdgeOrNode] = {k: k for k in edge_or_node_to_qspec.keys()}
190
+ for edge_or_node, qspec in edge_or_node_to_qspec.items():
191
+ if isinstance(edge_or_node, torch.fx.Node):
192
+ output_node = edge_or_node
193
+ _update_shared_with(output_node, qspec, shared_with_map)
194
+ else:
195
+ input_edge = edge_or_node
196
+ input_edge_root_qspec = _unwrap_shared_qspec(qspec, edge_or_node_to_qspec, shared_with_map)
197
+
198
+ assert isinstance(input_edge, tuple)
199
+ arg, n = input_edge
200
+ if n.meta["quantization_annotation"].allow_implicit_sharing:
201
+ # NOTE: the order is important here, we first share with other users and then share with previous
202
+ # output because the reverse order could cause circular dependency
203
+ # e.g node1 -> node2
204
+ # \ -> node3
205
+ # when processing (node1, node2), if we first point (node1, node2) to node1
206
+ # Step 1. shared_map = {(node1, node2): node1}
207
+ # Step 2. after that, we point the (node1, node2) to its other user (node1, node3) ,
208
+ # which means shared_map = {(node1, node2): node1, node1: (node1, node3)}
209
+ # because we will point the root of (node1, node2) (in this case node1) to the root of (node1, node3)
210
+ # Step 3. and when we process (node1, node3), it can try to point to node1 as well, then we'll
211
+ # have a circular dependency
212
+ # the following order works around this issue, but this does not allow arbitrary configuration
213
+ # of sharing so it might break in a different case in the future, when it breaks
214
+ # quantizer writer can check the notes here to debug the issue
215
+
216
+ # sharing with other users of the producer node
217
+ # (arg, user)
218
+ if not isinstance(arg, Node) or not isinstance(n, Node):
219
+ raise Exception(f"Expected input_edge to have type Tuple[Node, Node], but got: {arg, n}")
220
+ for user in arg.users:
221
+ if user is n:
222
+ continue
223
+ arg_to_user_edge = (arg, user)
224
+ _union_input_edge_with(
225
+ input_edge,
226
+ input_edge_root_qspec,
227
+ arg_to_user_edge,
228
+ edge_or_node_to_qspec,
229
+ shared_with_map
230
+ )
231
+
232
+ # sharing with output of producer node
233
+ _union_input_edge_with(input_edge, input_edge_root_qspec, arg, edge_or_node_to_qspec, shared_with_map)
234
+
235
+ _update_shared_with(input_edge, qspec, shared_with_map)
236
+
237
+ # now that we get the sharing relations between all edges and nodes, we can assingn group ids
238
+ cur_group_id = 0
239
+ edge_or_node_to_group_id: Dict[EdgeOrNode, int] = {}
240
+ for edge_or_node in shared_with_map.keys():
241
+ root = _find_root_edge_or_node(edge_or_node, shared_with_map)
242
+ if root not in edge_or_node_to_group_id:
243
+ edge_or_node_to_group_id[root] = cur_group_id
244
+ cur_group_id += 1
245
+ edge_or_node_to_group_id[edge_or_node] = edge_or_node_to_group_id[root]
246
+
247
+ return edge_or_node_to_group_id
248
+
249
+ def _get_obs_or_fq_map(
250
+ edge_or_node_to_group_id: Dict[EdgeOrNode, int],
251
+ edge_or_node_to_qspec: Dict[EdgeOrNode, QuantizationSpecBase],
252
+ is_qat: bool
253
+ ) -> Dict[EdgeOrNode, ObserverOrFakeQuantize]:
254
+ """Generates the EdgeOrNode to observer/fake_quant instances
255
+ Makes sure that for EdgeOrNode that has the same group_id should have the same observer or fake quant
256
+ instances
257
+ """
258
+ obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize] = {}
259
+ group_id_to_obs_or_fq: Dict[int, ObserverOrFakeQuantize] = {}
260
+ for edge_or_node, qspec in edge_or_node_to_qspec.items():
261
+ group_id = edge_or_node_to_group_id[edge_or_node]
262
+ if group_id not in group_id_to_obs_or_fq:
263
+ # TODO: maybe edge_or_node_to_qspec should be edge_or_node_to_root_qspec, this will simplify
264
+ # the implementation for _create_obs_or_fq_from_qspec
265
+ group_id_to_obs_or_fq[group_id] = _create_obs_or_fq_from_qspec(qspec, obs_or_fq_map, is_qat)
266
+ obs_or_fq_map[edge_or_node] = group_id_to_obs_or_fq[group_id]
267
+ return obs_or_fq_map
268
+
269
+ def _maybe_insert_input_observer_for_arg_or_kwarg(
270
+ node: Union[Node, Any],
271
+ arg: Argument,
272
+ qconfig: QConfigAny,
273
+ model: torch.nn.Module,
274
+ named_modules: Dict[str, torch.nn.Module],
275
+ obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize],
276
+ is_qat: bool,
277
+ ) -> Argument:
278
+ """
279
+ Given a `node` and an `arg`, inserts an input observer between
280
+ `node` and `arg` if necessary.
281
+ """
282
+ # for ops such as torch.cat([x0, x1]),
283
+ # traverse through the list
284
+ if isinstance(arg, (list, tuple)):
285
+ new_arg_to_return = []
286
+ for inner_arg in arg:
287
+ new_inner_arg = _maybe_insert_input_observer_for_arg_or_kwarg(
288
+ node, inner_arg, qconfig, model, named_modules, obs_or_fq_map, is_qat,
289
+ )
290
+ new_arg_to_return.append(new_inner_arg)
291
+ return type(arg)(new_arg_to_return)
292
+
293
+ if not isinstance(arg, Node):
294
+ return arg
295
+ assert isinstance(arg, Node)
296
+ # default (no observer)
297
+ new_arg = arg
298
+
299
+ # find the original `arg` node to the current node, skipping inserted observer/fake_quant nodes
300
+ original_arg = arg
301
+ while _is_activation_post_process_node(original_arg, named_modules):
302
+ original_arg = original_arg.args[0] # type: ignore[assignment]
303
+ assert isinstance(original_arg, Node), f"expect original argument to be a Node, but got: {type(original_arg)}"
304
+
305
+ input_edge = (original_arg, node)
306
+ if input_edge not in obs_or_fq_map:
307
+ return new_arg
308
+ # input_edge needs to be observed
309
+ input_edge_obs_or_fq = obs_or_fq_map[input_edge]
310
+ if input_edge_obs_or_fq is None:
311
+ return new_arg
312
+
313
+ arg_as_output_obs_or_fq = obs_or_fq_map.get(original_arg, None)
314
+ # the arg is observed as the output and is using the same instance as the input_edge
315
+ # we'll reuse the inserted observer/fake_quant
316
+ if arg_as_output_obs_or_fq is not None and id(arg_as_output_obs_or_fq) == id(input_edge_obs_or_fq):
317
+ return new_arg
318
+
319
+ # otherwise, we'll insert a new observer/fake_quant node
320
+
321
+ existing_obs_node = None
322
+ # skip inserting new observers if the same observer instance is inserted before for another user
323
+ # Example:
324
+ # conv1 -> obs1 -> existing_obs -> conv2
325
+ # \ -> conv3
326
+ #
327
+ # instead of inserting new observers we will have:
328
+ # conv1 -> obs1 -> existing_obs -> conv2
329
+ # \ -> conv3
330
+ for maybe_obs_node in arg.users.keys():
331
+ if not _is_activation_post_process_node(maybe_obs_node, named_modules):
332
+ continue
333
+ maybe_obs_mod = named_modules[maybe_obs_node.target] # type: ignore[index]
334
+ if id(maybe_obs_mod) == id(input_edge_obs_or_fq):
335
+ return maybe_obs_node
336
+
337
+ new_arg = _insert_obs_or_fq(arg, input_edge_obs_or_fq, model, named_modules, model.graph)
338
+ return new_arg
339
+
340
+ def _maybe_insert_input_observers_for_node(
341
+ node: Node,
342
+ qconfig: QConfigAny,
343
+ model: torch.nn.Module,
344
+ named_modules: Dict[str, torch.nn.Module],
345
+ obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize],
346
+ is_qat: bool,
347
+ ) -> None:
348
+ """
349
+ If needed, inserts observers to the input args and kwargs of `node`.
350
+ Note: modifies `node` inplace.
351
+
352
+ For example, if cur_node needs an observer after prev_node, we change from
353
+
354
+ prev_node -> cur_node
355
+
356
+ To
357
+
358
+ prev_node -> obs -> cur_node
359
+
360
+ """
361
+ # Look through every input arg. If that arg's target dtype does not
362
+ # match the current node's target dtype, insert an observer.
363
+ new_args = []
364
+ # map from old arg to new arg, used for updating the numeric debug handle map
365
+ remap = {}
366
+ for arg in node.args:
367
+ new_arg = _maybe_insert_input_observer_for_arg_or_kwarg(
368
+ node, arg, qconfig, model, named_modules, obs_or_fq_map, is_qat,
369
+ )
370
+ new_args.append(new_arg)
371
+ remap[arg] = new_arg
372
+
373
+ if "numeric_debug_handle" in node.meta:
374
+
375
+ def remap_fn(x):
376
+ return remap.get(x, x)
377
+
378
+ numeric_debug_handle = node.meta["numeric_debug_handle"]
379
+ node.meta["numeric_debug_handle"] = {remap_fn(k): v for k, v in numeric_debug_handle.items()}
380
+
381
+ # Clone has a memory_format kwarg and zeros_like has a pin_memory kwarg
382
+ # that persist in exported graph. This is just a work around for these.
383
+ assert (
384
+ node.target == torch.ops.aten.clone.default or
385
+ node.target == torch.ops.aten.zeros_like.default or
386
+ len(node.kwargs) == 0
387
+ ), " expecting kwargs for aten op IR to be empty"
388
+
389
+ # assign the new args to the node, inplace
390
+ node.args = tuple(new_args)
391
+
392
+ def _maybe_insert_output_observer_for_node(
393
+ node: Node,
394
+ model: torch.nn.Module,
395
+ named_modules: Dict[str, torch.nn.Module],
396
+ graph: Graph,
397
+ obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize],
398
+ is_qat: bool,
399
+ ) -> Optional[Node]:
400
+ if node in obs_or_fq_map:
401
+ output_act_obs_or_fq = obs_or_fq_map[node]
402
+ return _insert_obs_or_fq(node, output_act_obs_or_fq, model, named_modules, graph)
403
+ return None
404
+
405
+ def _maybe_insert_input_and_output_observers_for_node(
406
+ node: Node,
407
+ model: torch.fx.GraphModule,
408
+ obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize],
409
+ is_qat: bool,
410
+ ):
411
+ this_node_quantization_annotation = node.meta["quantization_annotation"] if "quantization_annotation" in node.meta else None
412
+ if this_node_quantization_annotation is None:
413
+ return
414
+
415
+ named_modules = dict(model.named_modules(remove_duplicate=False))
416
+ _maybe_insert_input_observers_for_node(
417
+ node,
418
+ None, # qconfig
419
+ model,
420
+ named_modules,
421
+ obs_or_fq_map,
422
+ is_qat,
423
+ )
424
+
425
+ output_is_a_tensor = "val" in node.meta and isinstance(node.meta["val"], FakeTensor)
426
+ if not output_is_a_tensor:
427
+ return
428
+
429
+ # this returns the new observer node if it was needed
430
+ maybe_output_obs_node = _maybe_insert_output_observer_for_node(
431
+ node, model, named_modules, model.graph, obs_or_fq_map, is_qat)
432
+
433
+ if maybe_output_obs_node is None:
434
+ return
435
+ # Update users of original node to use the output observer
436
+ # instead. For example, change
437
+ #
438
+ # next_node
439
+ # /
440
+ # cur_node -> obs
441
+ #
442
+ # to
443
+ #
444
+ # next_node
445
+ # /
446
+ # cur_node -> obs
447
+ #
448
+ # We need to save orig users before updating uses because
449
+ # the list of users will change as we update uses
450
+ orig_users = list(node.users.keys())
451
+ for user_node in orig_users:
452
+ if user_node is maybe_output_obs_node:
453
+ continue
454
+ user_node.replace_input_with(node, maybe_output_obs_node)
455
+
456
+ def prepare(
457
+ model: GraphModule,
458
+ node_name_to_scope: Dict[str, Tuple[str, type]],
459
+ is_qat: bool,
460
+ ) -> GraphModule:
461
+ # Since we are mutating the graph as we go, we iterate over the original
462
+ # nodes before observer insertion, instead of model.graph.nodes.
463
+ nodes_before_observation = list(model.graph.nodes)
464
+
465
+ # At the high level we construct a map from EdgeOrNode to a observer_or_fake_quant instance
466
+ # all edge/nodes that belongs to the same group will use the same instance
467
+ # and when we insert observers we'll just query this map to get the correct observer_or_fake_quant
468
+ # instance
469
+ edge_or_node_to_qspec = _get_edge_or_node_to_qspec(model)
470
+ edge_or_node_to_group_id = _get_edge_or_node_to_group_id(edge_or_node_to_qspec)
471
+ obs_or_fq_map = _get_obs_or_fq_map(edge_or_node_to_group_id, edge_or_node_to_qspec, is_qat)
472
+
473
+ for node in nodes_before_observation:
474
+ # TODO: simplify logic for inserting observers
475
+ _maybe_insert_input_and_output_observers_for_node(node, model, obs_or_fq_map, is_qat)
476
+
477
+ model = GraphModule(model, model.graph)
478
+
479
+ _save_state(
480
+ model,
481
+ {}, # node_name_to_qconfig
482
+ node_name_to_scope,
483
+ PrepareCustomConfig(),
484
+ {}, # equalization_node_name_to_qconfig
485
+ QConfigMapping(),
486
+ is_qat,
487
+ set() # observed_node_names
488
+ )
489
+ return model
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/representation/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .rewrite import reference_representation_rewrite
2
+
3
+ __all__ = [
4
+ "reference_representation_rewrite",
5
+ ]
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/qconfig_mapping.py ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from collections import OrderedDict
3
+ from typing import Any, Callable, Dict, Tuple, Union, List
4
+
5
+ import torch
6
+
7
+ from .fake_quantize import (
8
+ default_weight_fake_quant,
9
+ FixedQParamsFakeQuantize,
10
+ )
11
+ from .observer import (
12
+ _PartialWrapper,
13
+ default_fixed_qparams_range_0to1_observer,
14
+ default_fixed_qparams_range_neg1to1_observer,
15
+ default_placeholder_observer,
16
+ default_weight_observer,
17
+ )
18
+ from .qconfig import (
19
+ default_reuse_input_qconfig,
20
+ default_symmetric_qnnpack_qconfig,
21
+ default_symmetric_qnnpack_qat_qconfig,
22
+ get_default_qconfig,
23
+ get_default_qat_qconfig,
24
+ QConfig,
25
+ QConfigAny,
26
+ default_quint8_weight_qconfig
27
+ )
28
+
29
+
30
+ __all__ = [
31
+ "get_default_qconfig_mapping",
32
+ "get_default_qat_qconfig_mapping",
33
+ "QConfigMapping",
34
+ ]
35
+
36
+
37
+ # TODO: replace all usages with these constants
38
+ _GLOBAL_DICT_KEY = ""
39
+ _OBJECT_TYPE_DICT_KEY = "object_type"
40
+ _MODULE_NAME_REGEX_DICT_KEY = "module_name_regex"
41
+ _MODULE_NAME_DICT_KEY = "module_name"
42
+ _MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY = "module_name_object_type_order"
43
+
44
+ # TODO: derive this map from the BackendConfig
45
+ _FIXED_QPARAMS_OP_TO_OBSERVER: Dict[Union[Callable, str], _PartialWrapper] = {
46
+ torch.nn.Hardsigmoid: default_fixed_qparams_range_0to1_observer,
47
+ torch.nn.functional.hardsigmoid: default_fixed_qparams_range_0to1_observer,
48
+ "hardsigmoid": default_fixed_qparams_range_0to1_observer,
49
+ "hardsigmoid_": default_fixed_qparams_range_0to1_observer,
50
+ torch.nn.Sigmoid: default_fixed_qparams_range_0to1_observer,
51
+ torch.sigmoid: default_fixed_qparams_range_0to1_observer,
52
+ "sigmoid": default_fixed_qparams_range_0to1_observer,
53
+ "sigmoid_": default_fixed_qparams_range_0to1_observer,
54
+ torch.nn.Softmax: default_fixed_qparams_range_0to1_observer,
55
+ torch.nn.Tanh: default_fixed_qparams_range_neg1to1_observer,
56
+ torch.tanh: default_fixed_qparams_range_neg1to1_observer,
57
+ "tanh": default_fixed_qparams_range_neg1to1_observer,
58
+ "tanh_": default_fixed_qparams_range_neg1to1_observer,
59
+ }
60
+
61
+
62
+ def _get_default_qconfig_mapping(is_qat: bool, backend: str, version: int) -> QConfigMapping:
63
+ """
64
+ Return the default QConfigMapping for the given quantization type and backend.
65
+ """
66
+ if is_qat:
67
+ qconfig = get_default_qat_qconfig(backend, version)
68
+ else:
69
+ qconfig = get_default_qconfig(backend, version)
70
+ default_weight = default_weight_fake_quant if is_qat else default_weight_observer
71
+
72
+ # default_per_channel_weight_observer is not currently compatible with fbgemm backend
73
+ # so we have to modify the weight observer to default_weight_observer or another
74
+ # per tensor supported observer.
75
+ # see https://github.com/pytorch/pytorch/issues/47535
76
+ if backend in ("fbgemm", "x86"):
77
+ qconfig_transpose = QConfig(activation=qconfig.activation, weight=default_weight)
78
+ else:
79
+ qconfig_transpose = qconfig
80
+
81
+ # currently layernorm only supports float weights
82
+ # we have to add this because otherwise there will be a extra quantize-dequantize pair
83
+ qconfig_layernorm = QConfig(activation=qconfig.activation, weight=default_placeholder_observer)
84
+
85
+ qconfig_mapping = QConfigMapping() \
86
+ .set_global(qconfig) \
87
+ .set_object_type("reshape", default_reuse_input_qconfig) \
88
+ .set_object_type(torch.nn.ConvTranspose1d, qconfig_transpose) \
89
+ .set_object_type(torch.nn.ConvTranspose2d, qconfig_transpose) \
90
+ .set_object_type(torch.nn.ConvTranspose3d, qconfig_transpose) \
91
+ .set_object_type(torch.nn.functional.conv_transpose1d, qconfig_transpose) \
92
+ .set_object_type(torch.nn.functional.conv_transpose2d, qconfig_transpose) \
93
+ .set_object_type(torch.nn.functional.conv_transpose3d, qconfig_transpose) \
94
+ .set_object_type(torch.nn.functional.layer_norm, qconfig_layernorm) \
95
+ .set_object_type(torch.nn.LayerNorm, qconfig_layernorm) \
96
+ .set_object_type(torch.nn.PReLU, default_quint8_weight_qconfig) \
97
+
98
+ # Use special observers for ops with fixed qparams
99
+ fixed_qparams_observer_to_qconfig: Dict[Any, QConfigAny] = {}
100
+ for fixed_qparams_op, observer in _FIXED_QPARAMS_OP_TO_OBSERVER.items():
101
+ if observer in fixed_qparams_observer_to_qconfig:
102
+ fixed_qparams_qconfig = fixed_qparams_observer_to_qconfig[observer]
103
+ else:
104
+ if is_qat:
105
+ activation = FixedQParamsFakeQuantize.with_args(observer=observer)
106
+ else:
107
+ activation = observer
108
+ fixed_qparams_qconfig = QConfig(activation=activation, weight=default_weight)
109
+ fixed_qparams_observer_to_qconfig[observer] = fixed_qparams_qconfig
110
+ qconfig_mapping.set_object_type(fixed_qparams_op, fixed_qparams_qconfig)
111
+
112
+ # TODO Currently it's required that separate ops in a fused op/module have the same qconfig.
113
+ # Need to be able to support fusion of ops with different qconfigs
114
+
115
+ return qconfig_mapping
116
+
117
+ def get_default_qconfig_mapping(backend="x86", version=0) -> QConfigMapping:
118
+ """
119
+ Return the default QConfigMapping for post training quantization.
120
+
121
+ Args:
122
+ * ``backend`` (str) : the quantization backend for the default qconfig mapping, should be
123
+ one of ["x86" (default), "fbgemm", "qnnpack", "onednn"]
124
+ * ``version`` (int) : the version for the default qconfig mapping
125
+ """
126
+ # TODO: add assert for backend choices
127
+ return _get_default_qconfig_mapping(False, backend, version)
128
+
129
+ def get_default_qat_qconfig_mapping(backend="x86", version=1) -> QConfigMapping:
130
+ """
131
+ Return the default QConfigMapping for quantization aware training.
132
+
133
+ Args:
134
+ * ``backend`` (str) : the quantization backend for the default qconfig mapping, should be
135
+ one of ["x86" (default), "fbgemm", "qnnpack", "onednn"]
136
+ * ``version`` (int) : the version for the default qconfig mapping
137
+ """
138
+ return _get_default_qconfig_mapping(True, backend, version)
139
+
140
+ def _get_symmetric_qnnpack_qconfig_mapping() -> QConfigMapping:
141
+ """
142
+ Return a QConfigMapping that uses `torch.ao.quantization.default_symmetric_qnnpack_qconfig`
143
+ as the default QConfig.
144
+ """
145
+ default_qconfig = default_symmetric_qnnpack_qconfig
146
+ return _get_default_qconfig_mapping_with_default_qconfig(False, "qnnpack", default_qconfig)
147
+
148
+ def _get_symmetric_qnnpack_qat_qconfig_mapping() -> QConfigMapping:
149
+ """
150
+ Return a QConfigMapping that uses `torch.ao.quantization.default_symmetric_qnnpack_qat_qconfig`
151
+ as the default QConfig.
152
+ """
153
+ default_qconfig = default_symmetric_qnnpack_qat_qconfig
154
+ return _get_default_qconfig_mapping_with_default_qconfig(True, "qnnpack", default_qconfig)
155
+
156
+ def _get_default_qconfig_mapping_with_default_qconfig(
157
+ is_qat: bool,
158
+ backend: str,
159
+ default_qconfig: QConfig,
160
+ ) -> QConfigMapping:
161
+ """
162
+ Return a QConfigMapping that uses the provided qconfig as the default QConfig.
163
+ """
164
+ if is_qat:
165
+ qconfig_mapping = get_default_qat_qconfig_mapping(backend)
166
+ else:
167
+ qconfig_mapping = get_default_qconfig_mapping(backend)
168
+ qconfig_mapping.set_global(default_qconfig)
169
+ for pattern in qconfig_mapping.object_type_qconfigs.keys():
170
+ if pattern not in _FIXED_QPARAMS_OP_TO_OBSERVER:
171
+ qconfig_mapping.set_object_type(pattern, default_qconfig)
172
+ return qconfig_mapping
173
+
174
+ _QCONFIG_STYLE_ORDER: List[str] = [
175
+ "global_qconfig",
176
+ "object_type_qconfigs",
177
+ "module_name_regex_qconfigs",
178
+ "module_name_qconfigs",
179
+ "module_name_object_type_order_qconfigs",
180
+ ]
181
+
182
+ class QConfigMapping:
183
+ """
184
+ Mapping from model ops to :class:`torch.ao.quantization.QConfig` s.
185
+
186
+ The user can specify QConfigs using the following methods (in increasing match priority):
187
+
188
+ ``set_global`` : sets the global (default) QConfig
189
+
190
+ ``set_object_type`` : sets the QConfig for a given module type, function, or method name
191
+
192
+ ``set_module_name_regex`` : sets the QConfig for modules matching the given regex string
193
+
194
+ ``set_module_name`` : sets the QConfig for modules matching the given module name
195
+
196
+ ``set_module_name_object_type_order`` : sets the QConfig for modules matching a combination
197
+ of the given module name, object type, and the index at which the module appears
198
+
199
+ Example usage::
200
+
201
+ qconfig_mapping = QConfigMapping()
202
+ .set_global(global_qconfig)
203
+ .set_object_type(torch.nn.Linear, qconfig1)
204
+ .set_object_type(torch.nn.ReLU, qconfig1)
205
+ .set_module_name_regex("foo.*bar.*conv[0-9]+", qconfig1)
206
+ .set_module_name_regex("foo.*", qconfig2)
207
+ .set_module_name("module1", qconfig1)
208
+ .set_module_name("module2", qconfig2)
209
+ .set_module_name_object_type_order("foo.bar", torch.nn.functional.linear, 0, qconfig3)
210
+
211
+ """
212
+
213
+ def __init__(self):
214
+ # In increasing match priority:
215
+ self.global_qconfig: QConfigAny = None
216
+ self.object_type_qconfigs: OrderedDict[Union[Callable, str], QConfigAny] = OrderedDict()
217
+ self.module_name_regex_qconfigs: OrderedDict[str, QConfigAny] = OrderedDict()
218
+ self.module_name_qconfigs: OrderedDict[str, QConfigAny] = OrderedDict()
219
+ self.module_name_object_type_order_qconfigs: OrderedDict[Tuple[str, Callable, int], QConfigAny] =\
220
+ OrderedDict()
221
+
222
+ def set_global(self, global_qconfig: QConfigAny) -> QConfigMapping:
223
+ """
224
+ Set the global (default) QConfig.
225
+ """
226
+ self.global_qconfig = global_qconfig
227
+ return self
228
+
229
+ def set_object_type(self, object_type: Union[Callable, str], qconfig: QConfigAny) -> QConfigMapping:
230
+ """
231
+ Set the QConfig for a given module type, function, or method name.
232
+ If the QConfig for an existing object type was already set, the new QConfig will override the old one.
233
+ """
234
+ self.object_type_qconfigs[object_type] = qconfig
235
+ return self
236
+
237
+ def set_module_name_regex(self, module_name_regex: str, qconfig: QConfigAny) -> QConfigMapping:
238
+ """
239
+ Set the QConfig for modules matching the given regex string.
240
+
241
+ Regexes will be matched in the order in which they are registered through this method.
242
+ Thus, the caller should register more specific patterns first, e.g.::
243
+
244
+ qconfig_mapping = QConfigMapping()
245
+ .set_module_name_regex("foo.*bar.*conv[0-9]+", qconfig1)
246
+ .set_module_name_regex("foo.*bar.*", qconfig2)
247
+ .set_module_name_regex("foo.*", qconfig3)
248
+
249
+ In this example, "foo.bar.conv0" would match qconfig1, "foo.bar.linear" would match qconfig2,
250
+ and "foo.baz.relu" would match qconfig3.
251
+
252
+ If the QConfig for an existing module name regex was already set, the new QConfig will override the
253
+ old one while preserving the order in which the regexes were originally registered.
254
+ """
255
+ self.module_name_regex_qconfigs[module_name_regex] = qconfig
256
+ return self
257
+
258
+ def set_module_name(self, module_name: str, qconfig: QConfigAny) -> QConfigMapping:
259
+ """
260
+ Set the QConfig for modules matching the given module name.
261
+ If the QConfig for an existing module name was already set, the new QConfig will override the old one.
262
+ """
263
+ self.module_name_qconfigs[module_name] = qconfig
264
+ return self
265
+
266
+ def set_module_name_object_type_order(
267
+ self,
268
+ module_name: str,
269
+ object_type: Callable,
270
+ index: int,
271
+ qconfig: QConfigAny) -> QConfigMapping:
272
+ """
273
+ Set the QConfig for modules matching a combination of the given module name, object type,
274
+ and the index at which the module appears.
275
+
276
+ If the QConfig for an existing (module name, object type, index) was already set, the new QConfig
277
+ will override the old one.
278
+ """
279
+ self.module_name_object_type_order_qconfigs[(module_name, object_type, index)] = qconfig
280
+ return self
281
+
282
+ def __repr__(self) -> str:
283
+ output = self.__class__.__name__ + " ("
284
+ for style_name in _QCONFIG_STYLE_ORDER:
285
+ output += f"\n {style_name}"
286
+ qconfigs = getattr(self, style_name)
287
+ if isinstance(qconfigs, OrderedDict) and len(qconfigs) > 0:
288
+ for key, qconfig in qconfigs.items():
289
+ output += f"\n {key}: {qconfig}"
290
+ else:
291
+ output += f"\n {qconfigs}"
292
+ return output + "\n)"
293
+
294
+ # TODO: remove this
295
+ def to_dict(self) -> Dict[str, Any]:
296
+ """
297
+ Convert this ``QConfigMapping`` to a dictionary with the following keys:
298
+
299
+ "" (for global QConfig)
300
+
301
+ "object_type"
302
+
303
+ "module_name_regex"
304
+
305
+ "module_name"
306
+
307
+ "module_name_object_type_order"
308
+
309
+ The values of this dictionary are lists of tuples.
310
+ """
311
+ return {
312
+ _GLOBAL_DICT_KEY: self.global_qconfig,
313
+ _OBJECT_TYPE_DICT_KEY: list(self.object_type_qconfigs.items()),
314
+ _MODULE_NAME_REGEX_DICT_KEY: list(self.module_name_regex_qconfigs.items()),
315
+ _MODULE_NAME_DICT_KEY: list(self.module_name_qconfigs.items()),
316
+ _MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY: [
317
+ (*k, v) for k, v in self.module_name_object_type_order_qconfigs.items()
318
+ ],
319
+ }
320
+
321
+ # TODO: remove this
322
+ @classmethod
323
+ def from_dict(cls, qconfig_dict: Dict[str, Any]) -> QConfigMapping:
324
+ """
325
+ Create a ``QConfigMapping`` from a dictionary with the following keys (all optional):
326
+
327
+ "" (for global QConfig)
328
+
329
+ "object_type"
330
+
331
+ "module_name_regex"
332
+
333
+ "module_name"
334
+
335
+ "module_name_object_type_order"
336
+
337
+ The values of this dictionary are expected to be lists of tuples.
338
+ """
339
+ conf = cls()
340
+ if _GLOBAL_DICT_KEY in qconfig_dict:
341
+ conf.set_global(qconfig_dict[_GLOBAL_DICT_KEY])
342
+ for object_type, qconfig in qconfig_dict.get(_OBJECT_TYPE_DICT_KEY, []):
343
+ conf.set_object_type(object_type, qconfig)
344
+ for module_name_regex, qconfig in qconfig_dict.get(_MODULE_NAME_REGEX_DICT_KEY, []):
345
+ conf.set_module_name_regex(module_name_regex, qconfig)
346
+ for module_name, qconfig in qconfig_dict.get(_MODULE_NAME_DICT_KEY, []):
347
+ conf.set_module_name(module_name, qconfig)
348
+ for module_name, object_type, index, qconfig in qconfig_dict.get(_MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY, []):
349
+ conf.set_module_name_object_type_order(module_name, object_type, index, qconfig)
350
+ return conf
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/quantizer/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (663 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/quantizer/__pycache__/xnnpack_quantizer.cpython-311.pyc ADDED
Binary file (19.7 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (13.6 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/__pycache__/_remove_auto_functionalized_pass.cpython-311.pyc ADDED
Binary file (4.41 kB). View file