| | |
| |
|
| | #pragma once |
| |
|
| | #include <metal_stdlib> |
| |
|
| | METAL_FUNC ulong2 elem_to_loc_broadcast( |
| | uint elem, |
| | constant const int* shape, |
| | constant const int64_t* a_strides, |
| | constant const int64_t* b_strides, |
| | int ndim) { |
| | ulong loc_a{0}; |
| | ulong loc_b{0}; |
| | for (int i = ndim - 1; i >= 0 && elem > 0; --i) { |
| | int pos_in_dim = (elem % shape[i]); |
| | elem /= shape[i]; |
| | loc_a += pos_in_dim * a_strides[i]; |
| | loc_b += pos_in_dim * b_strides[i]; |
| | } |
| | return ulong2(loc_a, loc_b); |
| | } |
| |
|
| | METAL_FUNC ulong3 elem_to_loc_broadcast( |
| | uint elem, |
| | constant const int* shape, |
| | constant const int64_t* a_strides, |
| | constant const int64_t* b_strides, |
| | constant const int64_t* c_strides, |
| | int ndim) { |
| | ulong loc_a{0}; |
| | ulong loc_b{0}; |
| | ulong loc_c{0}; |
| | for (int i = ndim - 1; i >= 0 && elem > 0; --i) { |
| | int pos_in_dim = (elem % shape[i]); |
| | elem /= shape[i]; |
| | loc_a += pos_in_dim * a_strides[i]; |
| | loc_b += pos_in_dim * b_strides[i]; |
| | loc_c += pos_in_dim * c_strides[i]; |
| | } |
| | return ulong3(loc_a, loc_b, loc_c); |
| | } |