koichi12 commited on
Commit
9aa23eb
·
verified ·
1 Parent(s): 82ed4ab

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/__pycache__/ModuleNode.cpython-311.pyc +3 -0
  3. tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cufft/lib/libcufftw.so.10 +3 -0
  4. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/config.py +16 -0
  5. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/device_context.py +25 -0
  6. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/ir_cache.py +13 -0
  7. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/metrics.py +21 -0
  8. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/__pycache__/__init__.cpython-311.pyc +0 -0
  9. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/dynamic/__init__.py +1 -0
  10. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/dynamic/modules/__pycache__/__init__.cpython-311.pyc +0 -0
  11. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/dynamic/modules/__pycache__/conv.cpython-311.pyc +0 -0
  12. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/_learnable_fake_quantize.py +164 -0
  13. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/__pycache__/_qnnpack_pt2e.cpython-311.pyc +0 -0
  14. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/__pycache__/qnnpack.cpython-311.pyc +0 -0
  15. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/convert.py +1131 -0
  16. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/match_utils.py +237 -0
  17. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/tracer.py +45 -0
  18. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/multiprocessing/__pycache__/queue.cpython-311.pyc +0 -0
  19. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/multiprocessing/pool.py +52 -0
  20. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/__pycache__/_reduction.cpython-311.pyc +0 -0
  21. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/__pycache__/grad.cpython-311.pyc +0 -0
  22. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/attention/__pycache__/_utils.cpython-311.pyc +0 -0
  23. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/attention/_utils.py +57 -0
  24. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/quantized/modules/__pycache__/bn_relu.cpython-311.pyc +0 -0
  25. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/_functions.py +288 -0
  26. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/upsampling.py +264 -0
  27. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/qat/dynamic/__pycache__/__init__.cpython-311.pyc +0 -0
  28. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/qat/dynamic/modules/__init__.py +3 -0
  29. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/qat/dynamic/modules/linear.py +10 -0
  30. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/qat/modules/embedding_ops.py +14 -0
  31. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantizable/__init__.py +1 -0
  32. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantizable/modules/__pycache__/__init__.cpython-311.pyc +0 -0
  33. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantizable/modules/__pycache__/activation.cpython-311.pyc +0 -0
  34. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantizable/modules/rnn.py +11 -0
  35. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/__init__.py +40 -0
  36. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/_reference/modules/__pycache__/conv.cpython-311.pyc +0 -0
  37. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/_reference/modules/__pycache__/linear.cpython-311.pyc +0 -0
  38. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/_reference/modules/__pycache__/rnn.cpython-311.pyc +0 -0
  39. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/_reference/modules/__pycache__/sparse.cpython-311.pyc +0 -0
  40. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/dynamic/__init__.py +1 -0
  41. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/dynamic/__pycache__/__init__.cpython-311.pyc +0 -0
  42. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/dynamic/modules/__pycache__/__init__.cpython-311.pyc +0 -0
  43. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/dynamic/modules/conv.py +18 -0
  44. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/dynamic/modules/linear.py +10 -0
  45. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/dynamic/modules/rnn.py +22 -0
  46. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/modules/__pycache__/activation.cpython-311.pyc +0 -0
  47. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/modules/__pycache__/conv.cpython-311.pyc +0 -0
  48. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/modules/__pycache__/dropout.cpython-311.pyc +0 -0
  49. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/modules/__pycache__/linear.cpython-311.pyc +0 -0
  50. tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/modules/__pycache__/utils.cpython-311.pyc +0 -0
.gitattributes CHANGED
@@ -72,3 +72,5 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/F
72
  tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cuda_cupti/lib/libcheckpoint.so filter=lfs diff=lfs merge=lfs -text
73
  tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/Visitor.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
74
  tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/pkg_resources/__pycache__/__init__.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
 
 
 
72
  tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cuda_cupti/lib/libcheckpoint.so filter=lfs diff=lfs merge=lfs -text
73
  tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/Visitor.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
74
  tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/pkg_resources/__pycache__/__init__.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
75
+ tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cufft/lib/libcufftw.so.10 filter=lfs diff=lfs merge=lfs -text
76
+ tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/__pycache__/ModuleNode.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/__pycache__/ModuleNode.cpython-311.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6e838b6e228975d93cd0fdc5e05254915109c27d231652f0851ab23f1b207b5f
3
+ size 233093
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cufft/lib/libcufftw.so.10 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1a592a5b2f359a9077550ee1fdadd58eb2cf9cc0bfab8fe397a374fb949da143
3
+ size 1618440
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/config.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch._C._lazy
2
+
3
+
4
+ def get_force_fallback():
5
+ """Get the config used to force LTC fallback"""
6
+ return torch._C._lazy._get_force_fallback()
7
+
8
+
9
+ def set_force_fallback(configval):
10
+ """Set the config used to force LTC fallback"""
11
+ torch._C._lazy._set_force_fallback(configval)
12
+
13
+
14
+ def set_reuse_ir(val: bool):
15
+ """Set the config to reuse IR nodes for faster tracing"""
16
+ torch._C._lazy._set_reuse_ir(val)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/device_context.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import threading
2
+ from typing import Any, Dict
3
+
4
+ import torch._C._lazy
5
+
6
+
7
+ class DeviceContext:
8
+ _CONTEXTS: Dict[str, Any] = dict()
9
+ _CONTEXTS_LOCK = threading.Lock()
10
+
11
+ def __init__(self, device):
12
+ self.device = device
13
+
14
+
15
+ def get_device_context(device=None):
16
+ if device is None:
17
+ device = torch._C._lazy._get_default_device_type()
18
+ else:
19
+ device = str(device)
20
+ with DeviceContext._CONTEXTS_LOCK:
21
+ devctx = DeviceContext._CONTEXTS.get(device, None)
22
+ if devctx is None:
23
+ devctx = DeviceContext(device)
24
+ DeviceContext._CONTEXTS[device] = devctx
25
+ return devctx
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/ir_cache.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch._C._lazy
2
+
3
+
4
+ def dump(dot_file_name: str):
5
+ """Dump TrieCache in the dot format"""
6
+ return torch._C._lazy._dump_ir_cache(dot_file_name)
7
+
8
+
9
+ def reset():
10
+ """Clear TrieCache. This is needed in testing to avoid
11
+ node reusing between different tests.
12
+ """
13
+ return torch._C._lazy._clear_ir_cache()
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/metrics.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch._C._lazy
2
+
3
+
4
+ def reset():
5
+ """Resets all metric counters."""
6
+ torch._C._lazy._reset_metrics()
7
+
8
+
9
+ def counter_names():
10
+ """Retrieves all the currently active counter names."""
11
+ return torch._C._lazy._counter_names()
12
+
13
+
14
+ def counter_value(name: str):
15
+ """Return the value of the counter with the speficied name"""
16
+ return torch._C._lazy._counter_value(name)
17
+
18
+
19
+ def metrics_report():
20
+ """Return the combined (lazy core and backend) metric report"""
21
+ return torch._C._lazy._metrics_report()
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (756 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/dynamic/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .modules import * # noqa: F403
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/dynamic/modules/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (717 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/dynamic/modules/__pycache__/conv.cpython-311.pyc ADDED
Binary file (19.3 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/_learnable_fake_quantize.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.nn.parameter import Parameter
3
+ from typing import List
4
+
5
+ __all__: List[str] = []
6
+
7
+ class _LearnableFakeQuantize(torch.ao.quantization.FakeQuantizeBase):
8
+ r"""Generalized extension of the FakeQuantize module in fake_quantize.py.
9
+
10
+ This is an extension of the FakeQuantize module in fake_quantize.py, which
11
+ supports more generalized lower-bit quantization and support learning of the scale
12
+ and zero point parameters through backpropagation. For literature references,
13
+ please see the class _LearnableFakeQuantizePerTensorOp.
14
+
15
+ In addition to the attributes in the original FakeQuantize module, the _LearnableFakeQuantize
16
+ module also includes the following attributes to support quantization parameter learning.
17
+
18
+ * :attr:`channel_len` defines the length of the channel when initializing scale and zero point
19
+ for the per channel case.
20
+
21
+ * :attr:`use_grad_scaling` defines the flag for whether the gradients for scale and zero point are
22
+ normalized by the constant, which is proportional to the square root of the number of
23
+ elements in the tensor. The related literature justifying the use of this particular constant
24
+ can be found here: https://openreview.net/pdf?id=rkgO66VKDS.
25
+
26
+ * :attr:`fake_quant_enabled` defines the flag for enabling fake quantization on the output.
27
+
28
+ * :attr:`static_enabled` defines the flag for using observer's static estimation for
29
+ scale and zero point.
30
+
31
+ * :attr:`learning_enabled` defines the flag for enabling backpropagation for scale and zero point.
32
+ """
33
+ def __init__(self, observer, quant_min=0, quant_max=255, scale=1., zero_point=0., channel_len=-1,
34
+ use_grad_scaling=False, **observer_kwargs):
35
+ super().__init__()
36
+ assert quant_min < quant_max, 'quant_min must be strictly less than quant_max.'
37
+ self.quant_min = quant_min
38
+ self.quant_max = quant_max
39
+ # also pass quant_min and quant_max to observer
40
+ observer_kwargs["quant_min"] = quant_min
41
+ observer_kwargs["quant_max"] = quant_max
42
+ self.use_grad_scaling = use_grad_scaling
43
+ if channel_len == -1:
44
+ self.scale = Parameter(torch.tensor([scale]))
45
+ self.zero_point = Parameter(torch.tensor([zero_point]))
46
+ else:
47
+ assert isinstance(channel_len, int) and channel_len > 0, "Channel size must be a positive integer."
48
+ self.scale = Parameter(torch.tensor([scale] * channel_len))
49
+ self.zero_point = Parameter(torch.tensor([zero_point] * channel_len))
50
+
51
+ self.activation_post_process = observer(**observer_kwargs)
52
+ assert torch.iinfo(self.activation_post_process.dtype).min <= quant_min, \
53
+ 'quant_min out of bound'
54
+ assert quant_max <= torch.iinfo(self.activation_post_process.dtype).max, \
55
+ 'quant_max out of bound'
56
+ self.dtype = self.activation_post_process.dtype
57
+ self.qscheme = self.activation_post_process.qscheme
58
+ self.ch_axis = self.activation_post_process.ch_axis \
59
+ if hasattr(self.activation_post_process, 'ch_axis') else -1
60
+ self.register_buffer('fake_quant_enabled', torch.tensor([1], dtype=torch.uint8))
61
+ self.register_buffer('static_enabled', torch.tensor([1], dtype=torch.uint8))
62
+ self.register_buffer('learning_enabled', torch.tensor([0], dtype=torch.uint8))
63
+
64
+ bitrange = torch.tensor(quant_max - quant_min + 1).double()
65
+ self.bitwidth = int(torch.log2(bitrange).item())
66
+ self.register_buffer('eps', torch.tensor([torch.finfo(torch.float32).eps]))
67
+
68
+ @torch.jit.export
69
+ def enable_param_learning(self):
70
+ r"""Enable parameter learning over static observer estimates.
71
+
72
+ Enables learning of quantization parameters and
73
+ disables static observer estimates. Forward path returns fake quantized X.
74
+ """
75
+ self.toggle_qparam_learning(enabled=True) \
76
+ .toggle_fake_quant(enabled=True) \
77
+ .toggle_observer_update(enabled=False)
78
+ return self
79
+
80
+ @torch.jit.export
81
+ def enable_static_estimate(self):
82
+ """Enable static estimates of quantization parameters.
83
+
84
+ Enables static observer estimates and disables learning of
85
+ quantization parameters. Forward path returns fake quantized X.
86
+ """
87
+ self.toggle_qparam_learning(enabled=False) \
88
+ .toggle_fake_quant(enabled=True) \
89
+ .toggle_observer_update(enabled=True)
90
+
91
+ @torch.jit.export
92
+ def enable_static_observation(self):
93
+ """Enable accumulation of data without updating quantization parameters.
94
+
95
+ Enables static observer accumulating data from input but doesn't
96
+ update the quantization parameters. Forward path returns the original X.
97
+ """
98
+ self.toggle_qparam_learning(enabled=False) \
99
+ .toggle_fake_quant(enabled=False) \
100
+ .toggle_observer_update(enabled=True)
101
+
102
+ @torch.jit.export
103
+ def toggle_observer_update(self, enabled=True):
104
+ self.static_enabled[0] = int(enabled) # type: ignore[operator]
105
+ return self
106
+
107
+ @torch.jit.export
108
+ def enable_observer(self, enabled=True):
109
+ self.toggle_observer_update(enabled)
110
+
111
+ @torch.jit.export
112
+ def toggle_qparam_learning(self, enabled=True):
113
+ self.learning_enabled[0] = int(enabled) # type: ignore[operator]
114
+ self.scale.requires_grad = enabled
115
+ self.zero_point.requires_grad = enabled
116
+ return self
117
+
118
+ @torch.jit.export
119
+ def toggle_fake_quant(self, enabled=True):
120
+ self.fake_quant_enabled[0] = int(enabled)
121
+ return self
122
+
123
+ @torch.jit.export
124
+ def observe_quant_params(self):
125
+ print(f'_LearnableFakeQuantize Scale: {self.scale.detach()}')
126
+ print(f'_LearnableFakeQuantize Zero Point: {self.zero_point.detach()}')
127
+
128
+ @torch.jit.export
129
+ def calculate_qparams(self):
130
+ self.scale.data.clamp_(min=self.eps.item()) # type: ignore[operator]
131
+ scale = self.scale.detach()
132
+ zero_point = self.zero_point.detach().round().clamp(self.quant_min, self.quant_max).long()
133
+ return scale, zero_point
134
+
135
+ def forward(self, X):
136
+ if self.static_enabled[0] == 1: # type: ignore[index]
137
+ self.activation_post_process(X.detach())
138
+ _scale, _zero_point = self.activation_post_process.calculate_qparams()
139
+ _scale = _scale.to(self.scale.device)
140
+ _zero_point = _zero_point.to(self.zero_point.device)
141
+ self.scale.data.copy_(_scale)
142
+ self.zero_point.data.copy_(_zero_point)
143
+ else:
144
+ self.scale.data.clamp_(min=self.eps.item()) # type: ignore[operator]
145
+
146
+ if self.fake_quant_enabled[0] == 1:
147
+ if self.qscheme in (torch.per_channel_symmetric, torch.per_tensor_symmetric):
148
+ self.zero_point.data.zero_()
149
+
150
+ if self.use_grad_scaling:
151
+ grad_factor = 1.0 / (X.numel() * self.quant_max) ** 0.5
152
+ else:
153
+ grad_factor = 1.0
154
+ if self.qscheme in (
155
+ torch.per_channel_symmetric, torch.per_channel_affine):
156
+ X = torch._fake_quantize_learnable_per_channel_affine(
157
+ X, self.scale, self.zero_point, self.ch_axis,
158
+ self.quant_min, self.quant_max, grad_factor)
159
+ else:
160
+ X = torch._fake_quantize_learnable_per_tensor_affine(
161
+ X, self.scale, self.zero_point,
162
+ self.quant_min, self.quant_max, grad_factor)
163
+
164
+ return X
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/__pycache__/_qnnpack_pt2e.cpython-311.pyc ADDED
Binary file (6.74 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/__pycache__/qnnpack.cpython-311.pyc ADDED
Binary file (4.62 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/convert.py ADDED
@@ -0,0 +1,1131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: ignore-errors
2
+
3
+ from typing import Any, Dict, List, Optional, Set, Tuple, Union, Type, Callable
4
+ from torch.ao.quantization.quant_type import QuantType
5
+ import torch
6
+ import copy
7
+ import warnings
8
+ from torch.fx import (
9
+ GraphModule,
10
+ )
11
+ from torch.fx.graph import (
12
+ Graph,
13
+ Node,
14
+ Argument,
15
+ )
16
+ from ..utils import (
17
+ activation_is_statically_quantized,
18
+ weight_is_quantized,
19
+ get_qparam_dict,
20
+ _parent_name,
21
+ get_swapped_custom_module_class,
22
+ )
23
+ from ..qconfig import (
24
+ QConfigAny,
25
+ qconfig_equals
26
+ )
27
+ from ..qconfig_mapping import QConfigMapping
28
+ from .qconfig_mapping_utils import (
29
+ _generate_node_name_to_qconfig,
30
+ _compare_prepare_convert_qconfig_mappings,
31
+ _update_qconfig_for_fusion,
32
+ _is_qconfig_supported_by_dtype_configs,
33
+ _update_qconfig_for_qat,
34
+ )
35
+ from torch.ao.quantization.backend_config.utils import (
36
+ get_root_module_to_quantized_reference_module,
37
+ get_pattern_to_dtype_configs,
38
+ get_fused_module_classes,
39
+ get_qat_module_classes,
40
+ )
41
+ from torch.ao.quantization.backend_config import (
42
+ BackendConfig,
43
+ get_native_backend_config,
44
+ )
45
+ from torch.ao.quantization.observer import _is_activation_post_process
46
+ from .graph_module import (
47
+ _is_observed_module,
48
+ _is_observed_standalone_module,
49
+ )
50
+ from ._equalize import update_obs_for_equalization, convert_eq_obs
51
+ from torch.nn.utils.parametrize import type_before_parametrizations
52
+ from .utils import (
53
+ _get_module,
54
+ _is_custom_module_lstm,
55
+ _is_custom_module_mha,
56
+ assert_and_get_unique_device,
57
+ get_custom_module_class_keys,
58
+ create_getattr_from_value,
59
+ collect_producer_nodes,
60
+ graph_module_from_producer_nodes,
61
+ node_arg_is_weight,
62
+ )
63
+ from torch.ao.quantization.utils import (
64
+ is_per_channel,
65
+ to_underlying_dtype,
66
+ )
67
+ from torch.ao.quantization.quantize import (
68
+ _remove_qconfig,
69
+ )
70
+ from torch.ao.quantization.stubs import DeQuantStub
71
+ from .custom_config import (
72
+ ConvertCustomConfig,
73
+ PrepareCustomConfig,
74
+ )
75
+ from .lower_to_fbgemm import lower_to_fbgemm
76
+ # importing the lib so that the quantized_decomposed ops are registered
77
+ from ._decomposed import quantized_decomposed_lib # noqa: F401
78
+ import operator
79
+
80
+ __all__ = [
81
+ "convert",
82
+ "convert_custom_module",
83
+ "convert_standalone_module",
84
+ "convert_weighted_module",
85
+ ]
86
+
87
+ _QSCHEME_TO_CHOOSE_QPARAMS_OP = {
88
+ torch.per_tensor_affine: torch.ops.quantized_decomposed.choose_qparams.tensor,
89
+ torch.per_tensor_symmetric: torch.ops.quantized_decomposed.choose_qparams_symmetric.tensor,
90
+ }
91
+
92
+ def _replace_observer_with_quantize_dequantize_node_decomposed(
93
+ model: torch.fx.GraphModule,
94
+ node: Node,
95
+ modules: Dict[str, torch.nn.Module],
96
+ node_name_to_scope: Dict[str, Tuple[str, type]],
97
+ node_name_to_qconfig: Dict[str, QConfigAny]) -> None:
98
+ """ Replace activation_post_process module call node with quantize and
99
+ dequantize node working with decomposed Tensor
100
+
101
+ Before:
102
+ ... -> observer_0(x) -> ...
103
+ After:
104
+ ... -> torch.ops.quantized_decomposed.quantize_per_tensor(x, ...) ->
105
+ torch.ops.quantized_decomposed.dequantize_per_tensor() -> ...
106
+
107
+ or quantize_per_channel and dequantize_per_channel
108
+ """
109
+ graph = model.graph
110
+ assert modules is not None
111
+ assert isinstance(node.target, str)
112
+ module_path, prefix = _get_module_path_and_prefix(node, node_name_to_scope, node_name_to_qconfig)
113
+ activation_post_process = modules[node.target]
114
+ if hasattr(activation_post_process, "convert"):
115
+ activation_post_process.convert(model, node)
116
+ return
117
+ # skip replacing observers to quant/dequant nodes if the qconfigs of all
118
+ # consumers and producers of this observer are None
119
+ skip_replacement = all(_has_none_qconfig(n, node_name_to_qconfig) for n in
120
+ list(node.args) + list(node.users.keys()))
121
+ if skip_replacement or not _is_conversion_supported(activation_post_process):
122
+ # didn't find corresponding quantize op and info for the activation_post_process
123
+ # so we just remove the observer
124
+ with graph.inserting_before(node):
125
+ node.replace_all_uses_with(node.args[0])
126
+ graph.erase_node(node)
127
+ return
128
+
129
+ # otherwise, we can convert the activation_post_process module call to quantize/dequantize node
130
+
131
+ # 1. extract the information from activation_post_process module for generating
132
+ # the quantize and dequantize operator
133
+ dtype = activation_post_process.dtype # type: ignore[attr-defined]
134
+
135
+ is_dynamic = False
136
+ if hasattr(activation_post_process, "is_dynamic"):
137
+ is_dynamic = activation_post_process.is_dynamic # type: ignore[assignment]
138
+
139
+ if dtype in [torch.quint8, torch.qint8, torch.qint32, torch.uint8, torch.int8, torch.int16, torch.int32] and \
140
+ (not is_dynamic):
141
+ # TODO: probably should cleanup this condition check, it's hard
142
+ # to reason about this if and the following elif
143
+
144
+ # uint8/int8/int32 static quantization branch
145
+
146
+ # 1. extract information for inserting q/dq node from activation_post_process
147
+ node_type = "call_function"
148
+ quantize_op : Optional[Callable] = None
149
+ scale, zero_point = activation_post_process.calculate_qparams() # type: ignore[attr-defined, operator]
150
+ if is_per_channel(activation_post_process.qscheme): # type: ignore[attr-defined]
151
+ ch_axis = int(activation_post_process.ch_axis) # type: ignore[attr-defined, arg-type]
152
+ quantize_op = torch.ops.quantized_decomposed.quantize_per_channel.default
153
+ dequantize_op = torch.ops.quantized_decomposed.dequantize_per_channel.default
154
+ quant_min = activation_post_process.quant_min
155
+ quant_max = activation_post_process.quant_max
156
+ dtype_ = to_underlying_dtype(dtype)
157
+ qparams = {
158
+ "_scale_": scale,
159
+ "_zero_point_": zero_point,
160
+ "_axis_": ch_axis,
161
+ "_quant_min_": quant_min,
162
+ "_quant_max_": quant_max,
163
+ "_dtype_": dtype_
164
+ }
165
+ else:
166
+ quantize_op = torch.ops.quantized_decomposed.quantize_per_tensor.default
167
+ dequantize_op = torch.ops.quantized_decomposed.dequantize_per_tensor.default
168
+ scale = float(scale)
169
+ zero_point = int(zero_point)
170
+ quant_min = activation_post_process.quant_min # type: ignore[attr-defined]
171
+ quant_max = activation_post_process.quant_max # type: ignore[attr-defined]
172
+ dtype_ = to_underlying_dtype(dtype)
173
+ qparams = {
174
+ "_scale_": scale,
175
+ "_zero_point_": zero_point,
176
+ "_quant_min_": quant_min,
177
+ "_quant_max_": quant_max,
178
+ "_dtype_": dtype_
179
+ }
180
+
181
+ # 2. replace activation_post_process node with quantize and dequantize
182
+ with graph.inserting_before(node):
183
+ input_node = node.args[0]
184
+ quantize_op_inputs = [input_node]
185
+ for key, value_or_node in qparams.items():
186
+ # TODO: we can add the information of whether a value needs to
187
+ # be registered as an attribute in qparams dict itself
188
+ if key in ['_scale_', '_zero_point_'] and (not isinstance(value_or_node, (float, int))):
189
+ # For scale and zero_point values we register them as buffers in the root module.
190
+ # However, note that when the values are not tensors, as in the case of
191
+ # per_tensor quantization, they will be treated as literals.
192
+ # However, registering them as a node seems to cause issue with dynamo
193
+ # tracing where it may consider tensor overload as opposed to default.
194
+ # With extra check of scale and zero_point being scalar, it makes
195
+ # sure that the default overload can be used.
196
+ # TODO: maybe need more complex attr name here
197
+ qparam_node = create_getattr_from_value(
198
+ model, graph, module_path + prefix + key, value_or_node)
199
+ quantize_op_inputs.append(qparam_node)
200
+ else:
201
+ # for qparams that are not scale/zero_point (like axis, dtype) we store them as literals in the graph.
202
+ quantize_op_inputs.append(value_or_node)
203
+
204
+ quantized_node = graph.create_node(node_type, quantize_op, tuple(quantize_op_inputs), {})
205
+ # use the same qparams from quantize op
206
+ dq_inputs = [quantized_node] + quantize_op_inputs[1:]
207
+ dequantized_node = graph.call_function(
208
+ dequantize_op,
209
+ tuple(dq_inputs),
210
+ {}
211
+ )
212
+
213
+ def remap_fn(x):
214
+ return dequantized_node if x is node else x
215
+
216
+ # remap numeric_debug_handle
217
+ for user_node in node.users:
218
+ if "numeric_debug_handle" in user_node.meta:
219
+ numeric_debug_handle = user_node.meta["numeric_debug_handle"]
220
+ user_node.meta["numeric_debug_handle"] = {remap_fn(k): v for k, v in numeric_debug_handle.items()}
221
+ node.replace_all_uses_with(dequantized_node)
222
+ graph.erase_node(node)
223
+ elif is_dynamic:
224
+
225
+ # uint8/int8/fp16 dynamic quantization
226
+
227
+ # 1. extract information for inserting q/dq node from activation_post_process
228
+ node_type = "call_function"
229
+ quantize_op = torch.ops.quantized_decomposed.quantize_per_tensor.tensor
230
+ # we only use choose_qparams for is_decomposed now,
231
+ # but we should probably align the non-decomposed path with this as well,
232
+ # and that can be done after we remove reduce_range flag
233
+ # 1. extract qparams from activation_post_process module
234
+ dtype_ = to_underlying_dtype(dtype)
235
+ assert dtype_ in [torch.uint8, torch.int8], \
236
+ "only uint8 and int8 are supported in reference flow for " \
237
+ "dynamic quantization right now"
238
+ quant_min = activation_post_process.quant_min # type: ignore[attr-defined]
239
+ quant_max = activation_post_process.quant_max # type: ignore[attr-defined]
240
+ qscheme = getattr(activation_post_process, "qscheme", torch.per_tensor_affine) # type: ignore[attr-defined]
241
+ eps = getattr(activation_post_process, "eps", torch.finfo(torch.float32).eps) # type: ignore[attr-defined]
242
+ # note: scale and zero_point are missing for quantize_per_tensor op
243
+ # we'll need to get this from choose_qparams op, which we'll add after
244
+ # this step
245
+ qparams = {
246
+ "_quant_min_": quant_min,
247
+ "_quant_max_": quant_max,
248
+ "_eps_": eps,
249
+ "_dtype_": dtype_
250
+ }
251
+
252
+ choose_qparams_op = _QSCHEME_TO_CHOOSE_QPARAMS_OP[qscheme]
253
+ # 2. insert choose_qparams op and update the qparams list
254
+ with graph.inserting_before(node):
255
+ input_node = node.args[0]
256
+ choose_qparams_op_inputs = [node.args[0]]
257
+ for key, value in qparams.items():
258
+ # we have quant_min, quant_max and dtype, all should be stored
259
+ # as literals
260
+ choose_qparams_op_inputs.append(value)
261
+ choose_qparams_node = graph.create_node(
262
+ "call_function",
263
+ choose_qparams_op,
264
+ tuple(choose_qparams_op_inputs),
265
+ {}
266
+ )
267
+ # choose_qparms returns (scale, zero_point)
268
+ scale_node = graph.create_node(
269
+ "call_function",
270
+ operator.getitem,
271
+ (choose_qparams_node, 0),
272
+ {}
273
+ )
274
+ zero_point_node = graph.create_node(
275
+ "call_function",
276
+ operator.getitem,
277
+ (choose_qparams_node, 1),
278
+ {}
279
+ )
280
+ quant_min = qparams["_quant_min_"]
281
+ quant_max = qparams["_quant_max_"]
282
+ dtype = qparams["_dtype_"]
283
+ qparams = {
284
+ "_scale_": scale_node,
285
+ "_zero_point_": zero_point_node,
286
+ "_quant_min_": quant_min,
287
+ "_quant_max_": quant_max,
288
+ "_dtype_": dtype
289
+ }
290
+
291
+ # 3. replace activation_post_process node to quantize and dequantize node
292
+ with graph.inserting_before(node):
293
+ input_node = node.args[0]
294
+ quantize_op_inputs = [input_node]
295
+ for key, value_or_node in qparams.items():
296
+ # TODO: we can add the information of whether a value needs to
297
+ # be registered as an attribute in qparams dict itself
298
+ if key in ['_scale_', '_zero_point_']:
299
+ # in this case we have a node in the graph since it's dynamically
300
+ # computed from the input, with choose_qparams op
301
+ qparam_node = value_or_node
302
+ quantize_op_inputs.append(qparam_node)
303
+ else:
304
+ # for qparams that are not scale/zero_point (like axis, dtype) we
305
+ # store them as literals in the graph.
306
+ quantize_op_inputs.append(value_or_node)
307
+
308
+ quantized_node = graph.create_node(node_type, quantize_op, tuple(quantize_op_inputs), {})
309
+ # use the same qparams from quantize op
310
+ dq_inputs = [quantized_node] + quantize_op_inputs[1:]
311
+ # need to use the tensor variant of this op, since scale and zero_point
312
+ # from choose_qparam are Tensors, instead of float/int, this is to
313
+ # prevent these nodes being traced away by downstream systems
314
+ dequantize_op = torch.ops.quantized_decomposed.dequantize_per_tensor.tensor
315
+ dequantized_node = graph.call_function(
316
+ dequantize_op,
317
+ tuple(dq_inputs),
318
+ {}
319
+ )
320
+
321
+ def remap_fn(x):
322
+ return dequantized_node if x is node else x
323
+
324
+ # remap numeric_debug_handle
325
+ for user_node in node.users:
326
+ if "numeric_debug_handle" in user_node.meta:
327
+ numeric_debug_handle = user_node.meta["numeric_debug_handle"]
328
+ user_node.meta["numeric_debug_handle"] = {remap_fn(k): v for k, v in numeric_debug_handle.items()}
329
+ node.replace_all_uses_with(dequantized_node)
330
+ graph.erase_node(node)
331
+ elif dtype == torch.float16:
332
+ raise NotImplementedError("decomposed to float16 op not implemented yet")
333
+
334
+ # should not reach since we have checks in the beginning to make sure the
335
+ # activation_post_process is supported
336
+
337
+ def _replace_observer_with_quantize_dequantize_node(
338
+ model: torch.fx.GraphModule,
339
+ node: Node,
340
+ modules: Dict[str, torch.nn.Module],
341
+ node_name_to_scope: Dict[str, Tuple[str, type]],
342
+ node_name_to_qconfig: Dict[str, QConfigAny]) -> None:
343
+ """ Replace activation_post_process module call node with quantize and
344
+ dequantize node
345
+
346
+ Before:
347
+ ... -> observer_0(x) -> ...
348
+ After:
349
+ ... -> torch.quantize_per_tensor(x, ...) -> x.dequantize() -> ...
350
+ """
351
+ assert modules is not None
352
+ assert isinstance(node.target, str)
353
+ graph = model.graph
354
+ module_path, prefix = _get_module_path_and_prefix(node, node_name_to_scope, node_name_to_qconfig)
355
+ activation_post_process = modules[node.target]
356
+ # skip replacing observers to quant/dequant nodes if the qconfigs of all
357
+ # consumers and producers of this observer are None
358
+ skip_replacement = all(_has_none_qconfig(n, node_name_to_qconfig) for n in
359
+ list(node.args) + list(node.users.keys()))
360
+ if skip_replacement or not _is_conversion_supported(activation_post_process):
361
+ # didn't find corresponding quantize op and info for the activation_post_process
362
+ # so we just remove the observer
363
+ with graph.inserting_before(node):
364
+ node.replace_all_uses_with(node.args[0])
365
+ graph.erase_node(node)
366
+ return
367
+
368
+ # otherwise, we can convert the activation_post_process module call to quantize/dequantize node
369
+ dtype = activation_post_process.dtype # type: ignore[attr-defined]
370
+
371
+ is_dynamic = False
372
+ if hasattr(activation_post_process, "is_dynamic"):
373
+ is_dynamic = activation_post_process.is_dynamic # type: ignore[attr-defined, assignment]
374
+
375
+ if dtype in [torch.quint8, torch.qint8, torch.qint32] and \
376
+ (not is_dynamic):
377
+ # TODO: probably should cleanup this condition check, it's hard
378
+ # to reason about this if and the following elif
379
+
380
+ # uint8/int8/int32 static quantization branch
381
+
382
+ # 1. extract the information from activation_post_process module for generating
383
+ # the quantize and dequantize operator
384
+ node_type = "call_function"
385
+ quantize_op : Optional[Callable] = None
386
+ scale, zero_point = activation_post_process.calculate_qparams() # type: ignore[attr-defined, operator]
387
+ if is_per_channel(activation_post_process.qscheme): # type: ignore[attr-defined]
388
+ ch_axis = int(activation_post_process.ch_axis) # type: ignore[attr-defined, arg-type]
389
+ qparams = {"_scale_": scale, "_zero_point_": zero_point, "_axis_": ch_axis, "_dtype_": dtype}
390
+ quantize_op = torch.quantize_per_channel
391
+ else:
392
+ scale = float(scale)
393
+ zero_point = int(zero_point)
394
+ qparams = {"_scale_": scale, "_zero_point_": zero_point, "_dtype_": dtype}
395
+ quantize_op = torch.quantize_per_tensor
396
+
397
+ # 2. replace activation_post_process node with quantize and dequantize
398
+ with graph.inserting_before(node):
399
+ input_node = node.args[0]
400
+ quantize_op_inputs = [input_node]
401
+ for key, value_or_node in qparams.items():
402
+ # TODO: we can add the information of whether a value needs to
403
+ # be registered as an attribute in qparams dict itself
404
+ if key in ['_scale_', '_zero_point_']:
405
+ # For scale and zero_point values we register them as buffers in the root module.
406
+ # TODO: maybe need more complex attr name here
407
+ qparam_node = create_getattr_from_value(
408
+ model, graph, module_path + prefix + key, value_or_node)
409
+ quantize_op_inputs.append(qparam_node)
410
+ else:
411
+ # for qparams that are not scale/zero_point (like axis, dtype) we store them as literals in the graph.
412
+ quantize_op_inputs.append(value_or_node)
413
+
414
+ quantized_node = graph.create_node(node_type, quantize_op, tuple(quantize_op_inputs), {})
415
+ dequantized_node = graph.call_method("dequantize", args=(quantized_node,))
416
+ node.replace_all_uses_with(dequantized_node)
417
+ graph.erase_node(node)
418
+ elif is_dynamic:
419
+
420
+ # uint8/int8/fp16 dynamic quantization branch
421
+
422
+ node_type = "call_function"
423
+ quantize_op = torch.quantize_per_tensor_dynamic
424
+ # TODO: get reduce range from observer
425
+ # reduce_range = activation_post_process.reduce_range
426
+ reduce_range = torch.backends.quantized.engine in ("fbgemm", "x86")
427
+ qparams = {"_dtype_": dtype, "_reduce_range_": reduce_range}
428
+
429
+ with graph.inserting_before(node):
430
+ input_node = node.args[0]
431
+ quantize_op_inputs = [input_node]
432
+ for key, value in qparams.items():
433
+ quantize_op_inputs.append(value)
434
+
435
+ quantized_node = graph.create_node(node_type, quantize_op, tuple(quantize_op_inputs), {})
436
+ dequantized_node = graph.call_method("dequantize", args=(quantized_node,))
437
+ node.replace_all_uses_with(dequantized_node)
438
+ graph.erase_node(node)
439
+ elif dtype == torch.float16:
440
+ node_type = "call_method"
441
+ quantize_op = "to" # type: ignore[assignment]
442
+ qparams = {"_dtype_": dtype}
443
+ with graph.inserting_before(node):
444
+ input_node = node.args[0]
445
+ quantize_op_inputs = [input_node]
446
+ for key, value in qparams.items():
447
+ # TODO: we can add the information of whether a value needs to
448
+ # be registered as an attribute in qparams dict itself
449
+ quantize_op_inputs.append(value)
450
+
451
+ quantized_node = graph.create_node(node_type, quantize_op, tuple(quantize_op_inputs), {})
452
+ dequantized_node = graph.call_method("dequantize", args=(quantized_node,))
453
+ node.replace_all_uses_with(dequantized_node)
454
+ graph.erase_node(node)
455
+
456
+ # should not reach since we have checks in the beginning to make sure the
457
+ # activation_post_process is supported
458
+
459
+ # this is a temporary hack for custom module, we may want to implement
460
+ # this properly after the custom module class design is finalized
461
+ # TODO: DeQuantStubs are currently inserted only after custom module LSTM, while observers are inserted
462
+ # after all other custom modules. In the future, we should simply insert QuantStubs before and DeQuantStubs
463
+ # after custom modules in general, and replace these with "quantize" and "dequantize" nodes respectively.
464
+ def _replace_observer_or_dequant_stub_with_dequantize_node(node: Node, graph: Graph) -> None:
465
+ call_custom_module_node = node.args[0]
466
+ assert isinstance(call_custom_module_node, Node), \
467
+ f"Expecting the for call custom module node to be a Node, but got {call_custom_module_node}"
468
+ node.replace_all_uses_with(call_custom_module_node)
469
+ graph.erase_node(node)
470
+ _insert_dequantize_node(call_custom_module_node, graph)
471
+
472
+ def _is_conversion_supported(activation_post_process: torch.nn.Module) -> bool:
473
+ dtype = activation_post_process.dtype # type: ignore[attr-defined]
474
+
475
+ is_dynamic = False
476
+ if hasattr(activation_post_process, "is_dynamic"):
477
+ is_dynamic = activation_post_process.is_dynamic # type: ignore[attr-defined, assignment]
478
+
479
+ return (
480
+ (dtype in [
481
+ torch.quint8,
482
+ torch.qint8,
483
+ torch.qint32,
484
+ torch.uint8,
485
+ torch.int8,
486
+ torch.int16,
487
+ torch.int32
488
+ ] and (not is_dynamic)) or # type: ignore[return-value]
489
+ is_dynamic or
490
+ dtype == torch.float16
491
+ )
492
+
493
+ def _has_none_qconfig(node: Argument, node_name_to_qconfig: Dict[str, QConfigAny]) -> bool:
494
+ """ Check if a node has a qconfig of None, i.e. user requested to not quantize
495
+ the node
496
+ """
497
+ return isinstance(node, Node) and node.name in node_name_to_qconfig and node_name_to_qconfig[node.name] is None
498
+
499
+ def _run_weight_observers(observed: GraphModule, backend_config: BackendConfig) -> None:
500
+ """ Extract the subgraph that produces the weight for dynamic quant
501
+ or weight only quant node and run the subgraph to observe the weight.
502
+ Note that the observers of dynamic quant or weight only quant ops are
503
+ run during the convert step.
504
+ """
505
+ for node in observed.graph.nodes:
506
+ if node.op != "call_function":
507
+ continue
508
+ for node_arg in node.args:
509
+ # node_arg is weight
510
+ if node_arg and node_arg_is_weight(node, node_arg):
511
+ weight_observer_nodes = collect_producer_nodes(node_arg)
512
+ if weight_observer_nodes is None:
513
+ continue
514
+ weight_observer_module = \
515
+ graph_module_from_producer_nodes(
516
+ observed, weight_observer_nodes)
517
+ # run the weight observer
518
+ weight_observer_module()
519
+
520
+ def _maybe_recursive_remove_dequantize(arg: Any, node: Node, graph: Graph) -> None:
521
+ """ If the arg is a dequantize Node, or a list/tuple/dict of dequantize Node,
522
+ we'll recursively remove the dequantize Node
523
+ """
524
+ if isinstance(arg, Node) and \
525
+ arg.op == "call_method" and \
526
+ arg.target == "dequantize":
527
+ quantize_node = arg.args[0]
528
+ # we only replace the specific use since dequantize could be used by other nodes
529
+ # as well
530
+ node.replace_input_with(arg, quantize_node)
531
+ elif isinstance(arg, (list, tuple)):
532
+ for arg_element in arg:
533
+ _maybe_recursive_remove_dequantize(arg_element, node, graph)
534
+ elif isinstance(arg, dict):
535
+ for arg_element in arg.values():
536
+ _maybe_recursive_remove_dequantize(arg_element, node, graph)
537
+ else:
538
+ warnings.warn(f"Unsupported node type in recursive remove dequantize: {type(arg)}")
539
+
540
+ def _get_module_path_and_prefix(
541
+ obs_node: Node,
542
+ node_name_to_scope: Dict[str, Tuple[str, type]],
543
+ node_name_to_qconfig: Dict[str, QConfigAny]) -> Tuple[str, str]:
544
+ """ Given and observer node, get the `Scope` or the fully qualified name for
545
+ the submodule containing the observed node, also return a prefix of "_input"
546
+ when the observed node is an input of a F.linear op, and not the output of another
547
+ quantized op.
548
+ TODO: this logic is hacky, we should think about how to remove it or make it more
549
+ general
550
+ """
551
+ observed_node = obs_node.args[0]
552
+ # an observer can be inserted for both input of the next operator or output of the previous
553
+ # operator (they can be the same)
554
+ # this flag identifies if the observer is inserted only because the observed node is
555
+ # the input of the next operator
556
+ assert isinstance(observed_node, Node), \
557
+ f"Expecting observed node to be a Node, but got {observed_node}"
558
+ is_input_observer_only = node_name_to_qconfig[observed_node.name] is None \
559
+ if observed_node.name in node_name_to_qconfig else None
560
+ if is_input_observer_only:
561
+ # if the quantize function is at the input of op, then we find the first user of the observer_node
562
+ # to get the path. If a linear call_function is in the user list, we return the first instance
563
+ # of linear node to get the FQN.
564
+ users = list(obs_node.users)
565
+ first_linear_use_or_first_use = users[0] if users else None
566
+ linear_node = None
567
+ for n in users:
568
+ if n.op == "call_function" and n.target == torch.nn.functional.linear:
569
+ linear_node = n
570
+ break
571
+ if linear_node:
572
+ first_linear_use_or_first_use = linear_node
573
+ prefix = "_input"
574
+ else:
575
+ # if the quantize function is at the output of the op, we use the observer input node to get the path
576
+ first_linear_use_or_first_use = observed_node
577
+ prefix = ""
578
+
579
+ if first_linear_use_or_first_use and first_linear_use_or_first_use.name in node_name_to_scope:
580
+ module_path, _ = node_name_to_scope[first_linear_use_or_first_use.name]
581
+ else:
582
+ # TODO: it's not used, so actually we can skip quantization
583
+ # but this requires changing return type of quantize_node
584
+ # we can fix it later if needed
585
+ module_path = ""
586
+ return module_path, prefix
587
+
588
+ def _insert_dequantize_node(
589
+ node: Node,
590
+ graph: Graph) -> None:
591
+ """ Inserts dequantize node for `node` in `graph`
592
+ """
593
+ with graph.inserting_after(node):
594
+ dequantize_node = graph.call_method("dequantize", (node,))
595
+ for user_node in dict(node.users):
596
+ if user_node is not dequantize_node:
597
+ user_node.replace_input_with(node, dequantize_node)
598
+
599
+ def _maybe_get_observer_for_node(
600
+ node: Node,
601
+ modules: Dict[str, torch.nn.Module]
602
+ ) -> Optional[torch.nn.Module]:
603
+ """
604
+ If the node is observed, return the observer
605
+ instance. Otherwise, return None.
606
+ """
607
+ for maybe_obs_node in node.users.keys():
608
+ if maybe_obs_node.op == 'call_module':
609
+ maybe_obs = modules[str(maybe_obs_node.target)]
610
+ if _is_activation_post_process(maybe_obs):
611
+ return maybe_obs
612
+ return None
613
+
614
+ def convert_standalone_module(
615
+ node: Node,
616
+ modules: Dict[str, torch.nn.Module],
617
+ model: torch.fx.GraphModule,
618
+ is_reference: bool,
619
+ backend_config: Optional[BackendConfig]) -> None:
620
+ """ Converts a observed standalone module to a quantized standalone module by calling
621
+ the fx convert api, currently using the same `is_reference` flag as parent, but we may
622
+ changing this behavior in the future (e.g. separating quantization and lowering for
623
+ standalone module as well)
624
+
625
+ Args:
626
+ - node: The call_module node of the observed standalone module
627
+ - modules: named_module of original model
628
+ - model: original model
629
+ - is_reference: a flag from parent provided by user to decide if we want to
630
+ produce a reference model or a fbgemm/qnnpack model
631
+ - backend_config: backend configuration of the target backend of quantization
632
+ """
633
+ # TODO: remove is_reference flag
634
+ if is_reference:
635
+ convert_fn = torch.ao.quantization.quantize_fx.convert_to_reference_fx
636
+ else:
637
+ convert_fn = torch.ao.quantization.quantize_fx.convert_fx # type: ignore[attr-defined]
638
+ # We know that observed standalone module is a GraphModule since
639
+ # it's produced by us
640
+ observed_standalone_module : GraphModule = modules[str(node.target)] # type: ignore[assignment]
641
+ sm_input_quantized_idxs = \
642
+ observed_standalone_module \
643
+ .meta["_observed_graph_module_attrs"].standalone_module_input_quantized_idxs
644
+ # remove the dequantize nodes for inputs
645
+ args = list(node.args)
646
+ for idx in range(len(args)):
647
+ if idx in sm_input_quantized_idxs:
648
+ arg = args[idx]
649
+ if arg.op == "call_method" and arg.target == "dequantize": # type: ignore[union-attr]
650
+ quantize_node = arg.args[0] # type: ignore[union-attr]
651
+ node.replace_input_with(arg, quantize_node)
652
+ if len(arg.users) == 0: # type: ignore[union-attr]
653
+ model.graph.erase_node(arg)
654
+ # add dequantize node for output
655
+ sm_output_quantized_idxs = \
656
+ observed_standalone_module \
657
+ .meta["_observed_graph_module_attrs"].standalone_module_output_quantized_idxs
658
+ if len(sm_output_quantized_idxs) > 0:
659
+ assert sm_output_quantized_idxs[0] == 0, "Currently only quantized"
660
+ "output idxs = [0] is supported"
661
+
662
+ # if it's non-empty, then it means the output is kept in quantized form
663
+ # we'll just add a dequantize node after this node
664
+ _insert_dequantize_node(node, model.graph)
665
+
666
+ # TODO: allow convert_custom_config to override backend_config
667
+ # for standalone module
668
+ quantized_standalone_module = convert_fn(
669
+ observed_standalone_module,
670
+ backend_config=backend_config)
671
+ parent_name, name = _parent_name(node.target)
672
+ # update the modules dict
673
+ setattr(modules[parent_name], name, quantized_standalone_module)
674
+ modules[str(node.target)] = quantized_standalone_module
675
+
676
+ def convert_weighted_module(
677
+ node: Node,
678
+ modules: Dict[str, torch.nn.Module],
679
+ observed_node_names: Set[str],
680
+ node_name_to_qconfig: Dict[str, QConfigAny],
681
+ backend_config: BackendConfig,
682
+ is_decomposed: bool = False,
683
+ is_reference: bool = False,
684
+ ) -> None:
685
+ """ Convert a weighted module to reference quantized module in the model
686
+ If the QConfig of a QAT module is not set, the module will still be converted to
687
+ a float module.
688
+
689
+ Args:
690
+ - node: The call_module node of the observed standalone module
691
+ - modules: named_module of original model
692
+ - observed_node_names: names for the set of observed fx node, we can skip
693
+ this conversion if the node is not observed
694
+ """
695
+ original_module = modules[str(node.target)]
696
+ qconfig: QConfigAny = original_module.qconfig # type: ignore[assignment]
697
+ weight_post_process = None
698
+ qat_module_classes = get_qat_module_classes(backend_config)
699
+
700
+ if isinstance(
701
+ original_module,
702
+ qat_module_classes):
703
+ # Converting qat module to a float module, we need to attach
704
+ # weight fake_quant to the module, weight fake_quant is assumed to be run during
705
+ # QAT so we don't need to run it again here
706
+ weight_post_process = original_module.weight_fake_quant
707
+ original_module = original_module.to_float() # type: ignore[operator]
708
+ # change qat module to float module
709
+ parent_name, name = _parent_name(node.target)
710
+ setattr(modules[parent_name], name, original_module)
711
+
712
+ is_observed = node.name in observed_node_names
713
+ # If a qconfig is not defined for this node, then skip converting to a reference module
714
+ if qconfig is None or _has_none_qconfig(node, node_name_to_qconfig) or not is_observed:
715
+ return
716
+
717
+ # skip converting to reference quantized module if the qconfig is not supported
718
+ pattern_to_dtype_configs = get_pattern_to_dtype_configs(backend_config)
719
+ dtype_configs = pattern_to_dtype_configs.get(type(original_module), [])
720
+ if not _is_qconfig_supported_by_dtype_configs(qconfig, dtype_configs):
721
+ return
722
+
723
+ # TODO: rename weight_is_statically_quantized to weight_is_int8_quantized
724
+ is_weight_quantized = weight_is_quantized(qconfig)
725
+
726
+ # the condition for swapping the module to reference quantized module is:
727
+ # weights need to be quantized
728
+ if not is_weight_quantized:
729
+ return
730
+
731
+ fused_module = None
732
+ float_module = original_module
733
+ # extract the individual float_module and fused module
734
+ if isinstance(original_module, torch.ao.nn.intrinsic._FusedModule):
735
+ fused_module = float_module
736
+ float_module = fused_module[0] # type: ignore[index]
737
+
738
+ # TODO: move this to the reference quantized module
739
+ # weight_qparams or weight_qparams dict
740
+ wq_or_wq_dict = {"is_decomposed": is_decomposed}
741
+ if isinstance(float_module, torch.nn.RNNCellBase):
742
+ weight_post_process_ih = qconfig.weight() # type: ignore[union-attr, operator]
743
+ weight_post_process_hh = qconfig.weight() # type: ignore[union-attr, operator]
744
+ weight_post_process_ih(float_module.weight_ih)
745
+ weight_post_process_hh(float_module.weight_hh)
746
+ weight_qparams_ih = get_qparam_dict(weight_post_process_ih)
747
+ weight_qparams_hh = get_qparam_dict(weight_post_process_hh)
748
+ wq_or_wq_dict.update({
749
+ "weight_ih": weight_qparams_ih,
750
+ "weight_hh": weight_qparams_hh,
751
+ })
752
+ elif isinstance(float_module, (torch.nn.LSTM, torch.nn.GRU)):
753
+ # format for wq_or_wq_dict (flattened attributes):
754
+ # {"weight_ih_l0_scale": ..., "weight_ih_l0_qscheme": ..., ...}
755
+ for wn in float_module._flat_weights_names:
756
+ if hasattr(float_module, wn) and wn.startswith("weight"):
757
+ weight = getattr(float_module, wn)
758
+ weight_post_process = qconfig.weight() # type: ignore[union-attr, operator]
759
+ if weight_post_process.dtype == torch.qint8: # type: ignore[union-attr]
760
+ weight_post_process(weight) # type: ignore[operator, misc]
761
+ wq_or_wq_dict[wn] = get_qparam_dict(weight_post_process)
762
+ else:
763
+ # weight_post_process is None means the original module is not a QAT module
764
+ # we need to get weight_post_process from qconfig in this case
765
+ is_ptq = weight_post_process is None
766
+ if is_ptq:
767
+ weight_post_process = qconfig.weight() # type: ignore[union-attr, operator]
768
+ device = assert_and_get_unique_device(float_module)
769
+ if device:
770
+ weight_post_process.to(device)
771
+
772
+ # Call weight observer/fake_quant at least once to ensure the scales and zero points
773
+ # have the right shapes. Note: there are two cases where we don't have to do this:
774
+ #
775
+ # (1) QAT: The model's forward method already calls the weight observer/fake_quant,
776
+ # and this typically happens during training, so we don't need to do it here.
777
+ #
778
+ # (2) Non-reference (lowered) case: The quantized module's from_float method already
779
+ # calls the weight observer/fake_quant, so we don't have to do it here.
780
+ #
781
+ # Currently we ignore both cases and call the weight observer/fake_quant here
782
+ # regardless, which is technically incorrect. For (1), this is mainly to preserve BC
783
+ # in test code, which may not always train before convert. In the future, we should
784
+ # break BC for these two cases. See https://github.com/pytorch/pytorch/issues/73941.
785
+ #
786
+ # For PT2, however, we don't need to preserve BC here, so we can skip this hack
787
+ # for QAT. We identify this case as (is_decomposed + is_reference + is_qat).
788
+ # Note that we still need it for PTQ in the PT2 flow since the model's forward
789
+ # method doesn't call the weight observer.
790
+ is_qat = not is_ptq
791
+ if not (is_decomposed and is_reference and is_qat):
792
+ weight_post_process(float_module.weight) # type: ignore[operator]
793
+
794
+ wq_or_wq_dict.update(get_qparam_dict(weight_post_process))
795
+
796
+ # We use the same reference module for all modes of quantization: static, dynamic, weight_only
797
+ # root_module_to_quantized_reference_module: module mapping from root (floating point) module class
798
+ # to quantized reference module class, e.g. nn.Conv2d to nn.quantized._reference.Conv2d
799
+ root_module_to_quantized_reference_module = get_root_module_to_quantized_reference_module(backend_config)
800
+ ref_qmodule_cls = root_module_to_quantized_reference_module.get(type_before_parametrizations(float_module), None)
801
+ assert (
802
+ ref_qmodule_cls is not None
803
+ ), f"No reference quantized module class configured for {type_before_parametrizations(float_module)}"
804
+ ref_qmodule = ref_qmodule_cls.from_float(float_module, wq_or_wq_dict) # type: ignore[attr-defined]
805
+ if fused_module is not None:
806
+ fused_module[0] = ref_qmodule # type: ignore[operator]
807
+ else:
808
+ parent_name, name = _parent_name(node.target)
809
+ setattr(modules[parent_name], name, ref_qmodule)
810
+
811
+ def _remove_previous_dequantize_in_custom_module(node: Node, prev_node: Node, graph: Graph) -> None:
812
+ """
813
+ Given a custom module `node`, if the previous node is a dequantize, reroute the custom as follows:
814
+
815
+ Before: quantize - dequantize - custom_module
816
+ After: quantize - custom_module
817
+ \\ - dequantize
818
+ """
819
+ # expecting the input node for a custom module node to be a Node
820
+ assert isinstance(prev_node, Node), \
821
+ f"Expecting the argument for custom module node to be a Node, but got {prev_node}"
822
+ if prev_node.op == "call_method" and prev_node.target == "dequantize":
823
+ node.replace_input_with(prev_node, prev_node.args[0])
824
+ # Remove the dequantize node if it doesn't have other users
825
+ if len(prev_node.users) == 0:
826
+ graph.erase_node(prev_node)
827
+
828
+ def convert_custom_module(
829
+ node: Node,
830
+ graph: Graph,
831
+ modules: Dict[str, torch.nn.Module],
832
+ custom_module_class_mapping: Dict[QuantType, Dict[Type, Type]],
833
+ statically_quantized_custom_module_nodes: Set[Node]) -> None:
834
+ """ Converts an observed custom module to a quantized custom module based on
835
+ `custom_module_class_mapping`
836
+ For static quantization, we'll also remove the previous `dequantize` node and
837
+ attach the observer node for output to the module, the observer for the node
838
+ will be converted to a dequantize node instead of quantize-dequantize pairs
839
+ later in the graph. In the end we would have a quantized custom module that
840
+ has the same interface as a default quantized module in nn.quantized namespace,
841
+ i.e. quantized input and quantized output.
842
+
843
+ Args:
844
+ - node: The call_module node of the observed standalone module
845
+ - graph: The graph containing the node
846
+ - modules: named_module of original model
847
+ - custom_module_class_mapping: mapping from observed custom module class to
848
+ quantized custom module class, used to swap custom modules
849
+ - statically_quantized_custom_module_nodes: we'll add the custom module node
850
+ if we find it is statically quantized, this will be used later when converting
851
+ observers to quant/dequant node pairs, if the observed node is a statically
852
+ quantized custom module nodes, we'll convert the observer to a dequantize node,
853
+ this is to keep the interface the same as the default quantized module.
854
+ TODO: maybe we want to redesign this part to align with reference model design
855
+ as well, but there has been some discussions around the interface, so we can do
856
+ it later.
857
+ """
858
+ observed_custom_module = modules[str(node.target)]
859
+ maybe_obs = _maybe_get_observer_for_node(node, modules)
860
+ qconfig = observed_custom_module.qconfig
861
+ if activation_is_statically_quantized(qconfig):
862
+ statically_quantized_custom_module_nodes.add(node)
863
+ if _is_custom_module_lstm(node, modules):
864
+ # The inputs are tuples in the form (input, (hidden0, hidden1))
865
+ # Ensure all three input nodes are quantized
866
+ assert (
867
+ len(node.args) == 2 and
868
+ isinstance(node.args[1], tuple) and
869
+ len(node.args[1]) == 2
870
+ )
871
+ (inputs, (hidden0, hidden1)) = node.args # type: ignore[misc]
872
+ assert isinstance(inputs, Node)
873
+ assert isinstance(hidden0, Node)
874
+ assert isinstance(hidden1, Node)
875
+ _remove_previous_dequantize_in_custom_module(node, inputs, graph)
876
+ _remove_previous_dequantize_in_custom_module(node, hidden0, graph)
877
+ _remove_previous_dequantize_in_custom_module(node, hidden1, graph)
878
+ elif _is_custom_module_mha(node, modules):
879
+ # Inputs are in the form (query, key, value)
880
+ # TODO: This is the first step in enabling the full fx custom module
881
+ # quantization path for MultiheadAttention, and only covers the inputs
882
+ # to the module.
883
+ # Additional handling is yet to be implemented for the outputs, similar
884
+ # to LSTM custom module
885
+ assert len(node.args) == 3
886
+ query, key, value = node.args
887
+ assert isinstance(query, Node)
888
+ assert isinstance(key, Node)
889
+ assert isinstance(value, Node)
890
+ _remove_previous_dequantize_in_custom_module(node, query, graph)
891
+ _remove_previous_dequantize_in_custom_module(node, key, graph)
892
+ _remove_previous_dequantize_in_custom_module(node, value, graph)
893
+ else:
894
+ # remove the previous dequant node to ensure the inputs are quantized
895
+ arg = node.args[0]
896
+ assert isinstance(arg, Node)
897
+ _remove_previous_dequantize_in_custom_module(node, arg, graph)
898
+ # absorb the following observer into the module conversion
899
+ activation_post_process = _maybe_get_observer_for_node(node, modules)
900
+ assert activation_post_process is not None
901
+ observed_custom_module.activation_post_process = activation_post_process
902
+
903
+ # swap the observed custom module to quantized custom module
904
+ quantized_custom_module_class = get_swapped_custom_module_class(
905
+ observed_custom_module, custom_module_class_mapping, qconfig)
906
+ quantized_custom_module = \
907
+ quantized_custom_module_class.from_observed(observed_custom_module)
908
+ parent_name, name = _parent_name(node.target)
909
+ setattr(modules[parent_name], name, quantized_custom_module)
910
+
911
+ def convert(
912
+ model: GraphModule, is_reference: bool = False,
913
+ convert_custom_config: Union[ConvertCustomConfig, Dict[str, Any], None] = None,
914
+ is_standalone_module: bool = False,
915
+ _remove_qconfig_flag: bool = True,
916
+ qconfig_mapping: Union[QConfigMapping, Dict[str, Any], None] = None,
917
+ backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
918
+ is_decomposed: bool = False) -> GraphModule:
919
+ """
920
+ We will convert an observed model (a module with observer calls) to a reference
921
+ quantized model, the rule is simple:
922
+ 1. for each observer module call in the graph, we'll convert it to calls to
923
+ quantize and dequantize functions based on the observer instance
924
+ 2. for weighted operations like linear/conv, we need to convert them to reference
925
+ quantized module, this requires us to know whether the dtype configured for the
926
+ weight is supported in the backend, this is done in prepare step and the result
927
+ is stored in observed_node_names, we can decide whether we need to swap the
928
+ module based on this set
929
+
930
+ Args:
931
+ * `is_standalone_module`: when this flag is True, it means we are quantizing
932
+ a submodule that is not inlined in parent module, and will be quantized
933
+ separately as one unit.
934
+
935
+ * `is_decomposed`: a boolean flag to indicate whether we want to use the
936
+ quantize operator for decomposed quantized tensor
937
+ (torch.ops.quantized_decomposed.quantize_per_tensor) or default/standalone
938
+ quantized tensor (torch.quantize_per_tensor)
939
+
940
+ Returns:
941
+ a quantized standalone module, whether input/output is quantized is
942
+ specified by prepare_custom_config, with
943
+ input_quantized_idxs, output_quantized_idxs, please
944
+ see docs for :func:`~torch.ao.quantization.prepare_fx` for details
945
+ """
946
+ if convert_custom_config is None:
947
+ convert_custom_config = ConvertCustomConfig()
948
+
949
+ if isinstance(convert_custom_config, Dict):
950
+ warnings.warn(
951
+ "Passing a convert_custom_config_dict to convert is deprecated and will not be supported "
952
+ "in a future version. Please pass in a ConvertCustomConfig instead.")
953
+ convert_custom_config = ConvertCustomConfig.from_dict(convert_custom_config)
954
+
955
+ if isinstance(qconfig_mapping, Dict):
956
+ warnings.warn(
957
+ "Passing a QConfig dictionary to convert is deprecated and will not be supported "
958
+ "in a future version. Please pass in a QConfigMapping instead.")
959
+ qconfig_mapping = QConfigMapping.from_dict(qconfig_mapping) if qconfig_mapping else None
960
+ qconfig_mapping = copy.deepcopy(qconfig_mapping)
961
+ assert qconfig_mapping is None or isinstance(qconfig_mapping, QConfigMapping)
962
+
963
+ if isinstance(backend_config, Dict):
964
+ warnings.warn(
965
+ "Passing a backend_config_dict to prepare is deprecated and will not be supported "
966
+ "in a future version. Please pass in a BackendConfig instead.")
967
+ backend_config = BackendConfig.from_dict(backend_config)
968
+
969
+ if backend_config is None:
970
+ backend_config = get_native_backend_config()
971
+
972
+ assert _is_observed_module(model), \
973
+ 'incoming model must be produced by prepare_fx'
974
+ observed_graph_module_attrs = model.meta["_observed_graph_module_attrs"]
975
+ node_name_to_scope: Dict[str, Tuple[str, type]] = observed_graph_module_attrs.node_name_to_scope
976
+ prepare_custom_config: PrepareCustomConfig = observed_graph_module_attrs.prepare_custom_config
977
+ observed_node_names: Set[str] = observed_graph_module_attrs.observed_node_names
978
+ node_name_to_qconfig: Dict[str, QConfigAny] = observed_graph_module_attrs.node_name_to_qconfig # type: ignore[assignment]
979
+
980
+ # mapping from fully qualified module name to module instance
981
+ # for example,
982
+ # {
983
+ # '': Model(...),
984
+ # 'linear': Linear(...),
985
+ # 'linear.weight_fake_quant': PerChannelMinMaxObserver(...),
986
+ # }
987
+ # We use remove_duplicate=False here because torch.cat uses
988
+ # the same activation_post_process module instance but different names
989
+ modules = dict(model.named_modules(remove_duplicate=False))
990
+
991
+ # TODO refactor this code once we update the prepare logic to have additional information on
992
+ # which graph nodes have been observed and share that with convert to decide which observers to ignore.
993
+ if qconfig_mapping:
994
+ prepare_qconfig_mapping: QConfigMapping = observed_graph_module_attrs.qconfig_mapping # type: ignore[assignment]
995
+ modules_copy = copy.deepcopy(modules)
996
+
997
+ if observed_graph_module_attrs.is_qat:
998
+ _update_qconfig_for_qat(qconfig_mapping, backend_config)
999
+ _update_qconfig_for_fusion(model, qconfig_mapping)
1000
+
1001
+ _compare_prepare_convert_qconfig_mappings(prepare_qconfig_mapping, qconfig_mapping) # type: ignore[arg-type]
1002
+ convert_node_name_to_qconfig = _generate_node_name_to_qconfig(
1003
+ model, modules_copy, model.graph, qconfig_mapping, node_name_to_scope)
1004
+ # check the convert_node_name_to_qconfig generated and ensure that
1005
+ # all the values either match what was set in prepare node_name_to_qconfig
1006
+ # or are set to None in the convert_node_name_to_qconfig.
1007
+ for k, v in node_name_to_qconfig.items():
1008
+ assert k in convert_node_name_to_qconfig, f'Expected key {k} in convert node_name_to_qconfig'
1009
+ if convert_node_name_to_qconfig[k] is not None:
1010
+ assert qconfig_equals(v, convert_node_name_to_qconfig[k]), \
1011
+ f"Expected k {k} to have the same value in prepare and convert QConfigMappings, " \
1012
+ f"but {v} was updated to {convert_node_name_to_qconfig[k]}"
1013
+ node_name_to_qconfig = convert_node_name_to_qconfig
1014
+
1015
+ custom_module_classes = get_custom_module_class_keys(convert_custom_config.observed_to_quantized_mapping)
1016
+ custom_module_class_mapping = convert_custom_config.observed_to_quantized_mapping
1017
+
1018
+ if observed_graph_module_attrs.equalization_node_name_to_qconfig is not None:
1019
+ # If we want to do equalization then do the following:
1020
+ # Calculate the equalization scale, update the observers with the scaled
1021
+ # inputs, and scale the weight
1022
+ weight_eq_obs_dict = update_obs_for_equalization(model, modules)
1023
+ convert_eq_obs(model, modules, weight_eq_obs_dict)
1024
+
1025
+ # always run weight observers in the top level forward method
1026
+ # for dynamic quant ops or weight only quant ops
1027
+ _run_weight_observers(model, backend_config)
1028
+
1029
+ graph_inputs: List[str] = []
1030
+ for node in model.graph.nodes:
1031
+ if node.op == 'placeholder':
1032
+ graph_inputs.append(node.name)
1033
+
1034
+ # additional state to override inputs to be quantized, if specified
1035
+ # by the user
1036
+ placeholder_node_seen_cnt = 0
1037
+ input_quantized_idxs: List[int] = prepare_custom_config.input_quantized_indexes
1038
+ output_quantized_idxs: List[int] = prepare_custom_config.output_quantized_indexes
1039
+
1040
+ root_module_to_quantized_reference_module = get_root_module_to_quantized_reference_module(backend_config)
1041
+ # convert tuples so that it can work with isinstance(module, tuple_of_classes)
1042
+ root_module_classes = tuple(root_module_to_quantized_reference_module.keys())
1043
+ qat_module_classes = get_qat_module_classes(backend_config)
1044
+ fused_module_classes = get_fused_module_classes(backend_config)
1045
+ statically_quantized_custom_module_nodes: Set[Node] = set()
1046
+
1047
+ for node in list(model.graph.nodes):
1048
+ if node.op == 'placeholder':
1049
+ cur_placeholder_node_idx = placeholder_node_seen_cnt
1050
+ placeholder_node_seen_cnt += 1
1051
+ if cur_placeholder_node_idx in input_quantized_idxs:
1052
+ # Inputs are assumed to be quantized if the user specified the
1053
+ # input_quantized_idxs override.
1054
+ # we need to dequantize the inputs since all operators took
1055
+ # floating point inputs in reference quantized models
1056
+ _insert_dequantize_node(node, model.graph)
1057
+ elif node.op == "output":
1058
+ # If the argument is empty we don't need to do anything
1059
+ if len(output_quantized_idxs) == 0:
1060
+ continue
1061
+ # Result are kept quantized if the user specified the
1062
+ # output_quantized_idxs override.
1063
+ # Remove the dequantize operator for the node in the end if any
1064
+ return_node = node
1065
+ output = node.args[0]
1066
+ # outputs can be Node, list, tuple, dict, other cases are not supported yet
1067
+ if isinstance(output, (list, tuple)):
1068
+ for idx in output_quantized_idxs:
1069
+ _maybe_recursive_remove_dequantize(output[idx], return_node, model.graph)
1070
+ elif isinstance(output, (Node, dict)):
1071
+ # we treat dict as a single argument currently, but it can be extended
1072
+ # to support {"key": dtype} after we change output_quantized_idxs to
1073
+ # dict
1074
+ if 0 in output_quantized_idxs:
1075
+ _maybe_recursive_remove_dequantize(output, return_node, model.graph)
1076
+ else:
1077
+ warnings.warn(f"Unsupported node type for output_quantized_idxs: {type(output)}")
1078
+ elif node.op == "call_module":
1079
+ mod = _get_module(node, modules)
1080
+ assert mod is not None
1081
+ if _is_activation_post_process(mod):
1082
+ observed_node = node.args[0]
1083
+ if observed_node in statically_quantized_custom_module_nodes:
1084
+ _replace_observer_or_dequant_stub_with_dequantize_node(node, model.graph)
1085
+ else:
1086
+ if is_decomposed:
1087
+ _replace_observer_with_quantize_dequantize_node_decomposed(
1088
+ model, node, modules, node_name_to_scope,
1089
+ node_name_to_qconfig)
1090
+ else:
1091
+ _replace_observer_with_quantize_dequantize_node(
1092
+ model, node, modules, node_name_to_scope,
1093
+ node_name_to_qconfig)
1094
+ elif isinstance(mod, DeQuantStub):
1095
+ _replace_observer_or_dequant_stub_with_dequantize_node(node, model.graph)
1096
+ elif _is_observed_standalone_module(mod):
1097
+ convert_standalone_module(
1098
+ node, modules, model, is_reference, backend_config)
1099
+ # below this point `type_before_parametrizations` is used
1100
+ # instead of `type` to handle situations with fx quant + sparsity
1101
+ elif type_before_parametrizations(mod) in set(
1102
+ root_module_classes).union(qat_module_classes).union(fused_module_classes):
1103
+ # extra check for fused module classes to make sure they are fused module classes
1104
+ # of target modules
1105
+ if type_before_parametrizations(mod) in fused_module_classes and \
1106
+ type_before_parametrizations(mod[0]) not in root_module_classes: # type: ignore[index]
1107
+ continue
1108
+ convert_weighted_module(
1109
+ node, modules, observed_node_names, node_name_to_qconfig, backend_config,
1110
+ is_decomposed, is_reference)
1111
+ elif type_before_parametrizations(mod) in custom_module_classes:
1112
+ convert_custom_module(
1113
+ node, model.graph, modules, custom_module_class_mapping,
1114
+ statically_quantized_custom_module_nodes)
1115
+
1116
+ # remove deadcode after converting observers to quant/dequant ops
1117
+ model.graph.eliminate_dead_code()
1118
+ model = GraphModule(model, model.graph)
1119
+
1120
+ # TODO: maybe move this to quantize_fx.py
1121
+ if not is_reference:
1122
+ model = lower_to_fbgemm(model, node_name_to_qconfig, node_name_to_scope)
1123
+
1124
+ # TODO: this looks hacky, we want to check why we need this and see if we can
1125
+ # remove this
1126
+ # removes qconfig and activation_post_process modules
1127
+ if _remove_qconfig_flag:
1128
+ _remove_qconfig(model)
1129
+ model.delete_all_unused_submodules()
1130
+ model.meta.pop("_observed_graph_module_attrs", None)
1131
+ return model
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/match_utils.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import torch
3
+ from torch.fx.graph import (
4
+ Graph,
5
+ Node,
6
+ )
7
+ from torch.ao.quantization.utils import Pattern
8
+ from .quantize_handler import (
9
+ QuantizeHandler,
10
+ )
11
+ from ..qconfig import (
12
+ QConfigAny,
13
+ )
14
+ from ..utils import (
15
+ MatchAllNode
16
+ )
17
+ from .graph_module import (
18
+ _is_observed_standalone_module,
19
+ )
20
+ from torch.nn.utils.parametrize import type_before_parametrizations
21
+ from typing import Any, Dict, List, Callable, Optional, Tuple, Type, Set, Iterable
22
+
23
+
24
+ __all__: List[str] = []
25
+
26
+ # TODO(future PR): the 1st argument is typed as `List[Node]`, but a better type
27
+ # would be a recursive `List[Union[Node, Tuple[Union[Node, ...]]]]`
28
+ _MatchResult = Tuple[Node, List[Node], Optional[Pattern], QuantizeHandler]
29
+
30
+ _MatchResultWithQConfig = Tuple[Node, List[Node], Optional[Pattern], QuantizeHandler,
31
+ QConfigAny]
32
+
33
+ # Note: The order of patterns is important! match function will take whatever is matched first, so we'll
34
+ # need to put the fusion patterns before single patterns. For example, add_relu should be registered come before relu.
35
+ # decorators are applied in the reverse order we see. Also when we match the nodes in the graph with these patterns,
36
+ # we'll start from the last node of the graph and traverse back.
37
+ def _is_match(modules, node, pattern, max_uses=sys.maxsize):
38
+ """ Matches a node in fx against a pattern
39
+ """
40
+ if isinstance(pattern, tuple):
41
+ self_match, *arg_matches = pattern
42
+ if self_match is getattr:
43
+ assert len(pattern) == 2, 'Expecting getattr pattern to have two elements'
44
+ arg_matches = []
45
+ else:
46
+ self_match = pattern
47
+ arg_matches = []
48
+
49
+ if isinstance(self_match, type) and issubclass(self_match, MatchAllNode):
50
+ return True
51
+
52
+ if node == pattern:
53
+ return True
54
+
55
+ if not isinstance(node, Node) or len(node.users) > max_uses:
56
+ return False
57
+
58
+ if isinstance(self_match, type) and issubclass(self_match, torch.nn.Module):
59
+ if node.op != 'call_module':
60
+ return False
61
+ if not type_before_parametrizations(modules[node.target]) == self_match:
62
+ return False
63
+ elif callable(self_match):
64
+ if node.op != 'call_function' or node.target is not self_match:
65
+ return False
66
+ elif node.target is getattr:
67
+ if node.args[1] != pattern[1]:
68
+ return False
69
+ elif isinstance(self_match, str):
70
+ if node.op != 'call_method' or node.target != self_match:
71
+ return False
72
+ elif node.target != self_match:
73
+ return False
74
+
75
+ if not arg_matches:
76
+ return True
77
+
78
+ if len(arg_matches) != len(node.args):
79
+ return False
80
+
81
+ return all(_is_match(modules, node, arg_match, max_uses=1) for node, arg_match in zip(node.args, arg_matches))
82
+
83
+ def _find_matches(
84
+ graph: Graph,
85
+ modules: Dict[str, torch.nn.Module],
86
+ patterns: Dict[Pattern, QuantizeHandler],
87
+ root_node_getter_mapping: Dict[Pattern, Callable],
88
+ standalone_module_names: Optional[List[str]] = None,
89
+ standalone_module_classes: Optional[List[Type]] = None,
90
+ custom_module_classes: Optional[List[Any]] = None) -> Dict[str, _MatchResult]:
91
+ """
92
+ Matches the nodes in the input graph to quantization patterns, and
93
+ outputs the information needed to quantize them in future steps.
94
+
95
+ Inputs:
96
+ - graph: an fx.Graph object
97
+ - modules: a mapping of fully qualified module name to instance,
98
+ for example, {'foo': ModuleFoo, ...}
99
+ - patterns: a mapping from a tuple of nodes in reverse order to
100
+ uninitialized QuantizeHandler subclass.
101
+
102
+ Outputs a map of
103
+ node_name ->
104
+ (node, matched_values, matched_pattern, QuantizeHandler instance,
105
+ qconfig)
106
+
107
+ For example, {
108
+ 'relu_1': (relu_1, [relu_1], torch.nn.functional.relu,
109
+ <CopyNodeQuantizeHandler instance>, QConfig(...)),
110
+ ...
111
+ }
112
+ """
113
+ if custom_module_classes is None:
114
+ custom_module_classes = []
115
+
116
+ if standalone_module_classes is None:
117
+ standalone_module_classes = []
118
+
119
+ if standalone_module_names is None:
120
+ standalone_module_names = []
121
+
122
+ match_map: Dict[str, _MatchResult] = {}
123
+ all_matched : Set[str] = set()
124
+
125
+ def _recursive_record_node_in_match_map(
126
+ last_node,
127
+ match_map,
128
+ node_pattern,
129
+ matched_node_pattern,
130
+ pattern,
131
+ match_value):
132
+ if isinstance(node_pattern, Node):
133
+ match_map[node_pattern.name] = (
134
+ last_node, matched_node_pattern, pattern, match_value)
135
+ elif not isinstance(node_pattern, Iterable):
136
+ return
137
+ else:
138
+ for n in node_pattern:
139
+ _recursive_record_node_in_match_map(last_node, match_map, n, matched_node_pattern, pattern, match_value)
140
+
141
+ # TODO: 1. merge with fuse matcher 2. document the code
142
+ def record_match(
143
+ pattern,
144
+ node,
145
+ last_node,
146
+ matched_node_pattern,
147
+ match_map):
148
+ if isinstance(pattern, tuple):
149
+ s, *args = pattern
150
+ is_single_arg = len(args) == 1
151
+ current_node_pattern: List[Node] = []
152
+ record_match(
153
+ s,
154
+ node,
155
+ last_node,
156
+ matched_node_pattern,
157
+ match_map)
158
+ if pattern[0] is not getattr:
159
+ for subpattern, arg in zip(args, node.args):
160
+ record_match(
161
+ subpattern,
162
+ arg,
163
+ node,
164
+ current_node_pattern,
165
+ match_map)
166
+ if len(current_node_pattern) > 1:
167
+ # current_node_pattern is the node pattern we get from matching
168
+ # the subpattern with arguments of the node
169
+ # we use is_single_arg to recover the original structure of the pattern
170
+ # if the original pattern has a single argument, we will have
171
+ # (original_op, (original_arg, ...))
172
+ # otherwise, we'll have a list of arguments
173
+ # (original_op, arg0, arg1, arg2, ...)
174
+ if is_single_arg:
175
+ matched_node_pattern.append(tuple(current_node_pattern))
176
+ else:
177
+ matched_node_pattern.extend(list(current_node_pattern))
178
+ else:
179
+ matched_node_pattern.append(current_node_pattern[0])
180
+ else:
181
+ matched_node_pattern.append(node)
182
+
183
+ for node in reversed(graph.nodes):
184
+ if node.name not in match_map and node.name not in all_matched:
185
+ for pattern, quantize_handler_cls in patterns.items():
186
+ root_node_getter = root_node_getter_mapping.get(pattern, None)
187
+ if _is_match(modules, node, pattern) and node.name not in match_map:
188
+ matched_node_pattern: List[Node] = []
189
+ record_match(
190
+ pattern,
191
+ node,
192
+ node,
193
+ matched_node_pattern,
194
+ match_map)
195
+ quantize_handler = quantize_handler_cls( # type: ignore[operator]
196
+ matched_node_pattern,
197
+ modules,
198
+ root_node_getter)
199
+ last_node = node
200
+ # record the match for all nodes in the pattern
201
+ _recursive_record_node_in_match_map(
202
+ last_node,
203
+ match_map,
204
+ # we need to record all nodes in the matched pattern in the match_map
205
+ matched_node_pattern,
206
+ # this is a part of the value corresponding to the node
207
+ matched_node_pattern,
208
+ pattern,
209
+ quantize_handler)
210
+ break
211
+
212
+ # add custom module instances to the match result
213
+ assert modules is not None
214
+ for node in graph.nodes:
215
+ if node.op == 'call_module' and \
216
+ type(modules[node.target]) in custom_module_classes:
217
+ match_map[node.name] = (
218
+ node, node, None, QuantizeHandler(node, modules, is_custom_module=True))
219
+
220
+ def is_standalone_module(node_target: str, modules: Dict[str, torch.nn.Module]):
221
+ assert modules is not None
222
+ return (
223
+ node_target in standalone_module_names or # type: ignore[operator]
224
+ type(modules[node_target]) in standalone_module_classes # type: ignore[operator]
225
+ )
226
+
227
+ # add standalone modules to the match
228
+ for node in graph.nodes:
229
+ if node.op == 'call_module' and \
230
+ (is_standalone_module(node.target, modules) or
231
+ _is_observed_standalone_module(modules[node.target])):
232
+ # add node to matched nodes
233
+ match_map[node.name] = (
234
+ node, node, None,
235
+ QuantizeHandler(node, modules, is_standalone_module=True))
236
+
237
+ return match_map
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/tracer.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.fx._symbolic_trace import Tracer
3
+ from torch.fx.proxy import Scope
4
+ from torch.ao.nn.intrinsic import _FusedModule
5
+ from typing import List, Callable
6
+
7
+ __all__ = [
8
+ "QuantizationTracer",
9
+ ]
10
+
11
+ class ScopeContextManager(torch.fx.proxy.ScopeContextManager):
12
+ def __init__(
13
+ self,
14
+ scope: Scope,
15
+ current_module: torch.nn.Module,
16
+ current_module_path: str
17
+ ):
18
+ super().__init__(scope, Scope(current_module_path, type(current_module)))
19
+
20
+
21
+ class QuantizationTracer(Tracer):
22
+ def __init__(
23
+ self, skipped_module_names: List[str], skipped_module_classes: List[Callable]
24
+ ):
25
+ super().__init__()
26
+ self.skipped_module_names = skipped_module_names
27
+ self.skipped_module_classes = skipped_module_classes
28
+ # NB: initialized the module_type of top level module to None
29
+ # we are assuming people won't configure the model with the type of top level
30
+ # module here, since people can use "" for global config
31
+ # We can change this if there is a use case that configures
32
+ # qconfig using top level module type
33
+ self.scope = Scope("", None)
34
+ self.record_stack_traces = True
35
+
36
+ def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool:
37
+ return (
38
+ (
39
+ (m.__module__.startswith("torch.nn") or m.__module__.startswith("torch.ao.nn"))
40
+ and not isinstance(m, torch.nn.Sequential)
41
+ )
42
+ or module_qualified_name in self.skipped_module_names
43
+ or type(m) in self.skipped_module_classes
44
+ or isinstance(m, _FusedModule)
45
+ )
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/multiprocessing/__pycache__/queue.cpython-311.pyc ADDED
Binary file (3.71 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/multiprocessing/pool.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import multiprocessing.pool
2
+ import multiprocessing.util as util
3
+
4
+ from .queue import SimpleQueue
5
+
6
+
7
+ def clean_worker(*args, **kwargs):
8
+ import gc
9
+
10
+ multiprocessing.pool.worker(*args, **kwargs)
11
+ # Regular multiprocessing workers don't fully clean up after themselves,
12
+ # so we have to explicitly trigger garbage collection to make sure that all
13
+ # destructors are called...
14
+ gc.collect()
15
+
16
+
17
+ class Pool(multiprocessing.pool.Pool):
18
+ """Pool implementation which uses our version of SimpleQueue.
19
+
20
+ This lets us pass tensors in shared memory across processes instead of
21
+ serializing the underlying data.
22
+ """
23
+
24
+ def _setup_queues(self):
25
+ self._inqueue = SimpleQueue()
26
+ self._outqueue = SimpleQueue()
27
+ self._quick_put = self._inqueue._writer.send
28
+ self._quick_get = self._outqueue._reader.recv
29
+
30
+ def _repopulate_pool(self):
31
+ """Increase the number of pool processes to the specified number.
32
+
33
+ Bring the number of pool processes up to the specified number, for use after
34
+ reaping workers which have exited.
35
+ """
36
+ for i in range(self._processes - len(self._pool)):
37
+ # changed worker -> clean_worker
38
+ args = (
39
+ self._inqueue,
40
+ self._outqueue,
41
+ self._initializer,
42
+ self._initargs,
43
+ self._maxtasksperchild,
44
+ )
45
+ if hasattr(self, "_wrap_exception"):
46
+ args += (self._wrap_exception,)
47
+ w = self.Process(target=clean_worker, args=args)
48
+ self._pool.append(w)
49
+ w.name = w.name.replace("Process", "PoolWorker")
50
+ w.daemon = True
51
+ w.start()
52
+ util.debug("added worker")
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/__pycache__/_reduction.cpython-311.pyc ADDED
Binary file (1.92 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/__pycache__/grad.cpython-311.pyc ADDED
Binary file (10.7 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/attention/__pycache__/_utils.cpython-311.pyc ADDED
Binary file (3.66 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/attention/_utils.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Defines utilities for interacting with scaled_dot_product_attention"""
2
+ import math
3
+ from typing import List, Optional
4
+
5
+ import torch
6
+
7
+ __all__: List[str] = []
8
+
9
+
10
+ def _input_requires_grad(*tensors: torch.Tensor) -> bool:
11
+ """Returns True if any of the tensors requires grad"""
12
+ return any(t.requires_grad for t in tensors)
13
+
14
+
15
+ def _postprocess_flash_output(inpt_tensor: torch.Tensor, og_size: int) -> torch.Tensor:
16
+ """Handles the unpad of the last dimension"""
17
+ if inpt_tensor.size(-1) != og_size:
18
+ return inpt_tensor[..., :og_size]
19
+ return inpt_tensor
20
+
21
+
22
+ def _calculate_scale(head_dim_size: int, scale: Optional[float]) -> float:
23
+ """
24
+ For FlashAttention we pad the head dimension to be a multiple of 8 so we need to scale the output
25
+ by the original head size and not the padded.
26
+ """
27
+ if scale is not None:
28
+ return scale
29
+ return 1.0 / math.sqrt(head_dim_size)
30
+
31
+
32
+ def _validate_sdpa_input(
33
+ query: torch.Tensor,
34
+ key: torch.Tensor,
35
+ value: torch.Tensor,
36
+ attn_mask: Optional[torch.Tensor] = None,
37
+ dropout_p=0.0,
38
+ is_causal=False,
39
+ scale=None,
40
+ ):
41
+ if query.dtype != key.dtype or query.dtype != value.dtype:
42
+ raise ValueError(
43
+ f"Expected query, key, and value to have the same dtype, "
44
+ f"but got query.dtype: {query.dtype}, key.dtype: {key.dtype}, "
45
+ f"and value.dtype: {value.dtype} instead."
46
+ )
47
+ if query.device != key.device or query.device != value.device:
48
+ raise ValueError(
49
+ f"Expected query, key, and value to have the same device type, "
50
+ f"but got query.device: {query.device}, key.device: {key.device}, "
51
+ f"and value.device: {value.device} instead."
52
+ )
53
+ if query.dim() < 2 or key.dim() < 2 or value.dim() < 2:
54
+ raise ValueError(
55
+ f"Expected query, key, and value to all be at least 2 dimensional, but got query.dim: "
56
+ f"{query.dim()}, key.dim: {key.dim()} and value.dim: {value.dim()} instead."
57
+ )
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/quantized/modules/__pycache__/bn_relu.cpython-311.pyc ADDED
Binary file (401 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/_functions.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.distributed as dist
3
+
4
+ from torch.autograd.function import Function
5
+
6
+ class SyncBatchNorm(Function):
7
+
8
+ @staticmethod
9
+ def forward(self, input, weight, bias, running_mean, running_var, eps, momentum, process_group, world_size):
10
+ if not (
11
+ input.is_contiguous(memory_format=torch.channels_last) or
12
+ input.is_contiguous(memory_format=torch.channels_last_3d)
13
+ ):
14
+ input = input.contiguous()
15
+ if weight is not None:
16
+ weight = weight.contiguous()
17
+
18
+ size = int(input.numel() // input.size(1))
19
+ if size == 1 and world_size < 2:
20
+ raise ValueError(f'Expected more than 1 value per channel when training, got input size {size}')
21
+
22
+ num_channels = input.shape[1]
23
+ if input.numel() > 0:
24
+ # calculate mean/invstd for input.
25
+ mean, invstd = torch.batch_norm_stats(input, eps)
26
+
27
+ count = torch.full(
28
+ (1,),
29
+ input.numel() // input.size(1),
30
+ dtype=mean.dtype,
31
+ device=mean.device
32
+ )
33
+
34
+ # C, C, 1 -> (2C + 1)
35
+ combined = torch.cat([mean, invstd, count], dim=0)
36
+ else:
37
+ # for empty input, set stats and the count to zero. The stats with
38
+ # zero count will be filtered out later when computing global mean
39
+ # & invstd, but they still needs to participate the all_gather
40
+ # collective communication to unblock other peer processes.
41
+ combined = torch.zeros(
42
+ 2 * num_channels + 1,
43
+ dtype=input.dtype,
44
+ device=input.device
45
+ )
46
+
47
+ # Use allgather instead of allreduce because count could be different across
48
+ # ranks, simple all reduce op can not give correct results.
49
+ # batch_norm_gather_stats_with_counts calculates global mean & invstd based on
50
+ # all gathered mean, invstd and count.
51
+ # for nccl backend, use the optimized version of all gather.
52
+ # The Gloo backend does not support `all_gather_into_tensor`.
53
+ if process_group._get_backend_name() != "gloo":
54
+ # world_size * (2C + 1)
55
+ combined_size = combined.numel()
56
+ combined_flat = torch.empty(1,
57
+ combined_size * world_size,
58
+ dtype=combined.dtype,
59
+ device=combined.device)
60
+ dist.all_gather_into_tensor(combined_flat, combined, process_group, async_op=False)
61
+ combined = torch.reshape(combined_flat, (world_size, combined_size))
62
+ # world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1
63
+ mean_all, invstd_all, count_all = torch.split(combined, num_channels, dim=1)
64
+ else:
65
+ # world_size * (2C + 1)
66
+ combined_list = [
67
+ torch.empty_like(combined) for _ in range(world_size)
68
+ ]
69
+ dist.all_gather(combined_list, combined, process_group, async_op=False)
70
+ combined = torch.stack(combined_list, dim=0)
71
+ # world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1
72
+ mean_all, invstd_all, count_all = torch.split(combined, num_channels, dim=1)
73
+
74
+ if not (torch.cuda.is_available() and torch.cuda.is_current_stream_capturing()):
75
+ # The lines below force a synchronization between CUDA and CPU, because
76
+ # the shape of the result count_all depends on the values in mask tensor.
77
+ # Such synchronizations break CUDA Graph capturing.
78
+ # See https://github.com/pytorch/pytorch/issues/78549
79
+ # FIXME: https://github.com/pytorch/pytorch/issues/78656 describes
80
+ # a better longer-term solution.
81
+
82
+ # remove stats from empty inputs
83
+ mask = count_all.squeeze(-1) >= 1
84
+ count_all = count_all[mask]
85
+ mean_all = mean_all[mask]
86
+ invstd_all = invstd_all[mask]
87
+
88
+ # calculate global mean & invstd
89
+ counts = count_all.view(-1)
90
+ if running_mean is not None and counts.dtype != running_mean.dtype:
91
+ counts = counts.to(running_mean.dtype)
92
+ mean, invstd = torch.batch_norm_gather_stats_with_counts(
93
+ input,
94
+ mean_all,
95
+ invstd_all,
96
+ running_mean,
97
+ running_var,
98
+ momentum,
99
+ eps,
100
+ counts,
101
+ )
102
+
103
+ self.save_for_backward(input, weight, mean, invstd, count_all.to(torch.int32))
104
+ self.process_group = process_group
105
+
106
+ # apply element-wise normalization
107
+ if input.numel() > 0:
108
+ return torch.batch_norm_elemt(input, weight, bias, mean, invstd, eps)
109
+ else:
110
+ return torch.empty_like(input)
111
+
112
+ @staticmethod
113
+ def backward(self, grad_output):
114
+ if not (
115
+ grad_output.is_contiguous(memory_format=torch.channels_last) or
116
+ grad_output.is_contiguous(memory_format=torch.channels_last_3d)
117
+ ):
118
+ grad_output = grad_output.contiguous()
119
+ saved_input, weight, mean, invstd, count_tensor = self.saved_tensors
120
+ grad_input = grad_weight = grad_bias = None
121
+ process_group = self.process_group
122
+
123
+ if saved_input.numel() > 0:
124
+ # calculate local stats as well as grad_weight / grad_bias
125
+ sum_dy, sum_dy_xmu, grad_weight, grad_bias = torch.batch_norm_backward_reduce(
126
+ grad_output,
127
+ saved_input,
128
+ mean,
129
+ invstd,
130
+ weight,
131
+ self.needs_input_grad[0],
132
+ self.needs_input_grad[1],
133
+ self.needs_input_grad[2]
134
+ )
135
+
136
+ if self.needs_input_grad[0]:
137
+ # synchronizing stats used to calculate input gradient.
138
+ num_channels = sum_dy.shape[0]
139
+ combined = torch.cat([sum_dy, sum_dy_xmu], dim=0)
140
+ torch.distributed.all_reduce(
141
+ combined, torch.distributed.ReduceOp.SUM, process_group, async_op=False)
142
+ sum_dy, sum_dy_xmu = torch.split(combined, num_channels)
143
+
144
+ # backward pass for gradient calculation
145
+ if weight is not None and weight.dtype != mean.dtype:
146
+ weight = weight.to(mean.dtype)
147
+ grad_input = torch.batch_norm_backward_elemt(
148
+ grad_output,
149
+ saved_input,
150
+ mean,
151
+ invstd,
152
+ weight,
153
+ sum_dy,
154
+ sum_dy_xmu,
155
+ count_tensor
156
+ )
157
+ # synchronizing of grad_weight / grad_bias is not needed as distributed
158
+ # training would handle all reduce.
159
+ if weight is None or not self.needs_input_grad[1]:
160
+ grad_weight = None
161
+
162
+ if weight is None or not self.needs_input_grad[2]:
163
+ grad_bias = None
164
+ else:
165
+ # This process got an empty input tensor in the forward pass.
166
+ # Although this process can directly set grad_input as an empty
167
+ # tensor of zeros, it still needs to participate in the collective
168
+ # communication to unblock its peers, as other peer processes might
169
+ # have received non-empty inputs.
170
+ num_channels = saved_input.shape[1]
171
+ if self.needs_input_grad[0]:
172
+ # launch all_reduce to unblock other peer processes
173
+ combined = torch.zeros(
174
+ 2 * num_channels,
175
+ dtype=saved_input.dtype,
176
+ device=saved_input.device
177
+ )
178
+ torch.distributed.all_reduce(
179
+ combined, torch.distributed.ReduceOp.SUM, process_group, async_op=False)
180
+
181
+ # Leave grad_input, grad_weight and grad_bias as None, which will be
182
+ # interpreted by the autograd engine as Tensors full of zeros.
183
+
184
+ return grad_input, grad_weight, grad_bias, None, None, None, None, None, None
185
+
186
+ class CrossMapLRN2d(Function):
187
+
188
+ @staticmethod
189
+ def forward(ctx, input, size, alpha=1e-4, beta=0.75, k=1):
190
+ ctx.size = size
191
+ ctx.alpha = alpha
192
+ ctx.beta = beta
193
+ ctx.k = k
194
+ ctx.scale = None
195
+
196
+ if input.dim() != 4:
197
+ raise ValueError(f"CrossMapLRN2d: Expected input to be 4D, got {input.dim()}D instead.")
198
+
199
+ ctx.scale = ctx.scale or input.new()
200
+ output = input.new()
201
+
202
+ batch_size = input.size(0)
203
+ channels = input.size(1)
204
+ input_height = input.size(2)
205
+ input_width = input.size(3)
206
+
207
+ output.resize_as_(input)
208
+ ctx.scale.resize_as_(input)
209
+
210
+ # use output storage as temporary buffer
211
+ input_square = output
212
+ torch.pow(input, 2, out=input_square)
213
+
214
+ pre_pad = int((ctx.size - 1) / 2 + 1)
215
+ pre_pad_crop = min(pre_pad, channels)
216
+
217
+ scale_first = ctx.scale.select(1, 0)
218
+ scale_first.zero_()
219
+ # compute first feature map normalization
220
+ for c in range(pre_pad_crop):
221
+ scale_first.add_(input_square.select(1, c))
222
+
223
+ # reuse computations for next feature maps normalization
224
+ # by adding the next feature map and removing the previous
225
+ for c in range(1, channels):
226
+ scale_previous = ctx.scale.select(1, c - 1)
227
+ scale_current = ctx.scale.select(1, c)
228
+ scale_current.copy_(scale_previous)
229
+ if c < channels - pre_pad + 1:
230
+ square_next = input_square.select(1, c + pre_pad - 1)
231
+ scale_current.add_(square_next, alpha=1)
232
+
233
+ if c > pre_pad:
234
+ square_previous = input_square.select(1, c - pre_pad)
235
+ scale_current.add_(square_previous, alpha=-1)
236
+
237
+ ctx.scale.mul_(ctx.alpha / ctx.size).add_(ctx.k)
238
+
239
+ torch.pow(ctx.scale, -ctx.beta, out=output)
240
+ output.mul_(input)
241
+
242
+ ctx.save_for_backward(input, output)
243
+ return output
244
+
245
+ @staticmethod
246
+ def backward(ctx, grad_output):
247
+ input, output = ctx.saved_tensors
248
+ grad_input = grad_output.new()
249
+
250
+ batch_size = input.size(0)
251
+ channels = input.size(1)
252
+ input_height = input.size(2)
253
+ input_width = input.size(3)
254
+
255
+ paddded_ratio = input.new(channels + ctx.size - 1, input_height,
256
+ input_width)
257
+ accum_ratio = input.new(input_height, input_width)
258
+
259
+ cache_ratio_value = 2 * ctx.alpha * ctx.beta / ctx.size
260
+ inversePrePad = int(ctx.size - (ctx.size - 1) / 2)
261
+
262
+ grad_input.resize_as_(input)
263
+ torch.pow(ctx.scale, -ctx.beta, out=grad_input).mul_(grad_output)
264
+
265
+ paddded_ratio.zero_()
266
+ padded_ratio_center = paddded_ratio.narrow(0, inversePrePad,
267
+ channels)
268
+ for n in range(batch_size):
269
+ torch.mul(grad_output[n], output[n], out=padded_ratio_center)
270
+ padded_ratio_center.div_(ctx.scale[n])
271
+ torch.sum(
272
+ paddded_ratio.narrow(0, 0, ctx.size - 1), 0, keepdim=False, out=accum_ratio)
273
+ for c in range(channels):
274
+ accum_ratio.add_(paddded_ratio[c + ctx.size - 1])
275
+ grad_input[n][c].addcmul_(input[n][c], accum_ratio, value=-cache_ratio_value)
276
+ accum_ratio.add_(paddded_ratio[c], alpha=-1)
277
+
278
+ return grad_input, None, None, None, None
279
+
280
+ class BackwardHookFunction(torch.autograd.Function):
281
+ @staticmethod
282
+ def forward(ctx, *args):
283
+ ctx.mark_non_differentiable(*[arg for arg in args if not arg.requires_grad])
284
+ return args
285
+
286
+ @staticmethod
287
+ def backward(ctx, *args):
288
+ return args
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/upsampling.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .module import Module
2
+ from .. import functional as F
3
+
4
+ from torch import Tensor
5
+ from typing import Optional
6
+ from ..common_types import _size_2_t, _ratio_2_t, _size_any_t, _ratio_any_t
7
+
8
+ __all__ = ['Upsample', 'UpsamplingNearest2d', 'UpsamplingBilinear2d']
9
+
10
+
11
+ class Upsample(Module):
12
+ r"""Upsamples a given multi-channel 1D (temporal), 2D (spatial) or 3D (volumetric) data.
13
+
14
+ The input data is assumed to be of the form
15
+ `minibatch x channels x [optional depth] x [optional height] x width`.
16
+ Hence, for spatial inputs, we expect a 4D Tensor and for volumetric inputs, we expect a 5D Tensor.
17
+
18
+ The algorithms available for upsampling are nearest neighbor and linear,
19
+ bilinear, bicubic and trilinear for 3D, 4D and 5D input Tensor,
20
+ respectively.
21
+
22
+ One can either give a :attr:`scale_factor` or the target output :attr:`size` to
23
+ calculate the output size. (You cannot give both, as it is ambiguous)
24
+
25
+ Args:
26
+ size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int], optional):
27
+ output spatial sizes
28
+ scale_factor (float or Tuple[float] or Tuple[float, float] or Tuple[float, float, float], optional):
29
+ multiplier for spatial size. Has to match input size if it is a tuple.
30
+ mode (str, optional): the upsampling algorithm: one of ``'nearest'``,
31
+ ``'linear'``, ``'bilinear'``, ``'bicubic'`` and ``'trilinear'``.
32
+ Default: ``'nearest'``
33
+ align_corners (bool, optional): if ``True``, the corner pixels of the input
34
+ and output tensors are aligned, and thus preserving the values at
35
+ those pixels. This only has effect when :attr:`mode` is
36
+ ``'linear'``, ``'bilinear'``, ``'bicubic'``, or ``'trilinear'``.
37
+ Default: ``False``
38
+ recompute_scale_factor (bool, optional): recompute the scale_factor for use in the
39
+ interpolation calculation. If `recompute_scale_factor` is ``True``, then
40
+ `scale_factor` must be passed in and `scale_factor` is used to compute the
41
+ output `size`. The computed output `size` will be used to infer new scales for
42
+ the interpolation. Note that when `scale_factor` is floating-point, it may differ
43
+ from the recomputed `scale_factor` due to rounding and precision issues.
44
+ If `recompute_scale_factor` is ``False``, then `size` or `scale_factor` will
45
+ be used directly for interpolation.
46
+
47
+ Shape:
48
+ - Input: :math:`(N, C, W_{in})`, :math:`(N, C, H_{in}, W_{in})` or :math:`(N, C, D_{in}, H_{in}, W_{in})`
49
+ - Output: :math:`(N, C, W_{out})`, :math:`(N, C, H_{out}, W_{out})`
50
+ or :math:`(N, C, D_{out}, H_{out}, W_{out})`, where
51
+
52
+ .. math::
53
+ D_{out} = \left\lfloor D_{in} \times \text{scale\_factor} \right\rfloor
54
+
55
+ .. math::
56
+ H_{out} = \left\lfloor H_{in} \times \text{scale\_factor} \right\rfloor
57
+
58
+ .. math::
59
+ W_{out} = \left\lfloor W_{in} \times \text{scale\_factor} \right\rfloor
60
+
61
+ .. warning::
62
+ With ``align_corners = True``, the linearly interpolating modes
63
+ (`linear`, `bilinear`, `bicubic`, and `trilinear`) don't proportionally
64
+ align the output and input pixels, and thus the output values can depend
65
+ on the input size. This was the default behavior for these modes up to
66
+ version 0.3.1. Since then, the default behavior is
67
+ ``align_corners = False``. See below for concrete examples on how this
68
+ affects the outputs.
69
+
70
+ .. note::
71
+ If you want downsampling/general resizing, you should use :func:`~nn.functional.interpolate`.
72
+
73
+ Examples::
74
+
75
+ >>> input = torch.arange(1, 5, dtype=torch.float32).view(1, 1, 2, 2)
76
+ >>> input
77
+ tensor([[[[1., 2.],
78
+ [3., 4.]]]])
79
+
80
+ >>> m = nn.Upsample(scale_factor=2, mode='nearest')
81
+ >>> m(input)
82
+ tensor([[[[1., 1., 2., 2.],
83
+ [1., 1., 2., 2.],
84
+ [3., 3., 4., 4.],
85
+ [3., 3., 4., 4.]]]])
86
+
87
+ >>> # xdoctest: +IGNORE_WANT("other tests seem to modify printing styles")
88
+ >>> m = nn.Upsample(scale_factor=2, mode='bilinear') # align_corners=False
89
+ >>> m(input)
90
+ tensor([[[[1.0000, 1.2500, 1.7500, 2.0000],
91
+ [1.5000, 1.7500, 2.2500, 2.5000],
92
+ [2.5000, 2.7500, 3.2500, 3.5000],
93
+ [3.0000, 3.2500, 3.7500, 4.0000]]]])
94
+
95
+ >>> m = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
96
+ >>> m(input)
97
+ tensor([[[[1.0000, 1.3333, 1.6667, 2.0000],
98
+ [1.6667, 2.0000, 2.3333, 2.6667],
99
+ [2.3333, 2.6667, 3.0000, 3.3333],
100
+ [3.0000, 3.3333, 3.6667, 4.0000]]]])
101
+
102
+ >>> # Try scaling the same data in a larger tensor
103
+ >>> input_3x3 = torch.zeros(3, 3).view(1, 1, 3, 3)
104
+ >>> input_3x3[:, :, :2, :2].copy_(input)
105
+ tensor([[[[1., 2.],
106
+ [3., 4.]]]])
107
+ >>> input_3x3
108
+ tensor([[[[1., 2., 0.],
109
+ [3., 4., 0.],
110
+ [0., 0., 0.]]]])
111
+
112
+ >>> # xdoctest: +IGNORE_WANT("seems to fail when other tests are run in the same session")
113
+ >>> m = nn.Upsample(scale_factor=2, mode='bilinear') # align_corners=False
114
+ >>> # Notice that values in top left corner are the same with the small input (except at boundary)
115
+ >>> m(input_3x3)
116
+ tensor([[[[1.0000, 1.2500, 1.7500, 1.5000, 0.5000, 0.0000],
117
+ [1.5000, 1.7500, 2.2500, 1.8750, 0.6250, 0.0000],
118
+ [2.5000, 2.7500, 3.2500, 2.6250, 0.8750, 0.0000],
119
+ [2.2500, 2.4375, 2.8125, 2.2500, 0.7500, 0.0000],
120
+ [0.7500, 0.8125, 0.9375, 0.7500, 0.2500, 0.0000],
121
+ [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]]])
122
+
123
+ >>> m = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
124
+ >>> # Notice that values in top left corner are now changed
125
+ >>> m(input_3x3)
126
+ tensor([[[[1.0000, 1.4000, 1.8000, 1.6000, 0.8000, 0.0000],
127
+ [1.8000, 2.2000, 2.6000, 2.2400, 1.1200, 0.0000],
128
+ [2.6000, 3.0000, 3.4000, 2.8800, 1.4400, 0.0000],
129
+ [2.4000, 2.7200, 3.0400, 2.5600, 1.2800, 0.0000],
130
+ [1.2000, 1.3600, 1.5200, 1.2800, 0.6400, 0.0000],
131
+ [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]]])
132
+ """
133
+
134
+ __constants__ = ['size', 'scale_factor', 'mode', 'align_corners', 'name', 'recompute_scale_factor']
135
+ name: str
136
+ size: Optional[_size_any_t]
137
+ scale_factor: Optional[_ratio_any_t]
138
+ mode: str
139
+ align_corners: Optional[bool]
140
+ recompute_scale_factor: Optional[bool]
141
+
142
+ def __init__(self, size: Optional[_size_any_t] = None, scale_factor: Optional[_ratio_any_t] = None,
143
+ mode: str = 'nearest', align_corners: Optional[bool] = None,
144
+ recompute_scale_factor: Optional[bool] = None) -> None:
145
+ super().__init__()
146
+ self.name = type(self).__name__
147
+ self.size = size
148
+ if isinstance(scale_factor, tuple):
149
+ self.scale_factor = tuple(float(factor) for factor in scale_factor)
150
+ else:
151
+ self.scale_factor = float(scale_factor) if scale_factor else None
152
+ self.mode = mode
153
+ self.align_corners = align_corners
154
+ self.recompute_scale_factor = recompute_scale_factor
155
+
156
+ def forward(self, input: Tensor) -> Tensor:
157
+ return F.interpolate(input, self.size, self.scale_factor, self.mode, self.align_corners,
158
+ recompute_scale_factor=self.recompute_scale_factor)
159
+
160
+ def __setstate__(self, state):
161
+ if 'recompute_scale_factor' not in state:
162
+ state['recompute_scale_factor'] = True
163
+
164
+ super().__setstate__(state)
165
+
166
+ def extra_repr(self) -> str:
167
+ if self.scale_factor is not None:
168
+ info = 'scale_factor=' + repr(self.scale_factor)
169
+ else:
170
+ info = 'size=' + repr(self.size)
171
+ info += ', mode=' + repr(self.mode)
172
+ return info
173
+
174
+
175
+ class UpsamplingNearest2d(Upsample):
176
+ r"""Applies a 2D nearest neighbor upsampling to an input signal composed of several input channels.
177
+
178
+ To specify the scale, it takes either the :attr:`size` or the :attr:`scale_factor`
179
+ as it's constructor argument.
180
+
181
+ When :attr:`size` is given, it is the output size of the image `(h, w)`.
182
+
183
+ Args:
184
+ size (int or Tuple[int, int], optional): output spatial sizes
185
+ scale_factor (float or Tuple[float, float], optional): multiplier for
186
+ spatial size.
187
+
188
+ .. warning::
189
+ This class is deprecated in favor of :func:`~nn.functional.interpolate`.
190
+
191
+ Shape:
192
+ - Input: :math:`(N, C, H_{in}, W_{in})`
193
+ - Output: :math:`(N, C, H_{out}, W_{out})` where
194
+
195
+ .. math::
196
+ H_{out} = \left\lfloor H_{in} \times \text{scale\_factor} \right\rfloor
197
+
198
+ .. math::
199
+ W_{out} = \left\lfloor W_{in} \times \text{scale\_factor} \right\rfloor
200
+
201
+ Examples::
202
+
203
+ >>> input = torch.arange(1, 5, dtype=torch.float32).view(1, 1, 2, 2)
204
+ >>> input
205
+ tensor([[[[1., 2.],
206
+ [3., 4.]]]])
207
+
208
+ >>> m = nn.UpsamplingNearest2d(scale_factor=2)
209
+ >>> m(input)
210
+ tensor([[[[1., 1., 2., 2.],
211
+ [1., 1., 2., 2.],
212
+ [3., 3., 4., 4.],
213
+ [3., 3., 4., 4.]]]])
214
+ """
215
+
216
+ def __init__(self, size: Optional[_size_2_t] = None, scale_factor: Optional[_ratio_2_t] = None) -> None:
217
+ super().__init__(size, scale_factor, mode='nearest')
218
+
219
+
220
+ class UpsamplingBilinear2d(Upsample):
221
+ r"""Applies a 2D bilinear upsampling to an input signal composed of several input channels.
222
+
223
+ To specify the scale, it takes either the :attr:`size` or the :attr:`scale_factor`
224
+ as it's constructor argument.
225
+
226
+ When :attr:`size` is given, it is the output size of the image `(h, w)`.
227
+
228
+ Args:
229
+ size (int or Tuple[int, int], optional): output spatial sizes
230
+ scale_factor (float or Tuple[float, float], optional): multiplier for
231
+ spatial size.
232
+
233
+ .. warning::
234
+ This class is deprecated in favor of :func:`~nn.functional.interpolate`. It is
235
+ equivalent to ``nn.functional.interpolate(..., mode='bilinear', align_corners=True)``.
236
+
237
+ Shape:
238
+ - Input: :math:`(N, C, H_{in}, W_{in})`
239
+ - Output: :math:`(N, C, H_{out}, W_{out})` where
240
+
241
+ .. math::
242
+ H_{out} = \left\lfloor H_{in} \times \text{scale\_factor} \right\rfloor
243
+
244
+ .. math::
245
+ W_{out} = \left\lfloor W_{in} \times \text{scale\_factor} \right\rfloor
246
+
247
+ Examples::
248
+
249
+ >>> input = torch.arange(1, 5, dtype=torch.float32).view(1, 1, 2, 2)
250
+ >>> input
251
+ tensor([[[[1., 2.],
252
+ [3., 4.]]]])
253
+
254
+ >>> # xdoctest: +IGNORE_WANT("do other tests modify the global state?")
255
+ >>> m = nn.UpsamplingBilinear2d(scale_factor=2)
256
+ >>> m(input)
257
+ tensor([[[[1.0000, 1.3333, 1.6667, 2.0000],
258
+ [1.6667, 2.0000, 2.3333, 2.6667],
259
+ [2.3333, 2.6667, 3.0000, 3.3333],
260
+ [3.0000, 3.3333, 3.6667, 4.0000]]]])
261
+ """
262
+
263
+ def __init__(self, size: Optional[_size_2_t] = None, scale_factor: Optional[_ratio_2_t] = None) -> None:
264
+ super().__init__(size, scale_factor, mode='bilinear', align_corners=True)
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/qat/dynamic/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (401 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/qat/dynamic/modules/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .linear import Linear
2
+
3
+ __all__ = ["Linear"]
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/qat/dynamic/modules/linear.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa: F401
2
+ r"""QAT Modules.
3
+
4
+ This file is in the process of migration to `torch/ao/nn/qat/dynamic`, and
5
+ is kept here for compatibility while the migration process is ongoing.
6
+ If you are adding a new entry/functionality, please, add it to the
7
+ appropriate file under the `torch/ao/nn/qat/dynamic/modules`,
8
+ while adding an import statement here.
9
+ """
10
+ from torch.ao.nn.qat.dynamic.modules.linear import Linear
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/qat/modules/embedding_ops.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa: F401
2
+ r"""QAT Modules.
3
+
4
+ This file is in the process of migration to `torch/ao/nn/qat`, and
5
+ is kept here for compatibility while the migration process is ongoing.
6
+ If you are adding a new entry/functionality, please, add it to the
7
+ appropriate file under the `torch/ao/nn/qat/modules`,
8
+ while adding an import statement here.
9
+ """
10
+
11
+ __all__ = ['Embedding', 'EmbeddingBag']
12
+
13
+ from torch.ao.nn.qat.modules.embedding_ops import Embedding
14
+ from torch.ao.nn.qat.modules.embedding_ops import EmbeddingBag
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantizable/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .modules import * # noqa: F403
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantizable/modules/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (504 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantizable/modules/__pycache__/activation.cpython-311.pyc ADDED
Binary file (697 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantizable/modules/rnn.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa: F401
2
+ r"""Quantizable Modules.
3
+
4
+ This file is in the process of migration to `torch/ao/nn/quantizable`, and
5
+ is kept here for compatibility while the migration process is ongoing.
6
+ If you are adding a new entry/functionality, please, add it to the
7
+ appropriate file under the `torch/ao/nn/quantizable/modules`,
8
+ while adding an import statement here.
9
+ """
10
+ from torch.ao.nn.quantizable.modules.rnn import LSTM
11
+ from torch.ao.nn.quantizable.modules.rnn import LSTMCell
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/__init__.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from . import dynamic # noqa: F403
2
+ from . import functional # noqa: F403
3
+ from . import modules # noqa: F403
4
+ from .modules import * # noqa: F403
5
+ from .modules import MaxPool2d
6
+
7
+ __all__ = [
8
+ 'BatchNorm2d',
9
+ 'BatchNorm3d',
10
+ 'Conv1d',
11
+ 'Conv2d',
12
+ 'Conv3d',
13
+ 'ConvTranspose1d',
14
+ 'ConvTranspose2d',
15
+ 'ConvTranspose3d',
16
+ 'DeQuantize',
17
+ 'Dropout',
18
+ 'ELU',
19
+ 'Embedding',
20
+ 'EmbeddingBag',
21
+ 'GroupNorm',
22
+ 'Hardswish',
23
+ 'InstanceNorm1d',
24
+ 'InstanceNorm2d',
25
+ 'InstanceNorm3d',
26
+ 'LayerNorm',
27
+ 'LeakyReLU',
28
+ 'Linear',
29
+ 'LSTM',
30
+ 'MultiheadAttention',
31
+ 'PReLU',
32
+ 'Quantize',
33
+ 'ReLU6',
34
+ 'Sigmoid',
35
+ 'Softmax',
36
+ # Wrapper modules
37
+ 'FloatFunctional',
38
+ 'FXFloatFunctional',
39
+ 'QFunctional',
40
+ ]
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/_reference/modules/__pycache__/conv.cpython-311.pyc ADDED
Binary file (1.06 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/_reference/modules/__pycache__/linear.cpython-311.pyc ADDED
Binary file (712 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/_reference/modules/__pycache__/rnn.cpython-311.pyc ADDED
Binary file (939 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/_reference/modules/__pycache__/sparse.cpython-311.pyc ADDED
Binary file (766 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/dynamic/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from torch.ao.nn.quantized.dynamic import * # noqa: F403
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/dynamic/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (281 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/dynamic/modules/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (1.42 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/dynamic/modules/conv.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa: F401
2
+ r"""Quantized Dynamic Modules.
3
+
4
+ This file is in the process of migration to `torch/ao/nn/quantized/dynamic`,
5
+ and is kept here for compatibility while the migration process is ongoing.
6
+ If you are adding a new entry/functionality, please, add it to the
7
+ appropriate file under the `torch/ao/nn/quantized/dynamic/modules`,
8
+ while adding an import statement here.
9
+ """
10
+
11
+ __all__ = ['Conv1d', 'Conv2d', 'Conv3d', 'ConvTranspose1d', 'ConvTranspose2d', 'ConvTranspose3d']
12
+
13
+ from torch.ao.nn.quantized.dynamic.modules.conv import Conv1d
14
+ from torch.ao.nn.quantized.dynamic.modules.conv import Conv2d
15
+ from torch.ao.nn.quantized.dynamic.modules.conv import Conv3d
16
+ from torch.ao.nn.quantized.dynamic.modules.conv import ConvTranspose1d
17
+ from torch.ao.nn.quantized.dynamic.modules.conv import ConvTranspose2d
18
+ from torch.ao.nn.quantized.dynamic.modules.conv import ConvTranspose3d
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/dynamic/modules/linear.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa: F401
2
+ r"""Quantized Dynamic Modules.
3
+
4
+ This file is in the process of migration to `torch/ao/nn/quantized/dynamic`,
5
+ and is kept here for compatibility while the migration process is ongoing.
6
+ If you are adding a new entry/functionality, please, add it to the
7
+ appropriate file under the `torch/ao/nn/quantized/dynamic/modules`,
8
+ while adding an import statement here.
9
+ """
10
+ from torch.ao.nn.quantized.dynamic.modules.linear import Linear
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/dynamic/modules/rnn.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa: F401
2
+ r"""Quantized Dynamic Modules.
3
+
4
+ This file is in the process of migration to `torch/ao/nn/quantized/dynamic`,
5
+ and is kept here for compatibility while the migration process is ongoing.
6
+ If you are adding a new entry/functionality, please, add it to the
7
+ appropriate file under the `torch/ao/nn/quantized/dynamic/modules`,
8
+ while adding an import statement here.
9
+ """
10
+
11
+ __all__ = ['pack_weight_bias', 'PackedParameter', 'RNNBase', 'LSTM', 'GRU', 'RNNCellBase', 'RNNCell', 'LSTMCell',
12
+ 'GRUCell']
13
+
14
+ from torch.ao.nn.quantized.dynamic.modules.rnn import pack_weight_bias
15
+ from torch.ao.nn.quantized.dynamic.modules.rnn import PackedParameter
16
+ from torch.ao.nn.quantized.dynamic.modules.rnn import RNNBase
17
+ from torch.ao.nn.quantized.dynamic.modules.rnn import LSTM
18
+ from torch.ao.nn.quantized.dynamic.modules.rnn import GRU
19
+ from torch.ao.nn.quantized.dynamic.modules.rnn import RNNCellBase
20
+ from torch.ao.nn.quantized.dynamic.modules.rnn import RNNCell
21
+ from torch.ao.nn.quantized.dynamic.modules.rnn import LSTMCell
22
+ from torch.ao.nn.quantized.dynamic.modules.rnn import GRUCell
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/modules/__pycache__/activation.cpython-311.pyc ADDED
Binary file (1 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/modules/__pycache__/conv.cpython-311.pyc ADDED
Binary file (1.04 kB). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/modules/__pycache__/dropout.cpython-311.pyc ADDED
Binary file (696 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/modules/__pycache__/linear.cpython-311.pyc ADDED
Binary file (760 Bytes). View file
 
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/modules/__pycache__/utils.cpython-311.pyc ADDED
Binary file (912 Bytes). View file