BryanW commited on
Commit
cde130c
·
verified ·
1 Parent(s): 3f10421

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. URSA/.venv_ursa/lib/python3.12/site-packages/torch/_C/_acc/__init__.pyi +15 -0
  2. URSA/.venv_ursa/lib/python3.12/site-packages/torch/_C/_dynamo/__init__.pyi +4 -0
  3. URSA/.venv_ursa/lib/python3.12/site-packages/torch/_C/_dynamo/compiled_autograd.pyi +13 -0
  4. URSA/.venv_ursa/lib/python3.12/site-packages/torch/_C/_dynamo/eval_frame.pyi +84 -0
  5. URSA/.venv_ursa/lib/python3.12/site-packages/torch/_C/_dynamo/guards.pyi +452 -0
  6. URSA/.venv_ursa/lib/python3.12/site-packages/torch/_C/_export/__init__.pyi +9 -0
  7. URSA/.venv_ursa/lib/python3.12/site-packages/torch/_C/_export/pt2_archive_constants.pyi +25 -0
  8. URSA/.venv_ursa/lib/python3.12/site-packages/torch/_custom_op/__pycache__/__init__.cpython-312.pyc +0 -0
  9. URSA/.venv_ursa/lib/python3.12/site-packages/torch/_custom_op/__pycache__/autograd.cpython-312.pyc +0 -0
  10. URSA/.venv_ursa/lib/python3.12/site-packages/torch/_custom_op/__pycache__/impl.cpython-312.pyc +0 -0
  11. URSA/.venv_ursa/lib/python3.12/site-packages/torch/amp/__pycache__/__init__.cpython-312.pyc +0 -0
  12. URSA/.venv_ursa/lib/python3.12/site-packages/torch/amp/__pycache__/autocast_mode.cpython-312.pyc +0 -0
  13. URSA/.venv_ursa/lib/python3.12/site-packages/torch/amp/__pycache__/grad_scaler.cpython-312.pyc +0 -0
  14. URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/__pycache__/__init__.cpython-312.pyc +0 -0
  15. URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/__init__.py +35 -0
  16. URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/__pycache__/__init__.cpython-312.pyc +0 -0
  17. URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/__init__.py +41 -0
  18. URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/__pycache__/__init__.cpython-312.pyc +0 -0
  19. URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/modules/__init__.py +41 -0
  20. URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/modules/__pycache__/__init__.cpython-312.pyc +0 -0
  21. URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/modules/__pycache__/fused.cpython-312.pyc +0 -0
  22. URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/modules/fused.py +289 -0
  23. URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/__init__.py +1 -0
  24. URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/__pycache__/__init__.cpython-312.pyc +0 -0
  25. URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/modules/__init__.py +32 -0
  26. URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/modules/__pycache__/__init__.cpython-312.pyc +0 -0
  27. URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/modules/__pycache__/conv_fused.cpython-312.pyc +0 -0
  28. URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/modules/__pycache__/linear_fused.cpython-312.pyc +0 -0
  29. URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/modules/__pycache__/linear_relu.cpython-312.pyc +0 -0
  30. URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/modules/conv_fused.py +958 -0
  31. URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/modules/linear_fused.py +191 -0
  32. URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/modules/linear_relu.py +74 -0
  33. URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/__init__.py +15 -0
  34. URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/__pycache__/__init__.cpython-312.pyc +0 -0
  35. URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/__init__.py +1 -0
  36. URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/__pycache__/__init__.cpython-312.pyc +0 -0
  37. URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/modules/__init__.py +6 -0
  38. URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/modules/__pycache__/__init__.cpython-312.pyc +0 -0
  39. URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/modules/__pycache__/linear_relu.cpython-312.pyc +0 -0
  40. URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/modules/linear_relu.py +72 -0
  41. URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/__init__.py +18 -0
  42. URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/__init__.cpython-312.pyc +0 -0
  43. URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/bn_relu.cpython-312.pyc +0 -0
  44. URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/conv_add.cpython-312.pyc +0 -0
  45. URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/conv_relu.cpython-312.pyc +0 -0
  46. URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/linear_relu.cpython-312.pyc +0 -0
  47. URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/bn_relu.py +113 -0
  48. URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/conv_add.py +153 -0
  49. URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/conv_relu.py +276 -0
  50. URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/linear_relu.py +190 -0
URSA/.venv_ursa/lib/python3.12/site-packages/torch/_C/_acc/__init__.pyi ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import Tensor
2
+ from torch.types import _dtype, _int, Device
3
+
4
+ # Defined in torch/csrc/acc/Module.cpp
5
+ class PrivateUse1Hooks:
6
+ def has_primary_context(self, device_index: _int) -> bool: ...
7
+ def is_built(self) -> bool: ...
8
+ def is_avaible(self) -> bool: ...
9
+
10
+ class DeviceGuard:
11
+ def type_(self) -> Device: ...
12
+
13
+ def register_python_privateuseone_device_guard(guard: DeviceGuard) -> bool: ...
14
+ def register_python_privateuseone_hook(hook: PrivateUse1Hooks) -> bool: ...
15
+ def create_empty_tensor(shape: tuple[_int, ...], dtype: _dtype) -> Tensor: ...
URSA/.venv_ursa/lib/python3.12/site-packages/torch/_C/_dynamo/__init__.pyi ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from . import compiled_autograd, eval_frame, guards # noqa: F401
2
+
3
+ def strip_function_call(name: str) -> str: ...
4
+ def is_valid_var_name(name: str) -> bool | int: ...
URSA/.venv_ursa/lib/python3.12/site-packages/torch/_C/_dynamo/compiled_autograd.pyi ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections.abc import Callable
2
+
3
+ from torch import Tensor
4
+ from torch._dynamo.compiled_autograd import AutogradCompilerInstance
5
+
6
+ def set_autograd_compiler(
7
+ autograd_compiler: Callable[[], AutogradCompilerInstance] | None,
8
+ dynamic: bool,
9
+ ) -> tuple[Callable[[], AutogradCompilerInstance] | None, bool]: ...
10
+ def clear_cache() -> None: ...
11
+ def is_cache_empty() -> bool: ...
12
+ def set_verbose_logger(fn: Callable[[str], None] | None) -> bool: ...
13
+ def call_cpp_tensor_pre_hooks(idx: int, grad: Tensor) -> Tensor: ...
URSA/.venv_ursa/lib/python3.12/site-packages/torch/_C/_dynamo/eval_frame.pyi ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import enum
2
+ import types
3
+ from collections.abc import Callable
4
+ from typing import Optional, overload
5
+
6
+ from torch._dynamo.guards import GuardManagerWrapper
7
+ from torch._dynamo.types import DynamoCallback, DynamoGuardCompleteHook, DynamoGuardHook
8
+ from torch._guards import CompileId
9
+
10
+ def set_eval_frame(callback: DynamoCallback) -> DynamoCallback: ...
11
+ def set_skip_guard_eval_unsafe(value: bool) -> bool: ...
12
+ def get_eval_frame_callback() -> DynamoCallback: ...
13
+ def reset_code(code: types.CodeType) -> None: ...
14
+ def unsupported(obj1: object, obj2: object) -> object: ...
15
+ def set_code_exec_strategy(
16
+ code: types.CodeType, strategy: _FrameExecStrategy
17
+ ) -> None: ...
18
+ def set_guard_error_hook(hook: DynamoGuardHook) -> None: ...
19
+ def set_guard_complete_hook(
20
+ hook: Optional[DynamoGuardCompleteHook],
21
+ ) -> Optional[DynamoGuardCompleteHook]: ...
22
+ def raise_sigtrap() -> None: ...
23
+ def set_c_recursion_limit(limit: int) -> None: ...
24
+ def get_c_recursion_limit() -> int: ...
25
+
26
+ class _CacheEntry:
27
+ def check_fn(self, *args: object, **kwargs: object) -> bool: ...
28
+ def update_diff_guard_root_manager(self) -> None: ...
29
+ code: types.CodeType
30
+ compile_id: CompileId
31
+ # If we run into circular issues, just use object
32
+ guard_manager: GuardManagerWrapper
33
+ backend: Callable
34
+ next: _CacheEntry | None
35
+
36
+ class _PrecompileEntry:
37
+ guard_manager: GuardManagerWrapper
38
+
39
+ class _ExtraState:
40
+ def invalidate(
41
+ self, cache_entry: _CacheEntry, guard_manager: GuardManagerWrapper
42
+ ) -> None: ...
43
+
44
+ class _FrameAction(enum.IntEnum):
45
+ DEFAULT = 0
46
+ SKIP = 1
47
+ RUN_ONLY = 2
48
+
49
+ class _FrameExecStrategy:
50
+ cur_action: _FrameAction
51
+ recursive_action: _FrameAction
52
+
53
+ @overload
54
+ def __init__(self) -> None: ...
55
+ @overload
56
+ def __init__(
57
+ self, cur_action: _FrameAction, recursive_action: _FrameAction
58
+ ) -> None: ...
59
+
60
+ # This is an object that encapsulates the Python FrameType, and exposes
61
+ # properties Dynamo cares about for a frame.
62
+ class _PyInterpreterFrame:
63
+ f_code: types.CodeType
64
+ f_locals: dict[str, object]
65
+ f_globals: dict[str, object]
66
+ f_builtins: dict[str, object]
67
+ f_lasti: int
68
+ f_lineno: int
69
+ f_back: types.FrameType
70
+ # A tuple containing cell objects captured by this frame.
71
+ closure: tuple[types.CellType]
72
+
73
+ def _debug_get_cache_entry_list(code: types.CodeType) -> list[_CacheEntry]: ...
74
+
75
+ py_opcode_caches: list[int]
76
+
77
+ def code_framelocals_names(code: types.CodeType) -> tuple[str, ...]: ...
78
+ def _load_precompile_entry(
79
+ code: types.CodeType,
80
+ guard_manager: GuardManagerWrapper,
81
+ dynamo_code: types.CodeType,
82
+ ) -> None: ...
83
+ def _reset_precompile_entries(code: types.CodeType) -> None: ...
84
+ def _debug_get_precompile_entries(code: types.CodeType) -> list[_PrecompileEntry]: ...
URSA/.venv_ursa/lib/python3.12/site-packages/torch/_C/_dynamo/guards.pyi ADDED
@@ -0,0 +1,452 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import enum
2
+ from collections.abc import Callable
3
+ from typing import Any, Optional, TypeAlias
4
+
5
+ import torch
6
+
7
+ # TODO: We should move the `GuardManagerType`
8
+ # defined in `guards.py` here and update other
9
+ # imports
10
+ GuardManagerType: TypeAlias = enum.Enum
11
+
12
+ class GlobalStateGuard:
13
+ def check(self) -> bool: ...
14
+ def reason(self) -> str: ...
15
+
16
+ class LeafGuard:
17
+ def verbose_code_parts(self) -> list[str]: ...
18
+
19
+ class RelationalGuard: ...
20
+
21
+ class GuardDebugInfo:
22
+ verbose_code_parts: list[str]
23
+ result: bool
24
+ num_guards_executed: int
25
+
26
+ class GuardManager:
27
+ def check(self, value: Any) -> bool: ...
28
+ def check_verbose(self, value: Any) -> GuardDebugInfo: ...
29
+
30
+ # Accessors
31
+ def globals_dict_manager(
32
+ self,
33
+ f_globals: dict[str, Any],
34
+ source: str,
35
+ example_value: Any,
36
+ guard_manager_enum: GuardManagerType,
37
+ ) -> GuardManager: ...
38
+ def framelocals_manager(
39
+ self,
40
+ key: tuple[str, int],
41
+ source: str,
42
+ example_value: Any,
43
+ guard_manager_enum: GuardManagerType,
44
+ ) -> GuardManager: ...
45
+ def dict_getitem_manager(
46
+ self,
47
+ key: Any,
48
+ source: str,
49
+ example_value: Any,
50
+ guard_manager_enum: GuardManagerType,
51
+ ) -> GuardManager: ...
52
+ def grad_manager(
53
+ self,
54
+ source: str,
55
+ example_value: Any,
56
+ guard_manager_enum: GuardManagerType,
57
+ ) -> GuardManager: ...
58
+ def generic_getattr_manager(
59
+ self,
60
+ attr: str,
61
+ source: str,
62
+ example_value: Any,
63
+ guard_manager_enum: GuardManagerType,
64
+ ) -> GuardManager: ...
65
+ def getitem_manager(
66
+ self,
67
+ key: Any,
68
+ source: str,
69
+ example_value: Any,
70
+ guard_manager_enum: GuardManagerType,
71
+ ) -> GuardManager: ...
72
+ def get_generic_dict_manager(
73
+ self,
74
+ source: str,
75
+ example_value: Any,
76
+ guard_manager_enum: GuardManagerType,
77
+ ) -> GuardManager: ...
78
+ def list_getitem_manager(
79
+ self,
80
+ key: Any,
81
+ source: str,
82
+ example_value: Any,
83
+ guard_manager_enum: GuardManagerType,
84
+ ) -> GuardManager: ...
85
+ def tuple_getitem_manager(
86
+ self,
87
+ key: Any,
88
+ source: str,
89
+ example_value: Any,
90
+ guard_manager_enum: GuardManagerType,
91
+ ) -> GuardManager: ...
92
+ def set_getitem_manager(
93
+ self,
94
+ index: Any,
95
+ source: str,
96
+ example_value: Any,
97
+ guard_manager_enum: GuardManagerType,
98
+ ) -> GuardManager: ...
99
+ def func_defaults_manager(
100
+ self,
101
+ source: str,
102
+ example_value: Any,
103
+ guard_manager_enum: GuardManagerType,
104
+ ) -> GuardManager: ...
105
+ def func_kwdefaults_manager(
106
+ self,
107
+ source: str,
108
+ example_value: Any,
109
+ guard_manager_enum: GuardManagerType,
110
+ ) -> GuardManager: ...
111
+ def tuple_iterator_getitem_manager(
112
+ self,
113
+ index: Any,
114
+ source: str,
115
+ example_value: Any,
116
+ guard_manager_enum: GuardManagerType,
117
+ ) -> GuardManager: ...
118
+ def weakref_call_manager(
119
+ self,
120
+ source: str,
121
+ example_value: Any,
122
+ guard_manager_enum: GuardManagerType,
123
+ ) -> GuardManager: ...
124
+ def call_function_no_args_manager(
125
+ self,
126
+ source: str,
127
+ example_value: Any,
128
+ guard_manager_enum: GuardManagerType,
129
+ ) -> GuardManager: ...
130
+ def global_weakref_manager(
131
+ self,
132
+ global_name: str,
133
+ source: str,
134
+ example_value: Any,
135
+ guard_manager_enum: GuardManagerType,
136
+ ) -> GuardManager: ...
137
+ def type_manager(
138
+ self,
139
+ source: str,
140
+ example_value: Any,
141
+ guard_manager_enum: GuardManagerType,
142
+ ) -> GuardManager: ...
143
+ def getattr_manager(
144
+ self,
145
+ attr: str,
146
+ source: str,
147
+ example_value: Any,
148
+ guard_manager_enum: GuardManagerType,
149
+ ) -> GuardManager: ...
150
+ def tensor_property_size_manager(
151
+ self,
152
+ idx: int,
153
+ source: str,
154
+ example_value: Any,
155
+ guard_manager_enum: GuardManagerType,
156
+ ) -> GuardManager: ...
157
+ def tensor_property_shape_manager(
158
+ self,
159
+ idx: int,
160
+ source: str,
161
+ example_value: Any,
162
+ guard_manager_enum: GuardManagerType,
163
+ ) -> GuardManager: ...
164
+ def tensor_property_storage_offset_manager(
165
+ self,
166
+ idx: int,
167
+ source: str,
168
+ example_value: Any,
169
+ guard_manager_enum: GuardManagerType,
170
+ ) -> GuardManager: ...
171
+ def indexed_manager(
172
+ self,
173
+ idx: int,
174
+ source: str,
175
+ example_value: Any,
176
+ guard_manager_enum: GuardManagerType,
177
+ ) -> GuardManager: ...
178
+ def lambda_manager(
179
+ self,
180
+ python_lambda: Callable[..., Any],
181
+ source: str,
182
+ example_value: Any,
183
+ guard_manager_enum: GuardManagerType,
184
+ ) -> GuardManager: ...
185
+ def get_root(self) -> RootGuardManager: ...
186
+ def get_source(self) -> str: ...
187
+ def fail_count(self) -> int: ...
188
+ def get_child_managers(self) -> list[GuardManager]: ...
189
+ def repr(self) -> str: ...
190
+ def type_of_guarded_value(self) -> str: ...
191
+ def get_leaf_guards(self) -> list[LeafGuard]: ...
192
+ def get_accessors(self) -> list[GuardManager]: ...
193
+ def is_guarded_value_immutable(self) -> bool: ...
194
+ def is_tag_safe(self) -> bool: ...
195
+ def is_tag_safe_root(self) -> bool: ...
196
+ def has_no_accessors(self) -> bool: ...
197
+ def has_object_aliasing_guard(self) -> bool: ...
198
+ def get_type_of_guarded_value(self) -> type: ...
199
+ def type_dict_manager(
200
+ self,
201
+ source: str,
202
+ example_value: Any,
203
+ guard_manager_enum: GuardManagerType,
204
+ ) -> GuardManager: ...
205
+ def type_mro_manager(
206
+ self,
207
+ source: str,
208
+ example_value: Any,
209
+ guard_manager_enum: GuardManagerType,
210
+ ) -> GuardManager: ...
211
+ def code_manager(
212
+ self,
213
+ source: str,
214
+ example_value: Any,
215
+ guard_manager_enum: GuardManagerType,
216
+ ) -> GuardManager: ...
217
+ def closure_manager(
218
+ self,
219
+ source: str,
220
+ example_value: Any,
221
+ guard_manager_enum: GuardManagerType,
222
+ ) -> GuardManager: ...
223
+ # Leaf guards
224
+ def add_lambda_guard(
225
+ self, user_lambda: Callable[..., Any], verbose_code_parts: list[str]
226
+ ) -> None: ...
227
+ def add_id_match_guard(
228
+ self, id_val: int, verbose_code_parts: list[str]
229
+ ) -> None: ...
230
+ def add_equals_match_guard(
231
+ self,
232
+ equals_val: Any,
233
+ verbose_code_parts: list[str],
234
+ ) -> None: ...
235
+ def add_global_state_guard(
236
+ self, initial_state: Any, verbose_code_parts: list[str]
237
+ ) -> None: ...
238
+ def add_torch_function_mode_stack_guard(
239
+ self, initial_stack: list[Any], verbose_code_parts: list[str]
240
+ ) -> None: ...
241
+ def add_mapping_keys_guard(
242
+ self, value: Any, verbose_code_parts: list[str]
243
+ ) -> None: ...
244
+ def add_dict_length_check_guard(
245
+ self, value: int, verbose_code_parts: list[str]
246
+ ) -> None: ...
247
+ def add_length_check_guard(
248
+ self, value: int, verbose_code_parts: list[str]
249
+ ) -> None: ...
250
+ def add_true_match_guard(
251
+ self,
252
+ verbose_code_parts: list[str],
253
+ ) -> None: ...
254
+ def add_false_match_guard(
255
+ self,
256
+ verbose_code_parts: list[str],
257
+ ) -> None: ...
258
+ def add_none_match_guard(
259
+ self,
260
+ verbose_code_parts: list[str],
261
+ ) -> None: ...
262
+ def add_not_none_guard(
263
+ self,
264
+ verbose_code_parts: list[str],
265
+ ) -> None: ...
266
+ def add_dispatch_key_set_guard(
267
+ self,
268
+ dispatch_key: Any,
269
+ verbose_code_parts: list[str],
270
+ ) -> None: ...
271
+ def add_tensor_match_guard(
272
+ self,
273
+ value: Any,
274
+ sizes: list[int],
275
+ strides: list[int],
276
+ tensor_name: str,
277
+ verbose_code_parts: list[str],
278
+ ptype: Any,
279
+ dispatch_keys: Any,
280
+ ) -> None: ...
281
+ def add_dynamic_indices_guard(
282
+ self,
283
+ value: set[Any],
284
+ verbose_code_parts: list[str],
285
+ ) -> None: ...
286
+ def add_no_hasattr_guard(
287
+ self,
288
+ attr_name: str,
289
+ verbose_code_parts: list[str],
290
+ ) -> None: ...
291
+ def add_dict_contains_guard(
292
+ self,
293
+ contains: bool,
294
+ key: Any,
295
+ verbose_code_parts: list[str],
296
+ ) -> None: ...
297
+ def add_type_match_guard(
298
+ self,
299
+ value: int,
300
+ verbose_code_parts: list[str],
301
+ ) -> None: ...
302
+ def add_dict_version_guard(
303
+ self,
304
+ value: Any,
305
+ verbose_code_parts: list[str],
306
+ ) -> None: ...
307
+ def add_set_contains_guard(
308
+ self,
309
+ contains: bool,
310
+ item: Any,
311
+ verbose_code_parts: list[str],
312
+ ) -> None: ...
313
+ def add_dual_level_match_guard(
314
+ self,
315
+ level: int,
316
+ verbose_code_parts: list[str],
317
+ ) -> None: ...
318
+ def add_float_is_nan_guard(
319
+ self,
320
+ verbose_code_parts: list[str],
321
+ ) -> None: ...
322
+ def add_complex_is_nan_guard(
323
+ self,
324
+ verbose_code_parts: list[str],
325
+ ) -> None: ...
326
+ def add_tuple_iterator_length_guard(
327
+ self,
328
+ length: int,
329
+ type_id: int,
330
+ verbose_code_parts: list[str],
331
+ ) -> None: ...
332
+ def add_range_iterator_match_guard(
333
+ self,
334
+ start: int,
335
+ stop: int,
336
+ step: int,
337
+ type_id: int,
338
+ verbose_code_parts: list[str],
339
+ ) -> None: ...
340
+ def add_default_device_guard(
341
+ self,
342
+ verbose_code_parts: list[str],
343
+ ) -> None: ...
344
+ def mark_tag_safe(self) -> None: ...
345
+ def mark_tag_safe_root(self) -> None: ...
346
+
347
+ class RootGuardManager(GuardManager):
348
+ def get_epilogue_lambda_guards(self) -> list[LeafGuard]: ...
349
+ def add_epilogue_lambda_guard(
350
+ self,
351
+ guard: LeafGuard,
352
+ verbose_code_parts: list[str],
353
+ ) -> None: ...
354
+ def clone_manager(
355
+ self, clone_filter_fn: Callable[[GuardManager], bool]
356
+ ) -> RootGuardManager: ...
357
+ def attach_compile_id(self, compile_id: str) -> None: ...
358
+
359
+ class DictGuardManager(GuardManager):
360
+ def get_key_manager(
361
+ self,
362
+ index: int,
363
+ source: str,
364
+ example_value: Any,
365
+ guard_manager_enum: GuardManagerType,
366
+ ) -> GuardManager: ...
367
+ def get_value_manager(
368
+ self,
369
+ index: int,
370
+ source: str,
371
+ example_value: Any,
372
+ guard_manager_enum: GuardManagerType,
373
+ ) -> GuardManager: ...
374
+ def get_key_value_managers(
375
+ self,
376
+ ) -> dict[int, tuple[GuardManager, GuardManager]]: ...
377
+
378
+ # Guard accessor stubs
379
+ class GuardAccessor: ...
380
+ class DictGetItemGuardAccessor(GuardAccessor): ...
381
+ class GetGenericDictGuardAccessor(GuardAccessor): ...
382
+ class TypeDictGuardAccessor(GuardAccessor): ...
383
+ class TypeMROGuardAccessor(GuardAccessor): ...
384
+ class ClosureGuardAccessor(GuardAccessor): ...
385
+ class TupleGetItemGuardAccessor(GuardAccessor): ...
386
+ class TypeGuardAccessor(GuardAccessor): ...
387
+ class CodeGuardAccessor(GuardAccessor): ...
388
+ class FuncDefaultsGuardAccessor(GuardAccessor): ...
389
+ class FuncKwDefaultsGuardAccessor(GuardAccessor): ...
390
+
391
+ class GetAttrGuardAccessor(GuardAccessor):
392
+ def get_attr_name(self) -> str: ...
393
+
394
+ def install_object_aliasing_guard(
395
+ x: GuardManager,
396
+ y: GuardManager,
397
+ verbose_code_parts: list[str],
398
+ ) -> None: ...
399
+ def install_no_tensor_aliasing_guard(
400
+ guard_managers: list[GuardManager],
401
+ tensor_names: list[str],
402
+ verbose_code_parts: list[str],
403
+ ) -> None: ...
404
+ def install_storage_overlapping_guard(
405
+ overlapping_guard_managers: list[GuardManager],
406
+ non_overlapping_guard_managers: list[GuardManager],
407
+ verbose_code_parts: list[str],
408
+ ) -> None: ...
409
+ def install_symbolic_shape_guard(
410
+ guard_managers: list[GuardManager],
411
+ nargs_int: int,
412
+ nargs_float: int,
413
+ py_addr: int,
414
+ py_addr_keep_alive: Any,
415
+ verbose_code_parts: list[str],
416
+ ) -> None: ...
417
+ def profile_guard_manager(
418
+ guard_manager: GuardManager,
419
+ f_locals: dict[str, Any],
420
+ n_iters: int,
421
+ ) -> float: ...
422
+
423
+ class TensorGuards:
424
+ def __init__(
425
+ self,
426
+ *,
427
+ dynamic_dims_sizes: list[torch.SymInt | None] | None = None,
428
+ dynamic_dims_strides: list[torch.SymInt | None] | None = None,
429
+ ) -> None: ...
430
+ def check(self, *args: Any) -> bool: ...
431
+ def check_verbose(
432
+ self, *args: Any, tensor_check_names: Optional[list[str]] = None
433
+ ) -> bool | str: ...
434
+
435
+ def assert_size_stride(
436
+ item: torch.Tensor,
437
+ size: torch.types._size,
438
+ stride: torch.types._size,
439
+ op_name: str | None = None,
440
+ ) -> None: ...
441
+ def assert_alignment(
442
+ item: torch.Tensor,
443
+ alignment: int,
444
+ op_name: str | None = None,
445
+ ) -> None: ...
446
+ def check_obj_id(obj: object, expected: int) -> bool: ...
447
+ def check_type_id(obj: object, expected: int) -> bool: ...
448
+ def dict_version(d: dict[Any, Any]) -> int: ...
449
+ def compute_overlapping_tensors(
450
+ tensors: list[torch.Tensor], symbolic: bool = True
451
+ ) -> set[int]: ...
452
+ def set_is_in_mode_without_ignore_compile_internals(value: bool) -> None: ...
URSA/.venv_ursa/lib/python3.12/site-packages/torch/_C/_export/__init__.pyi ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Defined in torch/csrc/export/pybind.cpp
2
+ class CppExportedProgram: ...
3
+
4
+ def deserialize_exported_program(
5
+ serialized_program: str,
6
+ ) -> CppExportedProgram: ...
7
+ def serialize_exported_program(
8
+ cpp_exported_program: CppExportedProgram,
9
+ ) -> str: ...
URSA/.venv_ursa/lib/python3.12/site-packages/torch/_C/_export/pt2_archive_constants.pyi ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Defined in torch/csrc/export/pt2_archive_constants.h
2
+
3
+ ARCHIVE_ROOT_NAME: str = ...
4
+ ARCHIVE_FORMAT_PATH: str = ...
5
+ ARCHIVE_FORMAT_VALUE: str = ...
6
+ ARCHIVE_VERSION_PATH: str = ...
7
+ ARCHIVE_VERSION_VALUE: str = ...
8
+ MODELS_DIR: str = ...
9
+ MODELS_FILENAME_FORMAT: str = ...
10
+ AOTINDUCTOR_DIR: str = ...
11
+ MTIA_DIR: str = ...
12
+ WEIGHTS_DIR: str = ...
13
+ WEIGHTS_CONFIG_FILENAME_FORMAT: str = ...
14
+ WEIGHT_FILENAME_PREFIX: str = ...
15
+ CONSTANTS_DIR: str = ...
16
+ CONSTANTS_CONFIG_FILENAME_FORMAT: str = ...
17
+ TENSOR_CONSTANT_FILENAME_PREFIX: str = ...
18
+ CUSTOM_OBJ_FILENAME_PREFIX: str = ...
19
+ SAMPLE_INPUTS_DIR: str = ...
20
+ SAMPLE_INPUTS_FILENAME_FORMAT: str = ...
21
+ EXECUTORCH_DIR: str = ...
22
+ EXTRA_DIR: str = ...
23
+ MODULE_INFO_PATH: str = ...
24
+ XL_MODEL_WEIGHTS_DIR: str = ...
25
+ XL_MODEL_WEIGHTS_PARAM_CONFIG_PATH: str = ...
URSA/.venv_ursa/lib/python3.12/site-packages/torch/_custom_op/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (208 Bytes). View file
 
URSA/.venv_ursa/lib/python3.12/site-packages/torch/_custom_op/__pycache__/autograd.cpython-312.pyc ADDED
Binary file (13.6 kB). View file
 
URSA/.venv_ursa/lib/python3.12/site-packages/torch/_custom_op/__pycache__/impl.cpython-312.pyc ADDED
Binary file (32.4 kB). View file
 
URSA/.venv_ursa/lib/python3.12/site-packages/torch/amp/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (431 Bytes). View file
 
URSA/.venv_ursa/lib/python3.12/site-packages/torch/amp/__pycache__/autocast_mode.cpython-312.pyc ADDED
Binary file (24.1 kB). View file
 
URSA/.venv_ursa/lib/python3.12/site-packages/torch/amp/__pycache__/grad_scaler.cpython-312.pyc ADDED
Binary file (32.3 kB). View file
 
URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (862 Bytes). View file
 
URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/__init__.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # We are exposing all subpackages to the end-user.
2
+ # Because of possible inter-dependency, we want to avoid
3
+ # the cyclic imports, thus implementing lazy version
4
+ # as per https://peps.python.org/pep-0562/
5
+
6
+ from typing import TYPE_CHECKING as _TYPE_CHECKING
7
+
8
+
9
+ if _TYPE_CHECKING:
10
+ from types import ModuleType
11
+
12
+ from torch.ao.nn import ( # noqa: TC004
13
+ intrinsic as intrinsic,
14
+ qat as qat,
15
+ quantizable as quantizable,
16
+ quantized as quantized,
17
+ sparse as sparse,
18
+ )
19
+
20
+
21
+ __all__ = [
22
+ "intrinsic",
23
+ "qat",
24
+ "quantizable",
25
+ "quantized",
26
+ "sparse",
27
+ ]
28
+
29
+
30
+ def __getattr__(name: str) -> "ModuleType":
31
+ if name in __all__:
32
+ import importlib
33
+
34
+ return importlib.import_module("." + name, __name__)
35
+ raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (894 Bytes). View file
 
URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/__init__.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import types
2
+
3
+ from .modules import * # noqa: F403
4
+ from .modules.fused import _FusedModule # noqa: F403
5
+
6
+
7
+ # # Subpackages
8
+ # from . import qat # noqa: F403
9
+ # from . import quantized # noqa: F403
10
+
11
+ __all__ = [
12
+ "ConvBn1d",
13
+ "ConvBn2d",
14
+ "ConvBn3d",
15
+ "ConvBnReLU1d",
16
+ "ConvBnReLU2d",
17
+ "ConvBnReLU3d",
18
+ "ConvReLU1d",
19
+ "ConvReLU2d",
20
+ "ConvReLU3d",
21
+ "LinearReLU",
22
+ "BNReLU2d",
23
+ "BNReLU3d",
24
+ "LinearBn1d",
25
+ "LinearLeakyReLU",
26
+ "LinearTanh",
27
+ "ConvAdd2d",
28
+ "ConvAddReLU2d",
29
+ ]
30
+
31
+
32
+ # We are exposing all subpackages to the end-user.
33
+ # Because of possible inter-dependency, we want to avoid
34
+ # the cyclic imports, thus implementing lazy version
35
+ # as per https://peps.python.org/pep-0562/
36
+ def __getattr__(name: str) -> types.ModuleType:
37
+ if name in __all__:
38
+ import importlib
39
+
40
+ return importlib.import_module("." + name, __name__)
41
+ raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (1.01 kB). View file
 
URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/modules/__init__.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .fused import ( # noqa: F401
2
+ _FusedModule,
3
+ BNReLU2d,
4
+ BNReLU3d,
5
+ ConvAdd2d,
6
+ ConvAddReLU2d,
7
+ ConvBn1d,
8
+ ConvBn2d,
9
+ ConvBn3d,
10
+ ConvBnReLU1d,
11
+ ConvBnReLU2d,
12
+ ConvBnReLU3d,
13
+ ConvReLU1d,
14
+ ConvReLU2d,
15
+ ConvReLU3d,
16
+ LinearBn1d,
17
+ LinearLeakyReLU,
18
+ LinearReLU,
19
+ LinearTanh,
20
+ )
21
+
22
+
23
+ __all__ = [
24
+ "ConvBn1d",
25
+ "ConvBn2d",
26
+ "ConvBn3d",
27
+ "ConvBnReLU1d",
28
+ "ConvBnReLU2d",
29
+ "ConvBnReLU3d",
30
+ "ConvReLU1d",
31
+ "ConvReLU2d",
32
+ "ConvReLU3d",
33
+ "LinearReLU",
34
+ "BNReLU2d",
35
+ "BNReLU3d",
36
+ "LinearBn1d",
37
+ "LinearLeakyReLU",
38
+ "LinearTanh",
39
+ "ConvAdd2d",
40
+ "ConvAddReLU2d",
41
+ ]
URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/modules/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (759 Bytes). View file
 
URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/modules/__pycache__/fused.cpython-312.pyc ADDED
Binary file (13.2 kB). View file
 
URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/modules/fused.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import torch
3
+ from torch.nn import (
4
+ BatchNorm1d,
5
+ BatchNorm2d,
6
+ BatchNorm3d,
7
+ Conv1d,
8
+ Conv2d,
9
+ Conv3d,
10
+ Linear,
11
+ ReLU,
12
+ )
13
+ from torch.nn.utils.parametrize import type_before_parametrizations
14
+
15
+
16
+ __all__ = [
17
+ "ConvReLU1d",
18
+ "ConvReLU2d",
19
+ "ConvReLU3d",
20
+ "LinearReLU",
21
+ "ConvBn1d",
22
+ "ConvBn2d",
23
+ "ConvBnReLU1d",
24
+ "ConvBnReLU2d",
25
+ "ConvBn3d",
26
+ "ConvBnReLU3d",
27
+ "BNReLU2d",
28
+ "BNReLU3d",
29
+ "LinearBn1d",
30
+ "LinearLeakyReLU",
31
+ "LinearTanh",
32
+ "ConvAdd2d",
33
+ "ConvAddReLU2d",
34
+ ]
35
+
36
+
37
+ # Used for identifying intrinsic modules used in quantization
38
+ class _FusedModule(torch.nn.Sequential):
39
+ pass
40
+
41
+
42
+ class ConvReLU1d(_FusedModule):
43
+ r"""This is a sequential container which calls the Conv1d and ReLU modules.
44
+ During quantization this will be replaced with the corresponding fused module."""
45
+
46
+ def __init__(self, conv, relu):
47
+ assert (
48
+ type_before_parametrizations(conv) == Conv1d
49
+ and type_before_parametrizations(relu) == ReLU
50
+ ), (
51
+ f"Incorrect types for input modules{type_before_parametrizations(conv)}"
52
+ f"{type_before_parametrizations(relu)}"
53
+ )
54
+ super().__init__(conv, relu)
55
+
56
+
57
+ class ConvReLU2d(_FusedModule):
58
+ r"""This is a sequential container which calls the Conv2d and ReLU modules.
59
+ During quantization this will be replaced with the corresponding fused module."""
60
+
61
+ def __init__(self, conv, relu):
62
+ assert (
63
+ type_before_parametrizations(conv) == Conv2d
64
+ and type_before_parametrizations(relu) == ReLU
65
+ ), (
66
+ f"Incorrect types for input modules{type_before_parametrizations(conv)}"
67
+ f"{type_before_parametrizations(relu)}"
68
+ )
69
+ super().__init__(conv, relu)
70
+
71
+
72
+ class ConvReLU3d(_FusedModule):
73
+ r"""This is a sequential container which calls the Conv3d and ReLU modules.
74
+ During quantization this will be replaced with the corresponding fused module."""
75
+
76
+ def __init__(self, conv, relu):
77
+ assert (
78
+ type_before_parametrizations(conv) == Conv3d
79
+ and type_before_parametrizations(relu) == ReLU
80
+ ), (
81
+ f"Incorrect types for input modules{type_before_parametrizations(conv)}"
82
+ f"{type_before_parametrizations(relu)}"
83
+ )
84
+ super().__init__(conv, relu)
85
+
86
+
87
+ class LinearReLU(_FusedModule):
88
+ r"""This is a sequential container which calls the Linear and ReLU modules.
89
+ During quantization this will be replaced with the corresponding fused module."""
90
+
91
+ def __init__(self, linear, relu):
92
+ assert (
93
+ type_before_parametrizations(linear) == Linear
94
+ and type_before_parametrizations(relu) == ReLU
95
+ ), (
96
+ f"Incorrect types for input modules{type_before_parametrizations(linear)}"
97
+ f"{type_before_parametrizations(relu)}"
98
+ )
99
+ super().__init__(linear, relu)
100
+
101
+
102
+ class ConvBn1d(_FusedModule):
103
+ r"""This is a sequential container which calls the Conv 1d and Batch Norm 1d modules.
104
+ During quantization this will be replaced with the corresponding fused module."""
105
+
106
+ def __init__(self, conv, bn):
107
+ assert (
108
+ type_before_parametrizations(conv) == Conv1d
109
+ and type_before_parametrizations(bn) == BatchNorm1d
110
+ ), (
111
+ f"Incorrect types for input modules{type_before_parametrizations(conv)}"
112
+ f"{type_before_parametrizations(bn)}"
113
+ )
114
+ super().__init__(conv, bn)
115
+
116
+
117
+ class ConvBn2d(_FusedModule):
118
+ r"""This is a sequential container which calls the Conv 2d and Batch Norm 2d modules.
119
+ During quantization this will be replaced with the corresponding fused module."""
120
+
121
+ def __init__(self, conv, bn):
122
+ assert (
123
+ type_before_parametrizations(conv) == Conv2d
124
+ and type_before_parametrizations(bn) == BatchNorm2d
125
+ ), (
126
+ f"Incorrect types for input modules{type_before_parametrizations(conv)}"
127
+ f"{type_before_parametrizations(bn)}"
128
+ )
129
+ super().__init__(conv, bn)
130
+
131
+
132
+ class ConvBnReLU1d(_FusedModule):
133
+ r"""This is a sequential container which calls the Conv 1d, Batch Norm 1d, and ReLU modules.
134
+ During quantization this will be replaced with the corresponding fused module."""
135
+
136
+ def __init__(self, conv, bn, relu):
137
+ assert (
138
+ type_before_parametrizations(conv) == Conv1d
139
+ and type_before_parametrizations(bn) == BatchNorm1d
140
+ and type_before_parametrizations(relu) == ReLU
141
+ ), (
142
+ f"Incorrect types for input modules{type_before_parametrizations(conv)}"
143
+ f"{type_before_parametrizations(bn)}"
144
+ f"{type_before_parametrizations(relu)}"
145
+ )
146
+ super().__init__(conv, bn, relu)
147
+
148
+
149
+ class ConvBnReLU2d(_FusedModule):
150
+ r"""This is a sequential container which calls the Conv 2d, Batch Norm 2d, and ReLU modules.
151
+ During quantization this will be replaced with the corresponding fused module."""
152
+
153
+ def __init__(self, conv, bn, relu):
154
+ assert (
155
+ type_before_parametrizations(conv) == Conv2d
156
+ and type_before_parametrizations(bn) == BatchNorm2d
157
+ and type_before_parametrizations(relu) == ReLU
158
+ ), (
159
+ f"Incorrect types for input modules{type_before_parametrizations(conv)}"
160
+ f"{type_before_parametrizations(bn)}"
161
+ f"{type_before_parametrizations(relu)}"
162
+ )
163
+ super().__init__(conv, bn, relu)
164
+
165
+
166
+ class ConvBn3d(_FusedModule):
167
+ r"""This is a sequential container which calls the Conv 3d and Batch Norm 3d modules.
168
+ During quantization this will be replaced with the corresponding fused module."""
169
+
170
+ def __init__(self, conv, bn):
171
+ assert (
172
+ type_before_parametrizations(conv) == Conv3d
173
+ and type_before_parametrizations(bn) == BatchNorm3d
174
+ ), (
175
+ f"Incorrect types for input modules{type_before_parametrizations(conv)}"
176
+ f"{type_before_parametrizations(bn)}"
177
+ )
178
+ super().__init__(conv, bn)
179
+
180
+
181
+ class ConvBnReLU3d(_FusedModule):
182
+ r"""This is a sequential container which calls the Conv 3d, Batch Norm 3d, and ReLU modules.
183
+ During quantization this will be replaced with the corresponding fused module."""
184
+
185
+ def __init__(self, conv, bn, relu):
186
+ assert (
187
+ type_before_parametrizations(conv) == Conv3d
188
+ and type_before_parametrizations(bn) == BatchNorm3d
189
+ and type_before_parametrizations(relu) == ReLU
190
+ ), (
191
+ f"Incorrect types for input modules{type_before_parametrizations(conv)}"
192
+ f"{type_before_parametrizations(bn)}"
193
+ f"{type_before_parametrizations(relu)}"
194
+ )
195
+ super().__init__(conv, bn, relu)
196
+
197
+
198
+ class BNReLU2d(_FusedModule):
199
+ r"""This is a sequential container which calls the BatchNorm 2d and ReLU modules.
200
+ During quantization this will be replaced with the corresponding fused module."""
201
+
202
+ def __init__(self, batch_norm, relu):
203
+ assert (
204
+ type_before_parametrizations(batch_norm) == BatchNorm2d
205
+ and type_before_parametrizations(relu) == ReLU
206
+ ), (
207
+ f"Incorrect types for input modules{type_before_parametrizations(batch_norm)}"
208
+ f"{type_before_parametrizations(relu)}"
209
+ )
210
+ super().__init__(batch_norm, relu)
211
+
212
+
213
+ class BNReLU3d(_FusedModule):
214
+ r"""This is a sequential container which calls the BatchNorm 3d and ReLU modules.
215
+ During quantization this will be replaced with the corresponding fused module."""
216
+
217
+ def __init__(self, batch_norm, relu):
218
+ assert (
219
+ type_before_parametrizations(batch_norm) == BatchNorm3d
220
+ and type_before_parametrizations(relu) == ReLU
221
+ ), (
222
+ f"Incorrect types for input modules{type_before_parametrizations(batch_norm)}"
223
+ f"{type_before_parametrizations(relu)}"
224
+ )
225
+ super().__init__(batch_norm, relu)
226
+
227
+
228
+ class LinearBn1d(_FusedModule):
229
+ r"""This is a sequential container which calls the Linear and BatchNorm1d modules.
230
+ During quantization this will be replaced with the corresponding fused module."""
231
+
232
+ def __init__(self, linear, bn):
233
+ assert (
234
+ type_before_parametrizations(linear) == Linear
235
+ and type_before_parametrizations(bn) == BatchNorm1d
236
+ ), (
237
+ f"Incorrect types for input modules{type_before_parametrizations(linear)}"
238
+ f"{type_before_parametrizations(bn)}"
239
+ )
240
+ super().__init__(linear, bn)
241
+
242
+
243
+ class LinearLeakyReLU(_FusedModule):
244
+ r"""This is a sequential container which calls the Linear and LeakyReLU modules.
245
+ During quantization this will be replaced with the corresponding fused module."""
246
+
247
+ def __init__(self, linear, leaky_relu):
248
+ assert type(linear) is Linear and type(leaky_relu) is torch.nn.LeakyReLU, (
249
+ f"Incorrect types for input modules{type(linear)}{type(leaky_relu)}"
250
+ )
251
+ super().__init__(linear, leaky_relu)
252
+
253
+
254
+ class LinearTanh(_FusedModule):
255
+ r"""This is a sequential container which calls the Linear and Tanh modules.
256
+ During quantization this will be replaced with the corresponding fused module."""
257
+
258
+ def __init__(self, linear, tanh):
259
+ assert type(linear) is Linear and type(tanh) is torch.nn.Tanh, (
260
+ f"Incorrect types for input modules{type(linear)}{type(tanh)}"
261
+ )
262
+ super().__init__(linear, tanh)
263
+
264
+
265
+ class ConvAdd2d(_FusedModule):
266
+ r"""This is a sequential container which calls the Conv2d modules with extra Add.
267
+ During quantization this will be replaced with the corresponding fused module."""
268
+
269
+ def __init__(self, conv, add):
270
+ super().__init__(conv)
271
+ self.add = add
272
+
273
+ def forward(self, x1, x2): # type: ignore[override]
274
+ r"""Applies convolution to x1 and adds the result to x2."""
275
+ return self.add(self[0](x1), x2)
276
+
277
+
278
+ class ConvAddReLU2d(_FusedModule):
279
+ r"""This is a sequential container which calls the Conv2d, add, Relu.
280
+ During quantization this will be replaced with the corresponding fused module."""
281
+
282
+ def __init__(self, conv, add, relu):
283
+ super().__init__(conv)
284
+ self.add = add
285
+ self.relu = relu
286
+
287
+ def forward(self, x1, x2): # type: ignore[override]
288
+ r"""Applies convolution to x1, adds the result to x2, and applies ReLU."""
289
+ return self.relu(self.add(self[0](x1), x2))
URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .modules import * # noqa: F403
URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (246 Bytes). View file
 
URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/modules/__init__.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .conv_fused import (
2
+ ConvBn1d,
3
+ ConvBn2d,
4
+ ConvBn3d,
5
+ ConvBnReLU1d,
6
+ ConvBnReLU2d,
7
+ ConvBnReLU3d,
8
+ ConvReLU1d,
9
+ ConvReLU2d,
10
+ ConvReLU3d,
11
+ freeze_bn_stats,
12
+ update_bn_stats,
13
+ )
14
+ from .linear_fused import LinearBn1d
15
+ from .linear_relu import LinearReLU
16
+
17
+
18
+ __all__ = [
19
+ "LinearReLU",
20
+ "LinearBn1d",
21
+ "ConvReLU1d",
22
+ "ConvReLU2d",
23
+ "ConvReLU3d",
24
+ "ConvBn1d",
25
+ "ConvBn2d",
26
+ "ConvBn3d",
27
+ "ConvBnReLU1d",
28
+ "ConvBnReLU2d",
29
+ "ConvBnReLU3d",
30
+ "update_bn_stats",
31
+ "freeze_bn_stats",
32
+ ]
URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/modules/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (698 Bytes). View file
 
URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/modules/__pycache__/conv_fused.cpython-312.pyc ADDED
Binary file (33.3 kB). View file
 
URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/modules/__pycache__/linear_fused.cpython-312.pyc ADDED
Binary file (8.46 kB). View file
 
URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/modules/__pycache__/linear_relu.cpython-312.pyc ADDED
Binary file (3.78 kB). View file
 
URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/modules/conv_fused.py ADDED
@@ -0,0 +1,958 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import math
3
+ from typing import ClassVar
4
+
5
+ import torch
6
+ import torch.ao.nn.intrinsic as nni
7
+ import torch.ao.nn.qat as nnqat
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from torch.nn import init
11
+ from torch.nn.modules.utils import _pair, _single, _triple
12
+ from torch.nn.parameter import Parameter
13
+ from torch.nn.utils import fuse_conv_bn_weights
14
+
15
+
16
+ __all__ = [
17
+ "ConvBn1d",
18
+ "ConvBnReLU1d",
19
+ "ConvReLU1d",
20
+ "ConvBn2d",
21
+ "ConvBnReLU2d",
22
+ "ConvReLU2d",
23
+ "ConvBn3d",
24
+ "ConvBnReLU3d",
25
+ "ConvReLU3d",
26
+ "update_bn_stats",
27
+ "freeze_bn_stats",
28
+ ]
29
+ _BN_CLASS_MAP = {
30
+ 1: nn.BatchNorm1d,
31
+ 2: nn.BatchNorm2d,
32
+ 3: nn.BatchNorm3d,
33
+ }
34
+
35
+
36
+ class _ConvBnNd(nn.modules.conv._ConvNd, nni._FusedModule):
37
+ _version = 2
38
+ _FLOAT_MODULE: ClassVar[type[nn.modules.conv._ConvNd]]
39
+
40
+ def __init__(
41
+ self,
42
+ # ConvNd args
43
+ in_channels,
44
+ out_channels,
45
+ kernel_size,
46
+ stride,
47
+ padding,
48
+ dilation,
49
+ transposed,
50
+ output_padding,
51
+ groups,
52
+ bias,
53
+ padding_mode,
54
+ # BatchNormNd args
55
+ # num_features: out_channels
56
+ eps=1e-05,
57
+ momentum=0.1,
58
+ # affine: True
59
+ # track_running_stats: True
60
+ # Args for this module
61
+ freeze_bn=False,
62
+ qconfig=None,
63
+ dim=2,
64
+ ):
65
+ nn.modules.conv._ConvNd.__init__(
66
+ self,
67
+ in_channels,
68
+ out_channels,
69
+ kernel_size,
70
+ stride,
71
+ padding,
72
+ dilation,
73
+ transposed,
74
+ output_padding,
75
+ groups,
76
+ False,
77
+ padding_mode,
78
+ )
79
+ assert qconfig, "qconfig must be provided for QAT module"
80
+ self.qconfig = qconfig
81
+ self.freeze_bn = freeze_bn if self.training else True
82
+ self.bn = _BN_CLASS_MAP[dim](out_channels, eps, momentum, True, True)
83
+ self.weight_fake_quant = self.qconfig.weight()
84
+ if bias:
85
+ self.bias = Parameter(torch.empty(out_channels))
86
+ else:
87
+ self.register_parameter("bias", None)
88
+ self.reset_bn_parameters()
89
+
90
+ # this needs to be called after reset_bn_parameters,
91
+ # as they modify the same state
92
+ if self.training:
93
+ if freeze_bn:
94
+ self.freeze_bn_stats()
95
+ else:
96
+ self.update_bn_stats()
97
+ else:
98
+ self.freeze_bn_stats()
99
+
100
+ self._enable_slow_path_for_better_numerical_stability = False
101
+
102
+ def reset_running_stats(self):
103
+ self.bn.reset_running_stats()
104
+
105
+ def reset_bn_parameters(self):
106
+ self.bn.reset_running_stats()
107
+ init.uniform_(self.bn.weight)
108
+ init.zeros_(self.bn.bias)
109
+ # note: below is actually for conv, not BN
110
+ if self.bias is not None:
111
+ fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
112
+ bound = 1 / math.sqrt(fan_in)
113
+ init.uniform_(self.bias, -bound, bound)
114
+
115
+ def update_bn_stats(self):
116
+ self.freeze_bn = False
117
+ self.bn.training = True
118
+ return self
119
+
120
+ def freeze_bn_stats(self):
121
+ self.freeze_bn = True
122
+ self.bn.training = False
123
+ return self
124
+
125
+ def _forward(self, input):
126
+ if self._enable_slow_path_for_better_numerical_stability:
127
+ return self._forward_slow(input)
128
+ return self._forward_approximate(input)
129
+
130
+ def _forward_approximate(self, input):
131
+ """Approximated method to fuse conv and bn. It requires only one forward pass.
132
+ conv_orig = conv / scale_factor where scale_factor = bn.weight / running_std
133
+ """
134
+ assert self.bn.running_var is not None
135
+ running_std = torch.sqrt(self.bn.running_var + self.bn.eps)
136
+ scale_factor = self.bn.weight / running_std
137
+ weight_shape = [1] * len(self.weight.shape)
138
+ weight_shape[0] = -1
139
+ bias_shape = [1] * len(self.weight.shape)
140
+ bias_shape[1] = -1
141
+ scaled_weight = self.weight_fake_quant(
142
+ self.weight * scale_factor.reshape(weight_shape)
143
+ )
144
+ # using zero bias here since the bias for original conv
145
+ # will be added later
146
+ if self.bias is not None:
147
+ zero_bias = torch.zeros_like(self.bias, dtype=input.dtype)
148
+ else:
149
+ zero_bias = torch.zeros(
150
+ self.out_channels, device=scaled_weight.device, dtype=input.dtype
151
+ )
152
+ conv = self._conv_forward(input, scaled_weight, zero_bias)
153
+ conv_orig = conv / scale_factor.reshape(bias_shape)
154
+ if self.bias is not None:
155
+ conv_orig = conv_orig + self.bias.reshape(bias_shape)
156
+ conv = self.bn(conv_orig)
157
+ return conv
158
+
159
+ def _forward_slow(self, input):
160
+ """
161
+ A more accurate but slow method to compute conv bn fusion, following https://arxiv.org/pdf/1806.08342.pdf
162
+ It requires two forward passes but handles the case bn.weight == 0
163
+
164
+ Conv: Y = WX + B_c
165
+ Conv without bias: Y0 = WX = Y - B_c, Y = Y0 + B_c
166
+
167
+ Batch statistics:
168
+ mean_Y = Y.mean()
169
+ = Y0.mean() + B_c
170
+ var_Y = (Y - mean_Y)^2.mean()
171
+ = (Y0 - Y0.mean())^2.mean()
172
+ BN (r: bn.weight, beta: bn.bias):
173
+ Z = r * (Y - mean_Y) / sqrt(var_Y + eps) + beta
174
+ = r * (Y0 - Y0.mean()) / sqrt(var_Y + eps) + beta
175
+
176
+ Fused Conv BN training (std_Y = sqrt(var_Y + eps)):
177
+ Z = (r * W / std_Y) * X + r * (B_c - mean_Y) / std_Y + beta
178
+ = (r * W / std_Y) * X - r * Y0.mean() / std_Y + beta
179
+
180
+ Fused Conv BN inference (running_std = sqrt(running_var + eps)):
181
+ Z = (r * W / running_std) * X - r * (running_mean - B_c) / running_std + beta
182
+
183
+ QAT with fused conv bn:
184
+ Z_train = fake_quant(r * W / running_std) * X * (running_std / std_Y) - r * Y0.mean() / std_Y + beta
185
+ = conv(X, fake_quant(r * W / running_std)) * (running_std / std_Y) - r * Y0.mean() / std_Y + beta
186
+ Z_inference = conv(X, fake_quant(r * W / running_std)) - r * (running_mean - B_c) / running_std + beta
187
+ """
188
+
189
+ assert self.bn.running_var is not None
190
+ assert self.bn.running_mean is not None
191
+
192
+ # using zero bias here since the bias for original conv
193
+ # will be added later
194
+ zero_bias = torch.zeros(
195
+ self.out_channels, device=self.weight.device, dtype=input.dtype
196
+ )
197
+
198
+ weight_shape = [1] * len(self.weight.shape)
199
+ weight_shape[0] = -1
200
+ bias_shape = [1] * len(self.weight.shape)
201
+ bias_shape[1] = -1
202
+
203
+ if self.bn.training:
204
+ # needed to compute batch mean/std
205
+ conv_out = self._conv_forward(input, self.weight, zero_bias)
206
+ # update bn statistics
207
+ with torch.no_grad():
208
+ conv_out_bias = (
209
+ conv_out
210
+ if self.bias is None
211
+ else conv_out + self.bias.reshape(bias_shape)
212
+ )
213
+ self.bn(conv_out_bias)
214
+
215
+ # fused conv + bn without bias using bn running statistics
216
+ running_std = torch.sqrt(self.bn.running_var + self.bn.eps)
217
+ scale_factor = self.bn.weight / running_std
218
+ scaled_weight = self.weight_fake_quant(
219
+ self.weight * scale_factor.reshape(weight_shape)
220
+ )
221
+ # fused conv without bias for inference: (r * W / running_std) * X
222
+ conv_bn = self._conv_forward(input, scaled_weight, zero_bias)
223
+
224
+ avg_dims = [0] + list(range(2, len(self.weight.shape)))
225
+ batch_mean = conv_out.mean(avg_dims)
226
+ batch_var = torch.square(conv_out - batch_mean.reshape(bias_shape)).mean(
227
+ avg_dims
228
+ )
229
+ batch_std = torch.sqrt(batch_var + self.bn.eps)
230
+
231
+ # scale to use batch std in training mode
232
+ # conv(X, r * W / std_Y) = conv(X, r * W / running_std) * (running_std / std_Y)
233
+ unscale_factor = running_std / batch_std
234
+ conv_bn *= unscale_factor.reshape(bias_shape)
235
+
236
+ fused_mean = batch_mean
237
+ fused_std = batch_std
238
+ else:
239
+ # fused conv + bn without bias using bn running statistics
240
+ running_std = torch.sqrt(self.bn.running_var + self.bn.eps)
241
+ scale_factor = self.bn.weight / running_std
242
+ scaled_weight = self.weight_fake_quant(
243
+ self.weight * scale_factor.reshape(weight_shape)
244
+ )
245
+ # fused conv without bias for inference: (r * W / running_std) * X
246
+ conv_bn = self._conv_forward(input, scaled_weight, zero_bias)
247
+
248
+ fused_mean = self.bn.running_mean - (
249
+ self.bias if self.bias is not None else 0
250
+ )
251
+ fused_std = running_std
252
+
253
+ # fused bias = beta - r * mean / std
254
+ fused_bias = self.bn.bias - self.bn.weight * fused_mean / fused_std
255
+ conv_bn += fused_bias.reshape(bias_shape)
256
+
257
+ # HACK to let conv bias participate in loss to avoid DDP error (parameters
258
+ # were not used in producing loss)
259
+ if self.bias is not None:
260
+ conv_bn += (self.bias - self.bias).reshape(bias_shape)
261
+
262
+ return conv_bn
263
+
264
+ def forward(self, input):
265
+ return self._forward(input)
266
+
267
+ def train(self, mode=True):
268
+ """
269
+ Batchnorm's training behavior is using the self.training flag. Prevent
270
+ changing it if BN is frozen. This makes sure that calling `model.train()`
271
+ on a model with a frozen BN will behave properly.
272
+ """
273
+ self.training = mode
274
+ if not self.freeze_bn:
275
+ for module in self.children():
276
+ module.train(mode)
277
+ return self
278
+
279
+ # ===== Serialization version history =====
280
+ #
281
+ # Version 1/None
282
+ # self
283
+ # |--- weight : Tensor
284
+ # |--- bias : Tensor
285
+ # |--- gamma : Tensor
286
+ # |--- beta : Tensor
287
+ # |--- running_mean : Tensor
288
+ # |--- running_var : Tensor
289
+ # |--- num_batches_tracked : Tensor
290
+ #
291
+ # Version 2
292
+ # self
293
+ # |--- weight : Tensor
294
+ # |--- bias : Tensor
295
+ # |--- bn : Module
296
+ # |--- weight : Tensor (moved from v1.self.gamma)
297
+ # |--- bias : Tensor (moved from v1.self.beta)
298
+ # |--- running_mean : Tensor (moved from v1.self.running_mean)
299
+ # |--- running_var : Tensor (moved from v1.self.running_var)
300
+ # |--- num_batches_tracked : Tensor (moved from v1.self.num_batches_tracked)
301
+ def _load_from_state_dict(
302
+ self,
303
+ state_dict,
304
+ prefix,
305
+ local_metadata,
306
+ strict,
307
+ missing_keys,
308
+ unexpected_keys,
309
+ error_msgs,
310
+ ):
311
+ version = local_metadata.get("version", None)
312
+ if version is None or version == 1:
313
+ # BN related parameters and buffers were moved into the BN module for v2
314
+ v2_to_v1_names = {
315
+ "bn.weight": "gamma",
316
+ "bn.bias": "beta",
317
+ "bn.running_mean": "running_mean",
318
+ "bn.running_var": "running_var",
319
+ "bn.num_batches_tracked": "num_batches_tracked",
320
+ }
321
+ for v2_name, v1_name in v2_to_v1_names.items():
322
+ if prefix + v1_name in state_dict:
323
+ state_dict[prefix + v2_name] = state_dict[prefix + v1_name]
324
+ state_dict.pop(prefix + v1_name)
325
+ elif prefix + v2_name in state_dict:
326
+ # there was a brief period where forward compatibility
327
+ # for this module was broken (between
328
+ # https://github.com/pytorch/pytorch/pull/38478
329
+ # and https://github.com/pytorch/pytorch/pull/38820)
330
+ # and modules emitted the v2 state_dict format while
331
+ # specifying that version == 1. This patches the forward
332
+ # compatibility issue by allowing the v2 style entries to
333
+ # be used.
334
+ pass
335
+ elif strict:
336
+ missing_keys.append(prefix + v2_name)
337
+
338
+ super()._load_from_state_dict(
339
+ state_dict,
340
+ prefix,
341
+ local_metadata,
342
+ strict,
343
+ missing_keys,
344
+ unexpected_keys,
345
+ error_msgs,
346
+ )
347
+
348
+ @classmethod
349
+ def from_float(cls, mod, use_precomputed_fake_quant=False):
350
+ r"""Create a qat module from a float module or qparams_dict
351
+
352
+ Args: `mod` a float module, either produced by torch.ao.quantization utilities
353
+ or directly from user
354
+ """
355
+ # The ignore is because _FLOAT_MODULE is a TypeVar here where the bound
356
+ # has no __name__ (code is fine though)
357
+ assert type(mod) is cls._FLOAT_MODULE, (
358
+ "qat."
359
+ + cls.__name__
360
+ + ".from_float only works for "
361
+ + cls._FLOAT_MODULE.__name__
362
+ )
363
+ assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
364
+ assert mod.qconfig, "Input float module must have a valid qconfig"
365
+ qconfig = mod.qconfig
366
+ conv, bn = mod[0], mod[1] # type: ignore[index]
367
+ qat_convbn = cls(
368
+ conv.in_channels,
369
+ conv.out_channels,
370
+ conv.kernel_size,
371
+ conv.stride,
372
+ conv.padding,
373
+ conv.dilation,
374
+ conv.groups,
375
+ conv.bias is not None,
376
+ conv.padding_mode,
377
+ bn.eps,
378
+ bn.momentum,
379
+ False,
380
+ qconfig,
381
+ )
382
+ qat_convbn.weight = conv.weight
383
+ qat_convbn.bias = conv.bias
384
+ qat_convbn.bn.weight = bn.weight
385
+ qat_convbn.bn.bias = bn.bias
386
+ qat_convbn.bn.running_mean = bn.running_mean
387
+ qat_convbn.bn.running_var = bn.running_var
388
+ # mypy error: Cannot determine type of 'num_batches_tracked'
389
+ qat_convbn.bn.num_batches_tracked = bn.num_batches_tracked
390
+ return qat_convbn
391
+
392
+ def to_float(self):
393
+ cls = type(self)
394
+ conv = cls._FLOAT_CONV_MODULE( # type: ignore[attr-defined]
395
+ self.in_channels,
396
+ self.out_channels,
397
+ self.kernel_size,
398
+ self.stride,
399
+ self.padding,
400
+ self.dilation,
401
+ self.groups,
402
+ self.bias is not None,
403
+ self.padding_mode,
404
+ )
405
+ conv.weight = torch.nn.Parameter(self.weight.detach())
406
+ if self.bias is not None:
407
+ conv.bias = torch.nn.Parameter(self.bias.detach())
408
+
409
+ if cls._FLOAT_BN_MODULE: # type: ignore[attr-defined]
410
+ # fuse bn into conv
411
+ assert self.bn.running_var is not None and self.bn.running_mean is not None
412
+ conv.weight, conv.bias = fuse_conv_bn_weights(
413
+ conv.weight,
414
+ conv.bias,
415
+ self.bn.running_mean,
416
+ self.bn.running_var,
417
+ self.bn.eps,
418
+ self.bn.weight,
419
+ self.bn.bias,
420
+ )
421
+
422
+ if cls._FLOAT_RELU_MODULE: # type: ignore[attr-defined]
423
+ modules = []
424
+ modules.append(conv)
425
+ relu = cls._FLOAT_RELU_MODULE() # type: ignore[attr-defined]
426
+ modules.append(relu)
427
+ conv_relu = cls._FUSED_FLOAT_MODULE(*modules) # type: ignore[attr-defined]
428
+ conv_relu.train(self.training)
429
+ return conv_relu
430
+ else:
431
+ conv.train(self.training)
432
+ return conv
433
+
434
+
435
+ class ConvBn1d(_ConvBnNd, nn.Conv1d):
436
+ r"""
437
+ A ConvBn1d module is a module fused from Conv1d and BatchNorm1d,
438
+ attached with FakeQuantize modules for weight,
439
+ used in quantization aware training.
440
+
441
+ We combined the interface of :class:`torch.nn.Conv1d` and
442
+ :class:`torch.nn.BatchNorm1d`.
443
+
444
+ Similar to :class:`torch.nn.Conv1d`, with FakeQuantize modules initialized
445
+ to default.
446
+
447
+ Attributes:
448
+ freeze_bn:
449
+ weight_fake_quant: fake quant module for weight
450
+
451
+ """
452
+
453
+ _FLOAT_BN_MODULE: ClassVar[type[nn.BatchNorm1d]] = nn.BatchNorm1d
454
+ _FLOAT_RELU_MODULE: ClassVar[type[nn.Module] | None] = None
455
+ _FLOAT_MODULE: ClassVar[type[nn.Module]] = nni.ConvBn1d # type: ignore[assignment]
456
+ _FLOAT_CONV_MODULE: ClassVar[type[nn.Conv1d]] = nn.Conv1d
457
+
458
+ def __init__(
459
+ self,
460
+ # Conv1d args
461
+ in_channels,
462
+ out_channels,
463
+ kernel_size,
464
+ stride=1,
465
+ padding=0,
466
+ dilation=1,
467
+ groups=1,
468
+ bias=None,
469
+ padding_mode="zeros",
470
+ # BatchNorm1d args
471
+ # num_features: out_channels
472
+ eps=1e-05,
473
+ momentum=0.1,
474
+ # affine: True
475
+ # track_running_stats: True
476
+ # Args for this module
477
+ freeze_bn=False,
478
+ qconfig=None,
479
+ ):
480
+ kernel_size = _single(kernel_size)
481
+ stride = _single(stride)
482
+ padding = _single(padding)
483
+ dilation = _single(dilation)
484
+ _ConvBnNd.__init__(
485
+ self,
486
+ in_channels,
487
+ out_channels,
488
+ kernel_size,
489
+ stride,
490
+ padding,
491
+ dilation,
492
+ False,
493
+ _single(0),
494
+ groups,
495
+ bias,
496
+ padding_mode,
497
+ eps,
498
+ momentum,
499
+ freeze_bn,
500
+ qconfig,
501
+ dim=1,
502
+ )
503
+
504
+
505
+ class ConvBnReLU1d(ConvBn1d):
506
+ r"""
507
+ A ConvBnReLU1d module is a module fused from Conv1d, BatchNorm1d and ReLU,
508
+ attached with FakeQuantize modules for weight,
509
+ used in quantization aware training.
510
+
511
+ We combined the interface of :class:`torch.nn.Conv1d` and
512
+ :class:`torch.nn.BatchNorm1d` and :class:`torch.nn.ReLU`.
513
+
514
+ Similar to `torch.nn.Conv1d`, with FakeQuantize modules initialized to
515
+ default.
516
+
517
+ Attributes:
518
+ weight_fake_quant: fake quant module for weight
519
+
520
+ """
521
+
522
+ # base class defines _FLOAT_MODULE as "ConvBn1d"
523
+ _FLOAT_MODULE: ClassVar[type[nn.Module]] = nni.ConvBnReLU1d
524
+ _FLOAT_CONV_MODULE: ClassVar[type[nn.Conv1d]] = nn.Conv1d
525
+ _FLOAT_BN_MODULE: ClassVar[type[nn.BatchNorm1d]] = nn.BatchNorm1d
526
+ _FLOAT_RELU_MODULE: ClassVar[type[nn.Module] | None] = nn.ReLU
527
+ # module class after fusing bn into conv
528
+ _FUSED_FLOAT_MODULE: ClassVar[type[nn.Module] | None] = nni.ConvReLU1d
529
+
530
+ def forward(self, input):
531
+ r"""Performs forward pass through fused Conv1d, BatchNorm1d, and ReLU."""
532
+ return F.relu(self._forward(input))
533
+
534
+ @classmethod
535
+ def from_float(cls, mod, use_precomputed_fake_quant=False):
536
+ r"""Creates a QAT module from a floating point module."""
537
+ return super().from_float(mod, use_precomputed_fake_quant)
538
+
539
+
540
+ class ConvReLU1d(nnqat.Conv1d, nni._FusedModule):
541
+ r"""A ConvReLU1d module is a fused module of Conv1d and ReLU, attached with
542
+ FakeQuantize modules for weight for
543
+ quantization aware training.
544
+
545
+ We combined the interface of :class:`~torch.nn.Conv1d` and
546
+ :class:`~torch.nn.BatchNorm1d`.
547
+
548
+ Attributes:
549
+ weight_fake_quant: fake quant module for weight
550
+
551
+ """
552
+
553
+ _FLOAT_MODULE: ClassVar[type[nni.ConvReLU1d]] = nni.ConvReLU1d # type: ignore[assignment]
554
+ _FLOAT_CONV_MODULE: ClassVar[type[nn.Conv1d]] = nn.Conv1d
555
+ _FLOAT_BN_MODULE: ClassVar[type[nn.Module] | None] = None
556
+ _FLOAT_RELU_MODULE: ClassVar[type[nn.Module] | None] = nn.ReLU
557
+
558
+ def __init__(
559
+ self,
560
+ in_channels,
561
+ out_channels,
562
+ kernel_size,
563
+ stride=1,
564
+ padding=0,
565
+ dilation=1,
566
+ groups=1,
567
+ bias=True,
568
+ padding_mode="zeros",
569
+ qconfig=None,
570
+ ):
571
+ super().__init__(
572
+ in_channels,
573
+ out_channels,
574
+ kernel_size,
575
+ stride=stride,
576
+ padding=padding,
577
+ dilation=dilation,
578
+ groups=groups,
579
+ bias=bias,
580
+ # pyrefly: ignore [bad-argument-type]
581
+ padding_mode=padding_mode,
582
+ qconfig=qconfig,
583
+ )
584
+ assert qconfig, "qconfig must be provided for QAT module"
585
+ self.qconfig = qconfig
586
+ self.weight_fake_quant = self.qconfig.weight()
587
+
588
+ def forward(self, input):
589
+ r"""Performs forward pass through fused Conv1d and ReLU."""
590
+ return F.relu(
591
+ self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias)
592
+ )
593
+
594
+ @classmethod
595
+ def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override]
596
+ r"""Creates a QAT module from a floating point module."""
597
+ return super().from_float(
598
+ mod, use_precomputed_fake_quant=use_precomputed_fake_quant
599
+ )
600
+
601
+
602
+ class ConvBn2d(_ConvBnNd, nn.Conv2d):
603
+ r"""
604
+ A ConvBn2d module is a module fused from Conv2d and BatchNorm2d,
605
+ attached with FakeQuantize modules for weight,
606
+ used in quantization aware training.
607
+
608
+ We combined the interface of :class:`torch.nn.Conv2d` and
609
+ :class:`torch.nn.BatchNorm2d`.
610
+
611
+ Similar to :class:`torch.nn.Conv2d`, with FakeQuantize modules initialized
612
+ to default.
613
+
614
+ Attributes:
615
+ freeze_bn:
616
+ weight_fake_quant: fake quant module for weight
617
+
618
+ """
619
+
620
+ _FLOAT_MODULE: ClassVar[type[nni.ConvBn2d]] = nni.ConvBn2d # type: ignore[assignment]
621
+ _FLOAT_CONV_MODULE: ClassVar[type[nn.Conv2d]] = nn.Conv2d
622
+ _FLOAT_BN_MODULE: ClassVar[type[nn.Module] | None] = nn.BatchNorm2d
623
+ _FLOAT_RELU_MODULE: ClassVar[type[nn.Module] | None] = None
624
+
625
+ def __init__(
626
+ self,
627
+ # ConvNd args
628
+ in_channels,
629
+ out_channels,
630
+ kernel_size,
631
+ stride=1,
632
+ padding=0,
633
+ dilation=1,
634
+ groups=1,
635
+ bias=None,
636
+ padding_mode="zeros",
637
+ # BatchNorm2d args
638
+ # num_features: out_channels
639
+ eps=1e-05,
640
+ momentum=0.1,
641
+ # affine: True
642
+ # track_running_stats: True
643
+ # Args for this module
644
+ freeze_bn=False,
645
+ qconfig=None,
646
+ ):
647
+ kernel_size = _pair(kernel_size)
648
+ stride = _pair(stride)
649
+ padding = _pair(padding)
650
+ dilation = _pair(dilation)
651
+ _ConvBnNd.__init__(
652
+ self,
653
+ in_channels,
654
+ out_channels,
655
+ kernel_size,
656
+ stride,
657
+ padding,
658
+ dilation,
659
+ False,
660
+ _pair(0),
661
+ groups,
662
+ bias,
663
+ padding_mode,
664
+ eps,
665
+ momentum,
666
+ freeze_bn,
667
+ qconfig,
668
+ dim=2,
669
+ )
670
+
671
+
672
+ class ConvBnReLU2d(ConvBn2d):
673
+ r"""
674
+ A ConvBnReLU2d module is a module fused from Conv2d, BatchNorm2d and ReLU,
675
+ attached with FakeQuantize modules for weight,
676
+ used in quantization aware training.
677
+
678
+ We combined the interface of :class:`torch.nn.Conv2d` and
679
+ :class:`torch.nn.BatchNorm2d` and :class:`torch.nn.ReLU`.
680
+
681
+ Similar to `torch.nn.Conv2d`, with FakeQuantize modules initialized to
682
+ default.
683
+
684
+ Attributes:
685
+ weight_fake_quant: fake quant module for weight
686
+
687
+ """
688
+
689
+ # base class defines _FLOAT_MODULE as "ConvBn2d"
690
+ _FLOAT_MODULE: ClassVar[type[nni.ConvBnReLU2d]] = nni.ConvBnReLU2d # type: ignore[assignment]
691
+ _FLOAT_CONV_MODULE: ClassVar[type[nn.Conv2d]] = nn.Conv2d
692
+ _FLOAT_BN_MODULE: ClassVar[type[nn.BatchNorm2d]] = nn.BatchNorm2d
693
+ _FLOAT_RELU_MODULE: ClassVar[type[nn.Module] | None] = nn.ReLU
694
+ # module class after fusing bn into conv
695
+ _FUSED_FLOAT_MODULE: ClassVar[type[nni.ConvReLU2d] | None] = nni.ConvReLU2d
696
+
697
+ def forward(self, input):
698
+ r"""Performs forward pass through fused Conv2d, BatchNorm2d, and ReLU."""
699
+ return F.relu(self._forward(input))
700
+
701
+ @classmethod
702
+ def from_float(cls, mod, use_precomputed_fake_quant=False):
703
+ r"""Creates a QAT module from a floating point module."""
704
+ return super().from_float(mod, use_precomputed_fake_quant)
705
+
706
+
707
+ class ConvReLU2d(nnqat.Conv2d, nni._FusedModule):
708
+ r"""A ConvReLU2d module is a fused module of Conv2d and ReLU, attached with
709
+ FakeQuantize modules for weight for
710
+ quantization aware training.
711
+
712
+ We combined the interface of :class:`~torch.nn.Conv2d` and
713
+ :class:`~torch.nn.BatchNorm2d`.
714
+
715
+ Attributes:
716
+ weight_fake_quant: fake quant module for weight
717
+
718
+ """
719
+
720
+ _FLOAT_MODULE: ClassVar[type[nn.Module]] = nni.ConvReLU2d # type: ignore[assignment]
721
+ _FLOAT_CONV_MODULE: ClassVar[type[nn.Conv2d]] = nn.Conv2d
722
+ _FLOAT_BN_MODULE: ClassVar[type[nn.Module] | None] = None
723
+ _FLOAT_RELU_MODULE: ClassVar[type[nn.Module] | None] = nn.ReLU
724
+
725
+ def __init__(
726
+ self,
727
+ in_channels,
728
+ out_channels,
729
+ kernel_size,
730
+ stride=1,
731
+ padding=0,
732
+ dilation=1,
733
+ groups=1,
734
+ bias=True,
735
+ padding_mode="zeros",
736
+ qconfig=None,
737
+ ):
738
+ super().__init__(
739
+ in_channels,
740
+ out_channels,
741
+ kernel_size,
742
+ stride=stride,
743
+ padding=padding,
744
+ dilation=dilation,
745
+ groups=groups,
746
+ bias=bias,
747
+ # pyrefly: ignore [bad-argument-type]
748
+ padding_mode=padding_mode,
749
+ qconfig=qconfig,
750
+ )
751
+ assert qconfig, "qconfig must be provided for QAT module"
752
+ self.qconfig = qconfig
753
+ self.weight_fake_quant = self.qconfig.weight()
754
+
755
+ def forward(self, input):
756
+ r"""Performs forward pass through fused Conv2d and ReLU."""
757
+ return F.relu(
758
+ self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias)
759
+ )
760
+
761
+ @classmethod
762
+ def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override]
763
+ r"""Creates a QAT module from a floating point module."""
764
+ return super().from_float(
765
+ mod, use_precomputed_fake_quant=use_precomputed_fake_quant
766
+ )
767
+
768
+
769
+ class ConvBn3d(_ConvBnNd, nn.Conv3d):
770
+ r"""
771
+ A ConvBn3d module is a module fused from Conv3d and BatchNorm3d,
772
+ attached with FakeQuantize modules for weight,
773
+ used in quantization aware training.
774
+
775
+ We combined the interface of :class:`torch.nn.Conv3d` and
776
+ :class:`torch.nn.BatchNorm3d`.
777
+
778
+ Similar to :class:`torch.nn.Conv3d`, with FakeQuantize modules initialized
779
+ to default.
780
+
781
+ Attributes:
782
+ freeze_bn:
783
+ weight_fake_quant: fake quant module for weight
784
+
785
+ """
786
+
787
+ _FLOAT_MODULE: ClassVar[type[nni.ConvBn3d]] = nni.ConvBn3d # type: ignore[assignment]
788
+ _FLOAT_CONV_MODULE: ClassVar[type[nn.Conv3d]] = nn.Conv3d
789
+ _FLOAT_BN_MODULE: ClassVar[type[nn.Module] | None] = nn.BatchNorm3d
790
+ _FLOAT_RELU_MODULE: ClassVar[type[nn.Module] | None] = None
791
+
792
+ def __init__(
793
+ self,
794
+ # ConvNd args
795
+ in_channels,
796
+ out_channels,
797
+ kernel_size,
798
+ stride=1,
799
+ padding=0,
800
+ dilation=1,
801
+ groups=1,
802
+ bias=None,
803
+ padding_mode="zeros",
804
+ # BatchNorm3d args
805
+ # num_features: out_channels
806
+ eps=1e-05,
807
+ momentum=0.1,
808
+ # affine: True
809
+ # track_running_stats: True
810
+ # Args for this module
811
+ freeze_bn=False,
812
+ qconfig=None,
813
+ ):
814
+ kernel_size = _triple(kernel_size)
815
+ stride = _triple(stride)
816
+ padding = _triple(padding)
817
+ dilation = _triple(dilation)
818
+ _ConvBnNd.__init__(
819
+ self,
820
+ in_channels,
821
+ out_channels,
822
+ kernel_size,
823
+ stride,
824
+ padding,
825
+ dilation,
826
+ False,
827
+ _triple(0),
828
+ groups,
829
+ bias,
830
+ padding_mode,
831
+ eps,
832
+ momentum,
833
+ freeze_bn,
834
+ qconfig,
835
+ dim=3,
836
+ )
837
+
838
+
839
+ class ConvBnReLU3d(ConvBn3d):
840
+ r"""
841
+ A ConvBnReLU3d module is a module fused from Conv3d, BatchNorm3d and ReLU,
842
+ attached with FakeQuantize modules for weight,
843
+ used in quantization aware training.
844
+
845
+ We combined the interface of :class:`torch.nn.Conv3d` and
846
+ :class:`torch.nn.BatchNorm3d` and :class:`torch.nn.ReLU`.
847
+
848
+ Similar to `torch.nn.Conv3d`, with FakeQuantize modules initialized to
849
+ default.
850
+
851
+ Attributes:
852
+ weight_fake_quant: fake quant module for weight
853
+
854
+ """
855
+
856
+ _FLOAT_MODULE: ClassVar[type[nni.ConvBnReLU3d]] = nni.ConvBnReLU3d # type: ignore[assignment]
857
+ _FLOAT_CONV_MODULE: ClassVar[type[nn.Conv3d]] = nn.Conv3d
858
+ _FLOAT_BN_MODULE: ClassVar[type[nn.BatchNorm3d]] = nn.BatchNorm3d
859
+ _FLOAT_RELU_MODULE: ClassVar[type[nn.ReLU] | None] = nn.ReLU
860
+ # module class after fusing bn into conv
861
+ _FUSED_FLOAT_MODULE: ClassVar[type[nni.ConvReLU3d] | None] = nni.ConvReLU3d
862
+
863
+ def forward(self, input):
864
+ r"""Performs forward pass through fused Conv3d, BatchNorm3d, and ReLU."""
865
+ return F.relu(ConvBn3d._forward(self, input))
866
+
867
+ @classmethod
868
+ def from_float(cls, mod, use_precomputed_fake_quant=False):
869
+ r"""Creates a QAT module from a floating point module."""
870
+ return super().from_float(
871
+ mod, use_precomputed_fake_quant=use_precomputed_fake_quant
872
+ )
873
+
874
+
875
+ class ConvReLU3d(nnqat.Conv3d, nni._FusedModule):
876
+ r"""A ConvReLU3d module is a fused module of Conv3d and ReLU, attached with
877
+ FakeQuantize modules for weight for
878
+ quantization aware training.
879
+
880
+ We combined the interface of :class:`~torch.nn.Conv3d` and
881
+ :class:`~torch.nn.BatchNorm3d`.
882
+
883
+ Attributes:
884
+ weight_fake_quant: fake quant module for weight
885
+
886
+ """
887
+
888
+ _FLOAT_MODULE: ClassVar[type[nni.ConvReLU3d]] = nni.ConvReLU3d # type: ignore[assignment]
889
+ _FLOAT_CONV_MODULE: ClassVar[type[nn.Conv3d]] = nn.Conv3d
890
+ _FLOAT_BN_MODULE: ClassVar[type[nn.Module] | None] = None
891
+ _FLOAT_RELU_MODULE: ClassVar[type[nn.Module] | None] = nn.ReLU
892
+
893
+ def __init__(
894
+ self,
895
+ in_channels,
896
+ out_channels,
897
+ kernel_size,
898
+ stride=1,
899
+ padding=0,
900
+ dilation=1,
901
+ groups=1,
902
+ bias=True,
903
+ padding_mode="zeros",
904
+ qconfig=None,
905
+ ):
906
+ super().__init__(
907
+ in_channels,
908
+ out_channels,
909
+ kernel_size,
910
+ stride=stride,
911
+ padding=padding,
912
+ dilation=dilation,
913
+ groups=groups,
914
+ bias=bias,
915
+ # pyrefly: ignore [bad-argument-type]
916
+ padding_mode=padding_mode,
917
+ qconfig=qconfig,
918
+ )
919
+ assert qconfig, "qconfig must be provided for QAT module"
920
+ self.qconfig = qconfig
921
+ self.weight_fake_quant = self.qconfig.weight()
922
+
923
+ def forward(self, input):
924
+ r"""Performs forward pass through fused Conv3d and ReLU."""
925
+ return F.relu(
926
+ self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias)
927
+ )
928
+
929
+ @classmethod
930
+ def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override]
931
+ r"""Creates a QAT module from a floating point module."""
932
+ return super().from_float(
933
+ mod, use_precomputed_fake_quant=use_precomputed_fake_quant
934
+ )
935
+
936
+
937
+ def update_bn_stats(mod):
938
+ if type(mod) in {
939
+ ConvBnReLU1d,
940
+ ConvBnReLU2d,
941
+ ConvBnReLU3d,
942
+ ConvBn1d,
943
+ ConvBn2d,
944
+ ConvBn3d,
945
+ }:
946
+ mod.update_bn_stats()
947
+
948
+
949
+ def freeze_bn_stats(mod):
950
+ if type(mod) in {
951
+ ConvBnReLU1d,
952
+ ConvBnReLU2d,
953
+ ConvBnReLU3d,
954
+ ConvBn1d,
955
+ ConvBn2d,
956
+ ConvBn3d,
957
+ }:
958
+ mod.freeze_bn_stats()
URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/modules/linear_fused.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import torch
3
+ import torch.ao.nn.intrinsic as nni
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from torch.nn import init
7
+ from torch.nn.parameter import Parameter
8
+ from torch.nn.utils.fusion import fuse_linear_bn_weights
9
+
10
+
11
+ __all__ = [
12
+ "LinearBn1d",
13
+ ]
14
+
15
+
16
+ class LinearBn1d(nn.modules.linear.Linear, nni._FusedModule):
17
+ r"""
18
+ A LinearBn1d module is a module fused from Linear and BatchNorm1d, attached
19
+ with FakeQuantize modules for weight, used in quantization aware training.
20
+
21
+ We combined the interface of :class:`torch.nn.Linear` and
22
+ :class:torch.nn.BatchNorm1d`.
23
+
24
+ Similar to :class:`torch.nn.Linear`, with FakeQuantize modules initialized
25
+ to default.
26
+
27
+ Attributes:
28
+ freeze_bn:
29
+ weight_fake_quant: fake quant module for weight
30
+
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ # Linear args
36
+ in_features,
37
+ out_features,
38
+ bias=True,
39
+ # BatchNorm1d args
40
+ # num_features: out_features
41
+ eps=1e-05,
42
+ momentum=0.1,
43
+ # affine: True
44
+ # track_running_stats: True
45
+ # Args for this module
46
+ freeze_bn=False,
47
+ qconfig=None,
48
+ ):
49
+ nn.modules.linear.Linear.__init__(self, in_features, out_features, bias)
50
+ assert qconfig, "qconfig must be provided for QAT module"
51
+ self.qconfig = qconfig
52
+ self.freeze_bn = freeze_bn if self.training else True
53
+ self.bn = nn.BatchNorm1d(out_features, eps, momentum, True, True)
54
+ self.weight_fake_quant = self.qconfig.weight()
55
+ if bias:
56
+ self.bias = Parameter(torch.empty(out_features))
57
+ else:
58
+ self.register_parameter("bias", None)
59
+ self.reset_bn_parameters()
60
+
61
+ # this needs to be called after reset_bn_parameters,
62
+ # as they modify the same state
63
+ if self.training:
64
+ if freeze_bn:
65
+ self.freeze_bn_stats()
66
+ else:
67
+ self.update_bn_stats()
68
+ else:
69
+ self.freeze_bn_stats()
70
+
71
+ def reset_running_stats(self):
72
+ self.bn.reset_running_stats()
73
+
74
+ def reset_bn_parameters(self):
75
+ self.bn.reset_running_stats()
76
+ init.uniform_(self.bn.weight)
77
+ init.zeros_(self.bn.bias)
78
+
79
+ def update_bn_stats(self):
80
+ self.freeze_bn = False
81
+ self.bn.training = True
82
+ return self
83
+
84
+ def freeze_bn_stats(self):
85
+ self.freeze_bn = True
86
+ self.bn.training = False
87
+ return self
88
+
89
+ def forward(self, input):
90
+ assert self.bn.running_var is not None
91
+
92
+ # Scale the linear weights by BN's running statistics to reduce
93
+ # weight jitter, see https://arxiv.org/pdf/1806.08342.pdf, page 18
94
+ # for motivation.
95
+ #
96
+ # Instead of
97
+ #
98
+ # x1 = F.linear(x0, fq(w), b)
99
+ # x2 = self.bn(x1)
100
+ #
101
+ # We have
102
+ #
103
+ # # scale the weight by previous batch's running statistics
104
+ # scale_factor = bn.w / bn.running_std_from_prev_batch
105
+ # # do the linear transformation without bias
106
+ # x1_scaled = F.linear(x0, fq(w * scale_factor), 0)
107
+ # # reverse the scaling and add original bias
108
+ # x1_orig = x1_scaled / scale_factor + b
109
+ # x2 = self.bn(x1_orig)
110
+
111
+ running_std = torch.sqrt(self.bn.running_var + self.bn.eps)
112
+ scale_factor = self.bn.weight / running_std
113
+ weight_shape = [1] * len(self.weight.shape)
114
+ weight_shape[0] = -1
115
+ bias_shape = [1] * len(self.weight.shape)
116
+ bias_shape[1] = -1
117
+ scaled_weight = self.weight_fake_quant(
118
+ self.weight * scale_factor.reshape(weight_shape)
119
+ )
120
+ if self.bias is not None:
121
+ zero_bias = torch.zeros_like(self.bias)
122
+ else:
123
+ zero_bias = torch.zeros(self.out_features, device=scaled_weight.device)
124
+ linear_out = F.linear(input, scaled_weight, zero_bias)
125
+ linear_out_orig = linear_out / scale_factor.reshape(bias_shape)
126
+ if self.bias is not None:
127
+ linear_out_orig = linear_out_orig + self.bias.reshape(bias_shape)
128
+ bn_out = self.bn(linear_out_orig)
129
+ return bn_out
130
+
131
+ def train(self, mode=True):
132
+ """
133
+ Batchnorm's training behavior is using the self.training flag. Prevent
134
+ changing it if BN is frozen. This makes sure that calling `model.train()`
135
+ on a model with a frozen BN will behave properly.
136
+ """
137
+ self.training = mode
138
+ if not self.freeze_bn:
139
+ for module in self.children():
140
+ module.train(mode)
141
+ return self
142
+
143
+ @classmethod
144
+ def from_float(cls, mod, use_precomputed_fake_quant=False):
145
+ r"""Create a qat module from a float module or qparams_dict
146
+
147
+ Args:
148
+ mod: A float module, either produced by torch.ao.quantization
149
+ utilities or directly from the user.
150
+ """
151
+ assert type(mod) is nni.LinearBn1d, (
152
+ "qat."
153
+ + cls.__name__
154
+ + ".from_float only works for "
155
+ + nni.LinearBn1d.__name__
156
+ )
157
+ assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
158
+ assert mod.qconfig, "Input float module must have a valid config"
159
+ qconfig = mod.qconfig
160
+ linear, bn = mod[0], mod[1]
161
+ qat_linearbn = cls(
162
+ linear.in_features,
163
+ linear.out_features,
164
+ linear.bias is not None,
165
+ bn.eps,
166
+ bn.momentum,
167
+ False,
168
+ qconfig,
169
+ )
170
+ qat_linearbn.weight = linear.weight # type: ignore[assignment]
171
+ qat_linearbn.bias = linear.bias # type: ignore[assignment]
172
+ qat_linearbn.bn.weight = bn.weight # type: ignore[assignment]
173
+ qat_linearbn.bn.bias = bn.bias # type: ignore[assignment]
174
+ qat_linearbn.bn.running_mean = bn.running_mean # type: ignore[assignment]
175
+ qat_linearbn.bn.running_var = bn.running_var # type: ignore[assignment]
176
+ qat_linearbn.bn.num_batches_tracked = bn.num_batches_tracked # type: ignore[assignment]
177
+ return qat_linearbn
178
+
179
+ def to_float(self):
180
+ linear = torch.nn.Linear(self.in_features, self.out_features)
181
+ assert self.bn.running_var is not None and self.bn.running_mean is not None
182
+ linear.weight, linear.bias = fuse_linear_bn_weights(
183
+ self.weight,
184
+ self.bias,
185
+ self.bn.running_mean,
186
+ self.bn.running_var,
187
+ self.bn.eps,
188
+ self.bn.weight,
189
+ self.bn.bias,
190
+ )
191
+ return linear
URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/qat/modules/linear_relu.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ import torch
6
+ import torch.ao.nn.intrinsic as nni
7
+ import torch.ao.nn.qat as nnqat
8
+ import torch.nn.functional as F
9
+ from torch.ao.nn.intrinsic.modules.fused import _FusedModule
10
+
11
+
12
+ if TYPE_CHECKING:
13
+ from torch.ao.quantization.qconfig import QConfigAny
14
+
15
+
16
+ __all__ = ["LinearReLU"]
17
+
18
+
19
+ class LinearReLU(nnqat.Linear, _FusedModule):
20
+ r"""
21
+ A LinearReLU module fused from Linear and ReLU modules, attached with
22
+ FakeQuantize modules for weight, used in
23
+ quantization aware training.
24
+
25
+ We adopt the same interface as :class:`torch.nn.Linear`.
26
+
27
+ Similar to `torch.ao.nn.intrinsic.LinearReLU`, with FakeQuantize modules initialized to
28
+ default.
29
+
30
+ Attributes:
31
+ weight: fake quant module for weight
32
+
33
+ Examples::
34
+
35
+ >>> # xdoctest: +SKIP
36
+ >>> m = nn.qat.LinearReLU(20, 30)
37
+ >>> input = torch.randn(128, 20)
38
+ >>> output = m(input)
39
+ >>> print(output.size())
40
+ torch.Size([128, 30])
41
+ """
42
+
43
+ # pyrefly: ignore [bad-override]
44
+ _FLOAT_MODULE = nni.LinearReLU
45
+
46
+ def __init__(
47
+ self,
48
+ in_features: int,
49
+ out_features: int,
50
+ bias: bool = True,
51
+ qconfig: QConfigAny = None,
52
+ ) -> None:
53
+ super().__init__(in_features, out_features, bias, qconfig)
54
+
55
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
56
+ return F.relu(F.linear(input, self.weight_fake_quant(self.weight), self.bias))
57
+
58
+ @classmethod
59
+ def from_float(
60
+ cls,
61
+ mod: torch.nn.Module,
62
+ use_precomputed_fake_quant: bool = False,
63
+ ) -> LinearReLU:
64
+ return super().from_float(mod, use_precomputed_fake_quant) # type: ignore[no-untyped-call,no-any-return]
65
+
66
+ def to_float(self) -> nni.LinearReLU:
67
+ linear = torch.nn.Linear(
68
+ self.in_features, self.out_features, self.bias is not None
69
+ )
70
+ linear.weight = torch.nn.Parameter(self.weight.detach())
71
+ if self.bias is not None:
72
+ linear.bias = torch.nn.Parameter(self.bias.detach())
73
+ relu = torch.nn.ReLU()
74
+ return torch.ao.nn.intrinsic.LinearReLU(linear, relu) # type: ignore[no-untyped-call]
URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .modules import * # noqa: F403
2
+
3
+
4
+ __all__ = [
5
+ "BNReLU2d",
6
+ "BNReLU3d",
7
+ "ConvReLU1d",
8
+ "ConvReLU2d",
9
+ "ConvReLU3d",
10
+ "LinearReLU",
11
+ "LinearLeakyReLU",
12
+ "LinearTanh",
13
+ "ConvAdd2d",
14
+ "ConvAddReLU2d",
15
+ ]
URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (401 Bytes). View file
 
URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .modules import * # noqa: F403
URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (260 Bytes). View file
 
URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/modules/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .linear_relu import LinearReLU
2
+
3
+
4
+ __all__ = [
5
+ "LinearReLU",
6
+ ]
URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/modules/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (320 Bytes). View file
 
URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/modules/__pycache__/linear_relu.cpython-312.pyc ADDED
Binary file (3.78 kB). View file
 
URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/dynamic/modules/linear_relu.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+ from typing_extensions import Self
3
+
4
+ import torch
5
+ import torch.ao.nn.intrinsic as nni
6
+ import torch.ao.nn.quantized.dynamic as nnqd
7
+
8
+
9
+ __all__ = ["LinearReLU"]
10
+
11
+
12
+ class LinearReLU(nnqd.Linear):
13
+ r"""
14
+ A LinearReLU module fused from Linear and ReLU modules that can be used
15
+ for dynamic quantization.
16
+ Supports both, FP16 and INT8 quantization.
17
+
18
+ We adopt the same interface as :class:`torch.ao.nn.quantized.dynamic.Linear`.
19
+
20
+ Attributes:
21
+ Same as torch.ao.nn.quantized.dynamic.Linear
22
+
23
+ Examples::
24
+
25
+ >>> # xdoctest: +SKIP
26
+ >>> m = nn.intrinsic.quantized.dynamic.LinearReLU(20, 30)
27
+ >>> input = torch.randn(128, 20)
28
+ >>> output = m(input)
29
+ >>> print(output.size())
30
+ torch.Size([128, 30])
31
+ """
32
+
33
+ # pyrefly: ignore [bad-override]
34
+ _FLOAT_MODULE = nni.LinearReLU
35
+
36
+ def __init__(
37
+ self,
38
+ in_features: int,
39
+ out_features: int,
40
+ bias: bool = True,
41
+ dtype: torch.dtype = torch.qint8,
42
+ ) -> None:
43
+ super().__init__(in_features, out_features, bias, dtype)
44
+
45
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
46
+ if self._packed_params.dtype == torch.qint8:
47
+ # TODO check if we should set reduce_rage = True by default here
48
+ Y = torch.ops.quantized.linear_relu_dynamic(
49
+ x, self._packed_params._packed_params, reduce_range=True
50
+ )
51
+ elif self._packed_params.dtype == torch.float16:
52
+ Y = torch.ops.quantized.linear_relu_dynamic_fp16(
53
+ x, self._packed_params._packed_params
54
+ )
55
+ else:
56
+ raise RuntimeError("Unsupported dtype on dynamic quantized linear relu!")
57
+ return Y.to(x.dtype)
58
+
59
+ def _get_name(self) -> str:
60
+ return "DynamicQuantizedLinearReLU"
61
+
62
+ @classmethod
63
+ def from_float(
64
+ cls, mod: torch.nn.Module, use_precomputed_fake_quant: bool = False
65
+ ) -> Self:
66
+ return super().from_float(
67
+ mod, use_precomputed_fake_quant=use_precomputed_fake_quant
68
+ )
69
+
70
+ @classmethod
71
+ def from_reference(cls, ref_qlinear_relu: Any) -> Self: # type: ignore[override]
72
+ return super().from_reference(ref_qlinear_relu[0])
URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .bn_relu import BNReLU2d, BNReLU3d
2
+ from .conv_add import ConvAdd2d, ConvAddReLU2d
3
+ from .conv_relu import ConvReLU1d, ConvReLU2d, ConvReLU3d
4
+ from .linear_relu import LinearLeakyReLU, LinearReLU, LinearTanh
5
+
6
+
7
+ __all__ = [
8
+ "LinearReLU",
9
+ "ConvReLU1d",
10
+ "ConvReLU2d",
11
+ "ConvReLU3d",
12
+ "BNReLU2d",
13
+ "BNReLU3d",
14
+ "LinearLeakyReLU",
15
+ "LinearTanh",
16
+ "ConvAdd2d",
17
+ "ConvAddReLU2d",
18
+ ]
URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (623 Bytes). View file
 
URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/bn_relu.cpython-312.pyc ADDED
Binary file (4.79 kB). View file
 
URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/conv_add.cpython-312.pyc ADDED
Binary file (5.52 kB). View file
 
URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/conv_relu.cpython-312.pyc ADDED
Binary file (10.9 kB). View file
 
URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/__pycache__/linear_relu.cpython-312.pyc ADDED
Binary file (9.91 kB). View file
 
URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/bn_relu.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+
3
+ import torch
4
+ import torch.ao.nn.intrinsic
5
+ import torch.ao.nn.intrinsic.qat
6
+ import torch.ao.nn.quantized as nnq
7
+
8
+
9
+ __all__ = ["BNReLU2d", "BNReLU3d"]
10
+
11
+
12
+ class BNReLU2d(nnq.BatchNorm2d):
13
+ r"""
14
+ A BNReLU2d module is a fused module of BatchNorm2d and ReLU
15
+
16
+ We adopt the same interface as :class:`torch.ao.nn.quantized.BatchNorm2d`.
17
+
18
+ Attributes:
19
+ Same as torch.ao.nn.quantized.BatchNorm2d
20
+
21
+ """
22
+
23
+ _FLOAT_MODULE = torch.ao.nn.intrinsic.BNReLU2d
24
+
25
+ def __init__(self, num_features, eps=1e-5, momentum=0.1, device=None, dtype=None):
26
+ super().__init__(
27
+ num_features, eps=eps, momentum=momentum, device=device, dtype=dtype
28
+ )
29
+
30
+ def forward(self, input):
31
+ r"""Applies fused BatchNorm2d and ReLU."""
32
+ # Temporarily using len(shape) instead of ndim due to JIT issue
33
+ # https://github.com/pytorch/pytorch/issues/23890
34
+ if len(input.shape) != 4:
35
+ raise ValueError("Input shape must be `(N, C, H, W)`!")
36
+ return torch.ops.quantized.batch_norm2d_relu(
37
+ input,
38
+ self.weight,
39
+ self.bias,
40
+ self.running_mean,
41
+ self.running_var,
42
+ self.eps,
43
+ self.scale,
44
+ self.zero_point,
45
+ )
46
+
47
+ def _get_name(self):
48
+ return "QuantizedBNReLU2d"
49
+
50
+ @classmethod
51
+ def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override]
52
+ r"""Creates a quantized module from a float module."""
53
+ # TODO: Add qat support for BNReLU2d
54
+ return super().from_float(
55
+ mod, use_precomputed_fake_quant=use_precomputed_fake_quant
56
+ )
57
+
58
+ @classmethod
59
+ def from_reference(cls, bn_relu, output_scale, output_zero_point):
60
+ r"""Creates a quantized module from a reference module."""
61
+ return super().from_reference(bn_relu[0], output_scale, output_zero_point)
62
+
63
+
64
+ class BNReLU3d(nnq.BatchNorm3d):
65
+ r"""
66
+ A BNReLU3d module is a fused module of BatchNorm3d and ReLU
67
+
68
+ We adopt the same interface as :class:`torch.ao.nn.quantized.BatchNorm3d`.
69
+
70
+ Attributes:
71
+ Same as torch.ao.nn.quantized.BatchNorm3d
72
+
73
+ """
74
+
75
+ _FLOAT_MODULE = torch.ao.nn.intrinsic.BNReLU3d
76
+
77
+ def __init__(self, num_features, eps=1e-5, momentum=0.1, device=None, dtype=None):
78
+ super().__init__(
79
+ num_features, eps=eps, momentum=momentum, device=device, dtype=dtype
80
+ )
81
+
82
+ def forward(self, input):
83
+ r"""Applies fused BatchNorm3d and ReLU."""
84
+ # Temporarily using len(shape) instead of ndim due to JIT issue
85
+ # https://github.com/pytorch/pytorch/issues/23890
86
+ if len(input.shape) != 5:
87
+ raise ValueError("Input shape must be `(N, C, D, H, W)`!")
88
+ return torch.ops.quantized.batch_norm3d_relu(
89
+ input,
90
+ self.weight,
91
+ self.bias,
92
+ self.running_mean,
93
+ self.running_var,
94
+ self.eps,
95
+ self.scale,
96
+ self.zero_point,
97
+ )
98
+
99
+ def _get_name(self):
100
+ return "QuantizedBNReLU3d"
101
+
102
+ @classmethod
103
+ def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override]
104
+ r"""Creates a quantized module from a float module."""
105
+ # TODO: Add qat support for BNReLU3d
106
+ return super().from_float(
107
+ mod, use_precomputed_fake_quant=use_precomputed_fake_quant
108
+ )
109
+
110
+ @classmethod
111
+ def from_reference(cls, bn_relu, output_scale, output_zero_point):
112
+ r"""Creates a quantized module from a reference module."""
113
+ return super().from_reference(bn_relu[0], output_scale, output_zero_point)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/conv_add.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import torch
3
+ import torch.ao.nn.intrinsic
4
+ import torch.ao.nn.intrinsic.qat
5
+ import torch.ao.nn.quantized as nnq
6
+ import torch.nn.functional as F
7
+
8
+
9
+ _reverse_repeat_padding = nnq.modules.conv._reverse_repeat_padding
10
+
11
+
12
+ class ConvAdd2d(nnq.Conv2d):
13
+ r"""
14
+ A ConvAdd2d module is a fused module of Conv2d and Add
15
+
16
+ We adopt the same interface as :class:`torch.ao.nn.quantized.Conv2d`.
17
+
18
+ Attributes:
19
+ Same as torch.ao.nn.quantized.Conv2d
20
+
21
+ """
22
+
23
+ _FLOAT_MODULE = torch.ao.nn.intrinsic.ConvAdd2d # type: ignore[assignment]
24
+
25
+ def __init__(
26
+ self,
27
+ in_channels,
28
+ out_channels,
29
+ kernel_size,
30
+ stride=1,
31
+ padding=0,
32
+ dilation=1,
33
+ groups=1,
34
+ bias=True,
35
+ padding_mode="zeros",
36
+ device=None,
37
+ dtype=None,
38
+ ):
39
+ super().__init__(
40
+ in_channels,
41
+ out_channels,
42
+ kernel_size,
43
+ stride=stride,
44
+ padding=padding,
45
+ dilation=dilation,
46
+ groups=groups,
47
+ bias=bias,
48
+ padding_mode=padding_mode,
49
+ device=device,
50
+ dtype=dtype,
51
+ )
52
+
53
+ def forward(self, input, extra_input): # type: ignore[override]
54
+ r"""Applies fused quantized Conv2d and addition."""
55
+ # Temporarily using len(shape) instead of ndim due to JIT issue
56
+ # https://github.com/pytorch/pytorch/issues/23890
57
+ if len(input.shape) != 4:
58
+ raise ValueError("Input shape must be `(N, C, H, W)`!")
59
+ if self.padding_mode != "zeros":
60
+ _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding)
61
+ input = F.pad(
62
+ input, _reversed_padding_repeated_twice, mode=self.padding_mode
63
+ )
64
+ return torch.ops.quantized.conv2d_add(
65
+ input, extra_input, self._packed_params, self.scale, self.zero_point
66
+ )
67
+
68
+ def _get_name(self):
69
+ return "QuantizedConvAdd2d"
70
+
71
+ @classmethod
72
+ def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override]
73
+ r"""Creates a quantized module from a float module."""
74
+ return super().from_float(
75
+ mod, use_precomputed_fake_quant=use_precomputed_fake_quant
76
+ )
77
+
78
+ @classmethod
79
+ def from_reference(cls, ref_qconv, output_scale, output_zero_point):
80
+ r"""Creates a quantized module from a reference module."""
81
+ return super().from_reference(ref_qconv[0], output_scale, output_zero_point)
82
+
83
+
84
+ class ConvAddReLU2d(nnq.Conv2d):
85
+ r"""
86
+ A ConvAddReLU2d module is a fused module of Conv2d, Add and Relu
87
+
88
+ We adopt the same interface as :class:`torch.ao.nn.quantized.Conv2d`.
89
+
90
+ Attributes:
91
+ Same as torch.ao.nn.quantized.Conv2d
92
+
93
+ """
94
+
95
+ _FLOAT_MODULE = torch.ao.nn.intrinsic.ConvAddReLU2d # type: ignore[assignment]
96
+
97
+ def __init__(
98
+ self,
99
+ in_channels,
100
+ out_channels,
101
+ kernel_size,
102
+ stride=1,
103
+ padding=0,
104
+ dilation=1,
105
+ groups=1,
106
+ bias=True,
107
+ padding_mode="zeros",
108
+ device=None,
109
+ dtype=None,
110
+ ):
111
+ super().__init__(
112
+ in_channels,
113
+ out_channels,
114
+ kernel_size,
115
+ stride=stride,
116
+ padding=padding,
117
+ dilation=dilation,
118
+ groups=groups,
119
+ bias=bias,
120
+ padding_mode=padding_mode,
121
+ device=device,
122
+ dtype=dtype,
123
+ )
124
+
125
+ def forward(self, input, extra_input): # type: ignore[override]
126
+ r"""Applies fused quantized Conv2d, addition, and ReLU."""
127
+ # Temporarily using len(shape) instead of ndim due to JIT issue
128
+ # https://github.com/pytorch/pytorch/issues/23890
129
+ if len(input.shape) != 4:
130
+ raise ValueError("Input shape must be `(N, C, H, W)`!")
131
+ if self.padding_mode != "zeros":
132
+ _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding)
133
+ input = F.pad(
134
+ input, _reversed_padding_repeated_twice, mode=self.padding_mode
135
+ )
136
+ return torch.ops.quantized.conv2d_add_relu(
137
+ input, extra_input, self._packed_params, self.scale, self.zero_point
138
+ )
139
+
140
+ def _get_name(self):
141
+ return "QuantizedConvAddReLU2d"
142
+
143
+ @classmethod
144
+ def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override]
145
+ r"""Creates a quantized module from a float module."""
146
+ return super().from_float(
147
+ mod, use_precomputed_fake_quant=use_precomputed_fake_quant
148
+ )
149
+
150
+ @classmethod
151
+ def from_reference(cls, ref_qconv, output_scale, output_zero_point):
152
+ r"""Creates a quantized module from a reference module."""
153
+ return super().from_reference(ref_qconv[0], output_scale, output_zero_point)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/conv_relu.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+
3
+ import torch
4
+ import torch.ao.nn.intrinsic
5
+ import torch.ao.nn.intrinsic.qat
6
+ import torch.ao.nn.quantized as nnq
7
+ import torch.nn.functional as F
8
+ from torch.nn.utils import fuse_conv_bn_weights
9
+
10
+
11
+ __all__ = [
12
+ "ConvReLU1d",
13
+ "ConvReLU2d",
14
+ "ConvReLU3d",
15
+ ]
16
+
17
+ _reverse_repeat_padding = nnq.modules.conv._reverse_repeat_padding
18
+
19
+
20
+ # TODO: factor out the common parts to ConvNd
21
+ class ConvReLU1d(nnq.Conv1d):
22
+ r"""
23
+ A ConvReLU1d module is a fused module of Conv1d and ReLU
24
+
25
+ We adopt the same interface as :class:`torch.ao.nn.quantized.Conv1d`.
26
+
27
+ Attributes:
28
+ Same as torch.ao.nn.quantized.Conv1d
29
+
30
+ """
31
+
32
+ _FLOAT_MODULE = torch.ao.nn.intrinsic.ConvReLU1d # type: ignore[assignment]
33
+
34
+ def __init__(
35
+ self,
36
+ in_channels,
37
+ out_channels,
38
+ kernel_size,
39
+ stride=1,
40
+ padding=0,
41
+ dilation=1,
42
+ groups=1,
43
+ bias=True,
44
+ padding_mode="zeros",
45
+ device=None,
46
+ dtype=None,
47
+ ):
48
+ super().__init__(
49
+ in_channels,
50
+ out_channels,
51
+ kernel_size,
52
+ stride=stride,
53
+ padding=padding,
54
+ dilation=dilation,
55
+ groups=groups,
56
+ bias=bias,
57
+ # pyrefly: ignore [bad-argument-type]
58
+ padding_mode=padding_mode,
59
+ device=device,
60
+ dtype=dtype,
61
+ )
62
+
63
+ def forward(self, input):
64
+ r"""Applies fused quantized Conv1d and ReLU."""
65
+ # Temporarily using len(shape) instead of ndim due to JIT issue
66
+ # https://github.com/pytorch/pytorch/issues/23890
67
+ if len(input.shape) != 3:
68
+ raise ValueError("Input shape must be `(N, C, L)`!")
69
+ if self.padding_mode != "zeros":
70
+ # Padding in Conv1d is stored as (p, p), need to get (p,)
71
+ _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding[:1])
72
+ input = F.pad(
73
+ input, _reversed_padding_repeated_twice, mode=self.padding_mode
74
+ )
75
+ return torch.ops.quantized.conv1d_relu(
76
+ input, self._packed_params, self.scale, self.zero_point
77
+ )
78
+
79
+ def _get_name(self):
80
+ return "QuantizedConvReLU1d"
81
+
82
+ @classmethod
83
+ def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override]
84
+ r"""Creates a quantized module from a float module."""
85
+ if type(mod) is torch.ao.nn.intrinsic.qat.ConvBnReLU1d:
86
+ assert mod.bn.running_var is not None and mod.bn.running_mean is not None
87
+ mod.weight, mod.bias = fuse_conv_bn_weights(
88
+ mod.weight,
89
+ mod.bias,
90
+ mod.bn.running_mean,
91
+ mod.bn.running_var,
92
+ mod.bn.eps,
93
+ mod.bn.weight,
94
+ mod.bn.bias,
95
+ )
96
+ return super().from_float(mod, use_precomputed_fake_quant)
97
+
98
+ @classmethod
99
+ def from_reference(cls, ref_qconv, output_scale, output_zero_point):
100
+ r"""Creates a quantized module from a reference module."""
101
+ assert type(ref_qconv) is not torch.ao.nn.intrinsic.ConvBnReLU1d, (
102
+ "BatchNorm1d should be fused into Conv1d before converting to reference module"
103
+ )
104
+ return super().from_reference(ref_qconv[0], output_scale, output_zero_point)
105
+
106
+
107
+ class ConvReLU2d(nnq.Conv2d):
108
+ r"""
109
+ A ConvReLU2d module is a fused module of Conv2d and ReLU
110
+
111
+ We adopt the same interface as :class:`torch.ao.nn.quantized.Conv2d`.
112
+
113
+ Attributes:
114
+ Same as torch.ao.nn.quantized.Conv2d
115
+
116
+ """
117
+
118
+ _FLOAT_MODULE = torch.ao.nn.intrinsic.ConvReLU2d # type: ignore[assignment]
119
+
120
+ def __init__(
121
+ self,
122
+ in_channels,
123
+ out_channels,
124
+ kernel_size,
125
+ stride=1,
126
+ padding=0,
127
+ dilation=1,
128
+ groups=1,
129
+ bias=True,
130
+ padding_mode="zeros",
131
+ device=None,
132
+ dtype=None,
133
+ ):
134
+ super().__init__(
135
+ in_channels,
136
+ out_channels,
137
+ kernel_size,
138
+ stride=stride,
139
+ padding=padding,
140
+ dilation=dilation,
141
+ groups=groups,
142
+ bias=bias,
143
+ padding_mode=padding_mode,
144
+ device=device,
145
+ dtype=dtype,
146
+ )
147
+
148
+ def forward(self, input):
149
+ r"""Applies fused quantized Conv2d and ReLU."""
150
+ # Temporarily using len(shape) instead of ndim due to JIT issue
151
+ # https://github.com/pytorch/pytorch/issues/23890
152
+ if len(input.shape) != 4:
153
+ raise ValueError("Input shape must be `(N, C, H, W)`!")
154
+ if self.padding_mode != "zeros":
155
+ _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding)
156
+ input = F.pad(
157
+ input, _reversed_padding_repeated_twice, mode=self.padding_mode
158
+ )
159
+ return torch.ops.quantized.conv2d_relu(
160
+ input, self._packed_params, self.scale, self.zero_point
161
+ )
162
+
163
+ def _get_name(self):
164
+ return "QuantizedConvReLU2d"
165
+
166
+ @classmethod
167
+ def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override]
168
+ r"""Creates a quantized module from a float module."""
169
+ if type(mod) is torch.ao.nn.intrinsic.qat.ConvBnReLU2d:
170
+ assert mod.bn.running_var is not None and mod.bn.running_mean is not None
171
+ mod.weight, mod.bias = fuse_conv_bn_weights(
172
+ mod.weight,
173
+ mod.bias,
174
+ mod.bn.running_mean,
175
+ mod.bn.running_var,
176
+ mod.bn.eps,
177
+ mod.bn.weight,
178
+ mod.bn.bias,
179
+ )
180
+ return super().from_float(
181
+ mod, use_precomputed_fake_quant=use_precomputed_fake_quant
182
+ )
183
+
184
+ @classmethod
185
+ def from_reference(cls, ref_qconv, output_scale, output_zero_point):
186
+ r"""Creates a quantized module from a reference module."""
187
+ assert type(ref_qconv) is not torch.ao.nn.intrinsic.ConvBnReLU2d, (
188
+ "BatchNorm2d should be fused into Conv2d before converting to reference module"
189
+ )
190
+ return super().from_reference(ref_qconv[0], output_scale, output_zero_point)
191
+
192
+
193
+ class ConvReLU3d(nnq.Conv3d):
194
+ r"""
195
+ A ConvReLU3d module is a fused module of Conv3d and ReLU
196
+
197
+ We adopt the same interface as :class:`torch.ao.nn.quantized.Conv3d`.
198
+
199
+ Attributes: Same as torch.ao.nn.quantized.Conv3d
200
+
201
+ """
202
+
203
+ _FLOAT_MODULE = torch.ao.nn.intrinsic.ConvReLU3d # type: ignore[assignment]
204
+
205
+ def __init__(
206
+ self,
207
+ in_channels,
208
+ out_channels,
209
+ kernel_size,
210
+ stride=1,
211
+ padding=0,
212
+ dilation=1,
213
+ groups=1,
214
+ bias=True,
215
+ padding_mode="zeros",
216
+ device=None,
217
+ dtype=None,
218
+ ):
219
+ assert padding_mode != "reflect", "Conv3d does not support reflection padding"
220
+ super().__init__(
221
+ in_channels,
222
+ out_channels,
223
+ kernel_size,
224
+ stride=stride,
225
+ padding=padding,
226
+ dilation=dilation,
227
+ groups=groups,
228
+ bias=bias,
229
+ padding_mode=padding_mode,
230
+ device=device,
231
+ dtype=dtype,
232
+ )
233
+
234
+ def forward(self, input):
235
+ r"""Applies fused quantized Conv3d and ReLU."""
236
+ # Temporarily using len(shape) instead of ndim due to JIT issue
237
+ # https://github.com/pytorch/pytorch/issues/23890
238
+ if len(input.shape) != 5:
239
+ raise ValueError("Input shape must be `(N, C, D, H, W)`!")
240
+ if self.padding_mode != "zeros":
241
+ _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding)
242
+ input = F.pad(
243
+ input, _reversed_padding_repeated_twice, mode=self.padding_mode
244
+ )
245
+ return torch.ops.quantized.conv3d_relu(
246
+ input, self._packed_params, self.scale, self.zero_point
247
+ )
248
+
249
+ def _get_name(self):
250
+ return "QuantizedConvReLU3d"
251
+
252
+ @classmethod
253
+ def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override]
254
+ r"""Creates a quantized module from a float module."""
255
+ if type(mod) is torch.ao.nn.intrinsic.qat.ConvBnReLU3d:
256
+ assert mod.bn.running_var is not None and mod.bn.running_mean is not None
257
+ mod.weight, mod.bias = fuse_conv_bn_weights(
258
+ mod.weight,
259
+ mod.bias,
260
+ mod.bn.running_mean,
261
+ mod.bn.running_var,
262
+ mod.bn.eps,
263
+ mod.bn.weight,
264
+ mod.bn.bias,
265
+ )
266
+ return super().from_float(
267
+ mod, use_precomputed_fake_quant=use_precomputed_fake_quant
268
+ )
269
+
270
+ @classmethod
271
+ def from_reference(cls, ref_qconv, output_scale, output_zero_point):
272
+ r"""Creates a quantized module from a reference module."""
273
+ assert type(ref_qconv) is not torch.ao.nn.intrinsic.ConvBnReLU3d, (
274
+ "BatchNorm3d should be fused into Conv3d before converting to reference module"
275
+ )
276
+ return super().from_reference(ref_qconv[0], output_scale, output_zero_point)
URSA/.venv_ursa/lib/python3.12/site-packages/torch/ao/nn/intrinsic/quantized/modules/linear_relu.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import torch
3
+ import torch.ao.nn.intrinsic as nni
4
+ import torch.ao.nn.quantized as nnq
5
+ from torch.ao.nn.quantized.modules.utils import _quantize_weight
6
+
7
+
8
+ __all__ = [
9
+ "LinearReLU",
10
+ "LinearLeakyReLU",
11
+ "LinearTanh",
12
+ ]
13
+
14
+
15
+ class LinearReLU(nnq.Linear):
16
+ r"""
17
+ A LinearReLU module fused from Linear and ReLU modules
18
+
19
+ We adopt the same interface as :class:`torch.ao.nn.quantized.Linear`.
20
+
21
+ Attributes:
22
+ Same as torch.ao.nn.quantized.Linear
23
+
24
+ Examples::
25
+
26
+ >>> # xdoctest: +SKIP
27
+ >>> m = nn.intrinsic.LinearReLU(20, 30)
28
+ >>> input = torch.randn(128, 20)
29
+ >>> output = m(input)
30
+ >>> print(output.size())
31
+ torch.Size([128, 30])
32
+ """
33
+
34
+ _FLOAT_MODULE = nni.LinearReLU # type: ignore[assignment]
35
+
36
+ def __init__(self, in_features, out_features, bias=True, dtype=torch.qint8):
37
+ super().__init__(in_features, out_features, bias, dtype)
38
+
39
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
40
+ return torch.ops.quantized.linear_relu(
41
+ x, self._packed_params._packed_params, self.scale, self.zero_point
42
+ )
43
+
44
+ def _get_name(self):
45
+ return "QuantizedLinearReLU"
46
+
47
+ @classmethod
48
+ def from_float(cls, mod, use_precomputed_fake_quant=False):
49
+ return super().from_float(mod, use_precomputed_fake_quant)
50
+
51
+ @classmethod
52
+ def from_reference(cls, ref_linear_relu, output_scale, output_zero_point):
53
+ return super().from_reference(
54
+ ref_linear_relu[0], output_scale, output_zero_point
55
+ )
56
+
57
+
58
+ class LinearLeakyReLU(nnq.Linear):
59
+ r"""
60
+ For onednn backend only
61
+ A LinearLeakyReLU module fused from Linear and LeakyReLU modules
62
+ We adopt the same interface as :class:`torch.ao.nn.quantized.Linear`.
63
+ Attributes:
64
+ Same as torch.ao.nn.quantized.Linear
65
+ + negative_slope
66
+ Examples::
67
+ >>> # xdoctest: +SKIP
68
+ >>> m = nn.intrinsic.LinearLeakyReLU(20, 30, 0.01)
69
+ >>> input = torch.randn(128, 20)
70
+ >>> output = m(input)
71
+ >>> print(output.size())
72
+ torch.Size([128, 30])
73
+ """
74
+
75
+ _FLOAT_MODULE = nni.LinearLeakyReLU # type: ignore[assignment]
76
+
77
+ def __init__(
78
+ self, in_features, out_features, negative_slope, bias=True, dtype=torch.qint8
79
+ ):
80
+ super().__init__(in_features, out_features, bias, dtype)
81
+ self.negative_slope = negative_slope
82
+
83
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
84
+ return torch.ops.quantized.linear_leaky_relu(
85
+ x,
86
+ self._packed_params._packed_params,
87
+ self.scale,
88
+ self.zero_point,
89
+ self.negative_slope,
90
+ )
91
+
92
+ def _get_name(self):
93
+ return "QuantizedLinearLeakyReLU"
94
+
95
+ @classmethod
96
+ def from_float(cls, mod, use_precomputed_fake_quant=False):
97
+ assert type(mod) is nni.LinearLeakyReLU, (
98
+ "Input float module should be LinearLeakyReLU"
99
+ )
100
+ assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
101
+ activation_post_process = mod.activation_post_process
102
+ leaky_relu = mod[1]
103
+ mod = mod[0]
104
+ weight_post_process = mod.qconfig.weight() # type: ignore[union-attr, operator]
105
+ weight_post_process(mod.weight)
106
+ dtype = weight_post_process.dtype
107
+ act_scale, act_zp = activation_post_process.calculate_qparams() # type: ignore[union-attr,operator]
108
+ assert dtype == torch.qint8, "Weight observer must have dtype torch.qint8"
109
+ qweight = _quantize_weight(mod.weight.float(), weight_post_process)
110
+ qlinear_leaky_relu = cls(
111
+ mod.in_features, mod.out_features, leaky_relu.negative_slope, dtype=dtype
112
+ )
113
+ qlinear_leaky_relu.set_weight_bias(qweight, mod.bias) # type: ignore[arg-type]
114
+ qlinear_leaky_relu.scale = float(act_scale)
115
+ qlinear_leaky_relu.zero_point = int(act_zp)
116
+ return qlinear_leaky_relu
117
+
118
+ @classmethod
119
+ def from_reference(cls, ref_mod, output_scale, output_zero_point):
120
+ linear = ref_mod[0]
121
+ leaky_relu = ref_mod[1]
122
+ qlinear_leaky_relu = cls(
123
+ linear.in_features, linear.out_features, leaky_relu.negative_slope
124
+ )
125
+ qweight = linear.get_quantized_weight()
126
+ qlinear_leaky_relu.set_weight_bias(qweight, linear.bias)
127
+ qlinear_leaky_relu.scale = float(output_scale)
128
+ qlinear_leaky_relu.zero_point = int(output_zero_point)
129
+ return qlinear_leaky_relu
130
+
131
+
132
+ class LinearTanh(nnq.Linear):
133
+ r"""
134
+ A LinearTanh module fused from Linear and Tanh modules
135
+
136
+ We adopt the same interface as :class:`torch.ao.nn.quantized.Linear`.
137
+
138
+ Attributes:
139
+ Same as torch.ao.nn.quantized.Linear
140
+
141
+ Examples::
142
+
143
+ >>> # xdoctest: +SKIP
144
+ >>> m = nn.intrinsic.LinearTanh(20, 30)
145
+ >>> input = torch.randn(128, 20)
146
+ >>> output = m(input)
147
+ >>> print(output.size())
148
+ torch.Size([128, 30])
149
+ """
150
+
151
+ _FLOAT_MODULE = nni.LinearTanh # type: ignore[assignment]
152
+
153
+ def __init__(self, in_features, out_features, bias=True, dtype=torch.qint8):
154
+ super().__init__(in_features, out_features, bias, dtype)
155
+
156
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
157
+ return torch.ops.quantized.linear_tanh(
158
+ x, self._packed_params._packed_params, self.scale, self.zero_point
159
+ )
160
+
161
+ def _get_name(self):
162
+ return "QuantizedLinearTanh"
163
+
164
+ @classmethod
165
+ def from_float(cls, mod, use_precomputed_fake_quant=False):
166
+ assert type(mod) is nni.LinearTanh, "Input float module should be LinearTanh"
167
+ assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
168
+ activation_post_process = mod.activation_post_process
169
+ mod = mod[0]
170
+ weight_post_process = mod.qconfig.weight() # type: ignore[union-attr,operator]
171
+ weight_post_process(mod.weight)
172
+ dtype = weight_post_process.dtype
173
+ act_scale, act_zp = activation_post_process.calculate_qparams() # type: ignore[union-attr,operator]
174
+ assert dtype == torch.qint8, "Weight observer must have dtype torch.qint8"
175
+ qweight = _quantize_weight(mod.weight.float(), weight_post_process)
176
+ qlinear_tanh = cls(mod.in_features, mod.out_features, dtype=dtype)
177
+ qlinear_tanh.set_weight_bias(qweight, mod.bias) # type: ignore[arg-type]
178
+ qlinear_tanh.scale = float(act_scale)
179
+ qlinear_tanh.zero_point = int(act_zp)
180
+ return qlinear_tanh
181
+
182
+ @classmethod
183
+ def from_reference(cls, ref_mod, output_scale, output_zero_point):
184
+ linear = ref_mod[0]
185
+ qlinear_tanh = cls(linear.in_features, linear.out_features)
186
+ qweight = linear.get_quantized_weight()
187
+ qlinear_tanh.set_weight_bias(qweight, linear.bias)
188
+ qlinear_tanh.scale = float(output_scale)
189
+ qlinear_tanh.zero_point = int(output_zero_point)
190
+ return qlinear_tanh