diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_custom_op/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_custom_op/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_custom_op/__pycache__/autograd.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_custom_op/__pycache__/autograd.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9a8739cc350e2ab349b7c59204b14c3fb0272b1b Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_custom_op/__pycache__/autograd.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_custom_op/__pycache__/functional.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_custom_op/__pycache__/functional.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3265f3c3922fc5a7d5897eac2a6c8b343cbcc731 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_custom_op/__pycache__/functional.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_custom_op/__pycache__/impl.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_custom_op/__pycache__/impl.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..55d750730123854437b7214c96f4d8b04d33afc3 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_custom_op/__pycache__/impl.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_custom_op/functional.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_custom_op/functional.py new file mode 100644 index 0000000000000000000000000000000000000000..26ef5b307bd52b7c764270ee02365e780fd32349 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_custom_op/functional.py @@ -0,0 +1,187 @@ +import weakref + +import torch +import torch.utils._pytree as pytree +from torch._C import _ExcludeDispatchKeyGuard, DispatchKey, DispatchKeySet +from torch._ops import OpOverload +from torch.library import Library +from torchgen.model import ( + BaseTy, + BaseType, + FunctionSchema, + OperatorName, + OptionalType, + SchemaKind, +) + +from .autograd import autograd_not_implemented + + +def register_functional_op( + lib: Library, + new_op_name: str, + mutable_op: OpOverload, +) -> None: + """Given a mutable operator, registers the functional variant. + + This API also correctly links the functional variant with the mutable + operator for the purposes of functionalization. + + All of the new registrations are performed on the ``lib`` passed in. + + Arguments: + lib (Library): Should be a torch.library.Library object that has + the same namespace as ``mutable_op``'s namespace. + lib will be used to register the new functional op as well + as a functionalization kernel for the ``mutable_op`` + If you don't have a library handy, use + ``torch.library.Library(ns, 'FRAGMENT')`` to construct one. + new_op_name (str): The name of the functional operator (without the + namespace). If no namespace, the new functional variant will be + accessible under ``torch.ops.{lib.ns}.new_op_name``. + mutable_op (OpOverload): The mutable custom operator. Note + that you may need to add a `.default` to it, like + `torch.ops.aten.abs_.default`. + + """ + validate(mutable_op) + schema = functional_schema(new_op_name, mutable_op) + lib.define(schema) + + functional_impl = construct_functional_impl(mutable_op) + lib.impl(new_op_name, functional_impl, 'CompositeExplicitAutograd') + + functional_op = getattr(getattr(torch.ops, lib.ns), new_op_name).default + + # There's no easy way for us to generate the autograd kernel, so we + # use autograd_not_implemented. Also, this makes it so that the user + # is unable to register an autograd formula themselves. This shouldn't + # be a problem if the user doesn't use the functional op direclty + # in their program, but we may need to revist this in the future. + lib.impl(new_op_name, autograd_not_implemented(functional_op), 'Autograd') + + f_kernel = construct_functionalization_kernel(weakref.proxy(mutable_op), functional_op) + + lib.impl(mutable_op, f_kernel, 'Functionalize') + + +def construct_functional_impl(mutable_op): + def functional_impl(*args): + # Strategy: + # - clone args that would have been mutated + # - run mutable_op + # - return the cloned args as additional outputs + new_args = [] + extra_rets = [] + for is_write, arg in zip(mutable_args(mutable_op), args): + if is_write: + cloned = arg.clone() if arg is not None else None + new_args.append(cloned) + extra_rets.append(cloned) + else: + new_args.append(arg) + result = mutable_op(*new_args) + if result is None: + return tuple(extra_rets) + if isinstance(result, tuple): + return (*result, *extra_rets) + return (result, *extra_rets) + return functional_impl + + +def construct_functionalization_kernel(mutable_op, functional_op): + def kernel(*args): + # There's nothing to be functionalized! + # We can still end up here because DispatchKey::Functionalize is a mode key + if pytree.tree_all_only(torch.Tensor, lambda x: not torch._is_functional_tensor(x), args): + with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Functionalize)): + return mutable_op(*args) + + # NB: This differs from the codegen -- codegen handles cases where there + # are mixed FunctionalTensorWrapper and non-FunctionalTensorWrapper. + # This only really matters for XLA (mixed CPU-XLA tensors) and + # running functionalization without the PT2 stack (which guarantees to us that + # all tensors are FunctionalTensorWrapper). + if not pytree.tree_all_only(torch.Tensor, torch._is_functional_tensor, args): + raise RuntimeError("{mutable_op}: expected all args to be FunctionalTensorWrapper") + + unwrapped_args = [] + for arg in args: + if isinstance(arg, torch.Tensor) and torch._is_functional_tensor(arg): + torch._sync(arg) + unwrapped = torch._from_functional_tensor(arg) + unwrapped_args.append(unwrapped) + else: + unwrapped_args.append(arg) + + with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Functionalize)): + output = functional_op(*unwrapped_args) + + num_actual_output = len(mutable_op._schema.returns) + actual_output = pytree.tree_map( + torch._to_functional_tensor, output[:num_actual_output]) + + new_values_to_propagate = output[num_actual_output:] + inputs_to_replace = [arg for is_write, arg in zip(mutable_args(mutable_op), args) + if is_write] + assert len(new_values_to_propagate) == len(inputs_to_replace) + for new_value, arg in zip(new_values_to_propagate, inputs_to_replace): + if (arg is None and new_value is None) or (arg is not None and new_value is not None): + continue + torch._C._propagate_xla_data(arg, new_value) + torch._C._replace_(arg, new_value) + torch._C._commit_update(arg) + torch._sync(arg) + + if len(actual_output) == 1: + return actual_output[0] + elif len(actual_output) == 0: + return None + return actual_output + + return kernel + + +def validate(mutable_op: OpOverload): + if not isinstance(mutable_op, OpOverload): + raise TypeError( + f"register_functional_op(mutable_op): expected mutable_op to be instance of " + f"OpOverload but got {type(mutable_op)}") + + # There are generally three types of "in-place" or "mutable" ops. + # Each of them have their own conventions: + # - inplace (first input modified in-place and returned as only output) + # - out= (some args modified in-place and returned as outputs) + # - mutable (some args modified in-place but none of those returned as outputs) + # In theory we can support all three, but we'll just support the last + # option right now for simplicity. + schema = FunctionSchema.parse(str(mutable_op._schema)) + if not schema.kind() == SchemaKind.mutable: + raise RuntimeError("Expected op to be mutable (as opposed to functional, inplace or out)") + for ret in schema.returns: + # construct_functionalization_kernel assumes this for simplicity + if ret.annotation is not None: + raise NotImplementedError( + "NYI: register_functional_op(op) where op returns a mutated or aliased value. " + "Please file an issue (and as a workaround, modify your operator to " + "not return the mutated value or aliases)") + for arg in schema.arguments.flat_all: + # construct_functionalization_kernel assumes this for simplicity + if arg.type.is_tensor_like() and ( + arg.type != BaseType(BaseTy.Tensor) + and arg.type != OptionalType(BaseType(BaseTy.Tensor)) + ): + raise NotImplementedError( + "NYI: register_functional_op(op) where op has a List[Tensor] input." + "Please file an issue.") + + +def functional_schema(new_op_name, op: OpOverload): + schema = FunctionSchema.parse(str(op._schema)) + schema = schema.signature().with_name(OperatorName.parse(new_op_name)) + return str(schema) + + +def mutable_args(op: OpOverload): + return tuple(False if arg.alias_info is None else arg.alias_info.is_write + for arg in op._schema.arguments) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/_conversions.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/_conversions.py new file mode 100644 index 0000000000000000000000000000000000000000..fa1ca2428255aa9fe3892328f6ab95cc5f5b7568 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/_conversions.py @@ -0,0 +1,118 @@ +import torch +import torch._prims_common as utils + +# Utilities should come BEFORE this import +from torch._decomp import register_decomposition + +from torch._prims_common import TensorLikeType +from torch._prims_common.wrappers import out_wrapper +from torch._refs import _broadcast_shapes + +# Data conversion references. +# +# Note: this module breaks the usual _refs to torch naming scheme where +# _refs.foo.bar is a ref for torch.foo.bar. The following definitions are not +# part of _refs/__init__.py to avoid name clashes with Python builtin types +# (like int). + +__all__ = [ + # dtypes + "bfloat16", + "bool", + "byte", + "cdouble", + "cfloat", + "chalf", + "char", + "double", + "float", + "half", + "int", + "long", + "short", + # misc + "complex", + "polar", +] + + +def _make_conversion_method(name: str, dtype: torch.dtype): + def fn( + self: TensorLikeType, memory_format: torch.memory_format = torch.preserve_format + ) -> TensorLikeType: + return self.to(dtype, memory_format=memory_format) # type: ignore[call-overload] + + fn.__name__ = name + return fn + + +bfloat16 = _make_conversion_method("bfloat16", torch.bfloat16) + +bool = _make_conversion_method("bool", torch.bool) + +byte = _make_conversion_method("byte", torch.uint8) + +cdouble = _make_conversion_method("cdouble", torch.cdouble) + +cfloat = _make_conversion_method("cfloat", torch.cfloat) + +chalf = _make_conversion_method("chalf", torch.complex32) + +char = _make_conversion_method("char", torch.int8) + +double = _make_conversion_method("double", torch.double) + +float = _make_conversion_method("float", torch.float) + +half = _make_conversion_method("half", torch.half) + +int = _make_conversion_method("int", torch.int) + +long = _make_conversion_method("long", torch.long) + +short = _make_conversion_method("short", torch.short) + + +@register_decomposition(torch._ops.ops.aten.complex) +# Note: complex has type promotion tests disabled due to different semantics. +# exact_dtype is for compat with complex_check_dtype from core. +@out_wrapper(exact_dtype=True) +def complex(real: TensorLikeType, imag: TensorLikeType) -> TensorLikeType: + allowed_dtypes = (torch.float32, torch.float64, torch.float16) + torch._check( + real.dtype in allowed_dtypes and imag.dtype in allowed_dtypes, + lambda: ( + f"Expected both inputs to be Half, Float or Double tensors but got " + f"{real.dtype} and {imag.dtype}" + ), + ) + torch._check( + real.dtype == imag.dtype, + lambda: ( + f"Expected object of scalar type {real.dtype} but got " + f"scalar type {imag.dtype} for second argument" + ), + ) + result_dtype = utils.corresponding_complex_dtype(real.dtype) # type: ignore[arg-type] + common_shape = _broadcast_shapes(real.shape, imag.shape) + result = real.new_empty( + common_shape, + dtype=result_dtype, + layout=real.layout, + device=real.device, + # pin_memory=real.is_pinned(), # NYI + ) + result.real = real + result.imag = imag + return result + + +@register_decomposition(torch._ops.ops.aten.polar) +# Note: polar has type promotion tests disabled due to different semantics. +# exact_dtype is for compat with complex_check_dtype from core. +@out_wrapper(exact_dtype=True) +def polar(abs: TensorLikeType, angle: TensorLikeType) -> TensorLikeType: + result = torch.complex(abs, angle) + result.real = abs * torch.cos(angle) + result.imag = abs * torch.sin(angle) + return result diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/linalg/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/linalg/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b328343f5a033797ffabd50cd24331be679d27e8 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/linalg/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/nn/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/nn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3646b96be90be51e86070360b23212ed450186a6 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/nn/__init__.py @@ -0,0 +1,3 @@ +from typing import List + +__all__: List[str] = [] diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/nn/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/nn/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bff83aa2b19e4b02d8b86eeac2ea4ad1bcd6dee1 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/nn/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/nn/functional/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/nn/functional/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..963fb508458e6af5703b6a18a3a441262085173e Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/nn/functional/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/special/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/special/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d32df8897db6c4d7e99f6de18a3c00bf19be50ed Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/special/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantizable/modules/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantizable/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7c9fb032a2bb3d5e4452b48d0a870615c186f365 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantizable/modules/__init__.py @@ -0,0 +1,9 @@ +from .activation import MultiheadAttention +from .rnn import LSTM +from .rnn import LSTMCell + +__all__ = [ + 'LSTM', + 'LSTMCell', + 'MultiheadAttention', +] diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/__pycache__/functional_modules.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/__pycache__/functional_modules.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c292546a171782ad94d189f7d15a1599d257ac1e Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/__pycache__/functional_modules.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/__pycache__/normalization.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/__pycache__/normalization.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0d308a8f38ad2ee8f12052dd3bbe1f96fc97ecac Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/__pycache__/normalization.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/batchnorm.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/batchnorm.py new file mode 100644 index 0000000000000000000000000000000000000000..bfef31268cff11f937b806701d6e6667b1c3622f --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/modules/batchnorm.py @@ -0,0 +1,106 @@ +import torch +import torch.ao.nn.intrinsic as nni + +__all__ = [ + "BatchNorm2d", + "BatchNorm3d" +] + +class _BatchNorm(torch.nn.modules.batchnorm._BatchNorm): + def __init__(self, num_features, eps=1e-5, momentum=0.1, device=None, dtype=None) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__(num_features, eps, momentum, True, True, **factory_kwargs) + self.register_buffer('scale', torch.tensor(1.0, **factory_kwargs)) + self.register_buffer('zero_point', torch.tensor(0, **factory_kwargs)) + + @staticmethod + def from_float(cls, mod): + activation_post_process = mod.activation_post_process + if type(mod) == cls._NNI_BN_RELU_MODULE: + mod = mod[0] + scale, zero_point = activation_post_process.calculate_qparams() + new_mod = cls(mod.num_features, mod.eps) + new_mod.weight = mod.weight + new_mod.bias = mod.bias + new_mod.running_mean = mod.running_mean + new_mod.running_var = mod.running_var + new_mod.scale = scale + new_mod.zero_point = zero_point + return new_mod + + @classmethod + def from_reference(cls, bn, output_scale, output_zero_point): + qbn = cls( + bn.num_features, + bn.eps, + bn.momentum, + device=bn.weight.device, + dtype=bn.weight.dtype + ) + qbn.weight = bn.weight + qbn.bias = bn.bias + qbn.running_mean = bn.running_mean + qbn.running_var = bn.running_var + qbn.scale = output_scale + qbn.zero_point = output_zero_point + return qbn + +class BatchNorm2d(_BatchNorm): + r"""This is the quantized version of :class:`~torch.nn.BatchNorm2d`. + """ + + _NNI_BN_RELU_MODULE = nni.BNReLU2d + + def __init__(self, num_features, eps=1e-5, momentum=0.1, device=None, dtype=None) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__(num_features, eps, momentum, **factory_kwargs) + + def _get_name(self): + return 'QuantizedBatchNorm2d' + + def _check_input_dim(self, input): + # Temporarily using len(shape) instead of ndim due to JIT issue + # https://github.com/pytorch/pytorch/issues/23890 + if len(input.shape) != 4: + raise ValueError("Input shape must be `(N, C, H, W)`!") + + def forward(self, input: torch.Tensor) -> torch.Tensor: + # disabling this since this is not symbolically traceable + # self._check_input_dim(input) + return torch.ops.quantized.batch_norm2d( + input, self.weight, self.bias, self.running_mean, + self.running_var, self.eps, self.scale, self.zero_point) + + @classmethod + def from_float(cls, mod): + return _BatchNorm.from_float(cls, mod) + +class BatchNorm3d(_BatchNorm): + r"""This is the quantized version of :class:`~torch.nn.BatchNorm3d`. + """ + + _NNI_BN_RELU_MODULE = nni.BNReLU3d + + def __init__(self, num_features, eps=1e-5, momentum=0.1, device=None, dtype=None): + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__(num_features, eps, momentum, **factory_kwargs) + + def _get_name(self): + return 'QuantizedBatchNorm3d' + + def _check_input_dim(self, input): + # Temporarily using len(shape) instead of ndim due to JIT issue + # https://github.com/pytorch/pytorch/issues/23890 + if len(input.shape) != 5: + raise ValueError("Input shape must be `(N, C, H, W)`!") + + def forward(self, input: torch.Tensor) -> torch.Tensor: + # disabling this since this is not symbolically traceable + # self._check_input_dim(input) + return torch.ops.quantized.batch_norm3d( + input, self.weight, self.bias, self.running_mean, + self.running_var, self.eps, self.scale, self.zero_point) + + @classmethod + def from_float(cls, mod): + return _BatchNorm.from_float(cls, mod) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/utils.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a5b0b2151385dfa7c12dd21dbac8089666e56496 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/reference/modules/__pycache__/utils.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/reference/modules/linear.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/reference/modules/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..378fe0eb6eeeab9c69f06bef6cd71213f5b7fe34 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/quantized/reference/modules/linear.py @@ -0,0 +1,57 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Optional, Dict, Any +from .utils import ReferenceQuantizedModule + +__all__ = ['Linear'] + +class Linear(nn.Linear, ReferenceQuantizedModule): + """ A reference quantized linear module that fits into the FX + Graph Mode Quantization workflow + activation will be floating point Tensor, we will store floating + point weight as well in the module, but in forward we'll quantize + and dequantize the weight before running the floating point functional + linear operator. + """ + _IS_REFERENCE = True + + def __init__( + self, + in_features: int, + out_features: int, + bias_: bool = True, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + weight_qparams: Optional[Dict[str, Any]] = None): + super().__init__(in_features, out_features, bias_, device, dtype) + self._init_weight_qparams(weight_qparams, device) + + def _get_name(self): + return "QuantizedLinear(Reference)" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + we have: + w(float) -- quant - dequant \ + x(float) ------------- F.linear --- + + In the full model, we will see + w(float) -- quant - *dequant \ + x -- quant --- *dequant -- *F.linear --- *quant - dequant + and the backend should be able to fuse the ops with `*` into a quantized linear + """ + weight_quant_dequant = self.get_weight() + result = F.linear(x, weight_quant_dequant, self.bias) + return result + + @classmethod + def from_float(cls, float_linear, weight_qparams): + qref_linear = Linear( + float_linear.in_features, float_linear.out_features, + float_linear.bias is not None, device=float_linear.weight.device, + dtype=float_linear.weight.dtype, weight_qparams=weight_qparams) + qref_linear.weight = torch.nn.Parameter(float_linear.weight.detach()) + if float_linear.bias is not None: + qref_linear.bias = torch.nn.Parameter(float_linear.bias.detach()) + return qref_linear diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/sparse/quantized/dynamic/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/sparse/quantized/dynamic/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..61422835a93ba34d897c26fc226d79edc5d96c8d Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/nn/sparse/quantized/dynamic/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f33fe61468f16a43b2afbaba12dbf37e14d4ae30 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/__pycache__/_numeric_suite.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/__pycache__/_numeric_suite.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f57119c515e1ac00d313a3d7fb9ec0951cf60e0 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/__pycache__/_numeric_suite.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3eeefe98638a9308cb696ca405d79a172dfd36dd Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e2b8ee5c810ad0ef9ad4becc1c6d8eb45f30b493 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__init__.py @@ -0,0 +1,189 @@ +# flake8: noqa: F403 + +from .fake_quantize import * # noqa: F403 +from .fuse_modules import fuse_modules # noqa: F403 +from .fuse_modules import fuse_modules_qat # noqa: F403 +from .fuser_method_mappings import * # noqa: F403 +from .observer import * # noqa: F403 +from .qconfig import * # noqa: F403 +from .qconfig_mapping import * # noqa: F403 +from .quant_type import * # noqa: F403 +from .quantization_mappings import * # type: ignore[no-redef] +from .quantize import * # noqa: F403 +from .quantize_jit import * # noqa: F403 +from .stubs import * # noqa: F403 +from .pt2e.export_utils import _move_exported_model_to_eval as move_exported_model_to_eval +from .pt2e.export_utils import _move_exported_model_to_train as move_exported_model_to_train +from .pt2e.export_utils import _allow_exported_model_train_eval as allow_exported_model_train_eval +from .pt2e.generate_numeric_debug_handle import generate_numeric_debug_handle # noqa: F401 +from typing import Union, List, Callable, Tuple, Optional +from torch import Tensor +import torch + +ObserverOrFakeQuantize = Union[ObserverBase, FakeQuantizeBase] +ObserverOrFakeQuantize.__module__ = "torch.ao.quantization" + +__all__ = [ + "DeQuantStub", + "FakeQuantize", + "FakeQuantizeBase", + "FixedQParamsFakeQuantize", + "FixedQParamsObserver", + "FusedMovingAvgObsFakeQuantize", + "HistogramObserver", + "MatchAllNode", + "MinMaxObserver", + "MovingAverageMinMaxObserver", + "MovingAveragePerChannelMinMaxObserver", + "NoopObserver", + "ObserverBase", + "ObserverOrFakeQuantize", + "Pattern", + "PerChannelMinMaxObserver", + "PlaceholderObserver", + "QConfig", + "QConfigAny", + "QConfigDynamic", + "QConfigMapping", + "QuantStub", + "QuantType", + "QuantWrapper", + "RecordingObserver", + "ReuseInputObserver", + "UniformQuantizationObserverBase", + "add_quant_dequant", + "convert", + "convert_dynamic_jit", + "convert_jit", + "default_affine_fixed_qparams_fake_quant", + "default_affine_fixed_qparams_observer", + "default_debug_observer", + "default_dynamic_fake_quant", + "default_dynamic_quant_observer", + "default_embedding_fake_quant", + "default_embedding_fake_quant_4bit", + "default_eval_fn", + "default_fake_quant", + "default_fixed_qparams_range_0to1_fake_quant", + "default_fixed_qparams_range_0to1_observer", + "default_fixed_qparams_range_neg1to1_fake_quant", + "default_fixed_qparams_range_neg1to1_observer", + "default_float_qparams_observer", + "default_float_qparams_observer_4bit", + "default_fused_act_fake_quant", + "default_fused_per_channel_wt_fake_quant", + "default_fused_wt_fake_quant", + "default_histogram_fake_quant", + "default_histogram_observer", + "default_observer", + "default_per_channel_weight_fake_quant", + "default_per_channel_weight_observer", + "default_placeholder_observer", + "default_reuse_input_observer", + "default_symmetric_fixed_qparams_fake_quant", + "default_symmetric_fixed_qparams_observer", + "default_weight_fake_quant", + "default_weight_observer", + "disable_fake_quant", + "disable_observer", + "enable_fake_quant", + "enable_observer", + "fuse_conv_bn", + "fuse_conv_bn_jit", + "fuse_conv_bn_relu", + "fuse_convtranspose_bn", + "fuse_linear_bn", + "fuse_modules", + "fuse_modules_qat", + "fused_per_channel_wt_fake_quant_range_neg_127_to_127", + "fused_wt_fake_quant_range_neg_127_to_127", + "get_combined_dict", + "get_default_compare_output_module_list", + "get_default_custom_config_dict", + "get_default_dynamic_quant_module_mappings", + "get_default_dynamic_sparse_quant_module_mappings", + "get_default_float_to_quantized_operator_mappings", + "get_default_qat_module_mappings", + "get_default_qat_qconfig", + "get_default_qat_qconfig_dict", + "get_default_qat_qconfig_mapping", + "get_default_qconfig", + "get_default_qconfig_dict", + "get_default_qconfig_mapping", + "get_default_qconfig_propagation_list", + "get_default_static_quant_module_mappings", + "get_default_static_quant_reference_module_mappings", + "get_default_static_sparse_quant_module_mappings", + "get_dynamic_quant_module_class", + "get_embedding_qat_module_mappings", + "get_embedding_static_quant_module_mappings", + "get_fuser_method", + "get_fuser_method_new", + "get_observer_state_dict", + "get_quantized_operator", + "get_static_quant_module_class", + "load_observer_state_dict", + "move_exported_model_to_eval", + "move_exported_model_to_train", + "allow_exported_model_train_eval", + "no_observer_set", + "per_channel_weight_observer_range_neg_127_to_127", + "prepare", + "prepare_dynamic_jit", + "prepare_jit", + "prepare_qat", + "propagate_qconfig_", + "qconfig_equals", + "quantize", + "quantize_dynamic", + "quantize_dynamic_jit", + "quantize_jit", + "quantize_qat", + "script_qconfig", + "script_qconfig_dict", + "swap_module", + "weight_observer_range_neg_127_to_127", + "generate_numeric_debug_handle", +] + +def default_eval_fn(model, calib_data): + r"""Define the default evaluation function. + + Default evaluation function takes a torch.utils.data.Dataset or a list of + input Tensors and run the model on the dataset + """ + for data, target in calib_data: + model(data) + +class _DerivedObserverOrFakeQuantize(ObserverBase): + r"""This observer is used to describe an observer whose quantization parameters + are derived from other observers + """ + + def __init__( + self, + dtype: torch.dtype, + obs_or_fqs: List[ObserverOrFakeQuantize], + derive_qparams_fn: Callable[[List[ObserverOrFakeQuantize]], Tuple[Tensor, Tensor]], + quant_min: Optional[int]=None, + quant_max: Optional[int]=None, + qscheme: Optional[torch.qscheme]=None, + ch_axis: Optional[int] = None + ): + super().__init__(dtype) + self.obs_or_fqs = obs_or_fqs + self.derive_qparams_fn = derive_qparams_fn + self.quant_min = quant_min + self.quant_max = quant_max + self.qscheme = qscheme + self.ch_axis = ch_axis + + from .utils import is_per_channel + if is_per_channel(self.qscheme): + assert self.ch_axis is not None, "Must provide a valid ch_axis if qscheme is per channel" + + def forward(self, x: Tensor) -> Tensor: + return x + + def calculate_qparams(self): + return self.derive_qparams_fn(self.obs_or_fqs) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/quantization_mappings.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/quantization_mappings.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6ee5150bcf97c886a0b719d0506b6179322112c2 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/quantization_mappings.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/quantize_fx.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/quantize_fx.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d6430a057cfbef209681d2476f4131c63a368d47 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/quantize_fx.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/_correct_bias.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/_correct_bias.py new file mode 100644 index 0000000000000000000000000000000000000000..83cc81bb6b002c5df4110f52c0b6cb9a8e04e3c5 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/_correct_bias.py @@ -0,0 +1,144 @@ +import torch +import torch.nn as nn +import torch.ao.nn.quantized as nnq + +import torch.ao.quantization +import torch.ao.ns._numeric_suite as ns + +__all__ = [ + "get_module", + "parent_child_names", + "get_param", + "MeanShadowLogger", + "bias_correction", +] + +_supported_modules = {nn.Linear, nn.Conv2d} +_supported_modules_quantized = {nnq.Linear, nnq.Conv2d} + +def get_module(model, name): + """Given name of submodule, this function grabs the submodule from given model.""" + return dict(model.named_modules())[name] + +def parent_child_names(name): + """Split full name of submodule into parent submodule's full name and submodule's name.""" + split_name = name.rsplit('.', 1) + if len(split_name) == 1: + return '', split_name[0] + else: + return split_name[0], split_name[1] + +def get_param(module, attr): + """Get the parameter given a module and attribute. + + Sometimes the weights/bias attribute gives you the raw tensor, but sometimes + gives a function that will give you the raw tensor, this function takes care of that logic + """ + param = getattr(module, attr, None) + if callable(param): + return param() + else: + return param + +class MeanShadowLogger(ns.Logger): + """Mean Logger for a Shadow module. + + A logger for a Shadow module whose purpose is to record the rolling mean + of the data passed to the floating point and quantized models + """ + + def __init__(self): + """Set up initial values for float and quantized stats, count, float sum, and quant sum.""" + super().__init__() + self.stats["float"] = None + self.stats["quantized"] = None + self.count = 0 + self.float_sum = None + self.quant_sum = None + + def forward(self, x, y): + """Compute the average of quantized and floating-point data from modules. + + The inputs x,y are output data from the quantized and floating-point modules. + x is for the quantized module, y is for the floating point module + """ + if x.is_quantized: + x = x.dequantize() + + self.count += 1 + if self.stats["quantized"] is None: + self.stats["quantized"] = x + self.quant_sum = x + else: + self.quant_sum += x + self.stats["quantized"] = self.quant_sum / self.count + + if self.stats["float"] is None: + self.stats["float"] = y + self.float_sum = y + else: + self.float_sum += y + self.stats["float"] = self.float_sum / self.count + + def clear(self): + self.stats["float"] = None + self.stats["quantized"] = None + self.count = 0 + self.float_sum = None + self.quant_sum = None + +def bias_correction(float_model, quantized_model, img_data, target_modules=_supported_modules_quantized, neval_batches=None): + """Perform bias correction on a module. + + Using numeric suite shadow module, the expected output of the floating point and quantized modules + is recorded. Using that data the bias of supported modules is shifted to compensate for the drift caused + by quantization + Paper reference: https://arxiv.org/pdf/1906.04721.pdf (Section 4.2) + + Args: + float_model: a trained model that serves as a reference to what bias correction should aim for + quantized_model: quantized form of float_model that bias correction is to applied to + img_data: calibration data to estimate the expected output (used to find quantization error) + target_modules: specifies what submodules in quantized_model need bias correction (can be extended to + unquantized submodules) + neval_batches: a cap to the number of batches you want to be used for estimating the expected output + """ + ns.prepare_model_with_stubs(float_model, quantized_model, _supported_modules, MeanShadowLogger) + + uncorrected_modules = {} + for name, submodule in quantized_model.named_modules(): + if type(submodule) in target_modules: + uncorrected_modules[name] = submodule + + for uncorrected_module in uncorrected_modules: + quantized_submodule = get_module(quantized_model, uncorrected_module) + bias = get_param(quantized_submodule, 'bias') + if bias is not None: + + count = 0 + for data in img_data: + quantized_model(data[0]) + count += 1 + if count == neval_batches: + break + ob_dict = ns.get_logger_dict(quantized_model) + parent_name, _ = parent_child_names(uncorrected_module) + + float_data = ob_dict[parent_name + '.stats']['float'] + quant_data = ob_dict[parent_name + '.stats']['quantized'] + + # math for expected_error + quantization_error = quant_data - float_data + dims = list(range(quantization_error.dim())) + # Note: we don't want to take the mean over the output channel dimension + dims.remove(1) + expected_error = torch.mean(quantization_error, dims) + + updated_bias = bias.data - expected_error + + bias.data = updated_bias + + # Resets the data contained in the loggers + for name, submodule in quantized_model.named_modules(): + if isinstance(submodule, MeanShadowLogger): + submodule.clear() diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/__pycache__/_common_operator_config_utils.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/__pycache__/_common_operator_config_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..285774741edc6610ed2bcc9947a8e0c573d79751 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/__pycache__/_common_operator_config_utils.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/__pycache__/backend_config.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/__pycache__/backend_config.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..65d413d0eb2a0f1eff99c0972dd8c2d95af502cd Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/__pycache__/backend_config.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/__pycache__/observation_type.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/__pycache__/observation_type.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dab5a57d614c7467cee5535e00faa87b4d00fe5f Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/__pycache__/observation_type.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/__pycache__/onednn.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/__pycache__/onednn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..25fa1f823a8725a1328e5fd12393d54dd299e8b8 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/__pycache__/onednn.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/observation_type.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/observation_type.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/qnnpack.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/qnnpack.py new file mode 100644 index 0000000000000000000000000000000000000000..772a25c65574481d70186e9d968039756b2fa0ae --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/qnnpack.py @@ -0,0 +1,160 @@ +import torch +from ._common_operator_config_utils import ( + _get_binary_op_configs, + _get_bn_configs, + _get_cat_config, + _get_conv_configs, + _get_default_op_configs, + _get_embedding_op_configs, + _get_fixed_qparams_op_configs, + _get_linear_configs, + _get_rnn_op_configs, + _get_share_qparams_op_configs, +) +from .backend_config import BackendConfig, DTypeConfig, DTypeWithConstraints + +__all__ = [ + "get_qnnpack_backend_config", +] + +# =================== +# | DTYPE CONFIGS | +# =================== + +qnnpack_weighted_op_quint8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, + weight_dtype=torch.qint8, + bias_dtype=torch.float, +) + +qnnpack_default_op_quint8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, +) + +qnnpack_default_op_fp16_dtype_config = DTypeConfig( + input_dtype=torch.float16, + output_dtype=torch.float16, + weight_dtype=torch.float16, + bias_dtype=torch.float16, +) + +qnnpack_default_dynamic_int8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.float, + weight_dtype=torch.qint8, + bias_dtype=torch.float, + is_dynamic=True, +) + +qnnpack_default_dynamic_float16_dtype_config = DTypeConfig( + input_dtype=torch.float16, + output_dtype=torch.float, + weight_dtype=torch.float16, + bias_dtype=torch.float, + is_dynamic=True, +) + +qnnpack_weight_only_quint8_dtype_config = DTypeConfig( + input_dtype=torch.float, + output_dtype=torch.float, + weight_dtype=torch.quint8, +) + +qnnpack_weight_only_quint4x2_dtype_config = DTypeConfig( + input_dtype=torch.float, + output_dtype=torch.float, + weight_dtype=torch.quint4x2, +) + +# xnnpack compatible dtype configs + +# We restrict scale values to be 2 ** -12 to ensure the +# requantization scale never falls below the xnnpack lower +# threshold. Additionally, for qint8 weight, we restrict +# the quantization values to [-127, +127], excluding -128. +# For more detail, refer to the description of +# `default_symmetric_qnnpack_qconfig`. + +# TODO: add additional restriction on qscheme to ensure it +# is either per_tensor_symmetric or per_channel_symmetric + +qnnpack_act_qint8_scale_min_2_neg_12 = DTypeWithConstraints( + dtype=torch.qint8, + scale_min_lower_bound=2 ** -12, +) + +qnnpack_weight_qint8_neg_127_to_127_scale_min_2_neg_12 = DTypeWithConstraints( + dtype=torch.qint8, + quant_min_lower_bound=-127, + quant_max_upper_bound=127, + scale_min_lower_bound=2 ** -12, +) + +qnnpack_weighted_op_qint8_symmetric_dtype_config = DTypeConfig( + input_dtype=qnnpack_act_qint8_scale_min_2_neg_12, + output_dtype=qnnpack_act_qint8_scale_min_2_neg_12, + weight_dtype=qnnpack_weight_qint8_neg_127_to_127_scale_min_2_neg_12, + bias_dtype=torch.float, +) + +qnnpack_default_op_qint8_symmetric_dtype_config = DTypeConfig( + input_dtype=qnnpack_act_qint8_scale_min_2_neg_12, + output_dtype=qnnpack_act_qint8_scale_min_2_neg_12, +) + + +# ===================== +# | BACKEND CONFIGS | +# ===================== + +def get_qnnpack_backend_config() -> BackendConfig: + """ + Return the `BackendConfig` for PyTorch's native QNNPACK backend. + """ + conv_dtype_configs = [ + qnnpack_weighted_op_qint8_symmetric_dtype_config, + qnnpack_weighted_op_quint8_dtype_config, + ] + linear_dtype_configs = [ + qnnpack_weighted_op_qint8_symmetric_dtype_config, + qnnpack_weighted_op_quint8_dtype_config, + qnnpack_default_dynamic_int8_dtype_config, + qnnpack_default_dynamic_float16_dtype_config, + ] + binary_op_dtype_configs = [ + qnnpack_default_op_qint8_symmetric_dtype_config, + qnnpack_default_op_quint8_dtype_config, + ] + default_op_dtype_configs = [ + qnnpack_default_op_qint8_symmetric_dtype_config, + qnnpack_default_op_quint8_dtype_config, + ] + fixed_qparams_op_dtype_configs = [ + qnnpack_default_op_qint8_symmetric_dtype_config, + qnnpack_default_op_quint8_dtype_config, + ] + share_qparams_op_dtype_configs = [ + qnnpack_default_op_qint8_symmetric_dtype_config, + qnnpack_default_op_quint8_dtype_config, + ] + rnn_op_dtype_configs = [ + qnnpack_default_dynamic_int8_dtype_config, + qnnpack_default_dynamic_float16_dtype_config, + ] + embedding_op_dtype_configs = [ + qnnpack_weight_only_quint8_dtype_config, + qnnpack_weight_only_quint4x2_dtype_config, + ] + return BackendConfig("qnnpack") \ + .set_backend_pattern_configs(_get_conv_configs(conv_dtype_configs)) \ + .set_backend_pattern_configs(_get_linear_configs(linear_dtype_configs)) \ + .set_backend_pattern_configs(_get_binary_op_configs(binary_op_dtype_configs)) \ + .set_backend_pattern_config(_get_cat_config(default_op_dtype_configs)) \ + .set_backend_pattern_configs(_get_default_op_configs(default_op_dtype_configs)) \ + .set_backend_pattern_configs(_get_fixed_qparams_op_configs(fixed_qparams_op_dtype_configs)) \ + .set_backend_pattern_configs(_get_share_qparams_op_configs(share_qparams_op_dtype_configs)) \ + .set_backend_pattern_configs(_get_bn_configs(default_op_dtype_configs)) \ + .set_backend_pattern_configs(_get_rnn_op_configs(rnn_op_dtype_configs)) \ + .set_backend_pattern_configs(_get_embedding_op_configs(embedding_op_dtype_configs)) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/tensorrt.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/tensorrt.py new file mode 100644 index 0000000000000000000000000000000000000000..1c5f761508bbb9e95392bfe07d494f7fba61303d --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/tensorrt.py @@ -0,0 +1,81 @@ +import torch +from .backend_config import ( + BackendConfig, + BackendPatternConfig, + DTypeConfig, + ObservationType +) +from ._common_operator_config_utils import ( + _get_binary_op_configs, + _get_linear_configs, + _get_conv_configs, + _get_share_qparams_op_configs, + _get_tensor_info_op_configs, +) + +__all__ = [ + "get_tensorrt_backend_config", + "get_tensorrt_backend_config_dict", +] + +def get_tensorrt_backend_config() -> BackendConfig: + """ + Return the `BackendConfig` for the TensorRT backend. + NOTE: Current api will change in the future, it's just to unblock experimentation for + new backends, please don't use it right now. + TODO: add a README when it's more stable + """ + # dtype configs + weighted_op_qint8_dtype_config = DTypeConfig( + input_dtype=torch.qint8, + output_dtype=torch.qint8, + weight_dtype=torch.qint8, + bias_dtype=torch.float, + ) + non_weighted_op_qint8_dtype_config = DTypeConfig( + input_dtype=torch.qint8, + output_dtype=torch.qint8, + ) + + addmm_config = BackendPatternConfig(torch.addmm) \ + .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \ + .add_dtype_config(weighted_op_qint8_dtype_config) \ + ._set_input_type_to_index({ + "bias": 0, + "input": 1, + "weight": 2, + }) + cat_config = BackendPatternConfig(torch.cat) \ + .set_observation_type(ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT) \ + .add_dtype_config(non_weighted_op_qint8_dtype_config) + conv_dtype_configs = [ + weighted_op_qint8_dtype_config, + ] + linear_dtype_configs = [ + weighted_op_qint8_dtype_config, + ] + binary_op_dtype_configs = [ + weighted_op_qint8_dtype_config, + ] + share_qparams_op_dtype_configs = [ + non_weighted_op_qint8_dtype_config, + ] + tensor_info_op_dtype_configs = [ + non_weighted_op_qint8_dtype_config, + ] + # there might be things not supported in fx2trt, but it will error out + # during fx2trt conversion and can support them after that + return BackendConfig("tensorrt") \ + .set_backend_pattern_configs(_get_conv_configs(conv_dtype_configs)) \ + .set_backend_pattern_config(addmm_config) \ + .set_backend_pattern_config(cat_config) \ + .set_backend_pattern_configs(_get_linear_configs(linear_dtype_configs)) \ + .set_backend_pattern_configs(_get_binary_op_configs(binary_op_dtype_configs)) \ + .set_backend_pattern_configs(_get_share_qparams_op_configs(share_qparams_op_dtype_configs)) \ + .set_backend_pattern_configs(_get_tensor_info_op_configs(tensor_info_op_dtype_configs)) + +def get_tensorrt_backend_config_dict(): + """ + Return the `BackendConfig` for the TensorRT backend in dictionary form. + """ + return get_tensorrt_backend_config().to_dict() diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/utils.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2e738227407907ef942786937eb082f41d9e02ef --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/utils.py @@ -0,0 +1,279 @@ +from typing import Dict, Any, List, Callable, Union, Tuple, Type + +import torch +import torch.nn as nn +import torch.nn.functional as F +from .backend_config import ( + BackendConfig, + BackendPatternConfig, + DTypeConfig, +) +from ..utils import Pattern +from ..fuser_method_mappings import ( + _reverse2, + _reverse3, +) + +__all__ = [ + "get_pattern_to_dtype_configs", + "get_qat_module_classes", + "get_fused_module_classes", + "get_pattern_to_input_type_to_index", + "get_root_module_to_quantized_reference_module", + "get_fuser_method_mapping", + "get_module_to_qat_module", + "get_fusion_pattern_to_root_node_getter", + "get_fusion_pattern_to_extra_inputs_getter", + "remove_boolean_dispatch_from_name", + "pattern_to_human_readable", + "entry_to_pretty_str", +] + +def get_pattern_to_dtype_configs(backend_config: BackendConfig) -> Dict[Pattern, List[DTypeConfig]]: + pattern_to_dtype_configs: Dict[Pattern, List[DTypeConfig]] = {} + for pattern, config in backend_config._pattern_complex_format_to_config.items(): + pattern_to_dtype_configs[pattern] = config.dtype_configs + return pattern_to_dtype_configs + +def get_qat_module_classes(backend_config: BackendConfig) -> Tuple[type, ...]: + qat_module_classes = [] + for config in backend_config.configs: + if config.qat_module is not None: + qat_module_classes.append(config.qat_module) + return tuple(set(qat_module_classes)) + +def get_fused_module_classes(backend_config: BackendConfig) -> Tuple[type, ...]: + fused_module_classes = [] + for config in backend_config.configs: + if config.fused_module is not None: + fused_module_classes.append(config.fused_module) + return tuple(set(fused_module_classes)) + +def get_pattern_to_input_type_to_index(backend_config: BackendConfig) -> Dict[Pattern, Dict[str, int]]: + pattern_to_input_type_to_index: Dict[Pattern, Dict[str, int]] = {} + for pattern, config in backend_config._pattern_complex_format_to_config.items(): + pattern_to_input_type_to_index[pattern] = config._input_type_to_index + return pattern_to_input_type_to_index + +def get_root_module_to_quantized_reference_module( + backend_config: BackendConfig) -> Dict[Type[torch.nn.Module], Type[torch.nn.Module]]: + mapping: Dict[Type[torch.nn.Module], Type[torch.nn.Module]] = {} + for config in backend_config.configs: + if config.root_module is not None and config.reference_quantized_module is not None: + mapping[config.root_module] = config.reference_quantized_module + return mapping + +def get_fuser_method_mapping(backend_config: BackendConfig) -> Dict[Pattern, Union[nn.Sequential, Callable]]: + fuser_method_mapping : Dict[Pattern, Union[nn.Sequential, Callable]] = {} + for pattern, config in backend_config._pattern_complex_format_to_config.items(): + if config.fuser_method is not None: + # Note: both the fuser method and the pattern are specified in forward order in the + # BackendConfig, but the internal pattern matching code uses the reversed nested tuple + # format, so we need to convert both to the internal format + fuser_method = _get_fuser_method_in_reversed_nested_tuple_format(config) + fuser_method_mapping[pattern] = fuser_method + return fuser_method_mapping + +def get_module_to_qat_module(backend_config: BackendConfig) -> Dict[Pattern, Type[torch.nn.Module]]: + module_to_qat_module: Dict[Pattern, Type[torch.nn.Module]] = {} + for pattern, config in backend_config._pattern_complex_format_to_config.items(): + if config.qat_module is not None: + module_to_qat_module[pattern] = config.qat_module + return module_to_qat_module + +def get_fusion_pattern_to_root_node_getter(backend_config: BackendConfig) -> Dict[Pattern, Callable]: + """ Get a map from fusion pattern to a function that returns the root node + from the fusion pattern, e.g. the most common one is: + def get_root_node(node_pattern): + while not isinstance(node_pattern[-1], Node): + node_pattern = node_pattern[-1] + return node_pattern[-1] + This can work for all patterns whose root node is the "last node" in the pattern, + e.g. (torch.add, MatchAllNode, (torch.ReLU, torch.Conv2d)) + """ + root_node_getter_mapping: Dict[Pattern, Callable] = {} + for pattern, config in backend_config._pattern_complex_format_to_config.items(): + if config._root_node_getter is not None: + root_node_getter_mapping[pattern] = config._root_node_getter + return root_node_getter_mapping + +def get_fusion_pattern_to_extra_inputs_getter(backend_config: BackendConfig) -> Dict[Pattern, Callable]: + """ Get a map from fusion pattern to a function that returns extra input nodes + from the fusion pattern, in the order required by the root node. This is optional, + if not specified, we will not copy over any extra inputs for the root node. + Example: + # Let's say we have the pattern (torch.add, MatchAllNode, (torch.nn.BatchNorm2d, torch.nn.Conv2d)) + # and root node is torch.nn.Conv2d, and the node in MatchAllNode would be an extra + # argument to the fused module, we can unpack the pattern and return the node at + # MatchAllNode here + # we can implement extra_inputs_getter as follows: + def extra_inputs_getter(pattern) -> List[Any]: + add, extra_input, conv_pattern = pattern + return [extra_input] + """ + extra_inputs_getter_mapping: Dict[Pattern, Callable] = {} + for pattern, config in backend_config._pattern_complex_format_to_config.items(): + if config._extra_inputs_getter is not None: + extra_inputs_getter_mapping[pattern] = config._extra_inputs_getter + return extra_inputs_getter_mapping + +def remove_boolean_dispatch_from_name(p) -> Any: + """ + Some ops have a default string representation such as + '.fn at 0x7ff1106bf280>', + this function replaces them with the hardcoded function names. + """ + if p is F.fractional_max_pool2d: + return "torch.nn.functional.fractional_max_pool2d" + elif p is F.fractional_max_pool3d: + return "torch.nn.functional.fractional_max_pool3d" + elif p is F.max_pool1d: + return "torch.nn.functional.max_pool1d" + elif p is F.max_pool2d: + return "torch.nn.functional.max_pool2d" + elif p is F.max_pool3d: + return "torch.nn.functional.max_pool3d" + elif p is F.adaptive_max_pool1d: + return "torch.nn.functional.adaptive_max_pool1d" + elif p is F.adaptive_max_pool2d: + return "torch.nn.functional.adaptive_max_pool2d" + elif p is F.adaptive_max_pool3d: + return "torch.nn.functional.adaptive_max_pool3d" + assert "boolean_dispatch" not in str(p), \ + f"{p} does not have a human readable representation in " + \ + "quantization documentation" + return p + +def pattern_to_human_readable(p) -> Any: + if isinstance(p, tuple): + # nested patterns, recurse + return tuple(pattern_to_human_readable(inner_p) for inner_p in p) + elif isinstance(p, str): + # method names are already human readable + return p + else: + p = remove_boolean_dispatch_from_name(p) + return p + +# TODO(future PR): move backend_config_dict to use dataclass and move this logic to +# the corresponding __str__ function +def entry_to_pretty_str(entry) -> str: + """ + Given a backend_config_dict entry, returns a string with the human readable + representation of it. + """ + s = "{\n" + + # always output the pattern first + if "pattern" in entry: + pattern_str = pattern_to_human_readable(entry["pattern"]) + + s += f" 'pattern': {pattern_str},\n" + + # custom output for dtype_configs to make it look nice + if "dtype_configs" in entry: + s += " 'dtype_configs': [\n" + for dtype_config in entry["dtype_configs"]: + s += " {\n" + for k, v in dtype_config.items(): + s += f" '{k}': {v},\n" + s += " },\n" + s += " ],\n" + + # custom output for num_tensor_args_to_observation_type to make it look nice + if "num_tensor_args_to_observation_type" in entry: + s += " 'num_tensor_args_to_observation_type': {\n" + for k, v in entry["num_tensor_args_to_observation_type"].items(): + s += f" {k}: {v},\n" + s += " },\n" + + # output all the other fields + custom_handled_fields = [ + "pattern", + "dtype_configs", + "num_tensor_args_to_observation_type", + ] + for field_name in entry: + if field_name in custom_handled_fields: + continue + s += f" '{field_name}': {entry[field_name]},\n" + + s += "}" + return s + +def _get_pattern_in_reversed_nested_tuple_format(config: BackendPatternConfig) -> Pattern: + """ + Return the pattern specified in the given config in the reversed nested tuple format + used internally in the quantization pattern matching code. + + If the pattern is not a tuple, or the pattern is already specified in the reversed + nested tuple format, return the pattern as is. Otherwise: + + For 2-tuples (a, b), return (b, a). + For 3-tuples (a, b, c), return (c, (b, a)). + + For example: + * Given nn.Linear, return nn.Linear + * Given (nn.Linear, nn.ReLU), return (nn.ReLU, nn.Linear) + * Given (nn.Conv2d, nn.BatchNorm2d, nn.ReLU), return + (nn.ReLU, (nn.BatchNorm2d, nn.Conv2d)) + + For context, the reason why this is needed is the user-facing BackendConfig + API accepts the flat 2-or-3-tuple format in forward order. While this simple + format handles the vast majority of use cases, it does not handle the more + complex ones, and so the internal pattern matching code for quantization uses + the following, more general reversed nested tuple format instead: + + operator = module_type | functional | torch op | native op | MatchAllNode + Pattern = (operator, Pattern, Pattern, ...) | operator + + In the future, we expect to replace the above complex format with the one used + by the subgraph rewriter in torch.fx, so we don't have to maintain our own + complex pattern matching code. Then we won't need this helper function anymore. + """ + if config._pattern_complex_format is not None: + return config._pattern_complex_format + if config.pattern is None: + raise ValueError("Either 'pattern' or 'pattern_complex_format' must be specified") + if not isinstance(config.pattern, tuple): + return config.pattern + + # Pattern is specified in the simple tuple format, need to convert + if len(config.pattern) == 2: + (a, b) = config.pattern + return (b, a) + elif len(config.pattern) == 3: + (a, b, c) = config.pattern + return (c, (b, a)) + else: + raise ValueError("Expected a tuple with 2 or 3 elements, got: ", config.pattern) + +def _get_fuser_method_in_reversed_nested_tuple_format(config: BackendPatternConfig) -> Callable: + """ + Return the fuser method specified in the given config in the reversed nested + tuple format used internally in the quantization pattern matching code. + + If pattern is specified in the reversed nested tuple format, we assume the + fuser method is also specified in this format and simply return it as is. + Otherwise, we convert the fuser method as follows: + + * Given f(is_qat, conv, relu), return f'(is_qat, relu, conv) + * Given f(is_qat, conv, bn, relu), return f'(is_qat, relu, bn_conv), + where bn_conv is a 2-tuple (bn, conv) + + The first argument of a fuser method is always `is_qat` and is not affected + in the conversion. We currently only support functions with 3 or 4 arguments. + """ + assert config.fuser_method is not None + if config._pattern_complex_format is not None: + return config.fuser_method + if not isinstance(config.pattern, tuple): + raise ValueError("Expected pattern to be a tuple, got: ", config.pattern) + + # Pattern is specified in the simple tuple format, need to convert + if len(config.pattern) == 2: + return _reverse2(config.fuser_method) + elif len(config.pattern) == 3: + return _reverse3(config.fuser_method) + else: + raise ValueError("Expected a tuple with 2 or 3 elements, got: ", config.pattern) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/x86.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/x86.py new file mode 100644 index 0000000000000000000000000000000000000000..b4f165958f2791d3e6e2f63eceecdcd9e6f6d50c --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/backend_config/x86.py @@ -0,0 +1,113 @@ +import torch +from ._common_operator_config_utils import ( + _get_binary_op_configs, + _get_bn_configs, + _get_cat_config, + _get_conv_configs, + _get_default_op_configs, + _get_embedding_op_configs, + _get_fixed_qparams_op_configs, + _get_linear_configs, + _get_rnn_op_configs, + _get_share_qparams_op_configs, + _get_tensor_info_op_configs, +) +from .backend_config import BackendConfig, DTypeConfig + +__all__ = [ + "get_x86_backend_config", +] + +# =================== +# | DTYPE CONFIGS | +# =================== + +# X86 aligns with FBGEMM for now + +x86_weighted_op_int8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, + weight_dtype=torch.qint8, + bias_dtype=torch.float, +) + +x86_default_op_quint8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, +) + +x86_default_op_fp16_dtype_config = DTypeConfig( + input_dtype=torch.float16, + output_dtype=torch.float16, + weight_dtype=torch.float16, + bias_dtype=torch.float16, +) + +x86_default_dynamic_int8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.float, + weight_dtype=torch.qint8, + bias_dtype=torch.float, + is_dynamic=True, +) + +x86_default_dynamic_float16_dtype_config = DTypeConfig( + input_dtype=torch.float16, + output_dtype=torch.float, + weight_dtype=torch.float16, + bias_dtype=torch.float, + is_dynamic=True, +) + +x86_weight_only_quint8_dtype_config = DTypeConfig( + input_dtype=torch.float, + output_dtype=torch.float, + weight_dtype=torch.quint8, +) + +x86_weight_only_quint4x2_dtype_config = DTypeConfig( + input_dtype=torch.float, + output_dtype=torch.float, + weight_dtype=torch.quint4x2, +) + + +# ===================== +# | BACKEND CONFIGS | +# ===================== + +def get_x86_backend_config() -> BackendConfig: + """ + Return the `BackendConfig` for PyTorch's native x86 backend. + """ + conv_dtype_configs = [x86_weighted_op_int8_dtype_config] + linear_dtype_configs = [ + x86_weighted_op_int8_dtype_config, + x86_default_dynamic_int8_dtype_config, + x86_default_dynamic_float16_dtype_config, + ] + binary_op_dtype_configs = [x86_weighted_op_int8_dtype_config] + default_op_dtype_configs = [x86_default_op_quint8_dtype_config] + fixed_qparams_op_dtype_configs = [x86_weighted_op_int8_dtype_config] + share_qparams_op_dtype_configs = [x86_default_op_quint8_dtype_config] + tensor_info_op_dtype_configs = [x86_default_op_quint8_dtype_config] + rnn_op_dtype_configs = [ + x86_default_dynamic_int8_dtype_config, + x86_default_dynamic_float16_dtype_config, + ] + embedding_op_dtype_configs = [ + x86_weight_only_quint8_dtype_config, + x86_weight_only_quint4x2_dtype_config, + ] + return BackendConfig("x86") \ + .set_backend_pattern_configs(_get_conv_configs(conv_dtype_configs)) \ + .set_backend_pattern_configs(_get_linear_configs(linear_dtype_configs)) \ + .set_backend_pattern_configs(_get_binary_op_configs(binary_op_dtype_configs)) \ + .set_backend_pattern_config(_get_cat_config(default_op_dtype_configs)) \ + .set_backend_pattern_configs(_get_default_op_configs(default_op_dtype_configs)) \ + .set_backend_pattern_configs(_get_fixed_qparams_op_configs(fixed_qparams_op_dtype_configs)) \ + .set_backend_pattern_configs(_get_share_qparams_op_configs(share_qparams_op_dtype_configs)) \ + .set_backend_pattern_configs(_get_tensor_info_op_configs(tensor_info_op_dtype_configs)) \ + .set_backend_pattern_configs(_get_bn_configs(default_op_dtype_configs)) \ + .set_backend_pattern_configs(_get_rnn_op_configs(rnn_op_dtype_configs)) \ + .set_backend_pattern_configs(_get_embedding_op_configs(embedding_op_dtype_configs)) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fake_quantize.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fake_quantize.py new file mode 100644 index 0000000000000000000000000000000000000000..e0a169956fd39100fcff0c89b082f171192de838 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fake_quantize.py @@ -0,0 +1,546 @@ +"""Implements modules used to perform fake quantization.""" + +import torch +from torch.nn import Module +from torch.ao.quantization.observer import ( + MovingAverageMinMaxObserver, + HistogramObserver, + MovingAveragePerChannelMinMaxObserver, + FixedQParamsObserver, + default_fixed_qparams_range_0to1_observer, + default_fixed_qparams_range_neg1to1_observer, + _with_args, +) +import re +from abc import ABC, abstractmethod +from typing import Any, Tuple + +__all__ = [ + "FakeQuantizeBase", + "FakeQuantize", + "FixedQParamsFakeQuantize", + "FusedMovingAvgObsFakeQuantize", + "disable_fake_quant", + "disable_observer", + "enable_fake_quant", + "enable_observer", + "default_fake_quant", + "default_weight_fake_quant", + "default_dynamic_fake_quant", + "default_fixed_qparams_range_neg1to1_fake_quant", + "default_fixed_qparams_range_0to1_fake_quant", + "default_symmetric_fixed_qparams_fake_quant", + "default_affine_fixed_qparams_fake_quant", + "default_per_channel_weight_fake_quant", + "default_embedding_fake_quant", + "default_embedding_fake_quant_4bit", + "default_histogram_fake_quant", + "default_fused_act_fake_quant", + "default_fused_wt_fake_quant", + "default_fused_per_channel_wt_fake_quant", + "fused_wt_fake_quant_range_neg_127_to_127", + "fused_per_channel_wt_fake_quant_range_neg_127_to_127", +] + +def _is_per_channel(qscheme: 'torch.qscheme') -> bool: + return qscheme in [torch.per_channel_symmetric, torch.per_channel_affine, torch.per_channel_affine_float_qparams] + +def _is_per_tensor(qscheme: 'torch.qscheme') -> bool: + return qscheme in [torch.per_tensor_symmetric, torch.per_tensor_affine] + +def _is_symmetric_quant(qscheme: 'torch.qscheme') -> bool: + return qscheme in [torch.per_tensor_symmetric, torch.per_channel_symmetric] + +def _is_float_qparams(qscheme: 'torch.qscheme') -> bool: + return qscheme in [torch.per_channel_affine_float_qparams, ] + +class FakeQuantizeBase(ABC, Module): + r"""Base fake quantize module. + + Base fake quantize module + Any fake quantize implementation should derive from this class. + + Concrete fake quantize module should follow the same API. In forward, they will update + the statistics of the observed Tensor and fake quantize the input. They should also provide a + `calculate_qparams` function that computes the quantization parameters given + the collected statistics. + + """ + + fake_quant_enabled: torch.Tensor + observer_enabled: torch.Tensor + + def __init__(self): + """Set fake_quant_enabled and observer_enabled.""" + super().__init__() + # fake_quant_enabled and observer_enabled are buffers to support their + # replication in DDP. Data type is uint8 because NCCL does not support + # bool tensors. + self.register_buffer('fake_quant_enabled', torch.tensor([1], dtype=torch.uint8)) + self.register_buffer('observer_enabled', torch.tensor([1], dtype=torch.uint8)) + + @abstractmethod + def forward(self, x): + pass + + @abstractmethod + def calculate_qparams(self, **kwargs): + pass + + @torch.jit.export + def enable_fake_quant(self, enabled: bool = True) -> None: + self.fake_quant_enabled[0] = 1 if enabled else 0 + + @torch.jit.export + def disable_fake_quant(self): + self.enable_fake_quant(False) + + @torch.jit.export + def enable_observer(self, enabled: bool = True) -> None: + self.observer_enabled[0] = 1 if enabled else 0 + + @torch.jit.export + def disable_observer(self): + self.enable_observer(False) + + @classmethod + def with_args(cls, **kwargs): + fake_quant_constructor = _with_args(cls, **kwargs) + # need to assign the correct module to fake_quantize + # constructors to satisfy public v private requirements + fake_quant_constructor.__module__ = "torch.ao.quantization.fake_quantize" + return fake_quant_constructor + +class FakeQuantize(FakeQuantizeBase): + r"""Simulate the quantize and dequantize operations in training time. + + The output of this module is given by:: + + x_out = ( + clamp(round(x/scale + zero_point), quant_min, quant_max) - zero_point + ) * scale + + * :attr:`is_dynamic` indicates whether the fake quantie is a placeholder for dynamic quantization + operators (choose_qparams -> q -> dq) or static quantization operators (q -> dq) + + * :attr:`scale` defines the scale factor used for quantization. + + * :attr:`zero_point` specifies the quantized value to which 0 in floating point maps to + + * :attr:`fake_quant_enabled` controls the application of fake quantization on tensors, note that + statistics can still be updated. + + * :attr:`observer_enabled` controls statistics collection on tensors + + * :attr:`dtype` specifies the quantized dtype that is being emulated with fake-quantization, + allowable values are torch.qint8 and torch.quint8. + + Args: + + observer (module): Module for observing statistics on input tensors and calculating scale + and zero-point. + observer_kwargs (optional): Arguments for the observer module + + Attributes: + activation_post_process (Module): User provided module that collects statistics on the input tensor and + provides a method to calculate scale and zero-point. + + """ + + scale: torch.Tensor + zero_point: torch.Tensor + + def __init__(self, observer=MovingAverageMinMaxObserver, quant_min=None, quant_max=None, is_dynamic=False, **observer_kwargs): + super().__init__() + # Populate quant_min/quant_max to observer_kwargs if valid + if quant_min is not None and quant_max is not None: + assert quant_min <= quant_max, \ + 'quant_min must be less than or equal to quant_max' + dtype = observer_kwargs.get("dtype", torch.quint8) + if hasattr(observer, "p"): + # In case observer is _PartialWrapper, dtype can be stored in + # observer.p.keywords["dtype"] + dtype = getattr(getattr(observer, "p", {}), "keywords", {}).get( + "dtype", dtype + ) + assert torch.iinfo(dtype).min <= quant_min, 'quant_min out of bound' + assert quant_max <= torch.iinfo(dtype).max, 'quant_max out of bound' + observer_kwargs.update({"quant_min": quant_min, "quant_max": quant_max}) + observer_kwargs["is_dynamic"] = is_dynamic + self.activation_post_process = observer(**observer_kwargs) + # TODO: keeping self.quant_min/max for BC; remove after a couple releases + # Users should use self.activation_post_process.quant_min + self.quant_min = self.activation_post_process.quant_min + self.quant_max = self.activation_post_process.quant_max + self.is_dynamic = self.activation_post_process.is_dynamic + if _is_float_qparams(self.activation_post_process.qscheme): + zero_point_dtype = torch.float + else: + zero_point_dtype = torch.int + self.register_buffer('scale', torch.tensor([1.0], dtype=torch.float)) + self.register_buffer('zero_point', torch.tensor([0], dtype=zero_point_dtype)) + self.dtype = self.activation_post_process.dtype + self.qscheme = self.activation_post_process.qscheme + self.ch_axis = self.activation_post_process.ch_axis \ + if hasattr(self.activation_post_process, 'ch_axis') else -1 + assert _is_per_channel(self.qscheme) or \ + _is_per_tensor(self.qscheme), \ + 'Only per channel and per tensor quantization are supported in fake quantize' + \ + ' got qscheme: ' + str(self.qscheme) + self.is_per_channel = _is_per_channel(self.qscheme) + + @torch.jit.export + def calculate_qparams(self): + return self.activation_post_process.calculate_qparams() + + def forward(self, X): + if self.observer_enabled[0] == 1: + self.activation_post_process(X.detach()) + _scale, _zero_point = self.calculate_qparams() + _scale, _zero_point = _scale.to(self.scale.device), _zero_point.to(self.zero_point.device) + if self.scale.shape != _scale.shape: + self.scale.resize_(_scale.shape) + self.zero_point.resize_(_zero_point.shape) + self.scale.copy_(_scale) + self.zero_point.copy_(_zero_point) + + if self.fake_quant_enabled[0] == 1: + if self.is_per_channel: + X = torch.fake_quantize_per_channel_affine( + X, self.scale, self.zero_point, + self.ch_axis, self.activation_post_process.quant_min, self.activation_post_process.quant_max) + else: + X = torch.fake_quantize_per_tensor_affine( + X, self.scale, self.zero_point, + self.activation_post_process.quant_min, self.activation_post_process.quant_max) + return X + + @torch.jit.export + def extra_repr(self): + return 'fake_quant_enabled={}, observer_enabled={}, ' \ + 'quant_min={}, quant_max={}, dtype={}, qscheme={}, ch_axis={}, ' \ + 'scale={}, zero_point={}'.format( + self.fake_quant_enabled, self.observer_enabled, + self.activation_post_process.quant_min, self.activation_post_process.quant_max, + self.dtype, self.qscheme, self.ch_axis, self.scale, self.zero_point) + + def _save_to_state_dict(self, destination, prefix, keep_vars): + # We cannot currently register scalar values as buffers, so need to manually + # specify serialization here. + super()._save_to_state_dict(destination, prefix, keep_vars) + destination[prefix + 'scale'] = self.scale + destination[prefix + 'zero_point'] = self.zero_point + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + # Removing this function throws an error that the size of the loaded tensor does not match the original size + # i.e., These buffers start out with numel 0 and become numel 1 once they have their first forward pass. + local_state = ['scale', 'zero_point'] + for name in local_state: + key = prefix + name + if key in state_dict: + val = state_dict[key] + # Custom handling to allow loading scale and zero_point + # of size N into uninitialized buffers of size 0. The + # buffers are resized here, and the values are copied in + # the default state_dict loading code of the parent. + if name == 'scale': + self.scale.resize_(val.shape) + else: + assert name == 'zero_point' + self.zero_point.resize_(val.shape) + # For torchscript module we need to update the attributes here since we do not + # call the `_load_from_state_dict` function defined module.py + if torch.jit.is_scripting(): + if name == 'scale': + self.scale.copy_(val) + else: + assert name == 'zero_point' + self.zero_point.copy_(val) + elif strict: + missing_keys.append(key) + super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs) + + +class FixedQParamsFakeQuantize(FakeQuantize): + """Simulate quantize and dequantize in training time. + + Simulate quantize and dequantize with fixed quantization + parameters in training time. Only per tensor quantization + is supported. + """ + + # TODO: rename observer to observer_ctr + def __init__(self, observer): + super().__init__(observer=observer) + assert type(self.activation_post_process) == FixedQParamsObserver, \ + f"{self.__class__.__name__}'s observer must be a {FixedQParamsObserver.__name__}" + self._observer_ctr = observer + self.scale = self.activation_post_process.scale + self.zero_point = self.activation_post_process.zero_point + assert _is_per_tensor(self.qscheme), 'Only per tensor quantization is supported' + \ + ' FixedQParamsFakeQuantize module, got qscheme:' + str(self.qscheme) + + @torch.jit.export + def calculate_qparams(self): + return self.scale, self.zero_point + + @torch.jit.export + def extra_repr(self): + """Define a string representation of the object's attributes.""" + return 'fake_quant_enabled={}, observer_enabled={}, scale={}, zero_point={}, ' \ + 'dtype={}, quant_min={}, quant_max={}, qscheme={}'.format( + self.fake_quant_enabled, self.observer_enabled, + self.scale, self.zero_point, self.dtype, + self.activation_post_process.quant_min, self.activation_post_process.quant_max, self.qscheme) + + +class FusedMovingAvgObsFakeQuantize(FakeQuantize): + r"""Define a fused module to observe the tensor. + + Fused module that is used to observe the input tensor (compute min/max), compute + scale/zero_point and fake_quantize the tensor. + This module uses calculation similar MovingAverageMinMaxObserver for the inputs, + to compute the min/max values in order to compute the scale/zero_point. + The qscheme input in the observer is used to differentiate between symmetric/affine + quantization scheme. + + The output of this module is given by + x_out = (clamp(round(x/scale + zero_point), quant_min, quant_max)-zero_point)*scale + + Similar to :class:`~torch.ao.quantization.FakeQuantize`, and accepts the same attributes as the + base class. + + """ + + def __init__( + self, + observer: Any = MovingAverageMinMaxObserver, + quant_min: int = 0, + quant_max: int = 255, + **observer_kwargs: Any + ) -> None: + super().__init__(observer, quant_min, quant_max, **observer_kwargs) + assert isinstance(self.activation_post_process, (MovingAverageMinMaxObserver, MovingAveragePerChannelMinMaxObserver)), \ + "Fused observer+fake_quant module only works with MovingAverageMinMaxObserver" + self.register_buffer("fake_quant_enabled", torch.tensor([1], dtype=torch.long)) + self.register_buffer("observer_enabled", torch.tensor([1], dtype=torch.long)) + self.is_symmetric_quant = _is_symmetric_quant(self.activation_post_process.qscheme) + + @torch.jit.export + def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]: + return self.activation_post_process.calculate_qparams() + + @torch.jit.export + def extra_repr(self) -> str: + return ( + "fake_quant_enabled={}, observer_enabled={}, scale={}, zero_point={}, " + "dtype={}, quant_min={}, quant_max={}, qscheme={}, reduce_range={}".format( + self.fake_quant_enabled, + self.observer_enabled, + self.scale, + self.zero_point, + self.dtype, + self.activation_post_process.quant_min, + self.activation_post_process.quant_max, + self.qscheme, + self.activation_post_process.reduce_range, + ) + ) + + def forward(self, X: torch.Tensor) -> torch.Tensor: + return torch.fused_moving_avg_obs_fake_quant( + X, + self.observer_enabled, + self.fake_quant_enabled, + self.activation_post_process.min_val, + self.activation_post_process.max_val, + self.scale, + self.zero_point, + self.activation_post_process.averaging_constant, + self.activation_post_process.quant_min, + self.activation_post_process.quant_max, + self.ch_axis, + self.is_per_channel, + self.is_symmetric_quant, + ) + +default_fake_quant = FakeQuantize.with_args(observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255, + dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=True) +""" +Default fake_quant for activations. +""" + +default_weight_fake_quant = FakeQuantize.with_args(observer=MovingAverageMinMaxObserver, quant_min=-128, quant_max=127, + dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, reduce_range=False) +""" +Default fake_quant for weights. +Observer is memoryless since averaging_constant is 1. +""" + +default_dynamic_fake_quant = FakeQuantize.with_args( + observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255, is_dynamic=True, + dtype=torch.quint8, averaging_constant=1) +""" +Default dynamic fake_quant for activations. +""" + +default_fixed_qparams_range_neg1to1_fake_quant = ( + FixedQParamsFakeQuantize.with_args(observer=default_fixed_qparams_range_neg1to1_observer) +) +default_fixed_qparams_range_0to1_fake_quant = ( + FixedQParamsFakeQuantize.with_args(observer=default_fixed_qparams_range_0to1_observer) +) +# TODO: the following 2 variables are kept for backwards compatibility; remove after a few releases +default_symmetric_fixed_qparams_fake_quant = default_fixed_qparams_range_neg1to1_fake_quant +default_affine_fixed_qparams_fake_quant = default_fixed_qparams_range_0to1_fake_quant + +default_per_channel_weight_fake_quant = FakeQuantize.with_args(observer=MovingAveragePerChannelMinMaxObserver, + quant_min=-128, + quant_max=127, + dtype=torch.qint8, + qscheme=torch.per_channel_symmetric, + reduce_range=False, + ch_axis=0) +""" +Default fake_quant for per-channel weights. +Observer is memoryless since averaging_constant is 1. +""" +default_embedding_fake_quant = FakeQuantize.with_args(observer=MovingAveragePerChannelMinMaxObserver, + qscheme=torch.per_channel_affine_float_qparams, + dtype=torch.quint8, + quant_min=0, + quant_max=255, + ch_axis=0, + averaging_constant=1) +""" +Default fake_quant for embeddings. +Observer is memoryless since averaging_constant is 1. +""" + +default_embedding_fake_quant_4bit = FakeQuantize.with_args(observer=MovingAveragePerChannelMinMaxObserver, + qscheme=torch.per_channel_affine_float_qparams, + ch_axis=0, + dtype=torch.quint4x2, + averaging_constant=1) + +default_histogram_fake_quant = FakeQuantize.with_args(observer=HistogramObserver, + quant_min=0, + quant_max=255, + dtype=torch.quint8, + qscheme=torch.per_tensor_affine, + reduce_range=True) +""" +Fake_quant for activations using a histogram.. +""" + + +default_fused_act_fake_quant = FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver, + quant_min=0, + quant_max=255, + dtype=torch.quint8,) + +""" +Fused version of `default_fake_quant`, with improved performance. +""" + + +default_fused_wt_fake_quant = FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver, + quant_min=-128, + quant_max=127, + dtype=torch.qint8, + qscheme=torch.per_tensor_symmetric) +""" +Fused version of `default_weight_fake_quant`, with improved performance. +""" + +default_fused_per_channel_wt_fake_quant = FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAveragePerChannelMinMaxObserver, + quant_min=-128, + quant_max=127, + dtype=torch.qint8, + qscheme=torch.per_channel_symmetric) +""" +Fused version of `default_per_channel_weight_fake_quant`, with improved performance. +""" + +fused_wt_fake_quant_range_neg_127_to_127 = FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver, + quant_min=-127, + quant_max=127, + dtype=torch.qint8, + qscheme=torch.per_tensor_symmetric, + eps=2 ** -12) +""" +Fused version of `default_weight_fake_quant`, with the 8-bit values restricted to [-127, +127], excluding -128. +""" + +fused_per_channel_wt_fake_quant_range_neg_127_to_127 = \ + FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAveragePerChannelMinMaxObserver, + quant_min=-127, + quant_max=127, + dtype=torch.qint8, + qscheme=torch.per_channel_symmetric, + eps=2 ** -12) + +""" +Fused version of `default_per_channel_weight_fake_quant`, with the 8-bit values restricted to [-127, +127], excluding -128. +""" + + +def _is_fake_quant_script_module(mod): + """Return true if given mod is an instance of FakeQuantize script module.""" + if isinstance(mod, torch.jit.RecursiveScriptModule): + # qualified name looks like '__torch__.torch.ao.quantization.fake_quantize.___torch_mangle_2.FakeQuantize' + suffix = mod._c.qualified_name.split('.', 1)[1] + name = re.sub(r'\.___torch_mangle_\d+', '', suffix) + return name == 'torch.ao.quantization.fake_quantize.FakeQuantize' or \ + name == 'torch.ao.quantization.fake_quantize.FusedMovingAvgObsFakeQuantize' + return False + +def disable_fake_quant(mod): + """Disable fake quantization for the module. + + Disable fake quantization for this module, if applicable. Example usage:: + + # model is any PyTorch model + model.apply(torch.ao.quantization.disable_fake_quant) + + """ + if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod): + mod.disable_fake_quant() + +def enable_fake_quant(mod): + """Enable fake quantization for the module. + + Enable fake quantization for this module, if applicable. Example usage:: + + # model is any PyTorch model + model.apply(torch.ao.quantization.enable_fake_quant) + + """ + if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod): + mod.enable_fake_quant() + +def disable_observer(mod): + """Disable observation for this module. + + Disable observation for this module, if applicable. Example usage:: + + # model is any PyTorch model + model.apply(torch.ao.quantization.disable_observer) + + """ + if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod): + mod.disable_observer() + +def enable_observer(mod): + """Enable observation for this module. + + Enable observation for this module, if applicable. Example usage:: + + # model is any PyTorch model + model.apply(torch.ao.quantization.enable_observer) + + """ + if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod): + mod.enable_observer() diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/_model_report/model_report.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/_model_report/model_report.py new file mode 100644 index 0000000000000000000000000000000000000000..9ea5ff406d799ffa1ed68158d09639d089957949 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/_model_report/model_report.py @@ -0,0 +1,606 @@ +from typing import Any, Dict, Set, Tuple, Callable +from collections import OrderedDict +import torch +from torch.ao.quantization.fx._model_report.detector import ( + DetectorBase, + DETECTOR_OBS_ARGS_KEY, + DETECTOR_OBS_TO_INSERT_KEY, + DETECTOR_IS_POST_OBS_KEY, + DETECTOR_TARGET_NODE_KEY, + DetectorQConfigInfo +) +from torch.ao.quantization.fx._model_report.model_report_visualizer import ModelReportVisualizer +from torch.ao.quantization.fx.graph_module import GraphModule +from torch.ao.quantization.observer import ObserverBase +from torch.ao.quantization.qconfig_mapping import QConfigMapping, QConfig +from torch.ao.quantization.fx._equalize import EqualizationQConfig + +class ModelReport: + r""" + The ModelReport class aims to provide users an easy way to diagnose issues that they run into + with their models. The class works with all traceable GraphModules to help diagnose issues, + though the requirements on the type of model more-so depends on the specific report the user + is trying to generate. With respect to the reports, the ModelReport class is initialized with + a set of Detector classes, each of which generate reports on quantization configuration + issues a use might have. + + Currently supports generating reports on: + - Suggestions for per-channel vs. per-tensor quantization (nn.Module) + - Suggestions for dynamic vs static quantization for linear layers (Graph Modules) + - Suggestions for input-weight equalization for linear and conv layers (Graph Modules) + - Suggestions for outlier detection for all layers (Graph Modules) + + The ModelReport class has the primary functionality of inserting observers (primarily the ModelReportObserver) + where needed for each detector to gather the information it needs, and then after callibration, the ModelReport + class compiles the report generated by each Detector class into a single report to return to the user. It also + has the capability to remove all the observers it inserted as well. + + * :attr:`_model` The model we wish to generate the report for. Must be a traceable GraphModule + + * :attr:`_desired_report_detectors` The set of Detectors representing desired reports from the ModelReport class + Make sure that these are all unique types of detectors [do not have more than 1 of the same class] + + * :attr:`_desired_detector_names` The set of detector names of the _desired_report_detectors. + This set is generated by calling the get_detector_name() of each detector + + * :attr:`_detector_name_to_observer_fqns` The mapping from each detector to fqns of observers of interest + The purpose of this is to keep track of what observers were inserted for each detector, so that they + can be removed at the end if desired + + * :attr:`_prepared_flag` A boolean flag that keeps track of whether we have prepared the model or not + This is to ensure we only insert observers once with the ModelReport instance + + * :attr:`_removed_observers` A boolean to track if we have removed observers already + The purpose is to ensure we don't attempt to remove observers twice with the same ModelReport + instance. This also allows the functionality where we can generate the report multiple times + as long as we haven't removed the observers yet. + + Note: + This class was initially designed to work with the Fx Graph Mode workflow in mind. However, + full functionality is available as long as there is a traceable GraphModule that is being used. + One method to get a traceable GraphModule without going through the Fx workflow is to use + the QuantizationTracer class. + + General Flow for Fx workflow: + 1.) Initialize ModelReport object with reports of interest by passing in initialized detector objects and model + 2.) Prepare your model with prepare_fx + 3.) Call model_report.prepare_detailed_calibration to add relevant observers + 4.) Callibrate your model with data + 5.) Call model_report.generate_report on your model to generate report and optionally remove added observers + Optional + 6.) Call model_report.generate_visualizer to get a ModelReportVisualizer instance + 7.) To help in parsing report information and debugging, view report info as a: + - Table + - Histogram + - Line plot + 8.) Call model_report.generate_qconfigs to generate the qconfigs based on the report suggestions + + Example (with QuantizationTracer): + >>> # xdoctest: +SKIP + >>> # get the necessary qconfig + >>> config = PrepareCustomConfig() + >>> skipped_module_names, skipped_module_classes = get_skipped_module_name_and_classes(config, False) + + >>> # initialize our model and get GraphModule + >>> model = SomeModel() + >>> tracer = QuantizationTracer(skipped_module_names, skipped_module_classes) + >>> graph_module = GraphModule(model, tracer.trace(model)) + + >>> # get our set of detectors and ModelReport instance + >>> detector_set = set([DynamicStaticDetector(tolerance=0.5), InputWeightEqualizationDetector(ratio_threshold=0.7)]) + >>> tracer_reporter = ModelReport(graph_module, tracer_detector_set) + + >>> # now we insert the observers and callibrate the model + >>> tracer_model_with_observers = tracer_reporter.prepare_detailed_calibration() + >>> for i in range(num_callibration_batches): + >>> example_input = get_callibration_input() + >>> tracer_model_with_observers(example_input) + + >>> # finally we generate the reports and optionally remove the observers we inserted + >>> reports = tracer_reporter.generate_model_report(remove_inserted_observers=True) + + >>> # Optional: we can generate the qconfig mapping based on the suggestions + >>> qconfigs = model_report.generate_qconfig_mapping() + + >>> # Optional: we can generate the equalization mapping based on the suggestions + >>> qconfigs = model_report.generate_equalization_mapping() + + >>> # Optional: we get a ModelReportVisualizer instance to do any visualizations desired + >>> model_report_visualizer = tracer_reporter.generate_visualizer() + + """ + + def __init__(self, model: GraphModule, desired_report_detectors: Set[DetectorBase]): + + if len(desired_report_detectors) == 0: + raise ValueError("Should include at least 1 desired report") + + # keep track of the model we wish to generate report for + self._model: GraphModule = model + + # keep the reports private so they can't be modified + self._desired_report_detectors = desired_report_detectors + self._desired_detector_names = {detector.get_detector_name() for detector in desired_report_detectors} + + # keep a mapping of desired reports to observers of interest + # this is to get the readings, and to remove them, can create a large set + # this set can then be used to traverse the graph and remove added observers + self._detector_name_to_observer_fqns: Dict[str, Set[str]] = {} + + # initialize each report to have empty set of observers of interest + for desired_report in self._desired_detector_names: + self._detector_name_to_observer_fqns[desired_report] = set() + + # flags to ensure that we can only prepare and remove observers once + self._prepared_flag = False + self._removed_observers = False + + # store the reports that we generated for visualization purposes + # initially empty since no reports generated + self._generated_reports: Dict[str, Dict] = {} + + def get_desired_reports_names(self) -> Set[str]: + """ Returns a copy of the desired reports for viewing """ + return self._desired_detector_names.copy() + + def get_observers_of_interest(self) -> Dict[str, Set[str]]: + """ Returns a copy of the observers of interest for viewing """ + return self._detector_name_to_observer_fqns.copy() + + def prepare_detailed_calibration(self) -> GraphModule: + r""" + Takes in a graph model and inserts the following observers: + - ModelReportObserver + + Each observer is inserted based on the desired_reports into the relevant locations + + Right now, each report in self._desired_detector_names has independent insertions + However, if a module already has a Observer of the same type, the insertion will not occur + This is because all of the same type of Observer collect same information, so redundant + + Returns the same GraphModule with the observers inserted + """ + + # if already prepared once, cannot prepare again + if self._prepared_flag: + raise ValueError("Already ran preparing detailed callibration. Run the report generation next after callibration.") + + # loop through each detector, find where placements should be, and keep track + insert_observers_fqns: Dict[str, Any] = {} + + for detector in self._desired_report_detectors: + # determine observer points for each detector + obs_fqn_to_info = detector.determine_observer_insert_points(self._model) + # map each insert point to the observer to use + insert_observers_fqns.update(obs_fqn_to_info) + # update the set of observers this report cares about + self._detector_name_to_observer_fqns[detector.get_detector_name()] = set(obs_fqn_to_info.keys()) + + # now insert all the observers at their desired locations + for observer_fqn in insert_observers_fqns: + target_node = insert_observers_fqns[observer_fqn][DETECTOR_TARGET_NODE_KEY] + insert_obs = insert_observers_fqns[observer_fqn][DETECTOR_OBS_TO_INSERT_KEY] + insert_post = insert_observers_fqns[observer_fqn][DETECTOR_IS_POST_OBS_KEY] + observer_args = insert_observers_fqns[observer_fqn][DETECTOR_OBS_ARGS_KEY] + self._insert_observer_around_module( + observer_fqn, target_node, insert_obs, observer_args, insert_post + ) + + self._prepared_flag = True + + return self._model + + def _insert_observer_around_module( + self, + obs_fqn: str, + target_node: torch.fx.node.Node, + obs_to_insert: ObserverBase, + observer_args: Tuple, + insert_post: bool + ): + r""" + Helper function that inserts the observer into both the graph structure and the module of the model + + Args + node_fqn (str): The fully qualified name of the observer we want to insert + target_node (torch.fx.node.Node): The node in model we are inserting observers around + obs_to_insert (ObserverBase): The observer we are inserting around target_node + observer_args (Tuple): The arguments we want to pass into the observer + insert_post (bool): whether this is meant to be a post observer for this node + """ + # if we are inserting post, then our target node is the next node + if insert_post: + target_node = target_node.next + + with self._model.graph.inserting_before(target_node): + self._model.add_submodule(obs_fqn, obs_to_insert) + self._model.graph.create_node(op="call_module", target=obs_fqn, args=observer_args) + + # recompile model after inserts are made + self._model.recompile() + + def _get_node_from_fqn(self, node_fqn: str) -> torch.fx.node.Node: + r""" + Takes in a node fqn and returns the node based on the fqn + + Args + node_fqn (str): The fully qualified name of the node we want to find in model + + Returns the Node object of the given node_fqn otherwise returns None + """ + node_to_return = None + for node in self._model.graph.nodes: + # if the target matches the fqn, it's the node we are looking for + if node.target == node_fqn: + node_to_return = node + break + + if node_to_return is None: + raise ValueError("The node_fqn is was not found within the module.") + + # assert for MyPy + assert isinstance(node_to_return, torch.fx.node.Node) + + return node_to_return + + def generate_model_report( + self, remove_inserted_observers: bool + ) -> Dict[str, Tuple[str, Dict]]: + r""" + Generates all the requested reports. + + Note: + You should have callibrated the model with relevant data before calling this + + The reports generated are specified by the desired_reports specified in desired_reports + + Can optionally remove all the observers inserted by the ModelReport instance + + Args: + remove_inserted_observers (bool): True to remove the observers inserted by this ModelReport instance + + Returns a mapping of each desired report name to a tuple with: + The textual summary of that report information + A dictionary containing relevant statistics or information for that report + + Note: + Throws exception if we try to generate report on model we already removed observers from + Throws exception if we try to generate report without preparing for callibration + """ + # if we haven't prepped model for callibration, then we shouldn't generate report yet + if not self._prepared_flag: + raise Exception("Cannot generate report without preparing model for callibration") + + # if we already removed the observers, we cannot generate report + if self._removed_observers: + raise Exception("Cannot generate report on model you already removed observers from") + + # keep track of all the reports of interest and their outputs + reports_of_interest = {} + + for detector in self._desired_report_detectors: + # generate the individual report for the detector + report_output = detector.generate_detector_report(self._model) + reports_of_interest[detector.get_detector_name()] = report_output + + # if user wishes to remove inserted observers, go ahead and remove + if remove_inserted_observers: + self._removed_observers = True + # get the set of all Observers inserted by this instance of ModelReport + all_observers_of_interest: Set[str] = set() + for desired_report in self._detector_name_to_observer_fqns: + observers_of_interest = self._detector_name_to_observer_fqns[desired_report] + all_observers_of_interest.update(observers_of_interest) + + # go through all_observers_of_interest and remove them from the graph and model + for observer_fqn in all_observers_of_interest: + # remove the observer from the model + self._model.delete_submodule(observer_fqn) + + # remove the observer from the graph structure + node_obj = self._get_node_from_fqn(observer_fqn) + + if node_obj: + self._model.graph.erase_node(node_obj) + else: + raise ValueError("Node no longer exists in GraphModule structure") + + # remember to recompile the model + self._model.recompile() + + # save the generated reports for visualization purposes + saved_reports: Dict[str, Dict] = { + report_name : report_tuple[1] for report_name, report_tuple in reports_of_interest.items() + } + + self._generated_reports = saved_reports + + # return the reports of interest + return reports_of_interest + + def _is_same_info_for_same_key(self, info_dict_a: Dict, info_dict_b: Dict) -> bool: + r""" + Takes in two dictionaries and ensures that any common keys between the two have the same + values. + + Args: + info_dict_a (Dict): First dictionary we wish to compare + info_dict_b (Dict): Second dictionary we wish to compare + + Returns True if all shared keys have same values, false otherwise + """ + # get the set of keys for both + dict_a_keys: Set = set(info_dict_a.keys()) + dict_b_keys: Set = set(info_dict_b.keys()) + + # get the insersection keys and check if same value for both dicts + intersecting_keys: Set = dict_a_keys.intersection(dict_b_keys) + + for key in intersecting_keys: + dict_a_val = info_dict_a[key] + dict_b_val = info_dict_b[key] + + # if it's a tensor we have to handle separately + if type(dict_a_val) == torch.Tensor: + # if dict_b_val not tensor, automatically false + if type(dict_b_val) != torch.Tensor or sum(dict_a_val != dict_b_val) != 0: + return False + else: + # for non-tensor vals + if dict_a_val != dict_b_val: + return False + + # if no non matching shared keys found, return true + return True + + def _reformat_reports_for_visualizer(self) -> OrderedDict: + r""" + Takes the generated reports and reformats them into the format that is desired by the + ModelReportVisualizer + + Returns an OrderedDict mapping module_fqns to their features + """ + # we want to reorder and reformat the information so it is ordered in terms of order + # found in the model + + # first create new dict with all modules as keys and features under respective module + module_fqns_to_features: Dict[str, Dict] = {} + + for report_name in self._generated_reports: + # get mod -> feature dict and go through + module_info = self._generated_reports[report_name] + + for module_fqn in module_info: + # check if already in our accumulation dict + if module_fqn in module_fqns_to_features: + # we merge all the features together + new_info: Dict = module_info[module_fqn] + present_info: Dict = module_fqns_to_features[module_fqn] + + # merge them together into the new unioned dict + # same features keys -> same info, so okay if override + + # do safety check to make sure shared keys have same info + if self._is_same_info_for_same_key(new_info, present_info): + module_fqns_to_features[module_fqn] = {**new_info, **present_info} + else: + error_str = "You have the same key with different values across detectors. " + error_str += "Someone incorrectly implemented a detector with conflicting keys to existing detectors." + raise ValueError(error_str) + else: + # we just set it + module_fqns_to_features[module_fqn] = module_info[module_fqn] + + # our ordered dict so that modules can be ordered in order of how they appear in model + features_by_module: OrderedDict[str, Dict] = OrderedDict() + + # we loop through modules in graph in order + for fqn, module in self._model.named_modules(): + # find that fqn in fqns_to_features + if fqn in module_fqns_to_features: + # add it to our ordered dict + features_by_module[fqn] = module_fqns_to_features[fqn] + + # return the ordered dict of info we created + return features_by_module + + def generate_visualizer(self) -> ModelReportVisualizer: + r""" + Generates a ModelReportVisualizer instance using the reports generated + by the generate_model_report() method. + + Returns the generated ModelReportVisualizer instance initialized + + Note: + Throws exception if attempt to get visualizers without generating report + """ + # check if user has generated reports at least once + if len(self._generated_reports) == 0: + raise Exception("Unable to generate visualizers without first generating reports") + + # get the ordered dict mapping modules to their full set of collected features / stats + module_fqns_to_features: OrderedDict = self._reformat_reports_for_visualizer() + + # create and return ModelReportVisualizer instance + visualizer: ModelReportVisualizer = ModelReportVisualizer(module_fqns_to_features) + + return visualizer + + def _generate_qconfig_mapping_helper( + self, + detector_qconfig_info_combined: Dict[str, DetectorQConfigInfo], + generation_function: Callable + ) -> QConfigMapping: + r""" + This helper takes in the compiled detector qconfig info that + has been compiled together and merges it into a QConfigMapping + """ + # keep track of the qconfigmapping + qconfig_mapping = QConfigMapping() + + # loop through each module / fqn and attempt to create QConfigMapping + for fqn, module in self._model.named_modules(): + # if we have a qconfig info for this module + if fqn in detector_qconfig_info_combined: + qconfig_info_compiled = detector_qconfig_info_combined[fqn] + + # now generate the qconfig and add it to the mapping + generated_qconfig = generation_function(qconfig_info_compiled, module) + + # add to our config + qconfig_mapping.set_module_name(fqn, generated_qconfig) + + # return compiled mapping + return qconfig_mapping + + def _update_detector_quantizaiton_qconfig_info(self, combined_info: DetectorQConfigInfo, new_info: DetectorQConfigInfo): + r""" + Takes in the old and new information and updates the combined information. + + Args: + combined_info (DetectorQConfigInfo): The DetectorQConfigInfo we are compiling all of the information in + new_info (DetectorQConfigInfo): The DetectorQConfigInfo with the information we are trying to merge the new info + into it + """ + combined_info.is_activation_dynamic = combined_info.is_activation_dynamic or new_info.is_activation_dynamic + combined_info.is_weight_per_channel = combined_info.is_weight_per_channel or new_info.is_weight_per_channel + + def _update_detector_equalization_qconfig_info(self, combined_info: DetectorQConfigInfo, new_info: DetectorQConfigInfo): + r""" + Takes in the old and new information and updates the combined information. + + Args: + combined_info (DetectorQConfigInfo): The DetectorQConfigInfo we are compiling all of the information in + new_info (DetectorQConfigInfo): The DetectorQConfigInfo with the information we are trying to merge the new info + into it + """ + is_equalization_recommended = combined_info.is_equalization_recommended or new_info.is_equalization_recommended + combined_info.is_equalization_recommended = is_equalization_recommended + + def _generate_module_fqn_to_detector_info_mapping( + self, + update_qconfig_info_function: Callable + ) -> Dict[str, DetectorQConfigInfo]: + r""" + Generates a QConfigMapping based on the suggestions of the + ModelReport API. The generated mapping encompasses all the + different types of feedback from the different detectors + all into one place. + + These configs are based on the suggestions provided by the ModelReport API + and can only be generated once the reports have been generated. + + Args: + update_qconfig_info_function (Callable) takes in a function that takes in two DetectorQConfigInfo + and updates the one that is being compiled + + Returns a Dict mapping module_fqns to DetectorQConfigInfo objects + + Note: + Throws exception if we try to generate mapping on model we already removed observers from + Throws exception if we try to generate mapping without preparing for callibration + """ + # if we haven't prepped model for callibration, then we shouldn't generate mapping yet + if not self._prepared_flag: + raise Exception("Cannot generate report without preparing model for callibration") + + # if we already removed the observers, we cannot mapping + if self._removed_observers: + raise Exception("Cannot generate report on model you already removed observers from") + + # keep track of qconfig info for each module across detectors + detector_qconfig_info_combined: Dict[str, DetectorQConfigInfo] = {} + + for detector in self._desired_report_detectors: + # get the info from the detector + detector_info: Dict[str, DetectorQConfigInfo] = detector.get_qconfig_info(self._model) + + # we go through the modules + for module_fqn in detector_info: + # see if we already have info on it + if module_fqn in detector_qconfig_info_combined: + # we combine the current options with what is there + current_options = detector_qconfig_info_combined[module_fqn] + detector_options = detector_info[module_fqn] + + update_qconfig_info_function(current_options, detector_options) + else: + # we just use this for now + detector_qconfig_info_combined[module_fqn] = detector_info[module_fqn] + + return detector_qconfig_info_combined + + def generate_qconfig_mapping(self) -> QConfigMapping: + r""" + Generates a QConfigMapping based on the suggestions of the + ModelReport API. The generated mapping encompasses all the + different types of feedback from the different detectors + all into one place. + + These configs are based on the suggestions provided by the ModelReport API + and can only be generated once the reports have been generated. + + Returns a QConfigMapping for the quantization configuration + + Note: + Throws exception if we try to generate mapping on model we already removed observers from + Throws exception if we try to generate mapping without preparing for callibration + """ + # get the mapping info + detector_qconfig_info_combined = self._generate_module_fqn_to_detector_info_mapping( + self._update_detector_quantizaiton_qconfig_info + ) + + # we will do a bit of processing and remove fqns that don't have input weight recommended + + # now we generate the QConfig for each of the options + mapping: QConfigMapping = self._generate_qconfig_mapping_helper( + detector_qconfig_info_combined, + self._quantization_config_generator + ) + + # return the generated mapping + return mapping + + def _quantization_config_generator(self, detector_qconfig_info: DetectorQConfigInfo, module: torch.nn.Module) -> QConfig: + r""" + Returns the quantization configuration generated by the DetectorQConfigInfo object + """ + return detector_qconfig_info.generate_quantization_qconfig(module) + + def _equalization_config_generator( + self, + detector_qconfig_info: DetectorQConfigInfo, + module: torch.nn.Module + ) -> EqualizationQConfig: + r""" + We ignore the module argument here, and only focus on thedetector_qconfig_info + + Returns the equalization configuration generated by the DetectorQConfigInfo object + """ + return detector_qconfig_info.generate_equalization_qconfig() + + def generate_equalization_mapping(self) -> QConfigMapping: + r""" + Generates a QConfigMapping based on the suggestions of the + ModelReport API for equalization. The generated mapping encompasses all the + different types of feedback from the input-weight equalization detector. + + These configs are based on the suggestions provided by the ModelReport API + and can only be generated once the reports have been generated. + + Returns a QConfigMapping for the equalization configuration + """ + # get the mapping info + detector_qconfig_info_combined = self._generate_module_fqn_to_detector_info_mapping( + self._update_detector_equalization_qconfig_info + ) + + # now we generate the QConfig for each of the options + mapping: QConfigMapping = self._generate_qconfig_mapping_helper( + detector_qconfig_info_combined, + self._equalization_config_generator + ) + + # return the generated mapping + return mapping diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/prepare.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/prepare.py new file mode 100644 index 0000000000000000000000000000000000000000..aba802f01c6498c78e7c72e32df3c5717a1f738c --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/fx/prepare.py @@ -0,0 +1,1880 @@ +import copy +import torch +import warnings +from torch.fx import ( + GraphModule, +) +from torch.fx.graph import ( + Graph, + Node, +) +from torch.fx.node import Argument + +from ..quantize import ( + propagate_qconfig_, +) +from ..observer import ( + _is_activation_post_process, + _PartialWrapper, +) +from ..qconfig import ( + _is_reuse_input_qconfig, + QConfigAny, +) +from ..qconfig_mapping import ( + QConfigMapping, +) +from .qconfig_mapping_utils import ( + _generate_node_name_to_qconfig, + _update_qconfig_for_fusion, + _get_flattened_qconfig_dict, + _update_qconfig_for_qat, +) + +from .quantize_handler import ( + _default_root_node_getter, + _get_pattern_to_quantize_handlers, + QuantizeHandler, +) + +from torch.ao.quantization import ( + ObserverBase, + FixedQParamsObserver, + FixedQParamsFakeQuantize, + _DerivedObserverOrFakeQuantize, +) + +from torch.ao.quantization.utils import ( + Pattern, + NodePattern, +) + +from ._equalize import ( + is_equalization_observer, + node_supports_equalization, +) + +from .pattern_utils import ( + _sorted_patterns_dict, +) + +from .match_utils import ( + _MatchResultWithQConfig, + _find_matches, +) + +from .utils import ( + _insert_dequant_stubs_for_custom_module_lstm_output, + _is_custom_module_lstm, + _maybe_get_custom_module_lstm_from_node_arg, + _qconfig_satisfies_dtype_config_constraints, + get_custom_module_class_keys, + all_node_args_have_no_tensors, + assert_and_get_unique_device, + get_non_observable_arg_indexes_and_types, + get_new_attr_name_with_prefix, + node_arg_is_weight, + node_arg_is_bias, + NON_QUANTIZABLE_WEIGHT_OPS, + ObservedGraphModuleAttrs, +) + +from torch.ao.quantization import ( + PlaceholderObserver +) +from torch.ao.quantization.quantize import ( + convert +) + +from ..utils import ( + _parent_name, + get_qconfig_dtypes, + get_swapped_custom_module_class, +) + +from ..backend_config.utils import ( + get_pattern_to_dtype_configs, + get_module_to_qat_module, + get_fusion_pattern_to_root_node_getter, +) +from ..backend_config import ( + BackendConfig, + DTypeConfig, + get_native_backend_config, +) +from .custom_config import ( + PrepareCustomConfig, + StandaloneModuleConfigEntry, +) +from torch.ao.quantization.quantizer import ( + EdgeOrNode, + QuantizationSpec, + QuantizationSpecBase, + FixedQParamsQuantizationSpec, + SharedQuantizationSpec, + DerivedQuantizationSpec, +) +from torch.ao.quantization import ObserverOrFakeQuantize + +from torch._subclasses import FakeTensor + +from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union +from dataclasses import asdict + +__all__ = [ + "insert_observers_for_model", + "prepare", + "propagate_dtypes_for_known_nodes", +] + + +# list of dtypes to not add observers to +_DO_NOT_OBS_DTYPE_LIST = [int, float, torch.bool, None] +_OBS_DTYPE_LIST = [ + torch.quint8, + torch.qint8, + torch.qint32, + torch.float16, + torch.uint8, + torch.int8, + torch.int16, + torch.int32 +] + +_DEFAULT_FP32_OBS_OR_FQ_CTR = PlaceholderObserver.with_args(dtype=torch.float) + +# note: the following default target dtype info dicts are temporary, +# should be moved to the new programmable API class soon +_DEFAULT_FP32_QCONFIG_FOR_TARGET_DTYPE_INFO = { + "input_act_obs_or_fq_ctr": torch.ao.quantization.qconfig._default_fp32_placeholder_qconfig.activation, + "output_act_obs_or_fq_ctr": torch.ao.quantization.qconfig._default_fp32_placeholder_qconfig.activation +} + +_DEFAULT_QUINT8_QCONFIG_FOR_TARGET_DTYPE_INFO = { + "input_act_obs_or_fq_ctr": torch.ao.quantization.qconfig._default_quint8_placeholder_qconfig.activation, + "output_act_obs_or_fq_ctr": torch.ao.quantization.qconfig._default_quint8_placeholder_qconfig.activation +} + + +def _get_observer_kwargs(quant_spec: Union[QuantizationSpec, FixedQParamsQuantizationSpec]): + kwargs_dict = asdict(quant_spec) + return copy.deepcopy(kwargs_dict) + +def _get_qspec_for_arg( + arg: Node, + input_qspec_map: Dict[Node, QuantizationSpecBase], + named_modules: Dict[str, torch.nn.Module] +) -> Optional[QuantizationSpecBase]: + while _is_activation_post_process_node(arg, named_modules): + arg = arg.args[0] # type: ignore[assignment] + return input_qspec_map.get(arg, None) + +def _create_obs_or_fq_from_qspec( + quantization_spec: Optional[QuantizationSpecBase], + obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize], + is_qat: bool, +): + """ Create observer or fake quantize objects based on quantization spec + + Args: + quantization_spec: used to store parameters to create the observer or fake quantizer + obs_or_fq_map: this is a map from edge/output to the corresponding observer/fake_quant + instance, it may be reused for different edge/output depending on configuration + """ + if quantization_spec is None: + return None + if isinstance(quantization_spec, SharedQuantizationSpec): + edge_or_node = quantization_spec.edge_or_node + assert edge_or_node in obs_or_fq_map, \ + "please make sure only refer to edge or node that has " \ + f"observer/fake_quant inserted: '{edge_or_node}' not in\n{obs_or_fq_map.keys()}" + return obs_or_fq_map[edge_or_node] + elif isinstance(quantization_spec, DerivedQuantizationSpec): + # can't use asdict, so not calling get_observer_kwargs here + kwargs = { + "dtype": quantization_spec.dtype, + "derive_qparams_fn": quantization_spec.derive_qparams_fn, + "quant_min": quantization_spec.quant_min, + "quant_max": quantization_spec.quant_max, + "qscheme": quantization_spec.qscheme, + "ch_axis": quantization_spec.ch_axis, + } + edge_or_nodes = quantization_spec.derived_from + obs_or_fqs = [obs_or_fq_map[k] for k in edge_or_nodes] + kwargs["obs_or_fqs"] = obs_or_fqs + return _DerivedObserverOrFakeQuantize.with_args(**kwargs)() + elif isinstance(quantization_spec, FixedQParamsQuantizationSpec): + kwargs = _get_observer_kwargs(quantization_spec) + observer_ctr = FixedQParamsObserver.with_args(**kwargs) + if is_qat: + return FixedQParamsFakeQuantize.with_args(observer=observer_ctr) + else: + return observer_ctr() + + assert isinstance(quantization_spec, QuantizationSpec) + observer_or_fake_quant_ctr = quantization_spec.observer_or_fake_quant_ctr + kwargs = _get_observer_kwargs(quantization_spec) + kwargs.pop("observer_or_fake_quant_ctr") + # we will remove is_dynamic from QuantizationSpec because + # it seems that dynamic range quantization + obs_or_fq_class = observer_or_fake_quant_ctr + if isinstance(observer_or_fake_quant_ctr, _PartialWrapper): + obs_or_fq_class = observer_or_fake_quant_ctr.p.func # type: ignore[union-attr, assignment] + if "PerChannel" not in obs_or_fq_class.__name__: # type: ignore[operator, union-attr] + kwargs.pop("ch_axis") + return observer_or_fake_quant_ctr.with_args(**kwargs)() + +def _needs_obs_or_fq( + prev_output_dtype: Any, + prev_output_is_dynamic: bool, + cur_target_dtype: Any, + cur_target_is_dynamic: bool, + reuse_input_obs_or_fq: bool, + is_zeroth_arg: bool = False) -> bool: + """ + note: we will treat "not specified" as torch.float for now + utility function that checks if we should insert an observer or fake quant node + base on the requested dtype for the nodes from user + + is_zeroth_arg: we only dynamically quantize the first arg of the node right now + this should be removed when we enable configuring dynamic quantization + for a specific argument, this can be removed if we deprecate fx graph mode + quantization + + """ + + # need to insert placeholder observer for dynamic quantization so that it can + # be converted to choose_qparams -> q -> dq in convert step + if cur_target_is_dynamic: + assert cur_target_dtype in _OBS_DTYPE_LIST, \ + f"Expected cur_target_dtype to be torch.float, but got: {cur_target_dtype}" + assert prev_output_dtype not in _DO_NOT_OBS_DTYPE_LIST + return is_zeroth_arg + if reuse_input_obs_or_fq: + return False + # non dynamic quantization + if cur_target_dtype in _OBS_DTYPE_LIST: + return prev_output_dtype in _OBS_DTYPE_LIST + [torch.float] and cur_target_dtype != prev_output_dtype + + # lots of error checking are skipped here for now + return False + +def _is_activation_post_process_node(node: Node, named_modules: Dict[str, torch.nn.Module]) -> bool: + return isinstance(node, torch.fx.Node) and node.op == "call_module" and \ + _is_activation_post_process(named_modules[str(node.target)]) + +def _get_dtype_and_is_dynamic(obs_or_fq: Optional[ObserverOrFakeQuantize]) -> Tuple[Optional[torch.dtype], bool]: + """ Given a constructor for observer or fake quant module, returns + a Tuple of dtype and is_dynamic + """ + # TODO: instead of instantiating the instance, we can use inspect to get the default args + if obs_or_fq is None: + return None, False + else: + return obs_or_fq.dtype, getattr(obs_or_fq, "is_dynamic", False) # type: ignore[return-value] + +def _is_input_arg_dtype_supported_by_backend( + arg: Argument, + node: Node, + qconfig: QConfigAny, + dtype_config: DTypeConfig, + backend_config: BackendConfig, +) -> bool: + """ Check if the configured qconfig for the argument + is supported by the backend or not + """ + if isinstance(arg, (list, tuple)): + return all(_is_input_arg_dtype_supported_by_backend( + a, node, qconfig, + dtype_config, backend_config) for a in arg) + if not isinstance(arg, Node): + return True + # TODO: support check for standalone module + is_weight = node_arg_is_weight(node, arg) + is_bias = node_arg_is_bias(node, arg) + is_activation = not is_weight and not is_bias + if is_activation: + input_act_obs_or_fq_ctr = node.meta["target_dtype_info"].get("input_act_obs_or_fq_ctr") + input_act_obs_or_fq = input_act_obs_or_fq_ctr() if input_act_obs_or_fq_ctr else None + qconfig_dtype, qconfig_is_dynamic = _get_dtype_and_is_dynamic(input_act_obs_or_fq) + # TODO(future PR): remove the cast to bool below after figuring + # out why backend_config has is_dynamic set to None in some cases. + return (dtype_config.input_dtype is None) or ( + dtype_config.input_dtype == qconfig_dtype and + bool(dtype_config.is_dynamic) == bool(qconfig_is_dynamic) and + _qconfig_satisfies_dtype_config_constraints(qconfig, dtype_config.input_dtype_with_constraints) + ) + elif is_weight: + # TODO: move dtype check into `_qconfig_satisfies_dtype_config_constraints` as well + weight_obs_or_fq_ctr = node.meta["target_dtype_info"].get("weight_obs_or_fq_ctr", None) + weight_obs_or_fq = weight_obs_or_fq_ctr() if weight_obs_or_fq_ctr else None + qconfig_weight_dtype, _ = _get_dtype_and_is_dynamic(weight_obs_or_fq) + backend_config_weight_dtype = dtype_config.weight_dtype + dtype_matches = qconfig_weight_dtype == backend_config_weight_dtype + qconfig_satisfies_constraints = _qconfig_satisfies_dtype_config_constraints( + qconfig, dtype_config.weight_dtype_with_constraints, is_activation=False) + return backend_config_weight_dtype is None or (dtype_matches and qconfig_satisfies_constraints) + else: # bias + # TODO: move dtype check into `_qconfig_satisfies_dtype_config_constraints` as well + bias_obs_or_fq_ctr = node.meta["target_dtype_info"].get("bias_obs_or_fq_ctr", None) + bias_obs_or_fq = bias_obs_or_fq_ctr() if bias_obs_or_fq_ctr else None + qconfig_bias_dtype, _ = _get_dtype_and_is_dynamic(bias_obs_or_fq) + backend_config_bias_dtype = dtype_config.bias_dtype + return backend_config_bias_dtype is None or qconfig_bias_dtype == backend_config_bias_dtype + +def _is_output_dtype_supported_by_backend( + node: Node, + qconfig: QConfigAny, + dtype_config: DTypeConfig, +) -> bool: + """ Check if the configured qconfig for the output + is supported by the backend or not + """ + # TODO: move dtype check into `_qconfig_satisfies_dtype_config_constraints` as well + backend_config_output_dtype = dtype_config.output_dtype + # TODO: we should check is_dynamic here as well, the code from _is_input_arg_dtype_supported_by_backend + # from input activation check can be reused here + qconfig_output_dtype = None + output_act_obs_or_fq_ctr = node.meta["target_dtype_info"].get("output_act_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR) + output_act_obs_or_fq = output_act_obs_or_fq_ctr() if output_act_obs_or_fq_ctr else None + qconfig_output_dtype, qconfig_output_is_dynamic = _get_dtype_and_is_dynamic(output_act_obs_or_fq) + # TODO: this is a hack because we can only specify one activation_obs_or_fq for + # qconfig (qconfig.activation), and we are only supporting dynamically quantized + # linear op which has fp32 output dtype, this should be removed if we generalize + # the structure of qconfig in the future + if qconfig_output_is_dynamic: + qconfig_output_dtype = torch.float32 + dtype_matches = qconfig_output_dtype == backend_config_output_dtype + qconfig_satisfies_constraints = _qconfig_satisfies_dtype_config_constraints( + qconfig, dtype_config.output_dtype_with_constraints) + return backend_config_output_dtype is None or (dtype_matches and qconfig_satisfies_constraints) + +def _is_observer_in_same_graph( + node: Node, + named_modules: Dict[str, torch.nn.Module], + obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize], + is_qat, +): + """ Check if observer in same graph + when the node output is not fp32 and input is 'placeholder' + the input is assumed to be quantized, so it is observed + in a different place rather than not observed. + """ + node_output_dtype = _get_arg_target_dtype_as_output(node, named_modules, obs_or_fq_map, is_qat) + if len(node.args) > 0 and isinstance(node.args[0], Node): + if node_output_dtype in [torch.quint8, torch.uint8] and node.args[0].op == 'placeholder': + return False + return True + +def _is_pattern_dtype_config_and_qconfig_supported_by_backend( + pattern: Optional[Pattern], + matched_node_pattern: Optional[List[Node]], + qconfig: QConfigAny, + backend_config: BackendConfig, +) -> bool: + """ Check if the dtype configuration of a pattern is supported by + the backend or not, and whether the qconfig satisfies constraints + specified in the corresponding dtype config. + """ + if backend_config is None or pattern is None: + return True + assert matched_node_pattern is not None and len(matched_node_pattern) >= 1 + pattern_to_dtype_configs = get_pattern_to_dtype_configs(backend_config) + dtype_configs: List[DTypeConfig] = pattern_to_dtype_configs.get(pattern, []) + pattern_to_root_node_getter = get_fusion_pattern_to_root_node_getter(backend_config) + + root_node_getter = pattern_to_root_node_getter.get(pattern, _default_root_node_getter) + root_node = root_node_getter(matched_node_pattern) + input_node = root_node + output_node = matched_node_pattern[0] + for dtype_config in dtype_configs: + # check if arg dtype are supported + supported = True + for arg in list(input_node.args) + list(input_node.kwargs.values()): + supported = supported and _is_input_arg_dtype_supported_by_backend( + arg, input_node, qconfig, dtype_config, backend_config) + # check if output dtype is supported + supported = supported and _is_output_dtype_supported_by_backend( + output_node, qconfig, dtype_config) + if supported: + return True + return False + +def _get_standalone_module_configs( + node: Node, + named_modules: Dict[str, torch.nn.Module], + prepare_custom_config: PrepareCustomConfig, + parent_qconfig: QConfigAny, + parent_backend_config: Optional[BackendConfig], +) -> Tuple[QConfigMapping, Tuple[Any, ...], PrepareCustomConfig, Optional[BackendConfig]]: + """ + Returns the standalone module QConfigMapping and PrepareCustomConfig + for `node`, assuming that the module pointed to by `node` is + a standalone modules. + """ + module_name = str(node.target) + module_type = type(named_modules[module_name]) # type: ignore[index] + # name config has precedence over type config + config_entry = StandaloneModuleConfigEntry(None, (), None, None) + config_entry = prepare_custom_config.standalone_module_classes.get(module_type, config_entry) + config_entry = prepare_custom_config.standalone_module_names.get(module_name, config_entry) + # fallback to use parent module's qconfig if user didn't specify qconfig dict + qconfig_mapping = config_entry.qconfig_mapping or QConfigMapping().set_global(parent_qconfig) + example_inputs = config_entry.example_inputs + prepare_custom_config = config_entry.prepare_custom_config or PrepareCustomConfig() + backend_config = config_entry.backend_config or parent_backend_config + return (qconfig_mapping, example_inputs, prepare_custom_config, backend_config) + +def _qat_swap_modules( + root: torch.nn.Module, + module_to_qat_module: Dict[Pattern, Type[torch.nn.Module]]) -> None: + convert(root, mapping=module_to_qat_module, inplace=True, remove_qconfig=False) + +def _add_matched_node_name_to_set(matched_node_pattern: NodePattern, s: Set[str]): + if isinstance(matched_node_pattern, Node): + s.add(matched_node_pattern.name) + elif isinstance(matched_node_pattern, (list, tuple)): + for maybe_node in matched_node_pattern: + _add_matched_node_name_to_set(maybe_node, s) + +def _insert_obs_or_fq( + node: Node, + obs_or_fq: ObserverOrFakeQuantize, + model: torch.nn.Module, + named_modules: Dict[str, torch.nn.Module], + graph: Graph, +) -> Node: + """ + Attaches `obs_or_fq` to `model`, and creates a node which calls + `obs_or_fq` on the output of `node`. + + obs_or_fq: an instance of Observer or FakeQuantize module + """ + model_device = assert_and_get_unique_device(model) + if model_device: + obs_or_fq.to(model_device) + # add obs_or_fq module as attribute + if is_equalization_observer(obs_or_fq): + prefix = node.name + '_equalization_process_' + else: + prefix = 'activation_post_process_' + get_new_obs_or_fq_name = get_new_attr_name_with_prefix(prefix) + obs_or_fq_name = get_new_obs_or_fq_name(model) + setattr(model, obs_or_fq_name, obs_or_fq) + named_modules[obs_or_fq_name] = obs_or_fq + with graph.inserting_after(node): + new_obs = graph.create_node( + 'call_module', obs_or_fq_name, (node,), {}) + return new_obs + +def _set_target_dtype_info_for_matched_node_pattern( + matched_node_pattern: NodePattern, + last_node: Node, + qconfig: QConfigAny, + qhandler: Optional[QuantizeHandler], + backend_config: BackendConfig, + named_modules: Dict[str, torch.nn.Module], + cache_for_no_tensor_check: Dict[Node, bool], + processed_nodes: Set[Node], +) -> None: + """ Sets the target_dtype_info for each node in matched_node_pattern + Note: processed_nodes is used to ensure we only process each node once + """ + if isinstance(matched_node_pattern, (list, tuple)): + for node_pattern in matched_node_pattern: + _set_target_dtype_info_for_matched_node_pattern( + node_pattern, + last_node, + qconfig, + qhandler, + backend_config, + named_modules, + cache_for_no_tensor_check, + processed_nodes + ) + + # set target_dtype_info if matched_node_pattern is a Node + # other types of matched object, e.g. int, float literals, are ignored + elif isinstance(matched_node_pattern, Node): + # for pyre + assert isinstance(matched_node_pattern, Node) + node = matched_node_pattern + if node in processed_nodes: + return + processed_nodes.add(node) + + if qconfig is None: + return + # TODO: refactor the following code in terms of apply a qconfig to a pattern + # e.g. for a pattern with op1 -> op2 -> op3, and qconfig = QConfig(input_act=obs0, output_act=obs1) + # we set the input_obs_or_fq_ctr for the arguments of op1 to based on qconfig.input_act, + # and set output_obs_or_fq_ctr based on qconfig.output_act + # this also requires we extend the structure of QConfig to support more fine + # grained configurations + target_dtype_info: Dict[str, Any] = ( + _get_target_activation_dtype_for_node( + node, + qconfig, + qhandler, + named_modules, + backend_config, + cache_for_no_tensor_check, + ) + ) + node.meta["target_dtype_info"] = target_dtype_info + +def _get_target_activation_dtype_for_node( + node: Node, + qconfig: QConfigAny, + qhandler: Optional[QuantizeHandler], + named_modules: Dict[str, torch.nn.Module], + backend_config: BackendConfig, + cache_for_no_tensor_check: Dict[Node, bool], +) -> Dict[str, Any]: + """ + For each op attribute in the op's input activation, output activation, + weight, bias - returns the settings of dtype and is_dynamic we expect + for the `quantize` call in the reference model representation, or None + if there is no `quantize` call needed. + + For example, if we have a node corresponding to `op0` in + + x0 -> op0 -> x1 + + And we want a reference quantized representation to be + + x0 -> quant_static -> dequant -> op0 -> quant_dynamic -> dequant -> x1 + + Then this function will return + + { + "input_act_obs_or_fq_ctr": MinMaxObserver.with_args(dtype=torch.quint8, is_dynamic=False), + "output_act_obs_or_fq_ctr": MinMaxObserver.with_args(dtype=torch.quint8, is_dynamic=False), + } + + TODO(future PR, if needed): explicitly spell out the non-Tensor + dtypes. + """ + args_have_no_tensors = \ + all_node_args_have_no_tensors( + node, named_modules, cache_for_no_tensor_check) + if args_have_no_tensors: + return { + "input_act_obs_or_fq_ctr": None, + "output_act_obs_or_fq_ctr": None, + } + # get qconfig to determine the eventual dtype of this node + if qconfig is not None: + act_dtype, weight_dtype, input_act_is_dynamic = \ + get_qconfig_dtypes(qconfig) + + # Currently `QConfig` only has one `activation` field. + # For static quantization, it is reused for both input + # and output activation. For dynamic quantization, this + # field is currently only used for the input activation, + # with the output activation being in fp32. + # In the future this may change as we add more fields + # to the `QConfig` object. + output_act_dtype = act_dtype \ + if (not input_act_is_dynamic) else torch.float + + bias_dtype = torch.float16 \ + if ( + act_dtype == torch.float16 + and weight_dtype == torch.float16 + and (not input_act_is_dynamic) + ) else torch.float + + is_general_tensor_value_op = \ + (qhandler is not None and qhandler.is_general_tensor_value_op()) + + _is_standalone_module = ( + qhandler is not None and qhandler.is_standalone_module() + ) + + weight_index = None + if isinstance(node, Node) and node.op == "call_function" and \ + node.target in backend_config._pattern_complex_format_to_config: + weight_index = backend_config._pattern_complex_format_to_config[node.target]._input_type_to_index.get("weight") + + bias_index = None + if isinstance(node, Node) and node.op == "call_function" and \ + node.target in backend_config._pattern_complex_format_to_config: + bias_index = backend_config._pattern_complex_format_to_config[node.target]._input_type_to_index.get("bias") + + return { + "input_act_obs_or_fq_ctr": qconfig.activation, + "weight_obs_or_fq_ctr": qconfig.weight, + "bias_obs_or_fq_ctr": PlaceholderObserver.with_args(dtype=bias_dtype), + "weight_index": weight_index, + "bias_index": bias_index, + "output_act_obs_or_fq_ctr": qconfig.activation, + "reuse_input_obs_or_fq": _is_reuse_input_qconfig(qconfig), + "input_output_share_observers": is_general_tensor_value_op, + "_is_standalone_module": _is_standalone_module, + } + return copy.copy(_DEFAULT_FP32_QCONFIG_FOR_TARGET_DTYPE_INFO) + +def _get_output_act_obs_or_fq( + arg: Node, + named_modules: Dict[str, torch.nn.Module], + obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize], + is_qat: bool, +) -> ObserverOrFakeQuantize: + """ Get the constructor for observer or fake quant object for + the argument in the original graph as the output of previous node, + skipping inserted observers + + We are assuming that the observers are inserted correctly, and the dtype for + argument in quantized graph will match what is specified by the qconfig + """ + assert isinstance(arg, Node) + if "quantization_annotation" in arg.meta: + return _create_obs_or_fq_from_qspec(arg.meta["quantization_annotation"].output_qspec, obs_or_fq_map, is_qat) + + # Custom module LSTM output is a tuple that we broke down into the internal nodes in order + # to insert DeQuantStubs (see `_insert_dequant_stubs_for_custom_module_lstm_output`). + # Since we modified the graph in this case, we must trace back from the args through + # the specific nodes we added in order to reach the original LSTM node. Otherwise, we would + # not be able to accurately detect whether this node is a consumer of custom module LSTM. + custom_module_lstm_node = _maybe_get_custom_module_lstm_from_node_arg(arg, named_modules) + output_act_obs_or_fq_ctr = None + if custom_module_lstm_node is not None: + output_act_obs_or_fq_ctr = custom_module_lstm_node.meta["target_dtype_info"]["output_act_obs_or_fq_ctr"] + output_act_obs_or_fq = output_act_obs_or_fq_ctr() if output_act_obs_or_fq_ctr else None + elif _is_activation_post_process_node(arg, named_modules): + observed_arg = arg.args[0] + assert isinstance(observed_arg, Node), "Currently we only support observing Node" + if "quantization_annotation" in observed_arg.meta: + output_act_obs_or_fq = \ + _create_obs_or_fq_from_qspec( + observed_arg.meta["quantization_annotation"].output_qspec, obs_or_fq_map, is_qat) + else: + assert "target_dtype_info" in observed_arg.meta + output_act_obs_or_fq_ctr = observed_arg.meta["target_dtype_info"]["output_act_obs_or_fq_ctr"] + output_act_obs_or_fq = output_act_obs_or_fq_ctr() if output_act_obs_or_fq_ctr else None + else: + if "target_dtype_info" in arg.meta: + output_act_obs_or_fq_ctr = \ + arg.meta["target_dtype_info"].get("output_act_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR) + else: + output_act_obs_or_fq_ctr = _DEFAULT_FP32_OBS_OR_FQ_CTR + output_act_obs_or_fq = output_act_obs_or_fq_ctr() if output_act_obs_or_fq_ctr else None + + return output_act_obs_or_fq + +def _get_arg_target_dtype_as_output( + arg: Node, + named_modules: Dict[str, torch.nn.Module], + obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize], + is_qat: bool, +) -> Optional[torch.dtype]: + arg_as_output_act_obs_or_fq = _get_output_act_obs_or_fq(arg, named_modules, obs_or_fq_map, is_qat) + arg_as_output_target_dtype, _ = _get_dtype_and_is_dynamic(arg_as_output_act_obs_or_fq) + return arg_as_output_target_dtype + +def _get_arg_as_input_act_obs_or_fq( + arg: Node, + node: Node, + named_modules: Dict[str, torch.nn.Module], + obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize], + is_qat: bool, +) -> Optional[ObserverOrFakeQuantize]: + """ Get the observer or fake quant constructor for the Argument `arg`, as input + to Node `node` + """ + assert isinstance(arg, Node) + # "input_qspec_map" is the more general design we'll use for pt2e path + # it is a map from input argument node to observer or fake quant constructor, for example + # for the following graph: + # x -> conv -> output + # + # we may annotate conv node like the following: + # conv.meta[...] = QuantizationAnnotation("input_qspec_map": {x: MinMaxObserver.with_args(dtype=torch.qint8)}, ...) + # + if "quantization_annotation" in node.meta: + input_qspec_map = node.meta["quantization_annotation"].input_qspec_map + input_arg_qspec = _get_qspec_for_arg(arg, input_qspec_map, named_modules) + if input_arg_qspec is None: + input_arg_obs_or_fq = _DEFAULT_FP32_OBS_OR_FQ_CTR() + else: + input_arg_obs_or_fq = _create_obs_or_fq_from_qspec(input_arg_qspec, obs_or_fq_map, is_qat) + return input_arg_obs_or_fq + + # we can remove the following path in the future if fx graph mode quantization is + # no longer used + is_weight = node_arg_is_weight(node, arg) + is_bias = node_arg_is_bias(node, arg) + is_activation = not is_weight and not is_bias + obs_or_fq_ctr = None + if is_activation: + obs_or_fq_ctr = node.meta["target_dtype_info"].get("input_act_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR) + elif is_weight: + if node.target not in NON_QUANTIZABLE_WEIGHT_OPS: + obs_or_fq_ctr = node.meta["target_dtype_info"].get("weight_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR) + else: + obs_or_fq_ctr = node.meta["target_dtype_info"].get("bias_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR) + return obs_or_fq_ctr() if obs_or_fq_ctr else None + +def _maybe_insert_input_observer_for_arg_or_kwarg( + node: Union[Node, Any], + arg: Argument, + qconfig: QConfigAny, + model: torch.nn.Module, + named_modules: Dict[str, torch.nn.Module], + graph: Graph, + qhandler: Optional[QuantizeHandler], + prepare_custom_config: PrepareCustomConfig, + obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize], + is_qat: bool, + backend_config: Optional[BackendConfig] = None, +) -> Argument: + """ + Given a `node` and an `arg`, inserts an input observer between + `node` and `arg` if necessary. + """ + # for ops such as torch.cat([x0, x1]), + # traverse through the list + if isinstance(arg, (list, tuple)): + new_arg_to_return = [] + for inner_arg in arg: + new_inner_arg = _maybe_insert_input_observer_for_arg_or_kwarg( + node, inner_arg, qconfig, model, named_modules, + graph, + qhandler, + prepare_custom_config, + obs_or_fq_map, + is_qat, + backend_config) + new_arg_to_return.append(new_inner_arg) + return type(arg)(new_arg_to_return) + + if not isinstance(arg, Node): + return arg + assert isinstance(arg, Node) + # default (no observer) + new_arg = arg + + is_standalone_module = qhandler is not None and qhandler.is_standalone_module() + # TODO: move this to a separate function + if not is_standalone_module: + # Note: qconfig can be None in this branch this we are getting act/fq from + # node.meta now + # regular flow for most nodes, except standalone modules + + if "quantization_annotation" in node.meta: + reuse_input_obs_or_fq = node.meta["quantization_annotation"]._reuse_input_obs_or_fq + else: + assert "target_dtype_info" in node.meta + # TODO: we are assuming "target_dtype_info" exists here, maybe + # a default value also need to be provided here + target_dtype_info = node.meta["target_dtype_info"] + # for nodes that doesn't have `reuse_input_obs_or_fq` configured, + # we'll default to False, this makes configuring this field optional for users + reuse_input_obs_or_fq = target_dtype_info.get("reuse_input_obs_or_fq", False) + arg_as_input_act_obs_or_fq = _get_arg_as_input_act_obs_or_fq(arg, node, named_modules, obs_or_fq_map, is_qat) + arg_as_input_target_dtype, arg_as_input_target_is_dynamic = _get_dtype_and_is_dynamic(arg_as_input_act_obs_or_fq) + + arg_as_output_act_obs_or_fq = _get_output_act_obs_or_fq(arg, named_modules, obs_or_fq_map, is_qat) + arg_as_output_target_dtype, arg_as_output_target_is_dynamic = _get_dtype_and_is_dynamic(arg_as_output_act_obs_or_fq) + + + needs_obs_or_fq = _needs_obs_or_fq( + arg_as_output_target_dtype, + arg_as_output_target_is_dynamic, + arg_as_input_target_dtype, + arg_as_input_target_is_dynamic, + reuse_input_obs_or_fq, + is_zeroth_arg=len(node.args) > 0 and arg is node.args[0], + ) + + else: + assert qconfig is not None + # custom flow for standalone modules + _, _, sm_prepare_custom_config, _ = \ + _get_standalone_module_configs( + node, named_modules, prepare_custom_config, qconfig, backend_config) + sm_input_quantized_idxs = sm_prepare_custom_config.input_quantized_indexes + + # for args, this is set to the index of the current arg + # for kwargs, this is left at None + cur_input_idx = None + for arg_idx, arg_to_check in enumerate(node.args): + if arg_to_check is arg: + cur_input_idx = arg_idx + break + + if cur_input_idx is None: + needs_obs_or_fq = False + else: + arg_as_output_target_dtype = _get_arg_target_dtype_as_output(arg, named_modules, obs_or_fq_map, is_qat) + arg_as_input_target_dtype = torch.quint8 if cur_input_idx in sm_input_quantized_idxs \ + else torch.float + needs_obs_or_fq = ( + (arg_as_output_target_dtype != arg_as_input_target_dtype) and + (arg_as_input_target_dtype != torch.float) + ) + + act_post_process_ctr = qconfig.activation + arg_as_input_act_obs_or_fq = act_post_process_ctr() if act_post_process_ctr else None + + if needs_obs_or_fq: + + existing_obs_node = None + + # Before using the new observer, check if an observer + # of the correct type already exists. If it does, use it. + # This prevents duplicate observer insertions if a node is + # used by multiple nodes. + # TODO: this is looking into how the value is used in the future + # we should remove this + # removing this means we insert one observer for each use, even if they + # have the same dtype, we can have an extra pass that removes the extra observers + for maybe_obs_node in arg.users.keys(): + if maybe_obs_node.op == 'call_module': + maybe_obs_mod = named_modules[maybe_obs_node.target] # type: ignore[index] + if ( + type(maybe_obs_mod) == type(arg_as_input_act_obs_or_fq) and + maybe_obs_mod.dtype == arg_as_input_target_dtype # type: ignore[possibly-undefined] + ): + arg_as_input_act_obs_or_fq = maybe_obs_mod # type: ignore[assignment] + existing_obs_node = maybe_obs_node + break + + assert arg_as_input_act_obs_or_fq is not None + obs_or_fq_map[(arg, node)] = arg_as_input_act_obs_or_fq + if existing_obs_node is None: + new_obs_node = _insert_obs_or_fq( + arg, arg_as_input_act_obs_or_fq, model, named_modules, graph) + # override this arg to be the observed arg + new_arg = new_obs_node + else: + new_arg = existing_obs_node + + return new_arg + + +def _maybe_insert_input_observers_for_node( + node: Node, + qconfig: QConfigAny, + model: torch.nn.Module, + named_modules: Dict[str, torch.nn.Module], + graph: Graph, + qhandler: Optional[QuantizeHandler], + prepare_custom_config: PrepareCustomConfig, + obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize], + is_qat: bool, + backend_config: Optional[BackendConfig] = None +) -> None: + """ + If needed, inserts observers to the input args and kwargs of `node`. + Note: modifies `node` inplace. + + For example, if cur_node needs an observer after prev_node, we change from + + prev_node -> cur_node + + To + + prev_node -> obs -> cur_node + + Note: backend_config only needed for standalone_module node + """ + # Look through every input arg. If that arg's target dtype does not + # match the current node's target dtype, insert an observer. + new_args = [] + for arg in node.args: + new_arg = _maybe_insert_input_observer_for_arg_or_kwarg( + node, arg, qconfig, model, named_modules, graph, + qhandler, + prepare_custom_config, + obs_or_fq_map, + is_qat, + backend_config) + new_args.append(new_arg) + + new_kwargs = {} + for k, kwarg in node.kwargs.items(): + new_kwarg = _maybe_insert_input_observer_for_arg_or_kwarg( + node, kwarg, qconfig, model, named_modules, graph, + qhandler, + prepare_custom_config, + obs_or_fq_map, + is_qat, + backend_config) + new_kwargs[k] = new_kwarg + + # assign the new args and kwargs to the node, inplace + node.args = tuple(new_args) + node.kwargs = new_kwargs + +def _maybe_insert_input_equalization_observers_for_node( + node: Node, + equalization_qconfig: Any, + model: torch.nn.Module, + named_modules: Dict[str, torch.nn.Module], + graph: Graph, + is_branch: bool, +) -> None: + """ + If `node` needs to be equalized, find the input/weight observers it needs in + `equalization_qconfig`, creates them, and inserts it into `graph`. + + If `node` does not need an equalization observer, returns None. + """ + if equalization_qconfig is None or not node_supports_equalization(node, named_modules): + return + + if is_branch: + warnings.warn( + f"Cannot equalize {node} because it is part of a branch." + ) + return + + new_args = [] + for arg in node.args: + if not isinstance(arg, Node) or node_arg_is_bias(node, arg): + new_args.append(arg) + continue + + is_weight = node_arg_is_weight(node, arg) + + act_eq_process_ctr = equalization_qconfig.weight if is_weight else \ + equalization_qconfig.input_activation + + new_eq_obs_mod = act_eq_process_ctr() + new_eq_obs_node = _insert_obs_or_fq( + arg, new_eq_obs_mod, model, named_modules, graph) + + new_args.append(new_eq_obs_node) + + # assign the new args and kwargs to the node, inplace + node.args = tuple(new_args) + +def _maybe_insert_output_observer_for_node( + node: Node, + model: torch.nn.Module, + named_modules: Dict[str, torch.nn.Module], + graph: Graph, + obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize], + is_qat: bool, +) -> Optional[Node]: + """ + If `node` needs an output observer, creates it, inserts it into `graph` + and returns it. + + If `node` does not need an output observer, returns None. + + Note: inserting dynamic quantization ops for output is not supported in fx graph mode + quantization code path right now + """ + assert node.op != 'output', 'observer insertion for outputs is handled elsewhere' + + is_standalone_module = False + if "quantization_annotation" in node.meta: + output_act_obs_or_fq = _create_obs_or_fq_from_qspec( + node.meta["quantization_annotation"].output_qspec, obs_or_fq_map, is_qat + ) + else: + assert "target_dtype_info" in node.meta + is_standalone_module = node.meta["target_dtype_info"].get("_is_standalone_module", False) + output_act_obs_or_fq_ctr = node.meta["target_dtype_info"].get("output_act_obs_or_fq_ctr") + output_act_obs_or_fq = output_act_obs_or_fq_ctr() if output_act_obs_or_fq_ctr else None + target_dtype, target_is_dynamic = _get_dtype_and_is_dynamic(output_act_obs_or_fq) + # uncomment after we support reuse_input_obs_or_fq properly by having separate + # implemntations for this key instead of reusing the input_output_share_observers + # code + # reuse_input_obs_or_fq = node.meta["target_dtype_info"].get("reuse_input_obs_or_fq", False) + # for now we set this to False since reuse_input_obs_or_fq for + # the output of a node is implementation in the same code path as observer sharing, + # we should refactor this part to make it clearer in the future + # and we would be able to read this from config directly + reuse_input_obs_or_fq = False + + # Note: prev_output_dtype = torch.float and prev_output_is_dynamic=False + # because the prev_output is the output of an fp32 op, althought technically + # we should get the dtype of the output from node.meta["val"] in the future + # if we deprecate fx graph mode quantization + needs_obs_or_fq = _needs_obs_or_fq(torch.float, False, target_dtype, target_is_dynamic, reuse_input_obs_or_fq) + # currently the activation in QConfig(activation=...,) is for both input + # and output, and when the activation is configured to be dynamic quantization + # e.g. PlaceholderObserver(dtype=torch.quint8, is_dynamic=True, ...), it means + # the input should by dynamically quantized, but output should not be quantized + # + # there is no way we can specify different observer/fq for input and output + # activation through QConfig today, this limitation is lifted in the + # quantizer/annotation API in pytorch 2.0 export quantization code path, + # but since this code is reused, annotating output to be dynamically quantized + # would not work either for that. + # we can change QConfig to support input/output activation if we want + # to remove the following check, or if we can deprecate fx graph mode quantization + if target_is_dynamic: + needs_obs_or_fq = False + + # we never insert observers to output of standalone module, we assume + # if needed, they are inserted inside the standalone module + needs_obs_or_fq = needs_obs_or_fq and \ + (not is_standalone_module) + + if needs_obs_or_fq: + obs_or_fq_map[node] = output_act_obs_or_fq + return _insert_obs_or_fq(node, output_act_obs_or_fq, model, named_modules, graph) + else: + return None + +def _maybe_insert_observers_before_graph_output( + graph_output_node: Node, + model: torch.nn.Module, + named_modules: Dict[str, torch.nn.Module], + graph: Graph, + obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize], + is_qat: bool, +) -> None: + """ + If the output needs to be quantized and there are any nodes + in the output which are not already observed, inserts observers + for those nodes. + """ + + def _recursive_maybe_replace_node_with_obs( + maybe_node: Argument, + model: torch.nn.Module, + named_modules: Dict[str, torch.nn.Module], + graph: Graph, + ) -> Argument: + """ + Navigate an arbitrary data structure of lists, tuples, dicts. + For each container type, recurse on all inputs. Once any Node + is found, insert an observer if needed and do not recurse further. + + For example, given a structure of + + {'foo1': [[bar1]], 'foo2': {'foo3': [[[bar3]]]}} + + we recurse down to bar1 and bar3, observe them if necessary, + and if we inserted an observer then replace the original node + with its observer. + + Returns the data structure with all nodes needing observation being + replaced by their observers. + """ + if isinstance(maybe_node, Node): + # check dtype of this node + arg_as_output_target_dtype = _get_arg_target_dtype_as_output(maybe_node, named_modules, obs_or_fq_map, is_qat) + observer_mod = None + arg_as_input_target_dtype = torch.float + if "target_dtype_info" in maybe_node.meta: + observer_cls = maybe_node.meta["target_dtype_info"].get("input_act_obs_or_fq_ctr", None) + if observer_cls is not None: + observer_mod = observer_cls() + arg_as_input_target_dtype = observer_mod.dtype + # TODO: this does not handle dynamic quantization yet + need_obs = ( + arg_as_output_target_dtype != arg_as_input_target_dtype and + arg_as_input_target_dtype != torch.float + ) + if need_obs: + assert observer_mod is not None + # insert observer + observer_node = _insert_obs_or_fq( + maybe_node, observer_mod, model, named_modules, graph) + return observer_node + else: + return maybe_node + elif isinstance(maybe_node, (list, tuple)): + results = [] + for inner_node in maybe_node: + results.append(_recursive_maybe_replace_node_with_obs( + inner_node, model, named_modules, graph)) + if isinstance(maybe_node, list): + return results + else: + return tuple(results) + elif isinstance(maybe_node, dict): + results_dict = {} + for k, inner_v in maybe_node.items(): + results_dict[k] = _recursive_maybe_replace_node_with_obs( + inner_v, model, named_modules, graph) + return results_dict + elif maybe_node is None: + return None + else: + raise Exception("Unhandled type for returned node:", maybe_node) + + new_args = [] + for old_arg in graph_output_node.args: + new_args.append( + _recursive_maybe_replace_node_with_obs( + old_arg, model, named_modules, graph)) + + graph_output_node.args = tuple(new_args) # type: ignore[assignment] + + +def _maybe_propagate_dtype_for_node( + node: Node, + target_dtype: Union[torch.dtype, type], + node_name_to_match_result_with_qconfig: Dict[str, _MatchResultWithQConfig], +) -> None: + """ + Assigns `target_dtype` to `node`, setting `is_dynamic` to False. If `node` + is a general tensor shape op, also call this function recursively on + the first argument, to propagate the dtype to the caller. + """ + node.meta["target_dtype_info"]["input_act_obs_or_fq_ctr"] = None + node.meta["target_dtype_info"]["output_act_obs_or_fq_ctr"] = None + # if this is a copy node, propagate to first arg + root_node, _, pattern, qhandler, qconfig = node_name_to_match_result_with_qconfig.get( + node.name, (None, None, None, None, None)) + # TODO: probably need to remove `is_general_tensor_value_op` + if qhandler is not None and qhandler.is_general_tensor_value_op(): + prev_node = node.args[0] + if isinstance(prev_node, Node): + _maybe_propagate_dtype_for_node( + prev_node, target_dtype, node_name_to_match_result_with_qconfig) + +def propagate_dtypes_for_known_nodes( + graph: Graph, + node_name_to_match_result_with_qconfig: Dict[str, _MatchResultWithQConfig], +) -> None: + """ + Currently we assume that inputs to the graph are either `torch.float` or + `torch.quint8`, which is not always correct. For ops such as + `x.masked_fill(mask, value)`, we know that the dtype of `mask` is a + `BoolTensor`. Propagate this information throughout the graph. + + Note: not all dtypes in the graph will be correct after this pass, but a + higher percentage of them will be correct. Hopefully in the future we can + replace this with a better way to reason about dtypes of tensors. + """ + for node in graph.nodes: + non_observable_arg_dict = get_non_observable_arg_indexes_and_types(node) + + for arg_type in non_observable_arg_dict: + non_observable_indices = non_observable_arg_dict[arg_type](node) + + for index in non_observable_indices: + arg = node.args[index] + + # when an argument is a tuple, it does not show up as another node so we need to go through + # all elements of the tuple manually + if isinstance(arg, (tuple, list)): + arg_list = list(arg) + else: + arg_list = [arg] + + for cur_arg in arg_list: + # hard coded arguments show up but aren't `Node` typed and do not need dtype propagated + if isinstance(cur_arg, torch.fx.node.Node): + _maybe_propagate_dtype_for_node( + cur_arg, arg_type, node_name_to_match_result_with_qconfig) + +def _maybe_make_input_output_share_observers( + node: Node, + model: torch.nn.Module, + named_modules: Dict[str, torch.nn.Module], +) -> bool: + """ + Ensures that we share an observer + for all input arguments as well as the output argument. In detail, given + a graph of + + x0 -> obs0 -> op -> x2 + / + x1 -> obs1 / + + where node obs0 points to observer instance observer0, + obs1 points to observer1 and obs2 points to observer2, we make nodes obs1 + and ob2 point to observer0. + Returns: whether the operation succeeded or not + """ + first_arg = None + # find the first non-Tensor arg + for i in range(len(node.args)): + if isinstance(node.args[i], (Node, list, tuple)): + first_arg = node.args[i] + break + + # if there is no non-Tensor arg, return directly + if first_arg is None: + return False + + if isinstance(first_arg, (list, tuple)): + first_arg_arg = first_arg[0] + elif isinstance(first_arg, Node): + first_arg_arg = first_arg + else: + return False + + # if we have a graph such as + # observed_node -> non_observed_node -> cat + # we need to navigate up to the first observer + iteration_guard = 0 + while not _is_activation_post_process_node(first_arg_arg, named_modules): + if not isinstance(first_arg_arg, Node): + return False + # did not find an activation_post_process for the op + if first_arg_arg.op == "placeholder": + return False + # trace back the args until we found the first Tensor/Node + trace_back_node = None + for i in range(len(first_arg_arg.args)): + trace_back_node = first_arg_arg.args[i] + if isinstance(trace_back_node, Node): + break + if trace_back_node is None: + return False + first_arg_arg = trace_back_node + + iteration_guard += 1 + if iteration_guard > 10000: + raise AssertionError('Unable to find observer of previous node') + + assert isinstance(first_arg_arg, Node) + target_to_use = first_arg_arg.target + assert isinstance(target_to_use, str) + obs_mod_to_use = named_modules[target_to_use] + + if isinstance(first_arg, (list, tuple)): + # set all other input observer nodes to use that module + for input_idx, input_arg in enumerate(first_arg): + if input_idx == 0: + continue + iteration_guard = 0 + while not _is_activation_post_process_node(input_arg, named_modules): + # failed to trace back since no input arg for the current node + if len(input_arg.args) < 1: + return False + input_arg = input_arg.args[0] + iteration_guard += 1 + if iteration_guard > 10000: + raise AssertionError('Unable to find observer of previous node') + + parent_name, name = _parent_name(input_arg.target) + setattr(named_modules[parent_name], name, obs_mod_to_use) + + # set the output observer node to use that module + for output_obs_node in node.users.keys(): + assert _is_activation_post_process_node(output_obs_node, named_modules) + parent_name, name = _parent_name(output_obs_node.target) + setattr(named_modules[parent_name], name, obs_mod_to_use) + + # TODO(future PR): delete the orphaned observer modules + return True + +def _remove_output_observer( + node: Node, + model: torch.nn.Module, + named_modules: Dict[str, torch.nn.Module]): + items = list(node.users.items()) + for output_obs_node, _ in items: + assert _is_activation_post_process_node(output_obs_node, named_modules) + output_obs_node.replace_all_uses_with(node) + model.graph.erase_node(output_obs_node) # type: ignore[union-attr, operator] + +def _swap_custom_module_to_observed( + node: Node, + qconfig: QConfigAny, + named_modules: Dict[str, torch.nn.Module], + prepare_custom_config: PrepareCustomConfig): + custom_module = named_modules[node.target] # type: ignore[index] + custom_module_class_mapping = prepare_custom_config.float_to_observed_mapping + observed_custom_module_class = \ + get_swapped_custom_module_class( + custom_module, custom_module_class_mapping, qconfig) + observed_custom_module = \ + observed_custom_module_class.from_float(custom_module) + parent_name, name = _parent_name(node.target) + setattr(named_modules[parent_name], name, observed_custom_module) + +def insert_observers_for_model( + model: GraphModule, + node_name_to_match_result_with_qconfig: Dict[str, _MatchResultWithQConfig], + node_name_to_qconfig: Dict[str, QConfigAny], + prepare_custom_config: PrepareCustomConfig, + equalization_config_map: Dict[str, Any], + backend_config: BackendConfig, + observed_node_names: Set[str], + is_qat: bool, +) -> Optional[Node]: + """ + Inserts observers, using the following high level algorithm: + + For each node in the graph: + 1. determine the target dtype of this node in the quantized graph, and save + it for future steps + 2. determine the target dtype or all args and kwargs of this node + 3. if any arg or kwarg's target dtype does not match the current node's + dtype, insert an observer + 4. if the current node needs an output observer, insert it + + For example: + + - starting graph: + x0 -> linear -> x1 + + - observed graph after processing x0: + x0(fp32) + + - observed graph after processing linear: + x0(fp32) -> x0_obs0(int8) -> linear(int8) -> linear_obs0(int8) + + - observed graph after processing x1: + x0(fp32) -> x0_obs0(int8) -> linear(int8) -> linear_obs0(int8) -> x1 + + After a node is processed, the naive observer placement is guaranteed to be + complete for that node and all of its predecessors. There can be future + passes which optimize the graph by deduplicating observers, etc. + """ + + # node.meta["target_dtype_info"] stores the target dtype information + # that's derived from qconfig for the Node, for example, if we have + # a conv2d node that has a qconfig + # qconfig = QConfig(activation=..., weight=...) + # # information for input and bias node omitted + # # for getattr node + # # weight = getattr(self, 'weight') + # weight.meta["target_dtype_info"] = { + # 'output_act_obs_or_fq_ctr': qconfig.weight, + # } + # # for conv2d node + # # conv2d = call_function[target=torch.nn.functional.conv2d]( + # # args=(input, weight, bias)) + # conv2d.meta["target_dtype_info"] = { + # 'input_act_obs_or_fq_ctr': qconfig.activation + # 'weight_obs_or_fq_ctr': qconfig.weight, + # 'bias_obs_or_fq_ctr': PlaceholderObserver.with_args(dtype=torch.float32), + # 'output_act_obs_or_fq_ctr': qconfig.activation, + # } + # + cache_for_no_tensor_check: Dict[Node, bool] = {} + + # first, populate the dtype map based only on qconfig and qhandler + # this assumes: + # graph inputs are fp32 by default, and int8 where overriden + # other nodes output dtype is specified by the qconfig + named_modules = dict(model.named_modules(remove_duplicate=False)) + + input_quantized_idxs: List[int] = prepare_custom_config.input_quantized_indexes + output_quantized_idxs: List[int] = prepare_custom_config.output_quantized_indexes + processed_nodes: Set[Node] = set() + # initialize target_dtype_info + for node in model.graph.nodes: + node.meta["target_dtype_info"] = copy.copy(_DEFAULT_FP32_QCONFIG_FOR_TARGET_DTYPE_INFO) + + inputs_seen_counter = 0 + outputs_seen_counter = 0 + placeholder_node_to_input_index: Dict[Node, int] = {} + # TODO: we probably don't need this counter since each graph will only have + # one output node? + output_node_to_output_index: Dict[Node, int] = {} + for node in model.graph.nodes: + if node.op == "placeholder": + placeholder_node_to_input_index[node] = inputs_seen_counter + inputs_seen_counter += 1 + if node.op == "output": + output_node_to_output_index[node] = outputs_seen_counter + outputs_seen_counter += 1 + + # Step 1, set the observer or fake quantize module constructor for each node in the + # matched_node_pattern + + for match_res_with_qconfig in node_name_to_match_result_with_qconfig.values(): + last_node, matched_node_pattern, pattern, qhandler, qconfig = match_res_with_qconfig + assert qhandler is not None + _set_target_dtype_info_for_matched_node_pattern( + matched_node_pattern, + last_node, + qconfig, + qhandler, + backend_config, + named_modules, + cache_for_no_tensor_check, + processed_nodes + ) + + # Step 2. Special cases for some operators, we might be able to remove them + # in the future if we know dtype information of each node better + + # Step 2.1. some settings are not based on patterns, we need to process each node + # instead + for node in model.graph.nodes: + if node.op == "placeholder" and placeholder_node_to_input_index[node] in input_quantized_idxs: + # users are not supposed to call calculate_qparams on PlaceholderObserver, and + # this is OK because we are using this as a way to encode the dtypes of input + # tensor, we won't actually insert these observers in the graph and won't + # actually call calculate_qparams + node.meta["target_dtype_info"] = copy.copy(_DEFAULT_QUINT8_QCONFIG_FOR_TARGET_DTYPE_INFO) + elif node.op in ("call_module", "call_method", "call_function"): + args_have_no_tensors = \ + all_node_args_have_no_tensors( + node, named_modules, cache_for_no_tensor_check) + if args_have_no_tensors: + node.meta["target_dtype_info"] = { + "input_act_obs_or_fq_ctr": None, + "output_act_obs_or_fq_ctr": None, + } + elif node.op == "output" and output_node_to_output_index[node] in output_quantized_idxs: + # TODO(future PR): update the output_quantized_idxs API to match + # arbitrary data structures. There is always a single output, and + # that output can have arbitrary nesting of values. List[int] is + # not the right data type for this. + + # TODO(future PR): support more dtypes in model outputs, if necessary + node.meta["target_dtype_info"] = copy.copy(_DEFAULT_QUINT8_QCONFIG_FOR_TARGET_DTYPE_INFO) + + # Step 2.2, for nodes with known input dtypes, propagate them throughout the + # graph. For example, if there is a call such as + # x1 = x0.masked_fill(mask, 1) + # we propagate the type of mask to be torch.bool + propagate_dtypes_for_known_nodes(model.graph, node_name_to_match_result_with_qconfig) + + # Step 3, check if the requested target_dtype_info is supported by backend or not + # if not, we'll reset the target_dtye_info to use the default (float Tensor) + + # reset the counters and set of processed_nodes + processed_nodes: Set[Node] = set() + for match_res_with_qconfig in node_name_to_match_result_with_qconfig.values(): + last_node, matched_node_pattern, pattern, qhandler, qconfig = match_res_with_qconfig + is_supported_by_backend = _is_pattern_dtype_config_and_qconfig_supported_by_backend( + pattern, matched_node_pattern, qconfig, backend_config) + assert qhandler is not None + + # get output_act_dtype so that we don't also reset the special typed nodes + # TODO: we might want to handle these more uniformly with the default path + # this can be improved if we can use node.meta["val"] + output_act_or_fq_ctr = node.meta["target_dtype_info"]["output_act_obs_or_fq_ctr"] + output_act_or_fq = output_act_or_fq_ctr() if output_act_or_fq_ctr else None + output_act_dtype, _ = _get_dtype_and_is_dynamic(output_act_or_fq) + if not is_supported_by_backend and output_act_dtype not in [None, int, float, torch.bool]: + # restore target_dtype_info to default if it is not supported by backend + _set_target_dtype_info_for_matched_node_pattern( + matched_node_pattern, + last_node, + torch.ao.quantization.qconfig._default_fp32_placeholder_qconfig, + None, + backend_config, + named_modules, + cache_for_no_tensor_check, + processed_nodes + ) + + # After this point, the current node and all of its arguments + # have a target_dtype_info assigned. Now, we insert observers for inputs + # of this node (if needed for this node), and the output of this node + # (if needed for this node). + + # Since we are mutating the graph as we go, we iterate over the original + # nodes before observer insertion, instead of model.graph.nodes. + nodes_before_observation = list(model.graph.nodes) + + # Avoid duplicates custom module swaps for multiple nodes with same target. + custom_module_names_already_swapped: Set[str] = set() + + # TODO: reuse placeholder_node_to_input_index and output_node_to_output_index + # reset inputs/outputs counters + inputs_seen_counter = 0 + outputs_seen_counter = 0 + results_node = None + obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize] = {} + + # TODO: change this to insert obs/fq by pattern instead of by node + for node in nodes_before_observation: + + if node.op == 'placeholder': + # if a graph input is in fp32, it does not need observation + # if a graph input is in int8, we assume the observation happens + # outside of the graph, and no additional observation is needed + pass + + elif node.op in ('call_module', 'call_method', 'call_function', 'output'): + # check for matches + last_node, matched_node_pattern, pattern, qhandler, qconfig = ( + node_name_to_match_result_with_qconfig.get(node.name, (None, None, None, None, None)) # type: ignore[assignment] + ) + equalization_qconfig = equalization_config_map.get(node.name, None) + + this_node_dtype_info = node.meta["target_dtype_info"] + if "val" in node.meta: + output_is_a_tensor = ( + this_node_dtype_info is not None and + isinstance(node.meta["val"], FakeTensor) + ) + else: + output_is_a_tensor = this_node_dtype_info is not None + + skip_inserting_observers = ( + (qconfig is None) or + not output_is_a_tensor + ) and ( + not node.op == 'output' + ) + + # TODO: take a closer look to see if we can remove this check + # right now it is here because of `observed_node_names`, we are using + # it as an indicator for swapping the modules to reference modules in + # convert + is_supported_by_backend = _is_pattern_dtype_config_and_qconfig_supported_by_backend( + pattern, matched_node_pattern, qconfig, backend_config) + + if not skip_inserting_observers and is_supported_by_backend: + named_modules = dict(model.named_modules(remove_duplicate=False)) + if node.op != 'output': + assert matched_node_pattern is not None + # add matched nodes to the observed node name set + _add_matched_node_name_to_set(matched_node_pattern, observed_node_names) + + # This is currently only used for equalization. + # Checks if the current node is in a branch in which the two + # first layers are both being quantized. + # + # ex. conv2 + # / + # x -> conv1 + # + # If this is the case, we will not apply equalization to the + # initial two layers. + is_quantized_branch = False + if ( + len(node.args) > 0 and + isinstance(node.args[0], Node) and + len(node.args[0].users) > 1 + ): + for user in node.args[0].users: + # Checks if there exists another user being quantized + is_user_quantized = ( + node_name_to_qconfig.get(user.name, None) is not None or + (user.op == 'call_module' and isinstance(named_modules[str(user.target)], ObserverBase)) + ) + if user != node and is_user_quantized: + is_quantized_branch = True + + pattern_to_root_node_getter = get_fusion_pattern_to_root_node_getter(backend_config) + root_node_getter = pattern_to_root_node_getter.get(pattern, _default_root_node_getter) + root_node = root_node_getter(matched_node_pattern) + is_input_node_of_the_pattern = node is root_node + if is_input_node_of_the_pattern: + # this modifies node inplace + _maybe_insert_input_observers_for_node( + node, qconfig, model, named_modules, model.graph, + qhandler, + prepare_custom_config, + obs_or_fq_map, + is_qat, + backend_config) + + # insert equalization input observers if needed + _maybe_insert_input_equalization_observers_for_node( + node, equalization_qconfig, model, named_modules, model.graph, + is_quantized_branch) + + is_last_node_of_pattern = node is last_node + input_output_share_observers = node.meta["target_dtype_info"].get("input_output_share_observers", False) + reuse_input_obs_or_fq = node.meta["target_dtype_info"].get("reuse_input_obs_or_fq", False) + + if is_last_node_of_pattern: + if _is_custom_module_lstm(node, named_modules, qconfig, qhandler): + # Currently custom module outputs are assumed to be already quantized, + # so we need to insert a DeQuantStub after the output. For custom module + # LSTM specifically, the outputs are also a nested tuple, so we must first + # break down the tuple to insert DeQuantStubs after the internal nodes. + + # TODO: This currently diverges from how custom modules are handled today, + # where we insert observers after the output instead of DeQuantStubs, and + # replace these observers with "dequantize" nodes during convert. Conceptually, + # these output observers are the same as DeQuantStubs. In the future, we + # should resolve this inconsistency by inserting DeQuantStubs for all custom + # modules, not just for LSTM. + _insert_dequant_stubs_for_custom_module_lstm_output(node, model, named_modules, model.graph) + if node.target not in custom_module_names_already_swapped: + custom_module_names_already_swapped.add(node.target) + _swap_custom_module_to_observed(node, qconfig, named_modules, prepare_custom_config) + else: + # this returns the new observer node if it was needed + maybe_output_obs_node = _maybe_insert_output_observer_for_node( + node, model, named_modules, model.graph, obs_or_fq_map, is_qat) + + if maybe_output_obs_node is not None: + # Update users of original node to use the output observer + # instead. For example, change + # + # next_node + # / + # cur_node -> obs + # + # to + # + # next_node + # / + # cur_node -> obs + # + # We need to save orig users before updating uses because + # the list of users will change as we update uses + orig_users = list(node.users.keys()) + for user_node in orig_users: + if user_node is maybe_output_obs_node: + continue + user_node.replace_input_with(node, maybe_output_obs_node) + + _is_observer_in_same_graph_ = _is_observer_in_same_graph( + node, named_modules, obs_or_fq_map, is_qat) + + # for ops whose inputs and outputs share observer/fqs, we modify the graph + # to make all inputs and outputs use the first input's + # observer/fq + if (input_output_share_observers and _is_observer_in_same_graph_) or \ + reuse_input_obs_or_fq: + if not _maybe_make_input_output_share_observers(node, model, named_modules): + _remove_output_observer(node, model, named_modules) + + if qhandler is not None and qhandler.is_custom_module(): + if node.target not in custom_module_names_already_swapped: + custom_module_names_already_swapped.add(node.target) + _swap_custom_module_to_observed(node, qconfig, named_modules, prepare_custom_config) + + else: # output + _maybe_insert_observers_before_graph_output(node, model, named_modules, model.graph, obs_or_fq_map, is_qat) + + # + # After this point, the current node has input and output observers + # that it needs for itself inserted. + # + + # increment the counters, so future inputs and outputs are assigned + # correct dtypes + if node.op == 'placeholder': + inputs_seen_counter += 1 + elif node.op == 'output': + outputs_seen_counter += 1 + results_node = node + + return results_node + +def _run_prepare_fx_on_standalone_modules( + model: torch.nn.Module, + is_qat: bool, + named_modules: Dict[str, torch.nn.Module], + node_name_to_match_result_with_qconfig: Any, + prepare_custom_config: PrepareCustomConfig, + backend_config: BackendConfig, +) -> None: + """ + Runs prepare_fx on each standalone module. Note: this does + not modify the graph, it just replaces the unobserved modules with + their observed versions. + """ + for (root_node, _, pattern, qhandler, qconfig) in node_name_to_match_result_with_qconfig.values(): + if qhandler is None: + continue + elif not qhandler.is_standalone_module(): + continue + + sm_qconfig_mapping, sm_example_inputs, sm_prepare_custom_config, \ + sm_backend_config = _get_standalone_module_configs( + root_node, named_modules, prepare_custom_config, qconfig, backend_config) + + standalone_module = named_modules[root_node.target] + prepare = \ + torch.ao.quantization.quantize_fx._prepare_standalone_module_fx # type: ignore[attr-defined] + observed_standalone_module = \ + prepare( + standalone_module, + sm_qconfig_mapping, + is_qat, + example_inputs=sm_example_inputs, + prepare_custom_config=sm_prepare_custom_config, + backend_config=sm_backend_config) + parent_name, name = _parent_name(root_node.target) + setattr(named_modules[parent_name], name, observed_standalone_module) + named_modules[root_node.target] = observed_standalone_module + +def _save_state( + observed: GraphModule, + node_name_to_qconfig: Dict[str, QConfigAny], + node_name_to_scope: Dict[str, Tuple[str, type]], + prepare_custom_config: PrepareCustomConfig, + equalization_node_name_to_qconfig: Dict[str, Any], + qconfig_mapping: QConfigMapping, + is_qat: bool, + observed_node_names: Set[str], +) -> None: + observed.meta["_observed_graph_module_attrs"] = ( + ObservedGraphModuleAttrs( + node_name_to_qconfig=node_name_to_qconfig, + node_name_to_scope=node_name_to_scope, + prepare_custom_config=prepare_custom_config, + equalization_node_name_to_qconfig=equalization_node_name_to_qconfig, + qconfig_mapping=qconfig_mapping, + is_qat=is_qat, + observed_node_names=observed_node_names, + ) + ) + +def prepare( + model: GraphModule, + qconfig_mapping: Union[QConfigMapping, Dict[str, Any]], + is_qat: bool, + node_name_to_scope: Dict[str, Tuple[str, type]], + example_inputs: Tuple[Any, ...], + prepare_custom_config: Union[PrepareCustomConfig, Dict[str, Any], None] = None, + _equalization_config: Union[QConfigMapping, Dict[str, Any], None] = None, + backend_config: Union[BackendConfig, Dict[str, Any], None] = None, + is_standalone_module: bool = False) -> GraphModule: + """ standalone_module means it a submodule that is not inlined in + parent module, and will be quantized separately as one unit. + + How the standalone module is observed is specified by `input_quantized_idxs` and + `output_quantized_idxs` in the prepare_custom_config for the standalone module + Args: + node_name_to_scope: mapping from node name to the scope of the module which contains the node. + The scope is a tuple of fully qualified path of the module and the type of the module + Returns: + model(GraphModule): prepared standalone module + attributes related to standalone module + in model.meta["_observed_graph_module_attrs"]: + is_observed_standalone_module (bool): boolean value that shows whether the + current model is a observed standalone module or not + standalone_module_input_quantized_idxs(List[Int]): a list of + indexes for the graph input that is expected to be quantized, + same as input_quantized_idxs configuration provided + for the standalone module + standalone_module_output_quantized_idxs(List[Int]): a list of + indexs for the graph output that is quantized + same as input_quantized_idxs configuration provided + for the standalone module + """ + if prepare_custom_config is None: + prepare_custom_config = PrepareCustomConfig() + if _equalization_config is None: + _equalization_config = QConfigMapping() + + if isinstance(qconfig_mapping, Dict): + warnings.warn( + "Passing a QConfig dictionary to prepare is deprecated and will not be supported " + "in a future version. Please pass in a QConfigMapping instead.") + qconfig_mapping = QConfigMapping.from_dict(qconfig_mapping) + + if isinstance(_equalization_config, Dict): + warnings.warn( + "Passing a QConfig dictionary to prepare for equalization is deprecated and will not " + "be supported in a future version. Please pass in a QConfigMapping instead.") + _equalization_config = QConfigMapping.from_dict(_equalization_config) + + if isinstance(prepare_custom_config, Dict): + warnings.warn( + "Passing a prepare_custom_config_dict to prepare is deprecated and will not be supported " + "in a future version. Please pass in a PrepareCustomConfig instead.") + prepare_custom_config = PrepareCustomConfig.from_dict(prepare_custom_config) + + if isinstance(backend_config, Dict): + warnings.warn( + "Passing a backend_config_dict to prepare is deprecated and will not be supported " + "in a future version. Please pass in a BackendConfig instead.") + backend_config = BackendConfig.from_dict(backend_config) + + assert isinstance(qconfig_mapping, QConfigMapping) + assert isinstance(_equalization_config, QConfigMapping) + qconfig_mapping = copy.deepcopy(qconfig_mapping) + _equalization_config = copy.deepcopy(_equalization_config) + + # mapping from a tuple of nodes in reverse order to uninitialized + # QuantizeHandler subclass. For example, + # { + # # match a single node + # (: + # ), + # # match multiple nodes in reverse order + # ((, ): + # ), + # } + + pattern_to_quantize_handler: Dict[Pattern, QuantizeHandler] = {} + if backend_config is None: + backend_config = get_native_backend_config() + pattern_to_quantize_handler = _get_pattern_to_quantize_handlers(backend_config) + pattern_to_quantize_handler = _sorted_patterns_dict(pattern_to_quantize_handler) + + root_node_getter_mapping = \ + get_fusion_pattern_to_root_node_getter(backend_config) + + _update_qconfig_for_fusion(model, qconfig_mapping) + _update_qconfig_for_fusion(model, _equalization_config) + flattened_qconfig_dict = _get_flattened_qconfig_dict(qconfig_mapping) + # TODO: support regex as well + propagate_qconfig_(model, flattened_qconfig_dict, prepare_custom_config.to_dict()) + + if is_qat: + module_to_qat_module = get_module_to_qat_module(backend_config) + _qat_swap_modules(model, module_to_qat_module) + _update_qconfig_for_qat(qconfig_mapping, backend_config) + + # mapping from fully qualified module name to module instance + # for example, + # { + # '': Model(...), + # 'linear': Linear(...), + # 'linear.weight_fake_quant': PerChannelMinMaxObserver(...), + # } + named_modules = dict(model.named_modules(remove_duplicate=False)) + + # fill node_name_to_qconfig, a map from node name to qconfig, used in _find_matches + equalization_node_name_to_qconfig = _generate_node_name_to_qconfig( + model, named_modules, model.graph, _equalization_config, node_name_to_scope) + node_name_to_qconfig = _generate_node_name_to_qconfig(model, named_modules, model.graph, qconfig_mapping, node_name_to_scope) + + # match the patterns that will get quantized + standalone_module_names = list(prepare_custom_config.standalone_module_names.keys()) + standalone_module_classes = list(prepare_custom_config.standalone_module_classes.keys()) + + custom_module_classes = get_custom_module_class_keys(prepare_custom_config.float_to_observed_mapping) + matches_without_qconfig = _find_matches( + model.graph, named_modules, pattern_to_quantize_handler, root_node_getter_mapping, + standalone_module_names, standalone_module_classes, custom_module_classes) + + # map qconfig instances to matches + node_name_to_match_result_with_qconfig = {} + for node_name, match_without_qconfig in matches_without_qconfig.items(): + match_with_qconfig = (*match_without_qconfig, node_name_to_qconfig[node_name]) + node_name_to_match_result_with_qconfig[node_name] = match_with_qconfig + + _run_prepare_fx_on_standalone_modules( + model, is_qat, named_modules, node_name_to_match_result_with_qconfig, prepare_custom_config, backend_config) + + # record names for the set of observed node, so that in convert step + # we know whether we need to convert a floating point module to reference + # quantized module or not + observed_node_names: Set[str] = set() + + result_node = insert_observers_for_model( + model, + node_name_to_match_result_with_qconfig, + node_name_to_qconfig, + prepare_custom_config, + equalization_node_name_to_qconfig, + backend_config, + observed_node_names, + is_qat, + ) + model = GraphModule(model, model.graph) + + _save_state(model, node_name_to_qconfig, node_name_to_scope, + prepare_custom_config, equalization_node_name_to_qconfig, + qconfig_mapping, is_qat, observed_node_names) + + if is_standalone_module: + assert result_node is not None + assert isinstance(result_node.args[0], Node), \ + "standalone module only supports returning simple value currently"\ + "(not tuple, dict etc.)" + # these inputs are observed in parent + # converting List[int] to Tensor since module attribute is + # Union[Tensor, Module] + input_quantized_idxs: List[int] = prepare_custom_config.input_quantized_indexes + output_quantized_idxs: List[int] = prepare_custom_config.output_quantized_indexes + observed_graph_module_attrs = model.meta["_observed_graph_module_attrs"] + # inplace modification + observed_graph_module_attrs.is_observed_standalone_module = True + observed_graph_module_attrs.standalone_module_input_quantized_idxs = \ + input_quantized_idxs + observed_graph_module_attrs.standalone_module_output_quantized_idxs = \ + output_quantized_idxs + return model diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/__pycache__/duplicate_dq_pass.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/__pycache__/duplicate_dq_pass.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..27c615464d8a967d4fff09769e9e17bd067d740c Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/__pycache__/duplicate_dq_pass.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/__pycache__/prepare.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/__pycache__/prepare.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..911ed46ad9d8bc92f5922d56115c4b342b948402 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/__pycache__/prepare.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/__pycache__/qat_utils.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/__pycache__/qat_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0ec4b00aeedf5ad7fbd1fa50a08e9e97eb9386db Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/__pycache__/qat_utils.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/__pycache__/utils.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b9ad8350e401249616dfbdfe01223b54b9aef9bb Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/__pycache__/utils.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/duplicate_dq_pass.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/duplicate_dq_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..48c7d7247b99c1ea8a666fbee8aa8db41f4e0b2a --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/duplicate_dq_pass.py @@ -0,0 +1,83 @@ +import logging +import operator + +import torch + +from torch.ao.quantization.pt2e.utils import ( + _filter_sym_size_users, + _is_valid_annotation, +) + +from torch.fx.node import map_arg +from torch.fx.passes.infra.pass_base import PassBase, PassResult + + +logger = logging.getLogger(__name__) +logger.setLevel(logging.WARNING) + +__all__ = ["DuplicateDQPass"] + +_QUANTIZE_OPS = [ + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.quantize_per_tensor.tensor, + torch.ops.quantized_decomposed.quantize_per_channel.default, +] + +_DEQUANTIZE_OPS = [ + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, + torch.ops.quantized_decomposed.dequantize_per_channel.default, +] + + +def _maybe_duplicate_dq( + gm: torch.fx.GraphModule, dq_node: torch.fx.Node, user: torch.fx.Node +): + annotation = user.meta.get("quantization_annotation", None) + if not _is_valid_annotation(annotation): + return + with gm.graph.inserting_after(dq_node): + new_node = gm.graph.node_copy(dq_node) + + def maybe_replace_node(n: torch.fx.Node) -> torch.fx.Node: + if n == dq_node: + return new_node + else: + return n + + new_args = map_arg(user.args, maybe_replace_node) + new_kwargs = map_arg(user.kwargs, maybe_replace_node) + user.args = new_args + user.kwargs = new_kwargs + + +class DuplicateDQPass(PassBase): + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + for node in graph_module.graph.nodes: + if node.op == "call_function" and node.target in _DEQUANTIZE_OPS: + dq_users = _filter_sym_size_users(node) + if len(dq_users) <= 1: + continue + # Do not duplicate dq for dynamic quantization + # Pattern: choose_qparam - getitem - q - dq + q_node = node.args[0] + if q_node.op == "call_function" and q_node.target in _QUANTIZE_OPS: + getitem_node = q_node.args[1] + if ( + isinstance(getitem_node, torch.fx.node.Node) + and getitem_node.op == "call_function" + and getitem_node.target == operator.getitem + ): + choose_qparam_node = getitem_node.args[0] + if ( + isinstance(choose_qparam_node, torch.fx.node.Node) + and choose_qparam_node.op == "call_function" + and choose_qparam_node.target + == torch.ops.quantized_decomposed.choose_qparams.tensor + ): + continue + for user in dq_users: + _maybe_duplicate_dq(graph_module, node, user) + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + return PassResult(graph_module, True) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/export_utils.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/export_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d73319df019b1248c247a4dce5c7673c429d7866 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/export_utils.py @@ -0,0 +1,211 @@ +import types + +import torch +import torch.nn.functional as F + + +__all__ = [ + "model_is_exported", + "_WrapperModule", +] + + +class _WrapperModule(torch.nn.Module): + """Class to wrap a callable in an :class:`torch.nn.Module`. Use this if you + are trying to export a callable. + """ + + def __init__(self, fn): + super().__init__() + self.fn = fn + + def forward(self, *args, **kwargs): + """Simple forward that just calls the ``fn`` provided to :meth:`WrapperModule.__init__`.""" + return self.fn(*args, **kwargs) + + +def model_is_exported(m: torch.nn.Module) -> bool: + """ + Return True if the `torch.nn.Module` was exported, False otherwise + (e.g. if the model was FX symbolically traced or not traced at all). + """ + return isinstance(m, torch.fx.GraphModule) and any( + "val" in n.meta for n in m.graph.nodes + ) + + +def _replace_dropout(m: torch.fx.GraphModule, train_to_eval: bool): + """ + Switch dropout patterns in the model between train and eval modes. + + Dropout has different behavior in train vs eval mode. For exported models, + however, calling `model.train()` or `model.eval()` does not automatically switch + the dropout behavior between the two modes, so here we need to rewrite the aten + dropout patterns manually to achieve the same effect. + + See https://github.com/pytorch/pytorch/issues/103681. + """ + # Avoid circular dependencies + from .utils import get_aten_graph_module + + # Needed to ensure subgraph matches are self-contained + m.graph.eliminate_dead_code() + m.recompile() + + for inplace in [False, True]: + + def dropout_train(x): + return F.dropout(x, p=0.5, training=True, inplace=inplace) + + def dropout_eval(x): + return F.dropout(x, p=0.5, training=False, inplace=inplace) + + example_inputs = (torch.randn(1),) + if train_to_eval: + match_pattern = get_aten_graph_module( + _WrapperModule(dropout_train), example_inputs + ) + replacement_pattern = get_aten_graph_module( + _WrapperModule(dropout_eval), example_inputs + ) + else: + match_pattern = get_aten_graph_module( + _WrapperModule(dropout_eval), example_inputs + ) + replacement_pattern = get_aten_graph_module( + _WrapperModule(dropout_train), example_inputs + ) + + from torch.fx.subgraph_rewriter import replace_pattern_with_filters + + replace_pattern_with_filters( + m, + match_pattern, + replacement_pattern, + match_filters=[], + ignore_literals=True, + ) + m.recompile() + + +def _replace_batchnorm(m: torch.fx.GraphModule, train_to_eval: bool): + """ + Switch batchnorm patterns in the model between train and eval modes. + + Batchnorm has different behavior in train vs eval mode. For exported models, + however, calling `model.train()` or `model.eval()` does not automatically switch + the batchnorm behavior between the two modes, so here we need to rewrite the aten + batchnorm patterns manually to achieve the same effect. + """ + # TODO(Leslie): This function still fails to support custom momentum and eps value. + # Enable this support in future updates. + + # Avoid circular dependencies + from .utils import get_aten_graph_module + + # Needed to ensure subgraph matches are self-contained + m.graph.eliminate_dead_code() + m.recompile() + + def bn_train( + x: torch.Tensor, + bn_weight: torch.Tensor, + bn_bias: torch.Tensor, + bn_running_mean: torch.Tensor, + bn_running_var: torch.Tensor, + ): + return F.batch_norm( + x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=True + ) + + def bn_eval( + x: torch.Tensor, + bn_weight: torch.Tensor, + bn_bias: torch.Tensor, + bn_running_mean: torch.Tensor, + bn_running_var: torch.Tensor, + ): + return F.batch_norm( + x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=False + ) + + example_inputs = ( + torch.randn(1, 1, 3, 3), # x + torch.randn(1), # bn_weight + torch.randn(1), # bn_bias + torch.randn(1), # bn_running_mean + torch.randn(1), # bn_running_var + ) + if train_to_eval: + match_pattern = get_aten_graph_module(_WrapperModule(bn_train), example_inputs) + replacement_pattern = get_aten_graph_module( + _WrapperModule(bn_eval), example_inputs + ) + else: + match_pattern = get_aten_graph_module(_WrapperModule(bn_eval), example_inputs) + replacement_pattern = get_aten_graph_module( + _WrapperModule(bn_train), example_inputs + ) + + from torch.fx.subgraph_rewriter import replace_pattern_with_filters + + replace_pattern_with_filters( + m, + match_pattern, + replacement_pattern, + match_filters=[], + ignore_literals=True, + ) + m.recompile() + + +# TODO: expose these under this namespace? +def _move_exported_model_to_eval(model: torch.fx.GraphModule): + """ + Move an exported GraphModule to eval mode. + + This is equivalent to model.eval() but only for certain special ops like dropout, batchnorm. + QAT users should call this before performing inference on the model. + """ + _replace_dropout(model, train_to_eval=True) + _replace_batchnorm(model, train_to_eval=True) + return model + + +def _move_exported_model_to_train(model: torch.fx.GraphModule): + """ + Move an exported GraphModule to train mode. + + This is equivalent to model.train() but only for certain special ops like dropout, batchnorm. + QAT users should call this before performing training on the model. + """ + _replace_dropout(model, train_to_eval=False) + _replace_batchnorm(model, train_to_eval=False) + return model + + +def _allow_exported_model_train_eval(model: torch.fx.GraphModule): + """ + Allow users to call `model.train()` and `model.eval()` on an exported model, + but with the effect of changing behavior between the two modes limited to special + ops only, which are currently dropout and batchnorm. + + Note: This does not achieve the same effect as what `model.train()` and `model.eval()` + does in eager models, but only provides an approximation. In particular, user code + branching on `training` flag will not function correctly in general because the branch + is already specialized at export time. Additionally, other ops beyond dropout and batchnorm + that have different train/eval behavior will also not be converted properly. + """ + + def _train(self, mode: bool = True): + if mode: + _move_exported_model_to_train(self) + else: + _move_exported_model_to_eval(self) + + def _eval(self): + _move_exported_model_to_eval(self) + + model.train = types.MethodType(_train, model) # type: ignore[method-assign] + model.eval = types.MethodType(_eval, model) # type: ignore[method-assign] + return model diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/prepare.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/prepare.py new file mode 100644 index 0000000000000000000000000000000000000000..ac161e3f5fbb674e25690c5fe86d2436496d429c --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/prepare.py @@ -0,0 +1,489 @@ +import torch +from torch._subclasses import FakeTensor +from torch.ao.quantization.fx.prepare import ( + _insert_obs_or_fq, + _save_state, + _is_activation_post_process_node, + _create_obs_or_fq_from_qspec, +) +from torch.fx import ( + GraphModule, + Graph, + Node, +) +from torch.fx.node import Argument + +from torch.ao.quantization import QConfigMapping +from torch.ao.quantization.qconfig import QConfigAny +from torch.ao.quantization.fx.custom_config import PrepareCustomConfig +from typing import Dict, Tuple, Union, Any, Optional +from torch.ao.quantization.quantizer import ( + EdgeOrNode, + SharedQuantizationSpec, + QuantizationSpecBase, +) +from torch.ao.quantization import ObserverOrFakeQuantize + +# TODO: make pt2e folder private? +__all__ = [ + "prepare", +] + + +def _find_root_edge_or_node(edge_or_node: EdgeOrNode, shared_with_map: Dict[EdgeOrNode, EdgeOrNode]) -> EdgeOrNode: + """Find the root node for the sharing tree + Args: + edge_or_node: edge/node that we want to find the root + shared_with_map: each edge/node points to the parent, the root node will points to itself + + Returns: + root edge/node + """ + parent = shared_with_map[edge_or_node] + if parent == edge_or_node: + return edge_or_node + root = _find_root_edge_or_node(parent, shared_with_map) + # path compression + shared_with_map[edge_or_node] = root + return root + +def _union(parent: EdgeOrNode, child: EdgeOrNode, shared_with_map: Dict[EdgeOrNode, EdgeOrNode]) -> None: + """Merge the subtree for `child` with `parent`, the order is important here + """ + root_parent = _find_root_edge_or_node(parent, shared_with_map) + root_child = _find_root_edge_or_node(child, shared_with_map) + # union the two trees by pointing the root of child to root of parent + shared_with_map[root_child] = root_parent + +def _update_shared_with(child: EdgeOrNode, qspec: QuantizationSpecBase, shared_with_map: Dict[EdgeOrNode, EdgeOrNode]): + """Update the `shared_with_map` based on the qspec, this applies the `SharedQuantizationSpec` + configuration and established the relationship between `edge_or_node` with the edge/node that it + is pointing to, we'll use this information in the end to get the group id + """ + if isinstance(qspec, SharedQuantizationSpec): + parent = qspec.edge_or_node + # we point from edge_or_node to the node that it is sharing_with, e.g. + # qspec for a = SharedQuantizationSpec(b) means `a` points to `b` + _union(parent, child, shared_with_map) + +def _unwrap_shared_qspec( + qspec: QuantizationSpecBase, + edge_or_node_to_qspec: Dict[EdgeOrNode, QuantizationSpecBase], + shared_with_map: Dict[EdgeOrNode, EdgeOrNode] +) -> QuantizationSpecBase: + """Unwraps qspec to get the final root qspec (non SharedQuantizationSpec) + if qspec is SharedQuantizationSpec + (1). tries to find the root edge or node for the node that the qspec points to + (2). recursively find the root qspec based on the qspec for the root node + """ + if isinstance(qspec, SharedQuantizationSpec): + sharing_with = qspec.edge_or_node + root = _find_root_edge_or_node(sharing_with, shared_with_map) + qspec = edge_or_node_to_qspec[root] + return _unwrap_shared_qspec(qspec, edge_or_node_to_qspec, shared_with_map) + return qspec + +def _has_same_dtype(qspec_a: QuantizationSpecBase, qspec_b: QuantizationSpecBase): + return ( + hasattr(qspec_a, "dtype") and + hasattr(qspec_b, "dtype") and + qspec_a.dtype == qspec_b.dtype + ) + +def _has_same_is_dynamic(qspec_a: QuantizationSpecBase, qspec_b: QuantizationSpecBase): + return ( + hasattr(qspec_a, "is_dynamic") and + hasattr(qspec_b, "is_dynamic") and + qspec_a.is_dynamic == qspec_b.is_dynamic + ) + +def _get_edge_or_node_to_qspec(model: torch.fx.GraphModule) -> Dict[EdgeOrNode, QuantizationSpecBase]: + """Get a map from EdgeOrNode to quantization spec based on annotations on the nodes + """ + edge_or_node_to_qspec: Dict[EdgeOrNode, QuantizationSpecBase] = {} + for n in model.graph.nodes: + if hasattr(n, "meta") and "quantization_annotation" in n.meta: + qa = n.meta["quantization_annotation"] + for input_to_n, qspec in qa.input_qspec_map.items(): + input_edge = (input_to_n, n) + edge_or_node_to_qspec[input_edge] = qspec + if qa.output_qspec is not None: + output_node = n + qspec = qa.output_qspec + edge_or_node_to_qspec[output_node] = qspec + return edge_or_node_to_qspec + +def _union_input_edge_with(input_edge, input_edge_root_qspec, edge_or_node, edge_or_node_to_qspec, shared_with_map): + """Union input edge with another edge or node, used in implicit sharing to point the current input + edge to other user edges of the producer node, or the output of producer node since these are + referring to the same Tensor + """ + root_qspec = None + if edge_or_node in edge_or_node_to_qspec: + qspec = edge_or_node_to_qspec[edge_or_node] + root_qspec = _unwrap_shared_qspec(qspec, edge_or_node_to_qspec, shared_with_map) + # TODO: add assertions for types of root qspecs + if ( + root_qspec is not None and + _has_same_dtype(root_qspec, input_edge_root_qspec) and + _has_same_is_dynamic(root_qspec, input_edge_root_qspec) + ): + # the input arg to the node should reuse the existing output observer for arg + # since dtype is the same (we may want to extend this to be a more strict check + # in the future) + # so we point from `input_edge` to `arg` (output of the argument) + _union(edge_or_node, input_edge, shared_with_map) + + +def _get_edge_or_node_to_group_id(edge_or_node_to_qspec: Dict[EdgeOrNode, QuantizationSpecBase]) -> Dict[EdgeOrNode, int]: + """Map from edge/node to the group ID, generated from quantization annotations, + edge/node with the same group ID should use the same observer/fake_quant instance + + This is applying SharedQuantizationSpec configuration and map each edge/node to a group + There is another implicit sharing that's built in the quantization, when we have the following: + * op1 -> op2 + * output of op1: int8_qspec + * (op1 -> op2) input edge: int8_qspec + we'll assume sharing between the output of op1 and input of (op1 -> op2) since these are the same Tensor. + + Figuring out the correct group ID for all edge/node is a standard union find problem: + https://www.geeksforgeeks.org/introduction-to-disjoint-set-data-structure-or-union-find-algorithm/ + + Args: + edge_or_node_to_qspec: Dictionary from edge_or_node to the qspec, derived from annotations + Returns: + edge_or_node_to_group_id: Dictionary from edge_or_node to group_id (int), all edge or node that + belongs to the same group should have the same id + + Example: + op2 -> cat1 -> cat2 + op1 / / + op3 + edge_or_node_to_qspec: { + op1: int8_qspec, + op2: int8_qspec, + (op1, cat1): int8_qspc, + (op2, cat1): SharedQuantizationSpec((op1, cat1)), + cat1: SharedQuantizationSpec((op1, cat1)), + (op3, cat2): int8_qspec, + (cat1, cat2): SharedQuantizationSpec((op3, cat2)), + cat2: SharedQuantizationSpec((op3, cat2)), + } + + edge_or_node_to_group_id = _get_edge_or_node_to_group_id(edge_or_node_to_qspec) + edge_or_node_to_group_id: { + op1: 1, + op2: 1, + (op1, cat1): 1, + (op2, cat1): 1, + cat1: 1, + (op3, cat2): 1, + (cat1, cat2): 1, + cat2: 1, + } + # everything are in the same group because (cat1) and (cat1, cat2) are implicitly shared, which + # connects the two sharing group around cat1 and cat2 op due to transitive sharing + """ + # means the observer of key should be shared with observer with value, by default it will + # be shared with itself + shared_with_map: Dict[EdgeOrNode, EdgeOrNode] = {k: k for k in edge_or_node_to_qspec.keys()} + for edge_or_node, qspec in edge_or_node_to_qspec.items(): + if isinstance(edge_or_node, torch.fx.Node): + output_node = edge_or_node + _update_shared_with(output_node, qspec, shared_with_map) + else: + input_edge = edge_or_node + input_edge_root_qspec = _unwrap_shared_qspec(qspec, edge_or_node_to_qspec, shared_with_map) + + assert isinstance(input_edge, tuple) + arg, n = input_edge + if n.meta["quantization_annotation"].allow_implicit_sharing: + # NOTE: the order is important here, we first share with other users and then share with previous + # output because the reverse order could cause circular dependency + # e.g node1 -> node2 + # \ -> node3 + # when processing (node1, node2), if we first point (node1, node2) to node1 + # Step 1. shared_map = {(node1, node2): node1} + # Step 2. after that, we point the (node1, node2) to its other user (node1, node3) , + # which means shared_map = {(node1, node2): node1, node1: (node1, node3)} + # because we will point the root of (node1, node2) (in this case node1) to the root of (node1, node3) + # Step 3. and when we process (node1, node3), it can try to point to node1 as well, then we'll + # have a circular dependency + # the following order works around this issue, but this does not allow arbitrary configuration + # of sharing so it might break in a different case in the future, when it breaks + # quantizer writer can check the notes here to debug the issue + + # sharing with other users of the producer node + # (arg, user) + if not isinstance(arg, Node) or not isinstance(n, Node): + raise Exception(f"Expected input_edge to have type Tuple[Node, Node], but got: {arg, n}") + for user in arg.users: + if user is n: + continue + arg_to_user_edge = (arg, user) + _union_input_edge_with( + input_edge, + input_edge_root_qspec, + arg_to_user_edge, + edge_or_node_to_qspec, + shared_with_map + ) + + # sharing with output of producer node + _union_input_edge_with(input_edge, input_edge_root_qspec, arg, edge_or_node_to_qspec, shared_with_map) + + _update_shared_with(input_edge, qspec, shared_with_map) + + # now that we get the sharing relations between all edges and nodes, we can assingn group ids + cur_group_id = 0 + edge_or_node_to_group_id: Dict[EdgeOrNode, int] = {} + for edge_or_node in shared_with_map.keys(): + root = _find_root_edge_or_node(edge_or_node, shared_with_map) + if root not in edge_or_node_to_group_id: + edge_or_node_to_group_id[root] = cur_group_id + cur_group_id += 1 + edge_or_node_to_group_id[edge_or_node] = edge_or_node_to_group_id[root] + + return edge_or_node_to_group_id + +def _get_obs_or_fq_map( + edge_or_node_to_group_id: Dict[EdgeOrNode, int], + edge_or_node_to_qspec: Dict[EdgeOrNode, QuantizationSpecBase], + is_qat: bool +) -> Dict[EdgeOrNode, ObserverOrFakeQuantize]: + """Generates the EdgeOrNode to observer/fake_quant instances + Makes sure that for EdgeOrNode that has the same group_id should have the same observer or fake quant + instances + """ + obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize] = {} + group_id_to_obs_or_fq: Dict[int, ObserverOrFakeQuantize] = {} + for edge_or_node, qspec in edge_or_node_to_qspec.items(): + group_id = edge_or_node_to_group_id[edge_or_node] + if group_id not in group_id_to_obs_or_fq: + # TODO: maybe edge_or_node_to_qspec should be edge_or_node_to_root_qspec, this will simplify + # the implementation for _create_obs_or_fq_from_qspec + group_id_to_obs_or_fq[group_id] = _create_obs_or_fq_from_qspec(qspec, obs_or_fq_map, is_qat) + obs_or_fq_map[edge_or_node] = group_id_to_obs_or_fq[group_id] + return obs_or_fq_map + +def _maybe_insert_input_observer_for_arg_or_kwarg( + node: Union[Node, Any], + arg: Argument, + qconfig: QConfigAny, + model: torch.nn.Module, + named_modules: Dict[str, torch.nn.Module], + obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize], + is_qat: bool, +) -> Argument: + """ + Given a `node` and an `arg`, inserts an input observer between + `node` and `arg` if necessary. + """ + # for ops such as torch.cat([x0, x1]), + # traverse through the list + if isinstance(arg, (list, tuple)): + new_arg_to_return = [] + for inner_arg in arg: + new_inner_arg = _maybe_insert_input_observer_for_arg_or_kwarg( + node, inner_arg, qconfig, model, named_modules, obs_or_fq_map, is_qat, + ) + new_arg_to_return.append(new_inner_arg) + return type(arg)(new_arg_to_return) + + if not isinstance(arg, Node): + return arg + assert isinstance(arg, Node) + # default (no observer) + new_arg = arg + + # find the original `arg` node to the current node, skipping inserted observer/fake_quant nodes + original_arg = arg + while _is_activation_post_process_node(original_arg, named_modules): + original_arg = original_arg.args[0] # type: ignore[assignment] + assert isinstance(original_arg, Node), f"expect original argument to be a Node, but got: {type(original_arg)}" + + input_edge = (original_arg, node) + if input_edge not in obs_or_fq_map: + return new_arg + # input_edge needs to be observed + input_edge_obs_or_fq = obs_or_fq_map[input_edge] + if input_edge_obs_or_fq is None: + return new_arg + + arg_as_output_obs_or_fq = obs_or_fq_map.get(original_arg, None) + # the arg is observed as the output and is using the same instance as the input_edge + # we'll reuse the inserted observer/fake_quant + if arg_as_output_obs_or_fq is not None and id(arg_as_output_obs_or_fq) == id(input_edge_obs_or_fq): + return new_arg + + # otherwise, we'll insert a new observer/fake_quant node + + existing_obs_node = None + # skip inserting new observers if the same observer instance is inserted before for another user + # Example: + # conv1 -> obs1 -> existing_obs -> conv2 + # \ -> conv3 + # + # instead of inserting new observers we will have: + # conv1 -> obs1 -> existing_obs -> conv2 + # \ -> conv3 + for maybe_obs_node in arg.users.keys(): + if not _is_activation_post_process_node(maybe_obs_node, named_modules): + continue + maybe_obs_mod = named_modules[maybe_obs_node.target] # type: ignore[index] + if id(maybe_obs_mod) == id(input_edge_obs_or_fq): + return maybe_obs_node + + new_arg = _insert_obs_or_fq(arg, input_edge_obs_or_fq, model, named_modules, model.graph) + return new_arg + +def _maybe_insert_input_observers_for_node( + node: Node, + qconfig: QConfigAny, + model: torch.nn.Module, + named_modules: Dict[str, torch.nn.Module], + obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize], + is_qat: bool, +) -> None: + """ + If needed, inserts observers to the input args and kwargs of `node`. + Note: modifies `node` inplace. + + For example, if cur_node needs an observer after prev_node, we change from + + prev_node -> cur_node + + To + + prev_node -> obs -> cur_node + + """ + # Look through every input arg. If that arg's target dtype does not + # match the current node's target dtype, insert an observer. + new_args = [] + # map from old arg to new arg, used for updating the numeric debug handle map + remap = {} + for arg in node.args: + new_arg = _maybe_insert_input_observer_for_arg_or_kwarg( + node, arg, qconfig, model, named_modules, obs_or_fq_map, is_qat, + ) + new_args.append(new_arg) + remap[arg] = new_arg + + if "numeric_debug_handle" in node.meta: + + def remap_fn(x): + return remap.get(x, x) + + numeric_debug_handle = node.meta["numeric_debug_handle"] + node.meta["numeric_debug_handle"] = {remap_fn(k): v for k, v in numeric_debug_handle.items()} + + # Clone has a memory_format kwarg and zeros_like has a pin_memory kwarg + # that persist in exported graph. This is just a work around for these. + assert ( + node.target == torch.ops.aten.clone.default or + node.target == torch.ops.aten.zeros_like.default or + len(node.kwargs) == 0 + ), " expecting kwargs for aten op IR to be empty" + + # assign the new args to the node, inplace + node.args = tuple(new_args) + +def _maybe_insert_output_observer_for_node( + node: Node, + model: torch.nn.Module, + named_modules: Dict[str, torch.nn.Module], + graph: Graph, + obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize], + is_qat: bool, +) -> Optional[Node]: + if node in obs_or_fq_map: + output_act_obs_or_fq = obs_or_fq_map[node] + return _insert_obs_or_fq(node, output_act_obs_or_fq, model, named_modules, graph) + return None + +def _maybe_insert_input_and_output_observers_for_node( + node: Node, + model: torch.fx.GraphModule, + obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize], + is_qat: bool, +): + this_node_quantization_annotation = node.meta["quantization_annotation"] if "quantization_annotation" in node.meta else None + if this_node_quantization_annotation is None: + return + + named_modules = dict(model.named_modules(remove_duplicate=False)) + _maybe_insert_input_observers_for_node( + node, + None, # qconfig + model, + named_modules, + obs_or_fq_map, + is_qat, + ) + + output_is_a_tensor = "val" in node.meta and isinstance(node.meta["val"], FakeTensor) + if not output_is_a_tensor: + return + + # this returns the new observer node if it was needed + maybe_output_obs_node = _maybe_insert_output_observer_for_node( + node, model, named_modules, model.graph, obs_or_fq_map, is_qat) + + if maybe_output_obs_node is None: + return + # Update users of original node to use the output observer + # instead. For example, change + # + # next_node + # / + # cur_node -> obs + # + # to + # + # next_node + # / + # cur_node -> obs + # + # We need to save orig users before updating uses because + # the list of users will change as we update uses + orig_users = list(node.users.keys()) + for user_node in orig_users: + if user_node is maybe_output_obs_node: + continue + user_node.replace_input_with(node, maybe_output_obs_node) + +def prepare( + model: GraphModule, + node_name_to_scope: Dict[str, Tuple[str, type]], + is_qat: bool, +) -> GraphModule: + # Since we are mutating the graph as we go, we iterate over the original + # nodes before observer insertion, instead of model.graph.nodes. + nodes_before_observation = list(model.graph.nodes) + + # At the high level we construct a map from EdgeOrNode to a observer_or_fake_quant instance + # all edge/nodes that belongs to the same group will use the same instance + # and when we insert observers we'll just query this map to get the correct observer_or_fake_quant + # instance + edge_or_node_to_qspec = _get_edge_or_node_to_qspec(model) + edge_or_node_to_group_id = _get_edge_or_node_to_group_id(edge_or_node_to_qspec) + obs_or_fq_map = _get_obs_or_fq_map(edge_or_node_to_group_id, edge_or_node_to_qspec, is_qat) + + for node in nodes_before_observation: + # TODO: simplify logic for inserting observers + _maybe_insert_input_and_output_observers_for_node(node, model, obs_or_fq_map, is_qat) + + model = GraphModule(model, model.graph) + + _save_state( + model, + {}, # node_name_to_qconfig + node_name_to_scope, + PrepareCustomConfig(), + {}, # equalization_node_name_to_qconfig + QConfigMapping(), + is_qat, + set() # observed_node_names + ) + return model diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/representation/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/representation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9ddac64c04fa4bbc6a781540cbce9c6416ba0b52 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/pt2e/representation/__init__.py @@ -0,0 +1,5 @@ +from .rewrite import reference_representation_rewrite + +__all__ = [ + "reference_representation_rewrite", +] diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/qconfig_mapping.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/qconfig_mapping.py new file mode 100644 index 0000000000000000000000000000000000000000..6bf4b41c724a80455d5a39447550c4c98d6614fa --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/qconfig_mapping.py @@ -0,0 +1,350 @@ +from __future__ import annotations +from collections import OrderedDict +from typing import Any, Callable, Dict, Tuple, Union, List + +import torch + +from .fake_quantize import ( + default_weight_fake_quant, + FixedQParamsFakeQuantize, +) +from .observer import ( + _PartialWrapper, + default_fixed_qparams_range_0to1_observer, + default_fixed_qparams_range_neg1to1_observer, + default_placeholder_observer, + default_weight_observer, +) +from .qconfig import ( + default_reuse_input_qconfig, + default_symmetric_qnnpack_qconfig, + default_symmetric_qnnpack_qat_qconfig, + get_default_qconfig, + get_default_qat_qconfig, + QConfig, + QConfigAny, + default_quint8_weight_qconfig +) + + +__all__ = [ + "get_default_qconfig_mapping", + "get_default_qat_qconfig_mapping", + "QConfigMapping", +] + + +# TODO: replace all usages with these constants +_GLOBAL_DICT_KEY = "" +_OBJECT_TYPE_DICT_KEY = "object_type" +_MODULE_NAME_REGEX_DICT_KEY = "module_name_regex" +_MODULE_NAME_DICT_KEY = "module_name" +_MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY = "module_name_object_type_order" + +# TODO: derive this map from the BackendConfig +_FIXED_QPARAMS_OP_TO_OBSERVER: Dict[Union[Callable, str], _PartialWrapper] = { + torch.nn.Hardsigmoid: default_fixed_qparams_range_0to1_observer, + torch.nn.functional.hardsigmoid: default_fixed_qparams_range_0to1_observer, + "hardsigmoid": default_fixed_qparams_range_0to1_observer, + "hardsigmoid_": default_fixed_qparams_range_0to1_observer, + torch.nn.Sigmoid: default_fixed_qparams_range_0to1_observer, + torch.sigmoid: default_fixed_qparams_range_0to1_observer, + "sigmoid": default_fixed_qparams_range_0to1_observer, + "sigmoid_": default_fixed_qparams_range_0to1_observer, + torch.nn.Softmax: default_fixed_qparams_range_0to1_observer, + torch.nn.Tanh: default_fixed_qparams_range_neg1to1_observer, + torch.tanh: default_fixed_qparams_range_neg1to1_observer, + "tanh": default_fixed_qparams_range_neg1to1_observer, + "tanh_": default_fixed_qparams_range_neg1to1_observer, +} + + +def _get_default_qconfig_mapping(is_qat: bool, backend: str, version: int) -> QConfigMapping: + """ + Return the default QConfigMapping for the given quantization type and backend. + """ + if is_qat: + qconfig = get_default_qat_qconfig(backend, version) + else: + qconfig = get_default_qconfig(backend, version) + default_weight = default_weight_fake_quant if is_qat else default_weight_observer + + # default_per_channel_weight_observer is not currently compatible with fbgemm backend + # so we have to modify the weight observer to default_weight_observer or another + # per tensor supported observer. + # see https://github.com/pytorch/pytorch/issues/47535 + if backend in ("fbgemm", "x86"): + qconfig_transpose = QConfig(activation=qconfig.activation, weight=default_weight) + else: + qconfig_transpose = qconfig + + # currently layernorm only supports float weights + # we have to add this because otherwise there will be a extra quantize-dequantize pair + qconfig_layernorm = QConfig(activation=qconfig.activation, weight=default_placeholder_observer) + + qconfig_mapping = QConfigMapping() \ + .set_global(qconfig) \ + .set_object_type("reshape", default_reuse_input_qconfig) \ + .set_object_type(torch.nn.ConvTranspose1d, qconfig_transpose) \ + .set_object_type(torch.nn.ConvTranspose2d, qconfig_transpose) \ + .set_object_type(torch.nn.ConvTranspose3d, qconfig_transpose) \ + .set_object_type(torch.nn.functional.conv_transpose1d, qconfig_transpose) \ + .set_object_type(torch.nn.functional.conv_transpose2d, qconfig_transpose) \ + .set_object_type(torch.nn.functional.conv_transpose3d, qconfig_transpose) \ + .set_object_type(torch.nn.functional.layer_norm, qconfig_layernorm) \ + .set_object_type(torch.nn.LayerNorm, qconfig_layernorm) \ + .set_object_type(torch.nn.PReLU, default_quint8_weight_qconfig) \ + + # Use special observers for ops with fixed qparams + fixed_qparams_observer_to_qconfig: Dict[Any, QConfigAny] = {} + for fixed_qparams_op, observer in _FIXED_QPARAMS_OP_TO_OBSERVER.items(): + if observer in fixed_qparams_observer_to_qconfig: + fixed_qparams_qconfig = fixed_qparams_observer_to_qconfig[observer] + else: + if is_qat: + activation = FixedQParamsFakeQuantize.with_args(observer=observer) + else: + activation = observer + fixed_qparams_qconfig = QConfig(activation=activation, weight=default_weight) + fixed_qparams_observer_to_qconfig[observer] = fixed_qparams_qconfig + qconfig_mapping.set_object_type(fixed_qparams_op, fixed_qparams_qconfig) + + # TODO Currently it's required that separate ops in a fused op/module have the same qconfig. + # Need to be able to support fusion of ops with different qconfigs + + return qconfig_mapping + +def get_default_qconfig_mapping(backend="x86", version=0) -> QConfigMapping: + """ + Return the default QConfigMapping for post training quantization. + + Args: + * ``backend`` (str) : the quantization backend for the default qconfig mapping, should be + one of ["x86" (default), "fbgemm", "qnnpack", "onednn"] + * ``version`` (int) : the version for the default qconfig mapping + """ + # TODO: add assert for backend choices + return _get_default_qconfig_mapping(False, backend, version) + +def get_default_qat_qconfig_mapping(backend="x86", version=1) -> QConfigMapping: + """ + Return the default QConfigMapping for quantization aware training. + + Args: + * ``backend`` (str) : the quantization backend for the default qconfig mapping, should be + one of ["x86" (default), "fbgemm", "qnnpack", "onednn"] + * ``version`` (int) : the version for the default qconfig mapping + """ + return _get_default_qconfig_mapping(True, backend, version) + +def _get_symmetric_qnnpack_qconfig_mapping() -> QConfigMapping: + """ + Return a QConfigMapping that uses `torch.ao.quantization.default_symmetric_qnnpack_qconfig` + as the default QConfig. + """ + default_qconfig = default_symmetric_qnnpack_qconfig + return _get_default_qconfig_mapping_with_default_qconfig(False, "qnnpack", default_qconfig) + +def _get_symmetric_qnnpack_qat_qconfig_mapping() -> QConfigMapping: + """ + Return a QConfigMapping that uses `torch.ao.quantization.default_symmetric_qnnpack_qat_qconfig` + as the default QConfig. + """ + default_qconfig = default_symmetric_qnnpack_qat_qconfig + return _get_default_qconfig_mapping_with_default_qconfig(True, "qnnpack", default_qconfig) + +def _get_default_qconfig_mapping_with_default_qconfig( + is_qat: bool, + backend: str, + default_qconfig: QConfig, +) -> QConfigMapping: + """ + Return a QConfigMapping that uses the provided qconfig as the default QConfig. + """ + if is_qat: + qconfig_mapping = get_default_qat_qconfig_mapping(backend) + else: + qconfig_mapping = get_default_qconfig_mapping(backend) + qconfig_mapping.set_global(default_qconfig) + for pattern in qconfig_mapping.object_type_qconfigs.keys(): + if pattern not in _FIXED_QPARAMS_OP_TO_OBSERVER: + qconfig_mapping.set_object_type(pattern, default_qconfig) + return qconfig_mapping + +_QCONFIG_STYLE_ORDER: List[str] = [ + "global_qconfig", + "object_type_qconfigs", + "module_name_regex_qconfigs", + "module_name_qconfigs", + "module_name_object_type_order_qconfigs", +] + +class QConfigMapping: + """ + Mapping from model ops to :class:`torch.ao.quantization.QConfig` s. + + The user can specify QConfigs using the following methods (in increasing match priority): + + ``set_global`` : sets the global (default) QConfig + + ``set_object_type`` : sets the QConfig for a given module type, function, or method name + + ``set_module_name_regex`` : sets the QConfig for modules matching the given regex string + + ``set_module_name`` : sets the QConfig for modules matching the given module name + + ``set_module_name_object_type_order`` : sets the QConfig for modules matching a combination + of the given module name, object type, and the index at which the module appears + + Example usage:: + + qconfig_mapping = QConfigMapping() + .set_global(global_qconfig) + .set_object_type(torch.nn.Linear, qconfig1) + .set_object_type(torch.nn.ReLU, qconfig1) + .set_module_name_regex("foo.*bar.*conv[0-9]+", qconfig1) + .set_module_name_regex("foo.*", qconfig2) + .set_module_name("module1", qconfig1) + .set_module_name("module2", qconfig2) + .set_module_name_object_type_order("foo.bar", torch.nn.functional.linear, 0, qconfig3) + + """ + + def __init__(self): + # In increasing match priority: + self.global_qconfig: QConfigAny = None + self.object_type_qconfigs: OrderedDict[Union[Callable, str], QConfigAny] = OrderedDict() + self.module_name_regex_qconfigs: OrderedDict[str, QConfigAny] = OrderedDict() + self.module_name_qconfigs: OrderedDict[str, QConfigAny] = OrderedDict() + self.module_name_object_type_order_qconfigs: OrderedDict[Tuple[str, Callable, int], QConfigAny] =\ + OrderedDict() + + def set_global(self, global_qconfig: QConfigAny) -> QConfigMapping: + """ + Set the global (default) QConfig. + """ + self.global_qconfig = global_qconfig + return self + + def set_object_type(self, object_type: Union[Callable, str], qconfig: QConfigAny) -> QConfigMapping: + """ + Set the QConfig for a given module type, function, or method name. + If the QConfig for an existing object type was already set, the new QConfig will override the old one. + """ + self.object_type_qconfigs[object_type] = qconfig + return self + + def set_module_name_regex(self, module_name_regex: str, qconfig: QConfigAny) -> QConfigMapping: + """ + Set the QConfig for modules matching the given regex string. + + Regexes will be matched in the order in which they are registered through this method. + Thus, the caller should register more specific patterns first, e.g.:: + + qconfig_mapping = QConfigMapping() + .set_module_name_regex("foo.*bar.*conv[0-9]+", qconfig1) + .set_module_name_regex("foo.*bar.*", qconfig2) + .set_module_name_regex("foo.*", qconfig3) + + In this example, "foo.bar.conv0" would match qconfig1, "foo.bar.linear" would match qconfig2, + and "foo.baz.relu" would match qconfig3. + + If the QConfig for an existing module name regex was already set, the new QConfig will override the + old one while preserving the order in which the regexes were originally registered. + """ + self.module_name_regex_qconfigs[module_name_regex] = qconfig + return self + + def set_module_name(self, module_name: str, qconfig: QConfigAny) -> QConfigMapping: + """ + Set the QConfig for modules matching the given module name. + If the QConfig for an existing module name was already set, the new QConfig will override the old one. + """ + self.module_name_qconfigs[module_name] = qconfig + return self + + def set_module_name_object_type_order( + self, + module_name: str, + object_type: Callable, + index: int, + qconfig: QConfigAny) -> QConfigMapping: + """ + Set the QConfig for modules matching a combination of the given module name, object type, + and the index at which the module appears. + + If the QConfig for an existing (module name, object type, index) was already set, the new QConfig + will override the old one. + """ + self.module_name_object_type_order_qconfigs[(module_name, object_type, index)] = qconfig + return self + + def __repr__(self) -> str: + output = self.__class__.__name__ + " (" + for style_name in _QCONFIG_STYLE_ORDER: + output += f"\n {style_name}" + qconfigs = getattr(self, style_name) + if isinstance(qconfigs, OrderedDict) and len(qconfigs) > 0: + for key, qconfig in qconfigs.items(): + output += f"\n {key}: {qconfig}" + else: + output += f"\n {qconfigs}" + return output + "\n)" + + # TODO: remove this + def to_dict(self) -> Dict[str, Any]: + """ + Convert this ``QConfigMapping`` to a dictionary with the following keys: + + "" (for global QConfig) + + "object_type" + + "module_name_regex" + + "module_name" + + "module_name_object_type_order" + + The values of this dictionary are lists of tuples. + """ + return { + _GLOBAL_DICT_KEY: self.global_qconfig, + _OBJECT_TYPE_DICT_KEY: list(self.object_type_qconfigs.items()), + _MODULE_NAME_REGEX_DICT_KEY: list(self.module_name_regex_qconfigs.items()), + _MODULE_NAME_DICT_KEY: list(self.module_name_qconfigs.items()), + _MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY: [ + (*k, v) for k, v in self.module_name_object_type_order_qconfigs.items() + ], + } + + # TODO: remove this + @classmethod + def from_dict(cls, qconfig_dict: Dict[str, Any]) -> QConfigMapping: + """ + Create a ``QConfigMapping`` from a dictionary with the following keys (all optional): + + "" (for global QConfig) + + "object_type" + + "module_name_regex" + + "module_name" + + "module_name_object_type_order" + + The values of this dictionary are expected to be lists of tuples. + """ + conf = cls() + if _GLOBAL_DICT_KEY in qconfig_dict: + conf.set_global(qconfig_dict[_GLOBAL_DICT_KEY]) + for object_type, qconfig in qconfig_dict.get(_OBJECT_TYPE_DICT_KEY, []): + conf.set_object_type(object_type, qconfig) + for module_name_regex, qconfig in qconfig_dict.get(_MODULE_NAME_REGEX_DICT_KEY, []): + conf.set_module_name_regex(module_name_regex, qconfig) + for module_name, qconfig in qconfig_dict.get(_MODULE_NAME_DICT_KEY, []): + conf.set_module_name(module_name, qconfig) + for module_name, object_type, index, qconfig in qconfig_dict.get(_MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY, []): + conf.set_module_name_object_type_order(module_name, object_type, index, qconfig) + return conf diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/quantizer/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/quantizer/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..38222522e30c0a5913dd48b749421309b41414d6 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/quantizer/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/quantizer/__pycache__/xnnpack_quantizer.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/quantizer/__pycache__/xnnpack_quantizer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f2c95078e75ef20f50407b6f7a1ca93843fc3295 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/quantizer/__pycache__/xnnpack_quantizer.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eae09059a915679097b41e778c40e7dea85da9bd Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/__pycache__/_remove_auto_functionalized_pass.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/__pycache__/_remove_auto_functionalized_pass.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c1708e0af482f19a76071c3549d456033b310c30 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/__pycache__/_remove_auto_functionalized_pass.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/__pycache__/_remove_effect_tokens_pass.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/__pycache__/_remove_effect_tokens_pass.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..295b6f085a000d8e840f19cd31c48510a9142efd Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/__pycache__/_remove_effect_tokens_pass.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/__pycache__/_tree_utils.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/__pycache__/_tree_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b36cf95589403f59085e8dbe01f186f239b732f4 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/__pycache__/_tree_utils.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/__pycache__/custom_obj.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/__pycache__/custom_obj.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b704f4450c1eb6dbde497e8e990d3768c4abaddb Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/__pycache__/custom_obj.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/__pycache__/dynamic_shapes.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/__pycache__/dynamic_shapes.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..846395c054c28ef60ee31e942ae26923f3b9719b Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/__pycache__/dynamic_shapes.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/__pycache__/unflatten.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/__pycache__/unflatten.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..32fa0aea4c1ba4bf39dab2f5012b7fea1a512c51 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/__pycache__/unflatten.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/_remove_effect_tokens_pass.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/_remove_effect_tokens_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..f9f2d01ea6be2b3943e9b84ad49a24f2ac0517f2 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/_remove_effect_tokens_pass.py @@ -0,0 +1,126 @@ +import operator +from typing import List + +import torch +from torch._higher_order_ops.effects import with_effects +from .exported_program import ExportedProgram +from .graph_signature import ( + InputKind, + InputSpec, + OutputKind, + OutputSpec, + TensorArgument, +) + + +def _remove_effect_tokens(ep: ExportedProgram) -> ExportedProgram: + """ + Removes the existance of tokens from the exported program, including: + - Removes the input and output tokens + - Replaces with_effects(token, func, args) with just func(args) + + This function does an inplace modification on the given ExportedProgram. + """ + num_tokens: int = 0 + input_token_names: List[str] = [] + new_input_specs: List[InputSpec] = [] + for inp in ep.graph_signature.input_specs: + if inp.kind == InputKind.TOKEN: + num_tokens += 1 + assert isinstance(inp.arg, TensorArgument) + input_token_names.append(inp.arg.name) + else: + new_input_specs.append(inp) + + num_out_tokens: int = 0 + new_output_specs: List[str] = [] + output_token_names: List[OutputSpec] = [] + for out in ep.graph_signature.output_specs: + if out.kind == OutputKind.TOKEN: + num_out_tokens += 1 + output_token_names.append(out.arg.name) + else: + new_output_specs.append(out) + + assert num_tokens == num_out_tokens + + output_node = None + with_effect_nodes: List[torch.fx.Node] = [] + for node in ep.graph.nodes: + if node.op == "output": + output_node = node + break + + if not (node.op == "call_function" and node.target is with_effects): + continue + + with_effect_nodes.append(node) + + # Remove tokens from outputs + assert output_node is not None + output_args = output_node.args[0] + assert len(output_args) >= num_tokens + out_token_nodes = output_args[:num_tokens] + output_node.args = (tuple(output_args[num_tokens:]),) + for out_token in out_token_nodes: + assert out_token.name in output_token_names + ep.graph.erase_node(out_token) + + # Replace with_effects(token, func, args) with just func(args) + for node in reversed(with_effect_nodes): + func = node.args[1] + assert isinstance(func, torch._ops.OpOverload) + + with ep.graph.inserting_before(node): + new_node = ep.graph.call_function(func, node.args[2:]) + for k, v in node.meta.items(): + new_node.meta[k] = v + + node.replace_all_uses_with(new_node) + + # Update user getitem nodes + for user in list(new_node.users.keys()): + assert user.target == operator.getitem + # getitem(with_effects, 0) == token + if user.args[1] == 0: + ep.graph.erase_node(user) + + if len(func._schema.returns) == 1: + # If the function has 1 return then it will just directly return the + # result -- we don't need a getitem. So we can replace all the + # getitem(with_effects, 1) with just the note itself. + for user in list(new_node.users.keys()): + assert user.args[1] == 1 + user.replace_all_uses_with(new_node) + + new_node.meta["val"] = node.meta["val"][1] + elif len(func._schema.returns) > 1: + # If the function has more than 1 return then since we got rid of + # the 1st return value (the token), we need to bump all the other + # getitem calls by 1 down + for user in list(new_node.users.keys()): + assert user.args[1] >= 1 + user.args = (user.args[0], user.args[1] - 1) + + new_node.meta["val"] = node.meta["val"][1:] + else: + assert len(func._schema.returns) == 0 + assert len(new_node.users) == 0 + new_node.meta["val"] = None + + ep.graph.erase_node(node) + + # Remove tokens from inputs + placeholders = [node for node in ep.graph.nodes if node.op == "placeholder"] + assert len(placeholders) >= num_tokens + inp_token_nodes = placeholders[:num_tokens] + for inp_token in inp_token_nodes: + assert inp_token.name in input_token_names + ep.graph.erase_node(inp_token) + + # Update graph signature + ep.graph_signature.input_specs = new_input_specs + ep.graph_signature.output_specs = new_output_specs + + ep.graph.eliminate_dead_code() + return ep diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/_unlift.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/_unlift.py new file mode 100644 index 0000000000000000000000000000000000000000..bede2986d4eab0a4996c13f6f6f3d2f68f55c9dc --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/_unlift.py @@ -0,0 +1,314 @@ +import copy +from itertools import chain +from typing import Any, Dict, List, Optional, Tuple + +import torch +import torch.utils._pytree as pytree +from torch._export.utils import _check_input_constraints_for_graph +from torch.export.unflatten import _assign_attr, _AttrKind +from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo +from ._remove_effect_tokens_pass import _remove_effect_tokens + +from .exported_program import ( + ExportedProgram, + ExportGraphSignature, + InputKind, + OutputKind, +) + + +@torch._dynamo.disable +def _check_input_constraints_pre_hook(self, *args, **kwargs): + flat_args_with_path, received_spec = pytree.tree_flatten_with_path(args) + + if received_spec != self._in_spec: + raise ValueError( # noqa: TRY200 + "Trying to flatten user inputs with exported input tree spec: \n" + f"{self._in_spec}\n" + "but actually got inputs with tree spec of: \n" + f"{received_spec}" + ) + + return _check_input_constraints_for_graph( + [node for node in self.graph.nodes if node.op == "placeholder"], + flat_args_with_path, + self.range_constraints, + ) + + +def _unlift_inputs_as_getattr( + gm: torch.fx.GraphModule, + lifted_inputs: List[Optional[str]], +) -> Tuple[Dict[str, torch.fx.Node], Dict[str, torch.fx.Node]]: + """ + Unlift inputs referring to params/buffers/constants as getattr nodes in the + graph + """ + unlifted_name_to_node = {} + input_name_to_node = {} + + placeholder_nodes = [node for node in gm.graph.nodes if node.op == "placeholder"] + assert len(lifted_inputs) == len(placeholder_nodes) + for input_node, lifted_node in zip(placeholder_nodes, lifted_inputs): + if lifted_node is None: + input_name_to_node[input_node.name] = input_node + + else: + with gm.graph.inserting_after(input_node): + getattr_node = gm.graph.get_attr(lifted_node) + input_node.replace_all_uses_with(getattr_node) + metadata = input_node.meta + gm.graph.erase_node(input_node) + getattr_node.meta = metadata + unlifted_name_to_node[lifted_node] = getattr_node + + return unlifted_name_to_node, input_name_to_node + + +def _insert_copy_for_mutations( + gm: torch.fx.GraphModule, + mutated_outputs: List[Optional[str]], + unlifted_name_to_node: Dict[str, torch.fx.Node], + input_name_to_node: Dict[str, torch.fx.Node], +) -> None: + """ + Find the all the buffers and inputs that were mutated and insert copy_ + operators to reflect mutations. + """ + output_node = None + for node in gm.graph.nodes: + if node.op == "output": + output_node = node + break + assert output_node is not None + outputs = pytree.tree_flatten(output_node.args)[0] + assert len(outputs) == len(mutated_outputs) + + user_output_nodes = [] + for return_node, mutated_node_name in zip(outputs, mutated_outputs): + if mutated_node_name is None: + user_output_nodes.append(return_node) + continue + + if mutated_node_name in unlifted_name_to_node: + mutated_node = unlifted_name_to_node[mutated_node_name] + elif mutated_node_name in input_name_to_node: + mutated_node = input_name_to_node[mutated_node_name] + else: + raise RuntimeError( + f"Could not find {mutated_node_name} in either buffer or input nodes" + ) + + with gm.graph.inserting_before(output_node): + _ = gm.graph.call_function( + torch.ops.aten.copy_.default, (mutated_node, return_node) + ) + + with gm.graph.inserting_before(output_node): + # Only return user outputs + new_output = gm.graph.output(tuple(user_output_nodes)) + output_node.replace_all_uses_with(new_output) + gm.graph.erase_node(output_node) + + +def _get_codegen( + in_spec: pytree.TreeSpec, + out_spec: Optional[pytree.TreeSpec], +) -> _PyTreeCodeGen: + """ + Create the codegen for the graph module based on the in/out specs + """ + if ( + in_spec.type == tuple + and in_spec.num_children == 2 + and in_spec.children_specs[0].type == tuple + and in_spec.children_specs[1].type == dict + ): + # if in_spec contains the args (tuple) and kwargs (dict) + names = [f"arg_{i}" for i in range(in_spec.children_specs[0].num_children)] + # add kwarg names + names.extend(in_spec.children_specs[1].context) + else: + names = [f"arg_{i}" for i in range(in_spec.num_children)] + + return _PyTreeCodeGen( + _PyTreeInfo( + names, + in_spec, + out_spec, + ) + ) + + +def _unlift( + gm: torch.fx.GraphModule, + lifted_inputs: List[Optional[str]], + mutated_outputs: List[Optional[str]], + in_spec: pytree.TreeSpec, + out_spec: Optional[pytree.TreeSpec], + state_dict: Dict[str, Any], + constants: Dict[str, Any], +): + """ + Args: + lifted_inputs: A list matching the graph module's input nodes. For + an input node that is referring to a lifted parameter/buffer, this + list will contain the fqn the corresponding attribute. Otherwise, this + list will contain None. This is used to unlift the lifted parameters as + get_attr nodes. + + mutated_outputs: A list matching the graph module's output nodes. For + an output node that is referring to a mutated buffer or user input, this + list will contain the name of the corresponding buffer or user input + that needs to be mutated. Otherwise, this list will contain None. This + is used to re-insert an inplace copy_ operator to copy the mutated + values back to the original node. + """ + unlifted_name_to_node, input_name_to_node = _unlift_inputs_as_getattr( + gm, lifted_inputs + ) + _insert_copy_for_mutations( + gm, mutated_outputs, unlifted_name_to_node, input_name_to_node + ) + gm.graph._codegen = _get_codegen(in_spec, out_spec) + gm.graph.lint() + gm.graph.eliminate_dead_code() + gm.recompile() + return gm + + +def _register_attrs_to_new_gm( + new_gm: torch.fx.GraphModule, + graph_signature: ExportGraphSignature, + state_dict: Dict[str, Any], + constants: Dict[str, Any], +) -> None: + non_persistent_buffers = set(graph_signature.non_persistent_buffers) + for name in graph_signature.buffers: + if name in non_persistent_buffers: + persistent = False + value = constants[name] + else: + persistent = True + value = state_dict[name] + _assign_attr( + value, new_gm, name, attr_kind=_AttrKind.BUFFER, persistent=persistent + ) + for name in graph_signature.parameters: + value = state_dict[name] + _assign_attr( + value, + new_gm, + name, + attr_kind=_AttrKind.PARAMETER, + ) + + for name in chain( + graph_signature.lifted_custom_objs, graph_signature.lifted_tensor_constants + ): + value = constants[name] + _assign_attr( + value, + new_gm, + name, + attr_kind=_AttrKind.CONSTANT, + ) + + +class _StatefulGraphModuleFactory(type): + """ + Metaclass that ensures a private constructor for _StatefulGraphModule + """ + + def __call__(cls, *args, **kwargs): + raise TypeError( + f"{cls.__module__}.{cls.__qualname__} has no public constructor. " + ) + + def _create(cls, root, graph, range_constraints=None): + return super().__call__( + root, + graph, + range_constraints=range_constraints, + ) + + +class _StatefulGraphModule(torch.fx.GraphModule, metaclass=_StatefulGraphModuleFactory): + def __init__(self, root, graph, range_constraints=None): + super().__init__(root, graph) + # Need to fix up non-persistent buffers. + self.range_constraints = range_constraints or [] + + +def _create_stateful_graph_module( + plain_graph_module: torch.fx.GraphModule, + range_constraints, + # TODO(suo) this should not be optional, but is since we still ahve + # capture_pre_autograd_graph grr + graph_signature: Optional[ExportGraphSignature] = None, +): + stateful_gm = _StatefulGraphModule._create( + plain_graph_module, + plain_graph_module.graph, + range_constraints=range_constraints, + ) + stateful_gm.register_forward_pre_hook( + _check_input_constraints_pre_hook, with_kwargs=True + ) + + if graph_signature is None: + return stateful_gm + # Fix up non-persistent buffers. torch.fx does not distinguish between + # persistent and non-persistent buffers, so we must restore that distinction + # here. + for buffer in graph_signature.non_persistent_buffers: + _assign_attr( + plain_graph_module.get_buffer(buffer), + stateful_gm, + buffer, + attr_kind=_AttrKind.BUFFER, + persistent=False, + ) + + return stateful_gm + + +def _unlift_exported_program_lifted_states(ep: ExportedProgram) -> torch.nn.Module: + ep = _remove_effect_tokens(ep) + new_gm = torch.fx.GraphModule(ep.graph_module, copy.deepcopy(ep.graph)) + _register_attrs_to_new_gm(new_gm, ep.graph_signature, ep.state_dict, ep.constants) + + lifted_inputs: List[Optional[str]] = [ + in_spec.target + if in_spec.kind + in ( + InputKind.BUFFER, + InputKind.CONSTANT_TENSOR, + InputKind.PARAMETER, + InputKind.CUSTOM_OBJ, + ) + else None + for in_spec in ep.graph_signature.input_specs + ] + + mutated_outputs: List[Optional[str]] = [ + out_spec.target + if out_spec.kind in (OutputKind.BUFFER_MUTATION, OutputKind.USER_INPUT_MUTATION) + else None + for out_spec in ep.graph_signature.output_specs + ] + + new_gm = _unlift( + new_gm, + lifted_inputs, + mutated_outputs, + ep.call_spec.in_spec, + ep.call_spec.out_spec, + ep.state_dict, + ep.constants, + ) + unlift_gm = _create_stateful_graph_module( + new_gm, ep.range_constraints, ep.graph_signature + ) + unlift_gm.meta.update(ep.graph_module.meta) + return unlift_gm diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/custom_obj.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/custom_obj.py new file mode 100644 index 0000000000000000000000000000000000000000..8e7f2080a4ee705a2621386c9b69a089d507544a --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/custom_obj.py @@ -0,0 +1,16 @@ +from dataclasses import dataclass + + +__all__ = ["ScriptObjectMeta"] + + +@dataclass +class ScriptObjectMeta: + """ + Metadata which is stored on nodes representing ScriptObjects. + """ + + # Key into constants table to retrieve the real ScriptObject. + constant_name: str + + class_fqn: str diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/exported_program.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/exported_program.py new file mode 100644 index 0000000000000000000000000000000000000000..8d1d2740a437015875d2d2b85895f7ed00e67f26 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/exported_program.py @@ -0,0 +1,745 @@ +import copy +import dataclasses +import functools +import types +import warnings +from collections import namedtuple +from typing import ( + Any, + Callable, + Dict, + Iterator, + List, + Optional, + Tuple, + Type, + TYPE_CHECKING, + Union, +) + +from torch.fx.immutable_collections import immutable_dict, immutable_list + +if TYPE_CHECKING: + # Import the following modules during type checking to enable code intelligence features, + # such as auto-completion in tools like pylance, even when these modules are not explicitly + # imported in user code. + + import sympy + + from torch.utils._sympy.value_ranges import ValueRanges + +import torch +import torch.utils._pytree as pytree +from torch.export._tree_utils import is_equivalent, reorder_kwargs +from torch.fx._compatibility import compatibility +from torch.fx.experimental.proxy_tensor import maybe_disable_fake_tensor_mode + +from torch.fx.passes.infra.pass_base import PassResult +from torch.fx.passes.infra.pass_manager import PassManager + +from .graph_signature import ( # noqa: F401 + _sig_to_specs, + ArgumentSpec, + ConstantArgument, + CustomObjArgument, + ExportGraphSignature, + InputKind, + InputSpec, + OutputKind, + OutputSpec, + SymIntArgument, + TensorArgument, +) + + +__all__ = [ + "ExportedProgram", + "ModuleCallEntry", + "ModuleCallSignature", +] + + +PassType = Callable[[torch.fx.GraphModule], Optional[PassResult]] + + +@dataclasses.dataclass +class ModuleCallSignature: + inputs: List[ArgumentSpec] + outputs: List[ArgumentSpec] + in_spec: pytree.TreeSpec + out_spec: pytree.TreeSpec + + +@dataclasses.dataclass +class ModuleCallEntry: + fqn: str + signature: Optional[ModuleCallSignature] = None + + +def _disable_prexisiting_fake_mode(fn): + @functools.wraps(fn) + def wrapper(*args, **kwargs): + with maybe_disable_fake_tensor_mode(): + return fn(*args, **kwargs) + + return wrapper + + +def _fx_collection_equivalence_fn( + spec1_type: Optional[type], + spec1_context: pytree.Context, + spec2_type: Optional[type], + spec2_context: pytree.Context, +) -> bool: + """Treat containers and their immutable variants as the same type. Otherwise + compare as normal. + """ + if spec1_type is None or spec2_type is None: + return spec1_type is spec2_type and spec1_context == spec2_context + + if issubclass(spec1_type, (dict, immutable_dict)) and issubclass( + spec2_type, (dict, immutable_dict) + ): + return spec1_context == spec2_context + + if issubclass(spec1_type, (list, immutable_list)) and issubclass( + spec2_type, (list, immutable_list) + ): + return spec1_context == spec2_context + + return spec1_type is spec2_type and spec1_context == spec2_context + + +class ExportedProgram: + """ + Package of a program from :func:`export`. It contains + an :class:`torch.fx.Graph` that represents Tensor computation, a state_dict containing + tensor values of all lifted parameters and buffers, and various metadata. + + You can call an ExportedProgram like the original callable traced by + :func:`export` with the same calling convention. + + To perform transformations on the graph, use ``.module`` property to access + an :class:`torch.fx.GraphModule`. You can then use + `FX transformation `_ + to rewrite the graph. Afterwards, you can simply use :func:`export` + again to construct a correct ExportedProgram. + """ + + def __init__( + self, + root: Union[torch.nn.Module, Dict[str, Any]], + graph: torch.fx.Graph, + graph_signature: ExportGraphSignature, + state_dict: Dict[str, Union[torch.Tensor, torch.nn.Parameter]], + range_constraints: "Dict[sympy.Symbol, Any]", + module_call_graph: List[ModuleCallEntry], + example_inputs: Optional[Tuple[Tuple[Any, ...], Dict[str, Any]]] = None, + verifier: Optional[Type[Any]] = None, # TODO Change typing hint to Verifier. + tensor_constants: Optional[ + Dict[str, torch.Tensor] + ] = None, # TODO: deprecate this + constants: Optional[ + Dict[str, Union[torch.Tensor, torch._C.ScriptObject]] + ] = None, + ): + # Remove codegen related things from the graph. It should just be a flat graph. + graph._codegen = torch.fx.graph.CodeGen() + self._graph_module = _create_graph_module_for_export(root, graph) + if isinstance(root, torch.fx.GraphModule): + self._graph_module.meta.update(root.meta) + + self._graph_signature: ExportGraphSignature = graph_signature + self._state_dict: Dict[str, Any] = state_dict + self._range_constraints: "Dict[sympy.Symbol, ValueRanges]" = range_constraints + assert module_call_graph is not None + self._module_call_graph: List[ModuleCallEntry] = module_call_graph + self._example_inputs = example_inputs + + self._constants = tensor_constants or constants or {} + assert self._constants is not None + + from torch._export.verifier import Verifier + + if verifier is None: + verifier = Verifier + assert issubclass(verifier, Verifier) + self._verifier = verifier + # Validate should be always the last step of the constructor. + self.verifier().check(self) + + @property + @compatibility(is_backward_compatible=False) + def graph_module(self): + return self._graph_module + + @property + @compatibility(is_backward_compatible=False) + def graph(self): + return self.graph_module.graph + + @property + @compatibility(is_backward_compatible=False) + def graph_signature(self): + return self._graph_signature + + @property + @compatibility(is_backward_compatible=False) + def state_dict(self): + return self._state_dict + + @compatibility(is_backward_compatible=False) + def parameters(self) -> Iterator[torch.nn.Parameter]: + """ + Returns an iterator over original module's parameters. + """ + for _, param in self.named_parameters(): + yield param + + @compatibility(is_backward_compatible=False) + def named_parameters(self) -> Iterator[Tuple[str, torch.nn.Parameter]]: + """ + Returns an iterator over original module parameters, yielding + both the name of the parameter as well as the parameter itself. + """ + for param_name in self.graph_signature.parameters: + yield param_name, self.state_dict[param_name] + + @compatibility(is_backward_compatible=False) + def buffers(self) -> Iterator[torch.Tensor]: + """ + Returns an iterator over original module buffers. + """ + for _, buf in self.named_buffers(): + yield buf + + @compatibility(is_backward_compatible=False) + def named_buffers(self) -> Iterator[Tuple[str, torch.Tensor]]: + """ + Returns an iterator over original module buffers, yielding + both the name of the buffer as well as the buffer itself. + """ + non_persistent_buffers = set(self.graph_signature.non_persistent_buffers) + for buffer_name in self.graph_signature.buffers: + if buffer_name in non_persistent_buffers: + yield buffer_name, self.constants[buffer_name] + else: + yield buffer_name, self.state_dict[buffer_name] + + @property + @compatibility(is_backward_compatible=False) + def range_constraints(self): + return self._range_constraints + + @property + @compatibility(is_backward_compatible=False) + def module_call_graph(self): + return self._module_call_graph + + @property + @compatibility(is_backward_compatible=False) + def example_inputs(self): + return self._example_inputs + + @property + @compatibility(is_backward_compatible=False) + def call_spec(self): + CallSpec = namedtuple("CallSpec", ["in_spec", "out_spec"]) + + if len(self.module_call_graph) == 0: + return CallSpec(in_spec=None, out_spec=None) + assert self.module_call_graph[0].fqn == "" + return CallSpec( + in_spec=self.module_call_graph[0].signature.in_spec, + out_spec=self.module_call_graph[0].signature.out_spec, + ) + + @property + @compatibility(is_backward_compatible=False) + def verifier(self) -> Any: + return self._verifier + + @property + @compatibility(is_backward_compatible=False) + def dialect(self) -> str: + return self._verifier.dialect + + @property + @compatibility(is_backward_compatible=False) + def tensor_constants(self): + return self._constants + + @property + @compatibility(is_backward_compatible=False) + def constants(self): + return self._constants + + def _get_flat_args_with_check(self, args, kwargs): + """Flatten args, kwargs using pytree, then, check specs. + + Args: + args: List[Any] original args passed to __call__ + kwargs: Dict[str, Any] original kwargs passed to __call + + Returns: + A tuple of (flat_args, received_spec) + flat_args is flattend args / kwargs + received_spec is the pytree spec produced while flattening the + tuple (args, kwargs) + """ + in_spec = self.call_spec.in_spec + if in_spec is not None: + kwargs = reorder_kwargs(kwargs, in_spec) + flat_args_with_path, received_spec = pytree.tree_flatten_with_path( + (args, kwargs) + ) # type: ignore[possibly-undefined] + self._check_input_constraints(flat_args_with_path) + flat_args = tuple(x[1] for x in flat_args_with_path) + return flat_args, received_spec + + def _graph_module_flat_inputs(self, args: Any, kwargs: Any) -> Any: + """Transform args, kwargs of __call__ to args for graph_module. + + self.graph_module takes stuff from state dict as inputs. + The invariant is for ep: ExportedProgram is + ep(args, kwargs) == + ep.postprocess(ep.graph_module(ep.graph_module_flat_inputs(args, kwargs))) + """ + + in_spec = self.call_spec.in_spec + flat_args, received_spec = self._get_flat_args_with_check(args, kwargs) + if in_spec is not None and not is_equivalent( + received_spec, in_spec, _fx_collection_equivalence_fn + ): + raise ValueError( + "Trying to flatten user inputs with exported input tree spec: \n" + f"{in_spec}\n" + "but actually got inputs with tree spec of: \n" + f"{received_spec}" + ) + + additional_inputs = [] + for input_ in self.graph_signature.input_specs: + if input_.kind == InputKind.USER_INPUT: + continue + elif input_.kind in ( + InputKind.PARAMETER, + InputKind.BUFFER, + ): + if input_.persistent is False: + # This is a non-persistent buffer, grab it from our + # constants instead of the state dict. + additional_inputs.append(self.constants[input_.target]) + else: + additional_inputs.append(self.state_dict[input_.target]) + elif input_.kind in ( + InputKind.CONSTANT_TENSOR, + InputKind.CUSTOM_OBJ, + ): + additional_inputs.append(self.constants[input_.target]) + additional_inputs = tuple(additional_inputs) + + # NOTE: calling convention is first params, then buffers, then args as user supplied them. + # See: torch/_functorch/aot_autograd.py#L1034 + return additional_inputs + flat_args + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + raise RuntimeError( + "Unable to call ExportedProgram directly. " + "You should use `exported_program.module()` instead." + ) + + def _postprocess_graph_module_outputs(self, res, orig_args, orig_kwargs): + """Process potential mutations to the input. + + Because self.graph_module is functional, so mutations has to be written + back after execution of graph_module. + """ + import torch._export.error as error + + flat_args, _ = self._get_flat_args_with_check(orig_args, orig_kwargs) + if self.call_spec.out_spec is not None: + buffer_mutation = self.graph_signature.buffers_to_mutate + user_input_mutation = self.graph_signature.user_inputs_to_mutate + num_mutated = len(buffer_mutation) + len(user_input_mutation) + mutated_values = res[:num_mutated] + + # Exclude dependency token from final result. + assertion_dep_token = self.graph_signature.assertion_dep_token + if assertion_dep_token is not None: + assertion_dep_token_index = next(iter(assertion_dep_token.keys())) + res = res[:assertion_dep_token_index] + + res = res[num_mutated:] + try: + res = pytree.tree_unflatten(res, self.call_spec.out_spec) + except Exception: + _, received_spec = pytree.tree_flatten(res) + raise error.InternalError( # noqa: TRY200 + "Trying to flatten user outputs with exported output tree spec: \n" + f"{self.call_spec.out_spec}\n" + "but actually got outputs with tree spec of: \n" + f"{received_spec}" + ) + finally: + user_inputs = [ + spec + for spec in self.graph_signature.input_specs + if spec.kind == InputKind.USER_INPUT + ] + for i, value in enumerate(mutated_values): + output_spec = self.graph_signature.output_specs[i] + if output_spec.kind == OutputKind.BUFFER_MUTATION: + assert output_spec.target is not None + self.state_dict[output_spec.target] = value + elif output_spec.kind == OutputKind.USER_INPUT_MUTATION: + assert output_spec.target is not None + index = next( + i + for i, spec in enumerate(user_inputs) + if spec.arg.name == output_spec.target + ) + flat_args[index].copy_(value) + else: + raise AssertionError(f"Unexpected kind: {output_spec.kind}") + return res + + def __str__(self) -> str: + graph_module = self.graph_module.print_readable(print_output=False).replace( + "\n", "\n " + ) + string = ( + "ExportedProgram:\n" + f" {graph_module}\n" + f"Graph signature: {self.graph_signature}\n" + f"Range constraints: {self.range_constraints}\n" + ) + return string + + def module(self) -> torch.nn.Module: + """ + Returns a self contained GraphModule with all the parameters/buffers inlined. + """ + from ._unlift import _unlift_exported_program_lifted_states + + module = _unlift_exported_program_lifted_states(self) + + def _train(self, mode: bool = True): + raise NotImplementedError("Calling train() is not supported yet.") + + def _eval(self, mode: bool = True): + raise NotImplementedError("Calling eval() is not supported yet.") + + module.train = types.MethodType(_train, module) # type: ignore[method-assign] + module.eval = types.MethodType(_eval, module) # type: ignore[method-assign] + return module + + @_disable_prexisiting_fake_mode + def run_decompositions( + self, decomp_table: Optional[Dict[torch._ops.OperatorBase, Callable]] = None + ) -> "ExportedProgram": + """ + Run a set of decompositions on the exported program and returns a new + exported program. By default we will run the Core ATen decompositions to + get operators in the + `Core ATen Operator Set `_. + + For now, we do not decompose joint graphs. + """ + from torch._decomp import core_aten_decompositions + from torch._export.passes.add_runtime_assertions_for_constraints_pass import ( + _AddRuntimeAssertionsForInlineConstraintsPass, + ) + from torch._export.passes.lift_constants_pass import ( + ConstantAttrMap, + lift_constants_pass, + ) + from torch._export.passes.replace_sym_size_ops_pass import ( + _replace_sym_size_ops_pass, + ) + from torch._functorch.aot_autograd import aot_export_module + + def _get_placeholders(gm): + placeholders = [] + for node in gm.graph.nodes: + if node.op != "placeholder": + break + placeholders.append(node) + return placeholders + + decomp_table = decomp_table or core_aten_decompositions() + + old_placeholders = _get_placeholders(self.graph_module) + fake_args = [node.meta["val"] for node in old_placeholders] + + buffers_to_remove = [name for name, _ in self.graph_module.named_buffers()] + for name in buffers_to_remove: + delattr(self.graph_module, name) + # TODO(zhxhchen17) Return the new graph_signature directly. + gm, graph_signature = aot_export_module( + self.graph_module, fake_args, decompositions=decomp_table, trace_joint=False + ) + + # Update the signatures with the new placeholder names in case they + # changed when calling aot_export + def update_arg(old_arg, new_ph): + if isinstance(old_arg, ConstantArgument): + return old_arg + elif isinstance(old_arg, TensorArgument): + return TensorArgument(name=new_ph.name) + elif isinstance(old_arg, SymIntArgument): + return SymIntArgument(name=new_ph.name) + raise RuntimeError(f"Type of old_arg not supported: {type(old_arg)}") + + new_placeholders = _get_placeholders(gm) + new_outputs = list(gm.graph.nodes)[-1].args[0] + + # To match the output target with correct input for input mutations + # need to find the old to new placeholder map + old_new_placeholder_map = { + spec.arg.name: new_placeholders[i].name + for i, spec in enumerate(self.graph_signature.input_specs) + if not isinstance(spec.arg, ConstantArgument) + } + + input_specs = [ + InputSpec( + spec.kind, + update_arg(spec.arg, new_placeholders[i]), + spec.target, + spec.persistent, + ) + for i, spec in enumerate(self.graph_signature.input_specs) + ] + output_specs = [ + OutputSpec( + spec.kind, + update_arg(spec.arg, new_outputs[i]), + old_new_placeholder_map.get(spec.target, spec.target), + ) + for i, spec in enumerate(self.graph_signature.output_specs) + ] + + assert len(new_placeholders) == len(old_placeholders) + + new_graph_signature = ExportGraphSignature( + input_specs=input_specs, output_specs=output_specs + ) + # NOTE: aot_export adds symint metadata for placeholders with int + # values; since these become specialized, we replace such metadata with + # the original values. + # Also, set the param/buffer metadata back to the placeholders. + for old_node, new_node in zip(old_placeholders, new_placeholders): + if not isinstance(old_node.meta["val"], torch.Tensor): + new_node.meta["val"] = old_node.meta["val"] + + if ( + new_node.target in new_graph_signature.inputs_to_parameters + or new_node.target in new_graph_signature.inputs_to_buffers + ): + for k, v in old_node.meta.items(): + new_node.meta[k] = v + + # TODO unfortunately preserving graph-level metadata is not + # working well with aot_export. So we manually copy it. + # (The node-level meta is addressed above.) + gm.meta.update(self.graph_module.meta) + + new_range_constraints = _get_updated_range_constraints(gm) + + constants = lift_constants_pass(gm, new_graph_signature, ConstantAttrMap()) + for k, v in constants.items(): + assert k not in self.constants + self.constants[k] = v + + _replace_sym_size_ops_pass(gm) + exported_program = ExportedProgram( + root=gm, + graph=gm.graph, + graph_signature=new_graph_signature, + state_dict=self.state_dict, + range_constraints=new_range_constraints, + module_call_graph=copy.deepcopy(self.module_call_graph), + example_inputs=self.example_inputs, + verifier=self.verifier, + constants=self.constants, + ) + + if len(new_range_constraints) > 0: + exported_program = exported_program._transform_do_not_use( + _AddRuntimeAssertionsForInlineConstraintsPass(new_range_constraints) + ) + + return exported_program + + def _transform_do_not_use(self, *passes: PassType) -> "ExportedProgram": + pm = PassManager(list(passes)) + res = pm(self.graph_module) + transformed_gm = res.graph_module if res is not None else self.graph_module + assert transformed_gm is not None + + if transformed_gm is self.graph_module and not res.modified: + return self + + # TODO(zhxchen17) Remove this. + def _get_updated_graph_signature( + old_signature: ExportGraphSignature, + new_gm: torch.fx.GraphModule, + ) -> ExportGraphSignature: + """ + Update the graph signature's user_input/user_outputs. + """ + new_input_specs = [] + for i, node in enumerate(new_gm.graph.nodes): + if node.op != "placeholder": + break + + assert i < len( + old_signature.input_specs + ), "Number of inputs changed after transformation" + old_input_spec = old_signature.input_specs[i] + arg = ( + old_input_spec.arg + if isinstance( + old_input_spec.arg, (ConstantArgument, CustomObjArgument) + ) + else type(old_input_spec.arg)(node.name) + ) + new_input_specs.append( + InputSpec( + old_input_spec.kind, + arg, + old_input_spec.target, + old_input_spec.persistent, + ) + ) + + output_node = list(new_gm.graph.nodes)[-1] + assert output_node.op == "output" + + new_output_specs = [] + for i, node in enumerate(output_node.args[0]): + assert i < len( + old_signature.output_specs + ), "Number of outputs changed after transformation" + old_output_spec = old_signature.output_specs[i] + arg = ( + old_output_spec.arg + if isinstance( + old_output_spec.arg, (ConstantArgument, CustomObjArgument) + ) + else type(old_output_spec.arg)(node.name) + ) + new_output_specs.append( + OutputSpec(old_output_spec.kind, arg, old_output_spec.target) + ) + + new_signature = ExportGraphSignature( + input_specs=new_input_specs, output_specs=new_output_specs + ) + return new_signature + + transformed_ep = ExportedProgram( + root=transformed_gm, + graph=transformed_gm.graph, + graph_signature=_get_updated_graph_signature( + self.graph_signature, transformed_gm + ), + state_dict=self.state_dict, + range_constraints=_get_updated_range_constraints(transformed_gm), + module_call_graph=copy.deepcopy(self._module_call_graph), + example_inputs=self.example_inputs, + verifier=self.verifier, + constants=self.constants, + ) + transformed_ep.graph_module.meta.update(self.graph_module.meta) + transformed_ep.graph_module.meta.update(res.graph_module.meta) + return transformed_ep + + def _check_input_constraints(self, flat_args_with_path): + from torch._export.utils import _check_input_constraints_for_graph + + placeholders = [p for p in self.graph.nodes if p.op == "placeholder"] + input_placeholders = [ + p + for p, s in zip(placeholders, self.graph_signature.input_specs) + if s.kind == InputKind.USER_INPUT + ] + _check_input_constraints_for_graph( + input_placeholders, flat_args_with_path, self.range_constraints + ) + + def _validate(self): + self.verifier().check(self) + + # TODO(zhxchen17) Formalize this. + def _update( + self, graph_module, graph_signature, state_dict=None + ) -> "ExportedProgram": + return ExportedProgram( + root=graph_module, + graph=graph_module.graph, + graph_signature=graph_signature, + state_dict=state_dict or self.state_dict, + range_constraints=copy.deepcopy(self.range_constraints), + module_call_graph=copy.deepcopy(self._module_call_graph), + example_inputs=self.example_inputs, + verifier=self.verifier, + tensor_constants=self.tensor_constants, + ) + + +def _get_updated_range_constraints( + gm: torch.fx.GraphModule, +) -> "Dict[sympy.Symbol, Any]": + def get_shape_env(gm): + vals = [ + node.meta["val"] + for node in gm.graph.nodes + if node.meta.get("val", None) is not None + ] + from torch._guards import detect_fake_mode + + fake_mode = detect_fake_mode(vals) + if fake_mode is not None: + return fake_mode.shape_env + for v in vals: + if isinstance(v, torch.SymInt): + return v.node.shape_env + + shape_env = get_shape_env(gm) + if shape_env is None: + return {} + range_constraints = { + k: v + for k, v in shape_env.var_to_range.items() + if k not in shape_env.replacements + } + # Only when we have an unbacked symint, and it's used as constructor inputs, + # runtime_var_to_range will make a difference compated to var_to_range. + # e.g. [2, oo) -> [0, oo) + for k, v in shape_env.var_to_range.items(): + if k not in shape_env.replacements: + range_constraints[k] = v + return range_constraints + + +def _create_graph_module_for_export(root, graph): + try: + gm = torch.fx.GraphModule(root, graph) + except SyntaxError: + # If custom objects stored in memory are being used in the graph, + # the generated python code will result in a syntax error on the custom + # object, since it is unable to parse the in-memory object. However + # we can still run the graph eagerly through torch.fx.Interpreter, + # so we will bypass this error. + warnings.warn( + "Unable to execute the generated python source code from " + "the graph. The graph module will no longer be directly callable, " + "but you can still run the ExportedProgram, and if needed, you can " + "run the graph module eagerly using torch.fx.Interpreter." + ) + gm = torch.fx.GraphModule(root, torch.fx.Graph()) + gm._graph = graph + + return gm diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/graph_signature.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/graph_signature.py new file mode 100644 index 0000000000000000000000000000000000000000..9829fece8d785d2664f8be13c4f831e364047ad1 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/graph_signature.py @@ -0,0 +1,504 @@ +import dataclasses +from enum import auto, Enum +from typing import Collection, Dict, List, Mapping, Optional, Set, Tuple, Union + + +__all__ = [ + "ConstantArgument", + "CustomObjArgument", + "ExportBackwardSignature", + "ExportGraphSignature", + "InputKind", + "InputSpec", + "OutputKind", + "OutputSpec", + "SymIntArgument", + "TensorArgument", +] + + +@dataclasses.dataclass +class TensorArgument: + name: str + + +@dataclasses.dataclass +class SymIntArgument: + name: str + + +@dataclasses.dataclass +class CustomObjArgument: + name: str + class_fqn: str + + +@dataclasses.dataclass +class ConstantArgument: + value: Union[int, float, bool, None] + + +ArgumentSpec = Union[ + TensorArgument, SymIntArgument, ConstantArgument, CustomObjArgument +] + + +class InputKind(Enum): + USER_INPUT = auto() + PARAMETER = auto() + BUFFER = auto() + CONSTANT_TENSOR = auto() + CUSTOM_OBJ = auto() + TOKEN = auto() + + +@dataclasses.dataclass +class InputSpec: + kind: InputKind + arg: ArgumentSpec + target: Optional[str] + persistent: Optional[bool] = None + + def __post_init__(self): + if self.kind == InputKind.BUFFER: + assert ( + self.persistent is not None + ), "Failed to specify persistent flag on BUFFER." + assert isinstance( + self.arg, + (TensorArgument, SymIntArgument, ConstantArgument, CustomObjArgument), + ), f"got {type(self.arg)}" + + +class OutputKind(Enum): + USER_OUTPUT = auto() + LOSS_OUTPUT = auto() + BUFFER_MUTATION = auto() + GRADIENT_TO_PARAMETER = auto() + GRADIENT_TO_USER_INPUT = auto() + USER_INPUT_MUTATION = auto() + TOKEN = auto() + + +@dataclasses.dataclass +class OutputSpec: + kind: OutputKind + arg: ArgumentSpec + target: Optional[str] + + def __post_init__(self): + assert isinstance(self.arg, (TensorArgument, SymIntArgument, ConstantArgument)) + + +def _sig_to_specs( + *, + user_inputs: Set[str], + inputs_to_parameters: Mapping[str, str], + inputs_to_buffers: Mapping[str, str], + user_outputs: Set[str], + buffer_mutations: Mapping[str, str], + user_input_mutations: Mapping[str, str], + grad_params: Mapping[str, str], + grad_user_inputs: Mapping[str, str], + loss_output: Optional[str], + inputs: List[ArgumentSpec], + outputs: List[ArgumentSpec], + input_tokens: List[str], + output_tokens: List[str], +) -> Tuple[List[InputSpec], List[OutputSpec]]: + def to_input_spec(inp: ArgumentSpec) -> InputSpec: + if not isinstance(inp, TensorArgument): + return InputSpec(kind=InputKind.USER_INPUT, arg=inp, target=None) + name = inp.name + if name in user_inputs: + return InputSpec(kind=InputKind.USER_INPUT, arg=inp, target=None) + elif name in inputs_to_parameters: + return InputSpec( + kind=InputKind.PARAMETER, + arg=inp, + target=inputs_to_parameters[name], + ) + elif name in inputs_to_buffers: + return InputSpec( + kind=InputKind.BUFFER, + arg=inp, + target=inputs_to_buffers[name], + # Mark as True for now; we will fix this up to distinguish + # persistent from non-persistent later in tracing. + # See: rewrite_non_persistent_buffers() + # TODO(suo): this is horrible. + persistent=True, + ) + elif name in input_tokens: + return InputSpec(kind=InputKind.TOKEN, arg=inp, target=None) + else: + raise AssertionError(f"Unknown tensor input kind: {name}") + + def to_output_spec(idx: int, o: ArgumentSpec) -> OutputSpec: + if not isinstance(o, TensorArgument): + return OutputSpec(kind=OutputKind.USER_OUTPUT, arg=o, target=None) + name = o.name + if idx < len(buffer_mutations) + len(user_input_mutations) + len(output_tokens): + if name in buffer_mutations: + return OutputSpec( + kind=OutputKind.BUFFER_MUTATION, + arg=o, + target=buffer_mutations[name], + ) + elif name in user_input_mutations: + return OutputSpec( + kind=OutputKind.USER_INPUT_MUTATION, + arg=o, + target=user_input_mutations[name], + ) + elif name in output_tokens: + return OutputSpec(kind=OutputKind.TOKEN, arg=o, target=None) + else: + raise AssertionError(f"Unknown tensor mutation kind: {name}") + else: + if name in user_outputs: + return OutputSpec(kind=OutputKind.USER_OUTPUT, arg=o, target=None) + + elif name in grad_params: + return OutputSpec( + kind=OutputKind.GRADIENT_TO_PARAMETER, + arg=o, + target=grad_params[name], + ) + elif name in grad_user_inputs: + return OutputSpec( + kind=OutputKind.GRADIENT_TO_USER_INPUT, + arg=o, + target=grad_user_inputs[name], + ) + elif name == loss_output: + return OutputSpec(kind=OutputKind.LOSS_OUTPUT, arg=o, target=None) + + else: + raise AssertionError(f"Unknown tensor output kind: {name}") + + input_specs = [to_input_spec(inp) for inp in inputs] + output_specs = [to_output_spec(idx, o) for idx, o in enumerate(outputs)] + return input_specs, output_specs + + +@dataclasses.dataclass +class ExportBackwardSignature: + gradients_to_parameters: Dict[str, str] + gradients_to_user_inputs: Dict[str, str] + loss_output: str + + +@dataclasses.dataclass +class ExportGraphSignature: + """ + :class:`ExportGraphSignature` models the input/output signature of Export Graph, + which is a fx.Graph with stronger invariants gurantees. + + Export Graph is functional and does not access "states" like parameters + or buffers within the graph via ``getattr`` nodes. Instead, :func:`export` + gurantees that parameters, buffers, and constant tensors are lifted out of + the graph as inputs. Similarly, any mutations to buffers are not included + in the graph either, instead the updated values of mutated buffers are + modeled as additional outputs of Export Graph. + + The ordering of all inputs and outputs are:: + + Inputs = [*parameters_buffers_constant_tensors, *flattened_user_inputs] + Outputs = [*mutated_inputs, *flattened_user_outputs] + + e.g. If following module is exported:: + + class CustomModule(nn.Module): + def __init__(self): + super(CustomModule, self).__init__() + + # Define a parameter + self.my_parameter = nn.Parameter(torch.tensor(2.0)) + + # Define two buffers + self.register_buffer('my_buffer1', torch.tensor(3.0)) + self.register_buffer('my_buffer2', torch.tensor(4.0)) + + def forward(self, x1, x2): + # Use the parameter, buffers, and both inputs in the forward method + output = (x1 + self.my_parameter) * self.my_buffer1 + x2 * self.my_buffer2 + + # Mutate one of the buffers (e.g., increment it by 1) + self.my_buffer2.add_(1.0) # In-place addition + + return output + + Resulting Graph would be:: + + graph(): + %arg0_1 := placeholder[target=arg0_1] + %arg1_1 := placeholder[target=arg1_1] + %arg2_1 := placeholder[target=arg2_1] + %arg3_1 := placeholder[target=arg3_1] + %arg4_1 := placeholder[target=arg4_1] + %add_tensor := call_function[target=torch.ops.aten.add.Tensor](args = (%arg3_1, %arg0_1), kwargs = {}) + %mul_tensor := call_function[target=torch.ops.aten.mul.Tensor](args = (%add_tensor, %arg1_1), kwargs = {}) + %mul_tensor_1 := call_function[target=torch.ops.aten.mul.Tensor](args = (%arg4_1, %arg2_1), kwargs = {}) + %add_tensor_1 := call_function[target=torch.ops.aten.add.Tensor](args = (%mul_tensor, %mul_tensor_1), kwargs = {}) + %add_tensor_2 := call_function[target=torch.ops.aten.add.Tensor](args = (%arg2_1, 1.0), kwargs = {}) + return (add_tensor_2, add_tensor_1) + + Resulting ExportGraphSignature would be:: + + ExportGraphSignature( + input_specs=[ + InputSpec(kind=, arg=TensorArgument(name='arg0_1'), target='my_parameter'), + InputSpec(kind=, arg=TensorArgument(name='arg1_1'), target='my_buffer1'), + InputSpec(kind=, arg=TensorArgument(name='arg2_1'), target='my_buffer2'), + InputSpec(kind=, arg=TensorArgument(name='arg3_1'), target=None), + InputSpec(kind=, arg=TensorArgument(name='arg4_1'), target=None) + ], + output_specs=[ + OutputSpec(kind=, arg=TensorArgument(name='add_2'), target='my_buffer2'), + OutputSpec(kind=, arg=TensorArgument(name='add_1'), target=None) + ] + ) + """ + + input_specs: List[InputSpec] + output_specs: List[OutputSpec] + + # A list of parameters uniquely identified by mangled fully qualified name + @property + def parameters(self) -> Collection[str]: + # TODO Make this tuple. + return [ + s.target + for s in self.input_specs + if s.kind == InputKind.PARAMETER + if isinstance(s.target, str) + ] + + # A list of buffers uniquely identified by mangled fully qualified name + @property + def buffers(self) -> Collection[str]: + # TODO Make this tuple. + return [ + s.target + for s in self.input_specs + if s.kind == InputKind.BUFFER + if isinstance(s.target, str) + ] + + @property + def non_persistent_buffers(self) -> Collection[str]: + return [ + s.target + for s in self.input_specs + if s.kind == InputKind.BUFFER + if s.persistent is False + if isinstance(s.target, str) + ] + + # A list of lifted constant tensors + @property + def lifted_tensor_constants(self) -> Collection[str]: + # TODO Make this tuple. + return [ + s.target + for s in self.input_specs + if s.kind == InputKind.CONSTANT_TENSOR + if isinstance(s.target, str) + ] + + @property + def lifted_custom_objs(self) -> Collection[str]: + # TODO Make this tuple. + return [ + s.target + for s in self.input_specs + if s.kind == InputKind.CUSTOM_OBJ + if isinstance(s.target, str) + ] + + # Graph node names of pytree-flattened inputs of original program + @property + def user_inputs(self) -> Collection[Union[int, float, bool, None, str]]: + user_inputs: List[Union[int, float, bool, None, str]] = [] + for s in self.input_specs: + if s.kind != InputKind.USER_INPUT: + continue + + if isinstance(s.arg, (TensorArgument, SymIntArgument, CustomObjArgument)): + user_inputs.append(s.arg.name) + elif isinstance(s.arg, ConstantArgument): + user_inputs.append(s.arg.value) + else: + raise RuntimeError(f"{s.arg} is not a valid user inputs") + return tuple(user_inputs) + + # Graph node names of pytree-flattened outputs of original program + @property + def user_outputs(self) -> Collection[Union[int, float, bool, None, str]]: + user_outputs: List[Union[int, float, bool, None, str]] = [] + for s in self.output_specs: + if s.kind != OutputKind.USER_OUTPUT: + continue + + if isinstance(s.arg, (TensorArgument, SymIntArgument)): + user_outputs.append(s.arg.name) + elif isinstance(s.arg, ConstantArgument): + user_outputs.append(s.arg.value) + else: + raise RuntimeError(f"{s.arg} is not a valid user output") + return tuple(user_outputs) + + # A dictionary mapping graph input node names to parameters. If a graph input + # name is found in this dictionary, it is guranteed to be a lifted parameter. + @property + def inputs_to_parameters(self) -> Mapping[str, str]: + return { + s.arg.name: s.target + for s in self.input_specs + if s.kind == InputKind.PARAMETER + and isinstance(s.arg, TensorArgument) + and isinstance(s.target, str) + } + + # A dictionary mapping graph input node names to buffers. If a graph input + # name is found in this dictionary, it is guranteed to be a lifted buffer. + @property + def inputs_to_buffers(self) -> Mapping[str, str]: + return { + s.arg.name: s.target # type: ignore[union-attr, misc] + for s in self.input_specs + if s.kind == InputKind.BUFFER + and isinstance(s.arg, TensorArgument) + and isinstance(s.target, str) + } + + # A dictionary mapping graph output node names to buffers that are mutated in the + # original program. Buffers that are not mutated will not be found in this dictionary. + @property + def buffers_to_mutate(self) -> Mapping[str, str]: + return { + s.arg.name: s.target + for s in self.output_specs + if s.kind == OutputKind.BUFFER_MUTATION + and isinstance(s.arg, TensorArgument) + and isinstance(s.target, str) + } + + @property + def user_inputs_to_mutate(self) -> Mapping[str, str]: + return { + s.arg.name: s.target + for s in self.output_specs + if s.kind == OutputKind.USER_INPUT_MUTATION + and isinstance(s.arg, TensorArgument) + and isinstance(s.target, str) + } + + # A dictionary mapping graph input node names to lifted tensor constants. + @property + def inputs_to_lifted_tensor_constants(self) -> Mapping[str, str]: + return { + s.arg.name: s.target + for s in self.input_specs + if s.kind == InputKind.CONSTANT_TENSOR + and isinstance(s.arg, TensorArgument) + and isinstance(s.target, str) + } + + @property + def inputs_to_lifted_custom_objs(self) -> Mapping[str, str]: + return { + s.arg.name: s.target + for s in self.input_specs + if s.kind == InputKind.CUSTOM_OBJ + and isinstance(s.arg, CustomObjArgument) + and isinstance(s.target, str) + } + + @property + def backward_signature(self) -> Optional[ExportBackwardSignature]: + loss_output = None + gradients_to_parameters: Dict[str, str] = {} + gradients_to_user_inputs: Dict[str, str] = {} + for spec in self.output_specs: + if spec.kind == OutputKind.LOSS_OUTPUT: + assert loss_output is None + assert isinstance(spec.arg, TensorArgument) + loss_output = spec.arg.name + elif spec.kind == OutputKind.GRADIENT_TO_PARAMETER: + assert isinstance(spec.target, str) + assert isinstance(spec.arg, TensorArgument) + gradients_to_parameters[spec.arg.name] = spec.target + elif spec.kind == OutputKind.GRADIENT_TO_USER_INPUT: + assert isinstance(spec.target, str) + assert isinstance(spec.arg, TensorArgument) + gradients_to_user_inputs[spec.arg.name] = spec.target + + if loss_output is None: + return None + + return ExportBackwardSignature( + loss_output=loss_output, + gradients_to_parameters=gradients_to_parameters, + gradients_to_user_inputs=gradients_to_user_inputs, + ) + + # Map from assertion dependency token index to assertion dep token output + # name in output. The shape of output after aot_autograd will be like: + # (updated_inputs, user_outputs, dep_token). + @property + def assertion_dep_token(self) -> Optional[Mapping[int, str]]: + return None + + @property + def input_tokens(self) -> List[str]: + input_tokens = [] + for s in self.input_specs: + if s.kind == InputKind.TOKEN: + assert isinstance(s.arg, TensorArgument) + input_tokens.append(s.arg.name) + return input_tokens + + @property + def output_tokens(self) -> List[str]: + output_tokens = [] + for s in self.output_specs: + if s.kind == OutputKind.TOKEN: + assert isinstance(s.arg, TensorArgument) + output_tokens.append(s.arg.name) + return output_tokens + + def __post_init__(self) -> None: + assertion_dep_token = self.assertion_dep_token + if assertion_dep_token is None: + return + assert len(assertion_dep_token) == 1 + assertion_dep_token_index = next(iter(assertion_dep_token.keys())) + assert ( + len(self.user_outputs) + len(self.buffers_to_mutate) + == assertion_dep_token_index + ) + + def replace_all_uses(self, old: str, new: str): + """ + Replace all uses of the old name with new name in the signature. + """ + assert isinstance(old, str) + assert isinstance(new, str) + arg_types = (TensorArgument, SymIntArgument, CustomObjArgument) + for o in self.output_specs: + if isinstance(o.arg, arg_types): + if o.arg.name == old: + o.arg.name = new + for i in self.input_specs: + if isinstance(i.arg, arg_types): + if i.arg.name == old: + i.arg.name = new + + def get_replace_hook(self): + def _(old, new, user): + if user.op in ("output", "input"): + self.replace_all_uses(old.name, new) + + return _ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/unflatten.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/unflatten.py new file mode 100644 index 0000000000000000000000000000000000000000..b16febd462194334d6f03ab7a7ffc97e5e4ac723 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/unflatten.py @@ -0,0 +1,860 @@ +import abc +import copy +import operator +from copy import deepcopy +from enum import Enum +from itertools import chain +from typing import Any, cast, Dict, List, Optional, Union + +import torch +import torch.fx._pytree as fx_pytree +import torch.utils._pytree as pytree +from torch.export._tree_utils import reorder_kwargs +from torch.export.exported_program import ( + ConstantArgument, + ExportedProgram, + ModuleCallSignature, + SymIntArgument, + TensorArgument, +) +from torch.fx._symbolic_trace import is_fx_tracing +from torch.utils._pytree import GetAttrKey, SequenceKey + +__all__ = ["InterpreterModule", "UnflattenedModule", "unflatten", "FlatArgsAdapter"] + + +class _AttrKind(Enum): + PARAMETER = "parameter" + BUFFER = "buffer" + CONSTANT = "constant" + + +# Assign attribute 'from_obj' to the qualified name 'target' on 'to_module +# This installs empty Modules where none exist yet if they are subpaths of target +def _assign_attr( + from_obj: Union[torch.Tensor, torch.ScriptObject], + to_module: torch.nn.Module, + target: str, + attr_kind: _AttrKind, + persistent: bool = True, +): + *prefix, field = target.split(".") + for item in prefix: + t = getattr(to_module, item, None) + + if t is None: + t = torch.nn.Module() + setattr(to_module, item, t) + to_module = t + + if attr_kind == _AttrKind.PARAMETER: + assert isinstance(from_obj, torch.nn.Parameter) + to_module.register_parameter(field, from_obj) + elif attr_kind == _AttrKind.BUFFER: + assert isinstance(from_obj, torch.Tensor) + to_module.register_buffer(field, from_obj, persistent=persistent) + elif attr_kind == _AttrKind.CONSTANT: + assert isinstance(from_obj, (torch.Tensor, torch.ScriptObject)) + setattr(to_module, field, from_obj) + + +class InterpreterModule(torch.nn.Module): + """A module that uses torch.fx.Interpreter to execute instead of the usual + codegen that GraphModule uses. This provides better stack trace information + and makes it easier to debug execution. + """ + + def __init__( + self, + graph: torch.fx.Graph, + ): + super().__init__() + self.graph = graph + self.graph.owning_module = self + + def forward(self, *args, **kwargs): + assert self.graph_module is not None, "Didn't finalize this InterpreterModule" + if torch.compiler.is_dynamo_compiling(): + # Dynamo cannot trace through torch.fx.Interpreter, so fall back to + # GraphModule codegen in this instance. + return self.graph_module(*args, **kwargs) + else: + if kwargs: + # Handle **kwargs. FX only natively supports positional + # arguments (through placeholders). So in order to pass in + # kwargs, we must correspond the names of the placeholders with + # the keys in the kwarg dict. + arg_list = list(args) + kwarg_names = self.arg_names[len(arg_list) :] + for kwarg_name in kwarg_names: + if kwarg_name in kwargs: + arg_list.append(kwargs[kwarg_name]) + + # Assert that the kwargs passed in exactly match the positional + # arguments specified by the GraphModule. This should be + # guaranteed by the unflattening process. + assert len(kwarg_names) == len(kwargs) + assert len(arg_list) == len(self.arg_names) + args = tuple(arg_list) + + return torch.fx.Interpreter(self, graph=self.graph).run( + *args, enable_io_processing=False + ) + + def finalize(self): + # We need to "finalize" because GraphModule populates its own state_dict + # based on the get_attrs observed in the graph. So we need to fully + # construct the graph and call _sink_params before generating this + # GraphModule. + + # need to set `graph_module` directly on the dict to avoid it getting + # registered as a submodule. + self.__dict__["graph_module"] = torch.fx.GraphModule(self, self.graph) + self.graph.lint() + + # Cache arg names for kwarg handling (see forward()) + self.arg_names = [] + for node in self.graph.nodes: + if node.op == "placeholder": + self.arg_names.append(node.target) + + +class FlatArgsAdapter(abc.ABC): + """ + Adapts input arguments with ``input_spec`` to align ``target_spec``. + """ + + @abc.abstractmethod + def adapt( + self, + target_spec: pytree.TreeSpec, + input_spec: pytree.TreeSpec, + input_args: List[Any], + ) -> List[Any]: + """NOTE: This adapter may mutate given ``input_args_with_path``.""" + ... + + +class UnflattenedModule(torch.nn.Module): + def __init__( + self, + export_module: ExportedProgram, + flat_args_adapter: Optional[FlatArgsAdapter] = None, + ): + super().__init__() + if export_module.graph_signature.backward_signature is not None: + raise ValueError("Unflattening on JointExportModule NYI") + + export_graph = deepcopy(export_module.graph) + self.graph_signature = deepcopy(export_module.graph_signature) + self.graph = torch.fx.Graph() + self.module_call_graph = deepcopy(export_module.module_call_graph) + self.flat_args_adapter = flat_args_adapter + # Flag to indicate whether args have been adapted. + self.adapted = False + + _inplace_buffer_mutations(export_graph, self.graph_signature) + _outline_submodules(export_graph, self) + + self.range_constraints = export_module.range_constraints + self.equality_constraints: List = [] + + state_dict = export_module.state_dict + for name in self.graph_signature.parameters: + cloned = torch.nn.Parameter(state_dict[name].clone()) + _assign_attr( + cloned, + self, + name, + attr_kind=_AttrKind.PARAMETER, + ) + + non_persistent_buffers = set(self.graph_signature.non_persistent_buffers) + for name in self.graph_signature.buffers: + if name in non_persistent_buffers: + persistent = False + cloned = export_module.constants[name].clone() + else: + persistent = True + cloned = state_dict[name].clone() + + _assign_attr( + cloned, + self, + name, + attr_kind=_AttrKind.BUFFER, + persistent=persistent, + ) + + for fqn in chain( + self.graph_signature.lifted_tensor_constants, + self.graph_signature.lifted_custom_objs, + ): + constant = export_module.constants[fqn] + if isinstance(constant, torch.Tensor): + constant = constant.clone() + _assign_attr( + constant, + self, + fqn, + attr_kind=_AttrKind.CONSTANT, + ) + + inputs_to_state: Dict[str, str] = { + **self.graph_signature.inputs_to_parameters, + **self.graph_signature.inputs_to_buffers, + **self.graph_signature.inputs_to_lifted_tensor_constants, + **self.graph_signature.inputs_to_lifted_custom_objs, + } + + _sink_params(self, inputs_to_state, []) + # Check all input nodes has been processed. + for module in self.modules(): + if not isinstance(module, torch.fx.GraphModule): + continue + for node in module.graph.nodes: + if node.op != "placeholder": + continue + assert node.name not in inputs_to_state + + # Cache so we don't have to compute this every time. + # NOTE: this needs to be kept in sync with the placeholders in + # self.graph, but currently we have no way to guarantee that. + self.input_placeholders = [ + node for node in self.graph.nodes if node.op == "placeholder" + ] + self.check_input_constraints = True + assert self.module_call_graph[0].fqn == "" + + def forward(self, *args, **kwargs): + signature = self.module_call_graph[0].signature + + reordered_kwargs = reorder_kwargs(kwargs, signature.in_spec) + + flat_args_with_path, in_spec = pytree.tree_flatten_with_path( + (args, reordered_kwargs) + ) + flat_args = [x[1] for x in flat_args_with_path] + if is_fx_tracing(): + return_val = torch.fx.Interpreter(self, graph=self.graph).run( + *flat_args, enable_io_processing=False + ) + # For scalar return value, fx.Graph wraps in a tuple + if isinstance(return_val, tuple) and len(return_val) == 1: + return return_val[0] + return return_val + + if in_spec != signature.in_spec: + if not self.adapted: + print( + "Input treespec does not match with exported module's: \n" + f"Input treespec: {in_spec}. ", + f"Exported module treespec: {signature.in_spec}", + ) + if self.flat_args_adapter is None: + raise TypeError( + "There is no flat args adapter sepcified. " + "Are you sure you are calling this with the right arguments? " + ) + else: + if not self.adapted: + print("Adapting flat arg to match exported module's treespec") + flat_args = self.flat_args_adapter.adapt( + target_spec=signature.in_spec, + input_spec=in_spec, + input_args=flat_args, + ) + self.adapted = True + if len(flat_args) != signature.in_spec.num_leaves: + raise TypeError( + f"Flat args adaption failed, number of args mismatch " + f"Adatped: {len(flat_args)} \n" + f"Exported module: {signature.in_spec.num_leaves}" + ) + + if self.check_input_constraints: + # Import here to avoid an unfortunate circular dependency. + # TODO(suo): untangle this. + from torch._export.utils import _check_input_constraints_for_graph + + if self.adapted is True: + # TODO(suo): The FlatArgsAdapter returns a list of flat args, + # which we don't have keypaths for. For now, just create a dummy + # keypath to associate with the arg. + new_flat_args_with_path = [ # type: ignore[var-annotated] + ((SequenceKey(idx=0), GetAttrKey(name="")), arg) + for arg in flat_args + ] + else: + new_flat_args_with_path = flat_args_with_path # type: ignore[assignment] + + _check_input_constraints_for_graph( + self.input_placeholders, new_flat_args_with_path, self.range_constraints + ) + tree_out = torch.fx.Interpreter(self, graph=self.graph).run( + *flat_args, enable_io_processing=False + ) + return pytree.tree_unflatten(tree_out, signature.out_spec) + + +def unflatten( + module: ExportedProgram, flat_args_adapter: Optional[FlatArgsAdapter] = None +) -> UnflattenedModule: + """Unflatten an ExportedProgram, producing a module with the same module + hierarchy as the original eager module. This can be useful if you are trying + to use :mod:`torch.export` with another system that expects a module + hierachy instead of the flat graph that :mod:`torch.export` usually produces. + + .. note:: The args/kwargs of unflattened modules will not necessarily match + the eager module, so doing a module swap (e.g. :code:`self.submod = + new_mod`) will not necessarily work. If you need to swap a module out, you + need to set the :code:`preserve_module_call_signature` parameter of + :func:`torch.export.export`. + + Args: + module (ExportedProgram): The ExportedProgram to unflatten. + flat_args_adapter (Optional[FlatArgsAdapter]): Adapt flat args if input TreeSpec does not match with exported module's. + + Returns: + An instance of :class:`UnflattenedModule`, which has the same module + hierarchy as the original eager module pre-export. + """ + return UnflattenedModule(module, flat_args_adapter) + + +def _inplace_buffer_mutations(graph: torch.fx.Graph, graph_signature) -> None: + """Transform buffer mutations from their functionalized form into a copy_ + node in the graph. + + Functionalization represents buffer mutation by passing the buffer as an input and output. So for example, the eager code: + def forward(self, x): + self.buffer += x + return x * x + + Will become a graph that looks like: + def forward(self, buffer, x): + mutated_buffer = aten.add(buffer, x) + mul = aten.mul(x, x) + return (mutated_buffer, mul) + + We want to inplace this into something that looks like the original eager code: + def forward(self, buffer, x): + mutated_buffer = aten.add(buffer, x) + buffer.copy_(mutated_buffer) + mul = aten.mul(x, x) + return (mul,) + """ + output_node = next(iter(reversed(graph.nodes))) + assert output_node.op == "output" and len(output_node.args) == 1 + return_args = output_node.args[0] + + mutation_node_to_buffer = graph_signature.buffers_to_mutate + mutations = return_args[: len(mutation_node_to_buffer)] + buffers_to_inputs = {v: k for k, v in graph_signature.inputs_to_buffers.items()} + input_name_to_node = { + node.name: node for node in graph.nodes if node.op == "placeholder" + } + + for mutation in mutations: + buffer_name = mutation_node_to_buffer[mutation.name] + input_name = buffers_to_inputs[buffer_name] + input_node = input_name_to_node[input_name] + + with graph.inserting_after(mutation): + new_node = graph.create_node( + "call_function", torch.ops.aten.copy_, (input_node, mutation) + ) + for k, v in mutation.meta.items(): + new_node.meta[k] = v + # Replace all uses of the previously functional mutation with our copy_ output. + mutation.replace_all_uses_with(new_node, lambda x: x is not new_node) + + # Remove the mutated buffer from the graph outputs, since we don't need to + # thread it through anymore. We don't need to handle the inputs, which will + # be handled by _sink_params. + user_outputs = tuple( + return_args[len(mutation_node_to_buffer) :], + ) + output_node.args = ((user_outputs),) + + +def _is_prefix(candidate, target): + """Check whether `candidate` is a prefix of `target`.""" + return len(candidate) < len(target) and target[: len(candidate)] == candidate + + +def _compute_accessor(parent_fqn: str, child_fqn: str) -> str: + if parent_fqn == "": + # Handle the root module correctly. + return child_fqn + + parent_split = parent_fqn.split(".") + child_split = child_fqn.split(".") + + assert ( + child_split[: len(parent_split)] == parent_split + ), f"Child module '{child_fqn}' is not a descendant of parent module '{parent_fqn}'" + return ".".join(child_split[len(parent_split) :]) + + +def _verify_graph_equivalence(x: torch.nn.Module, y: torch.nn.Module): + def graph_dump(graph: torch.fx.Graph) -> str: + ret = [] + nodes_idx: Dict[int, int] = {} + + def arg_dump(arg) -> str: + if isinstance(arg, torch.fx.Node): + return "%" + str(nodes_idx[id(arg)]) + return str(arg) + + for i, node in enumerate(graph.nodes): + args_dump = [str(arg) for arg in pytree.tree_map(arg_dump, node.args)] + args_dump += [ + f"{key}={value}" + for key, value in pytree.tree_map(arg_dump, node.kwargs).items() + ] + target = node.target if node.op == "call_function" else "" + ret.append(f"{i}: {node.op}[{target}]({', '.join(args_dump)})") + nodes_idx[id(node)] = i + return "\n".join(ret) + + assert graph_dump(x.graph) == graph_dump(y.graph) + + +def _add_spec(gm: torch.nn.Module, spec) -> str: + i = 0 + while hasattr(gm, f"_spec_{i}"): + i += 1 + name = f"_spec_{i}" + setattr(gm, name, spec) + return name + + +def _generate_flatten(gm: torch.nn.Module, node, spec) -> torch.fx.Node: + name = _add_spec(gm, spec) + spec_node = gm.graph.get_attr(name) + return gm.graph.call_function(fx_pytree.tree_flatten_spec, (node, spec_node)) + + +def _generate_unflatten(gm: torch.nn.Module, nodes, spec) -> torch.fx.Node: + name = _add_spec(gm, spec) + spec_node = gm.graph.get_attr(name) + return gm.graph.call_function(pytree.tree_unflatten, (nodes, spec_node)) + + +def _add_submodule(mod: torch.nn.Module, target: str, module_to_add: torch.nn.Module): + *prefix, field = target.split(".") + + for item in prefix: + submod = getattr(mod, item, None) + + if submod is None: + submod = torch.nn.Module() + setattr(mod, item, submod) + + if not isinstance(submod, torch.nn.Module): + return False + + mod = submod + + mod.add_module(field, module_to_add) + + +class _ModuleFrame: + def __init__( + self, + flat_graph, + nodes, + seen_nodes, + seen_modules, + parent, + module_stack, + module_id, + module_call_graph: Dict[str, ModuleCallSignature], + module: Optional[torch.nn.Module] = None, + ): + self.flat_graph = flat_graph + self.nodes = nodes + self.seen_nodes = seen_nodes + self.seen_modules = seen_modules + self.parent = parent + self.module_stack = module_stack + self.module_id = module_id + + self.module_call_graph = module_call_graph + self.verbose = False + + self.fqn = self.module_stack[-1] + if module is not None: + self.module = module + else: + self.module = InterpreterModule(torch.fx.Graph()) + if self.module_id in self.seen_modules: + self.cached_graph_module = self.seen_modules[self.module_id] + else: + self.cached_graph_module = None + self.seen_modules[self.module_id] = self.module + + self.graph = self.module.graph + + # Mapping of nodes in the flat graph to nodes in this graph. + self.node_map: Dict[torch.fx.Node, torch.fx.Node] = {} + self.node_to_placeholder = {} + + self.parent_call_module: Optional[torch.fx.Node] = None + if parent is not None: + accessor = _compute_accessor(parent.fqn, self.fqn) + _add_submodule( + parent.module, + accessor, + self.module + if self.cached_graph_module is None + else self.cached_graph_module, + ) + self.parent_call_module = parent.graph.call_module(accessor) + + signature = module_call_graph.get(self.fqn) + if signature is not None and self.parent is not None: + assert signature.in_spec.num_children == 2 + args_spec = signature.in_spec.children_specs[0] + kwargs_spec = signature.in_spec.children_specs[1] + assert args_spec.context is None + assert kwargs_spec.context is not None + + with self.graph.inserting_after(None): + arg_nodes = [] + for idx in range(args_spec.num_children): + arg_nodes.append(self.graph.placeholder(f"_positional_arg_{idx}")) + kwarg_nodes = {} + for name in kwargs_spec.context: + kwarg_nodes[name] = self.graph.placeholder(name) + flat_args = _generate_flatten( + self.module, + (tuple(arg_nodes), kwarg_nodes), + signature.in_spec, + ) + for idx, arg in enumerate(signature.inputs): + flat_arg_node = self.graph.create_node( + op="call_function", + target=operator.getitem, + args=(flat_args, idx), + name=arg.name + if not isinstance(arg, ConstantArgument) + else f"_constant_{idx}", + ) + if isinstance(arg, ConstantArgument): + continue + flat_arg_node.meta = copy.copy(self.seen_nodes[arg.name].meta) + self.node_to_placeholder[self.seen_nodes[arg.name]] = flat_arg_node + + with self.parent.graph.inserting_before(self.parent_call_module): + input_nodes: List[Optional[torch.fx.Node]] = [] + for input in signature.inputs: + if isinstance(input, ConstantArgument) and input.value is None: + input_nodes.append(None) + else: + assert isinstance(input, (TensorArgument, SymIntArgument)) + input_nodes.append( + self.parent.remap_input(self.seen_nodes[input.name]) + ) + + inputs_node = _generate_unflatten( + self.parent.module, + input_nodes, + signature.in_spec, + ) + + args_node = self.parent.graph.call_function( + operator.getitem, (inputs_node, 0) + ) + kwargs_node = self.parent.graph.call_function( + operator.getitem, (inputs_node, 1) + ) + arg_nodes = [ + self.parent.graph.call_function(operator.getitem, (args_node, i)) + for i in range(args_spec.num_children) + ] + kwarg_nodes = { + k: self.parent.graph.call_function( + operator.getitem, (kwargs_node, k) + ) + for k in kwargs_spec.context + } + assert self.parent_call_module is not None + self.parent_call_module.args = tuple(arg_nodes) + self.parent_call_module.kwargs = kwarg_nodes + + def add_placeholder(self, x): + assert x.graph is self.flat_graph + # x is not in subgraph, create a new placeholder for subgraph + with self.graph.inserting_before(None): + placeholder_node = self.graph.placeholder(x.name, type_expr=x.type) + # copy all meta fields, even if some fields might be irrelvant for + # the placeholder node + placeholder_node.meta = copy.copy(x.meta) + self.node_to_placeholder[x] = placeholder_node + + def remap_input(self, x): + assert x.graph is self.flat_graph + if x in self.node_map: + return self.node_map[x] + if x not in self.node_to_placeholder: + self.add_placeholder(x) + if self.parent_call_module is not None: + # Important to *prepend* the output to match how we are + # inserting placeholder nodes. + self.parent_call_module.insert_arg(0, self.parent.remap_input(x)) + return self.node_to_placeholder[x] + + def finalize_outputs(self): + orig_outputs = [] + + signature = self.module_call_graph.get(self.fqn) + if signature is not None and self.parent is not None: + for output in signature.outputs: + if isinstance(output, (TensorArgument, SymIntArgument)): + orig_outputs.append(self.seen_nodes[output.name]) + else: + raise RuntimeError( + f"Unsupported data type for output node: {output}" + ) + + tree_out_node = _generate_unflatten( + self.module, + tuple( + self.node_map[self.seen_nodes[output.name]] + for output in orig_outputs + ), + signature.out_spec, + ) + parent_out: Optional[torch.fx.Node] = _generate_flatten( + self.parent.module, self.parent_call_module, signature.out_spec + ) + graph_outputs: Union[torch.fx.Node, List[torch.fx.Node]] = tree_out_node + else: + graph_outputs = [] + # Iterate through nodes we have copied into self.graph. + for orig_node in self.node_map.keys(): + for user_node in orig_node.users: + if user_node.name not in self.seen_nodes: + # external user node, need to expose as an output + orig_outputs.append(orig_node) + graph_outputs.append(self.node_map[orig_node]) + break + + parent_out = self.parent_call_module + if len(graph_outputs) == 1: + graph_outputs = graph_outputs[0] + + assert isinstance(graph_outputs, (list, torch.fx.Node)) + + self.graph.output(graph_outputs) + + # Rewrite outputs in parent module + if parent_out is None: + return + + parent_out.meta["val"] = ( + graph_outputs.meta.get("val") + if isinstance(graph_outputs, torch.fx.Node) + else [o.meta.get("val") for o in graph_outputs] + ) + + if len(orig_outputs) == 1 and signature is None: + self.parent.node_map[orig_outputs[0]] = parent_out + else: + for i, orig_output in enumerate(orig_outputs): + # Use Proxy to record getitem access. + proxy_out = torch.fx.Proxy(parent_out)[i].node # type: ignore[index] + proxy_out.meta["val"] = orig_output.meta.get("val") + self.parent.node_map[orig_output] = proxy_out + + if self.cached_graph_module is not None: + _verify_graph_equivalence(self.cached_graph_module, self.module) + + def copy_node(self, node): + self.print("copying", node.format_node()) + self.node_map[node] = self.graph.node_copy(node, self.remap_input) + self.seen_nodes[node.name] = node + + def run_outer(self): + i = 0 + for node in self.flat_graph.nodes: + self.print(i, node.meta.get("nn_module_stack"), node.format_node()) + i += 1 + + # Copy all graph inputs + node_idx: int = 0 + node = self.nodes[node_idx] + while node.op == "placeholder": + self.copy_node(node) + node_idx += 1 + node = self.nodes[node_idx] + + self.run_from(node_idx) + + # Copy graph outputs + for node in self.flat_graph.nodes: + if node.op == "output": + self.copy_node(node) + + def print(self, *args, **kwargs): + if self.verbose: + print(*args, **kwargs) + + def run_from(self, node_idx): + module_idx = 0 + # Walk through the graph, building up a new graph with the right submodules + while node_idx < len(self.nodes): + node = self.nodes[node_idx] + assert node.op != "placeholder" + + self.print() + self.print("STEP", node_idx, node.format_node()) + self.print(self.module_stack) + if node.op == "output": + if len(self.module_stack) == 1: + # We want the output node of the original graph to be handled + # specially by the outermost stack frame (in run_outer). So + # skip finalization here. + return node_idx + + # We've reached the end of the graph. Wrap up all the existing stack frames. + self.finalize_outputs() + return node_idx + + node_module_stack = ( + [path for path, ty in node.meta["nn_module_stack"].values()] + if "nn_module_stack" in node.meta + else self.module_stack + ) + if node_module_stack[: len(self.module_stack)] != self.module_stack: + # This means that the current module is done executing and the + # current node is the beginning of a new module. + # + # In this case, we should finalize this module and return without + # incrementing the node counter. + self.finalize_outputs() + self.print("outlining", self.fqn) + self.print(self.graph) + return node_idx + + assert node_module_stack is not None + + if _is_prefix(self.module_stack, node_module_stack): + # This means that the current node represents the execution of a new + # module. + next_module = node_module_stack[len(self.module_stack)] + self.print("Creating new stack frame for", next_module) + # Run a nested version of module outliner from the current node + # counter. Once it is complete, continue from that point. + node_idx = _ModuleFrame( + self.flat_graph, + self.nodes, + self.seen_nodes, + self.seen_modules, + self, + self.module_stack + [next_module], + list(node.meta["nn_module_stack"].keys())[len(self.module_stack)], + self.module_call_graph, + ).run_from(node_idx) + module_idx += 1 + continue + + # The only remaining possibility is that we are in the right stack + # frame. Copy the node into this frame's graph and increment the node counter. + assert node_module_stack == self.module_stack + self.copy_node(node) + node_idx += 1 + + +def _outline_submodules(orig_graph: torch.fx.Graph, root_module: UnflattenedModule): + seen_nodes: Dict[str, torch.fx.Node] = {} + seen_modules: Dict[int, torch.nn.Module] = {} + _ModuleFrame( + orig_graph, + tuple(orig_graph.nodes), + seen_nodes, + seen_modules, + None, + [""], + "", + { + entry.fqn: entry.signature + for entry in root_module.module_call_graph + if entry.signature + }, + module=root_module, + ).run_outer() + + +def _sink_params( + module: torch.nn.Module, + inputs_to_state: Dict[str, str], + scope: List[str], +): + """Sink params, buffers, and constants from graph inputs into get_attr nodes. + + Exported modules are purely functional, so they pass their parameters and + buffers in as inputs to the graph. + + To replicate eager's semantics, we need to get them from the module state + via get_attr instead. + + module: GraphModule, potentially containining nested submodules. + inputs_to_state: mapping graph input names to the corresponding key in the state_dict. + scope: tracks where we are in the module hierarchy, so that we can emit the + right `getattr(self, "foo.bar")` calls, etc. + """ + # We need to use _modules here instead of named_children(), because we + # explicitly want duplicate modules to show up in the traversal. + for name, submodule in module._modules.items(): + _sink_params(cast(torch.nn.Module, submodule), inputs_to_state, scope + [name]) + + if not hasattr(module, "graph"): + # Not all modules have graphs defined, if they are empty modules with no operations (like ParameterList) + return + + graph = module.graph + inputs = list(filter(lambda n: n.op == "placeholder", graph.nodes)) + the_last_input = inputs[-1] + + # Also remove from call_module nodes + call_module_nodes = filter(lambda n: n.op == "call_module", graph.nodes) + for node in call_module_nodes: + node.args = tuple(filter(lambda n: n.name not in inputs_to_state, node.args)) + + for node in inputs: + if node.name not in inputs_to_state: + continue + + if len(node.users) > 0: + state_name = inputs_to_state[node.name].split(".") + # If there's a mismatch beteewn scope name and state name, then there must be multuple scopes + # pointing to the same state name, meaning some modules are shared. In such case, we can simply + # skip updating the current node because another later iteration will take care of this input + # node when the unique match between scope and state name occurs. + # To make sure this always happen, we should enforce the invariant that no placeholder node + # in the unflattened graph appears in inputs_to_state dict, which means all the extra input + # nodes have been handled. + if state_name[: len(scope)] != scope: + continue + attr_path = state_name[len(scope) :] + state_attr = _recursive_getattr(module, attr_path) + assert isinstance(state_attr, (torch.Tensor, torch.ScriptObject)) + + # Make sure the newly created get_attr node is placed after the last placeholder node + with graph.inserting_after(the_last_input): + new_node = graph.create_node("get_attr", ".".join(attr_path)) + + node.replace_all_uses_with(new_node, propagate_meta=True) + graph.erase_node(node) + if isinstance(module, InterpreterModule): + module.finalize() + + +def _recursive_getattr(obj, attr_path): + for attr in attr_path: + obj = getattr(obj, attr) + + return obj diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/linalg/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/linalg/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d33c527d66df6ee28cbf738bb5bf88d008498855 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/linalg/__init__.py @@ -0,0 +1,2848 @@ +import sys + +import torch +from torch._C import _add_docstr, _linalg # type: ignore[attr-defined] + +LinAlgError = torch._C._LinAlgError # type: ignore[attr-defined] + +Tensor = torch.Tensor + +common_notes = { + "experimental_warning": """This function is "experimental" and it may change in a future PyTorch release.""", + "sync_note": "When inputs are on a CUDA device, this function synchronizes that device with the CPU.", + "sync_note_ex": r"When the inputs are on a CUDA device, this function synchronizes only when :attr:`check_errors`\ `= True`.", + "sync_note_has_ex": ("When inputs are on a CUDA device, this function synchronizes that device with the CPU. " + "For a version of this function that does not synchronize, see :func:`{}`.") +} + + +# Note: This not only adds doc strings for functions in the linalg namespace, but +# also connects the torch.linalg Python namespace to the torch._C._linalg builtins. + +cross = _add_docstr(_linalg.linalg_cross, r""" +linalg.cross(input, other, *, dim=-1, out=None) -> Tensor + + +Computes the cross product of two 3-dimensional vectors. + +Supports input of float, double, cfloat and cdouble dtypes. Also supports batches +of vectors, for which it computes the product along the dimension :attr:`dim`. +It broadcasts over the batch dimensions. + +Args: + input (Tensor): the first input tensor. + other (Tensor): the second input tensor. + dim (int, optional): the dimension along which to take the cross-product. Default: `-1`. + +Keyword args: + out (Tensor, optional): the output tensor. Ignored if `None`. Default: `None`. + +Example: + >>> a = torch.randn(4, 3) + >>> a + tensor([[-0.3956, 1.1455, 1.6895], + [-0.5849, 1.3672, 0.3599], + [-1.1626, 0.7180, -0.0521], + [-0.1339, 0.9902, -2.0225]]) + >>> b = torch.randn(4, 3) + >>> b + tensor([[-0.0257, -1.4725, -1.2251], + [-1.1479, -0.7005, -1.9757], + [-1.3904, 0.3726, -1.1836], + [-0.9688, -0.7153, 0.2159]]) + >>> torch.linalg.cross(a, b) + tensor([[ 1.0844, -0.5281, 0.6120], + [-2.4490, -1.5687, 1.9792], + [-0.8304, -1.3037, 0.5650], + [-1.2329, 1.9883, 1.0551]]) + >>> a = torch.randn(1, 3) # a is broadcast to match shape of b + >>> a + tensor([[-0.9941, -0.5132, 0.5681]]) + >>> torch.linalg.cross(a, b) + tensor([[ 1.4653, -1.2325, 1.4507], + [ 1.4119, -2.6163, 0.1073], + [ 0.3957, -1.9666, -1.0840], + [ 0.2956, -0.3357, 0.2139]]) +""") + +cholesky = _add_docstr(_linalg.linalg_cholesky, r""" +linalg.cholesky(A, *, upper=False, out=None) -> Tensor + +Computes the Cholesky decomposition of a complex Hermitian or real symmetric positive-definite matrix. + +Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, +the **Cholesky decomposition** of a complex Hermitian or real symmetric positive-definite matrix +:math:`A \in \mathbb{K}^{n \times n}` is defined as + +.. math:: + + A = LL^{\text{H}}\mathrlap{\qquad L \in \mathbb{K}^{n \times n}} + +where :math:`L` is a lower triangular matrix with real positive diagonal (even in the complex case) and +:math:`L^{\text{H}}` is the conjugate transpose when :math:`L` is complex, and the transpose when :math:`L` is real-valued. + +Supports input of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if :attr:`A` is a batch of matrices then +the output has the same batch dimensions. + +""" + fr""" +.. note:: {common_notes["sync_note_has_ex"].format("torch.linalg.cholesky_ex")} +""" + r""" + +.. seealso:: + + :func:`torch.linalg.cholesky_ex` for a version of this operation that + skips the (slow) error checking by default and instead returns the debug + information. This makes it a faster way to check if a matrix is + positive-definite. + + :func:`torch.linalg.eigh` for a different decomposition of a Hermitian matrix. + The eigenvalue decomposition gives more information about the matrix but it + slower to compute than the Cholesky decomposition. + +Args: + A (Tensor): tensor of shape `(*, n, n)` where `*` is zero or more batch dimensions + consisting of symmetric or Hermitian positive-definite matrices. + +Keyword args: + upper (bool, optional): whether to return an upper triangular matrix. + The tensor returned with upper=True is the conjugate transpose of the tensor + returned with upper=False. + out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. + +Raises: + RuntimeError: if the :attr:`A` matrix or any matrix in a batched :attr:`A` is not Hermitian + (resp. symmetric) positive-definite. If :attr:`A` is a batch of matrices, + the error message will include the batch index of the first matrix that fails + to meet this condition. + +Examples:: + + >>> A = torch.randn(2, 2, dtype=torch.complex128) + >>> A = A @ A.T.conj() + torch.eye(2) # creates a Hermitian positive-definite matrix + >>> A + tensor([[2.5266+0.0000j, 1.9586-2.0626j], + [1.9586+2.0626j, 9.4160+0.0000j]], dtype=torch.complex128) + >>> L = torch.linalg.cholesky(A) + >>> L + tensor([[1.5895+0.0000j, 0.0000+0.0000j], + [1.2322+1.2976j, 2.4928+0.0000j]], dtype=torch.complex128) + >>> torch.dist(L @ L.T.conj(), A) + tensor(4.4692e-16, dtype=torch.float64) + + >>> A = torch.randn(3, 2, 2, dtype=torch.float64) + >>> A = A @ A.mT + torch.eye(2) # batch of symmetric positive-definite matrices + >>> L = torch.linalg.cholesky(A) + >>> torch.dist(L @ L.mT, A) + tensor(5.8747e-16, dtype=torch.float64) +""") + +cholesky_ex = _add_docstr(_linalg.linalg_cholesky_ex, r""" +linalg.cholesky_ex(A, *, upper=False, check_errors=False, out=None) -> (Tensor, Tensor) + +Computes the Cholesky decomposition of a complex Hermitian or real +symmetric positive-definite matrix. + +This function skips the (slow) error checking and error message construction +of :func:`torch.linalg.cholesky`, instead directly returning the LAPACK +error codes as part of a named tuple ``(L, info)``. This makes this function +a faster way to check if a matrix is positive-definite, and it provides an +opportunity to handle decomposition errors more gracefully or performantly +than :func:`torch.linalg.cholesky` does. + +Supports input of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if :attr:`A` is a batch of matrices then +the output has the same batch dimensions. + +If :attr:`A` is not a Hermitian positive-definite matrix, or if it's a batch of matrices +and one or more of them is not a Hermitian positive-definite matrix, +then ``info`` stores a positive integer for the corresponding matrix. +The positive integer indicates the order of the leading minor that is not positive-definite, +and the decomposition could not be completed. +``info`` filled with zeros indicates that the decomposition was successful. +If ``check_errors=True`` and ``info`` contains positive integers, then a RuntimeError is thrown. + +""" + fr""" +.. note:: {common_notes["sync_note_ex"]} + +.. warning:: {common_notes["experimental_warning"]} +""" + r""" + +.. seealso:: + :func:`torch.linalg.cholesky` is a NumPy compatible variant that always checks for errors. + +Args: + A (Tensor): the Hermitian `n \times n` matrix or the batch of such matrices of size + `(*, n, n)` where `*` is one or more batch dimensions. + +Keyword args: + upper (bool, optional): whether to return an upper triangular matrix. + The tensor returned with upper=True is the conjugate transpose of the tensor + returned with upper=False. + check_errors (bool, optional): controls whether to check the content of ``infos``. Default: `False`. + out (tuple, optional): tuple of two tensors to write the output to. Ignored if `None`. Default: `None`. + +Examples:: + + >>> A = torch.randn(2, 2, dtype=torch.complex128) + >>> A = A @ A.t().conj() # creates a Hermitian positive-definite matrix + >>> L, info = torch.linalg.cholesky_ex(A) + >>> A + tensor([[ 2.3792+0.0000j, -0.9023+0.9831j], + [-0.9023-0.9831j, 0.8757+0.0000j]], dtype=torch.complex128) + >>> L + tensor([[ 1.5425+0.0000j, 0.0000+0.0000j], + [-0.5850-0.6374j, 0.3567+0.0000j]], dtype=torch.complex128) + >>> info + tensor(0, dtype=torch.int32) + +""") + +inv = _add_docstr(_linalg.linalg_inv, r""" +linalg.inv(A, *, out=None) -> Tensor + +Computes the inverse of a square matrix if it exists. +Throws a `RuntimeError` if the matrix is not invertible. + +Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, +for a matrix :math:`A \in \mathbb{K}^{n \times n}`, +its **inverse matrix** :math:`A^{-1} \in \mathbb{K}^{n \times n}` (if it exists) is defined as + +.. math:: + + A^{-1}A = AA^{-1} = \mathrm{I}_n + +where :math:`\mathrm{I}_n` is the `n`-dimensional identity matrix. + +The inverse matrix exists if and only if :math:`A` is `invertible`_. In this case, +the inverse is unique. + +Supports input of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if :attr:`A` is a batch of matrices +then the output has the same batch dimensions. + +""" + fr""" +.. note:: {common_notes["sync_note_has_ex"].format("torch.linalg.inv_ex")} +""" + r""" + +.. note:: + Consider using :func:`torch.linalg.solve` if possible for multiplying a matrix on the left by + the inverse, as:: + + linalg.solve(A, B) == linalg.inv(A) @ B # When B is a matrix + + It is always preferred to use :func:`~solve` when possible, as it is faster and more + numerically stable than computing the inverse explicitly. + +.. seealso:: + + :func:`torch.linalg.pinv` computes the pseudoinverse (Moore-Penrose inverse) of matrices + of any shape. + + :func:`torch.linalg.solve` computes :attr:`A`\ `.inv() @ \ `:attr:`B` with a + numerically stable algorithm. + +Args: + A (Tensor): tensor of shape `(*, n, n)` where `*` is zero or more batch dimensions + consisting of invertible matrices. + +Keyword args: + out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. + +Raises: + RuntimeError: if the matrix :attr:`A` or any matrix in the batch of matrices :attr:`A` is not invertible. + +Examples:: + + >>> A = torch.randn(4, 4) + >>> Ainv = torch.linalg.inv(A) + >>> torch.dist(A @ Ainv, torch.eye(4)) + tensor(1.1921e-07) + + >>> A = torch.randn(2, 3, 4, 4) # Batch of matrices + >>> Ainv = torch.linalg.inv(A) + >>> torch.dist(A @ Ainv, torch.eye(4)) + tensor(1.9073e-06) + + >>> A = torch.randn(4, 4, dtype=torch.complex128) # Complex matrix + >>> Ainv = torch.linalg.inv(A) + >>> torch.dist(A @ Ainv, torch.eye(4)) + tensor(7.5107e-16, dtype=torch.float64) + +.. _invertible: + https://en.wikipedia.org/wiki/Invertible_matrix#The_invertible_matrix_theorem +""") + +solve_ex = _add_docstr(_linalg.linalg_solve_ex, r""" +linalg.solve_ex(A, B, *, left=True, check_errors=False, out=None) -> (Tensor, Tensor) + +A version of :func:`~solve` that does not perform error checks unless :attr:`check_errors`\ `= True`. +It also returns the :attr:`info` tensor returned by `LAPACK's getrf`_. + +""" + fr""" +.. note:: {common_notes["sync_note_ex"]} + +.. warning:: {common_notes["experimental_warning"]} +""" + r""" + +Args: + A (Tensor): tensor of shape `(*, n, n)` where `*` is zero or more batch dimensions. + +Keyword args: + left (bool, optional): whether to solve the system :math:`AX=B` or :math:`XA = B`. Default: `True`. + check_errors (bool, optional): controls whether to check the content of ``infos`` and raise + an error if it is non-zero. Default: `False`. + out (tuple, optional): tuple of two tensors to write the output to. Ignored if `None`. Default: `None`. + +Returns: + A named tuple `(result, info)`. + +Examples:: + + >>> A = torch.randn(3, 3) + >>> Ainv, info = torch.linalg.solve_ex(A) + >>> torch.dist(torch.linalg.inv(A), Ainv) + tensor(0.) + >>> info + tensor(0, dtype=torch.int32) + +.. _LAPACK's getrf: + https://www.netlib.org/lapack/explore-html/dd/d9a/group__double_g_ecomputational_ga0019443faea08275ca60a734d0593e60.html +""") + +inv_ex = _add_docstr(_linalg.linalg_inv_ex, r""" +linalg.inv_ex(A, *, check_errors=False, out=None) -> (Tensor, Tensor) + +Computes the inverse of a square matrix if it is invertible. + +Returns a namedtuple ``(inverse, info)``. ``inverse`` contains the result of +inverting :attr:`A` and ``info`` stores the LAPACK error codes. + +If :attr:`A` is not an invertible matrix, or if it's a batch of matrices +and one or more of them is not an invertible matrix, +then ``info`` stores a positive integer for the corresponding matrix. +The positive integer indicates the diagonal element of the LU decomposition of +the input matrix that is exactly zero. +``info`` filled with zeros indicates that the inversion was successful. +If ``check_errors=True`` and ``info`` contains positive integers, then a RuntimeError is thrown. + +Supports input of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if :attr:`A` is a batch of matrices then +the output has the same batch dimensions. + +""" + fr""" +.. note:: {common_notes["sync_note_ex"]} + +.. warning:: {common_notes["experimental_warning"]} +""" + r""" + +.. seealso:: + + :func:`torch.linalg.inv` is a NumPy compatible variant that always checks for errors. + +Args: + A (Tensor): tensor of shape `(*, n, n)` where `*` is zero or more batch dimensions + consisting of square matrices. + check_errors (bool, optional): controls whether to check the content of ``info``. Default: `False`. + +Keyword args: + out (tuple, optional): tuple of two tensors to write the output to. Ignored if `None`. Default: `None`. + +Examples:: + + >>> A = torch.randn(3, 3) + >>> Ainv, info = torch.linalg.inv_ex(A) + >>> torch.dist(torch.linalg.inv(A), Ainv) + tensor(0.) + >>> info + tensor(0, dtype=torch.int32) + +""") + +det = _add_docstr(_linalg.linalg_det, r""" +linalg.det(A, *, out=None) -> Tensor + +Computes the determinant of a square matrix. + +Supports input of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if :attr:`A` is a batch of matrices then +the output has the same batch dimensions. + +.. seealso:: + + :func:`torch.linalg.slogdet` computes the sign and natural logarithm of the absolute + value of the determinant of square matrices. + +Args: + A (Tensor): tensor of shape `(*, n, n)` where `*` is zero or more batch dimensions. + +Keyword args: + out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. + +Examples:: + + >>> A = torch.randn(3, 3) + >>> torch.linalg.det(A) + tensor(0.0934) + + >>> A = torch.randn(3, 2, 2) + >>> torch.linalg.det(A) + tensor([1.1990, 0.4099, 0.7386]) +""") + +slogdet = _add_docstr(_linalg.linalg_slogdet, r""" +linalg.slogdet(A, *, out=None) -> (Tensor, Tensor) + +Computes the sign and natural logarithm of the absolute value of the determinant of a square matrix. + +For complex :attr:`A`, it returns the sign and the natural logarithm of the modulus of the +determinant, that is, a logarithmic polar decomposition of the determinant. + +The determinant can be recovered as `sign * exp(logabsdet)`. +When a matrix has a determinant of zero, it returns `(0, -inf)`. + +Supports input of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if :attr:`A` is a batch of matrices then +the output has the same batch dimensions. + +.. seealso:: + + :func:`torch.linalg.det` computes the determinant of square matrices. + +Args: + A (Tensor): tensor of shape `(*, n, n)` where `*` is zero or more batch dimensions. + +Keyword args: + out (tuple, optional): output tuple of two tensors. Ignored if `None`. Default: `None`. + +Returns: + A named tuple `(sign, logabsdet)`. + + `sign` will have the same dtype as :attr:`A`. + + `logabsdet` will always be real-valued, even when :attr:`A` is complex. + +Examples:: + + >>> A = torch.randn(3, 3) + >>> A + tensor([[ 0.0032, -0.2239, -1.1219], + [-0.6690, 0.1161, 0.4053], + [-1.6218, -0.9273, -0.0082]]) + >>> torch.linalg.det(A) + tensor(-0.7576) + >>> torch.logdet(A) + tensor(nan) + >>> torch.linalg.slogdet(A) + torch.return_types.linalg_slogdet(sign=tensor(-1.), logabsdet=tensor(-0.2776)) +""") + +eig = _add_docstr(_linalg.linalg_eig, r""" +linalg.eig(A, *, out=None) -> (Tensor, Tensor) + +Computes the eigenvalue decomposition of a square matrix if it exists. + +Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, +the **eigenvalue decomposition** of a square matrix +:math:`A \in \mathbb{K}^{n \times n}` (if it exists) is defined as + +.. math:: + + A = V \operatorname{diag}(\Lambda) V^{-1}\mathrlap{\qquad V \in \mathbb{C}^{n \times n}, \Lambda \in \mathbb{C}^n} + +This decomposition exists if and only if :math:`A` is `diagonalizable`_. +This is the case when all its eigenvalues are different. + +Supports input of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if :attr:`A` is a batch of matrices then +the output has the same batch dimensions. + +.. note:: The eigenvalues and eigenvectors of a real matrix may be complex. + +""" + fr""" +.. note:: {common_notes["sync_note"]} +""" + r""" + +.. warning:: This function assumes that :attr:`A` is `diagonalizable`_ (for example, when all the + eigenvalues are different). If it is not diagonalizable, the returned + eigenvalues will be correct but :math:`A \neq V \operatorname{diag}(\Lambda)V^{-1}`. + +.. warning:: The returned eigenvectors are normalized to have norm `1`. + Even then, the eigenvectors of a matrix are not unique, nor are they continuous with respect to + :attr:`A`. Due to this lack of uniqueness, different hardware and software may compute + different eigenvectors. + + This non-uniqueness is caused by the fact that multiplying an eigenvector by + by :math:`e^{i \phi}, \phi \in \mathbb{R}` produces another set of valid eigenvectors + of the matrix. For this reason, the loss function shall not depend on the phase of the + eigenvectors, as this quantity is not well-defined. + This is checked when computing the gradients of this function. As such, + when inputs are on a CUDA device, the computation of the gradients + of this function synchronizes that device with the CPU. + + +.. warning:: Gradients computed using the `eigenvectors` tensor will only be finite when + :attr:`A` has distinct eigenvalues. + Furthermore, if the distance between any two eigenvalues is close to zero, + the gradient will be numerically unstable, as it depends on the eigenvalues + :math:`\lambda_i` through the computation of + :math:`\frac{1}{\min_{i \neq j} \lambda_i - \lambda_j}`. + +.. seealso:: + + :func:`torch.linalg.eigvals` computes only the eigenvalues. + Unlike :func:`torch.linalg.eig`, the gradients of :func:`~eigvals` are always + numerically stable. + + :func:`torch.linalg.eigh` for a (faster) function that computes the eigenvalue decomposition + for Hermitian and symmetric matrices. + + :func:`torch.linalg.svd` for a function that computes another type of spectral + decomposition that works on matrices of any shape. + + :func:`torch.linalg.qr` for another (much faster) decomposition that works on matrices of + any shape. + +Args: + A (Tensor): tensor of shape `(*, n, n)` where `*` is zero or more batch dimensions + consisting of diagonalizable matrices. + +Keyword args: + out (tuple, optional): output tuple of two tensors. Ignored if `None`. Default: `None`. + +Returns: + A named tuple `(eigenvalues, eigenvectors)` which corresponds to :math:`\Lambda` and :math:`V` above. + + `eigenvalues` and `eigenvectors` will always be complex-valued, even when :attr:`A` is real. The eigenvectors + will be given by the columns of `eigenvectors`. + +Examples:: + + >>> A = torch.randn(2, 2, dtype=torch.complex128) + >>> A + tensor([[ 0.9828+0.3889j, -0.4617+0.3010j], + [ 0.1662-0.7435j, -0.6139+0.0562j]], dtype=torch.complex128) + >>> L, V = torch.linalg.eig(A) + >>> L + tensor([ 1.1226+0.5738j, -0.7537-0.1286j], dtype=torch.complex128) + >>> V + tensor([[ 0.9218+0.0000j, 0.1882-0.2220j], + [-0.0270-0.3867j, 0.9567+0.0000j]], dtype=torch.complex128) + >>> torch.dist(V @ torch.diag(L) @ torch.linalg.inv(V), A) + tensor(7.7119e-16, dtype=torch.float64) + + >>> A = torch.randn(3, 2, 2, dtype=torch.float64) + >>> L, V = torch.linalg.eig(A) + >>> torch.dist(V @ torch.diag_embed(L) @ torch.linalg.inv(V), A) + tensor(3.2841e-16, dtype=torch.float64) + +.. _diagonalizable: + https://en.wikipedia.org/wiki/Diagonalizable_matrix#Definition +""") + +eigvals = _add_docstr(_linalg.linalg_eigvals, r""" +linalg.eigvals(A, *, out=None) -> Tensor + +Computes the eigenvalues of a square matrix. + +Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, +the **eigenvalues** of a square matrix :math:`A \in \mathbb{K}^{n \times n}` are defined +as the roots (counted with multiplicity) of the polynomial `p` of degree `n` given by + +.. math:: + + p(\lambda) = \operatorname{det}(A - \lambda \mathrm{I}_n)\mathrlap{\qquad \lambda \in \mathbb{C}} + +where :math:`\mathrm{I}_n` is the `n`-dimensional identity matrix. + +Supports input of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if :attr:`A` is a batch of matrices then +the output has the same batch dimensions. + +.. note:: The eigenvalues of a real matrix may be complex, as the roots of a real polynomial may be complex. + + The eigenvalues of a matrix are always well-defined, even when the matrix is not diagonalizable. + +""" + fr""" +.. note:: {common_notes["sync_note"]} +""" + r""" + +.. seealso:: + + :func:`torch.linalg.eig` computes the full eigenvalue decomposition. + +Args: + A (Tensor): tensor of shape `(*, n, n)` where `*` is zero or more batch dimensions. + +Keyword args: + out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. + +Returns: + A complex-valued tensor containing the eigenvalues even when :attr:`A` is real. + +Examples:: + + >>> A = torch.randn(2, 2, dtype=torch.complex128) + >>> L = torch.linalg.eigvals(A) + >>> L + tensor([ 1.1226+0.5738j, -0.7537-0.1286j], dtype=torch.complex128) + + >>> torch.dist(L, torch.linalg.eig(A).eigenvalues) + tensor(2.4576e-07) +""") + +eigh = _add_docstr(_linalg.linalg_eigh, r""" +linalg.eigh(A, UPLO='L', *, out=None) -> (Tensor, Tensor) + +Computes the eigenvalue decomposition of a complex Hermitian or real symmetric matrix. + +Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, +the **eigenvalue decomposition** of a complex Hermitian or real symmetric matrix +:math:`A \in \mathbb{K}^{n \times n}` is defined as + +.. math:: + + A = Q \operatorname{diag}(\Lambda) Q^{\text{H}}\mathrlap{\qquad Q \in \mathbb{K}^{n \times n}, \Lambda \in \mathbb{R}^n} + +where :math:`Q^{\text{H}}` is the conjugate transpose when :math:`Q` is complex, and the transpose when :math:`Q` is real-valued. +:math:`Q` is orthogonal in the real case and unitary in the complex case. + +Supports input of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if :attr:`A` is a batch of matrices then +the output has the same batch dimensions. + +:attr:`A` is assumed to be Hermitian (resp. symmetric), but this is not checked internally, instead: + +- If :attr:`UPLO`\ `= 'L'` (default), only the lower triangular part of the matrix is used in the computation. +- If :attr:`UPLO`\ `= 'U'`, only the upper triangular part of the matrix is used. + +The eigenvalues are returned in ascending order. + +""" + fr""" +.. note:: {common_notes["sync_note"]} +""" + r""" + +.. note:: The eigenvalues of real symmetric or complex Hermitian matrices are always real. + +.. warning:: The eigenvectors of a symmetric matrix are not unique, nor are they continuous with + respect to :attr:`A`. Due to this lack of uniqueness, different hardware and + software may compute different eigenvectors. + + This non-uniqueness is caused by the fact that multiplying an eigenvector by + `-1` in the real case or by :math:`e^{i \phi}, \phi \in \mathbb{R}` in the complex + case produces another set of valid eigenvectors of the matrix. + For this reason, the loss function shall not depend on the phase of the eigenvectors, as + this quantity is not well-defined. + This is checked for complex inputs when computing the gradients of this function. As such, + when inputs are complex and are on a CUDA device, the computation of the gradients + of this function synchronizes that device with the CPU. + +.. warning:: Gradients computed using the `eigenvectors` tensor will only be finite when + :attr:`A` has distinct eigenvalues. + Furthermore, if the distance between any two eigenvalues is close to zero, + the gradient will be numerically unstable, as it depends on the eigenvalues + :math:`\lambda_i` through the computation of + :math:`\frac{1}{\min_{i \neq j} \lambda_i - \lambda_j}`. + +.. warning:: User may see pytorch crashes if running `eigh` on CUDA devices with CUDA versions before 12.1 update 1 + with large ill-conditioned matrices as inputs. + Refer to :ref:`Linear Algebra Numerical Stability` for more details. + If this is the case, user may (1) tune their matrix inputs to be less ill-conditioned, + or (2) use :func:`torch.backends.cuda.preferred_linalg_library` to + try other supported backends. + +.. seealso:: + + :func:`torch.linalg.eigvalsh` computes only the eigenvalues of a Hermitian matrix. + Unlike :func:`torch.linalg.eigh`, the gradients of :func:`~eigvalsh` are always + numerically stable. + + :func:`torch.linalg.cholesky` for a different decomposition of a Hermitian matrix. + The Cholesky decomposition gives less information about the matrix but is much faster + to compute than the eigenvalue decomposition. + + :func:`torch.linalg.eig` for a (slower) function that computes the eigenvalue decomposition + of a not necessarily Hermitian square matrix. + + :func:`torch.linalg.svd` for a (slower) function that computes the more general SVD + decomposition of matrices of any shape. + + :func:`torch.linalg.qr` for another (much faster) decomposition that works on general + matrices. + +Args: + A (Tensor): tensor of shape `(*, n, n)` where `*` is zero or more batch dimensions + consisting of symmetric or Hermitian matrices. + UPLO ('L', 'U', optional): controls whether to use the upper or lower triangular part + of :attr:`A` in the computations. Default: `'L'`. + +Keyword args: + out (tuple, optional): output tuple of two tensors. Ignored if `None`. Default: `None`. + +Returns: + A named tuple `(eigenvalues, eigenvectors)` which corresponds to :math:`\Lambda` and :math:`Q` above. + + `eigenvalues` will always be real-valued, even when :attr:`A` is complex. + It will also be ordered in ascending order. + + `eigenvectors` will have the same dtype as :attr:`A` and will contain the eigenvectors as its columns. + +Examples:: + >>> A = torch.randn(2, 2, dtype=torch.complex128) + >>> A = A + A.T.conj() # creates a Hermitian matrix + >>> A + tensor([[2.9228+0.0000j, 0.2029-0.0862j], + [0.2029+0.0862j, 0.3464+0.0000j]], dtype=torch.complex128) + >>> L, Q = torch.linalg.eigh(A) + >>> L + tensor([0.3277, 2.9415], dtype=torch.float64) + >>> Q + tensor([[-0.0846+-0.0000j, -0.9964+0.0000j], + [ 0.9170+0.3898j, -0.0779-0.0331j]], dtype=torch.complex128) + >>> torch.dist(Q @ torch.diag(L.cdouble()) @ Q.T.conj(), A) + tensor(6.1062e-16, dtype=torch.float64) + + >>> A = torch.randn(3, 2, 2, dtype=torch.float64) + >>> A = A + A.mT # creates a batch of symmetric matrices + >>> L, Q = torch.linalg.eigh(A) + >>> torch.dist(Q @ torch.diag_embed(L) @ Q.mH, A) + tensor(1.5423e-15, dtype=torch.float64) +""") + +eigvalsh = _add_docstr(_linalg.linalg_eigvalsh, r""" +linalg.eigvalsh(A, UPLO='L', *, out=None) -> Tensor + +Computes the eigenvalues of a complex Hermitian or real symmetric matrix. + +Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, +the **eigenvalues** of a complex Hermitian or real symmetric matrix :math:`A \in \mathbb{K}^{n \times n}` +are defined as the roots (counted with multiplicity) of the polynomial `p` of degree `n` given by + +.. math:: + + p(\lambda) = \operatorname{det}(A - \lambda \mathrm{I}_n)\mathrlap{\qquad \lambda \in \mathbb{R}} + +where :math:`\mathrm{I}_n` is the `n`-dimensional identity matrix. +The eigenvalues of a real symmetric or complex Hermitian matrix are always real. + +Supports input of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if :attr:`A` is a batch of matrices then +the output has the same batch dimensions. + +The eigenvalues are returned in ascending order. + +:attr:`A` is assumed to be Hermitian (resp. symmetric), but this is not checked internally, instead: + +- If :attr:`UPLO`\ `= 'L'` (default), only the lower triangular part of the matrix is used in the computation. +- If :attr:`UPLO`\ `= 'U'`, only the upper triangular part of the matrix is used. + +""" + fr""" +.. note:: {common_notes["sync_note"]} +""" + r""" + +.. seealso:: + + :func:`torch.linalg.eigh` computes the full eigenvalue decomposition. + +Args: + A (Tensor): tensor of shape `(*, n, n)` where `*` is zero or more batch dimensions + consisting of symmetric or Hermitian matrices. + UPLO ('L', 'U', optional): controls whether to use the upper or lower triangular part + of :attr:`A` in the computations. Default: `'L'`. + +Keyword args: + out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. + +Returns: + A real-valued tensor containing the eigenvalues even when :attr:`A` is complex. + The eigenvalues are returned in ascending order. + +Examples:: + + >>> A = torch.randn(2, 2, dtype=torch.complex128) + >>> A = A + A.T.conj() # creates a Hermitian matrix + >>> A + tensor([[2.9228+0.0000j, 0.2029-0.0862j], + [0.2029+0.0862j, 0.3464+0.0000j]], dtype=torch.complex128) + >>> torch.linalg.eigvalsh(A) + tensor([0.3277, 2.9415], dtype=torch.float64) + + >>> A = torch.randn(3, 2, 2, dtype=torch.float64) + >>> A = A + A.mT # creates a batch of symmetric matrices + >>> torch.linalg.eigvalsh(A) + tensor([[ 2.5797, 3.4629], + [-4.1605, 1.3780], + [-3.1113, 2.7381]], dtype=torch.float64) +""") + +householder_product = _add_docstr(_linalg.linalg_householder_product, r""" +householder_product(A, tau, *, out=None) -> Tensor + +Computes the first `n` columns of a product of Householder matrices. + +Let :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, and +let :math:`V \in \mathbb{K}^{m \times n}` be a matrix with columns :math:`v_i \in \mathbb{K}^m` +for :math:`i=1,\ldots,m` with :math:`m \geq n`. Denote by :math:`w_i` the vector resulting from +zeroing out the first :math:`i-1` components of :math:`v_i` and setting to `1` the :math:`i`-th. +For a vector :math:`\tau \in \mathbb{K}^k` with :math:`k \leq n`, this function computes the +first :math:`n` columns of the matrix + +.. math:: + + H_1H_2 ... H_k \qquad\text{with}\qquad H_i = \mathrm{I}_m - \tau_i w_i w_i^{\text{H}} + +where :math:`\mathrm{I}_m` is the `m`-dimensional identity matrix and :math:`w^{\text{H}}` is the +conjugate transpose when :math:`w` is complex, and the transpose when :math:`w` is real-valued. +The output matrix is the same size as the input matrix :attr:`A`. + +See `Representation of Orthogonal or Unitary Matrices`_ for further details. + +Supports inputs of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if the inputs are batches of matrices then +the output has the same batch dimensions. + +.. seealso:: + + :func:`torch.geqrf` can be used together with this function to form the `Q` from the + :func:`~qr` decomposition. + + :func:`torch.ormqr` is a related function that computes the matrix multiplication + of a product of Householder matrices with another matrix. + However, that function is not supported by autograd. + +.. warning:: + Gradient computations are only well-defined if :math:`tau_i \neq \frac{1}{||v_i||^2}`. + If this condition is not met, no error will be thrown, but the gradient produced may contain `NaN`. + +Args: + A (Tensor): tensor of shape `(*, m, n)` where `*` is zero or more batch dimensions. + tau (Tensor): tensor of shape `(*, k)` where `*` is zero or more batch dimensions. + +Keyword args: + out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. + +Raises: + RuntimeError: if :attr:`A` doesn't satisfy the requirement `m >= n`, + or :attr:`tau` doesn't satisfy the requirement `n >= k`. + +Examples:: + + >>> A = torch.randn(2, 2) + >>> h, tau = torch.geqrf(A) + >>> Q = torch.linalg.householder_product(h, tau) + >>> torch.dist(Q, torch.linalg.qr(A).Q) + tensor(0.) + + >>> h = torch.randn(3, 2, 2, dtype=torch.complex128) + >>> tau = torch.randn(3, 1, dtype=torch.complex128) + >>> Q = torch.linalg.householder_product(h, tau) + >>> Q + tensor([[[ 1.8034+0.4184j, 0.2588-1.0174j], + [-0.6853+0.7953j, 2.0790+0.5620j]], + + [[ 1.4581+1.6989j, -1.5360+0.1193j], + [ 1.3877-0.6691j, 1.3512+1.3024j]], + + [[ 1.4766+0.5783j, 0.0361+0.6587j], + [ 0.6396+0.1612j, 1.3693+0.4481j]]], dtype=torch.complex128) + +.. _Representation of Orthogonal or Unitary Matrices: + https://www.netlib.org/lapack/lug/node128.html +""") + +ldl_factor = _add_docstr(_linalg.linalg_ldl_factor, r""" +linalg.ldl_factor(A, *, hermitian=False, out=None) -> (Tensor, Tensor) + +Computes a compact representation of the LDL factorization of a Hermitian or symmetric (possibly indefinite) matrix. + +When :attr:`A` is complex valued it can be Hermitian (:attr:`hermitian`\ `= True`) +or symmetric (:attr:`hermitian`\ `= False`). + +The factorization is of the form the form :math:`A = L D L^T`. +If :attr:`hermitian` is `True` then transpose operation is the conjugate transpose. + +:math:`L` (or :math:`U`) and :math:`D` are stored in compact form in ``LD``. +They follow the format specified by `LAPACK's sytrf`_ function. +These tensors may be used in :func:`torch.linalg.ldl_solve` to solve linear systems. + +Supports input of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if :attr:`A` is a batch of matrices then +the output has the same batch dimensions. + +""" + fr""" +.. note:: {common_notes["sync_note_has_ex"].format("torch.linalg.ldl_factor_ex")} +""" + r""" + +Args: + A (Tensor): tensor of shape `(*, n, n)` where `*` is zero or more batch dimensions + consisting of symmetric or Hermitian matrices. + +Keyword args: + hermitian (bool, optional): whether to consider the input to be Hermitian or symmetric. + For real-valued matrices, this switch has no effect. Default: `False`. + out (tuple, optional): tuple of two tensors to write the output to. Ignored if `None`. Default: `None`. + +Returns: + A named tuple `(LD, pivots)`. + +Examples:: + + >>> A = torch.randn(3, 3) + >>> A = A @ A.mT # make symmetric + >>> A + tensor([[7.2079, 4.2414, 1.9428], + [4.2414, 3.4554, 0.3264], + [1.9428, 0.3264, 1.3823]]) + >>> LD, pivots = torch.linalg.ldl_factor(A) + >>> LD + tensor([[ 7.2079, 0.0000, 0.0000], + [ 0.5884, 0.9595, 0.0000], + [ 0.2695, -0.8513, 0.1633]]) + >>> pivots + tensor([1, 2, 3], dtype=torch.int32) + +.. _LAPACK's sytrf: + https://www.netlib.org/lapack/explore-html/d3/db6/group__double_s_ycomputational_gad91bde1212277b3e909eb6af7f64858a.html +""") + +ldl_factor_ex = _add_docstr(_linalg.linalg_ldl_factor_ex, r""" +linalg.ldl_factor_ex(A, *, hermitian=False, check_errors=False, out=None) -> (Tensor, Tensor, Tensor) + +This is a version of :func:`~ldl_factor` that does not perform error checks unless :attr:`check_errors`\ `= True`. +It also returns the :attr:`info` tensor returned by `LAPACK's sytrf`_. +``info`` stores integer error codes from the backend library. +A positive integer indicates the diagonal element of :math:`D` that is zero. +Division by 0 will occur if the result is used for solving a system of linear equations. +``info`` filled with zeros indicates that the factorization was successful. +If ``check_errors=True`` and ``info`` contains positive integers, then a `RuntimeError` is thrown. + +""" + fr""" +.. note:: {common_notes["sync_note_ex"]} + +.. warning:: {common_notes["experimental_warning"]} +""" + r""" + +Args: + A (Tensor): tensor of shape `(*, n, n)` where `*` is zero or more batch dimensions + consisting of symmetric or Hermitian matrices. + +Keyword args: + hermitian (bool, optional): whether to consider the input to be Hermitian or symmetric. + For real-valued matrices, this switch has no effect. Default: `False`. + check_errors (bool, optional): controls whether to check the content of ``info`` and raise + an error if it is non-zero. Default: `False`. + out (tuple, optional): tuple of three tensors to write the output to. Ignored if `None`. Default: `None`. + +Returns: + A named tuple `(LD, pivots, info)`. + +Examples:: + + >>> A = torch.randn(3, 3) + >>> A = A @ A.mT # make symmetric + >>> A + tensor([[7.2079, 4.2414, 1.9428], + [4.2414, 3.4554, 0.3264], + [1.9428, 0.3264, 1.3823]]) + >>> LD, pivots, info = torch.linalg.ldl_factor_ex(A) + >>> LD + tensor([[ 7.2079, 0.0000, 0.0000], + [ 0.5884, 0.9595, 0.0000], + [ 0.2695, -0.8513, 0.1633]]) + >>> pivots + tensor([1, 2, 3], dtype=torch.int32) + >>> info + tensor(0, dtype=torch.int32) + +.. _LAPACK's sytrf: + https://www.netlib.org/lapack/explore-html/d3/db6/group__double_s_ycomputational_gad91bde1212277b3e909eb6af7f64858a.html +""") + +ldl_solve = _add_docstr(_linalg.linalg_ldl_solve, r""" +linalg.ldl_solve(LD, pivots, B, *, hermitian=False, out=None) -> Tensor + +Computes the solution of a system of linear equations using the LDL factorization. + +:attr:`LD` and :attr:`pivots` are the compact representation of the LDL factorization and +are expected to be computed by :func:`torch.linalg.ldl_factor_ex`. +:attr:`hermitian` argument to this function should be the same +as the corresponding arguments in :func:`torch.linalg.ldl_factor_ex`. + +Supports input of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if :attr:`A` is a batch of matrices then +the output has the same batch dimensions. + +""" + fr""" +.. warning:: {common_notes["experimental_warning"]} +""" + r""" + +Args: + LD (Tensor): the `n \times n` matrix or the batch of such matrices of size + `(*, n, n)` where `*` is one or more batch dimensions. + pivots (Tensor): the pivots corresponding to the LDL factorization of :attr:`LD`. + B (Tensor): right-hand side tensor of shape `(*, n, k)`. + +Keyword args: + hermitian (bool, optional): whether to consider the decomposed matrix to be Hermitian or symmetric. + For real-valued matrices, this switch has no effect. Default: `False`. + out (tuple, optional): output tensor. `B` may be passed as `out` and the result is computed in-place on `B`. + Ignored if `None`. Default: `None`. + +Examples:: + + >>> A = torch.randn(2, 3, 3) + >>> A = A @ A.mT # make symmetric + >>> LD, pivots, info = torch.linalg.ldl_factor_ex(A) + >>> B = torch.randn(2, 3, 4) + >>> X = torch.linalg.ldl_solve(LD, pivots, B) + >>> torch.linalg.norm(A @ X - B) + >>> tensor(0.0001) +""") + +lstsq = _add_docstr(_linalg.linalg_lstsq, r""" +torch.linalg.lstsq(A, B, rcond=None, *, driver=None) -> (Tensor, Tensor, Tensor, Tensor) + +Computes a solution to the least squares problem of a system of linear equations. + +Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, +the **least squares problem** for a linear system :math:`AX = B` with +:math:`A \in \mathbb{K}^{m \times n}, B \in \mathbb{K}^{m \times k}` is defined as + +.. math:: + + \min_{X \in \mathbb{K}^{n \times k}} \|AX - B\|_F + +where :math:`\|-\|_F` denotes the Frobenius norm. + +Supports inputs of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if the inputs are batches of matrices then +the output has the same batch dimensions. + +:attr:`driver` chooses the backend function that will be used. +For CPU inputs the valid values are `'gels'`, `'gelsy'`, `'gelsd`, `'gelss'`. +To choose the best driver on CPU consider: + +- If :attr:`A` is well-conditioned (its `condition number`_ is not too large), or you do not mind some precision loss. + + - For a general matrix: `'gelsy'` (QR with pivoting) (default) + - If :attr:`A` is full-rank: `'gels'` (QR) + +- If :attr:`A` is not well-conditioned. + + - `'gelsd'` (tridiagonal reduction and SVD) + - But if you run into memory issues: `'gelss'` (full SVD). + +For CUDA input, the only valid driver is `'gels'`, which assumes that :attr:`A` is full-rank. + +See also the `full description of these drivers`_ + +:attr:`rcond` is used to determine the effective rank of the matrices in :attr:`A` +when :attr:`driver` is one of (`'gelsy'`, `'gelsd'`, `'gelss'`). +In this case, if :math:`\sigma_i` are the singular values of `A` in decreasing order, +:math:`\sigma_i` will be rounded down to zero if :math:`\sigma_i \leq \text{rcond} \cdot \sigma_1`. +If :attr:`rcond`\ `= None` (default), :attr:`rcond` is set to the machine precision of the dtype of :attr:`A` times `max(m, n)`. + +This function returns the solution to the problem and some extra information in a named tuple of +four tensors `(solution, residuals, rank, singular_values)`. For inputs :attr:`A`, :attr:`B` +of shape `(*, m, n)`, `(*, m, k)` respectively, it contains + +- `solution`: the least squares solution. It has shape `(*, n, k)`. +- `residuals`: the squared residuals of the solutions, that is, :math:`\|AX - B\|_F^2`. + It has shape equal to the batch dimensions of :attr:`A`. + It is computed when `m > n` and every matrix in :attr:`A` is full-rank, + otherwise, it is an empty tensor. + If :attr:`A` is a batch of matrices and any matrix in the batch is not full rank, + then an empty tensor is returned. This behavior may change in a future PyTorch release. +- `rank`: tensor of ranks of the matrices in :attr:`A`. + It has shape equal to the batch dimensions of :attr:`A`. + It is computed when :attr:`driver` is one of (`'gelsy'`, `'gelsd'`, `'gelss'`), + otherwise it is an empty tensor. +- `singular_values`: tensor of singular values of the matrices in :attr:`A`. + It has shape `(*, min(m, n))`. + It is computed when :attr:`driver` is one of (`'gelsd'`, `'gelss'`), + otherwise it is an empty tensor. + +.. note:: + This function computes `X = \ `:attr:`A`\ `.pinverse() @ \ `:attr:`B` in a faster and + more numerically stable way than performing the computations separately. + +.. warning:: + The default value of :attr:`rcond` may change in a future PyTorch release. + It is therefore recommended to use a fixed value to avoid potential + breaking changes. + +Args: + A (Tensor): lhs tensor of shape `(*, m, n)` where `*` is zero or more batch dimensions. + B (Tensor): rhs tensor of shape `(*, m, k)` where `*` is zero or more batch dimensions. + rcond (float, optional): used to determine the effective rank of :attr:`A`. + If :attr:`rcond`\ `= None`, :attr:`rcond` is set to the machine + precision of the dtype of :attr:`A` times `max(m, n)`. Default: `None`. + +Keyword args: + driver (str, optional): name of the LAPACK/MAGMA method to be used. + If `None`, `'gelsy'` is used for CPU inputs and `'gels'` for CUDA inputs. + Default: `None`. + +Returns: + A named tuple `(solution, residuals, rank, singular_values)`. + +Examples:: + + >>> A = torch.randn(1,3,3) + >>> A + tensor([[[-1.0838, 0.0225, 0.2275], + [ 0.2438, 0.3844, 0.5499], + [ 0.1175, -0.9102, 2.0870]]]) + >>> B = torch.randn(2,3,3) + >>> B + tensor([[[-0.6772, 0.7758, 0.5109], + [-1.4382, 1.3769, 1.1818], + [-0.3450, 0.0806, 0.3967]], + [[-1.3994, -0.1521, -0.1473], + [ 1.9194, 1.0458, 0.6705], + [-1.1802, -0.9796, 1.4086]]]) + >>> X = torch.linalg.lstsq(A, B).solution # A is broadcasted to shape (2, 3, 3) + >>> torch.dist(X, torch.linalg.pinv(A) @ B) + tensor(1.5152e-06) + + >>> S = torch.linalg.lstsq(A, B, driver='gelsd').singular_values + >>> torch.dist(S, torch.linalg.svdvals(A)) + tensor(2.3842e-07) + + >>> A[:, 0].zero_() # Decrease the rank of A + >>> rank = torch.linalg.lstsq(A, B).rank + >>> rank + tensor([2]) + +.. _condition number: + https://pytorch.org/docs/master/linalg.html#torch.linalg.cond +.. _full description of these drivers: + https://www.netlib.org/lapack/lug/node27.html +""") + +matrix_power = _add_docstr(_linalg.linalg_matrix_power, r""" +matrix_power(A, n, *, out=None) -> Tensor + +Computes the `n`-th power of a square matrix for an integer `n`. + +Supports input of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if :attr:`A` is a batch of matrices then +the output has the same batch dimensions. + +If :attr:`n`\ `= 0`, it returns the identity matrix (or batch) of the same shape +as :attr:`A`. If :attr:`n` is negative, it returns the inverse of each matrix +(if invertible) raised to the power of `abs(n)`. + +.. note:: + Consider using :func:`torch.linalg.solve` if possible for multiplying a matrix on the left by + a negative power as, if :attr:`n`\ `> 0`:: + + torch.linalg.solve(matrix_power(A, n), B) == matrix_power(A, -n) @ B + + It is always preferred to use :func:`~solve` when possible, as it is faster and more + numerically stable than computing :math:`A^{-n}` explicitly. + +.. seealso:: + + :func:`torch.linalg.solve` computes :attr:`A`\ `.inverse() @ \ `:attr:`B` with a + numerically stable algorithm. + +Args: + A (Tensor): tensor of shape `(*, m, m)` where `*` is zero or more batch dimensions. + n (int): the exponent. + +Keyword args: + out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. + +Raises: + RuntimeError: if :attr:`n`\ `< 0` and the matrix :attr:`A` or any matrix in the + batch of matrices :attr:`A` is not invertible. + +Examples:: + + >>> A = torch.randn(3, 3) + >>> torch.linalg.matrix_power(A, 0) + tensor([[1., 0., 0.], + [0., 1., 0.], + [0., 0., 1.]]) + >>> torch.linalg.matrix_power(A, 3) + tensor([[ 1.0756, 0.4980, 0.0100], + [-1.6617, 1.4994, -1.9980], + [-0.4509, 0.2731, 0.8001]]) + >>> torch.linalg.matrix_power(A.expand(2, -1, -1), -2) + tensor([[[ 0.2640, 0.4571, -0.5511], + [-1.0163, 0.3491, -1.5292], + [-0.4899, 0.0822, 0.2773]], + [[ 0.2640, 0.4571, -0.5511], + [-1.0163, 0.3491, -1.5292], + [-0.4899, 0.0822, 0.2773]]]) +""") + +matrix_rank = _add_docstr(_linalg.linalg_matrix_rank, r""" +linalg.matrix_rank(A, *, atol=None, rtol=None, hermitian=False, out=None) -> Tensor + +Computes the numerical rank of a matrix. + +The matrix rank is computed as the number of singular values +(or eigenvalues in absolute value when :attr:`hermitian`\ `= True`) +that are greater than :math:`\max(\text{atol}, \sigma_1 * \text{rtol})` threshold, +where :math:`\sigma_1` is the largest singular value (or eigenvalue). + +Supports input of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if :attr:`A` is a batch of matrices then +the output has the same batch dimensions. + +If :attr:`hermitian`\ `= True`, :attr:`A` is assumed to be Hermitian if complex or +symmetric if real, but this is not checked internally. Instead, just the lower +triangular part of the matrix is used in the computations. + +If :attr:`rtol` is not specified and :attr:`A` is a matrix of dimensions `(m, n)`, +the relative tolerance is set to be :math:`\text{rtol} = \max(m, n) \varepsilon` +and :math:`\varepsilon` is the epsilon value for the dtype of :attr:`A` (see :class:`.finfo`). +If :attr:`rtol` is not specified and :attr:`atol` is specified to be larger than zero then +:attr:`rtol` is set to zero. + +If :attr:`atol` or :attr:`rtol` is a :class:`torch.Tensor`, its shape must be broadcastable to that +of the singular values of :attr:`A` as returned by :func:`torch.linalg.svdvals`. + +.. note:: + This function has NumPy compatible variant `linalg.matrix_rank(A, tol, hermitian=False)`. + However, use of the positional argument :attr:`tol` is deprecated in favor of :attr:`atol` and :attr:`rtol`. + +""" + fr""" +.. note:: The matrix rank is computed using a singular value decomposition + :func:`torch.linalg.svdvals` if :attr:`hermitian`\ `= False` (default) and the eigenvalue + decomposition :func:`torch.linalg.eigvalsh` when :attr:`hermitian`\ `= True`. + {common_notes["sync_note"]} +""" + r""" + +Args: + A (Tensor): tensor of shape `(*, m, n)` where `*` is zero or more batch dimensions. + tol (float, Tensor, optional): [NumPy Compat] Alias for :attr:`atol`. Default: `None`. + +Keyword args: + atol (float, Tensor, optional): the absolute tolerance value. When `None` it's considered to be zero. + Default: `None`. + rtol (float, Tensor, optional): the relative tolerance value. See above for the value it takes when `None`. + Default: `None`. + hermitian(bool): indicates whether :attr:`A` is Hermitian if complex + or symmetric if real. Default: `False`. + out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. + +Examples:: + + >>> A = torch.eye(10) + >>> torch.linalg.matrix_rank(A) + tensor(10) + >>> B = torch.eye(10) + >>> B[0, 0] = 0 + >>> torch.linalg.matrix_rank(B) + tensor(9) + + >>> A = torch.randn(4, 3, 2) + >>> torch.linalg.matrix_rank(A) + tensor([2, 2, 2, 2]) + + >>> A = torch.randn(2, 4, 2, 3) + >>> torch.linalg.matrix_rank(A) + tensor([[2, 2, 2, 2], + [2, 2, 2, 2]]) + + >>> A = torch.randn(2, 4, 3, 3, dtype=torch.complex64) + >>> torch.linalg.matrix_rank(A) + tensor([[3, 3, 3, 3], + [3, 3, 3, 3]]) + >>> torch.linalg.matrix_rank(A, hermitian=True) + tensor([[3, 3, 3, 3], + [3, 3, 3, 3]]) + >>> torch.linalg.matrix_rank(A, atol=1.0, rtol=0.0) + tensor([[3, 2, 2, 2], + [1, 2, 1, 2]]) + >>> torch.linalg.matrix_rank(A, atol=1.0, rtol=0.0, hermitian=True) + tensor([[2, 2, 2, 1], + [1, 2, 2, 2]]) +""") + +norm = _add_docstr(_linalg.linalg_norm, r""" +linalg.norm(A, ord=None, dim=None, keepdim=False, *, out=None, dtype=None) -> Tensor + +Computes a vector or matrix norm. + +Supports input of float, double, cfloat and cdouble dtypes. + +Whether this function computes a vector or matrix norm is determined as follows: + +- If :attr:`dim` is an `int`, the vector norm will be computed. +- If :attr:`dim` is a `2`-`tuple`, the matrix norm will be computed. +- If :attr:`dim`\ `= None` and :attr:`ord`\ `= None`, + :attr:`A` will be flattened to 1D and the `2`-norm of the resulting vector will be computed. +- If :attr:`dim`\ `= None` and :attr:`ord` `!= None`, :attr:`A` must be 1D or 2D. + +:attr:`ord` defines the norm that is computed. The following norms are supported: + +====================== ========================= ======================================================== +:attr:`ord` norm for matrices norm for vectors +====================== ========================= ======================================================== +`None` (default) Frobenius norm `2`-norm (see below) +`'fro'` Frobenius norm -- not supported -- +`'nuc'` nuclear norm -- not supported -- +`inf` `max(sum(abs(x), dim=1))` `max(abs(x))` +`-inf` `min(sum(abs(x), dim=1))` `min(abs(x))` +`0` -- not supported -- `sum(x != 0)` +`1` `max(sum(abs(x), dim=0))` as below +`-1` `min(sum(abs(x), dim=0))` as below +`2` largest singular value as below +`-2` smallest singular value as below +other `int` or `float` -- not supported -- `sum(abs(x)^{ord})^{(1 / ord)}` +====================== ========================= ======================================================== + +where `inf` refers to `float('inf')`, NumPy's `inf` object, or any equivalent object. + +.. seealso:: + + :func:`torch.linalg.vector_norm` computes a vector norm. + + :func:`torch.linalg.matrix_norm` computes a matrix norm. + + The above functions are often clearer and more flexible than using :func:`torch.linalg.norm`. + For example, `torch.linalg.norm(A, ord=1, dim=(0, 1))` always + computes a matrix norm, but with `torch.linalg.vector_norm(A, ord=1, dim=(0, 1))` it is possible + to compute a vector norm over the two dimensions. + +Args: + A (Tensor): tensor of shape `(*, n)` or `(*, m, n)` where `*` is zero or more batch dimensions + ord (int, float, inf, -inf, 'fro', 'nuc', optional): order of norm. Default: `None` + dim (int, Tuple[int], optional): dimensions over which to compute + the vector or matrix norm. See above for the behavior when :attr:`dim`\ `= None`. + Default: `None` + keepdim (bool, optional): If set to `True`, the reduced dimensions are retained + in the result as dimensions with size one. Default: `False` + +Keyword args: + out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. + dtype (:class:`torch.dtype`, optional): If specified, the input tensor is cast to + :attr:`dtype` before performing the operation, and the returned tensor's type + will be :attr:`dtype`. Default: `None` + +Returns: + A real-valued tensor, even when :attr:`A` is complex. + +Examples:: + + >>> from torch import linalg as LA + >>> a = torch.arange(9, dtype=torch.float) - 4 + >>> a + tensor([-4., -3., -2., -1., 0., 1., 2., 3., 4.]) + >>> B = a.reshape((3, 3)) + >>> B + tensor([[-4., -3., -2.], + [-1., 0., 1.], + [ 2., 3., 4.]]) + + >>> LA.norm(a) + tensor(7.7460) + >>> LA.norm(B) + tensor(7.7460) + >>> LA.norm(B, 'fro') + tensor(7.7460) + >>> LA.norm(a, float('inf')) + tensor(4.) + >>> LA.norm(B, float('inf')) + tensor(9.) + >>> LA.norm(a, -float('inf')) + tensor(0.) + >>> LA.norm(B, -float('inf')) + tensor(2.) + + >>> LA.norm(a, 1) + tensor(20.) + >>> LA.norm(B, 1) + tensor(7.) + >>> LA.norm(a, -1) + tensor(0.) + >>> LA.norm(B, -1) + tensor(6.) + >>> LA.norm(a, 2) + tensor(7.7460) + >>> LA.norm(B, 2) + tensor(7.3485) + + >>> LA.norm(a, -2) + tensor(0.) + >>> LA.norm(B.double(), -2) + tensor(1.8570e-16, dtype=torch.float64) + >>> LA.norm(a, 3) + tensor(5.8480) + >>> LA.norm(a, -3) + tensor(0.) + +Using the :attr:`dim` argument to compute vector norms:: + + >>> c = torch.tensor([[1., 2., 3.], + ... [-1, 1, 4]]) + >>> LA.norm(c, dim=0) + tensor([1.4142, 2.2361, 5.0000]) + >>> LA.norm(c, dim=1) + tensor([3.7417, 4.2426]) + >>> LA.norm(c, ord=1, dim=1) + tensor([6., 6.]) + +Using the :attr:`dim` argument to compute matrix norms:: + + >>> A = torch.arange(8, dtype=torch.float).reshape(2, 2, 2) + >>> LA.norm(A, dim=(1,2)) + tensor([ 3.7417, 11.2250]) + >>> LA.norm(A[0, :, :]), LA.norm(A[1, :, :]) + (tensor(3.7417), tensor(11.2250)) +""") + +vector_norm = _add_docstr(_linalg.linalg_vector_norm, r""" +linalg.vector_norm(x, ord=2, dim=None, keepdim=False, *, dtype=None, out=None) -> Tensor + +Computes a vector norm. + +If :attr:`x` is complex valued, it computes the norm of :attr:`x`\ `.abs()` + +Supports input of float, double, cfloat and cdouble dtypes. + +This function does not necessarily treat multidimensional :attr:`x` as a batch of +vectors, instead: + +- If :attr:`dim`\ `= None`, :attr:`x` will be flattened before the norm is computed. +- If :attr:`dim` is an `int` or a `tuple`, the norm will be computed over these dimensions + and the other dimensions will be treated as batch dimensions. + +This behavior is for consistency with :func:`torch.linalg.norm`. + +:attr:`ord` defines the vector norm that is computed. The following norms are supported: + +====================== =============================== +:attr:`ord` vector norm +====================== =============================== +`2` (default) `2`-norm (see below) +`inf` `max(abs(x))` +`-inf` `min(abs(x))` +`0` `sum(x != 0)` +other `int` or `float` `sum(abs(x)^{ord})^{(1 / ord)}` +====================== =============================== + +where `inf` refers to `float('inf')`, NumPy's `inf` object, or any equivalent object. + +:attr:`dtype` may be used to perform the computation in a more precise dtype. +It is semantically equivalent to calling ``linalg.vector_norm(x.to(dtype))`` +but it is faster in some cases. + +.. seealso:: + + :func:`torch.linalg.matrix_norm` computes a matrix norm. + +Args: + x (Tensor): tensor, flattened by default, but this behavior can be + controlled using :attr:`dim`. + ord (int, float, inf, -inf, 'fro', 'nuc', optional): order of norm. Default: `2` + dim (int, Tuple[int], optional): dimensions over which to compute + the norm. See above for the behavior when :attr:`dim`\ `= None`. + Default: `None` + keepdim (bool, optional): If set to `True`, the reduced dimensions are retained + in the result as dimensions with size one. Default: `False` + +Keyword args: + out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. + dtype (:class:`torch.dtype`, optional): type used to perform the accumulation and the return. + If specified, :attr:`x` is cast to :attr:`dtype` before performing the operation, + and the returned tensor’s type will be :attr:`dtype` if real and of its real counterpart if complex. + :attr:`dtype` may be complex if :attr:`x` is complex, otherwise it must be real. + :attr:`x` should be convertible without narrowing to :attr:`dtype`. Default: None + +Returns: + A real-valued tensor, even when :attr:`x` is complex. + +Examples:: + + >>> from torch import linalg as LA + >>> a = torch.arange(9, dtype=torch.float) - 4 + >>> a + tensor([-4., -3., -2., -1., 0., 1., 2., 3., 4.]) + >>> B = a.reshape((3, 3)) + >>> B + tensor([[-4., -3., -2.], + [-1., 0., 1.], + [ 2., 3., 4.]]) + >>> LA.vector_norm(a, ord=3.5) + tensor(5.4345) + >>> LA.vector_norm(B, ord=3.5) + tensor(5.4345) +""") + +matrix_norm = _add_docstr(_linalg.linalg_matrix_norm, r""" +linalg.matrix_norm(A, ord='fro', dim=(-2, -1), keepdim=False, *, dtype=None, out=None) -> Tensor + +Computes a matrix norm. + +If :attr:`A` is complex valued, it computes the norm of :attr:`A`\ `.abs()` + +Support input of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices: the norm will be computed over the +dimensions specified by the 2-tuple :attr:`dim` and the other dimensions will +be treated as batch dimensions. The output will have the same batch dimensions. + +:attr:`ord` defines the matrix norm that is computed. The following norms are supported: + +====================== ======================================================== +:attr:`ord` matrix norm +====================== ======================================================== +`'fro'` (default) Frobenius norm +`'nuc'` nuclear norm +`inf` `max(sum(abs(x), dim=1))` +`-inf` `min(sum(abs(x), dim=1))` +`1` `max(sum(abs(x), dim=0))` +`-1` `min(sum(abs(x), dim=0))` +`2` largest singular value +`-2` smallest singular value +====================== ======================================================== + +where `inf` refers to `float('inf')`, NumPy's `inf` object, or any equivalent object. + +Args: + A (Tensor): tensor with two or more dimensions. By default its + shape is interpreted as `(*, m, n)` where `*` is zero or more + batch dimensions, but this behavior can be controlled using :attr:`dim`. + ord (int, inf, -inf, 'fro', 'nuc', optional): order of norm. Default: `'fro'` + dim (Tuple[int, int], optional): dimensions over which to compute the norm. Default: `(-2, -1)` + keepdim (bool, optional): If set to `True`, the reduced dimensions are retained + in the result as dimensions with size one. Default: `False` + +Keyword args: + out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. + dtype (:class:`torch.dtype`, optional): If specified, the input tensor is cast to + :attr:`dtype` before performing the operation, and the returned tensor's type + will be :attr:`dtype`. Default: `None` + +Returns: + A real-valued tensor, even when :attr:`A` is complex. + +Examples:: + + >>> from torch import linalg as LA + >>> A = torch.arange(9, dtype=torch.float).reshape(3, 3) + >>> A + tensor([[0., 1., 2.], + [3., 4., 5.], + [6., 7., 8.]]) + >>> LA.matrix_norm(A) + tensor(14.2829) + >>> LA.matrix_norm(A, ord=-1) + tensor(9.) + >>> B = A.expand(2, -1, -1) + >>> B + tensor([[[0., 1., 2.], + [3., 4., 5.], + [6., 7., 8.]], + + [[0., 1., 2.], + [3., 4., 5.], + [6., 7., 8.]]]) + >>> LA.matrix_norm(B) + tensor([14.2829, 14.2829]) + >>> LA.matrix_norm(B, dim=(0, 2)) + tensor([ 3.1623, 10.0000, 17.2627]) +""") + +matmul = _add_docstr(_linalg.linalg_matmul, r""" +linalg.matmul(input, other, *, out=None) -> Tensor + +Alias for :func:`torch.matmul` +""") + +diagonal = _add_docstr(_linalg.linalg_diagonal, r""" +linalg.diagonal(A, *, offset=0, dim1=-2, dim2=-1) -> Tensor + +Alias for :func:`torch.diagonal` with defaults :attr:`dim1`\ `= -2`, :attr:`dim2`\ `= -1`. +""") + +multi_dot = _add_docstr(_linalg.linalg_multi_dot, r""" +linalg.multi_dot(tensors, *, out=None) + +Efficiently multiplies two or more matrices by reordering the multiplications so that +the fewest arithmetic operations are performed. + +Supports inputs of float, double, cfloat and cdouble dtypes. +This function does not support batched inputs. + +Every tensor in :attr:`tensors` must be 2D, except for the first and last which +may be 1D. If the first tensor is a 1D vector of shape `(n,)` it is treated as a row vector +of shape `(1, n)`, similarly if the last tensor is a 1D vector of shape `(n,)` it is treated +as a column vector of shape `(n, 1)`. + +If the first and last tensors are matrices, the output will be a matrix. +However, if either is a 1D vector, then the output will be a 1D vector. + +Differences with `numpy.linalg.multi_dot`: + +- Unlike `numpy.linalg.multi_dot`, the first and last tensors must either be 1D or 2D + whereas NumPy allows them to be nD + +.. warning:: This function does not broadcast. + +.. note:: This function is implemented by chaining :func:`torch.mm` calls after + computing the optimal matrix multiplication order. + +.. note:: The cost of multiplying two matrices with shapes `(a, b)` and `(b, c)` is + `a * b * c`. Given matrices `A`, `B`, `C` with shapes `(10, 100)`, + `(100, 5)`, `(5, 50)` respectively, we can calculate the cost of different + multiplication orders as follows: + + .. math:: + + \begin{align*} + \operatorname{cost}((AB)C) &= 10 \times 100 \times 5 + 10 \times 5 \times 50 = 7500 \\ + \operatorname{cost}(A(BC)) &= 10 \times 100 \times 50 + 100 \times 5 \times 50 = 75000 + \end{align*} + + In this case, multiplying `A` and `B` first followed by `C` is 10 times faster. + +Args: + tensors (Sequence[Tensor]): two or more tensors to multiply. The first and last + tensors may be 1D or 2D. Every other tensor must be 2D. + +Keyword args: + out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. + +Examples:: + + >>> from torch.linalg import multi_dot + + >>> multi_dot([torch.tensor([1, 2]), torch.tensor([2, 3])]) + tensor(8) + >>> multi_dot([torch.tensor([[1, 2]]), torch.tensor([2, 3])]) + tensor([8]) + >>> multi_dot([torch.tensor([[1, 2]]), torch.tensor([[2], [3]])]) + tensor([[8]]) + + >>> A = torch.arange(2 * 3).view(2, 3) + >>> B = torch.arange(3 * 2).view(3, 2) + >>> C = torch.arange(2 * 2).view(2, 2) + >>> multi_dot((A, B, C)) + tensor([[ 26, 49], + [ 80, 148]]) +""") + +svd = _add_docstr(_linalg.linalg_svd, r""" +linalg.svd(A, full_matrices=True, *, driver=None, out=None) -> (Tensor, Tensor, Tensor) + +Computes the singular value decomposition (SVD) of a matrix. + +Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, +the **full SVD** of a matrix +:math:`A \in \mathbb{K}^{m \times n}`, if `k = min(m,n)`, is defined as + +.. math:: + + A = U \operatorname{diag}(S) V^{\text{H}} + \mathrlap{\qquad U \in \mathbb{K}^{m \times m}, S \in \mathbb{R}^k, V \in \mathbb{K}^{n \times n}} + +where :math:`\operatorname{diag}(S) \in \mathbb{K}^{m \times n}`, +:math:`V^{\text{H}}` is the conjugate transpose when :math:`V` is complex, and the transpose when :math:`V` is real-valued. +The matrices :math:`U`, :math:`V` (and thus :math:`V^{\text{H}}`) are orthogonal in the real case, and unitary in the complex case. + +When `m > n` (resp. `m < n`) we can drop the last `m - n` (resp. `n - m`) columns of `U` (resp. `V`) to form the **reduced SVD**: + +.. math:: + + A = U \operatorname{diag}(S) V^{\text{H}} + \mathrlap{\qquad U \in \mathbb{K}^{m \times k}, S \in \mathbb{R}^k, V \in \mathbb{K}^{k \times n}} + +where :math:`\operatorname{diag}(S) \in \mathbb{K}^{k \times k}`. +In this case, :math:`U` and :math:`V` also have orthonormal columns. + +Supports input of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if :attr:`A` is a batch of matrices then +the output has the same batch dimensions. + +The returned decomposition is a named tuple `(U, S, Vh)` +which corresponds to :math:`U`, :math:`S`, :math:`V^{\text{H}}` above. + +The singular values are returned in descending order. + +The parameter :attr:`full_matrices` chooses between the full (default) and reduced SVD. + +The :attr:`driver` kwarg may be used in CUDA with a cuSOLVER backend to choose the algorithm used to compute the SVD. +The choice of a driver is a trade-off between accuracy and speed. + +- If :attr:`A` is well-conditioned (its `condition number`_ is not too large), or you do not mind some precision loss. + + - For a general matrix: `'gesvdj'` (Jacobi method) + - If :attr:`A` is tall or wide (`m >> n` or `m << n`): `'gesvda'` (Approximate method) + +- If :attr:`A` is not well-conditioned or precision is relevant: `'gesvd'` (QR based) + +By default (:attr:`driver`\ `= None`), we call `'gesvdj'` and, if it fails, we fallback to `'gesvd'`. + +Differences with `numpy.linalg.svd`: + +- Unlike `numpy.linalg.svd`, this function always returns a tuple of three tensors + and it doesn't support `compute_uv` argument. + Please use :func:`torch.linalg.svdvals`, which computes only the singular values, + instead of `compute_uv=False`. + +.. note:: When :attr:`full_matrices`\ `= True`, the gradients with respect to `U[..., :, min(m, n):]` + and `Vh[..., min(m, n):, :]` will be ignored, as those vectors can be arbitrary bases + of the corresponding subspaces. + +.. warning:: The returned tensors `U` and `V` are not unique, nor are they continuous with + respect to :attr:`A`. + Due to this lack of uniqueness, different hardware and software may compute + different singular vectors. + + This non-uniqueness is caused by the fact that multiplying any pair of singular + vectors :math:`u_k, v_k` by `-1` in the real case or by + :math:`e^{i \phi}, \phi \in \mathbb{R}` in the complex case produces another two + valid singular vectors of the matrix. + For this reason, the loss function shall not depend on this :math:`e^{i \phi}` quantity, + as it is not well-defined. + This is checked for complex inputs when computing the gradients of this function. As such, + when inputs are complex and are on a CUDA device, the computation of the gradients + of this function synchronizes that device with the CPU. + +.. warning:: Gradients computed using `U` or `Vh` will only be finite when + :attr:`A` does not have repeated singular values. If :attr:`A` is rectangular, + additionally, zero must also not be one of its singular values. + Furthermore, if the distance between any two singular values is close to zero, + the gradient will be numerically unstable, as it depends on the singular values + :math:`\sigma_i` through the computation of + :math:`\frac{1}{\min_{i \neq j} \sigma_i^2 - \sigma_j^2}`. + In the rectangular case, the gradient will also be numerically unstable when + :attr:`A` has small singular values, as it also depends on the computation of + :math:`\frac{1}{\sigma_i}`. + +.. seealso:: + + :func:`torch.linalg.svdvals` computes only the singular values. + Unlike :func:`torch.linalg.svd`, the gradients of :func:`~svdvals` are always + numerically stable. + + :func:`torch.linalg.eig` for a function that computes another type of spectral + decomposition of a matrix. The eigendecomposition works just on square matrices. + + :func:`torch.linalg.eigh` for a (faster) function that computes the eigenvalue decomposition + for Hermitian and symmetric matrices. + + :func:`torch.linalg.qr` for another (much faster) decomposition that works on general + matrices. + +Args: + A (Tensor): tensor of shape `(*, m, n)` where `*` is zero or more batch dimensions. + full_matrices (bool, optional): controls whether to compute the full or reduced + SVD, and consequently, + the shape of the returned tensors + `U` and `Vh`. Default: `True`. + +Keyword args: + driver (str, optional): name of the cuSOLVER method to be used. This keyword argument only works on CUDA inputs. + Available options are: `None`, `gesvd`, `gesvdj`, and `gesvda`. + Default: `None`. + out (tuple, optional): output tuple of three tensors. Ignored if `None`. + +Returns: + A named tuple `(U, S, Vh)` which corresponds to :math:`U`, :math:`S`, :math:`V^{\text{H}}` above. + + `S` will always be real-valued, even when :attr:`A` is complex. + It will also be ordered in descending order. + + `U` and `Vh` will have the same dtype as :attr:`A`. The left / right singular vectors will be given by + the columns of `U` and the rows of `Vh` respectively. + +Examples:: + + >>> A = torch.randn(5, 3) + >>> U, S, Vh = torch.linalg.svd(A, full_matrices=False) + >>> U.shape, S.shape, Vh.shape + (torch.Size([5, 3]), torch.Size([3]), torch.Size([3, 3])) + >>> torch.dist(A, U @ torch.diag(S) @ Vh) + tensor(1.0486e-06) + + >>> U, S, Vh = torch.linalg.svd(A) + >>> U.shape, S.shape, Vh.shape + (torch.Size([5, 5]), torch.Size([3]), torch.Size([3, 3])) + >>> torch.dist(A, U[:, :3] @ torch.diag(S) @ Vh) + tensor(1.0486e-06) + + >>> A = torch.randn(7, 5, 3) + >>> U, S, Vh = torch.linalg.svd(A, full_matrices=False) + >>> torch.dist(A, U @ torch.diag_embed(S) @ Vh) + tensor(3.0957e-06) + +.. _condition number: + https://pytorch.org/docs/master/linalg.html#torch.linalg.cond +.. _the resulting vectors will span the same subspace: + https://en.wikipedia.org/wiki/Singular_value_decomposition#Singular_values,_singular_vectors,_and_their_relation_to_the_SVD +""") + +svdvals = _add_docstr(_linalg.linalg_svdvals, r""" +linalg.svdvals(A, *, driver=None, out=None) -> Tensor + +Computes the singular values of a matrix. + +Supports input of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if :attr:`A` is a batch of matrices then +the output has the same batch dimensions. + +The singular values are returned in descending order. + +.. note:: This function is equivalent to NumPy's `linalg.svd(A, compute_uv=False)`. + +""" + fr""" +.. note:: {common_notes["sync_note"]} +""" + r""" + +.. seealso:: + + :func:`torch.linalg.svd` computes the full singular value decomposition. + +Args: + A (Tensor): tensor of shape `(*, m, n)` where `*` is zero or more batch dimensions. + +Keyword args: + driver (str, optional): name of the cuSOLVER method to be used. This keyword argument only works on CUDA inputs. + Available options are: `None`, `gesvd`, `gesvdj`, and `gesvda`. + Check :func:`torch.linalg.svd` for details. + Default: `None`. + out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. + +Returns: + A real-valued tensor, even when :attr:`A` is complex. + +Examples:: + + >>> A = torch.randn(5, 3) + >>> S = torch.linalg.svdvals(A) + >>> S + tensor([2.5139, 2.1087, 1.1066]) + + >>> torch.dist(S, torch.linalg.svd(A, full_matrices=False).S) + tensor(2.4576e-07) +""") + +cond = _add_docstr(_linalg.linalg_cond, r""" +linalg.cond(A, p=None, *, out=None) -> Tensor + +Computes the condition number of a matrix with respect to a matrix norm. + +Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, +the **condition number** :math:`\kappa` of a matrix +:math:`A \in \mathbb{K}^{n \times n}` is defined as + +.. math:: + + \kappa(A) = \|A\|_p\|A^{-1}\|_p + +The condition number of :attr:`A` measures the numerical stability of the linear system `AX = B` +with respect to a matrix norm. + +Supports input of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if :attr:`A` is a batch of matrices then +the output has the same batch dimensions. + +:attr:`p` defines the matrix norm that is computed. The following norms are supported: + +========= ================================= +:attr:`p` matrix norm +========= ================================= +`None` `2`-norm (largest singular value) +`'fro'` Frobenius norm +`'nuc'` nuclear norm +`inf` `max(sum(abs(x), dim=1))` +`-inf` `min(sum(abs(x), dim=1))` +`1` `max(sum(abs(x), dim=0))` +`-1` `min(sum(abs(x), dim=0))` +`2` largest singular value +`-2` smallest singular value +========= ================================= + +where `inf` refers to `float('inf')`, NumPy's `inf` object, or any equivalent object. + +For :attr:`p` is one of `('fro', 'nuc', inf, -inf, 1, -1)`, this function uses +:func:`torch.linalg.norm` and :func:`torch.linalg.inv`. +As such, in this case, the matrix (or every matrix in the batch) :attr:`A` has to be square +and invertible. + +For :attr:`p` in `(2, -2)`, this function can be computed in terms of the singular values +:math:`\sigma_1 \geq \ldots \geq \sigma_n` + +.. math:: + + \kappa_2(A) = \frac{\sigma_1}{\sigma_n}\qquad \kappa_{-2}(A) = \frac{\sigma_n}{\sigma_1} + +In these cases, it is computed using :func:`torch.linalg.svdvals`. For these norms, the matrix +(or every matrix in the batch) :attr:`A` may have any shape. + +.. note :: When inputs are on a CUDA device, this function synchronizes that device with the CPU + if :attr:`p` is one of `('fro', 'nuc', inf, -inf, 1, -1)`. + +.. seealso:: + + :func:`torch.linalg.solve` for a function that solves linear systems of square matrices. + + :func:`torch.linalg.lstsq` for a function that solves linear systems of general matrices. + +Args: + A (Tensor): tensor of shape `(*, m, n)` where `*` is zero or more batch dimensions + for :attr:`p` in `(2, -2)`, and of shape `(*, n, n)` where every matrix + is invertible for :attr:`p` in `('fro', 'nuc', inf, -inf, 1, -1)`. + p (int, inf, -inf, 'fro', 'nuc', optional): + the type of the matrix norm to use in the computations (see above). Default: `None` + +Keyword args: + out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. + +Returns: + A real-valued tensor, even when :attr:`A` is complex. + +Raises: + RuntimeError: + if :attr:`p` is one of `('fro', 'nuc', inf, -inf, 1, -1)` + and the :attr:`A` matrix or any matrix in the batch :attr:`A` is not square + or invertible. + +Examples:: + + >>> A = torch.randn(3, 4, 4, dtype=torch.complex64) + >>> torch.linalg.cond(A) + >>> A = torch.tensor([[1., 0, -1], [0, 1, 0], [1, 0, 1]]) + >>> torch.linalg.cond(A) + tensor([1.4142]) + >>> torch.linalg.cond(A, 'fro') + tensor(3.1623) + >>> torch.linalg.cond(A, 'nuc') + tensor(9.2426) + >>> torch.linalg.cond(A, float('inf')) + tensor(2.) + >>> torch.linalg.cond(A, float('-inf')) + tensor(1.) + >>> torch.linalg.cond(A, 1) + tensor(2.) + >>> torch.linalg.cond(A, -1) + tensor(1.) + >>> torch.linalg.cond(A, 2) + tensor([1.4142]) + >>> torch.linalg.cond(A, -2) + tensor([0.7071]) + + >>> A = torch.randn(2, 3, 3) + >>> torch.linalg.cond(A) + tensor([[9.5917], + [3.2538]]) + >>> A = torch.randn(2, 3, 3, dtype=torch.complex64) + >>> torch.linalg.cond(A) + tensor([[4.6245], + [4.5671]]) +""") + +pinv = _add_docstr(_linalg.linalg_pinv, r""" +linalg.pinv(A, *, atol=None, rtol=None, hermitian=False, out=None) -> Tensor + +Computes the pseudoinverse (Moore-Penrose inverse) of a matrix. + +The pseudoinverse may be `defined algebraically`_ +but it is more computationally convenient to understand it `through the SVD`_ + +Supports input of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if :attr:`A` is a batch of matrices then +the output has the same batch dimensions. + +If :attr:`hermitian`\ `= True`, :attr:`A` is assumed to be Hermitian if complex or +symmetric if real, but this is not checked internally. Instead, just the lower +triangular part of the matrix is used in the computations. + +The singular values (or the norm of the eigenvalues when :attr:`hermitian`\ `= True`) +that are below :math:`\max(\text{atol}, \sigma_1 \cdot \text{rtol})` threshold are +treated as zero and discarded in the computation, +where :math:`\sigma_1` is the largest singular value (or eigenvalue). + +If :attr:`rtol` is not specified and :attr:`A` is a matrix of dimensions `(m, n)`, +the relative tolerance is set to be :math:`\text{rtol} = \max(m, n) \varepsilon` +and :math:`\varepsilon` is the epsilon value for the dtype of :attr:`A` (see :class:`.finfo`). +If :attr:`rtol` is not specified and :attr:`atol` is specified to be larger than zero then +:attr:`rtol` is set to zero. + +If :attr:`atol` or :attr:`rtol` is a :class:`torch.Tensor`, its shape must be broadcastable to that +of the singular values of :attr:`A` as returned by :func:`torch.linalg.svd`. + +.. note:: This function uses :func:`torch.linalg.svd` if :attr:`hermitian`\ `= False` and + :func:`torch.linalg.eigh` if :attr:`hermitian`\ `= True`. + For CUDA inputs, this function synchronizes that device with the CPU. + +.. note:: + Consider using :func:`torch.linalg.lstsq` if possible for multiplying a matrix on the left by + the pseudoinverse, as:: + + torch.linalg.lstsq(A, B).solution == A.pinv() @ B + + It is always preferred to use :func:`~lstsq` when possible, as it is faster and more + numerically stable than computing the pseudoinverse explicitly. + +.. note:: + This function has NumPy compatible variant `linalg.pinv(A, rcond, hermitian=False)`. + However, use of the positional argument :attr:`rcond` is deprecated in favor of :attr:`rtol`. + +.. warning:: + This function uses internally :func:`torch.linalg.svd` (or :func:`torch.linalg.eigh` + when :attr:`hermitian`\ `= True`), so its derivative has the same problems as those of these + functions. See the warnings in :func:`torch.linalg.svd` and :func:`torch.linalg.eigh` for + more details. + +.. seealso:: + + :func:`torch.linalg.inv` computes the inverse of a square matrix. + + :func:`torch.linalg.lstsq` computes :attr:`A`\ `.pinv() @ \ `:attr:`B` with a + numerically stable algorithm. + +Args: + A (Tensor): tensor of shape `(*, m, n)` where `*` is zero or more batch dimensions. + rcond (float, Tensor, optional): [NumPy Compat]. Alias for :attr:`rtol`. Default: `None`. + +Keyword args: + atol (float, Tensor, optional): the absolute tolerance value. When `None` it's considered to be zero. + Default: `None`. + rtol (float, Tensor, optional): the relative tolerance value. See above for the value it takes when `None`. + Default: `None`. + hermitian(bool, optional): indicates whether :attr:`A` is Hermitian if complex + or symmetric if real. Default: `False`. + out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. + +Examples:: + + >>> A = torch.randn(3, 5) + >>> A + tensor([[ 0.5495, 0.0979, -1.4092, -0.1128, 0.4132], + [-1.1143, -0.3662, 0.3042, 1.6374, -0.9294], + [-0.3269, -0.5745, -0.0382, -0.5922, -0.6759]]) + >>> torch.linalg.pinv(A) + tensor([[ 0.0600, -0.1933, -0.2090], + [-0.0903, -0.0817, -0.4752], + [-0.7124, -0.1631, -0.2272], + [ 0.1356, 0.3933, -0.5023], + [-0.0308, -0.1725, -0.5216]]) + + >>> A = torch.randn(2, 6, 3) + >>> Apinv = torch.linalg.pinv(A) + >>> torch.dist(Apinv @ A, torch.eye(3)) + tensor(8.5633e-07) + + >>> A = torch.randn(3, 3, dtype=torch.complex64) + >>> A = A + A.T.conj() # creates a Hermitian matrix + >>> Apinv = torch.linalg.pinv(A, hermitian=True) + >>> torch.dist(Apinv @ A, torch.eye(3)) + tensor(1.0830e-06) + +.. _defined algebraically: + https://en.wikipedia.org/wiki/Moore%E2%80%93Penrose_inverse#Existence_and_uniqueness +.. _through the SVD: + https://en.wikipedia.org/wiki/Moore%E2%80%93Penrose_inverse#Singular_value_decomposition_(SVD) +""") + +matrix_exp = _add_docstr(_linalg.linalg_matrix_exp, r""" +linalg.matrix_exp(A) -> Tensor + +Computes the matrix exponential of a square matrix. + +Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, +this function computes the **matrix exponential** of :math:`A \in \mathbb{K}^{n \times n}`, which is defined as + +.. math:: + \mathrm{matrix\_exp}(A) = \sum_{k=0}^\infty \frac{1}{k!}A^k \in \mathbb{K}^{n \times n}. + +If the matrix :math:`A` has eigenvalues :math:`\lambda_i \in \mathbb{C}`, +the matrix :math:`\mathrm{matrix\_exp}(A)` has eigenvalues :math:`e^{\lambda_i} \in \mathbb{C}`. + +Supports input of bfloat16, float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if :attr:`A` is a batch of matrices then +the output has the same batch dimensions. + +Args: + A (Tensor): tensor of shape `(*, n, n)` where `*` is zero or more batch dimensions. + +Example:: + + >>> A = torch.empty(2, 2, 2) + >>> A[0, :, :] = torch.eye(2, 2) + >>> A[1, :, :] = 2 * torch.eye(2, 2) + >>> A + tensor([[[1., 0.], + [0., 1.]], + + [[2., 0.], + [0., 2.]]]) + >>> torch.linalg.matrix_exp(A) + tensor([[[2.7183, 0.0000], + [0.0000, 2.7183]], + + [[7.3891, 0.0000], + [0.0000, 7.3891]]]) + + >>> import math + >>> A = torch.tensor([[0, math.pi/3], [-math.pi/3, 0]]) # A is skew-symmetric + >>> torch.linalg.matrix_exp(A) # matrix_exp(A) = [[cos(pi/3), sin(pi/3)], [-sin(pi/3), cos(pi/3)]] + tensor([[ 0.5000, 0.8660], + [-0.8660, 0.5000]]) +""") + + +solve = _add_docstr(_linalg.linalg_solve, r""" +linalg.solve(A, B, *, left=True, out=None) -> Tensor + +Computes the solution of a square system of linear equations with a unique solution. + +Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, +this function computes the solution :math:`X \in \mathbb{K}^{n \times k}` of the **linear system** associated to +:math:`A \in \mathbb{K}^{n \times n}, B \in \mathbb{K}^{n \times k}`, which is defined as + +.. math:: AX = B + +If :attr:`left`\ `= False`, this function returns the matrix :math:`X \in \mathbb{K}^{n \times k}` that solves the system + +.. math:: + + XA = B\mathrlap{\qquad A \in \mathbb{K}^{k \times k}, B \in \mathbb{K}^{n \times k}.} + +This system of linear equations has one solution if and only if :math:`A` is `invertible`_. +This function assumes that :math:`A` is invertible. + +Supports inputs of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if the inputs are batches of matrices then +the output has the same batch dimensions. + +Letting `*` be zero or more batch dimensions, + +- If :attr:`A` has shape `(*, n, n)` and :attr:`B` has shape `(*, n)` (a batch of vectors) or shape + `(*, n, k)` (a batch of matrices or "multiple right-hand sides"), this function returns `X` of shape + `(*, n)` or `(*, n, k)` respectively. +- Otherwise, if :attr:`A` has shape `(*, n, n)` and :attr:`B` has shape `(n,)` or `(n, k)`, :attr:`B` + is broadcasted to have shape `(*, n)` or `(*, n, k)` respectively. + This function then returns the solution of the resulting batch of systems of linear equations. + +.. note:: + This function computes `X = \ `:attr:`A`\ `.inverse() @ \ `:attr:`B` in a faster and + more numerically stable way than performing the computations separately. + +.. note:: + It is possible to compute the solution of the system :math:`XA = B` by passing the inputs + :attr:`A` and :attr:`B` transposed and transposing the output returned by this function. + +""" + fr""" +.. note:: {common_notes["sync_note_has_ex"].format("torch.linalg.solve_ex")} +""" + r""" + +.. seealso:: + + :func:`torch.linalg.solve_triangular` computes the solution of a triangular system of linear + equations with a unique solution. + +Args: + A (Tensor): tensor of shape `(*, n, n)` where `*` is zero or more batch dimensions. + B (Tensor): right-hand side tensor of shape `(*, n)` or `(*, n, k)` or `(n,)` or `(n, k)` + according to the rules described above + +Keyword args: + left (bool, optional): whether to solve the system :math:`AX=B` or :math:`XA = B`. Default: `True`. + out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. + +Raises: + RuntimeError: if the :attr:`A` matrix is not invertible or any matrix in a batched :attr:`A` + is not invertible. + +Examples:: + + >>> A = torch.randn(3, 3) + >>> b = torch.randn(3) + >>> x = torch.linalg.solve(A, b) + >>> torch.allclose(A @ x, b) + True + >>> A = torch.randn(2, 3, 3) + >>> B = torch.randn(2, 3, 4) + >>> X = torch.linalg.solve(A, B) + >>> X.shape + torch.Size([2, 3, 4]) + >>> torch.allclose(A @ X, B) + True + + >>> A = torch.randn(2, 3, 3) + >>> b = torch.randn(3, 1) + >>> x = torch.linalg.solve(A, b) # b is broadcasted to size (2, 3, 1) + >>> x.shape + torch.Size([2, 3, 1]) + >>> torch.allclose(A @ x, b) + True + >>> b = torch.randn(3) + >>> x = torch.linalg.solve(A, b) # b is broadcasted to size (2, 3) + >>> x.shape + torch.Size([2, 3]) + >>> Ax = A @ x.unsqueeze(-1) + >>> torch.allclose(Ax, b.unsqueeze(-1).expand_as(Ax)) + True + +.. _invertible: + https://en.wikipedia.org/wiki/Invertible_matrix#The_invertible_matrix_theorem +""") + +solve_triangular = _add_docstr(_linalg.linalg_solve_triangular, r""" +linalg.solve_triangular(A, B, *, upper, left=True, unitriangular=False, out=None) -> Tensor + +Computes the solution of a triangular system of linear equations with a unique solution. + +Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, +this function computes the solution :math:`X \in \mathbb{K}^{n \times k}` of the **linear system** +associated to the triangular matrix :math:`A \in \mathbb{K}^{n \times n}` without zeros on the diagonal +(that is, it is `invertible`_) and the rectangular matrix , :math:`B \in \mathbb{K}^{n \times k}`, +which is defined as + +.. math:: AX = B + +The argument :attr:`upper` signals whether :math:`A` is upper or lower triangular. + +If :attr:`left`\ `= False`, this function returns the matrix :math:`X \in \mathbb{K}^{n \times k}` that +solves the system + +.. math:: + + XA = B\mathrlap{\qquad A \in \mathbb{K}^{k \times k}, B \in \mathbb{K}^{n \times k}.} + +If :attr:`upper`\ `= True` (resp. `False`) just the upper (resp. lower) triangular half of :attr:`A` +will be accessed. The elements below the main diagonal will be considered to be zero and will not be accessed. + +If :attr:`unitriangular`\ `= True`, the diagonal of :attr:`A` is assumed to be ones and will not be accessed. + +The result may contain `NaN` s if the diagonal of :attr:`A` contains zeros or elements that +are very close to zero and :attr:`unitriangular`\ `= False` (default) or if the input matrix +has very small eigenvalues. + +Supports inputs of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if the inputs are batches of matrices then +the output has the same batch dimensions. + +.. seealso:: + + :func:`torch.linalg.solve` computes the solution of a general square system of linear + equations with a unique solution. + +Args: + A (Tensor): tensor of shape `(*, n, n)` (or `(*, k, k)` if :attr:`left`\ `= True`) + where `*` is zero or more batch dimensions. + B (Tensor): right-hand side tensor of shape `(*, n, k)`. + +Keyword args: + upper (bool): whether :attr:`A` is an upper or lower triangular matrix. + left (bool, optional): whether to solve the system :math:`AX=B` or :math:`XA = B`. Default: `True`. + unitriangular (bool, optional): if `True`, the diagonal elements of :attr:`A` are assumed to be + all equal to `1`. Default: `False`. + out (Tensor, optional): output tensor. `B` may be passed as `out` and the result is computed in-place on `B`. + Ignored if `None`. Default: `None`. + +Examples:: + + >>> A = torch.randn(3, 3).triu_() + >>> B = torch.randn(3, 4) + >>> X = torch.linalg.solve_triangular(A, B, upper=True) + >>> torch.allclose(A @ X, B) + True + + >>> A = torch.randn(2, 3, 3).tril_() + >>> B = torch.randn(2, 3, 4) + >>> X = torch.linalg.solve_triangular(A, B, upper=False) + >>> torch.allclose(A @ X, B) + True + + >>> A = torch.randn(2, 4, 4).tril_() + >>> B = torch.randn(2, 3, 4) + >>> X = torch.linalg.solve_triangular(A, B, upper=False, left=False) + >>> torch.allclose(X @ A, B) + True + +.. _invertible: + https://en.wikipedia.org/wiki/Invertible_matrix#The_invertible_matrix_theorem +""") + +lu_factor = _add_docstr(_linalg.linalg_lu_factor, r""" +linalg.lu_factor(A, *, bool pivot=True, out=None) -> (Tensor, Tensor) + +Computes a compact representation of the LU factorization with partial pivoting of a matrix. + +This function computes a compact representation of the decomposition given by :func:`torch.linalg.lu`. +If the matrix is square, this representation may be used in :func:`torch.linalg.lu_solve` +to solve system of linear equations that share the matrix :attr:`A`. + +The returned decomposition is represented as a named tuple `(LU, pivots)`. +The ``LU`` matrix has the same shape as the input matrix ``A``. Its upper and lower triangular +parts encode the non-constant elements of ``L`` and ``U`` of the LU decomposition of ``A``. + +The returned permutation matrix is represented by a 1-indexed vector. `pivots[i] == j` represents +that in the `i`-th step of the algorithm, the `i`-th row was permuted with the `j-1`-th row. + +On CUDA, one may use :attr:`pivot`\ `= False`. In this case, this function returns the LU +decomposition without pivoting if it exists. + +Supports inputs of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if the inputs are batches of matrices then +the output has the same batch dimensions. + +""" + fr""" +.. note:: {common_notes["sync_note_has_ex"].format("torch.linalg.lu_factor_ex")} +""" + r""" +.. warning:: The LU decomposition is almost never unique, as often there are different permutation + matrices that can yield different LU decompositions. + As such, different platforms, like SciPy, or inputs on different devices, + may produce different valid decompositions. + + Gradient computations are only supported if the input matrix is full-rank. + If this condition is not met, no error will be thrown, but the gradient may not be finite. + This is because the LU decomposition with pivoting is not differentiable at these points. + +.. seealso:: + + :func:`torch.linalg.lu_solve` solves a system of linear equations given the output of this + function provided the input matrix was square and invertible. + + :func:`torch.lu_unpack` unpacks the tensors returned by :func:`~lu_factor` into the three + matrices `P, L, U` that form the decomposition. + + :func:`torch.linalg.lu` computes the LU decomposition with partial pivoting of a possibly + non-square matrix. It is a composition of :func:`~lu_factor` and :func:`torch.lu_unpack`. + + :func:`torch.linalg.solve` solves a system of linear equations. It is a composition + of :func:`~lu_factor` and :func:`~lu_solve`. + +Args: + A (Tensor): tensor of shape `(*, m, n)` where `*` is zero or more batch dimensions. + +Keyword args: + pivot (bool, optional): Whether to compute the LU decomposition with partial pivoting, or the regular LU + decomposition. :attr:`pivot`\ `= False` not supported on CPU. Default: `True`. + out (tuple, optional): tuple of two tensors to write the output to. Ignored if `None`. Default: `None`. + +Returns: + A named tuple `(LU, pivots)`. + +Raises: + RuntimeError: if the :attr:`A` matrix is not invertible or any matrix in a batched :attr:`A` + is not invertible. + +Examples:: + + >>> A = torch.randn(2, 3, 3) + >>> B1 = torch.randn(2, 3, 4) + >>> B2 = torch.randn(2, 3, 7) + >>> LU, pivots = torch.linalg.lu_factor(A) + >>> X1 = torch.linalg.lu_solve(LU, pivots, B1) + >>> X2 = torch.linalg.lu_solve(LU, pivots, B2) + >>> torch.allclose(A @ X1, B1) + True + >>> torch.allclose(A @ X2, B2) + True + +.. _invertible: + https://en.wikipedia.org/wiki/Invertible_matrix#The_invertible_matrix_theorem +""") + +lu_factor_ex = _add_docstr(_linalg.linalg_lu_factor_ex, r""" +linalg.lu_factor_ex(A, *, pivot=True, check_errors=False, out=None) -> (Tensor, Tensor, Tensor) + +This is a version of :func:`~lu_factor` that does not perform error checks unless :attr:`check_errors`\ `= True`. +It also returns the :attr:`info` tensor returned by `LAPACK's getrf`_. + +""" + fr""" +.. note:: {common_notes["sync_note_ex"]} + +.. warning:: {common_notes["experimental_warning"]} +""" + r""" + +Args: + A (Tensor): tensor of shape `(*, m, n)` where `*` is zero or more batch dimensions. + +Keyword args: + pivot (bool, optional): Whether to compute the LU decomposition with partial pivoting, or the regular LU + decomposition. :attr:`pivot`\ `= False` not supported on CPU. Default: `True`. + check_errors (bool, optional): controls whether to check the content of ``infos`` and raise + an error if it is non-zero. Default: `False`. + out (tuple, optional): tuple of three tensors to write the output to. Ignored if `None`. Default: `None`. + +Returns: + A named tuple `(LU, pivots, info)`. + +.. _LAPACK's getrf: + https://www.netlib.org/lapack/explore-html/dd/d9a/group__double_g_ecomputational_ga0019443faea08275ca60a734d0593e60.html +""") + +lu_solve = _add_docstr(_linalg.linalg_lu_solve, r""" +linalg.lu_solve(LU, pivots, B, *, left=True, adjoint=False, out=None) -> Tensor + +Computes the solution of a square system of linear equations with a unique solution given an LU decomposition. + +Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, +this function computes the solution :math:`X \in \mathbb{K}^{n \times k}` of the **linear system** associated to +:math:`A \in \mathbb{K}^{n \times n}, B \in \mathbb{K}^{n \times k}`, which is defined as + +.. math:: AX = B + +where :math:`A` is given factorized as returned by :func:`~lu_factor`. + +If :attr:`left`\ `= False`, this function returns the matrix :math:`X \in \mathbb{K}^{n \times k}` that solves the system + +.. math:: + + XA = B\mathrlap{\qquad A \in \mathbb{K}^{k \times k}, B \in \mathbb{K}^{n \times k}.} + +If :attr:`adjoint`\ `= True` (and :attr:`left`\ `= True`), given an LU factorization of :math:`A` +this function function returns the :math:`X \in \mathbb{K}^{n \times k}` that solves the system + +.. math:: + + A^{\text{H}}X = B\mathrlap{\qquad A \in \mathbb{K}^{k \times k}, B \in \mathbb{K}^{n \times k}.} + +where :math:`A^{\text{H}}` is the conjugate transpose when :math:`A` is complex, and the +transpose when :math:`A` is real-valued. The :attr:`left`\ `= False` case is analogous. + +Supports inputs of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if the inputs are batches of matrices then +the output has the same batch dimensions. + +Args: + LU (Tensor): tensor of shape `(*, n, n)` (or `(*, k, k)` if :attr:`left`\ `= True`) + where `*` is zero or more batch dimensions as returned by :func:`~lu_factor`. + pivots (Tensor): tensor of shape `(*, n)` (or `(*, k)` if :attr:`left`\ `= True`) + where `*` is zero or more batch dimensions as returned by :func:`~lu_factor`. + B (Tensor): right-hand side tensor of shape `(*, n, k)`. + +Keyword args: + left (bool, optional): whether to solve the system :math:`AX=B` or :math:`XA = B`. Default: `True`. + adjoint (bool, optional): whether to solve the system :math:`AX=B` or :math:`A^{\text{H}}X = B`. Default: `False`. + out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. + +Examples:: + + >>> A = torch.randn(3, 3) + >>> LU, pivots = torch.linalg.lu_factor(A) + >>> B = torch.randn(3, 2) + >>> X = torch.linalg.lu_solve(LU, pivots, B) + >>> torch.allclose(A @ X, B) + True + + >>> B = torch.randn(3, 3, 2) # Broadcasting rules apply: A is broadcasted + >>> X = torch.linalg.lu_solve(LU, pivots, B) + >>> torch.allclose(A @ X, B) + True + + >>> B = torch.randn(3, 5, 3) + >>> X = torch.linalg.lu_solve(LU, pivots, B, left=False) + >>> torch.allclose(X @ A, B) + True + + >>> B = torch.randn(3, 3, 4) # Now solve for A^T + >>> X = torch.linalg.lu_solve(LU, pivots, B, adjoint=True) + >>> torch.allclose(A.mT @ X, B) + True + +.. _invertible: + https://en.wikipedia.org/wiki/Invertible_matrix#The_invertible_matrix_theorem +""") + +lu = _add_docstr(_linalg.linalg_lu, r""" +lu(A, *, pivot=True, out=None) -> (Tensor, Tensor, Tensor) + +Computes the LU decomposition with partial pivoting of a matrix. + +Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, +the **LU decomposition with partial pivoting** of a matrix +:math:`A \in \mathbb{K}^{m \times n}` is defined as + +.. math:: + + A = PLU\mathrlap{\qquad P \in \mathbb{K}^{m \times m}, L \in \mathbb{K}^{m \times k}, U \in \mathbb{K}^{k \times n}} + +where `k = min(m,n)`, :math:`P` is a `permutation matrix`_, :math:`L` is lower triangular with ones on the diagonal +and :math:`U` is upper triangular. + +If :attr:`pivot`\ `= False` and :attr:`A` is on GPU, then the **LU decomposition without pivoting** is computed + +.. math:: + + A = LU\mathrlap{\qquad L \in \mathbb{K}^{m \times k}, U \in \mathbb{K}^{k \times n}} + +When :attr:`pivot`\ `= False`, the returned matrix :attr:`P` will be empty. +The LU decomposition without pivoting `may not exist`_ if any of the principal minors of :attr:`A` is singular. +In this case, the output matrix may contain `inf` or `NaN`. + +Supports input of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if :attr:`A` is a batch of matrices then +the output has the same batch dimensions. + +.. seealso:: + + :func:`torch.linalg.solve` solves a system of linear equations using the LU decomposition + with partial pivoting. + +.. warning:: The LU decomposition is almost never unique, as often there are different permutation + matrices that can yield different LU decompositions. + As such, different platforms, like SciPy, or inputs on different devices, + may produce different valid decompositions. + +.. warning:: Gradient computations are only supported if the input matrix is full-rank. + If this condition is not met, no error will be thrown, but the gradient + may not be finite. + This is because the LU decomposition with pivoting is not differentiable at these points. + +Args: + A (Tensor): tensor of shape `(*, m, n)` where `*` is zero or more batch dimensions. + pivot (bool, optional): Controls whether to compute the LU decomposition with partial pivoting or + no pivoting. Default: `True`. + +Keyword args: + out (tuple, optional): output tuple of three tensors. Ignored if `None`. Default: `None`. + +Returns: + A named tuple `(P, L, U)`. + +Examples:: + + >>> A = torch.randn(3, 2) + >>> P, L, U = torch.linalg.lu(A) + >>> P + tensor([[0., 1., 0.], + [0., 0., 1.], + [1., 0., 0.]]) + >>> L + tensor([[1.0000, 0.0000], + [0.5007, 1.0000], + [0.0633, 0.9755]]) + >>> U + tensor([[0.3771, 0.0489], + [0.0000, 0.9644]]) + >>> torch.dist(A, P @ L @ U) + tensor(5.9605e-08) + + >>> A = torch.randn(2, 5, 7, device="cuda") + >>> P, L, U = torch.linalg.lu(A, pivot=False) + >>> P + tensor([], device='cuda:0') + >>> torch.dist(A, L @ U) + tensor(1.0376e-06, device='cuda:0') + +.. _permutation matrix: + https://en.wikipedia.org/wiki/Permutation_matrix +.. _may not exist: + https://en.wikipedia.org/wiki/LU_decomposition#Definitions +""") + +tensorinv = _add_docstr(_linalg.linalg_tensorinv, r""" +linalg.tensorinv(A, ind=2, *, out=None) -> Tensor + +Computes the multiplicative inverse of :func:`torch.tensordot`. + +If `m` is the product of the first :attr:`ind` dimensions of :attr:`A` and `n` is the product of +the rest of the dimensions, this function expects `m` and `n` to be equal. +If this is the case, it computes a tensor `X` such that +`tensordot(\ `:attr:`A`\ `, X, \ `:attr:`ind`\ `)` is the identity matrix in dimension `m`. +`X` will have the shape of :attr:`A` but with the first :attr:`ind` dimensions pushed back to the end + +.. code:: text + + X.shape == A.shape[ind:] + A.shape[:ind] + +Supports input of float, double, cfloat and cdouble dtypes. + +.. note:: When :attr:`A` is a `2`-dimensional tensor and :attr:`ind`\ `= 1`, + this function computes the (multiplicative) inverse of :attr:`A` + (see :func:`torch.linalg.inv`). + +.. note:: + Consider using :func:`torch.linalg.tensorsolve` if possible for multiplying a tensor on the left + by the tensor inverse, as:: + + linalg.tensorsolve(A, B) == torch.tensordot(linalg.tensorinv(A), B) # When B is a tensor with shape A.shape[:B.ndim] + + It is always preferred to use :func:`~tensorsolve` when possible, as it is faster and more + numerically stable than computing the pseudoinverse explicitly. + +.. seealso:: + + :func:`torch.linalg.tensorsolve` computes + `torch.tensordot(tensorinv(\ `:attr:`A`\ `), \ `:attr:`B`\ `)`. + +Args: + A (Tensor): tensor to invert. Its shape must satisfy + `prod(\ `:attr:`A`\ `.shape[:\ `:attr:`ind`\ `]) == + prod(\ `:attr:`A`\ `.shape[\ `:attr:`ind`\ `:])`. + ind (int): index at which to compute the inverse of :func:`torch.tensordot`. Default: `2`. + +Keyword args: + out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. + +Raises: + RuntimeError: if the reshaped :attr:`A` is not invertible or the product of the first + :attr:`ind` dimensions is not equal to the product of the rest. + +Examples:: + + >>> A = torch.eye(4 * 6).reshape((4, 6, 8, 3)) + >>> Ainv = torch.linalg.tensorinv(A, ind=2) + >>> Ainv.shape + torch.Size([8, 3, 4, 6]) + >>> B = torch.randn(4, 6) + >>> torch.allclose(torch.tensordot(Ainv, B), torch.linalg.tensorsolve(A, B)) + True + + >>> A = torch.randn(4, 4) + >>> Atensorinv = torch.linalg.tensorinv(A, ind=1) + >>> Ainv = torch.linalg.inv(A) + >>> torch.allclose(Atensorinv, Ainv) + True +""") + +tensorsolve = _add_docstr(_linalg.linalg_tensorsolve, r""" +linalg.tensorsolve(A, B, dims=None, *, out=None) -> Tensor + +Computes the solution `X` to the system `torch.tensordot(A, X) = B`. + +If `m` is the product of the first :attr:`B`\ `.ndim` dimensions of :attr:`A` and +`n` is the product of the rest of the dimensions, this function expects `m` and `n` to be equal. + +The returned tensor `x` satisfies +`tensordot(\ `:attr:`A`\ `, x, dims=x.ndim) == \ `:attr:`B`. +`x` has shape :attr:`A`\ `[B.ndim:]`. + +If :attr:`dims` is specified, :attr:`A` will be reshaped as + +.. code:: text + + A = movedim(A, dims, range(len(dims) - A.ndim + 1, 0)) + +Supports inputs of float, double, cfloat and cdouble dtypes. + +.. seealso:: + + :func:`torch.linalg.tensorinv` computes the multiplicative inverse of + :func:`torch.tensordot`. + +Args: + A (Tensor): tensor to solve for. Its shape must satisfy + `prod(\ `:attr:`A`\ `.shape[:\ `:attr:`B`\ `.ndim]) == + prod(\ `:attr:`A`\ `.shape[\ `:attr:`B`\ `.ndim:])`. + B (Tensor): tensor of shape :attr:`A`\ `.shape[:\ `:attr:`B`\ `.ndim]`. + dims (Tuple[int], optional): dimensions of :attr:`A` to be moved. + If `None`, no dimensions are moved. Default: `None`. + +Keyword args: + out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. + +Raises: + RuntimeError: if the reshaped :attr:`A`\ `.view(m, m)` with `m` as above is not + invertible or the product of the first :attr:`ind` dimensions is not equal + to the product of the rest of the dimensions. + +Examples:: + + >>> A = torch.eye(2 * 3 * 4).reshape((2 * 3, 4, 2, 3, 4)) + >>> B = torch.randn(2 * 3, 4) + >>> X = torch.linalg.tensorsolve(A, B) + >>> X.shape + torch.Size([2, 3, 4]) + >>> torch.allclose(torch.tensordot(A, X, dims=X.ndim), B) + True + + >>> A = torch.randn(6, 4, 4, 3, 2) + >>> B = torch.randn(4, 3, 2) + >>> X = torch.linalg.tensorsolve(A, B, dims=(0, 2)) + >>> X.shape + torch.Size([6, 4]) + >>> A = A.permute(1, 3, 4, 0, 2) + >>> A.shape[B.ndim:] + torch.Size([6, 4]) + >>> torch.allclose(torch.tensordot(A, X, dims=X.ndim), B, atol=1e-6) + True +""") + +qr = _add_docstr(_linalg.linalg_qr, r""" +qr(A, mode='reduced', *, out=None) -> (Tensor, Tensor) + +Computes the QR decomposition of a matrix. + +Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, +the **full QR decomposition** of a matrix +:math:`A \in \mathbb{K}^{m \times n}` is defined as + +.. math:: + + A = QR\mathrlap{\qquad Q \in \mathbb{K}^{m \times m}, R \in \mathbb{K}^{m \times n}} + +where :math:`Q` is orthogonal in the real case and unitary in the complex case, +and :math:`R` is upper triangular with real diagonal (even in the complex case). + +When `m > n` (tall matrix), as `R` is upper triangular, its last `m - n` rows are zero. +In this case, we can drop the last `m - n` columns of `Q` to form the +**reduced QR decomposition**: + +.. math:: + + A = QR\mathrlap{\qquad Q \in \mathbb{K}^{m \times n}, R \in \mathbb{K}^{n \times n}} + +The reduced QR decomposition agrees with the full QR decomposition when `n >= m` (wide matrix). + +Supports input of float, double, cfloat and cdouble dtypes. +Also supports batches of matrices, and if :attr:`A` is a batch of matrices then +the output has the same batch dimensions. + +The parameter :attr:`mode` chooses between the full and reduced QR decomposition. +If :attr:`A` has shape `(*, m, n)`, denoting `k = min(m, n)` + +- :attr:`mode`\ `= 'reduced'` (default): Returns `(Q, R)` of shapes `(*, m, k)`, `(*, k, n)` respectively. + It is always differentiable. +- :attr:`mode`\ `= 'complete'`: Returns `(Q, R)` of shapes `(*, m, m)`, `(*, m, n)` respectively. + It is differentiable for `m <= n`. +- :attr:`mode`\ `= 'r'`: Computes only the reduced `R`. Returns `(Q, R)` with `Q` empty and `R` of shape `(*, k, n)`. + It is never differentiable. + +Differences with `numpy.linalg.qr`: + +- :attr:`mode`\ `= 'raw'` is not implemented. +- Unlike `numpy.linalg.qr`, this function always returns a tuple of two tensors. + When :attr:`mode`\ `= 'r'`, the `Q` tensor is an empty tensor. + +.. warning:: The elements in the diagonal of `R` are not necessarily positive. + As such, the returned QR decomposition is only unique up to the sign of the diagonal of `R`. + Therefore, different platforms, like NumPy, or inputs on different devices, + may produce different valid decompositions. + +.. warning:: The QR decomposition is only well-defined if the first `k = min(m, n)` columns + of every matrix in :attr:`A` are linearly independent. + If this condition is not met, no error will be thrown, but the QR produced + may be incorrect and its autodiff may fail or produce incorrect results. + +Args: + A (Tensor): tensor of shape `(*, m, n)` where `*` is zero or more batch dimensions. + mode (str, optional): one of `'reduced'`, `'complete'`, `'r'`. + Controls the shape of the returned tensors. Default: `'reduced'`. + +Keyword args: + out (tuple, optional): output tuple of two tensors. Ignored if `None`. Default: `None`. + +Returns: + A named tuple `(Q, R)`. + +Examples:: + + >>> A = torch.tensor([[12., -51, 4], [6, 167, -68], [-4, 24, -41]]) + >>> Q, R = torch.linalg.qr(A) + >>> Q + tensor([[-0.8571, 0.3943, 0.3314], + [-0.4286, -0.9029, -0.0343], + [ 0.2857, -0.1714, 0.9429]]) + >>> R + tensor([[ -14.0000, -21.0000, 14.0000], + [ 0.0000, -175.0000, 70.0000], + [ 0.0000, 0.0000, -35.0000]]) + >>> (Q @ R).round() + tensor([[ 12., -51., 4.], + [ 6., 167., -68.], + [ -4., 24., -41.]]) + >>> (Q.T @ Q).round() + tensor([[ 1., 0., 0.], + [ 0., 1., -0.], + [ 0., -0., 1.]]) + >>> Q2, R2 = torch.linalg.qr(A, mode='r') + >>> Q2 + tensor([]) + >>> torch.equal(R, R2) + True + >>> A = torch.randn(3, 4, 5) + >>> Q, R = torch.linalg.qr(A, mode='complete') + >>> torch.dist(Q @ R, A) + tensor(1.6099e-06) + >>> torch.dist(Q.mT @ Q, torch.eye(4)) + tensor(6.2158e-07) +""") + +vander = _add_docstr(_linalg.linalg_vander, r""" +vander(x, N=None) -> Tensor + +Generates a Vandermonde matrix. + +Returns the Vandermonde matrix :math:`V` + +.. math:: + + V = \begin{pmatrix} + 1 & x_1 & x_1^2 & \dots & x_1^{N-1}\\ + 1 & x_2 & x_2^2 & \dots & x_2^{N-1}\\ + 1 & x_3 & x_3^2 & \dots & x_3^{N-1}\\ + \vdots & \vdots & \vdots & \ddots &\vdots \\ + 1 & x_n & x_n^2 & \dots & x_n^{N-1} + \end{pmatrix}. + +for `N > 1`. +If :attr:`N`\ `= None`, then `N = x.size(-1)` so that the output is a square matrix. + +Supports inputs of float, double, cfloat, cdouble, and integral dtypes. +Also supports batches of vectors, and if :attr:`x` is a batch of vectors then +the output has the same batch dimensions. + +Differences with `numpy.vander`: + +- Unlike `numpy.vander`, this function returns the powers of :attr:`x` in ascending order. + To get them in the reverse order call ``linalg.vander(x, N).flip(-1)``. + +Args: + x (Tensor): tensor of shape `(*, n)` where `*` is zero or more batch dimensions + consisting of vectors. + +Keyword args: + N (int, optional): Number of columns in the output. Default: `x.size(-1)` + +Example:: + + >>> x = torch.tensor([1, 2, 3, 5]) + >>> linalg.vander(x) + tensor([[ 1, 1, 1, 1], + [ 1, 2, 4, 8], + [ 1, 3, 9, 27], + [ 1, 5, 25, 125]]) + >>> linalg.vander(x, N=3) + tensor([[ 1, 1, 1], + [ 1, 2, 4], + [ 1, 3, 9], + [ 1, 5, 25]]) +""") + +vecdot = _add_docstr(_linalg.linalg_vecdot, r""" +linalg.vecdot(x, y, *, dim=-1, out=None) -> Tensor + +Computes the dot product of two batches of vectors along a dimension. + +In symbols, this function computes + +.. math:: + + \sum_{i=1}^n \overline{x_i}y_i. + +over the dimension :attr:`dim` where :math:`\overline{x_i}` denotes the conjugate for complex +vectors, and it is the identity for real vectors. + +Supports input of half, bfloat16, float, double, cfloat, cdouble and integral dtypes. +It also supports broadcasting. + +Args: + x (Tensor): first batch of vectors of shape `(*, n)`. + y (Tensor): second batch of vectors of shape `(*, n)`. + +Keyword args: + dim (int): Dimension along which to compute the dot product. Default: `-1`. + out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. + +Examples:: + + >>> v1 = torch.randn(3, 2) + >>> v2 = torch.randn(3, 2) + >>> linalg.vecdot(v1, v2) + tensor([ 0.3223, 0.2815, -0.1944]) + >>> torch.vdot(v1[0], v2[0]) + tensor(0.3223) +""") diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3d317b7c09f20a40ac26acf0c4474c43d4cafda9 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/__init__.py @@ -0,0 +1,53 @@ +from .modules import * # noqa: F403 +from .parameter import ( + Parameter as Parameter, + UninitializedParameter as UninitializedParameter, + UninitializedBuffer as UninitializedBuffer, +) +from .parallel import DataParallel as DataParallel +from . import init +from . import functional +from . import utils +from . import attention + + +def factory_kwargs(kwargs): + r"""Return a canonicalized dict of factory kwargs. + + Given kwargs, returns a canonicalized dict of factory kwargs that can be directly passed + to factory functions like torch.empty, or errors if unrecognized kwargs are present. + + This function makes it simple to write code like this:: + + class MyModule(nn.Module): + def __init__(self, **kwargs): + factory_kwargs = torch.nn.factory_kwargs(kwargs) + self.weight = Parameter(torch.empty(10, **factory_kwargs)) + + Why should you use this function instead of just passing `kwargs` along directly? + + 1. This function does error validation, so if there are unexpected kwargs we will + immediately report an error, instead of deferring it to the factory call + 2. This function supports a special `factory_kwargs` argument, which can be used to + explicitly specify a kwarg to be used for factory functions, in the event one of the + factory kwargs conflicts with an already existing argument in the signature (e.g. + in the signature ``def f(dtype, **kwargs)``, you can specify ``dtype`` for factory + functions, as distinct from the dtype argument, by saying + ``f(dtype1, factory_kwargs={"dtype": dtype2})``) + """ + if kwargs is None: + return {} + simple_keys = {"device", "dtype", "memory_format"} + expected_keys = simple_keys | {"factory_kwargs"} + if not kwargs.keys() <= expected_keys: + raise TypeError(f"unexpected kwargs {kwargs.keys() - expected_keys}") + + # guarantee no input kwargs is untouched + r = dict(kwargs.get("factory_kwargs", {})) + for k in simple_keys: + if k in kwargs: + if k in r: + raise TypeError(f"{k} specified twice, in **kwargs and in factory_kwargs") + r[k] = kwargs[k] + + return r diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/_reduction.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/_reduction.py new file mode 100644 index 0000000000000000000000000000000000000000..ac2a8bb0a0e9eda779073176bcc209f326011600 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/_reduction.py @@ -0,0 +1,47 @@ +from typing import Optional +import warnings + +# NB: Keep this file in sync with enums in aten/src/ATen/core/Reduction.h + + +def get_enum(reduction: str) -> int: + if reduction == 'none': + ret = 0 + elif reduction == 'mean': + ret = 1 + elif reduction == 'elementwise_mean': + warnings.warn("reduction='elementwise_mean' is deprecated, please use reduction='mean' instead.") + ret = 1 + elif reduction == 'sum': + ret = 2 + else: + ret = -1 # TODO: remove once JIT exceptions support control flow + raise ValueError(f"{reduction} is not a valid value for reduction") + return ret + +# In order to support previous versions, accept boolean size_average and reduce +# and convert them into the new constants for now + + +# We use these functions in torch/legacy as well, in which case we'll silence the warning +def legacy_get_string(size_average: Optional[bool], reduce: Optional[bool], emit_warning: bool = True) -> str: + warning = "size_average and reduce args will be deprecated, please use reduction='{}' instead." + + if size_average is None: + size_average = True + if reduce is None: + reduce = True + + if size_average and reduce: + ret = 'mean' + elif reduce: + ret = 'sum' + else: + ret = 'none' + if emit_warning: + warnings.warn(warning.format(ret)) + return ret + + +def legacy_get_enum(size_average: Optional[bool], reduce: Optional[bool], emit_warning: bool = True) -> int: + return get_enum(legacy_get_string(size_average, reduce, emit_warning)) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/functional.pyi b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/functional.pyi new file mode 100644 index 0000000000000000000000000000000000000000..766ee58d70ead2b0bc143656f83feba293460c36 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/functional.pyi @@ -0,0 +1,682 @@ +from typing import ( + Any, + Callable, + Dict, + List, + Literal, + Optional, + overload, + Sequence, + Tuple, + Union, +) + +from torch import Tensor +from torch.types import _dtype, _int, _size + +from .common_types import ( + _ratio_any_t, + _size_1_t, + _size_2_opt_t, + _size_2_t, + _size_3_opt_t, + _size_3_t, + _size_any_t, +) + +# 'TypedDict' is a new accepted type that represents a dictionary with a fixed set of allowed keys. +# It is standards-track but not in `typing` yet. We leave this hear to be uncommented once the feature +# is wide-spread. + +# from mypy_extensions import TypedDict + +# GRID_SAMPLE_INTERPOLATION_MODES = TypedDict('GRID_SAMPLE_INTERPOLATION_MODES', {'bilinear': int, 'nearest': int}) +# GRID_SAMPLE_PADDING_MODES = TypedDict('GRID_SAMPLE_PADDING_MODES', {'zeros': int, 'border': int, 'reflection': int}) + +GRID_SAMPLE_INTERPOLATION_MODES = Dict[str, int] +GRID_SAMPLE_PADDING_MODES = Dict[str, int] + +# These stubs were generated by running stubgen (`stubgen --parse-only functional.py`), followed by manual cleaning. +# +# The 'BroadcastingList{1,2,3}' types were replaced by `_size` or _output_ratio, as appropriate. +# This was necessary since the JIT uses BroadcastingList* types but static checking with mypy etc requires a `Sequence` +# type. There is no way to express the expected lengths of these lists in the current Python typing system. +# +# Functions created via `_add_docstr` in `functional.py` where merely typed as `Any` by `stubgen`, so those were +# deleted from the stub and replaced by generated declarations. See `gen_pyi` for the implementation of the code +# generation logic for those functions. In the future, it might be worth looking into using the mypy plugin system +# to encode the type semantics of `_add_docstr`, should that system ever become widespread. +def fractional_max_pool2d_with_indices( + input: Tensor, + kernel_size: _size, + output_size: Optional[_size] = ..., + output_ratio: Optional[_ratio_any_t] = ..., + return_indices: bool = ..., + _random_samples: Optional[Tensor] = ..., +) -> Tuple[Tensor, Tensor]: ... +def fractional_max_pool3d_with_indices( + input: Tensor, + kernel_size: _size, + output_size: Optional[_size] = ..., + output_ratio: Optional[_ratio_any_t] = ..., + return_indices: bool = ..., + _random_samples: Optional[Tensor] = ..., +) -> Tuple[Tensor, Tensor]: ... +def max_pool1d_with_indices( + input: Tensor, + kernel_size: _size, + stride: Optional[_size] = ..., + padding: _size = ..., + dilation: _size = ..., + ceil_mode: bool = ..., + return_indices: bool = ..., +) -> Tuple[Tensor, Tensor]: ... +def max_pool2d_with_indices( + input: Tensor, + kernel_size: _size, + stride: Optional[_size] = ..., + padding: _size = ..., + dilation: _size = ..., + ceil_mode: bool = ..., + return_indices: bool = ..., +) -> Tuple[Tensor, Tensor]: ... +def max_pool3d_with_indices( + input: Tensor, + kernel_size: _size, + stride: Optional[_size] = ..., + padding: _size = ..., + dilation: _size = ..., + ceil_mode: bool = ..., + return_indices: bool = ..., +) -> Tuple[Tensor, Tensor]: ... +def max_unpool1d( + input: Tensor, + indices: Tensor, + kernel_size: _size, + stride: Optional[_size] = ..., + padding: _size = ..., + output_size: Optional[_size] = ..., +) -> Tensor: ... +def max_unpool2d( + input: Tensor, + indices: Tensor, + kernel_size: _size, + stride: Optional[_size] = ..., + padding: _size = ..., + output_size: Optional[_size] = ..., +) -> Tensor: ... +def max_unpool3d( + input: Tensor, + indices: Tensor, + kernel_size: _size, + stride: Optional[_size] = ..., + padding: _size = ..., + output_size: Optional[_size] = ..., +) -> Tensor: ... +def lp_pool1d( + input: Tensor, + norm_type: float, + kernel_size: _size_1_t, + stride: Union[Optional[_size], Optional[int]] = ..., + ceil_mode: bool = ..., +) -> Tensor: ... +def lp_pool2d( + input: Tensor, + norm_type: float, + kernel_size: _size_2_t, + stride: Union[Optional[_size], Optional[int]] = ..., + ceil_mode: bool = ..., +) -> Tensor: ... +def lp_pool3d( + input: Tensor, + norm_type: float, + kernel_size: _size_3_t, + stride: Union[Optional[_size], Optional[int]] = ..., + ceil_mode: bool = ..., +) -> Tensor: ... +def adaptive_max_pool1d_with_indices( + input: Tensor, + output_size: _size, + return_indices: bool = ..., +) -> Tuple[Tensor, Tensor]: ... +def adaptive_max_pool2d_with_indices( + input: Tensor, + output_size: _size_2_opt_t, + return_indices: bool = ..., +) -> Tuple[Tensor, Tensor]: ... +def adaptive_max_pool3d_with_indices( + input: Tensor, + output_size: _size_3_opt_t, + return_indices: bool = ..., +) -> Tuple[Tensor, Tensor]: ... +def adaptive_avg_pool2d(input: Tensor, output_size: _size_2_opt_t) -> Tensor: ... +def adaptive_avg_pool3d(input: Tensor, output_size: _size_3_opt_t) -> Tensor: ... +def dropout( + input: Tensor, + p: float = ..., + training: bool = ..., + inplace: bool = ..., +) -> Tensor: ... +def alpha_dropout( + input: Tensor, + p: float = ..., + training: bool = ..., + inplace: bool = ..., +) -> Tensor: ... +def dropout1d( + input: Tensor, + p: float = ..., + training: bool = ..., + inplace: bool = ..., +) -> Tensor: ... +def dropout2d( + input: Tensor, + p: float = ..., + training: bool = ..., + inplace: bool = ..., +) -> Tensor: ... +def dropout3d( + input: Tensor, + p: float = ..., + training: bool = ..., + inplace: bool = ..., +) -> Tensor: ... +def feature_alpha_dropout( + input: Tensor, + p: float = ..., + training: bool = ..., + inplace: bool = ..., +) -> Tensor: ... +def threshold( + input: Tensor, + threshold: float, + value: float, + inplace: bool = ..., +) -> Tensor: ... +def relu(input: Tensor, inplace: bool = ...) -> Tensor: ... +def glu(input: Tensor, dim: int = ...) -> Tensor: ... +def hardtanh( + input: Tensor, + min_val: float = ..., + max_val: float = ..., + inplace: bool = ..., +) -> Tensor: ... +def relu6(input: Tensor, inplace: bool = ...) -> Tensor: ... +def elu(input: Tensor, alpha: float = ..., inplace: bool = ...) -> Tensor: ... +def selu(input: Tensor, inplace: bool = ...) -> Tensor: ... +def celu(input: Tensor, alpha: float = ..., inplace: bool = ...) -> Tensor: ... +def leaky_relu( + input: Tensor, + negative_slope: float = ..., + inplace: bool = ..., +) -> Tensor: ... +def rrelu( + input: Tensor, + lower: float = ..., + upper: float = ..., + training: bool = ..., + inplace: bool = ..., +) -> Tensor: ... +def tanhshrink(input: Any): ... +def softsign(input: Any): ... +def softmin( + input: Tensor, + dim: Optional[int] = ..., + _stacklevel: int = ..., + dtype: Optional[_dtype] = ..., +) -> Tensor: ... +def softmax( + input: Tensor, + dim: Optional[int] = ..., + _stacklevel: int = ..., + dtype: Optional[_dtype] = ..., +) -> Tensor: ... +def gumbel_softmax( + logits: Tensor, + tau: float = ..., + hard: bool = ..., + eps: float = ..., + dim: int = ..., +) -> Tensor: ... +def log_softmax( + input: Tensor, + dim: Optional[int] = ..., + _stacklevel: int = ..., + dtype: Optional[_dtype] = ..., +) -> Tensor: ... +def tanh(input: Any): ... +def sigmoid(input: Any) -> Tensor: ... +def hardsigmoid(input: Tensor, inplace: bool = False) -> Tensor: ... +def silu(input: Tensor, inplace: bool = False) -> Tensor: ... +def mish(input: Tensor, inplace: bool = False) -> Tensor: ... +def hardswish(input: Tensor, inplace: bool = False) -> Tensor: ... +def embedding( + input: Tensor, + weight: Tensor, + padding_idx: Optional[int] = ..., + max_norm: Optional[float] = ..., + norm_type: float = ..., + scale_grad_by_freq: bool = ..., + sparse: bool = ..., +) -> Tensor: ... +def embedding_bag( + input: Tensor, + weight: Tensor, + offsets: Optional[Tensor] = ..., + max_norm: Optional[float] = ..., + norm_type: float = ..., + scale_grad_by_freq: bool = ..., + mode: str = ..., + sparse: bool = ..., + per_sample_weights: Optional[Tensor] = ..., + include_last_offset: bool = ..., + padding_idx: Optional[int] = ..., +) -> Tensor: ... +def batch_norm( + input: Tensor, + running_mean: Optional[Tensor], + running_var: Optional[Tensor], + weight: Optional[Tensor] = ..., + bias: Optional[Tensor] = ..., + training: bool = ..., + momentum: float = ..., + eps: float = ..., +) -> Tensor: ... +def instance_norm( + input: Tensor, + running_mean: Optional[Tensor] = ..., + running_var: Optional[Tensor] = ..., + weight: Optional[Tensor] = ..., + bias: Optional[Tensor] = ..., + use_input_stats: bool = ..., + momentum: float = ..., + eps: float = ..., +) -> Tensor: ... +def layer_norm( + input: Tensor, + normalized_shape: Sequence[int], + weight: Optional[Tensor] = ..., + bias: Optional[Tensor] = ..., + eps: float = ..., +) -> Tensor: ... +def group_norm( + input: Tensor, + num_groups: int, + weight: Optional[Tensor] = ..., + bias: Optional[Tensor] = ..., + eps: float = ..., +) -> Tensor: ... +def local_response_norm( + input: Tensor, + size: int, + alpha: float = ..., + beta: float = ..., + k: float = ..., +) -> Tensor: ... +def ctc_loss( + log_probs: Tensor, + targets: Tensor, + input_lengths: Tensor, + target_lengths: Tensor, + blank: int = ..., + reduction: str = ..., + zero_infinity: bool = ..., +) -> Tensor: ... +def nll_loss( + input: Tensor, + target: Tensor, + weight: Optional[Tensor] = ..., + size_average: Optional[bool] = ..., + ignore_index: int = ..., + reduce: Optional[bool] = ..., + reduction: str = ..., +) -> Tensor: ... +def poisson_nll_loss( + input: Tensor, + target: Tensor, + log_input: bool = ..., + full: bool = ..., + size_average: Optional[bool] = ..., + eps: float = ..., + reduce: Optional[bool] = ..., + reduction: str = ..., +) -> Tensor: ... +def gaussian_nll_loss( + input: Tensor, + target: Tensor, + var: Tensor, + full: Optional[bool] = ..., + eps: Optional[float] = ..., + reduction: Optional[str] = ..., +) -> Tensor: ... +def kl_div( + input: Tensor, + target: Tensor, + size_average: Optional[bool] = ..., + reduce: Optional[bool] = ..., + reduction: str = ..., + log_target: bool = ..., +) -> Tensor: ... +def cross_entropy( + input: Tensor, + target: Tensor, + weight: Optional[Tensor] = ..., + size_average: Optional[bool] = ..., + ignore_index: int = ..., + reduce: Optional[bool] = ..., + reduction: str = ..., + label_smoothing: float = ..., +) -> Tensor: ... +def binary_cross_entropy( + input: Tensor, + target: Tensor, + weight: Optional[Tensor] = ..., + size_average: Optional[bool] = ..., + reduce: Optional[bool] = ..., + reduction: str = ..., +) -> Tensor: ... +def binary_cross_entropy_with_logits( + input: Tensor, + target: Tensor, + weight: Optional[Tensor] = ..., + size_average: Optional[bool] = ..., + reduce: Optional[bool] = ..., + reduction: str = ..., + pos_weight: Optional[Tensor] = ..., +) -> Tensor: ... +def smooth_l1_loss( + input: Tensor, + target: Tensor, + size_average: Optional[bool] = ..., + reduce: Optional[bool] = ..., + reduction: str = ..., + beta: float = ..., +) -> Tensor: ... +def huber_loss( + input: Tensor, + target: Tensor, + reduction: str = ..., + delta: float = ..., +) -> Tensor: ... +def l1_loss( + input: Tensor, + target: Tensor, + size_average: Optional[bool] = ..., + reduce: Optional[bool] = ..., + reduction: str = ..., +) -> Tensor: ... +def mse_loss( + input: Tensor, + target: Tensor, + size_average: Optional[bool] = ..., + reduce: Optional[bool] = ..., + reduction: str = ..., +) -> Tensor: ... +def margin_ranking_loss( + input1: Tensor, + input2: Tensor, + target: Tensor, + margin: float = ..., + size_average: Optional[bool] = ..., + reduce: Optional[bool] = ..., + reduction: str = ..., +) -> Tensor: ... +def hinge_embedding_loss( + input: Tensor, + target: Tensor, + margin: float = ..., + size_average: Optional[bool] = ..., + reduce: Optional[bool] = ..., + reduction: str = ..., +) -> Tensor: ... +def multilabel_margin_loss( + input: Tensor, + target: Tensor, + size_average: Optional[bool] = ..., + reduce: Optional[bool] = ..., + reduction: str = ..., +) -> Tensor: ... +def soft_margin_loss( + input: Tensor, + target: Tensor, + size_average: Optional[bool] = ..., + reduce: Optional[bool] = ..., + reduction: str = ..., +) -> Tensor: ... +def multilabel_soft_margin_loss( + input: Tensor, + target: Tensor, + weight: Optional[Tensor] = ..., + size_average: Optional[bool] = ..., + reduce: Optional[bool] = ..., + reduction: str = ..., +) -> Tensor: ... +def cosine_embedding_loss( + input1: Tensor, + input2: Tensor, + target: Tensor, + margin: float = ..., + size_average: Optional[bool] = ..., + reduce: Optional[bool] = ..., + reduction: str = ..., +) -> Tensor: ... +def multi_margin_loss( + input: Tensor, + target: Tensor, + p: int = ..., + margin: float = ..., + weight: Optional[Tensor] = ..., + size_average: Optional[bool] = ..., + reduce: Optional[bool] = ..., + reduction: str = ..., +) -> Tensor: ... +def upsample( + input: Any, + size: Optional[Any] = ..., + scale_factor: Optional[Any] = ..., + mode: str = ..., + align_corners: Optional[Any] = ..., +): ... +def interpolate( + input: Any, + size: Optional[Any] = ..., + scale_factor: Optional[Any] = ..., + mode: str = ..., + align_corners: Optional[Any] = ..., + recompute_scale_factor: Optional[Any] = ..., + antialias: bool = ..., +): ... +def upsample_nearest( + input: Any, + size: Optional[Any] = ..., + scale_factor: Optional[Any] = ..., +): ... +def upsample_bilinear( + input: Any, + size: Optional[Any] = ..., + scale_factor: Optional[Any] = ..., +): ... +def grid_sample( + input: Tensor, + grid: Tensor, + mode: str = ..., + padding_mode: str = ..., + align_corners: Optional[Any] = ..., +) -> Tensor: ... +def affine_grid( + theta: Tensor, + size: List[int], + align_corners: Optional[Any] = ..., +) -> Tensor: ... +def triplet_margin_loss( + anchor: Tensor, + positive: Tensor, + negative: Tensor, + margin: float = ..., + p: float = ..., + eps: float = ..., + swap: bool = ..., + size_average: Optional[bool] = ..., + reduce: Optional[bool] = ..., + reduction: str = ..., +) -> Tensor: ... +def triplet_margin_with_distance_loss( + anchor: Tensor, + positive: Tensor, + negative: Tensor, + *, + distance_function: Optional[Callable[[Tensor, Tensor], Tensor]] = ..., + margin: float = ..., + swap: bool = ..., + reduction: str = ..., +) -> Tensor: ... +def normalize( + input: Tensor, + p: float = ..., + dim: int = ..., + eps: float = ..., + out: Optional[Tensor] = ..., +) -> Tensor: ... +def assert_int_or_pair( + arg: Any, + arg_name: Any, + message: Any, +) -> None: ... +def unfold( + input: Tensor, + kernel_size: _size_any_t, + dilation: _size_any_t = ..., + padding: _size_any_t = ..., + stride: _size_any_t = ..., +) -> Tensor: ... +def fold( + input: Tensor, + output_size: _size_any_t, + kernel_size: _size_any_t, + dilation: _size_any_t = ..., + padding: _size_any_t = ..., + stride: _size_any_t = ..., +) -> Tensor: ... +def _canonical_mask( + mask: Optional[Tensor], + mask_name: str, + other_type: Optional[_dtype], + other_name: str, + target_type: _dtype, + check_other: bool = True, +) -> Optional[Tensor]: ... +def _none_or_dtype(input: Optional[Tensor]) -> Optional[_dtype]: ... +def multi_head_attention_forward( + query: Tensor, + key: Tensor, + value: Tensor, + embed_dim_to_check: int, + num_heads: int, + in_proj_weight: Optional[Tensor], + in_proj_bias: Optional[Tensor], + bias_k: Optional[Tensor], + bias_v: Optional[Tensor], + add_zero_attn: bool, + dropout_p: float, + out_proj_weight: Tensor, + out_proj_bias: Optional[Tensor], + training: bool = True, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + use_separate_proj_weight: bool = False, + q_proj_weight: Optional[Tensor] = None, + k_proj_weight: Optional[Tensor] = None, + v_proj_weight: Optional[Tensor] = None, + static_k: Optional[Tensor] = None, + static_v: Optional[Tensor] = None, + average_attn_weights: bool = True, + is_causal: bool = False, +) -> Tuple[Tensor, Optional[Tensor]]: ... + +from .. import conv1d as conv1d +from .. import conv2d as conv2d +from .. import conv3d as conv3d +from .. import conv_transpose1d as conv_transpose1d +from .. import conv_transpose2d as conv_transpose2d +from .. import conv_transpose3d as conv_transpose3d +from .. import conv_tbc as conv_tbc +from .. import avg_pool1d as avg_pool1d +from .. import adaptive_avg_pool1d as adaptive_avg_pool1d +from .. import relu_ as relu_ +from .. import selu_ as selu_ +from .. import celu_ as celu_ +from .. import prelu as prelu +from .. import rrelu_ as rrelu_ +from .. import hardshrink as hardshrink +from .. import bilinear as bilinear +from .. import pixel_shuffle as pixel_shuffle +from .. import pixel_unshuffle as pixel_unshuffle +from .. import channel_shuffle as channel_shuffle +from .. import native_channel_shuffle as native_channel_shuffle +from .. import pairwise_distance as pairwise_distance +from .. import pdist as pdist +from .. import cosine_similarity as cosine_similarity +from .._C._nn import avg_pool2d as avg_pool2d +from .._C._nn import avg_pool3d as avg_pool3d +from .._C._nn import hardtanh_ as hardtanh_ +from .._C._nn import elu_ as elu_ +from .._C._nn import leaky_relu_ as leaky_relu_ +from .._C._nn import gelu as gelu +from .._C._nn import softplus as softplus +from .._C._nn import softshrink as softshrink +from .._C._nn import linear as linear +from .._C._nn import pad as pad +from .._C._nn import one_hot as one_hot +from .._C._nn import scaled_dot_product_attention as scaled_dot_product_attention +from .._C._nn import log_sigmoid +logsigmoid = log_sigmoid + +@overload +def adaptive_max_pool1d(input: Tensor, output_size: Union[_int, _size], return_indices: Literal[False] = False) -> Tensor: ... +@overload +def adaptive_max_pool1d(input: Tensor, output_size: Union[_int, _size], return_indices: Literal[True], /) -> Tuple[Tensor, Tensor]: ... +@overload +def adaptive_max_pool1d(input: Tensor, output_size: Union[_int, _size], *, return_indices: Literal[True]) -> Tuple[Tensor, Tensor]: ... +@overload +def adaptive_max_pool2d(input: Tensor, output_size: Union[_int, _size], return_indices: Literal[False] = False) -> Tensor: ... +@overload +def adaptive_max_pool2d(input: Tensor, output_size: Union[_int, _size], return_indices: Literal[True], /) -> Tuple[Tensor, Tensor]: ... +@overload +def adaptive_max_pool2d(input: Tensor, output_size: Union[_int, _size], *, return_indices: Literal[True]) -> Tuple[Tensor, Tensor]: ... +@overload +def adaptive_max_pool3d(input: Tensor, output_size: Union[_int, _size], return_indices: Literal[False] = False) -> Tensor: ... +@overload +def adaptive_max_pool3d(input: Tensor, output_size: Union[_int, _size], return_indices: Literal[True], /) -> Tuple[Tensor, Tensor]: ... +@overload +def adaptive_max_pool3d(input: Tensor, output_size: Union[_int, _size], *, return_indices: Literal[True]) -> Tuple[Tensor, Tensor]: ... +@overload +def fractional_max_pool2d(input: Tensor, kernel_size: Union[_int, _size], output_size: Optional[Union[_int, _size]] = None, output_ratio: Optional[_ratio_any_t] = None, return_indices: Literal[False] = False, _random_samples: Optional[Tensor] = None) -> Tensor: ... +@overload +def fractional_max_pool2d(input: Tensor, kernel_size: Union[_int, _size], output_size: Optional[Union[_int, _size]], output_ratio: Optional[_ratio_any_t], return_indices: Literal[True], /, _random_samples: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]: ... +@overload +def fractional_max_pool2d(input: Tensor, kernel_size: Union[_int, _size], output_size: Optional[Union[_int, _size]] = None, output_ratio: Optional[_ratio_any_t] = None, *, return_indices: Literal[True], _random_samples: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]: ... +@overload +def fractional_max_pool3d(input: Tensor, kernel_size: Union[_int, _size], output_size: Optional[Union[_int, _size]] = None, output_ratio: Optional[_ratio_any_t] = None, return_indices: Literal[False] = False, _random_samples: Optional[Tensor] = None) -> Tensor: ... +@overload +def fractional_max_pool3d(input: Tensor, kernel_size: Union[_int, _size], output_size: Optional[Union[_int, _size]], output_ratio: Optional[_ratio_any_t], return_indices: Literal[True], /, _random_samples: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]: ... +@overload +def fractional_max_pool3d(input: Tensor, kernel_size: Union[_int, _size], output_size: Optional[Union[_int, _size]] = None, output_ratio: Optional[_ratio_any_t] = None, *, return_indices: Literal[True], _random_samples: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]: ... +@overload +def max_pool1d(input: Tensor, kernel_size: Union[_int, _size], stride: Optional[Union[_int, _size]] = None, padding: Union[_int, _size] = 0, dilation: Union[_int, _size] = 1, ceil_mode: bool = False, return_indices: Literal[False] = False) -> Tensor: ... +@overload +def max_pool1d(input: Tensor, kernel_size: Union[_int, _size], stride: Optional[Union[_int, _size]], padding: Union[_int, _size], dilation: Union[_int, _size], ceil_mode: bool, return_indices: Literal[True], /) -> Tuple[Tensor, Tensor]: ... +@overload +def max_pool1d(input: Tensor, kernel_size: Union[_int, _size], stride: Optional[Union[_int, _size]] = None, padding: Union[_int, _size] = 0, dilation: Union[_int, _size] = 1, ceil_mode: bool = False, *, return_indices: Literal[True]) -> Tuple[Tensor, Tensor]: ... +@overload +def max_pool2d(input: Tensor, kernel_size: Union[_int, _size], stride: Optional[Union[_int, _size]] = None, padding: Union[_int, _size] = 0, dilation: Union[_int, _size] = 1, ceil_mode: bool = False, return_indices: Literal[False] = False) -> Tensor: ... +@overload +def max_pool2d(input: Tensor, kernel_size: Union[_int, _size], stride: Optional[Union[_int, _size]], padding: Union[_int, _size], dilation: Union[_int, _size], ceil_mode: bool, return_indices: Literal[True], /) -> Tuple[Tensor, Tensor]: ... +@overload +def max_pool2d(input: Tensor, kernel_size: Union[_int, _size], stride: Optional[Union[_int, _size]] = None, padding: Union[_int, _size] = 0, dilation: Union[_int, _size] = 1, ceil_mode: bool = False, *, return_indices: Literal[True]) -> Tuple[Tensor, Tensor]: ... +@overload +def max_pool3d(input: Tensor, kernel_size: Union[_int, _size], stride: Optional[Union[_int, _size]] = None, padding: Union[_int, _size] = 0, dilation: Union[_int, _size] = 1, ceil_mode: bool = False, return_indices: Literal[False] = False) -> Tensor: ... +@overload +def max_pool3d(input: Tensor, kernel_size: Union[_int, _size], stride: Optional[Union[_int, _size]], padding: Union[_int, _size], dilation: Union[_int, _size], ceil_mode: bool, return_indices: Literal[True], /) -> Tuple[Tensor, Tensor]: ... +@overload +def max_pool3d(input: Tensor, kernel_size: Union[_int, _size], stride: Optional[Union[_int, _size]] = None, padding: Union[_int, _size] = 0, dilation: Union[_int, _size] = 1, ceil_mode: bool = False, *, return_indices: Literal[True]) -> Tuple[Tensor, Tensor]: ... diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/init.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/init.py new file mode 100644 index 0000000000000000000000000000000000000000..426069d780c0b350c60a3bf92331fbfc52b5bf23 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/init.py @@ -0,0 +1,626 @@ +"""This file contains utilities for initializing neural network parameters.""" +import math +import warnings + +from torch import Tensor +import torch +from typing import Optional as _Optional + +# These no_grad_* functions are necessary as wrappers around the parts of these +# functions that use `with torch.no_grad()`. The JIT doesn't support context +# managers, so these need to be implemented as builtins. Using these wrappers +# lets us keep those builtins small and re-usable. +def _no_grad_uniform_(tensor, a, b, generator=None): + with torch.no_grad(): + return tensor.uniform_(a, b, generator=generator) + + +def _no_grad_normal_(tensor, mean, std, generator=None): + with torch.no_grad(): + return tensor.normal_(mean, std, generator=generator) + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b, generator=None): + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1, generator=generator) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def _no_grad_fill_(tensor, val): + with torch.no_grad(): + return tensor.fill_(val) + + +def _no_grad_zero_(tensor): + with torch.no_grad(): + return tensor.zero_() + + +def calculate_gain(nonlinearity, param=None): + r"""Return the recommended gain value for the given nonlinearity function. + + The values are as follows: + + ================= ==================================================== + nonlinearity gain + ================= ==================================================== + Linear / Identity :math:`1` + Conv{1,2,3}D :math:`1` + Sigmoid :math:`1` + Tanh :math:`\frac{5}{3}` + ReLU :math:`\sqrt{2}` + Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}` + SELU :math:`\frac{3}{4}` + ================= ==================================================== + + .. warning:: + In order to implement `Self-Normalizing Neural Networks`_ , + you should use ``nonlinearity='linear'`` instead of ``nonlinearity='selu'``. + This gives the initial weights a variance of ``1 / N``, + which is necessary to induce a stable fixed point in the forward pass. + In contrast, the default gain for ``SELU`` sacrifices the normalization + effect for more stable gradient flow in rectangular layers. + + Args: + nonlinearity: the non-linear function (`nn.functional` name) + param: optional parameter for the non-linear function + + Examples: + >>> gain = nn.init.calculate_gain('leaky_relu', 0.2) # leaky_relu with negative_slope=0.2 + + .. _Self-Normalizing Neural Networks: https://papers.nips.cc/paper/2017/hash/5d44ee6f2c3f71b73125876103c8f6c4-Abstract.html + """ + linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d'] + if nonlinearity in linear_fns or nonlinearity == 'sigmoid': + return 1 + elif nonlinearity == 'tanh': + return 5.0 / 3 + elif nonlinearity == 'relu': + return math.sqrt(2.0) + elif nonlinearity == 'leaky_relu': + if param is None: + negative_slope = 0.01 + elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float): + # True/False are instances of int, hence check above + negative_slope = param + else: + raise ValueError(f"negative_slope {param} not a valid number") + return math.sqrt(2.0 / (1 + negative_slope ** 2)) + elif nonlinearity == 'selu': + return 3.0 / 4 # Value found empirically (https://github.com/pytorch/pytorch/pull/50664) + else: + raise ValueError(f"Unsupported nonlinearity {nonlinearity}") + + +def uniform_( + tensor: Tensor, + a: float = 0.0, + b: float = 1.0, + generator: _Optional[torch.Generator] = None, +) -> Tensor: + r"""Fill the input Tensor with values drawn from the uniform distribution. + + :math:`\mathcal{U}(a, b)`. + + Args: + tensor: an n-dimensional `torch.Tensor` + a: the lower bound of the uniform distribution + b: the upper bound of the uniform distribution + generator: the torch Generator to sample from (default: None) + + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.uniform_(w) + """ + if torch.overrides.has_torch_function_variadic(tensor): + return torch.overrides.handle_torch_function( + uniform_, (tensor,), tensor=tensor, a=a, b=b, generator=generator + ) + return _no_grad_uniform_(tensor, a, b, generator) + + +def normal_( + tensor: Tensor, + mean: float = 0.0, + std: float = 1.0, + generator: _Optional[torch.Generator] = None, +) -> Tensor: + r"""Fill the input Tensor with values drawn from the normal distribution. + + :math:`\mathcal{N}(\text{mean}, \text{std}^2)`. + + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + generator: the torch Generator to sample from (default: None) + + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.normal_(w) + """ + if torch.overrides.has_torch_function_variadic(tensor): + return torch.overrides.handle_torch_function( + normal_, (tensor,), tensor=tensor, mean=mean, std=std, generator=generator + ) + return _no_grad_normal_(tensor, mean, std, generator) + +def trunc_normal_( + tensor: Tensor, + mean: float = 0., + std: float = 1., + a: float = -2., + b: float = 2., + generator: _Optional[torch.Generator] = None +) -> Tensor: + r"""Fill the input Tensor with values drawn from a truncated normal distribution. + + The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + generator: the torch Generator to sample from (default: None) + + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b, generator=generator) + + +def constant_(tensor: Tensor, val: float) -> Tensor: + r"""Fill the input Tensor with the value :math:`\text{val}`. + + Args: + tensor: an n-dimensional `torch.Tensor` + val: the value to fill the tensor with + + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.constant_(w, 0.3) + """ + if torch.overrides.has_torch_function_variadic(tensor): + return torch.overrides.handle_torch_function(constant_, (tensor,), tensor=tensor, val=val) + return _no_grad_fill_(tensor, val) + + +def ones_(tensor: Tensor) -> Tensor: + r"""Fill the input Tensor with the scalar value `1`. + + Args: + tensor: an n-dimensional `torch.Tensor` + + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.ones_(w) + """ + return _no_grad_fill_(tensor, 1.) + + +def zeros_(tensor: Tensor) -> Tensor: + r"""Fill the input Tensor with the scalar value `0`. + + Args: + tensor: an n-dimensional `torch.Tensor` + + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.zeros_(w) + """ + return _no_grad_zero_(tensor) + + +def eye_(tensor): + r"""Fill the 2-dimensional input `Tensor` with the identity matrix. + + Preserves the identity of the inputs in `Linear` layers, where as + many inputs are preserved as possible. + + Args: + tensor: a 2-dimensional `torch.Tensor` + + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.eye_(w) + """ + if tensor.ndimension() != 2: + raise ValueError("Only tensors with 2 dimensions are supported") + + with torch.no_grad(): + torch.eye(*tensor.shape, out=tensor, requires_grad=tensor.requires_grad) + return tensor + + +def dirac_(tensor, groups=1): + r"""Fill the {3, 4, 5}-dimensional input `Tensor` with the Dirac delta function. + + Preserves the identity of the inputs in `Convolutional` + layers, where as many input channels are preserved as possible. In case + of groups>1, each group of channels preserves identity + + Args: + tensor: a {3, 4, 5}-dimensional `torch.Tensor` + groups (int, optional): number of groups in the conv layer (default: 1) + Examples: + >>> w = torch.empty(3, 16, 5, 5) + >>> nn.init.dirac_(w) + >>> w = torch.empty(3, 24, 5, 5) + >>> nn.init.dirac_(w, 3) + """ + dimensions = tensor.ndimension() + if dimensions not in [3, 4, 5]: + raise ValueError("Only tensors with 3, 4, or 5 dimensions are supported") + + sizes = tensor.size() + + if sizes[0] % groups != 0: + raise ValueError('dim 0 must be divisible by groups') + + out_chans_per_grp = sizes[0] // groups + min_dim = min(out_chans_per_grp, sizes[1]) + + with torch.no_grad(): + tensor.zero_() + + for g in range(groups): + for d in range(min_dim): + if dimensions == 3: # Temporal convolution + tensor[g * out_chans_per_grp + d, d, tensor.size(2) // 2] = 1 + elif dimensions == 4: # Spatial convolution + tensor[g * out_chans_per_grp + d, d, tensor.size(2) // 2, + tensor.size(3) // 2] = 1 + else: # Volumetric convolution + tensor[g * out_chans_per_grp + d, d, tensor.size(2) // 2, + tensor.size(3) // 2, tensor.size(4) // 2] = 1 + return tensor + + +def _calculate_fan_in_and_fan_out(tensor): + dimensions = tensor.dim() + if dimensions < 2: + raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions") + + num_input_fmaps = tensor.size(1) + num_output_fmaps = tensor.size(0) + receptive_field_size = 1 + if tensor.dim() > 2: + # math.prod is not always available, accumulate the product manually + # we could use functools.reduce but that is not supported by TorchScript + for s in tensor.shape[2:]: + receptive_field_size *= s + fan_in = num_input_fmaps * receptive_field_size + fan_out = num_output_fmaps * receptive_field_size + + return fan_in, fan_out + + +def xavier_uniform_( + tensor: Tensor, gain: float = 1.0, generator: _Optional[torch.Generator] = None +) -> Tensor: + r"""Fill the input `Tensor` with values using a Xavier uniform distribution. + + The method is described in `Understanding the difficulty of training + deep feedforward neural networks` - Glorot, X. & Bengio, Y. (2010). + The resulting tensor will have values sampled from + :math:`\mathcal{U}(-a, a)` where + + .. math:: + a = \text{gain} \times \sqrt{\frac{6}{\text{fan\_in} + \text{fan\_out}}} + + Also known as Glorot initialization. + + Args: + tensor: an n-dimensional `torch.Tensor` + gain: an optional scaling factor + generator: the torch Generator to sample from (default: None) + + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.xavier_uniform_(w, gain=nn.init.calculate_gain('relu')) + """ + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) + std = gain * math.sqrt(2.0 / float(fan_in + fan_out)) + a = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation + + return _no_grad_uniform_(tensor, -a, a, generator) + + +def xavier_normal_( + tensor: Tensor, + gain: float = 1.0, + generator: _Optional[torch.Generator] = None, +) -> Tensor: + r"""Fill the input `Tensor` with values using a Xavier normal distribution. + + The method is described in `Understanding the difficulty of training deep feedforward + neural networks` - Glorot, X. & Bengio, Y. (2010). The resulting tensor + will have values sampled from :math:`\mathcal{N}(0, \text{std}^2)` where + + .. math:: + \text{std} = \text{gain} \times \sqrt{\frac{2}{\text{fan\_in} + \text{fan\_out}}} + + Also known as Glorot initialization. + + Args: + tensor: an n-dimensional `torch.Tensor` + gain: an optional scaling factor + generator: the torch Generator to sample from (default: None) + + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.xavier_normal_(w) + """ + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) + std = gain * math.sqrt(2.0 / float(fan_in + fan_out)) + + return _no_grad_normal_(tensor, 0., std, generator) + + +def _calculate_correct_fan(tensor, mode): + mode = mode.lower() + valid_modes = ['fan_in', 'fan_out'] + if mode not in valid_modes: + raise ValueError(f"Mode {mode} not supported, please use one of {valid_modes}") + + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) + return fan_in if mode == 'fan_in' else fan_out + + +def kaiming_uniform_( + tensor: Tensor, + a: float = 0, + mode: str = "fan_in", + nonlinearity: str = "leaky_relu", + generator: _Optional[torch.Generator] = None, +): + r"""Fill the input `Tensor` with values using a Kaiming uniform distribution. + + The method is described in `Delving deep into rectifiers: Surpassing + human-level performance on ImageNet classification` - He, K. et al. (2015). + The resulting tensor will have values sampled from + :math:`\mathcal{U}(-\text{bound}, \text{bound})` where + + .. math:: + \text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}} + + Also known as He initialization. + + Args: + tensor: an n-dimensional `torch.Tensor` + a: the negative slope of the rectifier used after this layer (only + used with ``'leaky_relu'``) + mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'`` + preserves the magnitude of the variance of the weights in the + forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the + backwards pass. + nonlinearity: the non-linear function (`nn.functional` name), + recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default). + generator: the torch Generator to sample from (default: None) + + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu') + """ + if torch.overrides.has_torch_function_variadic(tensor): + return torch.overrides.handle_torch_function( + kaiming_uniform_, + (tensor,), + tensor=tensor, + a=a, + mode=mode, + nonlinearity=nonlinearity, + generator=generator) + + if 0 in tensor.shape: + warnings.warn("Initializing zero-element tensors is a no-op") + return tensor + fan = _calculate_correct_fan(tensor, mode) + gain = calculate_gain(nonlinearity, a) + std = gain / math.sqrt(fan) + bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation + with torch.no_grad(): + return tensor.uniform_(-bound, bound, generator=generator) + + +def kaiming_normal_( + tensor: Tensor, + a: float = 0, + mode: str = "fan_in", + nonlinearity: str = "leaky_relu", + generator: _Optional[torch.Generator] = None, +): + r"""Fill the input `Tensor` with values using a Kaiming normal distribution. + + The method is described in `Delving deep into rectifiers: Surpassing + human-level performance on ImageNet classification` - He, K. et al. (2015). + The resulting tensor will have values sampled from + :math:`\mathcal{N}(0, \text{std}^2)` where + + .. math:: + \text{std} = \frac{\text{gain}}{\sqrt{\text{fan\_mode}}} + + Also known as He initialization. + + Args: + tensor: an n-dimensional `torch.Tensor` + a: the negative slope of the rectifier used after this layer (only + used with ``'leaky_relu'``) + mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'`` + preserves the magnitude of the variance of the weights in the + forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the + backwards pass. + nonlinearity: the non-linear function (`nn.functional` name), + recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default). + generator: the torch Generator to sample from (default: None) + + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.kaiming_normal_(w, mode='fan_out', nonlinearity='relu') + """ + if 0 in tensor.shape: + warnings.warn("Initializing zero-element tensors is a no-op") + return tensor + fan = _calculate_correct_fan(tensor, mode) + gain = calculate_gain(nonlinearity, a) + std = gain / math.sqrt(fan) + with torch.no_grad(): + return tensor.normal_(0, std, generator=generator) + + +def orthogonal_( + tensor, + gain=1, + generator: _Optional[torch.Generator] = None, +): + r"""Fill the input `Tensor` with a (semi) orthogonal matrix. + + Described in `Exact solutions to the nonlinear dynamics of learning in deep + linear neural networks` - Saxe, A. et al. (2013). The input tensor must have + at least 2 dimensions, and for tensors with more than 2 dimensions the + trailing dimensions are flattened. + + Args: + tensor: an n-dimensional `torch.Tensor`, where :math:`n \geq 2` + gain: optional scaling factor + generator: the torch Generator to sample from (default: None) + + Examples: + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK) + >>> w = torch.empty(3, 5) + >>> nn.init.orthogonal_(w) + """ + if tensor.ndimension() < 2: + raise ValueError("Only tensors with 2 or more dimensions are supported") + + if tensor.numel() == 0: + # no-op + return tensor + rows = tensor.size(0) + cols = tensor.numel() // rows + flattened = tensor.new(rows, cols).normal_(0, 1, generator=generator) + + if rows < cols: + flattened.t_() + + # Compute the qr factorization + q, r = torch.linalg.qr(flattened) + # Make Q uniform according to https://arxiv.org/pdf/math-ph/0609050.pdf + d = torch.diag(r, 0) + ph = d.sign() + q *= ph + + if rows < cols: + q.t_() + + with torch.no_grad(): + tensor.view_as(q).copy_(q) + tensor.mul_(gain) + return tensor + + +def sparse_( + tensor, + sparsity, + std=0.01, + generator: _Optional[torch.Generator] = None, +): + r"""Fill the 2D input `Tensor` as a sparse matrix. + + The non-zero elements will be drawn from the normal distribution + :math:`\mathcal{N}(0, 0.01)`, as described in `Deep learning via + Hessian-free optimization` - Martens, J. (2010). + + Args: + tensor: an n-dimensional `torch.Tensor` + sparsity: The fraction of elements in each column to be set to zero + std: the standard deviation of the normal distribution used to generate + the non-zero values + generator: the torch Generator to sample from (default: None) + + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.sparse_(w, sparsity=0.1) + """ + if tensor.ndimension() != 2: + raise ValueError("Only tensors with 2 dimensions are supported") + + rows, cols = tensor.shape + num_zeros = int(math.ceil(sparsity * rows)) + + with torch.no_grad(): + tensor.normal_(0, std, generator=generator) + for col_idx in range(cols): + row_indices = torch.randperm(rows) + zero_indices = row_indices[:num_zeros] + tensor[zero_indices, col_idx] = 0 + return tensor + + +# for backward compatibility +def _make_deprecate(meth): + new_name = meth.__name__ + old_name = new_name[:-1] + + def deprecated_init(*args, **kwargs): + warnings.warn(f"nn.init.{old_name} is now deprecated in favor of nn.init.{new_name}.", stacklevel=2) + return meth(*args, **kwargs) + + deprecated_init.__doc__ = fr""" + {old_name}(...) + + .. warning:: + This method is now deprecated in favor of :func:`torch.nn.init.{new_name}`. + + See :func:`~torch.nn.init.{new_name}` for details.""" + deprecated_init.__name__ = old_name + return deprecated_init + + +uniform = _make_deprecate(uniform_) +normal = _make_deprecate(normal_) +constant = _make_deprecate(constant_) +eye = _make_deprecate(eye_) +dirac = _make_deprecate(dirac_) +xavier_uniform = _make_deprecate(xavier_uniform_) +xavier_normal = _make_deprecate(xavier_normal_) +kaiming_uniform = _make_deprecate(kaiming_uniform_) +kaiming_normal = _make_deprecate(kaiming_normal_) +orthogonal = _make_deprecate(orthogonal_) +sparse = _make_deprecate(sparse_) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/modules/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..670a654efb95096487f85de1d16e72ab55db37e8 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/modules/__init__.py @@ -0,0 +1,31 @@ +from .fused import _FusedModule # noqa: F401 +from .fused import BNReLU2d +from .fused import BNReLU3d +from .fused import ConvBn1d +from .fused import ConvBn2d +from .fused import ConvBn3d +from .fused import ConvBnReLU1d +from .fused import ConvBnReLU2d +from .fused import ConvBnReLU3d +from .fused import ConvReLU1d +from .fused import ConvReLU2d +from .fused import ConvReLU3d +from .fused import LinearBn1d +from .fused import LinearReLU + + +__all__ = [ + 'BNReLU2d', + 'BNReLU3d', + 'ConvBn1d', + 'ConvBn2d', + 'ConvBn3d', + 'ConvBnReLU1d', + 'ConvBnReLU2d', + 'ConvBnReLU3d', + 'ConvReLU1d', + 'ConvReLU2d', + 'ConvReLU3d', + 'LinearBn1d', + 'LinearReLU', +] diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/modules/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/modules/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9cf0fbac86a81fd68a50ba85a89ecd22445ff961 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/modules/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/modules/__pycache__/fused.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/modules/__pycache__/fused.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a732a41de3367b19e384d0387a1c3706f0ade161 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/modules/__pycache__/fused.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/qat/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/qat/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bd556d13b1d35c90afae495df1a91ca76512b386 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/qat/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/qat/modules/conv_fused.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/qat/modules/conv_fused.py new file mode 100644 index 0000000000000000000000000000000000000000..2f70dc038b5c4f8a8aa8b5c900314b9f409b5d89 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/qat/modules/conv_fused.py @@ -0,0 +1,37 @@ +# flake8: noqa: F401 +r"""Intrinsic QAT Modules. + +This file is in the process of migration to `torch/ao/nn/intrinsic/qat`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate file under the `torch/ao/nn/intrinsic/qat/modules`, +while adding an import statement here. +""" + +__all__ = [ + # Modules + 'ConvBn1d', + 'ConvBnReLU1d', + 'ConvReLU1d', + 'ConvBn2d', + 'ConvBnReLU2d', + 'ConvReLU2d', + 'ConvBn3d', + 'ConvBnReLU3d', + 'ConvReLU3d', + # Utilities + 'freeze_bn_stats', + 'update_bn_stats', +] + +from torch.ao.nn.intrinsic.qat import ConvBn1d +from torch.ao.nn.intrinsic.qat import ConvBnReLU1d +from torch.ao.nn.intrinsic.qat import ConvReLU1d +from torch.ao.nn.intrinsic.qat import ConvBn2d +from torch.ao.nn.intrinsic.qat import ConvBnReLU2d +from torch.ao.nn.intrinsic.qat import ConvReLU2d +from torch.ao.nn.intrinsic.qat import ConvBn3d +from torch.ao.nn.intrinsic.qat import ConvBnReLU3d +from torch.ao.nn.intrinsic.qat import ConvReLU3d +from torch.ao.nn.intrinsic.qat import freeze_bn_stats +from torch.ao.nn.intrinsic.qat import update_bn_stats diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/qat/modules/linear_fused.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/qat/modules/linear_fused.py new file mode 100644 index 0000000000000000000000000000000000000000..6e372a5c1d3f6d73bc128aeede9537e9e5de41b7 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/qat/modules/linear_fused.py @@ -0,0 +1,15 @@ +# flake8: noqa: F401 +r"""Intrinsic QAT Modules. + +This file is in the process of migration to `torch/ao/nn/intrinsic/qat`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate file under the `torch/ao/nn/intrinsic/qat/modules`, +while adding an import statement here. +""" + +__all__ = [ + 'LinearBn1d', +] + +from torch.ao.nn.intrinsic.qat import LinearBn1d diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/quantized/dynamic/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/quantized/dynamic/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6bc5f7d60dd9330fd111260b2b2552492e596dc4 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/quantized/dynamic/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/quantized/modules/__pycache__/conv_relu.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/quantized/modules/__pycache__/conv_relu.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bfc54835aea926c6a64d83efe795725eebda4fda Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/quantized/modules/__pycache__/conv_relu.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/parameter.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/parameter.py new file mode 100644 index 0000000000000000000000000000000000000000..43d4f1cf40008b0dd77e6fe6055753b57f75c041 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/parameter.py @@ -0,0 +1,223 @@ +import torch +from torch._C import _disabled_torch_function_impl +from collections import OrderedDict + +# Metaclass to combine _TensorMeta and the instance check override for Parameter. +class _ParameterMeta(torch._C._TensorMeta): + # Make `isinstance(t, Parameter)` return True for custom tensor instances that have the _is_param flag. + def __instancecheck__(self, instance): + return super().__instancecheck__(instance) or ( + isinstance(instance, torch.Tensor) and getattr(instance, '_is_param', False)) + + +class Parameter(torch.Tensor, metaclass=_ParameterMeta): + r"""A kind of Tensor that is to be considered a module parameter. + + Parameters are :class:`~torch.Tensor` subclasses, that have a + very special property when used with :class:`Module` s - when they're + assigned as Module attributes they are automatically added to the list of + its parameters, and will appear e.g. in :meth:`~Module.parameters` iterator. + Assigning a Tensor doesn't have such effect. This is because one might + want to cache some temporary state, like last hidden state of the RNN, in + the model. If there was no such class as :class:`Parameter`, these + temporaries would get registered too. + + Args: + data (Tensor): parameter tensor. + requires_grad (bool, optional): if the parameter requires gradient. Note that + the torch.no_grad() context does NOT affect the default behavior of + Parameter creation--the Parameter will still have `requires_grad=True` in + :class:`~no_grad` mode. See :ref:`locally-disable-grad-doc` for more + details. Default: `True` + """ + + def __new__(cls, data=None, requires_grad=True): + if data is None: + data = torch.empty(0) + if type(data) is torch.Tensor or type(data) is Parameter: + # For ease of BC maintenance, keep this path for standard Tensor. + # Eventually (tm), we should change the behavior for standard Tensor to match. + return torch.Tensor._make_subclass(cls, data, requires_grad) + + # Path for custom tensors: set a flag on the instance to indicate parameter-ness. + t = data.detach().requires_grad_(requires_grad) + if type(t) is not type(data): + raise RuntimeError(f"Creating a Parameter from an instance of type {type(data).__name__} " + "requires that detach() returns an instance of the same type, but return " + f"type {type(t).__name__} was found instead. To use the type as a " + "Parameter, please correct the detach() semantics defined by " + "its __torch_dispatch__() implementation.") + t._is_param = True + return t + + # Note: the 3 methods below only apply to standard Tensor. Parameters of custom tensor types + # are still considered that custom tensor type and these methods will not be called for them. + def __deepcopy__(self, memo): + if id(self) in memo: + return memo[id(self)] + else: + result = type(self)(self.data.clone(memory_format=torch.preserve_format), self.requires_grad) + memo[id(self)] = result + return result + + def __repr__(self): + return 'Parameter containing:\n' + super().__repr__() + + def __reduce_ex__(self, proto): + state = torch._utils._get_obj_state(self) + + # See Note [Don't serialize hooks] + hooks = OrderedDict() + if not state: + return ( + torch._utils._rebuild_parameter, + (self.data, self.requires_grad, hooks) + ) + + return ( + torch._utils._rebuild_parameter_with_state, + (self.data, self.requires_grad, hooks, state) + ) + + __torch_function__ = _disabled_torch_function_impl + + +class UninitializedTensorMixin: + _allowed_methods = [ + torch.Tensor.__hash__, + torch.Tensor.size, + torch.Tensor.copy_, + torch.Tensor.is_complex, + torch.Tensor.is_floating_point, + torch.Tensor.half, + torch.Tensor.float, + torch.Tensor.double, + torch.Tensor.char, + torch.Tensor.short, + torch.Tensor.int, + torch.Tensor.long, + torch.Tensor.cuda, + torch.Tensor.cpu, + torch.Tensor.to, + torch.Tensor.get_device, + torch._has_compatible_shallow_copy_type, + ] + + def materialize(self, shape, device=None, dtype=None): + r"""Create a Parameter or Tensor with the same properties of the uninitialized one. + + Given a shape, it materializes a parameter in the same device + and with the same `dtype` as the current one or the specified ones in the + arguments. + + Args: + shape : (tuple): the shape for the materialized tensor. + device (:class:`torch.device`): the desired device of the parameters + and buffers in this module. Optional. + dtype (:class:`torch.dtype`): the desired floating point type of + the floating point parameters and buffers in this module. Optional. + """ + if device is None: + device = self.data.device + if dtype is None: + dtype = self.data.dtype + self.data = torch.empty(shape, device=device, dtype=dtype) + self.__class__ = self.cls_to_become + + @property + def shape(self): + raise RuntimeError( + 'Can\'t access the shape of an uninitialized parameter or buffer. ' + 'This error usually happens in `load_state_dict` when trying to load ' + 'an uninitialized parameter into an initialized one. ' + 'Call `forward` to initialize the parameters before accessing their attributes.') + + def share_memory_(self): + raise RuntimeError( + 'Can\'t share memory on an uninitialized parameter or buffer. ' + 'Call `forward` to initialize the parameters before calling ' + '`module.share_memory()`.') + + def __repr__(self): + return f'<{self.__class__.__name__}>' + + def __reduce_ex__(self, proto): + # See Note [Don't serialize hooks] + return ( + self.__class__, + (self.requires_grad,) + ) + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + # method-wrapper is to detect access to Tensor properties that are + # wrapped in descriptors + if func in cls._allowed_methods or func.__class__.__name__ == 'method-wrapper': + if kwargs is None: + kwargs = {} + return super().__torch_function__(func, types, args, kwargs) + raise ValueError( + f'Attempted to use an uninitialized parameter in {func}. ' + 'This error happens when you are using a `LazyModule` or ' + f'explicitly manipulating `torch.nn.parameter.{cls.__name__}` ' + 'objects. When using LazyModules Call `forward` with a dummy batch ' + 'to initialize the parameters before calling torch functions') + + +def is_lazy(param): + return isinstance(param, UninitializedTensorMixin) + + +class UninitializedParameter(UninitializedTensorMixin, Parameter): + r"""A parameter that is not initialized. + + Uninitialized Parameters are a a special case of :class:`torch.nn.Parameter` + where the shape of the data is still unknown. + + Unlike a :class:`torch.nn.Parameter`, uninitialized parameters + hold no data and attempting to access some properties, like their shape, + will throw a runtime error. The only operations that can be performed on a uninitialized + parameter are changing its datatype, moving it to a different device and + converting it to a regular :class:`torch.nn.Parameter`. + + The default device or dtype to use when the parameter is materialized can be set + during construction using e.g. ``device='cuda'``. + """ + + cls_to_become = Parameter + + def __new__(cls, requires_grad=True, device=None, dtype=None) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + data = torch.empty(0, **factory_kwargs) + return torch.Tensor._make_subclass(cls, data, requires_grad) + + def __deepcopy__(self, memo): + if id(self) in memo: + return memo[id(self)] + else: + result = type(self)(self.requires_grad, self.data.device, self.data.dtype) + memo[id(self)] = result + return result + +class UninitializedBuffer(UninitializedTensorMixin, torch.Tensor): + r"""A buffer that is not initialized. + + Uninitialized Buffer is a a special case of :class:`torch.Tensor` + where the shape of the data is still unknown. + + Unlike a :class:`torch.Tensor`, uninitialized parameters + hold no data and attempting to access some properties, like their shape, + will throw a runtime error. The only operations that can be performed on a uninitialized + parameter are changing its datatype, moving it to a different device and + converting it to a regular :class:`torch.Tensor`. + + The default device or dtype to use when the buffer is materialized can be set + during construction using e.g. ``device='cuda'``. + """ + + cls_to_become = torch.Tensor + + def __new__(cls, requires_grad=False, device=None, dtype=None) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + data = torch.empty(0, **factory_kwargs) + return torch.Tensor._make_subclass(cls, data, requires_grad)