lsmpp commited on
Commit
613726b
·
verified ·
1 Parent(s): 5fa88dc

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .venv/lib/python3.12/site-packages/torch/_C/_dynamo/__init__.pyi +4 -0
  2. .venv/lib/python3.12/site-packages/torch/_C/_dynamo/compiled_autograd.pyi +13 -0
  3. .venv/lib/python3.12/site-packages/torch/_C/_dynamo/eval_frame.pyi +71 -0
  4. .venv/lib/python3.12/site-packages/torch/_C/_dynamo/guards.pyi +191 -0
  5. .venv/lib/python3.12/site-packages/torch/_C/_export/__init__.pyi +9 -0
  6. .venv/lib/python3.12/site-packages/torch/_C/_export/pt2_archive_constants.pyi +22 -0
  7. .venv/lib/python3.12/site-packages/torch/_custom_op/__pycache__/__init__.cpython-312.pyc +0 -0
  8. .venv/lib/python3.12/site-packages/torch/_custom_op/__pycache__/autograd.cpython-312.pyc +0 -0
  9. .venv/lib/python3.12/site-packages/torch/_custom_op/__pycache__/impl.cpython-312.pyc +0 -0
  10. .venv/lib/python3.12/site-packages/torch/_strobelight/__pycache__/__init__.cpython-312.pyc +0 -0
  11. .venv/lib/python3.12/site-packages/torch/_strobelight/__pycache__/cli_function_profiler.cpython-312.pyc +0 -0
  12. .venv/lib/python3.12/site-packages/torch/_strobelight/__pycache__/compile_time_profiler.cpython-312.pyc +0 -0
  13. .venv/lib/python3.12/site-packages/torch/ao/__pycache__/__init__.cpython-312.pyc +0 -0
  14. .venv/lib/python3.12/site-packages/torch/ao/nn/__init__.py +35 -0
  15. .venv/lib/python3.12/site-packages/torch/ao/nn/__pycache__/__init__.cpython-312.pyc +0 -0
  16. .venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/__init__.py +41 -0
  17. .venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/__pycache__/__init__.cpython-312.pyc +0 -0
  18. .venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/modules/__init__.py +41 -0
  19. .venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/modules/__pycache__/__init__.cpython-312.pyc +0 -0
  20. .venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/modules/__pycache__/fused.cpython-312.pyc +0 -0
  21. .venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/modules/fused.py +287 -0
  22. .venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/__init__.py +1 -0
  23. .venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/__pycache__/__init__.cpython-312.pyc +0 -0
  24. .venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/modules/__init__.py +32 -0
  25. .venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/modules/__pycache__/__init__.cpython-312.pyc +0 -0
  26. .venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/modules/__pycache__/conv_fused.cpython-312.pyc +0 -0
  27. .venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/modules/__pycache__/linear_fused.cpython-312.pyc +0 -0
  28. .venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/modules/__pycache__/linear_relu.cpython-312.pyc +0 -0
  29. .venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/modules/conv_fused.py +1064 -0
  30. .venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/modules/linear_fused.py +193 -0
  31. .venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/modules/linear_relu.py +52 -0
  32. .venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/__init__.py +15 -0
  33. .venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/__pycache__/__init__.cpython-312.pyc +0 -0
  34. .venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/__init__.py +1 -0
  35. .venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/__pycache__/__init__.cpython-312.pyc +0 -0
  36. .venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/modules/__init__.py +6 -0
  37. .venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/modules/__pycache__/__init__.cpython-312.pyc +0 -0
  38. .venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/modules/__pycache__/linear_relu.cpython-312.pyc +0 -0
  39. .venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/modules/linear_relu.py +61 -0
  40. .venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/__init__.py +18 -0
  41. .venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/__init__.cpython-312.pyc +0 -0
  42. .venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/bn_relu.cpython-312.pyc +0 -0
  43. .venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/conv_add.cpython-312.pyc +0 -0
  44. .venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/conv_relu.cpython-312.pyc +0 -0
  45. .venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/linear_relu.cpython-312.pyc +0 -0
  46. .venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/bn_relu.py +107 -0
  47. .venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/conv_add.py +147 -0
  48. .venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/conv_relu.py +266 -0
  49. .venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/linear_relu.py +190 -0
  50. .venv/lib/python3.12/site-packages/torch/ao/nn/qat/__init__.py +1 -0
.venv/lib/python3.12/site-packages/torch/_C/_dynamo/__init__.pyi ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from . import compiled_autograd, eval_frame, guards # noqa: F401
2
+
3
+ def strip_function_call(name: str) -> str: ...
4
+ def is_valid_var_name(name: str) -> bool | int: ...
.venv/lib/python3.12/site-packages/torch/_C/_dynamo/compiled_autograd.pyi ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable
2
+
3
+ from torch import Tensor
4
+ from torch._dynamo.compiled_autograd import AutogradCompilerInstance
5
+
6
+ def set_autograd_compiler(
7
+ autograd_compiler: Callable[[], AutogradCompilerInstance] | None,
8
+ dynamic: bool,
9
+ ) -> tuple[Callable[[], AutogradCompilerInstance] | None, bool]: ...
10
+ def clear_cache() -> None: ...
11
+ def is_cache_empty() -> bool: ...
12
+ def set_verbose_logger(fn: Callable[[str], None] | None) -> bool: ...
13
+ def call_cpp_tensor_pre_hooks(idx: int, grad: Tensor) -> Tensor: ...
.venv/lib/python3.12/site-packages/torch/_C/_dynamo/eval_frame.pyi ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import enum
2
+ import types
3
+ from typing import Optional, overload
4
+
5
+ from torch._dynamo.types import (
6
+ DynamoCallback,
7
+ DynamoGuardCompleteHook,
8
+ DynamoGuardHook,
9
+ GuardFn,
10
+ )
11
+
12
+ def set_eval_frame(callback: DynamoCallback) -> DynamoCallback: ...
13
+ def set_skip_guard_eval_unsafe(value: bool) -> bool: ...
14
+ def get_eval_frame_callback() -> DynamoCallback: ...
15
+ def reset_code(code: types.CodeType) -> None: ...
16
+ def unsupported(obj1: object, obj2: object) -> object: ...
17
+ def set_code_exec_strategy(
18
+ code: types.CodeType, strategy: _FrameExecStrategy
19
+ ) -> None: ...
20
+ def set_guard_error_hook(hook: DynamoGuardHook) -> None: ...
21
+ def set_guard_complete_hook(
22
+ hook: Optional[DynamoGuardCompleteHook],
23
+ ) -> Optional[DynamoGuardCompleteHook]: ...
24
+ def raise_sigtrap() -> None: ...
25
+
26
+ class _CacheEntry:
27
+ def check_fn(self, *args: object, **kwargs: object) -> bool: ...
28
+ code: types.CodeType
29
+ next: _CacheEntry | None
30
+
31
+ class _ExtraState:
32
+ def invalidate(self, cache_entry: _CacheEntry, guard_manager: object) -> None: ...
33
+
34
+ class _FrameAction(enum.IntEnum):
35
+ DEFAULT = 0
36
+ SKIP = 1
37
+ RUN_ONLY = 2
38
+
39
+ class _FrameExecStrategy:
40
+ cur_action: _FrameAction
41
+ recursive_action: _FrameAction
42
+
43
+ @overload
44
+ def __init__(self) -> None: ...
45
+ @overload
46
+ def __init__(
47
+ self, cur_action: _FrameAction, recursive_action: _FrameAction
48
+ ) -> None: ...
49
+
50
+ # This is an object that encapsulates the Python FrameType, and exposes
51
+ # properties Dynamo cares about for a frame.
52
+ class _PyInterpreterFrame:
53
+ f_code: types.CodeType
54
+ f_locals: dict[str, object]
55
+ f_globals: dict[str, object]
56
+ f_builtins: dict[str, object]
57
+ f_lasti: int
58
+ f_lineo: int
59
+ f_back: types.FrameType
60
+ # A tuple containing cell objects captured by this frame.
61
+ closure: tuple[types.CellType]
62
+
63
+ def _debug_get_cache_entry_list(code: types.CodeType) -> list[_CacheEntry]: ...
64
+
65
+ py_opcode_caches: list[int]
66
+
67
+ def code_framelocals_names(code: types.CodeType) -> tuple[str]: ...
68
+ def _load_precompile_entry(
69
+ code: types.CodeType, guard_manager: GuardFn, dynamo_code: types.CodeType
70
+ ) -> None: ...
71
+ def _reset_precompile_entries(code: types.CodeType) -> None: ...
.venv/lib/python3.12/site-packages/torch/_C/_dynamo/guards.pyi ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ from typing import Any, Callable
3
+
4
+ import torch
5
+
6
+ class GlobalStateGuard:
7
+ def check(self) -> bool: ...
8
+ def reason(self) -> str: ...
9
+
10
+ class LeafGuard: ...
11
+ class GuardDebugInfo: ...
12
+
13
+ class GuardManager:
14
+ def check(self, value) -> bool: ...
15
+ def check_verbose(self, value) -> GuardDebugInfo: ...
16
+
17
+ # Accessors
18
+ def globals_dict_manager(
19
+ self,
20
+ f_globals: dict[str, Any],
21
+ source,
22
+ example_value,
23
+ guard_manager_enum,
24
+ ) -> GuardManager: ...
25
+ def framelocals_manager(
26
+ self,
27
+ key: tuple[str, int],
28
+ source,
29
+ example_value,
30
+ guard_manager_enum,
31
+ ) -> GuardManager: ...
32
+ def dict_getitem_manager(
33
+ self,
34
+ key,
35
+ source,
36
+ example_value,
37
+ guard_manager_enum,
38
+ ) -> GuardManager: ...
39
+ def global_weakref_manager(
40
+ self,
41
+ global_name: str,
42
+ source,
43
+ example_value,
44
+ guard_manager_enum,
45
+ ) -> GuardManager: ...
46
+ def type_manager(
47
+ self,
48
+ source,
49
+ example_value,
50
+ guard_manager_enum,
51
+ ) -> GuardManager: ...
52
+ def getattr_manager(
53
+ self,
54
+ attr: str,
55
+ source,
56
+ example_value,
57
+ guard_manager_enum,
58
+ ) -> GuardManager: ...
59
+ def tensor_property_size_manager(
60
+ self,
61
+ idx: int,
62
+ source,
63
+ example_value,
64
+ guard_manager_enum,
65
+ ) -> GuardManager: ...
66
+ def tensor_property_shape_manager(
67
+ self,
68
+ idx: int,
69
+ source,
70
+ example_value,
71
+ guard_manager_enum,
72
+ ) -> GuardManager: ...
73
+ def tensor_property_storage_offset_manager(
74
+ self,
75
+ idx: None,
76
+ source,
77
+ example_value,
78
+ guard_manager_enum,
79
+ ) -> GuardManager: ...
80
+ def indexed_manager(
81
+ self,
82
+ idx: int,
83
+ source,
84
+ example_value,
85
+ guard_manager_enum,
86
+ ) -> GuardManager: ...
87
+ def lambda_manager(
88
+ self,
89
+ python_lambda,
90
+ source,
91
+ example_value,
92
+ guard_manager_enum,
93
+ ) -> GuardManager: ...
94
+
95
+ # Leaf guards
96
+ def add_lambda_guard(self, user_lambda, verbose_code_parts: list[str]) -> None: ...
97
+ def add_id_match_guard(self, id_val, verbose_code_parts: list[str]) -> None: ...
98
+ def add_equals_match_guard(
99
+ self,
100
+ equals_val,
101
+ verbose_code_parts: list[str],
102
+ ) -> None: ...
103
+ def add_global_state_guard(self, verbose_code_parts: list[str]) -> None: ...
104
+ def add_torch_function_mode_stack_guard(
105
+ self, initial_stack, verbose_code_parts: list[str]
106
+ ) -> None: ...
107
+ def add_mapping_keys_guard(sef, value, verbose_code_parts: list[str]) -> None: ...
108
+
109
+ class RootGuardManager(GuardManager):
110
+ def get_epilogue_lambda_guards(self) -> list[LeafGuard]: ...
111
+ def add_epilogue_lambda_guard(
112
+ self,
113
+ guard: LeafGuard,
114
+ verbose_code_parts: list[str],
115
+ ) -> None: ...
116
+ def clone_manager(
117
+ self, clone_filter_fn: Callable[[GuardManager], bool]
118
+ ) -> RootGuardManager: ...
119
+
120
+ class DictGuardManager(GuardManager):
121
+ def get_key_manager(
122
+ self,
123
+ index,
124
+ source,
125
+ example_value,
126
+ guard_manager_enum,
127
+ ) -> GuardManager: ...
128
+ def get_value_manager(
129
+ self,
130
+ index,
131
+ source,
132
+ example_value,
133
+ guard_manager_enum,
134
+ ) -> GuardManager: ...
135
+
136
+ def install_object_aliasing_guard(
137
+ guard_managers: list[GuardManager],
138
+ tensor_names: list[str],
139
+ verbose_code_parts: list[str],
140
+ ): ...
141
+ def install_no_tensor_aliasing_guard(
142
+ guard_managers: list[GuardManager],
143
+ tensor_names: list[str],
144
+ verbose_code_parts: list[str],
145
+ ): ...
146
+ def install_storage_overlapping_guard(
147
+ overlapping_guard_managers: list[GuardManager],
148
+ non_overlapping_guard_managers: list[GuardManager],
149
+ verbose_code_parts: list[str],
150
+ ): ...
151
+ def install_symbolic_shape_guard(
152
+ guard_managers: list[GuardManager],
153
+ nargs_int: int,
154
+ nargs_float: int,
155
+ py_addr: int,
156
+ py_addr_keep_alive: Any,
157
+ verbose_code_parts: list[str],
158
+ ): ...
159
+ def profile_guard_manager(
160
+ guard_manager: GuardManager,
161
+ f_locals: dict[str, Any],
162
+ n_iters: int,
163
+ ) -> float: ...
164
+
165
+ class TensorGuards:
166
+ def __init__(
167
+ self,
168
+ *,
169
+ dynamic_dims_sizes: list[torch.SymInt | None] | None = None,
170
+ dynamic_dims_strides: list[torch.SymInt | None] | None = None,
171
+ ) -> None: ...
172
+ def check(self, *args) -> bool: ...
173
+ def check_verbose(self, *args, tensor_check_names=None) -> bool | str: ...
174
+
175
+ def assert_size_stride(
176
+ item: torch.Tensor,
177
+ size: torch.types._size,
178
+ stride: torch.types._size,
179
+ op_name: str | None = None,
180
+ ): ...
181
+ def assert_alignment(
182
+ item: torch.Tensor,
183
+ alignment: int,
184
+ op_name: str | None = None,
185
+ ): ...
186
+ def check_obj_id(obj: object, expected: int) -> bool: ...
187
+ def check_type_id(obj: object, expected: int) -> bool: ...
188
+ def dict_version(d: dict[Any, Any]) -> int: ...
189
+ def compute_overlapping_tensors(
190
+ tensors: list[torch.Tensor], symbolic: bool = True
191
+ ) -> set[int]: ...
.venv/lib/python3.12/site-packages/torch/_C/_export/__init__.pyi ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Defined in torch/csrc/export/pybind.cpp
2
+ class CppExportedProgram: ...
3
+
4
+ def deserialize_exported_program(
5
+ serialized_program: str,
6
+ ) -> CppExportedProgram: ...
7
+ def serialize_exported_program(
8
+ cpp_exported_program: CppExportedProgram,
9
+ ) -> str: ...
.venv/lib/python3.12/site-packages/torch/_C/_export/pt2_archive_constants.pyi ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Defined in torch/csrc/export/pt2_archive_constants.h
2
+
3
+ ARCHIVE_ROOT_NAME: str = ...
4
+ ARCHIVE_FORMAT_PATH: str = ...
5
+ ARCHIVE_FORMAT_VALUE: str = ...
6
+ ARCHIVE_VERSION_PATH: str = ...
7
+ ARCHIVE_VERSION_VALUE: str = ...
8
+ MODELS_DIR: str = ...
9
+ MODELS_FILENAME_FORMAT: str = ...
10
+ AOTINDUCTOR_DIR: str = ...
11
+ MTIA_DIR: str = ...
12
+ WEIGHTS_DIR: str = ...
13
+ WEIGHT_FILENAME_PREFIX: str = ...
14
+ CONSTANTS_DIR: str = ...
15
+ TENSOR_CONSTANT_FILENAME_PREFIX: str = ...
16
+ CUSTOM_OBJ_FILENAME_PREFIX: str = ...
17
+ SAMPLE_INPUTS_DIR: str = ...
18
+ SAMPLE_INPUTS_FILENAME_FORMAT: str = ...
19
+ EXTRA_DIR: str = ...
20
+ MODULE_INFO_PATH: str = ...
21
+ XL_MODEL_WEIGHTS_DIR: str = ...
22
+ XL_MODEL_WEIGHTS_PARAM_CONFIG_PATH: str = ...
.venv/lib/python3.12/site-packages/torch/_custom_op/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (193 Bytes). View file
 
.venv/lib/python3.12/site-packages/torch/_custom_op/__pycache__/autograd.cpython-312.pyc ADDED
Binary file (13.5 kB). View file
 
.venv/lib/python3.12/site-packages/torch/_custom_op/__pycache__/impl.cpython-312.pyc ADDED
Binary file (32.3 kB). View file
 
.venv/lib/python3.12/site-packages/torch/_strobelight/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (195 Bytes). View file
 
.venv/lib/python3.12/site-packages/torch/_strobelight/__pycache__/cli_function_profiler.cpython-312.pyc ADDED
Binary file (15 kB). View file
 
.venv/lib/python3.12/site-packages/torch/_strobelight/__pycache__/compile_time_profiler.cpython-312.pyc ADDED
Binary file (9.85 kB). View file
 
.venv/lib/python3.12/site-packages/torch/ao/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (847 Bytes). View file
 
.venv/lib/python3.12/site-packages/torch/ao/nn/__init__.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # We are exposing all subpackages to the end-user.
2
+ # Because of possible inter-dependency, we want to avoid
3
+ # the cyclic imports, thus implementing lazy version
4
+ # as per https://peps.python.org/pep-0562/
5
+
6
+ from typing import TYPE_CHECKING as _TYPE_CHECKING
7
+
8
+
9
+ if _TYPE_CHECKING:
10
+ from types import ModuleType
11
+
12
+ from torch.ao.nn import ( # noqa: TC004
13
+ intrinsic as intrinsic,
14
+ qat as qat,
15
+ quantizable as quantizable,
16
+ quantized as quantized,
17
+ sparse as sparse,
18
+ )
19
+
20
+
21
+ __all__ = [
22
+ "intrinsic",
23
+ "qat",
24
+ "quantizable",
25
+ "quantized",
26
+ "sparse",
27
+ ]
28
+
29
+
30
+ def __getattr__(name: str) -> "ModuleType":
31
+ if name in __all__:
32
+ import importlib
33
+
34
+ return importlib.import_module("." + name, __name__)
35
+ raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
.venv/lib/python3.12/site-packages/torch/ao/nn/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (879 Bytes). View file
 
.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/__init__.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import types
2
+
3
+ from .modules import * # noqa: F403
4
+ from .modules.fused import _FusedModule # noqa: F403
5
+
6
+
7
+ # # Subpackages
8
+ # from . import qat # noqa: F403
9
+ # from . import quantized # noqa: F403
10
+
11
+ __all__ = [
12
+ "ConvBn1d",
13
+ "ConvBn2d",
14
+ "ConvBn3d",
15
+ "ConvBnReLU1d",
16
+ "ConvBnReLU2d",
17
+ "ConvBnReLU3d",
18
+ "ConvReLU1d",
19
+ "ConvReLU2d",
20
+ "ConvReLU3d",
21
+ "LinearReLU",
22
+ "BNReLU2d",
23
+ "BNReLU3d",
24
+ "LinearBn1d",
25
+ "LinearLeakyReLU",
26
+ "LinearTanh",
27
+ "ConvAdd2d",
28
+ "ConvAddReLU2d",
29
+ ]
30
+
31
+
32
+ # We are exposing all subpackages to the end-user.
33
+ # Because of possible inter-dependency, we want to avoid
34
+ # the cyclic imports, thus implementing lazy version
35
+ # as per https://peps.python.org/pep-0562/
36
+ def __getattr__(name: str) -> types.ModuleType:
37
+ if name in __all__:
38
+ import importlib
39
+
40
+ return importlib.import_module("." + name, __name__)
41
+ raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (993 Bytes). View file
 
.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/modules/__init__.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .fused import ( # noqa: F401
2
+ _FusedModule,
3
+ BNReLU2d,
4
+ BNReLU3d,
5
+ ConvAdd2d,
6
+ ConvAddReLU2d,
7
+ ConvBn1d,
8
+ ConvBn2d,
9
+ ConvBn3d,
10
+ ConvBnReLU1d,
11
+ ConvBnReLU2d,
12
+ ConvBnReLU3d,
13
+ ConvReLU1d,
14
+ ConvReLU2d,
15
+ ConvReLU3d,
16
+ LinearBn1d,
17
+ LinearLeakyReLU,
18
+ LinearReLU,
19
+ LinearTanh,
20
+ )
21
+
22
+
23
+ __all__ = [
24
+ "ConvBn1d",
25
+ "ConvBn2d",
26
+ "ConvBn3d",
27
+ "ConvBnReLU1d",
28
+ "ConvBnReLU2d",
29
+ "ConvBnReLU3d",
30
+ "ConvReLU1d",
31
+ "ConvReLU2d",
32
+ "ConvReLU3d",
33
+ "LinearReLU",
34
+ "BNReLU2d",
35
+ "BNReLU3d",
36
+ "LinearBn1d",
37
+ "LinearLeakyReLU",
38
+ "LinearTanh",
39
+ "ConvAdd2d",
40
+ "ConvAddReLU2d",
41
+ ]
.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/modules/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (744 Bytes). View file
 
.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/modules/__pycache__/fused.cpython-312.pyc ADDED
Binary file (13.1 kB). View file
 
.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/modules/fused.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import torch
3
+ from torch.nn import (
4
+ BatchNorm1d,
5
+ BatchNorm2d,
6
+ BatchNorm3d,
7
+ Conv1d,
8
+ Conv2d,
9
+ Conv3d,
10
+ Linear,
11
+ ReLU,
12
+ )
13
+ from torch.nn.utils.parametrize import type_before_parametrizations
14
+
15
+
16
+ __all__ = [
17
+ "ConvReLU1d",
18
+ "ConvReLU2d",
19
+ "ConvReLU3d",
20
+ "LinearReLU",
21
+ "ConvBn1d",
22
+ "ConvBn2d",
23
+ "ConvBnReLU1d",
24
+ "ConvBnReLU2d",
25
+ "ConvBn3d",
26
+ "ConvBnReLU3d",
27
+ "BNReLU2d",
28
+ "BNReLU3d",
29
+ "LinearBn1d",
30
+ "LinearLeakyReLU",
31
+ "LinearTanh",
32
+ "ConvAdd2d",
33
+ "ConvAddReLU2d",
34
+ ]
35
+
36
+
37
+ # Used for identifying intrinsic modules used in quantization
38
+ class _FusedModule(torch.nn.Sequential):
39
+ pass
40
+
41
+
42
+ class ConvReLU1d(_FusedModule):
43
+ r"""This is a sequential container which calls the Conv1d and ReLU modules.
44
+ During quantization this will be replaced with the corresponding fused module."""
45
+
46
+ def __init__(self, conv, relu):
47
+ assert (
48
+ type_before_parametrizations(conv) == Conv1d
49
+ and type_before_parametrizations(relu) == ReLU
50
+ ), (
51
+ f"Incorrect types for input modules{type_before_parametrizations(conv)}"
52
+ f"{type_before_parametrizations(relu)}"
53
+ )
54
+ super().__init__(conv, relu)
55
+
56
+
57
+ class ConvReLU2d(_FusedModule):
58
+ r"""This is a sequential container which calls the Conv2d and ReLU modules.
59
+ During quantization this will be replaced with the corresponding fused module."""
60
+
61
+ def __init__(self, conv, relu):
62
+ assert (
63
+ type_before_parametrizations(conv) == Conv2d
64
+ and type_before_parametrizations(relu) == ReLU
65
+ ), (
66
+ f"Incorrect types for input modules{type_before_parametrizations(conv)}"
67
+ f"{type_before_parametrizations(relu)}"
68
+ )
69
+ super().__init__(conv, relu)
70
+
71
+
72
+ class ConvReLU3d(_FusedModule):
73
+ r"""This is a sequential container which calls the Conv3d and ReLU modules.
74
+ During quantization this will be replaced with the corresponding fused module."""
75
+
76
+ def __init__(self, conv, relu):
77
+ assert (
78
+ type_before_parametrizations(conv) == Conv3d
79
+ and type_before_parametrizations(relu) == ReLU
80
+ ), (
81
+ f"Incorrect types for input modules{type_before_parametrizations(conv)}"
82
+ f"{type_before_parametrizations(relu)}"
83
+ )
84
+ super().__init__(conv, relu)
85
+
86
+
87
+ class LinearReLU(_FusedModule):
88
+ r"""This is a sequential container which calls the Linear and ReLU modules.
89
+ During quantization this will be replaced with the corresponding fused module."""
90
+
91
+ def __init__(self, linear, relu):
92
+ assert (
93
+ type_before_parametrizations(linear) == Linear
94
+ and type_before_parametrizations(relu) == ReLU
95
+ ), (
96
+ f"Incorrect types for input modules{type_before_parametrizations(linear)}"
97
+ f"{type_before_parametrizations(relu)}"
98
+ )
99
+ super().__init__(linear, relu)
100
+
101
+
102
+ class ConvBn1d(_FusedModule):
103
+ r"""This is a sequential container which calls the Conv 1d and Batch Norm 1d modules.
104
+ During quantization this will be replaced with the corresponding fused module."""
105
+
106
+ def __init__(self, conv, bn):
107
+ assert (
108
+ type_before_parametrizations(conv) == Conv1d
109
+ and type_before_parametrizations(bn) == BatchNorm1d
110
+ ), (
111
+ f"Incorrect types for input modules{type_before_parametrizations(conv)}"
112
+ f"{type_before_parametrizations(bn)}"
113
+ )
114
+ super().__init__(conv, bn)
115
+
116
+
117
+ class ConvBn2d(_FusedModule):
118
+ r"""This is a sequential container which calls the Conv 2d and Batch Norm 2d modules.
119
+ During quantization this will be replaced with the corresponding fused module."""
120
+
121
+ def __init__(self, conv, bn):
122
+ assert (
123
+ type_before_parametrizations(conv) == Conv2d
124
+ and type_before_parametrizations(bn) == BatchNorm2d
125
+ ), (
126
+ f"Incorrect types for input modules{type_before_parametrizations(conv)}"
127
+ f"{type_before_parametrizations(bn)}"
128
+ )
129
+ super().__init__(conv, bn)
130
+
131
+
132
+ class ConvBnReLU1d(_FusedModule):
133
+ r"""This is a sequential container which calls the Conv 1d, Batch Norm 1d, and ReLU modules.
134
+ During quantization this will be replaced with the corresponding fused module."""
135
+
136
+ def __init__(self, conv, bn, relu):
137
+ assert (
138
+ type_before_parametrizations(conv) == Conv1d
139
+ and type_before_parametrizations(bn) == BatchNorm1d
140
+ and type_before_parametrizations(relu) == ReLU
141
+ ), (
142
+ f"Incorrect types for input modules{type_before_parametrizations(conv)}"
143
+ f"{type_before_parametrizations(bn)}"
144
+ f"{type_before_parametrizations(relu)}"
145
+ )
146
+ super().__init__(conv, bn, relu)
147
+
148
+
149
+ class ConvBnReLU2d(_FusedModule):
150
+ r"""This is a sequential container which calls the Conv 2d, Batch Norm 2d, and ReLU modules.
151
+ During quantization this will be replaced with the corresponding fused module."""
152
+
153
+ def __init__(self, conv, bn, relu):
154
+ assert (
155
+ type_before_parametrizations(conv) == Conv2d
156
+ and type_before_parametrizations(bn) == BatchNorm2d
157
+ and type_before_parametrizations(relu) == ReLU
158
+ ), (
159
+ f"Incorrect types for input modules{type_before_parametrizations(conv)}"
160
+ f"{type_before_parametrizations(bn)}"
161
+ f"{type_before_parametrizations(relu)}"
162
+ )
163
+ super().__init__(conv, bn, relu)
164
+
165
+
166
+ class ConvBn3d(_FusedModule):
167
+ r"""This is a sequential container which calls the Conv 3d and Batch Norm 3d modules.
168
+ During quantization this will be replaced with the corresponding fused module."""
169
+
170
+ def __init__(self, conv, bn):
171
+ assert (
172
+ type_before_parametrizations(conv) == Conv3d
173
+ and type_before_parametrizations(bn) == BatchNorm3d
174
+ ), (
175
+ f"Incorrect types for input modules{type_before_parametrizations(conv)}"
176
+ f"{type_before_parametrizations(bn)}"
177
+ )
178
+ super().__init__(conv, bn)
179
+
180
+
181
+ class ConvBnReLU3d(_FusedModule):
182
+ r"""This is a sequential container which calls the Conv 3d, Batch Norm 3d, and ReLU modules.
183
+ During quantization this will be replaced with the corresponding fused module."""
184
+
185
+ def __init__(self, conv, bn, relu):
186
+ assert (
187
+ type_before_parametrizations(conv) == Conv3d
188
+ and type_before_parametrizations(bn) == BatchNorm3d
189
+ and type_before_parametrizations(relu) == ReLU
190
+ ), (
191
+ f"Incorrect types for input modules{type_before_parametrizations(conv)}"
192
+ f"{type_before_parametrizations(bn)}"
193
+ f"{type_before_parametrizations(relu)}"
194
+ )
195
+ super().__init__(conv, bn, relu)
196
+
197
+
198
+ class BNReLU2d(_FusedModule):
199
+ r"""This is a sequential container which calls the BatchNorm 2d and ReLU modules.
200
+ During quantization this will be replaced with the corresponding fused module."""
201
+
202
+ def __init__(self, batch_norm, relu):
203
+ assert (
204
+ type_before_parametrizations(batch_norm) == BatchNorm2d
205
+ and type_before_parametrizations(relu) == ReLU
206
+ ), (
207
+ f"Incorrect types for input modules{type_before_parametrizations(batch_norm)}"
208
+ f"{type_before_parametrizations(relu)}"
209
+ )
210
+ super().__init__(batch_norm, relu)
211
+
212
+
213
+ class BNReLU3d(_FusedModule):
214
+ r"""This is a sequential container which calls the BatchNorm 3d and ReLU modules.
215
+ During quantization this will be replaced with the corresponding fused module."""
216
+
217
+ def __init__(self, batch_norm, relu):
218
+ assert (
219
+ type_before_parametrizations(batch_norm) == BatchNorm3d
220
+ and type_before_parametrizations(relu) == ReLU
221
+ ), (
222
+ f"Incorrect types for input modules{type_before_parametrizations(batch_norm)}"
223
+ f"{type_before_parametrizations(relu)}"
224
+ )
225
+ super().__init__(batch_norm, relu)
226
+
227
+
228
+ class LinearBn1d(_FusedModule):
229
+ r"""This is a sequential container which calls the Linear and BatchNorm1d modules.
230
+ During quantization this will be replaced with the corresponding fused module."""
231
+
232
+ def __init__(self, linear, bn):
233
+ assert (
234
+ type_before_parametrizations(linear) == Linear
235
+ and type_before_parametrizations(bn) == BatchNorm1d
236
+ ), (
237
+ f"Incorrect types for input modules{type_before_parametrizations(linear)}"
238
+ f"{type_before_parametrizations(bn)}"
239
+ )
240
+ super().__init__(linear, bn)
241
+
242
+
243
+ class LinearLeakyReLU(_FusedModule):
244
+ r"""This is a sequential container which calls the Linear and LeakyReLU modules.
245
+ During quantization this will be replaced with the corresponding fused module."""
246
+
247
+ def __init__(self, linear, leaky_relu):
248
+ assert type(linear) == Linear and type(leaky_relu) == torch.nn.LeakyReLU, (
249
+ f"Incorrect types for input modules{type(linear)}{type(leaky_relu)}"
250
+ )
251
+ super().__init__(linear, leaky_relu)
252
+
253
+
254
+ class LinearTanh(_FusedModule):
255
+ r"""This is a sequential container which calls the Linear and Tanh modules.
256
+ During quantization this will be replaced with the corresponding fused module."""
257
+
258
+ def __init__(self, linear, tanh):
259
+ assert type(linear) == Linear and type(tanh) == torch.nn.Tanh, (
260
+ f"Incorrect types for input modules{type(linear)}{type(tanh)}"
261
+ )
262
+ super().__init__(linear, tanh)
263
+
264
+
265
+ class ConvAdd2d(_FusedModule):
266
+ r"""This is a sequential container which calls the Conv2d modules with extra Add.
267
+ During quantization this will be replaced with the corresponding fused module."""
268
+
269
+ def __init__(self, conv, add):
270
+ super().__init__(conv)
271
+ self.add = add
272
+
273
+ def forward(self, x1, x2): # type: ignore[override]
274
+ return self.add(self[0](x1), x2)
275
+
276
+
277
+ class ConvAddReLU2d(_FusedModule):
278
+ r"""This is a sequential container which calls the Conv2d, add, Relu.
279
+ During quantization this will be replaced with the corresponding fused module."""
280
+
281
+ def __init__(self, conv, add, relu):
282
+ super().__init__(conv)
283
+ self.add = add
284
+ self.relu = relu
285
+
286
+ def forward(self, x1, x2): # type: ignore[override]
287
+ return self.relu(self.add(self[0](x1), x2))
.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .modules import * # noqa: F403
.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (231 Bytes). View file
 
.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/modules/__init__.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .conv_fused import (
2
+ ConvBn1d,
3
+ ConvBn2d,
4
+ ConvBn3d,
5
+ ConvBnReLU1d,
6
+ ConvBnReLU2d,
7
+ ConvBnReLU3d,
8
+ ConvReLU1d,
9
+ ConvReLU2d,
10
+ ConvReLU3d,
11
+ freeze_bn_stats,
12
+ update_bn_stats,
13
+ )
14
+ from .linear_fused import LinearBn1d
15
+ from .linear_relu import LinearReLU
16
+
17
+
18
+ __all__ = [
19
+ "LinearReLU",
20
+ "LinearBn1d",
21
+ "ConvReLU1d",
22
+ "ConvReLU2d",
23
+ "ConvReLU3d",
24
+ "ConvBn1d",
25
+ "ConvBn2d",
26
+ "ConvBn3d",
27
+ "ConvBnReLU1d",
28
+ "ConvBnReLU2d",
29
+ "ConvBnReLU3d",
30
+ "update_bn_stats",
31
+ "freeze_bn_stats",
32
+ ]
.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/modules/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (683 Bytes). View file
 
.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/modules/__pycache__/conv_fused.cpython-312.pyc ADDED
Binary file (34.2 kB). View file
 
.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/modules/__pycache__/linear_fused.cpython-312.pyc ADDED
Binary file (8.64 kB). View file
 
.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/modules/__pycache__/linear_relu.cpython-312.pyc ADDED
Binary file (3.24 kB). View file
 
.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/modules/conv_fused.py ADDED
@@ -0,0 +1,1064 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import math
3
+ from typing import ClassVar, Optional
4
+
5
+ import torch
6
+ import torch.ao.nn.intrinsic as nni
7
+ import torch.ao.nn.qat as nnqat
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from torch.nn import init
11
+ from torch.nn.modules.utils import _pair, _single, _triple
12
+ from torch.nn.parameter import Parameter
13
+ from torch.nn.utils import fuse_conv_bn_weights
14
+
15
+
16
+ __all__ = [
17
+ "ConvBn1d",
18
+ "ConvBnReLU1d",
19
+ "ConvReLU1d",
20
+ "ConvBn2d",
21
+ "ConvBnReLU2d",
22
+ "ConvReLU2d",
23
+ "ConvBn3d",
24
+ "ConvBnReLU3d",
25
+ "ConvReLU3d",
26
+ "update_bn_stats",
27
+ "freeze_bn_stats",
28
+ ]
29
+ _BN_CLASS_MAP = {
30
+ 1: nn.BatchNorm1d,
31
+ 2: nn.BatchNorm2d,
32
+ 3: nn.BatchNorm3d,
33
+ }
34
+
35
+
36
+ class _ConvBnNd(nn.modules.conv._ConvNd, nni._FusedModule):
37
+ _version = 2
38
+ _FLOAT_MODULE: ClassVar[type[nn.modules.conv._ConvNd]]
39
+
40
+ def __init__(
41
+ self,
42
+ # ConvNd args
43
+ in_channels,
44
+ out_channels,
45
+ kernel_size,
46
+ stride,
47
+ padding,
48
+ dilation,
49
+ transposed,
50
+ output_padding,
51
+ groups,
52
+ bias,
53
+ padding_mode,
54
+ # BatchNormNd args
55
+ # num_features: out_channels
56
+ eps=1e-05,
57
+ momentum=0.1,
58
+ # affine: True
59
+ # track_running_stats: True
60
+ # Args for this module
61
+ freeze_bn=False,
62
+ qconfig=None,
63
+ dim=2,
64
+ ):
65
+ nn.modules.conv._ConvNd.__init__(
66
+ self,
67
+ in_channels,
68
+ out_channels,
69
+ kernel_size,
70
+ stride,
71
+ padding,
72
+ dilation,
73
+ transposed,
74
+ output_padding,
75
+ groups,
76
+ False,
77
+ padding_mode,
78
+ )
79
+ assert qconfig, "qconfig must be provided for QAT module"
80
+ self.qconfig = qconfig
81
+ self.freeze_bn = freeze_bn if self.training else True
82
+ self.bn = _BN_CLASS_MAP[dim](out_channels, eps, momentum, True, True)
83
+ self.weight_fake_quant = self.qconfig.weight()
84
+ if bias:
85
+ self.bias = Parameter(torch.empty(out_channels))
86
+ else:
87
+ self.register_parameter("bias", None)
88
+ self.reset_bn_parameters()
89
+
90
+ # this needs to be called after reset_bn_parameters,
91
+ # as they modify the same state
92
+ if self.training:
93
+ if freeze_bn:
94
+ self.freeze_bn_stats()
95
+ else:
96
+ self.update_bn_stats()
97
+ else:
98
+ self.freeze_bn_stats()
99
+
100
+ self._enable_slow_path_for_better_numerical_stability = False
101
+
102
+ def reset_running_stats(self):
103
+ self.bn.reset_running_stats()
104
+
105
+ def reset_bn_parameters(self):
106
+ self.bn.reset_running_stats()
107
+ init.uniform_(self.bn.weight)
108
+ init.zeros_(self.bn.bias)
109
+ # note: below is actually for conv, not BN
110
+ if self.bias is not None:
111
+ fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
112
+ bound = 1 / math.sqrt(fan_in)
113
+ init.uniform_(self.bias, -bound, bound)
114
+
115
+ def reset_parameters(self):
116
+ super().reset_parameters()
117
+
118
+ def update_bn_stats(self):
119
+ self.freeze_bn = False
120
+ self.bn.training = True
121
+ return self
122
+
123
+ def freeze_bn_stats(self):
124
+ self.freeze_bn = True
125
+ self.bn.training = False
126
+ return self
127
+
128
+ def _forward(self, input):
129
+ if self._enable_slow_path_for_better_numerical_stability:
130
+ return self._forward_slow(input)
131
+ return self._forward_approximate(input)
132
+
133
+ def _forward_approximate(self, input):
134
+ """Approximated method to fuse conv and bn. It requires only one forward pass.
135
+ conv_orig = conv / scale_factor where scale_factor = bn.weight / running_std
136
+ """
137
+ assert self.bn.running_var is not None
138
+ running_std = torch.sqrt(self.bn.running_var + self.bn.eps)
139
+ scale_factor = self.bn.weight / running_std
140
+ weight_shape = [1] * len(self.weight.shape)
141
+ weight_shape[0] = -1
142
+ bias_shape = [1] * len(self.weight.shape)
143
+ bias_shape[1] = -1
144
+ scaled_weight = self.weight_fake_quant(
145
+ self.weight * scale_factor.reshape(weight_shape)
146
+ )
147
+ # using zero bias here since the bias for original conv
148
+ # will be added later
149
+ if self.bias is not None:
150
+ zero_bias = torch.zeros_like(self.bias, dtype=input.dtype)
151
+ else:
152
+ zero_bias = torch.zeros(
153
+ self.out_channels, device=scaled_weight.device, dtype=input.dtype
154
+ )
155
+ conv = self._conv_forward(input, scaled_weight, zero_bias)
156
+ conv_orig = conv / scale_factor.reshape(bias_shape)
157
+ if self.bias is not None:
158
+ conv_orig = conv_orig + self.bias.reshape(bias_shape)
159
+ conv = self.bn(conv_orig)
160
+ return conv
161
+
162
+ def _forward_slow(self, input):
163
+ """
164
+ A more accurate but slow method to compute conv bn fusion, following https://arxiv.org/pdf/1806.08342.pdf
165
+ It requires two forward passes but handles the case bn.weight == 0
166
+
167
+ Conv: Y = WX + B_c
168
+ Conv without bias: Y0 = WX = Y - B_c, Y = Y0 + B_c
169
+
170
+ Batch statistics:
171
+ mean_Y = Y.mean()
172
+ = Y0.mean() + B_c
173
+ var_Y = (Y - mean_Y)^2.mean()
174
+ = (Y0 - Y0.mean())^2.mean()
175
+ BN (r: bn.weight, beta: bn.bias):
176
+ Z = r * (Y - mean_Y) / sqrt(var_Y + eps) + beta
177
+ = r * (Y0 - Y0.mean()) / sqrt(var_Y + eps) + beta
178
+
179
+ Fused Conv BN training (std_Y = sqrt(var_Y + eps)):
180
+ Z = (r * W / std_Y) * X + r * (B_c - mean_Y) / std_Y + beta
181
+ = (r * W / std_Y) * X - r * Y0.mean() / std_Y + beta
182
+
183
+ Fused Conv BN inference (running_std = sqrt(running_var + eps)):
184
+ Z = (r * W / running_std) * X - r * (running_mean - B_c) / running_std + beta
185
+
186
+ QAT with fused conv bn:
187
+ Z_train = fake_quant(r * W / running_std) * X * (running_std / std_Y) - r * Y0.mean() / std_Y + beta
188
+ = conv(X, fake_quant(r * W / running_std)) * (running_std / std_Y) - r * Y0.mean() / std_Y + beta
189
+ Z_inference = conv(X, fake_quant(r * W / running_std)) - r * (running_mean - B_c) / running_std + beta
190
+ """
191
+
192
+ assert self.bn.running_var is not None
193
+ assert self.bn.running_mean is not None
194
+
195
+ # using zero bias here since the bias for original conv
196
+ # will be added later
197
+ zero_bias = torch.zeros(
198
+ self.out_channels, device=self.weight.device, dtype=input.dtype
199
+ )
200
+
201
+ weight_shape = [1] * len(self.weight.shape)
202
+ weight_shape[0] = -1
203
+ bias_shape = [1] * len(self.weight.shape)
204
+ bias_shape[1] = -1
205
+
206
+ if self.bn.training:
207
+ # needed to compute batch mean/std
208
+ conv_out = self._conv_forward(input, self.weight, zero_bias)
209
+ # update bn statistics
210
+ with torch.no_grad():
211
+ conv_out_bias = (
212
+ conv_out
213
+ if self.bias is None
214
+ else conv_out + self.bias.reshape(bias_shape)
215
+ )
216
+ self.bn(conv_out_bias)
217
+
218
+ # fused conv + bn without bias using bn running statistics
219
+ running_std = torch.sqrt(self.bn.running_var + self.bn.eps)
220
+ scale_factor = self.bn.weight / running_std
221
+ scaled_weight = self.weight_fake_quant(
222
+ self.weight * scale_factor.reshape(weight_shape)
223
+ )
224
+ # fused conv without bias for inference: (r * W / running_std) * X
225
+ conv_bn = self._conv_forward(input, scaled_weight, zero_bias)
226
+
227
+ avg_dims = [0] + list(range(2, len(self.weight.shape)))
228
+ batch_mean = conv_out.mean(avg_dims)
229
+ batch_var = torch.square(conv_out - batch_mean.reshape(bias_shape)).mean(
230
+ avg_dims
231
+ )
232
+ batch_std = torch.sqrt(batch_var + self.bn.eps)
233
+
234
+ # scale to use batch std in training mode
235
+ # conv(X, r * W / std_Y) = conv(X, r * W / running_std) * (running_std / std_Y)
236
+ unscale_factor = running_std / batch_std
237
+ conv_bn *= unscale_factor.reshape(bias_shape)
238
+
239
+ fused_mean = batch_mean
240
+ fused_std = batch_std
241
+ else:
242
+ # fused conv + bn without bias using bn running statistics
243
+ running_std = torch.sqrt(self.bn.running_var + self.bn.eps)
244
+ scale_factor = self.bn.weight / running_std
245
+ scaled_weight = self.weight_fake_quant(
246
+ self.weight * scale_factor.reshape(weight_shape)
247
+ )
248
+ # fused conv without bias for inference: (r * W / running_std) * X
249
+ conv_bn = self._conv_forward(input, scaled_weight, zero_bias)
250
+
251
+ fused_mean = self.bn.running_mean - (
252
+ self.bias if self.bias is not None else 0
253
+ )
254
+ fused_std = running_std
255
+
256
+ # fused bias = beta - r * mean / std
257
+ fused_bias = self.bn.bias - self.bn.weight * fused_mean / fused_std
258
+ conv_bn += fused_bias.reshape(bias_shape)
259
+
260
+ # HACK to let conv bias participate in loss to avoid DDP error (parameters
261
+ # were not used in producing loss)
262
+ if self.bias is not None:
263
+ conv_bn += (self.bias - self.bias).reshape(bias_shape)
264
+
265
+ return conv_bn
266
+
267
+ def extra_repr(self):
268
+ # TODO(jerryzh): extend
269
+ return super().extra_repr()
270
+
271
+ def forward(self, input):
272
+ return self._forward(input)
273
+
274
+ def train(self, mode=True):
275
+ """
276
+ Batchnorm's training behavior is using the self.training flag. Prevent
277
+ changing it if BN is frozen. This makes sure that calling `model.train()`
278
+ on a model with a frozen BN will behave properly.
279
+ """
280
+ self.training = mode
281
+ if not self.freeze_bn:
282
+ for module in self.children():
283
+ module.train(mode)
284
+ return self
285
+
286
+ # ===== Serialization version history =====
287
+ #
288
+ # Version 1/None
289
+ # self
290
+ # |--- weight : Tensor
291
+ # |--- bias : Tensor
292
+ # |--- gamma : Tensor
293
+ # |--- beta : Tensor
294
+ # |--- running_mean : Tensor
295
+ # |--- running_var : Tensor
296
+ # |--- num_batches_tracked : Tensor
297
+ #
298
+ # Version 2
299
+ # self
300
+ # |--- weight : Tensor
301
+ # |--- bias : Tensor
302
+ # |--- bn : Module
303
+ # |--- weight : Tensor (moved from v1.self.gamma)
304
+ # |--- bias : Tensor (moved from v1.self.beta)
305
+ # |--- running_mean : Tensor (moved from v1.self.running_mean)
306
+ # |--- running_var : Tensor (moved from v1.self.running_var)
307
+ # |--- num_batches_tracked : Tensor (moved from v1.self.num_batches_tracked)
308
+ def _load_from_state_dict(
309
+ self,
310
+ state_dict,
311
+ prefix,
312
+ local_metadata,
313
+ strict,
314
+ missing_keys,
315
+ unexpected_keys,
316
+ error_msgs,
317
+ ):
318
+ version = local_metadata.get("version", None)
319
+ if version is None or version == 1:
320
+ # BN related parameters and buffers were moved into the BN module for v2
321
+ v2_to_v1_names = {
322
+ "bn.weight": "gamma",
323
+ "bn.bias": "beta",
324
+ "bn.running_mean": "running_mean",
325
+ "bn.running_var": "running_var",
326
+ "bn.num_batches_tracked": "num_batches_tracked",
327
+ }
328
+ for v2_name, v1_name in v2_to_v1_names.items():
329
+ if prefix + v1_name in state_dict:
330
+ state_dict[prefix + v2_name] = state_dict[prefix + v1_name]
331
+ state_dict.pop(prefix + v1_name)
332
+ elif prefix + v2_name in state_dict:
333
+ # there was a brief period where forward compatibility
334
+ # for this module was broken (between
335
+ # https://github.com/pytorch/pytorch/pull/38478
336
+ # and https://github.com/pytorch/pytorch/pull/38820)
337
+ # and modules emitted the v2 state_dict format while
338
+ # specifying that version == 1. This patches the forward
339
+ # compatibility issue by allowing the v2 style entries to
340
+ # be used.
341
+ pass
342
+ elif strict:
343
+ missing_keys.append(prefix + v2_name)
344
+
345
+ super()._load_from_state_dict(
346
+ state_dict,
347
+ prefix,
348
+ local_metadata,
349
+ strict,
350
+ missing_keys,
351
+ unexpected_keys,
352
+ error_msgs,
353
+ )
354
+
355
+ @classmethod
356
+ def from_float(cls, mod, use_precomputed_fake_quant=False):
357
+ r"""Create a qat module from a float module or qparams_dict
358
+
359
+ Args: `mod` a float module, either produced by torch.ao.quantization utilities
360
+ or directly from user
361
+ """
362
+ # The ignore is because _FLOAT_MODULE is a TypeVar here where the bound
363
+ # has no __name__ (code is fine though)
364
+ assert type(mod) == cls._FLOAT_MODULE, (
365
+ "qat."
366
+ + cls.__name__
367
+ + ".from_float only works for "
368
+ + cls._FLOAT_MODULE.__name__
369
+ )
370
+ assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
371
+ assert mod.qconfig, "Input float module must have a valid qconfig"
372
+ qconfig = mod.qconfig
373
+ conv, bn = mod[0], mod[1] # type: ignore[index]
374
+ qat_convbn = cls(
375
+ conv.in_channels,
376
+ conv.out_channels,
377
+ conv.kernel_size,
378
+ conv.stride,
379
+ conv.padding,
380
+ conv.dilation,
381
+ conv.groups,
382
+ conv.bias is not None,
383
+ conv.padding_mode,
384
+ bn.eps,
385
+ bn.momentum,
386
+ False,
387
+ qconfig,
388
+ )
389
+ qat_convbn.weight = conv.weight
390
+ qat_convbn.bias = conv.bias
391
+ qat_convbn.bn.weight = bn.weight
392
+ qat_convbn.bn.bias = bn.bias
393
+ qat_convbn.bn.running_mean = bn.running_mean
394
+ qat_convbn.bn.running_var = bn.running_var
395
+ # mypy error: Cannot determine type of 'num_batches_tracked'
396
+ qat_convbn.bn.num_batches_tracked = bn.num_batches_tracked
397
+ return qat_convbn
398
+
399
+ def to_float(self):
400
+ cls = type(self)
401
+ conv = cls._FLOAT_CONV_MODULE( # type: ignore[attr-defined]
402
+ self.in_channels,
403
+ self.out_channels,
404
+ self.kernel_size,
405
+ self.stride,
406
+ self.padding,
407
+ self.dilation,
408
+ self.groups,
409
+ self.bias is not None,
410
+ self.padding_mode,
411
+ )
412
+ conv.weight = torch.nn.Parameter(self.weight.detach())
413
+ if self.bias is not None:
414
+ conv.bias = torch.nn.Parameter(self.bias.detach())
415
+
416
+ if cls._FLOAT_BN_MODULE: # type: ignore[attr-defined]
417
+ # fuse bn into conv
418
+ assert self.bn.running_var is not None and self.bn.running_mean is not None
419
+ conv.weight, conv.bias = fuse_conv_bn_weights(
420
+ conv.weight,
421
+ conv.bias,
422
+ self.bn.running_mean,
423
+ self.bn.running_var,
424
+ self.bn.eps,
425
+ self.bn.weight,
426
+ self.bn.bias,
427
+ )
428
+
429
+ if cls._FLOAT_RELU_MODULE: # type: ignore[attr-defined]
430
+ modules = []
431
+ modules.append(conv)
432
+ relu = cls._FLOAT_RELU_MODULE() # type: ignore[attr-defined]
433
+ modules.append(relu)
434
+ conv_relu = cls._FUSED_FLOAT_MODULE(*modules) # type: ignore[attr-defined]
435
+ conv_relu.train(self.training)
436
+ return conv_relu
437
+ else:
438
+ conv.train(self.training)
439
+ return conv
440
+
441
+
442
+ class ConvBn1d(_ConvBnNd, nn.Conv1d):
443
+ r"""
444
+ A ConvBn1d module is a module fused from Conv1d and BatchNorm1d,
445
+ attached with FakeQuantize modules for weight,
446
+ used in quantization aware training.
447
+
448
+ We combined the interface of :class:`torch.nn.Conv1d` and
449
+ :class:`torch.nn.BatchNorm1d`.
450
+
451
+ Similar to :class:`torch.nn.Conv1d`, with FakeQuantize modules initialized
452
+ to default.
453
+
454
+ Attributes:
455
+ freeze_bn:
456
+ weight_fake_quant: fake quant module for weight
457
+
458
+ """
459
+
460
+ _FLOAT_BN_MODULE: ClassVar[type[nn.BatchNorm1d]] = nn.BatchNorm1d
461
+ _FLOAT_RELU_MODULE: ClassVar[Optional[type[nn.Module]]] = None
462
+ _FLOAT_MODULE: ClassVar[type[nn.Module]] = nni.ConvBn1d # type: ignore[assignment]
463
+ _FLOAT_CONV_MODULE: ClassVar[type[nn.Conv1d]] = nn.Conv1d
464
+
465
+ def __init__(
466
+ self,
467
+ # Conv1d args
468
+ in_channels,
469
+ out_channels,
470
+ kernel_size,
471
+ stride=1,
472
+ padding=0,
473
+ dilation=1,
474
+ groups=1,
475
+ bias=None,
476
+ padding_mode="zeros",
477
+ # BatchNorm1d args
478
+ # num_features: out_channels
479
+ eps=1e-05,
480
+ momentum=0.1,
481
+ # affine: True
482
+ # track_running_stats: True
483
+ # Args for this module
484
+ freeze_bn=False,
485
+ qconfig=None,
486
+ ):
487
+ kernel_size = _single(kernel_size)
488
+ stride = _single(stride)
489
+ padding = _single(padding)
490
+ dilation = _single(dilation)
491
+ _ConvBnNd.__init__(
492
+ self,
493
+ in_channels,
494
+ out_channels,
495
+ kernel_size,
496
+ stride,
497
+ padding,
498
+ dilation,
499
+ False,
500
+ _single(0),
501
+ groups,
502
+ bias,
503
+ padding_mode,
504
+ eps,
505
+ momentum,
506
+ freeze_bn,
507
+ qconfig,
508
+ dim=1,
509
+ )
510
+
511
+
512
+ class ConvBnReLU1d(ConvBn1d):
513
+ r"""
514
+ A ConvBnReLU1d module is a module fused from Conv1d, BatchNorm1d and ReLU,
515
+ attached with FakeQuantize modules for weight,
516
+ used in quantization aware training.
517
+
518
+ We combined the interface of :class:`torch.nn.Conv1d` and
519
+ :class:`torch.nn.BatchNorm1d` and :class:`torch.nn.ReLU`.
520
+
521
+ Similar to `torch.nn.Conv1d`, with FakeQuantize modules initialized to
522
+ default.
523
+
524
+ Attributes:
525
+ weight_fake_quant: fake quant module for weight
526
+
527
+ """
528
+
529
+ # base class defines _FLOAT_MODULE as "ConvBn1d"
530
+ _FLOAT_MODULE: ClassVar[type[nn.Module]] = nni.ConvBnReLU1d
531
+ _FLOAT_CONV_MODULE: ClassVar[type[nn.Conv1d]] = nn.Conv1d
532
+ _FLOAT_BN_MODULE: ClassVar[type[nn.BatchNorm1d]] = nn.BatchNorm1d
533
+ _FLOAT_RELU_MODULE: ClassVar[Optional[type[nn.Module]]] = nn.ReLU
534
+ # module class after fusing bn into conv
535
+ _FUSED_FLOAT_MODULE: ClassVar[Optional[type[nn.Module]]] = nni.ConvReLU1d
536
+
537
+ def __init__(
538
+ self,
539
+ # Conv1d args
540
+ in_channels,
541
+ out_channels,
542
+ kernel_size,
543
+ stride=1,
544
+ padding=0,
545
+ dilation=1,
546
+ groups=1,
547
+ bias=None,
548
+ padding_mode="zeros",
549
+ # BatchNorm1d args
550
+ # num_features: out_channels
551
+ eps=1e-05,
552
+ momentum=0.1,
553
+ # affine: True
554
+ # track_running_stats: True
555
+ # Args for this module
556
+ freeze_bn=False,
557
+ qconfig=None,
558
+ ):
559
+ super().__init__(
560
+ in_channels,
561
+ out_channels,
562
+ kernel_size,
563
+ stride,
564
+ padding,
565
+ dilation,
566
+ groups,
567
+ bias,
568
+ padding_mode,
569
+ eps,
570
+ momentum,
571
+ freeze_bn,
572
+ qconfig,
573
+ )
574
+
575
+ def forward(self, input):
576
+ return F.relu(self._forward(input))
577
+
578
+ @classmethod
579
+ def from_float(cls, mod, use_precomputed_fake_quant=False):
580
+ return super().from_float(mod, use_precomputed_fake_quant)
581
+
582
+
583
+ class ConvReLU1d(nnqat.Conv1d, nni._FusedModule):
584
+ r"""A ConvReLU1d module is a fused module of Conv1d and ReLU, attached with
585
+ FakeQuantize modules for weight for
586
+ quantization aware training.
587
+
588
+ We combined the interface of :class:`~torch.nn.Conv1d` and
589
+ :class:`~torch.nn.BatchNorm1d`.
590
+
591
+ Attributes:
592
+ weight_fake_quant: fake quant module for weight
593
+
594
+ """
595
+
596
+ _FLOAT_MODULE: ClassVar[type[nni.ConvReLU1d]] = nni.ConvReLU1d # type: ignore[assignment]
597
+ _FLOAT_CONV_MODULE: ClassVar[type[nn.Conv1d]] = nn.Conv1d
598
+ _FLOAT_BN_MODULE: ClassVar[Optional[type[nn.Module]]] = None
599
+ _FLOAT_RELU_MODULE: ClassVar[Optional[type[nn.Module]]] = nn.ReLU
600
+
601
+ def __init__(
602
+ self,
603
+ in_channels,
604
+ out_channels,
605
+ kernel_size,
606
+ stride=1,
607
+ padding=0,
608
+ dilation=1,
609
+ groups=1,
610
+ bias=True,
611
+ padding_mode="zeros",
612
+ qconfig=None,
613
+ ):
614
+ super().__init__(
615
+ in_channels,
616
+ out_channels,
617
+ kernel_size,
618
+ stride=stride,
619
+ padding=padding,
620
+ dilation=dilation,
621
+ groups=groups,
622
+ bias=bias,
623
+ padding_mode=padding_mode,
624
+ qconfig=qconfig,
625
+ )
626
+ assert qconfig, "qconfig must be provided for QAT module"
627
+ self.qconfig = qconfig
628
+ self.weight_fake_quant = self.qconfig.weight()
629
+
630
+ def forward(self, input):
631
+ return F.relu(
632
+ self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias)
633
+ )
634
+
635
+ @classmethod
636
+ def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override]
637
+ return super().from_float(
638
+ mod, use_precomputed_fake_quant=use_precomputed_fake_quant
639
+ )
640
+
641
+
642
+ class ConvBn2d(_ConvBnNd, nn.Conv2d):
643
+ r"""
644
+ A ConvBn2d module is a module fused from Conv2d and BatchNorm2d,
645
+ attached with FakeQuantize modules for weight,
646
+ used in quantization aware training.
647
+
648
+ We combined the interface of :class:`torch.nn.Conv2d` and
649
+ :class:`torch.nn.BatchNorm2d`.
650
+
651
+ Similar to :class:`torch.nn.Conv2d`, with FakeQuantize modules initialized
652
+ to default.
653
+
654
+ Attributes:
655
+ freeze_bn:
656
+ weight_fake_quant: fake quant module for weight
657
+
658
+ """
659
+
660
+ _FLOAT_MODULE: ClassVar[type[nni.ConvBn2d]] = nni.ConvBn2d # type: ignore[assignment]
661
+ _FLOAT_CONV_MODULE: ClassVar[type[nn.Conv2d]] = nn.Conv2d
662
+ _FLOAT_BN_MODULE: ClassVar[Optional[type[nn.Module]]] = nn.BatchNorm2d
663
+ _FLOAT_RELU_MODULE: ClassVar[Optional[type[nn.Module]]] = None
664
+
665
+ def __init__(
666
+ self,
667
+ # ConvNd args
668
+ in_channels,
669
+ out_channels,
670
+ kernel_size,
671
+ stride=1,
672
+ padding=0,
673
+ dilation=1,
674
+ groups=1,
675
+ bias=None,
676
+ padding_mode="zeros",
677
+ # BatchNorm2d args
678
+ # num_features: out_channels
679
+ eps=1e-05,
680
+ momentum=0.1,
681
+ # affine: True
682
+ # track_running_stats: True
683
+ # Args for this module
684
+ freeze_bn=False,
685
+ qconfig=None,
686
+ ):
687
+ kernel_size = _pair(kernel_size)
688
+ stride = _pair(stride)
689
+ padding = _pair(padding)
690
+ dilation = _pair(dilation)
691
+ _ConvBnNd.__init__(
692
+ self,
693
+ in_channels,
694
+ out_channels,
695
+ kernel_size,
696
+ stride,
697
+ padding,
698
+ dilation,
699
+ False,
700
+ _pair(0),
701
+ groups,
702
+ bias,
703
+ padding_mode,
704
+ eps,
705
+ momentum,
706
+ freeze_bn,
707
+ qconfig,
708
+ dim=2,
709
+ )
710
+
711
+
712
+ class ConvBnReLU2d(ConvBn2d):
713
+ r"""
714
+ A ConvBnReLU2d module is a module fused from Conv2d, BatchNorm2d and ReLU,
715
+ attached with FakeQuantize modules for weight,
716
+ used in quantization aware training.
717
+
718
+ We combined the interface of :class:`torch.nn.Conv2d` and
719
+ :class:`torch.nn.BatchNorm2d` and :class:`torch.nn.ReLU`.
720
+
721
+ Similar to `torch.nn.Conv2d`, with FakeQuantize modules initialized to
722
+ default.
723
+
724
+ Attributes:
725
+ weight_fake_quant: fake quant module for weight
726
+
727
+ """
728
+
729
+ # base class defines _FLOAT_MODULE as "ConvBn2d"
730
+ _FLOAT_MODULE: ClassVar[type[nni.ConvBnReLU2d]] = nni.ConvBnReLU2d # type: ignore[assignment]
731
+ _FLOAT_CONV_MODULE: ClassVar[type[nn.Conv2d]] = nn.Conv2d
732
+ _FLOAT_BN_MODULE: ClassVar[type[nn.BatchNorm2d]] = nn.BatchNorm2d
733
+ _FLOAT_RELU_MODULE: ClassVar[Optional[type[nn.Module]]] = nn.ReLU
734
+ # module class after fusing bn into conv
735
+ _FUSED_FLOAT_MODULE: ClassVar[Optional[type[nni.ConvReLU2d]]] = nni.ConvReLU2d
736
+
737
+ def __init__(
738
+ self,
739
+ # Conv2d args
740
+ in_channels,
741
+ out_channels,
742
+ kernel_size,
743
+ stride=1,
744
+ padding=0,
745
+ dilation=1,
746
+ groups=1,
747
+ bias=None,
748
+ padding_mode="zeros",
749
+ # BatchNorm2d args
750
+ # num_features: out_channels
751
+ eps=1e-05,
752
+ momentum=0.1,
753
+ # affine: True
754
+ # track_running_stats: True
755
+ # Args for this module
756
+ freeze_bn=False,
757
+ qconfig=None,
758
+ ):
759
+ super().__init__(
760
+ in_channels,
761
+ out_channels,
762
+ kernel_size,
763
+ stride,
764
+ padding,
765
+ dilation,
766
+ groups,
767
+ bias,
768
+ padding_mode,
769
+ eps,
770
+ momentum,
771
+ freeze_bn,
772
+ qconfig,
773
+ )
774
+
775
+ def forward(self, input):
776
+ return F.relu(self._forward(input))
777
+
778
+ @classmethod
779
+ def from_float(cls, mod, use_precomputed_fake_quant=False):
780
+ return super().from_float(mod, use_precomputed_fake_quant)
781
+
782
+
783
+ class ConvReLU2d(nnqat.Conv2d, nni._FusedModule):
784
+ r"""A ConvReLU2d module is a fused module of Conv2d and ReLU, attached with
785
+ FakeQuantize modules for weight for
786
+ quantization aware training.
787
+
788
+ We combined the interface of :class:`~torch.nn.Conv2d` and
789
+ :class:`~torch.nn.BatchNorm2d`.
790
+
791
+ Attributes:
792
+ weight_fake_quant: fake quant module for weight
793
+
794
+ """
795
+
796
+ _FLOAT_MODULE: ClassVar[type[nn.Module]] = nni.ConvReLU2d # type: ignore[assignment]
797
+ _FLOAT_CONV_MODULE: ClassVar[type[nn.Conv2d]] = nn.Conv2d
798
+ _FLOAT_BN_MODULE: ClassVar[Optional[type[nn.Module]]] = None
799
+ _FLOAT_RELU_MODULE: ClassVar[Optional[type[nn.Module]]] = nn.ReLU
800
+
801
+ def __init__(
802
+ self,
803
+ in_channels,
804
+ out_channels,
805
+ kernel_size,
806
+ stride=1,
807
+ padding=0,
808
+ dilation=1,
809
+ groups=1,
810
+ bias=True,
811
+ padding_mode="zeros",
812
+ qconfig=None,
813
+ ):
814
+ super().__init__(
815
+ in_channels,
816
+ out_channels,
817
+ kernel_size,
818
+ stride=stride,
819
+ padding=padding,
820
+ dilation=dilation,
821
+ groups=groups,
822
+ bias=bias,
823
+ padding_mode=padding_mode,
824
+ qconfig=qconfig,
825
+ )
826
+ assert qconfig, "qconfig must be provided for QAT module"
827
+ self.qconfig = qconfig
828
+ self.weight_fake_quant = self.qconfig.weight()
829
+
830
+ def forward(self, input):
831
+ return F.relu(
832
+ self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias)
833
+ )
834
+
835
+ @classmethod
836
+ def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override]
837
+ return super().from_float(
838
+ mod, use_precomputed_fake_quant=use_precomputed_fake_quant
839
+ )
840
+
841
+
842
+ class ConvBn3d(_ConvBnNd, nn.Conv3d):
843
+ r"""
844
+ A ConvBn3d module is a module fused from Conv3d and BatchNorm3d,
845
+ attached with FakeQuantize modules for weight,
846
+ used in quantization aware training.
847
+
848
+ We combined the interface of :class:`torch.nn.Conv3d` and
849
+ :class:`torch.nn.BatchNorm3d`.
850
+
851
+ Similar to :class:`torch.nn.Conv3d`, with FakeQuantize modules initialized
852
+ to default.
853
+
854
+ Attributes:
855
+ freeze_bn:
856
+ weight_fake_quant: fake quant module for weight
857
+
858
+ """
859
+
860
+ _FLOAT_MODULE: ClassVar[type[nni.ConvBn3d]] = nni.ConvBn3d # type: ignore[assignment]
861
+ _FLOAT_CONV_MODULE: ClassVar[type[nn.Conv3d]] = nn.Conv3d
862
+ _FLOAT_BN_MODULE: ClassVar[Optional[type[nn.Module]]] = nn.BatchNorm3d
863
+ _FLOAT_RELU_MODULE: ClassVar[Optional[type[nn.Module]]] = None
864
+
865
+ def __init__(
866
+ self,
867
+ # ConvNd args
868
+ in_channels,
869
+ out_channels,
870
+ kernel_size,
871
+ stride=1,
872
+ padding=0,
873
+ dilation=1,
874
+ groups=1,
875
+ bias=None,
876
+ padding_mode="zeros",
877
+ # BatchNorm3d args
878
+ # num_features: out_channels
879
+ eps=1e-05,
880
+ momentum=0.1,
881
+ # affine: True
882
+ # track_running_stats: True
883
+ # Args for this module
884
+ freeze_bn=False,
885
+ qconfig=None,
886
+ ):
887
+ kernel_size = _triple(kernel_size)
888
+ stride = _triple(stride)
889
+ padding = _triple(padding)
890
+ dilation = _triple(dilation)
891
+ _ConvBnNd.__init__(
892
+ self,
893
+ in_channels,
894
+ out_channels,
895
+ kernel_size,
896
+ stride,
897
+ padding,
898
+ dilation,
899
+ False,
900
+ _triple(0),
901
+ groups,
902
+ bias,
903
+ padding_mode,
904
+ eps,
905
+ momentum,
906
+ freeze_bn,
907
+ qconfig,
908
+ dim=3,
909
+ )
910
+
911
+
912
+ class ConvBnReLU3d(ConvBn3d):
913
+ r"""
914
+ A ConvBnReLU3d module is a module fused from Conv3d, BatchNorm3d and ReLU,
915
+ attached with FakeQuantize modules for weight,
916
+ used in quantization aware training.
917
+
918
+ We combined the interface of :class:`torch.nn.Conv3d` and
919
+ :class:`torch.nn.BatchNorm3d` and :class:`torch.nn.ReLU`.
920
+
921
+ Similar to `torch.nn.Conv3d`, with FakeQuantize modules initialized to
922
+ default.
923
+
924
+ Attributes:
925
+ weight_fake_quant: fake quant module for weight
926
+
927
+ """
928
+
929
+ _FLOAT_MODULE: ClassVar[type[nni.ConvBnReLU3d]] = nni.ConvBnReLU3d # type: ignore[assignment]
930
+ _FLOAT_CONV_MODULE: ClassVar[type[nn.Conv3d]] = nn.Conv3d
931
+ _FLOAT_BN_MODULE: ClassVar[type[nn.BatchNorm3d]] = nn.BatchNorm3d
932
+ _FLOAT_RELU_MODULE: ClassVar[Optional[type[nn.ReLU]]] = nn.ReLU
933
+ # module class after fusing bn into conv
934
+ _FUSED_FLOAT_MODULE: ClassVar[Optional[type[nni.ConvReLU3d]]] = nni.ConvReLU3d
935
+
936
+ def __init__(
937
+ self,
938
+ # Conv3d args
939
+ in_channels,
940
+ out_channels,
941
+ kernel_size,
942
+ stride=1,
943
+ padding=0,
944
+ dilation=1,
945
+ groups=1,
946
+ bias=None,
947
+ padding_mode="zeros",
948
+ # BatchNorm3d args
949
+ # num_features: out_channels
950
+ eps=1e-05,
951
+ momentum=0.1,
952
+ # affine: True
953
+ # track_running_stats: True
954
+ # Args for this module
955
+ freeze_bn=False,
956
+ qconfig=None,
957
+ ):
958
+ super().__init__(
959
+ in_channels,
960
+ out_channels,
961
+ kernel_size,
962
+ stride,
963
+ padding,
964
+ dilation,
965
+ groups,
966
+ bias,
967
+ padding_mode,
968
+ eps,
969
+ momentum,
970
+ freeze_bn,
971
+ qconfig,
972
+ )
973
+
974
+ def forward(self, input):
975
+ return F.relu(ConvBn3d._forward(self, input))
976
+
977
+ @classmethod
978
+ def from_float(cls, mod, use_precomputed_fake_quant=False):
979
+ return super().from_float(
980
+ mod, use_precomputed_fake_quant=use_precomputed_fake_quant
981
+ )
982
+
983
+
984
+ class ConvReLU3d(nnqat.Conv3d, nni._FusedModule):
985
+ r"""A ConvReLU3d module is a fused module of Conv3d and ReLU, attached with
986
+ FakeQuantize modules for weight for
987
+ quantization aware training.
988
+
989
+ We combined the interface of :class:`~torch.nn.Conv3d` and
990
+ :class:`~torch.nn.BatchNorm3d`.
991
+
992
+ Attributes:
993
+ weight_fake_quant: fake quant module for weight
994
+
995
+ """
996
+
997
+ _FLOAT_MODULE: ClassVar[type[nni.ConvReLU3d]] = nni.ConvReLU3d # type: ignore[assignment]
998
+ _FLOAT_CONV_MODULE: ClassVar[type[nn.Conv3d]] = nn.Conv3d
999
+ _FLOAT_BN_MODULE: ClassVar[Optional[type[nn.Module]]] = None
1000
+ _FLOAT_RELU_MODULE: ClassVar[Optional[type[nn.Module]]] = nn.ReLU
1001
+
1002
+ def __init__(
1003
+ self,
1004
+ in_channels,
1005
+ out_channels,
1006
+ kernel_size,
1007
+ stride=1,
1008
+ padding=0,
1009
+ dilation=1,
1010
+ groups=1,
1011
+ bias=True,
1012
+ padding_mode="zeros",
1013
+ qconfig=None,
1014
+ ):
1015
+ super().__init__(
1016
+ in_channels,
1017
+ out_channels,
1018
+ kernel_size,
1019
+ stride=stride,
1020
+ padding=padding,
1021
+ dilation=dilation,
1022
+ groups=groups,
1023
+ bias=bias,
1024
+ padding_mode=padding_mode,
1025
+ qconfig=qconfig,
1026
+ )
1027
+ assert qconfig, "qconfig must be provided for QAT module"
1028
+ self.qconfig = qconfig
1029
+ self.weight_fake_quant = self.qconfig.weight()
1030
+
1031
+ def forward(self, input):
1032
+ return F.relu(
1033
+ self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias)
1034
+ )
1035
+
1036
+ @classmethod
1037
+ def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override]
1038
+ return super().from_float(
1039
+ mod, use_precomputed_fake_quant=use_precomputed_fake_quant
1040
+ )
1041
+
1042
+
1043
+ def update_bn_stats(mod):
1044
+ if type(mod) in {
1045
+ ConvBnReLU1d,
1046
+ ConvBnReLU2d,
1047
+ ConvBnReLU3d,
1048
+ ConvBn1d,
1049
+ ConvBn2d,
1050
+ ConvBn3d,
1051
+ }:
1052
+ mod.update_bn_stats()
1053
+
1054
+
1055
+ def freeze_bn_stats(mod):
1056
+ if type(mod) in {
1057
+ ConvBnReLU1d,
1058
+ ConvBnReLU2d,
1059
+ ConvBnReLU3d,
1060
+ ConvBn1d,
1061
+ ConvBn2d,
1062
+ ConvBn3d,
1063
+ }:
1064
+ mod.freeze_bn_stats()
.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/modules/linear_fused.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import torch
3
+ import torch.ao.nn.intrinsic as nni
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from torch.nn import init
7
+ from torch.nn.parameter import Parameter
8
+ from torch.nn.utils.fusion import fuse_linear_bn_weights
9
+
10
+
11
+ __all__ = [
12
+ "LinearBn1d",
13
+ ]
14
+
15
+
16
+ class LinearBn1d(nn.modules.linear.Linear, nni._FusedModule):
17
+ r"""
18
+ A LinearBn1d module is a module fused from Linear and BatchNorm1d, attached
19
+ with FakeQuantize modules for weight, used in quantization aware training.
20
+
21
+ We combined the interface of :class:`torch.nn.Linear` and
22
+ :class:torch.nn.BatchNorm1d`.
23
+
24
+ Similar to :class:`torch.nn.Linear`, with FakeQuantize modules initialized
25
+ to default.
26
+
27
+ Attributes:
28
+ freeze_bn:
29
+ weight_fake_quant: fake quant module for weight
30
+
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ # Linear args
36
+ in_features,
37
+ out_features,
38
+ bias=True,
39
+ # BatchNorm1d args
40
+ # num_features: out_features
41
+ eps=1e-05,
42
+ momentum=0.1,
43
+ # affine: True
44
+ # track_running_stats: True
45
+ # Args for this module
46
+ freeze_bn=False,
47
+ qconfig=None,
48
+ ):
49
+ nn.modules.linear.Linear.__init__(self, in_features, out_features, bias)
50
+ assert qconfig, "qconfig must be provided for QAT module"
51
+ self.qconfig = qconfig
52
+ self.freeze_bn = freeze_bn if self.training else True
53
+ self.bn = nn.BatchNorm1d(out_features, eps, momentum, True, True)
54
+ self.weight_fake_quant = self.qconfig.weight()
55
+ if bias:
56
+ self.bias = Parameter(torch.empty(out_features))
57
+ else:
58
+ self.register_parameter("bias", None)
59
+ self.reset_bn_parameters()
60
+
61
+ # this needs to be called after reset_bn_parameters,
62
+ # as they modify the same state
63
+ if self.training:
64
+ if freeze_bn:
65
+ self.freeze_bn_stats()
66
+ else:
67
+ self.update_bn_stats()
68
+ else:
69
+ self.freeze_bn_stats()
70
+
71
+ def reset_running_stats(self):
72
+ self.bn.reset_running_stats()
73
+
74
+ def reset_bn_parameters(self):
75
+ self.bn.reset_running_stats()
76
+ init.uniform_(self.bn.weight)
77
+ init.zeros_(self.bn.bias)
78
+
79
+ def reset_parameters(self):
80
+ super().reset_parameters()
81
+
82
+ def update_bn_stats(self):
83
+ self.freeze_bn = False
84
+ self.bn.training = True
85
+ return self
86
+
87
+ def freeze_bn_stats(self):
88
+ self.freeze_bn = True
89
+ self.bn.training = False
90
+ return self
91
+
92
+ def forward(self, input):
93
+ assert self.bn.running_var is not None
94
+
95
+ # Scale the linear weights by BN's running statistics to reduce
96
+ # weight jitter, see https://arxiv.org/pdf/1806.08342.pdf, page 18
97
+ # for motivation.
98
+ #
99
+ # Instead of
100
+ #
101
+ # x1 = F.linear(x0, fq(w), b)
102
+ # x2 = self.bn(x1)
103
+ #
104
+ # We have
105
+ #
106
+ # # scale the weight by previous batch's running statistics
107
+ # scale_factor = bn.w / bn.running_std_from_prev_batch
108
+ # # do the linear transformation without bias
109
+ # x1_scaled = F.linear(x0, fq(w * scale_factor), 0)
110
+ # # reverse the scaling and add original bias
111
+ # x1_orig = x1_scaled / scale_factor + b
112
+ # x2 = self.bn(x1_orig)
113
+
114
+ running_std = torch.sqrt(self.bn.running_var + self.bn.eps)
115
+ scale_factor = self.bn.weight / running_std
116
+ weight_shape = [1] * len(self.weight.shape)
117
+ weight_shape[0] = -1
118
+ bias_shape = [1] * len(self.weight.shape)
119
+ bias_shape[1] = -1
120
+ scaled_weight = self.weight_fake_quant(
121
+ self.weight * scale_factor.reshape(weight_shape)
122
+ )
123
+ if self.bias is not None:
124
+ zero_bias = torch.zeros_like(self.bias)
125
+ else:
126
+ zero_bias = torch.zeros(self.out_features, device=scaled_weight.device)
127
+ linear_out = F.linear(input, scaled_weight, zero_bias)
128
+ linear_out_orig = linear_out / scale_factor.reshape(bias_shape)
129
+ if self.bias is not None:
130
+ linear_out_orig = linear_out_orig + self.bias.reshape(bias_shape)
131
+ bn_out = self.bn(linear_out_orig)
132
+ return bn_out
133
+
134
+ def train(self, mode=True):
135
+ """
136
+ Batchnorm's training behavior is using the self.training flag. Prevent
137
+ changing it if BN is frozen. This makes sure that calling `model.train()`
138
+ on a model with a frozen BN will behave properly.
139
+ """
140
+ self.training = mode
141
+ if not self.freeze_bn:
142
+ for module in self.children():
143
+ module.train(mode)
144
+ return self
145
+
146
+ @classmethod
147
+ def from_float(cls, mod, use_precomputed_fake_quant=False):
148
+ r"""Create a qat module from a float module or qparams_dict
149
+
150
+ Args: `mod' a float module, either produced by torch.ao.quantization
151
+ utilities or directly from user
152
+ """
153
+ assert type(mod) == nni.LinearBn1d, (
154
+ "qat."
155
+ + cls.__name__
156
+ + ".from_float only works for "
157
+ + nni.LinearBn1d.__name__
158
+ )
159
+ assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
160
+ assert mod.qconfig, "Input float module must have a valid config"
161
+ qconfig = mod.qconfig
162
+ linear, bn = mod[0], mod[1]
163
+ qat_linearbn = cls(
164
+ linear.in_features,
165
+ linear.out_features,
166
+ linear.bias is not None,
167
+ bn.eps,
168
+ bn.momentum,
169
+ False,
170
+ qconfig,
171
+ )
172
+ qat_linearbn.weight = linear.weight # type: ignore[assignment]
173
+ qat_linearbn.bias = linear.bias # type: ignore[assignment]
174
+ qat_linearbn.bn.weight = bn.weight # type: ignore[assignment]
175
+ qat_linearbn.bn.bias = bn.bias # type: ignore[assignment]
176
+ qat_linearbn.bn.running_mean = bn.running_mean # type: ignore[assignment]
177
+ qat_linearbn.bn.running_var = bn.running_var # type: ignore[assignment]
178
+ qat_linearbn.bn.num_batches_tracked = bn.num_batches_tracked # type: ignore[assignment]
179
+ return qat_linearbn
180
+
181
+ def to_float(self):
182
+ linear = torch.nn.Linear(self.in_features, self.out_features)
183
+ assert self.bn.running_var is not None and self.bn.running_mean is not None
184
+ linear.weight, linear.bias = fuse_linear_bn_weights(
185
+ self.weight,
186
+ self.bias,
187
+ self.bn.running_mean,
188
+ self.bn.running_var,
189
+ self.bn.eps,
190
+ self.bn.weight,
191
+ self.bn.bias,
192
+ )
193
+ return linear
.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/modules/linear_relu.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import torch
3
+ import torch.ao.nn.intrinsic as nni
4
+ import torch.ao.nn.qat as nnqat
5
+ import torch.nn.functional as F
6
+
7
+
8
+ class LinearReLU(nnqat.Linear, nni._FusedModule):
9
+ r"""
10
+ A LinearReLU module fused from Linear and ReLU modules, attached with
11
+ FakeQuantize modules for weight, used in
12
+ quantization aware training.
13
+
14
+ We adopt the same interface as :class:`torch.nn.Linear`.
15
+
16
+ Similar to `torch.ao.nn.intrinsic.LinearReLU`, with FakeQuantize modules initialized to
17
+ default.
18
+
19
+ Attributes:
20
+ weight: fake quant module for weight
21
+
22
+ Examples::
23
+
24
+ >>> # xdoctest: +SKIP
25
+ >>> m = nn.qat.LinearReLU(20, 30)
26
+ >>> input = torch.randn(128, 20)
27
+ >>> output = m(input)
28
+ >>> print(output.size())
29
+ torch.Size([128, 30])
30
+ """
31
+
32
+ _FLOAT_MODULE = nni.LinearReLU # type: ignore[assignment]
33
+
34
+ def __init__(self, in_features, out_features, bias=True, qconfig=None):
35
+ super().__init__(in_features, out_features, bias, qconfig)
36
+
37
+ def forward(self, input):
38
+ return F.relu(F.linear(input, self.weight_fake_quant(self.weight), self.bias))
39
+
40
+ @classmethod
41
+ def from_float(cls, mod, use_precomputed_fake_quant=False):
42
+ return super().from_float(mod, use_precomputed_fake_quant)
43
+
44
+ def to_float(self):
45
+ linear = torch.nn.Linear(
46
+ self.in_features, self.out_features, self.bias is not None
47
+ )
48
+ linear.weight = torch.nn.Parameter(self.weight.detach())
49
+ if self.bias is not None:
50
+ linear.bias = torch.nn.Parameter(self.bias.detach())
51
+ relu = torch.nn.ReLU()
52
+ return torch.ao.nn.intrinsic.LinearReLU(linear, relu)
.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .modules import * # noqa: F403
2
+
3
+
4
+ __all__ = [
5
+ "BNReLU2d",
6
+ "BNReLU3d",
7
+ "ConvReLU1d",
8
+ "ConvReLU2d",
9
+ "ConvReLU3d",
10
+ "LinearReLU",
11
+ "LinearLeakyReLU",
12
+ "LinearTanh",
13
+ "ConvAdd2d",
14
+ "ConvAddReLU2d",
15
+ ]
.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (386 Bytes). View file
 
.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .modules import * # noqa: F403
.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (245 Bytes). View file
 
.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/modules/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .linear_relu import LinearReLU
2
+
3
+
4
+ __all__ = [
5
+ "LinearReLU",
6
+ ]
.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/modules/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (305 Bytes). View file
 
.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/modules/__pycache__/linear_relu.cpython-312.pyc ADDED
Binary file (3.37 kB). View file
 
.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/modules/linear_relu.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import torch
3
+ import torch.ao.nn.intrinsic as nni
4
+ import torch.ao.nn.quantized.dynamic as nnqd
5
+
6
+
7
+ __all__ = ["LinearReLU"]
8
+
9
+
10
+ class LinearReLU(nnqd.Linear):
11
+ r"""
12
+ A LinearReLU module fused from Linear and ReLU modules that can be used
13
+ for dynamic quantization.
14
+ Supports both, FP16 and INT8 quantization.
15
+
16
+ We adopt the same interface as :class:`torch.ao.nn.quantized.dynamic.Linear`.
17
+
18
+ Attributes:
19
+ Same as torch.ao.nn.quantized.dynamic.Linear
20
+
21
+ Examples::
22
+
23
+ >>> # xdoctest: +SKIP
24
+ >>> m = nn.intrinsic.quantized.dynamic.LinearReLU(20, 30)
25
+ >>> input = torch.randn(128, 20)
26
+ >>> output = m(input)
27
+ >>> print(output.size())
28
+ torch.Size([128, 30])
29
+ """
30
+
31
+ _FLOAT_MODULE = nni.LinearReLU # type: ignore[assignment]
32
+
33
+ def __init__(self, in_features, out_features, bias=True, dtype=torch.qint8):
34
+ super().__init__(in_features, out_features, bias, dtype)
35
+
36
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
37
+ if self._packed_params.dtype == torch.qint8:
38
+ # TODO check if we should set reduce_rage = True by default here
39
+ Y = torch.ops.quantized.linear_relu_dynamic(
40
+ x, self._packed_params._packed_params, reduce_range=True
41
+ )
42
+ elif self._packed_params.dtype == torch.float16:
43
+ Y = torch.ops.quantized.linear_relu_dynamic_fp16(
44
+ x, self._packed_params._packed_params
45
+ )
46
+ else:
47
+ raise RuntimeError("Unsupported dtype on dynamic quantized linear relu!")
48
+ return Y.to(x.dtype)
49
+
50
+ def _get_name(self):
51
+ return "DynamicQuantizedLinearReLU"
52
+
53
+ @classmethod
54
+ def from_float(cls, mod, use_precomputed_fake_quant=False):
55
+ return super().from_float(
56
+ mod, use_precomputed_fake_quant=use_precomputed_fake_quant
57
+ )
58
+
59
+ @classmethod
60
+ def from_reference(cls, ref_qlinear_relu): # type: ignore[override]
61
+ return super().from_reference(ref_qlinear_relu[0])
.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .bn_relu import BNReLU2d, BNReLU3d
2
+ from .conv_add import ConvAdd2d, ConvAddReLU2d
3
+ from .conv_relu import ConvReLU1d, ConvReLU2d, ConvReLU3d
4
+ from .linear_relu import LinearLeakyReLU, LinearReLU, LinearTanh
5
+
6
+
7
+ __all__ = [
8
+ "LinearReLU",
9
+ "ConvReLU1d",
10
+ "ConvReLU2d",
11
+ "ConvReLU3d",
12
+ "BNReLU2d",
13
+ "BNReLU3d",
14
+ "LinearLeakyReLU",
15
+ "LinearTanh",
16
+ "ConvAdd2d",
17
+ "ConvAddReLU2d",
18
+ ]
.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (608 Bytes). View file
 
.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/bn_relu.cpython-312.pyc ADDED
Binary file (4.61 kB). View file
 
.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/conv_add.cpython-312.pyc ADDED
Binary file (5.28 kB). View file
 
.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/conv_relu.cpython-312.pyc ADDED
Binary file (10.6 kB). View file
 
.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/linear_relu.cpython-312.pyc ADDED
Binary file (9.9 kB). View file
 
.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/bn_relu.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+
3
+ import torch
4
+ import torch.ao.nn.intrinsic
5
+ import torch.ao.nn.intrinsic.qat
6
+ import torch.ao.nn.quantized as nnq
7
+
8
+
9
+ __all__ = ["BNReLU2d", "BNReLU3d"]
10
+
11
+
12
+ class BNReLU2d(nnq.BatchNorm2d):
13
+ r"""
14
+ A BNReLU2d module is a fused module of BatchNorm2d and ReLU
15
+
16
+ We adopt the same interface as :class:`torch.ao.nn.quantized.BatchNorm2d`.
17
+
18
+ Attributes:
19
+ Same as torch.ao.nn.quantized.BatchNorm2d
20
+
21
+ """
22
+
23
+ _FLOAT_MODULE = torch.ao.nn.intrinsic.BNReLU2d
24
+
25
+ def __init__(self, num_features, eps=1e-5, momentum=0.1, device=None, dtype=None):
26
+ super().__init__(
27
+ num_features, eps=eps, momentum=momentum, device=device, dtype=dtype
28
+ )
29
+
30
+ def forward(self, input):
31
+ # Temporarily using len(shape) instead of ndim due to JIT issue
32
+ # https://github.com/pytorch/pytorch/issues/23890
33
+ if len(input.shape) != 4:
34
+ raise ValueError("Input shape must be `(N, C, H, W)`!")
35
+ return torch.ops.quantized.batch_norm2d_relu(
36
+ input,
37
+ self.weight,
38
+ self.bias,
39
+ self.running_mean,
40
+ self.running_var,
41
+ self.eps,
42
+ self.scale,
43
+ self.zero_point,
44
+ )
45
+
46
+ def _get_name(self):
47
+ return "QuantizedBNReLU2d"
48
+
49
+ @classmethod
50
+ def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override]
51
+ # TODO: Add qat support for BNReLU2d
52
+ return super().from_float(
53
+ mod, use_precomputed_fake_quant=use_precomputed_fake_quant
54
+ )
55
+
56
+ @classmethod
57
+ def from_reference(cls, bn_relu, output_scale, output_zero_point):
58
+ return super().from_reference(bn_relu[0], output_scale, output_zero_point)
59
+
60
+
61
+ class BNReLU3d(nnq.BatchNorm3d):
62
+ r"""
63
+ A BNReLU3d module is a fused module of BatchNorm3d and ReLU
64
+
65
+ We adopt the same interface as :class:`torch.ao.nn.quantized.BatchNorm3d`.
66
+
67
+ Attributes:
68
+ Same as torch.ao.nn.quantized.BatchNorm3d
69
+
70
+ """
71
+
72
+ _FLOAT_MODULE = torch.ao.nn.intrinsic.BNReLU3d
73
+
74
+ def __init__(self, num_features, eps=1e-5, momentum=0.1, device=None, dtype=None):
75
+ super().__init__(
76
+ num_features, eps=eps, momentum=momentum, device=device, dtype=dtype
77
+ )
78
+
79
+ def forward(self, input):
80
+ # Temporarily using len(shape) instead of ndim due to JIT issue
81
+ # https://github.com/pytorch/pytorch/issues/23890
82
+ if len(input.shape) != 5:
83
+ raise ValueError("Input shape must be `(N, C, D, H, W)`!")
84
+ return torch.ops.quantized.batch_norm3d_relu(
85
+ input,
86
+ self.weight,
87
+ self.bias,
88
+ self.running_mean,
89
+ self.running_var,
90
+ self.eps,
91
+ self.scale,
92
+ self.zero_point,
93
+ )
94
+
95
+ def _get_name(self):
96
+ return "QuantizedBNReLU3d"
97
+
98
+ @classmethod
99
+ def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override]
100
+ # TODO: Add qat support for BNReLU3d
101
+ return super().from_float(
102
+ mod, use_precomputed_fake_quant=use_precomputed_fake_quant
103
+ )
104
+
105
+ @classmethod
106
+ def from_reference(cls, bn_relu, output_scale, output_zero_point):
107
+ return super().from_reference(bn_relu[0], output_scale, output_zero_point)
.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/conv_add.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import torch
3
+ import torch.ao.nn.intrinsic
4
+ import torch.ao.nn.intrinsic.qat
5
+ import torch.ao.nn.quantized as nnq
6
+ import torch.nn.functional as F
7
+
8
+
9
+ _reverse_repeat_padding = nnq.modules.conv._reverse_repeat_padding
10
+
11
+
12
+ class ConvAdd2d(nnq.Conv2d):
13
+ r"""
14
+ A ConvAdd2d module is a fused module of Conv2d and Add
15
+
16
+ We adopt the same interface as :class:`torch.ao.nn.quantized.Conv2d`.
17
+
18
+ Attributes:
19
+ Same as torch.ao.nn.quantized.Conv2d
20
+
21
+ """
22
+
23
+ _FLOAT_MODULE = torch.ao.nn.intrinsic.ConvAdd2d # type: ignore[assignment]
24
+
25
+ def __init__(
26
+ self,
27
+ in_channels,
28
+ out_channels,
29
+ kernel_size,
30
+ stride=1,
31
+ padding=0,
32
+ dilation=1,
33
+ groups=1,
34
+ bias=True,
35
+ padding_mode="zeros",
36
+ device=None,
37
+ dtype=None,
38
+ ):
39
+ super().__init__(
40
+ in_channels,
41
+ out_channels,
42
+ kernel_size,
43
+ stride=stride,
44
+ padding=padding,
45
+ dilation=dilation,
46
+ groups=groups,
47
+ bias=bias,
48
+ padding_mode=padding_mode,
49
+ device=device,
50
+ dtype=dtype,
51
+ )
52
+
53
+ def forward(self, input, extra_input): # type: ignore[override]
54
+ # Temporarily using len(shape) instead of ndim due to JIT issue
55
+ # https://github.com/pytorch/pytorch/issues/23890
56
+ if len(input.shape) != 4:
57
+ raise ValueError("Input shape must be `(N, C, H, W)`!")
58
+ if self.padding_mode != "zeros":
59
+ _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding)
60
+ input = F.pad(
61
+ input, _reversed_padding_repeated_twice, mode=self.padding_mode
62
+ )
63
+ return torch.ops.quantized.conv2d_add(
64
+ input, extra_input, self._packed_params, self.scale, self.zero_point
65
+ )
66
+
67
+ def _get_name(self):
68
+ return "QuantizedConvAdd2d"
69
+
70
+ @classmethod
71
+ def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override]
72
+ return super().from_float(
73
+ mod, use_precomputed_fake_quant=use_precomputed_fake_quant
74
+ )
75
+
76
+ @classmethod
77
+ def from_reference(cls, ref_qconv, output_scale, output_zero_point):
78
+ return super().from_reference(ref_qconv[0], output_scale, output_zero_point)
79
+
80
+
81
+ class ConvAddReLU2d(nnq.Conv2d):
82
+ r"""
83
+ A ConvAddReLU2d module is a fused module of Conv2d, Add and Relu
84
+
85
+ We adopt the same interface as :class:`torch.ao.nn.quantized.Conv2d`.
86
+
87
+ Attributes:
88
+ Same as torch.ao.nn.quantized.Conv2d
89
+
90
+ """
91
+
92
+ _FLOAT_MODULE = torch.ao.nn.intrinsic.ConvAddReLU2d # type: ignore[assignment]
93
+
94
+ def __init__(
95
+ self,
96
+ in_channels,
97
+ out_channels,
98
+ kernel_size,
99
+ stride=1,
100
+ padding=0,
101
+ dilation=1,
102
+ groups=1,
103
+ bias=True,
104
+ padding_mode="zeros",
105
+ device=None,
106
+ dtype=None,
107
+ ):
108
+ super().__init__(
109
+ in_channels,
110
+ out_channels,
111
+ kernel_size,
112
+ stride=stride,
113
+ padding=padding,
114
+ dilation=dilation,
115
+ groups=groups,
116
+ bias=bias,
117
+ padding_mode=padding_mode,
118
+ device=device,
119
+ dtype=dtype,
120
+ )
121
+
122
+ def forward(self, input, extra_input): # type: ignore[override]
123
+ # Temporarily using len(shape) instead of ndim due to JIT issue
124
+ # https://github.com/pytorch/pytorch/issues/23890
125
+ if len(input.shape) != 4:
126
+ raise ValueError("Input shape must be `(N, C, H, W)`!")
127
+ if self.padding_mode != "zeros":
128
+ _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding)
129
+ input = F.pad(
130
+ input, _reversed_padding_repeated_twice, mode=self.padding_mode
131
+ )
132
+ return torch.ops.quantized.conv2d_add_relu(
133
+ input, extra_input, self._packed_params, self.scale, self.zero_point
134
+ )
135
+
136
+ def _get_name(self):
137
+ return "QuantizedConvAddReLU2d"
138
+
139
+ @classmethod
140
+ def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override]
141
+ return super().from_float(
142
+ mod, use_precomputed_fake_quant=use_precomputed_fake_quant
143
+ )
144
+
145
+ @classmethod
146
+ def from_reference(cls, ref_qconv, output_scale, output_zero_point):
147
+ return super().from_reference(ref_qconv[0], output_scale, output_zero_point)
.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/conv_relu.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+
3
+ import torch
4
+ import torch.ao.nn.intrinsic
5
+ import torch.ao.nn.intrinsic.qat
6
+ import torch.ao.nn.quantized as nnq
7
+ import torch.nn.functional as F
8
+ from torch.nn.utils import fuse_conv_bn_weights
9
+
10
+
11
+ __all__ = [
12
+ "ConvReLU1d",
13
+ "ConvReLU2d",
14
+ "ConvReLU3d",
15
+ ]
16
+
17
+ _reverse_repeat_padding = nnq.modules.conv._reverse_repeat_padding
18
+
19
+
20
+ # TODO: factor out the common parts to ConvNd
21
+ class ConvReLU1d(nnq.Conv1d):
22
+ r"""
23
+ A ConvReLU1d module is a fused module of Conv1d and ReLU
24
+
25
+ We adopt the same interface as :class:`torch.ao.nn.quantized.Conv1d`.
26
+
27
+ Attributes:
28
+ Same as torch.ao.nn.quantized.Conv1d
29
+
30
+ """
31
+
32
+ _FLOAT_MODULE = torch.ao.nn.intrinsic.ConvReLU1d # type: ignore[assignment]
33
+
34
+ def __init__(
35
+ self,
36
+ in_channels,
37
+ out_channels,
38
+ kernel_size,
39
+ stride=1,
40
+ padding=0,
41
+ dilation=1,
42
+ groups=1,
43
+ bias=True,
44
+ padding_mode="zeros",
45
+ device=None,
46
+ dtype=None,
47
+ ):
48
+ super().__init__(
49
+ in_channels,
50
+ out_channels,
51
+ kernel_size,
52
+ stride=stride,
53
+ padding=padding,
54
+ dilation=dilation,
55
+ groups=groups,
56
+ bias=bias,
57
+ padding_mode=padding_mode,
58
+ device=device,
59
+ dtype=dtype,
60
+ )
61
+
62
+ def forward(self, input):
63
+ # Temporarily using len(shape) instead of ndim due to JIT issue
64
+ # https://github.com/pytorch/pytorch/issues/23890
65
+ if len(input.shape) != 3:
66
+ raise ValueError("Input shape must be `(N, C, L)`!")
67
+ if self.padding_mode != "zeros":
68
+ # Padding in Conv1d is stored as (p, p), need to get (p,)
69
+ _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding[:1])
70
+ input = F.pad(
71
+ input, _reversed_padding_repeated_twice, mode=self.padding_mode
72
+ )
73
+ return torch.ops.quantized.conv1d_relu(
74
+ input, self._packed_params, self.scale, self.zero_point
75
+ )
76
+
77
+ def _get_name(self):
78
+ return "QuantizedConvReLU1d"
79
+
80
+ @classmethod
81
+ def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override]
82
+ if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU1d:
83
+ assert mod.bn.running_var is not None and mod.bn.running_mean is not None
84
+ mod.weight, mod.bias = fuse_conv_bn_weights(
85
+ mod.weight,
86
+ mod.bias,
87
+ mod.bn.running_mean,
88
+ mod.bn.running_var,
89
+ mod.bn.eps,
90
+ mod.bn.weight,
91
+ mod.bn.bias,
92
+ )
93
+ return super().from_float(mod, use_precomputed_fake_quant)
94
+
95
+ @classmethod
96
+ def from_reference(cls, ref_qconv, output_scale, output_zero_point):
97
+ assert type(ref_qconv) != torch.ao.nn.intrinsic.ConvBnReLU1d, (
98
+ "BatchNorm1d should be fused into Conv1d before converting to reference module"
99
+ )
100
+ return super().from_reference(ref_qconv[0], output_scale, output_zero_point)
101
+
102
+
103
+ class ConvReLU2d(nnq.Conv2d):
104
+ r"""
105
+ A ConvReLU2d module is a fused module of Conv2d and ReLU
106
+
107
+ We adopt the same interface as :class:`torch.ao.nn.quantized.Conv2d`.
108
+
109
+ Attributes:
110
+ Same as torch.ao.nn.quantized.Conv2d
111
+
112
+ """
113
+
114
+ _FLOAT_MODULE = torch.ao.nn.intrinsic.ConvReLU2d # type: ignore[assignment]
115
+
116
+ def __init__(
117
+ self,
118
+ in_channels,
119
+ out_channels,
120
+ kernel_size,
121
+ stride=1,
122
+ padding=0,
123
+ dilation=1,
124
+ groups=1,
125
+ bias=True,
126
+ padding_mode="zeros",
127
+ device=None,
128
+ dtype=None,
129
+ ):
130
+ super().__init__(
131
+ in_channels,
132
+ out_channels,
133
+ kernel_size,
134
+ stride=stride,
135
+ padding=padding,
136
+ dilation=dilation,
137
+ groups=groups,
138
+ bias=bias,
139
+ padding_mode=padding_mode,
140
+ device=device,
141
+ dtype=dtype,
142
+ )
143
+
144
+ def forward(self, input):
145
+ # Temporarily using len(shape) instead of ndim due to JIT issue
146
+ # https://github.com/pytorch/pytorch/issues/23890
147
+ if len(input.shape) != 4:
148
+ raise ValueError("Input shape must be `(N, C, H, W)`!")
149
+ if self.padding_mode != "zeros":
150
+ _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding)
151
+ input = F.pad(
152
+ input, _reversed_padding_repeated_twice, mode=self.padding_mode
153
+ )
154
+ return torch.ops.quantized.conv2d_relu(
155
+ input, self._packed_params, self.scale, self.zero_point
156
+ )
157
+
158
+ def _get_name(self):
159
+ return "QuantizedConvReLU2d"
160
+
161
+ @classmethod
162
+ def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override]
163
+ if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU2d:
164
+ assert mod.bn.running_var is not None and mod.bn.running_mean is not None
165
+ mod.weight, mod.bias = fuse_conv_bn_weights(
166
+ mod.weight,
167
+ mod.bias,
168
+ mod.bn.running_mean,
169
+ mod.bn.running_var,
170
+ mod.bn.eps,
171
+ mod.bn.weight,
172
+ mod.bn.bias,
173
+ )
174
+ return super().from_float(
175
+ mod, use_precomputed_fake_quant=use_precomputed_fake_quant
176
+ )
177
+
178
+ @classmethod
179
+ def from_reference(cls, ref_qconv, output_scale, output_zero_point):
180
+ assert type(ref_qconv) != torch.ao.nn.intrinsic.ConvBnReLU2d, (
181
+ "BatchNorm2d should be fused into Conv2d before converting to reference module"
182
+ )
183
+ return super().from_reference(ref_qconv[0], output_scale, output_zero_point)
184
+
185
+
186
+ class ConvReLU3d(nnq.Conv3d):
187
+ r"""
188
+ A ConvReLU3d module is a fused module of Conv3d and ReLU
189
+
190
+ We adopt the same interface as :class:`torch.ao.nn.quantized.Conv3d`.
191
+
192
+ Attributes: Same as torch.ao.nn.quantized.Conv3d
193
+
194
+ """
195
+
196
+ _FLOAT_MODULE = torch.ao.nn.intrinsic.ConvReLU3d # type: ignore[assignment]
197
+
198
+ def __init__(
199
+ self,
200
+ in_channels,
201
+ out_channels,
202
+ kernel_size,
203
+ stride=1,
204
+ padding=0,
205
+ dilation=1,
206
+ groups=1,
207
+ bias=True,
208
+ padding_mode="zeros",
209
+ device=None,
210
+ dtype=None,
211
+ ):
212
+ assert padding_mode != "reflect", "Conv3d does not support reflection padding"
213
+ super().__init__(
214
+ in_channels,
215
+ out_channels,
216
+ kernel_size,
217
+ stride=stride,
218
+ padding=padding,
219
+ dilation=dilation,
220
+ groups=groups,
221
+ bias=bias,
222
+ padding_mode=padding_mode,
223
+ device=device,
224
+ dtype=dtype,
225
+ )
226
+
227
+ def forward(self, input):
228
+ # Temporarily using len(shape) instead of ndim due to JIT issue
229
+ # https://github.com/pytorch/pytorch/issues/23890
230
+ if len(input.shape) != 5:
231
+ raise ValueError("Input shape must be `(N, C, D, H, W)`!")
232
+ if self.padding_mode != "zeros":
233
+ _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding)
234
+ input = F.pad(
235
+ input, _reversed_padding_repeated_twice, mode=self.padding_mode
236
+ )
237
+ return torch.ops.quantized.conv3d_relu(
238
+ input, self._packed_params, self.scale, self.zero_point
239
+ )
240
+
241
+ def _get_name(self):
242
+ return "QuantizedConvReLU3d"
243
+
244
+ @classmethod
245
+ def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override]
246
+ if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU3d:
247
+ assert mod.bn.running_var is not None and mod.bn.running_mean is not None
248
+ mod.weight, mod.bias = fuse_conv_bn_weights(
249
+ mod.weight,
250
+ mod.bias,
251
+ mod.bn.running_mean,
252
+ mod.bn.running_var,
253
+ mod.bn.eps,
254
+ mod.bn.weight,
255
+ mod.bn.bias,
256
+ )
257
+ return super().from_float(
258
+ mod, use_precomputed_fake_quant=use_precomputed_fake_quant
259
+ )
260
+
261
+ @classmethod
262
+ def from_reference(cls, ref_qconv, output_scale, output_zero_point):
263
+ assert type(ref_qconv) != torch.ao.nn.intrinsic.ConvBnReLU3d, (
264
+ "BatchNorm3d should be fused into Conv3d before converting to reference module"
265
+ )
266
+ return super().from_reference(ref_qconv[0], output_scale, output_zero_point)
.venv/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/linear_relu.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import torch
3
+ import torch.ao.nn.intrinsic as nni
4
+ import torch.ao.nn.quantized as nnq
5
+ from torch.ao.nn.quantized.modules.utils import _quantize_weight
6
+
7
+
8
+ __all__ = [
9
+ "LinearReLU",
10
+ "LinearLeakyReLU",
11
+ "LinearTanh",
12
+ ]
13
+
14
+
15
+ class LinearReLU(nnq.Linear):
16
+ r"""
17
+ A LinearReLU module fused from Linear and ReLU modules
18
+
19
+ We adopt the same interface as :class:`torch.ao.nn.quantized.Linear`.
20
+
21
+ Attributes:
22
+ Same as torch.ao.nn.quantized.Linear
23
+
24
+ Examples::
25
+
26
+ >>> # xdoctest: +SKIP
27
+ >>> m = nn.intrinsic.LinearReLU(20, 30)
28
+ >>> input = torch.randn(128, 20)
29
+ >>> output = m(input)
30
+ >>> print(output.size())
31
+ torch.Size([128, 30])
32
+ """
33
+
34
+ _FLOAT_MODULE = nni.LinearReLU # type: ignore[assignment]
35
+
36
+ def __init__(self, in_features, out_features, bias=True, dtype=torch.qint8):
37
+ super().__init__(in_features, out_features, bias, dtype)
38
+
39
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
40
+ return torch.ops.quantized.linear_relu(
41
+ x, self._packed_params._packed_params, self.scale, self.zero_point
42
+ )
43
+
44
+ def _get_name(self):
45
+ return "QuantizedLinearReLU"
46
+
47
+ @classmethod
48
+ def from_float(cls, mod, use_precomputed_fake_quant=False):
49
+ return super().from_float(mod, use_precomputed_fake_quant)
50
+
51
+ @classmethod
52
+ def from_reference(cls, ref_linear_relu, output_scale, output_zero_point):
53
+ return super().from_reference(
54
+ ref_linear_relu[0], output_scale, output_zero_point
55
+ )
56
+
57
+
58
+ class LinearLeakyReLU(nnq.Linear):
59
+ r"""
60
+ For onednn backend only
61
+ A LinearLeakyReLU module fused from Linear and LeakyReLU modules
62
+ We adopt the same interface as :class:`torch.ao.nn.quantized.Linear`.
63
+ Attributes:
64
+ Same as torch.ao.nn.quantized.Linear
65
+ + negative_slope
66
+ Examples::
67
+ >>> # xdoctest: +SKIP
68
+ >>> m = nn.intrinsic.LinearLeakyReLU(20, 30, 0.01)
69
+ >>> input = torch.randn(128, 20)
70
+ >>> output = m(input)
71
+ >>> print(output.size())
72
+ torch.Size([128, 30])
73
+ """
74
+
75
+ _FLOAT_MODULE = nni.LinearLeakyReLU # type: ignore[assignment]
76
+
77
+ def __init__(
78
+ self, in_features, out_features, negative_slope, bias=True, dtype=torch.qint8
79
+ ):
80
+ super().__init__(in_features, out_features, bias, dtype)
81
+ self.negative_slope = negative_slope
82
+
83
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
84
+ return torch.ops.quantized.linear_leaky_relu(
85
+ x,
86
+ self._packed_params._packed_params,
87
+ self.scale,
88
+ self.zero_point,
89
+ self.negative_slope,
90
+ )
91
+
92
+ def _get_name(self):
93
+ return "QuantizedLinearLeakyReLU"
94
+
95
+ @classmethod
96
+ def from_float(cls, mod, use_precomputed_fake_quant=False):
97
+ assert type(mod) == nni.LinearLeakyReLU, (
98
+ "Input float module should be LinearLeakyReLU"
99
+ )
100
+ assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
101
+ activation_post_process = mod.activation_post_process
102
+ leaky_relu = mod[1]
103
+ mod = mod[0]
104
+ weight_post_process = mod.qconfig.weight() # type: ignore[union-attr, operator]
105
+ weight_post_process(mod.weight)
106
+ dtype = weight_post_process.dtype
107
+ act_scale, act_zp = activation_post_process.calculate_qparams() # type: ignore[union-attr,operator]
108
+ assert dtype == torch.qint8, "Weight observer must have dtype torch.qint8"
109
+ qweight = _quantize_weight(mod.weight.float(), weight_post_process)
110
+ qlinear_leaky_relu = cls(
111
+ mod.in_features, mod.out_features, leaky_relu.negative_slope, dtype=dtype
112
+ )
113
+ qlinear_leaky_relu.set_weight_bias(qweight, mod.bias) # type: ignore[arg-type]
114
+ qlinear_leaky_relu.scale = float(act_scale)
115
+ qlinear_leaky_relu.zero_point = int(act_zp)
116
+ return qlinear_leaky_relu
117
+
118
+ @classmethod
119
+ def from_reference(cls, ref_mod, output_scale, output_zero_point):
120
+ linear = ref_mod[0]
121
+ leaky_relu = ref_mod[1]
122
+ qlinear_leaky_relu = cls(
123
+ linear.in_features, linear.out_features, leaky_relu.negative_slope
124
+ )
125
+ qweight = linear.get_quantized_weight()
126
+ qlinear_leaky_relu.set_weight_bias(qweight, linear.bias)
127
+ qlinear_leaky_relu.scale = float(output_scale)
128
+ qlinear_leaky_relu.zero_point = int(output_zero_point)
129
+ return qlinear_leaky_relu
130
+
131
+
132
+ class LinearTanh(nnq.Linear):
133
+ r"""
134
+ A LinearTanh module fused from Linear and Tanh modules
135
+
136
+ We adopt the same interface as :class:`torch.ao.nn.quantized.Linear`.
137
+
138
+ Attributes:
139
+ Same as torch.ao.nn.quantized.Linear
140
+
141
+ Examples::
142
+
143
+ >>> # xdoctest: +SKIP
144
+ >>> m = nn.intrinsic.LinearTanh(20, 30)
145
+ >>> input = torch.randn(128, 20)
146
+ >>> output = m(input)
147
+ >>> print(output.size())
148
+ torch.Size([128, 30])
149
+ """
150
+
151
+ _FLOAT_MODULE = nni.LinearTanh # type: ignore[assignment]
152
+
153
+ def __init__(self, in_features, out_features, bias=True, dtype=torch.qint8):
154
+ super().__init__(in_features, out_features, bias, dtype)
155
+
156
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
157
+ return torch.ops.quantized.linear_tanh(
158
+ x, self._packed_params._packed_params, self.scale, self.zero_point
159
+ )
160
+
161
+ def _get_name(self):
162
+ return "QuantizedLinearTanh"
163
+
164
+ @classmethod
165
+ def from_float(cls, mod, use_precomputed_fake_quant=False):
166
+ assert type(mod) == nni.LinearTanh, "Input float module should be LinearTanh"
167
+ assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
168
+ activation_post_process = mod.activation_post_process
169
+ mod = mod[0]
170
+ weight_post_process = mod.qconfig.weight() # type: ignore[union-attr,operator]
171
+ weight_post_process(mod.weight)
172
+ dtype = weight_post_process.dtype
173
+ act_scale, act_zp = activation_post_process.calculate_qparams() # type: ignore[union-attr,operator]
174
+ assert dtype == torch.qint8, "Weight observer must have dtype torch.qint8"
175
+ qweight = _quantize_weight(mod.weight.float(), weight_post_process)
176
+ qlinear_tanh = cls(mod.in_features, mod.out_features, dtype=dtype)
177
+ qlinear_tanh.set_weight_bias(qweight, mod.bias) # type: ignore[arg-type]
178
+ qlinear_tanh.scale = float(act_scale)
179
+ qlinear_tanh.zero_point = int(act_zp)
180
+ return qlinear_tanh
181
+
182
+ @classmethod
183
+ def from_reference(cls, ref_mod, output_scale, output_zero_point):
184
+ linear = ref_mod[0]
185
+ qlinear_tanh = cls(linear.in_features, linear.out_features)
186
+ qweight = linear.get_quantized_weight()
187
+ qlinear_tanh.set_weight_bias(qweight, linear.bias)
188
+ qlinear_tanh.scale = float(output_scale)
189
+ qlinear_tanh.zero_point = int(output_zero_point)
190
+ return qlinear_tanh
.venv/lib/python3.12/site-packages/torch/ao/nn/qat/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .modules import * # noqa: F403