// Copyright © 2024 Apple Inc. #pragma once #include METAL_FUNC ulong2 elem_to_loc_broadcast( uint elem, constant const int* shape, constant const int64_t* a_strides, constant const int64_t* b_strides, int ndim) { ulong loc_a{0}; ulong loc_b{0}; for (int i = ndim - 1; i >= 0 && elem > 0; --i) { int pos_in_dim = (elem % shape[i]); elem /= shape[i]; loc_a += pos_in_dim * a_strides[i]; loc_b += pos_in_dim * b_strides[i]; } return ulong2(loc_a, loc_b); } METAL_FUNC ulong3 elem_to_loc_broadcast( uint elem, constant const int* shape, constant const int64_t* a_strides, constant const int64_t* b_strides, constant const int64_t* c_strides, int ndim) { ulong loc_a{0}; ulong loc_b{0}; ulong loc_c{0}; for (int i = ndim - 1; i >= 0 && elem > 0; --i) { int pos_in_dim = (elem % shape[i]); elem /= shape[i]; loc_a += pos_in_dim * a_strides[i]; loc_b += pos_in_dim * b_strides[i]; loc_c += pos_in_dim * c_strides[i]; } return ulong3(loc_a, loc_b, loc_c); }