File size: 5,544 Bytes
d1d4335 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 |
#pragma once
#include <c10/metal/utils.h>
#include <metal_compute>
namespace c10 {
namespace metal {
constant constexpr ushort simdgroup_size = 32;
template <typename T>
inline ::metal::enable_if_t<!::metal::is_same_v<T, long>, T> simd_sum(T val) {
return ::metal::simd_sum(val);
}
template <typename T>
inline ::metal::enable_if_t<!::metal::is_same_v<T, long>, T> simd_prod(T val) {
return ::metal::simd_product(val);
}
// Metal does not support SIMD reductions over 64-bit types, but it could be
// implement using simd_shuffle_down, that yields result in log2(simdgroup_size)
// iterations Use fill variant, as shuffle down returns garbage if inactive
// thread is referenced (on M1/M2, works fine on M4) and broadcast result to all
// threads in the end. Implementation heavily borrows from
// https://github.com/ml-explore/mlx/blob/86389bf9707f46101af45d90510e8e97c8a90b93/mlx/backend/metal/kernels/reduction/ops.h#L16
template <typename T>
inline ::metal::enable_if_t<::metal::is_same_v<T, long>, T> simd_sum(T val) {
for (ushort i = simdgroup_size / 2; i > 0; i /= 2) {
val += as_type<T>(
::metal::simd_shuffle_and_fill_down(as_type<int2>(val), int2(0), i));
}
return as_type<T>(::metal::simd_broadcast(as_type<int2>(val), 0));
}
template <typename T>
inline ::metal::enable_if_t<::metal::is_same_v<T, long>, T> simd_prod(T val) {
for (ushort i = simdgroup_size / 2; i > 0; i /= 2) {
val *= as_type<T>(
::metal::simd_shuffle_and_fill_down(as_type<int2>(val), int2(0), i));
}
return as_type<T>(::metal::simd_broadcast(as_type<int2>(val), 0));
}
// Below algorithms are written with hardcoded assumption that simdgroup is 32
// and threadgroup_max is 1024, i.e. reduction can be done in two stages max
template <typename T>
opmath_t<T> threadgroup_sum(
threadgroup opmath_t<T>* data,
T val,
unsigned idx,
unsigned size) {
auto rc = simd_sum(static_cast<opmath_t<T>>(val));
if (idx % simdgroup_size == 0) {
data[idx / simdgroup_size] = rc;
}
if (size > simdgroup_size) {
::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup);
if (idx < ((size + simdgroup_size - 1) / simdgroup_size)) {
auto rc1 = simd_sum(data[idx]);
if (idx == 0) {
data[0] = rc1;
}
}
}
::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup);
return data[0];
}
template <typename T>
opmath_t<T> threadgroup_prod(
threadgroup opmath_t<T>* data,
T val,
unsigned idx,
unsigned size) {
auto rc = simd_prod(static_cast<opmath_t<T>>(val));
if (idx % simdgroup_size == 0) {
data[idx / simdgroup_size] = rc;
}
if (size > simdgroup_size) {
::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup);
if (idx < ((size + simdgroup_size - 1) / simdgroup_size)) {
auto rc1 = simd_prod(data[idx]);
if (idx == 0) {
data[0] = rc1;
}
}
}
::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup);
return data[0];
}
template <typename T>
float3 threadgroup_welford_reduce(threadgroup T* data, unsigned size) {
::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup);
float m = data[0];
float m2 = 0;
for (unsigned idx = 1; idx < size; ++idx) {
float delta = data[idx] - m;
m += delta / (idx + 1);
m2 += delta * (data[idx] - m);
}
return float3(m, m2, size);
}
// Each vec3type is tuple of mean, m2 and weight
template <typename T>
float3 welford_combine(T a, T b) {
float delta = b.x - a.x;
float new_weight = a.z + b.z;
auto w2_over_w = new_weight != 0 ? b.z / new_weight : 0.0;
return float3(
a.x + delta * w2_over_w,
a.y + b.y + delta * delta * a.z * w2_over_w,
new_weight);
}
template <typename T>
float3 threadgroup_welford_combine(threadgroup T* data, unsigned size) {
::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup);
float3 rc = data[0];
for (unsigned idx = 1; idx < size; ++idx) {
rc = welford_combine(rc, data[idx]);
}
return rc;
}
template <typename T>
T threadgroup_max(threadgroup T* data, unsigned size) {
// TODO: This should be moved to the callee
::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup);
T rc = data[0];
for (unsigned idx = 1; idx < size; ++idx) {
rc = ::c10::metal::max(rc, data[idx]);
}
return rc;
}
template <typename T>
T threadgroup_min(threadgroup T* data, unsigned size) {
// TODO: This should be moved to the callee
::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup);
T rc = data[0];
for (unsigned idx = 1; idx < size; ++idx) {
rc = ::c10::metal::min(rc, data[idx]);
}
return rc;
}
template <typename T>
int threadgroup_argmax(threadgroup T* data, unsigned size) {
// TODO: This should be moved to the callee
::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup);
int rc = 0;
for (unsigned idx = 1; idx < size; ++idx) {
if (data[idx] > data[rc]) {
rc = idx;
}
}
return rc;
}
template <typename T>
int threadgroup_argmin(threadgroup T* data, unsigned size) {
// TODO: This should be moved to the callee
::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup);
int rc = 0;
for (unsigned idx = 1; idx < size; ++idx) {
if (data[idx] < data[rc]) {
rc = idx;
}
}
return rc;
}
} // namespace metal
} // namespace c10
|