Kernels
wyldecat Claude Opus 4.6 (1M context) commited on
Commit
f06406d
·
1 Parent(s): 656a6f4

test: add scores and hidden_clamp tests for fused_mul_grouped_poly_norm

Browse files

- Rename test file to match fused_mul_grouped_poly_norm convention
- Add forward/backward tests for scores fusion
- Add forward/backward tests for hidden_clamp fusion (clamp values 10.0, 1.0, 0.5)
- Adjust weight/bias grad tolerance for atomicAdd accumulation order
- 192 tests, all passing

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

tests/{test_grouped_fused_mul_poly_norm.py → test_fused_mul_grouped_poly_norm.py} RENAMED
@@ -3,11 +3,11 @@ import torch
3
 
4
  from grouped_poly_norm import (
5
  HAS_TRITON,
6
- grouped_fused_mul_poly_norm_ref,
7
  )
8
 
9
  if HAS_TRITON:
10
- from grouped_poly_norm import grouped_fused_mul_poly_norm
11
 
12
  from .utils import assert_close
13
 
@@ -26,7 +26,7 @@ CUDA_DEVICES = ["cuda:0"]
26
  def _counts_to_offsets(counts_list, device):
27
  """Convert list of counts to cumsum offsets tensor."""
28
  return torch.cumsum(
29
- torch.tensor(counts_list, device=device, dtype=torch.int32), dim=0)
30
 
31
 
32
  def _make_inputs(total_tokens, hidden_dim, num_experts, dtype, device,
@@ -54,32 +54,45 @@ def _make_inputs(total_tokens, hidden_dim, num_experts, dtype, device,
54
  return input_t, mul_t, weight, bias, offsets
55
 
56
 
57
- def _run_ref(input_t, mul_t, weight, bias, offsets, expert_offset=0):
 
 
 
 
 
 
58
  """Run reference forward + backward, return output and grads."""
59
  inp = input_t.clone().detach().requires_grad_(True)
60
  m = mul_t.clone().detach().requires_grad_(True)
61
  w = weight.clone().detach().requires_grad_(True)
62
  b = bias.clone().detach().requires_grad_(True)
 
63
 
64
- out = grouped_fused_mul_poly_norm_ref(inp, m, w, b, offsets,
65
- expert_offset=expert_offset)
 
66
  out.sum().backward()
67
 
68
- return out, inp.grad, m.grad, w.grad, b.grad
 
69
 
70
 
71
- def _run_triton(input_t, mul_t, weight, bias, offsets, expert_offset=0):
72
- """Run Triton forward + backward, return output and grads."""
 
73
  inp = input_t.clone().detach().requires_grad_(True)
74
  m = mul_t.clone().detach().requires_grad_(True)
75
  w = weight.clone().detach().requires_grad_(True)
76
  b = bias.clone().detach().requires_grad_(True)
 
77
 
78
- out = grouped_fused_mul_poly_norm(inp, m, w, b, offsets,
79
- expert_offset=expert_offset)
 
80
  out.sum().backward()
81
 
82
- return out, inp.grad, m.grad, w.grad, b.grad
 
83
 
84
 
85
  @pytest.mark.skipif(not HAS_TRITON, reason="Triton not available")
@@ -90,7 +103,7 @@ def _run_triton(input_t, mul_t, weight, bias, offsets, expert_offset=0):
90
  @pytest.mark.parametrize("expert_offset", EXPERT_OFFSETS)
91
  @pytest.mark.parametrize("seed", SEEDS)
92
  @pytest.mark.parametrize("device", CUDA_DEVICES)
93
- def test_grouped_fused_mul_poly_norm_forward(
94
  num_tokens: int,
95
  d: int,
96
  num_experts: int,
@@ -105,10 +118,10 @@ def test_grouped_fused_mul_poly_norm_forward(
105
  num_tokens, d, num_experts, dtype, device, seed,
106
  expert_offset=expert_offset)
107
 
108
- out_ref = grouped_fused_mul_poly_norm_ref(input_t, mul_t, weight, bias,
109
  offsets,
110
  expert_offset=expert_offset)
111
- out_tri = grouped_fused_mul_poly_norm(input_t, mul_t, weight, bias,
112
  offsets,
113
  expert_offset=expert_offset)
114
 
@@ -129,7 +142,7 @@ def test_grouped_fused_mul_poly_norm_forward(
129
  @pytest.mark.parametrize("expert_offset", EXPERT_OFFSETS)
130
  @pytest.mark.parametrize("seed", SEEDS)
131
  @pytest.mark.parametrize("device", CUDA_DEVICES)
132
- def test_grouped_fused_mul_poly_norm_backward(
133
  num_tokens: int,
134
  d: int,
135
  num_experts: int,
@@ -144,9 +157,9 @@ def test_grouped_fused_mul_poly_norm_backward(
144
  num_tokens, d, num_experts, dtype, device, seed,
145
  expert_offset=expert_offset)
146
 
147
- _, inp_grad_ref, mul_grad_ref, w_grad_ref, b_grad_ref = _run_ref(
148
  input_t, mul_t, weight, bias, offsets, expert_offset=expert_offset)
149
- _, inp_grad_tri, mul_grad_tri, w_grad_tri, b_grad_tri = _run_triton(
150
  input_t, mul_t, weight, bias, offsets, expert_offset=expert_offset)
151
 
152
  if dtype == torch.float32:
@@ -164,7 +177,7 @@ def test_grouped_fused_mul_poly_norm_backward(
164
  @pytest.mark.parametrize("dtype", DTYPES)
165
  @pytest.mark.parametrize("expert_offset", EXPERT_OFFSETS)
166
  @pytest.mark.parametrize("device", CUDA_DEVICES)
167
- def test_grouped_fused_mul_poly_norm_zero_token_experts(
168
  dtype: torch.dtype,
169
  expert_offset: int,
170
  device: str,
@@ -184,10 +197,10 @@ def test_grouped_fused_mul_poly_norm_zero_token_experts(
184
  bias = torch.zeros(total_experts, 1, device=device, dtype=dtype)
185
  offsets = _counts_to_offsets(counts, device)
186
 
187
- out_ref = grouped_fused_mul_poly_norm_ref(input_t, mul_t, weight, bias,
188
  offsets,
189
  expert_offset=expert_offset)
190
- out_tri = grouped_fused_mul_poly_norm(input_t, mul_t, weight, bias,
191
  offsets,
192
  expert_offset=expert_offset)
193
 
@@ -197,12 +210,12 @@ def test_grouped_fused_mul_poly_norm_zero_token_experts(
197
  assert_close(out_ref, out_tri, atol=1e-2, rtol=1e-2)
198
 
199
  # Check backward with zero-token experts
200
- _, _, _, w_grad_ref, b_grad_ref = _run_ref(input_t, mul_t, weight, bias,
201
- offsets,
202
- expert_offset=expert_offset)
203
- _, _, _, w_grad_tri, b_grad_tri = _run_triton(input_t, mul_t, weight, bias,
204
- offsets,
205
- expert_offset=expert_offset)
206
 
207
  if dtype == torch.float32:
208
  atol, rtol = 1e-3, 1e-3
@@ -227,7 +240,7 @@ def test_grouped_fused_mul_poly_norm_zero_token_experts(
227
  @pytest.mark.parametrize("dtype", DTYPES)
228
  @pytest.mark.parametrize("expert_offset", EXPERT_OFFSETS)
229
  @pytest.mark.parametrize("device", CUDA_DEVICES)
230
- def test_grouped_fused_mul_poly_norm_no_nan_inf(
231
  dtype: torch.dtype,
232
  expert_offset: int,
233
  device: str,
@@ -237,7 +250,7 @@ def test_grouped_fused_mul_poly_norm_no_nan_inf(
237
  input_t, mul_t, weight, bias, offsets = _make_inputs(
238
  4096, 256, 8, dtype, device, expert_offset=expert_offset)
239
 
240
- out, inp_grad, mul_grad, w_grad, b_grad = _run_triton(
241
  input_t, mul_t, weight, bias, offsets, expert_offset=expert_offset)
242
 
243
  assert not out.isnan().any(), "Output contains NaN"
@@ -247,3 +260,128 @@ def test_grouped_fused_mul_poly_norm_no_nan_inf(
247
  ("weight", w_grad), ("bias", b_grad)]:
248
  assert not grad.isnan().any(), f"{name}_grad contains NaN"
249
  assert not grad.isinf().any(), f"{name}_grad contains Inf"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  from grouped_poly_norm import (
5
  HAS_TRITON,
6
+ fused_mul_grouped_poly_norm_ref,
7
  )
8
 
9
  if HAS_TRITON:
10
+ from grouped_poly_norm import fused_mul_grouped_poly_norm
11
 
12
  from .utils import assert_close
13
 
 
26
  def _counts_to_offsets(counts_list, device):
27
  """Convert list of counts to cumsum offsets tensor."""
28
  return torch.cumsum(
29
+ torch.tensor(counts_list, device=device, dtype=torch.int32), dim=0).to(torch.int32)
30
 
31
 
32
  def _make_inputs(total_tokens, hidden_dim, num_experts, dtype, device,
 
54
  return input_t, mul_t, weight, bias, offsets
55
 
56
 
57
+ def _make_scores(total_tokens, device, dtype=torch.float32):
58
+ """Create random scores (N, 1) in fp32."""
59
+ return torch.rand(total_tokens, 1, device=device, dtype=dtype) * 0.5 + 0.5
60
+
61
+
62
+ def _run_ref(input_t, mul_t, weight, bias, offsets, expert_offset=0,
63
+ scores=None, hidden_clamp=None):
64
  """Run reference forward + backward, return output and grads."""
65
  inp = input_t.clone().detach().requires_grad_(True)
66
  m = mul_t.clone().detach().requires_grad_(True)
67
  w = weight.clone().detach().requires_grad_(True)
68
  b = bias.clone().detach().requires_grad_(True)
69
+ s = scores.clone().detach().requires_grad_(True) if scores is not None else None
70
 
71
+ out = fused_mul_grouped_poly_norm_ref(inp, m, w, b, offsets,
72
+ expert_offset=expert_offset,
73
+ scores=s, hidden_clamp=hidden_clamp)
74
  out.sum().backward()
75
 
76
+ grads = (out, inp.grad, m.grad, w.grad, b.grad)
77
+ return grads + (s.grad,) if s is not None else grads + (None,)
78
 
79
 
80
+ def _run_triton(input_t, mul_t, weight, bias, offsets, expert_offset=0,
81
+ scores=None, hidden_clamp=None):
82
+ """Run Triton/CUDA forward + backward, return output and grads."""
83
  inp = input_t.clone().detach().requires_grad_(True)
84
  m = mul_t.clone().detach().requires_grad_(True)
85
  w = weight.clone().detach().requires_grad_(True)
86
  b = bias.clone().detach().requires_grad_(True)
87
+ s = scores.clone().detach().requires_grad_(True) if scores is not None else None
88
 
89
+ out = fused_mul_grouped_poly_norm(inp, m, w, b, offsets,
90
+ expert_offset=expert_offset,
91
+ scores=s, hidden_clamp=hidden_clamp)
92
  out.sum().backward()
93
 
94
+ grads = (out, inp.grad, m.grad, w.grad, b.grad)
95
+ return grads + (s.grad,) if s is not None else grads + (None,)
96
 
97
 
98
  @pytest.mark.skipif(not HAS_TRITON, reason="Triton not available")
 
103
  @pytest.mark.parametrize("expert_offset", EXPERT_OFFSETS)
104
  @pytest.mark.parametrize("seed", SEEDS)
105
  @pytest.mark.parametrize("device", CUDA_DEVICES)
106
+ def test_fused_mul_grouped_poly_norm_forward(
107
  num_tokens: int,
108
  d: int,
109
  num_experts: int,
 
118
  num_tokens, d, num_experts, dtype, device, seed,
119
  expert_offset=expert_offset)
120
 
121
+ out_ref = fused_mul_grouped_poly_norm_ref(input_t, mul_t, weight, bias,
122
  offsets,
123
  expert_offset=expert_offset)
124
+ out_tri = fused_mul_grouped_poly_norm(input_t, mul_t, weight, bias,
125
  offsets,
126
  expert_offset=expert_offset)
127
 
 
142
  @pytest.mark.parametrize("expert_offset", EXPERT_OFFSETS)
143
  @pytest.mark.parametrize("seed", SEEDS)
144
  @pytest.mark.parametrize("device", CUDA_DEVICES)
145
+ def test_fused_mul_grouped_poly_norm_backward(
146
  num_tokens: int,
147
  d: int,
148
  num_experts: int,
 
157
  num_tokens, d, num_experts, dtype, device, seed,
158
  expert_offset=expert_offset)
159
 
160
+ _, inp_grad_ref, mul_grad_ref, w_grad_ref, b_grad_ref, _ = _run_ref(
161
  input_t, mul_t, weight, bias, offsets, expert_offset=expert_offset)
162
+ _, inp_grad_tri, mul_grad_tri, w_grad_tri, b_grad_tri, _ = _run_triton(
163
  input_t, mul_t, weight, bias, offsets, expert_offset=expert_offset)
164
 
165
  if dtype == torch.float32:
 
177
  @pytest.mark.parametrize("dtype", DTYPES)
178
  @pytest.mark.parametrize("expert_offset", EXPERT_OFFSETS)
179
  @pytest.mark.parametrize("device", CUDA_DEVICES)
180
+ def test_fused_mul_grouped_poly_norm_zero_token_experts(
181
  dtype: torch.dtype,
182
  expert_offset: int,
183
  device: str,
 
197
  bias = torch.zeros(total_experts, 1, device=device, dtype=dtype)
198
  offsets = _counts_to_offsets(counts, device)
199
 
200
+ out_ref = fused_mul_grouped_poly_norm_ref(input_t, mul_t, weight, bias,
201
  offsets,
202
  expert_offset=expert_offset)
203
+ out_tri = fused_mul_grouped_poly_norm(input_t, mul_t, weight, bias,
204
  offsets,
205
  expert_offset=expert_offset)
206
 
 
210
  assert_close(out_ref, out_tri, atol=1e-2, rtol=1e-2)
211
 
212
  # Check backward with zero-token experts
213
+ _, _, _, w_grad_ref, b_grad_ref, _ = _run_ref(input_t, mul_t, weight, bias,
214
+ offsets,
215
+ expert_offset=expert_offset)
216
+ _, _, _, w_grad_tri, b_grad_tri, _ = _run_triton(input_t, mul_t, weight, bias,
217
+ offsets,
218
+ expert_offset=expert_offset)
219
 
220
  if dtype == torch.float32:
221
  atol, rtol = 1e-3, 1e-3
 
240
  @pytest.mark.parametrize("dtype", DTYPES)
241
  @pytest.mark.parametrize("expert_offset", EXPERT_OFFSETS)
242
  @pytest.mark.parametrize("device", CUDA_DEVICES)
243
+ def test_fused_mul_grouped_poly_norm_no_nan_inf(
244
  dtype: torch.dtype,
245
  expert_offset: int,
246
  device: str,
 
250
  input_t, mul_t, weight, bias, offsets = _make_inputs(
251
  4096, 256, 8, dtype, device, expert_offset=expert_offset)
252
 
253
+ out, inp_grad, mul_grad, w_grad, b_grad, _ = _run_triton(
254
  input_t, mul_t, weight, bias, offsets, expert_offset=expert_offset)
255
 
256
  assert not out.isnan().any(), "Output contains NaN"
 
260
  ("weight", w_grad), ("bias", b_grad)]:
261
  assert not grad.isnan().any(), f"{name}_grad contains NaN"
262
  assert not grad.isinf().any(), f"{name}_grad contains Inf"
263
+
264
+
265
+ # ---------------------------------------------------------------------------
266
+ # Scores tests
267
+ # ---------------------------------------------------------------------------
268
+ @pytest.mark.skipif(not HAS_TRITON, reason="Triton not available")
269
+ @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
270
+ @pytest.mark.parametrize("d", D)
271
+ @pytest.mark.parametrize("num_experts", [8, 48])
272
+ @pytest.mark.parametrize("dtype", DTYPES)
273
+ @pytest.mark.parametrize("device", CUDA_DEVICES)
274
+ def test_fused_mul_grouped_poly_norm_scores_forward(
275
+ num_tokens, d, num_experts, dtype, device,
276
+ ):
277
+ """Forward with scores should match reference."""
278
+ torch.set_default_device(device)
279
+ input_t, mul_t, weight, bias, offsets = _make_inputs(
280
+ num_tokens, d, num_experts, dtype, device)
281
+ scores = _make_scores(num_tokens, device)
282
+
283
+ out_ref = fused_mul_grouped_poly_norm_ref(
284
+ input_t, mul_t, weight, bias, offsets, scores=scores)
285
+ out_tri = fused_mul_grouped_poly_norm(
286
+ input_t, mul_t, weight, bias, offsets, scores=scores)
287
+
288
+ atol, rtol = (1e-4, 1e-4) if dtype == torch.float32 else (1e-2, 1e-2)
289
+ assert_close(out_ref, out_tri, atol=atol, rtol=rtol)
290
+
291
+
292
+ @pytest.mark.skipif(not HAS_TRITON, reason="Triton not available")
293
+ @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
294
+ @pytest.mark.parametrize("d", D)
295
+ @pytest.mark.parametrize("num_experts", [8, 48])
296
+ @pytest.mark.parametrize("dtype", DTYPES)
297
+ @pytest.mark.parametrize("device", CUDA_DEVICES)
298
+ def test_fused_mul_grouped_poly_norm_scores_backward(
299
+ num_tokens, d, num_experts, dtype, device,
300
+ ):
301
+ """Backward with scores should match reference."""
302
+ torch.set_default_device(device)
303
+ input_t, mul_t, weight, bias, offsets = _make_inputs(
304
+ num_tokens, d, num_experts, dtype, device)
305
+ scores = _make_scores(num_tokens, device)
306
+
307
+ out_ref, ig_ref, mg_ref, wg_ref, bg_ref, sg_ref = _run_ref(
308
+ input_t, mul_t, weight, bias, offsets, scores=scores)
309
+ out_tri, ig_tri, mg_tri, wg_tri, bg_tri, sg_tri = _run_triton(
310
+ input_t, mul_t, weight, bias, offsets, scores=scores)
311
+
312
+ atol, rtol = (1e-4, 1e-4) if dtype == torch.float32 else (5e-2, 5e-2)
313
+ # weight/bias grads use atomicAdd accumulation across tokens,
314
+ # so allow slightly higher tolerance for fp32
315
+ wg_atol = 5e-4 if dtype == torch.float32 else 5e-2
316
+ assert_close(ig_ref, ig_tri, atol=atol, rtol=rtol)
317
+ assert_close(mg_ref, mg_tri, atol=atol, rtol=rtol)
318
+ assert_close(wg_ref, wg_tri, atol=wg_atol, rtol=wg_atol)
319
+ assert_close(bg_ref, bg_tri, atol=wg_atol, rtol=wg_atol)
320
+ assert_close(sg_ref, sg_tri, atol=atol, rtol=rtol)
321
+
322
+
323
+ # ---------------------------------------------------------------------------
324
+ # Hidden clamp tests
325
+ # ---------------------------------------------------------------------------
326
+ CLAMP_VALUES = [10.0, 1.0, 0.5]
327
+
328
+
329
+ @pytest.mark.skipif(not HAS_TRITON, reason="Triton not available")
330
+ @pytest.mark.parametrize("num_tokens", [4096])
331
+ @pytest.mark.parametrize("d", [256, 1280])
332
+ @pytest.mark.parametrize("num_experts", [8])
333
+ @pytest.mark.parametrize("dtype", DTYPES)
334
+ @pytest.mark.parametrize("hidden_clamp", CLAMP_VALUES)
335
+ @pytest.mark.parametrize("device", CUDA_DEVICES)
336
+ def test_fused_mul_grouped_poly_norm_hidden_clamp_forward(
337
+ num_tokens, d, num_experts, dtype, hidden_clamp, device,
338
+ ):
339
+ """Forward with hidden_clamp should match reference."""
340
+ torch.set_default_device(device)
341
+ input_t, mul_t, weight, bias, offsets = _make_inputs(
342
+ num_tokens, d, num_experts, dtype, device)
343
+ scores = _make_scores(num_tokens, device)
344
+
345
+ out_ref = fused_mul_grouped_poly_norm_ref(
346
+ input_t, mul_t, weight, bias, offsets,
347
+ scores=scores, hidden_clamp=hidden_clamp)
348
+ out_tri = fused_mul_grouped_poly_norm(
349
+ input_t, mul_t, weight, bias, offsets,
350
+ scores=scores, hidden_clamp=hidden_clamp)
351
+
352
+ atol, rtol = (1e-4, 1e-4) if dtype == torch.float32 else (1e-2, 1e-2)
353
+ assert_close(out_ref, out_tri, atol=atol, rtol=rtol)
354
+
355
+
356
+ @pytest.mark.skipif(not HAS_TRITON, reason="Triton not available")
357
+ @pytest.mark.parametrize("num_tokens", [4096])
358
+ @pytest.mark.parametrize("d", [256, 1280])
359
+ @pytest.mark.parametrize("num_experts", [8])
360
+ @pytest.mark.parametrize("dtype", DTYPES)
361
+ @pytest.mark.parametrize("hidden_clamp", CLAMP_VALUES)
362
+ @pytest.mark.parametrize("device", CUDA_DEVICES)
363
+ def test_fused_mul_grouped_poly_norm_hidden_clamp_backward(
364
+ num_tokens, d, num_experts, dtype, hidden_clamp, device,
365
+ ):
366
+ """Backward with hidden_clamp should match reference."""
367
+ torch.set_default_device(device)
368
+ input_t, mul_t, weight, bias, offsets = _make_inputs(
369
+ num_tokens, d, num_experts, dtype, device)
370
+ scores = _make_scores(num_tokens, device)
371
+
372
+ out_ref, ig_ref, mg_ref, wg_ref, bg_ref, sg_ref = _run_ref(
373
+ input_t, mul_t, weight, bias, offsets,
374
+ scores=scores, hidden_clamp=hidden_clamp)
375
+ out_tri, ig_tri, mg_tri, wg_tri, bg_tri, sg_tri = _run_triton(
376
+ input_t, mul_t, weight, bias, offsets,
377
+ scores=scores, hidden_clamp=hidden_clamp)
378
+
379
+ atol, rtol = (1e-4, 1e-4) if dtype == torch.float32 else (5e-2, 5e-2)
380
+ # weight/bias grads use atomicAdd accumulation across tokens,
381
+ # so allow slightly higher tolerance for fp32
382
+ wg_atol = 5e-4 if dtype == torch.float32 else 5e-2
383
+ assert_close(ig_ref, ig_tri, atol=atol, rtol=rtol)
384
+ assert_close(mg_ref, mg_tri, atol=atol, rtol=rtol)
385
+ assert_close(wg_ref, wg_tri, atol=wg_atol, rtol=wg_atol)
386
+ assert_close(bg_ref, bg_tri, atol=wg_atol, rtol=wg_atol)
387
+ assert_close(sg_ref, sg_tri, atol=atol, rtol=rtol)