Kernels
wyldecat Claude Opus 4.6 (1M context) commited on
Commit
906e125
·
1 Parent(s): f06406d

refactor: remove Triton kernels, add hidden_clamp to unscored ops

Browse files

- Remove all Triton kernel code (fwd/bwd kernels, autotune configs,
triton import) — replaced by CUDA kernels in grouped_poly_norm.cu
- Add hidden_clamp parameter to unscored C++ ops (forward/backward)
so both scored and unscored paths support clamping
- Update register_fake, autograd Function, and dispatch for unscored ops
- Replace HAS_TRITON with _has_cuda_ops in tests

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

activation/grouped_poly_norm.cu CHANGED
@@ -609,8 +609,9 @@ std::tuple<torch::Tensor, torch::Tensor>
609
  grouped_poly_norm_forward(
610
  const torch::Tensor &input, const torch::Tensor &mul,
611
  const torch::Tensor &weight, const torch::Tensor &bias,
612
- const torch::Tensor &offsets, double eps, int64_t expert_offset) {
613
- return _fwd_impl(input, mul, weight, bias, offsets, nullptr, eps, expert_offset, -1.0);
 
614
  }
615
 
616
  std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
@@ -618,11 +619,12 @@ grouped_poly_norm_backward(
618
  const torch::Tensor &grad_output, const torch::Tensor &input,
619
  const torch::Tensor &mul, const torch::Tensor &weight,
620
  const torch::Tensor &bias, const torch::Tensor &offsets,
621
- const torch::Tensor &inv_rms, double eps, int64_t expert_offset) {
 
622
  const int64_t N = input.size(0);
623
  auto [ig, mg, wg, bg, _] = _bwd_impl(
624
  grad_output, input, mul, weight, bias, offsets, inv_rms,
625
- nullptr, nullptr, N, eps, expert_offset, -1.0);
626
  return {ig, mg, wg, bg};
627
  }
628
 
 
609
  grouped_poly_norm_forward(
610
  const torch::Tensor &input, const torch::Tensor &mul,
611
  const torch::Tensor &weight, const torch::Tensor &bias,
612
+ const torch::Tensor &offsets, double eps, int64_t expert_offset,
613
+ double hidden_clamp) {
614
+ return _fwd_impl(input, mul, weight, bias, offsets, nullptr, eps, expert_offset, hidden_clamp);
615
  }
616
 
617
  std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
 
619
  const torch::Tensor &grad_output, const torch::Tensor &input,
620
  const torch::Tensor &mul, const torch::Tensor &weight,
621
  const torch::Tensor &bias, const torch::Tensor &offsets,
622
+ const torch::Tensor &inv_rms, double eps, int64_t expert_offset,
623
+ double hidden_clamp) {
624
  const int64_t N = input.size(0);
625
  auto [ig, mg, wg, bg, _] = _bwd_impl(
626
  grad_output, input, mul, weight, bias, offsets, inv_rms,
627
+ nullptr, nullptr, N, eps, expert_offset, hidden_clamp);
628
  return {ig, mg, wg, bg};
629
  }
630
 
tests/test_fused_mul_grouped_poly_norm.py CHANGED
@@ -2,11 +2,11 @@ import pytest
2
  import torch
3
 
4
  from grouped_poly_norm import (
5
- HAS_TRITON,
6
  fused_mul_grouped_poly_norm_ref,
7
  )
8
 
9
- if HAS_TRITON:
10
  from grouped_poly_norm import fused_mul_grouped_poly_norm
11
 
12
  from .utils import assert_close
@@ -95,7 +95,7 @@ def _run_triton(input_t, mul_t, weight, bias, offsets, expert_offset=0,
95
  return grads + (s.grad,) if s is not None else grads + (None,)
96
 
97
 
98
- @pytest.mark.skipif(not HAS_TRITON, reason="Triton not available")
99
  @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
100
  @pytest.mark.parametrize("d", D)
101
  @pytest.mark.parametrize("num_experts", NUM_EXPERTS_LIST)
@@ -134,7 +134,7 @@ def test_fused_mul_grouped_poly_norm_forward(
134
  assert_close(out_ref, out_tri, atol=1e-2, rtol=1e-2)
135
 
136
 
137
- @pytest.mark.skipif(not HAS_TRITON, reason="Triton not available")
138
  @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
139
  @pytest.mark.parametrize("d", D)
140
  @pytest.mark.parametrize("num_experts", NUM_EXPERTS_LIST)
@@ -173,7 +173,7 @@ def test_fused_mul_grouped_poly_norm_backward(
173
  assert_close(b_grad_ref, b_grad_tri, atol=atol, rtol=rtol)
174
 
175
 
176
- @pytest.mark.skipif(not HAS_TRITON, reason="Triton not available")
177
  @pytest.mark.parametrize("dtype", DTYPES)
178
  @pytest.mark.parametrize("expert_offset", EXPERT_OFFSETS)
179
  @pytest.mark.parametrize("device", CUDA_DEVICES)
@@ -236,7 +236,7 @@ def test_fused_mul_grouped_poly_norm_zero_token_experts(
236
  f"but got max={b_grad_tri[wi].abs().max().item()}")
237
 
238
 
239
- @pytest.mark.skipif(not HAS_TRITON, reason="Triton not available")
240
  @pytest.mark.parametrize("dtype", DTYPES)
241
  @pytest.mark.parametrize("expert_offset", EXPERT_OFFSETS)
242
  @pytest.mark.parametrize("device", CUDA_DEVICES)
@@ -265,7 +265,7 @@ def test_fused_mul_grouped_poly_norm_no_nan_inf(
265
  # ---------------------------------------------------------------------------
266
  # Scores tests
267
  # ---------------------------------------------------------------------------
268
- @pytest.mark.skipif(not HAS_TRITON, reason="Triton not available")
269
  @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
270
  @pytest.mark.parametrize("d", D)
271
  @pytest.mark.parametrize("num_experts", [8, 48])
@@ -289,7 +289,7 @@ def test_fused_mul_grouped_poly_norm_scores_forward(
289
  assert_close(out_ref, out_tri, atol=atol, rtol=rtol)
290
 
291
 
292
- @pytest.mark.skipif(not HAS_TRITON, reason="Triton not available")
293
  @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
294
  @pytest.mark.parametrize("d", D)
295
  @pytest.mark.parametrize("num_experts", [8, 48])
@@ -326,7 +326,7 @@ def test_fused_mul_grouped_poly_norm_scores_backward(
326
  CLAMP_VALUES = [10.0, 1.0, 0.5]
327
 
328
 
329
- @pytest.mark.skipif(not HAS_TRITON, reason="Triton not available")
330
  @pytest.mark.parametrize("num_tokens", [4096])
331
  @pytest.mark.parametrize("d", [256, 1280])
332
  @pytest.mark.parametrize("num_experts", [8])
@@ -353,7 +353,7 @@ def test_fused_mul_grouped_poly_norm_hidden_clamp_forward(
353
  assert_close(out_ref, out_tri, atol=atol, rtol=rtol)
354
 
355
 
356
- @pytest.mark.skipif(not HAS_TRITON, reason="Triton not available")
357
  @pytest.mark.parametrize("num_tokens", [4096])
358
  @pytest.mark.parametrize("d", [256, 1280])
359
  @pytest.mark.parametrize("num_experts", [8])
 
2
  import torch
3
 
4
  from grouped_poly_norm import (
5
+ _has_cuda_ops,
6
  fused_mul_grouped_poly_norm_ref,
7
  )
8
 
9
+ if _has_cuda_ops:
10
  from grouped_poly_norm import fused_mul_grouped_poly_norm
11
 
12
  from .utils import assert_close
 
95
  return grads + (s.grad,) if s is not None else grads + (None,)
96
 
97
 
98
+ @pytest.mark.skipif(not _has_cuda_ops, reason="CUDA ops not available")
99
  @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
100
  @pytest.mark.parametrize("d", D)
101
  @pytest.mark.parametrize("num_experts", NUM_EXPERTS_LIST)
 
134
  assert_close(out_ref, out_tri, atol=1e-2, rtol=1e-2)
135
 
136
 
137
+ @pytest.mark.skipif(not _has_cuda_ops, reason="CUDA ops not available")
138
  @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
139
  @pytest.mark.parametrize("d", D)
140
  @pytest.mark.parametrize("num_experts", NUM_EXPERTS_LIST)
 
173
  assert_close(b_grad_ref, b_grad_tri, atol=atol, rtol=rtol)
174
 
175
 
176
+ @pytest.mark.skipif(not _has_cuda_ops, reason="CUDA ops not available")
177
  @pytest.mark.parametrize("dtype", DTYPES)
178
  @pytest.mark.parametrize("expert_offset", EXPERT_OFFSETS)
179
  @pytest.mark.parametrize("device", CUDA_DEVICES)
 
236
  f"but got max={b_grad_tri[wi].abs().max().item()}")
237
 
238
 
239
+ @pytest.mark.skipif(not _has_cuda_ops, reason="CUDA ops not available")
240
  @pytest.mark.parametrize("dtype", DTYPES)
241
  @pytest.mark.parametrize("expert_offset", EXPERT_OFFSETS)
242
  @pytest.mark.parametrize("device", CUDA_DEVICES)
 
265
  # ---------------------------------------------------------------------------
266
  # Scores tests
267
  # ---------------------------------------------------------------------------
268
+ @pytest.mark.skipif(not _has_cuda_ops, reason="CUDA ops not available")
269
  @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
270
  @pytest.mark.parametrize("d", D)
271
  @pytest.mark.parametrize("num_experts", [8, 48])
 
289
  assert_close(out_ref, out_tri, atol=atol, rtol=rtol)
290
 
291
 
292
+ @pytest.mark.skipif(not _has_cuda_ops, reason="CUDA ops not available")
293
  @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
294
  @pytest.mark.parametrize("d", D)
295
  @pytest.mark.parametrize("num_experts", [8, 48])
 
326
  CLAMP_VALUES = [10.0, 1.0, 0.5]
327
 
328
 
329
+ @pytest.mark.skipif(not _has_cuda_ops, reason="CUDA ops not available")
330
  @pytest.mark.parametrize("num_tokens", [4096])
331
  @pytest.mark.parametrize("d", [256, 1280])
332
  @pytest.mark.parametrize("num_experts", [8])
 
353
  assert_close(out_ref, out_tri, atol=atol, rtol=rtol)
354
 
355
 
356
+ @pytest.mark.skipif(not _has_cuda_ops, reason="CUDA ops not available")
357
  @pytest.mark.parametrize("num_tokens", [4096])
358
  @pytest.mark.parametrize("d", [256, 1280])
359
  @pytest.mark.parametrize("num_experts", [8])
torch-ext/activation/grouped_poly_norm.py CHANGED
@@ -1,49 +1,26 @@
1
- """Triton-accelerated Grouped FusedMulPolyNorm for MoE.
2
 
3
- Fuses the entire PolyNorm computation into two Triton kernels (fwd + bwd),
4
- eliminating multiple intermediate tensors and kernel launches.
5
 
6
  PolyNorm formula (per row):
7
  poly = w[0] * rms_norm(x^3) + w[1] * rms_norm(x^2) + w[2] * rms_norm(x) + bias
8
- output = poly * mul
 
9
 
10
  where rms_norm(x) = x / sqrt(mean(x^2, dim=-1) + eps)
11
 
12
- Performance optimizations:
13
- - @triton.autotune selects optimal BLOCK_D, num_warps, and num_stages per
14
- hidden dimension.
15
- - Single-tile specialization: when D <= BLOCK_D, all data stays in registers
16
- across the reduction and output phases, eliminating redundant global reads.
17
- - Multi-tile software pipelining: explicit num_stages in autotune configs
18
- enables overlapping memory loads with computation across loop iterations.
19
- - In-kernel binary search for expert mapping: eliminates 2 PyTorch kernel
20
- launches (torch.arange + torch.bucketize) per forward/backward call.
21
- - Backward 2-pass optimization: pass 1 merges RMS statistics computation
22
- with dot product accumulation, pass 2 computes gradients. This reduces
23
- memory traffic compared to a naive 3-pass approach.
24
-
25
- Forward kernel: one program per row, tiles over D dimension.
26
- - Computes x, x^2, x^3 in registers
27
- - Computes three RMS norms in a single pass (shared variance reduction)
28
- - Applies polynomial weights + bias + mul in-place
29
-
30
- Backward kernel: one program per row, tiles over D dimension.
31
- - Recomputes forward intermediates from saved inputs (activation recomputation)
32
- - 2-pass: (1) RMS stats + dot products + bias grad, (2) grad_input + grad_mul + weight grads
33
- - Weight/bias gradients use tl.atomic_add for cross-row accumulation
34
  """
35
 
36
  import torch
37
  from torch import Tensor
38
 
39
- try:
40
- import triton
41
- import triton.language as tl
42
-
43
- HAS_TRITON = True
44
- except ImportError:
45
- HAS_TRITON = False
46
-
47
  # Try to load CUDA ops at module level
48
  _ops = None
49
  try:
@@ -61,14 +38,15 @@ _has_cuda_ops = _ops is not None and hasattr(_ops, 'grouped_poly_norm_forward')
61
  if _has_cuda_ops:
62
  try:
63
  @torch.library.register_fake("_activation::grouped_poly_norm_forward")
64
- def _fwd_fake(input, mul, weight, bias, offsets, eps, expert_offset):
 
65
  return (torch.empty_like(input),
66
  torch.empty(input.shape[0], 3, dtype=torch.float32,
67
  device=input.device))
68
 
69
  @torch.library.register_fake("_activation::grouped_poly_norm_backward")
70
  def _bwd_fake(grad_output, input, mul, weight, bias, offsets, inv_rms,
71
- eps, expert_offset):
72
  return (torch.empty_like(input),
73
  torch.empty_like(mul),
74
  torch.empty_like(weight),
@@ -164,383 +142,32 @@ def fused_mul_grouped_poly_norm_ref(
164
 
165
 
166
  # ---------------------------------------------------------------------------
167
- # Triton kernel implementation
168
  # ---------------------------------------------------------------------------
169
- if HAS_TRITON:
170
- # --- Autotune configurations ---
171
- _GROUPED_POLYNORM_FWD_CONFIGS = [
172
- triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2),
173
- triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3),
174
- triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2),
175
- triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3),
176
- triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2),
177
- triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4),
178
- triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2),
179
- triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3),
180
- triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4),
181
- triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2),
182
- triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3),
183
- triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4),
184
- triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2),
185
- triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2),
186
- triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3),
187
- triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2),
188
- triton.Config({"BLOCK_D": 2048}, num_warps=4, num_stages=1),
189
- triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1),
190
- triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1),
191
- triton.Config({"BLOCK_D": 2048}, num_warps=32, num_stages=1),
192
- ]
193
-
194
- _GROUPED_POLYNORM_BWD_CONFIGS = [
195
- # Low-warp configs for high SM occupancy (latency hiding)
196
- triton.Config({"BLOCK_D": 2048, "BLOCK_N": 1}, num_warps=2, num_stages=1),
197
- triton.Config({"BLOCK_D": 2048, "BLOCK_N": 1}, num_warps=4, num_stages=1),
198
- triton.Config({"BLOCK_D": 2048, "BLOCK_N": 2}, num_warps=2, num_stages=1),
199
- triton.Config({"BLOCK_D": 2048, "BLOCK_N": 2}, num_warps=4, num_stages=1),
200
- triton.Config({"BLOCK_D": 2048, "BLOCK_N": 4}, num_warps=2, num_stages=1),
201
- triton.Config({"BLOCK_D": 2048, "BLOCK_N": 4}, num_warps=4, num_stages=1),
202
- triton.Config({"BLOCK_D": 2048, "BLOCK_N": 8}, num_warps=2, num_stages=1),
203
- triton.Config({"BLOCK_D": 2048, "BLOCK_N": 8}, num_warps=4, num_stages=1),
204
- # Medium-warp configs
205
- triton.Config({"BLOCK_D": 2048, "BLOCK_N": 1}, num_warps=8, num_stages=1),
206
- triton.Config({"BLOCK_D": 2048, "BLOCK_N": 2}, num_warps=8, num_stages=1),
207
- triton.Config({"BLOCK_D": 2048, "BLOCK_N": 4}, num_warps=8, num_stages=1),
208
- triton.Config({"BLOCK_D": 2048, "BLOCK_N": 8}, num_warps=8, num_stages=1),
209
- # Multi-tile configs (BLOCK_D=1024 for D=1280 -> 2 tiles, no mask waste)
210
- triton.Config({"BLOCK_D": 1024, "BLOCK_N": 1}, num_warps=2, num_stages=2),
211
- triton.Config({"BLOCK_D": 1024, "BLOCK_N": 1}, num_warps=4, num_stages=2),
212
- triton.Config({"BLOCK_D": 1024, "BLOCK_N": 2}, num_warps=2, num_stages=2),
213
- triton.Config({"BLOCK_D": 1024, "BLOCK_N": 2}, num_warps=4, num_stages=2),
214
- triton.Config({"BLOCK_D": 1024, "BLOCK_N": 4}, num_warps=4, num_stages=2),
215
- triton.Config({"BLOCK_D": 1024, "BLOCK_N": 8}, num_warps=4, num_stages=2),
216
- ]
217
-
218
- @triton.autotune(
219
- configs=_GROUPED_POLYNORM_FWD_CONFIGS,
220
- key=["D"],
221
- )
222
- @triton.jit
223
- def _grouped_polynorm_fwd_kernel(
224
- input_ptr,
225
- mul_ptr,
226
- weight_ptr,
227
- bias_ptr,
228
- offsets_ptr,
229
- output_ptr,
230
- inv_rms_ptr,
231
- N,
232
- D,
233
- num_experts,
234
- eps,
235
- expert_offset,
236
- stride_input_row,
237
- stride_mul_row,
238
- stride_out_row,
239
- BLOCK_D: tl.constexpr,
240
- ):
241
- """Forward kernel: one program per row. Saves inv_rms for backward."""
242
- row = tl.program_id(0)
243
- if row >= N:
244
- return
245
-
246
- # Binary search for expert index (12 iters covers up to 4096 experts)
247
- lo = 0
248
- hi = num_experts
249
- for _ in range(12):
250
- if lo < hi:
251
- mid = (lo + hi) // 2
252
- if tl.load(offsets_ptr + mid) <= row:
253
- lo = mid + 1
254
- else:
255
- hi = mid
256
- eidx = lo + expert_offset
257
-
258
- w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32)
259
- w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32)
260
- w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32)
261
- b = tl.load(bias_ptr + eidx).to(tl.float32)
262
-
263
- input_row_ptr = input_ptr + row * stride_input_row
264
- mul_row_ptr = mul_ptr + row * stride_mul_row
265
- out_row_ptr = output_ptr + row * stride_out_row
266
-
267
- D_float = D.to(tl.float32)
268
-
269
- # --- Single-tile path ---
270
- if D <= BLOCK_D:
271
- d_offs = tl.arange(0, BLOCK_D)
272
- mask = d_offs < D
273
-
274
- x = tl.load(input_row_ptr + d_offs, mask=mask,
275
- other=0.0).to(tl.float32)
276
- m = tl.load(mul_row_ptr + d_offs, mask=mask,
277
- other=0.0).to(tl.float32)
278
-
279
- x2 = x * x
280
- x3 = x2 * x
281
-
282
- inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps)
283
- inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps)
284
- inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x3 * x3) / D_float + eps)
285
-
286
- # Save inv_rms for backward
287
- tl.store(inv_rms_ptr + row * 3 + 0, inv_rms_x)
288
- tl.store(inv_rms_ptr + row * 3 + 1, inv_rms_x2)
289
- tl.store(inv_rms_ptr + row * 3 + 2, inv_rms_x3)
290
-
291
- # Pre-multiply scalar weight * inv_rms to save 1 FMA per element
292
- w0_inv = w0 * inv_rms_x3
293
- w1_inv = w1 * inv_rms_x2
294
- w2_inv = w2 * inv_rms_x
295
-
296
- poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b
297
- tl.store(out_row_ptr + d_offs, poly * m, mask=mask)
298
- else:
299
- # --- Multi-tile: two-pass approach ---
300
- sum_x2 = tl.zeros((), dtype=tl.float32)
301
- sum_x4 = tl.zeros((), dtype=tl.float32)
302
- sum_x6 = tl.zeros((), dtype=tl.float32)
303
-
304
- for d_start in range(0, D, BLOCK_D):
305
- d_offs = d_start + tl.arange(0, BLOCK_D)
306
- mask = d_offs < D
307
- x = tl.load(input_row_ptr + d_offs, mask=mask,
308
- other=0.0).to(tl.float32)
309
- x2 = x * x
310
- x3 = x2 * x
311
- sum_x2 += tl.sum(x2)
312
- sum_x4 += tl.sum(x2 * x2)
313
- sum_x6 += tl.sum(x3 * x3)
314
-
315
- inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps)
316
- inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps)
317
- inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps)
318
-
319
- # Save inv_rms for backward
320
- tl.store(inv_rms_ptr + row * 3 + 0, inv_rms_x)
321
- tl.store(inv_rms_ptr + row * 3 + 1, inv_rms_x2)
322
- tl.store(inv_rms_ptr + row * 3 + 2, inv_rms_x3)
323
-
324
- # Pre-multiply scalar weight * inv_rms
325
- w0_inv = w0 * inv_rms_x3
326
- w1_inv = w1 * inv_rms_x2
327
- w2_inv = w2 * inv_rms_x
328
-
329
- for d_start in range(0, D, BLOCK_D):
330
- d_offs = d_start + tl.arange(0, BLOCK_D)
331
- mask = d_offs < D
332
- x = tl.load(input_row_ptr + d_offs, mask=mask,
333
- other=0.0).to(tl.float32)
334
- m = tl.load(mul_row_ptr + d_offs, mask=mask,
335
- other=0.0).to(tl.float32)
336
- x2 = x * x
337
- x3 = x2 * x
338
- poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b
339
- tl.store(out_row_ptr + d_offs, poly * m, mask=mask)
340
-
341
- @triton.jit
342
- def _grouped_polynorm_bwd_kernel(
343
- grad_out_ptr,
344
- input_ptr,
345
- mul_ptr,
346
- weight_ptr,
347
- bias_ptr,
348
- offsets_ptr,
349
- inv_rms_ptr,
350
- grad_input_ptr,
351
- grad_mul_ptr,
352
- grad_weight_ptr,
353
- grad_bias_ptr,
354
- N,
355
- D,
356
- num_experts,
357
- eps,
358
- expert_offset,
359
- stride_row,
360
- BLOCK_D: tl.constexpr,
361
- BLOCK_N: tl.constexpr,
362
- ):
363
- """Backward kernel: BLOCK_N rows per program. Loads saved inv_rms.
364
-
365
- Each program processes BLOCK_N consecutive rows. Since MoE tokens
366
- are sorted by expert, consecutive rows often share the same expert,
367
- allowing weight/bias load reuse and amortized binary search.
368
- """
369
- pid = tl.program_id(0)
370
- row_start = pid * BLOCK_N
371
- D_float = D.to(tl.float32)
372
- d_offs = tl.arange(0, BLOCK_D)
373
- d_mask = d_offs < D
374
-
375
- for row_off in tl.static_range(BLOCK_N):
376
- row = row_start + row_off
377
- if row < N:
378
- # Binary search for expert index
379
- lo = 0
380
- hi = num_experts
381
- for _ in range(12):
382
- if lo < hi:
383
- mid = (lo + hi) // 2
384
- if tl.load(offsets_ptr + mid) <= row:
385
- lo = mid + 1
386
- else:
387
- hi = mid
388
- eidx = lo + expert_offset
389
-
390
- w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32)
391
- w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32)
392
- w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32)
393
- b_val = tl.load(bias_ptr + eidx).to(tl.float32)
394
-
395
- input_row_ptr = input_ptr + row * stride_row
396
- mul_row_ptr = mul_ptr + row * stride_row
397
- grad_out_row_ptr = grad_out_ptr + row * stride_row
398
- grad_input_row_ptr = grad_input_ptr + row * stride_row
399
- grad_mul_row_ptr = grad_mul_ptr + row * stride_row
400
-
401
- # --- Single-tile path ---
402
- if D <= BLOCK_D:
403
- x = tl.load(input_row_ptr + d_offs, mask=d_mask,
404
- other=0.0).to(tl.float32)
405
- m = tl.load(mul_row_ptr + d_offs, mask=d_mask,
406
- other=0.0).to(tl.float32)
407
- go = tl.load(grad_out_row_ptr + d_offs, mask=d_mask,
408
- other=0.0).to(tl.float32)
409
-
410
- x2 = x * x
411
- x3 = x2 * x
412
-
413
- # Load saved inv_rms from forward
414
- inv_rms_x = tl.load(inv_rms_ptr + row * 3 + 0)
415
- inv_rms_x2 = tl.load(inv_rms_ptr + row * 3 + 1)
416
- inv_rms_x3 = tl.load(inv_rms_ptr + row * 3 + 2)
417
-
418
- w0_inv = w0 * inv_rms_x3
419
- w1_inv = w1 * inv_rms_x2
420
- w2_inv = w2 * inv_rms_x
421
-
422
- dpoly = go * m
423
-
424
- sum_dpoly_x = tl.sum(dpoly * x)
425
- sum_dpoly_x2 = tl.sum(dpoly * x2)
426
- sum_dpoly_x3 = tl.sum(dpoly * x3)
427
- grad_b_acc = tl.sum(dpoly)
428
-
429
- grad_w0_acc = inv_rms_x3 * sum_dpoly_x3
430
- grad_w1_acc = inv_rms_x2 * sum_dpoly_x2
431
- grad_w2_acc = inv_rms_x * sum_dpoly_x
432
-
433
- coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float
434
- coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float
435
- coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float
436
-
437
- # grad_mul
438
- poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv
439
- tl.store(grad_mul_row_ptr + d_offs, go * (poly + b_val),
440
- mask=d_mask)
441
-
442
- # grad_input
443
- g = inv_rms_x * (w2 * dpoly - x * coeff_x)
444
- g += 2.0 * x * inv_rms_x2 * (w1 * dpoly - x2 * coeff_x2)
445
- g += 3.0 * x2 * inv_rms_x3 * (w0 * dpoly - x3 * coeff_x3)
446
- tl.store(grad_input_row_ptr + d_offs, g, mask=d_mask)
447
-
448
- tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc)
449
- tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc)
450
- tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc)
451
- tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc)
452
- else:
453
- # --- Multi-tile: dot products pass ---
454
- # Load saved inv_rms from forward
455
- inv_rms_x = tl.load(inv_rms_ptr + row * 3 + 0)
456
- inv_rms_x2 = tl.load(inv_rms_ptr + row * 3 + 1)
457
- inv_rms_x3 = tl.load(inv_rms_ptr + row * 3 + 2)
458
-
459
- sum_dpoly_x = tl.zeros((), dtype=tl.float32)
460
- sum_dpoly_x2 = tl.zeros((), dtype=tl.float32)
461
- sum_dpoly_x3 = tl.zeros((), dtype=tl.float32)
462
- grad_b_acc = tl.zeros((), dtype=tl.float32)
463
-
464
- for d_start in range(0, D, BLOCK_D):
465
- tile_offs = d_start + d_offs
466
- tile_mask = tile_offs < D
467
- x = tl.load(input_row_ptr + tile_offs, mask=tile_mask,
468
- other=0.0).to(tl.float32)
469
- m = tl.load(mul_row_ptr + tile_offs, mask=tile_mask,
470
- other=0.0).to(tl.float32)
471
- go = tl.load(grad_out_row_ptr + tile_offs,
472
- mask=tile_mask, other=0.0).to(tl.float32)
473
-
474
- x2 = x * x
475
- x3 = x2 * x
476
- dpoly = go * m
477
-
478
- sum_dpoly_x += tl.sum(dpoly * x)
479
- sum_dpoly_x2 += tl.sum(dpoly * x2)
480
- sum_dpoly_x3 += tl.sum(dpoly * x3)
481
- grad_b_acc += tl.sum(dpoly)
482
-
483
- w0_inv = w0 * inv_rms_x3
484
- w1_inv = w1 * inv_rms_x2
485
- w2_inv = w2 * inv_rms_x
486
-
487
- grad_w0_acc = inv_rms_x3 * sum_dpoly_x3
488
- grad_w1_acc = inv_rms_x2 * sum_dpoly_x2
489
- grad_w2_acc = inv_rms_x * sum_dpoly_x
490
-
491
- coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float
492
- coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float
493
- coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float
494
-
495
- for d_start in range(0, D, BLOCK_D):
496
- tile_offs = d_start + d_offs
497
- tile_mask = tile_offs < D
498
- x = tl.load(input_row_ptr + tile_offs, mask=tile_mask,
499
- other=0.0).to(tl.float32)
500
- m = tl.load(mul_row_ptr + tile_offs, mask=tile_mask,
501
- other=0.0).to(tl.float32)
502
- go = tl.load(grad_out_row_ptr + tile_offs,
503
- mask=tile_mask, other=0.0).to(tl.float32)
504
-
505
- x2 = x * x
506
- x3 = x2 * x
507
-
508
- poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv
509
- tl.store(grad_mul_row_ptr + tile_offs,
510
- go * (poly + b_val), mask=tile_mask)
511
-
512
- dpoly = go * m
513
- g = inv_rms_x * (w2 * dpoly - x * coeff_x)
514
- g += (2.0 * x * inv_rms_x2 *
515
- (w1 * dpoly - x2 * coeff_x2))
516
- g += (3.0 * x2 * inv_rms_x3 *
517
- (w0 * dpoly - x3 * coeff_x3))
518
- tl.store(grad_input_row_ptr + tile_offs, g,
519
- mask=tile_mask)
520
-
521
- tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc)
522
- tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc)
523
- tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc)
524
- tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc)
525
 
526
  class _GroupedPolyNormFn(torch.autograd.Function):
527
  """Without scores — follows poly_norm.py pattern."""
528
 
529
  @staticmethod
530
- def forward(input, mul, weight, bias, offsets, eps, expert_offset):
 
531
  input = input.contiguous()
532
  mul = mul.contiguous()
533
  output, inv_rms = _ops.grouped_poly_norm_forward(
534
- input, mul, weight, bias, offsets, eps, expert_offset)
 
535
  return output, inv_rms
536
 
537
  @staticmethod
538
  def setup_context(ctx, inputs, output):
539
- input, mul, weight, bias, offsets, eps, expert_offset = inputs
 
540
  _, inv_rms = output
541
  ctx.save_for_backward(input, mul, weight, bias, offsets, inv_rms)
542
  ctx.eps = eps
543
  ctx.expert_offset = expert_offset
 
544
 
545
  @staticmethod
546
  def backward(ctx, grad_output, _grad_inv_rms):
@@ -548,8 +175,8 @@ if HAS_TRITON:
548
  grad_output = grad_output.contiguous()
549
  gi, gm, gw, gb = _ops.grouped_poly_norm_backward(
550
  grad_output, input, mul, weight, bias, offsets, inv_rms,
551
- ctx.eps, ctx.expert_offset)
552
- return gi, gm, gw, gb, None, None, None
553
 
554
  class _GroupedPolyNormScoredFn(torch.autograd.Function):
555
  """With scores — same pattern, adds scores + hidden_clamp."""
@@ -622,7 +249,8 @@ if HAS_TRITON:
622
  expert_offset, clamp_val)
623
  else:
624
  output, _ = _GroupedPolyNormFn.apply(
625
- input, mul, weight, bias, offsets, eps, expert_offset)
 
626
  return output
627
 
628
  else:
@@ -639,5 +267,5 @@ else:
639
  hidden_clamp: float | None = None,
640
  ) -> Tensor:
641
  raise RuntimeError(
642
- "Triton is not available. Install triton to use "
643
- "fused_mul_grouped_poly_norm.")
 
1
+ """Grouped FusedMulPolyNorm for MoE — CUDA kernel with autograd wrappers.
2
 
3
+ Fuses the entire PolyNorm computation into CUDA kernels (fwd + bwd),
4
+ with optional scores multiplication and hidden_clamp fusion.
5
 
6
  PolyNorm formula (per row):
7
  poly = w[0] * rms_norm(x^3) + w[1] * rms_norm(x^2) + w[2] * rms_norm(x) + bias
8
+ output = poly * mul * score
9
+ output = clamp(output, -hidden_clamp, hidden_clamp) (if enabled)
10
 
11
  where rms_norm(x) = x / sqrt(mean(x^2, dim=-1) + eps)
12
 
13
+ CUDA kernel (activation/grouped_poly_norm.cu):
14
+ - Vectorized loads (width=8 for bf16/fp16, width=4 for fp32)
15
+ - In-kernel binary search for expert mapping
16
+ - 2-pass forward (RMS stats + output), 2-pass backward (dot products + grads)
17
+ - Scores and hidden_clamp fused in-kernel (no extra kernel launches)
18
+ - Weight/bias gradients via atomicAdd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  """
20
 
21
  import torch
22
  from torch import Tensor
23
 
 
 
 
 
 
 
 
 
24
  # Try to load CUDA ops at module level
25
  _ops = None
26
  try:
 
38
  if _has_cuda_ops:
39
  try:
40
  @torch.library.register_fake("_activation::grouped_poly_norm_forward")
41
+ def _fwd_fake(input, mul, weight, bias, offsets, eps, expert_offset,
42
+ hidden_clamp):
43
  return (torch.empty_like(input),
44
  torch.empty(input.shape[0], 3, dtype=torch.float32,
45
  device=input.device))
46
 
47
  @torch.library.register_fake("_activation::grouped_poly_norm_backward")
48
  def _bwd_fake(grad_output, input, mul, weight, bias, offsets, inv_rms,
49
+ eps, expert_offset, hidden_clamp):
50
  return (torch.empty_like(input),
51
  torch.empty_like(mul),
52
  torch.empty_like(weight),
 
142
 
143
 
144
  # ---------------------------------------------------------------------------
145
+ # CUDA kernel autograd functions
146
  # ---------------------------------------------------------------------------
147
+ if _has_cuda_ops:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
  class _GroupedPolyNormFn(torch.autograd.Function):
150
  """Without scores — follows poly_norm.py pattern."""
151
 
152
  @staticmethod
153
+ def forward(input, mul, weight, bias, offsets, eps, expert_offset,
154
+ hidden_clamp):
155
  input = input.contiguous()
156
  mul = mul.contiguous()
157
  output, inv_rms = _ops.grouped_poly_norm_forward(
158
+ input, mul, weight, bias, offsets, eps, expert_offset,
159
+ hidden_clamp)
160
  return output, inv_rms
161
 
162
  @staticmethod
163
  def setup_context(ctx, inputs, output):
164
+ (input, mul, weight, bias, offsets, eps, expert_offset,
165
+ hidden_clamp) = inputs
166
  _, inv_rms = output
167
  ctx.save_for_backward(input, mul, weight, bias, offsets, inv_rms)
168
  ctx.eps = eps
169
  ctx.expert_offset = expert_offset
170
+ ctx.hidden_clamp = hidden_clamp
171
 
172
  @staticmethod
173
  def backward(ctx, grad_output, _grad_inv_rms):
 
175
  grad_output = grad_output.contiguous()
176
  gi, gm, gw, gb = _ops.grouped_poly_norm_backward(
177
  grad_output, input, mul, weight, bias, offsets, inv_rms,
178
+ ctx.eps, ctx.expert_offset, ctx.hidden_clamp)
179
+ return gi, gm, gw, gb, None, None, None, None
180
 
181
  class _GroupedPolyNormScoredFn(torch.autograd.Function):
182
  """With scores — same pattern, adds scores + hidden_clamp."""
 
249
  expert_offset, clamp_val)
250
  else:
251
  output, _ = _GroupedPolyNormFn.apply(
252
+ input, mul, weight, bias, offsets, eps, expert_offset,
253
+ clamp_val)
254
  return output
255
 
256
  else:
 
267
  hidden_clamp: float | None = None,
268
  ) -> Tensor:
269
  raise RuntimeError(
270
+ "CUDA ops not available. Build with setup.py or kernel-builder "
271
+ "to use fused_mul_grouped_poly_norm.")
torch-ext/torch_binding.cpp CHANGED
@@ -49,18 +49,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
49
  ops.impl("fused_add_rms_norm_backward", torch::kCUDA,
50
  &fused_add_rms_norm_backward);
51
 
52
- // grouped_poly_norm (without scores)
53
  ops.def("grouped_poly_norm_forward("
54
  "Tensor input, Tensor mul, Tensor weight, "
55
  "Tensor bias, Tensor offsets, "
56
- "float eps, int expert_offset) -> (Tensor, Tensor)");
57
  ops.impl("grouped_poly_norm_forward", torch::kCUDA,
58
  &grouped_poly_norm_forward);
59
 
60
  ops.def("grouped_poly_norm_backward("
61
  "Tensor grad_output, Tensor input, Tensor mul, Tensor weight, "
62
  "Tensor bias, Tensor offsets, Tensor inv_rms, "
63
- "float eps, int expert_offset) -> (Tensor, Tensor, Tensor, Tensor)");
64
  ops.impl("grouped_poly_norm_backward", torch::kCUDA,
65
  &grouped_poly_norm_backward);
66
 
 
49
  ops.impl("fused_add_rms_norm_backward", torch::kCUDA,
50
  &fused_add_rms_norm_backward);
51
 
52
+ // grouped_poly_norm (without scores, hidden_clamp < 0 = disabled)
53
  ops.def("grouped_poly_norm_forward("
54
  "Tensor input, Tensor mul, Tensor weight, "
55
  "Tensor bias, Tensor offsets, "
56
+ "float eps, int expert_offset, float hidden_clamp) -> (Tensor, Tensor)");
57
  ops.impl("grouped_poly_norm_forward", torch::kCUDA,
58
  &grouped_poly_norm_forward);
59
 
60
  ops.def("grouped_poly_norm_backward("
61
  "Tensor grad_output, Tensor input, Tensor mul, Tensor weight, "
62
  "Tensor bias, Tensor offsets, Tensor inv_rms, "
63
+ "float eps, int expert_offset, float hidden_clamp) -> (Tensor, Tensor, Tensor, Tensor)");
64
  ops.impl("grouped_poly_norm_backward", torch::kCUDA,
65
  &grouped_poly_norm_backward);
66
 
torch-ext/torch_binding.h CHANGED
@@ -36,19 +36,21 @@ std::tuple<torch::Tensor, torch::Tensor> fused_add_rms_norm_backward(
36
  const torch::Tensor &input, const torch::Tensor &weight, double eps,
37
  bool need_input_grad);
38
 
39
- // Without scores
40
  std::tuple<torch::Tensor, torch::Tensor>
41
  grouped_poly_norm_forward(
42
  const torch::Tensor &input, const torch::Tensor &mul,
43
  const torch::Tensor &weight, const torch::Tensor &bias,
44
- const torch::Tensor &offsets, double eps, int64_t expert_offset);
 
45
 
46
  std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
47
  grouped_poly_norm_backward(
48
  const torch::Tensor &grad_output, const torch::Tensor &input,
49
  const torch::Tensor &mul, const torch::Tensor &weight,
50
  const torch::Tensor &bias, const torch::Tensor &offsets,
51
- const torch::Tensor &inv_rms, double eps, int64_t expert_offset);
 
52
 
53
  // With scores (hidden_clamp < 0 = disabled)
54
  std::tuple<torch::Tensor, torch::Tensor>
 
36
  const torch::Tensor &input, const torch::Tensor &weight, double eps,
37
  bool need_input_grad);
38
 
39
+ // Without scores (hidden_clamp < 0 = disabled)
40
  std::tuple<torch::Tensor, torch::Tensor>
41
  grouped_poly_norm_forward(
42
  const torch::Tensor &input, const torch::Tensor &mul,
43
  const torch::Tensor &weight, const torch::Tensor &bias,
44
+ const torch::Tensor &offsets, double eps, int64_t expert_offset,
45
+ double hidden_clamp);
46
 
47
  std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
48
  grouped_poly_norm_backward(
49
  const torch::Tensor &grad_output, const torch::Tensor &input,
50
  const torch::Tensor &mul, const torch::Tensor &weight,
51
  const torch::Tensor &bias, const torch::Tensor &offsets,
52
+ const torch::Tensor &inv_rms, double eps, int64_t expert_offset,
53
+ double hidden_clamp);
54
 
55
  // With scores (hidden_clamp < 0 = disabled)
56
  std::tuple<torch::Tensor, torch::Tensor>