|
|
import builtins |
|
|
import collections |
|
|
import math |
|
|
import operator |
|
|
import warnings |
|
|
|
|
|
from collections.abc import Iterable |
|
|
from enum import Enum |
|
|
from functools import partial, reduce, singledispatch, wraps |
|
|
from typing import Callable, List, Optional, overload, Sequence, Tuple, Union |
|
|
|
|
|
import torch |
|
|
|
|
|
import torch._prims as prims |
|
|
import torch._prims_common as utils |
|
|
from torch._prims_common import ( |
|
|
check, |
|
|
DeviceLikeType, |
|
|
DimsSequenceType, |
|
|
DimsType, |
|
|
dtype_to_type, |
|
|
ELEMENTWISE_TYPE_PROMOTION_KIND, |
|
|
is_weakly_lesser_type, |
|
|
Number, |
|
|
NumberType, |
|
|
REDUCTION_OUTPUT_TYPE_KIND, |
|
|
ShapeType, |
|
|
StrideType, |
|
|
TensorLike, |
|
|
TensorLikeType, |
|
|
TensorOrNumberLikeType, |
|
|
TensorSequenceType, |
|
|
) |
|
|
from torch._prims_common.wrappers import ( |
|
|
_maybe_convert_to_dtype, |
|
|
_maybe_resize_out, |
|
|
_safe_copy_out, |
|
|
elementwise_type_promotion_wrapper, |
|
|
elementwise_unary_scalar_wrapper, |
|
|
out_wrapper, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
__all__ = [ |
|
|
|
|
|
|
|
|
|
|
|
"abs", |
|
|
"acos", |
|
|
"acosh", |
|
|
"asinh", |
|
|
"asin", |
|
|
"atan", |
|
|
"atanh", |
|
|
"bitwise_not", |
|
|
|
|
|
"ceil", |
|
|
"conj_physical", |
|
|
"cos", |
|
|
"cosh", |
|
|
"digamma", |
|
|
"erf", |
|
|
"erfinv", |
|
|
"erfc", |
|
|
"exp", |
|
|
"expm1", |
|
|
"exp2", |
|
|
"fill", |
|
|
"floor", |
|
|
"frac", |
|
|
"index_add", |
|
|
"index_add_", |
|
|
"index_copy", |
|
|
"index_copy_", |
|
|
"index_select", |
|
|
"index_fill", |
|
|
"index_fill_", |
|
|
"isfinite", |
|
|
"isinf", |
|
|
"isnan", |
|
|
"isreal", |
|
|
"i0", |
|
|
"lgamma", |
|
|
"log", |
|
|
"log1p", |
|
|
"log2", |
|
|
"log10", |
|
|
"nan_to_num", |
|
|
"neg", |
|
|
"positive", |
|
|
"reciprocal", |
|
|
"round", |
|
|
"sigmoid", |
|
|
"sgn", |
|
|
"sign", |
|
|
"signbit", |
|
|
"sin", |
|
|
"sinc", |
|
|
"sinh", |
|
|
"sqrt", |
|
|
"square", |
|
|
"tan", |
|
|
"tanh", |
|
|
"trace", |
|
|
"trunc", |
|
|
|
|
|
|
|
|
|
|
|
"add", |
|
|
"atan2", |
|
|
"bitwise_and", |
|
|
"bitwise_left_shift", |
|
|
"bitwise_or", |
|
|
"bitwise_right_shift", |
|
|
"bitwise_xor", |
|
|
|
|
|
"copysign", |
|
|
"div", |
|
|
"eq", |
|
|
"float_power", |
|
|
"floor_divide", |
|
|
"fmax", |
|
|
"fmin", |
|
|
"fmod", |
|
|
"gcd", |
|
|
"ge", |
|
|
"gt", |
|
|
"heaviside", |
|
|
"hypot", |
|
|
"igamma", |
|
|
"igammac", |
|
|
"imag", |
|
|
"isclose", |
|
|
"lcm", |
|
|
|
|
|
"le", |
|
|
"logical_and", |
|
|
"logical_not", |
|
|
"logical_or", |
|
|
"logical_xor", |
|
|
"lt", |
|
|
|
|
|
"maximum", |
|
|
|
|
|
"minimum", |
|
|
"mul", |
|
|
"ne", |
|
|
"nextafter", |
|
|
|
|
|
"pow", |
|
|
"real", |
|
|
"rpow", |
|
|
"remainder", |
|
|
"rsub", |
|
|
"rtruediv", |
|
|
"rfloordiv", |
|
|
|
|
|
|
|
|
"sub", |
|
|
"true_divide", |
|
|
"trunc_divide", |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"addcdiv", |
|
|
"clamp", |
|
|
|
|
|
|
|
|
|
|
|
"masked_fill", |
|
|
"where", |
|
|
|
|
|
|
|
|
|
|
|
"clone", |
|
|
"copy_to", |
|
|
"item", |
|
|
"to", |
|
|
|
|
|
|
|
|
|
|
|
"all", |
|
|
"amax", |
|
|
"amin", |
|
|
"any", |
|
|
"mean", |
|
|
"std_mean", |
|
|
"var_mean", |
|
|
"sum", |
|
|
"sum_to_size", |
|
|
"prod", |
|
|
"var", |
|
|
|
|
|
|
|
|
|
|
|
"addr", |
|
|
|
|
|
|
|
|
|
|
|
"atleast_1d", |
|
|
"atleast_2d", |
|
|
"atleast_3d", |
|
|
"as_strided", |
|
|
"broadcast_shapes", |
|
|
"broadcast_tensors", |
|
|
"broadcast_to", |
|
|
"cat", |
|
|
"chunk", |
|
|
"column_stack", |
|
|
"conj", |
|
|
"constant_pad_nd", |
|
|
"contiguous", |
|
|
"diag_embed", |
|
|
"diagonal", |
|
|
"dsplit", |
|
|
"dstack", |
|
|
"expand", |
|
|
"expand_as", |
|
|
"flatten", |
|
|
"flip", |
|
|
"fliplr", |
|
|
"flipud", |
|
|
"hsplit", |
|
|
"hstack", |
|
|
"meshgrid", |
|
|
"movedim", |
|
|
"narrow", |
|
|
"native_layer_norm", |
|
|
"permute", |
|
|
"ravel", |
|
|
"repeat", |
|
|
"reshape", |
|
|
"roll", |
|
|
"rot90", |
|
|
"rsqrt", |
|
|
"stack", |
|
|
"swap_axes", |
|
|
"squeeze", |
|
|
"t", |
|
|
"tensor_split", |
|
|
"transpose", |
|
|
"unfold_copy", |
|
|
"unsqueeze", |
|
|
"view", |
|
|
"vsplit", |
|
|
"vstack", |
|
|
"unflatten", |
|
|
"unbind", |
|
|
"triu", |
|
|
"tril", |
|
|
"triu_indices", |
|
|
"tril_indices", |
|
|
|
|
|
|
|
|
|
|
|
"arange", |
|
|
"empty", |
|
|
"empty_like", |
|
|
"empty_strided", |
|
|
"eye", |
|
|
"full", |
|
|
"full_like", |
|
|
"linspace", |
|
|
"logspace", |
|
|
"ones", |
|
|
"ones_like", |
|
|
"randn", |
|
|
"scalar_tensor", |
|
|
"zeros", |
|
|
"zeros_like", |
|
|
|
|
|
|
|
|
|
|
|
"uniform", |
|
|
|
|
|
|
|
|
|
|
|
"allclose", |
|
|
"equal", |
|
|
] |
|
|
|
|
|
Tensor = torch.Tensor |
|
|
DispatchKey = torch._C.DispatchKey |
|
|
|
|
|
|
|
|
def _broadcast_shapes(*_shapes): |
|
|
shapes = tuple( |
|
|
(x,) if isinstance(x, int) else x |
|
|
for x in filter(lambda x: x is not None, _shapes) |
|
|
) |
|
|
|
|
|
|
|
|
if len(shapes) == 0: |
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
for shape in shapes: |
|
|
assert isinstance(shape, Sequence) |
|
|
|
|
|
|
|
|
common_shape = [ |
|
|
1, |
|
|
] * reduce(max, (len(shape) for shape in shapes)) |
|
|
for shape in shapes: |
|
|
for idx in range(-1, -1 - len(shape), -1): |
|
|
if common_shape[idx] == 1: |
|
|
if shape[idx] < 0: |
|
|
raise ValueError( |
|
|
"Attempting to broadcast a dimension with negative length!" |
|
|
) |
|
|
common_shape[idx] = shape[idx] |
|
|
elif shape[idx] != 1: |
|
|
if common_shape[idx] != shape[idx]: |
|
|
raise RuntimeError( |
|
|
"Attempting to broadcast a dimension of length ", |
|
|
str(shape[idx]), |
|
|
"!", |
|
|
) |
|
|
|
|
|
return common_shape |
|
|
|
|
|
|
|
|
def _maybe_broadcast(*args, preserve_cpu_scalar_tensors=True): |
|
|
|
|
|
common_shape = _broadcast_shapes( |
|
|
*map(lambda t: t.shape if isinstance(t, TensorLike) else None, args) |
|
|
) |
|
|
|
|
|
def __maybe_broadcast(x, shape): |
|
|
if x is None: |
|
|
return None |
|
|
elif isinstance(x, Number): |
|
|
return x |
|
|
elif isinstance(x, TensorLike): |
|
|
if preserve_cpu_scalar_tensors and utils.is_cpu_scalar_tensor(x): |
|
|
return x |
|
|
|
|
|
if not utils.same_shape(x.shape, common_shape): |
|
|
return x.expand(common_shape) |
|
|
|
|
|
return x |
|
|
else: |
|
|
raise RuntimeError( |
|
|
"Unexpected type when broadcasting: " + str(type(x)) + "!" |
|
|
) |
|
|
|
|
|
return tuple(__maybe_broadcast(x, common_shape) for x in args) |
|
|
|
|
|
|
|
|
|
|
|
from torch._decomp import register_decomposition |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
infer_aten_op = object() |
|
|
|
|
|
|
|
|
def _make_elementwise_unary_reference( |
|
|
type_promotion_kind, |
|
|
*, |
|
|
aten_op=infer_aten_op, |
|
|
disable_meta=False, |
|
|
extra_meta=None, |
|
|
) -> Callable: |
|
|
def inner(prim: Callable): |
|
|
nonlocal aten_op |
|
|
|
|
|
@wraps(prim) |
|
|
@out_wrapper() |
|
|
@elementwise_unary_scalar_wrapper |
|
|
@elementwise_type_promotion_wrapper( |
|
|
type_promoting_args=("a",), |
|
|
type_promotion_kind=type_promotion_kind, |
|
|
) |
|
|
def _ref(a: TensorLikeType) -> TensorLikeType: |
|
|
if not isinstance(a, TensorLike): |
|
|
raise RuntimeError( |
|
|
"Expected a tensor input for an elementwise unary operation!" |
|
|
) |
|
|
|
|
|
if extra_meta is not None: |
|
|
extra_meta(a) |
|
|
|
|
|
return prim(a) |
|
|
|
|
|
if aten_op is infer_aten_op: |
|
|
aten_op = getattr(torch.ops.aten, prim.__name__) |
|
|
if aten_op is not None: |
|
|
register_decomposition(aten_op, disable_meta=disable_meta)(_ref) |
|
|
|
|
|
return _ref |
|
|
|
|
|
return inner |
|
|
|
|
|
|
|
|
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT) |
|
|
def abs(a): |
|
|
return prims.abs(a) |
|
|
|
|
|
|
|
|
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) |
|
|
def acos(a): |
|
|
return prims.acos(a) |
|
|
|
|
|
|
|
|
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) |
|
|
def acosh(a): |
|
|
return prims.acosh(a) |
|
|
|
|
|
|
|
|
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) |
|
|
def asin(a): |
|
|
return prims.asin(a) |
|
|
|
|
|
|
|
|
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) |
|
|
def asinh(a): |
|
|
return prims.asinh(a) |
|
|
|
|
|
|
|
|
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) |
|
|
def atan(a): |
|
|
return prims.atan(a) |
|
|
|
|
|
|
|
|
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) |
|
|
def atanh(a): |
|
|
return prims.atanh(a) |
|
|
|
|
|
|
|
|
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT) |
|
|
def bitwise_not(a): |
|
|
return prims.bitwise_not(a) |
|
|
|
|
|
|
|
|
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT) |
|
|
def ceil(a): |
|
|
return prims.ceil(a) |
|
|
|
|
|
|
|
|
@register_decomposition(torch.ops.aten.conj_physical) |
|
|
@out_wrapper() |
|
|
def conj_physical(input: TensorLikeType): |
|
|
if not utils.is_complex_dtype(input.dtype): |
|
|
return input |
|
|
return prims.conj_physical(input) |
|
|
|
|
|
|
|
|
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) |
|
|
def cos(a): |
|
|
return prims.cos(a) |
|
|
|
|
|
|
|
|
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) |
|
|
def cosh(a): |
|
|
return prims.cosh(a) |
|
|
|
|
|
|
|
|
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) |
|
|
def digamma(a): |
|
|
return prims.digamma(a) |
|
|
|
|
|
|
|
|
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) |
|
|
def erf(a): |
|
|
return prims.erf(a) |
|
|
|
|
|
|
|
|
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) |
|
|
def erfinv(a): |
|
|
return prims.erf_inv(a) |
|
|
|
|
|
|
|
|
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) |
|
|
def erfc(a): |
|
|
return prims.erfc(a) |
|
|
|
|
|
|
|
|
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) |
|
|
def exp(a): |
|
|
return prims.exp(a) |
|
|
|
|
|
|
|
|
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) |
|
|
def expm1(a): |
|
|
return prims.expm1(a) |
|
|
|
|
|
|
|
|
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) |
|
|
def exp2(a): |
|
|
return prims.exp2(a) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@out_wrapper() |
|
|
@elementwise_type_promotion_wrapper( |
|
|
type_promoting_args=("a,"), |
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH, |
|
|
) |
|
|
def fill(a: TensorLikeType, value: NumberType) -> TensorLikeType: |
|
|
|
|
|
assert isinstance(a, TensorLike) |
|
|
assert isinstance(value, Number) |
|
|
|
|
|
python_type = utils.dtype_to_type(a.dtype) |
|
|
if not utils.is_weakly_lesser_type(type(value), python_type): |
|
|
msg = "value argument of type {0} cannot be safely cast to type {1}!".format( |
|
|
type(value), python_type |
|
|
) |
|
|
raise ValueError(msg) |
|
|
|
|
|
return prims.fill(a, value) |
|
|
|
|
|
|
|
|
def fill_(a: TensorLikeType, value: NumberType) -> TensorLikeType: |
|
|
r = prims.fill(a, value) |
|
|
prims.copy_to(a, r) |
|
|
return a |
|
|
|
|
|
|
|
|
def zero_(a: TensorLikeType) -> TensorLikeType: |
|
|
r = prims.fill(a, 0) |
|
|
prims.copy_to(a, r) |
|
|
return a |
|
|
|
|
|
|
|
|
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT) |
|
|
def floor(a): |
|
|
return prims.floor(a) |
|
|
|
|
|
|
|
|
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT) |
|
|
def frac(x: TensorLikeType) -> TensorLikeType: |
|
|
trunc_x = mul(floor(abs(x)), sign(x)) |
|
|
return sub(x, trunc_x) |
|
|
|
|
|
|
|
|
|
|
|
def imag(a: TensorLikeType) -> TensorLikeType: |
|
|
assert isinstance(a, TensorLike) |
|
|
utils.check( |
|
|
utils.is_complex_dtype(a.dtype), lambda: "imag only supports complex tensors." |
|
|
) |
|
|
return prims.imag(a) |
|
|
|
|
|
|
|
|
@_make_elementwise_unary_reference( |
|
|
ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, |
|
|
aten_op=None, |
|
|
) |
|
|
def isfinite(a: TensorLikeType) -> TensorLikeType: |
|
|
if utils.is_float_dtype(a.dtype) or utils.is_complex_dtype(a.dtype): |
|
|
return prims.isfinite(a) |
|
|
|
|
|
return ones_like(a, dtype=torch.bool) |
|
|
|
|
|
|
|
|
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL) |
|
|
def isinf(a: TensorLikeType) -> TensorLikeType: |
|
|
if utils.is_complex_dtype(a.dtype): |
|
|
return logical_or(isinf(real(a)), isinf(imag(a))) |
|
|
return logical_not(logical_or(isnan(a), isfinite(a))) |
|
|
|
|
|
|
|
|
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL) |
|
|
def isposinf(a: TensorLikeType) -> TensorLikeType: |
|
|
utils.check( |
|
|
not utils.is_complex_dtype(a.dtype), |
|
|
lambda: f"Complex dtype is not supported for isposinf, got dtype {a.dtype}", |
|
|
) |
|
|
if utils.is_float_dtype(a.dtype): |
|
|
return eq(a, float("inf")) |
|
|
return zeros_like(a, dtype=torch.bool) |
|
|
|
|
|
|
|
|
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL) |
|
|
def isneginf(a: TensorLikeType) -> TensorLikeType: |
|
|
utils.check( |
|
|
not utils.is_complex_dtype(a.dtype), |
|
|
lambda: f"Complex dtype is not supported for isneginf, got dtype {a.dtype}", |
|
|
) |
|
|
if utils.is_float_dtype(a.dtype): |
|
|
return eq(a, float("-inf")) |
|
|
return zeros_like(a, dtype=torch.bool) |
|
|
|
|
|
|
|
|
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL) |
|
|
def isnan(a: TensorLikeType) -> TensorLikeType: |
|
|
return prims.ne(a, a) |
|
|
|
|
|
|
|
|
@_make_elementwise_unary_reference( |
|
|
ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, |
|
|
aten_op=None, |
|
|
) |
|
|
def isreal(a: TensorLikeType) -> TensorLikeType: |
|
|
if utils.is_complex_dtype(a.dtype): |
|
|
return torch.imag(a) == 0 |
|
|
return torch.ones_like(a, dtype=torch.bool) |
|
|
|
|
|
|
|
|
|
|
|
@_make_elementwise_unary_reference( |
|
|
ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, aten_op=torch.ops.aten.special_i0 |
|
|
) |
|
|
def i0(a): |
|
|
return prims.bessel_i0(a) |
|
|
|
|
|
|
|
|
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) |
|
|
def lgamma(a): |
|
|
return prims.lgamma(a) |
|
|
|
|
|
|
|
|
|
|
|
mvlgamma = torch.special.multigammaln |
|
|
|
|
|
|
|
|
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) |
|
|
def log(a): |
|
|
return prims.log(a) |
|
|
|
|
|
|
|
|
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) |
|
|
def log1p(a): |
|
|
return prims.log1p(a) |
|
|
|
|
|
|
|
|
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) |
|
|
def log2(a): |
|
|
return prims.log2(a) |
|
|
|
|
|
|
|
|
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) |
|
|
def log10(a): |
|
|
return prims.log10(a) |
|
|
|
|
|
|
|
|
@out_wrapper() |
|
|
def log_softmax( |
|
|
a: TensorLikeType, |
|
|
dim: int, |
|
|
*, |
|
|
dtype: Optional[torch.dtype] = None, |
|
|
) -> TensorLikeType: |
|
|
result_dtype = dtype or a.dtype |
|
|
computation_dtype = utils.get_computation_dtype(a.dtype) |
|
|
a_ = _maybe_convert_to_dtype(a, computation_dtype) |
|
|
return _maybe_convert_to_dtype(a_ - logsumexp(a_, dim, keepdim=True), result_dtype) |
|
|
|
|
|
|
|
|
@out_wrapper() |
|
|
def logsumexp( |
|
|
a: TensorLikeType, |
|
|
dim: DimsType, |
|
|
keepdim: bool = False, |
|
|
) -> TensorLikeType: |
|
|
dim = utils.canonicalize_dims(a.ndim, dim) |
|
|
|
|
|
if not isinstance(dim, Iterable): |
|
|
dim = (dim,) |
|
|
if utils.is_float_dtype(a.dtype) or utils.is_complex_dtype(a.dtype): |
|
|
|
|
|
a_max = amax(a, dim, keepdim=True) |
|
|
a_max = where(abs(a_max) == float("inf"), 0.0, a_max) |
|
|
a_max_squeezed = prims.squeeze(a_max, dim) if not keepdim else a_max |
|
|
result = log(sum(exp(a - a_max), dim, keepdim=keepdim)) + a_max_squeezed |
|
|
else: |
|
|
|
|
|
result = log(sum(exp(a), dim, keepdim=keepdim)) |
|
|
return result |
|
|
|
|
|
|
|
|
@register_decomposition(torch.ops.aten.nan_to_num) |
|
|
@out_wrapper() |
|
|
def nan_to_num( |
|
|
a: TensorLikeType, |
|
|
nan: Optional[NumberType] = 0.0, |
|
|
posinf: Optional[NumberType] = None, |
|
|
neginf: Optional[NumberType] = None, |
|
|
) -> TensorLikeType: |
|
|
assert isinstance(a, TensorLike) |
|
|
|
|
|
if utils.is_boolean_dtype(a.dtype) or utils.is_integer_dtype(a.dtype): |
|
|
return clone(a) |
|
|
|
|
|
if nan is None: |
|
|
nan = 0.0 |
|
|
|
|
|
if posinf is None: |
|
|
posinf = prims.maximum_value(a.dtype) |
|
|
|
|
|
if neginf is None: |
|
|
neginf = prims.minimum_value(a.dtype) |
|
|
|
|
|
result = where(isnan(a), nan, a) |
|
|
|
|
|
is_neg = signbit(a) |
|
|
is_neginf = bitwise_and(isinf(a), is_neg) |
|
|
result = where(is_neginf, neginf, result) |
|
|
|
|
|
is_posinf = bitwise_and(isinf(a), bitwise_not(is_neg)) |
|
|
result = where(is_posinf, posinf, result) |
|
|
return result |
|
|
|
|
|
|
|
|
def _neg_meta(a: TensorLikeType): |
|
|
if a.dtype is torch.bool: |
|
|
msg = "neg is not supported on bool tensors." |
|
|
raise RuntimeError(msg) |
|
|
|
|
|
|
|
|
@_make_elementwise_unary_reference( |
|
|
ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, extra_meta=_neg_meta |
|
|
) |
|
|
def neg(a): |
|
|
return prims.neg(a) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def positive(a: TensorLikeType) -> TensorLikeType: |
|
|
assert isinstance(a, TensorLike) |
|
|
if a.dtype is torch.bool: |
|
|
msg = "positive does not support bool tensors." |
|
|
raise RuntimeError(msg) |
|
|
return a |
|
|
|
|
|
|
|
|
|
|
|
def real(a: TensorLikeType) -> TensorLikeType: |
|
|
assert isinstance(a, TensorLike) |
|
|
if utils.is_complex_dtype(a.dtype): |
|
|
return prims.real(a) |
|
|
return a |
|
|
|
|
|
|
|
|
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) |
|
|
def reciprocal(a): |
|
|
return prims.reciprocal(a) |
|
|
|
|
|
|
|
|
|
|
|
@_make_elementwise_unary_reference( |
|
|
ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, |
|
|
aten_op=None, |
|
|
) |
|
|
def round(a): |
|
|
return prims.round(a) |
|
|
|
|
|
|
|
|
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) |
|
|
def rsqrt(a): |
|
|
return prims.rsqrt(a) |
|
|
|
|
|
|
|
|
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) |
|
|
def sigmoid(a: TensorLikeType) -> TensorLikeType: |
|
|
return true_divide(1, add(1, exp(neg(a)))) |
|
|
|
|
|
|
|
|
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT) |
|
|
def sgn(a): |
|
|
if utils.is_complex_dtype(a.dtype): |
|
|
a_abs = a.abs() |
|
|
return torch.where(a_abs == 0, 0, a / a_abs) |
|
|
else: |
|
|
return a.sign() |
|
|
|
|
|
|
|
|
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT) |
|
|
def sign(a): |
|
|
return prims.sign(a) |
|
|
|
|
|
|
|
|
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL) |
|
|
def signbit(a): |
|
|
return prims.signbit(a) |
|
|
|
|
|
|
|
|
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) |
|
|
def sin(a): |
|
|
return prims.sin(a) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) |
|
|
def sinc(a): |
|
|
a = math.pi * a |
|
|
return torch.where(a == 0, 1, torch.sin(a) / a) |
|
|
|
|
|
|
|
|
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) |
|
|
def sinh(a): |
|
|
return prims.sinh(a) |
|
|
|
|
|
|
|
|
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) |
|
|
def sqrt(a): |
|
|
return prims.sqrt(a) |
|
|
|
|
|
|
|
|
@_make_elementwise_unary_reference( |
|
|
ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG, |
|
|
aten_op=None, |
|
|
) |
|
|
def square(a: TensorLikeType) -> TensorLikeType: |
|
|
return mul(a, a) |
|
|
|
|
|
|
|
|
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) |
|
|
def tan(a): |
|
|
return prims.tan(a) |
|
|
|
|
|
|
|
|
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) |
|
|
def tanh(a): |
|
|
return prims.tanh(a) |
|
|
|
|
|
|
|
|
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT) |
|
|
def trunc(a): |
|
|
return prims.trunc(a) |
|
|
|
|
|
|
|
|
def _make_elementwise_binary_reference( |
|
|
prim: Callable, |
|
|
*, |
|
|
type_promotion_kind, |
|
|
aten_op=infer_aten_op, |
|
|
has_out=True, |
|
|
supports_lhs_python_scalar=True, |
|
|
supports_rhs_python_scalar=True, |
|
|
disable_meta=False, |
|
|
) -> Callable: |
|
|
@elementwise_type_promotion_wrapper( |
|
|
type_promoting_args=("a", "b"), |
|
|
type_promotion_kind=type_promotion_kind, |
|
|
) |
|
|
def _ref( |
|
|
a: Union[Tensor, NumberType], |
|
|
b: Union[Tensor, NumberType], |
|
|
) -> Tensor: |
|
|
if not supports_lhs_python_scalar and isinstance(a, Number): |
|
|
raise ValueError( |
|
|
"Received a lhs Python scalar to an elementwise binary operation that does not accept lhs scalars!" |
|
|
) |
|
|
|
|
|
if not supports_rhs_python_scalar and isinstance(b, Number): |
|
|
raise ValueError( |
|
|
"Received a rhs Python scalar to an elementwise binary operation that does not accept rhs scalars!" |
|
|
) |
|
|
|
|
|
|
|
|
if isinstance(a, Number) and isinstance(b, Number): |
|
|
raise ValueError( |
|
|
"Receive two Number inputs to an elementwise binary operation!" |
|
|
) |
|
|
|
|
|
a, b = _maybe_broadcast(a, b) |
|
|
return prim(a, b) |
|
|
|
|
|
if has_out: |
|
|
_ref = out_wrapper()(_ref) |
|
|
|
|
|
if aten_op is infer_aten_op: |
|
|
aten_op = getattr(torch.ops.aten, prim.__name__.split(".")[0]) |
|
|
if aten_op is not None: |
|
|
register_decomposition(aten_op, disable_meta=disable_meta)(_ref) |
|
|
|
|
|
return _ref |
|
|
|
|
|
|
|
|
|
|
|
@register_decomposition(torch.ops.aten.add) |
|
|
@out_wrapper() |
|
|
@elementwise_type_promotion_wrapper( |
|
|
type_promoting_args=("a", "b"), |
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, |
|
|
) |
|
|
def add( |
|
|
a: Union[TensorLikeType, NumberType], |
|
|
b: Union[TensorLikeType, NumberType], |
|
|
*, |
|
|
alpha: Optional[NumberType] = None, |
|
|
): |
|
|
""" |
|
|
Reference implementation of torch.add |
|
|
""" |
|
|
|
|
|
if isinstance(a, Number) and isinstance(b, Number): |
|
|
raise ValueError( |
|
|
"Receive two Number inputs to an elementwise binary operation!" |
|
|
) |
|
|
|
|
|
a, b = _maybe_broadcast(a, b) |
|
|
|
|
|
if alpha is not None: |
|
|
dtype = a.dtype if isinstance(a, TensorLike) else b.dtype |
|
|
python_type = utils.dtype_to_type(dtype) |
|
|
if not utils.is_weakly_lesser_type(type(alpha), python_type): |
|
|
msg = ( |
|
|
"alpha argument of type {0} cannot be safely cast to type {1}!".format( |
|
|
type(alpha), python_type |
|
|
) |
|
|
) |
|
|
raise ValueError(msg) |
|
|
b = prims.mul(b, alpha) |
|
|
|
|
|
return prims.add(a, b) |
|
|
|
|
|
|
|
|
|
|
|
atan2 = _make_elementwise_binary_reference( |
|
|
prims.atan2, |
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, |
|
|
supports_lhs_python_scalar=False, |
|
|
supports_rhs_python_scalar=False, |
|
|
) |
|
|
|
|
|
|
|
|
bitwise_and = _make_elementwise_binary_reference( |
|
|
prims.bitwise_and, |
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, |
|
|
) |
|
|
|
|
|
|
|
|
bitwise_left_shift = _make_elementwise_binary_reference( |
|
|
prims.shift_left, |
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, |
|
|
aten_op=torch.ops.aten.bitwise_left_shift, |
|
|
) |
|
|
|
|
|
|
|
|
bitwise_or = _make_elementwise_binary_reference( |
|
|
prims.bitwise_or, |
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, |
|
|
) |
|
|
|
|
|
|
|
|
bitwise_right_shift = _make_elementwise_binary_reference( |
|
|
prims.shift_right_arithmetic, |
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, |
|
|
aten_op=torch.ops.aten.bitwise_right_shift, |
|
|
) |
|
|
|
|
|
|
|
|
bitwise_xor = _make_elementwise_binary_reference( |
|
|
prims.bitwise_xor, |
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, |
|
|
) |
|
|
|
|
|
|
|
|
def _copysign( |
|
|
a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType] |
|
|
): |
|
|
if isinstance(b, Number) and isinstance(a, Tensor): |
|
|
b = scalar_tensor(b, dtype=a.dtype, device=a.device) |
|
|
elif isinstance(a, Tensor) and isinstance(b, Tensor) and a.device != b.device: |
|
|
msg = "Expected divisor (b) to be on the same device ({0}) as dividend (a), but it is found on {1}!".format( |
|
|
a.device, b.device |
|
|
) |
|
|
raise RuntimeError(msg) |
|
|
return where(signbit(b), neg(abs(a)), abs(a)) |
|
|
|
|
|
|
|
|
|
|
|
copysign = _make_elementwise_binary_reference( |
|
|
_copysign, |
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, |
|
|
supports_lhs_python_scalar=False, |
|
|
aten_op=torch.ops.aten.copysign, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_decomposition(torch.ops.aten.div) |
|
|
@out_wrapper() |
|
|
def div( |
|
|
a: Union[TensorLikeType, NumberType], |
|
|
b: Union[TensorLikeType, NumberType], |
|
|
*, |
|
|
rounding_mode: Optional[str] = None, |
|
|
): |
|
|
""" |
|
|
Reference implementation of torch.div |
|
|
""" |
|
|
if rounding_mode is None: |
|
|
return true_divide(a, b) |
|
|
elif rounding_mode == "trunc": |
|
|
return trunc_divide(a, b) |
|
|
elif rounding_mode == "floor": |
|
|
return floor_divide(a, b) |
|
|
else: |
|
|
msg = ( |
|
|
"div expected rounding_mode to be one of None, 'trunc', or 'floor' " |
|
|
"but found {0}.".format(rounding_mode) |
|
|
) |
|
|
raise ValueError(msg) |
|
|
|
|
|
|
|
|
|
|
|
eq = _make_elementwise_binary_reference( |
|
|
prims.eq, |
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, |
|
|
supports_lhs_python_scalar=False, |
|
|
) |
|
|
|
|
|
|
|
|
def _pow( |
|
|
a: Union[TensorLikeType, NumberType], |
|
|
b: Union[TensorLikeType, NumberType], |
|
|
) -> TensorLikeType: |
|
|
assert isinstance(a, TensorLikeType) or isinstance(b, TensorLikeType) |
|
|
|
|
|
if isinstance(b, Number): |
|
|
if b == 1.0: |
|
|
return a.clone() |
|
|
elif b == 2.0: |
|
|
return a * a |
|
|
elif b == 0.5: |
|
|
return torch.sqrt(a) |
|
|
return prims.pow(a, b) |
|
|
|
|
|
|
|
|
|
|
|
pow = _make_elementwise_binary_reference( |
|
|
_pow, |
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG, |
|
|
aten_op=torch.ops.aten.pow, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@out_wrapper() |
|
|
def float_power( |
|
|
a: Union[TensorLikeType, NumberType], |
|
|
b: Union[TensorLikeType, NumberType], |
|
|
) -> Tensor: |
|
|
|
|
|
if isinstance(a, Number) and isinstance(b, Number): |
|
|
raise ValueError( |
|
|
"Receive two Number inputs to an elementwise binary operation!" |
|
|
) |
|
|
|
|
|
|
|
|
dtype = utils.get_higher_dtype(a, b) |
|
|
assert dtype is not None |
|
|
if utils.is_complex_dtype(dtype): |
|
|
dtype = torch.complex128 |
|
|
else: |
|
|
dtype = torch.float64 |
|
|
|
|
|
|
|
|
|
|
|
if isinstance(a, TensorLike) and a.dtype != dtype: |
|
|
a = prims.to_dtype(a, dtype) |
|
|
if isinstance(b, TensorLike) and b.dtype != dtype: |
|
|
b = prims.to_dtype(b, dtype) |
|
|
|
|
|
a, b = _maybe_broadcast(a, b) |
|
|
return pow(a, b) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _floor_divide( |
|
|
a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType] |
|
|
): |
|
|
|
|
|
if isinstance(a, Number) and isinstance(b, Number): |
|
|
a = scalar_tensor(a) |
|
|
b = scalar_tensor(b) |
|
|
elif isinstance(b, Number) and isinstance(a, Tensor): |
|
|
b = scalar_tensor(b, dtype=a.dtype, device=a.device) |
|
|
elif isinstance(a, Number) and isinstance(b, Tensor): |
|
|
a = scalar_tensor(a, dtype=b.dtype, device=b.device) |
|
|
elif isinstance(a, Tensor) and isinstance(b, Tensor) and a.device != b.device: |
|
|
if a.device == torch.device("cpu"): |
|
|
msg = "Expected divisor (b) to be on the same device ({0}) as dividend (a), but it is found on {1}!".format( |
|
|
a.device, b.device |
|
|
) |
|
|
raise RuntimeError(msg) |
|
|
else: |
|
|
b = prims.device_put(b, device=a.device) |
|
|
|
|
|
assert isinstance(a, Tensor) and isinstance(b, Tensor) |
|
|
dtype = a.dtype |
|
|
if utils.is_float_dtype(dtype): |
|
|
return _floor_divide_float(a, b) |
|
|
elif utils.is_integer_dtype(dtype): |
|
|
return _floor_divide_integer(a, b) |
|
|
else: |
|
|
check(False, lambda: f"{dtype} not supported for floor_divide") |
|
|
|
|
|
|
|
|
def _floor_divide_integer(a: Tensor, b: Tensor) -> Tensor: |
|
|
a, b = _maybe_broadcast(a, b) |
|
|
|
|
|
if not a.dtype.is_signed: |
|
|
return prims.div(a, b) |
|
|
|
|
|
|
|
|
offset = (torch.signbit(a) != torch.signbit(b)).logical_and(torch.fmod(a, b) != 0) |
|
|
return prims.div(a, b) - prims.convert_element_type(offset, a.dtype) |
|
|
|
|
|
|
|
|
def _floor_divide_float(a: Tensor, b: Tensor) -> Tensor: |
|
|
mod = fmod(a, b) |
|
|
div = true_divide(sub(a, mod), b) |
|
|
|
|
|
|
|
|
different_signed_inputs = bitwise_xor(lt(a, 0), lt(b, 0)) |
|
|
non_zero_remainder = ne(mod, 0) |
|
|
mask = bitwise_and(non_zero_remainder, different_signed_inputs) |
|
|
div = where(mask, sub(div, 1), div) |
|
|
|
|
|
|
|
|
floor_div = floor(div) |
|
|
mask = gt(sub(div, floor_div), 0.5) |
|
|
floor_div = where(mask, add(floor_div, 1), floor_div) |
|
|
|
|
|
basic_div = true_divide(a, b) |
|
|
zero_tensor = scalar_tensor(0, dtype=basic_div.dtype, device=basic_div.device) |
|
|
|
|
|
|
|
|
floor_div = where(ne(div, 0), floor_div, copysign(zero_tensor, basic_div)) |
|
|
|
|
|
|
|
|
return where(ne(b, 0), floor_div, basic_div) |
|
|
|
|
|
|
|
|
|
|
|
floor_divide = _make_elementwise_binary_reference( |
|
|
_floor_divide, |
|
|
type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, |
|
|
aten_op=torch.ops.aten.floor_divide, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
fmax = _make_elementwise_binary_reference( |
|
|
prims.fmax, |
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, |
|
|
aten_op=torch.ops.aten.fmax, |
|
|
supports_lhs_python_scalar=False, |
|
|
supports_rhs_python_scalar=False, |
|
|
) |
|
|
|
|
|
|
|
|
fmin = _make_elementwise_binary_reference( |
|
|
prims.fmin, |
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, |
|
|
aten_op=torch.ops.aten.fmin, |
|
|
supports_lhs_python_scalar=False, |
|
|
supports_rhs_python_scalar=False, |
|
|
) |
|
|
|
|
|
|
|
|
fmod = _make_elementwise_binary_reference( |
|
|
prims.fmod, |
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, |
|
|
aten_op=torch.ops.aten.fmod, |
|
|
supports_lhs_python_scalar=False, |
|
|
supports_rhs_python_scalar=True, |
|
|
) |
|
|
|
|
|
|
|
|
gcd = _make_elementwise_binary_reference( |
|
|
prims.gcd, |
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, |
|
|
aten_op=torch.ops.aten.gcd, |
|
|
supports_lhs_python_scalar=False, |
|
|
supports_rhs_python_scalar=False, |
|
|
) |
|
|
|
|
|
|
|
|
ge = _make_elementwise_binary_reference( |
|
|
prims.ge, |
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, |
|
|
supports_lhs_python_scalar=False, |
|
|
) |
|
|
|
|
|
|
|
|
gt = _make_elementwise_binary_reference( |
|
|
prims.gt, |
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, |
|
|
supports_lhs_python_scalar=False, |
|
|
) |
|
|
|
|
|
|
|
|
def _heaviside(input: TensorLikeType, values: TensorLikeType) -> TensorLikeType: |
|
|
input_eq_zero = eq(input, 0) |
|
|
input_lt_zero = logical_or(lt(input, 0), isnan(input)) |
|
|
zeros_and_ones = where(input_lt_zero, 0, 1) |
|
|
output = where(input_eq_zero, values, zeros_and_ones) |
|
|
return output |
|
|
|
|
|
|
|
|
heaviside = _make_elementwise_binary_reference( |
|
|
_heaviside, |
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH, |
|
|
supports_lhs_python_scalar=False, |
|
|
supports_rhs_python_scalar=False, |
|
|
aten_op=torch.ops.aten.heaviside, |
|
|
) |
|
|
|
|
|
hypot = _make_elementwise_binary_reference( |
|
|
prims.hypot, |
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, |
|
|
supports_lhs_python_scalar=False, |
|
|
supports_rhs_python_scalar=False, |
|
|
) |
|
|
|
|
|
igamma = _make_elementwise_binary_reference( |
|
|
prims.igamma, |
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, |
|
|
supports_lhs_python_scalar=False, |
|
|
supports_rhs_python_scalar=False, |
|
|
) |
|
|
|
|
|
igammac = _make_elementwise_binary_reference( |
|
|
prims.igammac, |
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, |
|
|
supports_lhs_python_scalar=False, |
|
|
supports_rhs_python_scalar=False, |
|
|
) |
|
|
|
|
|
|
|
|
def _check_close_args( |
|
|
name: str, |
|
|
a: TensorLikeType, |
|
|
b: TensorLikeType, |
|
|
rtol: float, |
|
|
atol: float, |
|
|
) -> None: |
|
|
check( |
|
|
a.dtype == b.dtype, |
|
|
lambda: "{0}: Attempting to compare tensors of different dtypes {1} and {2}!".format( |
|
|
name, a.dtype, b.dtype |
|
|
), |
|
|
ValueError, |
|
|
) |
|
|
check( |
|
|
rtol >= 0, |
|
|
lambda: "{0}: rtol must be greater than or equal to zero, but got {1}!".format( |
|
|
name, rtol |
|
|
), |
|
|
) |
|
|
check( |
|
|
atol >= 0, |
|
|
lambda: "{0}: atol must be greater than or equal to zero, but got {1}!".format( |
|
|
name, atol |
|
|
), |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def isclose( |
|
|
a: TensorLikeType, |
|
|
b: TensorLikeType, |
|
|
rtol: float = 1e-05, |
|
|
atol: float = 1e-08, |
|
|
equal_nan: bool = False, |
|
|
) -> TensorLikeType: |
|
|
_check_close_args(name="torch.isclose", a=a, b=b, rtol=rtol, atol=atol) |
|
|
|
|
|
close = eq(a, b) |
|
|
if equal_nan and (utils.is_float_dtype(a.dtype) or utils.is_complex_dtype(a.dtype)): |
|
|
close = logical_or(close, logical_and(isnan(a), isnan(b))) |
|
|
|
|
|
|
|
|
|
|
|
if atol == 0 and rtol == 0: |
|
|
return close |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not utils.is_float_dtype(a.dtype) and not utils.is_complex_dtype(a.dtype): |
|
|
a = prims.convert_element_type(a, torch.get_default_dtype()) |
|
|
b = prims.convert_element_type(b, torch.get_default_dtype()) |
|
|
|
|
|
allowed_error = add(atol, abs(mul(b, rtol))) |
|
|
actual_error = abs(sub(a, b)) |
|
|
|
|
|
|
|
|
result = logical_or( |
|
|
close, logical_and(isfinite(actual_error), le(actual_error, allowed_error)) |
|
|
) |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
def _lcm(a: TensorLikeType, b: TensorLikeType): |
|
|
dtype = a.dtype |
|
|
promote_to_int = dtype in (torch.int8, torch.int16) |
|
|
if promote_to_int: |
|
|
a = prims.convert_element_type(a, torch.int32) |
|
|
b = prims.convert_element_type(b, torch.int32) |
|
|
|
|
|
g = torch.gcd(a, b) |
|
|
|
|
|
g = torch.where(g == 0, 1, g) |
|
|
res = torch.abs(prims.div(a, g) * b) |
|
|
return res if not promote_to_int else prims.convert_element_type(res, dtype) |
|
|
|
|
|
|
|
|
|
|
|
lcm = _make_elementwise_binary_reference( |
|
|
_lcm, |
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, |
|
|
aten_op=torch.ops.aten.lcm, |
|
|
supports_lhs_python_scalar=False, |
|
|
supports_rhs_python_scalar=False, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
le = _make_elementwise_binary_reference( |
|
|
prims.le, |
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, |
|
|
supports_lhs_python_scalar=False, |
|
|
) |
|
|
|
|
|
|
|
|
def _logical_and(a: TensorLikeType, b: TensorLikeType): |
|
|
if not utils.is_boolean_dtype(a.dtype): |
|
|
a = a != 0 |
|
|
if not utils.is_boolean_dtype(b.dtype): |
|
|
b = b != 0 |
|
|
return a & b |
|
|
|
|
|
|
|
|
logical_and = _make_elementwise_binary_reference( |
|
|
_logical_and, |
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, |
|
|
aten_op=torch.ops.aten.logical_and, |
|
|
) |
|
|
|
|
|
|
|
|
@_make_elementwise_unary_reference( |
|
|
ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, aten_op=torch.ops.aten.logical_not |
|
|
) |
|
|
def logical_not(a: TensorLikeType): |
|
|
if not utils.is_boolean_dtype(a.dtype): |
|
|
return a == 0 |
|
|
return ~a |
|
|
|
|
|
|
|
|
def _logical_or(a: TensorLikeType, b: TensorLikeType): |
|
|
if not utils.is_boolean_dtype(a.dtype): |
|
|
a = a != 0 |
|
|
if not utils.is_boolean_dtype(b.dtype): |
|
|
b = b != 0 |
|
|
return bitwise_or(a, b) |
|
|
|
|
|
|
|
|
logical_or = _make_elementwise_binary_reference( |
|
|
_logical_or, |
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, |
|
|
aten_op=torch.ops.aten.logical_or, |
|
|
) |
|
|
|
|
|
|
|
|
def _logical_xor(a: TensorLikeType, b: TensorLikeType): |
|
|
if not utils.is_boolean_dtype(a.dtype): |
|
|
a = a != 0 |
|
|
if not utils.is_boolean_dtype(b.dtype): |
|
|
b = b != 0 |
|
|
return a ^ b |
|
|
|
|
|
|
|
|
|
|
|
logical_xor = _make_elementwise_binary_reference( |
|
|
_logical_xor, |
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, |
|
|
aten_op=torch.ops.aten.logical_xor, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
lt = _make_elementwise_binary_reference( |
|
|
prims.lt, |
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, |
|
|
supports_lhs_python_scalar=False, |
|
|
) |
|
|
|
|
|
|
|
|
maximum = _make_elementwise_binary_reference( |
|
|
prims.maximum, |
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, |
|
|
) |
|
|
|
|
|
|
|
|
minimum = _make_elementwise_binary_reference( |
|
|
prims.minimum, |
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, |
|
|
) |
|
|
|
|
|
|
|
|
mul = _make_elementwise_binary_reference( |
|
|
prims.mul, |
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, |
|
|
) |
|
|
|
|
|
|
|
|
ne = _make_elementwise_binary_reference( |
|
|
prims.ne, |
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, |
|
|
supports_lhs_python_scalar=False, |
|
|
) |
|
|
|
|
|
|
|
|
nextafter = _make_elementwise_binary_reference( |
|
|
prims.nextafter, |
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH, |
|
|
supports_lhs_python_scalar=False, |
|
|
supports_rhs_python_scalar=False, |
|
|
) |
|
|
|
|
|
|
|
|
remainder = _make_elementwise_binary_reference( |
|
|
prims.remainder, |
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, |
|
|
aten_op=torch.ops.aten.remainder, |
|
|
) |
|
|
|
|
|
|
|
|
def rsub( |
|
|
a: Union[TensorLikeType, NumberType], |
|
|
b: Union[TensorLikeType, NumberType], |
|
|
*, |
|
|
alpha: Optional[NumberType] = None, |
|
|
): |
|
|
if isinstance(a, Number): |
|
|
msg = "Received a Number for the first argument, but expected a Tensor" |
|
|
raise ValueError(msg) |
|
|
return sub(b, a, alpha=alpha) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_decomposition(torch.ops.aten.sub) |
|
|
@out_wrapper() |
|
|
@elementwise_type_promotion_wrapper( |
|
|
type_promoting_args=("a", "b"), |
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, |
|
|
) |
|
|
def sub( |
|
|
a: Union[TensorLikeType, NumberType], |
|
|
b: Union[TensorLikeType, NumberType], |
|
|
*, |
|
|
alpha: Optional[NumberType] = None, |
|
|
): |
|
|
""" |
|
|
Reference implementation of torch.sub |
|
|
""" |
|
|
|
|
|
if isinstance(a, Number) and isinstance(b, Number): |
|
|
raise ValueError( |
|
|
"Receive two Number inputs to an elementwise binary operation!" |
|
|
) |
|
|
|
|
|
a, b = _maybe_broadcast(a, b) |
|
|
|
|
|
if alpha is not None: |
|
|
dtype = a.dtype if isinstance(a, TensorLike) else b.dtype |
|
|
python_type = utils.dtype_to_type(dtype) |
|
|
if not utils.is_weakly_lesser_type(type(alpha), python_type): |
|
|
msg = ( |
|
|
"alpha argument of type {0} cannot be safely cast to type {1}!".format( |
|
|
type(alpha), python_type |
|
|
) |
|
|
) |
|
|
raise ValueError(msg) |
|
|
b = prims.mul(b, alpha) |
|
|
|
|
|
return prims.sub(a, b) |
|
|
|
|
|
|
|
|
|
|
|
true_divide = _make_elementwise_binary_reference( |
|
|
prims.div, |
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, |
|
|
aten_op=None, |
|
|
) |
|
|
|
|
|
|
|
|
def _trunc_divide( |
|
|
a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType] |
|
|
): |
|
|
dtype = utils.get_dtype(a) |
|
|
if utils.is_integer_dtype(dtype): |
|
|
return prims.div(a, b) |
|
|
|
|
|
return trunc(prims.div(a, b)) |
|
|
|
|
|
|
|
|
|
|
|
trunc_divide = _make_elementwise_binary_reference( |
|
|
_trunc_divide, |
|
|
type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, |
|
|
aten_op=None, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_decomposition(torch.ops.aten.addcdiv) |
|
|
@out_wrapper() |
|
|
@elementwise_type_promotion_wrapper( |
|
|
type_promoting_args=("self", "tensor1", "tensor2"), |
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, |
|
|
) |
|
|
def addcdiv( |
|
|
self: TensorLikeType, |
|
|
tensor1: TensorLikeType, |
|
|
tensor2: TensorLikeType, |
|
|
*, |
|
|
value: NumberType = 1, |
|
|
) -> TensorLikeType: |
|
|
""" |
|
|
Reference implementation of torch.addcdiv |
|
|
""" |
|
|
if value is not None: |
|
|
dtype = self.dtype |
|
|
python_type = utils.dtype_to_type(dtype) |
|
|
if not utils.is_weakly_lesser_type(type(value), python_type): |
|
|
msg = ( |
|
|
"value argument of type {0} cannot be safely cast to type {1}!".format( |
|
|
type(value), python_type |
|
|
) |
|
|
) |
|
|
raise ValueError(msg) |
|
|
|
|
|
return self + value * tensor1 / tensor2 |
|
|
|
|
|
|
|
|
@register_decomposition(torch.ops.aten.clamp) |
|
|
@out_wrapper() |
|
|
@elementwise_type_promotion_wrapper( |
|
|
type_promoting_args=("a", "min", "max"), |
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, |
|
|
) |
|
|
def clamp( |
|
|
a: TensorLikeType, |
|
|
min: Optional[TensorOrNumberLikeType] = None, |
|
|
max: Optional[TensorOrNumberLikeType] = None, |
|
|
) -> TensorLikeType: |
|
|
|
|
|
if min is None and max is None: |
|
|
msg = "clamp called but both min and max are none!" |
|
|
raise ValueError(msg) |
|
|
if min is not None: |
|
|
a_isnan = torch.isnan(a) |
|
|
condition = torch.bitwise_or(torch.ge(a, min), a_isnan) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
a = torch.where(condition, a, min) |
|
|
if max is not None: |
|
|
a_isnan = torch.isnan(a) |
|
|
|
|
|
condition = torch.bitwise_or(torch.le(a, max), a_isnan) |
|
|
a = torch.where(condition, a, max) |
|
|
|
|
|
return a |
|
|
|
|
|
|
|
|
@register_decomposition(torch.ops.aten.clamp_min) |
|
|
@out_wrapper() |
|
|
def clamp_min( |
|
|
self: TensorLikeType, |
|
|
min: TensorOrNumberLikeType = None, |
|
|
) -> TensorLikeType: |
|
|
return torch.clamp(self, min=min) |
|
|
|
|
|
|
|
|
@register_decomposition(torch.ops.aten.clamp_max) |
|
|
@out_wrapper() |
|
|
def clamp_max( |
|
|
self: TensorLikeType, |
|
|
max: TensorOrNumberLikeType = None, |
|
|
) -> TensorLikeType: |
|
|
return torch.clamp(self, max=max) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_decomposition(torch.ops.aten.where) |
|
|
@out_wrapper() |
|
|
@elementwise_type_promotion_wrapper( |
|
|
type_promoting_args=("a", "b"), |
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH, |
|
|
) |
|
|
def where( |
|
|
pred: Tensor, |
|
|
a: Optional[TensorOrNumberLikeType] = None, |
|
|
b: Optional[TensorOrNumberLikeType] = None, |
|
|
): |
|
|
""" """ |
|
|
|
|
|
if a is None or b is None: |
|
|
raise NotImplementedError |
|
|
|
|
|
utils.check_same_device(pred, a, b, allow_cpu_scalar_tensors=True) |
|
|
check( |
|
|
pred.dtype is torch.bool, |
|
|
lambda: f"expected predicate to be bool, got {pred.dtype}", |
|
|
) |
|
|
|
|
|
pred, a, b = _maybe_broadcast(pred, a, b) |
|
|
return prims.where(pred, a, b) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_decomposition(torch.ops.aten.clone) |
|
|
def clone( |
|
|
a: TensorLikeType, *, memory_format: torch.memory_format = torch.preserve_format |
|
|
) -> TensorLikeType: |
|
|
result = torch.empty_like( |
|
|
a, requires_grad=a.requires_grad, memory_format=memory_format |
|
|
) |
|
|
copy_to(result, a) |
|
|
return result |
|
|
|
|
|
|
|
|
def copy_to(a: Tensor, b: Tensor, *, allow_cross_device=True): |
|
|
if not allow_cross_device and a.device != b.device: |
|
|
msg = "Attempting to copy from device {0} to device {1}, but cross-device copies are not allowed!".format( |
|
|
b.device, a.device |
|
|
) |
|
|
raise RuntimeError(msg) |
|
|
|
|
|
return prims.copy_to(a, b) |
|
|
|
|
|
|
|
|
@register_decomposition(torch.ops.aten.item) |
|
|
def item(a: TensorLikeType) -> NumberType: |
|
|
if a.numel() != 1: |
|
|
msg = f"Can't convert a tensor with {a.numel()} elements to a number!" |
|
|
raise ValueError(msg) |
|
|
|
|
|
|
|
|
|
|
|
number_type = utils.dtype_to_type(a.dtype) |
|
|
return number_type(prims.item(a)) |
|
|
|
|
|
|
|
|
|
|
|
def _to_will_alias( |
|
|
a: TensorLikeType, |
|
|
device: Optional[torch.device] = None, |
|
|
dtype: Optional[torch.dtype] = None, |
|
|
copy: Optional[bool] = None, |
|
|
layout: Optional[torch.layout] = None, |
|
|
memory_format: Optional[torch.memory_format] = None, |
|
|
pin_memory: Optional[bool] = False, |
|
|
non_blocking: bool = False, |
|
|
) -> bool: |
|
|
return ( |
|
|
not copy |
|
|
and (device is None or a.device == device) |
|
|
and (dtype is None or a.dtype == dtype) |
|
|
and (layout is None or a.layout == layout) |
|
|
|
|
|
|
|
|
and ( |
|
|
memory_format is None |
|
|
or memory_format == torch.preserve_format |
|
|
or utils.is_contiguous_for_memory_format(a, memory_format=memory_format) |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
@singledispatch |
|
|
def _to_dispatch(*args, **kwargs): |
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
@_to_dispatch.register |
|
|
def _to_device( |
|
|
device: torch.device, |
|
|
dtype: torch.dtype, |
|
|
non_blocking: bool = False, |
|
|
copy: bool = False, |
|
|
memory_format: Optional[torch.memory_format] = None, |
|
|
): |
|
|
kwargs = { |
|
|
"device": device, |
|
|
"dtype": dtype, |
|
|
"non_blocking": non_blocking, |
|
|
"copy": copy, |
|
|
"memory_format": memory_format, |
|
|
} |
|
|
return kwargs |
|
|
|
|
|
|
|
|
@_to_dispatch.register |
|
|
def _to_device_str( |
|
|
device: str, |
|
|
dtype: torch.dtype, |
|
|
non_blocking: bool = False, |
|
|
copy: bool = False, |
|
|
memory_format: Optional[torch.memory_format] = None, |
|
|
): |
|
|
kwargs = { |
|
|
"device": torch.device(device), |
|
|
"dtype": dtype, |
|
|
"non_blocking": non_blocking, |
|
|
"copy": copy, |
|
|
"memory_format": memory_format, |
|
|
} |
|
|
return kwargs |
|
|
|
|
|
|
|
|
@_to_dispatch.register |
|
|
def _to_dtype( |
|
|
dtype: torch.dtype, |
|
|
non_blocking: bool = False, |
|
|
copy: bool = False, |
|
|
memory_format: Optional[torch.memory_format] = None, |
|
|
): |
|
|
kwargs = { |
|
|
"dtype": dtype, |
|
|
"non_blocking": non_blocking, |
|
|
"copy": copy, |
|
|
"memory_format": memory_format, |
|
|
} |
|
|
return kwargs |
|
|
|
|
|
|
|
|
@_to_dispatch.register |
|
|
def _to_other( |
|
|
other: Tensor, |
|
|
non_blocking: bool = False, |
|
|
copy: bool = False, |
|
|
memory_format: Optional[torch.memory_format] = None, |
|
|
): |
|
|
device = other.device |
|
|
dtype = other.dtype |
|
|
layout = other.layout |
|
|
|
|
|
|
|
|
kwargs = { |
|
|
"device": device, |
|
|
"dtype": dtype, |
|
|
"layout": layout, |
|
|
"non_blocking": non_blocking, |
|
|
"copy": copy, |
|
|
"memory_format": memory_format, |
|
|
} |
|
|
return kwargs |
|
|
|
|
|
|
|
|
|
|
|
def canonicalize_to_arguments(a: Tensor, to_kwargs: dict): |
|
|
options_to_check = ["dtype", "device", "layout", "memory_format"] |
|
|
|
|
|
if "device" in to_kwargs and isinstance(to_kwargs["device"], str): |
|
|
to_kwargs["device"] = torch.device(to_kwargs["device"]) |
|
|
|
|
|
for kw in options_to_check: |
|
|
if kw in to_kwargs: |
|
|
if ( |
|
|
(kw == "memory_format" and to_kwargs[kw] is torch.preserve_format) |
|
|
or ( |
|
|
kw == "device" |
|
|
and to_kwargs[kw].type == a.device.type |
|
|
and ( |
|
|
not to_kwargs[kw].index or to_kwargs[kw].index == a.device.index |
|
|
) |
|
|
) |
|
|
or ( |
|
|
getattr(a, kw, None) == to_kwargs[kw] |
|
|
) |
|
|
): |
|
|
to_kwargs.pop(kw) |
|
|
|
|
|
|
|
|
def to(a: TensorLikeType, *args, **kwargs) -> TensorLikeType: |
|
|
|
|
|
if len(args) != 0: |
|
|
kwargs = _to_dispatch(*args, **kwargs) |
|
|
|
|
|
|
|
|
|
|
|
assert "pin_memory" not in kwargs |
|
|
canonicalize_to_arguments(a, kwargs) |
|
|
|
|
|
if _to_will_alias(a, **kwargs): |
|
|
return a |
|
|
|
|
|
copy = kwargs.pop("copy") if "copy" in kwargs else False |
|
|
non_blocking = kwargs.pop("non_blocking") if "non_blocking" in kwargs else False |
|
|
|
|
|
|
|
|
if ( |
|
|
(copy or (kwargs.get("dtype", a.dtype) != a.dtype)) |
|
|
and (not non_blocking) |
|
|
and ("memory_format" not in kwargs) |
|
|
and ("device" not in kwargs) |
|
|
and ("layout" not in kwargs) |
|
|
|
|
|
|
|
|
): |
|
|
return prims.convert_element_type(a, kwargs.get("dtype", a.dtype)) |
|
|
|
|
|
result = torch.empty_like(a, **kwargs) |
|
|
|
|
|
copy_to(result, a) |
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _reduction( |
|
|
a: TensorLikeType, |
|
|
prim: Callable, |
|
|
*, |
|
|
has_identity: bool = True, |
|
|
accepts_dim_tuple: bool = True, |
|
|
dims: Optional[DimsType] = None, |
|
|
keepdims: bool = False, |
|
|
dtype: Optional[torch.dtype] = None, |
|
|
out: Optional[Tensor] = None, |
|
|
output_dtype_kind: REDUCTION_OUTPUT_TYPE_KIND, |
|
|
) -> TensorLikeType: |
|
|
|
|
|
assert isinstance(a, TensorLike) |
|
|
if a.ndim > 64: |
|
|
raise RuntimeError( |
|
|
"Received a tensor with {0} dimensions, but only tensors with up to 64 dims are supported!".format( |
|
|
a.ndim |
|
|
) |
|
|
) |
|
|
|
|
|
if out is not None: |
|
|
assert isinstance(out, TensorLike) |
|
|
if dtype is not None: |
|
|
|
|
|
if dtype != out.dtype: |
|
|
raise RuntimeError( |
|
|
"dtype argument and out dtype must match in reduction" |
|
|
) |
|
|
if not accepts_dim_tuple: |
|
|
assert dims is None or isinstance(dims, int) |
|
|
if isinstance(dims, int): |
|
|
dims = (dims,) |
|
|
dims = utils.reduction_dims(a.shape, dims) |
|
|
if not has_identity: |
|
|
valid_shape = a.ndim == 0 or py_all(a.shape[i] for i in dims) |
|
|
if not valid_shape: |
|
|
raise RuntimeError( |
|
|
"reducing over zero-size dimension for reduction operation without identity" |
|
|
) |
|
|
computation_dtype, result_dtype = utils.reduction_dtypes( |
|
|
a, output_dtype_kind, dtype |
|
|
) |
|
|
a_converted = prims.convert_element_type(a, computation_dtype) |
|
|
result = prim(a_converted, dims) |
|
|
if keepdims: |
|
|
output_shape = [a.shape[i] if i not in dims else 1 for i in range(a.ndim)] |
|
|
broadcast_dims = [i for i in range(a.ndim) if i not in dims] |
|
|
result = prims.broadcast_in_dim(result, output_shape, broadcast_dims) |
|
|
|
|
|
if out is not None: |
|
|
assert result_dtype is not None |
|
|
if dtype is not None and result_dtype != out.dtype: |
|
|
raise RuntimeError( |
|
|
"Expected the dtype of reduction result and out to match" |
|
|
) |
|
|
out = _maybe_resize_out(out, result.shape) |
|
|
return _safe_copy_out(copy_from=result, copy_to=out) |
|
|
|
|
|
if result.dtype != result_dtype and result_dtype is not None: |
|
|
result = prims.convert_element_type(result, result_dtype) |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
py_all = all |
|
|
|
|
|
|
|
|
@register_decomposition(torch.ops.aten.all) |
|
|
@out_wrapper() |
|
|
def all( |
|
|
a: TensorLikeType, |
|
|
dim: Optional[DimsType] = None, |
|
|
keepdim: bool = False, |
|
|
) -> TensorLikeType: |
|
|
|
|
|
if isinstance(dim, int): |
|
|
dim = (dim,) |
|
|
|
|
|
a_ = _maybe_convert_to_dtype(a, torch.bool) |
|
|
|
|
|
result = eq(sum(logical_not(a_), dim=dim, keepdim=keepdim), 0) |
|
|
|
|
|
|
|
|
if a.dtype is torch.uint8: |
|
|
return prims.convert_element_type(result, torch.uint8) |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
py_any = any |
|
|
|
|
|
|
|
|
@register_decomposition(torch.ops.aten.any) |
|
|
@out_wrapper() |
|
|
def any( |
|
|
a: TensorLikeType, |
|
|
dim: Optional[DimsType] = None, |
|
|
keepdim: bool = False, |
|
|
) -> TensorLikeType: |
|
|
a_ = _maybe_convert_to_dtype(a, torch.bool) |
|
|
result = ne(sum(a_, dim=dim, keepdim=keepdim), False) |
|
|
|
|
|
|
|
|
if a.dtype is torch.uint8: |
|
|
return prims.convert_element_type(result, torch.uint8) |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
@register_decomposition(torch.ops.aten.sum) |
|
|
def sum( |
|
|
a: TensorLikeType, |
|
|
dim: Union[Optional[int], Optional[List[int]]] = None, |
|
|
keepdim: bool = False, |
|
|
*, |
|
|
dtype: Optional[torch.dtype] = None, |
|
|
out: Optional[Tensor] = None, |
|
|
) -> TensorLikeType: |
|
|
if dtype is None: |
|
|
if utils.is_boolean_dtype(a.dtype) or utils.is_integer_dtype(a.dtype): |
|
|
dtype = torch.int64 |
|
|
else: |
|
|
dtype = a.dtype |
|
|
|
|
|
if dim == () or dim == []: |
|
|
dim = None |
|
|
return _reduction( |
|
|
a, |
|
|
prims.sum, |
|
|
dims=dim, |
|
|
keepdims=keepdim, |
|
|
dtype=dtype, |
|
|
out=out, |
|
|
output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.SAME, |
|
|
) |
|
|
|
|
|
|
|
|
def sum_to_size( |
|
|
a: Tensor, |
|
|
*shape, |
|
|
) -> Tensor: |
|
|
shape = utils.extract_shape_from_varargs(shape, validate=False) |
|
|
utils.check( |
|
|
utils.is_expandable_to(shape, a.shape), |
|
|
lambda: f'sum_to_size: size "{shape}" is not expandable to size "{a.shape}"', |
|
|
) |
|
|
|
|
|
|
|
|
if utils.is_same_shape(shape, a.shape) and len(shape) > 0: |
|
|
return prims.view_of(a) |
|
|
leading_dims = a.ndim - len(shape) |
|
|
reduce_dims = tuple(range(leading_dims)) + tuple( |
|
|
i |
|
|
for i in range(leading_dims, len(shape)) |
|
|
if shape[i - leading_dims] == 1 and a.shape[i] != 1 |
|
|
) |
|
|
return torch.sum(a, dim=reduce_dims, keepdim=True, dtype=None) |
|
|
|
|
|
|
|
|
@register_decomposition(torch.ops.aten.prod) |
|
|
def prod( |
|
|
a: TensorLikeType, |
|
|
dim: Union[Optional[int], Optional[List[int]]] = None, |
|
|
keepdim: bool = False, |
|
|
*, |
|
|
dtype=None, |
|
|
out: Optional[Tensor] = None, |
|
|
) -> TensorLikeType: |
|
|
if dtype is None: |
|
|
if utils.is_boolean_dtype(a.dtype) or utils.is_integer_dtype(a.dtype): |
|
|
dtype = torch.int64 |
|
|
else: |
|
|
dtype = a.dtype |
|
|
|
|
|
if dim == () or dim == []: |
|
|
dim = None |
|
|
return _reduction( |
|
|
a, |
|
|
prims.prod, |
|
|
dims=dim, |
|
|
keepdims=keepdim, |
|
|
dtype=dtype, |
|
|
out=out, |
|
|
output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.SAME, |
|
|
) |
|
|
|
|
|
|
|
|
@register_decomposition(torch.ops.aten.amin) |
|
|
def amin( |
|
|
a: TensorLikeType, |
|
|
dim: Union[Optional[int], Optional[List[int]]] = None, |
|
|
keepdim: bool = False, |
|
|
*, |
|
|
out: Optional[Tensor] = None, |
|
|
) -> TensorLikeType: |
|
|
|
|
|
if dim == () or dim == []: |
|
|
dim = None |
|
|
|
|
|
return _reduction( |
|
|
a, |
|
|
prims.amin, |
|
|
dims=dim, |
|
|
keepdims=keepdim, |
|
|
dtype=None, |
|
|
out=out, |
|
|
has_identity=False, |
|
|
output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.SAME, |
|
|
) |
|
|
|
|
|
|
|
|
@register_decomposition(torch.ops.aten.amax) |
|
|
def amax( |
|
|
a: TensorLikeType, |
|
|
dim: Optional[DimsType] = None, |
|
|
keepdim: bool = False, |
|
|
*, |
|
|
out: Optional[Tensor] = None, |
|
|
) -> TensorLikeType: |
|
|
|
|
|
if dim == () or dim == []: |
|
|
dim = None |
|
|
|
|
|
return _reduction( |
|
|
a, |
|
|
prims.amax, |
|
|
dims=dim, |
|
|
keepdims=keepdim, |
|
|
dtype=None, |
|
|
out=out, |
|
|
has_identity=False, |
|
|
output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.SAME, |
|
|
) |
|
|
|
|
|
|
|
|
def _dim_var_dispatch(dim=None, unbiased=None): |
|
|
|
|
|
|
|
|
|
|
|
if unbiased is None and isinstance(dim, bool): |
|
|
unbiased = dim |
|
|
dim = None |
|
|
return dim, unbiased |
|
|
|
|
|
|
|
|
@out_wrapper() |
|
|
def var( |
|
|
a: TensorLikeType, |
|
|
dim: Optional[DimsType] = None, |
|
|
unbiased: Optional[bool] = None, |
|
|
keepdim: bool = False, |
|
|
*, |
|
|
correction: Optional[int] = None, |
|
|
) -> TensorLikeType: |
|
|
dim, unbiased = _dim_var_dispatch(dim, unbiased) |
|
|
correction = utils.set_correction(unbiased, correction) |
|
|
|
|
|
if dim == () or dim == []: |
|
|
dim = None |
|
|
|
|
|
result = _reduction( |
|
|
a, |
|
|
partial(prims.var, correction=correction), |
|
|
dims=dim, |
|
|
keepdims=keepdim, |
|
|
dtype=None, |
|
|
out=None, |
|
|
has_identity=True, |
|
|
output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT, |
|
|
) |
|
|
return result |
|
|
|
|
|
|
|
|
@out_wrapper() |
|
|
def std( |
|
|
a: TensorLikeType, |
|
|
dim: Union[Optional[int], Optional[List[int]]] = None, |
|
|
unbiased: Optional[bool] = None, |
|
|
keepdim: bool = False, |
|
|
*, |
|
|
correction: Optional[int] = None, |
|
|
) -> TensorLikeType: |
|
|
dim, unbiased = _dim_var_dispatch(dim, unbiased) |
|
|
correction = utils.set_correction(unbiased, correction) |
|
|
|
|
|
if dim == () or dim == []: |
|
|
dim = None |
|
|
|
|
|
opmath_dtype, dtype = utils.reduction_dtypes( |
|
|
a, REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT |
|
|
) |
|
|
|
|
|
result = _reduction( |
|
|
a, |
|
|
partial(prims.var, correction=correction), |
|
|
dims=dim, |
|
|
keepdims=keepdim, |
|
|
dtype=opmath_dtype, |
|
|
out=None, |
|
|
has_identity=True, |
|
|
output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT, |
|
|
) |
|
|
result = sqrt(result) |
|
|
return _maybe_convert_to_dtype(result, dtype) |
|
|
|
|
|
|
|
|
@register_decomposition(torch.ops.aten.mean) |
|
|
def mean( |
|
|
a: TensorLikeType, |
|
|
dim: Optional[DimsType] = None, |
|
|
keepdim: bool = False, |
|
|
*, |
|
|
dtype=None, |
|
|
out=None, |
|
|
) -> TensorLikeType: |
|
|
|
|
|
if dim == () or dim == []: |
|
|
dim = None |
|
|
if dtype is None: |
|
|
dtype = a.dtype |
|
|
|
|
|
if out is not None and out.dtype != dtype: |
|
|
raise RuntimeError("expected out dtype and dtype to match") |
|
|
result = _reduction( |
|
|
a, |
|
|
prims.sum, |
|
|
dims=dim, |
|
|
keepdims=keepdim, |
|
|
dtype=dtype, |
|
|
out=None, |
|
|
output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.KEEP_PROMOTED_TYPE, |
|
|
) |
|
|
if utils.is_integer_dtype(dtype): |
|
|
raise RuntimeError("result type should be floating point or complex") |
|
|
if isinstance(dim, int): |
|
|
dim = (dim,) |
|
|
dims = utils.reduction_dims(a.shape, dim) |
|
|
nelem = 1 if a.ndim == 0 else reduce(operator.mul, (a.shape[i] for i in dims), 1) |
|
|
result = true_divide(result, nelem) |
|
|
result_dtype = a.dtype if dtype is None else dtype |
|
|
result = _maybe_convert_to_dtype(result, result_dtype) |
|
|
if out is not None: |
|
|
assert isinstance(out, TensorLike) |
|
|
out = _maybe_resize_out(out, result.shape) |
|
|
return _safe_copy_out(copy_from=result, copy_to=out) |
|
|
return result |
|
|
|
|
|
|
|
|
@register_decomposition(torch.ops.aten.std_mean.correction) |
|
|
def std_mean( |
|
|
a: TensorLikeType, |
|
|
dim: Union[Optional[int], Optional[List[int]]] = None, |
|
|
*, |
|
|
unbiased: Optional[bool] = None, |
|
|
keepdim: bool = False, |
|
|
correction: Optional[int] = None, |
|
|
): |
|
|
dim, unbiased = _dim_var_dispatch(dim, unbiased) |
|
|
s = std(a, dim, unbiased, keepdim, correction=correction) |
|
|
m = mean(a, dim, keepdim) |
|
|
return s, m |
|
|
|
|
|
|
|
|
@register_decomposition(torch.ops.aten.var_mean) |
|
|
def var_mean( |
|
|
a: TensorLikeType, |
|
|
dim: Optional[DimsType] = None, |
|
|
unbiased: Optional[bool] = None, |
|
|
keepdim: bool = False, |
|
|
*, |
|
|
correction: Optional[int] = None, |
|
|
): |
|
|
dim, unbiased = _dim_var_dispatch(dim, unbiased) |
|
|
v = var(a, dim, unbiased, keepdim, correction=correction) |
|
|
m = mean(a, dim, keepdim) |
|
|
return v, m |
|
|
|
|
|
|
|
|
@register_decomposition(torch.ops.aten.addr) |
|
|
@out_wrapper() |
|
|
@elementwise_type_promotion_wrapper( |
|
|
type_promoting_args=("self", "vec1", "vec2"), |
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, |
|
|
) |
|
|
def addr( |
|
|
self: TensorLikeType, |
|
|
vec1: TensorLikeType, |
|
|
vec2: TensorLikeType, |
|
|
*, |
|
|
beta: NumberType = 1, |
|
|
alpha: NumberType = 1, |
|
|
) -> TensorLikeType: |
|
|
check( |
|
|
vec1.ndim == 1, |
|
|
lambda: f"addr: Expected 1-D argument vec1, but got {vec1.ndim}-D", |
|
|
) |
|
|
check( |
|
|
vec2.ndim == 1, |
|
|
lambda: f"addr: Expected 1-D argument vec2, but got {vec2.ndim}-D", |
|
|
) |
|
|
self = self.expand(vec1.shape[0], vec2.shape[0]) |
|
|
if utils.is_boolean_dtype(self.dtype): |
|
|
|
|
|
check( |
|
|
is_weakly_lesser_type(type(beta), int), |
|
|
lambda: f"expected bool/int beta but got {type(beta)}", |
|
|
) |
|
|
check( |
|
|
is_weakly_lesser_type(type(alpha), int), |
|
|
lambda: f"expected bool/int alpha but got {type(beta)}", |
|
|
) |
|
|
if not beta: |
|
|
return torch.outer(vec1, vec2) if alpha else torch.full_like(self, False) |
|
|
else: |
|
|
return torch.logical_or( |
|
|
self, |
|
|
torch.outer(vec1, vec2) if alpha else torch.full_like(self, False), |
|
|
) |
|
|
else: |
|
|
check( |
|
|
is_weakly_lesser_type(type(beta), dtype_to_type(self.dtype)), |
|
|
lambda: f"cannot safely convert {type(beta)} to {self.dtype}", |
|
|
) |
|
|
check( |
|
|
is_weakly_lesser_type(type(alpha), dtype_to_type(self.dtype)), |
|
|
lambda: f"cannot safely convert {type(alpha)} to {self.dtype}", |
|
|
) |
|
|
if beta == 0: |
|
|
|
|
|
return alpha * torch.outer(vec1, vec2) |
|
|
else: |
|
|
return beta * self + alpha * torch.outer(vec1, vec2) |
|
|
|
|
|
|
|
|
|
|
|
def atleast_1d( |
|
|
arg: Union[TensorLikeType, Sequence[TensorLikeType]], *args: TensorLikeType |
|
|
) -> Union[TensorLikeType, Tuple[TensorLikeType, ...]]: |
|
|
"""Reference implementation of :func:`torch.atleast_1d`.""" |
|
|
if not args and isinstance(arg, collections.abc.Sequence): |
|
|
args_ = arg |
|
|
else: |
|
|
assert not isinstance(arg, collections.abc.Sequence) |
|
|
args_ = (arg,) + args |
|
|
res = tuple(a if a.ndim >= 1 else unsqueeze(a, 0) for a in args_) |
|
|
return res if len(res) > 1 else res[0] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _unsqueeze_atleast( |
|
|
at_least_fn: Callable, dim: int, arg: TensorLikeType |
|
|
) -> TensorLikeType: |
|
|
arg_ = at_least_fn(arg) |
|
|
assert isinstance(arg_, TensorLike) |
|
|
return unsqueeze(arg_, dim) |
|
|
|
|
|
|
|
|
|
|
|
def atleast_2d( |
|
|
arg: Union[TensorLikeType, Sequence[TensorLikeType]], *args: TensorLikeType |
|
|
) -> Union[TensorLikeType, Tuple[TensorLikeType, ...]]: |
|
|
"""Reference implementation of :func:`torch.atleast_2d`.""" |
|
|
if not args and isinstance(arg, collections.abc.Sequence): |
|
|
args_ = arg |
|
|
else: |
|
|
assert not isinstance(arg, collections.abc.Sequence) |
|
|
args_ = (arg,) + args |
|
|
unsqueeze_atleast_1d = partial(_unsqueeze_atleast, atleast_1d, 0) |
|
|
res = tuple(a if a.ndim >= 2 else unsqueeze_atleast_1d(a) for a in args_) |
|
|
return res if len(res) > 1 else res[0] |
|
|
|
|
|
|
|
|
|
|
|
def atleast_3d( |
|
|
arg: Union[TensorLikeType, Sequence[TensorLikeType]], *args: TensorLikeType |
|
|
) -> Union[TensorLikeType, Tuple[TensorLikeType, ...]]: |
|
|
"""Reference implementation of :func:`torch.atleast_3d`.""" |
|
|
if not args and isinstance(arg, collections.abc.Sequence): |
|
|
args_ = arg |
|
|
else: |
|
|
assert not isinstance(arg, collections.abc.Sequence) |
|
|
args_ = (arg,) + args |
|
|
unsqueeze_atleast_2d = partial(_unsqueeze_atleast, atleast_2d, -1) |
|
|
res = tuple(a if a.ndim >= 3 else unsqueeze_atleast_2d(a) for a in args_) |
|
|
return res if len(res) > 1 else res[0] |
|
|
|
|
|
|
|
|
def as_strided( |
|
|
a: TensorLikeType, size: ShapeType, stride: StrideType, storage_offset: int = 0 |
|
|
) -> TensorLikeType: |
|
|
return prims.as_strided(a, size, stride, storage_offset) |
|
|
|
|
|
|
|
|
def broadcast_shapes(*shapes) -> ShapeType: |
|
|
return torch.Size(_broadcast_shapes(*shapes)) |
|
|
|
|
|
|
|
|
@torch.ops.aten.broadcast_tensors.default.py_impl(DispatchKey.CompositeImplicitAutograd) |
|
|
@torch.ops.aten.broadcast_tensors.default.py_impl(DispatchKey.Meta) |
|
|
def broadcast_tensors(*tensors) -> List[TensorLikeType]: |
|
|
if len(tensors) == 1 and not isinstance(tensors[0], Tensor): |
|
|
tensors = tensors[0] |
|
|
return list(_maybe_broadcast(*tensors, preserve_cpu_scalar_tensors=False)) |
|
|
|
|
|
|
|
|
|
|
|
def broadcast_to(a: TensorLikeType, size: ShapeType) -> TensorLikeType: |
|
|
start = len(size) - len(a.shape) |
|
|
dims = tuple(range(start, len(a.shape) + start)) |
|
|
return prims.broadcast_in_dim(a, size, dims) |
|
|
|
|
|
|
|
|
@register_decomposition(torch.ops.aten.cat) |
|
|
@out_wrapper() |
|
|
@elementwise_type_promotion_wrapper( |
|
|
type_promoting_args=("tensors",), |
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH, |
|
|
) |
|
|
def cat(tensors: TensorSequenceType, dim: int = 0) -> TensorLikeType: |
|
|
if len(tensors) == 0: |
|
|
msg = "cat expects at least one tensor, but received zero!" |
|
|
raise ValueError(msg) |
|
|
|
|
|
for tensor in tensors: |
|
|
assert isinstance(tensor, TensorLike) |
|
|
|
|
|
utils.check_same_device(*tensors, allow_cpu_scalar_tensors=False) |
|
|
|
|
|
dim = utils.canonicalize_dim(tensors[0].ndim, dim) |
|
|
utils.validate_idx(tensors[0].ndim, dim) |
|
|
|
|
|
|
|
|
filtered = tuple(x for x in tensors if not (x.ndim == 1 and x.numel() == 0)) |
|
|
if len(filtered) == 0: |
|
|
t = tensors[0] |
|
|
|
|
|
|
|
|
try: |
|
|
requires_grad = any(x.requires_grad for x in tensors) |
|
|
except Exception: |
|
|
requires_grad = False |
|
|
|
|
|
return empty((0,), dtype=t.dtype, device=t.device, requires_grad=requires_grad) |
|
|
|
|
|
return prims.cat(filtered, dim) |
|
|
|
|
|
|
|
|
|
|
|
@out_wrapper() |
|
|
def column_stack(tensors: TensorSequenceType) -> TensorLikeType: |
|
|
aligned_tensors = tuple( |
|
|
x if x.ndim > 1 else x.reshape((x.numel(), 1)) for x in tensors |
|
|
) |
|
|
return cat(aligned_tensors, 1) |
|
|
|
|
|
|
|
|
def conj(input: TensorLikeType) -> TensorLikeType: |
|
|
if not utils.is_complex_dtype(input.dtype): |
|
|
return input |
|
|
if input.is_sparse: |
|
|
return torch.conj_physical(input) |
|
|
return prims.conj(input) |
|
|
|
|
|
|
|
|
|
|
|
@register_decomposition(torch.ops.aten.constant_pad_nd) |
|
|
def constant_pad_nd( |
|
|
input: TensorLikeType, pad: List[int], value: NumberType = 0 |
|
|
) -> TensorLikeType: |
|
|
check( |
|
|
len(pad) % 2 == 0, |
|
|
lambda: f"Length of pad must be even but instead it equals {len(pad)}", |
|
|
) |
|
|
|
|
|
input_sizes = input.shape |
|
|
l_inp = len(input_sizes) |
|
|
|
|
|
l_pad = len(pad) // 2 |
|
|
l_diff = l_inp - l_pad |
|
|
|
|
|
check( |
|
|
l_inp >= l_pad, |
|
|
lambda: "Length of pad should be no more than twice the number of " |
|
|
f"dimensions of the input. Pad length is {len(pad)} while the input has " |
|
|
f"{l_inp} dimensions.", |
|
|
) |
|
|
|
|
|
c_input = input |
|
|
for i in range(l_diff, l_inp): |
|
|
pad_idx = 2 * (l_inp - i - 1) |
|
|
if pad[pad_idx] < 0: |
|
|
c_input = c_input.narrow(i, -pad[pad_idx], c_input.shape[i] + pad[pad_idx]) |
|
|
|
|
|
if pad[pad_idx + 1] < 0: |
|
|
c_input = c_input.narrow(i, 0, c_input.shape[i] + pad[pad_idx + 1]) |
|
|
|
|
|
|
|
|
if builtins.all(p <= 0 for p in pad): |
|
|
return c_input.clone() |
|
|
|
|
|
new_shape = list(input_sizes[:l_diff]) |
|
|
|
|
|
for i in range(l_pad): |
|
|
pad_idx = len(pad) - ((i + 1) * 2) |
|
|
new_dim = input_sizes[l_diff + i] + pad[pad_idx] + pad[pad_idx + 1] |
|
|
check( |
|
|
new_dim > 0, |
|
|
lambda: f"The input size {input_sizes[l_diff + i]}, plus negative padding " |
|
|
f"{pad[pad_idx]} and {pad[pad_idx + 1]} resulted in a negative output size, " |
|
|
f"which is invalid. Check dimension {l_diff + i} of your input.", |
|
|
) |
|
|
new_shape.append(new_dim) |
|
|
|
|
|
memory_format = utils.suggest_memory_format(input) |
|
|
output = torch.empty( |
|
|
new_shape, |
|
|
dtype=input.dtype, |
|
|
device=input.device, |
|
|
requires_grad=input.requires_grad, |
|
|
memory_format=memory_format, |
|
|
) |
|
|
|
|
|
if value == 0 and input.dtype == torch.bool: |
|
|
value = False |
|
|
|
|
|
output = torch.fill(output, value) |
|
|
|
|
|
c_output = output |
|
|
for i in range(l_diff, l_inp): |
|
|
pad_idx = 2 * (l_inp - i - 1) |
|
|
if pad[pad_idx] > 0: |
|
|
c_output = c_output.narrow( |
|
|
i, pad[pad_idx], c_output.shape[i] - pad[pad_idx] |
|
|
) |
|
|
if pad[pad_idx + 1] > 0: |
|
|
c_output = c_output.narrow(i, 0, c_output.shape[i] - pad[pad_idx + 1]) |
|
|
|
|
|
prims.copy_to(c_output, c_input) |
|
|
return output |
|
|
|
|
|
|
|
|
def contiguous( |
|
|
a: Tensor, *, memory_format: torch.memory_format = torch.contiguous_format |
|
|
) -> Tensor: |
|
|
check( |
|
|
memory_format != torch.preserve_format, |
|
|
lambda: "preserve memory format is unsupported by the contiguous operator", |
|
|
) |
|
|
|
|
|
if utils.is_contiguous_for_memory_format(a, memory_format=memory_format): |
|
|
return a |
|
|
|
|
|
return torch.clone(a, memory_format=memory_format) |
|
|
|
|
|
|
|
|
@out_wrapper() |
|
|
def dstack(tensors: TensorSequenceType) -> TensorLikeType: |
|
|
check(len(tensors) > 0, lambda: "dstack expects a non-empty TensorList") |
|
|
aligned_tensors = atleast_3d(*tensors) |
|
|
return cat(aligned_tensors, 2) |
|
|
|
|
|
|
|
|
@register_decomposition(torch.ops.aten.expand, disable_meta=True) |
|
|
def expand(a: Tensor, *shape) -> Tensor: |
|
|
|
|
|
|
|
|
|
|
|
if len(shape) == 1 and isinstance(shape[0], Sequence): |
|
|
shape = tuple(shape[0]) |
|
|
|
|
|
check( |
|
|
len(shape) >= len(a.shape), |
|
|
lambda: "expand: the requested shape has too few dimensions!", |
|
|
) |
|
|
|
|
|
offset = len(shape) - len(a.shape) |
|
|
shape_ = list(shape) |
|
|
for idx, x in enumerate(a.shape): |
|
|
offset_idx = idx + offset |
|
|
requested_length = shape[offset_idx] |
|
|
check( |
|
|
requested_length == x or x == 1 or requested_length == -1, |
|
|
lambda: f"expand: attempting to expand a dimension of length {x}!", |
|
|
) |
|
|
|
|
|
shape_[offset_idx] = requested_length if requested_length != -1 else x |
|
|
|
|
|
|
|
|
utils.validate_shape(shape_) |
|
|
|
|
|
return prims.broadcast_in_dim( |
|
|
a, shape_, tuple(range(offset, len(a.shape) + offset)) |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def expand_as(a: Tensor, b: Tensor) -> Tensor: |
|
|
return a.expand(b.shape) |
|
|
|
|
|
|
|
|
def chunk(a: TensorLikeType, chunks: int, dim: int = 0) -> Tuple[TensorLikeType, ...]: |
|
|
if chunks <= 0: |
|
|
msg = "Expected at least one chunk, but got {0}!".format(chunks) |
|
|
raise ValueError(msg) |
|
|
|
|
|
dim = utils.canonicalize_dim(a.ndim, dim) |
|
|
length = a.shape[dim] |
|
|
chunk_size = math.ceil(length / chunks) |
|
|
full_chunks = math.floor(length / chunk_size) |
|
|
tail_chunk_size = length % chunk_size |
|
|
|
|
|
result = [] |
|
|
for i in range(full_chunks): |
|
|
result.append(narrow(a, dim, i * chunk_size, chunk_size)) |
|
|
|
|
|
if tail_chunk_size != 0: |
|
|
result.append(narrow(a, dim, full_chunks * chunk_size, tail_chunk_size)) |
|
|
|
|
|
return tuple(result) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def flatten(a: TensorLikeType, start_dim: int = 0, end_dim: int = -1) -> TensorLikeType: |
|
|
start_dim = utils.canonicalize_dim(a.ndim, start_dim) |
|
|
end_dim = utils.canonicalize_dim(a.ndim, end_dim) |
|
|
|
|
|
|
|
|
if start_dim == end_dim and a.ndim != 0: |
|
|
return a |
|
|
|
|
|
|
|
|
|
|
|
new_shape, new_strides = prims._collapse_view_helper(a, start_dim, end_dim + 1) |
|
|
if new_shape is not None: |
|
|
return prims.collapse_view(a, start_dim, end_dim + 1) |
|
|
|
|
|
|
|
|
return prims.collapse(a, start_dim, end_dim + 1) |
|
|
|
|
|
|
|
|
@register_decomposition(torch.ops.aten.flip) |
|
|
def flip(a: TensorLikeType, dims: DimsSequenceType) -> TensorLikeType: |
|
|
if not isinstance(dims, tuple) and not isinstance(dims, list): |
|
|
raise ValueError("dims has to be a sequence of ints") |
|
|
dims = utils.canonicalize_dims(a.ndim, dims) |
|
|
utils.validate_no_repeating_dims(dims) |
|
|
return prims.rev(a, dims) |
|
|
|
|
|
|
|
|
|
|
|
def fliplr(a: TensorLikeType) -> TensorLikeType: |
|
|
if a.ndim < 2: |
|
|
raise RuntimeError("Input must be >= 2-d.") |
|
|
|
|
|
return flip(a, (1,)) |
|
|
|
|
|
|
|
|
|
|
|
def flipud(a: TensorLikeType) -> TensorLikeType: |
|
|
if a.ndim < 1: |
|
|
raise RuntimeError("Input must be >= 1-d.") |
|
|
|
|
|
return flip(a, (0,)) |
|
|
|
|
|
|
|
|
|
|
|
def narrow(a: TensorLikeType, dim: int, start: int, length: int) -> TensorLikeType: |
|
|
dim = utils.canonicalize_dim(a.ndim, dim) |
|
|
return prims.slice_in_dim(a, start, start + length, axis=dim) |
|
|
|
|
|
|
|
|
def _normalize( |
|
|
a: Tensor, norm_dims: DimsType, eps: float |
|
|
) -> Tuple[Tensor, Tensor, Tensor]: |
|
|
"""Computes mean and 1/std of a tensor along norm_dims. |
|
|
|
|
|
Used as a helper function for normalization layers. |
|
|
|
|
|
Args: |
|
|
a (Tensor): input tensor |
|
|
norm_dims (DimsType): dimensions to normalize over |
|
|
eps (float): epsilon for numerical stability |
|
|
|
|
|
Returns: |
|
|
out (Tensor): normalized tensor. |
|
|
mean (Tensor): mean of the tensor along norm_dims. |
|
|
rstd (Tensor): 1/std of the tensor along norm_dims. |
|
|
""" |
|
|
computation_dtype = utils.get_computation_dtype(a.dtype) |
|
|
a_acc = _maybe_convert_to_dtype(a, computation_dtype) |
|
|
assert isinstance(a_acc, TensorLike) |
|
|
biased_var, mean = torch.var_mean( |
|
|
a_acc, dim=norm_dims, unbiased=False, keepdim=True |
|
|
) |
|
|
rstd = torch.rsqrt(biased_var + eps) |
|
|
out = (a - mean) * rstd |
|
|
return out, mean, rstd |
|
|
|
|
|
|
|
|
@register_decomposition(torch.ops.aten.native_layer_norm) |
|
|
def native_layer_norm( |
|
|
input: Tensor, |
|
|
normalized_shape: ShapeType, |
|
|
weight: Optional[Tensor], |
|
|
bias: Optional[Tensor], |
|
|
eps: float, |
|
|
) -> Tuple[Tensor, Tensor, Tensor]: |
|
|
normalized_ndim = len(normalized_shape) |
|
|
utils.check( |
|
|
normalized_ndim >= 1, |
|
|
lambda: "Expected normalized_shape to be at least 1-dimensional, i.e., " |
|
|
+ "containing at least one element, but got normalized_shape = " |
|
|
+ str(normalized_shape), |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
utils.check( |
|
|
weight is None or weight.shape == tuple(normalized_shape), |
|
|
lambda: "Expected weight to be of same shape as normalized_shape, but got " |
|
|
+ "weight of shape " |
|
|
+ str(weight.shape) |
|
|
+ " and normalized_shape = " |
|
|
+ str(normalized_shape), |
|
|
) |
|
|
utils.check( |
|
|
bias is None or bias.shape == tuple(normalized_shape), |
|
|
lambda: "Expected bias to be of same shape as normalized_shape, but got " |
|
|
+ "bias of shape " |
|
|
+ str(bias.shape) |
|
|
+ " and normalized_shape = " |
|
|
+ str(normalized_shape), |
|
|
) |
|
|
utils.check( |
|
|
input.ndim >= normalized_ndim |
|
|
and input.shape[(input.ndim - normalized_ndim) :] == tuple(normalized_shape), |
|
|
lambda: "Given normalized_shape=" |
|
|
+ str(normalized_shape) |
|
|
+ ", expected input with shape " |
|
|
+ str(normalized_shape) |
|
|
+ ", but got input of size " |
|
|
+ str(input.shape), |
|
|
) |
|
|
|
|
|
input = input.contiguous() |
|
|
if weight is not None: |
|
|
weight = weight.contiguous() |
|
|
if bias is not None: |
|
|
bias = bias.contiguous() |
|
|
|
|
|
axis = input.ndim - normalized_ndim |
|
|
reduction_dims = list(range(axis, input.ndim)) |
|
|
out, mean, rstd = _normalize(input, reduction_dims, eps) |
|
|
|
|
|
if weight is None and bias is not None: |
|
|
out = out + bias |
|
|
elif weight is not None and bias is None: |
|
|
out = out * weight |
|
|
elif weight is not None and bias is not None: |
|
|
out = out * weight + bias |
|
|
|
|
|
out = prims.convert_element_type(out, input.dtype) |
|
|
if input.device.type == "cpu": |
|
|
mean = prims.convert_element_type(mean, input.dtype) |
|
|
rstd = prims.convert_element_type(rstd, input.dtype) |
|
|
return (out, mean, rstd) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_decomposition(torch.ops.aten.permute, disable_meta=True) |
|
|
def permute(a: TensorLikeType, *dims) -> TensorLikeType: |
|
|
_permutation = utils.canonicalize_dims( |
|
|
a.ndim, utils.extract_dims_from_varargs(dims) |
|
|
) |
|
|
return prims.transpose(a, _permutation) |
|
|
|
|
|
|
|
|
|
|
|
def _get_unfold_copy_shape_stride( |
|
|
a_shape: ShapeType, a_stride: StrideType, dimension: int, size: int, step: int |
|
|
): |
|
|
a_ndim = len(a_shape) |
|
|
dimension = utils.canonicalize_dim(a_ndim, dimension) |
|
|
max_size = 1 if a_ndim == 0 else a_shape[dimension] |
|
|
last_stride = 1 if a_ndim == 0 else a_stride[dimension] |
|
|
|
|
|
utils.check( |
|
|
size <= max_size, |
|
|
lambda: "Maximum size for tensor at dimension " |
|
|
+ str(dimension) |
|
|
+ " is " |
|
|
+ str(max_size) |
|
|
+ " but size is " |
|
|
+ str(size), |
|
|
) |
|
|
|
|
|
utils.check( |
|
|
step > 0, |
|
|
lambda: "Step is " + str(step) + " but must be > 0", |
|
|
) |
|
|
|
|
|
new_size = [] |
|
|
new_stride = [] |
|
|
|
|
|
for d, (dim_size, dim_stride) in enumerate(zip(a_shape, a_stride)): |
|
|
if d == dimension: |
|
|
new_size.append((dim_size - size) // step + 1) |
|
|
new_stride.append(step * dim_stride) |
|
|
else: |
|
|
new_size.append(dim_size) |
|
|
new_stride.append(dim_stride) |
|
|
new_size.append(size) |
|
|
new_stride.append(last_stride) |
|
|
return new_size, new_stride |
|
|
|
|
|
|
|
|
@register_decomposition(torch.ops.aten.repeat) |
|
|
def repeat(a: Tensor, *repeat_shape) -> Tensor: |
|
|
repeat_shape = utils.extract_shape_from_varargs(repeat_shape, validate=False) |
|
|
utils.check( |
|
|
len(repeat_shape) >= len(a.shape), |
|
|
lambda: "repeat: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor", |
|
|
) |
|
|
|
|
|
if len(repeat_shape) == 0: |
|
|
return torch.clone(a) |
|
|
|
|
|
num_new_dimensions = len(repeat_shape) - a.ndim |
|
|
padded_shape = [1] * num_new_dimensions |
|
|
for dim_size in a.shape: |
|
|
padded_shape.append(dim_size) |
|
|
|
|
|
target_shape = tuple( |
|
|
padded_size * repeat_size |
|
|
for padded_size, repeat_size in zip(padded_shape, repeat_shape) |
|
|
) |
|
|
|
|
|
|
|
|
if 0 in repeat_shape: |
|
|
return torch.empty( |
|
|
target_shape, |
|
|
dtype=a.dtype, |
|
|
device=a.device, |
|
|
requires_grad=a.requires_grad, |
|
|
memory_format=utils.suggest_memory_format(a), |
|
|
) |
|
|
|
|
|
urtensor_shape = target_shape |
|
|
urtensor_stride = utils.make_contiguous_strides_for(target_shape) |
|
|
for dim, dim_size in enumerate(padded_shape): |
|
|
|
|
|
urtensor_shape, urtensor_stride = _get_unfold_copy_shape_stride( |
|
|
urtensor_shape, urtensor_stride, dim, dim_size, max(dim_size, 1) |
|
|
) |
|
|
|
|
|
|
|
|
enumerated_stride = list(enumerate(urtensor_stride)) |
|
|
enumerated_stride.sort(key=lambda item: item[1], reverse=True) |
|
|
permute_order, sorted_stride = zip(*enumerated_stride) |
|
|
|
|
|
|
|
|
repeat_xtensor = a.expand(urtensor_shape) |
|
|
|
|
|
|
|
|
cloned_result = torch.clone(repeat_xtensor) |
|
|
|
|
|
|
|
|
permuted_result = cloned_result.permute(permute_order) |
|
|
|
|
|
|
|
|
return permuted_result.reshape(target_shape) |
|
|
|
|
|
|
|
|
def _reshape_view_helper(a: TensorLikeType, *shape, allow_copy: bool) -> TensorLikeType: |
|
|
|
|
|
shape = utils.extract_shape_from_varargs(shape, validate=False) |
|
|
|
|
|
|
|
|
shape = utils.infer_size(shape, a.numel()) |
|
|
|
|
|
|
|
|
if tuple(a.shape) == tuple(shape): |
|
|
return prims.view_of(a) |
|
|
|
|
|
|
|
|
if a.numel() == 0: |
|
|
return as_strided(a, shape, utils.make_contiguous_strides_for(shape)) |
|
|
|
|
|
|
|
|
if a.ndim == 0: |
|
|
_a = a |
|
|
for length in shape: |
|
|
assert length == 1 |
|
|
_a = unsqueeze(_a, -1) |
|
|
return _a |
|
|
|
|
|
|
|
|
if len(shape) == 0: |
|
|
_a = a |
|
|
for length in a.shape: |
|
|
assert length == 1 |
|
|
_a = squeeze(_a, -1) |
|
|
return _a |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
idx = 0 |
|
|
a_ = a |
|
|
for length in shape: |
|
|
|
|
|
if idx >= a_.ndim: |
|
|
assert length == 1 |
|
|
last_dim = a_.ndim - 1 |
|
|
|
|
|
|
|
|
a_ = prims.split_dim(a_, last_dim, a_.shape[last_dim]) |
|
|
idx = idx + 1 |
|
|
continue |
|
|
|
|
|
|
|
|
if length == a_.shape[idx]: |
|
|
idx = idx + 1 |
|
|
continue |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
accum = a_.shape[idx] |
|
|
end = idx |
|
|
while accum % length != 0: |
|
|
end = end + 1 |
|
|
accum = accum * a_.shape[end] |
|
|
if end != idx: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
new_shape, new_strides = prims._collapse_view_helper(a_, idx, end + 1) |
|
|
if new_shape is None: |
|
|
if allow_copy: |
|
|
return prims.reshape(a, shape) |
|
|
|
|
|
msg = "Cannot view a tensor with shape {0} and strides {1} as a tensor with shape {2}!".format( |
|
|
a.shape, a.stride(), shape |
|
|
) |
|
|
raise ValueError(msg) |
|
|
|
|
|
a_ = flatten(a_, idx, end) |
|
|
|
|
|
|
|
|
if accum != length: |
|
|
a_ = prims.split_dim(a_, idx, length) |
|
|
|
|
|
idx = idx + 1 |
|
|
|
|
|
|
|
|
while idx < a_.ndim: |
|
|
assert a_.shape[idx] == 1 |
|
|
a_ = squeeze(a_, idx) |
|
|
|
|
|
return a_ |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def reshape(a: TensorLikeType, *shape: ShapeType) -> TensorLikeType: |
|
|
return _reshape_view_helper(a, *shape, allow_copy=True) |
|
|
|
|
|
|
|
|
|
|
|
def reshape_as(self: TensorLikeType, other: TensorLikeType) -> TensorLikeType: |
|
|
return self.reshape(other.size()) |
|
|
|
|
|
|
|
|
@register_decomposition(torch.ops.aten.roll) |
|
|
def roll( |
|
|
a: TensorLikeType, shifts: DimsType, dims: DimsType = tuple() |
|
|
) -> TensorLikeType: |
|
|
"""Reference implementation of :func:`torch.roll`.""" |
|
|
dims = utils.canonicalize_dims(a.ndim, dims) |
|
|
|
|
|
if not isinstance(shifts, Iterable): |
|
|
shifts = (shifts,) |
|
|
if not isinstance(dims, Iterable): |
|
|
dims = (dims,) |
|
|
|
|
|
|
|
|
if a.numel() == 0: |
|
|
|
|
|
return clone(a) |
|
|
|
|
|
len_shifts = len(shifts) |
|
|
len_dims = len(dims) |
|
|
if len_shifts != 1 or len_dims != 1: |
|
|
if len_shifts == 0: |
|
|
raise RuntimeError("`shifts` required") |
|
|
|
|
|
|
|
|
if len_dims == 0 and len_shifts == 1: |
|
|
return torch.roll(torch.flatten(a), shifts, 0).view(a.shape) |
|
|
if len_shifts != len_dims: |
|
|
raise RuntimeError( |
|
|
f"shifts and dimensions must align. shifts: {len_shifts}, dims: {len_dims}" |
|
|
) |
|
|
assert len_dims > 1 |
|
|
tail_shifts = shifts[1:] |
|
|
tail_dims = dims[1:] |
|
|
first_dim_rolled = torch.roll(a, shifts[0], dims[0]) |
|
|
return torch.roll(first_dim_rolled, tail_shifts, tail_dims) |
|
|
|
|
|
|
|
|
|
|
|
dim = dims[0] |
|
|
size = a.shape[dim] |
|
|
start = (size - shifts[0]) % size |
|
|
t0 = torch.narrow(a, dim, start, size - start) |
|
|
t1 = torch.narrow(a, dim, 0, start) |
|
|
return torch.cat((t0, t1), dim) |
|
|
|
|
|
|
|
|
@register_decomposition(torch.ops.aten.rot90) |
|
|
def rot90( |
|
|
a: TensorLikeType, k: int = 1, dims: DimsSequenceType = (0, 1) |
|
|
) -> TensorLikeType: |
|
|
"""Reference implementation of :func:`torch.rot90`.""" |
|
|
if len(dims) != 2: |
|
|
raise RuntimeError( |
|
|
f"expected total rotation dims == 2, but got dims = {len(dims)}" |
|
|
) |
|
|
if a.ndim < 2: |
|
|
raise RuntimeError(f"expected total dims >= 2, but got total dims = {a.ndim}") |
|
|
|
|
|
|
|
|
|
|
|
dims = utils.canonicalize_dims(a.ndim, dims) |
|
|
|
|
|
if dims[0] == dims[1]: |
|
|
raise RuntimeError( |
|
|
f"expected rotation dims to be different, but got dim0 = {dims[0]} and dim1 = {dims[1]}" |
|
|
) |
|
|
k = k % 4 |
|
|
if k == 1: |
|
|
return torch.transpose(torch.flip(a, (dims[1],)), dims[0], dims[1]) |
|
|
elif k == 2: |
|
|
return torch.flip(a, dims) |
|
|
elif k == 3: |
|
|
return torch.transpose(torch.flip(a, (dims[0],)), dims[0], dims[1]) |
|
|
else: |
|
|
return clone(a) |
|
|
|
|
|
|
|
|
def _check_stack_inputs(tensors: TensorSequenceType) -> None: |
|
|
entry_shape = tensors[0].shape |
|
|
for i in range(1, len(tensors)): |
|
|
assert tensors[i].shape == entry_shape, ( |
|
|
f"stack expects each tensor to be equal size, but got {entry_shape} at entry 0" |
|
|
f"and {tensors[i].shape} at entry {i}" |
|
|
) |
|
|
|
|
|
|
|
|
@register_decomposition(torch.ops.aten.stack) |
|
|
@out_wrapper() |
|
|
def stack(tensors: TensorSequenceType, dim: int = 0) -> TensorLikeType: |
|
|
assert len(tensors) > 0, "stack expects a non-empty TensorList" |
|
|
wrapped_dim = utils.canonicalize_dim(tensors[0].ndim + 1, dim) |
|
|
|
|
|
if wrapped_dim < tensors[0].ndim: |
|
|
_check_stack_inputs(tensors) |
|
|
result_sizes = list(tensors[0].shape) |
|
|
result_sizes.insert(wrapped_dim, len(tensors)) |
|
|
out = torch.cat(tensors, wrapped_dim) |
|
|
return out.view(result_sizes) |
|
|
|
|
|
|
|
|
return torch.cat([t.unsqueeze(wrapped_dim) for t in tensors], dim) |
|
|
|
|
|
|
|
|
@out_wrapper() |
|
|
def softmax( |
|
|
a: TensorLikeType, |
|
|
dim: int, |
|
|
*, |
|
|
dtype: Optional[torch.dtype] = None, |
|
|
) -> TensorLikeType: |
|
|
result_dtype = dtype or a.dtype |
|
|
computation_dtype = utils.get_computation_dtype(a.dtype) |
|
|
a_ = _maybe_convert_to_dtype(a, computation_dtype) |
|
|
assert isinstance(a_, TensorLike) |
|
|
a_max = amax(a_, dim, keepdim=True) |
|
|
a_exp = exp(a_ - a_max) |
|
|
return _maybe_convert_to_dtype( |
|
|
true_divide(a_exp, sum(a_exp, dim, keepdim=True)), result_dtype |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
@out_wrapper() |
|
|
def hstack(tensors: TensorSequenceType) -> TensorLikeType: |
|
|
check(len(tensors) > 0, lambda: "hstack expects a non-empty TensorList") |
|
|
aligned_tensors = atleast_1d(*tensors) |
|
|
if aligned_tensors[0].ndim == 1: |
|
|
return cat(aligned_tensors, 0) |
|
|
return cat(aligned_tensors, 1) |
|
|
|
|
|
|
|
|
|
|
|
@out_wrapper() |
|
|
def vstack(tensors: TensorSequenceType) -> TensorLikeType: |
|
|
check(len(tensors) > 0, lambda: "vstack expects a non-empty TensorList") |
|
|
aligned_tensors = atleast_2d(*tensors) |
|
|
return cat(aligned_tensors, 0) |
|
|
|
|
|
|
|
|
|
|
|
def unflatten(a: TensorLikeType, dim: int, sizes: ShapeType) -> TensorLikeType: |
|
|
dim = utils.canonicalize_dim(a.ndim, dim) |
|
|
utils.check(len(sizes) != 0, lambda: "unflatten: sizes must be non-empty") |
|
|
return a.view(tuple(a.shape[:dim]) + tuple(sizes) + tuple(a.shape[dim + 1 :])) |
|
|
|
|
|
|
|
|
@register_decomposition(torch.ops.aten.unbind) |
|
|
def unbind(t: TensorLikeType, dim: int = 0) -> TensorSequenceType: |
|
|
dim = utils.canonicalize_dim(t.ndim, dim) |
|
|
check( |
|
|
len(t.shape) > 0, |
|
|
lambda: "dimension specified as 0 but tensor has no dimensions", |
|
|
IndexError, |
|
|
) |
|
|
return tuple( |
|
|
torch.squeeze(s, dim) for s in torch.tensor_split(t, t.shape[dim], dim) |
|
|
) |
|
|
|
|
|
|
|
|
@register_decomposition(torch.ops.aten.index_copy) |
|
|
@out_wrapper() |
|
|
def index_copy(x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike): |
|
|
return x.clone().index_copy_(dim, index, tensor) |
|
|
|
|
|
|
|
|
@register_decomposition(torch.ops.aten.index_copy_) |
|
|
def index_copy_(x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike): |
|
|
dim = utils.canonicalize_dims(x.ndim, dim) |
|
|
utils.check( |
|
|
index.ndim <= 1, |
|
|
lambda: f"Index should have dimension 1 or 0 (got {index.ndim})", |
|
|
) |
|
|
|
|
|
y = x.unsqueeze(0) if x.ndim == 0 else x |
|
|
idx = (slice(None),) * dim + (index,) |
|
|
y[idx] = tensor |
|
|
return x |
|
|
|
|
|
|
|
|
@register_decomposition(torch.ops.aten.index_fill) |
|
|
def index_fill( |
|
|
x: TensorLike, dim: int, index: TensorLike, value: Union[NumberType, TensorLike] |
|
|
): |
|
|
return x.clone().index_fill_(dim, index, value) |
|
|
|
|
|
|
|
|
@register_decomposition(torch.ops.aten.index_fill_) |
|
|
def index_fill_( |
|
|
x: TensorLike, dim: int, index: TensorLike, value: Union[NumberType, TensorLike] |
|
|
): |
|
|
if isinstance(value, TensorLike): |
|
|
utils.check( |
|
|
value.ndim == 0, |
|
|
lambda: "Only supports 0-dimensional value tensor. " |
|
|
f"Got a tensor with {value.ndim} dimensions.", |
|
|
) |
|
|
return x.clone().index_copy_(dim, index, value) |
|
|
dim = utils.canonicalize_dims(x.ndim, dim) |
|
|
utils.check( |
|
|
index.ndim <= 1, |
|
|
lambda: f"Index should have dimension 1 or 0 (got {index.ndim})", |
|
|
) |
|
|
idx = (slice(None),) * dim + (index,) |
|
|
|
|
|
y = x.unsqueeze(0) if x.ndim == 0 else x |
|
|
y[idx] = value |
|
|
return x |
|
|
|
|
|
|
|
|
@register_decomposition(torch.ops.aten.index_add) |
|
|
@out_wrapper() |
|
|
def index_add( |
|
|
x: TensorLike, |
|
|
dim: int, |
|
|
index: TensorLike, |
|
|
tensor: TensorLike, |
|
|
*, |
|
|
alpha: NumberType = 1, |
|
|
): |
|
|
return x.clone().index_add_(dim, index, tensor, alpha=alpha) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def index_add_( |
|
|
x: TensorLike, |
|
|
dim: int, |
|
|
index: TensorLike, |
|
|
tensor: TensorLike, |
|
|
*, |
|
|
alpha: NumberType = 1, |
|
|
): |
|
|
dim = utils.canonicalize_dims(x.ndim, dim) |
|
|
utils.check( |
|
|
index.ndim <= 1, |
|
|
lambda: f"Index should have dimension 1 or 0 (got {index.ndim})", |
|
|
) |
|
|
if alpha != 1: |
|
|
python_type = utils.dtype_to_type(x.dtype) |
|
|
utils.check( |
|
|
utils.is_weakly_lesser_type(type(alpha), python_type), |
|
|
lambda: f"alpha argument of type {type(alpha)} cannot be safely cast to type {python_type}!", |
|
|
) |
|
|
tensor = prims.mul(tensor, alpha) |
|
|
|
|
|
y = x.unsqueeze(0) if x.ndim == 0 else x |
|
|
idx = (slice(None),) * dim + (index,) |
|
|
y[idx] += tensor |
|
|
return x |
|
|
|
|
|
|
|
|
@register_decomposition(torch.ops.aten.index_select, disable_meta=True) |
|
|
@out_wrapper() |
|
|
def index_select(x: TensorLike, dim: int, index: TensorLike): |
|
|
dim = utils.canonicalize_dims(x.ndim, dim) |
|
|
utils.check( |
|
|
index.ndim <= 1, |
|
|
lambda: f"Index should have dimension 1 or 0 (got {index.ndim})", |
|
|
) |
|
|
|
|
|
if x.ndim == 0: |
|
|
return x.unsqueeze(0)[index].squeeze(0).clone() |
|
|
idx = (slice(None),) * dim + (index,) |
|
|
return x[idx] |
|
|
|
|
|
|
|
|
|
|
|
@register_decomposition(torch.ops.aten.squeeze, disable_meta=True) |
|
|
def squeeze(a: TensorLikeType, dim: Optional[int] = None) -> TensorLikeType: |
|
|
if dim is not None: |
|
|
dim = utils.canonicalize_dim(a.ndim, dim) |
|
|
|
|
|
if len(a.shape) == 0: |
|
|
assert dim == 0 |
|
|
return prims.view_of(a) |
|
|
|
|
|
|
|
|
if a.shape[dim] != 1: |
|
|
return prims.view_of(a) |
|
|
return prims.squeeze(a, (dim,)) |
|
|
|
|
|
dims = tuple(idx for idx in range(len(a.shape)) if a.shape[idx] == 1) |
|
|
return prims.squeeze(a, dims) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def tensor_split( |
|
|
a: TensorLikeType, |
|
|
indices_or_sections: Union[Tensor, DimsType], |
|
|
dim: int = 0, |
|
|
) -> Tuple[TensorLikeType, ...]: |
|
|
_dim = utils.canonicalize_dim(a.ndim, dim) |
|
|
if a.ndim == 0: |
|
|
msg = "tensor_split: received a rank zero tensor, but expected a tensor of rank one or greater!" |
|
|
raise ValueError(msg) |
|
|
|
|
|
|
|
|
if isinstance(indices_or_sections, TensorLike): |
|
|
if not indices_or_sections.device.type == "cpu": |
|
|
msg = "tensor_split: if indices_or_sections is a tensor it must be on the CPU, but received one on {0}".format( |
|
|
indices_or_sections.device |
|
|
) |
|
|
raise ValueError(msg) |
|
|
if indices_or_sections.dtype != torch.long: |
|
|
msg = "tensor_split: if indices_or_sections is a tensor it must have long dtype, " |
|
|
" but received one with dtype {0}".format(indices_or_sections.dtype) |
|
|
raise ValueError(msg) |
|
|
|
|
|
|
|
|
if isinstance(indices_or_sections, int) or ( |
|
|
isinstance(indices_or_sections, TensorLike) and indices_or_sections.ndim == 0 |
|
|
): |
|
|
sections: int = ( |
|
|
indices_or_sections |
|
|
if isinstance(indices_or_sections, Number) |
|
|
else indices_or_sections.item() |
|
|
) |
|
|
|
|
|
if sections <= 0: |
|
|
msg = "tensor_split: number of sections must be greater than 0, but was {0}".format( |
|
|
sections |
|
|
) |
|
|
raise ValueError(msg) |
|
|
|
|
|
splits = [] |
|
|
dim_size = a.shape[_dim] |
|
|
min_split_size = math.floor(dim_size / sections) |
|
|
num_splits_one_extra = dim_size % sections |
|
|
start_idx = 0 |
|
|
for split_idx in range(sections): |
|
|
split_size = ( |
|
|
min_split_size + 1 |
|
|
if (split_idx < num_splits_one_extra) |
|
|
else min_split_size |
|
|
) |
|
|
s = prims.slice_in_dim(a, start_idx, start_idx + split_size, axis=_dim) |
|
|
splits.append(s) |
|
|
start_idx = start_idx + split_size |
|
|
|
|
|
return tuple(splits) |
|
|
|
|
|
else: |
|
|
indices = indices_or_sections |
|
|
if isinstance(indices_or_sections, TensorLike): |
|
|
if indices_or_sections.ndim != 1: |
|
|
msg = "tensor_split: non-scalar indices_or_sections tensors must have only one dimension, " |
|
|
"but received a tensor with {0} dimensions".format( |
|
|
indices_or_sections.ndim |
|
|
) |
|
|
raise ValueError(msg) |
|
|
|
|
|
indices = indices_or_sections.tolist() |
|
|
|
|
|
splits = [] |
|
|
start_idx = 0 |
|
|
for x in indices: |
|
|
splits.append(prims.slice_in_dim(a, start_idx, x, axis=_dim)) |
|
|
start_idx = x |
|
|
splits.append(prims.slice_in_dim(a, start_idx, a.shape[_dim], axis=_dim)) |
|
|
return tuple(splits) |
|
|
|
|
|
|
|
|
|
|
|
def hsplit( |
|
|
a: TensorLikeType, indices_or_sections: DimsType |
|
|
) -> Tuple[TensorLikeType, ...]: |
|
|
check( |
|
|
a.ndim >= 1, |
|
|
lambda: ( |
|
|
"torch.hsplit requires a tensor with at least 1 dimension, but got a tensor with " |
|
|
+ str(a.ndim) |
|
|
+ " dimensions!" |
|
|
), |
|
|
) |
|
|
dim = 0 if a.ndim == 1 else 1 |
|
|
if isinstance(indices_or_sections, int): |
|
|
split_size = indices_or_sections |
|
|
check( |
|
|
(split_size != 0 and a.shape[dim] % split_size == 0), |
|
|
lambda: ( |
|
|
"torch.hsplit attempted to split along dimension " |
|
|
+ str(dim) |
|
|
+ ", but the size of the dimension " |
|
|
+ str(a.shape[dim]) |
|
|
+ " is not divisible by the split_size " |
|
|
+ str(split_size) |
|
|
+ "!" |
|
|
), |
|
|
) |
|
|
return tensor_split(a, split_size, dim) |
|
|
|
|
|
check( |
|
|
isinstance(indices_or_sections, (list, tuple)), |
|
|
lambda: ( |
|
|
"hsplit(): received an invalid combination of arguments. " |
|
|
"Expected indices_or_sections to be of type int, list of ints or tuple of ints " |
|
|
f"but got type {type(indices_or_sections)}" |
|
|
), |
|
|
exc_type=TypeError, |
|
|
) |
|
|
|
|
|
split_sizes = indices_or_sections |
|
|
return tensor_split(a, split_sizes, dim) |
|
|
|
|
|
|
|
|
|
|
|
def vsplit( |
|
|
a: TensorLikeType, indices_or_sections: DimsType |
|
|
) -> Tuple[TensorLikeType, ...]: |
|
|
check( |
|
|
a.ndim >= 2, |
|
|
lambda: ( |
|
|
"torch.vsplit requires a tensor with at least 2 dimension, but got a tensor with " |
|
|
+ str(a.ndim) |
|
|
+ " dimensions!" |
|
|
), |
|
|
) |
|
|
if isinstance(indices_or_sections, int): |
|
|
split_size = indices_or_sections |
|
|
check( |
|
|
(split_size != 0 and a.shape[0] % split_size == 0), |
|
|
lambda: ( |
|
|
"torch.vsplit attempted to split along dimension 0 " |
|
|
+ ", but the size of the dimension " |
|
|
+ str(a.shape[0]) |
|
|
+ " is not divisible by the split_size " |
|
|
+ str(split_size) |
|
|
+ "!" |
|
|
), |
|
|
) |
|
|
return tensor_split(a, split_size, 0) |
|
|
|
|
|
check( |
|
|
isinstance(indices_or_sections, (list, tuple)), |
|
|
lambda: ( |
|
|
"vsplit(): received an invalid combination of arguments. " |
|
|
"Expected indices_or_sections to be of type int, list of ints or tuple of ints " |
|
|
f"but got type {type(indices_or_sections)}" |
|
|
), |
|
|
exc_type=TypeError, |
|
|
) |
|
|
|
|
|
split_sizes = indices_or_sections |
|
|
return tensor_split(a, split_sizes, 0) |
|
|
|
|
|
|
|
|
@register_decomposition(torch.ops.aten.diagonal, disable_meta=True) |
|
|
def diagonal( |
|
|
self: TensorLikeType, |
|
|
offset: int = 0, |
|
|
dim1: int = 0, |
|
|
dim2: int = 1, |
|
|
) -> TensorLikeType: |
|
|
""" |
|
|
Reference implementation of torch.diagonal |
|
|
""" |
|
|
num_dims = self.dim() |
|
|
dim1 = utils.canonicalize_dim(idx=dim1, rank=num_dims) |
|
|
dim2 = utils.canonicalize_dim(idx=dim2, rank=num_dims) |
|
|
|
|
|
check( |
|
|
dim1 != dim2, lambda: f"diagonal dimensions cannot be identical {dim1}, {dim2}" |
|
|
) |
|
|
|
|
|
storage_offset = self.storage_offset() |
|
|
|
|
|
if offset >= 0: |
|
|
diag_size = max(min(self.size()[dim1], self.size()[dim2] - offset), 0) |
|
|
else: |
|
|
diag_size = max(min(self.size()[dim1] + offset, self.size()[dim2]), 0) |
|
|
|
|
|
if diag_size > 0: |
|
|
if offset >= 0: |
|
|
storage_offset += offset * self.stride()[dim2] |
|
|
else: |
|
|
storage_offset -= offset * self.stride()[dim1] |
|
|
|
|
|
sizes = [s for i, s in enumerate(self.size()) if i not in (dim1, dim2)] |
|
|
sizes.append(diag_size) |
|
|
|
|
|
strides = [s for i, s in enumerate(self.stride()) if i not in (dim1, dim2)] |
|
|
strides.append(self.stride()[dim1] + self.stride()[dim2]) |
|
|
|
|
|
result = self.as_strided(size=sizes, stride=strides, storage_offset=storage_offset) |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
@register_decomposition(torch.ops.aten.diag_embed) |
|
|
def diag_embed( |
|
|
t: TensorLikeType, |
|
|
offset: int = 0, |
|
|
dim1: int = -2, |
|
|
dim2: int = -1, |
|
|
) -> TensorLikeType: |
|
|
""" |
|
|
Reference implementation of torch.diag_embed |
|
|
""" |
|
|
|
|
|
|
|
|
if dim1 > dim2: |
|
|
dim1, dim2 = dim2, dim1 |
|
|
offset = -offset |
|
|
|
|
|
|
|
|
rank = t.ndim + 1 |
|
|
dim1 = utils.canonicalize_dim(rank=rank, idx=dim1) |
|
|
dim2 = utils.canonicalize_dim(rank=rank, idx=dim2) |
|
|
|
|
|
check( |
|
|
dim1 != dim2, lambda: f"diagonal dimensions cannot be identical {dim1}, {dim2}" |
|
|
) |
|
|
|
|
|
|
|
|
last_dim = t.size(-1) |
|
|
|
|
|
if offset != 0: |
|
|
|
|
|
t_shape = list(t.shape) |
|
|
t_shape[-1] = builtins.abs(offset) |
|
|
z = torch.zeros(t_shape, dtype=t.dtype, device=t.device, requires_grad=False) |
|
|
pair = (z, t) if offset > 0 else (t, z) |
|
|
t = torch.cat(pair, dim=-1) |
|
|
|
|
|
last_dim += builtins.abs(offset) |
|
|
|
|
|
|
|
|
t = t.unsqueeze(dim1).movedim(-1, dim2) |
|
|
|
|
|
|
|
|
a_range = torch.arange(last_dim, device=t.device, dtype=torch.int64) |
|
|
b_range = torch.arange( |
|
|
offset, last_dim + offset, device=t.device, dtype=torch.int64 |
|
|
) |
|
|
|
|
|
|
|
|
cond = a_range == b_range.unsqueeze(-1) |
|
|
cond_shape = [last_dim if i in (dim1, dim2) else 1 for i in range(len(t.shape))] |
|
|
cond = cond.reshape(cond_shape) |
|
|
return utils.mask_tensor(cond, t) |
|
|
|
|
|
|
|
|
|
|
|
def dsplit(a: TensorLikeType, sections: DimsType) -> TensorSequenceType: |
|
|
if a.ndim < 3: |
|
|
raise RuntimeError( |
|
|
f"torch.dsplit requires a tensor with at least 3 dimension, but got a tensor with {a.ndim} dimensions!" |
|
|
) |
|
|
if isinstance(sections, int) and (sections == 0 or a.shape[2] % sections != 0): |
|
|
raise RuntimeError( |
|
|
"torch._refs.dsplit attempted to split along dimension 2, " |
|
|
+ f"but the size of the dimension {a.shape[2]} is not divisible by the split_size {sections}!" |
|
|
) |
|
|
return tensor_split(a, sections, 2) |
|
|
|
|
|
|
|
|
@register_decomposition(torch.ops.aten.t.default, disable_meta=True) |
|
|
def t(a: TensorLikeType): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if a.ndim > 2: |
|
|
raise RuntimeError( |
|
|
f"t() expects a tensor with <= 2 dimensions, but self is {a.ndim}D" |
|
|
) |
|
|
return torch.transpose(a, 0, 0 if a.ndim < 2 else 1) |
|
|
|
|
|
|
|
|
@register_decomposition(torch.ops.aten.transpose, disable_meta=True) |
|
|
def transpose(a: TensorLikeType, dim0: int, dim1: int) -> TensorLikeType: |
|
|
_dim0, _dim1 = utils.canonicalize_dims(a.ndim, (dim0, dim1)) |
|
|
|
|
|
if a.ndim <= 1 or dim0 == dim1: |
|
|
return prims.view_of(a) |
|
|
|
|
|
_permutation = list(range(0, a.ndim)) |
|
|
_permutation[_dim0] = _dim1 |
|
|
_permutation[_dim1] = _dim0 |
|
|
return torch.permute(a, _permutation) |
|
|
|
|
|
|
|
|
|
|
|
swap_axes = transpose |
|
|
|
|
|
|
|
|
@register_decomposition(torch.ops.aten.unfold_copy) |
|
|
def unfold_copy(a: TensorLikeType, dimension: int, size: int, step: int): |
|
|
new_size, new_stride = _get_unfold_copy_shape_stride( |
|
|
a.shape, a.stride(), dimension, size, step |
|
|
) |
|
|
return a.as_strided(new_size, new_stride) |
|
|
|
|
|
|
|
|
@register_decomposition(torch.ops.aten.cumsum) |
|
|
def cumsum( |
|
|
a: TensorLikeType, |
|
|
dim: int, |
|
|
*, |
|
|
keepdim: bool = False, |
|
|
dtype: Optional[torch.dtype] = None, |
|
|
out: Optional[Tensor] = None, |
|
|
) -> TensorLikeType: |
|
|
|
|
|
|
|
|
ndim = a.ndim |
|
|
dim = utils.canonicalize_dim(ndim, dim) |
|
|
if ndim == 0: |
|
|
return sum(a.unsqueeze(0), dim=0, keepdim=keepdim, dtype=dtype, out=out) |
|
|
a = a.unsqueeze(dim + 1) |
|
|
rg = torch.arange(a.shape[dim], device=a.device) |
|
|
mask = rg.unsqueeze(1) <= rg |
|
|
for _ in range(ndim - dim - 1): |
|
|
mask = mask.unsqueeze(-1) |
|
|
masked_a = utils.mask_tensor(mask, a) |
|
|
return sum(masked_a, dim=dim, keepdim=keepdim, dtype=dtype, out=out) |
|
|
|
|
|
|
|
|
@register_decomposition(torch.ops.aten.unsqueeze, disable_meta=True) |
|
|
def unsqueeze(a: TensorLikeType, dim: int) -> TensorLikeType: |
|
|
|
|
|
|
|
|
ndim = a.ndim + 1 |
|
|
dim = utils.canonicalize_dim(ndim, dim) |
|
|
return prims.expand_dims(a, (dim,), ndim=ndim) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_decomposition(torch.ops.aten.view, disable_meta=True) |
|
|
def view(a: TensorLikeType, *shape: ShapeType) -> TensorLikeType: |
|
|
return _reshape_view_helper(a, *shape, allow_copy=False) |
|
|
|
|
|
|
|
|
|
|
|
def view_as(self: TensorLikeType, other: TensorLikeType) -> TensorLikeType: |
|
|
return self.view(other.size()) |
|
|
|
|
|
|
|
|
|
|
|
def ravel(a: TensorLikeType) -> TensorLikeType: |
|
|
return reshape(a, (-1,)) |
|
|
|
|
|
|
|
|
@register_decomposition(torch.ops.aten.empty) |
|
|
@out_wrapper() |
|
|
def empty( |
|
|
*shape, |
|
|
dtype: Optional[torch.dtype] = None, |
|
|
layout: torch.layout = torch.strided, |
|
|
device: Optional[torch.device] = None, |
|
|
requires_grad: bool = False, |
|
|
pin_memory: bool = False, |
|
|
memory_format: torch.memory_format = torch.contiguous_format, |
|
|
) -> TensorLikeType: |
|
|
check( |
|
|
memory_format != torch.preserve_format, |
|
|
lambda: "torch.empty: the Preserve memory format is not supported", |
|
|
) |
|
|
|
|
|
shape = utils.extract_shape_from_varargs(shape) |
|
|
|
|
|
if memory_format == torch.contiguous_format: |
|
|
strides = utils.make_contiguous_strides_for(shape) |
|
|
elif memory_format == torch.channels_last_3d: |
|
|
strides = utils.make_channels_last_3d_strides_for(shape) |
|
|
else: |
|
|
check( |
|
|
memory_format == torch.channels_last, |
|
|
lambda: f"torch.empty: received an unknown memory format {memory_format}!", |
|
|
) |
|
|
strides = utils.make_channels_last_2d_strides_for(shape) |
|
|
|
|
|
return torch.empty_strided( |
|
|
shape, |
|
|
strides, |
|
|
dtype=dtype, |
|
|
layout=layout, |
|
|
device=device, |
|
|
pin_memory=pin_memory, |
|
|
requires_grad=requires_grad, |
|
|
) |
|
|
|
|
|
|
|
|
@register_decomposition(torch.ops.aten.new_empty) |
|
|
def new_empty( |
|
|
a: TensorLikeType, |
|
|
size: ShapeType, |
|
|
*, |
|
|
dtype: Optional[torch.dtype] = None, |
|
|
layout: Optional[torch.layout] = None, |
|
|
device: Optional[torch.device] = None, |
|
|
pin_memory: bool = False, |
|
|
) -> TensorLikeType: |
|
|
|
|
|
dtype = a.dtype if dtype is None else dtype |
|
|
layout = a.layout if layout is None else layout |
|
|
device = a.device if device is None else device |
|
|
|
|
|
return torch.empty( |
|
|
size, |
|
|
dtype=dtype, |
|
|
device=device, |
|
|
pin_memory=pin_memory, |
|
|
layout=layout, |
|
|
) |
|
|
|
|
|
|
|
|
@register_decomposition(torch.ops.aten.new_empty_strided) |
|
|
def new_empty_strided( |
|
|
a: TensorLikeType, |
|
|
size: ShapeType, |
|
|
stride: StrideType, |
|
|
*, |
|
|
dtype: Optional[torch.dtype] = None, |
|
|
layout: Optional[torch.layout] = None, |
|
|
device: Optional[torch.device] = None, |
|
|
pin_memory: bool = False, |
|
|
) -> TensorLikeType: |
|
|
""" |
|
|
Reference implementation of torch.Tensor.new_empty_strided |
|
|
""" |
|
|
|
|
|
dtype = a.dtype if dtype is None else dtype |
|
|
layout = a.layout if layout is None else layout |
|
|
device = a.device if device is None else device |
|
|
|
|
|
return torch.empty_strided( |
|
|
size, |
|
|
stride, |
|
|
dtype=dtype, |
|
|
device=device, |
|
|
pin_memory=pin_memory, |
|
|
layout=layout, |
|
|
) |
|
|
|
|
|
|
|
|
@register_decomposition(torch.ops.aten.zeros) |
|
|
@out_wrapper() |
|
|
def zeros( |
|
|
*size, |
|
|
dtype: Optional[torch.dtype] = None, |
|
|
layout: torch.layout = torch.strided, |
|
|
device: Optional[torch.device] = None, |
|
|
pin_memory: bool = False, |
|
|
requires_grad: bool = False, |
|
|
) -> TensorLikeType: |
|
|
size = utils.extract_shape_from_varargs(size) |
|
|
|
|
|
if dtype is None: |
|
|
dtype = torch.get_default_dtype() |
|
|
|
|
|
return torch.full( |
|
|
size, |
|
|
False if dtype == torch.bool else 0, |
|
|
dtype=dtype, |
|
|
layout=layout, |
|
|
device=device, |
|
|
pin_memory=pin_memory, |
|
|
requires_grad=requires_grad, |
|
|
) |
|
|
|
|
|
|
|
|
@register_decomposition(torch.ops.aten.new_zeros) |
|
|
def new_zeros( |
|
|
a: TensorLikeType, |
|
|
size: ShapeType, |
|
|
*, |
|
|
dtype: Optional[torch.dtype] = None, |
|
|
layout: Optional[torch.layout] = None, |
|
|
device: Optional[torch.device] = None, |
|
|
pin_memory: bool = False, |
|
|
requires_grad: bool = False, |
|
|
) -> TensorLikeType: |
|
|
dtype = a.dtype if dtype is None else dtype |
|
|
layout = a.layout if layout is None else layout |
|
|
device = a.device if device is None else device |
|
|
|
|
|
return torch.full( |
|
|
size, |
|
|
False if dtype == torch.bool else 0, |
|
|
dtype=dtype, |
|
|
layout=layout, |
|
|
device=device, |
|
|
pin_memory=pin_memory, |
|
|
requires_grad=requires_grad, |
|
|
) |
|
|
|
|
|
|
|
|
@register_decomposition(torch.ops.aten.ones) |
|
|
@out_wrapper() |
|
|
def ones( |
|
|
*size, |
|
|
dtype: Optional[torch.dtype] = None, |
|
|
layout: torch.layout = torch.strided, |
|
|
device: Optional[torch.device] = None, |
|
|
pin_memory: bool = False, |
|
|
requires_grad: bool = False, |
|
|
) -> TensorLikeType: |
|
|
size = utils.extract_shape_from_varargs(size) |
|
|
|
|
|
if dtype is None: |
|
|
dtype = torch.get_default_dtype() |
|
|
|
|
|
return torch.full( |
|
|
size, |
|
|
True if dtype == torch.bool else 1, |
|
|
dtype=dtype, |
|
|
layout=layout, |
|
|
device=device, |
|
|
pin_memory=pin_memory, |
|
|
requires_grad=requires_grad, |
|
|
) |
|
|
|
|
|
|
|
|
@register_decomposition(torch.ops.aten.new_ones) |
|
|
def new_ones( |
|
|
a: TensorLikeType, |
|
|
size: ShapeType, |
|
|
*, |
|
|
dtype: Optional[torch.dtype] = None, |
|
|
layout: Optional[torch.layout] = None, |
|
|
device: Optional[torch.device] = None, |
|
|
pin_memory: bool = False, |
|
|
requires_grad: bool = False, |
|
|
) -> TensorLikeType: |
|
|
dtype = a.dtype if dtype is None else dtype |
|
|
layout = a.layout if layout is None else layout |
|
|
device = a.device if device is None else device |
|
|
|
|
|
return torch.full( |
|
|
size, |
|
|
True if dtype == torch.bool else 1, |
|
|
dtype=dtype, |
|
|
layout=layout, |
|
|
device=device, |
|
|
pin_memory=pin_memory, |
|
|
requires_grad=requires_grad, |
|
|
) |
|
|
|
|
|
|
|
|
@register_decomposition(torch.ops.aten.new_full) |
|
|
def new_full( |
|
|
a: TensorLikeType, |
|
|
size: ShapeType, |
|
|
fill_value: Union[int, float, bool], |
|
|
*, |
|
|
dtype: Optional[torch.dtype] = None, |
|
|
layout: Optional[torch.layout] = None, |
|
|
device: Optional[torch.device] = None, |
|
|
pin_memory: bool = False, |
|
|
) -> TensorLikeType: |
|
|
dtype = a.dtype if dtype is None else dtype |
|
|
layout = a.layout if layout is None else layout |
|
|
device = a.device if device is None else device |
|
|
|
|
|
return torch.full( |
|
|
size, |
|
|
fill_value, |
|
|
dtype=dtype, |
|
|
layout=layout, |
|
|
device=device, |
|
|
pin_memory=pin_memory, |
|
|
) |
|
|
|
|
|
|
|
|
@register_decomposition(torch.ops.aten.empty_like) |
|
|
def empty_like( |
|
|
a: TensorLikeType, |
|
|
*, |
|
|
dtype: Optional[torch.dtype] = None, |
|
|
device: Optional[torch.device] = None, |
|
|
layout: Optional[torch.layout] = None, |
|
|
pin_memory: bool = False, |
|
|
requires_grad: bool = False, |
|
|
memory_format: torch.memory_format = torch.preserve_format, |
|
|
) -> TensorLikeType: |
|
|
|
|
|
dtype = a.dtype if dtype is None else dtype |
|
|
layout = a.layout if layout is None else layout |
|
|
device = a.device if device is None else device |
|
|
|
|
|
strides: Tuple[int, ...] |
|
|
|
|
|
if memory_format != torch.preserve_format: |
|
|
return torch.empty( |
|
|
a.shape, |
|
|
dtype=dtype, |
|
|
layout=layout, |
|
|
device=device, |
|
|
requires_grad=requires_grad, |
|
|
pin_memory=pin_memory, |
|
|
memory_format=memory_format, |
|
|
) |
|
|
|
|
|
|
|
|
strides = utils.compute_elementwise_output_strides(a) |
|
|
return torch.empty_strided( |
|
|
a.shape, |
|
|
strides, |
|
|
dtype=dtype, |
|
|
layout=layout, |
|
|
device=device, |
|
|
pin_memory=pin_memory, |
|
|
requires_grad=requires_grad, |
|
|
) |
|
|
|
|
|
|
|
|
@register_decomposition( |
|
|
[ |
|
|
torch.ops.aten.arange.default, |
|
|
torch.ops.aten.arange.start, |
|
|
torch.ops.aten.arange.start_step, |
|
|
] |
|
|
) |
|
|
@out_wrapper() |
|
|
def arange( |
|
|
start: NumberType = 0, |
|
|
end: Optional[NumberType] = None, |
|
|
step: NumberType = 1, |
|
|
*, |
|
|
dtype: Optional[torch.dtype] = None, |
|
|
layout: torch.layout = torch.strided, |
|
|
device: Optional[torch.device] = None, |
|
|
pin_memory: bool = False, |
|
|
requires_grad: bool = False, |
|
|
) -> TensorLikeType: |
|
|
utils.check_layout(layout) |
|
|
utils.check_pin_memory(pin_memory) |
|
|
|
|
|
if end is None: |
|
|
end = start |
|
|
start = 0 |
|
|
return prims.arange( |
|
|
start, |
|
|
end, |
|
|
step, |
|
|
dtype=dtype, |
|
|
|
|
|
device=device, |
|
|
|
|
|
requires_grad=requires_grad, |
|
|
) |
|
|
|
|
|
|
|
|
@register_decomposition(torch.ops.aten.linspace) |
|
|
@out_wrapper() |
|
|
def linspace( |
|
|
start: NumberType, |
|
|
end: NumberType, |
|
|
steps: NumberType, |
|
|
*, |
|
|
dtype: Optional[torch.dtype] = None, |
|
|
device: Optional[torch.device] = None, |
|
|
layout: torch.layout = torch.strided, |
|
|
pin_memory: bool = False, |
|
|
requires_grad: bool = False, |
|
|
) -> TensorLikeType: |
|
|
if dtype is None: |
|
|
dtype = torch.get_default_dtype() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if prims.utils.is_integer_dtype(dtype): |
|
|
if isinstance(start, float): |
|
|
start = int(start) |
|
|
if isinstance(end, float): |
|
|
end = int(end) |
|
|
|
|
|
if py_any(isinstance(arg, complex) for arg in (start, end, steps)): |
|
|
raise NotImplementedError |
|
|
assert not isinstance(start, complex) and not isinstance(end, complex) |
|
|
|
|
|
check( |
|
|
isinstance(steps, int), |
|
|
lambda: "steps must be int, not float", |
|
|
exc_type=TypeError, |
|
|
) |
|
|
assert isinstance(steps, int) |
|
|
check(steps >= 0, lambda: "number of steps must be non-negative") |
|
|
|
|
|
factory_kwargs = { |
|
|
"layout": layout, |
|
|
"device": device, |
|
|
"pin_memory": pin_memory, |
|
|
"requires_grad": requires_grad, |
|
|
} |
|
|
if steps == 0: |
|
|
ret = torch.full((0,), 0, dtype=dtype, **factory_kwargs) |
|
|
elif steps == 1: |
|
|
ret = torch.full((1,), start, dtype=dtype, **factory_kwargs) |
|
|
elif start == end: |
|
|
ret = torch.full((steps,), start, dtype=dtype, **factory_kwargs) |
|
|
else: |
|
|
if prims.utils.is_integer_dtype(dtype): |
|
|
|
|
|
|
|
|
assert isinstance(start, int) and isinstance(end, int) |
|
|
step_size_x_denom = end - start |
|
|
eps = 1 if end > start else -1 |
|
|
denom = steps - 1 |
|
|
ret = prims.to_dtype( |
|
|
torch.arange( |
|
|
start * denom, |
|
|
end * denom + eps, |
|
|
step_size_x_denom, |
|
|
dtype=torch.int64, |
|
|
**factory_kwargs, |
|
|
) |
|
|
/ denom, |
|
|
dtype, |
|
|
) |
|
|
else: |
|
|
step_size = (end - start) / (steps - 1) |
|
|
eps = step_size / 2 |
|
|
ret = prims.to_dtype( |
|
|
torch.arange( |
|
|
start, end + eps, step_size, dtype=torch.float64, **factory_kwargs |
|
|
), |
|
|
dtype, |
|
|
) |
|
|
|
|
|
return ret |
|
|
|
|
|
|
|
|
@register_decomposition(torch.ops.aten.logspace) |
|
|
@out_wrapper() |
|
|
def logspace( |
|
|
start: NumberType, |
|
|
end: NumberType, |
|
|
steps: NumberType, |
|
|
base: NumberType = 10, |
|
|
*, |
|
|
dtype: Optional[torch.dtype] = None, |
|
|
device: Optional[torch.device] = None, |
|
|
layout: torch.layout = torch.strided, |
|
|
pin_memory: bool = False, |
|
|
requires_grad: bool = False, |
|
|
) -> TensorLikeType: |
|
|
if dtype is None: |
|
|
dtype = torch.get_default_dtype() |
|
|
|
|
|
|
|
|
if prims.utils.is_integer_dtype(dtype): |
|
|
if isinstance(start, float): |
|
|
start = int(start) |
|
|
if isinstance(end, float): |
|
|
end = int(end) |
|
|
|
|
|
assert not isinstance(base, complex) |
|
|
if base < 0: |
|
|
raise NotImplementedError |
|
|
ret = torch.linspace( |
|
|
start, |
|
|
end, |
|
|
steps, |
|
|
dtype=torch.float64, |
|
|
layout=layout, |
|
|
device=device, |
|
|
pin_memory=pin_memory, |
|
|
requires_grad=requires_grad, |
|
|
) |
|
|
return prims.to_dtype(torch.pow(base, ret), dtype) |
|
|
|
|
|
|
|
|
@overload |
|
|
def meshgrid(tensors: Sequence[TensorLikeType], indexing: str): |
|
|
pass |
|
|
|
|
|
|
|
|
@overload |
|
|
def meshgrid(*tensors: TensorLikeType, indexing: str): |
|
|
pass |
|
|
|
|
|
|
|
|
@register_decomposition(torch.ops.aten.meshgrid) |
|
|
def meshgrid( |
|
|
*tensors: Union[TensorLikeType, List[TensorLikeType], Tuple[TensorLikeType]], |
|
|
indexing: str, |
|
|
) -> List[TensorLikeType]: |
|
|
|
|
|
|
|
|
|
|
|
if isinstance(tensors[0], list) or isinstance(tensors[0], tuple): |
|
|
assert len(tensors) == 1 |
|
|
tensors = tuple(tensors[0]) |
|
|
|
|
|
check( |
|
|
py_all(isinstance(a, TensorLike) for a in tensors), |
|
|
lambda: "meshgrid expects its inputs to be tensors", |
|
|
) |
|
|
|
|
|
check(len(tensors) > 0, lambda: "meshgrid expects a non-empty TensorList") |
|
|
|
|
|
for i in range(len(tensors) - 1): |
|
|
check( |
|
|
tensors[i].dtype == tensors[i + 1].dtype, |
|
|
lambda: "meshgrid expects all tensors to have the same dtype", |
|
|
) |
|
|
check( |
|
|
tensors[i].device == tensors[i + 1].device, |
|
|
lambda: "meshgrid expects all tensors to have the same device", |
|
|
) |
|
|
|
|
|
swap_first_and_second_tensors = False |
|
|
if indexing == "xy": |
|
|
swap_first_and_second_tensors = len(tensors) >= 2 |
|
|
if swap_first_and_second_tensors: |
|
|
tensors = (tensors[1], tensors[0], *tensors[2:]) |
|
|
else: |
|
|
check( |
|
|
indexing == "ij", |
|
|
lambda: ( |
|
|
'torch.meshgrid: indexing must be one of "xy" or "ij", ' |
|
|
f"but received: {indexing}" |
|
|
), |
|
|
) |
|
|
|
|
|
result_shape: List[int] = [] |
|
|
for t in tensors: |
|
|
assert isinstance(t, TensorLike) |
|
|
check( |
|
|
t.ndim == 0 or t.ndim == 1, |
|
|
lambda: f"torch.meshgrid: Expected 0D or 1D tensor in the tensor list but got: {t}", |
|
|
) |
|
|
result_shape.append(t.numel()) |
|
|
|
|
|
grids: List[TensorLikeType] = [] |
|
|
for i, t in enumerate(tensors): |
|
|
assert isinstance(t, TensorLike) |
|
|
if t.ndim == 0: |
|
|
t = t.view((1,)) |
|
|
grids.append(prims.broadcast_in_dim(t, result_shape, (i,))) |
|
|
|
|
|
if swap_first_and_second_tensors: |
|
|
|
|
|
grids[0], grids[1] = grids[1], grids[0] |
|
|
|
|
|
return grids |
|
|
|
|
|
|
|
|
|
|
|
def movedim( |
|
|
input: TensorLikeType, |
|
|
source: Union[int, DimsSequenceType], |
|
|
destination: Union[int, DimsSequenceType], |
|
|
) -> TensorLikeType: |
|
|
""" |
|
|
Reference implementation of torch.movedim |
|
|
""" |
|
|
if type(source) is int: |
|
|
source = (source,) |
|
|
if type(destination) is int: |
|
|
destination = (destination,) |
|
|
|
|
|
utils.check( |
|
|
len(source) == len(destination), |
|
|
lambda: ( |
|
|
"movedim: Invalid source or destination dims: source " |
|
|
f"({source} dims) should contain the same number of dims as " |
|
|
f"destination ({destination} dims)" |
|
|
), |
|
|
) |
|
|
|
|
|
rank = input.ndim |
|
|
ss = tuple(utils.canonicalize_dims(rank=rank, indices=source)) |
|
|
ds = tuple(utils.canonicalize_dims(rank=rank, indices=destination)) |
|
|
|
|
|
sss = set(ss) |
|
|
dss = set(ds) |
|
|
|
|
|
utils.check( |
|
|
len(ss) == len(sss), |
|
|
lambda: f"movedim: repeated dim in `source` {source}", |
|
|
) |
|
|
utils.check( |
|
|
len(ds) == len(dss), |
|
|
lambda: f"movedim: repeated dim in `destination` {destination}", |
|
|
) |
|
|
|
|
|
m = dict(zip(ds, ss)) |
|
|
dims = [] |
|
|
si = 0 |
|
|
for di in range(rank): |
|
|
|
|
|
s = m.get(di) |
|
|
if s is not None: |
|
|
|
|
|
dims.append(s) |
|
|
else: |
|
|
|
|
|
while si in sss: |
|
|
si += 1 |
|
|
dims.append(si) |
|
|
si += 1 |
|
|
|
|
|
result = torch.permute(input, tuple(dims)) |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
@register_decomposition(torch.ops.aten.empty_strided) |
|
|
def empty_strided( |
|
|
shape: Union[ShapeType, Tuple[ShapeType]], |
|
|
strides: StrideType, |
|
|
*, |
|
|
dtype: Optional[torch.dtype] = None, |
|
|
device: Optional[torch.device] = None, |
|
|
layout: torch.layout = torch.strided, |
|
|
requires_grad: bool = False, |
|
|
pin_memory: bool = False, |
|
|
) -> TensorLikeType: |
|
|
|
|
|
utils.check_layout(layout) |
|
|
utils.check_pin_memory(pin_memory) |
|
|
|
|
|
shape = utils.extract_shape_from_varargs(shape) |
|
|
dtype = torch.get_default_dtype() if dtype is None else dtype |
|
|
device = torch.device("cpu") if device is None else device |
|
|
|
|
|
return prims.empty_strided( |
|
|
shape, |
|
|
strides, |
|
|
dtype=dtype, |
|
|
device=device, |
|
|
requires_grad=requires_grad, |
|
|
) |
|
|
|
|
|
|
|
|
@register_decomposition(torch.ops.aten.eye) |
|
|
@out_wrapper() |
|
|
def eye( |
|
|
n: int, |
|
|
m: Optional[int] = None, |
|
|
*, |
|
|
dtype: Optional[torch.dtype] = None, |
|
|
layout: torch.layout = torch.strided, |
|
|
device: Optional[torch.device] = None, |
|
|
pin_memory: bool = False, |
|
|
requires_grad: bool = False, |
|
|
) -> TensorLikeType: |
|
|
""" |
|
|
Reference implementation of torch.eye |
|
|
""" |
|
|
if m is None: |
|
|
m = n |
|
|
|
|
|
check(n >= 0, lambda: f"n must be greater or equal to 0, got {n}") |
|
|
check(m >= 0, lambda: f"m must be greater or equal to 0, got {m}") |
|
|
|
|
|
range_n = torch.arange(n, dtype=torch.int64, device=device, requires_grad=False) |
|
|
range_m = torch.arange(m, dtype=torch.int64, device=device, requires_grad=False) |
|
|
|
|
|
cond = range_n.unsqueeze(-1) == range_m |
|
|
if dtype is torch.bool: |
|
|
return cond |
|
|
else: |
|
|
one = torch.ones( |
|
|
(1,), |
|
|
dtype=dtype, |
|
|
layout=layout, |
|
|
device=device, |
|
|
pin_memory=pin_memory, |
|
|
requires_grad=False, |
|
|
) |
|
|
return torch.where(cond, one, 0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@out_wrapper() |
|
|
def full( |
|
|
shape: ShapeType, |
|
|
fill_value: NumberType, |
|
|
*, |
|
|
dtype: Optional[torch.dtype] = None, |
|
|
layout: torch.layout = torch.strided, |
|
|
device: Optional[torch.device] = None, |
|
|
pin_memory: bool = False, |
|
|
requires_grad: bool = False, |
|
|
) -> TensorLikeType: |
|
|
e = empty( |
|
|
shape, |
|
|
dtype=dtype, |
|
|
layout=layout, |
|
|
device=device, |
|
|
pin_memory=pin_memory, |
|
|
requires_grad=requires_grad, |
|
|
) |
|
|
return fill(e, fill_value) |
|
|
|
|
|
|
|
|
def full_like( |
|
|
a: TensorLikeType, |
|
|
fill_value: NumberType, |
|
|
*, |
|
|
dtype: Optional[torch.dtype] = None, |
|
|
layout: Optional[torch.layout] = None, |
|
|
device: Optional[torch.device] = None, |
|
|
pin_memory: bool = False, |
|
|
requires_grad: bool = False, |
|
|
memory_format: torch.memory_format = torch.preserve_format, |
|
|
) -> TensorLikeType: |
|
|
e = torch.empty_like( |
|
|
a, |
|
|
dtype=dtype, |
|
|
layout=layout, |
|
|
device=device, |
|
|
pin_memory=pin_memory, |
|
|
requires_grad=requires_grad, |
|
|
memory_format=memory_format, |
|
|
) |
|
|
return fill(e, fill_value) |
|
|
|
|
|
|
|
|
zeros_like = partial(full_like, fill_value=False) |
|
|
|
|
|
|
|
|
ones_like = partial(full_like, fill_value=True) |
|
|
|
|
|
|
|
|
@register_decomposition(torch.ops.aten.randn) |
|
|
@out_wrapper() |
|
|
def randn( |
|
|
*shape, |
|
|
dtype: Optional[torch.dtype] = None, |
|
|
device: Optional[torch.device] = None, |
|
|
layout: Optional[torch.layout] = None, |
|
|
requires_grad: bool = False, |
|
|
pin_memory: Optional[bool] = None, |
|
|
) -> TensorLikeType: |
|
|
|
|
|
check(pin_memory is None, lambda: "pin_memory parameter is not supported!") |
|
|
|
|
|
shape_ = utils.extract_shape_from_varargs(shape) |
|
|
|
|
|
dtype = utils.dtype_or_default(dtype) |
|
|
device = utils.device_or_default(device) |
|
|
layout = utils.layout_or_default(layout) |
|
|
|
|
|
return prims.normal( |
|
|
shape_, |
|
|
mean=0.0, |
|
|
std=1.0, |
|
|
dtype=dtype, |
|
|
device=device, |
|
|
requires_grad=requires_grad, |
|
|
) |
|
|
|
|
|
|
|
|
def scalar_tensor( |
|
|
a: NumberType, |
|
|
*, |
|
|
dtype: Optional[torch.dtype] = None, |
|
|
layout: torch.layout = torch.strided, |
|
|
device: Optional[torch.device] = None, |
|
|
pin_memory: bool = False, |
|
|
) -> TensorLikeType: |
|
|
utils.check_layout(layout) |
|
|
utils.check_pin_memory(pin_memory) |
|
|
dtype = dtype if dtype is not None else utils.type_to_dtype(type(a)) |
|
|
device = device if device is not None else torch.device("cpu") |
|
|
return prims.scalar_tensor(a, dtype=dtype, device=device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_decomposition(torch.ops.aten.uniform) |
|
|
def uniform( |
|
|
shape: ShapeType, |
|
|
low: Union[bool, int, float] = 0.0, |
|
|
high: Union[bool, int, float] = 1.0, |
|
|
*, |
|
|
dtype: torch.dtype, |
|
|
device: DeviceLikeType, |
|
|
) -> TensorLikeType: |
|
|
utils.validate_shape(shape) |
|
|
|
|
|
assert isinstance(low, (bool, int, float)) |
|
|
assert isinstance(high, (bool, int, float)) |
|
|
low = float(low) |
|
|
high = float(high) |
|
|
|
|
|
assert isinstance(dtype, torch.dtype) |
|
|
device = utils.canonicalize_device(device) |
|
|
|
|
|
return prims.uniform(shape, low=low, high=high, dtype=dtype, device=device) |
|
|
|
|
|
|
|
|
@register_decomposition( |
|
|
[torch.ops.aten.masked_fill.Scalar, torch.ops.aten.masked_fill.Tensor] |
|
|
) |
|
|
def masked_fill(a: TensorLikeType, mask: TensorLikeType, value: TensorOrNumberLikeType): |
|
|
python_type = utils.dtype_to_type(a.dtype) |
|
|
if isinstance(value, Number): |
|
|
value_type = type(value) |
|
|
else: |
|
|
|
|
|
|
|
|
value_ndim = value.ndim |
|
|
check( |
|
|
value_ndim == 0, |
|
|
lambda: f"only supports a 0-dimensional value tensor, but got tensor with {value_ndim} dimension", |
|
|
) |
|
|
|
|
|
check( |
|
|
a.device.type == "cuda" or value.device == a.device, |
|
|
lambda: "Expected `value` to be on same device as `a`", |
|
|
) |
|
|
value_type = utils.dtype_to_type(value.dtype) |
|
|
if utils.is_cpu_scalar_tensor(value): |
|
|
value = value.item() |
|
|
|
|
|
if value_type is complex: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
check( |
|
|
utils.is_weakly_lesser_type(value_type, python_type), |
|
|
lambda: f"could not convert to type {python_type} without overflow", |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
if isinstance(value, Number): |
|
|
return torch.where(mask, python_type(value), a) |
|
|
|
|
|
assert isinstance(value, TensorLike) |
|
|
return torch.where(mask, prims.to_dtype(value, a.dtype), a) |
|
|
|
|
|
|
|
|
|
|
|
def allclose( |
|
|
a: TensorLikeType, |
|
|
b: TensorLikeType, |
|
|
rtol: float = 1e-05, |
|
|
atol: float = 1e-08, |
|
|
equal_nan: bool = False, |
|
|
) -> bool: |
|
|
""" |
|
|
Reference implementation of torch.allclose |
|
|
""" |
|
|
_check_close_args(name="torch.allclose", a=a, b=b, rtol=rtol, atol=atol) |
|
|
|
|
|
return bool( |
|
|
torch.all(torch.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)).item() |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def equal(a: TensorLikeType, b: TensorLikeType) -> bool: |
|
|
utils.check_same_device(a, b, allow_cpu_scalar_tensors=False) |
|
|
utils.check_same_dtype(a, b) |
|
|
|
|
|
|
|
|
if a.ndim != b.ndim: |
|
|
return False |
|
|
|
|
|
for x, y in zip(a.shape, b.shape): |
|
|
if x != y: |
|
|
return False |
|
|
|
|
|
|
|
|
if a.numel() == 0: |
|
|
return True |
|
|
|
|
|
return item(all(eq(a, b))) |
|
|
|
|
|
|
|
|
@out_wrapper(exact_dtype=True) |
|
|
def norm( |
|
|
input: TensorLikeType, |
|
|
p: Optional[Union[float, str]] = "fro", |
|
|
dim: Optional[DimsType] = None, |
|
|
keepdim: bool = False, |
|
|
*, |
|
|
dtype: Optional[torch.dtype] = None, |
|
|
) -> TensorLikeType: |
|
|
|
|
|
if ( |
|
|
p == "fro" and (dim is None or isinstance(dim, int) or len(dim) <= 2) |
|
|
) or p is None: |
|
|
p = 2 |
|
|
if isinstance(dim, int): |
|
|
dim = [dim] |
|
|
if isinstance(p, str): |
|
|
|
|
|
|
|
|
if dim is None: |
|
|
dim = tuple(range(input.ndim)) |
|
|
return torch.linalg.matrix_norm(input, p, dim, keepdim, dtype=dtype) |
|
|
else: |
|
|
return torch.linalg.vector_norm(input, p, dim, keepdim, dtype=dtype) |
|
|
|
|
|
|
|
|
@register_decomposition(torch.ops.aten.trace) |
|
|
def trace(self: TensorLikeType) -> TensorLikeType: |
|
|
utils.check( |
|
|
self.ndim == 2, lambda: "expected a matrix, but got tensor with dim {self.ndim}" |
|
|
) |
|
|
return torch.sum(torch.diag(self, 0)) |
|
|
|
|
|
|
|
|
def _make_r_binary_op(base_op): |
|
|
def rop( |
|
|
a: Union[TensorLikeType, NumberType], |
|
|
b: Union[TensorLikeType, NumberType], |
|
|
) -> TensorLikeType: |
|
|
return base_op(b, a) |
|
|
|
|
|
return rop |
|
|
|
|
|
|
|
|
rtruediv = _make_r_binary_op(true_divide) |
|
|
rfloordiv = _make_r_binary_op(floor_divide) |
|
|
rpow = _make_r_binary_op(pow) |
|
|
|
|
|
|
|
|
@register_decomposition(torch.ops.aten.triu) |
|
|
@out_wrapper() |
|
|
def triu(a: TensorLikeType, diagonal: int = 0) -> TensorLikeType: |
|
|
utils.check( |
|
|
a.ndim >= 2, lambda: "triu: input tensor must have at least 2 dimensions" |
|
|
) |
|
|
h, w = a.shape[-2:] |
|
|
mask = ( |
|
|
torch.arange(w, device=a.device).unsqueeze(-2) |
|
|
- torch.arange(h, device=a.device).unsqueeze(-1) |
|
|
) >= diagonal |
|
|
|
|
|
return utils.mask_tensor(mask, a) |
|
|
|
|
|
|
|
|
@register_decomposition(torch.ops.aten.tril) |
|
|
@out_wrapper() |
|
|
def tril(a: TensorLikeType, diagonal: int = 0) -> TensorLikeType: |
|
|
utils.check( |
|
|
a.ndim >= 2, lambda: "tril: input tensor must have at least 2 dimensions" |
|
|
) |
|
|
h, w = a.shape[-2:] |
|
|
mask = ( |
|
|
torch.arange(w, device=a.device).unsqueeze(-2) |
|
|
- torch.arange(h, device=a.device).unsqueeze(-1) |
|
|
) <= diagonal |
|
|
|
|
|
return utils.mask_tensor(mask, a) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_tril_sizes(row: int, col: int, offset: int) -> Tuple[int, int, int]: |
|
|
if row == 0 or col == 0: |
|
|
return 0, 0, 0 |
|
|
|
|
|
m_first_row = min(col, 1 + offset) if offset > 0 else int(row + offset > 0) |
|
|
m_last_row = max(0, min(col, row + offset)) |
|
|
n_row_all = max(0, min(row, row + offset)) |
|
|
n_row_trapezoid = m_last_row - m_first_row + 1 |
|
|
|
|
|
|
|
|
trapezoid_size = (m_first_row + m_last_row) * n_row_trapezoid // 2 |
|
|
|
|
|
diff_row = n_row_all - n_row_trapezoid |
|
|
rectangle_size = max(0, diff_row * col) |
|
|
|
|
|
return trapezoid_size, rectangle_size, m_first_row |
|
|
|
|
|
|
|
|
def _trilu_checks( |
|
|
name: str, |
|
|
row: int, |
|
|
col: int, |
|
|
dtype: torch.dtype, |
|
|
layout: torch.layout, |
|
|
pin_memory: bool, |
|
|
): |
|
|
check(row >= 0, lambda: f"row must be non-negative, got {row}") |
|
|
check(col >= 0, lambda: f"col must be non-negative, got {col}") |
|
|
check( |
|
|
dtype in (torch.int32, torch.int64), |
|
|
lambda: f"\"{name}\" not implemented for '{dtype}'", |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
@register_decomposition(torch.ops.aten.tril_indices) |
|
|
def tril_indices( |
|
|
row: int, |
|
|
col: int, |
|
|
offset: int = 0, |
|
|
*, |
|
|
dtype: torch.dtype = torch.long, |
|
|
layout: torch.layout = torch.strided, |
|
|
device: DeviceLikeType = "cpu", |
|
|
pin_memory: bool = False, |
|
|
) -> TensorLikeType: |
|
|
_trilu_checks("tril_indices", row, col, dtype, layout, pin_memory) |
|
|
|
|
|
trapezoid_size, rectangle_size, m_first_row = _get_tril_sizes(row, col, offset) |
|
|
row_offset = max(0, -offset) |
|
|
|
|
|
arange_kw = partial( |
|
|
torch.arange, layout=layout, device=device, pin_memory=pin_memory |
|
|
) |
|
|
|
|
|
|
|
|
xs1 = arange_kw(0, trapezoid_size, dtype=torch.float64) |
|
|
b = m_first_row - 0.5 |
|
|
row_inds1 = torch.floor(-b + torch.sqrt(b * b + 2 * xs1)) |
|
|
col_inds1 = torch.floor(xs1 - (2 * m_first_row - 1 + row_inds1) * row_inds1 * 0.5) |
|
|
row_inds1 = prims.to_dtype(row_inds1 + row_offset, dtype) |
|
|
col_inds1 = prims.to_dtype(col_inds1, dtype) |
|
|
|
|
|
|
|
|
xs2 = arange_kw(0, rectangle_size, dtype=dtype) |
|
|
row_inds2 = xs2 // col + (col - m_first_row + 1 + row_offset) |
|
|
col_inds2 = xs2 % col |
|
|
|
|
|
return torch.stack( |
|
|
(torch.cat((row_inds1, row_inds2)), torch.cat((col_inds1, col_inds2))) |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_triu_sizes(row: int, col: int, offset: int) -> Tuple[int, int, int]: |
|
|
if row == 0 or col == 0: |
|
|
return 0, 0, 0 |
|
|
|
|
|
m_first_row = max(0, col - offset) if offset > 0 else col |
|
|
|
|
|
|
|
|
rectangle_size = max(0, min(row, -offset) * col) |
|
|
|
|
|
|
|
|
trapezoid_size_tril, rectangle_size_tril, _ = _get_tril_sizes(row, col, offset - 1) |
|
|
triu_size = row * col - (trapezoid_size_tril + rectangle_size_tril) |
|
|
trapezoid_size = triu_size - rectangle_size |
|
|
|
|
|
return trapezoid_size, rectangle_size, m_first_row |
|
|
|
|
|
|
|
|
@register_decomposition(torch.ops.aten.triu_indices) |
|
|
def triu_indices( |
|
|
row: int, |
|
|
col: int, |
|
|
offset: int = 0, |
|
|
*, |
|
|
dtype: torch.dtype = torch.long, |
|
|
layout: torch.layout = torch.strided, |
|
|
device: DeviceLikeType = "cpu", |
|
|
pin_memory: bool = False, |
|
|
) -> TensorLikeType: |
|
|
_trilu_checks("triu_indices", row, col, dtype, layout, pin_memory) |
|
|
|
|
|
trapezoid_size, rectangle_size, m_first_row = _get_triu_sizes(row, col, offset) |
|
|
col_offset = max(0, offset) |
|
|
|
|
|
arange_kw = partial( |
|
|
torch.arange, layout=layout, device=device, pin_memory=pin_memory |
|
|
) |
|
|
|
|
|
|
|
|
xs2 = arange_kw(0, rectangle_size, dtype=dtype) |
|
|
row_inds2 = xs2 // col |
|
|
col_inds2 = xs2 % col |
|
|
|
|
|
|
|
|
xs1 = arange_kw(0, trapezoid_size, dtype=torch.float64) |
|
|
b = -0.5 - m_first_row |
|
|
row_inds1 = torch.floor(-b - torch.sqrt(b * b - 2 * xs1)) |
|
|
col_inds1 = torch.floor(xs1 - ((2 * m_first_row - 1 - row_inds1) * row_inds1) * 0.5) |
|
|
row_inds1 = prims.to_dtype(row_inds1, dtype) |
|
|
col_inds1 = prims.to_dtype(col_inds1, dtype) |
|
|
|
|
|
if col: |
|
|
row_inds1 = row_inds1 + (rectangle_size // col) |
|
|
col_inds1 = col_inds1 + col_offset |
|
|
|
|
|
return torch.stack( |
|
|
(torch.cat((row_inds2, row_inds1)), torch.cat((col_inds2, col_inds1))) |
|
|
) |
|
|
|
|
|
|
|
|
import torch._refs.fft |
|
|
import torch._refs.linalg |
|
|
import torch._refs.nn.functional |
|
|
import torch._refs.special |
|
|
|