Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +2 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/__pycache__/ModuleNode.cpython-311.pyc +3 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cufft/lib/libcufftw.so.10 +3 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/config.py +16 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/device_context.py +25 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/ir_cache.py +13 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/metrics.py +21 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/__pycache__/__init__.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/dynamic/__init__.py +1 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/dynamic/modules/__pycache__/__init__.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/dynamic/modules/__pycache__/conv.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/_learnable_fake_quantize.py +164 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/__pycache__/_qnnpack_pt2e.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/__pycache__/qnnpack.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/convert.py +1131 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/match_utils.py +237 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/tracer.py +45 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/multiprocessing/__pycache__/queue.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/multiprocessing/pool.py +52 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/__pycache__/_reduction.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/__pycache__/grad.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/attention/__pycache__/_utils.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/attention/_utils.py +57 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/quantized/modules/__pycache__/bn_relu.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/_functions.py +288 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/upsampling.py +264 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/qat/dynamic/__pycache__/__init__.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/qat/dynamic/modules/__init__.py +3 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/qat/dynamic/modules/linear.py +10 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/qat/modules/embedding_ops.py +14 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantizable/__init__.py +1 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantizable/modules/__pycache__/__init__.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantizable/modules/__pycache__/activation.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantizable/modules/rnn.py +11 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/__init__.py +40 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/_reference/modules/__pycache__/conv.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/_reference/modules/__pycache__/linear.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/_reference/modules/__pycache__/rnn.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/_reference/modules/__pycache__/sparse.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/dynamic/__init__.py +1 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/dynamic/__pycache__/__init__.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/dynamic/modules/__pycache__/__init__.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/dynamic/modules/conv.py +18 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/dynamic/modules/linear.py +10 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/dynamic/modules/rnn.py +22 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/modules/__pycache__/activation.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/modules/__pycache__/conv.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/modules/__pycache__/dropout.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/modules/__pycache__/linear.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/modules/__pycache__/utils.cpython-311.pyc +0 -0
.gitattributes
CHANGED
|
@@ -72,3 +72,5 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/F
|
|
| 72 |
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cuda_cupti/lib/libcheckpoint.so filter=lfs diff=lfs merge=lfs -text
|
| 73 |
tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/Visitor.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
| 74 |
tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/pkg_resources/__pycache__/__init__.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 72 |
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cuda_cupti/lib/libcheckpoint.so filter=lfs diff=lfs merge=lfs -text
|
| 73 |
tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/Visitor.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
| 74 |
tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/pkg_resources/__pycache__/__init__.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 75 |
+
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cufft/lib/libcufftw.so.10 filter=lfs diff=lfs merge=lfs -text
|
| 76 |
+
tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/__pycache__/ModuleNode.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/__pycache__/ModuleNode.cpython-311.pyc
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6e838b6e228975d93cd0fdc5e05254915109c27d231652f0851ab23f1b207b5f
|
| 3 |
+
size 233093
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cufft/lib/libcufftw.so.10
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1a592a5b2f359a9077550ee1fdadd58eb2cf9cc0bfab8fe397a374fb949da143
|
| 3 |
+
size 1618440
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/config.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch._C._lazy
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def get_force_fallback():
|
| 5 |
+
"""Get the config used to force LTC fallback"""
|
| 6 |
+
return torch._C._lazy._get_force_fallback()
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def set_force_fallback(configval):
|
| 10 |
+
"""Set the config used to force LTC fallback"""
|
| 11 |
+
torch._C._lazy._set_force_fallback(configval)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def set_reuse_ir(val: bool):
|
| 15 |
+
"""Set the config to reuse IR nodes for faster tracing"""
|
| 16 |
+
torch._C._lazy._set_reuse_ir(val)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/device_context.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import threading
|
| 2 |
+
from typing import Any, Dict
|
| 3 |
+
|
| 4 |
+
import torch._C._lazy
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class DeviceContext:
|
| 8 |
+
_CONTEXTS: Dict[str, Any] = dict()
|
| 9 |
+
_CONTEXTS_LOCK = threading.Lock()
|
| 10 |
+
|
| 11 |
+
def __init__(self, device):
|
| 12 |
+
self.device = device
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def get_device_context(device=None):
|
| 16 |
+
if device is None:
|
| 17 |
+
device = torch._C._lazy._get_default_device_type()
|
| 18 |
+
else:
|
| 19 |
+
device = str(device)
|
| 20 |
+
with DeviceContext._CONTEXTS_LOCK:
|
| 21 |
+
devctx = DeviceContext._CONTEXTS.get(device, None)
|
| 22 |
+
if devctx is None:
|
| 23 |
+
devctx = DeviceContext(device)
|
| 24 |
+
DeviceContext._CONTEXTS[device] = devctx
|
| 25 |
+
return devctx
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/ir_cache.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch._C._lazy
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def dump(dot_file_name: str):
|
| 5 |
+
"""Dump TrieCache in the dot format"""
|
| 6 |
+
return torch._C._lazy._dump_ir_cache(dot_file_name)
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def reset():
|
| 10 |
+
"""Clear TrieCache. This is needed in testing to avoid
|
| 11 |
+
node reusing between different tests.
|
| 12 |
+
"""
|
| 13 |
+
return torch._C._lazy._clear_ir_cache()
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/metrics.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch._C._lazy
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def reset():
|
| 5 |
+
"""Resets all metric counters."""
|
| 6 |
+
torch._C._lazy._reset_metrics()
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def counter_names():
|
| 10 |
+
"""Retrieves all the currently active counter names."""
|
| 11 |
+
return torch._C._lazy._counter_names()
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def counter_value(name: str):
|
| 15 |
+
"""Return the value of the counter with the speficied name"""
|
| 16 |
+
return torch._C._lazy._counter_value(name)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def metrics_report():
|
| 20 |
+
"""Return the combined (lazy core and backend) metric report"""
|
| 21 |
+
return torch._C._lazy._metrics_report()
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (756 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/dynamic/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .modules import * # noqa: F403
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/dynamic/modules/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (717 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/dynamic/modules/__pycache__/conv.cpython-311.pyc
ADDED
|
Binary file (19.3 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/_learnable_fake_quantize.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.nn.parameter import Parameter
|
| 3 |
+
from typing import List
|
| 4 |
+
|
| 5 |
+
__all__: List[str] = []
|
| 6 |
+
|
| 7 |
+
class _LearnableFakeQuantize(torch.ao.quantization.FakeQuantizeBase):
|
| 8 |
+
r"""Generalized extension of the FakeQuantize module in fake_quantize.py.
|
| 9 |
+
|
| 10 |
+
This is an extension of the FakeQuantize module in fake_quantize.py, which
|
| 11 |
+
supports more generalized lower-bit quantization and support learning of the scale
|
| 12 |
+
and zero point parameters through backpropagation. For literature references,
|
| 13 |
+
please see the class _LearnableFakeQuantizePerTensorOp.
|
| 14 |
+
|
| 15 |
+
In addition to the attributes in the original FakeQuantize module, the _LearnableFakeQuantize
|
| 16 |
+
module also includes the following attributes to support quantization parameter learning.
|
| 17 |
+
|
| 18 |
+
* :attr:`channel_len` defines the length of the channel when initializing scale and zero point
|
| 19 |
+
for the per channel case.
|
| 20 |
+
|
| 21 |
+
* :attr:`use_grad_scaling` defines the flag for whether the gradients for scale and zero point are
|
| 22 |
+
normalized by the constant, which is proportional to the square root of the number of
|
| 23 |
+
elements in the tensor. The related literature justifying the use of this particular constant
|
| 24 |
+
can be found here: https://openreview.net/pdf?id=rkgO66VKDS.
|
| 25 |
+
|
| 26 |
+
* :attr:`fake_quant_enabled` defines the flag for enabling fake quantization on the output.
|
| 27 |
+
|
| 28 |
+
* :attr:`static_enabled` defines the flag for using observer's static estimation for
|
| 29 |
+
scale and zero point.
|
| 30 |
+
|
| 31 |
+
* :attr:`learning_enabled` defines the flag for enabling backpropagation for scale and zero point.
|
| 32 |
+
"""
|
| 33 |
+
def __init__(self, observer, quant_min=0, quant_max=255, scale=1., zero_point=0., channel_len=-1,
|
| 34 |
+
use_grad_scaling=False, **observer_kwargs):
|
| 35 |
+
super().__init__()
|
| 36 |
+
assert quant_min < quant_max, 'quant_min must be strictly less than quant_max.'
|
| 37 |
+
self.quant_min = quant_min
|
| 38 |
+
self.quant_max = quant_max
|
| 39 |
+
# also pass quant_min and quant_max to observer
|
| 40 |
+
observer_kwargs["quant_min"] = quant_min
|
| 41 |
+
observer_kwargs["quant_max"] = quant_max
|
| 42 |
+
self.use_grad_scaling = use_grad_scaling
|
| 43 |
+
if channel_len == -1:
|
| 44 |
+
self.scale = Parameter(torch.tensor([scale]))
|
| 45 |
+
self.zero_point = Parameter(torch.tensor([zero_point]))
|
| 46 |
+
else:
|
| 47 |
+
assert isinstance(channel_len, int) and channel_len > 0, "Channel size must be a positive integer."
|
| 48 |
+
self.scale = Parameter(torch.tensor([scale] * channel_len))
|
| 49 |
+
self.zero_point = Parameter(torch.tensor([zero_point] * channel_len))
|
| 50 |
+
|
| 51 |
+
self.activation_post_process = observer(**observer_kwargs)
|
| 52 |
+
assert torch.iinfo(self.activation_post_process.dtype).min <= quant_min, \
|
| 53 |
+
'quant_min out of bound'
|
| 54 |
+
assert quant_max <= torch.iinfo(self.activation_post_process.dtype).max, \
|
| 55 |
+
'quant_max out of bound'
|
| 56 |
+
self.dtype = self.activation_post_process.dtype
|
| 57 |
+
self.qscheme = self.activation_post_process.qscheme
|
| 58 |
+
self.ch_axis = self.activation_post_process.ch_axis \
|
| 59 |
+
if hasattr(self.activation_post_process, 'ch_axis') else -1
|
| 60 |
+
self.register_buffer('fake_quant_enabled', torch.tensor([1], dtype=torch.uint8))
|
| 61 |
+
self.register_buffer('static_enabled', torch.tensor([1], dtype=torch.uint8))
|
| 62 |
+
self.register_buffer('learning_enabled', torch.tensor([0], dtype=torch.uint8))
|
| 63 |
+
|
| 64 |
+
bitrange = torch.tensor(quant_max - quant_min + 1).double()
|
| 65 |
+
self.bitwidth = int(torch.log2(bitrange).item())
|
| 66 |
+
self.register_buffer('eps', torch.tensor([torch.finfo(torch.float32).eps]))
|
| 67 |
+
|
| 68 |
+
@torch.jit.export
|
| 69 |
+
def enable_param_learning(self):
|
| 70 |
+
r"""Enable parameter learning over static observer estimates.
|
| 71 |
+
|
| 72 |
+
Enables learning of quantization parameters and
|
| 73 |
+
disables static observer estimates. Forward path returns fake quantized X.
|
| 74 |
+
"""
|
| 75 |
+
self.toggle_qparam_learning(enabled=True) \
|
| 76 |
+
.toggle_fake_quant(enabled=True) \
|
| 77 |
+
.toggle_observer_update(enabled=False)
|
| 78 |
+
return self
|
| 79 |
+
|
| 80 |
+
@torch.jit.export
|
| 81 |
+
def enable_static_estimate(self):
|
| 82 |
+
"""Enable static estimates of quantization parameters.
|
| 83 |
+
|
| 84 |
+
Enables static observer estimates and disables learning of
|
| 85 |
+
quantization parameters. Forward path returns fake quantized X.
|
| 86 |
+
"""
|
| 87 |
+
self.toggle_qparam_learning(enabled=False) \
|
| 88 |
+
.toggle_fake_quant(enabled=True) \
|
| 89 |
+
.toggle_observer_update(enabled=True)
|
| 90 |
+
|
| 91 |
+
@torch.jit.export
|
| 92 |
+
def enable_static_observation(self):
|
| 93 |
+
"""Enable accumulation of data without updating quantization parameters.
|
| 94 |
+
|
| 95 |
+
Enables static observer accumulating data from input but doesn't
|
| 96 |
+
update the quantization parameters. Forward path returns the original X.
|
| 97 |
+
"""
|
| 98 |
+
self.toggle_qparam_learning(enabled=False) \
|
| 99 |
+
.toggle_fake_quant(enabled=False) \
|
| 100 |
+
.toggle_observer_update(enabled=True)
|
| 101 |
+
|
| 102 |
+
@torch.jit.export
|
| 103 |
+
def toggle_observer_update(self, enabled=True):
|
| 104 |
+
self.static_enabled[0] = int(enabled) # type: ignore[operator]
|
| 105 |
+
return self
|
| 106 |
+
|
| 107 |
+
@torch.jit.export
|
| 108 |
+
def enable_observer(self, enabled=True):
|
| 109 |
+
self.toggle_observer_update(enabled)
|
| 110 |
+
|
| 111 |
+
@torch.jit.export
|
| 112 |
+
def toggle_qparam_learning(self, enabled=True):
|
| 113 |
+
self.learning_enabled[0] = int(enabled) # type: ignore[operator]
|
| 114 |
+
self.scale.requires_grad = enabled
|
| 115 |
+
self.zero_point.requires_grad = enabled
|
| 116 |
+
return self
|
| 117 |
+
|
| 118 |
+
@torch.jit.export
|
| 119 |
+
def toggle_fake_quant(self, enabled=True):
|
| 120 |
+
self.fake_quant_enabled[0] = int(enabled)
|
| 121 |
+
return self
|
| 122 |
+
|
| 123 |
+
@torch.jit.export
|
| 124 |
+
def observe_quant_params(self):
|
| 125 |
+
print(f'_LearnableFakeQuantize Scale: {self.scale.detach()}')
|
| 126 |
+
print(f'_LearnableFakeQuantize Zero Point: {self.zero_point.detach()}')
|
| 127 |
+
|
| 128 |
+
@torch.jit.export
|
| 129 |
+
def calculate_qparams(self):
|
| 130 |
+
self.scale.data.clamp_(min=self.eps.item()) # type: ignore[operator]
|
| 131 |
+
scale = self.scale.detach()
|
| 132 |
+
zero_point = self.zero_point.detach().round().clamp(self.quant_min, self.quant_max).long()
|
| 133 |
+
return scale, zero_point
|
| 134 |
+
|
| 135 |
+
def forward(self, X):
|
| 136 |
+
if self.static_enabled[0] == 1: # type: ignore[index]
|
| 137 |
+
self.activation_post_process(X.detach())
|
| 138 |
+
_scale, _zero_point = self.activation_post_process.calculate_qparams()
|
| 139 |
+
_scale = _scale.to(self.scale.device)
|
| 140 |
+
_zero_point = _zero_point.to(self.zero_point.device)
|
| 141 |
+
self.scale.data.copy_(_scale)
|
| 142 |
+
self.zero_point.data.copy_(_zero_point)
|
| 143 |
+
else:
|
| 144 |
+
self.scale.data.clamp_(min=self.eps.item()) # type: ignore[operator]
|
| 145 |
+
|
| 146 |
+
if self.fake_quant_enabled[0] == 1:
|
| 147 |
+
if self.qscheme in (torch.per_channel_symmetric, torch.per_tensor_symmetric):
|
| 148 |
+
self.zero_point.data.zero_()
|
| 149 |
+
|
| 150 |
+
if self.use_grad_scaling:
|
| 151 |
+
grad_factor = 1.0 / (X.numel() * self.quant_max) ** 0.5
|
| 152 |
+
else:
|
| 153 |
+
grad_factor = 1.0
|
| 154 |
+
if self.qscheme in (
|
| 155 |
+
torch.per_channel_symmetric, torch.per_channel_affine):
|
| 156 |
+
X = torch._fake_quantize_learnable_per_channel_affine(
|
| 157 |
+
X, self.scale, self.zero_point, self.ch_axis,
|
| 158 |
+
self.quant_min, self.quant_max, grad_factor)
|
| 159 |
+
else:
|
| 160 |
+
X = torch._fake_quantize_learnable_per_tensor_affine(
|
| 161 |
+
X, self.scale, self.zero_point,
|
| 162 |
+
self.quant_min, self.quant_max, grad_factor)
|
| 163 |
+
|
| 164 |
+
return X
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/__pycache__/_qnnpack_pt2e.cpython-311.pyc
ADDED
|
Binary file (6.74 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/__pycache__/qnnpack.cpython-311.pyc
ADDED
|
Binary file (4.62 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/convert.py
ADDED
|
@@ -0,0 +1,1131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: ignore-errors
|
| 2 |
+
|
| 3 |
+
from typing import Any, Dict, List, Optional, Set, Tuple, Union, Type, Callable
|
| 4 |
+
from torch.ao.quantization.quant_type import QuantType
|
| 5 |
+
import torch
|
| 6 |
+
import copy
|
| 7 |
+
import warnings
|
| 8 |
+
from torch.fx import (
|
| 9 |
+
GraphModule,
|
| 10 |
+
)
|
| 11 |
+
from torch.fx.graph import (
|
| 12 |
+
Graph,
|
| 13 |
+
Node,
|
| 14 |
+
Argument,
|
| 15 |
+
)
|
| 16 |
+
from ..utils import (
|
| 17 |
+
activation_is_statically_quantized,
|
| 18 |
+
weight_is_quantized,
|
| 19 |
+
get_qparam_dict,
|
| 20 |
+
_parent_name,
|
| 21 |
+
get_swapped_custom_module_class,
|
| 22 |
+
)
|
| 23 |
+
from ..qconfig import (
|
| 24 |
+
QConfigAny,
|
| 25 |
+
qconfig_equals
|
| 26 |
+
)
|
| 27 |
+
from ..qconfig_mapping import QConfigMapping
|
| 28 |
+
from .qconfig_mapping_utils import (
|
| 29 |
+
_generate_node_name_to_qconfig,
|
| 30 |
+
_compare_prepare_convert_qconfig_mappings,
|
| 31 |
+
_update_qconfig_for_fusion,
|
| 32 |
+
_is_qconfig_supported_by_dtype_configs,
|
| 33 |
+
_update_qconfig_for_qat,
|
| 34 |
+
)
|
| 35 |
+
from torch.ao.quantization.backend_config.utils import (
|
| 36 |
+
get_root_module_to_quantized_reference_module,
|
| 37 |
+
get_pattern_to_dtype_configs,
|
| 38 |
+
get_fused_module_classes,
|
| 39 |
+
get_qat_module_classes,
|
| 40 |
+
)
|
| 41 |
+
from torch.ao.quantization.backend_config import (
|
| 42 |
+
BackendConfig,
|
| 43 |
+
get_native_backend_config,
|
| 44 |
+
)
|
| 45 |
+
from torch.ao.quantization.observer import _is_activation_post_process
|
| 46 |
+
from .graph_module import (
|
| 47 |
+
_is_observed_module,
|
| 48 |
+
_is_observed_standalone_module,
|
| 49 |
+
)
|
| 50 |
+
from ._equalize import update_obs_for_equalization, convert_eq_obs
|
| 51 |
+
from torch.nn.utils.parametrize import type_before_parametrizations
|
| 52 |
+
from .utils import (
|
| 53 |
+
_get_module,
|
| 54 |
+
_is_custom_module_lstm,
|
| 55 |
+
_is_custom_module_mha,
|
| 56 |
+
assert_and_get_unique_device,
|
| 57 |
+
get_custom_module_class_keys,
|
| 58 |
+
create_getattr_from_value,
|
| 59 |
+
collect_producer_nodes,
|
| 60 |
+
graph_module_from_producer_nodes,
|
| 61 |
+
node_arg_is_weight,
|
| 62 |
+
)
|
| 63 |
+
from torch.ao.quantization.utils import (
|
| 64 |
+
is_per_channel,
|
| 65 |
+
to_underlying_dtype,
|
| 66 |
+
)
|
| 67 |
+
from torch.ao.quantization.quantize import (
|
| 68 |
+
_remove_qconfig,
|
| 69 |
+
)
|
| 70 |
+
from torch.ao.quantization.stubs import DeQuantStub
|
| 71 |
+
from .custom_config import (
|
| 72 |
+
ConvertCustomConfig,
|
| 73 |
+
PrepareCustomConfig,
|
| 74 |
+
)
|
| 75 |
+
from .lower_to_fbgemm import lower_to_fbgemm
|
| 76 |
+
# importing the lib so that the quantized_decomposed ops are registered
|
| 77 |
+
from ._decomposed import quantized_decomposed_lib # noqa: F401
|
| 78 |
+
import operator
|
| 79 |
+
|
| 80 |
+
__all__ = [
|
| 81 |
+
"convert",
|
| 82 |
+
"convert_custom_module",
|
| 83 |
+
"convert_standalone_module",
|
| 84 |
+
"convert_weighted_module",
|
| 85 |
+
]
|
| 86 |
+
|
| 87 |
+
_QSCHEME_TO_CHOOSE_QPARAMS_OP = {
|
| 88 |
+
torch.per_tensor_affine: torch.ops.quantized_decomposed.choose_qparams.tensor,
|
| 89 |
+
torch.per_tensor_symmetric: torch.ops.quantized_decomposed.choose_qparams_symmetric.tensor,
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
def _replace_observer_with_quantize_dequantize_node_decomposed(
|
| 93 |
+
model: torch.fx.GraphModule,
|
| 94 |
+
node: Node,
|
| 95 |
+
modules: Dict[str, torch.nn.Module],
|
| 96 |
+
node_name_to_scope: Dict[str, Tuple[str, type]],
|
| 97 |
+
node_name_to_qconfig: Dict[str, QConfigAny]) -> None:
|
| 98 |
+
""" Replace activation_post_process module call node with quantize and
|
| 99 |
+
dequantize node working with decomposed Tensor
|
| 100 |
+
|
| 101 |
+
Before:
|
| 102 |
+
... -> observer_0(x) -> ...
|
| 103 |
+
After:
|
| 104 |
+
... -> torch.ops.quantized_decomposed.quantize_per_tensor(x, ...) ->
|
| 105 |
+
torch.ops.quantized_decomposed.dequantize_per_tensor() -> ...
|
| 106 |
+
|
| 107 |
+
or quantize_per_channel and dequantize_per_channel
|
| 108 |
+
"""
|
| 109 |
+
graph = model.graph
|
| 110 |
+
assert modules is not None
|
| 111 |
+
assert isinstance(node.target, str)
|
| 112 |
+
module_path, prefix = _get_module_path_and_prefix(node, node_name_to_scope, node_name_to_qconfig)
|
| 113 |
+
activation_post_process = modules[node.target]
|
| 114 |
+
if hasattr(activation_post_process, "convert"):
|
| 115 |
+
activation_post_process.convert(model, node)
|
| 116 |
+
return
|
| 117 |
+
# skip replacing observers to quant/dequant nodes if the qconfigs of all
|
| 118 |
+
# consumers and producers of this observer are None
|
| 119 |
+
skip_replacement = all(_has_none_qconfig(n, node_name_to_qconfig) for n in
|
| 120 |
+
list(node.args) + list(node.users.keys()))
|
| 121 |
+
if skip_replacement or not _is_conversion_supported(activation_post_process):
|
| 122 |
+
# didn't find corresponding quantize op and info for the activation_post_process
|
| 123 |
+
# so we just remove the observer
|
| 124 |
+
with graph.inserting_before(node):
|
| 125 |
+
node.replace_all_uses_with(node.args[0])
|
| 126 |
+
graph.erase_node(node)
|
| 127 |
+
return
|
| 128 |
+
|
| 129 |
+
# otherwise, we can convert the activation_post_process module call to quantize/dequantize node
|
| 130 |
+
|
| 131 |
+
# 1. extract the information from activation_post_process module for generating
|
| 132 |
+
# the quantize and dequantize operator
|
| 133 |
+
dtype = activation_post_process.dtype # type: ignore[attr-defined]
|
| 134 |
+
|
| 135 |
+
is_dynamic = False
|
| 136 |
+
if hasattr(activation_post_process, "is_dynamic"):
|
| 137 |
+
is_dynamic = activation_post_process.is_dynamic # type: ignore[assignment]
|
| 138 |
+
|
| 139 |
+
if dtype in [torch.quint8, torch.qint8, torch.qint32, torch.uint8, torch.int8, torch.int16, torch.int32] and \
|
| 140 |
+
(not is_dynamic):
|
| 141 |
+
# TODO: probably should cleanup this condition check, it's hard
|
| 142 |
+
# to reason about this if and the following elif
|
| 143 |
+
|
| 144 |
+
# uint8/int8/int32 static quantization branch
|
| 145 |
+
|
| 146 |
+
# 1. extract information for inserting q/dq node from activation_post_process
|
| 147 |
+
node_type = "call_function"
|
| 148 |
+
quantize_op : Optional[Callable] = None
|
| 149 |
+
scale, zero_point = activation_post_process.calculate_qparams() # type: ignore[attr-defined, operator]
|
| 150 |
+
if is_per_channel(activation_post_process.qscheme): # type: ignore[attr-defined]
|
| 151 |
+
ch_axis = int(activation_post_process.ch_axis) # type: ignore[attr-defined, arg-type]
|
| 152 |
+
quantize_op = torch.ops.quantized_decomposed.quantize_per_channel.default
|
| 153 |
+
dequantize_op = torch.ops.quantized_decomposed.dequantize_per_channel.default
|
| 154 |
+
quant_min = activation_post_process.quant_min
|
| 155 |
+
quant_max = activation_post_process.quant_max
|
| 156 |
+
dtype_ = to_underlying_dtype(dtype)
|
| 157 |
+
qparams = {
|
| 158 |
+
"_scale_": scale,
|
| 159 |
+
"_zero_point_": zero_point,
|
| 160 |
+
"_axis_": ch_axis,
|
| 161 |
+
"_quant_min_": quant_min,
|
| 162 |
+
"_quant_max_": quant_max,
|
| 163 |
+
"_dtype_": dtype_
|
| 164 |
+
}
|
| 165 |
+
else:
|
| 166 |
+
quantize_op = torch.ops.quantized_decomposed.quantize_per_tensor.default
|
| 167 |
+
dequantize_op = torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
| 168 |
+
scale = float(scale)
|
| 169 |
+
zero_point = int(zero_point)
|
| 170 |
+
quant_min = activation_post_process.quant_min # type: ignore[attr-defined]
|
| 171 |
+
quant_max = activation_post_process.quant_max # type: ignore[attr-defined]
|
| 172 |
+
dtype_ = to_underlying_dtype(dtype)
|
| 173 |
+
qparams = {
|
| 174 |
+
"_scale_": scale,
|
| 175 |
+
"_zero_point_": zero_point,
|
| 176 |
+
"_quant_min_": quant_min,
|
| 177 |
+
"_quant_max_": quant_max,
|
| 178 |
+
"_dtype_": dtype_
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
# 2. replace activation_post_process node with quantize and dequantize
|
| 182 |
+
with graph.inserting_before(node):
|
| 183 |
+
input_node = node.args[0]
|
| 184 |
+
quantize_op_inputs = [input_node]
|
| 185 |
+
for key, value_or_node in qparams.items():
|
| 186 |
+
# TODO: we can add the information of whether a value needs to
|
| 187 |
+
# be registered as an attribute in qparams dict itself
|
| 188 |
+
if key in ['_scale_', '_zero_point_'] and (not isinstance(value_or_node, (float, int))):
|
| 189 |
+
# For scale and zero_point values we register them as buffers in the root module.
|
| 190 |
+
# However, note that when the values are not tensors, as in the case of
|
| 191 |
+
# per_tensor quantization, they will be treated as literals.
|
| 192 |
+
# However, registering them as a node seems to cause issue with dynamo
|
| 193 |
+
# tracing where it may consider tensor overload as opposed to default.
|
| 194 |
+
# With extra check of scale and zero_point being scalar, it makes
|
| 195 |
+
# sure that the default overload can be used.
|
| 196 |
+
# TODO: maybe need more complex attr name here
|
| 197 |
+
qparam_node = create_getattr_from_value(
|
| 198 |
+
model, graph, module_path + prefix + key, value_or_node)
|
| 199 |
+
quantize_op_inputs.append(qparam_node)
|
| 200 |
+
else:
|
| 201 |
+
# for qparams that are not scale/zero_point (like axis, dtype) we store them as literals in the graph.
|
| 202 |
+
quantize_op_inputs.append(value_or_node)
|
| 203 |
+
|
| 204 |
+
quantized_node = graph.create_node(node_type, quantize_op, tuple(quantize_op_inputs), {})
|
| 205 |
+
# use the same qparams from quantize op
|
| 206 |
+
dq_inputs = [quantized_node] + quantize_op_inputs[1:]
|
| 207 |
+
dequantized_node = graph.call_function(
|
| 208 |
+
dequantize_op,
|
| 209 |
+
tuple(dq_inputs),
|
| 210 |
+
{}
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
def remap_fn(x):
|
| 214 |
+
return dequantized_node if x is node else x
|
| 215 |
+
|
| 216 |
+
# remap numeric_debug_handle
|
| 217 |
+
for user_node in node.users:
|
| 218 |
+
if "numeric_debug_handle" in user_node.meta:
|
| 219 |
+
numeric_debug_handle = user_node.meta["numeric_debug_handle"]
|
| 220 |
+
user_node.meta["numeric_debug_handle"] = {remap_fn(k): v for k, v in numeric_debug_handle.items()}
|
| 221 |
+
node.replace_all_uses_with(dequantized_node)
|
| 222 |
+
graph.erase_node(node)
|
| 223 |
+
elif is_dynamic:
|
| 224 |
+
|
| 225 |
+
# uint8/int8/fp16 dynamic quantization
|
| 226 |
+
|
| 227 |
+
# 1. extract information for inserting q/dq node from activation_post_process
|
| 228 |
+
node_type = "call_function"
|
| 229 |
+
quantize_op = torch.ops.quantized_decomposed.quantize_per_tensor.tensor
|
| 230 |
+
# we only use choose_qparams for is_decomposed now,
|
| 231 |
+
# but we should probably align the non-decomposed path with this as well,
|
| 232 |
+
# and that can be done after we remove reduce_range flag
|
| 233 |
+
# 1. extract qparams from activation_post_process module
|
| 234 |
+
dtype_ = to_underlying_dtype(dtype)
|
| 235 |
+
assert dtype_ in [torch.uint8, torch.int8], \
|
| 236 |
+
"only uint8 and int8 are supported in reference flow for " \
|
| 237 |
+
"dynamic quantization right now"
|
| 238 |
+
quant_min = activation_post_process.quant_min # type: ignore[attr-defined]
|
| 239 |
+
quant_max = activation_post_process.quant_max # type: ignore[attr-defined]
|
| 240 |
+
qscheme = getattr(activation_post_process, "qscheme", torch.per_tensor_affine) # type: ignore[attr-defined]
|
| 241 |
+
eps = getattr(activation_post_process, "eps", torch.finfo(torch.float32).eps) # type: ignore[attr-defined]
|
| 242 |
+
# note: scale and zero_point are missing for quantize_per_tensor op
|
| 243 |
+
# we'll need to get this from choose_qparams op, which we'll add after
|
| 244 |
+
# this step
|
| 245 |
+
qparams = {
|
| 246 |
+
"_quant_min_": quant_min,
|
| 247 |
+
"_quant_max_": quant_max,
|
| 248 |
+
"_eps_": eps,
|
| 249 |
+
"_dtype_": dtype_
|
| 250 |
+
}
|
| 251 |
+
|
| 252 |
+
choose_qparams_op = _QSCHEME_TO_CHOOSE_QPARAMS_OP[qscheme]
|
| 253 |
+
# 2. insert choose_qparams op and update the qparams list
|
| 254 |
+
with graph.inserting_before(node):
|
| 255 |
+
input_node = node.args[0]
|
| 256 |
+
choose_qparams_op_inputs = [node.args[0]]
|
| 257 |
+
for key, value in qparams.items():
|
| 258 |
+
# we have quant_min, quant_max and dtype, all should be stored
|
| 259 |
+
# as literals
|
| 260 |
+
choose_qparams_op_inputs.append(value)
|
| 261 |
+
choose_qparams_node = graph.create_node(
|
| 262 |
+
"call_function",
|
| 263 |
+
choose_qparams_op,
|
| 264 |
+
tuple(choose_qparams_op_inputs),
|
| 265 |
+
{}
|
| 266 |
+
)
|
| 267 |
+
# choose_qparms returns (scale, zero_point)
|
| 268 |
+
scale_node = graph.create_node(
|
| 269 |
+
"call_function",
|
| 270 |
+
operator.getitem,
|
| 271 |
+
(choose_qparams_node, 0),
|
| 272 |
+
{}
|
| 273 |
+
)
|
| 274 |
+
zero_point_node = graph.create_node(
|
| 275 |
+
"call_function",
|
| 276 |
+
operator.getitem,
|
| 277 |
+
(choose_qparams_node, 1),
|
| 278 |
+
{}
|
| 279 |
+
)
|
| 280 |
+
quant_min = qparams["_quant_min_"]
|
| 281 |
+
quant_max = qparams["_quant_max_"]
|
| 282 |
+
dtype = qparams["_dtype_"]
|
| 283 |
+
qparams = {
|
| 284 |
+
"_scale_": scale_node,
|
| 285 |
+
"_zero_point_": zero_point_node,
|
| 286 |
+
"_quant_min_": quant_min,
|
| 287 |
+
"_quant_max_": quant_max,
|
| 288 |
+
"_dtype_": dtype
|
| 289 |
+
}
|
| 290 |
+
|
| 291 |
+
# 3. replace activation_post_process node to quantize and dequantize node
|
| 292 |
+
with graph.inserting_before(node):
|
| 293 |
+
input_node = node.args[0]
|
| 294 |
+
quantize_op_inputs = [input_node]
|
| 295 |
+
for key, value_or_node in qparams.items():
|
| 296 |
+
# TODO: we can add the information of whether a value needs to
|
| 297 |
+
# be registered as an attribute in qparams dict itself
|
| 298 |
+
if key in ['_scale_', '_zero_point_']:
|
| 299 |
+
# in this case we have a node in the graph since it's dynamically
|
| 300 |
+
# computed from the input, with choose_qparams op
|
| 301 |
+
qparam_node = value_or_node
|
| 302 |
+
quantize_op_inputs.append(qparam_node)
|
| 303 |
+
else:
|
| 304 |
+
# for qparams that are not scale/zero_point (like axis, dtype) we
|
| 305 |
+
# store them as literals in the graph.
|
| 306 |
+
quantize_op_inputs.append(value_or_node)
|
| 307 |
+
|
| 308 |
+
quantized_node = graph.create_node(node_type, quantize_op, tuple(quantize_op_inputs), {})
|
| 309 |
+
# use the same qparams from quantize op
|
| 310 |
+
dq_inputs = [quantized_node] + quantize_op_inputs[1:]
|
| 311 |
+
# need to use the tensor variant of this op, since scale and zero_point
|
| 312 |
+
# from choose_qparam are Tensors, instead of float/int, this is to
|
| 313 |
+
# prevent these nodes being traced away by downstream systems
|
| 314 |
+
dequantize_op = torch.ops.quantized_decomposed.dequantize_per_tensor.tensor
|
| 315 |
+
dequantized_node = graph.call_function(
|
| 316 |
+
dequantize_op,
|
| 317 |
+
tuple(dq_inputs),
|
| 318 |
+
{}
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
def remap_fn(x):
|
| 322 |
+
return dequantized_node if x is node else x
|
| 323 |
+
|
| 324 |
+
# remap numeric_debug_handle
|
| 325 |
+
for user_node in node.users:
|
| 326 |
+
if "numeric_debug_handle" in user_node.meta:
|
| 327 |
+
numeric_debug_handle = user_node.meta["numeric_debug_handle"]
|
| 328 |
+
user_node.meta["numeric_debug_handle"] = {remap_fn(k): v for k, v in numeric_debug_handle.items()}
|
| 329 |
+
node.replace_all_uses_with(dequantized_node)
|
| 330 |
+
graph.erase_node(node)
|
| 331 |
+
elif dtype == torch.float16:
|
| 332 |
+
raise NotImplementedError("decomposed to float16 op not implemented yet")
|
| 333 |
+
|
| 334 |
+
# should not reach since we have checks in the beginning to make sure the
|
| 335 |
+
# activation_post_process is supported
|
| 336 |
+
|
| 337 |
+
def _replace_observer_with_quantize_dequantize_node(
|
| 338 |
+
model: torch.fx.GraphModule,
|
| 339 |
+
node: Node,
|
| 340 |
+
modules: Dict[str, torch.nn.Module],
|
| 341 |
+
node_name_to_scope: Dict[str, Tuple[str, type]],
|
| 342 |
+
node_name_to_qconfig: Dict[str, QConfigAny]) -> None:
|
| 343 |
+
""" Replace activation_post_process module call node with quantize and
|
| 344 |
+
dequantize node
|
| 345 |
+
|
| 346 |
+
Before:
|
| 347 |
+
... -> observer_0(x) -> ...
|
| 348 |
+
After:
|
| 349 |
+
... -> torch.quantize_per_tensor(x, ...) -> x.dequantize() -> ...
|
| 350 |
+
"""
|
| 351 |
+
assert modules is not None
|
| 352 |
+
assert isinstance(node.target, str)
|
| 353 |
+
graph = model.graph
|
| 354 |
+
module_path, prefix = _get_module_path_and_prefix(node, node_name_to_scope, node_name_to_qconfig)
|
| 355 |
+
activation_post_process = modules[node.target]
|
| 356 |
+
# skip replacing observers to quant/dequant nodes if the qconfigs of all
|
| 357 |
+
# consumers and producers of this observer are None
|
| 358 |
+
skip_replacement = all(_has_none_qconfig(n, node_name_to_qconfig) for n in
|
| 359 |
+
list(node.args) + list(node.users.keys()))
|
| 360 |
+
if skip_replacement or not _is_conversion_supported(activation_post_process):
|
| 361 |
+
# didn't find corresponding quantize op and info for the activation_post_process
|
| 362 |
+
# so we just remove the observer
|
| 363 |
+
with graph.inserting_before(node):
|
| 364 |
+
node.replace_all_uses_with(node.args[0])
|
| 365 |
+
graph.erase_node(node)
|
| 366 |
+
return
|
| 367 |
+
|
| 368 |
+
# otherwise, we can convert the activation_post_process module call to quantize/dequantize node
|
| 369 |
+
dtype = activation_post_process.dtype # type: ignore[attr-defined]
|
| 370 |
+
|
| 371 |
+
is_dynamic = False
|
| 372 |
+
if hasattr(activation_post_process, "is_dynamic"):
|
| 373 |
+
is_dynamic = activation_post_process.is_dynamic # type: ignore[attr-defined, assignment]
|
| 374 |
+
|
| 375 |
+
if dtype in [torch.quint8, torch.qint8, torch.qint32] and \
|
| 376 |
+
(not is_dynamic):
|
| 377 |
+
# TODO: probably should cleanup this condition check, it's hard
|
| 378 |
+
# to reason about this if and the following elif
|
| 379 |
+
|
| 380 |
+
# uint8/int8/int32 static quantization branch
|
| 381 |
+
|
| 382 |
+
# 1. extract the information from activation_post_process module for generating
|
| 383 |
+
# the quantize and dequantize operator
|
| 384 |
+
node_type = "call_function"
|
| 385 |
+
quantize_op : Optional[Callable] = None
|
| 386 |
+
scale, zero_point = activation_post_process.calculate_qparams() # type: ignore[attr-defined, operator]
|
| 387 |
+
if is_per_channel(activation_post_process.qscheme): # type: ignore[attr-defined]
|
| 388 |
+
ch_axis = int(activation_post_process.ch_axis) # type: ignore[attr-defined, arg-type]
|
| 389 |
+
qparams = {"_scale_": scale, "_zero_point_": zero_point, "_axis_": ch_axis, "_dtype_": dtype}
|
| 390 |
+
quantize_op = torch.quantize_per_channel
|
| 391 |
+
else:
|
| 392 |
+
scale = float(scale)
|
| 393 |
+
zero_point = int(zero_point)
|
| 394 |
+
qparams = {"_scale_": scale, "_zero_point_": zero_point, "_dtype_": dtype}
|
| 395 |
+
quantize_op = torch.quantize_per_tensor
|
| 396 |
+
|
| 397 |
+
# 2. replace activation_post_process node with quantize and dequantize
|
| 398 |
+
with graph.inserting_before(node):
|
| 399 |
+
input_node = node.args[0]
|
| 400 |
+
quantize_op_inputs = [input_node]
|
| 401 |
+
for key, value_or_node in qparams.items():
|
| 402 |
+
# TODO: we can add the information of whether a value needs to
|
| 403 |
+
# be registered as an attribute in qparams dict itself
|
| 404 |
+
if key in ['_scale_', '_zero_point_']:
|
| 405 |
+
# For scale and zero_point values we register them as buffers in the root module.
|
| 406 |
+
# TODO: maybe need more complex attr name here
|
| 407 |
+
qparam_node = create_getattr_from_value(
|
| 408 |
+
model, graph, module_path + prefix + key, value_or_node)
|
| 409 |
+
quantize_op_inputs.append(qparam_node)
|
| 410 |
+
else:
|
| 411 |
+
# for qparams that are not scale/zero_point (like axis, dtype) we store them as literals in the graph.
|
| 412 |
+
quantize_op_inputs.append(value_or_node)
|
| 413 |
+
|
| 414 |
+
quantized_node = graph.create_node(node_type, quantize_op, tuple(quantize_op_inputs), {})
|
| 415 |
+
dequantized_node = graph.call_method("dequantize", args=(quantized_node,))
|
| 416 |
+
node.replace_all_uses_with(dequantized_node)
|
| 417 |
+
graph.erase_node(node)
|
| 418 |
+
elif is_dynamic:
|
| 419 |
+
|
| 420 |
+
# uint8/int8/fp16 dynamic quantization branch
|
| 421 |
+
|
| 422 |
+
node_type = "call_function"
|
| 423 |
+
quantize_op = torch.quantize_per_tensor_dynamic
|
| 424 |
+
# TODO: get reduce range from observer
|
| 425 |
+
# reduce_range = activation_post_process.reduce_range
|
| 426 |
+
reduce_range = torch.backends.quantized.engine in ("fbgemm", "x86")
|
| 427 |
+
qparams = {"_dtype_": dtype, "_reduce_range_": reduce_range}
|
| 428 |
+
|
| 429 |
+
with graph.inserting_before(node):
|
| 430 |
+
input_node = node.args[0]
|
| 431 |
+
quantize_op_inputs = [input_node]
|
| 432 |
+
for key, value in qparams.items():
|
| 433 |
+
quantize_op_inputs.append(value)
|
| 434 |
+
|
| 435 |
+
quantized_node = graph.create_node(node_type, quantize_op, tuple(quantize_op_inputs), {})
|
| 436 |
+
dequantized_node = graph.call_method("dequantize", args=(quantized_node,))
|
| 437 |
+
node.replace_all_uses_with(dequantized_node)
|
| 438 |
+
graph.erase_node(node)
|
| 439 |
+
elif dtype == torch.float16:
|
| 440 |
+
node_type = "call_method"
|
| 441 |
+
quantize_op = "to" # type: ignore[assignment]
|
| 442 |
+
qparams = {"_dtype_": dtype}
|
| 443 |
+
with graph.inserting_before(node):
|
| 444 |
+
input_node = node.args[0]
|
| 445 |
+
quantize_op_inputs = [input_node]
|
| 446 |
+
for key, value in qparams.items():
|
| 447 |
+
# TODO: we can add the information of whether a value needs to
|
| 448 |
+
# be registered as an attribute in qparams dict itself
|
| 449 |
+
quantize_op_inputs.append(value)
|
| 450 |
+
|
| 451 |
+
quantized_node = graph.create_node(node_type, quantize_op, tuple(quantize_op_inputs), {})
|
| 452 |
+
dequantized_node = graph.call_method("dequantize", args=(quantized_node,))
|
| 453 |
+
node.replace_all_uses_with(dequantized_node)
|
| 454 |
+
graph.erase_node(node)
|
| 455 |
+
|
| 456 |
+
# should not reach since we have checks in the beginning to make sure the
|
| 457 |
+
# activation_post_process is supported
|
| 458 |
+
|
| 459 |
+
# this is a temporary hack for custom module, we may want to implement
|
| 460 |
+
# this properly after the custom module class design is finalized
|
| 461 |
+
# TODO: DeQuantStubs are currently inserted only after custom module LSTM, while observers are inserted
|
| 462 |
+
# after all other custom modules. In the future, we should simply insert QuantStubs before and DeQuantStubs
|
| 463 |
+
# after custom modules in general, and replace these with "quantize" and "dequantize" nodes respectively.
|
| 464 |
+
def _replace_observer_or_dequant_stub_with_dequantize_node(node: Node, graph: Graph) -> None:
|
| 465 |
+
call_custom_module_node = node.args[0]
|
| 466 |
+
assert isinstance(call_custom_module_node, Node), \
|
| 467 |
+
f"Expecting the for call custom module node to be a Node, but got {call_custom_module_node}"
|
| 468 |
+
node.replace_all_uses_with(call_custom_module_node)
|
| 469 |
+
graph.erase_node(node)
|
| 470 |
+
_insert_dequantize_node(call_custom_module_node, graph)
|
| 471 |
+
|
| 472 |
+
def _is_conversion_supported(activation_post_process: torch.nn.Module) -> bool:
|
| 473 |
+
dtype = activation_post_process.dtype # type: ignore[attr-defined]
|
| 474 |
+
|
| 475 |
+
is_dynamic = False
|
| 476 |
+
if hasattr(activation_post_process, "is_dynamic"):
|
| 477 |
+
is_dynamic = activation_post_process.is_dynamic # type: ignore[attr-defined, assignment]
|
| 478 |
+
|
| 479 |
+
return (
|
| 480 |
+
(dtype in [
|
| 481 |
+
torch.quint8,
|
| 482 |
+
torch.qint8,
|
| 483 |
+
torch.qint32,
|
| 484 |
+
torch.uint8,
|
| 485 |
+
torch.int8,
|
| 486 |
+
torch.int16,
|
| 487 |
+
torch.int32
|
| 488 |
+
] and (not is_dynamic)) or # type: ignore[return-value]
|
| 489 |
+
is_dynamic or
|
| 490 |
+
dtype == torch.float16
|
| 491 |
+
)
|
| 492 |
+
|
| 493 |
+
def _has_none_qconfig(node: Argument, node_name_to_qconfig: Dict[str, QConfigAny]) -> bool:
|
| 494 |
+
""" Check if a node has a qconfig of None, i.e. user requested to not quantize
|
| 495 |
+
the node
|
| 496 |
+
"""
|
| 497 |
+
return isinstance(node, Node) and node.name in node_name_to_qconfig and node_name_to_qconfig[node.name] is None
|
| 498 |
+
|
| 499 |
+
def _run_weight_observers(observed: GraphModule, backend_config: BackendConfig) -> None:
|
| 500 |
+
""" Extract the subgraph that produces the weight for dynamic quant
|
| 501 |
+
or weight only quant node and run the subgraph to observe the weight.
|
| 502 |
+
Note that the observers of dynamic quant or weight only quant ops are
|
| 503 |
+
run during the convert step.
|
| 504 |
+
"""
|
| 505 |
+
for node in observed.graph.nodes:
|
| 506 |
+
if node.op != "call_function":
|
| 507 |
+
continue
|
| 508 |
+
for node_arg in node.args:
|
| 509 |
+
# node_arg is weight
|
| 510 |
+
if node_arg and node_arg_is_weight(node, node_arg):
|
| 511 |
+
weight_observer_nodes = collect_producer_nodes(node_arg)
|
| 512 |
+
if weight_observer_nodes is None:
|
| 513 |
+
continue
|
| 514 |
+
weight_observer_module = \
|
| 515 |
+
graph_module_from_producer_nodes(
|
| 516 |
+
observed, weight_observer_nodes)
|
| 517 |
+
# run the weight observer
|
| 518 |
+
weight_observer_module()
|
| 519 |
+
|
| 520 |
+
def _maybe_recursive_remove_dequantize(arg: Any, node: Node, graph: Graph) -> None:
|
| 521 |
+
""" If the arg is a dequantize Node, or a list/tuple/dict of dequantize Node,
|
| 522 |
+
we'll recursively remove the dequantize Node
|
| 523 |
+
"""
|
| 524 |
+
if isinstance(arg, Node) and \
|
| 525 |
+
arg.op == "call_method" and \
|
| 526 |
+
arg.target == "dequantize":
|
| 527 |
+
quantize_node = arg.args[0]
|
| 528 |
+
# we only replace the specific use since dequantize could be used by other nodes
|
| 529 |
+
# as well
|
| 530 |
+
node.replace_input_with(arg, quantize_node)
|
| 531 |
+
elif isinstance(arg, (list, tuple)):
|
| 532 |
+
for arg_element in arg:
|
| 533 |
+
_maybe_recursive_remove_dequantize(arg_element, node, graph)
|
| 534 |
+
elif isinstance(arg, dict):
|
| 535 |
+
for arg_element in arg.values():
|
| 536 |
+
_maybe_recursive_remove_dequantize(arg_element, node, graph)
|
| 537 |
+
else:
|
| 538 |
+
warnings.warn(f"Unsupported node type in recursive remove dequantize: {type(arg)}")
|
| 539 |
+
|
| 540 |
+
def _get_module_path_and_prefix(
|
| 541 |
+
obs_node: Node,
|
| 542 |
+
node_name_to_scope: Dict[str, Tuple[str, type]],
|
| 543 |
+
node_name_to_qconfig: Dict[str, QConfigAny]) -> Tuple[str, str]:
|
| 544 |
+
""" Given and observer node, get the `Scope` or the fully qualified name for
|
| 545 |
+
the submodule containing the observed node, also return a prefix of "_input"
|
| 546 |
+
when the observed node is an input of a F.linear op, and not the output of another
|
| 547 |
+
quantized op.
|
| 548 |
+
TODO: this logic is hacky, we should think about how to remove it or make it more
|
| 549 |
+
general
|
| 550 |
+
"""
|
| 551 |
+
observed_node = obs_node.args[0]
|
| 552 |
+
# an observer can be inserted for both input of the next operator or output of the previous
|
| 553 |
+
# operator (they can be the same)
|
| 554 |
+
# this flag identifies if the observer is inserted only because the observed node is
|
| 555 |
+
# the input of the next operator
|
| 556 |
+
assert isinstance(observed_node, Node), \
|
| 557 |
+
f"Expecting observed node to be a Node, but got {observed_node}"
|
| 558 |
+
is_input_observer_only = node_name_to_qconfig[observed_node.name] is None \
|
| 559 |
+
if observed_node.name in node_name_to_qconfig else None
|
| 560 |
+
if is_input_observer_only:
|
| 561 |
+
# if the quantize function is at the input of op, then we find the first user of the observer_node
|
| 562 |
+
# to get the path. If a linear call_function is in the user list, we return the first instance
|
| 563 |
+
# of linear node to get the FQN.
|
| 564 |
+
users = list(obs_node.users)
|
| 565 |
+
first_linear_use_or_first_use = users[0] if users else None
|
| 566 |
+
linear_node = None
|
| 567 |
+
for n in users:
|
| 568 |
+
if n.op == "call_function" and n.target == torch.nn.functional.linear:
|
| 569 |
+
linear_node = n
|
| 570 |
+
break
|
| 571 |
+
if linear_node:
|
| 572 |
+
first_linear_use_or_first_use = linear_node
|
| 573 |
+
prefix = "_input"
|
| 574 |
+
else:
|
| 575 |
+
# if the quantize function is at the output of the op, we use the observer input node to get the path
|
| 576 |
+
first_linear_use_or_first_use = observed_node
|
| 577 |
+
prefix = ""
|
| 578 |
+
|
| 579 |
+
if first_linear_use_or_first_use and first_linear_use_or_first_use.name in node_name_to_scope:
|
| 580 |
+
module_path, _ = node_name_to_scope[first_linear_use_or_first_use.name]
|
| 581 |
+
else:
|
| 582 |
+
# TODO: it's not used, so actually we can skip quantization
|
| 583 |
+
# but this requires changing return type of quantize_node
|
| 584 |
+
# we can fix it later if needed
|
| 585 |
+
module_path = ""
|
| 586 |
+
return module_path, prefix
|
| 587 |
+
|
| 588 |
+
def _insert_dequantize_node(
|
| 589 |
+
node: Node,
|
| 590 |
+
graph: Graph) -> None:
|
| 591 |
+
""" Inserts dequantize node for `node` in `graph`
|
| 592 |
+
"""
|
| 593 |
+
with graph.inserting_after(node):
|
| 594 |
+
dequantize_node = graph.call_method("dequantize", (node,))
|
| 595 |
+
for user_node in dict(node.users):
|
| 596 |
+
if user_node is not dequantize_node:
|
| 597 |
+
user_node.replace_input_with(node, dequantize_node)
|
| 598 |
+
|
| 599 |
+
def _maybe_get_observer_for_node(
|
| 600 |
+
node: Node,
|
| 601 |
+
modules: Dict[str, torch.nn.Module]
|
| 602 |
+
) -> Optional[torch.nn.Module]:
|
| 603 |
+
"""
|
| 604 |
+
If the node is observed, return the observer
|
| 605 |
+
instance. Otherwise, return None.
|
| 606 |
+
"""
|
| 607 |
+
for maybe_obs_node in node.users.keys():
|
| 608 |
+
if maybe_obs_node.op == 'call_module':
|
| 609 |
+
maybe_obs = modules[str(maybe_obs_node.target)]
|
| 610 |
+
if _is_activation_post_process(maybe_obs):
|
| 611 |
+
return maybe_obs
|
| 612 |
+
return None
|
| 613 |
+
|
| 614 |
+
def convert_standalone_module(
|
| 615 |
+
node: Node,
|
| 616 |
+
modules: Dict[str, torch.nn.Module],
|
| 617 |
+
model: torch.fx.GraphModule,
|
| 618 |
+
is_reference: bool,
|
| 619 |
+
backend_config: Optional[BackendConfig]) -> None:
|
| 620 |
+
""" Converts a observed standalone module to a quantized standalone module by calling
|
| 621 |
+
the fx convert api, currently using the same `is_reference` flag as parent, but we may
|
| 622 |
+
changing this behavior in the future (e.g. separating quantization and lowering for
|
| 623 |
+
standalone module as well)
|
| 624 |
+
|
| 625 |
+
Args:
|
| 626 |
+
- node: The call_module node of the observed standalone module
|
| 627 |
+
- modules: named_module of original model
|
| 628 |
+
- model: original model
|
| 629 |
+
- is_reference: a flag from parent provided by user to decide if we want to
|
| 630 |
+
produce a reference model or a fbgemm/qnnpack model
|
| 631 |
+
- backend_config: backend configuration of the target backend of quantization
|
| 632 |
+
"""
|
| 633 |
+
# TODO: remove is_reference flag
|
| 634 |
+
if is_reference:
|
| 635 |
+
convert_fn = torch.ao.quantization.quantize_fx.convert_to_reference_fx
|
| 636 |
+
else:
|
| 637 |
+
convert_fn = torch.ao.quantization.quantize_fx.convert_fx # type: ignore[attr-defined]
|
| 638 |
+
# We know that observed standalone module is a GraphModule since
|
| 639 |
+
# it's produced by us
|
| 640 |
+
observed_standalone_module : GraphModule = modules[str(node.target)] # type: ignore[assignment]
|
| 641 |
+
sm_input_quantized_idxs = \
|
| 642 |
+
observed_standalone_module \
|
| 643 |
+
.meta["_observed_graph_module_attrs"].standalone_module_input_quantized_idxs
|
| 644 |
+
# remove the dequantize nodes for inputs
|
| 645 |
+
args = list(node.args)
|
| 646 |
+
for idx in range(len(args)):
|
| 647 |
+
if idx in sm_input_quantized_idxs:
|
| 648 |
+
arg = args[idx]
|
| 649 |
+
if arg.op == "call_method" and arg.target == "dequantize": # type: ignore[union-attr]
|
| 650 |
+
quantize_node = arg.args[0] # type: ignore[union-attr]
|
| 651 |
+
node.replace_input_with(arg, quantize_node)
|
| 652 |
+
if len(arg.users) == 0: # type: ignore[union-attr]
|
| 653 |
+
model.graph.erase_node(arg)
|
| 654 |
+
# add dequantize node for output
|
| 655 |
+
sm_output_quantized_idxs = \
|
| 656 |
+
observed_standalone_module \
|
| 657 |
+
.meta["_observed_graph_module_attrs"].standalone_module_output_quantized_idxs
|
| 658 |
+
if len(sm_output_quantized_idxs) > 0:
|
| 659 |
+
assert sm_output_quantized_idxs[0] == 0, "Currently only quantized"
|
| 660 |
+
"output idxs = [0] is supported"
|
| 661 |
+
|
| 662 |
+
# if it's non-empty, then it means the output is kept in quantized form
|
| 663 |
+
# we'll just add a dequantize node after this node
|
| 664 |
+
_insert_dequantize_node(node, model.graph)
|
| 665 |
+
|
| 666 |
+
# TODO: allow convert_custom_config to override backend_config
|
| 667 |
+
# for standalone module
|
| 668 |
+
quantized_standalone_module = convert_fn(
|
| 669 |
+
observed_standalone_module,
|
| 670 |
+
backend_config=backend_config)
|
| 671 |
+
parent_name, name = _parent_name(node.target)
|
| 672 |
+
# update the modules dict
|
| 673 |
+
setattr(modules[parent_name], name, quantized_standalone_module)
|
| 674 |
+
modules[str(node.target)] = quantized_standalone_module
|
| 675 |
+
|
| 676 |
+
def convert_weighted_module(
|
| 677 |
+
node: Node,
|
| 678 |
+
modules: Dict[str, torch.nn.Module],
|
| 679 |
+
observed_node_names: Set[str],
|
| 680 |
+
node_name_to_qconfig: Dict[str, QConfigAny],
|
| 681 |
+
backend_config: BackendConfig,
|
| 682 |
+
is_decomposed: bool = False,
|
| 683 |
+
is_reference: bool = False,
|
| 684 |
+
) -> None:
|
| 685 |
+
""" Convert a weighted module to reference quantized module in the model
|
| 686 |
+
If the QConfig of a QAT module is not set, the module will still be converted to
|
| 687 |
+
a float module.
|
| 688 |
+
|
| 689 |
+
Args:
|
| 690 |
+
- node: The call_module node of the observed standalone module
|
| 691 |
+
- modules: named_module of original model
|
| 692 |
+
- observed_node_names: names for the set of observed fx node, we can skip
|
| 693 |
+
this conversion if the node is not observed
|
| 694 |
+
"""
|
| 695 |
+
original_module = modules[str(node.target)]
|
| 696 |
+
qconfig: QConfigAny = original_module.qconfig # type: ignore[assignment]
|
| 697 |
+
weight_post_process = None
|
| 698 |
+
qat_module_classes = get_qat_module_classes(backend_config)
|
| 699 |
+
|
| 700 |
+
if isinstance(
|
| 701 |
+
original_module,
|
| 702 |
+
qat_module_classes):
|
| 703 |
+
# Converting qat module to a float module, we need to attach
|
| 704 |
+
# weight fake_quant to the module, weight fake_quant is assumed to be run during
|
| 705 |
+
# QAT so we don't need to run it again here
|
| 706 |
+
weight_post_process = original_module.weight_fake_quant
|
| 707 |
+
original_module = original_module.to_float() # type: ignore[operator]
|
| 708 |
+
# change qat module to float module
|
| 709 |
+
parent_name, name = _parent_name(node.target)
|
| 710 |
+
setattr(modules[parent_name], name, original_module)
|
| 711 |
+
|
| 712 |
+
is_observed = node.name in observed_node_names
|
| 713 |
+
# If a qconfig is not defined for this node, then skip converting to a reference module
|
| 714 |
+
if qconfig is None or _has_none_qconfig(node, node_name_to_qconfig) or not is_observed:
|
| 715 |
+
return
|
| 716 |
+
|
| 717 |
+
# skip converting to reference quantized module if the qconfig is not supported
|
| 718 |
+
pattern_to_dtype_configs = get_pattern_to_dtype_configs(backend_config)
|
| 719 |
+
dtype_configs = pattern_to_dtype_configs.get(type(original_module), [])
|
| 720 |
+
if not _is_qconfig_supported_by_dtype_configs(qconfig, dtype_configs):
|
| 721 |
+
return
|
| 722 |
+
|
| 723 |
+
# TODO: rename weight_is_statically_quantized to weight_is_int8_quantized
|
| 724 |
+
is_weight_quantized = weight_is_quantized(qconfig)
|
| 725 |
+
|
| 726 |
+
# the condition for swapping the module to reference quantized module is:
|
| 727 |
+
# weights need to be quantized
|
| 728 |
+
if not is_weight_quantized:
|
| 729 |
+
return
|
| 730 |
+
|
| 731 |
+
fused_module = None
|
| 732 |
+
float_module = original_module
|
| 733 |
+
# extract the individual float_module and fused module
|
| 734 |
+
if isinstance(original_module, torch.ao.nn.intrinsic._FusedModule):
|
| 735 |
+
fused_module = float_module
|
| 736 |
+
float_module = fused_module[0] # type: ignore[index]
|
| 737 |
+
|
| 738 |
+
# TODO: move this to the reference quantized module
|
| 739 |
+
# weight_qparams or weight_qparams dict
|
| 740 |
+
wq_or_wq_dict = {"is_decomposed": is_decomposed}
|
| 741 |
+
if isinstance(float_module, torch.nn.RNNCellBase):
|
| 742 |
+
weight_post_process_ih = qconfig.weight() # type: ignore[union-attr, operator]
|
| 743 |
+
weight_post_process_hh = qconfig.weight() # type: ignore[union-attr, operator]
|
| 744 |
+
weight_post_process_ih(float_module.weight_ih)
|
| 745 |
+
weight_post_process_hh(float_module.weight_hh)
|
| 746 |
+
weight_qparams_ih = get_qparam_dict(weight_post_process_ih)
|
| 747 |
+
weight_qparams_hh = get_qparam_dict(weight_post_process_hh)
|
| 748 |
+
wq_or_wq_dict.update({
|
| 749 |
+
"weight_ih": weight_qparams_ih,
|
| 750 |
+
"weight_hh": weight_qparams_hh,
|
| 751 |
+
})
|
| 752 |
+
elif isinstance(float_module, (torch.nn.LSTM, torch.nn.GRU)):
|
| 753 |
+
# format for wq_or_wq_dict (flattened attributes):
|
| 754 |
+
# {"weight_ih_l0_scale": ..., "weight_ih_l0_qscheme": ..., ...}
|
| 755 |
+
for wn in float_module._flat_weights_names:
|
| 756 |
+
if hasattr(float_module, wn) and wn.startswith("weight"):
|
| 757 |
+
weight = getattr(float_module, wn)
|
| 758 |
+
weight_post_process = qconfig.weight() # type: ignore[union-attr, operator]
|
| 759 |
+
if weight_post_process.dtype == torch.qint8: # type: ignore[union-attr]
|
| 760 |
+
weight_post_process(weight) # type: ignore[operator, misc]
|
| 761 |
+
wq_or_wq_dict[wn] = get_qparam_dict(weight_post_process)
|
| 762 |
+
else:
|
| 763 |
+
# weight_post_process is None means the original module is not a QAT module
|
| 764 |
+
# we need to get weight_post_process from qconfig in this case
|
| 765 |
+
is_ptq = weight_post_process is None
|
| 766 |
+
if is_ptq:
|
| 767 |
+
weight_post_process = qconfig.weight() # type: ignore[union-attr, operator]
|
| 768 |
+
device = assert_and_get_unique_device(float_module)
|
| 769 |
+
if device:
|
| 770 |
+
weight_post_process.to(device)
|
| 771 |
+
|
| 772 |
+
# Call weight observer/fake_quant at least once to ensure the scales and zero points
|
| 773 |
+
# have the right shapes. Note: there are two cases where we don't have to do this:
|
| 774 |
+
#
|
| 775 |
+
# (1) QAT: The model's forward method already calls the weight observer/fake_quant,
|
| 776 |
+
# and this typically happens during training, so we don't need to do it here.
|
| 777 |
+
#
|
| 778 |
+
# (2) Non-reference (lowered) case: The quantized module's from_float method already
|
| 779 |
+
# calls the weight observer/fake_quant, so we don't have to do it here.
|
| 780 |
+
#
|
| 781 |
+
# Currently we ignore both cases and call the weight observer/fake_quant here
|
| 782 |
+
# regardless, which is technically incorrect. For (1), this is mainly to preserve BC
|
| 783 |
+
# in test code, which may not always train before convert. In the future, we should
|
| 784 |
+
# break BC for these two cases. See https://github.com/pytorch/pytorch/issues/73941.
|
| 785 |
+
#
|
| 786 |
+
# For PT2, however, we don't need to preserve BC here, so we can skip this hack
|
| 787 |
+
# for QAT. We identify this case as (is_decomposed + is_reference + is_qat).
|
| 788 |
+
# Note that we still need it for PTQ in the PT2 flow since the model's forward
|
| 789 |
+
# method doesn't call the weight observer.
|
| 790 |
+
is_qat = not is_ptq
|
| 791 |
+
if not (is_decomposed and is_reference and is_qat):
|
| 792 |
+
weight_post_process(float_module.weight) # type: ignore[operator]
|
| 793 |
+
|
| 794 |
+
wq_or_wq_dict.update(get_qparam_dict(weight_post_process))
|
| 795 |
+
|
| 796 |
+
# We use the same reference module for all modes of quantization: static, dynamic, weight_only
|
| 797 |
+
# root_module_to_quantized_reference_module: module mapping from root (floating point) module class
|
| 798 |
+
# to quantized reference module class, e.g. nn.Conv2d to nn.quantized._reference.Conv2d
|
| 799 |
+
root_module_to_quantized_reference_module = get_root_module_to_quantized_reference_module(backend_config)
|
| 800 |
+
ref_qmodule_cls = root_module_to_quantized_reference_module.get(type_before_parametrizations(float_module), None)
|
| 801 |
+
assert (
|
| 802 |
+
ref_qmodule_cls is not None
|
| 803 |
+
), f"No reference quantized module class configured for {type_before_parametrizations(float_module)}"
|
| 804 |
+
ref_qmodule = ref_qmodule_cls.from_float(float_module, wq_or_wq_dict) # type: ignore[attr-defined]
|
| 805 |
+
if fused_module is not None:
|
| 806 |
+
fused_module[0] = ref_qmodule # type: ignore[operator]
|
| 807 |
+
else:
|
| 808 |
+
parent_name, name = _parent_name(node.target)
|
| 809 |
+
setattr(modules[parent_name], name, ref_qmodule)
|
| 810 |
+
|
| 811 |
+
def _remove_previous_dequantize_in_custom_module(node: Node, prev_node: Node, graph: Graph) -> None:
|
| 812 |
+
"""
|
| 813 |
+
Given a custom module `node`, if the previous node is a dequantize, reroute the custom as follows:
|
| 814 |
+
|
| 815 |
+
Before: quantize - dequantize - custom_module
|
| 816 |
+
After: quantize - custom_module
|
| 817 |
+
\\ - dequantize
|
| 818 |
+
"""
|
| 819 |
+
# expecting the input node for a custom module node to be a Node
|
| 820 |
+
assert isinstance(prev_node, Node), \
|
| 821 |
+
f"Expecting the argument for custom module node to be a Node, but got {prev_node}"
|
| 822 |
+
if prev_node.op == "call_method" and prev_node.target == "dequantize":
|
| 823 |
+
node.replace_input_with(prev_node, prev_node.args[0])
|
| 824 |
+
# Remove the dequantize node if it doesn't have other users
|
| 825 |
+
if len(prev_node.users) == 0:
|
| 826 |
+
graph.erase_node(prev_node)
|
| 827 |
+
|
| 828 |
+
def convert_custom_module(
|
| 829 |
+
node: Node,
|
| 830 |
+
graph: Graph,
|
| 831 |
+
modules: Dict[str, torch.nn.Module],
|
| 832 |
+
custom_module_class_mapping: Dict[QuantType, Dict[Type, Type]],
|
| 833 |
+
statically_quantized_custom_module_nodes: Set[Node]) -> None:
|
| 834 |
+
""" Converts an observed custom module to a quantized custom module based on
|
| 835 |
+
`custom_module_class_mapping`
|
| 836 |
+
For static quantization, we'll also remove the previous `dequantize` node and
|
| 837 |
+
attach the observer node for output to the module, the observer for the node
|
| 838 |
+
will be converted to a dequantize node instead of quantize-dequantize pairs
|
| 839 |
+
later in the graph. In the end we would have a quantized custom module that
|
| 840 |
+
has the same interface as a default quantized module in nn.quantized namespace,
|
| 841 |
+
i.e. quantized input and quantized output.
|
| 842 |
+
|
| 843 |
+
Args:
|
| 844 |
+
- node: The call_module node of the observed standalone module
|
| 845 |
+
- graph: The graph containing the node
|
| 846 |
+
- modules: named_module of original model
|
| 847 |
+
- custom_module_class_mapping: mapping from observed custom module class to
|
| 848 |
+
quantized custom module class, used to swap custom modules
|
| 849 |
+
- statically_quantized_custom_module_nodes: we'll add the custom module node
|
| 850 |
+
if we find it is statically quantized, this will be used later when converting
|
| 851 |
+
observers to quant/dequant node pairs, if the observed node is a statically
|
| 852 |
+
quantized custom module nodes, we'll convert the observer to a dequantize node,
|
| 853 |
+
this is to keep the interface the same as the default quantized module.
|
| 854 |
+
TODO: maybe we want to redesign this part to align with reference model design
|
| 855 |
+
as well, but there has been some discussions around the interface, so we can do
|
| 856 |
+
it later.
|
| 857 |
+
"""
|
| 858 |
+
observed_custom_module = modules[str(node.target)]
|
| 859 |
+
maybe_obs = _maybe_get_observer_for_node(node, modules)
|
| 860 |
+
qconfig = observed_custom_module.qconfig
|
| 861 |
+
if activation_is_statically_quantized(qconfig):
|
| 862 |
+
statically_quantized_custom_module_nodes.add(node)
|
| 863 |
+
if _is_custom_module_lstm(node, modules):
|
| 864 |
+
# The inputs are tuples in the form (input, (hidden0, hidden1))
|
| 865 |
+
# Ensure all three input nodes are quantized
|
| 866 |
+
assert (
|
| 867 |
+
len(node.args) == 2 and
|
| 868 |
+
isinstance(node.args[1], tuple) and
|
| 869 |
+
len(node.args[1]) == 2
|
| 870 |
+
)
|
| 871 |
+
(inputs, (hidden0, hidden1)) = node.args # type: ignore[misc]
|
| 872 |
+
assert isinstance(inputs, Node)
|
| 873 |
+
assert isinstance(hidden0, Node)
|
| 874 |
+
assert isinstance(hidden1, Node)
|
| 875 |
+
_remove_previous_dequantize_in_custom_module(node, inputs, graph)
|
| 876 |
+
_remove_previous_dequantize_in_custom_module(node, hidden0, graph)
|
| 877 |
+
_remove_previous_dequantize_in_custom_module(node, hidden1, graph)
|
| 878 |
+
elif _is_custom_module_mha(node, modules):
|
| 879 |
+
# Inputs are in the form (query, key, value)
|
| 880 |
+
# TODO: This is the first step in enabling the full fx custom module
|
| 881 |
+
# quantization path for MultiheadAttention, and only covers the inputs
|
| 882 |
+
# to the module.
|
| 883 |
+
# Additional handling is yet to be implemented for the outputs, similar
|
| 884 |
+
# to LSTM custom module
|
| 885 |
+
assert len(node.args) == 3
|
| 886 |
+
query, key, value = node.args
|
| 887 |
+
assert isinstance(query, Node)
|
| 888 |
+
assert isinstance(key, Node)
|
| 889 |
+
assert isinstance(value, Node)
|
| 890 |
+
_remove_previous_dequantize_in_custom_module(node, query, graph)
|
| 891 |
+
_remove_previous_dequantize_in_custom_module(node, key, graph)
|
| 892 |
+
_remove_previous_dequantize_in_custom_module(node, value, graph)
|
| 893 |
+
else:
|
| 894 |
+
# remove the previous dequant node to ensure the inputs are quantized
|
| 895 |
+
arg = node.args[0]
|
| 896 |
+
assert isinstance(arg, Node)
|
| 897 |
+
_remove_previous_dequantize_in_custom_module(node, arg, graph)
|
| 898 |
+
# absorb the following observer into the module conversion
|
| 899 |
+
activation_post_process = _maybe_get_observer_for_node(node, modules)
|
| 900 |
+
assert activation_post_process is not None
|
| 901 |
+
observed_custom_module.activation_post_process = activation_post_process
|
| 902 |
+
|
| 903 |
+
# swap the observed custom module to quantized custom module
|
| 904 |
+
quantized_custom_module_class = get_swapped_custom_module_class(
|
| 905 |
+
observed_custom_module, custom_module_class_mapping, qconfig)
|
| 906 |
+
quantized_custom_module = \
|
| 907 |
+
quantized_custom_module_class.from_observed(observed_custom_module)
|
| 908 |
+
parent_name, name = _parent_name(node.target)
|
| 909 |
+
setattr(modules[parent_name], name, quantized_custom_module)
|
| 910 |
+
|
| 911 |
+
def convert(
|
| 912 |
+
model: GraphModule, is_reference: bool = False,
|
| 913 |
+
convert_custom_config: Union[ConvertCustomConfig, Dict[str, Any], None] = None,
|
| 914 |
+
is_standalone_module: bool = False,
|
| 915 |
+
_remove_qconfig_flag: bool = True,
|
| 916 |
+
qconfig_mapping: Union[QConfigMapping, Dict[str, Any], None] = None,
|
| 917 |
+
backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
|
| 918 |
+
is_decomposed: bool = False) -> GraphModule:
|
| 919 |
+
"""
|
| 920 |
+
We will convert an observed model (a module with observer calls) to a reference
|
| 921 |
+
quantized model, the rule is simple:
|
| 922 |
+
1. for each observer module call in the graph, we'll convert it to calls to
|
| 923 |
+
quantize and dequantize functions based on the observer instance
|
| 924 |
+
2. for weighted operations like linear/conv, we need to convert them to reference
|
| 925 |
+
quantized module, this requires us to know whether the dtype configured for the
|
| 926 |
+
weight is supported in the backend, this is done in prepare step and the result
|
| 927 |
+
is stored in observed_node_names, we can decide whether we need to swap the
|
| 928 |
+
module based on this set
|
| 929 |
+
|
| 930 |
+
Args:
|
| 931 |
+
* `is_standalone_module`: when this flag is True, it means we are quantizing
|
| 932 |
+
a submodule that is not inlined in parent module, and will be quantized
|
| 933 |
+
separately as one unit.
|
| 934 |
+
|
| 935 |
+
* `is_decomposed`: a boolean flag to indicate whether we want to use the
|
| 936 |
+
quantize operator for decomposed quantized tensor
|
| 937 |
+
(torch.ops.quantized_decomposed.quantize_per_tensor) or default/standalone
|
| 938 |
+
quantized tensor (torch.quantize_per_tensor)
|
| 939 |
+
|
| 940 |
+
Returns:
|
| 941 |
+
a quantized standalone module, whether input/output is quantized is
|
| 942 |
+
specified by prepare_custom_config, with
|
| 943 |
+
input_quantized_idxs, output_quantized_idxs, please
|
| 944 |
+
see docs for :func:`~torch.ao.quantization.prepare_fx` for details
|
| 945 |
+
"""
|
| 946 |
+
if convert_custom_config is None:
|
| 947 |
+
convert_custom_config = ConvertCustomConfig()
|
| 948 |
+
|
| 949 |
+
if isinstance(convert_custom_config, Dict):
|
| 950 |
+
warnings.warn(
|
| 951 |
+
"Passing a convert_custom_config_dict to convert is deprecated and will not be supported "
|
| 952 |
+
"in a future version. Please pass in a ConvertCustomConfig instead.")
|
| 953 |
+
convert_custom_config = ConvertCustomConfig.from_dict(convert_custom_config)
|
| 954 |
+
|
| 955 |
+
if isinstance(qconfig_mapping, Dict):
|
| 956 |
+
warnings.warn(
|
| 957 |
+
"Passing a QConfig dictionary to convert is deprecated and will not be supported "
|
| 958 |
+
"in a future version. Please pass in a QConfigMapping instead.")
|
| 959 |
+
qconfig_mapping = QConfigMapping.from_dict(qconfig_mapping) if qconfig_mapping else None
|
| 960 |
+
qconfig_mapping = copy.deepcopy(qconfig_mapping)
|
| 961 |
+
assert qconfig_mapping is None or isinstance(qconfig_mapping, QConfigMapping)
|
| 962 |
+
|
| 963 |
+
if isinstance(backend_config, Dict):
|
| 964 |
+
warnings.warn(
|
| 965 |
+
"Passing a backend_config_dict to prepare is deprecated and will not be supported "
|
| 966 |
+
"in a future version. Please pass in a BackendConfig instead.")
|
| 967 |
+
backend_config = BackendConfig.from_dict(backend_config)
|
| 968 |
+
|
| 969 |
+
if backend_config is None:
|
| 970 |
+
backend_config = get_native_backend_config()
|
| 971 |
+
|
| 972 |
+
assert _is_observed_module(model), \
|
| 973 |
+
'incoming model must be produced by prepare_fx'
|
| 974 |
+
observed_graph_module_attrs = model.meta["_observed_graph_module_attrs"]
|
| 975 |
+
node_name_to_scope: Dict[str, Tuple[str, type]] = observed_graph_module_attrs.node_name_to_scope
|
| 976 |
+
prepare_custom_config: PrepareCustomConfig = observed_graph_module_attrs.prepare_custom_config
|
| 977 |
+
observed_node_names: Set[str] = observed_graph_module_attrs.observed_node_names
|
| 978 |
+
node_name_to_qconfig: Dict[str, QConfigAny] = observed_graph_module_attrs.node_name_to_qconfig # type: ignore[assignment]
|
| 979 |
+
|
| 980 |
+
# mapping from fully qualified module name to module instance
|
| 981 |
+
# for example,
|
| 982 |
+
# {
|
| 983 |
+
# '': Model(...),
|
| 984 |
+
# 'linear': Linear(...),
|
| 985 |
+
# 'linear.weight_fake_quant': PerChannelMinMaxObserver(...),
|
| 986 |
+
# }
|
| 987 |
+
# We use remove_duplicate=False here because torch.cat uses
|
| 988 |
+
# the same activation_post_process module instance but different names
|
| 989 |
+
modules = dict(model.named_modules(remove_duplicate=False))
|
| 990 |
+
|
| 991 |
+
# TODO refactor this code once we update the prepare logic to have additional information on
|
| 992 |
+
# which graph nodes have been observed and share that with convert to decide which observers to ignore.
|
| 993 |
+
if qconfig_mapping:
|
| 994 |
+
prepare_qconfig_mapping: QConfigMapping = observed_graph_module_attrs.qconfig_mapping # type: ignore[assignment]
|
| 995 |
+
modules_copy = copy.deepcopy(modules)
|
| 996 |
+
|
| 997 |
+
if observed_graph_module_attrs.is_qat:
|
| 998 |
+
_update_qconfig_for_qat(qconfig_mapping, backend_config)
|
| 999 |
+
_update_qconfig_for_fusion(model, qconfig_mapping)
|
| 1000 |
+
|
| 1001 |
+
_compare_prepare_convert_qconfig_mappings(prepare_qconfig_mapping, qconfig_mapping) # type: ignore[arg-type]
|
| 1002 |
+
convert_node_name_to_qconfig = _generate_node_name_to_qconfig(
|
| 1003 |
+
model, modules_copy, model.graph, qconfig_mapping, node_name_to_scope)
|
| 1004 |
+
# check the convert_node_name_to_qconfig generated and ensure that
|
| 1005 |
+
# all the values either match what was set in prepare node_name_to_qconfig
|
| 1006 |
+
# or are set to None in the convert_node_name_to_qconfig.
|
| 1007 |
+
for k, v in node_name_to_qconfig.items():
|
| 1008 |
+
assert k in convert_node_name_to_qconfig, f'Expected key {k} in convert node_name_to_qconfig'
|
| 1009 |
+
if convert_node_name_to_qconfig[k] is not None:
|
| 1010 |
+
assert qconfig_equals(v, convert_node_name_to_qconfig[k]), \
|
| 1011 |
+
f"Expected k {k} to have the same value in prepare and convert QConfigMappings, " \
|
| 1012 |
+
f"but {v} was updated to {convert_node_name_to_qconfig[k]}"
|
| 1013 |
+
node_name_to_qconfig = convert_node_name_to_qconfig
|
| 1014 |
+
|
| 1015 |
+
custom_module_classes = get_custom_module_class_keys(convert_custom_config.observed_to_quantized_mapping)
|
| 1016 |
+
custom_module_class_mapping = convert_custom_config.observed_to_quantized_mapping
|
| 1017 |
+
|
| 1018 |
+
if observed_graph_module_attrs.equalization_node_name_to_qconfig is not None:
|
| 1019 |
+
# If we want to do equalization then do the following:
|
| 1020 |
+
# Calculate the equalization scale, update the observers with the scaled
|
| 1021 |
+
# inputs, and scale the weight
|
| 1022 |
+
weight_eq_obs_dict = update_obs_for_equalization(model, modules)
|
| 1023 |
+
convert_eq_obs(model, modules, weight_eq_obs_dict)
|
| 1024 |
+
|
| 1025 |
+
# always run weight observers in the top level forward method
|
| 1026 |
+
# for dynamic quant ops or weight only quant ops
|
| 1027 |
+
_run_weight_observers(model, backend_config)
|
| 1028 |
+
|
| 1029 |
+
graph_inputs: List[str] = []
|
| 1030 |
+
for node in model.graph.nodes:
|
| 1031 |
+
if node.op == 'placeholder':
|
| 1032 |
+
graph_inputs.append(node.name)
|
| 1033 |
+
|
| 1034 |
+
# additional state to override inputs to be quantized, if specified
|
| 1035 |
+
# by the user
|
| 1036 |
+
placeholder_node_seen_cnt = 0
|
| 1037 |
+
input_quantized_idxs: List[int] = prepare_custom_config.input_quantized_indexes
|
| 1038 |
+
output_quantized_idxs: List[int] = prepare_custom_config.output_quantized_indexes
|
| 1039 |
+
|
| 1040 |
+
root_module_to_quantized_reference_module = get_root_module_to_quantized_reference_module(backend_config)
|
| 1041 |
+
# convert tuples so that it can work with isinstance(module, tuple_of_classes)
|
| 1042 |
+
root_module_classes = tuple(root_module_to_quantized_reference_module.keys())
|
| 1043 |
+
qat_module_classes = get_qat_module_classes(backend_config)
|
| 1044 |
+
fused_module_classes = get_fused_module_classes(backend_config)
|
| 1045 |
+
statically_quantized_custom_module_nodes: Set[Node] = set()
|
| 1046 |
+
|
| 1047 |
+
for node in list(model.graph.nodes):
|
| 1048 |
+
if node.op == 'placeholder':
|
| 1049 |
+
cur_placeholder_node_idx = placeholder_node_seen_cnt
|
| 1050 |
+
placeholder_node_seen_cnt += 1
|
| 1051 |
+
if cur_placeholder_node_idx in input_quantized_idxs:
|
| 1052 |
+
# Inputs are assumed to be quantized if the user specified the
|
| 1053 |
+
# input_quantized_idxs override.
|
| 1054 |
+
# we need to dequantize the inputs since all operators took
|
| 1055 |
+
# floating point inputs in reference quantized models
|
| 1056 |
+
_insert_dequantize_node(node, model.graph)
|
| 1057 |
+
elif node.op == "output":
|
| 1058 |
+
# If the argument is empty we don't need to do anything
|
| 1059 |
+
if len(output_quantized_idxs) == 0:
|
| 1060 |
+
continue
|
| 1061 |
+
# Result are kept quantized if the user specified the
|
| 1062 |
+
# output_quantized_idxs override.
|
| 1063 |
+
# Remove the dequantize operator for the node in the end if any
|
| 1064 |
+
return_node = node
|
| 1065 |
+
output = node.args[0]
|
| 1066 |
+
# outputs can be Node, list, tuple, dict, other cases are not supported yet
|
| 1067 |
+
if isinstance(output, (list, tuple)):
|
| 1068 |
+
for idx in output_quantized_idxs:
|
| 1069 |
+
_maybe_recursive_remove_dequantize(output[idx], return_node, model.graph)
|
| 1070 |
+
elif isinstance(output, (Node, dict)):
|
| 1071 |
+
# we treat dict as a single argument currently, but it can be extended
|
| 1072 |
+
# to support {"key": dtype} after we change output_quantized_idxs to
|
| 1073 |
+
# dict
|
| 1074 |
+
if 0 in output_quantized_idxs:
|
| 1075 |
+
_maybe_recursive_remove_dequantize(output, return_node, model.graph)
|
| 1076 |
+
else:
|
| 1077 |
+
warnings.warn(f"Unsupported node type for output_quantized_idxs: {type(output)}")
|
| 1078 |
+
elif node.op == "call_module":
|
| 1079 |
+
mod = _get_module(node, modules)
|
| 1080 |
+
assert mod is not None
|
| 1081 |
+
if _is_activation_post_process(mod):
|
| 1082 |
+
observed_node = node.args[0]
|
| 1083 |
+
if observed_node in statically_quantized_custom_module_nodes:
|
| 1084 |
+
_replace_observer_or_dequant_stub_with_dequantize_node(node, model.graph)
|
| 1085 |
+
else:
|
| 1086 |
+
if is_decomposed:
|
| 1087 |
+
_replace_observer_with_quantize_dequantize_node_decomposed(
|
| 1088 |
+
model, node, modules, node_name_to_scope,
|
| 1089 |
+
node_name_to_qconfig)
|
| 1090 |
+
else:
|
| 1091 |
+
_replace_observer_with_quantize_dequantize_node(
|
| 1092 |
+
model, node, modules, node_name_to_scope,
|
| 1093 |
+
node_name_to_qconfig)
|
| 1094 |
+
elif isinstance(mod, DeQuantStub):
|
| 1095 |
+
_replace_observer_or_dequant_stub_with_dequantize_node(node, model.graph)
|
| 1096 |
+
elif _is_observed_standalone_module(mod):
|
| 1097 |
+
convert_standalone_module(
|
| 1098 |
+
node, modules, model, is_reference, backend_config)
|
| 1099 |
+
# below this point `type_before_parametrizations` is used
|
| 1100 |
+
# instead of `type` to handle situations with fx quant + sparsity
|
| 1101 |
+
elif type_before_parametrizations(mod) in set(
|
| 1102 |
+
root_module_classes).union(qat_module_classes).union(fused_module_classes):
|
| 1103 |
+
# extra check for fused module classes to make sure they are fused module classes
|
| 1104 |
+
# of target modules
|
| 1105 |
+
if type_before_parametrizations(mod) in fused_module_classes and \
|
| 1106 |
+
type_before_parametrizations(mod[0]) not in root_module_classes: # type: ignore[index]
|
| 1107 |
+
continue
|
| 1108 |
+
convert_weighted_module(
|
| 1109 |
+
node, modules, observed_node_names, node_name_to_qconfig, backend_config,
|
| 1110 |
+
is_decomposed, is_reference)
|
| 1111 |
+
elif type_before_parametrizations(mod) in custom_module_classes:
|
| 1112 |
+
convert_custom_module(
|
| 1113 |
+
node, model.graph, modules, custom_module_class_mapping,
|
| 1114 |
+
statically_quantized_custom_module_nodes)
|
| 1115 |
+
|
| 1116 |
+
# remove deadcode after converting observers to quant/dequant ops
|
| 1117 |
+
model.graph.eliminate_dead_code()
|
| 1118 |
+
model = GraphModule(model, model.graph)
|
| 1119 |
+
|
| 1120 |
+
# TODO: maybe move this to quantize_fx.py
|
| 1121 |
+
if not is_reference:
|
| 1122 |
+
model = lower_to_fbgemm(model, node_name_to_qconfig, node_name_to_scope)
|
| 1123 |
+
|
| 1124 |
+
# TODO: this looks hacky, we want to check why we need this and see if we can
|
| 1125 |
+
# remove this
|
| 1126 |
+
# removes qconfig and activation_post_process modules
|
| 1127 |
+
if _remove_qconfig_flag:
|
| 1128 |
+
_remove_qconfig(model)
|
| 1129 |
+
model.delete_all_unused_submodules()
|
| 1130 |
+
model.meta.pop("_observed_graph_module_attrs", None)
|
| 1131 |
+
return model
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/match_utils.py
ADDED
|
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import torch
|
| 3 |
+
from torch.fx.graph import (
|
| 4 |
+
Graph,
|
| 5 |
+
Node,
|
| 6 |
+
)
|
| 7 |
+
from torch.ao.quantization.utils import Pattern
|
| 8 |
+
from .quantize_handler import (
|
| 9 |
+
QuantizeHandler,
|
| 10 |
+
)
|
| 11 |
+
from ..qconfig import (
|
| 12 |
+
QConfigAny,
|
| 13 |
+
)
|
| 14 |
+
from ..utils import (
|
| 15 |
+
MatchAllNode
|
| 16 |
+
)
|
| 17 |
+
from .graph_module import (
|
| 18 |
+
_is_observed_standalone_module,
|
| 19 |
+
)
|
| 20 |
+
from torch.nn.utils.parametrize import type_before_parametrizations
|
| 21 |
+
from typing import Any, Dict, List, Callable, Optional, Tuple, Type, Set, Iterable
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
__all__: List[str] = []
|
| 25 |
+
|
| 26 |
+
# TODO(future PR): the 1st argument is typed as `List[Node]`, but a better type
|
| 27 |
+
# would be a recursive `List[Union[Node, Tuple[Union[Node, ...]]]]`
|
| 28 |
+
_MatchResult = Tuple[Node, List[Node], Optional[Pattern], QuantizeHandler]
|
| 29 |
+
|
| 30 |
+
_MatchResultWithQConfig = Tuple[Node, List[Node], Optional[Pattern], QuantizeHandler,
|
| 31 |
+
QConfigAny]
|
| 32 |
+
|
| 33 |
+
# Note: The order of patterns is important! match function will take whatever is matched first, so we'll
|
| 34 |
+
# need to put the fusion patterns before single patterns. For example, add_relu should be registered come before relu.
|
| 35 |
+
# decorators are applied in the reverse order we see. Also when we match the nodes in the graph with these patterns,
|
| 36 |
+
# we'll start from the last node of the graph and traverse back.
|
| 37 |
+
def _is_match(modules, node, pattern, max_uses=sys.maxsize):
|
| 38 |
+
""" Matches a node in fx against a pattern
|
| 39 |
+
"""
|
| 40 |
+
if isinstance(pattern, tuple):
|
| 41 |
+
self_match, *arg_matches = pattern
|
| 42 |
+
if self_match is getattr:
|
| 43 |
+
assert len(pattern) == 2, 'Expecting getattr pattern to have two elements'
|
| 44 |
+
arg_matches = []
|
| 45 |
+
else:
|
| 46 |
+
self_match = pattern
|
| 47 |
+
arg_matches = []
|
| 48 |
+
|
| 49 |
+
if isinstance(self_match, type) and issubclass(self_match, MatchAllNode):
|
| 50 |
+
return True
|
| 51 |
+
|
| 52 |
+
if node == pattern:
|
| 53 |
+
return True
|
| 54 |
+
|
| 55 |
+
if not isinstance(node, Node) or len(node.users) > max_uses:
|
| 56 |
+
return False
|
| 57 |
+
|
| 58 |
+
if isinstance(self_match, type) and issubclass(self_match, torch.nn.Module):
|
| 59 |
+
if node.op != 'call_module':
|
| 60 |
+
return False
|
| 61 |
+
if not type_before_parametrizations(modules[node.target]) == self_match:
|
| 62 |
+
return False
|
| 63 |
+
elif callable(self_match):
|
| 64 |
+
if node.op != 'call_function' or node.target is not self_match:
|
| 65 |
+
return False
|
| 66 |
+
elif node.target is getattr:
|
| 67 |
+
if node.args[1] != pattern[1]:
|
| 68 |
+
return False
|
| 69 |
+
elif isinstance(self_match, str):
|
| 70 |
+
if node.op != 'call_method' or node.target != self_match:
|
| 71 |
+
return False
|
| 72 |
+
elif node.target != self_match:
|
| 73 |
+
return False
|
| 74 |
+
|
| 75 |
+
if not arg_matches:
|
| 76 |
+
return True
|
| 77 |
+
|
| 78 |
+
if len(arg_matches) != len(node.args):
|
| 79 |
+
return False
|
| 80 |
+
|
| 81 |
+
return all(_is_match(modules, node, arg_match, max_uses=1) for node, arg_match in zip(node.args, arg_matches))
|
| 82 |
+
|
| 83 |
+
def _find_matches(
|
| 84 |
+
graph: Graph,
|
| 85 |
+
modules: Dict[str, torch.nn.Module],
|
| 86 |
+
patterns: Dict[Pattern, QuantizeHandler],
|
| 87 |
+
root_node_getter_mapping: Dict[Pattern, Callable],
|
| 88 |
+
standalone_module_names: Optional[List[str]] = None,
|
| 89 |
+
standalone_module_classes: Optional[List[Type]] = None,
|
| 90 |
+
custom_module_classes: Optional[List[Any]] = None) -> Dict[str, _MatchResult]:
|
| 91 |
+
"""
|
| 92 |
+
Matches the nodes in the input graph to quantization patterns, and
|
| 93 |
+
outputs the information needed to quantize them in future steps.
|
| 94 |
+
|
| 95 |
+
Inputs:
|
| 96 |
+
- graph: an fx.Graph object
|
| 97 |
+
- modules: a mapping of fully qualified module name to instance,
|
| 98 |
+
for example, {'foo': ModuleFoo, ...}
|
| 99 |
+
- patterns: a mapping from a tuple of nodes in reverse order to
|
| 100 |
+
uninitialized QuantizeHandler subclass.
|
| 101 |
+
|
| 102 |
+
Outputs a map of
|
| 103 |
+
node_name ->
|
| 104 |
+
(node, matched_values, matched_pattern, QuantizeHandler instance,
|
| 105 |
+
qconfig)
|
| 106 |
+
|
| 107 |
+
For example, {
|
| 108 |
+
'relu_1': (relu_1, [relu_1], torch.nn.functional.relu,
|
| 109 |
+
<CopyNodeQuantizeHandler instance>, QConfig(...)),
|
| 110 |
+
...
|
| 111 |
+
}
|
| 112 |
+
"""
|
| 113 |
+
if custom_module_classes is None:
|
| 114 |
+
custom_module_classes = []
|
| 115 |
+
|
| 116 |
+
if standalone_module_classes is None:
|
| 117 |
+
standalone_module_classes = []
|
| 118 |
+
|
| 119 |
+
if standalone_module_names is None:
|
| 120 |
+
standalone_module_names = []
|
| 121 |
+
|
| 122 |
+
match_map: Dict[str, _MatchResult] = {}
|
| 123 |
+
all_matched : Set[str] = set()
|
| 124 |
+
|
| 125 |
+
def _recursive_record_node_in_match_map(
|
| 126 |
+
last_node,
|
| 127 |
+
match_map,
|
| 128 |
+
node_pattern,
|
| 129 |
+
matched_node_pattern,
|
| 130 |
+
pattern,
|
| 131 |
+
match_value):
|
| 132 |
+
if isinstance(node_pattern, Node):
|
| 133 |
+
match_map[node_pattern.name] = (
|
| 134 |
+
last_node, matched_node_pattern, pattern, match_value)
|
| 135 |
+
elif not isinstance(node_pattern, Iterable):
|
| 136 |
+
return
|
| 137 |
+
else:
|
| 138 |
+
for n in node_pattern:
|
| 139 |
+
_recursive_record_node_in_match_map(last_node, match_map, n, matched_node_pattern, pattern, match_value)
|
| 140 |
+
|
| 141 |
+
# TODO: 1. merge with fuse matcher 2. document the code
|
| 142 |
+
def record_match(
|
| 143 |
+
pattern,
|
| 144 |
+
node,
|
| 145 |
+
last_node,
|
| 146 |
+
matched_node_pattern,
|
| 147 |
+
match_map):
|
| 148 |
+
if isinstance(pattern, tuple):
|
| 149 |
+
s, *args = pattern
|
| 150 |
+
is_single_arg = len(args) == 1
|
| 151 |
+
current_node_pattern: List[Node] = []
|
| 152 |
+
record_match(
|
| 153 |
+
s,
|
| 154 |
+
node,
|
| 155 |
+
last_node,
|
| 156 |
+
matched_node_pattern,
|
| 157 |
+
match_map)
|
| 158 |
+
if pattern[0] is not getattr:
|
| 159 |
+
for subpattern, arg in zip(args, node.args):
|
| 160 |
+
record_match(
|
| 161 |
+
subpattern,
|
| 162 |
+
arg,
|
| 163 |
+
node,
|
| 164 |
+
current_node_pattern,
|
| 165 |
+
match_map)
|
| 166 |
+
if len(current_node_pattern) > 1:
|
| 167 |
+
# current_node_pattern is the node pattern we get from matching
|
| 168 |
+
# the subpattern with arguments of the node
|
| 169 |
+
# we use is_single_arg to recover the original structure of the pattern
|
| 170 |
+
# if the original pattern has a single argument, we will have
|
| 171 |
+
# (original_op, (original_arg, ...))
|
| 172 |
+
# otherwise, we'll have a list of arguments
|
| 173 |
+
# (original_op, arg0, arg1, arg2, ...)
|
| 174 |
+
if is_single_arg:
|
| 175 |
+
matched_node_pattern.append(tuple(current_node_pattern))
|
| 176 |
+
else:
|
| 177 |
+
matched_node_pattern.extend(list(current_node_pattern))
|
| 178 |
+
else:
|
| 179 |
+
matched_node_pattern.append(current_node_pattern[0])
|
| 180 |
+
else:
|
| 181 |
+
matched_node_pattern.append(node)
|
| 182 |
+
|
| 183 |
+
for node in reversed(graph.nodes):
|
| 184 |
+
if node.name not in match_map and node.name not in all_matched:
|
| 185 |
+
for pattern, quantize_handler_cls in patterns.items():
|
| 186 |
+
root_node_getter = root_node_getter_mapping.get(pattern, None)
|
| 187 |
+
if _is_match(modules, node, pattern) and node.name not in match_map:
|
| 188 |
+
matched_node_pattern: List[Node] = []
|
| 189 |
+
record_match(
|
| 190 |
+
pattern,
|
| 191 |
+
node,
|
| 192 |
+
node,
|
| 193 |
+
matched_node_pattern,
|
| 194 |
+
match_map)
|
| 195 |
+
quantize_handler = quantize_handler_cls( # type: ignore[operator]
|
| 196 |
+
matched_node_pattern,
|
| 197 |
+
modules,
|
| 198 |
+
root_node_getter)
|
| 199 |
+
last_node = node
|
| 200 |
+
# record the match for all nodes in the pattern
|
| 201 |
+
_recursive_record_node_in_match_map(
|
| 202 |
+
last_node,
|
| 203 |
+
match_map,
|
| 204 |
+
# we need to record all nodes in the matched pattern in the match_map
|
| 205 |
+
matched_node_pattern,
|
| 206 |
+
# this is a part of the value corresponding to the node
|
| 207 |
+
matched_node_pattern,
|
| 208 |
+
pattern,
|
| 209 |
+
quantize_handler)
|
| 210 |
+
break
|
| 211 |
+
|
| 212 |
+
# add custom module instances to the match result
|
| 213 |
+
assert modules is not None
|
| 214 |
+
for node in graph.nodes:
|
| 215 |
+
if node.op == 'call_module' and \
|
| 216 |
+
type(modules[node.target]) in custom_module_classes:
|
| 217 |
+
match_map[node.name] = (
|
| 218 |
+
node, node, None, QuantizeHandler(node, modules, is_custom_module=True))
|
| 219 |
+
|
| 220 |
+
def is_standalone_module(node_target: str, modules: Dict[str, torch.nn.Module]):
|
| 221 |
+
assert modules is not None
|
| 222 |
+
return (
|
| 223 |
+
node_target in standalone_module_names or # type: ignore[operator]
|
| 224 |
+
type(modules[node_target]) in standalone_module_classes # type: ignore[operator]
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
# add standalone modules to the match
|
| 228 |
+
for node in graph.nodes:
|
| 229 |
+
if node.op == 'call_module' and \
|
| 230 |
+
(is_standalone_module(node.target, modules) or
|
| 231 |
+
_is_observed_standalone_module(modules[node.target])):
|
| 232 |
+
# add node to matched nodes
|
| 233 |
+
match_map[node.name] = (
|
| 234 |
+
node, node, None,
|
| 235 |
+
QuantizeHandler(node, modules, is_standalone_module=True))
|
| 236 |
+
|
| 237 |
+
return match_map
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/tracer.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.fx._symbolic_trace import Tracer
|
| 3 |
+
from torch.fx.proxy import Scope
|
| 4 |
+
from torch.ao.nn.intrinsic import _FusedModule
|
| 5 |
+
from typing import List, Callable
|
| 6 |
+
|
| 7 |
+
__all__ = [
|
| 8 |
+
"QuantizationTracer",
|
| 9 |
+
]
|
| 10 |
+
|
| 11 |
+
class ScopeContextManager(torch.fx.proxy.ScopeContextManager):
|
| 12 |
+
def __init__(
|
| 13 |
+
self,
|
| 14 |
+
scope: Scope,
|
| 15 |
+
current_module: torch.nn.Module,
|
| 16 |
+
current_module_path: str
|
| 17 |
+
):
|
| 18 |
+
super().__init__(scope, Scope(current_module_path, type(current_module)))
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class QuantizationTracer(Tracer):
|
| 22 |
+
def __init__(
|
| 23 |
+
self, skipped_module_names: List[str], skipped_module_classes: List[Callable]
|
| 24 |
+
):
|
| 25 |
+
super().__init__()
|
| 26 |
+
self.skipped_module_names = skipped_module_names
|
| 27 |
+
self.skipped_module_classes = skipped_module_classes
|
| 28 |
+
# NB: initialized the module_type of top level module to None
|
| 29 |
+
# we are assuming people won't configure the model with the type of top level
|
| 30 |
+
# module here, since people can use "" for global config
|
| 31 |
+
# We can change this if there is a use case that configures
|
| 32 |
+
# qconfig using top level module type
|
| 33 |
+
self.scope = Scope("", None)
|
| 34 |
+
self.record_stack_traces = True
|
| 35 |
+
|
| 36 |
+
def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool:
|
| 37 |
+
return (
|
| 38 |
+
(
|
| 39 |
+
(m.__module__.startswith("torch.nn") or m.__module__.startswith("torch.ao.nn"))
|
| 40 |
+
and not isinstance(m, torch.nn.Sequential)
|
| 41 |
+
)
|
| 42 |
+
or module_qualified_name in self.skipped_module_names
|
| 43 |
+
or type(m) in self.skipped_module_classes
|
| 44 |
+
or isinstance(m, _FusedModule)
|
| 45 |
+
)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/multiprocessing/__pycache__/queue.cpython-311.pyc
ADDED
|
Binary file (3.71 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/multiprocessing/pool.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import multiprocessing.pool
|
| 2 |
+
import multiprocessing.util as util
|
| 3 |
+
|
| 4 |
+
from .queue import SimpleQueue
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def clean_worker(*args, **kwargs):
|
| 8 |
+
import gc
|
| 9 |
+
|
| 10 |
+
multiprocessing.pool.worker(*args, **kwargs)
|
| 11 |
+
# Regular multiprocessing workers don't fully clean up after themselves,
|
| 12 |
+
# so we have to explicitly trigger garbage collection to make sure that all
|
| 13 |
+
# destructors are called...
|
| 14 |
+
gc.collect()
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class Pool(multiprocessing.pool.Pool):
|
| 18 |
+
"""Pool implementation which uses our version of SimpleQueue.
|
| 19 |
+
|
| 20 |
+
This lets us pass tensors in shared memory across processes instead of
|
| 21 |
+
serializing the underlying data.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def _setup_queues(self):
|
| 25 |
+
self._inqueue = SimpleQueue()
|
| 26 |
+
self._outqueue = SimpleQueue()
|
| 27 |
+
self._quick_put = self._inqueue._writer.send
|
| 28 |
+
self._quick_get = self._outqueue._reader.recv
|
| 29 |
+
|
| 30 |
+
def _repopulate_pool(self):
|
| 31 |
+
"""Increase the number of pool processes to the specified number.
|
| 32 |
+
|
| 33 |
+
Bring the number of pool processes up to the specified number, for use after
|
| 34 |
+
reaping workers which have exited.
|
| 35 |
+
"""
|
| 36 |
+
for i in range(self._processes - len(self._pool)):
|
| 37 |
+
# changed worker -> clean_worker
|
| 38 |
+
args = (
|
| 39 |
+
self._inqueue,
|
| 40 |
+
self._outqueue,
|
| 41 |
+
self._initializer,
|
| 42 |
+
self._initargs,
|
| 43 |
+
self._maxtasksperchild,
|
| 44 |
+
)
|
| 45 |
+
if hasattr(self, "_wrap_exception"):
|
| 46 |
+
args += (self._wrap_exception,)
|
| 47 |
+
w = self.Process(target=clean_worker, args=args)
|
| 48 |
+
self._pool.append(w)
|
| 49 |
+
w.name = w.name.replace("Process", "PoolWorker")
|
| 50 |
+
w.daemon = True
|
| 51 |
+
w.start()
|
| 52 |
+
util.debug("added worker")
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/__pycache__/_reduction.cpython-311.pyc
ADDED
|
Binary file (1.92 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/__pycache__/grad.cpython-311.pyc
ADDED
|
Binary file (10.7 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/attention/__pycache__/_utils.cpython-311.pyc
ADDED
|
Binary file (3.66 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/attention/_utils.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Defines utilities for interacting with scaled_dot_product_attention"""
|
| 2 |
+
import math
|
| 3 |
+
from typing import List, Optional
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
__all__: List[str] = []
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def _input_requires_grad(*tensors: torch.Tensor) -> bool:
|
| 11 |
+
"""Returns True if any of the tensors requires grad"""
|
| 12 |
+
return any(t.requires_grad for t in tensors)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def _postprocess_flash_output(inpt_tensor: torch.Tensor, og_size: int) -> torch.Tensor:
|
| 16 |
+
"""Handles the unpad of the last dimension"""
|
| 17 |
+
if inpt_tensor.size(-1) != og_size:
|
| 18 |
+
return inpt_tensor[..., :og_size]
|
| 19 |
+
return inpt_tensor
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def _calculate_scale(head_dim_size: int, scale: Optional[float]) -> float:
|
| 23 |
+
"""
|
| 24 |
+
For FlashAttention we pad the head dimension to be a multiple of 8 so we need to scale the output
|
| 25 |
+
by the original head size and not the padded.
|
| 26 |
+
"""
|
| 27 |
+
if scale is not None:
|
| 28 |
+
return scale
|
| 29 |
+
return 1.0 / math.sqrt(head_dim_size)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def _validate_sdpa_input(
|
| 33 |
+
query: torch.Tensor,
|
| 34 |
+
key: torch.Tensor,
|
| 35 |
+
value: torch.Tensor,
|
| 36 |
+
attn_mask: Optional[torch.Tensor] = None,
|
| 37 |
+
dropout_p=0.0,
|
| 38 |
+
is_causal=False,
|
| 39 |
+
scale=None,
|
| 40 |
+
):
|
| 41 |
+
if query.dtype != key.dtype or query.dtype != value.dtype:
|
| 42 |
+
raise ValueError(
|
| 43 |
+
f"Expected query, key, and value to have the same dtype, "
|
| 44 |
+
f"but got query.dtype: {query.dtype}, key.dtype: {key.dtype}, "
|
| 45 |
+
f"and value.dtype: {value.dtype} instead."
|
| 46 |
+
)
|
| 47 |
+
if query.device != key.device or query.device != value.device:
|
| 48 |
+
raise ValueError(
|
| 49 |
+
f"Expected query, key, and value to have the same device type, "
|
| 50 |
+
f"but got query.device: {query.device}, key.device: {key.device}, "
|
| 51 |
+
f"and value.device: {value.device} instead."
|
| 52 |
+
)
|
| 53 |
+
if query.dim() < 2 or key.dim() < 2 or value.dim() < 2:
|
| 54 |
+
raise ValueError(
|
| 55 |
+
f"Expected query, key, and value to all be at least 2 dimensional, but got query.dim: "
|
| 56 |
+
f"{query.dim()}, key.dim: {key.dim()} and value.dim: {value.dim()} instead."
|
| 57 |
+
)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/quantized/modules/__pycache__/bn_relu.cpython-311.pyc
ADDED
|
Binary file (401 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/_functions.py
ADDED
|
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.distributed as dist
|
| 3 |
+
|
| 4 |
+
from torch.autograd.function import Function
|
| 5 |
+
|
| 6 |
+
class SyncBatchNorm(Function):
|
| 7 |
+
|
| 8 |
+
@staticmethod
|
| 9 |
+
def forward(self, input, weight, bias, running_mean, running_var, eps, momentum, process_group, world_size):
|
| 10 |
+
if not (
|
| 11 |
+
input.is_contiguous(memory_format=torch.channels_last) or
|
| 12 |
+
input.is_contiguous(memory_format=torch.channels_last_3d)
|
| 13 |
+
):
|
| 14 |
+
input = input.contiguous()
|
| 15 |
+
if weight is not None:
|
| 16 |
+
weight = weight.contiguous()
|
| 17 |
+
|
| 18 |
+
size = int(input.numel() // input.size(1))
|
| 19 |
+
if size == 1 and world_size < 2:
|
| 20 |
+
raise ValueError(f'Expected more than 1 value per channel when training, got input size {size}')
|
| 21 |
+
|
| 22 |
+
num_channels = input.shape[1]
|
| 23 |
+
if input.numel() > 0:
|
| 24 |
+
# calculate mean/invstd for input.
|
| 25 |
+
mean, invstd = torch.batch_norm_stats(input, eps)
|
| 26 |
+
|
| 27 |
+
count = torch.full(
|
| 28 |
+
(1,),
|
| 29 |
+
input.numel() // input.size(1),
|
| 30 |
+
dtype=mean.dtype,
|
| 31 |
+
device=mean.device
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
# C, C, 1 -> (2C + 1)
|
| 35 |
+
combined = torch.cat([mean, invstd, count], dim=0)
|
| 36 |
+
else:
|
| 37 |
+
# for empty input, set stats and the count to zero. The stats with
|
| 38 |
+
# zero count will be filtered out later when computing global mean
|
| 39 |
+
# & invstd, but they still needs to participate the all_gather
|
| 40 |
+
# collective communication to unblock other peer processes.
|
| 41 |
+
combined = torch.zeros(
|
| 42 |
+
2 * num_channels + 1,
|
| 43 |
+
dtype=input.dtype,
|
| 44 |
+
device=input.device
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
# Use allgather instead of allreduce because count could be different across
|
| 48 |
+
# ranks, simple all reduce op can not give correct results.
|
| 49 |
+
# batch_norm_gather_stats_with_counts calculates global mean & invstd based on
|
| 50 |
+
# all gathered mean, invstd and count.
|
| 51 |
+
# for nccl backend, use the optimized version of all gather.
|
| 52 |
+
# The Gloo backend does not support `all_gather_into_tensor`.
|
| 53 |
+
if process_group._get_backend_name() != "gloo":
|
| 54 |
+
# world_size * (2C + 1)
|
| 55 |
+
combined_size = combined.numel()
|
| 56 |
+
combined_flat = torch.empty(1,
|
| 57 |
+
combined_size * world_size,
|
| 58 |
+
dtype=combined.dtype,
|
| 59 |
+
device=combined.device)
|
| 60 |
+
dist.all_gather_into_tensor(combined_flat, combined, process_group, async_op=False)
|
| 61 |
+
combined = torch.reshape(combined_flat, (world_size, combined_size))
|
| 62 |
+
# world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1
|
| 63 |
+
mean_all, invstd_all, count_all = torch.split(combined, num_channels, dim=1)
|
| 64 |
+
else:
|
| 65 |
+
# world_size * (2C + 1)
|
| 66 |
+
combined_list = [
|
| 67 |
+
torch.empty_like(combined) for _ in range(world_size)
|
| 68 |
+
]
|
| 69 |
+
dist.all_gather(combined_list, combined, process_group, async_op=False)
|
| 70 |
+
combined = torch.stack(combined_list, dim=0)
|
| 71 |
+
# world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1
|
| 72 |
+
mean_all, invstd_all, count_all = torch.split(combined, num_channels, dim=1)
|
| 73 |
+
|
| 74 |
+
if not (torch.cuda.is_available() and torch.cuda.is_current_stream_capturing()):
|
| 75 |
+
# The lines below force a synchronization between CUDA and CPU, because
|
| 76 |
+
# the shape of the result count_all depends on the values in mask tensor.
|
| 77 |
+
# Such synchronizations break CUDA Graph capturing.
|
| 78 |
+
# See https://github.com/pytorch/pytorch/issues/78549
|
| 79 |
+
# FIXME: https://github.com/pytorch/pytorch/issues/78656 describes
|
| 80 |
+
# a better longer-term solution.
|
| 81 |
+
|
| 82 |
+
# remove stats from empty inputs
|
| 83 |
+
mask = count_all.squeeze(-1) >= 1
|
| 84 |
+
count_all = count_all[mask]
|
| 85 |
+
mean_all = mean_all[mask]
|
| 86 |
+
invstd_all = invstd_all[mask]
|
| 87 |
+
|
| 88 |
+
# calculate global mean & invstd
|
| 89 |
+
counts = count_all.view(-1)
|
| 90 |
+
if running_mean is not None and counts.dtype != running_mean.dtype:
|
| 91 |
+
counts = counts.to(running_mean.dtype)
|
| 92 |
+
mean, invstd = torch.batch_norm_gather_stats_with_counts(
|
| 93 |
+
input,
|
| 94 |
+
mean_all,
|
| 95 |
+
invstd_all,
|
| 96 |
+
running_mean,
|
| 97 |
+
running_var,
|
| 98 |
+
momentum,
|
| 99 |
+
eps,
|
| 100 |
+
counts,
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
self.save_for_backward(input, weight, mean, invstd, count_all.to(torch.int32))
|
| 104 |
+
self.process_group = process_group
|
| 105 |
+
|
| 106 |
+
# apply element-wise normalization
|
| 107 |
+
if input.numel() > 0:
|
| 108 |
+
return torch.batch_norm_elemt(input, weight, bias, mean, invstd, eps)
|
| 109 |
+
else:
|
| 110 |
+
return torch.empty_like(input)
|
| 111 |
+
|
| 112 |
+
@staticmethod
|
| 113 |
+
def backward(self, grad_output):
|
| 114 |
+
if not (
|
| 115 |
+
grad_output.is_contiguous(memory_format=torch.channels_last) or
|
| 116 |
+
grad_output.is_contiguous(memory_format=torch.channels_last_3d)
|
| 117 |
+
):
|
| 118 |
+
grad_output = grad_output.contiguous()
|
| 119 |
+
saved_input, weight, mean, invstd, count_tensor = self.saved_tensors
|
| 120 |
+
grad_input = grad_weight = grad_bias = None
|
| 121 |
+
process_group = self.process_group
|
| 122 |
+
|
| 123 |
+
if saved_input.numel() > 0:
|
| 124 |
+
# calculate local stats as well as grad_weight / grad_bias
|
| 125 |
+
sum_dy, sum_dy_xmu, grad_weight, grad_bias = torch.batch_norm_backward_reduce(
|
| 126 |
+
grad_output,
|
| 127 |
+
saved_input,
|
| 128 |
+
mean,
|
| 129 |
+
invstd,
|
| 130 |
+
weight,
|
| 131 |
+
self.needs_input_grad[0],
|
| 132 |
+
self.needs_input_grad[1],
|
| 133 |
+
self.needs_input_grad[2]
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
if self.needs_input_grad[0]:
|
| 137 |
+
# synchronizing stats used to calculate input gradient.
|
| 138 |
+
num_channels = sum_dy.shape[0]
|
| 139 |
+
combined = torch.cat([sum_dy, sum_dy_xmu], dim=0)
|
| 140 |
+
torch.distributed.all_reduce(
|
| 141 |
+
combined, torch.distributed.ReduceOp.SUM, process_group, async_op=False)
|
| 142 |
+
sum_dy, sum_dy_xmu = torch.split(combined, num_channels)
|
| 143 |
+
|
| 144 |
+
# backward pass for gradient calculation
|
| 145 |
+
if weight is not None and weight.dtype != mean.dtype:
|
| 146 |
+
weight = weight.to(mean.dtype)
|
| 147 |
+
grad_input = torch.batch_norm_backward_elemt(
|
| 148 |
+
grad_output,
|
| 149 |
+
saved_input,
|
| 150 |
+
mean,
|
| 151 |
+
invstd,
|
| 152 |
+
weight,
|
| 153 |
+
sum_dy,
|
| 154 |
+
sum_dy_xmu,
|
| 155 |
+
count_tensor
|
| 156 |
+
)
|
| 157 |
+
# synchronizing of grad_weight / grad_bias is not needed as distributed
|
| 158 |
+
# training would handle all reduce.
|
| 159 |
+
if weight is None or not self.needs_input_grad[1]:
|
| 160 |
+
grad_weight = None
|
| 161 |
+
|
| 162 |
+
if weight is None or not self.needs_input_grad[2]:
|
| 163 |
+
grad_bias = None
|
| 164 |
+
else:
|
| 165 |
+
# This process got an empty input tensor in the forward pass.
|
| 166 |
+
# Although this process can directly set grad_input as an empty
|
| 167 |
+
# tensor of zeros, it still needs to participate in the collective
|
| 168 |
+
# communication to unblock its peers, as other peer processes might
|
| 169 |
+
# have received non-empty inputs.
|
| 170 |
+
num_channels = saved_input.shape[1]
|
| 171 |
+
if self.needs_input_grad[0]:
|
| 172 |
+
# launch all_reduce to unblock other peer processes
|
| 173 |
+
combined = torch.zeros(
|
| 174 |
+
2 * num_channels,
|
| 175 |
+
dtype=saved_input.dtype,
|
| 176 |
+
device=saved_input.device
|
| 177 |
+
)
|
| 178 |
+
torch.distributed.all_reduce(
|
| 179 |
+
combined, torch.distributed.ReduceOp.SUM, process_group, async_op=False)
|
| 180 |
+
|
| 181 |
+
# Leave grad_input, grad_weight and grad_bias as None, which will be
|
| 182 |
+
# interpreted by the autograd engine as Tensors full of zeros.
|
| 183 |
+
|
| 184 |
+
return grad_input, grad_weight, grad_bias, None, None, None, None, None, None
|
| 185 |
+
|
| 186 |
+
class CrossMapLRN2d(Function):
|
| 187 |
+
|
| 188 |
+
@staticmethod
|
| 189 |
+
def forward(ctx, input, size, alpha=1e-4, beta=0.75, k=1):
|
| 190 |
+
ctx.size = size
|
| 191 |
+
ctx.alpha = alpha
|
| 192 |
+
ctx.beta = beta
|
| 193 |
+
ctx.k = k
|
| 194 |
+
ctx.scale = None
|
| 195 |
+
|
| 196 |
+
if input.dim() != 4:
|
| 197 |
+
raise ValueError(f"CrossMapLRN2d: Expected input to be 4D, got {input.dim()}D instead.")
|
| 198 |
+
|
| 199 |
+
ctx.scale = ctx.scale or input.new()
|
| 200 |
+
output = input.new()
|
| 201 |
+
|
| 202 |
+
batch_size = input.size(0)
|
| 203 |
+
channels = input.size(1)
|
| 204 |
+
input_height = input.size(2)
|
| 205 |
+
input_width = input.size(3)
|
| 206 |
+
|
| 207 |
+
output.resize_as_(input)
|
| 208 |
+
ctx.scale.resize_as_(input)
|
| 209 |
+
|
| 210 |
+
# use output storage as temporary buffer
|
| 211 |
+
input_square = output
|
| 212 |
+
torch.pow(input, 2, out=input_square)
|
| 213 |
+
|
| 214 |
+
pre_pad = int((ctx.size - 1) / 2 + 1)
|
| 215 |
+
pre_pad_crop = min(pre_pad, channels)
|
| 216 |
+
|
| 217 |
+
scale_first = ctx.scale.select(1, 0)
|
| 218 |
+
scale_first.zero_()
|
| 219 |
+
# compute first feature map normalization
|
| 220 |
+
for c in range(pre_pad_crop):
|
| 221 |
+
scale_first.add_(input_square.select(1, c))
|
| 222 |
+
|
| 223 |
+
# reuse computations for next feature maps normalization
|
| 224 |
+
# by adding the next feature map and removing the previous
|
| 225 |
+
for c in range(1, channels):
|
| 226 |
+
scale_previous = ctx.scale.select(1, c - 1)
|
| 227 |
+
scale_current = ctx.scale.select(1, c)
|
| 228 |
+
scale_current.copy_(scale_previous)
|
| 229 |
+
if c < channels - pre_pad + 1:
|
| 230 |
+
square_next = input_square.select(1, c + pre_pad - 1)
|
| 231 |
+
scale_current.add_(square_next, alpha=1)
|
| 232 |
+
|
| 233 |
+
if c > pre_pad:
|
| 234 |
+
square_previous = input_square.select(1, c - pre_pad)
|
| 235 |
+
scale_current.add_(square_previous, alpha=-1)
|
| 236 |
+
|
| 237 |
+
ctx.scale.mul_(ctx.alpha / ctx.size).add_(ctx.k)
|
| 238 |
+
|
| 239 |
+
torch.pow(ctx.scale, -ctx.beta, out=output)
|
| 240 |
+
output.mul_(input)
|
| 241 |
+
|
| 242 |
+
ctx.save_for_backward(input, output)
|
| 243 |
+
return output
|
| 244 |
+
|
| 245 |
+
@staticmethod
|
| 246 |
+
def backward(ctx, grad_output):
|
| 247 |
+
input, output = ctx.saved_tensors
|
| 248 |
+
grad_input = grad_output.new()
|
| 249 |
+
|
| 250 |
+
batch_size = input.size(0)
|
| 251 |
+
channels = input.size(1)
|
| 252 |
+
input_height = input.size(2)
|
| 253 |
+
input_width = input.size(3)
|
| 254 |
+
|
| 255 |
+
paddded_ratio = input.new(channels + ctx.size - 1, input_height,
|
| 256 |
+
input_width)
|
| 257 |
+
accum_ratio = input.new(input_height, input_width)
|
| 258 |
+
|
| 259 |
+
cache_ratio_value = 2 * ctx.alpha * ctx.beta / ctx.size
|
| 260 |
+
inversePrePad = int(ctx.size - (ctx.size - 1) / 2)
|
| 261 |
+
|
| 262 |
+
grad_input.resize_as_(input)
|
| 263 |
+
torch.pow(ctx.scale, -ctx.beta, out=grad_input).mul_(grad_output)
|
| 264 |
+
|
| 265 |
+
paddded_ratio.zero_()
|
| 266 |
+
padded_ratio_center = paddded_ratio.narrow(0, inversePrePad,
|
| 267 |
+
channels)
|
| 268 |
+
for n in range(batch_size):
|
| 269 |
+
torch.mul(grad_output[n], output[n], out=padded_ratio_center)
|
| 270 |
+
padded_ratio_center.div_(ctx.scale[n])
|
| 271 |
+
torch.sum(
|
| 272 |
+
paddded_ratio.narrow(0, 0, ctx.size - 1), 0, keepdim=False, out=accum_ratio)
|
| 273 |
+
for c in range(channels):
|
| 274 |
+
accum_ratio.add_(paddded_ratio[c + ctx.size - 1])
|
| 275 |
+
grad_input[n][c].addcmul_(input[n][c], accum_ratio, value=-cache_ratio_value)
|
| 276 |
+
accum_ratio.add_(paddded_ratio[c], alpha=-1)
|
| 277 |
+
|
| 278 |
+
return grad_input, None, None, None, None
|
| 279 |
+
|
| 280 |
+
class BackwardHookFunction(torch.autograd.Function):
|
| 281 |
+
@staticmethod
|
| 282 |
+
def forward(ctx, *args):
|
| 283 |
+
ctx.mark_non_differentiable(*[arg for arg in args if not arg.requires_grad])
|
| 284 |
+
return args
|
| 285 |
+
|
| 286 |
+
@staticmethod
|
| 287 |
+
def backward(ctx, *args):
|
| 288 |
+
return args
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/upsampling.py
ADDED
|
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .module import Module
|
| 2 |
+
from .. import functional as F
|
| 3 |
+
|
| 4 |
+
from torch import Tensor
|
| 5 |
+
from typing import Optional
|
| 6 |
+
from ..common_types import _size_2_t, _ratio_2_t, _size_any_t, _ratio_any_t
|
| 7 |
+
|
| 8 |
+
__all__ = ['Upsample', 'UpsamplingNearest2d', 'UpsamplingBilinear2d']
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class Upsample(Module):
|
| 12 |
+
r"""Upsamples a given multi-channel 1D (temporal), 2D (spatial) or 3D (volumetric) data.
|
| 13 |
+
|
| 14 |
+
The input data is assumed to be of the form
|
| 15 |
+
`minibatch x channels x [optional depth] x [optional height] x width`.
|
| 16 |
+
Hence, for spatial inputs, we expect a 4D Tensor and for volumetric inputs, we expect a 5D Tensor.
|
| 17 |
+
|
| 18 |
+
The algorithms available for upsampling are nearest neighbor and linear,
|
| 19 |
+
bilinear, bicubic and trilinear for 3D, 4D and 5D input Tensor,
|
| 20 |
+
respectively.
|
| 21 |
+
|
| 22 |
+
One can either give a :attr:`scale_factor` or the target output :attr:`size` to
|
| 23 |
+
calculate the output size. (You cannot give both, as it is ambiguous)
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int], optional):
|
| 27 |
+
output spatial sizes
|
| 28 |
+
scale_factor (float or Tuple[float] or Tuple[float, float] or Tuple[float, float, float], optional):
|
| 29 |
+
multiplier for spatial size. Has to match input size if it is a tuple.
|
| 30 |
+
mode (str, optional): the upsampling algorithm: one of ``'nearest'``,
|
| 31 |
+
``'linear'``, ``'bilinear'``, ``'bicubic'`` and ``'trilinear'``.
|
| 32 |
+
Default: ``'nearest'``
|
| 33 |
+
align_corners (bool, optional): if ``True``, the corner pixels of the input
|
| 34 |
+
and output tensors are aligned, and thus preserving the values at
|
| 35 |
+
those pixels. This only has effect when :attr:`mode` is
|
| 36 |
+
``'linear'``, ``'bilinear'``, ``'bicubic'``, or ``'trilinear'``.
|
| 37 |
+
Default: ``False``
|
| 38 |
+
recompute_scale_factor (bool, optional): recompute the scale_factor for use in the
|
| 39 |
+
interpolation calculation. If `recompute_scale_factor` is ``True``, then
|
| 40 |
+
`scale_factor` must be passed in and `scale_factor` is used to compute the
|
| 41 |
+
output `size`. The computed output `size` will be used to infer new scales for
|
| 42 |
+
the interpolation. Note that when `scale_factor` is floating-point, it may differ
|
| 43 |
+
from the recomputed `scale_factor` due to rounding and precision issues.
|
| 44 |
+
If `recompute_scale_factor` is ``False``, then `size` or `scale_factor` will
|
| 45 |
+
be used directly for interpolation.
|
| 46 |
+
|
| 47 |
+
Shape:
|
| 48 |
+
- Input: :math:`(N, C, W_{in})`, :math:`(N, C, H_{in}, W_{in})` or :math:`(N, C, D_{in}, H_{in}, W_{in})`
|
| 49 |
+
- Output: :math:`(N, C, W_{out})`, :math:`(N, C, H_{out}, W_{out})`
|
| 50 |
+
or :math:`(N, C, D_{out}, H_{out}, W_{out})`, where
|
| 51 |
+
|
| 52 |
+
.. math::
|
| 53 |
+
D_{out} = \left\lfloor D_{in} \times \text{scale\_factor} \right\rfloor
|
| 54 |
+
|
| 55 |
+
.. math::
|
| 56 |
+
H_{out} = \left\lfloor H_{in} \times \text{scale\_factor} \right\rfloor
|
| 57 |
+
|
| 58 |
+
.. math::
|
| 59 |
+
W_{out} = \left\lfloor W_{in} \times \text{scale\_factor} \right\rfloor
|
| 60 |
+
|
| 61 |
+
.. warning::
|
| 62 |
+
With ``align_corners = True``, the linearly interpolating modes
|
| 63 |
+
(`linear`, `bilinear`, `bicubic`, and `trilinear`) don't proportionally
|
| 64 |
+
align the output and input pixels, and thus the output values can depend
|
| 65 |
+
on the input size. This was the default behavior for these modes up to
|
| 66 |
+
version 0.3.1. Since then, the default behavior is
|
| 67 |
+
``align_corners = False``. See below for concrete examples on how this
|
| 68 |
+
affects the outputs.
|
| 69 |
+
|
| 70 |
+
.. note::
|
| 71 |
+
If you want downsampling/general resizing, you should use :func:`~nn.functional.interpolate`.
|
| 72 |
+
|
| 73 |
+
Examples::
|
| 74 |
+
|
| 75 |
+
>>> input = torch.arange(1, 5, dtype=torch.float32).view(1, 1, 2, 2)
|
| 76 |
+
>>> input
|
| 77 |
+
tensor([[[[1., 2.],
|
| 78 |
+
[3., 4.]]]])
|
| 79 |
+
|
| 80 |
+
>>> m = nn.Upsample(scale_factor=2, mode='nearest')
|
| 81 |
+
>>> m(input)
|
| 82 |
+
tensor([[[[1., 1., 2., 2.],
|
| 83 |
+
[1., 1., 2., 2.],
|
| 84 |
+
[3., 3., 4., 4.],
|
| 85 |
+
[3., 3., 4., 4.]]]])
|
| 86 |
+
|
| 87 |
+
>>> # xdoctest: +IGNORE_WANT("other tests seem to modify printing styles")
|
| 88 |
+
>>> m = nn.Upsample(scale_factor=2, mode='bilinear') # align_corners=False
|
| 89 |
+
>>> m(input)
|
| 90 |
+
tensor([[[[1.0000, 1.2500, 1.7500, 2.0000],
|
| 91 |
+
[1.5000, 1.7500, 2.2500, 2.5000],
|
| 92 |
+
[2.5000, 2.7500, 3.2500, 3.5000],
|
| 93 |
+
[3.0000, 3.2500, 3.7500, 4.0000]]]])
|
| 94 |
+
|
| 95 |
+
>>> m = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
|
| 96 |
+
>>> m(input)
|
| 97 |
+
tensor([[[[1.0000, 1.3333, 1.6667, 2.0000],
|
| 98 |
+
[1.6667, 2.0000, 2.3333, 2.6667],
|
| 99 |
+
[2.3333, 2.6667, 3.0000, 3.3333],
|
| 100 |
+
[3.0000, 3.3333, 3.6667, 4.0000]]]])
|
| 101 |
+
|
| 102 |
+
>>> # Try scaling the same data in a larger tensor
|
| 103 |
+
>>> input_3x3 = torch.zeros(3, 3).view(1, 1, 3, 3)
|
| 104 |
+
>>> input_3x3[:, :, :2, :2].copy_(input)
|
| 105 |
+
tensor([[[[1., 2.],
|
| 106 |
+
[3., 4.]]]])
|
| 107 |
+
>>> input_3x3
|
| 108 |
+
tensor([[[[1., 2., 0.],
|
| 109 |
+
[3., 4., 0.],
|
| 110 |
+
[0., 0., 0.]]]])
|
| 111 |
+
|
| 112 |
+
>>> # xdoctest: +IGNORE_WANT("seems to fail when other tests are run in the same session")
|
| 113 |
+
>>> m = nn.Upsample(scale_factor=2, mode='bilinear') # align_corners=False
|
| 114 |
+
>>> # Notice that values in top left corner are the same with the small input (except at boundary)
|
| 115 |
+
>>> m(input_3x3)
|
| 116 |
+
tensor([[[[1.0000, 1.2500, 1.7500, 1.5000, 0.5000, 0.0000],
|
| 117 |
+
[1.5000, 1.7500, 2.2500, 1.8750, 0.6250, 0.0000],
|
| 118 |
+
[2.5000, 2.7500, 3.2500, 2.6250, 0.8750, 0.0000],
|
| 119 |
+
[2.2500, 2.4375, 2.8125, 2.2500, 0.7500, 0.0000],
|
| 120 |
+
[0.7500, 0.8125, 0.9375, 0.7500, 0.2500, 0.0000],
|
| 121 |
+
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]]])
|
| 122 |
+
|
| 123 |
+
>>> m = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
|
| 124 |
+
>>> # Notice that values in top left corner are now changed
|
| 125 |
+
>>> m(input_3x3)
|
| 126 |
+
tensor([[[[1.0000, 1.4000, 1.8000, 1.6000, 0.8000, 0.0000],
|
| 127 |
+
[1.8000, 2.2000, 2.6000, 2.2400, 1.1200, 0.0000],
|
| 128 |
+
[2.6000, 3.0000, 3.4000, 2.8800, 1.4400, 0.0000],
|
| 129 |
+
[2.4000, 2.7200, 3.0400, 2.5600, 1.2800, 0.0000],
|
| 130 |
+
[1.2000, 1.3600, 1.5200, 1.2800, 0.6400, 0.0000],
|
| 131 |
+
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]]])
|
| 132 |
+
"""
|
| 133 |
+
|
| 134 |
+
__constants__ = ['size', 'scale_factor', 'mode', 'align_corners', 'name', 'recompute_scale_factor']
|
| 135 |
+
name: str
|
| 136 |
+
size: Optional[_size_any_t]
|
| 137 |
+
scale_factor: Optional[_ratio_any_t]
|
| 138 |
+
mode: str
|
| 139 |
+
align_corners: Optional[bool]
|
| 140 |
+
recompute_scale_factor: Optional[bool]
|
| 141 |
+
|
| 142 |
+
def __init__(self, size: Optional[_size_any_t] = None, scale_factor: Optional[_ratio_any_t] = None,
|
| 143 |
+
mode: str = 'nearest', align_corners: Optional[bool] = None,
|
| 144 |
+
recompute_scale_factor: Optional[bool] = None) -> None:
|
| 145 |
+
super().__init__()
|
| 146 |
+
self.name = type(self).__name__
|
| 147 |
+
self.size = size
|
| 148 |
+
if isinstance(scale_factor, tuple):
|
| 149 |
+
self.scale_factor = tuple(float(factor) for factor in scale_factor)
|
| 150 |
+
else:
|
| 151 |
+
self.scale_factor = float(scale_factor) if scale_factor else None
|
| 152 |
+
self.mode = mode
|
| 153 |
+
self.align_corners = align_corners
|
| 154 |
+
self.recompute_scale_factor = recompute_scale_factor
|
| 155 |
+
|
| 156 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 157 |
+
return F.interpolate(input, self.size, self.scale_factor, self.mode, self.align_corners,
|
| 158 |
+
recompute_scale_factor=self.recompute_scale_factor)
|
| 159 |
+
|
| 160 |
+
def __setstate__(self, state):
|
| 161 |
+
if 'recompute_scale_factor' not in state:
|
| 162 |
+
state['recompute_scale_factor'] = True
|
| 163 |
+
|
| 164 |
+
super().__setstate__(state)
|
| 165 |
+
|
| 166 |
+
def extra_repr(self) -> str:
|
| 167 |
+
if self.scale_factor is not None:
|
| 168 |
+
info = 'scale_factor=' + repr(self.scale_factor)
|
| 169 |
+
else:
|
| 170 |
+
info = 'size=' + repr(self.size)
|
| 171 |
+
info += ', mode=' + repr(self.mode)
|
| 172 |
+
return info
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
class UpsamplingNearest2d(Upsample):
|
| 176 |
+
r"""Applies a 2D nearest neighbor upsampling to an input signal composed of several input channels.
|
| 177 |
+
|
| 178 |
+
To specify the scale, it takes either the :attr:`size` or the :attr:`scale_factor`
|
| 179 |
+
as it's constructor argument.
|
| 180 |
+
|
| 181 |
+
When :attr:`size` is given, it is the output size of the image `(h, w)`.
|
| 182 |
+
|
| 183 |
+
Args:
|
| 184 |
+
size (int or Tuple[int, int], optional): output spatial sizes
|
| 185 |
+
scale_factor (float or Tuple[float, float], optional): multiplier for
|
| 186 |
+
spatial size.
|
| 187 |
+
|
| 188 |
+
.. warning::
|
| 189 |
+
This class is deprecated in favor of :func:`~nn.functional.interpolate`.
|
| 190 |
+
|
| 191 |
+
Shape:
|
| 192 |
+
- Input: :math:`(N, C, H_{in}, W_{in})`
|
| 193 |
+
- Output: :math:`(N, C, H_{out}, W_{out})` where
|
| 194 |
+
|
| 195 |
+
.. math::
|
| 196 |
+
H_{out} = \left\lfloor H_{in} \times \text{scale\_factor} \right\rfloor
|
| 197 |
+
|
| 198 |
+
.. math::
|
| 199 |
+
W_{out} = \left\lfloor W_{in} \times \text{scale\_factor} \right\rfloor
|
| 200 |
+
|
| 201 |
+
Examples::
|
| 202 |
+
|
| 203 |
+
>>> input = torch.arange(1, 5, dtype=torch.float32).view(1, 1, 2, 2)
|
| 204 |
+
>>> input
|
| 205 |
+
tensor([[[[1., 2.],
|
| 206 |
+
[3., 4.]]]])
|
| 207 |
+
|
| 208 |
+
>>> m = nn.UpsamplingNearest2d(scale_factor=2)
|
| 209 |
+
>>> m(input)
|
| 210 |
+
tensor([[[[1., 1., 2., 2.],
|
| 211 |
+
[1., 1., 2., 2.],
|
| 212 |
+
[3., 3., 4., 4.],
|
| 213 |
+
[3., 3., 4., 4.]]]])
|
| 214 |
+
"""
|
| 215 |
+
|
| 216 |
+
def __init__(self, size: Optional[_size_2_t] = None, scale_factor: Optional[_ratio_2_t] = None) -> None:
|
| 217 |
+
super().__init__(size, scale_factor, mode='nearest')
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
class UpsamplingBilinear2d(Upsample):
|
| 221 |
+
r"""Applies a 2D bilinear upsampling to an input signal composed of several input channels.
|
| 222 |
+
|
| 223 |
+
To specify the scale, it takes either the :attr:`size` or the :attr:`scale_factor`
|
| 224 |
+
as it's constructor argument.
|
| 225 |
+
|
| 226 |
+
When :attr:`size` is given, it is the output size of the image `(h, w)`.
|
| 227 |
+
|
| 228 |
+
Args:
|
| 229 |
+
size (int or Tuple[int, int], optional): output spatial sizes
|
| 230 |
+
scale_factor (float or Tuple[float, float], optional): multiplier for
|
| 231 |
+
spatial size.
|
| 232 |
+
|
| 233 |
+
.. warning::
|
| 234 |
+
This class is deprecated in favor of :func:`~nn.functional.interpolate`. It is
|
| 235 |
+
equivalent to ``nn.functional.interpolate(..., mode='bilinear', align_corners=True)``.
|
| 236 |
+
|
| 237 |
+
Shape:
|
| 238 |
+
- Input: :math:`(N, C, H_{in}, W_{in})`
|
| 239 |
+
- Output: :math:`(N, C, H_{out}, W_{out})` where
|
| 240 |
+
|
| 241 |
+
.. math::
|
| 242 |
+
H_{out} = \left\lfloor H_{in} \times \text{scale\_factor} \right\rfloor
|
| 243 |
+
|
| 244 |
+
.. math::
|
| 245 |
+
W_{out} = \left\lfloor W_{in} \times \text{scale\_factor} \right\rfloor
|
| 246 |
+
|
| 247 |
+
Examples::
|
| 248 |
+
|
| 249 |
+
>>> input = torch.arange(1, 5, dtype=torch.float32).view(1, 1, 2, 2)
|
| 250 |
+
>>> input
|
| 251 |
+
tensor([[[[1., 2.],
|
| 252 |
+
[3., 4.]]]])
|
| 253 |
+
|
| 254 |
+
>>> # xdoctest: +IGNORE_WANT("do other tests modify the global state?")
|
| 255 |
+
>>> m = nn.UpsamplingBilinear2d(scale_factor=2)
|
| 256 |
+
>>> m(input)
|
| 257 |
+
tensor([[[[1.0000, 1.3333, 1.6667, 2.0000],
|
| 258 |
+
[1.6667, 2.0000, 2.3333, 2.6667],
|
| 259 |
+
[2.3333, 2.6667, 3.0000, 3.3333],
|
| 260 |
+
[3.0000, 3.3333, 3.6667, 4.0000]]]])
|
| 261 |
+
"""
|
| 262 |
+
|
| 263 |
+
def __init__(self, size: Optional[_size_2_t] = None, scale_factor: Optional[_ratio_2_t] = None) -> None:
|
| 264 |
+
super().__init__(size, scale_factor, mode='bilinear', align_corners=True)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/qat/dynamic/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (401 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/qat/dynamic/modules/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .linear import Linear
|
| 2 |
+
|
| 3 |
+
__all__ = ["Linear"]
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/qat/dynamic/modules/linear.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flake8: noqa: F401
|
| 2 |
+
r"""QAT Modules.
|
| 3 |
+
|
| 4 |
+
This file is in the process of migration to `torch/ao/nn/qat/dynamic`, and
|
| 5 |
+
is kept here for compatibility while the migration process is ongoing.
|
| 6 |
+
If you are adding a new entry/functionality, please, add it to the
|
| 7 |
+
appropriate file under the `torch/ao/nn/qat/dynamic/modules`,
|
| 8 |
+
while adding an import statement here.
|
| 9 |
+
"""
|
| 10 |
+
from torch.ao.nn.qat.dynamic.modules.linear import Linear
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/qat/modules/embedding_ops.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flake8: noqa: F401
|
| 2 |
+
r"""QAT Modules.
|
| 3 |
+
|
| 4 |
+
This file is in the process of migration to `torch/ao/nn/qat`, and
|
| 5 |
+
is kept here for compatibility while the migration process is ongoing.
|
| 6 |
+
If you are adding a new entry/functionality, please, add it to the
|
| 7 |
+
appropriate file under the `torch/ao/nn/qat/modules`,
|
| 8 |
+
while adding an import statement here.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
__all__ = ['Embedding', 'EmbeddingBag']
|
| 12 |
+
|
| 13 |
+
from torch.ao.nn.qat.modules.embedding_ops import Embedding
|
| 14 |
+
from torch.ao.nn.qat.modules.embedding_ops import EmbeddingBag
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantizable/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .modules import * # noqa: F403
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantizable/modules/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (504 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantizable/modules/__pycache__/activation.cpython-311.pyc
ADDED
|
Binary file (697 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantizable/modules/rnn.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flake8: noqa: F401
|
| 2 |
+
r"""Quantizable Modules.
|
| 3 |
+
|
| 4 |
+
This file is in the process of migration to `torch/ao/nn/quantizable`, and
|
| 5 |
+
is kept here for compatibility while the migration process is ongoing.
|
| 6 |
+
If you are adding a new entry/functionality, please, add it to the
|
| 7 |
+
appropriate file under the `torch/ao/nn/quantizable/modules`,
|
| 8 |
+
while adding an import statement here.
|
| 9 |
+
"""
|
| 10 |
+
from torch.ao.nn.quantizable.modules.rnn import LSTM
|
| 11 |
+
from torch.ao.nn.quantizable.modules.rnn import LSTMCell
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/__init__.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from . import dynamic # noqa: F403
|
| 2 |
+
from . import functional # noqa: F403
|
| 3 |
+
from . import modules # noqa: F403
|
| 4 |
+
from .modules import * # noqa: F403
|
| 5 |
+
from .modules import MaxPool2d
|
| 6 |
+
|
| 7 |
+
__all__ = [
|
| 8 |
+
'BatchNorm2d',
|
| 9 |
+
'BatchNorm3d',
|
| 10 |
+
'Conv1d',
|
| 11 |
+
'Conv2d',
|
| 12 |
+
'Conv3d',
|
| 13 |
+
'ConvTranspose1d',
|
| 14 |
+
'ConvTranspose2d',
|
| 15 |
+
'ConvTranspose3d',
|
| 16 |
+
'DeQuantize',
|
| 17 |
+
'Dropout',
|
| 18 |
+
'ELU',
|
| 19 |
+
'Embedding',
|
| 20 |
+
'EmbeddingBag',
|
| 21 |
+
'GroupNorm',
|
| 22 |
+
'Hardswish',
|
| 23 |
+
'InstanceNorm1d',
|
| 24 |
+
'InstanceNorm2d',
|
| 25 |
+
'InstanceNorm3d',
|
| 26 |
+
'LayerNorm',
|
| 27 |
+
'LeakyReLU',
|
| 28 |
+
'Linear',
|
| 29 |
+
'LSTM',
|
| 30 |
+
'MultiheadAttention',
|
| 31 |
+
'PReLU',
|
| 32 |
+
'Quantize',
|
| 33 |
+
'ReLU6',
|
| 34 |
+
'Sigmoid',
|
| 35 |
+
'Softmax',
|
| 36 |
+
# Wrapper modules
|
| 37 |
+
'FloatFunctional',
|
| 38 |
+
'FXFloatFunctional',
|
| 39 |
+
'QFunctional',
|
| 40 |
+
]
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/_reference/modules/__pycache__/conv.cpython-311.pyc
ADDED
|
Binary file (1.06 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/_reference/modules/__pycache__/linear.cpython-311.pyc
ADDED
|
Binary file (712 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/_reference/modules/__pycache__/rnn.cpython-311.pyc
ADDED
|
Binary file (939 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/_reference/modules/__pycache__/sparse.cpython-311.pyc
ADDED
|
Binary file (766 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/dynamic/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from torch.ao.nn.quantized.dynamic import * # noqa: F403
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/dynamic/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (281 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/dynamic/modules/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (1.42 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/dynamic/modules/conv.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flake8: noqa: F401
|
| 2 |
+
r"""Quantized Dynamic Modules.
|
| 3 |
+
|
| 4 |
+
This file is in the process of migration to `torch/ao/nn/quantized/dynamic`,
|
| 5 |
+
and is kept here for compatibility while the migration process is ongoing.
|
| 6 |
+
If you are adding a new entry/functionality, please, add it to the
|
| 7 |
+
appropriate file under the `torch/ao/nn/quantized/dynamic/modules`,
|
| 8 |
+
while adding an import statement here.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
__all__ = ['Conv1d', 'Conv2d', 'Conv3d', 'ConvTranspose1d', 'ConvTranspose2d', 'ConvTranspose3d']
|
| 12 |
+
|
| 13 |
+
from torch.ao.nn.quantized.dynamic.modules.conv import Conv1d
|
| 14 |
+
from torch.ao.nn.quantized.dynamic.modules.conv import Conv2d
|
| 15 |
+
from torch.ao.nn.quantized.dynamic.modules.conv import Conv3d
|
| 16 |
+
from torch.ao.nn.quantized.dynamic.modules.conv import ConvTranspose1d
|
| 17 |
+
from torch.ao.nn.quantized.dynamic.modules.conv import ConvTranspose2d
|
| 18 |
+
from torch.ao.nn.quantized.dynamic.modules.conv import ConvTranspose3d
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/dynamic/modules/linear.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flake8: noqa: F401
|
| 2 |
+
r"""Quantized Dynamic Modules.
|
| 3 |
+
|
| 4 |
+
This file is in the process of migration to `torch/ao/nn/quantized/dynamic`,
|
| 5 |
+
and is kept here for compatibility while the migration process is ongoing.
|
| 6 |
+
If you are adding a new entry/functionality, please, add it to the
|
| 7 |
+
appropriate file under the `torch/ao/nn/quantized/dynamic/modules`,
|
| 8 |
+
while adding an import statement here.
|
| 9 |
+
"""
|
| 10 |
+
from torch.ao.nn.quantized.dynamic.modules.linear import Linear
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/dynamic/modules/rnn.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flake8: noqa: F401
|
| 2 |
+
r"""Quantized Dynamic Modules.
|
| 3 |
+
|
| 4 |
+
This file is in the process of migration to `torch/ao/nn/quantized/dynamic`,
|
| 5 |
+
and is kept here for compatibility while the migration process is ongoing.
|
| 6 |
+
If you are adding a new entry/functionality, please, add it to the
|
| 7 |
+
appropriate file under the `torch/ao/nn/quantized/dynamic/modules`,
|
| 8 |
+
while adding an import statement here.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
__all__ = ['pack_weight_bias', 'PackedParameter', 'RNNBase', 'LSTM', 'GRU', 'RNNCellBase', 'RNNCell', 'LSTMCell',
|
| 12 |
+
'GRUCell']
|
| 13 |
+
|
| 14 |
+
from torch.ao.nn.quantized.dynamic.modules.rnn import pack_weight_bias
|
| 15 |
+
from torch.ao.nn.quantized.dynamic.modules.rnn import PackedParameter
|
| 16 |
+
from torch.ao.nn.quantized.dynamic.modules.rnn import RNNBase
|
| 17 |
+
from torch.ao.nn.quantized.dynamic.modules.rnn import LSTM
|
| 18 |
+
from torch.ao.nn.quantized.dynamic.modules.rnn import GRU
|
| 19 |
+
from torch.ao.nn.quantized.dynamic.modules.rnn import RNNCellBase
|
| 20 |
+
from torch.ao.nn.quantized.dynamic.modules.rnn import RNNCell
|
| 21 |
+
from torch.ao.nn.quantized.dynamic.modules.rnn import LSTMCell
|
| 22 |
+
from torch.ao.nn.quantized.dynamic.modules.rnn import GRUCell
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/modules/__pycache__/activation.cpython-311.pyc
ADDED
|
Binary file (1 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/modules/__pycache__/conv.cpython-311.pyc
ADDED
|
Binary file (1.04 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/modules/__pycache__/dropout.cpython-311.pyc
ADDED
|
Binary file (696 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/modules/__pycache__/linear.cpython-311.pyc
ADDED
|
Binary file (760 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/modules/__pycache__/utils.cpython-311.pyc
ADDED
|
Binary file (912 Bytes). View file
|
|
|