| | |
| |
|
| | #pragma once |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | #define instantiate_metal_math_funcs(itype, otype, ctype, mfast) \ |
| | \ |
| | METAL_FUNC otype abs(itype x) { \ |
| | return static_cast<otype>(__metal_fabs(static_cast<ctype>(x), mfast)); \ |
| | } \ |
| | METAL_FUNC otype acos(itype x) { \ |
| | return static_cast<otype>(__metal_acos(static_cast<ctype>(x), mfast)); \ |
| | } \ |
| | METAL_FUNC otype acosh(itype x) { \ |
| | return static_cast<otype>(__metal_acosh(static_cast<ctype>(x), mfast)); \ |
| | } \ |
| | METAL_FUNC otype asin(itype x) { \ |
| | return static_cast<otype>(__metal_asin(static_cast<ctype>(x), mfast)); \ |
| | } \ |
| | METAL_FUNC otype asinh(itype x) { \ |
| | return static_cast<otype>(__metal_asinh(static_cast<ctype>(x), mfast)); \ |
| | } \ |
| | METAL_FUNC otype atan(itype y_over_x) { \ |
| | return static_cast<otype>( \ |
| | __metal_atan(static_cast<ctype>(y_over_x), mfast)); \ |
| | } \ |
| | METAL_FUNC otype atan2(itype y, itype x) { \ |
| | return static_cast<otype>( \ |
| | __metal_atan2(static_cast<ctype>(y), static_cast<ctype>(x), mfast)); \ |
| | } \ |
| | METAL_FUNC otype atanh(itype x) { \ |
| | return static_cast<otype>(__metal_atanh(static_cast<ctype>(x), mfast)); \ |
| | } \ |
| | METAL_FUNC otype ceil(itype x) { \ |
| | return static_cast<otype>(__metal_ceil(static_cast<ctype>(x), mfast)); \ |
| | } \ |
| | METAL_FUNC otype cos(itype x) { \ |
| | return static_cast<otype>(__metal_cos(static_cast<ctype>(x), mfast)); \ |
| | } \ |
| | METAL_FUNC otype cosh(itype x) { \ |
| | return static_cast<otype>(__metal_cosh(static_cast<ctype>(x), mfast)); \ |
| | } \ |
| | METAL_FUNC otype cospi(itype x) { \ |
| | return static_cast<otype>(__metal_cospi(static_cast<ctype>(x), mfast)); \ |
| | } \ |
| | METAL_FUNC otype divide(itype x, itype y) { \ |
| | return static_cast<otype>( \ |
| | __metal_divide(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \ |
| | } \ |
| | METAL_FUNC otype exp(itype x) { \ |
| | return static_cast<otype>(__metal_exp(static_cast<ctype>(x), mfast)); \ |
| | } \ |
| | METAL_FUNC otype exp10(itype x) { \ |
| | return static_cast<otype>(__metal_exp10(static_cast<ctype>(x), mfast)); \ |
| | } \ |
| | METAL_FUNC otype exp2(itype x) { \ |
| | return static_cast<otype>(__metal_exp2(static_cast<ctype>(x), mfast)); \ |
| | } \ |
| | METAL_FUNC otype fabs(itype x) { \ |
| | return static_cast<otype>(__metal_fabs(static_cast<ctype>(x), mfast)); \ |
| | } \ |
| | METAL_FUNC otype fdim(itype x, itype y) { \ |
| | ctype t = static_cast<ctype>(x - y); \ |
| | return static_cast<otype>(select(t, ctype(0), t < ctype(0) || x == y)); \ |
| | } \ |
| | METAL_FUNC otype floor(itype x) { \ |
| | return static_cast<otype>(__metal_floor(static_cast<ctype>(x), mfast)); \ |
| | } \ |
| | METAL_FUNC otype fma(itype x, itype y, itype z) { \ |
| | return static_cast<otype>(__metal_fma( \ |
| | static_cast<ctype>(x), static_cast<ctype>(y), static_cast<ctype>(z))); \ |
| | } \ |
| | METAL_FUNC otype fmax(itype x, itype y) { \ |
| | return static_cast<otype>( \ |
| | __metal_fmax(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \ |
| | } \ |
| | METAL_FUNC otype fmax3(itype x, itype y, itype z) { \ |
| | return static_cast<otype>(__metal_fmax3( \ |
| | static_cast<ctype>(x), \ |
| | static_cast<ctype>(y), \ |
| | static_cast<ctype>(z), \ |
| | mfast)); \ |
| | } \ |
| | METAL_FUNC otype fmedian3(itype x, itype y, itype z) { \ |
| | return static_cast<otype>(__metal_fmedian3( \ |
| | static_cast<ctype>(x), \ |
| | static_cast<ctype>(y), \ |
| | static_cast<ctype>(z), \ |
| | mfast)); \ |
| | } \ |
| | METAL_FUNC otype fmin(itype x, itype y) { \ |
| | return static_cast<otype>( \ |
| | __metal_fmin(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \ |
| | } \ |
| | METAL_FUNC otype fmin3(itype x, itype y, itype z) { \ |
| | return static_cast<otype>(__metal_fmin3( \ |
| | static_cast<ctype>(x), \ |
| | static_cast<ctype>(y), \ |
| | static_cast<ctype>(z), \ |
| | mfast)); \ |
| | } \ |
| | METAL_FUNC otype fmod(itype x, itype y) { \ |
| | return static_cast<otype>( \ |
| | __metal_fmod(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \ |
| | } \ |
| | METAL_FUNC otype fract(itype x) { \ |
| | return static_cast<otype>(__metal_fract(static_cast<ctype>(x), mfast)); \ |
| | } \ |
| | METAL_FUNC otype frexp(itype x, thread int& exp) { \ |
| | return static_cast<otype>(__metal_frexp(static_cast<ctype>(x), &exp)); \ |
| | } \ |
| | METAL_FUNC otype ldexp(itype x, int k) { \ |
| | return static_cast<otype>(__metal_ldexp(static_cast<ctype>(x), k, mfast)); \ |
| | } \ |
| | METAL_FUNC otype log(itype x) { \ |
| | return static_cast<otype>(__metal_log(static_cast<ctype>(x), mfast)); \ |
| | } \ |
| | METAL_FUNC otype log10(itype x) { \ |
| | return static_cast<otype>(__metal_log10(static_cast<ctype>(x), mfast)); \ |
| | } \ |
| | METAL_FUNC otype log2(itype x) { \ |
| | return static_cast<otype>(__metal_log2(static_cast<ctype>(x), mfast)); \ |
| | } \ |
| | METAL_FUNC otype max(itype x, itype y) { \ |
| | return static_cast<otype>( \ |
| | __metal_fmax(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \ |
| | } \ |
| | METAL_FUNC otype max3(itype x, itype y, itype z) { \ |
| | return static_cast<otype>(__metal_fmax3( \ |
| | static_cast<ctype>(x), \ |
| | static_cast<ctype>(y), \ |
| | static_cast<ctype>(z), \ |
| | mfast)); \ |
| | } \ |
| | METAL_FUNC otype median3(itype x, itype y, itype z) { \ |
| | return static_cast<otype>(__metal_fmedian3( \ |
| | static_cast<ctype>(x), \ |
| | static_cast<ctype>(y), \ |
| | static_cast<ctype>(z), \ |
| | mfast)); \ |
| | } \ |
| | METAL_FUNC otype min(itype x, itype y) { \ |
| | return static_cast<otype>( \ |
| | __metal_fmin(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \ |
| | } \ |
| | METAL_FUNC otype min3(itype x, itype y, itype z) { \ |
| | return static_cast<otype>(__metal_fmin3( \ |
| | static_cast<ctype>(x), \ |
| | static_cast<ctype>(y), \ |
| | static_cast<ctype>(z), \ |
| | mfast)); \ |
| | } \ |
| | METAL_FUNC otype nextafter(itype x, itype y) { \ |
| | return static_cast<otype>( \ |
| | __metal_nextafter(static_cast<ctype>(x), static_cast<ctype>(y))); \ |
| | } \ |
| | METAL_FUNC otype pow(itype x, itype y) { \ |
| | return static_cast<otype>( \ |
| | __metal_pow(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \ |
| | } \ |
| | METAL_FUNC otype powr(itype x, itype y) { \ |
| | return static_cast<otype>( \ |
| | __metal_powr(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \ |
| | } \ |
| | METAL_FUNC otype rint(itype x) { \ |
| | return static_cast<otype>(__metal_rint(static_cast<ctype>(x), mfast)); \ |
| | } \ |
| | METAL_FUNC otype round(itype x) { \ |
| | return static_cast<otype>(__metal_round(static_cast<ctype>(x), mfast)); \ |
| | } \ |
| | METAL_FUNC otype rsqrt(itype x) { \ |
| | return static_cast<otype>(__metal_rsqrt(static_cast<ctype>(x), mfast)); \ |
| | } \ |
| | METAL_FUNC otype sin(itype x) { \ |
| | return static_cast<otype>(__metal_sin(static_cast<ctype>(x), mfast)); \ |
| | } \ |
| | METAL_FUNC otype sinh(itype x) { \ |
| | return static_cast<otype>(__metal_sinh(static_cast<ctype>(x), mfast)); \ |
| | } \ |
| | METAL_FUNC otype sinpi(itype x) { \ |
| | return static_cast<otype>(__metal_sinpi(static_cast<ctype>(x), mfast)); \ |
| | } \ |
| | METAL_FUNC otype sqrt(itype x) { \ |
| | return static_cast<otype>(__metal_sqrt(static_cast<ctype>(x), mfast)); \ |
| | } \ |
| | METAL_FUNC otype tan(itype x) { \ |
| | return static_cast<otype>(__metal_tan(static_cast<ctype>(x), mfast)); \ |
| | } \ |
| | METAL_FUNC otype tanh(itype x) { \ |
| | return static_cast<otype>(__metal_tanh(static_cast<ctype>(x), mfast)); \ |
| | } \ |
| | METAL_FUNC otype tanpi(itype x) { \ |
| | return static_cast<otype>(__metal_tanpi(static_cast<ctype>(x), mfast)); \ |
| | } \ |
| | METAL_FUNC otype trunc(itype x) { \ |
| | return static_cast<otype>(__metal_trunc(static_cast<ctype>(x), mfast)); \ |
| | } |
| |
|
| | namespace metal { |
| |
|
| | instantiate_metal_math_funcs( |
| | bfloat16_t, |
| | bfloat16_t, |
| | float, |
| | __METAL_MAYBE_FAST_MATH__); |
| |
|
| | namespace fast { |
| |
|
| | instantiate_metal_math_funcs( |
| | bfloat16_t, |
| | bfloat16_t, |
| | float, |
| | __METAL_FAST_MATH__); |
| |
|
| | } |
| |
|
| | namespace precise { |
| |
|
| | instantiate_metal_math_funcs( |
| | bfloat16_t, |
| | bfloat16_t, |
| | float, |
| | __METAL_PRECISE_MATH__); |
| |
|
| | } |
| |
|
| | } |
| |
|
| | |
| | |
| | |
| |
|
| | #define instantiate_metal_simd_comm_funcs( \ |
| | itype, otype, ctype, itype_to_ctype, ctype_to_otype) \ |
| | \ |
| | METAL_FUNC otype simd_broadcast(itype data, ushort broadcast_lane_id) { \ |
| | return ctype_to_otype( \ |
| | __metal_simd_broadcast(itype_to_ctype(data), broadcast_lane_id)); \ |
| | } \ |
| | \ |
| | METAL_FUNC otype simd_shuffle(itype data, ushort simd_lane_id) { \ |
| | return ctype_to_otype( \ |
| | __metal_simd_shuffle(itype_to_ctype(data), simd_lane_id)); \ |
| | } \ |
| | \ |
| | METAL_FUNC otype simd_shuffle_and_fill_down( \ |
| | itype data, itype filling_data, ushort delta, ushort modulo) { \ |
| | return ctype_to_otype(__metal_simd_shuffle_and_fill_down( \ |
| | itype_to_ctype(data), itype_to_ctype(filling_data), delta, modulo)); \ |
| | } \ |
| | \ |
| | METAL_FUNC otype simd_shuffle_and_fill_down( \ |
| | itype data, itype filling_data, ushort delta) { \ |
| | return ctype_to_otype(__metal_simd_shuffle_and_fill_down( \ |
| | itype_to_ctype(data), \ |
| | itype_to_ctype(filling_data), \ |
| | delta, \ |
| | __metal_get_simdgroup_size(ushort()))); \ |
| | } \ |
| | \ |
| | METAL_FUNC otype simd_shuffle_and_fill_up( \ |
| | itype data, itype filling_data, ushort delta, ushort modulo) { \ |
| | return ctype_to_otype(__metal_simd_shuffle_and_fill_up( \ |
| | itype_to_ctype(data), itype_to_ctype(filling_data), delta, modulo)); \ |
| | } \ |
| | \ |
| | METAL_FUNC otype simd_shuffle_and_fill_up( \ |
| | itype data, itype filling_data, ushort delta) { \ |
| | return ctype_to_otype(__metal_simd_shuffle_and_fill_up( \ |
| | itype_to_ctype(data), \ |
| | itype_to_ctype(filling_data), \ |
| | delta, \ |
| | __metal_get_simdgroup_size(ushort()))); \ |
| | } \ |
| | \ |
| | METAL_FUNC otype simd_shuffle_down(itype data, ushort delta) { \ |
| | return ctype_to_otype( \ |
| | __metal_simd_shuffle_down(itype_to_ctype(data), delta)); \ |
| | } \ |
| | \ |
| | METAL_FUNC otype simd_shuffle_rotate_down(itype data, ushort delta) { \ |
| | return ctype_to_otype( \ |
| | __metal_simd_shuffle_rotate_down(itype_to_ctype(data), delta)); \ |
| | } \ |
| | \ |
| | METAL_FUNC otype simd_shuffle_rotate_up(itype data, ushort delta) { \ |
| | return ctype_to_otype( \ |
| | __metal_simd_shuffle_rotate_up(itype_to_ctype(data), delta)); \ |
| | } \ |
| | \ |
| | METAL_FUNC otype simd_shuffle_up(itype data, ushort delta) { \ |
| | return ctype_to_otype( \ |
| | __metal_simd_shuffle_up(itype_to_ctype(data), delta)); \ |
| | } \ |
| | \ |
| | METAL_FUNC otype simd_shuffle_xor(itype data, ushort mask) { \ |
| | return ctype_to_otype( \ |
| | __metal_simd_shuffle_xor(itype_to_ctype(data), mask)); \ |
| | } |
| |
|
| | #define instantiate_metal_simd_reduction_funcs(itype, otype, ctype) \ |
| | \ |
| | METAL_FUNC otype simd_max(itype data) { \ |
| | return static_cast<otype>(__metal_simd_max(static_cast<ctype>(data))); \ |
| | } \ |
| | \ |
| | METAL_FUNC otype simd_min(itype data) { \ |
| | return static_cast<otype>(__metal_simd_min(static_cast<ctype>(data))); \ |
| | } \ |
| | \ |
| | METAL_FUNC otype simd_prefix_exclusive_product(itype data) { \ |
| | return static_cast<otype>( \ |
| | __metal_simd_prefix_exclusive_product(static_cast<ctype>(data))); \ |
| | } \ |
| | \ |
| | METAL_FUNC otype simd_prefix_exclusive_sum(itype data) { \ |
| | return static_cast<otype>( \ |
| | __metal_simd_prefix_exclusive_sum(static_cast<ctype>(data))); \ |
| | } \ |
| | \ |
| | METAL_FUNC otype simd_prefix_inclusive_product(itype data) { \ |
| | return static_cast<otype>( \ |
| | __metal_simd_prefix_inclusive_product(static_cast<ctype>(data))); \ |
| | } \ |
| | \ |
| | METAL_FUNC otype simd_prefix_inclusive_sum(itype data) { \ |
| | return static_cast<otype>( \ |
| | __metal_simd_prefix_inclusive_sum(static_cast<ctype>(data))); \ |
| | } \ |
| | \ |
| | METAL_FUNC otype simd_product(itype data) { \ |
| | return static_cast<otype>(__metal_simd_product(static_cast<ctype>(data))); \ |
| | } \ |
| | \ |
| | METAL_FUNC otype simd_sum(itype data) { \ |
| | return static_cast<otype>(__metal_simd_sum(static_cast<ctype>(data))); \ |
| | } \ |
| | \ |
| | METAL_FUNC otype simd_xor(itype data) { \ |
| | return static_cast<otype>(__metal_simd_xor(static_cast<ctype>(data))); \ |
| | } |
| |
|
| | namespace metal { |
| |
|
| | instantiate_metal_simd_comm_funcs( |
| | bfloat16_t, |
| | bfloat16_t, |
| | uint16_t, |
| | bfloat16_to_uint16, |
| | uint16_to_bfloat16); |
| | instantiate_metal_simd_reduction_funcs(bfloat16_t, bfloat16_t, float); |
| |
|
| | } |