Kernels
TaehyunKim Claude Opus 4.6 github-actions[bot] commited on
Commit
e195bbb
·
unverified ·
1 Parent(s): 46020a2

feat: add GroupedFusedMulPolyNorm Triton kernel for MoE models (#16)

Browse files

* feat: add GroupedFusedMulPolyNorm Triton kernel for MoE models

Fuses the full PolyNorm computation into two Triton kernels (fwd + bwd)
with per-expert weights/bias and in-kernel binary search for expert mapping.
Includes benchmarks, tests, and README documentation with B200 results.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* Add built binary [skip-build]

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +129 -1
  2. benchmarks/cases/grouped_mul_poly.py +122 -0
  3. benchmarks/common/bench_framework.py +24 -4
  4. benchmarks/run_cases.py +138 -75
  5. build/torch210-cxx11-cu126-x86_64-linux/__init__.py +2 -0
  6. build/torch210-cxx11-cu126-x86_64-linux/{_activation_18b7543_dirty.abi3.so → _activation_0e6f27f_dirty.abi3.so} +1 -1
  7. build/torch210-cxx11-cu126-x86_64-linux/_ops.py +3 -3
  8. build/torch210-cxx11-cu126-x86_64-linux/grouped_poly_norm.py +583 -0
  9. build/torch210-cxx11-cu128-x86_64-linux/__init__.py +2 -0
  10. build/torch210-cxx11-cu128-x86_64-linux/{_activation_18b7543_dirty.abi3.so → _activation_0e6f27f_dirty.abi3.so} +1 -1
  11. build/torch210-cxx11-cu128-x86_64-linux/_ops.py +3 -3
  12. build/torch210-cxx11-cu128-x86_64-linux/grouped_poly_norm.py +583 -0
  13. build/torch210-cxx11-cu130-x86_64-linux/__init__.py +2 -0
  14. build/torch210-cxx11-cu130-x86_64-linux/{_activation_18b7543_dirty.abi3.so → _activation_0e6f27f_dirty.abi3.so} +1 -1
  15. build/torch210-cxx11-cu130-x86_64-linux/_ops.py +3 -3
  16. build/torch210-cxx11-cu130-x86_64-linux/grouped_poly_norm.py +583 -0
  17. build/torch210-cxx11-rocm70-x86_64-linux/__init__.py +2 -0
  18. build/torch210-cxx11-rocm70-x86_64-linux/{_activation_18b7543_dirty.abi3.so → _activation_0e6f27f_dirty.abi3.so} +1 -1
  19. build/torch210-cxx11-rocm70-x86_64-linux/_ops.py +3 -3
  20. build/torch210-cxx11-rocm70-x86_64-linux/grouped_poly_norm.py +583 -0
  21. build/torch210-cxx11-rocm71-x86_64-linux/__init__.py +2 -0
  22. build/torch210-cxx11-rocm71-x86_64-linux/{_activation_18b7543_dirty.abi3.so → _activation_0e6f27f_dirty.abi3.so} +1 -1
  23. build/torch210-cxx11-rocm71-x86_64-linux/_ops.py +3 -3
  24. build/torch210-cxx11-rocm71-x86_64-linux/grouped_poly_norm.py +583 -0
  25. build/torch28-cxx11-cu126-x86_64-linux/__init__.py +2 -0
  26. build/torch28-cxx11-cu126-x86_64-linux/{_activation_18b7543_dirty.abi3.so → _activation_0e6f27f_dirty.abi3.so} +1 -1
  27. build/torch28-cxx11-cu126-x86_64-linux/_ops.py +3 -3
  28. build/torch28-cxx11-cu126-x86_64-linux/grouped_poly_norm.py +583 -0
  29. build/torch28-cxx11-cu128-x86_64-linux/__init__.py +2 -0
  30. build/torch28-cxx11-cu128-x86_64-linux/{_activation_18b7543_dirty.abi3.so → _activation_0e6f27f_dirty.abi3.so} +1 -1
  31. build/torch28-cxx11-cu128-x86_64-linux/_ops.py +3 -3
  32. build/torch28-cxx11-cu128-x86_64-linux/grouped_poly_norm.py +583 -0
  33. build/torch28-cxx11-cu129-x86_64-linux/__init__.py +2 -0
  34. build/torch28-cxx11-cu129-x86_64-linux/{_activation_18b7543_dirty.abi3.so → _activation_0e6f27f_dirty.abi3.so} +1 -1
  35. build/torch28-cxx11-cu129-x86_64-linux/_ops.py +3 -3
  36. build/torch28-cxx11-cu129-x86_64-linux/grouped_poly_norm.py +583 -0
  37. build/torch28-cxx11-rocm63-x86_64-linux/__init__.py +2 -0
  38. build/torch28-cxx11-rocm63-x86_64-linux/{_activation_18b7543_dirty.abi3.so → _activation_0e6f27f_dirty.abi3.so} +1 -1
  39. build/torch28-cxx11-rocm63-x86_64-linux/_ops.py +3 -3
  40. build/torch28-cxx11-rocm63-x86_64-linux/grouped_poly_norm.py +583 -0
  41. build/torch28-cxx11-rocm64-x86_64-linux/__init__.py +2 -0
  42. build/torch28-cxx11-rocm64-x86_64-linux/{_activation_18b7543_dirty.abi3.so → _activation_0e6f27f_dirty.abi3.so} +1 -1
  43. build/torch28-cxx11-rocm64-x86_64-linux/_ops.py +3 -3
  44. build/torch28-cxx11-rocm64-x86_64-linux/grouped_poly_norm.py +583 -0
  45. build/torch29-cxx11-cu126-x86_64-linux/__init__.py +2 -0
  46. build/torch29-cxx11-cu126-x86_64-linux/{_activation_18b7543_dirty.abi3.so → _activation_0e6f27f_dirty.abi3.so} +1 -1
  47. build/torch29-cxx11-cu126-x86_64-linux/_ops.py +3 -3
  48. build/torch29-cxx11-cu126-x86_64-linux/grouped_poly_norm.py +583 -0
  49. build/torch29-cxx11-cu128-x86_64-linux/__init__.py +2 -0
  50. build/torch29-cxx11-cu128-x86_64-linux/{_activation_18b7543_dirty.abi3.so → _activation_0e6f27f_dirty.abi3.so} +1 -1
README.md CHANGED
@@ -19,7 +19,7 @@ Activation is a python package that contains custom CUDA-based activation kernel
19
  ```python
20
  y = x + residual
21
  hidden_state = rms_norm(y, weight, eps)
22
- out = y + some_op(hidden_state)
23
  ```
24
 
25
  - Fused as:
@@ -45,6 +45,22 @@ Activation is a python package that contains custom CUDA-based activation kernel
45
  out = fused_mul_poly_norm(x, a, weight, bias, eps)
46
  ```
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  ## Usage
49
 
50
  ```python
@@ -214,6 +230,118 @@ print(poly_norm(x))
214
 
215
  </details>
216
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
  ## Pre-commit Hooks
218
 
219
  This project uses [pre-commit](https://pre-commit.com/) to automatically check and format code before commits.
 
19
  ```python
20
  y = x + residual
21
  hidden_state = rms_norm(y, weight, eps)
22
+ out = y + some_op(hidden_state)
23
  ```
24
 
25
  - Fused as:
 
45
  out = fused_mul_poly_norm(x, a, weight, bias, eps)
46
  ```
47
 
48
+ - **GroupedFusedMulPolyNorm** (Triton)
49
+
50
+ A Triton-accelerated grouped variant of FusedMulPolyNorm for **MoE (Mixture of Experts)** models. Fuses the entire PolyNorm computation into two Triton kernels (fwd + bwd), with per-expert weights/bias and in-kernel binary search for expert mapping.
51
+ - Instead of:
52
+
53
+ ```python
54
+ for i, expert in enumerate(experts):
55
+ out[start:end] = fused_mul_poly_norm(x[start:end], mul[start:end], weight[i], bias[i], eps)
56
+ ```
57
+
58
+ - Fused as:
59
+
60
+ ```python
61
+ out = grouped_fused_mul_poly_norm(x, mul, weight, bias, offsets, eps)
62
+ ```
63
+
64
  ## Usage
65
 
66
  ```python
 
230
 
231
  </details>
232
 
233
+ ---
234
+
235
+ ### GroupedFusedMulPolyNorm (Triton)
236
+
237
+ > [!NOTE]
238
+ > This kernel is implemented in Triton (JIT-compiled, no CUDA C++ build required).
239
+ > Benchmarks compare three variants: **Naive** (raw PyTorch reference), **Compiled** (`torch.compile`'d reference), and **Triton** (fused Triton kernel).
240
+ > Benchmark dimension: 1280, 384 experts.
241
+
242
+ #### B200 Results (bf16)
243
+
244
+ <details>
245
+ <summary>Forward Performance</summary>
246
+
247
+ | batch_size | seq_len | Naive (us) | Compiled (us) | Triton (us) | Triton vs Naive |
248
+ |-----------|---------|-----------|--------------|------------|-----------------|
249
+ | 1 | 1024 | 294.54 | 73.46 | 64.33 | 4.58x |
250
+ | 1 | 2048 | 373.50 | 94.88 | 65.26 | 5.72x |
251
+ | 1 | 4096 | 372.65 | 94.90 | 66.90 | 5.57x |
252
+ | 1 | 8192 | 486.98 | 102.33 | 72.71 | 6.70x |
253
+ | 2 | 4096 | 486.66 | 101.87 | 72.27 | 6.73x |
254
+ | 2 | 8192 | 950.62 | 106.96 | 90.06 | 10.56x |
255
+ | 4 | 4096 | 950.72 | 107.17 | 71.28 | 13.34x |
256
+ | 4 | 8192 | 1779.12 | 198.91 | 96.93 | 18.35x |
257
+ | 8 | 4096 | 1778.73 | 199.10 | 96.88 | 18.36x |
258
+ | 8 | 8192 | 3384.03 | 381.91 | 179.57 | 18.85x |
259
+
260
+ </details>
261
+
262
+ <details>
263
+ <summary>Backward Performance</summary>
264
+
265
+ | batch_size | seq_len | Naive (us) | Compiled (us) | Triton (us) | Triton vs Naive |
266
+ |-----------|---------|-----------|--------------|------------|-----------------|
267
+ | 1 | 1024 | 1690.61 | 999.66 | 1017.66 | 1.66x |
268
+ | 1 | 8192 | 1680.39 | 906.43 | 906.41 | 1.85x |
269
+ | 2 | 8192 | 2466.73 | 870.74 | 862.78 | 2.86x |
270
+ | 4 | 4096 | 2466.04 | 942.62 | 945.68 | 2.61x |
271
+ | 4 | 8192 | 4543.10 | 941.01 | 908.30 | 5.00x |
272
+ | 8 | 4096 | 4542.91 | 814.73 | 900.01 | 5.05x |
273
+ | 8 | 8192 | 8599.41 | 956.81 | 955.07 | 9.00x |
274
+
275
+ </details>
276
+
277
+ <details>
278
+ <summary>Forward + Backward Combined</summary>
279
+
280
+ | batch_size | seq_len | Naive (us) | Compiled (us) | Triton (us) | Triton vs Naive | Triton vs Compiled |
281
+ |-----------|---------|-----------|--------------|------------|-----------------|-------------------|
282
+ | 1 | 1024 | 1985.15 | 1073.12 | 1081.99 | 1.83x | 0.99x |
283
+ | 1 | 4096 | 2085.10 | 974.32 | 960.73 | 2.17x | 1.01x |
284
+ | 1 | 8192 | 2167.37 | 1008.76 | 979.12 | 2.21x | 1.03x |
285
+ | 2 | 4096 | 2083.49 | 1001.03 | 965.30 | 2.16x | 1.04x |
286
+ | 2 | 8192 | 3417.35 | 977.70 | 952.84 | 3.59x | 1.03x |
287
+ | 4 | 4096 | 3416.76 | 1049.79 | 1016.97 | 3.36x | 1.03x |
288
+ | 4 | 8192 | 6322.22 | 1139.92 | 1005.23 | 6.29x | 1.13x |
289
+ | 8 | 4096 | 6321.64 | 1013.83 | 996.89 | 6.34x | 1.02x |
290
+ | 8 | 8192 | 11983.44 | 1338.71 | 1134.64 | 10.56x | 1.18x |
291
+
292
+ </details>
293
+
294
+ #### B200 Results (fp32)
295
+
296
+ <details>
297
+ <summary>Forward Performance</summary>
298
+
299
+ | batch_size | seq_len | Naive (us) | Compiled (us) | Triton (us) | Triton vs Naive |
300
+ |-----------|---------|-----------|--------------|------------|-----------------|
301
+ | 1 | 1024 | 318.05 | 83.29 | 64.24 | 4.95x |
302
+ | 1 | 2048 | 311.14 | 95.19 | 63.64 | 4.89x |
303
+ | 1 | 8192 | 401.78 | 101.61 | 68.21 | 5.89x |
304
+ | 2 | 4096 | 403.42 | 100.97 | 68.01 | 5.93x |
305
+ | 2 | 8192 | 803.31 | 130.51 | 68.21 | 11.78x |
306
+ | 4 | 4096 | 802.86 | 130.61 | 66.97 | 11.99x |
307
+ | 4 | 8192 | 1505.96 | 246.77 | 100.49 | 14.99x |
308
+ | 8 | 4096 | 1507.87 | 246.84 | 100.23 | 15.04x |
309
+ | 8 | 8192 | 2856.93 | 476.34 | 184.40 | 15.49x |
310
+
311
+ </details>
312
+
313
+ <details>
314
+ <summary>Backward Performance</summary>
315
+
316
+ | batch_size | seq_len | Naive (us) | Compiled (us) | Triton (us) | Triton vs Naive |
317
+ |-----------|---------|-----------|--------------|------------|-----------------|
318
+ | 1 | 1024 | 1604.25 | 989.30 | 1114.12 | 1.44x |
319
+ | 1 | 8192 | 1996.40 | 1117.71 | 1115.47 | 1.79x |
320
+ | 2 | 8192 | 2353.87 | 1119.41 | 1118.57 | 2.10x |
321
+ | 4 | 4096 | 2358.47 | 1102.23 | 1125.16 | 2.10x |
322
+ | 4 | 8192 | 4346.92 | 1125.33 | 1135.36 | 3.83x |
323
+ | 8 | 4096 | 4347.47 | 1104.27 | 1119.63 | 3.88x |
324
+ | 8 | 8192 | 8226.50 | 1172.66 | 1197.68 | 6.87x |
325
+
326
+ </details>
327
+
328
+ <details>
329
+ <summary>Forward + Backward Combined</summary>
330
+
331
+ | batch_size | seq_len | Naive (us) | Compiled (us) | Triton (us) | Triton vs Naive | Triton vs Compiled |
332
+ |-----------|---------|-----------|--------------|------------|-----------------|-------------------|
333
+ | 1 | 1024 | 1922.30 | 1072.59 | 1178.36 | 1.63x | 0.91x |
334
+ | 1 | 4096 | 2367.77 | 1208.69 | 1192.07 | 1.99x | 1.01x |
335
+ | 1 | 8192 | 2398.19 | 1219.32 | 1183.69 | 2.03x | 1.03x |
336
+ | 2 | 4096 | 2401.39 | 1248.87 | 1154.72 | 2.08x | 1.08x |
337
+ | 2 | 8192 | 3157.18 | 1249.92 | 1186.77 | 2.66x | 1.05x |
338
+ | 4 | 4096 | 3161.33 | 1232.84 | 1192.13 | 2.65x | 1.03x |
339
+ | 4 | 8192 | 5852.88 | 1372.10 | 1235.86 | 4.74x | 1.11x |
340
+ | 8 | 4096 | 5855.34 | 1351.11 | 1219.85 | 4.80x | 1.11x |
341
+ | 8 | 8192 | 11083.43 | 1649.00 | 1382.07 | 8.02x | 1.19x |
342
+
343
+ </details>
344
+
345
  ## Pre-commit Hooks
346
 
347
  This project uses [pre-commit](https://pre-commit.com/) to automatically check and format code before commits.
benchmarks/cases/grouped_mul_poly.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch._functorch.config
3
+ from common.diff_engine import DiffCase
4
+
5
+ torch._functorch.config.donated_buffer = False
6
+
7
+ from grouped_poly_norm import (
8
+ grouped_fused_mul_poly_norm,
9
+ grouped_fused_mul_poly_norm_ref,
10
+ )
11
+
12
+ NUM_EXPERTS = 384
13
+
14
+
15
+ class GroupedRefModule(torch.nn.Module):
16
+ """Wraps the PyTorch reference for grouped FusedMulPolyNorm."""
17
+
18
+ def __init__(self, weight, bias, offsets, eps, expert_offset=0):
19
+ super().__init__()
20
+ self.weight = torch.nn.Parameter(weight)
21
+ self.bias = torch.nn.Parameter(bias)
22
+ self.offsets = offsets
23
+ self.eps = eps
24
+ self.expert_offset = expert_offset
25
+
26
+ def forward(self, x, mul):
27
+ return grouped_fused_mul_poly_norm_ref(x, mul, self.weight, self.bias,
28
+ self.offsets, self.eps,
29
+ expert_offset=self.expert_offset)
30
+
31
+
32
+ class GroupedTritonModule(torch.nn.Module):
33
+ """Wraps the Triton kernel for grouped FusedMulPolyNorm."""
34
+
35
+ def __init__(self, weight, bias, offsets, eps, expert_offset=0):
36
+ super().__init__()
37
+ self.weight = torch.nn.Parameter(weight)
38
+ self.bias = torch.nn.Parameter(bias)
39
+ self.offsets = offsets
40
+ self.eps = eps
41
+ self.expert_offset = expert_offset
42
+
43
+ def forward(self, x, mul):
44
+ return grouped_fused_mul_poly_norm(x, mul, self.weight, self.bias,
45
+ self.offsets, self.eps,
46
+ expert_offset=self.expert_offset)
47
+
48
+
49
+ class GroupedMulPoly(DiffCase):
50
+ """Benchmark case for Grouped FusedMulPolyNorm (MoE).
51
+
52
+ Maps the framework's (bs, sl, hidden) to grouped polynorm's
53
+ (total_tokens, D) where total_tokens = bs * sl.
54
+ Uses a fixed number of experts with uniform token distribution.
55
+ """
56
+
57
+ def build_inputs(self, bs, sl, hidden, dtype, eps):
58
+ total_tokens = bs * sl
59
+ num_experts = min(NUM_EXPERTS, total_tokens)
60
+
61
+ torch.manual_seed(42)
62
+ probs = torch.ones(num_experts) / num_experts
63
+ assignments = torch.multinomial(probs, total_tokens, replacement=True)
64
+ counts = torch.bincount(assignments, minlength=num_experts).tolist()
65
+ offsets = torch.cumsum(
66
+ torch.tensor(counts, dtype=torch.int32), dim=0)
67
+
68
+ return {
69
+ "x":
70
+ torch.randn(total_tokens, hidden, dtype=dtype,
71
+ requires_grad=True) * 0.5,
72
+ "mul":
73
+ torch.randn(total_tokens, hidden, dtype=dtype,
74
+ requires_grad=True) * 0.5,
75
+ "weight":
76
+ torch.ones(num_experts, 3, dtype=dtype) / 3 +
77
+ torch.randn(num_experts, 3, dtype=dtype) * 0.01,
78
+ "bias":
79
+ torch.randn(num_experts, 1, dtype=dtype) * 0.01,
80
+ "offsets":
81
+ offsets,
82
+ "dim":
83
+ hidden,
84
+ "eps":
85
+ eps,
86
+ "dtype":
87
+ dtype,
88
+ }
89
+
90
+ def make_naive(self, I):
91
+ return GroupedRefModule(
92
+ I["weight"].detach().clone(),
93
+ I["bias"].detach().clone(),
94
+ I["offsets"],
95
+ I["eps"],
96
+ )
97
+
98
+ def make_compiled(self, I):
99
+ m = GroupedRefModule(
100
+ I["weight"].detach().clone(),
101
+ I["bias"].detach().clone(),
102
+ I["offsets"],
103
+ I["eps"],
104
+ )
105
+ return torch.compile(m)
106
+
107
+ def make_cuda(self, I):
108
+ return GroupedTritonModule(
109
+ I["weight"].detach().clone(),
110
+ I["bias"].detach().clone(),
111
+ I["offsets"],
112
+ I["eps"],
113
+ )
114
+
115
+ def forward(self, obj, I):
116
+ return obj(I["x"], I["mul"])
117
+
118
+ def grad_inputs(self, I):
119
+ return [I["x"], I["mul"]]
120
+
121
+
122
+ CASE = GroupedMulPoly()
benchmarks/common/bench_framework.py CHANGED
@@ -57,7 +57,12 @@ def make_fwd_benchmark_for_case(
57
  I = case.build_inputs(batch_size, seq_len, dim, dtype, eps)
58
  if provider == "speedup":
59
  return timings_ms["naive"][key] / timings_ms["cuda"][key]
60
- obj = case.make_naive(I) if provider == "naive" else case.make_cuda(I)
 
 
 
 
 
61
  run = lambda: case.forward(obj, I)
62
  ms = triton.testing.do_bench(run)
63
  timings_ms[provider][key] = ms
@@ -101,7 +106,12 @@ def make_fwd_benchmark_plot_for_case(
101
  return 1.00
102
  batch_size, seq_len, dim = parse_config_string(config)
103
  I = case.build_inputs(batch_size, seq_len, dim, dtype, eps)
104
- obj = case.make_naive(I) if provider == "naive" else case.make_cuda(I)
 
 
 
 
 
105
  run = lambda: case.forward(obj, I)
106
  ms = triton.testing.do_bench(run)
107
  timings_ms[provider][config] = ms
@@ -146,7 +156,12 @@ def make_bwd_benchmark_for_case(
146
  I = case.build_inputs(batch_size, seq_len, dim, dtype, eps)
147
  if provider == "speedup":
148
  return timings_ms["naive"][key] / timings_ms["cuda"][key]
149
- obj = case.make_naive(I) if provider == "naive" else case.make_cuda(I)
 
 
 
 
 
150
  y = case.forward(obj, I)
151
  gin = list(case.grad_inputs(I)) + list(obj.parameters())
152
  if isinstance(y, torch.Tensor):
@@ -201,7 +216,12 @@ def make_bwd_benchmark_plot_for_case(
201
  return 1.00
202
  batch_size, seq_len, dim = parse_config_string(config)
203
  I = case.build_inputs(batch_size, seq_len, dim, dtype, eps)
204
- obj = case.make_naive(I) if provider == "naive" else case.make_cuda(I)
 
 
 
 
 
205
  y = case.forward(obj, I)
206
  gin = list(case.grad_inputs(I)) + list(obj.parameters())
207
  if isinstance(y, torch.Tensor):
 
57
  I = case.build_inputs(batch_size, seq_len, dim, dtype, eps)
58
  if provider == "speedup":
59
  return timings_ms["naive"][key] / timings_ms["cuda"][key]
60
+ if provider == "naive":
61
+ obj = case.make_naive(I)
62
+ elif provider == "compiled" and hasattr(case, "make_compiled"):
63
+ obj = case.make_compiled(I)
64
+ else:
65
+ obj = case.make_cuda(I)
66
  run = lambda: case.forward(obj, I)
67
  ms = triton.testing.do_bench(run)
68
  timings_ms[provider][key] = ms
 
106
  return 1.00
107
  batch_size, seq_len, dim = parse_config_string(config)
108
  I = case.build_inputs(batch_size, seq_len, dim, dtype, eps)
109
+ if provider == "naive":
110
+ obj = case.make_naive(I)
111
+ elif provider == "compiled" and hasattr(case, "make_compiled"):
112
+ obj = case.make_compiled(I)
113
+ else:
114
+ obj = case.make_cuda(I)
115
  run = lambda: case.forward(obj, I)
116
  ms = triton.testing.do_bench(run)
117
  timings_ms[provider][config] = ms
 
156
  I = case.build_inputs(batch_size, seq_len, dim, dtype, eps)
157
  if provider == "speedup":
158
  return timings_ms["naive"][key] / timings_ms["cuda"][key]
159
+ if provider == "naive":
160
+ obj = case.make_naive(I)
161
+ elif provider == "compiled" and hasattr(case, "make_compiled"):
162
+ obj = case.make_compiled(I)
163
+ else:
164
+ obj = case.make_cuda(I)
165
  y = case.forward(obj, I)
166
  gin = list(case.grad_inputs(I)) + list(obj.parameters())
167
  if isinstance(y, torch.Tensor):
 
216
  return 1.00
217
  batch_size, seq_len, dim = parse_config_string(config)
218
  I = case.build_inputs(batch_size, seq_len, dim, dtype, eps)
219
+ if provider == "naive":
220
+ obj = case.make_naive(I)
221
+ elif provider == "compiled" and hasattr(case, "make_compiled"):
222
+ obj = case.make_compiled(I)
223
+ else:
224
+ obj = case.make_cuda(I)
225
  y = case.forward(obj, I)
226
  gin = list(case.grad_inputs(I)) + list(obj.parameters())
227
  if isinstance(y, torch.Tensor):
benchmarks/run_cases.py CHANGED
@@ -23,12 +23,15 @@ def make_title_tag():
23
  return f"[{dev_name} | torch {torch_ver}]"
24
 
25
 
26
- def plot_result(r_path):
27
  import matplotlib.pyplot as plt
28
  import pandas as pd
29
  df = pd.read_csv(r_path + ".csv")
 
 
 
30
  plt.figure(figsize=(12, 6))
31
- ax = df.plot(x="config", y=["Naive", "Cuda"], kind="bar", ax=plt.gca())
32
  ax.set_title("Speedup over torch (higher is better)\n" + make_title_tag(),
33
  fontsize=14,
34
  fontweight="bold")
@@ -44,9 +47,10 @@ def plot_result(r_path):
44
 
45
  def main():
46
  ap = argparse.ArgumentParser()
47
- ap.add_argument("--case",
48
- choices=["rms", "add_rms", "poly", "mul_poly"],
49
- required=True)
 
50
  ap.add_argument("--plot", action="store_true")
51
  ap.add_argument(
52
  "--save-path",
@@ -54,8 +58,25 @@ def main():
54
  default="./configs/",
55
  help="Path to save benchmark results",
56
  )
 
 
 
 
 
 
57
  args = ap.parse_args()
58
 
 
 
 
 
 
 
 
 
 
 
 
59
  torch.set_default_device("cuda")
60
  mod = importlib.import_module(f"cases.{args.case}")
61
  case: DiffCase = mod.CASE
@@ -67,76 +88,118 @@ def main():
67
  hidden_size=4096,
68
  )
69
 
70
- save_dir = os.path.join(args.save_path, args.case)
71
- if args.plot:
72
- batch_size_range = [1]
73
- seq_length_range = [4096, 8192, 16384]
74
- dim = [8192, 16384] if "poly" in args.case else [2048, 4096]
75
- configs = list(
76
- itertools.product(batch_size_range, seq_length_range, dim))
77
- plot_name = f"plot_{args.case}-fwd-perf"
78
- bench = make_fwd_benchmark_plot_for_case(
79
- case=case,
80
- configs=configs,
81
- plot_name=plot_name,
82
- line_names={
83
- "naive": "Naive",
84
- "cuda": "Cuda",
85
- },
86
- )
87
- bench.run(print_data=True, save_path=save_dir)
88
- plot_result(os.path.join(save_dir, plot_name))
89
-
90
- plot_name = f"plot_{args.case}-bwd-perf"
91
- bench = make_bwd_benchmark_plot_for_case(
92
- case=case,
93
- configs=configs,
94
- plot_name=plot_name,
95
- line_names={
96
- "naive": "Naive",
97
- "cuda": "Cuda",
98
- },
99
- )
100
- bench.run(print_data=True, save_path=save_dir)
101
- plot_result(os.path.join(save_dir, plot_name))
102
- for f in glob.glob(os.path.join(save_dir, "*.html")) + glob.glob(
103
- os.path.join(save_dir, "*.csv")):
104
- os.remove(f)
105
- else:
106
- batch_size_range = [2**i for i in range(0, 4, 1)]
107
- seq_length_range = [2**i for i in range(10, 14, 1)]
108
- dim = [8192, 16384] if "poly" in args.case else [2048, 4096]
109
- configs = list(
110
- itertools.product(dim, batch_size_range, seq_length_range))
111
-
112
- bench = make_fwd_benchmark_for_case(
113
- case=case,
114
- configs=configs,
115
- plot_name=f"{args.case}-fwd-perf",
116
- line_names={
117
- "naive": "Naive",
118
- "cuda": "Cuda",
119
- "speedup": "SpeedUp"
120
- },
121
- )
122
-
123
- bench.run(print_data=True, save_path=save_dir)
124
-
125
- bench = make_bwd_benchmark_for_case(
126
- case=case,
127
- configs=configs,
128
- plot_name=f"{args.case}-bwd-perf",
129
- line_names={
130
- "naive": "Naive",
131
- "cuda": "Cuda",
132
- "speedup": "SpeedUp"
133
- },
134
- )
135
-
136
- bench.run(print_data=True, save_path=save_dir)
137
- for f in glob.glob(os.path.join(save_dir, "*.html")) + glob.glob(
138
- os.path.join(save_dir, "*.png")):
139
- os.remove(f)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
 
142
  if __name__ == "__main__":
 
23
  return f"[{dev_name} | torch {torch_ver}]"
24
 
25
 
26
+ def plot_result(r_path, columns=None):
27
  import matplotlib.pyplot as plt
28
  import pandas as pd
29
  df = pd.read_csv(r_path + ".csv")
30
+ if columns is None:
31
+ columns = [c for c in ["Naive", "Compiled", "Cuda", "Triton"]
32
+ if c in df.columns]
33
  plt.figure(figsize=(12, 6))
34
+ ax = df.plot(x="config", y=columns, kind="bar", ax=plt.gca())
35
  ax.set_title("Speedup over torch (higher is better)\n" + make_title_tag(),
36
  fontsize=14,
37
  fontweight="bold")
 
47
 
48
  def main():
49
  ap = argparse.ArgumentParser()
50
+ ap.add_argument(
51
+ "--case",
52
+ choices=["rms", "add_rms", "poly", "mul_poly", "grouped_mul_poly"],
53
+ required=True)
54
  ap.add_argument("--plot", action="store_true")
55
  ap.add_argument(
56
  "--save-path",
 
58
  default="./configs/",
59
  help="Path to save benchmark results",
60
  )
61
+ ap.add_argument(
62
+ "--dtype",
63
+ choices=["fp16", "bf16", "fp32", "all"],
64
+ default="bf16",
65
+ help="Data type for benchmarking (default: bf16)",
66
+ )
67
  args = ap.parse_args()
68
 
69
+ dtype_map = {
70
+ "fp16": torch.float16,
71
+ "bf16": torch.bfloat16,
72
+ "fp32": torch.float32,
73
+ }
74
+ if args.dtype == "all":
75
+ dtypes = [("fp16", torch.float16), ("bf16", torch.bfloat16),
76
+ ("fp32", torch.float32)]
77
+ else:
78
+ dtypes = [(args.dtype, dtype_map[args.dtype])]
79
+
80
  torch.set_default_device("cuda")
81
  mod = importlib.import_module(f"cases.{args.case}")
82
  case: DiffCase = mod.CASE
 
88
  hidden_size=4096,
89
  )
90
 
91
+ for dtype_name, dtype in dtypes:
92
+ print(f"\n{'=' * 60}")
93
+ print(f" Benchmarking dtype: {dtype_name} ({dtype})")
94
+ print(f"{'=' * 60}\n")
95
+
96
+ save_dir = os.path.join(args.save_path, args.case, dtype_name)
97
+ is_grouped = args.case == "grouped_mul_poly"
98
+
99
+ if args.plot:
100
+ batch_size_range = [1]
101
+ seq_length_range = [4096, 8192, 16384]
102
+ if is_grouped:
103
+ dim = [1280]
104
+ elif "poly" in args.case:
105
+ dim = [8192, 16384]
106
+ else:
107
+ dim = [2048, 4096]
108
+ configs = list(
109
+ itertools.product(batch_size_range, seq_length_range, dim))
110
+
111
+ if is_grouped:
112
+ plot_line_vals = ("naive", "compiled", "cuda")
113
+ plot_line_names = {
114
+ "naive": "Naive",
115
+ "compiled": "Compiled",
116
+ "cuda": "Triton",
117
+ }
118
+ else:
119
+ plot_line_vals = ("naive", "cuda")
120
+ plot_line_names = {
121
+ "naive": "Naive",
122
+ "cuda": "Cuda",
123
+ }
124
+
125
+ plot_name = f"plot_{args.case}-{dtype_name}-fwd-perf"
126
+ bench = make_fwd_benchmark_plot_for_case(
127
+ case=case,
128
+ configs=configs,
129
+ plot_name=plot_name,
130
+ dtype=dtype,
131
+ line_vals=plot_line_vals,
132
+ line_names=plot_line_names,
133
+ )
134
+ bench.run(print_data=True, save_path=save_dir)
135
+ plot_result(os.path.join(save_dir, plot_name))
136
+
137
+ plot_name = f"plot_{args.case}-{dtype_name}-bwd-perf"
138
+ bench = make_bwd_benchmark_plot_for_case(
139
+ case=case,
140
+ configs=configs,
141
+ plot_name=plot_name,
142
+ dtype=dtype,
143
+ line_vals=plot_line_vals,
144
+ line_names=plot_line_names,
145
+ )
146
+ bench.run(print_data=True, save_path=save_dir)
147
+ plot_result(os.path.join(save_dir, plot_name))
148
+ for f in glob.glob(os.path.join(save_dir, "*.html")) + \
149
+ glob.glob(os.path.join(save_dir, "*.csv")):
150
+ os.remove(f)
151
+ else:
152
+ batch_size_range = [2**i for i in range(0, 4, 1)]
153
+ seq_length_range = [2**i for i in range(10, 14, 1)]
154
+ if is_grouped:
155
+ dim = [1280]
156
+ elif "poly" in args.case:
157
+ dim = [8192, 16384]
158
+ else:
159
+ dim = [2048, 4096]
160
+ configs = list(
161
+ itertools.product(dim, batch_size_range, seq_length_range))
162
+
163
+ if is_grouped:
164
+ csv_line_vals = ("naive", "compiled", "cuda", "speedup")
165
+ csv_line_names = {
166
+ "naive": "Naive",
167
+ "compiled": "Compiled",
168
+ "cuda": "Triton",
169
+ "speedup": "SpeedUp",
170
+ }
171
+ else:
172
+ csv_line_vals = ("naive", "cuda", "speedup")
173
+ csv_line_names = {
174
+ "naive": "Naive",
175
+ "cuda": "Cuda",
176
+ "speedup": "SpeedUp",
177
+ }
178
+
179
+ bench = make_fwd_benchmark_for_case(
180
+ case=case,
181
+ configs=configs,
182
+ plot_name=f"{args.case}-{dtype_name}-fwd-perf",
183
+ dtype=dtype,
184
+ line_vals=csv_line_vals,
185
+ line_names=csv_line_names,
186
+ )
187
+
188
+ bench.run(print_data=True, save_path=save_dir)
189
+
190
+ bench = make_bwd_benchmark_for_case(
191
+ case=case,
192
+ configs=configs,
193
+ plot_name=f"{args.case}-{dtype_name}-bwd-perf",
194
+ dtype=dtype,
195
+ line_vals=csv_line_vals,
196
+ line_names=csv_line_names,
197
+ )
198
+
199
+ bench.run(print_data=True, save_path=save_dir)
200
+ for f in glob.glob(os.path.join(save_dir, "*.html")) + \
201
+ glob.glob(os.path.join(save_dir, "*.png")):
202
+ os.remove(f)
203
 
204
 
205
  if __name__ == "__main__":
build/torch210-cxx11-cu126-x86_64-linux/__init__.py CHANGED
@@ -2,6 +2,7 @@ import torch
2
 
3
  from . import layers, parallel_style
4
  from ._ops import ops
 
5
  from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
6
  from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
7
 
@@ -45,6 +46,7 @@ def fused_add_rms_norm(
45
  __all__ = [
46
  "poly_norm",
47
  "fused_mul_poly_norm",
 
48
  "rms_norm",
49
  "fused_add_rms_norm",
50
  "layers",
 
2
 
3
  from . import layers, parallel_style
4
  from ._ops import ops
5
+ from .grouped_poly_norm import grouped_fused_mul_poly_norm
6
  from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
7
  from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
8
 
 
46
  __all__ = [
47
  "poly_norm",
48
  "fused_mul_poly_norm",
49
+ "grouped_fused_mul_poly_norm",
50
  "rms_norm",
51
  "fused_add_rms_norm",
52
  "layers",
build/torch210-cxx11-cu126-x86_64-linux/{_activation_18b7543_dirty.abi3.so → _activation_0e6f27f_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:39a7e25002120a73ea83ac813276c0518086fae2236f528dadf96bac4876a270
3
  size 10775296
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f31dfeac9b22c01a027f858b3d8beaf87eea9adf8dc45902f0e43d6c264fd985
3
  size 10775296
build/torch210-cxx11-cu126-x86_64-linux/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _activation_18b7543_dirty
3
- ops = torch.ops._activation_18b7543_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_activation_18b7543_dirty::{op_name}"
 
1
  import torch
2
+ from . import _activation_0e6f27f_dirty
3
+ ops = torch.ops._activation_0e6f27f_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_activation_0e6f27f_dirty::{op_name}"
build/torch210-cxx11-cu126-x86_64-linux/grouped_poly_norm.py ADDED
@@ -0,0 +1,583 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Triton-accelerated Grouped FusedMulPolyNorm for MoE.
2
+
3
+ Fuses the entire PolyNorm computation into two Triton kernels (fwd + bwd),
4
+ eliminating multiple intermediate tensors and kernel launches.
5
+
6
+ PolyNorm formula (per row):
7
+ poly = w[0] * rms_norm(x^3) + w[1] * rms_norm(x^2) + w[2] * rms_norm(x) + bias
8
+ output = poly * mul
9
+
10
+ where rms_norm(x) = x / sqrt(mean(x^2, dim=-1) + eps)
11
+
12
+ Performance optimizations:
13
+ - @triton.autotune selects optimal BLOCK_D, num_warps, and num_stages per
14
+ hidden dimension.
15
+ - Single-tile specialization: when D <= BLOCK_D, all data stays in registers
16
+ across the reduction and output phases, eliminating redundant global reads.
17
+ - Multi-tile software pipelining: explicit num_stages in autotune configs
18
+ enables overlapping memory loads with computation across loop iterations.
19
+ - In-kernel binary search for expert mapping: eliminates 2 PyTorch kernel
20
+ launches (torch.arange + torch.bucketize) per forward/backward call.
21
+ - Backward 2-pass optimization: pass 1 merges RMS statistics computation
22
+ with dot product accumulation, pass 2 computes gradients. This reduces
23
+ memory traffic compared to a naive 3-pass approach.
24
+
25
+ Forward kernel: one program per row, tiles over D dimension.
26
+ - Computes x, x^2, x^3 in registers
27
+ - Computes three RMS norms in a single pass (shared variance reduction)
28
+ - Applies polynomial weights + bias + mul in-place
29
+
30
+ Backward kernel: one program per row, tiles over D dimension.
31
+ - Recomputes forward intermediates from saved inputs (activation recomputation)
32
+ - 2-pass: (1) RMS stats + dot products + bias grad, (2) grad_input + grad_mul + weight grads
33
+ - Weight/bias gradients use tl.atomic_add for cross-row accumulation
34
+ """
35
+
36
+ import torch
37
+ from torch import Tensor
38
+
39
+ try:
40
+ import triton
41
+ import triton.language as tl
42
+
43
+ HAS_TRITON = True
44
+ except ImportError:
45
+ HAS_TRITON = False
46
+
47
+
48
+ # ---------------------------------------------------------------------------
49
+ # PyTorch reference implementation (for testing and benchmarking)
50
+ # ---------------------------------------------------------------------------
51
+ def _rms_norm(x: Tensor, eps: float) -> Tensor:
52
+ """Per-row RMS normalization: x / sqrt(mean(x^2, dim=-1) + eps)"""
53
+ return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps)
54
+
55
+
56
+ def grouped_fused_mul_poly_norm_ref(
57
+ input: Tensor,
58
+ mul: Tensor,
59
+ weight: Tensor,
60
+ bias: Tensor,
61
+ offsets: Tensor,
62
+ eps: float = 1e-6,
63
+ expert_offset: int = 0,
64
+ ) -> Tensor:
65
+ """PyTorch reference for Grouped FusedMulPolyNorm (vectorized, single pass).
66
+
67
+ Uses torch.bucketize to map tokens to experts, then computes PolyNorm
68
+ for all tokens at once. torch.compile friendly.
69
+
70
+ Args:
71
+ input: (total_tokens, D) - concatenated tokens for all experts
72
+ mul: (total_tokens, D) - gate values to multiply with
73
+ weight: (num_experts, 3) - per-expert polynomial weights [x^3, x^2, x]
74
+ bias: (num_experts, 1) - per-expert polynomial bias
75
+ offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32)
76
+ eps: numerical stability epsilon
77
+
78
+ Returns:
79
+ (total_tokens, D) - output tensor
80
+ """
81
+ orig_dtype = input.dtype
82
+
83
+ token_positions = torch.arange(input.shape[0], device=input.device)
84
+ expert_idx = torch.bucketize(token_positions, offsets, right=True) + expert_offset
85
+
86
+ weight_fp32 = weight.float()
87
+ bias_fp32 = bias.float()
88
+
89
+ per_token_w = weight_fp32[expert_idx]
90
+ per_token_b = bias_fp32[expert_idx]
91
+
92
+ x = input.float()
93
+ m = mul.float()
94
+
95
+ x2 = x * x
96
+ x3 = x2 * x
97
+
98
+ poly = (per_token_w[:, 0:1] * _rms_norm(x3, eps) +
99
+ per_token_w[:, 1:2] * _rms_norm(x2, eps) +
100
+ per_token_w[:, 2:3] * _rms_norm(x, eps) + per_token_b)
101
+
102
+ return (poly * m).to(orig_dtype)
103
+
104
+
105
+ # ---------------------------------------------------------------------------
106
+ # Triton kernel implementation
107
+ # ---------------------------------------------------------------------------
108
+ if HAS_TRITON:
109
+ # --- Autotune configurations ---
110
+ _GROUPED_POLYNORM_FWD_CONFIGS = [
111
+ triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2),
112
+ triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3),
113
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2),
114
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3),
115
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2),
116
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4),
117
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2),
118
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3),
119
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4),
120
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2),
121
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3),
122
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4),
123
+ triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2),
124
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2),
125
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3),
126
+ triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2),
127
+ triton.Config({"BLOCK_D": 2048}, num_warps=4, num_stages=1),
128
+ triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1),
129
+ triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1),
130
+ triton.Config({"BLOCK_D": 2048}, num_warps=32, num_stages=1),
131
+ ]
132
+
133
+ _GROUPED_POLYNORM_BWD_CONFIGS = [
134
+ triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2),
135
+ triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3),
136
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2),
137
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3),
138
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2),
139
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4),
140
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2),
141
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3),
142
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4),
143
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=5),
144
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2),
145
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3),
146
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4),
147
+ triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2),
148
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2),
149
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3),
150
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=4),
151
+ triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2),
152
+ triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1),
153
+ triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1),
154
+ ]
155
+
156
+ @triton.autotune(
157
+ configs=_GROUPED_POLYNORM_FWD_CONFIGS,
158
+ key=["D"],
159
+ )
160
+ @triton.jit
161
+ def _grouped_polynorm_fwd_kernel(
162
+ input_ptr,
163
+ mul_ptr,
164
+ weight_ptr,
165
+ bias_ptr,
166
+ offsets_ptr,
167
+ output_ptr,
168
+ N,
169
+ D,
170
+ num_experts,
171
+ eps,
172
+ expert_offset,
173
+ stride_input_row,
174
+ stride_mul_row,
175
+ stride_out_row,
176
+ BLOCK_D: tl.constexpr,
177
+ ):
178
+ """Forward kernel: one program per row."""
179
+ row = tl.program_id(0)
180
+ if row >= N:
181
+ return
182
+
183
+ # Binary search for expert index (12 iters covers up to 4096 experts)
184
+ lo = 0
185
+ hi = num_experts
186
+ for _ in range(12):
187
+ if lo < hi:
188
+ mid = (lo + hi) // 2
189
+ if tl.load(offsets_ptr + mid) <= row:
190
+ lo = mid + 1
191
+ else:
192
+ hi = mid
193
+ eidx = lo + expert_offset
194
+
195
+ w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32)
196
+ w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32)
197
+ w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32)
198
+ b = tl.load(bias_ptr + eidx).to(tl.float32)
199
+
200
+ input_row_ptr = input_ptr + row * stride_input_row
201
+ mul_row_ptr = mul_ptr + row * stride_mul_row
202
+ out_row_ptr = output_ptr + row * stride_out_row
203
+
204
+ D_float = D.to(tl.float32)
205
+
206
+ # --- Single-tile path ---
207
+ if D <= BLOCK_D:
208
+ d_offs = tl.arange(0, BLOCK_D)
209
+ mask = d_offs < D
210
+
211
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
212
+ other=0.0).to(tl.float32)
213
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
214
+ other=0.0).to(tl.float32)
215
+
216
+ x2 = x * x
217
+ x3 = x2 * x
218
+
219
+ inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps)
220
+ inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps)
221
+ inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps)
222
+
223
+ # Pre-multiply scalar weight * inv_rms to save 1 FMA per element
224
+ w0_inv = w0 * inv_rms_x3
225
+ w1_inv = w1 * inv_rms_x2
226
+ w2_inv = w2 * inv_rms_x
227
+
228
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b
229
+ tl.store(out_row_ptr + d_offs, poly * m, mask=mask)
230
+ else:
231
+ # --- Multi-tile: two-pass approach ---
232
+ sum_x2 = tl.zeros((), dtype=tl.float32)
233
+ sum_x4 = tl.zeros((), dtype=tl.float32)
234
+ sum_x6 = tl.zeros((), dtype=tl.float32)
235
+
236
+ for d_start in range(0, D, BLOCK_D):
237
+ d_offs = d_start + tl.arange(0, BLOCK_D)
238
+ mask = d_offs < D
239
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
240
+ other=0.0).to(tl.float32)
241
+ x2 = x * x
242
+ sum_x2 += tl.sum(x2)
243
+ sum_x4 += tl.sum(x2 * x2)
244
+ sum_x6 += tl.sum(x2 * x2 * x2)
245
+
246
+ inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps)
247
+ inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps)
248
+ inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps)
249
+
250
+ # Pre-multiply scalar weight * inv_rms
251
+ w0_inv = w0 * inv_rms_x3
252
+ w1_inv = w1 * inv_rms_x2
253
+ w2_inv = w2 * inv_rms_x
254
+
255
+ for d_start in range(0, D, BLOCK_D):
256
+ d_offs = d_start + tl.arange(0, BLOCK_D)
257
+ mask = d_offs < D
258
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
259
+ other=0.0).to(tl.float32)
260
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
261
+ other=0.0).to(tl.float32)
262
+ x2 = x * x
263
+ x3 = x2 * x
264
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b
265
+ tl.store(out_row_ptr + d_offs, poly * m, mask=mask)
266
+
267
+ @triton.autotune(
268
+ configs=_GROUPED_POLYNORM_BWD_CONFIGS,
269
+ key=["D"],
270
+ reset_to_zero=["grad_weight_ptr", "grad_bias_ptr"],
271
+ )
272
+ @triton.jit
273
+ def _grouped_polynorm_bwd_kernel(
274
+ grad_out_ptr,
275
+ input_ptr,
276
+ mul_ptr,
277
+ weight_ptr,
278
+ bias_ptr,
279
+ offsets_ptr,
280
+ grad_input_ptr,
281
+ grad_mul_ptr,
282
+ grad_weight_ptr,
283
+ grad_bias_ptr,
284
+ N,
285
+ D,
286
+ num_experts,
287
+ eps,
288
+ expert_offset,
289
+ stride_row,
290
+ BLOCK_D: tl.constexpr,
291
+ ):
292
+ """Backward kernel: one program per row, 2-pass approach.
293
+
294
+ Pass 1: RMS stats + dot products + bias grad
295
+ Pass 2: grad_input + grad_mul + weight grads (via atomic_add)
296
+ """
297
+ row = tl.program_id(0)
298
+ if row >= N:
299
+ return
300
+
301
+ lo = 0
302
+ hi = num_experts
303
+ for _ in range(12):
304
+ if lo < hi:
305
+ mid = (lo + hi) // 2
306
+ if tl.load(offsets_ptr + mid) <= row:
307
+ lo = mid + 1
308
+ else:
309
+ hi = mid
310
+ eidx = lo + expert_offset
311
+
312
+ w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32)
313
+ w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32)
314
+ w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32)
315
+ b_val = tl.load(bias_ptr + eidx).to(tl.float32)
316
+
317
+ input_row_ptr = input_ptr + row * stride_row
318
+ mul_row_ptr = mul_ptr + row * stride_row
319
+ grad_out_row_ptr = grad_out_ptr + row * stride_row
320
+ grad_input_row_ptr = grad_input_ptr + row * stride_row
321
+ grad_mul_row_ptr = grad_mul_ptr + row * stride_row
322
+
323
+ D_float = D.to(tl.float32)
324
+
325
+ # --- Single-tile path ---
326
+ if D <= BLOCK_D:
327
+ d_offs = tl.arange(0, BLOCK_D)
328
+ mask = d_offs < D
329
+
330
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
331
+ other=0.0).to(tl.float32)
332
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
333
+ other=0.0).to(tl.float32)
334
+ go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
335
+ other=0.0).to(tl.float32)
336
+
337
+ x2 = x * x
338
+ x3 = x2 * x
339
+
340
+ # Compute RMS stats (x4 inlined to reduce register pressure)
341
+ inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps)
342
+ inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps)
343
+ inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps)
344
+
345
+ w0_inv = w0 * inv_rms_x3
346
+ w1_inv = w1 * inv_rms_x2
347
+ w2_inv = w2 * inv_rms_x
348
+
349
+ dpoly = go * m
350
+
351
+ # Dot products for coefficients and weight grads
352
+ sum_dpoly_x = tl.sum(dpoly * x)
353
+ sum_dpoly_x2 = tl.sum(dpoly * x2)
354
+ sum_dpoly_x3 = tl.sum(dpoly * x3)
355
+ grad_b_acc = tl.sum(dpoly)
356
+
357
+ # Weight grads
358
+ grad_w0_acc = inv_rms_x3 * sum_dpoly_x3
359
+ grad_w1_acc = inv_rms_x2 * sum_dpoly_x2
360
+ grad_w2_acc = inv_rms_x * sum_dpoly_x
361
+
362
+ # Coefficients for grad_input
363
+ coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float
364
+ coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float
365
+ coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float
366
+
367
+ # grad_mul
368
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv
369
+ tl.store(grad_mul_row_ptr + d_offs, go * (poly + b_val), mask=mask)
370
+
371
+ # grad_input (in-place accumulation to reduce register pressure)
372
+ g = inv_rms_x * (w2 * dpoly - x * coeff_x)
373
+ g += 2.0 * x * inv_rms_x2 * (w1 * dpoly - x2 * coeff_x2)
374
+ g += 3.0 * x2 * inv_rms_x3 * (w0 * dpoly - x3 * coeff_x3)
375
+ tl.store(grad_input_row_ptr + d_offs, g, mask=mask)
376
+
377
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc)
378
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc)
379
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc)
380
+ tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc)
381
+ else:
382
+ # --- Multi-tile: 2-pass ---
383
+ # Pass 1: RMS stats + dot products + bias grad
384
+ sum_x2 = tl.zeros((), dtype=tl.float32)
385
+ sum_x4 = tl.zeros((), dtype=tl.float32)
386
+ sum_x6 = tl.zeros((), dtype=tl.float32)
387
+ sum_dpoly_x = tl.zeros((), dtype=tl.float32)
388
+ sum_dpoly_x2 = tl.zeros((), dtype=tl.float32)
389
+ sum_dpoly_x3 = tl.zeros((), dtype=tl.float32)
390
+ grad_b_acc = tl.zeros((), dtype=tl.float32)
391
+
392
+ for d_start in range(0, D, BLOCK_D):
393
+ d_offs = d_start + tl.arange(0, BLOCK_D)
394
+ mask = d_offs < D
395
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
396
+ other=0.0).to(tl.float32)
397
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
398
+ other=0.0).to(tl.float32)
399
+ go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
400
+ other=0.0).to(tl.float32)
401
+
402
+ x2 = x * x
403
+ x3 = x2 * x
404
+ dpoly = go * m
405
+
406
+ sum_x2 += tl.sum(x2)
407
+ sum_x4 += tl.sum(x2 * x2)
408
+ sum_x6 += tl.sum(x2 * x2 * x2)
409
+ sum_dpoly_x += tl.sum(dpoly * x)
410
+ sum_dpoly_x2 += tl.sum(dpoly * x2)
411
+ sum_dpoly_x3 += tl.sum(dpoly * x3)
412
+ grad_b_acc += tl.sum(dpoly)
413
+
414
+ inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps)
415
+ inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps)
416
+ inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps)
417
+
418
+ w0_inv = w0 * inv_rms_x3
419
+ w1_inv = w1 * inv_rms_x2
420
+ w2_inv = w2 * inv_rms_x
421
+
422
+ # Weight grads
423
+ grad_w0_acc = inv_rms_x3 * sum_dpoly_x3
424
+ grad_w1_acc = inv_rms_x2 * sum_dpoly_x2
425
+ grad_w2_acc = inv_rms_x * sum_dpoly_x
426
+
427
+ # Coefficients for grad_input
428
+ coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float
429
+ coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float
430
+ coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float
431
+
432
+ # Pass 2: grad_input + grad_mul
433
+ for d_start in range(0, D, BLOCK_D):
434
+ d_offs = d_start + tl.arange(0, BLOCK_D)
435
+ mask = d_offs < D
436
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
437
+ other=0.0).to(tl.float32)
438
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
439
+ other=0.0).to(tl.float32)
440
+ go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
441
+ other=0.0).to(tl.float32)
442
+
443
+ x2 = x * x
444
+ x3 = x2 * x
445
+
446
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv
447
+ tl.store(grad_mul_row_ptr + d_offs,
448
+ go * (poly + b_val),
449
+ mask=mask)
450
+
451
+ dpoly = go * m
452
+ g = inv_rms_x * (w2 * dpoly - x * coeff_x)
453
+ g += (2.0 * x * inv_rms_x2 *
454
+ (w1 * dpoly - x2 * coeff_x2))
455
+ g += (3.0 * x2 * inv_rms_x3 *
456
+ (w0 * dpoly - x3 * coeff_x3))
457
+
458
+ tl.store(grad_input_row_ptr + d_offs, g, mask=mask)
459
+
460
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc)
461
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc)
462
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc)
463
+ tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc)
464
+
465
+ class _GroupedPolyNormFn(torch.autograd.Function):
466
+
467
+ @staticmethod
468
+ def forward(ctx, input, mul, weight, bias, offsets, eps, expert_offset):
469
+ N, D = input.shape
470
+ input = input.contiguous()
471
+ mul = mul.contiguous()
472
+ output = torch.empty_like(input)
473
+
474
+ num_experts = offsets.shape[0]
475
+ assert num_experts <= 4096, (
476
+ f"Supports at most 4096 experts, got {num_experts}.")
477
+
478
+ _grouped_polynorm_fwd_kernel[(N,)](
479
+ input,
480
+ mul,
481
+ weight,
482
+ bias,
483
+ offsets,
484
+ output,
485
+ N,
486
+ D,
487
+ num_experts,
488
+ eps,
489
+ expert_offset,
490
+ stride_input_row=input.stride(0),
491
+ stride_mul_row=mul.stride(0),
492
+ stride_out_row=output.stride(0),
493
+ )
494
+
495
+ ctx.save_for_backward(input, mul, weight, bias, offsets)
496
+ ctx.eps = eps
497
+ ctx.expert_offset = expert_offset
498
+ return output
499
+
500
+ @staticmethod
501
+ def backward(ctx, grad_output):
502
+ input, mul, weight, bias, offsets = ctx.saved_tensors
503
+ eps = ctx.eps
504
+ expert_offset = ctx.expert_offset
505
+ N, D = input.shape
506
+
507
+ grad_output = grad_output.contiguous()
508
+ grad_input = torch.empty_like(input)
509
+ grad_mul = torch.empty_like(mul)
510
+ grad_weight = torch.zeros(weight.shape[0],
511
+ 3,
512
+ device=weight.device,
513
+ dtype=torch.float32)
514
+ grad_bias = torch.zeros(bias.shape[0],
515
+ device=bias.device,
516
+ dtype=torch.float32)
517
+
518
+ num_experts = offsets.shape[0]
519
+
520
+ _grouped_polynorm_bwd_kernel[(N,)](
521
+ grad_output,
522
+ input,
523
+ mul,
524
+ weight,
525
+ bias,
526
+ offsets,
527
+ grad_input,
528
+ grad_mul,
529
+ grad_weight,
530
+ grad_bias,
531
+ N,
532
+ D,
533
+ num_experts,
534
+ eps,
535
+ expert_offset,
536
+ stride_row=input.stride(0),
537
+ )
538
+
539
+ grad_weight = grad_weight.to(weight.dtype)
540
+ grad_bias = grad_bias.unsqueeze(-1).to(bias.dtype)
541
+
542
+ return grad_input, grad_mul, grad_weight, grad_bias, None, None, None
543
+
544
+ def grouped_fused_mul_poly_norm(
545
+ input: Tensor,
546
+ mul: Tensor,
547
+ weight: Tensor,
548
+ bias: Tensor,
549
+ offsets: Tensor,
550
+ eps: float = 1e-6,
551
+ expert_offset: int = 0,
552
+ ) -> Tensor:
553
+ """Triton-accelerated Grouped FusedMulPolyNorm.
554
+
555
+ Args:
556
+ input: (total_tokens, D) - concatenated tokens for all experts
557
+ mul: (total_tokens, D) - gate values to multiply with
558
+ weight: (num_experts, 3) - per-expert polynomial weights
559
+ bias: (num_experts, 1) - per-expert polynomial bias
560
+ offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32)
561
+ eps: numerical stability epsilon
562
+ expert_offset: offset to add to expert index
563
+
564
+ Returns:
565
+ (total_tokens, D) - output tensor
566
+ """
567
+ return _GroupedPolyNormFn.apply(input, mul, weight, bias, offsets, eps,
568
+ expert_offset)
569
+
570
+ else:
571
+
572
+ def grouped_fused_mul_poly_norm(
573
+ input: Tensor,
574
+ mul: Tensor,
575
+ weight: Tensor,
576
+ bias: Tensor,
577
+ offsets: Tensor,
578
+ eps: float = 1e-6,
579
+ expert_offset: int = 0,
580
+ ) -> Tensor:
581
+ raise RuntimeError(
582
+ "Triton is not available. Install triton to use "
583
+ "grouped_fused_mul_poly_norm.")
build/torch210-cxx11-cu128-x86_64-linux/__init__.py CHANGED
@@ -2,6 +2,7 @@ import torch
2
 
3
  from . import layers, parallel_style
4
  from ._ops import ops
 
5
  from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
6
  from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
7
 
@@ -45,6 +46,7 @@ def fused_add_rms_norm(
45
  __all__ = [
46
  "poly_norm",
47
  "fused_mul_poly_norm",
 
48
  "rms_norm",
49
  "fused_add_rms_norm",
50
  "layers",
 
2
 
3
  from . import layers, parallel_style
4
  from ._ops import ops
5
+ from .grouped_poly_norm import grouped_fused_mul_poly_norm
6
  from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
7
  from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
8
 
 
46
  __all__ = [
47
  "poly_norm",
48
  "fused_mul_poly_norm",
49
+ "grouped_fused_mul_poly_norm",
50
  "rms_norm",
51
  "fused_add_rms_norm",
52
  "layers",
build/torch210-cxx11-cu128-x86_64-linux/{_activation_18b7543_dirty.abi3.so → _activation_0e6f27f_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:078853c2db399a227822ea0c8e70c2e13bad41bfa370657dd19aa2efb3b503e9
3
  size 15815392
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4b8370d2e1f5561ae4b77ac8ae7b3a084e33a0d1952a8f5f9bf4700375313b35
3
  size 15815392
build/torch210-cxx11-cu128-x86_64-linux/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _activation_18b7543_dirty
3
- ops = torch.ops._activation_18b7543_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_activation_18b7543_dirty::{op_name}"
 
1
  import torch
2
+ from . import _activation_0e6f27f_dirty
3
+ ops = torch.ops._activation_0e6f27f_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_activation_0e6f27f_dirty::{op_name}"
build/torch210-cxx11-cu128-x86_64-linux/grouped_poly_norm.py ADDED
@@ -0,0 +1,583 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Triton-accelerated Grouped FusedMulPolyNorm for MoE.
2
+
3
+ Fuses the entire PolyNorm computation into two Triton kernels (fwd + bwd),
4
+ eliminating multiple intermediate tensors and kernel launches.
5
+
6
+ PolyNorm formula (per row):
7
+ poly = w[0] * rms_norm(x^3) + w[1] * rms_norm(x^2) + w[2] * rms_norm(x) + bias
8
+ output = poly * mul
9
+
10
+ where rms_norm(x) = x / sqrt(mean(x^2, dim=-1) + eps)
11
+
12
+ Performance optimizations:
13
+ - @triton.autotune selects optimal BLOCK_D, num_warps, and num_stages per
14
+ hidden dimension.
15
+ - Single-tile specialization: when D <= BLOCK_D, all data stays in registers
16
+ across the reduction and output phases, eliminating redundant global reads.
17
+ - Multi-tile software pipelining: explicit num_stages in autotune configs
18
+ enables overlapping memory loads with computation across loop iterations.
19
+ - In-kernel binary search for expert mapping: eliminates 2 PyTorch kernel
20
+ launches (torch.arange + torch.bucketize) per forward/backward call.
21
+ - Backward 2-pass optimization: pass 1 merges RMS statistics computation
22
+ with dot product accumulation, pass 2 computes gradients. This reduces
23
+ memory traffic compared to a naive 3-pass approach.
24
+
25
+ Forward kernel: one program per row, tiles over D dimension.
26
+ - Computes x, x^2, x^3 in registers
27
+ - Computes three RMS norms in a single pass (shared variance reduction)
28
+ - Applies polynomial weights + bias + mul in-place
29
+
30
+ Backward kernel: one program per row, tiles over D dimension.
31
+ - Recomputes forward intermediates from saved inputs (activation recomputation)
32
+ - 2-pass: (1) RMS stats + dot products + bias grad, (2) grad_input + grad_mul + weight grads
33
+ - Weight/bias gradients use tl.atomic_add for cross-row accumulation
34
+ """
35
+
36
+ import torch
37
+ from torch import Tensor
38
+
39
+ try:
40
+ import triton
41
+ import triton.language as tl
42
+
43
+ HAS_TRITON = True
44
+ except ImportError:
45
+ HAS_TRITON = False
46
+
47
+
48
+ # ---------------------------------------------------------------------------
49
+ # PyTorch reference implementation (for testing and benchmarking)
50
+ # ---------------------------------------------------------------------------
51
+ def _rms_norm(x: Tensor, eps: float) -> Tensor:
52
+ """Per-row RMS normalization: x / sqrt(mean(x^2, dim=-1) + eps)"""
53
+ return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps)
54
+
55
+
56
+ def grouped_fused_mul_poly_norm_ref(
57
+ input: Tensor,
58
+ mul: Tensor,
59
+ weight: Tensor,
60
+ bias: Tensor,
61
+ offsets: Tensor,
62
+ eps: float = 1e-6,
63
+ expert_offset: int = 0,
64
+ ) -> Tensor:
65
+ """PyTorch reference for Grouped FusedMulPolyNorm (vectorized, single pass).
66
+
67
+ Uses torch.bucketize to map tokens to experts, then computes PolyNorm
68
+ for all tokens at once. torch.compile friendly.
69
+
70
+ Args:
71
+ input: (total_tokens, D) - concatenated tokens for all experts
72
+ mul: (total_tokens, D) - gate values to multiply with
73
+ weight: (num_experts, 3) - per-expert polynomial weights [x^3, x^2, x]
74
+ bias: (num_experts, 1) - per-expert polynomial bias
75
+ offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32)
76
+ eps: numerical stability epsilon
77
+
78
+ Returns:
79
+ (total_tokens, D) - output tensor
80
+ """
81
+ orig_dtype = input.dtype
82
+
83
+ token_positions = torch.arange(input.shape[0], device=input.device)
84
+ expert_idx = torch.bucketize(token_positions, offsets, right=True) + expert_offset
85
+
86
+ weight_fp32 = weight.float()
87
+ bias_fp32 = bias.float()
88
+
89
+ per_token_w = weight_fp32[expert_idx]
90
+ per_token_b = bias_fp32[expert_idx]
91
+
92
+ x = input.float()
93
+ m = mul.float()
94
+
95
+ x2 = x * x
96
+ x3 = x2 * x
97
+
98
+ poly = (per_token_w[:, 0:1] * _rms_norm(x3, eps) +
99
+ per_token_w[:, 1:2] * _rms_norm(x2, eps) +
100
+ per_token_w[:, 2:3] * _rms_norm(x, eps) + per_token_b)
101
+
102
+ return (poly * m).to(orig_dtype)
103
+
104
+
105
+ # ---------------------------------------------------------------------------
106
+ # Triton kernel implementation
107
+ # ---------------------------------------------------------------------------
108
+ if HAS_TRITON:
109
+ # --- Autotune configurations ---
110
+ _GROUPED_POLYNORM_FWD_CONFIGS = [
111
+ triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2),
112
+ triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3),
113
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2),
114
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3),
115
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2),
116
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4),
117
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2),
118
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3),
119
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4),
120
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2),
121
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3),
122
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4),
123
+ triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2),
124
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2),
125
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3),
126
+ triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2),
127
+ triton.Config({"BLOCK_D": 2048}, num_warps=4, num_stages=1),
128
+ triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1),
129
+ triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1),
130
+ triton.Config({"BLOCK_D": 2048}, num_warps=32, num_stages=1),
131
+ ]
132
+
133
+ _GROUPED_POLYNORM_BWD_CONFIGS = [
134
+ triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2),
135
+ triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3),
136
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2),
137
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3),
138
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2),
139
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4),
140
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2),
141
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3),
142
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4),
143
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=5),
144
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2),
145
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3),
146
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4),
147
+ triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2),
148
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2),
149
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3),
150
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=4),
151
+ triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2),
152
+ triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1),
153
+ triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1),
154
+ ]
155
+
156
+ @triton.autotune(
157
+ configs=_GROUPED_POLYNORM_FWD_CONFIGS,
158
+ key=["D"],
159
+ )
160
+ @triton.jit
161
+ def _grouped_polynorm_fwd_kernel(
162
+ input_ptr,
163
+ mul_ptr,
164
+ weight_ptr,
165
+ bias_ptr,
166
+ offsets_ptr,
167
+ output_ptr,
168
+ N,
169
+ D,
170
+ num_experts,
171
+ eps,
172
+ expert_offset,
173
+ stride_input_row,
174
+ stride_mul_row,
175
+ stride_out_row,
176
+ BLOCK_D: tl.constexpr,
177
+ ):
178
+ """Forward kernel: one program per row."""
179
+ row = tl.program_id(0)
180
+ if row >= N:
181
+ return
182
+
183
+ # Binary search for expert index (12 iters covers up to 4096 experts)
184
+ lo = 0
185
+ hi = num_experts
186
+ for _ in range(12):
187
+ if lo < hi:
188
+ mid = (lo + hi) // 2
189
+ if tl.load(offsets_ptr + mid) <= row:
190
+ lo = mid + 1
191
+ else:
192
+ hi = mid
193
+ eidx = lo + expert_offset
194
+
195
+ w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32)
196
+ w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32)
197
+ w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32)
198
+ b = tl.load(bias_ptr + eidx).to(tl.float32)
199
+
200
+ input_row_ptr = input_ptr + row * stride_input_row
201
+ mul_row_ptr = mul_ptr + row * stride_mul_row
202
+ out_row_ptr = output_ptr + row * stride_out_row
203
+
204
+ D_float = D.to(tl.float32)
205
+
206
+ # --- Single-tile path ---
207
+ if D <= BLOCK_D:
208
+ d_offs = tl.arange(0, BLOCK_D)
209
+ mask = d_offs < D
210
+
211
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
212
+ other=0.0).to(tl.float32)
213
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
214
+ other=0.0).to(tl.float32)
215
+
216
+ x2 = x * x
217
+ x3 = x2 * x
218
+
219
+ inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps)
220
+ inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps)
221
+ inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps)
222
+
223
+ # Pre-multiply scalar weight * inv_rms to save 1 FMA per element
224
+ w0_inv = w0 * inv_rms_x3
225
+ w1_inv = w1 * inv_rms_x2
226
+ w2_inv = w2 * inv_rms_x
227
+
228
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b
229
+ tl.store(out_row_ptr + d_offs, poly * m, mask=mask)
230
+ else:
231
+ # --- Multi-tile: two-pass approach ---
232
+ sum_x2 = tl.zeros((), dtype=tl.float32)
233
+ sum_x4 = tl.zeros((), dtype=tl.float32)
234
+ sum_x6 = tl.zeros((), dtype=tl.float32)
235
+
236
+ for d_start in range(0, D, BLOCK_D):
237
+ d_offs = d_start + tl.arange(0, BLOCK_D)
238
+ mask = d_offs < D
239
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
240
+ other=0.0).to(tl.float32)
241
+ x2 = x * x
242
+ sum_x2 += tl.sum(x2)
243
+ sum_x4 += tl.sum(x2 * x2)
244
+ sum_x6 += tl.sum(x2 * x2 * x2)
245
+
246
+ inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps)
247
+ inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps)
248
+ inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps)
249
+
250
+ # Pre-multiply scalar weight * inv_rms
251
+ w0_inv = w0 * inv_rms_x3
252
+ w1_inv = w1 * inv_rms_x2
253
+ w2_inv = w2 * inv_rms_x
254
+
255
+ for d_start in range(0, D, BLOCK_D):
256
+ d_offs = d_start + tl.arange(0, BLOCK_D)
257
+ mask = d_offs < D
258
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
259
+ other=0.0).to(tl.float32)
260
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
261
+ other=0.0).to(tl.float32)
262
+ x2 = x * x
263
+ x3 = x2 * x
264
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b
265
+ tl.store(out_row_ptr + d_offs, poly * m, mask=mask)
266
+
267
+ @triton.autotune(
268
+ configs=_GROUPED_POLYNORM_BWD_CONFIGS,
269
+ key=["D"],
270
+ reset_to_zero=["grad_weight_ptr", "grad_bias_ptr"],
271
+ )
272
+ @triton.jit
273
+ def _grouped_polynorm_bwd_kernel(
274
+ grad_out_ptr,
275
+ input_ptr,
276
+ mul_ptr,
277
+ weight_ptr,
278
+ bias_ptr,
279
+ offsets_ptr,
280
+ grad_input_ptr,
281
+ grad_mul_ptr,
282
+ grad_weight_ptr,
283
+ grad_bias_ptr,
284
+ N,
285
+ D,
286
+ num_experts,
287
+ eps,
288
+ expert_offset,
289
+ stride_row,
290
+ BLOCK_D: tl.constexpr,
291
+ ):
292
+ """Backward kernel: one program per row, 2-pass approach.
293
+
294
+ Pass 1: RMS stats + dot products + bias grad
295
+ Pass 2: grad_input + grad_mul + weight grads (via atomic_add)
296
+ """
297
+ row = tl.program_id(0)
298
+ if row >= N:
299
+ return
300
+
301
+ lo = 0
302
+ hi = num_experts
303
+ for _ in range(12):
304
+ if lo < hi:
305
+ mid = (lo + hi) // 2
306
+ if tl.load(offsets_ptr + mid) <= row:
307
+ lo = mid + 1
308
+ else:
309
+ hi = mid
310
+ eidx = lo + expert_offset
311
+
312
+ w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32)
313
+ w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32)
314
+ w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32)
315
+ b_val = tl.load(bias_ptr + eidx).to(tl.float32)
316
+
317
+ input_row_ptr = input_ptr + row * stride_row
318
+ mul_row_ptr = mul_ptr + row * stride_row
319
+ grad_out_row_ptr = grad_out_ptr + row * stride_row
320
+ grad_input_row_ptr = grad_input_ptr + row * stride_row
321
+ grad_mul_row_ptr = grad_mul_ptr + row * stride_row
322
+
323
+ D_float = D.to(tl.float32)
324
+
325
+ # --- Single-tile path ---
326
+ if D <= BLOCK_D:
327
+ d_offs = tl.arange(0, BLOCK_D)
328
+ mask = d_offs < D
329
+
330
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
331
+ other=0.0).to(tl.float32)
332
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
333
+ other=0.0).to(tl.float32)
334
+ go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
335
+ other=0.0).to(tl.float32)
336
+
337
+ x2 = x * x
338
+ x3 = x2 * x
339
+
340
+ # Compute RMS stats (x4 inlined to reduce register pressure)
341
+ inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps)
342
+ inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps)
343
+ inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps)
344
+
345
+ w0_inv = w0 * inv_rms_x3
346
+ w1_inv = w1 * inv_rms_x2
347
+ w2_inv = w2 * inv_rms_x
348
+
349
+ dpoly = go * m
350
+
351
+ # Dot products for coefficients and weight grads
352
+ sum_dpoly_x = tl.sum(dpoly * x)
353
+ sum_dpoly_x2 = tl.sum(dpoly * x2)
354
+ sum_dpoly_x3 = tl.sum(dpoly * x3)
355
+ grad_b_acc = tl.sum(dpoly)
356
+
357
+ # Weight grads
358
+ grad_w0_acc = inv_rms_x3 * sum_dpoly_x3
359
+ grad_w1_acc = inv_rms_x2 * sum_dpoly_x2
360
+ grad_w2_acc = inv_rms_x * sum_dpoly_x
361
+
362
+ # Coefficients for grad_input
363
+ coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float
364
+ coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float
365
+ coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float
366
+
367
+ # grad_mul
368
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv
369
+ tl.store(grad_mul_row_ptr + d_offs, go * (poly + b_val), mask=mask)
370
+
371
+ # grad_input (in-place accumulation to reduce register pressure)
372
+ g = inv_rms_x * (w2 * dpoly - x * coeff_x)
373
+ g += 2.0 * x * inv_rms_x2 * (w1 * dpoly - x2 * coeff_x2)
374
+ g += 3.0 * x2 * inv_rms_x3 * (w0 * dpoly - x3 * coeff_x3)
375
+ tl.store(grad_input_row_ptr + d_offs, g, mask=mask)
376
+
377
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc)
378
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc)
379
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc)
380
+ tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc)
381
+ else:
382
+ # --- Multi-tile: 2-pass ---
383
+ # Pass 1: RMS stats + dot products + bias grad
384
+ sum_x2 = tl.zeros((), dtype=tl.float32)
385
+ sum_x4 = tl.zeros((), dtype=tl.float32)
386
+ sum_x6 = tl.zeros((), dtype=tl.float32)
387
+ sum_dpoly_x = tl.zeros((), dtype=tl.float32)
388
+ sum_dpoly_x2 = tl.zeros((), dtype=tl.float32)
389
+ sum_dpoly_x3 = tl.zeros((), dtype=tl.float32)
390
+ grad_b_acc = tl.zeros((), dtype=tl.float32)
391
+
392
+ for d_start in range(0, D, BLOCK_D):
393
+ d_offs = d_start + tl.arange(0, BLOCK_D)
394
+ mask = d_offs < D
395
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
396
+ other=0.0).to(tl.float32)
397
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
398
+ other=0.0).to(tl.float32)
399
+ go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
400
+ other=0.0).to(tl.float32)
401
+
402
+ x2 = x * x
403
+ x3 = x2 * x
404
+ dpoly = go * m
405
+
406
+ sum_x2 += tl.sum(x2)
407
+ sum_x4 += tl.sum(x2 * x2)
408
+ sum_x6 += tl.sum(x2 * x2 * x2)
409
+ sum_dpoly_x += tl.sum(dpoly * x)
410
+ sum_dpoly_x2 += tl.sum(dpoly * x2)
411
+ sum_dpoly_x3 += tl.sum(dpoly * x3)
412
+ grad_b_acc += tl.sum(dpoly)
413
+
414
+ inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps)
415
+ inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps)
416
+ inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps)
417
+
418
+ w0_inv = w0 * inv_rms_x3
419
+ w1_inv = w1 * inv_rms_x2
420
+ w2_inv = w2 * inv_rms_x
421
+
422
+ # Weight grads
423
+ grad_w0_acc = inv_rms_x3 * sum_dpoly_x3
424
+ grad_w1_acc = inv_rms_x2 * sum_dpoly_x2
425
+ grad_w2_acc = inv_rms_x * sum_dpoly_x
426
+
427
+ # Coefficients for grad_input
428
+ coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float
429
+ coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float
430
+ coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float
431
+
432
+ # Pass 2: grad_input + grad_mul
433
+ for d_start in range(0, D, BLOCK_D):
434
+ d_offs = d_start + tl.arange(0, BLOCK_D)
435
+ mask = d_offs < D
436
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
437
+ other=0.0).to(tl.float32)
438
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
439
+ other=0.0).to(tl.float32)
440
+ go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
441
+ other=0.0).to(tl.float32)
442
+
443
+ x2 = x * x
444
+ x3 = x2 * x
445
+
446
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv
447
+ tl.store(grad_mul_row_ptr + d_offs,
448
+ go * (poly + b_val),
449
+ mask=mask)
450
+
451
+ dpoly = go * m
452
+ g = inv_rms_x * (w2 * dpoly - x * coeff_x)
453
+ g += (2.0 * x * inv_rms_x2 *
454
+ (w1 * dpoly - x2 * coeff_x2))
455
+ g += (3.0 * x2 * inv_rms_x3 *
456
+ (w0 * dpoly - x3 * coeff_x3))
457
+
458
+ tl.store(grad_input_row_ptr + d_offs, g, mask=mask)
459
+
460
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc)
461
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc)
462
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc)
463
+ tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc)
464
+
465
+ class _GroupedPolyNormFn(torch.autograd.Function):
466
+
467
+ @staticmethod
468
+ def forward(ctx, input, mul, weight, bias, offsets, eps, expert_offset):
469
+ N, D = input.shape
470
+ input = input.contiguous()
471
+ mul = mul.contiguous()
472
+ output = torch.empty_like(input)
473
+
474
+ num_experts = offsets.shape[0]
475
+ assert num_experts <= 4096, (
476
+ f"Supports at most 4096 experts, got {num_experts}.")
477
+
478
+ _grouped_polynorm_fwd_kernel[(N,)](
479
+ input,
480
+ mul,
481
+ weight,
482
+ bias,
483
+ offsets,
484
+ output,
485
+ N,
486
+ D,
487
+ num_experts,
488
+ eps,
489
+ expert_offset,
490
+ stride_input_row=input.stride(0),
491
+ stride_mul_row=mul.stride(0),
492
+ stride_out_row=output.stride(0),
493
+ )
494
+
495
+ ctx.save_for_backward(input, mul, weight, bias, offsets)
496
+ ctx.eps = eps
497
+ ctx.expert_offset = expert_offset
498
+ return output
499
+
500
+ @staticmethod
501
+ def backward(ctx, grad_output):
502
+ input, mul, weight, bias, offsets = ctx.saved_tensors
503
+ eps = ctx.eps
504
+ expert_offset = ctx.expert_offset
505
+ N, D = input.shape
506
+
507
+ grad_output = grad_output.contiguous()
508
+ grad_input = torch.empty_like(input)
509
+ grad_mul = torch.empty_like(mul)
510
+ grad_weight = torch.zeros(weight.shape[0],
511
+ 3,
512
+ device=weight.device,
513
+ dtype=torch.float32)
514
+ grad_bias = torch.zeros(bias.shape[0],
515
+ device=bias.device,
516
+ dtype=torch.float32)
517
+
518
+ num_experts = offsets.shape[0]
519
+
520
+ _grouped_polynorm_bwd_kernel[(N,)](
521
+ grad_output,
522
+ input,
523
+ mul,
524
+ weight,
525
+ bias,
526
+ offsets,
527
+ grad_input,
528
+ grad_mul,
529
+ grad_weight,
530
+ grad_bias,
531
+ N,
532
+ D,
533
+ num_experts,
534
+ eps,
535
+ expert_offset,
536
+ stride_row=input.stride(0),
537
+ )
538
+
539
+ grad_weight = grad_weight.to(weight.dtype)
540
+ grad_bias = grad_bias.unsqueeze(-1).to(bias.dtype)
541
+
542
+ return grad_input, grad_mul, grad_weight, grad_bias, None, None, None
543
+
544
+ def grouped_fused_mul_poly_norm(
545
+ input: Tensor,
546
+ mul: Tensor,
547
+ weight: Tensor,
548
+ bias: Tensor,
549
+ offsets: Tensor,
550
+ eps: float = 1e-6,
551
+ expert_offset: int = 0,
552
+ ) -> Tensor:
553
+ """Triton-accelerated Grouped FusedMulPolyNorm.
554
+
555
+ Args:
556
+ input: (total_tokens, D) - concatenated tokens for all experts
557
+ mul: (total_tokens, D) - gate values to multiply with
558
+ weight: (num_experts, 3) - per-expert polynomial weights
559
+ bias: (num_experts, 1) - per-expert polynomial bias
560
+ offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32)
561
+ eps: numerical stability epsilon
562
+ expert_offset: offset to add to expert index
563
+
564
+ Returns:
565
+ (total_tokens, D) - output tensor
566
+ """
567
+ return _GroupedPolyNormFn.apply(input, mul, weight, bias, offsets, eps,
568
+ expert_offset)
569
+
570
+ else:
571
+
572
+ def grouped_fused_mul_poly_norm(
573
+ input: Tensor,
574
+ mul: Tensor,
575
+ weight: Tensor,
576
+ bias: Tensor,
577
+ offsets: Tensor,
578
+ eps: float = 1e-6,
579
+ expert_offset: int = 0,
580
+ ) -> Tensor:
581
+ raise RuntimeError(
582
+ "Triton is not available. Install triton to use "
583
+ "grouped_fused_mul_poly_norm.")
build/torch210-cxx11-cu130-x86_64-linux/__init__.py CHANGED
@@ -2,6 +2,7 @@ import torch
2
 
3
  from . import layers, parallel_style
4
  from ._ops import ops
 
5
  from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
6
  from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
7
 
@@ -45,6 +46,7 @@ def fused_add_rms_norm(
45
  __all__ = [
46
  "poly_norm",
47
  "fused_mul_poly_norm",
 
48
  "rms_norm",
49
  "fused_add_rms_norm",
50
  "layers",
 
2
 
3
  from . import layers, parallel_style
4
  from ._ops import ops
5
+ from .grouped_poly_norm import grouped_fused_mul_poly_norm
6
  from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
7
  from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
8
 
 
46
  __all__ = [
47
  "poly_norm",
48
  "fused_mul_poly_norm",
49
+ "grouped_fused_mul_poly_norm",
50
  "rms_norm",
51
  "fused_add_rms_norm",
52
  "layers",
build/torch210-cxx11-cu130-x86_64-linux/{_activation_18b7543_dirty.abi3.so → _activation_0e6f27f_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:59e2c13071e1807a6225c5ad7a4a7eb04d46b1f177ae6344d199a9e7f14daf92
3
  size 13520952
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:edf3fca2079788750c4e0497012ba93c34c770aca4c9d4f22d03be4a86a2ce8c
3
  size 13520952
build/torch210-cxx11-cu130-x86_64-linux/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _activation_18b7543_dirty
3
- ops = torch.ops._activation_18b7543_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_activation_18b7543_dirty::{op_name}"
 
1
  import torch
2
+ from . import _activation_0e6f27f_dirty
3
+ ops = torch.ops._activation_0e6f27f_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_activation_0e6f27f_dirty::{op_name}"
build/torch210-cxx11-cu130-x86_64-linux/grouped_poly_norm.py ADDED
@@ -0,0 +1,583 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Triton-accelerated Grouped FusedMulPolyNorm for MoE.
2
+
3
+ Fuses the entire PolyNorm computation into two Triton kernels (fwd + bwd),
4
+ eliminating multiple intermediate tensors and kernel launches.
5
+
6
+ PolyNorm formula (per row):
7
+ poly = w[0] * rms_norm(x^3) + w[1] * rms_norm(x^2) + w[2] * rms_norm(x) + bias
8
+ output = poly * mul
9
+
10
+ where rms_norm(x) = x / sqrt(mean(x^2, dim=-1) + eps)
11
+
12
+ Performance optimizations:
13
+ - @triton.autotune selects optimal BLOCK_D, num_warps, and num_stages per
14
+ hidden dimension.
15
+ - Single-tile specialization: when D <= BLOCK_D, all data stays in registers
16
+ across the reduction and output phases, eliminating redundant global reads.
17
+ - Multi-tile software pipelining: explicit num_stages in autotune configs
18
+ enables overlapping memory loads with computation across loop iterations.
19
+ - In-kernel binary search for expert mapping: eliminates 2 PyTorch kernel
20
+ launches (torch.arange + torch.bucketize) per forward/backward call.
21
+ - Backward 2-pass optimization: pass 1 merges RMS statistics computation
22
+ with dot product accumulation, pass 2 computes gradients. This reduces
23
+ memory traffic compared to a naive 3-pass approach.
24
+
25
+ Forward kernel: one program per row, tiles over D dimension.
26
+ - Computes x, x^2, x^3 in registers
27
+ - Computes three RMS norms in a single pass (shared variance reduction)
28
+ - Applies polynomial weights + bias + mul in-place
29
+
30
+ Backward kernel: one program per row, tiles over D dimension.
31
+ - Recomputes forward intermediates from saved inputs (activation recomputation)
32
+ - 2-pass: (1) RMS stats + dot products + bias grad, (2) grad_input + grad_mul + weight grads
33
+ - Weight/bias gradients use tl.atomic_add for cross-row accumulation
34
+ """
35
+
36
+ import torch
37
+ from torch import Tensor
38
+
39
+ try:
40
+ import triton
41
+ import triton.language as tl
42
+
43
+ HAS_TRITON = True
44
+ except ImportError:
45
+ HAS_TRITON = False
46
+
47
+
48
+ # ---------------------------------------------------------------------------
49
+ # PyTorch reference implementation (for testing and benchmarking)
50
+ # ---------------------------------------------------------------------------
51
+ def _rms_norm(x: Tensor, eps: float) -> Tensor:
52
+ """Per-row RMS normalization: x / sqrt(mean(x^2, dim=-1) + eps)"""
53
+ return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps)
54
+
55
+
56
+ def grouped_fused_mul_poly_norm_ref(
57
+ input: Tensor,
58
+ mul: Tensor,
59
+ weight: Tensor,
60
+ bias: Tensor,
61
+ offsets: Tensor,
62
+ eps: float = 1e-6,
63
+ expert_offset: int = 0,
64
+ ) -> Tensor:
65
+ """PyTorch reference for Grouped FusedMulPolyNorm (vectorized, single pass).
66
+
67
+ Uses torch.bucketize to map tokens to experts, then computes PolyNorm
68
+ for all tokens at once. torch.compile friendly.
69
+
70
+ Args:
71
+ input: (total_tokens, D) - concatenated tokens for all experts
72
+ mul: (total_tokens, D) - gate values to multiply with
73
+ weight: (num_experts, 3) - per-expert polynomial weights [x^3, x^2, x]
74
+ bias: (num_experts, 1) - per-expert polynomial bias
75
+ offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32)
76
+ eps: numerical stability epsilon
77
+
78
+ Returns:
79
+ (total_tokens, D) - output tensor
80
+ """
81
+ orig_dtype = input.dtype
82
+
83
+ token_positions = torch.arange(input.shape[0], device=input.device)
84
+ expert_idx = torch.bucketize(token_positions, offsets, right=True) + expert_offset
85
+
86
+ weight_fp32 = weight.float()
87
+ bias_fp32 = bias.float()
88
+
89
+ per_token_w = weight_fp32[expert_idx]
90
+ per_token_b = bias_fp32[expert_idx]
91
+
92
+ x = input.float()
93
+ m = mul.float()
94
+
95
+ x2 = x * x
96
+ x3 = x2 * x
97
+
98
+ poly = (per_token_w[:, 0:1] * _rms_norm(x3, eps) +
99
+ per_token_w[:, 1:2] * _rms_norm(x2, eps) +
100
+ per_token_w[:, 2:3] * _rms_norm(x, eps) + per_token_b)
101
+
102
+ return (poly * m).to(orig_dtype)
103
+
104
+
105
+ # ---------------------------------------------------------------------------
106
+ # Triton kernel implementation
107
+ # ---------------------------------------------------------------------------
108
+ if HAS_TRITON:
109
+ # --- Autotune configurations ---
110
+ _GROUPED_POLYNORM_FWD_CONFIGS = [
111
+ triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2),
112
+ triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3),
113
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2),
114
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3),
115
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2),
116
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4),
117
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2),
118
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3),
119
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4),
120
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2),
121
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3),
122
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4),
123
+ triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2),
124
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2),
125
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3),
126
+ triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2),
127
+ triton.Config({"BLOCK_D": 2048}, num_warps=4, num_stages=1),
128
+ triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1),
129
+ triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1),
130
+ triton.Config({"BLOCK_D": 2048}, num_warps=32, num_stages=1),
131
+ ]
132
+
133
+ _GROUPED_POLYNORM_BWD_CONFIGS = [
134
+ triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2),
135
+ triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3),
136
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2),
137
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3),
138
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2),
139
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4),
140
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2),
141
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3),
142
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4),
143
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=5),
144
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2),
145
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3),
146
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4),
147
+ triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2),
148
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2),
149
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3),
150
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=4),
151
+ triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2),
152
+ triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1),
153
+ triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1),
154
+ ]
155
+
156
+ @triton.autotune(
157
+ configs=_GROUPED_POLYNORM_FWD_CONFIGS,
158
+ key=["D"],
159
+ )
160
+ @triton.jit
161
+ def _grouped_polynorm_fwd_kernel(
162
+ input_ptr,
163
+ mul_ptr,
164
+ weight_ptr,
165
+ bias_ptr,
166
+ offsets_ptr,
167
+ output_ptr,
168
+ N,
169
+ D,
170
+ num_experts,
171
+ eps,
172
+ expert_offset,
173
+ stride_input_row,
174
+ stride_mul_row,
175
+ stride_out_row,
176
+ BLOCK_D: tl.constexpr,
177
+ ):
178
+ """Forward kernel: one program per row."""
179
+ row = tl.program_id(0)
180
+ if row >= N:
181
+ return
182
+
183
+ # Binary search for expert index (12 iters covers up to 4096 experts)
184
+ lo = 0
185
+ hi = num_experts
186
+ for _ in range(12):
187
+ if lo < hi:
188
+ mid = (lo + hi) // 2
189
+ if tl.load(offsets_ptr + mid) <= row:
190
+ lo = mid + 1
191
+ else:
192
+ hi = mid
193
+ eidx = lo + expert_offset
194
+
195
+ w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32)
196
+ w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32)
197
+ w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32)
198
+ b = tl.load(bias_ptr + eidx).to(tl.float32)
199
+
200
+ input_row_ptr = input_ptr + row * stride_input_row
201
+ mul_row_ptr = mul_ptr + row * stride_mul_row
202
+ out_row_ptr = output_ptr + row * stride_out_row
203
+
204
+ D_float = D.to(tl.float32)
205
+
206
+ # --- Single-tile path ---
207
+ if D <= BLOCK_D:
208
+ d_offs = tl.arange(0, BLOCK_D)
209
+ mask = d_offs < D
210
+
211
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
212
+ other=0.0).to(tl.float32)
213
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
214
+ other=0.0).to(tl.float32)
215
+
216
+ x2 = x * x
217
+ x3 = x2 * x
218
+
219
+ inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps)
220
+ inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps)
221
+ inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps)
222
+
223
+ # Pre-multiply scalar weight * inv_rms to save 1 FMA per element
224
+ w0_inv = w0 * inv_rms_x3
225
+ w1_inv = w1 * inv_rms_x2
226
+ w2_inv = w2 * inv_rms_x
227
+
228
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b
229
+ tl.store(out_row_ptr + d_offs, poly * m, mask=mask)
230
+ else:
231
+ # --- Multi-tile: two-pass approach ---
232
+ sum_x2 = tl.zeros((), dtype=tl.float32)
233
+ sum_x4 = tl.zeros((), dtype=tl.float32)
234
+ sum_x6 = tl.zeros((), dtype=tl.float32)
235
+
236
+ for d_start in range(0, D, BLOCK_D):
237
+ d_offs = d_start + tl.arange(0, BLOCK_D)
238
+ mask = d_offs < D
239
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
240
+ other=0.0).to(tl.float32)
241
+ x2 = x * x
242
+ sum_x2 += tl.sum(x2)
243
+ sum_x4 += tl.sum(x2 * x2)
244
+ sum_x6 += tl.sum(x2 * x2 * x2)
245
+
246
+ inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps)
247
+ inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps)
248
+ inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps)
249
+
250
+ # Pre-multiply scalar weight * inv_rms
251
+ w0_inv = w0 * inv_rms_x3
252
+ w1_inv = w1 * inv_rms_x2
253
+ w2_inv = w2 * inv_rms_x
254
+
255
+ for d_start in range(0, D, BLOCK_D):
256
+ d_offs = d_start + tl.arange(0, BLOCK_D)
257
+ mask = d_offs < D
258
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
259
+ other=0.0).to(tl.float32)
260
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
261
+ other=0.0).to(tl.float32)
262
+ x2 = x * x
263
+ x3 = x2 * x
264
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b
265
+ tl.store(out_row_ptr + d_offs, poly * m, mask=mask)
266
+
267
+ @triton.autotune(
268
+ configs=_GROUPED_POLYNORM_BWD_CONFIGS,
269
+ key=["D"],
270
+ reset_to_zero=["grad_weight_ptr", "grad_bias_ptr"],
271
+ )
272
+ @triton.jit
273
+ def _grouped_polynorm_bwd_kernel(
274
+ grad_out_ptr,
275
+ input_ptr,
276
+ mul_ptr,
277
+ weight_ptr,
278
+ bias_ptr,
279
+ offsets_ptr,
280
+ grad_input_ptr,
281
+ grad_mul_ptr,
282
+ grad_weight_ptr,
283
+ grad_bias_ptr,
284
+ N,
285
+ D,
286
+ num_experts,
287
+ eps,
288
+ expert_offset,
289
+ stride_row,
290
+ BLOCK_D: tl.constexpr,
291
+ ):
292
+ """Backward kernel: one program per row, 2-pass approach.
293
+
294
+ Pass 1: RMS stats + dot products + bias grad
295
+ Pass 2: grad_input + grad_mul + weight grads (via atomic_add)
296
+ """
297
+ row = tl.program_id(0)
298
+ if row >= N:
299
+ return
300
+
301
+ lo = 0
302
+ hi = num_experts
303
+ for _ in range(12):
304
+ if lo < hi:
305
+ mid = (lo + hi) // 2
306
+ if tl.load(offsets_ptr + mid) <= row:
307
+ lo = mid + 1
308
+ else:
309
+ hi = mid
310
+ eidx = lo + expert_offset
311
+
312
+ w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32)
313
+ w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32)
314
+ w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32)
315
+ b_val = tl.load(bias_ptr + eidx).to(tl.float32)
316
+
317
+ input_row_ptr = input_ptr + row * stride_row
318
+ mul_row_ptr = mul_ptr + row * stride_row
319
+ grad_out_row_ptr = grad_out_ptr + row * stride_row
320
+ grad_input_row_ptr = grad_input_ptr + row * stride_row
321
+ grad_mul_row_ptr = grad_mul_ptr + row * stride_row
322
+
323
+ D_float = D.to(tl.float32)
324
+
325
+ # --- Single-tile path ---
326
+ if D <= BLOCK_D:
327
+ d_offs = tl.arange(0, BLOCK_D)
328
+ mask = d_offs < D
329
+
330
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
331
+ other=0.0).to(tl.float32)
332
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
333
+ other=0.0).to(tl.float32)
334
+ go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
335
+ other=0.0).to(tl.float32)
336
+
337
+ x2 = x * x
338
+ x3 = x2 * x
339
+
340
+ # Compute RMS stats (x4 inlined to reduce register pressure)
341
+ inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps)
342
+ inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps)
343
+ inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps)
344
+
345
+ w0_inv = w0 * inv_rms_x3
346
+ w1_inv = w1 * inv_rms_x2
347
+ w2_inv = w2 * inv_rms_x
348
+
349
+ dpoly = go * m
350
+
351
+ # Dot products for coefficients and weight grads
352
+ sum_dpoly_x = tl.sum(dpoly * x)
353
+ sum_dpoly_x2 = tl.sum(dpoly * x2)
354
+ sum_dpoly_x3 = tl.sum(dpoly * x3)
355
+ grad_b_acc = tl.sum(dpoly)
356
+
357
+ # Weight grads
358
+ grad_w0_acc = inv_rms_x3 * sum_dpoly_x3
359
+ grad_w1_acc = inv_rms_x2 * sum_dpoly_x2
360
+ grad_w2_acc = inv_rms_x * sum_dpoly_x
361
+
362
+ # Coefficients for grad_input
363
+ coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float
364
+ coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float
365
+ coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float
366
+
367
+ # grad_mul
368
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv
369
+ tl.store(grad_mul_row_ptr + d_offs, go * (poly + b_val), mask=mask)
370
+
371
+ # grad_input (in-place accumulation to reduce register pressure)
372
+ g = inv_rms_x * (w2 * dpoly - x * coeff_x)
373
+ g += 2.0 * x * inv_rms_x2 * (w1 * dpoly - x2 * coeff_x2)
374
+ g += 3.0 * x2 * inv_rms_x3 * (w0 * dpoly - x3 * coeff_x3)
375
+ tl.store(grad_input_row_ptr + d_offs, g, mask=mask)
376
+
377
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc)
378
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc)
379
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc)
380
+ tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc)
381
+ else:
382
+ # --- Multi-tile: 2-pass ---
383
+ # Pass 1: RMS stats + dot products + bias grad
384
+ sum_x2 = tl.zeros((), dtype=tl.float32)
385
+ sum_x4 = tl.zeros((), dtype=tl.float32)
386
+ sum_x6 = tl.zeros((), dtype=tl.float32)
387
+ sum_dpoly_x = tl.zeros((), dtype=tl.float32)
388
+ sum_dpoly_x2 = tl.zeros((), dtype=tl.float32)
389
+ sum_dpoly_x3 = tl.zeros((), dtype=tl.float32)
390
+ grad_b_acc = tl.zeros((), dtype=tl.float32)
391
+
392
+ for d_start in range(0, D, BLOCK_D):
393
+ d_offs = d_start + tl.arange(0, BLOCK_D)
394
+ mask = d_offs < D
395
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
396
+ other=0.0).to(tl.float32)
397
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
398
+ other=0.0).to(tl.float32)
399
+ go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
400
+ other=0.0).to(tl.float32)
401
+
402
+ x2 = x * x
403
+ x3 = x2 * x
404
+ dpoly = go * m
405
+
406
+ sum_x2 += tl.sum(x2)
407
+ sum_x4 += tl.sum(x2 * x2)
408
+ sum_x6 += tl.sum(x2 * x2 * x2)
409
+ sum_dpoly_x += tl.sum(dpoly * x)
410
+ sum_dpoly_x2 += tl.sum(dpoly * x2)
411
+ sum_dpoly_x3 += tl.sum(dpoly * x3)
412
+ grad_b_acc += tl.sum(dpoly)
413
+
414
+ inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps)
415
+ inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps)
416
+ inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps)
417
+
418
+ w0_inv = w0 * inv_rms_x3
419
+ w1_inv = w1 * inv_rms_x2
420
+ w2_inv = w2 * inv_rms_x
421
+
422
+ # Weight grads
423
+ grad_w0_acc = inv_rms_x3 * sum_dpoly_x3
424
+ grad_w1_acc = inv_rms_x2 * sum_dpoly_x2
425
+ grad_w2_acc = inv_rms_x * sum_dpoly_x
426
+
427
+ # Coefficients for grad_input
428
+ coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float
429
+ coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float
430
+ coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float
431
+
432
+ # Pass 2: grad_input + grad_mul
433
+ for d_start in range(0, D, BLOCK_D):
434
+ d_offs = d_start + tl.arange(0, BLOCK_D)
435
+ mask = d_offs < D
436
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
437
+ other=0.0).to(tl.float32)
438
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
439
+ other=0.0).to(tl.float32)
440
+ go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
441
+ other=0.0).to(tl.float32)
442
+
443
+ x2 = x * x
444
+ x3 = x2 * x
445
+
446
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv
447
+ tl.store(grad_mul_row_ptr + d_offs,
448
+ go * (poly + b_val),
449
+ mask=mask)
450
+
451
+ dpoly = go * m
452
+ g = inv_rms_x * (w2 * dpoly - x * coeff_x)
453
+ g += (2.0 * x * inv_rms_x2 *
454
+ (w1 * dpoly - x2 * coeff_x2))
455
+ g += (3.0 * x2 * inv_rms_x3 *
456
+ (w0 * dpoly - x3 * coeff_x3))
457
+
458
+ tl.store(grad_input_row_ptr + d_offs, g, mask=mask)
459
+
460
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc)
461
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc)
462
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc)
463
+ tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc)
464
+
465
+ class _GroupedPolyNormFn(torch.autograd.Function):
466
+
467
+ @staticmethod
468
+ def forward(ctx, input, mul, weight, bias, offsets, eps, expert_offset):
469
+ N, D = input.shape
470
+ input = input.contiguous()
471
+ mul = mul.contiguous()
472
+ output = torch.empty_like(input)
473
+
474
+ num_experts = offsets.shape[0]
475
+ assert num_experts <= 4096, (
476
+ f"Supports at most 4096 experts, got {num_experts}.")
477
+
478
+ _grouped_polynorm_fwd_kernel[(N,)](
479
+ input,
480
+ mul,
481
+ weight,
482
+ bias,
483
+ offsets,
484
+ output,
485
+ N,
486
+ D,
487
+ num_experts,
488
+ eps,
489
+ expert_offset,
490
+ stride_input_row=input.stride(0),
491
+ stride_mul_row=mul.stride(0),
492
+ stride_out_row=output.stride(0),
493
+ )
494
+
495
+ ctx.save_for_backward(input, mul, weight, bias, offsets)
496
+ ctx.eps = eps
497
+ ctx.expert_offset = expert_offset
498
+ return output
499
+
500
+ @staticmethod
501
+ def backward(ctx, grad_output):
502
+ input, mul, weight, bias, offsets = ctx.saved_tensors
503
+ eps = ctx.eps
504
+ expert_offset = ctx.expert_offset
505
+ N, D = input.shape
506
+
507
+ grad_output = grad_output.contiguous()
508
+ grad_input = torch.empty_like(input)
509
+ grad_mul = torch.empty_like(mul)
510
+ grad_weight = torch.zeros(weight.shape[0],
511
+ 3,
512
+ device=weight.device,
513
+ dtype=torch.float32)
514
+ grad_bias = torch.zeros(bias.shape[0],
515
+ device=bias.device,
516
+ dtype=torch.float32)
517
+
518
+ num_experts = offsets.shape[0]
519
+
520
+ _grouped_polynorm_bwd_kernel[(N,)](
521
+ grad_output,
522
+ input,
523
+ mul,
524
+ weight,
525
+ bias,
526
+ offsets,
527
+ grad_input,
528
+ grad_mul,
529
+ grad_weight,
530
+ grad_bias,
531
+ N,
532
+ D,
533
+ num_experts,
534
+ eps,
535
+ expert_offset,
536
+ stride_row=input.stride(0),
537
+ )
538
+
539
+ grad_weight = grad_weight.to(weight.dtype)
540
+ grad_bias = grad_bias.unsqueeze(-1).to(bias.dtype)
541
+
542
+ return grad_input, grad_mul, grad_weight, grad_bias, None, None, None
543
+
544
+ def grouped_fused_mul_poly_norm(
545
+ input: Tensor,
546
+ mul: Tensor,
547
+ weight: Tensor,
548
+ bias: Tensor,
549
+ offsets: Tensor,
550
+ eps: float = 1e-6,
551
+ expert_offset: int = 0,
552
+ ) -> Tensor:
553
+ """Triton-accelerated Grouped FusedMulPolyNorm.
554
+
555
+ Args:
556
+ input: (total_tokens, D) - concatenated tokens for all experts
557
+ mul: (total_tokens, D) - gate values to multiply with
558
+ weight: (num_experts, 3) - per-expert polynomial weights
559
+ bias: (num_experts, 1) - per-expert polynomial bias
560
+ offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32)
561
+ eps: numerical stability epsilon
562
+ expert_offset: offset to add to expert index
563
+
564
+ Returns:
565
+ (total_tokens, D) - output tensor
566
+ """
567
+ return _GroupedPolyNormFn.apply(input, mul, weight, bias, offsets, eps,
568
+ expert_offset)
569
+
570
+ else:
571
+
572
+ def grouped_fused_mul_poly_norm(
573
+ input: Tensor,
574
+ mul: Tensor,
575
+ weight: Tensor,
576
+ bias: Tensor,
577
+ offsets: Tensor,
578
+ eps: float = 1e-6,
579
+ expert_offset: int = 0,
580
+ ) -> Tensor:
581
+ raise RuntimeError(
582
+ "Triton is not available. Install triton to use "
583
+ "grouped_fused_mul_poly_norm.")
build/torch210-cxx11-rocm70-x86_64-linux/__init__.py CHANGED
@@ -2,6 +2,7 @@ import torch
2
 
3
  from . import layers, parallel_style
4
  from ._ops import ops
 
5
  from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
6
  from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
7
 
@@ -45,6 +46,7 @@ def fused_add_rms_norm(
45
  __all__ = [
46
  "poly_norm",
47
  "fused_mul_poly_norm",
 
48
  "rms_norm",
49
  "fused_add_rms_norm",
50
  "layers",
 
2
 
3
  from . import layers, parallel_style
4
  from ._ops import ops
5
+ from .grouped_poly_norm import grouped_fused_mul_poly_norm
6
  from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
7
  from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
8
 
 
46
  __all__ = [
47
  "poly_norm",
48
  "fused_mul_poly_norm",
49
+ "grouped_fused_mul_poly_norm",
50
  "rms_norm",
51
  "fused_add_rms_norm",
52
  "layers",
build/torch210-cxx11-rocm70-x86_64-linux/{_activation_18b7543_dirty.abi3.so → _activation_0e6f27f_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:45ff2b71abb33d840d92116980e519786ed06f1e337d681d0e3301dba241ff63
3
  size 2919488
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5f8d5a173c51cb2dabe3554da743aed307c04b5d51c9d0d460a8fa5a821b5495
3
  size 2919488
build/torch210-cxx11-rocm70-x86_64-linux/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _activation_18b7543_dirty
3
- ops = torch.ops._activation_18b7543_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_activation_18b7543_dirty::{op_name}"
 
1
  import torch
2
+ from . import _activation_0e6f27f_dirty
3
+ ops = torch.ops._activation_0e6f27f_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_activation_0e6f27f_dirty::{op_name}"
build/torch210-cxx11-rocm70-x86_64-linux/grouped_poly_norm.py ADDED
@@ -0,0 +1,583 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Triton-accelerated Grouped FusedMulPolyNorm for MoE.
2
+
3
+ Fuses the entire PolyNorm computation into two Triton kernels (fwd + bwd),
4
+ eliminating multiple intermediate tensors and kernel launches.
5
+
6
+ PolyNorm formula (per row):
7
+ poly = w[0] * rms_norm(x^3) + w[1] * rms_norm(x^2) + w[2] * rms_norm(x) + bias
8
+ output = poly * mul
9
+
10
+ where rms_norm(x) = x / sqrt(mean(x^2, dim=-1) + eps)
11
+
12
+ Performance optimizations:
13
+ - @triton.autotune selects optimal BLOCK_D, num_warps, and num_stages per
14
+ hidden dimension.
15
+ - Single-tile specialization: when D <= BLOCK_D, all data stays in registers
16
+ across the reduction and output phases, eliminating redundant global reads.
17
+ - Multi-tile software pipelining: explicit num_stages in autotune configs
18
+ enables overlapping memory loads with computation across loop iterations.
19
+ - In-kernel binary search for expert mapping: eliminates 2 PyTorch kernel
20
+ launches (torch.arange + torch.bucketize) per forward/backward call.
21
+ - Backward 2-pass optimization: pass 1 merges RMS statistics computation
22
+ with dot product accumulation, pass 2 computes gradients. This reduces
23
+ memory traffic compared to a naive 3-pass approach.
24
+
25
+ Forward kernel: one program per row, tiles over D dimension.
26
+ - Computes x, x^2, x^3 in registers
27
+ - Computes three RMS norms in a single pass (shared variance reduction)
28
+ - Applies polynomial weights + bias + mul in-place
29
+
30
+ Backward kernel: one program per row, tiles over D dimension.
31
+ - Recomputes forward intermediates from saved inputs (activation recomputation)
32
+ - 2-pass: (1) RMS stats + dot products + bias grad, (2) grad_input + grad_mul + weight grads
33
+ - Weight/bias gradients use tl.atomic_add for cross-row accumulation
34
+ """
35
+
36
+ import torch
37
+ from torch import Tensor
38
+
39
+ try:
40
+ import triton
41
+ import triton.language as tl
42
+
43
+ HAS_TRITON = True
44
+ except ImportError:
45
+ HAS_TRITON = False
46
+
47
+
48
+ # ---------------------------------------------------------------------------
49
+ # PyTorch reference implementation (for testing and benchmarking)
50
+ # ---------------------------------------------------------------------------
51
+ def _rms_norm(x: Tensor, eps: float) -> Tensor:
52
+ """Per-row RMS normalization: x / sqrt(mean(x^2, dim=-1) + eps)"""
53
+ return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps)
54
+
55
+
56
+ def grouped_fused_mul_poly_norm_ref(
57
+ input: Tensor,
58
+ mul: Tensor,
59
+ weight: Tensor,
60
+ bias: Tensor,
61
+ offsets: Tensor,
62
+ eps: float = 1e-6,
63
+ expert_offset: int = 0,
64
+ ) -> Tensor:
65
+ """PyTorch reference for Grouped FusedMulPolyNorm (vectorized, single pass).
66
+
67
+ Uses torch.bucketize to map tokens to experts, then computes PolyNorm
68
+ for all tokens at once. torch.compile friendly.
69
+
70
+ Args:
71
+ input: (total_tokens, D) - concatenated tokens for all experts
72
+ mul: (total_tokens, D) - gate values to multiply with
73
+ weight: (num_experts, 3) - per-expert polynomial weights [x^3, x^2, x]
74
+ bias: (num_experts, 1) - per-expert polynomial bias
75
+ offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32)
76
+ eps: numerical stability epsilon
77
+
78
+ Returns:
79
+ (total_tokens, D) - output tensor
80
+ """
81
+ orig_dtype = input.dtype
82
+
83
+ token_positions = torch.arange(input.shape[0], device=input.device)
84
+ expert_idx = torch.bucketize(token_positions, offsets, right=True) + expert_offset
85
+
86
+ weight_fp32 = weight.float()
87
+ bias_fp32 = bias.float()
88
+
89
+ per_token_w = weight_fp32[expert_idx]
90
+ per_token_b = bias_fp32[expert_idx]
91
+
92
+ x = input.float()
93
+ m = mul.float()
94
+
95
+ x2 = x * x
96
+ x3 = x2 * x
97
+
98
+ poly = (per_token_w[:, 0:1] * _rms_norm(x3, eps) +
99
+ per_token_w[:, 1:2] * _rms_norm(x2, eps) +
100
+ per_token_w[:, 2:3] * _rms_norm(x, eps) + per_token_b)
101
+
102
+ return (poly * m).to(orig_dtype)
103
+
104
+
105
+ # ---------------------------------------------------------------------------
106
+ # Triton kernel implementation
107
+ # ---------------------------------------------------------------------------
108
+ if HAS_TRITON:
109
+ # --- Autotune configurations ---
110
+ _GROUPED_POLYNORM_FWD_CONFIGS = [
111
+ triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2),
112
+ triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3),
113
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2),
114
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3),
115
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2),
116
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4),
117
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2),
118
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3),
119
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4),
120
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2),
121
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3),
122
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4),
123
+ triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2),
124
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2),
125
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3),
126
+ triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2),
127
+ triton.Config({"BLOCK_D": 2048}, num_warps=4, num_stages=1),
128
+ triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1),
129
+ triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1),
130
+ triton.Config({"BLOCK_D": 2048}, num_warps=32, num_stages=1),
131
+ ]
132
+
133
+ _GROUPED_POLYNORM_BWD_CONFIGS = [
134
+ triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2),
135
+ triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3),
136
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2),
137
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3),
138
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2),
139
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4),
140
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2),
141
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3),
142
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4),
143
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=5),
144
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2),
145
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3),
146
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4),
147
+ triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2),
148
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2),
149
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3),
150
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=4),
151
+ triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2),
152
+ triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1),
153
+ triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1),
154
+ ]
155
+
156
+ @triton.autotune(
157
+ configs=_GROUPED_POLYNORM_FWD_CONFIGS,
158
+ key=["D"],
159
+ )
160
+ @triton.jit
161
+ def _grouped_polynorm_fwd_kernel(
162
+ input_ptr,
163
+ mul_ptr,
164
+ weight_ptr,
165
+ bias_ptr,
166
+ offsets_ptr,
167
+ output_ptr,
168
+ N,
169
+ D,
170
+ num_experts,
171
+ eps,
172
+ expert_offset,
173
+ stride_input_row,
174
+ stride_mul_row,
175
+ stride_out_row,
176
+ BLOCK_D: tl.constexpr,
177
+ ):
178
+ """Forward kernel: one program per row."""
179
+ row = tl.program_id(0)
180
+ if row >= N:
181
+ return
182
+
183
+ # Binary search for expert index (12 iters covers up to 4096 experts)
184
+ lo = 0
185
+ hi = num_experts
186
+ for _ in range(12):
187
+ if lo < hi:
188
+ mid = (lo + hi) // 2
189
+ if tl.load(offsets_ptr + mid) <= row:
190
+ lo = mid + 1
191
+ else:
192
+ hi = mid
193
+ eidx = lo + expert_offset
194
+
195
+ w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32)
196
+ w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32)
197
+ w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32)
198
+ b = tl.load(bias_ptr + eidx).to(tl.float32)
199
+
200
+ input_row_ptr = input_ptr + row * stride_input_row
201
+ mul_row_ptr = mul_ptr + row * stride_mul_row
202
+ out_row_ptr = output_ptr + row * stride_out_row
203
+
204
+ D_float = D.to(tl.float32)
205
+
206
+ # --- Single-tile path ---
207
+ if D <= BLOCK_D:
208
+ d_offs = tl.arange(0, BLOCK_D)
209
+ mask = d_offs < D
210
+
211
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
212
+ other=0.0).to(tl.float32)
213
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
214
+ other=0.0).to(tl.float32)
215
+
216
+ x2 = x * x
217
+ x3 = x2 * x
218
+
219
+ inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps)
220
+ inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps)
221
+ inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps)
222
+
223
+ # Pre-multiply scalar weight * inv_rms to save 1 FMA per element
224
+ w0_inv = w0 * inv_rms_x3
225
+ w1_inv = w1 * inv_rms_x2
226
+ w2_inv = w2 * inv_rms_x
227
+
228
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b
229
+ tl.store(out_row_ptr + d_offs, poly * m, mask=mask)
230
+ else:
231
+ # --- Multi-tile: two-pass approach ---
232
+ sum_x2 = tl.zeros((), dtype=tl.float32)
233
+ sum_x4 = tl.zeros((), dtype=tl.float32)
234
+ sum_x6 = tl.zeros((), dtype=tl.float32)
235
+
236
+ for d_start in range(0, D, BLOCK_D):
237
+ d_offs = d_start + tl.arange(0, BLOCK_D)
238
+ mask = d_offs < D
239
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
240
+ other=0.0).to(tl.float32)
241
+ x2 = x * x
242
+ sum_x2 += tl.sum(x2)
243
+ sum_x4 += tl.sum(x2 * x2)
244
+ sum_x6 += tl.sum(x2 * x2 * x2)
245
+
246
+ inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps)
247
+ inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps)
248
+ inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps)
249
+
250
+ # Pre-multiply scalar weight * inv_rms
251
+ w0_inv = w0 * inv_rms_x3
252
+ w1_inv = w1 * inv_rms_x2
253
+ w2_inv = w2 * inv_rms_x
254
+
255
+ for d_start in range(0, D, BLOCK_D):
256
+ d_offs = d_start + tl.arange(0, BLOCK_D)
257
+ mask = d_offs < D
258
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
259
+ other=0.0).to(tl.float32)
260
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
261
+ other=0.0).to(tl.float32)
262
+ x2 = x * x
263
+ x3 = x2 * x
264
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b
265
+ tl.store(out_row_ptr + d_offs, poly * m, mask=mask)
266
+
267
+ @triton.autotune(
268
+ configs=_GROUPED_POLYNORM_BWD_CONFIGS,
269
+ key=["D"],
270
+ reset_to_zero=["grad_weight_ptr", "grad_bias_ptr"],
271
+ )
272
+ @triton.jit
273
+ def _grouped_polynorm_bwd_kernel(
274
+ grad_out_ptr,
275
+ input_ptr,
276
+ mul_ptr,
277
+ weight_ptr,
278
+ bias_ptr,
279
+ offsets_ptr,
280
+ grad_input_ptr,
281
+ grad_mul_ptr,
282
+ grad_weight_ptr,
283
+ grad_bias_ptr,
284
+ N,
285
+ D,
286
+ num_experts,
287
+ eps,
288
+ expert_offset,
289
+ stride_row,
290
+ BLOCK_D: tl.constexpr,
291
+ ):
292
+ """Backward kernel: one program per row, 2-pass approach.
293
+
294
+ Pass 1: RMS stats + dot products + bias grad
295
+ Pass 2: grad_input + grad_mul + weight grads (via atomic_add)
296
+ """
297
+ row = tl.program_id(0)
298
+ if row >= N:
299
+ return
300
+
301
+ lo = 0
302
+ hi = num_experts
303
+ for _ in range(12):
304
+ if lo < hi:
305
+ mid = (lo + hi) // 2
306
+ if tl.load(offsets_ptr + mid) <= row:
307
+ lo = mid + 1
308
+ else:
309
+ hi = mid
310
+ eidx = lo + expert_offset
311
+
312
+ w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32)
313
+ w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32)
314
+ w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32)
315
+ b_val = tl.load(bias_ptr + eidx).to(tl.float32)
316
+
317
+ input_row_ptr = input_ptr + row * stride_row
318
+ mul_row_ptr = mul_ptr + row * stride_row
319
+ grad_out_row_ptr = grad_out_ptr + row * stride_row
320
+ grad_input_row_ptr = grad_input_ptr + row * stride_row
321
+ grad_mul_row_ptr = grad_mul_ptr + row * stride_row
322
+
323
+ D_float = D.to(tl.float32)
324
+
325
+ # --- Single-tile path ---
326
+ if D <= BLOCK_D:
327
+ d_offs = tl.arange(0, BLOCK_D)
328
+ mask = d_offs < D
329
+
330
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
331
+ other=0.0).to(tl.float32)
332
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
333
+ other=0.0).to(tl.float32)
334
+ go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
335
+ other=0.0).to(tl.float32)
336
+
337
+ x2 = x * x
338
+ x3 = x2 * x
339
+
340
+ # Compute RMS stats (x4 inlined to reduce register pressure)
341
+ inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps)
342
+ inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps)
343
+ inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps)
344
+
345
+ w0_inv = w0 * inv_rms_x3
346
+ w1_inv = w1 * inv_rms_x2
347
+ w2_inv = w2 * inv_rms_x
348
+
349
+ dpoly = go * m
350
+
351
+ # Dot products for coefficients and weight grads
352
+ sum_dpoly_x = tl.sum(dpoly * x)
353
+ sum_dpoly_x2 = tl.sum(dpoly * x2)
354
+ sum_dpoly_x3 = tl.sum(dpoly * x3)
355
+ grad_b_acc = tl.sum(dpoly)
356
+
357
+ # Weight grads
358
+ grad_w0_acc = inv_rms_x3 * sum_dpoly_x3
359
+ grad_w1_acc = inv_rms_x2 * sum_dpoly_x2
360
+ grad_w2_acc = inv_rms_x * sum_dpoly_x
361
+
362
+ # Coefficients for grad_input
363
+ coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float
364
+ coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float
365
+ coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float
366
+
367
+ # grad_mul
368
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv
369
+ tl.store(grad_mul_row_ptr + d_offs, go * (poly + b_val), mask=mask)
370
+
371
+ # grad_input (in-place accumulation to reduce register pressure)
372
+ g = inv_rms_x * (w2 * dpoly - x * coeff_x)
373
+ g += 2.0 * x * inv_rms_x2 * (w1 * dpoly - x2 * coeff_x2)
374
+ g += 3.0 * x2 * inv_rms_x3 * (w0 * dpoly - x3 * coeff_x3)
375
+ tl.store(grad_input_row_ptr + d_offs, g, mask=mask)
376
+
377
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc)
378
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc)
379
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc)
380
+ tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc)
381
+ else:
382
+ # --- Multi-tile: 2-pass ---
383
+ # Pass 1: RMS stats + dot products + bias grad
384
+ sum_x2 = tl.zeros((), dtype=tl.float32)
385
+ sum_x4 = tl.zeros((), dtype=tl.float32)
386
+ sum_x6 = tl.zeros((), dtype=tl.float32)
387
+ sum_dpoly_x = tl.zeros((), dtype=tl.float32)
388
+ sum_dpoly_x2 = tl.zeros((), dtype=tl.float32)
389
+ sum_dpoly_x3 = tl.zeros((), dtype=tl.float32)
390
+ grad_b_acc = tl.zeros((), dtype=tl.float32)
391
+
392
+ for d_start in range(0, D, BLOCK_D):
393
+ d_offs = d_start + tl.arange(0, BLOCK_D)
394
+ mask = d_offs < D
395
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
396
+ other=0.0).to(tl.float32)
397
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
398
+ other=0.0).to(tl.float32)
399
+ go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
400
+ other=0.0).to(tl.float32)
401
+
402
+ x2 = x * x
403
+ x3 = x2 * x
404
+ dpoly = go * m
405
+
406
+ sum_x2 += tl.sum(x2)
407
+ sum_x4 += tl.sum(x2 * x2)
408
+ sum_x6 += tl.sum(x2 * x2 * x2)
409
+ sum_dpoly_x += tl.sum(dpoly * x)
410
+ sum_dpoly_x2 += tl.sum(dpoly * x2)
411
+ sum_dpoly_x3 += tl.sum(dpoly * x3)
412
+ grad_b_acc += tl.sum(dpoly)
413
+
414
+ inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps)
415
+ inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps)
416
+ inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps)
417
+
418
+ w0_inv = w0 * inv_rms_x3
419
+ w1_inv = w1 * inv_rms_x2
420
+ w2_inv = w2 * inv_rms_x
421
+
422
+ # Weight grads
423
+ grad_w0_acc = inv_rms_x3 * sum_dpoly_x3
424
+ grad_w1_acc = inv_rms_x2 * sum_dpoly_x2
425
+ grad_w2_acc = inv_rms_x * sum_dpoly_x
426
+
427
+ # Coefficients for grad_input
428
+ coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float
429
+ coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float
430
+ coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float
431
+
432
+ # Pass 2: grad_input + grad_mul
433
+ for d_start in range(0, D, BLOCK_D):
434
+ d_offs = d_start + tl.arange(0, BLOCK_D)
435
+ mask = d_offs < D
436
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
437
+ other=0.0).to(tl.float32)
438
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
439
+ other=0.0).to(tl.float32)
440
+ go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
441
+ other=0.0).to(tl.float32)
442
+
443
+ x2 = x * x
444
+ x3 = x2 * x
445
+
446
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv
447
+ tl.store(grad_mul_row_ptr + d_offs,
448
+ go * (poly + b_val),
449
+ mask=mask)
450
+
451
+ dpoly = go * m
452
+ g = inv_rms_x * (w2 * dpoly - x * coeff_x)
453
+ g += (2.0 * x * inv_rms_x2 *
454
+ (w1 * dpoly - x2 * coeff_x2))
455
+ g += (3.0 * x2 * inv_rms_x3 *
456
+ (w0 * dpoly - x3 * coeff_x3))
457
+
458
+ tl.store(grad_input_row_ptr + d_offs, g, mask=mask)
459
+
460
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc)
461
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc)
462
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc)
463
+ tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc)
464
+
465
+ class _GroupedPolyNormFn(torch.autograd.Function):
466
+
467
+ @staticmethod
468
+ def forward(ctx, input, mul, weight, bias, offsets, eps, expert_offset):
469
+ N, D = input.shape
470
+ input = input.contiguous()
471
+ mul = mul.contiguous()
472
+ output = torch.empty_like(input)
473
+
474
+ num_experts = offsets.shape[0]
475
+ assert num_experts <= 4096, (
476
+ f"Supports at most 4096 experts, got {num_experts}.")
477
+
478
+ _grouped_polynorm_fwd_kernel[(N,)](
479
+ input,
480
+ mul,
481
+ weight,
482
+ bias,
483
+ offsets,
484
+ output,
485
+ N,
486
+ D,
487
+ num_experts,
488
+ eps,
489
+ expert_offset,
490
+ stride_input_row=input.stride(0),
491
+ stride_mul_row=mul.stride(0),
492
+ stride_out_row=output.stride(0),
493
+ )
494
+
495
+ ctx.save_for_backward(input, mul, weight, bias, offsets)
496
+ ctx.eps = eps
497
+ ctx.expert_offset = expert_offset
498
+ return output
499
+
500
+ @staticmethod
501
+ def backward(ctx, grad_output):
502
+ input, mul, weight, bias, offsets = ctx.saved_tensors
503
+ eps = ctx.eps
504
+ expert_offset = ctx.expert_offset
505
+ N, D = input.shape
506
+
507
+ grad_output = grad_output.contiguous()
508
+ grad_input = torch.empty_like(input)
509
+ grad_mul = torch.empty_like(mul)
510
+ grad_weight = torch.zeros(weight.shape[0],
511
+ 3,
512
+ device=weight.device,
513
+ dtype=torch.float32)
514
+ grad_bias = torch.zeros(bias.shape[0],
515
+ device=bias.device,
516
+ dtype=torch.float32)
517
+
518
+ num_experts = offsets.shape[0]
519
+
520
+ _grouped_polynorm_bwd_kernel[(N,)](
521
+ grad_output,
522
+ input,
523
+ mul,
524
+ weight,
525
+ bias,
526
+ offsets,
527
+ grad_input,
528
+ grad_mul,
529
+ grad_weight,
530
+ grad_bias,
531
+ N,
532
+ D,
533
+ num_experts,
534
+ eps,
535
+ expert_offset,
536
+ stride_row=input.stride(0),
537
+ )
538
+
539
+ grad_weight = grad_weight.to(weight.dtype)
540
+ grad_bias = grad_bias.unsqueeze(-1).to(bias.dtype)
541
+
542
+ return grad_input, grad_mul, grad_weight, grad_bias, None, None, None
543
+
544
+ def grouped_fused_mul_poly_norm(
545
+ input: Tensor,
546
+ mul: Tensor,
547
+ weight: Tensor,
548
+ bias: Tensor,
549
+ offsets: Tensor,
550
+ eps: float = 1e-6,
551
+ expert_offset: int = 0,
552
+ ) -> Tensor:
553
+ """Triton-accelerated Grouped FusedMulPolyNorm.
554
+
555
+ Args:
556
+ input: (total_tokens, D) - concatenated tokens for all experts
557
+ mul: (total_tokens, D) - gate values to multiply with
558
+ weight: (num_experts, 3) - per-expert polynomial weights
559
+ bias: (num_experts, 1) - per-expert polynomial bias
560
+ offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32)
561
+ eps: numerical stability epsilon
562
+ expert_offset: offset to add to expert index
563
+
564
+ Returns:
565
+ (total_tokens, D) - output tensor
566
+ """
567
+ return _GroupedPolyNormFn.apply(input, mul, weight, bias, offsets, eps,
568
+ expert_offset)
569
+
570
+ else:
571
+
572
+ def grouped_fused_mul_poly_norm(
573
+ input: Tensor,
574
+ mul: Tensor,
575
+ weight: Tensor,
576
+ bias: Tensor,
577
+ offsets: Tensor,
578
+ eps: float = 1e-6,
579
+ expert_offset: int = 0,
580
+ ) -> Tensor:
581
+ raise RuntimeError(
582
+ "Triton is not available. Install triton to use "
583
+ "grouped_fused_mul_poly_norm.")
build/torch210-cxx11-rocm71-x86_64-linux/__init__.py CHANGED
@@ -2,6 +2,7 @@ import torch
2
 
3
  from . import layers, parallel_style
4
  from ._ops import ops
 
5
  from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
6
  from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
7
 
@@ -45,6 +46,7 @@ def fused_add_rms_norm(
45
  __all__ = [
46
  "poly_norm",
47
  "fused_mul_poly_norm",
 
48
  "rms_norm",
49
  "fused_add_rms_norm",
50
  "layers",
 
2
 
3
  from . import layers, parallel_style
4
  from ._ops import ops
5
+ from .grouped_poly_norm import grouped_fused_mul_poly_norm
6
  from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
7
  from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
8
 
 
46
  __all__ = [
47
  "poly_norm",
48
  "fused_mul_poly_norm",
49
+ "grouped_fused_mul_poly_norm",
50
  "rms_norm",
51
  "fused_add_rms_norm",
52
  "layers",
build/torch210-cxx11-rocm71-x86_64-linux/{_activation_18b7543_dirty.abi3.so → _activation_0e6f27f_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:af4db38e8d5ad56226f5a95a86c2b5fc726bd9d576d07df2f07d3f03c1b6b35b
3
  size 2911200
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:28e6de9c3cd172e95284b0df0e15a2afb21ab7b89dd624e69b1361942095e8be
3
  size 2911200
build/torch210-cxx11-rocm71-x86_64-linux/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _activation_18b7543_dirty
3
- ops = torch.ops._activation_18b7543_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_activation_18b7543_dirty::{op_name}"
 
1
  import torch
2
+ from . import _activation_0e6f27f_dirty
3
+ ops = torch.ops._activation_0e6f27f_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_activation_0e6f27f_dirty::{op_name}"
build/torch210-cxx11-rocm71-x86_64-linux/grouped_poly_norm.py ADDED
@@ -0,0 +1,583 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Triton-accelerated Grouped FusedMulPolyNorm for MoE.
2
+
3
+ Fuses the entire PolyNorm computation into two Triton kernels (fwd + bwd),
4
+ eliminating multiple intermediate tensors and kernel launches.
5
+
6
+ PolyNorm formula (per row):
7
+ poly = w[0] * rms_norm(x^3) + w[1] * rms_norm(x^2) + w[2] * rms_norm(x) + bias
8
+ output = poly * mul
9
+
10
+ where rms_norm(x) = x / sqrt(mean(x^2, dim=-1) + eps)
11
+
12
+ Performance optimizations:
13
+ - @triton.autotune selects optimal BLOCK_D, num_warps, and num_stages per
14
+ hidden dimension.
15
+ - Single-tile specialization: when D <= BLOCK_D, all data stays in registers
16
+ across the reduction and output phases, eliminating redundant global reads.
17
+ - Multi-tile software pipelining: explicit num_stages in autotune configs
18
+ enables overlapping memory loads with computation across loop iterations.
19
+ - In-kernel binary search for expert mapping: eliminates 2 PyTorch kernel
20
+ launches (torch.arange + torch.bucketize) per forward/backward call.
21
+ - Backward 2-pass optimization: pass 1 merges RMS statistics computation
22
+ with dot product accumulation, pass 2 computes gradients. This reduces
23
+ memory traffic compared to a naive 3-pass approach.
24
+
25
+ Forward kernel: one program per row, tiles over D dimension.
26
+ - Computes x, x^2, x^3 in registers
27
+ - Computes three RMS norms in a single pass (shared variance reduction)
28
+ - Applies polynomial weights + bias + mul in-place
29
+
30
+ Backward kernel: one program per row, tiles over D dimension.
31
+ - Recomputes forward intermediates from saved inputs (activation recomputation)
32
+ - 2-pass: (1) RMS stats + dot products + bias grad, (2) grad_input + grad_mul + weight grads
33
+ - Weight/bias gradients use tl.atomic_add for cross-row accumulation
34
+ """
35
+
36
+ import torch
37
+ from torch import Tensor
38
+
39
+ try:
40
+ import triton
41
+ import triton.language as tl
42
+
43
+ HAS_TRITON = True
44
+ except ImportError:
45
+ HAS_TRITON = False
46
+
47
+
48
+ # ---------------------------------------------------------------------------
49
+ # PyTorch reference implementation (for testing and benchmarking)
50
+ # ---------------------------------------------------------------------------
51
+ def _rms_norm(x: Tensor, eps: float) -> Tensor:
52
+ """Per-row RMS normalization: x / sqrt(mean(x^2, dim=-1) + eps)"""
53
+ return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps)
54
+
55
+
56
+ def grouped_fused_mul_poly_norm_ref(
57
+ input: Tensor,
58
+ mul: Tensor,
59
+ weight: Tensor,
60
+ bias: Tensor,
61
+ offsets: Tensor,
62
+ eps: float = 1e-6,
63
+ expert_offset: int = 0,
64
+ ) -> Tensor:
65
+ """PyTorch reference for Grouped FusedMulPolyNorm (vectorized, single pass).
66
+
67
+ Uses torch.bucketize to map tokens to experts, then computes PolyNorm
68
+ for all tokens at once. torch.compile friendly.
69
+
70
+ Args:
71
+ input: (total_tokens, D) - concatenated tokens for all experts
72
+ mul: (total_tokens, D) - gate values to multiply with
73
+ weight: (num_experts, 3) - per-expert polynomial weights [x^3, x^2, x]
74
+ bias: (num_experts, 1) - per-expert polynomial bias
75
+ offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32)
76
+ eps: numerical stability epsilon
77
+
78
+ Returns:
79
+ (total_tokens, D) - output tensor
80
+ """
81
+ orig_dtype = input.dtype
82
+
83
+ token_positions = torch.arange(input.shape[0], device=input.device)
84
+ expert_idx = torch.bucketize(token_positions, offsets, right=True) + expert_offset
85
+
86
+ weight_fp32 = weight.float()
87
+ bias_fp32 = bias.float()
88
+
89
+ per_token_w = weight_fp32[expert_idx]
90
+ per_token_b = bias_fp32[expert_idx]
91
+
92
+ x = input.float()
93
+ m = mul.float()
94
+
95
+ x2 = x * x
96
+ x3 = x2 * x
97
+
98
+ poly = (per_token_w[:, 0:1] * _rms_norm(x3, eps) +
99
+ per_token_w[:, 1:2] * _rms_norm(x2, eps) +
100
+ per_token_w[:, 2:3] * _rms_norm(x, eps) + per_token_b)
101
+
102
+ return (poly * m).to(orig_dtype)
103
+
104
+
105
+ # ---------------------------------------------------------------------------
106
+ # Triton kernel implementation
107
+ # ---------------------------------------------------------------------------
108
+ if HAS_TRITON:
109
+ # --- Autotune configurations ---
110
+ _GROUPED_POLYNORM_FWD_CONFIGS = [
111
+ triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2),
112
+ triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3),
113
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2),
114
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3),
115
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2),
116
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4),
117
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2),
118
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3),
119
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4),
120
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2),
121
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3),
122
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4),
123
+ triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2),
124
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2),
125
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3),
126
+ triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2),
127
+ triton.Config({"BLOCK_D": 2048}, num_warps=4, num_stages=1),
128
+ triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1),
129
+ triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1),
130
+ triton.Config({"BLOCK_D": 2048}, num_warps=32, num_stages=1),
131
+ ]
132
+
133
+ _GROUPED_POLYNORM_BWD_CONFIGS = [
134
+ triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2),
135
+ triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3),
136
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2),
137
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3),
138
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2),
139
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4),
140
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2),
141
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3),
142
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4),
143
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=5),
144
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2),
145
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3),
146
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4),
147
+ triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2),
148
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2),
149
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3),
150
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=4),
151
+ triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2),
152
+ triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1),
153
+ triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1),
154
+ ]
155
+
156
+ @triton.autotune(
157
+ configs=_GROUPED_POLYNORM_FWD_CONFIGS,
158
+ key=["D"],
159
+ )
160
+ @triton.jit
161
+ def _grouped_polynorm_fwd_kernel(
162
+ input_ptr,
163
+ mul_ptr,
164
+ weight_ptr,
165
+ bias_ptr,
166
+ offsets_ptr,
167
+ output_ptr,
168
+ N,
169
+ D,
170
+ num_experts,
171
+ eps,
172
+ expert_offset,
173
+ stride_input_row,
174
+ stride_mul_row,
175
+ stride_out_row,
176
+ BLOCK_D: tl.constexpr,
177
+ ):
178
+ """Forward kernel: one program per row."""
179
+ row = tl.program_id(0)
180
+ if row >= N:
181
+ return
182
+
183
+ # Binary search for expert index (12 iters covers up to 4096 experts)
184
+ lo = 0
185
+ hi = num_experts
186
+ for _ in range(12):
187
+ if lo < hi:
188
+ mid = (lo + hi) // 2
189
+ if tl.load(offsets_ptr + mid) <= row:
190
+ lo = mid + 1
191
+ else:
192
+ hi = mid
193
+ eidx = lo + expert_offset
194
+
195
+ w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32)
196
+ w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32)
197
+ w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32)
198
+ b = tl.load(bias_ptr + eidx).to(tl.float32)
199
+
200
+ input_row_ptr = input_ptr + row * stride_input_row
201
+ mul_row_ptr = mul_ptr + row * stride_mul_row
202
+ out_row_ptr = output_ptr + row * stride_out_row
203
+
204
+ D_float = D.to(tl.float32)
205
+
206
+ # --- Single-tile path ---
207
+ if D <= BLOCK_D:
208
+ d_offs = tl.arange(0, BLOCK_D)
209
+ mask = d_offs < D
210
+
211
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
212
+ other=0.0).to(tl.float32)
213
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
214
+ other=0.0).to(tl.float32)
215
+
216
+ x2 = x * x
217
+ x3 = x2 * x
218
+
219
+ inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps)
220
+ inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps)
221
+ inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps)
222
+
223
+ # Pre-multiply scalar weight * inv_rms to save 1 FMA per element
224
+ w0_inv = w0 * inv_rms_x3
225
+ w1_inv = w1 * inv_rms_x2
226
+ w2_inv = w2 * inv_rms_x
227
+
228
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b
229
+ tl.store(out_row_ptr + d_offs, poly * m, mask=mask)
230
+ else:
231
+ # --- Multi-tile: two-pass approach ---
232
+ sum_x2 = tl.zeros((), dtype=tl.float32)
233
+ sum_x4 = tl.zeros((), dtype=tl.float32)
234
+ sum_x6 = tl.zeros((), dtype=tl.float32)
235
+
236
+ for d_start in range(0, D, BLOCK_D):
237
+ d_offs = d_start + tl.arange(0, BLOCK_D)
238
+ mask = d_offs < D
239
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
240
+ other=0.0).to(tl.float32)
241
+ x2 = x * x
242
+ sum_x2 += tl.sum(x2)
243
+ sum_x4 += tl.sum(x2 * x2)
244
+ sum_x6 += tl.sum(x2 * x2 * x2)
245
+
246
+ inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps)
247
+ inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps)
248
+ inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps)
249
+
250
+ # Pre-multiply scalar weight * inv_rms
251
+ w0_inv = w0 * inv_rms_x3
252
+ w1_inv = w1 * inv_rms_x2
253
+ w2_inv = w2 * inv_rms_x
254
+
255
+ for d_start in range(0, D, BLOCK_D):
256
+ d_offs = d_start + tl.arange(0, BLOCK_D)
257
+ mask = d_offs < D
258
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
259
+ other=0.0).to(tl.float32)
260
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
261
+ other=0.0).to(tl.float32)
262
+ x2 = x * x
263
+ x3 = x2 * x
264
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b
265
+ tl.store(out_row_ptr + d_offs, poly * m, mask=mask)
266
+
267
+ @triton.autotune(
268
+ configs=_GROUPED_POLYNORM_BWD_CONFIGS,
269
+ key=["D"],
270
+ reset_to_zero=["grad_weight_ptr", "grad_bias_ptr"],
271
+ )
272
+ @triton.jit
273
+ def _grouped_polynorm_bwd_kernel(
274
+ grad_out_ptr,
275
+ input_ptr,
276
+ mul_ptr,
277
+ weight_ptr,
278
+ bias_ptr,
279
+ offsets_ptr,
280
+ grad_input_ptr,
281
+ grad_mul_ptr,
282
+ grad_weight_ptr,
283
+ grad_bias_ptr,
284
+ N,
285
+ D,
286
+ num_experts,
287
+ eps,
288
+ expert_offset,
289
+ stride_row,
290
+ BLOCK_D: tl.constexpr,
291
+ ):
292
+ """Backward kernel: one program per row, 2-pass approach.
293
+
294
+ Pass 1: RMS stats + dot products + bias grad
295
+ Pass 2: grad_input + grad_mul + weight grads (via atomic_add)
296
+ """
297
+ row = tl.program_id(0)
298
+ if row >= N:
299
+ return
300
+
301
+ lo = 0
302
+ hi = num_experts
303
+ for _ in range(12):
304
+ if lo < hi:
305
+ mid = (lo + hi) // 2
306
+ if tl.load(offsets_ptr + mid) <= row:
307
+ lo = mid + 1
308
+ else:
309
+ hi = mid
310
+ eidx = lo + expert_offset
311
+
312
+ w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32)
313
+ w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32)
314
+ w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32)
315
+ b_val = tl.load(bias_ptr + eidx).to(tl.float32)
316
+
317
+ input_row_ptr = input_ptr + row * stride_row
318
+ mul_row_ptr = mul_ptr + row * stride_row
319
+ grad_out_row_ptr = grad_out_ptr + row * stride_row
320
+ grad_input_row_ptr = grad_input_ptr + row * stride_row
321
+ grad_mul_row_ptr = grad_mul_ptr + row * stride_row
322
+
323
+ D_float = D.to(tl.float32)
324
+
325
+ # --- Single-tile path ---
326
+ if D <= BLOCK_D:
327
+ d_offs = tl.arange(0, BLOCK_D)
328
+ mask = d_offs < D
329
+
330
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
331
+ other=0.0).to(tl.float32)
332
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
333
+ other=0.0).to(tl.float32)
334
+ go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
335
+ other=0.0).to(tl.float32)
336
+
337
+ x2 = x * x
338
+ x3 = x2 * x
339
+
340
+ # Compute RMS stats (x4 inlined to reduce register pressure)
341
+ inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps)
342
+ inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps)
343
+ inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps)
344
+
345
+ w0_inv = w0 * inv_rms_x3
346
+ w1_inv = w1 * inv_rms_x2
347
+ w2_inv = w2 * inv_rms_x
348
+
349
+ dpoly = go * m
350
+
351
+ # Dot products for coefficients and weight grads
352
+ sum_dpoly_x = tl.sum(dpoly * x)
353
+ sum_dpoly_x2 = tl.sum(dpoly * x2)
354
+ sum_dpoly_x3 = tl.sum(dpoly * x3)
355
+ grad_b_acc = tl.sum(dpoly)
356
+
357
+ # Weight grads
358
+ grad_w0_acc = inv_rms_x3 * sum_dpoly_x3
359
+ grad_w1_acc = inv_rms_x2 * sum_dpoly_x2
360
+ grad_w2_acc = inv_rms_x * sum_dpoly_x
361
+
362
+ # Coefficients for grad_input
363
+ coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float
364
+ coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float
365
+ coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float
366
+
367
+ # grad_mul
368
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv
369
+ tl.store(grad_mul_row_ptr + d_offs, go * (poly + b_val), mask=mask)
370
+
371
+ # grad_input (in-place accumulation to reduce register pressure)
372
+ g = inv_rms_x * (w2 * dpoly - x * coeff_x)
373
+ g += 2.0 * x * inv_rms_x2 * (w1 * dpoly - x2 * coeff_x2)
374
+ g += 3.0 * x2 * inv_rms_x3 * (w0 * dpoly - x3 * coeff_x3)
375
+ tl.store(grad_input_row_ptr + d_offs, g, mask=mask)
376
+
377
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc)
378
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc)
379
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc)
380
+ tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc)
381
+ else:
382
+ # --- Multi-tile: 2-pass ---
383
+ # Pass 1: RMS stats + dot products + bias grad
384
+ sum_x2 = tl.zeros((), dtype=tl.float32)
385
+ sum_x4 = tl.zeros((), dtype=tl.float32)
386
+ sum_x6 = tl.zeros((), dtype=tl.float32)
387
+ sum_dpoly_x = tl.zeros((), dtype=tl.float32)
388
+ sum_dpoly_x2 = tl.zeros((), dtype=tl.float32)
389
+ sum_dpoly_x3 = tl.zeros((), dtype=tl.float32)
390
+ grad_b_acc = tl.zeros((), dtype=tl.float32)
391
+
392
+ for d_start in range(0, D, BLOCK_D):
393
+ d_offs = d_start + tl.arange(0, BLOCK_D)
394
+ mask = d_offs < D
395
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
396
+ other=0.0).to(tl.float32)
397
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
398
+ other=0.0).to(tl.float32)
399
+ go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
400
+ other=0.0).to(tl.float32)
401
+
402
+ x2 = x * x
403
+ x3 = x2 * x
404
+ dpoly = go * m
405
+
406
+ sum_x2 += tl.sum(x2)
407
+ sum_x4 += tl.sum(x2 * x2)
408
+ sum_x6 += tl.sum(x2 * x2 * x2)
409
+ sum_dpoly_x += tl.sum(dpoly * x)
410
+ sum_dpoly_x2 += tl.sum(dpoly * x2)
411
+ sum_dpoly_x3 += tl.sum(dpoly * x3)
412
+ grad_b_acc += tl.sum(dpoly)
413
+
414
+ inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps)
415
+ inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps)
416
+ inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps)
417
+
418
+ w0_inv = w0 * inv_rms_x3
419
+ w1_inv = w1 * inv_rms_x2
420
+ w2_inv = w2 * inv_rms_x
421
+
422
+ # Weight grads
423
+ grad_w0_acc = inv_rms_x3 * sum_dpoly_x3
424
+ grad_w1_acc = inv_rms_x2 * sum_dpoly_x2
425
+ grad_w2_acc = inv_rms_x * sum_dpoly_x
426
+
427
+ # Coefficients for grad_input
428
+ coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float
429
+ coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float
430
+ coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float
431
+
432
+ # Pass 2: grad_input + grad_mul
433
+ for d_start in range(0, D, BLOCK_D):
434
+ d_offs = d_start + tl.arange(0, BLOCK_D)
435
+ mask = d_offs < D
436
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
437
+ other=0.0).to(tl.float32)
438
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
439
+ other=0.0).to(tl.float32)
440
+ go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
441
+ other=0.0).to(tl.float32)
442
+
443
+ x2 = x * x
444
+ x3 = x2 * x
445
+
446
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv
447
+ tl.store(grad_mul_row_ptr + d_offs,
448
+ go * (poly + b_val),
449
+ mask=mask)
450
+
451
+ dpoly = go * m
452
+ g = inv_rms_x * (w2 * dpoly - x * coeff_x)
453
+ g += (2.0 * x * inv_rms_x2 *
454
+ (w1 * dpoly - x2 * coeff_x2))
455
+ g += (3.0 * x2 * inv_rms_x3 *
456
+ (w0 * dpoly - x3 * coeff_x3))
457
+
458
+ tl.store(grad_input_row_ptr + d_offs, g, mask=mask)
459
+
460
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc)
461
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc)
462
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc)
463
+ tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc)
464
+
465
+ class _GroupedPolyNormFn(torch.autograd.Function):
466
+
467
+ @staticmethod
468
+ def forward(ctx, input, mul, weight, bias, offsets, eps, expert_offset):
469
+ N, D = input.shape
470
+ input = input.contiguous()
471
+ mul = mul.contiguous()
472
+ output = torch.empty_like(input)
473
+
474
+ num_experts = offsets.shape[0]
475
+ assert num_experts <= 4096, (
476
+ f"Supports at most 4096 experts, got {num_experts}.")
477
+
478
+ _grouped_polynorm_fwd_kernel[(N,)](
479
+ input,
480
+ mul,
481
+ weight,
482
+ bias,
483
+ offsets,
484
+ output,
485
+ N,
486
+ D,
487
+ num_experts,
488
+ eps,
489
+ expert_offset,
490
+ stride_input_row=input.stride(0),
491
+ stride_mul_row=mul.stride(0),
492
+ stride_out_row=output.stride(0),
493
+ )
494
+
495
+ ctx.save_for_backward(input, mul, weight, bias, offsets)
496
+ ctx.eps = eps
497
+ ctx.expert_offset = expert_offset
498
+ return output
499
+
500
+ @staticmethod
501
+ def backward(ctx, grad_output):
502
+ input, mul, weight, bias, offsets = ctx.saved_tensors
503
+ eps = ctx.eps
504
+ expert_offset = ctx.expert_offset
505
+ N, D = input.shape
506
+
507
+ grad_output = grad_output.contiguous()
508
+ grad_input = torch.empty_like(input)
509
+ grad_mul = torch.empty_like(mul)
510
+ grad_weight = torch.zeros(weight.shape[0],
511
+ 3,
512
+ device=weight.device,
513
+ dtype=torch.float32)
514
+ grad_bias = torch.zeros(bias.shape[0],
515
+ device=bias.device,
516
+ dtype=torch.float32)
517
+
518
+ num_experts = offsets.shape[0]
519
+
520
+ _grouped_polynorm_bwd_kernel[(N,)](
521
+ grad_output,
522
+ input,
523
+ mul,
524
+ weight,
525
+ bias,
526
+ offsets,
527
+ grad_input,
528
+ grad_mul,
529
+ grad_weight,
530
+ grad_bias,
531
+ N,
532
+ D,
533
+ num_experts,
534
+ eps,
535
+ expert_offset,
536
+ stride_row=input.stride(0),
537
+ )
538
+
539
+ grad_weight = grad_weight.to(weight.dtype)
540
+ grad_bias = grad_bias.unsqueeze(-1).to(bias.dtype)
541
+
542
+ return grad_input, grad_mul, grad_weight, grad_bias, None, None, None
543
+
544
+ def grouped_fused_mul_poly_norm(
545
+ input: Tensor,
546
+ mul: Tensor,
547
+ weight: Tensor,
548
+ bias: Tensor,
549
+ offsets: Tensor,
550
+ eps: float = 1e-6,
551
+ expert_offset: int = 0,
552
+ ) -> Tensor:
553
+ """Triton-accelerated Grouped FusedMulPolyNorm.
554
+
555
+ Args:
556
+ input: (total_tokens, D) - concatenated tokens for all experts
557
+ mul: (total_tokens, D) - gate values to multiply with
558
+ weight: (num_experts, 3) - per-expert polynomial weights
559
+ bias: (num_experts, 1) - per-expert polynomial bias
560
+ offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32)
561
+ eps: numerical stability epsilon
562
+ expert_offset: offset to add to expert index
563
+
564
+ Returns:
565
+ (total_tokens, D) - output tensor
566
+ """
567
+ return _GroupedPolyNormFn.apply(input, mul, weight, bias, offsets, eps,
568
+ expert_offset)
569
+
570
+ else:
571
+
572
+ def grouped_fused_mul_poly_norm(
573
+ input: Tensor,
574
+ mul: Tensor,
575
+ weight: Tensor,
576
+ bias: Tensor,
577
+ offsets: Tensor,
578
+ eps: float = 1e-6,
579
+ expert_offset: int = 0,
580
+ ) -> Tensor:
581
+ raise RuntimeError(
582
+ "Triton is not available. Install triton to use "
583
+ "grouped_fused_mul_poly_norm.")
build/torch28-cxx11-cu126-x86_64-linux/__init__.py CHANGED
@@ -2,6 +2,7 @@ import torch
2
 
3
  from . import layers, parallel_style
4
  from ._ops import ops
 
5
  from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
6
  from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
7
 
@@ -45,6 +46,7 @@ def fused_add_rms_norm(
45
  __all__ = [
46
  "poly_norm",
47
  "fused_mul_poly_norm",
 
48
  "rms_norm",
49
  "fused_add_rms_norm",
50
  "layers",
 
2
 
3
  from . import layers, parallel_style
4
  from ._ops import ops
5
+ from .grouped_poly_norm import grouped_fused_mul_poly_norm
6
  from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
7
  from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
8
 
 
46
  __all__ = [
47
  "poly_norm",
48
  "fused_mul_poly_norm",
49
+ "grouped_fused_mul_poly_norm",
50
  "rms_norm",
51
  "fused_add_rms_norm",
52
  "layers",
build/torch28-cxx11-cu126-x86_64-linux/{_activation_18b7543_dirty.abi3.so → _activation_0e6f27f_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:0cc858ef4dce1d14e3a74f5f4a17a2d6c6a8c54cba436f938449c859dd84c3b1
3
  size 10756352
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:51ac098828ee90af0d1d17ae75326f89777b1b2e7ef57e00035aed560c434a20
3
  size 10756352
build/torch28-cxx11-cu126-x86_64-linux/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _activation_18b7543_dirty
3
- ops = torch.ops._activation_18b7543_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_activation_18b7543_dirty::{op_name}"
 
1
  import torch
2
+ from . import _activation_0e6f27f_dirty
3
+ ops = torch.ops._activation_0e6f27f_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_activation_0e6f27f_dirty::{op_name}"
build/torch28-cxx11-cu126-x86_64-linux/grouped_poly_norm.py ADDED
@@ -0,0 +1,583 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Triton-accelerated Grouped FusedMulPolyNorm for MoE.
2
+
3
+ Fuses the entire PolyNorm computation into two Triton kernels (fwd + bwd),
4
+ eliminating multiple intermediate tensors and kernel launches.
5
+
6
+ PolyNorm formula (per row):
7
+ poly = w[0] * rms_norm(x^3) + w[1] * rms_norm(x^2) + w[2] * rms_norm(x) + bias
8
+ output = poly * mul
9
+
10
+ where rms_norm(x) = x / sqrt(mean(x^2, dim=-1) + eps)
11
+
12
+ Performance optimizations:
13
+ - @triton.autotune selects optimal BLOCK_D, num_warps, and num_stages per
14
+ hidden dimension.
15
+ - Single-tile specialization: when D <= BLOCK_D, all data stays in registers
16
+ across the reduction and output phases, eliminating redundant global reads.
17
+ - Multi-tile software pipelining: explicit num_stages in autotune configs
18
+ enables overlapping memory loads with computation across loop iterations.
19
+ - In-kernel binary search for expert mapping: eliminates 2 PyTorch kernel
20
+ launches (torch.arange + torch.bucketize) per forward/backward call.
21
+ - Backward 2-pass optimization: pass 1 merges RMS statistics computation
22
+ with dot product accumulation, pass 2 computes gradients. This reduces
23
+ memory traffic compared to a naive 3-pass approach.
24
+
25
+ Forward kernel: one program per row, tiles over D dimension.
26
+ - Computes x, x^2, x^3 in registers
27
+ - Computes three RMS norms in a single pass (shared variance reduction)
28
+ - Applies polynomial weights + bias + mul in-place
29
+
30
+ Backward kernel: one program per row, tiles over D dimension.
31
+ - Recomputes forward intermediates from saved inputs (activation recomputation)
32
+ - 2-pass: (1) RMS stats + dot products + bias grad, (2) grad_input + grad_mul + weight grads
33
+ - Weight/bias gradients use tl.atomic_add for cross-row accumulation
34
+ """
35
+
36
+ import torch
37
+ from torch import Tensor
38
+
39
+ try:
40
+ import triton
41
+ import triton.language as tl
42
+
43
+ HAS_TRITON = True
44
+ except ImportError:
45
+ HAS_TRITON = False
46
+
47
+
48
+ # ---------------------------------------------------------------------------
49
+ # PyTorch reference implementation (for testing and benchmarking)
50
+ # ---------------------------------------------------------------------------
51
+ def _rms_norm(x: Tensor, eps: float) -> Tensor:
52
+ """Per-row RMS normalization: x / sqrt(mean(x^2, dim=-1) + eps)"""
53
+ return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps)
54
+
55
+
56
+ def grouped_fused_mul_poly_norm_ref(
57
+ input: Tensor,
58
+ mul: Tensor,
59
+ weight: Tensor,
60
+ bias: Tensor,
61
+ offsets: Tensor,
62
+ eps: float = 1e-6,
63
+ expert_offset: int = 0,
64
+ ) -> Tensor:
65
+ """PyTorch reference for Grouped FusedMulPolyNorm (vectorized, single pass).
66
+
67
+ Uses torch.bucketize to map tokens to experts, then computes PolyNorm
68
+ for all tokens at once. torch.compile friendly.
69
+
70
+ Args:
71
+ input: (total_tokens, D) - concatenated tokens for all experts
72
+ mul: (total_tokens, D) - gate values to multiply with
73
+ weight: (num_experts, 3) - per-expert polynomial weights [x^3, x^2, x]
74
+ bias: (num_experts, 1) - per-expert polynomial bias
75
+ offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32)
76
+ eps: numerical stability epsilon
77
+
78
+ Returns:
79
+ (total_tokens, D) - output tensor
80
+ """
81
+ orig_dtype = input.dtype
82
+
83
+ token_positions = torch.arange(input.shape[0], device=input.device)
84
+ expert_idx = torch.bucketize(token_positions, offsets, right=True) + expert_offset
85
+
86
+ weight_fp32 = weight.float()
87
+ bias_fp32 = bias.float()
88
+
89
+ per_token_w = weight_fp32[expert_idx]
90
+ per_token_b = bias_fp32[expert_idx]
91
+
92
+ x = input.float()
93
+ m = mul.float()
94
+
95
+ x2 = x * x
96
+ x3 = x2 * x
97
+
98
+ poly = (per_token_w[:, 0:1] * _rms_norm(x3, eps) +
99
+ per_token_w[:, 1:2] * _rms_norm(x2, eps) +
100
+ per_token_w[:, 2:3] * _rms_norm(x, eps) + per_token_b)
101
+
102
+ return (poly * m).to(orig_dtype)
103
+
104
+
105
+ # ---------------------------------------------------------------------------
106
+ # Triton kernel implementation
107
+ # ---------------------------------------------------------------------------
108
+ if HAS_TRITON:
109
+ # --- Autotune configurations ---
110
+ _GROUPED_POLYNORM_FWD_CONFIGS = [
111
+ triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2),
112
+ triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3),
113
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2),
114
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3),
115
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2),
116
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4),
117
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2),
118
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3),
119
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4),
120
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2),
121
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3),
122
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4),
123
+ triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2),
124
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2),
125
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3),
126
+ triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2),
127
+ triton.Config({"BLOCK_D": 2048}, num_warps=4, num_stages=1),
128
+ triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1),
129
+ triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1),
130
+ triton.Config({"BLOCK_D": 2048}, num_warps=32, num_stages=1),
131
+ ]
132
+
133
+ _GROUPED_POLYNORM_BWD_CONFIGS = [
134
+ triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2),
135
+ triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3),
136
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2),
137
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3),
138
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2),
139
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4),
140
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2),
141
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3),
142
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4),
143
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=5),
144
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2),
145
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3),
146
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4),
147
+ triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2),
148
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2),
149
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3),
150
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=4),
151
+ triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2),
152
+ triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1),
153
+ triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1),
154
+ ]
155
+
156
+ @triton.autotune(
157
+ configs=_GROUPED_POLYNORM_FWD_CONFIGS,
158
+ key=["D"],
159
+ )
160
+ @triton.jit
161
+ def _grouped_polynorm_fwd_kernel(
162
+ input_ptr,
163
+ mul_ptr,
164
+ weight_ptr,
165
+ bias_ptr,
166
+ offsets_ptr,
167
+ output_ptr,
168
+ N,
169
+ D,
170
+ num_experts,
171
+ eps,
172
+ expert_offset,
173
+ stride_input_row,
174
+ stride_mul_row,
175
+ stride_out_row,
176
+ BLOCK_D: tl.constexpr,
177
+ ):
178
+ """Forward kernel: one program per row."""
179
+ row = tl.program_id(0)
180
+ if row >= N:
181
+ return
182
+
183
+ # Binary search for expert index (12 iters covers up to 4096 experts)
184
+ lo = 0
185
+ hi = num_experts
186
+ for _ in range(12):
187
+ if lo < hi:
188
+ mid = (lo + hi) // 2
189
+ if tl.load(offsets_ptr + mid) <= row:
190
+ lo = mid + 1
191
+ else:
192
+ hi = mid
193
+ eidx = lo + expert_offset
194
+
195
+ w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32)
196
+ w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32)
197
+ w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32)
198
+ b = tl.load(bias_ptr + eidx).to(tl.float32)
199
+
200
+ input_row_ptr = input_ptr + row * stride_input_row
201
+ mul_row_ptr = mul_ptr + row * stride_mul_row
202
+ out_row_ptr = output_ptr + row * stride_out_row
203
+
204
+ D_float = D.to(tl.float32)
205
+
206
+ # --- Single-tile path ---
207
+ if D <= BLOCK_D:
208
+ d_offs = tl.arange(0, BLOCK_D)
209
+ mask = d_offs < D
210
+
211
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
212
+ other=0.0).to(tl.float32)
213
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
214
+ other=0.0).to(tl.float32)
215
+
216
+ x2 = x * x
217
+ x3 = x2 * x
218
+
219
+ inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps)
220
+ inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps)
221
+ inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps)
222
+
223
+ # Pre-multiply scalar weight * inv_rms to save 1 FMA per element
224
+ w0_inv = w0 * inv_rms_x3
225
+ w1_inv = w1 * inv_rms_x2
226
+ w2_inv = w2 * inv_rms_x
227
+
228
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b
229
+ tl.store(out_row_ptr + d_offs, poly * m, mask=mask)
230
+ else:
231
+ # --- Multi-tile: two-pass approach ---
232
+ sum_x2 = tl.zeros((), dtype=tl.float32)
233
+ sum_x4 = tl.zeros((), dtype=tl.float32)
234
+ sum_x6 = tl.zeros((), dtype=tl.float32)
235
+
236
+ for d_start in range(0, D, BLOCK_D):
237
+ d_offs = d_start + tl.arange(0, BLOCK_D)
238
+ mask = d_offs < D
239
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
240
+ other=0.0).to(tl.float32)
241
+ x2 = x * x
242
+ sum_x2 += tl.sum(x2)
243
+ sum_x4 += tl.sum(x2 * x2)
244
+ sum_x6 += tl.sum(x2 * x2 * x2)
245
+
246
+ inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps)
247
+ inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps)
248
+ inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps)
249
+
250
+ # Pre-multiply scalar weight * inv_rms
251
+ w0_inv = w0 * inv_rms_x3
252
+ w1_inv = w1 * inv_rms_x2
253
+ w2_inv = w2 * inv_rms_x
254
+
255
+ for d_start in range(0, D, BLOCK_D):
256
+ d_offs = d_start + tl.arange(0, BLOCK_D)
257
+ mask = d_offs < D
258
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
259
+ other=0.0).to(tl.float32)
260
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
261
+ other=0.0).to(tl.float32)
262
+ x2 = x * x
263
+ x3 = x2 * x
264
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b
265
+ tl.store(out_row_ptr + d_offs, poly * m, mask=mask)
266
+
267
+ @triton.autotune(
268
+ configs=_GROUPED_POLYNORM_BWD_CONFIGS,
269
+ key=["D"],
270
+ reset_to_zero=["grad_weight_ptr", "grad_bias_ptr"],
271
+ )
272
+ @triton.jit
273
+ def _grouped_polynorm_bwd_kernel(
274
+ grad_out_ptr,
275
+ input_ptr,
276
+ mul_ptr,
277
+ weight_ptr,
278
+ bias_ptr,
279
+ offsets_ptr,
280
+ grad_input_ptr,
281
+ grad_mul_ptr,
282
+ grad_weight_ptr,
283
+ grad_bias_ptr,
284
+ N,
285
+ D,
286
+ num_experts,
287
+ eps,
288
+ expert_offset,
289
+ stride_row,
290
+ BLOCK_D: tl.constexpr,
291
+ ):
292
+ """Backward kernel: one program per row, 2-pass approach.
293
+
294
+ Pass 1: RMS stats + dot products + bias grad
295
+ Pass 2: grad_input + grad_mul + weight grads (via atomic_add)
296
+ """
297
+ row = tl.program_id(0)
298
+ if row >= N:
299
+ return
300
+
301
+ lo = 0
302
+ hi = num_experts
303
+ for _ in range(12):
304
+ if lo < hi:
305
+ mid = (lo + hi) // 2
306
+ if tl.load(offsets_ptr + mid) <= row:
307
+ lo = mid + 1
308
+ else:
309
+ hi = mid
310
+ eidx = lo + expert_offset
311
+
312
+ w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32)
313
+ w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32)
314
+ w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32)
315
+ b_val = tl.load(bias_ptr + eidx).to(tl.float32)
316
+
317
+ input_row_ptr = input_ptr + row * stride_row
318
+ mul_row_ptr = mul_ptr + row * stride_row
319
+ grad_out_row_ptr = grad_out_ptr + row * stride_row
320
+ grad_input_row_ptr = grad_input_ptr + row * stride_row
321
+ grad_mul_row_ptr = grad_mul_ptr + row * stride_row
322
+
323
+ D_float = D.to(tl.float32)
324
+
325
+ # --- Single-tile path ---
326
+ if D <= BLOCK_D:
327
+ d_offs = tl.arange(0, BLOCK_D)
328
+ mask = d_offs < D
329
+
330
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
331
+ other=0.0).to(tl.float32)
332
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
333
+ other=0.0).to(tl.float32)
334
+ go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
335
+ other=0.0).to(tl.float32)
336
+
337
+ x2 = x * x
338
+ x3 = x2 * x
339
+
340
+ # Compute RMS stats (x4 inlined to reduce register pressure)
341
+ inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps)
342
+ inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps)
343
+ inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps)
344
+
345
+ w0_inv = w0 * inv_rms_x3
346
+ w1_inv = w1 * inv_rms_x2
347
+ w2_inv = w2 * inv_rms_x
348
+
349
+ dpoly = go * m
350
+
351
+ # Dot products for coefficients and weight grads
352
+ sum_dpoly_x = tl.sum(dpoly * x)
353
+ sum_dpoly_x2 = tl.sum(dpoly * x2)
354
+ sum_dpoly_x3 = tl.sum(dpoly * x3)
355
+ grad_b_acc = tl.sum(dpoly)
356
+
357
+ # Weight grads
358
+ grad_w0_acc = inv_rms_x3 * sum_dpoly_x3
359
+ grad_w1_acc = inv_rms_x2 * sum_dpoly_x2
360
+ grad_w2_acc = inv_rms_x * sum_dpoly_x
361
+
362
+ # Coefficients for grad_input
363
+ coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float
364
+ coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float
365
+ coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float
366
+
367
+ # grad_mul
368
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv
369
+ tl.store(grad_mul_row_ptr + d_offs, go * (poly + b_val), mask=mask)
370
+
371
+ # grad_input (in-place accumulation to reduce register pressure)
372
+ g = inv_rms_x * (w2 * dpoly - x * coeff_x)
373
+ g += 2.0 * x * inv_rms_x2 * (w1 * dpoly - x2 * coeff_x2)
374
+ g += 3.0 * x2 * inv_rms_x3 * (w0 * dpoly - x3 * coeff_x3)
375
+ tl.store(grad_input_row_ptr + d_offs, g, mask=mask)
376
+
377
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc)
378
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc)
379
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc)
380
+ tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc)
381
+ else:
382
+ # --- Multi-tile: 2-pass ---
383
+ # Pass 1: RMS stats + dot products + bias grad
384
+ sum_x2 = tl.zeros((), dtype=tl.float32)
385
+ sum_x4 = tl.zeros((), dtype=tl.float32)
386
+ sum_x6 = tl.zeros((), dtype=tl.float32)
387
+ sum_dpoly_x = tl.zeros((), dtype=tl.float32)
388
+ sum_dpoly_x2 = tl.zeros((), dtype=tl.float32)
389
+ sum_dpoly_x3 = tl.zeros((), dtype=tl.float32)
390
+ grad_b_acc = tl.zeros((), dtype=tl.float32)
391
+
392
+ for d_start in range(0, D, BLOCK_D):
393
+ d_offs = d_start + tl.arange(0, BLOCK_D)
394
+ mask = d_offs < D
395
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
396
+ other=0.0).to(tl.float32)
397
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
398
+ other=0.0).to(tl.float32)
399
+ go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
400
+ other=0.0).to(tl.float32)
401
+
402
+ x2 = x * x
403
+ x3 = x2 * x
404
+ dpoly = go * m
405
+
406
+ sum_x2 += tl.sum(x2)
407
+ sum_x4 += tl.sum(x2 * x2)
408
+ sum_x6 += tl.sum(x2 * x2 * x2)
409
+ sum_dpoly_x += tl.sum(dpoly * x)
410
+ sum_dpoly_x2 += tl.sum(dpoly * x2)
411
+ sum_dpoly_x3 += tl.sum(dpoly * x3)
412
+ grad_b_acc += tl.sum(dpoly)
413
+
414
+ inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps)
415
+ inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps)
416
+ inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps)
417
+
418
+ w0_inv = w0 * inv_rms_x3
419
+ w1_inv = w1 * inv_rms_x2
420
+ w2_inv = w2 * inv_rms_x
421
+
422
+ # Weight grads
423
+ grad_w0_acc = inv_rms_x3 * sum_dpoly_x3
424
+ grad_w1_acc = inv_rms_x2 * sum_dpoly_x2
425
+ grad_w2_acc = inv_rms_x * sum_dpoly_x
426
+
427
+ # Coefficients for grad_input
428
+ coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float
429
+ coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float
430
+ coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float
431
+
432
+ # Pass 2: grad_input + grad_mul
433
+ for d_start in range(0, D, BLOCK_D):
434
+ d_offs = d_start + tl.arange(0, BLOCK_D)
435
+ mask = d_offs < D
436
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
437
+ other=0.0).to(tl.float32)
438
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
439
+ other=0.0).to(tl.float32)
440
+ go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
441
+ other=0.0).to(tl.float32)
442
+
443
+ x2 = x * x
444
+ x3 = x2 * x
445
+
446
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv
447
+ tl.store(grad_mul_row_ptr + d_offs,
448
+ go * (poly + b_val),
449
+ mask=mask)
450
+
451
+ dpoly = go * m
452
+ g = inv_rms_x * (w2 * dpoly - x * coeff_x)
453
+ g += (2.0 * x * inv_rms_x2 *
454
+ (w1 * dpoly - x2 * coeff_x2))
455
+ g += (3.0 * x2 * inv_rms_x3 *
456
+ (w0 * dpoly - x3 * coeff_x3))
457
+
458
+ tl.store(grad_input_row_ptr + d_offs, g, mask=mask)
459
+
460
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc)
461
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc)
462
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc)
463
+ tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc)
464
+
465
+ class _GroupedPolyNormFn(torch.autograd.Function):
466
+
467
+ @staticmethod
468
+ def forward(ctx, input, mul, weight, bias, offsets, eps, expert_offset):
469
+ N, D = input.shape
470
+ input = input.contiguous()
471
+ mul = mul.contiguous()
472
+ output = torch.empty_like(input)
473
+
474
+ num_experts = offsets.shape[0]
475
+ assert num_experts <= 4096, (
476
+ f"Supports at most 4096 experts, got {num_experts}.")
477
+
478
+ _grouped_polynorm_fwd_kernel[(N,)](
479
+ input,
480
+ mul,
481
+ weight,
482
+ bias,
483
+ offsets,
484
+ output,
485
+ N,
486
+ D,
487
+ num_experts,
488
+ eps,
489
+ expert_offset,
490
+ stride_input_row=input.stride(0),
491
+ stride_mul_row=mul.stride(0),
492
+ stride_out_row=output.stride(0),
493
+ )
494
+
495
+ ctx.save_for_backward(input, mul, weight, bias, offsets)
496
+ ctx.eps = eps
497
+ ctx.expert_offset = expert_offset
498
+ return output
499
+
500
+ @staticmethod
501
+ def backward(ctx, grad_output):
502
+ input, mul, weight, bias, offsets = ctx.saved_tensors
503
+ eps = ctx.eps
504
+ expert_offset = ctx.expert_offset
505
+ N, D = input.shape
506
+
507
+ grad_output = grad_output.contiguous()
508
+ grad_input = torch.empty_like(input)
509
+ grad_mul = torch.empty_like(mul)
510
+ grad_weight = torch.zeros(weight.shape[0],
511
+ 3,
512
+ device=weight.device,
513
+ dtype=torch.float32)
514
+ grad_bias = torch.zeros(bias.shape[0],
515
+ device=bias.device,
516
+ dtype=torch.float32)
517
+
518
+ num_experts = offsets.shape[0]
519
+
520
+ _grouped_polynorm_bwd_kernel[(N,)](
521
+ grad_output,
522
+ input,
523
+ mul,
524
+ weight,
525
+ bias,
526
+ offsets,
527
+ grad_input,
528
+ grad_mul,
529
+ grad_weight,
530
+ grad_bias,
531
+ N,
532
+ D,
533
+ num_experts,
534
+ eps,
535
+ expert_offset,
536
+ stride_row=input.stride(0),
537
+ )
538
+
539
+ grad_weight = grad_weight.to(weight.dtype)
540
+ grad_bias = grad_bias.unsqueeze(-1).to(bias.dtype)
541
+
542
+ return grad_input, grad_mul, grad_weight, grad_bias, None, None, None
543
+
544
+ def grouped_fused_mul_poly_norm(
545
+ input: Tensor,
546
+ mul: Tensor,
547
+ weight: Tensor,
548
+ bias: Tensor,
549
+ offsets: Tensor,
550
+ eps: float = 1e-6,
551
+ expert_offset: int = 0,
552
+ ) -> Tensor:
553
+ """Triton-accelerated Grouped FusedMulPolyNorm.
554
+
555
+ Args:
556
+ input: (total_tokens, D) - concatenated tokens for all experts
557
+ mul: (total_tokens, D) - gate values to multiply with
558
+ weight: (num_experts, 3) - per-expert polynomial weights
559
+ bias: (num_experts, 1) - per-expert polynomial bias
560
+ offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32)
561
+ eps: numerical stability epsilon
562
+ expert_offset: offset to add to expert index
563
+
564
+ Returns:
565
+ (total_tokens, D) - output tensor
566
+ """
567
+ return _GroupedPolyNormFn.apply(input, mul, weight, bias, offsets, eps,
568
+ expert_offset)
569
+
570
+ else:
571
+
572
+ def grouped_fused_mul_poly_norm(
573
+ input: Tensor,
574
+ mul: Tensor,
575
+ weight: Tensor,
576
+ bias: Tensor,
577
+ offsets: Tensor,
578
+ eps: float = 1e-6,
579
+ expert_offset: int = 0,
580
+ ) -> Tensor:
581
+ raise RuntimeError(
582
+ "Triton is not available. Install triton to use "
583
+ "grouped_fused_mul_poly_norm.")
build/torch28-cxx11-cu128-x86_64-linux/__init__.py CHANGED
@@ -2,6 +2,7 @@ import torch
2
 
3
  from . import layers, parallel_style
4
  from ._ops import ops
 
5
  from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
6
  from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
7
 
@@ -45,6 +46,7 @@ def fused_add_rms_norm(
45
  __all__ = [
46
  "poly_norm",
47
  "fused_mul_poly_norm",
 
48
  "rms_norm",
49
  "fused_add_rms_norm",
50
  "layers",
 
2
 
3
  from . import layers, parallel_style
4
  from ._ops import ops
5
+ from .grouped_poly_norm import grouped_fused_mul_poly_norm
6
  from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
7
  from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
8
 
 
46
  __all__ = [
47
  "poly_norm",
48
  "fused_mul_poly_norm",
49
+ "grouped_fused_mul_poly_norm",
50
  "rms_norm",
51
  "fused_add_rms_norm",
52
  "layers",
build/torch28-cxx11-cu128-x86_64-linux/{_activation_18b7543_dirty.abi3.so → _activation_0e6f27f_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:3f3e704c3833d0d41cf43549ddad15c36c67e4b80b0b2a7af5c6a9a2c488690b
3
  size 15804360
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7edb027993454d74da9632c7368edc2b0526b5f1ef33ae9e790d49bdf7285640
3
  size 15804360
build/torch28-cxx11-cu128-x86_64-linux/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _activation_18b7543_dirty
3
- ops = torch.ops._activation_18b7543_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_activation_18b7543_dirty::{op_name}"
 
1
  import torch
2
+ from . import _activation_0e6f27f_dirty
3
+ ops = torch.ops._activation_0e6f27f_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_activation_0e6f27f_dirty::{op_name}"
build/torch28-cxx11-cu128-x86_64-linux/grouped_poly_norm.py ADDED
@@ -0,0 +1,583 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Triton-accelerated Grouped FusedMulPolyNorm for MoE.
2
+
3
+ Fuses the entire PolyNorm computation into two Triton kernels (fwd + bwd),
4
+ eliminating multiple intermediate tensors and kernel launches.
5
+
6
+ PolyNorm formula (per row):
7
+ poly = w[0] * rms_norm(x^3) + w[1] * rms_norm(x^2) + w[2] * rms_norm(x) + bias
8
+ output = poly * mul
9
+
10
+ where rms_norm(x) = x / sqrt(mean(x^2, dim=-1) + eps)
11
+
12
+ Performance optimizations:
13
+ - @triton.autotune selects optimal BLOCK_D, num_warps, and num_stages per
14
+ hidden dimension.
15
+ - Single-tile specialization: when D <= BLOCK_D, all data stays in registers
16
+ across the reduction and output phases, eliminating redundant global reads.
17
+ - Multi-tile software pipelining: explicit num_stages in autotune configs
18
+ enables overlapping memory loads with computation across loop iterations.
19
+ - In-kernel binary search for expert mapping: eliminates 2 PyTorch kernel
20
+ launches (torch.arange + torch.bucketize) per forward/backward call.
21
+ - Backward 2-pass optimization: pass 1 merges RMS statistics computation
22
+ with dot product accumulation, pass 2 computes gradients. This reduces
23
+ memory traffic compared to a naive 3-pass approach.
24
+
25
+ Forward kernel: one program per row, tiles over D dimension.
26
+ - Computes x, x^2, x^3 in registers
27
+ - Computes three RMS norms in a single pass (shared variance reduction)
28
+ - Applies polynomial weights + bias + mul in-place
29
+
30
+ Backward kernel: one program per row, tiles over D dimension.
31
+ - Recomputes forward intermediates from saved inputs (activation recomputation)
32
+ - 2-pass: (1) RMS stats + dot products + bias grad, (2) grad_input + grad_mul + weight grads
33
+ - Weight/bias gradients use tl.atomic_add for cross-row accumulation
34
+ """
35
+
36
+ import torch
37
+ from torch import Tensor
38
+
39
+ try:
40
+ import triton
41
+ import triton.language as tl
42
+
43
+ HAS_TRITON = True
44
+ except ImportError:
45
+ HAS_TRITON = False
46
+
47
+
48
+ # ---------------------------------------------------------------------------
49
+ # PyTorch reference implementation (for testing and benchmarking)
50
+ # ---------------------------------------------------------------------------
51
+ def _rms_norm(x: Tensor, eps: float) -> Tensor:
52
+ """Per-row RMS normalization: x / sqrt(mean(x^2, dim=-1) + eps)"""
53
+ return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps)
54
+
55
+
56
+ def grouped_fused_mul_poly_norm_ref(
57
+ input: Tensor,
58
+ mul: Tensor,
59
+ weight: Tensor,
60
+ bias: Tensor,
61
+ offsets: Tensor,
62
+ eps: float = 1e-6,
63
+ expert_offset: int = 0,
64
+ ) -> Tensor:
65
+ """PyTorch reference for Grouped FusedMulPolyNorm (vectorized, single pass).
66
+
67
+ Uses torch.bucketize to map tokens to experts, then computes PolyNorm
68
+ for all tokens at once. torch.compile friendly.
69
+
70
+ Args:
71
+ input: (total_tokens, D) - concatenated tokens for all experts
72
+ mul: (total_tokens, D) - gate values to multiply with
73
+ weight: (num_experts, 3) - per-expert polynomial weights [x^3, x^2, x]
74
+ bias: (num_experts, 1) - per-expert polynomial bias
75
+ offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32)
76
+ eps: numerical stability epsilon
77
+
78
+ Returns:
79
+ (total_tokens, D) - output tensor
80
+ """
81
+ orig_dtype = input.dtype
82
+
83
+ token_positions = torch.arange(input.shape[0], device=input.device)
84
+ expert_idx = torch.bucketize(token_positions, offsets, right=True) + expert_offset
85
+
86
+ weight_fp32 = weight.float()
87
+ bias_fp32 = bias.float()
88
+
89
+ per_token_w = weight_fp32[expert_idx]
90
+ per_token_b = bias_fp32[expert_idx]
91
+
92
+ x = input.float()
93
+ m = mul.float()
94
+
95
+ x2 = x * x
96
+ x3 = x2 * x
97
+
98
+ poly = (per_token_w[:, 0:1] * _rms_norm(x3, eps) +
99
+ per_token_w[:, 1:2] * _rms_norm(x2, eps) +
100
+ per_token_w[:, 2:3] * _rms_norm(x, eps) + per_token_b)
101
+
102
+ return (poly * m).to(orig_dtype)
103
+
104
+
105
+ # ---------------------------------------------------------------------------
106
+ # Triton kernel implementation
107
+ # ---------------------------------------------------------------------------
108
+ if HAS_TRITON:
109
+ # --- Autotune configurations ---
110
+ _GROUPED_POLYNORM_FWD_CONFIGS = [
111
+ triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2),
112
+ triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3),
113
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2),
114
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3),
115
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2),
116
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4),
117
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2),
118
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3),
119
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4),
120
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2),
121
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3),
122
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4),
123
+ triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2),
124
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2),
125
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3),
126
+ triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2),
127
+ triton.Config({"BLOCK_D": 2048}, num_warps=4, num_stages=1),
128
+ triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1),
129
+ triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1),
130
+ triton.Config({"BLOCK_D": 2048}, num_warps=32, num_stages=1),
131
+ ]
132
+
133
+ _GROUPED_POLYNORM_BWD_CONFIGS = [
134
+ triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2),
135
+ triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3),
136
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2),
137
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3),
138
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2),
139
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4),
140
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2),
141
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3),
142
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4),
143
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=5),
144
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2),
145
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3),
146
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4),
147
+ triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2),
148
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2),
149
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3),
150
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=4),
151
+ triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2),
152
+ triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1),
153
+ triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1),
154
+ ]
155
+
156
+ @triton.autotune(
157
+ configs=_GROUPED_POLYNORM_FWD_CONFIGS,
158
+ key=["D"],
159
+ )
160
+ @triton.jit
161
+ def _grouped_polynorm_fwd_kernel(
162
+ input_ptr,
163
+ mul_ptr,
164
+ weight_ptr,
165
+ bias_ptr,
166
+ offsets_ptr,
167
+ output_ptr,
168
+ N,
169
+ D,
170
+ num_experts,
171
+ eps,
172
+ expert_offset,
173
+ stride_input_row,
174
+ stride_mul_row,
175
+ stride_out_row,
176
+ BLOCK_D: tl.constexpr,
177
+ ):
178
+ """Forward kernel: one program per row."""
179
+ row = tl.program_id(0)
180
+ if row >= N:
181
+ return
182
+
183
+ # Binary search for expert index (12 iters covers up to 4096 experts)
184
+ lo = 0
185
+ hi = num_experts
186
+ for _ in range(12):
187
+ if lo < hi:
188
+ mid = (lo + hi) // 2
189
+ if tl.load(offsets_ptr + mid) <= row:
190
+ lo = mid + 1
191
+ else:
192
+ hi = mid
193
+ eidx = lo + expert_offset
194
+
195
+ w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32)
196
+ w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32)
197
+ w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32)
198
+ b = tl.load(bias_ptr + eidx).to(tl.float32)
199
+
200
+ input_row_ptr = input_ptr + row * stride_input_row
201
+ mul_row_ptr = mul_ptr + row * stride_mul_row
202
+ out_row_ptr = output_ptr + row * stride_out_row
203
+
204
+ D_float = D.to(tl.float32)
205
+
206
+ # --- Single-tile path ---
207
+ if D <= BLOCK_D:
208
+ d_offs = tl.arange(0, BLOCK_D)
209
+ mask = d_offs < D
210
+
211
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
212
+ other=0.0).to(tl.float32)
213
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
214
+ other=0.0).to(tl.float32)
215
+
216
+ x2 = x * x
217
+ x3 = x2 * x
218
+
219
+ inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps)
220
+ inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps)
221
+ inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps)
222
+
223
+ # Pre-multiply scalar weight * inv_rms to save 1 FMA per element
224
+ w0_inv = w0 * inv_rms_x3
225
+ w1_inv = w1 * inv_rms_x2
226
+ w2_inv = w2 * inv_rms_x
227
+
228
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b
229
+ tl.store(out_row_ptr + d_offs, poly * m, mask=mask)
230
+ else:
231
+ # --- Multi-tile: two-pass approach ---
232
+ sum_x2 = tl.zeros((), dtype=tl.float32)
233
+ sum_x4 = tl.zeros((), dtype=tl.float32)
234
+ sum_x6 = tl.zeros((), dtype=tl.float32)
235
+
236
+ for d_start in range(0, D, BLOCK_D):
237
+ d_offs = d_start + tl.arange(0, BLOCK_D)
238
+ mask = d_offs < D
239
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
240
+ other=0.0).to(tl.float32)
241
+ x2 = x * x
242
+ sum_x2 += tl.sum(x2)
243
+ sum_x4 += tl.sum(x2 * x2)
244
+ sum_x6 += tl.sum(x2 * x2 * x2)
245
+
246
+ inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps)
247
+ inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps)
248
+ inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps)
249
+
250
+ # Pre-multiply scalar weight * inv_rms
251
+ w0_inv = w0 * inv_rms_x3
252
+ w1_inv = w1 * inv_rms_x2
253
+ w2_inv = w2 * inv_rms_x
254
+
255
+ for d_start in range(0, D, BLOCK_D):
256
+ d_offs = d_start + tl.arange(0, BLOCK_D)
257
+ mask = d_offs < D
258
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
259
+ other=0.0).to(tl.float32)
260
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
261
+ other=0.0).to(tl.float32)
262
+ x2 = x * x
263
+ x3 = x2 * x
264
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b
265
+ tl.store(out_row_ptr + d_offs, poly * m, mask=mask)
266
+
267
+ @triton.autotune(
268
+ configs=_GROUPED_POLYNORM_BWD_CONFIGS,
269
+ key=["D"],
270
+ reset_to_zero=["grad_weight_ptr", "grad_bias_ptr"],
271
+ )
272
+ @triton.jit
273
+ def _grouped_polynorm_bwd_kernel(
274
+ grad_out_ptr,
275
+ input_ptr,
276
+ mul_ptr,
277
+ weight_ptr,
278
+ bias_ptr,
279
+ offsets_ptr,
280
+ grad_input_ptr,
281
+ grad_mul_ptr,
282
+ grad_weight_ptr,
283
+ grad_bias_ptr,
284
+ N,
285
+ D,
286
+ num_experts,
287
+ eps,
288
+ expert_offset,
289
+ stride_row,
290
+ BLOCK_D: tl.constexpr,
291
+ ):
292
+ """Backward kernel: one program per row, 2-pass approach.
293
+
294
+ Pass 1: RMS stats + dot products + bias grad
295
+ Pass 2: grad_input + grad_mul + weight grads (via atomic_add)
296
+ """
297
+ row = tl.program_id(0)
298
+ if row >= N:
299
+ return
300
+
301
+ lo = 0
302
+ hi = num_experts
303
+ for _ in range(12):
304
+ if lo < hi:
305
+ mid = (lo + hi) // 2
306
+ if tl.load(offsets_ptr + mid) <= row:
307
+ lo = mid + 1
308
+ else:
309
+ hi = mid
310
+ eidx = lo + expert_offset
311
+
312
+ w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32)
313
+ w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32)
314
+ w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32)
315
+ b_val = tl.load(bias_ptr + eidx).to(tl.float32)
316
+
317
+ input_row_ptr = input_ptr + row * stride_row
318
+ mul_row_ptr = mul_ptr + row * stride_row
319
+ grad_out_row_ptr = grad_out_ptr + row * stride_row
320
+ grad_input_row_ptr = grad_input_ptr + row * stride_row
321
+ grad_mul_row_ptr = grad_mul_ptr + row * stride_row
322
+
323
+ D_float = D.to(tl.float32)
324
+
325
+ # --- Single-tile path ---
326
+ if D <= BLOCK_D:
327
+ d_offs = tl.arange(0, BLOCK_D)
328
+ mask = d_offs < D
329
+
330
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
331
+ other=0.0).to(tl.float32)
332
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
333
+ other=0.0).to(tl.float32)
334
+ go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
335
+ other=0.0).to(tl.float32)
336
+
337
+ x2 = x * x
338
+ x3 = x2 * x
339
+
340
+ # Compute RMS stats (x4 inlined to reduce register pressure)
341
+ inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps)
342
+ inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps)
343
+ inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps)
344
+
345
+ w0_inv = w0 * inv_rms_x3
346
+ w1_inv = w1 * inv_rms_x2
347
+ w2_inv = w2 * inv_rms_x
348
+
349
+ dpoly = go * m
350
+
351
+ # Dot products for coefficients and weight grads
352
+ sum_dpoly_x = tl.sum(dpoly * x)
353
+ sum_dpoly_x2 = tl.sum(dpoly * x2)
354
+ sum_dpoly_x3 = tl.sum(dpoly * x3)
355
+ grad_b_acc = tl.sum(dpoly)
356
+
357
+ # Weight grads
358
+ grad_w0_acc = inv_rms_x3 * sum_dpoly_x3
359
+ grad_w1_acc = inv_rms_x2 * sum_dpoly_x2
360
+ grad_w2_acc = inv_rms_x * sum_dpoly_x
361
+
362
+ # Coefficients for grad_input
363
+ coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float
364
+ coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float
365
+ coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float
366
+
367
+ # grad_mul
368
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv
369
+ tl.store(grad_mul_row_ptr + d_offs, go * (poly + b_val), mask=mask)
370
+
371
+ # grad_input (in-place accumulation to reduce register pressure)
372
+ g = inv_rms_x * (w2 * dpoly - x * coeff_x)
373
+ g += 2.0 * x * inv_rms_x2 * (w1 * dpoly - x2 * coeff_x2)
374
+ g += 3.0 * x2 * inv_rms_x3 * (w0 * dpoly - x3 * coeff_x3)
375
+ tl.store(grad_input_row_ptr + d_offs, g, mask=mask)
376
+
377
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc)
378
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc)
379
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc)
380
+ tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc)
381
+ else:
382
+ # --- Multi-tile: 2-pass ---
383
+ # Pass 1: RMS stats + dot products + bias grad
384
+ sum_x2 = tl.zeros((), dtype=tl.float32)
385
+ sum_x4 = tl.zeros((), dtype=tl.float32)
386
+ sum_x6 = tl.zeros((), dtype=tl.float32)
387
+ sum_dpoly_x = tl.zeros((), dtype=tl.float32)
388
+ sum_dpoly_x2 = tl.zeros((), dtype=tl.float32)
389
+ sum_dpoly_x3 = tl.zeros((), dtype=tl.float32)
390
+ grad_b_acc = tl.zeros((), dtype=tl.float32)
391
+
392
+ for d_start in range(0, D, BLOCK_D):
393
+ d_offs = d_start + tl.arange(0, BLOCK_D)
394
+ mask = d_offs < D
395
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
396
+ other=0.0).to(tl.float32)
397
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
398
+ other=0.0).to(tl.float32)
399
+ go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
400
+ other=0.0).to(tl.float32)
401
+
402
+ x2 = x * x
403
+ x3 = x2 * x
404
+ dpoly = go * m
405
+
406
+ sum_x2 += tl.sum(x2)
407
+ sum_x4 += tl.sum(x2 * x2)
408
+ sum_x6 += tl.sum(x2 * x2 * x2)
409
+ sum_dpoly_x += tl.sum(dpoly * x)
410
+ sum_dpoly_x2 += tl.sum(dpoly * x2)
411
+ sum_dpoly_x3 += tl.sum(dpoly * x3)
412
+ grad_b_acc += tl.sum(dpoly)
413
+
414
+ inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps)
415
+ inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps)
416
+ inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps)
417
+
418
+ w0_inv = w0 * inv_rms_x3
419
+ w1_inv = w1 * inv_rms_x2
420
+ w2_inv = w2 * inv_rms_x
421
+
422
+ # Weight grads
423
+ grad_w0_acc = inv_rms_x3 * sum_dpoly_x3
424
+ grad_w1_acc = inv_rms_x2 * sum_dpoly_x2
425
+ grad_w2_acc = inv_rms_x * sum_dpoly_x
426
+
427
+ # Coefficients for grad_input
428
+ coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float
429
+ coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float
430
+ coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float
431
+
432
+ # Pass 2: grad_input + grad_mul
433
+ for d_start in range(0, D, BLOCK_D):
434
+ d_offs = d_start + tl.arange(0, BLOCK_D)
435
+ mask = d_offs < D
436
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
437
+ other=0.0).to(tl.float32)
438
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
439
+ other=0.0).to(tl.float32)
440
+ go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
441
+ other=0.0).to(tl.float32)
442
+
443
+ x2 = x * x
444
+ x3 = x2 * x
445
+
446
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv
447
+ tl.store(grad_mul_row_ptr + d_offs,
448
+ go * (poly + b_val),
449
+ mask=mask)
450
+
451
+ dpoly = go * m
452
+ g = inv_rms_x * (w2 * dpoly - x * coeff_x)
453
+ g += (2.0 * x * inv_rms_x2 *
454
+ (w1 * dpoly - x2 * coeff_x2))
455
+ g += (3.0 * x2 * inv_rms_x3 *
456
+ (w0 * dpoly - x3 * coeff_x3))
457
+
458
+ tl.store(grad_input_row_ptr + d_offs, g, mask=mask)
459
+
460
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc)
461
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc)
462
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc)
463
+ tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc)
464
+
465
+ class _GroupedPolyNormFn(torch.autograd.Function):
466
+
467
+ @staticmethod
468
+ def forward(ctx, input, mul, weight, bias, offsets, eps, expert_offset):
469
+ N, D = input.shape
470
+ input = input.contiguous()
471
+ mul = mul.contiguous()
472
+ output = torch.empty_like(input)
473
+
474
+ num_experts = offsets.shape[0]
475
+ assert num_experts <= 4096, (
476
+ f"Supports at most 4096 experts, got {num_experts}.")
477
+
478
+ _grouped_polynorm_fwd_kernel[(N,)](
479
+ input,
480
+ mul,
481
+ weight,
482
+ bias,
483
+ offsets,
484
+ output,
485
+ N,
486
+ D,
487
+ num_experts,
488
+ eps,
489
+ expert_offset,
490
+ stride_input_row=input.stride(0),
491
+ stride_mul_row=mul.stride(0),
492
+ stride_out_row=output.stride(0),
493
+ )
494
+
495
+ ctx.save_for_backward(input, mul, weight, bias, offsets)
496
+ ctx.eps = eps
497
+ ctx.expert_offset = expert_offset
498
+ return output
499
+
500
+ @staticmethod
501
+ def backward(ctx, grad_output):
502
+ input, mul, weight, bias, offsets = ctx.saved_tensors
503
+ eps = ctx.eps
504
+ expert_offset = ctx.expert_offset
505
+ N, D = input.shape
506
+
507
+ grad_output = grad_output.contiguous()
508
+ grad_input = torch.empty_like(input)
509
+ grad_mul = torch.empty_like(mul)
510
+ grad_weight = torch.zeros(weight.shape[0],
511
+ 3,
512
+ device=weight.device,
513
+ dtype=torch.float32)
514
+ grad_bias = torch.zeros(bias.shape[0],
515
+ device=bias.device,
516
+ dtype=torch.float32)
517
+
518
+ num_experts = offsets.shape[0]
519
+
520
+ _grouped_polynorm_bwd_kernel[(N,)](
521
+ grad_output,
522
+ input,
523
+ mul,
524
+ weight,
525
+ bias,
526
+ offsets,
527
+ grad_input,
528
+ grad_mul,
529
+ grad_weight,
530
+ grad_bias,
531
+ N,
532
+ D,
533
+ num_experts,
534
+ eps,
535
+ expert_offset,
536
+ stride_row=input.stride(0),
537
+ )
538
+
539
+ grad_weight = grad_weight.to(weight.dtype)
540
+ grad_bias = grad_bias.unsqueeze(-1).to(bias.dtype)
541
+
542
+ return grad_input, grad_mul, grad_weight, grad_bias, None, None, None
543
+
544
+ def grouped_fused_mul_poly_norm(
545
+ input: Tensor,
546
+ mul: Tensor,
547
+ weight: Tensor,
548
+ bias: Tensor,
549
+ offsets: Tensor,
550
+ eps: float = 1e-6,
551
+ expert_offset: int = 0,
552
+ ) -> Tensor:
553
+ """Triton-accelerated Grouped FusedMulPolyNorm.
554
+
555
+ Args:
556
+ input: (total_tokens, D) - concatenated tokens for all experts
557
+ mul: (total_tokens, D) - gate values to multiply with
558
+ weight: (num_experts, 3) - per-expert polynomial weights
559
+ bias: (num_experts, 1) - per-expert polynomial bias
560
+ offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32)
561
+ eps: numerical stability epsilon
562
+ expert_offset: offset to add to expert index
563
+
564
+ Returns:
565
+ (total_tokens, D) - output tensor
566
+ """
567
+ return _GroupedPolyNormFn.apply(input, mul, weight, bias, offsets, eps,
568
+ expert_offset)
569
+
570
+ else:
571
+
572
+ def grouped_fused_mul_poly_norm(
573
+ input: Tensor,
574
+ mul: Tensor,
575
+ weight: Tensor,
576
+ bias: Tensor,
577
+ offsets: Tensor,
578
+ eps: float = 1e-6,
579
+ expert_offset: int = 0,
580
+ ) -> Tensor:
581
+ raise RuntimeError(
582
+ "Triton is not available. Install triton to use "
583
+ "grouped_fused_mul_poly_norm.")
build/torch28-cxx11-cu129-x86_64-linux/__init__.py CHANGED
@@ -2,6 +2,7 @@ import torch
2
 
3
  from . import layers, parallel_style
4
  from ._ops import ops
 
5
  from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
6
  from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
7
 
@@ -45,6 +46,7 @@ def fused_add_rms_norm(
45
  __all__ = [
46
  "poly_norm",
47
  "fused_mul_poly_norm",
 
48
  "rms_norm",
49
  "fused_add_rms_norm",
50
  "layers",
 
2
 
3
  from . import layers, parallel_style
4
  from ._ops import ops
5
+ from .grouped_poly_norm import grouped_fused_mul_poly_norm
6
  from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
7
  from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
8
 
 
46
  __all__ = [
47
  "poly_norm",
48
  "fused_mul_poly_norm",
49
+ "grouped_fused_mul_poly_norm",
50
  "rms_norm",
51
  "fused_add_rms_norm",
52
  "layers",
build/torch28-cxx11-cu129-x86_64-linux/{_activation_18b7543_dirty.abi3.so → _activation_0e6f27f_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:f589ae5cb26c7fb85fac22ee789749c54897bf08b170dbf77b9d62eb98ee8b53
3
  size 15795640
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f06fb9594dcf9c0bc3a8af619fec3a541b775f0dad304a7978314d32fae8d244
3
  size 15795640
build/torch28-cxx11-cu129-x86_64-linux/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _activation_18b7543_dirty
3
- ops = torch.ops._activation_18b7543_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_activation_18b7543_dirty::{op_name}"
 
1
  import torch
2
+ from . import _activation_0e6f27f_dirty
3
+ ops = torch.ops._activation_0e6f27f_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_activation_0e6f27f_dirty::{op_name}"
build/torch28-cxx11-cu129-x86_64-linux/grouped_poly_norm.py ADDED
@@ -0,0 +1,583 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Triton-accelerated Grouped FusedMulPolyNorm for MoE.
2
+
3
+ Fuses the entire PolyNorm computation into two Triton kernels (fwd + bwd),
4
+ eliminating multiple intermediate tensors and kernel launches.
5
+
6
+ PolyNorm formula (per row):
7
+ poly = w[0] * rms_norm(x^3) + w[1] * rms_norm(x^2) + w[2] * rms_norm(x) + bias
8
+ output = poly * mul
9
+
10
+ where rms_norm(x) = x / sqrt(mean(x^2, dim=-1) + eps)
11
+
12
+ Performance optimizations:
13
+ - @triton.autotune selects optimal BLOCK_D, num_warps, and num_stages per
14
+ hidden dimension.
15
+ - Single-tile specialization: when D <= BLOCK_D, all data stays in registers
16
+ across the reduction and output phases, eliminating redundant global reads.
17
+ - Multi-tile software pipelining: explicit num_stages in autotune configs
18
+ enables overlapping memory loads with computation across loop iterations.
19
+ - In-kernel binary search for expert mapping: eliminates 2 PyTorch kernel
20
+ launches (torch.arange + torch.bucketize) per forward/backward call.
21
+ - Backward 2-pass optimization: pass 1 merges RMS statistics computation
22
+ with dot product accumulation, pass 2 computes gradients. This reduces
23
+ memory traffic compared to a naive 3-pass approach.
24
+
25
+ Forward kernel: one program per row, tiles over D dimension.
26
+ - Computes x, x^2, x^3 in registers
27
+ - Computes three RMS norms in a single pass (shared variance reduction)
28
+ - Applies polynomial weights + bias + mul in-place
29
+
30
+ Backward kernel: one program per row, tiles over D dimension.
31
+ - Recomputes forward intermediates from saved inputs (activation recomputation)
32
+ - 2-pass: (1) RMS stats + dot products + bias grad, (2) grad_input + grad_mul + weight grads
33
+ - Weight/bias gradients use tl.atomic_add for cross-row accumulation
34
+ """
35
+
36
+ import torch
37
+ from torch import Tensor
38
+
39
+ try:
40
+ import triton
41
+ import triton.language as tl
42
+
43
+ HAS_TRITON = True
44
+ except ImportError:
45
+ HAS_TRITON = False
46
+
47
+
48
+ # ---------------------------------------------------------------------------
49
+ # PyTorch reference implementation (for testing and benchmarking)
50
+ # ---------------------------------------------------------------------------
51
+ def _rms_norm(x: Tensor, eps: float) -> Tensor:
52
+ """Per-row RMS normalization: x / sqrt(mean(x^2, dim=-1) + eps)"""
53
+ return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps)
54
+
55
+
56
+ def grouped_fused_mul_poly_norm_ref(
57
+ input: Tensor,
58
+ mul: Tensor,
59
+ weight: Tensor,
60
+ bias: Tensor,
61
+ offsets: Tensor,
62
+ eps: float = 1e-6,
63
+ expert_offset: int = 0,
64
+ ) -> Tensor:
65
+ """PyTorch reference for Grouped FusedMulPolyNorm (vectorized, single pass).
66
+
67
+ Uses torch.bucketize to map tokens to experts, then computes PolyNorm
68
+ for all tokens at once. torch.compile friendly.
69
+
70
+ Args:
71
+ input: (total_tokens, D) - concatenated tokens for all experts
72
+ mul: (total_tokens, D) - gate values to multiply with
73
+ weight: (num_experts, 3) - per-expert polynomial weights [x^3, x^2, x]
74
+ bias: (num_experts, 1) - per-expert polynomial bias
75
+ offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32)
76
+ eps: numerical stability epsilon
77
+
78
+ Returns:
79
+ (total_tokens, D) - output tensor
80
+ """
81
+ orig_dtype = input.dtype
82
+
83
+ token_positions = torch.arange(input.shape[0], device=input.device)
84
+ expert_idx = torch.bucketize(token_positions, offsets, right=True) + expert_offset
85
+
86
+ weight_fp32 = weight.float()
87
+ bias_fp32 = bias.float()
88
+
89
+ per_token_w = weight_fp32[expert_idx]
90
+ per_token_b = bias_fp32[expert_idx]
91
+
92
+ x = input.float()
93
+ m = mul.float()
94
+
95
+ x2 = x * x
96
+ x3 = x2 * x
97
+
98
+ poly = (per_token_w[:, 0:1] * _rms_norm(x3, eps) +
99
+ per_token_w[:, 1:2] * _rms_norm(x2, eps) +
100
+ per_token_w[:, 2:3] * _rms_norm(x, eps) + per_token_b)
101
+
102
+ return (poly * m).to(orig_dtype)
103
+
104
+
105
+ # ---------------------------------------------------------------------------
106
+ # Triton kernel implementation
107
+ # ---------------------------------------------------------------------------
108
+ if HAS_TRITON:
109
+ # --- Autotune configurations ---
110
+ _GROUPED_POLYNORM_FWD_CONFIGS = [
111
+ triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2),
112
+ triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3),
113
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2),
114
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3),
115
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2),
116
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4),
117
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2),
118
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3),
119
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4),
120
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2),
121
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3),
122
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4),
123
+ triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2),
124
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2),
125
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3),
126
+ triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2),
127
+ triton.Config({"BLOCK_D": 2048}, num_warps=4, num_stages=1),
128
+ triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1),
129
+ triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1),
130
+ triton.Config({"BLOCK_D": 2048}, num_warps=32, num_stages=1),
131
+ ]
132
+
133
+ _GROUPED_POLYNORM_BWD_CONFIGS = [
134
+ triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2),
135
+ triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3),
136
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2),
137
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3),
138
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2),
139
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4),
140
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2),
141
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3),
142
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4),
143
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=5),
144
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2),
145
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3),
146
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4),
147
+ triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2),
148
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2),
149
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3),
150
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=4),
151
+ triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2),
152
+ triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1),
153
+ triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1),
154
+ ]
155
+
156
+ @triton.autotune(
157
+ configs=_GROUPED_POLYNORM_FWD_CONFIGS,
158
+ key=["D"],
159
+ )
160
+ @triton.jit
161
+ def _grouped_polynorm_fwd_kernel(
162
+ input_ptr,
163
+ mul_ptr,
164
+ weight_ptr,
165
+ bias_ptr,
166
+ offsets_ptr,
167
+ output_ptr,
168
+ N,
169
+ D,
170
+ num_experts,
171
+ eps,
172
+ expert_offset,
173
+ stride_input_row,
174
+ stride_mul_row,
175
+ stride_out_row,
176
+ BLOCK_D: tl.constexpr,
177
+ ):
178
+ """Forward kernel: one program per row."""
179
+ row = tl.program_id(0)
180
+ if row >= N:
181
+ return
182
+
183
+ # Binary search for expert index (12 iters covers up to 4096 experts)
184
+ lo = 0
185
+ hi = num_experts
186
+ for _ in range(12):
187
+ if lo < hi:
188
+ mid = (lo + hi) // 2
189
+ if tl.load(offsets_ptr + mid) <= row:
190
+ lo = mid + 1
191
+ else:
192
+ hi = mid
193
+ eidx = lo + expert_offset
194
+
195
+ w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32)
196
+ w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32)
197
+ w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32)
198
+ b = tl.load(bias_ptr + eidx).to(tl.float32)
199
+
200
+ input_row_ptr = input_ptr + row * stride_input_row
201
+ mul_row_ptr = mul_ptr + row * stride_mul_row
202
+ out_row_ptr = output_ptr + row * stride_out_row
203
+
204
+ D_float = D.to(tl.float32)
205
+
206
+ # --- Single-tile path ---
207
+ if D <= BLOCK_D:
208
+ d_offs = tl.arange(0, BLOCK_D)
209
+ mask = d_offs < D
210
+
211
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
212
+ other=0.0).to(tl.float32)
213
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
214
+ other=0.0).to(tl.float32)
215
+
216
+ x2 = x * x
217
+ x3 = x2 * x
218
+
219
+ inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps)
220
+ inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps)
221
+ inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps)
222
+
223
+ # Pre-multiply scalar weight * inv_rms to save 1 FMA per element
224
+ w0_inv = w0 * inv_rms_x3
225
+ w1_inv = w1 * inv_rms_x2
226
+ w2_inv = w2 * inv_rms_x
227
+
228
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b
229
+ tl.store(out_row_ptr + d_offs, poly * m, mask=mask)
230
+ else:
231
+ # --- Multi-tile: two-pass approach ---
232
+ sum_x2 = tl.zeros((), dtype=tl.float32)
233
+ sum_x4 = tl.zeros((), dtype=tl.float32)
234
+ sum_x6 = tl.zeros((), dtype=tl.float32)
235
+
236
+ for d_start in range(0, D, BLOCK_D):
237
+ d_offs = d_start + tl.arange(0, BLOCK_D)
238
+ mask = d_offs < D
239
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
240
+ other=0.0).to(tl.float32)
241
+ x2 = x * x
242
+ sum_x2 += tl.sum(x2)
243
+ sum_x4 += tl.sum(x2 * x2)
244
+ sum_x6 += tl.sum(x2 * x2 * x2)
245
+
246
+ inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps)
247
+ inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps)
248
+ inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps)
249
+
250
+ # Pre-multiply scalar weight * inv_rms
251
+ w0_inv = w0 * inv_rms_x3
252
+ w1_inv = w1 * inv_rms_x2
253
+ w2_inv = w2 * inv_rms_x
254
+
255
+ for d_start in range(0, D, BLOCK_D):
256
+ d_offs = d_start + tl.arange(0, BLOCK_D)
257
+ mask = d_offs < D
258
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
259
+ other=0.0).to(tl.float32)
260
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
261
+ other=0.0).to(tl.float32)
262
+ x2 = x * x
263
+ x3 = x2 * x
264
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b
265
+ tl.store(out_row_ptr + d_offs, poly * m, mask=mask)
266
+
267
+ @triton.autotune(
268
+ configs=_GROUPED_POLYNORM_BWD_CONFIGS,
269
+ key=["D"],
270
+ reset_to_zero=["grad_weight_ptr", "grad_bias_ptr"],
271
+ )
272
+ @triton.jit
273
+ def _grouped_polynorm_bwd_kernel(
274
+ grad_out_ptr,
275
+ input_ptr,
276
+ mul_ptr,
277
+ weight_ptr,
278
+ bias_ptr,
279
+ offsets_ptr,
280
+ grad_input_ptr,
281
+ grad_mul_ptr,
282
+ grad_weight_ptr,
283
+ grad_bias_ptr,
284
+ N,
285
+ D,
286
+ num_experts,
287
+ eps,
288
+ expert_offset,
289
+ stride_row,
290
+ BLOCK_D: tl.constexpr,
291
+ ):
292
+ """Backward kernel: one program per row, 2-pass approach.
293
+
294
+ Pass 1: RMS stats + dot products + bias grad
295
+ Pass 2: grad_input + grad_mul + weight grads (via atomic_add)
296
+ """
297
+ row = tl.program_id(0)
298
+ if row >= N:
299
+ return
300
+
301
+ lo = 0
302
+ hi = num_experts
303
+ for _ in range(12):
304
+ if lo < hi:
305
+ mid = (lo + hi) // 2
306
+ if tl.load(offsets_ptr + mid) <= row:
307
+ lo = mid + 1
308
+ else:
309
+ hi = mid
310
+ eidx = lo + expert_offset
311
+
312
+ w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32)
313
+ w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32)
314
+ w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32)
315
+ b_val = tl.load(bias_ptr + eidx).to(tl.float32)
316
+
317
+ input_row_ptr = input_ptr + row * stride_row
318
+ mul_row_ptr = mul_ptr + row * stride_row
319
+ grad_out_row_ptr = grad_out_ptr + row * stride_row
320
+ grad_input_row_ptr = grad_input_ptr + row * stride_row
321
+ grad_mul_row_ptr = grad_mul_ptr + row * stride_row
322
+
323
+ D_float = D.to(tl.float32)
324
+
325
+ # --- Single-tile path ---
326
+ if D <= BLOCK_D:
327
+ d_offs = tl.arange(0, BLOCK_D)
328
+ mask = d_offs < D
329
+
330
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
331
+ other=0.0).to(tl.float32)
332
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
333
+ other=0.0).to(tl.float32)
334
+ go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
335
+ other=0.0).to(tl.float32)
336
+
337
+ x2 = x * x
338
+ x3 = x2 * x
339
+
340
+ # Compute RMS stats (x4 inlined to reduce register pressure)
341
+ inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps)
342
+ inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps)
343
+ inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps)
344
+
345
+ w0_inv = w0 * inv_rms_x3
346
+ w1_inv = w1 * inv_rms_x2
347
+ w2_inv = w2 * inv_rms_x
348
+
349
+ dpoly = go * m
350
+
351
+ # Dot products for coefficients and weight grads
352
+ sum_dpoly_x = tl.sum(dpoly * x)
353
+ sum_dpoly_x2 = tl.sum(dpoly * x2)
354
+ sum_dpoly_x3 = tl.sum(dpoly * x3)
355
+ grad_b_acc = tl.sum(dpoly)
356
+
357
+ # Weight grads
358
+ grad_w0_acc = inv_rms_x3 * sum_dpoly_x3
359
+ grad_w1_acc = inv_rms_x2 * sum_dpoly_x2
360
+ grad_w2_acc = inv_rms_x * sum_dpoly_x
361
+
362
+ # Coefficients for grad_input
363
+ coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float
364
+ coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float
365
+ coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float
366
+
367
+ # grad_mul
368
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv
369
+ tl.store(grad_mul_row_ptr + d_offs, go * (poly + b_val), mask=mask)
370
+
371
+ # grad_input (in-place accumulation to reduce register pressure)
372
+ g = inv_rms_x * (w2 * dpoly - x * coeff_x)
373
+ g += 2.0 * x * inv_rms_x2 * (w1 * dpoly - x2 * coeff_x2)
374
+ g += 3.0 * x2 * inv_rms_x3 * (w0 * dpoly - x3 * coeff_x3)
375
+ tl.store(grad_input_row_ptr + d_offs, g, mask=mask)
376
+
377
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc)
378
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc)
379
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc)
380
+ tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc)
381
+ else:
382
+ # --- Multi-tile: 2-pass ---
383
+ # Pass 1: RMS stats + dot products + bias grad
384
+ sum_x2 = tl.zeros((), dtype=tl.float32)
385
+ sum_x4 = tl.zeros((), dtype=tl.float32)
386
+ sum_x6 = tl.zeros((), dtype=tl.float32)
387
+ sum_dpoly_x = tl.zeros((), dtype=tl.float32)
388
+ sum_dpoly_x2 = tl.zeros((), dtype=tl.float32)
389
+ sum_dpoly_x3 = tl.zeros((), dtype=tl.float32)
390
+ grad_b_acc = tl.zeros((), dtype=tl.float32)
391
+
392
+ for d_start in range(0, D, BLOCK_D):
393
+ d_offs = d_start + tl.arange(0, BLOCK_D)
394
+ mask = d_offs < D
395
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
396
+ other=0.0).to(tl.float32)
397
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
398
+ other=0.0).to(tl.float32)
399
+ go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
400
+ other=0.0).to(tl.float32)
401
+
402
+ x2 = x * x
403
+ x3 = x2 * x
404
+ dpoly = go * m
405
+
406
+ sum_x2 += tl.sum(x2)
407
+ sum_x4 += tl.sum(x2 * x2)
408
+ sum_x6 += tl.sum(x2 * x2 * x2)
409
+ sum_dpoly_x += tl.sum(dpoly * x)
410
+ sum_dpoly_x2 += tl.sum(dpoly * x2)
411
+ sum_dpoly_x3 += tl.sum(dpoly * x3)
412
+ grad_b_acc += tl.sum(dpoly)
413
+
414
+ inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps)
415
+ inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps)
416
+ inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps)
417
+
418
+ w0_inv = w0 * inv_rms_x3
419
+ w1_inv = w1 * inv_rms_x2
420
+ w2_inv = w2 * inv_rms_x
421
+
422
+ # Weight grads
423
+ grad_w0_acc = inv_rms_x3 * sum_dpoly_x3
424
+ grad_w1_acc = inv_rms_x2 * sum_dpoly_x2
425
+ grad_w2_acc = inv_rms_x * sum_dpoly_x
426
+
427
+ # Coefficients for grad_input
428
+ coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float
429
+ coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float
430
+ coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float
431
+
432
+ # Pass 2: grad_input + grad_mul
433
+ for d_start in range(0, D, BLOCK_D):
434
+ d_offs = d_start + tl.arange(0, BLOCK_D)
435
+ mask = d_offs < D
436
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
437
+ other=0.0).to(tl.float32)
438
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
439
+ other=0.0).to(tl.float32)
440
+ go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
441
+ other=0.0).to(tl.float32)
442
+
443
+ x2 = x * x
444
+ x3 = x2 * x
445
+
446
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv
447
+ tl.store(grad_mul_row_ptr + d_offs,
448
+ go * (poly + b_val),
449
+ mask=mask)
450
+
451
+ dpoly = go * m
452
+ g = inv_rms_x * (w2 * dpoly - x * coeff_x)
453
+ g += (2.0 * x * inv_rms_x2 *
454
+ (w1 * dpoly - x2 * coeff_x2))
455
+ g += (3.0 * x2 * inv_rms_x3 *
456
+ (w0 * dpoly - x3 * coeff_x3))
457
+
458
+ tl.store(grad_input_row_ptr + d_offs, g, mask=mask)
459
+
460
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc)
461
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc)
462
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc)
463
+ tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc)
464
+
465
+ class _GroupedPolyNormFn(torch.autograd.Function):
466
+
467
+ @staticmethod
468
+ def forward(ctx, input, mul, weight, bias, offsets, eps, expert_offset):
469
+ N, D = input.shape
470
+ input = input.contiguous()
471
+ mul = mul.contiguous()
472
+ output = torch.empty_like(input)
473
+
474
+ num_experts = offsets.shape[0]
475
+ assert num_experts <= 4096, (
476
+ f"Supports at most 4096 experts, got {num_experts}.")
477
+
478
+ _grouped_polynorm_fwd_kernel[(N,)](
479
+ input,
480
+ mul,
481
+ weight,
482
+ bias,
483
+ offsets,
484
+ output,
485
+ N,
486
+ D,
487
+ num_experts,
488
+ eps,
489
+ expert_offset,
490
+ stride_input_row=input.stride(0),
491
+ stride_mul_row=mul.stride(0),
492
+ stride_out_row=output.stride(0),
493
+ )
494
+
495
+ ctx.save_for_backward(input, mul, weight, bias, offsets)
496
+ ctx.eps = eps
497
+ ctx.expert_offset = expert_offset
498
+ return output
499
+
500
+ @staticmethod
501
+ def backward(ctx, grad_output):
502
+ input, mul, weight, bias, offsets = ctx.saved_tensors
503
+ eps = ctx.eps
504
+ expert_offset = ctx.expert_offset
505
+ N, D = input.shape
506
+
507
+ grad_output = grad_output.contiguous()
508
+ grad_input = torch.empty_like(input)
509
+ grad_mul = torch.empty_like(mul)
510
+ grad_weight = torch.zeros(weight.shape[0],
511
+ 3,
512
+ device=weight.device,
513
+ dtype=torch.float32)
514
+ grad_bias = torch.zeros(bias.shape[0],
515
+ device=bias.device,
516
+ dtype=torch.float32)
517
+
518
+ num_experts = offsets.shape[0]
519
+
520
+ _grouped_polynorm_bwd_kernel[(N,)](
521
+ grad_output,
522
+ input,
523
+ mul,
524
+ weight,
525
+ bias,
526
+ offsets,
527
+ grad_input,
528
+ grad_mul,
529
+ grad_weight,
530
+ grad_bias,
531
+ N,
532
+ D,
533
+ num_experts,
534
+ eps,
535
+ expert_offset,
536
+ stride_row=input.stride(0),
537
+ )
538
+
539
+ grad_weight = grad_weight.to(weight.dtype)
540
+ grad_bias = grad_bias.unsqueeze(-1).to(bias.dtype)
541
+
542
+ return grad_input, grad_mul, grad_weight, grad_bias, None, None, None
543
+
544
+ def grouped_fused_mul_poly_norm(
545
+ input: Tensor,
546
+ mul: Tensor,
547
+ weight: Tensor,
548
+ bias: Tensor,
549
+ offsets: Tensor,
550
+ eps: float = 1e-6,
551
+ expert_offset: int = 0,
552
+ ) -> Tensor:
553
+ """Triton-accelerated Grouped FusedMulPolyNorm.
554
+
555
+ Args:
556
+ input: (total_tokens, D) - concatenated tokens for all experts
557
+ mul: (total_tokens, D) - gate values to multiply with
558
+ weight: (num_experts, 3) - per-expert polynomial weights
559
+ bias: (num_experts, 1) - per-expert polynomial bias
560
+ offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32)
561
+ eps: numerical stability epsilon
562
+ expert_offset: offset to add to expert index
563
+
564
+ Returns:
565
+ (total_tokens, D) - output tensor
566
+ """
567
+ return _GroupedPolyNormFn.apply(input, mul, weight, bias, offsets, eps,
568
+ expert_offset)
569
+
570
+ else:
571
+
572
+ def grouped_fused_mul_poly_norm(
573
+ input: Tensor,
574
+ mul: Tensor,
575
+ weight: Tensor,
576
+ bias: Tensor,
577
+ offsets: Tensor,
578
+ eps: float = 1e-6,
579
+ expert_offset: int = 0,
580
+ ) -> Tensor:
581
+ raise RuntimeError(
582
+ "Triton is not available. Install triton to use "
583
+ "grouped_fused_mul_poly_norm.")
build/torch28-cxx11-rocm63-x86_64-linux/__init__.py CHANGED
@@ -2,6 +2,7 @@ import torch
2
 
3
  from . import layers, parallel_style
4
  from ._ops import ops
 
5
  from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
6
  from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
7
 
@@ -45,6 +46,7 @@ def fused_add_rms_norm(
45
  __all__ = [
46
  "poly_norm",
47
  "fused_mul_poly_norm",
 
48
  "rms_norm",
49
  "fused_add_rms_norm",
50
  "layers",
 
2
 
3
  from . import layers, parallel_style
4
  from ._ops import ops
5
+ from .grouped_poly_norm import grouped_fused_mul_poly_norm
6
  from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
7
  from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
8
 
 
46
  __all__ = [
47
  "poly_norm",
48
  "fused_mul_poly_norm",
49
+ "grouped_fused_mul_poly_norm",
50
  "rms_norm",
51
  "fused_add_rms_norm",
52
  "layers",
build/torch28-cxx11-rocm63-x86_64-linux/{_activation_18b7543_dirty.abi3.so → _activation_0e6f27f_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:e753cc4a2aaa76aea3b821b245d4da638d7c94059a660f8233e17a54df379813
3
  size 2788456
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:448ab4d8a859725a3200d95d2164a3fe261f67b20e834f4f7062485cf729cf88
3
  size 2788456
build/torch28-cxx11-rocm63-x86_64-linux/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _activation_18b7543_dirty
3
- ops = torch.ops._activation_18b7543_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_activation_18b7543_dirty::{op_name}"
 
1
  import torch
2
+ from . import _activation_0e6f27f_dirty
3
+ ops = torch.ops._activation_0e6f27f_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_activation_0e6f27f_dirty::{op_name}"
build/torch28-cxx11-rocm63-x86_64-linux/grouped_poly_norm.py ADDED
@@ -0,0 +1,583 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Triton-accelerated Grouped FusedMulPolyNorm for MoE.
2
+
3
+ Fuses the entire PolyNorm computation into two Triton kernels (fwd + bwd),
4
+ eliminating multiple intermediate tensors and kernel launches.
5
+
6
+ PolyNorm formula (per row):
7
+ poly = w[0] * rms_norm(x^3) + w[1] * rms_norm(x^2) + w[2] * rms_norm(x) + bias
8
+ output = poly * mul
9
+
10
+ where rms_norm(x) = x / sqrt(mean(x^2, dim=-1) + eps)
11
+
12
+ Performance optimizations:
13
+ - @triton.autotune selects optimal BLOCK_D, num_warps, and num_stages per
14
+ hidden dimension.
15
+ - Single-tile specialization: when D <= BLOCK_D, all data stays in registers
16
+ across the reduction and output phases, eliminating redundant global reads.
17
+ - Multi-tile software pipelining: explicit num_stages in autotune configs
18
+ enables overlapping memory loads with computation across loop iterations.
19
+ - In-kernel binary search for expert mapping: eliminates 2 PyTorch kernel
20
+ launches (torch.arange + torch.bucketize) per forward/backward call.
21
+ - Backward 2-pass optimization: pass 1 merges RMS statistics computation
22
+ with dot product accumulation, pass 2 computes gradients. This reduces
23
+ memory traffic compared to a naive 3-pass approach.
24
+
25
+ Forward kernel: one program per row, tiles over D dimension.
26
+ - Computes x, x^2, x^3 in registers
27
+ - Computes three RMS norms in a single pass (shared variance reduction)
28
+ - Applies polynomial weights + bias + mul in-place
29
+
30
+ Backward kernel: one program per row, tiles over D dimension.
31
+ - Recomputes forward intermediates from saved inputs (activation recomputation)
32
+ - 2-pass: (1) RMS stats + dot products + bias grad, (2) grad_input + grad_mul + weight grads
33
+ - Weight/bias gradients use tl.atomic_add for cross-row accumulation
34
+ """
35
+
36
+ import torch
37
+ from torch import Tensor
38
+
39
+ try:
40
+ import triton
41
+ import triton.language as tl
42
+
43
+ HAS_TRITON = True
44
+ except ImportError:
45
+ HAS_TRITON = False
46
+
47
+
48
+ # ---------------------------------------------------------------------------
49
+ # PyTorch reference implementation (for testing and benchmarking)
50
+ # ---------------------------------------------------------------------------
51
+ def _rms_norm(x: Tensor, eps: float) -> Tensor:
52
+ """Per-row RMS normalization: x / sqrt(mean(x^2, dim=-1) + eps)"""
53
+ return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps)
54
+
55
+
56
+ def grouped_fused_mul_poly_norm_ref(
57
+ input: Tensor,
58
+ mul: Tensor,
59
+ weight: Tensor,
60
+ bias: Tensor,
61
+ offsets: Tensor,
62
+ eps: float = 1e-6,
63
+ expert_offset: int = 0,
64
+ ) -> Tensor:
65
+ """PyTorch reference for Grouped FusedMulPolyNorm (vectorized, single pass).
66
+
67
+ Uses torch.bucketize to map tokens to experts, then computes PolyNorm
68
+ for all tokens at once. torch.compile friendly.
69
+
70
+ Args:
71
+ input: (total_tokens, D) - concatenated tokens for all experts
72
+ mul: (total_tokens, D) - gate values to multiply with
73
+ weight: (num_experts, 3) - per-expert polynomial weights [x^3, x^2, x]
74
+ bias: (num_experts, 1) - per-expert polynomial bias
75
+ offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32)
76
+ eps: numerical stability epsilon
77
+
78
+ Returns:
79
+ (total_tokens, D) - output tensor
80
+ """
81
+ orig_dtype = input.dtype
82
+
83
+ token_positions = torch.arange(input.shape[0], device=input.device)
84
+ expert_idx = torch.bucketize(token_positions, offsets, right=True) + expert_offset
85
+
86
+ weight_fp32 = weight.float()
87
+ bias_fp32 = bias.float()
88
+
89
+ per_token_w = weight_fp32[expert_idx]
90
+ per_token_b = bias_fp32[expert_idx]
91
+
92
+ x = input.float()
93
+ m = mul.float()
94
+
95
+ x2 = x * x
96
+ x3 = x2 * x
97
+
98
+ poly = (per_token_w[:, 0:1] * _rms_norm(x3, eps) +
99
+ per_token_w[:, 1:2] * _rms_norm(x2, eps) +
100
+ per_token_w[:, 2:3] * _rms_norm(x, eps) + per_token_b)
101
+
102
+ return (poly * m).to(orig_dtype)
103
+
104
+
105
+ # ---------------------------------------------------------------------------
106
+ # Triton kernel implementation
107
+ # ---------------------------------------------------------------------------
108
+ if HAS_TRITON:
109
+ # --- Autotune configurations ---
110
+ _GROUPED_POLYNORM_FWD_CONFIGS = [
111
+ triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2),
112
+ triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3),
113
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2),
114
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3),
115
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2),
116
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4),
117
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2),
118
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3),
119
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4),
120
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2),
121
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3),
122
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4),
123
+ triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2),
124
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2),
125
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3),
126
+ triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2),
127
+ triton.Config({"BLOCK_D": 2048}, num_warps=4, num_stages=1),
128
+ triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1),
129
+ triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1),
130
+ triton.Config({"BLOCK_D": 2048}, num_warps=32, num_stages=1),
131
+ ]
132
+
133
+ _GROUPED_POLYNORM_BWD_CONFIGS = [
134
+ triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2),
135
+ triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3),
136
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2),
137
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3),
138
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2),
139
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4),
140
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2),
141
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3),
142
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4),
143
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=5),
144
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2),
145
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3),
146
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4),
147
+ triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2),
148
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2),
149
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3),
150
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=4),
151
+ triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2),
152
+ triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1),
153
+ triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1),
154
+ ]
155
+
156
+ @triton.autotune(
157
+ configs=_GROUPED_POLYNORM_FWD_CONFIGS,
158
+ key=["D"],
159
+ )
160
+ @triton.jit
161
+ def _grouped_polynorm_fwd_kernel(
162
+ input_ptr,
163
+ mul_ptr,
164
+ weight_ptr,
165
+ bias_ptr,
166
+ offsets_ptr,
167
+ output_ptr,
168
+ N,
169
+ D,
170
+ num_experts,
171
+ eps,
172
+ expert_offset,
173
+ stride_input_row,
174
+ stride_mul_row,
175
+ stride_out_row,
176
+ BLOCK_D: tl.constexpr,
177
+ ):
178
+ """Forward kernel: one program per row."""
179
+ row = tl.program_id(0)
180
+ if row >= N:
181
+ return
182
+
183
+ # Binary search for expert index (12 iters covers up to 4096 experts)
184
+ lo = 0
185
+ hi = num_experts
186
+ for _ in range(12):
187
+ if lo < hi:
188
+ mid = (lo + hi) // 2
189
+ if tl.load(offsets_ptr + mid) <= row:
190
+ lo = mid + 1
191
+ else:
192
+ hi = mid
193
+ eidx = lo + expert_offset
194
+
195
+ w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32)
196
+ w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32)
197
+ w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32)
198
+ b = tl.load(bias_ptr + eidx).to(tl.float32)
199
+
200
+ input_row_ptr = input_ptr + row * stride_input_row
201
+ mul_row_ptr = mul_ptr + row * stride_mul_row
202
+ out_row_ptr = output_ptr + row * stride_out_row
203
+
204
+ D_float = D.to(tl.float32)
205
+
206
+ # --- Single-tile path ---
207
+ if D <= BLOCK_D:
208
+ d_offs = tl.arange(0, BLOCK_D)
209
+ mask = d_offs < D
210
+
211
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
212
+ other=0.0).to(tl.float32)
213
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
214
+ other=0.0).to(tl.float32)
215
+
216
+ x2 = x * x
217
+ x3 = x2 * x
218
+
219
+ inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps)
220
+ inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps)
221
+ inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps)
222
+
223
+ # Pre-multiply scalar weight * inv_rms to save 1 FMA per element
224
+ w0_inv = w0 * inv_rms_x3
225
+ w1_inv = w1 * inv_rms_x2
226
+ w2_inv = w2 * inv_rms_x
227
+
228
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b
229
+ tl.store(out_row_ptr + d_offs, poly * m, mask=mask)
230
+ else:
231
+ # --- Multi-tile: two-pass approach ---
232
+ sum_x2 = tl.zeros((), dtype=tl.float32)
233
+ sum_x4 = tl.zeros((), dtype=tl.float32)
234
+ sum_x6 = tl.zeros((), dtype=tl.float32)
235
+
236
+ for d_start in range(0, D, BLOCK_D):
237
+ d_offs = d_start + tl.arange(0, BLOCK_D)
238
+ mask = d_offs < D
239
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
240
+ other=0.0).to(tl.float32)
241
+ x2 = x * x
242
+ sum_x2 += tl.sum(x2)
243
+ sum_x4 += tl.sum(x2 * x2)
244
+ sum_x6 += tl.sum(x2 * x2 * x2)
245
+
246
+ inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps)
247
+ inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps)
248
+ inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps)
249
+
250
+ # Pre-multiply scalar weight * inv_rms
251
+ w0_inv = w0 * inv_rms_x3
252
+ w1_inv = w1 * inv_rms_x2
253
+ w2_inv = w2 * inv_rms_x
254
+
255
+ for d_start in range(0, D, BLOCK_D):
256
+ d_offs = d_start + tl.arange(0, BLOCK_D)
257
+ mask = d_offs < D
258
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
259
+ other=0.0).to(tl.float32)
260
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
261
+ other=0.0).to(tl.float32)
262
+ x2 = x * x
263
+ x3 = x2 * x
264
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b
265
+ tl.store(out_row_ptr + d_offs, poly * m, mask=mask)
266
+
267
+ @triton.autotune(
268
+ configs=_GROUPED_POLYNORM_BWD_CONFIGS,
269
+ key=["D"],
270
+ reset_to_zero=["grad_weight_ptr", "grad_bias_ptr"],
271
+ )
272
+ @triton.jit
273
+ def _grouped_polynorm_bwd_kernel(
274
+ grad_out_ptr,
275
+ input_ptr,
276
+ mul_ptr,
277
+ weight_ptr,
278
+ bias_ptr,
279
+ offsets_ptr,
280
+ grad_input_ptr,
281
+ grad_mul_ptr,
282
+ grad_weight_ptr,
283
+ grad_bias_ptr,
284
+ N,
285
+ D,
286
+ num_experts,
287
+ eps,
288
+ expert_offset,
289
+ stride_row,
290
+ BLOCK_D: tl.constexpr,
291
+ ):
292
+ """Backward kernel: one program per row, 2-pass approach.
293
+
294
+ Pass 1: RMS stats + dot products + bias grad
295
+ Pass 2: grad_input + grad_mul + weight grads (via atomic_add)
296
+ """
297
+ row = tl.program_id(0)
298
+ if row >= N:
299
+ return
300
+
301
+ lo = 0
302
+ hi = num_experts
303
+ for _ in range(12):
304
+ if lo < hi:
305
+ mid = (lo + hi) // 2
306
+ if tl.load(offsets_ptr + mid) <= row:
307
+ lo = mid + 1
308
+ else:
309
+ hi = mid
310
+ eidx = lo + expert_offset
311
+
312
+ w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32)
313
+ w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32)
314
+ w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32)
315
+ b_val = tl.load(bias_ptr + eidx).to(tl.float32)
316
+
317
+ input_row_ptr = input_ptr + row * stride_row
318
+ mul_row_ptr = mul_ptr + row * stride_row
319
+ grad_out_row_ptr = grad_out_ptr + row * stride_row
320
+ grad_input_row_ptr = grad_input_ptr + row * stride_row
321
+ grad_mul_row_ptr = grad_mul_ptr + row * stride_row
322
+
323
+ D_float = D.to(tl.float32)
324
+
325
+ # --- Single-tile path ---
326
+ if D <= BLOCK_D:
327
+ d_offs = tl.arange(0, BLOCK_D)
328
+ mask = d_offs < D
329
+
330
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
331
+ other=0.0).to(tl.float32)
332
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
333
+ other=0.0).to(tl.float32)
334
+ go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
335
+ other=0.0).to(tl.float32)
336
+
337
+ x2 = x * x
338
+ x3 = x2 * x
339
+
340
+ # Compute RMS stats (x4 inlined to reduce register pressure)
341
+ inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps)
342
+ inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps)
343
+ inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps)
344
+
345
+ w0_inv = w0 * inv_rms_x3
346
+ w1_inv = w1 * inv_rms_x2
347
+ w2_inv = w2 * inv_rms_x
348
+
349
+ dpoly = go * m
350
+
351
+ # Dot products for coefficients and weight grads
352
+ sum_dpoly_x = tl.sum(dpoly * x)
353
+ sum_dpoly_x2 = tl.sum(dpoly * x2)
354
+ sum_dpoly_x3 = tl.sum(dpoly * x3)
355
+ grad_b_acc = tl.sum(dpoly)
356
+
357
+ # Weight grads
358
+ grad_w0_acc = inv_rms_x3 * sum_dpoly_x3
359
+ grad_w1_acc = inv_rms_x2 * sum_dpoly_x2
360
+ grad_w2_acc = inv_rms_x * sum_dpoly_x
361
+
362
+ # Coefficients for grad_input
363
+ coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float
364
+ coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float
365
+ coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float
366
+
367
+ # grad_mul
368
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv
369
+ tl.store(grad_mul_row_ptr + d_offs, go * (poly + b_val), mask=mask)
370
+
371
+ # grad_input (in-place accumulation to reduce register pressure)
372
+ g = inv_rms_x * (w2 * dpoly - x * coeff_x)
373
+ g += 2.0 * x * inv_rms_x2 * (w1 * dpoly - x2 * coeff_x2)
374
+ g += 3.0 * x2 * inv_rms_x3 * (w0 * dpoly - x3 * coeff_x3)
375
+ tl.store(grad_input_row_ptr + d_offs, g, mask=mask)
376
+
377
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc)
378
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc)
379
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc)
380
+ tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc)
381
+ else:
382
+ # --- Multi-tile: 2-pass ---
383
+ # Pass 1: RMS stats + dot products + bias grad
384
+ sum_x2 = tl.zeros((), dtype=tl.float32)
385
+ sum_x4 = tl.zeros((), dtype=tl.float32)
386
+ sum_x6 = tl.zeros((), dtype=tl.float32)
387
+ sum_dpoly_x = tl.zeros((), dtype=tl.float32)
388
+ sum_dpoly_x2 = tl.zeros((), dtype=tl.float32)
389
+ sum_dpoly_x3 = tl.zeros((), dtype=tl.float32)
390
+ grad_b_acc = tl.zeros((), dtype=tl.float32)
391
+
392
+ for d_start in range(0, D, BLOCK_D):
393
+ d_offs = d_start + tl.arange(0, BLOCK_D)
394
+ mask = d_offs < D
395
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
396
+ other=0.0).to(tl.float32)
397
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
398
+ other=0.0).to(tl.float32)
399
+ go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
400
+ other=0.0).to(tl.float32)
401
+
402
+ x2 = x * x
403
+ x3 = x2 * x
404
+ dpoly = go * m
405
+
406
+ sum_x2 += tl.sum(x2)
407
+ sum_x4 += tl.sum(x2 * x2)
408
+ sum_x6 += tl.sum(x2 * x2 * x2)
409
+ sum_dpoly_x += tl.sum(dpoly * x)
410
+ sum_dpoly_x2 += tl.sum(dpoly * x2)
411
+ sum_dpoly_x3 += tl.sum(dpoly * x3)
412
+ grad_b_acc += tl.sum(dpoly)
413
+
414
+ inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps)
415
+ inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps)
416
+ inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps)
417
+
418
+ w0_inv = w0 * inv_rms_x3
419
+ w1_inv = w1 * inv_rms_x2
420
+ w2_inv = w2 * inv_rms_x
421
+
422
+ # Weight grads
423
+ grad_w0_acc = inv_rms_x3 * sum_dpoly_x3
424
+ grad_w1_acc = inv_rms_x2 * sum_dpoly_x2
425
+ grad_w2_acc = inv_rms_x * sum_dpoly_x
426
+
427
+ # Coefficients for grad_input
428
+ coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float
429
+ coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float
430
+ coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float
431
+
432
+ # Pass 2: grad_input + grad_mul
433
+ for d_start in range(0, D, BLOCK_D):
434
+ d_offs = d_start + tl.arange(0, BLOCK_D)
435
+ mask = d_offs < D
436
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
437
+ other=0.0).to(tl.float32)
438
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
439
+ other=0.0).to(tl.float32)
440
+ go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
441
+ other=0.0).to(tl.float32)
442
+
443
+ x2 = x * x
444
+ x3 = x2 * x
445
+
446
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv
447
+ tl.store(grad_mul_row_ptr + d_offs,
448
+ go * (poly + b_val),
449
+ mask=mask)
450
+
451
+ dpoly = go * m
452
+ g = inv_rms_x * (w2 * dpoly - x * coeff_x)
453
+ g += (2.0 * x * inv_rms_x2 *
454
+ (w1 * dpoly - x2 * coeff_x2))
455
+ g += (3.0 * x2 * inv_rms_x3 *
456
+ (w0 * dpoly - x3 * coeff_x3))
457
+
458
+ tl.store(grad_input_row_ptr + d_offs, g, mask=mask)
459
+
460
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc)
461
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc)
462
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc)
463
+ tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc)
464
+
465
+ class _GroupedPolyNormFn(torch.autograd.Function):
466
+
467
+ @staticmethod
468
+ def forward(ctx, input, mul, weight, bias, offsets, eps, expert_offset):
469
+ N, D = input.shape
470
+ input = input.contiguous()
471
+ mul = mul.contiguous()
472
+ output = torch.empty_like(input)
473
+
474
+ num_experts = offsets.shape[0]
475
+ assert num_experts <= 4096, (
476
+ f"Supports at most 4096 experts, got {num_experts}.")
477
+
478
+ _grouped_polynorm_fwd_kernel[(N,)](
479
+ input,
480
+ mul,
481
+ weight,
482
+ bias,
483
+ offsets,
484
+ output,
485
+ N,
486
+ D,
487
+ num_experts,
488
+ eps,
489
+ expert_offset,
490
+ stride_input_row=input.stride(0),
491
+ stride_mul_row=mul.stride(0),
492
+ stride_out_row=output.stride(0),
493
+ )
494
+
495
+ ctx.save_for_backward(input, mul, weight, bias, offsets)
496
+ ctx.eps = eps
497
+ ctx.expert_offset = expert_offset
498
+ return output
499
+
500
+ @staticmethod
501
+ def backward(ctx, grad_output):
502
+ input, mul, weight, bias, offsets = ctx.saved_tensors
503
+ eps = ctx.eps
504
+ expert_offset = ctx.expert_offset
505
+ N, D = input.shape
506
+
507
+ grad_output = grad_output.contiguous()
508
+ grad_input = torch.empty_like(input)
509
+ grad_mul = torch.empty_like(mul)
510
+ grad_weight = torch.zeros(weight.shape[0],
511
+ 3,
512
+ device=weight.device,
513
+ dtype=torch.float32)
514
+ grad_bias = torch.zeros(bias.shape[0],
515
+ device=bias.device,
516
+ dtype=torch.float32)
517
+
518
+ num_experts = offsets.shape[0]
519
+
520
+ _grouped_polynorm_bwd_kernel[(N,)](
521
+ grad_output,
522
+ input,
523
+ mul,
524
+ weight,
525
+ bias,
526
+ offsets,
527
+ grad_input,
528
+ grad_mul,
529
+ grad_weight,
530
+ grad_bias,
531
+ N,
532
+ D,
533
+ num_experts,
534
+ eps,
535
+ expert_offset,
536
+ stride_row=input.stride(0),
537
+ )
538
+
539
+ grad_weight = grad_weight.to(weight.dtype)
540
+ grad_bias = grad_bias.unsqueeze(-1).to(bias.dtype)
541
+
542
+ return grad_input, grad_mul, grad_weight, grad_bias, None, None, None
543
+
544
+ def grouped_fused_mul_poly_norm(
545
+ input: Tensor,
546
+ mul: Tensor,
547
+ weight: Tensor,
548
+ bias: Tensor,
549
+ offsets: Tensor,
550
+ eps: float = 1e-6,
551
+ expert_offset: int = 0,
552
+ ) -> Tensor:
553
+ """Triton-accelerated Grouped FusedMulPolyNorm.
554
+
555
+ Args:
556
+ input: (total_tokens, D) - concatenated tokens for all experts
557
+ mul: (total_tokens, D) - gate values to multiply with
558
+ weight: (num_experts, 3) - per-expert polynomial weights
559
+ bias: (num_experts, 1) - per-expert polynomial bias
560
+ offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32)
561
+ eps: numerical stability epsilon
562
+ expert_offset: offset to add to expert index
563
+
564
+ Returns:
565
+ (total_tokens, D) - output tensor
566
+ """
567
+ return _GroupedPolyNormFn.apply(input, mul, weight, bias, offsets, eps,
568
+ expert_offset)
569
+
570
+ else:
571
+
572
+ def grouped_fused_mul_poly_norm(
573
+ input: Tensor,
574
+ mul: Tensor,
575
+ weight: Tensor,
576
+ bias: Tensor,
577
+ offsets: Tensor,
578
+ eps: float = 1e-6,
579
+ expert_offset: int = 0,
580
+ ) -> Tensor:
581
+ raise RuntimeError(
582
+ "Triton is not available. Install triton to use "
583
+ "grouped_fused_mul_poly_norm.")
build/torch28-cxx11-rocm64-x86_64-linux/__init__.py CHANGED
@@ -2,6 +2,7 @@ import torch
2
 
3
  from . import layers, parallel_style
4
  from ._ops import ops
 
5
  from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
6
  from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
7
 
@@ -45,6 +46,7 @@ def fused_add_rms_norm(
45
  __all__ = [
46
  "poly_norm",
47
  "fused_mul_poly_norm",
 
48
  "rms_norm",
49
  "fused_add_rms_norm",
50
  "layers",
 
2
 
3
  from . import layers, parallel_style
4
  from ._ops import ops
5
+ from .grouped_poly_norm import grouped_fused_mul_poly_norm
6
  from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
7
  from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
8
 
 
46
  __all__ = [
47
  "poly_norm",
48
  "fused_mul_poly_norm",
49
+ "grouped_fused_mul_poly_norm",
50
  "rms_norm",
51
  "fused_add_rms_norm",
52
  "layers",
build/torch28-cxx11-rocm64-x86_64-linux/{_activation_18b7543_dirty.abi3.so → _activation_0e6f27f_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:5213ce21365e97016b4a04bfa304e03c9fc1dc6780f62b3312fb65f96e8f6381
3
  size 2794152
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:388f461d91124b99544cbdbd4dc4d98f24c010d7a0dc2e9389648860a809a51a
3
  size 2794152
build/torch28-cxx11-rocm64-x86_64-linux/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _activation_18b7543_dirty
3
- ops = torch.ops._activation_18b7543_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_activation_18b7543_dirty::{op_name}"
 
1
  import torch
2
+ from . import _activation_0e6f27f_dirty
3
+ ops = torch.ops._activation_0e6f27f_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_activation_0e6f27f_dirty::{op_name}"
build/torch28-cxx11-rocm64-x86_64-linux/grouped_poly_norm.py ADDED
@@ -0,0 +1,583 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Triton-accelerated Grouped FusedMulPolyNorm for MoE.
2
+
3
+ Fuses the entire PolyNorm computation into two Triton kernels (fwd + bwd),
4
+ eliminating multiple intermediate tensors and kernel launches.
5
+
6
+ PolyNorm formula (per row):
7
+ poly = w[0] * rms_norm(x^3) + w[1] * rms_norm(x^2) + w[2] * rms_norm(x) + bias
8
+ output = poly * mul
9
+
10
+ where rms_norm(x) = x / sqrt(mean(x^2, dim=-1) + eps)
11
+
12
+ Performance optimizations:
13
+ - @triton.autotune selects optimal BLOCK_D, num_warps, and num_stages per
14
+ hidden dimension.
15
+ - Single-tile specialization: when D <= BLOCK_D, all data stays in registers
16
+ across the reduction and output phases, eliminating redundant global reads.
17
+ - Multi-tile software pipelining: explicit num_stages in autotune configs
18
+ enables overlapping memory loads with computation across loop iterations.
19
+ - In-kernel binary search for expert mapping: eliminates 2 PyTorch kernel
20
+ launches (torch.arange + torch.bucketize) per forward/backward call.
21
+ - Backward 2-pass optimization: pass 1 merges RMS statistics computation
22
+ with dot product accumulation, pass 2 computes gradients. This reduces
23
+ memory traffic compared to a naive 3-pass approach.
24
+
25
+ Forward kernel: one program per row, tiles over D dimension.
26
+ - Computes x, x^2, x^3 in registers
27
+ - Computes three RMS norms in a single pass (shared variance reduction)
28
+ - Applies polynomial weights + bias + mul in-place
29
+
30
+ Backward kernel: one program per row, tiles over D dimension.
31
+ - Recomputes forward intermediates from saved inputs (activation recomputation)
32
+ - 2-pass: (1) RMS stats + dot products + bias grad, (2) grad_input + grad_mul + weight grads
33
+ - Weight/bias gradients use tl.atomic_add for cross-row accumulation
34
+ """
35
+
36
+ import torch
37
+ from torch import Tensor
38
+
39
+ try:
40
+ import triton
41
+ import triton.language as tl
42
+
43
+ HAS_TRITON = True
44
+ except ImportError:
45
+ HAS_TRITON = False
46
+
47
+
48
+ # ---------------------------------------------------------------------------
49
+ # PyTorch reference implementation (for testing and benchmarking)
50
+ # ---------------------------------------------------------------------------
51
+ def _rms_norm(x: Tensor, eps: float) -> Tensor:
52
+ """Per-row RMS normalization: x / sqrt(mean(x^2, dim=-1) + eps)"""
53
+ return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps)
54
+
55
+
56
+ def grouped_fused_mul_poly_norm_ref(
57
+ input: Tensor,
58
+ mul: Tensor,
59
+ weight: Tensor,
60
+ bias: Tensor,
61
+ offsets: Tensor,
62
+ eps: float = 1e-6,
63
+ expert_offset: int = 0,
64
+ ) -> Tensor:
65
+ """PyTorch reference for Grouped FusedMulPolyNorm (vectorized, single pass).
66
+
67
+ Uses torch.bucketize to map tokens to experts, then computes PolyNorm
68
+ for all tokens at once. torch.compile friendly.
69
+
70
+ Args:
71
+ input: (total_tokens, D) - concatenated tokens for all experts
72
+ mul: (total_tokens, D) - gate values to multiply with
73
+ weight: (num_experts, 3) - per-expert polynomial weights [x^3, x^2, x]
74
+ bias: (num_experts, 1) - per-expert polynomial bias
75
+ offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32)
76
+ eps: numerical stability epsilon
77
+
78
+ Returns:
79
+ (total_tokens, D) - output tensor
80
+ """
81
+ orig_dtype = input.dtype
82
+
83
+ token_positions = torch.arange(input.shape[0], device=input.device)
84
+ expert_idx = torch.bucketize(token_positions, offsets, right=True) + expert_offset
85
+
86
+ weight_fp32 = weight.float()
87
+ bias_fp32 = bias.float()
88
+
89
+ per_token_w = weight_fp32[expert_idx]
90
+ per_token_b = bias_fp32[expert_idx]
91
+
92
+ x = input.float()
93
+ m = mul.float()
94
+
95
+ x2 = x * x
96
+ x3 = x2 * x
97
+
98
+ poly = (per_token_w[:, 0:1] * _rms_norm(x3, eps) +
99
+ per_token_w[:, 1:2] * _rms_norm(x2, eps) +
100
+ per_token_w[:, 2:3] * _rms_norm(x, eps) + per_token_b)
101
+
102
+ return (poly * m).to(orig_dtype)
103
+
104
+
105
+ # ---------------------------------------------------------------------------
106
+ # Triton kernel implementation
107
+ # ---------------------------------------------------------------------------
108
+ if HAS_TRITON:
109
+ # --- Autotune configurations ---
110
+ _GROUPED_POLYNORM_FWD_CONFIGS = [
111
+ triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2),
112
+ triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3),
113
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2),
114
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3),
115
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2),
116
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4),
117
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2),
118
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3),
119
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4),
120
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2),
121
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3),
122
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4),
123
+ triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2),
124
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2),
125
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3),
126
+ triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2),
127
+ triton.Config({"BLOCK_D": 2048}, num_warps=4, num_stages=1),
128
+ triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1),
129
+ triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1),
130
+ triton.Config({"BLOCK_D": 2048}, num_warps=32, num_stages=1),
131
+ ]
132
+
133
+ _GROUPED_POLYNORM_BWD_CONFIGS = [
134
+ triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2),
135
+ triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3),
136
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2),
137
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3),
138
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2),
139
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4),
140
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2),
141
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3),
142
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4),
143
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=5),
144
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2),
145
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3),
146
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4),
147
+ triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2),
148
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2),
149
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3),
150
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=4),
151
+ triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2),
152
+ triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1),
153
+ triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1),
154
+ ]
155
+
156
+ @triton.autotune(
157
+ configs=_GROUPED_POLYNORM_FWD_CONFIGS,
158
+ key=["D"],
159
+ )
160
+ @triton.jit
161
+ def _grouped_polynorm_fwd_kernel(
162
+ input_ptr,
163
+ mul_ptr,
164
+ weight_ptr,
165
+ bias_ptr,
166
+ offsets_ptr,
167
+ output_ptr,
168
+ N,
169
+ D,
170
+ num_experts,
171
+ eps,
172
+ expert_offset,
173
+ stride_input_row,
174
+ stride_mul_row,
175
+ stride_out_row,
176
+ BLOCK_D: tl.constexpr,
177
+ ):
178
+ """Forward kernel: one program per row."""
179
+ row = tl.program_id(0)
180
+ if row >= N:
181
+ return
182
+
183
+ # Binary search for expert index (12 iters covers up to 4096 experts)
184
+ lo = 0
185
+ hi = num_experts
186
+ for _ in range(12):
187
+ if lo < hi:
188
+ mid = (lo + hi) // 2
189
+ if tl.load(offsets_ptr + mid) <= row:
190
+ lo = mid + 1
191
+ else:
192
+ hi = mid
193
+ eidx = lo + expert_offset
194
+
195
+ w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32)
196
+ w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32)
197
+ w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32)
198
+ b = tl.load(bias_ptr + eidx).to(tl.float32)
199
+
200
+ input_row_ptr = input_ptr + row * stride_input_row
201
+ mul_row_ptr = mul_ptr + row * stride_mul_row
202
+ out_row_ptr = output_ptr + row * stride_out_row
203
+
204
+ D_float = D.to(tl.float32)
205
+
206
+ # --- Single-tile path ---
207
+ if D <= BLOCK_D:
208
+ d_offs = tl.arange(0, BLOCK_D)
209
+ mask = d_offs < D
210
+
211
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
212
+ other=0.0).to(tl.float32)
213
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
214
+ other=0.0).to(tl.float32)
215
+
216
+ x2 = x * x
217
+ x3 = x2 * x
218
+
219
+ inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps)
220
+ inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps)
221
+ inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps)
222
+
223
+ # Pre-multiply scalar weight * inv_rms to save 1 FMA per element
224
+ w0_inv = w0 * inv_rms_x3
225
+ w1_inv = w1 * inv_rms_x2
226
+ w2_inv = w2 * inv_rms_x
227
+
228
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b
229
+ tl.store(out_row_ptr + d_offs, poly * m, mask=mask)
230
+ else:
231
+ # --- Multi-tile: two-pass approach ---
232
+ sum_x2 = tl.zeros((), dtype=tl.float32)
233
+ sum_x4 = tl.zeros((), dtype=tl.float32)
234
+ sum_x6 = tl.zeros((), dtype=tl.float32)
235
+
236
+ for d_start in range(0, D, BLOCK_D):
237
+ d_offs = d_start + tl.arange(0, BLOCK_D)
238
+ mask = d_offs < D
239
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
240
+ other=0.0).to(tl.float32)
241
+ x2 = x * x
242
+ sum_x2 += tl.sum(x2)
243
+ sum_x4 += tl.sum(x2 * x2)
244
+ sum_x6 += tl.sum(x2 * x2 * x2)
245
+
246
+ inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps)
247
+ inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps)
248
+ inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps)
249
+
250
+ # Pre-multiply scalar weight * inv_rms
251
+ w0_inv = w0 * inv_rms_x3
252
+ w1_inv = w1 * inv_rms_x2
253
+ w2_inv = w2 * inv_rms_x
254
+
255
+ for d_start in range(0, D, BLOCK_D):
256
+ d_offs = d_start + tl.arange(0, BLOCK_D)
257
+ mask = d_offs < D
258
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
259
+ other=0.0).to(tl.float32)
260
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
261
+ other=0.0).to(tl.float32)
262
+ x2 = x * x
263
+ x3 = x2 * x
264
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b
265
+ tl.store(out_row_ptr + d_offs, poly * m, mask=mask)
266
+
267
+ @triton.autotune(
268
+ configs=_GROUPED_POLYNORM_BWD_CONFIGS,
269
+ key=["D"],
270
+ reset_to_zero=["grad_weight_ptr", "grad_bias_ptr"],
271
+ )
272
+ @triton.jit
273
+ def _grouped_polynorm_bwd_kernel(
274
+ grad_out_ptr,
275
+ input_ptr,
276
+ mul_ptr,
277
+ weight_ptr,
278
+ bias_ptr,
279
+ offsets_ptr,
280
+ grad_input_ptr,
281
+ grad_mul_ptr,
282
+ grad_weight_ptr,
283
+ grad_bias_ptr,
284
+ N,
285
+ D,
286
+ num_experts,
287
+ eps,
288
+ expert_offset,
289
+ stride_row,
290
+ BLOCK_D: tl.constexpr,
291
+ ):
292
+ """Backward kernel: one program per row, 2-pass approach.
293
+
294
+ Pass 1: RMS stats + dot products + bias grad
295
+ Pass 2: grad_input + grad_mul + weight grads (via atomic_add)
296
+ """
297
+ row = tl.program_id(0)
298
+ if row >= N:
299
+ return
300
+
301
+ lo = 0
302
+ hi = num_experts
303
+ for _ in range(12):
304
+ if lo < hi:
305
+ mid = (lo + hi) // 2
306
+ if tl.load(offsets_ptr + mid) <= row:
307
+ lo = mid + 1
308
+ else:
309
+ hi = mid
310
+ eidx = lo + expert_offset
311
+
312
+ w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32)
313
+ w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32)
314
+ w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32)
315
+ b_val = tl.load(bias_ptr + eidx).to(tl.float32)
316
+
317
+ input_row_ptr = input_ptr + row * stride_row
318
+ mul_row_ptr = mul_ptr + row * stride_row
319
+ grad_out_row_ptr = grad_out_ptr + row * stride_row
320
+ grad_input_row_ptr = grad_input_ptr + row * stride_row
321
+ grad_mul_row_ptr = grad_mul_ptr + row * stride_row
322
+
323
+ D_float = D.to(tl.float32)
324
+
325
+ # --- Single-tile path ---
326
+ if D <= BLOCK_D:
327
+ d_offs = tl.arange(0, BLOCK_D)
328
+ mask = d_offs < D
329
+
330
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
331
+ other=0.0).to(tl.float32)
332
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
333
+ other=0.0).to(tl.float32)
334
+ go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
335
+ other=0.0).to(tl.float32)
336
+
337
+ x2 = x * x
338
+ x3 = x2 * x
339
+
340
+ # Compute RMS stats (x4 inlined to reduce register pressure)
341
+ inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps)
342
+ inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps)
343
+ inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps)
344
+
345
+ w0_inv = w0 * inv_rms_x3
346
+ w1_inv = w1 * inv_rms_x2
347
+ w2_inv = w2 * inv_rms_x
348
+
349
+ dpoly = go * m
350
+
351
+ # Dot products for coefficients and weight grads
352
+ sum_dpoly_x = tl.sum(dpoly * x)
353
+ sum_dpoly_x2 = tl.sum(dpoly * x2)
354
+ sum_dpoly_x3 = tl.sum(dpoly * x3)
355
+ grad_b_acc = tl.sum(dpoly)
356
+
357
+ # Weight grads
358
+ grad_w0_acc = inv_rms_x3 * sum_dpoly_x3
359
+ grad_w1_acc = inv_rms_x2 * sum_dpoly_x2
360
+ grad_w2_acc = inv_rms_x * sum_dpoly_x
361
+
362
+ # Coefficients for grad_input
363
+ coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float
364
+ coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float
365
+ coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float
366
+
367
+ # grad_mul
368
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv
369
+ tl.store(grad_mul_row_ptr + d_offs, go * (poly + b_val), mask=mask)
370
+
371
+ # grad_input (in-place accumulation to reduce register pressure)
372
+ g = inv_rms_x * (w2 * dpoly - x * coeff_x)
373
+ g += 2.0 * x * inv_rms_x2 * (w1 * dpoly - x2 * coeff_x2)
374
+ g += 3.0 * x2 * inv_rms_x3 * (w0 * dpoly - x3 * coeff_x3)
375
+ tl.store(grad_input_row_ptr + d_offs, g, mask=mask)
376
+
377
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc)
378
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc)
379
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc)
380
+ tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc)
381
+ else:
382
+ # --- Multi-tile: 2-pass ---
383
+ # Pass 1: RMS stats + dot products + bias grad
384
+ sum_x2 = tl.zeros((), dtype=tl.float32)
385
+ sum_x4 = tl.zeros((), dtype=tl.float32)
386
+ sum_x6 = tl.zeros((), dtype=tl.float32)
387
+ sum_dpoly_x = tl.zeros((), dtype=tl.float32)
388
+ sum_dpoly_x2 = tl.zeros((), dtype=tl.float32)
389
+ sum_dpoly_x3 = tl.zeros((), dtype=tl.float32)
390
+ grad_b_acc = tl.zeros((), dtype=tl.float32)
391
+
392
+ for d_start in range(0, D, BLOCK_D):
393
+ d_offs = d_start + tl.arange(0, BLOCK_D)
394
+ mask = d_offs < D
395
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
396
+ other=0.0).to(tl.float32)
397
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
398
+ other=0.0).to(tl.float32)
399
+ go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
400
+ other=0.0).to(tl.float32)
401
+
402
+ x2 = x * x
403
+ x3 = x2 * x
404
+ dpoly = go * m
405
+
406
+ sum_x2 += tl.sum(x2)
407
+ sum_x4 += tl.sum(x2 * x2)
408
+ sum_x6 += tl.sum(x2 * x2 * x2)
409
+ sum_dpoly_x += tl.sum(dpoly * x)
410
+ sum_dpoly_x2 += tl.sum(dpoly * x2)
411
+ sum_dpoly_x3 += tl.sum(dpoly * x3)
412
+ grad_b_acc += tl.sum(dpoly)
413
+
414
+ inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps)
415
+ inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps)
416
+ inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps)
417
+
418
+ w0_inv = w0 * inv_rms_x3
419
+ w1_inv = w1 * inv_rms_x2
420
+ w2_inv = w2 * inv_rms_x
421
+
422
+ # Weight grads
423
+ grad_w0_acc = inv_rms_x3 * sum_dpoly_x3
424
+ grad_w1_acc = inv_rms_x2 * sum_dpoly_x2
425
+ grad_w2_acc = inv_rms_x * sum_dpoly_x
426
+
427
+ # Coefficients for grad_input
428
+ coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float
429
+ coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float
430
+ coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float
431
+
432
+ # Pass 2: grad_input + grad_mul
433
+ for d_start in range(0, D, BLOCK_D):
434
+ d_offs = d_start + tl.arange(0, BLOCK_D)
435
+ mask = d_offs < D
436
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
437
+ other=0.0).to(tl.float32)
438
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
439
+ other=0.0).to(tl.float32)
440
+ go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
441
+ other=0.0).to(tl.float32)
442
+
443
+ x2 = x * x
444
+ x3 = x2 * x
445
+
446
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv
447
+ tl.store(grad_mul_row_ptr + d_offs,
448
+ go * (poly + b_val),
449
+ mask=mask)
450
+
451
+ dpoly = go * m
452
+ g = inv_rms_x * (w2 * dpoly - x * coeff_x)
453
+ g += (2.0 * x * inv_rms_x2 *
454
+ (w1 * dpoly - x2 * coeff_x2))
455
+ g += (3.0 * x2 * inv_rms_x3 *
456
+ (w0 * dpoly - x3 * coeff_x3))
457
+
458
+ tl.store(grad_input_row_ptr + d_offs, g, mask=mask)
459
+
460
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc)
461
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc)
462
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc)
463
+ tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc)
464
+
465
+ class _GroupedPolyNormFn(torch.autograd.Function):
466
+
467
+ @staticmethod
468
+ def forward(ctx, input, mul, weight, bias, offsets, eps, expert_offset):
469
+ N, D = input.shape
470
+ input = input.contiguous()
471
+ mul = mul.contiguous()
472
+ output = torch.empty_like(input)
473
+
474
+ num_experts = offsets.shape[0]
475
+ assert num_experts <= 4096, (
476
+ f"Supports at most 4096 experts, got {num_experts}.")
477
+
478
+ _grouped_polynorm_fwd_kernel[(N,)](
479
+ input,
480
+ mul,
481
+ weight,
482
+ bias,
483
+ offsets,
484
+ output,
485
+ N,
486
+ D,
487
+ num_experts,
488
+ eps,
489
+ expert_offset,
490
+ stride_input_row=input.stride(0),
491
+ stride_mul_row=mul.stride(0),
492
+ stride_out_row=output.stride(0),
493
+ )
494
+
495
+ ctx.save_for_backward(input, mul, weight, bias, offsets)
496
+ ctx.eps = eps
497
+ ctx.expert_offset = expert_offset
498
+ return output
499
+
500
+ @staticmethod
501
+ def backward(ctx, grad_output):
502
+ input, mul, weight, bias, offsets = ctx.saved_tensors
503
+ eps = ctx.eps
504
+ expert_offset = ctx.expert_offset
505
+ N, D = input.shape
506
+
507
+ grad_output = grad_output.contiguous()
508
+ grad_input = torch.empty_like(input)
509
+ grad_mul = torch.empty_like(mul)
510
+ grad_weight = torch.zeros(weight.shape[0],
511
+ 3,
512
+ device=weight.device,
513
+ dtype=torch.float32)
514
+ grad_bias = torch.zeros(bias.shape[0],
515
+ device=bias.device,
516
+ dtype=torch.float32)
517
+
518
+ num_experts = offsets.shape[0]
519
+
520
+ _grouped_polynorm_bwd_kernel[(N,)](
521
+ grad_output,
522
+ input,
523
+ mul,
524
+ weight,
525
+ bias,
526
+ offsets,
527
+ grad_input,
528
+ grad_mul,
529
+ grad_weight,
530
+ grad_bias,
531
+ N,
532
+ D,
533
+ num_experts,
534
+ eps,
535
+ expert_offset,
536
+ stride_row=input.stride(0),
537
+ )
538
+
539
+ grad_weight = grad_weight.to(weight.dtype)
540
+ grad_bias = grad_bias.unsqueeze(-1).to(bias.dtype)
541
+
542
+ return grad_input, grad_mul, grad_weight, grad_bias, None, None, None
543
+
544
+ def grouped_fused_mul_poly_norm(
545
+ input: Tensor,
546
+ mul: Tensor,
547
+ weight: Tensor,
548
+ bias: Tensor,
549
+ offsets: Tensor,
550
+ eps: float = 1e-6,
551
+ expert_offset: int = 0,
552
+ ) -> Tensor:
553
+ """Triton-accelerated Grouped FusedMulPolyNorm.
554
+
555
+ Args:
556
+ input: (total_tokens, D) - concatenated tokens for all experts
557
+ mul: (total_tokens, D) - gate values to multiply with
558
+ weight: (num_experts, 3) - per-expert polynomial weights
559
+ bias: (num_experts, 1) - per-expert polynomial bias
560
+ offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32)
561
+ eps: numerical stability epsilon
562
+ expert_offset: offset to add to expert index
563
+
564
+ Returns:
565
+ (total_tokens, D) - output tensor
566
+ """
567
+ return _GroupedPolyNormFn.apply(input, mul, weight, bias, offsets, eps,
568
+ expert_offset)
569
+
570
+ else:
571
+
572
+ def grouped_fused_mul_poly_norm(
573
+ input: Tensor,
574
+ mul: Tensor,
575
+ weight: Tensor,
576
+ bias: Tensor,
577
+ offsets: Tensor,
578
+ eps: float = 1e-6,
579
+ expert_offset: int = 0,
580
+ ) -> Tensor:
581
+ raise RuntimeError(
582
+ "Triton is not available. Install triton to use "
583
+ "grouped_fused_mul_poly_norm.")
build/torch29-cxx11-cu126-x86_64-linux/__init__.py CHANGED
@@ -2,6 +2,7 @@ import torch
2
 
3
  from . import layers, parallel_style
4
  from ._ops import ops
 
5
  from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
6
  from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
7
 
@@ -45,6 +46,7 @@ def fused_add_rms_norm(
45
  __all__ = [
46
  "poly_norm",
47
  "fused_mul_poly_norm",
 
48
  "rms_norm",
49
  "fused_add_rms_norm",
50
  "layers",
 
2
 
3
  from . import layers, parallel_style
4
  from ._ops import ops
5
+ from .grouped_poly_norm import grouped_fused_mul_poly_norm
6
  from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
7
  from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
8
 
 
46
  __all__ = [
47
  "poly_norm",
48
  "fused_mul_poly_norm",
49
+ "grouped_fused_mul_poly_norm",
50
  "rms_norm",
51
  "fused_add_rms_norm",
52
  "layers",
build/torch29-cxx11-cu126-x86_64-linux/{_activation_18b7543_dirty.abi3.so → _activation_0e6f27f_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:0a64085eb92e8f49acb50abc5fc2c50e0dc3e3fd84ae29d7ea8bf27518c34af3
3
  size 10756320
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0c5304ceac9171f76c03792dfb9b7e8299ba2f2885983c39031546fde8f61f8b
3
  size 10756320
build/torch29-cxx11-cu126-x86_64-linux/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _activation_18b7543_dirty
3
- ops = torch.ops._activation_18b7543_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_activation_18b7543_dirty::{op_name}"
 
1
  import torch
2
+ from . import _activation_0e6f27f_dirty
3
+ ops = torch.ops._activation_0e6f27f_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_activation_0e6f27f_dirty::{op_name}"
build/torch29-cxx11-cu126-x86_64-linux/grouped_poly_norm.py ADDED
@@ -0,0 +1,583 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Triton-accelerated Grouped FusedMulPolyNorm for MoE.
2
+
3
+ Fuses the entire PolyNorm computation into two Triton kernels (fwd + bwd),
4
+ eliminating multiple intermediate tensors and kernel launches.
5
+
6
+ PolyNorm formula (per row):
7
+ poly = w[0] * rms_norm(x^3) + w[1] * rms_norm(x^2) + w[2] * rms_norm(x) + bias
8
+ output = poly * mul
9
+
10
+ where rms_norm(x) = x / sqrt(mean(x^2, dim=-1) + eps)
11
+
12
+ Performance optimizations:
13
+ - @triton.autotune selects optimal BLOCK_D, num_warps, and num_stages per
14
+ hidden dimension.
15
+ - Single-tile specialization: when D <= BLOCK_D, all data stays in registers
16
+ across the reduction and output phases, eliminating redundant global reads.
17
+ - Multi-tile software pipelining: explicit num_stages in autotune configs
18
+ enables overlapping memory loads with computation across loop iterations.
19
+ - In-kernel binary search for expert mapping: eliminates 2 PyTorch kernel
20
+ launches (torch.arange + torch.bucketize) per forward/backward call.
21
+ - Backward 2-pass optimization: pass 1 merges RMS statistics computation
22
+ with dot product accumulation, pass 2 computes gradients. This reduces
23
+ memory traffic compared to a naive 3-pass approach.
24
+
25
+ Forward kernel: one program per row, tiles over D dimension.
26
+ - Computes x, x^2, x^3 in registers
27
+ - Computes three RMS norms in a single pass (shared variance reduction)
28
+ - Applies polynomial weights + bias + mul in-place
29
+
30
+ Backward kernel: one program per row, tiles over D dimension.
31
+ - Recomputes forward intermediates from saved inputs (activation recomputation)
32
+ - 2-pass: (1) RMS stats + dot products + bias grad, (2) grad_input + grad_mul + weight grads
33
+ - Weight/bias gradients use tl.atomic_add for cross-row accumulation
34
+ """
35
+
36
+ import torch
37
+ from torch import Tensor
38
+
39
+ try:
40
+ import triton
41
+ import triton.language as tl
42
+
43
+ HAS_TRITON = True
44
+ except ImportError:
45
+ HAS_TRITON = False
46
+
47
+
48
+ # ---------------------------------------------------------------------------
49
+ # PyTorch reference implementation (for testing and benchmarking)
50
+ # ---------------------------------------------------------------------------
51
+ def _rms_norm(x: Tensor, eps: float) -> Tensor:
52
+ """Per-row RMS normalization: x / sqrt(mean(x^2, dim=-1) + eps)"""
53
+ return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps)
54
+
55
+
56
+ def grouped_fused_mul_poly_norm_ref(
57
+ input: Tensor,
58
+ mul: Tensor,
59
+ weight: Tensor,
60
+ bias: Tensor,
61
+ offsets: Tensor,
62
+ eps: float = 1e-6,
63
+ expert_offset: int = 0,
64
+ ) -> Tensor:
65
+ """PyTorch reference for Grouped FusedMulPolyNorm (vectorized, single pass).
66
+
67
+ Uses torch.bucketize to map tokens to experts, then computes PolyNorm
68
+ for all tokens at once. torch.compile friendly.
69
+
70
+ Args:
71
+ input: (total_tokens, D) - concatenated tokens for all experts
72
+ mul: (total_tokens, D) - gate values to multiply with
73
+ weight: (num_experts, 3) - per-expert polynomial weights [x^3, x^2, x]
74
+ bias: (num_experts, 1) - per-expert polynomial bias
75
+ offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32)
76
+ eps: numerical stability epsilon
77
+
78
+ Returns:
79
+ (total_tokens, D) - output tensor
80
+ """
81
+ orig_dtype = input.dtype
82
+
83
+ token_positions = torch.arange(input.shape[0], device=input.device)
84
+ expert_idx = torch.bucketize(token_positions, offsets, right=True) + expert_offset
85
+
86
+ weight_fp32 = weight.float()
87
+ bias_fp32 = bias.float()
88
+
89
+ per_token_w = weight_fp32[expert_idx]
90
+ per_token_b = bias_fp32[expert_idx]
91
+
92
+ x = input.float()
93
+ m = mul.float()
94
+
95
+ x2 = x * x
96
+ x3 = x2 * x
97
+
98
+ poly = (per_token_w[:, 0:1] * _rms_norm(x3, eps) +
99
+ per_token_w[:, 1:2] * _rms_norm(x2, eps) +
100
+ per_token_w[:, 2:3] * _rms_norm(x, eps) + per_token_b)
101
+
102
+ return (poly * m).to(orig_dtype)
103
+
104
+
105
+ # ---------------------------------------------------------------------------
106
+ # Triton kernel implementation
107
+ # ---------------------------------------------------------------------------
108
+ if HAS_TRITON:
109
+ # --- Autotune configurations ---
110
+ _GROUPED_POLYNORM_FWD_CONFIGS = [
111
+ triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2),
112
+ triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3),
113
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2),
114
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3),
115
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2),
116
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4),
117
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2),
118
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3),
119
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4),
120
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2),
121
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3),
122
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4),
123
+ triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2),
124
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2),
125
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3),
126
+ triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2),
127
+ triton.Config({"BLOCK_D": 2048}, num_warps=4, num_stages=1),
128
+ triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1),
129
+ triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1),
130
+ triton.Config({"BLOCK_D": 2048}, num_warps=32, num_stages=1),
131
+ ]
132
+
133
+ _GROUPED_POLYNORM_BWD_CONFIGS = [
134
+ triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2),
135
+ triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3),
136
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2),
137
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3),
138
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2),
139
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4),
140
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2),
141
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3),
142
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4),
143
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=5),
144
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2),
145
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3),
146
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4),
147
+ triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2),
148
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2),
149
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3),
150
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=4),
151
+ triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2),
152
+ triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1),
153
+ triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1),
154
+ ]
155
+
156
+ @triton.autotune(
157
+ configs=_GROUPED_POLYNORM_FWD_CONFIGS,
158
+ key=["D"],
159
+ )
160
+ @triton.jit
161
+ def _grouped_polynorm_fwd_kernel(
162
+ input_ptr,
163
+ mul_ptr,
164
+ weight_ptr,
165
+ bias_ptr,
166
+ offsets_ptr,
167
+ output_ptr,
168
+ N,
169
+ D,
170
+ num_experts,
171
+ eps,
172
+ expert_offset,
173
+ stride_input_row,
174
+ stride_mul_row,
175
+ stride_out_row,
176
+ BLOCK_D: tl.constexpr,
177
+ ):
178
+ """Forward kernel: one program per row."""
179
+ row = tl.program_id(0)
180
+ if row >= N:
181
+ return
182
+
183
+ # Binary search for expert index (12 iters covers up to 4096 experts)
184
+ lo = 0
185
+ hi = num_experts
186
+ for _ in range(12):
187
+ if lo < hi:
188
+ mid = (lo + hi) // 2
189
+ if tl.load(offsets_ptr + mid) <= row:
190
+ lo = mid + 1
191
+ else:
192
+ hi = mid
193
+ eidx = lo + expert_offset
194
+
195
+ w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32)
196
+ w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32)
197
+ w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32)
198
+ b = tl.load(bias_ptr + eidx).to(tl.float32)
199
+
200
+ input_row_ptr = input_ptr + row * stride_input_row
201
+ mul_row_ptr = mul_ptr + row * stride_mul_row
202
+ out_row_ptr = output_ptr + row * stride_out_row
203
+
204
+ D_float = D.to(tl.float32)
205
+
206
+ # --- Single-tile path ---
207
+ if D <= BLOCK_D:
208
+ d_offs = tl.arange(0, BLOCK_D)
209
+ mask = d_offs < D
210
+
211
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
212
+ other=0.0).to(tl.float32)
213
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
214
+ other=0.0).to(tl.float32)
215
+
216
+ x2 = x * x
217
+ x3 = x2 * x
218
+
219
+ inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps)
220
+ inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps)
221
+ inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps)
222
+
223
+ # Pre-multiply scalar weight * inv_rms to save 1 FMA per element
224
+ w0_inv = w0 * inv_rms_x3
225
+ w1_inv = w1 * inv_rms_x2
226
+ w2_inv = w2 * inv_rms_x
227
+
228
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b
229
+ tl.store(out_row_ptr + d_offs, poly * m, mask=mask)
230
+ else:
231
+ # --- Multi-tile: two-pass approach ---
232
+ sum_x2 = tl.zeros((), dtype=tl.float32)
233
+ sum_x4 = tl.zeros((), dtype=tl.float32)
234
+ sum_x6 = tl.zeros((), dtype=tl.float32)
235
+
236
+ for d_start in range(0, D, BLOCK_D):
237
+ d_offs = d_start + tl.arange(0, BLOCK_D)
238
+ mask = d_offs < D
239
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
240
+ other=0.0).to(tl.float32)
241
+ x2 = x * x
242
+ sum_x2 += tl.sum(x2)
243
+ sum_x4 += tl.sum(x2 * x2)
244
+ sum_x6 += tl.sum(x2 * x2 * x2)
245
+
246
+ inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps)
247
+ inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps)
248
+ inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps)
249
+
250
+ # Pre-multiply scalar weight * inv_rms
251
+ w0_inv = w0 * inv_rms_x3
252
+ w1_inv = w1 * inv_rms_x2
253
+ w2_inv = w2 * inv_rms_x
254
+
255
+ for d_start in range(0, D, BLOCK_D):
256
+ d_offs = d_start + tl.arange(0, BLOCK_D)
257
+ mask = d_offs < D
258
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
259
+ other=0.0).to(tl.float32)
260
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
261
+ other=0.0).to(tl.float32)
262
+ x2 = x * x
263
+ x3 = x2 * x
264
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b
265
+ tl.store(out_row_ptr + d_offs, poly * m, mask=mask)
266
+
267
+ @triton.autotune(
268
+ configs=_GROUPED_POLYNORM_BWD_CONFIGS,
269
+ key=["D"],
270
+ reset_to_zero=["grad_weight_ptr", "grad_bias_ptr"],
271
+ )
272
+ @triton.jit
273
+ def _grouped_polynorm_bwd_kernel(
274
+ grad_out_ptr,
275
+ input_ptr,
276
+ mul_ptr,
277
+ weight_ptr,
278
+ bias_ptr,
279
+ offsets_ptr,
280
+ grad_input_ptr,
281
+ grad_mul_ptr,
282
+ grad_weight_ptr,
283
+ grad_bias_ptr,
284
+ N,
285
+ D,
286
+ num_experts,
287
+ eps,
288
+ expert_offset,
289
+ stride_row,
290
+ BLOCK_D: tl.constexpr,
291
+ ):
292
+ """Backward kernel: one program per row, 2-pass approach.
293
+
294
+ Pass 1: RMS stats + dot products + bias grad
295
+ Pass 2: grad_input + grad_mul + weight grads (via atomic_add)
296
+ """
297
+ row = tl.program_id(0)
298
+ if row >= N:
299
+ return
300
+
301
+ lo = 0
302
+ hi = num_experts
303
+ for _ in range(12):
304
+ if lo < hi:
305
+ mid = (lo + hi) // 2
306
+ if tl.load(offsets_ptr + mid) <= row:
307
+ lo = mid + 1
308
+ else:
309
+ hi = mid
310
+ eidx = lo + expert_offset
311
+
312
+ w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32)
313
+ w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32)
314
+ w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32)
315
+ b_val = tl.load(bias_ptr + eidx).to(tl.float32)
316
+
317
+ input_row_ptr = input_ptr + row * stride_row
318
+ mul_row_ptr = mul_ptr + row * stride_row
319
+ grad_out_row_ptr = grad_out_ptr + row * stride_row
320
+ grad_input_row_ptr = grad_input_ptr + row * stride_row
321
+ grad_mul_row_ptr = grad_mul_ptr + row * stride_row
322
+
323
+ D_float = D.to(tl.float32)
324
+
325
+ # --- Single-tile path ---
326
+ if D <= BLOCK_D:
327
+ d_offs = tl.arange(0, BLOCK_D)
328
+ mask = d_offs < D
329
+
330
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
331
+ other=0.0).to(tl.float32)
332
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
333
+ other=0.0).to(tl.float32)
334
+ go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
335
+ other=0.0).to(tl.float32)
336
+
337
+ x2 = x * x
338
+ x3 = x2 * x
339
+
340
+ # Compute RMS stats (x4 inlined to reduce register pressure)
341
+ inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps)
342
+ inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps)
343
+ inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps)
344
+
345
+ w0_inv = w0 * inv_rms_x3
346
+ w1_inv = w1 * inv_rms_x2
347
+ w2_inv = w2 * inv_rms_x
348
+
349
+ dpoly = go * m
350
+
351
+ # Dot products for coefficients and weight grads
352
+ sum_dpoly_x = tl.sum(dpoly * x)
353
+ sum_dpoly_x2 = tl.sum(dpoly * x2)
354
+ sum_dpoly_x3 = tl.sum(dpoly * x3)
355
+ grad_b_acc = tl.sum(dpoly)
356
+
357
+ # Weight grads
358
+ grad_w0_acc = inv_rms_x3 * sum_dpoly_x3
359
+ grad_w1_acc = inv_rms_x2 * sum_dpoly_x2
360
+ grad_w2_acc = inv_rms_x * sum_dpoly_x
361
+
362
+ # Coefficients for grad_input
363
+ coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float
364
+ coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float
365
+ coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float
366
+
367
+ # grad_mul
368
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv
369
+ tl.store(grad_mul_row_ptr + d_offs, go * (poly + b_val), mask=mask)
370
+
371
+ # grad_input (in-place accumulation to reduce register pressure)
372
+ g = inv_rms_x * (w2 * dpoly - x * coeff_x)
373
+ g += 2.0 * x * inv_rms_x2 * (w1 * dpoly - x2 * coeff_x2)
374
+ g += 3.0 * x2 * inv_rms_x3 * (w0 * dpoly - x3 * coeff_x3)
375
+ tl.store(grad_input_row_ptr + d_offs, g, mask=mask)
376
+
377
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc)
378
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc)
379
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc)
380
+ tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc)
381
+ else:
382
+ # --- Multi-tile: 2-pass ---
383
+ # Pass 1: RMS stats + dot products + bias grad
384
+ sum_x2 = tl.zeros((), dtype=tl.float32)
385
+ sum_x4 = tl.zeros((), dtype=tl.float32)
386
+ sum_x6 = tl.zeros((), dtype=tl.float32)
387
+ sum_dpoly_x = tl.zeros((), dtype=tl.float32)
388
+ sum_dpoly_x2 = tl.zeros((), dtype=tl.float32)
389
+ sum_dpoly_x3 = tl.zeros((), dtype=tl.float32)
390
+ grad_b_acc = tl.zeros((), dtype=tl.float32)
391
+
392
+ for d_start in range(0, D, BLOCK_D):
393
+ d_offs = d_start + tl.arange(0, BLOCK_D)
394
+ mask = d_offs < D
395
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
396
+ other=0.0).to(tl.float32)
397
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
398
+ other=0.0).to(tl.float32)
399
+ go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
400
+ other=0.0).to(tl.float32)
401
+
402
+ x2 = x * x
403
+ x3 = x2 * x
404
+ dpoly = go * m
405
+
406
+ sum_x2 += tl.sum(x2)
407
+ sum_x4 += tl.sum(x2 * x2)
408
+ sum_x6 += tl.sum(x2 * x2 * x2)
409
+ sum_dpoly_x += tl.sum(dpoly * x)
410
+ sum_dpoly_x2 += tl.sum(dpoly * x2)
411
+ sum_dpoly_x3 += tl.sum(dpoly * x3)
412
+ grad_b_acc += tl.sum(dpoly)
413
+
414
+ inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps)
415
+ inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps)
416
+ inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps)
417
+
418
+ w0_inv = w0 * inv_rms_x3
419
+ w1_inv = w1 * inv_rms_x2
420
+ w2_inv = w2 * inv_rms_x
421
+
422
+ # Weight grads
423
+ grad_w0_acc = inv_rms_x3 * sum_dpoly_x3
424
+ grad_w1_acc = inv_rms_x2 * sum_dpoly_x2
425
+ grad_w2_acc = inv_rms_x * sum_dpoly_x
426
+
427
+ # Coefficients for grad_input
428
+ coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float
429
+ coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float
430
+ coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float
431
+
432
+ # Pass 2: grad_input + grad_mul
433
+ for d_start in range(0, D, BLOCK_D):
434
+ d_offs = d_start + tl.arange(0, BLOCK_D)
435
+ mask = d_offs < D
436
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
437
+ other=0.0).to(tl.float32)
438
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
439
+ other=0.0).to(tl.float32)
440
+ go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
441
+ other=0.0).to(tl.float32)
442
+
443
+ x2 = x * x
444
+ x3 = x2 * x
445
+
446
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv
447
+ tl.store(grad_mul_row_ptr + d_offs,
448
+ go * (poly + b_val),
449
+ mask=mask)
450
+
451
+ dpoly = go * m
452
+ g = inv_rms_x * (w2 * dpoly - x * coeff_x)
453
+ g += (2.0 * x * inv_rms_x2 *
454
+ (w1 * dpoly - x2 * coeff_x2))
455
+ g += (3.0 * x2 * inv_rms_x3 *
456
+ (w0 * dpoly - x3 * coeff_x3))
457
+
458
+ tl.store(grad_input_row_ptr + d_offs, g, mask=mask)
459
+
460
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc)
461
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc)
462
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc)
463
+ tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc)
464
+
465
+ class _GroupedPolyNormFn(torch.autograd.Function):
466
+
467
+ @staticmethod
468
+ def forward(ctx, input, mul, weight, bias, offsets, eps, expert_offset):
469
+ N, D = input.shape
470
+ input = input.contiguous()
471
+ mul = mul.contiguous()
472
+ output = torch.empty_like(input)
473
+
474
+ num_experts = offsets.shape[0]
475
+ assert num_experts <= 4096, (
476
+ f"Supports at most 4096 experts, got {num_experts}.")
477
+
478
+ _grouped_polynorm_fwd_kernel[(N,)](
479
+ input,
480
+ mul,
481
+ weight,
482
+ bias,
483
+ offsets,
484
+ output,
485
+ N,
486
+ D,
487
+ num_experts,
488
+ eps,
489
+ expert_offset,
490
+ stride_input_row=input.stride(0),
491
+ stride_mul_row=mul.stride(0),
492
+ stride_out_row=output.stride(0),
493
+ )
494
+
495
+ ctx.save_for_backward(input, mul, weight, bias, offsets)
496
+ ctx.eps = eps
497
+ ctx.expert_offset = expert_offset
498
+ return output
499
+
500
+ @staticmethod
501
+ def backward(ctx, grad_output):
502
+ input, mul, weight, bias, offsets = ctx.saved_tensors
503
+ eps = ctx.eps
504
+ expert_offset = ctx.expert_offset
505
+ N, D = input.shape
506
+
507
+ grad_output = grad_output.contiguous()
508
+ grad_input = torch.empty_like(input)
509
+ grad_mul = torch.empty_like(mul)
510
+ grad_weight = torch.zeros(weight.shape[0],
511
+ 3,
512
+ device=weight.device,
513
+ dtype=torch.float32)
514
+ grad_bias = torch.zeros(bias.shape[0],
515
+ device=bias.device,
516
+ dtype=torch.float32)
517
+
518
+ num_experts = offsets.shape[0]
519
+
520
+ _grouped_polynorm_bwd_kernel[(N,)](
521
+ grad_output,
522
+ input,
523
+ mul,
524
+ weight,
525
+ bias,
526
+ offsets,
527
+ grad_input,
528
+ grad_mul,
529
+ grad_weight,
530
+ grad_bias,
531
+ N,
532
+ D,
533
+ num_experts,
534
+ eps,
535
+ expert_offset,
536
+ stride_row=input.stride(0),
537
+ )
538
+
539
+ grad_weight = grad_weight.to(weight.dtype)
540
+ grad_bias = grad_bias.unsqueeze(-1).to(bias.dtype)
541
+
542
+ return grad_input, grad_mul, grad_weight, grad_bias, None, None, None
543
+
544
+ def grouped_fused_mul_poly_norm(
545
+ input: Tensor,
546
+ mul: Tensor,
547
+ weight: Tensor,
548
+ bias: Tensor,
549
+ offsets: Tensor,
550
+ eps: float = 1e-6,
551
+ expert_offset: int = 0,
552
+ ) -> Tensor:
553
+ """Triton-accelerated Grouped FusedMulPolyNorm.
554
+
555
+ Args:
556
+ input: (total_tokens, D) - concatenated tokens for all experts
557
+ mul: (total_tokens, D) - gate values to multiply with
558
+ weight: (num_experts, 3) - per-expert polynomial weights
559
+ bias: (num_experts, 1) - per-expert polynomial bias
560
+ offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32)
561
+ eps: numerical stability epsilon
562
+ expert_offset: offset to add to expert index
563
+
564
+ Returns:
565
+ (total_tokens, D) - output tensor
566
+ """
567
+ return _GroupedPolyNormFn.apply(input, mul, weight, bias, offsets, eps,
568
+ expert_offset)
569
+
570
+ else:
571
+
572
+ def grouped_fused_mul_poly_norm(
573
+ input: Tensor,
574
+ mul: Tensor,
575
+ weight: Tensor,
576
+ bias: Tensor,
577
+ offsets: Tensor,
578
+ eps: float = 1e-6,
579
+ expert_offset: int = 0,
580
+ ) -> Tensor:
581
+ raise RuntimeError(
582
+ "Triton is not available. Install triton to use "
583
+ "grouped_fused_mul_poly_norm.")
build/torch29-cxx11-cu128-x86_64-linux/__init__.py CHANGED
@@ -2,6 +2,7 @@ import torch
2
 
3
  from . import layers, parallel_style
4
  from ._ops import ops
 
5
  from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
6
  from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
7
 
@@ -45,6 +46,7 @@ def fused_add_rms_norm(
45
  __all__ = [
46
  "poly_norm",
47
  "fused_mul_poly_norm",
 
48
  "rms_norm",
49
  "fused_add_rms_norm",
50
  "layers",
 
2
 
3
  from . import layers, parallel_style
4
  from ._ops import ops
5
+ from .grouped_poly_norm import grouped_fused_mul_poly_norm
6
  from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
7
  from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
8
 
 
46
  __all__ = [
47
  "poly_norm",
48
  "fused_mul_poly_norm",
49
+ "grouped_fused_mul_poly_norm",
50
  "rms_norm",
51
  "fused_add_rms_norm",
52
  "layers",
build/torch29-cxx11-cu128-x86_64-linux/{_activation_18b7543_dirty.abi3.so → _activation_0e6f27f_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:a996dcbd533b29a6de849fb4c83b58f5b818688b1c89ae8609805d09b500bc13
3
  size 15804336
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dfa89588a5e7e74b3a903912190b97004e308dd8fcb87832c2798d99733591f2
3
  size 15804336