|
|
import math |
|
|
from typing import Optional |
|
|
|
|
|
import torch |
|
|
import torch._prims as prims |
|
|
import torch._prims_common as utils |
|
|
import torch._refs as refs |
|
|
|
|
|
from torch import Tensor |
|
|
from torch._decomp import register_decomposition |
|
|
from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND, TensorLikeType |
|
|
from torch._prims_common.wrappers import elementwise_type_promotion_wrapper, out_wrapper |
|
|
from torch._refs import ( |
|
|
_make_elementwise_binary_reference, |
|
|
_make_elementwise_unary_reference, |
|
|
) |
|
|
|
|
|
|
|
|
__all__ = [ |
|
|
"bessel_j0", |
|
|
"bessel_j1", |
|
|
"i0e", |
|
|
"i1", |
|
|
"i1e", |
|
|
"logit", |
|
|
"multigammaln", |
|
|
"spherical_bessel_j0", |
|
|
"zeta", |
|
|
] |
|
|
|
|
|
|
|
|
@_make_elementwise_unary_reference( |
|
|
ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, |
|
|
aten_op=torch.ops.aten.special_bessel_j0, |
|
|
) |
|
|
def bessel_j0(a: TensorLikeType) -> TensorLikeType: |
|
|
return prims.bessel_j0(a) |
|
|
|
|
|
|
|
|
@_make_elementwise_unary_reference( |
|
|
ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, |
|
|
aten_op=torch.ops.aten.special_bessel_j1, |
|
|
) |
|
|
def bessel_j1(a: TensorLikeType) -> TensorLikeType: |
|
|
return prims.bessel_j1(a) |
|
|
|
|
|
|
|
|
@_make_elementwise_unary_reference( |
|
|
ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, aten_op=torch.ops.aten.special_i0e |
|
|
) |
|
|
def i0e(a: TensorLikeType) -> TensorLikeType: |
|
|
return prims.bessel_i0e(a) |
|
|
|
|
|
|
|
|
@_make_elementwise_unary_reference( |
|
|
ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, aten_op=torch.ops.aten.special_i1 |
|
|
) |
|
|
def i1(a: TensorLikeType) -> TensorLikeType: |
|
|
return prims.bessel_i1(a) |
|
|
|
|
|
|
|
|
@_make_elementwise_unary_reference( |
|
|
ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, aten_op=torch.ops.aten.special_i1e |
|
|
) |
|
|
def i1e(a: TensorLikeType) -> TensorLikeType: |
|
|
return prims.bessel_i1e(a) |
|
|
|
|
|
|
|
|
@register_decomposition(torch.ops.aten.logit) |
|
|
@out_wrapper() |
|
|
@elementwise_type_promotion_wrapper( |
|
|
type_promoting_args=("self",), |
|
|
type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, |
|
|
) |
|
|
def logit(self: TensorLikeType, eps: Optional[float] = None) -> TensorLikeType: |
|
|
if eps is None: |
|
|
eps = -1.0 |
|
|
lo = eps |
|
|
hi = 1 - eps |
|
|
self = torch.clamp(self, lo, hi) |
|
|
return torch.log(torch.true_divide(self, torch.sub(1, self))) |
|
|
|
|
|
|
|
|
@register_decomposition(torch.ops.aten.mvlgamma) |
|
|
@out_wrapper() |
|
|
@elementwise_type_promotion_wrapper( |
|
|
type_promoting_args=("a",), |
|
|
type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, |
|
|
) |
|
|
def multigammaln(a: TensorLikeType, p: int) -> TensorLikeType: |
|
|
c = 0.25 * p * (p - 1) * math.log(math.pi) |
|
|
b = 0.5 * torch.arange(start=(1 - p), end=1, step=1, dtype=a.dtype, device=a.device) |
|
|
return torch.sum(torch.lgamma(a.unsqueeze(-1) + b), dim=-1) + c |
|
|
|
|
|
|
|
|
@_make_elementwise_unary_reference( |
|
|
ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, |
|
|
aten_op=torch.ops.aten.special_spherical_bessel_j0, |
|
|
) |
|
|
def spherical_bessel_j0(a: TensorLikeType) -> TensorLikeType: |
|
|
return prims.spherical_bessel_j0(a) |
|
|
|
|
|
|
|
|
zeta = _make_elementwise_binary_reference( |
|
|
prims.zeta, |
|
|
type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, |
|
|
aten_op=torch.ops.aten.special_zeta, |
|
|
) |
|
|
|