| import unittest |
| from functools import partial |
| from itertools import product |
| from typing import List |
|
|
| import numpy as np |
|
|
| import torch |
| from torch.testing import make_tensor |
| from torch.testing._internal.common_device_type import ( |
| precisionOverride, |
| tol, |
| toleranceOverride, |
| ) |
| from torch.testing._internal.common_dtype import all_types_and, floating_types |
| from torch.testing._internal.common_utils import TEST_SCIPY, torch_to_numpy_dtype_dict |
| from torch.testing._internal.opinfo.core import ( |
| BinaryUfuncInfo, |
| DecorateInfo, |
| L, |
| NumericsFilter, |
| OpInfo, |
| S, |
| SampleInput, |
| UnaryUfuncInfo, |
| ) |
| from torch.testing._internal.opinfo.refs import ( |
| ElementwiseBinaryPythonRefInfo, |
| ElementwiseUnaryPythonRefInfo, |
| ) |
| from torch.testing._internal.opinfo.utils import ( |
| np_unary_ufunc_integer_promotion_wrapper, |
| ) |
|
|
|
|
| if TEST_SCIPY: |
| import scipy.special |
|
|
| |
| |
| |
| def sample_inputs_i0_i1(op_info, device, dtype, requires_grad, **kwargs): |
| exclude_zero = requires_grad and op_info.op == torch.special.i0e |
| make_arg = partial( |
| make_tensor, |
| dtype=dtype, |
| device=device, |
| requires_grad=requires_grad, |
| exclude_zero=exclude_zero, |
| ) |
| yield SampleInput(make_arg((S,))) |
| yield SampleInput(make_arg(())) |
|
|
| if requires_grad and not exclude_zero: |
| |
| |
| t = make_arg((S,)) |
| t[0] = 0 |
|
|
| yield SampleInput(t) |
|
|
|
|
| def sample_inputs_polygamma(op_info, device, dtype, requires_grad, **kwargs): |
| make_arg = partial( |
| make_tensor, device=device, dtype=dtype, requires_grad=requires_grad |
| ) |
| tensor_shapes = ((S, S), ()) |
| ns = (1, 2, 3, 4, 5) |
|
|
| for shape, n in product(tensor_shapes, ns): |
| yield SampleInput(make_arg(shape), args=(n,)) |
|
|
|
|
| def reference_polygamma(x, n): |
| |
| |
| |
| |
| |
| |
| |
| result_dtype = torch_to_numpy_dtype_dict[torch.get_default_dtype()] |
| if x.dtype == np.double: |
| result_dtype = np.double |
| return scipy.special.polygamma(n, x).astype(result_dtype) |
|
|
|
|
| def sample_inputs_entr(op_info, device, dtype, requires_grad, **kwargs): |
| low, _ = op_info.domain |
|
|
| if requires_grad: |
| low = 0 + op_info._domain_eps |
|
|
| return ( |
| SampleInput( |
| make_tensor( |
| (L,), dtype=dtype, device=device, low=low, requires_grad=requires_grad |
| ) |
| ), |
| SampleInput( |
| make_tensor( |
| (), dtype=dtype, device=device, low=low, requires_grad=requires_grad |
| ) |
| ), |
| ) |
|
|
|
|
| op_db: List[OpInfo] = [ |
| UnaryUfuncInfo( |
| "special.i0e", |
| aten_name="special_i0e", |
| ref=scipy.special.i0e if TEST_SCIPY else None, |
| decorators=(precisionOverride({torch.bfloat16: 3e-1, torch.float16: 3e-1}),), |
| dtypes=all_types_and(torch.bool, torch.bfloat16), |
| dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16), |
| backward_dtypes=floating_types(), |
| sample_inputs_func=sample_inputs_i0_i1, |
| supports_forward_ad=True, |
| supports_fwgrad_bwgrad=True, |
| ), |
| UnaryUfuncInfo( |
| "special.i1", |
| aten_name="special_i1", |
| ref=np_unary_ufunc_integer_promotion_wrapper(scipy.special.i1) |
| if TEST_SCIPY |
| else None, |
| dtypes=all_types_and(torch.bool), |
| dtypesIfCUDA=all_types_and(torch.bool), |
| sample_inputs_func=sample_inputs_i0_i1, |
| decorators=( |
| DecorateInfo( |
| toleranceOverride( |
| { |
| torch.float32: tol(atol=1e-4, rtol=0), |
| torch.bool: tol(atol=1e-4, rtol=0), |
| } |
| ) |
| ), |
| ), |
| skips=( |
| DecorateInfo( |
| unittest.skip("Incorrect result!"), |
| "TestUnaryUfuncs", |
| "test_reference_numerics_large", |
| dtypes=(torch.int8,), |
| ), |
| ), |
| supports_fwgrad_bwgrad=True, |
| supports_forward_ad=True, |
| ), |
| UnaryUfuncInfo( |
| "special.i1e", |
| aten_name="special_i1e", |
| ref=scipy.special.i1e if TEST_SCIPY else None, |
| dtypes=all_types_and(torch.bool), |
| dtypesIfCUDA=all_types_and(torch.bool), |
| sample_inputs_func=sample_inputs_i0_i1, |
| supports_forward_ad=True, |
| supports_fwgrad_bwgrad=True, |
| ), |
| UnaryUfuncInfo( |
| "special.ndtr", |
| aten_name="special_ndtr", |
| decorators=(precisionOverride({torch.bfloat16: 5e-3, torch.float16: 5e-4}),), |
| ref=scipy.special.ndtr if TEST_SCIPY else None, |
| dtypes=all_types_and(torch.bool, torch.bfloat16), |
| dtypesIfCUDA=all_types_and(torch.bool, torch.bfloat16, torch.float16), |
| supports_forward_ad=True, |
| supports_fwgrad_bwgrad=True, |
| skips=( |
| |
| DecorateInfo( |
| unittest.expectedFailure, |
| "TestGradients", |
| "test_fn_fwgrad_bwgrad", |
| device_type="meta", |
| ), |
| ), |
| ), |
| |
| |
| UnaryUfuncInfo( |
| "special.polygamma", |
| op=lambda x, n, **kwargs: torch.special.polygamma(n, x, **kwargs), |
| variant_test_name="special_polygamma_n_0", |
| ref=reference_polygamma if TEST_SCIPY else None, |
| dtypes=all_types_and(torch.bool, torch.bfloat16), |
| dtypesIfCUDA=all_types_and(torch.bool, torch.half), |
| supports_forward_ad=True, |
| supports_fwgrad_bwgrad=True, |
| sample_inputs_func=sample_inputs_polygamma, |
| skips=( |
| |
| DecorateInfo( |
| unittest.expectedFailure, "TestJit", "test_variant_consistency_jit" |
| ), |
| DecorateInfo( |
| unittest.expectedFailure, |
| "TestNormalizeOperators", |
| "test_normalize_operator_exhaustive", |
| ), |
| ), |
| sample_kwargs=lambda device, dtype, input: ({"n": 0}, {"n": 0}), |
| |
| reference_numerics_filter=NumericsFilter( |
| condition=lambda x: x < 0.1, safe_val=1 |
| ), |
| ), |
| BinaryUfuncInfo( |
| "special.xlog1py", |
| aten_name="special_xlog1py", |
| dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), |
| promotes_int_to_float=True, |
| supports_forward_ad=True, |
| supports_fwgrad_bwgrad=True, |
| supports_one_python_scalar=True, |
| |
| rhs_make_tensor_kwargs=dict(low=-0.99), |
| ), |
| BinaryUfuncInfo( |
| "special.zeta", |
| aten_name="special_zeta", |
| dtypes=all_types_and(torch.bool), |
| promotes_int_to_float=True, |
| supports_autograd=False, |
| supports_one_python_scalar=True, |
| ), |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| UnaryUfuncInfo( |
| "special.entr", |
| ref=scipy.special.entr if TEST_SCIPY else None, |
| aten_name="special_entr", |
| supports_forward_ad=True, |
| supports_fwgrad_bwgrad=True, |
| decorators=(precisionOverride({torch.float16: 1e-1, torch.bfloat16: 1e-1}),), |
| dtypes=all_types_and(torch.bool, torch.bfloat16), |
| dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16), |
| skips=( |
| DecorateInfo( |
| unittest.skip("Skipped!"), |
| "TestUnaryUfuncs", |
| "test_reference_numerics_large", |
| dtypes=[torch.bfloat16, torch.float16], |
| ), |
| ), |
| supports_inplace_autograd=False, |
| sample_inputs_func=sample_inputs_entr, |
| ), |
| UnaryUfuncInfo( |
| "special.ndtri", |
| ref=scipy.special.ndtri if TEST_SCIPY else None, |
| domain=(0, 1), |
| aten_name="special_ndtri", |
| dtypes=all_types_and(torch.bool), |
| supports_forward_ad=True, |
| supports_fwgrad_bwgrad=True, |
| ), |
| UnaryUfuncInfo( |
| "special.log_ndtr", |
| aten_name="special_log_ndtr", |
| ref=scipy.special.log_ndtr if TEST_SCIPY else None, |
| dtypes=all_types_and(torch.bool), |
| supports_forward_ad=True, |
| supports_fwgrad_bwgrad=True, |
| ), |
| UnaryUfuncInfo( |
| "special.erfcx", |
| ref=scipy.special.erfcx if TEST_SCIPY else None, |
| aten_name="special_erfcx", |
| decorators=( |
| toleranceOverride( |
| { |
| torch.float32: tol(atol=0, rtol=4e-6), |
| } |
| ), |
| ), |
| dtypes=all_types_and(torch.bool), |
| supports_forward_ad=True, |
| supports_fwgrad_bwgrad=True, |
| ), |
| UnaryUfuncInfo( |
| "special.airy_ai", |
| decorators=( |
| precisionOverride( |
| { |
| torch.float32: 1e-03, |
| torch.float64: 1e-05, |
| }, |
| ), |
| ), |
| dtypes=all_types_and(torch.bool), |
| ref=lambda x: scipy.special.airy(x)[0] if TEST_SCIPY else None, |
| skips=( |
| DecorateInfo( |
| unittest.skip("Skipped!"), |
| "TestUnaryUfuncs", |
| "test_reference_numerics_large", |
| ), |
| ), |
| supports_autograd=False, |
| ), |
| UnaryUfuncInfo( |
| "special.bessel_j0", |
| decorators=( |
| precisionOverride( |
| { |
| torch.float32: 1e-04, |
| torch.float64: 1e-05, |
| }, |
| ), |
| ), |
| dtypes=all_types_and(torch.bool), |
| ref=scipy.special.j0 if TEST_SCIPY else None, |
| supports_autograd=False, |
| ), |
| UnaryUfuncInfo( |
| "special.bessel_j1", |
| decorators=( |
| precisionOverride( |
| { |
| torch.float32: 1e-04, |
| torch.float64: 1e-05, |
| }, |
| ), |
| ), |
| dtypes=all_types_and(torch.bool), |
| ref=scipy.special.j1 if TEST_SCIPY else None, |
| supports_autograd=False, |
| ), |
| UnaryUfuncInfo( |
| "special.bessel_y0", |
| decorators=( |
| precisionOverride( |
| { |
| torch.float32: 1e-04, |
| torch.float64: 1e-05, |
| }, |
| ), |
| ), |
| dtypes=all_types_and(torch.bool), |
| ref=scipy.special.y0 if TEST_SCIPY else None, |
| supports_autograd=False, |
| ), |
| UnaryUfuncInfo( |
| "special.bessel_y1", |
| decorators=( |
| precisionOverride( |
| { |
| torch.float32: 1e-04, |
| torch.float64: 1e-05, |
| }, |
| ), |
| ), |
| dtypes=all_types_and(torch.bool), |
| ref=scipy.special.y1 if TEST_SCIPY else None, |
| supports_autograd=False, |
| ), |
| BinaryUfuncInfo( |
| "special.chebyshev_polynomial_t", |
| dtypes=all_types_and(torch.bool), |
| promotes_int_to_float=True, |
| skips=( |
| DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"), |
| DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"), |
| ), |
| supports_one_python_scalar=True, |
| supports_autograd=False, |
| ), |
| BinaryUfuncInfo( |
| "special.chebyshev_polynomial_u", |
| dtypes=all_types_and(torch.bool), |
| promotes_int_to_float=True, |
| skips=( |
| DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"), |
| DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"), |
| ), |
| supports_one_python_scalar=True, |
| supports_autograd=False, |
| ), |
| BinaryUfuncInfo( |
| "special.chebyshev_polynomial_v", |
| dtypes=all_types_and(torch.bool), |
| promotes_int_to_float=True, |
| skips=( |
| DecorateInfo( |
| unittest.skip( |
| "Skipping - testing takes an unreasonably long time, #79528" |
| ) |
| ), |
| DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"), |
| DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"), |
| ), |
| supports_one_python_scalar=True, |
| supports_autograd=False, |
| ), |
| BinaryUfuncInfo( |
| "special.chebyshev_polynomial_w", |
| dtypes=all_types_and(torch.bool), |
| promotes_int_to_float=True, |
| skips=( |
| DecorateInfo( |
| unittest.skip( |
| "Skipping - testing takes an unreasonably long time, #79528" |
| ) |
| ), |
| DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"), |
| DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"), |
| ), |
| supports_one_python_scalar=True, |
| supports_autograd=False, |
| ), |
| BinaryUfuncInfo( |
| "special.hermite_polynomial_h", |
| dtypes=all_types_and(torch.bool), |
| promotes_int_to_float=True, |
| skips=( |
| DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"), |
| DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"), |
| ), |
| supports_one_python_scalar=True, |
| supports_autograd=False, |
| ), |
| BinaryUfuncInfo( |
| "special.hermite_polynomial_he", |
| dtypes=all_types_and(torch.bool), |
| promotes_int_to_float=True, |
| skips=( |
| DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"), |
| DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"), |
| ), |
| supports_one_python_scalar=True, |
| supports_autograd=False, |
| ), |
| BinaryUfuncInfo( |
| "special.laguerre_polynomial_l", |
| dtypes=all_types_and(torch.bool), |
| promotes_int_to_float=True, |
| skips=( |
| DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"), |
| DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"), |
| ), |
| supports_one_python_scalar=True, |
| supports_autograd=False, |
| ), |
| BinaryUfuncInfo( |
| "special.legendre_polynomial_p", |
| dtypes=all_types_and(torch.bool), |
| promotes_int_to_float=True, |
| skips=( |
| DecorateInfo( |
| unittest.skip( |
| "Skipping - testing takes an unreasonably long time, #79528" |
| ) |
| ), |
| DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"), |
| DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"), |
| ), |
| supports_one_python_scalar=True, |
| supports_autograd=False, |
| ), |
| UnaryUfuncInfo( |
| "special.modified_bessel_i0", |
| decorators=( |
| precisionOverride( |
| { |
| torch.float32: 1e-03, |
| torch.float64: 1e-05, |
| }, |
| ), |
| ), |
| dtypes=all_types_and(torch.bool), |
| ref=scipy.special.i0 if TEST_SCIPY else None, |
| supports_autograd=False, |
| ), |
| UnaryUfuncInfo( |
| "special.modified_bessel_i1", |
| decorators=( |
| precisionOverride( |
| { |
| torch.float32: 1e-03, |
| torch.float64: 1e-05, |
| }, |
| ), |
| ), |
| dtypes=all_types_and(torch.bool), |
| ref=scipy.special.i1 if TEST_SCIPY else None, |
| supports_autograd=False, |
| ), |
| UnaryUfuncInfo( |
| "special.modified_bessel_k0", |
| decorators=( |
| precisionOverride( |
| { |
| torch.float32: 1e-03, |
| torch.float64: 1e-05, |
| }, |
| ), |
| ), |
| dtypes=all_types_and(torch.bool), |
| ref=scipy.special.k0 if TEST_SCIPY else None, |
| supports_autograd=False, |
| ), |
| UnaryUfuncInfo( |
| "special.modified_bessel_k1", |
| decorators=( |
| precisionOverride( |
| { |
| torch.float32: 1e-03, |
| torch.float64: 1e-05, |
| }, |
| ), |
| ), |
| dtypes=all_types_and(torch.bool), |
| ref=scipy.special.k1 if TEST_SCIPY else None, |
| supports_autograd=False, |
| ), |
| UnaryUfuncInfo( |
| "special.scaled_modified_bessel_k0", |
| decorators=( |
| toleranceOverride( |
| { |
| torch.float32: tol(atol=1e-03, rtol=1e-03), |
| torch.float64: tol(atol=1e-05, rtol=1e-03), |
| } |
| ), |
| ), |
| dtypes=all_types_and(torch.bool), |
| ref=scipy.special.k0e if TEST_SCIPY else None, |
| supports_autograd=False, |
| ), |
| UnaryUfuncInfo( |
| "special.scaled_modified_bessel_k1", |
| decorators=( |
| toleranceOverride( |
| { |
| torch.float32: tol(atol=1e-03, rtol=1e-03), |
| torch.float64: tol(atol=1e-05, rtol=1e-03), |
| } |
| ), |
| ), |
| dtypes=all_types_and(torch.bool), |
| ref=scipy.special.k1e if TEST_SCIPY else None, |
| supports_autograd=False, |
| ), |
| BinaryUfuncInfo( |
| "special.shifted_chebyshev_polynomial_t", |
| dtypes=all_types_and(torch.bool), |
| promotes_int_to_float=True, |
| skips=( |
| DecorateInfo( |
| unittest.skip( |
| "Skipping - testing takes an unreasonably long time, #79528" |
| ) |
| ), |
| DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"), |
| DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"), |
| ), |
| supports_one_python_scalar=True, |
| supports_autograd=False, |
| ), |
| BinaryUfuncInfo( |
| "special.shifted_chebyshev_polynomial_u", |
| dtypes=all_types_and(torch.bool), |
| promotes_int_to_float=True, |
| skips=( |
| DecorateInfo( |
| unittest.skip( |
| "Skipping - testing takes an unreasonably long time, #79528" |
| ) |
| ), |
| DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"), |
| DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"), |
| ), |
| supports_one_python_scalar=True, |
| supports_autograd=False, |
| ), |
| BinaryUfuncInfo( |
| "special.shifted_chebyshev_polynomial_v", |
| dtypes=all_types_and(torch.bool), |
| promotes_int_to_float=True, |
| skips=( |
| DecorateInfo( |
| unittest.skip( |
| "Skipping - testing takes an unreasonably long time, #79528" |
| ) |
| ), |
| DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"), |
| DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"), |
| ), |
| supports_one_python_scalar=True, |
| supports_autograd=False, |
| ), |
| BinaryUfuncInfo( |
| "special.shifted_chebyshev_polynomial_w", |
| dtypes=all_types_and(torch.bool), |
| promotes_int_to_float=True, |
| skips=( |
| DecorateInfo( |
| unittest.skip( |
| "Skipping - testing takes an unreasonably long time, #79528" |
| ) |
| ), |
| DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"), |
| DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"), |
| ), |
| supports_one_python_scalar=True, |
| supports_autograd=False, |
| ), |
| UnaryUfuncInfo( |
| "special.spherical_bessel_j0", |
| decorators=( |
| toleranceOverride( |
| { |
| torch.float32: tol(atol=1e-03, rtol=1e-03), |
| torch.float64: tol(atol=1e-05, rtol=1e-03), |
| } |
| ), |
| ), |
| dtypes=all_types_and(torch.bool), |
| ref=lambda x: scipy.special.spherical_jn(0, x) if TEST_SCIPY else None, |
| supports_autograd=False, |
| ), |
| ] |
|
|
| python_ref_db: List[OpInfo] = [ |
| |
| |
| |
| ElementwiseUnaryPythonRefInfo( |
| "_refs.special.bessel_j0", |
| torch_opinfo_name="special.bessel_j0", |
| supports_nvfuser=False, |
| op_db=op_db, |
| ), |
| ElementwiseUnaryPythonRefInfo( |
| "_refs.special.bessel_j1", |
| torch_opinfo_name="special.bessel_j1", |
| supports_nvfuser=False, |
| op_db=op_db, |
| ), |
| ElementwiseUnaryPythonRefInfo( |
| "_refs.special.i0e", |
| torch_opinfo_name="special.i0e", |
| supports_nvfuser=False, |
| op_db=op_db, |
| ), |
| ElementwiseUnaryPythonRefInfo( |
| "_refs.special.i1", |
| torch_opinfo_name="special.i1", |
| supports_nvfuser=False, |
| op_db=op_db, |
| ), |
| ElementwiseUnaryPythonRefInfo( |
| "_refs.special.i1e", |
| torch_opinfo_name="special.i1e", |
| supports_nvfuser=False, |
| op_db=op_db, |
| ), |
| ElementwiseUnaryPythonRefInfo( |
| "_refs.special.spherical_bessel_j0", |
| torch_opinfo_name="special.spherical_bessel_j0", |
| supports_nvfuser=False, |
| op_db=op_db, |
| ), |
| |
| |
| |
| ElementwiseBinaryPythonRefInfo( |
| "_refs.special.zeta", |
| torch_opinfo_name="special.zeta", |
| supports_one_python_scalar=True, |
| supports_nvfuser=False, |
| op_db=op_db, |
| ), |
| ] |
|
|