robtaylor-chipflow commited on
Commit
6c81401
·
0 Parent(s):

Add Metal fused_add_rms_norm + rms_norm kernels for vLLM

Browse files

New Metal kernels implementing both rms_norm and fused_add_rms_norm
with the exact signatures vLLM expects.

rms_norm(out, input, weight, epsilon):
out = (input / RMS(input)) * weight

fused_add_rms_norm(input, residual, weight, epsilon):
residual += input
input = (residual / RMS(residual)) * weight

The fused variant saves memory bandwidth by combining residual
addition and variance accumulation into a single pass. Every
transformer layer calls this operation.

Features:
- Supports fp16, bf16, fp32 dtypes
- Threadgroup-wide reduction using simd_sum + shared memory
- Float32 accumulation for numerical stability
- Handles strided input layouts (input_stride parameter)
- Comprehensive tests with property-based checks

Co-developed-by: Claude Code v2.1.50 (claude-opus-4-6)

build.toml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [general]
2
+ name = "fused_rms_norm"
3
+ backends = ["metal"]
4
+
5
+ [torch]
6
+ src = [
7
+ "torch-ext/torch_binding.cpp",
8
+ "torch-ext/torch_binding.h",
9
+ ]
10
+
11
+ [kernel.fused_rms_norm_metal]
12
+ backend = "metal"
13
+ src = [
14
+ "fused-rms-norm-metal/rms_norm.metal",
15
+ "fused-rms-norm-metal/rms_norm.mm",
16
+ "fused-rms-norm-metal/utils.metal",
17
+ ]
18
+ depends = ["torch"]
flake.nix ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ description = "Flake for fused RMS normalization kernel";
3
+
4
+ inputs = {
5
+ kernel-builder.url = "github:huggingface/kernel-builder";
6
+ };
7
+
8
+ outputs =
9
+ {
10
+ self,
11
+ kernel-builder,
12
+ }:
13
+ kernel-builder.lib.genFlakeOutputs {
14
+ inherit self;
15
+ path = ./.;
16
+ };
17
+ }
fused-rms-norm-metal/rms_norm.metal ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <metal_stdlib>
2
+ #include "utils.metal"
3
+
4
+ using namespace metal;
5
+
6
+ // Maximum number of simdgroups per threadgroup for reduction.
7
+ // 512 threads / 32 threads per simdgroup = 16 simdgroups max.
8
+ constant constexpr int MAX_SIMDGROUPS = 16;
9
+
10
+ // Threadgroup-wide sum reduction using simdgroups.
11
+ // Each thread contributes a value; returns the total sum to all threads.
12
+ static inline float threadgroup_reduce_sum(
13
+ float value,
14
+ threadgroup float *shared [[threadgroup(0)]],
15
+ uint tid [[thread_position_in_threadgroup]],
16
+ uint tg_size [[threads_per_threadgroup]]) {
17
+
18
+ // Phase 1: reduce within each simdgroup.
19
+ float simd_val = simd_sum(value);
20
+
21
+ // Phase 2: first thread of each simdgroup writes to shared memory.
22
+ uint simdgroup_id = tid / 32;
23
+ uint lane_id = tid % 32;
24
+ if (lane_id == 0) {
25
+ shared[simdgroup_id] = simd_val;
26
+ }
27
+ threadgroup_barrier(mem_flags::mem_threadgroup);
28
+
29
+ // Phase 3: first simdgroup reduces across simdgroup partial sums.
30
+ uint num_simdgroups = (tg_size + 31) / 32;
31
+ float result = 0.0f;
32
+ if (tid < num_simdgroups) {
33
+ result = shared[tid];
34
+ }
35
+ result = simd_sum(result);
36
+
37
+ // Broadcast result to all threads via shared memory.
38
+ if (tid == 0) {
39
+ shared[0] = result;
40
+ }
41
+ threadgroup_barrier(mem_flags::mem_threadgroup);
42
+ return shared[0];
43
+ }
44
+
45
+ // RMS normalization kernel.
46
+ // out[token, i] = (input[token, i] / RMS(input[token, :])) * weight[i]
47
+ // where RMS = sqrt(mean(x^2) + epsilon)
48
+ //
49
+ // One threadgroup per token. Threads stride across hidden_size.
50
+ template <typename scalar_t>
51
+ kernel void rms_norm_kernel(
52
+ device scalar_t *out [[buffer(0)]],
53
+ const device scalar_t *input [[buffer(1)]],
54
+ const device scalar_t *weight [[buffer(2)]],
55
+ const device float &epsilon [[buffer(3)]],
56
+ const device int &num_tokens [[buffer(4)]],
57
+ const device int &hidden_size [[buffer(5)]],
58
+ const device int64_t &input_stride [[buffer(6)]],
59
+ threadgroup float *shared [[threadgroup(0)]],
60
+ uint token_idx [[threadgroup_position_in_grid]],
61
+ uint tid [[thread_position_in_threadgroup]],
62
+ uint tg_size [[threads_per_threadgroup]]) {
63
+
64
+ // Phase 1: accumulate sum of squares for variance.
65
+ float variance = 0.0f;
66
+ for (int i = tid; i < hidden_size; i += tg_size) {
67
+ float x = static_cast<float>(input[token_idx * input_stride + i]);
68
+ variance += x * x;
69
+ }
70
+
71
+ // Phase 2: reduce variance across threadgroup.
72
+ variance = threadgroup_reduce_sum(variance, shared, tid, tg_size);
73
+
74
+ // Phase 3: compute scaling factor.
75
+ float s_variance = rsqrt(variance / static_cast<float>(hidden_size) + epsilon);
76
+
77
+ // Phase 4: normalize and scale.
78
+ for (int i = tid; i < hidden_size; i += tg_size) {
79
+ float x = static_cast<float>(input[token_idx * input_stride + i]);
80
+ float w = static_cast<float>(weight[i]);
81
+ out[token_idx * hidden_size + i] = static_cast<scalar_t>(x * s_variance * w);
82
+ }
83
+ }
84
+
85
+ // Fused residual addition + RMS normalization kernel.
86
+ //
87
+ // After execution:
88
+ // residual[token, i] = old_residual[token, i] + old_input[token, i]
89
+ // input[token, i] = rms_norm(new_residual[token, :]) * weight[i]
90
+ //
91
+ // This fuses two memory passes into one: the residual addition and variance
92
+ // accumulation happen in the same loop, saving memory bandwidth.
93
+ template <typename scalar_t>
94
+ kernel void fused_add_rms_norm_kernel(
95
+ device scalar_t *input [[buffer(0)]],
96
+ device scalar_t *residual [[buffer(1)]],
97
+ const device scalar_t *weight [[buffer(2)]],
98
+ const device float &epsilon [[buffer(3)]],
99
+ const device int &num_tokens [[buffer(4)]],
100
+ const device int &hidden_size [[buffer(5)]],
101
+ const device int64_t &input_stride [[buffer(6)]],
102
+ threadgroup float *shared [[threadgroup(0)]],
103
+ uint token_idx [[threadgroup_position_in_grid]],
104
+ uint tid [[thread_position_in_threadgroup]],
105
+ uint tg_size [[threads_per_threadgroup]]) {
106
+
107
+ // Phase 1: add residual and accumulate variance in one pass.
108
+ float variance = 0.0f;
109
+ for (int i = tid; i < hidden_size; i += tg_size) {
110
+ float inp = static_cast<float>(input[token_idx * input_stride + i]);
111
+ float res = static_cast<float>(residual[token_idx * hidden_size + i]);
112
+ float z = inp + res;
113
+ variance += z * z;
114
+ residual[token_idx * hidden_size + i] = static_cast<scalar_t>(z);
115
+ }
116
+
117
+ // Phase 2: reduce variance across threadgroup.
118
+ variance = threadgroup_reduce_sum(variance, shared, tid, tg_size);
119
+
120
+ // Phase 3: compute scaling factor.
121
+ float s_variance = rsqrt(variance / static_cast<float>(hidden_size) + epsilon);
122
+
123
+ // Phase 4: read updated residual, normalize, and write to input.
124
+ for (int i = tid; i < hidden_size; i += tg_size) {
125
+ float x = static_cast<float>(residual[token_idx * hidden_size + i]);
126
+ float w = static_cast<float>(weight[i]);
127
+ input[token_idx * input_stride + i] = static_cast<scalar_t>(x * s_variance * w);
128
+ }
129
+ }
130
+
131
+ // Instantiate kernel variants.
132
+ #define instantiate_rms_norm(type) \
133
+ template [[host_name("rms_norm_" #type)]] [[kernel]] void \
134
+ rms_norm_kernel<type>( \
135
+ device type *out [[buffer(0)]], \
136
+ const device type *input [[buffer(1)]], \
137
+ const device type *weight [[buffer(2)]], \
138
+ const device float &epsilon [[buffer(3)]], \
139
+ const device int &num_tokens [[buffer(4)]], \
140
+ const device int &hidden_size [[buffer(5)]], \
141
+ const device int64_t &input_stride [[buffer(6)]], \
142
+ threadgroup float *shared [[threadgroup(0)]], \
143
+ uint token_idx [[threadgroup_position_in_grid]], \
144
+ uint tid [[thread_position_in_threadgroup]], \
145
+ uint tg_size [[threads_per_threadgroup]]);
146
+
147
+ #define instantiate_fused_add_rms_norm(type) \
148
+ template [[host_name("fused_add_rms_norm_" #type)]] [[kernel]] void \
149
+ fused_add_rms_norm_kernel<type>( \
150
+ device type *input [[buffer(0)]], \
151
+ device type *residual [[buffer(1)]], \
152
+ const device type *weight [[buffer(2)]], \
153
+ const device float &epsilon [[buffer(3)]], \
154
+ const device int &num_tokens [[buffer(4)]], \
155
+ const device int &hidden_size [[buffer(5)]], \
156
+ const device int64_t &input_stride [[buffer(6)]], \
157
+ threadgroup float *shared [[threadgroup(0)]], \
158
+ uint token_idx [[threadgroup_position_in_grid]], \
159
+ uint tid [[thread_position_in_threadgroup]], \
160
+ uint tg_size [[threads_per_threadgroup]]);
161
+
162
+ instantiate_rms_norm(float);
163
+ instantiate_rms_norm(half);
164
+ instantiate_rms_norm(bfloat16_t);
165
+
166
+ instantiate_fused_add_rms_norm(float);
167
+ instantiate_fused_add_rms_norm(half);
168
+ instantiate_fused_add_rms_norm(bfloat16_t);
fused-rms-norm-metal/rms_norm.mm ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ATen/mps/MPSDevice.h>
2
+ #include <ATen/mps/MPSStream.h>
3
+ #include <torch/torch.h>
4
+
5
+ #import <Foundation/Foundation.h>
6
+ #import <Metal/Metal.h>
7
+ #include <dlfcn.h>
8
+ #include <string>
9
+
10
+ static inline id<MTLBuffer> getMTLBufferStorage(const torch::Tensor &tensor) {
11
+ return __builtin_bit_cast(id<MTLBuffer>, tensor.storage().data());
12
+ }
13
+
14
+ static std::string getModuleDirectory() {
15
+ Dl_info dl_info;
16
+ if (dladdr((void *)getModuleDirectory, &dl_info)) {
17
+ std::string path(dl_info.dli_fname);
18
+ size_t pos = path.find_last_of('/');
19
+ if (pos != std::string::npos) {
20
+ return path.substr(0, pos);
21
+ }
22
+ }
23
+ return ".";
24
+ }
25
+
26
+ // Helper to select kernel name by dtype.
27
+ static NSString *kernelNameForDtype(const char *prefix,
28
+ torch::ScalarType dtype) {
29
+ switch (dtype) {
30
+ case torch::kFloat:
31
+ return [NSString stringWithFormat:@"%s_float", prefix];
32
+ case torch::kHalf:
33
+ return [NSString stringWithFormat:@"%s_half", prefix];
34
+ case torch::kBFloat16:
35
+ return [NSString stringWithFormat:@"%s_bfloat16_t", prefix];
36
+ default:
37
+ TORCH_CHECK(false, "Unsupported dtype: ", dtype);
38
+ return nil;
39
+ }
40
+ }
41
+
42
+ // Helper to load metallib and create pipeline state.
43
+ static id<MTLComputePipelineState>
44
+ createPipeline(id<MTLDevice> device, NSString *kernName, NSError **error) {
45
+ std::string moduleDir = getModuleDirectory();
46
+ std::string metallibPath = moduleDir + "/" + METALLIB_PATH;
47
+
48
+ NSString *metallibPathStr =
49
+ [NSString stringWithUTF8String:metallibPath.c_str()];
50
+ NSURL *metallibURL = [NSURL fileURLWithPath:metallibPathStr];
51
+
52
+ id<MTLLibrary> lib = [device newLibraryWithURL:metallibURL error:error];
53
+ TORCH_CHECK(lib, "Failed to load Metal library at ", metallibPath,
54
+ *error ? [NSString stringWithFormat:@": %@",
55
+ (*error).localizedDescription]
56
+ .UTF8String
57
+ : "");
58
+
59
+ id<MTLFunction> fn = [lib newFunctionWithName:kernName];
60
+ TORCH_CHECK(fn, "Missing Metal kernel function: ", kernName.UTF8String);
61
+
62
+ return [device newComputePipelineStateWithFunction:fn error:error];
63
+ }
64
+
65
+ // Dispatch a layernorm kernel with 7 buffer bindings + threadgroup memory.
66
+ static void dispatchNormKernel(id<MTLComputePipelineState> pso,
67
+ at::mps::MPSStream *stream,
68
+ id<MTLCommandBuffer> cmdBuf,
69
+ // Buffers 0-2: tensor buffers with offsets
70
+ id<MTLBuffer> buf0, NSUInteger off0,
71
+ id<MTLBuffer> buf1, NSUInteger off1,
72
+ id<MTLBuffer> buf2, NSUInteger off2,
73
+ // Scalars
74
+ float epsilon, int32_t num_tokens,
75
+ int32_t hidden_size, int64_t input_stride,
76
+ // Grid
77
+ uint32_t threadgroups,
78
+ uint32_t threads_per_tg) {
79
+ // Shared memory: MAX_SIMDGROUPS (16) floats for reduction.
80
+ const uint32_t shared_mem_size = 16 * sizeof(float);
81
+
82
+ dispatch_queue_t q = stream->queue();
83
+ dispatch_sync(q, ^{
84
+ id<MTLComputeCommandEncoder> enc = [cmdBuf computeCommandEncoder];
85
+ TORCH_CHECK(enc, "Failed to create compute encoder");
86
+
87
+ [enc setComputePipelineState:pso];
88
+
89
+ [enc setBuffer:buf0 offset:off0 atIndex:0];
90
+ [enc setBuffer:buf1 offset:off1 atIndex:1];
91
+ [enc setBuffer:buf2 offset:off2 atIndex:2];
92
+
93
+ [enc setBytes:&epsilon length:sizeof(float) atIndex:3];
94
+ [enc setBytes:&num_tokens length:sizeof(int32_t) atIndex:4];
95
+ [enc setBytes:&hidden_size length:sizeof(int32_t) atIndex:5];
96
+ [enc setBytes:&input_stride length:sizeof(int64_t) atIndex:6];
97
+
98
+ [enc setThreadgroupMemoryLength:shared_mem_size atIndex:0];
99
+
100
+ MTLSize grid = MTLSizeMake(threadgroups, 1, 1);
101
+ MTLSize tg = MTLSizeMake(threads_per_tg, 1, 1);
102
+ [enc dispatchThreadgroups:grid threadsPerThreadgroup:tg];
103
+ [enc endEncoding];
104
+ });
105
+
106
+ stream->synchronize(at::mps::SyncType::COMMIT);
107
+ }
108
+
109
+ void rms_norm(torch::Tensor &out, torch::Tensor &input,
110
+ torch::Tensor &weight, double epsilon) {
111
+ TORCH_CHECK(out.is_contiguous(), "out must be contiguous");
112
+ TORCH_CHECK(weight.is_contiguous(), "weight must be contiguous");
113
+ TORCH_CHECK(input.device().is_mps(), "input must be on MPS device");
114
+
115
+ const int hidden_size = input.size(-1);
116
+ const int64_t input_stride = input.stride(-2);
117
+ const int num_tokens =
118
+ static_cast<int>(input.numel() / hidden_size);
119
+
120
+ @autoreleasepool {
121
+ at::mps::MPSStream *stream = at::mps::getCurrentMPSStream();
122
+ TORCH_CHECK(stream, "Failed to get MPS stream");
123
+
124
+ id<MTLDevice> device = stream->device();
125
+ id<MTLCommandBuffer> cmdBuf = stream->commandBuffer();
126
+ NSError *error = nil;
127
+
128
+ NSString *kernName = kernelNameForDtype("rms_norm", input.scalar_type());
129
+ id<MTLComputePipelineState> pso = createPipeline(device, kernName, &error);
130
+ TORCH_CHECK(pso, "Pipeline creation failed",
131
+ error ? [NSString stringWithFormat:@": %@",
132
+ error.localizedDescription]
133
+ .UTF8String
134
+ : "");
135
+
136
+ const uint32_t threads_per_tg =
137
+ std::min<uint32_t>(512, hidden_size);
138
+
139
+ dispatchNormKernel(
140
+ pso, stream, cmdBuf,
141
+ getMTLBufferStorage(out),
142
+ out.storage_offset() * out.element_size(),
143
+ getMTLBufferStorage(input),
144
+ input.storage_offset() * input.element_size(),
145
+ getMTLBufferStorage(weight),
146
+ weight.storage_offset() * weight.element_size(),
147
+ static_cast<float>(epsilon), static_cast<int32_t>(num_tokens),
148
+ static_cast<int32_t>(hidden_size), input_stride,
149
+ static_cast<uint32_t>(num_tokens), threads_per_tg);
150
+ }
151
+ }
152
+
153
+ void fused_add_rms_norm(torch::Tensor &input, torch::Tensor &residual,
154
+ torch::Tensor &weight, double epsilon) {
155
+ TORCH_CHECK(residual.is_contiguous(), "residual must be contiguous");
156
+ TORCH_CHECK(weight.is_contiguous(), "weight must be contiguous");
157
+ TORCH_CHECK(input.device().is_mps(), "input must be on MPS device");
158
+ TORCH_CHECK(input.scalar_type() == residual.scalar_type(),
159
+ "input and residual must have same dtype");
160
+
161
+ const int hidden_size = input.size(-1);
162
+ const int64_t input_stride = input.stride(-2);
163
+ const int num_tokens =
164
+ static_cast<int>(input.numel() / hidden_size);
165
+
166
+ @autoreleasepool {
167
+ at::mps::MPSStream *stream = at::mps::getCurrentMPSStream();
168
+ TORCH_CHECK(stream, "Failed to get MPS stream");
169
+
170
+ id<MTLDevice> device = stream->device();
171
+ id<MTLCommandBuffer> cmdBuf = stream->commandBuffer();
172
+ NSError *error = nil;
173
+
174
+ NSString *kernName =
175
+ kernelNameForDtype("fused_add_rms_norm", input.scalar_type());
176
+ id<MTLComputePipelineState> pso = createPipeline(device, kernName, &error);
177
+ TORCH_CHECK(pso, "Pipeline creation failed",
178
+ error ? [NSString stringWithFormat:@": %@",
179
+ error.localizedDescription]
180
+ .UTF8String
181
+ : "");
182
+
183
+ const uint32_t threads_per_tg =
184
+ std::min<uint32_t>(512, hidden_size);
185
+
186
+ dispatchNormKernel(
187
+ pso, stream, cmdBuf,
188
+ getMTLBufferStorage(input),
189
+ input.storage_offset() * input.element_size(),
190
+ getMTLBufferStorage(residual),
191
+ residual.storage_offset() * residual.element_size(),
192
+ getMTLBufferStorage(weight),
193
+ weight.storage_offset() * weight.element_size(),
194
+ static_cast<float>(epsilon), static_cast<int32_t>(num_tokens),
195
+ static_cast<int32_t>(hidden_size), input_stride,
196
+ static_cast<uint32_t>(num_tokens), threads_per_tg);
197
+ }
198
+ }
fused-rms-norm-metal/utils.metal ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <metal_stdlib>
2
+ using namespace metal;
3
+
4
+ #if defined(__HAVE_BFLOAT__)
5
+
6
+ typedef bfloat bfloat16_t;
7
+
8
+ #else
9
+
10
+ constexpr METAL_FUNC uint16_t float_to_bfloat_bits(float x) {
11
+ if ((as_type<uint32_t>(x) & ~_fp_encoding_traits<float>::sign_mask) >
12
+ _fp_encoding_traits<float>::inf_mask) {
13
+ return uint16_t(as_type<uint32_t>(0x7FC0));
14
+ }
15
+ uint32_t float_bits = as_type<uint32_t>(x);
16
+ float_bits += ((float_bits >> 16) & 1) + as_type<uint32_t>(0x7FFF);
17
+ return float_bits >> 16;
18
+ }
19
+
20
+ constexpr METAL_FUNC float bfloat_bits_to_float(uint16_t x) {
21
+ return as_type<float>((uint32_t)x << 16);
22
+ }
23
+
24
+ struct _MLX_BFloat16;
25
+
26
+ template <typename T>
27
+ static constexpr constant bool can_convert_to_bfloat =
28
+ !is_same_v<T, _MLX_BFloat16> && is_convertible_v<T, float>;
29
+
30
+ template <typename T>
31
+ static constexpr constant bool can_convert_from_bfloat =
32
+ !is_same_v<T, _MLX_BFloat16> && is_convertible_v<float, T>;
33
+
34
+ struct _MLX_BFloat16 {
35
+ uint16_t bits_;
36
+ _MLX_BFloat16() thread = default;
37
+ _MLX_BFloat16() threadgroup = default;
38
+ _MLX_BFloat16() device = default;
39
+ _MLX_BFloat16() constant = default;
40
+
41
+ struct bits_to_bfloat_struct {};
42
+ static constexpr METAL_FUNC bits_to_bfloat_struct bits_to_bfloat() {
43
+ return bits_to_bfloat_struct();
44
+ }
45
+ constexpr METAL_FUNC _MLX_BFloat16(uint16_t bits, bits_to_bfloat_struct)
46
+ : bits_(bits) {}
47
+
48
+ template <typename T,
49
+ typename = typename enable_if<can_convert_to_bfloat<T>>::type>
50
+ constexpr METAL_FUNC _MLX_BFloat16(T x) thread
51
+ : bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
52
+
53
+ template <typename T,
54
+ typename = typename enable_if<can_convert_to_bfloat<T>>::type>
55
+ constexpr METAL_FUNC _MLX_BFloat16(T x) threadgroup
56
+ : bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
57
+
58
+ template <typename T,
59
+ typename = typename enable_if<can_convert_to_bfloat<T>>::type>
60
+ constexpr METAL_FUNC _MLX_BFloat16(T x) device
61
+ : bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
62
+
63
+ template <typename T,
64
+ typename = typename enable_if<can_convert_to_bfloat<T>>::type>
65
+ constexpr METAL_FUNC _MLX_BFloat16(T x) constant
66
+ : bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
67
+
68
+ template <typename T,
69
+ typename = typename enable_if<can_convert_from_bfloat<T>>::type>
70
+ constexpr METAL_FUNC operator T() const thread {
71
+ return static_cast<T>(bfloat_bits_to_float(bits_));
72
+ }
73
+
74
+ template <typename T,
75
+ typename = typename enable_if<can_convert_from_bfloat<T>>::type>
76
+ constexpr METAL_FUNC operator T() const threadgroup {
77
+ return static_cast<T>(bfloat_bits_to_float(bits_));
78
+ }
79
+
80
+ template <typename T,
81
+ typename = typename enable_if<can_convert_from_bfloat<T>>::type>
82
+ constexpr METAL_FUNC operator T() const device {
83
+ return static_cast<T>(bfloat_bits_to_float(bits_));
84
+ }
85
+
86
+ template <typename T,
87
+ typename = typename enable_if<can_convert_from_bfloat<T>>::type>
88
+ constexpr METAL_FUNC operator T() constant {
89
+ return static_cast<T>(bfloat_bits_to_float(bits_));
90
+ }
91
+ };
92
+
93
+ constexpr METAL_FUNC _MLX_BFloat16 operator-(_MLX_BFloat16 x) {
94
+ return -static_cast<float>(x);
95
+ }
96
+
97
+ #define bfloat_binop_base(__op__, __operator__, otype, atype, btype, ctype) \
98
+ constexpr METAL_FUNC otype __operator__(atype lhs, btype rhs) { \
99
+ return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
100
+ }
101
+
102
+ #define bfloat_binop_helper(__op__, __operator__, otype, itype, ctype) \
103
+ constexpr METAL_FUNC otype __operator__(_MLX_BFloat16 lhs, itype rhs) { \
104
+ return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
105
+ } \
106
+ constexpr METAL_FUNC otype __operator__(itype lhs, _MLX_BFloat16 rhs) { \
107
+ return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
108
+ }
109
+
110
+ #define bfloat_binop(_op_, _operator_) \
111
+ bfloat_binop_base(_op_, _operator_, _MLX_BFloat16, _MLX_BFloat16, \
112
+ _MLX_BFloat16, float); \
113
+ bfloat_binop_helper(_op_, _operator_, float, float, float); \
114
+ bfloat_binop_helper(_op_, _operator_, float, half, float); \
115
+ bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int32_t, float); \
116
+ bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint32_t, float); \
117
+ bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int64_t, float); \
118
+ bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint64_t, float);
119
+
120
+ bfloat_binop(+, operator+);
121
+ bfloat_binop(-, operator-);
122
+ bfloat_binop(*, operator*);
123
+ bfloat_binop(/, operator/);
124
+
125
+ #undef bfloat_binop_base
126
+ #undef bfloat_binop_helper
127
+ #undef bfloat_binop
128
+
129
+ typedef struct _MLX_BFloat16 bfloat16_t;
130
+
131
+ #endif
tests/test_rms_norm.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for Metal RMS normalization kernels.
2
+
3
+ Validates rms_norm and fused_add_rms_norm against PyTorch reference
4
+ implementations across dtypes and hidden sizes.
5
+ """
6
+
7
+ import pytest
8
+ import torch
9
+
10
+ import fused_rms_norm as ops
11
+
12
+
13
+ def _is_mps_available() -> bool:
14
+ return hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
15
+
16
+
17
+ if _is_mps_available():
18
+ DEVICES = ["mps"]
19
+ else:
20
+ DEVICES = [f"cuda:{i}" for i in range(max(1, torch.cuda.device_count()))]
21
+
22
+ DTYPES = [torch.float32, torch.float16, torch.bfloat16]
23
+ HIDDEN_SIZES = [128, 768, 2048, 4096]
24
+ NUM_TOKENS = [1, 7, 32]
25
+ EPSILON = 1e-6
26
+
27
+
28
+ def _ref_rms_norm(
29
+ input: torch.Tensor,
30
+ weight: torch.Tensor,
31
+ epsilon: float,
32
+ ) -> torch.Tensor:
33
+ """Pure-PyTorch reference for RMS normalization."""
34
+ variance = input.float().pow(2).mean(dim=-1, keepdim=True)
35
+ inv_rms = torch.rsqrt(variance + epsilon)
36
+ return (input.float() * inv_rms * weight.float()).to(input.dtype)
37
+
38
+
39
+ def _ref_fused_add_rms_norm(
40
+ input: torch.Tensor,
41
+ residual: torch.Tensor,
42
+ weight: torch.Tensor,
43
+ epsilon: float,
44
+ ) -> tuple[torch.Tensor, torch.Tensor]:
45
+ """Pure-PyTorch reference for fused residual add + RMS norm.
46
+
47
+ Returns (normalized_output, updated_residual).
48
+ """
49
+ new_residual = residual.float() + input.float()
50
+ variance = new_residual.pow(2).mean(dim=-1, keepdim=True)
51
+ inv_rms = torch.rsqrt(variance + epsilon)
52
+ normalized = (new_residual * inv_rms * weight.float()).to(input.dtype)
53
+ return normalized, new_residual.to(residual.dtype)
54
+
55
+
56
+ @pytest.mark.parametrize("device", DEVICES)
57
+ @pytest.mark.parametrize("dtype", DTYPES)
58
+ @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
59
+ @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
60
+ @torch.inference_mode()
61
+ def test_rms_norm(
62
+ device: str,
63
+ dtype: torch.dtype,
64
+ hidden_size: int,
65
+ num_tokens: int,
66
+ ) -> None:
67
+ torch.manual_seed(42)
68
+ input = torch.randn(num_tokens, hidden_size, dtype=dtype, device=device)
69
+ weight = torch.randn(hidden_size, dtype=dtype, device=device)
70
+ out = torch.empty_like(input)
71
+
72
+ # Run kernel.
73
+ ops.rms_norm(out, input, weight, EPSILON)
74
+
75
+ # Run reference on CPU.
76
+ ref = _ref_rms_norm(input.cpu(), weight.cpu(), EPSILON)
77
+
78
+ # Compare.
79
+ if dtype == torch.float32:
80
+ atol, rtol = 1e-5, 1e-5
81
+ elif dtype == torch.float16:
82
+ atol, rtol = 1e-3, 1e-3
83
+ else: # bfloat16
84
+ atol, rtol = 2e-2, 2e-2
85
+
86
+ torch.testing.assert_close(out.cpu(), ref, atol=atol, rtol=rtol)
87
+
88
+
89
+ @pytest.mark.parametrize("device", DEVICES)
90
+ @pytest.mark.parametrize("dtype", DTYPES)
91
+ @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
92
+ @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
93
+ @torch.inference_mode()
94
+ def test_fused_add_rms_norm(
95
+ device: str,
96
+ dtype: torch.dtype,
97
+ hidden_size: int,
98
+ num_tokens: int,
99
+ ) -> None:
100
+ torch.manual_seed(42)
101
+ input = torch.randn(num_tokens, hidden_size, dtype=dtype, device=device)
102
+ residual = torch.randn(num_tokens, hidden_size, dtype=dtype, device=device)
103
+ weight = torch.randn(hidden_size, dtype=dtype, device=device)
104
+
105
+ # Compute reference on CPU BEFORE running kernel (kernel modifies in-place).
106
+ ref_normalized, ref_residual = _ref_fused_add_rms_norm(
107
+ input.cpu(), residual.cpu(), weight.cpu(), EPSILON
108
+ )
109
+
110
+ # Run kernel (modifies input and residual in-place).
111
+ ops.fused_add_rms_norm(input, residual, weight, EPSILON)
112
+
113
+ # Compare.
114
+ if dtype == torch.float32:
115
+ atol, rtol = 1e-5, 1e-5
116
+ elif dtype == torch.float16:
117
+ atol, rtol = 1e-3, 1e-3
118
+ else: # bfloat16
119
+ atol, rtol = 2e-2, 2e-2
120
+
121
+ torch.testing.assert_close(
122
+ residual.cpu(), ref_residual, atol=atol, rtol=rtol
123
+ )
124
+ torch.testing.assert_close(
125
+ input.cpu(), ref_normalized, atol=atol, rtol=rtol
126
+ )
127
+
128
+
129
+ @pytest.mark.parametrize("device", DEVICES)
130
+ @pytest.mark.parametrize("dtype", [torch.float32])
131
+ @torch.inference_mode()
132
+ def test_rms_norm_weight_scaling(
133
+ device: str,
134
+ dtype: torch.dtype,
135
+ ) -> None:
136
+ """Verify that weight=1 gives pure RMS normalization."""
137
+ hidden_size = 256
138
+ num_tokens = 4
139
+ input = torch.randn(num_tokens, hidden_size, dtype=dtype, device=device)
140
+ weight_ones = torch.ones(hidden_size, dtype=dtype, device=device)
141
+ weight_twos = 2.0 * torch.ones(hidden_size, dtype=dtype, device=device)
142
+
143
+ out_ones = torch.empty_like(input)
144
+ out_twos = torch.empty_like(input)
145
+
146
+ ops.rms_norm(out_ones, input, weight_ones, EPSILON)
147
+ ops.rms_norm(out_twos, input, weight_twos, EPSILON)
148
+
149
+ # weight=2 should produce exactly 2x the weight=1 result.
150
+ torch.testing.assert_close(
151
+ out_twos.cpu(), 2.0 * out_ones.cpu(), atol=1e-5, rtol=1e-5
152
+ )
153
+
154
+
155
+ @pytest.mark.parametrize("device", DEVICES)
156
+ @pytest.mark.parametrize("dtype", [torch.float32])
157
+ @torch.inference_mode()
158
+ def test_fused_add_rms_norm_residual_accumulation(
159
+ device: str,
160
+ dtype: torch.dtype,
161
+ ) -> None:
162
+ """Verify residual is correctly accumulated (residual += input)."""
163
+ hidden_size = 128
164
+ num_tokens = 2
165
+ input = torch.randn(num_tokens, hidden_size, dtype=dtype, device=device)
166
+ residual = torch.randn(num_tokens, hidden_size, dtype=dtype, device=device)
167
+ weight = torch.ones(hidden_size, dtype=dtype, device=device)
168
+
169
+ expected_residual = (residual + input).cpu()
170
+
171
+ ops.fused_add_rms_norm(input, residual, weight, EPSILON)
172
+
173
+ torch.testing.assert_close(
174
+ residual.cpu(), expected_residual, atol=1e-5, rtol=1e-5
175
+ )
torch-ext/fused_rms_norm/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from ._custom_ops import fused_add_rms_norm, rms_norm
2
+ from ._ops import ops
3
+
4
+ __all__ = [
5
+ "fused_add_rms_norm",
6
+ "ops",
7
+ "rms_norm",
8
+ ]
torch-ext/fused_rms_norm/_custom_ops.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from ._ops import ops
4
+
5
+
6
+ def rms_norm(
7
+ out: torch.Tensor,
8
+ input: torch.Tensor,
9
+ weight: torch.Tensor,
10
+ epsilon: float,
11
+ ) -> None:
12
+ ops.rms_norm(out, input, weight, epsilon)
13
+
14
+
15
+ def fused_add_rms_norm(
16
+ input: torch.Tensor,
17
+ residual: torch.Tensor,
18
+ weight: torch.Tensor,
19
+ epsilon: float,
20
+ ) -> None:
21
+ ops.fused_add_rms_norm(input, residual, weight, epsilon)
torch-ext/torch_binding.cpp ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/library.h>
2
+
3
+ #include "registration.h"
4
+ #include "torch_binding.h"
5
+
6
+ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
7
+ ops.def(
8
+ "rms_norm(Tensor! result, Tensor input, Tensor weight, float epsilon) -> ()");
9
+ #if defined(METAL_KERNEL)
10
+ ops.impl("rms_norm", torch::kMPS, rms_norm);
11
+ #endif
12
+
13
+ ops.def(
14
+ "fused_add_rms_norm(Tensor! input, Tensor! residual, Tensor weight, float epsilon) -> ()");
15
+ #if defined(METAL_KERNEL)
16
+ ops.impl("fused_add_rms_norm", torch::kMPS, fused_add_rms_norm);
17
+ #endif
18
+ }
19
+
20
+ REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
torch-ext/torch_binding.h ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/torch.h>
4
+
5
+ void rms_norm(torch::Tensor& out, torch::Tensor& input,
6
+ torch::Tensor& weight, double epsilon);
7
+
8
+ void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
9
+ torch::Tensor& weight, double epsilon);