diff --git a/.gitattributes b/.gitattributes index 45f721353c93ea457c3a59e85d51f5c414156235..eb80d51c0f40b475ff09def757c067bda3dc0fcc 100644 --- a/.gitattributes +++ b/.gitattributes @@ -72,3 +72,5 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/F tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cuda_cupti/lib/libcheckpoint.so filter=lfs diff=lfs merge=lfs -text tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/Visitor.cpython-311-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text tuning-competition-baseline/.venv/lib/python3.11/site-packages/pip/_vendor/pkg_resources/__pycache__/__init__.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text +tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cufft/lib/libcufftw.so.10 filter=lfs diff=lfs merge=lfs -text +tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/__pycache__/ModuleNode.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/__pycache__/ModuleNode.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/__pycache__/ModuleNode.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3eff7d591579cbaf34ec18750651b5e825d90497 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/Cython/Compiler/__pycache__/ModuleNode.cpython-311.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6e838b6e228975d93cd0fdc5e05254915109c27d231652f0851ab23f1b207b5f +size 233093 diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cufft/lib/libcufftw.so.10 b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cufft/lib/libcufftw.so.10 new file mode 100644 index 0000000000000000000000000000000000000000..9b25e67296d2dca6182c2ab6d6f2360fb60e663b --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cufft/lib/libcufftw.so.10 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1a592a5b2f359a9077550ee1fdadd58eb2cf9cc0bfab8fe397a374fb949da143 +size 1618440 diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/config.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/config.py new file mode 100644 index 0000000000000000000000000000000000000000..e7a4d1dd24f8dbf505995982bbb33b8d90d3de2e --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/config.py @@ -0,0 +1,16 @@ +import torch._C._lazy + + +def get_force_fallback(): + """Get the config used to force LTC fallback""" + return torch._C._lazy._get_force_fallback() + + +def set_force_fallback(configval): + """Set the config used to force LTC fallback""" + torch._C._lazy._set_force_fallback(configval) + + +def set_reuse_ir(val: bool): + """Set the config to reuse IR nodes for faster tracing""" + torch._C._lazy._set_reuse_ir(val) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/device_context.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/device_context.py new file mode 100644 index 0000000000000000000000000000000000000000..840c7f8e50d039c9b72f31b16e8d69f706920534 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/device_context.py @@ -0,0 +1,25 @@ +import threading +from typing import Any, Dict + +import torch._C._lazy + + +class DeviceContext: + _CONTEXTS: Dict[str, Any] = dict() + _CONTEXTS_LOCK = threading.Lock() + + def __init__(self, device): + self.device = device + + +def get_device_context(device=None): + if device is None: + device = torch._C._lazy._get_default_device_type() + else: + device = str(device) + with DeviceContext._CONTEXTS_LOCK: + devctx = DeviceContext._CONTEXTS.get(device, None) + if devctx is None: + devctx = DeviceContext(device) + DeviceContext._CONTEXTS[device] = devctx + return devctx diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/ir_cache.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/ir_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..4270684d29434747f53177e48a58fd8dc9c7c44b --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/ir_cache.py @@ -0,0 +1,13 @@ +import torch._C._lazy + + +def dump(dot_file_name: str): + """Dump TrieCache in the dot format""" + return torch._C._lazy._dump_ir_cache(dot_file_name) + + +def reset(): + """Clear TrieCache. This is needed in testing to avoid + node reusing between different tests. + """ + return torch._C._lazy._clear_ir_cache() diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/metrics.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..2d7db730556779a353a1bb9f4b2529464d4bfc95 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/metrics.py @@ -0,0 +1,21 @@ +import torch._C._lazy + + +def reset(): + """Resets all metric counters.""" + torch._C._lazy._reset_metrics() + + +def counter_names(): + """Retrieves all the currently active counter names.""" + return torch._C._lazy._counter_names() + + +def counter_value(name: str): + """Return the value of the counter with the speficied name""" + return torch._C._lazy._counter_value(name) + + +def metrics_report(): + """Return the combined (lazy core and backend) metric report""" + return torch._C._lazy._metrics_report() diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4af9193dbba209f97d491a43fe29c86cfeb8a9c3 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/dynamic/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/dynamic/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3d79bdbfe83209f18b17cc8c7b245f322871d6c0 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/dynamic/__init__.py @@ -0,0 +1 @@ +from .modules import * # noqa: F403 diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/dynamic/modules/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/dynamic/modules/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ad08e79638654cb812a55d13999fbb676edc6499 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/dynamic/modules/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/dynamic/modules/__pycache__/conv.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/dynamic/modules/__pycache__/conv.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..afff3d7aa9f701bf783f5e2f3985ed56585248cd Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/dynamic/modules/__pycache__/conv.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/_learnable_fake_quantize.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/_learnable_fake_quantize.py new file mode 100644 index 0000000000000000000000000000000000000000..6827ae35533cd1a7ea651256b599c916efb0b8a0 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/_learnable_fake_quantize.py @@ -0,0 +1,164 @@ +import torch +from torch.nn.parameter import Parameter +from typing import List + +__all__: List[str] = [] + +class _LearnableFakeQuantize(torch.ao.quantization.FakeQuantizeBase): + r"""Generalized extension of the FakeQuantize module in fake_quantize.py. + + This is an extension of the FakeQuantize module in fake_quantize.py, which + supports more generalized lower-bit quantization and support learning of the scale + and zero point parameters through backpropagation. For literature references, + please see the class _LearnableFakeQuantizePerTensorOp. + + In addition to the attributes in the original FakeQuantize module, the _LearnableFakeQuantize + module also includes the following attributes to support quantization parameter learning. + + * :attr:`channel_len` defines the length of the channel when initializing scale and zero point + for the per channel case. + + * :attr:`use_grad_scaling` defines the flag for whether the gradients for scale and zero point are + normalized by the constant, which is proportional to the square root of the number of + elements in the tensor. The related literature justifying the use of this particular constant + can be found here: https://openreview.net/pdf?id=rkgO66VKDS. + + * :attr:`fake_quant_enabled` defines the flag for enabling fake quantization on the output. + + * :attr:`static_enabled` defines the flag for using observer's static estimation for + scale and zero point. + + * :attr:`learning_enabled` defines the flag for enabling backpropagation for scale and zero point. + """ + def __init__(self, observer, quant_min=0, quant_max=255, scale=1., zero_point=0., channel_len=-1, + use_grad_scaling=False, **observer_kwargs): + super().__init__() + assert quant_min < quant_max, 'quant_min must be strictly less than quant_max.' + self.quant_min = quant_min + self.quant_max = quant_max + # also pass quant_min and quant_max to observer + observer_kwargs["quant_min"] = quant_min + observer_kwargs["quant_max"] = quant_max + self.use_grad_scaling = use_grad_scaling + if channel_len == -1: + self.scale = Parameter(torch.tensor([scale])) + self.zero_point = Parameter(torch.tensor([zero_point])) + else: + assert isinstance(channel_len, int) and channel_len > 0, "Channel size must be a positive integer." + self.scale = Parameter(torch.tensor([scale] * channel_len)) + self.zero_point = Parameter(torch.tensor([zero_point] * channel_len)) + + self.activation_post_process = observer(**observer_kwargs) + assert torch.iinfo(self.activation_post_process.dtype).min <= quant_min, \ + 'quant_min out of bound' + assert quant_max <= torch.iinfo(self.activation_post_process.dtype).max, \ + 'quant_max out of bound' + self.dtype = self.activation_post_process.dtype + self.qscheme = self.activation_post_process.qscheme + self.ch_axis = self.activation_post_process.ch_axis \ + if hasattr(self.activation_post_process, 'ch_axis') else -1 + self.register_buffer('fake_quant_enabled', torch.tensor([1], dtype=torch.uint8)) + self.register_buffer('static_enabled', torch.tensor([1], dtype=torch.uint8)) + self.register_buffer('learning_enabled', torch.tensor([0], dtype=torch.uint8)) + + bitrange = torch.tensor(quant_max - quant_min + 1).double() + self.bitwidth = int(torch.log2(bitrange).item()) + self.register_buffer('eps', torch.tensor([torch.finfo(torch.float32).eps])) + + @torch.jit.export + def enable_param_learning(self): + r"""Enable parameter learning over static observer estimates. + + Enables learning of quantization parameters and + disables static observer estimates. Forward path returns fake quantized X. + """ + self.toggle_qparam_learning(enabled=True) \ + .toggle_fake_quant(enabled=True) \ + .toggle_observer_update(enabled=False) + return self + + @torch.jit.export + def enable_static_estimate(self): + """Enable static estimates of quantization parameters. + + Enables static observer estimates and disables learning of + quantization parameters. Forward path returns fake quantized X. + """ + self.toggle_qparam_learning(enabled=False) \ + .toggle_fake_quant(enabled=True) \ + .toggle_observer_update(enabled=True) + + @torch.jit.export + def enable_static_observation(self): + """Enable accumulation of data without updating quantization parameters. + + Enables static observer accumulating data from input but doesn't + update the quantization parameters. Forward path returns the original X. + """ + self.toggle_qparam_learning(enabled=False) \ + .toggle_fake_quant(enabled=False) \ + .toggle_observer_update(enabled=True) + + @torch.jit.export + def toggle_observer_update(self, enabled=True): + self.static_enabled[0] = int(enabled) # type: ignore[operator] + return self + + @torch.jit.export + def enable_observer(self, enabled=True): + self.toggle_observer_update(enabled) + + @torch.jit.export + def toggle_qparam_learning(self, enabled=True): + self.learning_enabled[0] = int(enabled) # type: ignore[operator] + self.scale.requires_grad = enabled + self.zero_point.requires_grad = enabled + return self + + @torch.jit.export + def toggle_fake_quant(self, enabled=True): + self.fake_quant_enabled[0] = int(enabled) + return self + + @torch.jit.export + def observe_quant_params(self): + print(f'_LearnableFakeQuantize Scale: {self.scale.detach()}') + print(f'_LearnableFakeQuantize Zero Point: {self.zero_point.detach()}') + + @torch.jit.export + def calculate_qparams(self): + self.scale.data.clamp_(min=self.eps.item()) # type: ignore[operator] + scale = self.scale.detach() + zero_point = self.zero_point.detach().round().clamp(self.quant_min, self.quant_max).long() + return scale, zero_point + + def forward(self, X): + if self.static_enabled[0] == 1: # type: ignore[index] + self.activation_post_process(X.detach()) + _scale, _zero_point = self.activation_post_process.calculate_qparams() + _scale = _scale.to(self.scale.device) + _zero_point = _zero_point.to(self.zero_point.device) + self.scale.data.copy_(_scale) + self.zero_point.data.copy_(_zero_point) + else: + self.scale.data.clamp_(min=self.eps.item()) # type: ignore[operator] + + if self.fake_quant_enabled[0] == 1: + if self.qscheme in (torch.per_channel_symmetric, torch.per_tensor_symmetric): + self.zero_point.data.zero_() + + if self.use_grad_scaling: + grad_factor = 1.0 / (X.numel() * self.quant_max) ** 0.5 + else: + grad_factor = 1.0 + if self.qscheme in ( + torch.per_channel_symmetric, torch.per_channel_affine): + X = torch._fake_quantize_learnable_per_channel_affine( + X, self.scale, self.zero_point, self.ch_axis, + self.quant_min, self.quant_max, grad_factor) + else: + X = torch._fake_quantize_learnable_per_tensor_affine( + X, self.scale, self.zero_point, + self.quant_min, self.quant_max, grad_factor) + + return X diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/__pycache__/_qnnpack_pt2e.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/__pycache__/_qnnpack_pt2e.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0fd6a83187698d44bf570379447fd892dbc37b72 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/__pycache__/_qnnpack_pt2e.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/__pycache__/qnnpack.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/__pycache__/qnnpack.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4d587b599f9793310d007164151a01f26f169f03 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/__pycache__/qnnpack.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/convert.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/convert.py new file mode 100644 index 0000000000000000000000000000000000000000..023abff83404dc9b521c754976eb828c3f03d744 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/convert.py @@ -0,0 +1,1131 @@ +# mypy: ignore-errors + +from typing import Any, Dict, List, Optional, Set, Tuple, Union, Type, Callable +from torch.ao.quantization.quant_type import QuantType +import torch +import copy +import warnings +from torch.fx import ( + GraphModule, +) +from torch.fx.graph import ( + Graph, + Node, + Argument, +) +from ..utils import ( + activation_is_statically_quantized, + weight_is_quantized, + get_qparam_dict, + _parent_name, + get_swapped_custom_module_class, +) +from ..qconfig import ( + QConfigAny, + qconfig_equals +) +from ..qconfig_mapping import QConfigMapping +from .qconfig_mapping_utils import ( + _generate_node_name_to_qconfig, + _compare_prepare_convert_qconfig_mappings, + _update_qconfig_for_fusion, + _is_qconfig_supported_by_dtype_configs, + _update_qconfig_for_qat, +) +from torch.ao.quantization.backend_config.utils import ( + get_root_module_to_quantized_reference_module, + get_pattern_to_dtype_configs, + get_fused_module_classes, + get_qat_module_classes, +) +from torch.ao.quantization.backend_config import ( + BackendConfig, + get_native_backend_config, +) +from torch.ao.quantization.observer import _is_activation_post_process +from .graph_module import ( + _is_observed_module, + _is_observed_standalone_module, +) +from ._equalize import update_obs_for_equalization, convert_eq_obs +from torch.nn.utils.parametrize import type_before_parametrizations +from .utils import ( + _get_module, + _is_custom_module_lstm, + _is_custom_module_mha, + assert_and_get_unique_device, + get_custom_module_class_keys, + create_getattr_from_value, + collect_producer_nodes, + graph_module_from_producer_nodes, + node_arg_is_weight, +) +from torch.ao.quantization.utils import ( + is_per_channel, + to_underlying_dtype, +) +from torch.ao.quantization.quantize import ( + _remove_qconfig, +) +from torch.ao.quantization.stubs import DeQuantStub +from .custom_config import ( + ConvertCustomConfig, + PrepareCustomConfig, +) +from .lower_to_fbgemm import lower_to_fbgemm +# importing the lib so that the quantized_decomposed ops are registered +from ._decomposed import quantized_decomposed_lib # noqa: F401 +import operator + +__all__ = [ + "convert", + "convert_custom_module", + "convert_standalone_module", + "convert_weighted_module", +] + +_QSCHEME_TO_CHOOSE_QPARAMS_OP = { + torch.per_tensor_affine: torch.ops.quantized_decomposed.choose_qparams.tensor, + torch.per_tensor_symmetric: torch.ops.quantized_decomposed.choose_qparams_symmetric.tensor, +} + +def _replace_observer_with_quantize_dequantize_node_decomposed( + model: torch.fx.GraphModule, + node: Node, + modules: Dict[str, torch.nn.Module], + node_name_to_scope: Dict[str, Tuple[str, type]], + node_name_to_qconfig: Dict[str, QConfigAny]) -> None: + """ Replace activation_post_process module call node with quantize and + dequantize node working with decomposed Tensor + + Before: + ... -> observer_0(x) -> ... + After: + ... -> torch.ops.quantized_decomposed.quantize_per_tensor(x, ...) -> + torch.ops.quantized_decomposed.dequantize_per_tensor() -> ... + + or quantize_per_channel and dequantize_per_channel + """ + graph = model.graph + assert modules is not None + assert isinstance(node.target, str) + module_path, prefix = _get_module_path_and_prefix(node, node_name_to_scope, node_name_to_qconfig) + activation_post_process = modules[node.target] + if hasattr(activation_post_process, "convert"): + activation_post_process.convert(model, node) + return + # skip replacing observers to quant/dequant nodes if the qconfigs of all + # consumers and producers of this observer are None + skip_replacement = all(_has_none_qconfig(n, node_name_to_qconfig) for n in + list(node.args) + list(node.users.keys())) + if skip_replacement or not _is_conversion_supported(activation_post_process): + # didn't find corresponding quantize op and info for the activation_post_process + # so we just remove the observer + with graph.inserting_before(node): + node.replace_all_uses_with(node.args[0]) + graph.erase_node(node) + return + + # otherwise, we can convert the activation_post_process module call to quantize/dequantize node + + # 1. extract the information from activation_post_process module for generating + # the quantize and dequantize operator + dtype = activation_post_process.dtype # type: ignore[attr-defined] + + is_dynamic = False + if hasattr(activation_post_process, "is_dynamic"): + is_dynamic = activation_post_process.is_dynamic # type: ignore[assignment] + + if dtype in [torch.quint8, torch.qint8, torch.qint32, torch.uint8, torch.int8, torch.int16, torch.int32] and \ + (not is_dynamic): + # TODO: probably should cleanup this condition check, it's hard + # to reason about this if and the following elif + + # uint8/int8/int32 static quantization branch + + # 1. extract information for inserting q/dq node from activation_post_process + node_type = "call_function" + quantize_op : Optional[Callable] = None + scale, zero_point = activation_post_process.calculate_qparams() # type: ignore[attr-defined, operator] + if is_per_channel(activation_post_process.qscheme): # type: ignore[attr-defined] + ch_axis = int(activation_post_process.ch_axis) # type: ignore[attr-defined, arg-type] + quantize_op = torch.ops.quantized_decomposed.quantize_per_channel.default + dequantize_op = torch.ops.quantized_decomposed.dequantize_per_channel.default + quant_min = activation_post_process.quant_min + quant_max = activation_post_process.quant_max + dtype_ = to_underlying_dtype(dtype) + qparams = { + "_scale_": scale, + "_zero_point_": zero_point, + "_axis_": ch_axis, + "_quant_min_": quant_min, + "_quant_max_": quant_max, + "_dtype_": dtype_ + } + else: + quantize_op = torch.ops.quantized_decomposed.quantize_per_tensor.default + dequantize_op = torch.ops.quantized_decomposed.dequantize_per_tensor.default + scale = float(scale) + zero_point = int(zero_point) + quant_min = activation_post_process.quant_min # type: ignore[attr-defined] + quant_max = activation_post_process.quant_max # type: ignore[attr-defined] + dtype_ = to_underlying_dtype(dtype) + qparams = { + "_scale_": scale, + "_zero_point_": zero_point, + "_quant_min_": quant_min, + "_quant_max_": quant_max, + "_dtype_": dtype_ + } + + # 2. replace activation_post_process node with quantize and dequantize + with graph.inserting_before(node): + input_node = node.args[0] + quantize_op_inputs = [input_node] + for key, value_or_node in qparams.items(): + # TODO: we can add the information of whether a value needs to + # be registered as an attribute in qparams dict itself + if key in ['_scale_', '_zero_point_'] and (not isinstance(value_or_node, (float, int))): + # For scale and zero_point values we register them as buffers in the root module. + # However, note that when the values are not tensors, as in the case of + # per_tensor quantization, they will be treated as literals. + # However, registering them as a node seems to cause issue with dynamo + # tracing where it may consider tensor overload as opposed to default. + # With extra check of scale and zero_point being scalar, it makes + # sure that the default overload can be used. + # TODO: maybe need more complex attr name here + qparam_node = create_getattr_from_value( + model, graph, module_path + prefix + key, value_or_node) + quantize_op_inputs.append(qparam_node) + else: + # for qparams that are not scale/zero_point (like axis, dtype) we store them as literals in the graph. + quantize_op_inputs.append(value_or_node) + + quantized_node = graph.create_node(node_type, quantize_op, tuple(quantize_op_inputs), {}) + # use the same qparams from quantize op + dq_inputs = [quantized_node] + quantize_op_inputs[1:] + dequantized_node = graph.call_function( + dequantize_op, + tuple(dq_inputs), + {} + ) + + def remap_fn(x): + return dequantized_node if x is node else x + + # remap numeric_debug_handle + for user_node in node.users: + if "numeric_debug_handle" in user_node.meta: + numeric_debug_handle = user_node.meta["numeric_debug_handle"] + user_node.meta["numeric_debug_handle"] = {remap_fn(k): v for k, v in numeric_debug_handle.items()} + node.replace_all_uses_with(dequantized_node) + graph.erase_node(node) + elif is_dynamic: + + # uint8/int8/fp16 dynamic quantization + + # 1. extract information for inserting q/dq node from activation_post_process + node_type = "call_function" + quantize_op = torch.ops.quantized_decomposed.quantize_per_tensor.tensor + # we only use choose_qparams for is_decomposed now, + # but we should probably align the non-decomposed path with this as well, + # and that can be done after we remove reduce_range flag + # 1. extract qparams from activation_post_process module + dtype_ = to_underlying_dtype(dtype) + assert dtype_ in [torch.uint8, torch.int8], \ + "only uint8 and int8 are supported in reference flow for " \ + "dynamic quantization right now" + quant_min = activation_post_process.quant_min # type: ignore[attr-defined] + quant_max = activation_post_process.quant_max # type: ignore[attr-defined] + qscheme = getattr(activation_post_process, "qscheme", torch.per_tensor_affine) # type: ignore[attr-defined] + eps = getattr(activation_post_process, "eps", torch.finfo(torch.float32).eps) # type: ignore[attr-defined] + # note: scale and zero_point are missing for quantize_per_tensor op + # we'll need to get this from choose_qparams op, which we'll add after + # this step + qparams = { + "_quant_min_": quant_min, + "_quant_max_": quant_max, + "_eps_": eps, + "_dtype_": dtype_ + } + + choose_qparams_op = _QSCHEME_TO_CHOOSE_QPARAMS_OP[qscheme] + # 2. insert choose_qparams op and update the qparams list + with graph.inserting_before(node): + input_node = node.args[0] + choose_qparams_op_inputs = [node.args[0]] + for key, value in qparams.items(): + # we have quant_min, quant_max and dtype, all should be stored + # as literals + choose_qparams_op_inputs.append(value) + choose_qparams_node = graph.create_node( + "call_function", + choose_qparams_op, + tuple(choose_qparams_op_inputs), + {} + ) + # choose_qparms returns (scale, zero_point) + scale_node = graph.create_node( + "call_function", + operator.getitem, + (choose_qparams_node, 0), + {} + ) + zero_point_node = graph.create_node( + "call_function", + operator.getitem, + (choose_qparams_node, 1), + {} + ) + quant_min = qparams["_quant_min_"] + quant_max = qparams["_quant_max_"] + dtype = qparams["_dtype_"] + qparams = { + "_scale_": scale_node, + "_zero_point_": zero_point_node, + "_quant_min_": quant_min, + "_quant_max_": quant_max, + "_dtype_": dtype + } + + # 3. replace activation_post_process node to quantize and dequantize node + with graph.inserting_before(node): + input_node = node.args[0] + quantize_op_inputs = [input_node] + for key, value_or_node in qparams.items(): + # TODO: we can add the information of whether a value needs to + # be registered as an attribute in qparams dict itself + if key in ['_scale_', '_zero_point_']: + # in this case we have a node in the graph since it's dynamically + # computed from the input, with choose_qparams op + qparam_node = value_or_node + quantize_op_inputs.append(qparam_node) + else: + # for qparams that are not scale/zero_point (like axis, dtype) we + # store them as literals in the graph. + quantize_op_inputs.append(value_or_node) + + quantized_node = graph.create_node(node_type, quantize_op, tuple(quantize_op_inputs), {}) + # use the same qparams from quantize op + dq_inputs = [quantized_node] + quantize_op_inputs[1:] + # need to use the tensor variant of this op, since scale and zero_point + # from choose_qparam are Tensors, instead of float/int, this is to + # prevent these nodes being traced away by downstream systems + dequantize_op = torch.ops.quantized_decomposed.dequantize_per_tensor.tensor + dequantized_node = graph.call_function( + dequantize_op, + tuple(dq_inputs), + {} + ) + + def remap_fn(x): + return dequantized_node if x is node else x + + # remap numeric_debug_handle + for user_node in node.users: + if "numeric_debug_handle" in user_node.meta: + numeric_debug_handle = user_node.meta["numeric_debug_handle"] + user_node.meta["numeric_debug_handle"] = {remap_fn(k): v for k, v in numeric_debug_handle.items()} + node.replace_all_uses_with(dequantized_node) + graph.erase_node(node) + elif dtype == torch.float16: + raise NotImplementedError("decomposed to float16 op not implemented yet") + + # should not reach since we have checks in the beginning to make sure the + # activation_post_process is supported + +def _replace_observer_with_quantize_dequantize_node( + model: torch.fx.GraphModule, + node: Node, + modules: Dict[str, torch.nn.Module], + node_name_to_scope: Dict[str, Tuple[str, type]], + node_name_to_qconfig: Dict[str, QConfigAny]) -> None: + """ Replace activation_post_process module call node with quantize and + dequantize node + + Before: + ... -> observer_0(x) -> ... + After: + ... -> torch.quantize_per_tensor(x, ...) -> x.dequantize() -> ... + """ + assert modules is not None + assert isinstance(node.target, str) + graph = model.graph + module_path, prefix = _get_module_path_and_prefix(node, node_name_to_scope, node_name_to_qconfig) + activation_post_process = modules[node.target] + # skip replacing observers to quant/dequant nodes if the qconfigs of all + # consumers and producers of this observer are None + skip_replacement = all(_has_none_qconfig(n, node_name_to_qconfig) for n in + list(node.args) + list(node.users.keys())) + if skip_replacement or not _is_conversion_supported(activation_post_process): + # didn't find corresponding quantize op and info for the activation_post_process + # so we just remove the observer + with graph.inserting_before(node): + node.replace_all_uses_with(node.args[0]) + graph.erase_node(node) + return + + # otherwise, we can convert the activation_post_process module call to quantize/dequantize node + dtype = activation_post_process.dtype # type: ignore[attr-defined] + + is_dynamic = False + if hasattr(activation_post_process, "is_dynamic"): + is_dynamic = activation_post_process.is_dynamic # type: ignore[attr-defined, assignment] + + if dtype in [torch.quint8, torch.qint8, torch.qint32] and \ + (not is_dynamic): + # TODO: probably should cleanup this condition check, it's hard + # to reason about this if and the following elif + + # uint8/int8/int32 static quantization branch + + # 1. extract the information from activation_post_process module for generating + # the quantize and dequantize operator + node_type = "call_function" + quantize_op : Optional[Callable] = None + scale, zero_point = activation_post_process.calculate_qparams() # type: ignore[attr-defined, operator] + if is_per_channel(activation_post_process.qscheme): # type: ignore[attr-defined] + ch_axis = int(activation_post_process.ch_axis) # type: ignore[attr-defined, arg-type] + qparams = {"_scale_": scale, "_zero_point_": zero_point, "_axis_": ch_axis, "_dtype_": dtype} + quantize_op = torch.quantize_per_channel + else: + scale = float(scale) + zero_point = int(zero_point) + qparams = {"_scale_": scale, "_zero_point_": zero_point, "_dtype_": dtype} + quantize_op = torch.quantize_per_tensor + + # 2. replace activation_post_process node with quantize and dequantize + with graph.inserting_before(node): + input_node = node.args[0] + quantize_op_inputs = [input_node] + for key, value_or_node in qparams.items(): + # TODO: we can add the information of whether a value needs to + # be registered as an attribute in qparams dict itself + if key in ['_scale_', '_zero_point_']: + # For scale and zero_point values we register them as buffers in the root module. + # TODO: maybe need more complex attr name here + qparam_node = create_getattr_from_value( + model, graph, module_path + prefix + key, value_or_node) + quantize_op_inputs.append(qparam_node) + else: + # for qparams that are not scale/zero_point (like axis, dtype) we store them as literals in the graph. + quantize_op_inputs.append(value_or_node) + + quantized_node = graph.create_node(node_type, quantize_op, tuple(quantize_op_inputs), {}) + dequantized_node = graph.call_method("dequantize", args=(quantized_node,)) + node.replace_all_uses_with(dequantized_node) + graph.erase_node(node) + elif is_dynamic: + + # uint8/int8/fp16 dynamic quantization branch + + node_type = "call_function" + quantize_op = torch.quantize_per_tensor_dynamic + # TODO: get reduce range from observer + # reduce_range = activation_post_process.reduce_range + reduce_range = torch.backends.quantized.engine in ("fbgemm", "x86") + qparams = {"_dtype_": dtype, "_reduce_range_": reduce_range} + + with graph.inserting_before(node): + input_node = node.args[0] + quantize_op_inputs = [input_node] + for key, value in qparams.items(): + quantize_op_inputs.append(value) + + quantized_node = graph.create_node(node_type, quantize_op, tuple(quantize_op_inputs), {}) + dequantized_node = graph.call_method("dequantize", args=(quantized_node,)) + node.replace_all_uses_with(dequantized_node) + graph.erase_node(node) + elif dtype == torch.float16: + node_type = "call_method" + quantize_op = "to" # type: ignore[assignment] + qparams = {"_dtype_": dtype} + with graph.inserting_before(node): + input_node = node.args[0] + quantize_op_inputs = [input_node] + for key, value in qparams.items(): + # TODO: we can add the information of whether a value needs to + # be registered as an attribute in qparams dict itself + quantize_op_inputs.append(value) + + quantized_node = graph.create_node(node_type, quantize_op, tuple(quantize_op_inputs), {}) + dequantized_node = graph.call_method("dequantize", args=(quantized_node,)) + node.replace_all_uses_with(dequantized_node) + graph.erase_node(node) + + # should not reach since we have checks in the beginning to make sure the + # activation_post_process is supported + +# this is a temporary hack for custom module, we may want to implement +# this properly after the custom module class design is finalized +# TODO: DeQuantStubs are currently inserted only after custom module LSTM, while observers are inserted +# after all other custom modules. In the future, we should simply insert QuantStubs before and DeQuantStubs +# after custom modules in general, and replace these with "quantize" and "dequantize" nodes respectively. +def _replace_observer_or_dequant_stub_with_dequantize_node(node: Node, graph: Graph) -> None: + call_custom_module_node = node.args[0] + assert isinstance(call_custom_module_node, Node), \ + f"Expecting the for call custom module node to be a Node, but got {call_custom_module_node}" + node.replace_all_uses_with(call_custom_module_node) + graph.erase_node(node) + _insert_dequantize_node(call_custom_module_node, graph) + +def _is_conversion_supported(activation_post_process: torch.nn.Module) -> bool: + dtype = activation_post_process.dtype # type: ignore[attr-defined] + + is_dynamic = False + if hasattr(activation_post_process, "is_dynamic"): + is_dynamic = activation_post_process.is_dynamic # type: ignore[attr-defined, assignment] + + return ( + (dtype in [ + torch.quint8, + torch.qint8, + torch.qint32, + torch.uint8, + torch.int8, + torch.int16, + torch.int32 + ] and (not is_dynamic)) or # type: ignore[return-value] + is_dynamic or + dtype == torch.float16 + ) + +def _has_none_qconfig(node: Argument, node_name_to_qconfig: Dict[str, QConfigAny]) -> bool: + """ Check if a node has a qconfig of None, i.e. user requested to not quantize + the node + """ + return isinstance(node, Node) and node.name in node_name_to_qconfig and node_name_to_qconfig[node.name] is None + +def _run_weight_observers(observed: GraphModule, backend_config: BackendConfig) -> None: + """ Extract the subgraph that produces the weight for dynamic quant + or weight only quant node and run the subgraph to observe the weight. + Note that the observers of dynamic quant or weight only quant ops are + run during the convert step. + """ + for node in observed.graph.nodes: + if node.op != "call_function": + continue + for node_arg in node.args: + # node_arg is weight + if node_arg and node_arg_is_weight(node, node_arg): + weight_observer_nodes = collect_producer_nodes(node_arg) + if weight_observer_nodes is None: + continue + weight_observer_module = \ + graph_module_from_producer_nodes( + observed, weight_observer_nodes) + # run the weight observer + weight_observer_module() + +def _maybe_recursive_remove_dequantize(arg: Any, node: Node, graph: Graph) -> None: + """ If the arg is a dequantize Node, or a list/tuple/dict of dequantize Node, + we'll recursively remove the dequantize Node + """ + if isinstance(arg, Node) and \ + arg.op == "call_method" and \ + arg.target == "dequantize": + quantize_node = arg.args[0] + # we only replace the specific use since dequantize could be used by other nodes + # as well + node.replace_input_with(arg, quantize_node) + elif isinstance(arg, (list, tuple)): + for arg_element in arg: + _maybe_recursive_remove_dequantize(arg_element, node, graph) + elif isinstance(arg, dict): + for arg_element in arg.values(): + _maybe_recursive_remove_dequantize(arg_element, node, graph) + else: + warnings.warn(f"Unsupported node type in recursive remove dequantize: {type(arg)}") + +def _get_module_path_and_prefix( + obs_node: Node, + node_name_to_scope: Dict[str, Tuple[str, type]], + node_name_to_qconfig: Dict[str, QConfigAny]) -> Tuple[str, str]: + """ Given and observer node, get the `Scope` or the fully qualified name for + the submodule containing the observed node, also return a prefix of "_input" + when the observed node is an input of a F.linear op, and not the output of another + quantized op. + TODO: this logic is hacky, we should think about how to remove it or make it more + general + """ + observed_node = obs_node.args[0] + # an observer can be inserted for both input of the next operator or output of the previous + # operator (they can be the same) + # this flag identifies if the observer is inserted only because the observed node is + # the input of the next operator + assert isinstance(observed_node, Node), \ + f"Expecting observed node to be a Node, but got {observed_node}" + is_input_observer_only = node_name_to_qconfig[observed_node.name] is None \ + if observed_node.name in node_name_to_qconfig else None + if is_input_observer_only: + # if the quantize function is at the input of op, then we find the first user of the observer_node + # to get the path. If a linear call_function is in the user list, we return the first instance + # of linear node to get the FQN. + users = list(obs_node.users) + first_linear_use_or_first_use = users[0] if users else None + linear_node = None + for n in users: + if n.op == "call_function" and n.target == torch.nn.functional.linear: + linear_node = n + break + if linear_node: + first_linear_use_or_first_use = linear_node + prefix = "_input" + else: + # if the quantize function is at the output of the op, we use the observer input node to get the path + first_linear_use_or_first_use = observed_node + prefix = "" + + if first_linear_use_or_first_use and first_linear_use_or_first_use.name in node_name_to_scope: + module_path, _ = node_name_to_scope[first_linear_use_or_first_use.name] + else: + # TODO: it's not used, so actually we can skip quantization + # but this requires changing return type of quantize_node + # we can fix it later if needed + module_path = "" + return module_path, prefix + +def _insert_dequantize_node( + node: Node, + graph: Graph) -> None: + """ Inserts dequantize node for `node` in `graph` + """ + with graph.inserting_after(node): + dequantize_node = graph.call_method("dequantize", (node,)) + for user_node in dict(node.users): + if user_node is not dequantize_node: + user_node.replace_input_with(node, dequantize_node) + +def _maybe_get_observer_for_node( + node: Node, + modules: Dict[str, torch.nn.Module] +) -> Optional[torch.nn.Module]: + """ + If the node is observed, return the observer + instance. Otherwise, return None. + """ + for maybe_obs_node in node.users.keys(): + if maybe_obs_node.op == 'call_module': + maybe_obs = modules[str(maybe_obs_node.target)] + if _is_activation_post_process(maybe_obs): + return maybe_obs + return None + +def convert_standalone_module( + node: Node, + modules: Dict[str, torch.nn.Module], + model: torch.fx.GraphModule, + is_reference: bool, + backend_config: Optional[BackendConfig]) -> None: + """ Converts a observed standalone module to a quantized standalone module by calling + the fx convert api, currently using the same `is_reference` flag as parent, but we may + changing this behavior in the future (e.g. separating quantization and lowering for + standalone module as well) + + Args: + - node: The call_module node of the observed standalone module + - modules: named_module of original model + - model: original model + - is_reference: a flag from parent provided by user to decide if we want to + produce a reference model or a fbgemm/qnnpack model + - backend_config: backend configuration of the target backend of quantization + """ + # TODO: remove is_reference flag + if is_reference: + convert_fn = torch.ao.quantization.quantize_fx.convert_to_reference_fx + else: + convert_fn = torch.ao.quantization.quantize_fx.convert_fx # type: ignore[attr-defined] + # We know that observed standalone module is a GraphModule since + # it's produced by us + observed_standalone_module : GraphModule = modules[str(node.target)] # type: ignore[assignment] + sm_input_quantized_idxs = \ + observed_standalone_module \ + .meta["_observed_graph_module_attrs"].standalone_module_input_quantized_idxs + # remove the dequantize nodes for inputs + args = list(node.args) + for idx in range(len(args)): + if idx in sm_input_quantized_idxs: + arg = args[idx] + if arg.op == "call_method" and arg.target == "dequantize": # type: ignore[union-attr] + quantize_node = arg.args[0] # type: ignore[union-attr] + node.replace_input_with(arg, quantize_node) + if len(arg.users) == 0: # type: ignore[union-attr] + model.graph.erase_node(arg) + # add dequantize node for output + sm_output_quantized_idxs = \ + observed_standalone_module \ + .meta["_observed_graph_module_attrs"].standalone_module_output_quantized_idxs + if len(sm_output_quantized_idxs) > 0: + assert sm_output_quantized_idxs[0] == 0, "Currently only quantized" + "output idxs = [0] is supported" + + # if it's non-empty, then it means the output is kept in quantized form + # we'll just add a dequantize node after this node + _insert_dequantize_node(node, model.graph) + + # TODO: allow convert_custom_config to override backend_config + # for standalone module + quantized_standalone_module = convert_fn( + observed_standalone_module, + backend_config=backend_config) + parent_name, name = _parent_name(node.target) + # update the modules dict + setattr(modules[parent_name], name, quantized_standalone_module) + modules[str(node.target)] = quantized_standalone_module + +def convert_weighted_module( + node: Node, + modules: Dict[str, torch.nn.Module], + observed_node_names: Set[str], + node_name_to_qconfig: Dict[str, QConfigAny], + backend_config: BackendConfig, + is_decomposed: bool = False, + is_reference: bool = False, +) -> None: + """ Convert a weighted module to reference quantized module in the model + If the QConfig of a QAT module is not set, the module will still be converted to + a float module. + + Args: + - node: The call_module node of the observed standalone module + - modules: named_module of original model + - observed_node_names: names for the set of observed fx node, we can skip + this conversion if the node is not observed + """ + original_module = modules[str(node.target)] + qconfig: QConfigAny = original_module.qconfig # type: ignore[assignment] + weight_post_process = None + qat_module_classes = get_qat_module_classes(backend_config) + + if isinstance( + original_module, + qat_module_classes): + # Converting qat module to a float module, we need to attach + # weight fake_quant to the module, weight fake_quant is assumed to be run during + # QAT so we don't need to run it again here + weight_post_process = original_module.weight_fake_quant + original_module = original_module.to_float() # type: ignore[operator] + # change qat module to float module + parent_name, name = _parent_name(node.target) + setattr(modules[parent_name], name, original_module) + + is_observed = node.name in observed_node_names + # If a qconfig is not defined for this node, then skip converting to a reference module + if qconfig is None or _has_none_qconfig(node, node_name_to_qconfig) or not is_observed: + return + + # skip converting to reference quantized module if the qconfig is not supported + pattern_to_dtype_configs = get_pattern_to_dtype_configs(backend_config) + dtype_configs = pattern_to_dtype_configs.get(type(original_module), []) + if not _is_qconfig_supported_by_dtype_configs(qconfig, dtype_configs): + return + + # TODO: rename weight_is_statically_quantized to weight_is_int8_quantized + is_weight_quantized = weight_is_quantized(qconfig) + + # the condition for swapping the module to reference quantized module is: + # weights need to be quantized + if not is_weight_quantized: + return + + fused_module = None + float_module = original_module + # extract the individual float_module and fused module + if isinstance(original_module, torch.ao.nn.intrinsic._FusedModule): + fused_module = float_module + float_module = fused_module[0] # type: ignore[index] + + # TODO: move this to the reference quantized module + # weight_qparams or weight_qparams dict + wq_or_wq_dict = {"is_decomposed": is_decomposed} + if isinstance(float_module, torch.nn.RNNCellBase): + weight_post_process_ih = qconfig.weight() # type: ignore[union-attr, operator] + weight_post_process_hh = qconfig.weight() # type: ignore[union-attr, operator] + weight_post_process_ih(float_module.weight_ih) + weight_post_process_hh(float_module.weight_hh) + weight_qparams_ih = get_qparam_dict(weight_post_process_ih) + weight_qparams_hh = get_qparam_dict(weight_post_process_hh) + wq_or_wq_dict.update({ + "weight_ih": weight_qparams_ih, + "weight_hh": weight_qparams_hh, + }) + elif isinstance(float_module, (torch.nn.LSTM, torch.nn.GRU)): + # format for wq_or_wq_dict (flattened attributes): + # {"weight_ih_l0_scale": ..., "weight_ih_l0_qscheme": ..., ...} + for wn in float_module._flat_weights_names: + if hasattr(float_module, wn) and wn.startswith("weight"): + weight = getattr(float_module, wn) + weight_post_process = qconfig.weight() # type: ignore[union-attr, operator] + if weight_post_process.dtype == torch.qint8: # type: ignore[union-attr] + weight_post_process(weight) # type: ignore[operator, misc] + wq_or_wq_dict[wn] = get_qparam_dict(weight_post_process) + else: + # weight_post_process is None means the original module is not a QAT module + # we need to get weight_post_process from qconfig in this case + is_ptq = weight_post_process is None + if is_ptq: + weight_post_process = qconfig.weight() # type: ignore[union-attr, operator] + device = assert_and_get_unique_device(float_module) + if device: + weight_post_process.to(device) + + # Call weight observer/fake_quant at least once to ensure the scales and zero points + # have the right shapes. Note: there are two cases where we don't have to do this: + # + # (1) QAT: The model's forward method already calls the weight observer/fake_quant, + # and this typically happens during training, so we don't need to do it here. + # + # (2) Non-reference (lowered) case: The quantized module's from_float method already + # calls the weight observer/fake_quant, so we don't have to do it here. + # + # Currently we ignore both cases and call the weight observer/fake_quant here + # regardless, which is technically incorrect. For (1), this is mainly to preserve BC + # in test code, which may not always train before convert. In the future, we should + # break BC for these two cases. See https://github.com/pytorch/pytorch/issues/73941. + # + # For PT2, however, we don't need to preserve BC here, so we can skip this hack + # for QAT. We identify this case as (is_decomposed + is_reference + is_qat). + # Note that we still need it for PTQ in the PT2 flow since the model's forward + # method doesn't call the weight observer. + is_qat = not is_ptq + if not (is_decomposed and is_reference and is_qat): + weight_post_process(float_module.weight) # type: ignore[operator] + + wq_or_wq_dict.update(get_qparam_dict(weight_post_process)) + + # We use the same reference module for all modes of quantization: static, dynamic, weight_only + # root_module_to_quantized_reference_module: module mapping from root (floating point) module class + # to quantized reference module class, e.g. nn.Conv2d to nn.quantized._reference.Conv2d + root_module_to_quantized_reference_module = get_root_module_to_quantized_reference_module(backend_config) + ref_qmodule_cls = root_module_to_quantized_reference_module.get(type_before_parametrizations(float_module), None) + assert ( + ref_qmodule_cls is not None + ), f"No reference quantized module class configured for {type_before_parametrizations(float_module)}" + ref_qmodule = ref_qmodule_cls.from_float(float_module, wq_or_wq_dict) # type: ignore[attr-defined] + if fused_module is not None: + fused_module[0] = ref_qmodule # type: ignore[operator] + else: + parent_name, name = _parent_name(node.target) + setattr(modules[parent_name], name, ref_qmodule) + +def _remove_previous_dequantize_in_custom_module(node: Node, prev_node: Node, graph: Graph) -> None: + """ + Given a custom module `node`, if the previous node is a dequantize, reroute the custom as follows: + + Before: quantize - dequantize - custom_module + After: quantize - custom_module + \\ - dequantize + """ + # expecting the input node for a custom module node to be a Node + assert isinstance(prev_node, Node), \ + f"Expecting the argument for custom module node to be a Node, but got {prev_node}" + if prev_node.op == "call_method" and prev_node.target == "dequantize": + node.replace_input_with(prev_node, prev_node.args[0]) + # Remove the dequantize node if it doesn't have other users + if len(prev_node.users) == 0: + graph.erase_node(prev_node) + +def convert_custom_module( + node: Node, + graph: Graph, + modules: Dict[str, torch.nn.Module], + custom_module_class_mapping: Dict[QuantType, Dict[Type, Type]], + statically_quantized_custom_module_nodes: Set[Node]) -> None: + """ Converts an observed custom module to a quantized custom module based on + `custom_module_class_mapping` + For static quantization, we'll also remove the previous `dequantize` node and + attach the observer node for output to the module, the observer for the node + will be converted to a dequantize node instead of quantize-dequantize pairs + later in the graph. In the end we would have a quantized custom module that + has the same interface as a default quantized module in nn.quantized namespace, + i.e. quantized input and quantized output. + + Args: + - node: The call_module node of the observed standalone module + - graph: The graph containing the node + - modules: named_module of original model + - custom_module_class_mapping: mapping from observed custom module class to + quantized custom module class, used to swap custom modules + - statically_quantized_custom_module_nodes: we'll add the custom module node + if we find it is statically quantized, this will be used later when converting + observers to quant/dequant node pairs, if the observed node is a statically + quantized custom module nodes, we'll convert the observer to a dequantize node, + this is to keep the interface the same as the default quantized module. + TODO: maybe we want to redesign this part to align with reference model design + as well, but there has been some discussions around the interface, so we can do + it later. + """ + observed_custom_module = modules[str(node.target)] + maybe_obs = _maybe_get_observer_for_node(node, modules) + qconfig = observed_custom_module.qconfig + if activation_is_statically_quantized(qconfig): + statically_quantized_custom_module_nodes.add(node) + if _is_custom_module_lstm(node, modules): + # The inputs are tuples in the form (input, (hidden0, hidden1)) + # Ensure all three input nodes are quantized + assert ( + len(node.args) == 2 and + isinstance(node.args[1], tuple) and + len(node.args[1]) == 2 + ) + (inputs, (hidden0, hidden1)) = node.args # type: ignore[misc] + assert isinstance(inputs, Node) + assert isinstance(hidden0, Node) + assert isinstance(hidden1, Node) + _remove_previous_dequantize_in_custom_module(node, inputs, graph) + _remove_previous_dequantize_in_custom_module(node, hidden0, graph) + _remove_previous_dequantize_in_custom_module(node, hidden1, graph) + elif _is_custom_module_mha(node, modules): + # Inputs are in the form (query, key, value) + # TODO: This is the first step in enabling the full fx custom module + # quantization path for MultiheadAttention, and only covers the inputs + # to the module. + # Additional handling is yet to be implemented for the outputs, similar + # to LSTM custom module + assert len(node.args) == 3 + query, key, value = node.args + assert isinstance(query, Node) + assert isinstance(key, Node) + assert isinstance(value, Node) + _remove_previous_dequantize_in_custom_module(node, query, graph) + _remove_previous_dequantize_in_custom_module(node, key, graph) + _remove_previous_dequantize_in_custom_module(node, value, graph) + else: + # remove the previous dequant node to ensure the inputs are quantized + arg = node.args[0] + assert isinstance(arg, Node) + _remove_previous_dequantize_in_custom_module(node, arg, graph) + # absorb the following observer into the module conversion + activation_post_process = _maybe_get_observer_for_node(node, modules) + assert activation_post_process is not None + observed_custom_module.activation_post_process = activation_post_process + + # swap the observed custom module to quantized custom module + quantized_custom_module_class = get_swapped_custom_module_class( + observed_custom_module, custom_module_class_mapping, qconfig) + quantized_custom_module = \ + quantized_custom_module_class.from_observed(observed_custom_module) + parent_name, name = _parent_name(node.target) + setattr(modules[parent_name], name, quantized_custom_module) + +def convert( + model: GraphModule, is_reference: bool = False, + convert_custom_config: Union[ConvertCustomConfig, Dict[str, Any], None] = None, + is_standalone_module: bool = False, + _remove_qconfig_flag: bool = True, + qconfig_mapping: Union[QConfigMapping, Dict[str, Any], None] = None, + backend_config: Union[BackendConfig, Dict[str, Any], None] = None, + is_decomposed: bool = False) -> GraphModule: + """ + We will convert an observed model (a module with observer calls) to a reference + quantized model, the rule is simple: + 1. for each observer module call in the graph, we'll convert it to calls to + quantize and dequantize functions based on the observer instance + 2. for weighted operations like linear/conv, we need to convert them to reference + quantized module, this requires us to know whether the dtype configured for the + weight is supported in the backend, this is done in prepare step and the result + is stored in observed_node_names, we can decide whether we need to swap the + module based on this set + + Args: + * `is_standalone_module`: when this flag is True, it means we are quantizing + a submodule that is not inlined in parent module, and will be quantized + separately as one unit. + + * `is_decomposed`: a boolean flag to indicate whether we want to use the + quantize operator for decomposed quantized tensor + (torch.ops.quantized_decomposed.quantize_per_tensor) or default/standalone + quantized tensor (torch.quantize_per_tensor) + + Returns: + a quantized standalone module, whether input/output is quantized is + specified by prepare_custom_config, with + input_quantized_idxs, output_quantized_idxs, please + see docs for :func:`~torch.ao.quantization.prepare_fx` for details + """ + if convert_custom_config is None: + convert_custom_config = ConvertCustomConfig() + + if isinstance(convert_custom_config, Dict): + warnings.warn( + "Passing a convert_custom_config_dict to convert is deprecated and will not be supported " + "in a future version. Please pass in a ConvertCustomConfig instead.") + convert_custom_config = ConvertCustomConfig.from_dict(convert_custom_config) + + if isinstance(qconfig_mapping, Dict): + warnings.warn( + "Passing a QConfig dictionary to convert is deprecated and will not be supported " + "in a future version. Please pass in a QConfigMapping instead.") + qconfig_mapping = QConfigMapping.from_dict(qconfig_mapping) if qconfig_mapping else None + qconfig_mapping = copy.deepcopy(qconfig_mapping) + assert qconfig_mapping is None or isinstance(qconfig_mapping, QConfigMapping) + + if isinstance(backend_config, Dict): + warnings.warn( + "Passing a backend_config_dict to prepare is deprecated and will not be supported " + "in a future version. Please pass in a BackendConfig instead.") + backend_config = BackendConfig.from_dict(backend_config) + + if backend_config is None: + backend_config = get_native_backend_config() + + assert _is_observed_module(model), \ + 'incoming model must be produced by prepare_fx' + observed_graph_module_attrs = model.meta["_observed_graph_module_attrs"] + node_name_to_scope: Dict[str, Tuple[str, type]] = observed_graph_module_attrs.node_name_to_scope + prepare_custom_config: PrepareCustomConfig = observed_graph_module_attrs.prepare_custom_config + observed_node_names: Set[str] = observed_graph_module_attrs.observed_node_names + node_name_to_qconfig: Dict[str, QConfigAny] = observed_graph_module_attrs.node_name_to_qconfig # type: ignore[assignment] + + # mapping from fully qualified module name to module instance + # for example, + # { + # '': Model(...), + # 'linear': Linear(...), + # 'linear.weight_fake_quant': PerChannelMinMaxObserver(...), + # } + # We use remove_duplicate=False here because torch.cat uses + # the same activation_post_process module instance but different names + modules = dict(model.named_modules(remove_duplicate=False)) + + # TODO refactor this code once we update the prepare logic to have additional information on + # which graph nodes have been observed and share that with convert to decide which observers to ignore. + if qconfig_mapping: + prepare_qconfig_mapping: QConfigMapping = observed_graph_module_attrs.qconfig_mapping # type: ignore[assignment] + modules_copy = copy.deepcopy(modules) + + if observed_graph_module_attrs.is_qat: + _update_qconfig_for_qat(qconfig_mapping, backend_config) + _update_qconfig_for_fusion(model, qconfig_mapping) + + _compare_prepare_convert_qconfig_mappings(prepare_qconfig_mapping, qconfig_mapping) # type: ignore[arg-type] + convert_node_name_to_qconfig = _generate_node_name_to_qconfig( + model, modules_copy, model.graph, qconfig_mapping, node_name_to_scope) + # check the convert_node_name_to_qconfig generated and ensure that + # all the values either match what was set in prepare node_name_to_qconfig + # or are set to None in the convert_node_name_to_qconfig. + for k, v in node_name_to_qconfig.items(): + assert k in convert_node_name_to_qconfig, f'Expected key {k} in convert node_name_to_qconfig' + if convert_node_name_to_qconfig[k] is not None: + assert qconfig_equals(v, convert_node_name_to_qconfig[k]), \ + f"Expected k {k} to have the same value in prepare and convert QConfigMappings, " \ + f"but {v} was updated to {convert_node_name_to_qconfig[k]}" + node_name_to_qconfig = convert_node_name_to_qconfig + + custom_module_classes = get_custom_module_class_keys(convert_custom_config.observed_to_quantized_mapping) + custom_module_class_mapping = convert_custom_config.observed_to_quantized_mapping + + if observed_graph_module_attrs.equalization_node_name_to_qconfig is not None: + # If we want to do equalization then do the following: + # Calculate the equalization scale, update the observers with the scaled + # inputs, and scale the weight + weight_eq_obs_dict = update_obs_for_equalization(model, modules) + convert_eq_obs(model, modules, weight_eq_obs_dict) + + # always run weight observers in the top level forward method + # for dynamic quant ops or weight only quant ops + _run_weight_observers(model, backend_config) + + graph_inputs: List[str] = [] + for node in model.graph.nodes: + if node.op == 'placeholder': + graph_inputs.append(node.name) + + # additional state to override inputs to be quantized, if specified + # by the user + placeholder_node_seen_cnt = 0 + input_quantized_idxs: List[int] = prepare_custom_config.input_quantized_indexes + output_quantized_idxs: List[int] = prepare_custom_config.output_quantized_indexes + + root_module_to_quantized_reference_module = get_root_module_to_quantized_reference_module(backend_config) + # convert tuples so that it can work with isinstance(module, tuple_of_classes) + root_module_classes = tuple(root_module_to_quantized_reference_module.keys()) + qat_module_classes = get_qat_module_classes(backend_config) + fused_module_classes = get_fused_module_classes(backend_config) + statically_quantized_custom_module_nodes: Set[Node] = set() + + for node in list(model.graph.nodes): + if node.op == 'placeholder': + cur_placeholder_node_idx = placeholder_node_seen_cnt + placeholder_node_seen_cnt += 1 + if cur_placeholder_node_idx in input_quantized_idxs: + # Inputs are assumed to be quantized if the user specified the + # input_quantized_idxs override. + # we need to dequantize the inputs since all operators took + # floating point inputs in reference quantized models + _insert_dequantize_node(node, model.graph) + elif node.op == "output": + # If the argument is empty we don't need to do anything + if len(output_quantized_idxs) == 0: + continue + # Result are kept quantized if the user specified the + # output_quantized_idxs override. + # Remove the dequantize operator for the node in the end if any + return_node = node + output = node.args[0] + # outputs can be Node, list, tuple, dict, other cases are not supported yet + if isinstance(output, (list, tuple)): + for idx in output_quantized_idxs: + _maybe_recursive_remove_dequantize(output[idx], return_node, model.graph) + elif isinstance(output, (Node, dict)): + # we treat dict as a single argument currently, but it can be extended + # to support {"key": dtype} after we change output_quantized_idxs to + # dict + if 0 in output_quantized_idxs: + _maybe_recursive_remove_dequantize(output, return_node, model.graph) + else: + warnings.warn(f"Unsupported node type for output_quantized_idxs: {type(output)}") + elif node.op == "call_module": + mod = _get_module(node, modules) + assert mod is not None + if _is_activation_post_process(mod): + observed_node = node.args[0] + if observed_node in statically_quantized_custom_module_nodes: + _replace_observer_or_dequant_stub_with_dequantize_node(node, model.graph) + else: + if is_decomposed: + _replace_observer_with_quantize_dequantize_node_decomposed( + model, node, modules, node_name_to_scope, + node_name_to_qconfig) + else: + _replace_observer_with_quantize_dequantize_node( + model, node, modules, node_name_to_scope, + node_name_to_qconfig) + elif isinstance(mod, DeQuantStub): + _replace_observer_or_dequant_stub_with_dequantize_node(node, model.graph) + elif _is_observed_standalone_module(mod): + convert_standalone_module( + node, modules, model, is_reference, backend_config) + # below this point `type_before_parametrizations` is used + # instead of `type` to handle situations with fx quant + sparsity + elif type_before_parametrizations(mod) in set( + root_module_classes).union(qat_module_classes).union(fused_module_classes): + # extra check for fused module classes to make sure they are fused module classes + # of target modules + if type_before_parametrizations(mod) in fused_module_classes and \ + type_before_parametrizations(mod[0]) not in root_module_classes: # type: ignore[index] + continue + convert_weighted_module( + node, modules, observed_node_names, node_name_to_qconfig, backend_config, + is_decomposed, is_reference) + elif type_before_parametrizations(mod) in custom_module_classes: + convert_custom_module( + node, model.graph, modules, custom_module_class_mapping, + statically_quantized_custom_module_nodes) + + # remove deadcode after converting observers to quant/dequant ops + model.graph.eliminate_dead_code() + model = GraphModule(model, model.graph) + + # TODO: maybe move this to quantize_fx.py + if not is_reference: + model = lower_to_fbgemm(model, node_name_to_qconfig, node_name_to_scope) + + # TODO: this looks hacky, we want to check why we need this and see if we can + # remove this + # removes qconfig and activation_post_process modules + if _remove_qconfig_flag: + _remove_qconfig(model) + model.delete_all_unused_submodules() + model.meta.pop("_observed_graph_module_attrs", None) + return model diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/match_utils.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/match_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..cf287db8c5245453afc795565f130ed64080674d --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/match_utils.py @@ -0,0 +1,237 @@ +import sys +import torch +from torch.fx.graph import ( + Graph, + Node, +) +from torch.ao.quantization.utils import Pattern +from .quantize_handler import ( + QuantizeHandler, +) +from ..qconfig import ( + QConfigAny, +) +from ..utils import ( + MatchAllNode +) +from .graph_module import ( + _is_observed_standalone_module, +) +from torch.nn.utils.parametrize import type_before_parametrizations +from typing import Any, Dict, List, Callable, Optional, Tuple, Type, Set, Iterable + + +__all__: List[str] = [] + +# TODO(future PR): the 1st argument is typed as `List[Node]`, but a better type +# would be a recursive `List[Union[Node, Tuple[Union[Node, ...]]]]` +_MatchResult = Tuple[Node, List[Node], Optional[Pattern], QuantizeHandler] + +_MatchResultWithQConfig = Tuple[Node, List[Node], Optional[Pattern], QuantizeHandler, + QConfigAny] + +# Note: The order of patterns is important! match function will take whatever is matched first, so we'll +# need to put the fusion patterns before single patterns. For example, add_relu should be registered come before relu. +# decorators are applied in the reverse order we see. Also when we match the nodes in the graph with these patterns, +# we'll start from the last node of the graph and traverse back. +def _is_match(modules, node, pattern, max_uses=sys.maxsize): + """ Matches a node in fx against a pattern + """ + if isinstance(pattern, tuple): + self_match, *arg_matches = pattern + if self_match is getattr: + assert len(pattern) == 2, 'Expecting getattr pattern to have two elements' + arg_matches = [] + else: + self_match = pattern + arg_matches = [] + + if isinstance(self_match, type) and issubclass(self_match, MatchAllNode): + return True + + if node == pattern: + return True + + if not isinstance(node, Node) or len(node.users) > max_uses: + return False + + if isinstance(self_match, type) and issubclass(self_match, torch.nn.Module): + if node.op != 'call_module': + return False + if not type_before_parametrizations(modules[node.target]) == self_match: + return False + elif callable(self_match): + if node.op != 'call_function' or node.target is not self_match: + return False + elif node.target is getattr: + if node.args[1] != pattern[1]: + return False + elif isinstance(self_match, str): + if node.op != 'call_method' or node.target != self_match: + return False + elif node.target != self_match: + return False + + if not arg_matches: + return True + + if len(arg_matches) != len(node.args): + return False + + return all(_is_match(modules, node, arg_match, max_uses=1) for node, arg_match in zip(node.args, arg_matches)) + +def _find_matches( + graph: Graph, + modules: Dict[str, torch.nn.Module], + patterns: Dict[Pattern, QuantizeHandler], + root_node_getter_mapping: Dict[Pattern, Callable], + standalone_module_names: Optional[List[str]] = None, + standalone_module_classes: Optional[List[Type]] = None, + custom_module_classes: Optional[List[Any]] = None) -> Dict[str, _MatchResult]: + """ + Matches the nodes in the input graph to quantization patterns, and + outputs the information needed to quantize them in future steps. + + Inputs: + - graph: an fx.Graph object + - modules: a mapping of fully qualified module name to instance, + for example, {'foo': ModuleFoo, ...} + - patterns: a mapping from a tuple of nodes in reverse order to + uninitialized QuantizeHandler subclass. + + Outputs a map of + node_name -> + (node, matched_values, matched_pattern, QuantizeHandler instance, + qconfig) + + For example, { + 'relu_1': (relu_1, [relu_1], torch.nn.functional.relu, + , QConfig(...)), + ... + } + """ + if custom_module_classes is None: + custom_module_classes = [] + + if standalone_module_classes is None: + standalone_module_classes = [] + + if standalone_module_names is None: + standalone_module_names = [] + + match_map: Dict[str, _MatchResult] = {} + all_matched : Set[str] = set() + + def _recursive_record_node_in_match_map( + last_node, + match_map, + node_pattern, + matched_node_pattern, + pattern, + match_value): + if isinstance(node_pattern, Node): + match_map[node_pattern.name] = ( + last_node, matched_node_pattern, pattern, match_value) + elif not isinstance(node_pattern, Iterable): + return + else: + for n in node_pattern: + _recursive_record_node_in_match_map(last_node, match_map, n, matched_node_pattern, pattern, match_value) + + # TODO: 1. merge with fuse matcher 2. document the code + def record_match( + pattern, + node, + last_node, + matched_node_pattern, + match_map): + if isinstance(pattern, tuple): + s, *args = pattern + is_single_arg = len(args) == 1 + current_node_pattern: List[Node] = [] + record_match( + s, + node, + last_node, + matched_node_pattern, + match_map) + if pattern[0] is not getattr: + for subpattern, arg in zip(args, node.args): + record_match( + subpattern, + arg, + node, + current_node_pattern, + match_map) + if len(current_node_pattern) > 1: + # current_node_pattern is the node pattern we get from matching + # the subpattern with arguments of the node + # we use is_single_arg to recover the original structure of the pattern + # if the original pattern has a single argument, we will have + # (original_op, (original_arg, ...)) + # otherwise, we'll have a list of arguments + # (original_op, arg0, arg1, arg2, ...) + if is_single_arg: + matched_node_pattern.append(tuple(current_node_pattern)) + else: + matched_node_pattern.extend(list(current_node_pattern)) + else: + matched_node_pattern.append(current_node_pattern[0]) + else: + matched_node_pattern.append(node) + + for node in reversed(graph.nodes): + if node.name not in match_map and node.name not in all_matched: + for pattern, quantize_handler_cls in patterns.items(): + root_node_getter = root_node_getter_mapping.get(pattern, None) + if _is_match(modules, node, pattern) and node.name not in match_map: + matched_node_pattern: List[Node] = [] + record_match( + pattern, + node, + node, + matched_node_pattern, + match_map) + quantize_handler = quantize_handler_cls( # type: ignore[operator] + matched_node_pattern, + modules, + root_node_getter) + last_node = node + # record the match for all nodes in the pattern + _recursive_record_node_in_match_map( + last_node, + match_map, + # we need to record all nodes in the matched pattern in the match_map + matched_node_pattern, + # this is a part of the value corresponding to the node + matched_node_pattern, + pattern, + quantize_handler) + break + + # add custom module instances to the match result + assert modules is not None + for node in graph.nodes: + if node.op == 'call_module' and \ + type(modules[node.target]) in custom_module_classes: + match_map[node.name] = ( + node, node, None, QuantizeHandler(node, modules, is_custom_module=True)) + + def is_standalone_module(node_target: str, modules: Dict[str, torch.nn.Module]): + assert modules is not None + return ( + node_target in standalone_module_names or # type: ignore[operator] + type(modules[node_target]) in standalone_module_classes # type: ignore[operator] + ) + + # add standalone modules to the match + for node in graph.nodes: + if node.op == 'call_module' and \ + (is_standalone_module(node.target, modules) or + _is_observed_standalone_module(modules[node.target])): + # add node to matched nodes + match_map[node.name] = ( + node, node, None, + QuantizeHandler(node, modules, is_standalone_module=True)) + + return match_map diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/tracer.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/tracer.py new file mode 100644 index 0000000000000000000000000000000000000000..47f326caf7043f54866f860ab464c3434eb91a5d --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/tracer.py @@ -0,0 +1,45 @@ +import torch +from torch.fx._symbolic_trace import Tracer +from torch.fx.proxy import Scope +from torch.ao.nn.intrinsic import _FusedModule +from typing import List, Callable + +__all__ = [ + "QuantizationTracer", +] + +class ScopeContextManager(torch.fx.proxy.ScopeContextManager): + def __init__( + self, + scope: Scope, + current_module: torch.nn.Module, + current_module_path: str + ): + super().__init__(scope, Scope(current_module_path, type(current_module))) + + +class QuantizationTracer(Tracer): + def __init__( + self, skipped_module_names: List[str], skipped_module_classes: List[Callable] + ): + super().__init__() + self.skipped_module_names = skipped_module_names + self.skipped_module_classes = skipped_module_classes + # NB: initialized the module_type of top level module to None + # we are assuming people won't configure the model with the type of top level + # module here, since people can use "" for global config + # We can change this if there is a use case that configures + # qconfig using top level module type + self.scope = Scope("", None) + self.record_stack_traces = True + + def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool: + return ( + ( + (m.__module__.startswith("torch.nn") or m.__module__.startswith("torch.ao.nn")) + and not isinstance(m, torch.nn.Sequential) + ) + or module_qualified_name in self.skipped_module_names + or type(m) in self.skipped_module_classes + or isinstance(m, _FusedModule) + ) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/multiprocessing/__pycache__/queue.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/multiprocessing/__pycache__/queue.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..45c504fb266e4b644172f3aca09fd80148423e69 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/multiprocessing/__pycache__/queue.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/multiprocessing/pool.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/multiprocessing/pool.py new file mode 100644 index 0000000000000000000000000000000000000000..6915203566469cfaf7170d87894ce03cc8348dd5 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/multiprocessing/pool.py @@ -0,0 +1,52 @@ +import multiprocessing.pool +import multiprocessing.util as util + +from .queue import SimpleQueue + + +def clean_worker(*args, **kwargs): + import gc + + multiprocessing.pool.worker(*args, **kwargs) + # Regular multiprocessing workers don't fully clean up after themselves, + # so we have to explicitly trigger garbage collection to make sure that all + # destructors are called... + gc.collect() + + +class Pool(multiprocessing.pool.Pool): + """Pool implementation which uses our version of SimpleQueue. + + This lets us pass tensors in shared memory across processes instead of + serializing the underlying data. + """ + + def _setup_queues(self): + self._inqueue = SimpleQueue() + self._outqueue = SimpleQueue() + self._quick_put = self._inqueue._writer.send + self._quick_get = self._outqueue._reader.recv + + def _repopulate_pool(self): + """Increase the number of pool processes to the specified number. + + Bring the number of pool processes up to the specified number, for use after + reaping workers which have exited. + """ + for i in range(self._processes - len(self._pool)): + # changed worker -> clean_worker + args = ( + self._inqueue, + self._outqueue, + self._initializer, + self._initargs, + self._maxtasksperchild, + ) + if hasattr(self, "_wrap_exception"): + args += (self._wrap_exception,) + w = self.Process(target=clean_worker, args=args) + self._pool.append(w) + w.name = w.name.replace("Process", "PoolWorker") + w.daemon = True + w.start() + util.debug("added worker") diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/__pycache__/_reduction.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/__pycache__/_reduction.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..10b52a3821c4a9bf6e855a9250211c546814a271 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/__pycache__/_reduction.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/__pycache__/grad.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/__pycache__/grad.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..12654abbcb81873181090a810d13a102f855d022 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/__pycache__/grad.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/attention/__pycache__/_utils.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/attention/__pycache__/_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..60e7af8745d05fed4db4ce9bc00c58e1b3195b71 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/attention/__pycache__/_utils.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/attention/_utils.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/attention/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6662eb58f361f1d650bb5f217d7d72571d6652a1 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/attention/_utils.py @@ -0,0 +1,57 @@ +"""Defines utilities for interacting with scaled_dot_product_attention""" +import math +from typing import List, Optional + +import torch + +__all__: List[str] = [] + + +def _input_requires_grad(*tensors: torch.Tensor) -> bool: + """Returns True if any of the tensors requires grad""" + return any(t.requires_grad for t in tensors) + + +def _postprocess_flash_output(inpt_tensor: torch.Tensor, og_size: int) -> torch.Tensor: + """Handles the unpad of the last dimension""" + if inpt_tensor.size(-1) != og_size: + return inpt_tensor[..., :og_size] + return inpt_tensor + + +def _calculate_scale(head_dim_size: int, scale: Optional[float]) -> float: + """ + For FlashAttention we pad the head dimension to be a multiple of 8 so we need to scale the output + by the original head size and not the padded. + """ + if scale is not None: + return scale + return 1.0 / math.sqrt(head_dim_size) + + +def _validate_sdpa_input( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p=0.0, + is_causal=False, + scale=None, +): + if query.dtype != key.dtype or query.dtype != value.dtype: + raise ValueError( + f"Expected query, key, and value to have the same dtype, " + f"but got query.dtype: {query.dtype}, key.dtype: {key.dtype}, " + f"and value.dtype: {value.dtype} instead." + ) + if query.device != key.device or query.device != value.device: + raise ValueError( + f"Expected query, key, and value to have the same device type, " + f"but got query.device: {query.device}, key.device: {key.device}, " + f"and value.device: {value.device} instead." + ) + if query.dim() < 2 or key.dim() < 2 or value.dim() < 2: + raise ValueError( + f"Expected query, key, and value to all be at least 2 dimensional, but got query.dim: " + f"{query.dim()}, key.dim: {key.dim()} and value.dim: {value.dim()} instead." + ) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/quantized/modules/__pycache__/bn_relu.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/quantized/modules/__pycache__/bn_relu.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..65c95771b105900772843cffad0f959ae7dded78 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/quantized/modules/__pycache__/bn_relu.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/_functions.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..669448ce4fdad2732f75461f646cd125734c221d --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/_functions.py @@ -0,0 +1,288 @@ +import torch +import torch.distributed as dist + +from torch.autograd.function import Function + +class SyncBatchNorm(Function): + + @staticmethod + def forward(self, input, weight, bias, running_mean, running_var, eps, momentum, process_group, world_size): + if not ( + input.is_contiguous(memory_format=torch.channels_last) or + input.is_contiguous(memory_format=torch.channels_last_3d) + ): + input = input.contiguous() + if weight is not None: + weight = weight.contiguous() + + size = int(input.numel() // input.size(1)) + if size == 1 and world_size < 2: + raise ValueError(f'Expected more than 1 value per channel when training, got input size {size}') + + num_channels = input.shape[1] + if input.numel() > 0: + # calculate mean/invstd for input. + mean, invstd = torch.batch_norm_stats(input, eps) + + count = torch.full( + (1,), + input.numel() // input.size(1), + dtype=mean.dtype, + device=mean.device + ) + + # C, C, 1 -> (2C + 1) + combined = torch.cat([mean, invstd, count], dim=0) + else: + # for empty input, set stats and the count to zero. The stats with + # zero count will be filtered out later when computing global mean + # & invstd, but they still needs to participate the all_gather + # collective communication to unblock other peer processes. + combined = torch.zeros( + 2 * num_channels + 1, + dtype=input.dtype, + device=input.device + ) + + # Use allgather instead of allreduce because count could be different across + # ranks, simple all reduce op can not give correct results. + # batch_norm_gather_stats_with_counts calculates global mean & invstd based on + # all gathered mean, invstd and count. + # for nccl backend, use the optimized version of all gather. + # The Gloo backend does not support `all_gather_into_tensor`. + if process_group._get_backend_name() != "gloo": + # world_size * (2C + 1) + combined_size = combined.numel() + combined_flat = torch.empty(1, + combined_size * world_size, + dtype=combined.dtype, + device=combined.device) + dist.all_gather_into_tensor(combined_flat, combined, process_group, async_op=False) + combined = torch.reshape(combined_flat, (world_size, combined_size)) + # world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1 + mean_all, invstd_all, count_all = torch.split(combined, num_channels, dim=1) + else: + # world_size * (2C + 1) + combined_list = [ + torch.empty_like(combined) for _ in range(world_size) + ] + dist.all_gather(combined_list, combined, process_group, async_op=False) + combined = torch.stack(combined_list, dim=0) + # world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1 + mean_all, invstd_all, count_all = torch.split(combined, num_channels, dim=1) + + if not (torch.cuda.is_available() and torch.cuda.is_current_stream_capturing()): + # The lines below force a synchronization between CUDA and CPU, because + # the shape of the result count_all depends on the values in mask tensor. + # Such synchronizations break CUDA Graph capturing. + # See https://github.com/pytorch/pytorch/issues/78549 + # FIXME: https://github.com/pytorch/pytorch/issues/78656 describes + # a better longer-term solution. + + # remove stats from empty inputs + mask = count_all.squeeze(-1) >= 1 + count_all = count_all[mask] + mean_all = mean_all[mask] + invstd_all = invstd_all[mask] + + # calculate global mean & invstd + counts = count_all.view(-1) + if running_mean is not None and counts.dtype != running_mean.dtype: + counts = counts.to(running_mean.dtype) + mean, invstd = torch.batch_norm_gather_stats_with_counts( + input, + mean_all, + invstd_all, + running_mean, + running_var, + momentum, + eps, + counts, + ) + + self.save_for_backward(input, weight, mean, invstd, count_all.to(torch.int32)) + self.process_group = process_group + + # apply element-wise normalization + if input.numel() > 0: + return torch.batch_norm_elemt(input, weight, bias, mean, invstd, eps) + else: + return torch.empty_like(input) + + @staticmethod + def backward(self, grad_output): + if not ( + grad_output.is_contiguous(memory_format=torch.channels_last) or + grad_output.is_contiguous(memory_format=torch.channels_last_3d) + ): + grad_output = grad_output.contiguous() + saved_input, weight, mean, invstd, count_tensor = self.saved_tensors + grad_input = grad_weight = grad_bias = None + process_group = self.process_group + + if saved_input.numel() > 0: + # calculate local stats as well as grad_weight / grad_bias + sum_dy, sum_dy_xmu, grad_weight, grad_bias = torch.batch_norm_backward_reduce( + grad_output, + saved_input, + mean, + invstd, + weight, + self.needs_input_grad[0], + self.needs_input_grad[1], + self.needs_input_grad[2] + ) + + if self.needs_input_grad[0]: + # synchronizing stats used to calculate input gradient. + num_channels = sum_dy.shape[0] + combined = torch.cat([sum_dy, sum_dy_xmu], dim=0) + torch.distributed.all_reduce( + combined, torch.distributed.ReduceOp.SUM, process_group, async_op=False) + sum_dy, sum_dy_xmu = torch.split(combined, num_channels) + + # backward pass for gradient calculation + if weight is not None and weight.dtype != mean.dtype: + weight = weight.to(mean.dtype) + grad_input = torch.batch_norm_backward_elemt( + grad_output, + saved_input, + mean, + invstd, + weight, + sum_dy, + sum_dy_xmu, + count_tensor + ) + # synchronizing of grad_weight / grad_bias is not needed as distributed + # training would handle all reduce. + if weight is None or not self.needs_input_grad[1]: + grad_weight = None + + if weight is None or not self.needs_input_grad[2]: + grad_bias = None + else: + # This process got an empty input tensor in the forward pass. + # Although this process can directly set grad_input as an empty + # tensor of zeros, it still needs to participate in the collective + # communication to unblock its peers, as other peer processes might + # have received non-empty inputs. + num_channels = saved_input.shape[1] + if self.needs_input_grad[0]: + # launch all_reduce to unblock other peer processes + combined = torch.zeros( + 2 * num_channels, + dtype=saved_input.dtype, + device=saved_input.device + ) + torch.distributed.all_reduce( + combined, torch.distributed.ReduceOp.SUM, process_group, async_op=False) + + # Leave grad_input, grad_weight and grad_bias as None, which will be + # interpreted by the autograd engine as Tensors full of zeros. + + return grad_input, grad_weight, grad_bias, None, None, None, None, None, None + +class CrossMapLRN2d(Function): + + @staticmethod + def forward(ctx, input, size, alpha=1e-4, beta=0.75, k=1): + ctx.size = size + ctx.alpha = alpha + ctx.beta = beta + ctx.k = k + ctx.scale = None + + if input.dim() != 4: + raise ValueError(f"CrossMapLRN2d: Expected input to be 4D, got {input.dim()}D instead.") + + ctx.scale = ctx.scale or input.new() + output = input.new() + + batch_size = input.size(0) + channels = input.size(1) + input_height = input.size(2) + input_width = input.size(3) + + output.resize_as_(input) + ctx.scale.resize_as_(input) + + # use output storage as temporary buffer + input_square = output + torch.pow(input, 2, out=input_square) + + pre_pad = int((ctx.size - 1) / 2 + 1) + pre_pad_crop = min(pre_pad, channels) + + scale_first = ctx.scale.select(1, 0) + scale_first.zero_() + # compute first feature map normalization + for c in range(pre_pad_crop): + scale_first.add_(input_square.select(1, c)) + + # reuse computations for next feature maps normalization + # by adding the next feature map and removing the previous + for c in range(1, channels): + scale_previous = ctx.scale.select(1, c - 1) + scale_current = ctx.scale.select(1, c) + scale_current.copy_(scale_previous) + if c < channels - pre_pad + 1: + square_next = input_square.select(1, c + pre_pad - 1) + scale_current.add_(square_next, alpha=1) + + if c > pre_pad: + square_previous = input_square.select(1, c - pre_pad) + scale_current.add_(square_previous, alpha=-1) + + ctx.scale.mul_(ctx.alpha / ctx.size).add_(ctx.k) + + torch.pow(ctx.scale, -ctx.beta, out=output) + output.mul_(input) + + ctx.save_for_backward(input, output) + return output + + @staticmethod + def backward(ctx, grad_output): + input, output = ctx.saved_tensors + grad_input = grad_output.new() + + batch_size = input.size(0) + channels = input.size(1) + input_height = input.size(2) + input_width = input.size(3) + + paddded_ratio = input.new(channels + ctx.size - 1, input_height, + input_width) + accum_ratio = input.new(input_height, input_width) + + cache_ratio_value = 2 * ctx.alpha * ctx.beta / ctx.size + inversePrePad = int(ctx.size - (ctx.size - 1) / 2) + + grad_input.resize_as_(input) + torch.pow(ctx.scale, -ctx.beta, out=grad_input).mul_(grad_output) + + paddded_ratio.zero_() + padded_ratio_center = paddded_ratio.narrow(0, inversePrePad, + channels) + for n in range(batch_size): + torch.mul(grad_output[n], output[n], out=padded_ratio_center) + padded_ratio_center.div_(ctx.scale[n]) + torch.sum( + paddded_ratio.narrow(0, 0, ctx.size - 1), 0, keepdim=False, out=accum_ratio) + for c in range(channels): + accum_ratio.add_(paddded_ratio[c + ctx.size - 1]) + grad_input[n][c].addcmul_(input[n][c], accum_ratio, value=-cache_ratio_value) + accum_ratio.add_(paddded_ratio[c], alpha=-1) + + return grad_input, None, None, None, None + +class BackwardHookFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, *args): + ctx.mark_non_differentiable(*[arg for arg in args if not arg.requires_grad]) + return args + + @staticmethod + def backward(ctx, *args): + return args diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/upsampling.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/upsampling.py new file mode 100644 index 0000000000000000000000000000000000000000..da9b23add18dec8b051730c78a5944d4339e526c --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/upsampling.py @@ -0,0 +1,264 @@ +from .module import Module +from .. import functional as F + +from torch import Tensor +from typing import Optional +from ..common_types import _size_2_t, _ratio_2_t, _size_any_t, _ratio_any_t + +__all__ = ['Upsample', 'UpsamplingNearest2d', 'UpsamplingBilinear2d'] + + +class Upsample(Module): + r"""Upsamples a given multi-channel 1D (temporal), 2D (spatial) or 3D (volumetric) data. + + The input data is assumed to be of the form + `minibatch x channels x [optional depth] x [optional height] x width`. + Hence, for spatial inputs, we expect a 4D Tensor and for volumetric inputs, we expect a 5D Tensor. + + The algorithms available for upsampling are nearest neighbor and linear, + bilinear, bicubic and trilinear for 3D, 4D and 5D input Tensor, + respectively. + + One can either give a :attr:`scale_factor` or the target output :attr:`size` to + calculate the output size. (You cannot give both, as it is ambiguous) + + Args: + size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int], optional): + output spatial sizes + scale_factor (float or Tuple[float] or Tuple[float, float] or Tuple[float, float, float], optional): + multiplier for spatial size. Has to match input size if it is a tuple. + mode (str, optional): the upsampling algorithm: one of ``'nearest'``, + ``'linear'``, ``'bilinear'``, ``'bicubic'`` and ``'trilinear'``. + Default: ``'nearest'`` + align_corners (bool, optional): if ``True``, the corner pixels of the input + and output tensors are aligned, and thus preserving the values at + those pixels. This only has effect when :attr:`mode` is + ``'linear'``, ``'bilinear'``, ``'bicubic'``, or ``'trilinear'``. + Default: ``False`` + recompute_scale_factor (bool, optional): recompute the scale_factor for use in the + interpolation calculation. If `recompute_scale_factor` is ``True``, then + `scale_factor` must be passed in and `scale_factor` is used to compute the + output `size`. The computed output `size` will be used to infer new scales for + the interpolation. Note that when `scale_factor` is floating-point, it may differ + from the recomputed `scale_factor` due to rounding and precision issues. + If `recompute_scale_factor` is ``False``, then `size` or `scale_factor` will + be used directly for interpolation. + + Shape: + - Input: :math:`(N, C, W_{in})`, :math:`(N, C, H_{in}, W_{in})` or :math:`(N, C, D_{in}, H_{in}, W_{in})` + - Output: :math:`(N, C, W_{out})`, :math:`(N, C, H_{out}, W_{out})` + or :math:`(N, C, D_{out}, H_{out}, W_{out})`, where + + .. math:: + D_{out} = \left\lfloor D_{in} \times \text{scale\_factor} \right\rfloor + + .. math:: + H_{out} = \left\lfloor H_{in} \times \text{scale\_factor} \right\rfloor + + .. math:: + W_{out} = \left\lfloor W_{in} \times \text{scale\_factor} \right\rfloor + + .. warning:: + With ``align_corners = True``, the linearly interpolating modes + (`linear`, `bilinear`, `bicubic`, and `trilinear`) don't proportionally + align the output and input pixels, and thus the output values can depend + on the input size. This was the default behavior for these modes up to + version 0.3.1. Since then, the default behavior is + ``align_corners = False``. See below for concrete examples on how this + affects the outputs. + + .. note:: + If you want downsampling/general resizing, you should use :func:`~nn.functional.interpolate`. + + Examples:: + + >>> input = torch.arange(1, 5, dtype=torch.float32).view(1, 1, 2, 2) + >>> input + tensor([[[[1., 2.], + [3., 4.]]]]) + + >>> m = nn.Upsample(scale_factor=2, mode='nearest') + >>> m(input) + tensor([[[[1., 1., 2., 2.], + [1., 1., 2., 2.], + [3., 3., 4., 4.], + [3., 3., 4., 4.]]]]) + + >>> # xdoctest: +IGNORE_WANT("other tests seem to modify printing styles") + >>> m = nn.Upsample(scale_factor=2, mode='bilinear') # align_corners=False + >>> m(input) + tensor([[[[1.0000, 1.2500, 1.7500, 2.0000], + [1.5000, 1.7500, 2.2500, 2.5000], + [2.5000, 2.7500, 3.2500, 3.5000], + [3.0000, 3.2500, 3.7500, 4.0000]]]]) + + >>> m = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) + >>> m(input) + tensor([[[[1.0000, 1.3333, 1.6667, 2.0000], + [1.6667, 2.0000, 2.3333, 2.6667], + [2.3333, 2.6667, 3.0000, 3.3333], + [3.0000, 3.3333, 3.6667, 4.0000]]]]) + + >>> # Try scaling the same data in a larger tensor + >>> input_3x3 = torch.zeros(3, 3).view(1, 1, 3, 3) + >>> input_3x3[:, :, :2, :2].copy_(input) + tensor([[[[1., 2.], + [3., 4.]]]]) + >>> input_3x3 + tensor([[[[1., 2., 0.], + [3., 4., 0.], + [0., 0., 0.]]]]) + + >>> # xdoctest: +IGNORE_WANT("seems to fail when other tests are run in the same session") + >>> m = nn.Upsample(scale_factor=2, mode='bilinear') # align_corners=False + >>> # Notice that values in top left corner are the same with the small input (except at boundary) + >>> m(input_3x3) + tensor([[[[1.0000, 1.2500, 1.7500, 1.5000, 0.5000, 0.0000], + [1.5000, 1.7500, 2.2500, 1.8750, 0.6250, 0.0000], + [2.5000, 2.7500, 3.2500, 2.6250, 0.8750, 0.0000], + [2.2500, 2.4375, 2.8125, 2.2500, 0.7500, 0.0000], + [0.7500, 0.8125, 0.9375, 0.7500, 0.2500, 0.0000], + [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]]]) + + >>> m = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) + >>> # Notice that values in top left corner are now changed + >>> m(input_3x3) + tensor([[[[1.0000, 1.4000, 1.8000, 1.6000, 0.8000, 0.0000], + [1.8000, 2.2000, 2.6000, 2.2400, 1.1200, 0.0000], + [2.6000, 3.0000, 3.4000, 2.8800, 1.4400, 0.0000], + [2.4000, 2.7200, 3.0400, 2.5600, 1.2800, 0.0000], + [1.2000, 1.3600, 1.5200, 1.2800, 0.6400, 0.0000], + [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]]]) + """ + + __constants__ = ['size', 'scale_factor', 'mode', 'align_corners', 'name', 'recompute_scale_factor'] + name: str + size: Optional[_size_any_t] + scale_factor: Optional[_ratio_any_t] + mode: str + align_corners: Optional[bool] + recompute_scale_factor: Optional[bool] + + def __init__(self, size: Optional[_size_any_t] = None, scale_factor: Optional[_ratio_any_t] = None, + mode: str = 'nearest', align_corners: Optional[bool] = None, + recompute_scale_factor: Optional[bool] = None) -> None: + super().__init__() + self.name = type(self).__name__ + self.size = size + if isinstance(scale_factor, tuple): + self.scale_factor = tuple(float(factor) for factor in scale_factor) + else: + self.scale_factor = float(scale_factor) if scale_factor else None + self.mode = mode + self.align_corners = align_corners + self.recompute_scale_factor = recompute_scale_factor + + def forward(self, input: Tensor) -> Tensor: + return F.interpolate(input, self.size, self.scale_factor, self.mode, self.align_corners, + recompute_scale_factor=self.recompute_scale_factor) + + def __setstate__(self, state): + if 'recompute_scale_factor' not in state: + state['recompute_scale_factor'] = True + + super().__setstate__(state) + + def extra_repr(self) -> str: + if self.scale_factor is not None: + info = 'scale_factor=' + repr(self.scale_factor) + else: + info = 'size=' + repr(self.size) + info += ', mode=' + repr(self.mode) + return info + + +class UpsamplingNearest2d(Upsample): + r"""Applies a 2D nearest neighbor upsampling to an input signal composed of several input channels. + + To specify the scale, it takes either the :attr:`size` or the :attr:`scale_factor` + as it's constructor argument. + + When :attr:`size` is given, it is the output size of the image `(h, w)`. + + Args: + size (int or Tuple[int, int], optional): output spatial sizes + scale_factor (float or Tuple[float, float], optional): multiplier for + spatial size. + + .. warning:: + This class is deprecated in favor of :func:`~nn.functional.interpolate`. + + Shape: + - Input: :math:`(N, C, H_{in}, W_{in})` + - Output: :math:`(N, C, H_{out}, W_{out})` where + + .. math:: + H_{out} = \left\lfloor H_{in} \times \text{scale\_factor} \right\rfloor + + .. math:: + W_{out} = \left\lfloor W_{in} \times \text{scale\_factor} \right\rfloor + + Examples:: + + >>> input = torch.arange(1, 5, dtype=torch.float32).view(1, 1, 2, 2) + >>> input + tensor([[[[1., 2.], + [3., 4.]]]]) + + >>> m = nn.UpsamplingNearest2d(scale_factor=2) + >>> m(input) + tensor([[[[1., 1., 2., 2.], + [1., 1., 2., 2.], + [3., 3., 4., 4.], + [3., 3., 4., 4.]]]]) + """ + + def __init__(self, size: Optional[_size_2_t] = None, scale_factor: Optional[_ratio_2_t] = None) -> None: + super().__init__(size, scale_factor, mode='nearest') + + +class UpsamplingBilinear2d(Upsample): + r"""Applies a 2D bilinear upsampling to an input signal composed of several input channels. + + To specify the scale, it takes either the :attr:`size` or the :attr:`scale_factor` + as it's constructor argument. + + When :attr:`size` is given, it is the output size of the image `(h, w)`. + + Args: + size (int or Tuple[int, int], optional): output spatial sizes + scale_factor (float or Tuple[float, float], optional): multiplier for + spatial size. + + .. warning:: + This class is deprecated in favor of :func:`~nn.functional.interpolate`. It is + equivalent to ``nn.functional.interpolate(..., mode='bilinear', align_corners=True)``. + + Shape: + - Input: :math:`(N, C, H_{in}, W_{in})` + - Output: :math:`(N, C, H_{out}, W_{out})` where + + .. math:: + H_{out} = \left\lfloor H_{in} \times \text{scale\_factor} \right\rfloor + + .. math:: + W_{out} = \left\lfloor W_{in} \times \text{scale\_factor} \right\rfloor + + Examples:: + + >>> input = torch.arange(1, 5, dtype=torch.float32).view(1, 1, 2, 2) + >>> input + tensor([[[[1., 2.], + [3., 4.]]]]) + + >>> # xdoctest: +IGNORE_WANT("do other tests modify the global state?") + >>> m = nn.UpsamplingBilinear2d(scale_factor=2) + >>> m(input) + tensor([[[[1.0000, 1.3333, 1.6667, 2.0000], + [1.6667, 2.0000, 2.3333, 2.6667], + [2.3333, 2.6667, 3.0000, 3.3333], + [3.0000, 3.3333, 3.6667, 4.0000]]]]) + """ + + def __init__(self, size: Optional[_size_2_t] = None, scale_factor: Optional[_ratio_2_t] = None) -> None: + super().__init__(size, scale_factor, mode='bilinear', align_corners=True) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/qat/dynamic/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/qat/dynamic/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d0fb2f2d463bfdc2b14dde3a8393d4aa30592681 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/qat/dynamic/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/qat/dynamic/modules/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/qat/dynamic/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c8168b30406a8b0c27251d466b3a9195016eba64 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/qat/dynamic/modules/__init__.py @@ -0,0 +1,3 @@ +from .linear import Linear + +__all__ = ["Linear"] diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/qat/dynamic/modules/linear.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/qat/dynamic/modules/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..ea69fba158d3bfedfe49bffef4b1664117fc3246 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/qat/dynamic/modules/linear.py @@ -0,0 +1,10 @@ +# flake8: noqa: F401 +r"""QAT Modules. + +This file is in the process of migration to `torch/ao/nn/qat/dynamic`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate file under the `torch/ao/nn/qat/dynamic/modules`, +while adding an import statement here. +""" +from torch.ao.nn.qat.dynamic.modules.linear import Linear diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/qat/modules/embedding_ops.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/qat/modules/embedding_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..f957543a0de1e00a1ad93566a6454a110d085150 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/qat/modules/embedding_ops.py @@ -0,0 +1,14 @@ +# flake8: noqa: F401 +r"""QAT Modules. + +This file is in the process of migration to `torch/ao/nn/qat`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate file under the `torch/ao/nn/qat/modules`, +while adding an import statement here. +""" + +__all__ = ['Embedding', 'EmbeddingBag'] + +from torch.ao.nn.qat.modules.embedding_ops import Embedding +from torch.ao.nn.qat.modules.embedding_ops import EmbeddingBag diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantizable/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantizable/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3d79bdbfe83209f18b17cc8c7b245f322871d6c0 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantizable/__init__.py @@ -0,0 +1 @@ +from .modules import * # noqa: F403 diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantizable/modules/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantizable/modules/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e172794289bca8209811a398c24f3c7ad3ea7c4e Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantizable/modules/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantizable/modules/__pycache__/activation.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantizable/modules/__pycache__/activation.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aac5dff105bd2f25c595139a0241b090a6b87d43 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantizable/modules/__pycache__/activation.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantizable/modules/rnn.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantizable/modules/rnn.py new file mode 100644 index 0000000000000000000000000000000000000000..a767ae060f96d0d509dbd3411d33c87ba99bb4d9 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantizable/modules/rnn.py @@ -0,0 +1,11 @@ +# flake8: noqa: F401 +r"""Quantizable Modules. + +This file is in the process of migration to `torch/ao/nn/quantizable`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate file under the `torch/ao/nn/quantizable/modules`, +while adding an import statement here. +""" +from torch.ao.nn.quantizable.modules.rnn import LSTM +from torch.ao.nn.quantizable.modules.rnn import LSTMCell diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c53b961e9494353094150da627341a9e950e3f35 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/__init__.py @@ -0,0 +1,40 @@ +from . import dynamic # noqa: F403 +from . import functional # noqa: F403 +from . import modules # noqa: F403 +from .modules import * # noqa: F403 +from .modules import MaxPool2d + +__all__ = [ + 'BatchNorm2d', + 'BatchNorm3d', + 'Conv1d', + 'Conv2d', + 'Conv3d', + 'ConvTranspose1d', + 'ConvTranspose2d', + 'ConvTranspose3d', + 'DeQuantize', + 'Dropout', + 'ELU', + 'Embedding', + 'EmbeddingBag', + 'GroupNorm', + 'Hardswish', + 'InstanceNorm1d', + 'InstanceNorm2d', + 'InstanceNorm3d', + 'LayerNorm', + 'LeakyReLU', + 'Linear', + 'LSTM', + 'MultiheadAttention', + 'PReLU', + 'Quantize', + 'ReLU6', + 'Sigmoid', + 'Softmax', + # Wrapper modules + 'FloatFunctional', + 'FXFloatFunctional', + 'QFunctional', +] diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/_reference/modules/__pycache__/conv.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/_reference/modules/__pycache__/conv.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0d984d2d68f405b35447dd892ed2416f33d1ed0e Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/_reference/modules/__pycache__/conv.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/_reference/modules/__pycache__/linear.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/_reference/modules/__pycache__/linear.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..59d5983a2adfd90c3d7ae365b5a074d209155e6a Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/_reference/modules/__pycache__/linear.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/_reference/modules/__pycache__/rnn.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/_reference/modules/__pycache__/rnn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d3df40eccae839b6fc312b50b3be58c214d76715 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/_reference/modules/__pycache__/rnn.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/_reference/modules/__pycache__/sparse.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/_reference/modules/__pycache__/sparse.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..46292875e51d6037bb5bd5276c328fa099da2855 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/_reference/modules/__pycache__/sparse.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/dynamic/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/dynamic/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1b08cd1bc7149c5506db3a952fff488eb06749f5 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/dynamic/__init__.py @@ -0,0 +1 @@ +from torch.ao.nn.quantized.dynamic import * # noqa: F403 diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/dynamic/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/dynamic/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6489168085d3543cc511de2f64e6e86d690e8c8c Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/dynamic/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/dynamic/modules/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/dynamic/modules/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1d6601b91a9ba2cbb57a769a9d0c790a92ac653b Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/dynamic/modules/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/dynamic/modules/conv.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/dynamic/modules/conv.py new file mode 100644 index 0000000000000000000000000000000000000000..4dc5ef66147c5a7d3495ea62c48cb35e38cfe1a5 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/dynamic/modules/conv.py @@ -0,0 +1,18 @@ +# flake8: noqa: F401 +r"""Quantized Dynamic Modules. + +This file is in the process of migration to `torch/ao/nn/quantized/dynamic`, +and is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate file under the `torch/ao/nn/quantized/dynamic/modules`, +while adding an import statement here. +""" + +__all__ = ['Conv1d', 'Conv2d', 'Conv3d', 'ConvTranspose1d', 'ConvTranspose2d', 'ConvTranspose3d'] + +from torch.ao.nn.quantized.dynamic.modules.conv import Conv1d +from torch.ao.nn.quantized.dynamic.modules.conv import Conv2d +from torch.ao.nn.quantized.dynamic.modules.conv import Conv3d +from torch.ao.nn.quantized.dynamic.modules.conv import ConvTranspose1d +from torch.ao.nn.quantized.dynamic.modules.conv import ConvTranspose2d +from torch.ao.nn.quantized.dynamic.modules.conv import ConvTranspose3d diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/dynamic/modules/linear.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/dynamic/modules/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..592384dbdb34425cc713f06511f286bee2235b73 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/dynamic/modules/linear.py @@ -0,0 +1,10 @@ +# flake8: noqa: F401 +r"""Quantized Dynamic Modules. + +This file is in the process of migration to `torch/ao/nn/quantized/dynamic`, +and is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate file under the `torch/ao/nn/quantized/dynamic/modules`, +while adding an import statement here. +""" +from torch.ao.nn.quantized.dynamic.modules.linear import Linear diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/dynamic/modules/rnn.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/dynamic/modules/rnn.py new file mode 100644 index 0000000000000000000000000000000000000000..4abef6573bed0033e6d6f5ed4438c2475b08ab43 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/dynamic/modules/rnn.py @@ -0,0 +1,22 @@ +# flake8: noqa: F401 +r"""Quantized Dynamic Modules. + +This file is in the process of migration to `torch/ao/nn/quantized/dynamic`, +and is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate file under the `torch/ao/nn/quantized/dynamic/modules`, +while adding an import statement here. +""" + +__all__ = ['pack_weight_bias', 'PackedParameter', 'RNNBase', 'LSTM', 'GRU', 'RNNCellBase', 'RNNCell', 'LSTMCell', + 'GRUCell'] + +from torch.ao.nn.quantized.dynamic.modules.rnn import pack_weight_bias +from torch.ao.nn.quantized.dynamic.modules.rnn import PackedParameter +from torch.ao.nn.quantized.dynamic.modules.rnn import RNNBase +from torch.ao.nn.quantized.dynamic.modules.rnn import LSTM +from torch.ao.nn.quantized.dynamic.modules.rnn import GRU +from torch.ao.nn.quantized.dynamic.modules.rnn import RNNCellBase +from torch.ao.nn.quantized.dynamic.modules.rnn import RNNCell +from torch.ao.nn.quantized.dynamic.modules.rnn import LSTMCell +from torch.ao.nn.quantized.dynamic.modules.rnn import GRUCell diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/modules/__pycache__/activation.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/modules/__pycache__/activation.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..84236c390f53a05ea1e20f125e8dec34aa78ee9b Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/modules/__pycache__/activation.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/modules/__pycache__/conv.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/modules/__pycache__/conv.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1a9c8eeccdd8ec288a3197c23da692d3107bbfb9 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/modules/__pycache__/conv.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/modules/__pycache__/dropout.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/modules/__pycache__/dropout.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d06288dd24511854cb4cff57ec75bebac4937fd2 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/modules/__pycache__/dropout.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/modules/__pycache__/linear.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/modules/__pycache__/linear.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3547fecaa37ac46361cb9ab13583c7436b22d730 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/modules/__pycache__/linear.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/modules/__pycache__/utils.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/modules/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b81419a61c275d6700d3bd232e6ad60abe08e219 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/modules/__pycache__/utils.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/modules/conv.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/modules/conv.py new file mode 100644 index 0000000000000000000000000000000000000000..63d9dc5d4c7de91d804131a9dcc5c744f013602a --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/modules/conv.py @@ -0,0 +1,21 @@ +# flake8: noqa: F401 +r"""Quantized Modules. + +This file is in the process of migration to `torch/ao/nn/quantized`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate file under the `torch/ao/nn/quantized/modules`, +while adding an import statement here. +""" + +__all__ = ['Conv1d', 'Conv2d', 'Conv3d', 'ConvTranspose1d', 'ConvTranspose2d', 'ConvTranspose3d'] + +from torch.ao.nn.quantized.modules.conv import _reverse_repeat_padding + +from torch.ao.nn.quantized.modules.conv import Conv1d +from torch.ao.nn.quantized.modules.conv import Conv2d +from torch.ao.nn.quantized.modules.conv import Conv3d + +from torch.ao.nn.quantized.modules.conv import ConvTranspose1d +from torch.ao.nn.quantized.modules.conv import ConvTranspose2d +from torch.ao.nn.quantized.modules.conv import ConvTranspose3d diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/modules/utils.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/modules/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..51c81a62b78f1b12ac5fe9a3a71239725b033f7c --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/modules/utils.py @@ -0,0 +1,15 @@ +# flake8: noqa: F401 +r"""Quantized Modules. + +This file is in the process of migration to `torch/ao/nn/quantized`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate file under the `torch/ao/nn/quantized/modules`, +while adding an import statement here. +""" + +from torch.ao.nn.quantized.modules.utils import _ntuple_from_first +from torch.ao.nn.quantized.modules.utils import _pair_from_first +from torch.ao.nn.quantized.modules.utils import _quantize_weight +from torch.ao.nn.quantized.modules.utils import _hide_packed_params_repr +from torch.ao.nn.quantized.modules.utils import WeightedQuantizedModule diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ddc61d26d24876a84ad1b78b8adee3aa73f9ed52 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/__init__.py @@ -0,0 +1,32 @@ +from . import rnn +from .clip_grad import clip_grad_norm, clip_grad_norm_, clip_grad_value_ +from .weight_norm import weight_norm, remove_weight_norm +from .convert_parameters import parameters_to_vector, vector_to_parameters +from .spectral_norm import spectral_norm, remove_spectral_norm +from .fusion import fuse_conv_bn_eval, fuse_conv_bn_weights, fuse_linear_bn_eval, fuse_linear_bn_weights +from .memory_format import convert_conv2d_weight_memory_format, convert_conv3d_weight_memory_format +from . import parametrizations +from .init import skip_init +from . import stateless + +__all__ = [ + "clip_grad_norm", + "clip_grad_norm_", + "clip_grad_value_", + "convert_conv2d_weight_memory_format", + "convert_conv3d_weight_memory_format", + "fuse_conv_bn_eval", + "fuse_conv_bn_weights", + "fuse_linear_bn_eval", + "fuse_linear_bn_weights", + "parameters_to_vector", + "parametrizations", + "remove_spectral_norm", + "remove_weight_norm", + "rnn", + "skip_init", + "spectral_norm", + "stateless", + "vector_to_parameters", + "weight_norm", +] diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..77fad4ecde2c978d32d7deb36a50dcec6bf2c7e9 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/clip_grad.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/clip_grad.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b1071121b7a25d1c5f6a12f8d1b6114bd985240d Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/clip_grad.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/convert_parameters.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/convert_parameters.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d81120ffc51ea5ae55457a303d357f451282b709 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/convert_parameters.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/fusion.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/fusion.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..86ee0b75f35d73955126e61eceb19d7e75ec466c Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/fusion.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/memory_format.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/memory_format.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ab5e713be44af900fe66e0e319458fc3c9fa962 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/memory_format.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/parametrizations.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/parametrizations.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3019a0c71d7f0d1d0991297c679c2c5ac7a246a4 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/parametrizations.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/parametrize.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/parametrize.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a5434c930b28c6f18e1972030a78d449c9854aa8 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/parametrize.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/spectral_norm.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/spectral_norm.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..82e179ec321fc25102ec66191dd26c8446fe461d Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/spectral_norm.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/stateless.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/stateless.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8ba810b3a9113a540aa700423baa0203509bd9fe Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/stateless.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/weight_norm.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/weight_norm.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5c8a762b822787cbaa0cb03dfae2ba5809c5e57b Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/weight_norm.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/_deprecation_utils.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/_deprecation_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1b2a9b6e29f2f2f0157f97e8210d13751e0bcb8c --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/_deprecation_utils.py @@ -0,0 +1,45 @@ +from typing import List, Callable +import importlib +import warnings + + +_MESSAGE_TEMPLATE = r"Usage of '{old_location}' is deprecated; please use '{new_location}' instead." + +def lazy_deprecated_import(all: List[str], old_module: str, new_module: str) -> Callable: + r"""Import utility to lazily import deprecated packages / modules / functional. + + The old_module and new_module are also used in the deprecation warning defined + by the `_MESSAGE_TEMPLATE`. + + Args: + all: The list of the functions that are imported. Generally, the module's + __all__ list of the module. + old_module: Old module location + new_module: New module location / Migrated location + + Returns: + Callable to assign to the `__getattr__` + + Usage: + + # In the `torch/nn/quantized/functional.py` + from torch.nn.utils._deprecation_utils import lazy_deprecated_import + _MIGRATED_TO = "torch.ao.nn.quantized.functional" + __getattr__ = lazy_deprecated_import( + all=__all__, + old_module=__name__, + new_module=_MIGRATED_TO) + """ + warning_message = _MESSAGE_TEMPLATE.format( + old_location=old_module, + new_location=new_module) + + def getattr_dunder(name): + if name in all: + # We are using the "RuntimeWarning" to make sure it is not + # ignored by default. + warnings.warn(warning_message, RuntimeWarning) + package = importlib.import_module(new_module) + return getattr(package, name) + raise AttributeError(f"Module {new_module!r} has no attribute {name!r}.") + return getattr_dunder diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..102474614238efec588ea4dc69d1d568d4fc60bb --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/__init__.py @@ -0,0 +1,9 @@ +from .conv_expanded_weights import ConvPerSampleGrad +from .embedding_expanded_weights import EmbeddingPerSampleGrad +from .group_norm_expanded_weights import GroupNormPerSampleGrad +from .instance_norm_expanded_weights import InstanceNormPerSampleGrad +from .layer_norm_expanded_weights import LayerNormPerSampleGrad +from .linear_expanded_weights import LinearPerSampleGrad +from .expanded_weights_impl import ExpandedWeight + +__all__ = ['ExpandedWeight'] diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/__pycache__/embedding_expanded_weights.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/__pycache__/embedding_expanded_weights.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6ea669c266f3e247a0c225e395b4cdd5b70cadcd Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/__pycache__/embedding_expanded_weights.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/__pycache__/expanded_weights_impl.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/__pycache__/expanded_weights_impl.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f5c516d732016e46e35a87506e7f2782c03bd3dc Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/__pycache__/expanded_weights_impl.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/__pycache__/group_norm_expanded_weights.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/__pycache__/group_norm_expanded_weights.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a51ada861566edd07399a5b1c183d5f09f8904ca Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/__pycache__/group_norm_expanded_weights.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/__pycache__/instance_norm_expanded_weights.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/__pycache__/instance_norm_expanded_weights.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..94349ce058921eb0d685da102ae34b4b475d09b6 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/__pycache__/instance_norm_expanded_weights.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/__pycache__/linear_expanded_weights.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/__pycache__/linear_expanded_weights.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..78c5aaa7156aab98e66e1b820810e94d2d1a937e Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/__pycache__/linear_expanded_weights.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/conv_expanded_weights.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/conv_expanded_weights.py new file mode 100644 index 0000000000000000000000000000000000000000..c10ccb90ae92f1f57513de5c0ab7a56c26996298 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/conv_expanded_weights.py @@ -0,0 +1,52 @@ +import torch +import torch.nn.functional as F + +from .conv_utils import conv_backward, conv_args_and_kwargs, conv_picker, conv_input_for_string_padding +from .expanded_weights_impl import ExpandedWeight, implements_per_sample_grads +from .expanded_weights_utils import forward_helper + +@implements_per_sample_grads(F.conv1d) +@implements_per_sample_grads(F.conv2d) +@implements_per_sample_grads(F.conv3d) +class ConvPerSampleGrad(torch.autograd.Function): + @staticmethod + def forward(ctx, kwarg_names, conv_fn, *expanded_args_and_kwargs): + expanded_args, expanded_kwargs = conv_args_and_kwargs(kwarg_names, expanded_args_and_kwargs) + orig_input = expanded_args[0] + was_same_padding = expanded_kwargs['padding'] == "same" + + if isinstance(expanded_kwargs['padding'], str): + # if padding is a string, we'll do the necessary padding (slowly) using F.pad + kernel_size = expanded_args[1].shape[2:] + padding, dilation = expanded_kwargs['padding'], expanded_kwargs['dilation'] + input = conv_input_for_string_padding(conv_fn, padding, expanded_args[0], dilation, kernel_size) + expanded_args = (input, expanded_args[1]) + # since we've already done the padding, don't need any more + expanded_kwargs['padding'] = 0 + + output = forward_helper(conv_fn, expanded_args, expanded_kwargs) + input, weight = expanded_args + batched_dim_size = conv_picker(conv_fn, 3, 4, 5) + if input.dim() != batched_dim_size: + raise RuntimeError(f"Expanded Weights only support convolution with batched input, got {conv_fn} with an" + f"unbatched input of dim {input.dim()}, expected input of dim {batched_dim_size}") + + ctx.conv_fn = conv_fn + + ctx.batch_size = orig_input.shape[0] + ctx.input_required_grad = orig_input.requires_grad + ctx.orig_input_shape = orig_input.shape + ctx.was_same_padding = was_same_padding + ctx.stride, ctx.padding = expanded_kwargs['stride'], expanded_kwargs['padding'] + ctx.dilation, ctx.groups = expanded_kwargs['dilation'], expanded_kwargs['groups'] + + if isinstance(weight, ExpandedWeight): + ctx.input = input + ctx.weight = weight + ctx.bias = expanded_kwargs['bias'] + + return output + + @staticmethod + def backward(ctx, grad_output): + return conv_backward(ctx.conv_fn, ctx, grad_output) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/expanded_weights_utils.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/expanded_weights_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..249dbe59120434b856acb654bc6ba8bd65b926c0 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/expanded_weights_utils.py @@ -0,0 +1,145 @@ +from typing import Optional + +import torch +from .expanded_weights_impl import ExpandedWeight + +def is_batch_first(expanded_args_and_kwargs): + batch_first = None + for arg in expanded_args_and_kwargs: + if not isinstance(arg, ExpandedWeight): + continue + + if not batch_first: + batch_first = arg.batch_first + elif arg.batch_first != batch_first: + raise RuntimeError("Got conflicting batch_first arguments in the same layer") + return batch_first + +def standard_kwargs(kwarg_names, expanded_args): + r"""Separate args and kwargs from `__torch_function__`s that standardize kwargs. + + Most `__torch_function__`s standardize the kwargs that they give, so this will separate + the args and kwargs they pass. Functions that don't are linear and convND. + """ + kwarg_values = expanded_args[len(expanded_args) - len(kwarg_names):] + expanded_args_without_kwargs = expanded_args[:len(expanded_args) - len(kwarg_names)] + expanded_kwargs = dict(zip(kwarg_names, kwarg_values)) + return expanded_args_without_kwargs, expanded_kwargs + +def forward_helper(func, expanded_args, expanded_kwargs): + r"""Compute the forward pass for a function that has expanded weight(s) passed to it. + + It will run the forward pass where all ExpandedWeights are their original + weight. It runs checks on the given arguments and detaches the outputs. + + .. note:: First argument in :attr:`expanded_args` must be the input with the batch + dimension as the first element of the shape + + .. note:: :attr:`func` must return a Tensor or tuple of Tensors + + Args: + func: The function to be called + expanded_args: Arguments to be passed to :attr:`func`. Will include arguments + that need to be unpacked because they are ExpandedWeights + expanded_kwargs: Keyword arguments to be passed to :attr:`func`. + Similar to :attr:`expanded_args`. + """ + unexpanded_args, unexpanded_kwargs = _check_and_unexpand_args(func, expanded_args, expanded_kwargs) + return func(*unexpanded_args, **unexpanded_kwargs) + +def _check_and_unexpand_args(func, expanded_args, expanded_kwargs): + # input must be the first argument passed + input = expanded_args[0] + if isinstance(input, ExpandedWeight): + raise RuntimeError("Expanded Weights do not support inputs that are also ExpandedWeights. " + f"Input must be a Tensor, got {type(input).__name__} in function {func.__name__}") + if not isinstance(input, torch.Tensor): + raise RuntimeError("Expanded Weights requires a Tensor as the first input to get the batch dimension, " + f"got {type(input).__name__} in function {func.__name__}") + if len(input.shape) == 0: + raise RuntimeError(f"Expanded Weights requires a batch dimension but got an input of size 0 in function {func.__name__}") + if input.shape[0] == 0: + raise RuntimeError("0 is not a valid batch size for Expanded Weights but got input tensor of " + f"{input} in function {func.__name__}") + for arg in expanded_args + tuple(expanded_kwargs.values()): + if not isinstance(arg, ExpandedWeight): + continue + batch_size = input.shape[0] if arg.batch_first else input.shape[1] + if (arg.allow_smaller_batches and batch_size > arg.batch_size) or \ + (not arg.allow_smaller_batches and arg.batch_size != batch_size): + raise RuntimeError("Expected ExpandedWeights to have batch size matching input but got " + f"input batch size of {batch_size} with ExpandedWeight of batch size {arg.batch_size}") + + loss_reduction: Optional[str] = None + for arg in expanded_args + tuple(expanded_kwargs.values()): + if isinstance(arg, ExpandedWeight): + if loss_reduction is None: + loss_reduction = arg.loss_reduction + elif loss_reduction != arg.loss_reduction: + raise RuntimeError("Expected ExpandedWeights to all have the same loss_reduction argument but got one" + f"with {loss_reduction} and one with {arg.loss_reduction}") + + unexpanded_args = tuple(arg.orig_weight if isinstance(arg, ExpandedWeight) else arg for arg in expanded_args) + unexpanded_kwargs = {name: arg.orig_weight if isinstance(arg, ExpandedWeight) else arg + for (name, arg) in expanded_kwargs.items()} + return unexpanded_args, unexpanded_kwargs + +def maybe_scale_by_batch_size(grad_sample, expanded_weight): + if expanded_weight.loss_reduction == "mean": + return grad_sample * expanded_weight.batch_size + else: + return grad_sample + +def set_grad_sample_if_exists(maybe_expanded_weight, per_sample_grad_fn): + unpacked = unpack_expanded_weight_or_tensor(maybe_expanded_weight) + if isinstance(maybe_expanded_weight, ExpandedWeight): + grad_sample_contribution = maybe_scale_by_batch_size(per_sample_grad_fn(unpacked), maybe_expanded_weight) + + if maybe_expanded_weight.batch_size > grad_sample_contribution.shape[0]: + # this only passes the other checks if the arg allows smaller batch sizes + intermediate = torch.zeros(maybe_expanded_weight.batch_size, *grad_sample_contribution.shape[1:], + dtype=grad_sample_contribution.dtype, + device=grad_sample_contribution.device) + intermediate[:grad_sample_contribution.shape[0]] = grad_sample_contribution + grad_sample_contribution = intermediate + + if hasattr(unpacked, "grad_sample") and unpacked.grad_sample is not None: + unpacked.grad_sample = unpacked.grad_sample + grad_sample_contribution + else: + unpacked.grad_sample = grad_sample_contribution + +def unpack_expanded_weight_or_tensor(maybe_expanded_weight, func=lambda x: x): + if isinstance(maybe_expanded_weight, ExpandedWeight): + orig_weight = maybe_expanded_weight.orig_weight + return func(orig_weight) + elif isinstance(maybe_expanded_weight, torch.Tensor) and not maybe_expanded_weight.requires_grad: + return func(maybe_expanded_weight) + elif isinstance(maybe_expanded_weight, torch.Tensor): + raise RuntimeError("ExpandedWeights currently does not support a mixture of ExpandedWeight parameters " + "and normal Parameters. Please file and issue with pytorch/pytorch") + + + +def sum_over_all_but_batch_and_last_n( + tensor: torch.Tensor, n_dims: int +) -> torch.Tensor: + r""" + Calculate the sum over all dimensions, except the first (batch dimension), and excluding the last n_dims. + + This function will ignore the first dimension and it will + not aggregate over the last n_dims dimensions. + Args: + tensor: An input tensor of shape ``(B, ..., X[n_dims-1])``. + n_dims: Number of dimensions to keep. + Example: + >>> tensor = torch.ones(1, 2, 3, 4, 5) + >>> sum_over_all_but_batch_and_last_n(tensor, n_dims=2).shape + torch.Size([1, 4, 5]) + Returns: + A tensor of shape ``(B, ..., X[n_dims-1])`` + """ + if tensor.dim() == n_dims + 1: + return tensor + else: + dims = list(range(1, tensor.dim() - n_dims)) + return tensor.sum(dim=dims) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/group_norm_expanded_weights.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/group_norm_expanded_weights.py new file mode 100644 index 0000000000000000000000000000000000000000..fe29b1eafbe2c0be87a96f4e24d8c026b310b3d7 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/group_norm_expanded_weights.py @@ -0,0 +1,64 @@ +from functools import reduce +import operator +import torch +import torch.nn.functional as F +from .expanded_weights_impl import ExpandedWeight, implements_per_sample_grads +from .expanded_weights_utils import standard_kwargs, \ + forward_helper, set_grad_sample_if_exists, unpack_expanded_weight_or_tensor +from typing import List, Optional + +@implements_per_sample_grads(F.group_norm) +class GroupNormPerSampleGrad(torch.autograd.Function): + @staticmethod + def forward(ctx, kwarg_names, _, *expanded_args_and_kwargs): + expanded_args, expanded_kwargs = standard_kwargs(kwarg_names, expanded_args_and_kwargs) + input, num_groups = expanded_args + N = input.shape[0] + C = input.shape[1] + HxW = reduce(operator.mul, input.shape[2:], 1) + weight, bias, eps = expanded_kwargs['weight'], expanded_kwargs['bias'], expanded_kwargs['eps'] + output, mean, rstd = forward_helper(torch.native_group_norm, (input, weight, bias, N, C, HxW, num_groups, eps), {}) + ctx.input, ctx.num_groups = input, num_groups + ctx.weight, ctx.eps = weight, eps + ctx.mean, ctx.rstd = mean, rstd + if isinstance(bias, ExpandedWeight): + ctx.bias = bias + if input.requires_grad and isinstance(weight, ExpandedWeight): + ctx.weight = weight + return output + + @staticmethod + def backward(ctx, grad_output): + input, num_groups = ctx.input, ctx.num_groups + weight, bias, eps = ctx.weight, ctx.bias, ctx.eps + mean, rstd = ctx.mean, ctx.rstd + + results: List[Optional[torch.Tensor]] = [] + results.append(None) # for kwarg names + results.append(None) # for op reference + + if input.requires_grad: + weight_c = unpack_expanded_weight_or_tensor(weight, lambda t: t.contiguous()) + input_c = input.contiguous() + grad_output_c = grad_output.contiguous() if grad_output is not None else None + N = input.shape[0] + C = input.shape[1] + HxW = 1 + for s in input.shape[2:]: + HxW *= s + bw_fn = torch.ops.aten.native_group_norm_backward + results.append(bw_fn(grad_output_c, input_c, + mean, rstd, weight_c, N, C, HxW, num_groups, (True, False, False))[0]) + else: + results.append(None) + + # weight and bias don't compute batched gradients; no other arguments are differentiable + results = results + [None] * 4 + + # set grad_sample field for weight and bias with per sample gradients + if hasattr(ctx, "weight"): + set_grad_sample_if_exists(weight, + lambda _: torch.einsum("ni...->ni", F.group_norm(input, num_groups, eps=eps) * grad_output)) + if hasattr(ctx, "bias"): + set_grad_sample_if_exists(bias, lambda _: torch.einsum("ni...->ni", grad_output)) + return tuple(results) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/layer_norm_expanded_weights.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/layer_norm_expanded_weights.py new file mode 100644 index 0000000000000000000000000000000000000000..f2ead2d4c08fb03aafec2469d86c672ebe9bb222 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/layer_norm_expanded_weights.py @@ -0,0 +1,59 @@ + +import torch +import torch.nn.functional as F +from .expanded_weights_impl import ExpandedWeight, implements_per_sample_grads +from .expanded_weights_utils import forward_helper, set_grad_sample_if_exists, \ + standard_kwargs, sum_over_all_but_batch_and_last_n, unpack_expanded_weight_or_tensor +from typing import List, Optional + +@implements_per_sample_grads(F.layer_norm) +class LayerNormPerSampleGrad(torch.autograd.Function): + @staticmethod + def forward(ctx, kwarg_names, _, *expanded_args_and_kwargs): + expanded_args, expanded_kwargs = standard_kwargs(kwarg_names, expanded_args_and_kwargs) + input = expanded_args[0] + normalized_shape = expanded_args[1] + if len(input.shape) <= len(normalized_shape): + raise RuntimeError("Expanded Weights: Layer norm should not normalize over batch dimension for per sample gradient" + f"computations but got that normalized shape, {normalized_shape}, matched input shape.") + output, mean, rstd = forward_helper(torch.native_layer_norm, expanded_args, expanded_kwargs) + ctx.args = expanded_args + + if input.requires_grad or isinstance(expanded_kwargs['weight'], ExpandedWeight): + ctx.weight = expanded_kwargs['weight'] + if input.requires_grad or isinstance(expanded_kwargs['bias'], ExpandedWeight): + ctx.bias = expanded_kwargs['bias'] + ctx.eps = expanded_kwargs['eps'] + ctx.mean, ctx.rstd = mean, rstd + return output + + + @staticmethod + def backward(ctx, grad_output): + + def weight_per_sample_grad(weight): + return sum_over_all_but_batch_and_last_n(F.layer_norm(input, normalized_shape, eps=ctx.eps) * grad_output, weight.dim()) + + input, normalized_shape = ctx.args + mean, rstd = ctx.mean, ctx.rstd + + results: List[Optional[torch.Tensor]] = [] + results.append(None) # for kwarg names + results.append(None) # for op reference + if input.requires_grad: + weight_ = unpack_expanded_weight_or_tensor(ctx.weight) + bias_ = unpack_expanded_weight_or_tensor(ctx.bias) + results.append(torch.ops.aten.native_layer_norm_backward( + grad_output, input, normalized_shape, mean, rstd, weight_, bias_, (True, False, False))[0]) + else: + results.append(None) + + # weight and bias don't compute batched gradients; no other arguments are differentiable + results = results + [None] * 4 + + # set grad_sample field for weight and bias with per sample gradients + if hasattr(ctx, "weight"): + set_grad_sample_if_exists(ctx.weight, weight_per_sample_grad) + if hasattr(ctx, "bias"): + set_grad_sample_if_exists(ctx.bias, lambda bias: sum_over_all_but_batch_and_last_n(grad_output, bias.dim())) + return tuple(results) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/clip_grad.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/clip_grad.py new file mode 100644 index 0000000000000000000000000000000000000000..6549a6f3e2c8db1c9f46ba5f6a28d641e8871f6f --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/clip_grad.py @@ -0,0 +1,151 @@ +import warnings +import functools +from typing import Union, Iterable, List, Dict, Tuple, Optional, cast + +import torch +from torch import Tensor +from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype, _has_foreach_support, _device_has_foreach_support + +_tensor_or_tensors = Union[torch.Tensor, Iterable[torch.Tensor]] + +__all__ = ['clip_grad_norm_', 'clip_grad_norm', 'clip_grad_value_'] + +def _no_grad(func): + """ + This wrapper is needed to avoid a circular import when using @torch.no_grad on the exposed functions + clip_grad_norm_ and clip_grad_value_ themselves. + """ + def _no_grad_wrapper(*args, **kwargs): + with torch.no_grad(): + return func(*args, **kwargs) + functools.update_wrapper(_no_grad_wrapper, func) + return _no_grad_wrapper + +@_no_grad +def clip_grad_norm_( + parameters: _tensor_or_tensors, max_norm: float, norm_type: float = 2.0, + error_if_nonfinite: bool = False, foreach: Optional[bool] = None) -> torch.Tensor: + r"""Clip the gradient norm of an iterable of parameters. + + The norm is computed over all gradients together, as if they were + concatenated into a single vector. Gradients are modified in-place. + + Args: + parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a + single Tensor that will have gradients normalized + max_norm (float): max norm of the gradients + norm_type (float): type of the used p-norm. Can be ``'inf'`` for + infinity norm. + error_if_nonfinite (bool): if True, an error is thrown if the total + norm of the gradients from :attr:`parameters` is ``nan``, + ``inf``, or ``-inf``. Default: False (will switch to True in the future) + foreach (bool): use the faster foreach-based implementation. + If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently + fall back to the slow implementation for other device types. + Default: ``None`` + + Returns: + Total norm of the parameter gradients (viewed as a single vector). + """ + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + grads = [p.grad for p in parameters if p.grad is not None] + max_norm = float(max_norm) + norm_type = float(norm_type) + if len(grads) == 0: + return torch.tensor(0.) + first_device = grads[0].device + grouped_grads: Dict[Tuple[torch.device, torch.dtype], Tuple[List[List[Tensor]], List[int]]] \ + = _group_tensors_by_device_and_dtype([grads]) # type: ignore[assignment] + + norms: List[Tensor] = [] + for ((device, _), ([device_grads], _)) in grouped_grads.items(): # type: ignore[assignment] + if ( + (foreach is None and _has_foreach_support(device_grads, device)) + or (foreach and _device_has_foreach_support(device)) + ): + norms.extend(torch._foreach_norm(device_grads, norm_type)) + elif foreach: + raise RuntimeError(f'foreach=True was passed, but can\'t use the foreach API on {device.type} tensors') + else: + norms.extend([torch.linalg.vector_norm(g, norm_type) for g in device_grads]) + + total_norm = torch.linalg.vector_norm(torch.stack([norm.to(first_device) for norm in norms]), norm_type) + + if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()): + raise RuntimeError( + f'The total norm of order {norm_type} for gradients from ' + '`parameters` is non-finite, so it cannot be clipped. To disable ' + 'this error and scale the gradients by the non-finite norm anyway, ' + 'set `error_if_nonfinite=False`') + clip_coef = max_norm / (total_norm + 1e-6) + # Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so + # avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization + # when the gradients do not reside in CPU memory. + clip_coef_clamped = torch.clamp(clip_coef, max=1.0) + for ((device, _), ([device_grads], _)) in grouped_grads.items(): # type: ignore[assignment] + if ( + (foreach is None and _has_foreach_support(device_grads, device)) + or (foreach and _device_has_foreach_support(device)) + ): + torch._foreach_mul_(device_grads, clip_coef_clamped.to(device)) + elif foreach: + raise RuntimeError(f'foreach=True was passed, but can\'t use the foreach API on {device.type} tensors') + else: + clip_coef_clamped_device = clip_coef_clamped.to(device) + for g in device_grads: + g.mul_(clip_coef_clamped_device) + + return total_norm + + +def clip_grad_norm( + parameters: _tensor_or_tensors, max_norm: float, norm_type: float = 2., + error_if_nonfinite: bool = False, foreach: Optional[bool] = None) -> torch.Tensor: + r"""Clip the gradient norm of an iterable of parameters. + + .. warning:: + This method is now deprecated in favor of + :func:`torch.nn.utils.clip_grad_norm_`. + """ + warnings.warn("torch.nn.utils.clip_grad_norm is now deprecated in favor " + "of torch.nn.utils.clip_grad_norm_.", stacklevel=2) + return clip_grad_norm_(parameters, max_norm, norm_type, error_if_nonfinite, foreach) + + +@_no_grad +def clip_grad_value_(parameters: _tensor_or_tensors, clip_value: float, foreach: Optional[bool] = None) -> None: + r"""Clip the gradients of an iterable of parameters at specified value. + + Gradients are modified in-place. + + Args: + parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a + single Tensor that will have gradients normalized + clip_value (float): maximum allowed value of the gradients. + The gradients are clipped in the range + :math:`\left[\text{-clip\_value}, \text{clip\_value}\right]` + foreach (bool): use the faster foreach-based implementation + If ``None``, use the foreach implementation for CUDA and CPU native tensors and + silently fall back to the slow implementation for other device types. + Default: ``None`` + """ + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + clip_value = float(clip_value) + + grads = [p.grad for p in parameters if p.grad is not None] + grouped_grads = _group_tensors_by_device_and_dtype([grads]) + + for ((device, _), ([grads], _)) in grouped_grads.items(): # type: ignore[assignment] + if ( + (foreach is None and _has_foreach_support(cast(List[Tensor], grads), device=device)) + or (foreach and _device_has_foreach_support(device)) + ): + torch._foreach_clamp_min_(cast(List[Tensor], grads), -clip_value) + torch._foreach_clamp_max_(cast(List[Tensor], grads), clip_value) + elif foreach: + raise RuntimeError(f'foreach=True was passed, but can\'t use the foreach API on {device.type} tensors') + else: + for grad in grads: + cast(Tensor, grad).clamp_(min=-clip_value, max=clip_value) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/prune.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/prune.py new file mode 100644 index 0000000000000000000000000000000000000000..0375106d69e02d872372d8ae61fb163950bba848 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/prune.py @@ -0,0 +1,1379 @@ +r"""Pruning methods.""" +import numbers +from abc import ABC, abstractmethod +from collections.abc import Iterable +from typing import Tuple + +import torch + + +class BasePruningMethod(ABC): + r"""Abstract base class for creation of new pruning techniques. + + Provides a skeleton for customization requiring the overriding of methods + such as :meth:`compute_mask` and :meth:`apply`. + """ + + _tensor_name: str + + def __call__(self, module, inputs): + r"""Multiply the mask into original tensor and store the result. + + Multiplies the mask (stored in ``module[name + '_mask']``) + into the original tensor (stored in ``module[name + '_orig']``) + and stores the result into ``module[name]`` by using :meth:`apply_mask`. + + Args: + module (nn.Module): module containing the tensor to prune + inputs: not used. + """ + setattr(module, self._tensor_name, self.apply_mask(module)) + + @abstractmethod + def compute_mask(self, t, default_mask): + r"""Compute and returns a mask for the input tensor ``t``. + + Starting from a base ``default_mask`` (which should be a mask of ones + if the tensor has not been pruned yet), generate a random mask to + apply on top of the ``default_mask`` according to the specific pruning + method recipe. + + Args: + t (torch.Tensor): tensor representing the importance scores of the + parameter to prune. + default_mask (torch.Tensor): Base mask from previous pruning + iterations, that need to be respected after the new mask is + applied. Same dims as ``t``. + + Returns: + mask (torch.Tensor): mask to apply to ``t``, of same dims as ``t`` + """ + pass + + def apply_mask(self, module): + r"""Simply handles the multiplication between the parameter being pruned and the generated mask. + + Fetches the mask and the original tensor from the module + and returns the pruned version of the tensor. + + Args: + module (nn.Module): module containing the tensor to prune + + Returns: + pruned_tensor (torch.Tensor): pruned version of the input tensor + """ + # to carry out the multiplication, the mask needs to have been computed, + # so the pruning method must know what tensor it's operating on + assert self._tensor_name is not None, f"Module {module} has to be pruned" # this gets set in apply() + mask = getattr(module, self._tensor_name + "_mask") + orig = getattr(module, self._tensor_name + "_orig") + pruned_tensor = mask.to(dtype=orig.dtype) * orig + return pruned_tensor + + @classmethod + def apply(cls, module, name, *args, importance_scores=None, **kwargs): + r"""Add pruning on the fly and reparametrization of a tensor. + + Adds the forward pre-hook that enables pruning on the fly and + the reparametrization of a tensor in terms of the original tensor + and the pruning mask. + + Args: + module (nn.Module): module containing the tensor to prune + name (str): parameter name within ``module`` on which pruning + will act. + args: arguments passed on to a subclass of + :class:`BasePruningMethod` + importance_scores (torch.Tensor): tensor of importance scores (of + same shape as module parameter) used to compute mask for pruning. + The values in this tensor indicate the importance of the + corresponding elements in the parameter being pruned. + If unspecified or None, the parameter will be used in its place. + kwargs: keyword arguments passed on to a subclass of a + :class:`BasePruningMethod` + """ + + def _get_composite_method(cls, module, name, *args, **kwargs): + # Check if a pruning method has already been applied to + # `module[name]`. If so, store that in `old_method`. + old_method = None + found = 0 + # there should technically be only 1 hook with hook.name == name + # assert this using `found` + hooks_to_remove = [] + for k, hook in module._forward_pre_hooks.items(): + # if it exists, take existing thing, remove hook, then + # go through normal thing + if isinstance(hook, BasePruningMethod) and hook._tensor_name == name: + old_method = hook + hooks_to_remove.append(k) + found += 1 + assert ( + found <= 1 + ), f"Avoid adding multiple pruning hooks to the\ + same tensor {name} of module {module}. Use a PruningContainer." + + for k in hooks_to_remove: + del module._forward_pre_hooks[k] + + # Apply the new pruning method, either from scratch or on top of + # the previous one. + method = cls(*args, **kwargs) # new pruning + # Have the pruning method remember what tensor it's been applied to + method._tensor_name = name + + # combine `methods` with `old_method`, if `old_method` exists + if old_method is not None: # meaning that there was a hook + # if the hook is already a pruning container, just add the + # new pruning method to the container + if isinstance(old_method, PruningContainer): + old_method.add_pruning_method(method) + method = old_method # rename old_method --> method + + # if the hook is simply a single pruning method, create a + # container, add the old pruning method and the new one + elif isinstance(old_method, BasePruningMethod): + container = PruningContainer(old_method) + # Have the pruning method remember the name of its tensor + # setattr(container, '_tensor_name', name) + container.add_pruning_method(method) + method = container # rename container --> method + return method + + method = _get_composite_method(cls, module, name, *args, **kwargs) + # at this point we have no forward_pre_hooks but we could have an + # active reparametrization of the tensor if another pruning method + # had been applied (in which case `method` would be a PruningContainer + # and not a simple pruning method). + + # Pruning is to be applied to the module's tensor named `name`, + # starting from the state it is found in prior to this iteration of + # pruning. The pruning mask is calculated based on importances scores. + + orig = getattr(module, name) + if importance_scores is not None: + assert ( + importance_scores.shape == orig.shape + ), f"importance_scores should have the same shape as parameter {name} of {module}" + else: + importance_scores = orig + + # If this is the first time pruning is applied, take care of moving + # the original tensor to a new parameter called name + '_orig' and + # and deleting the original parameter + if not isinstance(method, PruningContainer): + # copy `module[name]` to `module[name + '_orig']` + module.register_parameter(name + "_orig", orig) + # temporarily delete `module[name]` + del module._parameters[name] + default_mask = torch.ones_like(orig) # temp + # If this is not the first time pruning is applied, all of the above + # has been done before in a previous pruning iteration, so we're good + # to go + else: + default_mask = ( + getattr(module, name + "_mask") + .detach() + .clone(memory_format=torch.contiguous_format) + ) + + # Use try/except because if anything goes wrong with the mask + # computation etc., you'd want to roll back. + try: + # get the final mask, computed according to the specific method + mask = method.compute_mask(importance_scores, default_mask=default_mask) + # reparameterize by saving mask to `module[name + '_mask']`... + module.register_buffer(name + "_mask", mask) + # ... and the new pruned tensor to `module[name]` + setattr(module, name, method.apply_mask(module)) + # associate the pruning method to the module via a hook to + # compute the function before every forward() (compile by run) + module.register_forward_pre_hook(method) + + except Exception as e: + if not isinstance(method, PruningContainer): + orig = getattr(module, name + "_orig") + module.register_parameter(name, orig) + del module._parameters[name + "_orig"] + raise e + + return method + + def prune(self, t, default_mask=None, importance_scores=None): + r"""Compute and returns a pruned version of input tensor ``t``. + + According to the pruning rule specified in :meth:`compute_mask`. + + Args: + t (torch.Tensor): tensor to prune (of same dimensions as + ``default_mask``). + importance_scores (torch.Tensor): tensor of importance scores (of + same shape as ``t``) used to compute mask for pruning ``t``. + The values in this tensor indicate the importance of the + corresponding elements in the ``t`` that is being pruned. + If unspecified or None, the tensor ``t`` will be used in its place. + default_mask (torch.Tensor, optional): mask from previous pruning + iteration, if any. To be considered when determining what + portion of the tensor that pruning should act on. If None, + default to a mask of ones. + + Returns: + pruned version of tensor ``t``. + """ + if importance_scores is not None: + assert ( + importance_scores.shape == t.shape + ), "importance_scores should have the same shape as tensor t" + else: + importance_scores = t + default_mask = default_mask if default_mask is not None else torch.ones_like(t) + return t * self.compute_mask(importance_scores, default_mask=default_mask) + + def remove(self, module): + r"""Remove the pruning reparameterization from a module. + + The pruned parameter named ``name`` remains permanently pruned, + and the parameter named ``name+'_orig'`` is removed from the parameter list. + Similarly, the buffer named ``name+'_mask'`` is removed from the buffers. + + Note: + Pruning itself is NOT undone or reversed! + """ + # before removing pruning from a tensor, it has to have been applied + assert ( + self._tensor_name is not None + ), f"Module {module} has to be pruned before pruning can be removed" # this gets set in apply() + + # to update module[name] to latest trained weights + weight = self.apply_mask(module) # masked weights + + # delete and reset + if hasattr(module, self._tensor_name): + delattr(module, self._tensor_name) + orig = module._parameters[self._tensor_name + "_orig"] + orig.data = weight.data + del module._parameters[self._tensor_name + "_orig"] + del module._buffers[self._tensor_name + "_mask"] + setattr(module, self._tensor_name, orig) + + +class PruningContainer(BasePruningMethod): + """Container holding a sequence of pruning methods for iterative pruning. + + Keeps track of the order in which pruning methods are applied and handles + combining successive pruning calls. + + Accepts as argument an instance of a BasePruningMethod or an iterable of + them. + """ + + def __init__(self, *args): + self._pruning_methods: Tuple[BasePruningMethod, ...] = tuple() + if not isinstance(args, Iterable): # only 1 item + self._tensor_name = args._tensor_name + self.add_pruning_method(args) + elif len(args) == 1: # only 1 item in a tuple + self._tensor_name = args[0]._tensor_name + self.add_pruning_method(args[0]) + else: # manual construction from list or other iterable (or no args) + for method in args: + self.add_pruning_method(method) + + def add_pruning_method(self, method): + r"""Add a child pruning ``method`` to the container. + + Args: + method (subclass of BasePruningMethod): child pruning method + to be added to the container. + """ + # check that we're adding a pruning method to the container + if not isinstance(method, BasePruningMethod) and method is not None: + raise TypeError( + f"{type(method)} is not a BasePruningMethod subclass" + ) + elif method is not None and self._tensor_name != method._tensor_name: + raise ValueError( + "Can only add pruning methods acting on " + f"the parameter named '{self._tensor_name}' to PruningContainer {self}." + + f" Found '{method._tensor_name}'" + ) + # if all checks passed, add to _pruning_methods tuple + self._pruning_methods += (method,) # type: ignore[operator] + + def __len__(self): + return len(self._pruning_methods) + + def __iter__(self): + return iter(self._pruning_methods) + + def __getitem__(self, idx): + return self._pruning_methods[idx] + + def compute_mask(self, t, default_mask): + r"""Apply the latest ``method`` by computing the new partial masks and returning its combination with the ``default_mask``. + + The new partial mask should be computed on the entries or channels + that were not zeroed out by the ``default_mask``. + Which portions of the tensor ``t`` the new mask will be calculated from + depends on the ``PRUNING_TYPE`` (handled by the type handler): + + * for 'unstructured', the mask will be computed from the raveled + list of nonmasked entries; + + * for 'structured', the mask will be computed from the nonmasked + channels in the tensor; + + * for 'global', the mask will be computed across all entries. + + Args: + t (torch.Tensor): tensor representing the parameter to prune + (of same dimensions as ``default_mask``). + default_mask (torch.Tensor): mask from previous pruning iteration. + + Returns: + mask (torch.Tensor): new mask that combines the effects + of the ``default_mask`` and the new mask from the current + pruning ``method`` (of same dimensions as ``default_mask`` and + ``t``). + """ + + def _combine_masks(method, t, mask): + r"""Combine the masks from all pruning methods and returns a new mask. + + Args: + method (a BasePruningMethod subclass): pruning method + currently being applied. + t (torch.Tensor): tensor representing the parameter to prune + (of same dimensions as mask). + mask (torch.Tensor): mask from previous pruning iteration + + Returns: + new_mask (torch.Tensor): new mask that combines the effects + of the old mask and the new mask from the current + pruning method (of same dimensions as mask and t). + """ + new_mask = mask # start off from existing mask + new_mask = new_mask.to(dtype=t.dtype) + + # compute a slice of t onto which the new pruning method will operate + if method.PRUNING_TYPE == "unstructured": + # prune entries of t where the mask is 1 + slc = mask == 1 + + # for struct pruning, exclude channels that have already been + # entirely pruned + elif method.PRUNING_TYPE == "structured": + if not hasattr(method, "dim"): + raise AttributeError( + "Pruning methods of PRUNING_TYPE " + '"structured" need to have the attribute `dim` defined.' + ) + + # find the channels to keep by removing the ones that have been + # zeroed out already (i.e. where sum(entries) == 0) + n_dims = t.dim() # "is this a 2D tensor? 3D? ..." + dim = method.dim + # convert negative indexing + if dim < 0: + dim = n_dims + dim + # if dim is still negative after subtracting it from n_dims + if dim < 0: + raise IndexError( + f"Index is out of bounds for tensor with dimensions {n_dims}" + ) + # find channels along dim = dim that aren't already tots 0ed out + keep_channel = mask.sum(dim=[d for d in range(n_dims) if d != dim]) != 0 + # create slice to identify what to prune + slc = [slice(None)] * n_dims + slc[dim] = keep_channel + + elif method.PRUNING_TYPE == "global": + n_dims = len(t.shape) # "is this a 2D tensor? 3D? ..." + slc = [slice(None)] * n_dims + + else: + raise ValueError( + f"Unrecognized PRUNING_TYPE {method.PRUNING_TYPE}" + ) + + # compute the new mask on the unpruned slice of the tensor t + partial_mask = method.compute_mask(t[slc], default_mask=mask[slc]) + new_mask[slc] = partial_mask.to(dtype=new_mask.dtype) + + return new_mask + + method = self._pruning_methods[-1] + mask = _combine_masks(method, t, default_mask) + return mask + + +class Identity(BasePruningMethod): + r"""Utility pruning method that does not prune any units but generates the pruning parametrization with a mask of ones.""" + + PRUNING_TYPE = "unstructured" + + def compute_mask(self, t, default_mask): + mask = default_mask + return mask + + @classmethod + def apply(cls, module, name): + r"""Add pruning on the fly and reparametrization of a tensor. + + Adds the forward pre-hook that enables pruning on the fly and + the reparametrization of a tensor in terms of the original tensor + and the pruning mask. + + Args: + module (nn.Module): module containing the tensor to prune + name (str): parameter name within ``module`` on which pruning + will act. + """ + return super().apply(module, name) + + +class RandomUnstructured(BasePruningMethod): + r"""Prune (currently unpruned) units in a tensor at random. + + Args: + name (str): parameter name within ``module`` on which pruning + will act. + amount (int or float): quantity of parameters to prune. + If ``float``, should be between 0.0 and 1.0 and represent the + fraction of parameters to prune. If ``int``, it represents the + absolute number of parameters to prune. + """ + + PRUNING_TYPE = "unstructured" + + def __init__(self, amount): + # Check range of validity of pruning amount + _validate_pruning_amount_init(amount) + self.amount = amount + + def compute_mask(self, t, default_mask): + # Check that the amount of units to prune is not > than the number of + # parameters in t + tensor_size = t.nelement() + # Compute number of units to prune: amount if int, + # else amount * tensor_size + nparams_toprune = _compute_nparams_toprune(self.amount, tensor_size) + # This should raise an error if the number of units to prune is larger + # than the number of units in the tensor + _validate_pruning_amount(nparams_toprune, tensor_size) + + mask = default_mask.clone(memory_format=torch.contiguous_format) + + if nparams_toprune != 0: # k=0 not supported by torch.kthvalue + prob = torch.rand_like(t) + topk = torch.topk(prob.view(-1), k=nparams_toprune) + mask.view(-1)[topk.indices] = 0 + + return mask + + @classmethod + def apply(cls, module, name, amount): + r"""Add pruning on the fly and reparametrization of a tensor. + + Adds the forward pre-hook that enables pruning on the fly and + the reparametrization of a tensor in terms of the original tensor + and the pruning mask. + + Args: + module (nn.Module): module containing the tensor to prune + name (str): parameter name within ``module`` on which pruning + will act. + amount (int or float): quantity of parameters to prune. + If ``float``, should be between 0.0 and 1.0 and represent the + fraction of parameters to prune. If ``int``, it represents the + absolute number of parameters to prune. + """ + return super().apply(module, name, amount=amount) + + +class L1Unstructured(BasePruningMethod): + r"""Prune (currently unpruned) units in a tensor by zeroing out the ones with the lowest L1-norm. + + Args: + amount (int or float): quantity of parameters to prune. + If ``float``, should be between 0.0 and 1.0 and represent the + fraction of parameters to prune. If ``int``, it represents the + absolute number of parameters to prune. + """ + + PRUNING_TYPE = "unstructured" + + def __init__(self, amount): + # Check range of validity of pruning amount + _validate_pruning_amount_init(amount) + self.amount = amount + + def compute_mask(self, t, default_mask): + # Check that the amount of units to prune is not > than the number of + # parameters in t + tensor_size = t.nelement() + # Compute number of units to prune: amount if int, + # else amount * tensor_size + nparams_toprune = _compute_nparams_toprune(self.amount, tensor_size) + # This should raise an error if the number of units to prune is larger + # than the number of units in the tensor + _validate_pruning_amount(nparams_toprune, tensor_size) + + mask = default_mask.clone(memory_format=torch.contiguous_format) + + if nparams_toprune != 0: # k=0 not supported by torch.kthvalue + # largest=True --> top k; largest=False --> bottom k + # Prune the smallest k + topk = torch.topk(torch.abs(t).view(-1), k=nparams_toprune, largest=False) + # topk will have .indices and .values + mask.view(-1)[topk.indices] = 0 + + return mask + + @classmethod + def apply(cls, module, name, amount, importance_scores=None): + r"""Add pruning on the fly and reparametrization of a tensor. + + Adds the forward pre-hook that enables pruning on the fly and + the reparametrization of a tensor in terms of the original tensor + and the pruning mask. + + Args: + module (nn.Module): module containing the tensor to prune + name (str): parameter name within ``module`` on which pruning + will act. + amount (int or float): quantity of parameters to prune. + If ``float``, should be between 0.0 and 1.0 and represent the + fraction of parameters to prune. If ``int``, it represents the + absolute number of parameters to prune. + importance_scores (torch.Tensor): tensor of importance scores (of same + shape as module parameter) used to compute mask for pruning. + The values in this tensor indicate the importance of the corresponding + elements in the parameter being pruned. + If unspecified or None, the module parameter will be used in its place. + """ + return super().apply( + module, name, amount=amount, importance_scores=importance_scores + ) + + +class RandomStructured(BasePruningMethod): + r"""Prune entire (currently unpruned) channels in a tensor at random. + + Args: + amount (int or float): quantity of parameters to prune. + If ``float``, should be between 0.0 and 1.0 and represent the + fraction of parameters to prune. If ``int``, it represents the + absolute number of parameters to prune. + dim (int, optional): index of the dim along which we define + channels to prune. Default: -1. + """ + + PRUNING_TYPE = "structured" + + def __init__(self, amount, dim=-1): + # Check range of validity of amount + _validate_pruning_amount_init(amount) + self.amount = amount + self.dim = dim + + def compute_mask(self, t, default_mask): + r"""Compute and returns a mask for the input tensor ``t``. + + Starting from a base ``default_mask`` (which should be a mask of ones + if the tensor has not been pruned yet), generate a random mask to + apply on top of the ``default_mask`` by randomly zeroing out channels + along the specified dim of the tensor. + + Args: + t (torch.Tensor): tensor representing the parameter to prune + default_mask (torch.Tensor): Base mask from previous pruning + iterations, that need to be respected after the new mask is + applied. Same dims as ``t``. + + Returns: + mask (torch.Tensor): mask to apply to ``t``, of same dims as ``t`` + + Raises: + IndexError: if ``self.dim >= len(t.shape)`` + """ + # Check that tensor has structure (i.e. more than 1 dimension) such + # that the concept of "channels" makes sense + _validate_structured_pruning(t) + + # Check that self.dim is a valid dim to index t, else raise IndexError + _validate_pruning_dim(t, self.dim) + + # Check that the amount of channels to prune is not > than the number of + # channels in t along the dim to prune + tensor_size = t.shape[self.dim] + # Compute number of units to prune: amount if int, + # else amount * tensor_size + nparams_toprune = _compute_nparams_toprune(self.amount, tensor_size) + # This should raise an error if the number of units to prune is larger + # than the number of units in the tensor + _validate_pruning_amount(nparams_toprune, tensor_size) + + # Compute binary mask by initializing it to all 0s and then filling in + # 1s wherever topk.indices indicates, along self.dim. + # mask has the same shape as tensor t + def make_mask(t, dim, nchannels, nchannels_toprune): + # generate a random number in [0, 1] to associate to each channel + prob = torch.rand(nchannels) + # generate mask for each channel by 0ing out the channels that + # got assigned the k = nchannels_toprune lowest values in prob + threshold = torch.kthvalue(prob, k=nchannels_toprune).values + channel_mask = prob > threshold + + mask = torch.zeros_like(t) + slc = [slice(None)] * len(t.shape) + slc[dim] = channel_mask + mask[slc] = 1 + return mask + + if nparams_toprune == 0: # k=0 not supported by torch.kthvalue + mask = default_mask + else: + # apply the new structured mask on top of prior (potentially + # unstructured) mask + mask = make_mask(t, self.dim, tensor_size, nparams_toprune) + mask *= default_mask.to(dtype=mask.dtype) + return mask + + @classmethod + def apply(cls, module, name, amount, dim=-1): + r"""Add pruning on the fly and reparametrization of a tensor. + + Adds the forward pre-hook that enables pruning on the fly and + the reparametrization of a tensor in terms of the original tensor + and the pruning mask. + + Args: + module (nn.Module): module containing the tensor to prune + name (str): parameter name within ``module`` on which pruning + will act. + amount (int or float): quantity of parameters to prune. + If ``float``, should be between 0.0 and 1.0 and represent the + fraction of parameters to prune. If ``int``, it represents the + absolute number of parameters to prune. + dim (int, optional): index of the dim along which we define + channels to prune. Default: -1. + """ + return super().apply(module, name, amount=amount, dim=dim) + + +class LnStructured(BasePruningMethod): + r"""Prune entire (currently unpruned) channels in a tensor based on their L\ ``n``-norm. + + Args: + amount (int or float): quantity of channels to prune. + If ``float``, should be between 0.0 and 1.0 and represent the + fraction of parameters to prune. If ``int``, it represents the + absolute number of parameters to prune. + n (int, float, inf, -inf, 'fro', 'nuc'): See documentation of valid + entries for argument ``p`` in :func:`torch.norm`. + dim (int, optional): index of the dim along which we define + channels to prune. Default: -1. + """ + + PRUNING_TYPE = "structured" + + def __init__(self, amount, n, dim=-1): + # Check range of validity of amount + _validate_pruning_amount_init(amount) + self.amount = amount + self.n = n + self.dim = dim + + def compute_mask(self, t, default_mask): + r"""Compute and returns a mask for the input tensor ``t``. + + Starting from a base ``default_mask`` (which should be a mask of ones + if the tensor has not been pruned yet), generate a mask to apply on + top of the ``default_mask`` by zeroing out the channels along the + specified dim with the lowest L\ ``n``-norm. + + Args: + t (torch.Tensor): tensor representing the parameter to prune + default_mask (torch.Tensor): Base mask from previous pruning + iterations, that need to be respected after the new mask is + applied. Same dims as ``t``. + + Returns: + mask (torch.Tensor): mask to apply to ``t``, of same dims as ``t`` + + Raises: + IndexError: if ``self.dim >= len(t.shape)`` + """ + # Check that tensor has structure (i.e. more than 1 dimension) such + # that the concept of "channels" makes sense + _validate_structured_pruning(t) + # Check that self.dim is a valid dim to index t, else raise IndexError + _validate_pruning_dim(t, self.dim) + + # Check that the amount of channels to prune is not > than the number of + # channels in t along the dim to prune + tensor_size = t.shape[self.dim] + # Compute number of units to prune: amount if int, + # else amount * tensor_size + nparams_toprune = _compute_nparams_toprune(self.amount, tensor_size) + nparams_tokeep = tensor_size - nparams_toprune + # This should raise an error if the number of units to prune is larger + # than the number of units in the tensor + _validate_pruning_amount(nparams_toprune, tensor_size) + + # Structured pruning prunes entire channels so we need to know the + # L_n norm along each channel to then find the topk based on this + # metric + norm = _compute_norm(t, self.n, self.dim) + # largest=True --> top k; largest=False --> bottom k + # Keep the largest k channels along dim=self.dim + topk = torch.topk(norm, k=nparams_tokeep, largest=True) + # topk will have .indices and .values + + # Compute binary mask by initializing it to all 0s and then filling in + # 1s wherever topk.indices indicates, along self.dim. + # mask has the same shape as tensor t + def make_mask(t, dim, indices): + # init mask to 0 + mask = torch.zeros_like(t) + # e.g.: slc = [None, None, None], if len(t.shape) = 3 + slc = [slice(None)] * len(t.shape) + # replace a None at position=dim with indices + # e.g.: slc = [None, None, [0, 2, 3]] if dim=2 & indices=[0,2,3] + slc[dim] = indices + # use slc to slice mask and replace all its entries with 1s + # e.g.: mask[:, :, [0, 2, 3]] = 1 + mask[slc] = 1 + return mask + + if nparams_toprune == 0: # k=0 not supported by torch.kthvalue + mask = default_mask + else: + mask = make_mask(t, self.dim, topk.indices) + mask *= default_mask.to(dtype=mask.dtype) + + return mask + + @classmethod + def apply(cls, module, name, amount, n, dim, importance_scores=None): + r"""Add pruning on the fly and reparametrization of a tensor. + + Adds the forward pre-hook that enables pruning on the fly and + the reparametrization of a tensor in terms of the original tensor + and the pruning mask. + + Args: + module (nn.Module): module containing the tensor to prune + name (str): parameter name within ``module`` on which pruning + will act. + amount (int or float): quantity of parameters to prune. + If ``float``, should be between 0.0 and 1.0 and represent the + fraction of parameters to prune. If ``int``, it represents the + absolute number of parameters to prune. + n (int, float, inf, -inf, 'fro', 'nuc'): See documentation of valid + entries for argument ``p`` in :func:`torch.norm`. + dim (int): index of the dim along which we define channels to + prune. + importance_scores (torch.Tensor): tensor of importance scores (of same + shape as module parameter) used to compute mask for pruning. + The values in this tensor indicate the importance of the corresponding + elements in the parameter being pruned. + If unspecified or None, the module parameter will be used in its place. + """ + return super().apply( + module, + name, + amount=amount, + n=n, + dim=dim, + importance_scores=importance_scores, + ) + + +class CustomFromMask(BasePruningMethod): + + PRUNING_TYPE = "global" + + def __init__(self, mask): + self.mask = mask + + def compute_mask(self, t, default_mask): + assert default_mask.shape == self.mask.shape + mask = default_mask * self.mask.to(dtype=default_mask.dtype) + return mask + + @classmethod + def apply(cls, module, name, mask): + r"""Add pruning on the fly and reparametrization of a tensor. + + Adds the forward pre-hook that enables pruning on the fly and + the reparametrization of a tensor in terms of the original tensor + and the pruning mask. + + Args: + module (nn.Module): module containing the tensor to prune + name (str): parameter name within ``module`` on which pruning + will act. + """ + return super().apply(module, name, mask=mask) + + +def identity(module, name): + r"""Apply pruning reparametrization without pruning any units. + + Applies pruning reparametrization to the tensor corresponding to the + parameter called ``name`` in ``module`` without actually pruning any + units. Modifies module in place (and also return the modified module) + by: + + 1) adding a named buffer called ``name+'_mask'`` corresponding to the + binary mask applied to the parameter ``name`` by the pruning method. + 2) replacing the parameter ``name`` by its pruned version, while the + original (unpruned) parameter is stored in a new parameter named + ``name+'_orig'``. + + Note: + The mask is a tensor of ones. + + Args: + module (nn.Module): module containing the tensor to prune. + name (str): parameter name within ``module`` on which pruning + will act. + + Returns: + module (nn.Module): modified (i.e. pruned) version of the input module + + Examples: + >>> # xdoctest: +SKIP + >>> m = prune.identity(nn.Linear(2, 3), 'bias') + >>> print(m.bias_mask) + tensor([1., 1., 1.]) + """ + Identity.apply(module, name) + return module + + +def random_unstructured(module, name, amount): + r"""Prune tensor by removing random (currently unpruned) units. + + Prunes tensor corresponding to parameter called ``name`` in ``module`` + by removing the specified ``amount`` of (currently unpruned) units + selected at random. + Modifies module in place (and also return the modified module) by: + + 1) adding a named buffer called ``name+'_mask'`` corresponding to the + binary mask applied to the parameter ``name`` by the pruning method. + 2) replacing the parameter ``name`` by its pruned version, while the + original (unpruned) parameter is stored in a new parameter named + ``name+'_orig'``. + + Args: + module (nn.Module): module containing the tensor to prune + name (str): parameter name within ``module`` on which pruning + will act. + amount (int or float): quantity of parameters to prune. + If ``float``, should be between 0.0 and 1.0 and represent the + fraction of parameters to prune. If ``int``, it represents the + absolute number of parameters to prune. + + Returns: + module (nn.Module): modified (i.e. pruned) version of the input module + + Examples: + >>> # xdoctest: +SKIP + >>> m = prune.random_unstructured(nn.Linear(2, 3), 'weight', amount=1) + >>> torch.sum(m.weight_mask == 0) + tensor(1) + + """ + RandomUnstructured.apply(module, name, amount) + return module + + +def l1_unstructured(module, name, amount, importance_scores=None): + r"""Prune tensor by removing units with the lowest L1-norm. + + Prunes tensor corresponding to parameter called ``name`` in ``module`` + by removing the specified `amount` of (currently unpruned) units with the + lowest L1-norm. + Modifies module in place (and also return the modified module) + by: + + 1) adding a named buffer called ``name+'_mask'`` corresponding to the + binary mask applied to the parameter ``name`` by the pruning method. + 2) replacing the parameter ``name`` by its pruned version, while the + original (unpruned) parameter is stored in a new parameter named + ``name+'_orig'``. + + Args: + module (nn.Module): module containing the tensor to prune + name (str): parameter name within ``module`` on which pruning + will act. + amount (int or float): quantity of parameters to prune. + If ``float``, should be between 0.0 and 1.0 and represent the + fraction of parameters to prune. If ``int``, it represents the + absolute number of parameters to prune. + importance_scores (torch.Tensor): tensor of importance scores (of same + shape as module parameter) used to compute mask for pruning. + The values in this tensor indicate the importance of the corresponding + elements in the parameter being pruned. + If unspecified or None, the module parameter will be used in its place. + + Returns: + module (nn.Module): modified (i.e. pruned) version of the input module + + Examples: + >>> # xdoctest: +SKIP + >>> m = prune.l1_unstructured(nn.Linear(2, 3), 'weight', amount=0.2) + >>> m.state_dict().keys() + odict_keys(['bias', 'weight_orig', 'weight_mask']) + """ + L1Unstructured.apply( + module, name, amount=amount, importance_scores=importance_scores + ) + return module + + +def random_structured(module, name, amount, dim): + r"""Prune tensor by removing random channels along the specified dimension. + + Prunes tensor corresponding to parameter called ``name`` in ``module`` + by removing the specified ``amount`` of (currently unpruned) channels + along the specified ``dim`` selected at random. + Modifies module in place (and also return the modified module) + by: + + 1) adding a named buffer called ``name+'_mask'`` corresponding to the + binary mask applied to the parameter ``name`` by the pruning method. + 2) replacing the parameter ``name`` by its pruned version, while the + original (unpruned) parameter is stored in a new parameter named + ``name+'_orig'``. + + Args: + module (nn.Module): module containing the tensor to prune + name (str): parameter name within ``module`` on which pruning + will act. + amount (int or float): quantity of parameters to prune. + If ``float``, should be between 0.0 and 1.0 and represent the + fraction of parameters to prune. If ``int``, it represents the + absolute number of parameters to prune. + dim (int): index of the dim along which we define channels to prune. + + Returns: + module (nn.Module): modified (i.e. pruned) version of the input module + + Examples: + >>> # xdoctest: +SKIP + >>> m = prune.random_structured( + ... nn.Linear(5, 3), 'weight', amount=3, dim=1 + ... ) + >>> columns_pruned = int(sum(torch.sum(m.weight, dim=0) == 0)) + >>> print(columns_pruned) + 3 + """ + RandomStructured.apply(module, name, amount, dim) + return module + + +def ln_structured(module, name, amount, n, dim, importance_scores=None): + r"""Prune tensor by removing channels with the lowest L\ ``n``-norm along the specified dimension. + + Prunes tensor corresponding to parameter called ``name`` in ``module`` + by removing the specified ``amount`` of (currently unpruned) channels + along the specified ``dim`` with the lowest L\ ``n``-norm. + Modifies module in place (and also return the modified module) + by: + + 1) adding a named buffer called ``name+'_mask'`` corresponding to the + binary mask applied to the parameter ``name`` by the pruning method. + 2) replacing the parameter ``name`` by its pruned version, while the + original (unpruned) parameter is stored in a new parameter named + ``name+'_orig'``. + + Args: + module (nn.Module): module containing the tensor to prune + name (str): parameter name within ``module`` on which pruning + will act. + amount (int or float): quantity of parameters to prune. + If ``float``, should be between 0.0 and 1.0 and represent the + fraction of parameters to prune. If ``int``, it represents the + absolute number of parameters to prune. + n (int, float, inf, -inf, 'fro', 'nuc'): See documentation of valid + entries for argument ``p`` in :func:`torch.norm`. + dim (int): index of the dim along which we define channels to prune. + importance_scores (torch.Tensor): tensor of importance scores (of same + shape as module parameter) used to compute mask for pruning. + The values in this tensor indicate the importance of the corresponding + elements in the parameter being pruned. + If unspecified or None, the module parameter will be used in its place. + + Returns: + module (nn.Module): modified (i.e. pruned) version of the input module + + Examples: + >>> from torch.nn.utils import prune + >>> m = prune.ln_structured( + ... nn.Conv2d(5, 3, 2), 'weight', amount=0.3, dim=1, n=float('-inf') + ... ) + """ + LnStructured.apply( + module, name, amount, n, dim, importance_scores=importance_scores + ) + return module + + +def global_unstructured(parameters, pruning_method, importance_scores=None, **kwargs): + r""" + Globally prunes tensors corresponding to all parameters in ``parameters`` by applying the specified ``pruning_method``. + + Modifies modules in place by: + + 1) adding a named buffer called ``name+'_mask'`` corresponding to the + binary mask applied to the parameter ``name`` by the pruning method. + 2) replacing the parameter ``name`` by its pruned version, while the + original (unpruned) parameter is stored in a new parameter named + ``name+'_orig'``. + + Args: + parameters (Iterable of (module, name) tuples): parameters of + the model to prune in a global fashion, i.e. by aggregating all + weights prior to deciding which ones to prune. module must be of + type :class:`nn.Module`, and name must be a string. + pruning_method (function): a valid pruning function from this module, + or a custom one implemented by the user that satisfies the + implementation guidelines and has ``PRUNING_TYPE='unstructured'``. + importance_scores (dict): a dictionary mapping (module, name) tuples to + the corresponding parameter's importance scores tensor. The tensor + should be the same shape as the parameter, and is used for computing + mask for pruning. + If unspecified or None, the parameter will be used in place of its + importance scores. + kwargs: other keyword arguments such as: + amount (int or float): quantity of parameters to prune across the + specified parameters. + If ``float``, should be between 0.0 and 1.0 and represent the + fraction of parameters to prune. If ``int``, it represents the + absolute number of parameters to prune. + + Raises: + TypeError: if ``PRUNING_TYPE != 'unstructured'`` + + Note: + Since global structured pruning doesn't make much sense unless the + norm is normalized by the size of the parameter, we now limit the + scope of global pruning to unstructured methods. + + Examples: + >>> from torch.nn.utils import prune + >>> from collections import OrderedDict + >>> net = nn.Sequential(OrderedDict([ + ... ('first', nn.Linear(10, 4)), + ... ('second', nn.Linear(4, 1)), + ... ])) + >>> parameters_to_prune = ( + ... (net.first, 'weight'), + ... (net.second, 'weight'), + ... ) + >>> prune.global_unstructured( + ... parameters_to_prune, + ... pruning_method=prune.L1Unstructured, + ... amount=10, + ... ) + >>> print(sum(torch.nn.utils.parameters_to_vector(net.buffers()) == 0)) + tensor(10) + + """ + # ensure parameters is a list or generator of tuples + if not isinstance(parameters, Iterable): + raise TypeError("global_unstructured(): parameters is not an Iterable") + + importance_scores = importance_scores if importance_scores is not None else {} + if not isinstance(importance_scores, dict): + raise TypeError("global_unstructured(): importance_scores must be of type dict") + + # flatten importance scores to consider them all at once in global pruning + relevant_importance_scores = torch.nn.utils.parameters_to_vector( + [ + importance_scores.get((module, name), getattr(module, name)) + for (module, name) in parameters + ] + ) + # similarly, flatten the masks (if they exist), or use a flattened vector + # of 1s of the same dimensions as t + default_mask = torch.nn.utils.parameters_to_vector( + [ + getattr(module, name + "_mask", torch.ones_like(getattr(module, name))) + for (module, name) in parameters + ] + ) + + # use the canonical pruning methods to compute the new mask, even if the + # parameter is now a flattened out version of `parameters` + container = PruningContainer() + container._tensor_name = "temp" # to make it match that of `method` + method = pruning_method(**kwargs) + method._tensor_name = "temp" # to make it match that of `container` + if method.PRUNING_TYPE != "unstructured": + raise TypeError( + 'Only "unstructured" PRUNING_TYPE supported for ' + f"the `pruning_method`. Found method {pruning_method} of type {method.PRUNING_TYPE}" + ) + + container.add_pruning_method(method) + + # use the `compute_mask` method from `PruningContainer` to combine the + # mask computed by the new method with the pre-existing mask + final_mask = container.compute_mask(relevant_importance_scores, default_mask) + + # Pointer for slicing the mask to match the shape of each parameter + pointer = 0 + for module, name in parameters: + + param = getattr(module, name) + # The length of the parameter + num_param = param.numel() + # Slice the mask, reshape it + param_mask = final_mask[pointer : pointer + num_param].view_as(param) + # Assign the correct pre-computed mask to each parameter and add it + # to the forward_pre_hooks like any other pruning method + custom_from_mask(module, name, mask=param_mask) + + # Increment the pointer to continue slicing the final_mask + pointer += num_param + + +def custom_from_mask(module, name, mask): + r"""Prune tensor corresponding to parameter called ``name`` in ``module`` by applying the pre-computed mask in ``mask``. + + Modifies module in place (and also return the modified module) by: + + 1) adding a named buffer called ``name+'_mask'`` corresponding to the + binary mask applied to the parameter ``name`` by the pruning method. + 2) replacing the parameter ``name`` by its pruned version, while the + original (unpruned) parameter is stored in a new parameter named + ``name+'_orig'``. + + Args: + module (nn.Module): module containing the tensor to prune + name (str): parameter name within ``module`` on which pruning + will act. + mask (Tensor): binary mask to be applied to the parameter. + + Returns: + module (nn.Module): modified (i.e. pruned) version of the input module + + Examples: + >>> from torch.nn.utils import prune + >>> m = prune.custom_from_mask( + ... nn.Linear(5, 3), name='bias', mask=torch.tensor([0, 1, 0]) + ... ) + >>> print(m.bias_mask) + tensor([0., 1., 0.]) + + """ + CustomFromMask.apply(module, name, mask) + return module + + +def remove(module, name): + r"""Remove the pruning reparameterization from a module and the pruning method from the forward hook. + + The pruned parameter named ``name`` remains permanently pruned, and the parameter + named ``name+'_orig'`` is removed from the parameter list. Similarly, + the buffer named ``name+'_mask'`` is removed from the buffers. + + Note: + Pruning itself is NOT undone or reversed! + + Args: + module (nn.Module): module containing the tensor to prune + name (str): parameter name within ``module`` on which pruning + will act. + + Examples: + >>> m = random_unstructured(nn.Linear(5, 7), name='weight', amount=0.2) + >>> m = remove(m, name='weight') + """ + for k, hook in module._forward_pre_hooks.items(): + if isinstance(hook, BasePruningMethod) and hook._tensor_name == name: + hook.remove(module) + del module._forward_pre_hooks[k] + return module + + raise ValueError( + f"Parameter '{name}' of module {module} has to be pruned before pruning can be removed" + ) + + +def is_pruned(module): + r"""Check if a module is pruned by looking for pruning pre-hooks. + + Check whether ``module`` is pruned by looking for + ``forward_pre_hooks`` in its modules that inherit from the + :class:`BasePruningMethod`. + + Args: + module (nn.Module): object that is either pruned or unpruned + + Returns: + binary answer to whether ``module`` is pruned. + + Examples: + >>> from torch.nn.utils import prune + >>> m = nn.Linear(5, 7) + >>> print(prune.is_pruned(m)) + False + >>> prune.random_unstructured(m, name='weight', amount=0.2) + >>> print(prune.is_pruned(m)) + True + """ + for _, submodule in module.named_modules(): + for hook in submodule._forward_pre_hooks.values(): + if isinstance(hook, BasePruningMethod): + return True + return False + + +def _validate_pruning_amount_init(amount): + r"""Validate helper to check the range of amount at init. + + Args: + amount (int or float): quantity of parameters to prune. + If float, should be between 0.0 and 1.0 and represent the + fraction of parameters to prune. If int, it represents the + absolute number of parameters to prune. + + Raises: + ValueError: if amount is a float not in [0, 1], or if it's a negative + integer. + TypeError: if amount is neither a float nor an integer. + + Note: + This does not take into account the number of parameters in the + tensor to be pruned, which is known only at prune. + """ + if not isinstance(amount, numbers.Real): + raise TypeError( + f"Invalid type for amount: {amount}. Must be int or float." + ) + + if (isinstance(amount, numbers.Integral) and amount < 0) or ( + not isinstance(amount, numbers.Integral) # so it's a float + and (float(amount) > 1.0 or float(amount) < 0.0) + ): + raise ValueError( + f"amount={amount} should either be a float in the range [0, 1] or a non-negative integer" + ) + + +def _validate_pruning_amount(amount, tensor_size): + r"""Validate that the pruning amount is meaningful wrt to the size of the data. + + Validation helper to check that the amount of parameters to prune + is meaningful wrt to the size of the data (`tensor_size`). + + Args: + amount (int or float): quantity of parameters to prune. + If float, should be between 0.0 and 1.0 and represent the + fraction of parameters to prune. If int, it represents the + absolute number of parameters to prune. + tensor_size (int): absolute number of parameters in the tensor + to prune. + """ + # TODO: consider removing this check and allowing users to specify + # a number of units to prune that is greater than the number of units + # left to prune. In this case, the tensor will just be fully pruned. + + if isinstance(amount, numbers.Integral) and amount > tensor_size: + raise ValueError( + f"amount={amount} should be smaller than the number of parameters to prune={tensor_size}" + ) + + +def _validate_structured_pruning(t): + r"""Validate that the tensor to be pruned is at least 2-Dimensional. + + Validation helper to check that the tensor to be pruned is multi- + dimensional, such that the concept of "channels" is well-defined. + + Args: + t (torch.Tensor): tensor representing the parameter to prune + + Raises: + ValueError: if the tensor `t` is not at least 2D. + """ + shape = t.shape + if len(shape) <= 1: + raise ValueError( + "Structured pruning can only be applied to " + "multidimensional tensors. Found tensor of shape " + f"{shape} with {len(shape)} dims" + ) + + +def _compute_nparams_toprune(amount, tensor_size): + r"""Convert the pruning amount from a percentage to absolute value. + + Since amount can be expressed either in absolute value or as a + percentage of the number of units/channels in a tensor, this utility + function converts the percentage to absolute value to standardize + the handling of pruning. + + Args: + amount (int or float): quantity of parameters to prune. + If float, should be between 0.0 and 1.0 and represent the + fraction of parameters to prune. If int, it represents the + absolute number of parameters to prune. + tensor_size (int): absolute number of parameters in the tensor + to prune. + + Returns: + int: the number of units to prune in the tensor + """ + # incorrect type already checked in _validate_pruning_amount_init + if isinstance(amount, numbers.Integral): + return amount + else: + return round(amount * tensor_size) + + +def _validate_pruning_dim(t, dim): + r"""Validate that the pruning dimension is within the bounds of the tensor dimension. + + Args: + t (torch.Tensor): tensor representing the parameter to prune + dim (int): index of the dim along which we define channels to prune + """ + if dim >= t.dim(): + raise IndexError(f"Invalid index {dim} for tensor of size {t.shape}") + + +def _compute_norm(t, n, dim): + r"""Compute the L_n-norm of a tensor along all dimensions except for the specified dimension. + + The L_n-norm will be computed across all entries in tensor `t` along all dimension + except for the one identified by dim. + Example: if `t` is of shape, say, 3x2x4 and dim=2 (the last dim), + then norm will have Size [4], and each entry will represent the + `L_n`-norm computed using the 3x2=6 entries for each of the 4 channels. + + Args: + t (torch.Tensor): tensor representing the parameter to prune + n (int, float, inf, -inf, 'fro', 'nuc'): See documentation of valid + entries for argument p in torch.norm + dim (int): dim identifying the channels to prune + + Returns: + norm (torch.Tensor): L_n norm computed across all dimensions except + for `dim`. By construction, `norm.shape = t.shape[-1]`. + """ + # dims = all axes, except for the one identified by `dim` + dims = list(range(t.dim())) + # convert negative indexing + if dim < 0: + dim = dims[dim] + dims.remove(dim) + + norm = torch.norm(t, p=n, dim=dims) + return norm diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/rnn.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/rnn.py new file mode 100644 index 0000000000000000000000000000000000000000..2a3ff1f1de9a90c2570e92cdcdcdd8b56730cad5 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/rnn.py @@ -0,0 +1,517 @@ +import warnings +from typing import Iterable, List, NamedTuple, Tuple, Union + +import torch +from torch import Tensor +from ... import _VF +from ..._jit_internal import Optional + + +__all__ = ['PackedSequence', 'invert_permutation', 'pack_padded_sequence', 'pad_packed_sequence', 'pad_sequence', + 'unpad_sequence', 'pack_sequence', 'unpack_sequence'] + + +class PackedSequence_(NamedTuple): + data: torch.Tensor + batch_sizes: torch.Tensor + sorted_indices: Optional[torch.Tensor] + unsorted_indices: Optional[torch.Tensor] + + +def bind(optional, fn): + if optional is None: + return None + return fn(optional) + + +class PackedSequence(PackedSequence_): + r"""Holds the data and list of :attr:`batch_sizes` of a packed sequence. + + All RNN modules accept packed sequences as inputs. + + Note: + Instances of this class should never be created manually. They are meant + to be instantiated by functions like :func:`pack_padded_sequence`. + + Batch sizes represent the number elements at each sequence step in + the batch, not the varying sequence lengths passed to + :func:`pack_padded_sequence`. For instance, given data ``abc`` and ``x`` + the :class:`PackedSequence` would contain data ``axbc`` with + ``batch_sizes=[2,1,1]``. + + Attributes: + data (Tensor): Tensor containing packed sequence + batch_sizes (Tensor): Tensor of integers holding + information about the batch size at each sequence step + sorted_indices (Tensor, optional): Tensor of integers holding how this + :class:`PackedSequence` is constructed from sequences. + unsorted_indices (Tensor, optional): Tensor of integers holding how this + to recover the original sequences with correct order. + + .. note:: + :attr:`data` can be on arbitrary device and of arbitrary dtype. + :attr:`sorted_indices` and :attr:`unsorted_indices` must be ``torch.int64`` + tensors on the same device as :attr:`data`. + + However, :attr:`batch_sizes` should always be a CPU ``torch.int64`` tensor. + + This invariant is maintained throughout :class:`PackedSequence` class, + and all functions that construct a :class:`PackedSequence` in PyTorch + (i.e., they only pass in tensors conforming to this constraint). + + """ + + def __new__(cls, data, batch_sizes=None, sorted_indices=None, unsorted_indices=None): + return super().__new__( + cls, + *_packed_sequence_init_args(data, batch_sizes, sorted_indices, + unsorted_indices)) + + # NOTE [ device and dtype of a PackedSequence ] + # + # See the note above in doc string (starting with ":attr:`data` can be on + # arbitrary device..."). + def pin_memory(self): + # Why not convert `batch_sizes`? + # See NOTE [ device and dtype of a PackedSequence ] + return type(self)(self.data.pin_memory(), self.batch_sizes, + bind(self.sorted_indices, lambda t: t.pin_memory()), + bind(self.unsorted_indices, lambda t: t.pin_memory())) + + def cuda(self, *args, **kwargs): + # Tests to see if 'cuda' should be added to kwargs + ex = torch.tensor((), dtype=self.data.dtype, device=self.data.device).to(*args, **kwargs) + if ex.is_cuda: + return self.to(*args, **kwargs) + return self.to(*args, device='cuda', **kwargs) + + def cpu(self, *args, **kwargs): + + ex = torch.tensor((), dtype=self.data.dtype, device=self.data.device).to(*args, **kwargs) + if ex.device.type == 'cpu': + return self.to(*args, **kwargs) + return self.to(*args, device='cpu', **kwargs) + + def double(self): + return self.to(dtype=torch.double) + + def float(self): + return self.to(dtype=torch.float) + + def half(self): + return self.to(dtype=torch.half) + + def long(self): + return self.to(dtype=torch.long) + + def int(self): + return self.to(dtype=torch.int) + + def short(self): + return self.to(dtype=torch.short) + + def char(self): + return self.to(dtype=torch.int8) + + def byte(self): + return self.to(dtype=torch.uint8) + + def to(self, *args, **kwargs): + r"""Perform dtype and/or device conversion on `self.data`. + + It has similar signature as :meth:`torch.Tensor.to`, except optional + arguments like `non_blocking` and `copy` should be passed as kwargs, + not args, or they will not apply to the index tensors. + + .. note:: + + If the ``self.data`` Tensor already has the correct :class:`torch.dtype` + and :class:`torch.device`, then ``self`` is returned. + Otherwise, returns a copy with the desired configuration. + """ + # Why not convert `batch_sizes`? + # See NOTE [ device and dtype of a PackedSequence ] + data = self.data.to(*args, **kwargs) + if data is self.data: + return self + else: + # Does not forward device or dtype arg/kwargs, device is set from data.device + kwargs = dict(filter(lambda t: t[0] != 'device' and t[0] != 'dtype', kwargs.items())) + sorted_indices = bind(self.sorted_indices, lambda t: t.to(data.device, **kwargs)) + unsorted_indices = bind(self.unsorted_indices, lambda t: t.to(data.device, **kwargs)) + return type(self)(data, self.batch_sizes, sorted_indices, unsorted_indices) + + @property + def is_cuda(self): + r"""Return true if `self.data` stored on a gpu.""" + return self.data.is_cuda + + def is_pinned(self): + r"""Return true if `self.data` stored on in pinned memory.""" + return self.data.is_pinned() + + +# TorchScript doesn't support constructors on named tuples, so we use this helper +# method to construct PackedSequence +def _packed_sequence_init_args( + data: Tensor, + batch_sizes: Optional[Tensor] = None, + sorted_indices: Optional[Tensor] = None, + unsorted_indices: Optional[Tensor] = None, +) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]: + # NB: if unsorted_indices is provided, it should be the inverse permutation + # to sorted_indices. Don't assert it here because the PackedSequence ctor + # should only be used internally. + + if unsorted_indices is None: + unsorted_indices = invert_permutation(sorted_indices) + + # support being called as `PackedSequence(data, batch_sizes, sorted_indices)` + if batch_sizes is not None: + # TODO: Re-enable this check (.type isn't supported in TorchScript) + if batch_sizes.device.type != 'cpu': + raise ValueError( + "batch_sizes should always be on CPU. " + "Instances of PackedSequence should never be created manually. " + "They should be instantiated by functions like pack_sequence " + "and pack_padded_sequences in nn.utils.rnn. " + "https://pytorch.org/docs/stable/nn.html#torch.nn.utils.rnn.pack_sequence") + return data, batch_sizes, sorted_indices, unsorted_indices + + # support being called as `PackedSequence((data, batch_sizes), *, sorted_indices)` + else: + assert isinstance(data, (list, tuple)) and len(data) == 2 + return data[0], data[1], sorted_indices, unsorted_indices + + +def _packed_sequence_init( + data: Tensor, + batch_sizes: Optional[Tensor] = None, + sorted_indices: Optional[Tensor] = None, + unsorted_indices: Optional[Tensor] = None, +) -> PackedSequence: + data, batch_sizes, sorted_indices, unsorted_indices = _packed_sequence_init_args( + data, batch_sizes, sorted_indices, unsorted_indices) + return PackedSequence(data, batch_sizes, sorted_indices, unsorted_indices) + + +def invert_permutation(permutation: Optional[Tensor]) -> Optional[Tensor]: + if permutation is None: + return None + output = torch.empty_like(permutation, memory_format=torch.legacy_contiguous_format) + output.scatter_(0, permutation, + torch.arange(0, permutation.numel(), device=permutation.device)) + return output + + +def pack_padded_sequence( + input: Tensor, + lengths: Tensor, + batch_first: bool = False, + enforce_sorted: bool = True, +) -> PackedSequence: + r"""Packs a Tensor containing padded sequences of variable length. + + :attr:`input` can be of size ``T x B x *`` where `T` is the length of the + longest sequence (equal to ``lengths[0]``), ``B`` is the batch size, and + ``*`` is any number of dimensions (including 0). If ``batch_first`` is + ``True``, ``B x T x *`` :attr:`input` is expected. + + For unsorted sequences, use `enforce_sorted = False`. If :attr:`enforce_sorted` is + ``True``, the sequences should be sorted by length in a decreasing order, i.e. + ``input[:,0]`` should be the longest sequence, and ``input[:,B-1]`` the shortest + one. `enforce_sorted = True` is only necessary for ONNX export. + + Note: + This function accepts any input that has at least two dimensions. You + can apply it to pack the labels, and use the output of the RNN with + them to compute the loss directly. A Tensor can be retrieved from + a :class:`PackedSequence` object by accessing its ``.data`` attribute. + + Args: + input (Tensor): padded batch of variable length sequences. + lengths (Tensor or list(int)): list of sequence lengths of each batch + element (must be on the CPU if provided as a tensor). + batch_first (bool, optional): if ``True``, the input is expected in ``B x T x *`` + format. + enforce_sorted (bool, optional): if ``True``, the input is expected to + contain sequences sorted by length in a decreasing order. If + ``False``, the input will get sorted unconditionally. Default: ``True``. + + Returns: + a :class:`PackedSequence` object + """ + if not isinstance(lengths, torch.Tensor): + if torch._C._get_tracing_state(): + warnings.warn('pack_padded_sequence has been called with a Python list of ' + 'sequence lengths. The tracer cannot track the data flow of Python ' + 'values, and it will treat them as constants, likely rendering ' + 'the trace incorrect for any other combination of lengths.', + stacklevel=2) + lengths = torch.as_tensor(lengths, dtype=torch.int64, device='cpu') + else: + lengths = lengths.to(dtype=torch.int64) + + if enforce_sorted: + sorted_indices = None + else: + lengths, sorted_indices = torch.sort(lengths, descending=True) + sorted_indices = sorted_indices.to(input.device) + batch_dim = 0 if batch_first else 1 + input = input.index_select(batch_dim, sorted_indices) + + data, batch_sizes = \ + _VF._pack_padded_sequence(input, lengths, batch_first) + return _packed_sequence_init(data, batch_sizes, sorted_indices, None) + + +def pad_packed_sequence( + sequence: PackedSequence, + batch_first: bool = False, + padding_value: float = 0.0, + total_length: Optional[int] = None, +) -> Tuple[Tensor, Tensor]: + r"""Pad a packed batch of variable length sequences. + + It is an inverse operation to :func:`pack_padded_sequence`. + + The returned Tensor's data will be of size ``T x B x *``, where `T` is the length + of the longest sequence and `B` is the batch size. If ``batch_first`` is True, + the data will be transposed into ``B x T x *`` format. + + Example: + >>> from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence + >>> seq = torch.tensor([[1, 2, 0], [3, 0, 0], [4, 5, 6]]) + >>> lens = [2, 1, 3] + >>> packed = pack_padded_sequence(seq, lens, batch_first=True, enforce_sorted=False) + >>> packed + PackedSequence(data=tensor([4, 1, 3, 5, 2, 6]), batch_sizes=tensor([3, 2, 1]), + sorted_indices=tensor([2, 0, 1]), unsorted_indices=tensor([1, 2, 0])) + >>> seq_unpacked, lens_unpacked = pad_packed_sequence(packed, batch_first=True) + >>> seq_unpacked + tensor([[1, 2, 0], + [3, 0, 0], + [4, 5, 6]]) + >>> lens_unpacked + tensor([2, 1, 3]) + + .. note:: + :attr:`total_length` is useful to implement the + ``pack sequence -> recurrent network -> unpack sequence`` pattern in a + :class:`~torch.nn.Module` wrapped in :class:`~torch.nn.DataParallel`. + See :ref:`this FAQ section ` for + details. + + Args: + sequence (PackedSequence): batch to pad + batch_first (bool, optional): if ``True``, the output will be in ``B x T x *`` + format. + padding_value (float, optional): values for padded elements. + total_length (int, optional): if not ``None``, the output will be padded to + have length :attr:`total_length`. This method will throw :class:`ValueError` + if :attr:`total_length` is less than the max sequence length in + :attr:`sequence`. + + Returns: + Tuple of Tensor containing the padded sequence, and a Tensor + containing the list of lengths of each sequence in the batch. + Batch elements will be re-ordered as they were ordered originally when + the batch was passed to ``pack_padded_sequence`` or ``pack_sequence``. + + + + + """ + max_seq_length = sequence.batch_sizes.size(0) + if total_length is not None: + if total_length < max_seq_length: + raise ValueError("Expected total_length to be at least the length " + "of the longest sequence in input, but got " + f"total_length={total_length} and max sequence length being {max_seq_length}" + ) + max_seq_length = total_length + padded_output, lengths = _VF._pad_packed_sequence( + sequence.data, sequence.batch_sizes, batch_first, padding_value, max_seq_length) + unsorted_indices = sequence.unsorted_indices + if unsorted_indices is not None: + batch_dim = 0 if batch_first else 1 + return padded_output.index_select(batch_dim, unsorted_indices), lengths[unsorted_indices.cpu()] + return padded_output, lengths + +# NOTE: .pyi stub allows Iterable[Tensor], but for JIT-compatibility we need to be more restrictive here. +def pad_sequence( + sequences: Union[Tensor, List[Tensor]], + batch_first: bool = False, + padding_value: float = 0.0, +) -> Tensor: + r"""Pad a list of variable length Tensors with ``padding_value``. + + ``pad_sequence`` stacks a list of Tensors along a new dimension, + and pads them to equal length. For example, if the input is a list of + sequences with size ``L x *`` and ``batch_first`` is False, the output is + of size ``T x B x *``. + + `B` is batch size. It is equal to the number of elements in ``sequences``. + `T` is length of the longest sequence. + `L` is length of the sequence. + `*` is any number of trailing dimensions, including none. + + Example: + >>> from torch.nn.utils.rnn import pad_sequence + >>> a = torch.ones(25, 300) + >>> b = torch.ones(22, 300) + >>> c = torch.ones(15, 300) + >>> pad_sequence([a, b, c]).size() + torch.Size([25, 3, 300]) + + Note: + This function returns a Tensor of size ``T x B x *`` or ``B x T x *`` + where `T` is the length of the longest sequence. This function assumes + trailing dimensions and type of all the Tensors in sequences are same. + + Args: + sequences (list[Tensor]): list of variable length sequences. + batch_first (bool, optional): output will be in ``B x T x *`` if True, or in + ``T x B x *`` otherwise. Default: False. + padding_value (float, optional): value for padded elements. Default: 0. + + Returns: + Tensor of size ``T x B x *`` if :attr:`batch_first` is ``False``. + Tensor of size ``B x T x *`` otherwise + """ + if not (torch.jit.is_tracing() or torch.jit.is_scripting()): + # JIT doesn't support `Iterable` + if not isinstance(sequences, Iterable): + msg = ('pad_sequence: Expected iterable for input sequences, but got arg of type: ' + f'{type(sequences)}') + raise RuntimeError(msg) + + # In JIT context this leads to, + # RuntimeError: cannot statically infer the expected size of a list in this context + sequences = tuple(sequences) + else: + # For JIT, we only support Union[Tensor, Tuple[Tensor]] + if isinstance(sequences, torch.Tensor): + sequences = sequences.unbind(0) + + # assuming trailing dimensions and type of all the Tensors + # in sequences are same and fetching those from sequences[0] + return torch._C._nn.pad_sequence(sequences, batch_first, padding_value) + + +def unpad_sequence( + padded_sequences: Tensor, + lengths: Tensor, + batch_first: bool = False, +) -> List[Tensor]: + r"""Unpad padded Tensor into a list of variable length Tensors. + + ``unpad_sequence`` unstacks padded Tensor into a list of variable length Tensors. + + Example: + >>> from torch.nn.utils.rnn import pad_sequence, unpad_sequence + >>> a = torch.ones(25, 300) + >>> b = torch.ones(22, 300) + >>> c = torch.ones(15, 300) + >>> sequences = [a, b, c] + >>> padded_sequences = pad_sequence(sequences) + >>> lengths = torch.as_tensor([v.size(0) for v in sequences]) + >>> unpadded_sequences = unpad_sequence(padded_sequences, lengths) + >>> torch.allclose(sequences[0], unpadded_sequences[0]) + True + >>> torch.allclose(sequences[1], unpadded_sequences[1]) + True + >>> torch.allclose(sequences[2], unpadded_sequences[2]) + True + + Args: + padded_sequences (Tensor): padded sequences. + lengths (Tensor): length of original (unpadded) sequences. + batch_first (bool, optional): whether batch dimension first or not. Default: False. + + Returns: + a list of :class:`Tensor` objects + """ + unpadded_sequences = [] + + if not batch_first: + padded_sequences.transpose_(0, 1) + + max_length = padded_sequences.shape[1] + idx = torch.arange(max_length, device=lengths.device) + + for seq, length in zip(padded_sequences, lengths): + mask = idx < length + unpacked_seq = seq[mask] + unpadded_sequences.append(unpacked_seq) + + return unpadded_sequences + + +def pack_sequence(sequences: List[Tensor], enforce_sorted: bool = True) -> PackedSequence: + r"""Packs a list of variable length Tensors. + + Consecutive call of the next functions: ``pad_sequence``, ``pack_padded_sequence``. + + ``sequences`` should be a list of Tensors of size ``L x *``, where `L` is + the length of a sequence and `*` is any number of trailing dimensions, + including zero. + + For unsorted sequences, use `enforce_sorted = False`. If ``enforce_sorted`` + is ``True``, the sequences should be sorted in the order of decreasing length. + ``enforce_sorted = True`` is only necessary for ONNX export. + + + Example: + >>> from torch.nn.utils.rnn import pack_sequence + >>> a = torch.tensor([1, 2, 3]) + >>> b = torch.tensor([4, 5]) + >>> c = torch.tensor([6]) + >>> pack_sequence([a, b, c]) + PackedSequence(data=tensor([1, 4, 6, 2, 5, 3]), batch_sizes=tensor([3, 2, 1]), sorted_indices=None, unsorted_indices=None) + + + Args: + sequences (list[Tensor]): A list of sequences of decreasing length. + enforce_sorted (bool, optional): if ``True``, checks that the input + contains sequences sorted by length in a decreasing order. If + ``False``, this condition is not checked. Default: ``True``. + + Returns: + a :class:`PackedSequence` object + """ + lengths = torch.as_tensor([v.size(0) for v in sequences]) + return pack_padded_sequence(pad_sequence(sequences), lengths, enforce_sorted=enforce_sorted) + + +def unpack_sequence(packed_sequences: PackedSequence) -> List[Tensor]: + r"""Unpack PackedSequence into a list of variable length Tensors. + + ``packed_sequences`` should be a PackedSequence object. + + + Example: + >>> from torch.nn.utils.rnn import pack_sequence, unpack_sequence + >>> a = torch.tensor([1, 2, 3]) + >>> b = torch.tensor([4, 5]) + >>> c = torch.tensor([6]) + >>> sequences = [a, b, c] + >>> print(sequences) + [tensor([1, 2, 3]), tensor([4, 5]), tensor([6])] + >>> packed_sequences = pack_sequence(sequences) + >>> print(packed_sequences) + PackedSequence(data=tensor([1, 4, 6, 2, 5, 3]), batch_sizes=tensor([3, 2, 1]), sorted_indices=None, unsorted_indices=None) + >>> unpacked_sequences = unpack_sequence(packed_sequences) + >>> print(unpacked_sequences) + [tensor([1, 2, 3]), tensor([4, 5]), tensor([6])] + + + Args: + packed_sequences (PackedSequence): A PackedSequence object. + + Returns: + a list of :class:`Tensor` objects + """ + padded_sequences, lengths = pad_packed_sequence(packed_sequences, batch_first=True) + unpacked_sequences = unpad_sequence(padded_sequences, lengths, batch_first=True) + return unpacked_sequences diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/spectral_norm.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/spectral_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..bda54b9a1222203791556b0fc2193bab59f33644 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/spectral_norm.py @@ -0,0 +1,312 @@ +"""Spectral Normalization from https://arxiv.org/abs/1802.05957.""" +import torch +from torch.nn.functional import normalize +from typing import Any, Optional, TypeVar +from ..modules import Module + +__all__ = ['SpectralNorm', 'SpectralNormLoadStateDictPreHook', 'SpectralNormStateDictHook', + 'spectral_norm', 'remove_spectral_norm'] + +class SpectralNorm: + # Invariant before and after each forward call: + # u = normalize(W @ v) + # NB: At initialization, this invariant is not enforced + + _version: int = 1 + # At version 1: + # made `W` not a buffer, + # added `v` as a buffer, and + # made eval mode use `W = u @ W_orig @ v` rather than the stored `W`. + name: str + dim: int + n_power_iterations: int + eps: float + + def __init__(self, name: str = 'weight', n_power_iterations: int = 1, dim: int = 0, eps: float = 1e-12) -> None: + self.name = name + self.dim = dim + if n_power_iterations <= 0: + raise ValueError('Expected n_power_iterations to be positive, but ' + f'got n_power_iterations={n_power_iterations}') + self.n_power_iterations = n_power_iterations + self.eps = eps + + def reshape_weight_to_matrix(self, weight: torch.Tensor) -> torch.Tensor: + weight_mat = weight + if self.dim != 0: + # permute dim to front + weight_mat = weight_mat.permute(self.dim, + *[d for d in range(weight_mat.dim()) if d != self.dim]) + height = weight_mat.size(0) + return weight_mat.reshape(height, -1) + + def compute_weight(self, module: Module, do_power_iteration: bool) -> torch.Tensor: + # NB: If `do_power_iteration` is set, the `u` and `v` vectors are + # updated in power iteration **in-place**. This is very important + # because in `DataParallel` forward, the vectors (being buffers) are + # broadcast from the parallelized module to each module replica, + # which is a new module object created on the fly. And each replica + # runs its own spectral norm power iteration. So simply assigning + # the updated vectors to the module this function runs on will cause + # the update to be lost forever. And the next time the parallelized + # module is replicated, the same randomly initialized vectors are + # broadcast and used! + # + # Therefore, to make the change propagate back, we rely on two + # important behaviors (also enforced via tests): + # 1. `DataParallel` doesn't clone storage if the broadcast tensor + # is already on correct device; and it makes sure that the + # parallelized module is already on `device[0]`. + # 2. If the out tensor in `out=` kwarg has correct shape, it will + # just fill in the values. + # Therefore, since the same power iteration is performed on all + # devices, simply updating the tensors in-place will make sure that + # the module replica on `device[0]` will update the _u vector on the + # parallelized module (by shared storage). + # + # However, after we update `u` and `v` in-place, we need to **clone** + # them before using them to normalize the weight. This is to support + # backproping through two forward passes, e.g., the common pattern in + # GAN training: loss = D(real) - D(fake). Otherwise, engine will + # complain that variables needed to do backward for the first forward + # (i.e., the `u` and `v` vectors) are changed in the second forward. + weight = getattr(module, self.name + '_orig') + u = getattr(module, self.name + '_u') + v = getattr(module, self.name + '_v') + weight_mat = self.reshape_weight_to_matrix(weight) + + if do_power_iteration: + with torch.no_grad(): + for _ in range(self.n_power_iterations): + # Spectral norm of weight equals to `u^T W v`, where `u` and `v` + # are the first left and right singular vectors. + # This power iteration produces approximations of `u` and `v`. + v = normalize(torch.mv(weight_mat.t(), u), dim=0, eps=self.eps, out=v) + u = normalize(torch.mv(weight_mat, v), dim=0, eps=self.eps, out=u) + if self.n_power_iterations > 0: + # See above on why we need to clone + u = u.clone(memory_format=torch.contiguous_format) + v = v.clone(memory_format=torch.contiguous_format) + + sigma = torch.dot(u, torch.mv(weight_mat, v)) + weight = weight / sigma + return weight + + def remove(self, module: Module) -> None: + with torch.no_grad(): + weight = self.compute_weight(module, do_power_iteration=False) + delattr(module, self.name) + delattr(module, self.name + '_u') + delattr(module, self.name + '_v') + delattr(module, self.name + '_orig') + module.register_parameter(self.name, torch.nn.Parameter(weight.detach())) + + def __call__(self, module: Module, inputs: Any) -> None: + setattr(module, self.name, self.compute_weight(module, do_power_iteration=module.training)) + + def _solve_v_and_rescale(self, weight_mat, u, target_sigma): + # Tries to returns a vector `v` s.t. `u = normalize(W @ v)` + # (the invariant at top of this class) and `u @ W @ v = sigma`. + # This uses pinverse in case W^T W is not invertible. + v = torch.linalg.multi_dot([weight_mat.t().mm(weight_mat).pinverse(), weight_mat.t(), u.unsqueeze(1)]).squeeze(1) + return v.mul_(target_sigma / torch.dot(u, torch.mv(weight_mat, v))) + + @staticmethod + def apply(module: Module, name: str, n_power_iterations: int, dim: int, eps: float) -> 'SpectralNorm': + for hook in module._forward_pre_hooks.values(): + if isinstance(hook, SpectralNorm) and hook.name == name: + raise RuntimeError(f"Cannot register two spectral_norm hooks on the same parameter {name}") + + fn = SpectralNorm(name, n_power_iterations, dim, eps) + weight = module._parameters[name] + if weight is None: + raise ValueError(f'`SpectralNorm` cannot be applied as parameter `{name}` is None') + if isinstance(weight, torch.nn.parameter.UninitializedParameter): + raise ValueError( + 'The module passed to `SpectralNorm` can\'t have uninitialized parameters. ' + 'Make sure to run the dummy forward before applying spectral normalization') + + with torch.no_grad(): + weight_mat = fn.reshape_weight_to_matrix(weight) + + h, w = weight_mat.size() + # randomly initialize `u` and `v` + u = normalize(weight.new_empty(h).normal_(0, 1), dim=0, eps=fn.eps) + v = normalize(weight.new_empty(w).normal_(0, 1), dim=0, eps=fn.eps) + + delattr(module, fn.name) + module.register_parameter(fn.name + "_orig", weight) + # We still need to assign weight back as fn.name because all sorts of + # things may assume that it exists, e.g., when initializing weights. + # However, we can't directly assign as it could be an nn.Parameter and + # gets added as a parameter. Instead, we register weight.data as a plain + # attribute. + setattr(module, fn.name, weight.data) + module.register_buffer(fn.name + "_u", u) + module.register_buffer(fn.name + "_v", v) + + module.register_forward_pre_hook(fn) + module._register_state_dict_hook(SpectralNormStateDictHook(fn)) + module._register_load_state_dict_pre_hook(SpectralNormLoadStateDictPreHook(fn)) + return fn + + +# This is a top level class because Py2 pickle doesn't like inner class nor an +# instancemethod. +class SpectralNormLoadStateDictPreHook: + # See docstring of SpectralNorm._version on the changes to spectral_norm. + def __init__(self, fn) -> None: + self.fn = fn + + # For state_dict with version None, (assuming that it has gone through at + # least one training forward), we have + # + # u = normalize(W_orig @ v) + # W = W_orig / sigma, where sigma = u @ W_orig @ v + # + # To compute `v`, we solve `W_orig @ x = u`, and let + # v = x / (u @ W_orig @ x) * (W / W_orig). + def __call__(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs) -> None: + fn = self.fn + version = local_metadata.get('spectral_norm', {}).get(fn.name + '.version', None) + if version is None or version < 1: + weight_key = prefix + fn.name + if version is None and all(weight_key + s in state_dict for s in ('_orig', '_u', '_v')) and \ + weight_key not in state_dict: + # Detect if it is the updated state dict and just missing metadata. + # This could happen if the users are crafting a state dict themselves, + # so we just pretend that this is the newest. + return + has_missing_keys = False + for suffix in ('_orig', '', '_u'): + key = weight_key + suffix + if key not in state_dict: + has_missing_keys = True + if strict: + missing_keys.append(key) + if has_missing_keys: + return + with torch.no_grad(): + weight_orig = state_dict[weight_key + '_orig'] + weight = state_dict.pop(weight_key) + sigma = (weight_orig / weight).mean() + weight_mat = fn.reshape_weight_to_matrix(weight_orig) + u = state_dict[weight_key + '_u'] + v = fn._solve_v_and_rescale(weight_mat, u, sigma) + state_dict[weight_key + '_v'] = v + + +# This is a top level class because Py2 pickle doesn't like inner class nor an +# instancemethod. +class SpectralNormStateDictHook: + # See docstring of SpectralNorm._version on the changes to spectral_norm. + def __init__(self, fn) -> None: + self.fn = fn + + def __call__(self, module, state_dict, prefix, local_metadata) -> None: + if 'spectral_norm' not in local_metadata: + local_metadata['spectral_norm'] = {} + key = self.fn.name + '.version' + if key in local_metadata['spectral_norm']: + raise RuntimeError(f"Unexpected key in metadata['spectral_norm']: {key}") + local_metadata['spectral_norm'][key] = self.fn._version + + +T_module = TypeVar('T_module', bound=Module) + +def spectral_norm(module: T_module, + name: str = 'weight', + n_power_iterations: int = 1, + eps: float = 1e-12, + dim: Optional[int] = None) -> T_module: + r"""Apply spectral normalization to a parameter in the given module. + + .. math:: + \mathbf{W}_{SN} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})}, + \sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2} + + Spectral normalization stabilizes the training of discriminators (critics) + in Generative Adversarial Networks (GANs) by rescaling the weight tensor + with spectral norm :math:`\sigma` of the weight matrix calculated using + power iteration method. If the dimension of the weight tensor is greater + than 2, it is reshaped to 2D in power iteration method to get spectral + norm. This is implemented via a hook that calculates spectral norm and + rescales weight before every :meth:`~Module.forward` call. + + See `Spectral Normalization for Generative Adversarial Networks`_ . + + .. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957 + + Args: + module (nn.Module): containing module + name (str, optional): name of weight parameter + n_power_iterations (int, optional): number of power iterations to + calculate spectral norm + eps (float, optional): epsilon for numerical stability in + calculating norms + dim (int, optional): dimension corresponding to number of outputs, + the default is ``0``, except for modules that are instances of + ConvTranspose{1,2,3}d, when it is ``1`` + + Returns: + The original module with the spectral norm hook + + .. note:: + This function has been reimplemented as + :func:`torch.nn.utils.parametrizations.spectral_norm` using the new + parametrization functionality in + :func:`torch.nn.utils.parametrize.register_parametrization`. Please use + the newer version. This function will be deprecated in a future version + of PyTorch. + + Example:: + + >>> m = spectral_norm(nn.Linear(20, 40)) + >>> m + Linear(in_features=20, out_features=40, bias=True) + >>> m.weight_u.size() + torch.Size([40]) + + """ + if dim is None: + if isinstance(module, (torch.nn.ConvTranspose1d, + torch.nn.ConvTranspose2d, + torch.nn.ConvTranspose3d)): + dim = 1 + else: + dim = 0 + SpectralNorm.apply(module, name, n_power_iterations, dim, eps) + return module + + +def remove_spectral_norm(module: T_module, name: str = 'weight') -> T_module: + r"""Remove the spectral normalization reparameterization from a module. + + Args: + module (Module): containing module + name (str, optional): name of weight parameter + + Example: + >>> m = spectral_norm(nn.Linear(40, 10)) + >>> remove_spectral_norm(m) + """ + for k, hook in module._forward_pre_hooks.items(): + if isinstance(hook, SpectralNorm) and hook.name == name: + hook.remove(module) + del module._forward_pre_hooks[k] + break + else: + raise ValueError(f"spectral_norm of '{name}' not found in {module}") + + for k, hook in module._state_dict_hooks.items(): + if isinstance(hook, SpectralNormStateDictHook) and hook.fn.name == name: + del module._state_dict_hooks[k] + break + + for k, hook in module._load_state_dict_pre_hooks.items(): + if isinstance(hook, SpectralNormLoadStateDictPreHook) and hook.fn.name == name: + del module._load_state_dict_pre_hooks[k] + break + + return module