File size: 5,544 Bytes
d1d4335
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
#pragma once

#include <c10/metal/utils.h>
#include <metal_compute>

namespace c10 {
namespace metal {

constant constexpr ushort simdgroup_size = 32;

template <typename T>
inline ::metal::enable_if_t<!::metal::is_same_v<T, long>, T> simd_sum(T val) {
  return ::metal::simd_sum(val);
}

template <typename T>
inline ::metal::enable_if_t<!::metal::is_same_v<T, long>, T> simd_prod(T val) {
  return ::metal::simd_product(val);
}

// Metal does not support SIMD reductions over 64-bit types, but it could be
// implement using simd_shuffle_down, that yields result in log2(simdgroup_size)
// iterations Use fill variant, as shuffle down returns garbage if inactive
// thread is referenced (on M1/M2, works fine on M4) and broadcast result to all
// threads in the end. Implementation heavily borrows from
// https://github.com/ml-explore/mlx/blob/86389bf9707f46101af45d90510e8e97c8a90b93/mlx/backend/metal/kernels/reduction/ops.h#L16
template <typename T>
inline ::metal::enable_if_t<::metal::is_same_v<T, long>, T> simd_sum(T val) {
  for (ushort i = simdgroup_size / 2; i > 0; i /= 2) {
    val += as_type<T>(
        ::metal::simd_shuffle_and_fill_down(as_type<int2>(val), int2(0), i));
  }
  return as_type<T>(::metal::simd_broadcast(as_type<int2>(val), 0));
}

template <typename T>
inline ::metal::enable_if_t<::metal::is_same_v<T, long>, T> simd_prod(T val) {
  for (ushort i = simdgroup_size / 2; i > 0; i /= 2) {
    val *= as_type<T>(
        ::metal::simd_shuffle_and_fill_down(as_type<int2>(val), int2(0), i));
  }
  return as_type<T>(::metal::simd_broadcast(as_type<int2>(val), 0));
}

// Below algorithms are  written with hardcoded assumption that simdgroup is 32
// and threadgroup_max is 1024, i.e. reduction can be done in two stages max
template <typename T>
opmath_t<T> threadgroup_sum(

    threadgroup opmath_t<T>* data,

    T val,

    unsigned idx,

    unsigned size) {
  auto rc = simd_sum(static_cast<opmath_t<T>>(val));
  if (idx % simdgroup_size == 0) {
    data[idx / simdgroup_size] = rc;
  }
  if (size > simdgroup_size) {
    ::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup);
    if (idx < ((size + simdgroup_size - 1) / simdgroup_size)) {
      auto rc1 = simd_sum(data[idx]);
      if (idx == 0) {
        data[0] = rc1;
      }
    }
  }
  ::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup);
  return data[0];
}

template <typename T>
opmath_t<T> threadgroup_prod(

    threadgroup opmath_t<T>* data,

    T val,

    unsigned idx,

    unsigned size) {
  auto rc = simd_prod(static_cast<opmath_t<T>>(val));
  if (idx % simdgroup_size == 0) {
    data[idx / simdgroup_size] = rc;
  }
  if (size > simdgroup_size) {
    ::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup);
    if (idx < ((size + simdgroup_size - 1) / simdgroup_size)) {
      auto rc1 = simd_prod(data[idx]);
      if (idx == 0) {
        data[0] = rc1;
      }
    }
  }
  ::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup);
  return data[0];
}

template <typename T>
float3 threadgroup_welford_reduce(threadgroup T* data, unsigned size) {
  ::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup);
  float m = data[0];
  float m2 = 0;
  for (unsigned idx = 1; idx < size; ++idx) {
    float delta = data[idx] - m;
    m += delta / (idx + 1);
    m2 += delta * (data[idx] - m);
  }
  return float3(m, m2, size);
}

// Each vec3type is tuple of mean, m2 and weight
template <typename T>
float3 welford_combine(T a, T b) {
  float delta = b.x - a.x;
  float new_weight = a.z + b.z;
  auto w2_over_w = new_weight != 0 ? b.z / new_weight : 0.0;
  return float3(
      a.x + delta * w2_over_w,
      a.y + b.y + delta * delta * a.z * w2_over_w,
      new_weight);
}

template <typename T>
float3 threadgroup_welford_combine(threadgroup T* data, unsigned size) {
  ::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup);
  float3 rc = data[0];
  for (unsigned idx = 1; idx < size; ++idx) {
    rc = welford_combine(rc, data[idx]);
  }
  return rc;
}

template <typename T>
T threadgroup_max(threadgroup T* data, unsigned size) {
  // TODO: This should be moved to the callee
  ::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup);
  T rc = data[0];
  for (unsigned idx = 1; idx < size; ++idx) {
    rc = ::c10::metal::max(rc, data[idx]);
  }
  return rc;
}

template <typename T>
T threadgroup_min(threadgroup T* data, unsigned size) {
  // TODO: This should be moved to the callee
  ::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup);
  T rc = data[0];
  for (unsigned idx = 1; idx < size; ++idx) {
    rc = ::c10::metal::min(rc, data[idx]);
  }
  return rc;
}

template <typename T>
int threadgroup_argmax(threadgroup T* data, unsigned size) {
  // TODO: This should be moved to the callee
  ::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup);
  int rc = 0;
  for (unsigned idx = 1; idx < size; ++idx) {
    if (data[idx] > data[rc]) {
      rc = idx;
    }
  }
  return rc;
}

template <typename T>
int threadgroup_argmin(threadgroup T* data, unsigned size) {
  // TODO: This should be moved to the callee
  ::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup);
  int rc = 0;
  for (unsigned idx = 1; idx < size; ++idx) {
    if (data[idx] < data[rc]) {
      rc = idx;
    }
  }
  return rc;
}

} // namespace metal
} // namespace c10