koichi12 commited on
Commit
8caf96a
·
verified ·
1 Parent(s): 6229f35

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +3 -0
  2. tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/Visitor.cpython-311-x86_64-linux-gnu.so +3 -0
  3. tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cuda_cupti/lib/libcheckpoint.so +3 -0
  4. tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/pkg_resources/__pycache__/__init__.cpython-311.pyc +3 -0
  5. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_custom_op/impl.py +976 -0
  6. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/fft.py +590 -0
  7. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/__pycache__/__init__.cpython-311.pyc +0 -0
  8. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/modules/__pycache__/__init__.cpython-311.pyc +0 -0
  9. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/qat/__init__.py +1 -0
  10. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/__pycache__/__init__.cpython-311.pyc +0 -0
  11. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/reference/__init__.py +18 -0
  12. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/__init__.cpython-311.pyc +0 -0
  13. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/utils.py +533 -0
  14. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/__init__.cpython-311.pyc +0 -0
  15. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/_correct_bias.cpython-311.pyc +0 -0
  16. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/_equalize.cpython-311.pyc +0 -0
  17. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/fuser_method_mappings.cpython-311.pyc +0 -0
  18. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/quant_type.cpython-311.pyc +0 -0
  19. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/quantize_pt2e.cpython-311.pyc +0 -0
  20. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/stubs.cpython-311.pyc +0 -0
  21. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/__pycache__/native.cpython-311.pyc +0 -0
  22. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/_qnnpack_pt2e.py +160 -0
  23. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/fbgemm.py +116 -0
  24. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/native.py +204 -0
  25. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fuser_method_mappings.py +259 -0
  26. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__init__.py +3 -0
  27. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/_decomposed.cpython-311.pyc +0 -0
  28. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/_equalize.cpython-311.pyc +0 -0
  29. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/custom_config.cpython-311.pyc +0 -0
  30. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/fuse.cpython-311.pyc +0 -0
  31. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/graph_module.cpython-311.pyc +0 -0
  32. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/lower_to_qnnpack.cpython-311.pyc +0 -0
  33. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/match_utils.cpython-311.pyc +0 -0
  34. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/pattern_utils.cpython-311.pyc +0 -0
  35. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/prepare.cpython-311.pyc +0 -0
  36. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/qconfig_mapping_utils.cpython-311.pyc +0 -0
  37. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/quantize_handler.cpython-311.pyc +0 -0
  38. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/_equalize.py +820 -0
  39. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/_model_report/__init__.py +0 -0
  40. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/__init__.cpython-311.pyc +0 -0
  41. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/detector.cpython-311.pyc +0 -0
  42. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/model_report.cpython-311.pyc +0 -0
  43. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/model_report_observer.cpython-311.pyc +0 -0
  44. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/model_report_visualizer.cpython-311.pyc +0 -0
  45. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/_model_report/detector.py +1539 -0
  46. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/_model_report/model_report_visualizer.py +666 -0
  47. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/quantize_handler.py +197 -0
  48. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/utils.py +885 -0
  49. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/__init__.py +0 -0
  50. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/__pycache__/export_utils.cpython-311.pyc +0 -0
.gitattributes CHANGED
@@ -69,3 +69,6 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/distl
69
  tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/rich/__pycache__/console.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
70
  tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Debugger/__pycache__/libpython.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
71
  tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/FlowControl.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
 
 
 
 
69
  tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/rich/__pycache__/console.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
70
  tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Debugger/__pycache__/libpython.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
71
  tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/FlowControl.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
72
+ tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cuda_cupti/lib/libcheckpoint.so filter=lfs diff=lfs merge=lfs -text
73
+ tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/Visitor.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
74
+ tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/pkg_resources/__pycache__/__init__.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/Visitor.cpython-311-x86_64-linux-gnu.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8597f1985804f6c0c55b84d29a8744f0e2bc6600aaa695402499fbbbcba1decc
3
+ size 374848
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cuda_cupti/lib/libcheckpoint.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:38073c63ab8f022926f58f7cb39c565005f382bdfacd85822e7502a5256b6671
3
+ size 1509528
tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/pkg_resources/__pycache__/__init__.cpython-311.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f1e41fea31e2f114e2b8bb3065092e62588a33b909a8fa70bc578e734128e529
3
+ size 176864
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_custom_op/impl.py ADDED
@@ -0,0 +1,976 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ import functools
3
+ import inspect
4
+ import sys
5
+ import typing
6
+ import weakref
7
+
8
+ from torchgen.model import FunctionSchema, OperatorName, SchemaKind, BaseType, ListType, BaseTy
9
+
10
+ import torch
11
+ import torch._C as _C
12
+ import torch.library as library
13
+ from torch._library.abstract_impl import AbstractImplCtx
14
+ from torch.library import get_ctx
15
+
16
+ from .autograd import autograd_kernel_indirection, construct_autograd_kernel
17
+
18
+ """
19
+ For a detailed guide on custom ops, please see
20
+ https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
21
+
22
+ This file includes pieces of the implementation of our custom operator API.
23
+ """
24
+
25
+ __all__ = ["custom_op", "CustomOp", "get_ctx", "AbstractImplCtx"]
26
+
27
+
28
+ SUPPORTED_DEVICE_TYPE_TO_KEY = {
29
+ "cpu": "CPU",
30
+ "cuda": "CUDA",
31
+ }
32
+
33
+ # We will not let users register CustomOps with anything that could look like
34
+ # PyTorch internals to avoid confusion.
35
+ RESERVED_NS = {
36
+ "prim",
37
+ "prims",
38
+ "aten",
39
+ "at",
40
+ "torch",
41
+ "pytorch",
42
+ }
43
+
44
+
45
+ def custom_op(
46
+ qualname: str, manual_schema: typing.Optional[str] = None
47
+ ) -> typing.Callable:
48
+ r"""Creates a new CustomOp object.
49
+
50
+ WARNING: if you're a user, please do not use this directly
51
+ (instead use the torch._custom_ops APIs).
52
+ Also please see the following for a detailed guide on custom ops.
53
+ https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
54
+
55
+ In PyTorch, defining an op (short for "operator") is a two step-process:
56
+ - we need to define (create) the op
57
+ - we need to implement behavior for how the operator interacts with
58
+ various PyTorch subsystems, like CPU/CUDA Tensors, Autograd, etc.
59
+
60
+ This entrypoint defines the CustomOp object (the first step);
61
+ you must then perform the second step by calling various methods on
62
+ the CustomOp object.
63
+
64
+ This API is used as a decorator (see examples).
65
+
66
+ Arguments:
67
+ qualname (str): Should be a string that looks like
68
+ "namespace::operator_name". Operators in PyTorch need a namespace to
69
+ avoid name collisions; a given operator may only be created once.
70
+ If you are writing a Python library, we recommend the namespace to
71
+ be the name of your top-level module. The operator_name must be
72
+ the same as the name of the function you pass to custom_op
73
+ (see examples).
74
+ manual_schema (Optional[str]): Each PyTorch operator needs a schema that
75
+ tells PyTorch the types of the inputs/outputs. If None (default),
76
+ we will infer the schema from the type annotations on the function
77
+ (see examples). Otherwise, if you don't want to use type annotations,
78
+ you may provide us the schema string.
79
+
80
+ Example::
81
+ >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
82
+ >>> import numpy as np
83
+ >>> from torch import Tensor
84
+ >>>
85
+ >>> # Step 1: define the CustomOp.
86
+ >>> # We need to provide the decorator a "prototype function"
87
+ >>> # (a function with Python ellipses as the body).
88
+ >>> @custom_op("my_library::numpy_sin")
89
+ >>> def numpy_sin(x: Tensor) -> Tensor:
90
+ >>> ...
91
+ >>>
92
+ >>> # numpy_sin is now an instance of class CustomOp
93
+ >>> print(type(numpy_sin))
94
+ >>>
95
+ >>> # Step 2: Register an implementation for various PyTorch subsystems
96
+ >>>
97
+ >>> # Register an implementation for CPU tensors
98
+ >>> @numpy_sin.impl('cpu')
99
+ >>> def numpy_sin_impl_cpu(x):
100
+ >>> return torch.from_numpy(np.sin(x.numpy()))
101
+ >>>
102
+ >>> # Register an implementation for CUDA tensors
103
+ >>> @numpy_sin.impl('cuda')
104
+ >>> def numpy_sin_impl_cuda(x):
105
+ >>> return torch.from_numpy(np.sin(x.cpu().numpy())).to(x.device)
106
+ >>>
107
+ >>> x = torch.randn(3)
108
+ >>> numpy_sin(x) # calls numpy_sin_impl_cpu
109
+ >>>
110
+ >>> x_cuda = x.cuda()
111
+ >>> numpy_sin(x) # calls numpy_sin_impl_cuda
112
+
113
+ """
114
+
115
+ def inner(func):
116
+ if not inspect.isfunction(func):
117
+ raise ValueError(
118
+ f"custom_op(...)(func): Expected `func` to be a Python "
119
+ f"function, got: {type(func)}"
120
+ )
121
+
122
+ ns, name = parse_qualname(qualname)
123
+ validate_namespace(ns)
124
+ if func.__name__ != name:
125
+ raise ValueError(
126
+ f"custom_op(qualname='{qualname}', ...)(func): expected `func` "
127
+ f"to have name '{name}' but got '{func.__name__}'. "
128
+ f"Please either change the name of `func` or the qualname that "
129
+ f"is passed to `custom_op`"
130
+ )
131
+
132
+ schema = infer_schema(func) if manual_schema is None else manual_schema
133
+ schema_str = f"{name}{schema}"
134
+ function_schema = FunctionSchema.parse(schema_str)
135
+ validate_schema(function_schema)
136
+ if manual_schema is not None:
137
+ validate_function_matches_schema(function_schema, func)
138
+
139
+ lib = library.Library(ns, "FRAGMENT")
140
+ lib.define(schema_str)
141
+ ophandle = find_ophandle_or_throw(ns, function_schema.name)
142
+ result = CustomOp(lib, ns, function_schema, name, ophandle, _private_access=True)
143
+
144
+ result.__name__ = func.__name__
145
+ result.__module__ = func.__module__
146
+ result.__doc__ = func.__doc__
147
+
148
+ library.impl(lib, result._opname, "Autograd")(
149
+ autograd_kernel_indirection(weakref.proxy(result))
150
+ )
151
+
152
+ torch._C._dispatch_set_report_error_callback(
153
+ ophandle, functools.partial(report_error_callback, weakref.proxy(result))
154
+ )
155
+
156
+ return result
157
+
158
+ return inner
159
+
160
+
161
+ # Global dictionary holding references to all CustomOp objects
162
+ # Yes, it keeps all CustomOps alive (see NOTE [CustomOp lifetime])
163
+ # Used to query the CustomOp associated with a specific C++ dispatcher operator.
164
+ # An example usage is FakeTensor: FakeTensor checks if a specific operator
165
+ # has an implementation registered via the CustomOp API.
166
+ # Indexed by qualname (e.g. aten::foo)
167
+ global_registry: typing.Dict[str, "CustomOp"] = {}
168
+
169
+
170
+ class CustomOp:
171
+ r"""Class for custom operators in PyTorch.
172
+
173
+ Use the CustomOp API to create user-defined custom operators that behave
174
+ just like regular PyTorch operators (e.g. torch.sin, torch.mm) when it
175
+ comes to various PyTorch subsystems (like torch.compile).
176
+
177
+ To construct a `CustomOp`, use `custom_op`.
178
+ """
179
+
180
+ def __init__(self, lib, cpp_ns, schema, operator_name, ophandle, *, _private_access=False):
181
+ super().__init__()
182
+ if not _private_access:
183
+ raise RuntimeError(
184
+ "The CustomOp constructor is private and we do not guarantee "
185
+ "BC for it. Please use custom_op(...) to create a CustomOp object"
186
+ )
187
+ name = f"{cpp_ns}::{operator_name}"
188
+ self._schema = schema
189
+ self._cpp_ns = cpp_ns
190
+ self._lib: library.Library = lib
191
+ self._ophandle: _C._DispatchOperatorHandle = ophandle
192
+ # Has the name of the op, e.g. "foo". We cache here for convenience.
193
+ self._opname: str = operator_name
194
+ # this is _opname but with namespace. e.g. "custom::foo"
195
+ self._qualname: str = name
196
+ self.__name__ = None # mypy requires this
197
+ # NB: Some of these impls are registered as kernels to DispatchKeys.
198
+ # Modifying the _impls dict directly won't do anything in that case.
199
+ self._impls: typing.Dict[str, typing.Optional[FuncAndLocation]] = {}
200
+ # See NOTE [CustomOp autograd kernel indirection]
201
+ self._registered_autograd_kernel_indirection = False
202
+
203
+ global_registry[self._qualname] = self
204
+
205
+ def _register_autograd_kernel_indirection(self):
206
+ assert not self._registered_autograd_kernel_indirection
207
+ self._lib.impl(self._opname, autograd_kernel_indirection(weakref.proxy(self)), "Autograd")
208
+ self._registered_autograd_kernel_indirection = True
209
+
210
+ # Records the impl and the source location in self._impls
211
+ # Note that this doesn't cause torch.library to use the impl, that
212
+ # needs to be done in a separate self._lib.impl call.
213
+ def _register_impl(self, kind, func, stacklevel=2):
214
+ if self._has_impl(kind):
215
+ func_and_location = self._impls[kind]
216
+ assert func_and_location is not None # Pacify mypy
217
+ location = func_and_location.location
218
+ raise RuntimeError(
219
+ f"Attempting to register a {kind} impl for operator {self._qualname} "
220
+ f"that already has a {kind} impl registered from Python at "
221
+ f"{location}. This is not supported."
222
+ )
223
+ frame = inspect.getframeinfo(sys._getframe(stacklevel))
224
+ location = f"{frame.filename}:{frame.lineno}"
225
+ self._impls[kind] = FuncAndLocation(func, location)
226
+
227
+ def _get_impl(self, kind):
228
+ return self._impls[kind]
229
+
230
+ def _has_impl(self, kind):
231
+ return kind in self._impls
232
+
233
+ def _destroy(self):
234
+ # NOTE: [CustomOp lifetime]
235
+ # A CustomOp, once created, lives forever. The mechanism is that the
236
+ # global registry holds a reference to it. However, to make testing
237
+ # easier, we want to be able to destroy CustomOp objects.
238
+ # CustomOp._destroy does the job, though it leaves the CustomOp
239
+ # in a garbage state.
240
+ del self._lib
241
+
242
+ opnamespace = getattr(torch.ops, self._cpp_ns)
243
+ if hasattr(opnamespace, self._opname):
244
+ delattr(opnamespace, self._opname)
245
+
246
+ del global_registry[self._qualname]
247
+
248
+ def __repr__(self):
249
+ return f'<CustomOp(op="{self._qualname}")>'
250
+
251
+ def __call__(self, *args, **kwargs):
252
+ # Bypass torch.ops.* and directly do OperatorHandle::callBoxed.
253
+ # Using torch.ops.* is a bit of a pain (it can be slow and it has lifetime
254
+ # issues from caching operators that make testing CustomOp difficult).
255
+ result = _C._dispatch_call_boxed(self._ophandle, *args, **kwargs)
256
+ return result
257
+
258
+ def impl(
259
+ self, device_types: typing.Union[str, typing.Iterable[str]], _stacklevel=2,
260
+ ) -> typing.Callable:
261
+ r"""Register an implementation for a device type for this CustomOp object.
262
+
263
+ WARNING: if you're a user, please do not use this directly
264
+ (instead use the torch._custom_ops APIs).
265
+ Also please see the following for a detailed guide on custom ops.
266
+ https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
267
+
268
+ If the CustomOp is passed multiple Tensor inputs with different device
269
+ types, it will dispatch to the registered implementation for the highest
270
+ priority device type among those present.
271
+ The supported device types, in order of priority, are {'cuda', 'cpu'}.
272
+
273
+ This API is used as a decorator (see examples).
274
+
275
+ Arguments:
276
+ device_types (str or Iterable[str]): the device type(s) to register the function for.
277
+
278
+ Examples::
279
+ >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
280
+ >>> import numpy as np
281
+ >>> from torch import Tensor
282
+ >>>
283
+ >>> @custom_op("my_library::numpy_cos")
284
+ >>> def numpy_cos(x: Tensor) -> Tensor:
285
+ >>> ...
286
+ >>>
287
+ >>> # Register an implementation for CPU Tensors
288
+ >>> @numpy_cos.impl('cpu')
289
+ >>> def numpy_cos_impl_cpu(x):
290
+ >>> return torch.from_numpy(np.cos(x.numpy()))
291
+ >>>
292
+ >>> # Register an implementation for CUDA Tensors
293
+ >>> @numpy_cos.impl('cuda')
294
+ >>> def numpy_cos_impl_cuda(x):
295
+ >>> return torch.from_numpy(np.cos(x.cpu().numpy())).to(x.device)
296
+ >>>
297
+ >>> x = torch.randn(3)
298
+ >>> numpy_cos(x) # calls numpy_cos_impl_cpu
299
+ >>>
300
+ >>> x_cuda = x.cuda()
301
+ >>> numpy_cos(x) # calls numpy_cos_impl_cuda
302
+
303
+ """
304
+ if isinstance(device_types, str):
305
+ device_types = [device_types]
306
+ for device_type in device_types:
307
+ validate_device_type(device_type)
308
+
309
+ def inner(f):
310
+ for device_type in set(device_types):
311
+ self._check_doesnt_have_library_impl(device_type)
312
+ self._register_impl(device_type, f, stacklevel=_stacklevel)
313
+ dispatch_key = SUPPORTED_DEVICE_TYPE_TO_KEY[device_type]
314
+ library.impl(self._lib, self._opname, dispatch_key)(f)
315
+ return f
316
+
317
+ return inner
318
+
319
+ def _check_doesnt_have_library_impl(self, device_type):
320
+ if self._has_impl(device_type):
321
+ return
322
+ key = SUPPORTED_DEVICE_TYPE_TO_KEY[device_type]
323
+ if _C._dispatch_has_computed_kernel_for_dispatch_key(self._qualname, key):
324
+ raise RuntimeError(
325
+ f"impl(..., device_types={device_type}): the operator {self._qualname} "
326
+ f"already has an implementation for this device type via a "
327
+ f"pre-existing torch.library or TORCH_LIBRARY registration.")
328
+
329
+ def impl_factory(self) -> typing.Callable:
330
+ r"""Register an implementation for a factory function."""
331
+
332
+ def inner(f):
333
+ self._register_impl("factory", f)
334
+ library.impl(self._lib, self._opname, "BackendSelect")(f)
335
+ return f
336
+
337
+ return inner
338
+
339
+ def impl_abstract(self, _stacklevel=2) -> typing.Callable:
340
+ r"""Register an abstract implementation for this operator.
341
+
342
+ WARNING: please do not use this directly (and instead use the torch._custom_ops
343
+ APIs). Also please see the following for a detailed guide on custom ops.
344
+ https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
345
+
346
+ An "abstract implementation" specifies the behavior of this operator on
347
+ Tensors that carry no data. Given some input Tensors with certain properties
348
+ (sizes/strides/storage_offset/device), it specifies what the properties of
349
+ the output Tensors are.
350
+
351
+ The abstract implementation has the same signature as the operator.
352
+ It is run for both FakeTensors and meta tensors. To write an abstract
353
+ implementation, assume that all Tensor inputs to the operator are
354
+ regular CPU/CUDA/Meta tensors, but they do not have storage, and
355
+ you are trying to return regular CPU/CUDA/Meta tensor(s) as output.
356
+ The abstract implementation must consist of only PyTorch operations
357
+ (and may not directly access the storage or data of any input or
358
+ intermediate Tensors).
359
+
360
+ This API is used as a decorator (see examples).
361
+
362
+ Examples::
363
+ >>> import numpy as np
364
+ >>> from torch import Tensor
365
+ >>>
366
+ >>> # Example 1: an operator without data-dependent output shape
367
+ >>> @custom_op('my_library::custom_linear')
368
+ >>> def custom_linear(x: Tensor, weight: Tensor, bias: Tensor) -> Tensor:
369
+ >>> ...
370
+ >>>
371
+ >>> @custom_linear.impl_abstract()
372
+ >>> def custom_linear_abstract(x, weight):
373
+ >>> assert x.dim() == 2
374
+ >>> assert weight.dim() == 2
375
+ >>> assert bias.dim() == 1
376
+ >>> assert x.shape[1] == weight.shape[1]
377
+ >>> assert weight.shape[0] == bias.shape[0]
378
+ >>> assert x.device == weight.device
379
+ >>>
380
+ >>> return (x @ weight.t()) + bias
381
+ >>>
382
+ >>> # Example 2: an operator with data-dependent output shape
383
+ >>> @custom_op('my_library::custom_nonzero')
384
+ >>> def custom_nonzero(x: Tensor) -> Tensor:
385
+ >>> ...
386
+ >>>
387
+ >>> @custom_nonzero.impl_abstract()
388
+ >>> def custom_nonzero_abstract(x):
389
+ >>> # Number of nonzero-elements is data-dependent.
390
+ >>> # Since we cannot peek at the data in an abstract impl,
391
+ >>> # we use the ctx object to construct a new symint that
392
+ >>> # represents the data-dependent size.
393
+ >>> ctx = torch._custom_op.get_ctx()
394
+ >>> nnz = ctx.create_unbacked_symint()
395
+ >>> shape = [x.dim(), nnz]
396
+ >>> result = x.new_empty(shape, dtype=torch.long)
397
+ >>> return result
398
+ >>>
399
+ >>> @custom_nonzero.impl(['cpu', 'cuda'])
400
+ >>> def custom_nonzero_impl(x):
401
+ >>> x_np = to_numpy(x)
402
+ >>> res = np.stack(np.nonzero(x_np), axis=1)
403
+ >>> # unbacked symbolic ints in PyTorch must be >= 2, so we
404
+ >>> # constrain the range to at least 2
405
+ >>> if res.shape[0] <= 1:
406
+ >>> raise RuntimeError("not supported")
407
+ >>> return torch.tensor(res, device=x.device)
408
+
409
+ """
410
+
411
+ def inner(f):
412
+ self._check_doesnt_have_library_meta_impl()
413
+ self._register_impl("abstract", f, stacklevel=_stacklevel)
414
+ location = self._get_impl("abstract").location
415
+
416
+ qualname = self._qualname
417
+
418
+ # Handle DispatchKey.Meta registration
419
+ @functools.wraps(f)
420
+ def f_with_ctx(*args, **kwargs):
421
+ def error_on_ctx():
422
+ raise RuntimeError(
423
+ f"Attempted to call get_ctx() for the meta implementation "
424
+ f"for {qualname}."
425
+ f"You have presumably called get_ctx() because the operator "
426
+ f"has a data-dependent output shape; if so, there is no "
427
+ f"such meta implementation and this error is the correct "
428
+ f"behavior. Otherwise, please remove the call to get_ctx() "
429
+ f"in the implementation registered with impl_abstract "
430
+ f"at {location}"
431
+ )
432
+
433
+ with torch._library.abstract_impl.set_ctx_getter(error_on_ctx):
434
+ return f(*args, **kwargs)
435
+
436
+ self._lib.impl(self._opname, f_with_ctx, "Meta")
437
+ return f
438
+
439
+ return inner
440
+
441
+ def _check_can_register_backward(self):
442
+ def error(detail):
443
+ raise RuntimeError(
444
+ f"Cannot use torch._custom_ops APIs to register backward "
445
+ f"formula for {detail}. Got operator "
446
+ f"{self._qualname} with schema: {schema}"
447
+ )
448
+
449
+ schema = self._schema
450
+ if schema.kind() != SchemaKind.functional:
451
+ error("non-functional operator")
452
+
453
+ rets = schema.returns
454
+ if not schema.returns:
455
+ error("operator with no returns")
456
+
457
+ assert len(rets) > 0
458
+ is_non_mutating_view = any(
459
+ r.annotation is not None and not r.annotation.is_write for r in rets
460
+ )
461
+ if is_non_mutating_view:
462
+ error("operator that returns views")
463
+
464
+ # We make assumptions about the schema's return types.
465
+ allowed_return_types = {
466
+ BaseType(BaseTy.int): "int",
467
+ BaseType(BaseTy.SymInt): "SymInt",
468
+ BaseType(BaseTy.bool): "bool",
469
+ BaseType(BaseTy.float): "float",
470
+ BaseType(BaseTy.Tensor): "Tensor",
471
+ ListType(BaseType(BaseTy.Tensor), None): "List[Tensor]",
472
+ }
473
+ for ret in schema.returns:
474
+ if ret.type in allowed_return_types:
475
+ continue
476
+ error(f"operator with return not in {list(allowed_return_types.values())} (got {ret.type})")
477
+
478
+ def _check_doesnt_have_library_autograd_impl(self):
479
+ if self._registered_autograd_kernel_indirection:
480
+ return
481
+
482
+ if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "CompositeImplicitAutograd"):
483
+ raise RuntimeError(
484
+ f"impl_backward/impl_save_for_backward: the operator {self._qualname} "
485
+ f"already has an implementation for this device type via a "
486
+ f"pre-existing registration to DispatchKey::CompositeImplicitAutograd."
487
+ f"CompositeImplicitAutograd operators do not need an autograd formula; "
488
+ f"instead, the operator will decompose into its constituents and those "
489
+ f"can have autograd formulas defined on them.")
490
+
491
+ # We can improve this by adding "all Autograd<BACKEND> keys", but
492
+ # realistically people will just be using this API for CPU/CUDA for now.
493
+ for key in ["Autograd", "AutogradCPU", "AutogradCUDA"]:
494
+ if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, key):
495
+ raise RuntimeError(
496
+ f"impl_backward/impl_save_for_backward: "
497
+ f"the operator {self._qualname} already has an Autograd kernel "
498
+ f"registered to DispatchKey::{key} vi a pre-existing "
499
+ f"torch.library or TORCH_LIBRARY registration. Please either "
500
+ f"remove those registrations or don't use the torch._custom_ops APIs")
501
+
502
+ def _check_doesnt_have_library_meta_impl(self):
503
+ if self._has_impl("abstract"):
504
+ return
505
+
506
+ # If the user's operator is CompositeExplicitAutograd,
507
+ # allow them to impl_abstract. This is being pragmatic
508
+ # (existing custom ops may have CompositeExplicitAutograd
509
+ # registration that don't work with Meta kernels, so this
510
+ # gives them an escape hatch).
511
+ if (
512
+ _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "CompositeExplicitAutograd")
513
+ and not _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "Meta")
514
+ ):
515
+ return
516
+
517
+ # Otherwise, if the user's already has a Meta kernel or their
518
+ # op is CompositeImplicitAutograd or some other alias dispatch key,
519
+ # raise.
520
+
521
+ # Special case for CompositeImplicitAutograd
522
+ if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "CompositeImplicitAutograd"):
523
+ raise RuntimeError(
524
+ f"impl_abstract(...): the operator {self._qualname} "
525
+ f"already has an implementation for this device type via a "
526
+ f"pre-existing registration to DispatchKey::CompositeImplicitAutograd."
527
+ f"CompositeImplicitAutograd operators do not need an abstract impl; "
528
+ f"instead, the operator will decompose into its constituents and those "
529
+ f"can have abstract impls defined on them.")
530
+
531
+ if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "Meta"):
532
+ raise RuntimeError(
533
+ f"impl_abstract(...): the operator {self._qualname} "
534
+ f"already has an DispatchKey::Meta implementation via a "
535
+ f"pre-existing torch.library or TORCH_LIBRARY registration. "
536
+ f"Please either remove that registration or don't call impl_abstract.")
537
+
538
+ # NOTE ["backward", "save_for_backward", and "autograd"]
539
+ # As a part of the explicit autograd API, a user must provide us
540
+ # a "save_for_backward" function and a "backward" function.
541
+ # When both of these have been provided, then we automatically
542
+ # construct the "autograd" kernel.
543
+ def _register_autograd_kernel(self):
544
+ assert self._has_impl("backward")
545
+ assert self._has_impl("save_for_backward")
546
+ kernel = construct_autograd_kernel(
547
+ self._schema,
548
+ self._output_differentiability,
549
+ self,
550
+ get_op(self._qualname),
551
+ self._get_impl("save_for_backward").func,
552
+ self._get_impl("backward").func)
553
+ self._register_impl("autograd", kernel)
554
+
555
+ def impl_save_for_backward(self, _stacklevel=2):
556
+ r"""Register a function that tells us what to save for backward.
557
+
558
+ Please see impl_backward for more details.
559
+ """
560
+ def inner(f):
561
+ self._check_can_register_backward()
562
+ self._check_doesnt_have_library_autograd_impl()
563
+ if not self._registered_autograd_kernel_indirection:
564
+ self._register_autograd_kernel_indirection()
565
+ self._register_impl("save_for_backward", f, stacklevel=_stacklevel)
566
+ if self._has_impl("backward"):
567
+ self._register_autograd_kernel()
568
+ return inner
569
+
570
+ def impl_backward(self, output_differentiability=None, _stacklevel=2):
571
+ r"""Registers a backward formula.
572
+
573
+ WARNING: if you're a user, please do not use this directly
574
+ (instead use the torch._custom_ops APIs).
575
+ Also please see the following for a detailed guide on custom ops.
576
+ https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
577
+
578
+ In order for the CustomOp to work with autograd, you need to register
579
+ a backward formula. There are two pieces to this:
580
+ 1. You must give us a function to specify what to save for backward.
581
+ Call this the "save for backward" function.
582
+ 2. You must give us a function that computes gradients. Call this the
583
+ "backward" function.
584
+
585
+ Use `impl_save_for_backward` to define a "save for backward" function
586
+ that specifies what gets saved for backward. The function should accept
587
+ two arguments ``(inputs, output)`` and return the quantities to be saved
588
+ for backward.
589
+
590
+ During runtime, when you call the CustomOp, PyTorch will invoke the
591
+ "save for backward" function with the inputs and output of the CustomOp.
592
+
593
+ Use `impl_backward` to define the "backward" function. The backward
594
+ function must accept ``(ctx, saved, *grads)``:
595
+ - ``ctx`` is a context object where we may provide information
596
+ - ``saved`` is exactly what gets returned from the "save for backward"
597
+ function
598
+ - ``grads`` is one or more gradients. The number of gradients matches
599
+ the number of outputs of the CustomOp.
600
+
601
+ The backward function must return a dict that maps the name of
602
+ an input to the CustomOp to its corresponding gradient. All inputs that
603
+ were declared to be Tensors in the CustomOp definition must be accounted
604
+ for in the dict. The gradient may be a Tensor or None.
605
+
606
+ """
607
+ if output_differentiability is not None:
608
+ def yell():
609
+ raise RuntimeError(
610
+ f"impl_backward(output_differentiability): expected "
611
+ f"output_differentiability to be a list of bools with "
612
+ f"length equal to the number of outputs of this CustomOp "
613
+ f"got: {output_differentiability}")
614
+
615
+ if not isinstance(output_differentiability, list):
616
+ yell()
617
+ for diff in output_differentiability:
618
+ if not isinstance(diff, bool):
619
+ yell()
620
+ if len(self._schema.returns) != len(output_differentiability):
621
+ yell()
622
+
623
+ def inner(f):
624
+ self._check_can_register_backward()
625
+ self._check_doesnt_have_library_autograd_impl()
626
+ if not self._registered_autograd_kernel_indirection:
627
+ self._register_autograd_kernel_indirection()
628
+ self._register_impl("backward", f, stacklevel=_stacklevel)
629
+ self._output_differentiability = output_differentiability
630
+ if self._has_impl("save_for_backward"):
631
+ self._register_autograd_kernel()
632
+ return inner
633
+
634
+
635
+ @dataclasses.dataclass
636
+ class FuncAndLocation:
637
+ func: typing.Callable
638
+ location: str
639
+
640
+
641
+ def find_ophandle_or_throw(cpp_ns: str, operator_name: OperatorName):
642
+ overload_name = (
643
+ "" if operator_name.overload_name is None else operator_name.overload_name
644
+ )
645
+ return _C._dispatch_find_schema_or_throw(
646
+ f"{cpp_ns}::{str(operator_name.name)}", overload_name
647
+ )
648
+
649
+
650
+ def validate_namespace(ns: str) -> None:
651
+ if "." in ns:
652
+ raise ValueError(
653
+ f'custom_op(..., ns="{ns}"): expected ns to not contain any . (and be a '
654
+ f"valid variable name)"
655
+ )
656
+ if ns in RESERVED_NS:
657
+ raise ValueError(
658
+ f"custom_op(..., ns='{ns}'): '{ns}' is a reserved namespace, "
659
+ f"please choose something else. "
660
+ )
661
+
662
+ def validate_schema(schema: FunctionSchema) -> None:
663
+ if not torch._library.utils.is_functional_schema(schema):
664
+ raise ValueError(
665
+ f"custom_op only supports functional operators "
666
+ f"(ops that do not mutate any inputs, do not return "
667
+ f"views of the inputs, and has at least one return). "
668
+ f"Got the following non-functional schema: {schema}"
669
+ )
670
+
671
+ # For simplicity: don't allow self arguments
672
+ if schema.arguments.self_arg is not None:
673
+ raise ValueError(
674
+ f"custom_op does not support arguments named 'self'. Please "
675
+ f"rename your argument. Got: {schema}"
676
+ )
677
+
678
+
679
+ def parse_qualname(qualname: str) -> typing.Tuple[str, str]:
680
+ names = qualname.split("::", 1)
681
+ if len(names) != 2:
682
+ raise ValueError(f"Expected there to be a namespace in {qualname}, i.e. The "
683
+ f"operator name should look something like ns::foo")
684
+ if '.' in names[1]:
685
+ raise ValueError(f"The torch.custom_ops APIs do not handle overloads, "
686
+ f"i.e. operator names with '.' in them. "
687
+ f"Please name your operator something like ns::foo. "
688
+ f"Got: {qualname}")
689
+ return names[0], names[1]
690
+
691
+
692
+ def validate_device_type(device_type: str) -> None:
693
+ if device_type not in SUPPORTED_DEVICE_TYPE_TO_KEY:
694
+ raise ValueError(
695
+ f"CustomOp.impl(device_types=[{device_type}, ...]): we only support device_type "
696
+ f"in {SUPPORTED_DEVICE_TYPE_TO_KEY.keys()}."
697
+ )
698
+
699
+
700
+ def supported_param(param: inspect.Parameter) -> bool:
701
+ return param.kind in (
702
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
703
+ inspect.Parameter.KEYWORD_ONLY,
704
+ )
705
+
706
+
707
+ def validate_function_matches_schema(
708
+ schema: FunctionSchema, func: typing.Callable
709
+ ) -> None:
710
+ sig = inspect.signature(func)
711
+
712
+ if not all(supported_param(p) for _, p in sig.parameters.items()):
713
+ raise ValueError(
714
+ f"custom_op(..., manual_schema)(func): positional-only args, "
715
+ f"varargs, and kwargs are not supported. Please rewrite `func` "
716
+ f"to not have them. Got `func` with signature: {sig}"
717
+ )
718
+
719
+ if (
720
+ any(
721
+ p.annotation is not inspect.Parameter.empty
722
+ for _, p in sig.parameters.items()
723
+ )
724
+ or sig.return_annotation is not inspect.Signature.empty
725
+ ):
726
+ raise ValueError(
727
+ f"custom_op(..., manual_schema)(func): When passing in a manual "
728
+ f"schema, we expect `func` to have no type annotations to avoid "
729
+ f"ambiguity. Got `func` with signature: {sig}"
730
+ )
731
+
732
+ positional = [
733
+ (name, param)
734
+ for name, param in sig.parameters.items()
735
+ if param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
736
+ ]
737
+ kwargonly = [
738
+ (name, param)
739
+ for name, param in sig.parameters.items()
740
+ if param.kind == inspect.Parameter.KEYWORD_ONLY
741
+ ]
742
+
743
+ def error():
744
+ raise ValueError(
745
+ f"custom_op(..., manual_schema)(func): When passing in a manual "
746
+ f"schema, we expect `func`'s signature to match `manual_schema` "
747
+ f"(aside from type annotations). "
748
+ f"func's signature: {sig}, manual_schema: {schema}"
749
+ )
750
+
751
+ def error_default_args():
752
+ raise ValueError(
753
+ f"custom_op(..., manual_schema)(func): "
754
+ f"neither func nor manual_schema should have default "
755
+ f"arguments. Got "
756
+ f"func's signature: {sig}, manual_schema: {schema}"
757
+ )
758
+
759
+ def compare(sig_args, schema_args):
760
+ if len(sig_args) != len(schema_args):
761
+ error()
762
+ for (name, param), arg in zip(sig_args, schema_args):
763
+ if name != arg.name:
764
+ error()
765
+ if param.default is not inspect.Parameter.empty or arg.default is not None:
766
+ error_default_args()
767
+
768
+ compare(positional, schema.arguments.flat_positional)
769
+ compare(kwargonly, schema.arguments.flat_kwarg_only)
770
+
771
+
772
+ def infer_schema(prototype_function: typing.Callable) -> str:
773
+ sig = inspect.signature(prototype_function)
774
+
775
+ def error_fn(what):
776
+ raise ValueError(
777
+ f"custom_op(...)(func): {what} " f"Got func with signature {sig})"
778
+ )
779
+
780
+ params = [
781
+ parse_param(name, param, error_fn) for name, param in sig.parameters.items()
782
+ ]
783
+ ret = parse_return(sig.return_annotation, error_fn)
784
+ return f"({', '.join(params)}) -> {ret}"
785
+
786
+
787
+ def parse_param(name, param, error_fn):
788
+ if not supported_param(param):
789
+ error_fn("We do not support positional-only args, varargs, or varkwargs.")
790
+
791
+ if param.annotation is inspect.Parameter.empty:
792
+ error_fn(f"Parameter {name} must have a type annotation.")
793
+
794
+ if param.annotation not in SUPPORTED_PARAM_TYPES.keys():
795
+ error_fn(
796
+ f"Parameter {name} has unsupported type {param.annotation}. "
797
+ f"The valid types are: {SUPPORTED_PARAM_TYPES.keys()}."
798
+ )
799
+
800
+ if param.default is not inspect.Parameter.empty:
801
+ error_fn(
802
+ f"Parameter {name} has a default value; this is not supported. "
803
+ f"If you want to use default values then create a function with "
804
+ f"default values that calls the CustomOp"
805
+ )
806
+
807
+ return f"{SUPPORTED_PARAM_TYPES[param.annotation]} {name}"
808
+
809
+
810
+ def derived_types(
811
+ base_type, cpp_type, list_base, optional_base_list, optional_list_base
812
+ ):
813
+ result = [
814
+ (base_type, cpp_type),
815
+ (typing.Optional[base_type], f"{cpp_type}?"),
816
+ ]
817
+ if list_base:
818
+ result.append((typing.Sequence[base_type], f"{cpp_type}[]")) # type: ignore[valid-type]
819
+ if optional_base_list:
820
+ result.append((typing.Sequence[typing.Optional[base_type]], f"{cpp_type}?[]")) # type: ignore[valid-type]
821
+ if optional_list_base:
822
+ result.append((typing.Optional[typing.Sequence[base_type]], f"{cpp_type}[]?")) # type: ignore[valid-type]
823
+ return result
824
+
825
+
826
+ def get_supported_param_types():
827
+ data = [
828
+ # (python type, schema type, type[] variant, type?[] variant, type[]? variant
829
+ (torch.Tensor, "Tensor", True, True, False),
830
+ (int, "SymInt", True, False, True),
831
+ (float, "float", True, False, True),
832
+ (bool, "bool", True, False, True),
833
+ (str, "str", False, False, False),
834
+ (torch.types.Number, "Scalar", True, False, False),
835
+ (torch.dtype, "ScalarType", False, False, False),
836
+ (torch.device, "Device", False, False, False),
837
+ ]
838
+ result = []
839
+ for line in data:
840
+ result.extend(derived_types(*line))
841
+ return dict(result)
842
+
843
+
844
+ SUPPORTED_RETURN_TYPES = {
845
+ torch.Tensor: "Tensor",
846
+ typing.List[torch.Tensor]: "Tensor[]",
847
+ int: "SymInt",
848
+ float: "float",
849
+ bool: "bool",
850
+ torch.types.Number: "Scalar",
851
+ }
852
+
853
+
854
+ def parse_return(annotation, error_fn):
855
+ origin = typing.get_origin(annotation)
856
+ if origin is not tuple:
857
+ if annotation not in SUPPORTED_RETURN_TYPES.keys():
858
+ error_fn(
859
+ f"Return has unsupported type {annotation}. "
860
+ f"The valid types are: {SUPPORTED_RETURN_TYPES}."
861
+ )
862
+ return SUPPORTED_RETURN_TYPES[annotation]
863
+
864
+ args = typing.get_args(annotation)
865
+ for arg in args:
866
+ if arg not in SUPPORTED_RETURN_TYPES:
867
+ error_fn(
868
+ f"Return has unsupported type {annotation}. "
869
+ f"The valid types are: {SUPPORTED_RETURN_TYPES}."
870
+ )
871
+
872
+ return "(" + ", ".join([SUPPORTED_RETURN_TYPES[arg] for arg in args]) + ")"
873
+
874
+
875
+ SUPPORTED_PARAM_TYPES = get_supported_param_types()
876
+
877
+
878
+ def report_error_callback(custom_op: typing.Any, key: str) -> None:
879
+ if key == "Undefined":
880
+ raise NotImplementedError(
881
+ f"{custom_op}: There were no Tensor inputs to this operator "
882
+ f"(e.g. you passed an empty list of Tensors). If your operator is a "
883
+ f"factory function (that is, it takes no Tensors and constructs "
884
+ f"a new one), then please use CustomOp.impl_factory to register "
885
+ f"an implementation for it"
886
+ )
887
+ if key == "Meta":
888
+ raise NotImplementedError(
889
+ f"{custom_op}: when running with device='Meta' tensors: there is no "
890
+ f"abstract impl registered for this CustomOp. Please register one via "
891
+ f"CustomOp.impl_abstract to get this CustomOp to work with Meta tensors"
892
+ )
893
+ if key in ("CPU", "CUDA"):
894
+ device = key.lower()
895
+ raise NotImplementedError(
896
+ f"{custom_op}: when running with device='{device}' tensors: there is no "
897
+ f"{device} impl registered for this CustomOp. Please register one via "
898
+ f"CustomOp.impl(device_type='{device}')"
899
+ )
900
+ raise NotImplementedError(
901
+ f"{custom_op}: No implementation for dispatch key {key}. It is likely "
902
+ f"that we have not added this functionality yet, please either open an "
903
+ f"issue or if you're feeling adventurous, use the low-level "
904
+ f"torch.library API"
905
+ )
906
+
907
+
908
+ def custom_op_from_existing(op):
909
+ ns = op.namespace
910
+ lib = torch.library.Library(ns, "FRAGMENT")
911
+ name = op.name().split("::")[-1]
912
+ schema_str = str(op._schema)
913
+ # CustomOp expects the schema string without the namespace
914
+ schema_str = schema_str.split("::")[-1]
915
+ schema = FunctionSchema.parse(schema_str)
916
+ return CustomOp(lib, ns, schema, name, op, _private_access=True)
917
+
918
+
919
+ def get_op(qualname):
920
+ def error_not_found():
921
+ raise ValueError(
922
+ f"Could not find the operator {qualname}. Please make sure you have "
923
+ f"already registered the operator and (if registered from C++) "
924
+ f"loaded it via torch.ops.load_library.")
925
+
926
+ ns, name = parse_qualname(qualname)
927
+ if not hasattr(torch.ops, ns):
928
+ error_not_found()
929
+ opnamespace = getattr(torch.ops, ns)
930
+ if not hasattr(opnamespace, name):
931
+ error_not_found()
932
+ packet = getattr(opnamespace, name)
933
+ if not hasattr(packet, 'default'):
934
+ error_not_found()
935
+ return packet.default
936
+
937
+
938
+ def _find_custom_op(qualname, also_check_torch_library=False):
939
+ if qualname in global_registry:
940
+ return global_registry[qualname]
941
+ if not also_check_torch_library:
942
+ raise RuntimeError(
943
+ f"Could not find custom op \"{qualname}\". Did you register it via "
944
+ f"the torch._custom_ops API?")
945
+ overload = get_op(qualname)
946
+ result = custom_op_from_existing(overload)
947
+ return result
948
+
949
+
950
+ def get_abstract_impl(qualname):
951
+ if qualname not in torch._custom_op.impl.global_registry:
952
+ return None
953
+ custom_op = torch._custom_op.impl.global_registry[qualname]
954
+ if custom_op is None:
955
+ return None
956
+ if not custom_op._has_impl("abstract"):
957
+ return None
958
+ return custom_op._get_impl("abstract").func
959
+
960
+
961
+ def _custom_op_with_schema(qualname, schema, needs_fixed_stride_order=True):
962
+ ns, name = qualname.split("::")
963
+ schema_str = f"{name}{schema}"
964
+ function_schema = FunctionSchema.parse(schema_str)
965
+ validate_schema(function_schema)
966
+ tags = [torch._C.Tag.needs_fixed_stride_order] if needs_fixed_stride_order else []
967
+ lib = library.Library(ns, "FRAGMENT")
968
+ lib.define(schema_str, tags=tags)
969
+ ophandle = find_ophandle_or_throw(ns, function_schema.name)
970
+ result = CustomOp(lib, ns, function_schema, name, ophandle, _private_access=True)
971
+ result._register_autograd_kernel_indirection()
972
+
973
+ torch._C._dispatch_set_report_error_callback(
974
+ ophandle, functools.partial(report_error_callback, weakref.proxy(result))
975
+ )
976
+ return get_op(qualname)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/fft.py ADDED
@@ -0,0 +1,590 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ from typing import Iterable, List, Literal, NamedTuple, Optional, Sequence, Tuple, Union
4
+
5
+ import torch
6
+ import torch._prims as prims
7
+ import torch._prims_common as utils
8
+ from torch._decomp import register_decomposition
9
+ from torch._prims_common import DimsType, ShapeType, TensorLikeType
10
+ from torch._prims_common.wrappers import _maybe_convert_to_dtype, out_wrapper
11
+
12
+ __all__ = [
13
+ # Transforms
14
+ "fft",
15
+ "fft2",
16
+ "fftn",
17
+ "hfft",
18
+ "hfft2",
19
+ "hfftn",
20
+ "rfft",
21
+ "rfft2",
22
+ "rfftn",
23
+ "ifft",
24
+ "ifft2",
25
+ "ifftn",
26
+ "ihfft",
27
+ "ihfft2",
28
+ "ihfftn",
29
+ "irfft",
30
+ "irfft2",
31
+ "irfftn",
32
+ # Helpers
33
+ "fftshift",
34
+ "ifftshift",
35
+ ]
36
+
37
+ NormType = Union[None, Literal["forward", "backward", "ortho"]]
38
+ _NORM_VALUES = {None, "forward", "backward", "ortho"}
39
+ aten = torch._ops.ops.aten
40
+
41
+
42
+ def _apply_norm(
43
+ x: TensorLikeType, norm: NormType, signal_numel: int, forward: bool
44
+ ) -> TensorLikeType:
45
+ """Apply normalization to the un-normalized FFT result"""
46
+ torch._check(norm in _NORM_VALUES, lambda: f"Invalid normalization mode: {norm}")
47
+
48
+ if norm == "ortho":
49
+ return x * (1 / math.sqrt(signal_numel))
50
+
51
+ normalize = (not forward and (norm is None or norm == "backward")) or (
52
+ forward and norm == "forward"
53
+ )
54
+ return x * (1 / signal_numel) if normalize else x
55
+
56
+
57
+ def _promote_type_fft(
58
+ dtype: torch.dtype, require_complex: bool, device: torch.device
59
+ ) -> torch.dtype:
60
+ """Helper to promote a dtype to one supported by the FFT primitives"""
61
+ if dtype.is_complex:
62
+ return dtype
63
+
64
+ # Promote integral to default float type
65
+ if not dtype.is_floating_point:
66
+ dtype = torch.get_default_dtype()
67
+
68
+ allowed_types = [torch.float32, torch.float64]
69
+ maybe_support_half = device.type in ["cuda", "meta"]
70
+
71
+ if maybe_support_half:
72
+ allowed_types.append(torch.float16)
73
+ torch._check(dtype in allowed_types, lambda: f"Unsupported dtype {dtype}")
74
+
75
+ if require_complex:
76
+ dtype = utils.corresponding_complex_dtype(dtype)
77
+
78
+ return dtype
79
+
80
+
81
+ def _maybe_promote_tensor_fft(
82
+ t: TensorLikeType, require_complex: bool = False
83
+ ) -> TensorLikeType:
84
+ """Helper to promote a tensor to a dtype supported by the FFT primitives"""
85
+ cur_type = t.dtype
86
+ new_type = _promote_type_fft(cur_type, require_complex, t.device)
87
+ return _maybe_convert_to_dtype(t, new_type) # type: ignore[return-value]
88
+
89
+
90
+ def _resize_fft_input(
91
+ x: TensorLikeType, dims: Tuple[int, ...], sizes: Tuple[int, ...]
92
+ ) -> TensorLikeType:
93
+ """
94
+ Fixes the shape of x such that x.size(dims[i]) == sizes[i],
95
+ either by zero-padding, or by slicing x starting from 0.
96
+ """
97
+ assert len(dims) == len(sizes)
98
+ must_copy = False
99
+ x_sizes = x.shape
100
+ pad_amount = [0] * len(x_sizes) * 2
101
+ for i in range(len(dims)):
102
+ if sizes[i] == -1:
103
+ continue
104
+
105
+ if x_sizes[dims[i]] < sizes[i]:
106
+ must_copy = True
107
+ pad_idx = len(pad_amount) - 2 * dims[i] - 1
108
+ pad_amount[pad_idx] = sizes[i] - x_sizes[dims[i]]
109
+
110
+ if x_sizes[dims[i]] > sizes[i]:
111
+ x = x.narrow(dims[i], 0, sizes[i])
112
+
113
+ return torch.constant_pad_nd(x, pad_amount) if must_copy else x
114
+
115
+
116
+ def _fft_c2r(
117
+ func_name: str,
118
+ input: TensorLikeType,
119
+ n: Optional[int],
120
+ dim: int,
121
+ norm: NormType,
122
+ forward: bool,
123
+ ) -> TensorLikeType:
124
+ """Common code for performing any complex to real FFT (irfft or hfft)"""
125
+ input = _maybe_promote_tensor_fft(input, require_complex=True)
126
+ dims = (utils.canonicalize_dim(input.ndim, dim, wrap_scalar=False),)
127
+ last_dim_size = n if n is not None else 2 * (input.shape[dim] - 1)
128
+ torch._check(
129
+ last_dim_size >= 1,
130
+ lambda: f"Invalid number of data points ({last_dim_size}) specified",
131
+ )
132
+
133
+ if n is not None:
134
+ input = _resize_fft_input(input, dims=dims, sizes=(last_dim_size // 2 + 1,))
135
+
136
+ if forward:
137
+ input = torch.conj(input)
138
+
139
+ output = prims.fft_c2r(input, dim=dims, last_dim_size=last_dim_size)
140
+ return _apply_norm(output, norm=norm, signal_numel=last_dim_size, forward=forward)
141
+
142
+
143
+ def _fft_r2c(
144
+ func_name: str,
145
+ input: TensorLikeType,
146
+ n: Optional[int],
147
+ dim: int,
148
+ norm: NormType,
149
+ forward: bool,
150
+ onesided: bool,
151
+ ) -> TensorLikeType:
152
+ """Common code for performing any real to complex FFT (rfft or ihfft)"""
153
+ torch._check(
154
+ not input.dtype.is_complex,
155
+ lambda: f"{func_name} expects a floating point input tensor, but got {input.dtype}",
156
+ )
157
+ input = _maybe_promote_tensor_fft(input)
158
+ dims = (utils.canonicalize_dim(input.ndim, dim, wrap_scalar=False),)
159
+ dim_size = n if n is not None else input.shape[dim]
160
+ torch._check(
161
+ dim_size >= 1, lambda: f"Invalid number of data points ({dim_size}) specified"
162
+ )
163
+
164
+ if n is not None:
165
+ input = _resize_fft_input(input, dims, (n,))
166
+
167
+ ret = prims.fft_r2c(input, dim=dims, onesided=onesided)
168
+ ret = _apply_norm(ret, norm, dim_size, forward)
169
+ return ret if forward else torch.conj(ret)
170
+
171
+
172
+ def _fft_c2c(
173
+ func_name: str,
174
+ input: TensorLikeType,
175
+ n: Optional[int],
176
+ dim: int,
177
+ norm: NormType,
178
+ forward: bool,
179
+ ) -> TensorLikeType:
180
+ """Common code for performing any complex to complex FFT (fft or ifft)"""
181
+ torch._check(
182
+ input.dtype.is_complex,
183
+ lambda: f"{func_name} expects a complex input tensor, but got {input.dtype}",
184
+ )
185
+ dims = (utils.canonicalize_dim(input.ndim, dim, wrap_scalar=False),)
186
+ dim_size = n if n is not None else input.shape[dim]
187
+ torch._check(
188
+ dim_size >= 1, lambda: f"Invalid number of data points ({dim_size}) specified"
189
+ )
190
+
191
+ if n is not None:
192
+ input = _resize_fft_input(input, dims, (n,))
193
+
194
+ ret = prims.fft_c2c(input, dim=dims, forward=forward)
195
+ return _apply_norm(ret, norm, dim_size, forward)
196
+
197
+
198
+ @register_decomposition(aten.fft_fft)
199
+ @out_wrapper()
200
+ def fft(
201
+ input: TensorLikeType,
202
+ n: Optional[int] = None,
203
+ dim: int = -1,
204
+ norm: NormType = None,
205
+ ) -> TensorLikeType:
206
+ if input.dtype.is_complex:
207
+ return _fft_c2c("fft", input, n, dim, norm, forward=True)
208
+ else:
209
+ return _fft_r2c("fft", input, n, dim, norm, forward=True, onesided=False)
210
+
211
+
212
+ @register_decomposition(aten.fft_ifft)
213
+ @out_wrapper()
214
+ def ifft(
215
+ input: TensorLikeType,
216
+ n: Optional[int] = None,
217
+ dim: int = -1,
218
+ norm: NormType = None,
219
+ ) -> TensorLikeType:
220
+ if input.dtype.is_complex:
221
+ return _fft_c2c("ifft", input, n, dim, norm, forward=False)
222
+ else:
223
+ return _fft_r2c("ifft", input, n, dim, norm, forward=False, onesided=False)
224
+
225
+
226
+ @register_decomposition(aten.fft_rfft)
227
+ @out_wrapper()
228
+ def rfft(
229
+ input: TensorLikeType,
230
+ n: Optional[int] = None,
231
+ dim: int = -1,
232
+ norm: NormType = None,
233
+ ) -> TensorLikeType:
234
+ return _fft_r2c("rfft", input, n, dim, norm, forward=True, onesided=True)
235
+
236
+
237
+ @register_decomposition(aten.fft_irfft)
238
+ @out_wrapper()
239
+ def irfft(
240
+ input: TensorLikeType,
241
+ n: Optional[int] = None,
242
+ dim: int = -1,
243
+ norm: NormType = None,
244
+ ) -> TensorLikeType:
245
+ return _fft_c2r("irfft", input, n, dim, norm, forward=False)
246
+
247
+
248
+ @register_decomposition(aten.fft_hfft)
249
+ @out_wrapper()
250
+ def hfft(
251
+ input: TensorLikeType,
252
+ n: Optional[int] = None,
253
+ dim: int = -1,
254
+ norm: NormType = None,
255
+ ) -> TensorLikeType:
256
+ return _fft_c2r("hfft", input, n, dim, norm, forward=True)
257
+
258
+
259
+ @register_decomposition(aten.fft_ihfft)
260
+ @out_wrapper()
261
+ def ihfft(
262
+ input: TensorLikeType,
263
+ n: Optional[int] = None,
264
+ dim: int = -1,
265
+ norm: NormType = None,
266
+ ) -> TensorLikeType:
267
+ return _fft_r2c("ihfft", input, n, dim, norm, forward=False, onesided=True)
268
+
269
+
270
+ class _ShapeAndDims(NamedTuple):
271
+ shape: Tuple[int, ...]
272
+ dims: Tuple[int, ...]
273
+
274
+
275
+ def _canonicalize_fft_shape_and_dim_args(
276
+ input: TensorLikeType, shape: Optional[ShapeType], dim: Optional[DimsType]
277
+ ) -> _ShapeAndDims:
278
+ """Convert the shape and dim arguments into a canonical form where neither are optional"""
279
+ input_dim = input.ndim
280
+ input_sizes = input.shape
281
+
282
+ if dim is not None:
283
+ if not isinstance(dim, Sequence):
284
+ dim = (dim,)
285
+ ret_dims = utils.canonicalize_dims(input_dim, dim, wrap_scalar=False)
286
+
287
+ # Check dims are unique
288
+ torch._check(
289
+ len(set(ret_dims)) == len(ret_dims), lambda: "FFT dims must be unique"
290
+ )
291
+
292
+ if shape is not None:
293
+ if not isinstance(shape, Sequence):
294
+ shape = (shape,)
295
+
296
+ # Has shape, might have dim
297
+ torch._check(
298
+ dim is None or len(dim) == len(shape),
299
+ lambda: "When given, dim and shape arguments must have the same length",
300
+ )
301
+ transform_ndim = len(shape)
302
+
303
+ torch._check(
304
+ transform_ndim <= input_dim,
305
+ lambda: f"Got shape with {transform_ndim} values but input tensor "
306
+ f"only has {input_dim} dimensions.",
307
+ )
308
+
309
+ # If shape is given, dims defaults to the last len(shape) dimensions
310
+ if dim is None:
311
+ ret_dims = tuple(range(input_dim - transform_ndim, input_dim))
312
+
313
+ # Translate any -1 values in shape to the default length
314
+ ret_shape = tuple(
315
+ s if s != -1 else input_sizes[d] for (s, d) in zip(shape, ret_dims) # type: ignore[possibly-undefined]
316
+ )
317
+ elif dim is None:
318
+ # No shape, no dim
319
+ ret_dims = tuple(range(input_dim))
320
+ ret_shape = tuple(input_sizes)
321
+ else:
322
+ # No shape, has dim
323
+ ret_shape = tuple(input_sizes[d] for d in ret_dims) # type: ignore[possibly-undefined]
324
+
325
+ for n in ret_shape:
326
+ torch._check(n > 0, lambda: f"Invalid number of data points ({n}) specified")
327
+
328
+ return _ShapeAndDims(shape=ret_shape, dims=ret_dims) # type: ignore[possibly-undefined]
329
+
330
+
331
+ def _prod(xs: Iterable[int]) -> int:
332
+ """Compute product of a list"""
333
+ prod = 1
334
+ for x in xs:
335
+ prod *= x
336
+ return prod
337
+
338
+
339
+ def _fftn_c2c(
340
+ function_name: str,
341
+ input: TensorLikeType,
342
+ shape: Tuple[int, ...],
343
+ dim: Tuple[int, ...],
344
+ norm: NormType,
345
+ forward: bool,
346
+ ) -> TensorLikeType:
347
+ """Common code for n-dimensional complex to complex FFTs (fftn or ifftn)"""
348
+ torch._check(
349
+ input.dtype.is_complex,
350
+ lambda: f"{function_name} expects a complex input tensor, "
351
+ f"but got {input.dtype}",
352
+ )
353
+ x = _resize_fft_input(input, dim, shape)
354
+ output = prims.fft_c2c(x, dim=dim, forward=forward)
355
+ return _apply_norm(output, norm=norm, signal_numel=_prod(shape), forward=forward)
356
+
357
+
358
+ @register_decomposition(aten.fft_fftn)
359
+ @out_wrapper()
360
+ def fftn(
361
+ input: TensorLikeType,
362
+ s: Optional[ShapeType] = None,
363
+ dim: Optional[DimsType] = None,
364
+ norm: NormType = None,
365
+ ) -> TensorLikeType:
366
+ (shape, dim) = _canonicalize_fft_shape_and_dim_args(input, s, dim)
367
+ x = _maybe_promote_tensor_fft(input, require_complex=True)
368
+ return _fftn_c2c("fftn", x, shape, dim, norm, forward=True)
369
+
370
+
371
+ @register_decomposition(aten.fft_ifftn)
372
+ @out_wrapper()
373
+ def ifftn(
374
+ input: TensorLikeType,
375
+ s: Optional[ShapeType] = None,
376
+ dim: Optional[DimsType] = None,
377
+ norm: NormType = None,
378
+ ) -> TensorLikeType:
379
+ (shape, dim) = _canonicalize_fft_shape_and_dim_args(input, s, dim)
380
+ x = _maybe_promote_tensor_fft(input, require_complex=True)
381
+ return _fftn_c2c("ifftn", x, shape, dim, norm, forward=False)
382
+
383
+
384
+ @register_decomposition(aten.fft_rfftn)
385
+ @out_wrapper()
386
+ def rfftn(
387
+ input: TensorLikeType,
388
+ s: Optional[ShapeType] = None,
389
+ dim: Optional[DimsType] = None,
390
+ norm: NormType = None,
391
+ ) -> TensorLikeType:
392
+ torch._check(
393
+ not input.dtype.is_complex,
394
+ lambda: f"rfftn expects a real-valued input tensor, but got {input.dtype}",
395
+ )
396
+ shape, dim = _canonicalize_fft_shape_and_dim_args(input, s, dim)
397
+ input = _maybe_promote_tensor_fft(input, require_complex=False)
398
+ input = _resize_fft_input(input, dim, shape)
399
+ out = prims.fft_r2c(input, dim=dim, onesided=True)
400
+ return _apply_norm(out, norm=norm, signal_numel=_prod(shape), forward=True)
401
+
402
+
403
+ @register_decomposition(aten.fft_ihfftn)
404
+ @out_wrapper()
405
+ def ihfftn(
406
+ input: TensorLikeType,
407
+ s: Optional[ShapeType] = None,
408
+ dim: Optional[DimsType] = None,
409
+ norm: NormType = None,
410
+ ) -> TensorLikeType:
411
+ torch._check(
412
+ not input.dtype.is_complex,
413
+ lambda: f"ihfftn expects a real-valued input tensor, but got {input.dtype}",
414
+ )
415
+ shape, dim = _canonicalize_fft_shape_and_dim_args(input, s, dim)
416
+ torch._check(len(shape) > 0, lambda: "ihfftn must transform at least one axis")
417
+ input = _maybe_promote_tensor_fft(input, require_complex=False)
418
+ input = _resize_fft_input(input, dim, shape)
419
+
420
+ tmp = prims.fft_r2c(input, dim=dim[-1:], onesided=True)
421
+
422
+ if len(dim) == 1:
423
+ tmp = _apply_norm(tmp, norm=norm, signal_numel=shape[0], forward=False)
424
+ return prims.conj(tmp)
425
+
426
+ tmp = prims.conj_physical(tmp)
427
+ tmp = prims.fft_c2c(tmp, dim=dim[:-1], forward=False)
428
+ return _apply_norm(tmp, norm=norm, signal_numel=_prod(shape), forward=False)
429
+
430
+
431
+ class _CanonicalizeC2rReturn(NamedTuple):
432
+ shape: Tuple[int, ...]
433
+ dim: Tuple[int, ...]
434
+ last_dim_size: int
435
+
436
+
437
+ def _canonicalize_fft_c2r_shape_and_dim_args(
438
+ fname: str,
439
+ input: TensorLikeType,
440
+ s: Optional[ShapeType],
441
+ dim: Optional[DimsType],
442
+ ) -> _CanonicalizeC2rReturn:
443
+ """Canonicalize shape and dim arguments for n-dimensional c2r transforms,
444
+ as well as calculating the last_dim_size which is shape[dim[-1]] for the output"""
445
+ (shape, dim) = _canonicalize_fft_shape_and_dim_args(input, s, dim)
446
+ torch._check(len(shape) > 0, lambda: f"{fname} must transform at least one axis")
447
+
448
+ if s is None or s[-1] == -1:
449
+ last_dim_size = 2 * (input.shape[dim[-1]] - 1)
450
+ else:
451
+ last_dim_size = shape[-1]
452
+
453
+ torch._check(
454
+ last_dim_size >= 1,
455
+ lambda: f"Invalid number of data points ({last_dim_size}) specified",
456
+ )
457
+
458
+ shape_list = list(shape)
459
+ shape_list[-1] = last_dim_size // 2 + 1
460
+ return _CanonicalizeC2rReturn(
461
+ shape=tuple(shape_list), dim=dim, last_dim_size=last_dim_size
462
+ )
463
+
464
+
465
+ @register_decomposition(aten.fft_irfftn)
466
+ @out_wrapper()
467
+ def irfftn(
468
+ input: TensorLikeType,
469
+ s: Optional[ShapeType] = None,
470
+ dim: Optional[DimsType] = None,
471
+ norm: NormType = None,
472
+ ) -> TensorLikeType:
473
+ shape, dim, last_dim_size = _canonicalize_fft_c2r_shape_and_dim_args(
474
+ "irfftn", input, s, dim
475
+ )
476
+ input = _maybe_promote_tensor_fft(input, require_complex=True)
477
+ input = _resize_fft_input(input, dim, shape)
478
+ out = prims.fft_c2r(input, dim=dim, last_dim_size=last_dim_size)
479
+ return _apply_norm(out, norm, _prod(out.shape[d] for d in dim), forward=False)
480
+
481
+
482
+ @register_decomposition(aten.fft_hfftn)
483
+ @out_wrapper()
484
+ def hfftn(
485
+ input: TensorLikeType,
486
+ s: Optional[ShapeType] = None,
487
+ dim: Optional[DimsType] = None,
488
+ norm: NormType = None,
489
+ ) -> TensorLikeType:
490
+ shape, dim, last_dim_size = _canonicalize_fft_c2r_shape_and_dim_args(
491
+ "hfftn", input, s, dim
492
+ )
493
+ input = _maybe_promote_tensor_fft(input, require_complex=True)
494
+ input = _resize_fft_input(input, dim, shape)
495
+
496
+ tmp = prims.fft_c2c(input, dim=dim[:-1], forward=True) if len(dim) > 1 else input
497
+ tmp = _apply_norm(tmp, norm, _prod(shape[:-1]), forward=True)
498
+ tmp = prims.conj_physical(tmp)
499
+ out = prims.fft_c2r(tmp, dim=dim[-1:], last_dim_size=last_dim_size)
500
+ return _apply_norm(out, norm, last_dim_size, forward=True)
501
+
502
+
503
+ @register_decomposition(aten.fft_fft2)
504
+ @out_wrapper()
505
+ def fft2(
506
+ input: TensorLikeType,
507
+ s: Optional[ShapeType] = None,
508
+ dim: Optional[DimsType] = (-2, -1),
509
+ norm: NormType = None,
510
+ ) -> TensorLikeType:
511
+ return torch.fft.fftn(input, s=s, dim=dim, norm=norm)
512
+
513
+
514
+ @register_decomposition(aten.fft_ifft2)
515
+ @out_wrapper()
516
+ def ifft2(
517
+ input: TensorLikeType,
518
+ s: Optional[ShapeType] = None,
519
+ dim: Optional[DimsType] = (-2, -1),
520
+ norm: NormType = None,
521
+ ) -> TensorLikeType:
522
+ return torch.fft.ifftn(input, s=s, dim=dim, norm=norm)
523
+
524
+
525
+ @register_decomposition(aten.fft_rfft2)
526
+ @out_wrapper()
527
+ def rfft2(
528
+ input: TensorLikeType,
529
+ s: Optional[ShapeType] = None,
530
+ dim: Optional[DimsType] = (-2, -1),
531
+ norm: NormType = None,
532
+ ) -> TensorLikeType:
533
+ return torch.fft.rfftn(input, s=s, dim=dim, norm=norm)
534
+
535
+
536
+ @register_decomposition(aten.fft_irfft2)
537
+ @out_wrapper()
538
+ def irfft2(
539
+ input: TensorLikeType,
540
+ s: Optional[ShapeType] = None,
541
+ dim: Optional[DimsType] = (-2, -1),
542
+ norm: NormType = None,
543
+ ) -> TensorLikeType:
544
+ return torch.fft.irfftn(input, s=s, dim=dim, norm=norm)
545
+
546
+
547
+ @register_decomposition(aten.fft_hfft2)
548
+ @out_wrapper()
549
+ def hfft2(
550
+ input: TensorLikeType,
551
+ s: Optional[ShapeType] = None,
552
+ dim: Optional[DimsType] = (-2, -1),
553
+ norm: NormType = None,
554
+ ) -> TensorLikeType:
555
+ return torch.fft.hfftn(input, s=s, dim=dim, norm=norm)
556
+
557
+
558
+ @register_decomposition(aten.fft_ihfft2)
559
+ @out_wrapper()
560
+ def ihfft2(
561
+ input: TensorLikeType,
562
+ s: Optional[ShapeType] = None,
563
+ dim: Optional[DimsType] = (-2, -1),
564
+ norm: NormType = None,
565
+ ) -> TensorLikeType:
566
+ return torch.fft.ihfftn(input, s=s, dim=dim, norm=norm)
567
+
568
+
569
+ def _default_alldims(dim: Optional[DimsType], x: TensorLikeType) -> List[int]:
570
+ """Convert Optional[DimsType] to a simple list, defaulting to all dimensions"""
571
+ if dim is None:
572
+ return list(range(x.ndim))
573
+ elif not isinstance(dim, Sequence):
574
+ return [dim]
575
+ else:
576
+ return list(dim)
577
+
578
+
579
+ @register_decomposition(aten.fft_fftshift)
580
+ def fftshift(input: TensorLikeType, dim: Optional[DimsType] = None) -> TensorLikeType:
581
+ dims = _default_alldims(dim, input)
582
+ shift = [input.shape[d] // 2 for d in dims]
583
+ return torch.roll(input, shift, dims)
584
+
585
+
586
+ @register_decomposition(aten.fft_ifftshift)
587
+ def ifftshift(input: TensorLikeType, dim: Optional[DimsType] = None) -> TensorLikeType:
588
+ dims = _default_alldims(dim, input)
589
+ shift = [(input.shape[d] + 1) // 2 for d in dims]
590
+ return torch.roll(input, shift, dims)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (706 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/modules/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (367 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/qat/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .modules import * # noqa: F403
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (6.64 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/reference/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .modules import * # noqa: F403
2
+
3
+ __all__ = [
4
+ 'Linear',
5
+ 'Conv1d',
6
+ 'Conv2d',
7
+ 'Conv3d',
8
+ 'ConvTranspose1d',
9
+ 'ConvTranspose2d',
10
+ 'ConvTranspose3d',
11
+ 'RNNCell',
12
+ 'LSTMCell',
13
+ 'GRUCell',
14
+ 'LSTM',
15
+ 'GRU',
16
+ 'Embedding',
17
+ 'EmbeddingBag',
18
+ ]
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (814 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/utils.py ADDED
@@ -0,0 +1,533 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import enum
2
+ import operator
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.ao.nn.intrinsic.quantized as nniq
7
+ import torch.ao.nn.quantized as nnq
8
+
9
+ toq = torch.ops.quantized
10
+ from typing import Tuple, Callable, Dict, Set, List, Optional, Union
11
+
12
+ from torch.fx import GraphModule
13
+ from torch.fx.graph import Node
14
+ from torch.ao.quantization import (
15
+ ObserverBase,
16
+ FakeQuantizeBase,
17
+ )
18
+ from torch.ao.quantization.utils import getattr_from_fqn
19
+ from torch.ao.quantization.observer import _is_activation_post_process
20
+
21
+ from .ns_types import NSNodeTargetType, NSResultsType
22
+
23
+ # TODO(future PR): consider deleting this enum and using the torch types
24
+ # directly. This might be tricky because it is not a one to one mapping.
25
+ class NodeInputOrOutputType(enum.Enum):
26
+ FP32 = enum.auto() # torch.float
27
+ INT8 = enum.auto() # torch.qint8 or torch.quint8
28
+ FP16 = enum.auto() # torch.float16
29
+ UNKNOWN = enum.auto() # we cannot determine input/output dtype
30
+ # TODO(future PR): while these functions can support multiple dtypes,
31
+ # for the purposes of numerical debugging we want to get the actual
32
+ # dtype used in the model. We will likely need some kind of dtype
33
+ # propagation to estimate this.
34
+ FP32_OR_INT8 = enum.auto() # either torch.float or torch.quint8 or torch.qint8
35
+ # TODO(future PRs): dynamic quant, fake quant, etc
36
+
37
+
38
+ def get_node_first_input_and_output_type(
39
+ node: Node,
40
+ gm: GraphModule,
41
+ logger_cls: Callable,
42
+ node_type_to_io_type_map: Dict[str, Set[NSNodeTargetType]],
43
+ ) -> Tuple[NodeInputOrOutputType, NodeInputOrOutputType]:
44
+
45
+ # TODO(future PR): clean this up
46
+ FUNS_IO_TYPE_FP32 = node_type_to_io_type_map["funs_io_type_fp32"]
47
+ FUNS_IO_TYPE_FP16 = node_type_to_io_type_map["funs_io_type_fp16"]
48
+ FUNS_IO_TYPE_INT8 = node_type_to_io_type_map["funs_io_type_int8"]
49
+ FUNS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map["funs_io_type_fp32_or_int8"]
50
+ MODS_IO_TYPE_FP32 = node_type_to_io_type_map["mods_io_type_fp32"]
51
+ MODS_IO_TYPE_INT8 = node_type_to_io_type_map["mods_io_type_int8"]
52
+ MODS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map["mods_io_type_fp32_or_int8"]
53
+ METHS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map["meths_io_type_fp32_or_int8"]
54
+
55
+ if node.op == "call_function":
56
+ if node.target in FUNS_IO_TYPE_FP32:
57
+ return (NodeInputOrOutputType.FP32, NodeInputOrOutputType.FP32)
58
+ if node.target in FUNS_IO_TYPE_FP16:
59
+ return (NodeInputOrOutputType.FP16, NodeInputOrOutputType.FP16)
60
+ elif node.target in FUNS_IO_TYPE_INT8:
61
+ return (NodeInputOrOutputType.INT8, NodeInputOrOutputType.INT8)
62
+ elif node.target in FUNS_IO_TYPE_FP32_OR_INT8:
63
+ first_arg = get_normalized_nth_input(node, gm, 0)
64
+ assert isinstance(first_arg, Node)
65
+ (
66
+ _prev_node_input_type,
67
+ prev_node_output_type,
68
+ ) = get_node_first_input_and_output_type(
69
+ first_arg, gm, logger_cls, node_type_to_io_type_map
70
+ )
71
+ return (prev_node_output_type, prev_node_output_type)
72
+ else:
73
+ return (NodeInputOrOutputType.UNKNOWN, NodeInputOrOutputType.UNKNOWN)
74
+
75
+ elif node.op == "call_module":
76
+ assert node.op == "call_module"
77
+ assert isinstance(node.target, str)
78
+ mod = getattr_from_fqn(gm, node.target)
79
+ is_known_fp32_or_int8_input_module = any(
80
+ isinstance(mod, target_type) for target_type in MODS_IO_TYPE_FP32_OR_INT8 # type: ignore[arg-type]
81
+ )
82
+ if (
83
+ isinstance(mod, (logger_cls, ObserverBase, FakeQuantizeBase)) # type: ignore[arg-type]
84
+ or is_known_fp32_or_int8_input_module
85
+ ):
86
+ # A logger or observer's input and output type is the output
87
+ # type of the preceding node.
88
+ first_arg = get_normalized_nth_input(node, gm, 0)
89
+ assert isinstance(first_arg, Node)
90
+ (
91
+ _prev_node_input_type,
92
+ prev_node_output_type,
93
+ ) = get_node_first_input_and_output_type(
94
+ first_arg, gm, logger_cls, node_type_to_io_type_map
95
+ )
96
+ return (prev_node_output_type, prev_node_output_type)
97
+ is_known_fp32_input_module = any(
98
+ isinstance(mod, target_type) for target_type in MODS_IO_TYPE_FP32 # type: ignore[arg-type]
99
+ )
100
+ is_known_int8_input_module = any(
101
+ isinstance(mod, target_type) for target_type in MODS_IO_TYPE_INT8 # type: ignore[arg-type]
102
+ )
103
+ if is_known_fp32_input_module:
104
+ return (NodeInputOrOutputType.FP32, NodeInputOrOutputType.FP32)
105
+ elif is_known_int8_input_module:
106
+ return (NodeInputOrOutputType.INT8, NodeInputOrOutputType.INT8)
107
+ else:
108
+ return (NodeInputOrOutputType.UNKNOWN, NodeInputOrOutputType.UNKNOWN)
109
+
110
+ elif node.op == "call_method":
111
+ if node.target == "dequantize":
112
+ # Dequantize is a special node because it allows multiple input types.
113
+ # So, we look up the output type of the previous node and return that
114
+ # as the input type of this node instance.
115
+ prev_node = get_normalized_nth_input(node, gm, 0)
116
+ assert isinstance(prev_node, Node)
117
+ (
118
+ _prev_node_input_type,
119
+ prev_node_output_type,
120
+ ) = get_node_first_input_and_output_type(
121
+ prev_node, gm, logger_cls, node_type_to_io_type_map
122
+ )
123
+ return (prev_node_output_type, NodeInputOrOutputType.FP32)
124
+
125
+ elif node.target == "to":
126
+ # to is a special node because it allows multiple input types.
127
+ # So, we look up the output type of the previous node and return that
128
+ # as the input type of this node instance. We also look up the target
129
+ # of to and return the correct output type.
130
+ prev_node = get_normalized_nth_input(node, gm, 0)
131
+ assert isinstance(prev_node, Node)
132
+ (
133
+ _prev_node_input_type,
134
+ prev_node_output_type,
135
+ ) = get_node_first_input_and_output_type(
136
+ prev_node, gm, logger_cls, node_type_to_io_type_map
137
+ )
138
+
139
+ cur_node_dtype_target = get_normalized_nth_input(node, gm, 1)
140
+ assert (
141
+ cur_node_dtype_target is torch.float16
142
+ ), f"{cur_node_dtype_target} handling needs to be added"
143
+
144
+ return (prev_node_output_type, NodeInputOrOutputType.FP16)
145
+
146
+ elif node.target in METHS_IO_TYPE_FP32_OR_INT8:
147
+ first_arg = get_normalized_nth_input(node, gm, 0)
148
+ assert isinstance(first_arg, Node)
149
+ (
150
+ _prev_node_input_type,
151
+ prev_node_output_type,
152
+ ) = get_node_first_input_and_output_type(
153
+ first_arg, gm, logger_cls, node_type_to_io_type_map
154
+ )
155
+ return (prev_node_output_type, prev_node_output_type)
156
+
157
+ return (NodeInputOrOutputType.UNKNOWN, NodeInputOrOutputType.UNKNOWN)
158
+ else:
159
+ return (NodeInputOrOutputType.UNKNOWN, NodeInputOrOutputType.UNKNOWN)
160
+
161
+
162
+ def get_node_input_qparams(
163
+ node: Node,
164
+ gm: GraphModule,
165
+ node_type_to_io_type_map: Dict[str, Set[NSNodeTargetType]],
166
+ ) -> Optional[Tuple[Union[torch.Tensor, float], Union[torch.Tensor, int]]]:
167
+ """
168
+ Returns the qparams (scale, zero_point) of the first input to `node`,
169
+ if they can be inferred from the graph.
170
+ """
171
+ prev_node = get_normalized_nth_input(node, gm, 0)
172
+
173
+ if not isinstance(prev_node, Node):
174
+ return None
175
+
176
+ MODS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map["mods_io_type_fp32_or_int8"]
177
+
178
+ def _get_scale_zp_from_function_args(node, gm, scale_arg_idx, zp_arg_idx):
179
+ scale_node = get_normalized_nth_input(node, gm, scale_arg_idx)
180
+ zp_node = get_normalized_nth_input(node, gm, zp_arg_idx)
181
+ assert isinstance(scale_node, Node) and isinstance(scale_node.target, str)
182
+ assert isinstance(zp_node, Node) and isinstance(zp_node.target, str)
183
+ scale_obj = getattr_from_fqn(gm, scale_node.target)
184
+ zp_obj = getattr_from_fqn(gm, zp_node.target)
185
+ return (scale_obj, zp_obj)
186
+
187
+ if prev_node.op == "call_function":
188
+
189
+ # quantize - read the args directly
190
+ if prev_node.target == torch.quantize_per_tensor:
191
+ return _get_scale_zp_from_function_args(prev_node, gm, 1, 2)
192
+ elif prev_node.target in (toq.add, toq.add_relu, toq.mul, toq.mul_relu):
193
+ return _get_scale_zp_from_function_args(prev_node, gm, 2, 3)
194
+
195
+ return None
196
+ # TODO(future PR): handle more functionals
197
+ # TODO(future PR): handle functional ops which inherit qparams from input
198
+
199
+ elif prev_node.op == "call_module":
200
+
201
+ # get type of the module
202
+ assert isinstance(prev_node.target, str)
203
+ module_obj = getattr_from_fqn(gm, prev_node.target)
204
+ if isinstance(
205
+ module_obj,
206
+ (
207
+ nnq.Linear,
208
+ nnq.Conv1d,
209
+ nnq.Conv2d,
210
+ nniq.ConvReLU2d,
211
+ nnq.Conv3d,
212
+ nnq.BatchNorm2d,
213
+ nnq.BatchNorm3d,
214
+ nnq.ConvTranspose1d,
215
+ nnq.ConvTranspose2d,
216
+ nnq.ELU,
217
+ nnq.GroupNorm,
218
+ nnq.InstanceNorm1d,
219
+ nnq.InstanceNorm2d,
220
+ nnq.InstanceNorm3d,
221
+ nnq.LayerNorm,
222
+ nnq.Hardswish,
223
+ nnq.LeakyReLU,
224
+ nnq.ReLU6,
225
+ nniq.BNReLU2d,
226
+ nniq.BNReLU3d,
227
+ nniq.ConvReLU1d,
228
+ nniq.ConvReLU2d,
229
+ nniq.ConvReLU3d,
230
+ nniq.LinearReLU,
231
+ ),
232
+ ):
233
+ return (module_obj.scale, module_obj.zero_point) # type: ignore[return-value]
234
+
235
+ is_known_fp32_or_int8_input_module = any(
236
+ isinstance(module_obj, target_type) for target_type in MODS_IO_TYPE_FP32_OR_INT8 # type: ignore[arg-type]
237
+ )
238
+ if is_known_fp32_or_int8_input_module:
239
+ return get_node_input_qparams(prev_node, gm, node_type_to_io_type_map)
240
+
241
+ return None
242
+
243
+
244
+ def return_first_non_observer_node(
245
+ node: Node,
246
+ gm: GraphModule,
247
+ ) -> Node:
248
+ """
249
+ If node is not an observer, returns it. If node is an observer,
250
+ navigates up the graph and returns the first parent which is not an
251
+ observer. For example,
252
+
253
+ graph: (node_non_obs), node = node_non_obs : returns node_non_obs
254
+ graph: (node_non_obs -> obs0), node = obs0 : returns node_non_obs
255
+ graph: (node_non_obs -> obs0 -> fq0), node = fq0 : returns node_non_obs
256
+ """
257
+ if node.op == "call_module":
258
+ node_obj = getattr_from_fqn(gm, node.target) # type: ignore[arg-type]
259
+ if _is_activation_post_process(node_obj):
260
+ assert len(node.args) == 1
261
+ assert isinstance(node.args[0], Node)
262
+ node = node.args[0]
263
+ # code duplication intended, not worth refactoring
264
+ assert isinstance(node.target, str)
265
+ node_obj = getattr_from_fqn(gm, node.target)
266
+ if _is_activation_post_process(node_obj):
267
+ assert len(node.args) == 1
268
+ assert isinstance(node.args[0], Node)
269
+ node = node.args[0]
270
+ return node
271
+
272
+
273
+ def get_number_of_non_param_args(
274
+ node: Node,
275
+ gm: GraphModule,
276
+ ) -> int:
277
+ """
278
+ Assumes that all non-param args occur first. Returns the number of
279
+ non-param args expected for a node. For example, for
280
+
281
+ F.linear(x, weight, bias)
282
+
283
+ Returns 1, because x is a non-param arg and weight and bias are params.
284
+ For
285
+
286
+ lstm_mod(x, hid)
287
+
288
+ Returns 2, because both x and hid are non-param args.
289
+ """
290
+ if node.op == "call_module":
291
+ node_obj = getattr_from_fqn(gm, node.target) # type: ignore[arg-type]
292
+ if isinstance(node_obj, nn.LSTM):
293
+ return 2
294
+
295
+ # default is 1
296
+ return 1
297
+
298
+
299
+ def get_arg_indices_of_inputs_to_log(node: Node) -> List[int]:
300
+ """
301
+ Returns the indices of args of the node which we should attach
302
+ loggers to, if input logging is enabled.
303
+
304
+ For example,
305
+ * for (x + y), returns [0, 1]
306
+ * for (1 + y), returns [1]
307
+ * for (x + 1), returns [0]
308
+ * for (linear(x, w, b)) returns [0]
309
+ * by default, returns [0]
310
+ """
311
+ if len(node.args) == 0:
312
+ return []
313
+ if node.op == "call_function" and (
314
+ # TODO(future PR): use relationship map instead of hardcoding
315
+ node.target in (torch.add, torch.ops.quantized.add, operator.add)
316
+ or node.target in (torch.mul, torch.ops.quantized.mul, operator.mul)
317
+ ):
318
+ result = []
319
+ for i in range(2):
320
+ if type(node.args[i]) == Node:
321
+ result.append(i)
322
+ return result
323
+ return [0]
324
+
325
+
326
+ def get_target_type_str(node: Node, gm: GraphModule) -> str:
327
+ """
328
+ Returns a string representation of the type of the function or module
329
+ pointed to by this node, or '' for other node types.
330
+ """
331
+ target_type = ""
332
+ if node.op in ("call_function", "call_method"):
333
+ target_type = torch.typename(node.target)
334
+ elif node.op == "call_module":
335
+ assert isinstance(node.target, str)
336
+ target_mod = getattr_from_fqn(gm, node.target)
337
+ target_type = torch.typename(target_mod)
338
+ return target_type
339
+
340
+
341
+ def rekey_logger_info_on_node_name_of_model(
342
+ results: NSResultsType,
343
+ model_name: str,
344
+ ) -> NSResultsType:
345
+ """
346
+ Rekeys the layer name of a results dictionary to use node names
347
+ from `model_name`.
348
+
349
+ For example, transforms
350
+
351
+ {'base_op_1_0': {'node_output': {'model_a':
352
+ [{'ref_node_name': 'linear1', ...}]}}}
353
+
354
+ into
355
+
356
+ {'linear1': {'node_output': {'model_a':
357
+ [{'ref_node_name': 'linear1', ...}]}}}
358
+
359
+ Note: we cannot use these node names directly because they are not
360
+ guaranteed to be consistent across models. This is why we extract
361
+ the results first and rekey afterwards.
362
+ """
363
+ new_results = {}
364
+ for old_layer_name, result_type_to_results in results.items():
365
+ new_layer_name = None
366
+ for model_name_to_results in result_type_to_results.values():
367
+ for cur_model_name, list_of_results in model_name_to_results.items():
368
+ if cur_model_name == model_name:
369
+ assert len(list_of_results)
370
+ new_layer_name = list_of_results[0]["ref_node_name"]
371
+ else:
372
+ continue
373
+ if new_layer_name is not None:
374
+ new_results[new_layer_name] = result_type_to_results
375
+ else:
376
+ new_results[old_layer_name] = result_type_to_results
377
+ return new_results
378
+
379
+
380
+ def maybe_add_missing_fqns(results: NSResultsType) -> None:
381
+ """
382
+ If `fqn` entries are filled in for one of the models in `results`, copies
383
+ them over to any models which do not have them filled out.
384
+
385
+ A common use case benefitting from this is comparing a model prepared by
386
+ quantization to a quantized model. In this case, the model prepared by
387
+ quantization would have `fqn` entries, and the quantized model would not.
388
+ """
389
+
390
+ # Check in the first result to find any model with fqn entries defined.
391
+ model_name_with_fqns = None
392
+ for result_type_to_results in results.values():
393
+ for model_name_to_results in result_type_to_results.values():
394
+ for model_name, model_results in model_name_to_results.items():
395
+ if len(model_results) > 0:
396
+ if model_results[0]["fqn"] is not None:
397
+ model_name_with_fqns = model_name
398
+ break
399
+ break
400
+ break
401
+
402
+ if model_name_with_fqns:
403
+ for result_type_to_results in results.values():
404
+ for model_name_to_results in result_type_to_results.values():
405
+ ref_model_results = model_name_to_results[model_name_with_fqns]
406
+ for model_name, model_results in model_name_to_results.items():
407
+ if model_name == model_name_with_fqns:
408
+ continue
409
+ for i in range(len(model_results)):
410
+ fqn = ref_model_results[i]["fqn"]
411
+ model_results[i]["fqn"] = fqn
412
+
413
+
414
+ def maybe_dequantize_first_two_tensor_args_and_handle_tuples(f):
415
+ def inner(*args, **kwargs):
416
+ a0, a1, *a_other = args
417
+
418
+ if (isinstance(a0, tuple) and isinstance(a1, tuple)) or (
419
+ isinstance(a0, list) and isinstance(a1, list)
420
+ ):
421
+ results = []
422
+ for el0, el1 in zip(a0, a1):
423
+ new_args = (el0, el1, *a_other)
424
+ results.append(inner(*new_args, **kwargs))
425
+ return results
426
+
427
+ elif isinstance(a0, torch.Tensor) and isinstance(a1, torch.Tensor):
428
+ if a0.is_quantized:
429
+ a0 = a0.dequantize()
430
+ if a1.is_quantized:
431
+ a1 = a1.dequantize()
432
+
433
+ # for the purposes of this util, only handle floats
434
+ if a0.dtype != torch.float or a1.dtype != torch.float:
435
+ return None
436
+
437
+ new_args = (a0, a1, *a_other)
438
+ return f(*new_args, **kwargs)
439
+
440
+ return inner
441
+
442
+
443
+ @maybe_dequantize_first_two_tensor_args_and_handle_tuples
444
+ def compute_sqnr(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
445
+ """
446
+ Computes the SQNR between `x` and `y`.
447
+
448
+ Args:
449
+ x: Tensor or tuple of tensors
450
+ y: Tensor or tuple of tensors
451
+
452
+ Return:
453
+ float or tuple of floats
454
+ """
455
+ Ps = torch.norm(x)
456
+ Pn = torch.norm(x - y)
457
+ return 20 * torch.log10(Ps / Pn)
458
+
459
+
460
+ @maybe_dequantize_first_two_tensor_args_and_handle_tuples
461
+ def compute_normalized_l2_error(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
462
+ """
463
+ Computes the normalized L2 error between `x` and `y`.
464
+
465
+ Args:
466
+ x: Tensor or tuple of tensors
467
+ y: Tensor or tuple of tensors
468
+
469
+ Return:
470
+ float or tuple of floats
471
+ """
472
+ return torch.sqrt(((x - y) ** 2).sum() / (x ** 2).sum())
473
+
474
+
475
+ @maybe_dequantize_first_two_tensor_args_and_handle_tuples
476
+ def compute_cosine_similarity(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
477
+ """
478
+ Computes the cosine similarity between `x` and `y`.
479
+
480
+ Args:
481
+ x: Tensor or tuple of tensors
482
+ y: Tensor or tuple of tensors
483
+
484
+ Return:
485
+ float or tuple of floats
486
+ """
487
+ # For convolutions, the shape of the quantized weight has one additional
488
+ # dimension compared to the shape of the fp32 weight. Match the shapes
489
+ # to enable cosine similarity comparison.
490
+ x = x.reshape(1, -1)
491
+ y = y.reshape(1, -1)
492
+ return torch.nn.functional.cosine_similarity(x, y)
493
+
494
+ def op_type_supports_shadowing(node: Node) -> bool:
495
+ if node.op == 'call_function':
496
+ if node.target in (torch.add, torch.mul, operator.add, operator.mul, torch.cat, torch.stack):
497
+ # shadowing for ops with multiple tensor inputs is not implemented yet
498
+ return False
499
+ return True
500
+
501
+ def get_normalized_nth_input(node: Node, gm: GraphModule, idx: int) -> Node:
502
+ """
503
+ Given a node, gets the n'th input to that node, normalizing
504
+ args and kwargs to the best of its ability.
505
+ """
506
+ try:
507
+ norm_args_and_kwargs = node.normalized_arguments(
508
+ gm, normalize_to_only_use_kwargs=True)
509
+ if norm_args_and_kwargs is not None:
510
+ norm_args, norm_kwargs = norm_args_and_kwargs
511
+ assert len(norm_args) + len(norm_kwargs) > idx
512
+ if idx < len(norm_args):
513
+ return norm_args[idx]
514
+ else:
515
+ # note: in Python 3.7+ dicts are ordered
516
+ return list(norm_kwargs.values())[idx]
517
+ else:
518
+ assert len(node.args) + len(node.kwargs) > idx
519
+ if idx < len(node.args):
520
+ return node.args[idx] # type: ignore[return-value]
521
+ else:
522
+ kwargs_idx = idx + len(node.args)
523
+ return list(node.kwargs.values())[kwargs_idx] # type: ignore[return-value]
524
+ except RuntimeError:
525
+ # this RuntimeError happens when node argument normalization
526
+ # requires typehints to proceed, such as for torch.add where
527
+ # either the first, second or both arguments could be tensors
528
+ assert len(node.args) + len(node.kwargs) > idx
529
+ if idx < len(node.args):
530
+ return node.args[idx] # type: ignore[return-value]
531
+ else:
532
+ kwargs_idx = idx + len(node.args)
533
+ return list(node.kwargs.values())[kwargs_idx] # type: ignore[return-value]
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (6.95 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/_correct_bias.cpython-311.pyc ADDED
Binary file (7.2 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/_equalize.cpython-311.pyc ADDED
Binary file (9.7 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/fuser_method_mappings.cpython-311.pyc ADDED
Binary file (13.4 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/quant_type.cpython-311.pyc ADDED
Binary file (1.43 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/quantize_pt2e.cpython-311.pyc ADDED
Binary file (10.8 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/stubs.cpython-311.pyc ADDED
Binary file (4.07 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/__pycache__/native.cpython-311.pyc ADDED
Binary file (7.07 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/_qnnpack_pt2e.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import operator
2
+ import torch
3
+ from torch.ao.quantization.backend_config import (
4
+ BackendConfig,
5
+ DTypeConfig,
6
+ ObservationType,
7
+ BackendPatternConfig,
8
+ )
9
+
10
+ weighted_op_quint8_dtype_config = DTypeConfig(
11
+ input_dtype=torch.quint8,
12
+ output_dtype=torch.quint8,
13
+ weight_dtype=torch.qint8,
14
+ bias_dtype=torch.float,
15
+ )
16
+ from typing import List
17
+
18
+ def get_linear_configs():
19
+ linear_configs = []
20
+ observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
21
+ dtype_configs = [weighted_op_quint8_dtype_config]
22
+
23
+ # TODO: need to fix the way we insert observers for this pattern
24
+ # should be solved in the new fusion API
25
+ # reason that this doesn't work: the pattern is a bit complicated and we don't
26
+ # have a way to specify which input of the pattern we would like to observe
27
+ # pattern:
28
+ # bias input weight
29
+ # \ | /
30
+ # \ | t
31
+ # \ | /
32
+ # addmm
33
+ # we want to observe "weight" as weight, but there is not way to convey this
34
+ # information with current pattern language
35
+ #
36
+ # right now:
37
+ # original:
38
+ # weight - t \
39
+ # input - addmm
40
+ # observed (no hack):
41
+ # weight - t - observer \
42
+ # input - observer - addmm
43
+ # target:
44
+ # weight - observer - t \
45
+ # input - observer - addmm
46
+
47
+ # def root_node_getter(node_pattern):
48
+ # addmm, bias, act, weight = node_pattern
49
+ # return addmm
50
+
51
+ # linear_configs.append(
52
+ # BackendPatternConfig((torch.ops.aten.addmm.default, MatchAllNode, MatchAllNode, torch.ops.aten.t.default))
53
+ # .set_observation_type(observation_type) # noqa: E131
54
+ # .set_dtype_configs(dtype_configs)
55
+ # ._set_root_node_getter(root_node_getter))
56
+
57
+ linear_configs.append(
58
+ BackendPatternConfig(torch.ops.aten.addmm.default)
59
+ .set_observation_type(observation_type) # noqa: E131
60
+ .set_dtype_configs(dtype_configs)
61
+ ._set_input_type_to_index({"weight": 2, "bias": 0})
62
+ )
63
+ # linear is decomposed to `t - mm` if bias is not present
64
+ linear_configs.append(
65
+ BackendPatternConfig(torch.ops.aten.mm.default)
66
+ .set_observation_type(observation_type) # noqa: E131
67
+ .set_dtype_configs(dtype_configs)
68
+ ._set_input_type_to_index({"weight": 1})
69
+ )
70
+ return linear_configs
71
+
72
+ def get_conv_configs():
73
+ conv_configs = []
74
+ observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
75
+ dtype_configs = [weighted_op_quint8_dtype_config]
76
+ conv_configs.append(
77
+ BackendPatternConfig(torch.ops.aten.convolution.default)
78
+ .set_observation_type(observation_type) # noqa: E131
79
+ .set_dtype_configs(dtype_configs)
80
+ ._set_input_type_to_index({"weight": 1, "bias": 2})
81
+ )
82
+ conv_configs.append(
83
+ BackendPatternConfig((torch.ops.aten.convolution.default, torch.ops.aten.relu.default))
84
+ .set_observation_type(observation_type) # noqa: E131
85
+ .set_dtype_configs(dtype_configs)
86
+ ._set_input_type_to_index({"weight": 1, "bias": 2})
87
+ )
88
+ # TODO: remove when functionalization is supported in PT2 mode
89
+ conv_configs.append(
90
+ BackendPatternConfig((torch.ops.aten.convolution.default, torch.ops.aten.relu_.default))
91
+ .set_observation_type(observation_type) # noqa: E131
92
+ .set_dtype_configs(dtype_configs)
93
+ ._set_input_type_to_index({"weight": 1, "bias": 2})
94
+ )
95
+ return conv_configs
96
+
97
+ def get_pooling_configs():
98
+ backend_pattern_configs = []
99
+ observation_type = ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT
100
+ dtype_configs = [weighted_op_quint8_dtype_config]
101
+
102
+ def root_node_getter(node_pattern):
103
+ getitem, maxpool, index = node_pattern
104
+ return maxpool
105
+
106
+ backend_pattern_configs.append(
107
+ BackendPatternConfig()
108
+ ._set_pattern_complex_format((operator.getitem, torch.ops.aten.max_pool2d_with_indices.default, 0))
109
+ .set_observation_type(observation_type) # noqa: E131
110
+ .set_dtype_configs(dtype_configs)
111
+ ._set_root_node_getter(root_node_getter)
112
+ )
113
+
114
+ return backend_pattern_configs
115
+
116
+ def get_relu_configs():
117
+ backend_pattern_configs = []
118
+ observation_type = ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT
119
+ dtype_configs = [weighted_op_quint8_dtype_config]
120
+ backend_pattern_configs.append(
121
+ BackendPatternConfig(torch.ops.aten.relu.default)
122
+ .set_observation_type(observation_type) # noqa: E131
123
+ .set_dtype_configs(dtype_configs))
124
+ return backend_pattern_configs
125
+
126
+ def get_binary_op_configs():
127
+ binary_op_configs: List[BackendPatternConfig] = []
128
+ dtype_configs = [weighted_op_quint8_dtype_config]
129
+ num_tensor_args_to_observation_type_mapping = {
130
+ # TODO: this is not used right now since we have extra check in prepare
131
+ # will need to change this to NO_OBSERVER later after we implemented
132
+ # Tensor dtype inference properly
133
+ 0: ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
134
+ 1: ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT,
135
+ 2: ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
136
+ }
137
+ for op_with_quantized_bop_scalar_variant in [torch.ops.aten.add.Tensor, torch.ops.aten.add_.Tensor]:
138
+ bop_patterns = [
139
+ (op_with_quantized_bop_scalar_variant, torch.ops.aten.relu.default),
140
+ op_with_quantized_bop_scalar_variant,
141
+ # TODO: remove when functionalization is supported in pt2_mode
142
+ (op_with_quantized_bop_scalar_variant, torch.ops.aten.relu_.default),
143
+ ]
144
+ for bop_pattern in bop_patterns:
145
+ binary_op_configs.append(
146
+ BackendPatternConfig(bop_pattern)
147
+ .set_dtype_configs(dtype_configs) # noqa: E131
148
+ ._set_num_tensor_args_to_observation_type(num_tensor_args_to_observation_type_mapping))
149
+
150
+ return binary_op_configs
151
+
152
+ def get_qnnpack_pt2e_backend_config():
153
+ return (
154
+ BackendConfig("qnnpack_pytorch_2.0_export")
155
+ .set_backend_pattern_configs(get_linear_configs())
156
+ .set_backend_pattern_configs(get_binary_op_configs())
157
+ .set_backend_pattern_configs(get_conv_configs())
158
+ .set_backend_pattern_configs(get_pooling_configs())
159
+ .set_backend_pattern_configs(get_relu_configs())
160
+ )
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/fbgemm.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from ._common_operator_config_utils import (
3
+ _get_binary_op_configs,
4
+ _get_bn_configs,
5
+ _get_cat_config,
6
+ _get_conv_configs,
7
+ _get_default_op_configs,
8
+ _get_embedding_op_configs,
9
+ _get_fixed_qparams_op_configs,
10
+ _get_linear_configs,
11
+ _get_rnn_op_configs,
12
+ _get_share_qparams_op_configs,
13
+ _get_tensor_info_op_configs,
14
+ )
15
+ from .backend_config import BackendConfig, DTypeConfig
16
+
17
+ __all__ = [
18
+ "get_fbgemm_backend_config",
19
+ ]
20
+
21
+ # ===================
22
+ # | DTYPE CONFIGS |
23
+ # ===================
24
+
25
+ # TODO: For now, these DTypeConfigs are identical to the ones defined in native.py
26
+ # In the future, once we support specifying quant_min/quant_max and scale_min/scale_max,
27
+ # these will diverge. In particular, for FBGEMM, we will restrict the activation quantized
28
+ # values to within [0, 127].
29
+
30
+ fbgemm_weighted_op_quint8_dtype_config = DTypeConfig(
31
+ input_dtype=torch.quint8,
32
+ output_dtype=torch.quint8,
33
+ weight_dtype=torch.qint8,
34
+ bias_dtype=torch.float,
35
+ )
36
+
37
+ fbgemm_default_op_quint8_dtype_config = DTypeConfig(
38
+ input_dtype=torch.quint8,
39
+ output_dtype=torch.quint8,
40
+ )
41
+
42
+ fbgemm_default_op_fp16_dtype_config = DTypeConfig(
43
+ input_dtype=torch.float16,
44
+ output_dtype=torch.float16,
45
+ weight_dtype=torch.float16,
46
+ bias_dtype=torch.float16,
47
+ )
48
+
49
+ fbgemm_default_dynamic_int8_dtype_config = DTypeConfig(
50
+ input_dtype=torch.quint8,
51
+ output_dtype=torch.float,
52
+ weight_dtype=torch.qint8,
53
+ bias_dtype=torch.float,
54
+ is_dynamic=True,
55
+ )
56
+
57
+ fbgemm_default_dynamic_float16_dtype_config = DTypeConfig(
58
+ input_dtype=torch.float16,
59
+ output_dtype=torch.float,
60
+ weight_dtype=torch.float16,
61
+ bias_dtype=torch.float,
62
+ is_dynamic=True,
63
+ )
64
+
65
+ fbgemm_weight_only_quint8_dtype_config = DTypeConfig(
66
+ input_dtype=torch.float,
67
+ output_dtype=torch.float,
68
+ weight_dtype=torch.quint8,
69
+ )
70
+
71
+ fbgemm_weight_only_quint4x2_dtype_config = DTypeConfig(
72
+ input_dtype=torch.float,
73
+ output_dtype=torch.float,
74
+ weight_dtype=torch.quint4x2,
75
+ )
76
+
77
+
78
+ # =====================
79
+ # | BACKEND CONFIGS |
80
+ # =====================
81
+
82
+ def get_fbgemm_backend_config() -> BackendConfig:
83
+ """
84
+ Return the `BackendConfig` for PyTorch's native FBGEMM backend.
85
+ """
86
+ conv_dtype_configs = [fbgemm_weighted_op_quint8_dtype_config]
87
+ linear_dtype_configs = [
88
+ fbgemm_weighted_op_quint8_dtype_config,
89
+ fbgemm_default_dynamic_int8_dtype_config,
90
+ fbgemm_default_dynamic_float16_dtype_config,
91
+ ]
92
+ binary_op_dtype_configs = [fbgemm_default_op_quint8_dtype_config]
93
+ default_op_dtype_configs = [fbgemm_default_op_quint8_dtype_config]
94
+ fixed_qparams_op_dtype_configs = [fbgemm_default_op_quint8_dtype_config]
95
+ share_qparams_op_dtype_configs = [fbgemm_default_op_quint8_dtype_config]
96
+ tensor_info_op_dtype_configs = [fbgemm_default_op_quint8_dtype_config]
97
+ rnn_op_dtype_configs = [
98
+ fbgemm_default_dynamic_int8_dtype_config,
99
+ fbgemm_default_dynamic_float16_dtype_config,
100
+ ]
101
+ embedding_op_dtype_configs = [
102
+ fbgemm_weight_only_quint8_dtype_config,
103
+ fbgemm_weight_only_quint4x2_dtype_config,
104
+ ]
105
+ return BackendConfig("fbgemm") \
106
+ .set_backend_pattern_configs(_get_conv_configs(conv_dtype_configs)) \
107
+ .set_backend_pattern_configs(_get_linear_configs(linear_dtype_configs)) \
108
+ .set_backend_pattern_configs(_get_binary_op_configs(binary_op_dtype_configs)) \
109
+ .set_backend_pattern_config(_get_cat_config(default_op_dtype_configs)) \
110
+ .set_backend_pattern_configs(_get_default_op_configs(default_op_dtype_configs)) \
111
+ .set_backend_pattern_configs(_get_fixed_qparams_op_configs(fixed_qparams_op_dtype_configs)) \
112
+ .set_backend_pattern_configs(_get_share_qparams_op_configs(share_qparams_op_dtype_configs)) \
113
+ .set_backend_pattern_configs(_get_tensor_info_op_configs(tensor_info_op_dtype_configs)) \
114
+ .set_backend_pattern_configs(_get_bn_configs(default_op_dtype_configs)) \
115
+ .set_backend_pattern_configs(_get_rnn_op_configs(rnn_op_dtype_configs)) \
116
+ .set_backend_pattern_configs(_get_embedding_op_configs(embedding_op_dtype_configs))
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/native.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from ._common_operator_config_utils import (
3
+ _get_binary_op_configs,
4
+ _get_bn_configs,
5
+ _get_cat_config,
6
+ _get_conv_configs,
7
+ _get_default_op_configs,
8
+ _get_embedding_op_configs,
9
+ _get_fixed_qparams_op_configs,
10
+ _get_linear_configs,
11
+ _get_ln_configs,
12
+ _get_rnn_op_configs,
13
+ _get_share_qparams_op_configs,
14
+ _get_tensor_info_op_configs,
15
+ )
16
+ from .backend_config import BackendConfig, DTypeConfig
17
+
18
+ __all__ = [
19
+ "get_test_only_legacy_native_backend_config",
20
+ "default_op_quint8_dtype_config",
21
+ "default_op_fp16_dtype_config",
22
+ "default_dynamic_int8_dtype_config",
23
+ "default_dynamic_float16_dtype_config",
24
+ "input_output_only_quint8_dtype_config",
25
+ "weight_only_quint8_dtype_config",
26
+ "weight_only_quint4x2_dtype_config",
27
+ "get_native_backend_config",
28
+ "get_native_backend_config_dict",
29
+ "get_test_only_legacy_native_backend_config_dict",
30
+ ]
31
+
32
+ # ===================
33
+ # | DTYPE CONFIGS |
34
+ # ===================
35
+
36
+ # weighted op int8 dtype config
37
+ # this is config for ops that has quantized weights, like linear, conv
38
+ weighted_op_quint8_dtype_config = DTypeConfig(
39
+ input_dtype=torch.quint8,
40
+ output_dtype=torch.quint8,
41
+ weight_dtype=torch.qint8,
42
+ bias_dtype=torch.float,
43
+ )
44
+
45
+ default_op_quint8_dtype_config = DTypeConfig(
46
+ input_dtype=torch.quint8,
47
+ output_dtype=torch.quint8,
48
+ )
49
+
50
+ default_op_fp16_dtype_config = DTypeConfig(
51
+ input_dtype=torch.float16,
52
+ output_dtype=torch.float16,
53
+ weight_dtype=torch.float16,
54
+ bias_dtype=torch.float16,
55
+ )
56
+
57
+ default_dynamic_int8_dtype_config = DTypeConfig(
58
+ input_dtype=torch.quint8,
59
+ output_dtype=torch.float,
60
+ weight_dtype=torch.qint8,
61
+ bias_dtype=torch.float,
62
+ # currently the dtype check is not yet enabled, so we provided the dtype_configs but
63
+ # it is not really used yet,
64
+ # we will enable it a bit later after we moved everything to backend_config_dict
65
+ is_dynamic=True,
66
+ )
67
+
68
+ default_dynamic_float16_dtype_config = DTypeConfig(
69
+ input_dtype=torch.float16,
70
+ output_dtype=torch.float,
71
+ weight_dtype=torch.float16,
72
+ bias_dtype=torch.float,
73
+ # currently the dtype check is not yet enabled, so we provided the dtype_configs but
74
+ # it is not really used yet,
75
+ # we will enable it a bit later after we moved everything to backend_config_dict
76
+ is_dynamic=True,
77
+ )
78
+
79
+ # Needed for LayerNorm and f.layer_norm, since currently the kernel only supports float weights
80
+ input_output_only_quint8_dtype_config = DTypeConfig(
81
+ input_dtype=torch.quint8,
82
+ output_dtype=torch.quint8,
83
+ weight_dtype=torch.float,
84
+ bias_dtype=torch.float,
85
+ )
86
+
87
+ weight_only_quint8_dtype_config = DTypeConfig(
88
+ input_dtype=torch.float,
89
+ output_dtype=torch.float,
90
+ weight_dtype=torch.quint8,
91
+ )
92
+
93
+ weight_only_quint4x2_dtype_config = DTypeConfig(
94
+ input_dtype=torch.float,
95
+ output_dtype=torch.float,
96
+ weight_dtype=torch.quint4x2,
97
+ )
98
+
99
+
100
+ # =====================
101
+ # | BACKEND CONFIGS |
102
+ # =====================
103
+
104
+ def get_test_only_legacy_native_backend_config() -> BackendConfig:
105
+ """
106
+ Return the `BackendConfig` for PyTorch Native backend (fbgemm/qnnpack) with various additional fp16 ops.
107
+ """
108
+ conv_dtype_configs = [weighted_op_quint8_dtype_config]
109
+ linear_dtype_configs = [
110
+ weighted_op_quint8_dtype_config,
111
+ default_dynamic_int8_dtype_config,
112
+ default_dynamic_float16_dtype_config,
113
+ default_op_fp16_dtype_config,
114
+ ]
115
+ binary_op_dtype_configs = [
116
+ default_op_quint8_dtype_config,
117
+ default_op_fp16_dtype_config,
118
+ ]
119
+ default_op_dtype_configs = [default_op_quint8_dtype_config]
120
+ fixed_qparams_op_dtype_configs = [
121
+ default_op_quint8_dtype_config,
122
+ default_op_fp16_dtype_config,
123
+ ]
124
+ share_qparams_op_dtype_configs = [
125
+ default_op_quint8_dtype_config,
126
+ default_op_fp16_dtype_config
127
+ ]
128
+ tensor_info_op_dtype_configs = [
129
+ default_op_quint8_dtype_config,
130
+ ]
131
+ rnn_op_dtype_configs = [
132
+ default_dynamic_int8_dtype_config,
133
+ default_dynamic_float16_dtype_config,
134
+ ]
135
+ embedding_op_dtype_configs = [
136
+ weight_only_quint8_dtype_config,
137
+ weight_only_quint4x2_dtype_config,
138
+ ]
139
+ layer_norm_op_dtype_configs = [input_output_only_quint8_dtype_config]
140
+ return BackendConfig("_native_and_fp16") \
141
+ .set_backend_pattern_configs(_get_conv_configs(conv_dtype_configs)) \
142
+ .set_backend_pattern_configs(_get_linear_configs(linear_dtype_configs)) \
143
+ .set_backend_pattern_configs(_get_binary_op_configs(binary_op_dtype_configs)) \
144
+ .set_backend_pattern_config(_get_cat_config(default_op_dtype_configs)) \
145
+ .set_backend_pattern_configs(_get_default_op_configs(default_op_dtype_configs)) \
146
+ .set_backend_pattern_configs(_get_fixed_qparams_op_configs(fixed_qparams_op_dtype_configs)) \
147
+ .set_backend_pattern_configs(_get_share_qparams_op_configs(share_qparams_op_dtype_configs)) \
148
+ .set_backend_pattern_configs(_get_tensor_info_op_configs(tensor_info_op_dtype_configs)) \
149
+ .set_backend_pattern_configs(_get_bn_configs(default_op_dtype_configs)) \
150
+ .set_backend_pattern_configs(_get_ln_configs(layer_norm_op_dtype_configs)) \
151
+ .set_backend_pattern_configs(_get_rnn_op_configs(rnn_op_dtype_configs)) \
152
+ .set_backend_pattern_configs(_get_embedding_op_configs(embedding_op_dtype_configs))
153
+
154
+ def get_native_backend_config() -> BackendConfig:
155
+ """
156
+ Return the `BackendConfig` for PyTorch Native backend (fbgemm/qnnpack).
157
+ """
158
+ # TODO: express this BackendConfig as a union of the FBGEMM and QNNPACK BackendConfigs
159
+ conv_dtype_configs = [weighted_op_quint8_dtype_config]
160
+ linear_dtype_configs = [
161
+ weighted_op_quint8_dtype_config,
162
+ default_dynamic_int8_dtype_config,
163
+ default_dynamic_float16_dtype_config,
164
+ ]
165
+ binary_op_dtype_configs = [default_op_quint8_dtype_config]
166
+ default_op_dtype_configs = [default_op_quint8_dtype_config]
167
+ fixed_qparams_op_dtype_configs = [default_op_quint8_dtype_config]
168
+ share_qparams_op_dtype_configs = [default_op_quint8_dtype_config]
169
+ tensor_info_op_dtype_configs = [default_op_quint8_dtype_config]
170
+ rnn_op_dtype_configs = [
171
+ default_dynamic_int8_dtype_config,
172
+ default_dynamic_float16_dtype_config,
173
+ ]
174
+ embedding_op_dtype_configs = [
175
+ weight_only_quint8_dtype_config,
176
+ weight_only_quint4x2_dtype_config,
177
+ ]
178
+ layer_norm_op_dtype_configs = [input_output_only_quint8_dtype_config]
179
+ return BackendConfig("native") \
180
+ .set_backend_pattern_configs(_get_conv_configs(conv_dtype_configs)) \
181
+ .set_backend_pattern_configs(_get_linear_configs(linear_dtype_configs)) \
182
+ .set_backend_pattern_configs(_get_binary_op_configs(binary_op_dtype_configs)) \
183
+ .set_backend_pattern_config(_get_cat_config(default_op_dtype_configs)) \
184
+ .set_backend_pattern_configs(_get_default_op_configs(default_op_dtype_configs)) \
185
+ .set_backend_pattern_configs(_get_fixed_qparams_op_configs(fixed_qparams_op_dtype_configs)) \
186
+ .set_backend_pattern_configs(_get_share_qparams_op_configs(share_qparams_op_dtype_configs)) \
187
+ .set_backend_pattern_configs(_get_tensor_info_op_configs(tensor_info_op_dtype_configs)) \
188
+ .set_backend_pattern_configs(_get_bn_configs(default_op_dtype_configs)) \
189
+ .set_backend_pattern_configs(_get_ln_configs(layer_norm_op_dtype_configs)) \
190
+ .set_backend_pattern_configs(_get_rnn_op_configs(rnn_op_dtype_configs)) \
191
+ .set_backend_pattern_configs(_get_embedding_op_configs(embedding_op_dtype_configs))
192
+
193
+ def get_native_backend_config_dict():
194
+ """
195
+ Return the `BackendConfig` for PyTorch Native backend (fbgemm/qnnpack) in dictionary form.
196
+ """
197
+ return get_native_backend_config().to_dict()
198
+
199
+ def get_test_only_legacy_native_backend_config_dict():
200
+ """
201
+ Return the `BackendConfig` for PyTorch Native backend (fbgemm/qnnpack) with various additional
202
+ fp16 ops in dictionary form.
203
+ """
204
+ return get_test_only_legacy_native_backend_config().to_dict()
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fuser_method_mappings.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch.ao.nn.intrinsic as nni
3
+
4
+ from typing import Any, Union, Callable, List, Tuple, Dict, Optional, Type
5
+ from torch.ao.quantization.utils import Pattern, get_combined_dict, MatchAllNode
6
+ import itertools
7
+
8
+ __all__ = [
9
+ "fuse_conv_bn",
10
+ "fuse_conv_bn_relu",
11
+ "fuse_linear_bn",
12
+ "fuse_convtranspose_bn",
13
+ "get_fuser_method",
14
+ "get_fuser_method_new",
15
+ ]
16
+
17
+ def fuse_conv_bn(is_qat, conv, bn):
18
+ r"""Return the fused the conv and bn modules.
19
+ Given the conv and bn modules, fuses them and returns the fused module
20
+
21
+ Args:
22
+ is_qat: a flag for whether we are using quantization aware training fusion
23
+ or post training quantization fusion
24
+ conv: Module instance of type conv2d/conv3d
25
+ bn: Spatial BN instance that needs to be fused with the conv
26
+
27
+ Examples::
28
+
29
+ >>> m1 = nn.Conv2d(10, 20, 3)
30
+ >>> b1 = nn.BatchNorm2d(20)
31
+ >>> # xdoctest: +SKIP
32
+ >>> m2 = fuse_conv_bn(m1, b1)
33
+ """
34
+ assert conv.training == bn.training, \
35
+ "Conv and BN both must be in the same mode (train or eval)."
36
+
37
+ fused_module_class_map = {
38
+ nn.Conv1d: nni.ConvBn1d,
39
+ nn.Conv2d: nni.ConvBn2d,
40
+ nn.Conv3d: nni.ConvBn3d,
41
+ }
42
+
43
+ if is_qat:
44
+ assert bn.num_features == conv.out_channels, 'Output channel of Conv2d must match num_features of BatchNorm2d'
45
+ assert bn.affine, 'Only support fusing BatchNorm2d with affine set to True'
46
+ assert bn.track_running_stats, 'Only support fusing BatchNorm2d with tracking_running_stats set to True'
47
+ fused_module_class = fused_module_class_map.get((type(conv)), None)
48
+ if fused_module_class is not None:
49
+ return fused_module_class(conv, bn)
50
+ else:
51
+ raise NotImplementedError(f"Cannot fuse train modules: {(conv, bn)}")
52
+ else:
53
+ return nn.utils.fuse_conv_bn_eval(conv, bn)
54
+
55
+ def fuse_conv_bn_relu(is_qat, conv, bn, relu):
56
+ r"""Return the fused conv and bv modules.
57
+
58
+ Given the conv and bn modules, fuses them and returns the fused module
59
+
60
+ Args:
61
+ is_qat: a flag for whether we are using quantization aware training fusion
62
+ or post training quantization fusion
63
+ conv: Module instance of type conv2d/conv3d
64
+ bn: Spatial BN instance that needs to be fused with the conv
65
+
66
+ Examples::
67
+
68
+ >>> m1 = nn.Conv2d(10, 20, 3)
69
+ >>> b1 = nn.BatchNorm2d(20)
70
+ >>> r1 = nn.ReLU(inplace=False)
71
+ >>> # xdoctest: +SKIP
72
+ >>> m2 = fuse_conv_bn_relu(m1, b1, r1)
73
+ """
74
+ assert conv.training == bn.training == relu.training, \
75
+ "Conv and BN both must be in the same mode (train or eval)."
76
+ fused_module : Optional[Type[nn.Sequential]] = None
77
+ if is_qat:
78
+ map_to_fused_module_train = {
79
+ nn.Conv1d: nni.ConvBnReLU1d,
80
+ nn.Conv2d: nni.ConvBnReLU2d,
81
+ nn.Conv3d: nni.ConvBnReLU3d,
82
+ }
83
+ assert bn.num_features == conv.out_channels, 'Output channel of Conv must match num_features of BatchNorm'
84
+ assert bn.affine, 'Only support fusing BatchNorm with affine set to True'
85
+ assert bn.track_running_stats, 'Only support fusing BatchNorm with tracking_running_stats set to True'
86
+ fused_module = map_to_fused_module_train.get(type(conv), None)
87
+ if fused_module is not None:
88
+ return fused_module(conv, bn, relu)
89
+ else:
90
+ raise NotImplementedError(f"Cannot fuse train modules: {(conv, bn, relu)}")
91
+ else:
92
+ map_to_fused_module_eval = {
93
+ nn.Conv1d: nni.ConvReLU1d,
94
+ nn.Conv2d: nni.ConvReLU2d,
95
+ nn.Conv3d: nni.ConvReLU3d,
96
+ }
97
+ fused_module = map_to_fused_module_eval.get(type(conv), None)
98
+ if fused_module is not None:
99
+ fused_conv = nn.utils.fusion.fuse_conv_bn_eval(conv, bn)
100
+ return fused_module(fused_conv, relu)
101
+ else:
102
+ raise NotImplementedError(f"Cannot fuse eval modules: {(conv, bn, relu)}")
103
+
104
+ def fuse_linear_bn(is_qat, linear, bn):
105
+ r"""Return the fused linear and bn modules.
106
+ Given the linear and bn modules, fuses them and returns the fused module
107
+
108
+ Args:
109
+ is_qat: a flag for whether we are using quantization aware training fusion
110
+ or post training quantization fusion
111
+ linear: Module instance of type Linear
112
+ bn: BatchNorm1d instance that needs to be fused with the linear layer
113
+
114
+ Examples::
115
+
116
+ >>> m1 = nn.Linear(20, 10)
117
+ >>> b1 = nn.BatchNorm1d(10)
118
+ >>> # xdoctest: +SKIP
119
+ >>> m2 = fuse_linear_bn(m1, b1)
120
+ """
121
+ assert linear.training == bn.training, \
122
+ "Linear and BN both must be in the same mode (train or eval)."
123
+
124
+ if is_qat:
125
+ assert bn.num_features == linear.out_features, \
126
+ "Output features of Linear must match num_features of BatchNorm1d"
127
+ assert bn.affine, "Only support fusing BatchNorm1d with affine set to True"
128
+ assert bn.track_running_stats, \
129
+ "Only support fusing BatchNorm1d with tracking_running_stats set to True"
130
+ return nni.LinearBn1d(linear, bn)
131
+ else:
132
+ return nn.utils.fusion.fuse_linear_bn_eval(linear, bn)
133
+
134
+ def fuse_convtranspose_bn(is_qat, convt, bn):
135
+ r"""Return the fused ConvTranspose and bn modules.
136
+ Given ConvTranspose and bn modules, fuses them and returns the fused module
137
+
138
+ Args:
139
+ convt: Module instance of type ConvTransposeNd
140
+ bn: BatchNormNd instance that needs to be fused with the linear layer.
141
+ batch norm N should match the ConvTranspose N
142
+
143
+ Examples::
144
+
145
+ >>> m1 = nn.ConvTranspose2d(10, 20, 3)
146
+ >>> b1 = nn.BatchNorm2d(20)
147
+ >>> # xdoctest: +SKIP
148
+ >>> m2 = fuse_convtranspose_bn(m1, b1)
149
+ """
150
+ assert convt.training == bn.training, \
151
+ "ConvTranspose and BN both must be in the same mode (train or eval)."
152
+
153
+ if is_qat:
154
+ raise Exception("Fusing ConvTranspose+BatchNorm not yet supported in QAT.")
155
+ else:
156
+ return nn.utils.fusion.fuse_conv_bn_eval(convt, bn, transpose=True)
157
+
158
+ def _sequential_wrapper2(sequential):
159
+ """Return a sequential wrapped that for is_qat and two modules.
160
+ Given a sequential class for two modules, return a function that takes
161
+ is_qat, and then two modules as argument, that ignores the is_qat flag
162
+ and always returns the sequential that combines the two input modules
163
+ """
164
+ def fuser_method(is_qat, m1, m2):
165
+ return sequential(m1, m2)
166
+ return fuser_method
167
+
168
+ _DEFAULT_OP_LIST_TO_FUSER_METHOD: Dict[Tuple, Union[nn.Sequential, Callable]] = {
169
+ (nn.Conv1d, nn.BatchNorm1d): fuse_conv_bn,
170
+ (nn.Conv1d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
171
+ (nn.Conv2d, nn.BatchNorm2d): fuse_conv_bn,
172
+ (nn.Conv2d, nn.BatchNorm2d, nn.ReLU): fuse_conv_bn_relu,
173
+ (nn.Conv3d, nn.BatchNorm3d): fuse_conv_bn,
174
+ (nn.Conv3d, nn.BatchNorm3d, nn.ReLU): fuse_conv_bn_relu,
175
+ (nn.Conv1d, nn.ReLU): _sequential_wrapper2(nni.ConvReLU1d),
176
+ (nn.Conv2d, nn.ReLU): _sequential_wrapper2(nni.ConvReLU2d),
177
+ (nn.Conv3d, nn.ReLU): _sequential_wrapper2(nni.ConvReLU3d),
178
+ (nn.Linear, nn.BatchNorm1d): fuse_linear_bn,
179
+ (nn.Linear, nn.ReLU): _sequential_wrapper2(nni.LinearReLU),
180
+ (nn.BatchNorm2d, nn.ReLU): _sequential_wrapper2(nni.BNReLU2d),
181
+ (nn.BatchNorm3d, nn.ReLU): _sequential_wrapper2(nni.BNReLU3d),
182
+ (nn.ConvTranspose1d, nn.BatchNorm1d): fuse_convtranspose_bn,
183
+ (nn.ConvTranspose2d, nn.BatchNorm2d): fuse_convtranspose_bn,
184
+ (nn.ConvTranspose3d, nn.BatchNorm3d): fuse_convtranspose_bn,
185
+ }
186
+
187
+ def get_fuser_method(op_list, additional_fuser_method_mapping=None):
188
+ """Get fuser method for the given list of module types.
189
+
190
+ Get fuser method for the given list of module types,
191
+ return None if fuser method does not exist
192
+ """
193
+ if additional_fuser_method_mapping is None:
194
+ additional_fuser_method_mapping = {}
195
+ all_mappings = get_combined_dict(_DEFAULT_OP_LIST_TO_FUSER_METHOD,
196
+ additional_fuser_method_mapping)
197
+ fuser_method = all_mappings.get(op_list, None)
198
+ assert fuser_method is not None, f"did not find fuser method for: {op_list} "
199
+ return fuser_method
200
+
201
+ def _reverse2(f):
202
+ def reversed(is_qat, x, y):
203
+ return f(is_qat, y, x)
204
+ return reversed
205
+
206
+ def _reverse3(f):
207
+ def reversed(is_qat, x, w):
208
+ y, z = w
209
+ return f(is_qat, z, y, x)
210
+ return reversed
211
+
212
+ def _get_valid_patterns(op_pattern):
213
+ """Return a list of valid patterns generated from the op_pattern.
214
+
215
+ Returns a list of valid patterns generated from the op_pattern,
216
+ since MatchAllNode can match all types of nodes,
217
+ e.g. pattern (torch.nn.Conv2d, torch.add) should also be able to match keys like
218
+ (MatchAllNode, torch.add) and (torch.nn.Conv2d, MatchAllNode)
219
+
220
+ Example Input:
221
+ (torch.add, (torch.nn.ReLU, torch.nn.Conv2d))
222
+
223
+ Example Output:
224
+ [(torch.add, (torch.nn.ReLU, torch.nn.Conv2d)),
225
+ (torch.add, (torch.nn.ReLU, MatchAllNode)),
226
+ (torch.add, (MatchAllNode, torch.nn.Conv2d)),
227
+ (torch.add, (MatchAllNode, MatchAllNode)),
228
+ (MatchAllNode, (torch.nn.ReLU, torch.nn.Conv2d)),
229
+ (MatchAllNode, (torch.nn.ReLU, MatchAllNode)),
230
+ (MatchAllNode, (MatchAllNode, torch.nn.Conv2d)),
231
+ (MatchAllNode, (MatchAllNode, MatchAllNode)),
232
+ ]
233
+ """
234
+ result: List[Any]
235
+ if isinstance(op_pattern, (tuple, list)):
236
+ sub_combs = []
237
+ for sub_pattern in op_pattern:
238
+ sub_combs.append(_get_valid_patterns(sub_pattern))
239
+ result = list(itertools.product(*sub_combs))
240
+ else:
241
+ result = [op_pattern, MatchAllNode]
242
+ return result
243
+
244
+ def get_fuser_method_new(
245
+ op_pattern: Pattern,
246
+ fuser_method_mapping: Dict[Pattern, Union[nn.Sequential, Callable]]):
247
+ """Get fuser method.
248
+
249
+ This will be made default after we deprecate the get_fuser_method
250
+ Would like to implement this first and have a separate PR for deprecation
251
+ """
252
+ op_patterns = _get_valid_patterns(op_pattern)
253
+ fuser_method = None
254
+ for op_pattern in op_patterns:
255
+ fuser_method = fuser_method_mapping.get(op_pattern, None)
256
+ if fuser_method is not None:
257
+ break
258
+ assert fuser_method is not None, f"did not find fuser method for: {op_pattern} "
259
+ return fuser_method
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .prepare import prepare
2
+ from .convert import convert
3
+ from .fuse import fuse
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/_decomposed.cpython-311.pyc ADDED
Binary file (46.2 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/_equalize.cpython-311.pyc ADDED
Binary file (40.1 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/custom_config.cpython-311.pyc ADDED
Binary file (24.2 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/fuse.cpython-311.pyc ADDED
Binary file (7.17 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/graph_module.cpython-311.pyc ADDED
Binary file (10.2 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/lower_to_qnnpack.cpython-311.pyc ADDED
Binary file (1.02 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/match_utils.cpython-311.pyc ADDED
Binary file (8.89 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/pattern_utils.cpython-311.pyc ADDED
Binary file (4.54 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/prepare.cpython-311.pyc ADDED
Binary file (65.7 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/qconfig_mapping_utils.cpython-311.pyc ADDED
Binary file (14.8 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/__pycache__/quantize_handler.cpython-311.pyc ADDED
Binary file (9.98 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/_equalize.py ADDED
@@ -0,0 +1,820 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+
3
+ from collections import namedtuple
4
+ from typing import Any, Dict, List, Optional, Tuple
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import torch.ao.nn.intrinsic as nni
10
+ from torch.fx import GraphModule
11
+ from torch.fx.graph import Node
12
+ from torch.ao.quantization.fx.graph_module import _get_observed_graph_module_attr
13
+
14
+ from ..observer import _with_args, ObserverBase, PerChannelMinMaxObserver
15
+ from ..utils import _parent_name, check_min_max_valid
16
+
17
+ from .utils import (
18
+ get_new_attr_name_with_prefix,
19
+ maybe_get_next_module,
20
+ node_arg_is_weight,
21
+ )
22
+
23
+ CUSTOM_MODULE_SUPP_LIST: List[Any] = []
24
+
25
+ def reshape_scale(scale: torch.Tensor, axis: int, input: torch.Tensor) -> torch.Tensor:
26
+ """Reshapes the scale so that we can multiply it to the input by the given axis.
27
+ """
28
+ new_shape = [1] * input.ndim
29
+ new_shape[axis] = input.size(axis)
30
+ return scale.view(new_shape)
31
+
32
+ qsheme_mapping_per_tensor_to_per_channel = {
33
+ torch.per_tensor_affine: torch.per_channel_affine,
34
+ torch.per_tensor_symmetric: torch.per_channel_symmetric,
35
+ }
36
+
37
+
38
+ class _InputEqualizationObserver(nn.Module):
39
+ r"""Observer for tracking the running min/max values of input columns, and
40
+ computing the quantization parameters for the overall min/max input values.
41
+
42
+ Args:
43
+ dtype: Quantized data type
44
+ qscheme: Quantization scheme
45
+ quant_min: Minimum quantization value. If unspecified, it will
46
+ follow the 8-bit setup.
47
+ quant_max: Maximum quantization value. If unspecified, it will
48
+ follow the 8-bit setup.
49
+
50
+ The running minimum/maximum :math:`x_\text{min/max}` are computed in the
51
+ same way as :class:`~torch.ao.quantization.observer.PerChannelMinMaxObserver`,
52
+ with the difference that the running min/max values are stored per column.
53
+ This observer is intended to be used along with a WeightEqualizationObserver
54
+ to calculate the equalization scale.
55
+ """
56
+
57
+ def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine,
58
+ quant_min=None, quant_max=None, factory_kwargs=None) -> None:
59
+ super().__init__()
60
+
61
+ if qscheme not in {torch.per_tensor_affine, torch.per_tensor_symmetric}:
62
+ raise TypeError("Input qscheme must be per-tensor")
63
+
64
+ self.dtype = dtype
65
+ self.qscheme = qscheme
66
+
67
+ per_channel_qscheme = qsheme_mapping_per_tensor_to_per_channel[qscheme]
68
+ self.input_obs = PerChannelMinMaxObserver(ch_axis=1, dtype=dtype,
69
+ qscheme=per_channel_qscheme,
70
+ quant_min=quant_min,
71
+ quant_max=quant_max,
72
+ factory_kwargs=factory_kwargs)
73
+
74
+ self.equalization_scale = torch.tensor(1)
75
+ self.equalization_shape: List[int] = []
76
+
77
+ def forward(self, x_orig):
78
+ if not (x_orig.ndim >= 2 and x_orig.ndim <= 5):
79
+ raise ValueError("InputEqualizationObserver only supports Linear and Conv layers")
80
+
81
+ # Calculate the shape needed to reshape the equalization scale later (needed for Conv layers)
82
+ self.equalization_shape = [1] * x_orig.ndim
83
+ self.equalization_shape[1] = x_orig.size(1)
84
+
85
+ return self.input_obs(x_orig)
86
+
87
+ def get_input_minmax(self):
88
+ return (self.input_obs.min_val, self.input_obs.max_val)
89
+
90
+ def set_equalization_scale(self, equalization_scale):
91
+ # Reshape the equalization scale along axis=1 so that it can be
92
+ # multiplied with the input along axis=1
93
+ if equalization_scale.nelement() == 1 and equalization_scale == torch.tensor(1):
94
+ return
95
+ self.equalization_scale = torch.reshape(equalization_scale, self.equalization_shape)
96
+
97
+ def calculate_scaled_minmax(self):
98
+ r""" Returns the scaled min/max inputs
99
+ """
100
+ if self.equalization_scale.nelement() == 1 and self.equalization_scale == torch.tensor(1):
101
+ warnings.warn(
102
+ "Must call calculate_equalization_scale before calling calculate_scaled_minmax. " +
103
+ "Will not scale the next quantization observer."
104
+ )
105
+ return None, None
106
+
107
+ # Calculate qparams for the scaled min/max inputs
108
+ # Scale the input by the equalization scale located at the same column
109
+ # index
110
+ (min_inputs, max_inputs) = self.get_input_minmax()
111
+ equalization_scale_reshaped = reshape_scale(self.equalization_scale, 0, min_inputs)
112
+ min_input_scaled = torch.min(torch.mul(min_inputs, equalization_scale_reshaped))
113
+ max_input_scaled = torch.max(torch.mul(max_inputs, equalization_scale_reshaped))
114
+
115
+ return min_input_scaled, max_input_scaled
116
+
117
+ with_args = classmethod(_with_args)
118
+
119
+
120
+ class _WeightEqualizationObserver(nn.Module):
121
+ r"""Observer for tracking the running min/max values of weight columns and
122
+ rows, and computing the quantization parameters for the weight rows.
123
+
124
+ Args:
125
+ dtype: Quantized data type
126
+ qscheme: Quantization scheme
127
+ quant_min: Minimum quantization value. If unspecified, it will
128
+ follow the 8-bit setup.
129
+ quant_max: Maximum quantization value. If unspecified, it will
130
+ follow the 8-bit setup.
131
+
132
+ This observer is made up of 1 PerChannelMinMaxObserver `weight_col_obs` used
133
+ to record the running minimum and maximum of columns of incoming weight
134
+ tensors. This observer is intended to be used along with an
135
+ InputEqualizationObserver to calculate the equalization scale.
136
+
137
+ The running minimum/maximum :math:`w_\text{min/max}` are computed in the
138
+ same way as :class:`~torch.ao.quantization.observer.PerChannelMinMaxObserver`.
139
+ """
140
+
141
+ def __init__(self, dtype=torch.qint8, qscheme=torch.per_tensor_affine, quant_min=None,
142
+ quant_max=None, factory_kwargs=None) -> None:
143
+ super().__init__()
144
+
145
+ self.dtype = dtype
146
+ self.qscheme = qscheme
147
+ self.ch_axis = 1
148
+
149
+ per_channel_qscheme = qscheme
150
+ if qscheme in {torch.per_tensor_affine, torch.per_tensor_symmetric}:
151
+ per_channel_qscheme = qsheme_mapping_per_tensor_to_per_channel[qscheme]
152
+ self.weight_col_obs = PerChannelMinMaxObserver(ch_axis=1, dtype=dtype,
153
+ qscheme=per_channel_qscheme,
154
+ quant_min=quant_min,
155
+ quant_max=quant_max,
156
+ factory_kwargs=factory_kwargs)
157
+
158
+ self.equalization_scale = torch.tensor(1)
159
+
160
+ def forward(self, w_orig):
161
+ if not (w_orig.ndim >= 2 and w_orig.ndim <= 5):
162
+ raise ValueError("InputEqualizationObserver only supports Linear and Conv layers")
163
+
164
+ return self.weight_col_obs(w_orig)
165
+
166
+ def get_weight_col_minmax(self):
167
+ return (self.weight_col_obs.min_val, self.weight_col_obs.max_val)
168
+
169
+ def set_equalization_scale(self, equalization_scale):
170
+ self.equalization_scale = equalization_scale
171
+
172
+ with_args = classmethod(_with_args)
173
+
174
+
175
+ def calculate_equalization_scale(input_obs: _InputEqualizationObserver,
176
+ weight_obs: _WeightEqualizationObserver) -> torch.Tensor:
177
+ r""" Calculates the equalization scale and sets the equalization_scale value
178
+ in the observers.
179
+
180
+ Args:
181
+ input_obs: Observer that tracks the ranges for the input columns
182
+ weight_obs: Observer that tracks the ranges for the weight columns
183
+ """
184
+
185
+ (min_inputs, max_inputs) = input_obs.get_input_minmax()
186
+ (min_weights, max_weights) = weight_obs.get_weight_col_minmax()
187
+
188
+ if not (check_min_max_valid(min_inputs, max_inputs) and check_min_max_valid(min_weights, max_weights)):
189
+ warnings.warn(
190
+ "Must run observer before calling calculate_equalization_scale. " +
191
+ "Returning default equalization scale torch.tensor(1)."
192
+ )
193
+ return torch.tensor(1)
194
+
195
+ if not (min_inputs.shape == min_weights.shape):
196
+ raise ValueError(
197
+ "Input and Weight must have the same column dimension. " +
198
+ f"Found {min_inputs.shape} and {min_weights.shape} shapes instead."
199
+ )
200
+
201
+ equalization_scale = torch.sqrt((max_weights - min_weights) / (max_inputs - min_inputs))
202
+ # Replace all 'inf', 'nan', 0's with 1s to prevent errors
203
+ equalization_scale[equalization_scale == 0.] = 1
204
+ equalization_scale = torch.nan_to_num(equalization_scale, nan=1, posinf=1, neginf=1)
205
+ return equalization_scale
206
+
207
+
208
+ class EqualizationQConfig(namedtuple('EqualizationQConfig', ['input_activation', 'weight'])):
209
+ """
210
+ Describes how to quantize a layer or a part of the network specifically for
211
+ input-weight equalization by providing settings (observer classes) for
212
+ inputs, outputs, and weights.
213
+
214
+ Note that EqualizationQConfig needs to contain observer **classes** (like
215
+ MinMaxObserver) or a callable that returns instances on invocation, not the
216
+ concrete observer instances themselves.
217
+ Quantization function will instantiate observers multiple times for each of
218
+ the layers.
219
+
220
+ Observer classes have usually reasonable default arguments, but they can be
221
+ overwritten with `with_args` method (that behaves like functools.partial):
222
+
223
+ my_qconfig = EqualizationQConfig(input_activation=_InputEqualizationObserver.with_args(dtype=torch.qint8),
224
+ weight=_WeightEqualizationObserver.with_args(dtype=torch.qint8))
225
+ """
226
+ def __new__(cls, input_activation=torch.nn.Identity, weight=torch.nn.Identity):
227
+ if isinstance(input_activation, nn.Module) or isinstance(weight, nn.Module):
228
+ raise ValueError("EqualizationQConfig received observer instance, please pass observer class instead. " +
229
+ "Use MyObserver.with_args(x=1) to override arguments to constructor if needed")
230
+ self = super().__new__(cls, input_activation, weight)
231
+ return self
232
+
233
+
234
+ input_equalization_observer = _InputEqualizationObserver.with_args(
235
+ dtype=torch.quint8, qscheme=torch.per_tensor_symmetric)
236
+ weight_equalization_observer = _WeightEqualizationObserver.with_args(
237
+ dtype=torch.qint8, qscheme=torch.per_channel_symmetric)
238
+ default_equalization_qconfig = EqualizationQConfig(input_activation=input_equalization_observer,
239
+ weight=weight_equalization_observer)
240
+
241
+
242
+ def fused_module_supports_equalization(module) -> bool:
243
+ """ Checks if the fused node supports equalization. """
244
+ return type(module) in [nni.LinearReLU, nni.ConvReLU1d, nni.ConvReLU2d, nni.ConvReLU3d]
245
+
246
+ def nn_module_supports_equalization(module) -> bool:
247
+ """ Checks if the torch.nn node supports equalization. """
248
+ return type(module) in [nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d]
249
+
250
+ def custom_module_supports_equalization(module) -> bool:
251
+ """ Checks if the custom node supports equalization. """
252
+ return type(module) in CUSTOM_MODULE_SUPP_LIST
253
+
254
+
255
+ def node_supports_equalization(node: Node, modules) -> bool:
256
+ """ Checks if the current node supports equalization
257
+ Currently we only support nn.Linear/F.Linear and nn.Conv/F.conv layers
258
+ """
259
+ if node.op == 'call_module':
260
+ return nn_module_supports_equalization(modules[str(node.target)]) or \
261
+ fused_module_supports_equalization(modules[str(node.target)]) or \
262
+ custom_module_supports_equalization(modules[str(node.target)])
263
+ elif node.op == 'call_function':
264
+ return node.target in [F.linear, F.conv1d, F.conv2d, F.conv3d]
265
+ return False
266
+
267
+ def is_equalization_observer(observer: nn.Module) -> bool:
268
+ return (isinstance(observer, (_InputEqualizationObserver, _WeightEqualizationObserver)))
269
+
270
+
271
+ ###############################################################################
272
+ # Functions for equalization during convert #
273
+ ###############################################################################
274
+
275
+ def get_op_node_and_weight_eq_obs(
276
+ input_eq_obs_node: Node,
277
+ model: GraphModule,
278
+ modules: Dict[str, nn.Module]
279
+ ) -> Tuple[Optional[Node], Optional[_WeightEqualizationObserver]]:
280
+ """ Gets the following weight equalization observer. There should always
281
+ exist a weight equalization observer after an input equalization observer.
282
+
283
+ Returns the operation node that follows the input equalization observer node
284
+ and the weight equalization observer
285
+ """
286
+
287
+ # Find the op node that comes directly after the input equalization observer
288
+ op_node = None
289
+ for user in input_eq_obs_node.users.keys():
290
+ if node_supports_equalization(user, modules):
291
+ op_node = user
292
+ break
293
+
294
+ assert op_node is not None
295
+ if op_node.op == 'call_module':
296
+ # If the op_node is a nn.Linear layer, then it must have a
297
+ # WeightEqualizationObserver configuration
298
+ maybe_equalization_node_name_to_config = _get_observed_graph_module_attr(model, "equalization_node_name_to_qconfig")
299
+ assert maybe_equalization_node_name_to_config is not None
300
+ equalization_node_name_to_qconfig: Dict[str, Any] = maybe_equalization_node_name_to_config # type: ignore[assignment]
301
+ assert equalization_node_name_to_qconfig.get(op_node.name, None) is not None
302
+ weight_eq_obs = equalization_node_name_to_qconfig.get(op_node.name, None).weight()
303
+
304
+ assert isinstance(weight_eq_obs, _WeightEqualizationObserver)
305
+ return op_node, weight_eq_obs
306
+
307
+ elif op_node.op == 'call_function':
308
+ weight_node = maybe_get_weight_eq_obs_node(op_node, modules)
309
+ if weight_node is not None:
310
+ weight_eq_obs = modules[str(weight_node.target)]
311
+ assert isinstance(weight_eq_obs, _WeightEqualizationObserver)
312
+ return op_node, weight_eq_obs
313
+
314
+ return None, None
315
+
316
+ def maybe_get_weight_eq_obs_node(op_node: Node, modules: Dict[str, nn.Module]) -> Optional[Node]:
317
+ """ Gets the weight equalization observer node if it exists.
318
+ """
319
+ assert op_node.op == 'call_function'
320
+ for node_arg in op_node.args:
321
+ if node_arg_is_weight(op_node, node_arg):
322
+ assert (isinstance(node_arg, Node) and node_arg.op == 'call_module' and
323
+ isinstance(modules[str(node_arg.target)], _WeightEqualizationObserver))
324
+ return node_arg
325
+ return None
326
+
327
+ def maybe_get_next_input_eq_obs(node: Node, modules: Dict[str, nn.Module]) -> Optional[_InputEqualizationObserver]:
328
+ """ Gets the following input equalization observer if it exists.
329
+
330
+ For example, in the case of connecting linear layers:
331
+ x -> inp_obs1 -> eq_obs1 -> linear1 -> out_obs1 -> eq_obs2 -> linear2 -> out_obs2
332
+ If the node being passed in is the linear1 node, then we want to return eq_obs2,
333
+ the following equalization observer for linear2.
334
+
335
+ However, if there are no connecting layers:
336
+ x -> inp_obs1 -> eq_obs1 -> linear1 -> out_obs1 -> add
337
+ Then we want to return None.
338
+
339
+ In the case of an unfused linear-relu layer with a connecting linear layer:
340
+ linear1 -> relu -> out_obs1 -> eq_obs2 -> linear2 -> out_obs2
341
+ Since it is unfused, we want to skip over the relu layer and return eq_obs2,
342
+ the following equalization observer for linear2.
343
+ """
344
+
345
+ assert node_supports_equalization(node, modules)
346
+
347
+ # Locate the following nn.ReLU or F.relu node if it exists
348
+ maybe_relu_node = maybe_get_next_module(node, modules, nn.ReLU)
349
+ if maybe_relu_node is None:
350
+ maybe_relu_node = maybe_get_next_module(node, modules, target_functional_type=F.relu)
351
+
352
+ # Locate the following output observer if it exists.
353
+ # We will skip the relu node if it exists.
354
+ maybe_obs_node = (
355
+ maybe_get_next_module(node, modules, ObserverBase)
356
+ if maybe_relu_node is None
357
+ else maybe_get_next_module(maybe_relu_node, modules, ObserverBase)
358
+ )
359
+ if maybe_obs_node is None:
360
+ return None
361
+
362
+ maybe_eq_obs_node = maybe_get_next_module(maybe_obs_node, modules, _InputEqualizationObserver)
363
+ if maybe_eq_obs_node is None:
364
+ return None
365
+
366
+ maybe_eq_obs = modules[str(maybe_eq_obs_node)]
367
+ assert isinstance(maybe_eq_obs, _InputEqualizationObserver)
368
+ return maybe_eq_obs
369
+
370
+ def maybe_get_next_equalization_scale(node: Node, modules: Dict[str, nn.Module]) -> Optional[torch.Tensor]:
371
+ """ If the next next node is an InputEqualizationObserver then we want to
372
+ return its equalization scale, else we return 1
373
+
374
+ This is used in the case where there are two connecting linear layers:
375
+ linear1 -> LinearOutObs -> InputEqObs -> linear2
376
+ In this case, the node given is linear1 and we want to locate the InputEqObs.
377
+ """
378
+ next_inp_eq_obs = maybe_get_next_input_eq_obs(node, modules)
379
+ if next_inp_eq_obs:
380
+ if next_inp_eq_obs.equalization_scale.nelement() == 1 and \
381
+ next_inp_eq_obs.equalization_scale == torch.tensor(1):
382
+ return None
383
+ return next_inp_eq_obs.equalization_scale
384
+ return None
385
+
386
+ def scale_input_observer(node: Node, modules: Dict[str, nn.Module]) -> None:
387
+ """ Scales the following input quantization observer's min/max values by
388
+ updating the values with the scaled min/max values calculated by the input
389
+ equalization observer
390
+ """
391
+ input_eq_obs = modules[str(node.target)]
392
+ assert isinstance(input_eq_obs, _InputEqualizationObserver)
393
+
394
+ input_quant_obs_node = node.args[0]
395
+ assert isinstance(input_quant_obs_node, Node)
396
+
397
+ input_quant_obs = modules[str(input_quant_obs_node.target)]
398
+ if not isinstance(input_quant_obs, ObserverBase):
399
+ return
400
+
401
+ min_input_scaled, max_input_scaled = input_eq_obs.calculate_scaled_minmax()
402
+ if min_input_scaled is None and max_input_scaled is None:
403
+ return
404
+ input_quant_obs.min_val = min_input_scaled
405
+ input_quant_obs.max_val = max_input_scaled
406
+
407
+ def scale_weight_node(
408
+ node: Node,
409
+ modules: Dict[str, nn.Module],
410
+ equalization_scale: torch.Tensor,
411
+ next_equalization_scale: Optional[torch.Tensor],
412
+ ) -> None:
413
+ """ Scale the weights for input-weight equalization by multiplying the
414
+ weight by 1/equalization_scale and next_equalization_scale
415
+
416
+ Args:
417
+ node: Current node whose weights we want to scale
418
+ equalization_scale: Current node's calculated equalization scale
419
+ next_equalization_scale: Next node's calculated equalization scale if
420
+ the following node needs to be equalized, 1 otherwise
421
+ """
422
+ if equalization_scale is None:
423
+ return
424
+
425
+ if fused_module_supports_equalization(modules[str(node.target)]):
426
+ op_module = modules[str(node.target)][0] # type: ignore[index]
427
+ else:
428
+ op_module = modules[str(node.target)]
429
+ assert nn_module_supports_equalization(op_module) or custom_module_supports_equalization(op_module)
430
+
431
+ # Scale the weights for input-weight equalization
432
+ # If the following layer needs to be equalized then we will multiply its scale
433
+ weight = op_module.weight
434
+ assert isinstance(weight, torch.Tensor)
435
+
436
+ # Scale the weights by the reciprocal of the equalization scale
437
+ # Reshape the equalization scale so that we can multiply it to the weight along axis=1
438
+ equalization_scale_reshaped = reshape_scale(equalization_scale, 1, weight)
439
+ scaled_weight = torch.mul(weight, torch.reciprocal(equalization_scale_reshaped))
440
+
441
+ if next_equalization_scale is None:
442
+ op_module.weight = nn.Parameter(scaled_weight)
443
+ return
444
+
445
+ # Multiply the weights row wise by the next equalization scale
446
+ # Reshape the equalization scale so that we can multiply it to the weight along axis=0
447
+ next_equalization_scale_reshaped = reshape_scale(next_equalization_scale, 0, weight)
448
+ scaled_weight = torch.mul(scaled_weight, next_equalization_scale_reshaped)
449
+
450
+ op_module.weight = nn.Parameter(scaled_weight)
451
+
452
+ # Multiply the bias element wise by the next equalization scale
453
+ bias = op_module.bias
454
+ if bias is None:
455
+ return
456
+ assert isinstance(bias, torch.Tensor)
457
+
458
+ # Reshape the equalization scale so that we can multiply it element-wise to the bias
459
+ next_equalization_scale_reshaped = reshape_scale(next_equalization_scale, 0, bias)
460
+ scaled_bias = torch.mul(bias, next_equalization_scale_reshaped)
461
+ op_module.bias = nn.Parameter(scaled_bias)
462
+
463
+ def scale_weight_functional(
464
+ op_node: Node,
465
+ model: GraphModule,
466
+ modules: Dict[str, nn.Module],
467
+ equalization_scale: torch.Tensor,
468
+ next_equalization_scale: Optional[torch.Tensor],
469
+ ) -> None:
470
+ """ Scales the weight value for functional layers
471
+ """
472
+ if equalization_scale is None:
473
+ return
474
+
475
+ # From the given op_node, the path looks like:
476
+ # get_attr(weight) -> weight_quant_obs -> weight_eq_obs -> op_node
477
+ # So we want to trace back from the op_node to get the equalization observer
478
+ # node, then the quantization observer node, and then finally the weight
479
+ # node which contains the weight values.
480
+
481
+ # Get the equalization observer node
482
+ weight_eq_obs_node = maybe_get_weight_eq_obs_node(op_node, modules)
483
+ if weight_eq_obs_node is None:
484
+ return
485
+
486
+ # Get the quantization observer node
487
+ weight_quant_obs_node = weight_eq_obs_node.args[0]
488
+ if weight_quant_obs_node is None:
489
+ return
490
+ assert (isinstance(weight_quant_obs_node, Node) and
491
+ isinstance(modules[str(weight_quant_obs_node.target)], ObserverBase))
492
+
493
+ # Get the get_attr(weight) node
494
+ weight_node = weight_quant_obs_node.args[0]
495
+ if weight_node is None:
496
+ return
497
+ assert isinstance(weight_node, Node) and weight_node.op == 'get_attr'
498
+
499
+ weight_parent_name, weight_name = _parent_name(weight_node.target)
500
+ weight = getattr(modules[weight_parent_name], weight_name)
501
+
502
+ # Scale the weights for input-weight equalization
503
+ # If the following layer needs to be equalized then we will multiply its scale
504
+ # Reshape the equalization scale so that we can multiply it to the weight along axis=1
505
+ equalization_scale_reshaped = reshape_scale(equalization_scale, 1, weight)
506
+ scaled_weight = torch.mul(weight, torch.reciprocal(equalization_scale_reshaped))
507
+
508
+ if next_equalization_scale is None:
509
+ setattr(modules[weight_parent_name], weight_name, scaled_weight)
510
+ return
511
+
512
+ # Multiply the weights row wise by the next equalization scale
513
+ # Reshape the equalization scale so that we can multiply it to the weight along axis=1
514
+ next_equalization_scale_reshaped = reshape_scale(next_equalization_scale, 0, scaled_weight)
515
+ scaled_weight = torch.mul(scaled_weight, next_equalization_scale_reshaped)
516
+
517
+ setattr(modules[weight_parent_name], weight_name, scaled_weight)
518
+ assert torch.allclose(model.get_buffer(str(weight_node.target)), scaled_weight)
519
+
520
+ # Multiply the bias element wise by the next equalization scale
521
+ bias_node = None
522
+ for node in op_node.args:
523
+ # Find the node containing the weight values
524
+ if isinstance(node, Node) and node.op == 'get_attr' and 'bias' in node.name:
525
+ bias_node = node
526
+ break
527
+ if bias_node is None:
528
+ return
529
+
530
+ bias_parent_name, bias_name = _parent_name(bias_node.target)
531
+ bias = getattr(modules[bias_parent_name], bias_name)
532
+
533
+ # Reshape the equalization scale so that we can multiply it element-wise to the bias
534
+ next_equalization_scale_reshaped = reshape_scale(next_equalization_scale, 0, bias)
535
+ scaled_bias = torch.mul(bias, next_equalization_scale_reshaped)
536
+ setattr(modules[bias_parent_name], bias_name, scaled_bias)
537
+
538
+ def clear_weight_quant_obs_node(op_node: Node, modules: Dict[str, nn.Module]) -> None:
539
+ """ Given the operation node, we want find the corresponding quantization
540
+ observer and reset its min/max values
541
+ """
542
+ weight_eq_obs_node = maybe_get_weight_eq_obs_node(op_node, modules)
543
+ if weight_eq_obs_node is None:
544
+ return
545
+
546
+ weight_quant_obs_node = weight_eq_obs_node.args[0]
547
+ if weight_quant_obs_node is None:
548
+ return
549
+ assert isinstance(weight_quant_obs_node, Node)
550
+
551
+ weight_quant_obs = modules[str(weight_quant_obs_node.target)]
552
+ assert isinstance(modules[str(weight_quant_obs_node.target)], ObserverBase)
553
+ weight_quant_obs.reset_min_max_vals() # type: ignore[operator]
554
+
555
+ def remove_node(model: GraphModule, node: Node, prev_node: Node):
556
+ """ Removes the given node from the model by replacing all of its users with
557
+ the given previous node
558
+ """
559
+ # For all of the current node's users, replace the current node with
560
+ # the input quantization observer node
561
+ orig_users = list(node.users.keys())
562
+ for user_node in orig_users:
563
+ user_node.replace_input_with(node, prev_node)
564
+
565
+ # Erase the InputEqualizationObserver node
566
+ model.graph.erase_node(node)
567
+
568
+ def update_obs_for_equalization(model: GraphModule, modules: Dict[str, nn.Module]) -> Dict[str, _WeightEqualizationObserver]:
569
+ """ Update all of the observer's equalization scale. For each
570
+ InputEqualizationObserver, we will find the location of the next
571
+ WeightEqualizationObserver, create it, and calculate the equalization scale
572
+ based on the two observers.
573
+
574
+ We will then return a dictionary mapping operation node names to
575
+ the corresponding WeightEqualizationObservers for that operation.
576
+ """
577
+ weight_eq_obs_dict = {}
578
+ for node in model.graph.nodes:
579
+ if node.op == 'call_module' and isinstance(modules[node.target], _InputEqualizationObserver):
580
+ input_eq_obs = modules[node.target]
581
+ assert isinstance(input_eq_obs, _InputEqualizationObserver)
582
+ op_node, weight_eq_obs = get_op_node_and_weight_eq_obs(node, model, modules)
583
+
584
+ if op_node is None or weight_eq_obs is None:
585
+ continue
586
+
587
+ if op_node.op == 'call_module':
588
+ # Calibrate the weight equalization observer since it has just
589
+ # been created
590
+ if fused_module_supports_equalization(modules[str(op_node.target)]):
591
+ module = modules[str(op_node.target)][0] # type: ignore[index]
592
+ assert nn_module_supports_equalization(module)
593
+ weight_eq_obs(module.weight)
594
+ else:
595
+ weight_eq_obs(modules[str(op_node.target)].weight)
596
+
597
+ # Calculate and set the equalization scale values
598
+ equalization_scale = calculate_equalization_scale(input_eq_obs, weight_eq_obs)
599
+ input_eq_obs.set_equalization_scale(equalization_scale)
600
+ weight_eq_obs.set_equalization_scale(equalization_scale)
601
+
602
+ weight_eq_obs_dict[op_node.name] = weight_eq_obs
603
+
604
+ return weight_eq_obs_dict
605
+
606
+ def convert_eq_obs(
607
+ model: GraphModule,
608
+ modules: Dict[str, nn.Module],
609
+ weight_eq_obs_dict: Dict[str, _WeightEqualizationObserver],
610
+ ) -> None:
611
+ """ Converts the equalization operations and updates the other nodes in the
612
+ following way:
613
+ - Removes the input equalization observers and inserts a mul operator
614
+ along with an equalization scale node wherever applicable (we do not
615
+ want to insert a mul operator between connecting linear layers).
616
+ - Updates the input quantization observers with the scaled input min/max
617
+ values.
618
+ - Scales the weights by the current and next equalization scales.
619
+ - Removes the weight equalization observer node if it exists.
620
+
621
+ Before (after prepare):
622
+ weight values
623
+ |
624
+ WeightQuantObs
625
+ |
626
+ WeightEqObs
627
+ |
628
+ x -> InpQuantObs -> InpEqObs -> linear -> OutQuantObs
629
+
630
+ After this function:
631
+ scaled weight values
632
+ |
633
+ equalization scale WeightQuantObs
634
+ | |
635
+ x -> mul -> InpQuantObs (scaled min/max) -> linear -> OutQuantObs
636
+
637
+ After convert:
638
+ equalization scale scaled weight values
639
+ | |
640
+ x -> mul -> quantize_per_tensor -> quantized::linear
641
+
642
+ Note that although the equalization observer appeared after the quantization
643
+ observer after prepare_fx, the mul node appears before the quantization node
644
+ after convert_fx. This is because placing the equalization observer after
645
+ the quantization observer in prepare_fx would allow us to keep the invariant
646
+ that the graph before the current node inserts its observers is not
647
+ modified.
648
+
649
+ Having the equalization observer before the quantization observer would also
650
+ cause some inconsistences between the ordering of the quantization and
651
+ equalization observers.
652
+ For example, a single linear layer would look like:
653
+ x -> InpEqObs1 -> InpQuantObs1 -> linear1 -> OutQuantObs1
654
+ But between two connected linear layers, it would look like:
655
+ linear1 -> OutQuantObs1 -> InpEqObs2 -> linear2 -> OutQuantObs2
656
+ """
657
+ for node in model.graph.nodes:
658
+ if node.op == 'call_module' and isinstance(modules[node.target], _InputEqualizationObserver):
659
+ inp_quant_obs_node = node.args[0]
660
+ prev_node = inp_quant_obs_node.args[0]
661
+
662
+ # If the previous node is a layer that needs to be equalized, then
663
+ # we will remove the current node because we do not need to add any
664
+ # equalization nodes between two layers that need to be equalized
665
+
666
+ # Before: linear1/relu (prev_node) -> output_quant_obs1 (inp_quant_obs_node) -> input_eq_obs2 (node) -> linear2
667
+ # After: linear1/relu (prev_node) -> output_quant_obs1 (inp_quant_obs_node) -> linear2
668
+ if node_supports_equalization(prev_node, modules) or "relu" in prev_node.name:
669
+ remove_node(model, node, inp_quant_obs_node)
670
+ continue
671
+
672
+ # Update the following input quantization observer's min/max values
673
+ scale_input_observer(node, modules)
674
+
675
+ # Remove the InputEqualization node and add a mul operator before
676
+ # the quantization observer node that appears before the equalization node
677
+ # Before: x -> input_quant_obs -> input_eq_obs -> linear
678
+ # After: x -> mul -> input_quant_obs -> linear
679
+
680
+ # Create a node containing the equalization scale
681
+ with model.graph.inserting_before(inp_quant_obs_node):
682
+ get_new_eq_scale_name = get_new_attr_name_with_prefix(prev_node.name + '_equalization_scale')
683
+ name = get_new_eq_scale_name(modules)
684
+ setattr(model, name, modules[node.target].equalization_scale)
685
+ eq_scale_node = model.graph.create_node('get_attr', name)
686
+
687
+ # Create a node multiplying the input with the equalization scale
688
+ with model.graph.inserting_after(eq_scale_node):
689
+ inputs = (prev_node, eq_scale_node)
690
+ mul_node = model.graph.create_node("call_function", torch.mul, inputs)
691
+
692
+ # Set the mul nod to be the input_quant_obs_node's input instead of
693
+ # the previous node
694
+ inp_quant_obs_node.replace_input_with(prev_node, mul_node)
695
+ remove_node(model, node, inp_quant_obs_node)
696
+
697
+ elif weight_eq_obs_dict.get(node.name, None) is not None:
698
+ weight_eq_obs = weight_eq_obs_dict.get(node.name)
699
+ assert isinstance(weight_eq_obs, _WeightEqualizationObserver)
700
+ equalization_scale = weight_eq_obs.equalization_scale
701
+
702
+ if equalization_scale.nelement() == 1 and equalization_scale == torch.tensor(1):
703
+ equalization_scale = None # type: ignore[assignment]
704
+ maybe_next_equalization_scale = maybe_get_next_equalization_scale(node, modules)
705
+
706
+ # Scale the weight nodes
707
+ if node.op == 'call_module':
708
+ scale_weight_node(node, modules, equalization_scale, maybe_next_equalization_scale)
709
+ elif node.op == 'call_function':
710
+ scale_weight_functional(node, model, modules, equalization_scale, maybe_next_equalization_scale)
711
+
712
+ weight_eq_obs_node = maybe_get_weight_eq_obs_node(node, modules)
713
+ if weight_eq_obs_node is None:
714
+ return
715
+ assert isinstance(modules[str(weight_eq_obs_node.target)], _WeightEqualizationObserver)
716
+
717
+ # Clear the quantization observer's min/max values so that they
718
+ # can get updated later based on the new scale values
719
+ clear_weight_quant_obs_node(node, modules)
720
+
721
+ # Erase the weight equalization observer node
722
+ prev_node = weight_eq_obs_node.args[0]
723
+ remove_node(model, weight_eq_obs_node, prev_node)
724
+ else:
725
+ raise ValueError("Expected operation node to be 'call_module' or 'call_function" +
726
+ f"Instead got node {node.name} as '{node.op}'.")
727
+
728
+ def _convert_equalization_ref(model: GraphModule):
729
+ """ Reference function which applies changes needed for equalization, but
730
+ does not quantize the nodes
731
+ """
732
+ modules = dict(model.named_modules(remove_duplicate=False))
733
+
734
+ # Calculate the equalization scale, update the observers with the scaled
735
+ # inputs, and scale the weight
736
+ weight_eq_obs_dict = update_obs_for_equalization(model, modules)
737
+ convert_eq_obs(model, modules, weight_eq_obs_dict)
738
+
739
+ return GraphModule(model, model.graph)
740
+
741
+
742
+ ###############################################################################
743
+ # Functions for running the equalized model on the Numeric Suite #
744
+ ###############################################################################
745
+
746
+ def get_layer_sqnr_dict(model_a: nn.Module, model_b: nn.Module, x: torch.Tensor) -> Dict[str, float]:
747
+ """ Runs the Numeric Suite on model_a and model_b and returns a dictionary
748
+ containing the SQNR between layers in model_a and model_b.
749
+
750
+ Note: In order to support equalized models, this function has a hacky fix in
751
+ which we do not match any torch.mul operators. This is because equalized
752
+ models contain extra mul operators to scale the input by the equalization
753
+ scale, but this edge case has not been resolved yet within the numeric suite code.
754
+
755
+ Args:
756
+ model_a: A float model
757
+ model_b: A quantized model
758
+ x: Inputs to use during calibration
759
+ """
760
+ import torch.ao.ns._numeric_suite_fx as ns
761
+ from torch.ao.ns.fx.mappings import get_unmatchable_types_map
762
+
763
+ unmatchable_types_map = get_unmatchable_types_map()
764
+ unmatchable_types_map["funs_unmatchable"].add(torch.mul)
765
+
766
+ model_a_ns, model_b_ns = ns.add_loggers(
767
+ 'fp32', model_a,
768
+ 'int8', model_b,
769
+ ns.OutputLogger,
770
+ unmatchable_types_map=unmatchable_types_map
771
+ )
772
+
773
+ model_a_ns(x)
774
+ model_b_ns(x)
775
+
776
+ activation_comparison_dict = ns.extract_logger_info(
777
+ model_a_ns,
778
+ model_b_ns,
779
+ ns.OutputLogger,
780
+ 'int8')
781
+ ns.extend_logger_results_with_comparison(
782
+ activation_comparison_dict,
783
+ 'fp32', 'int8',
784
+ torch.ao.ns.fx.utils.compute_sqnr, 'sqnr'
785
+ )
786
+
787
+ # Construct a dictionary mapping layer names to the SQNR values
788
+ layer_sqnr_dict = {}
789
+ for key in activation_comparison_dict:
790
+ layer = activation_comparison_dict[key]['node_output']['int8'][0]['fqn']
791
+ sqnr = activation_comparison_dict[key]['node_output']['int8'][0]['sqnr'][0]
792
+ layer_sqnr_dict[layer] = sqnr
793
+
794
+ return layer_sqnr_dict
795
+
796
+ def get_equalization_qconfig_dict(
797
+ layer_sqnr_dict: Dict[str, float],
798
+ num_layers_to_equalize: int
799
+ ) -> Any:
800
+ """ Given the layer to SQNR dictionary, find the layers with the highest
801
+ quantization errors, and return an equalization_qconfig_dict
802
+ specifying to only equalize those top layers.
803
+
804
+ Args:
805
+ layer_sqnr_dict: Dictionary mapping layer names to SQNR values (found
806
+ when comparing an equalized model against a float model)
807
+ num_layers_to_equalize: Number of layers with the highest quantization
808
+ errors to equalize
809
+ """
810
+
811
+ # Sort the layer_sqnr_dictionary values and get the layers with the lowest
812
+ # SQNR values (aka highest quantization errors)
813
+ layer_sqnr_sorted = sorted(layer_sqnr_dict.items(), key=lambda item: item[1])
814
+ layers_to_equalize = layer_sqnr_sorted[:num_layers_to_equalize]
815
+
816
+ # Constructs an equalization_qconfig_dict that specifies to only equalize
817
+ # the layers with the highest quantization errors
818
+ module_to_qconfig_list = [(item[0], default_equalization_qconfig) for item in layers_to_equalize]
819
+ equalization_qconfig_dict = {"module_name": module_to_qconfig_list}
820
+ return equalization_qconfig_dict
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/_model_report/__init__.py ADDED
File without changes
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (239 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/detector.cpython-311.pyc ADDED
Binary file (66.2 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/model_report.cpython-311.pyc ADDED
Binary file (26.7 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/model_report_observer.cpython-311.pyc ADDED
Binary file (13.6 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/_model_report/__pycache__/model_report_visualizer.cpython-311.pyc ADDED
Binary file (29.8 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/_model_report/detector.py ADDED
@@ -0,0 +1,1539 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Set, Tuple, Callable, List
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.ao.nn.qat as nnqat
6
+ from abc import ABC, abstractmethod
7
+ from torch.ao.quantization.fake_quantize import FakeQuantize
8
+ from torch.ao.quantization.fx.graph_module import GraphModule
9
+ from torch.ao.quantization.fx._model_report.model_report_observer import ModelReportObserver
10
+ from torch.ao.quantization.qconfig import (
11
+ QConfig,
12
+ default_qconfig,
13
+ _assert_valid_qconfig,
14
+ )
15
+ from torch.ao.quantization.observer import (
16
+ ObserverBase,
17
+ default_dynamic_quant_observer,
18
+ default_per_channel_weight_observer,
19
+ default_observer,
20
+ default_weight_observer,
21
+ )
22
+ from torch.ao.quantization.fx._equalize import (
23
+ default_equalization_qconfig,
24
+ EqualizationQConfig,
25
+ )
26
+ from torch.ao.quantization.observer import _is_activation_post_process
27
+
28
+ # Names for observer insert keys
29
+ DETECTOR_TARGET_NODE_KEY = "target_node"
30
+ DETECTOR_OBS_TO_INSERT_KEY = "observer_to_insert"
31
+ DETECTOR_IS_POST_OBS_KEY = "is_post_observer"
32
+ DETECTOR_OBS_ARGS_KEY = "observer_args"
33
+
34
+ # Mapping related code
35
+ class DetectorQConfigInfo:
36
+ r"""
37
+ This class contains the QConfig information for a single module.
38
+ The list of variables / values this contains can grow depending on the
39
+ extensibility of the qconfig mapping feature set but this currently includes:
40
+ - if activation observer is dynamic
41
+ - if weight observer is per channel
42
+
43
+
44
+ Args:
45
+ module_fqn (str): The fully qualified name (fqn) of the module that this
46
+ information contains info relevant to qconfig for
47
+ """
48
+
49
+ def __init__(self, module_fqn: str):
50
+ super().__init__()
51
+ self.module_fqn = module_fqn
52
+
53
+ # populate this section with all the variables we might find important
54
+ # change from none if your detector is actually using this
55
+ self.is_activation_dynamic = False
56
+ self.is_weight_per_channel = False
57
+
58
+ # equalization related options
59
+ self.is_equalization_recommended = False
60
+
61
+ def generate_quantization_qconfig(self, module: torch.nn.Module) -> QConfig:
62
+ r"""
63
+ Args:
64
+ module (torch.nn.Module) The module we are generating
65
+ the qconfig for
66
+
67
+ Returns the generated quantization QConfig according to what a valid configuration is
68
+ """
69
+ # Apply suggestions to new qconfig
70
+ module_qconfig = default_qconfig
71
+
72
+ # keep track of dynamic and per_channel recommendations
73
+ recommendations_list = []
74
+ # append as if a list of combinations
75
+ recommendations_list.append((self.is_activation_dynamic, self.is_weight_per_channel))
76
+ recommendations_list.append((self.is_activation_dynamic, False)) # only trying dynamic rec
77
+ recommendations_list.append((False, self.is_weight_per_channel)) # only trying dynamic
78
+
79
+ # now we try each of the combinations
80
+ for rec in recommendations_list:
81
+ # rec[0] -> dynamic recommended
82
+ # rec[1] -> per channel recommended
83
+ activation = default_dynamic_quant_observer if rec[0] else default_observer
84
+ weight = default_per_channel_weight_observer if rec[1] else default_weight_observer
85
+ test_config = QConfig(activation, weight)
86
+ try:
87
+ _assert_valid_qconfig(test_config, module)
88
+ module_qconfig = test_config
89
+ break
90
+ except AssertionError:
91
+ # if not a valid configuration, we move on to the next one in priority
92
+ continue
93
+
94
+ # return the QConfig chosen
95
+ return module_qconfig
96
+
97
+ def generate_equalization_qconfig(self) -> EqualizationQConfig:
98
+ r"""
99
+ This returns the equalization configuration for a module.
100
+
101
+ For now, it just returns the default, but as more equalization options become
102
+ possible, this method can get more fleshed out with more nuanced granularity.
103
+
104
+
105
+ Returns the generated equalization QConfig according to what a valid configuration is
106
+ """
107
+ # in this case, we just return default equalization config
108
+ # we know this is valid because only valid modules would even
109
+ # have this option
110
+ return default_equalization_qconfig
111
+
112
+ # Adding base class for detectors
113
+ class DetectorBase(ABC):
114
+ r""" Base Detector Module
115
+ Any detector class should derive from this class.
116
+
117
+ Concrete detectors should follow the same general API, which includes:
118
+ - A method to calculate and return observer insertion points
119
+ - Should return both the fqns and the Observer class to insert
120
+ - A method to return a report based on the detector
121
+ - Should return a str-based report and dict info in Tuple[str,Dict] format
122
+ """
123
+
124
+ def __init__(self):
125
+ super().__init__()
126
+ self.detector_config_info = None
127
+
128
+ @abstractmethod
129
+ def determine_observer_insert_points(self, model) -> Dict:
130
+ r"""
131
+ Args
132
+ model (nn.Module or subclass): model to find observer insertion points
133
+
134
+ Returns a Dict mapping from unique observer fqns (where we want to insert them) to a Dict.
135
+ This dict maps string keys to detector specific information
136
+ """
137
+ pass
138
+
139
+ @abstractmethod
140
+ def get_detector_name(self) -> str:
141
+ r""" Returns the name of the current detector """
142
+ pass
143
+
144
+
145
+ @abstractmethod
146
+ def get_qconfig_info(self, model) -> Dict[str, DetectorQConfigInfo]:
147
+ r""" Returns the DetectorQConfigInfo for each module_fqn relevant
148
+ Args
149
+ model (nn.Module or subclass): model to find observer insertion points
150
+
151
+ Returns a Dict mapping from unique observer fqns (where we want to insert them) to:
152
+ A DetectorQConfigInfo with the information to generate a QConfig for a specific module
153
+ """
154
+ pass
155
+
156
+ def _get_targeting_node(self, prepared_fx_model: GraphModule, target_fqn: str) -> torch.fx.node.Node:
157
+ r"""
158
+ Takes in a GraphModule and the target_fqn and finds the node whose target is this fqn.
159
+
160
+ If it's not found, it means it is most likely inside a fused layer
161
+ We just go one layer up in terms of the fqn we are searching for until we find parent node
162
+ If we get to empty string, then we know that it doesn't exist
163
+
164
+ The reason for the recursion is that if the model that we are looking for got fused,
165
+ we will have module fqn as e.g. x.linear.0 but the graph will only have a node for the fused module,
166
+ which would have fqn as x.linear so they will not match.
167
+ To handle this, if we don't match, we then take off the last bit of the fqn e.g. x.linear.0 -> x.linear,
168
+ or more generally foo.bar.baz -> foo.bar and search again, this will allow us to locate the correct module
169
+ even in cases with fusion
170
+
171
+ Args:
172
+ prepared_fx_model (GraphModule): The prepared Fx GraphModule
173
+ target_fqn (str): The fqn of the layer we are trying to target
174
+
175
+ Returns the node object we are trying to add observers around
176
+ """
177
+ for node in prepared_fx_model.graph.nodes:
178
+ # if the node's target is our target, return it
179
+ if node.target == target_fqn:
180
+ return node
181
+
182
+ # getting here means node not found
183
+ # if no "." we are already at base and failed
184
+ parent_fqn_sep_index = target_fqn.rfind(".")
185
+ if parent_fqn_sep_index == -1:
186
+ raise ValueError("passed in target_fqn not found in graph's targets.")
187
+ else:
188
+ # recursively call it with parent fqn
189
+ return self._get_targeting_node(prepared_fx_model, target_fqn[:parent_fqn_sep_index])
190
+
191
+ @abstractmethod
192
+ def generate_detector_report(self, model) -> Tuple[str, Dict[str, Any]]:
193
+ r"""
194
+ Args
195
+ model (nn.Module or subclass): model to find observer insertion points
196
+
197
+ Returns a Tuple of two elements:
198
+ Str: string report of the suggested improvements
199
+ Dict: contains useful data collected by the observer pertinent to this report
200
+ """
201
+ pass
202
+
203
+ class PerChannelDetector(DetectorBase):
204
+ r""" This class is used to detect if any Linear or Conv layers in a model utilize per_channel quantization.
205
+ Only Linear and Conv layers can use per_channel as of now so only these two are currently checked.
206
+
207
+ per_channel quantization can lead to major benefits in the form of accuracy.
208
+ Therefore, if the backend used by the user supports it, it is recommended to use
209
+
210
+ Args:
211
+ backend (str, optional): the backend the user wishes to use in production
212
+ Default value is current torch.backends.quantized.engine
213
+ """
214
+
215
+ # Keys for return dictionary
216
+ BACKEND_KEY = "backend"
217
+ PER_CHAN_SUPPORTED_KEY = "per_channel_quantization_supported"
218
+ PER_CHAN_USED_KEY = "per_channel_quantization_used"
219
+
220
+ # Default map for representing supported per channel quantization modules for different backends
221
+ DEFAULT_BACKEND_PER_CHANNEL_SUPPORTED_MODULES: Dict[str, Set[Any]] = {
222
+ "fbgemm": {nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nnqat.Linear, nnqat.Conv1d, nnqat.Conv2d, nnqat.Conv3d},
223
+ "qnnpack": {nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nnqat.Linear, nnqat.Conv1d, nnqat.Conv2d, nnqat.Conv3d},
224
+ "onednn": {nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nnqat.Linear, nnqat.Conv1d, nnqat.Conv2d, nnqat.Conv3d},
225
+ "x86": {nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nnqat.Linear, nnqat.Conv1d, nnqat.Conv2d, nnqat.Conv3d},
226
+ }
227
+
228
+ def __init__(self, backend: str = torch.backends.quantized.engine):
229
+ super().__init__()
230
+
231
+ # store the backend information
232
+ self.backend_chosen = backend
233
+ self.supported_modules = set()
234
+ if self.backend_chosen in self.DEFAULT_BACKEND_PER_CHANNEL_SUPPORTED_MODULES:
235
+ self.supported_modules = self.DEFAULT_BACKEND_PER_CHANNEL_SUPPORTED_MODULES[self.backend_chosen]
236
+ else:
237
+ raise ValueError(f"Not configured to work with {self.backend_chosen}. Try a different default backend")
238
+
239
+ def get_detector_name(self) -> str:
240
+ r""" returns the string name of this detector"""
241
+ return "per_channel_detector"
242
+
243
+ def get_qconfig_info(self, model) -> Dict[str, DetectorQConfigInfo]:
244
+ r""" Returns the DetectorQConfigInfo for each module_fqn relevant
245
+ Args
246
+ model (nn.Module or subclass): model to find observer insertion points
247
+
248
+ Returns a Dict mapping from unique observer fqns (where we want to insert them) to:
249
+ A DetectorQConfigInfo with the information to generate a QConfig for a specific module
250
+ """
251
+ # run the helper function to populate the dictionary
252
+ per_channel_info = self._detect_per_channel_helper(model)
253
+
254
+ # we actually have a qconfig info object we are populating
255
+ module_fqn_to_detector_qconfig_info = {}
256
+
257
+ for module_fqn in per_channel_info:
258
+ # create a detector info instance
259
+ detector_qconfig_info = DetectorQConfigInfo(module_fqn)
260
+
261
+ # see if per channel quantization is supported
262
+ per_chan_supported: bool = per_channel_info[module_fqn][self.PER_CHAN_SUPPORTED_KEY]
263
+ detector_qconfig_info.is_weight_per_channel = per_chan_supported
264
+ module_fqn_to_detector_qconfig_info[module_fqn] = detector_qconfig_info
265
+
266
+ return module_fqn_to_detector_qconfig_info
267
+
268
+ def determine_observer_insert_points(self, model: nn.Module) -> Dict:
269
+ r"""
270
+ There is no observers inserted for the PerChannelDetector.
271
+
272
+ Returns an empty dictionary since no observers are added or needed
273
+ """
274
+ return {}
275
+
276
+
277
+ def _detect_per_channel_helper(self, model: nn.Module):
278
+ r"""
279
+ determines if per_channel quantization is supported in modules and submodules.
280
+
281
+ Returns a dictionary in the higher level _detect_per_channel function.
282
+ Each entry maps the fully-qualified-name to information on whether per_channel quantization.
283
+
284
+ Args:
285
+ model: The current module that is being checked to see if it is per_channel quantizable
286
+
287
+ Returns dictionary mapping fqns to if per_channel quantization is possible
288
+ """
289
+ # create dict we will return
290
+ per_channel_info: Dict = {}
291
+
292
+ # get the fully qualified name and check if in list of modules to include and list of modules to ignore
293
+ for fqn, module in model.named_modules():
294
+
295
+ is_in_include_list = sum([isinstance(module, x) for x in self.supported_modules]) > 0
296
+
297
+ # check if the module per_channel is supported
298
+ # based on backend
299
+ per_channel_supported = False
300
+
301
+ if is_in_include_list:
302
+ per_channel_supported = True
303
+
304
+ # assert statement for MyPy
305
+ q_config_file = module.qconfig
306
+ assert isinstance(q_config_file, QConfig)
307
+
308
+ # this object should either be fake quant or observer
309
+ q_or_s_obj = module.qconfig.weight.p.func()
310
+ assert isinstance(q_or_s_obj, (FakeQuantize, ObserverBase))
311
+
312
+ per_channel_used = False # will be true if found in qconfig
313
+
314
+ if hasattr(q_or_s_obj, "ch_axis"): # then we know that per_channel quantization used
315
+
316
+ # all fake quants have channel axis so need to check is_per_channel
317
+ if isinstance(q_or_s_obj, FakeQuantize):
318
+ if hasattr(q_or_s_obj, "is_per_channel") and q_or_s_obj.is_per_channel:
319
+ per_channel_used = True
320
+ elif isinstance(q_or_s_obj, ObserverBase):
321
+ # should be an observer otherwise
322
+ per_channel_used = True
323
+ else:
324
+ raise ValueError("Should be either observer or fake quant")
325
+
326
+ per_channel_info[fqn] = {
327
+ self.PER_CHAN_SUPPORTED_KEY: per_channel_supported,
328
+ self.PER_CHAN_USED_KEY: per_channel_used,
329
+ self.BACKEND_KEY: self.backend_chosen
330
+ }
331
+
332
+ return per_channel_info
333
+
334
+ def generate_detector_report(self, model: nn.Module) -> Tuple[str, Dict[str, Any]]:
335
+ r"""Checks if any Linear or Conv layers in the model utilize per_channel quantization.
336
+ Only Linear and Conv layers can use per_channel as of now so only these two are currently checked.
337
+
338
+ Looks at q_config format and backend to determine if per_channel can be utilized.
339
+ Uses the DEFAULT_BACKEND_PER_CHANNEL_SUPPORTED_MODULES structure to determine support
340
+
341
+ Args:
342
+ model: The prepared and calibrated model we want to check if using per_channel
343
+
344
+ Returns a tuple with two elements:
345
+ String report of potential actions to improve model (if per_channel quantization is available in backend)
346
+ Dictionary mapping per_channel quantizable elements to:
347
+ whether per_channel quantization is supported by the backend
348
+ if it is being utilized in the current model
349
+ """
350
+
351
+ # run the helper function to populate the dictionary
352
+ per_channel_info = self._detect_per_channel_helper(model)
353
+
354
+ # String to let the user know of further optimizations
355
+ further_optims_str = f"Further Optimizations for backend {self.backend_chosen}: \n"
356
+
357
+ optimizations_possible = False
358
+ for fqn in per_channel_info:
359
+ fqn_dict = per_channel_info[fqn]
360
+ if fqn_dict[self.PER_CHAN_SUPPORTED_KEY] and not fqn_dict[self.PER_CHAN_USED_KEY]:
361
+ optimizations_possible = True
362
+ further_optims_str += f"Module {fqn} can be configured to use per_channel quantization.\n"
363
+
364
+ if optimizations_possible:
365
+ further_optims_str += (
366
+ "To use per_channel quantization, make sure the qconfig has a per_channel weight observer."
367
+ )
368
+ else:
369
+ further_optims_str += "No further per_channel optimizations possible."
370
+
371
+ # return the string and the dictionary form of same information
372
+ return (further_optims_str, per_channel_info)
373
+
374
+
375
+ class DynamicStaticDetector(DetectorBase):
376
+ r"""
377
+ Determines whether dynamic or static quantization is more appropriate for a given module.
378
+
379
+ Takes advantage of the ModelReportObserver that records range information.
380
+ Stationary distribution of data are strictly above tolerance level for the comparison statistic:
381
+
382
+ S = average_batch_activation_range/epoch_activation_range
383
+
384
+ Nonstationary distributions are below or at the tolerance level for this metric.
385
+
386
+ If the distribution of data right after the module is non-stationary, recommend dynamic quantization
387
+ Otherwise recommend static quantization
388
+
389
+ Args:
390
+ tolerance (float, optional): The threshold where S metric is stationary above and non-stationary otherwise. Default: 0.5
391
+ """
392
+ # names for the pre and post observers that are inserted
393
+ DEFAULT_PRE_OBSERVER_NAME = "model_report_pre_observer"
394
+ DEFAULT_POST_OBSERVER_NAME = "model_report_post_observer"
395
+
396
+ # naming conventions for stationary vs non-stationary data
397
+ STATIONARY_STR = "stationary"
398
+ NON_STATIONARY_STR = "non-stationary"
399
+
400
+ # naming for activation
401
+ INPUT_ACTIVATION_PREFIX = "input_activation_"
402
+ OUTPUT_ACTIVATION_PREFIX = "output_activation_"
403
+
404
+ # naming conventions for the keys of the return module info
405
+ TOLERANCE_KEY = "dynamic_static_tolerance"
406
+ DEFAULT_DYNAMIC_REC_KEY = "dynamic_recommended"
407
+ PRE_OBS_COMP_STAT_KEY = INPUT_ACTIVATION_PREFIX + "dynamic_static_comp_stat"
408
+ POST_OBS_COMP_STAT_KEY = OUTPUT_ACTIVATION_PREFIX + "dynamic_static_comp_stat"
409
+ PRE_OBS_DATA_DIST_KEY = INPUT_ACTIVATION_PREFIX + "dynamic_static_data_classification"
410
+ POST_OBS_DATA_DIST_KEY = OUTPUT_ACTIVATION_PREFIX + "dynamic_static_data_classification"
411
+ IS_CURRENTLY_SUPPORTED_KEY = "is_dynamic_supported"
412
+
413
+ # modules that are supported both dynamic and static for this report function
414
+ DEFAULT_DYNAMIC_STATIC_CHECK_SUPPORTED = {nn.Linear}
415
+
416
+ # modules that will be supported soon for both
417
+ DEFAULT_DYNAMIC_STATIC_FUTURE_SUPPORTED = {nn.Conv1d, nn.Conv2d, nn.Conv3d}
418
+
419
+ def __init__(self, tolerance=0.5):
420
+ super().__init__()
421
+
422
+ # set tolerance level and initialize a set to keep track of useful fqn locations
423
+ self.tolerance = tolerance
424
+ self.useful_observer_fqns: Set[str] = set()
425
+
426
+ def determine_observer_insert_points(self, prepared_fx_model: GraphModule) -> Dict[str, Dict[str, Any]]:
427
+ r"""
428
+ Determines where observers need to be inserted for the Dynamic vs Static detector.
429
+ For this detector, we want to place observers on either side of linear layers in the model.
430
+
431
+ Currently inserts observers for:
432
+ linear layers
433
+
434
+ Args:
435
+ prepared_fx_model (GraphModule): The prepared Fx GraphModule
436
+
437
+ Returns a Dict mapping from unique observer fqns (where we want to insert them) to a Dict with:
438
+ key "target_node" -> the node we are trying to observe with this observer (torch.fx.node.Node)
439
+ key "observer_to_insert" -> the observer we wish to insert (ObserverBase)
440
+ key "is_post_observer" -> True if this is meant to be a post-observer for target_node, False if pre-observer
441
+ key "observer_args" -> The arguments that are meant to be passed into the observer
442
+ """
443
+
444
+ # observer for this detector is ModelReportObserver
445
+ obs_ctr = ModelReportObserver
446
+
447
+ # return dict
448
+ obs_fqn_to_info: Dict[str, Dict[str, Any]] = {}
449
+
450
+ for fqn, module in prepared_fx_model.named_modules():
451
+ # make sure module is supported
452
+ if self._is_supported(module, insert=True):
453
+ # if it's a supported type, we want to get node and add observer insert locations
454
+ targeted_node = self._get_targeting_node(prepared_fx_model, fqn)
455
+
456
+ # add entry for pre-observer
457
+ pre_obs_fqn = fqn + "." + self.DEFAULT_PRE_OBSERVER_NAME
458
+
459
+ obs_fqn_to_info[pre_obs_fqn] = {
460
+ DETECTOR_TARGET_NODE_KEY: targeted_node,
461
+ DETECTOR_OBS_TO_INSERT_KEY: obs_ctr(),
462
+ DETECTOR_IS_POST_OBS_KEY: False,
463
+ DETECTOR_OBS_ARGS_KEY: targeted_node.args
464
+ }
465
+
466
+ # add entry for post-observer
467
+ post_obs_fqn = fqn + "." + self.DEFAULT_POST_OBSERVER_NAME
468
+
469
+ obs_fqn_to_info[post_obs_fqn] = {
470
+ DETECTOR_TARGET_NODE_KEY: targeted_node,
471
+ DETECTOR_OBS_TO_INSERT_KEY: obs_ctr(),
472
+ DETECTOR_IS_POST_OBS_KEY: True,
473
+ DETECTOR_OBS_ARGS_KEY: (targeted_node,)
474
+ }
475
+
476
+ return obs_fqn_to_info
477
+
478
+ def get_detector_name(self) -> str:
479
+ r""" returns the string name of this detector"""
480
+ return "dynamic_vs_static_detector"
481
+
482
+
483
+ def get_qconfig_info(self, model) -> Dict[str, DetectorQConfigInfo]:
484
+ r""" Returns the DetectorQConfigInfo for each module_fqn relevant
485
+ Args
486
+ model (nn.Module or subclass): model to find observer insertion points
487
+
488
+ Returns a Dict mapping from unique observer fqns (where we want to insert them) to:
489
+ A DetectorQConfigInfo with the information to generate a QConfig for a specific module
490
+ """
491
+ # run the helper function to populate the dictionary
492
+ dynamic_static_info = self._generate_dict_info(model)
493
+
494
+ # we actually have a qconfig info object we are populating
495
+ module_fqn_to_detector_qconfig_info = {}
496
+
497
+ for module_fqn in dynamic_static_info:
498
+ # create a detector info instance
499
+ detector_qconfig_info = DetectorQConfigInfo(module_fqn)
500
+
501
+ # see if per channel quantization is supported
502
+ dynamic_static_recommended: bool = dynamic_static_info[module_fqn][self.DEFAULT_DYNAMIC_REC_KEY]
503
+ detector_qconfig_info.is_activation_dynamic = dynamic_static_recommended
504
+ module_fqn_to_detector_qconfig_info[module_fqn] = detector_qconfig_info
505
+
506
+ return module_fqn_to_detector_qconfig_info
507
+
508
+ def _is_supported(self, module: nn.Module, insert: bool = False) -> bool:
509
+ r"""Returns whether the given module is supported for observers
510
+
511
+ Args
512
+ module: The module to check and ensure is supported
513
+ insert: True if this is check for observer insertion, false if for report gen
514
+
515
+ Returns True if the module is supported by observer, False otherwise
516
+ """
517
+ # check to see if module is of a supported type
518
+ is_supported_type = sum([isinstance(module, x) for x in self.DEFAULT_DYNAMIC_STATIC_CHECK_SUPPORTED]) > 0
519
+
520
+ # check if it will be supported
521
+ future_supported_type = sum([isinstance(module, x) for x in self.DEFAULT_DYNAMIC_STATIC_FUTURE_SUPPORTED]) > 0
522
+
523
+ # supported
524
+ supported = is_supported_type or future_supported_type
525
+
526
+ # this is check for observer insertion
527
+ if insert:
528
+ return supported
529
+ else:
530
+ # this is for report gen and we also need to check if it contains observers
531
+ has_obs = hasattr(module, self.DEFAULT_PRE_OBSERVER_NAME) and hasattr(module, self.DEFAULT_POST_OBSERVER_NAME)
532
+ return supported and has_obs
533
+
534
+ def _generate_dict_info(self, model: GraphModule) -> Dict[str, Any]:
535
+ r"""
536
+ Helper function for generate_detector_report that does the generation of the dictionary.
537
+ This process is done as specified in generate_detector_report documentation
538
+
539
+ Args:
540
+ model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers
541
+
542
+ Returns a Dictionary mapping modules with ModelReportObservers around them to:
543
+ whether dynamic quantization is recommended
544
+ their S metric of input to module
545
+ whether input to module is stationary or non-stationary
546
+ their S metric of output of module
547
+ whether output of module is stationary or non-stationary
548
+ the tolerance level to decided whether input/output is stationary or non-stationary
549
+ whether it is currently supported or planned for the future
550
+ """
551
+ # store modules dynamic vs static information
552
+ module_dynamic_static_info = {}
553
+
554
+ # This for loop goes through the modules, and extracts all relevant information into module_dynamic_static_info
555
+ # This information primary includes whether the data distributions around a supported module is stationary or not
556
+ # Based on this, it is recorded whether dynamic or static quantization is recommended
557
+
558
+ # loop through all submodules included nested ones
559
+ for fqn, module in model.named_modules():
560
+ # if module is Linear has the ModelReportObserver attached to it
561
+ if self._is_supported(module):
562
+ # get pre and post observers for the module
563
+ pre_obs = getattr(module, self.DEFAULT_PRE_OBSERVER_NAME)
564
+ post_obs = getattr(module, self.DEFAULT_POST_OBSERVER_NAME)
565
+
566
+ # get the statistics for each module
567
+ pre_stat = pre_obs.get_batch_to_epoch_ratio()
568
+ post_stat = post_obs.get_batch_to_epoch_ratio()
569
+
570
+ # record module, pre and post stat, and whether to do dynamic or static based off it
571
+ # true if post observer data distribution is non-stationary, false if it's stationary
572
+ dynamic_recommended = post_stat <= self.tolerance
573
+
574
+ # specify the classifications for whether data distributions considered stationary or non-stationary
575
+ pre_obs_dist_classif = self.STATIONARY_STR if pre_stat > self.tolerance else self.NON_STATIONARY_STR
576
+ post_obs_dist_classif = self.STATIONARY_STR if post_stat > self.tolerance else self.NON_STATIONARY_STR
577
+
578
+ # check if current support or future support
579
+ is_supported_type = sum([isinstance(module, x) for x in self.DEFAULT_DYNAMIC_STATIC_CHECK_SUPPORTED]) > 0
580
+
581
+ # store the set of important information for this module
582
+ module_info = {
583
+ self.TOLERANCE_KEY: self.tolerance,
584
+ self.DEFAULT_DYNAMIC_REC_KEY: dynamic_recommended,
585
+ self.PRE_OBS_COMP_STAT_KEY: pre_stat,
586
+ self.PRE_OBS_DATA_DIST_KEY: pre_obs_dist_classif,
587
+ self.POST_OBS_COMP_STAT_KEY: post_stat,
588
+ self.POST_OBS_DATA_DIST_KEY: post_obs_dist_classif,
589
+ self.IS_CURRENTLY_SUPPORTED_KEY: is_supported_type,
590
+ }
591
+
592
+ module_dynamic_static_info[fqn] = module_info
593
+
594
+ return module_dynamic_static_info
595
+
596
+ def generate_detector_report(self, model: GraphModule) -> Tuple[str, Dict[str, Any]]:
597
+ r"""
598
+ Determines whether dynamic or static quantization is more appropriate for a given module.
599
+
600
+ Takes advantage of the ModelReportObserver that records range information.
601
+ Stationary distribution of data are strictly above tolerance level for the comparison statistic:
602
+
603
+ S = average_batch_activation_range/epoch_activation_range
604
+
605
+ Nonstationary distributions are below or at the tolerance level for this metric.
606
+
607
+ If the distribution of data right after the module is non-stationary, recommend dynamic quantization
608
+ Otherwise recommend static quantization
609
+
610
+ This will then generate suggestions for dynamic vs static quantization focused around Linear.
611
+
612
+ Args:
613
+ model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers
614
+
615
+ Returns a tuple with two elements:
616
+ String report of of whether dynamic or static quantization is recommended for certain modules
617
+ Dictionary mapping modules with ModelReportObservers around them to:
618
+ whether dynamic quantization is recommended
619
+ their S metric of input to module
620
+ whether input to module is stationary or non-stationary
621
+ their S metric of output of module
622
+ whether output of module is stationary or non-stationary
623
+ the tolerance level to decided whether input/output is stationary or non-stationary
624
+ whether it is currently supported or planned for the future
625
+ """
626
+
627
+ # get the dictionary of the information to format the string report
628
+ module_dynamic_static_info = self._generate_dict_info(model)
629
+
630
+ dynamic_vs_static_string = "Dynamic vs. Static Quantization suggestions: \n"
631
+
632
+ modules_added: bool = False # check to make sure at least 1 module added.
633
+
634
+ dynamic_benefit = " You will get more accurate results if you use dynamic quantization"
635
+ static_benefit = " You can increase model efficiency if you use static quantization"
636
+ future_support_str = ". This layer is not yet supported for dynamic quantization"
637
+ # This for loop goes through the information collected in module_dynamic_static_info and:
638
+ # Populates the string based report with the information from module_dynamic_static_info
639
+ # Compiles the complete report by appending relevant formatted strings
640
+
641
+ for module_fqn in module_dynamic_static_info.keys():
642
+
643
+ # there is at least 1 module for suggestion
644
+ modules_added = True
645
+ module_info = module_dynamic_static_info[module_fqn]
646
+ suggestion_string_template = "For module {} it is suggested to use {} quantization because {}.\n"
647
+
648
+ # decide what string formatting values will be
649
+ quantization_type = ""
650
+ quantization_reasoning = "the distribution of data before {} is {} and the distribution after is {}."
651
+
652
+ benefit_str = ""
653
+
654
+ # strings for if dynamic quantized per tensor is needed
655
+ recommend_per_tensor = ". We recommend to add a {} before this module if it is static."
656
+ rec_lay_to_add = "dynamic quantize per tensor layer"
657
+ dynamic_per_tensor_string = recommend_per_tensor.format(rec_lay_to_add)
658
+ dynamic_per_tensor_reasoning_string = (
659
+ " This is because the input to this module has a non-stationary distribution"
660
+ )
661
+
662
+ # start composing explanation
663
+ if module_info[self.DEFAULT_DYNAMIC_REC_KEY]:
664
+ quantization_type = "dynamic"
665
+ # check if currently supported or future supported
666
+ benefit_str = dynamic_benefit
667
+ if not module_info[self.IS_CURRENTLY_SUPPORTED_KEY]:
668
+ benefit_str += future_support_str
669
+ else:
670
+ quantization_type = "static"
671
+ benefit_str = static_benefit
672
+
673
+ # now set the quantization explanation string
674
+ quantization_reasoning = (
675
+ quantization_reasoning.format(
676
+ module_fqn, module_info[self.PRE_OBS_DATA_DIST_KEY], module_info[self.POST_OBS_DATA_DIST_KEY]
677
+ )
678
+ + benefit_str
679
+ )
680
+
681
+ # if we have a non-stationary input -> linear -> stationary we suggested static
682
+ # however, we want to also recommend they add a dynamic quantize per tensor right if this change is made
683
+ if (
684
+ module_info[self.PRE_OBS_DATA_DIST_KEY] == self.NON_STATIONARY_STR
685
+ and module_info[self.POST_OBS_DATA_DIST_KEY] == self.STATIONARY_STR
686
+ ):
687
+ quantization_reasoning = (
688
+ quantization_reasoning + dynamic_per_tensor_string + dynamic_per_tensor_reasoning_string
689
+ )
690
+
691
+ # format the overall suggestion string with the specific inputs
692
+ module_suggestion_string = suggestion_string_template.format(
693
+ module_fqn, quantization_type, quantization_reasoning
694
+ )
695
+
696
+ # append to overall suggestion
697
+ dynamic_vs_static_string += module_suggestion_string
698
+
699
+ if not modules_added:
700
+ dynamic_vs_static_string += "No applicable layers for suggestions. Only linear and conv are valid.\n"
701
+
702
+ # return the string as well as the dictionary of information
703
+ return (dynamic_vs_static_string, module_dynamic_static_info)
704
+
705
+
706
+ class InputWeightEqualizationDetector(DetectorBase):
707
+ r"""
708
+ Determines whether input-weight equalization can help improve quantization for certain modules.
709
+
710
+ Specifically, this list of modules includes:
711
+ linear
712
+ conv
713
+
714
+ Determines whether input-weight equalization is recommended based on the comp stat:
715
+ s_c = sqrt(w_c/W)/sqrt(i_c/I)
716
+ where:
717
+ w_c is range of weight for channel c, W is range of weight over all channels
718
+ i_c is range of input for channel c, I is range of input over all channels
719
+
720
+ if s_c >= threshold or <= 1 / threshold, recommends input-weight equalization
721
+
722
+ Args:
723
+ ratio_threshold (float): The threshold for s_c to determine if input-weight equalization is suggested
724
+ Should be between 0 and 1 (both non-inclusive)
725
+ ch_axis (int, optional): The channel axis being observed to determine input weight equalization
726
+ Default: 1
727
+
728
+ * :attr:`ratio_threshold`: The threshold for s_c to determine if input-weight equalization is suggested
729
+ Should be between 0 and 1
730
+
731
+ * :attr:`ch_axis`: The channel axis being observed to determine input weight equalization
732
+
733
+ * :attr:`SUPPORTED_MODULES`: This specifies the modules that are supported for input-weight equalization
734
+
735
+ * :attr:`DEFAULT_PRE_OBSERVER_NAME`: The name of the pre-observer to be inserted for this detector
736
+ """
737
+
738
+ SUPPORTED_MODULES: Set[Callable] = {nn.Linear,
739
+ nn.Conv1d,
740
+ nn.Conv2d,
741
+ nn.Conv3d,
742
+ nnqat.Linear,
743
+ nnqat.Conv1d,
744
+ nnqat.Conv2d,
745
+ nnqat.Conv3d}
746
+
747
+ # names for the pre and post observers that are inserted
748
+ DEFAULT_PRE_OBSERVER_NAME: str = "model_report_pre_observer"
749
+
750
+ # weight / activation prefix for each of the below info
751
+ WEIGHT_PREFIX = "weight_"
752
+ ACTIVATION_PREFIX = "input_activation_"
753
+
754
+ # string names for keys of info dictionaries
755
+ PER_CHANNEL_MAX_KEY = "per_channel_max"
756
+ PER_CHANNEL_MIN_KEY = "per_channel_min"
757
+ GLOBAL_MAX_KEY = "global_max"
758
+ GLOBAL_MIN_KEY = "global_min"
759
+
760
+ # keys for return dict of recommendations
761
+ RECOMMENDED_KEY = "input_weight_equalization_recommended"
762
+ COMP_METRIC_KEY = "input_weight_channel_comparison_metrics"
763
+ THRESHOLD_KEY = "input_weight_threshold"
764
+ CHANNEL_KEY = "input_weight_channel_axis"
765
+
766
+ # default weight and info strings
767
+ WEIGHT_STR = "weight"
768
+ INPUT_STR = "input"
769
+
770
+ # default for what ratio we recommend input weight
771
+ DEFAULT_RECOMMEND_INPUT_WEIGHT_CHANNEL_RATIO = 0.4
772
+
773
+ def __init__(self, ratio_threshold: float, ch_axis: int = 1):
774
+ # ensure passed in inputs are valid
775
+ if ratio_threshold <= 0 or ratio_threshold >= 1:
776
+ raise ValueError("Make sure threshold is > 0 and < 1")
777
+
778
+ # initialize attributes based on args
779
+ self.ratio_threshold: float = ratio_threshold
780
+ self.ch_axis: int = ch_axis
781
+
782
+ def _is_supported(self, module: nn.Module, insert: bool = False) -> bool:
783
+ r"""Returns whether the given module is supported for observers
784
+
785
+ Args
786
+ module: The module to check and ensure is supported
787
+ insert: True if this is check for observer insertion, false if for report gen
788
+
789
+ Returns True if the module is supported by observer, False otherwise
790
+ """
791
+ # check to see if module is of a supported type
792
+ is_supported_type = sum([type(module) is x for x in self.SUPPORTED_MODULES]) > 0
793
+
794
+ # this is check for observer insertion
795
+ if insert:
796
+ return is_supported_type
797
+ else:
798
+ # this is for report gen and we also need to check if it contains observers
799
+ has_obs = hasattr(module, self.DEFAULT_PRE_OBSERVER_NAME)
800
+ return is_supported_type and has_obs
801
+
802
+ def get_qconfig_info(self, model) -> Dict[str, DetectorQConfigInfo]:
803
+ r""" Returns the DetectorQConfigInfo for each module_fqn relevant
804
+ Args
805
+ model (nn.Module or subclass): model to find observer insertion points
806
+
807
+ Returns a Dict mapping from unique observer fqns (where we want to insert them) to:
808
+ A DetectorQConfigInfo with the information to generate a QConfig for a specific module
809
+ """
810
+ # run the helper function to populate the dictionary
811
+ # find the range of inputs
812
+ input_values: Dict[str, Dict] = self._extract_input_info(model)
813
+
814
+ # find the range of weights
815
+ weight_values: Dict[str, Dict] = self._extract_weight_info(model)
816
+
817
+ # calculate per_channel comparison statistic s_c
818
+ comp_stats: Dict[str, torch.Tensor] = self._generate_comparison_values(input_values, weight_values)
819
+
820
+ # generate the return dictionary
821
+ input_weight_equalization_info: Dict[str, Dict] = self._generate_dict_info(input_values, weight_values, comp_stats)
822
+
823
+ # we actually have a qconfig info object we are populating
824
+ module_fqn_to_detector_qconfig_info = {}
825
+
826
+ for module_fqn in input_weight_equalization_info:
827
+ # create a detector info instance
828
+ detector_qconfig_info = DetectorQConfigInfo(module_fqn)
829
+
830
+ # see if per channel quantization is supported
831
+ input_weight_recommended: bool = input_weight_equalization_info[module_fqn][self.RECOMMENDED_KEY]
832
+ detector_qconfig_info.is_equalization_recommended = input_weight_recommended
833
+ module_fqn_to_detector_qconfig_info[module_fqn] = detector_qconfig_info
834
+
835
+ return module_fqn_to_detector_qconfig_info
836
+
837
+ def determine_observer_insert_points(self, prepared_fx_model: GraphModule) -> Dict[str, Dict[str, Any]]:
838
+ r"""Determines where observers need to be inserted for the Input Weight Equalization Detector.
839
+ For this detector, we want to place observers in front of supported layers.
840
+
841
+ Currently inserts observers for:
842
+ linear layers
843
+ conv layers
844
+
845
+ Args:
846
+ prepared_fx_model (GraphModule): The prepared Fx GraphModule
847
+
848
+ Returns a Dict mapping from unique observer fqns (where we want to insert them) to a Dict with:
849
+ key "target_node" -> the node we are trying to observe with this observer (torch.fx.node.Node)
850
+ key "observer_to_insert" -> the observer we wish to insert (ObserverBase)
851
+ key "is_post_observer" -> True if this is meant to be a post-observer for target_node, False if pre-observer
852
+ key "observer_args" -> The arguments that are meant to be passed into the observer
853
+ """
854
+
855
+ # observer for this detector is ModelReportObserver
856
+ obs_ctr = ModelReportObserver
857
+
858
+ # return dict
859
+ obs_fqn_to_info: Dict[str, Dict[str, Any]] = {}
860
+
861
+ for fqn, module in prepared_fx_model.named_modules():
862
+ # check to see if module is of a supported type
863
+ if self._is_supported(module, insert=True):
864
+ # if it's a supported type, we want to get node and add observer insert locations
865
+ targeted_node = self._get_targeting_node(prepared_fx_model, fqn)
866
+
867
+ # add entry for pre-observer
868
+ pre_obs_fqn = fqn + "." + self.DEFAULT_PRE_OBSERVER_NAME
869
+
870
+ obs_fqn_to_info[pre_obs_fqn] = {
871
+ DETECTOR_TARGET_NODE_KEY: targeted_node,
872
+ DETECTOR_OBS_TO_INSERT_KEY: obs_ctr(ch_axis=self.ch_axis),
873
+ DETECTOR_IS_POST_OBS_KEY: False,
874
+ DETECTOR_OBS_ARGS_KEY: targeted_node.args,
875
+ }
876
+
877
+ return obs_fqn_to_info
878
+
879
+ def get_detector_name(self) -> str:
880
+ r"""Returns the name of this detector"""
881
+ return "input_weight_equalization_detector"
882
+
883
+ def _extract_input_info(self, model: GraphModule) -> Dict[str, Dict]:
884
+ r"""
885
+ Takes in a calibrated GraphModule and then finds the relevant observers.
886
+ It then extracts the input information for each observer returns it
887
+
888
+ Args
889
+ model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers
890
+
891
+ Returns a dict mapping relevant module fqns (str) to a dict with keys:
892
+ "input_activation_per_channel_max" : maps to the per_channel max values
893
+ "input_activation_per_channel_min" : maps to the per_channel min values
894
+ "input_activation_global_max" : maps to the global max recorded
895
+ "input_activation_global_min" : maps to the global min recorded
896
+ """
897
+
898
+ # return dictionary mapping observer fqns to desired info
899
+ input_info: Dict[str, Dict] = {}
900
+
901
+ for fqn, module in model.named_modules():
902
+ # if module is supported and it has a pre-observer
903
+ if self._is_supported(module):
904
+ # get pre observer for the module
905
+ pre_obs = getattr(module, self.DEFAULT_PRE_OBSERVER_NAME)
906
+
907
+ input_info[fqn] = {
908
+ self.ACTIVATION_PREFIX + self.PER_CHANNEL_MAX_KEY: pre_obs.max_val,
909
+ self.ACTIVATION_PREFIX + self.PER_CHANNEL_MIN_KEY: pre_obs.min_val,
910
+ self.ACTIVATION_PREFIX + self.GLOBAL_MAX_KEY: max(pre_obs.max_val),
911
+ self.ACTIVATION_PREFIX + self.GLOBAL_MIN_KEY: min(pre_obs.min_val),
912
+ }
913
+
914
+ return input_info
915
+
916
+ def _extract_weight_info(self, model: GraphModule) -> Dict[str, Dict]:
917
+ r"""
918
+ Takes in a calibrated GraphModule and then finds the relevant observers.
919
+ It then extracts the weight information for each layer an observer is attached to.
920
+
921
+ Args
922
+ model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers
923
+
924
+ Returns a dict mapping module fqns (str) to a dict with keys:
925
+ "per_channel_max" : maps to the per_channel max values
926
+ "per_channel_min" : maps to the per_channel min values
927
+ "global_max" : maps to the global max recorded
928
+ "global_min" : maps to the global min recorded
929
+ """
930
+ # return dictionary mapping observer fqns to desired info
931
+ weight_info: Dict[str, Dict] = {}
932
+
933
+ for fqn, module in model.named_modules():
934
+ # if module is supported and it has a pre-observer
935
+ if self._is_supported(module):
936
+ # we don't need actual observer, just the module weights
937
+ # calculate min and max vals
938
+ device = module.weight.device
939
+ min_val: torch.Tensor = torch.tensor([float('inf')], device=device)
940
+ max_val: torch.Tensor = torch.tensor([float('-inf')], device=device)
941
+ x_copy = module.weight
942
+ x_dim = x_copy.size()
943
+
944
+ new_axis_list = [i for i in range(len(x_dim))] # noqa: C416
945
+ new_axis_list[self.ch_axis] = 0
946
+ new_axis_list[0] = self.ch_axis
947
+ y = x_copy.permute(new_axis_list)
948
+
949
+ # Need to match dtype of min/max because the updates to buffers
950
+ # are done in place and types need to match for comparisons
951
+ y = y.to(min_val.dtype)
952
+ y = torch.flatten(y, start_dim=1)
953
+ if min_val.numel() == 0 or max_val.numel() == 0:
954
+ min_val, max_val = torch.aminmax(y, dim=1)
955
+ else:
956
+ min_val_cur, max_val_cur = torch.aminmax(y, dim=1)
957
+ min_val = torch.min(min_val_cur, min_val)
958
+ max_val = torch.max(max_val_cur, max_val)
959
+
960
+ weight_info[fqn] = {
961
+ self.WEIGHT_PREFIX + self.PER_CHANNEL_MAX_KEY: max_val,
962
+ self.WEIGHT_PREFIX + self.PER_CHANNEL_MIN_KEY: min_val,
963
+ self.WEIGHT_PREFIX + self.GLOBAL_MAX_KEY: max(max_val),
964
+ self.WEIGHT_PREFIX + self.GLOBAL_MIN_KEY: min(min_val),
965
+ }
966
+
967
+ return weight_info
968
+
969
+ def _calculate_range_ratio(self, info_dict: Dict, info_str: str, module_fqn: str) -> torch.Tensor:
970
+ r"""
971
+ Takes in an info dict and calculates the s_c matrix.
972
+
973
+ Args:
974
+ info_dict (dict): A dictionary of either input or weight range info
975
+ info_str (str): A str describing whether currently looking at weight or input info
976
+ Either "weight" or "input"
977
+ module_fqn (str): The fqn of the module we are looking at
978
+
979
+ Returns a tensor of values, where each value is the s_c stat for a different channel
980
+ """
981
+ # calculate the ratios of the info
982
+ # get the prefix str
983
+ prefix_str = self.ACTIVATION_PREFIX if info_str == self.INPUT_STR else self.WEIGHT_PREFIX
984
+
985
+ per_channel_range = info_dict[prefix_str + self.PER_CHANNEL_MAX_KEY] - info_dict[prefix_str + self.PER_CHANNEL_MIN_KEY]
986
+ global_range = info_dict[prefix_str + self.GLOBAL_MAX_KEY] - info_dict[prefix_str + self.GLOBAL_MIN_KEY]
987
+
988
+ if global_range == 0:
989
+ range_zero_explanation = "We recommend removing this channel as it doesn't provide any useful information."
990
+ raise ValueError(
991
+ "The range of the {} data for module {} is 0, which means you have a constant value channel. {}".format(
992
+ info_str, module_fqn, range_zero_explanation
993
+ )
994
+ )
995
+
996
+ ratio = per_channel_range / global_range
997
+
998
+ return ratio
999
+
1000
+ def _generate_comparison_values(self, input_info: Dict, weight_info: Dict) -> Dict[str, torch.Tensor]:
1001
+ r"""
1002
+ Takes in the information on the min and max values of the inputs and weights and:
1003
+ Calculates the comp stat for each channel: s_c = sqrt(w_c/W)/sqrt(i_c/I)
1004
+
1005
+ Args:
1006
+ input_info (dict): A dict mapping each observer to input range information
1007
+ weight_info (dict): A dict mapping each observer to weight range information
1008
+
1009
+ Returns a dict mapping relevant observer fqns (str) to a 1-D tensor.
1010
+ Each value is a different s_c value for a different channel
1011
+ """
1012
+ # create return dictionary for each observer
1013
+ module_fqn_to_channel: Dict[str, torch.Tensor] = {}
1014
+
1015
+ # for each module (both passed in dicts should have same keys)
1016
+ for module_fqn in input_info:
1017
+
1018
+ # raise error if not in weight info
1019
+ if module_fqn not in weight_info:
1020
+ raise KeyError(f"Unable to find weight range stats for module {module_fqn}")
1021
+
1022
+ # calculate the ratios of the weight info and input info
1023
+ weight_ratio = self._calculate_range_ratio(weight_info[module_fqn], self.WEIGHT_STR, module_fqn)
1024
+ input_ratio = self._calculate_range_ratio(input_info[module_fqn], self.INPUT_STR, module_fqn)
1025
+
1026
+ # if mismatched size, because of grouping, we want to replicate weight enough times
1027
+ weight_channels = len(weight_ratio)
1028
+ input_channels = len(input_ratio)
1029
+ if weight_channels != input_channels:
1030
+ # we try to replicate
1031
+ assert input_channels % weight_channels == 0, "input channels should be divisible by weight channels."
1032
+ # get replication factor
1033
+ rep_factor: int = input_channels // weight_channels
1034
+
1035
+ # weight ratio is (n,), input ratio is (k,), we just repeat weight ratio k // n
1036
+ weight_ratio = weight_ratio.repeat(rep_factor)
1037
+
1038
+ # calculate the s metric per channel
1039
+ s = torch.sqrt(weight_ratio) / torch.sqrt(input_ratio)
1040
+ module_fqn_to_channel[module_fqn] = s
1041
+
1042
+ # return compiled observer ratios
1043
+ return module_fqn_to_channel
1044
+
1045
+ def _generate_dict_info(self, input_info: Dict, weight_info: Dict, comp_stats: Dict) -> Dict[str, Dict]:
1046
+ r"""
1047
+ Helper function for generate_detector_report that does the generation of the dictionary.
1048
+ This process is done as specified in generate_detector_report documentation
1049
+
1050
+ Args:
1051
+ input_info (dict): A dict mapping each module to input range information
1052
+ weight_info (dict): A dict mapping each module to weight range information
1053
+ comp_stats (dict): A dict mapping each module to its corresponding comp stat
1054
+
1055
+ Returns a dictionary mapping each module with relevant ModelReportObservers around them to:
1056
+ whether input weight equalization is recommended
1057
+ their s_c metric compared to the threshold
1058
+ the threshold used to make the recommendation
1059
+ the channel used for recording data
1060
+ the input channel range info
1061
+ the weight channel range info
1062
+ """
1063
+ # store modules input weight equalization info
1064
+ input_weight_equalization_info: Dict[str, Dict] = {}
1065
+
1066
+ # for each module we add separate set of suggestions
1067
+ for module_fqn in input_info:
1068
+
1069
+ # get relevant info for this module
1070
+ mod_input_info: Dict = input_info[module_fqn]
1071
+ mod_weight_info: Dict = weight_info[module_fqn]
1072
+ mod_comp_stat: Dict = comp_stats[module_fqn]
1073
+
1074
+ # decide if each channel should have input weight equalization or not
1075
+ channel_rec_vals: list = []
1076
+
1077
+ for val in mod_comp_stat:
1078
+ float_rep: float = val.item()
1079
+
1080
+ # decide if recommending input weight equalization
1081
+ recommended: bool = float_rep >= self.ratio_threshold and float_rep <= 1 / self.ratio_threshold
1082
+ channel_rec_vals.append(recommended)
1083
+
1084
+ # build the return dict input
1085
+ # also unpack input and weight dicts into it
1086
+ input_weight_equalization_info[module_fqn] = {
1087
+ self.RECOMMENDED_KEY: channel_rec_vals,
1088
+ self.COMP_METRIC_KEY: mod_comp_stat,
1089
+ self.THRESHOLD_KEY: self.ratio_threshold,
1090
+ self.CHANNEL_KEY: self.ch_axis,
1091
+ **mod_input_info,
1092
+ **mod_weight_info,
1093
+ }
1094
+
1095
+ # return our compiled info for each module
1096
+ return input_weight_equalization_info
1097
+
1098
+ def generate_detector_report(self, model: GraphModule) -> Tuple[str, Dict[str, Any]]:
1099
+ r"""
1100
+ Determines whether input weight equalization is appropriate for a given module.
1101
+
1102
+ Takes advantage of the ModelReport Observer which records per channel information of input range
1103
+ It then uses the passed in weight info inconjunction to compute the desired ratio
1104
+ Finally, it gives suggestions based on this information for each module of interest
1105
+
1106
+ Args:
1107
+ model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers
1108
+
1109
+ Returns a tuple with two elements:
1110
+ String report of of whether input weight equalization is recommended for certain modules
1111
+ Dictionary mapping modules of interest to:
1112
+ whether input weight equalization is recommended
1113
+ their s_c metric compared to the threshold
1114
+ the threshold used to make the recommendation
1115
+ the channel used for recording data
1116
+ the input channel range info
1117
+ the weight channel range info
1118
+ """
1119
+
1120
+ # find the range of inputs
1121
+ input_values: Dict[str, Dict] = self._extract_input_info(model)
1122
+
1123
+ # find the range of weights
1124
+ weight_values: Dict[str, Dict] = self._extract_weight_info(model)
1125
+
1126
+ # calculate per_channel comparison statistic s_c
1127
+ comp_stats: Dict[str, torch.Tensor] = self._generate_comparison_values(input_values, weight_values)
1128
+
1129
+ # generate the return dictionary
1130
+ input_weight_equalization_info: Dict[str, Dict] = self._generate_dict_info(input_values, weight_values, comp_stats)
1131
+
1132
+ # now we can generate report based on this information
1133
+ input_weight_string = "Input-Weight Equalization suggestions: \n"
1134
+
1135
+ # some strings to be formatted depending on module we are adding
1136
+ module_suggestion_str = "For Module {} looked at with axis {}: \n"
1137
+ channel_suggestion_str = "\tWe suggest {} input weight equalization because {}\n"
1138
+ use_str = "to use"
1139
+ no_use_str = "to not use"
1140
+ input_weight_benefit_str = "{}/{} channels would benefit and we expect significant reduction in quantization error."
1141
+ input_weight_non_benefit_reasoning = "{}/{} channels benefitting from input-weight equalization being applied."
1142
+ input_weight_non_benefit_str = "we don't expect much improvement from input-weight equalization based on {}"
1143
+
1144
+ # added module check
1145
+ added_module: bool = False
1146
+
1147
+ # compile the suggestion string
1148
+ for module_fqn in input_weight_equalization_info:
1149
+ # we added at least 1 module
1150
+ added_module = True
1151
+ # add the module level description
1152
+ input_weight_string += module_suggestion_str.format(module_fqn, self.ch_axis)
1153
+
1154
+ mod_info: Dict[str, Any] = input_weight_equalization_info[module_fqn]
1155
+
1156
+ # gather info on how many channels would benefit from input weight and
1157
+ recommendation_per_channel: torch.Tensor = mod_info[self.RECOMMENDED_KEY]
1158
+ num_recs = sum(recommendation_per_channel)
1159
+
1160
+ if num_recs / len(recommendation_per_channel) >= self.DEFAULT_RECOMMEND_INPUT_WEIGHT_CHANNEL_RATIO:
1161
+ input_benefit_formatted = input_weight_benefit_str.format(num_recs, len(recommendation_per_channel))
1162
+ channel_str = channel_suggestion_str.format(use_str, input_benefit_formatted)
1163
+ input_weight_string += channel_str
1164
+ else:
1165
+ non_benefit_reason_formatted = input_weight_non_benefit_reasoning.format(num_recs, len(recommendation_per_channel))
1166
+ non_benefit_str = input_weight_non_benefit_str.format(non_benefit_reason_formatted)
1167
+ channel_str = channel_suggestion_str.format(no_use_str, non_benefit_str)
1168
+ input_weight_string += channel_str
1169
+
1170
+ # if no modules looked at, amend return string
1171
+ if not added_module:
1172
+ input_weight_string += "No applicable layers for suggestions. Only linear and conv valid.\n"
1173
+
1174
+ # return a tuple with the string explanation and the compiled dict info
1175
+ return (input_weight_string, input_weight_equalization_info)
1176
+
1177
+
1178
+ class OutlierDetector(DetectorBase):
1179
+ r"""
1180
+ Determines whether there are significant outliers in activation data around a certain layer.
1181
+
1182
+ This is ideally used in conjunction with information on stationary vs. non-stationary distribution:
1183
+ If the data is stationary, and there are significant outliers, then we want to flag them
1184
+ We want to do this on a per channel basis for detecting outliers
1185
+
1186
+ Determines whether activation data is flagged as outlier based on if data is stationary and:
1187
+ p_r = avg(100th percentile / "reference_percentile"th percentile)
1188
+ where:
1189
+ p_r is average percentile ratio across all batches in the epoch
1190
+ reference_percentile is a percentile values between 0 and 100 exclusive
1191
+
1192
+ if p_r is above some threshold, then we consider the activations to have significant outliers
1193
+
1194
+ Args:
1195
+ ratio_threshold (float, optional): The threshold for p_r to determine if there are outliers in activations
1196
+ Should be >= 1
1197
+ Default: 3.5
1198
+ reference_percentile (float, optional): The denominator to find the relative scale of the 100th percentile
1199
+ Should be between 0 and 1
1200
+ Default: 0.975
1201
+ fraction_batches_used_threshold (float, optional): Threshold of fraction of batches per channel to determine outlier
1202
+ If fraction is below this, we deem number of samples used to calculate outliers as insignificant and alert user
1203
+ regardless of whether we detected outliers or not in channel to take a closer look at channel results
1204
+ Should be between 0 and 1
1205
+ Default: 0.95
1206
+ ch_axis (int, optional): The channel axis being observed to determine input weight equalization
1207
+ Default: 1
1208
+
1209
+ * :attr:`ratio_threshold`: The threshold for p_r to determine if there are outliers in activations
1210
+ The p_r value (average ratio of 100th percentile/reference_percentile) is compared to ratio_threshold
1211
+ If it is significantly greater, then we consider it an outlier
1212
+ This threshold was calculated based on the ratio of the percentiles in a normal distribution
1213
+ The calculations behind value choice: https://drive.google.com/file/d/1N2wdtXWI-kOH8S7HH4-PYB_NmqzZil4p/view?usp=sharing
1214
+
1215
+ * :attr:`reference_percentile`: The denominator of the top fraction to find the relative scale of the 100th percentile
1216
+ Should be between 0 and 1
1217
+ The calculations behind value choice: https://drive.google.com/file/d/1N2wdtXWI-kOH8S7HH4-PYB_NmqzZil4p/view?usp=sharing
1218
+
1219
+ * :attr:`fraction_batches_used_threshold`: The fraction of batches to determine outliers for each channel should be above this
1220
+ Some batches may not be used because of 0-based errors, so this is to ensure a good amount of the total batches are used
1221
+ Should be between 0 and 1
1222
+
1223
+ * :attr:`ch_axis`: The channel axis being observed to determine outliers
1224
+
1225
+ * :attr:`DEFAULT_PRE_OBSERVER_NAME`: The name of the pre-observer to be inserted for this detector
1226
+ """
1227
+
1228
+ # names for the pre observers that are inserted
1229
+ DEFAULT_PRE_OBSERVER_NAME: str = "model_report_pre_observer"
1230
+
1231
+ # pre activation prefix
1232
+ INPUT_ACTIVATION_PREFIX = "input_activation_"
1233
+
1234
+ # names for dict keys
1235
+ OUTLIER_KEY = "outliers_detected"
1236
+ NUM_BATCHES_KEY = "outlier_detection_batches_used"
1237
+ IS_SUFFICIENT_BATCHES_KEY = "outlier_detection_is_sufficient_batches"
1238
+ COMP_METRIC_KEY = "outlier_detection_percentile_ratios"
1239
+ RATIO_THRES_KEY = "outlier_detection_ratio_threshold"
1240
+ REF_PERCENTILE_KEY = "outlier_detection_reference_percentile"
1241
+ CHANNEL_AXIS_KEY = "outlier_detection_channel_axis"
1242
+ MAX_VALS_KEY = INPUT_ACTIVATION_PREFIX + "per_channel_max"
1243
+ CONSTANT_COUNTS_KEY = "constant_batch_counts"
1244
+
1245
+ def __init__(
1246
+ self,
1247
+ ratio_threshold: float = 3.5,
1248
+ reference_percentile: float = 0.975,
1249
+ fraction_batches_used_threshold: float = 0.95,
1250
+ ch_axis: int = 1,
1251
+ ):
1252
+ # initialize the variables of interest
1253
+ self.ratio_threshold = ratio_threshold
1254
+
1255
+ # make sure passed in percentile is valid
1256
+ assert reference_percentile >= 0 and reference_percentile <= 1
1257
+ assert fraction_batches_used_threshold >= 0 and fraction_batches_used_threshold <= 1
1258
+ self.reference_percentile = reference_percentile
1259
+ self.fraction_batches_used_threshold = fraction_batches_used_threshold
1260
+ self.ch_axis = ch_axis
1261
+
1262
+ def get_detector_name(self) -> str:
1263
+ r"""Returns the name of this detector"""
1264
+ return "outlier_detector"
1265
+
1266
+ def _supports_insertion(self, module: nn.Module) -> bool:
1267
+ r"""Returns whether the given module is supported for observers insertion
1268
+
1269
+ Any module that doesn't have children and isn't an observer itself is supported
1270
+
1271
+ Args
1272
+ module: The module to check and ensure is supported
1273
+
1274
+ Returns True if the module is supported by observer, False otherwise
1275
+ """
1276
+ # case for insertion of module
1277
+ # check if the module has any children and isn't observer
1278
+ num_children = len(list(module.children()))
1279
+ return num_children == 0 and not _is_activation_post_process(module)
1280
+
1281
+ def get_qconfig_info(self, model) -> Dict[str, DetectorQConfigInfo]:
1282
+ r""" Returns the DetectorQConfigInfo for each module_fqn relevant
1283
+ Args
1284
+ model (nn.Module or subclass): model to find observer insertion points
1285
+
1286
+ Returns a Dict mapping from unique observer fqns (where we want to insert them) to:
1287
+ A DetectorQConfigInfo with the information to generate a QConfig for a specific module
1288
+ """
1289
+ # currently doesn't do anything for outlier detector
1290
+ return {}
1291
+
1292
+ def _supports_report_gen(self, module: nn.Module) -> bool:
1293
+ r"""Returns whether the given module is supported for report generation
1294
+
1295
+ Any module that has a model report pre-observer is supported
1296
+
1297
+ Args
1298
+ module: The module to check and ensure is supported
1299
+
1300
+ Returns True if the module is supported by observer, False otherwise
1301
+ """
1302
+ return hasattr(module, self.DEFAULT_PRE_OBSERVER_NAME)
1303
+
1304
+ def determine_observer_insert_points(self, prepared_fx_model: GraphModule) -> Dict[str, Dict[str, Any]]:
1305
+ r""" Determines where observers need to be inserted for the Outlier Detector.
1306
+
1307
+ For this detector, we want to place observers in front of supported layers.
1308
+
1309
+ Currently inserts observers for:
1310
+ all layers that do not have children (leaf level layers)
1311
+
1312
+ Args:
1313
+ prepared_fx_model (GraphModule): The prepared Fx GraphModule
1314
+
1315
+ Returns a Dict mapping from unique observer fqns (where we want to insert them) to a Dict with:
1316
+ key "target_node" -> the node we are trying to observe with this observer (torch.fx.node.Node)
1317
+ key "observer_to_insert" -> the observer we wish to insert (ObserverBase)
1318
+ key "is_post_observer" -> True if this is meant to be a post-observer for target_node, False if pre-observer
1319
+ key "observer_args" -> The arguments that are meant to be passed into the observer
1320
+ """
1321
+ # observer for this detector is ModelReportObserver
1322
+ obs_ctr = ModelReportObserver
1323
+
1324
+ # return dict
1325
+ obs_fqn_to_info: Dict[str, Dict[str, Any]] = {}
1326
+
1327
+ for fqn, module in prepared_fx_model.named_modules():
1328
+ # check to see if module is of a supported type
1329
+ if self._supports_insertion(module):
1330
+ # if it's a supported type, we want to get node and add observer insert locations
1331
+ targeted_node = self._get_targeting_node(prepared_fx_model, fqn)
1332
+
1333
+ # add entry for pre-observer
1334
+ pre_obs_fqn = fqn + "." + self.DEFAULT_PRE_OBSERVER_NAME
1335
+
1336
+ obs_fqn_to_info[pre_obs_fqn] = {
1337
+ DETECTOR_TARGET_NODE_KEY: targeted_node,
1338
+ DETECTOR_OBS_TO_INSERT_KEY: obs_ctr(ch_axis=self.ch_axis, comp_percentile=self.reference_percentile),
1339
+ DETECTOR_IS_POST_OBS_KEY: False,
1340
+ DETECTOR_OBS_ARGS_KEY: targeted_node.args,
1341
+ }
1342
+
1343
+ return obs_fqn_to_info
1344
+
1345
+ def _calculate_outlier_info(
1346
+ self,
1347
+ percentile_ratios: torch.Tensor,
1348
+ counted_batches: torch.Tensor,
1349
+ total_batches: int,
1350
+ ) -> Dict[str, List[bool]]:
1351
+ r"""
1352
+ Gives info on whether the percentile ratios calculated would be considered outliers
1353
+ Also gives information on whether the collected data is statistically significant to make this claim
1354
+
1355
+ Args:
1356
+ percentile_ratios (torch.Tensor): The average percentile_ratios per channel calculated by the observer
1357
+ counted_batches (torch.Tensor): The number of batches used for average calculation per tensor
1358
+ total_batches (int): The total number of batches that passed through observer in this epoch
1359
+
1360
+ Returns a dictionary mapping:
1361
+ "outliers_detected" : list of bools per channel that are true if it is considered an outlier
1362
+ "is_sufficient_batches": if o_r was >= fraction_batches_used_threshold:
1363
+ where o_r = counted_batches / total_batches
1364
+ """
1365
+ outlier_dict: Dict[str, List[bool]] = {self.OUTLIER_KEY: [], self.IS_SUFFICIENT_BATCHES_KEY: []}
1366
+
1367
+ # get both as flattened lists for easy mapping
1368
+ ratios_list: List = percentile_ratios.tolist()
1369
+ num_batches_list: List = counted_batches.tolist()
1370
+
1371
+ # calculate whether channels were statistically significant
1372
+ significant_size = [
1373
+ batch_size / total_batches >= self.fraction_batches_used_threshold for batch_size in num_batches_list
1374
+ ]
1375
+ outlier_dict[self.IS_SUFFICIENT_BATCHES_KEY] = significant_size
1376
+
1377
+ # calculate for each channel whether it's an outlier or not based on ratio
1378
+ outlier_detected = [ratio > self.ratio_threshold for ratio in ratios_list]
1379
+ outlier_dict[self.OUTLIER_KEY] = outlier_detected
1380
+
1381
+ # return the dictionary with the two lists
1382
+ return outlier_dict
1383
+
1384
+ def _generate_info_dict(self, model: GraphModule) -> Dict[str, Dict]:
1385
+ r"""
1386
+ Helper function for generate_detector_report that does the generation of the dictionary.
1387
+ This process is done as specified in generate_detector_report documentation
1388
+
1389
+ Args:
1390
+ model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers
1391
+
1392
+ Returns a dict mapping relevant module fqns to:
1393
+ whether there were outliers found in activation before
1394
+ the number of batches used for each channel
1395
+ whether fraction of applicable batches used is above fraction_batches_used_threshold
1396
+ their p_r metric compared to the threshold
1397
+ the threshold used to make the recommendation
1398
+ the reference_percentile used to make the recommendation
1399
+ the channel axis used to determine individual channels
1400
+ the constant batch counts per channel
1401
+ the per channel max values
1402
+ """
1403
+ # return dictionary mapping observer fqns to desired info
1404
+ info_dict: Dict[str, Dict] = {}
1405
+
1406
+ for fqn, module in model.named_modules():
1407
+ # if module is supported and it has a pre-observer
1408
+ if self._supports_report_gen(module):
1409
+ # get pre observer for the module
1410
+ pre_obs: ModelReportObserver = getattr(module, self.DEFAULT_PRE_OBSERVER_NAME)
1411
+
1412
+ # get the number of batches and calculated ratio thresholds
1413
+ num_batches: torch.Tensor = pre_obs.percentile_batches_tracked
1414
+ average_ratios: torch.Tensor = pre_obs.average_percentile_ratio
1415
+ channel_batch_cnts: torch.Tensor = pre_obs.constant_channels
1416
+ total_batches: int = pre_obs.num_batches_tracked
1417
+
1418
+ # also get the max values
1419
+ max_vals: torch.Tensor = pre_obs.max_val
1420
+
1421
+ # we have to specifically modify how we are recording negative ratio for pre-relu layers
1422
+ for index, ratio_val in enumerate(average_ratios):
1423
+ # check if we have a negative ratio
1424
+ # a ratio might be negative if we have a situation where the 100th percentile is
1425
+ # > 0 while the nth percentile is < 0, in which case this would not be detected
1426
+ # as an outlier. Since we care more about magnitude, we make it positive.
1427
+ if ratio_val.item() < 0:
1428
+ # first make it positive
1429
+ average_ratios[index] = -ratio_val
1430
+
1431
+ if ratio_val.item() < 1:
1432
+ # if it's less than 1 we have the flip it as well
1433
+ average_ratios[index] = 1 / ratio_val
1434
+
1435
+ outlier_calcs = self._calculate_outlier_info(average_ratios, num_batches, total_batches)
1436
+
1437
+ # calculate whether ratios were outliers
1438
+ info_dict[fqn] = {
1439
+ self.CHANNEL_AXIS_KEY: self.ch_axis,
1440
+ self.REF_PERCENTILE_KEY: self.reference_percentile,
1441
+ self.RATIO_THRES_KEY: self.ratio_threshold,
1442
+ self.COMP_METRIC_KEY: average_ratios,
1443
+ self.NUM_BATCHES_KEY: num_batches,
1444
+ self.OUTLIER_KEY: outlier_calcs[self.OUTLIER_KEY],
1445
+ self.IS_SUFFICIENT_BATCHES_KEY: outlier_calcs[self.IS_SUFFICIENT_BATCHES_KEY],
1446
+ self.CONSTANT_COUNTS_KEY: channel_batch_cnts,
1447
+ self.MAX_VALS_KEY: max_vals
1448
+ }
1449
+
1450
+ return info_dict
1451
+
1452
+ def generate_detector_report(self, model: GraphModule) -> Tuple[str, Dict[str, Any]]:
1453
+ r"""
1454
+ Determines whether input weight equalization is appropriate for a given module.
1455
+
1456
+ Takes advantage of the ModelReport Observer which records the relevant percentile information
1457
+
1458
+ Args:
1459
+ model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers
1460
+
1461
+ Returns a tuple with two elements:
1462
+ String report of of whether there are outliers in the activations around certain modules
1463
+ Dictionary mapping modules of interest to:
1464
+ whether there were outliers found in activation before
1465
+ the number of batches used for each channel
1466
+ whether fraction of applicable batches used is above fraction_batches_used_threshold
1467
+ their p_r metric compared to the threshold
1468
+ the threshold used to make the recommendation
1469
+ the reference_percentile used to make the recommendation
1470
+ the channel axis used to determine individual channels
1471
+ the constant batch counts per channel
1472
+ the per channel max values
1473
+ """
1474
+ # generate the information dictionary of outlier information
1475
+ info_dict = self._generate_info_dict(model)
1476
+
1477
+ # now we can generate report based on this information
1478
+ outlier_string = "Outlier detection report: \n"
1479
+
1480
+ # added module check
1481
+ added_module: bool = False
1482
+
1483
+ # some strings to be formatted depending on module we are adding
1484
+ module_suggestion_str = "For Module {} looked at with axis {}: \n"
1485
+ channel_suggestion_str = "\tFor channel {}, we found outliers in the preceding activation data with {}.\n"
1486
+ channel_max_value_str = "a max value across all batches of {}"
1487
+ note_string = "Note: outlier detection is only reliable for {}. We recommend {} to ensure the most accurate results."
1488
+ note_distribution = "stationary distributions"
1489
+ note_rec = "running the static vs. dynamic detector to ensure activation data before modules above is stationary"
1490
+
1491
+ # suggestion for constant batch check since that can make it no outliers
1492
+ constant_str = "\tFor channel {}, we found {} constant value batches. {}\n"
1493
+ constant_suggestion = "We recommend taking a look at the dict and data to see how frequent this occurred and why."
1494
+
1495
+ # compile the suggestion string
1496
+ for module_fqn in info_dict:
1497
+ # get module specific info
1498
+ mod_info: Dict[str, Any] = info_dict[module_fqn]
1499
+ # check to see if we already added high level model desc
1500
+ added_model_desc = False
1501
+ # look at each individual channel and add a suggestion
1502
+ for index, outlier_detected in enumerate(mod_info[self.OUTLIER_KEY]):
1503
+ if outlier_detected:
1504
+ # we found at least 1 outlier
1505
+ if not added_model_desc:
1506
+ # add the module level description
1507
+ outlier_string += module_suggestion_str.format(module_fqn, self.ch_axis)
1508
+ added_model_desc = True
1509
+
1510
+ # we mark that we found at least one outlier
1511
+ added_module = True
1512
+ max_value_found_str = channel_max_value_str.format(mod_info[self.MAX_VALS_KEY][index])
1513
+ channel_str = channel_suggestion_str.format(index, max_value_found_str)
1514
+ outlier_string += channel_str
1515
+
1516
+ # also check if we found constant batch
1517
+ if mod_info[self.CONSTANT_COUNTS_KEY][index] != 0:
1518
+ # make sure we add a module level highlight.
1519
+ if not added_model_desc:
1520
+ # add the module level description
1521
+ outlier_string += module_suggestion_str.format(module_fqn, self.ch_axis)
1522
+ added_model_desc = True
1523
+
1524
+ constant_values_for_channel = mod_info[self.CONSTANT_COUNTS_KEY][index]
1525
+ formatted_str = constant_str.format(index, constant_values_for_channel, constant_suggestion)
1526
+ outlier_string += formatted_str
1527
+ # we also added at least one thing to description
1528
+ added_module = True
1529
+
1530
+
1531
+ # if found outlier, give suggestion, else give default response
1532
+ if added_module:
1533
+ # compose the note string
1534
+ note_composed = note_string.format(note_distribution, note_rec)
1535
+ outlier_string += note_composed
1536
+ else:
1537
+ outlier_string += "There were no outliers found in the activations.\n"
1538
+
1539
+ return (outlier_string, info_dict)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/_model_report/model_report_visualizer.py ADDED
@@ -0,0 +1,666 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Any, Set, Dict, List, Tuple, OrderedDict
3
+ from collections import OrderedDict as OrdDict
4
+
5
+ # try to import tablate
6
+ got_tabulate = True
7
+ try:
8
+ from tabulate import tabulate
9
+ except ImportError:
10
+ got_tabulate = False
11
+
12
+
13
+ # var to see if we could import matplotlib
14
+ got_matplotlib = True
15
+ try:
16
+ import matplotlib.pyplot as plt
17
+ except ImportError:
18
+ got_matplotlib = False
19
+
20
+ class ModelReportVisualizer:
21
+ r"""
22
+ The ModelReportVisualizer class aims to provide users a way to visualize some of the statistics
23
+ that were generated by the ModelReport API. However, at a higher level, the class aims to provide
24
+ some level of visualization of statistics to PyTorch in order to make it easier to parse data and
25
+ diagnose any potential issues with data or a specific model. With respect to the visualizations,
26
+ the ModelReportVisualizer class currently supports several methods of visualizing data.
27
+
28
+ Supported Visualization Methods Include:
29
+ - Table format
30
+ - Plot format (line graph)
31
+ - Histogram format
32
+
33
+ For all of the existing visualization methods, there is the option to filter data based on:
34
+ - A module fqn prefix
35
+ - Feature [required for the plot and histogram]
36
+
37
+ * :attr:`generated_reports` The reports generated by the ModelReport class in the structure below
38
+ Ensure sure that features that are the same across different report contain the same name
39
+ Ensure that objects representing the same features are the same type / dimension (where applicable)
40
+
41
+ Note:
42
+ Currently, the ModelReportVisualizer class supports visualization of data generated by the
43
+ ModelReport class. However, this structure is extensible and should allow the visualization of
44
+ other information as long as the information is structured in the following general format:
45
+
46
+ Report Structure
47
+ -- module_fqn [module with attached detectors]
48
+ |
49
+ -- feature keys [not every detector extracts same information]
50
+ [same collected info has same keys, unless can be specific to detector]
51
+
52
+
53
+ The goal behind the class is that the generated visualizations can be used in conjunction with the generated
54
+ report for people to get a better understanding of issues and what the fix might be. It is also just to provide
55
+ a good visualization platform, since it might be hard to parse through the ModelReport returned dictionary as
56
+ that grows in size.
57
+
58
+ General Use Flow Expected
59
+ 1.) Initialize ModelReport object with reports of interest by passing in initialized detector objects
60
+ 2.) Prepare your model with prepare_fx
61
+ 3.) Call model_report.prepare_detailed_calibration on your model to add relevant observers
62
+ 4.) Callibrate your model with data
63
+ 5.) Call model_report.generate_report on your model to generate report and optionally remove added observers
64
+ 6.) Use output of model_report.generate_report to initialize ModelReportVisualizer instance
65
+ 7.) Use instance to view different views of data as desired, applying filters as needed
66
+ 8.) Either see the super detailed information or just the actual printed or shown table / plot / histogram
67
+
68
+ """
69
+
70
+ # keys for table dict
71
+ TABLE_TENSOR_KEY = "tensor_level_info"
72
+ TABLE_CHANNEL_KEY = "channel_level_info"
73
+
74
+ # Constants for header vals
75
+ NUM_NON_FEATURE_TENSOR_HEADERS = 2
76
+ NUM_NON_FEATURE_CHANNEL_HEADERS = 3
77
+
78
+ # Constants for row index in header
79
+ CHANNEL_NUM_INDEX = 2
80
+
81
+ def __init__(self, generated_reports: OrderedDict[str, Any]):
82
+ r"""
83
+ Initializes the ModelReportVisualizer instance with the necessary reports.
84
+
85
+ Args:
86
+ generated_reports (Dict[str, Any]): The reports generated by the ModelReport class
87
+ can also be a dictionary generated in another manner, as long as format is same
88
+ """
89
+ self.generated_reports = generated_reports
90
+
91
+ def get_all_unique_module_fqns(self) -> Set[str]:
92
+ r"""
93
+ The purpose of this method is to provide a user the set of all module_fqns so that if
94
+ they wish to use some of the filtering capabilities of the ModelReportVisualizer class,
95
+ they don't need to manually parse the generated_reports dictionary to get this information.
96
+
97
+ Returns all the unique module fqns present in the reports the ModelReportVisualizer
98
+ instance was initialized with.
99
+ """
100
+ # returns the keys of the ordered dict
101
+ return set(self.generated_reports.keys())
102
+
103
+ def get_all_unique_feature_names(self, plottable_features_only: bool = True) -> Set[str]:
104
+ r"""
105
+ The purpose of this method is to provide a user the set of all feature names so that if
106
+ they wish to use the filtering capabilities of the generate_table_view(), or use either of
107
+ the generate_plot_view() or generate_histogram_view(), they don't need to manually parse
108
+ the generated_reports dictionary to get this information.
109
+
110
+ Args:
111
+ plottable_features_only (bool): True if the user is only looking for plottable features,
112
+ False otherwise
113
+ plottable features are those that are tensor values
114
+ Default: True (only return those feature names that are plottable)
115
+
116
+ Returns all the unique module fqns present in the reports the ModelReportVisualizer
117
+ instance was initialized with.
118
+ """
119
+ unique_feature_names = set()
120
+ for module_fqn in self.generated_reports:
121
+ # get dict of the features
122
+ feature_dict: Dict[str, Any] = self.generated_reports[module_fqn]
123
+
124
+ # loop through features
125
+ for feature_name in feature_dict:
126
+ # if we need plottable, ensure type of val is tensor
127
+ if not plottable_features_only or type(feature_dict[feature_name]) == torch.Tensor:
128
+ unique_feature_names.add(feature_name)
129
+
130
+ # return our compiled set of unique feature names
131
+ return unique_feature_names
132
+
133
+ def _get_filtered_data(self, feature_filter: str, module_fqn_filter: str) -> OrderedDict[str, Any]:
134
+ r"""
135
+ Filters the data and returns it in the same ordered dictionary format so the relevant views can be displayed.
136
+
137
+ Args:
138
+ feature_filter (str): The feature filter, if we want to filter the set of data to only include
139
+ a certain set of features that include feature_filter
140
+ If feature = "", then we do not filter based on any features
141
+ module_fqn_filter (str): The filter on prefix for the module fqn. All modules that have fqn with
142
+ this prefix will be included
143
+ If module_fqn_filter = "" we do not filter based on module fqn, and include all modules
144
+
145
+ First, the data is filtered based on module_fqn, and then filtered based on feature
146
+ Returns an OrderedDict (sorted in order of model) mapping:
147
+ module_fqns -> feature_names -> values
148
+ """
149
+ # create return dict
150
+ filtered_dict: OrderedDict[str, Any] = OrdDict()
151
+
152
+ for module_fqn in self.generated_reports:
153
+ # first filter based on module
154
+ if module_fqn_filter == "" or module_fqn_filter in module_fqn:
155
+ # create entry for module and loop through features
156
+ filtered_dict[module_fqn] = {}
157
+ module_reports = self.generated_reports[module_fqn]
158
+ for feature_name in module_reports:
159
+ # check if filtering on features and do so if desired
160
+ if feature_filter == "" or feature_filter in feature_name:
161
+ filtered_dict[module_fqn][feature_name] = module_reports[feature_name]
162
+
163
+ # we have populated the filtered dict, and must return it
164
+
165
+ return filtered_dict
166
+
167
+ def _generate_tensor_table(
168
+ self,
169
+ filtered_data: OrderedDict[str, Dict[str, Any]],
170
+ tensor_features: List[str]
171
+ ) -> Tuple[List, List]:
172
+ r"""
173
+ Takes in the filtered data and features list and generates the tensor headers and table
174
+
175
+ Currently meant to generate the headers and table for both the tensor information.
176
+
177
+ Args:
178
+ filtered_data (OrderedDict[str, Dict[str, Any]]): An OrderedDict (sorted in order of model) mapping:
179
+ module_fqns -> feature_names -> values
180
+ tensor_features (List[str]): A list of the tensor level features
181
+
182
+ Returns a tuple with:
183
+ A list of the headers of the tensor table
184
+ A list of lists containing the table information row by row
185
+ The 0th index row will contain the headers of the columns
186
+ The rest of the rows will contain data
187
+ """
188
+ # now we compose the tensor information table
189
+ tensor_table: List[List[Any]] = []
190
+ tensor_headers: List[str] = []
191
+
192
+ # append the table row to the table only if we have features
193
+ if len(tensor_features) > 0:
194
+ # now we add all the data
195
+ for index, module_fqn in enumerate(filtered_data):
196
+ # we make a new row for the tensor table
197
+ tensor_table_row = [index, module_fqn]
198
+ for feature in tensor_features:
199
+ # we iterate in same order of added features
200
+
201
+ if feature in filtered_data[module_fqn]:
202
+ # add value if applicable to module
203
+ feature_val = filtered_data[module_fqn][feature]
204
+ else:
205
+ # add that it is not applicable
206
+ feature_val = "Not Applicable"
207
+
208
+ # if it's a tensor we want to extract val
209
+ if isinstance(feature_val, torch.Tensor):
210
+ feature_val = feature_val.item()
211
+
212
+ # we add to our list of values
213
+ tensor_table_row.append(feature_val)
214
+
215
+ tensor_table.append(tensor_table_row)
216
+
217
+ # add row of headers of we actually have something, otherwise just empty
218
+ if len(tensor_table) != 0:
219
+ tensor_headers = ["idx", "layer_fqn"] + tensor_features
220
+
221
+ return (tensor_headers, tensor_table)
222
+
223
+ def _generate_channels_table(
224
+ self,
225
+ filtered_data: OrderedDict[str, Any],
226
+ channel_features: List[str],
227
+ num_channels: int
228
+ ) -> Tuple[List, List]:
229
+ r"""
230
+ Takes in the filtered data and features list and generates the channels headers and table
231
+
232
+ Currently meant to generate the headers and table for both the channels information.
233
+
234
+ Args:
235
+ filtered_data (OrderedDict[str, Any]): An OrderedDict (sorted in order of model) mapping:
236
+ module_fqns -> feature_names -> values
237
+ channel_features (List[str]): A list of the channel level features
238
+ num_channels (int): Number of channels in the channel data
239
+
240
+ Returns a tuple with:
241
+ A list of the headers of the channel table
242
+ A list of lists containing the table information row by row
243
+ The 0th index row will contain the headers of the columns
244
+ The rest of the rows will contain data
245
+ """
246
+ # now we compose the table for the channel information table
247
+ channel_table: List[List[Any]] = []
248
+ channel_headers: List[str] = []
249
+
250
+ # counter to keep track of number of entries in
251
+ channel_table_entry_counter: int = 0
252
+
253
+ if len(channel_features) > 0:
254
+ # now we add all channel data
255
+ for module_fqn in filtered_data:
256
+ # we iterate over all channels
257
+ for channel in range(num_channels):
258
+ # we make a new row for the channel
259
+ new_channel_row = [channel_table_entry_counter, module_fqn, channel]
260
+ for feature in channel_features:
261
+ if feature in filtered_data[module_fqn]:
262
+ # add value if applicable to module
263
+ feature_val = filtered_data[module_fqn][feature][channel]
264
+ else:
265
+ # add that it is not applicable
266
+ feature_val = "Not Applicable"
267
+
268
+ # if it's a tensor we want to extract val
269
+ if type(feature_val) is torch.Tensor:
270
+ feature_val = feature_val.item()
271
+
272
+ # add value to channel specific row
273
+ new_channel_row.append(feature_val)
274
+
275
+ # add to table and increment row index counter
276
+ channel_table.append(new_channel_row)
277
+ channel_table_entry_counter += 1
278
+
279
+ # add row of headers of we actually have something, otherwise just empty
280
+ if len(channel_table) != 0:
281
+ channel_headers = ["idx", "layer_fqn", "channel"] + channel_features
282
+
283
+ return (channel_headers, channel_table)
284
+
285
+ def generate_filtered_tables(self, feature_filter: str = "", module_fqn_filter: str = "") -> Dict[str, Tuple[List, List]]:
286
+ r"""
287
+ Takes in optional filter values and generates two tables with desired information.
288
+
289
+ The generated tables are presented in both a list-of-lists format
290
+
291
+ The reason for the two tables are that they handle different things:
292
+ 1.) the first table handles all tensor level information
293
+ 2.) the second table handles and displays all channel based information
294
+
295
+ The reasoning for this is that having all the info in one table can make it ambiguous which collected
296
+ statistics are global, and which are actually per-channel, so it's better to split it up into two
297
+ tables. This also makes the information much easier to digest given the plethora of statistics collected
298
+
299
+ Tensor table columns:
300
+ idx layer_fqn feature_1 feature_2 feature_3 .... feature_n
301
+ ---- --------- --------- --------- --------- ---------
302
+
303
+ Per-Channel table columns:
304
+ idx layer_fqn channel feature_1 feature_2 feature_3 .... feature_n
305
+ ---- --------- ------- --------- --------- --------- ---------
306
+
307
+ Args:
308
+ feature_filter (str, optional): Filters the features presented to only those that
309
+ contain this filter substring
310
+ Default = "", results in all the features being printed
311
+ module_fqn_filter (str, optional): Only includes modules that contains this string
312
+ Default = "", results in all the modules in the reports to be visible in the table
313
+
314
+ Returns a dictionary with two keys:
315
+ (Dict[str, Tuple[List, List]]) A dict containing two keys:
316
+ "tensor_level_info", "channel_level_info"
317
+ Each key maps to a tuple with:
318
+ A list of the headers of each table
319
+ A list of lists containing the table information row by row
320
+ The 0th index row will contain the headers of the columns
321
+ The rest of the rows will contain data
322
+
323
+ Example Use:
324
+ >>> # xdoctest: +SKIP("undefined variables")
325
+ >>> mod_report_visualizer.generate_filtered_tables(
326
+ ... feature_filter = "per_channel_min",
327
+ ... module_fqn_filter = "block1"
328
+ ... ) # generates table with per_channel_min info for all modules in block 1 of the model
329
+ """
330
+ # first get the filtered data
331
+ filtered_data: OrderedDict[str, Any] = self._get_filtered_data(feature_filter, module_fqn_filter)
332
+
333
+ # now we split into tensor and per-channel data
334
+ tensor_features: Set[str] = set()
335
+ channel_features: Set[str] = set()
336
+
337
+ # keep track of the number of channels we have
338
+ num_channels: int = 0
339
+
340
+ for module_fqn in filtered_data:
341
+ for feature_name in filtered_data[module_fqn]:
342
+ # get the data for that specific feature
343
+ feature_data = filtered_data[module_fqn][feature_name]
344
+
345
+ # check if not zero dim tensor
346
+ is_tensor: bool = isinstance(feature_data, torch.Tensor)
347
+ is_not_zero_dim: bool = is_tensor and len(feature_data.shape) != 0
348
+
349
+ if is_not_zero_dim or isinstance(feature_data, list):
350
+ # works means per channel
351
+ channel_features.add(feature_name)
352
+ num_channels = len(feature_data)
353
+ else:
354
+ # means is per-tensor
355
+ tensor_features.add(feature_name)
356
+
357
+ # we make them lists for iteration purposes
358
+ tensor_features_list: List[str] = sorted(tensor_features)
359
+ channel_features_list: List[str] = sorted(channel_features)
360
+
361
+ # get the tensor info
362
+ tensor_headers, tensor_table = self._generate_tensor_table(filtered_data, tensor_features_list)
363
+
364
+ # get the channel info
365
+ channel_headers, channel_table = self._generate_channels_table(
366
+ filtered_data, channel_features_list, num_channels
367
+ )
368
+
369
+ # let's now create the dictionary to return
370
+ table_dict = {
371
+ self.TABLE_TENSOR_KEY : (tensor_headers, tensor_table),
372
+ self.TABLE_CHANNEL_KEY : (channel_headers, channel_table)
373
+ }
374
+
375
+ # return the two tables
376
+ return table_dict
377
+
378
+ def generate_table_visualization(self, feature_filter: str = "", module_fqn_filter: str = ""):
379
+ r"""
380
+ Takes in optional filter values and prints out formatted tables of the information.
381
+
382
+ The reason for the two tables printed out instead of one large one are that they handle different things:
383
+ 1.) the first table handles all tensor level information
384
+ 2.) the second table handles and displays all channel based information
385
+
386
+ The reasoning for this is that having all the info in one table can make it ambiguous which collected
387
+ statistics are global, and which are actually per-channel, so it's better to split it up into two
388
+ tables. This also makes the information much easier to digest given the plethora of statistics collected
389
+
390
+ Tensor table columns:
391
+ idx layer_fqn feature_1 feature_2 feature_3 .... feature_n
392
+ ---- --------- --------- --------- --------- ---------
393
+
394
+ Per-Channel table columns:
395
+
396
+ idx layer_fqn channel feature_1 feature_2 feature_3 .... feature_n
397
+ ---- --------- ------- --------- --------- --------- ---------
398
+
399
+ Args:
400
+ feature_filter (str, optional): Filters the features presented to only those that
401
+ contain this filter substring
402
+ Default = "", results in all the features being printed
403
+ module_fqn_filter (str, optional): Only includes modules that contains this string
404
+ Default = "", results in all the modules in the reports to be visible in the table
405
+
406
+ Example Use:
407
+ >>> # xdoctest: +SKIP("undefined variables")
408
+ >>> mod_report_visualizer.generate_table_visualization(
409
+ ... feature_filter = "per_channel_min",
410
+ ... module_fqn_filter = "block1"
411
+ ... )
412
+ >>> # prints out neatly formatted table with per_channel_min info
413
+ >>> # for all modules in block 1 of the model
414
+ """
415
+ # see if we got tabulate
416
+ if not got_tabulate:
417
+ print("Make sure to install tabulate and try again.")
418
+ return None
419
+
420
+ # get the table dict and the specific tables of interest
421
+ table_dict = self.generate_filtered_tables(feature_filter, module_fqn_filter)
422
+ tensor_headers, tensor_table = table_dict[self.TABLE_TENSOR_KEY]
423
+ channel_headers, channel_table = table_dict[self.TABLE_CHANNEL_KEY]
424
+
425
+ # get the table string and print it out
426
+ # now we have populated the tables for each one
427
+ # let's create the strings to be returned
428
+ table_str = ""
429
+ # the tables will have some headers columns that are non-feature
430
+ # ex. table index, module name, channel index, etc.
431
+ # we want to look at header columns for features, that come after those headers
432
+ if len(tensor_headers) > self.NUM_NON_FEATURE_TENSOR_HEADERS:
433
+ # if we have at least one tensor level feature to be added we add tensor table
434
+ table_str += "Tensor Level Information \n"
435
+ table_str += tabulate(tensor_table, headers=tensor_headers)
436
+ if len(channel_headers) > self.NUM_NON_FEATURE_CHANNEL_HEADERS:
437
+ # if we have at least one channel level feature to be added we add tensor table
438
+ table_str += "\n\n Channel Level Information \n"
439
+ table_str += tabulate(channel_table, headers=channel_headers)
440
+
441
+ # if no features at all, let user know
442
+ if table_str == "":
443
+ table_str = "No data points to generate table with."
444
+
445
+ print(table_str)
446
+
447
+ def _get_plottable_data(self, feature_filter: str, module_fqn_filter: str) -> Tuple[List, List[List], bool]:
448
+ r"""
449
+ Takes in the feature filters and module filters and outputs the x and y data for plotting
450
+
451
+ Args:
452
+ feature_filter (str): Filters the features presented to only those that
453
+ contain this filter substring
454
+ module_fqn_filter (str): Only includes modules that contains this string
455
+
456
+ Returns a tuple of three elements
457
+ The first is a list containing relevant x-axis data
458
+ The second is a list containing the corresponding y-axis data
459
+ If the data is per channel
460
+ """
461
+ # get the table dict and the specific tables of interest
462
+ table_dict = self.generate_filtered_tables(feature_filter, module_fqn_filter)
463
+ tensor_headers, tensor_table = table_dict[self.TABLE_TENSOR_KEY]
464
+ channel_headers, channel_table = table_dict[self.TABLE_CHANNEL_KEY]
465
+
466
+ # make sure it is only 1 feature that is being plotted
467
+ # get the number of features in each of these
468
+ tensor_info_features_count = len(tensor_headers) - ModelReportVisualizer.NUM_NON_FEATURE_TENSOR_HEADERS
469
+ channel_info_features_count = len(channel_headers) - ModelReportVisualizer.NUM_NON_FEATURE_CHANNEL_HEADERS
470
+
471
+ # see if valid tensor or channel plot
472
+ is_valid_per_tensor_plot: bool = tensor_info_features_count == 1
473
+ is_valid_per_channel_plot: bool = channel_info_features_count == 1
474
+
475
+ # offset should either be one of tensor or channel table or neither
476
+ feature_column_offset = ModelReportVisualizer.NUM_NON_FEATURE_TENSOR_HEADERS
477
+ table = tensor_table
478
+
479
+ # if a per_channel plot, we have different offset and table
480
+ if is_valid_per_channel_plot:
481
+ feature_column_offset = ModelReportVisualizer.NUM_NON_FEATURE_CHANNEL_HEADERS
482
+ table = channel_table
483
+
484
+ x_data: List = []
485
+ y_data: List[List] = []
486
+ # the feature will either be a tensor feature or channel feature
487
+ if is_valid_per_tensor_plot:
488
+ for table_row_num, row in enumerate(table):
489
+ # get x_value to append
490
+ x_val_to_append = table_row_num
491
+ # the index of the feature will the 0 + num non feature columns
492
+ tensor_feature_index = feature_column_offset
493
+ row_value = row[tensor_feature_index]
494
+ if not type(row_value) == str:
495
+ x_data.append(x_val_to_append)
496
+ y_data.append(row_value)
497
+ elif is_valid_per_channel_plot:
498
+ # gather the x_data and multiple y_data
499
+ # calculate the number of channels
500
+ num_channels: int = max(row[self.CHANNEL_NUM_INDEX] for row in table) + 1
501
+ for channel in range(num_channels):
502
+ y_data.append([]) # separate data list per channel
503
+
504
+ for table_row_num, row in enumerate(table):
505
+ # get x_value to append
506
+ x_val_to_append = table_row_num
507
+ current_channel = row[self.CHANNEL_NUM_INDEX] # initially chose current channel
508
+ new_module_index: int = table_row_num // num_channels
509
+ x_val_to_append = new_module_index
510
+
511
+ # the index of the feature will the 0 + num non feature columns
512
+ tensor_feature_index = feature_column_offset
513
+ row_value = row[tensor_feature_index]
514
+ if not type(row_value) == str:
515
+ # only append if new index we are appending
516
+ if len(x_data) == 0 or x_data[-1] != x_val_to_append:
517
+ x_data.append(x_val_to_append)
518
+
519
+ # append value for that channel
520
+ y_data[current_channel].append(row_value)
521
+ else:
522
+ # more than one feature was chosen
523
+ error_str = "Make sure to pick only a single feature with your filter to plot a graph."
524
+ error_str += " We recommend calling get_all_unique_feature_names() to find unique feature names."
525
+ error_str += " Pick one of those features to plot."
526
+ raise ValueError(error_str)
527
+
528
+ # return x, y values, and if data is per-channel
529
+ return (x_data, y_data, is_valid_per_channel_plot)
530
+
531
+ def generate_plot_visualization(self, feature_filter: str, module_fqn_filter: str = ""):
532
+ r"""
533
+ Takes in a feature and optional module_filter and plots of the desired data.
534
+
535
+ For per channel features, it averages the value across the channels and plots a point
536
+ per module. The reason for this is that for models with hundreds of channels, it can
537
+ be hard to differentiate one channel line from another, and so the point of generating
538
+ a single average point per module is to give a sense of general trends that encourage
539
+ further deep dives.
540
+
541
+ Note:
542
+ Only features in the report that have tensor value data are plottable by this class
543
+ When the tensor information is plotted, it will plot:
544
+ idx as the x val, feature value as the y_val
545
+ When the channel information is plotted, it will plot:
546
+ the first idx of each module as the x val, feature value as the y_val [for each channel]
547
+ The reason for this is that we want to be able to compare values across the
548
+ channels for same layer, and it will be hard if values are staggered by idx
549
+ This means each module is represented by only 1 x value
550
+ Args:
551
+ feature_filter (str): Filters the features presented to only those that
552
+ contain this filter substring
553
+ module_fqn_filter (str, optional): Only includes modules that contains this string
554
+ Default = "", results in all the modules in the reports to be visible in the table
555
+
556
+ Example Use:
557
+ >>> # xdoctest: +SKIP("undefined variables")
558
+ >>> mod_report_visualizer.generate_plot_visualization(
559
+ ... feature_filter = "per_channel_min",
560
+ ... module_fqn_filter = "block1"
561
+ ... )
562
+ >>> # outputs line plot of per_channel_min information for all
563
+ >>> # modules in block1 of model each channel gets it's own line,
564
+ >>> # and it's plotted across the in-order modules on the x-axis
565
+ """
566
+ # checks if we have matplotlib and let's user know to install it if don't
567
+ if not got_matplotlib:
568
+ print("make sure to install matplotlib and try again.")
569
+ return None
570
+
571
+ # get the x and y data and if per channel
572
+ x_data, y_data, data_per_channel = self._get_plottable_data(feature_filter, module_fqn_filter)
573
+
574
+ # plot based on whether data is per channel or not
575
+ ax = plt.subplot()
576
+ ax.set_ylabel(feature_filter)
577
+ ax.set_title(feature_filter + " Plot")
578
+ plt.xticks(x_data) # only show ticks for actual points
579
+
580
+ if data_per_channel:
581
+ ax.set_xlabel("First idx of module")
582
+ # set the legend as well
583
+ # plot a single line that is average of the channel values
584
+ num_modules = len(y_data[0]) # all y_data have same length, so get num modules
585
+ num_channels = len(y_data) # we want num channels to be able to calculate average later
586
+
587
+ avg_vals = [sum(y_data[:][index]) / num_channels for index in range(num_modules)]
588
+
589
+ # plot the three things we measured
590
+ ax.plot(x_data, avg_vals, label=f"Average Value Across {num_channels} Channels")
591
+ ax.legend(loc='upper right')
592
+ else:
593
+ ax.set_xlabel("idx")
594
+ ax.plot(x_data, y_data)
595
+
596
+ # actually show the plot
597
+ plt.show()
598
+
599
+ def generate_histogram_visualization(self, feature_filter: str, module_fqn_filter: str = "", num_bins: int = 10):
600
+ r"""
601
+ Takes in a feature and optional module_filter and plots the histogram of desired data.
602
+
603
+ Note:
604
+ Only features in the report that have tensor value data can be viewed as a histogram
605
+ If you want to plot a histogram from all the channel values of a specific feature for
606
+ a specific model, make sure to specify both the model and the feature properly
607
+ in the filters and you should be able to see a distribution of the channel data
608
+
609
+ Args:
610
+ feature_filter (str, optional): Filters the features presented to only those that
611
+ contain this filter substring
612
+ Default = "", results in all the features being printed
613
+ module_fqn_filter (str, optional): Only includes modules that contains this string
614
+ Default = "", results in all the modules in the reports to be visible in the table
615
+ num_bins (int, optional): The number of bins to create the histogram with
616
+ Default = 10, the values will be split into 10 equal sized bins
617
+
618
+ Example Use:
619
+ >>> # xdoctest: +SKIP
620
+ >>> mod_report_visualizer.generategenerate_histogram_visualization_plot_visualization(
621
+ ... feature_filter = "per_channel_min",
622
+ ... module_fqn_filter = "block1"
623
+ ... )
624
+ # outputs histogram of per_channel_min information for all modules in block1 of model
625
+ information is gathered across all channels for all modules in block 1 for the
626
+ per_channel_min and is displayed in a histogram of equally sized bins
627
+ """
628
+ # checks if we have matplotlib and let's user know to install it if don't
629
+ if not got_matplotlib:
630
+ print("make sure to install matplotlib and try again.")
631
+ return None
632
+
633
+ # get the x and y data and if per channel
634
+ x_data, y_data, data_per_channel = self._get_plottable_data(feature_filter, module_fqn_filter)
635
+
636
+ # for histogram, we just care about plotting the y data
637
+ # plot based on whether data is per channel or not
638
+ ax = plt.subplot()
639
+ ax.set_xlabel(feature_filter)
640
+ ax.set_ylabel("Frequency")
641
+ ax.set_title(feature_filter + " Histogram")
642
+
643
+ if data_per_channel:
644
+ # set the legend as well
645
+ # combine all the data
646
+ all_data = []
647
+ for channel_info in y_data:
648
+ all_data.extend(channel_info)
649
+
650
+ val, bins, _ = plt.hist(
651
+ all_data,
652
+ bins=num_bins,
653
+ stacked=True,
654
+ rwidth=0.8,
655
+ )
656
+ plt.xticks(bins)
657
+ else:
658
+ val, bins, _ = plt.hist(
659
+ y_data,
660
+ bins=num_bins,
661
+ stacked=False,
662
+ rwidth=0.8,
663
+ )
664
+ plt.xticks(bins)
665
+
666
+ plt.show()
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/quantize_handler.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC
2
+ from typing import Callable, Dict, List, Optional, Type
3
+
4
+ import torch
5
+
6
+ from torch.ao.quantization.backend_config import (
7
+ BackendConfig,
8
+ DTypeConfig,
9
+ ObservationType,
10
+ )
11
+ from torch.ao.quantization.utils import NodePattern, Pattern, QuantizerCls
12
+ from torch.fx.graph import Node
13
+
14
+ from .utils import all_node_args_have_no_tensors
15
+
16
+
17
+ __all__ = [
18
+ "QuantizeHandler",
19
+ "BinaryOpQuantizeHandler",
20
+ "CatQuantizeHandler",
21
+ "ConvReluQuantizeHandler",
22
+ "LinearReLUQuantizeHandler",
23
+ "BatchNormQuantizeHandler",
24
+ "EmbeddingQuantizeHandler",
25
+ "RNNDynamicQuantizeHandler",
26
+ "DefaultNodeQuantizeHandler",
27
+ "FixedQParamsOpQuantizeHandler",
28
+ "CopyNodeQuantizeHandler",
29
+ "GeneralTensorShapeOpQuantizeHandler",
30
+ "CustomModuleQuantizeHandler",
31
+ "StandaloneModuleQuantizeHandler",
32
+ ]
33
+
34
+ def _default_root_node_getter(node_pattern):
35
+ if node_pattern is None:
36
+ return node_pattern
37
+ while not isinstance(node_pattern, Node):
38
+ node_pattern = node_pattern[-1]
39
+ return node_pattern
40
+
41
+ # Base Pattern Handler
42
+ class QuantizeHandler(ABC): # noqa: B024
43
+ """ Base handler class for the quantizer patterns
44
+ """
45
+ def __init__(
46
+ self,
47
+ node_pattern: NodePattern,
48
+ modules: Dict[str, torch.nn.Module],
49
+ root_node_getter: Optional[Callable] = None,
50
+ is_custom_module=False,
51
+ is_standalone_module=False):
52
+ """ Records pattern information in __init__, which will be used
53
+ in convert
54
+ """
55
+ self.node_pattern = node_pattern
56
+ self.modules = modules
57
+ if root_node_getter is None:
58
+ root_node_getter = _default_root_node_getter
59
+ self.root_node = root_node_getter(node_pattern)
60
+ self.is_custom_module_ = is_custom_module
61
+ self.is_standalone_module_ = is_standalone_module
62
+ self.num_tensor_args = 0
63
+ # determine how many of the first two args are Tensors (versus scalars)
64
+ # this distinguishes things like "x + y" from "x + 2" or "2 + x"
65
+ if isinstance(self.root_node, Node):
66
+ cache_for_no_tensor_check: Dict[Node, bool] = {}
67
+ for arg_idx in range(len(self.root_node.args)):
68
+ arg = self.root_node.args[arg_idx]
69
+ if isinstance(arg, Node) and (
70
+ not all_node_args_have_no_tensors(
71
+ arg, self.modules, cache_for_no_tensor_check)):
72
+ self.num_tensor_args += 1
73
+
74
+ def is_general_tensor_value_op(self) -> bool:
75
+ """
76
+ Returns True if the operator works for both floating point and
77
+ quantized input, and does some computation based on the input Tensor,
78
+ or the ops that only re-arranges the Tensor values or query some metadata
79
+ about the Tensor
80
+ so we need to insert observer/fake_quant for the output of the
81
+ operator (same observer instance as input)
82
+ since the distribution of values is different for input and output
83
+ Tensors (for HistogramObserver) while they share the same quantization
84
+ parameters
85
+ Example operator: avgpool2d, reshape, transpose, maxpool2d
86
+ Example observed operator:
87
+ observer_0 - avgpool2d - observer_0 (same observer instance as input)
88
+ """
89
+ return False
90
+
91
+ def is_custom_module(self):
92
+ return self.is_custom_module_
93
+
94
+ def is_standalone_module(self):
95
+ return self.is_standalone_module_
96
+
97
+ def _get_quantize_handler_cls(
98
+ observation_type: ObservationType,
99
+ dtype_configs: List[DTypeConfig],
100
+ num_tensor_args_to_observation_type: Dict[int, ObservationType]) -> Type[QuantizeHandler]:
101
+ """
102
+ Return a configurable QuantizeHandler that matches the given specifications from the backend.
103
+ """
104
+
105
+ class ConfigurableQuantizeHandler(QuantizeHandler):
106
+ def __init__(
107
+ self,
108
+ node_pattern: NodePattern,
109
+ modules: Dict[str, torch.nn.Module],
110
+ root_node_getter: Optional[Callable] = None):
111
+ super().__init__(node_pattern, modules, root_node_getter)
112
+ if num_tensor_args_to_observation_type:
113
+ assert self.num_tensor_args in num_tensor_args_to_observation_type, \
114
+ f"Must provide observation_type config for tensor number {self.num_tensor_args}" \
115
+ f" in num_tensor_args_to_observation_type for {node_pattern}"
116
+ self.observation_type = num_tensor_args_to_observation_type[self.num_tensor_args]
117
+ else:
118
+ self.observation_type = observation_type
119
+ self.dtype_configs = dtype_configs
120
+
121
+ def is_general_tensor_value_op(self) -> bool:
122
+ return self.observation_type == ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT
123
+
124
+ return ConfigurableQuantizeHandler
125
+
126
+ def _get_pattern_to_quantize_handlers(backend_config: BackendConfig) -> Dict[Pattern, QuantizerCls]:
127
+ """
128
+ Note: Quantize handler is just a holder for some check methods like
129
+ (should_insert_observer_for_output), maybe this can be a enum as well,
130
+ we can refactor this after we convert the path for fbgemm/qnnpack fully to the
131
+ new path, this is not exposed to backend developers
132
+ """
133
+ pattern_to_quantize_handlers = {}
134
+ for pattern, config in backend_config._pattern_complex_format_to_config.items():
135
+ observation_type = config.observation_type
136
+ dtype_configs = config.dtype_configs
137
+ num_tensor_args_to_observation_type = config._num_tensor_args_to_observation_type
138
+ pattern_to_quantize_handlers[pattern] = \
139
+ _get_quantize_handler_cls(
140
+ observation_type,
141
+ dtype_configs,
142
+ num_tensor_args_to_observation_type)
143
+ return pattern_to_quantize_handlers
144
+
145
+ # TODO: remove this class, this is still exposed in torch.ao.quantization
146
+ # but we should be able to break bc
147
+ class BinaryOpQuantizeHandler(QuantizeHandler):
148
+ pass
149
+
150
+ class CatQuantizeHandler(QuantizeHandler):
151
+ pass
152
+
153
+ # TODO: remove this class
154
+ class ConvReluQuantizeHandler(QuantizeHandler):
155
+ pass
156
+
157
+ # TODO: remove this class
158
+ class LinearReLUQuantizeHandler(QuantizeHandler):
159
+ pass
160
+
161
+ # TODO: remove this class
162
+ class BatchNormQuantizeHandler(QuantizeHandler):
163
+ pass
164
+
165
+ # TODO: remove this class
166
+ class EmbeddingQuantizeHandler(QuantizeHandler):
167
+ pass
168
+
169
+ # TODO: remove this class
170
+ class RNNDynamicQuantizeHandler(QuantizeHandler):
171
+ pass
172
+
173
+ # TODO: remove this class
174
+ class DefaultNodeQuantizeHandler(QuantizeHandler):
175
+ """ Common quantized op, first input and first output will be quantized
176
+ """
177
+ pass
178
+
179
+ # TODO: remove this class
180
+ class FixedQParamsOpQuantizeHandler(QuantizeHandler):
181
+ pass
182
+
183
+ # TODO: remove
184
+ class CopyNodeQuantizeHandler(QuantizeHandler):
185
+ pass
186
+
187
+ # TODO: remove
188
+ class GeneralTensorShapeOpQuantizeHandler(QuantizeHandler):
189
+ pass
190
+
191
+ # TODO: not used, can be removed after torch.ao.quantization namespace is deprecated
192
+ class CustomModuleQuantizeHandler(QuantizeHandler):
193
+ pass
194
+
195
+ # TODO: not used, can be removed after torch.ao.quantization namespace is deprecated
196
+ class StandaloneModuleQuantizeHandler(QuantizeHandler):
197
+ pass
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/utils.py ADDED
@@ -0,0 +1,885 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.ao.quantization import (
5
+ QConfigAny,
6
+ QuantType,
7
+ )
8
+ from torch.ao.quantization.backend_config import (
9
+ DTypeWithConstraints,
10
+ )
11
+ from torch.ao.quantization.fake_quantize import (
12
+ FakeQuantizeBase,
13
+ FixedQParamsFakeQuantize,
14
+ )
15
+ from torch.ao.quantization.observer import (
16
+ FixedQParamsObserver,
17
+ ObserverBase,
18
+ )
19
+ from torch.ao.quantization.qconfig import (
20
+ float16_static_qconfig,
21
+ float16_dynamic_qconfig,
22
+ qconfig_equals,
23
+ )
24
+ from torch.ao.quantization.stubs import DeQuantStub
25
+ from torch.ao.quantization.utils import (
26
+ activation_is_statically_quantized,
27
+ )
28
+ from torch.ao.quantization.observer import _is_activation_post_process
29
+ from torch.ao.quantization.qconfig_mapping import QConfigMapping
30
+
31
+ from torch.fx import GraphModule, map_arg
32
+
33
+ from torch.fx.graph import (
34
+ Graph,
35
+ Node,
36
+ )
37
+ from .custom_config import PrepareCustomConfig
38
+ # importing the lib so that the quantized_decomposed ops are registered
39
+ from ._decomposed import quantized_decomposed_lib # noqa: F401
40
+
41
+ from typing import Callable, Optional, List, Dict, Any, Set, Tuple, Union, Type
42
+ from dataclasses import dataclass
43
+ from collections import namedtuple
44
+ import operator
45
+ import warnings
46
+
47
+ # TODO: revisit this list. Many helper methods shouldn't be public
48
+ __all__ = [
49
+ "all_node_args_except_first",
50
+ "all_node_args_have_no_tensors",
51
+ "assert_and_get_unique_device",
52
+ "collect_producer_nodes",
53
+ "create_getattr_from_value",
54
+ "create_node_from_old_node_preserve_meta",
55
+ "EMPTY_ARG_DICT",
56
+ "get_custom_module_class_keys",
57
+ "get_linear_prepack_op_for_dtype",
58
+ "get_new_attr_name_with_prefix",
59
+ "get_non_observable_arg_indexes_and_types",
60
+ "get_qconv_prepack_op",
61
+ "get_skipped_module_name_and_classes",
62
+ "graph_module_from_producer_nodes",
63
+ "maybe_get_next_module",
64
+ "NodeInfo",
65
+ "node_arg_is_bias",
66
+ "node_arg_is_weight",
67
+ "NON_OBSERVABLE_ARG_DICT",
68
+ "NON_QUANTIZABLE_WEIGHT_OPS",
69
+ "return_arg_list",
70
+ "ObservedGraphModuleAttrs",
71
+ ]
72
+
73
+ NON_QUANTIZABLE_WEIGHT_OPS = {torch.nn.functional.layer_norm, torch.nn.functional.group_norm, torch.nn.functional.instance_norm}
74
+
75
+ @dataclass
76
+ class ObservedGraphModuleAttrs:
77
+ node_name_to_qconfig: Dict[str, QConfigAny]
78
+ node_name_to_scope: Dict[str, Tuple[str, type]]
79
+ prepare_custom_config: PrepareCustomConfig
80
+ equalization_node_name_to_qconfig: Dict[str, Any]
81
+ qconfig_mapping: QConfigMapping
82
+ is_qat: bool
83
+ observed_node_names: Set[str]
84
+ is_observed_standalone_module: bool = False
85
+ standalone_module_input_quantized_idxs: Optional[List[int]] = None
86
+ standalone_module_output_quantized_idxs: Optional[List[int]] = None
87
+
88
+ def node_arg_is_weight(node: Node, arg: Any) -> bool:
89
+ """Returns if node arg is weight"""
90
+ weight_index = None
91
+ if "target_dtype_info" in node.meta:
92
+ weight_index = node.meta["target_dtype_info"].get("weight_index", None)
93
+ if weight_index is not None and weight_index < len(node.args) and node.args[weight_index] is arg:
94
+ return True
95
+ return node.kwargs.get("weight") is arg
96
+
97
+ def node_arg_is_bias(node: Node, arg: Any) -> bool:
98
+ """Returns if node arg is bias"""
99
+ bias_index = None
100
+ if "target_dtype_info" in node.meta:
101
+ bias_index = node.meta["target_dtype_info"].get("bias_index", None)
102
+ if bias_index is not None and bias_index < len(node.args) and node.args[bias_index] is arg:
103
+ return True
104
+ return node.kwargs.get("bias") is arg
105
+
106
+ def get_custom_module_class_keys(custom_module_mapping: Dict[QuantType, Dict[Type, Type]]) -> List[Any]:
107
+ r""" Get all the unique custom module keys in the custom config dict
108
+ e.g.
109
+ Input:
110
+ {
111
+ QuantType.STATIC: {
112
+ CustomModule1: ObservedCustomModule
113
+ },
114
+ QuantType.DYNAMIC: {
115
+ CustomModule2: DynamicObservedCustomModule
116
+ },
117
+ QuantType.WEIGHT_ONLY: {
118
+ CustomModule3: WeightOnlyObservedCustomModule
119
+ },
120
+ }
121
+
122
+ Output:
123
+ # extract the keys across all inner STATIC, DYNAMIC, and WEIGHT_ONLY dicts
124
+ [CustomModule1, CustomModule2, CustomModule3]
125
+ """
126
+ # using set to dedup
127
+ float_custom_module_classes : Set[Any] = set()
128
+ for quant_mode in [QuantType.STATIC, QuantType.DYNAMIC, QuantType.WEIGHT_ONLY]:
129
+ quant_mode_custom_module_config = custom_module_mapping.get(quant_mode, {})
130
+ quant_mode_custom_module_classes = set(quant_mode_custom_module_config.keys())
131
+ float_custom_module_classes |= quant_mode_custom_module_classes
132
+ return list(float_custom_module_classes)
133
+
134
+ def get_linear_prepack_op_for_dtype(dtype):
135
+ if dtype == torch.float16:
136
+ return torch.ops.quantized.linear_prepack_fp16
137
+ elif dtype == torch.qint8:
138
+ return torch.ops.quantized.linear_prepack
139
+ else:
140
+ raise Exception("can't get linear prepack op for dtype:", dtype)
141
+
142
+ def get_qconv_prepack_op(conv_op: Callable) -> Callable:
143
+ prepack_ops = {
144
+ torch.nn.functional.conv1d: torch.ops.quantized.conv1d_prepack,
145
+ torch.nn.functional.conv2d: torch.ops.quantized.conv2d_prepack,
146
+ torch.nn.functional.conv3d: torch.ops.quantized.conv3d_prepack,
147
+ torch.nn.functional.conv_transpose1d: torch.ops.quantized.conv_transpose1d_prepack,
148
+ torch.nn.functional.conv_transpose2d: torch.ops.quantized.conv_transpose2d_prepack,
149
+ torch.nn.functional.conv_transpose3d: torch.ops.quantized.conv_transpose3d_prepack,
150
+ }
151
+ prepack_op = prepack_ops.get(conv_op, None)
152
+ assert prepack_op, f"Didn't find prepack op for {conv_op}"
153
+ return prepack_op
154
+
155
+ # Returns a function that can get a new attribute name for module with given
156
+ # prefix, for example,
157
+ # >> get_new_observer_name = get_new_attr_name_with_prefix('_observer')
158
+ # >> new_name = get_new_observer_name(module)
159
+ # new_name will be an unused attribute name on module, e.g. `_observer_1`
160
+ def get_new_attr_name_with_prefix(prefix: str) -> Callable:
161
+ prefix = prefix.replace(".", "_")
162
+
163
+ def get_new_attr_name(module: torch.nn.Module):
164
+ def get_attr_name(i: int):
165
+ return prefix + str(i)
166
+ i = 0
167
+ attr_name = get_attr_name(i)
168
+ while hasattr(module, attr_name):
169
+ i += 1
170
+ attr_name = get_attr_name(i)
171
+ return attr_name
172
+ return get_new_attr_name
173
+
174
+ def collect_producer_nodes(node: Node) -> Optional[List[Node]]:
175
+ r''' Starting from a target node, trace back until we hit inpu or
176
+ getattr node. This is used to extract the chain of operators
177
+ starting from getattr to the target node, for example
178
+ def forward(self, x):
179
+ observed = self.observer(self.weight)
180
+ return F.linear(x, observed)
181
+ collect_producer_nodes(observed) will either return a list of nodes that
182
+ produces the observed node or None if we can't extract a self contained
183
+ graph without free variables(inputs of the forward function).
184
+ '''
185
+ nodes = [node]
186
+ frontier = [node]
187
+ while frontier:
188
+ node = frontier.pop()
189
+ all_args = list(node.args) + list(node.kwargs.values())
190
+ for arg in all_args:
191
+ if not isinstance(arg, Node):
192
+ continue
193
+ if arg.op == 'placeholder':
194
+ # hit input, can't fold in this case
195
+ return None
196
+ nodes.append(arg)
197
+ if not (arg.op == 'call_function' and arg.target == getattr):
198
+ frontier.append(arg)
199
+ return nodes
200
+
201
+ def graph_module_from_producer_nodes(
202
+ root: GraphModule, producer_nodes: List[Node]) -> GraphModule:
203
+ r''' Construct a graph module from extracted producer nodes
204
+ from `collect_producer_nodes` function
205
+ Args:
206
+ root: the root module for the original graph
207
+ producer_nodes: a list of nodes we use to construct the graph
208
+ Return:
209
+ A graph module constructed from the producer nodes
210
+ '''
211
+ assert len(producer_nodes) > 0, 'list of producer nodes can not be empty'
212
+ # since we traced back from node to getattr
213
+ producer_nodes.reverse()
214
+ graph = Graph()
215
+ env: Dict[Any, Any] = {}
216
+
217
+ def load_arg(a):
218
+ return map_arg(a, lambda node: env[node])
219
+ for producer_node in producer_nodes:
220
+ env[producer_node] = graph.node_copy(producer_node, load_arg)
221
+ graph.output(load_arg(producer_nodes[-1]))
222
+ graph_module = GraphModule(root, graph)
223
+ return graph_module
224
+
225
+ def assert_and_get_unique_device(module: torch.nn.Module) -> Any:
226
+ """
227
+ Returns the unique device for a module, or None if no device is found.
228
+ Throws an error if multiple devices are detected.
229
+ """
230
+ devices = {p.device for p in module.parameters()} | \
231
+ {p.device for p in module.buffers()}
232
+ """
233
+ As a temp workaround for AIMP HHC publish we added CPU check.remove it later. T163614564
234
+ """
235
+ if {torch.device("cpu"), torch.device("meta")} == devices:
236
+ warnings.warn("Both 'meta' and 'cpu' are present in the list of devices. Module can have one device. We Select 'cpu'.")
237
+ devices = {torch.device("cpu")}
238
+ ""
239
+ assert len(devices) <= 1, (
240
+ "prepare only works with cpu or single-device CUDA modules, "
241
+ f"but got devices {devices}"
242
+ )
243
+ device = next(iter(devices)) if len(devices) > 0 else None
244
+ return device
245
+
246
+ def create_getattr_from_value(module: torch.nn.Module, graph: Graph, prefix: str, value: Any) -> Node:
247
+ """
248
+ Given a value of any type, creates a getattr node corresponding to the value and
249
+ registers the value as a buffer to the module.
250
+ """
251
+ get_new_attr_name = get_new_attr_name_with_prefix(prefix)
252
+ attr_name = get_new_attr_name(module)
253
+ device = assert_and_get_unique_device(module)
254
+ new_value = value.clone().detach() if isinstance(value, torch.Tensor) \
255
+ else torch.tensor(value, device=device)
256
+ module.register_buffer(attr_name, new_value)
257
+ # Create get_attr with value
258
+ attr_node = graph.create_node("get_attr", attr_name)
259
+ return attr_node
260
+
261
+ def all_node_args_have_no_tensors(node: Node, modules: Dict[str, torch.nn.Module], cache: Dict[Node, bool]) -> bool:
262
+ """
263
+ If we know for sure that all of this node's args have no
264
+ tensors (are primitives), return True. If we either
265
+ find a tensor or are not sure, return False. Note: this
266
+ function is not exact.
267
+ """
268
+ if cache and node in cache:
269
+ return cache[node]
270
+
271
+ result = False # will be overwritten
272
+ if not isinstance(node, Node):
273
+ result = True
274
+ elif node.op == 'placeholder':
275
+ result = False
276
+ elif node.op == 'call_module':
277
+ assert isinstance(node.target, str)
278
+ if _is_activation_post_process(modules[node.target]):
279
+ result = all_node_args_have_no_tensors(node.args[0], modules, cache) # type: ignore[arg-type]
280
+ elif node.op == 'call_module':
281
+ result = False
282
+ elif node.op == 'call_function' and node.target is operator.getitem:
283
+ result = all_node_args_have_no_tensors(node.args[0], modules, cache) # type: ignore[arg-type]
284
+ elif node.op == 'get_attr':
285
+ result = False
286
+ elif node.target is getattr and node.args[1] in ['ndim', 'shape']:
287
+ # x1 = x0.ndim
288
+ result = True
289
+ elif node.op == 'call_method' and node.target == 'size':
290
+ # x1 = x0.size(0)
291
+ result = True
292
+ else:
293
+ found_one_tensor = False
294
+ for arg in node.args:
295
+ if isinstance(arg, list):
296
+ for list_el in arg:
297
+ if isinstance(list_el, Node):
298
+ this_list_el_args_have_no_tensors = \
299
+ all_node_args_have_no_tensors(list_el, modules, cache)
300
+ found_one_tensor = found_one_tensor or \
301
+ (not this_list_el_args_have_no_tensors)
302
+ # If found_one_tensor is True, there is no point in
303
+ # recursing further as the end result will always
304
+ # be True.
305
+ # TODO(future PR): remove this entire function and
306
+ # change to dtype inference without recursion.
307
+ if found_one_tensor:
308
+ result = not found_one_tensor
309
+ if cache:
310
+ cache[node] = result
311
+ return result
312
+ elif isinstance(arg, int):
313
+ pass
314
+ else:
315
+ if isinstance(arg, Node):
316
+ this_arg_args_have_no_tensors = all_node_args_have_no_tensors(arg, modules, cache)
317
+ found_one_tensor = found_one_tensor or \
318
+ (not this_arg_args_have_no_tensors)
319
+ # If found_one_tensor is True, there is no point in
320
+ # recursing further as the end result will always
321
+ # be True.
322
+ # TODO(future PR): remove this entire function and
323
+ # change to dtype inference without recursion.
324
+ if found_one_tensor:
325
+ result = not found_one_tensor
326
+ if cache:
327
+ cache[node] = result
328
+ return result
329
+ else:
330
+ found_one_tensor = True
331
+ result = not found_one_tensor
332
+ if cache:
333
+ cache[node] = result
334
+ return result
335
+
336
+ def all_node_args_except_first(node: Node) -> List[int]:
337
+ """
338
+ Returns all node arg indices after first
339
+ """
340
+ return list(range(1, len(node.args)))
341
+
342
+ def return_arg_list(arg_indices: List[int]) -> Callable[[Node], List[int]]:
343
+ """
344
+ Constructs a function that takes a node as arg and returns the arg_indices
345
+ that are valid for node.args
346
+ """
347
+ def arg_indices_func(node: Node) -> List[int]:
348
+ return [i for i in arg_indices if i < len(node.args)]
349
+ return arg_indices_func
350
+
351
+ NodeInfo = namedtuple("NodeInfo", "op target")
352
+
353
+ # this dict identifies which indices of a node are non tensors
354
+ # so that they can be propagated correctly since inserting observers
355
+ # for them would cause errors
356
+
357
+ NON_OBSERVABLE_ARG_DICT: Dict[NodeInfo, Dict[Union[type, torch.dtype], Callable[[Node], List[int]]]] = {
358
+ NodeInfo("call_method", "masked_fill") : {
359
+ torch.bool: return_arg_list([1]),
360
+ float: return_arg_list([2])
361
+ },
362
+ NodeInfo("call_method", "permute") : {
363
+ int: all_node_args_except_first
364
+ },
365
+ NodeInfo("call_method", "repeat") : {
366
+ int: all_node_args_except_first
367
+ },
368
+ NodeInfo("call_method", "reshape") : {
369
+ int: all_node_args_except_first
370
+ },
371
+ NodeInfo("call_method", "size") : {
372
+ int: return_arg_list([1])
373
+ },
374
+ NodeInfo("call_method", "transpose") : {
375
+ int: all_node_args_except_first
376
+ },
377
+ NodeInfo("call_method", torch.transpose) : {
378
+ int: all_node_args_except_first
379
+ },
380
+ NodeInfo("call_method", "unsqueeze") : {
381
+ int: return_arg_list([1])
382
+ },
383
+ NodeInfo("call_method", "unsqueeze_") : {
384
+ int: return_arg_list([1])
385
+ },
386
+ NodeInfo("call_method", torch.unsqueeze) : {
387
+ int: return_arg_list([1])
388
+ },
389
+ NodeInfo("call_method", "view") : {
390
+ int: all_node_args_except_first
391
+ },
392
+ }
393
+
394
+ EMPTY_ARG_DICT: Dict[Union[type, torch.dtype], Callable[[Node], List[int]]] = {}
395
+
396
+ def get_non_observable_arg_indexes_and_types(node: Node) -> Dict[Union[type, torch.dtype], Callable[[Node], List[int]]]:
397
+ """
398
+ Returns a dict with of non float tensor types as keys and values which correspond to a
399
+ function to retrieve the list (which takes the node as an argument)
400
+ """
401
+ info = NodeInfo(node.op, node.target)
402
+
403
+ return NON_OBSERVABLE_ARG_DICT.get(info, EMPTY_ARG_DICT)
404
+
405
+ def maybe_get_next_module(
406
+ node: Node,
407
+ modules: Dict[str, nn.Module],
408
+ target_module_type: Optional[Type[nn.Module]] = None,
409
+ target_functional_type: Any = None,
410
+ ) -> Optional[Node]:
411
+ """ Gets the next module that matches what is needed in
412
+ is_target_module_type if it exists
413
+
414
+ Args:
415
+ node: The node whose users we want to look at
416
+ target_module_type: Module type that we want to check
417
+ target_functional_type: Functional type that we want to check
418
+ """
419
+
420
+ for user in node.users.keys():
421
+ if user.op == 'call_module' and target_module_type is not None and \
422
+ isinstance(modules[str(user.target)], target_module_type):
423
+ return user
424
+ elif (user.op == 'call_function' and target_functional_type is not None and
425
+ user.target == target_functional_type):
426
+ return user
427
+
428
+ return None
429
+
430
+ def create_node_from_old_node_preserve_meta(
431
+ quantized_graph: Graph,
432
+ create_node_args: Tuple[Any, ...],
433
+ old_node: Node,
434
+ ) -> Node:
435
+ """
436
+ Creates `new_node` and copies the necessary metadata to it from `old_node`.
437
+ """
438
+ new_node = quantized_graph.create_node(*create_node_args)
439
+ new_node.stack_trace = old_node.stack_trace
440
+ return new_node
441
+
442
+ def get_skipped_module_name_and_classes(
443
+ prepare_custom_config: PrepareCustomConfig,
444
+ is_standalone_module: bool) -> Tuple[List[str], List[Type[Any]]]:
445
+ skipped_module_names = copy.copy(prepare_custom_config.non_traceable_module_names)
446
+ skipped_module_classes = copy.copy(prepare_custom_config.non_traceable_module_classes)
447
+ if not is_standalone_module:
448
+ # standalone module and custom module config are applied in top level module
449
+ skipped_module_names += list(prepare_custom_config.standalone_module_names.keys())
450
+ skipped_module_classes += list(prepare_custom_config.standalone_module_classes.keys())
451
+ skipped_module_classes += get_custom_module_class_keys(prepare_custom_config.float_to_observed_mapping)
452
+
453
+ return skipped_module_names, skipped_module_classes
454
+
455
+ def _is_custom_module_lstm(
456
+ node: Node,
457
+ named_modules: Dict[str, torch.nn.Module],
458
+ qconfig: QConfigAny = None,
459
+ # QuantizeHandler, but we cannot include the type here due to circular imports
460
+ qhandler: Optional[Any] = None,
461
+ ) -> bool:
462
+ """
463
+ Return whether this refers to the custom module LSTM flow.
464
+ """
465
+ mod = _get_module(node, named_modules)
466
+ if qconfig is not None and qhandler is not None:
467
+ assert isinstance(qhandler, torch.ao.quantization.fx.quantize_handler.QuantizeHandler) # type: ignore[attr-defined]
468
+ return isinstance(mod, torch.nn.LSTM) and \
469
+ activation_is_statically_quantized(qconfig) and \
470
+ qhandler.is_custom_module()
471
+ else:
472
+ return isinstance(mod, torch.ao.nn.quantizable.LSTM)
473
+
474
+ def _is_custom_module_mha(
475
+ node: Node,
476
+ named_modules: Dict[str, torch.nn.Module],
477
+ qconfig: QConfigAny = None,
478
+ # QuantizeHandler, but we cannot include the type here due to circular imports
479
+ qhandler: Optional[Any] = None,
480
+ ) -> bool:
481
+ """
482
+ Return whether this refers to the custom module MultiheadAttention flow.
483
+ """
484
+ mod = _get_module(node, named_modules)
485
+ if qconfig is not None and qhandler is not None:
486
+ assert isinstance(qhandler, torch.ao.quantization.fx.quantize_handler.QuantizeHandler) # type: ignore[attr-defined]
487
+ return isinstance(mod, torch.nn.MultiheadAttention) and \
488
+ activation_is_statically_quantized(qconfig) and \
489
+ qhandler.is_custom_module()
490
+ else:
491
+ return isinstance(mod, torch.ao.nn.quantizable.MultiheadAttention)
492
+
493
+ def _get_module(node: Node, named_modules: Dict[str, torch.nn.Module]) -> Optional[torch.nn.Module]:
494
+ """
495
+ If `node` refers to a call_module node, return the module, else None.
496
+ """
497
+ if node.op == "call_module" and str(node.target) in named_modules:
498
+ return named_modules[str(node.target)]
499
+ else:
500
+ return None
501
+
502
+ def _insert_dequant_stub(
503
+ node: Node,
504
+ model: torch.nn.Module,
505
+ named_modules: Dict[str, torch.nn.Module],
506
+ graph: Graph,
507
+ ) -> Node:
508
+ """
509
+ Attach a `DeQuantStub` to the model and create a node that calls this
510
+ `DeQuantStub` on the output of `node`, similar to how observers are inserted.
511
+ """
512
+ prefix = "dequant_stub_"
513
+ get_new_dequant_stub_name = get_new_attr_name_with_prefix(prefix)
514
+ dequant_stub_name = get_new_dequant_stub_name(model)
515
+ dequant_stub = DeQuantStub()
516
+ setattr(model, dequant_stub_name, dequant_stub)
517
+ named_modules[dequant_stub_name] = dequant_stub
518
+ with graph.inserting_after(node):
519
+ return graph.call_module(dequant_stub_name, (node,))
520
+
521
+ def _insert_dequant_stubs_for_custom_module_lstm_output(
522
+ node: Node,
523
+ model: torch.nn.Module,
524
+ named_modules: Dict[str, torch.nn.Module],
525
+ graph: Graph,
526
+ ) -> Node:
527
+ """
528
+ Insert DeQuantStubs after each internal output node of custom module LSTM.
529
+
530
+ Custom module LSTM outputs are nested tuples of the structure (output, (hidden0, hidden1)),
531
+ Since we cannot dequantize a tuple as a whole, we must first break down the tuple into its
532
+ components through `getitem`. This function transforms the graph as follows:
533
+
534
+ (1) Split the LSTM node into (output, (hidden0, hidden1))
535
+ (2) Insert a DeQuantStub after each internal node
536
+ (3) Recombine the DeQuantStubs into the same structure as before
537
+ (4) Reroute all consumers of the original LSTM node and its sub-nodes
538
+ (e.g. lstm[0])
539
+
540
+ Before:
541
+ lstm_output
542
+ |
543
+ v
544
+ original_user(s)
545
+ After:
546
+ lstm_output
547
+ / \\
548
+ / (getitem) \\
549
+ / \\
550
+ v v
551
+ output hidden
552
+ | / \\
553
+ (DeQuantStub) (getitem)
554
+ | / \\
555
+ v v v
556
+ output_dq hidden0 hidden1
557
+ | | |
558
+ | (DeQuantStub) (DeQuantStub)
559
+ | | |
560
+ | v v
561
+ | hidden0_dq hidden1_dq
562
+ | \\ /
563
+ | (tuple)
564
+ | \\ /
565
+ | v v
566
+ | hidden_dq
567
+ \\ /
568
+ \\ (tuple) /
569
+ v v
570
+ lstm_output_dq
571
+ |
572
+ v
573
+ original_user(s)
574
+
575
+ For step (4), reroute all users of the original LSTM node(s) as follows:
576
+ lstm_output -> lstm_output_dq
577
+ lstm_output[0] -> output_dq
578
+ lstm_output[1] -> hidden_dq
579
+ lstm_output[1][0] -> hidden0_dq
580
+ lstm_output[1][1] -> hidden1_dq
581
+
582
+ Return the node `lstm_output_dq`.
583
+ """
584
+ # (1) Split the LSTM node into (output, (hidden0, hidden1))
585
+ # (2) Insert a DeQuantStub after each internal node
586
+ with graph.inserting_after(node):
587
+ output = graph.call_function(operator.getitem, (node, 0))
588
+ output_dq = _insert_dequant_stub(output, model, named_modules, graph)
589
+ with graph.inserting_after(output_dq):
590
+ hidden = graph.call_function(operator.getitem, (node, 1))
591
+ with graph.inserting_after(hidden):
592
+ hidden0 = graph.call_function(operator.getitem, (hidden, 0))
593
+ hidden0_dq = _insert_dequant_stub(hidden0, model, named_modules, graph)
594
+ with graph.inserting_after(hidden0_dq):
595
+ hidden1 = graph.call_function(operator.getitem, (hidden, 1))
596
+ hidden1_dq = _insert_dequant_stub(hidden1, model, named_modules, graph)
597
+
598
+ # (3) Recombine the DeQuantStubs into the same structure as before
599
+ with graph.inserting_after(hidden1_dq):
600
+ hidden_dq = graph.call_function(tuple, ([hidden0_dq, hidden1_dq],))
601
+ with graph.inserting_after(hidden_dq):
602
+ lstm_output_dq = graph.call_function(tuple, ([output_dq, hidden_dq],))
603
+
604
+ # (4) Reroute all consumers of the original LSTM node and its sub-nodes
605
+ for user in list(node.users.keys()):
606
+ if user != output and user != hidden:
607
+ user.replace_input_with(node, lstm_output_dq)
608
+ # The getitem and tuple nodes we added here may interfere with reference quantized
609
+ # pattern matching, so we need to redirect the consumers of internal nodes to the
610
+ # corresponding nodes with DeQuantStubs (e.g. lstm_output_dq[0] -> output_dq) attached,
611
+ # in order to preserve reference patterns like "dequantize - consumer - quantize".
612
+ _reroute_tuple_getitem_pattern(graph)
613
+ return lstm_output_dq
614
+
615
+ def _maybe_get_custom_module_lstm_from_node_arg(
616
+ arg: Node,
617
+ named_modules: Dict[str, torch.nn.Module],
618
+ ) -> Optional[Node]:
619
+ """
620
+ Given an argument of a node, if the argument refers to the path through which the node
621
+ is a consumer of custom module LSTM, return the custom module LSTM node, or None otherwise.
622
+
623
+ This is used to determine whether a node is a consumer of custom module LSTM, and, if so,
624
+ skip inserting input observers for this node. This is because custom module LSTM produces
625
+ quantized outputs, so inserting an input observer for the consumer of custom module LSTM
626
+ would unnecessarily quantize the outputs again.
627
+
628
+ lstm -> consumer
629
+
630
+ In practice, however, custom module LSTM outputs a tuple (output, (hidden0, hidden1)) with
631
+ DeQuantStubs attached to each internal node (see `_insert_dequant_stubs_for_custom_module_lstm_output`).
632
+ This tuple can be consumed in one of four ways:
633
+
634
+ lstm -> getitem -> DeQuantStub -> consumer # consume lstm[0]
635
+ lstm -> getitem -> getitem -> DeQuantStub -> tuple -> consumer # consume lstm[1]
636
+ lstm -> getitem -> getitem -> DeQuantStub -> consumer # consume lstm[1][0] or lstm[1][1]
637
+ lstm -> getitem -> DeQuantStub -> tuple -> consumer # consume lstm
638
+
639
+ Thus, we must match against the above patterns instead of simply checking the parent node
640
+ to determine whether this node is a consumer of a custom module LSTM.
641
+ """
642
+ def match_dq(a):
643
+ return isinstance(_get_module(a, named_modules), DeQuantStub)
644
+
645
+ def match_lstm(a):
646
+ return _is_custom_module_lstm(a, named_modules)
647
+
648
+ def match_getitem(a):
649
+ return a.op == "call_function" and a.target == operator.getitem
650
+
651
+ def match_tuple(a):
652
+ return a.op == "call_function" and a.target == tuple
653
+
654
+ def _match_pattern(match_pattern: List[Callable]) -> Optional[Node]:
655
+ """
656
+ Traverse up the graph and match the args one by one.
657
+ If there is a match, return the last matched node, or None otherwise.
658
+ """
659
+ a = arg
660
+ for i, match in enumerate(match_pattern):
661
+ if not match(a):
662
+ return None
663
+ # Match next arg, for tuple the arg is a tuple of a list, e.g. ([dq_1, other_node],)
664
+ if i < len(match_pattern) - 1:
665
+ if match == match_tuple:
666
+ a = a.args[0][0] # type: ignore[assignment,index]
667
+ else:
668
+ a = a.args[0] # type: ignore[assignment]
669
+ return a
670
+
671
+ all_match_patterns = [
672
+ [match_dq, match_getitem, match_lstm],
673
+ [match_tuple, match_dq, match_getitem, match_getitem, match_lstm],
674
+ [match_dq, match_getitem, match_getitem, match_lstm],
675
+ [match_tuple, match_dq, match_getitem, match_lstm],
676
+ ]
677
+
678
+ for p in all_match_patterns:
679
+ matched_node = _match_pattern(p)
680
+ if matched_node is not None:
681
+ return matched_node
682
+ return None
683
+
684
+ def _reroute_tuple_getitem_pattern(graph: Graph):
685
+ """
686
+ Search for patterns where N consecutive `tuple` call_function nodes are followed by
687
+ N consecutive `getitem` call_function nodes that are "reverses" of the `tuple` nodes.
688
+ If we find this pattern, reroute the consumers of the last `getitem` to skip these
689
+ N `tuple` and `getitem` nodes.
690
+
691
+ Before:
692
+
693
+ a b c
694
+ | \\ /
695
+ \\ tuple
696
+ \\ /
697
+ tuple
698
+ |
699
+ getitem(1)
700
+ |
701
+ getitem(0)
702
+ |
703
+ d
704
+
705
+ After:
706
+
707
+ b
708
+ |
709
+ d
710
+ """
711
+ def find_patterns(
712
+ node: Node,
713
+ index_stack: List[int],
714
+ current_pattern: List[Node],
715
+ matched_patterns: List[List[Node]],
716
+ seen: Set[Tuple[Node, Tuple[int, ...]]]):
717
+ """
718
+ Traverse the graph recursively to match for the N-tuple - N-getitem patterns,
719
+ starting at the given node.
720
+
721
+ We use a stack to keep track of the expected `getitem` indices, since these are
722
+ reversed from the `tuple` indices. In the above example, the stack after
723
+ (b -> tuple -> tuple) will be [0, 1], which will be popped by getitem(1) first
724
+ and then by getitem(0).
725
+
726
+ TODO: traverse upwards from the output and handle the case when tuple is not a
727
+ separate node, e.g. graph.call_function(operator.getitem, args=(a, (b, c)))
728
+ """
729
+ if len(index_stack) == 0 and len(current_pattern) > 0:
730
+ matched_patterns.append(copy.copy(current_pattern))
731
+ current_pattern.clear()
732
+
733
+ # Avoid duplicating work
734
+ state = (node, tuple(index_stack))
735
+ if state in seen:
736
+ return
737
+ seen.add(state)
738
+
739
+ # Iterate through users of this node to find tuple/getitem nodes to match
740
+ for user in node.users:
741
+ if user.op == "call_function" and user.target == tuple:
742
+ for i, user_arg in enumerate(user.args[0]): # type: ignore[arg-type]
743
+ if user_arg == node:
744
+ index_stack.append(i)
745
+ current_pattern.append(user)
746
+ find_patterns(user, index_stack, current_pattern, matched_patterns, seen)
747
+ elif user.op == "call_function" and user.target == operator.getitem:
748
+ if len(index_stack) > 0:
749
+ if user.args[1] == index_stack[-1]:
750
+ index_stack.pop()
751
+ current_pattern.append(user)
752
+ find_patterns(user, index_stack, current_pattern, matched_patterns, seen)
753
+ return matched_patterns
754
+
755
+ # Collect all matched patterns
756
+ matched_patterns: List[List[Node]] = []
757
+ seen: Set[Tuple[Node, Tuple[int, ...]]] = set() # (node, index_stack)
758
+ for node in graph.nodes:
759
+ find_patterns(node, [], [], matched_patterns, seen)
760
+
761
+ # For each pattern, redirect all consumers of the last getitem node to the correct input
762
+ # of the first tuple node
763
+ for pattern in matched_patterns:
764
+ first_tuple = pattern[0]
765
+ last_getitem = pattern[-1]
766
+ assert first_tuple.op == "call_function" and first_tuple.target == tuple
767
+ assert last_getitem.op == "call_function" and last_getitem.target == operator.getitem
768
+ last_getitem_index = last_getitem.args[1]
769
+ new_input = first_tuple.args[0][last_getitem_index] # type: ignore[index]
770
+ for user in list(last_getitem.users.keys()):
771
+ user.replace_input_with(last_getitem, new_input)
772
+
773
+ def _get_observer_from_activation_post_process(
774
+ activation_post_process: Union[ObserverBase, FakeQuantizeBase],
775
+ ) -> ObserverBase:
776
+ """
777
+ If `activation_post_process` is an observer, return the observer.
778
+ If `activation_post_process` is a fake quantize, return the internal observer.
779
+ """
780
+ if isinstance(activation_post_process, ObserverBase):
781
+ return activation_post_process
782
+ else:
783
+ assert isinstance(activation_post_process, FakeQuantizeBase)
784
+ return activation_post_process.activation_post_process # type: ignore[return-value]
785
+
786
+ def _qconfig_satisfies_dtype_config_constraints(
787
+ qconfig: QConfigAny,
788
+ dtype_with_constraints: DTypeWithConstraints,
789
+ is_activation: bool = True) -> bool:
790
+ """
791
+ Return whether `qconfig` satisfies the following constraints from the backend,
792
+ specified through the activation and weight DTypeWithConstraints.
793
+
794
+ 1. QConfig specified a quantization range that falls within the backend's, if any
795
+ 2. QConfig specified a min scale value that is >= the backend's, if any
796
+ 3. QConfig specified a FixedQParamsObserver or FixedQParamsFakeQuantize that has
797
+ scale and zero point that match the backend's, if any
798
+
799
+ If `is_activation` is True, we check `qconfig.activation`, else we check `qconfig.weight`.
800
+ If `qconfig` or `dtype_with_constraints.dtype` is None, or the dtypes do not match, return True.
801
+ """
802
+ # TODO: log warnings only when the user enabled a debug flag
803
+ def _activation_post_process_satisfies_dtype_config_constraints(
804
+ activation_post_process: Union[ObserverBase, FakeQuantizeBase],
805
+ dtype_with_constraints: DTypeWithConstraints,
806
+ debug_string: str) -> bool:
807
+ observer = _get_observer_from_activation_post_process(activation_post_process)
808
+ app_quant_min = getattr(observer, "quant_min", None)
809
+ app_quant_max = getattr(observer, "quant_max", None)
810
+ # TODO: for now, just use the existing eps value as scale_min. In the future, we should
811
+ # resolve the differences between the two, either by renaming eps or some other way
812
+ app_scale_min = getattr(observer, "eps", None)
813
+ backend_quant_min = dtype_with_constraints.quant_min_lower_bound
814
+ backend_quant_max = dtype_with_constraints.quant_max_upper_bound
815
+ backend_scale_min = dtype_with_constraints.scale_min_lower_bound
816
+ backend_scale_exact_match = dtype_with_constraints.scale_exact_match
817
+ backend_zero_point_exact_match = dtype_with_constraints.zero_point_exact_match
818
+ # check quantization ranges
819
+ if backend_quant_min is not None and backend_quant_max is not None:
820
+ if app_quant_min is None or app_quant_max is None:
821
+ warnings.warn(f"QConfig {debug_string} must specify 'quant_min' and 'quant_max', ignoring {qconfig}")
822
+ return False
823
+ elif app_quant_min < backend_quant_min or app_quant_max > backend_quant_max:
824
+ warnings.warn(
825
+ f"QConfig {debug_string} quantization range must fall within the backend's:\n"
826
+ f"QConfig range = ({app_quant_min}, {app_quant_max}), "
827
+ f"BackendConfig range = ({backend_quant_min}, {backend_quant_max}), "
828
+ f"ignoring {qconfig}"
829
+ )
830
+ return False
831
+ # check scale min
832
+ if backend_scale_min is not None:
833
+ if app_scale_min is None:
834
+ warnings.warn(f"QConfig {debug_string} must specify 'eps', ignoring {qconfig}")
835
+ return False
836
+ if app_scale_min < backend_scale_min:
837
+ warnings.warn(
838
+ f"QConfig {debug_string} eps ({app_scale_min}) must be greater than or equal to "
839
+ f"the backend's min scale value ({backend_scale_min}), ignoring {qconfig}"
840
+ )
841
+ return False
842
+ # check fixed scale and zero point
843
+ if backend_scale_exact_match is not None and backend_zero_point_exact_match is not None:
844
+ # For tests only, accept the following qconfigs for now
845
+ # TODO: handle fp16 qconfigs properly
846
+ for accepted_qconfig in [float16_static_qconfig, float16_dynamic_qconfig]:
847
+ if qconfig_equals(qconfig, accepted_qconfig):
848
+ return True
849
+ suggestion_str = (
850
+ "Please use torch.ao.quantization.get_default_qconfig_mapping or "
851
+ "torch.ao.quantization.get_default_qat_qconfig_mapping. Example:\n"
852
+ " qconfig_mapping = get_default_qconfig_mapping(\"fbgemm\")\n"
853
+ " model = prepare_fx(model, qconfig_mapping, example_inputs)"
854
+ )
855
+ if not isinstance(activation_post_process, FixedQParamsObserver) and \
856
+ not isinstance(activation_post_process, FixedQParamsFakeQuantize):
857
+ warnings.warn(
858
+ f"QConfig must specify a FixedQParamsObserver or a FixedQParamsFakeQuantize "
859
+ f"for fixed qparams ops, ignoring {qconfig}.\n{suggestion_str}"
860
+ )
861
+ return False
862
+ if observer.scale != backend_scale_exact_match or observer.zero_point != backend_zero_point_exact_match:
863
+ warnings.warn(
864
+ f"QConfig fixed scale ({observer.scale}) and zero point ({observer.zero_point}) "
865
+ f"do not match the backend's ({backend_scale_exact_match} and {backend_zero_point_exact_match}), "
866
+ f"ignoring {qconfig}.\n{suggestion_str}"
867
+ )
868
+ return False
869
+ return True
870
+
871
+ if qconfig is None or dtype_with_constraints.dtype is None:
872
+ return True
873
+
874
+ activation_post_process_ctr = qconfig.activation if is_activation else qconfig.weight
875
+ debug_string = "activation" if is_activation else "weight"
876
+ satisfies_constraints = True
877
+ if activation_post_process_ctr is not None:
878
+ activation_post_process = activation_post_process_ctr()
879
+ assert _is_activation_post_process(activation_post_process)
880
+ # If dtypes don't match, don't check the activation_post_process and return True early
881
+ if activation_post_process.dtype != dtype_with_constraints.dtype:
882
+ return True
883
+ satisfies_constraints = _activation_post_process_satisfies_dtype_config_constraints(
884
+ activation_post_process, dtype_with_constraints, debug_string)
885
+ return satisfies_constraints
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/__init__.py ADDED
File without changes
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/__pycache__/export_utils.cpython-311.pyc ADDED
Binary file (9.73 kB). View file