diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_custom_op/autograd.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_custom_op/autograd.py new file mode 100644 index 0000000000000000000000000000000000000000..116a4612a45ef2267109bfc6dbec6e5d4c642c7f --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_custom_op/autograd.py @@ -0,0 +1,274 @@ +import torch +import torch.utils._pytree as pytree +from collections import namedtuple +import functools + + +# NOTE [CustomOp autograd kernel indirection] +# We register `inner` as the autograd kernel for this custom_op. +# `inner` either calls the autograd formula registered by the user, +# or goes into an `autograd_not_implemented` kernel. +# +# The reason why this indirection exists is +# so that we can swap out the autograd kernel (the PyTorch dispatcher +# doesn't actually allow us to do this). By default, we want +# the `autograd_not_implemented` behavior, but then the user may come +# and register something that is actually a backward formula +def autograd_kernel_indirection(custom_op): + autograd_fallback = autograd_not_implemented(custom_op) + + def inner(*args, **kwargs): + if custom_op._has_impl('autograd'): + kernel = custom_op._get_impl('autograd').func + return kernel(*args, **kwargs) + # As explained in NOTE ["backward", "save_for_backward", and "autograd"], + # after the user gives us "backward" and "save_for_backward", we generate + # the "autograd" impl. If the user only provided one, then we tell + # the user they've done something wrong. + if custom_op._has_impl('save_for_backward') or custom_op._has_impl('backward'): + missing = ( + 'save_for_backward' if custom_op._has_impl('backward') + else 'backward' + ) + found = 'save_for_backward' if missing == 'backward' else 'backward' + loc = custom_op._get_impl(found).location + raise RuntimeError( + f"We found a '{found}' registration for {custom_op} at " + f"{loc} but were unable to find a '{missing}' registration. " + f"To use the CustomOp API to register a backward formula, " + f"please provide us both a backward function and a " + f"'save for backward' function via `impl_backward` and " + f"`impl_save_for_backward` respectively.") + return autograd_fallback(*args, **kwargs) + return inner + + +# TODO(#101191): Use the actual C++ autograd not implemented fallback, +# or change the default autograd fallback to the autograd not implemented fallback. +def autograd_not_implemented(custom_op): + def kernel(*args, **kwargs): + if torch.is_grad_enabled() and pytree.tree_any( + lambda x: isinstance(x, torch.Tensor) and x.requires_grad, (args, kwargs) + ): + raise RuntimeError("Autograd has not been implemented for operator") + with torch._C._AutoDispatchBelowAutograd(): + return custom_op(*args, **kwargs) + return kernel + + +def mark_non_differentiable(ctx, output, output_differentiability): + # Output types are restricted to be: + # - Tensor + # - Tensor[] + # - int, bool, Scalar, float + # See _check_can_register_backward + if output_differentiability is not None: + if not isinstance(output, tuple): + tuple_output = (output,) + else: + tuple_output = output # type: ignore[assignment] + assert len(output_differentiability) == len(tuple_output) + non_differentiable_tensors = [] + for idx, (differentiable, out) in enumerate(zip(output_differentiability, tuple_output)): + if isinstance(out, torch.Tensor): + if not differentiable: + non_differentiable_tensors.append(out) + continue + if isinstance(out, list): + if not differentiable: + non_differentiable_tensors.extend(out) + continue + if differentiable: + raise RuntimeError( + f"With output_differentiability={output_differentiability}. " + f"At idx {idx}, we received an object of type {type(out)} that " + f"is not a Tensor, so it cannot have be marked as differentiable in " + f"output_differentiability.") + if non_differentiable_tensors: + ctx.mark_non_differentiable(*non_differentiable_tensors) + + +def construct_autograd_kernel( + schema, + output_differentiability, + custom_op, + op_overload, + save_for_backward_fn, + backward_fn): + + def apply(*args): + flat_args, spec = pytree.tree_flatten(args) + out_spec = None + + def forward(ctx, *flat_args): + ctx.set_materialize_grads(True) + args = pytree.tree_unflatten(list(flat_args), spec) + with torch._C._AutoDispatchBelowAutograd(): + output = op_overload(*args) + + # We use the info about args to give better error messages in backward + args_info = namedtuple_args( + schema, pytree.tree_map(type, args)) + + save_for_backward_fn_inputs = namedtuple_args(schema, args) + to_save = save_for_backward_fn(save_for_backward_fn_inputs, output) + + save_pytree_for_backward(ctx, (to_save, args_info)) + mark_non_differentiable(ctx, output, output_differentiability) + + nonlocal out_spec + flat_output, out_spec = pytree.tree_flatten(output) + return tuple(flat_output) + + def backward(ctx, *flat_grad_output): + assert out_spec is not None + grads = pytree.tree_unflatten(list(flat_grad_output), out_spec) + saved, args_info = unpack_saved(ctx) + # There is nothing on the ctx object for now, it is just there so + # that we can add additional things in the future. + inner_ctx = object() + if not isinstance(grads, tuple): + grads = (grads,) + grad_inputs_dict = backward_fn(inner_ctx, saved, *grads) + + # Massage the grad_inputs_dict to a form acceptable by + # autograd.Function. + validate_grad_inputs_dict(grad_inputs_dict, custom_op, args_info) + return grad_inputs_dict_to_flat_tuple(grad_inputs_dict, args_info) + + generated_cls = gen_autograd_function( + custom_op._opname + '_customop', forward, backward) + + flat_output = generated_cls.apply(*flat_args) + assert out_spec is not None + return pytree.tree_unflatten(list(flat_output), out_spec) + return apply + + +def gen_autograd_function(name, forward, backward): + generated_cls = type( + name, + (torch.autograd.Function,), + { + 'forward': staticmethod(forward), + 'backward': staticmethod(backward), + } + ) + return generated_cls + + +@functools.lru_cache +def namedtuple_args_cls(schema): + attribs = [arg.name for arg in schema.arguments.flat_all] + name = str(schema.name) + "_args" + # mypy doesn't support dynamic namedtuple name + tuple_cls = namedtuple(name, attribs) # type: ignore[misc] + return tuple_cls + + +def namedtuple_args(schema, args): + assert isinstance(args, tuple) + tuple_cls = namedtuple_args_cls(schema) + return tuple_cls(*args) + + +def validate_grad_inputs_dict(grad_inputs_dict, forward_op, args_info): + def error(what): + backward = forward_op._get_impl('backward') + raise RuntimeError( + f"In the backward function defined for {forward_op} at " + f"{backward.location} using the CustomOp API, {what}") + + if not isinstance(grad_inputs_dict, dict): + error(f"expected the output of the backward function to be a dict but " + f"got {type(grad_inputs_dict)}") + + expected_keys = {arg.name for arg in forward_op._schema.arguments.flat_all + if arg.type.is_tensor_like()} + actual_keys = grad_inputs_dict.keys() + if expected_keys != actual_keys: + error(f"expected the returned grad_input dict to have keys " + f"{expected_keys} but got {actual_keys}. The backward " + f"function must return a gradient (can be None) for each arg " + f"to the CustomOp that may be a Tensor or Sequence[Tensor]. " + f"Args declared to be non-Tensor-like types should not appear " + f"in the grad_input dict") + + for name, grad in grad_inputs_dict.items(): + arg_info = getattr(args_info, name) + + if isinstance(arg_info, list): + if not isinstance(grad, (tuple, list)): + error(f"for input '{name}' expected the grad_input dict to " + f"hold a list of gradients but got object of type " + f"{type(grad)}.") + if not len(grad) == len(arg_info): + error(f"for input '{name}' expected the grad_input dict to " + f"hold a list of {len(arg_info)} gradients but got " + f"{len(grad)}") + for idx, (g, info) in enumerate(zip(grad, arg_info)): + if g is None: + continue + if not isinstance(g, torch.Tensor): + error(f"for input '{name}' expected the grad_input dict to " + f"hold a list of None or Tensor gradients but got " + f"object of {type(g)} at index {idx}") + if not issubclass(info, torch.Tensor): + error(f"for input '{name}', got a Tensor as the gradient " + f"for the {idx}-th value but expected None because " + f"the {idx}-th value was not a Tensor (it was " + f"type {arg_info}") + continue + + if grad is None: + continue + if not isinstance(grad, torch.Tensor): + error(f"got object of type {type(grad)} as the gradient for input " + f"'{name}', " + f"but expected the gradient to be either None or a Tensor") + if not issubclass(arg_info, torch.Tensor): + error(f"got a Tensor as the gradient for input '{name}' but " + f"expected None as the gradient because input '{name}' " + f"was not a Tensor (it was type {arg_info}).") + + +def grad_inputs_dict_to_flat_tuple(grad_inputs_dict, args_info): + result = [] + for name, arg_info in args_info._asdict().items(): + if name not in grad_inputs_dict: + result.append(pytree.tree_map(lambda x: None, arg_info)) + continue + result.append(grad_inputs_dict[name]) + return tuple(pytree.tree_leaves(result)) + +# Saves "stuff" (a pytree) onto the ctx object. Use unpack_saved to unpack it. +# autograd.Function prefers that users use ctx.save_for_backward to +# save Tensors (to avoid reference cycles) and for non-Tensors to go onto the +# ctx object. +def save_pytree_for_backward(ctx, stuff): + flat_stuff, spec = pytree.tree_flatten(stuff) + num_elts = len(flat_stuff) + tensor_idxs = [idx for idx, thing in enumerate(flat_stuff) + if isinstance(thing, torch.Tensor)] + non_tensor_idxs = [idx for idx, thing in enumerate(flat_stuff) + if not isinstance(thing, torch.Tensor)] + tensors = [thing for thing in flat_stuff if isinstance(thing, torch.Tensor)] + non_tensors = [thing for thing in flat_stuff if not isinstance(thing, torch.Tensor)] + + ctx.spec = spec + ctx.num_elts = num_elts + ctx.save_for_backward(*tensors) + ctx.tensor_idxs = tensor_idxs + ctx.saved_non_tensors = non_tensors + ctx.non_tensor_idxs = non_tensor_idxs + + +# Inverse operation to save_pytree_for_backward +def unpack_saved(ctx): + flat_stuff = [None] * ctx.num_elts + for tensor, idx in zip(ctx.saved_tensors, ctx.tensor_idxs): + flat_stuff[idx] = tensor + for non_tensor, idx in zip(ctx.saved_non_tensors, ctx.non_tensor_idxs): + flat_stuff[idx] = non_tensor + stuff = pytree.tree_unflatten(flat_stuff, ctx.spec) + return stuff diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fa571489bd7f3e752872eee4c7832704dfece39f --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/__init__.py @@ -0,0 +1,6443 @@ +import builtins +import collections +import inspect +import itertools +import math +import operator +import warnings + +from collections.abc import Iterable +from enum import Enum +from functools import partial, reduce, singledispatch, wraps +from typing import Any, Callable, Dict, List, Optional, overload, Sequence, Tuple, Union + +import torch + +import torch._prims as prims +import torch._prims_common as utils +from torch import sym_float, sym_int +from torch._prims_common import ( + DeviceLikeType, + Dim, + DimsSequenceType, + DimsType, + dtype_to_type, + ELEMENTWISE_TYPE_PROMOTION_KIND, + FloatLike, + FloatWithoutSymFloat, + IntLike, + is_weakly_lesser_type, + Number, + NumberType, + RealNumberType, + REDUCTION_OUTPUT_TYPE_KIND, + ShapeType, + StrideType, + TensorLike, + TensorLikeType, + TensorOrNumberLikeType, + TensorSequenceType, +) +from torch._prims_common.wrappers import ( + _maybe_convert_to_dtype, + _maybe_resize_out, + _safe_copy_out, + elementwise_type_promotion_wrapper, + elementwise_unary_scalar_wrapper, + out_wrapper, +) + +# Experimental module containing prototype Python references for existing +# PyTorch operations. + +__all__ = [ + # + # Elementwise Unary References + # + "abs", + "acos", + "acosh", + "asinh", + "asin", + "atan", + "atanh", + "bitwise_not", + # "cbrt", # No corresponding torch operation + "ceil", + "conj_physical", + "cos", + "cosh", + "count_nonzero", + "deg2rad", + "digamma", + "erf", + "erfinv", + "erfc", + "exp", + "expm1", + "exponential", + "exp2", + "fill", + "fill_", + "floor", + "frac", + "geometric", + "index_add", + "index_copy", + "index_copy_", + "index_select", + "index_fill", + "index_fill_", + "isfinite", + "isinf", + "isposinf", + "isneginf", + "isnan", + "isreal", + "i0", + "lerp", + "lgamma", + "log", + "log1p", + "log2", + "log10", + "log_normal", + "log_softmax", + "mvlgamma", + "norm", + "normal", + "nan_to_num", + "neg", + "positive", + "rad2deg", + "reciprocal", + "round", # TODO: model kwargs + "sigmoid", + "sgn", + "sign", + "signbit", + "sin", + "sinc", + "sinh", + "softmax", + "sqrt", + "square", + "tan", + "tanh", + "trace", + "trunc", + # + # Elementwise Binary References + # + "add", + "atan2", + "bitwise_and", + "bitwise_left_shift", + "bitwise_or", + "bitwise_right_shift", + "bitwise_xor", + "clamp_min", + "clamp_max", + "copysign", + "div", + "eq", + "float_power", + "floor_divide", + "fmax", + "fmin", + "fmod", + "gcd", + "ge", + "gt", + "heaviside", + "hypot", + "igamma", + "igammac", + "imag", + "isclose", + "lcm", + # 'ldexp', + "le", + "logaddexp", + "logaddexp2", + "logical_and", + "logical_not", + "logical_or", + "logical_xor", + "logsumexp", + "lt", + # 'max', # implement with reductions + "maximum", + # 'min', # implement with reductions + "minimum", + "mul", + "ne", + "nextafter", + # 'polar', # abs, cos, sin + "pow", + "real", + "rpow", + "remainder", + "rsub", + "rtruediv", + "rfloordiv", + "sub", + "true_divide", + "trunc_divide", + "xlogy", + # + # Elementwise Ternary References + # + "addcdiv", + "addcmul", + "clamp", + # + # Conditional references + # + "masked_fill", + "masked_fill_", + "where", + # + # Data conversion and movement references + # + "clone", + "copy_to", # TODO: add OpInfo (or implement .to) + "item", + "to", + # + # Reduction ops + # + "all", + "amax", + "amin", + "any", + "cumsum", + "cumprod", + "mean", + "dot", + "vdot", + "std", + "std_mean", + "sum", + "sum_to_size", + "prod", + "var", + "var_mean", + # + # Linear algebra ops + # + "addr", + # + # View & Shape Ops + # + "alias", + "atleast_1d", + "atleast_2d", + "atleast_3d", + "as_strided", + "as_strided_scatter", + "block_diag", + "broadcast_shapes", + "broadcast_tensors", + "broadcast_to", + "cat", + "chunk", + "column_stack", + "conj", + "constant_pad_nd", + "contiguous", + "diag_embed", + "diag", + "diagonal", + "diagonal_copy", + "diagonal_scatter", + "dsplit", + "dstack", + "expand", + "expand_as", + "flatten", + "flip", + "fliplr", + "flipud", + "hsplit", + "hstack", + "meshgrid", + "movedim", + "narrow", + "narrow_copy", + "native_group_norm", + "native_layer_norm", + "permute", + "ravel", + "repeat", + "reshape", + "reshape_as", + "roll", + "rot90", + "rsqrt", + "stack", + "swap_axes", # alias for transpose + "squeeze", + "t", + "T", + "take_along_dim", + "tensor_split", + "transpose", + "unfold", + "unfold_copy", + "unsqueeze", + "view", + "view_as", + "vsplit", + "vstack", + "view_as_complex", + "unflatten", + "unbind", + "triu", + "tril", + "triu_indices", + "tril_indices", + # + # Tensor Creation + # + "arange", + "cauchy", + "empty", + "empty_like", + "empty_permuted", + "empty_strided", + "eye", + "full", + "full_like", + "linspace", + "logspace", + "new_empty", + "new_empty_strided", + "new_full", + "new_ones", + "new_zeros", + "ones", + "ones_like", + "randn", + "scalar_tensor", + "zero", + "zeros", + "zeros_like", + # + # Test-related functions + # + "allclose", + "equal", + # + # Statistical operations + # + "bucketize", + # + # Misc + # + "is_complex", + "renorm", + "stft", + "istft", +] + +Tensor = torch.Tensor +DispatchKey = torch._C.DispatchKey # type: ignore[attr-defined] +aten = torch._ops.ops.aten + +# Note that the docstrings for the public methods from this file are in +# torch/_torch_docs.py + + +def is_noncontiguous_supported(device): + if device is not None and device.type == "hpu": + return False + return True + + +def handle_noncontiguous_outputs(input_tlist, output): + device = None + from torch._subclasses.fake_tensor import FakeTensor + + for t in input_tlist: + if isinstance(t, FakeTensor): + device = t.fake_device + break + + if not is_noncontiguous_supported(device): + output = output.contiguous() + + return output + + +def _broadcast_shapes(*_shapes): + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + shapes = tuple( + (x,) if isinstance(x, IntLike) else x + for x in filter(lambda x: x is not None, _shapes) + ) + + # Short-circuits on no input + if len(shapes) == 0: + return None + + # Type checking + # TODO: make common validations available as utils + for shape in shapes: + assert isinstance(shape, Sequence) + + # Computes common shape + common_shape = [ + 1, + ] * reduce(max, (len(shape) for shape in shapes)) + for arg_idx, shape in enumerate(shapes): + for idx in range(-1, -1 - len(shape), -1): + if guard_size_oblivious(common_shape[idx] == 1): + if shape[idx] < 0: + raise ValueError( + "Attempting to broadcast a dimension with negative length!" + ) + common_shape[idx] = shape[idx] + elif guard_size_oblivious(shape[idx] != 1): + if common_shape[idx] != shape[idx]: + raise RuntimeError( + f"Attempting to broadcast a dimension of length {shape[idx]} at {idx}! " + f"Mismatching argument at index {arg_idx} had {shape}; but expected shape " + f"should be broadcastable to {common_shape}" + ) + + return common_shape + + +def _maybe_broadcast(*args, preserve_cpu_scalar_tensors=True): + # Computes common shape + common_shape = _broadcast_shapes( + *(t.shape if isinstance(t, TensorLike) else None for t in args) + ) + + def __maybe_broadcast(x, shape): + if x is None: + return None + elif isinstance(x, Number): + return x + elif isinstance(x, TensorLike): + if preserve_cpu_scalar_tensors and utils.is_cpu_scalar_tensor(x): + return x + + if not utils.same_shape(x.shape, common_shape): + return x.expand(common_shape) + + return x + else: + raise RuntimeError( + "Unexpected type when broadcasting: " + str(type(x)) + "!" + ) + + return tuple(__maybe_broadcast(x, common_shape) for x in args) + + +# Utilities should come BEFORE this import +from torch._decomp import register_decomposition + +# +# Elementwise unary references +# + +infer_aten_op = object() + + +# TODO: add type promotion support +def _make_elementwise_unary_reference( + type_promotion_kind, + *, + aten_op=infer_aten_op, + extra_meta=None, +) -> Callable: + def inner(prim: Callable): + nonlocal aten_op + + @wraps(prim) + @out_wrapper() + @elementwise_unary_scalar_wrapper + @elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=type_promotion_kind, + ) + def _ref(a: TensorLikeType) -> TensorLikeType: + if extra_meta is not None: + extra_meta(a) + + output = prim(a) + return handle_noncontiguous_outputs([a], output) + + if aten_op is infer_aten_op: + aten_op = utils.get_aten_op(prim, prim.__name__) + if aten_op is not None: + register_decomposition(aten_op)(_ref) + + return _ref + + return inner + + +def _make_alias(fn, name): + """ + This function defines an alias of another function and sets its __name__ argument. + It also sets its __module__ argument to the module of the caller. + Note that when naïvely doing `alias = fn`, we have that `alias.__name__ == "fn"`, and + `alias.__module__ == fn.__module__`. + """ + + def _fn(*args, **kwargs): + return fn(*args, **kwargs) + + _fn.__name__ = name + _fn.__module__ = inspect.currentframe().f_back.f_globals["__name__"] # type: ignore[union-attr] + return _fn + + +def _make_inplace(fn): + """ + Given a function with out variant (i.e. using `out_wrapper()), it returns its in-place variant + See https://github.com/pytorch/pytorch/wiki/Developer-FAQ#how-do-in-place-operations-work-in-pytorch + """ + + # nb. We use the name of the first argument used in the unary references + @wraps(fn) + def _fn(a, *args, **kwargs): + return fn(a, *args, out=a, **kwargs) + + inplace_name = f"{fn.__name__}_" + _fn.__name__ = inplace_name + _fn = register_decomposition(getattr(aten, inplace_name))(_fn) + + # We access the __all__ attribute of the module where fn is defined + # There may be a cleaner way of doing this... + from inspect import getmodule + + _all = getmodule(fn).__all__ # type: ignore[union-attr] + if inplace_name not in _all: + _all.append(inplace_name) + return _fn + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT) +def abs(a): + return prims.abs(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def acos(a): + return prims.acos(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def acosh(a): + return prims.acosh(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def asin(a): + return prims.asin(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def asinh(a): + return prims.asinh(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def atan(a): + return prims.atan(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def atanh(a): + return prims.atanh(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT) +def bitwise_not(a): + return prims.bitwise_not(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT) +def ceil(a): + return prims.ceil(a) + + +@register_decomposition(aten.is_complex) +def is_complex(input: TensorLikeType): + return utils.is_complex_dtype(input.dtype) + + +@register_decomposition(aten.conj_physical) +@out_wrapper() +def conj_physical(input: TensorLikeType): + if not utils.is_complex_dtype(input.dtype): + return input + return prims.conj_physical(input) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def cos(a): + return prims.cos(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def cosh(a): + return prims.cosh(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def digamma(a): + return prims.digamma(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def erf(a): + return prims.erf(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def erfinv(a): + return prims.erf_inv(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def erfc(a): + return prims.erfc(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def exp(a): + return prims.exp(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def expm1(a): + return prims.expm1(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def exp2(a): + return prims.exp2(a) + + +# Fill has its own implementation because it has a value parameter +# CompositeImplicitAutograd - don't register decomp +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a,"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH, +) +def fill(a: TensorLikeType, value: NumberType) -> TensorLikeType: + assert isinstance(a, TensorLike) + assert isinstance(value, Number) + + python_type = utils.dtype_to_type(a.dtype) + if not utils.is_weakly_lesser_type(type(value), python_type): + msg = f"value argument of type {type(value)} cannot be safely cast to type {python_type}!" + raise ValueError(msg) + + return prims.fill(a, value) + + +def fill_(a: TensorLikeType, value: NumberType) -> TensorLikeType: + r = prims.fill(a, value) + prims.copy_to(a, r) + return a + + +@register_decomposition(aten.zero) +@out_wrapper() +def zero(input: TensorLikeType) -> TensorLikeType: + return torch.zeros_like(input) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT) +def floor(a): + return prims.floor(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT) +def frac(x: TensorLikeType) -> TensorLikeType: + trunc_x = torch.mul(torch.floor(torch.abs(x)), torch.sign(x)) + return torch.sub(x, trunc_x) + + +# imag does not use _make_elementwise_unary_reference because it does not support out +def imag(a: TensorLikeType) -> TensorLikeType: + assert isinstance(a, TensorLike) + torch._check( + utils.is_complex_dtype(a.dtype), lambda: "imag only supports complex tensors." + ) + return prims.imag(a) + + +@_make_elementwise_unary_reference( + ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, + aten_op=None, # CompositeImplicitAutograd +) +def isfinite(a: TensorLikeType) -> TensorLikeType: + if utils.is_float_dtype(a.dtype) or utils.is_complex_dtype(a.dtype): + return prims.isfinite(a) + + return ones_like(a, dtype=torch.bool) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL) +def isinf(a: TensorLikeType) -> TensorLikeType: + if utils.is_complex_dtype(a.dtype): + return torch.logical_or(isinf(torch.real(a)), isinf(torch.imag(a))) + if utils.is_float_dtype(a.dtype): + return torch.abs(a) == float("inf") + return torch.zeros_like(a, dtype=torch.bool) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL) +def isposinf(a: TensorLikeType) -> TensorLikeType: + torch._check( + not utils.is_complex_dtype(a.dtype), + lambda: f"Complex dtype is not supported for isposinf, got dtype {a.dtype}", + ) + if utils.is_float_dtype(a.dtype): + return a == float("inf") + return torch.zeros_like(a, dtype=torch.bool) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL) +def isneginf(a: TensorLikeType) -> TensorLikeType: + torch._check( + not utils.is_complex_dtype(a.dtype), + lambda: f"Complex dtype is not supported for isneginf, got dtype {a.dtype}", + ) + if utils.is_float_dtype(a.dtype): + return a == float("-inf") + return torch.zeros_like(a, dtype=torch.bool) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL) +def isnan(a: TensorLikeType) -> TensorLikeType: + return prims.ne(a, a) + + +# alias +mvlgamma = _make_alias(torch.special.multigammaln, "mvlgamma") # type: ignore[has-type] + + +@_make_elementwise_unary_reference( + ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, + aten_op=None, # CompositeImplicitAutograd +) +def isreal(a: TensorLikeType) -> TensorLikeType: + if utils.is_complex_dtype(a.dtype): + return torch.imag(a) == 0 + return torch.ones_like(a, dtype=torch.bool) + + +# TODO: if this is special maybe it should be defined there and imported here? +@_make_elementwise_unary_reference( + ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, aten_op=aten.i0 +) +def i0(a): + return prims.bessel_i0(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def lgamma(a): + return prims.lgamma(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def log(a): + return prims.log(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def log1p(a): + return prims.log1p(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def log2(a): + return prims.log2(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def log10(a): + return prims.log10(a) + + +# CompositeImplicitAutograd - don't register decomp +@out_wrapper() +def log_softmax( + a: TensorLikeType, + dim: int, + dtype: Optional[torch.dtype] = None, +) -> TensorLikeType: + result_dtype = dtype or a.dtype + computation_dtype = utils.get_computation_dtype(result_dtype) + a_ = _maybe_convert_to_dtype(a, computation_dtype) + return _maybe_convert_to_dtype(a_ - logsumexp(a_, dim, keepdim=True), result_dtype) # type: ignore[return-value] + + +@register_decomposition(aten.logsumexp) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("self",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def logsumexp( + self: TensorLikeType, dim: DimsType, keepdim: bool = False +) -> TensorLikeType: + if not isinstance(dim, Iterable): + dim = (dim,) + if self.numel() == 0: + return torch.sum(torch.exp(self), dim, keepdim).log() + maxes = torch.amax(self, dim, keepdim=True) + maxes = torch.masked_fill(maxes, maxes.abs() == float("inf"), 0) + maxes_squeezed = maxes if keepdim else torch.squeeze(maxes, dim) + result = torch.sum(torch.exp(self - maxes), dim, keepdim) + return result.log().add(maxes_squeezed) + + +@register_decomposition(aten.nan_to_num) +@out_wrapper() +def nan_to_num( + a: TensorLikeType, + nan: Optional[NumberType] = 0.0, + posinf: Optional[NumberType] = None, + neginf: Optional[NumberType] = None, +) -> TensorLikeType: + assert isinstance(a, TensorLike) + + if utils.is_boolean_dtype(a.dtype) or utils.is_integer_dtype(a.dtype): + return a.clone() + + if nan is None: + nan = 0.0 + + if posinf is None: + posinf = torch.finfo(a.dtype).max + + if neginf is None: + neginf = torch.finfo(a.dtype).min + + result = torch.where(torch.isnan(a), nan, a) # type: ignore[call-overload] + result = torch.where(torch.isneginf(a), neginf, result) # type: ignore[call-overload] + result = torch.where(torch.isposinf(a), posinf, result) # type: ignore[call-overload] + return result + + +def _neg_meta(a: TensorLikeType): + torch._check( + a.dtype is not torch.bool, + lambda: ( + "Negation, the `-` operator, on a bool tensor is not supported. " + "If you are trying to invert a mask, use the `~` or `logical_not()` " + "operator instead." + ), + ) + + +@_make_elementwise_unary_reference( + ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, extra_meta=_neg_meta +) +def neg(a): + return prims.neg(a) + + +# positive does not use _make_elementwise_unary_reference because it does not support out +# CompositeImplicitAutograd - don't register decomp +def positive(a: TensorLikeType) -> TensorLikeType: + assert isinstance(a, TensorLike) + if a.dtype is torch.bool: + msg = "positive does not support bool tensors." + raise RuntimeError(msg) + return a + + +# real does not use _make_elementwise_unary_reference because it does not support out +def real(a: TensorLikeType) -> TensorLikeType: + assert isinstance(a, TensorLike) + if utils.is_complex_dtype(a.dtype): + return prims.real(a) + return a + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def reciprocal(a): + return prims.reciprocal(a) + + +@register_decomposition(aten.round) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def round(a: TensorLikeType, *, decimals: int = 0) -> TensorLikeType: + if decimals == 0: + return prims.round(a) + else: + ten_pow = 10**decimals + ten_neg_pow = 10 ** (-decimals) + return prims.mul(prims.round(prims.mul(a, ten_pow)), ten_neg_pow) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def rsqrt(a): + return prims.rsqrt(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def sigmoid(a: TensorLikeType) -> TensorLikeType: + return true_divide(1, add(1, exp(neg(a)))) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT) +def sgn(a): + if utils.is_complex_dtype(a.dtype): + a_abs = a.abs() + return torch.where(a_abs == 0, 0, a / a_abs) + else: + return a.sign() + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT) +def sign(a): + return prims.sign(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL) +def signbit(a): + return prims.signbit(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def sin(a): + return prims.sin(a) + + +# Autograd note: This will give the right first derivative at zero (by chance), +# but not the right second derivative +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def sinc(a): + a = math.pi * a + return torch.where(a == 0, 1, torch.sin(a) / a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def sinh(a): + return prims.sinh(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def sqrt(a): + return prims.sqrt(a) + + +@_make_elementwise_unary_reference( + ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG, + aten_op=None, # CompositeImplicitAutograd, +) +def square(a: TensorLikeType) -> TensorLikeType: + return mul(a, a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def tan(a): + return prims.tan(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def tanh(a): + return prims.tanh(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT) +def trunc(a): + return prims.trunc(a) + + +# TODO: register this as a real ref/decomposition once TorchInductor supports complex! +def view_as_complex(self: TensorLikeType) -> TensorLikeType: + input_dtype = self.dtype + torch._check( + utils.is_float_dtype(input_dtype), + lambda: f"view_as_complex is only supported for floating point" + f"tensors, but got a tensor of scalar type: {input_dtype}", + ) + sizes = self.size() + torch._check( + len(sizes) != 0, + lambda: "Input tensor must have one or more dimensions", + ) + torch._check( + sizes[-1] == 2, + lambda: "Tensor must have a last dimension of size 2", + ) + + old_strides = self.stride() + torch._check( + old_strides[-1] == 1, + lambda: "Tensor must have a last dimension with stride 1", + ) + dims = old_strides[:-1] + torch._check( + py_all(stride % 2 == 0 for stride in dims), + lambda: "Tensor must have a stride divisible by 2 for all but last dimension", + ) + torch._check( + self.storage_offset() % 2 == 0, + lambda: "Tensor must have a storage_offset divisible by 2", + ) + return prims.view_element_type( + self, utils.corresponding_complex_dtype(input_dtype) + ).squeeze(-1) + + +def _make_elementwise_binary_reference( + type_promotion_kind, + aten_op=infer_aten_op, + name=None, + has_out=True, + supports_lhs_python_scalar=True, + supports_rhs_python_scalar=True, + supports_two_python_scalars=False, + should_register_decomposition=True, +) -> Callable: + def inner(prim: Callable): + nonlocal aten_op, name + if name is None: + name = prim.__name__ + + @wraps(prim) + @elementwise_type_promotion_wrapper( + type_promoting_args=("a", "b"), + type_promotion_kind=type_promotion_kind, + ) + def _ref( + a: Union[Tensor, NumberType], + b: Union[Tensor, NumberType], + ) -> Tensor: + torch._check_value( + supports_lhs_python_scalar or not isinstance(a, Number), + lambda: f"{name}: Received a lhs Python scalar to an elementwise binary " + "operation that does not accept lhs scalars!", + ) + torch._check_value( + supports_rhs_python_scalar or not isinstance(b, Number), + lambda: f"{name}: Received a rhs Python scalar to an elementwise binary " + "operation that does not accept rhs scalars!", + ) + torch._check_value( + supports_two_python_scalars + or not (isinstance(a, Number) and isinstance(b, Number)), + lambda: f"{name}: Receive two Number inputs to an elementwise binary operation!", + ) + a, b = _maybe_broadcast(a, b) + output = prim(a, b) + return handle_noncontiguous_outputs([a, b], output) + + if has_out: + _ref = out_wrapper()(_ref) + + _ref.__name__ = name + if aten_op is infer_aten_op: + aten_op = utils.get_aten_op(prim, name) + if aten_op is not None and should_register_decomposition: + register_decomposition(aten_op)(_ref) + + return _ref + + return inner + + +# Add has its own implementation because it has an alpha argument +@register_decomposition(aten.add) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a", "b"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def add( + a: Union[TensorLikeType, NumberType], + b: Union[TensorLikeType, NumberType], + *, + alpha: Optional[NumberType] = None, +): + """ + Reference implementation of torch.add + """ + + a, b = _maybe_broadcast(a, b) + + if alpha is not None: + dtype = a.dtype if isinstance(a, TensorLike) else b.dtype # type: ignore[union-attr] + python_type = utils.dtype_to_type(dtype) + if python_type != bool and not utils.is_weakly_lesser_type( + type(alpha), python_type + ): + msg = f"alpha argument of type {type(alpha)} cannot be safely cast to type {python_type}!" + raise ValueError(msg) + if isinstance(b, TensorLike): + b = prims.mul(b, alpha) + else: + b = b * alpha + + output = prims.add(a, b) + return handle_noncontiguous_outputs([a, b], output) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + supports_lhs_python_scalar=False, + supports_rhs_python_scalar=False, +) +def atan2(a, b): + return prims.atan2(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def bitwise_and(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.bitwise_and(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def bitwise_left_shift(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.shift_left(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def bitwise_or(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.bitwise_or(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def bitwise_right_shift(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.shift_right_arithmetic(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def bitwise_xor(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.bitwise_xor(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + supports_lhs_python_scalar=False, +) +def copysign( + a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType] +): + if isinstance(b, Number) and isinstance(a, Tensor): + b = scalar_tensor(b, dtype=a.dtype, device=a.device) + elif isinstance(a, Tensor) and isinstance(b, Tensor) and a.device != b.device: + msg = "Expected divisor (b) to be on the same device ({}) as dividend (a), but it is found on {}!".format( + a.device, b.device + ) + raise RuntimeError(msg) + return where(signbit(b), neg(abs(a)), abs(a)) + + +# complex = _make_elementwise_binary_reference(prims.complex, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT) + + +@register_decomposition(aten.div) +@out_wrapper() +def div( + a: Union[TensorLikeType, NumberType], + b: Union[TensorLikeType, NumberType], + *, + rounding_mode: Optional[str] = None, +): + """ + Reference implementation of torch.div + """ + if rounding_mode is None: + return true_divide(a, b) + elif rounding_mode == "trunc": + return trunc_divide(a, b) + elif rounding_mode == "floor": + return floor_divide(a, b) + else: + msg = f"div expected rounding_mode to be one of None, 'trunc', or 'floor' but found {rounding_mode}." + raise ValueError(msg) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, + supports_lhs_python_scalar=False, +) +def eq(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.eq(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG, +) +def pow( + a: Union[TensorLikeType, NumberType], + b: Union[TensorLikeType, NumberType], +) -> TensorLikeType: + assert isinstance(a, TensorLikeType) or isinstance(b, TensorLikeType) + + if isinstance(b, Number): + if b == 1.0: + return a.clone() # type: ignore[return-value,union-attr] + elif b == 2.0: + return a * a # type: ignore[return-value] + elif b == 0.5: + return torch.sqrt(a) # type: ignore[arg-type] + elif isinstance(a, Number): + if a == 1.0: + return torch.fill(b, True) + if a == 2.0 and ( + utils.is_float_dtype(b.dtype) or utils.is_complex_dtype(b.dtype) + ): + return torch.exp2(b) + + return prims.pow(a, b) + + +# Float power has its own implementation because it has unique type promotion. +# CompositeImplicitAutograd - don't register decomp +@out_wrapper() +def float_power( + a: Union[TensorLikeType, NumberType], + b: Union[TensorLikeType, NumberType], +) -> Tensor: + if isinstance(a, Number) and isinstance(b, Number): + raise ValueError( + "Receive two Number inputs to an elementwise binary operation!" + ) + + # Handles type promotion + dtype = utils.get_higher_dtype(a, b) + assert dtype is not None + if utils.is_complex_dtype(dtype): + dtype = torch.complex128 + else: + dtype = torch.float64 + + # Float power has the following contiguous cast behavior to be + # consistent with its C++ impl + a = _maybe_convert_to_dtype(a, dtype) + b = _maybe_convert_to_dtype(b, dtype) + + a, b = _maybe_broadcast(a, b) + return pow(a, b) + + +# >>> a = torch.tensor(-0.2500, dtype=torch.float64) +# tensor(-0.250000000000000, dtype=torch.float64) +# +# >>> b = torch.tensor(-0.0010, dtype=torch.float64) +# tensor(-0.001000000000000, dtype=torch.float64) +# +# Note: In this case, casting float to double will expand the float mantissa with zeros, +# while creating a double generates a distinct mantissa. +# >>> torch.tensor(-0.001).to(dtype=torch.float64) +# tensor(-0.001000000047497, dtype=torch.float64) +# +# Floor Division +# The difference is caused because torch.remainder(a, b) = -0.001. +# +# >>> torch.floor(torch.true_divide(a, b)) +# tensor(250., dtype=torch.float64) +# +# >>> torch.div(a, b, rounding_mode='floor') +# tensor(249., dtype=torch.float64) +# +# Definition: a // b = (a - remainder(a, b)) / b +# >>> torch.true_divide(torch.sub(a, torch.remainder(a, b)), b) +# tensor(249., dtype=torch.float64) +# +# For reference, see CPython's implementation: +# https://github.com/python/cpython/blob/ace008c531dd685a30c1dd68f9b5ba35f20171cf/Objects/floatobject.c#L636 + + +@_make_elementwise_binary_reference( + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + supports_two_python_scalars=True, + should_register_decomposition=False, +) +def floor_divide( + a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType] +): + # Wrap scalars because some references only accept tensor arguments. + if isinstance(a, Number) and isinstance(b, Number): + a = scalar_tensor(a) + b = scalar_tensor(b) + elif isinstance(b, Number) and isinstance(a, Tensor): + b = scalar_tensor(b, dtype=a.dtype, device=a.device) + elif isinstance(a, Number) and isinstance(b, Tensor): + a = scalar_tensor(a, dtype=b.dtype, device=b.device) + elif isinstance(a, Tensor) and isinstance(b, Tensor) and a.device != b.device: + if a.device == torch.device("cpu"): + msg = "Expected divisor (b) to be on the same device ({}) as dividend (a), but it is found on {}!".format( + a.device, b.device + ) + raise RuntimeError(msg) + else: + b = prims.device_put(b, device=a.device) + + assert isinstance(a, Tensor) and isinstance(b, Tensor) + dtype = a.dtype + if utils.is_float_dtype(dtype): + return _floor_divide_float(a, b) + elif utils.is_integer_dtype(dtype): + return _floor_divide_integer(a, b) + else: + torch._check(False, lambda: f"{dtype} not supported for floor_divide") + + +def _floor_divide_integer(a: Tensor, b: Tensor) -> Tensor: + a, b = _maybe_broadcast(a, b) + + if not a.dtype.is_signed: + return prims.div(a, b) + + # Convert truncation to flooring: + offset = (torch.signbit(a) != torch.signbit(b)).logical_and(torch.fmod(a, b) != 0) + return prims.div(a, b) - _maybe_convert_to_dtype(offset, a.dtype) + + +def _floor_divide_float(a: Tensor, b: Tensor) -> Tensor: + mod = fmod(a, b) + div = true_divide(sub(a, mod), b) + + # Ensure that the remainder has the same sign as denominator + different_signed_inputs = bitwise_xor(lt(a, 0), lt(b, 0)) + non_zero_remainder = ne(mod, 0) + mask = bitwise_and(non_zero_remainder, different_signed_inputs) + div = where(mask, sub(div, 1), div) + + # Map quotient to nearest integer value + floor_div = floor(div) + mask = gt(sub(div, floor_div), 0.5) + floor_div = where(mask, add(floor_div, 1), floor_div) + + basic_div = true_divide(a, b) + zero_tensor = scalar_tensor(0, dtype=basic_div.dtype, device=basic_div.device) + + # If quotient is zero, copy signbit from true_divide quotient + floor_div = where(ne(div, 0), floor_div, copysign(zero_tensor, basic_div)) + + # If denominator is zero, then follow true_divide behavior + return where(ne(b, 0), floor_div, basic_div) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + supports_lhs_python_scalar=False, + supports_rhs_python_scalar=False, +) +def fmax(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.fmax(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + supports_lhs_python_scalar=False, + supports_rhs_python_scalar=False, +) +def fmin(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.fmin(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + supports_lhs_python_scalar=False, + supports_rhs_python_scalar=True, +) +def fmod(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.fmod(a, b) + + +@register_decomposition(aten.frexp) +@out_wrapper("mantissa", "exponent") +def frexp(self: TensorLikeType) -> Tuple[TensorLikeType, TensorLikeType]: + return torch.return_types.frexp(prims.frexp(self)) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + supports_lhs_python_scalar=False, + supports_rhs_python_scalar=False, +) +def gcd(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.gcd(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, + supports_lhs_python_scalar=False, +) +def ge(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.ge(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, + supports_lhs_python_scalar=False, +) +def gt(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.gt(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + supports_lhs_python_scalar=False, + supports_rhs_python_scalar=False, +) +def heaviside(input: TensorLikeType, values: TensorLikeType) -> TensorLikeType: + input_eq_zero = torch.eq(input, 0) + input_lt_zero = torch.logical_or(torch.lt(input, 0), torch.isnan(input)) + zeros_and_ones = torch.where(input_lt_zero, 0, 1) + output = torch.where(input_eq_zero, values, zeros_and_ones) + return output + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + supports_lhs_python_scalar=False, + supports_rhs_python_scalar=False, +) +def hypot(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.hypot(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + supports_lhs_python_scalar=False, + supports_rhs_python_scalar=False, +) +def igamma(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.igamma(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + supports_lhs_python_scalar=False, + supports_rhs_python_scalar=False, +) +def igammac(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.igammac(a, b) + + +def _check_close_args( + name: str, + a: TensorLikeType, + b: TensorLikeType, + rtol: float, + atol: float, +) -> None: + torch._check_value( + a.dtype == b.dtype, + lambda: f"{name}: Attempting to compare tensors of different dtypes {a.dtype} and {b.dtype}!", + ) + torch._check( + rtol >= 0, + lambda: f"{name}: rtol must be greater than or equal to zero, but got {rtol}!", + ) + torch._check( + atol >= 0, + lambda: f"{name}: atol must be greater than or equal to zero, but got {atol}!", + ) + + +# CompositeImplicitAutograd - don't register decomp +def isclose( + a: TensorLikeType, + b: TensorLikeType, + rtol: float = 1e-05, + atol: float = 1e-08, + equal_nan: bool = False, +) -> TensorLikeType: + _check_close_args(name="torch.isclose", a=a, b=b, rtol=rtol, atol=atol) + + close = eq(a, b) + if equal_nan and (utils.is_float_dtype(a.dtype) or utils.is_complex_dtype(a.dtype)): + close = logical_or(close, logical_and(isnan(a), isnan(b))) + + # Note: In case of zero tolerances the closeness inequality degenerates to an equality check. + # In this case, the short-circuit prevents false positives as detailed in the paragraph below. + if atol == 0 and rtol == 0: + return close + + # Note [closeness error computation] + # atol and rtol are provided as doubles, so the computation + # rtol * other will produce a float or complex tensor. + # When the difference (self - other) is compared to it then the + # tensor representing the difference will also be cast to float or complex. + # However, since (self - other) in uint8 is very likely to produce a + # negative value, this moves the cast forward so the difference is + # always computed in a float or complex type. + # If the values of the integer tensors cannot be exactly represented + # by the default scalar type then this may cause an incorrect result. + if not utils.is_float_dtype(a.dtype) and not utils.is_complex_dtype(a.dtype): + a = prims.convert_element_type(a, torch.get_default_dtype()) + b = prims.convert_element_type(b, torch.get_default_dtype()) + + allowed_error = add(atol, abs(mul(b, rtol))) + actual_error = abs(sub(a, b)) + + # Computes finite closeness + result = logical_or( + close, logical_and(isfinite(actual_error), le(actual_error, allowed_error)) + ) + + return result + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + supports_lhs_python_scalar=False, + supports_rhs_python_scalar=False, +) +def lcm(a: TensorLikeType, b: TensorLikeType): + dtype = a.dtype + # promoting to int32 to maintain 100% consistency with C++ and to + # prevent overflow in case of int8 and int16 + promote_to_int = dtype in (torch.int8, torch.int16) + if promote_to_int: + a = prims.convert_element_type(a, torch.int32) + b = prims.convert_element_type(b, torch.int32) + + g = torch.gcd(a, b) + # Avoid division by zero in case gcd(0, 0) == 0 + g = torch.where(g == 0, 1, g) + res = torch.abs(prims.div(a, g) * b) + return res if not promote_to_int else prims.convert_element_type(res, dtype) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, + supports_lhs_python_scalar=False, +) +def le(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.le(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + supports_lhs_python_scalar=False, + supports_rhs_python_scalar=False, +) +def logaddexp(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + # Nb. this implementation does not distribute the gradients evenly when a == b + mask = torch.real(a) >= torch.real(b) + max_ = torch.where(mask, a, b) + min_ = torch.where(mask, b, a) + inf_mask = torch.logical_and( + torch.logical_not(torch.isfinite(torch.real(a))), torch.real(a) == torch.real(b) + ) + if utils.is_complex_dtype(a.dtype) or utils.is_complex_dtype(b.dtype): + # are you wondering what this bunch of codes are for? edge cases! + neg_min_mask = torch.real(min_) < 0 + inf_vals = torch.where( + neg_min_mask, min_, torch.log(torch.exp(min_) + torch.exp(max_)) + ) + non_nan_vals = torch.where( + inf_mask, inf_vals, max_ + torch.log1p(torch.exp(min_ - max_)) + ) + # the type for full_like does not include tensor yet + nan_mask = torch.isnan(min_) + return torch.where(nan_mask, complex(float("nan"), float("nan")), non_nan_vals) # type: ignore[call-overload] + else: + return torch.where(inf_mask, a, max_ + torch.log1p(torch.exp(min_ - max_))) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + supports_lhs_python_scalar=False, + supports_rhs_python_scalar=False, +) +def logaddexp2(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + torch._check( + not (utils.is_complex_dtype(a.dtype) or utils.is_complex_dtype(b.dtype)), + lambda: "logaddexp2 doesn't support complex dtypes", + ) + # Nb. this implementation does not distribute the gradients evenly when a == b + mask = a >= b + max_ = torch.where(mask, a, b) + min_ = torch.where(mask, b, a) + inf_mask = torch.logical_and(torch.isinf(a), a == b) + inv_log_2 = 1.0 / math.log(2) + result = max_ + torch.log1p(torch.exp2(min_ - max_)) * inv_log_2 + return torch.where(inf_mask, a, result) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, +) +def logical_and(a: TensorLikeType, b: TensorLikeType): + if not utils.is_boolean_dtype(a.dtype): + a = a != 0 + if not utils.is_boolean_dtype(b.dtype): + b = b != 0 + return a & b + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL) +def logical_not(a: TensorLikeType): + if not utils.is_boolean_dtype(a.dtype): + return a == 0 + return ~a + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, +) +def logical_or(a: TensorLikeType, b: TensorLikeType): + if not utils.is_boolean_dtype(a.dtype): + a = a != 0 + if not utils.is_boolean_dtype(b.dtype): + b = b != 0 + return bitwise_or(a, b) + + +# TODO: skip unnecessary conversion of long to float +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, +) +def logical_xor(a: TensorLikeType, b: TensorLikeType): + if not utils.is_boolean_dtype(a.dtype): + a = a != 0 + if not utils.is_boolean_dtype(b.dtype): + b = b != 0 + return a ^ b + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, + supports_lhs_python_scalar=False, +) +def lt(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.lt(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def maximum(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.maximum(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def minimum(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.minimum(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + supports_two_python_scalars=True, +) +def mul(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.mul(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, + supports_lhs_python_scalar=False, +) +def ne(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.ne(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH, + supports_lhs_python_scalar=False, + supports_rhs_python_scalar=False, +) +def nextafter(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.nextafter(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def remainder(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.remainder(a, b) + + +# reverse sub +@register_decomposition(aten.rsub) +@out_wrapper() +def rsub( + a: Union[TensorLikeType, NumberType], + b: Union[TensorLikeType, NumberType], + alpha: NumberType = 1, +): + if isinstance(a, Number): + msg = "Received a Number for the first argument, but expected a Tensor" + raise ValueError(msg) + + return torch.sub(b, a, alpha=alpha) + + +# TODO: consider refactoring this with add impl +# sub has its own implementation because it has an alpha argument +@register_decomposition(aten.sub) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a", "b"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def sub( + a: Union[TensorLikeType, NumberType], + b: Union[TensorLikeType, NumberType], + *, + alpha: NumberType = 1, +): + """ + Reference implementation of torch.sub + """ + + a, b = _maybe_broadcast(a, b) + + if alpha != 1: + dtype = a.dtype if isinstance(a, TensorLike) else b.dtype # type: ignore[union-attr] + python_type = utils.dtype_to_type(dtype) + if not utils.is_weakly_lesser_type(type(alpha), python_type): + msg = f"alpha argument of type {type(alpha)} cannot be safely cast to type {python_type}!" + raise ValueError(msg) + if isinstance(b, torch.Tensor): + b = prims.mul(b, alpha) + else: + # Carefully not to use prims.mul if b is a scalar / symint. + # prims.mul always returns a tensor, + # which will mess with type promotion. + b = b * alpha + + output = prims.sub(a, b) + return handle_noncontiguous_outputs([a, b], output) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + name="true_divide", + aten_op=None, # CompositeImplicitAutograd + supports_two_python_scalars=True, +) +def true_divide(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.div(a, b) + + +@register_decomposition(aten.xlogy) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a", "b"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def xlogy(a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]): + torch._check( + isinstance(a, TensorLike) or isinstance(b, TensorLike), + lambda: 'Expected either argument a or b to be a Tensor"', + ) + + # Operations like eq and log do not handle scalar values, so we convert them to scalar_tensors. + if isinstance(b, TensorLike) and isinstance(a, Number): + a = scalar_tensor(a, dtype=b.dtype, device=b.device) + elif isinstance(a, TensorLike) and isinstance(b, Number): + b = scalar_tensor(b, dtype=a.dtype, device=a.device) + + # mypy: expected "Tensor" + assert isinstance(a, TensorLike) + assert isinstance(b, TensorLike) + rhs = torch.where(torch.eq(a, 0), 0, torch.mul(a, torch.log(b))) + return torch.where(torch.isnan(b), float("nan"), rhs) + + +@_make_elementwise_binary_reference( + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + aten_op=None, # CompositeImplicitAutograd + supports_two_python_scalars=True, +) +def trunc_divide( + a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType] +): + dtype = utils.get_dtype(a) + if utils.is_integer_dtype(dtype): + return prims.div(a, b) + + return trunc(prims.div(a, b)) + + +# +# Elementwise Ternary References +# + + +@register_decomposition(aten.addcdiv) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("self", "tensor1", "tensor2"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def addcdiv( + self: TensorLikeType, + tensor1: TensorLikeType, + tensor2: TensorLikeType, + *, + value: NumberType = 1, +) -> TensorLikeType: + """ + Reference implementation of torch.addcdiv + """ + if value is not None: + dtype = self.dtype # no scalars allowed, see add + python_type = utils.dtype_to_type(dtype) + torch._check_value( + utils.is_weakly_lesser_type(type(value), python_type), + lambda: f"value argument of type {type(value)} cannot be safely cast to type {python_type}!", + ) + + return self + value * tensor1 / tensor2 + + +@register_decomposition(aten.addcmul) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("self", "tensor1", "tensor2"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def addcmul( + self: TensorLikeType, + tensor1: TensorLikeType, + tensor2: TensorLikeType, + *, + value: NumberType = 1, +) -> TensorLikeType: + """ + Reference implementation of torch.addcmul + """ + if value is not None: + dtype = self.dtype # no scalars allowed, see add + python_type = utils.dtype_to_type(dtype) + torch._check_value( + utils.is_weakly_lesser_type(type(value), python_type), + lambda: f"value argument of type {type(value)} cannot be safely cast to type {python_type}!", + ) + + return self + value * tensor1 * tensor2 + + +@register_decomposition(aten.clamp) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a", "min", "max"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def clamp( + a: TensorLikeType, + min: Optional[TensorOrNumberLikeType] = None, + max: Optional[TensorOrNumberLikeType] = None, +) -> TensorLikeType: + # NOTE: grad behavior with implementation `where` is not consistent on `nan` + if min is None and max is None: + msg = "clamp called but both min and max are none!" + raise ValueError(msg) + if min is not None: + a_isnan = torch.isnan(a) + condition = torch.bitwise_or(torch.ge(a, min), a_isnan) # type: ignore[arg-type] + # we should also propagate `nan` coming from boundaries. However, that's + # not necessary since `ge` would already `False` when either operands has + # a `nan`. So this line below is redundant + # `condition = bitwise_and(condition, bitwise_not(isnan(min)))` + a = torch.where(condition, a, min) # type: ignore[arg-type] + if max is not None: + a_isnan = torch.isnan(a) + # same as above, no need to adjust `nan` from `max` + condition = torch.bitwise_or(torch.le(a, max), a_isnan) # type: ignore[arg-type] + a = torch.where(condition, a, max) # type: ignore[arg-type] + + return a + + +@register_decomposition(aten.clamp_min) +@out_wrapper() +def clamp_min( + self: TensorLikeType, + min: Optional[TensorOrNumberLikeType] = None, +) -> TensorLikeType: + return torch.clamp(self, min=min) # type: ignore[arg-type] + + +@register_decomposition(aten.clamp_max) +@out_wrapper() +def clamp_max( + self: TensorLikeType, + max: Optional[TensorOrNumberLikeType] = None, +) -> TensorLikeType: + return torch.clamp(self, max=max) # type: ignore[arg-type] + + +# +# Conditional references +# + + +# https://pytorch.org/docs/stable/generated/torch.where.html +# TODO: implement alternate where +@register_decomposition(aten.where) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a", "b"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH, +) +def where( + pred: Tensor, + a: Optional[TensorOrNumberLikeType] = None, + b: Optional[TensorOrNumberLikeType] = None, +): + """ """ + + if a is None or b is None: + raise NotImplementedError + + utils.check_same_device(pred, a, b, allow_cpu_scalar_tensors=True) + torch._check( + pred.dtype is torch.bool, + lambda: f"expected predicate to be bool, got {pred.dtype}", + ) + + pred, a, b = _maybe_broadcast(pred, a, b) + return prims.where(pred, a, b) + + +# +# Data Movement References +# +@register_decomposition(aten.clone) +@out_wrapper() +def clone( + a: TensorLikeType, *, memory_format: torch.memory_format = torch.preserve_format +) -> TensorLikeType: + result = prims.clone(a, memory_format=memory_format) + return result + + +def copy_to(a: Tensor, b: Tensor, *, allow_cross_device=True): + if not allow_cross_device and a.device != b.device: + msg = "Attempting to copy from device {} to device {}, but cross-device copies are not allowed!".format( + b.device, a.device + ) + raise RuntimeError(msg) + + return prims.copy_to(a, b) + + +@register_decomposition(aten.item) +def item(a: TensorLikeType) -> NumberType: + if a.numel() != 1: + msg = f"Can't convert a tensor with {a.numel()} elements to a number!" + raise ValueError(msg) + + # NOTE: explicit conversion is necessary for bool! + # See https://github.com/pytorch/pytorch/issues/78071 + number_type = utils.dtype_to_type(a.dtype) + return number_type(prims.item(a)) + + +# fast path when `to` returns an alias to input. This mimics the same function in aten +def _to_will_alias( + a: TensorLikeType, + device: Optional[DeviceLikeType] = None, + dtype: Optional[torch.dtype] = None, + copy: Optional[bool] = None, + layout: Optional[torch.layout] = None, + memory_format: Optional[torch.memory_format] = None, + pin_memory: Optional[bool] = False, + non_blocking: bool = False, # not using non_blocking +) -> bool: + return ( + not copy + and (device is None or a.device == device) + and (dtype is None or a.dtype == dtype) + and (layout is None or a.layout == layout) + # is_pinned issue #84925 + # and (pin_memory is None or pin_memory == a.is_pinned()) + and ( + memory_format is None + or memory_format == torch.preserve_format + or utils.is_contiguous_for_memory_format(a, memory_format=memory_format) + ) + ) + + +@singledispatch +def _to_dispatch(*args, **kwargs): + raise NotImplementedError + + +@_to_dispatch.register +def _to_device( + device: torch.device, + dtype: torch.dtype, + non_blocking: bool = False, + copy: bool = False, + memory_format: Optional[torch.memory_format] = None, +) -> Dict[str, Any]: + kwargs = { + "device": device, + "dtype": dtype, + "non_blocking": non_blocking, + "copy": copy, + "memory_format": memory_format, + } + return kwargs + + +@_to_dispatch.register +def _to_device_str( + device: str, + dtype: torch.dtype, + non_blocking: bool = False, + copy: bool = False, + memory_format: Optional[torch.memory_format] = None, +) -> Dict[str, Any]: + kwargs = { + "device": torch.device(device), + "dtype": dtype, + "non_blocking": non_blocking, + "copy": copy, + "memory_format": memory_format, + } + return kwargs + + +@_to_dispatch.register +def _to_dtype( + dtype: torch.dtype, + non_blocking: bool = False, + copy: bool = False, + memory_format: Optional[torch.memory_format] = None, +) -> Dict[str, Any]: + kwargs = { + "dtype": dtype, + "non_blocking": non_blocking, + "copy": copy, + "memory_format": memory_format, + } + return kwargs + + +@_to_dispatch.register +def _to_other( + other: Tensor, + non_blocking: bool = False, + copy: bool = False, + memory_format: Optional[torch.memory_format] = None, +) -> Dict[str, Any]: + device = other.device + dtype = other.dtype + layout = other.layout + # is_pinned issue #84925 + # pin_memory = other.is_pinned() + kwargs = { + "device": device, + "dtype": dtype, + "layout": layout, + "non_blocking": non_blocking, + "copy": copy, + "memory_format": memory_format, + } + return kwargs + + +# remove to_kwargs that is already present in `a` +def _canonicalize_to_arguments(a: Tensor, to_kwargs: dict): + options_to_check = ["dtype", "device", "layout", "memory_format"] + # "device" option could be passed a str instead torch.device + if "device" in to_kwargs and isinstance(to_kwargs["device"], str): + to_kwargs["device"] = torch.device(to_kwargs["device"]) + + for kw in options_to_check: + if kw in to_kwargs: + if ( + (kw == "memory_format" and to_kwargs[kw] is torch.preserve_format) + or ( + kw == "device" + and to_kwargs[kw].type == a.device.type + and ( + not to_kwargs[kw].index or to_kwargs[kw].index == a.device.index + ) + ) + or ( + getattr(a, kw, None) == to_kwargs[kw] + ) # this also handles {"memory_format": None} + ): + to_kwargs.pop(kw) + + +def to(a: TensorLikeType, *args, **kwargs) -> TensorLikeType: + # handled dispatch via positional arguments + if len(args) != 0: + kwargs = _to_dispatch(*args, **kwargs) + + # TODO: is_pinned is not currently supported in refs or fake_tensor + # https://github.com/pytorch/pytorch/issues/84925 + assert "pin_memory" not in kwargs + _canonicalize_to_arguments(a, kwargs) + + if _to_will_alias(a, **kwargs): + return a + + copy = kwargs.pop("copy") if "copy" in kwargs else False + non_blocking = kwargs.pop("non_blocking") if "non_blocking" in kwargs else False + + # short-circuit to `prims.convert_element_type` when `to` is just a dtype change + if ( + (copy or (kwargs.get("dtype", a.dtype) != a.dtype)) + and (not non_blocking) + and ("memory_format" not in kwargs) + and ("device" not in kwargs) + and ("layout" not in kwargs) + # is_pinned issue #84925 + # and ("pin_memory" not in kwargs) + ): + return prims.convert_element_type(a, kwargs.get("dtype", a.dtype)) + + result = torch.empty_like(a, **kwargs) + # TODO: non_blocking should be handled by `copy_to` + copy_to(result, a) + return result + + +# +# Reduction references +# + + +def _reduction( + a: TensorLikeType, + prim: Callable, + *, + has_identity: bool = True, + accepts_dim_tuple: bool = True, # to handle min/argmin that accept single dim only + dims: Optional[DimsType] = None, + keepdims: bool = False, + dtype: Optional[torch.dtype] = None, # should be specified for ops that support it + out: Optional[Tensor] = None, + output_dtype_kind: REDUCTION_OUTPUT_TYPE_KIND, +) -> TensorLikeType: # it is usually SAME, but I want + # ref writers to actually think about what to put here + assert isinstance(a, TensorLike) + if a.ndim > 64: + raise RuntimeError( + f"Received a tensor with {a.ndim} dimensions, but only tensors with up to 64 dims are supported!" + ) + + if out is not None: + assert isinstance(out, TensorLike) + if dtype is not None: + # TODO - this is true for eager mode currently, but it's wrong behavior for complex norms + if dtype != out.dtype: + raise RuntimeError( + "dtype argument and out dtype must match in reduction" + ) + if not accepts_dim_tuple: + assert dims is None or isinstance(dims, Dim) + if isinstance(dims, Dim): + dims = (dims,) # type: ignore[assignment] + dims = utils.reduction_dims(a.shape, dims) + if not has_identity: + valid_shape = a.ndim == 0 or py_all(a.shape[i] for i in dims) + if not valid_shape: + raise RuntimeError( + "reducing over zero-size dimension for reduction operation without identity" + ) + computation_dtype, result_dtype = utils.reduction_dtypes( + a, output_dtype_kind, dtype + ) + a = _maybe_convert_to_dtype(a, computation_dtype) # type: ignore[method-assign] + result = prim(a, dims) + if keepdims: + output_shape = [a.shape[i] if i not in dims else 1 for i in range(a.ndim)] + broadcast_dims = [i for i in range(a.ndim) if i not in dims] + result = prims.broadcast_in_dim(result, output_shape, broadcast_dims) + + if out is not None: + assert result_dtype is not None + if dtype is not None and result_dtype != out.dtype: + raise RuntimeError( + "Expected the dtype of reduction result and out to match" + ) + out = _maybe_resize_out(out, result.shape) + return _safe_copy_out(copy_from=result, copy_to=out) # type: ignore[arg-type] + + if result.dtype != result_dtype and result_dtype is not None: + result = prims.convert_element_type(result, result_dtype) + + return result + + +def _make_copy_from_view(fn): + """ + Given a view function (e.g. torch.diagonal) generates its copy variant (e.g. torch.diagonal_copy) + """ + name = fn.__name__ + fn = out_wrapper()(fn) + + def _fn(*args, out=None, **kwargs): + result = fn(*args, out=out, **kwargs) + if out is None: + return result.clone(memory_format=torch.contiguous_format) + return result + + copy_name = f"{name}_copy" + _fn.__name__ = copy_name + _fn = register_decomposition(getattr(aten, copy_name))(_fn) + return _fn + + +# Saves Python all +py_all = all + + +@register_decomposition(aten.all) +@out_wrapper() +def all( + a: TensorLikeType, + dim: Optional[DimsType] = None, + keepdim: bool = False, +) -> TensorLikeType: + result = torch.logical_not(torch.any(torch.logical_not(a), dim, keepdim=keepdim)) + + if a.dtype == torch.uint8: + result = result.to(dtype=torch.uint8) + + return result + + +# Saves Python any +py_any = any + + +@register_decomposition(aten.any) +@out_wrapper() +def any( + a: TensorLikeType, + dim: Optional[DimsType] = None, + keepdim: bool = False, +) -> TensorLikeType: + a_ = _maybe_convert_to_dtype(a, torch.bool) + if isinstance(dim, (list, tuple)) and len(dim) == 0: + result = a_.clone() + else: + result = a_.sum(dim=dim, keepdim=keepdim).ne(False) + + # Preserves uint8 -- probably a legacy mask thing + if a.dtype is torch.uint8: + return prims.convert_element_type(result, torch.uint8) + + return result + + +@register_decomposition([aten.sum.dim_IntList, aten.sum.IntList_out]) +def sum( + a: TensorLikeType, + dim: Union[Optional[int], Optional[List[int]]] = None, + keepdim: bool = False, + *, + dtype: Optional[torch.dtype] = None, + out: Optional[Tensor] = None, +) -> TensorLikeType: + if dtype is None: + if out is not None: + dtype = out.dtype + elif utils.is_boolean_dtype(a.dtype) or utils.is_integer_dtype(a.dtype): + dtype = torch.int64 + else: + dtype = a.dtype + # reduces over all dimensions if dim=() is passed + if dim == () or dim == []: + dim = None + return _reduction( + a, + prims.sum, + dims=dim, + keepdims=keepdim, + dtype=dtype, + out=out, + output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.SAME, + ) + + +def sum_to_size( + a: Tensor, + *shape, +) -> Tensor: + shape = utils.extract_shape_from_varargs(shape, validate=False) + torch._check( + utils.is_expandable_to(shape, a.shape), + lambda: f'sum_to_size: size "{shape}" is not expandable to size "{a.shape}"', + ) + # In ATen scalar tensors are sent through sum and the result is returned as + # type promoted + if utils.is_same_shape(shape, a.shape) and len(shape) > 0: + return prims.view_of(a) + leading_dims = a.ndim - len(shape) + reduce_dims = tuple(range(leading_dims)) + tuple( + i + for i in range(leading_dims, len(shape)) + if shape[i - leading_dims] == 1 and a.shape[i] != 1 + ) + return torch.sum(a, dim=reduce_dims, keepdim=True, dtype=None) + + +@register_decomposition(aten.prod) +def prod( + a: TensorLikeType, + dim: Union[Optional[int], Optional[List[int]]] = None, + keepdim: bool = False, + *, + dtype=None, + out: Optional[Tensor] = None, +) -> TensorLikeType: + if dtype is None: + if out is not None: + dtype = out.dtype + elif utils.is_boolean_dtype(a.dtype) or utils.is_integer_dtype(a.dtype): + dtype = torch.int64 + else: + dtype = a.dtype + # reduces over all dimensions if dim=() is passed + if dim == () or dim == []: + dim = None + return _reduction( + a, + prims.prod, + dims=dim, + keepdims=keepdim, + dtype=dtype, + out=out, + output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.SAME, + ) + + +@register_decomposition(aten.amin) +def amin( + a: TensorLikeType, + dim: Optional[DimsType] = None, + keepdim: bool = False, + *, + out: Optional[Tensor] = None, +) -> TensorLikeType: + # reduces over all dimensions if dim=() is passed + if dim == () or dim == []: + dim = None + + return _reduction( + a, + prims.amin, + dims=dim, + keepdims=keepdim, + dtype=None, + out=out, + has_identity=False, + output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.SAME, + ) + + +@register_decomposition(aten.amax) +def amax( + a: TensorLikeType, + dim: Optional[DimsType] = None, + keepdim: bool = False, + *, + out: Optional[Tensor] = None, +) -> TensorLikeType: + # reduces over all dimensions if dim=() is passed + if dim == () or dim == []: + dim = None + + return _reduction( + a, + prims.amax, + dims=dim, + keepdims=keepdim, + dtype=None, + out=out, + has_identity=False, + output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.SAME, + ) + + +def _dim_var_dispatch(dim=None, unbiased=None): + # There's the following overload of torch.var: + # var(Tensor self, bool unbiased=True) -> (Tensor, Tensor) + # We need to explicitly convert bool dims to unbiased arg + if unbiased is None and isinstance(dim, bool): + unbiased = dim + dim = None + return dim, unbiased + + +@register_decomposition(aten.var) +@out_wrapper() +def var( + a: TensorLikeType, + dim: Optional[DimsType] = None, + unbiased: Optional[bool] = None, + keepdim: bool = False, + *, + correction: Optional[NumberType] = None, +) -> TensorLikeType: + dim, unbiased = _dim_var_dispatch(dim, unbiased) + correction = utils.set_correction(unbiased, correction) + # reduces over all dimensions if dim=() is passed + if dim == () or dim == []: + dim = None + + result = _reduction( + a, + partial(prims.var, correction=correction), + dims=dim, + keepdims=keepdim, + dtype=None, + out=None, + has_identity=True, + output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT, + ) + return result + + +@register_decomposition(aten.std) +@out_wrapper() +def std( + a: TensorLikeType, + dim: Union[Optional[int], Optional[List[int]]] = None, + unbiased: Optional[bool] = None, + keepdim: bool = False, + *, + correction: Optional[NumberType] = None, +) -> TensorLikeType: + dim, unbiased = _dim_var_dispatch(dim, unbiased) + correction = utils.set_correction(unbiased, correction) + + opmath_dtype, dtype = utils.reduction_dtypes( + a, REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT + ) + a = _maybe_convert_to_dtype(a, opmath_dtype) + a_var = torch.var(a, dim, correction=correction, keepdim=keepdim) + a_std = torch.sqrt(a_var) + assert dtype is not None + return _maybe_convert_to_dtype(a_std, dtype) + + +@register_decomposition(aten.mean) +def mean( + a: TensorLikeType, + dim: Optional[DimsType] = None, + keepdim: bool = False, + *, + dtype=None, + out=None, +) -> TensorLikeType: + # reduces over all dimensions if dim=() is passed + if dim == () or dim == []: + dim = None + orig_dtype = dtype + if dtype is None: + dtype = a.dtype + # can't use out wrapper because of this argument + torch._check( + out is None or out.dtype == dtype, + lambda: f"Expected out tensor to have dtype {dtype}, but got {out.dtype} instead", + ) + result = _reduction( + a, + prims.sum, + dims=dim, + keepdims=keepdim, + dtype=dtype, + out=None, + output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.KEEP_PROMOTED_TYPE, + ) + torch._check( + utils.is_float_dtype(dtype) or utils.is_complex_dtype(dtype), + lambda: ( + f"mean(): could not infer output dtype. " + f"{'Input' if orig_dtype is None else 'Optional'} dtype must be either " + f"a floating point or complex dtype. Got: {dtype}" + ), + ) + if isinstance(dim, Dim): + dim = (dim,) # type: ignore[assignment] + dims = utils.reduction_dims(a.shape, dim) # type: ignore[arg-type] + nelem = 1 if a.ndim == 0 else reduce(operator.mul, (a.shape[i] for i in dims), 1) + result = true_divide(result, nelem) + result_dtype = a.dtype if dtype is None else dtype + result = _maybe_convert_to_dtype(result, result_dtype) # type: ignore[method-assign] + if out is not None: + assert isinstance(out, TensorLike) + out = _maybe_resize_out(out, result.shape) + return _safe_copy_out(copy_from=result, copy_to=out) # type: ignore[arg-type] + return result + + +@register_decomposition(aten.std_mean) +@out_wrapper("out0", "out1") +def std_mean( + a: TensorLikeType, + dim: Optional[DimsType] = None, + *, + unbiased: Optional[bool] = None, + keepdim: bool = False, + correction: Optional[NumberType] = None, +): + dim, unbiased = _dim_var_dispatch(dim, unbiased) + correction = utils.set_correction(unbiased, correction) + opmath_dtype, dtype = utils.reduction_dtypes( + a, REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT + ) + original_dtype = a.dtype + a = _maybe_convert_to_dtype(a, opmath_dtype) + a_var, a_mean = torch.var_mean(a, dim, correction=correction, keepdim=keepdim) + a_std = torch.sqrt(a_var) + assert dtype is not None + return ( + _maybe_convert_to_dtype(a_std, dtype), + _maybe_convert_to_dtype(a_mean, original_dtype), + ) + + +@register_decomposition(aten.var_mean) +@out_wrapper("out0", "out1") +def var_mean( + a: TensorLikeType, + dim: Optional[DimsType] = None, + unbiased: Optional[bool] = None, + keepdim: bool = False, + *, + correction: Optional[NumberType] = None, +): + dim, unbiased = _dim_var_dispatch(dim, unbiased) + v = var(a, dim, unbiased, keepdim, correction=correction) + m = mean(a, dim, keepdim) + return v, m + + +@register_decomposition(aten.addr) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("self", "vec1", "vec2"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def addr( + self: TensorLikeType, + vec1: TensorLikeType, + vec2: TensorLikeType, + *, + beta: NumberType = 1, + alpha: NumberType = 1, +) -> TensorLikeType: + torch._check( + vec1.ndim == 1, + lambda: f"addr: Expected 1-D argument vec1, but got {vec1.ndim}-D", + ) + torch._check( + vec2.ndim == 1, + lambda: f"addr: Expected 1-D argument vec2, but got {vec2.ndim}-D", + ) + self = self.expand(vec1.shape[0], vec2.shape[0]) + if utils.is_boolean_dtype(self.dtype): + # Integers are accepted for booleans + torch._check( + is_weakly_lesser_type(type(beta), int), + lambda: f"expected bool/int beta but got {type(beta)}", + ) + torch._check( + is_weakly_lesser_type(type(alpha), int), + lambda: f"expected bool/int alpha but got {type(beta)}", + ) + if not beta: + return torch.outer(vec1, vec2) if alpha else torch.full_like(self, False) + else: + return torch.logical_or( + self, + torch.outer(vec1, vec2) if alpha else torch.full_like(self, False), + ) + else: + torch._check( + is_weakly_lesser_type(type(beta), dtype_to_type(self.dtype)), + lambda: f"cannot safely convert {type(beta)} to {self.dtype}", + ) + torch._check( + is_weakly_lesser_type(type(alpha), dtype_to_type(self.dtype)), + lambda: f"cannot safely convert {type(alpha)} to {self.dtype}", + ) + if beta == 0: + # This means NaNs from self are dropped if beta is zero + return alpha * torch.outer(vec1, vec2) + else: + return beta * self + alpha * torch.outer(vec1, vec2) + + +# CompositeImplicitAutograd - don't register decomp +def atleast_1d( + arg: Union[TensorLikeType, Sequence[TensorLikeType]], *args: TensorLikeType +) -> Union[TensorLikeType, Tuple[TensorLikeType, ...]]: + """Reference implementation of :func:`torch.atleast_1d`.""" + if not args and isinstance(arg, collections.abc.Sequence): + args_ = arg + else: + assert not isinstance(arg, collections.abc.Sequence) + args_ = (arg,) + args + res = tuple(a if a.ndim >= 1 else unsqueeze(a, 0) for a in args_) + return res if len(res) > 1 else res[0] + + +# Helper function with assert to avoid MyPy error +# of incompatible type passed to unsqueeze +def _unsqueeze_atleast( + at_least_fn: Callable, dim: int, arg: TensorLikeType +) -> TensorLikeType: + arg_ = at_least_fn(arg) + assert isinstance(arg_, TensorLike) + return unsqueeze(arg_, dim) + + +# CompositeImplicitAutograd - don't register decomp +def atleast_2d( + arg: Union[TensorLikeType, Sequence[TensorLikeType]], *args: TensorLikeType +) -> Union[TensorLikeType, Tuple[TensorLikeType, ...]]: + """Reference implementation of :func:`torch.atleast_2d`.""" + if not args and isinstance(arg, collections.abc.Sequence): + args_ = arg + else: + assert not isinstance(arg, collections.abc.Sequence) + args_ = (arg,) + args + unsqueeze_atleast_1d = partial(_unsqueeze_atleast, atleast_1d, 0) + res = tuple(a if a.ndim >= 2 else unsqueeze_atleast_1d(a) for a in args_) + return res if len(res) > 1 else res[0] + + +# CompositeImplicitAutograd - don't register decomp +def atleast_3d( + arg: Union[TensorLikeType, Sequence[TensorLikeType]], *args: TensorLikeType +) -> Union[TensorLikeType, Tuple[TensorLikeType, ...]]: + """Reference implementation of :func:`torch.atleast_3d`.""" + if not args and isinstance(arg, collections.abc.Sequence): + args_ = arg + else: + assert not isinstance(arg, collections.abc.Sequence) + args_ = (arg,) + args + unsqueeze_atleast_2d = partial(_unsqueeze_atleast, atleast_2d, -1) + res = tuple(a if a.ndim >= 3 else unsqueeze_atleast_2d(a) for a in args_) + return res if len(res) > 1 else res[0] + + +def as_strided( + a: TensorLikeType, + size: ShapeType, + stride: StrideType, + storage_offset: Optional[int] = None, +) -> TensorLikeType: + storage_offset_int = ( + storage_offset if storage_offset is not None else a.storage_offset() + ) + return prims.as_strided(a, size, stride, storage_offset_int) + + +@register_decomposition(aten.as_strided_scatter) +@out_wrapper() +def as_strided_scatter( + input: TensorLikeType, + src: TensorLikeType, + size: ShapeType, + stride: StrideType, + storage_offset: Optional[int] = None, +) -> TensorLikeType: + storage_offset_int = 0 if storage_offset is None else storage_offset + return prims.as_strided_scatter(input, src, size, stride, storage_offset_int) + + +def broadcast_shapes(*shapes) -> ShapeType: + return torch.Size(_broadcast_shapes(*shapes)) + + +@aten.broadcast_tensors.default.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten.broadcast_tensors.default.py_impl(DispatchKey.Meta) +def broadcast_tensors(*tensors) -> List[TensorLikeType]: + if len(tensors) == 1 and not isinstance(tensors[0], Tensor): + tensors = tensors[0] + return list(_maybe_broadcast(*tensors, preserve_cpu_scalar_tensors=False)) + + +# CompositeImplicitAutograd - don't register decomp +def broadcast_to(a: TensorLikeType, size: ShapeType) -> TensorLikeType: + start = len(size) - len(a.shape) + dims = tuple(range(start, len(a.shape) + start)) + return prims.broadcast_in_dim(a, size, dims) + + +@register_decomposition(aten.cat) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("tensors",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH, +) +def cat(tensors: TensorSequenceType, dim: int = 0) -> TensorLikeType: + def cat_compute_output_memory_format(inputs): + format = None + for t in inputs: + f = utils.suggest_memory_format(t) + if f == torch.contiguous_format: + return f + if format is not None and format != f: + return torch.contiguous_format + format = f + assert format is not None + return format + + if len(tensors) == 0: + msg = "cat expects at least one tensor, but received zero!" + raise ValueError(msg) + + for tensor in tensors: + assert isinstance(tensor, TensorLike) + + utils.check_same_device(*tensors, allow_cpu_scalar_tensors=False) + + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + # This is a bit tricky. Naively, you would expect to just pick one + # arbitrary tensor and check that all tensors match this tensor. However, + # there is legacy behavior which says that if you have a 1-D empty tensor + # (0,), this is permissible. So you can't assume that all the tensors + # have same dimensionality, and you can't assume that the first tensor is + # the correct stencil. + # + # We'll implement this in a few passes. First, we will try to infer the + # ndim of the cat output. If this ndim != 1, then we know that all ndim = + # 1 inputs must be empty, or are errors. If this ndim == 1, then life + # is easy (the legacy special case coincides with regular handling). + # + # NB: The regular implementation of cat just filters out empty inputs, + # but we do it slightly different here for better handling for unbacked + # SymInts + + example = None + for i, t in enumerate(tensors): + if example is None: + if t.ndim != 1: + example = t + else: + if t.ndim != 1: + torch._check( + t.ndim == example.ndim, + lambda: "Number of dimensions of tensors must match. " + f"Expected {example.ndim}-D tensors, but got {t.ndim}-D for " + f"tensor number {i} in the list", + ) + + if example is None: + # example is None if everything is 1-D. If so, just arbitrarily pick + # the first one + example = tensors[0] + + shape = example.shape + filtered = [] + for tensor_idx, tensor in enumerate(tensors): + if len(shape) != len(tensor.shape): + assert tensor.ndim == 1 # we've already checked this above + # Don't suggest the legacy behavior in the error message + torch._check( + tensor.shape[0] == 0, + lambda: f"Number of dimensions of tensors must match. " + f"Expected {example.ndim}-D tensors, but got 1-D for " + f"tensor number {tensor_idx} in the list", + ) + else: + # Remove inputs that are 1-D, zero size + if tensor.ndim == 1 and guard_size_oblivious(tensor.shape[0] == 0): + continue + # Don't bother checking size match, prims.cat will handle it + filtered.append(tensor) + + memory_format = cat_compute_output_memory_format(tensors) + + if len(filtered) == 0: + t = tensors[0] + + # TODO: fix this to work with meta tensors + try: + requires_grad = any(x.requires_grad for x in tensors) + except Exception: + requires_grad = False + + return empty( + (0,), + dtype=t.dtype, + device=t.device, + requires_grad=requires_grad, + memory_format=memory_format, + ) + + dim = utils.canonicalize_dim(filtered[0].ndim, dim) + utils.validate_idx(filtered[0].ndim, dim) + + return prims.cat(filtered, dim).clone(memory_format=memory_format) + + +# CompositeImplicitAutograd - don't register decomp +@out_wrapper() +def column_stack(tensors: TensorSequenceType) -> TensorLikeType: + aligned_tensors = tuple( + x if x.ndim > 1 else x.reshape((x.numel(), 1)) for x in tensors + ) + return cat(aligned_tensors, 1) + + +def conj(input: TensorLikeType) -> TensorLikeType: + if not utils.is_complex_dtype(input.dtype): + return input + if input.is_sparse: + return torch.conj_physical(input) + return prims.conj(input) + + +# This replicates at::constant_pad_nd, defined in ATen/native/PadNd.cpp +@register_decomposition(aten.constant_pad_nd) +@out_wrapper() +def constant_pad_nd( + input: TensorLikeType, pad: List[int], value: NumberType = 0 +) -> TensorLikeType: + torch._check( + len(pad) % 2 == 0, + lambda: f"Length of pad must be even but instead it equals {len(pad)}", + ) + + input_sizes = input.shape + l_inp = len(input_sizes) + + l_pad = len(pad) // 2 + l_diff = l_inp - l_pad + + torch._check( + l_inp >= l_pad, + lambda: "Length of pad should be no more than twice the number of " + f"dimensions of the input. Pad length is {len(pad)} while the input has " + f"{l_inp} dimensions.", + ) + + c_input = input + for i in range(l_diff, l_inp): + pad_idx = 2 * (l_inp - i - 1) + if pad[pad_idx] < 0: + c_input = c_input.narrow(i, -pad[pad_idx], c_input.shape[i] + pad[pad_idx]) + + if pad[pad_idx + 1] < 0: + c_input = c_input.narrow(i, 0, c_input.shape[i] + pad[pad_idx + 1]) + + # if none of the pads are positive we can just return the result + if builtins.all(p <= 0 for p in pad): + return c_input.clone() + + new_shape = list(input_sizes[:l_diff]) + + for i in range(l_pad): + pad_idx = len(pad) - ((i + 1) * 2) + new_dim = input_sizes[l_diff + i] + pad[pad_idx] + pad[pad_idx + 1] + torch._check( + new_dim > 0, + lambda: f"The input size {input_sizes[l_diff + i]}, plus negative padding " + f"{pad[pad_idx]} and {pad[pad_idx + 1]} resulted in a negative output size, " + f"which is invalid. Check dimension {l_diff + i} of your input.", + ) + new_shape.append(new_dim) + + memory_format = utils.suggest_memory_format(input) + output = torch.empty( + new_shape, + dtype=input.dtype, + device=input.device, + requires_grad=input.requires_grad, + memory_format=memory_format, + ) + + if value == 0 and input.dtype == torch.bool: + value = False + # torch.fill isn't typed to allow complex values + output = torch.fill(output, value) # type: ignore[arg-type] + + c_output = output + for i in range(l_diff, l_inp): + pad_idx = 2 * (l_inp - i - 1) + if pad[pad_idx] > 0: + c_output = c_output.narrow( + i, pad[pad_idx], c_output.shape[i] - pad[pad_idx] + ) + if pad[pad_idx + 1] > 0: + c_output = c_output.narrow(i, 0, c_output.shape[i] - pad[pad_idx + 1]) + + prims.copy_to(c_output, c_input) + return output + + +def contiguous( + a: Tensor, *, memory_format: torch.memory_format = torch.contiguous_format +) -> Tensor: + torch._check( + memory_format != torch.preserve_format, + lambda: "preserve memory format is unsupported by the contiguous operator", + ) + + if utils.is_contiguous_for_memory_format(a, memory_format=memory_format): + return a + + return torch.clone(a, memory_format=memory_format) + + +@out_wrapper() +def dstack(tensors: TensorSequenceType) -> TensorLikeType: + torch._check(len(tensors) > 0, lambda: "dstack expects a non-empty TensorList") + aligned_tensors = atleast_3d(*tensors) + return cat(aligned_tensors, 2) + + +@register_decomposition(aten.expand) +def expand(a: Tensor, *shape) -> Tensor: + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + # NOTE: cannot use utils.extract_shape_from_varargs here + # because that also validates the shape, but the shape + # given to expand may be "invalid" + if len(shape) == 1 and isinstance(shape[0], Sequence): + shape = tuple(shape[0]) + + torch._check( + len(shape) >= len(a.shape), + lambda: "expand: the requested shape has too few dimensions!", + ) + + offset = len(shape) - len(a.shape) + shape_ = list(shape) + for idx, x in enumerate(a.shape): + offset_idx = idx + offset + requested_length = shape[offset_idx] + torch._check( + guard_size_oblivious(requested_length == x) + or guard_size_oblivious(x == 1) + or requested_length == -1, + lambda: f"expand: attempting to expand a dimension of length {x}!", + ) + + shape_[offset_idx] = requested_length if requested_length != -1 else x + + # At this point shape must be valid + utils.validate_shape(shape_) + + return prims.broadcast_in_dim( + a, shape_, tuple(range(offset, len(a.shape) + offset)) + ) + + +# CompositeImplicitAutograd - don't register decomp +def expand_as(a: Tensor, b: Tensor) -> Tensor: + return a.expand(b.shape) + + +def chunk(a: TensorLikeType, chunks: int, dim: int = 0) -> Tuple[TensorLikeType, ...]: + if chunks <= 0: + msg = f"Expected at least one chunk, but got {chunks}!" + raise ValueError(msg) + + dim = utils.canonicalize_dim(a.ndim, dim) + length = a.shape[dim] + chunk_size = math.ceil(length / chunks) + full_chunks = math.floor(length / chunk_size) + tail_chunk_size = length % chunk_size + + result = [] + for i in range(full_chunks): + result.append(narrow(a, dim, i * chunk_size, chunk_size)) + + if tail_chunk_size != 0: + result.append(narrow(a, dim, full_chunks * chunk_size, tail_chunk_size)) + + return tuple(result) + + +# Note: flatten, unlike other shape operators, returns the input tensor on a no-op (unless +# a 0D tensor is flattened, in which case it's returned in 1D) +# CompositeImplicitAutograd - don't register decomp +def flatten(a: TensorLikeType, start_dim: int = 0, end_dim: int = -1) -> TensorLikeType: + start_dim = utils.canonicalize_dim(a.ndim, start_dim) + end_dim = utils.canonicalize_dim(a.ndim, end_dim) + + # Short-circuits on no-op + if start_dim == end_dim and a.ndim != 0: + return a + + # Tries to take a view + # TODO: we could look at directing collapse_view to skip its meta function here (unsafe_collapse_view) + new_shape, new_strides = prims._collapse_view_helper(a, start_dim, end_dim) + if new_shape is not None: + return prims.collapse_view(a, start_dim, end_dim) + + # Makes a copy if it can't make a view + return prims.collapse(a, start_dim, end_dim) + + +@register_decomposition(aten.flip) +@out_wrapper() +def flip(a: TensorLikeType, dims: DimsSequenceType) -> TensorLikeType: + if not isinstance(dims, tuple) and not isinstance(dims, list): + raise ValueError("dims has to be a sequence of ints") + dims = utils.canonicalize_dims(a.ndim, dims) # type: ignore[assignment] + utils.validate_no_repeating_dims(dims) + return prims.rev(a, dims) + + +# CompositeImplicitAutograd - don't register decomp +def fliplr(a: TensorLikeType) -> TensorLikeType: + if a.ndim < 2: + raise RuntimeError("Input must be >= 2-d.") + + return flip(a, (1,)) + + +# CompositeImplicitAutograd - don't register decomp +def flipud(a: TensorLikeType) -> TensorLikeType: + if a.ndim < 1: + raise RuntimeError("Input must be >= 1-d.") + + return flip(a, (0,)) + + +# CompositeImplicitAutograd - don't register decomp +def narrow( + a: TensorLikeType, dim: int, start: Union[int, TensorLikeType], length: int +) -> TensorLikeType: + # Supports Tensor overload that was added for XLA: + # https://github.com/pytorch/pytorch/issues/31558 + if isinstance(start, TensorLike): + torch._check( + start.dim() == 0 and utils.is_integer_dtype(start.dtype), + lambda: "start must be an 0-dim integral Tensor.", + ) + start = start.item() # type: ignore[assignment] + torch._check(a.dim() > 0, lambda: "narrow() cannot be applied to a 0-dim tensor.") + torch._check(length >= 0, lambda: "narrow(): length must be non-negative.") + dim = utils.canonicalize_dim(a.ndim, dim) + dim_length = a.size(dim) + torch._check_with( + IndexError, + -dim_length <= start and start <= dim_length, # type: ignore[arg-type] + lambda: f"start out of range (expected to be in range of [{-dim_length}, {dim_length}], but got {start})", + ) + if start < 0: + start = start + dim_length + torch._check( + start <= dim_length - length, # type: ignore[arg-type] + lambda: f"start ({start}) + length ({length}) exceeds dimension size ({dim_length}).", + ) + return prims.slice_in_dim(a, start, start + length, axis=dim) + + +# TODO: This must return a sparse tensor if the input is sparse, but refs have +# no sparse support. See narrow_copy_sparse in core. +narrow_copy = _make_copy_from_view(narrow) + + +def _normalize( + a: Tensor, norm_dims: DimsType, eps: float +) -> Tuple[Tensor, Tensor, Tensor]: + """Computes mean and 1/std of a tensor along norm_dims. + + Used as a helper function for normalization layers. + + Args: + a (Tensor): input tensor + norm_dims (DimsType): dimensions to normalize over + eps (float): epsilon for numerical stability + + Returns: + out (Tensor): normalized tensor. + mean (Tensor): mean of the tensor along norm_dims. + rstd (Tensor): 1/std of the tensor along norm_dims. + """ + norm_dims = utils.canonicalize_dims(a.ndim, norm_dims) + computation_dtype = utils.get_computation_dtype(a.dtype) + a_acc = _maybe_convert_to_dtype(a, computation_dtype) + assert isinstance(a_acc, TensorLike) # to avoid mypy error for var_mean + biased_var, mean = torch.var_mean( + a_acc, dim=norm_dims, unbiased=False, keepdim=True + ) + rstd = torch.rsqrt(biased_var + eps) + out = (a - mean) * rstd + return out, mean, rstd + + +# add all specified dimensions +def _unsqueeze_multiple(x: TensorLikeType, dimensions: List[int]) -> TensorLikeType: + for dim in sorted(dimensions): + x = torch.unsqueeze(x, dim) + return x + + +@register_decomposition(aten.native_group_norm.default) +def native_group_norm( + input: Tensor, + weight: Optional[Tensor], + bias: Optional[Tensor], + batch_size: int, + num_channels: int, + flattened_inner_size: int, + num_groups: int, + eps: float, +) -> Tuple[Tensor, Tensor, Tensor]: + torch._check( + input.ndim >= 2, + lambda: f"Expected at least 2 dimensions for input tensor but received {input.ndim}", + ) + torch._check( + num_channels % num_groups == 0, + lambda: "Expected number of channels in input to be divisible by num_groups, " + + f"but got input of shape {input.shape} and num_groups = {num_groups}", + ) + + # num_channels / num_groups and flattened inner dimension are the reduction axes + reduction_dims = [2, 3] + input_reshaped = torch.reshape( + input, + [batch_size, num_groups, num_channels // num_groups, flattened_inner_size], + ) + out, mean, rstd = _normalize(input_reshaped, reduction_dims, eps) + out = out.view(input.shape) + + broadcast_dims = [0] + list(range(2, input.ndim)) + unsqueeze_bias = None + if bias is not None: + unsqueeze_bias = _unsqueeze_multiple(bias, broadcast_dims) + unsqueeze_weight = None + if weight is not None: + unsqueeze_weight = _unsqueeze_multiple(weight, broadcast_dims) + + if unsqueeze_weight is not None: + out = out * unsqueeze_weight + if unsqueeze_bias is not None: + out = out + unsqueeze_bias + + out = _maybe_convert_to_dtype(out, input.dtype) # type: ignore[assignment] + mean = _maybe_convert_to_dtype(mean, input.dtype) # type: ignore[assignment] + rstd = _maybe_convert_to_dtype(rstd, input.dtype) # type: ignore[assignment] + + # remove broadcast dimensions from mean and rstd + mean = torch.squeeze(mean, reduction_dims) + rstd = torch.squeeze(rstd, reduction_dims) + return (out, mean, rstd) + + +@register_decomposition(aten.native_layer_norm) +@out_wrapper("out0", "out1", "out2") +def native_layer_norm( + input: Tensor, + normalized_shape: ShapeType, + weight: Optional[Tensor], + bias: Optional[Tensor], + eps: float, +) -> Tuple[Tensor, Tensor, Tensor]: + normalized_ndim = len(normalized_shape) + torch._check( + normalized_ndim >= 1, + lambda: "Expected normalized_shape to be at least 1-dimensional, i.e., " + + "containing at least one element, but got normalized_shape = " + + str(normalized_shape), + ) + # torch.Size([1, 2, 3]) == [1, 2, 3] evaluates to False + # while torch.Size([1, 2, 3]) == (1, 2, 3) is True + # therefore we use tuple(normalized_shape) + torch._check( + weight is None or weight.shape == tuple(normalized_shape), + lambda: "Expected weight to be of same shape as normalized_shape, but got " + + "weight of shape " + + str(weight.shape) # type: ignore[union-attr] + + " and normalized_shape = " + + str(normalized_shape), + ) + torch._check( + bias is None or bias.shape == tuple(normalized_shape), + lambda: "Expected bias to be of same shape as normalized_shape, but got " + + "bias of shape " + + str(bias.shape) # type: ignore[union-attr] + + " and normalized_shape = " + + str(normalized_shape), + ) + torch._check( + input.ndim >= normalized_ndim + and input.shape[(input.ndim - normalized_ndim) :] == tuple(normalized_shape), + lambda: "Given normalized_shape=" + + str(normalized_shape) + + ", expected input with shape " + + str(normalized_shape) + + ", but got input of size " + + str(input.shape), + ) + + input = input.contiguous() + if weight is not None: + weight = weight.contiguous() + if bias is not None: + bias = bias.contiguous() + + axis = input.ndim - normalized_ndim + reduction_dims = list(range(axis, input.ndim)) + out, mean, rstd = _normalize(input, reduction_dims, eps) + + if weight is None and bias is not None: + out = out + bias + elif weight is not None and bias is None: + out = out * weight + elif weight is not None and bias is not None: + out = out * weight + bias + + out = _maybe_convert_to_dtype(out, input.dtype) # type: ignore[assignment] + if input.device.type == "cpu": + mean = _maybe_convert_to_dtype(mean, input.dtype) # type: ignore[assignment] + rstd = _maybe_convert_to_dtype(rstd, input.dtype) # type: ignore[assignment] + return (out, mean, rstd) + + +# TODO: Adding this as a meta function causes functorch tests to fail when compiled with debug mode. +# test/test_eager_transforms.py::TestFunctionalizeCPU::test_functionalize_fx_transpose_simple_cpu +@register_decomposition(aten.permute) +def permute(a: TensorLikeType, *dims) -> TensorLikeType: + _permutation = utils.canonicalize_dims( + a.ndim, utils.extract_dims_from_varargs(dims) + ) + return prims.transpose(a, _permutation) + + +@register_decomposition(aten.renorm) +@out_wrapper() +def renorm( + input: TensorLikeType, p: RealNumberType, dim: int, maxnorm: RealNumberType +) -> TensorLikeType: + torch._check(not isinstance(p, complex), lambda: "renorm: p must be real-valued") + torch._check(p > 0, lambda: "renorm: non-positive norm not supported") + torch._check( + not isinstance(maxnorm, complex), lambda: "renorm: maxnorm must be real-valued" + ) + torch._check( + maxnorm >= 0, lambda: f"renorm: expected maxnorm to be >= 0 but got {maxnorm}" + ) + ndim = input.ndim + torch._check( + ndim > 1, + lambda: f"renorm: input needs at least 2 dimensions, got {ndim} dimensions", + ) + + dim = utils.canonicalize_dim(ndim, dim) + reduce_dims = list(range(ndim)) + del reduce_dims[dim] + + # For half and bfloat16, calculate norm in float precision then cast + # normalization factor to half + acc_type = utils.get_computation_dtype(input.dtype) + if acc_type != input.dtype: + norm = torch.linalg.vector_norm( + input, p, reduce_dims, keepdim=True, dtype=acc_type + ) + else: + norm = torch.linalg.vector_norm(input, p, reduce_dims, keepdim=True) + + eps = 1e-7 + norm_factor = torch.where(norm > maxnorm, maxnorm / (norm + eps), 1.0) + if acc_type != input.dtype: + norm_factor = prims.convert_element_type(norm_factor, input.dtype) + return (input * norm_factor).contiguous() + + +# CompositeImplicitAutograd - don't register decomp +@aten.stft.center.py_impl(DispatchKey.CompositeImplicitAutograd) +def stft( + input: Tensor, + n_fft: int, + hop_length: Optional[int] = None, + win_length: Optional[int] = None, + window: Optional[Tensor] = None, + center: bool = True, + pad_mode: str = "reflect", + normalized: bool = False, + onesided: Optional[bool] = None, + return_complex: Optional[bool] = None, +) -> Tensor: + torch._check( + window is None or window.device == input.device, + lambda: ( + f"stft input and window must be on the same device but got self on {input.device}" + + f" and window on {window.device}" # type: ignore[union-attr] + ), + ) + + hop_length_ = hop_length if hop_length is not None else n_fft // 4 + win_length_ = win_length if win_length is not None else n_fft + + if return_complex is None: + return_complex_ = input.is_complex() or ( + window is not None and utils.is_complex_dtype(window.dtype) + ) + torch._check( + return_complex_, + ( + "stft requires the return_complex parameter be given for real inputs, " + + "and will further require that return_complex=True in a future PyTorch release." + ), + ) + else: + return_complex_ = return_complex + + torch._check( + utils.is_float_dtype(input.dtype) or utils.is_complex_dtype(input.dtype), + lambda: "stft expected a tensor of floating point or complex values", + ) + torch._check(1 <= input.ndim <= 2, lambda: "stft expected a 1D or 2D tensor") + + original_ndim = input.ndim + if original_ndim == 1: + input = input.unsqueeze(0) + + if center: + extra_dims = 3 - input.ndim + pad_amount = n_fft // 2 + extended_shape = [*itertools.repeat(1, extra_dims), *input.shape] + input = aten.pad(input.view(extended_shape), [pad_amount, pad_amount], pad_mode) + input = input.view(input.size()[extra_dims:]) + + batch = input.size(0) + length = input.size(1) + torch._check( + 0 < n_fft <= length, + lambda: f"stft expected 0 < n_fft <= {length}, but got n_fft={n_fft}", + ) + torch._check( + hop_length_ > 0, + lambda: f"stft expected hop_length > 0 but got hop_length={hop_length_}", + ) + torch._check( + 0 < win_length_ <= n_fft, + lambda: f"stft expected 0 < win_length <= n_fft but got win_length={win_length_}", + ) + torch._check( + window is None or window.shape == (win_length_,), + lambda: ( + f"expected a 1D window tensor of size equal to win_length={win_length_}, " + + f"but got window with size {window.shape}" # type: ignore[union-attr] + ), + ) + + if win_length_ < n_fft: + if window is None: + window = torch.ones(win_length_, dtype=input.dtype, device=input.device) + left = (n_fft - win_length_) // 2 + window = aten.constant_pad_nd(window, [left, n_fft - win_length_ - left]) + + input = input.unfold(dimension=-1, size=n_fft, step=hop_length_) + if window is not None: + input = input * window + + complex_fft = utils.is_complex_dtype(input.dtype) + onesided = onesided if onesided is not None else not complex_fft + norm = "ortho" if normalized else None + if onesided: + torch._check( + not complex_fft, + lambda: "Cannot have onesided output if window or input is complex", + ) + out = torch.fft.rfft(input, dim=-1, norm=norm) + else: + out = torch.fft.fft(input, dim=-1, norm=norm) + + out.transpose_(1, 2) + + if original_ndim == 1: + out = out.squeeze_(0) + + return out if return_complex_ else torch.view_as_real(out) + + +# CompositeImplicitAutograd - don't register decomp +@aten.istft.default.py_impl(DispatchKey.CompositeImplicitAutograd) +def istft( + input: Tensor, + n_fft: int, + hop_length: Optional[int] = None, + win_length: Optional[int] = None, + window: Optional[Tensor] = None, + center: bool = True, + normalized: bool = False, + onesided: Optional[bool] = None, + length: Optional[int] = None, + return_complex=False, +) -> Tensor: + torch._check( + window is None or window.device == input.device, + lambda: ( + f"istft input and window must be on the same device but got self on {input.device}" + + f" and window on {window.device}" # type: ignore[union-attr] + ), + ) + + hop_length_ = hop_length if hop_length is not None else n_fft // 4 + win_length_ = win_length if win_length is not None else n_fft + + torch._check( + utils.is_complex_dtype(input.dtype), + lambda: ( + "istft input and window must be on the same device but got self on " + + f"{input.device} and window on {window.device}" # type: ignore[union-attr] + ), + ) + n_frames = input.size(-1) + fft_size = input.size(-2) + + expected_output_signal_len = n_fft + hop_length_ * (n_frames - 1) + torch._check(input.numel() > 0, lambda: "istft input tensor cannot be empty") + torch._check( + 2 <= input.ndim <= 3, + lambda: f"istft expected a tensor with 2 or 3 dimensions, but got {input.ndim}", + ) + onesided_ = onesided if onesided is not None else fft_size != n_fft + + if onesided_: + torch._check( + n_fft // 2 + 1 == fft_size, + lambda: ( + "istft expected the frequency dimension (3rd to the last) of the input tensor " + + "to match n_fft / 2 + 1 when onesided=True, but got {fft_size}" + ), + ) + else: + torch._check( + n_fft == fft_size, + lambda: ( + "istft expected the frequency dimension (3rd to the last) of the input tensor " + + "to match n_fft when onesided=False, but got {fft_size}", + ), + ) + + torch._check( + 0 < hop_length_ <= win_length_, + lambda: "istft expected 0 < hop_length <= win_length", + ) + torch._check( + 0 < win_length_ <= n_fft, lambda: "istft expected 0 < win_length <= n_fft" + ) + torch._check( + window is None or window.shape == (win_length_,), + lambda: "Invalid window shape. window has to be 1D and length of `win_length`", + ) + + if window is None: + real_dtype = utils.corresponding_real_dtype(input.dtype) + window_ = torch.ones(win_length_, dtype=real_dtype, device=input.device) + else: + window_ = window + + if win_length_ != n_fft: + left = (n_fft - win_length_) // 2 + window_ = aten.constant_pad_nd(window_, (left, n_fft - win_length_ - left), 0) + + original_ndim = input.ndim + if input.ndim == 2: + input = input.unsqueeze(0) + + input = input.transpose(1, 2) + norm = "ortho" if normalized else None + if return_complex: + torch._check( + not onesided_, + lambda: "cannot have onesided output if window or input is complex", + ) + input = torch.fft.ifft(input, dim=-1, norm=norm) + else: + torch._check( + window is None or not utils.is_complex_dtype(window.dtype), + lambda: "Complex windows are incompatible with return_complex=False", + ) + if not onesided_: + input = input.narrow(dim=-1, start=0, length=n_fft // 2 + 1) + input = torch.fft.irfft(input, dim=-1, norm=norm) + + assert input.size(2) == n_fft + + y_tmp = input * window_.view([1, 1, n_fft]) + y = aten.unfold_backward( + y_tmp, + input_sizes=(y_tmp.size(0), expected_output_signal_len), + dim=1, + size=n_fft, + step=hop_length_, + ) + window_envelop = aten.unfold_backward( + window_.pow(2).expand((1, n_frames, n_fft)), + input_sizes=(y_tmp.size(0), expected_output_signal_len), + dim=1, + size=n_fft, + step=hop_length_, + ) + + assert expected_output_signal_len == y.size(1) + assert expected_output_signal_len == window_envelop.size(1) + + start = n_fft // 2 if center else 0 + if length is not None: + end = start + length + elif center: + end = expected_output_signal_len - n_fft // 2 + else: + end = expected_output_signal_len + + length = max(0, end - start) + y = y.narrow(dim=1, start=start, length=length) + window_envelop = window_envelop.narrow(dim=1, start=start, length=length) + + window_envelop_lowest = window_envelop.abs().min().lt(1e-11) + torch._check( + not window_envelop_lowest.item(), + lambda: "window overlap add min less than 1e-11", + ) + + y = y / window_envelop + if original_ndim == 2: + y = y.squeeze(0) + + if end > expected_output_signal_len: + warnings.warn( + "The length of signal is shorter than the length parameter. Result is being " + + "padded with zeros in the tail. Please check your center and hop_length settings" + ) + y = aten.constant_pad_nd(y, (0, end - expected_output_signal_len), 0) + return y + + +# Get the new shape and stride after applying unfold to an input tensor +def _get_unfold_shape_stride( + a_shape: ShapeType, a_stride: StrideType, dimension: int, size: int, step: int +): + a_ndim = len(a_shape) + dim = utils.canonicalize_dim(a_ndim, dimension, wrap_scalar=True) + max_size = 1 if a_ndim == 0 else a_shape[dim] + last_stride = 1 if a_ndim == 0 else a_stride[dim] + + torch._check( + size <= max_size, + lambda: f"Maximum size for tensor at dimension {dim} is {max_size} but size is {size}", + ) + + torch._check( + step > 0, + lambda: f"Step is {step} but must be > 0", + ) + + shape = list(a_shape) + strides = list(a_stride) + shape.append(size) + strides.append(last_stride) + if dim < a_ndim: + shape[dim] = (shape[dim] - size) // step + 1 + strides[dim] *= step + return shape, strides + + +@register_decomposition(aten.repeat) +@out_wrapper() +def repeat(a: Tensor, *repeat_shape) -> Tensor: + repeat_shape = utils.extract_shape_from_varargs(repeat_shape, validate=False) + torch._check( + len(repeat_shape) >= len(a.shape), + lambda: "repeat: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor", + ) + + if len(repeat_shape) == 0: + return torch.clone(a) + + num_new_dimensions = len(repeat_shape) - a.ndim + padded_shape = [1] * num_new_dimensions + for dim_size in a.shape: + padded_shape.append(dim_size) + + target_shape = tuple( + padded_size * repeat_size + for padded_size, repeat_size in zip(padded_shape, repeat_shape) + ) + + # return an empty tensor if one of the repeat_shape dimensions is zero + if 0 in repeat_shape: + return torch.empty( + target_shape, + dtype=a.dtype, + device=a.device, + requires_grad=a.requires_grad, + memory_format=utils.suggest_memory_format(a), + ) + + urtensor_shape = target_shape + urtensor_stride = utils.make_contiguous_strides_for(target_shape) + for dim, dim_size in enumerate(padded_shape): + # repeat each dimension by using unfold_copy operation + urtensor_shape, urtensor_stride = _get_unfold_shape_stride( + urtensor_shape, urtensor_stride, dim, dim_size, max(dim_size, 1) + ) + + # derive permute order by sorting urtensor strides + enumerated_stride = list(enumerate(urtensor_stride)) + enumerated_stride.sort(key=lambda item: item[1], reverse=True) + permute_order, sorted_stride = zip(*enumerated_stride) + + # add new and expand dimensions according to urtensor + repeat_xtensor = a.expand(urtensor_shape) + + # clone tensor to concretize expanded dimensions + cloned_result = torch.clone(repeat_xtensor) + + # transpose axis so strides are in sorted order + permuted_result = cloned_result.permute(permute_order) + + # reshape to get contiguous tensor with correct target shape + return permuted_result.reshape(target_shape) + + +def _reshape_view_helper(a: TensorLikeType, *shape, allow_copy: bool) -> TensorLikeType: + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious, sym_eq + + # Creates a valid shape + shape = utils.extract_shape_from_varargs(shape, validate=False) + # Reshape may be given a shape with a -1 length + # This indicates that the dimension's length should be inferred + shape = utils.infer_size(shape, a.numel()) + + # Short-circuits if shape is the same + if guard_size_oblivious(sym_eq(tuple(a.shape), tuple(shape))): + return prims.view_of(a) + + # Special-cases tensors with no elements + if guard_size_oblivious(a.numel() == 0): + return as_strided(a, shape, utils.make_contiguous_strides_for(shape)) + + # Special-cases reshaping zero dim tensors + if a.ndim == 0: + _a = a + for length in shape: + assert length == 1 + _a = unsqueeze(_a, -1) + return _a + + # Special-cases reshaping to zero dim tensors + if len(shape) == 0: + _a = a + for length in a.shape: + assert length == 1 + _a = squeeze(_a, -1) + return _a + + # Handles general case: a 1+D tensor reshaped into a distinct 1+D shape + + # NOTE [Reshape Algorithm] + # This algorithm works by attempting to greedily construct the desired dimensions in + # the output shape, left to right. It does this by, conceptually, accumulating + # dimensions of the original tensor, also left to right, until the dimension + # can be constructed using prims.split_dim. + # The algorithm also has special handling for tail squeezes/unsqueezes, like + # if a reshape from (5, 5) to (5, 5, 1) or vice versa. + # + # This algorithm does not flatten the original tensor and then split dims as appropriate + # because that would create copies more often than this algorithm. flatten is the only + # operation below which can create a view or a copy, and while it prefers creating + # views it may sometimes create a copy if the tensor's strides do not permit a view. + # As a result, this algorithm tries to minimize flattening. + # + # Note that a better version of this algorithm may exist. Regions which could be + # flattened without creating a copy can be identified in advance, and that might + # allow fewer flatten calls or faster short-circuiting to make a copy. + idx = 0 + a_ = a + for length in shape: + # Handles tail unsqueezes + if idx >= a_.ndim: + assert length == 1 + last_dim = a_.ndim - 1 + # NOTE: using split_dim instead of unsqueeze may seem silly here, + # but it's necessary to get the strides correct + a_ = prims.split_dim(a_, last_dim, a_.shape[last_dim]) + idx = idx + 1 + continue + + # Skips dimensions that are already the correct length + if guard_size_oblivious(length == a_.shape[idx]): + idx = idx + 1 + continue + + # Gathers enough original dimensions such that this new dimension can be created + # Note that this accumulation will terminate because we've verified a and the shape + # specify the same number of elements above + accum = a_.shape[idx] + end = idx + while guard_size_oblivious(accum % length != 0): + end = end + 1 + accum = accum * a_.shape[end] + if end != idx: + # NOTE: in this case multiple dimensions must be flatten to create the desired dimension + # This flattening is why reshape sometimes creates a copy -- because flattening + # may return a view of a copy + + # Checks if collapse can be a view and short-circuits to copying reshape if it can't + new_shape, new_strides = prims._collapse_view_helper(a_, idx, end) + if new_shape is None: + if allow_copy: + return prims.reshape(a, shape) + + msg = "Cannot view a tensor with shape {} and strides {} as a tensor with shape {}!".format( + a.shape, a.stride(), shape + ) + raise ValueError(msg) + + a_ = flatten(a_, idx, end) + + # Splits the (possibly flattened) dimension to create the desired dim length + if guard_size_oblivious(accum != length): + a_ = prims.split_dim(a_, idx, length) + + idx = idx + 1 + + # Squeezes tail + while idx < a_.ndim: + assert a_.shape[idx] == 1 + a_ = squeeze(a_, idx) + + return a_ + + +# CompositeImplicitAutograd - don't register decomp +# NOTE: shape is a vararg because Tensor.reshape can be called with as +# Tensor.reshape(a, b, c) or Tensor.reshape((a, b, c)) Function call +# torch.reshape doesn't support unpacked shapes +def reshape(a: TensorLikeType, *shape: ShapeType) -> TensorLikeType: + return _reshape_view_helper(a, *shape, allow_copy=True) + + +# CompositeImplicitAutograd - don't register decomp +def reshape_as(self: TensorLikeType, other: TensorLikeType) -> TensorLikeType: + return self.reshape(other.size()) + + +@register_decomposition(aten.roll) +@out_wrapper() +def roll( + a: TensorLikeType, shifts: DimsType, dims: DimsType = tuple() +) -> TensorLikeType: + """Reference implementation of :func:`torch.roll`.""" + dims = utils.canonicalize_dims(a.ndim, dims) + # ATen specifies int[1] type for shifts and dims which expands integers to tuples of length 1 + if not isinstance(shifts, Iterable): + shifts = (shifts,) + if not isinstance(dims, Iterable): + dims = (dims,) + + # Avoid modulo by zero + if a.numel() == 0: + # Keeping this as ref for now as FakeTensor runs into some issues with complex tensors + return a.clone() + + if a.dim() == 0 and len(dims) > 0: + raise IndexError( + f"Dimension specified as {dims[0]} but tensor has no dimensions" + ) + + len_shifts = len(shifts) + len_dims = len(dims) + if len_shifts != 1 or len_dims != 1: + if len_shifts == 0: + raise RuntimeError("`shifts` required") + # Takes care of the case when dims is not specified (default) + # By default, the tensor is flattened before shifting, after which the original shape is restored + if len_dims == 0 and len_shifts == 1: + return torch.roll(torch.flatten(a), shifts, 0).view(a.shape) + if len_shifts != len_dims: + raise RuntimeError( + f"shifts and dimensions must align. shifts: {len_shifts}, dims: {len_dims}" + ) + assert len_dims > 1 + tail_shifts = shifts[1:] + tail_dims = dims[1:] + first_dim_rolled = torch.roll(a, (shifts[0],), dims[0]) + return torch.roll(first_dim_rolled, tail_shifts, tail_dims) + + # This path is taken when only one dimension is rolled + # For example to get `first_dim_rolled` above + dim = dims[0] + size = a.shape[dim] + start = (size - shifts[0]) % size + idx = torch.arange(size, device=a.device) + return a.index_select(dim, torch.fmod(start + idx, size)) + + +@register_decomposition(aten.rot90) +@out_wrapper() +def rot90( + a: TensorLikeType, k: int = 1, dims: DimsSequenceType = (0, 1) +) -> TensorLikeType: + """Reference implementation of :func:`torch.rot90`.""" + if len(dims) != 2: + raise RuntimeError( + f"expected total rotation dims == 2, but got dims = {len(dims)}" + ) + if a.ndim < 2: + raise RuntimeError(f"expected total dims >= 2, but got total dims = {a.ndim}") + + # Do this after the initial checks to be compatible with the behavior in + # core. + dims = utils.canonicalize_dims(a.ndim, dims) + + if dims[0] == dims[1]: + raise RuntimeError( + f"expected rotation dims to be different, but got dim0 = {dims[0]} and dim1 = {dims[1]}" + ) + k = k % 4 # Rotation direction is from the second towards the first axis for k < 0 + if k == 1: + return torch.transpose(torch.flip(a, (dims[1],)), dims[0], dims[1]) + elif k == 2: + return torch.flip(a, dims) + elif k == 3: + return torch.transpose(torch.flip(a, (dims[0],)), dims[0], dims[1]) + else: + return clone(a, memory_format=torch.contiguous_format) + + +def _check_stack_inputs(tensors: TensorSequenceType) -> None: + entry_shape = tensors[0].shape + for i in range(1, len(tensors)): + assert tensors[i].shape == entry_shape, ( + f"stack expects each tensor to be equal size, but got {entry_shape} at entry 0" + f"and {tensors[i].shape} at entry {i}" + ) + + +@register_decomposition(aten.stack) +@out_wrapper() +def stack(tensors: TensorSequenceType, dim: int = 0) -> TensorLikeType: + assert len(tensors) > 0, "stack expects a non-empty TensorList" + wrapped_dim = utils.canonicalize_dim(tensors[0].ndim + 1, dim) + # Refs need sparse support to check other condition + if wrapped_dim < tensors[0].ndim: # and not tensors[0].is_sparse: + _check_stack_inputs(tensors) + result_sizes = list(tensors[0].shape) + result_sizes.insert(wrapped_dim, len(tensors)) + out = torch.cat(tensors, wrapped_dim) + return out.view(result_sizes) + + # If dim == tensors[0].ndim, view cannot efficiently handle it + return torch.cat([t.unsqueeze(wrapped_dim) for t in tensors], dim) + + +# CompositeImplicitAutograd - don't register decomp +@out_wrapper() +def softmax( + a: TensorLikeType, + dim: int, + dtype: Optional[torch.dtype] = None, +) -> TensorLikeType: + result_dtype = dtype or a.dtype + computation_dtype = utils.get_computation_dtype(result_dtype) + a_ = _maybe_convert_to_dtype(a, computation_dtype) + if a.numel() == 0: + a_exp = exp(a_) + else: + a_max = amax(a_, dim, keepdim=True) + a_exp = exp(a_ - a_max) + return _maybe_convert_to_dtype( + true_divide(a_exp, sum(a_exp, dim, keepdim=True)), result_dtype + ) # type: ignore[return-value] + + +# CompositeImplicitAutograd - don't register decomp +@out_wrapper() +def hstack(tensors: TensorSequenceType) -> TensorLikeType: + torch._check(len(tensors) > 0, lambda: "hstack expects a non-empty TensorList") + aligned_tensors = atleast_1d(*tensors) + if aligned_tensors[0].ndim == 1: + return cat(aligned_tensors, 0) + return cat(aligned_tensors, 1) + + +# CompositeImplicitAutograd - don't register decomp +@out_wrapper() +def vstack(tensors: TensorSequenceType) -> TensorLikeType: + torch._check(len(tensors) > 0, lambda: "vstack expects a non-empty TensorList") + aligned_tensors = atleast_2d(*tensors) + return cat(aligned_tensors, 0) + + +# CompositeImplicitAutograd - don't register decomp +def unflatten(a: TensorLikeType, dim: int, sizes: ShapeType) -> TensorLikeType: + dim = utils.canonicalize_dim(a.ndim, dim) + torch._check(len(sizes) != 0, lambda: "unflatten: sizes must be non-empty") + return a.view(tuple(a.shape[:dim]) + tuple(sizes) + tuple(a.shape[dim + 1 :])) + + +@register_decomposition(aten.unbind) +def unbind(t: TensorLikeType, dim: int = 0) -> TensorSequenceType: + dim = utils.canonicalize_dim(t.ndim, dim) + torch._check_index( + len(t.shape) > 0, + lambda: "Dimension specified as 0 but tensor has no dimensions", + ) + if t.shape[dim] == 0: + return tuple() + else: + return tuple( + torch.squeeze(s, dim) for s in torch.tensor_split(t, t.shape[dim], dim) + ) + + +@out_wrapper() +def index_copy(x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike): + return x.clone(memory_format=torch.contiguous_format).index_copy_( + dim, index, tensor + ) + + +def index_copy_(x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike): + dim = utils.canonicalize_dims(x.ndim, dim) + torch._check( + index.ndim <= 1, + lambda: f"Index should have dimension 1 or 0 (got {index.ndim})", + ) + # Treat scalars as elements of \R^1 + y = x.unsqueeze(0) if x.ndim == 0 else x + idx = (slice(None),) * dim + (index,) + y[idx] = tensor + return x + + +@register_decomposition(aten.index_fill) +@out_wrapper() +def index_fill( + x: TensorLike, dim: int, index: TensorLike, value: Union[NumberType, TensorLike] +): + return _index_fill(x, dim, index, value, inplace=False) + + +@register_decomposition(aten.index_fill_) +def index_fill_( + x: TensorLike, dim: int, index: TensorLike, value: Union[NumberType, TensorLike] +): + return _index_fill(x, dim, index, value, inplace=True) + + +def _index_fill( + x: TensorLike, + dim: int, + index: TensorLike, + value: Union[NumberType, TensorLike], + *, + inplace: bool, +): + torch._check( + index.ndim <= 1, + lambda: f"Index should have dimension 1 or 0 (got {index.ndim})", + ) + if isinstance(value, TensorLike): + torch._check( + value.ndim == 0, + lambda: "Only supports 0-dimensional value tensor. " # type: ignore[union-attr] + f"Got a tensor with {value.ndim} dimensions.", + ) # type: ignore[arg-type] + else: + value = torch.scalar_tensor( + value, dtype=x.dtype, layout=x.layout, device=x.device # type: ignore[arg-type] + ) + + # index_copy has some unnecessary preconditions when x is a scalar. We do this to work through them + zero_dim = x.ndim == 0 + y = x.unsqueeze(0) if zero_dim else x + # index_copy does not broadcast on value so we have to do it manually + shape = list(y.shape) + shape[dim] = index.numel() + value = value.expand(shape) + index_copy = Tensor.index_copy_ if inplace else torch.index_copy + out = index_copy(y, dim, index, value) # type: ignore[operator] + if inplace: + return x + else: + if zero_dim: + # The clone is necessary so that it returns a fresh tensor rather than a view + out = out.squeeze(0).clone() + # index_fill preserves the strides. index_copy always returns contiguous tensors + if out.stride() != x.stride(): + new_out = torch.empty_like(x) + new_out.copy_(out) + out = new_out + return out + + +@out_wrapper() +def index_add( + x: TensorLike, + dim: int, + index: TensorLike, + tensor: TensorLike, + *, + alpha: NumberType = 1, +): + # index_add always returns a new contiguous tensor + return x.clone(memory_format=torch.contiguous_format).index_add_( + dim, index, tensor, alpha=alpha # type: ignore[arg-type] + ) + + +@register_decomposition(aten.index_select) +@out_wrapper() +def index_select(x: TensorLike, dim: int, index: TensorLike): + dim = utils.canonicalize_dims(x.ndim, dim) + torch._check( + index.ndim <= 1, + lambda: f"Index should have dimension 1 or 0 (got {index.ndim})", + ) + if index.ndim == 0: + index = index.unsqueeze(0) + if x.ndim == 0: + # Treat scalars as elements of \R^1 + # We cannot use x[idx] here as it accesses item() (??), hence this awkward construction + return torch.empty_like(x).index_copy(0, index, x.expand_as(index)) + + idx = (slice(None),) * dim + (index,) + return x[idx] + + +@register_decomposition(aten.squeeze.dims) +def squeeze(a: TensorLikeType, dim: Optional[DimsType] = None) -> TensorLikeType: + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + if dim is None: + dims = tuple(idx for idx, size in enumerate(a.shape) if size == 1) + return prims.squeeze(a, dims) if dims else prims.view_of(a) + + ndim = a.ndim + dim = utils.canonicalize_dims(ndim, dim) + dims = (dim,) if isinstance(dim, Dim) else dim + # Short-circuits if the tensor has no dimensions + if ndim == 0: + assert len(dims) == 0 or dims == (0,) + return prims.view_of(a) + + # Note: squeeze does not modify tensors when the given dim is not a dimension of length 1 + dims = tuple(d for d in dims if guard_size_oblivious(a.shape[d] == 1)) + if len(dims) == 0: + return prims.view_of(a) + if len(dims) == 1: + return prims.squeeze(a, dims) + dims_list = list(dims) + dims_list = sorted(dims_list, reverse=True) + for i in dims_list: + a = squeeze(a, i) + return a + + +# Note: does not work with TensorMetas because of data-dependent control-flow +# CompositeImplicitAutograd - don't register decomp +def tensor_split( + a: TensorLikeType, + indices_or_sections: Union[Tensor, DimsType], + dim: int = 0, +) -> Tuple[TensorLikeType, ...]: + _dim = utils.canonicalize_dim(a.ndim, dim) + if a.ndim == 0: + msg = "tensor_split: received a rank zero tensor, but expected a tensor of rank one or greater!" + raise ValueError(msg) + + # If indices_or_sections is a tensor, it must be a CPU Long tensor + if isinstance(indices_or_sections, TensorLike): + if not indices_or_sections.device.type == "cpu": + msg = "tensor_split: if indices_or_sections is a tensor it must be on the CPU, but received one on {}".format( + indices_or_sections.device + ) + raise ValueError(msg) + if indices_or_sections.dtype != torch.long: + msg = "tensor_split: if indices_or_sections is a tensor it must have long dtype, " + f" but received one with dtype {indices_or_sections.dtype}" + raise ValueError(msg) + + # Case 0 -- indices_or_sections is an integer or a scalar tensor n and a is split along dim into n parts of equal-ish length + if isinstance(indices_or_sections, IntLike) or ( + isinstance(indices_or_sections, TensorLike) and indices_or_sections.ndim == 0 + ): + sections: int = ( + indices_or_sections # type: ignore[assignment] + if isinstance(indices_or_sections, Number) + else indices_or_sections.item() + ) + + if sections <= 0: + msg = f"tensor_split: number of sections must be greater than 0, but was {sections}" + raise ValueError(msg) + + splits = [] + dim_size = a.shape[_dim] + min_split_size = math.floor(dim_size / sections) + num_splits_one_extra = dim_size % sections + start_idx = 0 + for split_idx in range(sections): + split_size = ( + min_split_size + 1 + if (split_idx < num_splits_one_extra) + else min_split_size + ) + s = prims.slice_in_dim(a, start_idx, start_idx + split_size, axis=_dim) + splits.append(s) + start_idx = start_idx + split_size + + return tuple(splits) + # Case 1 -- indices_or_sections is a sequence of integers or a 1D tensor describing the splits + else: + indices = indices_or_sections + if isinstance(indices_or_sections, TensorLike): + if indices_or_sections.ndim != 1: + msg = "tensor_split: non-scalar indices_or_sections tensors must have only one dimension, " + f"but received a tensor with {indices_or_sections.ndim} dimensions" + raise ValueError(msg) + + indices = indices_or_sections.tolist() + + splits = [] + start_idx = 0 + for x in indices: + splits.append(prims.slice_in_dim(a, start_idx, x, axis=_dim)) + start_idx = x + splits.append(prims.slice_in_dim(a, start_idx, a.shape[_dim], axis=_dim)) + return tuple(splits) + + +# CompositeImplicitAutograd - don't register decomp +def hsplit( + a: TensorLikeType, indices_or_sections: DimsType +) -> Tuple[TensorLikeType, ...]: + torch._check( + a.ndim >= 1, + lambda: ( + "torch.hsplit requires a tensor with at least 1 dimension, but got a tensor with " + + str(a.ndim) + + " dimensions!" + ), + ) + dim = 0 if a.ndim == 1 else 1 + if isinstance(indices_or_sections, IntLike): + split_size = indices_or_sections + torch._check( + (split_size != 0 and a.shape[dim] % split_size == 0), + lambda: ( + "torch.hsplit attempted to split along dimension " + + str(dim) + + ", but the size of the dimension " + + str(a.shape[dim]) + + " is not divisible by the split_size " + + str(split_size) + + "!" + ), + ) + return tensor_split(a, split_size, dim) + + torch._check_type( + isinstance(indices_or_sections, (list, tuple)), + lambda: ( + "hsplit(): received an invalid combination of arguments. " + "Expected indices_or_sections to be of type int, list of ints or tuple of ints " + f"but got type {type(indices_or_sections)}" + ), + ) + + split_sizes = indices_or_sections + return tensor_split(a, split_sizes, dim) + + +# CompositeImplicitAutograd - don't register decomp +def vsplit( + a: TensorLikeType, indices_or_sections: DimsType +) -> Tuple[TensorLikeType, ...]: + torch._check( + a.ndim >= 2, + lambda: ( + "torch.vsplit requires a tensor with at least 2 dimension, but got a tensor with " + + str(a.ndim) + + " dimensions!" + ), + ) + if isinstance(indices_or_sections, IntLike): + split_size = indices_or_sections + torch._check( + (split_size != 0 and a.shape[0] % split_size == 0), + lambda: ( + f"torch.vsplit attempted to split along dimension 0" + f", but the size of the dimension " + f"{a.shape[0]}" + f" is not divisible by the split_size " + f"{split_size}" + f"!" + ), + ) + return tensor_split(a, split_size, 0) + + torch._check_type( + isinstance(indices_or_sections, (list, tuple)), + lambda: ( + "vsplit(): received an invalid combination of arguments. " + "Expected indices_or_sections to be of type int, list of ints or tuple of ints " + f"but got type {type(indices_or_sections)}" + ), + ) + + split_sizes = indices_or_sections + return tensor_split(a, split_sizes, 0) + + +@register_decomposition(aten.diag.out) +@out_wrapper() +def diag( + self: TensorLikeType, + offset: int = 0, +) -> TensorLikeType: + ndim = self.dim() + torch._check( + ndim in (1, 2), lambda: f"diag(): Supports 1D or 2D tensors. Got {ndim}D" + ) + if ndim == 1: + return torch.diag_embed(self, offset) + else: + return torch.diagonal_copy(self, offset) + + +@register_decomposition(aten.diagonal_scatter) +@out_wrapper() +def diagonal_scatter( + input: TensorLikeType, + src: TensorLikeType, + offset: int = 0, + dim1: int = 0, + dim2: int = 1, +) -> TensorLikeType: + out = utils.clone_preserve_strides(input) + diag = out.diagonal(offset, dim1, dim2) + torch._check( + diag.shape == src.shape, + lambda: "expected src to have a size equal to the diagonal of the input." + f"Got {src.shape} for a diagonal of shape {diag.shape}", + ) + copy_to(diag, src) + return out + + +@register_decomposition(aten.diagonal) +def diagonal( + self: TensorLikeType, + offset: int = 0, + dim1: int = 0, + dim2: int = 1, +) -> TensorLikeType: + """ + Reference implementation of torch.diagonal + """ + num_dims = self.dim() + dim1 = utils.canonicalize_dim(idx=dim1, rank=num_dims) + dim2 = utils.canonicalize_dim(idx=dim2, rank=num_dims) + + torch._check( + dim1 != dim2, lambda: f"diagonal dimensions cannot be identical {dim1}, {dim2}" + ) + + storage_offset = self.storage_offset() + + if offset >= 0: + diag_size = max(min(self.size()[dim1], self.size()[dim2] - offset), 0) + else: + diag_size = max(min(self.size()[dim1] + offset, self.size()[dim2]), 0) + + if diag_size > 0: + if offset >= 0: + storage_offset += offset * self.stride()[dim2] + else: + storage_offset -= offset * self.stride()[dim1] + + sizes = [s for i, s in enumerate(self.size()) if i not in (dim1, dim2)] + sizes.append(diag_size) + + strides = [s for i, s in enumerate(self.stride()) if i not in (dim1, dim2)] + strides.append(self.stride()[dim1] + self.stride()[dim2]) + + result = self.as_strided(size=sizes, stride=strides, storage_offset=storage_offset) + + return result + + +diagonal_copy = _make_copy_from_view(diagonal) + + +@register_decomposition(aten.diag_embed) +@out_wrapper() +def diag_embed( + t: TensorLikeType, + offset: int = 0, + dim1: int = -2, + dim2: int = -1, +) -> TensorLikeType: + """ + Reference implementation of torch.diag_embed + """ + # convert from negative dims + rank = t.ndim + 1 + dim1 = utils.canonicalize_dim(rank=rank, idx=dim1) + dim2 = utils.canonicalize_dim(rank=rank, idx=dim2) + + # as per the docs, exchanging dims is equivalent to changing the sign of + # offset + if dim1 > dim2: + dim1, dim2 = dim2, dim1 + offset = -offset + + torch._check( + dim1 != dim2, lambda: f"diagonal dimensions cannot be identical {dim1}, {dim2}" + ) + + # as per the docs, the size of last dim is placed at dim1 and dim2 + last_dim = t.size(-1) + + if offset != 0: + # add padding to match the new size + t_shape = list(t.shape) + t_shape[-1] = builtins.abs(offset) + z = torch.zeros(t_shape, dtype=t.dtype, device=t.device, requires_grad=False) + pair = (z, t) if offset > 0 else (t, z) + t = torch.cat(pair, dim=-1) + # make sure the diagonal always has the same size + last_dim += builtins.abs(offset) + + # preserve original data, but place 1 at dim1 and move last dim to dim2 + t = t.unsqueeze(dim1).movedim(-1, dim2) + + # generate ranges shifting indices based on offset + a_range = torch.arange(last_dim, device=t.device, dtype=torch.int64) + b_range = torch.arange( + offset, last_dim + offset, device=t.device, dtype=torch.int64 + ) + + # broadcast + cond = a_range == b_range.unsqueeze(-1) + cond_shape = [last_dim if i in (dim1, dim2) else 1 for i in range(len(t.shape))] + cond = cond.reshape(cond_shape) + + # aten.diag_embed always returns a new contiguous tensor + # contiguous() is needed to correctly model the output stride + return utils.mask_tensor(cond, t).contiguous() + + +@register_decomposition(aten.block_diag) +@out_wrapper() +def _block_diag_iterable(tensors: List[TensorLikeType]) -> TensorLikeType: + """ + Reference implementation of torch.block_diag + """ + tensors_2d = [ + tensor.view(1, -1) if tensor.dim() <= 1 else tensor for tensor in tensors + ] + + ncols = builtins.sum(tensor.shape[1] for tensor in tensors_2d) + device = tensors_2d[0].device + + result = [] + + col_start = 0 + for i, tensor in enumerate(tensors_2d): + torch._check( + tensor.dim() == 2, + lambda: "Input tensors must have 2 or fewer dimensions. " + f"Input {i} has {tensor.dim()} dimensions", + ) + torch._check( + tensor.device == device, + lambda: "Input tensors must all be on the same device. " + f"Input 0 is on device {device} and input {i} is on device {tensor.device}.", + ) + row, col = tensor.shape + left = torch.zeros((row, col_start), device=device, dtype=tensor.dtype) + right = torch.zeros( + (row, ncols - col_start - col), device=device, dtype=tensor.dtype + ) + result += [torch.cat((left, tensor, right), dim=1)] + col_start += col + + return torch.cat(result, dim=0) + + +def block_diag(*tensors: List[TensorLikeType]) -> TensorLikeType: + """ + This is used as an input to PythonRefInfo. `torch.block_diag` + expects arguments splatted, but `aten.block_diag` expects only + one argument that is a list of Tensors. + """ + return _block_diag_iterable(tensors) + + +# CompositeImplicitAutograd - don't register decomp +def dsplit(a: TensorLikeType, sections: DimsType) -> TensorSequenceType: + if a.ndim < 3: + raise RuntimeError( + f"torch.dsplit requires a tensor with at least 3 dimension, but got a tensor with {a.ndim} dimensions!" + ) + if isinstance(sections, IntLike) and (sections == 0 or a.shape[2] % sections != 0): + raise RuntimeError( + "torch.dsplit attempted to split along dimension 2, " + + f"but the size of the dimension {a.shape[2]} is not divisible by the split_size {sections}!" + ) + return tensor_split(a, sections, 2) + + +@register_decomposition(aten.t.default) +def t(a: TensorLikeType): + # TODO: Add sparse support + # if a.is_sparse: + # sparse_dim = a.sparse_dim() + # dense_dim = a.dense_dim() + # if not (sparse_dim <= 2 and dense_dim == 0): + # raise RuntimeError( + # f"t() expects a tensor with <= 2 sparse and 0 dense dimensions, but got {sparse_dim} sparse and" + # f"{dense_dim} dense dimensions" + # ) + if a.ndim > 2: + raise RuntimeError( + f"t() expects a tensor with <= 2 dimensions, but self is {a.ndim}D" + ) + return torch.transpose(a, 0, 0 if a.ndim < 2 else 1) + + +# CompositeImplicitAutograd - don't register decomp +def T(a: TensorLikeType) -> TensorLikeType: + # n != 2 && n != 0 is deprecated in regular PyTorch. + torch._check( + a.ndim in (0, 2), + lambda: ( + "The use of `x.T` on tensors of dimension other than 0 or 2 " + "to reverse their shape is not supported." + ), + ) + return a.t() + + +@register_decomposition(aten.alias) +def alias(a: TensorLikeType) -> TensorLikeType: + return prims.view_of(a) + + +@register_decomposition(aten.transpose) +def transpose(a: TensorLikeType, dim0: int, dim1: int) -> TensorLikeType: + _dim0, _dim1 = utils.canonicalize_dims(a.ndim, (dim0, dim1)) # type: ignore[misc] + + if a.ndim <= 1 or dim0 == dim1: + return aten.alias.default(a) + + _permutation = list(range(0, a.ndim)) + _permutation[_dim0] = _dim1 + _permutation[_dim1] = _dim0 + return torch.permute(a, _permutation) + + +# Aliases for transpose +swap_axes = transpose + + +@register_decomposition(aten.unfold) +def unfold( + self: TensorLikeType, dimension: int, size: int, step: int +) -> TensorLikeType: + shape, strides = _get_unfold_shape_stride( + self.shape, self.stride(), dimension, size, step + ) + return self.as_strided(shape, strides) + + +@register_decomposition(aten.unfold_copy) +@out_wrapper() +def unfold_copy(self: TensorLikeType, dimension: int, size: int, step: int): + return self.unfold(dimension, size, step).clone( + memory_format=torch.contiguous_format + ) + + +def _cumsumprod_common( + func, + init, + a: TensorLikeType, + dim: int, + *, + dtype: Optional[torch.dtype] = None, + out: Optional[Tensor] = None, +) -> TensorLikeType: + # We implement all the kwargs of a reduction. ATen just handles dtype + # nb. This decomposition may not be as efficient as a backend-specific implementation + ndim = a.ndim + dim = utils.canonicalize_dim(ndim, dim) + if ndim == 0: + return func(a.unsqueeze(0), dim=0, dtype=dtype, out=out) + a = a.unsqueeze(dim + 1) + rg = torch.arange(a.shape[dim], device=a.device) + mask = rg.unsqueeze(1) <= rg + for _ in range(ndim - dim - 1): + mask = mask.unsqueeze(-1) + masked_a = torch.where(mask, a, init) + return func(masked_a, dim=dim, dtype=dtype, out=out) + + +@register_decomposition(aten.cumsum) +def cumsum( + a: TensorLikeType, + dim: int, + *, + dtype: Optional[torch.dtype] = None, + out: Optional[Tensor] = None, +) -> TensorLikeType: + return _cumsumprod_common(func=sum, init=0, a=a, dim=dim, dtype=dtype, out=out) + + +@register_decomposition(aten.cumprod) +def cumprod( + a: TensorLikeType, + dim: int, + *, + dtype: Optional[torch.dtype] = None, + out: Optional[Tensor] = None, +) -> TensorLikeType: + return _cumsumprod_common(func=prod, init=1, a=a, dim=dim, dtype=dtype, out=out) + + +# Note: although squeeze is documented as having the out= kwarg it doesn't +@register_decomposition(aten.unsqueeze) +def unsqueeze(a: TensorLikeType, dim: int) -> TensorLikeType: + # Note that unsqueeze canonicalizes with rank + 1 because it allows + # a new innermost dimension to be specified + ndim = a.ndim + 1 + dim = utils.canonicalize_dim(ndim, dim) + return prims.expand_dims(a, (dim,), ndim=ndim) + + +# NOTE: shape is a vararg because Tensor.reshape can be called with as +# Tensor.view(a, b, c) or Tensor.view((a, b, c)) Function call torch.view +# doesn't support unpacked shapes +# TODO: Turn this into a decomposition (currently fails on reshape meta tests) +@register_decomposition(aten.view.default) +def view(a: TensorLikeType, *shape: ShapeType) -> TensorLikeType: + return _reshape_view_helper(a, *shape, allow_copy=False) + + +# CompositeImplicitAutograd - don't register decomp +def view_as(self: TensorLikeType, other: TensorLikeType) -> TensorLikeType: + return self.view(other.size()) + + +# CompositeImplicitAutograd - don't register decomp +def ravel(a: TensorLikeType) -> TensorLikeType: + return reshape(a, (-1,)) + + +# CompositeImplicitAutograd - don't register decomp +# missing ref impl. for aten.gather +@out_wrapper() +def take_along_dim( + a: torch.Tensor, indices: torch.Tensor, dim: Optional[int] = None +) -> torch.Tensor: + torch._check( + a.ndim == indices.ndim, + lambda: ( + "torch.take_along_dim(): input and indices should have the same " + f"number of dimensions, but got {a.ndim} dimensions for input, and " + f"{indices.ndim} dimensions for indices" + ), + ) + + torch._check( + utils.is_integer_dtype(indices.dtype), + lambda: ( + "torch.take_along_dim(): dtype of indices should be int but got " + f"{indices.dtype} instead" + ), + ) + + if dim is None: + return torch.gather(a.view(-1), 0, indices.view(-1)) + else: + self_sizes = list(a.shape) + self_sizes[dim] = indices.size(dim) + broadcast_shape = utils.infer_size_shapes(self_sizes, indices.size()) + indices_broadcast = broadcast_to(indices, broadcast_shape) + + indices_sizes = list(indices.shape) + indices_sizes[dim] = a.size(dim) + broadcast_shape = utils.infer_size_shapes(indices_sizes, a.size()) + self_broadcast = broadcast_to(a, broadcast_shape) + + return torch.gather(self_broadcast, dim, indices_broadcast) + + +@out_wrapper() +def empty( + *shape, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + device: Optional[DeviceLikeType] = None, + requires_grad: bool = False, + pin_memory: bool = False, + memory_format: torch.memory_format = torch.contiguous_format, +) -> TensorLikeType: + torch._check( + memory_format != torch.preserve_format, + lambda: "torch.empty: the Preserve memory format is not supported", + ) + + shape = utils.extract_shape_from_varargs(shape) + + if memory_format == torch.contiguous_format: + strides = utils.make_contiguous_strides_for(shape) + elif memory_format == torch.channels_last_3d: + strides = utils.make_channels_last_3d_strides_for(shape) + else: # memory_format == torch.channels_last + torch._check( + memory_format == torch.channels_last, + lambda: f"torch.empty: received an unknown memory format {memory_format}!", + ) + strides = utils.make_channels_last_2d_strides_for(shape) + + return torch.empty_strided( + shape, + strides, + dtype=dtype, + layout=layout, + device=device, + pin_memory=pin_memory, + requires_grad=requires_grad, + ) + + +@out_wrapper() +def empty_permuted( + shape, + physical_layout, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + device: Optional[DeviceLikeType] = None, + requires_grad: bool = False, + pin_memory: bool = False, +) -> TensorLikeType: + return prims.empty_permuted( + shape, + physical_layout, + dtype=dtype, + device=device, + requires_grad=requires_grad, + ) + + +@register_decomposition(aten.new_empty) +@out_wrapper() +def new_empty( + a: TensorLikeType, + size: ShapeType, + *, + dtype: Optional[torch.dtype] = None, + layout: Optional[torch.layout] = None, + device: Optional[DeviceLikeType] = None, + pin_memory: bool = False, +) -> TensorLikeType: + dtype = a.dtype if dtype is None else dtype + layout = a.layout if layout is None else layout + device = a.device if device is None else device + + return torch.empty( + size, + dtype=dtype, + device=device, + pin_memory=pin_memory, + layout=layout, + ) + + +@register_decomposition(aten.new_empty_strided) +@out_wrapper() +def new_empty_strided( + a: TensorLikeType, + size: ShapeType, + stride: StrideType, + *, + dtype: Optional[torch.dtype] = None, + layout: Optional[torch.layout] = None, + device: Optional[DeviceLikeType] = None, + pin_memory: bool = False, +) -> TensorLikeType: + """ + Reference implementation of torch.Tensor.new_empty_strided + """ + + dtype = a.dtype if dtype is None else dtype + layout = a.layout if layout is None else layout + device = a.device if device is None else device + + return torch.empty_strided( + size, + stride, + dtype=dtype, + device=device, + pin_memory=pin_memory, + layout=layout, + ) + + +@register_decomposition(aten.zeros.default) +@out_wrapper() +def zeros( + *size, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + device: Optional[DeviceLikeType] = None, + pin_memory: bool = False, + requires_grad: bool = False, +) -> TensorLikeType: + size = utils.extract_shape_from_varargs(size) + + if dtype is None: + dtype = torch.get_default_dtype() + + return torch.full( + size, + False if dtype == torch.bool else 0, + dtype=dtype, + layout=layout, + device=device, + pin_memory=pin_memory, + requires_grad=requires_grad, + ) + + +@register_decomposition(aten.new_zeros) +@out_wrapper() +def new_zeros( + a: TensorLikeType, + size: ShapeType, + *, + dtype: Optional[torch.dtype] = None, + layout: Optional[torch.layout] = None, + device: Optional[DeviceLikeType] = None, + pin_memory: bool = False, + requires_grad: bool = False, +) -> TensorLikeType: + dtype = a.dtype if dtype is None else dtype + layout = a.layout if layout is None else layout + device = a.device if device is None else device + + return torch.full( + size, + False if (dtype or a.dtype) == torch.bool else 0, + dtype=dtype, + layout=layout, + device=device, + pin_memory=pin_memory, + requires_grad=requires_grad, + ) + + +@register_decomposition(aten.ones.default) +@out_wrapper() +def ones( + *size, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + device: Optional[DeviceLikeType] = None, + pin_memory: bool = False, + requires_grad: bool = False, +) -> TensorLikeType: + size = utils.extract_shape_from_varargs(size) + + if dtype is None: + dtype = torch.get_default_dtype() + + return torch.full( + size, + True if dtype == torch.bool else 1, + dtype=dtype, + layout=layout, + device=device, + pin_memory=pin_memory, + requires_grad=requires_grad, + ) + + +@register_decomposition(aten.new_ones) +@out_wrapper() +def new_ones( + a: TensorLikeType, + size: ShapeType, + *, + dtype: Optional[torch.dtype] = None, + layout: Optional[torch.layout] = None, + device: Optional[DeviceLikeType] = None, + pin_memory: bool = False, + requires_grad: bool = False, +) -> TensorLikeType: + dtype = a.dtype if dtype is None else dtype + layout = a.layout if layout is None else layout + device = a.device if device is None else device + + return torch.full( + size, + True if (dtype or a.dtype) == torch.bool else 1, + dtype=dtype, + layout=layout, + device=device, + pin_memory=pin_memory, + requires_grad=requires_grad, + ) + + +@register_decomposition(aten.new_full) +@out_wrapper() +def new_full( + a: TensorLikeType, + size: ShapeType, + fill_value: NumberType, + *, + dtype: Optional[torch.dtype] = None, + layout: Optional[torch.layout] = None, + device: Optional[DeviceLikeType] = None, + pin_memory: bool = False, +) -> TensorLikeType: + dtype = a.dtype if dtype is None else dtype + layout = a.layout if layout is None else layout + device = a.device if device is None else device + + return torch.full( + size, + fill_value, + dtype=dtype, + layout=layout, + device=device, + pin_memory=pin_memory, + ) + + +@register_decomposition(aten.empty_like) +@out_wrapper() +def empty_like( + a: TensorLikeType, + *, + dtype: Optional[torch.dtype] = None, + device: Optional[DeviceLikeType] = None, + layout: Optional[torch.layout] = None, + pin_memory: bool = False, + requires_grad: bool = False, + memory_format: torch.memory_format = torch.preserve_format, +) -> TensorLikeType: + dtype = a.dtype if dtype is None else dtype + layout = a.layout if layout is None else layout + device = a.device if device is None else device + + if memory_format != torch.preserve_format: + return torch.empty( + a.shape, + dtype=dtype, + layout=layout, + device=device, + requires_grad=requires_grad, + pin_memory=pin_memory, + memory_format=memory_format, + ) + + # memory_format == torch.preserve_format + logical_to_physical_perm = ( + utils.compute_elementwise_output_logical_to_physical_perm(a) + ) + # identity perm is [2, 1, 0] + return torch.empty_permuted( + a.shape, + logical_to_physical_perm, + dtype=dtype, + layout=layout, + device=device, + pin_memory=pin_memory, + requires_grad=requires_grad, + ) + + +@register_decomposition([aten.arange.start_step, aten.arange.start_out]) +@out_wrapper() +def arange( + start: NumberType = 0, + end: Optional[NumberType] = None, + step: NumberType = 1, + *, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + device: Optional[DeviceLikeType] = None, + pin_memory: bool = False, + requires_grad: bool = False, +) -> TensorLikeType: + utils.check_layout(layout) + utils.check_pin_memory(pin_memory) + device = torch.device(utils.device_or_default(device)) + + assert not isinstance(start, complex) + assert not isinstance(end, complex) + assert not isinstance(step, complex) + + # Case: torch.arange(5) + if end is None: + end = start + start = 0 + torch._check(step != 0, lambda: "step must be nonzero") + if step > 0: + torch._check( + end >= start, + lambda: "upper bound and lower bound inconsistent with step sign", + ) + elif step < 0: + torch._check( + end <= start, + lambda: "upper bound and lower bound inconsistent with step sign", + ) + + def is_finite(x): + return not isinstance(x, FloatWithoutSymFloat) or math.isfinite(x) + + torch._check( + is_finite(start) and is_finite(end), + lambda: f"unsupported range: {start} -> {end}", + ) + torch._check( + is_finite(step), + lambda: f"step must be finite but got {step}", + ) + + if dtype is None: + args = (start, end, step) + integer_args = builtins.all(isinstance(arg, IntLike) for arg in args) + dtype = torch.int64 if integer_args else torch.get_default_dtype() + + is_integer = utils.is_integer_dtype(dtype) + if is_integer: + xstart = sym_int(start) + xend = sym_int(end) + xstep = sym_int(step) + + # For int64 we truncate arguments to int before calculating length, but + # other integral dtypes we don't. Weird... but needed to match ATen shapes. + if dtype == torch.int64: + # Uses floordiv to avoid ceil in inductor. + sgn = bool(xstep > 0) - bool(xstep < 0) # type: ignore[possibly-undefined] + length = (xend - xstart + xstep - sgn) // xstep # type: ignore[possibly-undefined] + else: + length = math.ceil((end - start) / step) + + if is_integer: + return prims.iota( + length, + start=xstart, # type: ignore[possibly-undefined] + step=xstep, # type: ignore[possibly-undefined] + dtype=dtype, + device=device, + requires_grad=requires_grad, + ) + + computation_dtype = utils.get_acc_type(dtype, device) + index = prims.iota( + length, + start=0, + step=1, + dtype=torch.int64, + device=device, + requires_grad=False, + ) + index = _maybe_convert_to_dtype(index, computation_dtype) + result = start + step * index + result = _maybe_convert_to_dtype(result, dtype) + + if requires_grad: + result.requires_grad_(True) + return result + + +@register_decomposition(aten.lerp) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("start", "end", "weight"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def lerp(start: Tensor, end: Tensor, weight: Union[Tensor, NumberType]): + inputs = [start, end] + if isinstance(weight, Number): + weight = start.new_full((), weight) # type: ignore[arg-type] + else: + inputs.append(weight) + assert isinstance(weight, Tensor) # mypy + # We implement it this way for numerical stability. We assume (in the stability optimisation) + # that 0 <= weight <= 1. We take the abs to deal with complex numbers + # We want to perform operations near zero, which is where floating points are most precise + # thus, we perform the following optimisation: + # If weight.abs() >= 0.5: + # return (1 - weight) * (start - end) + end + mask = weight.abs() >= 0.5 + coeff = torch.where(mask, weight - 1, weight) + base = torch.where(mask, end, start) + output = coeff * (end - start) + base + # make sure the decomposition output's stride is same as non-decomposition path. + stride = utils.compute_elementwise_output_strides(*_maybe_broadcast(*inputs)) + if output.stride() != stride: + output = prims.copy_strided(output, stride) + + return handle_noncontiguous_outputs(inputs, output) + + +@register_decomposition(aten.linspace) +@out_wrapper() +def linspace( + start: Union[NumberType, TensorLikeType], + end: Union[NumberType, TensorLikeType], + steps: NumberType, + *, + dtype: Optional[torch.dtype] = None, + device: Optional[DeviceLikeType] = None, + layout: torch.layout = torch.strided, + pin_memory: bool = False, + requires_grad: bool = False, +) -> TensorLikeType: + if isinstance(start, TensorLikeType): + torch._check( + start.dim() == 0, + lambda: "linspace only supports 0-dimensional start and end tensors", + ) + start = _maybe_convert_to_dtype(start, torch.float64) + if isinstance(end, TensorLikeType): + torch._check( + end.dim() == 0, + lambda: "linspace only supports 0-dimensional start and end tensors", + ) + end = _maybe_convert_to_dtype(end, torch.float64) + + if py_any(isinstance(arg, complex) for arg in (start, end, steps)): + default_complex_dtype = utils.corresponding_complex_dtype( + torch.get_default_dtype() + ) + if dtype is None: + dtype = default_complex_dtype + else: + torch._check( + utils.is_complex_dtype(dtype), + lambda: f"linspace(): inferred dtype {default_complex_dtype} can't be safely cast to passed dtype {dtype}", + ) + else: + dtype = dtype or torch.get_default_dtype() + assert isinstance(dtype, torch.dtype) + + # steps does not participate in the computation of the dtype + torch._check_type( + isinstance(steps, IntLike), + lambda: f"received an invalid combination of arguments - got \ +({type(start).__name__}, {type(end).__name__}, {type(steps).__name__})", + ) + assert isinstance(steps, IntLike) # for mypy + torch._check(steps >= 0, lambda: "number of steps must be non-negative") + + factory_kwargs = { + "layout": layout, + "device": device, + "pin_memory": pin_memory, + "requires_grad": requires_grad, + } + if steps == 0: + return torch.full((0,), 0, dtype=dtype, **factory_kwargs) # type: ignore[arg-type] + if steps == 1: + if isinstance(start, TensorLikeType): + return torch.empty((steps,), dtype=dtype, **factory_kwargs).copy_(start) # type: ignore[arg-type] + else: + return torch.full((steps,), start, dtype=dtype, **factory_kwargs) # type: ignore[arg-type] + + # Perform in arange in int because some backends like ATen or Triton do not support all the dtypes + rg = torch.arange(0, steps, **factory_kwargs) # type: ignore[arg-type] + + # Small types need to be computed in higher precision as this is, at heart, an associative scan + dtype_red = ( + torch.int64 + if (utils.is_boolean_dtype(dtype) or utils.is_integer_dtype(dtype)) + else dtype + ) + computation_dtype, _ = utils.reduction_dtypes( + rg, REDUCTION_OUTPUT_TYPE_KIND.SAME, dtype_red + ) + cast_rg = partial(_maybe_convert_to_dtype, dtype=computation_dtype) + + # We implement torch.lerp without performing rg / (steps - 1) explicitly + # With this we get out[0] == start, out[-1] == end + step = (end - start) / (steps - 1) + out = torch.where( + rg < steps / 2, + start + step * cast_rg(rg), # type: ignore[arg-type,operator] + end - step * cast_rg((steps - 1) - rg), # type: ignore[arg-type,operator] + ) + return _maybe_convert_to_dtype(out, dtype) # type: ignore[return-value] + + +@register_decomposition(aten.logspace) +@out_wrapper() +def logspace( + start: Union[NumberType, TensorLikeType], + end: Union[NumberType, TensorLikeType], + steps: NumberType, + base: NumberType = 10, + *, + dtype: Optional[torch.dtype] = None, + device: Optional[DeviceLikeType] = None, + layout: torch.layout = torch.strided, + pin_memory: bool = False, + requires_grad: bool = False, +) -> TensorLikeType: + if dtype is None: + dtype = torch.get_default_dtype() + + # NB: NumPy doesn't have this cast + if prims.utils.is_integer_dtype(dtype): + if isinstance(start, FloatLike): + start = sym_int(start) + elif isinstance(start, TensorLikeType): + torch._check( + start.dim() == 0, + lambda: "logspace only supports 0-dimensional start and end tensors", + ) + start = _maybe_convert_to_dtype(start, dtype) + if isinstance(end, FloatLike): + end = sym_int(end) + elif isinstance(end, TensorLikeType): + torch._check( + end.dim() == 0, + lambda: "logspace only supports 0-dimensional start and end tensors", + ) + end = _maybe_convert_to_dtype(end, dtype) + + if py_any(isinstance(arg, complex) for arg in (start, end, steps)): + default_complex_dtype = utils.corresponding_complex_dtype( + torch.get_default_dtype() + ) + dtype = default_complex_dtype + _dtype = None # torch.linspace will update the correct dtype + else: + _dtype = torch.float64 + + assert not isinstance(base, complex) # for mypy + if base < 0: + raise NotImplementedError + ret = torch.linspace( # type: ignore[misc] + start, # type: ignore[arg-type] + end, # type: ignore[arg-type] + steps, # type: ignore[arg-type] + dtype=_dtype, + layout=layout, + device=device, + pin_memory=pin_memory, + requires_grad=requires_grad, + ) + return _maybe_convert_to_dtype(torch.pow(base, ret), dtype) # type: ignore[arg-type,return-value] + + +@overload +def meshgrid(tensors: Sequence[TensorLikeType], indexing: str): + pass + + +@overload +def meshgrid(*tensors: TensorLikeType, indexing: str): + pass + + +@register_decomposition(aten.meshgrid) +def meshgrid( + *tensors: Union[TensorLikeType, List[TensorLikeType], Tuple[TensorLikeType]], + indexing: str, +) -> List[TensorLikeType]: + # This ref simultaneously handles two overloads (see stubs above) + # The `indexing` argument is currently optional for torch.meshgrid, but we + # plan to make the argument required: https://github.com/pytorch/pytorch/issues/50276 + if isinstance(tensors[0], (list, tuple)): + assert len(tensors) == 1 + tensors = tuple(tensors[0]) + + torch._check( + py_all(isinstance(a, TensorLike) for a in tensors), + lambda: "meshgrid expects its inputs to be tensors", + ) + + torch._check(len(tensors) > 0, lambda: "meshgrid expects a non-empty TensorList") + + for i in range(len(tensors) - 1): + torch._check( + tensors[i].dtype == tensors[i + 1].dtype, # type: ignore[union-attr] + lambda: "meshgrid expects all tensors to have the same dtype", + ) + torch._check( + tensors[i].device == tensors[i + 1].device, # type: ignore[union-attr] + lambda: "meshgrid expects all tensors to have the same device", + ) + + swap_first_and_second_tensors = False + if indexing == "xy": + swap_first_and_second_tensors = len(tensors) >= 2 + if swap_first_and_second_tensors: + tensors = (tensors[1], tensors[0], *tensors[2:]) + else: + torch._check( + indexing == "ij", + lambda: ( + 'torch.meshgrid: indexing must be one of "xy" or "ij", ' + f"but received: {indexing}" + ), + ) + + result_shape: List[int] = [] + for t in tensors: + assert isinstance(t, TensorLike) # mypy + torch._check( + t.ndim == 0 or t.ndim == 1, + lambda: f"torch.meshgrid: Expected 0D or 1D tensor in the tensor list but got: {t}", + ) + result_shape.append(t.numel()) + + grids: List[TensorLikeType] = [] + for i, t in enumerate(tensors): + assert isinstance(t, TensorLike) # mypy + if t.ndim == 0: + t = t.view((1,)) + grids.append(prims.broadcast_in_dim(t, result_shape, (i,))) + + if swap_first_and_second_tensors: + # Swap outputs if we originally swapped at the beginning + grids[0], grids[1] = grids[1], grids[0] + + return grids + + +# CompositeImplicitAutograd - don't register decomp +def movedim( + input: TensorLikeType, + source: Union[int, DimsSequenceType], + destination: Union[int, DimsSequenceType], +) -> TensorLikeType: + """ + Reference implementation of torch.movedim + """ + if type(source) is int: + source = (source,) + if type(destination) is int: + destination = (destination,) + + # Converts to list to produce a compatible error message with core PyTorch, + # which prints sequences in square brackets. + torch._check( + len(source) == len(destination), # type: ignore[arg-type] + lambda: ( + "movedim: Invalid source or destination dims: source " # type: ignore[arg-type] + f"({list(source)} dims) should contain the same number " # type: ignore[arg-type] + f"of dims as destination ({list(destination)} dims)" # type: ignore[arg-type] + ), + ) + + rank = input.ndim + ss = tuple(utils.canonicalize_dims(rank=rank, indices=source)) # type: ignore[arg-type] + ds = tuple(utils.canonicalize_dims(rank=rank, indices=destination)) # type: ignore[arg-type] + + sss = set(ss) + dss = set(ds) + + # See above on why this converts to list in error messages. + torch._check( + len(ss) == len(sss), + lambda: f"movedim: repeated dim in `source` ({list(source)})", # type: ignore[arg-type] + ) + torch._check( + len(ds) == len(dss), + lambda: f"movedim: repeated dim in `destination` ({list(destination)})", # type: ignore[arg-type] + ) + + m = dict(zip(ds, ss)) + dims = [] + si = 0 # source index + for di in range(rank): + # check if the destination index is in the mapping + s = m.get(di) + if s is not None: + # insert source index if found + dims.append(s) + else: + # insert source index sequentially, skipping indices from the mapping + while si in sss: + si += 1 + dims.append(si) + si += 1 + + result = torch.permute(input, tuple(dims)) + + return result + + +# NOTE: for convenience, shape can be a tuple of ints or a tuple containing a tuple of ints +@register_decomposition(aten.empty_strided) +@out_wrapper() +def empty_strided( + shape: Union[ShapeType, Tuple[ShapeType]], + strides: StrideType, + *, + dtype: Optional[torch.dtype] = None, + device: Optional[DeviceLikeType] = None, + layout: torch.layout = torch.strided, + requires_grad: bool = False, + pin_memory: bool = False, +) -> TensorLikeType: + # Layout == strided, pin_memory is False + utils.check_layout(layout) + utils.check_pin_memory(pin_memory) + + shape = utils.extract_shape_from_varargs(shape) + dtype = torch.get_default_dtype() if dtype is None else dtype + device = torch.device("cpu") if device is None else device + + return prims.empty_strided( + shape, + strides, + dtype=dtype, + device=device, + requires_grad=requires_grad, + ) + + +@register_decomposition(aten.eye) +@out_wrapper() +def eye( + n: int, + m: Optional[int] = None, + *, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + device: Optional[DeviceLikeType] = None, + pin_memory: bool = False, + requires_grad: bool = False, # TODO: unused +) -> TensorLikeType: + """ + Reference implementation of torch.eye + """ + if m is None: + m = n + + torch._check(n >= 0, lambda: f"n must be greater or equal to 0, got {n}") + torch._check(m >= 0, lambda: f"m must be greater or equal to 0, got {m}") + + range_n = torch.arange(n, dtype=torch.int64, device=device, requires_grad=False) + range_m = torch.arange(m, dtype=torch.int64, device=device, requires_grad=False) + + cond = range_n.unsqueeze(-1) == range_m + if dtype is torch.bool: + return cond + else: + one = torch.ones( + (1,), + dtype=dtype, + layout=layout, + device=device, + pin_memory=pin_memory, + requires_grad=False, + ) + return torch.where(cond, one, 0) + # TODO: Use requires_grad. All refs taking the requires_grad kwarg must + # return a leaf tensor. + # result.requires_grad_(requires_grad) + + +@register_decomposition([aten.full.default, aten.full.out]) +@out_wrapper() +def full( + shape: ShapeType, + fill_value: NumberType, + *, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + device: Optional[DeviceLikeType] = None, + pin_memory: bool = False, + requires_grad: bool = False, +) -> TensorLikeType: + utils.check_layout(layout) + utils.check_pin_memory(pin_memory) + + dtype = dtype if dtype is not None else utils.type_to_dtype(type(fill_value)) + device = device if device is not None else torch.device("cpu") + + e = empty( + shape, + dtype=dtype, + layout=layout, + device=device, + pin_memory=pin_memory, + requires_grad=requires_grad, + ) + return torch.fill(e, fill_value) # type: ignore[arg-type] + + +def full_like( + a: TensorLikeType, + fill_value: NumberType, + *, + dtype: Optional[torch.dtype] = None, + layout: Optional[torch.layout] = None, + device: Optional[DeviceLikeType] = None, + pin_memory: bool = False, + requires_grad: bool = False, + memory_format: torch.memory_format = torch.preserve_format, +) -> TensorLikeType: + e = torch.empty_like( + a, + dtype=dtype, + layout=layout, + device=device, + pin_memory=pin_memory, + requires_grad=requires_grad, + memory_format=memory_format, + ) + return fill(e, fill_value) + + +@register_decomposition(aten.zeros_like) +@out_wrapper() +def zeros_like( + a: TensorLikeType, + *, + dtype: Optional[torch.dtype] = None, + layout: Optional[torch.layout] = None, + device: Optional[DeviceLikeType] = None, + pin_memory: bool = False, + requires_grad: bool = False, + memory_format: torch.memory_format = torch.preserve_format, +) -> TensorLikeType: + return torch.full_like( + a, + False if (dtype or a.dtype) == torch.bool else 0, + dtype=dtype, + layout=layout, + device=device, + pin_memory=pin_memory, + requires_grad=requires_grad, + memory_format=memory_format, + ) + + +@register_decomposition(aten.ones_like) +@out_wrapper() +def ones_like( + a: TensorLikeType, + *, + dtype: Optional[torch.dtype] = None, + layout: Optional[torch.layout] = None, + device: Optional[DeviceLikeType] = None, + pin_memory: bool = False, + requires_grad: bool = False, + memory_format: torch.memory_format = torch.preserve_format, +) -> TensorLikeType: + return torch.full_like( + a, + True if (dtype or a.dtype) == torch.bool else 1, + dtype=dtype, + layout=layout, + device=device, + pin_memory=pin_memory, + requires_grad=requires_grad, + memory_format=memory_format, + ) + + +@register_decomposition(aten.randn.default) +@out_wrapper() +def randn( + *shape, + dtype: Optional[torch.dtype] = None, + device: Optional[DeviceLikeType] = None, + layout: Optional[torch.layout] = None, + requires_grad: bool = False, + pin_memory: bool = False, +) -> TensorLikeType: + utils.check_pin_memory(pin_memory) + + shape_ = utils.extract_shape_from_varargs(shape) + + dtype = utils.dtype_or_default(dtype) + device = utils.device_or_default(device) + + return prims.normal( + shape_, + mean=0.0, + std=1.0, + dtype=dtype, + device=device, + requires_grad=requires_grad, + ) + + +def scalar_tensor( + a: NumberType, + *, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + device: Optional[DeviceLikeType] = None, + pin_memory: bool = False, +) -> TensorLikeType: + utils.check_layout(layout) + utils.check_pin_memory(pin_memory) + dtype = dtype if dtype is not None else utils.type_to_dtype(type(a)) + device = device if device is not None else torch.device("cpu") + return prims.scalar_tensor(a, dtype=dtype, device=device) + + +# +# Randomness References +# + + +def _uniform_helper( + shape: ShapeType, + low: Union[bool, int, float] = 0.0, + high: Union[bool, int, float] = 1.0, + *, + dtype: torch.dtype, + device: DeviceLikeType, +) -> TensorLikeType: + utils.validate_shape(shape) + + assert isinstance(low, Number) + assert isinstance(high, Number) + low = sym_float(low) + high = sym_float(high) + + assert isinstance(dtype, torch.dtype) + device = utils.canonicalize_device(device) + + return prims._uniform_helper(shape, low=low, high=high, dtype=dtype, device=device) + + +@register_decomposition(aten.masked_fill) +@out_wrapper() +def masked_fill(a: TensorLikeType, mask: TensorLikeType, value: TensorOrNumberLikeType): + python_type = utils.dtype_to_type(a.dtype) + if isinstance(value, Number): + value_type = type(value) + else: + # NOTE: Could not use value = item(value) as it resulted in + # RuntimeError: Cannot cast FakeTensor(cpu) to number + value_ndim = value.ndim + torch._check( + value_ndim == 0, + lambda: f"only supports a 0-dimensional value tensor, but got tensor with {value_ndim} dimension", + ) + # `masked_fill` allows cpu scalar to be moved to cuda and xpu but not otherwise. + is_cpu_scalar = a.device.type in ["cuda", "xpu"] and value.device.type == "cpu" + torch._check( + is_cpu_scalar or value.device == a.device, + lambda: "Expected `value` to be on same device as `a`", + ) + value_type = utils.dtype_to_type(value.dtype) + + if value_type is complex: + # only downcasting from complex to lower type is not allowed. + # We allow casting `value` to lower type for other case + # Eg. float -> int. + # Ref: https://github.com/pytorch/pytorch/issues/79195 + torch._check( + utils.is_weakly_lesser_type(value_type, python_type), + lambda: f"could not convert to type {python_type} without overflow", + ) + + # Since `where` allows type-promotion, + # cast value to correct type before passing to `where` + value = _maybe_convert_to_dtype(value, a.dtype) + r = torch.where(mask, value, a) # type: ignore[arg-type] + + # aten.mask_fill always return a new contiguous tensor + # contiguous() is needed to correctly model the output stride + return r.contiguous() + + +@register_decomposition(aten.masked_fill_) +def masked_fill_( + a: TensorLikeType, mask: TensorLikeType, value: TensorOrNumberLikeType +) -> TensorLikeType: + b = torch.masked_fill(a, mask, value) # type: ignore[arg-type] + a.copy_(b) + return a + + +# CompositeImplicitAutograd - don't register decomp +def allclose( + a: TensorLikeType, + b: TensorLikeType, + rtol: float = 1e-05, + atol: float = 1e-08, + equal_nan: bool = False, +) -> bool: + """ + Reference implementation of torch.allclose + """ + _check_close_args(name="torch.allclose", a=a, b=b, rtol=rtol, atol=atol) + + return bool( + torch.all(torch.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)).item() + ) + + +def equal(a: TensorLikeType, b: TensorLikeType) -> bool: + utils.check_same_device(a, b, allow_cpu_scalar_tensors=False) + utils.check_same_dtype(a, b) + + # Shape check + if a.ndim != b.ndim: + return False + + for x, y in zip(a.shape, b.shape): + if x != y: + return False + + # Short-circuits if there are no elements to validate + if a.numel() == 0: + return True + + return item(all(eq(a, b))) # type: ignore[return-value] + + +@register_decomposition(aten.norm) +@out_wrapper(exact_dtype=True) +def norm( + input: TensorLikeType, + p: Optional[Union[float, str]] = "fro", + dim: Optional[DimsType] = None, + keepdim: bool = False, + *, + dtype: Optional[torch.dtype] = None, +) -> TensorLikeType: + # In these cases we compute the "Frobenius norm" + if ( + p == "fro" and (dim is None or isinstance(dim, Dim) or len(dim) <= 2) + ) or p is None: + p = 2 + if isinstance(dim, Dim): + dim = [dim] + if isinstance(p, str): + # Here we either call the nuclear norm, or we call matrix_norm with some arguments + # that will throw an error + if dim is None: + dim = tuple(range(input.ndim)) + return torch.linalg.matrix_norm(input, p, dim, keepdim, dtype=dtype) + else: + return torch.linalg.vector_norm(input, p, dim, keepdim, dtype=dtype) + + +@register_decomposition(aten.trace) +@out_wrapper() +def trace(self: TensorLikeType) -> TensorLikeType: + torch._check( + self.ndim == 2, lambda: "expected a matrix, but got tensor with dim {self.ndim}" + ) + return torch.sum(torch.diag(self, 0)) + + +def _make_r_binary_op(base_op): + def rop( + a: Union[TensorLikeType, NumberType], + b: Union[TensorLikeType, NumberType], + ) -> TensorLikeType: + return base_op(b, a) + + return rop + + +rtruediv = _make_r_binary_op(true_divide) +rfloordiv = _make_r_binary_op(floor_divide) +rpow = _make_r_binary_op(pow) + + +@register_decomposition(aten.triu) +@out_wrapper() +def triu(a: TensorLikeType, diagonal: int = 0) -> TensorLikeType: + torch._check( + a.ndim >= 2, lambda: "triu: input tensor must have at least 2 dimensions" + ) + h, w = a.shape[-2:] + mask = ( + torch.arange(w, device=a.device).unsqueeze(-2) + - torch.arange(h, device=a.device).unsqueeze(-1) + ) >= diagonal + + # aten.triu always returns a new contiguous tensor + # contiguous() is needed to correctly model the output stride + return utils.mask_tensor(mask, a).contiguous() + + +@register_decomposition(aten.tril) +@out_wrapper() +def tril(a: TensorLikeType, diagonal: int = 0) -> TensorLikeType: + torch._check( + a.ndim >= 2, lambda: "tril: input tensor must have at least 2 dimensions" + ) + h, w = a.shape[-2:] + mask = ( + torch.arange(w, device=a.device).unsqueeze(-2) + - torch.arange(h, device=a.device).unsqueeze(-1) + ) <= diagonal + + # aten.tril always returns a new contiguous tensor + # contiguous() is needed to correctly model the output stride + return utils.mask_tensor(mask, a).contiguous() + + +# This is based on get_tril_size in aten/src/ATen/native/TensorFactories.h +# The components of the matrix that belong to the lower triangle with offset +# form a pentagon that can be broken down into a top trapezoid and a bottom +# rectangle. For the implementation of tril_indices, we need the sizes of +# both of these, as well as the length of the top side of the trapezoid. +def _get_tril_sizes(row: int, col: int, offset: int) -> Tuple[int, int, int]: + if row == 0 or col == 0: + return 0, 0, 0 + + m_first_row = min(col, 1 + offset) if offset > 0 else int(row + offset > 0) + m_last_row = max(0, min(col, row + offset)) + n_row_all = max(0, min(row, row + offset)) + n_row_trapezoid = m_last_row - m_first_row + 1 + + # Number of elements in top trapezoid + trapezoid_size = (m_first_row + m_last_row) * n_row_trapezoid // 2 + # Number of elements in bottom rectangle + diff_row = n_row_all - n_row_trapezoid + rectangle_size = max(0, diff_row * col) + + return trapezoid_size, rectangle_size, m_first_row + + +def _trilu_checks( + name: str, + row: int, + col: int, + dtype: torch.dtype, + layout: torch.layout, + pin_memory: bool, +): + torch._check(row >= 0, lambda: f"row must be non-negative, got {row}") + torch._check(col >= 0, lambda: f"col must be non-negative, got {col}") + torch._check( + dtype in (torch.int32, torch.int64), + lambda: f"\"{name}\" not implemented for '{dtype}'", + ) + + +# This is based on tril_indices_cuda in aten/src/ATen/native/cuda/TensorFactories.cu +@register_decomposition(aten.tril_indices) +@out_wrapper() +def tril_indices( + row: int, + col: int, + offset: int = 0, + *, + dtype: torch.dtype = torch.long, + layout: torch.layout = torch.strided, + device: DeviceLikeType = "cpu", + pin_memory: bool = False, +) -> TensorLikeType: + _trilu_checks("tril_indices", row, col, dtype, layout, pin_memory) + + trapezoid_size, rectangle_size, m_first_row = _get_tril_sizes(row, col, offset) + row_offset = max(0, -offset) + + arange_kw = partial( + torch.arange, layout=layout, device=device, pin_memory=pin_memory + ) + + # first we do the indices for top trapezoid + xs1 = arange_kw(0, trapezoid_size, dtype=torch.float64) + b = m_first_row - 0.5 + row_inds1 = torch.floor(-b + torch.sqrt(b * b + 2 * xs1)) + col_inds1 = torch.floor(xs1 - (2 * m_first_row - 1 + row_inds1) * row_inds1 * 0.5) + row_inds1 = _maybe_convert_to_dtype(row_inds1 + row_offset, dtype) + col_inds1 = _maybe_convert_to_dtype(col_inds1, dtype) + + # then bottom rectangle + xs2 = arange_kw(0, rectangle_size, dtype=dtype) + row_inds2 = xs2 // col + (col - m_first_row + 1 + row_offset) + col_inds2 = xs2 % col + + return torch.stack( + (torch.cat((row_inds1, row_inds2)), torch.cat((col_inds1, col_inds2))) + ) + + +# Similar to _get_tril_sizes above, but here there is a top trapezoid and +# a bottom rectangle instead. Note that you can't reduce this to +# _get_tril_sizes(col, row, -offset) because that would correspond to +# decomposing into a left trapezoid and right rectangle. +def _get_triu_sizes(row: int, col: int, offset: int) -> Tuple[int, int, int]: + if row == 0 or col == 0: + return 0, 0, 0 + + m_first_row = max(0, col - offset) if offset > 0 else col + + # Number of elements in top rectangle + rectangle_size = max(0, min(row, -offset) * col) + + # Number of elements in bottom trapezoid + trapezoid_size_tril, rectangle_size_tril, _ = _get_tril_sizes(row, col, offset - 1) + triu_size = row * col - (trapezoid_size_tril + rectangle_size_tril) + trapezoid_size = triu_size - rectangle_size + + return trapezoid_size, rectangle_size, m_first_row + + +@register_decomposition(aten.triu_indices) +@out_wrapper() +def triu_indices( + row: int, + col: int, + offset: int = 0, + *, + dtype: torch.dtype = torch.long, + layout: torch.layout = torch.strided, + device: DeviceLikeType = "cpu", + pin_memory: bool = False, +) -> TensorLikeType: + _trilu_checks("triu_indices", row, col, dtype, layout, pin_memory) + + trapezoid_size, rectangle_size, m_first_row = _get_triu_sizes(row, col, offset) + col_offset = max(0, offset) + + arange_kw = partial( + torch.arange, layout=layout, device=device, pin_memory=pin_memory + ) + + # indices for top rectangle + xs2 = arange_kw(0, rectangle_size, dtype=dtype) + row_inds2 = xs2 // col + col_inds2 = xs2 % col + + # bottom trapezoid + xs1 = arange_kw(0, trapezoid_size, dtype=torch.float64) + b = -0.5 - m_first_row + row_inds1 = torch.floor(-b - torch.sqrt(b * b - 2 * xs1)) + col_inds1 = torch.floor(xs1 - ((2 * m_first_row - 1 - row_inds1) * row_inds1) * 0.5) + row_inds1 = _maybe_convert_to_dtype(row_inds1, dtype) + col_inds1 = _maybe_convert_to_dtype(col_inds1, dtype) + + if col: + row_inds1 = row_inds1 + (rectangle_size // col) + col_inds1 = col_inds1 + col_offset + + return torch.stack( + (torch.cat((row_inds2, row_inds1)), torch.cat((col_inds2, col_inds1))) + ) + + +@register_decomposition(aten.bucketize) +@out_wrapper(exact_dtype=True) +def bucketize( + a: TensorLikeType, + boundaries: TensorLikeType, + *, + out_int32: bool = False, + right: bool = False, +): + torch._check( + boundaries.dim() == 1, + lambda: f"boundaries tensor must be 1 dimension but got dim({boundaries.dim()})", + ) + + out_dtype = torch.int32 if out_int32 else torch.int64 + n_boundaries = boundaries.shape[-1] + if n_boundaries == 0: + return torch.zeros_like(a) + # We are trying to find the bucket (defined by pairs of consecutive elements of `boundaries`) + # each element of `a` belongs to. We use binary search to achieve logarithimic complexity, + # but each step of the search is done "in parallel" over all elements of `a` + # can't use int32 as indexes, so we have to do all computations with int64 and convert at the end + start = torch.zeros(a.shape, device=a.device, dtype=torch.int64) + end = start + n_boundaries + # Max depth of the binary search + # Since we can't break out of the loop at different points for different elements of a, + # we just do the max amount of iterations that binary search requires and add condition + # tensor (cond_update below) to stop updating once the search terminates + + # For first iteration through loop we can skip some checks, we have separate implementation + mid = start + (end - start) // 2 + mid_val = boundaries[mid] + if right: + cond_mid = mid_val > a + else: + cond_mid = mid_val >= a + start = torch.where(cond_mid, start, mid + 1) + + if n_boundaries > 1: + cond_update = torch.ones_like(a, dtype=torch.bool) + niters = int(math.log2(n_boundaries)) + for _ in range(niters): + end = torch.where(cond_mid & cond_update, mid, end) + cond_update = start < end + # start might end up pointing to 1 past the end, we guard against that + mid = torch.where(cond_update, start + (end - start) // 2, 0) + mid_val = boundaries[mid] + # If right is true, the buckets are closed on the *left* + # (i.e., we are doing the equivalent of std::upper_bound in C++) + # Otherwise they are closed on the right (std::lower_bound) + if right: + cond_mid = mid_val > a + else: + cond_mid = mid_val >= a + start = torch.where((~cond_mid) & cond_update, mid + 1, start) + + return start.to(dtype=out_dtype) + + +@register_decomposition(aten.cauchy) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("self",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def cauchy(self, median=0, sigma=1, generator=None): + assert generator is None + torch._check( + not utils.is_complex_dtype(self.dtype) + and not utils.is_integer_dtype(self.dtype) + and not utils.is_boolean_dtype(self.dtype), + lambda: f"Cauchy distribution is a continuous probability distribution. \ + dtype must be a floating point but you specified {self.dtype}", + ) + torch._check( + sigma > 0.0, + lambda: f"cauchy_ expects sigma > 0.0, but found sigma={sigma}", + ) + return median + sigma * torch.tan(math.pi * (torch.rand_like(self) - 0.5)) + + +@register_decomposition(aten.exponential) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("self",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def exponential(self, rate=1, generator=None): + assert generator is None + torch._check( + not utils.is_complex_dtype(self.dtype) + and not utils.is_integer_dtype(self.dtype) + and not utils.is_boolean_dtype(self.dtype), + lambda: f"Exponential distribution is a continuous probability distribution. \ + dtype must be a floating point but you specified {self.dtype}", + ) + torch._check( + rate > 0.0, + lambda: f"exponential_ expects lambda > 0.0, but found lambda={rate}", + ) + return -1 / rate * torch.log1p(-torch.rand_like(self)) + + +@register_decomposition(aten.geometric) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("self",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def geometric(self, p, generator=None): + assert generator is None + # TODO: fix inductor rand_like for integer, bool dtypes + torch._check( + not utils.is_complex_dtype(self.dtype) + and not utils.is_boolean_dtype(self.dtype), + lambda: f"geometric not implemented for {self.dtype}", + ) + torch._check( + 0 < p and p < 1, + lambda: f"geometric_ expects p to be in (0, 1), but got p={p}", + ) + return torch.floor(torch.log1p(-torch.rand_like(self)) / math.log1p(-p)) + 1 + + +@register_decomposition(aten.log_normal) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("self",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def log_normal(self, mean=1, std=2, generator=None): + assert generator is None + torch._check( + not utils.is_complex_dtype(self.dtype) + and not utils.is_integer_dtype(self.dtype) + and not utils.is_boolean_dtype(self.dtype), + lambda: f"log_normal not implemented for {self.dtype}", + ) + torch._check( + 0 < std, + lambda: f"log_normal_ expects std > 0.0, but found std={std}", + ) + return torch.exp(std * torch.randn_like(self) + mean) + + +# TODO: add support for functionalization aten.normal_functional +# NOTE: the device and dtype will be ignored when shape is None +@register_decomposition(aten.normal) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=( + "mean", + "std", + ), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def normal( + mean=0, + std=1, + size=None, + *, + generator=None, + dtype=None, + layout=None, + device=None, + pin_memory=None, +): + assert layout is None or layout == torch.strided + + if not isinstance(std, TensorLike): + torch._check( + std >= 0, lambda: f"normal expects std >= 0.0, but found std {std}" + ) + + if size is None: + tensors = tuple(t for t in (mean, std) if isinstance(t, TensorLike)) + torch._check( + len(tensors) > 0, + lambda: "normal expects that either mean or std is a tensor, or size is defined", + ) + torch._check( + layout is None and pin_memory is None, + lambda: "Cannot pass layout, or pin_memory without size", + ) + + size = _broadcast_shapes(*(t.shape for t in tensors)) + dtype = tensors[0].dtype + device = tensors[0].device + else: + torch._check( + not isinstance(mean, TensorLike) and not isinstance(std, TensorLike), + lambda: "normal expects mean and std to be scalars when size is defined", + ) + dtype = torch.get_default_dtype() if dtype is None else dtype + device = torch.device("cpu") if device is None else device + + normal_samples = prims.normal( + size, + mean=0.0, + std=1.0, + dtype=dtype, + device=device, + requires_grad=False, + generator=generator, + ) + return std * normal_samples + mean + + +@register_decomposition(aten.normal_) +def normal_(self, mean=0, std=1, *, generator=None): + return normal(mean, std, self.shape, out=self, generator=generator) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def rad2deg(self: TensorLikeType): + torch._check( + not utils.is_complex_dtype(self.dtype), + lambda: "rad2deg is not supported for complex tensors.", + ) + M_180_PI = 57.295779513082320876798154814105170332405472466564 + return self * M_180_PI + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def deg2rad(self: TensorLikeType): + torch._check( + not utils.is_complex_dtype(self.dtype), + lambda: "deg2rad is not supported for complex tensors.", + ) + M_PI_180 = 0.017453292519943295769236907684886127134428718885417 + return self * M_PI_180 + + +@register_decomposition(aten.count_nonzero) +@out_wrapper() +def count_nonzero(self, dim: Optional[DimsType] = None): + return (self != 0).sum(dim) + + +def _dot_check(self, other): + torch._check( + self.dim() == 1 and other.dim() == 1, + lambda: f"1D tensors expected, but got {self.dim()}D and {other.dim()}D tensors", + ) + + def numel_error(): + return ( + f"inconsistent tensor size, expected tensor [{self.numel()}] and src [{other.numel()}] to have the" + f"same number of elements, but got {self.numel()} and {other.numel()} elements respectively" + ) + + torch._check(self.numel() == other.numel(), numel_error) + + +@register_decomposition(aten.dot) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("self", "other"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def dot(self, other): + if self.is_complex(): + if self.is_conj(): + if other.is_conj(): + return torch.dot(self.conj(), other.conj()).conj() + else: + return torch.vdot(self.conj(), other) + elif other.is_conj(): + return torch.vdot(other.conj(), self) + + _dot_check(self, other) + return (self * other).sum() + + +@register_decomposition(aten.vdot) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("self", "other"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def vdot(self, other): + if not self.is_complex(): + return torch.dot(self, other) + + if self.is_conj(): + if other.is_conj(): + return torch.vdot(other.conj(), self.conj()) + else: + return torch.dot(self.conj(), other) + elif other.is_conj(): + return torch.dot(self, other.conj()).conj() + + _dot_check(self, other) + # The decomposition fails if you do self.conj()... not sure why + return (self.conj_physical() * other).sum() + + +# inplace +abs_ = _make_inplace(abs) +acos_ = _make_inplace(acos) +acosh_ = _make_inplace(acosh) +add_ = _make_inplace(add) +addcmul_ = _make_inplace(addcmul) +addcdiv_ = _make_inplace(addcdiv) +asin_ = _make_inplace(asin) +asinh_ = _make_inplace(asinh) +atan_ = _make_inplace(atan) +atanh_ = _make_inplace(atanh) +atan2_ = _make_inplace(atan2) +bitwise_and_ = _make_inplace(bitwise_and) +bitwise_left_shift_ = _make_inplace(bitwise_left_shift) +bitwise_not_ = _make_inplace(bitwise_not) +bitwise_or_ = _make_inplace(bitwise_or) +bitwise_right_shift_ = _make_inplace(bitwise_right_shift) +bitwise_xor_ = _make_inplace(bitwise_xor) +ceil_ = _make_inplace(ceil) +clamp_ = _make_inplace(clamp) +clamp_min_ = _make_inplace(clamp_min) +clamp_max_ = _make_inplace(clamp_max) +conj_physical_ = _make_inplace(conj_physical) +copysign_ = _make_inplace(copysign) +cos_ = _make_inplace(cos) +cosh_ = _make_inplace(cosh) +cumsum_ = _make_inplace(cumsum) +cumprod_ = _make_inplace(cumprod) +deg2rad_ = _make_inplace(deg2rad) +digamma_ = _make_inplace(digamma) +div_ = _make_inplace(div) +eq_ = _make_inplace(eq) +erf_ = _make_inplace(erf) +erfc_ = _make_inplace(erfc) +erfinv_ = _make_inplace(erfinv) +exp_ = _make_inplace(exp) +exp2_ = _make_inplace(exp2) +expm1_ = _make_inplace(expm1) +float_power_ = _make_inplace(float_power) +floor_ = _make_inplace(floor) +floor_divide_ = _make_inplace(floor_divide) +fmod_ = _make_inplace(fmod) +frac_ = _make_inplace(frac) +gcd_ = _make_inplace(gcd) +ge_ = _make_inplace(ge) +gt_ = _make_inplace(gt) +heaviside_ = _make_inplace(heaviside) +hypot_ = _make_inplace(hypot) +igamma_ = _make_inplace(igamma) +igammac_ = _make_inplace(igammac) +i0_ = _make_inplace(i0) +lcm_ = _make_inplace(lcm) +le_ = _make_inplace(le) +lerp_ = _make_inplace(lerp) +lgamma_ = _make_inplace(lgamma) +log10_ = _make_inplace(log10) +log1p_ = _make_inplace(log1p) +log2_ = _make_inplace(log2) +log_ = _make_inplace(log) +logical_and_ = _make_inplace(logical_and) +logical_not_ = _make_inplace(logical_not) +logical_or_ = _make_inplace(logical_or) +logical_xor_ = _make_inplace(logical_xor) +lt_ = _make_inplace(lt) +mul_ = _make_inplace(mul) +mvlgamma_ = _make_inplace(mvlgamma) +nan_to_num_ = _make_inplace(nan_to_num) +ne_ = _make_inplace(ne) +neg_ = _make_inplace(neg) +nextafter_ = _make_inplace(nextafter) +pow_ = _make_inplace(pow) +rad2deg_ = _make_inplace(rad2deg) +reciprocal_ = _make_inplace(reciprocal) +remainder_ = _make_inplace(remainder) +rsqrt_ = _make_inplace(rsqrt) +sgn_ = _make_inplace(sgn) +sigmoid_ = _make_inplace(sigmoid) +sign_ = _make_inplace(sign) +sin_ = _make_inplace(sin) +sinc_ = _make_inplace(sinc) +sinh_ = _make_inplace(sinh) +sqrt_ = _make_inplace(sqrt) +square_ = _make_inplace(square) +sub_ = _make_inplace(sub) +tan_ = _make_inplace(tan) +tanh_ = _make_inplace(tanh) +tril_ = _make_inplace(tril) +triu_ = _make_inplace(triu) +true_divide_ = _make_inplace(true_divide) +trunc_ = _make_inplace(trunc) +xlogy_ = _make_inplace(xlogy) +cauchy_ = _make_inplace(cauchy) +exponential_ = _make_inplace(exponential) +geometric_ = _make_inplace(geometric) +log_normal_ = _make_inplace(log_normal) +zero_ = _make_inplace(zero) + + +# xref: isStorage in torch/csrc/DynamicTypes.cpp +def _isStorage(obj): + return isinstance(obj, (torch.TypedStorage, torch.UntypedStorage)) + + +# xref: compute_sizes in torch/csrc/utils/tensor_new.cpp +def _compute_sizes(seq, scalar_type): + MAX_DIMS = 128 + is_storage = _isStorage(seq) + sizes = [] + # TODO: this is inaccurate, we actually test PySequence_Check + while isinstance(seq, (list, tuple)): + length = len(seq) + if is_storage: + length //= scalar_type.itemsize + sizes.append(length) + if len(sizes) > MAX_DIMS: + raise ValueError(f"too many dimensions '{type(seq).__name__}'") + if length == 0: + break + try: + handle = seq[0] + except Exception: + raise ValueError( # noqa: TRY200 + f"could not determine the shape of object type '{type(seq).__name__}'" + ) + seq = handle + + return sizes + + +# xref: infer_scalar_type in torch/csrc/utils/tensor_new.cpp +def _infer_scalar_type(obj): + if isinstance(obj, FloatLike): + return torch.get_default_dtype() + if isinstance(obj, IntLike) and not isinstance(obj, bool): # careful! + return torch.int64 + if isinstance(obj, bool): + return torch.bool + if isinstance(obj, complex): + default_dtype = torch.get_default_dtype() + if default_dtype is torch.float: + return torch.cfloat + elif default_dtype is torch.double: + return torch.cdouble + else: + raise RuntimeError("invalid default scalar type for complex") + if isinstance(obj, torch.Tensor): + return obj.dtype + if isinstance(obj, str): + raise TypeError(f"new(): invalid data type '{type(obj).__name__}'") + # TODO: this is inaccurate, we actually test PySequence_Check + if isinstance(obj, (list, tuple)): + scalarType = None + length = len(obj) + # match NumPy semantics, except use default tensor type instead of + # double. + if length == 0: + return torch.get_default_dtype() + for i in range(length): + cur_item = obj[i] + # TODO: test this + """ + if cur_item is obj: + raise TypeError("new(): self-referential lists are incompatible") + """ + item_scalarType = _infer_scalar_type(cur_item) # recurse! + if scalarType is not None: + scalarType = torch.promote_types(scalarType, item_scalarType) + else: + scalarType = item_scalarType + if scalarType is torch.cdouble: + # this won't change (unless we hit undefined, but that will + # fail later) + return scalarType + return scalarType + raise RuntimeError(f"Could not infer dtype of {type(obj).__name__}") + + +# Analogous to recursive_store +# xref: recursive_store in torch/csrc/utils/tensor_new.cpp +def _recursive_build(scalarType: torch.dtype, obj: TensorOrNumberLikeType): + if isinstance(obj, Tensor) and obj.ndim <= 1: + obj = obj.item() + # fall through into next case + if isinstance(obj, Number): + return torch.scalar_tensor(obj, dtype=scalarType) + + seq = obj + return torch.stack([_recursive_build(scalarType, item) for item in seq]) + + +# xref: internal_new_from_data in torch/csrc/utils/tensor_new.cpp +def _internal_new_from_data( + options, + scalar_type, + device_opt, + data, + copy_variables, + copy_numpy, + type_inference, + pin_memory=False, +): + if isinstance(data, torch.Tensor): + torch._check( + not pin_memory, lambda: "Can't pin tensor constructed from a variable" + ) + var = data + if copy_variables: + var = var.detach() + inferred_scalar_type = var.dtype if type_inference else scalar_type + device = device_opt if device_opt is not None else var.device + return var.to( + device=device, + dtype=inferred_scalar_type, + non_blocking=False, + copy=copy_variables, + ) + + # TODO + if hasattr(data, "__cuda_array_interface__"): + return NotImplemented + + # TODO: test for numpy input with PyArray_Check + + device = device_opt if device_opt is not None else options["device"] + inferred_scalar_type = _infer_scalar_type(data) if type_inference else scalar_type + + # NB: Don't need to avoid tracing, as we aren't going to do any manual + # pointer filling tricks + if _isStorage(data): + return NotImplemented + else: + if torch.device(device).type == "meta": + return NotImplemented + + # In the C implementation, we would directly start poking the memory + # of a freshly allocated CPU tensor. Here, we're going to do an + # alternate, heinously slow implementation: turn each individual + # scalar into a tensor, and then repeatedly cat them together + tensor = _recursive_build(inferred_scalar_type, data) + + tensor = tensor.to(device, inferred_scalar_type, non_blocking=False, copy=False) + + # NB: lift_fresh is not needed, because we built the tensor from scalars + # guaranteeing a fresh tensor in this case + return tensor + + +# xref: tensor_ctor in torch/csrc/utils/tensor_new.cpp +def tensor(data, *, dtype=None, device=None, pin_memory=False, requires_grad=False): + # TODO (or not): support names kwarg + if isinstance(data, torch.Tensor): + warnings.warn( + "To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() " + "or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor)" + ) + type_inference = dtype is None + new_tensor = _internal_new_from_data( + # device="cpu" because that's what you get with torch.tensor(2) no + # device by default + {"device": "cpu"}, # TODO: use torch.get_default_tensor_type + dtype if dtype is not None else torch.get_default_dtype(), + device, + data, + copy_variables=True, + copy_numpy=True, + type_inference=type_inference, + pin_memory=pin_memory, + ) + new_tensor.detach_() + new_tensor.requires_grad_(requires_grad) + return new_tensor + + +# Views +# We can't model these as above, as the pattern of doing `op(a, out=a)` does not work for a view function +# given that it does not reshape the input (it just copies the result into it) + +# squeeze_ = _make_inplace(squeeze) +# t_ = _make_inplace(t) +# transpose_ = _make_inplace(transpose) +# unsqueeze_ = _make_inplace(unsqueeze) + + +import torch._refs._conversions +import torch._refs.fft +import torch._refs.linalg +import torch._refs.nn.functional +import torch._refs.special diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/__pycache__/_conversions.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/__pycache__/_conversions.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f44be2b85471e8684150c02d5328a61e820f6e0c Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/__pycache__/_conversions.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/nn/functional/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/nn/functional/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d6be2180e8b34cb1fdac215e6d885f9450c5103f --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/nn/functional/__init__.py @@ -0,0 +1,1230 @@ +import math +from functools import wraps +from typing import Callable, Optional, Union + +import torch +import torch._prims as prims +import torch._prims_common as utils +import torch._refs as refs +from torch._decomp import register_decomposition +from torch._prims_common import ( + ELEMENTWISE_TYPE_PROMOTION_KIND, + NumberType, + ShapeType, + TensorLike, + TensorLikeType, +) +from torch._prims_common.wrappers import ( + elementwise_type_promotion_wrapper, + elementwise_unary_scalar_wrapper, + out_wrapper, +) +from torch._refs import _make_inplace + +__all__ = [ + "alpha_dropout", + "celu", + "celu_", + "dropout", + "elu", + "elu_", + "gelu", + "glu", + "group_norm", + "hardshrink", + "hardtanh", + "hinge_embedding_loss", + "huber_loss", + "l1_loss", + "layer_norm", + "leaky_relu", + "log_softmax", + "margin_ranking_loss", + "mish", + "mish_", + "mse_loss", + "nll_loss", + "pairwise_distance", + "pdist", + "poisson_nll_loss", + "prelu", + "relu", + "relu6", + "selu", + "selu_", + "smooth_l1_loss", + "softmax", + "softmin", + "softplus", + "softshrink", + "tanhshrink", + "threshold", + "threshold_", + "triplet_margin_loss", +] + +Tensor = torch.Tensor +aten = torch._ops.ops.aten +DispatchKey = torch._C.DispatchKey # type: ignore[attr-defined] + + +def _dropout_helper( + self: TensorLikeType, + val: float, +) -> TensorLikeType: + """ + Helper function for all dropout-type operators. During training, + some of the elements of the input tensor are randomly masked. + + Returns the masked tensor of the boolean values. + + """ + + return ( + refs._uniform_helper( + self.shape, low=0.0, high=1.0, dtype=torch.float32, device=self.device + ) + < val + ) + + +@register_decomposition(aten.alpha_dropout) +def alpha_dropout( + self: TensorLikeType, p: float = 0.5, training: bool = False, inplace: bool = False +) -> TensorLikeType: + if inplace: + raise NotImplementedError + + if not training: + return self + + torch._check( + p <= 1 and p >= 0, + lambda: f"dropout probability has to be between 0 and 1, but got, {p}", + ) + + if p == 1: + return torch.zeros_like(self) + + if p == 0: + return self + + dropout_mask = _dropout_helper(self, 1 - p) + + # From paper: Self-Normalizing Neural Networks (https://arxiv.org/pdf/1706.02515.pdf) + # alpha = - SELU.alpha * SELU.scale, here + # SELU.alpha = 1.6732632423543772848170429916717 and + # SELU.scale = 1.0507009873554804934193349852946 + alpha = -1.7580993408473766 + + a = 1.0 / math.sqrt((alpha * alpha * p + 1) * (1 - p)) + b = torch.logical_not(dropout_mask) + b = b * (alpha * a) + alpha * a * p + dropout_mask = a * dropout_mask + + return self * dropout_mask + b + + +def _inplace_wrapper(fn): + """ + Given a nn.functional non-linearity, implements its `inplace: bool` argument + """ + + # nb. We use the name of the first argument used in the unary references + @wraps(fn) + def _fn(a, *args, inplace=False, **kwargs): + if inplace: + torch._check( + "out" not in kwargs, + lambda: "Cannot set inplace=True and pass out= at the same time", + ) + return fn(a, *args, inplace=False, out=a, **kwargs) + else: + return fn(a, *args, inplace=False, **kwargs) + + return _fn + + +# celu is implemented specially because it has an alpha argument +# celu is very similar to elu +@register_decomposition(aten.celu) +@_inplace_wrapper +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def celu( + a: TensorLikeType, alpha: Optional[NumberType] = None, inplace: bool = False +) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.celu + """ + + if inplace: + raise NotImplementedError + + rhs: TensorLikeType + if alpha is not None: + python_type = utils.dtype_to_type(a.dtype) + if not utils.is_weakly_lesser_type(type(alpha), python_type): + msg = f"alpha argument of type {type(alpha)} cannot be safely cast to type {python_type}!" + raise ValueError(msg) + rhs = alpha * torch.expm1(torch.true_divide(a, alpha)) # type: ignore[arg-type] + else: + rhs = torch.expm1(a) + + return torch.where(a > 0, a, rhs) + + +@_inplace_wrapper +@out_wrapper() +def dropout( + a: TensorLikeType, p: float = 0.5, training: bool = True, inplace: bool = False +) -> TensorLikeType: + if inplace: + raise NotImplementedError + + if not training: + return a + + torch._check( + p <= 1 and p >= 0, + lambda: f"dropout probability has to be between 0 and 1, but got, {p}", + ) + + if p == 1: + return torch.zeros_like(a) + + if p == 0: + return a + + scale = 1 / (1 - p) + dropout_mask = _dropout_helper(a, 1 - p) + + return a * dropout_mask * scale + + +@register_decomposition(aten.elu) +@_inplace_wrapper +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def elu( + a: TensorLikeType, + alpha: NumberType = 1.0, + scale: NumberType = 1.0, + input_scale: NumberType = 1.0, + inplace: bool = False, +) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.elu + """ + if inplace: + raise NotImplementedError + + # nb. This should be factored out into a can_cast aux function + python_type = utils.dtype_to_type(a.dtype) + torch._check( + utils.is_weakly_lesser_type(type(input_scale), python_type), + lambda: f"input_scale argument of type {type(input_scale)} cannot be safely cast to type {python_type}!", + ) + torch._check( + utils.is_weakly_lesser_type(type(scale), python_type), + lambda: f"scale argument of type {type(scale)} cannot be safely cast to type {python_type}!", + ) + torch._check( + utils.is_weakly_lesser_type(type(alpha), python_type), + lambda: f"alpha argument of type {type(alpha)} cannot be safely cast to type {python_type}!", + ) + + return torch.where(a > 0, scale * a, (alpha * scale) * torch.expm1(a * input_scale)) + + +@register_decomposition(aten.relu) +@_inplace_wrapper +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def relu(a: TensorLikeType, inplace: bool = False) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.relu + """ + + if inplace: + raise NotImplementedError + + return torch.where(torch.le(a, 0), 0, a) + + +def group_norm( + input: Tensor, + num_groups: int, + weight: Optional[Tensor] = None, + bias: Optional[Tensor] = None, + eps: float = 1e-5, +) -> Tensor: + """ + Reference implementation of :func:`torch.nn.functional.group_norm`. + """ + torch._check( + input.ndim >= 2, + lambda: f"Expected at least 2 dimensions for input tensor but received {input.ndim}", + ) + + batch_size = input.shape[0] + num_channels = input.shape[1] + torch._check( + num_channels % num_groups == 0, + lambda: "Expected number of channels in input to be divisible by num_groups, " + + f"but got input of shape {input.shape} and num_groups = {num_groups}", + ) + + # input shape is (N, C, *), so we flatten all inner dimensions except (N, C) + flattened_inner_size = 1 + for dim_length in input.shape[2:]: + flattened_inner_size *= dim_length + + return torch.native_group_norm( + input, + weight, + bias, + batch_size, + num_channels, + flattened_inner_size, + num_groups, + eps, + )[0] + + +def layer_norm( + input: Tensor, + normalized_shape: ShapeType, + weight: Optional[Tensor] = None, + bias: Optional[Tensor] = None, + eps: float = 1e-5, +) -> Tensor: + """ + Reference implementation of :func:`torch.nn.functional.layer_norm`. + """ + return torch.native_layer_norm(input, normalized_shape, weight, bias, eps)[0] + + +@register_decomposition(aten.leaky_relu) +@_inplace_wrapper +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def leaky_relu( + a: TensorLikeType, negative_slope: float = 0.01, inplace: bool = False +) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.leaky_relu + """ + + if inplace: + raise NotImplementedError + + python_type = utils.dtype_to_type(a.dtype) + if not utils.is_weakly_lesser_type(type(negative_slope), python_type): + msg = f"negative_slope argument of type {type(negative_slope)} cannot be safely cast to type {python_type}!" + raise ValueError(msg) + return torch.where(torch.gt(a, 0), a, torch.mul(a, negative_slope)) + + +@register_decomposition(aten.mish) +@_inplace_wrapper +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def mish(a: TensorLikeType, inplace: bool = False) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.mish + """ + + if inplace: + raise NotImplementedError + return a * torch.tanh(torch.nn.functional.softplus(a)) + + +@register_decomposition(aten.selu) +@_inplace_wrapper +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def selu(a: TensorLikeType, inplace: bool = False) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.selu + """ + if inplace: + raise NotImplementedError + + alpha = 1.6732632423543772848170429916717 + scale = 1.0507009873554804934193349852946 + + rhs = alpha * torch.expm1(a) + + return scale * torch.where(a > 0, a, rhs) + + +# Forwarding alias: the functional variant doesn't support the out kwarg +# CompositeImplicitAutograd - don't register decomp +def softmax( + a: TensorLikeType, + dim: Optional[int] = None, + _stacklevel: int = 3, # for compat when using TorchRefsMode(strict=True) + dtype: Optional[torch.dtype] = None, +) -> TensorLikeType: + # The error is for compat with regular PyTorch, which has this behavior + # deprecated. For PrimTorch, it's fine to drop support for deprecated + # behavior because it requires explicit opt in. This error is to inform + # users how to update their calls. + torch._check(dim is not None, lambda: "implicit dim not supported, use dim=X") + return torch.softmax(a=a, dim=dim, dtype=dtype) # type: ignore[call-overload] + + +# CompositeImplicitAutograd - don't register decomp +def softmin( + a: TensorLikeType, + dim: Optional[int] = None, + _stacklevel: int = 3, # for compat when using TorchRefsMode(strict=True) + dtype: Optional[torch.dtype] = None, +) -> TensorLikeType: + # The error is for compat with regular PyTorch, which has this behavior + # deprecated. For PrimTorch, it's fine to drop support for deprecated + # behavior because it requires explicit opt in. This error is to inform + # users how to update their calls. + torch._check(dim is not None, lambda: "implicit dim not supported, use dim=X") + return torch.softmax(a=-a, dim=dim, dtype=dtype) # type: ignore[call-overload] + + +# softplus is implemented specially because it has beta and threshold arguments +@register_decomposition(aten.softplus) +@_inplace_wrapper +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def softplus( + a: TensorLikeType, + beta: Optional[NumberType] = None, + threshold: NumberType = 20, + inplace: bool = False, +) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.softplus + """ + + if inplace: + raise NotImplementedError + + rhs: TensorLikeType + if beta is not None: + python_type = utils.dtype_to_type(a.dtype) + if not utils.is_weakly_lesser_type(type(beta), python_type): + msg = f"beta argument of type {type(beta)} cannot be safely cast to type {python_type}!" + raise ValueError(msg) + scaled_input = a * beta + rhs = torch.true_divide(torch.log1p(torch.exp(scaled_input)), beta) # type: ignore[arg-type] + + else: + scaled_input = a + rhs = torch.log1p(torch.exp(scaled_input)) + + return torch.where(scaled_input > threshold, a, rhs) + + +@aten.hardshrink.default.py_impl(DispatchKey.Autograd) +@register_decomposition(aten.hardshrink) +@out_wrapper() +def hardshrink(a: TensorLikeType, lambd: float = 0.5): + # Formula for reference, + # hardshrink(x) = x if x > lambd + # = x if x < -lambd + # = 0 otherwise + return torch.where(torch.abs(a) <= lambd, 0, a) + + +@aten.softshrink.default.py_impl(DispatchKey.Autograd) +@register_decomposition(aten.softshrink) +@out_wrapper() +def softshrink(a: TensorLikeType, lambd: float = 0.5): + # Formula for reference, + # softshrink(x) = x - lambd if x > lambd + # = x + lambd if x < -lambd + # = 0 otherwise + torch._check( + lambd >= 0, + lambda: f"lambda must be greater or equal to 0, but found to be {lambd}", + ) + # We implement this in one torch.where to generate better code in the backward + # see https://github.com/pytorch/pytorch/pull/107052#discussion_r1293748211 + return torch.where(torch.abs(a) > lambd, a - torch.sign(a) * lambd, 0) + + +# Losses +def _reduction_int_to_str(reduction: int) -> str: + from torch._decomp.decompositions import Reduction + + if reduction == Reduction.NONE.value: + return "none" + elif reduction == Reduction.MEAN.value: + return "mean" + elif reduction == Reduction.SUM.value: + return "sum" + else: + raise ValueError(f"{reduction} is not a valid value for reduction") + + +def _apply_loss_reduction(loss: TensorLikeType, reduction: str) -> TensorLikeType: + if reduction == "sum": + return torch.sum(loss) + elif reduction == "mean": + return torch.mean(loss) + else: # reduction == "none" + return loss + + +def _check_reduction_value(reduction: str): + if reduction not in ("mean", "sum", "none"): + raise ValueError(f"{reduction} is not a valid value for reduction") + + +# This helper function maps depreciated arguments, "size_average" and "reduce" +# to their corresponding "reduction" string argument +def _get_string_reduction_arg( + *, size_average: Optional[bool], reduce: Optional[bool] +) -> str: + if size_average is None: + size_average = True + if reduce is None: + reduce = True + if size_average and reduce: + ret = "mean" + elif reduce: + ret = "sum" + else: + ret = "none" + return ret + + +# CompositeImplicitAutograd - don't register decomp +@elementwise_type_promotion_wrapper( + type_promoting_args=("input", "target"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT, +) +def l1_loss( + input: TensorLikeType, + target: TensorLikeType, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.l1_loss + """ + if size_average is not None or reduce is not None: + # TODO: Raise exception instead of converting value. This is only for + # primTorch since it can drop support for deprecated arguments. + # msg = "size_average and reduce args are deprecated, please use reduction argument." + reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce) + _check_reduction_value(reduction) + loss = torch.abs(input - target) + return _apply_loss_reduction(loss, reduction) + + +@elementwise_type_promotion_wrapper( + type_promoting_args=("input", "target"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT, +) +def smooth_l1_loss( + input: TensorLikeType, + target: TensorLikeType, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", + beta: float = 1.0, +) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.smooth_l1_loss + """ + if size_average is not None or reduce is not None: + # TODO: Raise exception instead of converting value. This is only for + # primTorch since it can drop support for deprecated arguments. + # msg = "size_average and reduce args are deprecated, please use reduction argument." + reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce) + _check_reduction_value(reduction) + + if beta == 0.0: + return torch.nn.functional.l1_loss( + input, target, size_average=size_average, reduce=reduce, reduction=reduction + ) + else: + loss = torch.abs(input - target) + loss = torch.where(loss < beta, 0.5 * loss**2 / beta, loss - 0.5 * beta) + return _apply_loss_reduction(loss, reduction) + + +# Forwarding alias: the functional variant doesn't support the out kwarg +# CompositeImplicitAutograd - don't register decomp +def log_softmax( + a: TensorLikeType, + dim: Optional[int] = None, + _stacklevel: int = 3, # for compat when using TorchRefsMode(strict=True) + dtype: Optional[torch.dtype] = None, +) -> TensorLikeType: + # The error is for compat with regular PyTorch, which has this behavior + # deprecated. For PrimTorch, it's fine to drop support for deprecated + # behavior because it requires explicit opt in. This error is to inform + # users how to update their calls. + torch._check(dim is not None, lambda: "implicit dim not supported, use dim=X") + return torch.log_softmax(a=a, dim=dim, dtype=dtype) # type: ignore[call-overload] + + +@register_decomposition(aten.margin_ranking_loss) +def margin_ranking_loss( + input1: TensorLikeType, + input2: TensorLikeType, + target: TensorLikeType, + margin: float = 0.0, + reduction: str = "mean", +) -> TensorLikeType: + # loss_without_reduction = max(0, −target * (input1 − input2) + margin) + if input1.ndim != input2.ndim or input1.ndim != target.ndim: + raise RuntimeError( + "margin_ranking_loss : All input tensors should have same dimension but got sizes: " + f"input1: {input1.shape}, input2: {input2.shape}, target: {target.shape} " + ) + _check_reduction_value(reduction) + loss = torch.clamp_min(-target * (input1 - input2) + margin, 0) + return _apply_loss_reduction(loss, reduction) + + +@elementwise_type_promotion_wrapper( + type_promoting_args=("input", "target"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT, +) +def mse_loss( + input: TensorLikeType, + target: TensorLikeType, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> TensorLikeType: + if size_average is not None or reduce is not None: + # TODO: Raise exception instead of converting value. This is only for + # primTorch since it can drop support for deprecated arguments. + # msg = "size_average and reduce args are deprecated, please use reduction argument." + reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce) + _check_reduction_value(reduction) + loss = torch.pow(input - target, 2) + return _apply_loss_reduction(loss, reduction) + + +@register_decomposition(aten.hinge_embedding_loss) +def hinge_embedding_loss( + input: TensorLikeType, + target: TensorLikeType, + margin: float = 1.0, + reduction: str = "mean", +) -> TensorLikeType: + # loss_without_reduction = input if y == 1 + # = max(0, margin - input) if y == -1 + _check_reduction_value(reduction) + margin_clamp = torch.clamp_min(margin - input, 0) + output_margin = torch.where(target != 1, margin_clamp, 0) + output_self = torch.where(target != -1, input, 0) + loss = output_margin + output_self + return _apply_loss_reduction(loss, reduction) + + +def _nll_loss_nd( + input: TensorLikeType, + target: TensorLikeType, + weight: Optional[TensorLikeType], + reduction: str, + ignore_index: int, +) -> TensorLikeType: + torch._check( + input.ndim > 0 and input.ndim <= 3, + lambda: f"Expected input dimension to be either [1, 2, 3] but received {input.ndim}.", + ) + + torch._check( + (input.ndim == 1) or (input.shape[0] == target.shape[0]), + lambda: f"Expected input batch size {input.shape[0]} to match target batch size {target.shape[0]}.", + ) + + _check_reduction_value(reduction) + + flat_target = torch.flatten(target) + ignore_classes_mask = torch.eq(flat_target, ignore_index) + + # TODO: Enable data-dependent checks with debug mode + # TODO: This check does not work with FakeTensor inputs; See Issue #85834 + # Explicit cast for class_check to bool; See Issue #78071 + """ + from torch._subclasses.fake_tensor import FakeTensor + num_classes = input.shape[1] if input.ndim > 1 else input.shape[0] + valid_classes_mask = torch.logical_and( + (flat_target >= 0), (flat_target < num_classes) + ) + class_check = torch.all(torch.logical_or(ignore_classes_mask, valid_classes_mask)) + torch._check( + isinstance(target, FakeTensor) or bool(class_check.item()), + lambda: "A target class is out-of-bounds and not the ignore index.", + ) + """ + + ignore_class_weight = torch.scalar_tensor(0, dtype=input.dtype, device=input.device) + class_weight = ( + torch.scalar_tensor(1, dtype=input.dtype, device=input.device) + if weight is None + else weight[flat_target] + ) + current_weight = torch.where( + ignore_classes_mask, + ignore_class_weight, + class_weight, + ) + + if input.ndim == 1: + # implicit batch size = 1 + # input (1 batch size, C classes) + loss = -input[target] * current_weight + elif input.ndim == 2: + # input (N batch size, C classes) + batch_size = input.shape[0] + loss = -input[torch.arange(batch_size), target] * current_weight + else: + # 3D case (N batch size, C classe, K dimensions) + # input (N batch size, C classes, K) + batch_size = input.shape[0] + extent = input.shape[2] + numel = batch_size * extent + indices = torch.arange(numel) + bdx = indices // extent + kdx = indices % extent + loss = -input[bdx, flat_target, kdx] * current_weight + loss = torch.reshape(loss, target.shape) + + if reduction == "none": + return loss + elif reduction == "sum": + return torch.sum(loss) + else: + # calculate weighted mean of the loss function + return torch.sum(loss) / torch.sum(current_weight) + + +@register_decomposition(aten.nll_loss) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("input",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def nll_loss( + input: TensorLikeType, + target: TensorLikeType, + weight: Optional[TensorLikeType] = None, + size_average: Optional[bool] = None, + ignore_index: int = -100, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.nll_loss + """ + torch._check( + input.ndim > 0, + lambda: f"Expected input tensor to have 1 or more dimensions (got {input.ndim})", + ) + + # TODO: raise exception instead of converting value + # msg = "size_average and reduce args are deprecated, please use reduction argument." + # Convert these options for consistency with the eager mode + if size_average is not None or reduce is not None: + reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce) + + # The expected behavior when the target and input have zero elements: + # reduction = 'none' --- tensor([]) + # reduction = 'sum' --- tensor(0.) + # reduction = 'mean' --- tensor(nan) + # Mean reduction on empty tensors produces NaN. See the discussion in + # https://github.com/pytorch/pytorch/pull/64572#issuecomment-926504162 + if input.numel() == 0 and target.numel() == 0: + if reduction == "none": + return torch.zeros_like(target) + elif reduction == "sum": + return torch.empty_like(target) + else: + return torch.full_like(target, float("nan")) + + # The _nll_loss_nd helper function handles the most common cases. + # ndim == 1 (Single Example) + # => Batch Size: 1, Input: (C), Target: () + # ndim == 2 (k = 1) + # => Batch Size: N, Input: (N, C), Target: (N) + # ndim == 3 (k > 1) + # => Batch Size: N, Input: (N, C, K), Target: (N, K) + if input.ndim <= 3: + return _nll_loss_nd(input, target, weight, reduction, ignore_index) + + # For ndim > 3, we reshape the input and target to 3-D case. + # Input (N batch-size, C classes, k-dimensions) + # Target (N batch-size, k-dimensions) + torch._check( + input.ndim > 0 and target.ndim > 0 and target.shape[1:] == input.shape[2:], + lambda: ( + "Expected input and target to both have ndim > 0 and " + "target.shape[1:] == input.shape[2:], but got " + f"target.shape {target.shape} and input.shape {input.shape}" + ), + ) + + batch_size = input.shape[0] + num_classes = input.shape[1] + out_size = [batch_size] + list(target.shape[1:]) + + input = torch.reshape(input, [batch_size, num_classes, -1]) + target = torch.reshape(target, [batch_size, -1]) + if reduction != "none": + return _nll_loss_nd(input, target, weight, reduction, ignore_index) + else: + result = _nll_loss_nd(input, target, weight, reduction, ignore_index) + # reshape flattened inner-dim to original k-dimensions + return torch.reshape(result, out_size) + + +# TODO: This ref supports int reduction and out kwarg to be compatible with ATen: +# https://github.com/pytorch/pytorch/issues/83931 +# TODO: Could be rewritten to support complex: +# https://github.com/pytorch/pytorch/pull/85041 +@register_decomposition(aten.huber_loss) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("input", "target"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def huber_loss( + input: TensorLikeType, + target: TensorLikeType, + reduction: Union[str, int] = "mean", + delta: float = 1.0, +) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.huber_loss + """ + if type(reduction) is int: + reduction = _reduction_int_to_str(reduction) + _check_reduction_value(reduction) # type: ignore[arg-type] + torch._check( + delta > 0, + lambda: "huber_loss does not support non-positive values for delta.", + ) + z = (input - target).abs() + loss = torch.where(z < delta, 0.5 * z * z, delta * (z - 0.5 * delta)) + return _apply_loss_reduction(loss, reduction) # type: ignore[arg-type] + + +# tanhshrink does not use _make_elementwise_unary_reference because it does not support out +@elementwise_unary_scalar_wrapper +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def tanhshrink(a: TensorLikeType) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.tanhshrink + """ + if not isinstance(a, TensorLike): + raise RuntimeError( + "Expected a tensor input for an elementwise unary operation!" + ) + return a - torch.tanh(a) + + +@register_decomposition(aten.threshold) +@_inplace_wrapper +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def threshold( + a: TensorLikeType, + threshold: NumberType, + value: Union[bool, int, float], + inplace: bool = False, +) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.threshold + """ + + if inplace: + raise NotImplementedError + + return torch.where(a <= threshold, value, a) + + +# CompositeImplicitAutograd - don't register decomp +# No elementwise type promotion - core op doesn't explicitly type promote +def triplet_margin_loss( + anchor: TensorLikeType, + positive: TensorLikeType, + negative: TensorLikeType, + margin: float = 1.0, + p: float = 2, + eps: float = 1e-6, + swap: bool = False, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> TensorLikeType: + if size_average is not None or reduce is not None: + # TODO: Raise exception instead of converting value. This is only for + # primTorch since it can drop support for deprecated arguments. + # msg = "size_average and reduce args are deprecated, please use reduction argument." + reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce) + + # torch.nn.functional.triplet_margin_with_distance_loss has no ref defined + # since it's a pure Python implementation. Use this helper instead. + return _triplet_margin_with_distance_loss( + anchor=anchor, + positive=positive, + negative=negative, + distance_function=lambda x, y: torch.pairwise_distance(x, y, p, eps), + margin=margin, + swap=swap, + reduction=reduction, + ) + + +# Pure Python impl - don't register decomp and don't add a ref. Defined as a +# helper here since triplet_margin_loss can be nicely implemented with it. +def _triplet_margin_with_distance_loss( + anchor: TensorLikeType, + positive: TensorLikeType, + negative: TensorLikeType, + *, + distance_function: Optional[ + Callable[[TensorLikeType, TensorLikeType], TensorLikeType] + ] = None, + margin: float = 1.0, + swap: bool = False, + reduction: str = "mean", +) -> TensorLikeType: + _check_reduction_value(reduction) + + a_dim = anchor.ndim + p_dim = positive.ndim + n_dim = negative.ndim + torch._check( + a_dim == p_dim and p_dim == n_dim, + lambda: ( + f"The anchor, positive, and negative tensors are expected to have " + f"the same number of dimensions, but got: anchor {a_dim}D, " + f"positive {p_dim}D, and negative {n_dim}D inputs" + ), + ) + + if distance_function is None: + distance_function = torch.pairwise_distance + + dist_pos = distance_function(anchor, positive) + dist_neg = distance_function(anchor, negative) + # The distance swap is described in the paper "Learning shallow + # convolutional feature descriptors with triplet losses" by V. Balntas, E. + # Riba et al. If True, and if the positive example is closer to the + # negative example than the anchor is, swaps the positive example and the + # anchor in the loss computation. + if swap: + dist_swap = distance_function(positive, negative) + dist_neg = torch.minimum(dist_neg, dist_swap) + loss = torch.clamp_min(margin + dist_pos - dist_neg, 0) + return _apply_loss_reduction(loss, reduction) + + +@register_decomposition(aten.hardtanh) +@_inplace_wrapper +@out_wrapper() +@elementwise_unary_scalar_wrapper +@elementwise_type_promotion_wrapper( + type_promoting_args=("a"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def hardtanh( + a: TensorLikeType, + min_val: NumberType = -1, + max_val: NumberType = 1, + inplace: bool = False, +) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.hardtanh + """ + if inplace: + raise NotImplementedError + if utils.is_boolean_dtype(a.dtype): + raise RuntimeError("Bool inputs not supported for hardtanh") + + # preserve legacy behavior of boundaries not causing type promotion + if utils.is_integer_dtype(a.dtype): + min_val = int(min_val) # type: ignore[arg-type] + max_val = int(max_val) # type: ignore[arg-type] + if not (a.dtype != torch.uint8 or (min_val >= 0 and max_val >= 0)): + raise RuntimeError( + "Cannot do hardtanh on an unsigned type with negative limits" + ) + return torch.clamp(a, min_val, max_val) # type: ignore[arg-type] + + +@register_decomposition(aten.gelu) +@out_wrapper() +@elementwise_unary_scalar_wrapper +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def gelu(a: TensorLikeType, approximate: str = "none") -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.gelu + """ + if not isinstance(a, TensorLike): + raise RuntimeError( + "Expected a tensor input for an elementwise unary operation!" + ) + M_SQRT2 = 1.41421356237309504880 + M_SQRT1_2 = 0.70710678118654752440 + M_2_SQRTPI = 1.12837916709551257390 + if approximate == "tanh": + kBeta = M_SQRT2 * M_2_SQRTPI * 0.5 + kKappa = 0.044715 + a_cube = a * a * a + inner = kBeta * (a + kKappa * a_cube) + return 0.5 * a * (1 + torch.tanh(inner)) + elif approximate == "none": + kAlpha = M_SQRT1_2 + return a * 0.5 * (1 + torch.erf(a * kAlpha)) + else: + raise RuntimeError("approximate argument must be either none or tanh.") + + +# CompositeImplicitAutograd - don't register decomp +@elementwise_type_promotion_wrapper( + type_promoting_args=("input", "target"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def poisson_nll_loss( + input: TensorLikeType, + target: TensorLikeType, + log_input: bool = True, + full: bool = False, + size_average: Optional[bool] = None, + eps: float = 1e-8, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.poisson_nll_loss + """ + if size_average is not None or reduce is not None: + # TODO: Raise exception instead of converting value. This is only for + # primTorch since it can drop support for deprecated arguments. + # msg = "size_average and reduce args are deprecated, please use reduction argument." + reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce) + _check_reduction_value(reduction) + if log_input: + loss = torch.exp(input) - target * input + else: + loss = input - target * torch.log(input + eps) + + if full: + stirling_term = ( + target * torch.log(target) - target + 0.5 * torch.log(2 * torch.pi * target) + ) + # avoid inplace add + loss = loss + stirling_term.masked_fill(target <= 1, 0) + return _apply_loss_reduction(loss, reduction) + + +@register_decomposition(aten.prelu) +@elementwise_type_promotion_wrapper( + type_promoting_args=("a", "weight"), + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def prelu(a: TensorLikeType, weight: TensorLikeType) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.prelu + """ + torch._check( + isinstance(a, TensorLike), + lambda: f"prelu: Expected `a` to be tensor, but got: {type(a)}", + ) + torch._check( + isinstance(weight, TensorLike), + lambda: f"prelu: Expected `weight` to be tensor, but got: {type(weight)}", + ) + + if weight.numel() != 1: + torch._check(a.ndim > 0, lambda: "Not allow zero-dim input tensor.") + channel_size = a.shape[1] if a.ndim >= 2 else 1 + torch._check( + weight.numel() == channel_size, + lambda: f"Mismatch of parameter numbers and input channel size. Found parameter numbers =" + f" {weight.numel()} and channel size = {channel_size}.", + ) + + torch._check( + weight.ndim == 0 or weight.ndim == 1, + lambda: f"prelu: Expected `weight` to be a scalar or 1D tensor, but got: " + f"ndim = {weight.ndim}", + ) + if a.ndim == 0: + weight = weight[0] if weight.ndim == 1 else weight + else: + weight = prims.broadcast_in_dim( + weight, a.shape, tuple() if weight.ndim == 0 else (0 if a.ndim == 1 else 1,) + ) + + return torch.where(a > 0, a, a * weight) + + +@register_decomposition(aten.relu6) +@_inplace_wrapper +@out_wrapper() +def relu6(a: TensorLikeType, inplace: bool = False) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.relu6 + """ + if inplace: + raise NotImplementedError + + # See https://github.com/pytorch/pytorch/pull/81142#discussion_r918220126 + # It may be better to use clamp here, but we use hardtanh to replicate + # the behavior of the existing implementation + return torch.nn.functional.hardtanh(a, 0, 6) + + +@register_decomposition(aten.glu) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def glu(a: TensorLikeType, dim: int = -1) -> TensorLikeType: + dim = utils.canonicalize_dims(a.ndim, dim) + torch._check( + a.shape[dim] % 2 == 0, + lambda: f"Halving dimension must be even, but dimension {dim} is size {a.shape[dim]}", + ) + b, c = torch.tensor_split(a, 2, dim) + + return b * torch.sigmoid(c) + + +@register_decomposition(aten.pairwise_distance) +@out_wrapper() +def pairwise_distance( + x1: TensorLikeType, + x2: TensorLikeType, + p: NumberType = 2.0, + eps: NumberType = 1e-6, + keepdim=False, +) -> TensorLikeType: + return torch.linalg.vector_norm(x1 - x2 + eps, ord=p, dim=-1, keepdim=keepdim) + + +@register_decomposition(aten.pdist) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def pdist(a: TensorLikeType, p: float = 2) -> TensorLikeType: + torch._check(a.ndim == 2, lambda: f"pdist only supports 2D tensors, got: {a.ndim}D") + torch._check(p >= 0, lambda: "pdist only supports non-negative p values") + # For p == 2 we can use an efficient implementation, but other values of p + # require creating a much bigger tensor for an intermediate step + if p == 2: + aTa = torch.mm(a, a.T) + aTa_diag = torch.diag(aTa) + t = torch.sqrt(torch.clamp(aTa_diag + aTa_diag.unsqueeze(-1) - 2 * aTa, min=0)) + else: + t = torch.linalg.vector_norm(a.unsqueeze(1) - a, ord=p, dim=2) + i = torch.triu_indices(t.shape[0], t.shape[1], offset=1, device=a.device) + return t.flatten().index_select(0, i[0] * t.shape[0] + i[1]) + + +@register_decomposition(aten.pixel_shuffle) +@out_wrapper() +def pixel_shuffle(self: Tensor, upscale_factor: int): + torch._check( + self.dim() >= 3, + lambda: f"pixel_shuffle expects input to have at least 3 dimensions, but got input with {self.dim} dimension(s)", + ) + batch = self.shape[:-3] + C_out = self.shape[-3] // upscale_factor**2 + HW_out = (self.shape[-2] * upscale_factor, self.shape[-1] * upscale_factor) + n = len(batch) + B_dims = range(n) + C_dim, r1_dim, r2_dim, H_dim, W_dim = range(n, n + 5) + return ( + self.view( + *batch, + C_out, + upscale_factor, + upscale_factor, + self.shape[-2], + self.shape[-1], + ) + .permute(*B_dims, C_dim, H_dim, r1_dim, W_dim, r2_dim) + .reshape(*batch, C_out, *HW_out) + .clone(memory_format=utils.suggest_memory_format(self)) + ) + + +@register_decomposition(aten.pixel_unshuffle) +@out_wrapper() +def pixel_unshuffle(self: Tensor, downscale_factor: int): + torch._check( + self.dim() >= 3, + lambda: f"pixel_unshuffle expects input to have at least 3 dimensions, but got input with {self.dim} dimension(s)", + ) + batch = self.shape[:-3] + C_out = self.shape[-3] * downscale_factor**2 + HW_out = (self.shape[-2] // downscale_factor, self.shape[-1] // downscale_factor) + n = len(batch) + B_dims = range(n) + C_dim, H_dim, r1_dim, W_dim, r2_dim = range(n, n + 5) + return ( + self.view( + *batch, + self.shape[-3], + HW_out[0], + downscale_factor, + HW_out[1], + downscale_factor, + ) + .permute(*B_dims, C_dim, r1_dim, r2_dim, H_dim, W_dim) + .reshape(*batch, C_out, *HW_out) + .clone(memory_format=utils.suggest_memory_format(self)) + ) + + +# Needed as aten.{celu_,elu_...} exist (even if they don't have the in-place kwarg) +celu_ = _make_inplace(celu) +elu_ = _make_inplace(elu) +mish_ = _make_inplace(mish) +selu_ = _make_inplace(selu) +threshold_ = _make_inplace(threshold) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/__pycache__/_numeric_suite_fx.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/__pycache__/_numeric_suite_fx.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..35661d0534dc2712cdf4c231deec60df5c440606 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/__pycache__/_numeric_suite_fx.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/_numeric_suite.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/_numeric_suite.py new file mode 100644 index 0000000000000000000000000000000000000000..3f0df31dfd2a1de31b7f10e3d870f6cc8a80757b --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/_numeric_suite.py @@ -0,0 +1,526 @@ +import torch +import torch.nn as nn +import torch.ao.nn.quantized as nnq +import torch.ao.nn.quantized.dynamic as nnqd +from torch.ao.quantization import prepare +from typing import Dict, List, Optional, Any, Union, Callable, Set + +from torch.ao.quantization.quantization_mappings import ( + get_default_compare_output_module_list, +) + +NON_LEAF_MODULE_TO_ADD_OBSERVER_ALLOW_LIST = { + nnqd.Linear, + nnq.Linear, + nnqd.LSTM, + nn.LSTM, +} + + +def _find_match( + str_list: Union[Dict[str, Any], List[str]], key_str: str, + postfix: str, +) -> Optional[str]: + split_str = key_str.split(".") + if split_str[-1] == postfix: + match_string = "".join(key_str.split(".")[0:-1]) + for s2 in str_list: + pattern1 = "".join(s2.split(".")[0:-1]) + pattern2 = "".join(s2.split(".")[0:-2]) + if match_string == pattern1: + return s2 + if match_string == pattern2: + return s2 + + # For matching "fc.weight" and "fc._packed_params._packed_params" + if postfix == "_packed_params": + match_string = "".join(key_str.split(".")[0:-2]) + if len(match_string) == 0: + return None + for s2 in str_list: + pattern1 = "".join(s2.split(".")[0:-1]) + pattern2 = "".join(s2.split(".")[0:-2]) + if match_string == pattern1: + return s2 + if match_string == pattern2: + return s2 + return None + else: + return None + + +def compare_weights( + float_dict: Dict[str, Any], quantized_dict: Dict[str, Any] +) -> Dict[str, Dict[str, torch.Tensor]]: + r"""Compare the weights of the float module with its corresponding quantized + module. Return a dict with key corresponding to module names and each entry being + a dictionary with two keys 'float' and 'quantized', containing the float and + quantized weights. This dict can be used to compare and compute the quantization + error of the weights of float and quantized models. + + Example usage:: + + wt_compare_dict = compare_weights( + float_model.state_dict(), qmodel.state_dict()) + for key in wt_compare_dict: + print( + key, + compute_error( + wt_compare_dict[key]['float'], + wt_compare_dict[key]['quantized'].dequantize() + ) + ) + + Args: + float_dict: state dict of the float model + quantized_dict: state dict of the quantized model + + Return: + weight_dict: dict with key corresponding to module names and each entry being + a dictionary with two keys 'float' and 'quantized', containing the float and + quantized weights + """ + torch._C._log_api_usage_once("quantization_api._numeric_suite.compare_weights") + weight_dict: Dict[str, Dict] = {} + for key in quantized_dict: + match_key = _find_match(float_dict, key, "weight") + if match_key is not None: + weight_dict[key] = {} + weight_dict[key]["float"] = float_dict[match_key] + weight_dict[key]["quantized"] = quantized_dict[key] + continue + + # For matching "fc.weight" and "fc._packed_params._packed_params" + match_key = _find_match(float_dict, key, "_packed_params") + if match_key is not None: + weight_dict[key] = {} + weight_dict[key]["float"] = float_dict[match_key] + weight_dict[key]["quantized"] = quantized_dict[key][0] + + # For LSTM + split_str = key.split(".") + if split_str[-1] == "param" and split_str[-3] == "_all_weight_values": + layer = split_str[-2] + module_name = ".".join(split_str[:-3]) + float_weight_ih_key = module_name + ".weight_ih_l" + layer + float_weight_hh_key = module_name + ".weight_hh_l" + layer + if float_weight_ih_key in float_dict and float_weight_hh_key in float_dict: + weight_dict[key] = {} + weight_dict[key]["float"] = float_dict[float_weight_ih_key] + weight_dict[key]["quantized"] = ( + quantized_dict[key].__getstate__()[0][4][0].__getstate__()[0][0] + ) + weight_dict[key]["float"] = float_dict[float_weight_hh_key] + weight_dict[key]["quantized"] = ( + quantized_dict[key].__getstate__()[0][4][1].__getstate__()[0][0] + ) + + return weight_dict + + +def _get_logger_dict_helper( + mod: nn.Module, target_dict: Dict[str, Any], + prefix: str = "", +) -> None: + r"""This is the helper function for get_logger_dict + + Args: + mod: module we want to save all logger stats + prefix: prefix for the current module + target_dict: the dictionary used to save all logger stats + """ + + def get_prefix(prefix): + return prefix if prefix == "" else prefix + "." + + for name, child in mod.named_children(): + if isinstance(child, Logger): + target_dict[get_prefix(prefix) + "stats"] = child.stats + break + + for name, child in mod.named_children(): + module_prefix = get_prefix(prefix) + name if prefix else name + _get_logger_dict_helper(child, target_dict, module_prefix) + + +def get_logger_dict(mod: nn.Module, prefix: str = "") -> Dict[str, Dict]: + r"""Traverse the modules and save all logger stats into target dict. + This is mainly used for quantization accuracy debug. + + Type of loggers supported: + ShadowLogger: used to log the outputs of the quantized module and its matching float shadow module, + OutputLogger: used to log the outputs of the modules + + Args: + mod: module we want to save all logger stats + prefix: prefix for the current module + + Return: + target_dict: the dictionary used to save all logger stats + + """ + torch._C._log_api_usage_once("quantization_api._numeric_suite.get_logger_dict") + + target_dict: Dict[str, Dict] = {} + _get_logger_dict_helper(mod, target_dict, prefix) + return target_dict + + +class Logger(nn.Module): + r"""Base class for stats logging + """ + + def __init__(self): + super().__init__() + self.stats = {} + # We only insert observer if the op is quantized with static quantization, + # which is identified by activation_observer.dtype == quint8. This is needed + # when attaching Logger as observer for FX mode + self.dtype = torch.quint8 + + def forward(self, x): + """ + """ # blank docblock to make autodoc happy + pass + + +class ShadowLogger(Logger): + r"""Class used in Shadow module to record the outputs of the original and + shadow modules. + """ + + def __init__(self): + super().__init__() + self.stats["float"] = [] + self.stats["quantized"] = [] + + def forward(self, x, y): + """ + """ # blank docblock to make autodoc happy + if len(x) > 1: + x = x[0] + if len(y) > 1: + y = y[0] + self.stats["quantized"].append(x.detach()) + self.stats["float"].append(y.detach()) + + +class OutputLogger(Logger): + r"""Class used to log the outputs of the module + """ + + def __init__(self): + super().__init__() + self.stats["tensor_val"] = [] + + + def forward(self, x): + """ + """ # blank docblock to make autodoc happy + self.stats["tensor_val"].append(x) + return x + + +def _convert_tuple_to_list(t: Any) -> Any: + return [_convert_tuple_to_list(x) for x in t] if type(t) is tuple else t + + +def _dequantize_tensor_list(t: Any) -> Any: + return ( + [_dequantize_tensor_list(x) for x in t] + if type(t) is list + else t.dequantize() + if t.is_quantized + else t + ) + + +class Shadow(nn.Module): + r"""Shadow module attaches the float module to its matching quantized module + as the shadow. Then it uses Logger module to process the outputs of both + modules. + + Args: + q_module: module quantized from float_module that we want to shadow + float_module: float module used to shadow q_module + logger_cls: type of logger used to process the outputs of q_module and + float_module. ShadowLogger or custom loggers can be used. + """ + + def __init__(self, q_module, float_module, logger_cls): + super().__init__() + self.orig_module = q_module + self.shadow_module = float_module + self.dequant = nnq.DeQuantize() + self.logger = logger_cls() + + def forward(self, *x) -> torch.Tensor: + """ + """ # blank docblock to make autodoc happy + xl = _convert_tuple_to_list(x) + output = self.orig_module(*xl) + xl_float = _dequantize_tensor_list(xl) + shadow_output = self.shadow_module(*xl_float) + self.logger(output, shadow_output) + return output + + def add(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """ + """ # blank docblock to make autodoc happy + output = self.orig_module.add(x, y) + x = x.dequantize() + y = y.dequantize() + shadow_output = self.shadow_module.add(x, y) + self.logger(output, shadow_output) + return output + + def add_scalar(self, x: torch.Tensor, y: float) -> torch.Tensor: + """ + """ # blank docblock to make autodoc happy + output = self.orig_module.add_scalar(x, y) + x = x.dequantize() + shadow_output = self.shadow_module.add_scalar(x, y) + self.logger(output, shadow_output) + return output + + def mul(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """ + """ # blank docblock to make autodoc happy + output = self.orig_module.mul(x, y) + x = x.dequantize() + y = y.dequantize() + shadow_output = self.shadow_module.mul(x, y) + self.logger(output, shadow_output) + return output + + def mul_scalar(self, x: torch.Tensor, y: float) -> torch.Tensor: + """ + """ # blank docblock to make autodoc happy + output = self.orig_module.mul_scalar(x, y) + x = x.dequantize() + shadow_output = self.shadow_module.mul_scalar(x, y) + self.logger(output, shadow_output) + return output + + def cat(self, x: List[torch.Tensor], dim: int = 0) -> torch.Tensor: + """ + """ # blank docblock to make autodoc happy + output = self.orig_module.cat(x, dim) + x = [y.dequantize() for y in x] + shadow_output = self.shadow_module.cat(x, dim) + self.logger(output, shadow_output) + return output + + def add_relu(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """ + """ # blank docblock to make autodoc happy + output = self.orig_module.add_relu(x, y) + x = x.dequantize() + y = y.dequantize() + shadow_output = self.shadow_module.add_relu(x, y) + self.logger(output, shadow_output) + return output + + +def prepare_model_with_stubs( + float_module: nn.Module, q_module: nn.Module, + module_swap_list: Set[type], logger_cls: Callable, +) -> None: + r"""Prepare the model by attaching the float module to its matching quantized + module as the shadow if the float module type is in module_swap_list. + + Example usage:: + + prepare_model_with_stubs(float_model, q_model, module_swap_list, Logger) + q_model(data) + ob_dict = get_logger_dict(q_model) + + Args: + float_module: float module used to generate the q_module + q_module: module quantized from float_module + module_swap_list: list of float module types to attach the shadow + logger_cls: type of logger to be used in shadow module to process the outputs of + quantized module and its float shadow module + """ + torch._C._log_api_usage_once("quantization_api._numeric_suite.prepare_model_with_stubs") + + float_module_children = {} + for name, mod in float_module.named_children(): + float_module_children[name] = mod + + reassign = {} + for name, mod in q_module.named_children(): + + if name not in float_module_children: + continue + + float_mod = float_module_children[name] + + if type(float_mod) not in module_swap_list: + prepare_model_with_stubs(float_mod, mod, module_swap_list, logger_cls) + + # Insert shadow module only if the module is not of the same type as + # the floating point module + if type(float_mod) in module_swap_list and not _is_identical_module_type(mod, float_mod): + reassign[name] = Shadow(mod, float_mod, logger_cls) + + for key, value in reassign.items(): + q_module._modules[key] = value + +def _is_identical_module_type(mod1, mod2): + # Compare if two modules have the same dtype + mod1_module_types = [type(mod) for mod in mod1.modules()] + mod2_module_types = [type(mod) for mod in mod2.modules()] + return mod1_module_types == mod2_module_types + + + +def compare_model_stub( + float_model: nn.Module, q_model: nn.Module, module_swap_list: Set[type], + *data, logger_cls=ShadowLogger +) -> Dict[str, Dict]: + r"""Compare quantized module in a model with its floating point counterpart, + feeding both of them the same input. Return a dict with key corresponding to + module names and each entry being a dictionary with two keys 'float' and + 'quantized', containing the output tensors of quantized and its matching + float shadow module. This dict can be used to compare and compute the module + level quantization error. + + This function first call prepare_model_with_stubs() to swap the quantized + module that we want to compare with the Shadow module, which takes quantized + module, corresponding float module and logger as input, and creates a forward + path inside to make the float module to shadow quantized module sharing the + same input. The logger can be customizable, default logger is ShadowLogger + and it will save the outputs of the quantized module and float module that + can be used to compute the module level quantization error. + + Example usage:: + + module_swap_list = [torchvision.models.quantization.resnet.QuantizableBasicBlock] + ob_dict = compare_model_stub(float_model,qmodel,module_swap_list, data) + for key in ob_dict: + print(key, compute_error(ob_dict[key]['float'], ob_dict[key]['quantized'].dequantize())) + + Args: + float_model: float model used to generate the q_model + q_model: model quantized from float_model + module_swap_list: list of float module types at which shadow modules will + be attached. + data: input data used to run the prepared q_model + logger_cls: type of logger to be used in shadow module to process the outputs of + quantized module and its float shadow module + """ + torch._C._log_api_usage_once("quantization_api._numeric_suite.compare_model_stub") + prepare_model_with_stubs(float_model, q_model, module_swap_list, logger_cls) + q_model(*data) + ob_dict = get_logger_dict(q_model) + return ob_dict + + +def get_matching_activations( + float_module: nn.Module, q_module: nn.Module, +) -> Dict[str, Dict[str, torch.Tensor]]: + r"""Find the matching activation between float and quantized modules. + + Args: + float_module: float module used to generate the q_module + q_module: module quantized from float_module + + Return: + act_dict: dict with key corresponding to quantized module names and each + entry being a dictionary with two keys 'float' and 'quantized', containing + the matching float and quantized activations + """ + torch._C._log_api_usage_once("quantization_api._numeric_suite.get_matching_activations") + float_dict = get_logger_dict(float_module) + quantized_dict = get_logger_dict(q_module) + act_dict: Dict[str, Dict] = {} + for key in quantized_dict: + if len(quantized_dict[key]["tensor_val"]) == 0: + continue + match_key = _find_match(sorted(float_dict, reverse=True), key, "stats") + if match_key is not None: + act_dict[key] = {} + act_dict[key]["float"] = float_dict[match_key]["tensor_val"] + act_dict[key]["quantized"] = quantized_dict[key]["tensor_val"] + return act_dict + + +def prepare_model_outputs( + float_module: nn.Module, + q_module: nn.Module, + logger_cls=OutputLogger, + allow_list=None +) -> None: + r"""Prepare the model by attaching the logger to both float module + and quantized module if they are in the allow_list. + + Args: + float_module: float module used to generate the q_module + q_module: module quantized from float_module + logger_cls: type of logger to be attached to float_module and q_module + allow_list: list of module types to attach logger + """ + torch._C._log_api_usage_once("quantization_api._numeric_suite.prepare_model_outputs") + if allow_list is None: + allow_list = get_default_compare_output_module_list() + + qconfig_debug = torch.ao.quantization.QConfig(activation=logger_cls, weight=None) + float_module.qconfig = qconfig_debug # type: ignore[assignment] + prepare(float_module, inplace=True, allow_list=allow_list, prepare_custom_config_dict={}) + q_module.qconfig = qconfig_debug # type: ignore[assignment] + prepare( + q_module, + inplace=True, + allow_list=allow_list, + observer_non_leaf_module_list=NON_LEAF_MODULE_TO_ADD_OBSERVER_ALLOW_LIST, + prepare_custom_config_dict={} + ) + + +def compare_model_outputs( + float_model: nn.Module, + q_model: nn.Module, + *data, + logger_cls=OutputLogger, + allow_list=None +) -> Dict[str, Dict[str, torch.Tensor]]: + r"""Compare output activations between float and quantized models at + corresponding locations for the same input. Return a dict with key corresponding + to quantized module names and each entry being a dictionary with two keys + 'float' and 'quantized', containing the activations of quantized model and + float model at matching locations. This dict can be used to compare and + compute the propagation quantization error. + + Example usage:: + + act_compare_dict = compare_model_outputs(float_model, qmodel, data) + for key in act_compare_dict: + print( + key, + compute_error( + act_compare_dict[key]['float'], + act_compare_dict[key]['quantized'].dequantize() + ) + ) + + Args: + float_model: float model used to generate the q_model + q_model: model quantized from float_model + data: input data used to run the prepared float_model and q_model + logger_cls: type of logger to be attached to float_module and q_module + allow_list: list of module types to attach logger + + Return: + act_compare_dict: dict with key corresponding to quantized module names + and each entry being a dictionary with two keys 'float' and 'quantized', + containing the matching float and quantized activations + """ + torch._C._log_api_usage_once("quantization_api._numeric_suite.compare_model_outputs") + if allow_list is None: + allow_list = get_default_compare_output_module_list() + prepare_model_outputs(float_model, q_model, logger_cls, allow_list) + float_model(*data) + q_model(*data) + act_compare_dict = get_matching_activations(float_model, q_model) + return act_compare_dict diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/__pycache__/graph_matcher.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/__pycache__/graph_matcher.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dee44a167f3da066dc2a37d62d1a8cb8de6acbbb Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/__pycache__/graph_matcher.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/__pycache__/mappings.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/__pycache__/mappings.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7eb6b44225368c08a0b26459bf4279146a545735 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/__pycache__/mappings.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/__pycache__/n_shadows_utils.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/__pycache__/n_shadows_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2654b38b4052d9689425ecba76912610dc76fa84 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/__pycache__/n_shadows_utils.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/__pycache__/weight_utils.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/__pycache__/weight_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c71df815049b41b0b59c3774fe314786953e7be1 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/__pycache__/weight_utils.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/weight_utils.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/weight_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d375694b88b300cab05154e9cdf0f0088f824575 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/weight_utils.py @@ -0,0 +1,275 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.ao.nn.quantized.dynamic as nnqd +import torch.ao.nn.quantized as nnq +import torch.ao.nn.intrinsic.qat as nniqat +import torch.ao.nn.qat as nnqat +import torch.ao.nn.intrinsic as nni +import torch.ao.nn.intrinsic.quantized as nniq +toq = torch.ops.quantized +from torch.fx import GraphModule +from torch.fx.graph import Node + +from .utils import ( + get_target_type_str, + getattr_from_fqn, + return_first_non_observer_node, +) + +from .ns_types import ( + NSSingleResultValuesType, + NSSingleResultType, +) + +from typing import List, Optional, Dict, Callable + +def mod_weight_detach(mod: nn.Module) -> torch.Tensor: + return mod.weight.detach() # type: ignore[operator] + +def mod_0_weight_detach(mod: nn.Module) -> torch.Tensor: + return mod[0].weight.detach() # type: ignore[index] + +def mod_weight_bias_0(mod: nn.Module) -> torch.Tensor: + return mod._weight_bias()[0] # type: ignore[operator] + +def get_lstm_weight(mod: nn.Module) -> List[torch.Tensor]: + res = [] + for idx, param_name in enumerate(mod._flat_weights_names): # type: ignore[arg-type] + if 'weight_ih_l' in param_name or 'weight_hh_l' in param_name: + param_value = mod._flat_weights[idx].detach() # type: ignore[index] + res.append(param_value) + return res + +def get_qlstm_weight(mod: nn.Module) -> List[torch.Tensor]: + res = [] + for weight_value in mod._all_weight_values: # type: ignore[union-attr] + res.append(weight_value.param.__getstate__()[0][4][0].__getstate__()[0][0]) + res.append(weight_value.param.__getstate__()[0][4][1].__getstate__()[0][0]) + return res + +def get_conv_mod_weight(mod: nn.Module) -> torch.Tensor: + if ( + isinstance(mod, (nn.Conv1d, nn.Conv2d, nn.Conv3d)) + ): + return mod.weight.detach() + elif ( + isinstance(mod, (nni.ConvReLU1d, nni.ConvReLU2d, nni.ConvReLU3d)) + ): + return mod[0].weight.detach() + else: + return mod._weight_bias()[0] # type: ignore[operator] + +def get_linear_mod_weight(mod: nn.Module) -> torch.Tensor: + if isinstance(mod, nn.Linear): + return mod.weight.detach() + elif isinstance(mod, nni.LinearReLU): + return mod[0].weight.detach() + else: + return mod._weight_bias()[0] # type: ignore[operator] + +def get_lstm_mod_weights(mod: nn.Module) -> List[torch.Tensor]: + # TODO(future PR): make more generic, handle everything + if isinstance(mod, nn.LSTM): + res = [] + for idx, param_name in enumerate(mod._flat_weights_names): + if 'weight_ih_l' in param_name or 'weight_hh_l' in param_name: + param_value = mod._flat_weights[idx].detach() + res.append(param_value) + return res + else: + assert isinstance(mod, nnqd.LSTM), f"type {type(mod)} not handled yet" + res = [] + for weight_value in mod._all_weight_values: + res.append(weight_value.param.__getstate__()[0][4][0].__getstate__()[0][0]) + res.append(weight_value.param.__getstate__()[0][4][1].__getstate__()[0][0]) + return res + +def get_conv_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor: + # traverse backwards from the weight arg, accounting for any observers + weight_arg_node = node.args[1] + assert isinstance(weight_arg_node, Node) + weight_node = return_first_non_observer_node(weight_arg_node, gm) + assert isinstance(weight_node, Node) + assert weight_node.op == 'get_attr' + weight = getattr_from_fqn(gm, weight_node.target) # type: ignore[arg-type] + return weight.detach() + +def get_qconv_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor: + # qconv state is arg 1 + qconv_state_node = node.args[1] + assert isinstance(qconv_state_node, Node) + assert qconv_state_node.op == 'get_attr' + qconv_state_obj = getattr_from_fqn(gm, qconv_state_node.target) # type: ignore[arg-type] + return qconv_state_obj.weight() + +def get_linear_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor: + # traverse backwards from the weight arg, accounting for any observers + # supported patterns: + # weight -> obs -> linear + # weight -> to(torch.float16) -> dequantize -> linear + linear_second_arg = node.args[1] + assert isinstance(linear_second_arg, Node) + + if linear_second_arg.op == 'call_module': + # weight -> obs -> linear + weight_arg_node = node.args[1] + assert isinstance(weight_arg_node, Node) + weight_node = weight_arg_node.args[0] + assert isinstance(weight_node, Node) + assert weight_node.op == 'get_attr' + weight = getattr_from_fqn(gm, weight_node.target) # type: ignore[arg-type] + return weight.detach() + elif linear_second_arg.op == 'call_method': + # weight -> to(torch.float16) -> dequantize -> linear + assert linear_second_arg.op == 'call_method' + dequant_node = node.args[1] + assert isinstance(dequant_node, Node) + to_fp16_node = dequant_node.args[0] + assert isinstance(to_fp16_node, Node) + # extract the dtype, so we can cast to it before returning + target_dtype = to_fp16_node.args[1] + weight_node = to_fp16_node.args[0] + assert isinstance(weight_node, Node) + assert weight_node.op == 'get_attr' + weight = getattr_from_fqn(gm, weight_node.target) # type: ignore[arg-type] + # return the weight with fp16 cast + return weight.detach().to(target_dtype) + else: + assert linear_second_arg.op == 'get_attr' + weight = getattr_from_fqn(gm, linear_second_arg.target) # type: ignore[arg-type] + return weight.detach() + +def get_qlinear_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor: + # packed weight is arg 1 + packed_weight_node = node.args[1] + assert isinstance(packed_weight_node, Node) + assert packed_weight_node.op == 'get_attr' + packed_weight = getattr_from_fqn(gm, packed_weight_node.target) # type: ignore[arg-type] + # TODO(future PR): why does packed_weight.unpack() not work? + (weight, _bias), _name = packed_weight.__getstate__() + return weight + +def get_op_to_type_to_weight_extraction_fn() -> Dict[str, Dict[Callable, Callable]]: + + op_to_type_to_weight_extraction_fn: Dict[str, Dict[Callable, Callable]] = { + 'call_module': { + # Conv1d + nn.Conv1d: mod_weight_detach, + nni.ConvReLU1d: mod_0_weight_detach, + nnq.Conv1d: mod_weight_bias_0, + nnqat.Conv1d: mod_weight_detach, + nniqat.ConvBn1d: mod_weight_detach, + nniqat.ConvBnReLU1d: mod_weight_detach, + nniqat.ConvReLU1d: mod_weight_detach, + nniq.ConvReLU1d: mod_weight_bias_0, + # Conv2d + nn.Conv2d: mod_weight_detach, + nni.ConvReLU2d: mod_0_weight_detach, + nnq.Conv2d: mod_weight_bias_0, + nnqat.Conv2d: mod_weight_detach, + nniqat.ConvBn2d: mod_weight_detach, + nniqat.ConvBnReLU2d: mod_weight_detach, + nniqat.ConvReLU2d: mod_weight_detach, + nniq.ConvReLU2d: mod_weight_bias_0, + # Conv3d + nn.Conv3d: mod_weight_detach, + nni.ConvReLU3d: mod_0_weight_detach, + nnq.Conv3d: mod_weight_bias_0, + nnqat.Conv3d: mod_weight_detach, + nniqat.ConvBn3d: mod_weight_detach, + nniqat.ConvBnReLU3d: mod_weight_detach, + nniqat.ConvReLU3d: mod_weight_detach, + nniq.ConvReLU3d: mod_weight_bias_0, + # Linear + nn.Linear: mod_weight_detach, + nnq.Linear: mod_weight_bias_0, + nni.LinearReLU: mod_0_weight_detach, + nniq.LinearReLU: mod_weight_bias_0, + nnqat.Linear: mod_weight_detach, + nnqd.Linear: mod_weight_bias_0, + nniqat.LinearReLU: mod_weight_detach, + nniqat.LinearBn1d: mod_weight_detach, + nn.modules.linear.NonDynamicallyQuantizableLinear: mod_weight_detach, + # LSTM + nn.LSTM: get_lstm_weight, + nnqd.LSTM: get_qlstm_weight, + }, + 'call_function': { + # Conv + F.conv1d: get_conv_fun_weight, + F.conv2d: get_conv_fun_weight, + F.conv3d: get_conv_fun_weight, + toq.conv1d: get_qconv_fun_weight, + toq.conv2d: get_qconv_fun_weight, + toq.conv3d: get_qconv_fun_weight, + toq.conv1d_relu: get_qconv_fun_weight, + toq.conv2d_relu: get_qconv_fun_weight, + toq.conv3d_relu: get_qconv_fun_weight, + # Linear + F.linear: get_linear_fun_weight, + toq.linear: get_qlinear_fun_weight, + toq.linear_relu: get_qlinear_fun_weight, + }, + } + + return op_to_type_to_weight_extraction_fn + +def extract_weight_from_node( + node: Node, + gm: GraphModule, + op_to_type_to_weight_extraction_fn: Optional[Dict[str, Dict[Callable, Callable]]] = None, +) -> Optional[NSSingleResultType]: + res_type = NSSingleResultValuesType.WEIGHT.value + + # Not all graphmodules have _node_name_to_scope, so only fill it + # out if it exists. + fqn = None + if hasattr(gm, '_node_name_to_scope'): + fqn = gm._node_name_to_scope[node.name][0] # type: ignore[index] + + if op_to_type_to_weight_extraction_fn is None: + op_to_type_to_weight_extraction_fn = get_op_to_type_to_weight_extraction_fn() + + ref_node_type = get_target_type_str(node, gm) + # for extracting weights, these are always the same + prev_node_type = ref_node_type + + if node.op == 'call_function': + function_mapping = op_to_type_to_weight_extraction_fn['call_function'] + for target_fn_type, weight_extraction_fn in function_mapping.items(): + if node.target == target_fn_type: + weight = weight_extraction_fn(node, gm) + return { + 'type': res_type, + 'values': [weight], + 'prev_node_name': node.name, + 'prev_node_target_type': prev_node_type, + 'ref_node_name': node.name, + 'ref_node_target_type': ref_node_type, + 'index_within_arg': 0, + 'index_of_arg': 0, + 'fqn': fqn, + } + + elif node.op == 'call_module': + # for call_module, we need to look up the modules to do the type check + assert isinstance(node.target, str) + mod = getattr_from_fqn(gm, node.target) + module_mapping = op_to_type_to_weight_extraction_fn['call_module'] + for target_mod_type, weight_extraction_fn in module_mapping.items(): + if type(mod) == target_mod_type: + weight = weight_extraction_fn(mod) + return { + 'type': res_type, + 'values': [weight], + 'prev_node_name': node.name, + 'prev_node_target_type': prev_node_type, + 'ref_node_name': node.name, + 'ref_node_target_type': ref_node_type, + 'index_within_arg': 0, + 'index_of_arg': 0, + 'fqn': fqn, + } + + return None diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/fake_quantize.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/fake_quantize.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..05bd070c5e1061b6c069649114a72b39912b35f7 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/fake_quantize.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/qconfig.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/qconfig.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1d18cb8b972becfc599b1d4eb4054a5141fc827a Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/qconfig.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/utils.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..822153cc602b202031d315895a67ae6e3643fc6b Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/utils.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/__pycache__/graph_signature.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/__pycache__/graph_signature.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..756563d1c8b5f141d4601cd0e9b727cd500fe710 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/__pycache__/graph_signature.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/_remove_auto_functionalized_pass.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/_remove_auto_functionalized_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..76280411f2bcf6124c8b659171699806bb01fc2b --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/_remove_auto_functionalized_pass.py @@ -0,0 +1,93 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import operator +from typing import List + +import torch +from torch._higher_order_ops.auto_functionalize import ( + auto_functionalized, + get_mutable_arg_names, +) +from torch.export import ExportedProgram + + +def unsafe_remove_auto_functionalized_pass( + ep: ExportedProgram, +) -> ExportedProgram: + """ + This pass removes an instances of the higher order op 'auto_functionalized', + and modifies the calling EP inplace to have the original mutator op. + This pass doesn't perform safety checks to make sure that this inplace mutation is safe. + """ + auto_functionalize_nodes: List[torch.fx.Node] = [] + for module in ep.graph_module.modules(): + if not isinstance(module, torch.fx.GraphModule): + continue + for node in ep.graph.nodes: + if node.op == "call_function" and node.target is auto_functionalized: + auto_functionalize_nodes.append(node) + + # Update every use of the HOP + for node in reversed(auto_functionalize_nodes): + func = node.args[0] + original_kwargs = node.kwargs + assert isinstance(func, torch._ops.OpOverload) + + with ep.graph.inserting_before(node): + # This makes the call_function refer to every arg as a kwarg, this is weird but probably fine? + new_node = ep.graph.call_function(func, kwargs=node.kwargs) + for k, v in node.meta.items(): + new_node.meta[k] = v + + # Replace auto_functionalize(func, args) with just func(args) + node.replace_all_uses_with(new_node) + + mutable_args_names = get_mutable_arg_names(new_node.target) + output_specs = ep.graph_signature.output_specs + + # update the users of the auto_func node (the getitem nodes) + for user in list(new_node.users.keys()): + assert user.target == operator.getitem + # getitem corresponding to a mutated input, just replace all uses with the original input + if user.args[1] >= len(func._schema.returns): + assert user.args[1] <= len(func._schema.returns) + len( + mutable_args_names + ) + + # If the result of getitem was used in an output node, update the output spec with the correct name + adusted_index = user.args[1] - len(func._schema.returns) + original_arg = original_kwargs[mutable_args_names[adusted_index]] + for spec in output_specs: + if spec.arg.name == user.name: + spec.arg.name = original_arg.name # pyre-ignore + break + + # This is a little fragile/implementation dependent, but the order of the mutable args is the same as the order + # of the getitem calls following the HOP. + user.replace_all_uses_with( + original_kwargs[mutable_args_names[adusted_index]] + ) + + if len(func._schema.returns) == 1: + # If the function has 1 return then it will just directly return the + # result -- we don't need a getitem. So we can replace all the + # getitem(auto_functionalized, 0) with just the note itself. + for user in list(new_node.users.keys()): + if user.args[1] == 0: + user.replace_all_uses_with(new_node) + + # Same case as above, update the output spec if getitem result used in an output node + for spec in output_specs: + if spec.arg.name == user.name: + spec.arg.name = new_node.name + break + + new_node.meta["val"] = node.meta["val"][: len(func._schema.returns)] + ep.graph.erase_node(node) + + ep.graph.eliminate_dead_code() + return ep diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/operation.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/operation.py new file mode 100644 index 0000000000000000000000000000000000000000..ec2cb91bbcc1790c419fa603b36cf6bc7afddc18 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/operation.py @@ -0,0 +1,14 @@ +op_add = '+' +op_sub = '-' +op_mul = '*' +op_div = '/' +op_eq = '=' +op_neq = '!=' +op_imp = '=>' +op_matching = '⊳' +op_consistency = '~' +op_precision = '⊑' +op_leq = '≤' +op_lt = '<' +op_gt = '>' +op_mod = '%' diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/z3_types.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/z3_types.py new file mode 100644 index 0000000000000000000000000000000000000000..897a79d5697573a51f5886d5e9965a98e2c4cf6a --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/z3_types.py @@ -0,0 +1,29 @@ +try: + import z3 # type: ignore[import] + HAS_Z3 = True + # dynamic type + dyn = z3.DeclareSort('Dyn') + dyn_type = z3.Const('dyn', dyn) + + # dimension + dim = z3.Datatype('dim') + dim.declare('dim', ('0', z3.IntSort()), ('1', z3.IntSort())) + dim = dim.create() + + # tensors + tensor_type = z3.Datatype('TensorType') + tensor_type.declare('Dyn', ('dyn', dyn)) + tensor_type.declare('tensor1', ('0', dim)) + tensor_type.declare('tensor2', ('0', dim), ('1', dim)) + tensor_type.declare('tensor3', ('0', dim), ('1', dim), ('2', dim)) + tensor_type.declare('tensor4', ('0', dim), ('1', dim), ('2', dim), ('3', dim)) + tensor_type = tensor_type.create() + + # create dimension + D = dim.dim + + z3_dyn = tensor_type.Dyn(dyn_type) + + +except ImportError: + HAS_Z3 = False diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..31446d0e61253d7f722a3235e6e4c5788b4b01ba --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__init__.py @@ -0,0 +1,4 @@ +# mypy: disable-error-code=attr-defined +from .core import unify, reify # noqa: F403 +from .more import unifiable # noqa: F403 +from .variable import var, isvar, vars, variables, Var # noqa: F403 diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/match.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/match.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ea2de67532cc12f882905ef17b0157b48ee21735 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/match.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a0295af0ea6b6b92836e034c1d28cfdf69b1d3ba --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__init__.py @@ -0,0 +1,3 @@ +from .core import dispatch +from .dispatcher import (Dispatcher, halt_ordering, restart_ordering, + MDNotImplementedError) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/conflict.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/conflict.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1bd4ab34e758455f8f04633d2d378d5a6a83590d Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/conflict.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/utils.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..26be3d12b03f17b1adf5323447f0f2417c196fa7 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/utils.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9577d6c66a9ea03c99953fd1a3a8c79296f3a15f --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__init__.py @@ -0,0 +1,11 @@ +from . import graph_drawer +from . import graph_manipulation +from . import net_min_base +from . import operator_support +from . import param_fetch +from . import reinplace +from . import shape_prop +from . import split_module +from . import split_utils +from . import splitter_base +from . import tools_common diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/dialect/common/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/dialect/common/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/dialect/common/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/dialect/common/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e710c7fcab347a0e628656bca037332da84c275a Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/dialect/common/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/graph_manipulation.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/graph_manipulation.py new file mode 100644 index 0000000000000000000000000000000000000000..f6e53f0e969a1614edf8f2b97f6f28fe527b332e --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/graph_manipulation.py @@ -0,0 +1,110 @@ +from typing import Any, Dict, List, NamedTuple, Optional + +import torch +from torch.fx._compatibility import compatibility +from torch.fx.graph import Graph +from torch.fx.graph_module import GraphModule +from torch.fx.node import ( + map_arg, + Node, + Target, +) +from torch.fx.passes.shape_prop import ShapeProp + +__all__ = ['replace_target_nodes_with', 'size_bytes', 'get_size_of_all_nodes', 'get_tensor_meta', + 'get_size_of_node'] + +@compatibility(is_backward_compatible=False) +def replace_target_nodes_with( + fx_module: GraphModule, + old_op: str, + old_target: Target, + new_op: str, + new_target: Target, +): + """Modifies all nodes in fx_module.graph.nodes which match the specified op code and target, + and updates them to match the new op code and target""" + new_graph = Graph() + val_map: Dict[Node, Node] = {} + for node in fx_module.graph.nodes: + if node.op == old_op and node.target == old_target: + args = map_arg(node.args, lambda n: val_map[n]) + kwargs = map_arg(node.kwargs, lambda n: val_map[n]) + assert isinstance(args, tuple) + assert isinstance(kwargs, dict) + val_map[node] = new_graph.create_node( + new_op, new_target, args, kwargs, node.name + ) + else: + val_map[node] = new_graph.node_copy(node, lambda n: val_map[n]) + fx_module.graph = new_graph + + +@compatibility(is_backward_compatible=False) +class size_bytes(NamedTuple): + output_size: int + total_size: int + + +@compatibility(is_backward_compatible=False) +def get_size_of_all_nodes( + fx_module: GraphModule, args: Optional[List[torch.Tensor]] = None +) -> None: + """Given a fx graph module, update each node with its total size (weights + bias + output) + and its output_size(output). For a non-module node, the total size is the output size. + return total size""" + if args is not None: + # Mark shape and dtype for each node (node.shape and node.dtype) + ShapeProp(fx_module).propagate(*args) + # Calculate the total size of the whole fx graph + total_size_of_graph = 0.0 + for node in fx_module.graph.nodes: + if node.op == "output": + break + node.size_bytes = get_size_of_node(fx_module, node) + return + + +@compatibility(is_backward_compatible=False) +def get_tensor_meta(node: Node) -> Any: + tensor_meta = node.meta.get("tensor_meta") + + if not tensor_meta: + raise RuntimeError( + f"Node {node} has no tensor metadata associated with it! " + f"Check that shape propagation has run." + ) + + return tensor_meta + + +@compatibility(is_backward_compatible=False) +def get_size_of_node(fx_module: GraphModule, node: Node) -> size_bytes: + """Given a node with node.dtype and node.shape, return its total size and its output size. + total_size = weights + bias + output_size + """ + # Total num of elements + total_num_of_elems = 0 + # For a module, conside all parameters + if node.op == "call_module": + submodule_dict = dict(fx_module.named_modules()) + submodule = submodule_dict[node.target] + parameters = submodule.named_parameters() + # Parameters are named tuples + for name, p in parameters: + total_num_of_elems += p.numel() + # Don't forget the output size + # node.shape is the shape of this node's output + tensor_meta = get_tensor_meta(node) + output_elem = tensor_meta.shape.numel() + total_num_of_elems += output_elem + # Assume for now if it's quantized then it's qint8 or quint8 + if tensor_meta.is_quantized: + size_per_elem_bytes = torch._empty_affine_quantized( + [], dtype=tensor_meta.dtype + ).element_size() + else: + size_per_elem_bytes = torch.tensor([], dtype=tensor_meta.dtype).element_size() + total_size = size_per_elem_bytes * total_num_of_elems + output_size = size_per_elem_bytes * output_elem + return size_bytes(output_size, total_size) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/net_min_base.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/net_min_base.py new file mode 100644 index 0000000000000000000000000000000000000000..98aef4eb54a7664afe54472c0fc1113740e7a0ff --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/net_min_base.py @@ -0,0 +1,731 @@ +import logging +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple + +import torch +import torch.fx + +from torch.fx._compatibility import compatibility +from torch.fx.node import map_arg + +from .shape_prop import ShapeProp +from .split_utils import split_by_tags +from .tools_common import ( + CALLABLE_NODE_OPS, + FxNetAccFusionsFinder, + Names, + NodeList, + NodeSet, + TensorOrTensors, + Tensors, +) + +__all__ = [ + "FxNetMinimizerBadModuleError", + "FxNetMinimizerRunFuncError", + "FxNetMinimizerResultMismatchError", +] + +_LOGGER = logging.getLogger(__name__) + + +@compatibility(is_backward_compatible=False) +class FxNetMinimizerBadModuleError(Exception): + """ + Raised if failed to split out a minimize module + """ + + pass + + +@compatibility(is_backward_compatible=False) +class FxNetMinimizerRunFuncError(Exception): + """ + Raised if error occurs during run_a or run_b functions + """ + + pass + + +@compatibility(is_backward_compatible=False) +class FxNetMinimizerResultMismatchError(Exception): + """ + Raised if comparing function thinks the results are mismatching. + """ + + pass + + +@dataclass +class _MinimizerSettingBase: + """ + Args: + `accumulate_error`: Instead of using a's input for both converted module to verify + , use the previous outputs of each converted module as input to accumulate the + errors. + + `traverse_method`: "sequential" or "binary" or "accumulate" + Determine the way of traverse the nodes in FX module. + + `find_all`: Minimizer will go through the entire model and return all problematic nodes. + + `return_intermediate`: If true, when using `run_nodes()` function to run the + model, intermediate results of all the ops will be returned as output. + """ + + accumulate_error: bool = False + traverse_method: str = "sequential" + find_all: bool = False + return_intermediate: bool = False + + def __str__(self): + settings_str = "FX Minimizer Settings:\n" + + for k, v in vars(self).items(): + settings_str += f"\t{k}: {v}\n" + + return settings_str + + +class _MinimizerBase: + """ + This class is used to automatically find problematic nodes in a model. It takes a FX + graphmodule and generate some submodules while traverse the graph. Then two functions + `run_a` and `run_b` will be used to run the same submodule and a function `compare_fn` + will be used to compare the results. + + Currently we provides two ways to traverse the graph and generate submodules. + 1. Sequential traversal: this will traverse the graph node by node and generate + one submodule with one sigle node. + 2. Binary searching: this will do a binary search style traversal on the graph. + + For internal Users, a guide can be found here https://fb.quip.com/HDtuAgiKGfkP. + """ + + def __init__( + self, + module: torch.fx.GraphModule, + sample_input: Tensors, + compare_fn: Callable[ + [TensorOrTensors, TensorOrTensors, Names], Tuple[float, bool] + ], + settings: _MinimizerSettingBase, + module_exporter: Optional[ + Callable[ + [List[torch.Tensor], torch.fx.GraphModule, str], + None + ] + ] = None, + ): + assert isinstance(module, torch.fx.GraphModule) + + self.module = module + self.sample_input = sample_input + self.compare_fn = compare_fn + self.module_exporter = module_exporter + self.settings = settings + + # Stores outputs of run_a function + self.a_outputs: Dict[str, Any] = {} + + # Stores outputs of run_b function + self.b_outputs: Dict[str, Any] = {} + + # Stores the results of compare_fn + self.results: Dict[Any, Any] = {} + + # Stores the report for the runs + self.reports: List[List[str]] = [] + + # Current iteration + self.iteration: int = 0 + + callable_nodes = { + node for node in self.module.graph.nodes if node.op in CALLABLE_NODE_OPS + } + ShapeProp(self.module).propagate(*self.sample_input) + self.fusions = FxNetAccFusionsFinder(self.module, callable_nodes)() + + # Check if number of input in sample_input matches the number of placeholders + placeholders = [ + node.name for node in self.module.graph.nodes if node.op == "placeholder" + ] + assert len(placeholders) == len(self.sample_input) + + # Store sample_input + for i, name in enumerate(placeholders): + self.a_outputs[name] = sample_input[i] + self.b_outputs[name] = sample_input[i] + + def run_a(self, mod: torch.fx.GraphModule, inputs: Tensors) -> TensorOrTensors: + """ + Run `mod` with `inputs` and generate output. The output will be compared with + output of run_b(). + """ + raise RuntimeError("run_a() is not implemented.") + + def run_b(self, mod: torch.fx.GraphModule, inputs: Tensors) -> TensorOrTensors: + """ + Run `mod` with `inputs` and generate output. The output will be compared with + output of run_a(). + """ + raise RuntimeError("run_b() is not implemented.") + + def _store_outputs( + self, + a_result: TensorOrTensors, + b_result: TensorOrTensors, + submodule: torch.fx.GraphModule, + ): + """ + Store the outputs of self.run_a() and self.run_b() into self.a_outputs and + self.b_outputs, so that we can use them when execute preceding nodes that + use those outputs as inputs. + + Args: + a_result: Output of self.run_a(). Could be a tensor or tensors. + b_result: Output of self.run_b(). Could be a tensor or tensors. + submodule: The module that generates a_result and b_result. + """ + output_node = next( + node for node in submodule.graph.nodes if node.op == "output" + ) + + # Only one output + if isinstance(output_node.args[0], torch.fx.Node): + self.a_outputs[output_node.args[0].name] = a_result + self.b_outputs[output_node.args[0].name] = b_result + # Multiple outputs + else: + for i, arg in enumerate(output_node.args[0]): + self.a_outputs[arg.name] = a_result[i] + self.b_outputs[arg.name] = b_result[i] + + def _get_submod_inputs( + self, main_module: torch.fx.GraphModule, submod_path: str + ) -> Tuple[Tensors, Tensors]: + """ + Try get submodule inputs from stored outputs. If not found then use + torch_glow.get_submod_inputs to get the inputs. + + If accumulate_error is False, use a_input for run_a() and run_b() + otherwise use a_input for run_a and b_input for run_b. + + Args: + main_module: Top-levlel fx module. + submod_path: Path to the submodule we want to run and compare results. + + Returns: + a_input: List of tensor(s) that will be used by run_a() as submodule inputs. + b_input: List of tensor(s) that will be used by run_b() as submodule inputs. + """ + a_input = [] + b_input = [] + submodule = getattr(main_module, submod_path) + placeholders = [ + node.name for node in submodule.graph.nodes if node.op == "placeholder" + ] + + # If all placeholder can be found in stored outputs, use stored + # outputs as inputs. Otherwise, use `torch_glow.get_submod_inputs` + # to get the inputs. + if set(placeholders) <= self.a_outputs.keys(): + for name in placeholders: + a_input.append(self.a_outputs[name]) + b_input.append(self.b_outputs[name]) + else: + if self.settings.accumulate_error: + print(f"Can't find previous stored outputs named {placeholders}!") + + def get_inputs(self: torch.nn.Module, inputs: Any): + nonlocal a_input + a_input = inputs + + # Use forward hook to get the inputs to the submodule + handle = submodule.register_forward_pre_hook(get_inputs) + main_module(*self.sample_input) + handle.remove() + + b_input = a_input + + if not self.settings.accumulate_error: + return a_input, a_input + + return a_input, b_input + + def _tag_nodes(self, selected_nodes: NodeSet): + """ + Tag selected nodes with tag "minimize". Nodes with the same tags will + be split to the same submodule afterwards. + + Args: + selected_nodes: Nodes that we want to minimize. We will tag those nodes + with "minimize", all preceding nodes with "main_0" and all following + nodes with "main_1". + """ + for node in self.module.graph.nodes: + if node.op not in CALLABLE_NODE_OPS: + continue + + if node in selected_nodes: + node.tag = "minimize" + elif any( + n.tag in {"minimize", "main_1"} + for n in node.all_input_nodes + if n.op in CALLABLE_NODE_OPS + ): + node.tag = "main_1" + else: + node.tag = "main_0" + + def _build_submodule(self, nodes: NodeSet) -> Tuple[torch.fx.GraphModule, str]: + """ + Split self.module so that one submodule consists of `nodes` and only `nodes`. + + Args: + nodes: Nodes that we want to include in the minimize submodule. + + Returns: + split_module (torch.fx.GraphModule): the module after split. + submodule_name (str): the name of the submodule that consists of `nodes`. + """ + # Color provided nodes + self._tag_nodes(nodes) + + # Split module based on coloring + split_module = split_by_tags(self.module, ["main_0", "minimize", "main_1"]) + + # Find submodule containing colored nodes + submodule_name: str = "" + for child_name, _ in split_module.named_children(): + # Skip submodules we're not interested in at the moment + if "minimize" not in child_name: + continue + + if submodule_name == "": + submodule_name = child_name + else: + raise FxNetMinimizerBadModuleError( + f"Expected only one minimize submodule with nodes {nodes}" + ) + + if submodule_name == "": + raise FxNetMinimizerBadModuleError( + f"Minimize submodule was not found with nodes {nodes}" + ) + + return split_module, submodule_name + + def _run_and_compare( + self, split_module: torch.fx.GraphModule, submod_name: str, output_names: Names + ): + """ + Run the submodule in `split_module` that has name `submod_name` + using `self.run_a` and `self.run_b` and compare their results. + + Args: + split_module: Main module that contains the minimize submodule. + submod_name: Name of the minimize submodule. + output_names: Names of the node we want to output. If None, we + will use the original output. + """ + submodule = getattr(split_module, submod_name) + a_input, b_input = self._get_submod_inputs(split_module, submod_name) + + if len(self.reports) == 0: + self.reports.append([]) + self.iteration = 1 + + report = self.reports[self.iteration - 1] + report.append("Run and compare ...") + + if output_names: + output_nodes: NodeList = [] + for node in submodule.graph.nodes: + if node.op == "output": + submodule.graph.erase_node(node) + + if node.name in output_names: + output_nodes.append(node) + + submodule.graph.output( + output_nodes[0] if len(output_nodes) == 1 else tuple(output_nodes) + ) + submodule.graph.lint() + submodule.recompile() + + # Use name of args in output node as key to store comparison result + for node in submodule.graph.nodes: + if node.op == "output": + result_key = map_arg(node.args, lambda x: x.name) + + try: + a_result = self.run_a(submodule, a_input) + b_result = self.run_b(submodule, b_input) + self._store_outputs(a_result, b_result, submodule) + except Exception as e: + report.append(f"Exception raised when running {submod_name}: {e}") + raise FxNetMinimizerRunFuncError( # noqa: TRY200 + f"Exception raised when running {submod_name}: {e}" + ) + + # Compare results + names: Names = output_names + if output_names is None: + names = [str(v) for v in result_key] # type: ignore[possibly-undefined] + + numeric_result, bool_result = self.compare_fn(a_result, b_result, names) + + self.results[result_key] = numeric_result # type: ignore[possibly-undefined] + report.append(f"Numerical accuracy = {numeric_result}") + if not bool_result: + report.append(f"Result mismatch for {result_key}") + if self.module_exporter: + self.module_exporter( + List[torch.Tensor](a_input), submodule, str(result_key[0]) + "_cpu", + ) + self.module_exporter( + List[torch.Tensor](b_input), submodule, str(result_key[0]) + "_acc", + ) + raise FxNetMinimizerResultMismatchError(f"Result mismatch for {result_key}") + + def _binary_search_impl( + self, all_nodes: NodeList, start_idx: int, end_idx: int + ) -> NodeSet: + """ + Recursive binary search implementation. + """ + nodes: NodeList = all_nodes[start_idx:end_idx] + + report: List[str] = [] + self.reports.append(report) + self.iteration += 1 + report.append(f"Binary search iteration {self.iteration}.") + report.append( + f"From node index {start_idx} to {end_idx-1}. " + f"Size of the interested node list is {len(nodes)}" + ) + + cur_nodes: NodeSet = set(nodes) + + for node in nodes: + if node in self.fusions: + cur_nodes.update(self.fusions[node]) + + try: + split_module, submod_name = self._build_submodule(cur_nodes) + self._run_and_compare(split_module, submod_name, []) + except (FxNetMinimizerRunFuncError, FxNetMinimizerResultMismatchError): + + if len(nodes) == 1: + report.append( + f"This is the last node in the sub-module. " + f"Search in the current branch is successful with culprit = {cur_nodes}." + ) + self.print_report(report) + return cur_nodes + + report.append( + "Proceed to split and lower the halves of the current " + "sub-module individually." + ) + self.print_report(report) + + mid = len(nodes) // 2 + culprits = self._binary_search_impl(all_nodes, start_idx, start_idx + mid) + + if len(culprits) != 0 and not self.settings.find_all: + return culprits + + culprits = self._binary_search_impl(all_nodes, start_idx + mid, end_idx) + + if len(culprits) == 0: + report.append( + f"Further split and lowering found no errors. " + f"Unable to minimize the submodule with list of nodes: {nodes}" + ) + self.print_report(report) + + return culprits + else: + report.append("No discrepancy found.") + self.print_report(report) + return set() + + def _binary_traverse(self, nodes: NodeList) -> NodeSet: + """ + Binary search on `nodes` for culprit. + """ + return self._binary_search_impl(nodes, 0, len(nodes)) + + def _sequential_traverse(self, nodes: NodeList) -> NodeSet: + """ + Traverse `nodes` one by one and determine if any of them is a culprit. + """ + culprits: NodeSet = set() + + for node in nodes: + report: List[str] = [] + self.reports.append(report) + self.iteration += 1 + report.append(f"Sequential traverse iteration {self.iteration}.") + report.append(f"Visit node: {node.name}") + + _LOGGER.info("Visit node: %s", node.name) + cur_nodes: NodeSet = {node} + + if node in self.fusions: + cur_nodes = self.fusions[node] + + try: + split_module, submod_name = self._build_submodule(cur_nodes) + self._run_and_compare(split_module, submod_name, [node.name]) + self.print_report(report) + except (FxNetMinimizerResultMismatchError): + culprits.add(node) + report.append(f"Found culprit from numeric error: {node}") + self.print_report(report) + if not self.settings.find_all: + return culprits + except (FxNetMinimizerRunFuncError): + culprits.update(cur_nodes) + report.append(f"Found culprit from run error: {node}") + self.print_report(report) + if not self.settings.find_all: + return culprits + + return culprits + + def _defined_traverse(self, nodes: NodeList) -> NodeSet: + """ + run user defined `nodes` and determine if it is a culprit. + """ + culprits: NodeSet = set() + + first_node_name = nodes[0].name + output_node_name = nodes[-1].name + report = [f"Defined graph from {first_node_name} to {output_node_name}"] + cur_nodes: NodeSet = set(nodes) + try: + split_module, submod_name = self._build_submodule(cur_nodes) + self._run_and_compare(split_module, submod_name, [output_node_name]) + self.print_report(report) + except (FxNetMinimizerResultMismatchError, FxNetMinimizerRunFuncError): + report.append(f"Found culprit {cur_nodes}") + self.print_report(report) + return culprits + + return culprits + + def _accumulate_traverse(self, nodes: NodeList) -> NodeSet: + culprits: NodeSet = set() + nodes_to_run: NodeSet = set() + + # find_all is not supported for accumulate traversal because all the + # ops run on NNPI. So we return after the first op that raises error. + if self.settings.find_all: + print("'Find All' mode is not supported in accumulate traversal.") + return culprits + + for node in nodes: + report: List[str] = [] + self.reports.append(report) + self.iteration += 1 + report.append(f"Accumulate traverse iteration {self.iteration}.") + + nodes_to_run.add(node) + + node_name = node.name + if node_name is not None and isinstance(node_name, tuple): + node_name = node_name[0] + assert node_name is not None and isinstance( + node_name, str + ), f"minimize: node_name: {node_name}" + + report.append(f"Add node: {node_name}") + + try: + split_module, submod_name = self._build_submodule(nodes_to_run) + self._run_and_compare(split_module, submod_name, [node_name]) + self.print_report(report) + except (FxNetMinimizerResultMismatchError, FxNetMinimizerRunFuncError): + culprits.add(node) + report.append(f"Found culprit {node}") + self.print_report(report) + return culprits + + return culprits + + def _skip_traverse_impl(self, all_nodes: NodeList, start_idx: int, end_idx: int) -> NodeSet: + """ + Skip certain nodes in graph based on settings + """ + culprits: NodeSet = set() + nodes: NodeList = all_nodes[start_idx:end_idx] + + report: List[str] = [] + self.reports.append(report) + self.iteration += 1 + report.append(f" Nodes block {self.iteration}.") + report.append( + f"From node index {start_idx} to {end_idx-1}. " + f"Size of the interested node list is {len(nodes)}" + ) + + cur_nodes: NodeSet = set(nodes) + + for node in nodes: + if node in self.fusions: + cur_nodes.update(self.fusions[node]) + + try: + split_module, submod_name = self._build_submodule(cur_nodes) + self._run_and_compare(split_module, submod_name, []) + except (FxNetMinimizerResultMismatchError): + culprits.update(cur_nodes) + report.append(f"Found culprit from numeric error: {cur_nodes}") + self.print_report(report) + return culprits + except (FxNetMinimizerRunFuncError): + culprits.update(cur_nodes) + report.append(f"Found culprit from run error: {node}") + self.print_report(report) + return culprits + else: + report.append("No discrepancy found.") + self.print_report(report) + return set() + + + def _skip_traverse(self, all_nodes: NodeList, skip_nodes: List) -> NodeSet: + """ + Skip certain nodes in graph based on settings + """ + start_idx = 0 + num_nodes = len(all_nodes) + idx = 0 + culprits = set() + while idx < num_nodes: + node = all_nodes[idx] + if (node.name in skip_nodes): # skip the node + if idx > start_idx: + culprits = self._skip_traverse_impl(all_nodes, start_idx, idx) + start_idx = idx + 1 + elif idx == num_nodes - 1 and start_idx <= idx: # last node + culprits = self._skip_traverse_impl(all_nodes, start_idx, idx + 1) + idx += 1 + + return culprits + + + + def _collect_nodes(self, start: Optional[str], end: Optional[str]) -> NodeList: + """ + Collect nodes in the model that between nodes with name of `start` and `end`. + These two nodes are also included. + """ + nodes: NodeList = [] + add_node = start is None + + for node in self.module.graph.nodes: + if node.op not in CALLABLE_NODE_OPS: + continue + + if node.name == start: + add_node = True + + if add_node: + nodes.append(node) + + if node.name == end: + break + + return nodes + + def run_nodes(self, start: Optional[str] = None, end: Optional[str] = None): + """ + Run part of the model from `start` node to `end` node. If `start` is None + then we start from the beginning of the model. If `end` is None then we + stop at the end of the model. + + Args: + start: The name of the node which is the first node of the submodule + we want to run. If set to None, then we'll start with the first + node of the model. + end: The name of the node which is the last node of the submodule we + want to run. If set to None, we'll end with the last node of the + model. + """ + nodes = self._collect_nodes(start, end) + cur_nodes = set(nodes) + + for node in nodes: + if node in self.fusions: + cur_nodes.update(self.fusions[node]) + + output_names = [] + if self.settings.return_intermediate: + output_names = [node.name for node in nodes] + + try: + split_module, submod_name = self._build_submodule(cur_nodes) + self._run_and_compare(split_module, submod_name, output_names) + except ( + FxNetMinimizerRunFuncError, + FxNetMinimizerResultMismatchError, + ) as e: + print(e) + + def print_report(self, report: List[str]): + for i in range(len(report)): + if i > 0: + print(" . " + report[i]) + else: + print(report[i]) + + def print_reports(self): + for report in self.reports: + self.print_report(report) + + def minimize( + self, start: Optional[str] = None, end: Optional[str] = None, skip_nodes: Optional[List] = None, + ) -> NodeSet: + """ + Minimizing the model from node with name `start` to node with name `end` base + on self.settings. Find culprits that causes FxNetMinimizerRunFuncError or + FxNetMinimizerResultMismatchError errors. + + Args: + start: The name of the node where we want to start minimizing. If set + to None, then we'll start with the first node of the model. + end: The name of the node where we want to terminate minimizing. If + set to None, we'll end with the last node of the model. + + Returns: + nodes: A list of nodes that causes FxNetMinimizerRunFuncError or + FxNetMinimizerResultMismatchError errors during minimizing. + """ + + print(self.settings) + print(self.module.graph) + + nodes = self._collect_nodes(start, end) + + if self.settings.traverse_method == "sequential": + return self._sequential_traverse(nodes) + + if self.settings.traverse_method == "binary": + return self._binary_traverse(nodes) + + if self.settings.traverse_method == "accumulate": + return self._accumulate_traverse(nodes) + + if self.settings.traverse_method == "skip": + if (skip_nodes is None): + raise RuntimeError("'skip_nodes' can't be None when 'traverse_method' is 'skip'.") + return self._skip_traverse(nodes, skip_nodes) + + if self.settings.traverse_method == "defined": + return self._defined_traverse(nodes) + + raise RuntimeError(f"Unknown traverse method {self.settings.traverse_method}!") diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/pass_manager.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/pass_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..55d5ea0af54d01942ee310c8136e17d826ea183c --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/pass_manager.py @@ -0,0 +1,257 @@ +from functools import wraps +from inspect import unwrap +from typing import Callable, List, Optional +import logging + +logger = logging.getLogger(__name__) + +__all__ = [ + "PassManager", + "inplace_wrapper", + "log_hook", + "loop_pass", + "this_before_that_pass_constraint", + "these_before_those_pass_constraint", +] + +# for callables which modify object inplace and return something other than +# the object on which they act +def inplace_wrapper(fn: Callable) -> Callable: + """ + Convenience wrapper for passes which modify an object inplace. This + wrapper makes them return the modified object instead. + + Args: + fn (Callable[Object, Any]) + + Returns: + wrapped_fn (Callable[Object, Object]) + """ + + @wraps(fn) + def wrapped_fn(gm): + val = fn(gm) + return gm + + return wrapped_fn + +def log_hook(fn: Callable, level=logging.INFO) -> Callable: + """ + Logs callable output. + + This is useful for logging output of passes. Note inplace_wrapper replaces + the pass output with the modified object. If we want to log the original + output, apply this wrapper before inplace_wrapper. + + + ``` + def my_pass(d: Dict) -> bool: + changed = False + if 'foo' in d: + d['foo'] = 'bar' + changed = True + return changed + + pm = PassManager( + passes=[ + inplace_wrapper(log_hook(my_pass)) + ] + ) + ``` + + Args: + fn (Callable[Type1, Type2]) + level: logging level (e.g. logging.INFO) + + Returns: + wrapped_fn (Callable[Type1, Type2]) + """ + @wraps(fn) + def wrapped_fn(gm): + val = fn(gm) + logger.log(level, "Ran pass %s\t Return value: %s", fn, val) + return val + + return wrapped_fn + + + +def loop_pass(base_pass: Callable, n_iter: Optional[int] = None, predicate: Optional[Callable] = None): + """ + Convenience wrapper for passes which need to be applied multiple times. + + Exactly one of `n_iter`or `predicate` must be specified. + + Args: + base_pass (Callable[Object, Object]): pass to be applied in loop + n_iter (int, optional): number of times to loop pass + predicate (Callable[Object, bool], optional): + + """ + assert (n_iter is not None) ^ ( + predicate is not None + ), "Exactly one of `n_iter`or `predicate` must be specified." + + @wraps(base_pass) + def new_pass(source): + output = source + if n_iter is not None and n_iter > 0: + for _ in range(n_iter): + output = base_pass(output) + elif predicate is not None: + while predicate(output): + output = base_pass(output) + else: + raise RuntimeError( + f"loop_pass must be given positive int n_iter (given " + f"{n_iter}) xor predicate (given {predicate})" + ) + return output + + return new_pass + + +# Pass Schedule Constraints: +# +# Implemented as 'depends on' operators. A constraint is satisfied iff a list +# has a valid partial ordering according to this comparison operator. +def _validate_pass_schedule_constraint( + constraint: Callable[[Callable, Callable], bool], passes: List[Callable] +): + for i, a in enumerate(passes): + for j, b in enumerate(passes[i + 1 :]): + if constraint(a, b): + continue + raise RuntimeError( + f"pass schedule constraint violated. Expected {a} before {b}" + f" but found {a} at index {i} and {b} at index{j} in pass" + f" list." + ) + + +def this_before_that_pass_constraint(this: Callable, that: Callable): + """ + Defines a partial order ('depends on' function) where `this` must occur + before `that`. + """ + + def depends_on(a: Callable, b: Callable): + if a == that and b == this: + return False + return True + + return depends_on + + +def these_before_those_pass_constraint(these: Callable, those: Callable): + """ + Defines a partial order ('depends on' function) where `these` must occur + before `those`. Where the inputs are 'unwrapped' before comparison. + + For example, the following pass list and constraint list would be invalid. + ``` + passes = [ + loop_pass(pass_b, 3), + loop_pass(pass_a, 5), + ] + + constraints = [ + these_before_those_pass_constraint(pass_a, pass_b) + ] + ``` + + Args: + these (Callable): pass which should occur first + those (Callable): pass which should occur later + + Returns: + depends_on (Callable[[Object, Object], bool] + """ + + def depends_on(a: Callable, b: Callable): + if unwrap(a) == those and unwrap(b) == these: + return False + return True + + return depends_on + + +class PassManager: + """ + Construct a PassManager. + + Collects passes and constraints. This defines the pass schedule, manages + pass constraints and pass execution. + + Args: + passes (Optional[List[Callable]]): list of passes. A pass is a + callable which modifies an object and returns modified object + constraint (Optional[List[Callable]]): list of constraints. A + constraint is a callable which takes two passes (A, B) and returns + True if A depends on B and False otherwise. See implementation of + `this_before_that_pass_constraint` for example. + """ + + passes: List[Callable] + constraints: List[Callable] + _validated: bool = False + + def __init__( + self, + passes=None, + constraints=None, + ): + self.passes = passes or [] + self.constraints = constraints or [] + + @classmethod + def build_from_passlist(cls, passes): + pm = PassManager(passes) + # TODO(alexbeloi): add constraint management/validation + return pm + + def add_pass(self, _pass: Callable): + self.passes.append(_pass) + self._validated = False + + def add_constraint(self, constraint): + self.constraints.append(constraint) + self._validated = False + + def remove_pass(self, _passes: List[str]): + if _passes is None: + return + passes_left = [] + for ps in self.passes: + if ps.__name__ not in _passes: + passes_left.append(ps) + self.passes = passes_left + self._validated = False + + def replace_pass(self, _target, _replacement): + passes_left = [] + for ps in self.passes: + if ps.__name__ == _target.__name__: + passes_left.append(_replacement) + else: + passes_left.append(ps) + self.passes = passes_left + self._validated = False + + def validate(self): + """ + Validates that current pass schedule defined by `self.passes` is valid + according to all constraints in `self.constraints` + """ + if self._validated: + return + for constraint in self.constraints: + _validate_pass_schedule_constraint(constraint, self.passes) + self._validated = True + + def __call__(self, source): + self.validate() + out = source + for _pass in self.passes: + out = _pass(out) + return out diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/split_module.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/split_module.py new file mode 100644 index 0000000000000000000000000000000000000000..4f3eee93eccb4f419df32a4f437c8dbeacbeb11f --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/split_module.py @@ -0,0 +1,514 @@ +import inspect +from typing import Any, Callable, Dict, List, Optional, Set, TYPE_CHECKING +from collections import OrderedDict +import logging + +import torch +from torch.fx._compatibility import compatibility +from torch.fx.graph_module import GraphModule +from torch.fx.node import Node + +if TYPE_CHECKING: + import sympy # noqa: F401 + +__all__ = ["Partition", "split_module"] +_LOGGER = logging.getLogger(__name__) + +@compatibility(is_backward_compatible=True) +class Partition: + def __init__(self, name: str): + self.name: str = name + self.submod_name = f"submod_{name}" + self.node_names: List[str] = [] + self.inputs: Dict[str, None] = {} + self.outputs: Dict[str, None] = {} + self.dependencies: Dict[str, None] = {} + self.dependents: Dict[str, None] = {} + self.graph: torch.fx.graph.Graph = torch.fx.graph.Graph() + self.environment: Dict[Node, Node] = {} + self.targets: Dict[str, Any] = {} + + def __repr__(self) -> str: + return ( + f"name: {self.name},\n" + f" nodes: {self.node_names},\n" + f" inputs: {self.inputs},\n" + f" outputs: {self.outputs},\n" + f" partitions depended on: {self.dependencies},\n" + f" partition dependents: {self.dependents}" + ) + + +# Creates subgraphs out of main graph +@compatibility(is_backward_compatible=True) +def split_module( + m: GraphModule, + root_m: torch.nn.Module, + split_callback: Callable[[Node], int], + qualname_map: Optional[Dict[str, str]] = None, + keep_original_order: Optional[bool] = False, + keep_original_node_name: Optional[bool] = False, +): + """ + Creates subgraphs out of main graph + + Args: + m (GraphModule): Graph module to split + root_m (torch.nn.Module): root nn module. Not currently used. Included + because the root nn module is usually transformed via + torch.fx._symbolic_trace.symbolic_trace (see example below) + split_callback (Callable[[Node], int]): Callable function + that maps a given Node instance to a numeric partition identifier. + split_module will use this function as the policy for which operations + appear in which partitions in the output Module. + qualname_map: Optional[Dict[str, str]]: optional output parameter that returns a + mapping from new target names in the module after split to old target + names in the original module. + keep_original_order: Optional[bool]: keep the original order of the GraphModule + or use the Topological order of the new constructed GraphModule + + + Returns: + GraphModule: the module after split. + + Example: + + This is a sample setup: + + import torch + from torch.fx.symbolic_trace import symbolic_trace + from torch.fx.graph_module import GraphModule + from torch.fx.node import Node + from torch.fx.passes.split_module import split_module + + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.param = torch.nn.Parameter(torch.rand(3, 4)) + self.linear = torch.nn.Linear(4, 5) + + def forward(self, x, y): + z = self.linear(x + self.param).clamp(min=0.0, max=1.0) + w = self.linear(y).clamp(min=0.0, max=1.0) + return z + w + + # symbolically trace model + my_module = MyModule() + my_module_traced = symbolic_trace(my_module) + + # random mod partitioning + partition_counter = 0 + NPARTITIONS = 3 + + def mod_partition(node: Node): + global partition_counter + partition = partition_counter % NPARTITIONS + partition_counter = (partition_counter + 1) % NPARTITIONS + return partition + + # split module in module with submodules + module_with_submodules = split_module( + my_module_traced, my_module, mod_partition + ) + + Output looks like this. Original graph is broken into partitions + + > print(module_with_submodules) + GraphModule( + (submod_0): GraphModule( + (linear): Linear(in_features=4, out_features=5, bias=True) + ) + (submod_1): GraphModule( + (linear): Linear(in_features=4, out_features=5, bias=True) + ) + (submod_2): GraphModule() + ) + + def forward(self, x, y): + param = self.param + submod_0 = self.submod_0(x, param, y); x = param = y = None + getitem = submod_0[0] + getitem_1 = submod_0[1]; submod_0 = None + submod_1 = self.submod_1(getitem, getitem_1); getitem = getitem_1 = None + getitem_2 = submod_1[0] + getitem_3 = submod_1[1]; submod_1 = None + submod_2 = self.submod_2(getitem_2, getitem_3); getitem_2 = getitem_3 = None + return submod_2 + + Output of split module is the same as output of input traced module. + This is an example within a test setting: + + > orig_out = my_module_traced(x, y) + > submodules_out = module_with_submodules(x, y) + > self.assertEqual(orig_out, submodules_out) + True + """ + + def construct_graph( + node: Node, + base_mod_env: Dict[str, Node], + base_mod_attrs: Dict[str, torch.fx.graph_module.GraphModule], + ): + if node.op == "placeholder": + default_value = ( + node.args[0] if len(node.args) > 0 else inspect.Signature.empty + ) + if keep_original_node_name: + args = () if default_value is inspect.Signature.empty else (default_value,) + base_mod_env[node.name] = base_mod_graph.create_node('placeholder', node.name, args=args, type_expr=node.type) + else: + base_mod_env[node.name] = base_mod_graph.placeholder( + node.target, type_expr=node.type, default_value=default_value + ) + base_mod_env[node.name].meta = node.meta.copy() + elif node.op == "get_attr": + base_mod_env[node.name] = base_mod_graph.get_attr(node.target) + base_mod_env[node.name].meta = node.meta.copy() + attr_val = m + for atom in node.target.split("."): # type: ignore[union-attr] + if not hasattr(attr_val, atom): + raise AttributeError(f"Node target {node.target} not found!") + attr_val = getattr(attr_val, atom) + base_mod_attrs[node.target] = attr_val # type: ignore[index] + return base_mod_env, base_mod_attrs + + partitions: Dict[str, Partition] = {} + orig_nodes: Dict[str, Node] = {} + symbol_to_node: Dict["sympy.Symbol", Node] = {} + + def record_cross_partition_use( + def_node: Node, use_node: Optional[Node] + ): # noqa: B950 + from torch.fx.experimental.symbolic_shapes import free_symbols + + defined = getattr(def_node, "_fx_partition", None) + used = getattr(use_node, "_fx_partition", None) + if defined != used: + if defined is not None: + def_partition = partitions[defined] + def_partition.outputs.setdefault(def_node.name) + if used is not None: + def_partition.dependents.setdefault(used) + + if used is not None: + use_partition = partitions[used] + use_partition.inputs.setdefault(def_node.name) + if (def_val := def_node.meta.get("example_value")) is not None: + for s in sorted(free_symbols(def_val), key=str): + use_partition.inputs.setdefault(symbol_to_node[s].name) + if defined is not None: + use_partition.dependencies.setdefault(defined) + + def instantiate_node_partition_mapping(node): + partition_name = str(split_callback(node)) + + # add node to partitions + partition = partitions.get(partition_name) + if partition is None: + partitions[partition_name] = partition = Partition(partition_name) + + partition.node_names.append(node.name) + node._fx_partition = partition_name + + # Global State Nodes are nodes which by their global state effects, + # "taint" all downstream nodes while they are active. + GLOBAL_STATE_NODES = [ + torch.amp._enter_autocast, + torch.amp._exit_autocast, + torch._C._set_grad_enabled + ] + + # For grad regions: + # ------------------------ + # 1. first region: we do nothing + # 2. subsequent regions: we insert the set_grad at the beginning + grad_regions: OrderedDict[Node, Set[int]] = OrderedDict() + + # For autocast regions: + # ------------------------ + # 1. first region: we will only insert the _exit at the end + # 2. intermediate regions: we will insert both the + # _enter at the beginning and _exit at the end + # 3. last region: we will only insert _enter at the beginning + # We will do so in the order in which the autocasts were instantiated. + autocast_regions: OrderedDict[Node, Set[int]] = OrderedDict() + autocast_exits: Dict[Node, Optional[Node]] = {} + + active_grad = None + active_autocasts = set() + + import sympy # noqa: F811 + + for node in m.graph.nodes: + if node.op in ["placeholder", "get_attr", "output"]: + if ( + node.op == "placeholder" and + (val := node.meta.get("example_value")) is not None and + isinstance(val, torch.SymInt) and + isinstance(val.node.expr, sympy.Symbol) + ): + symbol_to_node[val.node.expr] = node + continue + + instantiate_node_partition_mapping(node) + + if node.op == "call_function" and node.target in GLOBAL_STATE_NODES: + if node.target == torch._C._set_grad_enabled: + assert len(node.args) == 1 + assert isinstance(node.args[0], bool) + active_grad = node + grad_regions[active_grad] = set({split_callback(node)}) + elif node.target == torch.amp._enter_autocast: + # Should all be python constants + assert all(not isinstance(arg, Node) for arg in node.args) + active_autocasts.add(node) + autocast_regions[node] = set({split_callback(node)}) + autocast_exits[node] = None + elif node.target == torch.amp._exit_autocast: + assert len(node.args) == 1 + autocast_regions[node.args[0]].add(split_callback(node)) + active_autocasts.remove(node.args[0]) + autocast_exits[node.args[0]] = node + + if active_grad is not None: + grad_regions[active_grad].add(split_callback(node)) + + for a in active_autocasts: + autocast_regions[a].add(split_callback(node)) + + assert all(v is not None for v in autocast_exits.values()), "autocast must exit" + + autocast_regions = {k: sorted(v) for k, v in autocast_regions.items()} + grad_regions = {k: sorted(v) for k, v in grad_regions.items()} + + if _LOGGER.isEnabledFor(logging.DEBUG): + _LOGGER.debug("autocast_regions: %s", autocast_regions) + _LOGGER.debug("grad_regions: %s", grad_regions) + + assert_monotonically_increasing = bool(autocast_regions) or bool(grad_regions) + + # split nodes into partitions + highest_partition = -1 + for node in m.graph.nodes: + orig_nodes[node.name] = node + + # TODO currently placeholders/parameters aren't put into random partitions, + # rather they're added to the graphs where they are used down below + if node.op in ["placeholder", "get_attr"]: + continue + if node.op == "output": + torch.fx.graph.map_arg( + node.args[0], lambda n: record_cross_partition_use(n, None) + ) + continue + + if assert_monotonically_increasing: + pid = split_callback(node) + assert highest_partition <= pid, \ + ("autocast or set_grad_enabled require monotonically increasing partitions:" + f"highest: {highest_partition}, this node's: {pid}") + highest_partition = pid + + # do not capture cross-partition dependencies for global state nodes as they will be + # self-contained - their setup and unwind will be isolated to each partition submodule. + if node.target not in GLOBAL_STATE_NODES: + torch.fx.graph.map_arg( + node.args, lambda def_node: record_cross_partition_use(def_node, node) + ) + torch.fx.graph.map_arg( + node.kwargs, lambda def_node: record_cross_partition_use(def_node, node) + ) # noqa: B950 + + original_partition_order = list(partitions.keys()) + # find partitions with no dependencies + root_partitions: List[str] = [] + for partition_name, partition in partitions.items(): + if not len(partition.dependencies): + root_partitions.append(partition_name) + + # check partitions for circular dependencies and create topological partition ordering + sorted_partitions: List[str] = [] + while root_partitions: + root_partition = root_partitions.pop() + sorted_partitions.append(root_partition) + for dependent in partitions[root_partition].dependents: + partitions[dependent].dependencies.pop(root_partition) + if not partitions[dependent].dependencies: + root_partitions.append(dependent) + if len(sorted_partitions) != len(partitions): + raise RuntimeError("cycle exists between partitions!") + + # Enter prelude + for regions_mapping in [autocast_regions, grad_regions]: + for node, regions in regions_mapping.items(): + assert len(regions) > 0 + partitions[str(regions[0])].environment[node] = node + for r in regions[1:]: + partition = partitions[str(r)] + new_node = partition.graph.create_node( + op=node.op, + target=node.target, + args=tuple(arg for arg in node.args), + kwargs={}, + type_expr=node.type, + ) + new_node.meta = node.meta.copy() # is it really a good idea to copy this? + partition.environment[node] = new_node + + # add placeholders to partition inputs + for partition_name in sorted_partitions: + partition = partitions[partition_name] + for inp in partition.inputs: + placeholder = partition.graph.placeholder( + inp, + type_expr=orig_nodes[inp].type, + ) + placeholder.meta = orig_nodes[inp].meta.copy() + partition.environment[orig_nodes[inp]] = placeholder + + # Transform nodes and collect targets for partition's submodule + for node in m.graph.nodes: + if hasattr(node, "_fx_partition"): + partition = partitions[node._fx_partition] + + # swap out old graph nodes in kw/args with references to new nodes in this submodule + environment = partition.environment + gathered_args = torch.fx.graph.map_arg(node.args, lambda n: environment[n]) + gathered_kwargs = torch.fx.graph.map_arg( + node.kwargs, lambda n: environment[n] + ) + + if node.op not in ["call_module", "get_attr"]: + target = node.target + else: + target_atoms = node.target.split(".") + target_attr = m + for atom in target_atoms: + if not hasattr(target_attr, atom): + raise AttributeError(f"Operator target {node.target} not found!") + target_attr = getattr(target_attr, atom) + # target = target_atoms[-1] + target = "_".join(target_atoms) + partition.targets[target] = target_attr + # Fill in the passed-in mapping from new qualname to old qualname + if qualname_map is not None: + # When creating the split module later, the submodules will have + # path prefix matching the corresponding partition's submod_name + qualname = f"{partition.submod_name}.{target}" + qualname_map[qualname] = node.target + + assert isinstance(gathered_args, tuple) + assert isinstance(gathered_kwargs, dict) + name = node.name if keep_original_node_name else None + new_node = partition.graph.create_node( + op=node.op, + target=target, + args=gathered_args, + kwargs=gathered_kwargs, + type_expr=node.type, + name=name, + ) + new_node.meta = node.meta.copy() + partition.environment[node] = new_node + + # Exit epilogue + for regions_mapping in [autocast_regions]: + for node in reversed(regions_mapping): + regions = regions_mapping[node] + assert len(regions) > 0 + for r in regions[:-1]: + partition = partitions[str(r)] + exit_node = autocast_exits[node] + assert exit_node is not None, "Missing exit node" + new_node = partition.graph.create_node( + op=exit_node.op, + target=exit_node.target, + args=(partition.environment[node],), + kwargs={}, + type_expr=exit_node.type, + ) + new_node.meta = exit_node.meta.copy() # is it really a good idea to copy this? + + # original module environment dict mapping node names to nodes + orig_mod_env: Dict[str, Node] = {} + # Set up values to construct base module + base_mod_env: Dict[str, Node] = {} + base_mod_graph: torch.fx.graph.Graph = torch.fx.graph.Graph() + base_mod_attrs: Dict[str, torch.fx.graph_module.GraphModule] = {} + if not keep_original_order: + for node in m.graph.nodes: + base_mod_env, base_mod_attrs = construct_graph( + node, base_mod_env, base_mod_attrs + ) + + else: + # Go through the graph to construct the mapping dict + for node in m.graph.nodes: + orig_mod_env[node.name] = node + + # Do some things iterating over the partitions in topological order again: + # 1) Finish off submodule Graphs by setting corresponding outputs + # 2) Construct GraphModules for each submodule + # 3) Construct the base graph by emitting calls to those submodules in + # topological order or original order specified by keep_original_order + + construct_order_partitions = ( + sorted_partitions if not keep_original_order else original_partition_order + ) + + already_constructed_attr_nodes = set() + for partition_name in construct_order_partitions: + partition = partitions[partition_name] + + # Set correct output values + output_vals = tuple( + partition.environment[orig_nodes[name]] for name in partition.outputs + ) + + # skip output node generation if there are no output values + num_output_vals = len(output_vals) + if num_output_vals == 1: + partition.graph.output(output_vals[0]) + elif num_output_vals > 1: + partition.graph.output(output_vals) + + if keep_original_order: + # first get the attr nodes required by this partition + orig_mod_attr_nodes: List[Node] = [ + orig_mod_env[key] for key in partition.inputs + ] + # Construct GraphModule for this partition + for node in orig_mod_attr_nodes: # type: ignore[attr-defined] + if node in already_constructed_attr_nodes: + continue + base_mod_env, base_mod_attrs = construct_graph( + node, base_mod_env, base_mod_attrs + ) + already_constructed_attr_nodes.add(node) + + base_mod_attrs[partition.submod_name] = torch.fx.graph_module.GraphModule( + partition.targets, partition.graph + ) # noqa: B950 + + # Emit call in base graph to this submodule + output_val = base_mod_graph.call_module( + partition.submod_name, + tuple(base_mod_env[name] for name in partition.inputs), + ) + + num_outputs = len(partition.outputs) + if num_outputs > 1: + # Unpack multiple return values from submodule + output_val_proxy = torch.fx.proxy.Proxy(output_val) + for i, output_name in enumerate(partition.outputs): + base_mod_env[output_name] = output_val_proxy[i].node # type: ignore[index] + elif num_outputs == 1: + base_mod_env[next(iter(partition.outputs))] = output_val + + for node in m.graph.nodes: + if node.op == "output": + base_mod_graph.output( + torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name]) + ) # noqa: B950 + + return torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/splitter_base.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/splitter_base.py new file mode 100644 index 0000000000000000000000000000000000000000..3a493f4af335d0cfc0a1f17cf87c3eb8a629f0db --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/splitter_base.py @@ -0,0 +1,871 @@ +import argparse +import copy +from collections import defaultdict +from dataclasses import dataclass +from typing import NamedTuple, Sequence, Iterable, Any, List, Dict, Optional, Tuple +import logging + +import torch +from torch.fx.passes.graph_manipulation import get_size_of_node +from torch.fx.node import map_arg +from torch.fx._compatibility import compatibility + +from .operator_support import ( + get_node_target, + OperatorSupportBase, +) +from .graph_drawer import FxGraphDrawer +from .shape_prop import ShapeProp +from .split_utils import split_by_tags +from .tools_common import ( + FxNetAccFusionsFinder, + CALLABLE_NODE_OPS, + Tensors, + NodeList, + NodeSet, + is_node_output_tensor, +) + + +__all__ = ['FxNetAccNodesFinder', 'FxNetSplitterInternalError', 'Subgraph', 'SplitResult', 'generate_inputs_for_submodules'] +_LOGGER = logging.getLogger(__name__) + +DEFAULT_MIN_ACC_MODULE_SIZE = 1 +DEFAULT_SKIP_FUSION = False +DEFAULT_ALLOW_NON_TENSOR = False + +class _SplitterSettingBase: + def __init__( + self, + min_acc_module_size=DEFAULT_MIN_ACC_MODULE_SIZE, + skip_fusion=DEFAULT_SKIP_FUSION, + allow_non_tensor=DEFAULT_ALLOW_NON_TENSOR + ): + parser = argparse.ArgumentParser() + parser.add_argument( + "--min-acc-module-size", + "--min_acc_module_size", + required=False, + type=int, + help="Minimum size limit of an accelerator subgraph.", + ) + parser.add_argument( + "--skip-fusion", + "--skip_fusion", + default=False, + action="store_true", + help="If true then no fusion groups. Fusion group is used to " + "enforce no non-tensor data flow between submodules. If we don't " + "have this constrain, setting this to false is recommended as it " + "can reduce overhead.", + ) + parser.add_argument( + "--allow-non-tensor", + "--allow_non_tensor", + default=False, + action="store_true", + help="For some backends non-tensor data flow between cpu and them " + "are not allowed. Therefore, if a node supported by accelerator but " + "it has non-tensor inputs or outputs to a cpu node we would want to " + "consider it as a cpu node during splitting. However, for some backends " + "we might not care about non-tensor data flow and we can set this option " + "to true to disable the functionality that prevent non-tensor data flow.", + ) + args, unknown = parser.parse_known_args() + + self.min_acc_module_size: int = args.min_acc_module_size if args.min_acc_module_size else min_acc_module_size + self.skip_fusion: bool = args.skip_fusion if args.skip_fusion else skip_fusion + self.allow_non_tensor: bool = args.allow_non_tensor if args.allow_non_tensor else allow_non_tensor + + +@compatibility(is_backward_compatible=False) +class FxNetAccNodesFinder: + """ + Finds a set of nodes that can be supported on ACC, excluding nodes that have non-tensor + input/output to cpu nodes to prevent non-tensor data flow between backends and cpu. + + I.e. if we have a chain: + + ACC_NODE_1 -> ACC_NODE_2 -> ACC_NODE_3 -> CPU_NODE_1 + + where every ACC node produces non-tensor output, then they all should be treated as CPU nodes. + + This behavior can be turned off by passing allow_non_tensor=True. + """ + + def __init__( + self, + module: torch.fx.GraphModule, + operator_support: OperatorSupportBase, + allow_non_tensor: bool, + ): + self.module = module + self.operator_support = operator_support + self.allow_non_tensor = allow_non_tensor + + def reduce_acc_nodes_non_tensor_input_helper( + self, cpu_worklist: NodeList + ): + """ + Transitively excludes nodes from ACC supported set. + For every node in the worklist: + - removes its downstream ACC nodes from ACC supported set, + - if any downstream ACC node produces non-tensor output, + then it gets added into the worklist. + """ + while cpu_worklist: + node = cpu_worklist.pop(0) + + for user in node.users: + if user in self.acc_nodes: + self.acc_nodes.remove(user) + if not is_node_output_tensor(user): + cpu_worklist.append(user) + + def reduce_acc_nodes_non_tensor_input(self): + """ + Excludes nodes from ACC supported set that have direct + upstream CPU nodes that produce non-tensor outputs. + """ + non_tensor_cpu_nodes: NodeList = [] + + for node in self.module.graph.nodes: + if node.op not in CALLABLE_NODE_OPS: + continue + if node in self.acc_nodes: + continue + if is_node_output_tensor(node): + continue + non_tensor_cpu_nodes.append(node) + + self.reduce_acc_nodes_non_tensor_input_helper(non_tensor_cpu_nodes) + + def reduce_acc_nodes_non_tensor_output(self): + """ + Excludes nodes from ACC supported set that produce non-tensor + outputs and have downstream CPU nodes. + """ + while True: + new_cpu_nodes: NodeList = [] + + for acc_node in self.acc_nodes: + if is_node_output_tensor(acc_node): + continue + for user in acc_node.users: + if user not in self.acc_nodes: + new_cpu_nodes.append(acc_node) + break + + if not new_cpu_nodes: + break + + for new_cpu_node in new_cpu_nodes: + self.acc_nodes.remove(new_cpu_node) + + self.reduce_acc_nodes_non_tensor_input_helper(new_cpu_nodes) + + def __call__(self) -> NodeSet: + submodules = dict(self.module.named_modules()) + self.acc_nodes = { + n + for n in self.module.graph.nodes + if n.op in CALLABLE_NODE_OPS + and self.operator_support.is_node_supported(submodules, n) + } + + if not self.allow_non_tensor: + self.reduce_acc_nodes_non_tensor_input() + self.reduce_acc_nodes_non_tensor_output() + + return self.acc_nodes + +@compatibility(is_backward_compatible=False) +class FxNetSplitterInternalError(Exception): + pass + +@compatibility(is_backward_compatible=False) +@dataclass +class Subgraph: + is_acc: bool + nodes: NodeList + + +@compatibility(is_backward_compatible=False) +class SplitResult(NamedTuple): + """ + Stores the results of the splitter. + + Attributes: + split_module: root module after splitting. + submodule_inputs: a dict that maps submodule name to its inputs. + non_acc_submodule_prefix: the prefix for non acc submodules. For + acc submodule the prefix is alwasy "_run_on_acc_". + """ + + split_module: torch.fx.GraphModule + submodule_inputs: Dict[str, Any] + non_acc_submodule_prefix: str + + +@compatibility(is_backward_compatible=False) +def generate_inputs_for_submodules( + model: torch.nn.Module, + inputs: Sequence[Any], + target_submodules: Iterable[str], + deepcopy: bool = False, +) -> Dict[str, Any]: + """ + Generate inputs for targeting submdoules in the given model. Note that if two submodules refer to the same obj, this + function doesn't work. + + Args: + model: root model. + inputs: inputs to the root model. + target_submodules: submodules that we want to generate inputs for. + + Returns: + A dict that maps from submodule name to its inputs. + """ + + handles = [] + results = {} + submodule_to_names = {mod: name for name, mod in model.named_modules()} + + def pre_forward(module, module_inputs): + results[submodule_to_names[module]] = copy.deepcopy(module_inputs) if deepcopy else module_inputs + + for name, mod in model.named_modules(): + if name in target_submodules: + handles.append(mod.register_forward_pre_hook(pre_forward)) + + def clean_up_handles(): + for h in handles: + h.remove() + + try: + with torch.no_grad(): + model(*inputs) + except Exception as e: + clean_up_handles() + raise e + + clean_up_handles() + return results + + +class _SplitterBase: + """ + Splits a GraphModule into sub-GraphModules for execution on CPU or the accelerator. + Output is a GraphModule with supported and unsupported operators grouped into as few sub-GraphModules as possible. + Assumes that only "call_module", "call_function" and "call_method" from FX IR can potentially be executed on the accelerator. + + Given the following graph: + ==> b ==> + // \\ + a d + \\ // + ==> c ==> + + class SimpleModule(torch.nn.Module): + def forward(self, a): + b = torch.sin(a) + c = torch.cos(a) + d = b + c + return d + + and providing "operator_support" that indicates that 'b' and 'c' can be executed on the accelerator, + we will get the following split result: + + main: + def forward(self, a): + run_on_acc_0_0 = self._run_on_acc_0_0(a) + getitem = run_on_acc_0_0[0] + getitem_1 = run_on_acc_0_0[1] + run_on_cpu_1_1 = self._run_on_cpu_1_1(getitem, getitem_1) + return run_on_cpu_1_1 + + _run_on_acc_0_0: + def forward(self, a): + sin_1 = torch.sin(a) + cos_1 = torch.cos(a) + return (sin_1, cos_1) + + _run_on_cpu_1_1: + def forward(self, sin_1, cos_1): + add_1 = sin_1 + cos_1 + return add_1 + """ + + # PCIe bandwidth for the backend, default to 100 GB/s + PCIe_BW = 100 * 2 ** 30 + + def __init__( + self, + module: torch.fx.GraphModule, + sample_input: Sequence[Any], + operator_support: OperatorSupportBase, + settings: _SplitterSettingBase, + non_acc_submodule_name: str = "_run_on_cpu_", + ): + """ + Preprocesses graph before splitting: + - finds nodes supported by ACC, + - finds fusion groups for ACC nodes having non-tensor IO, + - builds a graph of direct dependencies, + - builds a map of fused nodes to their fusions. + As a result we get self.acc_nodes, self.deps and self.fusions. + """ + assert isinstance(module, torch.fx.GraphModule) + + self.module = module + ShapeProp(self.module).propagate(*sample_input) + + self.settings = settings + self.operator_support = operator_support + self.sample_input = sample_input + self.acc_nodes = FxNetAccNodesFinder(self.module, self.operator_support, self.settings.allow_non_tensor)() + + if self.settings.skip_fusion: + self.fusions = {} + else: + self.fusions = FxNetAccFusionsFinder(module, self.acc_nodes)() + + # Modify deps to add more deps for fused nodes + self.deps = self.find_deps() + self.update_deps_for_fusions() + + self.non_acc_submodule_name = non_acc_submodule_name + self._node_submodule_map: Dict[str, str] = {} + + # =============================================================== + # Helpers for ctor and initial state + # =============================================================== + + def get_node_submodule_map(self) -> Dict[str, str]: + """ Returns a map from node name to submodule name, e.g. + node: main_module_impl_impl_over_arch_unary_multiple_embedding + _pooling_embedding_pooling_sparse_entity_equivalence_key + _proxy_embedding_bag + maps to submodule name of: _run_on_acc_1 + """ + return self._node_submodule_map + + def find_deps(self) -> Dict[torch.fx.Node, NodeSet]: + """ + Builds a graph of node dependencies. Leaf nodes don't have any + dependencies and the "output" node doesn't have nodes depending on it. + + Resulting graph has only direct dependencies, i.e. there are no + transitive dependencies. + """ + deps: Dict[torch.fx.Node, NodeSet] = defaultdict(set) + for node in self.module.graph.nodes: + if node.op not in CALLABLE_NODE_OPS: + continue + + for user in node.users: + if user.op != "output": + deps[user].add(node) + return deps + + def update_deps_for_fusions(self): + """ + Updates graph of dependencies so that: + - nodes from the same fusion depend on the same set of outer nodes, + - outer nodes depending on a fusion depend on all nodes in that fusion. + """ + for node in self.fusions: + fusion = self.fusions[node] + for fused_neighbor in fusion: + self.deps[node].update(self.deps[fused_neighbor] - fusion) + + for user in fused_neighbor.users: + if user not in fusion: + self.deps[user].add(node) + + # =============================================================== + # Helpers for preview + # =============================================================== + + def _lower_model_to_backend( + self, mod: torch.fx.GraphModule, inputs: Tensors + ) -> torch.nn.Module: + """ + Lower the model to a backend. + """ + + return mod + + def _find_culprit( + self, mod: torch.fx.GraphModule, inputs: Tensors + ) -> str: + """ + When an error occurs during lowering or running the lowered mod, we use this + function to find culprits in the `mod` that causes the error. + """ + + return "Unable to find a culprit because _find_culprit() function is not implemented." + + def _draw_graph_based_on_node_support( + self, mod: torch.fx.GraphModule, supported_nodes: NodeList + ): + color_map = { + "default": "AliceBlue", + "supported": "chartreuse1", + "unsupported": "crimson", + } + + class CustomDrawer(FxGraphDrawer): + def _get_node_style(self, node): + template = super()._get_node_style(node) + if node in supported_nodes: + template["fillcolor"] = color_map["supported"] + elif node.op in CALLABLE_NODE_OPS: + template["fillcolor"] = color_map["unsupported"] + else: + template["fillcolor"] = color_map["default"] + + return template + + drawer = CustomDrawer(mod, "node_support", ignore_getattr=True) + dot_graph = drawer.get_main_dot_graph() + dot_graph.write_raw("node_support.dot") + + def node_support_preview(self, dump_graph: bool = False): + submodules = dict(self.module.named_modules()) + + supported_nodes: NodeList = [] + supported_node_types = defaultdict(set) + unsupported_node_types = defaultdict(set) + + def get_dtype(arg): + tensor_meta = arg.meta.get("tensor_meta") + return getattr(tensor_meta, "dtype", None) + + for node in self.module.graph.nodes: + if node.op not in CALLABLE_NODE_OPS: + continue + + target = get_node_target(submodules, node) + + # Store dtype of arg in node.args. If arg doesn't have dtype, i.e. not a tensor, we'll store None. + arg_dtypes = [ + get_dtype(arg) if isinstance(arg, torch.fx.Node) else None + for arg in node.args + ] + + # Find last non-None element. If all elements are None, return max_len. + last_index = len(arg_dtypes) - next( + ( + i + for i, dtype in enumerate(reversed(arg_dtypes)) + if dtype is not None + ), + len(arg_dtypes), + ) + + # Strip None elements at the end. + arg_dtypes_tuple = tuple(arg_dtypes[:last_index]) + kwarg_dtypes_tuple = tuple( + (k, get_dtype(arg)) + for k, arg in node.kwargs.items() + if isinstance(arg, torch.fx.Node) + ) + + if self.operator_support.is_node_supported(submodules, node): + supported_nodes.append(node) + supported_node_types[target].add((arg_dtypes_tuple, kwarg_dtypes_tuple)) + else: + unsupported_node_types[target].add((arg_dtypes_tuple, kwarg_dtypes_tuple)) + + if dump_graph: + self._draw_graph_based_on_node_support(self.module, supported_nodes) + + reports = "\nSupported node types in the model:\n" + for t, dtypes in supported_node_types.items(): + for arg_dtypes_tuple, kwarg_dtypes_tuple in dtypes: + reports += f"{t}: ({arg_dtypes_tuple}, {dict(kwarg_dtypes_tuple)})\n" + + reports += "\nUnsupported node types in the model:\n" + for t, dtypes in unsupported_node_types.items(): + for arg_dtypes_tuple, kwarg_dtypes_tuple in dtypes: + reports += f"{t}: ({arg_dtypes_tuple}, {dict(kwarg_dtypes_tuple)})\n" + + print(reports) + + # Return reports for testing purpose + return reports + + def split_preview(self, dump_graph: bool = False): + reports = "" + subgraphs = self.put_nodes_into_subgraphs() + acc_subgraphs_num = len([g for g in subgraphs if g.is_acc]) + cpu_subgraphs_num = len(subgraphs) - acc_subgraphs_num + reports += f"Before removing small acc subgraphs, total {len(subgraphs)} subgraphs are created:" + reports += f" {acc_subgraphs_num} acc subgraphs and {cpu_subgraphs_num} cpu subgraphs.\n" + + subgraphs = self.remove_small_acc_subgraphs(subgraphs) + acc_subgraphs_num = len([g for g in subgraphs if g.is_acc]) + cpu_subgraphs_num = len(subgraphs) - acc_subgraphs_num + reports += f"After removing small acc subgraphs, total {len(subgraphs)} subgraphs are created:" + reports += f" {acc_subgraphs_num} acc subgraphs and {cpu_subgraphs_num} cpu subgraphs.\n" + + for i, subgraph in enumerate(subgraphs): + reports += f"_run_on_acc_{i}: " if subgraph.is_acc else f"{self.non_acc_submodule_name}{i}: " + reports += f"{len(subgraph.nodes)} node(s)\n" + + self.tag(subgraphs) + split_mod = self.split(remove_tag=True) + split_mod.eval() + + if dump_graph: + drawer = FxGraphDrawer( + split_mod, "preview", ignore_getattr=True + ) + dot_graphs = drawer.get_all_dot_graphs() + for name, dot_graph in dot_graphs.items(): + dot_graph.write_raw(f"{name}.dot") + + max_qps: float = self.PCIe_BW + bottleneck_module = "" + + for node in split_mod.graph.nodes: + if node.op == "call_module" and "acc" in node.target: + reports += f"\nProcessing acc submodule {node.target}\n" + + submod = getattr(split_mod, node.target) + + def get_submod_inputs(main_mod, submod, example_inputs): + sub_inputs = None + + def get_inputs(self, inputs): + nonlocal sub_inputs + sub_inputs = inputs + + handle = submod.register_forward_pre_hook(get_inputs) + main_mod(*example_inputs) + handle.remove() + return sub_inputs + + submod_inputs = get_submod_inputs( + split_mod, submod, self.sample_input + ) + ShapeProp(submod).propagate(*submod_inputs) + + total_input_bytes = 0 + total_output_bytes = 0 + + reports += "Checking inputs...\n" + for n in submod.graph.nodes: + if n.op == "placeholder": + if not is_node_output_tensor(n): + reports += f"Input {n.name} is not a tensor, this might cause problems during lowering!\n" + else: + total_input_bytes += get_size_of_node(submod, n)[0] + if n.op == "output": + output_node = n + + reports += "Checking outputs...\n" + + def get_bytes(node: torch.fx.Node): + nonlocal total_output_bytes + nonlocal reports + if not is_node_output_tensor(node): + reports += f"Output {node.name} is not a tensor, this might cause problems during lowering!\n" + else: + total_output_bytes += get_size_of_node(submod, node)[0] + + map_arg(output_node.args, get_bytes) # type: ignore[possibly-undefined] + qps = self.PCIe_BW / max(total_input_bytes, total_output_bytes) + reports += f"Total input size in bytes is {total_input_bytes}, total output size in bytes is {total_output_bytes}," + reports += f" theoretical max qps (bounds by PCIe bandwidth) for this submodule is {qps}.\n" + + if qps < max_qps: + max_qps = qps + bottleneck_module = node.target + + try: + lowered_submod = self._lower_model_to_backend(submod, submod_inputs) + except RuntimeError: + reports += "Run into an error during lowering!\n" + reports += self._find_culprit(submod, submod_inputs) + continue + + try: + lowered_submod(*submod_inputs) + except RuntimeError: + reports += "Run into an error during inference!\n" + reports += self._find_culprit(submod, submod_inputs) + else: + reports += "Lowering and running succeed!\n" + + reports += f"\nTheoretical max qps (bounds by PCIe bandwidth) for this model is {max_qps}," + reports += f" bottleneck is submodule {bottleneck_module}." + print(reports) + + # return the reports for testing purposes + return reports + + # =============================================================== + # Helpers for extend_acc_subgraph() method + # =============================================================== + + def find_reverse_deps( + self, tag_id: Optional[int] = None + ) -> Dict[torch.fx.Node, NodeSet]: + """ + Builds reversed topological node dependencies, if tag_id is specified, + we ignore nodes that are in later subgraph i.e. nodes have greater tag_id. + """ + result: Dict[torch.fx.Node, NodeSet] = defaultdict(set) + + for node in self.module.graph.nodes: + if node.op not in CALLABLE_NODE_OPS: + continue + + for user in node.users: + if user.op not in CALLABLE_NODE_OPS: + continue + + if tag_id is None or (int(user.tag.split("_")[-1]) < tag_id): + result[node].add(user) + + return result + + def update_reverse_deps_for_fusions( + self, deps: Dict[torch.fx.Node, NodeSet] + ): + processed_node = set() + + for node, fusion in self.fusions.items(): + if node in processed_node: + continue + + new_dep = set() + + # Create a new dependency set which include all the + # dependencies of the nodes in the fusion group + for n in fusion: + new_dep.update(deps[n]) + + # Exclude nodes in the fusion + new_dep.difference_update(fusion) + + # Update dependency + for n in fusion: + deps[n] = new_dep + + for arg in n.all_input_nodes: + if arg not in fusion: + deps[arg].update(fusion) + + processed_node.add(n) + + def find_parent_nodes_of_subgraph(self, tag: str) -> NodeSet: + """ + Finds parent nodes of the `tag` subgraph. + + Traverse the inputs of nodes in the subgraph, if input doesn't belong to the subgraph + and is not a placeholder, we consider it as the parent node of the subgraph. + """ + parent_nodes = set() + + for node in self.module.graph.nodes: + if node.op in CALLABLE_NODE_OPS and node.tag == tag: + for arg in node.all_input_nodes: + if arg.op in CALLABLE_NODE_OPS and arg.tag != tag: + parent_nodes.add(arg) + + return parent_nodes + + def extend_acc_subgraph(self, tag: str): + """ + Extend the acc subgraph with `tag` going the reversed topological direction. + """ + # Dict that maps node to its users and ignore users that + # are in the subgraph that has greater tag + deps = self.find_reverse_deps(tag_id=int(tag.split("_")[-1])) + self.update_reverse_deps_for_fusions(deps) + + # Parent nodes of the subgraph + parent_nodes = self.find_parent_nodes_of_subgraph(tag) + + visited_nodes: NodeSet = set() + + while parent_nodes: + node = None + + # Find a acc node that depends on visited nodes only + for n in parent_nodes: + if deps[n] <= visited_nodes and n in self.acc_nodes: + node = n + break + + if node is None: + break + + # Put the node into `tag` subgraph + node.tag = tag # type: ignore[attr-defined] + parent_nodes.remove(node) + visited_nodes.add(node) + + # If node is in a fusion group, add all fusion buddies to parent nodes + if node in self.fusions: + for fusion_node in self.fusions[node]: + if fusion_node not in visited_nodes: + parent_nodes.add(fusion_node) + + # Add inputs of the node to parent nodes + for arg in node.all_input_nodes: + if arg.op in CALLABLE_NODE_OPS and arg not in visited_nodes: + parent_nodes.add(arg) + + # =============================================================== + # Helpers for split() method + # =============================================================== + + def starter_nodes(self) -> Tuple[NodeSet, NodeSet]: + """ + Finds nodes that consume module inputs or get_attr nodes. + """ + starter_cpu_nodes: NodeSet = set() + starter_acc_nodes: NodeSet = set() + for node in self.module.graph.nodes: + if node.op not in {"placeholder", "get_attr"}: + continue + for user in node.users: + if user in self.acc_nodes: + starter_acc_nodes.add(user) + else: + starter_cpu_nodes.add(user) + return starter_cpu_nodes, starter_acc_nodes + + def put_nodes_into_subgraphs(self) -> List[Subgraph]: + # We start graph traversal from leaf nodes + current_cpu_nodes, current_acc_nodes = self.starter_nodes() + visited_nodes: NodeSet = set() + + # Determine which subgraph to start from based on which subgraph has + # 0-dep node + acc_subgraph: bool = not any(len(self.deps[n]) == 0 for n in current_cpu_nodes) + + current_subgraph_nodes: NodeList = [] + + # Result accumulator + subgraphs: List[Subgraph] = [] + while current_cpu_nodes or current_acc_nodes: + # Find the first node that should belong to the current subgraph and has all dependencies resolved + current_nodes = current_acc_nodes if acc_subgraph else current_cpu_nodes + node = next( + (n for n in current_nodes if self.deps[n] <= visited_nodes), + None, + ) + + # If nothing was found, then it's time to flip the mode and start a new subgraph + if node is None: + if not current_subgraph_nodes: + raise FxNetSplitterInternalError("Subgraph can't be empty") + + subgraphs.append( + Subgraph(is_acc=acc_subgraph, nodes=current_subgraph_nodes) + ) + acc_subgraph = not acc_subgraph + current_subgraph_nodes = [] + continue + + current_nodes.remove(node) + visited_nodes.add(node) + current_subgraph_nodes.append(node) + + # Add fusion buddies + if node in self.fusions: + if node in self.acc_nodes: + current_acc_nodes.update(self.fusions[node] - visited_nodes) + else: + current_cpu_nodes.update(self.fusions[node] - visited_nodes) + + # Put depending nodes into the queue + for user in node.users: + if user.op not in CALLABLE_NODE_OPS: + continue + + # Add downstream nodes + if user in self.acc_nodes: + current_acc_nodes.add(user) + else: + current_cpu_nodes.add(user) + + # Check if the last subgraph was not created + if current_subgraph_nodes: + subgraphs.append( + Subgraph(is_acc=acc_subgraph, nodes=current_subgraph_nodes) + ) + + if not subgraphs: + raise FxNetSplitterInternalError("Couldn't create subgraphs") + + return subgraphs + + def remove_small_acc_subgraphs(self, subgraphs: List[Subgraph]) -> List[Subgraph]: + """ + This pass finds ACC submodules with less than specified size and merges + them with adjacent CPU submodules. + """ + result: List[Subgraph] = [] + for subgraph in subgraphs: + if subgraph.is_acc: + if len(subgraph.nodes) >= self.settings.min_acc_module_size: + result.append(subgraph) + else: + print( + "Eliminating acc subgraph because it's smaller than the threshold: " + f"{len(subgraph.nodes)} < {self.settings.min_acc_module_size}" + ) + if result: + result[-1].nodes.extend(subgraph.nodes) + else: + subgraph.is_acc = False + result.append(subgraph) + else: + if result and not result[-1].is_acc: + result[-1].nodes.extend(subgraph.nodes) + else: + result.append(subgraph) + return result + + def tag(self, subgraphs: List[Subgraph]): + self.tags: List[str] = [] + for subgraph in subgraphs: + tag = f"_run_on_acc_{len(self.tags)}" if subgraph.is_acc else f"{self.non_acc_submodule_name}{len(self.tags)}" + self.tags.append(tag) + for node in subgraph.nodes: + if hasattr(node, "tag"): + raise FxNetSplitterInternalError(f"Node {node} was already tagged") + + node.tag = tag # type: ignore[attr-defined] + self._node_submodule_map[node.name] = tag + + def split(self, remove_tag: bool = False) -> torch.fx.GraphModule: + split_module = split_by_tags(self.module, self.tags) + if remove_tag: + for node in self.module.graph.nodes: + if hasattr(node, "tag"): + del node.tag + return split_module + + def __call__(self) -> torch.fx.GraphModule: + subgraphs = self.put_nodes_into_subgraphs() + subgraphs = self.remove_small_acc_subgraphs(subgraphs) + acc_subgraphs_count = len([s for s in subgraphs if s.is_acc]) + non_acc_subgraphs_count = len(subgraphs) - acc_subgraphs_count + print(f"Got {acc_subgraphs_count} acc subgraphs and {non_acc_subgraphs_count} non-acc subgraphs") + self.tag(subgraphs) + return self.split() + + def generate_split_results(self) -> SplitResult: + split_module = self() + submodule_names = [] + for name, mod in split_module.named_children(): + submodule_names.append(name) + submodule_inputs = generate_inputs_for_submodules(split_module, self.sample_input, submodule_names) + return SplitResult(split_module, submodule_inputs, self.non_acc_submodule_name) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/tools_common.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/tools_common.py new file mode 100644 index 0000000000000000000000000000000000000000..8a5b7d87b8e420bb72dffe30c7535a30052e3c66 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/tools_common.py @@ -0,0 +1,273 @@ +from typing import List, Tuple, Union, Dict, Any, Set, Mapping, Optional +import collections +from dataclasses import dataclass + +import torch +import torch.fx +from torch.fx.node import _get_qualified_name +from torch.fx._compatibility import compatibility + +__all__ = ['get_acc_ops_name', 'get_node_target', 'is_node_output_tensor', 'FxNetAccFusionsFinder', 'legalize_graph'] + +Tensors = Union[Tuple[torch.Tensor], List[torch.Tensor]] +TensorOrTensors = Union[torch.Tensor, Tensors] +NodeList = List[torch.fx.Node] +NodeSet = Set[torch.fx.Node] +Names = List[str] +CALLABLE_NODE_OPS = {"call_module", "call_function", "call_method"} + + +@compatibility(is_backward_compatible=False) +def get_acc_ops_name(k): + if isinstance(k, str): + return k + elif k.__module__ and "acc_ops" in k.__module__: + return f"acc_ops.{k.__name__}" + else: + module = k.__module__.replace('torch._ops', 'torch.ops') # WAR for bug in how torch.ops assigns module + return f"{module if module else ''}.{k.__name__}" + + +@compatibility(is_backward_compatible=False) +def get_node_target(submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node) -> str: + """ + Given a `node` returns its target typename. + + For "call_method" node, return node.target which is the name of that method being called. + This could potential lead to conflict but should be okay because normally it's on a tensor. + + For "call_function" node, return typename of node.target. + + For "call_module" node, return typename of the module that node.target point to. + + If seeing "_VariableFunctionsClass" in the target name string, it will be replaced by + "torch". e.g. _VariableFunctionsClass.relu would become torch.relu. + """ + + assert node.op in CALLABLE_NODE_OPS, ( + "Expect op types of " + ", ".join(CALLABLE_NODE_OPS) + f", but found {node.op}" + ) + + if node.op == "call_module": + assert isinstance(node.target, str) + submod = submodules[node.target] + submod_type = getattr(submod, "_base_class_origin", type(submod)) + return get_acc_ops_name(submod_type) + elif node.op == "call_function": + target: Any = node.target + return ( + f"acc_ops.{target.__name__}" + if target.__module__ is not None and "acc_ops" in target.__module__ + else _get_qualified_name(target) + ) + else: + assert isinstance(node.target, str) + return node.target + +@compatibility(is_backward_compatible=False) +def is_node_output_tensor(node: torch.fx.Node) -> bool: + """Checks if the node output produces a Tensor or not. + + NOTE: This requires to run `ShapeProp` on the containing fx graph before + calling this function. This is because it works by checking the `type` + metadata on the node. This metadata is produced by the `ShapeProp`. + """ + type_ = node.meta.get("type", None) + return type_ is not None and issubclass(type_, torch.Tensor) + +@compatibility(is_backward_compatible=False) +class FxNetAccFusionsFinder: + """ + Finds groups of connected ACC nodes that pass non-tensor data between each other. + Such groups are called fusion groups. + """ + + def __init__(self, module: torch.fx.GraphModule, acc_nodes: NodeSet): + self.module = module + self.nodes = list(module.graph.nodes) + self.acc_nodes = acc_nodes + + @dataclass + class FusionGroup: + # The smallest idx of nodes in the fusion group after topological sorting all the nodes in the model. + top_node_idx: int + + # Nodes in this fusion group. + nodes: NodeSet + + # Inputs to this fusion group. + inputs: NodeSet + + # Nodes that in the fusion group that haven't been processed yet. + nodes_need_process: NodeSet + + def add_node(self, node): + """ + Add a node to fusion group. + """ + if node in self.nodes: + return + + self.nodes_need_process.add(node) + self.nodes.add(node) + self.inputs.discard(node) + self.inputs.update( + { + n + for n in node.all_input_nodes + if n.op in CALLABLE_NODE_OPS and n not in self.nodes + } + ) + + def recursive_add_node( + self, + fusion_group: "FxNetAccFusionsFinder.FusionGroup", + inputs: Union[NodeSet, NodeList], + visited: Optional[NodeSet] = None, + ): + """ + Start from inputs and going reverse topological order. If any upstream node + is in the fusion group, add all the nodes in this path to fusion group. + """ + for arg in inputs: + # skip the node if already seen + if visited is not None: + if arg in visited: + continue + visited.add(arg) + + # Skip placeholder and get_attr because they won't be in the fusion group. + if arg.op not in CALLABLE_NODE_OPS: + continue + + # If the node has smaller idx, it's already an upstream node of the fusion + # group. We don't need to check it anymore. + if self.nodes.index(arg) < fusion_group.top_node_idx: + continue + + # If the node is in the fusion group, return True. + if arg in fusion_group.nodes: + return True + + # Check the upstream nodes of the node, if any of them is in the fusion group + # we'll add this node to fusion group and return True. + if self.recursive_add_node(fusion_group, arg.all_input_nodes, visited): + fusion_group.add_node(arg) + return True + + return False + + def __call__(self) -> Dict[torch.fx.Node, NodeSet]: + result: Dict[torch.fx.Node, NodeSet] = {} + acc_nodes = list(self.acc_nodes) + + for node in acc_nodes: + if node in result: + continue + if node.op not in CALLABLE_NODE_OPS: + continue + if "tensor_meta" in node.meta: + continue + if node not in self.acc_nodes: + continue + + fusion_group: FxNetAccFusionsFinder.FusionGroup = self.FusionGroup( + top_node_idx=self.nodes.index(node), + nodes={node}, + inputs=set(node.all_input_nodes), + nodes_need_process={node}, + ) + while fusion_group.nodes_need_process: + node = fusion_group.nodes_need_process.pop() + self.recursive_add_node( + fusion_group, + fusion_group.inputs, + visited=set(), + ) + + # Optionally add downstream nodes + if "tensor_meta" not in node.meta: + for user in node.users: + if user.op not in CALLABLE_NODE_OPS: + continue + if user in fusion_group.nodes: + continue + + fusion_group.add_node(user) + self.recursive_add_node( + fusion_group, + fusion_group.inputs, + visited=set(), + ) + + # Add some upstream nodes + for arg in node.all_input_nodes: + if arg.op not in CALLABLE_NODE_OPS: + continue + if "tensor_meta" in arg.meta: + continue + if arg in fusion_group.nodes: + continue + + fusion_group.add_node(arg) + fusion_group.top_node_idx = min( + fusion_group.top_node_idx, self.nodes.index(arg) + ) + self.recursive_add_node( + fusion_group, + fusion_group.inputs, + visited=set(), + ) + + if not (set(fusion_group.nodes) <= self.acc_nodes): + self.acc_nodes -= fusion_group.nodes + else: + for n in fusion_group.nodes: + result[n] = fusion_group.nodes + + return result + + +@compatibility(is_backward_compatible=False) +def legalize_graph(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + """ + Replace the graph of the given GraphModule with one that contains the same nodes as the + original, but in topologically sorted order. + + This is used by the merge_matmul transformation below, which disturbs the topologically sorted + order of its input GraphModule, so that this order is restored before further transformation. + + Arguments: + gm: The graph module to topologically sort. It is modified in-place. + + Returns: + The graph module in-place sorted + """ + indeg = dict.fromkeys(gm.graph.nodes, 0) + new_graph = torch.fx.Graph() + # Track how many unfulfilled dependencies each node has + for node in gm.graph.nodes: + for user in node.users: + indeg[user] += 1 + queue: collections.deque = collections.deque() + # Add all nodes with no dependencies to the queue + for node in gm.graph.nodes: + if indeg[node] == 0: + queue.append(node) + env: Dict[torch.fx.Node, torch.fx.Node] = {} + # Pop nodes from the queue, and add nodes that have had all their + # dependencies fulfilled + while len(queue) > 0: + cur = queue.popleft() + env[cur] = new_graph.node_copy(cur, lambda x: env[x]) + for user in cur.users: + indeg[user] -= 1 + if indeg[user] == 0: + queue.append(user) + # If the new graph's size is not as large as the old one, then there must be + # a cycle (i.e. some node's dependencies were not satisfied.) + if len(new_graph.nodes) < len(gm.graph.nodes): + raise RuntimeError(f"Input graph has cycles, unable to add {[node for node in indeg if indeg[node] != 0]}") + new_graph._codegen = gm.graph._codegen + gm.graph = new_graph + return gm diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/cpp.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/cpp.py new file mode 100644 index 0000000000000000000000000000000000000000..a08c7b3141001613c1aa22851dde97f4dd48275c --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/cpp.py @@ -0,0 +1,88 @@ +"""Functionality for Python <-> C++ frontend inter-op.""" + +from torch import nn + + +class OrderedDictWrapper: + """A wrapper around a C++ OrderedDict. + + It dynamically evaluates the OrderedDict getter on a bound C++ module, such + that new changes on the C++ side are picked up. Otherwise accessing e.g. + ``cpp_module._parameters`` just once would get a frozen copy of the parameters + at the time of access. ``torch.nn.Module`` accesses ``_parameters`` et al. via ``self.__dict__`` + so using properties does not work. + """ + + def __init__(self, cpp_module, attr): + self.cpp_module = cpp_module + self.attr = attr + + @property + def cpp_dict(self): + return getattr(self.cpp_module, self.attr) + + # Magic methods cannot be assigned dynamically and bypass ``getattr``, so we + # must manually override them. + + def items(self): + return self.cpp_dict.items() + + def keys(self): + return self.cpp_dict.keys() + + def values(self): + return self.cpp_dict.values() + + def __iter__(self): + return self.cpp_dict.__iter__() + + def __len__(self): + return self.cpp_dict.__len__() + + def __contains__(self, key): + return self.cpp_dict.__contains__(key) + + def __getitem__(self, key): + return self.cpp_dict.__getitem__(key) + + +class ModuleWrapper(nn.Module): + """A subclass of ``torch.nn.Module`` that wraps a C++ frontend module and delegates all access.""" + + def __init__(self, cpp_module): + # Assign before the super class constructor so ``self.training`` can be + # assigned to in the super class constructor. + self.cpp_module = cpp_module + super().__init__() + self._parameters = OrderedDictWrapper(cpp_module, "_parameters") # type: ignore[assignment] + self._buffers: OrderedDictWrapper = OrderedDictWrapper(cpp_module, "_buffers") # type: ignore[assignment] + self._modules: OrderedDictWrapper = OrderedDictWrapper(cpp_module, "_modules") # type: ignore[assignment] + for attr in dir(cpp_module): + # Skip magic methods and the three attributes above. + if not attr.startswith("_"): + setattr(self, attr, getattr(self.cpp_module, attr)) + + def _apply(self, fn, recurse=True): + for param in self.parameters(): + # Tensors stored in modules are graph leaves, and we don't + # want to create copy nodes, so we have to unpack the data. + param.data = fn(param.data) + if param._grad is not None: + param._grad.data = fn(param._grad.data) + + for buf in self.buffers(): + buf.data = fn(buf.data) + + return self + + # nn.Module defines training as a boolean + @property # type: ignore[override] + def training(self): + return self.cpp_module.training + + @training.setter + def training(self, mode): + self.cpp_module.train(mode) + + def __repr__(self): + return self.cpp_module.__repr__() diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/functional.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/functional.py new file mode 100644 index 0000000000000000000000000000000000000000..a2107a00d678a7b79f3e1eb0b0abb87d3c0a268c --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/functional.py @@ -0,0 +1,5512 @@ +"""Functional interface.""" +from typing import Callable, List, Optional, Tuple, Union +import math +import warnings +import importlib + +try: + import numpy as np +except ModuleNotFoundError: + np = None + +import torch +from torch import _VF +from torch import sym_int as _sym_int +from torch._C import _infer_size, _add_docstr +from torch._torch_docs import reproducibility_notes, tf32_notes, sparse_support_notes +# A workaround to support both TorchScript and MyPy: +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from torch.types import _dtype as DType +else: + # The JIT doesn't understand Union, nor torch.dtype here + DType = int + +from .._jit_internal import boolean_dispatch, _overload, BroadcastingList1, BroadcastingList2, BroadcastingList3 +from ..overrides import ( + has_torch_function, has_torch_function_unary, has_torch_function_variadic, + handle_torch_function) +from . import _reduction as _Reduction +from . import grad # noqa: F401 +from .modules import utils +from .modules.utils import _single, _pair, _triple, _list_with_default + +Tensor = torch.Tensor + +conv1d = _add_docstr( + torch.conv1d, + r""" +conv1d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) -> Tensor + +Applies a 1D convolution over an input signal composed of several input +planes. + +{tf32_note} + +See :class:`~torch.nn.Conv1d` for details and output shape. + +Note: + {cudnn_reproducibility_note} + +Note: + This operator supports complex data types i.e. ``complex32, complex64, complex128``. +""".format( + **reproducibility_notes, **tf32_notes + ) + + r""" + +Args: + input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iW)` + weight: filters of shape :math:`(\text{out\_channels} , \frac{\text{in\_channels}}{\text{groups}} , kW)` + bias: optional bias of shape :math:`(\text{out\_channels})`. Default: ``None`` + stride: the stride of the convolving kernel. Can be a single number or + a one-element tuple `(sW,)`. Default: 1 + padding: implicit paddings on both sides of the input. Can be a string {'valid', 'same'}, + single number or a one-element tuple `(padW,)`. Default: 0 + ``padding='valid'`` is the same as no padding. ``padding='same'`` pads + the input so the output has the same shape as the input. However, this mode + doesn't support any stride values other than 1. + + .. warning:: + For ``padding='same'``, if the ``weight`` is even-length and + ``dilation`` is odd in any dimension, a full :func:`pad` operation + may be needed internally. Lowering performance. + dilation: the spacing between kernel elements. Can be a single number or + a one-element tuple `(dW,)`. Default: 1 + groups: split input into groups, :math:`\text{in\_channels}` should be divisible by + the number of groups. Default: 1 + +Examples:: + + >>> inputs = torch.randn(33, 16, 30) + >>> filters = torch.randn(20, 16, 5) + >>> F.conv1d(inputs, filters) +""", +) + +conv2d = _add_docstr( + torch.conv2d, + r""" +conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) -> Tensor + +Applies a 2D convolution over an input image composed of several input +planes. + +{tf32_note} + +See :class:`~torch.nn.Conv2d` for details and output shape. + +Note: + {cudnn_reproducibility_note} + +Note: + This operator supports complex data types i.e. ``complex32, complex64, complex128``. +""".format( + **reproducibility_notes, **tf32_notes + ) + + r""" + +Args: + input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)` + weight: filters of shape :math:`(\text{out\_channels} , \frac{\text{in\_channels}}{\text{groups}} , kH , kW)` + bias: optional bias tensor of shape :math:`(\text{out\_channels})`. Default: ``None`` + stride: the stride of the convolving kernel. Can be a single number or a + tuple `(sH, sW)`. Default: 1 + padding: implicit paddings on both sides of the input. Can be a string {'valid', 'same'}, + single number or a tuple `(padH, padW)`. Default: 0 + ``padding='valid'`` is the same as no padding. ``padding='same'`` pads + the input so the output has the same shape as the input. However, this mode + doesn't support any stride values other than 1. + + .. warning:: + For ``padding='same'``, if the ``weight`` is even-length and + ``dilation`` is odd in any dimension, a full :func:`pad` operation + may be needed internally. Lowering performance. + + dilation: the spacing between kernel elements. Can be a single number or + a tuple `(dH, dW)`. Default: 1 + groups: split input into groups, both :math:`\text{in\_channels}` and :math:`\text{out\_channels}` + should be divisible by the number of groups. Default: 1 + +Examples:: + + >>> # With square kernels and equal stride + >>> filters = torch.randn(8, 4, 3, 3) + >>> inputs = torch.randn(1, 4, 5, 5) + >>> F.conv2d(inputs, filters, padding=1) +""", +) # noqa: E501 + +conv3d = _add_docstr( + torch.conv3d, + r""" +conv3d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) -> Tensor + +Applies a 3D convolution over an input image composed of several input +planes. + +{tf32_note} + +See :class:`~torch.nn.Conv3d` for details and output shape. + +Note: + {cudnn_reproducibility_note} + +Note: + This operator supports complex data types i.e. ``complex32, complex64, complex128``. +""".format( + **reproducibility_notes, **tf32_notes + ) + + r""" + +Args: + input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iT , iH , iW)` + weight: filters of shape :math:`(\text{out\_channels} , \frac{\text{in\_channels}}{\text{groups}} , kT , kH , kW)` + bias: optional bias tensor of shape :math:`(\text{out\_channels})`. Default: None + stride: the stride of the convolving kernel. Can be a single number or a + tuple `(sT, sH, sW)`. Default: 1 + padding: implicit paddings on both sides of the input. Can be a string {'valid', 'same'}, + single number or a tuple `(padT, padH, padW)`. Default: 0 + ``padding='valid'`` is the same as no padding. ``padding='same'`` pads + the input so the output has the same shape as the input. However, this mode + doesn't support any stride values other than 1. + + .. warning:: + For ``padding='same'``, if the ``weight`` is even-length and + ``dilation`` is odd in any dimension, a full :func:`pad` operation + may be needed internally. Lowering performance. + + dilation: the spacing between kernel elements. Can be a single number or + a tuple `(dT, dH, dW)`. Default: 1 + groups: split input into groups, :math:`\text{in\_channels}` should be divisible by + the number of groups. Default: 1 + +Examples:: + + >>> filters = torch.randn(33, 16, 3, 3, 3) + >>> inputs = torch.randn(20, 16, 50, 10, 20) + >>> F.conv3d(inputs, filters) +""", +) # noqa: E501 + +conv_transpose1d = _add_docstr( + torch.conv_transpose1d, + r""" +conv_transpose1d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1) -> Tensor + +Applies a 1D transposed convolution operator over an input signal +composed of several input planes, sometimes also called "deconvolution". + +{tf32_note} + +See :class:`~torch.nn.ConvTranspose1d` for details and output shape. + +Note: + {cudnn_reproducibility_note} +""".format( + **reproducibility_notes, **tf32_notes + ) + + r""" + +Args: + input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iW)` + weight: filters of shape :math:`(\text{in\_channels} , \frac{\text{out\_channels}}{\text{groups}} , kW)` + bias: optional bias of shape :math:`(\text{out\_channels})`. Default: None + stride: the stride of the convolving kernel. Can be a single number or a + tuple ``(sW,)``. Default: 1 + padding: ``dilation * (kernel_size - 1) - padding`` zero-padding will be added to both + sides of each dimension in the input. Can be a single number or a tuple + ``(padW,)``. Default: 0 + output_padding: additional size added to one side of each dimension in the + output shape. Can be a single number or a tuple ``(out_padW)``. Default: 0 + groups: split input into groups, :math:`\text{in\_channels}` should be divisible by the + number of groups. Default: 1 + dilation: the spacing between kernel elements. Can be a single number or + a tuple ``(dW,)``. Default: 1 + +Examples:: + + >>> inputs = torch.randn(20, 16, 50) + >>> weights = torch.randn(16, 33, 5) + >>> F.conv_transpose1d(inputs, weights) +""", +) + +conv_transpose2d = _add_docstr( + torch.conv_transpose2d, + r""" +conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1) -> Tensor + +Applies a 2D transposed convolution operator over an input image +composed of several input planes, sometimes also called "deconvolution". + +{tf32_note} + +See :class:`~torch.nn.ConvTranspose2d` for details and output shape. + +Note: + {cudnn_reproducibility_note} +""".format( + **reproducibility_notes, **tf32_notes + ) + + r""" + +Args: + input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)` + weight: filters of shape :math:`(\text{in\_channels} , \frac{\text{out\_channels}}{\text{groups}} , kH , kW)` + bias: optional bias of shape :math:`(\text{out\_channels})`. Default: None + stride: the stride of the convolving kernel. Can be a single number or a + tuple ``(sH, sW)``. Default: 1 + padding: ``dilation * (kernel_size - 1) - padding`` zero-padding will be added to both + sides of each dimension in the input. Can be a single number or a tuple + ``(padH, padW)``. Default: 0 + output_padding: additional size added to one side of each dimension in the + output shape. Can be a single number or a tuple ``(out_padH, out_padW)``. + Default: 0 + groups: split input into groups, :math:`\text{in\_channels}` should be divisible by the + number of groups. Default: 1 + dilation: the spacing between kernel elements. Can be a single number or + a tuple ``(dH, dW)``. Default: 1 + +Examples:: + + >>> # With square kernels and equal stride + >>> inputs = torch.randn(1, 4, 5, 5) + >>> weights = torch.randn(4, 8, 3, 3) + >>> F.conv_transpose2d(inputs, weights, padding=1) +""", +) # noqa: E501 + +conv_transpose3d = _add_docstr( + torch.conv_transpose3d, + r""" +conv_transpose3d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1) -> Tensor + +Applies a 3D transposed convolution operator over an input image +composed of several input planes, sometimes also called "deconvolution" + +{tf32_note} + +See :class:`~torch.nn.ConvTranspose3d` for details and output shape. + +Note: + {cudnn_reproducibility_note} +""".format( + **reproducibility_notes, **tf32_notes + ) + + r""" + +Args: + input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iT , iH , iW)` + weight: filters of shape :math:`(\text{in\_channels} , \frac{\text{out\_channels}}{\text{groups}} , kT , kH , kW)` + bias: optional bias of shape :math:`(\text{out\_channels})`. Default: None + stride: the stride of the convolving kernel. Can be a single number or a + tuple ``(sT, sH, sW)``. Default: 1 + padding: ``dilation * (kernel_size - 1) - padding`` zero-padding will be added to both + sides of each dimension in the input. Can be a single number or a tuple + ``(padT, padH, padW)``. Default: 0 + output_padding: additional size added to one side of each dimension in the + output shape. Can be a single number or a tuple + ``(out_padT, out_padH, out_padW)``. Default: 0 + groups: split input into groups, :math:`\text{in\_channels}` should be divisible by the + number of groups. Default: 1 + dilation: the spacing between kernel elements. Can be a single number or + a tuple `(dT, dH, dW)`. Default: 1 + +Examples:: + + >>> inputs = torch.randn(20, 16, 50, 10, 20) + >>> weights = torch.randn(16, 33, 3, 3, 3) + >>> F.conv_transpose3d(inputs, weights) +""", +) # noqa: E501 + +conv_tbc = _add_docstr( + torch.conv_tbc, + r""" +Applies a 1-dimensional sequence convolution over an input sequence. +Input and output dimensions are (Time, Batch, Channels) - hence TBC. + +Args: + input: input tensor of shape :math:`(\text{sequence length} \times batch \times \text{in\_channels})` + weight: filter of shape (:math:`\text{kernel width} \times \text{in\_channels} \times \text{out\_channels}`) + bias: bias of shape (:math:`\text{out\_channels}`) + pad: number of timesteps to pad. Default: 0 +""", +) + + +# Pooling +avg_pool1d = _add_docstr( + torch.avg_pool1d, + r""" +avg_pool1d(input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True) -> Tensor + +Applies a 1D average pooling over an input signal composed of several +input planes. + +See :class:`~torch.nn.AvgPool1d` for details and output shape. + +Args: + input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iW)` + kernel_size: the size of the window. Can be a single number or a + tuple `(kW,)` + stride: the stride of the window. Can be a single number or a tuple + `(sW,)`. Default: :attr:`kernel_size` + padding: implicit zero paddings on both sides of the input. Can be a + single number or a tuple `(padW,)`. Default: 0 + ceil_mode: when True, will use `ceil` instead of `floor` to compute the + output shape. Default: ``False`` + count_include_pad: when True, will include the zero-padding in the + averaging calculation. Default: ``True`` + +Examples:: + + >>> # pool of square window of size=3, stride=2 + >>> input = torch.tensor([[[1, 2, 3, 4, 5, 6, 7]]], dtype=torch.float32) + >>> F.avg_pool1d(input, kernel_size=3, stride=2) + tensor([[[ 2., 4., 6.]]]) + +""", +) + + +avg_pool2d = _add_docstr( + torch._C._nn.avg_pool2d, + r""" +avg_pool2d(input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None) -> Tensor + +Applies 2D average-pooling operation in :math:`kH \times kW` regions by step size +:math:`sH \times sW` steps. The number of output features is equal to the number of +input planes. + +See :class:`~torch.nn.AvgPool2d` for details and output shape. + +Args: + input: input tensor :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)` + kernel_size: size of the pooling region. Can be a single number or a + tuple `(kH, kW)` + stride: stride of the pooling operation. Can be a single number or a + tuple `(sH, sW)`. Default: :attr:`kernel_size` + padding: implicit zero paddings on both sides of the input. Can be a + single number or a tuple `(padH, padW)`. Default: 0 + ceil_mode: when True, will use `ceil` instead of `floor` in the formula + to compute the output shape. Default: ``False`` + count_include_pad: when True, will include the zero-padding in the + averaging calculation. Default: ``True`` + divisor_override: if specified, it will be used as divisor, otherwise + size of the pooling region will be used. Default: None +""", +) + +avg_pool3d = _add_docstr( + torch._C._nn.avg_pool3d, + r""" +avg_pool3d(input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None) -> Tensor + +Applies 3D average-pooling operation in :math:`kT \times kH \times kW` regions by step +size :math:`sT \times sH \times sW` steps. The number of output features is equal to +:math:`\lfloor\frac{\text{input planes}}{sT}\rfloor`. + +See :class:`~torch.nn.AvgPool3d` for details and output shape. + +Args: + input: input tensor :math:`(\text{minibatch} , \text{in\_channels} , iT \times iH , iW)` + kernel_size: size of the pooling region. Can be a single number or a + tuple `(kT, kH, kW)` + stride: stride of the pooling operation. Can be a single number or a + tuple `(sT, sH, sW)`. Default: :attr:`kernel_size` + padding: implicit zero paddings on both sides of the input. Can be a + single number or a tuple `(padT, padH, padW)`, Default: 0 + ceil_mode: when True, will use `ceil` instead of `floor` in the formula + to compute the output shape + count_include_pad: when True, will include the zero-padding in the + averaging calculation + divisor_override: if specified, it will be used as divisor, otherwise + size of the pooling region will be used. Default: None +""", +) + + +def fractional_max_pool2d_with_indices( + input: Tensor, kernel_size: BroadcastingList2[int], + output_size: Optional[BroadcastingList2[int]] = None, + output_ratio: Optional[BroadcastingList2[float]] = None, + return_indices: bool = False, + _random_samples: Optional[Tensor] = None +) -> Tuple[Tensor, Tensor]: # noqa: D400 + r""" + fractional_max_pool2d(input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None) + + Applies 2D fractional max pooling over an input signal composed of several input planes. + + Fractional MaxPooling is described in detail in the paper `Fractional MaxPooling`_ by Ben Graham + + The max-pooling operation is applied in :math:`kH \times kW` regions by a stochastic + step size determined by the target output size. + The number of output features is equal to the number of input planes. + + Args: + kernel_size: the size of the window to take a max over. + Can be a single number :math:`k` (for a square kernel of :math:`k \times k`) + or a tuple `(kH, kW)` + output_size: the target output size of the image of the form :math:`oH \times oW`. + Can be a tuple `(oH, oW)` or a single number :math:`oH` for a square image :math:`oH \times oH` + output_ratio: If one wants to have an output size as a ratio of the input size, this option can be given. + This has to be a number or tuple in the range (0, 1) + return_indices: if ``True``, will return the indices along with the outputs. + Useful to pass to :func:`~torch.nn.functional.max_unpool2d`. + + Examples:: + >>> input = torch.randn(20, 16, 50, 32) + >>> # pool of square window of size=3, and target output size 13x12 + >>> F.fractional_max_pool2d(input, 3, output_size=(13, 12)) + >>> # pool of square window and target output size being half of input image size + >>> F.fractional_max_pool2d(input, 3, output_ratio=(0.5, 0.5)) + + .. _Fractional MaxPooling: + http://arxiv.org/abs/1412.6071 + """ + if has_torch_function_variadic(input, _random_samples): + return handle_torch_function( + fractional_max_pool2d_with_indices, + (input, _random_samples), + input, + kernel_size, + output_size=output_size, + output_ratio=output_ratio, + return_indices=return_indices, + _random_samples=_random_samples, + ) + if output_size is None and output_ratio is None: + raise ValueError("fractional_max_pool2d requires specifying either an output_size or an output_ratio") + if output_size is None: + assert output_ratio is not None + if len(output_ratio) > 2: + raise ValueError("fractional_max_pool2d requires output_ratio to either be a single Int or tuple of Ints.") + _output_ratio = _pair(output_ratio) + output_size = [int(input.size(-2) * _output_ratio[0]), int(input.size(-1) * _output_ratio[1])] + + if _random_samples is None: + n_batch = 1 if input.dim() == 3 else input.size(0) + _random_samples = torch.rand(n_batch, input.size(-3), 2, dtype=input.dtype, device=input.device) + return torch._C._nn.fractional_max_pool2d(input, kernel_size, output_size, _random_samples) + + +def _fractional_max_pool2d( + input: Tensor, kernel_size: BroadcastingList2[int], + output_size: Optional[BroadcastingList2[int]] = None, + output_ratio: Optional[BroadcastingList2[float]] = None, + return_indices: bool = False, + _random_samples: Optional[Tensor] = None +) -> Tensor: + if has_torch_function_variadic(input, _random_samples): + return handle_torch_function( + fractional_max_pool2d, + (input, _random_samples), + input, + kernel_size, + output_size=output_size, + output_ratio=output_ratio, + return_indices=return_indices, + _random_samples=_random_samples, + ) + return fractional_max_pool2d_with_indices( + input, kernel_size, output_size, output_ratio, return_indices, _random_samples + )[0] + + +fractional_max_pool2d = boolean_dispatch( + arg_name="return_indices", + arg_index=4, + default=False, + if_true=fractional_max_pool2d_with_indices, + if_false=_fractional_max_pool2d, + module_name=__name__, + func_name="fractional_max_pool2d", +) + + +def fractional_max_pool3d_with_indices( + input: Tensor, kernel_size: BroadcastingList3[int], + output_size: Optional[BroadcastingList3[int]] = None, + output_ratio: Optional[BroadcastingList3[float]] = None, + return_indices: bool = False, + _random_samples: Optional[Tensor] = None +) -> Tuple[Tensor, Tensor]: # noqa: D400 + r""" + fractional_max_pool3d(input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None) + + Applies 3D fractional max pooling over an input signal composed of several input planes. + + Fractional MaxPooling is described in detail in the paper `Fractional MaxPooling`_ by Ben Graham + + The max-pooling operation is applied in :math:`kT \times kH \times kW` regions by a stochastic + step size determined by the target output size. + The number of output features is equal to the number of input planes. + + Args: + kernel_size: the size of the window to take a max over. + Can be a single number :math:`k` (for a square kernel of :math:`k \times k \times k`) + or a tuple `(kT, kH, kW)` + output_size: the target output size of the form :math:`oT \times oH \times oW`. + Can be a tuple `(oT, oH, oW)` or a single number :math:`oH` for a cubic output + :math:`oH \times oH \times oH` + output_ratio: If one wants to have an output size as a ratio of the input size, this option can be given. + This has to be a number or tuple in the range (0, 1) + return_indices: if ``True``, will return the indices along with the outputs. + Useful to pass to :func:`~torch.nn.functional.max_unpool3d`. + + Shape: + - Input: :math:`(N, C, T_{in}, H_{in}, W_{in})` or :math:`(C, T_{in}, H_{in}, W_{in})`. + - Output: :math:`(N, C, T_{out}, H_{out}, W_{out})` or :math:`(C, T_{out}, H_{out}, W_{out})`, where + :math:`(T_{out}, H_{out}, W_{out})=\text{output\_size}` or + :math:`(T_{out}, H_{out}, W_{out})=\text{output\_ratio} \times (T_{in}, H_{in}, W_{in})` + + Examples:: + >>> input = torch.randn(20, 16, 50, 32, 16) + >>> # pool of cubic window of size=3, and target output size 13x12x11 + >>> F.fractional_max_pool3d(input, 3, output_size=(13, 12, 11)) + >>> # pool of cubic window and target output size being half of input size + >>> F.fractional_max_pool3d(input, 3, output_ratio=(0.5, 0.5, 0.5)) + + .. _Fractional MaxPooling: + http://arxiv.org/abs/1412.6071 + """ + if has_torch_function_variadic(input, _random_samples): + return handle_torch_function( + fractional_max_pool3d_with_indices, + (input, _random_samples), + input, + kernel_size, + output_size=output_size, + output_ratio=output_ratio, + return_indices=return_indices, + _random_samples=_random_samples, + ) + if output_size is None and output_ratio is None: + raise ValueError("fractional_max_pool3d requires specifying either an output_size or an output_ratio") + if output_size is None: + assert output_ratio is not None + _output_ratio = _triple(output_ratio) + output_size = [ + int(input.size(-3) * _output_ratio[0]), + int(input.size(-2) * _output_ratio[1]), + int(input.size(-1) * _output_ratio[2]), + ] + + if _random_samples is None: + n_batch = 1 if input.dim() == 4 else input.size(0) + _random_samples = torch.rand(n_batch, input.size(-4), 3, dtype=input.dtype, device=input.device) + return torch._C._nn.fractional_max_pool3d(input, kernel_size, output_size, _random_samples) + + +def _fractional_max_pool3d( + input: Tensor, kernel_size: BroadcastingList3[int], + output_size: Optional[BroadcastingList3[int]] = None, + output_ratio: Optional[BroadcastingList3[float]] = None, + return_indices: bool = False, + _random_samples: Optional[Tensor] = None +) -> Tensor: + if has_torch_function_variadic(input, _random_samples): + return handle_torch_function( + fractional_max_pool3d, + (input, _random_samples), + input, + kernel_size, + output_size=output_size, + output_ratio=output_ratio, + return_indices=return_indices, + _random_samples=_random_samples, + ) + return fractional_max_pool3d_with_indices( + input, kernel_size, output_size, output_ratio, return_indices, _random_samples + )[0] + + +fractional_max_pool3d = boolean_dispatch( + arg_name="return_indices", + arg_index=4, + default=False, + if_true=fractional_max_pool3d_with_indices, + if_false=_fractional_max_pool3d, + module_name=__name__, + func_name="fractional_max_pool3d", +) + + +def max_pool1d_with_indices( + input: Tensor, kernel_size: BroadcastingList1[int], + stride: Optional[BroadcastingList1[int]] = None, + padding: BroadcastingList1[int] = 0, + dilation: BroadcastingList1[int] = 1, + ceil_mode: bool = False, + return_indices: bool = False +) -> Tuple[Tensor, Tensor]: # noqa: D400 + r""" + max_pool1d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False) + + Applies a 1D max pooling over an input signal composed of several input + planes. + + .. note:: + The order of :attr:`ceil_mode` and :attr:`return_indices` is different from + what seen in :class:`~torch.nn.MaxPool1d`, and will change in a future release. + + See :class:`~torch.nn.MaxPool1d` for details. + + Args: + input: input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iW)`, minibatch dim optional. + kernel_size: the size of the window. Can be a single number or a + tuple `(kW,)` + stride: the stride of the window. Can be a single number or a tuple + `(sW,)`. Default: :attr:`kernel_size` + padding: Implicit negative infinity padding to be added on both sides, must be >= 0 and <= kernel_size / 2. + dilation: The stride between elements within a sliding window, must be > 0. + ceil_mode: If ``True``, will use `ceil` instead of `floor` to compute the output shape. This + ensures that every element in the input tensor is covered by a sliding window. + return_indices: If ``True``, will return the argmax along with the max values. + Useful for :class:`torch.nn.functional.max_unpool1d` later + """ + if has_torch_function_unary(input): + return handle_torch_function( + max_pool1d_with_indices, + (input,), + input, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + ceil_mode=ceil_mode, + return_indices=return_indices, + ) + if stride is None: + stride = torch.jit.annotate(List[int], []) + return torch.max_pool1d_with_indices(input, kernel_size, stride, padding, dilation, ceil_mode) + + +def _max_pool1d( + input: Tensor, kernel_size: BroadcastingList1[int], + stride: Optional[BroadcastingList1[int]] = None, + padding: BroadcastingList1[int] = 0, + dilation: BroadcastingList1[int] = 1, + ceil_mode: bool = False, + return_indices: bool = False +) -> Tensor: + if has_torch_function_unary(input): + return handle_torch_function( + max_pool1d, + (input,), + input, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + ceil_mode=ceil_mode, + return_indices=return_indices, + ) + if stride is None: + stride = torch.jit.annotate(List[int], []) + return torch.max_pool1d(input, kernel_size, stride, padding, dilation, ceil_mode) + + +max_pool1d = boolean_dispatch( + arg_name="return_indices", + arg_index=6, + default=False, + if_true=max_pool1d_with_indices, + if_false=_max_pool1d, + module_name=__name__, + func_name="max_pool1d", +) + + +def max_pool2d_with_indices( + input: Tensor, kernel_size: BroadcastingList2[int], + stride: Optional[BroadcastingList2[int]] = None, + padding: BroadcastingList2[int] = 0, + dilation: BroadcastingList2[int] = 1, + ceil_mode: bool = False, + return_indices: bool = False +) -> Tuple[Tensor, Tensor]: # noqa: D400 + r""" + max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False) + + Applies a 2D max pooling over an input signal composed of several input + planes. + + .. note:: + The order of :attr:`ceil_mode` and :attr:`return_indices` is different from + what seen in :class:`~torch.nn.MaxPool2d`, and will change in a future release. + + See :class:`~torch.nn.MaxPool2d` for details. + + Args: + input: input tensor :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)`, minibatch dim optional. + kernel_size: size of the pooling region. Can be a single number or a + tuple `(kH, kW)` + stride: stride of the pooling operation. Can be a single number or a + tuple `(sH, sW)`. Default: :attr:`kernel_size` + padding: Implicit negative infinity padding to be added on both sides, must be >= 0 and <= kernel_size / 2. + dilation: The stride between elements within a sliding window, must be > 0. + ceil_mode: If ``True``, will use `ceil` instead of `floor` to compute the output shape. This + ensures that every element in the input tensor is covered by a sliding window. + return_indices: If ``True``, will return the argmax along with the max values. + Useful for :class:`torch.nn.functional.max_unpool2d` later + """ + if has_torch_function_unary(input): + return handle_torch_function( + max_pool2d_with_indices, + (input,), + input, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + ceil_mode=ceil_mode, + return_indices=return_indices, + ) + if stride is None: + stride = torch.jit.annotate(List[int], []) + return torch._C._nn.max_pool2d_with_indices(input, kernel_size, stride, padding, dilation, ceil_mode) + + +def _max_pool2d( + input: Tensor, kernel_size: BroadcastingList2[int], + stride: Optional[BroadcastingList2[int]] = None, + padding: BroadcastingList2[int] = 0, + dilation: BroadcastingList2[int] = 1, + ceil_mode: bool = False, + return_indices: bool = False +) -> Tensor: + if has_torch_function_unary(input): + return handle_torch_function( + max_pool2d, + (input,), + input, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + ceil_mode=ceil_mode, + return_indices=return_indices, + ) + if stride is None: + stride = torch.jit.annotate(List[int], []) + return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode) + + +max_pool2d = boolean_dispatch( + arg_name="return_indices", + arg_index=6, + default=False, + if_true=max_pool2d_with_indices, + if_false=_max_pool2d, + module_name=__name__, + func_name="max_pool2d", +) + + +def max_pool3d_with_indices( + input: Tensor, kernel_size: BroadcastingList3[int], + stride: Optional[BroadcastingList3[int]] = None, + padding: BroadcastingList3[int] = 0, + dilation: BroadcastingList3[int] = 1, + ceil_mode: bool = False, + return_indices: bool = False +) -> Tuple[Tensor, Tensor]: # noqa: D400 + r""" + max_pool3d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False) + + Applies a 3D max pooling over an input signal composed of several input + planes. + + .. note:: + The order of :attr:`ceil_mode` and :attr:`return_indices` is different from + what seen in :class:`~torch.nn.MaxPool3d`, and will change in a future release. + + See :class:`~torch.nn.MaxPool3d` for details. + + Args: + input: input tensor :math:`(\text{minibatch} , \text{in\_channels} , iD, iH , iW)`, minibatch dim optional. + kernel_size: size of the pooling region. Can be a single number or a + tuple `(kT, kH, kW)` + stride: stride of the pooling operation. Can be a single number or a + tuple `(sT, sH, sW)`. Default: :attr:`kernel_size` + padding: Implicit negative infinity padding to be added on both sides, must be >= 0 and <= kernel_size / 2. + dilation: The stride between elements within a sliding window, must be > 0. + ceil_mode: If ``True``, will use `ceil` instead of `floor` to compute the output shape. This + ensures that every element in the input tensor is covered by a sliding window. + return_indices: If ``True``, will return the argmax along with the max values. + Useful for :class:`torch.nn.functional.max_unpool3d` later + """ + if has_torch_function_unary(input): + return handle_torch_function( + max_pool3d_with_indices, + (input,), + input, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + ceil_mode=ceil_mode, + return_indices=return_indices, + ) + if stride is None: + stride = torch.jit.annotate(List[int], []) + return torch._C._nn.max_pool3d_with_indices(input, kernel_size, stride, padding, dilation, ceil_mode) + + +def _max_pool3d( + input: Tensor, kernel_size: BroadcastingList3[int], + stride: Optional[BroadcastingList3[int]] = None, + padding: BroadcastingList3[int] = 0, + dilation: BroadcastingList3[int] = 1, + ceil_mode: bool = False, + return_indices: bool = False +) -> Tensor: + if has_torch_function_unary(input): + return handle_torch_function( + max_pool3d, + (input,), + input, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + ceil_mode=ceil_mode, + return_indices=return_indices, + ) + if stride is None: + stride = torch.jit.annotate(List[int], []) + return torch.max_pool3d(input, kernel_size, stride, padding, dilation, ceil_mode) + + +max_pool3d = boolean_dispatch( + arg_name="return_indices", + arg_index=6, + default=False, + if_true=max_pool3d_with_indices, + if_false=_max_pool3d, + module_name=__name__, + func_name="max_pool3d", +) + + +def _unpool_output_size( + input: Tensor, kernel_size: List[int], stride: List[int], padding: List[int], output_size: Optional[List[int]] +) -> List[int]: + input_size = input.size() + default_size = torch.jit.annotate(List[int], []) + for d in range(len(kernel_size)): + default_size.append((input_size[-len(kernel_size) + d] - 1) * stride[d] + kernel_size[d] - 2 * padding[d]) + if output_size is None: + ret = default_size + else: + if len(output_size) == len(kernel_size) + 2: + output_size = output_size[2:] + if len(output_size) != len(kernel_size): + raise ValueError( + "output_size should be a sequence containing " + f"{len(kernel_size)} or {len(kernel_size) + 2} elements, but it has a length of '{len(output_size)}'" + ) + for d in range(len(kernel_size)): + min_size = default_size[d] - stride[d] + max_size = default_size[d] + stride[d] + if not (min_size < output_size[d] < max_size): + raise ValueError( + f'invalid output_size "{output_size}" (dim {d} must be between {min_size} and {max_size})' + ) + + ret = output_size + return ret + + +def max_unpool1d( + input: Tensor, indices: Tensor, + kernel_size: BroadcastingList1[int], + stride: Optional[BroadcastingList1[int]] = None, + padding: BroadcastingList1[int] = 0, + output_size: Optional[BroadcastingList1[int]] = None +) -> Tensor: + r"""Compute a partial inverse of :class:`MaxPool1d`. + + See :class:`~torch.nn.MaxUnpool1d` for details. + """ + if has_torch_function_unary(input): + return handle_torch_function( + max_unpool1d, + (input,), + input, + indices, + kernel_size, + stride=stride, + padding=padding, + output_size=output_size, + ) + kernel_size = _single(kernel_size) + if stride is not None: + _stride = _single(stride) + else: + _stride = kernel_size + padding = _single(padding) + output_size = _unpool_output_size(input, kernel_size, _stride, padding, output_size) + if isinstance(output_size, list): + output_size = output_size + [1] + else: + output_size = output_size + (1,) + return torch._C._nn.max_unpool2d(input.unsqueeze(-1), indices.unsqueeze(-1), output_size).squeeze(-1) + + +def max_unpool2d( + input: Tensor, indices: Tensor, + kernel_size: BroadcastingList2[int], + stride: Optional[BroadcastingList2[int]] = None, + padding: BroadcastingList2[int] = 0, + output_size: Optional[BroadcastingList2[int]] = None +) -> Tensor: + r"""Compute a partial inverse of :class:`MaxPool2d`. + + See :class:`~torch.nn.MaxUnpool2d` for details. + """ + if has_torch_function_unary(input): + return handle_torch_function( + max_unpool2d, + (input,), + input, + indices, + kernel_size, + stride=stride, + padding=padding, + output_size=output_size, + ) + kernel_size = _pair(kernel_size) + if stride is not None: + _stride = _pair(stride) + else: + _stride = kernel_size + padding = _pair(padding) + output_size = _unpool_output_size(input, kernel_size, _stride, padding, output_size) + return torch._C._nn.max_unpool2d(input, indices, output_size) + + +def max_unpool3d( + input: Tensor, indices: Tensor, + kernel_size: BroadcastingList3[int], + stride: Optional[BroadcastingList3[int]] = None, + padding: BroadcastingList3[int] = 0, + output_size: Optional[BroadcastingList3[int]] = None +) -> Tensor: + r"""Compute a partial inverse of :class:`MaxPool3d`. + + See :class:`~torch.nn.MaxUnpool3d` for details. + """ + if has_torch_function_unary(input): + return handle_torch_function( + max_unpool3d, + (input,), + input, + indices, + kernel_size, + stride=stride, + padding=padding, + output_size=output_size, + ) + kernel_size = _triple(kernel_size) + if stride is not None: + _stride = _triple(stride) + else: + _stride = kernel_size + padding = _triple(padding) + output_size = _unpool_output_size(input, kernel_size, _stride, padding, output_size) + return torch._C._nn.max_unpool3d(input, indices, output_size, _stride, padding) + + +def lp_pool3d( + input: Tensor, norm_type: Union[int, float], + kernel_size: BroadcastingList3[int], + stride: Optional[BroadcastingList3[int]] = None, + ceil_mode: bool = False +) -> Tensor: + r""" + Apply a 3D power-average pooling over an input signal composed of several input planes. + + If the sum of all inputs to the power of `p` is + zero, the gradient is set to zero as well. + + See :class:`~torch.nn.LPPool3d` for details. + """ + if has_torch_function_unary(input): + return handle_torch_function( + lp_pool3d, (input,), input, norm_type, kernel_size, stride=stride, ceil_mode=ceil_mode + ) + kd, kw, kh = utils._triple(kernel_size) + if stride is not None: + out = avg_pool3d(input.pow(norm_type), kernel_size, stride, 0, ceil_mode) + else: + out = avg_pool3d(input.pow(norm_type), kernel_size, padding=0, ceil_mode=ceil_mode) + + return (torch.sign(out) * relu(torch.abs(out))).mul(kd * kw * kh).pow(1.0 / norm_type) + + +def lp_pool2d( + input: Tensor, norm_type: Union[int, float], + kernel_size: BroadcastingList2[int], + stride: Optional[BroadcastingList2[int]] = None, + ceil_mode: bool = False +) -> Tensor: + r""" + Apply a 2D power-average pooling over an input signal composed of several input planes. + + If the sum of all inputs to the power of `p` is + zero, the gradient is set to zero as well. + + See :class:`~torch.nn.LPPool2d` for details. + """ + if has_torch_function_unary(input): + return handle_torch_function( + lp_pool2d, (input,), input, norm_type, kernel_size, stride=stride, ceil_mode=ceil_mode + ) + kw, kh = utils._pair(kernel_size) + if stride is not None: + out = avg_pool2d(input.pow(norm_type), kernel_size, stride, 0, ceil_mode) + else: + out = avg_pool2d(input.pow(norm_type), kernel_size, padding=0, ceil_mode=ceil_mode) + + return (torch.sign(out) * relu(torch.abs(out))).mul(kw * kh).pow(1.0 / norm_type) + + +def lp_pool1d( + input: Tensor, norm_type: Union[int, float], + kernel_size: int, + stride: Optional[BroadcastingList1[int]] = None, + ceil_mode: bool = False +) -> Tensor: + r"""Apply a 1D power-average pooling over an input signal composed of several input planes. + + If the sum of all inputs to the power of `p` is + zero, the gradient is set to zero as well. + + See :class:`~torch.nn.LPPool1d` for details. + """ + if has_torch_function_unary(input): + return handle_torch_function( + lp_pool1d, (input,), input, norm_type, kernel_size, stride=stride, ceil_mode=ceil_mode + ) + if stride is not None: + out = avg_pool1d(input.pow(norm_type), kernel_size, stride, 0, ceil_mode) + else: + out = avg_pool1d(input.pow(norm_type), kernel_size, padding=0, ceil_mode=ceil_mode) + + return (torch.sign(out) * relu(torch.abs(out))).mul(kernel_size).pow(1.0 / norm_type) + + +def adaptive_max_pool1d_with_indices( + input: Tensor, output_size: BroadcastingList1[int], return_indices: bool = False +) -> Tuple[Tensor, Tensor]: # noqa: D400 + r""" + adaptive_max_pool1d(input, output_size, return_indices=False) + + Applies a 1D adaptive max pooling over an input signal composed of + several input planes. + + See :class:`~torch.nn.AdaptiveMaxPool1d` for details and output shape. + + Args: + output_size: the target output size (single integer) + return_indices: whether to return pooling indices. Default: ``False`` + """ + if has_torch_function_unary(input): + return handle_torch_function( + adaptive_max_pool1d_with_indices, (input,), input, output_size, return_indices=return_indices + ) + return torch.adaptive_max_pool1d(input, output_size) + + +def _adaptive_max_pool1d(input: Tensor, output_size: BroadcastingList1[int], return_indices: bool = False) -> Tensor: + if has_torch_function_unary(input): + return handle_torch_function( + adaptive_max_pool1d, (input,), input, output_size, return_indices=return_indices + ) + return adaptive_max_pool1d_with_indices(input, output_size)[0] + + +adaptive_max_pool1d = boolean_dispatch( + arg_name="return_indices", + arg_index=2, + default=False, + if_true=adaptive_max_pool1d_with_indices, + if_false=_adaptive_max_pool1d, + module_name=__name__, + func_name="adaptive_max_pool1d", +) + + +def adaptive_max_pool2d_with_indices( + input: Tensor, output_size: BroadcastingList2[int], + return_indices: bool = False +) -> Tuple[Tensor, Tensor]: # noqa: D400 + r"""adaptive_max_pool2d(input, output_size, return_indices=False) + + Applies a 2D adaptive max pooling over an input signal composed of + several input planes. + + See :class:`~torch.nn.AdaptiveMaxPool2d` for details and output shape. + + Args: + output_size: the target output size (single integer or + double-integer tuple) + return_indices: whether to return pooling indices. Default: ``False`` + """ + if has_torch_function_unary(input): + return handle_torch_function( + adaptive_max_pool2d_with_indices, (input,), input, output_size, return_indices=return_indices + ) + output_size = _list_with_default(output_size, input.size()) + return torch._C._nn.adaptive_max_pool2d(input, output_size) + + +def _adaptive_max_pool2d(input: Tensor, output_size: BroadcastingList2[int], return_indices: bool = False) -> Tensor: + if has_torch_function_unary(input): + return handle_torch_function( + adaptive_max_pool2d, (input,), input, output_size, return_indices=return_indices + ) + return adaptive_max_pool2d_with_indices(input, output_size)[0] + + +adaptive_max_pool2d = boolean_dispatch( + arg_name="return_indices", + arg_index=2, + default=False, + if_true=adaptive_max_pool2d_with_indices, + if_false=_adaptive_max_pool2d, + module_name=__name__, + func_name="adaptive_max_pool2d", +) + + +def adaptive_max_pool3d_with_indices( + input: Tensor, output_size: BroadcastingList3[int], + return_indices: bool = False +) -> Tuple[Tensor, Tensor]: # noqa: D400 + r""" + adaptive_max_pool3d(input, output_size, return_indices=False) + + Applies a 3D adaptive max pooling over an input signal composed of + several input planes. + + See :class:`~torch.nn.AdaptiveMaxPool3d` for details and output shape. + + Args: + output_size: the target output size (single integer or + triple-integer tuple) + return_indices: whether to return pooling indices. Default: ``False`` + """ + if has_torch_function_unary(input): + return handle_torch_function( + adaptive_max_pool3d_with_indices, (input,), input, output_size, return_indices=return_indices + ) + output_size = _list_with_default(output_size, input.size()) + return torch._C._nn.adaptive_max_pool3d(input, output_size) + + +def _adaptive_max_pool3d(input: Tensor, output_size: BroadcastingList3[int], return_indices: bool = False) -> Tensor: + if has_torch_function_unary(input): + return handle_torch_function( + adaptive_max_pool3d, (input,), input, output_size, return_indices=return_indices + ) + return adaptive_max_pool3d_with_indices(input, output_size)[0] + + +adaptive_max_pool3d = boolean_dispatch( + arg_name="return_indices", + arg_index=2, + default=False, + if_true=adaptive_max_pool3d_with_indices, + if_false=_adaptive_max_pool3d, + module_name=__name__, + func_name="adaptive_max_pool3d", +) + + +adaptive_avg_pool1d = _add_docstr( + torch.adaptive_avg_pool1d, + r""" +adaptive_avg_pool1d(input, output_size) -> Tensor + +Applies a 1D adaptive average pooling over an input signal composed of +several input planes. + +See :class:`~torch.nn.AdaptiveAvgPool1d` for details and output shape. + +Args: + output_size: the target output size (single integer) +""", +) + + +def adaptive_avg_pool2d(input: Tensor, output_size: BroadcastingList2[int]) -> Tensor: + r"""Apply a 2D adaptive average pooling over an input signal composed of several input planes. + + See :class:`~torch.nn.AdaptiveAvgPool2d` for details and output shape. + + Args: + output_size: the target output size (single integer or + double-integer tuple) + """ + if has_torch_function_unary(input): + return handle_torch_function(adaptive_avg_pool2d, (input,), input, output_size) + _output_size = _list_with_default(output_size, input.size()) + return torch._C._nn.adaptive_avg_pool2d(input, _output_size) + + +def adaptive_avg_pool3d(input: Tensor, output_size: BroadcastingList3[int]) -> Tensor: + r"""Apply a 3D adaptive average pooling over an input signal composed of several input planes. + + See :class:`~torch.nn.AdaptiveAvgPool3d` for details and output shape. + + Args: + output_size: the target output size (single integer or + triple-integer tuple) + """ + if has_torch_function_unary(input): + return handle_torch_function(adaptive_avg_pool3d, (input,), input, output_size) + _output_size = _list_with_default(output_size, input.size()) + return torch._C._nn.adaptive_avg_pool3d(input, _output_size) + + +# Activation functions +def dropout(input: Tensor, p: float = 0.5, training: bool = True, inplace: bool = False) -> Tensor: + r"""During training, randomly zeroes some elements of the input tensor with probability :attr:`p`. + + Uses samples from a Bernoulli distribution. + + See :class:`~torch.nn.Dropout` for details. + + Args: + p: probability of an element to be zeroed. Default: 0.5 + training: apply dropout if is ``True``. Default: ``True`` + inplace: If set to ``True``, will do this operation in-place. Default: ``False`` + """ + if has_torch_function_unary(input): + return handle_torch_function(dropout, (input,), input, p=p, training=training, inplace=inplace) + if p < 0.0 or p > 1.0: + raise ValueError(f"dropout probability has to be between 0 and 1, but got {p}") + return _VF.dropout_(input, p, training) if inplace else _VF.dropout(input, p, training) + + +def alpha_dropout(input: Tensor, p: float = 0.5, training: bool = False, inplace: bool = False) -> Tensor: + r"""Apply alpha dropout to the input. + + See :class:`~torch.nn.AlphaDropout` for details. + """ + if has_torch_function_unary(input): + return handle_torch_function(alpha_dropout, (input,), input, p=p, training=training, inplace=inplace) + if p < 0.0 or p > 1.0: + raise ValueError(f"dropout probability has to be between 0 and 1, but got {p}") + return _VF.alpha_dropout_(input, p, training) if inplace else _VF.alpha_dropout(input, p, training) + + +def dropout1d(input: Tensor, p: float = 0.5, training: bool = True, inplace: bool = False) -> Tensor: + r"""Randomly zero out entire channels (a channel is a 1D feature map). + + For example, the :math:`j`-th channel of the :math:`i`-th sample in the + batched input is a 1D tensor :math:`\text{input}[i, j]` of the input tensor. + Each channel will be zeroed out independently on every forward call with + probability :attr:`p` using samples from a Bernoulli distribution. + + See :class:`~torch.nn.Dropout1d` for details. + + Args: + p: probability of a channel to be zeroed. Default: 0.5 + training: apply dropout if is ``True``. Default: ``True`` + inplace: If set to ``True``, will do this operation in-place. Default: ``False`` + """ + if has_torch_function_unary(input): + return handle_torch_function(dropout1d, (input,), input, p=p, training=training, inplace=inplace) + if p < 0.0 or p > 1.0: + raise ValueError(f"dropout probability has to be between 0 and 1, but got {p}") + inp_dim = input.dim() + if inp_dim not in (2, 3): + raise RuntimeError(f"dropout1d: Expected 2D or 3D input, but received a {inp_dim}D input. " + "Note that dropout1d exists to provide channel-wise dropout on inputs with 1 " + "spatial dimension, a channel dimension, and an optional batch dimension " + "(i.e. 2D or 3D inputs).") + + is_batched = inp_dim == 3 + if not is_batched: + input = input.unsqueeze_(0) if inplace else input.unsqueeze(0) + + result = _VF.feature_dropout_(input, p, training) if inplace else _VF.feature_dropout(input, p, training) + + if not is_batched: + result = result.squeeze_(0) if inplace else result.squeeze(0) + + return result + + +def dropout2d(input: Tensor, p: float = 0.5, training: bool = True, inplace: bool = False) -> Tensor: + r"""Randomly zero out entire channels (a channel is a 2D feature map). + + For example, the :math:`j`-th channel of the :math:`i`-th sample in the + batched input is a 2D tensor :math:`\text{input}[i, j]` of the input tensor. + Each channel will be zeroed out independently on every forward call with + probability :attr:`p` using samples from a Bernoulli distribution. + + See :class:`~torch.nn.Dropout2d` for details. + + Args: + p: probability of a channel to be zeroed. Default: 0.5 + training: apply dropout if is ``True``. Default: ``True`` + inplace: If set to ``True``, will do this operation in-place. Default: ``False`` + """ + if has_torch_function_unary(input): + return handle_torch_function(dropout2d, (input,), input, p=p, training=training, inplace=inplace) + if p < 0.0 or p > 1.0: + raise ValueError(f"dropout probability has to be between 0 and 1, but got {p}") + inp_dim = input.dim() + if inp_dim not in (3, 4): + warn_msg = (f"dropout2d: Received a {inp_dim}-D input to dropout2d, which is deprecated " + "and will result in an error in a future release. To retain the behavior " + "and silence this warning, please use dropout instead. Note that dropout2d " + "exists to provide channel-wise dropout on inputs with 2 spatial dimensions, " + "a channel dimension, and an optional batch dimension (i.e. 3D or 4D inputs).") + warnings.warn(warn_msg) + + # TODO: Properly support no-batch-dim inputs. For now, these are NOT supported; passing + # a 3D input will perform dropout1d behavior instead. This was done historically and the + # behavior is maintained here for now. + # See https://github.com/pytorch/pytorch/issues/77081 + if inp_dim == 3: + warnings.warn("dropout2d: Received a 3D input to dropout2d and assuming that channel-wise " + "1D dropout behavior is desired - input is interpreted as shape (N, C, L), where C " + "is the channel dim. This behavior will change in a future release to interpret the " + "input as one without a batch dimension, i.e. shape (C, H, W). To maintain the 1D " + "channel-wise dropout behavior, please switch to using dropout1d instead.") + + result = _VF.feature_dropout_(input, p, training) if inplace else _VF.feature_dropout(input, p, training) + + return result + + +def dropout3d(input: Tensor, p: float = 0.5, training: bool = True, inplace: bool = False) -> Tensor: + r"""Randomly zero out entire channels (a channel is a 3D feature map). + + For example, the :math:`j`-th channel of the :math:`i`-th sample in the + batched input is a 3D tensor :math:`\text{input}[i, j]` of the input tensor. + Each channel will be zeroed out independently on every forward call with + probability :attr:`p` using samples from a Bernoulli distribution. + + See :class:`~torch.nn.Dropout3d` for details. + + Args: + p: probability of a channel to be zeroed. Default: 0.5 + training: apply dropout if is ``True``. Default: ``True`` + inplace: If set to ``True``, will do this operation in-place. Default: ``False`` + """ + if has_torch_function_unary(input): + return handle_torch_function(dropout3d, (input,), input, p=p, training=training, inplace=inplace) + if p < 0.0 or p > 1.0: + raise ValueError(f"dropout probability has to be between 0 and 1, but got {p}") + inp_dim = input.dim() + if inp_dim not in (4, 5): + warn_msg = (f"dropout3d: Received a {inp_dim}-D input to dropout3d, which is deprecated " + "and will result in an error in a future release. To retain the behavior " + "and silence this warning, please use dropout instead. Note that dropout3d " + "exists to provide channel-wise dropout on inputs with 3 spatial dimensions, " + "a channel dimension, and an optional batch dimension (i.e. 4D or 5D inputs).") + warnings.warn(warn_msg) + + is_batched = inp_dim == 5 + if not is_batched: + input = input.unsqueeze_(0) if inplace else input.unsqueeze(0) + + result = _VF.feature_dropout_(input, p, training) if inplace else _VF.feature_dropout(input, p, training) + + if not is_batched: + result = result.squeeze_(0) if inplace else result.squeeze(0) + return result + + +def feature_alpha_dropout(input: Tensor, p: float = 0.5, training: bool = False, inplace: bool = False) -> Tensor: + r"""Randomly masks out entire channels (a channel is a feature map). + + For example, the :math:`j`-th channel of the :math:`i`-th sample in the batch input + is a tensor :math:`\text{input}[i, j]` of the input tensor. Instead of + setting activations to zero, as in regular Dropout, the activations are set + to the negative saturation value of the SELU activation function. + + Each element will be masked independently on every forward call with + probability :attr:`p` using samples from a Bernoulli distribution. + The elements to be masked are randomized on every forward call, and scaled + and shifted to maintain zero mean and unit variance. + + See :class:`~torch.nn.FeatureAlphaDropout` for details. + + Args: + p: dropout probability of a channel to be zeroed. Default: 0.5 + training: apply dropout if is ``True``. Default: ``True`` + inplace: If set to ``True``, will do this operation in-place. Default: ``False`` + """ + if has_torch_function_unary(input): + return handle_torch_function( + feature_alpha_dropout, (input,), input, p=p, training=training, inplace=inplace + ) + if p < 0.0 or p > 1.0: + raise ValueError(f"dropout probability has to be between 0 and 1, but got {p}") + return _VF.feature_alpha_dropout_(input, p, training) if inplace else _VF.feature_alpha_dropout(input, p, training) + + +def _threshold(input: Tensor, threshold: float, value: float, inplace: bool = False) -> Tensor: + r"""Apply a threshold to each element of the input Tensor. + + See :class:`~torch.nn.Threshold` for more details. + """ + if has_torch_function_unary(input): + return handle_torch_function(_threshold, (input,), input, threshold, value, inplace=inplace) + if inplace: + result = _VF.threshold_(input, threshold, value) + else: + result = _VF.threshold(input, threshold, value) + return result + + +# We define this function as _threshold because it takes an argument +# named threshold, which clobbers the recursive reference to the +# function needed for __torch_function__ support +threshold = _threshold + +threshold_ = _add_docstr( + _VF.threshold_, + r""" +threshold_(input, threshold, value) -> Tensor + +In-place version of :func:`~threshold`. +""", +) + + +def relu(input: Tensor, inplace: bool = False) -> Tensor: # noqa: D400,D402 + r"""relu(input, inplace=False) -> Tensor + + Applies the rectified linear unit function element-wise. See + :class:`~torch.nn.ReLU` for more details. + """ + if has_torch_function_unary(input): + return handle_torch_function(relu, (input,), input, inplace=inplace) + if inplace: + result = torch.relu_(input) + else: + result = torch.relu(input) + return result + + +relu_ = _add_docstr( + torch.relu_, + r""" +relu_(input) -> Tensor + +In-place version of :func:`~relu`. +""", +) + + +def glu(input: Tensor, dim: int = -1) -> Tensor: # noqa: D400,D402 + r""" + glu(input, dim=-1) -> Tensor + + The gated linear unit. Computes: + + .. math :: + \text{GLU}(a, b) = a \otimes \sigma(b) + + where `input` is split in half along `dim` to form `a` and `b`, :math:`\sigma` + is the sigmoid function and :math:`\otimes` is the element-wise product between matrices. + + See `Language Modeling with Gated Convolutional Networks `_. + + Args: + input (Tensor): input tensor + dim (int): dimension on which to split the input. Default: -1 + """ + if has_torch_function_unary(input): + return handle_torch_function(glu, (input,), input, dim=dim) + if input.dim() == 0: + raise RuntimeError("glu does not support scalars because halving size must be even") + return torch._C._nn.glu(input, dim) + + +def hardtanh(input: Tensor, min_val: float = -1., max_val: float = 1., inplace: bool = False) -> Tensor: # noqa: D400,D402 + r""" + hardtanh(input, min_val=-1., max_val=1., inplace=False) -> Tensor + + Applies the HardTanh function element-wise. See :class:`~torch.nn.Hardtanh` for more + details. + """ + if has_torch_function_unary(input): + return handle_torch_function(hardtanh, (input,), input, min_val=min_val, max_val=max_val, inplace=inplace) + if inplace: + result = torch._C._nn.hardtanh_(input, min_val, max_val) + else: + result = torch._C._nn.hardtanh(input, min_val, max_val) + return result + + +hardtanh_ = _add_docstr( + torch._C._nn.hardtanh_, + r""" +hardtanh_(input, min_val=-1., max_val=1.) -> Tensor + +In-place version of :func:`~hardtanh`. +""", +) + + +def relu6(input: Tensor, inplace: bool = False) -> Tensor: # noqa: D400,D402 + r"""relu6(input, inplace=False) -> Tensor + + Applies the element-wise function :math:`\text{ReLU6}(x) = \min(\max(0,x), 6)`. + + See :class:`~torch.nn.ReLU6` for more details. + """ + if has_torch_function_unary(input): + return handle_torch_function(relu6, (input,), input, inplace=inplace) + if inplace: + result = torch._C._nn.relu6_(input) + else: + result = torch._C._nn.relu6(input) + return result + + +def elu(input: Tensor, alpha: float = 1.0, inplace: bool = False) -> Tensor: + r"""Apply the Exponential Linear Unit (ELU) function element-wise. + + See :class:`~torch.nn.ELU` for more details. + """ + if has_torch_function_unary(input): + return handle_torch_function(elu, (input,), input, alpha=alpha, inplace=inplace) + if inplace: + result = torch._C._nn.elu_(input, alpha) + else: + result = torch._C._nn.elu(input, alpha) + return result + + +elu_ = _add_docstr( + torch._C._nn.elu_, + r""" +elu_(input, alpha=1.) -> Tensor + +In-place version of :func:`~elu`. +""", +) + + +def selu(input: Tensor, inplace: bool = False) -> Tensor: # noqa: D400,D402 + r"""selu(input, inplace=False) -> Tensor + + Applies element-wise, + :math:`\text{SELU}(x) = scale * (\max(0,x) + \min(0, \alpha * (\exp(x) - 1)))`, + with :math:`\alpha=1.6732632423543772848170429916717` and + :math:`scale=1.0507009873554804934193349852946`. + + See :class:`~torch.nn.SELU` for more details. + """ + if has_torch_function_unary(input): + return handle_torch_function(selu, (input,), input, inplace=inplace) + if inplace: + result = torch.selu_(input) + else: + result = torch.selu(input) + return result + + +selu_ = _add_docstr( + torch.selu_, + r""" +selu_(input) -> Tensor + +In-place version of :func:`~selu`. +""", +) + + +def celu(input: Tensor, alpha: float = 1.0, inplace: bool = False) -> Tensor: # noqa: D400,D402 + r"""celu(input, alpha=1., inplace=False) -> Tensor + + Applies element-wise, + :math:`\text{CELU}(x) = \max(0,x) + \min(0, \alpha * (\exp(x/\alpha) - 1))`. + + See :class:`~torch.nn.CELU` for more details. + """ + if has_torch_function_unary(input): + return handle_torch_function(celu, (input,), input, alpha=alpha, inplace=inplace) + if inplace: + result = torch.celu_(input, alpha) + else: + result = torch.celu(input, alpha) + return result + + +celu_ = _add_docstr( + torch.celu_, + r""" +celu_(input, alpha=1.) -> Tensor + +In-place version of :func:`~celu`. +""", +) + + +def leaky_relu(input: Tensor, negative_slope: float = 0.01, inplace: bool = False) -> Tensor: # noqa: D400,D402 + r""" + leaky_relu(input, negative_slope=0.01, inplace=False) -> Tensor + + Applies element-wise, + :math:`\text{LeakyReLU}(x) = \max(0, x) + \text{negative\_slope} * \min(0, x)` + + See :class:`~torch.nn.LeakyReLU` for more details. + """ + if has_torch_function_unary(input): + return handle_torch_function(leaky_relu, (input,), input, negative_slope=negative_slope, inplace=inplace) + if inplace: + result = torch._C._nn.leaky_relu_(input, negative_slope) + else: + result = torch._C._nn.leaky_relu(input, negative_slope) + return result + + +leaky_relu_ = _add_docstr( + torch._C._nn.leaky_relu_, + r""" +leaky_relu_(input, negative_slope=0.01) -> Tensor + +In-place version of :func:`~leaky_relu`. +""", +) + + +prelu = _add_docstr( + torch.prelu, + r"""prelu(input, weight) -> Tensor + +Applies element-wise the function +:math:`\text{PReLU}(x) = \max(0,x) + \text{weight} * \min(0,x)` where weight is a +learnable parameter. + +.. note:: + `weight` is expected to be a scalar or 1-D tensor. If `weight` is 1-D, + its size must match the number of input channels, determined by + `input.size(1)` when `input.dim() >= 2`, otherwise 1. + In the 1-D case, note that when `input` has dim > 2, `weight` can be expanded + to the shape of `input` in a way that is not possible using normal + :ref:`broadcasting semantics`. + +See :class:`~torch.nn.PReLU` for more details. +""") + + +def rrelu( + input: Tensor, lower: float = 1.0 / 8, upper: float = 1.0 / 3, training: bool = False, inplace: bool = False +) -> Tensor: # noqa: D400,D402 + r"""rrelu(input, lower=1./8, upper=1./3, training=False, inplace=False) -> Tensor + + Randomized leaky ReLU. + + See :class:`~torch.nn.RReLU` for more details. + """ + if has_torch_function_unary(input): + return handle_torch_function( + rrelu, (input,), input, lower=lower, upper=upper, training=training, inplace=inplace + ) + if inplace: + result = torch.rrelu_(input, lower, upper, training) + else: + result = torch.rrelu(input, lower, upper, training) + return result + + +rrelu_ = _add_docstr( + torch.rrelu_, + r""" +rrelu_(input, lower=1./8, upper=1./3, training=False) -> Tensor + +In-place version of :func:`~rrelu`. +""", +) + +logsigmoid = _add_docstr( + torch._C._nn.log_sigmoid, + r""" +logsigmoid(input) -> Tensor + +Applies element-wise :math:`\text{LogSigmoid}(x_i) = \log \left(\frac{1}{1 + \exp(-x_i)}\right)` + +See :class:`~torch.nn.LogSigmoid` for more details. +""", +) + +gelu = _add_docstr( + torch._C._nn.gelu, + r""" +gelu(input, approximate = 'none') -> Tensor + +When the approximate argument is 'none', it applies element-wise the function +:math:`\text{GELU}(x) = x * \Phi(x)` + +where :math:`\Phi(x)` is the Cumulative Distribution Function for Gaussian Distribution. + +When the approximate argument is 'tanh', Gelu is estimated with + +.. math:: + \text{GELU}(x) = 0.5 * x * (1 + \text{Tanh}(\sqrt{2 / \pi} * (x + 0.044715 * x^3))) + +See `Gaussian Error Linear Units (GELUs) `_. +""") + +hardshrink = _add_docstr( + torch.hardshrink, + r""" +hardshrink(input, lambd=0.5) -> Tensor + +Applies the hard shrinkage function element-wise + +See :class:`~torch.nn.Hardshrink` for more details. +""") + + +def tanhshrink(input): # noqa: D400,D402 + r"""tanhshrink(input) -> Tensor + + Applies element-wise, :math:`\text{Tanhshrink}(x) = x - \text{Tanh}(x)` + + See :class:`~torch.nn.Tanhshrink` for more details. + """ + if has_torch_function_unary(input): + return handle_torch_function(tanhshrink, (input,), input) + return input - input.tanh() + + +def softsign(input): # noqa: D400,D402 + r"""softsign(input) -> Tensor + + Applies element-wise, the function :math:`\text{SoftSign}(x) = \frac{x}{1 + |x|}` + + See :class:`~torch.nn.Softsign` for more details. + """ + if has_torch_function_unary(input): + return handle_torch_function(softsign, (input,), input) + return input / (input.abs() + 1) + + +softplus = _add_docstr( + torch._C._nn.softplus, + r""" +softplus(input, beta=1, threshold=20) -> Tensor + +Applies element-wise, the function :math:`\text{Softplus}(x) = \frac{1}{\beta} * \log(1 + \exp(\beta * x))`. + +For numerical stability the implementation reverts to the linear function +when :math:`input \times \beta > threshold`. + +See :class:`~torch.nn.Softplus` for more details. +""", +) + + +def _get_softmax_dim(name: str, ndim: int, stacklevel: int) -> int: + warnings.warn( + f"Implicit dimension choice for {name} has been deprecated. Change the call to include dim=X as an argument.", + stacklevel=stacklevel, + ) + if ndim == 0 or ndim == 1 or ndim == 3: + ret = 0 + else: + ret = 1 + return ret + + +def softmin(input: Tensor, dim: Optional[int] = None, _stacklevel: int = 3, dtype: Optional[DType] = None) -> Tensor: + r"""Apply a softmin function. + + Note that :math:`\text{Softmin}(x) = \text{Softmax}(-x)`. See softmax definition for mathematical formula. + + See :class:`~torch.nn.Softmin` for more details. + + Args: + input (Tensor): input + dim (int): A dimension along which softmin will be computed (so every slice + along dim will sum to 1). + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + """ + if has_torch_function_unary(input): + return handle_torch_function(softmin, (input,), input, dim=dim, _stacklevel=_stacklevel, dtype=dtype) + if dim is None: + dim = _get_softmax_dim("softmin", input.dim(), _stacklevel) + if dtype is None: + ret = (-input).softmax(dim) + else: + ret = (-input).softmax(dim, dtype=dtype) + return ret + + +def softmax(input: Tensor, dim: Optional[int] = None, _stacklevel: int = 3, dtype: Optional[DType] = None) -> Tensor: + r"""Apply a softmax function. + + Softmax is defined as: + + :math:`\text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}` + + It is applied to all slices along dim, and will re-scale them so that the elements + lie in the range `[0, 1]` and sum to 1. + + See :class:`~torch.nn.Softmax` for more details. + + Args: + input (Tensor): input + dim (int): A dimension along which softmax will be computed. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + + .. note:: + This function doesn't work directly with NLLLoss, + which expects the Log to be computed between the Softmax and itself. + Use log_softmax instead (it's faster and has better numerical properties). + + """ + if has_torch_function_unary(input): + return handle_torch_function(softmax, (input,), input, dim=dim, _stacklevel=_stacklevel, dtype=dtype) + if dim is None: + dim = _get_softmax_dim("softmax", input.dim(), _stacklevel) + if dtype is None: + ret = input.softmax(dim) + else: + ret = input.softmax(dim, dtype=dtype) + return ret + + +def gumbel_softmax(logits: Tensor, tau: float = 1, hard: bool = False, eps: float = 1e-10, dim: int = -1) -> Tensor: + r""" + Sample from the Gumbel-Softmax distribution (`Link 1`_ `Link 2`_) and optionally discretize. + + Args: + logits: `[..., num_features]` unnormalized log probabilities + tau: non-negative scalar temperature + hard: if ``True``, the returned samples will be discretized as one-hot vectors, + but will be differentiated as if it is the soft sample in autograd + dim (int): A dimension along which softmax will be computed. Default: -1. + + Returns: + Sampled tensor of same shape as `logits` from the Gumbel-Softmax distribution. + If ``hard=True``, the returned samples will be one-hot, otherwise they will + be probability distributions that sum to 1 across `dim`. + + .. note:: + This function is here for legacy reasons, may be removed from nn.Functional in the future. + + .. note:: + The main trick for `hard` is to do `y_hard - y_soft.detach() + y_soft` + + It achieves two things: + - makes the output value exactly one-hot + (since we add then subtract y_soft value) + - makes the gradient equal to y_soft gradient + (since we strip all other gradients) + + Examples:: + >>> logits = torch.randn(20, 32) + >>> # Sample soft categorical using reparametrization trick: + >>> F.gumbel_softmax(logits, tau=1, hard=False) + >>> # Sample hard categorical using "Straight-through" trick: + >>> F.gumbel_softmax(logits, tau=1, hard=True) + + .. _Link 1: + https://arxiv.org/abs/1611.00712 + .. _Link 2: + https://arxiv.org/abs/1611.01144 + """ + if has_torch_function_unary(logits): + return handle_torch_function(gumbel_softmax, (logits,), logits, tau=tau, hard=hard, eps=eps, dim=dim) + if eps != 1e-10: + warnings.warn("`eps` parameter is deprecated and has no effect.") + + gumbels = ( + -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log() + ) # ~Gumbel(0,1) + gumbels = (logits + gumbels) / tau # ~Gumbel(logits,tau) + y_soft = gumbels.softmax(dim) + + if hard: + # Straight through. + index = y_soft.max(dim, keepdim=True)[1] + y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0) + ret = y_hard - y_soft.detach() + y_soft + else: + # Reparametrization trick. + ret = y_soft + return ret + + +def log_softmax(input: Tensor, dim: Optional[int] = None, _stacklevel: int = 3, dtype: Optional[DType] = None) -> Tensor: + r"""Apply a softmax followed by a logarithm. + + While mathematically equivalent to log(softmax(x)), doing these two + operations separately is slower and numerically unstable. This function + uses an alternative formulation to compute the output and gradient correctly. + + See :class:`~torch.nn.LogSoftmax` for more details. + + Args: + input (Tensor): input + dim (int): A dimension along which log_softmax will be computed. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is cast to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + """ + if has_torch_function_unary(input): + return handle_torch_function(log_softmax, (input,), input, dim=dim, _stacklevel=_stacklevel, dtype=dtype) + if dim is None: + dim = _get_softmax_dim("log_softmax", input.dim(), _stacklevel) + if dtype is None: + ret = input.log_softmax(dim) + else: + ret = input.log_softmax(dim, dtype=dtype) + return ret + + +softshrink = _add_docstr( + torch._C._nn.softshrink, + r""" +softshrink(input, lambd=0.5) -> Tensor + +Applies the soft shrinkage function elementwise + +See :class:`~torch.nn.Softshrink` for more details. +""", +) + + +def tanh(input): # noqa: D400,D402 + r"""tanh(input) -> Tensor + + Applies element-wise, + :math:`\text{Tanh}(x) = \tanh(x) = \frac{\exp(x) - \exp(-x)}{\exp(x) + \exp(-x)}` + + See :class:`~torch.nn.Tanh` for more details. + """ + return input.tanh() + + +def sigmoid(input): # noqa: D400,D402 + r"""sigmoid(input) -> Tensor + + Applies the element-wise function :math:`\text{Sigmoid}(x) = \frac{1}{1 + \exp(-x)}` + + See :class:`~torch.nn.Sigmoid` for more details. + """ + return input.sigmoid() + + +def hardsigmoid(input: Tensor, inplace: bool = False) -> Tensor: + r"""Apply the Hardsigmoid function element-wise. + + .. math:: + \text{Hardsigmoid}(x) = \begin{cases} + 0 & \text{if~} x \le -3, \\ + 1 & \text{if~} x \ge +3, \\ + x / 6 + 1 / 2 & \text{otherwise} + \end{cases} + + Args: + inplace: If set to ``True``, will do this operation in-place. Default: ``False`` + + See :class:`~torch.nn.Hardsigmoid` for more details. + """ + if has_torch_function_unary(input): + return handle_torch_function(hardsigmoid, (input,), input, inplace=inplace) + if inplace: + return torch._C._nn.hardsigmoid_(input) + return torch._C._nn.hardsigmoid(input) + + +linear = _add_docstr( + torch._C._nn.linear, + r""" +linear(input, weight, bias=None) -> Tensor + +Applies a linear transformation to the incoming data: :math:`y = xA^T + b`. + +This operation supports 2-D :attr:`weight` with :ref:`sparse layout` + +{sparse_beta_warning} + +This operator supports :ref:`TensorFloat32`. + +Shape: + + - Input: :math:`(*, in\_features)` where `*` means any number of + additional dimensions, including none + - Weight: :math:`(out\_features, in\_features)` or :math:`(in\_features)` + - Bias: :math:`(out\_features)` or :math:`()` + - Output: :math:`(*, out\_features)` or :math:`(*)`, based on the shape of the weight +""".format(**sparse_support_notes)) + + +bilinear = _add_docstr( + torch.bilinear, + r""" +bilinear(input1, input2, weight, bias=None) -> Tensor + +Applies a bilinear transformation to the incoming data: +:math:`y = x_1^T A x_2 + b` + +Shape: + + - input1: :math:`(N, *, H_{in1})` where :math:`H_{in1}=\text{in1\_features}` + and :math:`*` means any number of additional dimensions. + All but the last dimension of the inputs should be the same. + - input2: :math:`(N, *, H_{in2})` where :math:`H_{in2}=\text{in2\_features}` + - weight: :math:`(\text{out\_features}, \text{in1\_features}, + \text{in2\_features})` + - bias: :math:`(\text{out\_features})` + - output: :math:`(N, *, H_{out})` where :math:`H_{out}=\text{out\_features}` + and all but the last dimension are the same shape as the input. +""") + + +def silu(input: Tensor, inplace: bool = False) -> Tensor: + r"""Apply the Sigmoid Linear Unit (SiLU) function, element-wise. + + The SiLU function is also known as the swish function. + + .. math:: + \text{silu}(x) = x * \sigma(x), \text{where } \sigma(x) \text{ is the logistic sigmoid.} + + .. note:: + See `Gaussian Error Linear Units (GELUs) `_ + where the SiLU (Sigmoid Linear Unit) was originally coined, and see + `Sigmoid-Weighted Linear Units for Neural Network Function Approximation + in Reinforcement Learning `_ and `Swish: + a Self-Gated Activation Function `_ + where the SiLU was experimented with later. + + See :class:`~torch.nn.SiLU` for more details. + """ + if has_torch_function_unary(input): + return handle_torch_function(silu, (input,), input, inplace=inplace) + if inplace: + return torch._C._nn.silu_(input) + return torch._C._nn.silu(input) + + +def mish(input: Tensor, inplace: bool = False) -> Tensor: + r"""Apply the Mish function, element-wise. + + Mish: A Self Regularized Non-Monotonic Neural Activation Function. + + .. math:: + \text{Mish}(x) = x * \text{Tanh}(\text{Softplus}(x)) + + .. note:: + See `Mish: A Self Regularized Non-Monotonic Neural Activation Function `_ + + See :class:`~torch.nn.Mish` for more details. + """ + if has_torch_function_unary(input): + return handle_torch_function(mish, (input,), input, inplace=inplace) + if inplace: + return torch._C._nn.mish_(input) + return torch._C._nn.mish(input) + + +def hardswish(input: Tensor, inplace: bool = False) -> Tensor: + r"""Apply hardswish function, element-wise. + + Follows implementation as described in the paper: + `Searching for MobileNetV3`_. + + .. math:: + \text{Hardswish}(x) = \begin{cases} + 0 & \text{if~} x \le -3, \\ + x & \text{if~} x \ge +3, \\ + x \cdot (x + 3) /6 & \text{otherwise} + \end{cases} + + See :class:`~torch.nn.Hardswish` for more details. + + .. _`Searching for MobileNetV3`: + https://arxiv.org/abs/1905.02244 + """ + if has_torch_function_unary(input): + return handle_torch_function(hardswish, (input,), input, inplace=inplace) + if inplace: + return torch._C._nn.hardswish_(input) + return torch._C._nn.hardswish(input) + + +def _no_grad_embedding_renorm_(weight: Tensor, input: Tensor, max_norm: float, norm_type: float) -> Tuple[Tensor, Tensor]: + torch.embedding_renorm_(weight.detach(), input, max_norm, norm_type) + + +def embedding( + input: Tensor, + weight: Tensor, + padding_idx: Optional[int] = None, + max_norm: Optional[float] = None, + norm_type: float = 2.0, + scale_grad_by_freq: bool = False, + sparse: bool = False, +) -> Tensor: + r"""Generate a simple lookup table that looks up embeddings in a fixed dictionary and size. + + This module is often used to retrieve word embeddings using indices. + The input to the module is a list of indices, and the embedding matrix, + and the output is the corresponding word embeddings. + + See :class:`torch.nn.Embedding` for more details. + + .. note:: + Note that the analytical gradients of this function with respect to + entries in :attr:`weight` at the row specified by :attr:`padding_idx` + are expected to differ from the numerical ones. + + .. note:: + Note that `:class:`torch.nn.Embedding` differs from this function in + that it initializes the row of :attr:`weight` specified by + :attr:`padding_idx` to all zeros on construction. + + Args: + input (LongTensor): Tensor containing indices into the embedding matrix + weight (Tensor): The embedding matrix with number of rows equal to the maximum possible index + 1, + and number of columns equal to the embedding size + padding_idx (int, optional): If specified, the entries at :attr:`padding_idx` do not contribute to the gradient; + therefore, the embedding vector at :attr:`padding_idx` is not updated during training, + i.e. it remains as a fixed "pad". + max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm` + is renormalized to have norm :attr:`max_norm`. + Note: this will modify :attr:`weight` in-place. + norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``. + scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse of frequency of + the words in the mini-batch. Default ``False``. + sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` will be a sparse tensor. See Notes under + :class:`torch.nn.Embedding` for more details regarding sparse gradients. + + Shape: + - Input: LongTensor of arbitrary shape containing the indices to extract + - Weight: Embedding matrix of floating point type with shape `(V, embedding_dim)`, + where V = maximum index + 1 and embedding_dim = the embedding size + - Output: `(*, embedding_dim)`, where `*` is the input shape + + Examples:: + + >>> # a batch of 2 samples of 4 indices each + >>> input = torch.tensor([[1, 2, 4, 5], [4, 3, 2, 9]]) + >>> # an embedding matrix containing 10 tensors of size 3 + >>> embedding_matrix = torch.rand(10, 3) + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> F.embedding(input, embedding_matrix) + tensor([[[ 0.8490, 0.9625, 0.6753], + [ 0.9666, 0.7761, 0.6108], + [ 0.6246, 0.9751, 0.3618], + [ 0.4161, 0.2419, 0.7383]], + + [[ 0.6246, 0.9751, 0.3618], + [ 0.0237, 0.7794, 0.0528], + [ 0.9666, 0.7761, 0.6108], + [ 0.3385, 0.8612, 0.1867]]]) + + >>> # example with padding_idx + >>> weights = torch.rand(10, 3) + >>> weights[0, :].zero_() + >>> embedding_matrix = weights + >>> input = torch.tensor([[0, 2, 0, 5]]) + >>> F.embedding(input, embedding_matrix, padding_idx=0) + tensor([[[ 0.0000, 0.0000, 0.0000], + [ 0.5609, 0.5384, 0.8720], + [ 0.0000, 0.0000, 0.0000], + [ 0.6262, 0.2438, 0.7471]]]) + """ + if has_torch_function_variadic(input, weight): + return handle_torch_function( + embedding, + (input, weight), + input, + weight, + padding_idx=padding_idx, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + sparse=sparse, + ) + if padding_idx is not None: + if padding_idx > 0: + assert padding_idx < weight.size(0), "Padding_idx must be within num_embeddings" + elif padding_idx < 0: + assert padding_idx >= -weight.size(0), "Padding_idx must be within num_embeddings" + padding_idx = weight.size(0) + padding_idx + else: + padding_idx = -1 + if max_norm is not None: + # Note [embedding_renorm contiguous] + # `embedding_renorm_` will call .contiguous() on input anyways, so we + # call it here and take advantage of the improved locality in the + # `embedding` call below too. + input = input.contiguous() + # Note [embedding_renorm set_grad_enabled] + # XXX: equivalent to + # with torch.no_grad(): + # torch.embedding_renorm_ + # remove once script supports set_grad_enabled + _no_grad_embedding_renorm_(weight, input, max_norm, norm_type) + return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse) + + +def embedding_bag( + input: Tensor, + weight: Tensor, + offsets: Optional[Tensor] = None, + max_norm: Optional[float] = None, + norm_type: float = 2, + scale_grad_by_freq: bool = False, + mode: str = "mean", + sparse: bool = False, + per_sample_weights: Optional[Tensor] = None, + include_last_offset: bool = False, + padding_idx: Optional[int] = None, +) -> Tensor: + r"""Compute sums, means or maxes of `bags` of embeddings. + + Calculation is done without instantiating the intermediate embeddings. + See :class:`torch.nn.EmbeddingBag` for more details. + + Note: + {backward_reproducibility_note} + + Args: + input (LongTensor): Tensor containing bags of indices into the embedding matrix + weight (Tensor): The embedding matrix with number of rows equal to the maximum possible index + 1, + and number of columns equal to the embedding size + offsets (LongTensor, optional): Only used when :attr:`input` is 1D. :attr:`offsets` determines + the starting index position of each bag (sequence) in :attr:`input`. + max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm` + is renormalized to have norm :attr:`max_norm`. + Note: this will modify :attr:`weight` in-place. + norm_type (float, optional): The ``p`` in the ``p``-norm to compute for the :attr:`max_norm` option. + Default ``2``. + scale_grad_by_freq (bool, optional): if given, this will scale gradients by the inverse of frequency of + the words in the mini-batch. Default ``False``. + Note: this option is not supported when ``mode="max"``. + mode (str, optional): ``"sum"``, ``"mean"`` or ``"max"``. Specifies the way to reduce the bag. + Default: ``"mean"`` + sparse (bool, optional): if ``True``, gradient w.r.t. :attr:`weight` will be a sparse tensor. See Notes under + :class:`torch.nn.Embedding` for more details regarding sparse gradients. + Note: this option is not supported when ``mode="max"``. + per_sample_weights (Tensor, optional): a tensor of float / double weights, or None + to indicate all weights should be taken to be 1. If specified, :attr:`per_sample_weights` + must have exactly the same shape as input and is treated as having the same + :attr:`offsets`, if those are not None. + + include_last_offset (bool, optional): if ``True``, the size of offsets is equal to the number of bags + 1. + The last element is the size of the input, or the ending index position of the last bag (sequence). + + padding_idx (int, optional): If specified, the entries at :attr:`padding_idx` do not contribute to the + gradient; therefore, the embedding vector at :attr:`padding_idx` is not updated + during training, i.e. it remains as a fixed "pad". Note that the embedding + vector at :attr:`padding_idx` is excluded from the reduction. + + Shape: + - :attr:`input` (LongTensor) and :attr:`offsets` (LongTensor, optional) + + - If :attr:`input` is 2D of shape `(B, N)`, it will be treated as ``B`` bags (sequences) + each of fixed length ``N``, and this will return ``B`` values aggregated in a way + depending on the :attr:`mode`. :attr:`offsets` is ignored and required to be ``None`` in this case. + + - If :attr:`input` is 1D of shape `(N)`, it will be treated as a concatenation of + multiple bags (sequences). :attr:`offsets` is required to be a 1D tensor containing + the starting index positions of each bag in :attr:`input`. Therefore, for :attr:`offsets` + of shape `(B)`, :attr:`input` will be viewed as having ``B`` bags. + Empty bags (i.e., having 0-length) will have returned vectors filled by zeros. + + - :attr:`weight` (Tensor): the learnable weights of the module of shape `(num_embeddings, embedding_dim)` + + - :attr:`per_sample_weights` (Tensor, optional). Has the same shape as :attr:`input`. + + - :attr:`output`: aggregated embedding values of shape `(B, embedding_dim)` + + Examples:: + + >>> # an Embedding module containing 10 tensors of size 3 + >>> embedding_matrix = torch.rand(10, 3) + >>> # a batch of 2 samples of 4 indices each + >>> input = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9]) + >>> offsets = torch.tensor([0, 4]) + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> F.embedding_bag(input, embedding_matrix, offsets) + tensor([[ 0.3397, 0.3552, 0.5545], + [ 0.5893, 0.4386, 0.5882]]) + + >>> # example with padding_idx + >>> embedding_matrix = torch.rand(10, 3) + >>> input = torch.tensor([2, 2, 2, 2, 4, 3, 2, 9]) + >>> offsets = torch.tensor([0, 4]) + >>> F.embedding_bag(input, embedding_matrix, offsets, padding_idx=2, mode='sum') + tensor([[ 0.0000, 0.0000, 0.0000], + [-0.7082, 3.2145, -2.6251]]) + """ + if has_torch_function_variadic(input, weight, offsets, per_sample_weights): + return handle_torch_function( + embedding_bag, + (input, weight, offsets, per_sample_weights), + input, + weight, + offsets=offsets, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + mode=mode, + sparse=sparse, + per_sample_weights=per_sample_weights, + include_last_offset=include_last_offset, + padding_idx=padding_idx, + ) + # Check for backward compatibility. + # Used to be embedding_bag(weight, input, ...) + # Now is embedding_bag(input, weight, ...) + if weight.dtype == torch.long and input.is_floating_point(): + warnings.warn( + "Argument order of nn.functional.embedding_bag was changed. " + "Usage `embedding_bag(weight, input, ...)` is deprecated, " + "and should now be `embedding_bag(input, weight, ...)`." + ) + weight, input = input, weight + + if per_sample_weights is not None and input.size() != per_sample_weights.size(): + raise ValueError( + f"embedding_bag: If per_sample_weights ({per_sample_weights.shape}) is not None, " + f"then it must have the same shape as the input ({input.shape})" + ) + + if not weight.dim() == 2: + raise ValueError( + f"weight has to be a 2D Tensor, but got Tensor of dimension {weight.dim()}" + ) + + if input.dim() == 2: + if offsets is not None: + type_str = "" + # TODO: Remove this once script supports type() calls + if not torch.jit.is_scripting(): + type_str = str(type(offsets)) + raise ValueError( + "if input is 2D, then offsets has to be None" + ", as input is treated is a mini-batch of" + " fixed length sequences. However, found " + f"offsets of type {type_str}" + ) + offsets = torch.arange(0, input.numel(), input.size(1), dtype=input.dtype, device=input.device) + + input = input.reshape(-1) + if per_sample_weights is not None: + per_sample_weights = per_sample_weights.reshape(-1) + elif input.dim() == 1: + if offsets is None: + raise ValueError("offsets has to be a 1D Tensor but got None") + if offsets.dim() != 1: + raise ValueError("offsets has to be a 1D Tensor") + else: + raise ValueError(f"input has to be 1D or 2D Tensor, but got Tensor of dimension {input.dim()}") + if mode == "sum": + mode_enum = 0 + elif mode == "mean": + mode_enum = 1 + elif mode == "max": + mode_enum = 2 + + if scale_grad_by_freq: + raise ValueError("max mode does not support scaling the gradient by the frequency") + + if sparse: + raise ValueError("max mode does not support sparse weights") + + else: + raise ValueError("mode has to be one of sum, mean or max") + + if max_norm is not None: + # XXX: equivalent to + # with torch.no_grad(): + # torch.nembedding_renorm_ + # remove once script supports set_grad_enabled + _no_grad_embedding_renorm_(weight, input, max_norm, norm_type) + + if per_sample_weights is not None and mode != "sum": + raise NotImplementedError( + "embedding_bag: per_sample_weights was not None. " + "per_sample_weights is only supported for mode='sum' " + f"(got mode='{mode}'). Please open a feature request on GitHub." + ) + + ret, _, _, _ = torch.embedding_bag( + weight, input, offsets, scale_grad_by_freq, mode_enum, sparse, per_sample_weights, include_last_offset, padding_idx + ) + return ret + + +if embedding_bag.__doc__: + embedding_bag.__doc__ = embedding_bag.__doc__.format(**reproducibility_notes) + + +def _verify_batch_size(size: List[int]) -> None: + # XXX: JIT script does not support the reduce from functools, and mul op is a + # builtin, which cannot be used as a value to a func yet, so rewrite this size + # check to a simple equivalent for loop + # + # TODO: make use of reduce like below when JIT is ready with the missing features: + # from operator import mul + # from functools import reduce + # + # if reduce(mul, size[2:], size[0]) == 1 + size_prods = size[0] + for i in range(len(size) - 2): + size_prods *= size[i + 2] + if size_prods == 1: + raise ValueError(f"Expected more than 1 value per channel when training, got input size {size}") + + +def batch_norm( + input: Tensor, + running_mean: Optional[Tensor], + running_var: Optional[Tensor], + weight: Optional[Tensor] = None, + bias: Optional[Tensor] = None, + training: bool = False, + momentum: float = 0.1, + eps: float = 1e-5, +) -> Tensor: + r"""Apply Batch Normalization for each channel across a batch of data. + + See :class:`~torch.nn.BatchNorm1d`, :class:`~torch.nn.BatchNorm2d`, + :class:`~torch.nn.BatchNorm3d` for details. + """ + if has_torch_function_variadic(input, running_mean, running_var, weight, bias): + return handle_torch_function( + batch_norm, + (input, running_mean, running_var, weight, bias), + input, + running_mean, + running_var, + weight=weight, + bias=bias, + training=training, + momentum=momentum, + eps=eps, + ) + if training: + _verify_batch_size(input.size()) + + return torch.batch_norm( + input, weight, bias, running_mean, running_var, training, momentum, eps, torch.backends.cudnn.enabled + ) + + +def _verify_spatial_size(size: List[int]) -> None: + # Verify that there is > 1 spatial element for instance norm calculation. + size_prods = 1 + for i in range(2, len(size)): + size_prods *= size[i] + if size_prods == 1: + raise ValueError(f"Expected more than 1 spatial element when training, got input size {size}") + + +def instance_norm( + input: Tensor, + running_mean: Optional[Tensor] = None, + running_var: Optional[Tensor] = None, + weight: Optional[Tensor] = None, + bias: Optional[Tensor] = None, + use_input_stats: bool = True, + momentum: float = 0.1, + eps: float = 1e-5, +) -> Tensor: + r"""Apply Instance Normalization independently for each channel in every data sample within a batch. + + See :class:`~torch.nn.InstanceNorm1d`, :class:`~torch.nn.InstanceNorm2d`, + :class:`~torch.nn.InstanceNorm3d` for details. + """ + if has_torch_function_variadic(input, running_mean, running_var, weight, bias): + return handle_torch_function( + instance_norm, + (input, running_mean, running_var, weight, bias), + input, + running_mean=running_mean, + running_var=running_var, + weight=weight, + bias=bias, + use_input_stats=use_input_stats, + momentum=momentum, + eps=eps, + ) + if use_input_stats: + _verify_spatial_size(input.size()) + return torch.instance_norm( + input, weight, bias, running_mean, running_var, use_input_stats, momentum, eps, torch.backends.cudnn.enabled + ) + + +def layer_norm( + input: Tensor, + normalized_shape: List[int], + weight: Optional[Tensor] = None, + bias: Optional[Tensor] = None, + eps: float = 1e-5, +) -> Tensor: + r"""Apply Layer Normalization for last certain number of dimensions. + + See :class:`~torch.nn.LayerNorm` for details. + """ + if has_torch_function_variadic(input, weight, bias): + return handle_torch_function( + layer_norm, (input, weight, bias), input, normalized_shape, weight=weight, bias=bias, eps=eps + ) + return torch.layer_norm(input, normalized_shape, weight, bias, eps, torch.backends.cudnn.enabled) + + +def group_norm( + input: Tensor, num_groups: int, weight: Optional[Tensor] = None, bias: Optional[Tensor] = None, eps: float = 1e-5 +) -> Tensor: + r"""Apply Group Normalization for last certain number of dimensions. + + See :class:`~torch.nn.GroupNorm` for details. + """ + if has_torch_function_variadic(input, weight, bias): + return handle_torch_function(group_norm, (input, weight, bias,), input, num_groups, weight=weight, bias=bias, eps=eps) + if input.dim() < 2: + raise RuntimeError(f"Expected at least 2 dimensions for input tensor but received {input.dim()}") + _verify_batch_size([input.size(0) * input.size(1) // num_groups, num_groups] + list(input.size()[2:])) + return torch.group_norm(input, num_groups, weight, bias, eps, torch.backends.cudnn.enabled) + + +def local_response_norm(input: Tensor, size: int, alpha: float = 1e-4, beta: float = 0.75, k: float = 1.0) -> Tensor: + r"""Apply local response normalization over an input signal. + + The input signal is composed of several input planes, where channels occupy the second dimension. + Normalization is applied across channels. + + See :class:`~torch.nn.LocalResponseNorm` for details. + """ + if has_torch_function_unary(input): + return handle_torch_function(local_response_norm, (input,), input, size, alpha=alpha, beta=beta, k=k) + dim = input.dim() + if dim < 3: + raise ValueError( + f"Expected 3D or higher dimensionality input (got {dim} dimensions)" + ) + + if input.numel() == 0: + return input + + div = input.mul(input) + if dim == 3: + div = div.unsqueeze(1) + div = pad(div, (0, 0, size // 2, (size - 1) // 2)) + div = avg_pool2d(div, (size, 1), stride=1).squeeze(1) + else: + sizes = input.size() + div = div.view(sizes[0], 1, sizes[1], sizes[2], -1) + div = pad(div, (0, 0, 0, 0, size // 2, (size - 1) // 2)) + div = avg_pool3d(div, (size, 1, 1), stride=1).squeeze(1) + div = div.view(sizes) + div = div.mul(alpha).add(k).pow(beta) + return input / div + + +# loss + + +def ctc_loss( + log_probs: Tensor, + targets: Tensor, + input_lengths: Tensor, + target_lengths: Tensor, + blank: int = 0, + reduction: str = "mean", + zero_infinity: bool = False, +) -> Tensor: + r"""Apply the Connectionist Temporal Classification loss. + + See :class:`~torch.nn.CTCLoss` for details. + + Note: + {cudnn_reproducibility_note} + + Note: + {backward_reproducibility_note} + + Args: + log_probs: :math:`(T, N, C)` or :math:`(T, C)` where `C = number of characters in alphabet including blank`, + `T = input length`, and `N = batch size`. + The logarithmized probabilities of the outputs + (e.g. obtained with :func:`torch.nn.functional.log_softmax`). + targets: :math:`(N, S)` or `(sum(target_lengths))`. + Targets cannot be blank. In the second form, the targets are assumed to be concatenated. + input_lengths: :math:`(N)` or :math:`()`. + Lengths of the inputs (must each be :math:`\leq T`) + target_lengths: :math:`(N)` or :math:`()`. + Lengths of the targets + blank (int, optional): + Blank label. Default :math:`0`. + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the output losses will be divided by the target lengths and + then the mean over the batch is taken, ``'sum'``: the output will be + summed. Default: ``'mean'`` + zero_infinity (bool, optional): + Whether to zero infinite losses and the associated gradients. + Default: ``False`` + Infinite losses mainly occur when the inputs are too short + to be aligned to the targets. + + Example:: + + >>> log_probs = torch.randn(50, 16, 20).log_softmax(2).detach().requires_grad_() + >>> targets = torch.randint(1, 20, (16, 30), dtype=torch.long) + >>> input_lengths = torch.full((16,), 50, dtype=torch.long) + >>> target_lengths = torch.randint(10, 30, (16,), dtype=torch.long) + >>> loss = F.ctc_loss(log_probs, targets, input_lengths, target_lengths) + >>> loss.backward() + """ + if has_torch_function_variadic(log_probs, targets, input_lengths, target_lengths): + return handle_torch_function( + ctc_loss, + (log_probs, targets, input_lengths, target_lengths), + log_probs, targets, input_lengths, target_lengths, + blank=blank, reduction=reduction, zero_infinity=zero_infinity + ) + return torch.ctc_loss( + log_probs, targets, input_lengths, target_lengths, blank, _Reduction.get_enum(reduction), zero_infinity + ) + + +if ctc_loss.__doc__: + ctc_loss.__doc__ = ctc_loss.__doc__.format(**reproducibility_notes) + + +def nll_loss( + input: Tensor, + target: Tensor, + weight: Optional[Tensor] = None, + size_average: Optional[bool] = None, + ignore_index: int = -100, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: + r"""Compute the negative log likelihood loss. + + See :class:`~torch.nn.NLLLoss` for details. + + Args: + input: :math:`(N, C)` where `C = number of classes` or :math:`(N, C, H, W)` + in case of 2D Loss, or :math:`(N, C, d_1, d_2, ..., d_K)` where :math:`K \geq 1` + in the case of K-dimensional loss. `input` is expected to be log-probabilities. + target: :math:`(N)` where each value is :math:`0 \leq \text{targets}[i] \leq C-1`, + or :math:`(N, d_1, d_2, ..., d_K)` where :math:`K \geq 1` for + K-dimensional loss. + weight (Tensor, optional): a manual rescaling weight given to each + class. If given, has to be a Tensor of size `C` + size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there multiple elements per sample. If the field :attr:`size_average` + is set to ``False``, the losses are instead summed for each minibatch. Ignored + when reduce is ``False``. Default: ``True`` + ignore_index (int, optional): Specifies a target value that is ignored + and does not contribute to the input gradient. When :attr:`size_average` is + ``True``, the loss is averaged over non-ignored targets. Default: -100 + reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: ``True`` + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in the meantime, + specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + + Example:: + + >>> # input is of size N x C = 3 x 5 + >>> input = torch.randn(3, 5, requires_grad=True) + >>> # each element in target has to have 0 <= value < C + >>> target = torch.tensor([1, 0, 4]) + >>> output = F.nll_loss(F.log_softmax(input, dim=1), target) + >>> output.backward() + """ + if has_torch_function_variadic(input, target, weight): + return handle_torch_function( + nll_loss, + (input, target, weight), + input, + target, + weight=weight, + size_average=size_average, + ignore_index=ignore_index, + reduce=reduce, + reduction=reduction, + ) + if size_average is not None or reduce is not None: + reduction = _Reduction.legacy_get_string(size_average, reduce) + return torch._C._nn.nll_loss_nd(input, target, weight, _Reduction.get_enum(reduction), ignore_index) + + +def poisson_nll_loss( + input: Tensor, + target: Tensor, + log_input: bool = True, + full: bool = False, + size_average: Optional[bool] = None, + eps: float = 1e-8, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: + r"""Poisson negative log likelihood loss. + + See :class:`~torch.nn.PoissonNLLLoss` for details. + + Args: + input: expectation of underlying Poisson distribution. + target: random sample :math:`target \sim \text{Poisson}(input)`. + log_input: if ``True`` the loss is computed as + :math:`\exp(\text{input}) - \text{target} * \text{input}`, if ``False`` then loss is + :math:`\text{input} - \text{target} * \log(\text{input}+\text{eps})`. Default: ``True`` + full: whether to compute full loss, i. e. to add the Stirling + approximation term. Default: ``False`` + :math:`\text{target} * \log(\text{target}) - \text{target} + 0.5 * \log(2 * \pi * \text{target})`. + size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there multiple elements per sample. If the field :attr:`size_average` + is set to ``False``, the losses are instead summed for each minibatch. Ignored + when reduce is ``False``. Default: ``True`` + eps (float, optional): Small value to avoid evaluation of :math:`\log(0)` when + :attr:`log_input`\ =\ ``False``. Default: 1e-8 + reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: ``True`` + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in the meantime, + specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + + """ + if has_torch_function_variadic(input, target): + return handle_torch_function( + poisson_nll_loss, + (input, target), + input, + target, + log_input=log_input, + full=full, + size_average=size_average, + eps=eps, + reduce=reduce, + reduction=reduction, + ) + if size_average is not None or reduce is not None: + reduction = _Reduction.legacy_get_string(size_average, reduce) + if reduction != "none" and reduction != "mean" and reduction != "sum": + ret = input + raise ValueError(reduction + " is not a valid value for reduction") + + ret = torch.poisson_nll_loss(input, target, log_input, full, eps, _Reduction.get_enum(reduction)) + return ret + + +def gaussian_nll_loss( + input: Tensor, + target: Tensor, + var: Tensor, + full: bool = False, + eps: float = 1e-6, + reduction: str = "mean", +) -> Tensor: + r"""Gaussian negative log likelihood loss. + + See :class:`~torch.nn.GaussianNLLLoss` for details. + + Args: + input: expectation of the Gaussian distribution. + target: sample from the Gaussian distribution. + var: tensor of positive variance(s), one for each of the expectations + in the input (heteroscedastic), or a single one (homoscedastic). + full (bool, optional): include the constant term in the loss calculation. Default: ``False``. + eps (float, optional): value added to var, for stability. Default: 1e-6. + reduction (str, optional): specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the output is the average of all batch member losses, + ``'sum'``: the output is the sum of all batch member losses. + Default: ``'mean'``. + """ + if has_torch_function_variadic(input, target, var): + return handle_torch_function( + gaussian_nll_loss, + (input, target, var), + input, + target, + var, + full=full, + eps=eps, + reduction=reduction, + ) + + # Check var size + # If var.size == input.size, the case is heteroscedastic and no further checks are needed. + # Otherwise: + if var.size() != input.size(): + + # If var is one dimension short of input, but the sizes match otherwise, then this is a homoscedastic case. + # e.g. input.size = (10, 2, 3), var.size = (10, 2) + # -> unsqueeze var so that var.shape = (10, 2, 1) + # this is done so that broadcasting can happen in the loss calculation + if input.size()[:-1] == var.size(): + var = torch.unsqueeze(var, -1) + + # This checks if the sizes match up to the final dimension, and the final dimension of var is of size 1. + # This is also a homoscedastic case. + # e.g. input.size = (10, 2, 3), var.size = (10, 2, 1) + elif input.size()[:-1] == var.size()[:-1] and var.size(-1) == 1: # Heteroscedastic case + pass + + # If none of the above pass, then the size of var is incorrect. + else: + raise ValueError("var is of incorrect size") + + # Check validity of reduction mode + if reduction != 'none' and reduction != 'mean' and reduction != 'sum': + raise ValueError(reduction + " is not valid") + + # Entries of var must be non-negative + if torch.any(var < 0): + raise ValueError("var has negative entry/entries") + + # Clamp for stability + var = var.clone() + with torch.no_grad(): + var.clamp_(min=eps) + + # Calculate the loss + loss = 0.5 * (torch.log(var) + (input - target)**2 / var) + if full: + loss += 0.5 * math.log(2 * math.pi) + + if reduction == 'mean': + return loss.mean() + elif reduction == 'sum': + return loss.sum() + else: + return loss + + +def kl_div( + input: Tensor, + target: Tensor, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", + log_target: bool = False, +) -> Tensor: + r"""Compute the KL Divergence loss. + + Refer - The `Kullback-Leibler divergence Loss + `__ + + See :class:`~torch.nn.KLDivLoss` for details. + + Args: + input: Tensor of arbitrary shape in log-probabilities. + target: Tensor of the same shape as input. See :attr:`log_target` for + the target's interpretation. + size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there multiple elements per sample. If the field :attr:`size_average` + is set to ``False``, the losses are instead summed for each minibatch. Ignored + when reduce is ``False``. Default: ``True`` + reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: ``True`` + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'batchmean'`` | ``'sum'`` | ``'mean'``. + ``'none'``: no reduction will be applied + ``'batchmean'``: the sum of the output will be divided by the batchsize + ``'sum'``: the output will be summed + ``'mean'``: the output will be divided by the number of elements in the output + Default: ``'mean'`` + log_target (bool): A flag indicating whether ``target`` is passed in the log space. + It is recommended to pass certain distributions (like ``softmax``) + in the log space to avoid numerical issues caused by explicit ``log``. + Default: ``False`` + + .. note:: + :attr:`size_average` and :attr:`reduce` are in the process of being deprecated, + and in the meantime, specifying either of those two args will override :attr:`reduction`. + + .. warning:: + :attr:`reduction` = ``'mean'`` doesn't return the true kl divergence value, please use + :attr:`reduction` = ``'batchmean'`` which aligns with KL math definition. + """ + if has_torch_function_variadic(input, target): + return handle_torch_function( + kl_div, + (input, target), + input, + target, + size_average=size_average, + reduce=reduce, + reduction=reduction, + log_target=log_target, + ) + if size_average is not None or reduce is not None: + reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) + else: + if reduction == "mean": + warnings.warn( + "reduction: 'mean' divides the total loss by both the batch size and the support size." + "'batchmean' divides only by the batch size, and aligns with the KL div math definition." + "'mean' will be changed to behave the same as 'batchmean' in the next major release." + ) + + # special case for batchmean + if reduction == "batchmean": + reduction_enum = _Reduction.get_enum("sum") + else: + reduction_enum = _Reduction.get_enum(reduction) + + reduced = torch.kl_div(input, target, reduction_enum, log_target=log_target) + + if reduction == "batchmean" and input.dim() != 0: + reduced = reduced / input.size()[0] + + return reduced + + +def cross_entropy( + input: Tensor, + target: Tensor, + weight: Optional[Tensor] = None, + size_average: Optional[bool] = None, + ignore_index: int = -100, + reduce: Optional[bool] = None, + reduction: str = "mean", + label_smoothing: float = 0.0, +) -> Tensor: + r"""Compute the cross entropy loss between input logits and target. + + See :class:`~torch.nn.CrossEntropyLoss` for details. + + Args: + input (Tensor) : Predicted unnormalized logits; + see Shape section below for supported shapes. + target (Tensor) : Ground truth class indices or class probabilities; + see Shape section below for supported shapes. + weight (Tensor, optional): a manual rescaling weight given to each + class. If given, has to be a Tensor of size `C` + size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there multiple elements per sample. If the field :attr:`size_average` + is set to ``False``, the losses are instead summed for each minibatch. Ignored + when reduce is ``False``. Default: ``True`` + ignore_index (int, optional): Specifies a target value that is ignored + and does not contribute to the input gradient. When :attr:`size_average` is + ``True``, the loss is averaged over non-ignored targets. Note that + :attr:`ignore_index` is only applicable when the target contains class indices. + Default: -100 + reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: ``True`` + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in the meantime, + specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + label_smoothing (float, optional): A float in [0.0, 1.0]. Specifies the amount + of smoothing when computing the loss, where 0.0 means no smoothing. The targets + become a mixture of the original ground truth and a uniform distribution as described in + `Rethinking the Inception Architecture for Computer Vision `__. Default: :math:`0.0`. + + Shape: + - Input: Shape :math:`(C)`, :math:`(N, C)` or :math:`(N, C, d_1, d_2, ..., d_K)` with :math:`K \geq 1` + in the case of `K`-dimensional loss. + - Target: If containing class indices, shape :math:`()`, :math:`(N)` or :math:`(N, d_1, d_2, ..., d_K)` with + :math:`K \geq 1` in the case of K-dimensional loss where each value should be between :math:`[0, C)`. + If containing class probabilities, same shape as the input and each value should be between :math:`[0, 1]`. + + where: + + .. math:: + \begin{aligned} + C ={} & \text{number of classes} \\ + N ={} & \text{batch size} \\ + \end{aligned} + + Examples:: + + >>> # Example of target with class indices + >>> input = torch.randn(3, 5, requires_grad=True) + >>> target = torch.randint(5, (3,), dtype=torch.int64) + >>> loss = F.cross_entropy(input, target) + >>> loss.backward() + >>> + >>> # Example of target with class probabilities + >>> input = torch.randn(3, 5, requires_grad=True) + >>> target = torch.randn(3, 5).softmax(dim=1) + >>> loss = F.cross_entropy(input, target) + >>> loss.backward() + """ + if has_torch_function_variadic(input, target, weight): + return handle_torch_function( + cross_entropy, + (input, target, weight), + input, + target, + weight=weight, + size_average=size_average, + ignore_index=ignore_index, + reduce=reduce, + reduction=reduction, + label_smoothing=label_smoothing, + ) + if size_average is not None or reduce is not None: + reduction = _Reduction.legacy_get_string(size_average, reduce) + return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing) + + +def binary_cross_entropy( + input: Tensor, + target: Tensor, + weight: Optional[Tensor] = None, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: + r"""Measure Binary Cross Entropy between the target and input probabilities. + + See :class:`~torch.nn.BCELoss` for details. + + Args: + input: Tensor of arbitrary shape as probabilities. + target: Tensor of the same shape as input with values between 0 and 1. + weight (Tensor, optional): a manual rescaling weight + if provided it's repeated to match input tensor shape + size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there multiple elements per sample. If the field :attr:`size_average` + is set to ``False``, the losses are instead summed for each minibatch. Ignored + when reduce is ``False``. Default: ``True`` + reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: ``True`` + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in the meantime, + specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + + Examples:: + + >>> input = torch.randn(3, 2, requires_grad=True) + >>> target = torch.rand(3, 2, requires_grad=False) + >>> loss = F.binary_cross_entropy(torch.sigmoid(input), target) + >>> loss.backward() + """ + if has_torch_function_variadic(input, target, weight): + return handle_torch_function( + binary_cross_entropy, + (input, target, weight), + input, + target, + weight=weight, + size_average=size_average, + reduce=reduce, + reduction=reduction, + ) + if size_average is not None or reduce is not None: + reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) + else: + reduction_enum = _Reduction.get_enum(reduction) + if target.size() != input.size(): + raise ValueError( + "Using a target size ({}) that is different to the input size ({}) is deprecated. " + "Please ensure they have the same size.".format(target.size(), input.size()) + ) + + if weight is not None: + new_size = _infer_size(target.size(), weight.size()) + weight = weight.expand(new_size) + + return torch._C._nn.binary_cross_entropy(input, target, weight, reduction_enum) + + +def binary_cross_entropy_with_logits( + input: Tensor, + target: Tensor, + weight: Optional[Tensor] = None, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", + pos_weight: Optional[Tensor] = None, +) -> Tensor: + r"""Calculate Binary Cross Entropy between target and input logits. + + See :class:`~torch.nn.BCEWithLogitsLoss` for details. + + Args: + input: Tensor of arbitrary shape as unnormalized scores (often referred to as logits). + target: Tensor of the same shape as input with values between 0 and 1 + weight (Tensor, optional): a manual rescaling weight + if provided it's repeated to match input tensor shape + size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there multiple elements per sample. If the field :attr:`size_average` + is set to ``False``, the losses are instead summed for each minibatch. Ignored + when reduce is ``False``. Default: ``True`` + reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: ``True`` + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in the meantime, + specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + pos_weight (Tensor, optional): a weight of positive examples to be broadcasted with target. + Must be a tensor with equal size along the class dimension to the number of classes. + Pay close attention to PyTorch's broadcasting semantics in order to achieve the desired + operations. For a target of size [B, C, H, W] (where B is batch size) pos_weight of + size [B, C, H, W] will apply different pos_weights to each element of the batch or + [C, H, W] the same pos_weights across the batch. To apply the same positive weight + along all spacial dimensions for a 2D multi-class target [C, H, W] use: [C, 1, 1]. + Default: ``None`` + + Examples:: + + >>> input = torch.randn(3, requires_grad=True) + >>> target = torch.empty(3).random_(2) + >>> loss = F.binary_cross_entropy_with_logits(input, target) + >>> loss.backward() + """ + if has_torch_function_variadic(input, target, weight, pos_weight): + return handle_torch_function( + binary_cross_entropy_with_logits, + (input, target, weight, pos_weight), + input, + target, + weight=weight, + size_average=size_average, + reduce=reduce, + reduction=reduction, + pos_weight=pos_weight, + ) + if size_average is not None or reduce is not None: + reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) + else: + reduction_enum = _Reduction.get_enum(reduction) + + if not (target.size() == input.size()): + raise ValueError(f"Target size ({target.size()}) must be the same as input size ({input.size()})") + + return torch.binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction_enum) + + +def smooth_l1_loss( + input: Tensor, + target: Tensor, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", + beta: float = 1.0, +) -> Tensor: + r"""Compute the Smooth L1 loss. + + Function uses a squared term if the absolute + element-wise error falls below beta and an L1 term otherwise. + + See :class:`~torch.nn.SmoothL1Loss` for details. + """ + if has_torch_function_variadic(input, target): + return handle_torch_function( + smooth_l1_loss, + (input, target), + input, + target, + size_average=size_average, + reduce=reduce, + reduction=reduction, + beta=beta, + ) + if not (target.size() == input.size()): + warnings.warn( + f"Using a target size ({target.size()}) that is different to the input size ({input.size()}). " + "This will likely lead to incorrect results due to broadcasting. " + "Please ensure they have the same size.", + stacklevel=2, + ) + if size_average is not None or reduce is not None: + reduction = _Reduction.legacy_get_string(size_average, reduce) + + expanded_input, expanded_target = torch.broadcast_tensors(input, target) + + if beta == 0.0: + return torch._C._nn.l1_loss(expanded_input, expanded_target, _Reduction.get_enum(reduction)) + else: + return torch._C._nn.smooth_l1_loss(expanded_input, expanded_target, _Reduction.get_enum(reduction), beta) + + +def huber_loss( + input: Tensor, + target: Tensor, + reduction: str = 'mean', + delta: float = 1.0, +) -> Tensor: + r"""Compute the Huber loss. + + Function uses a squared term if the absolute + element-wise error falls below delta and a delta-scaled L1 term otherwise. + + When delta equals 1, this loss is equivalent to SmoothL1Loss. + In general, Huber loss differs from SmoothL1Loss by a factor of delta (AKA beta in Smooth L1). + + See :class:`~torch.nn.HuberLoss` for details. + """ + if has_torch_function_variadic(input, target): + return handle_torch_function( + huber_loss, + (input, target), + input, + target, + reduction=reduction, + delta=delta, + ) + if not (target.size() == input.size()): + warnings.warn(f"Using a target size ({target.size()}) that is different to the input size ({input.size()}). " + "This will likely lead to incorrect results due to broadcasting. " + "Please ensure they have the same size.", + stacklevel=2) + + expanded_input, expanded_target = torch.broadcast_tensors(input, target) + return torch._C._nn.huber_loss(expanded_input, expanded_target, _Reduction.get_enum(reduction), delta) + + +def l1_loss( + input: Tensor, + target: Tensor, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: # noqa: D400,D402 + r"""l1_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor + + Function that takes the mean element-wise absolute value difference. + + See :class:`~torch.nn.L1Loss` for details. + """ + if has_torch_function_variadic(input, target): + return handle_torch_function( + l1_loss, (input, target), input, target, size_average=size_average, reduce=reduce, reduction=reduction + ) + if not (target.size() == input.size()): + warnings.warn( + f"Using a target size ({target.size()}) that is different to the input size ({input.size()}). " + "This will likely lead to incorrect results due to broadcasting. " + "Please ensure they have the same size.", + stacklevel=2, + ) + if size_average is not None or reduce is not None: + reduction = _Reduction.legacy_get_string(size_average, reduce) + + expanded_input, expanded_target = torch.broadcast_tensors(input, target) + return torch._C._nn.l1_loss(expanded_input, expanded_target, _Reduction.get_enum(reduction)) + + +def mse_loss( + input: Tensor, + target: Tensor, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: # noqa: D400,D402 + r"""mse_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor + + Measures the element-wise mean squared error. + See :class:`~torch.nn.MSELoss` for details. + """ + if has_torch_function_variadic(input, target): + return handle_torch_function( + mse_loss, (input, target), input, target, size_average=size_average, reduce=reduce, reduction=reduction + ) + if not (target.size() == input.size()): + warnings.warn( + f"Using a target size ({target.size()}) that is different to the input size ({input.size()}). " + "This will likely lead to incorrect results due to broadcasting. " + "Please ensure they have the same size.", + stacklevel=2, + ) + if size_average is not None or reduce is not None: + reduction = _Reduction.legacy_get_string(size_average, reduce) + + expanded_input, expanded_target = torch.broadcast_tensors(input, target) + return torch._C._nn.mse_loss(expanded_input, expanded_target, _Reduction.get_enum(reduction)) + + +def margin_ranking_loss( + input1: Tensor, + input2: Tensor, + target: Tensor, + margin: float = 0, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: # noqa: D400,D402 + r"""margin_ranking_loss(input1, input2, target, margin=0, size_average=None, reduce=None, reduction='mean') -> Tensor + + See :class:`~torch.nn.MarginRankingLoss` for details. + """ + if has_torch_function_variadic(input1, input2, target): + return handle_torch_function( + margin_ranking_loss, + (input1, input2, target), + input1, + input2, + target, + margin=margin, + size_average=size_average, + reduce=reduce, + reduction=reduction, + ) + if size_average is not None or reduce is not None: + reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) + else: + reduction_enum = _Reduction.get_enum(reduction) + if (input1.dim() != input2.dim() or input1.dim() != target.dim()): + raise RuntimeError( + f"margin_ranking_loss : All input tensors should have same dimension but got sizes: " + f"input1: {input1.size()}, input2: {input2.size()}, target: {target.size()} " + ) + return torch.margin_ranking_loss(input1, input2, target, margin, reduction_enum) + + +def hinge_embedding_loss( + input: Tensor, + target: Tensor, + margin: float = 1.0, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: # noqa: D400,D402 + r"""hinge_embedding_loss(input, target, margin=1.0, size_average=None, reduce=None, reduction='mean') -> Tensor + + See :class:`~torch.nn.HingeEmbeddingLoss` for details. + """ + if has_torch_function_variadic(input, target): + return handle_torch_function( + hinge_embedding_loss, + (input, target), + input, + target, + margin=margin, + size_average=size_average, + reduce=reduce, + reduction=reduction, + ) + if size_average is not None or reduce is not None: + reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) + else: + reduction_enum = _Reduction.get_enum(reduction) + return torch.hinge_embedding_loss(input, target, margin, reduction_enum) + + +def multilabel_margin_loss( + input: Tensor, + target: Tensor, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: # noqa: D400,D402 + r"""multilabel_margin_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor + + See :class:`~torch.nn.MultiLabelMarginLoss` for details. + """ + if has_torch_function_variadic(input, target): + return handle_torch_function( + multilabel_margin_loss, + (input, target), + input, + target, + size_average=size_average, + reduce=reduce, + reduction=reduction, + ) + if size_average is not None or reduce is not None: + reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) + else: + reduction_enum = _Reduction.get_enum(reduction) + return torch._C._nn.multilabel_margin_loss(input, target, reduction_enum) + + +def soft_margin_loss( + input: Tensor, + target: Tensor, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: # noqa: D400,D402 + r""" + soft_margin_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor + + See :class:`~torch.nn.SoftMarginLoss` for details. + """ + if has_torch_function_variadic(input, target): + return handle_torch_function( + soft_margin_loss, (input, target), input, target, size_average=size_average, reduce=reduce, reduction=reduction + ) + if size_average is not None or reduce is not None: + reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) + else: + reduction_enum = _Reduction.get_enum(reduction) + return torch._C._nn.soft_margin_loss(input, target, reduction_enum) + + +def multilabel_soft_margin_loss( + input: Tensor, + target: Tensor, + weight: Optional[Tensor] = None, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: # noqa: D400,D402 + r"""multilabel_soft_margin_loss(input, target, weight=None, size_average=None, reduce=None, reduction='mean') -> Tensor + + See :class:`~torch.nn.MultiLabelSoftMarginLoss` for details. + """ + if has_torch_function_variadic(input, target, weight): + return handle_torch_function( + multilabel_soft_margin_loss, + (input, target, weight), + input, + target, + weight=weight, + size_average=size_average, + reduce=reduce, + reduction=reduction, + ) + if size_average is not None or reduce is not None: + reduction = _Reduction.legacy_get_string(size_average, reduce) + + loss = -(target * logsigmoid(input) + (1 - target) * logsigmoid(-input)) + + if weight is not None: + loss = loss * weight + + class_dim = input.dim() - 1 + C = input.size(class_dim) + loss = loss.sum(dim=class_dim) / C # only return N loss values + + if reduction == "none": + ret = loss + elif reduction == "mean": + ret = loss.mean() + elif reduction == "sum": + ret = loss.sum() + else: + ret = input + raise ValueError(reduction + " is not valid") + return ret + + +def cosine_embedding_loss( + input1: Tensor, + input2: Tensor, + target: Tensor, + margin: float = 0, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: # noqa: D400,D402 + r"""cosine_embedding_loss(input1, input2, target, margin=0, size_average=None, reduce=None, reduction='mean') -> Tensor + + See :class:`~torch.nn.CosineEmbeddingLoss` for details. + """ + if has_torch_function_variadic(input1, input2, target): + return handle_torch_function( + cosine_embedding_loss, + (input1, input2, target), + input1, + input2, + target, + margin=margin, + size_average=size_average, + reduce=reduce, + reduction=reduction, + ) + if size_average is not None or reduce is not None: + reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) + else: + reduction_enum = _Reduction.get_enum(reduction) + return torch.cosine_embedding_loss(input1, input2, target, margin, reduction_enum) + + +def multi_margin_loss( + input: Tensor, + target: Tensor, + p: int = 1, + margin: float = 1.0, + weight: Optional[Tensor] = None, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: # noqa: D400,D402 + r"""multi_margin_loss(input, target, p=1, margin=1, weight=None, size_average=None, reduce=None, reduction='mean') -> Tensor + + See :class:`~torch.nn.MultiMarginLoss` for details. + """ + if has_torch_function_variadic(input, target, weight): + return handle_torch_function( + multi_margin_loss, + (input, target, weight), + input, + target, + p=p, + margin=margin, + weight=weight, + size_average=size_average, + reduce=reduce, + reduction=reduction, + ) + if size_average is not None or reduce is not None: + reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) + else: + reduction_enum = _Reduction.get_enum(reduction) + if p != 1 and p != 2: + raise ValueError("only p == 1 and p == 2 supported") + if weight is not None: + if weight.dim() != 1: + raise ValueError("weight must be one-dimensional") + + return torch._C._nn.multi_margin_loss(input, target, p, margin, weight, reduction_enum) + + +pixel_shuffle = _add_docstr( + torch.pixel_shuffle, + r""" +pixel_shuffle(input, upscale_factor) -> Tensor + +Rearranges elements in a tensor of shape :math:`(*, C \times r^2, H, W)` to a +tensor of shape :math:`(*, C, H \times r, W \times r)`, where r is the :attr:`upscale_factor`. + +See :class:`~torch.nn.PixelShuffle` for details. + +Args: + input (Tensor): the input tensor + upscale_factor (int): factor to increase spatial resolution by + +Examples:: + + >>> input = torch.randn(1, 9, 4, 4) + >>> output = torch.nn.functional.pixel_shuffle(input, 3) + >>> print(output.size()) + torch.Size([1, 1, 12, 12]) +""", +) + +pixel_unshuffle = _add_docstr( + torch.pixel_unshuffle, + r""" +pixel_unshuffle(input, downscale_factor) -> Tensor + +Reverses the :class:`~torch.nn.PixelShuffle` operation by rearranging elements in a +tensor of shape :math:`(*, C, H \times r, W \times r)` to a tensor of shape +:math:`(*, C \times r^2, H, W)`, where r is the :attr:`downscale_factor`. + +See :class:`~torch.nn.PixelUnshuffle` for details. + +Args: + input (Tensor): the input tensor + downscale_factor (int): factor to increase spatial resolution by + +Examples:: + + >>> input = torch.randn(1, 1, 12, 12) + >>> output = torch.nn.functional.pixel_unshuffle(input, 3) + >>> print(output.size()) + torch.Size([1, 9, 4, 4]) +""", +) + +channel_shuffle = _add_docstr( + torch.channel_shuffle, + r""" +channel_shuffle(input, groups) -> Tensor + +Divide the channels in a tensor of shape :math:`(*, C , H, W)` +into g groups and rearrange them as :math:`(*, C \frac g, g, H, W)`, +while keeping the original tensor shape. + +See :class:`~torch.nn.ChannelShuffle` for details. + +Args: + input (Tensor): the input tensor + groups (int): number of groups to divide channels in and rearrange. + +Examples:: + + >>> input = torch.randn(1, 4, 2, 2) + >>> print(input) + [[[[1, 2], + [3, 4]], + [[5, 6], + [7, 8]], + [[9, 10], + [11, 12]], + [[13, 14], + [15, 16]], + ]] + >>> output = torch.nn.functional.channel_shuffle(input, 2) + >>> print(output) + [[[[1, 2], + [3, 4]], + [[9, 10], + [11, 12]], + [[5, 6], + [7, 8]], + [[13, 14], + [15, 16]], + ]] +""", +) + +native_channel_shuffle = _add_docstr( + torch.native_channel_shuffle, + r""" +native_channel_shuffle(input, groups) -> Tensor + +Native kernel level implementation of the `channel_shuffle`. +This function might become private in future releases, use with caution. + +Divide the channels in a tensor of shape :math:`(*, C , H, W)` +into g groups and rearrange them as :math:`(*, C \frac g, g, H, W)`, +while keeping the original tensor shape. + +See :class:`~torch.nn.ChannelShuffle` for details. + +Args: + input (Tensor): the input tensor + groups (int): number of groups to divide channels in and rearrange. + +Examples:: + + >>> input = torch.randn(1, 4, 2, 2) + >>> print(input) + [[[[1, 2], + [3, 4]], + [[5, 6], + [7, 8]], + [[9, 10], + [11, 12]], + [[13, 14], + [15, 16]], + ]] + >>> output = torch.nn.functional.native_channel_shuffle(input, 2) + >>> print(output) + [[[[1, 2], + [3, 4]], + [[9, 10], + [11, 12]], + [[5, 6], + [7, 8]], + [[13, 14], + [15, 16]], + ]] +""", +) + +@_overload # noqa: F811 +def upsample(input: Tensor, size: Optional[int] = None, scale_factor: Optional[float] = None, mode: str = "nearest", align_corners: Optional[bool] = None) -> Tensor: # noqa: F811,B950 + pass + + +@_overload # noqa: F811 +def upsample(input: Tensor, size: Optional[List[int]] = None, scale_factor: Optional[float] = None, mode: str = "nearest", align_corners: Optional[bool] = None) -> Tensor: # noqa: F811,B950 + pass + + +def upsample(input, size=None, scale_factor=None, mode="nearest", align_corners=None): # noqa: F811 + r"""Upsample input. + + Provided tensor is upsampled to either the given :attr:`size` or the given + :attr:`scale_factor` + + .. warning:: + This function is deprecated in favor of :func:`torch.nn.functional.interpolate`. + This is equivalent with ``nn.functional.interpolate(...)``. + + Note: + {backward_reproducibility_note} + + The algorithm used for upsampling is determined by :attr:`mode`. + + Currently temporal, spatial and volumetric upsampling are supported, i.e. + expected inputs are 3-D, 4-D or 5-D in shape. + + The input dimensions are interpreted in the form: + `mini-batch x channels x [optional depth] x [optional height] x width`. + + The modes available for upsampling are: `nearest`, `linear` (3D-only), + `bilinear`, `bicubic` (4D-only), `trilinear` (5D-only) + + Args: + input (Tensor): the input tensor + size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int]): + output spatial size. + scale_factor (float or Tuple[float]): multiplier for spatial size. Has to match input size if it is a tuple. + mode (str): algorithm used for upsampling: + ``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` | + ``'trilinear'``. Default: ``'nearest'`` + align_corners (bool, optional): Geometrically, we consider the pixels of the + input and output as squares rather than points. + If set to ``True``, the input and output tensors are aligned by the + center points of their corner pixels, preserving the values at the corner pixels. + If set to ``False``, the input and output tensors are aligned by the corner + points of their corner pixels, and the interpolation uses edge value padding + for out-of-boundary values, making this operation *independent* of input size + when :attr:`scale_factor` is kept the same. This only has an effect when :attr:`mode` + is ``'linear'``, ``'bilinear'``, ``'bicubic'`` or ``'trilinear'``. + Default: ``False`` + + .. note:: + With ``mode='bicubic'``, it's possible to cause overshoot, in other words it can produce + negative values or values greater than 255 for images. + Explicitly call ``result.clamp(min=0, max=255)`` if you want to reduce the overshoot + when displaying the image. + + .. warning:: + With ``align_corners = True``, the linearly interpolating modes + (`linear`, `bilinear`, and `trilinear`) don't proportionally align the + output and input pixels, and thus the output values can depend on the + input size. This was the default behavior for these modes up to version + 0.3.1. Since then, the default behavior is ``align_corners = False``. + See :class:`~torch.nn.Upsample` for concrete examples on how this + affects the outputs. + + """ + warnings.warn("nn.functional.upsample is deprecated. Use nn.functional.interpolate instead.") + return interpolate(input, size, scale_factor, mode, align_corners) + + +if upsample.__doc__: + upsample.__doc__ = upsample.__doc__.format(**reproducibility_notes) + + +def _is_integer(x) -> bool: + r"""Type check the input number is an integer. + + Will return True for int, SymInt, Numpy integers and Tensors with integer elements. + """ + if isinstance(x, (int, torch.SymInt)): + return True + if np is not None and isinstance(x, np.integer): + return True + return isinstance(x, Tensor) and not x.is_floating_point() + + +@_overload # noqa: F811 +def interpolate(input: Tensor, size: Optional[int] = None, scale_factor: Optional[List[float]] = None, mode: str = 'nearest', align_corners: Optional[bool] = None, recompute_scale_factor: Optional[bool] = None, antialias: bool = False) -> Tensor: # noqa: F811,B950 + pass + + +@_overload # noqa: F811 +def interpolate(input: Tensor, size: Optional[List[int]] = None, scale_factor: Optional[List[float]] = None, mode: str = 'nearest', align_corners: Optional[bool] = None, recompute_scale_factor: Optional[bool] = None, antialias: bool = False) -> Tensor: # noqa: F811,B950 + pass + + +@_overload # noqa: F811 +def interpolate(input: Tensor, size: Optional[int] = None, scale_factor: Optional[float] = None, mode: str = 'nearest', align_corners: Optional[bool] = None, recompute_scale_factor: Optional[bool] = None, antialias: bool = False) -> Tensor: # noqa: F811,B950 + pass + + +@_overload # noqa: F811 +def interpolate( # noqa: F811 + input: Tensor, + size: Optional[List[int]] = None, + scale_factor: Optional[float] = None, + mode: str = "nearest", + align_corners: Optional[bool] = None, + recompute_scale_factor: Optional[bool] = None, + antialias: bool = False, +) -> Tensor: # noqa: F811 + pass + +def interpolate(input: Tensor, size: Optional[int] = None, scale_factor: Optional[List[float]] = None, mode: str = 'nearest', align_corners: Optional[bool] = None, recompute_scale_factor: Optional[bool] = None, antialias: bool = False) -> Tensor: # noqa: F811,B950 + r"""Down/up samples the input. + + Tensor interpolated to either the given :attr:`size` or the given + :attr:`scale_factor` + + The algorithm used for interpolation is determined by :attr:`mode`. + + Currently temporal, spatial and volumetric sampling are supported, i.e. + expected inputs are 3-D, 4-D or 5-D in shape. + + The input dimensions are interpreted in the form: + `mini-batch x channels x [optional depth] x [optional height] x width`. + + The modes available for resizing are: `nearest`, `linear` (3D-only), + `bilinear`, `bicubic` (4D-only), `trilinear` (5D-only), `area`, `nearest-exact` + + Args: + input (Tensor): the input tensor + size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int]): + output spatial size. + scale_factor (float or Tuple[float]): multiplier for spatial size. If `scale_factor` is a tuple, + its length has to match the number of spatial dimensions; `input.dim() - 2`. + mode (str): algorithm used for upsampling: + ``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` | + ``'trilinear'`` | ``'area'`` | ``'nearest-exact'``. Default: ``'nearest'`` + align_corners (bool, optional): Geometrically, we consider the pixels of the + input and output as squares rather than points. + If set to ``True``, the input and output tensors are aligned by the + center points of their corner pixels, preserving the values at the corner pixels. + If set to ``False``, the input and output tensors are aligned by the corner + points of their corner pixels, and the interpolation uses edge value padding + for out-of-boundary values, making this operation *independent* of input size + when :attr:`scale_factor` is kept the same. This only has an effect when :attr:`mode` + is ``'linear'``, ``'bilinear'``, ``'bicubic'`` or ``'trilinear'``. + Default: ``False`` + recompute_scale_factor (bool, optional): recompute the scale_factor for use in the + interpolation calculation. If `recompute_scale_factor` is ``True``, then + `scale_factor` must be passed in and `scale_factor` is used to compute the + output `size`. The computed output `size` will be used to infer new scales for + the interpolation. Note that when `scale_factor` is floating-point, it may differ + from the recomputed `scale_factor` due to rounding and precision issues. + If `recompute_scale_factor` is ``False``, then `size` or `scale_factor` will + be used directly for interpolation. Default: ``None``. + antialias (bool, optional): flag to apply anti-aliasing. Default: ``False``. Using anti-alias + option together with ``align_corners=False``, interpolation result would match Pillow + result for downsampling operation. Supported modes: ``'bilinear'``, ``'bicubic'``. + + .. note:: + With ``mode='bicubic'``, it's possible to cause overshoot, in other words it can produce + negative values or values greater than 255 for images. + Explicitly call ``result.clamp(min=0, max=255)`` if you want to reduce the overshoot + when displaying the image. + + .. note:: + Mode ``mode='nearest-exact'`` matches Scikit-Image and PIL nearest neighbours interpolation + algorithms and fixes known issues with ``mode='nearest'``. This mode is introduced to keep + backward compatibility. + Mode ``mode='nearest'`` matches buggy OpenCV's ``INTER_NEAREST`` interpolation algorithm. + + .. note:: + The gradients for the dtype ``float16`` on CUDA may be inaccurate in the upsample operation + when using modes ``['linear', 'bilinear', 'bicubic', 'trilinear', 'area']``. + For more details, please refer to the discussion in + `issue#104157 `_. + + Note: + {backward_reproducibility_note} + """ + if has_torch_function_unary(input): + return handle_torch_function( + interpolate, + (input,), + input, + size=size, + scale_factor=scale_factor, + mode=mode, + align_corners=align_corners, + recompute_scale_factor=recompute_scale_factor, + antialias=antialias + ) + + if mode in ("nearest", "area", "nearest-exact"): + if align_corners is not None: + raise ValueError( + "align_corners option can only be set with the " + "interpolating modes: linear | bilinear | bicubic | trilinear" + ) + else: + if align_corners is None: + align_corners = False + + dim = input.dim() - 2 # Number of spatial dimensions. + + # Process size and scale_factor. Validate that exactly one is set. + # Validate its length if it is a list, or expand it if it is a scalar. + # After this block, exactly one of output_size and scale_factors will + # be non-None, and it will be a list (or tuple). + if size is not None and scale_factor is not None: + raise ValueError("only one of size or scale_factor should be defined") + elif size is not None: + assert scale_factor is None + scale_factors = None + if isinstance(size, (list, tuple)): + if len(size) != dim: + raise ValueError( + "Input and output must have the same number of spatial dimensions, but got " + f"input with spatial dimensions of {list(input.shape[2:])} and output size of {size}. " + "Please provide input tensor in (N, C, d1, d2, ...,dK) format and " + "output size in (o1, o2, ...,oK) format." + ) + if not torch.jit.is_scripting(): + if not all(_is_integer(x) for x in size): + raise TypeError( + "expected size to be one of int or Tuple[int] or Tuple[int, int] or " + f"Tuple[int, int, int], but got size with types {[type(x) for x in size]}" + ) + output_size = size + else: + output_size = [size for _ in range(dim)] + elif scale_factor is not None: + assert size is None + output_size = None + if isinstance(scale_factor, (list, tuple)): + if len(scale_factor) != dim: + raise ValueError( + "Input and scale_factor must have the same number of spatial dimensions, but " + f"got input with spatial dimensions of {list(input.shape[2:])} and " + f"scale_factor of shape {scale_factor}. " + "Please provide input tensor in (N, C, d1, d2, ...,dK) format and " + "scale_factor in (s1, s2, ...,sK) format." + ) + scale_factors = scale_factor + else: + scale_factors = [scale_factor for _ in range(dim)] + else: + raise ValueError("either size or scale_factor should be defined") + + if recompute_scale_factor is not None and recompute_scale_factor and size is not None: + raise ValueError("recompute_scale_factor is not meaningful with an explicit size.") + + # "area" mode always requires an explicit size rather than scale factor. + # Re-use the recompute_scale_factor code path. + if mode == "area" and output_size is None: + recompute_scale_factor = True + + if recompute_scale_factor is not None and recompute_scale_factor: + # We compute output_size here, then un-set scale_factors. + # The C++ code will recompute it based on the (integer) output size. + assert scale_factors is not None + if not torch.jit.is_scripting() and torch._C._get_tracing_state(): + # make scale_factor a tensor in tracing so constant doesn't get baked in + output_size = [ + (torch.floor((input.size(i + 2).float() * torch.tensor(scale_factors[i], dtype=torch.float32)).float())) + for i in range(dim) + ] + elif torch.jit.is_scripting(): + output_size = [int(math.floor(float(input.size(i + 2)) * scale_factors[i])) + for i in range(dim)] + else: + output_size = [ + _sym_int(input.size(i + 2) * scale_factors[i]) + for i in range(dim) + ] + scale_factors = None + + if antialias and not (mode in ("bilinear", "bicubic") and input.ndim == 4): + raise ValueError("Anti-alias option is restricted to bilinear and bicubic modes and requires a 4-D tensor as input") + + if input.dim() == 3 and mode == "nearest": + return torch._C._nn.upsample_nearest1d(input, output_size, scale_factors) + if input.dim() == 4 and mode == "nearest": + return torch._C._nn.upsample_nearest2d(input, output_size, scale_factors) + if input.dim() == 5 and mode == "nearest": + return torch._C._nn.upsample_nearest3d(input, output_size, scale_factors) + + if input.dim() == 3 and mode == "nearest-exact": + return torch._C._nn._upsample_nearest_exact1d(input, output_size, scale_factors) + if input.dim() == 4 and mode == "nearest-exact": + return torch._C._nn._upsample_nearest_exact2d(input, output_size, scale_factors) + if input.dim() == 5 and mode == "nearest-exact": + return torch._C._nn._upsample_nearest_exact3d(input, output_size, scale_factors) + + if input.dim() == 3 and mode == "area": + assert output_size is not None + return adaptive_avg_pool1d(input, output_size) + if input.dim() == 4 and mode == "area": + assert output_size is not None + return adaptive_avg_pool2d(input, output_size) + if input.dim() == 5 and mode == "area": + assert output_size is not None + return adaptive_avg_pool3d(input, output_size) + + if input.dim() == 3 and mode == "linear": + assert align_corners is not None + return torch._C._nn.upsample_linear1d(input, output_size, align_corners, scale_factors) + if input.dim() == 4 and mode == "bilinear": + assert align_corners is not None + if antialias: + return torch._C._nn._upsample_bilinear2d_aa(input, output_size, align_corners, scale_factors) + # Two levels are necessary to prevent TorchScript from touching + # are_deterministic_algorithms_enabled. + if not torch.jit.is_scripting(): + if torch.are_deterministic_algorithms_enabled() and input.is_cuda: + # Use slow decomp whose backward will be in terms of index_put + # importlib is required because the import cannot be top level + # (cycle) and cannot be nested (TS doesn't support) + return importlib.import_module('torch._decomp.decompositions')._upsample_linear_vec( + input, output_size, align_corners, scale_factors) + return torch._C._nn.upsample_bilinear2d(input, output_size, align_corners, scale_factors) + if input.dim() == 5 and mode == "trilinear": + assert align_corners is not None + return torch._C._nn.upsample_trilinear3d(input, output_size, align_corners, scale_factors) + if input.dim() == 4 and mode == "bicubic": + assert align_corners is not None + if antialias: + return torch._C._nn._upsample_bicubic2d_aa(input, output_size, align_corners, scale_factors) + return torch._C._nn.upsample_bicubic2d(input, output_size, align_corners, scale_factors) + + if input.dim() == 3 and mode == "bilinear": + raise NotImplementedError("Got 3D input, but bilinear mode needs 4D input") + if input.dim() == 3 and mode == "trilinear": + raise NotImplementedError("Got 3D input, but trilinear mode needs 5D input") + if input.dim() == 4 and mode == "linear": + raise NotImplementedError("Got 4D input, but linear mode needs 3D input") + if input.dim() == 4 and mode == "trilinear": + raise NotImplementedError("Got 4D input, but trilinear mode needs 5D input") + if input.dim() == 5 and mode == "linear": + raise NotImplementedError("Got 5D input, but linear mode needs 3D input") + if input.dim() == 5 and mode == "bilinear": + raise NotImplementedError("Got 5D input, but bilinear mode needs 4D input") + + raise NotImplementedError( + "Input Error: Only 3D, 4D and 5D input Tensors supported" + f" (got {input.dim()}D) for the modes: nearest | linear | bilinear | bicubic | trilinear | area | nearest-exact" + f" (got {mode})" + ) + + +if interpolate.__doc__: + interpolate.__doc__ = interpolate.__doc__.format(**reproducibility_notes) + + +@_overload # noqa: F811 +def upsample_nearest(input: Tensor, size: Optional[int] = None, scale_factor: Optional[float] = None) -> Tensor: # noqa: F811 + pass + + +@_overload # noqa: F811 +def upsample_nearest(input: Tensor, size: Optional[List[int]] = None, scale_factor: Optional[float] = None) -> Tensor: # noqa: F811 + pass + + +def upsample_nearest(input, size=None, scale_factor=None): # noqa: F811 + r"""Upsamples the input, using nearest neighbours' pixel values. + + .. warning:: + This function is deprecated in favor of :func:`torch.nn.functional.interpolate`. + This is equivalent with ``nn.functional.interpolate(..., mode='nearest')``. + + Currently spatial and volumetric upsampling are supported (i.e. expected + inputs are 4 or 5 dimensional). + + Args: + input (Tensor): input + size (int or Tuple[int, int] or Tuple[int, int, int]): output spatia + size. + scale_factor (int): multiplier for spatial size. Has to be an integer. + + Note: + {backward_reproducibility_note} + """ + # DeprecationWarning is ignored by default + warnings.warn("nn.functional.upsample_nearest is deprecated. Use nn.functional.interpolate instead.") + return interpolate(input, size, scale_factor, mode="nearest") + + +if upsample_nearest.__doc__: + upsample_nearest.__doc__ = upsample_nearest.__doc__.format(**reproducibility_notes) + + +@_overload # noqa: F811 +def upsample_bilinear( + input: Tensor, size: Optional[int] = None, scale_factor: Optional[float] = None +) -> Tensor: # noqa: F811 + pass + + +@_overload # noqa: F811 +def upsample_bilinear( # noqa: F811 + input: Tensor, size: Optional[List[int]] = None, scale_factor: Optional[float] = None +) -> Tensor: # noqa: F811 + pass + + +@_overload # noqa: F811 +def upsample_bilinear( # noqa: F811 + input: Tensor, size: Optional[int] = None, scale_factor: Optional[List[float]] = None +) -> Tensor: # noqa: F811 + pass + + +@_overload # noqa: F811 +def upsample_bilinear( # noqa: F811 + input: Tensor, size: Optional[List[int]] = None, scale_factor: Optional[List[float]] = None +) -> Tensor: # noqa: F811 + pass + + +def upsample_bilinear(input, size=None, scale_factor=None): # noqa: F811 + r"""Upsamples the input, using bilinear upsampling. + + .. warning:: + This function is deprecated in favor of :func:`torch.nn.functional.interpolate`. + This is equivalent with + ``nn.functional.interpolate(..., mode='bilinear', align_corners=True)``. + + Expected inputs are spatial (4 dimensional). Use `upsample_trilinear` fo + volumetric (5 dimensional) inputs. + + Args: + input (Tensor): input + size (int or Tuple[int, int]): output spatial size. + scale_factor (int or Tuple[int, int]): multiplier for spatial size + + Note: + {backward_reproducibility_note} + """ + # DeprecationWarning is ignored by default + warnings.warn("nn.functional.upsample_bilinear is deprecated. Use nn.functional.interpolate instead.") + return interpolate(input, size, scale_factor, mode="bilinear", align_corners=True) + + +if upsample_bilinear.__doc__: + upsample_bilinear.__doc__ = upsample_bilinear.__doc__.format(**reproducibility_notes) + +GRID_SAMPLE_INTERPOLATION_MODES = { + "bilinear": 0, + "nearest": 1, + "bicubic": 2, +} + +GRID_SAMPLE_PADDING_MODES = { + "zeros": 0, + "border": 1, + "reflection": 2, +} + + +def grid_sample( + input: Tensor, + grid: Tensor, + mode: str = "bilinear", + padding_mode: str = "zeros", + align_corners: Optional[bool] = None, +) -> Tensor: + r"""Compute grid sample. + + Given an :attr:`input` and a flow-field :attr:`grid`, computes the + ``output`` using :attr:`input` values and pixel locations from :attr:`grid`. + + Currently, only spatial (4-D) and volumetric (5-D) :attr:`input` are + supported. + + In the spatial (4-D) case, for :attr:`input` with shape + :math:`(N, C, H_\text{in}, W_\text{in})` and :attr:`grid` with shape + :math:`(N, H_\text{out}, W_\text{out}, 2)`, the output will have shape + :math:`(N, C, H_\text{out}, W_\text{out})`. + + For each output location ``output[n, :, h, w]``, the size-2 vector + ``grid[n, h, w]`` specifies :attr:`input` pixel locations ``x`` and ``y``, + which are used to interpolate the output value ``output[n, :, h, w]``. + In the case of 5D inputs, ``grid[n, d, h, w]`` specifies the + ``x``, ``y``, ``z`` pixel locations for interpolating + ``output[n, :, d, h, w]``. :attr:`mode` argument specifies ``nearest`` or + ``bilinear`` interpolation method to sample the input pixels. + + :attr:`grid` specifies the sampling pixel locations normalized by the + :attr:`input` spatial dimensions. Therefore, it should have most values in + the range of ``[-1, 1]``. For example, values ``x = -1, y = -1`` is the + left-top pixel of :attr:`input`, and values ``x = 1, y = 1`` is the + right-bottom pixel of :attr:`input`. + + If :attr:`grid` has values outside the range of ``[-1, 1]``, the corresponding + outputs are handled as defined by :attr:`padding_mode`. Options are + + * ``padding_mode="zeros"``: use ``0`` for out-of-bound grid locations, + * ``padding_mode="border"``: use border values for out-of-bound grid locations, + * ``padding_mode="reflection"``: use values at locations reflected by + the border for out-of-bound grid locations. For location far away + from the border, it will keep being reflected until becoming in bound, + e.g., (normalized) pixel location ``x = -3.5`` reflects by border ``-1`` + and becomes ``x' = 1.5``, then reflects by border ``1`` and becomes + ``x'' = -0.5``. + + Note: + This function is often used in conjunction with :func:`affine_grid` + to build `Spatial Transformer Networks`_ . + + Note: + When using the CUDA backend, this operation may induce nondeterministic + behaviour in its backward pass that is not easily switched off. + Please see the notes on :doc:`/notes/randomness` for background. + + Note: + NaN values in :attr:`grid` would be interpreted as ``-1``. + + Args: + input (Tensor): input of shape :math:`(N, C, H_\text{in}, W_\text{in})` (4-D case) + or :math:`(N, C, D_\text{in}, H_\text{in}, W_\text{in})` (5-D case) + grid (Tensor): flow-field of shape :math:`(N, H_\text{out}, W_\text{out}, 2)` (4-D case) + or :math:`(N, D_\text{out}, H_\text{out}, W_\text{out}, 3)` (5-D case) + mode (str): interpolation mode to calculate output values + ``'bilinear'`` | ``'nearest'`` | ``'bicubic'``. Default: ``'bilinear'`` + Note: ``mode='bicubic'`` supports only 4-D input. + When ``mode='bilinear'`` and the input is 5-D, the interpolation mode + used internally will actually be trilinear. However, when the input is 4-D, + the interpolation mode will legitimately be bilinear. + padding_mode (str): padding mode for outside grid values + ``'zeros'`` | ``'border'`` | ``'reflection'``. Default: ``'zeros'`` + align_corners (bool, optional): Geometrically, we consider the pixels of the + input as squares rather than points. + If set to ``True``, the extrema (``-1`` and ``1``) are considered as referring + to the center points of the input's corner pixels. If set to ``False``, they + are instead considered as referring to the corner points of the input's corner + pixels, making the sampling more resolution agnostic. + This option parallels the ``align_corners`` option in + :func:`interpolate`, and so whichever option is used here + should also be used there to resize the input image before grid sampling. + Default: ``False`` + + Returns: + output (Tensor): output Tensor + + .. _`Spatial Transformer Networks`: + https://arxiv.org/abs/1506.02025 + + .. warning:: + When ``align_corners = True``, the grid positions depend on the pixel + size relative to the input image size, and so the locations sampled by + :func:`grid_sample` will differ for the same input given at different + resolutions (that is, after being upsampled or downsampled). + The default behavior up to version 1.2.0 was ``align_corners = True``. + Since then, the default behavior has been changed to ``align_corners = False``, + in order to bring it in line with the default for :func:`interpolate`. + + .. note:: + ``mode='bicubic'`` is implemented using the `cubic convolution algorithm`_ with :math:`\alpha=-0.75`. + The constant :math:`\alpha` might be different from packages to packages. + For example, `PIL`_ and `OpenCV`_ use -0.5 and -0.75 respectively. + This algorithm may "overshoot" the range of values it's interpolating. + For example, it may produce negative values or values greater than 255 when interpolating input in [0, 255]. + Clamp the results with :func:`torch.clamp` to ensure they are within the valid range. + .. _`cubic convolution algorithm`: https://en.wikipedia.org/wiki/Bicubic_interpolation + .. _`PIL`: https://github.com/python-pillow/Pillow/blob/4634eafe3c695a014267eefdce830b4a825beed7/src/libImaging/Resample.c#L51 + .. _`OpenCV`: https://github.com/opencv/opencv/blob/f345ed564a06178670750bad59526cfa4033be55/modules/imgproc/src/resize.cpp#L908 + """ + if has_torch_function_variadic(input, grid): + return handle_torch_function( + grid_sample, (input, grid), input, grid, mode=mode, padding_mode=padding_mode, align_corners=align_corners + ) + if mode != "bilinear" and mode != "nearest" and mode != "bicubic": + raise ValueError( + f"nn.functional.grid_sample(): expected mode to be 'bilinear', 'nearest' or 'bicubic', but got: '{mode}'" + ) + if padding_mode != "zeros" and padding_mode != "border" and padding_mode != "reflection": + raise ValueError( + "nn.functional.grid_sample(): expected padding_mode " + "to be 'zeros', 'border', or 'reflection', " + f"but got: '{padding_mode}'" + ) + + if mode == "bilinear": + mode_enum = 0 + elif mode == "nearest": + mode_enum = 1 + else: # mode == 'bicubic' + mode_enum = 2 + + if padding_mode == "zeros": + padding_mode_enum = 0 + elif padding_mode == "border": + padding_mode_enum = 1 + else: # padding_mode == 'reflection' + padding_mode_enum = 2 + + if align_corners is None: + warnings.warn( + "Default grid_sample and affine_grid behavior has changed " + "to align_corners=False since 1.3.0. Please specify " + "align_corners=True if the old behavior is desired. " + "See the documentation of grid_sample for details." + ) + align_corners = False + + return torch.grid_sampler(input, grid, mode_enum, padding_mode_enum, align_corners) + + +def affine_grid(theta: Tensor, size: List[int], align_corners: Optional[bool] = None) -> Tensor: + r"""Generate 2D or 3D flow field (sampling grid), given a batch of affine matrices :attr:`theta`. + + .. note:: + This function is often used in conjunction with :func:`grid_sample` + to build `Spatial Transformer Networks`_ . + + Args: + theta (Tensor): input batch of affine matrices with shape + (:math:`N \times 2 \times 3`) for 2D or + (:math:`N \times 3 \times 4`) for 3D + size (torch.Size): the target output image size. + (:math:`N \times C \times H \times W` for 2D or + :math:`N \times C \times D \times H \times W` for 3D) + Example: torch.Size((32, 3, 24, 24)) + align_corners (bool, optional): if ``True``, consider ``-1`` and ``1`` + to refer to the centers of the corner pixels rather than the image corners. + Refer to :func:`grid_sample` for a more complete description. + A grid generated by :func:`affine_grid` should be passed to :func:`grid_sample` + with the same setting for this option. + Default: ``False`` + + Returns: + output (Tensor): output Tensor of size (:math:`N \times H \times W \times 2`) + + .. _`Spatial Transformer Networks`: + https://arxiv.org/abs/1506.02025 + + .. warning:: + When ``align_corners = True``, the grid positions depend on the pixel + size relative to the input image size, and so the locations sampled by + :func:`grid_sample` will differ for the same input given at different + resolutions (that is, after being upsampled or downsampled). + The default behavior up to version 1.2.0 was ``align_corners = True``. + Since then, the default behavior has been changed to ``align_corners = False``, + in order to bring it in line with the default for :func:`interpolate`. + .. warning:: + When ``align_corners = True``, 2D affine transforms on 1D data and + 3D affine transforms on 2D data (that is, when one of the spatial + dimensions has unit size) are ill-defined, and not an intended use case. + This is not a problem when ``align_corners = False``. + Up to version 1.2.0, all grid points along a unit dimension were + considered arbitrarily to be at ``-1``. + From version 1.3.0, under ``align_corners = True`` all grid points + along a unit dimension are considered to be at ``0`` + (the center of the input image). + """ + if has_torch_function_unary(theta): + return handle_torch_function(affine_grid, (theta,), theta, size, align_corners=align_corners) + if align_corners is None: + warnings.warn( + "Default grid_sample and affine_grid behavior has changed " + "to align_corners=False since 1.3.0. Please specify " + "align_corners=True if the old behavior is desired. " + "See the documentation of grid_sample for details." + ) + align_corners = False + + # enforce floating point dtype on theta + if not theta.is_floating_point(): + raise ValueError(f"Expected theta to have floating point type, but got {theta.dtype}") + # check that shapes and sizes match + if len(size) == 4: + if theta.dim() != 3 or theta.shape[-2] != 2 or theta.shape[-1] != 3: + raise ValueError( + f"Expected a batch of 2D affine matrices of shape Nx2x3 for size {size}. Got {theta.shape}." + ) + spatial_size = size[-2:] # spatial dimension sizes + elif len(size) == 5: + if theta.dim() != 3 or theta.shape[-2] != 3 or theta.shape[-1] != 4: + raise ValueError( + f"Expected a batch of 3D affine matrices of shape Nx3x4 for size {size}. Got {theta.shape}." + ) + spatial_size = size[-3:] # spatial dimension sizes + else: + raise NotImplementedError( + "affine_grid only supports 4D and 5D sizes, " + "for 2D and 3D affine transforms, respectively. " + f"Got size {size}." + ) + # check for empty span + if align_corners and min(spatial_size) == 1: + warnings.warn( + "Since version 1.3.0, affine_grid behavior has changed " + "for unit-size grids when align_corners=True. " + "This is not an intended use case of affine_grid. " + "See the documentation of affine_grid for details." + ) + elif min(size) <= 0: + raise ValueError(f"Expected non-zero, positive output size. Got {size}") + + return torch.affine_grid_generator(theta, size, align_corners) + + +def pad(input: Tensor, pad: List[int], mode: str = "constant", value: Optional[float] = None) -> Tensor: + r""" +pad(input, pad, mode="constant", value=None) -> Tensor + +Pads tensor. + +Padding size: + The padding size by which to pad some dimensions of :attr:`input` + are described starting from the last dimension and moving forward. + :math:`\left\lfloor\frac{\text{len(pad)}}{2}\right\rfloor` dimensions + of ``input`` will be padded. + For example, to pad only the last dimension of the input tensor, then + :attr:`pad` has the form + :math:`(\text{padding\_left}, \text{padding\_right})`; + to pad the last 2 dimensions of the input tensor, then use + :math:`(\text{padding\_left}, \text{padding\_right},` + :math:`\text{padding\_top}, \text{padding\_bottom})`; + to pad the last 3 dimensions, use + :math:`(\text{padding\_left}, \text{padding\_right},` + :math:`\text{padding\_top}, \text{padding\_bottom}` + :math:`\text{padding\_front}, \text{padding\_back})`. + +Padding mode: + See :class:`torch.nn.CircularPad2d`, :class:`torch.nn.ConstantPad2d`, + :class:`torch.nn.ReflectionPad2d`, and :class:`torch.nn.ReplicationPad2d` + for concrete examples on how each of the padding modes works. Constant + padding is implemented for arbitrary dimensions. Circular, replicate and + reflection padding are implemented for padding the last 3 dimensions of a + 4D or 5D input tensor, the last 2 dimensions of a 3D or 4D input tensor, + or the last dimension of a 2D or 3D input tensor. + +Note: + When using the CUDA backend, this operation may induce nondeterministic + behaviour in its backward pass that is not easily switched off. + Please see the notes on :doc:`/notes/randomness` for background. + +Args: + input (Tensor): N-dimensional tensor + pad (tuple): m-elements tuple, where + :math:`\frac{m}{2} \leq` input dimensions and :math:`m` is even. + mode: ``'constant'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. + Default: ``'constant'`` + value: fill value for ``'constant'`` padding. Default: ``0`` + +Examples:: + + >>> t4d = torch.empty(3, 3, 4, 2) + >>> p1d = (1, 1) # pad last dim by 1 on each side + >>> out = F.pad(t4d, p1d, "constant", 0) # effectively zero padding + >>> print(out.size()) + torch.Size([3, 3, 4, 4]) + >>> p2d = (1, 1, 2, 2) # pad last dim by (1, 1) and 2nd to last by (2, 2) + >>> out = F.pad(t4d, p2d, "constant", 0) + >>> print(out.size()) + torch.Size([3, 3, 8, 4]) + >>> t4d = torch.empty(3, 3, 4, 2) + >>> p3d = (0, 1, 2, 1, 3, 3) # pad by (0, 1), (2, 1), and (3, 3) + >>> out = F.pad(t4d, p3d, "constant", 0) + >>> print(out.size()) + torch.Size([3, 9, 7, 3]) + +""" + if has_torch_function_unary(input): + return handle_torch_function( + torch.nn.functional.pad, (input,), input, pad, mode=mode, value=value) + if not torch.jit.is_scripting(): + if torch.are_deterministic_algorithms_enabled() and input.is_cuda: + if mode == 'replicate': + # Use slow decomp whose backward will be in terms of index_put. + # importlib is required because the import cannot be top level + # (cycle) and cannot be nested (TS doesn't support) + return importlib.import_module('torch._decomp.decompositions')._replication_pad( + input, pad + ) + return torch._C._nn.pad(input, pad, mode, value) + +# TODO: Fix via https://github.com/pytorch/pytorch/issues/75798 +pad.__module__ = "torch.nn.functional" + +# distance + + +pairwise_distance = _add_docstr( + torch.pairwise_distance, + r""" +pairwise_distance(x1, x2, p=2.0, eps=1e-6, keepdim=False) -> Tensor + +See :class:`torch.nn.PairwiseDistance` for details +""") + + +pdist = _add_docstr( + torch.pdist, + r""" +pdist(input, p=2) -> Tensor + +Computes the p-norm distance between every pair of row vectors in the input. +This is identical to the upper triangular portion, excluding the diagonal, of +`torch.norm(input[:, None] - input, dim=2, p=p)`. This function will be faster +if the rows are contiguous. + +If input has shape :math:`N \times M` then the output will have shape +:math:`\frac{1}{2} N (N - 1)`. + +This function is equivalent to ``scipy.spatial.distance.pdist(input, +'minkowski', p=p)`` if :math:`p \in (0, \infty)`. When :math:`p = 0` it is +equivalent to ``scipy.spatial.distance.pdist(input, 'hamming') * M``. +When :math:`p = \infty`, the closest scipy function is +``scipy.spatial.distance.pdist(xn, lambda x, y: np.abs(x - y).max())``. + +Args: + input: input tensor of shape :math:`N \times M`. + p: p value for the p-norm distance to calculate between each vector pair + :math:`\in [0, \infty]`. +""", +) + + +cosine_similarity = _add_docstr( + torch.cosine_similarity, + r""" +cosine_similarity(x1, x2, dim=1, eps=1e-8) -> Tensor + +Returns cosine similarity between ``x1`` and ``x2``, computed along dim. ``x1`` and ``x2`` must be broadcastable +to a common shape. ``dim`` refers to the dimension in this common shape. Dimension ``dim`` of the output is +squeezed (see :func:`torch.squeeze`), resulting in the +output tensor having 1 fewer dimension. + +.. math :: + \text{similarity} = \dfrac{x_1 \cdot x_2}{\max(\Vert x_1 \Vert _2, \epsilon) \cdot \max(\Vert x_2 \Vert _2, \epsilon)} + +Supports :ref:`type promotion `. + +Args: + x1 (Tensor): First input. + x2 (Tensor): Second input. + dim (int, optional): Dimension along which cosine similarity is computed. Default: 1 + eps (float, optional): Small value to avoid division by zero. + Default: 1e-8 + +Example:: + + >>> input1 = torch.randn(100, 128) + >>> input2 = torch.randn(100, 128) + >>> output = F.cosine_similarity(input1, input2) + >>> print(output) +""", +) + + +one_hot = _add_docstr( + torch._C._nn.one_hot, + r""" +one_hot(tensor, num_classes=-1) -> LongTensor + +Takes LongTensor with index values of shape ``(*)`` and returns a tensor +of shape ``(*, num_classes)`` that have zeros everywhere except where the +index of last dimension matches the corresponding value of the input tensor, +in which case it will be 1. + +See also `One-hot on Wikipedia`_ . + +.. _One-hot on Wikipedia: + https://en.wikipedia.org/wiki/One-hot + +Arguments: + tensor (LongTensor): class values of any shape. + num_classes (int): Total number of classes. If set to -1, the number + of classes will be inferred as one greater than the largest class + value in the input tensor. + +Returns: + LongTensor that has one more dimension with 1 values at the + index of last dimension indicated by the input, and 0 everywhere + else. + +Examples: + >>> F.one_hot(torch.arange(0, 5) % 3) + tensor([[1, 0, 0], + [0, 1, 0], + [0, 0, 1], + [1, 0, 0], + [0, 1, 0]]) + >>> F.one_hot(torch.arange(0, 5) % 3, num_classes=5) + tensor([[1, 0, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 0, 1, 0, 0], + [1, 0, 0, 0, 0], + [0, 1, 0, 0, 0]]) + >>> F.one_hot(torch.arange(0, 6).view(3,2) % 3) + tensor([[[1, 0, 0], + [0, 1, 0]], + [[0, 0, 1], + [1, 0, 0]], + [[0, 1, 0], + [0, 0, 1]]]) +""", +) + + +def triplet_margin_loss( + anchor: Tensor, + positive: Tensor, + negative: Tensor, + margin: float = 1.0, + p: float = 2, + eps: float = 1e-6, + swap: bool = False, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> Tensor: + r"""Compute the triplet loss between given input tensors and a margin greater than 0. + + See :class:`~torch.nn.TripletMarginLoss` for details. + """ + if has_torch_function_variadic(anchor, positive, negative): + return handle_torch_function( + triplet_margin_loss, + (anchor, positive, negative), + anchor, + positive, + negative, + margin=margin, + p=p, + eps=eps, + swap=swap, + size_average=size_average, + reduce=reduce, + reduction=reduction, + ) + if size_average is not None or reduce is not None: + reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) + else: + reduction_enum = _Reduction.get_enum(reduction) + return torch.triplet_margin_loss(anchor, positive, negative, margin, p, eps, swap, reduction_enum) + + +def triplet_margin_with_distance_loss( + anchor: Tensor, + positive: Tensor, + negative: Tensor, + *, + distance_function: Optional[Callable[[Tensor, Tensor], Tensor]] = None, + margin: float = 1.0, + swap: bool = False, + reduction: str = "mean" +) -> Tensor: + r"""Compute the triplet margin loss for input tensors using a custom distance function. + + See :class:`~torch.nn.TripletMarginWithDistanceLoss` for details. + """ + if torch.jit.is_scripting(): + raise NotImplementedError( + "F.triplet_margin_with_distance_loss does not support JIT scripting: " + "functions requiring Callables cannot be scripted." + ) + + if has_torch_function_variadic(anchor, positive, negative): + return handle_torch_function( + triplet_margin_with_distance_loss, + (anchor, positive, negative), + anchor, + positive, + negative, + distance_function=distance_function, + margin=margin, + swap=swap, + reduction=reduction, + ) + + # Check validity of reduction mode + if reduction not in ("mean", "sum", "none"): + raise ValueError(f"{reduction} is not a valid value for reduction") + + # Check dimensions + a_dim = anchor.ndim + p_dim = positive.ndim + n_dim = negative.ndim + if not (a_dim == p_dim and p_dim == n_dim): + raise RuntimeError( + f"The anchor, positive, and negative tensors are expected to have " + f"the same number of dimensions, but got: anchor {a_dim}D, " + f"positive {p_dim}D, and negative {n_dim}D inputs") + + # Calculate loss + if distance_function is None: + distance_function = torch.pairwise_distance + + dist_pos = distance_function(anchor, positive) + dist_neg = distance_function(anchor, negative) + # The distance swap is described in the paper "Learning shallow + # convolutional feature descriptors with triplet losses" by V. Balntas, E. + # Riba et al. If True, and if the positive example is closer to the + # negative example than the anchor is, swaps the positive example and the + # anchor in the loss computation. + if swap: + dist_swap = distance_function(positive, negative) + dist_neg = torch.minimum(dist_neg, dist_swap) + loss = torch.clamp_min(margin + dist_pos - dist_neg, 0) + + # Apply reduction + if reduction == "sum": + return torch.sum(loss) + elif reduction == "mean": + return torch.mean(loss) + else: # reduction == "none" + return loss + + +def normalize(input: Tensor, p: float = 2.0, dim: int = 1, eps: float = 1e-12, out: Optional[Tensor] = None) -> Tensor: + r"""Perform :math:`L_p` normalization of inputs over specified dimension. + + For a tensor :attr:`input` of sizes :math:`(n_0, ..., n_{dim}, ..., n_k)`, each + :math:`n_{dim}` -element vector :math:`v` along dimension :attr:`dim` is transformed as + + .. math:: + v = \frac{v}{\max(\lVert v \rVert_p, \epsilon)}. + + With the default arguments it uses the Euclidean norm over vectors along dimension :math:`1` for normalization. + + Args: + input: input tensor of any shape + p (float): the exponent value in the norm formulation. Default: 2 + dim (int or tuple of ints): the dimension to reduce. Default: 1 + eps (float): small value to avoid division by zero. Default: 1e-12 + out (Tensor, optional): the output tensor. If :attr:`out` is used, this + operation won't be differentiable. + """ + if has_torch_function_variadic(input, out): + return handle_torch_function(normalize, (input, out), input, p=p, dim=dim, eps=eps, out=out) + if out is None: + denom = input.norm(p, dim, keepdim=True).clamp_min(eps).expand_as(input) + return input / denom + else: + denom = input.norm(p, dim, keepdim=True).clamp_min_(eps).expand_as(input) + return torch.div(input, denom, out=out) + + +def assert_int_or_pair(arg: List[int], arg_name: str, message: str) -> None: + assert isinstance(arg, int) or len(arg) == 2, message.format(arg_name) + + +def unfold( + input: Tensor, kernel_size: BroadcastingList2[int], + dilation: BroadcastingList2[int] = 1, + padding: BroadcastingList2[int] = 0, + stride: BroadcastingList2[int] = 1 +) -> Tensor: + r"""Extract sliding local blocks from a batched input tensor. + + .. warning:: + Currently, only 4-D input tensors (batched image-like tensors) are + supported. + + .. warning:: + + More than one element of the unfolded tensor may refer to a single + memory location. As a result, in-place operations (especially ones that + are vectorized) may result in incorrect behavior. If you need to write + to the tensor, please clone it first. + + + See :class:`torch.nn.Unfold` for details + """ + if has_torch_function_unary(input): + return handle_torch_function( + unfold, (input,), input, kernel_size, dilation=dilation, padding=padding, stride=stride + ) + return torch._C._nn.im2col(input, _pair(kernel_size), _pair(dilation), _pair(padding), _pair(stride)) + + +def fold( + input: Tensor, output_size: BroadcastingList2[int], + kernel_size: BroadcastingList2[int], + dilation: BroadcastingList2[int] = 1, + padding: BroadcastingList2[int] = 0, + stride: BroadcastingList2[int] = 1 +) -> Tensor: + r"""Combine an array of sliding local blocks into a large containing tensor. + + .. warning:: + Currently, only unbatched (3D) or batched (4D) image-like output tensors are supported. + + See :class:`torch.nn.Fold` for details + """ + if has_torch_function_unary(input): + return handle_torch_function( + fold, (input,), input, output_size, kernel_size, dilation=dilation, padding=padding, stride=stride + ) + return torch._C._nn.col2im( + input, _pair(output_size), _pair(kernel_size), _pair(dilation), _pair(padding), _pair(stride) + ) + +# +# multihead attention +# + +def _in_projection_packed( + q: Tensor, + k: Tensor, + v: Tensor, + w: Tensor, + b: Optional[Tensor] = None, +) -> List[Tensor]: + r"""Perform the in-projection step of the attention operation, using packed weights. + + Output is a triple containing projection tensors for query, key and value. + + Args: + q, k, v: query, key and value tensors to be projected. For self-attention, + these are typically the same tensor; for encoder-decoder attention, + k and v are typically the same tensor. (We take advantage of these + identities for performance if they are present.) Regardless, q, k and v + must share a common embedding dimension; otherwise their shapes may vary. + w: projection weights for q, k and v, packed into a single tensor. Weights + are packed along dimension 0, in q, k, v order. + b: optional projection biases for q, k and v, packed into a single tensor + in q, k, v order. + + Shape: + Inputs: + - q: :math:`(..., E)` where E is the embedding dimension + - k: :math:`(..., E)` where E is the embedding dimension + - v: :math:`(..., E)` where E is the embedding dimension + - w: :math:`(E * 3, E)` where E is the embedding dimension + - b: :math:`E * 3` where E is the embedding dimension + + Output: + - in output list :math:`[q', k', v']`, each output tensor will have the + same shape as the corresponding input tensor. + """ + E = q.size(-1) + if k is v: + if q is k: + # self-attention + proj = linear(q, w, b) + # reshape to 3, E and not E, 3 is deliberate for better memory coalescing and keeping same order as chunk() + proj = proj.unflatten(-1, (3, E)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous() + return proj[0], proj[1], proj[2] + else: + # encoder-decoder attention + w_q, w_kv = w.split([E, E * 2]) + if b is None: + b_q = b_kv = None + else: + b_q, b_kv = b.split([E, E * 2]) + q_proj = linear(q, w_q, b_q) + kv_proj = linear(k, w_kv, b_kv) + # reshape to 2, E and not E, 2 is deliberate for better memory coalescing and keeping same order as chunk() + kv_proj = kv_proj.unflatten(-1, (2, E)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous() + return (q_proj, kv_proj[0], kv_proj[1]) + else: + w_q, w_k, w_v = w.chunk(3) + if b is None: + b_q = b_k = b_v = None + else: + b_q, b_k, b_v = b.chunk(3) + return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v) + + +def _in_projection( + q: Tensor, + k: Tensor, + v: Tensor, + w_q: Tensor, + w_k: Tensor, + w_v: Tensor, + b_q: Optional[Tensor] = None, + b_k: Optional[Tensor] = None, + b_v: Optional[Tensor] = None, +) -> Tuple[Tensor, Tensor, Tensor]: + r"""Perform the in-projection step of the attention operation. + + This is simply a triple of linear projections, + with shape constraints on the weights which + ensure embedding dimension uniformity in the projected outputs. + Output is a triple containing projection tensors for query, key and value. + + Args: + q, k, v: query, key and value tensors to be projected. + w_q, w_k, w_v: weights for q, k and v, respectively. + b_q, b_k, b_v: optional biases for q, k and v, respectively. + + Shape: + Inputs: + - q: :math:`(Qdims..., Eq)` where Eq is the query embedding dimension and Qdims are any + number of leading dimensions. + - k: :math:`(Kdims..., Ek)` where Ek is the key embedding dimension and Kdims are any + number of leading dimensions. + - v: :math:`(Vdims..., Ev)` where Ev is the value embedding dimension and Vdims are any + number of leading dimensions. + - w_q: :math:`(Eq, Eq)` + - w_k: :math:`(Eq, Ek)` + - w_v: :math:`(Eq, Ev)` + - b_q: :math:`(Eq)` + - b_k: :math:`(Eq)` + - b_v: :math:`(Eq)` + + Output: in output triple :math:`(q', k', v')`, + - q': :math:`[Qdims..., Eq]` + - k': :math:`[Kdims..., Eq]` + - v': :math:`[Vdims..., Eq]` + + """ + Eq, Ek, Ev = q.size(-1), k.size(-1), v.size(-1) + assert w_q.shape == (Eq, Eq), f"expecting query weights shape of {(Eq, Eq)}, but got {w_q.shape}" + assert w_k.shape == (Eq, Ek), f"expecting key weights shape of {(Eq, Ek)}, but got {w_k.shape}" + assert w_v.shape == (Eq, Ev), f"expecting value weights shape of {(Eq, Ev)}, but got {w_v.shape}" + assert b_q is None or b_q.shape == (Eq,), f"expecting query bias shape of {(Eq,)}, but got {b_q.shape}" + assert b_k is None or b_k.shape == (Eq,), f"expecting key bias shape of {(Eq,)}, but got {b_k.shape}" + assert b_v is None or b_v.shape == (Eq,), f"expecting value bias shape of {(Eq,)}, but got {b_v.shape}" + return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v) + +scaled_dot_product_attention = _add_docstr( + torch._C._nn.scaled_dot_product_attention, r""" +scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> Tensor: + +Computes scaled dot product attention on query, key and value tensors, using +an optional attention mask if passed, and applying dropout if a probability +greater than 0.0 is specified. The optional scale argument can only be specified as a keyword argument. + +.. code-block:: python + + # Efficient implementation equivalent to the following: + def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor: + L, S = query.size(-2), key.size(-2) + scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale + attn_bias = torch.zeros(L, S, dtype=query.dtype) + if is_causal: + assert attn_mask is None + temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(query.dtype) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) + else: + attn_bias += attn_mask + attn_weight = query @ key.transpose(-2, -1) * scale_factor + attn_weight += attn_bias + attn_weight = torch.softmax(attn_weight, dim=-1) + attn_weight = torch.dropout(attn_weight, dropout_p, train=True) + return attn_weight @ value + +.. warning:: This function is beta and subject to change. + +Note: + + There are currently three supported implementations of scaled dot product attention: + + - `FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning`_ + - `Memory-Efficient Attention`_ + - A PyTorch implementation defined in C++ matching the above formulation + + The function may call optimized kernels for improved performance when using the CUDA backend. + For all other backends, the PyTorch implementation will be used. + + All implementations are enabled by default. Scaled dot product attention attempts to automatically select the + most optimal implementation based on the inputs. In order to provide more fine-grained control over what implementation + is used, the following functions are provided for enabling and disabling implementations. + The context manager is the preferred mechanism: + + - :func:`torch.nn.attention.sdpa_kernel`: A context manager used to enable or disable any of the implementations. + - :func:`torch.backends.cuda.enable_flash_sdp`: Globally enables or disables FlashAttention. + - :func:`torch.backends.cuda.enable_mem_efficient_sdp`: Globally enables or disables Memory-Efficient Attention. + - :func:`torch.backends.cuda.enable_math_sdp`: Globally enables or disables the PyTorch C++ implementation. + + Each of the fused kernels has specific input limitations. If the user requires the use of a specific fused implementation, + disable the PyTorch C++ implementation using :func:`torch.nn.attention.sdpa_kernel`. + In the event that a fused implementation is not available, a warning will be raised with the + reasons why the fused implementation cannot run. + + Due to the nature of fusing floating point operations, the output of this function may be different + depending on what backend kernel is chosen. + The c++ implementation supports torch.float64 and can be used when higher precision is required. + For more information please see :doc:`/notes/numerical_accuracy` + +Note: + {cudnn_reproducibility_note} +""".format(**reproducibility_notes) + + r""" +Args: + query (Tensor): Query tensor; shape :math:`(N, ..., L, E)`. + key (Tensor): Key tensor; shape :math:`(N, ..., S, E)`. + value (Tensor): Value tensor; shape :math:`(N, ..., S, Ev)`. + attn_mask (optional Tensor): Attention mask; shape must be broadcastable to the shape of attention weights, + which is :math:`(N,..., L, S)`. Two types of masks are supported. + A boolean mask where a value of True indicates that the element *should* take part in attention. + A float mask of the same type as query, key, value that is added to the attention score. + dropout_p (float): Dropout probability; if greater than 0.0, dropout is applied + is_causal (bool): If true, assumes upper left causal attention masking and errors if both attn_mask and is_causal + are set. + scale (optional float, keyword-only): Scaling factor applied prior to softmax. If None, the default value is set + to :math:`\frac{1}{\sqrt{E}}`. + + +Returns: + output (Tensor): Attention output; shape :math:`(N, ..., L, Ev)`. + +Shape legend: + - :math:`N: \text{Batch size} ... : \text{Any number of other batch dimensions (optional)}` + - :math:`S: \text{Source sequence length}` + - :math:`L: \text{Target sequence length}` + - :math:`E: \text{Embedding dimension of the query and key}` + - :math:`Ev: \text{Embedding dimension of the value}` + +Examples: + + >>> # Optionally use the context manager to ensure one of the fused kernels is run + >>> query = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda") + >>> key = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda") + >>> value = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda") + >>> with torch.backends.cuda.sdp_kernel(enable_math=False): + >>> F.scaled_dot_product_attention(query,key,value) + + +.. _FlashAttention-2\: Faster Attention with Better Parallelism and Work Partitioning: + https://arxiv.org/abs/2307.08691 +.. _Memory-Efficient Attention: + https://github.com/facebookresearch/xformers + +""") + +def _mha_shape_check(query: Tensor, key: Tensor, value: Tensor, + key_padding_mask: Optional[Tensor], attn_mask: Optional[Tensor], num_heads: int): + # Verifies the expected shape for `query, `key`, `value`, `key_padding_mask` and `attn_mask` + # and returns if the input is batched or not. + # Raises an error if `query` is not 2-D (unbatched) or 3-D (batched) tensor. + + # Shape check. + if query.dim() == 3: + # Batched Inputs + is_batched = True + assert key.dim() == 3 and value.dim() == 3, \ + ("For batched (3-D) `query`, expected `key` and `value` to be 3-D" + f" but found {key.dim()}-D and {value.dim()}-D tensors respectively") + if key_padding_mask is not None: + assert key_padding_mask.dim() == 2, \ + ("For batched (3-D) `query`, expected `key_padding_mask` to be `None` or 2-D" + f" but found {key_padding_mask.dim()}-D tensor instead") + if attn_mask is not None: + assert attn_mask.dim() in (2, 3), \ + ("For batched (3-D) `query`, expected `attn_mask` to be `None`, 2-D or 3-D" + f" but found {attn_mask.dim()}-D tensor instead") + elif query.dim() == 2: + # Unbatched Inputs + is_batched = False + assert key.dim() == 2 and value.dim() == 2, \ + ("For unbatched (2-D) `query`, expected `key` and `value` to be 2-D" + f" but found {key.dim()}-D and {value.dim()}-D tensors respectively") + + if key_padding_mask is not None: + assert key_padding_mask.dim() == 1, \ + ("For unbatched (2-D) `query`, expected `key_padding_mask` to be `None` or 1-D" + f" but found {key_padding_mask.dim()}-D tensor instead") + + if attn_mask is not None: + assert attn_mask.dim() in (2, 3), \ + ("For unbatched (2-D) `query`, expected `attn_mask` to be `None`, 2-D or 3-D" + f" but found {attn_mask.dim()}-D tensor instead") + if attn_mask.dim() == 3: + expected_shape = (num_heads, query.shape[0], key.shape[0]) + assert attn_mask.shape == expected_shape, \ + (f"Expected `attn_mask` shape to be {expected_shape} but got {attn_mask.shape}") + else: + raise AssertionError( + f"query should be unbatched 2D or batched 3D tensor but received {query.dim()}-D query tensor") + + return is_batched + +def _canonical_mask( + mask: Optional[Tensor], + mask_name: str, + other_type: Optional[DType], + other_name: str, + target_type: DType, + check_other: bool = True, +) -> Optional[Tensor]: + + if mask is not None: + _mask_dtype = mask.dtype + _mask_is_float = torch.is_floating_point(mask) + if _mask_dtype != torch.bool and not _mask_is_float: + raise AssertionError( + f"only bool and floating types of {mask_name} are supported") + if check_other and other_type is not None: + if _mask_dtype != other_type: + warnings.warn( + f"Support for mismatched {mask_name} and {other_name} " + "is deprecated. Use same type for both instead." + ) + if not _mask_is_float: + mask = ( + torch.zeros_like(mask, dtype=target_type) + .masked_fill_(mask, float("-inf")) + ) + return mask + +def _none_or_dtype(input: Optional[Tensor]) -> Optional[DType]: + if input is None: + return None + elif isinstance(input, torch.Tensor): + return input.dtype + raise RuntimeError("input to _none_or_dtype() must be None or torch.Tensor") + +def multi_head_attention_forward( + query: Tensor, + key: Tensor, + value: Tensor, + embed_dim_to_check: int, + num_heads: int, + in_proj_weight: Optional[Tensor], + in_proj_bias: Optional[Tensor], + bias_k: Optional[Tensor], + bias_v: Optional[Tensor], + add_zero_attn: bool, + dropout_p: float, + out_proj_weight: Tensor, + out_proj_bias: Optional[Tensor], + training: bool = True, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + use_separate_proj_weight: bool = False, + q_proj_weight: Optional[Tensor] = None, + k_proj_weight: Optional[Tensor] = None, + v_proj_weight: Optional[Tensor] = None, + static_k: Optional[Tensor] = None, + static_v: Optional[Tensor] = None, + average_attn_weights: bool = True, + is_causal: bool = False, +) -> Tuple[Tensor, Optional[Tensor]]: + r"""Forward method for MultiHeadAttention. + + See :class:`torch.nn.MultiheadAttention` for details. + + Args: + query, key, value: map a query and a set of key-value pairs to an output. + See "Attention Is All You Need" for more details. + embed_dim_to_check: total dimension of the model. + num_heads: parallel attention heads. + in_proj_weight, in_proj_bias: input projection weight and bias. + bias_k, bias_v: bias of the key and value sequences to be added at dim=0. + add_zero_attn: add a new batch of zeros to the key and + value sequences at dim=1. + dropout_p: probability of an element to be zeroed. + out_proj_weight, out_proj_bias: the output projection weight and bias. + training: apply dropout if is ``True``. + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. This is an binary mask. When the value is True, + the corresponding value on the attention layer will be filled with -inf. + need_weights: output attn_output_weights. + Default: `True` + Note: `needs_weight` defaults to `True`, but should be set to `False` + For best performance when attention weights are not needed. + *Setting needs_weights to `True` + leads to a significant performance degradation.* + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + is_causal: If specified, applies a causal mask as attention mask, and ignores + attn_mask for computing scaled dot product attention. + Default: ``False``. + .. warning:: + is_causal is provides a hint that the attn_mask is the + causal mask.Providing incorrect hints can result in + incorrect execution, including forward and backward + compatibility. + use_separate_proj_weight: the function accept the proj. weights for query, key, + and value in different forms. If false, in_proj_weight will be used, which is + a combination of q_proj_weight, k_proj_weight, v_proj_weight. + q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias. + static_k, static_v: static key and value used for attention operators. + average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across heads. + Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an effect + when ``need_weights=True.``. Default: True + + + Shape: + Inputs: + - query: :math:`(L, E)` or :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key: :math:`(S, E)` or :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - value: :math:`(S, E)` or :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - key_padding_mask: :math:`(S)` or :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a FloatTensor is provided, it will be directly added to the value. + If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked + positions. If a BoolTensor is provided, positions with ``True`` + are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length, + N is the batch size, E is the embedding dimension. E/num_heads is the head dimension. + - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length, + N is the batch size, E is the embedding dimension. E/num_heads is the head dimension. + + Outputs: + - attn_output: :math:`(L, E)` or :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_output_weights: Only returned when ``need_weights=True``. If ``average_attn_weights=True``, returns + attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or + :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and + :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per + head of shape :math:`(num_heads, L, S)` when input is unbatched or :math:`(N, num_heads, L, S)`. + """ + tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, out_proj_weight, out_proj_bias) + if has_torch_function(tens_ops): + return handle_torch_function( + multi_head_attention_forward, + tens_ops, + query, + key, + value, + embed_dim_to_check, + num_heads, + in_proj_weight, + in_proj_bias, + bias_k, + bias_v, + add_zero_attn, + dropout_p, + out_proj_weight, + out_proj_bias, + training=training, + key_padding_mask=key_padding_mask, + need_weights=need_weights, + attn_mask=attn_mask, + is_causal=is_causal, + use_separate_proj_weight=use_separate_proj_weight, + q_proj_weight=q_proj_weight, + k_proj_weight=k_proj_weight, + v_proj_weight=v_proj_weight, + static_k=static_k, + static_v=static_v, + average_attn_weights=average_attn_weights, + ) + + is_batched = _mha_shape_check(query, key, value, key_padding_mask, attn_mask, num_heads) + + # For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input + # is batched, run the computation and before returning squeeze the + # batch dimension so that the output doesn't carry this temporary batch dimension. + if not is_batched: + # unsqueeze if the input is unbatched + query = query.unsqueeze(1) + key = key.unsqueeze(1) + value = value.unsqueeze(1) + if key_padding_mask is not None: + key_padding_mask = key_padding_mask.unsqueeze(0) + + # set up shape vars + tgt_len, bsz, embed_dim = query.shape + src_len, _, _ = key.shape + + key_padding_mask = _canonical_mask( + mask=key_padding_mask, + mask_name="key_padding_mask", + other_type=_none_or_dtype(attn_mask), + other_name="attn_mask", + target_type=query.dtype + ) + + if is_causal and attn_mask is None: + raise RuntimeError( + "Need attn_mask if specifying the is_causal hint. " + "You may use the Transformer module method " + "`generate_square_subsequent_mask` to create this mask." + ) + + if is_causal and key_padding_mask is None and not need_weights: + # when we have a kpm or need weights, we need attn_mask + # Otherwise, we use the is_causal hint go as is_causal + # indicator to SDPA. + attn_mask = None + else: + attn_mask = _canonical_mask( + mask=attn_mask, + mask_name="attn_mask", + other_type=None, + other_name="", + target_type=query.dtype, + check_other=False, + ) + + if key_padding_mask is not None: + # We have the attn_mask, and use that to merge kpm into it. + # Turn off use of is_causal hint, as the merged mask is no + # longer causal. + is_causal = False + + assert embed_dim == embed_dim_to_check, \ + f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}" + if isinstance(embed_dim, torch.Tensor): + # embed_dim can be a tensor when JIT tracing + head_dim = embed_dim.div(num_heads, rounding_mode='trunc') + else: + head_dim = embed_dim // num_heads + assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {num_heads}" + if use_separate_proj_weight: + # allow MHA to have different embedding dimensions when separate projection weights are used + assert key.shape[:2] == value.shape[:2], \ + f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}" + else: + assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}" + + # + # compute in-projection + # + if not use_separate_proj_weight: + assert in_proj_weight is not None, "use_separate_proj_weight is False but in_proj_weight is None" + q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias) + else: + assert q_proj_weight is not None, "use_separate_proj_weight is True but q_proj_weight is None" + assert k_proj_weight is not None, "use_separate_proj_weight is True but k_proj_weight is None" + assert v_proj_weight is not None, "use_separate_proj_weight is True but v_proj_weight is None" + if in_proj_bias is None: + b_q = b_k = b_v = None + else: + b_q, b_k, b_v = in_proj_bias.chunk(3) + q, k, v = _in_projection(query, key, value, q_proj_weight, k_proj_weight, v_proj_weight, b_q, b_k, b_v) + + # prep attention mask + + if attn_mask is not None: + # ensure attn_mask's dim is 3 + if attn_mask.dim() == 2: + correct_2d_size = (tgt_len, src_len) + if attn_mask.shape != correct_2d_size: + raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.") + attn_mask = attn_mask.unsqueeze(0) + elif attn_mask.dim() == 3: + correct_3d_size = (bsz * num_heads, tgt_len, src_len) + if attn_mask.shape != correct_3d_size: + raise RuntimeError(f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.") + else: + raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported") + + # add bias along batch dimension (currently second) + if bias_k is not None and bias_v is not None: + assert static_k is None, "bias cannot be added to static key." + assert static_v is None, "bias cannot be added to static value." + k = torch.cat([k, bias_k.repeat(1, bsz, 1)]) + v = torch.cat([v, bias_v.repeat(1, bsz, 1)]) + if attn_mask is not None: + attn_mask = pad(attn_mask, (0, 1)) + if key_padding_mask is not None: + key_padding_mask = pad(key_padding_mask, (0, 1)) + else: + assert bias_k is None + assert bias_v is None + + # + # reshape q, k, v for multihead attention and make em batch first + # + q = q.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1) + if static_k is None: + k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1) + else: + # TODO finish disentangling control flow so we don't do in-projections when statics are passed + assert static_k.size(0) == bsz * num_heads, \ + f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}" + assert static_k.size(2) == head_dim, \ + f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}" + k = static_k + if static_v is None: + v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1) + else: + # TODO finish disentangling control flow so we don't do in-projections when statics are passed + assert static_v.size(0) == bsz * num_heads, \ + f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}" + assert static_v.size(2) == head_dim, \ + f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}" + v = static_v + + # add zero attention along batch dimension (now first) + if add_zero_attn: + zero_attn_shape = (bsz * num_heads, 1, head_dim) + k = torch.cat([k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1) + v = torch.cat([v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1) + if attn_mask is not None: + attn_mask = pad(attn_mask, (0, 1)) + if key_padding_mask is not None: + key_padding_mask = pad(key_padding_mask, (0, 1)) + + # update source sequence length after adjustments + src_len = k.size(1) + + # merge key padding and attention masks + if key_padding_mask is not None: + assert key_padding_mask.shape == (bsz, src_len), \ + f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}" + key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len). \ + expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len) + if attn_mask is None: + attn_mask = key_padding_mask + else: + attn_mask = attn_mask + key_padding_mask + + # adjust dropout probability + if not training: + dropout_p = 0.0 + + # + # (deep breath) calculate attention and out projection + # + + if need_weights: + B, Nt, E = q.shape + q_scaled = q * math.sqrt(1.0 / float(E)) + + assert not (is_causal and attn_mask is None), "FIXME: is_causal not implemented for need_weights" + + if attn_mask is not None: + attn_output_weights = torch.baddbmm(attn_mask, q_scaled, k.transpose(-2, -1)) + else: + attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1)) + attn_output_weights = softmax(attn_output_weights, dim=-1) + if dropout_p > 0.0: + attn_output_weights = dropout(attn_output_weights, p=dropout_p) + + attn_output = torch.bmm(attn_output_weights, v) + + attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim) + attn_output = linear(attn_output, out_proj_weight, out_proj_bias) + attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1)) + + # optionally average attention weights over heads + attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) + if average_attn_weights: + attn_output_weights = attn_output_weights.mean(dim=1) + + if not is_batched: + # squeeze the output if input was unbatched + attn_output = attn_output.squeeze(1) + attn_output_weights = attn_output_weights.squeeze(0) + return attn_output, attn_output_weights + else: + # attn_mask can be either (L,S) or (N*num_heads, L, S) + # if attn_mask's shape is (1, L, S) we need to unsqueeze to (1, 1, L, S) + # in order to match the input for SDPA of (N, num_heads, L, S) + if attn_mask is not None: + if attn_mask.size(0) == 1 and attn_mask.dim() == 3: + attn_mask = attn_mask.unsqueeze(0) + else: + attn_mask = attn_mask.view(bsz, num_heads, -1, src_len) + + q = q.view(bsz, num_heads, tgt_len, head_dim) + k = k.view(bsz, num_heads, src_len, head_dim) + v = v.view(bsz, num_heads, src_len, head_dim) + + attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal) + attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim) + + attn_output = linear(attn_output, out_proj_weight, out_proj_bias) + attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1)) + if not is_batched: + # squeeze the output if input was unbatched + attn_output = attn_output.squeeze(1) + return attn_output, None diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/qat/modules/__pycache__/linear_fused.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/qat/modules/__pycache__/linear_fused.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b65e2f2482d21a4211b28a9165bdbfc95dcc638f Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/qat/modules/__pycache__/linear_fused.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/qat/modules/linear_relu.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/qat/modules/linear_relu.py new file mode 100644 index 0000000000000000000000000000000000000000..ae8dffbae5e5602f42db8e3a951efab0f417a6b7 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/qat/modules/linear_relu.py @@ -0,0 +1,15 @@ +# flake8: noqa: F401 +r"""Intrinsic QAT Modules. + +This file is in the process of migration to `torch/ao/nn/intrinsic/qat`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate file under the `torch/ao/nn/intrinsic/qat/modules`, +while adding an import statement here. +""" + +__all__ = [ + 'LinearReLU', +] + +from torch.ao.nn.intrinsic.qat import LinearReLU diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/quantized/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/quantized/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..293549d00159b85bcac3c5d437c4614bd180d6b6 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/quantized/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/quantized/dynamic/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/quantized/dynamic/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3d79bdbfe83209f18b17cc8c7b245f322871d6c0 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/quantized/dynamic/__init__.py @@ -0,0 +1 @@ +from .modules import * # noqa: F403 diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/quantized/dynamic/modules/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/quantized/dynamic/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ea1885a6aec4e570a8eed81bd0cce61bf0e8390a --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/quantized/dynamic/modules/__init__.py @@ -0,0 +1,5 @@ +from .linear_relu import LinearReLU + +__all__ = [ + 'LinearReLU', +] diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/quantized/dynamic/modules/__pycache__/linear_relu.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/quantized/dynamic/modules/__pycache__/linear_relu.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8bcc0a3cca5596661942c1c7a070893d8775851c Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/quantized/dynamic/modules/__pycache__/linear_relu.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/quantized/modules/bn_relu.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/quantized/modules/bn_relu.py new file mode 100644 index 0000000000000000000000000000000000000000..7682b4f8ae426b4aa0537505e55c9c98efd47b94 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/quantized/modules/bn_relu.py @@ -0,0 +1,7 @@ +from torch.ao.nn.intrinsic.quantized import BNReLU2d +from torch.ao.nn.intrinsic.quantized import BNReLU3d + +__all__ = [ + 'BNReLU2d', + 'BNReLU3d', +] diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/quantized/modules/linear_relu.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/quantized/modules/linear_relu.py new file mode 100644 index 0000000000000000000000000000000000000000..3e89e9b5821f78bf18dca840a0834cba70ab8f0d --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/quantized/modules/linear_relu.py @@ -0,0 +1,5 @@ +from torch.ao.nn.intrinsic.quantized import LinearReLU + +__all__ = [ + 'LinearReLU', +] diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/activation.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/activation.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ddf06224616bd579c1022c90d3037595f48990f2 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/activation.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/adaptive.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/adaptive.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cc979ce283c0484f5d1099528fa3d6032037552e Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/adaptive.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/channelshuffle.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/channelshuffle.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0fce27f73ff632e59c06e6c491a8b3ce362b0b30 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/channelshuffle.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/distance.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/distance.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ae78dbbc4ed914fff8bf4ccd0459df46b20f4cb1 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/distance.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/instancenorm.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/instancenorm.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c66fcec07babbc936c7548557f849c393ef717aa Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/instancenorm.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/linear.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/linear.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a81202f08e07fc3be77a198681fe8b90c4a0be8d Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/linear.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/normalization.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/normalization.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab9dbe8dcd6a771cbf61d72fb7a026ff3ddd11b4 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/normalization.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/pixelshuffle.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/pixelshuffle.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aa947b91ed15ee3cd0f2bed950c52e52c0a54685 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/pixelshuffle.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/pooling.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/pooling.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f45c9bc0e7870b631b5f9f93407ed52a6473f6ee Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/pooling.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/rnn.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/rnn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7b48e59c565a93d6dbdd0021b8b3d745f8329b46 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/rnn.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/sparse.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/sparse.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3cda46af61d5cee12623631e34dcee138e1acb5d Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/sparse.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/upsampling.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/upsampling.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a5aae71ac07178c8d47586cb0a24ad9a4cd71940 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/upsampling.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/utils.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aa9e3ac14963e141c95f62544de06595ab4c3e43 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/utils.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/conv.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/conv.py new file mode 100644 index 0000000000000000000000000000000000000000..3f894532a64ee8148124a4b37af98a06c3903c3e --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/conv.py @@ -0,0 +1,1602 @@ +import math +import warnings + +import torch +from torch import Tensor +from torch.nn.parameter import Parameter, UninitializedParameter +from .. import functional as F +from .. import init +from .lazy import LazyModuleMixin +from .module import Module +from .utils import _single, _pair, _triple, _reverse_repeat_tuple +from torch._torch_docs import reproducibility_notes + +from ..common_types import _size_1_t, _size_2_t, _size_3_t +from typing import Optional, List, Tuple, Union + +__all__ = ['Conv1d', 'Conv2d', 'Conv3d', 'ConvTranspose1d', 'ConvTranspose2d', 'ConvTranspose3d', + 'LazyConv1d', 'LazyConv2d', 'LazyConv3d', 'LazyConvTranspose1d', 'LazyConvTranspose2d', + 'LazyConvTranspose3d'] + +convolution_notes = \ + {"groups_note": r"""* :attr:`groups` controls the connections between inputs and outputs. + :attr:`in_channels` and :attr:`out_channels` must both be divisible by + :attr:`groups`. For example, + + * At groups=1, all inputs are convolved to all outputs. + * At groups=2, the operation becomes equivalent to having two conv + layers side by side, each seeing half the input channels + and producing half the output channels, and both subsequently + concatenated. + * At groups= :attr:`in_channels`, each input channel is convolved with + its own set of filters (of size + :math:`\frac{\text{out\_channels}}{\text{in\_channels}}`).""", + + "depthwise_separable_note": r"""When `groups == in_channels` and `out_channels == K * in_channels`, + where `K` is a positive integer, this operation is also known as a "depthwise convolution". + + In other words, for an input of size :math:`(N, C_{in}, L_{in})`, + a depthwise convolution with a depthwise multiplier `K` can be performed with the arguments + :math:`(C_\text{in}=C_\text{in}, C_\text{out}=C_\text{in} \times \text{K}, ..., \text{groups}=C_\text{in})`."""} # noqa: B950 + + + + + +class _ConvNd(Module): + + __constants__ = ['stride', 'padding', 'dilation', 'groups', + 'padding_mode', 'output_padding', 'in_channels', + 'out_channels', 'kernel_size'] + __annotations__ = {'bias': Optional[torch.Tensor]} + + def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor: # type: ignore[empty-body] + ... + + in_channels: int + _reversed_padding_repeated_twice: List[int] + out_channels: int + kernel_size: Tuple[int, ...] + stride: Tuple[int, ...] + padding: Union[str, Tuple[int, ...]] + dilation: Tuple[int, ...] + transposed: bool + output_padding: Tuple[int, ...] + groups: int + padding_mode: str + weight: Tensor + bias: Optional[Tensor] + + def __init__(self, + in_channels: int, + out_channels: int, + kernel_size: Tuple[int, ...], + stride: Tuple[int, ...], + padding: Tuple[int, ...], + dilation: Tuple[int, ...], + transposed: bool, + output_padding: Tuple[int, ...], + groups: int, + bias: bool, + padding_mode: str, + device=None, + dtype=None) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + if groups <= 0: + raise ValueError('groups must be a positive integer') + if in_channels % groups != 0: + raise ValueError('in_channels must be divisible by groups') + if out_channels % groups != 0: + raise ValueError('out_channels must be divisible by groups') + valid_padding_strings = {'same', 'valid'} + if isinstance(padding, str): + if padding not in valid_padding_strings: + raise ValueError( + f"Invalid padding string {padding!r}, should be one of {valid_padding_strings}") + if padding == 'same' and any(s != 1 for s in stride): + raise ValueError("padding='same' is not supported for strided convolutions") + + valid_padding_modes = {'zeros', 'reflect', 'replicate', 'circular'} + if padding_mode not in valid_padding_modes: + raise ValueError(f"padding_mode must be one of {valid_padding_modes}, but got padding_mode='{padding_mode}'") + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + self.transposed = transposed + self.output_padding = output_padding + self.groups = groups + self.padding_mode = padding_mode + # `_reversed_padding_repeated_twice` is the padding to be passed to + # `F.pad` if needed (e.g., for non-zero padding types that are + # implemented as two ops: padding + conv). `F.pad` accepts paddings in + # reverse order than the dimension. + if isinstance(self.padding, str): + self._reversed_padding_repeated_twice = [0, 0] * len(kernel_size) + if padding == 'same': + for d, k, i in zip(dilation, kernel_size, + range(len(kernel_size) - 1, -1, -1)): + total_padding = d * (k - 1) + left_pad = total_padding // 2 + self._reversed_padding_repeated_twice[2 * i] = left_pad + self._reversed_padding_repeated_twice[2 * i + 1] = ( + total_padding - left_pad) + else: + self._reversed_padding_repeated_twice = _reverse_repeat_tuple(self.padding, 2) + + if transposed: + self.weight = Parameter(torch.empty( + (in_channels, out_channels // groups, *kernel_size), **factory_kwargs)) + else: + self.weight = Parameter(torch.empty( + (out_channels, in_channels // groups, *kernel_size), **factory_kwargs)) + if bias: + self.bias = Parameter(torch.empty(out_channels, **factory_kwargs)) + else: + self.register_parameter('bias', None) + + self.reset_parameters() + + def reset_parameters(self) -> None: + # Setting a=sqrt(5) in kaiming_uniform is the same as initializing with + # uniform(-1/sqrt(k), 1/sqrt(k)), where k = weight.size(1) * prod(*kernel_size) + # For more details see: https://github.com/pytorch/pytorch/issues/15314#issuecomment-477448573 + init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + if self.bias is not None: + fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) + if fan_in != 0: + bound = 1 / math.sqrt(fan_in) + init.uniform_(self.bias, -bound, bound) + + def extra_repr(self): + s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}' + ', stride={stride}') + if self.padding != (0,) * len(self.padding): + s += ', padding={padding}' + if self.dilation != (1,) * len(self.dilation): + s += ', dilation={dilation}' + if self.output_padding != (0,) * len(self.output_padding): + s += ', output_padding={output_padding}' + if self.groups != 1: + s += ', groups={groups}' + if self.bias is None: + s += ', bias=False' + if self.padding_mode != 'zeros': + s += ', padding_mode={padding_mode}' + return s.format(**self.__dict__) + + def __setstate__(self, state): + super().__setstate__(state) + if not hasattr(self, 'padding_mode'): + self.padding_mode = 'zeros' + + +class Conv1d(_ConvNd): + __doc__ = r"""Applies a 1D convolution over an input signal composed of several input + planes. + + In the simplest case, the output value of the layer with input size + :math:`(N, C_{\text{in}}, L)` and output :math:`(N, C_{\text{out}}, L_{\text{out}})` can be + precisely described as: + + .. math:: + \text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) + + \sum_{k = 0}^{C_{in} - 1} \text{weight}(C_{\text{out}_j}, k) + \star \text{input}(N_i, k) + + where :math:`\star` is the valid `cross-correlation`_ operator, + :math:`N` is a batch size, :math:`C` denotes a number of channels, + :math:`L` is a length of signal sequence. + """ + r""" + + This module supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + * :attr:`stride` controls the stride for the cross-correlation, a single + number or a one-element tuple. + + * :attr:`padding` controls the amount of padding applied to the input. It + can be either a string {{'valid', 'same'}} or a tuple of ints giving the + amount of implicit padding applied on both sides. + + * :attr:`dilation` controls the spacing between the kernel points; also + known as the à trous algorithm. It is harder to describe, but this `link`_ + has a nice visualization of what :attr:`dilation` does. + + {groups_note} + + Note: + {depthwise_separable_note} + Note: + {cudnn_reproducibility_note} + + Note: + ``padding='valid'`` is the same as no padding. ``padding='same'`` pads + the input so the output has the shape as the input. However, this mode + doesn't support any stride values other than 1. + + Note: + This module supports complex data types i.e. ``complex32, complex64, complex128``. + + Args: + in_channels (int): Number of channels in the input image + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int, tuple or str, optional): Padding added to both sides of + the input. Default: 0 + padding_mode (str, optional): ``'zeros'``, ``'reflect'``, + ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` + dilation (int or tuple, optional): Spacing between kernel + elements. Default: 1 + groups (int, optional): Number of blocked connections from input + channels to output channels. Default: 1 + bias (bool, optional): If ``True``, adds a learnable bias to the + output. Default: ``True`` + + """.format(**reproducibility_notes, **convolution_notes) + r""" + + Shape: + - Input: :math:`(N, C_{in}, L_{in})` or :math:`(C_{in}, L_{in})` + - Output: :math:`(N, C_{out}, L_{out})` or :math:`(C_{out}, L_{out})`, where + + .. math:: + L_{out} = \left\lfloor\frac{L_{in} + 2 \times \text{padding} - \text{dilation} + \times (\text{kernel\_size} - 1) - 1}{\text{stride}} + 1\right\rfloor + + Attributes: + weight (Tensor): the learnable weights of the module of shape + :math:`(\text{out\_channels}, + \frac{\text{in\_channels}}{\text{groups}}, \text{kernel\_size})`. + The values of these weights are sampled from + :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where + :math:`k = \frac{groups}{C_\text{in} * \text{kernel\_size}}` + bias (Tensor): the learnable bias of the module of shape + (out_channels). If :attr:`bias` is ``True``, then the values of these weights are + sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where + :math:`k = \frac{groups}{C_\text{in} * \text{kernel\_size}}` + + Examples:: + + >>> m = nn.Conv1d(16, 33, 3, stride=2) + >>> input = torch.randn(20, 16, 50) + >>> output = m(input) + + .. _cross-correlation: + https://en.wikipedia.org/wiki/Cross-correlation + + .. _link: + https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: _size_1_t, + stride: _size_1_t = 1, + padding: Union[str, _size_1_t] = 0, + dilation: _size_1_t = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = 'zeros', # TODO: refine this type + device=None, + dtype=None + ) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + # we create new variables below to make mypy happy since kernel_size has + # type Union[int, Tuple[int]] and kernel_size_ has type Tuple[int] + kernel_size_ = _single(kernel_size) + stride_ = _single(stride) + padding_ = padding if isinstance(padding, str) else _single(padding) + dilation_ = _single(dilation) + super().__init__( + in_channels, out_channels, kernel_size_, stride_, padding_, dilation_, + False, _single(0), groups, bias, padding_mode, **factory_kwargs) + + def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]): + if self.padding_mode != 'zeros': + return F.conv1d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), + weight, bias, self.stride, + _single(0), self.dilation, self.groups) + return F.conv1d(input, weight, bias, self.stride, + self.padding, self.dilation, self.groups) + + def forward(self, input: Tensor) -> Tensor: + return self._conv_forward(input, self.weight, self.bias) + + +class Conv2d(_ConvNd): + __doc__ = r"""Applies a 2D convolution over an input signal composed of several input + planes. + + In the simplest case, the output value of the layer with input size + :math:`(N, C_{\text{in}}, H, W)` and output :math:`(N, C_{\text{out}}, H_{\text{out}}, W_{\text{out}})` + can be precisely described as: + + .. math:: + \text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) + + \sum_{k = 0}^{C_{\text{in}} - 1} \text{weight}(C_{\text{out}_j}, k) \star \text{input}(N_i, k) + + + where :math:`\star` is the valid 2D `cross-correlation`_ operator, + :math:`N` is a batch size, :math:`C` denotes a number of channels, + :math:`H` is a height of input planes in pixels, and :math:`W` is + width in pixels. + """ + r""" + + This module supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + * :attr:`stride` controls the stride for the cross-correlation, a single + number or a tuple. + + * :attr:`padding` controls the amount of padding applied to the input. It + can be either a string {{'valid', 'same'}} or an int / a tuple of ints giving the + amount of implicit padding applied on both sides. + + * :attr:`dilation` controls the spacing between the kernel points; also + known as the à trous algorithm. It is harder to describe, but this `link`_ + has a nice visualization of what :attr:`dilation` does. + + {groups_note} + + The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be: + + - a single ``int`` -- in which case the same value is used for the height and width dimension + - a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension, + and the second `int` for the width dimension + + Note: + {depthwise_separable_note} + + Note: + {cudnn_reproducibility_note} + + Note: + ``padding='valid'`` is the same as no padding. ``padding='same'`` pads + the input so the output has the shape as the input. However, this mode + doesn't support any stride values other than 1. + + Note: + This module supports complex data types i.e. ``complex32, complex64, complex128``. + + Args: + in_channels (int): Number of channels in the input image + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int, tuple or str, optional): Padding added to all four sides of + the input. Default: 0 + padding_mode (str, optional): ``'zeros'``, ``'reflect'``, + ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` + dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 + groups (int, optional): Number of blocked connections from input + channels to output channels. Default: 1 + bias (bool, optional): If ``True``, adds a learnable bias to the + output. Default: ``True`` + """.format(**reproducibility_notes, **convolution_notes) + r""" + + Shape: + - Input: :math:`(N, C_{in}, H_{in}, W_{in})` or :math:`(C_{in}, H_{in}, W_{in})` + - Output: :math:`(N, C_{out}, H_{out}, W_{out})` or :math:`(C_{out}, H_{out}, W_{out})`, where + + .. math:: + H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[0] - \text{dilation}[0] + \times (\text{kernel\_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor + + .. math:: + W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[1] - \text{dilation}[1] + \times (\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor + + Attributes: + weight (Tensor): the learnable weights of the module of shape + :math:`(\text{out\_channels}, \frac{\text{in\_channels}}{\text{groups}},` + :math:`\text{kernel\_size[0]}, \text{kernel\_size[1]})`. + The values of these weights are sampled from + :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where + :math:`k = \frac{groups}{C_\text{in} * \prod_{i=0}^{1}\text{kernel\_size}[i]}` + bias (Tensor): the learnable bias of the module of shape + (out_channels). If :attr:`bias` is ``True``, + then the values of these weights are + sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where + :math:`k = \frac{groups}{C_\text{in} * \prod_{i=0}^{1}\text{kernel\_size}[i]}` + + Examples: + + >>> # With square kernels and equal stride + >>> m = nn.Conv2d(16, 33, 3, stride=2) + >>> # non-square kernels and unequal stride and with padding + >>> m = nn.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2)) + >>> # non-square kernels and unequal stride and with padding and dilation + >>> m = nn.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1)) + >>> input = torch.randn(20, 16, 50, 100) + >>> output = m(input) + + .. _cross-correlation: + https://en.wikipedia.org/wiki/Cross-correlation + + .. _link: + https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: _size_2_t, + stride: _size_2_t = 1, + padding: Union[str, _size_2_t] = 0, + dilation: _size_2_t = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = 'zeros', # TODO: refine this type + device=None, + dtype=None + ) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + kernel_size_ = _pair(kernel_size) + stride_ = _pair(stride) + padding_ = padding if isinstance(padding, str) else _pair(padding) + dilation_ = _pair(dilation) + super().__init__( + in_channels, out_channels, kernel_size_, stride_, padding_, dilation_, + False, _pair(0), groups, bias, padding_mode, **factory_kwargs) + + def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]): + if self.padding_mode != 'zeros': + return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), + weight, bias, self.stride, + _pair(0), self.dilation, self.groups) + return F.conv2d(input, weight, bias, self.stride, + self.padding, self.dilation, self.groups) + + def forward(self, input: Tensor) -> Tensor: + return self._conv_forward(input, self.weight, self.bias) + +class Conv3d(_ConvNd): + __doc__ = r"""Applies a 3D convolution over an input signal composed of several input + planes. + + In the simplest case, the output value of the layer with input size :math:`(N, C_{in}, D, H, W)` + and output :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})` can be precisely described as: + + .. math:: + out(N_i, C_{out_j}) = bias(C_{out_j}) + + \sum_{k = 0}^{C_{in} - 1} weight(C_{out_j}, k) \star input(N_i, k) + + where :math:`\star` is the valid 3D `cross-correlation`_ operator + """ + r""" + + This module supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + * :attr:`stride` controls the stride for the cross-correlation. + + * :attr:`padding` controls the amount of padding applied to the input. It + can be either a string {{'valid', 'same'}} or a tuple of ints giving the + amount of implicit padding applied on both sides. + + * :attr:`dilation` controls the spacing between the kernel points; also known as the à trous algorithm. + It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does. + + {groups_note} + + The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be: + + - a single ``int`` -- in which case the same value is used for the depth, height and width dimension + - a ``tuple`` of three ints -- in which case, the first `int` is used for the depth dimension, + the second `int` for the height dimension and the third `int` for the width dimension + + Note: + {depthwise_separable_note} + + Note: + {cudnn_reproducibility_note} + + Note: + ``padding='valid'`` is the same as no padding. ``padding='same'`` pads + the input so the output has the shape as the input. However, this mode + doesn't support any stride values other than 1. + + Note: + This module supports complex data types i.e. ``complex32, complex64, complex128``. + + Args: + in_channels (int): Number of channels in the input image + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int, tuple or str, optional): Padding added to all six sides of + the input. Default: 0 + padding_mode (str, optional): ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` + dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 + groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 + bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` + """.format(**reproducibility_notes, **convolution_notes) + r""" + + Shape: + - Input: :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})` or :math:`(C_{in}, D_{in}, H_{in}, W_{in})` + - Output: :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})` or :math:`(C_{out}, D_{out}, H_{out}, W_{out})`, + where + + .. math:: + D_{out} = \left\lfloor\frac{D_{in} + 2 \times \text{padding}[0] - \text{dilation}[0] + \times (\text{kernel\_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor + + .. math:: + H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[1] - \text{dilation}[1] + \times (\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor + + .. math:: + W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[2] - \text{dilation}[2] + \times (\text{kernel\_size}[2] - 1) - 1}{\text{stride}[2]} + 1\right\rfloor + + Attributes: + weight (Tensor): the learnable weights of the module of shape + :math:`(\text{out\_channels}, \frac{\text{in\_channels}}{\text{groups}},` + :math:`\text{kernel\_size[0]}, \text{kernel\_size[1]}, \text{kernel\_size[2]})`. + The values of these weights are sampled from + :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where + :math:`k = \frac{groups}{C_\text{in} * \prod_{i=0}^{2}\text{kernel\_size}[i]}` + bias (Tensor): the learnable bias of the module of shape (out_channels). If :attr:`bias` is ``True``, + then the values of these weights are + sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where + :math:`k = \frac{groups}{C_\text{in} * \prod_{i=0}^{2}\text{kernel\_size}[i]}` + + Examples:: + + >>> # With square kernels and equal stride + >>> m = nn.Conv3d(16, 33, 3, stride=2) + >>> # non-square kernels and unequal stride and with padding + >>> m = nn.Conv3d(16, 33, (3, 5, 2), stride=(2, 1, 1), padding=(4, 2, 0)) + >>> input = torch.randn(20, 16, 10, 50, 100) + >>> output = m(input) + + .. _cross-correlation: + https://en.wikipedia.org/wiki/Cross-correlation + + .. _link: + https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: _size_3_t, + stride: _size_3_t = 1, + padding: Union[str, _size_3_t] = 0, + dilation: _size_3_t = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = 'zeros', + device=None, + dtype=None + ) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + kernel_size_ = _triple(kernel_size) + stride_ = _triple(stride) + padding_ = padding if isinstance(padding, str) else _triple(padding) + dilation_ = _triple(dilation) + super().__init__( + in_channels, out_channels, kernel_size_, stride_, padding_, dilation_, + False, _triple(0), groups, bias, padding_mode, **factory_kwargs) + + def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]): + if self.padding_mode != "zeros": + return F.conv3d( + F.pad( + input, self._reversed_padding_repeated_twice, mode=self.padding_mode + ), + weight, + bias, + self.stride, + _triple(0), + self.dilation, + self.groups, + ) + return F.conv3d( + input, weight, bias, self.stride, self.padding, self.dilation, self.groups + ) + + def forward(self, input: Tensor) -> Tensor: + return self._conv_forward(input, self.weight, self.bias) + + + +class _ConvTransposeNd(_ConvNd): + def __init__(self, in_channels, out_channels, kernel_size, stride, + padding, dilation, transposed, output_padding, + groups, bias, padding_mode, device=None, dtype=None) -> None: + if padding_mode != 'zeros': + raise ValueError(f'Only "zeros" padding mode is supported for {self.__class__.__name__}') + + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__( + in_channels, out_channels, kernel_size, stride, + padding, dilation, transposed, output_padding, + groups, bias, padding_mode, **factory_kwargs) + + # dilation being an optional parameter is for backwards + # compatibility + def _output_padding(self, input: Tensor, output_size: Optional[List[int]], + stride: List[int], padding: List[int], kernel_size: List[int], + num_spatial_dims: int, dilation: Optional[List[int]] = None) -> List[int]: + if output_size is None: + ret = _single(self.output_padding) # converting to list if was not already + else: + has_batch_dim = input.dim() == num_spatial_dims + 2 + num_non_spatial_dims = 2 if has_batch_dim else 1 + if len(output_size) == num_non_spatial_dims + num_spatial_dims: + output_size = output_size[num_non_spatial_dims:] + if len(output_size) != num_spatial_dims: + raise ValueError( + "ConvTranspose{}D: for {}D input, output_size must have {} or {} elements (got {})" + .format(num_spatial_dims, input.dim(), num_spatial_dims, + num_non_spatial_dims + num_spatial_dims, len(output_size))) + + min_sizes = torch.jit.annotate(List[int], []) + max_sizes = torch.jit.annotate(List[int], []) + for d in range(num_spatial_dims): + dim_size = ((input.size(d + num_non_spatial_dims) - 1) * stride[d] - + 2 * padding[d] + + (dilation[d] if dilation is not None else 1) * (kernel_size[d] - 1) + 1) + min_sizes.append(dim_size) + max_sizes.append(min_sizes[d] + stride[d] - 1) + + for i in range(len(output_size)): + size = output_size[i] + min_size = min_sizes[i] + max_size = max_sizes[i] + if size < min_size or size > max_size: + raise ValueError( + f"requested an output size of {output_size}, but valid sizes range " + f"from {min_sizes} to {max_sizes} (for an input of {input.size()[2:]})") + + res = torch.jit.annotate(List[int], []) + for d in range(num_spatial_dims): + res.append(output_size[d] - min_sizes[d]) + + ret = res + return ret + + +class ConvTranspose1d(_ConvTransposeNd): + __doc__ = r"""Applies a 1D transposed convolution operator over an input image + composed of several input planes. + + This module can be seen as the gradient of Conv1d with respect to its input. + It is also known as a fractionally-strided convolution or + a deconvolution (although it is not an actual deconvolution operation as it does + not compute a true inverse of convolution). For more information, see the visualizations + `here`_ and the `Deconvolutional Networks`_ paper. + + This module supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + * :attr:`stride` controls the stride for the cross-correlation. + + * :attr:`padding` controls the amount of implicit zero padding on both + sides for ``dilation * (kernel_size - 1) - padding`` number of points. See note + below for details. + + * :attr:`output_padding` controls the additional size added to one side + of the output shape. See note below for details. + + * :attr:`dilation` controls the spacing between the kernel points; also known as the à trous algorithm. + It is harder to describe, but the link `here`_ has a nice visualization of what :attr:`dilation` does. + + {groups_note} + + Note: + The :attr:`padding` argument effectively adds ``dilation * (kernel_size - 1) - padding`` + amount of zero padding to both sizes of the input. This is set so that + when a :class:`~torch.nn.Conv1d` and a :class:`~torch.nn.ConvTranspose1d` + are initialized with same parameters, they are inverses of each other in + regard to the input and output shapes. However, when ``stride > 1``, + :class:`~torch.nn.Conv1d` maps multiple input shapes to the same output + shape. :attr:`output_padding` is provided to resolve this ambiguity by + effectively increasing the calculated output shape on one side. Note + that :attr:`output_padding` is only used to find output shape, but does + not actually add zero-padding to output. + + Note: + In some circumstances when using the CUDA backend with CuDNN, this operator + may select a nondeterministic algorithm to increase performance. If this is + undesirable, you can try to make the operation deterministic (potentially at + a performance cost) by setting ``torch.backends.cudnn.deterministic = + True``. + Please see the notes on :doc:`/notes/randomness` for background. + + + Args: + in_channels (int): Number of channels in the input image + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): ``dilation * (kernel_size - 1) - padding`` zero-padding + will be added to both sides of the input. Default: 0 + output_padding (int or tuple, optional): Additional size added to one side + of the output shape. Default: 0 + groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 + bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` + dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 + """.format(**reproducibility_notes, **convolution_notes) + r""" + + Shape: + - Input: :math:`(N, C_{in}, L_{in})` or :math:`(C_{in}, L_{in})` + - Output: :math:`(N, C_{out}, L_{out})` or :math:`(C_{out}, L_{out})`, where + + .. math:: + L_{out} = (L_{in} - 1) \times \text{stride} - 2 \times \text{padding} + \text{dilation} + \times (\text{kernel\_size} - 1) + \text{output\_padding} + 1 + + Attributes: + weight (Tensor): the learnable weights of the module of shape + :math:`(\text{in\_channels}, \frac{\text{out\_channels}}{\text{groups}},` + :math:`\text{kernel\_size})`. + The values of these weights are sampled from + :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where + :math:`k = \frac{groups}{C_\text{out} * \text{kernel\_size}}` + bias (Tensor): the learnable bias of the module of shape (out_channels). + If :attr:`bias` is ``True``, then the values of these weights are + sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where + :math:`k = \frac{groups}{C_\text{out} * \text{kernel\_size}}` + + .. _`here`: + https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md + + .. _`Deconvolutional Networks`: + https://www.matthewzeiler.com/mattzeiler/deconvolutionalnetworks.pdf + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: _size_1_t, + stride: _size_1_t = 1, + padding: _size_1_t = 0, + output_padding: _size_1_t = 0, + groups: int = 1, + bias: bool = True, + dilation: _size_1_t = 1, + padding_mode: str = 'zeros', + device=None, + dtype=None + ) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + kernel_size = _single(kernel_size) + stride = _single(stride) + padding = _single(padding) + dilation = _single(dilation) + output_padding = _single(output_padding) + super().__init__( + in_channels, out_channels, kernel_size, stride, padding, dilation, + True, output_padding, groups, bias, padding_mode, **factory_kwargs) + + def forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor: + if self.padding_mode != 'zeros': + raise ValueError('Only `zeros` padding mode is supported for ConvTranspose1d') + + assert isinstance(self.padding, tuple) + # One cannot replace List by Tuple or Sequence in "_output_padding" because + # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`. + num_spatial_dims = 1 + output_padding = self._output_padding( + input, output_size, self.stride, self.padding, self.kernel_size, # type: ignore[arg-type] + num_spatial_dims, self.dilation) # type: ignore[arg-type] + return F.conv_transpose1d( + input, self.weight, self.bias, self.stride, self.padding, + output_padding, self.groups, self.dilation) + + +class ConvTranspose2d(_ConvTransposeNd): + __doc__ = r"""Applies a 2D transposed convolution operator over an input image + composed of several input planes. + + This module can be seen as the gradient of Conv2d with respect to its input. + It is also known as a fractionally-strided convolution or + a deconvolution (although it is not an actual deconvolution operation as it does + not compute a true inverse of convolution). For more information, see the visualizations + `here`_ and the `Deconvolutional Networks`_ paper. + + This module supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + * :attr:`stride` controls the stride for the cross-correlation. + + * :attr:`padding` controls the amount of implicit zero padding on both + sides for ``dilation * (kernel_size - 1) - padding`` number of points. See note + below for details. + + * :attr:`output_padding` controls the additional size added to one side + of the output shape. See note below for details. + + * :attr:`dilation` controls the spacing between the kernel points; also known as the à trous algorithm. + It is harder to describe, but the link `here`_ has a nice visualization of what :attr:`dilation` does. + + {groups_note} + + The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`output_padding` + can either be: + + - a single ``int`` -- in which case the same value is used for the height and width dimensions + - a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension, + and the second `int` for the width dimension + + Note: + The :attr:`padding` argument effectively adds ``dilation * (kernel_size - 1) - padding`` + amount of zero padding to both sizes of the input. This is set so that + when a :class:`~torch.nn.Conv2d` and a :class:`~torch.nn.ConvTranspose2d` + are initialized with same parameters, they are inverses of each other in + regard to the input and output shapes. However, when ``stride > 1``, + :class:`~torch.nn.Conv2d` maps multiple input shapes to the same output + shape. :attr:`output_padding` is provided to resolve this ambiguity by + effectively increasing the calculated output shape on one side. Note + that :attr:`output_padding` is only used to find output shape, but does + not actually add zero-padding to output. + + Note: + {cudnn_reproducibility_note} + + Args: + in_channels (int): Number of channels in the input image + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): ``dilation * (kernel_size - 1) - padding`` zero-padding + will be added to both sides of each dimension in the input. Default: 0 + output_padding (int or tuple, optional): Additional size added to one side + of each dimension in the output shape. Default: 0 + groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 + bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` + dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 + """.format(**reproducibility_notes, **convolution_notes) + r""" + + Shape: + - Input: :math:`(N, C_{in}, H_{in}, W_{in})` or :math:`(C_{in}, H_{in}, W_{in})` + - Output: :math:`(N, C_{out}, H_{out}, W_{out})` or :math:`(C_{out}, H_{out}, W_{out})`, where + + .. math:: + H_{out} = (H_{in} - 1) \times \text{stride}[0] - 2 \times \text{padding}[0] + \text{dilation}[0] + \times (\text{kernel\_size}[0] - 1) + \text{output\_padding}[0] + 1 + .. math:: + W_{out} = (W_{in} - 1) \times \text{stride}[1] - 2 \times \text{padding}[1] + \text{dilation}[1] + \times (\text{kernel\_size}[1] - 1) + \text{output\_padding}[1] + 1 + + Attributes: + weight (Tensor): the learnable weights of the module of shape + :math:`(\text{in\_channels}, \frac{\text{out\_channels}}{\text{groups}},` + :math:`\text{kernel\_size[0]}, \text{kernel\_size[1]})`. + The values of these weights are sampled from + :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where + :math:`k = \frac{groups}{C_\text{out} * \prod_{i=0}^{1}\text{kernel\_size}[i]}` + bias (Tensor): the learnable bias of the module of shape (out_channels) + If :attr:`bias` is ``True``, then the values of these weights are + sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where + :math:`k = \frac{groups}{C_\text{out} * \prod_{i=0}^{1}\text{kernel\_size}[i]}` + + Examples:: + + >>> # With square kernels and equal stride + >>> m = nn.ConvTranspose2d(16, 33, 3, stride=2) + >>> # non-square kernels and unequal stride and with padding + >>> m = nn.ConvTranspose2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2)) + >>> input = torch.randn(20, 16, 50, 100) + >>> output = m(input) + >>> # exact output size can be also specified as an argument + >>> input = torch.randn(1, 16, 12, 12) + >>> downsample = nn.Conv2d(16, 16, 3, stride=2, padding=1) + >>> upsample = nn.ConvTranspose2d(16, 16, 3, stride=2, padding=1) + >>> h = downsample(input) + >>> h.size() + torch.Size([1, 16, 6, 6]) + >>> output = upsample(h, output_size=input.size()) + >>> output.size() + torch.Size([1, 16, 12, 12]) + + .. _`here`: + https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md + + .. _`Deconvolutional Networks`: + https://www.matthewzeiler.com/mattzeiler/deconvolutionalnetworks.pdf + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: _size_2_t, + stride: _size_2_t = 1, + padding: _size_2_t = 0, + output_padding: _size_2_t = 0, + groups: int = 1, + bias: bool = True, + dilation: _size_2_t = 1, + padding_mode: str = 'zeros', + device=None, + dtype=None + ) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + kernel_size = _pair(kernel_size) + stride = _pair(stride) + padding = _pair(padding) + dilation = _pair(dilation) + output_padding = _pair(output_padding) + super().__init__( + in_channels, out_channels, kernel_size, stride, padding, dilation, + True, output_padding, groups, bias, padding_mode, **factory_kwargs) + + def forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor: + if self.padding_mode != 'zeros': + raise ValueError('Only `zeros` padding mode is supported for ConvTranspose2d') + + assert isinstance(self.padding, tuple) + # One cannot replace List by Tuple or Sequence in "_output_padding" because + # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`. + num_spatial_dims = 2 + output_padding = self._output_padding( + input, output_size, self.stride, self.padding, self.kernel_size, # type: ignore[arg-type] + num_spatial_dims, self.dilation) # type: ignore[arg-type] + + return F.conv_transpose2d( + input, self.weight, self.bias, self.stride, self.padding, + output_padding, self.groups, self.dilation) + + +class ConvTranspose3d(_ConvTransposeNd): + __doc__ = r"""Applies a 3D transposed convolution operator over an input image composed of several input + planes. + The transposed convolution operator multiplies each input value element-wise by a learnable kernel, + and sums over the outputs from all input feature planes. + + This module can be seen as the gradient of Conv3d with respect to its input. + It is also known as a fractionally-strided convolution or + a deconvolution (although it is not an actual deconvolution operation as it does + not compute a true inverse of convolution). For more information, see the visualizations + `here`_ and the `Deconvolutional Networks`_ paper. + + This module supports :ref:`TensorFloat32`. + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + * :attr:`stride` controls the stride for the cross-correlation. + + * :attr:`padding` controls the amount of implicit zero padding on both + sides for ``dilation * (kernel_size - 1) - padding`` number of points. See note + below for details. + + * :attr:`output_padding` controls the additional size added to one side + of the output shape. See note below for details. + + * :attr:`dilation` controls the spacing between the kernel points; also known as the à trous algorithm. + It is harder to describe, but the link `here`_ has a nice visualization of what :attr:`dilation` does. + + {groups_note} + + The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`output_padding` + can either be: + + - a single ``int`` -- in which case the same value is used for the depth, height and width dimensions + - a ``tuple`` of three ints -- in which case, the first `int` is used for the depth dimension, + the second `int` for the height dimension and the third `int` for the width dimension + + Note: + The :attr:`padding` argument effectively adds ``dilation * (kernel_size - 1) - padding`` + amount of zero padding to both sizes of the input. This is set so that + when a :class:`~torch.nn.Conv3d` and a :class:`~torch.nn.ConvTranspose3d` + are initialized with same parameters, they are inverses of each other in + regard to the input and output shapes. However, when ``stride > 1``, + :class:`~torch.nn.Conv3d` maps multiple input shapes to the same output + shape. :attr:`output_padding` is provided to resolve this ambiguity by + effectively increasing the calculated output shape on one side. Note + that :attr:`output_padding` is only used to find output shape, but does + not actually add zero-padding to output. + + Note: + {cudnn_reproducibility_note} + + Args: + in_channels (int): Number of channels in the input image + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): ``dilation * (kernel_size - 1) - padding`` zero-padding + will be added to both sides of each dimension in the input. Default: 0 + output_padding (int or tuple, optional): Additional size added to one side + of each dimension in the output shape. Default: 0 + groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 + bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` + dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 + """.format(**reproducibility_notes, **convolution_notes) + r""" + + Shape: + - Input: :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})` or :math:`(C_{in}, D_{in}, H_{in}, W_{in})` + - Output: :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})` or + :math:`(C_{out}, D_{out}, H_{out}, W_{out})`, where + + .. math:: + D_{out} = (D_{in} - 1) \times \text{stride}[0] - 2 \times \text{padding}[0] + \text{dilation}[0] + \times (\text{kernel\_size}[0] - 1) + \text{output\_padding}[0] + 1 + .. math:: + H_{out} = (H_{in} - 1) \times \text{stride}[1] - 2 \times \text{padding}[1] + \text{dilation}[1] + \times (\text{kernel\_size}[1] - 1) + \text{output\_padding}[1] + 1 + .. math:: + W_{out} = (W_{in} - 1) \times \text{stride}[2] - 2 \times \text{padding}[2] + \text{dilation}[2] + \times (\text{kernel\_size}[2] - 1) + \text{output\_padding}[2] + 1 + + + Attributes: + weight (Tensor): the learnable weights of the module of shape + :math:`(\text{in\_channels}, \frac{\text{out\_channels}}{\text{groups}},` + :math:`\text{kernel\_size[0]}, \text{kernel\_size[1]}, \text{kernel\_size[2]})`. + The values of these weights are sampled from + :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where + :math:`k = \frac{groups}{C_\text{out} * \prod_{i=0}^{2}\text{kernel\_size}[i]}` + bias (Tensor): the learnable bias of the module of shape (out_channels) + If :attr:`bias` is ``True``, then the values of these weights are + sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where + :math:`k = \frac{groups}{C_\text{out} * \prod_{i=0}^{2}\text{kernel\_size}[i]}` + + Examples:: + + >>> # With square kernels and equal stride + >>> m = nn.ConvTranspose3d(16, 33, 3, stride=2) + >>> # non-square kernels and unequal stride and with padding + >>> m = nn.ConvTranspose3d(16, 33, (3, 5, 2), stride=(2, 1, 1), padding=(0, 4, 2)) + >>> input = torch.randn(20, 16, 10, 50, 100) + >>> output = m(input) + + .. _`here`: + https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md + + .. _`Deconvolutional Networks`: + https://www.matthewzeiler.com/mattzeiler/deconvolutionalnetworks.pdf + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: _size_3_t, + stride: _size_3_t = 1, + padding: _size_3_t = 0, + output_padding: _size_3_t = 0, + groups: int = 1, + bias: bool = True, + dilation: _size_3_t = 1, + padding_mode: str = 'zeros', + device=None, + dtype=None + ) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + kernel_size = _triple(kernel_size) + stride = _triple(stride) + padding = _triple(padding) + dilation = _triple(dilation) + output_padding = _triple(output_padding) + super().__init__( + in_channels, out_channels, kernel_size, stride, padding, dilation, + True, output_padding, groups, bias, padding_mode, **factory_kwargs) + + def forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor: + if self.padding_mode != 'zeros': + raise ValueError('Only `zeros` padding mode is supported for ConvTranspose3d') + + assert isinstance(self.padding, tuple) + # One cannot replace List by Tuple or Sequence in "_output_padding" because + # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`. + num_spatial_dims = 3 + output_padding = self._output_padding( + input, output_size, self.stride, self.padding, self.kernel_size, # type: ignore[arg-type] + num_spatial_dims, self.dilation) # type: ignore[arg-type] + + return F.conv_transpose3d( + input, self.weight, self.bias, self.stride, self.padding, + output_padding, self.groups, self.dilation) + + +# TODO: Deprecate and remove the following alias `_ConvTransposeMixin`. +# +# `_ConvTransposeMixin` was a mixin that was removed. It is meant to be used +# with `_ConvNd` to construct actual module classes that implements conv +# transpose ops: +# +# class MyConvTranspose(_ConvNd, _ConvTransposeMixin): +# ... +# +# In PyTorch, it has been replaced by `_ConvTransposeNd`, which is a proper +# subclass of `_ConvNd`. However, some user code in the wild still (incorrectly) +# use the internal class `_ConvTransposeMixin`. Hence, we provide this alias +# for BC, because it is cheap and easy for us to do so, even though that +# `_ConvTransposeNd` is really not a mixin anymore (but multiple inheritance as +# above would still work). +class _ConvTransposeMixin(_ConvTransposeNd): + def __init__(self, *args, **kwargs): + warnings.warn( + "_ConvTransposeMixin is a deprecated internal class. " + "Please consider using public APIs.") + super().__init__(*args, **kwargs) + + +# TODO: Conv2dLocal +# TODO: Conv2dMap +# TODO: ConvTranspose2dMap + + +class _LazyConvXdMixin(LazyModuleMixin): + groups: int + transposed: bool + in_channels: int + out_channels: int + kernel_size: Tuple[int, ...] + weight: UninitializedParameter + bias: UninitializedParameter + + def reset_parameters(self) -> None: + # has_uninitialized_params is defined in parent class and it is using a protocol on self + if not self.has_uninitialized_params() and self.in_channels != 0: # type: ignore[misc] + # "type:ignore[..]" is required because mypy thinks that "reset_parameters" is undefined + # in super class. Turns out that it is defined in _ConvND which is inherited by any class + # that also inherits _LazyConvXdMixin + super().reset_parameters() # type: ignore[misc] + + # Signature of "initialize_parameters" is incompatible with the definition in supertype LazyModuleMixin + def initialize_parameters(self, input) -> None: # type: ignore[override] + # defined by parent class but using a protocol + if self.has_uninitialized_params(): # type: ignore[misc] + self.in_channels = self._get_in_channels(input) + if self.in_channels % self.groups != 0: + raise ValueError('in_channels must be divisible by groups') + assert isinstance(self.weight, UninitializedParameter) + if self.transposed: + self.weight.materialize(( + self.in_channels, self.out_channels // self.groups, *self.kernel_size)) + else: + self.weight.materialize(( + self.out_channels, self.in_channels // self.groups, *self.kernel_size)) + if self.bias is not None: + assert isinstance(self.bias, UninitializedParameter) + self.bias.materialize((self.out_channels,)) + self.reset_parameters() + + # Function to extract in_channels from first input. + def _get_in_channels(self, input: Tensor) -> int: + num_spatial_dims = self._get_num_spatial_dims() + num_dims_no_batch = num_spatial_dims + 1 # +1 for channels dim + num_dims_batch = num_dims_no_batch + 1 + if input.dim() not in (num_dims_no_batch, num_dims_batch): + raise RuntimeError("Expected {}D (unbatched) or {}D (batched) input to {}, but " + "got input of size: {}".format(num_dims_no_batch, num_dims_batch, + self.__class__.__name__, input.shape)) + return input.shape[1] if input.dim() == num_dims_batch else input.shape[0] + + # Function to return the number of spatial dims expected for inputs to the module. + # This is expected to be implemented by subclasses. + def _get_num_spatial_dims(self) -> int: + raise NotImplementedError() + + +# LazyConv1d defines weight as a Tensor but derived class defines it as UnitializeParameter +class LazyConv1d(_LazyConvXdMixin, Conv1d): # type: ignore[misc] + r"""A :class:`torch.nn.Conv1d` module with lazy initialization of the ``in_channels`` argument. + + The ``in_channels`` argument of the :class:`Conv1d` is inferred from the ``input.size(1)``. + The attributes that will be lazily initialized are `weight` and `bias`. + + Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation + on lazy modules and their limitations. + + Args: + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): Zero-padding added to both sides of + the input. Default: 0 + padding_mode (str, optional): ``'zeros'``, ``'reflect'``, + ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` + dilation (int or tuple, optional): Spacing between kernel + elements. Default: 1 + groups (int, optional): Number of blocked connections from input + channels to output channels. Default: 1 + bias (bool, optional): If ``True``, adds a learnable bias to the + output. Default: ``True`` + + .. seealso:: :class:`torch.nn.Conv1d` and :class:`torch.nn.modules.lazy.LazyModuleMixin` + """ + + # super class define this variable as None. "type: ignore[..] is required + # since we are redefining the variable. + cls_to_become = Conv1d # type: ignore[assignment] + + def __init__( + self, + out_channels: int, + kernel_size: _size_1_t, + stride: _size_1_t = 1, + padding: _size_1_t = 0, + dilation: _size_1_t = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = 'zeros', + device=None, + dtype=None + ) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__( + 0, + 0, + kernel_size, + stride, + padding, + dilation, + groups, + # bias is hardcoded to False to avoid creating tensor + # that will soon be overwritten. + False, + padding_mode, + **factory_kwargs + ) + self.weight = UninitializedParameter(**factory_kwargs) + self.out_channels = out_channels + if bias: + self.bias = UninitializedParameter(**factory_kwargs) + + def _get_num_spatial_dims(self) -> int: + return 1 + + +# LazyConv2d defines weight as a Tensor but derived class defines it as UnitializeParameter +class LazyConv2d(_LazyConvXdMixin, Conv2d): # type: ignore[misc] + r"""A :class:`torch.nn.Conv2d` module with lazy initialization of the ``in_channels`` argument. + + The ``in_channels`` argument of the :class:`Conv2d` that is inferred from the ``input.size(1)``. + The attributes that will be lazily initialized are `weight` and `bias`. + + Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation + on lazy modules and their limitations. + + Args: + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): Zero-padding added to both sides of + the input. Default: 0 + padding_mode (str, optional): ``'zeros'``, ``'reflect'``, + ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` + dilation (int or tuple, optional): Spacing between kernel + elements. Default: 1 + groups (int, optional): Number of blocked connections from input + channels to output channels. Default: 1 + bias (bool, optional): If ``True``, adds a learnable bias to the + output. Default: ``True`` + + .. seealso:: :class:`torch.nn.Conv2d` and :class:`torch.nn.modules.lazy.LazyModuleMixin` + """ + + # super class define this variable as None. "type: ignore[..] is required + # since we are redefining the variable. + cls_to_become = Conv2d # type: ignore[assignment] + + def __init__( + self, + out_channels: int, + kernel_size: _size_2_t, + stride: _size_2_t = 1, + padding: _size_2_t = 0, + dilation: _size_2_t = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = 'zeros', # TODO: refine this type + device=None, + dtype=None + ) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__( + 0, + 0, + kernel_size, + stride, + padding, + dilation, + groups, + # bias is hardcoded to False to avoid creating tensor + # that will soon be overwritten. + False, + padding_mode, + **factory_kwargs + ) + self.weight = UninitializedParameter(**factory_kwargs) + self.out_channels = out_channels + if bias: + self.bias = UninitializedParameter(**factory_kwargs) + + def _get_num_spatial_dims(self) -> int: + return 2 + + +# LazyConv3d defines weight as a Tensor but derived class defines it as UnitializeParameter +class LazyConv3d(_LazyConvXdMixin, Conv3d): # type: ignore[misc] + r"""A :class:`torch.nn.Conv3d` module with lazy initialization of the ``in_channels`` argument. + + The ``in_channels`` argument of the :class:`Conv3d` that is inferred from + the ``input.size(1)``. + The attributes that will be lazily initialized are `weight` and `bias`. + + Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation + on lazy modules and their limitations. + + Args: + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): Zero-padding added to both sides of + the input. Default: 0 + padding_mode (str, optional): ``'zeros'``, ``'reflect'``, + ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` + dilation (int or tuple, optional): Spacing between kernel + elements. Default: 1 + groups (int, optional): Number of blocked connections from input + channels to output channels. Default: 1 + bias (bool, optional): If ``True``, adds a learnable bias to the + output. Default: ``True`` + + .. seealso:: :class:`torch.nn.Conv3d` and :class:`torch.nn.modules.lazy.LazyModuleMixin` + """ + + # super class define this variable as None. "type: ignore[..] is required + # since we are redefining the variable. + cls_to_become = Conv3d # type: ignore[assignment] + + def __init__( + self, + out_channels: int, + kernel_size: _size_3_t, + stride: _size_3_t = 1, + padding: _size_3_t = 0, + dilation: _size_3_t = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = 'zeros', + device=None, + dtype=None + ) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__( + 0, + 0, + kernel_size, + stride, + padding, + dilation, + groups, + # bias is hardcoded to False to avoid creating tensor + # that will soon be overwritten. + False, + padding_mode, + **factory_kwargs + ) + self.weight = UninitializedParameter(**factory_kwargs) + self.out_channels = out_channels + if bias: + self.bias = UninitializedParameter(**factory_kwargs) + + def _get_num_spatial_dims(self) -> int: + return 3 + + +# LazyConvTranspose1d defines weight as a Tensor but derived class defines it as UnitializeParameter +class LazyConvTranspose1d(_LazyConvXdMixin, ConvTranspose1d): # type: ignore[misc] + r"""A :class:`torch.nn.ConvTranspose1d` module with lazy initialization of the ``in_channels`` argument. + + The ``in_channels`` argument of the :class:`ConvTranspose1d` that is inferred from + the ``input.size(1)``. + The attributes that will be lazily initialized are `weight` and `bias`. + + Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation + on lazy modules and their limitations. + + Args: + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): ``dilation * (kernel_size - 1) - padding`` zero-padding + will be added to both sides of the input. Default: 0 + output_padding (int or tuple, optional): Additional size added to one side + of the output shape. Default: 0 + groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 + bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` + dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 + + .. seealso:: :class:`torch.nn.ConvTranspose1d` and :class:`torch.nn.modules.lazy.LazyModuleMixin` + """ + + # super class define this variable as None. "type: ignore[..] is required + # since we are redefining the variable. + cls_to_become = ConvTranspose1d # type: ignore[assignment] + + def __init__( + self, + out_channels: int, + kernel_size: _size_1_t, + stride: _size_1_t = 1, + padding: _size_1_t = 0, + output_padding: _size_1_t = 0, + groups: int = 1, + bias: bool = True, + dilation: _size_1_t = 1, + padding_mode: str = 'zeros', + device=None, + dtype=None + ) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__( + 0, + 0, + kernel_size, + stride, + padding, + output_padding, + groups, + # bias is hardcoded to False to avoid creating tensor + # that will soon be overwritten. + False, + dilation, + padding_mode, + **factory_kwargs + ) + self.weight = UninitializedParameter(**factory_kwargs) + self.out_channels = out_channels + if bias: + self.bias = UninitializedParameter(**factory_kwargs) + + def _get_num_spatial_dims(self) -> int: + return 1 + + +# LazyConvTranspose2d defines weight as a Tensor but derived class defines it as UnitializeParameter +class LazyConvTranspose2d(_LazyConvXdMixin, ConvTranspose2d): # type: ignore[misc] + r"""A :class:`torch.nn.ConvTranspose2d` module with lazy initialization of the ``in_channels`` argument. + + The ``in_channels`` argument of the :class:`ConvTranspose2d` is inferred from + the ``input.size(1)``. + The attributes that will be lazily initialized are `weight` and `bias`. + + Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation + on lazy modules and their limitations. + + Args: + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): ``dilation * (kernel_size - 1) - padding`` zero-padding + will be added to both sides of each dimension in the input. Default: 0 + output_padding (int or tuple, optional): Additional size added to one side + of each dimension in the output shape. Default: 0 + groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 + bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` + dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 + + .. seealso:: :class:`torch.nn.ConvTranspose2d` and :class:`torch.nn.modules.lazy.LazyModuleMixin` + """ + + # super class define this variable as None. "type: ignore[..] is required + # since we are redefining the variable. + cls_to_become = ConvTranspose2d # type: ignore[assignment] + + def __init__( + self, + out_channels: int, + kernel_size: _size_2_t, + stride: _size_2_t = 1, + padding: _size_2_t = 0, + output_padding: _size_2_t = 0, + groups: int = 1, + bias: bool = True, + dilation: int = 1, + padding_mode: str = 'zeros', + device=None, + dtype=None + ) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__( + 0, + 0, + kernel_size, + stride, + padding, + output_padding, + groups, + # bias is hardcoded to False to avoid creating tensor + # that will soon be overwritten. + False, + dilation, + padding_mode, + **factory_kwargs + ) + self.weight = UninitializedParameter(**factory_kwargs) + self.out_channels = out_channels + if bias: + self.bias = UninitializedParameter(**factory_kwargs) + + def _get_num_spatial_dims(self) -> int: + return 2 + + +# LazyConvTranspose3d defines weight as a Tensor but derived class defines it as UnitializeParameter +class LazyConvTranspose3d(_LazyConvXdMixin, ConvTranspose3d): # type: ignore[misc] + r"""A :class:`torch.nn.ConvTranspose3d` module with lazy initialization of the ``in_channels`` argument. + + The ``in_channels`` argument of the :class:`ConvTranspose3d` is inferred from + the ``input.size(1)``. + The attributes that will be lazily initialized are `weight` and `bias`. + + Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation + on lazy modules and their limitations. + + Args: + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): ``dilation * (kernel_size - 1) - padding`` zero-padding + will be added to both sides of each dimension in the input. Default: 0 + output_padding (int or tuple, optional): Additional size added to one side + of each dimension in the output shape. Default: 0 + groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 + bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` + dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 + + .. seealso:: :class:`torch.nn.ConvTranspose3d` and :class:`torch.nn.modules.lazy.LazyModuleMixin` + """ + + # super class define this variable as None. "type: ignore[..] is required + # since we are redefining the variable. + cls_to_become = ConvTranspose3d # type: ignore[assignment] + + def __init__( + self, + out_channels: int, + kernel_size: _size_3_t, + stride: _size_3_t = 1, + padding: _size_3_t = 0, + output_padding: _size_3_t = 0, + groups: int = 1, + bias: bool = True, + dilation: _size_3_t = 1, + padding_mode: str = 'zeros', + device=None, + dtype=None + ) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__( + 0, + 0, + kernel_size, + stride, + padding, + output_padding, + groups, + # bias is hardcoded to False to avoid creating tensor + # that will soon be overwritten. + False, + dilation, + padding_mode, + **factory_kwargs + ) + self.weight = UninitializedParameter(**factory_kwargs) + self.out_channels = out_channels + if bias: + self.bias = UninitializedParameter(**factory_kwargs) + + def _get_num_spatial_dims(self) -> int: + return 3 diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/loss.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..7173e81797df5620d4e3908cfa6035dac2ae9c60 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/loss.py @@ -0,0 +1,1790 @@ +import warnings + +from .distance import PairwiseDistance +from .module import Module +from .. import functional as F +from .. import _reduction as _Reduction + +from torch import Tensor +from typing import Callable, Optional + +__all__ = ['L1Loss', 'NLLLoss', 'NLLLoss2d', 'PoissonNLLLoss', 'GaussianNLLLoss', 'KLDivLoss', + 'MSELoss', 'BCELoss', 'BCEWithLogitsLoss', 'HingeEmbeddingLoss', 'MultiLabelMarginLoss', + 'SmoothL1Loss', 'HuberLoss', 'SoftMarginLoss', 'CrossEntropyLoss', 'MultiLabelSoftMarginLoss', + 'CosineEmbeddingLoss', 'MarginRankingLoss', 'MultiMarginLoss', 'TripletMarginLoss', + 'TripletMarginWithDistanceLoss', 'CTCLoss'] + +class _Loss(Module): + reduction: str + + def __init__(self, size_average=None, reduce=None, reduction: str = 'mean') -> None: + super().__init__() + if size_average is not None or reduce is not None: + self.reduction: str = _Reduction.legacy_get_string(size_average, reduce) + else: + self.reduction = reduction + + +class _WeightedLoss(_Loss): + def __init__(self, weight: Optional[Tensor] = None, size_average=None, reduce=None, reduction: str = 'mean') -> None: + super().__init__(size_average, reduce, reduction) + self.register_buffer('weight', weight) + self.weight: Optional[Tensor] + + +class L1Loss(_Loss): + r"""Creates a criterion that measures the mean absolute error (MAE) between each element in + the input :math:`x` and target :math:`y`. + + The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as: + + .. math:: + \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad + l_n = \left| x_n - y_n \right|, + + where :math:`N` is the batch size. If :attr:`reduction` is not ``'none'`` + (default ``'mean'``), then: + + .. math:: + \ell(x, y) = + \begin{cases} + \operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\ + \operatorname{sum}(L), & \text{if reduction} = \text{`sum'.} + \end{cases} + + :math:`x` and :math:`y` are tensors of arbitrary shapes with a total + of :math:`n` elements each. + + The sum operation still operates over all the elements, and divides by :math:`n`. + + The division by :math:`n` can be avoided if one sets ``reduction = 'sum'``. + + Supports real-valued and complex-valued inputs. + + Args: + size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there are multiple elements per sample. If the field :attr:`size_average` + is set to ``False``, the losses are instead summed for each minibatch. Ignored + when :attr:`reduce` is ``False``. Default: ``True`` + reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: ``True`` + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in the meantime, + specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Target: :math:`(*)`, same shape as the input. + - Output: scalar. If :attr:`reduction` is ``'none'``, then + :math:`(*)`, same shape as the input. + + Examples:: + + >>> loss = nn.L1Loss() + >>> input = torch.randn(3, 5, requires_grad=True) + >>> target = torch.randn(3, 5) + >>> output = loss(input, target) + >>> output.backward() + """ + __constants__ = ['reduction'] + + def __init__(self, size_average=None, reduce=None, reduction: str = 'mean') -> None: + super().__init__(size_average, reduce, reduction) + + def forward(self, input: Tensor, target: Tensor) -> Tensor: + return F.l1_loss(input, target, reduction=self.reduction) + + +class NLLLoss(_WeightedLoss): + r"""The negative log likelihood loss. It is useful to train a classification + problem with `C` classes. + + If provided, the optional argument :attr:`weight` should be a 1D Tensor assigning + weight to each of the classes. This is particularly useful when you have an + unbalanced training set. + + The `input` given through a forward call is expected to contain + log-probabilities of each class. `input` has to be a Tensor of size either + :math:`(minibatch, C)` or :math:`(minibatch, C, d_1, d_2, ..., d_K)` + with :math:`K \geq 1` for the `K`-dimensional case. The latter is useful for + higher dimension inputs, such as computing NLL loss per-pixel for 2D images. + + Obtaining log-probabilities in a neural network is easily achieved by + adding a `LogSoftmax` layer in the last layer of your network. + You may use `CrossEntropyLoss` instead, if you prefer not to add an extra + layer. + + The `target` that this loss expects should be a class index in the range :math:`[0, C-1]` + where `C = number of classes`; if `ignore_index` is specified, this loss also accepts + this class index (this index may not necessarily be in the class range). + + The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as: + + .. math:: + \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad + l_n = - w_{y_n} x_{n,y_n}, \quad + w_{c} = \text{weight}[c] \cdot \mathbb{1}\{c \not= \text{ignore\_index}\}, + + where :math:`x` is the input, :math:`y` is the target, :math:`w` is the weight, and + :math:`N` is the batch size. If :attr:`reduction` is not ``'none'`` + (default ``'mean'``), then + + .. math:: + \ell(x, y) = \begin{cases} + \sum_{n=1}^N \frac{1}{\sum_{n=1}^N w_{y_n}} l_n, & + \text{if reduction} = \text{`mean';}\\ + \sum_{n=1}^N l_n, & + \text{if reduction} = \text{`sum'.} + \end{cases} + + Args: + weight (Tensor, optional): a manual rescaling weight given to each + class. If given, it has to be a Tensor of size `C`. Otherwise, it is + treated as if having all ones. + size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there are multiple elements per sample. If the field :attr:`size_average` + is set to ``False``, the losses are instead summed for each minibatch. Ignored + when :attr:`reduce` is ``False``. Default: ``None`` + ignore_index (int, optional): Specifies a target value that is ignored + and does not contribute to the input gradient. When + :attr:`size_average` is ``True``, the loss is averaged over + non-ignored targets. + reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: ``None`` + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will + be applied, ``'mean'``: the weighted mean of the output is taken, + ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in + the meantime, specifying either of those two args will override + :attr:`reduction`. Default: ``'mean'`` + + Shape: + - Input: :math:`(N, C)` or :math:`(C)`, where `C = number of classes`, or + :math:`(N, C, d_1, d_2, ..., d_K)` with :math:`K \geq 1` + in the case of `K`-dimensional loss. + - Target: :math:`(N)` or :math:`()`, where each value is + :math:`0 \leq \text{targets}[i] \leq C-1`, or + :math:`(N, d_1, d_2, ..., d_K)` with :math:`K \geq 1` in the case of + K-dimensional loss. + - Output: If :attr:`reduction` is ``'none'``, shape :math:`(N)` or + :math:`(N, d_1, d_2, ..., d_K)` with :math:`K \geq 1` in the case of K-dimensional loss. + Otherwise, scalar. + + Examples:: + + >>> m = nn.LogSoftmax(dim=1) + >>> loss = nn.NLLLoss() + >>> # input is of size N x C = 3 x 5 + >>> input = torch.randn(3, 5, requires_grad=True) + >>> # each element in target has to have 0 <= value < C + >>> target = torch.tensor([1, 0, 4]) + >>> output = loss(m(input), target) + >>> output.backward() + >>> + >>> + >>> # 2D loss example (used, for example, with image inputs) + >>> N, C = 5, 4 + >>> loss = nn.NLLLoss() + >>> # input is of size N x C x height x width + >>> data = torch.randn(N, 16, 10, 10) + >>> conv = nn.Conv2d(16, C, (3, 3)) + >>> m = nn.LogSoftmax(dim=1) + >>> # each element in target has to have 0 <= value < C + >>> target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, C) + >>> output = loss(m(conv(data)), target) + >>> output.backward() + """ + __constants__ = ['ignore_index', 'reduction'] + ignore_index: int + + def __init__(self, weight: Optional[Tensor] = None, size_average=None, ignore_index: int = -100, + reduce=None, reduction: str = 'mean') -> None: + super().__init__(weight, size_average, reduce, reduction) + self.ignore_index = ignore_index + + def forward(self, input: Tensor, target: Tensor) -> Tensor: + return F.nll_loss(input, target, weight=self.weight, ignore_index=self.ignore_index, reduction=self.reduction) + + +class NLLLoss2d(NLLLoss): + def __init__(self, weight: Optional[Tensor] = None, size_average=None, ignore_index: int = -100, + reduce=None, reduction: str = 'mean') -> None: + warnings.warn("NLLLoss2d has been deprecated. " + "Please use NLLLoss instead as a drop-in replacement and see " + "https://pytorch.org/docs/master/nn.html#torch.nn.NLLLoss for more details.") + super().__init__(weight, size_average, ignore_index, reduce, reduction) + + +class PoissonNLLLoss(_Loss): + r"""Negative log likelihood loss with Poisson distribution of target. + + The loss can be described as: + + .. math:: + \text{target} \sim \mathrm{Poisson}(\text{input}) + + \text{loss}(\text{input}, \text{target}) = \text{input} - \text{target} * \log(\text{input}) + + \log(\text{target!}) + + The last term can be omitted or approximated with Stirling formula. The + approximation is used for target values more than 1. For targets less or + equal to 1 zeros are added to the loss. + + Args: + log_input (bool, optional): if ``True`` the loss is computed as + :math:`\exp(\text{input}) - \text{target}*\text{input}`, if ``False`` the loss is + :math:`\text{input} - \text{target}*\log(\text{input}+\text{eps})`. + full (bool, optional): whether to compute full loss, i. e. to add the + Stirling approximation term + + .. math:: + \text{target}*\log(\text{target}) - \text{target} + 0.5 * \log(2\pi\text{target}). + size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there are multiple elements per sample. If the field :attr:`size_average` + is set to ``False``, the losses are instead summed for each minibatch. Ignored + when :attr:`reduce` is ``False``. Default: ``True`` + eps (float, optional): Small value to avoid evaluation of :math:`\log(0)` when + :attr:`log_input = False`. Default: 1e-8 + reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: ``True`` + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in the meantime, + specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + + Examples:: + + >>> loss = nn.PoissonNLLLoss() + >>> log_input = torch.randn(5, 2, requires_grad=True) + >>> target = torch.randn(5, 2) + >>> output = loss(log_input, target) + >>> output.backward() + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Target: :math:`(*)`, same shape as the input. + - Output: scalar by default. If :attr:`reduction` is ``'none'``, then :math:`(*)`, + the same shape as the input. + """ + __constants__ = ['log_input', 'full', 'eps', 'reduction'] + log_input: bool + full: bool + eps: float + + def __init__(self, log_input: bool = True, full: bool = False, size_average=None, + eps: float = 1e-8, reduce=None, reduction: str = 'mean') -> None: + super().__init__(size_average, reduce, reduction) + self.log_input = log_input + self.full = full + self.eps = eps + + def forward(self, log_input: Tensor, target: Tensor) -> Tensor: + return F.poisson_nll_loss(log_input, target, log_input=self.log_input, full=self.full, + eps=self.eps, reduction=self.reduction) + + +class GaussianNLLLoss(_Loss): + r"""Gaussian negative log likelihood loss. + + The targets are treated as samples from Gaussian distributions with + expectations and variances predicted by the neural network. For a + ``target`` tensor modelled as having Gaussian distribution with a tensor + of expectations ``input`` and a tensor of positive variances ``var`` the loss is: + + .. math:: + \text{loss} = \frac{1}{2}\left(\log\left(\text{max}\left(\text{var}, + \ \text{eps}\right)\right) + \frac{\left(\text{input} - \text{target}\right)^2} + {\text{max}\left(\text{var}, \ \text{eps}\right)}\right) + \text{const.} + + where :attr:`eps` is used for stability. By default, the constant term of + the loss function is omitted unless :attr:`full` is ``True``. If ``var`` is not the same + size as ``input`` (due to a homoscedastic assumption), it must either have a final dimension + of 1 or have one fewer dimension (with all other sizes being the same) for correct broadcasting. + + Args: + full (bool, optional): include the constant term in the loss + calculation. Default: ``False``. + eps (float, optional): value used to clamp ``var`` (see note below), for + stability. Default: 1e-6. + reduction (str, optional): specifies the reduction to apply to the + output:``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction + will be applied, ``'mean'``: the output is the average of all batch + member losses, ``'sum'``: the output is the sum of all batch member + losses. Default: ``'mean'``. + + Shape: + - Input: :math:`(N, *)` or :math:`(*)` where :math:`*` means any number of additional + dimensions + - Target: :math:`(N, *)` or :math:`(*)`, same shape as the input, or same shape as the input + but with one dimension equal to 1 (to allow for broadcasting) + - Var: :math:`(N, *)` or :math:`(*)`, same shape as the input, or same shape as the input but + with one dimension equal to 1, or same shape as the input but with one fewer + dimension (to allow for broadcasting) + - Output: scalar if :attr:`reduction` is ``'mean'`` (default) or + ``'sum'``. If :attr:`reduction` is ``'none'``, then :math:`(N, *)`, same + shape as the input + + Examples:: + >>> loss = nn.GaussianNLLLoss() + >>> input = torch.randn(5, 2, requires_grad=True) + >>> target = torch.randn(5, 2) + >>> var = torch.ones(5, 2, requires_grad=True) # heteroscedastic + >>> output = loss(input, target, var) + >>> output.backward() + + >>> loss = nn.GaussianNLLLoss() + >>> input = torch.randn(5, 2, requires_grad=True) + >>> target = torch.randn(5, 2) + >>> var = torch.ones(5, 1, requires_grad=True) # homoscedastic + >>> output = loss(input, target, var) + >>> output.backward() + + Note: + The clamping of ``var`` is ignored with respect to autograd, and so the + gradients are unaffected by it. + + Reference: + Nix, D. A. and Weigend, A. S., "Estimating the mean and variance of the + target probability distribution", Proceedings of 1994 IEEE International + Conference on Neural Networks (ICNN'94), Orlando, FL, USA, 1994, pp. 55-60 + vol.1, doi: 10.1109/ICNN.1994.374138. + """ + __constants__ = ['full', 'eps', 'reduction'] + full: bool + eps: float + + def __init__(self, *, full: bool = False, eps: float = 1e-6, reduction: str = 'mean') -> None: + super().__init__(None, None, reduction) + self.full = full + self.eps = eps + + def forward(self, input: Tensor, target: Tensor, var: Tensor) -> Tensor: + return F.gaussian_nll_loss(input, target, var, full=self.full, eps=self.eps, reduction=self.reduction) + + +class KLDivLoss(_Loss): + r"""The Kullback-Leibler divergence loss. + + For tensors of the same shape :math:`y_{\text{pred}},\ y_{\text{true}}`, + where :math:`y_{\text{pred}}` is the :attr:`input` and :math:`y_{\text{true}}` is the + :attr:`target`, we define the **pointwise KL-divergence** as + + .. math:: + + L(y_{\text{pred}},\ y_{\text{true}}) + = y_{\text{true}} \cdot \log \frac{y_{\text{true}}}{y_{\text{pred}}} + = y_{\text{true}} \cdot (\log y_{\text{true}} - \log y_{\text{pred}}) + + To avoid underflow issues when computing this quantity, this loss expects the argument + :attr:`input` in the log-space. The argument :attr:`target` may also be provided in the + log-space if :attr:`log_target`\ `= True`. + + To summarise, this function is roughly equivalent to computing + + .. code-block:: python + + if not log_target: # default + loss_pointwise = target * (target.log() - input) + else: + loss_pointwise = target.exp() * (target - input) + + and then reducing this result depending on the argument :attr:`reduction` as + + .. code-block:: python + + if reduction == "mean": # default + loss = loss_pointwise.mean() + elif reduction == "batchmean": # mathematically correct + loss = loss_pointwise.sum() / input.size(0) + elif reduction == "sum": + loss = loss_pointwise.sum() + else: # reduction == "none" + loss = loss_pointwise + + .. note:: + As all the other losses in PyTorch, this function expects the first argument, + :attr:`input`, to be the output of the model (e.g. the neural network) + and the second, :attr:`target`, to be the observations in the dataset. + This differs from the standard mathematical notation :math:`KL(P\ ||\ Q)` where + :math:`P` denotes the distribution of the observations and :math:`Q` denotes the model. + + .. warning:: + :attr:`reduction`\ `= "mean"` doesn't return the true KL divergence value, please use + :attr:`reduction`\ `= "batchmean"` which aligns with the mathematical definition. + + Args: + size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there are multiple elements per sample. If the field :attr:`size_average` + is set to `False`, the losses are instead summed for each minibatch. Ignored + when :attr:`reduce` is `False`. Default: `True` + reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is `False`, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: `True` + reduction (str, optional): Specifies the reduction to apply to the output. Default: `"mean"` + log_target (bool, optional): Specifies whether `target` is the log space. Default: `False` + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Target: :math:`(*)`, same shape as the input. + - Output: scalar by default. If :attr:`reduction` is `'none'`, then :math:`(*)`, + same shape as the input. + + Examples:: + + >>> import torch.nn.functional as F + >>> kl_loss = nn.KLDivLoss(reduction="batchmean") + >>> # input should be a distribution in the log space + >>> input = F.log_softmax(torch.randn(3, 5, requires_grad=True), dim=1) + >>> # Sample a batch of distributions. Usually this would come from the dataset + >>> target = F.softmax(torch.rand(3, 5), dim=1) + >>> output = kl_loss(input, target) + + >>> kl_loss = nn.KLDivLoss(reduction="batchmean", log_target=True) + >>> log_target = F.log_softmax(torch.rand(3, 5), dim=1) + >>> output = kl_loss(input, log_target) + """ + __constants__ = ['reduction'] + + def __init__(self, size_average=None, reduce=None, reduction: str = 'mean', log_target: bool = False) -> None: + super().__init__(size_average, reduce, reduction) + self.log_target = log_target + + def forward(self, input: Tensor, target: Tensor) -> Tensor: + return F.kl_div(input, target, reduction=self.reduction, log_target=self.log_target) + + +class MSELoss(_Loss): + r"""Creates a criterion that measures the mean squared error (squared L2 norm) between + each element in the input :math:`x` and target :math:`y`. + + The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as: + + .. math:: + \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad + l_n = \left( x_n - y_n \right)^2, + + where :math:`N` is the batch size. If :attr:`reduction` is not ``'none'`` + (default ``'mean'``), then: + + .. math:: + \ell(x, y) = + \begin{cases} + \operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\ + \operatorname{sum}(L), & \text{if reduction} = \text{`sum'.} + \end{cases} + + :math:`x` and :math:`y` are tensors of arbitrary shapes with a total + of :math:`n` elements each. + + The mean operation still operates over all the elements, and divides by :math:`n`. + + The division by :math:`n` can be avoided if one sets ``reduction = 'sum'``. + + Args: + size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there are multiple elements per sample. If the field :attr:`size_average` + is set to ``False``, the losses are instead summed for each minibatch. Ignored + when :attr:`reduce` is ``False``. Default: ``True`` + reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: ``True`` + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in the meantime, + specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Target: :math:`(*)`, same shape as the input. + + Examples:: + + >>> loss = nn.MSELoss() + >>> input = torch.randn(3, 5, requires_grad=True) + >>> target = torch.randn(3, 5) + >>> output = loss(input, target) + >>> output.backward() + """ + __constants__ = ['reduction'] + + def __init__(self, size_average=None, reduce=None, reduction: str = 'mean') -> None: + super().__init__(size_average, reduce, reduction) + + def forward(self, input: Tensor, target: Tensor) -> Tensor: + return F.mse_loss(input, target, reduction=self.reduction) + + +class BCELoss(_WeightedLoss): + r"""Creates a criterion that measures the Binary Cross Entropy between the target and + the input probabilities: + + The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as: + + .. math:: + \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad + l_n = - w_n \left[ y_n \cdot \log x_n + (1 - y_n) \cdot \log (1 - x_n) \right], + + where :math:`N` is the batch size. If :attr:`reduction` is not ``'none'`` + (default ``'mean'``), then + + .. math:: + \ell(x, y) = \begin{cases} + \operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\ + \operatorname{sum}(L), & \text{if reduction} = \text{`sum'.} + \end{cases} + + This is used for measuring the error of a reconstruction in for example + an auto-encoder. Note that the targets :math:`y` should be numbers + between 0 and 1. + + Notice that if :math:`x_n` is either 0 or 1, one of the log terms would be + mathematically undefined in the above loss equation. PyTorch chooses to set + :math:`\log (0) = -\infty`, since :math:`\lim_{x\to 0} \log (x) = -\infty`. + However, an infinite term in the loss equation is not desirable for several reasons. + + For one, if either :math:`y_n = 0` or :math:`(1 - y_n) = 0`, then we would be + multiplying 0 with infinity. Secondly, if we have an infinite loss value, then + we would also have an infinite term in our gradient, since + :math:`\lim_{x\to 0} \frac{d}{dx} \log (x) = \infty`. + This would make BCELoss's backward method nonlinear with respect to :math:`x_n`, + and using it for things like linear regression would not be straight-forward. + + Our solution is that BCELoss clamps its log function outputs to be greater than + or equal to -100. This way, we can always have a finite loss value and a linear + backward method. + + + Args: + weight (Tensor, optional): a manual rescaling weight given to the loss + of each batch element. If given, has to be a Tensor of size `nbatch`. + size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there are multiple elements per sample. If the field :attr:`size_average` + is set to ``False``, the losses are instead summed for each minibatch. Ignored + when :attr:`reduce` is ``False``. Default: ``True`` + reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: ``True`` + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in the meantime, + specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Target: :math:`(*)`, same shape as the input. + - Output: scalar. If :attr:`reduction` is ``'none'``, then :math:`(*)`, same + shape as input. + + Examples:: + + >>> m = nn.Sigmoid() + >>> loss = nn.BCELoss() + >>> input = torch.randn(3, 2, requires_grad=True) + >>> target = torch.rand(3, 2, requires_grad=False) + >>> output = loss(m(input), target) + >>> output.backward() + """ + __constants__ = ['reduction'] + + def __init__(self, weight: Optional[Tensor] = None, size_average=None, reduce=None, reduction: str = 'mean') -> None: + super().__init__(weight, size_average, reduce, reduction) + + def forward(self, input: Tensor, target: Tensor) -> Tensor: + return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction) + + +class BCEWithLogitsLoss(_Loss): + r"""This loss combines a `Sigmoid` layer and the `BCELoss` in one single + class. This version is more numerically stable than using a plain `Sigmoid` + followed by a `BCELoss` as, by combining the operations into one layer, + we take advantage of the log-sum-exp trick for numerical stability. + + The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as: + + .. math:: + \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad + l_n = - w_n \left[ y_n \cdot \log \sigma(x_n) + + (1 - y_n) \cdot \log (1 - \sigma(x_n)) \right], + + where :math:`N` is the batch size. If :attr:`reduction` is not ``'none'`` + (default ``'mean'``), then + + .. math:: + \ell(x, y) = \begin{cases} + \operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\ + \operatorname{sum}(L), & \text{if reduction} = \text{`sum'.} + \end{cases} + + This is used for measuring the error of a reconstruction in for example + an auto-encoder. Note that the targets `t[i]` should be numbers + between 0 and 1. + + It's possible to trade off recall and precision by adding weights to positive examples. + In the case of multi-label classification the loss can be described as: + + .. math:: + \ell_c(x, y) = L_c = \{l_{1,c},\dots,l_{N,c}\}^\top, \quad + l_{n,c} = - w_{n,c} \left[ p_c y_{n,c} \cdot \log \sigma(x_{n,c}) + + (1 - y_{n,c}) \cdot \log (1 - \sigma(x_{n,c})) \right], + + where :math:`c` is the class number (:math:`c > 1` for multi-label binary classification, + :math:`c = 1` for single-label binary classification), + :math:`n` is the number of the sample in the batch and + :math:`p_c` is the weight of the positive answer for the class :math:`c`. + + :math:`p_c > 1` increases the recall, :math:`p_c < 1` increases the precision. + + For example, if a dataset contains 100 positive and 300 negative examples of a single class, + then ``pos_weight`` for the class should be equal to :math:`\frac{300}{100}=3`. + The loss would act as if the dataset contains :math:`3\times 100=300` positive examples. + + Examples:: + + >>> target = torch.ones([10, 64], dtype=torch.float32) # 64 classes, batch size = 10 + >>> output = torch.full([10, 64], 1.5) # A prediction (logit) + >>> pos_weight = torch.ones([64]) # All weights are equal to 1 + >>> criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight) + >>> criterion(output, target) # -log(sigmoid(1.5)) + tensor(0.20...) + + In the above example, the ``pos_weight`` tensor's elements correspond to the 64 distinct classes + in a multi-label binary classification scenario. Each element in ``pos_weight`` is designed to adjust the + loss function based on the imbalance between negative and positive samples for the respective class. + This approach is useful in datasets with varying levels of class imbalance, ensuring that the loss + calculation accurately accounts for the distribution in each class. + + Args: + weight (Tensor, optional): a manual rescaling weight given to the loss + of each batch element. If given, has to be a Tensor of size `nbatch`. + size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there are multiple elements per sample. If the field :attr:`size_average` + is set to ``False``, the losses are instead summed for each minibatch. Ignored + when :attr:`reduce` is ``False``. Default: ``True`` + reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: ``True`` + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in the meantime, + specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + pos_weight (Tensor, optional): a weight of positive examples to be broadcasted with target. + Must be a tensor with equal size along the class dimension to the number of classes. + Pay close attention to PyTorch's broadcasting semantics in order to achieve the desired + operations. For a target of size [B, C, H, W] (where B is batch size) pos_weight of + size [B, C, H, W] will apply different pos_weights to each element of the batch or + [C, H, W] the same pos_weights across the batch. To apply the same positive weight + along all spacial dimensions for a 2D multi-class target [C, H, W] use: [C, 1, 1]. + Default: ``None`` + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Target: :math:`(*)`, same shape as the input. + - Output: scalar. If :attr:`reduction` is ``'none'``, then :math:`(*)`, same + shape as input. + + Examples:: + + >>> loss = nn.BCEWithLogitsLoss() + >>> input = torch.randn(3, requires_grad=True) + >>> target = torch.empty(3).random_(2) + >>> output = loss(input, target) + >>> output.backward() + """ + def __init__(self, weight: Optional[Tensor] = None, size_average=None, reduce=None, reduction: str = 'mean', + pos_weight: Optional[Tensor] = None) -> None: + super().__init__(size_average, reduce, reduction) + self.register_buffer('weight', weight) + self.register_buffer('pos_weight', pos_weight) + self.weight: Optional[Tensor] + self.pos_weight: Optional[Tensor] + + def forward(self, input: Tensor, target: Tensor) -> Tensor: + return F.binary_cross_entropy_with_logits(input, target, + self.weight, + pos_weight=self.pos_weight, + reduction=self.reduction) + + +class HingeEmbeddingLoss(_Loss): + r"""Measures the loss given an input tensor :math:`x` and a labels tensor :math:`y` + (containing 1 or -1). + This is usually used for measuring whether two inputs are similar or + dissimilar, e.g. using the L1 pairwise distance as :math:`x`, and is typically + used for learning nonlinear embeddings or semi-supervised learning. + + The loss function for :math:`n`-th sample in the mini-batch is + + .. math:: + l_n = \begin{cases} + x_n, & \text{if}\; y_n = 1,\\ + \max \{0, margin - x_n\}, & \text{if}\; y_n = -1, + \end{cases} + + and the total loss functions is + + .. math:: + \ell(x, y) = \begin{cases} + \operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\ + \operatorname{sum}(L), & \text{if reduction} = \text{`sum'.} + \end{cases} + + where :math:`L = \{l_1,\dots,l_N\}^\top`. + + Args: + margin (float, optional): Has a default value of `1`. + size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there are multiple elements per sample. If the field :attr:`size_average` + is set to ``False``, the losses are instead summed for each minibatch. Ignored + when :attr:`reduce` is ``False``. Default: ``True`` + reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: ``True`` + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in the meantime, + specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + + Shape: + - Input: :math:`(*)` where :math:`*` means, any number of dimensions. The sum operation + operates over all the elements. + - Target: :math:`(*)`, same shape as the input + - Output: scalar. If :attr:`reduction` is ``'none'``, then same shape as the input + """ + __constants__ = ['margin', 'reduction'] + margin: float + + def __init__(self, margin: float = 1.0, size_average=None, reduce=None, reduction: str = 'mean') -> None: + super().__init__(size_average, reduce, reduction) + self.margin = margin + + def forward(self, input: Tensor, target: Tensor) -> Tensor: + return F.hinge_embedding_loss(input, target, margin=self.margin, reduction=self.reduction) + + +class MultiLabelMarginLoss(_Loss): + r"""Creates a criterion that optimizes a multi-class multi-classification + hinge loss (margin-based loss) between input :math:`x` (a 2D mini-batch `Tensor`) + and output :math:`y` (which is a 2D `Tensor` of target class indices). + For each sample in the mini-batch: + + .. math:: + \text{loss}(x, y) = \sum_{ij}\frac{\max(0, 1 - (x[y[j]] - x[i]))}{\text{x.size}(0)} + + where :math:`x \in \left\{0, \; \cdots , \; \text{x.size}(0) - 1\right\}`, \ + :math:`y \in \left\{0, \; \cdots , \; \text{y.size}(0) - 1\right\}`, \ + :math:`0 \leq y[j] \leq \text{x.size}(0)-1`, \ + and :math:`i \neq y[j]` for all :math:`i` and :math:`j`. + + :math:`y` and :math:`x` must have the same size. + + The criterion only considers a contiguous block of non-negative targets that + starts at the front. + + This allows for different samples to have variable amounts of target classes. + + Args: + size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there are multiple elements per sample. If the field :attr:`size_average` + is set to ``False``, the losses are instead summed for each minibatch. Ignored + when :attr:`reduce` is ``False``. Default: ``True`` + reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: ``True`` + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in the meantime, + specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + + Shape: + - Input: :math:`(C)` or :math:`(N, C)` where `N` is the batch size and `C` + is the number of classes. + - Target: :math:`(C)` or :math:`(N, C)`, label targets padded by -1 ensuring same shape as the input. + - Output: scalar. If :attr:`reduction` is ``'none'``, then :math:`(N)`. + + Examples:: + + >>> loss = nn.MultiLabelMarginLoss() + >>> x = torch.FloatTensor([[0.1, 0.2, 0.4, 0.8]]) + >>> # for target y, only consider labels 3 and 0, not after label -1 + >>> y = torch.LongTensor([[3, 0, -1, 1]]) + >>> # 0.25 * ((1-(0.1-0.2)) + (1-(0.1-0.4)) + (1-(0.8-0.2)) + (1-(0.8-0.4))) + >>> loss(x, y) + tensor(0.85...) + + """ + __constants__ = ['reduction'] + + def __init__(self, size_average=None, reduce=None, reduction: str = 'mean') -> None: + super().__init__(size_average, reduce, reduction) + + def forward(self, input: Tensor, target: Tensor) -> Tensor: + return F.multilabel_margin_loss(input, target, reduction=self.reduction) + + +class SmoothL1Loss(_Loss): + r"""Creates a criterion that uses a squared term if the absolute + element-wise error falls below beta and an L1 term otherwise. + It is less sensitive to outliers than :class:`torch.nn.MSELoss` and in some cases + prevents exploding gradients (e.g. see the paper `Fast R-CNN`_ by Ross Girshick). + + For a batch of size :math:`N`, the unreduced loss can be described as: + + .. math:: + \ell(x, y) = L = \{l_1, ..., l_N\}^T + + with + + .. math:: + l_n = \begin{cases} + 0.5 (x_n - y_n)^2 / beta, & \text{if } |x_n - y_n| < beta \\ + |x_n - y_n| - 0.5 * beta, & \text{otherwise } + \end{cases} + + If `reduction` is not `none`, then: + + .. math:: + \ell(x, y) = + \begin{cases} + \operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\ + \operatorname{sum}(L), & \text{if reduction} = \text{`sum'.} + \end{cases} + + .. note:: + Smooth L1 loss can be seen as exactly :class:`L1Loss`, but with the :math:`|x - y| < beta` + portion replaced with a quadratic function such that its slope is 1 at :math:`|x - y| = beta`. + The quadratic segment smooths the L1 loss near :math:`|x - y| = 0`. + + .. note:: + Smooth L1 loss is closely related to :class:`HuberLoss`, being + equivalent to :math:`huber(x, y) / beta` (note that Smooth L1's beta hyper-parameter is + also known as delta for Huber). This leads to the following differences: + + * As beta -> 0, Smooth L1 loss converges to :class:`L1Loss`, while :class:`HuberLoss` + converges to a constant 0 loss. When beta is 0, Smooth L1 loss is equivalent to L1 loss. + * As beta -> :math:`+\infty`, Smooth L1 loss converges to a constant 0 loss, while + :class:`HuberLoss` converges to :class:`MSELoss`. + * For Smooth L1 loss, as beta varies, the L1 segment of the loss has a constant slope of 1. + For :class:`HuberLoss`, the slope of the L1 segment is beta. + + .. _`Fast R-CNN`: https://arxiv.org/abs/1504.08083 + + Args: + size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there are multiple elements per sample. If the field :attr:`size_average` + is set to ``False``, the losses are instead summed for each minibatch. Ignored + when :attr:`reduce` is ``False``. Default: ``True`` + reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: ``True`` + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in the meantime, + specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + beta (float, optional): Specifies the threshold at which to change between L1 and L2 loss. + The value must be non-negative. Default: 1.0 + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Target: :math:`(*)`, same shape as the input. + - Output: scalar. If :attr:`reduction` is ``'none'``, then :math:`(*)`, same shape as the input. + """ + __constants__ = ['reduction'] + + def __init__(self, size_average=None, reduce=None, reduction: str = 'mean', beta: float = 1.0) -> None: + super().__init__(size_average, reduce, reduction) + self.beta = beta + + def forward(self, input: Tensor, target: Tensor) -> Tensor: + return F.smooth_l1_loss(input, target, reduction=self.reduction, beta=self.beta) + + +class HuberLoss(_Loss): + r"""Creates a criterion that uses a squared term if the absolute + element-wise error falls below delta and a delta-scaled L1 term otherwise. + This loss combines advantages of both :class:`L1Loss` and :class:`MSELoss`; the + delta-scaled L1 region makes the loss less sensitive to outliers than :class:`MSELoss`, + while the L2 region provides smoothness over :class:`L1Loss` near 0. See + `Huber loss `_ for more information. + + For a batch of size :math:`N`, the unreduced loss can be described as: + + .. math:: + \ell(x, y) = L = \{l_1, ..., l_N\}^T + + with + + .. math:: + l_n = \begin{cases} + 0.5 (x_n - y_n)^2, & \text{if } |x_n - y_n| < delta \\ + delta * (|x_n - y_n| - 0.5 * delta), & \text{otherwise } + \end{cases} + + If `reduction` is not `none`, then: + + .. math:: + \ell(x, y) = + \begin{cases} + \operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\ + \operatorname{sum}(L), & \text{if reduction} = \text{`sum'.} + \end{cases} + + .. note:: + When delta is set to 1, this loss is equivalent to :class:`SmoothL1Loss`. + In general, this loss differs from :class:`SmoothL1Loss` by a factor of delta (AKA beta + in Smooth L1). + See :class:`SmoothL1Loss` for additional discussion on the differences in behavior + between the two losses. + + Args: + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Default: ``'mean'`` + delta (float, optional): Specifies the threshold at which to change between delta-scaled L1 and L2 loss. + The value must be positive. Default: 1.0 + + Shape: + - Input: :math:`(*)` where :math:`*` means any number of dimensions. + - Target: :math:`(*)`, same shape as the input. + - Output: scalar. If :attr:`reduction` is ``'none'``, then :math:`(*)`, same shape as the input. + """ + __constants__ = ['reduction', 'delta'] + + def __init__(self, reduction: str = 'mean', delta: float = 1.0) -> None: + super().__init__(reduction=reduction) + self.delta = delta + + def forward(self, input: Tensor, target: Tensor) -> Tensor: + return F.huber_loss(input, target, reduction=self.reduction, delta=self.delta) + + +class SoftMarginLoss(_Loss): + r"""Creates a criterion that optimizes a two-class classification + logistic loss between input tensor :math:`x` and target tensor :math:`y` + (containing 1 or -1). + + .. math:: + \text{loss}(x, y) = \sum_i \frac{\log(1 + \exp(-y[i]*x[i]))}{\text{x.nelement}()} + + Args: + size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there are multiple elements per sample. If the field :attr:`size_average` + is set to ``False``, the losses are instead summed for each minibatch. Ignored + when :attr:`reduce` is ``False``. Default: ``True`` + reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: ``True`` + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in the meantime, + specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Target: :math:`(*)`, same shape as the input. + - Output: scalar. If :attr:`reduction` is ``'none'``, then :math:`(*)`, same + shape as input. + + """ + __constants__ = ['reduction'] + + def __init__(self, size_average=None, reduce=None, reduction: str = 'mean') -> None: + super().__init__(size_average, reduce, reduction) + + def forward(self, input: Tensor, target: Tensor) -> Tensor: + return F.soft_margin_loss(input, target, reduction=self.reduction) + + +class CrossEntropyLoss(_WeightedLoss): + r"""This criterion computes the cross entropy loss between input logits + and target. + + It is useful when training a classification problem with `C` classes. + If provided, the optional argument :attr:`weight` should be a 1D `Tensor` + assigning weight to each of the classes. + This is particularly useful when you have an unbalanced training set. + + The `input` is expected to contain the unnormalized logits for each class (which do `not` need + to be positive or sum to 1, in general). + `input` has to be a Tensor of size :math:`(C)` for unbatched input, + :math:`(minibatch, C)` or :math:`(minibatch, C, d_1, d_2, ..., d_K)` with :math:`K \geq 1` for the + `K`-dimensional case. The last being useful for higher dimension inputs, such + as computing cross entropy loss per-pixel for 2D images. + + The `target` that this criterion expects should contain either: + + - Class indices in the range :math:`[0, C)` where :math:`C` is the number of classes; if + `ignore_index` is specified, this loss also accepts this class index (this index + may not necessarily be in the class range). The unreduced (i.e. with :attr:`reduction` + set to ``'none'``) loss for this case can be described as: + + .. math:: + \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad + l_n = - w_{y_n} \log \frac{\exp(x_{n,y_n})}{\sum_{c=1}^C \exp(x_{n,c})} + \cdot \mathbb{1}\{y_n \not= \text{ignore\_index}\} + + where :math:`x` is the input, :math:`y` is the target, :math:`w` is the weight, + :math:`C` is the number of classes, and :math:`N` spans the minibatch dimension as well as + :math:`d_1, ..., d_k` for the `K`-dimensional case. If + :attr:`reduction` is not ``'none'`` (default ``'mean'``), then + + .. math:: + \ell(x, y) = \begin{cases} + \sum_{n=1}^N \frac{1}{\sum_{n=1}^N w_{y_n} \cdot \mathbb{1}\{y_n \not= \text{ignore\_index}\}} l_n, & + \text{if reduction} = \text{`mean';}\\ + \sum_{n=1}^N l_n, & + \text{if reduction} = \text{`sum'.} + \end{cases} + + Note that this case is equivalent to applying :class:`~torch.nn.LogSoftmax` + on an input, followed by :class:`~torch.nn.NLLLoss`. + + - Probabilities for each class; useful when labels beyond a single class per minibatch item + are required, such as for blended labels, label smoothing, etc. The unreduced (i.e. with + :attr:`reduction` set to ``'none'``) loss for this case can be described as: + + .. math:: + \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad + l_n = - \sum_{c=1}^C w_c \log \frac{\exp(x_{n,c})}{\sum_{i=1}^C \exp(x_{n,i})} y_{n,c} + + where :math:`x` is the input, :math:`y` is the target, :math:`w` is the weight, + :math:`C` is the number of classes, and :math:`N` spans the minibatch dimension as well as + :math:`d_1, ..., d_k` for the `K`-dimensional case. If + :attr:`reduction` is not ``'none'`` (default ``'mean'``), then + + .. math:: + \ell(x, y) = \begin{cases} + \frac{\sum_{n=1}^N l_n}{N}, & + \text{if reduction} = \text{`mean';}\\ + \sum_{n=1}^N l_n, & + \text{if reduction} = \text{`sum'.} + \end{cases} + + .. note:: + The performance of this criterion is generally better when `target` contains class + indices, as this allows for optimized computation. Consider providing `target` as + class probabilities only when a single class label per minibatch item is too restrictive. + + Args: + weight (Tensor, optional): a manual rescaling weight given to each class. + If given, has to be a Tensor of size `C` and floating point dtype + size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there are multiple elements per sample. If the field :attr:`size_average` + is set to ``False``, the losses are instead summed for each minibatch. Ignored + when :attr:`reduce` is ``False``. Default: ``True`` + ignore_index (int, optional): Specifies a target value that is ignored + and does not contribute to the input gradient. When :attr:`size_average` is + ``True``, the loss is averaged over non-ignored targets. Note that + :attr:`ignore_index` is only applicable when the target contains class indices. + reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: ``True`` + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will + be applied, ``'mean'``: the weighted mean of the output is taken, + ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in + the meantime, specifying either of those two args will override + :attr:`reduction`. Default: ``'mean'`` + label_smoothing (float, optional): A float in [0.0, 1.0]. Specifies the amount + of smoothing when computing the loss, where 0.0 means no smoothing. The targets + become a mixture of the original ground truth and a uniform distribution as described in + `Rethinking the Inception Architecture for Computer Vision `__. Default: :math:`0.0`. + + Shape: + - Input: Shape :math:`(C)`, :math:`(N, C)` or :math:`(N, C, d_1, d_2, ..., d_K)` with :math:`K \geq 1` + in the case of `K`-dimensional loss. + - Target: If containing class indices, shape :math:`()`, :math:`(N)` or :math:`(N, d_1, d_2, ..., d_K)` with + :math:`K \geq 1` in the case of K-dimensional loss where each value should be between :math:`[0, C)`. + If containing class probabilities, same shape as the input and each value should be between :math:`[0, 1]`. + - Output: If reduction is 'none', shape :math:`()`, :math:`(N)` or :math:`(N, d_1, d_2, ..., d_K)` with :math:`K \geq 1` + in the case of K-dimensional loss, depending on the shape of the input. Otherwise, scalar. + + + where: + + .. math:: + \begin{aligned} + C ={} & \text{number of classes} \\ + N ={} & \text{batch size} \\ + \end{aligned} + + Examples:: + + >>> # Example of target with class indices + >>> loss = nn.CrossEntropyLoss() + >>> input = torch.randn(3, 5, requires_grad=True) + >>> target = torch.empty(3, dtype=torch.long).random_(5) + >>> output = loss(input, target) + >>> output.backward() + >>> + >>> # Example of target with class probabilities + >>> input = torch.randn(3, 5, requires_grad=True) + >>> target = torch.randn(3, 5).softmax(dim=1) + >>> output = loss(input, target) + >>> output.backward() + """ + __constants__ = ['ignore_index', 'reduction', 'label_smoothing'] + ignore_index: int + label_smoothing: float + + def __init__(self, weight: Optional[Tensor] = None, size_average=None, ignore_index: int = -100, + reduce=None, reduction: str = 'mean', label_smoothing: float = 0.0) -> None: + super().__init__(weight, size_average, reduce, reduction) + self.ignore_index = ignore_index + self.label_smoothing = label_smoothing + + def forward(self, input: Tensor, target: Tensor) -> Tensor: + return F.cross_entropy(input, target, weight=self.weight, + ignore_index=self.ignore_index, reduction=self.reduction, + label_smoothing=self.label_smoothing) + + +class MultiLabelSoftMarginLoss(_WeightedLoss): + r"""Creates a criterion that optimizes a multi-label one-versus-all + loss based on max-entropy, between input :math:`x` and target :math:`y` of size + :math:`(N, C)`. + For each sample in the minibatch: + + .. math:: + loss(x, y) = - \frac{1}{C} * \sum_i y[i] * \log((1 + \exp(-x[i]))^{-1}) + + (1-y[i]) * \log\left(\frac{\exp(-x[i])}{(1 + \exp(-x[i]))}\right) + + where :math:`i \in \left\{0, \; \cdots , \; \text{x.nElement}() - 1\right\}`, + :math:`y[i] \in \left\{0, \; 1\right\}`. + + Args: + weight (Tensor, optional): a manual rescaling weight given to each + class. If given, it has to be a Tensor of size `C`. Otherwise, it is + treated as if having all ones. + size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there are multiple elements per sample. If the field :attr:`size_average` + is set to ``False``, the losses are instead summed for each minibatch. Ignored + when :attr:`reduce` is ``False``. Default: ``True`` + reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: ``True`` + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in the meantime, + specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + + Shape: + - Input: :math:`(N, C)` where `N` is the batch size and `C` is the number of classes. + - Target: :math:`(N, C)`, label targets must have the same shape as the input. + - Output: scalar. If :attr:`reduction` is ``'none'``, then :math:`(N)`. + """ + __constants__ = ['reduction'] + + def __init__(self, weight: Optional[Tensor] = None, size_average=None, reduce=None, reduction: str = 'mean') -> None: + super().__init__(weight, size_average, reduce, reduction) + + def forward(self, input: Tensor, target: Tensor) -> Tensor: + return F.multilabel_soft_margin_loss(input, target, weight=self.weight, reduction=self.reduction) + + +class CosineEmbeddingLoss(_Loss): + r"""Creates a criterion that measures the loss given input tensors + :math:`x_1`, :math:`x_2` and a `Tensor` label :math:`y` with values 1 or -1. + Use (:math:`y=1`) to maximize the cosine similarity of two inputs, and (:math:`y=-1`) otherwise. + This is typically used for learning nonlinear + embeddings or semi-supervised learning. + + The loss function for each sample is: + + .. math:: + \text{loss}(x, y) = + \begin{cases} + 1 - \cos(x_1, x_2), & \text{if } y = 1 \\ + \max(0, \cos(x_1, x_2) - \text{margin}), & \text{if } y = -1 + \end{cases} + + Args: + margin (float, optional): Should be a number from :math:`-1` to :math:`1`, + :math:`0` to :math:`0.5` is suggested. If :attr:`margin` is missing, the + default value is :math:`0`. + size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there are multiple elements per sample. If the field :attr:`size_average` + is set to ``False``, the losses are instead summed for each minibatch. Ignored + when :attr:`reduce` is ``False``. Default: ``True`` + reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: ``True`` + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in the meantime, + specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + + Shape: + - Input1: :math:`(N, D)` or :math:`(D)`, where `N` is the batch size and `D` is the embedding dimension. + - Input2: :math:`(N, D)` or :math:`(D)`, same shape as Input1. + - Target: :math:`(N)` or :math:`()`. + - Output: If :attr:`reduction` is ``'none'``, then :math:`(N)`, otherwise scalar. + + Examples:: + + >>> loss = nn.CosineEmbeddingLoss() + >>> input1 = torch.randn(3, 5, requires_grad=True) + >>> input2 = torch.randn(3, 5, requires_grad=True) + >>> target = torch.ones(3) + >>> output = loss(input1, input2, target) + >>> output.backward() + """ + __constants__ = ['margin', 'reduction'] + margin: float + + def __init__(self, margin: float = 0., size_average=None, reduce=None, reduction: str = 'mean') -> None: + super().__init__(size_average, reduce, reduction) + self.margin = margin + + def forward(self, input1: Tensor, input2: Tensor, target: Tensor) -> Tensor: + return F.cosine_embedding_loss(input1, input2, target, margin=self.margin, reduction=self.reduction) + + +class MarginRankingLoss(_Loss): + r"""Creates a criterion that measures the loss given + inputs :math:`x1`, :math:`x2`, two 1D mini-batch or 0D `Tensors`, + and a label 1D mini-batch or 0D `Tensor` :math:`y` (containing 1 or -1). + + If :math:`y = 1` then it assumed the first input should be ranked higher + (have a larger value) than the second input, and vice-versa for :math:`y = -1`. + + The loss function for each pair of samples in the mini-batch is: + + .. math:: + \text{loss}(x1, x2, y) = \max(0, -y * (x1 - x2) + \text{margin}) + + Args: + margin (float, optional): Has a default value of :math:`0`. + size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there are multiple elements per sample. If the field :attr:`size_average` + is set to ``False``, the losses are instead summed for each minibatch. Ignored + when :attr:`reduce` is ``False``. Default: ``True`` + reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: ``True`` + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in the meantime, + specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + + Shape: + - Input1: :math:`(N)` or :math:`()` where `N` is the batch size. + - Input2: :math:`(N)` or :math:`()`, same shape as the Input1. + - Target: :math:`(N)` or :math:`()`, same shape as the inputs. + - Output: scalar. If :attr:`reduction` is ``'none'`` and Input size is not :math:`()`, then :math:`(N)`. + + Examples:: + + >>> loss = nn.MarginRankingLoss() + >>> input1 = torch.randn(3, requires_grad=True) + >>> input2 = torch.randn(3, requires_grad=True) + >>> target = torch.randn(3).sign() + >>> output = loss(input1, input2, target) + >>> output.backward() + """ + __constants__ = ['margin', 'reduction'] + margin: float + + def __init__(self, margin: float = 0., size_average=None, reduce=None, reduction: str = 'mean') -> None: + super().__init__(size_average, reduce, reduction) + self.margin = margin + + def forward(self, input1: Tensor, input2: Tensor, target: Tensor) -> Tensor: + return F.margin_ranking_loss(input1, input2, target, margin=self.margin, reduction=self.reduction) + + +class MultiMarginLoss(_WeightedLoss): + r"""Creates a criterion that optimizes a multi-class classification hinge + loss (margin-based loss) between input :math:`x` (a 2D mini-batch `Tensor`) and + output :math:`y` (which is a 1D tensor of target class indices, + :math:`0 \leq y \leq \text{x.size}(1)-1`): + + For each mini-batch sample, the loss in terms of the 1D input :math:`x` and scalar + output :math:`y` is: + + .. math:: + \text{loss}(x, y) = \frac{\sum_i \max(0, \text{margin} - x[y] + x[i])^p}{\text{x.size}(0)} + + where :math:`i \in \left\{0, \; \cdots , \; \text{x.size}(0) - 1\right\}` + and :math:`i \neq y`. + + Optionally, you can give non-equal weighting on the classes by passing + a 1D :attr:`weight` tensor into the constructor. + + The loss function then becomes: + + .. math:: + \text{loss}(x, y) = \frac{\sum_i w[y] * \max(0, \text{margin} - x[y] + x[i])^p}{\text{x.size}(0)} + + Args: + p (int, optional): Has a default value of :math:`1`. :math:`1` and :math:`2` + are the only supported values. + margin (float, optional): Has a default value of :math:`1`. + weight (Tensor, optional): a manual rescaling weight given to each + class. If given, it has to be a Tensor of size `C`. Otherwise, it is + treated as if having all ones. + size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there are multiple elements per sample. If the field :attr:`size_average` + is set to ``False``, the losses are instead summed for each minibatch. Ignored + when :attr:`reduce` is ``False``. Default: ``True`` + reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: ``True`` + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in the meantime, + specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + + Shape: + - Input: :math:`(N, C)` or :math:`(C)`, where :math:`N` is the batch size and :math:`C` is the number of classes. + - Target: :math:`(N)` or :math:`()`, where each value is :math:`0 \leq \text{targets}[i] \leq C-1`. + - Output: scalar. If :attr:`reduction` is ``'none'``, then same shape as the target. + + Examples:: + + >>> loss = nn.MultiMarginLoss() + >>> x = torch.tensor([[0.1, 0.2, 0.4, 0.8]]) + >>> y = torch.tensor([3]) + >>> # 0.25 * ((1-(0.8-0.1)) + (1-(0.8-0.2)) + (1-(0.8-0.4))) + >>> loss(x, y) + tensor(0.32...) + """ + __constants__ = ['p', 'margin', 'reduction'] + margin: float + p: int + + def __init__(self, p: int = 1, margin: float = 1., weight: Optional[Tensor] = None, size_average=None, + reduce=None, reduction: str = 'mean') -> None: + super().__init__(weight, size_average, reduce, reduction) + if p != 1 and p != 2: + raise ValueError("only p == 1 and p == 2 supported") + if weight is not None and weight.dim() != 1 : + raise ValueError( + f"MultiMarginLoss: expected weight to be None or 1D tensor, got {weight.dim()}D instead" + ) + self.p = p + self.margin = margin + + def forward(self, input: Tensor, target: Tensor) -> Tensor: + return F.multi_margin_loss(input, target, p=self.p, margin=self.margin, + weight=self.weight, reduction=self.reduction) + + +class TripletMarginLoss(_Loss): + r"""Creates a criterion that measures the triplet loss given an input + tensors :math:`x1`, :math:`x2`, :math:`x3` and a margin with a value greater than :math:`0`. + This is used for measuring a relative similarity between samples. A triplet + is composed by `a`, `p` and `n` (i.e., `anchor`, `positive examples` and `negative + examples` respectively). The shapes of all input tensors should be + :math:`(N, D)`. + + The distance swap is described in detail in the paper `Learning shallow + convolutional feature descriptors with triplet losses`_ by + V. Balntas, E. Riba et al. + + The loss function for each sample in the mini-batch is: + + .. math:: + L(a, p, n) = \max \{d(a_i, p_i) - d(a_i, n_i) + {\rm margin}, 0\} + + + where + + .. math:: + d(x_i, y_i) = \left\lVert {\bf x}_i - {\bf y}_i \right\rVert_p + + The norm is calculated using the specified p value and a small constant :math:`\varepsilon` is + added for numerical stability. + + See also :class:`~torch.nn.TripletMarginWithDistanceLoss`, which computes the + triplet margin loss for input tensors using a custom distance function. + + Args: + margin (float, optional): Default: :math:`1`. + p (int, optional): The norm degree for pairwise distance. Default: :math:`2`. + eps (float, optional): Small constant for numerical stability. Default: :math:`1e-6`. + swap (bool, optional): The distance swap is described in detail in the paper + `Learning shallow convolutional feature descriptors with triplet losses` by + V. Balntas, E. Riba et al. Default: ``False``. + size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there are multiple elements per sample. If the field :attr:`size_average` + is set to ``False``, the losses are instead summed for each minibatch. Ignored + when :attr:`reduce` is ``False``. Default: ``True`` + reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: ``True`` + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in the meantime, + specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + + Shape: + - Input: :math:`(N, D)` or :math:`(D)` where :math:`D` is the vector dimension. + - Output: A Tensor of shape :math:`(N)` if :attr:`reduction` is ``'none'`` and + input shape is :math:`(N, D)`; a scalar otherwise. + + Examples:: + + >>> triplet_loss = nn.TripletMarginLoss(margin=1.0, p=2, eps=1e-7) + >>> anchor = torch.randn(100, 128, requires_grad=True) + >>> positive = torch.randn(100, 128, requires_grad=True) + >>> negative = torch.randn(100, 128, requires_grad=True) + >>> output = triplet_loss(anchor, positive, negative) + >>> output.backward() + + .. _Learning shallow convolutional feature descriptors with triplet losses: + http://www.bmva.org/bmvc/2016/papers/paper119/index.html + """ + __constants__ = ['margin', 'p', 'eps', 'swap', 'reduction'] + margin: float + p: float + eps: float + swap: bool + + def __init__(self, margin: float = 1.0, p: float = 2., eps: float = 1e-6, swap: bool = False, size_average=None, + reduce=None, reduction: str = 'mean'): + super().__init__(size_average, reduce, reduction) + self.margin = margin + self.p = p + self.eps = eps + self.swap = swap + + def forward(self, anchor: Tensor, positive: Tensor, negative: Tensor) -> Tensor: + return F.triplet_margin_loss(anchor, positive, negative, margin=self.margin, p=self.p, + eps=self.eps, swap=self.swap, reduction=self.reduction) + + +class TripletMarginWithDistanceLoss(_Loss): + r"""Creates a criterion that measures the triplet loss given input + tensors :math:`a`, :math:`p`, and :math:`n` (representing anchor, + positive, and negative examples, respectively), and a nonnegative, + real-valued function ("distance function") used to compute the relationship + between the anchor and positive example ("positive distance") and the + anchor and negative example ("negative distance"). + + The unreduced loss (i.e., with :attr:`reduction` set to ``'none'``) + can be described as: + + .. math:: + \ell(a, p, n) = L = \{l_1,\dots,l_N\}^\top, \quad + l_i = \max \{d(a_i, p_i) - d(a_i, n_i) + {\rm margin}, 0\} + + where :math:`N` is the batch size; :math:`d` is a nonnegative, real-valued function + quantifying the closeness of two tensors, referred to as the :attr:`distance_function`; + and :math:`margin` is a nonnegative margin representing the minimum difference + between the positive and negative distances that is required for the loss to + be 0. The input tensors have :math:`N` elements each and can be of any shape + that the distance function can handle. + + If :attr:`reduction` is not ``'none'`` + (default ``'mean'``), then: + + .. math:: + \ell(x, y) = + \begin{cases} + \operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\ + \operatorname{sum}(L), & \text{if reduction} = \text{`sum'.} + \end{cases} + + See also :class:`~torch.nn.TripletMarginLoss`, which computes the triplet + loss for input tensors using the :math:`l_p` distance as the distance function. + + Args: + distance_function (Callable, optional): A nonnegative, real-valued function that + quantifies the closeness of two tensors. If not specified, + `nn.PairwiseDistance` will be used. Default: ``None`` + margin (float, optional): A nonnegative margin representing the minimum difference + between the positive and negative distances required for the loss to be 0. Larger + margins penalize cases where the negative examples are not distant enough from the + anchors, relative to the positives. Default: :math:`1`. + swap (bool, optional): Whether to use the distance swap described in the paper + `Learning shallow convolutional feature descriptors with triplet losses` by + V. Balntas, E. Riba et al. If True, and if the positive example is closer to the + negative example than the anchor is, swaps the positive example and the anchor in + the loss computation. Default: ``False``. + reduction (str, optional): Specifies the (optional) reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Default: ``'mean'`` + + + Shape: + - Input: :math:`(N, *)` where :math:`*` represents any number of additional dimensions + as supported by the distance function. + - Output: A Tensor of shape :math:`(N)` if :attr:`reduction` is ``'none'``, or a scalar + otherwise. + + Examples:: + + >>> # Initialize embeddings + >>> embedding = nn.Embedding(1000, 128) + >>> anchor_ids = torch.randint(0, 1000, (1,)) + >>> positive_ids = torch.randint(0, 1000, (1,)) + >>> negative_ids = torch.randint(0, 1000, (1,)) + >>> anchor = embedding(anchor_ids) + >>> positive = embedding(positive_ids) + >>> negative = embedding(negative_ids) + >>> + >>> # Built-in Distance Function + >>> triplet_loss = \ + >>> nn.TripletMarginWithDistanceLoss(distance_function=nn.PairwiseDistance()) + >>> output = triplet_loss(anchor, positive, negative) + >>> output.backward() + >>> + >>> # Custom Distance Function + >>> def l_infinity(x1, x2): + >>> return torch.max(torch.abs(x1 - x2), dim=1).values + >>> + >>> # xdoctest: +SKIP("FIXME: Would call backwards a second time") + >>> triplet_loss = ( + >>> nn.TripletMarginWithDistanceLoss(distance_function=l_infinity, margin=1.5)) + >>> output = triplet_loss(anchor, positive, negative) + >>> output.backward() + >>> + >>> # Custom Distance Function (Lambda) + >>> triplet_loss = ( + >>> nn.TripletMarginWithDistanceLoss( + >>> distance_function=lambda x, y: 1.0 - F.cosine_similarity(x, y))) + >>> output = triplet_loss(anchor, positive, negative) + >>> output.backward() + + Reference: + V. Balntas, et al.: Learning shallow convolutional feature descriptors with triplet losses: + http://www.bmva.org/bmvc/2016/papers/paper119/index.html + """ + __constants__ = ['margin', 'swap', 'reduction'] + margin: float + swap: bool + + def __init__(self, *, distance_function: Optional[Callable[[Tensor, Tensor], Tensor]] = None, + margin: float = 1.0, swap: bool = False, reduction: str = 'mean'): + super().__init__(size_average=None, reduce=None, reduction=reduction) + self.distance_function: Optional[Callable[[Tensor, Tensor], Tensor]] = \ + distance_function if distance_function is not None else PairwiseDistance() + self.margin = margin + self.swap = swap + + def forward(self, anchor: Tensor, positive: Tensor, negative: Tensor) -> Tensor: + return F.triplet_margin_with_distance_loss(anchor, positive, negative, + distance_function=self.distance_function, + margin=self.margin, swap=self.swap, reduction=self.reduction) + + +class CTCLoss(_Loss): + r"""The Connectionist Temporal Classification loss. + + Calculates loss between a continuous (unsegmented) time series and a target sequence. CTCLoss sums over the + probability of possible alignments of input to target, producing a loss value which is differentiable + with respect to each input node. The alignment of input to target is assumed to be "many-to-one", which + limits the length of the target sequence such that it must be :math:`\leq` the input length. + + Args: + blank (int, optional): blank label. Default :math:`0`. + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the output losses will be divided by the target lengths and + then the mean over the batch is taken, ``'sum'``: the output losses will be summed. + Default: ``'mean'`` + zero_infinity (bool, optional): + Whether to zero infinite losses and the associated gradients. + Default: ``False`` + Infinite losses mainly occur when the inputs are too short + to be aligned to the targets. + + Shape: + - Log_probs: Tensor of size :math:`(T, N, C)` or :math:`(T, C)`, + where :math:`T = \text{input length}`, + :math:`N = \text{batch size}`, and + :math:`C = \text{number of classes (including blank)}`. + The logarithmized probabilities of the outputs (e.g. obtained with + :func:`torch.nn.functional.log_softmax`). + - Targets: Tensor of size :math:`(N, S)` or + :math:`(\operatorname{sum}(\text{target\_lengths}))`, + where :math:`N = \text{batch size}` and + :math:`S = \text{max target length, if shape is } (N, S)`. + It represent the target sequences. Each element in the target + sequence is a class index. And the target index cannot be blank (default=0). + In the :math:`(N, S)` form, targets are padded to the + length of the longest sequence, and stacked. + In the :math:`(\operatorname{sum}(\text{target\_lengths}))` form, + the targets are assumed to be un-padded and + concatenated within 1 dimension. + - Input_lengths: Tuple or tensor of size :math:`(N)` or :math:`()`, + where :math:`N = \text{batch size}`. It represent the lengths of the + inputs (must each be :math:`\leq T`). And the lengths are specified + for each sequence to achieve masking under the assumption that sequences + are padded to equal lengths. + - Target_lengths: Tuple or tensor of size :math:`(N)` or :math:`()`, + where :math:`N = \text{batch size}`. It represent lengths of the targets. + Lengths are specified for each sequence to achieve masking under the + assumption that sequences are padded to equal lengths. If target shape is + :math:`(N,S)`, target_lengths are effectively the stop index + :math:`s_n` for each target sequence, such that ``target_n = targets[n,0:s_n]`` for + each target in a batch. Lengths must each be :math:`\leq S` + If the targets are given as a 1d tensor that is the concatenation of individual + targets, the target_lengths must add up to the total length of the tensor. + - Output: scalar if :attr:`reduction` is ``'mean'`` (default) or + ``'sum'``. If :attr:`reduction` is ``'none'``, then :math:`(N)` if input is batched or + :math:`()` if input is unbatched, where :math:`N = \text{batch size}`. + + Examples:: + + >>> # Target are to be padded + >>> T = 50 # Input sequence length + >>> C = 20 # Number of classes (including blank) + >>> N = 16 # Batch size + >>> S = 30 # Target sequence length of longest target in batch (padding length) + >>> S_min = 10 # Minimum target length, for demonstration purposes + >>> + >>> # Initialize random batch of input vectors, for *size = (T,N,C) + >>> input = torch.randn(T, N, C).log_softmax(2).detach().requires_grad_() + >>> + >>> # Initialize random batch of targets (0 = blank, 1:C = classes) + >>> target = torch.randint(low=1, high=C, size=(N, S), dtype=torch.long) + >>> + >>> input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long) + >>> target_lengths = torch.randint(low=S_min, high=S, size=(N,), dtype=torch.long) + >>> ctc_loss = nn.CTCLoss() + >>> loss = ctc_loss(input, target, input_lengths, target_lengths) + >>> loss.backward() + >>> + >>> + >>> # Target are to be un-padded + >>> T = 50 # Input sequence length + >>> C = 20 # Number of classes (including blank) + >>> N = 16 # Batch size + >>> + >>> # Initialize random batch of input vectors, for *size = (T,N,C) + >>> input = torch.randn(T, N, C).log_softmax(2).detach().requires_grad_() + >>> input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long) + >>> + >>> # Initialize random batch of targets (0 = blank, 1:C = classes) + >>> target_lengths = torch.randint(low=1, high=T, size=(N,), dtype=torch.long) + >>> target = torch.randint(low=1, high=C, size=(sum(target_lengths),), dtype=torch.long) + >>> ctc_loss = nn.CTCLoss() + >>> loss = ctc_loss(input, target, input_lengths, target_lengths) + >>> loss.backward() + >>> + >>> + >>> # Target are to be un-padded and unbatched (effectively N=1) + >>> T = 50 # Input sequence length + >>> C = 20 # Number of classes (including blank) + >>> + >>> # Initialize random batch of input vectors, for *size = (T,C) + >>> # xdoctest: +SKIP("FIXME: error in doctest") + >>> input = torch.randn(T, C).log_softmax(1).detach().requires_grad_() + >>> input_lengths = torch.tensor(T, dtype=torch.long) + >>> + >>> # Initialize random batch of targets (0 = blank, 1:C = classes) + >>> target_lengths = torch.randint(low=1, high=T, size=(), dtype=torch.long) + >>> target = torch.randint(low=1, high=C, size=(target_lengths,), dtype=torch.long) + >>> ctc_loss = nn.CTCLoss() + >>> loss = ctc_loss(input, target, input_lengths, target_lengths) + >>> loss.backward() + + Reference: + A. Graves et al.: Connectionist Temporal Classification: + Labelling Unsegmented Sequence Data with Recurrent Neural Networks: + https://www.cs.toronto.edu/~graves/icml_2006.pdf + + Note: + In order to use CuDNN, the following must be satisfied: :attr:`targets` must be + in concatenated format, all :attr:`input_lengths` must be `T`. :math:`blank=0`, + :attr:`target_lengths` :math:`\leq 256`, the integer arguments must be of + dtype :attr:`torch.int32`. + + The regular implementation uses the (more common in PyTorch) `torch.long` dtype. + + + Note: + In some circumstances when using the CUDA backend with CuDNN, this operator + may select a nondeterministic algorithm to increase performance. If this is + undesirable, you can try to make the operation deterministic (potentially at + a performance cost) by setting ``torch.backends.cudnn.deterministic = + True``. + Please see the notes on :doc:`/notes/randomness` for background. + """ + __constants__ = ['blank', 'reduction'] + blank: int + zero_infinity: bool + + def __init__(self, blank: int = 0, reduction: str = 'mean', zero_infinity: bool = False): + super().__init__(reduction=reduction) + self.blank = blank + self.zero_infinity = zero_infinity + + def forward(self, log_probs: Tensor, targets: Tensor, input_lengths: Tensor, target_lengths: Tensor) -> Tensor: + return F.ctc_loss(log_probs, targets, input_lengths, target_lengths, self.blank, self.reduction, + self.zero_infinity) + +# TODO: L1HingeEmbeddingCriterion +# TODO: MSECriterion weight +# TODO: ClassSimplexCriterion diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py new file mode 100644 index 0000000000000000000000000000000000000000..8fd81d734bc6725027b6ca88c367069fb87d5eec --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py @@ -0,0 +1,2577 @@ +from collections import OrderedDict, namedtuple +import itertools +import warnings +import functools +import weakref + +import torch +from torch._prims_common import DeviceLikeType +from ..parameter import Parameter +import torch.utils.hooks as hooks + +from torch import Tensor, device, dtype +from typing import Union, Tuple, Any, Callable, Iterator, Set, Optional, overload, TypeVar, Mapping, Dict, List +from typing_extensions import Self +from ...utils.hooks import RemovableHandle +from torch.utils._python_dispatch import is_traceable_wrapper_subclass + +__all__ = ['register_module_forward_pre_hook', 'register_module_forward_hook', + 'register_module_full_backward_pre_hook', 'register_module_backward_hook', + 'register_module_full_backward_hook', 'register_module_buffer_registration_hook', + 'register_module_module_registration_hook', 'register_module_parameter_registration_hook', 'Module'] + +_grad_t = Union[Tuple[Tensor, ...], Tensor] +# See https://mypy.readthedocs.io/en/latest/generics.html#generic-methods-and-generic-self for the use +# of `T` to annotate `self`. Many methods of `Module` return `self` and we want those return values to be +# the type of the subclass, not the looser type of `Module`. +T = TypeVar('T', bound='Module') + + +class _IncompatibleKeys(namedtuple('IncompatibleKeys', ['missing_keys', 'unexpected_keys'])): + def __repr__(self): + if not self.missing_keys and not self.unexpected_keys: + return '' + return super().__repr__() + + __str__ = __repr__ + + +def _addindent(s_, numSpaces): + s = s_.split('\n') + # don't do anything for single-line stuff + if len(s) == 1: + return s_ + first = s.pop(0) + s = [(numSpaces * ' ') + line for line in s] + s = '\n'.join(s) + s = first + '\n' + s + return s + +r"""This tracks hooks common to all modules that are executed immediately before +.registering the buffer/module/parameter""" +_global_buffer_registration_hooks: Dict[int, Callable] = OrderedDict() +_global_module_registration_hooks: Dict[int, Callable] = OrderedDict() +_global_parameter_registration_hooks: Dict[int, Callable] = OrderedDict() + +class _WrappedHook: + def __init__(self, hook: Callable, module: Optional["Module"] = None): + self.hook: Callable = hook + functools.update_wrapper(self, hook) + + self.with_module: bool = False + + if module is not None: + self.module: weakref.ReferenceType[Module] = weakref.ref(module) + self.with_module = True + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + if self.with_module: + module = self.module() + if module is None: + raise RuntimeError("You are trying to call the hook of a dead Module!") + return self.hook(module, *args, **kwargs) + return self.hook(*args, **kwargs) + + def __getstate__(self) -> Dict: + result = {"hook": self.hook, "with_module": self.with_module} + if self.with_module: + result["module"] = self.module() + + return result + + def __setstate__(self, state: Dict): + self.hook = state["hook"] + self.with_module = state["with_module"] + + if self.with_module: + if state["module"] is None: + raise RuntimeError("You are trying to revive the hook of a dead Module!") + self.module = weakref.ref(state["module"]) + + +r"""This tracks hooks common to all modules that are executed before/after +calling forward and backward. This is global state used for debugging/profiling +purposes""" +_global_backward_pre_hooks: Dict[int, Callable] = OrderedDict() +_global_backward_hooks: Dict[int, Callable] = OrderedDict() +_global_is_full_backward_hook: Optional[bool] = None +_global_forward_pre_hooks: Dict[int, Callable] = OrderedDict() +_global_forward_hooks: Dict[int, Callable] = OrderedDict() +_global_forward_hooks_always_called: Dict[int, bool] = OrderedDict() + +_EXTRA_STATE_KEY_SUFFIX = '_extra_state' + + +def register_module_buffer_registration_hook(hook: Callable[..., None]) -> RemovableHandle: + r"""Register a buffer registration hook common to all modules. + + .. warning :: + + This adds global state to the `nn.Module` module + + The hook will be called every time :func:`register_buffer` is invoked. + It should have the following signature:: + + hook(module, name, buffer) -> None or new buffer + + The hook can modify the input or return a single modified value in the hook. + + Returns: + :class:`torch.utils.hooks.RemovableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + """ + handle = hooks.RemovableHandle(_global_buffer_registration_hooks) + _global_buffer_registration_hooks[handle.id] = hook + return handle + + +def register_module_module_registration_hook(hook: Callable[..., None]) -> RemovableHandle: + r"""Register a module registration hook common to all modules. + + .. warning :: + + This adds global state to the `nn.Module` module + + The hook will be called every time :func:`register_module` is invoked. + It should have the following signature:: + + hook(module, name, submodule) -> None or new submodule + + The hook can modify the input or return a single modified value in the hook. + + Returns: + :class:`torch.utils.hooks.RemovableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + """ + handle = hooks.RemovableHandle(_global_module_registration_hooks) + _global_module_registration_hooks[handle.id] = hook + return handle + + +def register_module_parameter_registration_hook(hook: Callable[..., None]) -> RemovableHandle: + r"""Register a parameter registration hook common to all modules. + + .. warning :: + + This adds global state to the `nn.Module` module + + The hook will be called every time :func:`register_parameter` is invoked. + It should have the following signature:: + + hook(module, name, param) -> None or new parameter + + The hook can modify the input or return a single modified value in the hook. + + Returns: + :class:`torch.utils.hooks.RemovableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + """ + handle = hooks.RemovableHandle(_global_parameter_registration_hooks) + _global_parameter_registration_hooks[handle.id] = hook + return handle + + +def register_module_forward_pre_hook(hook: Callable[..., None]) -> RemovableHandle: + r"""Register a forward pre-hook common to all modules. + + .. warning :: + + This adds global state to the `nn.module` module + and it is only intended for debugging/profiling purposes. + + The hook will be called every time before :func:`forward` is invoked. + It should have the following signature:: + + hook(module, input) -> None or modified input + + The input contains only the positional arguments given to the module. + Keyword arguments won't be passed to the hooks and only to the ``forward``. + The hook can modify the input. User can either return a tuple or a + single modified value in the hook. We will wrap the value into a tuple + if a single value is returned(unless that value is already a tuple). + + This hook has precedence over the specific module hooks registered with + ``register_forward_pre_hook``. + + Returns: + :class:`torch.utils.hooks.RemovableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + """ + handle = hooks.RemovableHandle(_global_forward_pre_hooks) + _global_forward_pre_hooks[handle.id] = hook + return handle + + +def register_module_forward_hook(hook: Callable[..., None], *, always_call: bool = False) -> RemovableHandle: + r"""Register a global forward hook for all the modules. + + .. warning :: + + This adds global state to the `nn.module` module + and it is only intended for debugging/profiling purposes. + + The hook will be called every time after :func:`forward` has computed an output. + It should have the following signature:: + + hook(module, input, output) -> None or modified output + + The input contains only the positional arguments given to the module. + Keyword arguments won't be passed to the hooks and only to the ``forward``. + The hook can modify the output. It can modify the input inplace but + it will not have effect on forward since this is called after + :func:`forward` is called. + + Parameters: + hook (Callable): The user defined hook to be registered. + always_call (bool): If ``True`` the ``hook`` will be run regardless of + whether an exception is raised while calling the Module. + Default: ``False`` + Returns: + :class:`torch.utils.hooks.RemovableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + + This hook will be executed before specific module hooks registered with + ``register_forward_hook``. + """ + handle = hooks.RemovableHandle(_global_forward_hooks, + extra_dict=_global_forward_hooks_always_called) + _global_forward_hooks[handle.id] = hook + if always_call: + _global_forward_hooks_always_called[handle.id] = True + return handle + + +def register_module_backward_hook( + hook: Callable[['Module', _grad_t, _grad_t], Union[None, _grad_t]] +) -> RemovableHandle: + r"""Register a backward hook common to all the modules. + + This function is deprecated in favor of + :func:`torch.nn.modules.module.register_module_full_backward_hook` + and the behavior of this function will change in future versions. + + Returns: + :class:`torch.utils.hooks.RemovableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + + """ + global _global_is_full_backward_hook + if _global_is_full_backward_hook is True: + raise RuntimeError("Cannot use both regular backward hooks and full backward hooks as a " + "global Module hook. Please use only one of them.") + + _global_is_full_backward_hook = False + + handle = hooks.RemovableHandle(_global_backward_hooks) + _global_backward_hooks[handle.id] = hook + return handle + + +def register_module_full_backward_pre_hook( + hook: Callable[['Module', _grad_t], Union[None, _grad_t]] +) -> RemovableHandle: + r"""Register a backward pre-hook common to all the modules. + + .. warning :: + This adds global state to the `nn.module` module + and it is only intended for debugging/profiling purposes. + + Hooks registered using this function behave in the same way as those + registered by :meth:`torch.nn.Module.register_full_backward_pre_hook`. + Refer to its documentation for more details. + + Hooks registered using this function will be called before hooks registered + using :meth:`torch.nn.Module.register_full_backward_pre_hook`. + + Returns: + :class:`torch.utils.hooks.RemovableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + + """ + handle = hooks.RemovableHandle(_global_backward_pre_hooks) + _global_backward_pre_hooks[handle.id] = hook + return handle + + +def register_module_full_backward_hook( + hook: Callable[['Module', _grad_t, _grad_t], Union[None, _grad_t]] +) -> RemovableHandle: + r"""Register a backward hook common to all the modules. + + .. warning :: + This adds global state to the `nn.module` module + and it is only intended for debugging/profiling purposes. + + Hooks registered using this function behave in the same way as those + registered by :meth:`torch.nn.Module.register_full_backward_hook`. + Refer to its documentation for more details. + + Hooks registered using this function will be called before hooks registered + using :meth:`torch.nn.Module.register_full_backward_hook`. + + Returns: + :class:`torch.utils.hooks.RemovableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + + """ + global _global_is_full_backward_hook + if _global_is_full_backward_hook is False: + raise RuntimeError("Cannot use both regular backward hooks and full backward hooks as a " + "global Module hook. Please use only one of them.") + + _global_is_full_backward_hook = True + + handle = hooks.RemovableHandle(_global_backward_hooks) + _global_backward_hooks[handle.id] = hook + return handle + + +# Trick mypy into not applying contravariance rules to inputs by defining +# forward as a value, rather than a function. See also +# https://github.com/python/mypy/issues/8795 +def _forward_unimplemented(self, *input: Any) -> None: + r"""Define the computation performed at every call. + + Should be overridden by all subclasses. + + .. note:: + Although the recipe for forward pass needs to be defined within + this function, one should call the :class:`Module` instance afterwards + instead of this since the former takes care of running the + registered hooks while the latter silently ignores them. + """ + raise NotImplementedError(f"Module [{type(self).__name__}] is missing the required \"forward\" function") + + +class Module: + r"""Base class for all neural network modules. + + Your models should also subclass this class. + + Modules can also contain other Modules, allowing to nest them in + a tree structure. You can assign the submodules as regular attributes:: + + import torch.nn as nn + import torch.nn.functional as F + + class Model(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(1, 20, 5) + self.conv2 = nn.Conv2d(20, 20, 5) + + def forward(self, x): + x = F.relu(self.conv1(x)) + return F.relu(self.conv2(x)) + + Submodules assigned in this way will be registered, and will have their + parameters converted too when you call :meth:`to`, etc. + + .. note:: + As per the example above, an ``__init__()`` call to the parent class + must be made before assignment on the child. + + :ivar training: Boolean represents whether this module is in training or + evaluation mode. + :vartype training: bool + """ + + dump_patches: bool = False + + _version: int = 1 + r"""This allows better BC support for :meth:`load_state_dict`. In + :meth:`state_dict`, the version number will be saved as in the attribute + `_metadata` of the returned state dict, and thus pickled. `_metadata` is a + dictionary with keys that follow the naming convention of state dict. See + ``_load_from_state_dict`` on how to use this information in loading. + + If new parameters/buffers are added/removed from a module, this number shall + be bumped, and the module's `_load_from_state_dict` method can compare the + version number and do appropriate changes if the state dict is from before + the change.""" + + training: bool + _parameters: Dict[str, Optional[Parameter]] + _buffers: Dict[str, Optional[Tensor]] + _non_persistent_buffers_set: Set[str] + _backward_pre_hooks: Dict[int, Callable] + _backward_hooks: Dict[int, Callable] + _is_full_backward_hook: Optional[bool] + _forward_hooks: Dict[int, Callable] + # Marks whether the corresponding _forward_hooks accept kwargs or not. + # As JIT does not support Set[int], this dict is used as a set, where all + # hooks represented in this dict accept kwargs. + _forward_hooks_with_kwargs: Dict[int, bool] + # forward hooks that should always be called even if an exception is raised + _forward_hooks_always_called: Dict[int, bool] + _forward_pre_hooks: Dict[int, Callable] + # Marks whether the corresponding _forward_hooks accept kwargs or not. + # As JIT does not support Set[int], this dict is used as a set, where all + # hooks represented in this dict accept kwargs. + _forward_pre_hooks_with_kwargs: Dict[int, bool] + _state_dict_hooks: Dict[int, Callable] + _load_state_dict_pre_hooks: Dict[int, Callable] + _state_dict_pre_hooks: Dict[int, Callable] + _load_state_dict_post_hooks: Dict[int, Callable] + _modules: Dict[str, Optional['Module']] + call_super_init: bool = False + _compiled_call_impl : Optional[Callable] = None + + def __init__(self, *args, **kwargs) -> None: + """Initialize internal Module state, shared by both nn.Module and ScriptModule.""" + torch._C._log_api_usage_once("python.nn_module") + + # Backward compatibility: no args used to be allowed when call_super_init=False + if self.call_super_init is False and bool(kwargs): + raise TypeError("{}.__init__() got an unexpected keyword argument '{}'" + "".format(type(self).__name__, next(iter(kwargs)))) + + if self.call_super_init is False and bool(args): + raise TypeError(f"{type(self).__name__}.__init__() takes 1 positional argument but {len(args) + 1} were" + " given") + + """ + Calls super().__setattr__('a', a) instead of the typical self.a = a + to avoid Module.__setattr__ overhead. Module's __setattr__ has special + handling for parameters, submodules, and buffers but simply calls into + super().__setattr__ for all other attributes. + """ + super().__setattr__('training', True) + super().__setattr__('_parameters', OrderedDict()) + super().__setattr__('_buffers', OrderedDict()) + super().__setattr__('_non_persistent_buffers_set', set()) + super().__setattr__('_backward_pre_hooks', OrderedDict()) + super().__setattr__('_backward_hooks', OrderedDict()) + super().__setattr__('_is_full_backward_hook', None) + super().__setattr__('_forward_hooks', OrderedDict()) + super().__setattr__('_forward_hooks_with_kwargs', OrderedDict()) + super().__setattr__('_forward_hooks_always_called', OrderedDict()) + super().__setattr__('_forward_pre_hooks', OrderedDict()) + super().__setattr__('_forward_pre_hooks_with_kwargs', OrderedDict()) + super().__setattr__('_state_dict_hooks', OrderedDict()) + super().__setattr__('_state_dict_pre_hooks', OrderedDict()) + super().__setattr__('_load_state_dict_pre_hooks', OrderedDict()) + super().__setattr__('_load_state_dict_post_hooks', OrderedDict()) + super().__setattr__('_modules', OrderedDict()) + + if self.call_super_init: + super().__init__(*args, **kwargs) + + forward: Callable[..., Any] = _forward_unimplemented + + def register_buffer(self, name: str, tensor: Optional[Tensor], persistent: bool = True) -> None: + r"""Add a buffer to the module. + + This is typically used to register a buffer that should not to be + considered a model parameter. For example, BatchNorm's ``running_mean`` + is not a parameter, but is part of the module's state. Buffers, by + default, are persistent and will be saved alongside parameters. This + behavior can be changed by setting :attr:`persistent` to ``False``. The + only difference between a persistent buffer and a non-persistent buffer + is that the latter will not be a part of this module's + :attr:`state_dict`. + + Buffers can be accessed as attributes using given names. + + Args: + name (str): name of the buffer. The buffer can be accessed + from this module using the given name + tensor (Tensor or None): buffer to be registered. If ``None``, then operations + that run on buffers, such as :attr:`cuda`, are ignored. If ``None``, + the buffer is **not** included in the module's :attr:`state_dict`. + persistent (bool): whether the buffer is part of this module's + :attr:`state_dict`. + + Example:: + + >>> # xdoctest: +SKIP("undefined vars") + >>> self.register_buffer('running_mean', torch.zeros(num_features)) + + """ + if persistent is False and isinstance(self, torch.jit.ScriptModule): + raise RuntimeError("ScriptModule does not support non-persistent buffers") + + if '_buffers' not in self.__dict__: + raise AttributeError( + "cannot assign buffer before Module.__init__() call") + elif not isinstance(name, str): + raise TypeError(f"buffer name should be a string. Got {torch.typename(name)}") + elif '.' in name: + raise KeyError("buffer name can't contain \".\"") + elif name == '': + raise KeyError("buffer name can't be empty string \"\"") + elif hasattr(self, name) and name not in self._buffers: + raise KeyError(f"attribute '{name}' already exists") + elif tensor is not None and not isinstance(tensor, torch.Tensor): + raise TypeError(f"cannot assign '{torch.typename(tensor)}' object to buffer '{name}' " + "(torch Tensor or None required)" + ) + else: + for hook in _global_buffer_registration_hooks.values(): + output = hook(self, name, tensor) + if output is not None: + tensor = output + self._buffers[name] = tensor + if persistent: + self._non_persistent_buffers_set.discard(name) + else: + self._non_persistent_buffers_set.add(name) + + def register_parameter(self, name: str, param: Optional[Parameter]) -> None: + r"""Add a parameter to the module. + + The parameter can be accessed as an attribute using given name. + + Args: + name (str): name of the parameter. The parameter can be accessed + from this module using the given name + param (Parameter or None): parameter to be added to the module. If + ``None``, then operations that run on parameters, such as :attr:`cuda`, + are ignored. If ``None``, the parameter is **not** included in the + module's :attr:`state_dict`. + """ + if '_parameters' not in self.__dict__: + raise AttributeError( + "cannot assign parameter before Module.__init__() call") + + elif not isinstance(name, str): + raise TypeError(f"parameter name should be a string. Got {torch.typename(name)}") + elif '.' in name: + raise KeyError("parameter name can't contain \".\"") + elif name == '': + raise KeyError("parameter name can't be empty string \"\"") + elif hasattr(self, name) and name not in self._parameters: + raise KeyError(f"attribute '{name}' already exists") + + if param is None: + self._parameters[name] = None + elif not isinstance(param, Parameter): + raise TypeError(f"cannot assign '{torch.typename(param)}' object to parameter '{name}' " + "(torch.nn.Parameter or None required)" + ) + elif param.grad_fn: + raise ValueError( + f"Cannot assign non-leaf Tensor to parameter '{name}'. Model " + f"parameters must be created explicitly. To express '{name}' " + "as a function of another Tensor, compute the value in " + "the forward() method.") + else: + for hook in _global_parameter_registration_hooks.values(): + output = hook(self, name, param) + if output is not None: + param = output + self._parameters[name] = param + + def add_module(self, name: str, module: Optional['Module']) -> None: + r"""Add a child module to the current module. + + The module can be accessed as an attribute using the given name. + + Args: + name (str): name of the child module. The child module can be + accessed from this module using the given name + module (Module): child module to be added to the module. + """ + if not isinstance(module, Module) and module is not None: + raise TypeError(f"{torch.typename(module)} is not a Module subclass") + elif not isinstance(name, str): + raise TypeError(f"module name should be a string. Got {torch.typename(name)}") + elif hasattr(self, name) and name not in self._modules: + raise KeyError(f"attribute '{name}' already exists") + elif '.' in name: + raise KeyError(f"module name can't contain \".\", got: {name}") + elif name == '': + raise KeyError("module name can't be empty string \"\"") + for hook in _global_module_registration_hooks.values(): + output = hook(self, name, module) + if output is not None: + module = output + self._modules[name] = module + + def register_module(self, name: str, module: Optional['Module']) -> None: + r"""Alias for :func:`add_module`.""" + self.add_module(name, module) + + def get_submodule(self, target: str) -> "Module": + """Return the submodule given by ``target`` if it exists, otherwise throw an error. + + For example, let's say you have an ``nn.Module`` ``A`` that + looks like this: + + .. code-block:: text + + A( + (net_b): Module( + (net_c): Module( + (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2)) + ) + (linear): Linear(in_features=100, out_features=200, bias=True) + ) + ) + + (The diagram shows an ``nn.Module`` ``A``. ``A`` has a nested + submodule ``net_b``, which itself has two submodules ``net_c`` + and ``linear``. ``net_c`` then has a submodule ``conv``.) + + To check whether or not we have the ``linear`` submodule, we + would call ``get_submodule("net_b.linear")``. To check whether + we have the ``conv`` submodule, we would call + ``get_submodule("net_b.net_c.conv")``. + + The runtime of ``get_submodule`` is bounded by the degree + of module nesting in ``target``. A query against + ``named_modules`` achieves the same result, but it is O(N) in + the number of transitive modules. So, for a simple check to see + if some submodule exists, ``get_submodule`` should always be + used. + + Args: + target: The fully-qualified string name of the submodule + to look for. (See above example for how to specify a + fully-qualified string.) + + Returns: + torch.nn.Module: The submodule referenced by ``target`` + + Raises: + AttributeError: If the target string references an invalid + path or resolves to something that is not an + ``nn.Module`` + """ + if target == "": + return self + + atoms: List[str] = target.split(".") + mod: torch.nn.Module = self + + for item in atoms: + + if not hasattr(mod, item): + raise AttributeError(mod._get_name() + " has no " + "attribute `" + item + "`") + + mod = getattr(mod, item) + + if not isinstance(mod, torch.nn.Module): + raise AttributeError("`" + item + "` is not " + "an nn.Module") + + return mod + + def get_parameter(self, target: str) -> "Parameter": + """Return the parameter given by ``target`` if it exists, otherwise throw an error. + + See the docstring for ``get_submodule`` for a more detailed + explanation of this method's functionality as well as how to + correctly specify ``target``. + + Args: + target: The fully-qualified string name of the Parameter + to look for. (See ``get_submodule`` for how to specify a + fully-qualified string.) + + Returns: + torch.nn.Parameter: The Parameter referenced by ``target`` + + Raises: + AttributeError: If the target string references an invalid + path or resolves to something that is not an + ``nn.Parameter`` + """ + module_path, _, param_name = target.rpartition(".") + + mod: torch.nn.Module = self.get_submodule(module_path) + + if not hasattr(mod, param_name): + raise AttributeError(mod._get_name() + " has no attribute `" + + param_name + "`") + + param: torch.nn.Parameter = getattr(mod, param_name) + + if not isinstance(param, torch.nn.Parameter): + raise AttributeError("`" + param_name + "` is not an " + "nn.Parameter") + + return param + + def get_buffer(self, target: str) -> "Tensor": + """Return the buffer given by ``target`` if it exists, otherwise throw an error. + + See the docstring for ``get_submodule`` for a more detailed + explanation of this method's functionality as well as how to + correctly specify ``target``. + + Args: + target: The fully-qualified string name of the buffer + to look for. (See ``get_submodule`` for how to specify a + fully-qualified string.) + + Returns: + torch.Tensor: The buffer referenced by ``target`` + + Raises: + AttributeError: If the target string references an invalid + path or resolves to something that is not a + buffer + """ + module_path, _, buffer_name = target.rpartition(".") + + mod: torch.nn.Module = self.get_submodule(module_path) + + if not hasattr(mod, buffer_name): + raise AttributeError(mod._get_name() + " has no attribute `" + + buffer_name + "`") + + buffer: torch.Tensor = getattr(mod, buffer_name) + + if buffer_name not in mod._buffers: + raise AttributeError("`" + buffer_name + "` is not a buffer") + + return buffer + + def get_extra_state(self) -> Any: + """Return any extra state to include in the module's state_dict. + + Implement this and a corresponding :func:`set_extra_state` for your module + if you need to store extra state. This function is called when building the + module's `state_dict()`. + + Note that extra state should be picklable to ensure working serialization + of the state_dict. We only provide provide backwards compatibility guarantees + for serializing Tensors; other objects may break backwards compatibility if + their serialized pickled form changes. + + Returns: + object: Any extra state to store in the module's state_dict + """ + raise RuntimeError( + "Reached a code path in Module.get_extra_state() that should never be called. " + "Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml " + "to report this bug.") + + def set_extra_state(self, state: Any) -> None: + """Set extra state contained in the loaded `state_dict`. + + This function is called from :func:`load_state_dict` to handle any extra state + found within the `state_dict`. Implement this function and a corresponding + :func:`get_extra_state` for your module if you need to store extra state within its + `state_dict`. + + Args: + state (dict): Extra state from the `state_dict` + """ + raise RuntimeError( + "Reached a code path in Module.set_extra_state() that should never be called. " + "Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml " + "to report this bug.") + + def _apply(self, fn, recurse=True): + if recurse: + for module in self.children(): + module._apply(fn) + + def compute_should_use_set_data(tensor, tensor_applied): + if torch._has_compatible_shallow_copy_type(tensor, tensor_applied): + # If the new tensor has compatible tensor type as the existing tensor, + # the current behavior is to change the tensor in-place using `.data =`, + # and the future behavior is to overwrite the existing tensor. However, + # changing the current behavior is a BC-breaking change, and we want it + # to happen in future releases. So for now we introduce the + # `torch.__future__.get_overwrite_module_params_on_conversion()` + # global flag to let the user control whether they want the future + # behavior of overwriting the existing tensor or not. + return not torch.__future__.get_overwrite_module_params_on_conversion() + else: + return False + + should_use_swap_tensors = torch.__future__.get_swap_module_params_on_conversion() + + for key, param in self._parameters.items(): + if param is None: + continue + # Tensors stored in modules are graph leaves, and we don't want to + # track autograd history of `param_applied`, so we have to use + # `with torch.no_grad():` + with torch.no_grad(): + param_applied = fn(param) + p_should_use_set_data = compute_should_use_set_data(param, param_applied) + + # subclasses may have multiple child tensors so we need to use swap_tensors + p_should_use_swap_tensors = should_use_swap_tensors or is_traceable_wrapper_subclass(param_applied) + + param_grad = param.grad + if p_should_use_swap_tensors: + try: + if param_grad is not None: + # Accessing param.grad makes its at::Tensor's use_count 2, which will prevent swapping. + # Decrement use count of the gradient by setting to None + param.grad = None + param_applied = torch.nn.Parameter(param_applied, requires_grad=param.requires_grad) + torch.utils.swap_tensors(param, param_applied) + except Exception as e: + if param_grad is not None: + param.grad = param_grad + raise RuntimeError(f"_apply(): Couldn't swap {self._get_name()}.{key}") from e + out_param = param + elif p_should_use_set_data: + param.data = param_applied + out_param = param + else: + assert isinstance(param, Parameter) + assert param.is_leaf + out_param = Parameter(param_applied, param.requires_grad) + self._parameters[key] = out_param + + if param_grad is not None: + with torch.no_grad(): + grad_applied = fn(param_grad) + g_should_use_set_data = compute_should_use_set_data(param_grad, grad_applied) + if p_should_use_swap_tensors: + grad_applied.requires_grad_(param_grad.requires_grad) + try: + torch.utils.swap_tensors(param_grad, grad_applied) + except Exception as e: + raise RuntimeError(f"_apply(): Couldn't swap {self._get_name()}.{key}.grad") from e + out_param.grad = param_grad + elif g_should_use_set_data: + assert out_param.grad is not None + out_param.grad.data = grad_applied + else: + assert param_grad.is_leaf + out_param.grad = grad_applied.requires_grad_(param_grad.requires_grad) + + for key, buf in self._buffers.items(): + if buf is not None: + self._buffers[key] = fn(buf) + + return self + + def apply(self: T, fn: Callable[['Module'], None]) -> T: + r"""Apply ``fn`` recursively to every submodule (as returned by ``.children()``) as well as self. + + Typical use includes initializing the parameters of a model + (see also :ref:`nn-init-doc`). + + Args: + fn (:class:`Module` -> None): function to be applied to each submodule + + Returns: + Module: self + + Example:: + + >>> @torch.no_grad() + >>> def init_weights(m): + >>> print(m) + >>> if type(m) == nn.Linear: + >>> m.weight.fill_(1.0) + >>> print(m.weight) + >>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2)) + >>> net.apply(init_weights) + Linear(in_features=2, out_features=2, bias=True) + Parameter containing: + tensor([[1., 1.], + [1., 1.]], requires_grad=True) + Linear(in_features=2, out_features=2, bias=True) + Parameter containing: + tensor([[1., 1.], + [1., 1.]], requires_grad=True) + Sequential( + (0): Linear(in_features=2, out_features=2, bias=True) + (1): Linear(in_features=2, out_features=2, bias=True) + ) + + """ + for module in self.children(): + module.apply(fn) + fn(self) + return self + + def cuda(self: T, device: Optional[Union[int, device]] = None) -> T: + r"""Move all model parameters and buffers to the GPU. + + This also makes associated parameters and buffers different objects. So + it should be called before constructing optimizer if the module will + live on GPU while being optimized. + + .. note:: + This method modifies the module in-place. + + Args: + device (int, optional): if specified, all parameters will be + copied to that device + + Returns: + Module: self + """ + return self._apply(lambda t: t.cuda(device)) + + def ipu(self: T, device: Optional[Union[int, device]] = None) -> T: + r"""Move all model parameters and buffers to the IPU. + + This also makes associated parameters and buffers different objects. So + it should be called before constructing optimizer if the module will + live on IPU while being optimized. + + .. note:: + This method modifies the module in-place. + + Arguments: + device (int, optional): if specified, all parameters will be + copied to that device + + Returns: + Module: self + """ + return self._apply(lambda t: t.ipu(device)) + + def xpu(self: T, device: Optional[Union[int, device]] = None) -> T: + r"""Move all model parameters and buffers to the XPU. + + This also makes associated parameters and buffers different objects. So + it should be called before constructing optimizer if the module will + live on XPU while being optimized. + + .. note:: + This method modifies the module in-place. + + Arguments: + device (int, optional): if specified, all parameters will be + copied to that device + + Returns: + Module: self + """ + return self._apply(lambda t: t.xpu(device)) + + def cpu(self: T) -> T: + r"""Move all model parameters and buffers to the CPU. + + .. note:: + This method modifies the module in-place. + + Returns: + Module: self + """ + return self._apply(lambda t: t.cpu()) + + def type(self: T, dst_type: Union[dtype, str]) -> T: + r"""Casts all parameters and buffers to :attr:`dst_type`. + + .. note:: + This method modifies the module in-place. + + Args: + dst_type (type or string): the desired type + + Returns: + Module: self + """ + return self._apply(lambda t: t.type(dst_type)) + + def float(self: T) -> T: + r"""Casts all floating point parameters and buffers to ``float`` datatype. + + .. note:: + This method modifies the module in-place. + + Returns: + Module: self + """ + return self._apply(lambda t: t.float() if t.is_floating_point() else t) + + def double(self: T) -> T: + r"""Casts all floating point parameters and buffers to ``double`` datatype. + + .. note:: + This method modifies the module in-place. + + Returns: + Module: self + """ + return self._apply(lambda t: t.double() if t.is_floating_point() else t) + + def half(self: T) -> T: + r"""Casts all floating point parameters and buffers to ``half`` datatype. + + .. note:: + This method modifies the module in-place. + + Returns: + Module: self + """ + return self._apply(lambda t: t.half() if t.is_floating_point() else t) + + def bfloat16(self: T) -> T: + r"""Casts all floating point parameters and buffers to ``bfloat16`` datatype. + + .. note:: + This method modifies the module in-place. + + Returns: + Module: self + """ + return self._apply(lambda t: t.bfloat16() if t.is_floating_point() else t) + + def to_empty(self: T, *, device: Optional[DeviceLikeType], recurse: bool = True) -> T: + r"""Move the parameters and buffers to the specified device without copying storage. + + Args: + device (:class:`torch.device`): The desired device of the parameters + and buffers in this module. + recurse (bool): Whether parameters and buffers of submodules should + be recursively moved to the specified device. + + Returns: + Module: self + """ + return self._apply(lambda t: torch.empty_like(t, device=device), recurse=recurse) + + @overload + def to(self, device: Optional[DeviceLikeType] = ..., dtype: Optional[dtype] = ..., + non_blocking: bool = ...) -> Self: + ... + + @overload + def to(self, dtype: dtype, non_blocking: bool = ...) -> Self: + ... + + @overload + def to(self, tensor: Tensor, non_blocking: bool = ...) -> Self: + ... + + def to(self, *args, **kwargs): + r"""Move and/or cast the parameters and buffers. + + This can be called as + + .. function:: to(device=None, dtype=None, non_blocking=False) + :noindex: + + .. function:: to(dtype, non_blocking=False) + :noindex: + + .. function:: to(tensor, non_blocking=False) + :noindex: + + .. function:: to(memory_format=torch.channels_last) + :noindex: + + Its signature is similar to :meth:`torch.Tensor.to`, but only accepts + floating point or complex :attr:`dtype`\ s. In addition, this method will + only cast the floating point or complex parameters and buffers to :attr:`dtype` + (if given). The integral parameters and buffers will be moved + :attr:`device`, if that is given, but with dtypes unchanged. When + :attr:`non_blocking` is set, it tries to convert/move asynchronously + with respect to the host if possible, e.g., moving CPU Tensors with + pinned memory to CUDA devices. + + See below for examples. + + .. note:: + This method modifies the module in-place. + + Args: + device (:class:`torch.device`): the desired device of the parameters + and buffers in this module + dtype (:class:`torch.dtype`): the desired floating point or complex dtype of + the parameters and buffers in this module + tensor (torch.Tensor): Tensor whose dtype and device are the desired + dtype and device for all parameters and buffers in this module + memory_format (:class:`torch.memory_format`): the desired memory + format for 4D parameters and buffers in this module (keyword + only argument) + + Returns: + Module: self + + Examples:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> linear = nn.Linear(2, 2) + >>> linear.weight + Parameter containing: + tensor([[ 0.1913, -0.3420], + [-0.5113, -0.2325]]) + >>> linear.to(torch.double) + Linear(in_features=2, out_features=2, bias=True) + >>> linear.weight + Parameter containing: + tensor([[ 0.1913, -0.3420], + [-0.5113, -0.2325]], dtype=torch.float64) + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA1) + >>> gpu1 = torch.device("cuda:1") + >>> linear.to(gpu1, dtype=torch.half, non_blocking=True) + Linear(in_features=2, out_features=2, bias=True) + >>> linear.weight + Parameter containing: + tensor([[ 0.1914, -0.3420], + [-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1') + >>> cpu = torch.device("cpu") + >>> linear.to(cpu) + Linear(in_features=2, out_features=2, bias=True) + >>> linear.weight + Parameter containing: + tensor([[ 0.1914, -0.3420], + [-0.5112, -0.2324]], dtype=torch.float16) + + >>> linear = nn.Linear(2, 2, bias=None).to(torch.cdouble) + >>> linear.weight + Parameter containing: + tensor([[ 0.3741+0.j, 0.2382+0.j], + [ 0.5593+0.j, -0.4443+0.j]], dtype=torch.complex128) + >>> linear(torch.ones(3, 2, dtype=torch.cdouble)) + tensor([[0.6122+0.j, 0.1150+0.j], + [0.6122+0.j, 0.1150+0.j], + [0.6122+0.j, 0.1150+0.j]], dtype=torch.complex128) + + """ + device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) + + if dtype is not None: + if not (dtype.is_floating_point or dtype.is_complex): + raise TypeError('nn.Module.to only accepts floating point or complex ' + f'dtypes, but got desired dtype={dtype}') + if dtype.is_complex: + warnings.warn( + "Complex modules are a new feature under active development whose design may change, " + "and some modules might not work as expected when using complex tensors as parameters or buffers. " + "Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml " + "if a complex module does not work as expected.") + + def convert(t): + try: + if convert_to_format is not None and t.dim() in (4, 5): + return t.to( + device, + dtype if t.is_floating_point() or t.is_complex() else None, + non_blocking, + memory_format=convert_to_format, + ) + return t.to( + device, + dtype if t.is_floating_point() or t.is_complex() else None, + non_blocking, + ) + except NotImplementedError as e: + if str(e) == "Cannot copy out of meta tensor; no data!": + raise NotImplementedError( + f"{e} Please use torch.nn.Module.to_empty() instead of torch.nn.Module.to() " + f"when moving module from meta to a different device." + ) from None + else: + raise + + return self._apply(convert) + + def register_full_backward_pre_hook( + self, + hook: Callable[["Module", _grad_t], Union[None, _grad_t]], + prepend: bool = False, + ) -> RemovableHandle: + r"""Register a backward pre-hook on the module. + + The hook will be called every time the gradients for the module are computed. + The hook should have the following signature:: + + hook(module, grad_output) -> tuple[Tensor] or None + + The :attr:`grad_output` is a tuple. The hook should + not modify its arguments, but it can optionally return a new gradient with + respect to the output that will be used in place of :attr:`grad_output` in + subsequent computations. Entries in :attr:`grad_output` will be ``None`` for + all non-Tensor arguments. + + For technical reasons, when this hook is applied to a Module, its forward function will + receive a view of each Tensor passed to the Module. Similarly the caller will receive a view + of each Tensor returned by the Module's forward function. + + .. warning :: + Modifying inputs inplace is not allowed when using backward hooks and + will raise an error. + + Args: + hook (Callable): The user-defined hook to be registered. + prepend (bool): If true, the provided ``hook`` will be fired before + all existing ``backward_pre`` hooks on this + :class:`torch.nn.modules.Module`. Otherwise, the provided + ``hook`` will be fired after all existing ``backward_pre`` hooks + on this :class:`torch.nn.modules.Module`. Note that global + ``backward_pre`` hooks registered with + :func:`register_module_full_backward_pre_hook` will fire before + all hooks registered by this method. + + Returns: + :class:`torch.utils.hooks.RemovableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + + """ + handle = hooks.RemovableHandle(self._backward_pre_hooks) + self._backward_pre_hooks[handle.id] = hook + if prepend: + self._backward_pre_hooks.move_to_end(handle.id, last=False) # type: ignore[attr-defined] + return handle + + def register_backward_hook( + self, hook: Callable[['Module', _grad_t, _grad_t], Union[None, _grad_t]] + ) -> RemovableHandle: + r"""Register a backward hook on the module. + + This function is deprecated in favor of :meth:`~torch.nn.Module.register_full_backward_hook` and + the behavior of this function will change in future versions. + + Returns: + :class:`torch.utils.hooks.RemovableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + + """ + if self._is_full_backward_hook is True: + raise RuntimeError("Cannot use both regular backward hooks and full backward hooks on a " + "single Module. Please use only one of them.") + + self._is_full_backward_hook = False + + handle = hooks.RemovableHandle(self._backward_hooks) + self._backward_hooks[handle.id] = hook + return handle + + def register_full_backward_hook( + self, + hook: Callable[["Module", _grad_t, _grad_t], Union[None, _grad_t]], + prepend: bool = False, + ) -> RemovableHandle: + r"""Register a backward hook on the module. + + The hook will be called every time the gradients with respect to a module + are computed, i.e. the hook will execute if and only if the gradients with + respect to module outputs are computed. The hook should have the following + signature:: + + hook(module, grad_input, grad_output) -> tuple(Tensor) or None + + The :attr:`grad_input` and :attr:`grad_output` are tuples that contain the gradients + with respect to the inputs and outputs respectively. The hook should + not modify its arguments, but it can optionally return a new gradient with + respect to the input that will be used in place of :attr:`grad_input` in + subsequent computations. :attr:`grad_input` will only correspond to the inputs given + as positional arguments and all kwarg arguments are ignored. Entries + in :attr:`grad_input` and :attr:`grad_output` will be ``None`` for all non-Tensor + arguments. + + For technical reasons, when this hook is applied to a Module, its forward function will + receive a view of each Tensor passed to the Module. Similarly the caller will receive a view + of each Tensor returned by the Module's forward function. + + .. warning :: + Modifying inputs or outputs inplace is not allowed when using backward hooks and + will raise an error. + + Args: + hook (Callable): The user-defined hook to be registered. + prepend (bool): If true, the provided ``hook`` will be fired before + all existing ``backward`` hooks on this + :class:`torch.nn.modules.Module`. Otherwise, the provided + ``hook`` will be fired after all existing ``backward`` hooks on + this :class:`torch.nn.modules.Module`. Note that global + ``backward`` hooks registered with + :func:`register_module_full_backward_hook` will fire before + all hooks registered by this method. + + Returns: + :class:`torch.utils.hooks.RemovableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + + """ + if self._is_full_backward_hook is False: + raise RuntimeError("Cannot use both regular backward hooks and full backward hooks on a " + "single Module. Please use only one of them.") + + self._is_full_backward_hook = True + + handle = hooks.RemovableHandle(self._backward_hooks) + self._backward_hooks[handle.id] = hook + if prepend: + self._backward_hooks.move_to_end(handle.id, last=False) # type: ignore[attr-defined] + return handle + + def _get_backward_hooks(self): + r"""Return the backward hooks for use in the call function. + + It returns two lists, one with the full backward hooks and one with the non-full + backward hooks. + """ + full_backward_hooks: List[Callable] = [] + if (_global_is_full_backward_hook is True): + full_backward_hooks += _global_backward_hooks.values() + if (self._is_full_backward_hook is True): + full_backward_hooks += self._backward_hooks.values() + + non_full_backward_hooks: List[Callable] = [] + if (_global_is_full_backward_hook is False): + non_full_backward_hooks += _global_backward_hooks.values() + if (self._is_full_backward_hook is False): + non_full_backward_hooks += self._backward_hooks.values() + + return full_backward_hooks, non_full_backward_hooks + + def _get_backward_pre_hooks(self): + backward_pre_hooks: List[Callable] = [] + backward_pre_hooks += _global_backward_pre_hooks.values() + backward_pre_hooks += self._backward_pre_hooks.values() + + return backward_pre_hooks + + def _maybe_warn_non_full_backward_hook(self, inputs, result, grad_fn): + if not isinstance(result, torch.Tensor): + if not (isinstance(result, tuple) and all(isinstance(r, torch.Tensor) for r in result)): + warnings.warn("Using non-full backward hooks on a Module that does not return a " + "single Tensor or a tuple of Tensors is deprecated and will be removed " + "in future versions. This hook will be missing some of the grad_output. " + "Please use register_full_backward_hook to get the documented behavior.") + return + else: + result = (result,) + + if not isinstance(inputs, torch.Tensor): + if not (isinstance(inputs, tuple) and all(isinstance(i, torch.Tensor) for i in inputs)): + warnings.warn("Using non-full backward hooks on a Module that does not take as input a " + "single Tensor or a tuple of Tensors is deprecated and will be removed " + "in future versions. This hook will be missing some of the grad_input. " + "Please use register_full_backward_hook to get the documented behavior.") + return + else: + inputs = (inputs,) + + # At this point we are sure that inputs and result are tuple of Tensors + out_grad_fn = {r.grad_fn for r in result if r.grad_fn is not None} + if len(out_grad_fn) == 0 or (len(out_grad_fn) == 1 and grad_fn not in out_grad_fn): + warnings.warn("Using a non-full backward hook when outputs are nested in python data structure " + "is deprecated and will be removed in future versions. This hook will be missing " + "some grad_output.") + elif len(out_grad_fn) > 1: + warnings.warn("Using a non-full backward hook when outputs are generated by different autograd Nodes " + "is deprecated and will be removed in future versions. This hook will be missing " + "some grad_output. Please use register_full_backward_hook to get the documented behavior.") + else: + # At this point the grad_output part of the hook will most likely be correct + inputs_grad_fn = {i.grad_fn for i in inputs if i.grad_fn is not None} + + next_functions = {n[0] for n in grad_fn.next_functions} + + if inputs_grad_fn != next_functions: + warnings.warn("Using a non-full backward hook when the forward contains multiple autograd Nodes " + "is deprecated and will be removed in future versions. This hook will be missing " + "some grad_input. Please use register_full_backward_hook to get the documented " + "behavior.") + + def register_forward_pre_hook( + self, + hook: Union[ + Callable[[T, Tuple[Any, ...]], Optional[Any]], + Callable[[T, Tuple[Any, ...], Dict[str, Any]], Optional[Tuple[Any, Dict[str, Any]]]], + ], + *, + prepend: bool = False, + with_kwargs: bool = False, + ) -> RemovableHandle: + r"""Register a forward pre-hook on the module. + + The hook will be called every time before :func:`forward` is invoked. + + + If ``with_kwargs`` is false or not specified, the input contains only + the positional arguments given to the module. Keyword arguments won't be + passed to the hooks and only to the ``forward``. The hook can modify the + input. User can either return a tuple or a single modified value in the + hook. We will wrap the value into a tuple if a single value is returned + (unless that value is already a tuple). The hook should have the + following signature:: + + hook(module, args) -> None or modified input + + If ``with_kwargs`` is true, the forward pre-hook will be passed the + kwargs given to the forward function. And if the hook modifies the + input, both the args and kwargs should be returned. The hook should have + the following signature:: + + hook(module, args, kwargs) -> None or a tuple of modified input and kwargs + + Args: + hook (Callable): The user defined hook to be registered. + prepend (bool): If true, the provided ``hook`` will be fired before + all existing ``forward_pre`` hooks on this + :class:`torch.nn.modules.Module`. Otherwise, the provided + ``hook`` will be fired after all existing ``forward_pre`` hooks + on this :class:`torch.nn.modules.Module`. Note that global + ``forward_pre`` hooks registered with + :func:`register_module_forward_pre_hook` will fire before all + hooks registered by this method. + Default: ``False`` + with_kwargs (bool): If true, the ``hook`` will be passed the kwargs + given to the forward function. + Default: ``False`` + + Returns: + :class:`torch.utils.hooks.RemovableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + """ + handle = hooks.RemovableHandle( + self._forward_pre_hooks, + extra_dict=self._forward_pre_hooks_with_kwargs + ) + self._forward_pre_hooks[handle.id] = hook + if with_kwargs: + self._forward_pre_hooks_with_kwargs[handle.id] = True + + if prepend: + self._forward_pre_hooks.move_to_end(handle.id, last=False) # type: ignore[attr-defined] + return handle + + def register_forward_hook( + self, + hook: Union[ + Callable[[T, Tuple[Any, ...], Any], Optional[Any]], + Callable[[T, Tuple[Any, ...], Dict[str, Any], Any], Optional[Any]], + ], + *, + prepend: bool = False, + with_kwargs: bool = False, + always_call: bool = False, + ) -> RemovableHandle: + r"""Register a forward hook on the module. + + The hook will be called every time after :func:`forward` has computed an output. + + If ``with_kwargs`` is ``False`` or not specified, the input contains only + the positional arguments given to the module. Keyword arguments won't be + passed to the hooks and only to the ``forward``. The hook can modify the + output. It can modify the input inplace but it will not have effect on + forward since this is called after :func:`forward` is called. The hook + should have the following signature:: + + hook(module, args, output) -> None or modified output + + If ``with_kwargs`` is ``True``, the forward hook will be passed the + ``kwargs`` given to the forward function and be expected to return the + output possibly modified. The hook should have the following signature:: + + hook(module, args, kwargs, output) -> None or modified output + + Args: + hook (Callable): The user defined hook to be registered. + prepend (bool): If ``True``, the provided ``hook`` will be fired + before all existing ``forward`` hooks on this + :class:`torch.nn.modules.Module`. Otherwise, the provided + ``hook`` will be fired after all existing ``forward`` hooks on + this :class:`torch.nn.modules.Module`. Note that global + ``forward`` hooks registered with + :func:`register_module_forward_hook` will fire before all hooks + registered by this method. + Default: ``False`` + with_kwargs (bool): If ``True``, the ``hook`` will be passed the + kwargs given to the forward function. + Default: ``False`` + always_call (bool): If ``True`` the ``hook`` will be run regardless of + whether an exception is raised while calling the Module. + Default: ``False`` + + Returns: + :class:`torch.utils.hooks.RemovableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + """ + handle = hooks.RemovableHandle( + self._forward_hooks, + extra_dict=[self._forward_hooks_with_kwargs, self._forward_hooks_always_called], + ) + self._forward_hooks[handle.id] = hook + if with_kwargs: + self._forward_hooks_with_kwargs[handle.id] = True + if always_call: + self._forward_hooks_always_called[handle.id] = True + if prepend: + self._forward_hooks.move_to_end(handle.id, last=False) # type: ignore[attr-defined] + return handle + + def _slow_forward(self, *input, **kwargs): + tracing_state = torch._C._get_tracing_state() + if not tracing_state or isinstance(self.forward, torch._C.ScriptMethod): + return self.forward(*input, **kwargs) + recording_scopes = torch.jit._trace._trace_module_map is not None + if recording_scopes: + # type ignore was added because at this point one knows that + # torch.jit._trace._trace_module_map is not Optional and has type Dict[Any, Any] + name = torch.jit._trace._trace_module_map[self] if self in torch.jit._trace._trace_module_map else None # type: ignore[index, operator] # noqa: B950 + if name: + tracing_state.push_scope(name) + else: + recording_scopes = False + try: + result = self.forward(*input, **kwargs) + finally: + if recording_scopes: + tracing_state.pop_scope() + return result + + def _wrapped_call_impl(self, *args, **kwargs): + if self._compiled_call_impl is not None: + return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] + else: + return self._call_impl(*args, **kwargs) + + def _call_impl(self, *args, **kwargs): + forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.forward) + # If we don't have any hooks, we want to skip the rest of the logic in + # this function, and just call forward. + if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks + or _global_backward_pre_hooks or _global_backward_hooks + or _global_forward_hooks or _global_forward_pre_hooks): + return forward_call(*args, **kwargs) + + try: + result = None + called_always_called_hooks = set() + + full_backward_hooks, non_full_backward_hooks = [], [] + backward_pre_hooks = [] + if self._backward_pre_hooks or _global_backward_pre_hooks: + backward_pre_hooks = self._get_backward_pre_hooks() + + if self._backward_hooks or _global_backward_hooks: + full_backward_hooks, non_full_backward_hooks = self._get_backward_hooks() + + if _global_forward_pre_hooks or self._forward_pre_hooks: + for hook_id, hook in ( + *_global_forward_pre_hooks.items(), + *self._forward_pre_hooks.items(), + ): + if hook_id in self._forward_pre_hooks_with_kwargs: + args_kwargs_result = hook(self, args, kwargs) # type: ignore[misc] + if args_kwargs_result is not None: + if isinstance(args_kwargs_result, tuple) and len(args_kwargs_result) == 2: + args, kwargs = args_kwargs_result + else: + raise RuntimeError( + "forward pre-hook must return None or a tuple " + f"of (new_args, new_kwargs), but got {args_kwargs_result}." + ) + else: + args_result = hook(self, args) + if args_result is not None: + if not isinstance(args_result, tuple): + args_result = (args_result,) + args = args_result + + bw_hook = None + if full_backward_hooks or backward_pre_hooks: + bw_hook = hooks.BackwardHook(self, full_backward_hooks, backward_pre_hooks) + args = bw_hook.setup_input_hook(args) + + result = forward_call(*args, **kwargs) + if _global_forward_hooks or self._forward_hooks: + for hook_id, hook in ( + *_global_forward_hooks.items(), + *self._forward_hooks.items(), + ): + # mark that always called hook is run + if hook_id in self._forward_hooks_always_called or hook_id in _global_forward_hooks_always_called: + called_always_called_hooks.add(hook_id) + + if hook_id in self._forward_hooks_with_kwargs: + hook_result = hook(self, args, kwargs, result) + else: + hook_result = hook(self, args, result) + + if hook_result is not None: + result = hook_result + + if bw_hook: + if not isinstance(result, (torch.Tensor, tuple)): + warnings.warn("For backward hooks to be called," + " module output should be a Tensor or a tuple of Tensors" + f" but received {type(result)}") + result = bw_hook.setup_output_hook(result) + + # Handle the non-full backward hooks + if non_full_backward_hooks: + var = result + while not isinstance(var, torch.Tensor): + if isinstance(var, dict): + var = next(v for v in var.values() if isinstance(v, torch.Tensor)) + else: + var = var[0] + grad_fn = var.grad_fn + if grad_fn is not None: + for hook in non_full_backward_hooks: + grad_fn.register_hook(_WrappedHook(hook, self)) + self._maybe_warn_non_full_backward_hook(args, result, grad_fn) + + return result + + except Exception: + # run always called hooks if they have not already been run + # For now only forward hooks have the always_call option but perhaps + # this functionality should be added to full backward hooks as well. + for hook_id, hook in _global_forward_hooks.items(): + if hook_id in _global_forward_hooks_always_called and hook_id not in called_always_called_hooks: # type: ignore[possibly-undefined] + try: + hook_result = hook(self, args, result) # type: ignore[possibly-undefined] + if hook_result is not None: + result = hook_result + except Exception as e: + warnings.warn("global module forward hook with ``always_call=True`` raised an exception " + f"that was silenced as another error was raised in forward: {str(e)}") + continue + + for hook_id, hook in self._forward_hooks.items(): + if hook_id in self._forward_hooks_always_called and hook_id not in called_always_called_hooks: # type: ignore[possibly-undefined] + try: + if hook_id in self._forward_hooks_with_kwargs: + hook_result = hook(self, args, kwargs, result) # type: ignore[possibly-undefined] + else: + hook_result = hook(self, args, result) # type: ignore[possibly-undefined] + if hook_result is not None: + result = hook_result + except Exception as e: + warnings.warn("module forward hook with ``always_call=True`` raised an exception " + f"that was silenced as another error was raised in forward: {str(e)}") + continue + # raise exception raised in try block + raise + + + __call__ : Callable[..., Any] = _wrapped_call_impl + + def __getstate__(self): + state = self.__dict__.copy() + state.pop("_compiled_call_impl", None) + return state + + def __setstate__(self, state): + self.__dict__.update(state) + + # Support loading old checkpoints that don't have the following attrs: + if '_forward_pre_hooks' not in self.__dict__: + self._forward_pre_hooks = OrderedDict() + if '_forward_pre_hooks_with_kwargs' not in self.__dict__: + self._forward_pre_hooks_with_kwargs = OrderedDict() + if '_forward_hooks_with_kwargs' not in self.__dict__: + self._forward_hooks_with_kwargs = OrderedDict() + if '_forward_hooks_always_called' not in self.__dict__: + self._forward_hooks_always_called = OrderedDict() + if '_state_dict_hooks' not in self.__dict__: + self._state_dict_hooks = OrderedDict() + if '_state_dict_pre_hooks' not in self.__dict__: + self._state_dict_pre_hooks = OrderedDict() + if '_load_state_dict_pre_hooks' not in self.__dict__: + self._load_state_dict_pre_hooks = OrderedDict() + if '_load_state_dict_post_hooks' not in self.__dict__: + self._load_state_dict_post_hooks = OrderedDict() + if '_non_persistent_buffers_set' not in self.__dict__: + self._non_persistent_buffers_set = set() + if '_is_full_backward_hook' not in self.__dict__: + self._is_full_backward_hook = None + if '_backward_pre_hooks' not in self.__dict__: + self._backward_pre_hooks = OrderedDict() + + # On the return type: + # We choose to return `Any` in the `__getattr__` type signature instead of a more strict `Union[Tensor, Module]`. + # This is done for better interop with various type checkers for the end users. + # Having a stricter return type doesn't play nicely with `register_buffer()` and forces + # people to excessively use type-ignores, asserts, casts, etc. + # See full discussion on the problems with returning `Union` here + # https://github.com/microsoft/pyright/issues/4213 + def __getattr__(self, name: str) -> Any: + if '_parameters' in self.__dict__: + _parameters = self.__dict__['_parameters'] + if name in _parameters: + return _parameters[name] + if '_buffers' in self.__dict__: + _buffers = self.__dict__['_buffers'] + if name in _buffers: + return _buffers[name] + if '_modules' in self.__dict__: + modules = self.__dict__['_modules'] + if name in modules: + return modules[name] + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") + + def __setattr__(self, name: str, value: Union[Tensor, 'Module']) -> None: + def remove_from(*dicts_or_sets): + for d in dicts_or_sets: + if name in d: + if isinstance(d, dict): + del d[name] + else: + d.discard(name) + + params = self.__dict__.get('_parameters') + if isinstance(value, Parameter): + if params is None: + raise AttributeError( + "cannot assign parameters before Module.__init__() call") + remove_from(self.__dict__, self._buffers, self._modules, self._non_persistent_buffers_set) + self.register_parameter(name, value) + elif params is not None and name in params: + if value is not None: + raise TypeError(f"cannot assign '{torch.typename(value)}' as parameter '{name}' " + "(torch.nn.Parameter or None expected)" + ) + self.register_parameter(name, value) + else: + modules = self.__dict__.get('_modules') + if isinstance(value, Module): + if modules is None: + raise AttributeError( + "cannot assign module before Module.__init__() call") + remove_from(self.__dict__, self._parameters, self._buffers, self._non_persistent_buffers_set) + for hook in _global_module_registration_hooks.values(): + output = hook(self, name, value) + if output is not None: + value = output + modules[name] = value + elif modules is not None and name in modules: + if value is not None: + raise TypeError(f"cannot assign '{torch.typename(value)}' as child module '{name}' " + "(torch.nn.Module or None expected)" + ) + for hook in _global_module_registration_hooks.values(): + output = hook(self, name, value) + if output is not None: + value = output + modules[name] = value + else: + buffers = self.__dict__.get('_buffers') + if buffers is not None and name in buffers: + if value is not None and not isinstance(value, torch.Tensor): + raise TypeError(f"cannot assign '{torch.typename(value)}' as buffer '{name}' " + "(torch.Tensor or None expected)" + ) + for hook in _global_buffer_registration_hooks.values(): + output = hook(self, name, value) + if output is not None: + value = output + buffers[name] = value + else: + super().__setattr__(name, value) + + def __delattr__(self, name): + if name in self._parameters: + del self._parameters[name] + elif name in self._buffers: + del self._buffers[name] + self._non_persistent_buffers_set.discard(name) + elif name in self._modules: + del self._modules[name] + else: + super().__delattr__(name) + + def _register_state_dict_hook(self, hook): + r"""Register a state-dict hook. + + These hooks will be called with arguments: `self`, `state_dict`, + `prefix`, `local_metadata`, after the `state_dict` of `self` is set. + Note that only parameters and buffers of `self` or its children are + guaranteed to exist in `state_dict`. The hooks may modify `state_dict` + inplace or return a new one. + """ + handle = hooks.RemovableHandle(self._state_dict_hooks) + self._state_dict_hooks[handle.id] = hook + return handle + + def register_state_dict_pre_hook(self, hook): + r"""Register a pre-hook for the :meth:`~torch.nn.Module.state_dict` method. + + These hooks will be called with arguments: ``self``, ``prefix``, + and ``keep_vars`` before calling ``state_dict`` on ``self``. The registered + hooks can be used to perform pre-processing before the ``state_dict`` + call is made. + """ + handle = hooks.RemovableHandle(self._state_dict_pre_hooks) + self._state_dict_pre_hooks[handle.id] = hook + return handle + + def _save_to_state_dict(self, destination, prefix, keep_vars): + r"""Save module state to the `destination` dictionary. + + The `destination` dictionary will contain the state + of the module, but not its descendants. This is called on every + submodule in :meth:`~torch.nn.Module.state_dict`. + + In rare cases, subclasses can achieve class-specific behavior by + overriding this method with custom logic. + + Args: + destination (dict): a dict where state will be stored + prefix (str): the prefix for parameters and buffers used in this + module + """ + for name, param in self._parameters.items(): + if param is not None: + destination[prefix + name] = param if keep_vars else param.detach() + for name, buf in self._buffers.items(): + if buf is not None and name not in self._non_persistent_buffers_set: + destination[prefix + name] = buf if keep_vars else buf.detach() + extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX + if getattr(self.__class__, "get_extra_state", Module.get_extra_state) is not Module.get_extra_state: + destination[extra_state_key] = self.get_extra_state() + + # The user can pass an optional arbitrary mappable object to `state_dict`, in which case `state_dict` returns + # back that same object. But if they pass nothing, an `OrderedDict` is created and returned. + T_destination = TypeVar('T_destination', bound=Dict[str, Any]) + + @overload + def state_dict(self, *, destination: T_destination, prefix: str = ..., keep_vars: bool = ...) -> T_destination: + ... + + @overload + def state_dict(self, *, prefix: str = ..., keep_vars: bool = ...) -> Dict[str, Any]: + ... + + # TODO: Change `*args` to `*` and remove the corresponding warning in docs when BC allows. + # Also remove the logic for arg parsing together. + def state_dict(self, *args, destination=None, prefix='', keep_vars=False): + r"""Return a dictionary containing references to the whole state of the module. + + Both parameters and persistent buffers (e.g. running averages) are + included. Keys are corresponding parameter and buffer names. + Parameters and buffers set to ``None`` are not included. + + .. note:: + The returned object is a shallow copy. It contains references + to the module's parameters and buffers. + + .. warning:: + Currently ``state_dict()`` also accepts positional arguments for + ``destination``, ``prefix`` and ``keep_vars`` in order. However, + this is being deprecated and keyword arguments will be enforced in + future releases. + + .. warning:: + Please avoid the use of argument ``destination`` as it is not + designed for end-users. + + Args: + destination (dict, optional): If provided, the state of module will + be updated into the dict and the same object is returned. + Otherwise, an ``OrderedDict`` will be created and returned. + Default: ``None``. + prefix (str, optional): a prefix added to parameter and buffer + names to compose the keys in state_dict. Default: ``''``. + keep_vars (bool, optional): by default the :class:`~torch.Tensor` s + returned in the state dict are detached from autograd. If it's + set to ``True``, detaching will not be performed. + Default: ``False``. + + Returns: + dict: + a dictionary containing a whole state of the module + + Example:: + + >>> # xdoctest: +SKIP("undefined vars") + >>> module.state_dict().keys() + ['bias', 'weight'] + + """ + # TODO: Remove `args` and the parsing logic when BC allows. + if len(args) > 0: + if destination is None: + destination = args[0] + if len(args) > 1 and prefix == '': + prefix = args[1] + if len(args) > 2 and keep_vars is False: + keep_vars = args[2] + # DeprecationWarning is ignored by default + warnings.warn( + "Positional args are being deprecated, use kwargs instead. Refer to " + "https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.state_dict" + " for details.") + + if destination is None: + destination = OrderedDict() + destination._metadata = OrderedDict() + + local_metadata = dict(version=self._version) + if hasattr(destination, "_metadata"): + destination._metadata[prefix[:-1]] = local_metadata + + for hook in self._state_dict_pre_hooks.values(): + hook(self, prefix, keep_vars) + self._save_to_state_dict(destination, prefix, keep_vars) + for name, module in self._modules.items(): + if module is not None: + module.state_dict(destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars) + for hook in self._state_dict_hooks.values(): + hook_result = hook(self, destination, prefix, local_metadata) + if hook_result is not None: + destination = hook_result + return destination + + def _register_load_state_dict_pre_hook(self, hook, with_module=False): + r"""Register a pre-hook for the :meth:`~torch.nn.Module.load_state_dict` method. + + These hooks will be called with arguments: `state_dict`, `prefix`, + `local_metadata`, `strict`, `missing_keys`, `unexpected_keys`, + `error_msgs`, before loading `state_dict` into `self`. These arguments + are exactly the same as those of `_load_from_state_dict`. + + If ``with_module`` is ``True``, then the first argument to the hook is + an instance of the module. + + Arguments: + hook (Callable): Callable hook that will be invoked before + loading the state dict. + with_module (bool, optional): Whether or not to pass the module + instance to the hook as the first parameter. + """ + handle = hooks.RemovableHandle(self._load_state_dict_pre_hooks) + self._load_state_dict_pre_hooks[handle.id] = _WrappedHook(hook, self if with_module else None) + return handle + + def register_load_state_dict_post_hook(self, hook): + r"""Register a post hook to be run after module's ``load_state_dict`` is called. + + It should have the following signature:: + hook(module, incompatible_keys) -> None + + The ``module`` argument is the current module that this hook is registered + on, and the ``incompatible_keys`` argument is a ``NamedTuple`` consisting + of attributes ``missing_keys`` and ``unexpected_keys``. ``missing_keys`` + is a ``list`` of ``str`` containing the missing keys and + ``unexpected_keys`` is a ``list`` of ``str`` containing the unexpected keys. + + The given incompatible_keys can be modified inplace if needed. + + Note that the checks performed when calling :func:`load_state_dict` with + ``strict=True`` are affected by modifications the hook makes to + ``missing_keys`` or ``unexpected_keys``, as expected. Additions to either + set of keys will result in an error being thrown when ``strict=True``, and + clearing out both missing and unexpected keys will avoid an error. + + Returns: + :class:`torch.utils.hooks.RemovableHandle`: + a handle that can be used to remove the added hook by calling + ``handle.remove()`` + """ + handle = hooks.RemovableHandle(self._load_state_dict_post_hooks) + self._load_state_dict_post_hooks[handle.id] = hook + return handle + + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + r"""Copy parameters and buffers from :attr:`state_dict` into only this module, but not its descendants. + + This is called on every submodule + in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this + module in input :attr:`state_dict` is provided as :attr:`local_metadata`. + For state dicts without metadata, :attr:`local_metadata` is empty. + Subclasses can achieve class-specific backward compatible loading using + the version number at `local_metadata.get("version", None)`. + Additionally, :attr:`local_metadata` can also contain the key + `assign_to_params_buffers` that indicates whether keys should be + assigned their corresponding tensor in the state_dict. + + .. note:: + :attr:`state_dict` is not the same object as the input + :attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So + it can be modified. + + Args: + state_dict (dict): a dict containing parameters and + persistent buffers. + prefix (str): the prefix for parameters and buffers used in this + module + local_metadata (dict): a dict containing the metadata for this module. + See + strict (bool): whether to strictly enforce that the keys in + :attr:`state_dict` with :attr:`prefix` match the names of + parameters and buffers in this module + missing_keys (list of str): if ``strict=True``, add missing keys to + this list + unexpected_keys (list of str): if ``strict=True``, add unexpected + keys to this list + error_msgs (list of str): error messages should be added to this + list, and will be reported together in + :meth:`~torch.nn.Module.load_state_dict` + """ + for hook in self._load_state_dict_pre_hooks.values(): + hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + + persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set} + local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items()) + local_state = {k: v for k, v in local_name_params if v is not None} + assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False) + use_swap_tensors = torch.__future__.get_swap_module_params_on_conversion() + + for name, param in local_state.items(): + key = prefix + name + if key in state_dict: + input_param = state_dict[key] + if not torch.overrides.is_tensor_like(input_param): + error_msgs.append(f'While copying the parameter named "{key}", ' + 'expected torch.Tensor or Tensor-like object from checkpoint but ' + f'received {type(input_param)}' + ) + continue + + # This is used to avoid copying uninitialized parameters into + # non-lazy modules, since they dont have the hook to do the checks + # in such case, it will error when accessing the .shape attribute. + is_param_lazy = torch.nn.parameter.is_lazy(param) + # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+ + if not is_param_lazy and len(param.shape) == 0 and len(input_param.shape) == 1: + input_param = input_param[0] + + if not is_param_lazy and input_param.shape != param.shape: + # local shape should match the one in checkpoint + error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, ' + 'the shape in current model is {}.' + .format(key, input_param.shape, param.shape)) + continue + + if param.is_meta and not input_param.is_meta and not assign_to_params_buffers: + warnings.warn(f'for {key}: copying from a non-meta parameter in the checkpoint to a meta ' + 'parameter in the current model, which is a no-op. (Did you mean to ' + 'pass `assign=True` to assign items in the state dictionary to their ' + 'corresponding key in the module instead of copying them in place?)') + + try: + with torch.no_grad(): + if use_swap_tensors: + new_input_param = param.module_load(input_param, assign=assign_to_params_buffers) + if id(new_input_param) == id(input_param) or id(new_input_param) == id(param): + raise RuntimeError("module_load returned one of self or other, please .detach() " + "the result if returning one of the inputs in module_load") + if (isinstance(param, torch.nn.Parameter)): + if not isinstance(new_input_param, torch.nn.Parameter): + new_input_param = torch.nn.Parameter(new_input_param, requires_grad=param.requires_grad) + else: + new_input_param.requires_grad_(param.requires_grad) + torch.utils.swap_tensors(param, new_input_param) + del new_input_param + elif assign_to_params_buffers: + # Shape checks are already done above + if (isinstance(param, torch.nn.Parameter)): + if not isinstance(input_param, torch.nn.Parameter): + input_param = torch.nn.Parameter(input_param, requires_grad=param.requires_grad) + else: + input_param.requires_grad_(param.requires_grad) + setattr(self, name, input_param) + else: + param.copy_(input_param) + except Exception as ex: + action = "swapping" if use_swap_tensors else "copying" + error_msgs.append(f'While {action} the parameter named "{key}", ' + f'whose dimensions in the model are {param.size()} and ' + f'whose dimensions in the checkpoint are {input_param.size()}, ' + f'an exception occurred : {ex.args}.' + ) + elif strict: + missing_keys.append(key) + + extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX + if getattr(self.__class__, "set_extra_state", Module.set_extra_state) is not Module.set_extra_state: + if extra_state_key in state_dict: + self.set_extra_state(state_dict[extra_state_key]) + elif strict: + missing_keys.append(extra_state_key) + elif strict and (extra_state_key in state_dict): + unexpected_keys.append(extra_state_key) + + if strict: + for key in state_dict.keys(): + if key.startswith(prefix) and key != extra_state_key: + input_name = key[len(prefix):] + input_name = input_name.split('.', 1)[0] # get the name of param/buffer/child + if input_name not in self._modules and input_name not in local_state: + unexpected_keys.append(key) + + def load_state_dict(self, state_dict: Mapping[str, Any], + strict: bool = True, assign: bool = False): + r"""Copy parameters and buffers from :attr:`state_dict` into this module and its descendants. + + If :attr:`strict` is ``True``, then + the keys of :attr:`state_dict` must exactly match the keys returned + by this module's :meth:`~torch.nn.Module.state_dict` function. + + .. warning:: + If :attr:`assign` is ``True`` the optimizer must be created after + the call to :attr:`load_state_dict` unless + :func:`~torch.__future__.get_swap_module_params_on_conversion` is ``True``. + + Args: + state_dict (dict): a dict containing parameters and + persistent buffers. + strict (bool, optional): whether to strictly enforce that the keys + in :attr:`state_dict` match the keys returned by this module's + :meth:`~torch.nn.Module.state_dict` function. Default: ``True`` + assign (bool, optional): When ``False``, the properties of the tensors + in the current module are preserved while when ``True``, the + properties of the Tensors in the state dict are preserved. The only + exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter`s + for which the value from the module is preserved. + Default: ``False`` + + Returns: + ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: + * **missing_keys** is a list of str containing the missing keys + * **unexpected_keys** is a list of str containing the unexpected keys + + Note: + If a parameter or buffer is registered as ``None`` and its corresponding key + exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a + ``RuntimeError``. + """ + if not isinstance(state_dict, Mapping): + raise TypeError(f"Expected state_dict to be dict-like, got {type(state_dict)}.") + + missing_keys: List[str] = [] + unexpected_keys: List[str] = [] + error_msgs: List[str] = [] + + # copy state_dict so _load_from_state_dict can modify it + metadata = getattr(state_dict, '_metadata', None) + state_dict = OrderedDict(state_dict) + if metadata is not None: + # mypy isn't aware that "_metadata" exists in state_dict + state_dict._metadata = metadata # type: ignore[attr-defined] + + def load(module, local_state_dict, prefix=''): + local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) + if assign: + local_metadata['assign_to_params_buffers'] = assign + module._load_from_state_dict( + local_state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) + for name, child in module._modules.items(): + if child is not None: + child_prefix = prefix + name + '.' + child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)} + load(child, child_state_dict, child_prefix) # noqa: F821 + + # Note that the hook can modify missing_keys and unexpected_keys. + incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys) + for hook in module._load_state_dict_post_hooks.values(): + out = hook(module, incompatible_keys) + assert out is None, ( + "Hooks registered with ``register_load_state_dict_post_hook`` are not" + "expected to return new values, if incompatible_keys need to be modified," + "it should be done inplace." + ) + + load(self, state_dict) + del load + + if strict: + if len(unexpected_keys) > 0: + error_msgs.insert( + 0, 'Unexpected key(s) in state_dict: {}. '.format( + ', '.join(f'"{k}"' for k in unexpected_keys))) + if len(missing_keys) > 0: + error_msgs.insert( + 0, 'Missing key(s) in state_dict: {}. '.format( + ', '.join(f'"{k}"' for k in missing_keys))) + + if len(error_msgs) > 0: + raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( + self.__class__.__name__, "\n\t".join(error_msgs))) + return _IncompatibleKeys(missing_keys, unexpected_keys) + + def _named_members(self, get_members_fn, prefix='', recurse=True, remove_duplicate: bool = True): + r"""Help yield various names + members of modules.""" + memo = set() + modules = self.named_modules(prefix=prefix, remove_duplicate=remove_duplicate) if recurse else [(prefix, self)] + for module_prefix, module in modules: + members = get_members_fn(module) + for k, v in members: + if v is None or v in memo: + continue + if remove_duplicate: + memo.add(v) + name = module_prefix + ('.' if module_prefix else '') + k + yield name, v + + def parameters(self, recurse: bool = True) -> Iterator[Parameter]: + r"""Return an iterator over module parameters. + + This is typically passed to an optimizer. + + Args: + recurse (bool): if True, then yields parameters of this module + and all submodules. Otherwise, yields only parameters that + are direct members of this module. + + Yields: + Parameter: module parameter + + Example:: + + >>> # xdoctest: +SKIP("undefined vars") + >>> for param in model.parameters(): + >>> print(type(param), param.size()) + (20L,) + (20L, 1L, 5L, 5L) + + """ + for name, param in self.named_parameters(recurse=recurse): + yield param + + def named_parameters( + self, + prefix: str = '', + recurse: bool = True, + remove_duplicate: bool = True + ) -> Iterator[Tuple[str, Parameter]]: + r"""Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself. + + Args: + prefix (str): prefix to prepend to all parameter names. + recurse (bool): if True, then yields parameters of this module + and all submodules. Otherwise, yields only parameters that + are direct members of this module. + remove_duplicate (bool, optional): whether to remove the duplicated + parameters in the result. Defaults to True. + + Yields: + (str, Parameter): Tuple containing the name and parameter + + Example:: + + >>> # xdoctest: +SKIP("undefined vars") + >>> for name, param in self.named_parameters(): + >>> if name in ['bias']: + >>> print(param.size()) + + """ + gen = self._named_members( + lambda module: module._parameters.items(), + prefix=prefix, recurse=recurse, remove_duplicate=remove_duplicate) + yield from gen + + def buffers(self, recurse: bool = True) -> Iterator[Tensor]: + r"""Return an iterator over module buffers. + + Args: + recurse (bool): if True, then yields buffers of this module + and all submodules. Otherwise, yields only buffers that + are direct members of this module. + + Yields: + torch.Tensor: module buffer + + Example:: + + >>> # xdoctest: +SKIP("undefined vars") + >>> for buf in model.buffers(): + >>> print(type(buf), buf.size()) + (20L,) + (20L, 1L, 5L, 5L) + + """ + for _, buf in self.named_buffers(recurse=recurse): + yield buf + + def named_buffers(self, prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) -> Iterator[Tuple[str, Tensor]]: + r"""Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself. + + Args: + prefix (str): prefix to prepend to all buffer names. + recurse (bool, optional): if True, then yields buffers of this module + and all submodules. Otherwise, yields only buffers that + are direct members of this module. Defaults to True. + remove_duplicate (bool, optional): whether to remove the duplicated buffers in the result. Defaults to True. + + Yields: + (str, torch.Tensor): Tuple containing the name and buffer + + Example:: + + >>> # xdoctest: +SKIP("undefined vars") + >>> for name, buf in self.named_buffers(): + >>> if name in ['running_var']: + >>> print(buf.size()) + + """ + gen = self._named_members( + lambda module: module._buffers.items(), + prefix=prefix, recurse=recurse, remove_duplicate=remove_duplicate) + yield from gen + + def children(self) -> Iterator['Module']: + r"""Return an iterator over immediate children modules. + + Yields: + Module: a child module + """ + for name, module in self.named_children(): + yield module + + def named_children(self) -> Iterator[Tuple[str, 'Module']]: + r"""Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself. + + Yields: + (str, Module): Tuple containing a name and child module + + Example:: + + >>> # xdoctest: +SKIP("undefined vars") + >>> for name, module in model.named_children(): + >>> if name in ['conv4', 'conv5']: + >>> print(module) + + """ + memo = set() + for name, module in self._modules.items(): + if module is not None and module not in memo: + memo.add(module) + yield name, module + + def modules(self) -> Iterator['Module']: + r"""Return an iterator over all modules in the network. + + Yields: + Module: a module in the network + + Note: + Duplicate modules are returned only once. In the following + example, ``l`` will be returned only once. + + Example:: + + >>> l = nn.Linear(2, 2) + >>> net = nn.Sequential(l, l) + >>> for idx, m in enumerate(net.modules()): + ... print(idx, '->', m) + + 0 -> Sequential( + (0): Linear(in_features=2, out_features=2, bias=True) + (1): Linear(in_features=2, out_features=2, bias=True) + ) + 1 -> Linear(in_features=2, out_features=2, bias=True) + + """ + for _, module in self.named_modules(): + yield module + + def named_modules(self, memo: Optional[Set['Module']] = None, prefix: str = '', remove_duplicate: bool = True): + r"""Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself. + + Args: + memo: a memo to store the set of modules already added to the result + prefix: a prefix that will be added to the name of the module + remove_duplicate: whether to remove the duplicated module instances in the result + or not + + Yields: + (str, Module): Tuple of name and module + + Note: + Duplicate modules are returned only once. In the following + example, ``l`` will be returned only once. + + Example:: + + >>> l = nn.Linear(2, 2) + >>> net = nn.Sequential(l, l) + >>> for idx, m in enumerate(net.named_modules()): + ... print(idx, '->', m) + + 0 -> ('', Sequential( + (0): Linear(in_features=2, out_features=2, bias=True) + (1): Linear(in_features=2, out_features=2, bias=True) + )) + 1 -> ('0', Linear(in_features=2, out_features=2, bias=True)) + + """ + if memo is None: + memo = set() + if self not in memo: + if remove_duplicate: + memo.add(self) + yield prefix, self + for name, module in self._modules.items(): + if module is None: + continue + submodule_prefix = prefix + ('.' if prefix else '') + name + yield from module.named_modules(memo, submodule_prefix, remove_duplicate) + + def train(self: T, mode: bool = True) -> T: + r"""Set the module in training mode. + + This has any effect only on certain modules. See documentations of + particular modules for details of their behaviors in training/evaluation + mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`, + etc. + + Args: + mode (bool): whether to set training mode (``True``) or evaluation + mode (``False``). Default: ``True``. + + Returns: + Module: self + """ + if not isinstance(mode, bool): + raise ValueError("training mode is expected to be boolean") + self.training = mode + for module in self.children(): + module.train(mode) + return self + + def eval(self: T) -> T: + r"""Set the module in evaluation mode. + + This has any effect only on certain modules. See documentations of + particular modules for details of their behaviors in training/evaluation + mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`, + etc. + + This is equivalent with :meth:`self.train(False) `. + + See :ref:`locally-disable-grad-doc` for a comparison between + `.eval()` and several similar mechanisms that may be confused with it. + + Returns: + Module: self + """ + return self.train(False) + + def requires_grad_(self: T, requires_grad: bool = True) -> T: + r"""Change if autograd should record operations on parameters in this module. + + This method sets the parameters' :attr:`requires_grad` attributes + in-place. + + This method is helpful for freezing part of the module for finetuning + or training parts of a model individually (e.g., GAN training). + + See :ref:`locally-disable-grad-doc` for a comparison between + `.requires_grad_()` and several similar mechanisms that may be confused with it. + + Args: + requires_grad (bool): whether autograd should record operations on + parameters in this module. Default: ``True``. + + Returns: + Module: self + """ + for p in self.parameters(): + p.requires_grad_(requires_grad) + return self + + def zero_grad(self, set_to_none: bool = True) -> None: + r"""Reset gradients of all model parameters. + + See similar function under :class:`torch.optim.Optimizer` for more context. + + Args: + set_to_none (bool): instead of setting to zero, set the grads to None. + See :meth:`torch.optim.Optimizer.zero_grad` for details. + """ + if getattr(self, '_is_replica', False): + warnings.warn( + "Calling .zero_grad() from a module created with nn.DataParallel() has no effect. " + "The parameters are copied (in a differentiable manner) from the original module. " + "This means they are not leaf nodes in autograd and so don't accumulate gradients. " + "If you need gradients in your forward method, consider using autograd.grad instead.") + + for p in self.parameters(): + if p.grad is not None: + if set_to_none: + p.grad = None + else: + if p.grad.grad_fn is not None: + p.grad.detach_() + else: + p.grad.requires_grad_(False) + p.grad.zero_() + + def share_memory(self: T) -> T: + r"""See :meth:`torch.Tensor.share_memory_`.""" + return self._apply(lambda t: t.share_memory_()) + + def _get_name(self): + return self.__class__.__name__ + + def extra_repr(self) -> str: + r"""Set the extra representation of the module. + + To print customized extra information, you should re-implement + this method in your own modules. Both single-line and multi-line + strings are acceptable. + """ + return '' + + def __repr__(self): + # We treat the extra repr like the sub-module, one item per line + extra_lines = [] + extra_repr = self.extra_repr() + # empty string will be split into list [''] + if extra_repr: + extra_lines = extra_repr.split('\n') + child_lines = [] + for key, module in self._modules.items(): + mod_str = repr(module) + mod_str = _addindent(mod_str, 2) + child_lines.append('(' + key + '): ' + mod_str) + lines = extra_lines + child_lines + + main_str = self._get_name() + '(' + if lines: + # simple one-liner info, which most builtin Modules will use + if len(extra_lines) == 1 and not child_lines: + main_str += extra_lines[0] + else: + main_str += '\n ' + '\n '.join(lines) + '\n' + + main_str += ')' + return main_str + + def __dir__(self): + module_attrs = dir(self.__class__) + attrs = list(self.__dict__.keys()) + parameters = list(self._parameters.keys()) + modules = list(self._modules.keys()) + buffers = list(self._buffers.keys()) + keys = module_attrs + attrs + parameters + modules + buffers + + # Eliminate attrs that are not legal Python variable names + keys = [key for key in keys if not key[0].isdigit()] + + return sorted(keys) + + def _replicate_for_data_parallel(self): + replica = self.__new__(type(self)) + replica.__dict__ = self.__dict__.copy() + + # replicas do not have parameters themselves, the replicas reference the original + # module. + replica._parameters = OrderedDict() + replica._buffers = replica._buffers.copy() + replica._modules = replica._modules.copy() + replica._is_replica = True # type: ignore[assignment] + + return replica + + def compile(self, *args, **kwargs): + """ + Compile this Module's forward using :func:`torch.compile`. + + This Module's `__call__` method is compiled and all arguments are passed as-is + to :func:`torch.compile`. + + See :func:`torch.compile` for details on the arguments for this function. + """ + self._compiled_call_impl = torch.compile(self._call_impl, *args, **kwargs) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/rnn.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/rnn.py new file mode 100644 index 0000000000000000000000000000000000000000..a1b3ac4e51976f2704d5df1fd218358129669c5e --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/rnn.py @@ -0,0 +1,1480 @@ +import math +import warnings +import numbers +import weakref +from typing import List, Tuple, Optional, overload + +import torch +from torch import Tensor +from .module import Module +from ..parameter import Parameter +from ..utils.rnn import PackedSequence +from .. import init +from ... import _VF + +__all__ = ['RNNBase', 'RNN', 'LSTM', 'GRU', 'RNNCellBase', 'RNNCell', 'LSTMCell', 'GRUCell'] + +_rnn_impls = { + 'RNN_TANH': _VF.rnn_tanh, + 'RNN_RELU': _VF.rnn_relu, +} + + +def _apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor: + return tensor.index_select(dim, permutation) + + +def apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor: + warnings.warn("apply_permutation is deprecated, please use tensor.index_select(dim, permutation) instead") + return _apply_permutation(tensor, permutation, dim) + + +class RNNBase(Module): + r"""Base class for RNN modules (RNN, LSTM, GRU). + + Implements aspects of RNNs shared by the RNN, LSTM, and GRU classes, such as module initialization + and utility methods for parameter storage management. + + .. note:: + The forward method is not implemented by the RNNBase class. + + .. note:: + LSTM and GRU classes override some methods implemented by RNNBase. + """ + + __constants__ = ['mode', 'input_size', 'hidden_size', 'num_layers', 'bias', + 'batch_first', 'dropout', 'bidirectional', 'proj_size'] + __jit_unused_properties__ = ['all_weights'] + + mode: str + input_size: int + hidden_size: int + num_layers: int + bias: bool + batch_first: bool + dropout: float + bidirectional: bool + proj_size: int + + def __init__(self, mode: str, input_size: int, hidden_size: int, + num_layers: int = 1, bias: bool = True, batch_first: bool = False, + dropout: float = 0., bidirectional: bool = False, proj_size: int = 0, + device=None, dtype=None) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + self.mode = mode + self.input_size = input_size + self.hidden_size = hidden_size + self.num_layers = num_layers + self.bias = bias + self.batch_first = batch_first + self.dropout = float(dropout) + self.bidirectional = bidirectional + self.proj_size = proj_size + self._flat_weight_refs: List[Optional[weakref.ReferenceType[Parameter]]] = [] + num_directions = 2 if bidirectional else 1 + + if not isinstance(dropout, numbers.Number) or not 0 <= dropout <= 1 or \ + isinstance(dropout, bool): + raise ValueError("dropout should be a number in range [0, 1] " + "representing the probability of an element being " + "zeroed") + if dropout > 0 and num_layers == 1: + warnings.warn("dropout option adds dropout after all but last " + "recurrent layer, so non-zero dropout expects " + f"num_layers greater than 1, but got dropout={dropout} and " + f"num_layers={num_layers}") + + if not isinstance(hidden_size, int): + raise TypeError(f"hidden_size should be of type int, got: {type(hidden_size).__name__}") + if hidden_size <= 0: + raise ValueError("hidden_size must be greater than zero") + if num_layers <= 0: + raise ValueError("num_layers must be greater than zero") + if proj_size < 0: + raise ValueError("proj_size should be a positive integer or zero to disable projections") + if proj_size >= hidden_size: + raise ValueError("proj_size has to be smaller than hidden_size") + + if mode == 'LSTM': + gate_size = 4 * hidden_size + elif mode == 'GRU': + gate_size = 3 * hidden_size + elif mode == 'RNN_TANH': + gate_size = hidden_size + elif mode == 'RNN_RELU': + gate_size = hidden_size + else: + raise ValueError("Unrecognized RNN mode: " + mode) + + self._flat_weights_names = [] + self._all_weights = [] + for layer in range(num_layers): + for direction in range(num_directions): + real_hidden_size = proj_size if proj_size > 0 else hidden_size + layer_input_size = input_size if layer == 0 else real_hidden_size * num_directions + + w_ih = Parameter(torch.empty((gate_size, layer_input_size), **factory_kwargs)) + w_hh = Parameter(torch.empty((gate_size, real_hidden_size), **factory_kwargs)) + b_ih = Parameter(torch.empty(gate_size, **factory_kwargs)) + # Second bias vector included for CuDNN compatibility. Only one + # bias vector is needed in standard definition. + b_hh = Parameter(torch.empty(gate_size, **factory_kwargs)) + layer_params: Tuple[Tensor, ...] = () + if self.proj_size == 0: + if bias: + layer_params = (w_ih, w_hh, b_ih, b_hh) + else: + layer_params = (w_ih, w_hh) + else: + w_hr = Parameter(torch.empty((proj_size, hidden_size), **factory_kwargs)) + if bias: + layer_params = (w_ih, w_hh, b_ih, b_hh, w_hr) + else: + layer_params = (w_ih, w_hh, w_hr) + + suffix = '_reverse' if direction == 1 else '' + param_names = ['weight_ih_l{}{}', 'weight_hh_l{}{}'] + if bias: + param_names += ['bias_ih_l{}{}', 'bias_hh_l{}{}'] + if self.proj_size > 0: + param_names += ['weight_hr_l{}{}'] + param_names = [x.format(layer, suffix) for x in param_names] + + for name, param in zip(param_names, layer_params): + setattr(self, name, param) + self._flat_weights_names.extend(param_names) + self._all_weights.append(param_names) + + self._init_flat_weights() + + self.reset_parameters() + + def _init_flat_weights(self): + self._flat_weights = [getattr(self, wn) if hasattr(self, wn) else None + for wn in self._flat_weights_names] + self._flat_weight_refs = [weakref.ref(w) if w is not None else None + for w in self._flat_weights] + self.flatten_parameters() + + def __setattr__(self, attr, value): + if hasattr(self, "_flat_weights_names") and attr in self._flat_weights_names: + # keep self._flat_weights up to date if you do self.weight = ... + idx = self._flat_weights_names.index(attr) + self._flat_weights[idx] = value + super().__setattr__(attr, value) + + def flatten_parameters(self) -> None: + """Reset parameter data pointer so that they can use faster code paths. + + Right now, this works only if the module is on the GPU and cuDNN is enabled. + Otherwise, it's a no-op. + """ + # Short-circuits if _flat_weights is only partially instantiated + if len(self._flat_weights) != len(self._flat_weights_names): + return + + for w in self._flat_weights: + if not isinstance(w, Tensor): + return + # Short-circuits if any tensor in self._flat_weights is not acceptable to cuDNN + # or the tensors in _flat_weights are of different dtypes + + first_fw = self._flat_weights[0] + dtype = first_fw.dtype + for fw in self._flat_weights: + if (not isinstance(fw.data, Tensor) or not (fw.data.dtype == dtype) or + not fw.data.is_cuda or + not torch.backends.cudnn.is_acceptable(fw.data)): + return + + # If any parameters alias, we fall back to the slower, copying code path. This is + # a sufficient check, because overlapping parameter buffers that don't completely + # alias would break the assumptions of the uniqueness check in + # Module.named_parameters(). + unique_data_ptrs = {p.data_ptr() for p in self._flat_weights} + if len(unique_data_ptrs) != len(self._flat_weights): + return + + with torch.cuda.device_of(first_fw): + import torch.backends.cudnn.rnn as rnn + + # Note: no_grad() is necessary since _cudnn_rnn_flatten_weight is + # an inplace operation on self._flat_weights + with torch.no_grad(): + if torch._use_cudnn_rnn_flatten_weight(): + num_weights = 4 if self.bias else 2 + if self.proj_size > 0: + num_weights += 1 + torch._cudnn_rnn_flatten_weight( + self._flat_weights, num_weights, + self.input_size, rnn.get_cudnn_mode(self.mode), + self.hidden_size, self.proj_size, self.num_layers, + self.batch_first, bool(self.bidirectional)) + + def _apply(self, fn, recurse=True): + self._flat_weight_refs = [] + ret = super()._apply(fn, recurse) + + # Resets _flat_weights + # Note: be v. careful before removing this, as 3rd party device types + # likely rely on this behavior to properly .to() modules like LSTM. + self._init_flat_weights() + + return ret + + def reset_parameters(self) -> None: + stdv = 1.0 / math.sqrt(self.hidden_size) if self.hidden_size > 0 else 0 + for weight in self.parameters(): + init.uniform_(weight, -stdv, stdv) + + def check_input(self, input: Tensor, batch_sizes: Optional[Tensor]) -> None: + if not torch.jit.is_scripting(): + if input.dtype != self._flat_weights[0].dtype and not torch._C._is_any_autocast_enabled(): + raise ValueError(f'input must have the type {self._flat_weights[0].dtype}, got type {input.dtype}') + expected_input_dim = 2 if batch_sizes is not None else 3 + if input.dim() != expected_input_dim: + raise RuntimeError( + f'input must have {expected_input_dim} dimensions, got {input.dim()}') + if self.input_size != input.size(-1): + raise RuntimeError( + f'input.size(-1) must be equal to input_size. Expected {self.input_size}, got {input.size(-1)}') + + def get_expected_hidden_size(self, input: Tensor, batch_sizes: Optional[Tensor]) -> Tuple[int, int, int]: + if batch_sizes is not None: + mini_batch = int(batch_sizes[0]) + else: + mini_batch = input.size(0) if self.batch_first else input.size(1) + num_directions = 2 if self.bidirectional else 1 + if self.proj_size > 0: + expected_hidden_size = (self.num_layers * num_directions, + mini_batch, self.proj_size) + else: + expected_hidden_size = (self.num_layers * num_directions, + mini_batch, self.hidden_size) + return expected_hidden_size + + def check_hidden_size(self, hx: Tensor, expected_hidden_size: Tuple[int, int, int], + msg: str = 'Expected hidden size {}, got {}') -> None: + if hx.size() != expected_hidden_size: + raise RuntimeError(msg.format(expected_hidden_size, list(hx.size()))) + + def _weights_have_changed(self): + # Returns True if the weight tensors have changed since the last forward pass. + # This is the case when used with torch.func.functional_call(), for example. + weights_changed = False + for ref, name in zip(self._flat_weight_refs, self._flat_weights_names): + weight = getattr(self, name) if hasattr(self, name) else None + if weight is not None and ref is not None and ref() is not weight: + weights_changed = True + break + return weights_changed + + def check_forward_args(self, input: Tensor, hidden: Tensor, batch_sizes: Optional[Tensor]): + self.check_input(input, batch_sizes) + expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes) + + self.check_hidden_size(hidden, expected_hidden_size) + + def permute_hidden(self, hx: Tensor, permutation: Optional[Tensor]): + if permutation is None: + return hx + return _apply_permutation(hx, permutation) + + + def extra_repr(self) -> str: + s = '{input_size}, {hidden_size}' + if self.proj_size != 0: + s += ', proj_size={proj_size}' + if self.num_layers != 1: + s += ', num_layers={num_layers}' + if self.bias is not True: + s += ', bias={bias}' + if self.batch_first is not False: + s += ', batch_first={batch_first}' + if self.dropout != 0: + s += ', dropout={dropout}' + if self.bidirectional is not False: + s += ', bidirectional={bidirectional}' + return s.format(**self.__dict__) + + def _update_flat_weights(self): + if not torch.jit.is_scripting(): + if self._weights_have_changed(): + self._init_flat_weights() + + def __getstate__(self): + # If weights have been changed, update the _flat_weights in __getstate__ here. + self._update_flat_weights() + # Don't serialize the weight references. + state = self.__dict__.copy() + del state['_flat_weight_refs'] + return state + + def __setstate__(self, d): + super().__setstate__(d) + if 'all_weights' in d: + self._all_weights = d['all_weights'] + # In PyTorch 1.8 we added a proj_size member variable to LSTM. + # LSTMs that were serialized via torch.save(module) before PyTorch 1.8 + # don't have it, so to preserve compatibility we set proj_size here. + if 'proj_size' not in d: + self.proj_size = 0 + + if not isinstance(self._all_weights[0][0], str): + num_layers = self.num_layers + num_directions = 2 if self.bidirectional else 1 + self._flat_weights_names = [] + self._all_weights = [] + for layer in range(num_layers): + for direction in range(num_directions): + suffix = '_reverse' if direction == 1 else '' + weights = ['weight_ih_l{}{}', 'weight_hh_l{}{}', 'bias_ih_l{}{}', + 'bias_hh_l{}{}', 'weight_hr_l{}{}'] + weights = [x.format(layer, suffix) for x in weights] + if self.bias: + if self.proj_size > 0: + self._all_weights += [weights] + self._flat_weights_names.extend(weights) + else: + self._all_weights += [weights[:4]] + self._flat_weights_names.extend(weights[:4]) + else: + if self.proj_size > 0: + self._all_weights += [weights[:2]] + [weights[-1:]] + self._flat_weights_names.extend(weights[:2] + [weights[-1:]]) + else: + self._all_weights += [weights[:2]] + self._flat_weights_names.extend(weights[:2]) + self._flat_weights = [getattr(self, wn) if hasattr(self, wn) else None + for wn in self._flat_weights_names] + + self._flat_weight_refs = [weakref.ref(w) if w is not None else None + for w in self._flat_weights] + + @property + def all_weights(self) -> List[List[Parameter]]: + return [[getattr(self, weight) for weight in weights] for weights in self._all_weights] + + def _replicate_for_data_parallel(self): + replica = super()._replicate_for_data_parallel() + # Need to copy these caches, otherwise the replica will share the same + # flat weights list. + replica._flat_weights = replica._flat_weights[:] + replica._flat_weights_names = replica._flat_weights_names[:] + return replica + + +class RNN(RNNBase): + r"""__init__(input_size,hidden_size,num_layers=1,nonlinearity='tanh',bias=True,batch_first=False,dropout=0.0,bidirectional=False,device=None,dtype=None) + + Apply a multi-layer Elman RNN with :math:`\tanh` or :math:`\text{ReLU}` + non-linearity to an input sequence. For each element in the input sequence, + each layer computes the following function: + + .. math:: + h_t = \tanh(x_t W_{ih}^T + b_{ih} + h_{t-1}W_{hh}^T + b_{hh}) + + where :math:`h_t` is the hidden state at time `t`, :math:`x_t` is + the input at time `t`, and :math:`h_{(t-1)}` is the hidden state of the + previous layer at time `t-1` or the initial hidden state at time `0`. + If :attr:`nonlinearity` is ``'relu'``, then :math:`\text{ReLU}` is used instead of :math:`\tanh`. + + .. code-block:: python + + # Efficient implementation equivalent to the following with bidirectional=False + def forward(x, h_0=None): + if batch_first: + x = x.transpose(0, 1) + seq_len, batch_size, _ = x.size() + if h_0 is None: + h_0 = torch.zeros(num_layers, batch_size, hidden_size) + h_t_minus_1 = h_0 + h_t = h_0 + output = [] + for t in range(seq_len): + for layer in range(num_layers): + h_t[layer] = torch.tanh( + x[t] @ weight_ih[layer].T + + bias_ih[layer] + + h_t_minus_1[layer] @ weight_hh[layer].T + + bias_hh[layer] + ) + output.append(h_t[-1]) + h_t_minus_1 = h_t + output = torch.stack(output) + if batch_first: + output = output.transpose(0, 1) + return output, h_t + + Args: + input_size: The number of expected features in the input `x` + hidden_size: The number of features in the hidden state `h` + num_layers: Number of recurrent layers. E.g., setting ``num_layers=2`` + would mean stacking two RNNs together to form a `stacked RNN`, + with the second RNN taking in outputs of the first RNN and + computing the final results. Default: 1 + nonlinearity: The non-linearity to use. Can be either ``'tanh'`` or ``'relu'``. Default: ``'tanh'`` + bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`. + Default: ``True`` + batch_first: If ``True``, then the input and output tensors are provided + as `(batch, seq, feature)` instead of `(seq, batch, feature)`. + Note that this does not apply to hidden or cell states. See the + Inputs/Outputs sections below for details. Default: ``False`` + dropout: If non-zero, introduces a `Dropout` layer on the outputs of each + RNN layer except the last layer, with dropout probability equal to + :attr:`dropout`. Default: 0 + bidirectional: If ``True``, becomes a bidirectional RNN. Default: ``False`` + + Inputs: input, h_0 + * **input**: tensor of shape :math:`(L, H_{in})` for unbatched input, + :math:`(L, N, H_{in})` when ``batch_first=False`` or + :math:`(N, L, H_{in})` when ``batch_first=True`` containing the features of + the input sequence. The input can also be a packed variable length sequence. + See :func:`torch.nn.utils.rnn.pack_padded_sequence` or + :func:`torch.nn.utils.rnn.pack_sequence` for details. + * **h_0**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` for unbatched input or + :math:`(D * \text{num\_layers}, N, H_{out})` containing the initial hidden + state for the input sequence batch. Defaults to zeros if not provided. + + where: + + .. math:: + \begin{aligned} + N ={} & \text{batch size} \\ + L ={} & \text{sequence length} \\ + D ={} & 2 \text{ if bidirectional=True otherwise } 1 \\ + H_{in} ={} & \text{input\_size} \\ + H_{out} ={} & \text{hidden\_size} + \end{aligned} + + Outputs: output, h_n + * **output**: tensor of shape :math:`(L, D * H_{out})` for unbatched input, + :math:`(L, N, D * H_{out})` when ``batch_first=False`` or + :math:`(N, L, D * H_{out})` when ``batch_first=True`` containing the output features + `(h_t)` from the last layer of the RNN, for each `t`. If a + :class:`torch.nn.utils.rnn.PackedSequence` has been given as the input, the output + will also be a packed sequence. + * **h_n**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` for unbatched input or + :math:`(D * \text{num\_layers}, N, H_{out})` containing the final hidden state + for each element in the batch. + + Attributes: + weight_ih_l[k]: the learnable input-hidden weights of the k-th layer, + of shape `(hidden_size, input_size)` for `k = 0`. Otherwise, the shape is + `(hidden_size, num_directions * hidden_size)` + weight_hh_l[k]: the learnable hidden-hidden weights of the k-th layer, + of shape `(hidden_size, hidden_size)` + bias_ih_l[k]: the learnable input-hidden bias of the k-th layer, + of shape `(hidden_size)` + bias_hh_l[k]: the learnable hidden-hidden bias of the k-th layer, + of shape `(hidden_size)` + + .. note:: + All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` + where :math:`k = \frac{1}{\text{hidden\_size}}` + + .. note:: + For bidirectional RNNs, forward and backward are directions 0 and 1 respectively. + Example of splitting the output layers when ``batch_first=False``: + ``output.view(seq_len, batch, num_directions, hidden_size)``. + + .. note:: + ``batch_first`` argument is ignored for unbatched inputs. + + .. include:: ../cudnn_rnn_determinism.rst + + .. include:: ../cudnn_persistent_rnn.rst + + Examples:: + + >>> rnn = nn.RNN(10, 20, 2) + >>> input = torch.randn(5, 3, 10) + >>> h0 = torch.randn(2, 3, 20) + >>> output, hn = rnn(input, h0) + """ + + @overload + def __init__(self, input_size: int, hidden_size: int, num_layers: int = 1, + nonlinearity: str = 'tanh', bias: bool = True, batch_first: bool = False, + dropout: float = 0., bidirectional: bool = False, device=None, + dtype=None) -> None: + ... + + @overload + def __init__(self, *args, **kwargs): + ... + + def __init__(self, *args, **kwargs): + if 'proj_size' in kwargs: + raise ValueError("proj_size argument is only supported for LSTM, not RNN or GRU") + if len(args) > 3: + self.nonlinearity = args[3] + args = args[:3] + args[4:] + else: + self.nonlinearity = kwargs.pop('nonlinearity', 'tanh') + if self.nonlinearity == 'tanh': + mode = 'RNN_TANH' + elif self.nonlinearity == 'relu': + mode = 'RNN_RELU' + else: + raise ValueError(f"Unknown nonlinearity '{self.nonlinearity}'. Select from 'tanh' or 'relu'.") + super().__init__(mode, *args, **kwargs) + + @overload + @torch._jit_internal._overload_method # noqa: F811 + def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]: + pass + + @overload + @torch._jit_internal._overload_method # noqa: F811 + def forward(self, input: PackedSequence, hx: Optional[Tensor] = None) -> Tuple[PackedSequence, Tensor]: + pass + + def forward(self, input, hx=None): # noqa: F811 + self._update_flat_weights() + + num_directions = 2 if self.bidirectional else 1 + orig_input = input + + if isinstance(orig_input, PackedSequence): + input, batch_sizes, sorted_indices, unsorted_indices = input + max_batch_size = batch_sizes[0] + # script() is unhappy when max_batch_size is different type in cond branches, so we duplicate + if hx is None: + hx = torch.zeros(self.num_layers * num_directions, + max_batch_size, self.hidden_size, + dtype=input.dtype, device=input.device) + else: + # Each batch of the hidden state should match the input sequence that + # the user believes he/she is passing in. + hx = self.permute_hidden(hx, sorted_indices) + else: + batch_sizes = None + if input.dim() not in (2, 3): + raise ValueError(f"RNN: Expected input to be 2D or 3D, got {input.dim()}D tensor instead") + is_batched = input.dim() == 3 + batch_dim = 0 if self.batch_first else 1 + if not is_batched: + input = input.unsqueeze(batch_dim) + if hx is not None: + if hx.dim() != 2: + raise RuntimeError( + f"For unbatched 2-D input, hx should also be 2-D but got {hx.dim()}-D tensor") + hx = hx.unsqueeze(1) + else: + if hx is not None and hx.dim() != 3: + raise RuntimeError( + f"For batched 3-D input, hx should also be 3-D but got {hx.dim()}-D tensor") + max_batch_size = input.size(0) if self.batch_first else input.size(1) + sorted_indices = None + unsorted_indices = None + if hx is None: + hx = torch.zeros(self.num_layers * num_directions, + max_batch_size, self.hidden_size, + dtype=input.dtype, device=input.device) + else: + # Each batch of the hidden state should match the input sequence that + # the user believes he/she is passing in. + hx = self.permute_hidden(hx, sorted_indices) + + assert hx is not None + self.check_forward_args(input, hx, batch_sizes) + assert self.mode == 'RNN_TANH' or self.mode == 'RNN_RELU' + if batch_sizes is None: + if self.mode == 'RNN_TANH': + result = _VF.rnn_tanh(input, hx, self._flat_weights, self.bias, self.num_layers, + self.dropout, self.training, self.bidirectional, + self.batch_first) + else: + result = _VF.rnn_relu(input, hx, self._flat_weights, self.bias, self.num_layers, + self.dropout, self.training, self.bidirectional, + self.batch_first) + else: + if self.mode == 'RNN_TANH': + result = _VF.rnn_tanh(input, batch_sizes, hx, self._flat_weights, self.bias, + self.num_layers, self.dropout, self.training, + self.bidirectional) + else: + result = _VF.rnn_relu(input, batch_sizes, hx, self._flat_weights, self.bias, + self.num_layers, self.dropout, self.training, + self.bidirectional) + + output = result[0] + hidden = result[1] + + if isinstance(orig_input, PackedSequence): + output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices) + return output_packed, self.permute_hidden(hidden, unsorted_indices) + + if not is_batched: # type: ignore[possibly-undefined] + output = output.squeeze(batch_dim) # type: ignore[possibly-undefined] + hidden = hidden.squeeze(1) + + return output, self.permute_hidden(hidden, unsorted_indices) + +# XXX: LSTM and GRU implementation is different from RNNBase, this is because: +# 1. we want to support nn.LSTM and nn.GRU in TorchScript and TorchScript in +# its current state could not support the python Union Type or Any Type +# 2. TorchScript static typing does not allow a Function or Callable type in +# Dict values, so we have to separately call _VF instead of using _rnn_impls +# 3. This is temporary only and in the transition state that we want to make it +# on time for the release +# +# More discussion details in https://github.com/pytorch/pytorch/pull/23266 +# +# TODO: remove the overriding implementations for LSTM and GRU when TorchScript +# support expressing these two modules generally. + + +class LSTM(RNNBase): + r"""__init__(input_size,hidden_size,num_layers=1,bias=True,batch_first=False,dropout=0.0,bidirectional=False,proj_size=0,device=None,dtype=None) + + Apply a multi-layer long short-term memory (LSTM) RNN to an input sequence. + For each element in the input sequence, each layer computes the following + function: + + .. math:: + \begin{array}{ll} \\ + i_t = \sigma(W_{ii} x_t + b_{ii} + W_{hi} h_{t-1} + b_{hi}) \\ + f_t = \sigma(W_{if} x_t + b_{if} + W_{hf} h_{t-1} + b_{hf}) \\ + g_t = \tanh(W_{ig} x_t + b_{ig} + W_{hg} h_{t-1} + b_{hg}) \\ + o_t = \sigma(W_{io} x_t + b_{io} + W_{ho} h_{t-1} + b_{ho}) \\ + c_t = f_t \odot c_{t-1} + i_t \odot g_t \\ + h_t = o_t \odot \tanh(c_t) \\ + \end{array} + + where :math:`h_t` is the hidden state at time `t`, :math:`c_t` is the cell + state at time `t`, :math:`x_t` is the input at time `t`, :math:`h_{t-1}` + is the hidden state of the layer at time `t-1` or the initial hidden + state at time `0`, and :math:`i_t`, :math:`f_t`, :math:`g_t`, + :math:`o_t` are the input, forget, cell, and output gates, respectively. + :math:`\sigma` is the sigmoid function, and :math:`\odot` is the Hadamard product. + + In a multilayer LSTM, the input :math:`x^{(l)}_t` of the :math:`l` -th layer + (:math:`l \ge 2`) is the hidden state :math:`h^{(l-1)}_t` of the previous layer multiplied by + dropout :math:`\delta^{(l-1)}_t` where each :math:`\delta^{(l-1)}_t` is a Bernoulli random + variable which is :math:`0` with probability :attr:`dropout`. + + If ``proj_size > 0`` is specified, LSTM with projections will be used. This changes + the LSTM cell in the following way. First, the dimension of :math:`h_t` will be changed from + ``hidden_size`` to ``proj_size`` (dimensions of :math:`W_{hi}` will be changed accordingly). + Second, the output hidden state of each layer will be multiplied by a learnable projection + matrix: :math:`h_t = W_{hr}h_t`. Note that as a consequence of this, the output + of LSTM network will be of different shape as well. See Inputs/Outputs sections below for exact + dimensions of all variables. You can find more details in https://arxiv.org/abs/1402.1128. + + Args: + input_size: The number of expected features in the input `x` + hidden_size: The number of features in the hidden state `h` + num_layers: Number of recurrent layers. E.g., setting ``num_layers=2`` + would mean stacking two LSTMs together to form a `stacked LSTM`, + with the second LSTM taking in outputs of the first LSTM and + computing the final results. Default: 1 + bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`. + Default: ``True`` + batch_first: If ``True``, then the input and output tensors are provided + as `(batch, seq, feature)` instead of `(seq, batch, feature)`. + Note that this does not apply to hidden or cell states. See the + Inputs/Outputs sections below for details. Default: ``False`` + dropout: If non-zero, introduces a `Dropout` layer on the outputs of each + LSTM layer except the last layer, with dropout probability equal to + :attr:`dropout`. Default: 0 + bidirectional: If ``True``, becomes a bidirectional LSTM. Default: ``False`` + proj_size: If ``> 0``, will use LSTM with projections of corresponding size. Default: 0 + + Inputs: input, (h_0, c_0) + * **input**: tensor of shape :math:`(L, H_{in})` for unbatched input, + :math:`(L, N, H_{in})` when ``batch_first=False`` or + :math:`(N, L, H_{in})` when ``batch_first=True`` containing the features of + the input sequence. The input can also be a packed variable length sequence. + See :func:`torch.nn.utils.rnn.pack_padded_sequence` or + :func:`torch.nn.utils.rnn.pack_sequence` for details. + * **h_0**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` for unbatched input or + :math:`(D * \text{num\_layers}, N, H_{out})` containing the + initial hidden state for each element in the input sequence. + Defaults to zeros if (h_0, c_0) is not provided. + * **c_0**: tensor of shape :math:`(D * \text{num\_layers}, H_{cell})` for unbatched input or + :math:`(D * \text{num\_layers}, N, H_{cell})` containing the + initial cell state for each element in the input sequence. + Defaults to zeros if (h_0, c_0) is not provided. + + where: + + .. math:: + \begin{aligned} + N ={} & \text{batch size} \\ + L ={} & \text{sequence length} \\ + D ={} & 2 \text{ if bidirectional=True otherwise } 1 \\ + H_{in} ={} & \text{input\_size} \\ + H_{cell} ={} & \text{hidden\_size} \\ + H_{out} ={} & \text{proj\_size if } \text{proj\_size}>0 \text{ otherwise hidden\_size} \\ + \end{aligned} + + Outputs: output, (h_n, c_n) + * **output**: tensor of shape :math:`(L, D * H_{out})` for unbatched input, + :math:`(L, N, D * H_{out})` when ``batch_first=False`` or + :math:`(N, L, D * H_{out})` when ``batch_first=True`` containing the output features + `(h_t)` from the last layer of the LSTM, for each `t`. If a + :class:`torch.nn.utils.rnn.PackedSequence` has been given as the input, the output + will also be a packed sequence. When ``bidirectional=True``, `output` will contain + a concatenation of the forward and reverse hidden states at each time step in the sequence. + * **h_n**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` for unbatched input or + :math:`(D * \text{num\_layers}, N, H_{out})` containing the + final hidden state for each element in the sequence. When ``bidirectional=True``, + `h_n` will contain a concatenation of the final forward and reverse hidden states, respectively. + * **c_n**: tensor of shape :math:`(D * \text{num\_layers}, H_{cell})` for unbatched input or + :math:`(D * \text{num\_layers}, N, H_{cell})` containing the + final cell state for each element in the sequence. When ``bidirectional=True``, + `c_n` will contain a concatenation of the final forward and reverse cell states, respectively. + + Attributes: + weight_ih_l[k] : the learnable input-hidden weights of the :math:`\text{k}^{th}` layer + `(W_ii|W_if|W_ig|W_io)`, of shape `(4*hidden_size, input_size)` for `k = 0`. + Otherwise, the shape is `(4*hidden_size, num_directions * hidden_size)`. If + ``proj_size > 0`` was specified, the shape will be + `(4*hidden_size, num_directions * proj_size)` for `k > 0` + weight_hh_l[k] : the learnable hidden-hidden weights of the :math:`\text{k}^{th}` layer + `(W_hi|W_hf|W_hg|W_ho)`, of shape `(4*hidden_size, hidden_size)`. If ``proj_size > 0`` + was specified, the shape will be `(4*hidden_size, proj_size)`. + bias_ih_l[k] : the learnable input-hidden bias of the :math:`\text{k}^{th}` layer + `(b_ii|b_if|b_ig|b_io)`, of shape `(4*hidden_size)` + bias_hh_l[k] : the learnable hidden-hidden bias of the :math:`\text{k}^{th}` layer + `(b_hi|b_hf|b_hg|b_ho)`, of shape `(4*hidden_size)` + weight_hr_l[k] : the learnable projection weights of the :math:`\text{k}^{th}` layer + of shape `(proj_size, hidden_size)`. Only present when ``proj_size > 0`` was + specified. + weight_ih_l[k]_reverse: Analogous to `weight_ih_l[k]` for the reverse direction. + Only present when ``bidirectional=True``. + weight_hh_l[k]_reverse: Analogous to `weight_hh_l[k]` for the reverse direction. + Only present when ``bidirectional=True``. + bias_ih_l[k]_reverse: Analogous to `bias_ih_l[k]` for the reverse direction. + Only present when ``bidirectional=True``. + bias_hh_l[k]_reverse: Analogous to `bias_hh_l[k]` for the reverse direction. + Only present when ``bidirectional=True``. + weight_hr_l[k]_reverse: Analogous to `weight_hr_l[k]` for the reverse direction. + Only present when ``bidirectional=True`` and ``proj_size > 0`` was specified. + + .. note:: + All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` + where :math:`k = \frac{1}{\text{hidden\_size}}` + + .. note:: + For bidirectional LSTMs, forward and backward are directions 0 and 1 respectively. + Example of splitting the output layers when ``batch_first=False``: + ``output.view(seq_len, batch, num_directions, hidden_size)``. + + .. note:: + For bidirectional LSTMs, `h_n` is not equivalent to the last element of `output`; the + former contains the final forward and reverse hidden states, while the latter contains the + final forward hidden state and the initial reverse hidden state. + + .. note:: + ``batch_first`` argument is ignored for unbatched inputs. + + .. note:: + ``proj_size`` should be smaller than ``hidden_size``. + + .. include:: ../cudnn_rnn_determinism.rst + + .. include:: ../cudnn_persistent_rnn.rst + + Examples:: + + >>> rnn = nn.LSTM(10, 20, 2) + >>> input = torch.randn(5, 3, 10) + >>> h0 = torch.randn(2, 3, 20) + >>> c0 = torch.randn(2, 3, 20) + >>> output, (hn, cn) = rnn(input, (h0, c0)) + """ + + @overload + def __init__(self, input_size: int, hidden_size: int, num_layers: int = 1, bias: bool = True, + batch_first: bool = False, dropout: float = 0., bidirectional: bool = False, + proj_size: int = 0, device=None, dtype=None) -> None: + ... + + @overload + def __init__(self, *args, **kwargs): + ... + + def __init__(self, *args, **kwargs): + super().__init__('LSTM', *args, **kwargs) + + def get_expected_cell_size(self, input: Tensor, batch_sizes: Optional[Tensor]) -> Tuple[int, int, int]: + if batch_sizes is not None: + mini_batch = int(batch_sizes[0]) + else: + mini_batch = input.size(0) if self.batch_first else input.size(1) + num_directions = 2 if self.bidirectional else 1 + expected_hidden_size = (self.num_layers * num_directions, + mini_batch, self.hidden_size) + return expected_hidden_size + + # In the future, we should prevent mypy from applying contravariance rules here. + # See torch/nn/modules/module.py::_forward_unimplemented + def check_forward_args(self, # type: ignore[override] + input: Tensor, + hidden: Tuple[Tensor, Tensor], + batch_sizes: Optional[Tensor], + ): + self.check_input(input, batch_sizes) + self.check_hidden_size(hidden[0], self.get_expected_hidden_size(input, batch_sizes), + 'Expected hidden[0] size {}, got {}') + self.check_hidden_size(hidden[1], self.get_expected_cell_size(input, batch_sizes), + 'Expected hidden[1] size {}, got {}') + + # Same as above, see torch/nn/modules/module.py::_forward_unimplemented + def permute_hidden(self, # type: ignore[override] + hx: Tuple[Tensor, Tensor], + permutation: Optional[Tensor] + ) -> Tuple[Tensor, Tensor]: + if permutation is None: + return hx + return _apply_permutation(hx[0], permutation), _apply_permutation(hx[1], permutation) + + # Same as above, see torch/nn/modules/module.py::_forward_unimplemented + @overload # type: ignore[override] + @torch._jit_internal._overload_method # noqa: F811 + def forward(self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None + ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: # noqa: F811 + pass + + # Same as above, see torch/nn/modules/module.py::_forward_unimplemented + @overload + @torch._jit_internal._overload_method # noqa: F811 + def forward(self, input: PackedSequence, hx: Optional[Tuple[Tensor, Tensor]] = None + ) -> Tuple[PackedSequence, Tuple[Tensor, Tensor]]: # noqa: F811 + pass + + def forward(self, input, hx=None): # noqa: F811 + self._update_flat_weights() + + orig_input = input + # xxx: isinstance check needs to be in conditional for TorchScript to compile + batch_sizes = None + do_permute = False + num_directions = 2 if self.bidirectional else 1 + real_hidden_size = self.proj_size if self.proj_size > 0 else self.hidden_size + if isinstance(orig_input, PackedSequence): + input, batch_sizes, sorted_indices, unsorted_indices = input + max_batch_size = batch_sizes[0] + if hx is None: + h_zeros = torch.zeros(self.num_layers * num_directions, + max_batch_size, real_hidden_size, + dtype=input.dtype, device=input.device) + c_zeros = torch.zeros(self.num_layers * num_directions, + max_batch_size, self.hidden_size, + dtype=input.dtype, device=input.device) + hx = (h_zeros, c_zeros) + else: + # Each batch of the hidden state should match the input sequence that + # the user believes he/she is passing in. + hx = self.permute_hidden(hx, sorted_indices) + else: + if input.dim() not in (2, 3): + raise ValueError(f"LSTM: Expected input to be 2D or 3D, got {input.dim()}D instead") + is_batched = input.dim() == 3 + batch_dim = 0 if self.batch_first else 1 + if not is_batched: + input = input.unsqueeze(batch_dim) + max_batch_size = input.size(0) if self.batch_first else input.size(1) + sorted_indices = None + unsorted_indices = None + if hx is None: + h_zeros = torch.zeros(self.num_layers * num_directions, + max_batch_size, real_hidden_size, + dtype=input.dtype, device=input.device) + c_zeros = torch.zeros(self.num_layers * num_directions, + max_batch_size, self.hidden_size, + dtype=input.dtype, device=input.device) + hx = (h_zeros, c_zeros) + self.check_forward_args(input, hx, batch_sizes) + else: + if is_batched: + if (hx[0].dim() != 3 or hx[1].dim() != 3): + msg = ("For batched 3-D input, hx and cx should " + f"also be 3-D but got ({hx[0].dim()}-D, {hx[1].dim()}-D) tensors") + raise RuntimeError(msg) + else: + if hx[0].dim() != 2 or hx[1].dim() != 2: + msg = ("For unbatched 2-D input, hx and cx should " + f"also be 2-D but got ({hx[0].dim()}-D, {hx[1].dim()}-D) tensors") + raise RuntimeError(msg) + hx = (hx[0].unsqueeze(1), hx[1].unsqueeze(1)) + # Each batch of the hidden state should match the input sequence that + # the user believes he/she is passing in. + self.check_forward_args(input, hx, batch_sizes) + hx = self.permute_hidden(hx, sorted_indices) + + if batch_sizes is None: + result = _VF.lstm(input, hx, self._flat_weights, self.bias, self.num_layers, + self.dropout, self.training, self.bidirectional, self.batch_first) + else: + result = _VF.lstm(input, batch_sizes, hx, self._flat_weights, self.bias, + self.num_layers, self.dropout, self.training, self.bidirectional) + output = result[0] + hidden = result[1:] + # xxx: isinstance check needs to be in conditional for TorchScript to compile + if isinstance(orig_input, PackedSequence): + output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices) + return output_packed, self.permute_hidden(hidden, unsorted_indices) + else: + if not is_batched: # type: ignore[possibly-undefined] + output = output.squeeze(batch_dim) # type: ignore[possibly-undefined] + hidden = (hidden[0].squeeze(1), hidden[1].squeeze(1)) + return output, self.permute_hidden(hidden, unsorted_indices) + + +class GRU(RNNBase): + r"""__init__(input_size,hidden_size,num_layers=1,bias=True,batch_first=False,dropout=0.0,bidirectional=False,device=None,dtype=None) + + Apply a multi-layer gated recurrent unit (GRU) RNN to an input sequence. + For each element in the input sequence, each layer computes the following + function: + + .. math:: + \begin{array}{ll} + r_t = \sigma(W_{ir} x_t + b_{ir} + W_{hr} h_{(t-1)} + b_{hr}) \\ + z_t = \sigma(W_{iz} x_t + b_{iz} + W_{hz} h_{(t-1)} + b_{hz}) \\ + n_t = \tanh(W_{in} x_t + b_{in} + r_t \odot (W_{hn} h_{(t-1)}+ b_{hn})) \\ + h_t = (1 - z_t) \odot n_t + z_t \odot h_{(t-1)} + \end{array} + + where :math:`h_t` is the hidden state at time `t`, :math:`x_t` is the input + at time `t`, :math:`h_{(t-1)}` is the hidden state of the layer + at time `t-1` or the initial hidden state at time `0`, and :math:`r_t`, + :math:`z_t`, :math:`n_t` are the reset, update, and new gates, respectively. + :math:`\sigma` is the sigmoid function, and :math:`\odot` is the Hadamard product. + + In a multilayer GRU, the input :math:`x^{(l)}_t` of the :math:`l` -th layer + (:math:`l \ge 2`) is the hidden state :math:`h^{(l-1)}_t` of the previous layer multiplied by + dropout :math:`\delta^{(l-1)}_t` where each :math:`\delta^{(l-1)}_t` is a Bernoulli random + variable which is :math:`0` with probability :attr:`dropout`. + + Args: + input_size: The number of expected features in the input `x` + hidden_size: The number of features in the hidden state `h` + num_layers: Number of recurrent layers. E.g., setting ``num_layers=2`` + would mean stacking two GRUs together to form a `stacked GRU`, + with the second GRU taking in outputs of the first GRU and + computing the final results. Default: 1 + bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`. + Default: ``True`` + batch_first: If ``True``, then the input and output tensors are provided + as `(batch, seq, feature)` instead of `(seq, batch, feature)`. + Note that this does not apply to hidden or cell states. See the + Inputs/Outputs sections below for details. Default: ``False`` + dropout: If non-zero, introduces a `Dropout` layer on the outputs of each + GRU layer except the last layer, with dropout probability equal to + :attr:`dropout`. Default: 0 + bidirectional: If ``True``, becomes a bidirectional GRU. Default: ``False`` + + Inputs: input, h_0 + * **input**: tensor of shape :math:`(L, H_{in})` for unbatched input, + :math:`(L, N, H_{in})` when ``batch_first=False`` or + :math:`(N, L, H_{in})` when ``batch_first=True`` containing the features of + the input sequence. The input can also be a packed variable length sequence. + See :func:`torch.nn.utils.rnn.pack_padded_sequence` or + :func:`torch.nn.utils.rnn.pack_sequence` for details. + * **h_0**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` or + :math:`(D * \text{num\_layers}, N, H_{out})` + containing the initial hidden state for the input sequence. Defaults to zeros if not provided. + + where: + + .. math:: + \begin{aligned} + N ={} & \text{batch size} \\ + L ={} & \text{sequence length} \\ + D ={} & 2 \text{ if bidirectional=True otherwise } 1 \\ + H_{in} ={} & \text{input\_size} \\ + H_{out} ={} & \text{hidden\_size} + \end{aligned} + + Outputs: output, h_n + * **output**: tensor of shape :math:`(L, D * H_{out})` for unbatched input, + :math:`(L, N, D * H_{out})` when ``batch_first=False`` or + :math:`(N, L, D * H_{out})` when ``batch_first=True`` containing the output features + `(h_t)` from the last layer of the GRU, for each `t`. If a + :class:`torch.nn.utils.rnn.PackedSequence` has been given as the input, the output + will also be a packed sequence. + * **h_n**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` or + :math:`(D * \text{num\_layers}, N, H_{out})` containing the final hidden state + for the input sequence. + + Attributes: + weight_ih_l[k] : the learnable input-hidden weights of the :math:`\text{k}^{th}` layer + (W_ir|W_iz|W_in), of shape `(3*hidden_size, input_size)` for `k = 0`. + Otherwise, the shape is `(3*hidden_size, num_directions * hidden_size)` + weight_hh_l[k] : the learnable hidden-hidden weights of the :math:`\text{k}^{th}` layer + (W_hr|W_hz|W_hn), of shape `(3*hidden_size, hidden_size)` + bias_ih_l[k] : the learnable input-hidden bias of the :math:`\text{k}^{th}` layer + (b_ir|b_iz|b_in), of shape `(3*hidden_size)` + bias_hh_l[k] : the learnable hidden-hidden bias of the :math:`\text{k}^{th}` layer + (b_hr|b_hz|b_hn), of shape `(3*hidden_size)` + + .. note:: + All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` + where :math:`k = \frac{1}{\text{hidden\_size}}` + + .. note:: + For bidirectional GRUs, forward and backward are directions 0 and 1 respectively. + Example of splitting the output layers when ``batch_first=False``: + ``output.view(seq_len, batch, num_directions, hidden_size)``. + + .. note:: + ``batch_first`` argument is ignored for unbatched inputs. + + .. note:: + The calculation of new gate :math:`n_t` subtly differs from the original paper and other frameworks. + In the original implementation, the Hadamard product :math:`(\odot)` between :math:`r_t` and the + previous hidden state :math:`h_{(t-1)}` is done before the multiplication with the weight matrix + `W` and addition of bias: + + .. math:: + \begin{aligned} + n_t = \tanh(W_{in} x_t + b_{in} + W_{hn} ( r_t \odot h_{(t-1)} ) + b_{hn}) + \end{aligned} + + This is in contrast to PyTorch implementation, which is done after :math:`W_{hn} h_{(t-1)}` + + .. math:: + \begin{aligned} + n_t = \tanh(W_{in} x_t + b_{in} + r_t \odot (W_{hn} h_{(t-1)}+ b_{hn})) + \end{aligned} + + This implementation differs on purpose for efficiency. + + .. include:: ../cudnn_persistent_rnn.rst + + Examples:: + + >>> rnn = nn.GRU(10, 20, 2) + >>> input = torch.randn(5, 3, 10) + >>> h0 = torch.randn(2, 3, 20) + >>> output, hn = rnn(input, h0) + """ + + @overload + def __init__(self, input_size: int, hidden_size: int, num_layers: int = 1, bias: bool = True, + batch_first: bool = False, dropout: float = 0., bidirectional: bool = False, + device=None, dtype=None) -> None: + ... + + @overload + def __init__(self, *args, **kwargs): + ... + + def __init__(self, *args, **kwargs): + if 'proj_size' in kwargs: + raise ValueError("proj_size argument is only supported for LSTM, not RNN or GRU") + super().__init__('GRU', *args, **kwargs) + + @overload # type: ignore[override] + @torch._jit_internal._overload_method # noqa: F811 + def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]: # noqa: F811 + pass + + @overload + @torch._jit_internal._overload_method # noqa: F811 + def forward(self, input: PackedSequence, hx: Optional[Tensor] = None) -> Tuple[PackedSequence, Tensor]: # noqa: F811 + pass + + def forward(self, input, hx=None): # noqa: F811 + self._update_flat_weights() + + orig_input = input + # xxx: isinstance check needs to be in conditional for TorchScript to compile + if isinstance(orig_input, PackedSequence): + input, batch_sizes, sorted_indices, unsorted_indices = input + max_batch_size = batch_sizes[0] + if hx is None: + num_directions = 2 if self.bidirectional else 1 + hx = torch.zeros(self.num_layers * num_directions, + max_batch_size, self.hidden_size, + dtype=input.dtype, device=input.device) + else: + # Each batch of the hidden state should match the input sequence that + # the user believes he/she is passing in. + hx = self.permute_hidden(hx, sorted_indices) + else: + batch_sizes = None + if input.dim() not in (2, 3): + raise ValueError(f"GRU: Expected input to be 2D or 3D, got {input.dim()}D instead") + is_batched = input.dim() == 3 + batch_dim = 0 if self.batch_first else 1 + if not is_batched: + input = input.unsqueeze(batch_dim) + if hx is not None: + if hx.dim() != 2: + raise RuntimeError( + f"For unbatched 2-D input, hx should also be 2-D but got {hx.dim()}-D tensor") + hx = hx.unsqueeze(1) + else: + if hx is not None and hx.dim() != 3: + raise RuntimeError( + f"For batched 3-D input, hx should also be 3-D but got {hx.dim()}-D tensor") + max_batch_size = input.size(0) if self.batch_first else input.size(1) + sorted_indices = None + unsorted_indices = None + if hx is None: + num_directions = 2 if self.bidirectional else 1 + hx = torch.zeros(self.num_layers * num_directions, + max_batch_size, self.hidden_size, + dtype=input.dtype, device=input.device) + else: + # Each batch of the hidden state should match the input sequence that + # the user believes he/she is passing in. + hx = self.permute_hidden(hx, sorted_indices) + + self.check_forward_args(input, hx, batch_sizes) + if batch_sizes is None: + result = _VF.gru(input, hx, self._flat_weights, self.bias, self.num_layers, + self.dropout, self.training, self.bidirectional, self.batch_first) + else: + result = _VF.gru(input, batch_sizes, hx, self._flat_weights, self.bias, + self.num_layers, self.dropout, self.training, self.bidirectional) + output = result[0] + hidden = result[1] + + # xxx: isinstance check needs to be in conditional for TorchScript to compile + if isinstance(orig_input, PackedSequence): + output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices) + return output_packed, self.permute_hidden(hidden, unsorted_indices) + else: + if not is_batched: # type: ignore[possibly-undefined] + output = output.squeeze(batch_dim) # type: ignore[possibly-undefined] + hidden = hidden.squeeze(1) + + return output, self.permute_hidden(hidden, unsorted_indices) + + +class RNNCellBase(Module): + __constants__ = ['input_size', 'hidden_size', 'bias'] + + input_size: int + hidden_size: int + bias: bool + weight_ih: Tensor + weight_hh: Tensor + # WARNING: bias_ih and bias_hh purposely not defined here. + # See https://github.com/pytorch/pytorch/issues/39670 + + def __init__(self, input_size: int, hidden_size: int, bias: bool, num_chunks: int, + device=None, dtype=None) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.bias = bias + self.weight_ih = Parameter(torch.empty((num_chunks * hidden_size, input_size), **factory_kwargs)) + self.weight_hh = Parameter(torch.empty((num_chunks * hidden_size, hidden_size), **factory_kwargs)) + if bias: + self.bias_ih = Parameter(torch.empty(num_chunks * hidden_size, **factory_kwargs)) + self.bias_hh = Parameter(torch.empty(num_chunks * hidden_size, **factory_kwargs)) + else: + self.register_parameter('bias_ih', None) + self.register_parameter('bias_hh', None) + + self.reset_parameters() + + def extra_repr(self) -> str: + s = '{input_size}, {hidden_size}' + if 'bias' in self.__dict__ and self.bias is not True: + s += ', bias={bias}' + if 'nonlinearity' in self.__dict__ and self.nonlinearity != "tanh": + s += ', nonlinearity={nonlinearity}' + return s.format(**self.__dict__) + + def reset_parameters(self) -> None: + stdv = 1.0 / math.sqrt(self.hidden_size) if self.hidden_size > 0 else 0 + for weight in self.parameters(): + init.uniform_(weight, -stdv, stdv) + + +class RNNCell(RNNCellBase): + r"""An Elman RNN cell with tanh or ReLU non-linearity. + + .. math:: + + h' = \tanh(W_{ih} x + b_{ih} + W_{hh} h + b_{hh}) + + If :attr:`nonlinearity` is `'relu'`, then ReLU is used in place of tanh. + + Args: + input_size: The number of expected features in the input `x` + hidden_size: The number of features in the hidden state `h` + bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`. + Default: ``True`` + nonlinearity: The non-linearity to use. Can be either ``'tanh'`` or ``'relu'``. Default: ``'tanh'`` + + Inputs: input, hidden + - **input**: tensor containing input features + - **hidden**: tensor containing the initial hidden state + Defaults to zero if not provided. + + Outputs: h' + - **h'** of shape `(batch, hidden_size)`: tensor containing the next hidden state + for each element in the batch + + Shape: + - input: :math:`(N, H_{in})` or :math:`(H_{in})` tensor containing input features where + :math:`H_{in}` = `input_size`. + - hidden: :math:`(N, H_{out})` or :math:`(H_{out})` tensor containing the initial hidden + state where :math:`H_{out}` = `hidden_size`. Defaults to zero if not provided. + - output: :math:`(N, H_{out})` or :math:`(H_{out})` tensor containing the next hidden state. + + Attributes: + weight_ih: the learnable input-hidden weights, of shape + `(hidden_size, input_size)` + weight_hh: the learnable hidden-hidden weights, of shape + `(hidden_size, hidden_size)` + bias_ih: the learnable input-hidden bias, of shape `(hidden_size)` + bias_hh: the learnable hidden-hidden bias, of shape `(hidden_size)` + + .. note:: + All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` + where :math:`k = \frac{1}{\text{hidden\_size}}` + + Examples:: + + >>> rnn = nn.RNNCell(10, 20) + >>> input = torch.randn(6, 3, 10) + >>> hx = torch.randn(3, 20) + >>> output = [] + >>> for i in range(6): + ... hx = rnn(input[i], hx) + ... output.append(hx) + """ + + __constants__ = ['input_size', 'hidden_size', 'bias', 'nonlinearity'] + nonlinearity: str + + def __init__(self, input_size: int, hidden_size: int, bias: bool = True, nonlinearity: str = "tanh", + device=None, dtype=None) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__(input_size, hidden_size, bias, num_chunks=1, **factory_kwargs) + self.nonlinearity = nonlinearity + + def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor: + if input.dim() not in (1, 2): + raise ValueError(f"RNNCell: Expected input to be 1D or 2D, got {input.dim()}D instead") + if hx is not None and hx.dim() not in (1, 2): + raise ValueError(f"RNNCell: Expected hidden to be 1D or 2D, got {hx.dim()}D instead") + is_batched = input.dim() == 2 + if not is_batched: + input = input.unsqueeze(0) + + if hx is None: + hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device) + else: + hx = hx.unsqueeze(0) if not is_batched else hx + + if self.nonlinearity == "tanh": + ret = _VF.rnn_tanh_cell( + input, hx, + self.weight_ih, self.weight_hh, + self.bias_ih, self.bias_hh, + ) + elif self.nonlinearity == "relu": + ret = _VF.rnn_relu_cell( + input, hx, + self.weight_ih, self.weight_hh, + self.bias_ih, self.bias_hh, + ) + else: + ret = input # TODO: remove when jit supports exception flow + raise RuntimeError( + f"Unknown nonlinearity: {self.nonlinearity}") + + if not is_batched: + ret = ret.squeeze(0) + + return ret + + +class LSTMCell(RNNCellBase): + r"""A long short-term memory (LSTM) cell. + + .. math:: + + \begin{array}{ll} + i = \sigma(W_{ii} x + b_{ii} + W_{hi} h + b_{hi}) \\ + f = \sigma(W_{if} x + b_{if} + W_{hf} h + b_{hf}) \\ + g = \tanh(W_{ig} x + b_{ig} + W_{hg} h + b_{hg}) \\ + o = \sigma(W_{io} x + b_{io} + W_{ho} h + b_{ho}) \\ + c' = f \odot c + i \odot g \\ + h' = o \odot \tanh(c') \\ + \end{array} + + where :math:`\sigma` is the sigmoid function, and :math:`\odot` is the Hadamard product. + + Args: + input_size: The number of expected features in the input `x` + hidden_size: The number of features in the hidden state `h` + bias: If ``False``, then the layer does not use bias weights `b_ih` and + `b_hh`. Default: ``True`` + + Inputs: input, (h_0, c_0) + - **input** of shape `(batch, input_size)` or `(input_size)`: tensor containing input features + - **h_0** of shape `(batch, hidden_size)` or `(hidden_size)`: tensor containing the initial hidden state + - **c_0** of shape `(batch, hidden_size)` or `(hidden_size)`: tensor containing the initial cell state + + If `(h_0, c_0)` is not provided, both **h_0** and **c_0** default to zero. + + Outputs: (h_1, c_1) + - **h_1** of shape `(batch, hidden_size)` or `(hidden_size)`: tensor containing the next hidden state + - **c_1** of shape `(batch, hidden_size)` or `(hidden_size)`: tensor containing the next cell state + + Attributes: + weight_ih: the learnable input-hidden weights, of shape + `(4*hidden_size, input_size)` + weight_hh: the learnable hidden-hidden weights, of shape + `(4*hidden_size, hidden_size)` + bias_ih: the learnable input-hidden bias, of shape `(4*hidden_size)` + bias_hh: the learnable hidden-hidden bias, of shape `(4*hidden_size)` + + .. note:: + All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` + where :math:`k = \frac{1}{\text{hidden\_size}}` + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + Examples:: + + >>> rnn = nn.LSTMCell(10, 20) # (input_size, hidden_size) + >>> input = torch.randn(2, 3, 10) # (time_steps, batch, input_size) + >>> hx = torch.randn(3, 20) # (batch, hidden_size) + >>> cx = torch.randn(3, 20) + >>> output = [] + >>> for i in range(input.size()[0]): + ... hx, cx = rnn(input[i], (hx, cx)) + ... output.append(hx) + >>> output = torch.stack(output, dim=0) + """ + + def __init__(self, input_size: int, hidden_size: int, bias: bool = True, + device=None, dtype=None) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__(input_size, hidden_size, bias, num_chunks=4, **factory_kwargs) + + def forward(self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None) -> Tuple[Tensor, Tensor]: + if input.dim() not in (1, 2): + raise ValueError(f"LSTMCell: Expected input to be 1D or 2D, got {input.dim()}D instead") + if hx is not None: + for idx, value in enumerate(hx): + if value.dim() not in (1, 2): + raise ValueError(f"LSTMCell: Expected hx[{idx}] to be 1D or 2D, got {value.dim()}D instead") + is_batched = input.dim() == 2 + if not is_batched: + input = input.unsqueeze(0) + + if hx is None: + zeros = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device) + hx = (zeros, zeros) + else: + hx = (hx[0].unsqueeze(0), hx[1].unsqueeze(0)) if not is_batched else hx + + ret = _VF.lstm_cell( + input, hx, + self.weight_ih, self.weight_hh, + self.bias_ih, self.bias_hh, + ) + + if not is_batched: + ret = (ret[0].squeeze(0), ret[1].squeeze(0)) + return ret + + +class GRUCell(RNNCellBase): + r"""A gated recurrent unit (GRU) cell. + + .. math:: + + \begin{array}{ll} + r = \sigma(W_{ir} x + b_{ir} + W_{hr} h + b_{hr}) \\ + z = \sigma(W_{iz} x + b_{iz} + W_{hz} h + b_{hz}) \\ + n = \tanh(W_{in} x + b_{in} + r \odot (W_{hn} h + b_{hn})) \\ + h' = (1 - z) \odot n + z \odot h + \end{array} + + where :math:`\sigma` is the sigmoid function, and :math:`\odot` is the Hadamard product. + + Args: + input_size: The number of expected features in the input `x` + hidden_size: The number of features in the hidden state `h` + bias: If ``False``, then the layer does not use bias weights `b_ih` and + `b_hh`. Default: ``True`` + + Inputs: input, hidden + - **input** : tensor containing input features + - **hidden** : tensor containing the initial hidden + state for each element in the batch. + Defaults to zero if not provided. + + Outputs: h' + - **h'** : tensor containing the next hidden state + for each element in the batch + + Shape: + - input: :math:`(N, H_{in})` or :math:`(H_{in})` tensor containing input features where + :math:`H_{in}` = `input_size`. + - hidden: :math:`(N, H_{out})` or :math:`(H_{out})` tensor containing the initial hidden + state where :math:`H_{out}` = `hidden_size`. Defaults to zero if not provided. + - output: :math:`(N, H_{out})` or :math:`(H_{out})` tensor containing the next hidden state. + + Attributes: + weight_ih: the learnable input-hidden weights, of shape + `(3*hidden_size, input_size)` + weight_hh: the learnable hidden-hidden weights, of shape + `(3*hidden_size, hidden_size)` + bias_ih: the learnable input-hidden bias, of shape `(3*hidden_size)` + bias_hh: the learnable hidden-hidden bias, of shape `(3*hidden_size)` + + .. note:: + All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` + where :math:`k = \frac{1}{\text{hidden\_size}}` + + On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision` for backward. + + Examples:: + + >>> rnn = nn.GRUCell(10, 20) + >>> input = torch.randn(6, 3, 10) + >>> hx = torch.randn(3, 20) + >>> output = [] + >>> for i in range(6): + ... hx = rnn(input[i], hx) + ... output.append(hx) + """ + + def __init__(self, input_size: int, hidden_size: int, bias: bool = True, + device=None, dtype=None) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__(input_size, hidden_size, bias, num_chunks=3, **factory_kwargs) + + def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor: + if input.dim() not in (1, 2): + raise ValueError(f"GRUCell: Expected input to be 1D or 2D, got {input.dim()}D instead") + if hx is not None and hx.dim() not in (1, 2): + raise ValueError(f"GRUCell: Expected hidden to be 1D or 2D, got {hx.dim()}D instead") + is_batched = input.dim() == 2 + if not is_batched: + input = input.unsqueeze(0) + + if hx is None: + hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device) + else: + hx = hx.unsqueeze(0) if not is_batched else hx + + ret = _VF.gru_cell( + input, hx, + self.weight_ih, self.weight_hh, + self.bias_ih, self.bias_hh, + ) + + if not is_batched: + ret = ret.squeeze(0) + + return ret diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/parallel/__pycache__/_functions.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/parallel/__pycache__/_functions.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..120e47eda9b2b4dedf0044e0c0089e8349d337a5 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/parallel/__pycache__/_functions.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/parallel/_functions.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/parallel/_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..d987ed2bc427462bf517665bda8e1889d5a58c90 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/parallel/_functions.py @@ -0,0 +1,126 @@ +import warnings + +import torch +from . import comm +from torch.autograd import Function +from torch._utils import _get_device_index +from typing import List, Optional + + +class Broadcast(Function): + + @staticmethod + def forward(ctx, target_gpus, *inputs): + assert all(i.device.type != 'cpu' for i in inputs), ( + 'Broadcast function not implemented for CPU tensors' + ) + target_gpus = [_get_device_index(x, True) for x in target_gpus] + ctx.target_gpus = target_gpus + if len(inputs) == 0: + return tuple() + ctx.num_inputs = len(inputs) + ctx.input_device = inputs[0].get_device() + outputs = comm.broadcast_coalesced(inputs, ctx.target_gpus) + non_differentiables = [] + for idx, input_requires_grad in enumerate(ctx.needs_input_grad[1:]): + if not input_requires_grad: + for output in outputs: + non_differentiables.append(output[idx]) + ctx.mark_non_differentiable(*non_differentiables) + return tuple([t for tensors in outputs for t in tensors]) + + @staticmethod + def backward(ctx, *grad_outputs): + return (None,) + ReduceAddCoalesced.apply(ctx.input_device, ctx.num_inputs, *grad_outputs) + + +class ReduceAddCoalesced(Function): + + @staticmethod + def forward(ctx, destination, num_inputs, *grads): + ctx.target_gpus = [grads[i].get_device() for i in range(0, len(grads), num_inputs)] + + grads_ = [grads[i:i + num_inputs] + for i in range(0, len(grads), num_inputs)] + return comm.reduce_add_coalesced(grads_, destination) + + @staticmethod + def backward(ctx, *grad_outputs): + return (None, None,) + Broadcast.apply(ctx.target_gpus, *grad_outputs) + + +class Gather(Function): + + @staticmethod + def forward(ctx, target_device, dim, *inputs): + assert all(i.device.type != 'cpu' for i in inputs), ( + 'Gather function not implemented for CPU tensors' + ) + if (target_device == 'cpu'): + ctx.target_device = 'cpu' + else: + target_device = _get_device_index(target_device, True) + ctx.target_device = target_device + ctx.dim = dim + ctx.input_gpus = tuple(i.get_device() for i in inputs) + if all(t.dim() == 0 for t in inputs) and dim == 0: + inputs = tuple(t.view(1) for t in inputs) + warnings.warn('Was asked to gather along dimension 0, but all ' + 'input tensors were scalars; will instead unsqueeze ' + 'and return a vector.') + ctx.unsqueezed_scalar = True + else: + ctx.unsqueezed_scalar = False + ctx.input_sizes = tuple(i.size(ctx.dim) for i in inputs) + return comm.gather(inputs, ctx.dim, ctx.target_device) + + @staticmethod + def backward(ctx, grad_output): + scattered_grads = Scatter.apply(ctx.input_gpus, ctx.input_sizes, ctx.dim, grad_output) + if ctx.unsqueezed_scalar: + scattered_grads = tuple(g[0] for g in scattered_grads) + return (None, None) + scattered_grads + + +class Scatter(Function): + + @staticmethod + def forward(ctx, target_gpus, chunk_sizes, dim, input): + target_gpus = [_get_device_index(x, True) for x in target_gpus] + ctx.dim = dim + ctx.input_device = input.get_device() if input.device.type != "cpu" else -1 + streams = None + if torch.cuda.is_available() and ctx.input_device == -1: + # Perform CPU to GPU copies in a background stream + streams = [_get_stream(torch.device("cuda", device)) for device in target_gpus] + outputs = comm.scatter(input, target_gpus, chunk_sizes, ctx.dim, streams) + # Synchronize with the copy stream + if streams is not None: + for i, output in enumerate(outputs): + with torch.cuda.device(target_gpus[i]): + main_stream = torch.cuda.current_stream() + main_stream.wait_stream(streams[i]) + output.record_stream(main_stream) + return outputs + + @staticmethod + def backward(ctx, *grad_output): + return None, None, None, Gather.apply(ctx.input_device, ctx.dim, *grad_output) + + +# background streams used for copying +_streams: Optional[List[Optional[torch.Stream]]] = None + +def _get_stream(device: torch.device): + """Get a background stream for copying between CPU and target device.""" + global _streams + if device.type == "cpu": + return None + device_mod = getattr(torch, device.type, None) + if device_mod is None: + return None + if _streams is None: + _streams = [None] * device_mod.device_count() + if _streams[device.index] is None: + _streams[device.index] = device_mod.Stream(device.index) + return _streams[device.index] diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/parallel/comm.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/parallel/comm.py new file mode 100644 index 0000000000000000000000000000000000000000..764775587d6859665fb13e2140c588ed5d91dc1c --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/parallel/comm.py @@ -0,0 +1,236 @@ +import warnings +import torch +from torch.cuda import nccl +from torch._utils import _take_tensors, _flatten_dense_tensors, \ + _unflatten_dense_tensors, _reorder_tensors_as, _get_device_index, _handle_complex +from typing import List + +def broadcast(tensor, devices=None, *, out=None): + r"""Broadcasts a tensor to specified GPU devices. + + Args: + tensor (Tensor): tensor to broadcast. Can be on CPU or GPU. + devices (Iterable[torch.device, str or int], optional): an iterable of + GPU devices, among which to broadcast. + out (Sequence[Tensor], optional, keyword-only): the GPU tensors to + store output results. + + .. note:: + Exactly one of :attr:`devices` and :attr:`out` must be specified. + + Returns: + - If :attr:`devices` is specified, + a tuple containing copies of :attr:`tensor`, placed on + :attr:`devices`. + - If :attr:`out` is specified, + a tuple containing :attr:`out` tensors, each containing a copy of + :attr:`tensor`. + """ + tensor = _handle_complex(tensor) + if not ((devices is None) ^ (out is None)): + raise RuntimeError( + f"Exactly one of 'devices' and 'out' must be specified, but got devices={devices} and out={out}") + if devices is not None: + devices = [_get_device_index(d) for d in devices] + return torch._C._broadcast(tensor, devices) + else: + return torch._C._broadcast_out(tensor, out) + + +def broadcast_coalesced(tensors, devices, buffer_size=10485760): + """Broadcast a sequence of tensors to the specified GPUs. + + Small tensors are first coalesced into a buffer to reduce the number of synchronizations. + + Args: + tensors (sequence): tensors to broadcast. Must be on the same device, + either CPU or GPU. + devices (Iterable[torch.device, str or int]): an iterable of GPU + devices, among which to broadcast. + buffer_size (int): maximum size of the buffer used for coalescing + + Returns: + A tuple containing copies of :attr:`tensor`, placed on :attr:`devices`. + """ + devices = [_get_device_index(d) for d in devices] + tensors = [_handle_complex(t) for t in tensors] + return torch._C._broadcast_coalesced(tensors, devices, buffer_size) + + +def reduce_add(inputs, destination=None): + """Sum tensors from multiple GPUs. + + All inputs should have matching shapes, dtype, and layout. The output tensor + will be of the same shape, dtype, and layout. + + Args: + inputs (Iterable[Tensor]): an iterable of tensors to add. + destination (int, optional): a device on which the output will be + placed (default: current device). + + Returns: + A tensor containing an elementwise sum of all inputs, placed on the + :attr:`destination` device. + """ + destination = _get_device_index(destination, optional=True) + input_size = inputs[0].size() + root_index = None # index of input tensor that already is on the correct device + for i, inp in enumerate(inputs): + assert inp.device.type != "cpu", "reduce_add expects all inputs to be on GPUs" + if inp.get_device() == destination: + root_index = i + if inp.size() != input_size: + got = 'x'.join(str(x) for x in inp.size()) + expected = 'x'.join(str(x) for x in input_size) + raise ValueError(f"input {i} has invalid size: got {got}, but expected {expected}") + if root_index is None: + raise RuntimeError("reduce_add expects destination to be on the same GPU with one of the tensors") + + if len(inputs) == 1: + return inputs[0] + + if nccl.is_available(inputs): + result = torch.empty_like(inputs[root_index]) + nccl.reduce(inputs, output=result, root=root_index) + else: + destination_device = torch.device(inputs[root_index].device.type, destination) + nonroot = [t for i, t in enumerate(inputs) if i != root_index] + # make a new tensor w/o clone + result = inputs[root_index] + nonroot[0].to(device=destination_device, non_blocking=True) + for other in nonroot[1:]: + result.add_(other.to(device=destination_device, non_blocking=True)) + return result + + +def reduce_add_coalesced(inputs, destination=None, buffer_size=10485760): + """Sum tensors from multiple GPUs. + + Small tensors are first coalesced into a buffer to reduce the number + of synchronizations. + + Args: + inputs (Iterable[Iterable[Tensor]]): iterable of iterables that + contain tensors from a single device. + destination (int, optional): a device on which the output will be + placed (default: current device). + buffer_size (int): maximum size of the buffer used for coalescing + + Returns: + A tuple of tensors containing an elementwise sum of each group of + inputs, placed on the ``destination`` device. + """ + # TODO: When `len(inputs) == 1` and all inputs are on `destination`, just + # return `inputs`. + dense_tensors: List[List] = [[] for _ in inputs] # shape (num_gpus, num_tensors) + output = [] + ref_order = [] + # process sparse ones first since they may have different sizes on different gpus + for tensor_at_gpus in zip(*inputs): + if all(t.is_sparse for t in tensor_at_gpus): + result = reduce_add(tensor_at_gpus, destination) # this will be sparse too + output.append(result) + ref_order.append(tensor_at_gpus[0]) + else: + for coll, t in zip(dense_tensors, tensor_at_gpus): + coll.append(t.to_dense() if t.is_sparse else t) + ref_order.append(dense_tensors[0][-1]) + itrs = [_take_tensors(tensors, buffer_size) for tensors in dense_tensors] + # now the dense ones, which have consistent sizes + for chunks in zip(*itrs): + flat_tensors = [_flatten_dense_tensors(chunk) for chunk in chunks] # (num_gpus,) + flat_result = reduce_add(flat_tensors, destination) + for t in _unflatten_dense_tensors(flat_result, chunks[0]): + # The unflattened tensors do not share storage, and we don't expose + # base flat tensor anyways, so give them different version counters. + # See NOTE [ Version Counter in comm.*_coalesced ] + output.append(t.data) + return tuple(_reorder_tensors_as(output, ref_order)) + + +def scatter(tensor, devices=None, chunk_sizes=None, dim=0, streams=None, *, out=None): + """Scatters tensor across multiple GPUs. + + Args: + tensor (Tensor): tensor to scatter. Can be on CPU or GPU. + devices (Iterable[torch.device, str or int], optional): an iterable of + GPU devices, among which to scatter. + chunk_sizes (Iterable[int], optional): sizes of chunks to be placed on + each device. It should match :attr:`devices` in length and sums to + ``tensor.size(dim)``. If not specified, :attr:`tensor` will be divided + into equal chunks. + dim (int, optional): A dimension along which to chunk :attr:`tensor`. + Default: ``0``. + streams (Iterable[torch.cuda.Stream], optional): an iterable of Streams, among + which to execute the scatter. If not specified, the default stream will + be utilized. + out (Sequence[Tensor], optional, keyword-only): the GPU tensors to + store output results. Sizes of these tensors must match that of + :attr:`tensor`, except for :attr:`dim`, where the total size must + sum to ``tensor.size(dim)``. + + .. note:: + Exactly one of :attr:`devices` and :attr:`out` must be specified. When + :attr:`out` is specified, :attr:`chunk_sizes` must not be specified and + will be inferred from sizes of :attr:`out`. + + Returns: + - If :attr:`devices` is specified, + a tuple containing chunks of :attr:`tensor`, placed on + :attr:`devices`. + - If :attr:`out` is specified, + a tuple containing :attr:`out` tensors, each containing a chunk of + :attr:`tensor`. + """ + tensor = _handle_complex(tensor) + if out is None: + devices = [_get_device_index(d) for d in devices] + return tuple(torch._C._scatter(tensor, devices, chunk_sizes, dim, streams)) + else: + if devices is not None: + raise RuntimeError( + f"'devices' must not be specified when 'out' is specified, but got devices={devices}") + if chunk_sizes is not None: + raise RuntimeError( + f"'chunk_sizes' must not be specified when 'out' is specified, but got chunk_sizes={chunk_sizes}") + return tuple(torch._C._scatter_out(tensor, out, dim, streams)) + + +def gather(tensors, dim=0, destination=None, *, out=None): + r"""Gathers tensors from multiple GPU devices. + + Args: + tensors (Iterable[Tensor]): an iterable of tensors to gather. + Tensor sizes in all dimensions other than :attr:`dim` have to match. + dim (int, optional): a dimension along which the tensors will be + concatenated. Default: ``0``. + destination (torch.device, str, or int, optional): the output device. + Can be CPU or CUDA. Default: the current CUDA device. + out (Tensor, optional, keyword-only): the tensor to store gather result. + Its sizes must match those of :attr:`tensors`, except for :attr:`dim`, + where the size must equal ``sum(tensor.size(dim) for tensor in tensors)``. + Can be on CPU or CUDA. + + .. note:: + :attr:`destination` must not be specified when :attr:`out` is specified. + + Returns: + - If :attr:`destination` is specified, + a tensor located on :attr:`destination` device, that is a result of + concatenating :attr:`tensors` along :attr:`dim`. + - If :attr:`out` is specified, + the :attr:`out` tensor, now containing results of concatenating + :attr:`tensors` along :attr:`dim`. + """ + tensors = [_handle_complex(t) for t in tensors] + if out is None: + if destination == -1: + warnings.warn( + 'Using -1 to represent CPU tensor is deprecated. Please use a ' + 'device object or string instead, e.g., "cpu".') + destination = _get_device_index(destination, allow_cpu=True, optional=True) + return torch._C._gather(tensors, dim, destination) + else: + if destination is not None: + raise RuntimeError( + f"'destination' must not be specified when 'out' is specified, but got destination={destination}") + return torch._C._gather_out(tensors, out, dim) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/parallel/data_parallel.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/parallel/data_parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..4471cee6f379fec8850e37163c11cde6838338fd --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/parallel/data_parallel.py @@ -0,0 +1,269 @@ +import operator +import torch +import warnings +from itertools import chain +from typing import Any, Dict, Generic, List, Optional, Sequence, Tuple, TypeVar, Union +from ..modules import Module +from .scatter_gather import scatter_kwargs, gather +from .replicate import replicate +from .parallel_apply import parallel_apply +from torch._utils import ( + _get_all_device_indices, + _get_available_device_type, + _get_device_index, + _get_devices_properties +) + +__all__ = ['DataParallel', 'data_parallel'] + +def _check_balance(device_ids: Sequence[Union[int, torch.device]]) -> None: + imbalance_warn = """ + There is an imbalance between your GPUs. You may want to exclude GPU {} which + has less than 75% of the memory or cores of GPU {}. You can do so by setting + the device_ids argument to DataParallel, or by setting the CUDA_VISIBLE_DEVICES + environment variable.""" + device_ids = [_get_device_index(x, True) for x in device_ids] + dev_props = _get_devices_properties(device_ids) + + def warn_imbalance(get_prop): + values = [get_prop(props) for props in dev_props] + min_pos, min_val = min(enumerate(values), key=operator.itemgetter(1)) + max_pos, max_val = max(enumerate(values), key=operator.itemgetter(1)) + if min_val / max_val < 0.75: + warnings.warn(imbalance_warn.format(device_ids[min_pos], device_ids[max_pos])) + return True + return False + + if warn_imbalance(lambda props: props.total_memory): + return + if warn_imbalance(lambda props: props.multi_processor_count): + return + + +T = TypeVar("T", bound=Module) + + +class DataParallel(Module, Generic[T]): + r"""Implements data parallelism at the module level. + + This container parallelizes the application of the given :attr:`module` by + splitting the input across the specified devices by chunking in the batch + dimension (other objects will be copied once per device). In the forward + pass, the module is replicated on each device, and each replica handles a + portion of the input. During the backwards pass, gradients from each replica + are summed into the original module. + + The batch size should be larger than the number of GPUs used. + + .. warning:: + It is recommended to use :class:`~torch.nn.parallel.DistributedDataParallel`, + instead of this class, to do multi-GPU training, even if there is only a single + node. See: :ref:`cuda-nn-ddp-instead` and :ref:`ddp`. + + Arbitrary positional and keyword inputs are allowed to be passed into + DataParallel but some types are specially handled. tensors will be + **scattered** on dim specified (default 0). tuple, list and dict types will + be shallow copied. The other types will be shared among different threads + and can be corrupted if written to in the model's forward pass. + + The parallelized :attr:`module` must have its parameters and buffers on + ``device_ids[0]`` before running this :class:`~torch.nn.DataParallel` + module. + + .. warning:: + In each forward, :attr:`module` is **replicated** on each device, so any + updates to the running module in ``forward`` will be lost. For example, + if :attr:`module` has a counter attribute that is incremented in each + ``forward``, it will always stay at the initial value because the update + is done on the replicas which are destroyed after ``forward``. However, + :class:`~torch.nn.DataParallel` guarantees that the replica on + ``device[0]`` will have its parameters and buffers sharing storage with + the base parallelized :attr:`module`. So **in-place** updates to the + parameters or buffers on ``device[0]`` will be recorded. E.g., + :class:`~torch.nn.BatchNorm2d` and :func:`~torch.nn.utils.spectral_norm` + rely on this behavior to update the buffers. + + .. warning:: + Forward and backward hooks defined on :attr:`module` and its submodules + will be invoked ``len(device_ids)`` times, each with inputs located on + a particular device. Particularly, the hooks are only guaranteed to be + executed in correct order with respect to operations on corresponding + devices. For example, it is not guaranteed that hooks set via + :meth:`~torch.nn.Module.register_forward_pre_hook` be executed before + `all` ``len(device_ids)`` :meth:`~torch.nn.Module.forward` calls, but + that each such hook be executed before the corresponding + :meth:`~torch.nn.Module.forward` call of that device. + + .. warning:: + When :attr:`module` returns a scalar (i.e., 0-dimensional tensor) in + :func:`forward`, this wrapper will return a vector of length equal to + number of devices used in data parallelism, containing the result from + each device. + + .. note:: + There is a subtlety in using the + ``pack sequence -> recurrent network -> unpack sequence`` pattern in a + :class:`~torch.nn.Module` wrapped in :class:`~torch.nn.DataParallel`. + See :ref:`pack-rnn-unpack-with-data-parallelism` section in FAQ for + details. + + + Args: + module (Module): module to be parallelized + device_ids (list of int or torch.device): CUDA devices (default: all devices) + output_device (int or torch.device): device location of output (default: device_ids[0]) + + Attributes: + module (Module): the module to be parallelized + + Example:: + + >>> # xdoctest: +SKIP + >>> net = torch.nn.DataParallel(model, device_ids=[0, 1, 2]) + >>> output = net(input_var) # input_var can be on any device, including CPU + """ + + # TODO: update notes/cuda.rst when this class handles 8+ GPUs well + + def __init__( + self, + module: T, + device_ids: Optional[Sequence[Union[int, torch.device]]] = None, + output_device: Optional[Union[int, torch.device]] = None, + dim: int = 0, + ) -> None: + super().__init__() + torch._C._log_api_usage_once("torch.nn.parallel.DataParallel") + device_type = _get_available_device_type() + if device_type is None: + self.module = module + self.device_ids = [] + return + + if device_ids is None: + device_ids = _get_all_device_indices() + + if device_ids is None: + raise RuntimeError("no available devices were found") + + if output_device is None: + output_device = device_ids[0] + + self.dim = dim + self.module = module + self.device_ids = [_get_device_index(x, True) for x in device_ids] + self.output_device = _get_device_index(output_device, True) + self.src_device_obj = torch.device(device_type, self.device_ids[0]) + + if device_type == "cuda": + _check_balance(self.device_ids) + + if len(self.device_ids) == 1: + self.module.to(self.src_device_obj) + + def forward(self, *inputs: Any, **kwargs: Any) -> Any: + with torch.autograd.profiler.record_function("DataParallel.forward"): + if not self.device_ids: + return self.module(*inputs, **kwargs) + + for t in chain(self.module.parameters(), self.module.buffers()): + if t.device != self.src_device_obj: + raise RuntimeError("module must have its parameters and buffers " + f"on device {self.src_device_obj} (device_ids[0]) but found one of " + f"them on device: {t.device}") + + inputs, module_kwargs = self.scatter(inputs, kwargs, self.device_ids) + # for forward function without any inputs, empty list and dict will be created + # so the module can be executed on one device which is the first one in device_ids + if not inputs and not module_kwargs: + inputs = ((),) + module_kwargs = ({},) + + if len(self.device_ids) == 1: + return self.module(*inputs[0], **module_kwargs[0]) + replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) + outputs = self.parallel_apply(replicas, inputs, module_kwargs) + return self.gather(outputs, self.output_device) + + def replicate(self, module: T, device_ids: Sequence[Union[int, torch.device]]) -> List[T]: + return replicate(module, device_ids, not torch.is_grad_enabled()) + + def scatter( + self, + inputs: Tuple[Any, ...], + kwargs: Optional[Dict[str, Any]], + device_ids: Sequence[Union[int, torch.device]], + ) -> Any: + return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) + + def parallel_apply(self, replicas: Sequence[T], inputs: Sequence[Any], kwargs: Any) -> List[Any]: + return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)]) + + def gather(self, outputs: Any, output_device: Union[int, torch.device]) -> Any: + return gather(outputs, output_device, dim=self.dim) + + +def data_parallel( + module: Module, + inputs: Any, + device_ids: Optional[Sequence[Union[int, torch.device]]] = None, + output_device: Optional[Union[int, torch.device]] = None, + dim: int = 0, + module_kwargs: Optional[Any] = None, +) -> torch.Tensor: + r"""Evaluate module(input) in parallel across the GPUs given in device_ids. + + This is the functional version of the DataParallel module. + + Args: + module (Module): the module to evaluate in parallel + inputs (Tensor): inputs to the module + device_ids (list of int or torch.device): GPU ids on which to replicate module + output_device (list of int or torch.device): GPU location of the output Use -1 to indicate the CPU. + (default: device_ids[0]) + Returns: + a Tensor containing the result of module(input) located on + output_device + """ + if not isinstance(inputs, tuple): + inputs = (inputs,) if inputs is not None else () + + device_type = _get_available_device_type() + + if device_type is None: + raise RuntimeError("device type could not be determined") + + if device_ids is None: + device_ids = _get_all_device_indices() + + if device_ids is None: + raise RuntimeError("no available devices were found") + + if output_device is None: + output_device = device_ids[0] + + device_ids = [_get_device_index(x, True) for x in device_ids] + output_device = _get_device_index(output_device, True) + src_device_obj = torch.device(device_type, device_ids[0]) + + for t in chain(module.parameters(), module.buffers()): + if t.device != src_device_obj: + raise RuntimeError("module must have its parameters and buffers " + f"on device {src_device_obj} (device_ids[0]) but found one of " + f"them on device: {t.device}") + + inputs, module_kwargs = scatter_kwargs(inputs, module_kwargs, device_ids, dim) + # for module without any inputs, empty list and dict will be created + # so the module can be executed on one device which is the first one in device_ids + if not inputs and not module_kwargs: + inputs = ((),) + module_kwargs = ({},) + + assert module_kwargs is not None + + if len(device_ids) == 1: + return module(*inputs[0], **module_kwargs[0]) + used_device_ids = device_ids[:len(inputs)] + replicas = replicate(module, used_device_ids) + outputs = parallel_apply(replicas, inputs, module_kwargs, used_device_ids) + return gather(outputs, output_device, dim) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/parameter.pyi b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/parameter.pyi new file mode 100644 index 0000000000000000000000000000000000000000..219bb6d4efa2a0bb1843ca3b38a4183a346a995c --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/parameter.pyi @@ -0,0 +1,40 @@ +import builtins +from typing import Optional, Tuple + +import torch +from torch import Tensor + +class Parameter(Tensor): + def __init__( + self, + data: Tensor = ..., + requires_grad: builtins.bool = ..., + ): ... + +def is_lazy(param: Tensor): ... + +class UninitializedParameter(Tensor): + def __init__( + self, + data: Tensor = ..., + requires_grad: builtins.bool = ..., + ): ... + def materialize( + self, + shape: Tuple[int, ...], + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): ... + +class UninitializedBuffer(Tensor): + def __init__( + self, + data: Tensor = ..., + requires_grad: builtins.bool = ..., + ): ... + def materialize( + self, + shape: Tuple[int, ...], + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): ... diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5ad7466bdaaaedd6ba455e0188caf170b3533064 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/_reference/modules/__pycache__/utils.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/_reference/modules/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a15179a29e95784f36074284f365c87960e75667 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/_reference/modules/__pycache__/utils.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/dynamic/modules/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/dynamic/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1d7de4e239e50ee263536bef64b3ef30a776a8ad --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/dynamic/modules/__init__.py @@ -0,0 +1,32 @@ +# flake8: noqa: F401 +r"""Quantized Dynamic Modules. + +This file is in the process of migration to `torch/ao/nn/quantized/dynamic`, +and is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate file under the `torch/ao/nn/quantized/dynamic`, +while adding an import statement here. +""" + +from torch.ao.nn.quantized.dynamic.modules import conv +from torch.ao.nn.quantized.dynamic.modules import linear +from torch.ao.nn.quantized.dynamic.modules import rnn + +from torch.ao.nn.quantized.dynamic.modules.conv import Conv1d, Conv2d, Conv3d, ConvTranspose1d, ConvTranspose2d, ConvTranspose3d +from torch.ao.nn.quantized.dynamic.modules.linear import Linear +from torch.ao.nn.quantized.dynamic.modules.rnn import LSTM, GRU, LSTMCell, RNNCell, GRUCell + +__all__ = [ + 'Linear', + 'LSTM', + 'GRU', + 'LSTMCell', + 'RNNCell', + 'GRUCell', + 'Conv1d', + 'Conv2d', + 'Conv3d', + 'ConvTranspose1d', + 'ConvTranspose2d', + 'ConvTranspose3d', +] diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/dynamic/modules/__pycache__/linear.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/dynamic/modules/__pycache__/linear.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..290a81ead8e044223bd653426f159969790d1a5b Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/dynamic/modules/__pycache__/linear.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/dynamic/modules/__pycache__/rnn.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/dynamic/modules/__pycache__/rnn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7f28eaedfb2ed0e2bce26c3b8b0d90a63a41a24b Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/dynamic/modules/__pycache__/rnn.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/modules/__pycache__/batchnorm.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/modules/__pycache__/batchnorm.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef87cc6de83065273a18a8083c3f489f5f4fe053 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/modules/__pycache__/batchnorm.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/modules/__pycache__/normalization.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/modules/__pycache__/normalization.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..87b76c499245ae50154b48983af330888451d788 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/modules/__pycache__/normalization.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/modules/activation.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/modules/activation.py new file mode 100644 index 0000000000000000000000000000000000000000..43d7fbf19c38453198446ca1b99ab8570a9ef122 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/modules/activation.py @@ -0,0 +1,18 @@ +# flake8: noqa: F401 +r"""Quantized Modules. + +This file is in the process of migration to `torch/ao/nn/quantized`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate file under the `torch/ao/nn/quantized/modules`, +while adding an import statement here. +""" + +from torch.ao.nn.quantized.modules.activation import ELU +from torch.ao.nn.quantized.modules.activation import Hardswish +from torch.ao.nn.quantized.modules.activation import LeakyReLU +from torch.ao.nn.quantized.modules.activation import MultiheadAttention +from torch.ao.nn.quantized.modules.activation import PReLU +from torch.ao.nn.quantized.modules.activation import ReLU6 +from torch.ao.nn.quantized.modules.activation import Sigmoid +from torch.ao.nn.quantized.modules.activation import Softmax diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/modules/dropout.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/modules/dropout.py new file mode 100644 index 0000000000000000000000000000000000000000..c42d68d595075045712d587f6218e52df810cc97 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/modules/dropout.py @@ -0,0 +1,13 @@ +# flake8: noqa: F401 +r"""Quantized Modules. + +This file is in the process of migration to `torch/ao/nn/quantized`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate file under the `torch/ao/nn/quantized/modules`, +while adding an import statement here. +""" + +__all__ = ['Dropout'] + +from torch.ao.nn.quantized.modules.dropout import Dropout diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/modules/embedding_ops.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/modules/embedding_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..73c8d84c76c28584396b59ba1aad08da8ca6d686 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/quantized/modules/embedding_ops.py @@ -0,0 +1,15 @@ +# flake8: noqa: F401 +r"""Quantized Modules. + +This file is in the process of migration to `torch/ao/nn/quantized`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate file under the `torch/ao/nn/quantized/modules`, +while adding an import statement here. +""" + +__all__ = ['EmbeddingPackedParams', 'Embedding', 'EmbeddingBag'] + +from torch.ao.nn.quantized.modules.embedding_ops import Embedding +from torch.ao.nn.quantized.modules.embedding_ops import EmbeddingBag +from torch.ao.nn.quantized.modules.embedding_ops import EmbeddingPackedParams