TaehyunKim Claude Opus 4.6 github-actions[bot] commited on
feat: add GroupedFusedMulPolyNorm Triton kernel for MoE models (#16)
Browse files* feat: add GroupedFusedMulPolyNorm Triton kernel for MoE models
Fuses the full PolyNorm computation into two Triton kernels (fwd + bwd)
with per-expert weights/bias and in-kernel binary search for expert mapping.
Includes benchmarks, tests, and README documentation with B200 results.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
* Add built binary [skip-build]
---------
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This view is limited to 50 files because it contains too many changes. Â
See raw diff
- README.md +129 -1
- benchmarks/cases/grouped_mul_poly.py +122 -0
- benchmarks/common/bench_framework.py +24 -4
- benchmarks/run_cases.py +138 -75
- build/torch210-cxx11-cu126-x86_64-linux/__init__.py +2 -0
- build/torch210-cxx11-cu126-x86_64-linux/{_activation_18b7543_dirty.abi3.so → _activation_0e6f27f_dirty.abi3.so} +1 -1
- build/torch210-cxx11-cu126-x86_64-linux/_ops.py +3 -3
- build/torch210-cxx11-cu126-x86_64-linux/grouped_poly_norm.py +583 -0
- build/torch210-cxx11-cu128-x86_64-linux/__init__.py +2 -0
- build/torch210-cxx11-cu128-x86_64-linux/{_activation_18b7543_dirty.abi3.so → _activation_0e6f27f_dirty.abi3.so} +1 -1
- build/torch210-cxx11-cu128-x86_64-linux/_ops.py +3 -3
- build/torch210-cxx11-cu128-x86_64-linux/grouped_poly_norm.py +583 -0
- build/torch210-cxx11-cu130-x86_64-linux/__init__.py +2 -0
- build/torch210-cxx11-cu130-x86_64-linux/{_activation_18b7543_dirty.abi3.so → _activation_0e6f27f_dirty.abi3.so} +1 -1
- build/torch210-cxx11-cu130-x86_64-linux/_ops.py +3 -3
- build/torch210-cxx11-cu130-x86_64-linux/grouped_poly_norm.py +583 -0
- build/torch210-cxx11-rocm70-x86_64-linux/__init__.py +2 -0
- build/torch210-cxx11-rocm70-x86_64-linux/{_activation_18b7543_dirty.abi3.so → _activation_0e6f27f_dirty.abi3.so} +1 -1
- build/torch210-cxx11-rocm70-x86_64-linux/_ops.py +3 -3
- build/torch210-cxx11-rocm70-x86_64-linux/grouped_poly_norm.py +583 -0
- build/torch210-cxx11-rocm71-x86_64-linux/__init__.py +2 -0
- build/torch210-cxx11-rocm71-x86_64-linux/{_activation_18b7543_dirty.abi3.so → _activation_0e6f27f_dirty.abi3.so} +1 -1
- build/torch210-cxx11-rocm71-x86_64-linux/_ops.py +3 -3
- build/torch210-cxx11-rocm71-x86_64-linux/grouped_poly_norm.py +583 -0
- build/torch28-cxx11-cu126-x86_64-linux/__init__.py +2 -0
- build/torch28-cxx11-cu126-x86_64-linux/{_activation_18b7543_dirty.abi3.so → _activation_0e6f27f_dirty.abi3.so} +1 -1
- build/torch28-cxx11-cu126-x86_64-linux/_ops.py +3 -3
- build/torch28-cxx11-cu126-x86_64-linux/grouped_poly_norm.py +583 -0
- build/torch28-cxx11-cu128-x86_64-linux/__init__.py +2 -0
- build/torch28-cxx11-cu128-x86_64-linux/{_activation_18b7543_dirty.abi3.so → _activation_0e6f27f_dirty.abi3.so} +1 -1
- build/torch28-cxx11-cu128-x86_64-linux/_ops.py +3 -3
- build/torch28-cxx11-cu128-x86_64-linux/grouped_poly_norm.py +583 -0
- build/torch28-cxx11-cu129-x86_64-linux/__init__.py +2 -0
- build/torch28-cxx11-cu129-x86_64-linux/{_activation_18b7543_dirty.abi3.so → _activation_0e6f27f_dirty.abi3.so} +1 -1
- build/torch28-cxx11-cu129-x86_64-linux/_ops.py +3 -3
- build/torch28-cxx11-cu129-x86_64-linux/grouped_poly_norm.py +583 -0
- build/torch28-cxx11-rocm63-x86_64-linux/__init__.py +2 -0
- build/torch28-cxx11-rocm63-x86_64-linux/{_activation_18b7543_dirty.abi3.so → _activation_0e6f27f_dirty.abi3.so} +1 -1
- build/torch28-cxx11-rocm63-x86_64-linux/_ops.py +3 -3
- build/torch28-cxx11-rocm63-x86_64-linux/grouped_poly_norm.py +583 -0
- build/torch28-cxx11-rocm64-x86_64-linux/__init__.py +2 -0
- build/torch28-cxx11-rocm64-x86_64-linux/{_activation_18b7543_dirty.abi3.so → _activation_0e6f27f_dirty.abi3.so} +1 -1
- build/torch28-cxx11-rocm64-x86_64-linux/_ops.py +3 -3
- build/torch28-cxx11-rocm64-x86_64-linux/grouped_poly_norm.py +583 -0
- build/torch29-cxx11-cu126-x86_64-linux/__init__.py +2 -0
- build/torch29-cxx11-cu126-x86_64-linux/{_activation_18b7543_dirty.abi3.so → _activation_0e6f27f_dirty.abi3.so} +1 -1
- build/torch29-cxx11-cu126-x86_64-linux/_ops.py +3 -3
- build/torch29-cxx11-cu126-x86_64-linux/grouped_poly_norm.py +583 -0
- build/torch29-cxx11-cu128-x86_64-linux/__init__.py +2 -0
- build/torch29-cxx11-cu128-x86_64-linux/{_activation_18b7543_dirty.abi3.so → _activation_0e6f27f_dirty.abi3.so} +1 -1
README.md
CHANGED
|
@@ -19,7 +19,7 @@ Activation is a python package that contains custom CUDA-based activation kernel
|
|
| 19 |
```python
|
| 20 |
y = x + residual
|
| 21 |
hidden_state = rms_norm(y, weight, eps)
|
| 22 |
-
out = y + some_op(hidden_state)
|
| 23 |
```
|
| 24 |
|
| 25 |
- Fused as:
|
|
@@ -45,6 +45,22 @@ Activation is a python package that contains custom CUDA-based activation kernel
|
|
| 45 |
out = fused_mul_poly_norm(x, a, weight, bias, eps)
|
| 46 |
```
|
| 47 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
## Usage
|
| 49 |
|
| 50 |
```python
|
|
@@ -214,6 +230,118 @@ print(poly_norm(x))
|
|
| 214 |
|
| 215 |
</details>
|
| 216 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
## Pre-commit Hooks
|
| 218 |
|
| 219 |
This project uses [pre-commit](https://pre-commit.com/) to automatically check and format code before commits.
|
|
|
|
| 19 |
```python
|
| 20 |
y = x + residual
|
| 21 |
hidden_state = rms_norm(y, weight, eps)
|
| 22 |
+
out = y + some_op(hidden_state)
|
| 23 |
```
|
| 24 |
|
| 25 |
- Fused as:
|
|
|
|
| 45 |
out = fused_mul_poly_norm(x, a, weight, bias, eps)
|
| 46 |
```
|
| 47 |
|
| 48 |
+
- **GroupedFusedMulPolyNorm** (Triton)
|
| 49 |
+
|
| 50 |
+
A Triton-accelerated grouped variant of FusedMulPolyNorm for **MoE (Mixture of Experts)** models. Fuses the entire PolyNorm computation into two Triton kernels (fwd + bwd), with per-expert weights/bias and in-kernel binary search for expert mapping.
|
| 51 |
+
- Instead of:
|
| 52 |
+
|
| 53 |
+
```python
|
| 54 |
+
for i, expert in enumerate(experts):
|
| 55 |
+
out[start:end] = fused_mul_poly_norm(x[start:end], mul[start:end], weight[i], bias[i], eps)
|
| 56 |
+
```
|
| 57 |
+
|
| 58 |
+
- Fused as:
|
| 59 |
+
|
| 60 |
+
```python
|
| 61 |
+
out = grouped_fused_mul_poly_norm(x, mul, weight, bias, offsets, eps)
|
| 62 |
+
```
|
| 63 |
+
|
| 64 |
## Usage
|
| 65 |
|
| 66 |
```python
|
|
|
|
| 230 |
|
| 231 |
</details>
|
| 232 |
|
| 233 |
+
---
|
| 234 |
+
|
| 235 |
+
### GroupedFusedMulPolyNorm (Triton)
|
| 236 |
+
|
| 237 |
+
> [!NOTE]
|
| 238 |
+
> This kernel is implemented in Triton (JIT-compiled, no CUDA C++ build required).
|
| 239 |
+
> Benchmarks compare three variants: **Naive** (raw PyTorch reference), **Compiled** (`torch.compile`'d reference), and **Triton** (fused Triton kernel).
|
| 240 |
+
> Benchmark dimension: 1280, 384 experts.
|
| 241 |
+
|
| 242 |
+
#### B200 Results (bf16)
|
| 243 |
+
|
| 244 |
+
<details>
|
| 245 |
+
<summary>Forward Performance</summary>
|
| 246 |
+
|
| 247 |
+
| batch_size | seq_len | Naive (us) | Compiled (us) | Triton (us) | Triton vs Naive |
|
| 248 |
+
|-----------|---------|-----------|--------------|------------|-----------------|
|
| 249 |
+
| 1 | 1024 | 294.54 | 73.46 | 64.33 | 4.58x |
|
| 250 |
+
| 1 | 2048 | 373.50 | 94.88 | 65.26 | 5.72x |
|
| 251 |
+
| 1 | 4096 | 372.65 | 94.90 | 66.90 | 5.57x |
|
| 252 |
+
| 1 | 8192 | 486.98 | 102.33 | 72.71 | 6.70x |
|
| 253 |
+
| 2 | 4096 | 486.66 | 101.87 | 72.27 | 6.73x |
|
| 254 |
+
| 2 | 8192 | 950.62 | 106.96 | 90.06 | 10.56x |
|
| 255 |
+
| 4 | 4096 | 950.72 | 107.17 | 71.28 | 13.34x |
|
| 256 |
+
| 4 | 8192 | 1779.12 | 198.91 | 96.93 | 18.35x |
|
| 257 |
+
| 8 | 4096 | 1778.73 | 199.10 | 96.88 | 18.36x |
|
| 258 |
+
| 8 | 8192 | 3384.03 | 381.91 | 179.57 | 18.85x |
|
| 259 |
+
|
| 260 |
+
</details>
|
| 261 |
+
|
| 262 |
+
<details>
|
| 263 |
+
<summary>Backward Performance</summary>
|
| 264 |
+
|
| 265 |
+
| batch_size | seq_len | Naive (us) | Compiled (us) | Triton (us) | Triton vs Naive |
|
| 266 |
+
|-----------|---------|-----------|--------------|------------|-----------------|
|
| 267 |
+
| 1 | 1024 | 1690.61 | 999.66 | 1017.66 | 1.66x |
|
| 268 |
+
| 1 | 8192 | 1680.39 | 906.43 | 906.41 | 1.85x |
|
| 269 |
+
| 2 | 8192 | 2466.73 | 870.74 | 862.78 | 2.86x |
|
| 270 |
+
| 4 | 4096 | 2466.04 | 942.62 | 945.68 | 2.61x |
|
| 271 |
+
| 4 | 8192 | 4543.10 | 941.01 | 908.30 | 5.00x |
|
| 272 |
+
| 8 | 4096 | 4542.91 | 814.73 | 900.01 | 5.05x |
|
| 273 |
+
| 8 | 8192 | 8599.41 | 956.81 | 955.07 | 9.00x |
|
| 274 |
+
|
| 275 |
+
</details>
|
| 276 |
+
|
| 277 |
+
<details>
|
| 278 |
+
<summary>Forward + Backward Combined</summary>
|
| 279 |
+
|
| 280 |
+
| batch_size | seq_len | Naive (us) | Compiled (us) | Triton (us) | Triton vs Naive | Triton vs Compiled |
|
| 281 |
+
|-----------|---------|-----------|--------------|------------|-----------------|-------------------|
|
| 282 |
+
| 1 | 1024 | 1985.15 | 1073.12 | 1081.99 | 1.83x | 0.99x |
|
| 283 |
+
| 1 | 4096 | 2085.10 | 974.32 | 960.73 | 2.17x | 1.01x |
|
| 284 |
+
| 1 | 8192 | 2167.37 | 1008.76 | 979.12 | 2.21x | 1.03x |
|
| 285 |
+
| 2 | 4096 | 2083.49 | 1001.03 | 965.30 | 2.16x | 1.04x |
|
| 286 |
+
| 2 | 8192 | 3417.35 | 977.70 | 952.84 | 3.59x | 1.03x |
|
| 287 |
+
| 4 | 4096 | 3416.76 | 1049.79 | 1016.97 | 3.36x | 1.03x |
|
| 288 |
+
| 4 | 8192 | 6322.22 | 1139.92 | 1005.23 | 6.29x | 1.13x |
|
| 289 |
+
| 8 | 4096 | 6321.64 | 1013.83 | 996.89 | 6.34x | 1.02x |
|
| 290 |
+
| 8 | 8192 | 11983.44 | 1338.71 | 1134.64 | 10.56x | 1.18x |
|
| 291 |
+
|
| 292 |
+
</details>
|
| 293 |
+
|
| 294 |
+
#### B200 Results (fp32)
|
| 295 |
+
|
| 296 |
+
<details>
|
| 297 |
+
<summary>Forward Performance</summary>
|
| 298 |
+
|
| 299 |
+
| batch_size | seq_len | Naive (us) | Compiled (us) | Triton (us) | Triton vs Naive |
|
| 300 |
+
|-----------|---------|-----------|--------------|------------|-----------------|
|
| 301 |
+
| 1 | 1024 | 318.05 | 83.29 | 64.24 | 4.95x |
|
| 302 |
+
| 1 | 2048 | 311.14 | 95.19 | 63.64 | 4.89x |
|
| 303 |
+
| 1 | 8192 | 401.78 | 101.61 | 68.21 | 5.89x |
|
| 304 |
+
| 2 | 4096 | 403.42 | 100.97 | 68.01 | 5.93x |
|
| 305 |
+
| 2 | 8192 | 803.31 | 130.51 | 68.21 | 11.78x |
|
| 306 |
+
| 4 | 4096 | 802.86 | 130.61 | 66.97 | 11.99x |
|
| 307 |
+
| 4 | 8192 | 1505.96 | 246.77 | 100.49 | 14.99x |
|
| 308 |
+
| 8 | 4096 | 1507.87 | 246.84 | 100.23 | 15.04x |
|
| 309 |
+
| 8 | 8192 | 2856.93 | 476.34 | 184.40 | 15.49x |
|
| 310 |
+
|
| 311 |
+
</details>
|
| 312 |
+
|
| 313 |
+
<details>
|
| 314 |
+
<summary>Backward Performance</summary>
|
| 315 |
+
|
| 316 |
+
| batch_size | seq_len | Naive (us) | Compiled (us) | Triton (us) | Triton vs Naive |
|
| 317 |
+
|-----------|---------|-----------|--------------|------------|-----------------|
|
| 318 |
+
| 1 | 1024 | 1604.25 | 989.30 | 1114.12 | 1.44x |
|
| 319 |
+
| 1 | 8192 | 1996.40 | 1117.71 | 1115.47 | 1.79x |
|
| 320 |
+
| 2 | 8192 | 2353.87 | 1119.41 | 1118.57 | 2.10x |
|
| 321 |
+
| 4 | 4096 | 2358.47 | 1102.23 | 1125.16 | 2.10x |
|
| 322 |
+
| 4 | 8192 | 4346.92 | 1125.33 | 1135.36 | 3.83x |
|
| 323 |
+
| 8 | 4096 | 4347.47 | 1104.27 | 1119.63 | 3.88x |
|
| 324 |
+
| 8 | 8192 | 8226.50 | 1172.66 | 1197.68 | 6.87x |
|
| 325 |
+
|
| 326 |
+
</details>
|
| 327 |
+
|
| 328 |
+
<details>
|
| 329 |
+
<summary>Forward + Backward Combined</summary>
|
| 330 |
+
|
| 331 |
+
| batch_size | seq_len | Naive (us) | Compiled (us) | Triton (us) | Triton vs Naive | Triton vs Compiled |
|
| 332 |
+
|-----------|---------|-----------|--------------|------------|-----------------|-------------------|
|
| 333 |
+
| 1 | 1024 | 1922.30 | 1072.59 | 1178.36 | 1.63x | 0.91x |
|
| 334 |
+
| 1 | 4096 | 2367.77 | 1208.69 | 1192.07 | 1.99x | 1.01x |
|
| 335 |
+
| 1 | 8192 | 2398.19 | 1219.32 | 1183.69 | 2.03x | 1.03x |
|
| 336 |
+
| 2 | 4096 | 2401.39 | 1248.87 | 1154.72 | 2.08x | 1.08x |
|
| 337 |
+
| 2 | 8192 | 3157.18 | 1249.92 | 1186.77 | 2.66x | 1.05x |
|
| 338 |
+
| 4 | 4096 | 3161.33 | 1232.84 | 1192.13 | 2.65x | 1.03x |
|
| 339 |
+
| 4 | 8192 | 5852.88 | 1372.10 | 1235.86 | 4.74x | 1.11x |
|
| 340 |
+
| 8 | 4096 | 5855.34 | 1351.11 | 1219.85 | 4.80x | 1.11x |
|
| 341 |
+
| 8 | 8192 | 11083.43 | 1649.00 | 1382.07 | 8.02x | 1.19x |
|
| 342 |
+
|
| 343 |
+
</details>
|
| 344 |
+
|
| 345 |
## Pre-commit Hooks
|
| 346 |
|
| 347 |
This project uses [pre-commit](https://pre-commit.com/) to automatically check and format code before commits.
|
benchmarks/cases/grouped_mul_poly.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch._functorch.config
|
| 3 |
+
from common.diff_engine import DiffCase
|
| 4 |
+
|
| 5 |
+
torch._functorch.config.donated_buffer = False
|
| 6 |
+
|
| 7 |
+
from grouped_poly_norm import (
|
| 8 |
+
grouped_fused_mul_poly_norm,
|
| 9 |
+
grouped_fused_mul_poly_norm_ref,
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
NUM_EXPERTS = 384
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class GroupedRefModule(torch.nn.Module):
|
| 16 |
+
"""Wraps the PyTorch reference for grouped FusedMulPolyNorm."""
|
| 17 |
+
|
| 18 |
+
def __init__(self, weight, bias, offsets, eps, expert_offset=0):
|
| 19 |
+
super().__init__()
|
| 20 |
+
self.weight = torch.nn.Parameter(weight)
|
| 21 |
+
self.bias = torch.nn.Parameter(bias)
|
| 22 |
+
self.offsets = offsets
|
| 23 |
+
self.eps = eps
|
| 24 |
+
self.expert_offset = expert_offset
|
| 25 |
+
|
| 26 |
+
def forward(self, x, mul):
|
| 27 |
+
return grouped_fused_mul_poly_norm_ref(x, mul, self.weight, self.bias,
|
| 28 |
+
self.offsets, self.eps,
|
| 29 |
+
expert_offset=self.expert_offset)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class GroupedTritonModule(torch.nn.Module):
|
| 33 |
+
"""Wraps the Triton kernel for grouped FusedMulPolyNorm."""
|
| 34 |
+
|
| 35 |
+
def __init__(self, weight, bias, offsets, eps, expert_offset=0):
|
| 36 |
+
super().__init__()
|
| 37 |
+
self.weight = torch.nn.Parameter(weight)
|
| 38 |
+
self.bias = torch.nn.Parameter(bias)
|
| 39 |
+
self.offsets = offsets
|
| 40 |
+
self.eps = eps
|
| 41 |
+
self.expert_offset = expert_offset
|
| 42 |
+
|
| 43 |
+
def forward(self, x, mul):
|
| 44 |
+
return grouped_fused_mul_poly_norm(x, mul, self.weight, self.bias,
|
| 45 |
+
self.offsets, self.eps,
|
| 46 |
+
expert_offset=self.expert_offset)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class GroupedMulPoly(DiffCase):
|
| 50 |
+
"""Benchmark case for Grouped FusedMulPolyNorm (MoE).
|
| 51 |
+
|
| 52 |
+
Maps the framework's (bs, sl, hidden) to grouped polynorm's
|
| 53 |
+
(total_tokens, D) where total_tokens = bs * sl.
|
| 54 |
+
Uses a fixed number of experts with uniform token distribution.
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
def build_inputs(self, bs, sl, hidden, dtype, eps):
|
| 58 |
+
total_tokens = bs * sl
|
| 59 |
+
num_experts = min(NUM_EXPERTS, total_tokens)
|
| 60 |
+
|
| 61 |
+
torch.manual_seed(42)
|
| 62 |
+
probs = torch.ones(num_experts) / num_experts
|
| 63 |
+
assignments = torch.multinomial(probs, total_tokens, replacement=True)
|
| 64 |
+
counts = torch.bincount(assignments, minlength=num_experts).tolist()
|
| 65 |
+
offsets = torch.cumsum(
|
| 66 |
+
torch.tensor(counts, dtype=torch.int32), dim=0)
|
| 67 |
+
|
| 68 |
+
return {
|
| 69 |
+
"x":
|
| 70 |
+
torch.randn(total_tokens, hidden, dtype=dtype,
|
| 71 |
+
requires_grad=True) * 0.5,
|
| 72 |
+
"mul":
|
| 73 |
+
torch.randn(total_tokens, hidden, dtype=dtype,
|
| 74 |
+
requires_grad=True) * 0.5,
|
| 75 |
+
"weight":
|
| 76 |
+
torch.ones(num_experts, 3, dtype=dtype) / 3 +
|
| 77 |
+
torch.randn(num_experts, 3, dtype=dtype) * 0.01,
|
| 78 |
+
"bias":
|
| 79 |
+
torch.randn(num_experts, 1, dtype=dtype) * 0.01,
|
| 80 |
+
"offsets":
|
| 81 |
+
offsets,
|
| 82 |
+
"dim":
|
| 83 |
+
hidden,
|
| 84 |
+
"eps":
|
| 85 |
+
eps,
|
| 86 |
+
"dtype":
|
| 87 |
+
dtype,
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
def make_naive(self, I):
|
| 91 |
+
return GroupedRefModule(
|
| 92 |
+
I["weight"].detach().clone(),
|
| 93 |
+
I["bias"].detach().clone(),
|
| 94 |
+
I["offsets"],
|
| 95 |
+
I["eps"],
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
def make_compiled(self, I):
|
| 99 |
+
m = GroupedRefModule(
|
| 100 |
+
I["weight"].detach().clone(),
|
| 101 |
+
I["bias"].detach().clone(),
|
| 102 |
+
I["offsets"],
|
| 103 |
+
I["eps"],
|
| 104 |
+
)
|
| 105 |
+
return torch.compile(m)
|
| 106 |
+
|
| 107 |
+
def make_cuda(self, I):
|
| 108 |
+
return GroupedTritonModule(
|
| 109 |
+
I["weight"].detach().clone(),
|
| 110 |
+
I["bias"].detach().clone(),
|
| 111 |
+
I["offsets"],
|
| 112 |
+
I["eps"],
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
def forward(self, obj, I):
|
| 116 |
+
return obj(I["x"], I["mul"])
|
| 117 |
+
|
| 118 |
+
def grad_inputs(self, I):
|
| 119 |
+
return [I["x"], I["mul"]]
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
CASE = GroupedMulPoly()
|
benchmarks/common/bench_framework.py
CHANGED
|
@@ -57,7 +57,12 @@ def make_fwd_benchmark_for_case(
|
|
| 57 |
I = case.build_inputs(batch_size, seq_len, dim, dtype, eps)
|
| 58 |
if provider == "speedup":
|
| 59 |
return timings_ms["naive"][key] / timings_ms["cuda"][key]
|
| 60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
run = lambda: case.forward(obj, I)
|
| 62 |
ms = triton.testing.do_bench(run)
|
| 63 |
timings_ms[provider][key] = ms
|
|
@@ -101,7 +106,12 @@ def make_fwd_benchmark_plot_for_case(
|
|
| 101 |
return 1.00
|
| 102 |
batch_size, seq_len, dim = parse_config_string(config)
|
| 103 |
I = case.build_inputs(batch_size, seq_len, dim, dtype, eps)
|
| 104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
run = lambda: case.forward(obj, I)
|
| 106 |
ms = triton.testing.do_bench(run)
|
| 107 |
timings_ms[provider][config] = ms
|
|
@@ -146,7 +156,12 @@ def make_bwd_benchmark_for_case(
|
|
| 146 |
I = case.build_inputs(batch_size, seq_len, dim, dtype, eps)
|
| 147 |
if provider == "speedup":
|
| 148 |
return timings_ms["naive"][key] / timings_ms["cuda"][key]
|
| 149 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
y = case.forward(obj, I)
|
| 151 |
gin = list(case.grad_inputs(I)) + list(obj.parameters())
|
| 152 |
if isinstance(y, torch.Tensor):
|
|
@@ -201,7 +216,12 @@ def make_bwd_benchmark_plot_for_case(
|
|
| 201 |
return 1.00
|
| 202 |
batch_size, seq_len, dim = parse_config_string(config)
|
| 203 |
I = case.build_inputs(batch_size, seq_len, dim, dtype, eps)
|
| 204 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
y = case.forward(obj, I)
|
| 206 |
gin = list(case.grad_inputs(I)) + list(obj.parameters())
|
| 207 |
if isinstance(y, torch.Tensor):
|
|
|
|
| 57 |
I = case.build_inputs(batch_size, seq_len, dim, dtype, eps)
|
| 58 |
if provider == "speedup":
|
| 59 |
return timings_ms["naive"][key] / timings_ms["cuda"][key]
|
| 60 |
+
if provider == "naive":
|
| 61 |
+
obj = case.make_naive(I)
|
| 62 |
+
elif provider == "compiled" and hasattr(case, "make_compiled"):
|
| 63 |
+
obj = case.make_compiled(I)
|
| 64 |
+
else:
|
| 65 |
+
obj = case.make_cuda(I)
|
| 66 |
run = lambda: case.forward(obj, I)
|
| 67 |
ms = triton.testing.do_bench(run)
|
| 68 |
timings_ms[provider][key] = ms
|
|
|
|
| 106 |
return 1.00
|
| 107 |
batch_size, seq_len, dim = parse_config_string(config)
|
| 108 |
I = case.build_inputs(batch_size, seq_len, dim, dtype, eps)
|
| 109 |
+
if provider == "naive":
|
| 110 |
+
obj = case.make_naive(I)
|
| 111 |
+
elif provider == "compiled" and hasattr(case, "make_compiled"):
|
| 112 |
+
obj = case.make_compiled(I)
|
| 113 |
+
else:
|
| 114 |
+
obj = case.make_cuda(I)
|
| 115 |
run = lambda: case.forward(obj, I)
|
| 116 |
ms = triton.testing.do_bench(run)
|
| 117 |
timings_ms[provider][config] = ms
|
|
|
|
| 156 |
I = case.build_inputs(batch_size, seq_len, dim, dtype, eps)
|
| 157 |
if provider == "speedup":
|
| 158 |
return timings_ms["naive"][key] / timings_ms["cuda"][key]
|
| 159 |
+
if provider == "naive":
|
| 160 |
+
obj = case.make_naive(I)
|
| 161 |
+
elif provider == "compiled" and hasattr(case, "make_compiled"):
|
| 162 |
+
obj = case.make_compiled(I)
|
| 163 |
+
else:
|
| 164 |
+
obj = case.make_cuda(I)
|
| 165 |
y = case.forward(obj, I)
|
| 166 |
gin = list(case.grad_inputs(I)) + list(obj.parameters())
|
| 167 |
if isinstance(y, torch.Tensor):
|
|
|
|
| 216 |
return 1.00
|
| 217 |
batch_size, seq_len, dim = parse_config_string(config)
|
| 218 |
I = case.build_inputs(batch_size, seq_len, dim, dtype, eps)
|
| 219 |
+
if provider == "naive":
|
| 220 |
+
obj = case.make_naive(I)
|
| 221 |
+
elif provider == "compiled" and hasattr(case, "make_compiled"):
|
| 222 |
+
obj = case.make_compiled(I)
|
| 223 |
+
else:
|
| 224 |
+
obj = case.make_cuda(I)
|
| 225 |
y = case.forward(obj, I)
|
| 226 |
gin = list(case.grad_inputs(I)) + list(obj.parameters())
|
| 227 |
if isinstance(y, torch.Tensor):
|
benchmarks/run_cases.py
CHANGED
|
@@ -23,12 +23,15 @@ def make_title_tag():
|
|
| 23 |
return f"[{dev_name} | torch {torch_ver}]"
|
| 24 |
|
| 25 |
|
| 26 |
-
def plot_result(r_path):
|
| 27 |
import matplotlib.pyplot as plt
|
| 28 |
import pandas as pd
|
| 29 |
df = pd.read_csv(r_path + ".csv")
|
|
|
|
|
|
|
|
|
|
| 30 |
plt.figure(figsize=(12, 6))
|
| 31 |
-
ax = df.plot(x="config", y=
|
| 32 |
ax.set_title("Speedup over torch (higher is better)\n" + make_title_tag(),
|
| 33 |
fontsize=14,
|
| 34 |
fontweight="bold")
|
|
@@ -44,9 +47,10 @@ def plot_result(r_path):
|
|
| 44 |
|
| 45 |
def main():
|
| 46 |
ap = argparse.ArgumentParser()
|
| 47 |
-
ap.add_argument(
|
| 48 |
-
|
| 49 |
-
|
|
|
|
| 50 |
ap.add_argument("--plot", action="store_true")
|
| 51 |
ap.add_argument(
|
| 52 |
"--save-path",
|
|
@@ -54,8 +58,25 @@ def main():
|
|
| 54 |
default="./configs/",
|
| 55 |
help="Path to save benchmark results",
|
| 56 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
args = ap.parse_args()
|
| 58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
torch.set_default_device("cuda")
|
| 60 |
mod = importlib.import_module(f"cases.{args.case}")
|
| 61 |
case: DiffCase = mod.CASE
|
|
@@ -67,76 +88,118 @@ def main():
|
|
| 67 |
hidden_size=4096,
|
| 68 |
)
|
| 69 |
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
|
| 141 |
|
| 142 |
if __name__ == "__main__":
|
|
|
|
| 23 |
return f"[{dev_name} | torch {torch_ver}]"
|
| 24 |
|
| 25 |
|
| 26 |
+
def plot_result(r_path, columns=None):
|
| 27 |
import matplotlib.pyplot as plt
|
| 28 |
import pandas as pd
|
| 29 |
df = pd.read_csv(r_path + ".csv")
|
| 30 |
+
if columns is None:
|
| 31 |
+
columns = [c for c in ["Naive", "Compiled", "Cuda", "Triton"]
|
| 32 |
+
if c in df.columns]
|
| 33 |
plt.figure(figsize=(12, 6))
|
| 34 |
+
ax = df.plot(x="config", y=columns, kind="bar", ax=plt.gca())
|
| 35 |
ax.set_title("Speedup over torch (higher is better)\n" + make_title_tag(),
|
| 36 |
fontsize=14,
|
| 37 |
fontweight="bold")
|
|
|
|
| 47 |
|
| 48 |
def main():
|
| 49 |
ap = argparse.ArgumentParser()
|
| 50 |
+
ap.add_argument(
|
| 51 |
+
"--case",
|
| 52 |
+
choices=["rms", "add_rms", "poly", "mul_poly", "grouped_mul_poly"],
|
| 53 |
+
required=True)
|
| 54 |
ap.add_argument("--plot", action="store_true")
|
| 55 |
ap.add_argument(
|
| 56 |
"--save-path",
|
|
|
|
| 58 |
default="./configs/",
|
| 59 |
help="Path to save benchmark results",
|
| 60 |
)
|
| 61 |
+
ap.add_argument(
|
| 62 |
+
"--dtype",
|
| 63 |
+
choices=["fp16", "bf16", "fp32", "all"],
|
| 64 |
+
default="bf16",
|
| 65 |
+
help="Data type for benchmarking (default: bf16)",
|
| 66 |
+
)
|
| 67 |
args = ap.parse_args()
|
| 68 |
|
| 69 |
+
dtype_map = {
|
| 70 |
+
"fp16": torch.float16,
|
| 71 |
+
"bf16": torch.bfloat16,
|
| 72 |
+
"fp32": torch.float32,
|
| 73 |
+
}
|
| 74 |
+
if args.dtype == "all":
|
| 75 |
+
dtypes = [("fp16", torch.float16), ("bf16", torch.bfloat16),
|
| 76 |
+
("fp32", torch.float32)]
|
| 77 |
+
else:
|
| 78 |
+
dtypes = [(args.dtype, dtype_map[args.dtype])]
|
| 79 |
+
|
| 80 |
torch.set_default_device("cuda")
|
| 81 |
mod = importlib.import_module(f"cases.{args.case}")
|
| 82 |
case: DiffCase = mod.CASE
|
|
|
|
| 88 |
hidden_size=4096,
|
| 89 |
)
|
| 90 |
|
| 91 |
+
for dtype_name, dtype in dtypes:
|
| 92 |
+
print(f"\n{'=' * 60}")
|
| 93 |
+
print(f" Benchmarking dtype: {dtype_name} ({dtype})")
|
| 94 |
+
print(f"{'=' * 60}\n")
|
| 95 |
+
|
| 96 |
+
save_dir = os.path.join(args.save_path, args.case, dtype_name)
|
| 97 |
+
is_grouped = args.case == "grouped_mul_poly"
|
| 98 |
+
|
| 99 |
+
if args.plot:
|
| 100 |
+
batch_size_range = [1]
|
| 101 |
+
seq_length_range = [4096, 8192, 16384]
|
| 102 |
+
if is_grouped:
|
| 103 |
+
dim = [1280]
|
| 104 |
+
elif "poly" in args.case:
|
| 105 |
+
dim = [8192, 16384]
|
| 106 |
+
else:
|
| 107 |
+
dim = [2048, 4096]
|
| 108 |
+
configs = list(
|
| 109 |
+
itertools.product(batch_size_range, seq_length_range, dim))
|
| 110 |
+
|
| 111 |
+
if is_grouped:
|
| 112 |
+
plot_line_vals = ("naive", "compiled", "cuda")
|
| 113 |
+
plot_line_names = {
|
| 114 |
+
"naive": "Naive",
|
| 115 |
+
"compiled": "Compiled",
|
| 116 |
+
"cuda": "Triton",
|
| 117 |
+
}
|
| 118 |
+
else:
|
| 119 |
+
plot_line_vals = ("naive", "cuda")
|
| 120 |
+
plot_line_names = {
|
| 121 |
+
"naive": "Naive",
|
| 122 |
+
"cuda": "Cuda",
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
plot_name = f"plot_{args.case}-{dtype_name}-fwd-perf"
|
| 126 |
+
bench = make_fwd_benchmark_plot_for_case(
|
| 127 |
+
case=case,
|
| 128 |
+
configs=configs,
|
| 129 |
+
plot_name=plot_name,
|
| 130 |
+
dtype=dtype,
|
| 131 |
+
line_vals=plot_line_vals,
|
| 132 |
+
line_names=plot_line_names,
|
| 133 |
+
)
|
| 134 |
+
bench.run(print_data=True, save_path=save_dir)
|
| 135 |
+
plot_result(os.path.join(save_dir, plot_name))
|
| 136 |
+
|
| 137 |
+
plot_name = f"plot_{args.case}-{dtype_name}-bwd-perf"
|
| 138 |
+
bench = make_bwd_benchmark_plot_for_case(
|
| 139 |
+
case=case,
|
| 140 |
+
configs=configs,
|
| 141 |
+
plot_name=plot_name,
|
| 142 |
+
dtype=dtype,
|
| 143 |
+
line_vals=plot_line_vals,
|
| 144 |
+
line_names=plot_line_names,
|
| 145 |
+
)
|
| 146 |
+
bench.run(print_data=True, save_path=save_dir)
|
| 147 |
+
plot_result(os.path.join(save_dir, plot_name))
|
| 148 |
+
for f in glob.glob(os.path.join(save_dir, "*.html")) + \
|
| 149 |
+
glob.glob(os.path.join(save_dir, "*.csv")):
|
| 150 |
+
os.remove(f)
|
| 151 |
+
else:
|
| 152 |
+
batch_size_range = [2**i for i in range(0, 4, 1)]
|
| 153 |
+
seq_length_range = [2**i for i in range(10, 14, 1)]
|
| 154 |
+
if is_grouped:
|
| 155 |
+
dim = [1280]
|
| 156 |
+
elif "poly" in args.case:
|
| 157 |
+
dim = [8192, 16384]
|
| 158 |
+
else:
|
| 159 |
+
dim = [2048, 4096]
|
| 160 |
+
configs = list(
|
| 161 |
+
itertools.product(dim, batch_size_range, seq_length_range))
|
| 162 |
+
|
| 163 |
+
if is_grouped:
|
| 164 |
+
csv_line_vals = ("naive", "compiled", "cuda", "speedup")
|
| 165 |
+
csv_line_names = {
|
| 166 |
+
"naive": "Naive",
|
| 167 |
+
"compiled": "Compiled",
|
| 168 |
+
"cuda": "Triton",
|
| 169 |
+
"speedup": "SpeedUp",
|
| 170 |
+
}
|
| 171 |
+
else:
|
| 172 |
+
csv_line_vals = ("naive", "cuda", "speedup")
|
| 173 |
+
csv_line_names = {
|
| 174 |
+
"naive": "Naive",
|
| 175 |
+
"cuda": "Cuda",
|
| 176 |
+
"speedup": "SpeedUp",
|
| 177 |
+
}
|
| 178 |
+
|
| 179 |
+
bench = make_fwd_benchmark_for_case(
|
| 180 |
+
case=case,
|
| 181 |
+
configs=configs,
|
| 182 |
+
plot_name=f"{args.case}-{dtype_name}-fwd-perf",
|
| 183 |
+
dtype=dtype,
|
| 184 |
+
line_vals=csv_line_vals,
|
| 185 |
+
line_names=csv_line_names,
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
bench.run(print_data=True, save_path=save_dir)
|
| 189 |
+
|
| 190 |
+
bench = make_bwd_benchmark_for_case(
|
| 191 |
+
case=case,
|
| 192 |
+
configs=configs,
|
| 193 |
+
plot_name=f"{args.case}-{dtype_name}-bwd-perf",
|
| 194 |
+
dtype=dtype,
|
| 195 |
+
line_vals=csv_line_vals,
|
| 196 |
+
line_names=csv_line_names,
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
bench.run(print_data=True, save_path=save_dir)
|
| 200 |
+
for f in glob.glob(os.path.join(save_dir, "*.html")) + \
|
| 201 |
+
glob.glob(os.path.join(save_dir, "*.png")):
|
| 202 |
+
os.remove(f)
|
| 203 |
|
| 204 |
|
| 205 |
if __name__ == "__main__":
|
build/torch210-cxx11-cu126-x86_64-linux/__init__.py
CHANGED
|
@@ -2,6 +2,7 @@ import torch
|
|
| 2 |
|
| 3 |
from . import layers, parallel_style
|
| 4 |
from ._ops import ops
|
|
|
|
| 5 |
from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
|
| 6 |
from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
|
| 7 |
|
|
@@ -45,6 +46,7 @@ def fused_add_rms_norm(
|
|
| 45 |
__all__ = [
|
| 46 |
"poly_norm",
|
| 47 |
"fused_mul_poly_norm",
|
|
|
|
| 48 |
"rms_norm",
|
| 49 |
"fused_add_rms_norm",
|
| 50 |
"layers",
|
|
|
|
| 2 |
|
| 3 |
from . import layers, parallel_style
|
| 4 |
from ._ops import ops
|
| 5 |
+
from .grouped_poly_norm import grouped_fused_mul_poly_norm
|
| 6 |
from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
|
| 7 |
from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
|
| 8 |
|
|
|
|
| 46 |
__all__ = [
|
| 47 |
"poly_norm",
|
| 48 |
"fused_mul_poly_norm",
|
| 49 |
+
"grouped_fused_mul_poly_norm",
|
| 50 |
"rms_norm",
|
| 51 |
"fused_add_rms_norm",
|
| 52 |
"layers",
|
build/torch210-cxx11-cu126-x86_64-linux/{_activation_18b7543_dirty.abi3.so → _activation_0e6f27f_dirty.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 10775296
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f31dfeac9b22c01a027f858b3d8beaf87eea9adf8dc45902f0e43d6c264fd985
|
| 3 |
size 10775296
|
build/torch210-cxx11-cu126-x86_64-linux/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _activation_0e6f27f_dirty
|
| 3 |
+
ops = torch.ops._activation_0e6f27f_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_activation_0e6f27f_dirty::{op_name}"
|
build/torch210-cxx11-cu126-x86_64-linux/grouped_poly_norm.py
ADDED
|
@@ -0,0 +1,583 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Triton-accelerated Grouped FusedMulPolyNorm for MoE.
|
| 2 |
+
|
| 3 |
+
Fuses the entire PolyNorm computation into two Triton kernels (fwd + bwd),
|
| 4 |
+
eliminating multiple intermediate tensors and kernel launches.
|
| 5 |
+
|
| 6 |
+
PolyNorm formula (per row):
|
| 7 |
+
poly = w[0] * rms_norm(x^3) + w[1] * rms_norm(x^2) + w[2] * rms_norm(x) + bias
|
| 8 |
+
output = poly * mul
|
| 9 |
+
|
| 10 |
+
where rms_norm(x) = x / sqrt(mean(x^2, dim=-1) + eps)
|
| 11 |
+
|
| 12 |
+
Performance optimizations:
|
| 13 |
+
- @triton.autotune selects optimal BLOCK_D, num_warps, and num_stages per
|
| 14 |
+
hidden dimension.
|
| 15 |
+
- Single-tile specialization: when D <= BLOCK_D, all data stays in registers
|
| 16 |
+
across the reduction and output phases, eliminating redundant global reads.
|
| 17 |
+
- Multi-tile software pipelining: explicit num_stages in autotune configs
|
| 18 |
+
enables overlapping memory loads with computation across loop iterations.
|
| 19 |
+
- In-kernel binary search for expert mapping: eliminates 2 PyTorch kernel
|
| 20 |
+
launches (torch.arange + torch.bucketize) per forward/backward call.
|
| 21 |
+
- Backward 2-pass optimization: pass 1 merges RMS statistics computation
|
| 22 |
+
with dot product accumulation, pass 2 computes gradients. This reduces
|
| 23 |
+
memory traffic compared to a naive 3-pass approach.
|
| 24 |
+
|
| 25 |
+
Forward kernel: one program per row, tiles over D dimension.
|
| 26 |
+
- Computes x, x^2, x^3 in registers
|
| 27 |
+
- Computes three RMS norms in a single pass (shared variance reduction)
|
| 28 |
+
- Applies polynomial weights + bias + mul in-place
|
| 29 |
+
|
| 30 |
+
Backward kernel: one program per row, tiles over D dimension.
|
| 31 |
+
- Recomputes forward intermediates from saved inputs (activation recomputation)
|
| 32 |
+
- 2-pass: (1) RMS stats + dot products + bias grad, (2) grad_input + grad_mul + weight grads
|
| 33 |
+
- Weight/bias gradients use tl.atomic_add for cross-row accumulation
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
import torch
|
| 37 |
+
from torch import Tensor
|
| 38 |
+
|
| 39 |
+
try:
|
| 40 |
+
import triton
|
| 41 |
+
import triton.language as tl
|
| 42 |
+
|
| 43 |
+
HAS_TRITON = True
|
| 44 |
+
except ImportError:
|
| 45 |
+
HAS_TRITON = False
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# ---------------------------------------------------------------------------
|
| 49 |
+
# PyTorch reference implementation (for testing and benchmarking)
|
| 50 |
+
# ---------------------------------------------------------------------------
|
| 51 |
+
def _rms_norm(x: Tensor, eps: float) -> Tensor:
|
| 52 |
+
"""Per-row RMS normalization: x / sqrt(mean(x^2, dim=-1) + eps)"""
|
| 53 |
+
return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def grouped_fused_mul_poly_norm_ref(
|
| 57 |
+
input: Tensor,
|
| 58 |
+
mul: Tensor,
|
| 59 |
+
weight: Tensor,
|
| 60 |
+
bias: Tensor,
|
| 61 |
+
offsets: Tensor,
|
| 62 |
+
eps: float = 1e-6,
|
| 63 |
+
expert_offset: int = 0,
|
| 64 |
+
) -> Tensor:
|
| 65 |
+
"""PyTorch reference for Grouped FusedMulPolyNorm (vectorized, single pass).
|
| 66 |
+
|
| 67 |
+
Uses torch.bucketize to map tokens to experts, then computes PolyNorm
|
| 68 |
+
for all tokens at once. torch.compile friendly.
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
input: (total_tokens, D) - concatenated tokens for all experts
|
| 72 |
+
mul: (total_tokens, D) - gate values to multiply with
|
| 73 |
+
weight: (num_experts, 3) - per-expert polynomial weights [x^3, x^2, x]
|
| 74 |
+
bias: (num_experts, 1) - per-expert polynomial bias
|
| 75 |
+
offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32)
|
| 76 |
+
eps: numerical stability epsilon
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
(total_tokens, D) - output tensor
|
| 80 |
+
"""
|
| 81 |
+
orig_dtype = input.dtype
|
| 82 |
+
|
| 83 |
+
token_positions = torch.arange(input.shape[0], device=input.device)
|
| 84 |
+
expert_idx = torch.bucketize(token_positions, offsets, right=True) + expert_offset
|
| 85 |
+
|
| 86 |
+
weight_fp32 = weight.float()
|
| 87 |
+
bias_fp32 = bias.float()
|
| 88 |
+
|
| 89 |
+
per_token_w = weight_fp32[expert_idx]
|
| 90 |
+
per_token_b = bias_fp32[expert_idx]
|
| 91 |
+
|
| 92 |
+
x = input.float()
|
| 93 |
+
m = mul.float()
|
| 94 |
+
|
| 95 |
+
x2 = x * x
|
| 96 |
+
x3 = x2 * x
|
| 97 |
+
|
| 98 |
+
poly = (per_token_w[:, 0:1] * _rms_norm(x3, eps) +
|
| 99 |
+
per_token_w[:, 1:2] * _rms_norm(x2, eps) +
|
| 100 |
+
per_token_w[:, 2:3] * _rms_norm(x, eps) + per_token_b)
|
| 101 |
+
|
| 102 |
+
return (poly * m).to(orig_dtype)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
# ---------------------------------------------------------------------------
|
| 106 |
+
# Triton kernel implementation
|
| 107 |
+
# ---------------------------------------------------------------------------
|
| 108 |
+
if HAS_TRITON:
|
| 109 |
+
# --- Autotune configurations ---
|
| 110 |
+
_GROUPED_POLYNORM_FWD_CONFIGS = [
|
| 111 |
+
triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2),
|
| 112 |
+
triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3),
|
| 113 |
+
triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2),
|
| 114 |
+
triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3),
|
| 115 |
+
triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2),
|
| 116 |
+
triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4),
|
| 117 |
+
triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2),
|
| 118 |
+
triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3),
|
| 119 |
+
triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4),
|
| 120 |
+
triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2),
|
| 121 |
+
triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3),
|
| 122 |
+
triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4),
|
| 123 |
+
triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2),
|
| 124 |
+
triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2),
|
| 125 |
+
triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3),
|
| 126 |
+
triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2),
|
| 127 |
+
triton.Config({"BLOCK_D": 2048}, num_warps=4, num_stages=1),
|
| 128 |
+
triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1),
|
| 129 |
+
triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1),
|
| 130 |
+
triton.Config({"BLOCK_D": 2048}, num_warps=32, num_stages=1),
|
| 131 |
+
]
|
| 132 |
+
|
| 133 |
+
_GROUPED_POLYNORM_BWD_CONFIGS = [
|
| 134 |
+
triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2),
|
| 135 |
+
triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3),
|
| 136 |
+
triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2),
|
| 137 |
+
triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3),
|
| 138 |
+
triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2),
|
| 139 |
+
triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4),
|
| 140 |
+
triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2),
|
| 141 |
+
triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3),
|
| 142 |
+
triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4),
|
| 143 |
+
triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=5),
|
| 144 |
+
triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2),
|
| 145 |
+
triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3),
|
| 146 |
+
triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4),
|
| 147 |
+
triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2),
|
| 148 |
+
triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2),
|
| 149 |
+
triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3),
|
| 150 |
+
triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=4),
|
| 151 |
+
triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2),
|
| 152 |
+
triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1),
|
| 153 |
+
triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1),
|
| 154 |
+
]
|
| 155 |
+
|
| 156 |
+
@triton.autotune(
|
| 157 |
+
configs=_GROUPED_POLYNORM_FWD_CONFIGS,
|
| 158 |
+
key=["D"],
|
| 159 |
+
)
|
| 160 |
+
@triton.jit
|
| 161 |
+
def _grouped_polynorm_fwd_kernel(
|
| 162 |
+
input_ptr,
|
| 163 |
+
mul_ptr,
|
| 164 |
+
weight_ptr,
|
| 165 |
+
bias_ptr,
|
| 166 |
+
offsets_ptr,
|
| 167 |
+
output_ptr,
|
| 168 |
+
N,
|
| 169 |
+
D,
|
| 170 |
+
num_experts,
|
| 171 |
+
eps,
|
| 172 |
+
expert_offset,
|
| 173 |
+
stride_input_row,
|
| 174 |
+
stride_mul_row,
|
| 175 |
+
stride_out_row,
|
| 176 |
+
BLOCK_D: tl.constexpr,
|
| 177 |
+
):
|
| 178 |
+
"""Forward kernel: one program per row."""
|
| 179 |
+
row = tl.program_id(0)
|
| 180 |
+
if row >= N:
|
| 181 |
+
return
|
| 182 |
+
|
| 183 |
+
# Binary search for expert index (12 iters covers up to 4096 experts)
|
| 184 |
+
lo = 0
|
| 185 |
+
hi = num_experts
|
| 186 |
+
for _ in range(12):
|
| 187 |
+
if lo < hi:
|
| 188 |
+
mid = (lo + hi) // 2
|
| 189 |
+
if tl.load(offsets_ptr + mid) <= row:
|
| 190 |
+
lo = mid + 1
|
| 191 |
+
else:
|
| 192 |
+
hi = mid
|
| 193 |
+
eidx = lo + expert_offset
|
| 194 |
+
|
| 195 |
+
w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32)
|
| 196 |
+
w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32)
|
| 197 |
+
w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32)
|
| 198 |
+
b = tl.load(bias_ptr + eidx).to(tl.float32)
|
| 199 |
+
|
| 200 |
+
input_row_ptr = input_ptr + row * stride_input_row
|
| 201 |
+
mul_row_ptr = mul_ptr + row * stride_mul_row
|
| 202 |
+
out_row_ptr = output_ptr + row * stride_out_row
|
| 203 |
+
|
| 204 |
+
D_float = D.to(tl.float32)
|
| 205 |
+
|
| 206 |
+
# --- Single-tile path ---
|
| 207 |
+
if D <= BLOCK_D:
|
| 208 |
+
d_offs = tl.arange(0, BLOCK_D)
|
| 209 |
+
mask = d_offs < D
|
| 210 |
+
|
| 211 |
+
x = tl.load(input_row_ptr + d_offs, mask=mask,
|
| 212 |
+
other=0.0).to(tl.float32)
|
| 213 |
+
m = tl.load(mul_row_ptr + d_offs, mask=mask,
|
| 214 |
+
other=0.0).to(tl.float32)
|
| 215 |
+
|
| 216 |
+
x2 = x * x
|
| 217 |
+
x3 = x2 * x
|
| 218 |
+
|
| 219 |
+
inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps)
|
| 220 |
+
inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps)
|
| 221 |
+
inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps)
|
| 222 |
+
|
| 223 |
+
# Pre-multiply scalar weight * inv_rms to save 1 FMA per element
|
| 224 |
+
w0_inv = w0 * inv_rms_x3
|
| 225 |
+
w1_inv = w1 * inv_rms_x2
|
| 226 |
+
w2_inv = w2 * inv_rms_x
|
| 227 |
+
|
| 228 |
+
poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b
|
| 229 |
+
tl.store(out_row_ptr + d_offs, poly * m, mask=mask)
|
| 230 |
+
else:
|
| 231 |
+
# --- Multi-tile: two-pass approach ---
|
| 232 |
+
sum_x2 = tl.zeros((), dtype=tl.float32)
|
| 233 |
+
sum_x4 = tl.zeros((), dtype=tl.float32)
|
| 234 |
+
sum_x6 = tl.zeros((), dtype=tl.float32)
|
| 235 |
+
|
| 236 |
+
for d_start in range(0, D, BLOCK_D):
|
| 237 |
+
d_offs = d_start + tl.arange(0, BLOCK_D)
|
| 238 |
+
mask = d_offs < D
|
| 239 |
+
x = tl.load(input_row_ptr + d_offs, mask=mask,
|
| 240 |
+
other=0.0).to(tl.float32)
|
| 241 |
+
x2 = x * x
|
| 242 |
+
sum_x2 += tl.sum(x2)
|
| 243 |
+
sum_x4 += tl.sum(x2 * x2)
|
| 244 |
+
sum_x6 += tl.sum(x2 * x2 * x2)
|
| 245 |
+
|
| 246 |
+
inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps)
|
| 247 |
+
inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps)
|
| 248 |
+
inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps)
|
| 249 |
+
|
| 250 |
+
# Pre-multiply scalar weight * inv_rms
|
| 251 |
+
w0_inv = w0 * inv_rms_x3
|
| 252 |
+
w1_inv = w1 * inv_rms_x2
|
| 253 |
+
w2_inv = w2 * inv_rms_x
|
| 254 |
+
|
| 255 |
+
for d_start in range(0, D, BLOCK_D):
|
| 256 |
+
d_offs = d_start + tl.arange(0, BLOCK_D)
|
| 257 |
+
mask = d_offs < D
|
| 258 |
+
x = tl.load(input_row_ptr + d_offs, mask=mask,
|
| 259 |
+
other=0.0).to(tl.float32)
|
| 260 |
+
m = tl.load(mul_row_ptr + d_offs, mask=mask,
|
| 261 |
+
other=0.0).to(tl.float32)
|
| 262 |
+
x2 = x * x
|
| 263 |
+
x3 = x2 * x
|
| 264 |
+
poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b
|
| 265 |
+
tl.store(out_row_ptr + d_offs, poly * m, mask=mask)
|
| 266 |
+
|
| 267 |
+
@triton.autotune(
|
| 268 |
+
configs=_GROUPED_POLYNORM_BWD_CONFIGS,
|
| 269 |
+
key=["D"],
|
| 270 |
+
reset_to_zero=["grad_weight_ptr", "grad_bias_ptr"],
|
| 271 |
+
)
|
| 272 |
+
@triton.jit
|
| 273 |
+
def _grouped_polynorm_bwd_kernel(
|
| 274 |
+
grad_out_ptr,
|
| 275 |
+
input_ptr,
|
| 276 |
+
mul_ptr,
|
| 277 |
+
weight_ptr,
|
| 278 |
+
bias_ptr,
|
| 279 |
+
offsets_ptr,
|
| 280 |
+
grad_input_ptr,
|
| 281 |
+
grad_mul_ptr,
|
| 282 |
+
grad_weight_ptr,
|
| 283 |
+
grad_bias_ptr,
|
| 284 |
+
N,
|
| 285 |
+
D,
|
| 286 |
+
num_experts,
|
| 287 |
+
eps,
|
| 288 |
+
expert_offset,
|
| 289 |
+
stride_row,
|
| 290 |
+
BLOCK_D: tl.constexpr,
|
| 291 |
+
):
|
| 292 |
+
"""Backward kernel: one program per row, 2-pass approach.
|
| 293 |
+
|
| 294 |
+
Pass 1: RMS stats + dot products + bias grad
|
| 295 |
+
Pass 2: grad_input + grad_mul + weight grads (via atomic_add)
|
| 296 |
+
"""
|
| 297 |
+
row = tl.program_id(0)
|
| 298 |
+
if row >= N:
|
| 299 |
+
return
|
| 300 |
+
|
| 301 |
+
lo = 0
|
| 302 |
+
hi = num_experts
|
| 303 |
+
for _ in range(12):
|
| 304 |
+
if lo < hi:
|
| 305 |
+
mid = (lo + hi) // 2
|
| 306 |
+
if tl.load(offsets_ptr + mid) <= row:
|
| 307 |
+
lo = mid + 1
|
| 308 |
+
else:
|
| 309 |
+
hi = mid
|
| 310 |
+
eidx = lo + expert_offset
|
| 311 |
+
|
| 312 |
+
w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32)
|
| 313 |
+
w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32)
|
| 314 |
+
w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32)
|
| 315 |
+
b_val = tl.load(bias_ptr + eidx).to(tl.float32)
|
| 316 |
+
|
| 317 |
+
input_row_ptr = input_ptr + row * stride_row
|
| 318 |
+
mul_row_ptr = mul_ptr + row * stride_row
|
| 319 |
+
grad_out_row_ptr = grad_out_ptr + row * stride_row
|
| 320 |
+
grad_input_row_ptr = grad_input_ptr + row * stride_row
|
| 321 |
+
grad_mul_row_ptr = grad_mul_ptr + row * stride_row
|
| 322 |
+
|
| 323 |
+
D_float = D.to(tl.float32)
|
| 324 |
+
|
| 325 |
+
# --- Single-tile path ---
|
| 326 |
+
if D <= BLOCK_D:
|
| 327 |
+
d_offs = tl.arange(0, BLOCK_D)
|
| 328 |
+
mask = d_offs < D
|
| 329 |
+
|
| 330 |
+
x = tl.load(input_row_ptr + d_offs, mask=mask,
|
| 331 |
+
other=0.0).to(tl.float32)
|
| 332 |
+
m = tl.load(mul_row_ptr + d_offs, mask=mask,
|
| 333 |
+
other=0.0).to(tl.float32)
|
| 334 |
+
go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
|
| 335 |
+
other=0.0).to(tl.float32)
|
| 336 |
+
|
| 337 |
+
x2 = x * x
|
| 338 |
+
x3 = x2 * x
|
| 339 |
+
|
| 340 |
+
# Compute RMS stats (x4 inlined to reduce register pressure)
|
| 341 |
+
inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps)
|
| 342 |
+
inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps)
|
| 343 |
+
inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps)
|
| 344 |
+
|
| 345 |
+
w0_inv = w0 * inv_rms_x3
|
| 346 |
+
w1_inv = w1 * inv_rms_x2
|
| 347 |
+
w2_inv = w2 * inv_rms_x
|
| 348 |
+
|
| 349 |
+
dpoly = go * m
|
| 350 |
+
|
| 351 |
+
# Dot products for coefficients and weight grads
|
| 352 |
+
sum_dpoly_x = tl.sum(dpoly * x)
|
| 353 |
+
sum_dpoly_x2 = tl.sum(dpoly * x2)
|
| 354 |
+
sum_dpoly_x3 = tl.sum(dpoly * x3)
|
| 355 |
+
grad_b_acc = tl.sum(dpoly)
|
| 356 |
+
|
| 357 |
+
# Weight grads
|
| 358 |
+
grad_w0_acc = inv_rms_x3 * sum_dpoly_x3
|
| 359 |
+
grad_w1_acc = inv_rms_x2 * sum_dpoly_x2
|
| 360 |
+
grad_w2_acc = inv_rms_x * sum_dpoly_x
|
| 361 |
+
|
| 362 |
+
# Coefficients for grad_input
|
| 363 |
+
coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float
|
| 364 |
+
coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float
|
| 365 |
+
coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float
|
| 366 |
+
|
| 367 |
+
# grad_mul
|
| 368 |
+
poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv
|
| 369 |
+
tl.store(grad_mul_row_ptr + d_offs, go * (poly + b_val), mask=mask)
|
| 370 |
+
|
| 371 |
+
# grad_input (in-place accumulation to reduce register pressure)
|
| 372 |
+
g = inv_rms_x * (w2 * dpoly - x * coeff_x)
|
| 373 |
+
g += 2.0 * x * inv_rms_x2 * (w1 * dpoly - x2 * coeff_x2)
|
| 374 |
+
g += 3.0 * x2 * inv_rms_x3 * (w0 * dpoly - x3 * coeff_x3)
|
| 375 |
+
tl.store(grad_input_row_ptr + d_offs, g, mask=mask)
|
| 376 |
+
|
| 377 |
+
tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc)
|
| 378 |
+
tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc)
|
| 379 |
+
tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc)
|
| 380 |
+
tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc)
|
| 381 |
+
else:
|
| 382 |
+
# --- Multi-tile: 2-pass ---
|
| 383 |
+
# Pass 1: RMS stats + dot products + bias grad
|
| 384 |
+
sum_x2 = tl.zeros((), dtype=tl.float32)
|
| 385 |
+
sum_x4 = tl.zeros((), dtype=tl.float32)
|
| 386 |
+
sum_x6 = tl.zeros((), dtype=tl.float32)
|
| 387 |
+
sum_dpoly_x = tl.zeros((), dtype=tl.float32)
|
| 388 |
+
sum_dpoly_x2 = tl.zeros((), dtype=tl.float32)
|
| 389 |
+
sum_dpoly_x3 = tl.zeros((), dtype=tl.float32)
|
| 390 |
+
grad_b_acc = tl.zeros((), dtype=tl.float32)
|
| 391 |
+
|
| 392 |
+
for d_start in range(0, D, BLOCK_D):
|
| 393 |
+
d_offs = d_start + tl.arange(0, BLOCK_D)
|
| 394 |
+
mask = d_offs < D
|
| 395 |
+
x = tl.load(input_row_ptr + d_offs, mask=mask,
|
| 396 |
+
other=0.0).to(tl.float32)
|
| 397 |
+
m = tl.load(mul_row_ptr + d_offs, mask=mask,
|
| 398 |
+
other=0.0).to(tl.float32)
|
| 399 |
+
go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
|
| 400 |
+
other=0.0).to(tl.float32)
|
| 401 |
+
|
| 402 |
+
x2 = x * x
|
| 403 |
+
x3 = x2 * x
|
| 404 |
+
dpoly = go * m
|
| 405 |
+
|
| 406 |
+
sum_x2 += tl.sum(x2)
|
| 407 |
+
sum_x4 += tl.sum(x2 * x2)
|
| 408 |
+
sum_x6 += tl.sum(x2 * x2 * x2)
|
| 409 |
+
sum_dpoly_x += tl.sum(dpoly * x)
|
| 410 |
+
sum_dpoly_x2 += tl.sum(dpoly * x2)
|
| 411 |
+
sum_dpoly_x3 += tl.sum(dpoly * x3)
|
| 412 |
+
grad_b_acc += tl.sum(dpoly)
|
| 413 |
+
|
| 414 |
+
inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps)
|
| 415 |
+
inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps)
|
| 416 |
+
inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps)
|
| 417 |
+
|
| 418 |
+
w0_inv = w0 * inv_rms_x3
|
| 419 |
+
w1_inv = w1 * inv_rms_x2
|
| 420 |
+
w2_inv = w2 * inv_rms_x
|
| 421 |
+
|
| 422 |
+
# Weight grads
|
| 423 |
+
grad_w0_acc = inv_rms_x3 * sum_dpoly_x3
|
| 424 |
+
grad_w1_acc = inv_rms_x2 * sum_dpoly_x2
|
| 425 |
+
grad_w2_acc = inv_rms_x * sum_dpoly_x
|
| 426 |
+
|
| 427 |
+
# Coefficients for grad_input
|
| 428 |
+
coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float
|
| 429 |
+
coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float
|
| 430 |
+
coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float
|
| 431 |
+
|
| 432 |
+
# Pass 2: grad_input + grad_mul
|
| 433 |
+
for d_start in range(0, D, BLOCK_D):
|
| 434 |
+
d_offs = d_start + tl.arange(0, BLOCK_D)
|
| 435 |
+
mask = d_offs < D
|
| 436 |
+
x = tl.load(input_row_ptr + d_offs, mask=mask,
|
| 437 |
+
other=0.0).to(tl.float32)
|
| 438 |
+
m = tl.load(mul_row_ptr + d_offs, mask=mask,
|
| 439 |
+
other=0.0).to(tl.float32)
|
| 440 |
+
go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
|
| 441 |
+
other=0.0).to(tl.float32)
|
| 442 |
+
|
| 443 |
+
x2 = x * x
|
| 444 |
+
x3 = x2 * x
|
| 445 |
+
|
| 446 |
+
poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv
|
| 447 |
+
tl.store(grad_mul_row_ptr + d_offs,
|
| 448 |
+
go * (poly + b_val),
|
| 449 |
+
mask=mask)
|
| 450 |
+
|
| 451 |
+
dpoly = go * m
|
| 452 |
+
g = inv_rms_x * (w2 * dpoly - x * coeff_x)
|
| 453 |
+
g += (2.0 * x * inv_rms_x2 *
|
| 454 |
+
(w1 * dpoly - x2 * coeff_x2))
|
| 455 |
+
g += (3.0 * x2 * inv_rms_x3 *
|
| 456 |
+
(w0 * dpoly - x3 * coeff_x3))
|
| 457 |
+
|
| 458 |
+
tl.store(grad_input_row_ptr + d_offs, g, mask=mask)
|
| 459 |
+
|
| 460 |
+
tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc)
|
| 461 |
+
tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc)
|
| 462 |
+
tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc)
|
| 463 |
+
tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc)
|
| 464 |
+
|
| 465 |
+
class _GroupedPolyNormFn(torch.autograd.Function):
|
| 466 |
+
|
| 467 |
+
@staticmethod
|
| 468 |
+
def forward(ctx, input, mul, weight, bias, offsets, eps, expert_offset):
|
| 469 |
+
N, D = input.shape
|
| 470 |
+
input = input.contiguous()
|
| 471 |
+
mul = mul.contiguous()
|
| 472 |
+
output = torch.empty_like(input)
|
| 473 |
+
|
| 474 |
+
num_experts = offsets.shape[0]
|
| 475 |
+
assert num_experts <= 4096, (
|
| 476 |
+
f"Supports at most 4096 experts, got {num_experts}.")
|
| 477 |
+
|
| 478 |
+
_grouped_polynorm_fwd_kernel[(N,)](
|
| 479 |
+
input,
|
| 480 |
+
mul,
|
| 481 |
+
weight,
|
| 482 |
+
bias,
|
| 483 |
+
offsets,
|
| 484 |
+
output,
|
| 485 |
+
N,
|
| 486 |
+
D,
|
| 487 |
+
num_experts,
|
| 488 |
+
eps,
|
| 489 |
+
expert_offset,
|
| 490 |
+
stride_input_row=input.stride(0),
|
| 491 |
+
stride_mul_row=mul.stride(0),
|
| 492 |
+
stride_out_row=output.stride(0),
|
| 493 |
+
)
|
| 494 |
+
|
| 495 |
+
ctx.save_for_backward(input, mul, weight, bias, offsets)
|
| 496 |
+
ctx.eps = eps
|
| 497 |
+
ctx.expert_offset = expert_offset
|
| 498 |
+
return output
|
| 499 |
+
|
| 500 |
+
@staticmethod
|
| 501 |
+
def backward(ctx, grad_output):
|
| 502 |
+
input, mul, weight, bias, offsets = ctx.saved_tensors
|
| 503 |
+
eps = ctx.eps
|
| 504 |
+
expert_offset = ctx.expert_offset
|
| 505 |
+
N, D = input.shape
|
| 506 |
+
|
| 507 |
+
grad_output = grad_output.contiguous()
|
| 508 |
+
grad_input = torch.empty_like(input)
|
| 509 |
+
grad_mul = torch.empty_like(mul)
|
| 510 |
+
grad_weight = torch.zeros(weight.shape[0],
|
| 511 |
+
3,
|
| 512 |
+
device=weight.device,
|
| 513 |
+
dtype=torch.float32)
|
| 514 |
+
grad_bias = torch.zeros(bias.shape[0],
|
| 515 |
+
device=bias.device,
|
| 516 |
+
dtype=torch.float32)
|
| 517 |
+
|
| 518 |
+
num_experts = offsets.shape[0]
|
| 519 |
+
|
| 520 |
+
_grouped_polynorm_bwd_kernel[(N,)](
|
| 521 |
+
grad_output,
|
| 522 |
+
input,
|
| 523 |
+
mul,
|
| 524 |
+
weight,
|
| 525 |
+
bias,
|
| 526 |
+
offsets,
|
| 527 |
+
grad_input,
|
| 528 |
+
grad_mul,
|
| 529 |
+
grad_weight,
|
| 530 |
+
grad_bias,
|
| 531 |
+
N,
|
| 532 |
+
D,
|
| 533 |
+
num_experts,
|
| 534 |
+
eps,
|
| 535 |
+
expert_offset,
|
| 536 |
+
stride_row=input.stride(0),
|
| 537 |
+
)
|
| 538 |
+
|
| 539 |
+
grad_weight = grad_weight.to(weight.dtype)
|
| 540 |
+
grad_bias = grad_bias.unsqueeze(-1).to(bias.dtype)
|
| 541 |
+
|
| 542 |
+
return grad_input, grad_mul, grad_weight, grad_bias, None, None, None
|
| 543 |
+
|
| 544 |
+
def grouped_fused_mul_poly_norm(
|
| 545 |
+
input: Tensor,
|
| 546 |
+
mul: Tensor,
|
| 547 |
+
weight: Tensor,
|
| 548 |
+
bias: Tensor,
|
| 549 |
+
offsets: Tensor,
|
| 550 |
+
eps: float = 1e-6,
|
| 551 |
+
expert_offset: int = 0,
|
| 552 |
+
) -> Tensor:
|
| 553 |
+
"""Triton-accelerated Grouped FusedMulPolyNorm.
|
| 554 |
+
|
| 555 |
+
Args:
|
| 556 |
+
input: (total_tokens, D) - concatenated tokens for all experts
|
| 557 |
+
mul: (total_tokens, D) - gate values to multiply with
|
| 558 |
+
weight: (num_experts, 3) - per-expert polynomial weights
|
| 559 |
+
bias: (num_experts, 1) - per-expert polynomial bias
|
| 560 |
+
offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32)
|
| 561 |
+
eps: numerical stability epsilon
|
| 562 |
+
expert_offset: offset to add to expert index
|
| 563 |
+
|
| 564 |
+
Returns:
|
| 565 |
+
(total_tokens, D) - output tensor
|
| 566 |
+
"""
|
| 567 |
+
return _GroupedPolyNormFn.apply(input, mul, weight, bias, offsets, eps,
|
| 568 |
+
expert_offset)
|
| 569 |
+
|
| 570 |
+
else:
|
| 571 |
+
|
| 572 |
+
def grouped_fused_mul_poly_norm(
|
| 573 |
+
input: Tensor,
|
| 574 |
+
mul: Tensor,
|
| 575 |
+
weight: Tensor,
|
| 576 |
+
bias: Tensor,
|
| 577 |
+
offsets: Tensor,
|
| 578 |
+
eps: float = 1e-6,
|
| 579 |
+
expert_offset: int = 0,
|
| 580 |
+
) -> Tensor:
|
| 581 |
+
raise RuntimeError(
|
| 582 |
+
"Triton is not available. Install triton to use "
|
| 583 |
+
"grouped_fused_mul_poly_norm.")
|
build/torch210-cxx11-cu128-x86_64-linux/__init__.py
CHANGED
|
@@ -2,6 +2,7 @@ import torch
|
|
| 2 |
|
| 3 |
from . import layers, parallel_style
|
| 4 |
from ._ops import ops
|
|
|
|
| 5 |
from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
|
| 6 |
from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
|
| 7 |
|
|
@@ -45,6 +46,7 @@ def fused_add_rms_norm(
|
|
| 45 |
__all__ = [
|
| 46 |
"poly_norm",
|
| 47 |
"fused_mul_poly_norm",
|
|
|
|
| 48 |
"rms_norm",
|
| 49 |
"fused_add_rms_norm",
|
| 50 |
"layers",
|
|
|
|
| 2 |
|
| 3 |
from . import layers, parallel_style
|
| 4 |
from ._ops import ops
|
| 5 |
+
from .grouped_poly_norm import grouped_fused_mul_poly_norm
|
| 6 |
from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
|
| 7 |
from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
|
| 8 |
|
|
|
|
| 46 |
__all__ = [
|
| 47 |
"poly_norm",
|
| 48 |
"fused_mul_poly_norm",
|
| 49 |
+
"grouped_fused_mul_poly_norm",
|
| 50 |
"rms_norm",
|
| 51 |
"fused_add_rms_norm",
|
| 52 |
"layers",
|
build/torch210-cxx11-cu128-x86_64-linux/{_activation_18b7543_dirty.abi3.so → _activation_0e6f27f_dirty.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 15815392
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4b8370d2e1f5561ae4b77ac8ae7b3a084e33a0d1952a8f5f9bf4700375313b35
|
| 3 |
size 15815392
|
build/torch210-cxx11-cu128-x86_64-linux/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _activation_0e6f27f_dirty
|
| 3 |
+
ops = torch.ops._activation_0e6f27f_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_activation_0e6f27f_dirty::{op_name}"
|
build/torch210-cxx11-cu128-x86_64-linux/grouped_poly_norm.py
ADDED
|
@@ -0,0 +1,583 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Triton-accelerated Grouped FusedMulPolyNorm for MoE.
|
| 2 |
+
|
| 3 |
+
Fuses the entire PolyNorm computation into two Triton kernels (fwd + bwd),
|
| 4 |
+
eliminating multiple intermediate tensors and kernel launches.
|
| 5 |
+
|
| 6 |
+
PolyNorm formula (per row):
|
| 7 |
+
poly = w[0] * rms_norm(x^3) + w[1] * rms_norm(x^2) + w[2] * rms_norm(x) + bias
|
| 8 |
+
output = poly * mul
|
| 9 |
+
|
| 10 |
+
where rms_norm(x) = x / sqrt(mean(x^2, dim=-1) + eps)
|
| 11 |
+
|
| 12 |
+
Performance optimizations:
|
| 13 |
+
- @triton.autotune selects optimal BLOCK_D, num_warps, and num_stages per
|
| 14 |
+
hidden dimension.
|
| 15 |
+
- Single-tile specialization: when D <= BLOCK_D, all data stays in registers
|
| 16 |
+
across the reduction and output phases, eliminating redundant global reads.
|
| 17 |
+
- Multi-tile software pipelining: explicit num_stages in autotune configs
|
| 18 |
+
enables overlapping memory loads with computation across loop iterations.
|
| 19 |
+
- In-kernel binary search for expert mapping: eliminates 2 PyTorch kernel
|
| 20 |
+
launches (torch.arange + torch.bucketize) per forward/backward call.
|
| 21 |
+
- Backward 2-pass optimization: pass 1 merges RMS statistics computation
|
| 22 |
+
with dot product accumulation, pass 2 computes gradients. This reduces
|
| 23 |
+
memory traffic compared to a naive 3-pass approach.
|
| 24 |
+
|
| 25 |
+
Forward kernel: one program per row, tiles over D dimension.
|
| 26 |
+
- Computes x, x^2, x^3 in registers
|
| 27 |
+
- Computes three RMS norms in a single pass (shared variance reduction)
|
| 28 |
+
- Applies polynomial weights + bias + mul in-place
|
| 29 |
+
|
| 30 |
+
Backward kernel: one program per row, tiles over D dimension.
|
| 31 |
+
- Recomputes forward intermediates from saved inputs (activation recomputation)
|
| 32 |
+
- 2-pass: (1) RMS stats + dot products + bias grad, (2) grad_input + grad_mul + weight grads
|
| 33 |
+
- Weight/bias gradients use tl.atomic_add for cross-row accumulation
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
import torch
|
| 37 |
+
from torch import Tensor
|
| 38 |
+
|
| 39 |
+
try:
|
| 40 |
+
import triton
|
| 41 |
+
import triton.language as tl
|
| 42 |
+
|
| 43 |
+
HAS_TRITON = True
|
| 44 |
+
except ImportError:
|
| 45 |
+
HAS_TRITON = False
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# ---------------------------------------------------------------------------
|
| 49 |
+
# PyTorch reference implementation (for testing and benchmarking)
|
| 50 |
+
# ---------------------------------------------------------------------------
|
| 51 |
+
def _rms_norm(x: Tensor, eps: float) -> Tensor:
|
| 52 |
+
"""Per-row RMS normalization: x / sqrt(mean(x^2, dim=-1) + eps)"""
|
| 53 |
+
return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def grouped_fused_mul_poly_norm_ref(
|
| 57 |
+
input: Tensor,
|
| 58 |
+
mul: Tensor,
|
| 59 |
+
weight: Tensor,
|
| 60 |
+
bias: Tensor,
|
| 61 |
+
offsets: Tensor,
|
| 62 |
+
eps: float = 1e-6,
|
| 63 |
+
expert_offset: int = 0,
|
| 64 |
+
) -> Tensor:
|
| 65 |
+
"""PyTorch reference for Grouped FusedMulPolyNorm (vectorized, single pass).
|
| 66 |
+
|
| 67 |
+
Uses torch.bucketize to map tokens to experts, then computes PolyNorm
|
| 68 |
+
for all tokens at once. torch.compile friendly.
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
input: (total_tokens, D) - concatenated tokens for all experts
|
| 72 |
+
mul: (total_tokens, D) - gate values to multiply with
|
| 73 |
+
weight: (num_experts, 3) - per-expert polynomial weights [x^3, x^2, x]
|
| 74 |
+
bias: (num_experts, 1) - per-expert polynomial bias
|
| 75 |
+
offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32)
|
| 76 |
+
eps: numerical stability epsilon
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
(total_tokens, D) - output tensor
|
| 80 |
+
"""
|
| 81 |
+
orig_dtype = input.dtype
|
| 82 |
+
|
| 83 |
+
token_positions = torch.arange(input.shape[0], device=input.device)
|
| 84 |
+
expert_idx = torch.bucketize(token_positions, offsets, right=True) + expert_offset
|
| 85 |
+
|
| 86 |
+
weight_fp32 = weight.float()
|
| 87 |
+
bias_fp32 = bias.float()
|
| 88 |
+
|
| 89 |
+
per_token_w = weight_fp32[expert_idx]
|
| 90 |
+
per_token_b = bias_fp32[expert_idx]
|
| 91 |
+
|
| 92 |
+
x = input.float()
|
| 93 |
+
m = mul.float()
|
| 94 |
+
|
| 95 |
+
x2 = x * x
|
| 96 |
+
x3 = x2 * x
|
| 97 |
+
|
| 98 |
+
poly = (per_token_w[:, 0:1] * _rms_norm(x3, eps) +
|
| 99 |
+
per_token_w[:, 1:2] * _rms_norm(x2, eps) +
|
| 100 |
+
per_token_w[:, 2:3] * _rms_norm(x, eps) + per_token_b)
|
| 101 |
+
|
| 102 |
+
return (poly * m).to(orig_dtype)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
# ---------------------------------------------------------------------------
|
| 106 |
+
# Triton kernel implementation
|
| 107 |
+
# ---------------------------------------------------------------------------
|
| 108 |
+
if HAS_TRITON:
|
| 109 |
+
# --- Autotune configurations ---
|
| 110 |
+
_GROUPED_POLYNORM_FWD_CONFIGS = [
|
| 111 |
+
triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2),
|
| 112 |
+
triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3),
|
| 113 |
+
triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2),
|
| 114 |
+
triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3),
|
| 115 |
+
triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2),
|
| 116 |
+
triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4),
|
| 117 |
+
triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2),
|
| 118 |
+
triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3),
|
| 119 |
+
triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4),
|
| 120 |
+
triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2),
|
| 121 |
+
triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3),
|
| 122 |
+
triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4),
|
| 123 |
+
triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2),
|
| 124 |
+
triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2),
|
| 125 |
+
triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3),
|
| 126 |
+
triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2),
|
| 127 |
+
triton.Config({"BLOCK_D": 2048}, num_warps=4, num_stages=1),
|
| 128 |
+
triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1),
|
| 129 |
+
triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1),
|
| 130 |
+
triton.Config({"BLOCK_D": 2048}, num_warps=32, num_stages=1),
|
| 131 |
+
]
|
| 132 |
+
|
| 133 |
+
_GROUPED_POLYNORM_BWD_CONFIGS = [
|
| 134 |
+
triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2),
|
| 135 |
+
triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3),
|
| 136 |
+
triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2),
|
| 137 |
+
triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3),
|
| 138 |
+
triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2),
|
| 139 |
+
triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4),
|
| 140 |
+
triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2),
|
| 141 |
+
triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3),
|
| 142 |
+
triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4),
|
| 143 |
+
triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=5),
|
| 144 |
+
triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2),
|
| 145 |
+
triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3),
|
| 146 |
+
triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4),
|
| 147 |
+
triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2),
|
| 148 |
+
triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2),
|
| 149 |
+
triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3),
|
| 150 |
+
triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=4),
|
| 151 |
+
triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2),
|
| 152 |
+
triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1),
|
| 153 |
+
triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1),
|
| 154 |
+
]
|
| 155 |
+
|
| 156 |
+
@triton.autotune(
|
| 157 |
+
configs=_GROUPED_POLYNORM_FWD_CONFIGS,
|
| 158 |
+
key=["D"],
|
| 159 |
+
)
|
| 160 |
+
@triton.jit
|
| 161 |
+
def _grouped_polynorm_fwd_kernel(
|
| 162 |
+
input_ptr,
|
| 163 |
+
mul_ptr,
|
| 164 |
+
weight_ptr,
|
| 165 |
+
bias_ptr,
|
| 166 |
+
offsets_ptr,
|
| 167 |
+
output_ptr,
|
| 168 |
+
N,
|
| 169 |
+
D,
|
| 170 |
+
num_experts,
|
| 171 |
+
eps,
|
| 172 |
+
expert_offset,
|
| 173 |
+
stride_input_row,
|
| 174 |
+
stride_mul_row,
|
| 175 |
+
stride_out_row,
|
| 176 |
+
BLOCK_D: tl.constexpr,
|
| 177 |
+
):
|
| 178 |
+
"""Forward kernel: one program per row."""
|
| 179 |
+
row = tl.program_id(0)
|
| 180 |
+
if row >= N:
|
| 181 |
+
return
|
| 182 |
+
|
| 183 |
+
# Binary search for expert index (12 iters covers up to 4096 experts)
|
| 184 |
+
lo = 0
|
| 185 |
+
hi = num_experts
|
| 186 |
+
for _ in range(12):
|
| 187 |
+
if lo < hi:
|
| 188 |
+
mid = (lo + hi) // 2
|
| 189 |
+
if tl.load(offsets_ptr + mid) <= row:
|
| 190 |
+
lo = mid + 1
|
| 191 |
+
else:
|
| 192 |
+
hi = mid
|
| 193 |
+
eidx = lo + expert_offset
|
| 194 |
+
|
| 195 |
+
w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32)
|
| 196 |
+
w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32)
|
| 197 |
+
w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32)
|
| 198 |
+
b = tl.load(bias_ptr + eidx).to(tl.float32)
|
| 199 |
+
|
| 200 |
+
input_row_ptr = input_ptr + row * stride_input_row
|
| 201 |
+
mul_row_ptr = mul_ptr + row * stride_mul_row
|
| 202 |
+
out_row_ptr = output_ptr + row * stride_out_row
|
| 203 |
+
|
| 204 |
+
D_float = D.to(tl.float32)
|
| 205 |
+
|
| 206 |
+
# --- Single-tile path ---
|
| 207 |
+
if D <= BLOCK_D:
|
| 208 |
+
d_offs = tl.arange(0, BLOCK_D)
|
| 209 |
+
mask = d_offs < D
|
| 210 |
+
|
| 211 |
+
x = tl.load(input_row_ptr + d_offs, mask=mask,
|
| 212 |
+
other=0.0).to(tl.float32)
|
| 213 |
+
m = tl.load(mul_row_ptr + d_offs, mask=mask,
|
| 214 |
+
other=0.0).to(tl.float32)
|
| 215 |
+
|
| 216 |
+
x2 = x * x
|
| 217 |
+
x3 = x2 * x
|
| 218 |
+
|
| 219 |
+
inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps)
|
| 220 |
+
inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps)
|
| 221 |
+
inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps)
|
| 222 |
+
|
| 223 |
+
# Pre-multiply scalar weight * inv_rms to save 1 FMA per element
|
| 224 |
+
w0_inv = w0 * inv_rms_x3
|
| 225 |
+
w1_inv = w1 * inv_rms_x2
|
| 226 |
+
w2_inv = w2 * inv_rms_x
|
| 227 |
+
|
| 228 |
+
poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b
|
| 229 |
+
tl.store(out_row_ptr + d_offs, poly * m, mask=mask)
|
| 230 |
+
else:
|
| 231 |
+
# --- Multi-tile: two-pass approach ---
|
| 232 |
+
sum_x2 = tl.zeros((), dtype=tl.float32)
|
| 233 |
+
sum_x4 = tl.zeros((), dtype=tl.float32)
|
| 234 |
+
sum_x6 = tl.zeros((), dtype=tl.float32)
|
| 235 |
+
|
| 236 |
+
for d_start in range(0, D, BLOCK_D):
|
| 237 |
+
d_offs = d_start + tl.arange(0, BLOCK_D)
|
| 238 |
+
mask = d_offs < D
|
| 239 |
+
x = tl.load(input_row_ptr + d_offs, mask=mask,
|
| 240 |
+
other=0.0).to(tl.float32)
|
| 241 |
+
x2 = x * x
|
| 242 |
+
sum_x2 += tl.sum(x2)
|
| 243 |
+
sum_x4 += tl.sum(x2 * x2)
|
| 244 |
+
sum_x6 += tl.sum(x2 * x2 * x2)
|
| 245 |
+
|
| 246 |
+
inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps)
|
| 247 |
+
inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps)
|
| 248 |
+
inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps)
|
| 249 |
+
|
| 250 |
+
# Pre-multiply scalar weight * inv_rms
|
| 251 |
+
w0_inv = w0 * inv_rms_x3
|
| 252 |
+
w1_inv = w1 * inv_rms_x2
|
| 253 |
+
w2_inv = w2 * inv_rms_x
|
| 254 |
+
|
| 255 |
+
for d_start in range(0, D, BLOCK_D):
|
| 256 |
+
d_offs = d_start + tl.arange(0, BLOCK_D)
|
| 257 |
+
mask = d_offs < D
|
| 258 |
+
x = tl.load(input_row_ptr + d_offs, mask=mask,
|
| 259 |
+
other=0.0).to(tl.float32)
|
| 260 |
+
m = tl.load(mul_row_ptr + d_offs, mask=mask,
|
| 261 |
+
other=0.0).to(tl.float32)
|
| 262 |
+
x2 = x * x
|
| 263 |
+
x3 = x2 * x
|
| 264 |
+
poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b
|
| 265 |
+
tl.store(out_row_ptr + d_offs, poly * m, mask=mask)
|
| 266 |
+
|
| 267 |
+
@triton.autotune(
|
| 268 |
+
configs=_GROUPED_POLYNORM_BWD_CONFIGS,
|
| 269 |
+
key=["D"],
|
| 270 |
+
reset_to_zero=["grad_weight_ptr", "grad_bias_ptr"],
|
| 271 |
+
)
|
| 272 |
+
@triton.jit
|
| 273 |
+
def _grouped_polynorm_bwd_kernel(
|
| 274 |
+
grad_out_ptr,
|
| 275 |
+
input_ptr,
|
| 276 |
+
mul_ptr,
|
| 277 |
+
weight_ptr,
|
| 278 |
+
bias_ptr,
|
| 279 |
+
offsets_ptr,
|
| 280 |
+
grad_input_ptr,
|
| 281 |
+
grad_mul_ptr,
|
| 282 |
+
grad_weight_ptr,
|
| 283 |
+
grad_bias_ptr,
|
| 284 |
+
N,
|
| 285 |
+
D,
|
| 286 |
+
num_experts,
|
| 287 |
+
eps,
|
| 288 |
+
expert_offset,
|
| 289 |
+
stride_row,
|
| 290 |
+
BLOCK_D: tl.constexpr,
|
| 291 |
+
):
|
| 292 |
+
"""Backward kernel: one program per row, 2-pass approach.
|
| 293 |
+
|
| 294 |
+
Pass 1: RMS stats + dot products + bias grad
|
| 295 |
+
Pass 2: grad_input + grad_mul + weight grads (via atomic_add)
|
| 296 |
+
"""
|
| 297 |
+
row = tl.program_id(0)
|
| 298 |
+
if row >= N:
|
| 299 |
+
return
|
| 300 |
+
|
| 301 |
+
lo = 0
|
| 302 |
+
hi = num_experts
|
| 303 |
+
for _ in range(12):
|
| 304 |
+
if lo < hi:
|
| 305 |
+
mid = (lo + hi) // 2
|
| 306 |
+
if tl.load(offsets_ptr + mid) <= row:
|
| 307 |
+
lo = mid + 1
|
| 308 |
+
else:
|
| 309 |
+
hi = mid
|
| 310 |
+
eidx = lo + expert_offset
|
| 311 |
+
|
| 312 |
+
w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32)
|
| 313 |
+
w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32)
|
| 314 |
+
w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32)
|
| 315 |
+
b_val = tl.load(bias_ptr + eidx).to(tl.float32)
|
| 316 |
+
|
| 317 |
+
input_row_ptr = input_ptr + row * stride_row
|
| 318 |
+
mul_row_ptr = mul_ptr + row * stride_row
|
| 319 |
+
grad_out_row_ptr = grad_out_ptr + row * stride_row
|
| 320 |
+
grad_input_row_ptr = grad_input_ptr + row * stride_row
|
| 321 |
+
grad_mul_row_ptr = grad_mul_ptr + row * stride_row
|
| 322 |
+
|
| 323 |
+
D_float = D.to(tl.float32)
|
| 324 |
+
|
| 325 |
+
# --- Single-tile path ---
|
| 326 |
+
if D <= BLOCK_D:
|
| 327 |
+
d_offs = tl.arange(0, BLOCK_D)
|
| 328 |
+
mask = d_offs < D
|
| 329 |
+
|
| 330 |
+
x = tl.load(input_row_ptr + d_offs, mask=mask,
|
| 331 |
+
other=0.0).to(tl.float32)
|
| 332 |
+
m = tl.load(mul_row_ptr + d_offs, mask=mask,
|
| 333 |
+
other=0.0).to(tl.float32)
|
| 334 |
+
go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
|
| 335 |
+
other=0.0).to(tl.float32)
|
| 336 |
+
|
| 337 |
+
x2 = x * x
|
| 338 |
+
x3 = x2 * x
|
| 339 |
+
|
| 340 |
+
# Compute RMS stats (x4 inlined to reduce register pressure)
|
| 341 |
+
inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps)
|
| 342 |
+
inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps)
|
| 343 |
+
inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps)
|
| 344 |
+
|
| 345 |
+
w0_inv = w0 * inv_rms_x3
|
| 346 |
+
w1_inv = w1 * inv_rms_x2
|
| 347 |
+
w2_inv = w2 * inv_rms_x
|
| 348 |
+
|
| 349 |
+
dpoly = go * m
|
| 350 |
+
|
| 351 |
+
# Dot products for coefficients and weight grads
|
| 352 |
+
sum_dpoly_x = tl.sum(dpoly * x)
|
| 353 |
+
sum_dpoly_x2 = tl.sum(dpoly * x2)
|
| 354 |
+
sum_dpoly_x3 = tl.sum(dpoly * x3)
|
| 355 |
+
grad_b_acc = tl.sum(dpoly)
|
| 356 |
+
|
| 357 |
+
# Weight grads
|
| 358 |
+
grad_w0_acc = inv_rms_x3 * sum_dpoly_x3
|
| 359 |
+
grad_w1_acc = inv_rms_x2 * sum_dpoly_x2
|
| 360 |
+
grad_w2_acc = inv_rms_x * sum_dpoly_x
|
| 361 |
+
|
| 362 |
+
# Coefficients for grad_input
|
| 363 |
+
coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float
|
| 364 |
+
coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float
|
| 365 |
+
coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float
|
| 366 |
+
|
| 367 |
+
# grad_mul
|
| 368 |
+
poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv
|
| 369 |
+
tl.store(grad_mul_row_ptr + d_offs, go * (poly + b_val), mask=mask)
|
| 370 |
+
|
| 371 |
+
# grad_input (in-place accumulation to reduce register pressure)
|
| 372 |
+
g = inv_rms_x * (w2 * dpoly - x * coeff_x)
|
| 373 |
+
g += 2.0 * x * inv_rms_x2 * (w1 * dpoly - x2 * coeff_x2)
|
| 374 |
+
g += 3.0 * x2 * inv_rms_x3 * (w0 * dpoly - x3 * coeff_x3)
|
| 375 |
+
tl.store(grad_input_row_ptr + d_offs, g, mask=mask)
|
| 376 |
+
|
| 377 |
+
tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc)
|
| 378 |
+
tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc)
|
| 379 |
+
tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc)
|
| 380 |
+
tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc)
|
| 381 |
+
else:
|
| 382 |
+
# --- Multi-tile: 2-pass ---
|
| 383 |
+
# Pass 1: RMS stats + dot products + bias grad
|
| 384 |
+
sum_x2 = tl.zeros((), dtype=tl.float32)
|
| 385 |
+
sum_x4 = tl.zeros((), dtype=tl.float32)
|
| 386 |
+
sum_x6 = tl.zeros((), dtype=tl.float32)
|
| 387 |
+
sum_dpoly_x = tl.zeros((), dtype=tl.float32)
|
| 388 |
+
sum_dpoly_x2 = tl.zeros((), dtype=tl.float32)
|
| 389 |
+
sum_dpoly_x3 = tl.zeros((), dtype=tl.float32)
|
| 390 |
+
grad_b_acc = tl.zeros((), dtype=tl.float32)
|
| 391 |
+
|
| 392 |
+
for d_start in range(0, D, BLOCK_D):
|
| 393 |
+
d_offs = d_start + tl.arange(0, BLOCK_D)
|
| 394 |
+
mask = d_offs < D
|
| 395 |
+
x = tl.load(input_row_ptr + d_offs, mask=mask,
|
| 396 |
+
other=0.0).to(tl.float32)
|
| 397 |
+
m = tl.load(mul_row_ptr + d_offs, mask=mask,
|
| 398 |
+
other=0.0).to(tl.float32)
|
| 399 |
+
go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
|
| 400 |
+
other=0.0).to(tl.float32)
|
| 401 |
+
|
| 402 |
+
x2 = x * x
|
| 403 |
+
x3 = x2 * x
|
| 404 |
+
dpoly = go * m
|
| 405 |
+
|
| 406 |
+
sum_x2 += tl.sum(x2)
|
| 407 |
+
sum_x4 += tl.sum(x2 * x2)
|
| 408 |
+
sum_x6 += tl.sum(x2 * x2 * x2)
|
| 409 |
+
sum_dpoly_x += tl.sum(dpoly * x)
|
| 410 |
+
sum_dpoly_x2 += tl.sum(dpoly * x2)
|
| 411 |
+
sum_dpoly_x3 += tl.sum(dpoly * x3)
|
| 412 |
+
grad_b_acc += tl.sum(dpoly)
|
| 413 |
+
|
| 414 |
+
inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps)
|
| 415 |
+
inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps)
|
| 416 |
+
inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps)
|
| 417 |
+
|
| 418 |
+
w0_inv = w0 * inv_rms_x3
|
| 419 |
+
w1_inv = w1 * inv_rms_x2
|
| 420 |
+
w2_inv = w2 * inv_rms_x
|
| 421 |
+
|
| 422 |
+
# Weight grads
|
| 423 |
+
grad_w0_acc = inv_rms_x3 * sum_dpoly_x3
|
| 424 |
+
grad_w1_acc = inv_rms_x2 * sum_dpoly_x2
|
| 425 |
+
grad_w2_acc = inv_rms_x * sum_dpoly_x
|
| 426 |
+
|
| 427 |
+
# Coefficients for grad_input
|
| 428 |
+
coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float
|
| 429 |
+
coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float
|
| 430 |
+
coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float
|
| 431 |
+
|
| 432 |
+
# Pass 2: grad_input + grad_mul
|
| 433 |
+
for d_start in range(0, D, BLOCK_D):
|
| 434 |
+
d_offs = d_start + tl.arange(0, BLOCK_D)
|
| 435 |
+
mask = d_offs < D
|
| 436 |
+
x = tl.load(input_row_ptr + d_offs, mask=mask,
|
| 437 |
+
other=0.0).to(tl.float32)
|
| 438 |
+
m = tl.load(mul_row_ptr + d_offs, mask=mask,
|
| 439 |
+
other=0.0).to(tl.float32)
|
| 440 |
+
go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
|
| 441 |
+
other=0.0).to(tl.float32)
|
| 442 |
+
|
| 443 |
+
x2 = x * x
|
| 444 |
+
x3 = x2 * x
|
| 445 |
+
|
| 446 |
+
poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv
|
| 447 |
+
tl.store(grad_mul_row_ptr + d_offs,
|
| 448 |
+
go * (poly + b_val),
|
| 449 |
+
mask=mask)
|
| 450 |
+
|
| 451 |
+
dpoly = go * m
|
| 452 |
+
g = inv_rms_x * (w2 * dpoly - x * coeff_x)
|
| 453 |
+
g += (2.0 * x * inv_rms_x2 *
|
| 454 |
+
(w1 * dpoly - x2 * coeff_x2))
|
| 455 |
+
g += (3.0 * x2 * inv_rms_x3 *
|
| 456 |
+
(w0 * dpoly - x3 * coeff_x3))
|
| 457 |
+
|
| 458 |
+
tl.store(grad_input_row_ptr + d_offs, g, mask=mask)
|
| 459 |
+
|
| 460 |
+
tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc)
|
| 461 |
+
tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc)
|
| 462 |
+
tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc)
|
| 463 |
+
tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc)
|
| 464 |
+
|
| 465 |
+
class _GroupedPolyNormFn(torch.autograd.Function):
|
| 466 |
+
|
| 467 |
+
@staticmethod
|
| 468 |
+
def forward(ctx, input, mul, weight, bias, offsets, eps, expert_offset):
|
| 469 |
+
N, D = input.shape
|
| 470 |
+
input = input.contiguous()
|
| 471 |
+
mul = mul.contiguous()
|
| 472 |
+
output = torch.empty_like(input)
|
| 473 |
+
|
| 474 |
+
num_experts = offsets.shape[0]
|
| 475 |
+
assert num_experts <= 4096, (
|
| 476 |
+
f"Supports at most 4096 experts, got {num_experts}.")
|
| 477 |
+
|
| 478 |
+
_grouped_polynorm_fwd_kernel[(N,)](
|
| 479 |
+
input,
|
| 480 |
+
mul,
|
| 481 |
+
weight,
|
| 482 |
+
bias,
|
| 483 |
+
offsets,
|
| 484 |
+
output,
|
| 485 |
+
N,
|
| 486 |
+
D,
|
| 487 |
+
num_experts,
|
| 488 |
+
eps,
|
| 489 |
+
expert_offset,
|
| 490 |
+
stride_input_row=input.stride(0),
|
| 491 |
+
stride_mul_row=mul.stride(0),
|
| 492 |
+
stride_out_row=output.stride(0),
|
| 493 |
+
)
|
| 494 |
+
|
| 495 |
+
ctx.save_for_backward(input, mul, weight, bias, offsets)
|
| 496 |
+
ctx.eps = eps
|
| 497 |
+
ctx.expert_offset = expert_offset
|
| 498 |
+
return output
|
| 499 |
+
|
| 500 |
+
@staticmethod
|
| 501 |
+
def backward(ctx, grad_output):
|
| 502 |
+
input, mul, weight, bias, offsets = ctx.saved_tensors
|
| 503 |
+
eps = ctx.eps
|
| 504 |
+
expert_offset = ctx.expert_offset
|
| 505 |
+
N, D = input.shape
|
| 506 |
+
|
| 507 |
+
grad_output = grad_output.contiguous()
|
| 508 |
+
grad_input = torch.empty_like(input)
|
| 509 |
+
grad_mul = torch.empty_like(mul)
|
| 510 |
+
grad_weight = torch.zeros(weight.shape[0],
|
| 511 |
+
3,
|
| 512 |
+
device=weight.device,
|
| 513 |
+
dtype=torch.float32)
|
| 514 |
+
grad_bias = torch.zeros(bias.shape[0],
|
| 515 |
+
device=bias.device,
|
| 516 |
+
dtype=torch.float32)
|
| 517 |
+
|
| 518 |
+
num_experts = offsets.shape[0]
|
| 519 |
+
|
| 520 |
+
_grouped_polynorm_bwd_kernel[(N,)](
|
| 521 |
+
grad_output,
|
| 522 |
+
input,
|
| 523 |
+
mul,
|
| 524 |
+
weight,
|
| 525 |
+
bias,
|
| 526 |
+
offsets,
|
| 527 |
+
grad_input,
|
| 528 |
+
grad_mul,
|
| 529 |
+
grad_weight,
|
| 530 |
+
grad_bias,
|
| 531 |
+
N,
|
| 532 |
+
D,
|
| 533 |
+
num_experts,
|
| 534 |
+
eps,
|
| 535 |
+
expert_offset,
|
| 536 |
+
stride_row=input.stride(0),
|
| 537 |
+
)
|
| 538 |
+
|
| 539 |
+
grad_weight = grad_weight.to(weight.dtype)
|
| 540 |
+
grad_bias = grad_bias.unsqueeze(-1).to(bias.dtype)
|
| 541 |
+
|
| 542 |
+
return grad_input, grad_mul, grad_weight, grad_bias, None, None, None
|
| 543 |
+
|
| 544 |
+
def grouped_fused_mul_poly_norm(
|
| 545 |
+
input: Tensor,
|
| 546 |
+
mul: Tensor,
|
| 547 |
+
weight: Tensor,
|
| 548 |
+
bias: Tensor,
|
| 549 |
+
offsets: Tensor,
|
| 550 |
+
eps: float = 1e-6,
|
| 551 |
+
expert_offset: int = 0,
|
| 552 |
+
) -> Tensor:
|
| 553 |
+
"""Triton-accelerated Grouped FusedMulPolyNorm.
|
| 554 |
+
|
| 555 |
+
Args:
|
| 556 |
+
input: (total_tokens, D) - concatenated tokens for all experts
|
| 557 |
+
mul: (total_tokens, D) - gate values to multiply with
|
| 558 |
+
weight: (num_experts, 3) - per-expert polynomial weights
|
| 559 |
+
bias: (num_experts, 1) - per-expert polynomial bias
|
| 560 |
+
offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32)
|
| 561 |
+
eps: numerical stability epsilon
|
| 562 |
+
expert_offset: offset to add to expert index
|
| 563 |
+
|
| 564 |
+
Returns:
|
| 565 |
+
(total_tokens, D) - output tensor
|
| 566 |
+
"""
|
| 567 |
+
return _GroupedPolyNormFn.apply(input, mul, weight, bias, offsets, eps,
|
| 568 |
+
expert_offset)
|
| 569 |
+
|
| 570 |
+
else:
|
| 571 |
+
|
| 572 |
+
def grouped_fused_mul_poly_norm(
|
| 573 |
+
input: Tensor,
|
| 574 |
+
mul: Tensor,
|
| 575 |
+
weight: Tensor,
|
| 576 |
+
bias: Tensor,
|
| 577 |
+
offsets: Tensor,
|
| 578 |
+
eps: float = 1e-6,
|
| 579 |
+
expert_offset: int = 0,
|
| 580 |
+
) -> Tensor:
|
| 581 |
+
raise RuntimeError(
|
| 582 |
+
"Triton is not available. Install triton to use "
|
| 583 |
+
"grouped_fused_mul_poly_norm.")
|
build/torch210-cxx11-cu130-x86_64-linux/__init__.py
CHANGED
|
@@ -2,6 +2,7 @@ import torch
|
|
| 2 |
|
| 3 |
from . import layers, parallel_style
|
| 4 |
from ._ops import ops
|
|
|
|
| 5 |
from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
|
| 6 |
from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
|
| 7 |
|
|
@@ -45,6 +46,7 @@ def fused_add_rms_norm(
|
|
| 45 |
__all__ = [
|
| 46 |
"poly_norm",
|
| 47 |
"fused_mul_poly_norm",
|
|
|
|
| 48 |
"rms_norm",
|
| 49 |
"fused_add_rms_norm",
|
| 50 |
"layers",
|
|
|
|
| 2 |
|
| 3 |
from . import layers, parallel_style
|
| 4 |
from ._ops import ops
|
| 5 |
+
from .grouped_poly_norm import grouped_fused_mul_poly_norm
|
| 6 |
from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
|
| 7 |
from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
|
| 8 |
|
|
|
|
| 46 |
__all__ = [
|
| 47 |
"poly_norm",
|
| 48 |
"fused_mul_poly_norm",
|
| 49 |
+
"grouped_fused_mul_poly_norm",
|
| 50 |
"rms_norm",
|
| 51 |
"fused_add_rms_norm",
|
| 52 |
"layers",
|
build/torch210-cxx11-cu130-x86_64-linux/{_activation_18b7543_dirty.abi3.so → _activation_0e6f27f_dirty.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 13520952
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:edf3fca2079788750c4e0497012ba93c34c770aca4c9d4f22d03be4a86a2ce8c
|
| 3 |
size 13520952
|
build/torch210-cxx11-cu130-x86_64-linux/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _activation_0e6f27f_dirty
|
| 3 |
+
ops = torch.ops._activation_0e6f27f_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_activation_0e6f27f_dirty::{op_name}"
|
build/torch210-cxx11-cu130-x86_64-linux/grouped_poly_norm.py
ADDED
|
@@ -0,0 +1,583 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Triton-accelerated Grouped FusedMulPolyNorm for MoE.
|
| 2 |
+
|
| 3 |
+
Fuses the entire PolyNorm computation into two Triton kernels (fwd + bwd),
|
| 4 |
+
eliminating multiple intermediate tensors and kernel launches.
|
| 5 |
+
|
| 6 |
+
PolyNorm formula (per row):
|
| 7 |
+
poly = w[0] * rms_norm(x^3) + w[1] * rms_norm(x^2) + w[2] * rms_norm(x) + bias
|
| 8 |
+
output = poly * mul
|
| 9 |
+
|
| 10 |
+
where rms_norm(x) = x / sqrt(mean(x^2, dim=-1) + eps)
|
| 11 |
+
|
| 12 |
+
Performance optimizations:
|
| 13 |
+
- @triton.autotune selects optimal BLOCK_D, num_warps, and num_stages per
|
| 14 |
+
hidden dimension.
|
| 15 |
+
- Single-tile specialization: when D <= BLOCK_D, all data stays in registers
|
| 16 |
+
across the reduction and output phases, eliminating redundant global reads.
|
| 17 |
+
- Multi-tile software pipelining: explicit num_stages in autotune configs
|
| 18 |
+
enables overlapping memory loads with computation across loop iterations.
|
| 19 |
+
- In-kernel binary search for expert mapping: eliminates 2 PyTorch kernel
|
| 20 |
+
launches (torch.arange + torch.bucketize) per forward/backward call.
|
| 21 |
+
- Backward 2-pass optimization: pass 1 merges RMS statistics computation
|
| 22 |
+
with dot product accumulation, pass 2 computes gradients. This reduces
|
| 23 |
+
memory traffic compared to a naive 3-pass approach.
|
| 24 |
+
|
| 25 |
+
Forward kernel: one program per row, tiles over D dimension.
|
| 26 |
+
- Computes x, x^2, x^3 in registers
|
| 27 |
+
- Computes three RMS norms in a single pass (shared variance reduction)
|
| 28 |
+
- Applies polynomial weights + bias + mul in-place
|
| 29 |
+
|
| 30 |
+
Backward kernel: one program per row, tiles over D dimension.
|
| 31 |
+
- Recomputes forward intermediates from saved inputs (activation recomputation)
|
| 32 |
+
- 2-pass: (1) RMS stats + dot products + bias grad, (2) grad_input + grad_mul + weight grads
|
| 33 |
+
- Weight/bias gradients use tl.atomic_add for cross-row accumulation
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
import torch
|
| 37 |
+
from torch import Tensor
|
| 38 |
+
|
| 39 |
+
try:
|
| 40 |
+
import triton
|
| 41 |
+
import triton.language as tl
|
| 42 |
+
|
| 43 |
+
HAS_TRITON = True
|
| 44 |
+
except ImportError:
|
| 45 |
+
HAS_TRITON = False
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# ---------------------------------------------------------------------------
|
| 49 |
+
# PyTorch reference implementation (for testing and benchmarking)
|
| 50 |
+
# ---------------------------------------------------------------------------
|
| 51 |
+
def _rms_norm(x: Tensor, eps: float) -> Tensor:
|
| 52 |
+
"""Per-row RMS normalization: x / sqrt(mean(x^2, dim=-1) + eps)"""
|
| 53 |
+
return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def grouped_fused_mul_poly_norm_ref(
|
| 57 |
+
input: Tensor,
|
| 58 |
+
mul: Tensor,
|
| 59 |
+
weight: Tensor,
|
| 60 |
+
bias: Tensor,
|
| 61 |
+
offsets: Tensor,
|
| 62 |
+
eps: float = 1e-6,
|
| 63 |
+
expert_offset: int = 0,
|
| 64 |
+
) -> Tensor:
|
| 65 |
+
"""PyTorch reference for Grouped FusedMulPolyNorm (vectorized, single pass).
|
| 66 |
+
|
| 67 |
+
Uses torch.bucketize to map tokens to experts, then computes PolyNorm
|
| 68 |
+
for all tokens at once. torch.compile friendly.
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
input: (total_tokens, D) - concatenated tokens for all experts
|
| 72 |
+
mul: (total_tokens, D) - gate values to multiply with
|
| 73 |
+
weight: (num_experts, 3) - per-expert polynomial weights [x^3, x^2, x]
|
| 74 |
+
bias: (num_experts, 1) - per-expert polynomial bias
|
| 75 |
+
offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32)
|
| 76 |
+
eps: numerical stability epsilon
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
(total_tokens, D) - output tensor
|
| 80 |
+
"""
|
| 81 |
+
orig_dtype = input.dtype
|
| 82 |
+
|
| 83 |
+
token_positions = torch.arange(input.shape[0], device=input.device)
|
| 84 |
+
expert_idx = torch.bucketize(token_positions, offsets, right=True) + expert_offset
|
| 85 |
+
|
| 86 |
+
weight_fp32 = weight.float()
|
| 87 |
+
bias_fp32 = bias.float()
|
| 88 |
+
|
| 89 |
+
per_token_w = weight_fp32[expert_idx]
|
| 90 |
+
per_token_b = bias_fp32[expert_idx]
|
| 91 |
+
|
| 92 |
+
x = input.float()
|
| 93 |
+
m = mul.float()
|
| 94 |
+
|
| 95 |
+
x2 = x * x
|
| 96 |
+
x3 = x2 * x
|
| 97 |
+
|
| 98 |
+
poly = (per_token_w[:, 0:1] * _rms_norm(x3, eps) +
|
| 99 |
+
per_token_w[:, 1:2] * _rms_norm(x2, eps) +
|
| 100 |
+
per_token_w[:, 2:3] * _rms_norm(x, eps) + per_token_b)
|
| 101 |
+
|
| 102 |
+
return (poly * m).to(orig_dtype)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
# ---------------------------------------------------------------------------
|
| 106 |
+
# Triton kernel implementation
|
| 107 |
+
# ---------------------------------------------------------------------------
|
| 108 |
+
if HAS_TRITON:
|
| 109 |
+
# --- Autotune configurations ---
|
| 110 |
+
_GROUPED_POLYNORM_FWD_CONFIGS = [
|
| 111 |
+
triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2),
|
| 112 |
+
triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3),
|
| 113 |
+
triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2),
|
| 114 |
+
triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3),
|
| 115 |
+
triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2),
|
| 116 |
+
triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4),
|
| 117 |
+
triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2),
|
| 118 |
+
triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3),
|
| 119 |
+
triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4),
|
| 120 |
+
triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2),
|
| 121 |
+
triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3),
|
| 122 |
+
triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4),
|
| 123 |
+
triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2),
|
| 124 |
+
triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2),
|
| 125 |
+
triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3),
|
| 126 |
+
triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2),
|
| 127 |
+
triton.Config({"BLOCK_D": 2048}, num_warps=4, num_stages=1),
|
| 128 |
+
triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1),
|
| 129 |
+
triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1),
|
| 130 |
+
triton.Config({"BLOCK_D": 2048}, num_warps=32, num_stages=1),
|
| 131 |
+
]
|
| 132 |
+
|
| 133 |
+
_GROUPED_POLYNORM_BWD_CONFIGS = [
|
| 134 |
+
triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2),
|
| 135 |
+
triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3),
|
| 136 |
+
triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2),
|
| 137 |
+
triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3),
|
| 138 |
+
triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2),
|
| 139 |
+
triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4),
|
| 140 |
+
triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2),
|
| 141 |
+
triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3),
|
| 142 |
+
triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4),
|
| 143 |
+
triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=5),
|
| 144 |
+
triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2),
|
| 145 |
+
triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3),
|
| 146 |
+
triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4),
|
| 147 |
+
triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2),
|
| 148 |
+
triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2),
|
| 149 |
+
triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3),
|
| 150 |
+
triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=4),
|
| 151 |
+
triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2),
|
| 152 |
+
triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1),
|
| 153 |
+
triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1),
|
| 154 |
+
]
|
| 155 |
+
|
| 156 |
+
@triton.autotune(
|
| 157 |
+
configs=_GROUPED_POLYNORM_FWD_CONFIGS,
|
| 158 |
+
key=["D"],
|
| 159 |
+
)
|
| 160 |
+
@triton.jit
|
| 161 |
+
def _grouped_polynorm_fwd_kernel(
|
| 162 |
+
input_ptr,
|
| 163 |
+
mul_ptr,
|
| 164 |
+
weight_ptr,
|
| 165 |
+
bias_ptr,
|
| 166 |
+
offsets_ptr,
|
| 167 |
+
output_ptr,
|
| 168 |
+
N,
|
| 169 |
+
D,
|
| 170 |
+
num_experts,
|
| 171 |
+
eps,
|
| 172 |
+
expert_offset,
|
| 173 |
+
stride_input_row,
|
| 174 |
+
stride_mul_row,
|
| 175 |
+
stride_out_row,
|
| 176 |
+
BLOCK_D: tl.constexpr,
|
| 177 |
+
):
|
| 178 |
+
"""Forward kernel: one program per row."""
|
| 179 |
+
row = tl.program_id(0)
|
| 180 |
+
if row >= N:
|
| 181 |
+
return
|
| 182 |
+
|
| 183 |
+
# Binary search for expert index (12 iters covers up to 4096 experts)
|
| 184 |
+
lo = 0
|
| 185 |
+
hi = num_experts
|
| 186 |
+
for _ in range(12):
|
| 187 |
+
if lo < hi:
|
| 188 |
+
mid = (lo + hi) // 2
|
| 189 |
+
if tl.load(offsets_ptr + mid) <= row:
|
| 190 |
+
lo = mid + 1
|
| 191 |
+
else:
|
| 192 |
+
hi = mid
|
| 193 |
+
eidx = lo + expert_offset
|
| 194 |
+
|
| 195 |
+
w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32)
|
| 196 |
+
w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32)
|
| 197 |
+
w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32)
|
| 198 |
+
b = tl.load(bias_ptr + eidx).to(tl.float32)
|
| 199 |
+
|
| 200 |
+
input_row_ptr = input_ptr + row * stride_input_row
|
| 201 |
+
mul_row_ptr = mul_ptr + row * stride_mul_row
|
| 202 |
+
out_row_ptr = output_ptr + row * stride_out_row
|
| 203 |
+
|
| 204 |
+
D_float = D.to(tl.float32)
|
| 205 |
+
|
| 206 |
+
# --- Single-tile path ---
|
| 207 |
+
if D <= BLOCK_D:
|
| 208 |
+
d_offs = tl.arange(0, BLOCK_D)
|
| 209 |
+
mask = d_offs < D
|
| 210 |
+
|
| 211 |
+
x = tl.load(input_row_ptr + d_offs, mask=mask,
|
| 212 |
+
other=0.0).to(tl.float32)
|
| 213 |
+
m = tl.load(mul_row_ptr + d_offs, mask=mask,
|
| 214 |
+
other=0.0).to(tl.float32)
|
| 215 |
+
|
| 216 |
+
x2 = x * x
|
| 217 |
+
x3 = x2 * x
|
| 218 |
+
|
| 219 |
+
inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps)
|
| 220 |
+
inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps)
|
| 221 |
+
inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps)
|
| 222 |
+
|
| 223 |
+
# Pre-multiply scalar weight * inv_rms to save 1 FMA per element
|
| 224 |
+
w0_inv = w0 * inv_rms_x3
|
| 225 |
+
w1_inv = w1 * inv_rms_x2
|
| 226 |
+
w2_inv = w2 * inv_rms_x
|
| 227 |
+
|
| 228 |
+
poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b
|
| 229 |
+
tl.store(out_row_ptr + d_offs, poly * m, mask=mask)
|
| 230 |
+
else:
|
| 231 |
+
# --- Multi-tile: two-pass approach ---
|
| 232 |
+
sum_x2 = tl.zeros((), dtype=tl.float32)
|
| 233 |
+
sum_x4 = tl.zeros((), dtype=tl.float32)
|
| 234 |
+
sum_x6 = tl.zeros((), dtype=tl.float32)
|
| 235 |
+
|
| 236 |
+
for d_start in range(0, D, BLOCK_D):
|
| 237 |
+
d_offs = d_start + tl.arange(0, BLOCK_D)
|
| 238 |
+
mask = d_offs < D
|
| 239 |
+
x = tl.load(input_row_ptr + d_offs, mask=mask,
|
| 240 |
+
other=0.0).to(tl.float32)
|
| 241 |
+
x2 = x * x
|
| 242 |
+
sum_x2 += tl.sum(x2)
|
| 243 |
+
sum_x4 += tl.sum(x2 * x2)
|
| 244 |
+
sum_x6 += tl.sum(x2 * x2 * x2)
|
| 245 |
+
|
| 246 |
+
inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps)
|
| 247 |
+
inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps)
|
| 248 |
+
inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps)
|
| 249 |
+
|
| 250 |
+
# Pre-multiply scalar weight * inv_rms
|
| 251 |
+
w0_inv = w0 * inv_rms_x3
|
| 252 |
+
w1_inv = w1 * inv_rms_x2
|
| 253 |
+
w2_inv = w2 * inv_rms_x
|
| 254 |
+
|
| 255 |
+
for d_start in range(0, D, BLOCK_D):
|
| 256 |
+
d_offs = d_start + tl.arange(0, BLOCK_D)
|
| 257 |
+
mask = d_offs < D
|
| 258 |
+
x = tl.load(input_row_ptr + d_offs, mask=mask,
|
| 259 |
+
other=0.0).to(tl.float32)
|
| 260 |
+
m = tl.load(mul_row_ptr + d_offs, mask=mask,
|
| 261 |
+
other=0.0).to(tl.float32)
|
| 262 |
+
x2 = x * x
|
| 263 |
+
x3 = x2 * x
|
| 264 |
+
poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b
|
| 265 |
+
tl.store(out_row_ptr + d_offs, poly * m, mask=mask)
|
| 266 |
+
|
| 267 |
+
@triton.autotune(
|
| 268 |
+
configs=_GROUPED_POLYNORM_BWD_CONFIGS,
|
| 269 |
+
key=["D"],
|
| 270 |
+
reset_to_zero=["grad_weight_ptr", "grad_bias_ptr"],
|
| 271 |
+
)
|
| 272 |
+
@triton.jit
|
| 273 |
+
def _grouped_polynorm_bwd_kernel(
|
| 274 |
+
grad_out_ptr,
|
| 275 |
+
input_ptr,
|
| 276 |
+
mul_ptr,
|
| 277 |
+
weight_ptr,
|
| 278 |
+
bias_ptr,
|
| 279 |
+
offsets_ptr,
|
| 280 |
+
grad_input_ptr,
|
| 281 |
+
grad_mul_ptr,
|
| 282 |
+
grad_weight_ptr,
|
| 283 |
+
grad_bias_ptr,
|
| 284 |
+
N,
|
| 285 |
+
D,
|
| 286 |
+
num_experts,
|
| 287 |
+
eps,
|
| 288 |
+
expert_offset,
|
| 289 |
+
stride_row,
|
| 290 |
+
BLOCK_D: tl.constexpr,
|
| 291 |
+
):
|
| 292 |
+
"""Backward kernel: one program per row, 2-pass approach.
|
| 293 |
+
|
| 294 |
+
Pass 1: RMS stats + dot products + bias grad
|
| 295 |
+
Pass 2: grad_input + grad_mul + weight grads (via atomic_add)
|
| 296 |
+
"""
|
| 297 |
+
row = tl.program_id(0)
|
| 298 |
+
if row >= N:
|
| 299 |
+
return
|
| 300 |
+
|
| 301 |
+
lo = 0
|
| 302 |
+
hi = num_experts
|
| 303 |
+
for _ in range(12):
|
| 304 |
+
if lo < hi:
|
| 305 |
+
mid = (lo + hi) // 2
|
| 306 |
+
if tl.load(offsets_ptr + mid) <= row:
|
| 307 |
+
lo = mid + 1
|
| 308 |
+
else:
|
| 309 |
+
hi = mid
|
| 310 |
+
eidx = lo + expert_offset
|
| 311 |
+
|
| 312 |
+
w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32)
|
| 313 |
+
w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32)
|
| 314 |
+
w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32)
|
| 315 |
+
b_val = tl.load(bias_ptr + eidx).to(tl.float32)
|
| 316 |
+
|
| 317 |
+
input_row_ptr = input_ptr + row * stride_row
|
| 318 |
+
mul_row_ptr = mul_ptr + row * stride_row
|
| 319 |
+
grad_out_row_ptr = grad_out_ptr + row * stride_row
|
| 320 |
+
grad_input_row_ptr = grad_input_ptr + row * stride_row
|
| 321 |
+
grad_mul_row_ptr = grad_mul_ptr + row * stride_row
|
| 322 |
+
|
| 323 |
+
D_float = D.to(tl.float32)
|
| 324 |
+
|
| 325 |
+
# --- Single-tile path ---
|
| 326 |
+
if D <= BLOCK_D:
|
| 327 |
+
d_offs = tl.arange(0, BLOCK_D)
|
| 328 |
+
mask = d_offs < D
|
| 329 |
+
|
| 330 |
+
x = tl.load(input_row_ptr + d_offs, mask=mask,
|
| 331 |
+
other=0.0).to(tl.float32)
|
| 332 |
+
m = tl.load(mul_row_ptr + d_offs, mask=mask,
|
| 333 |
+
other=0.0).to(tl.float32)
|
| 334 |
+
go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
|
| 335 |
+
other=0.0).to(tl.float32)
|
| 336 |
+
|
| 337 |
+
x2 = x * x
|
| 338 |
+
x3 = x2 * x
|
| 339 |
+
|
| 340 |
+
# Compute RMS stats (x4 inlined to reduce register pressure)
|
| 341 |
+
inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps)
|
| 342 |
+
inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps)
|
| 343 |
+
inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps)
|
| 344 |
+
|
| 345 |
+
w0_inv = w0 * inv_rms_x3
|
| 346 |
+
w1_inv = w1 * inv_rms_x2
|
| 347 |
+
w2_inv = w2 * inv_rms_x
|
| 348 |
+
|
| 349 |
+
dpoly = go * m
|
| 350 |
+
|
| 351 |
+
# Dot products for coefficients and weight grads
|
| 352 |
+
sum_dpoly_x = tl.sum(dpoly * x)
|
| 353 |
+
sum_dpoly_x2 = tl.sum(dpoly * x2)
|
| 354 |
+
sum_dpoly_x3 = tl.sum(dpoly * x3)
|
| 355 |
+
grad_b_acc = tl.sum(dpoly)
|
| 356 |
+
|
| 357 |
+
# Weight grads
|
| 358 |
+
grad_w0_acc = inv_rms_x3 * sum_dpoly_x3
|
| 359 |
+
grad_w1_acc = inv_rms_x2 * sum_dpoly_x2
|
| 360 |
+
grad_w2_acc = inv_rms_x * sum_dpoly_x
|
| 361 |
+
|
| 362 |
+
# Coefficients for grad_input
|
| 363 |
+
coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float
|
| 364 |
+
coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float
|
| 365 |
+
coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float
|
| 366 |
+
|
| 367 |
+
# grad_mul
|
| 368 |
+
poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv
|
| 369 |
+
tl.store(grad_mul_row_ptr + d_offs, go * (poly + b_val), mask=mask)
|
| 370 |
+
|
| 371 |
+
# grad_input (in-place accumulation to reduce register pressure)
|
| 372 |
+
g = inv_rms_x * (w2 * dpoly - x * coeff_x)
|
| 373 |
+
g += 2.0 * x * inv_rms_x2 * (w1 * dpoly - x2 * coeff_x2)
|
| 374 |
+
g += 3.0 * x2 * inv_rms_x3 * (w0 * dpoly - x3 * coeff_x3)
|
| 375 |
+
tl.store(grad_input_row_ptr + d_offs, g, mask=mask)
|
| 376 |
+
|
| 377 |
+
tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc)
|
| 378 |
+
tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc)
|
| 379 |
+
tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc)
|
| 380 |
+
tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc)
|
| 381 |
+
else:
|
| 382 |
+
# --- Multi-tile: 2-pass ---
|
| 383 |
+
# Pass 1: RMS stats + dot products + bias grad
|
| 384 |
+
sum_x2 = tl.zeros((), dtype=tl.float32)
|
| 385 |
+
sum_x4 = tl.zeros((), dtype=tl.float32)
|
| 386 |
+
sum_x6 = tl.zeros((), dtype=tl.float32)
|
| 387 |
+
sum_dpoly_x = tl.zeros((), dtype=tl.float32)
|
| 388 |
+
sum_dpoly_x2 = tl.zeros((), dtype=tl.float32)
|
| 389 |
+
sum_dpoly_x3 = tl.zeros((), dtype=tl.float32)
|
| 390 |
+
grad_b_acc = tl.zeros((), dtype=tl.float32)
|
| 391 |
+
|
| 392 |
+
for d_start in range(0, D, BLOCK_D):
|
| 393 |
+
d_offs = d_start + tl.arange(0, BLOCK_D)
|
| 394 |
+
mask = d_offs < D
|
| 395 |
+
x = tl.load(input_row_ptr + d_offs, mask=mask,
|
| 396 |
+
other=0.0).to(tl.float32)
|
| 397 |
+
m = tl.load(mul_row_ptr + d_offs, mask=mask,
|
| 398 |
+
other=0.0).to(tl.float32)
|
| 399 |
+
go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
|
| 400 |
+
other=0.0).to(tl.float32)
|
| 401 |
+
|
| 402 |
+
x2 = x * x
|
| 403 |
+
x3 = x2 * x
|
| 404 |
+
dpoly = go * m
|
| 405 |
+
|
| 406 |
+
sum_x2 += tl.sum(x2)
|
| 407 |
+
sum_x4 += tl.sum(x2 * x2)
|
| 408 |
+
sum_x6 += tl.sum(x2 * x2 * x2)
|
| 409 |
+
sum_dpoly_x += tl.sum(dpoly * x)
|
| 410 |
+
sum_dpoly_x2 += tl.sum(dpoly * x2)
|
| 411 |
+
sum_dpoly_x3 += tl.sum(dpoly * x3)
|
| 412 |
+
grad_b_acc += tl.sum(dpoly)
|
| 413 |
+
|
| 414 |
+
inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps)
|
| 415 |
+
inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps)
|
| 416 |
+
inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps)
|
| 417 |
+
|
| 418 |
+
w0_inv = w0 * inv_rms_x3
|
| 419 |
+
w1_inv = w1 * inv_rms_x2
|
| 420 |
+
w2_inv = w2 * inv_rms_x
|
| 421 |
+
|
| 422 |
+
# Weight grads
|
| 423 |
+
grad_w0_acc = inv_rms_x3 * sum_dpoly_x3
|
| 424 |
+
grad_w1_acc = inv_rms_x2 * sum_dpoly_x2
|
| 425 |
+
grad_w2_acc = inv_rms_x * sum_dpoly_x
|
| 426 |
+
|
| 427 |
+
# Coefficients for grad_input
|
| 428 |
+
coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float
|
| 429 |
+
coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float
|
| 430 |
+
coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float
|
| 431 |
+
|
| 432 |
+
# Pass 2: grad_input + grad_mul
|
| 433 |
+
for d_start in range(0, D, BLOCK_D):
|
| 434 |
+
d_offs = d_start + tl.arange(0, BLOCK_D)
|
| 435 |
+
mask = d_offs < D
|
| 436 |
+
x = tl.load(input_row_ptr + d_offs, mask=mask,
|
| 437 |
+
other=0.0).to(tl.float32)
|
| 438 |
+
m = tl.load(mul_row_ptr + d_offs, mask=mask,
|
| 439 |
+
other=0.0).to(tl.float32)
|
| 440 |
+
go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
|
| 441 |
+
other=0.0).to(tl.float32)
|
| 442 |
+
|
| 443 |
+
x2 = x * x
|
| 444 |
+
x3 = x2 * x
|
| 445 |
+
|
| 446 |
+
poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv
|
| 447 |
+
tl.store(grad_mul_row_ptr + d_offs,
|
| 448 |
+
go * (poly + b_val),
|
| 449 |
+
mask=mask)
|
| 450 |
+
|
| 451 |
+
dpoly = go * m
|
| 452 |
+
g = inv_rms_x * (w2 * dpoly - x * coeff_x)
|
| 453 |
+
g += (2.0 * x * inv_rms_x2 *
|
| 454 |
+
(w1 * dpoly - x2 * coeff_x2))
|
| 455 |
+
g += (3.0 * x2 * inv_rms_x3 *
|
| 456 |
+
(w0 * dpoly - x3 * coeff_x3))
|
| 457 |
+
|
| 458 |
+
tl.store(grad_input_row_ptr + d_offs, g, mask=mask)
|
| 459 |
+
|
| 460 |
+
tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc)
|
| 461 |
+
tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc)
|
| 462 |
+
tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc)
|
| 463 |
+
tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc)
|
| 464 |
+
|
| 465 |
+
class _GroupedPolyNormFn(torch.autograd.Function):
|
| 466 |
+
|
| 467 |
+
@staticmethod
|
| 468 |
+
def forward(ctx, input, mul, weight, bias, offsets, eps, expert_offset):
|
| 469 |
+
N, D = input.shape
|
| 470 |
+
input = input.contiguous()
|
| 471 |
+
mul = mul.contiguous()
|
| 472 |
+
output = torch.empty_like(input)
|
| 473 |
+
|
| 474 |
+
num_experts = offsets.shape[0]
|
| 475 |
+
assert num_experts <= 4096, (
|
| 476 |
+
f"Supports at most 4096 experts, got {num_experts}.")
|
| 477 |
+
|
| 478 |
+
_grouped_polynorm_fwd_kernel[(N,)](
|
| 479 |
+
input,
|
| 480 |
+
mul,
|
| 481 |
+
weight,
|
| 482 |
+
bias,
|
| 483 |
+
offsets,
|
| 484 |
+
output,
|
| 485 |
+
N,
|
| 486 |
+
D,
|
| 487 |
+
num_experts,
|
| 488 |
+
eps,
|
| 489 |
+
expert_offset,
|
| 490 |
+
stride_input_row=input.stride(0),
|
| 491 |
+
stride_mul_row=mul.stride(0),
|
| 492 |
+
stride_out_row=output.stride(0),
|
| 493 |
+
)
|
| 494 |
+
|
| 495 |
+
ctx.save_for_backward(input, mul, weight, bias, offsets)
|
| 496 |
+
ctx.eps = eps
|
| 497 |
+
ctx.expert_offset = expert_offset
|
| 498 |
+
return output
|
| 499 |
+
|
| 500 |
+
@staticmethod
|
| 501 |
+
def backward(ctx, grad_output):
|
| 502 |
+
input, mul, weight, bias, offsets = ctx.saved_tensors
|
| 503 |
+
eps = ctx.eps
|
| 504 |
+
expert_offset = ctx.expert_offset
|
| 505 |
+
N, D = input.shape
|
| 506 |
+
|
| 507 |
+
grad_output = grad_output.contiguous()
|
| 508 |
+
grad_input = torch.empty_like(input)
|
| 509 |
+
grad_mul = torch.empty_like(mul)
|
| 510 |
+
grad_weight = torch.zeros(weight.shape[0],
|
| 511 |
+
3,
|
| 512 |
+
device=weight.device,
|
| 513 |
+
dtype=torch.float32)
|
| 514 |
+
grad_bias = torch.zeros(bias.shape[0],
|
| 515 |
+
device=bias.device,
|
| 516 |
+
dtype=torch.float32)
|
| 517 |
+
|
| 518 |
+
num_experts = offsets.shape[0]
|
| 519 |
+
|
| 520 |
+
_grouped_polynorm_bwd_kernel[(N,)](
|
| 521 |
+
grad_output,
|
| 522 |
+
input,
|
| 523 |
+
mul,
|
| 524 |
+
weight,
|
| 525 |
+
bias,
|
| 526 |
+
offsets,
|
| 527 |
+
grad_input,
|
| 528 |
+
grad_mul,
|
| 529 |
+
grad_weight,
|
| 530 |
+
grad_bias,
|
| 531 |
+
N,
|
| 532 |
+
D,
|
| 533 |
+
num_experts,
|
| 534 |
+
eps,
|
| 535 |
+
expert_offset,
|
| 536 |
+
stride_row=input.stride(0),
|
| 537 |
+
)
|
| 538 |
+
|
| 539 |
+
grad_weight = grad_weight.to(weight.dtype)
|
| 540 |
+
grad_bias = grad_bias.unsqueeze(-1).to(bias.dtype)
|
| 541 |
+
|
| 542 |
+
return grad_input, grad_mul, grad_weight, grad_bias, None, None, None
|
| 543 |
+
|
| 544 |
+
def grouped_fused_mul_poly_norm(
|
| 545 |
+
input: Tensor,
|
| 546 |
+
mul: Tensor,
|
| 547 |
+
weight: Tensor,
|
| 548 |
+
bias: Tensor,
|
| 549 |
+
offsets: Tensor,
|
| 550 |
+
eps: float = 1e-6,
|
| 551 |
+
expert_offset: int = 0,
|
| 552 |
+
) -> Tensor:
|
| 553 |
+
"""Triton-accelerated Grouped FusedMulPolyNorm.
|
| 554 |
+
|
| 555 |
+
Args:
|
| 556 |
+
input: (total_tokens, D) - concatenated tokens for all experts
|
| 557 |
+
mul: (total_tokens, D) - gate values to multiply with
|
| 558 |
+
weight: (num_experts, 3) - per-expert polynomial weights
|
| 559 |
+
bias: (num_experts, 1) - per-expert polynomial bias
|
| 560 |
+
offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32)
|
| 561 |
+
eps: numerical stability epsilon
|
| 562 |
+
expert_offset: offset to add to expert index
|
| 563 |
+
|
| 564 |
+
Returns:
|
| 565 |
+
(total_tokens, D) - output tensor
|
| 566 |
+
"""
|
| 567 |
+
return _GroupedPolyNormFn.apply(input, mul, weight, bias, offsets, eps,
|
| 568 |
+
expert_offset)
|
| 569 |
+
|
| 570 |
+
else:
|
| 571 |
+
|
| 572 |
+
def grouped_fused_mul_poly_norm(
|
| 573 |
+
input: Tensor,
|
| 574 |
+
mul: Tensor,
|
| 575 |
+
weight: Tensor,
|
| 576 |
+
bias: Tensor,
|
| 577 |
+
offsets: Tensor,
|
| 578 |
+
eps: float = 1e-6,
|
| 579 |
+
expert_offset: int = 0,
|
| 580 |
+
) -> Tensor:
|
| 581 |
+
raise RuntimeError(
|
| 582 |
+
"Triton is not available. Install triton to use "
|
| 583 |
+
"grouped_fused_mul_poly_norm.")
|
build/torch210-cxx11-rocm70-x86_64-linux/__init__.py
CHANGED
|
@@ -2,6 +2,7 @@ import torch
|
|
| 2 |
|
| 3 |
from . import layers, parallel_style
|
| 4 |
from ._ops import ops
|
|
|
|
| 5 |
from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
|
| 6 |
from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
|
| 7 |
|
|
@@ -45,6 +46,7 @@ def fused_add_rms_norm(
|
|
| 45 |
__all__ = [
|
| 46 |
"poly_norm",
|
| 47 |
"fused_mul_poly_norm",
|
|
|
|
| 48 |
"rms_norm",
|
| 49 |
"fused_add_rms_norm",
|
| 50 |
"layers",
|
|
|
|
| 2 |
|
| 3 |
from . import layers, parallel_style
|
| 4 |
from ._ops import ops
|
| 5 |
+
from .grouped_poly_norm import grouped_fused_mul_poly_norm
|
| 6 |
from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
|
| 7 |
from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
|
| 8 |
|
|
|
|
| 46 |
__all__ = [
|
| 47 |
"poly_norm",
|
| 48 |
"fused_mul_poly_norm",
|
| 49 |
+
"grouped_fused_mul_poly_norm",
|
| 50 |
"rms_norm",
|
| 51 |
"fused_add_rms_norm",
|
| 52 |
"layers",
|
build/torch210-cxx11-rocm70-x86_64-linux/{_activation_18b7543_dirty.abi3.so → _activation_0e6f27f_dirty.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 2919488
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5f8d5a173c51cb2dabe3554da743aed307c04b5d51c9d0d460a8fa5a821b5495
|
| 3 |
size 2919488
|
build/torch210-cxx11-rocm70-x86_64-linux/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _activation_0e6f27f_dirty
|
| 3 |
+
ops = torch.ops._activation_0e6f27f_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_activation_0e6f27f_dirty::{op_name}"
|
build/torch210-cxx11-rocm70-x86_64-linux/grouped_poly_norm.py
ADDED
|
@@ -0,0 +1,583 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Triton-accelerated Grouped FusedMulPolyNorm for MoE.
|
| 2 |
+
|
| 3 |
+
Fuses the entire PolyNorm computation into two Triton kernels (fwd + bwd),
|
| 4 |
+
eliminating multiple intermediate tensors and kernel launches.
|
| 5 |
+
|
| 6 |
+
PolyNorm formula (per row):
|
| 7 |
+
poly = w[0] * rms_norm(x^3) + w[1] * rms_norm(x^2) + w[2] * rms_norm(x) + bias
|
| 8 |
+
output = poly * mul
|
| 9 |
+
|
| 10 |
+
where rms_norm(x) = x / sqrt(mean(x^2, dim=-1) + eps)
|
| 11 |
+
|
| 12 |
+
Performance optimizations:
|
| 13 |
+
- @triton.autotune selects optimal BLOCK_D, num_warps, and num_stages per
|
| 14 |
+
hidden dimension.
|
| 15 |
+
- Single-tile specialization: when D <= BLOCK_D, all data stays in registers
|
| 16 |
+
across the reduction and output phases, eliminating redundant global reads.
|
| 17 |
+
- Multi-tile software pipelining: explicit num_stages in autotune configs
|
| 18 |
+
enables overlapping memory loads with computation across loop iterations.
|
| 19 |
+
- In-kernel binary search for expert mapping: eliminates 2 PyTorch kernel
|
| 20 |
+
launches (torch.arange + torch.bucketize) per forward/backward call.
|
| 21 |
+
- Backward 2-pass optimization: pass 1 merges RMS statistics computation
|
| 22 |
+
with dot product accumulation, pass 2 computes gradients. This reduces
|
| 23 |
+
memory traffic compared to a naive 3-pass approach.
|
| 24 |
+
|
| 25 |
+
Forward kernel: one program per row, tiles over D dimension.
|
| 26 |
+
- Computes x, x^2, x^3 in registers
|
| 27 |
+
- Computes three RMS norms in a single pass (shared variance reduction)
|
| 28 |
+
- Applies polynomial weights + bias + mul in-place
|
| 29 |
+
|
| 30 |
+
Backward kernel: one program per row, tiles over D dimension.
|
| 31 |
+
- Recomputes forward intermediates from saved inputs (activation recomputation)
|
| 32 |
+
- 2-pass: (1) RMS stats + dot products + bias grad, (2) grad_input + grad_mul + weight grads
|
| 33 |
+
- Weight/bias gradients use tl.atomic_add for cross-row accumulation
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
import torch
|
| 37 |
+
from torch import Tensor
|
| 38 |
+
|
| 39 |
+
try:
|
| 40 |
+
import triton
|
| 41 |
+
import triton.language as tl
|
| 42 |
+
|
| 43 |
+
HAS_TRITON = True
|
| 44 |
+
except ImportError:
|
| 45 |
+
HAS_TRITON = False
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# ---------------------------------------------------------------------------
|
| 49 |
+
# PyTorch reference implementation (for testing and benchmarking)
|
| 50 |
+
# ---------------------------------------------------------------------------
|
| 51 |
+
def _rms_norm(x: Tensor, eps: float) -> Tensor:
|
| 52 |
+
"""Per-row RMS normalization: x / sqrt(mean(x^2, dim=-1) + eps)"""
|
| 53 |
+
return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def grouped_fused_mul_poly_norm_ref(
|
| 57 |
+
input: Tensor,
|
| 58 |
+
mul: Tensor,
|
| 59 |
+
weight: Tensor,
|
| 60 |
+
bias: Tensor,
|
| 61 |
+
offsets: Tensor,
|
| 62 |
+
eps: float = 1e-6,
|
| 63 |
+
expert_offset: int = 0,
|
| 64 |
+
) -> Tensor:
|
| 65 |
+
"""PyTorch reference for Grouped FusedMulPolyNorm (vectorized, single pass).
|
| 66 |
+
|
| 67 |
+
Uses torch.bucketize to map tokens to experts, then computes PolyNorm
|
| 68 |
+
for all tokens at once. torch.compile friendly.
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
input: (total_tokens, D) - concatenated tokens for all experts
|
| 72 |
+
mul: (total_tokens, D) - gate values to multiply with
|
| 73 |
+
weight: (num_experts, 3) - per-expert polynomial weights [x^3, x^2, x]
|
| 74 |
+
bias: (num_experts, 1) - per-expert polynomial bias
|
| 75 |
+
offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32)
|
| 76 |
+
eps: numerical stability epsilon
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
(total_tokens, D) - output tensor
|
| 80 |
+
"""
|
| 81 |
+
orig_dtype = input.dtype
|
| 82 |
+
|
| 83 |
+
token_positions = torch.arange(input.shape[0], device=input.device)
|
| 84 |
+
expert_idx = torch.bucketize(token_positions, offsets, right=True) + expert_offset
|
| 85 |
+
|
| 86 |
+
weight_fp32 = weight.float()
|
| 87 |
+
bias_fp32 = bias.float()
|
| 88 |
+
|
| 89 |
+
per_token_w = weight_fp32[expert_idx]
|
| 90 |
+
per_token_b = bias_fp32[expert_idx]
|
| 91 |
+
|
| 92 |
+
x = input.float()
|
| 93 |
+
m = mul.float()
|
| 94 |
+
|
| 95 |
+
x2 = x * x
|
| 96 |
+
x3 = x2 * x
|
| 97 |
+
|
| 98 |
+
poly = (per_token_w[:, 0:1] * _rms_norm(x3, eps) +
|
| 99 |
+
per_token_w[:, 1:2] * _rms_norm(x2, eps) +
|
| 100 |
+
per_token_w[:, 2:3] * _rms_norm(x, eps) + per_token_b)
|
| 101 |
+
|
| 102 |
+
return (poly * m).to(orig_dtype)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
# ---------------------------------------------------------------------------
|
| 106 |
+
# Triton kernel implementation
|
| 107 |
+
# ---------------------------------------------------------------------------
|
| 108 |
+
if HAS_TRITON:
|
| 109 |
+
# --- Autotune configurations ---
|
| 110 |
+
_GROUPED_POLYNORM_FWD_CONFIGS = [
|
| 111 |
+
triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2),
|
| 112 |
+
triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3),
|
| 113 |
+
triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2),
|
| 114 |
+
triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3),
|
| 115 |
+
triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2),
|
| 116 |
+
triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4),
|
| 117 |
+
triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2),
|
| 118 |
+
triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3),
|
| 119 |
+
triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4),
|
| 120 |
+
triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2),
|
| 121 |
+
triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3),
|
| 122 |
+
triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4),
|
| 123 |
+
triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2),
|
| 124 |
+
triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2),
|
| 125 |
+
triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3),
|
| 126 |
+
triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2),
|
| 127 |
+
triton.Config({"BLOCK_D": 2048}, num_warps=4, num_stages=1),
|
| 128 |
+
triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1),
|
| 129 |
+
triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1),
|
| 130 |
+
triton.Config({"BLOCK_D": 2048}, num_warps=32, num_stages=1),
|
| 131 |
+
]
|
| 132 |
+
|
| 133 |
+
_GROUPED_POLYNORM_BWD_CONFIGS = [
|
| 134 |
+
triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2),
|
| 135 |
+
triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3),
|
| 136 |
+
triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2),
|
| 137 |
+
triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3),
|
| 138 |
+
triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2),
|
| 139 |
+
triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4),
|
| 140 |
+
triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2),
|
| 141 |
+
triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3),
|
| 142 |
+
triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4),
|
| 143 |
+
triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=5),
|
| 144 |
+
triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2),
|
| 145 |
+
triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3),
|
| 146 |
+
triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4),
|
| 147 |
+
triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2),
|
| 148 |
+
triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2),
|
| 149 |
+
triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3),
|
| 150 |
+
triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=4),
|
| 151 |
+
triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2),
|
| 152 |
+
triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1),
|
| 153 |
+
triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1),
|
| 154 |
+
]
|
| 155 |
+
|
| 156 |
+
@triton.autotune(
|
| 157 |
+
configs=_GROUPED_POLYNORM_FWD_CONFIGS,
|
| 158 |
+
key=["D"],
|
| 159 |
+
)
|
| 160 |
+
@triton.jit
|
| 161 |
+
def _grouped_polynorm_fwd_kernel(
|
| 162 |
+
input_ptr,
|
| 163 |
+
mul_ptr,
|
| 164 |
+
weight_ptr,
|
| 165 |
+
bias_ptr,
|
| 166 |
+
offsets_ptr,
|
| 167 |
+
output_ptr,
|
| 168 |
+
N,
|
| 169 |
+
D,
|
| 170 |
+
num_experts,
|
| 171 |
+
eps,
|
| 172 |
+
expert_offset,
|
| 173 |
+
stride_input_row,
|
| 174 |
+
stride_mul_row,
|
| 175 |
+
stride_out_row,
|
| 176 |
+
BLOCK_D: tl.constexpr,
|
| 177 |
+
):
|
| 178 |
+
"""Forward kernel: one program per row."""
|
| 179 |
+
row = tl.program_id(0)
|
| 180 |
+
if row >= N:
|
| 181 |
+
return
|
| 182 |
+
|
| 183 |
+
# Binary search for expert index (12 iters covers up to 4096 experts)
|
| 184 |
+
lo = 0
|
| 185 |
+
hi = num_experts
|
| 186 |
+
for _ in range(12):
|
| 187 |
+
if lo < hi:
|
| 188 |
+
mid = (lo + hi) // 2
|
| 189 |
+
if tl.load(offsets_ptr + mid) <= row:
|
| 190 |
+
lo = mid + 1
|
| 191 |
+
else:
|
| 192 |
+
hi = mid
|
| 193 |
+
eidx = lo + expert_offset
|
| 194 |
+
|
| 195 |
+
w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32)
|
| 196 |
+
w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32)
|
| 197 |
+
w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32)
|
| 198 |
+
b = tl.load(bias_ptr + eidx).to(tl.float32)
|
| 199 |
+
|
| 200 |
+
input_row_ptr = input_ptr + row * stride_input_row
|
| 201 |
+
mul_row_ptr = mul_ptr + row * stride_mul_row
|
| 202 |
+
out_row_ptr = output_ptr + row * stride_out_row
|
| 203 |
+
|
| 204 |
+
D_float = D.to(tl.float32)
|
| 205 |
+
|
| 206 |
+
# --- Single-tile path ---
|
| 207 |
+
if D <= BLOCK_D:
|
| 208 |
+
d_offs = tl.arange(0, BLOCK_D)
|
| 209 |
+
mask = d_offs < D
|
| 210 |
+
|
| 211 |
+
x = tl.load(input_row_ptr + d_offs, mask=mask,
|
| 212 |
+
other=0.0).to(tl.float32)
|
| 213 |
+
m = tl.load(mul_row_ptr + d_offs, mask=mask,
|
| 214 |
+
other=0.0).to(tl.float32)
|
| 215 |
+
|
| 216 |
+
x2 = x * x
|
| 217 |
+
x3 = x2 * x
|
| 218 |
+
|
| 219 |
+
inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps)
|
| 220 |
+
inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps)
|
| 221 |
+
inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps)
|
| 222 |
+
|
| 223 |
+
# Pre-multiply scalar weight * inv_rms to save 1 FMA per element
|
| 224 |
+
w0_inv = w0 * inv_rms_x3
|
| 225 |
+
w1_inv = w1 * inv_rms_x2
|
| 226 |
+
w2_inv = w2 * inv_rms_x
|
| 227 |
+
|
| 228 |
+
poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b
|
| 229 |
+
tl.store(out_row_ptr + d_offs, poly * m, mask=mask)
|
| 230 |
+
else:
|
| 231 |
+
# --- Multi-tile: two-pass approach ---
|
| 232 |
+
sum_x2 = tl.zeros((), dtype=tl.float32)
|
| 233 |
+
sum_x4 = tl.zeros((), dtype=tl.float32)
|
| 234 |
+
sum_x6 = tl.zeros((), dtype=tl.float32)
|
| 235 |
+
|
| 236 |
+
for d_start in range(0, D, BLOCK_D):
|
| 237 |
+
d_offs = d_start + tl.arange(0, BLOCK_D)
|
| 238 |
+
mask = d_offs < D
|
| 239 |
+
x = tl.load(input_row_ptr + d_offs, mask=mask,
|
| 240 |
+
other=0.0).to(tl.float32)
|
| 241 |
+
x2 = x * x
|
| 242 |
+
sum_x2 += tl.sum(x2)
|
| 243 |
+
sum_x4 += tl.sum(x2 * x2)
|
| 244 |
+
sum_x6 += tl.sum(x2 * x2 * x2)
|
| 245 |
+
|
| 246 |
+
inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps)
|
| 247 |
+
inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps)
|
| 248 |
+
inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps)
|
| 249 |
+
|
| 250 |
+
# Pre-multiply scalar weight * inv_rms
|
| 251 |
+
w0_inv = w0 * inv_rms_x3
|
| 252 |
+
w1_inv = w1 * inv_rms_x2
|
| 253 |
+
w2_inv = w2 * inv_rms_x
|
| 254 |
+
|
| 255 |
+
for d_start in range(0, D, BLOCK_D):
|
| 256 |
+
d_offs = d_start + tl.arange(0, BLOCK_D)
|
| 257 |
+
mask = d_offs < D
|
| 258 |
+
x = tl.load(input_row_ptr + d_offs, mask=mask,
|
| 259 |
+
other=0.0).to(tl.float32)
|
| 260 |
+
m = tl.load(mul_row_ptr + d_offs, mask=mask,
|
| 261 |
+
other=0.0).to(tl.float32)
|
| 262 |
+
x2 = x * x
|
| 263 |
+
x3 = x2 * x
|
| 264 |
+
poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b
|
| 265 |
+
tl.store(out_row_ptr + d_offs, poly * m, mask=mask)
|
| 266 |
+
|
| 267 |
+
@triton.autotune(
|
| 268 |
+
configs=_GROUPED_POLYNORM_BWD_CONFIGS,
|
| 269 |
+
key=["D"],
|
| 270 |
+
reset_to_zero=["grad_weight_ptr", "grad_bias_ptr"],
|
| 271 |
+
)
|
| 272 |
+
@triton.jit
|
| 273 |
+
def _grouped_polynorm_bwd_kernel(
|
| 274 |
+
grad_out_ptr,
|
| 275 |
+
input_ptr,
|
| 276 |
+
mul_ptr,
|
| 277 |
+
weight_ptr,
|
| 278 |
+
bias_ptr,
|
| 279 |
+
offsets_ptr,
|
| 280 |
+
grad_input_ptr,
|
| 281 |
+
grad_mul_ptr,
|
| 282 |
+
grad_weight_ptr,
|
| 283 |
+
grad_bias_ptr,
|
| 284 |
+
N,
|
| 285 |
+
D,
|
| 286 |
+
num_experts,
|
| 287 |
+
eps,
|
| 288 |
+
expert_offset,
|
| 289 |
+
stride_row,
|
| 290 |
+
BLOCK_D: tl.constexpr,
|
| 291 |
+
):
|
| 292 |
+
"""Backward kernel: one program per row, 2-pass approach.
|
| 293 |
+
|
| 294 |
+
Pass 1: RMS stats + dot products + bias grad
|
| 295 |
+
Pass 2: grad_input + grad_mul + weight grads (via atomic_add)
|
| 296 |
+
"""
|
| 297 |
+
row = tl.program_id(0)
|
| 298 |
+
if row >= N:
|
| 299 |
+
return
|
| 300 |
+
|
| 301 |
+
lo = 0
|
| 302 |
+
hi = num_experts
|
| 303 |
+
for _ in range(12):
|
| 304 |
+
if lo < hi:
|
| 305 |
+
mid = (lo + hi) // 2
|
| 306 |
+
if tl.load(offsets_ptr + mid) <= row:
|
| 307 |
+
lo = mid + 1
|
| 308 |
+
else:
|
| 309 |
+
hi = mid
|
| 310 |
+
eidx = lo + expert_offset
|
| 311 |
+
|
| 312 |
+
w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32)
|
| 313 |
+
w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32)
|
| 314 |
+
w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32)
|
| 315 |
+
b_val = tl.load(bias_ptr + eidx).to(tl.float32)
|
| 316 |
+
|
| 317 |
+
input_row_ptr = input_ptr + row * stride_row
|
| 318 |
+
mul_row_ptr = mul_ptr + row * stride_row
|
| 319 |
+
grad_out_row_ptr = grad_out_ptr + row * stride_row
|
| 320 |
+
grad_input_row_ptr = grad_input_ptr + row * stride_row
|
| 321 |
+
grad_mul_row_ptr = grad_mul_ptr + row * stride_row
|
| 322 |
+
|
| 323 |
+
D_float = D.to(tl.float32)
|
| 324 |
+
|
| 325 |
+
# --- Single-tile path ---
|
| 326 |
+
if D <= BLOCK_D:
|
| 327 |
+
d_offs = tl.arange(0, BLOCK_D)
|
| 328 |
+
mask = d_offs < D
|
| 329 |
+
|
| 330 |
+
x = tl.load(input_row_ptr + d_offs, mask=mask,
|
| 331 |
+
other=0.0).to(tl.float32)
|
| 332 |
+
m = tl.load(mul_row_ptr + d_offs, mask=mask,
|
| 333 |
+
other=0.0).to(tl.float32)
|
| 334 |
+
go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
|
| 335 |
+
other=0.0).to(tl.float32)
|
| 336 |
+
|
| 337 |
+
x2 = x * x
|
| 338 |
+
x3 = x2 * x
|
| 339 |
+
|
| 340 |
+
# Compute RMS stats (x4 inlined to reduce register pressure)
|
| 341 |
+
inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps)
|
| 342 |
+
inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps)
|
| 343 |
+
inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps)
|
| 344 |
+
|
| 345 |
+
w0_inv = w0 * inv_rms_x3
|
| 346 |
+
w1_inv = w1 * inv_rms_x2
|
| 347 |
+
w2_inv = w2 * inv_rms_x
|
| 348 |
+
|
| 349 |
+
dpoly = go * m
|
| 350 |
+
|
| 351 |
+
# Dot products for coefficients and weight grads
|
| 352 |
+
sum_dpoly_x = tl.sum(dpoly * x)
|
| 353 |
+
sum_dpoly_x2 = tl.sum(dpoly * x2)
|
| 354 |
+
sum_dpoly_x3 = tl.sum(dpoly * x3)
|
| 355 |
+
grad_b_acc = tl.sum(dpoly)
|
| 356 |
+
|
| 357 |
+
# Weight grads
|
| 358 |
+
grad_w0_acc = inv_rms_x3 * sum_dpoly_x3
|
| 359 |
+
grad_w1_acc = inv_rms_x2 * sum_dpoly_x2
|
| 360 |
+
grad_w2_acc = inv_rms_x * sum_dpoly_x
|
| 361 |
+
|
| 362 |
+
# Coefficients for grad_input
|
| 363 |
+
coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float
|
| 364 |
+
coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float
|
| 365 |
+
coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float
|
| 366 |
+
|
| 367 |
+
# grad_mul
|
| 368 |
+
poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv
|
| 369 |
+
tl.store(grad_mul_row_ptr + d_offs, go * (poly + b_val), mask=mask)
|
| 370 |
+
|
| 371 |
+
# grad_input (in-place accumulation to reduce register pressure)
|
| 372 |
+
g = inv_rms_x * (w2 * dpoly - x * coeff_x)
|
| 373 |
+
g += 2.0 * x * inv_rms_x2 * (w1 * dpoly - x2 * coeff_x2)
|
| 374 |
+
g += 3.0 * x2 * inv_rms_x3 * (w0 * dpoly - x3 * coeff_x3)
|
| 375 |
+
tl.store(grad_input_row_ptr + d_offs, g, mask=mask)
|
| 376 |
+
|
| 377 |
+
tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc)
|
| 378 |
+
tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc)
|
| 379 |
+
tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc)
|
| 380 |
+
tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc)
|
| 381 |
+
else:
|
| 382 |
+
# --- Multi-tile: 2-pass ---
|
| 383 |
+
# Pass 1: RMS stats + dot products + bias grad
|
| 384 |
+
sum_x2 = tl.zeros((), dtype=tl.float32)
|
| 385 |
+
sum_x4 = tl.zeros((), dtype=tl.float32)
|
| 386 |
+
sum_x6 = tl.zeros((), dtype=tl.float32)
|
| 387 |
+
sum_dpoly_x = tl.zeros((), dtype=tl.float32)
|
| 388 |
+
sum_dpoly_x2 = tl.zeros((), dtype=tl.float32)
|
| 389 |
+
sum_dpoly_x3 = tl.zeros((), dtype=tl.float32)
|
| 390 |
+
grad_b_acc = tl.zeros((), dtype=tl.float32)
|
| 391 |
+
|
| 392 |
+
for d_start in range(0, D, BLOCK_D):
|
| 393 |
+
d_offs = d_start + tl.arange(0, BLOCK_D)
|
| 394 |
+
mask = d_offs < D
|
| 395 |
+
x = tl.load(input_row_ptr + d_offs, mask=mask,
|
| 396 |
+
other=0.0).to(tl.float32)
|
| 397 |
+
m = tl.load(mul_row_ptr + d_offs, mask=mask,
|
| 398 |
+
other=0.0).to(tl.float32)
|
| 399 |
+
go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
|
| 400 |
+
other=0.0).to(tl.float32)
|
| 401 |
+
|
| 402 |
+
x2 = x * x
|
| 403 |
+
x3 = x2 * x
|
| 404 |
+
dpoly = go * m
|
| 405 |
+
|
| 406 |
+
sum_x2 += tl.sum(x2)
|
| 407 |
+
sum_x4 += tl.sum(x2 * x2)
|
| 408 |
+
sum_x6 += tl.sum(x2 * x2 * x2)
|
| 409 |
+
sum_dpoly_x += tl.sum(dpoly * x)
|
| 410 |
+
sum_dpoly_x2 += tl.sum(dpoly * x2)
|
| 411 |
+
sum_dpoly_x3 += tl.sum(dpoly * x3)
|
| 412 |
+
grad_b_acc += tl.sum(dpoly)
|
| 413 |
+
|
| 414 |
+
inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps)
|
| 415 |
+
inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps)
|
| 416 |
+
inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps)
|
| 417 |
+
|
| 418 |
+
w0_inv = w0 * inv_rms_x3
|
| 419 |
+
w1_inv = w1 * inv_rms_x2
|
| 420 |
+
w2_inv = w2 * inv_rms_x
|
| 421 |
+
|
| 422 |
+
# Weight grads
|
| 423 |
+
grad_w0_acc = inv_rms_x3 * sum_dpoly_x3
|
| 424 |
+
grad_w1_acc = inv_rms_x2 * sum_dpoly_x2
|
| 425 |
+
grad_w2_acc = inv_rms_x * sum_dpoly_x
|
| 426 |
+
|
| 427 |
+
# Coefficients for grad_input
|
| 428 |
+
coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float
|
| 429 |
+
coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float
|
| 430 |
+
coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float
|
| 431 |
+
|
| 432 |
+
# Pass 2: grad_input + grad_mul
|
| 433 |
+
for d_start in range(0, D, BLOCK_D):
|
| 434 |
+
d_offs = d_start + tl.arange(0, BLOCK_D)
|
| 435 |
+
mask = d_offs < D
|
| 436 |
+
x = tl.load(input_row_ptr + d_offs, mask=mask,
|
| 437 |
+
other=0.0).to(tl.float32)
|
| 438 |
+
m = tl.load(mul_row_ptr + d_offs, mask=mask,
|
| 439 |
+
other=0.0).to(tl.float32)
|
| 440 |
+
go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
|
| 441 |
+
other=0.0).to(tl.float32)
|
| 442 |
+
|
| 443 |
+
x2 = x * x
|
| 444 |
+
x3 = x2 * x
|
| 445 |
+
|
| 446 |
+
poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv
|
| 447 |
+
tl.store(grad_mul_row_ptr + d_offs,
|
| 448 |
+
go * (poly + b_val),
|
| 449 |
+
mask=mask)
|
| 450 |
+
|
| 451 |
+
dpoly = go * m
|
| 452 |
+
g = inv_rms_x * (w2 * dpoly - x * coeff_x)
|
| 453 |
+
g += (2.0 * x * inv_rms_x2 *
|
| 454 |
+
(w1 * dpoly - x2 * coeff_x2))
|
| 455 |
+
g += (3.0 * x2 * inv_rms_x3 *
|
| 456 |
+
(w0 * dpoly - x3 * coeff_x3))
|
| 457 |
+
|
| 458 |
+
tl.store(grad_input_row_ptr + d_offs, g, mask=mask)
|
| 459 |
+
|
| 460 |
+
tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc)
|
| 461 |
+
tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc)
|
| 462 |
+
tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc)
|
| 463 |
+
tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc)
|
| 464 |
+
|
| 465 |
+
class _GroupedPolyNormFn(torch.autograd.Function):
|
| 466 |
+
|
| 467 |
+
@staticmethod
|
| 468 |
+
def forward(ctx, input, mul, weight, bias, offsets, eps, expert_offset):
|
| 469 |
+
N, D = input.shape
|
| 470 |
+
input = input.contiguous()
|
| 471 |
+
mul = mul.contiguous()
|
| 472 |
+
output = torch.empty_like(input)
|
| 473 |
+
|
| 474 |
+
num_experts = offsets.shape[0]
|
| 475 |
+
assert num_experts <= 4096, (
|
| 476 |
+
f"Supports at most 4096 experts, got {num_experts}.")
|
| 477 |
+
|
| 478 |
+
_grouped_polynorm_fwd_kernel[(N,)](
|
| 479 |
+
input,
|
| 480 |
+
mul,
|
| 481 |
+
weight,
|
| 482 |
+
bias,
|
| 483 |
+
offsets,
|
| 484 |
+
output,
|
| 485 |
+
N,
|
| 486 |
+
D,
|
| 487 |
+
num_experts,
|
| 488 |
+
eps,
|
| 489 |
+
expert_offset,
|
| 490 |
+
stride_input_row=input.stride(0),
|
| 491 |
+
stride_mul_row=mul.stride(0),
|
| 492 |
+
stride_out_row=output.stride(0),
|
| 493 |
+
)
|
| 494 |
+
|
| 495 |
+
ctx.save_for_backward(input, mul, weight, bias, offsets)
|
| 496 |
+
ctx.eps = eps
|
| 497 |
+
ctx.expert_offset = expert_offset
|
| 498 |
+
return output
|
| 499 |
+
|
| 500 |
+
@staticmethod
|
| 501 |
+
def backward(ctx, grad_output):
|
| 502 |
+
input, mul, weight, bias, offsets = ctx.saved_tensors
|
| 503 |
+
eps = ctx.eps
|
| 504 |
+
expert_offset = ctx.expert_offset
|
| 505 |
+
N, D = input.shape
|
| 506 |
+
|
| 507 |
+
grad_output = grad_output.contiguous()
|
| 508 |
+
grad_input = torch.empty_like(input)
|
| 509 |
+
grad_mul = torch.empty_like(mul)
|
| 510 |
+
grad_weight = torch.zeros(weight.shape[0],
|
| 511 |
+
3,
|
| 512 |
+
device=weight.device,
|
| 513 |
+
dtype=torch.float32)
|
| 514 |
+
grad_bias = torch.zeros(bias.shape[0],
|
| 515 |
+
device=bias.device,
|
| 516 |
+
dtype=torch.float32)
|
| 517 |
+
|
| 518 |
+
num_experts = offsets.shape[0]
|
| 519 |
+
|
| 520 |
+
_grouped_polynorm_bwd_kernel[(N,)](
|
| 521 |
+
grad_output,
|
| 522 |
+
input,
|
| 523 |
+
mul,
|
| 524 |
+
weight,
|
| 525 |
+
bias,
|
| 526 |
+
offsets,
|
| 527 |
+
grad_input,
|
| 528 |
+
grad_mul,
|
| 529 |
+
grad_weight,
|
| 530 |
+
grad_bias,
|
| 531 |
+
N,
|
| 532 |
+
D,
|
| 533 |
+
num_experts,
|
| 534 |
+
eps,
|
| 535 |
+
expert_offset,
|
| 536 |
+
stride_row=input.stride(0),
|
| 537 |
+
)
|
| 538 |
+
|
| 539 |
+
grad_weight = grad_weight.to(weight.dtype)
|
| 540 |
+
grad_bias = grad_bias.unsqueeze(-1).to(bias.dtype)
|
| 541 |
+
|
| 542 |
+
return grad_input, grad_mul, grad_weight, grad_bias, None, None, None
|
| 543 |
+
|
| 544 |
+
def grouped_fused_mul_poly_norm(
|
| 545 |
+
input: Tensor,
|
| 546 |
+
mul: Tensor,
|
| 547 |
+
weight: Tensor,
|
| 548 |
+
bias: Tensor,
|
| 549 |
+
offsets: Tensor,
|
| 550 |
+
eps: float = 1e-6,
|
| 551 |
+
expert_offset: int = 0,
|
| 552 |
+
) -> Tensor:
|
| 553 |
+
"""Triton-accelerated Grouped FusedMulPolyNorm.
|
| 554 |
+
|
| 555 |
+
Args:
|
| 556 |
+
input: (total_tokens, D) - concatenated tokens for all experts
|
| 557 |
+
mul: (total_tokens, D) - gate values to multiply with
|
| 558 |
+
weight: (num_experts, 3) - per-expert polynomial weights
|
| 559 |
+
bias: (num_experts, 1) - per-expert polynomial bias
|
| 560 |
+
offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32)
|
| 561 |
+
eps: numerical stability epsilon
|
| 562 |
+
expert_offset: offset to add to expert index
|
| 563 |
+
|
| 564 |
+
Returns:
|
| 565 |
+
(total_tokens, D) - output tensor
|
| 566 |
+
"""
|
| 567 |
+
return _GroupedPolyNormFn.apply(input, mul, weight, bias, offsets, eps,
|
| 568 |
+
expert_offset)
|
| 569 |
+
|
| 570 |
+
else:
|
| 571 |
+
|
| 572 |
+
def grouped_fused_mul_poly_norm(
|
| 573 |
+
input: Tensor,
|
| 574 |
+
mul: Tensor,
|
| 575 |
+
weight: Tensor,
|
| 576 |
+
bias: Tensor,
|
| 577 |
+
offsets: Tensor,
|
| 578 |
+
eps: float = 1e-6,
|
| 579 |
+
expert_offset: int = 0,
|
| 580 |
+
) -> Tensor:
|
| 581 |
+
raise RuntimeError(
|
| 582 |
+
"Triton is not available. Install triton to use "
|
| 583 |
+
"grouped_fused_mul_poly_norm.")
|
build/torch210-cxx11-rocm71-x86_64-linux/__init__.py
CHANGED
|
@@ -2,6 +2,7 @@ import torch
|
|
| 2 |
|
| 3 |
from . import layers, parallel_style
|
| 4 |
from ._ops import ops
|
|
|
|
| 5 |
from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
|
| 6 |
from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
|
| 7 |
|
|
@@ -45,6 +46,7 @@ def fused_add_rms_norm(
|
|
| 45 |
__all__ = [
|
| 46 |
"poly_norm",
|
| 47 |
"fused_mul_poly_norm",
|
|
|
|
| 48 |
"rms_norm",
|
| 49 |
"fused_add_rms_norm",
|
| 50 |
"layers",
|
|
|
|
| 2 |
|
| 3 |
from . import layers, parallel_style
|
| 4 |
from ._ops import ops
|
| 5 |
+
from .grouped_poly_norm import grouped_fused_mul_poly_norm
|
| 6 |
from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
|
| 7 |
from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
|
| 8 |
|
|
|
|
| 46 |
__all__ = [
|
| 47 |
"poly_norm",
|
| 48 |
"fused_mul_poly_norm",
|
| 49 |
+
"grouped_fused_mul_poly_norm",
|
| 50 |
"rms_norm",
|
| 51 |
"fused_add_rms_norm",
|
| 52 |
"layers",
|
build/torch210-cxx11-rocm71-x86_64-linux/{_activation_18b7543_dirty.abi3.so → _activation_0e6f27f_dirty.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 2911200
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:28e6de9c3cd172e95284b0df0e15a2afb21ab7b89dd624e69b1361942095e8be
|
| 3 |
size 2911200
|
build/torch210-cxx11-rocm71-x86_64-linux/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _activation_0e6f27f_dirty
|
| 3 |
+
ops = torch.ops._activation_0e6f27f_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_activation_0e6f27f_dirty::{op_name}"
|
build/torch210-cxx11-rocm71-x86_64-linux/grouped_poly_norm.py
ADDED
|
@@ -0,0 +1,583 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Triton-accelerated Grouped FusedMulPolyNorm for MoE.
|
| 2 |
+
|
| 3 |
+
Fuses the entire PolyNorm computation into two Triton kernels (fwd + bwd),
|
| 4 |
+
eliminating multiple intermediate tensors and kernel launches.
|
| 5 |
+
|
| 6 |
+
PolyNorm formula (per row):
|
| 7 |
+
poly = w[0] * rms_norm(x^3) + w[1] * rms_norm(x^2) + w[2] * rms_norm(x) + bias
|
| 8 |
+
output = poly * mul
|
| 9 |
+
|
| 10 |
+
where rms_norm(x) = x / sqrt(mean(x^2, dim=-1) + eps)
|
| 11 |
+
|
| 12 |
+
Performance optimizations:
|
| 13 |
+
- @triton.autotune selects optimal BLOCK_D, num_warps, and num_stages per
|
| 14 |
+
hidden dimension.
|
| 15 |
+
- Single-tile specialization: when D <= BLOCK_D, all data stays in registers
|
| 16 |
+
across the reduction and output phases, eliminating redundant global reads.
|
| 17 |
+
- Multi-tile software pipelining: explicit num_stages in autotune configs
|
| 18 |
+
enables overlapping memory loads with computation across loop iterations.
|
| 19 |
+
- In-kernel binary search for expert mapping: eliminates 2 PyTorch kernel
|
| 20 |
+
launches (torch.arange + torch.bucketize) per forward/backward call.
|
| 21 |
+
- Backward 2-pass optimization: pass 1 merges RMS statistics computation
|
| 22 |
+
with dot product accumulation, pass 2 computes gradients. This reduces
|
| 23 |
+
memory traffic compared to a naive 3-pass approach.
|
| 24 |
+
|
| 25 |
+
Forward kernel: one program per row, tiles over D dimension.
|
| 26 |
+
- Computes x, x^2, x^3 in registers
|
| 27 |
+
- Computes three RMS norms in a single pass (shared variance reduction)
|
| 28 |
+
- Applies polynomial weights + bias + mul in-place
|
| 29 |
+
|
| 30 |
+
Backward kernel: one program per row, tiles over D dimension.
|
| 31 |
+
- Recomputes forward intermediates from saved inputs (activation recomputation)
|
| 32 |
+
- 2-pass: (1) RMS stats + dot products + bias grad, (2) grad_input + grad_mul + weight grads
|
| 33 |
+
- Weight/bias gradients use tl.atomic_add for cross-row accumulation
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
import torch
|
| 37 |
+
from torch import Tensor
|
| 38 |
+
|
| 39 |
+
try:
|
| 40 |
+
import triton
|
| 41 |
+
import triton.language as tl
|
| 42 |
+
|
| 43 |
+
HAS_TRITON = True
|
| 44 |
+
except ImportError:
|
| 45 |
+
HAS_TRITON = False
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# ---------------------------------------------------------------------------
|
| 49 |
+
# PyTorch reference implementation (for testing and benchmarking)
|
| 50 |
+
# ---------------------------------------------------------------------------
|
| 51 |
+
def _rms_norm(x: Tensor, eps: float) -> Tensor:
|
| 52 |
+
"""Per-row RMS normalization: x / sqrt(mean(x^2, dim=-1) + eps)"""
|
| 53 |
+
return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def grouped_fused_mul_poly_norm_ref(
|
| 57 |
+
input: Tensor,
|
| 58 |
+
mul: Tensor,
|
| 59 |
+
weight: Tensor,
|
| 60 |
+
bias: Tensor,
|
| 61 |
+
offsets: Tensor,
|
| 62 |
+
eps: float = 1e-6,
|
| 63 |
+
expert_offset: int = 0,
|
| 64 |
+
) -> Tensor:
|
| 65 |
+
"""PyTorch reference for Grouped FusedMulPolyNorm (vectorized, single pass).
|
| 66 |
+
|
| 67 |
+
Uses torch.bucketize to map tokens to experts, then computes PolyNorm
|
| 68 |
+
for all tokens at once. torch.compile friendly.
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
input: (total_tokens, D) - concatenated tokens for all experts
|
| 72 |
+
mul: (total_tokens, D) - gate values to multiply with
|
| 73 |
+
weight: (num_experts, 3) - per-expert polynomial weights [x^3, x^2, x]
|
| 74 |
+
bias: (num_experts, 1) - per-expert polynomial bias
|
| 75 |
+
offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32)
|
| 76 |
+
eps: numerical stability epsilon
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
(total_tokens, D) - output tensor
|
| 80 |
+
"""
|
| 81 |
+
orig_dtype = input.dtype
|
| 82 |
+
|
| 83 |
+
token_positions = torch.arange(input.shape[0], device=input.device)
|
| 84 |
+
expert_idx = torch.bucketize(token_positions, offsets, right=True) + expert_offset
|
| 85 |
+
|
| 86 |
+
weight_fp32 = weight.float()
|
| 87 |
+
bias_fp32 = bias.float()
|
| 88 |
+
|
| 89 |
+
per_token_w = weight_fp32[expert_idx]
|
| 90 |
+
per_token_b = bias_fp32[expert_idx]
|
| 91 |
+
|
| 92 |
+
x = input.float()
|
| 93 |
+
m = mul.float()
|
| 94 |
+
|
| 95 |
+
x2 = x * x
|
| 96 |
+
x3 = x2 * x
|
| 97 |
+
|
| 98 |
+
poly = (per_token_w[:, 0:1] * _rms_norm(x3, eps) +
|
| 99 |
+
per_token_w[:, 1:2] * _rms_norm(x2, eps) +
|
| 100 |
+
per_token_w[:, 2:3] * _rms_norm(x, eps) + per_token_b)
|
| 101 |
+
|
| 102 |
+
return (poly * m).to(orig_dtype)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
# ---------------------------------------------------------------------------
|
| 106 |
+
# Triton kernel implementation
|
| 107 |
+
# ---------------------------------------------------------------------------
|
| 108 |
+
if HAS_TRITON:
|
| 109 |
+
# --- Autotune configurations ---
|
| 110 |
+
_GROUPED_POLYNORM_FWD_CONFIGS = [
|
| 111 |
+
triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2),
|
| 112 |
+
triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3),
|
| 113 |
+
triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2),
|
| 114 |
+
triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3),
|
| 115 |
+
triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2),
|
| 116 |
+
triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4),
|
| 117 |
+
triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2),
|
| 118 |
+
triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3),
|
| 119 |
+
triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4),
|
| 120 |
+
triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2),
|
| 121 |
+
triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3),
|
| 122 |
+
triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4),
|
| 123 |
+
triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2),
|
| 124 |
+
triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2),
|
| 125 |
+
triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3),
|
| 126 |
+
triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2),
|
| 127 |
+
triton.Config({"BLOCK_D": 2048}, num_warps=4, num_stages=1),
|
| 128 |
+
triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1),
|
| 129 |
+
triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1),
|
| 130 |
+
triton.Config({"BLOCK_D": 2048}, num_warps=32, num_stages=1),
|
| 131 |
+
]
|
| 132 |
+
|
| 133 |
+
_GROUPED_POLYNORM_BWD_CONFIGS = [
|
| 134 |
+
triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2),
|
| 135 |
+
triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3),
|
| 136 |
+
triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2),
|
| 137 |
+
triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3),
|
| 138 |
+
triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2),
|
| 139 |
+
triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4),
|
| 140 |
+
triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2),
|
| 141 |
+
triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3),
|
| 142 |
+
triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4),
|
| 143 |
+
triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=5),
|
| 144 |
+
triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2),
|
| 145 |
+
triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3),
|
| 146 |
+
triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4),
|
| 147 |
+
triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2),
|
| 148 |
+
triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2),
|
| 149 |
+
triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3),
|
| 150 |
+
triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=4),
|
| 151 |
+
triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2),
|
| 152 |
+
triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1),
|
| 153 |
+
triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1),
|
| 154 |
+
]
|
| 155 |
+
|
| 156 |
+
@triton.autotune(
|
| 157 |
+
configs=_GROUPED_POLYNORM_FWD_CONFIGS,
|
| 158 |
+
key=["D"],
|
| 159 |
+
)
|
| 160 |
+
@triton.jit
|
| 161 |
+
def _grouped_polynorm_fwd_kernel(
|
| 162 |
+
input_ptr,
|
| 163 |
+
mul_ptr,
|
| 164 |
+
weight_ptr,
|
| 165 |
+
bias_ptr,
|
| 166 |
+
offsets_ptr,
|
| 167 |
+
output_ptr,
|
| 168 |
+
N,
|
| 169 |
+
D,
|
| 170 |
+
num_experts,
|
| 171 |
+
eps,
|
| 172 |
+
expert_offset,
|
| 173 |
+
stride_input_row,
|
| 174 |
+
stride_mul_row,
|
| 175 |
+
stride_out_row,
|
| 176 |
+
BLOCK_D: tl.constexpr,
|
| 177 |
+
):
|
| 178 |
+
"""Forward kernel: one program per row."""
|
| 179 |
+
row = tl.program_id(0)
|
| 180 |
+
if row >= N:
|
| 181 |
+
return
|
| 182 |
+
|
| 183 |
+
# Binary search for expert index (12 iters covers up to 4096 experts)
|
| 184 |
+
lo = 0
|
| 185 |
+
hi = num_experts
|
| 186 |
+
for _ in range(12):
|
| 187 |
+
if lo < hi:
|
| 188 |
+
mid = (lo + hi) // 2
|
| 189 |
+
if tl.load(offsets_ptr + mid) <= row:
|
| 190 |
+
lo = mid + 1
|
| 191 |
+
else:
|
| 192 |
+
hi = mid
|
| 193 |
+
eidx = lo + expert_offset
|
| 194 |
+
|
| 195 |
+
w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32)
|
| 196 |
+
w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32)
|
| 197 |
+
w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32)
|
| 198 |
+
b = tl.load(bias_ptr + eidx).to(tl.float32)
|
| 199 |
+
|
| 200 |
+
input_row_ptr = input_ptr + row * stride_input_row
|
| 201 |
+
mul_row_ptr = mul_ptr + row * stride_mul_row
|
| 202 |
+
out_row_ptr = output_ptr + row * stride_out_row
|
| 203 |
+
|
| 204 |
+
D_float = D.to(tl.float32)
|
| 205 |
+
|
| 206 |
+
# --- Single-tile path ---
|
| 207 |
+
if D <= BLOCK_D:
|
| 208 |
+
d_offs = tl.arange(0, BLOCK_D)
|
| 209 |
+
mask = d_offs < D
|
| 210 |
+
|
| 211 |
+
x = tl.load(input_row_ptr + d_offs, mask=mask,
|
| 212 |
+
other=0.0).to(tl.float32)
|
| 213 |
+
m = tl.load(mul_row_ptr + d_offs, mask=mask,
|
| 214 |
+
other=0.0).to(tl.float32)
|
| 215 |
+
|
| 216 |
+
x2 = x * x
|
| 217 |
+
x3 = x2 * x
|
| 218 |
+
|
| 219 |
+
inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps)
|
| 220 |
+
inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps)
|
| 221 |
+
inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps)
|
| 222 |
+
|
| 223 |
+
# Pre-multiply scalar weight * inv_rms to save 1 FMA per element
|
| 224 |
+
w0_inv = w0 * inv_rms_x3
|
| 225 |
+
w1_inv = w1 * inv_rms_x2
|
| 226 |
+
w2_inv = w2 * inv_rms_x
|
| 227 |
+
|
| 228 |
+
poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b
|
| 229 |
+
tl.store(out_row_ptr + d_offs, poly * m, mask=mask)
|
| 230 |
+
else:
|
| 231 |
+
# --- Multi-tile: two-pass approach ---
|
| 232 |
+
sum_x2 = tl.zeros((), dtype=tl.float32)
|
| 233 |
+
sum_x4 = tl.zeros((), dtype=tl.float32)
|
| 234 |
+
sum_x6 = tl.zeros((), dtype=tl.float32)
|
| 235 |
+
|
| 236 |
+
for d_start in range(0, D, BLOCK_D):
|
| 237 |
+
d_offs = d_start + tl.arange(0, BLOCK_D)
|
| 238 |
+
mask = d_offs < D
|
| 239 |
+
x = tl.load(input_row_ptr + d_offs, mask=mask,
|
| 240 |
+
other=0.0).to(tl.float32)
|
| 241 |
+
x2 = x * x
|
| 242 |
+
sum_x2 += tl.sum(x2)
|
| 243 |
+
sum_x4 += tl.sum(x2 * x2)
|
| 244 |
+
sum_x6 += tl.sum(x2 * x2 * x2)
|
| 245 |
+
|
| 246 |
+
inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps)
|
| 247 |
+
inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps)
|
| 248 |
+
inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps)
|
| 249 |
+
|
| 250 |
+
# Pre-multiply scalar weight * inv_rms
|
| 251 |
+
w0_inv = w0 * inv_rms_x3
|
| 252 |
+
w1_inv = w1 * inv_rms_x2
|
| 253 |
+
w2_inv = w2 * inv_rms_x
|
| 254 |
+
|
| 255 |
+
for d_start in range(0, D, BLOCK_D):
|
| 256 |
+
d_offs = d_start + tl.arange(0, BLOCK_D)
|
| 257 |
+
mask = d_offs < D
|
| 258 |
+
x = tl.load(input_row_ptr + d_offs, mask=mask,
|
| 259 |
+
other=0.0).to(tl.float32)
|
| 260 |
+
m = tl.load(mul_row_ptr + d_offs, mask=mask,
|
| 261 |
+
other=0.0).to(tl.float32)
|
| 262 |
+
x2 = x * x
|
| 263 |
+
x3 = x2 * x
|
| 264 |
+
poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b
|
| 265 |
+
tl.store(out_row_ptr + d_offs, poly * m, mask=mask)
|
| 266 |
+
|
| 267 |
+
@triton.autotune(
|
| 268 |
+
configs=_GROUPED_POLYNORM_BWD_CONFIGS,
|
| 269 |
+
key=["D"],
|
| 270 |
+
reset_to_zero=["grad_weight_ptr", "grad_bias_ptr"],
|
| 271 |
+
)
|
| 272 |
+
@triton.jit
|
| 273 |
+
def _grouped_polynorm_bwd_kernel(
|
| 274 |
+
grad_out_ptr,
|
| 275 |
+
input_ptr,
|
| 276 |
+
mul_ptr,
|
| 277 |
+
weight_ptr,
|
| 278 |
+
bias_ptr,
|
| 279 |
+
offsets_ptr,
|
| 280 |
+
grad_input_ptr,
|
| 281 |
+
grad_mul_ptr,
|
| 282 |
+
grad_weight_ptr,
|
| 283 |
+
grad_bias_ptr,
|
| 284 |
+
N,
|
| 285 |
+
D,
|
| 286 |
+
num_experts,
|
| 287 |
+
eps,
|
| 288 |
+
expert_offset,
|
| 289 |
+
stride_row,
|
| 290 |
+
BLOCK_D: tl.constexpr,
|
| 291 |
+
):
|
| 292 |
+
"""Backward kernel: one program per row, 2-pass approach.
|
| 293 |
+
|
| 294 |
+
Pass 1: RMS stats + dot products + bias grad
|
| 295 |
+
Pass 2: grad_input + grad_mul + weight grads (via atomic_add)
|
| 296 |
+
"""
|
| 297 |
+
row = tl.program_id(0)
|
| 298 |
+
if row >= N:
|
| 299 |
+
return
|
| 300 |
+
|
| 301 |
+
lo = 0
|
| 302 |
+
hi = num_experts
|
| 303 |
+
for _ in range(12):
|
| 304 |
+
if lo < hi:
|
| 305 |
+
mid = (lo + hi) // 2
|
| 306 |
+
if tl.load(offsets_ptr + mid) <= row:
|
| 307 |
+
lo = mid + 1
|
| 308 |
+
else:
|
| 309 |
+
hi = mid
|
| 310 |
+
eidx = lo + expert_offset
|
| 311 |
+
|
| 312 |
+
w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32)
|
| 313 |
+
w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32)
|
| 314 |
+
w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32)
|
| 315 |
+
b_val = tl.load(bias_ptr + eidx).to(tl.float32)
|
| 316 |
+
|
| 317 |
+
input_row_ptr = input_ptr + row * stride_row
|
| 318 |
+
mul_row_ptr = mul_ptr + row * stride_row
|
| 319 |
+
grad_out_row_ptr = grad_out_ptr + row * stride_row
|
| 320 |
+
grad_input_row_ptr = grad_input_ptr + row * stride_row
|
| 321 |
+
grad_mul_row_ptr = grad_mul_ptr + row * stride_row
|
| 322 |
+
|
| 323 |
+
D_float = D.to(tl.float32)
|
| 324 |
+
|
| 325 |
+
# --- Single-tile path ---
|
| 326 |
+
if D <= BLOCK_D:
|
| 327 |
+
d_offs = tl.arange(0, BLOCK_D)
|
| 328 |
+
mask = d_offs < D
|
| 329 |
+
|
| 330 |
+
x = tl.load(input_row_ptr + d_offs, mask=mask,
|
| 331 |
+
other=0.0).to(tl.float32)
|
| 332 |
+
m = tl.load(mul_row_ptr + d_offs, mask=mask,
|
| 333 |
+
other=0.0).to(tl.float32)
|
| 334 |
+
go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
|
| 335 |
+
other=0.0).to(tl.float32)
|
| 336 |
+
|
| 337 |
+
x2 = x * x
|
| 338 |
+
x3 = x2 * x
|
| 339 |
+
|
| 340 |
+
# Compute RMS stats (x4 inlined to reduce register pressure)
|
| 341 |
+
inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps)
|
| 342 |
+
inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps)
|
| 343 |
+
inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps)
|
| 344 |
+
|
| 345 |
+
w0_inv = w0 * inv_rms_x3
|
| 346 |
+
w1_inv = w1 * inv_rms_x2
|
| 347 |
+
w2_inv = w2 * inv_rms_x
|
| 348 |
+
|
| 349 |
+
dpoly = go * m
|
| 350 |
+
|
| 351 |
+
# Dot products for coefficients and weight grads
|
| 352 |
+
sum_dpoly_x = tl.sum(dpoly * x)
|
| 353 |
+
sum_dpoly_x2 = tl.sum(dpoly * x2)
|
| 354 |
+
sum_dpoly_x3 = tl.sum(dpoly * x3)
|
| 355 |
+
grad_b_acc = tl.sum(dpoly)
|
| 356 |
+
|
| 357 |
+
# Weight grads
|
| 358 |
+
grad_w0_acc = inv_rms_x3 * sum_dpoly_x3
|
| 359 |
+
grad_w1_acc = inv_rms_x2 * sum_dpoly_x2
|
| 360 |
+
grad_w2_acc = inv_rms_x * sum_dpoly_x
|
| 361 |
+
|
| 362 |
+
# Coefficients for grad_input
|
| 363 |
+
coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float
|
| 364 |
+
coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float
|
| 365 |
+
coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float
|
| 366 |
+
|
| 367 |
+
# grad_mul
|
| 368 |
+
poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv
|
| 369 |
+
tl.store(grad_mul_row_ptr + d_offs, go * (poly + b_val), mask=mask)
|
| 370 |
+
|
| 371 |
+
# grad_input (in-place accumulation to reduce register pressure)
|
| 372 |
+
g = inv_rms_x * (w2 * dpoly - x * coeff_x)
|
| 373 |
+
g += 2.0 * x * inv_rms_x2 * (w1 * dpoly - x2 * coeff_x2)
|
| 374 |
+
g += 3.0 * x2 * inv_rms_x3 * (w0 * dpoly - x3 * coeff_x3)
|
| 375 |
+
tl.store(grad_input_row_ptr + d_offs, g, mask=mask)
|
| 376 |
+
|
| 377 |
+
tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc)
|
| 378 |
+
tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc)
|
| 379 |
+
tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc)
|
| 380 |
+
tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc)
|
| 381 |
+
else:
|
| 382 |
+
# --- Multi-tile: 2-pass ---
|
| 383 |
+
# Pass 1: RMS stats + dot products + bias grad
|
| 384 |
+
sum_x2 = tl.zeros((), dtype=tl.float32)
|
| 385 |
+
sum_x4 = tl.zeros((), dtype=tl.float32)
|
| 386 |
+
sum_x6 = tl.zeros((), dtype=tl.float32)
|
| 387 |
+
sum_dpoly_x = tl.zeros((), dtype=tl.float32)
|
| 388 |
+
sum_dpoly_x2 = tl.zeros((), dtype=tl.float32)
|
| 389 |
+
sum_dpoly_x3 = tl.zeros((), dtype=tl.float32)
|
| 390 |
+
grad_b_acc = tl.zeros((), dtype=tl.float32)
|
| 391 |
+
|
| 392 |
+
for d_start in range(0, D, BLOCK_D):
|
| 393 |
+
d_offs = d_start + tl.arange(0, BLOCK_D)
|
| 394 |
+
mask = d_offs < D
|
| 395 |
+
x = tl.load(input_row_ptr + d_offs, mask=mask,
|
| 396 |
+
other=0.0).to(tl.float32)
|
| 397 |
+
m = tl.load(mul_row_ptr + d_offs, mask=mask,
|
| 398 |
+
other=0.0).to(tl.float32)
|
| 399 |
+
go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
|
| 400 |
+
other=0.0).to(tl.float32)
|
| 401 |
+
|
| 402 |
+
x2 = x * x
|
| 403 |
+
x3 = x2 * x
|
| 404 |
+
dpoly = go * m
|
| 405 |
+
|
| 406 |
+
sum_x2 += tl.sum(x2)
|
| 407 |
+
sum_x4 += tl.sum(x2 * x2)
|
| 408 |
+
sum_x6 += tl.sum(x2 * x2 * x2)
|
| 409 |
+
sum_dpoly_x += tl.sum(dpoly * x)
|
| 410 |
+
sum_dpoly_x2 += tl.sum(dpoly * x2)
|
| 411 |
+
sum_dpoly_x3 += tl.sum(dpoly * x3)
|
| 412 |
+
grad_b_acc += tl.sum(dpoly)
|
| 413 |
+
|
| 414 |
+
inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps)
|
| 415 |
+
inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps)
|
| 416 |
+
inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps)
|
| 417 |
+
|
| 418 |
+
w0_inv = w0 * inv_rms_x3
|
| 419 |
+
w1_inv = w1 * inv_rms_x2
|
| 420 |
+
w2_inv = w2 * inv_rms_x
|
| 421 |
+
|
| 422 |
+
# Weight grads
|
| 423 |
+
grad_w0_acc = inv_rms_x3 * sum_dpoly_x3
|
| 424 |
+
grad_w1_acc = inv_rms_x2 * sum_dpoly_x2
|
| 425 |
+
grad_w2_acc = inv_rms_x * sum_dpoly_x
|
| 426 |
+
|
| 427 |
+
# Coefficients for grad_input
|
| 428 |
+
coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float
|
| 429 |
+
coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float
|
| 430 |
+
coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float
|
| 431 |
+
|
| 432 |
+
# Pass 2: grad_input + grad_mul
|
| 433 |
+
for d_start in range(0, D, BLOCK_D):
|
| 434 |
+
d_offs = d_start + tl.arange(0, BLOCK_D)
|
| 435 |
+
mask = d_offs < D
|
| 436 |
+
x = tl.load(input_row_ptr + d_offs, mask=mask,
|
| 437 |
+
other=0.0).to(tl.float32)
|
| 438 |
+
m = tl.load(mul_row_ptr + d_offs, mask=mask,
|
| 439 |
+
other=0.0).to(tl.float32)
|
| 440 |
+
go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
|
| 441 |
+
other=0.0).to(tl.float32)
|
| 442 |
+
|
| 443 |
+
x2 = x * x
|
| 444 |
+
x3 = x2 * x
|
| 445 |
+
|
| 446 |
+
poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv
|
| 447 |
+
tl.store(grad_mul_row_ptr + d_offs,
|
| 448 |
+
go * (poly + b_val),
|
| 449 |
+
mask=mask)
|
| 450 |
+
|
| 451 |
+
dpoly = go * m
|
| 452 |
+
g = inv_rms_x * (w2 * dpoly - x * coeff_x)
|
| 453 |
+
g += (2.0 * x * inv_rms_x2 *
|
| 454 |
+
(w1 * dpoly - x2 * coeff_x2))
|
| 455 |
+
g += (3.0 * x2 * inv_rms_x3 *
|
| 456 |
+
(w0 * dpoly - x3 * coeff_x3))
|
| 457 |
+
|
| 458 |
+
tl.store(grad_input_row_ptr + d_offs, g, mask=mask)
|
| 459 |
+
|
| 460 |
+
tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc)
|
| 461 |
+
tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc)
|
| 462 |
+
tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc)
|
| 463 |
+
tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc)
|
| 464 |
+
|
| 465 |
+
class _GroupedPolyNormFn(torch.autograd.Function):
|
| 466 |
+
|
| 467 |
+
@staticmethod
|
| 468 |
+
def forward(ctx, input, mul, weight, bias, offsets, eps, expert_offset):
|
| 469 |
+
N, D = input.shape
|
| 470 |
+
input = input.contiguous()
|
| 471 |
+
mul = mul.contiguous()
|
| 472 |
+
output = torch.empty_like(input)
|
| 473 |
+
|
| 474 |
+
num_experts = offsets.shape[0]
|
| 475 |
+
assert num_experts <= 4096, (
|
| 476 |
+
f"Supports at most 4096 experts, got {num_experts}.")
|
| 477 |
+
|
| 478 |
+
_grouped_polynorm_fwd_kernel[(N,)](
|
| 479 |
+
input,
|
| 480 |
+
mul,
|
| 481 |
+
weight,
|
| 482 |
+
bias,
|
| 483 |
+
offsets,
|
| 484 |
+
output,
|
| 485 |
+
N,
|
| 486 |
+
D,
|
| 487 |
+
num_experts,
|
| 488 |
+
eps,
|
| 489 |
+
expert_offset,
|
| 490 |
+
stride_input_row=input.stride(0),
|
| 491 |
+
stride_mul_row=mul.stride(0),
|
| 492 |
+
stride_out_row=output.stride(0),
|
| 493 |
+
)
|
| 494 |
+
|
| 495 |
+
ctx.save_for_backward(input, mul, weight, bias, offsets)
|
| 496 |
+
ctx.eps = eps
|
| 497 |
+
ctx.expert_offset = expert_offset
|
| 498 |
+
return output
|
| 499 |
+
|
| 500 |
+
@staticmethod
|
| 501 |
+
def backward(ctx, grad_output):
|
| 502 |
+
input, mul, weight, bias, offsets = ctx.saved_tensors
|
| 503 |
+
eps = ctx.eps
|
| 504 |
+
expert_offset = ctx.expert_offset
|
| 505 |
+
N, D = input.shape
|
| 506 |
+
|
| 507 |
+
grad_output = grad_output.contiguous()
|
| 508 |
+
grad_input = torch.empty_like(input)
|
| 509 |
+
grad_mul = torch.empty_like(mul)
|
| 510 |
+
grad_weight = torch.zeros(weight.shape[0],
|
| 511 |
+
3,
|
| 512 |
+
device=weight.device,
|
| 513 |
+
dtype=torch.float32)
|
| 514 |
+
grad_bias = torch.zeros(bias.shape[0],
|
| 515 |
+
device=bias.device,
|
| 516 |
+
dtype=torch.float32)
|
| 517 |
+
|
| 518 |
+
num_experts = offsets.shape[0]
|
| 519 |
+
|
| 520 |
+
_grouped_polynorm_bwd_kernel[(N,)](
|
| 521 |
+
grad_output,
|
| 522 |
+
input,
|
| 523 |
+
mul,
|
| 524 |
+
weight,
|
| 525 |
+
bias,
|
| 526 |
+
offsets,
|
| 527 |
+
grad_input,
|
| 528 |
+
grad_mul,
|
| 529 |
+
grad_weight,
|
| 530 |
+
grad_bias,
|
| 531 |
+
N,
|
| 532 |
+
D,
|
| 533 |
+
num_experts,
|
| 534 |
+
eps,
|
| 535 |
+
expert_offset,
|
| 536 |
+
stride_row=input.stride(0),
|
| 537 |
+
)
|
| 538 |
+
|
| 539 |
+
grad_weight = grad_weight.to(weight.dtype)
|
| 540 |
+
grad_bias = grad_bias.unsqueeze(-1).to(bias.dtype)
|
| 541 |
+
|
| 542 |
+
return grad_input, grad_mul, grad_weight, grad_bias, None, None, None
|
| 543 |
+
|
| 544 |
+
def grouped_fused_mul_poly_norm(
|
| 545 |
+
input: Tensor,
|
| 546 |
+
mul: Tensor,
|
| 547 |
+
weight: Tensor,
|
| 548 |
+
bias: Tensor,
|
| 549 |
+
offsets: Tensor,
|
| 550 |
+
eps: float = 1e-6,
|
| 551 |
+
expert_offset: int = 0,
|
| 552 |
+
) -> Tensor:
|
| 553 |
+
"""Triton-accelerated Grouped FusedMulPolyNorm.
|
| 554 |
+
|
| 555 |
+
Args:
|
| 556 |
+
input: (total_tokens, D) - concatenated tokens for all experts
|
| 557 |
+
mul: (total_tokens, D) - gate values to multiply with
|
| 558 |
+
weight: (num_experts, 3) - per-expert polynomial weights
|
| 559 |
+
bias: (num_experts, 1) - per-expert polynomial bias
|
| 560 |
+
offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32)
|
| 561 |
+
eps: numerical stability epsilon
|
| 562 |
+
expert_offset: offset to add to expert index
|
| 563 |
+
|
| 564 |
+
Returns:
|
| 565 |
+
(total_tokens, D) - output tensor
|
| 566 |
+
"""
|
| 567 |
+
return _GroupedPolyNormFn.apply(input, mul, weight, bias, offsets, eps,
|
| 568 |
+
expert_offset)
|
| 569 |
+
|
| 570 |
+
else:
|
| 571 |
+
|
| 572 |
+
def grouped_fused_mul_poly_norm(
|
| 573 |
+
input: Tensor,
|
| 574 |
+
mul: Tensor,
|
| 575 |
+
weight: Tensor,
|
| 576 |
+
bias: Tensor,
|
| 577 |
+
offsets: Tensor,
|
| 578 |
+
eps: float = 1e-6,
|
| 579 |
+
expert_offset: int = 0,
|
| 580 |
+
) -> Tensor:
|
| 581 |
+
raise RuntimeError(
|
| 582 |
+
"Triton is not available. Install triton to use "
|
| 583 |
+
"grouped_fused_mul_poly_norm.")
|
build/torch28-cxx11-cu126-x86_64-linux/__init__.py
CHANGED
|
@@ -2,6 +2,7 @@ import torch
|
|
| 2 |
|
| 3 |
from . import layers, parallel_style
|
| 4 |
from ._ops import ops
|
|
|
|
| 5 |
from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
|
| 6 |
from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
|
| 7 |
|
|
@@ -45,6 +46,7 @@ def fused_add_rms_norm(
|
|
| 45 |
__all__ = [
|
| 46 |
"poly_norm",
|
| 47 |
"fused_mul_poly_norm",
|
|
|
|
| 48 |
"rms_norm",
|
| 49 |
"fused_add_rms_norm",
|
| 50 |
"layers",
|
|
|
|
| 2 |
|
| 3 |
from . import layers, parallel_style
|
| 4 |
from ._ops import ops
|
| 5 |
+
from .grouped_poly_norm import grouped_fused_mul_poly_norm
|
| 6 |
from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
|
| 7 |
from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
|
| 8 |
|
|
|
|
| 46 |
__all__ = [
|
| 47 |
"poly_norm",
|
| 48 |
"fused_mul_poly_norm",
|
| 49 |
+
"grouped_fused_mul_poly_norm",
|
| 50 |
"rms_norm",
|
| 51 |
"fused_add_rms_norm",
|
| 52 |
"layers",
|
build/torch28-cxx11-cu126-x86_64-linux/{_activation_18b7543_dirty.abi3.so → _activation_0e6f27f_dirty.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 10756352
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:51ac098828ee90af0d1d17ae75326f89777b1b2e7ef57e00035aed560c434a20
|
| 3 |
size 10756352
|
build/torch28-cxx11-cu126-x86_64-linux/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _activation_0e6f27f_dirty
|
| 3 |
+
ops = torch.ops._activation_0e6f27f_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_activation_0e6f27f_dirty::{op_name}"
|
build/torch28-cxx11-cu126-x86_64-linux/grouped_poly_norm.py
ADDED
|
@@ -0,0 +1,583 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Triton-accelerated Grouped FusedMulPolyNorm for MoE.
|
| 2 |
+
|
| 3 |
+
Fuses the entire PolyNorm computation into two Triton kernels (fwd + bwd),
|
| 4 |
+
eliminating multiple intermediate tensors and kernel launches.
|
| 5 |
+
|
| 6 |
+
PolyNorm formula (per row):
|
| 7 |
+
poly = w[0] * rms_norm(x^3) + w[1] * rms_norm(x^2) + w[2] * rms_norm(x) + bias
|
| 8 |
+
output = poly * mul
|
| 9 |
+
|
| 10 |
+
where rms_norm(x) = x / sqrt(mean(x^2, dim=-1) + eps)
|
| 11 |
+
|
| 12 |
+
Performance optimizations:
|
| 13 |
+
- @triton.autotune selects optimal BLOCK_D, num_warps, and num_stages per
|
| 14 |
+
hidden dimension.
|
| 15 |
+
- Single-tile specialization: when D <= BLOCK_D, all data stays in registers
|
| 16 |
+
across the reduction and output phases, eliminating redundant global reads.
|
| 17 |
+
- Multi-tile software pipelining: explicit num_stages in autotune configs
|
| 18 |
+
enables overlapping memory loads with computation across loop iterations.
|
| 19 |
+
- In-kernel binary search for expert mapping: eliminates 2 PyTorch kernel
|
| 20 |
+
launches (torch.arange + torch.bucketize) per forward/backward call.
|
| 21 |
+
- Backward 2-pass optimization: pass 1 merges RMS statistics computation
|
| 22 |
+
with dot product accumulation, pass 2 computes gradients. This reduces
|
| 23 |
+
memory traffic compared to a naive 3-pass approach.
|
| 24 |
+
|
| 25 |
+
Forward kernel: one program per row, tiles over D dimension.
|
| 26 |
+
- Computes x, x^2, x^3 in registers
|
| 27 |
+
- Computes three RMS norms in a single pass (shared variance reduction)
|
| 28 |
+
- Applies polynomial weights + bias + mul in-place
|
| 29 |
+
|
| 30 |
+
Backward kernel: one program per row, tiles over D dimension.
|
| 31 |
+
- Recomputes forward intermediates from saved inputs (activation recomputation)
|
| 32 |
+
- 2-pass: (1) RMS stats + dot products + bias grad, (2) grad_input + grad_mul + weight grads
|
| 33 |
+
- Weight/bias gradients use tl.atomic_add for cross-row accumulation
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
import torch
|
| 37 |
+
from torch import Tensor
|
| 38 |
+
|
| 39 |
+
try:
|
| 40 |
+
import triton
|
| 41 |
+
import triton.language as tl
|
| 42 |
+
|
| 43 |
+
HAS_TRITON = True
|
| 44 |
+
except ImportError:
|
| 45 |
+
HAS_TRITON = False
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# ---------------------------------------------------------------------------
|
| 49 |
+
# PyTorch reference implementation (for testing and benchmarking)
|
| 50 |
+
# ---------------------------------------------------------------------------
|
| 51 |
+
def _rms_norm(x: Tensor, eps: float) -> Tensor:
|
| 52 |
+
"""Per-row RMS normalization: x / sqrt(mean(x^2, dim=-1) + eps)"""
|
| 53 |
+
return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def grouped_fused_mul_poly_norm_ref(
|
| 57 |
+
input: Tensor,
|
| 58 |
+
mul: Tensor,
|
| 59 |
+
weight: Tensor,
|
| 60 |
+
bias: Tensor,
|
| 61 |
+
offsets: Tensor,
|
| 62 |
+
eps: float = 1e-6,
|
| 63 |
+
expert_offset: int = 0,
|
| 64 |
+
) -> Tensor:
|
| 65 |
+
"""PyTorch reference for Grouped FusedMulPolyNorm (vectorized, single pass).
|
| 66 |
+
|
| 67 |
+
Uses torch.bucketize to map tokens to experts, then computes PolyNorm
|
| 68 |
+
for all tokens at once. torch.compile friendly.
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
input: (total_tokens, D) - concatenated tokens for all experts
|
| 72 |
+
mul: (total_tokens, D) - gate values to multiply with
|
| 73 |
+
weight: (num_experts, 3) - per-expert polynomial weights [x^3, x^2, x]
|
| 74 |
+
bias: (num_experts, 1) - per-expert polynomial bias
|
| 75 |
+
offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32)
|
| 76 |
+
eps: numerical stability epsilon
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
(total_tokens, D) - output tensor
|
| 80 |
+
"""
|
| 81 |
+
orig_dtype = input.dtype
|
| 82 |
+
|
| 83 |
+
token_positions = torch.arange(input.shape[0], device=input.device)
|
| 84 |
+
expert_idx = torch.bucketize(token_positions, offsets, right=True) + expert_offset
|
| 85 |
+
|
| 86 |
+
weight_fp32 = weight.float()
|
| 87 |
+
bias_fp32 = bias.float()
|
| 88 |
+
|
| 89 |
+
per_token_w = weight_fp32[expert_idx]
|
| 90 |
+
per_token_b = bias_fp32[expert_idx]
|
| 91 |
+
|
| 92 |
+
x = input.float()
|
| 93 |
+
m = mul.float()
|
| 94 |
+
|
| 95 |
+
x2 = x * x
|
| 96 |
+
x3 = x2 * x
|
| 97 |
+
|
| 98 |
+
poly = (per_token_w[:, 0:1] * _rms_norm(x3, eps) +
|
| 99 |
+
per_token_w[:, 1:2] * _rms_norm(x2, eps) +
|
| 100 |
+
per_token_w[:, 2:3] * _rms_norm(x, eps) + per_token_b)
|
| 101 |
+
|
| 102 |
+
return (poly * m).to(orig_dtype)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
# ---------------------------------------------------------------------------
|
| 106 |
+
# Triton kernel implementation
|
| 107 |
+
# ---------------------------------------------------------------------------
|
| 108 |
+
if HAS_TRITON:
|
| 109 |
+
# --- Autotune configurations ---
|
| 110 |
+
_GROUPED_POLYNORM_FWD_CONFIGS = [
|
| 111 |
+
triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2),
|
| 112 |
+
triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3),
|
| 113 |
+
triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2),
|
| 114 |
+
triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3),
|
| 115 |
+
triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2),
|
| 116 |
+
triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4),
|
| 117 |
+
triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2),
|
| 118 |
+
triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3),
|
| 119 |
+
triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4),
|
| 120 |
+
triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2),
|
| 121 |
+
triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3),
|
| 122 |
+
triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4),
|
| 123 |
+
triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2),
|
| 124 |
+
triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2),
|
| 125 |
+
triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3),
|
| 126 |
+
triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2),
|
| 127 |
+
triton.Config({"BLOCK_D": 2048}, num_warps=4, num_stages=1),
|
| 128 |
+
triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1),
|
| 129 |
+
triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1),
|
| 130 |
+
triton.Config({"BLOCK_D": 2048}, num_warps=32, num_stages=1),
|
| 131 |
+
]
|
| 132 |
+
|
| 133 |
+
_GROUPED_POLYNORM_BWD_CONFIGS = [
|
| 134 |
+
triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2),
|
| 135 |
+
triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3),
|
| 136 |
+
triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2),
|
| 137 |
+
triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3),
|
| 138 |
+
triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2),
|
| 139 |
+
triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4),
|
| 140 |
+
triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2),
|
| 141 |
+
triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3),
|
| 142 |
+
triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4),
|
| 143 |
+
triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=5),
|
| 144 |
+
triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2),
|
| 145 |
+
triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3),
|
| 146 |
+
triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4),
|
| 147 |
+
triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2),
|
| 148 |
+
triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2),
|
| 149 |
+
triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3),
|
| 150 |
+
triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=4),
|
| 151 |
+
triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2),
|
| 152 |
+
triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1),
|
| 153 |
+
triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1),
|
| 154 |
+
]
|
| 155 |
+
|
| 156 |
+
@triton.autotune(
|
| 157 |
+
configs=_GROUPED_POLYNORM_FWD_CONFIGS,
|
| 158 |
+
key=["D"],
|
| 159 |
+
)
|
| 160 |
+
@triton.jit
|
| 161 |
+
def _grouped_polynorm_fwd_kernel(
|
| 162 |
+
input_ptr,
|
| 163 |
+
mul_ptr,
|
| 164 |
+
weight_ptr,
|
| 165 |
+
bias_ptr,
|
| 166 |
+
offsets_ptr,
|
| 167 |
+
output_ptr,
|
| 168 |
+
N,
|
| 169 |
+
D,
|
| 170 |
+
num_experts,
|
| 171 |
+
eps,
|
| 172 |
+
expert_offset,
|
| 173 |
+
stride_input_row,
|
| 174 |
+
stride_mul_row,
|
| 175 |
+
stride_out_row,
|
| 176 |
+
BLOCK_D: tl.constexpr,
|
| 177 |
+
):
|
| 178 |
+
"""Forward kernel: one program per row."""
|
| 179 |
+
row = tl.program_id(0)
|
| 180 |
+
if row >= N:
|
| 181 |
+
return
|
| 182 |
+
|
| 183 |
+
# Binary search for expert index (12 iters covers up to 4096 experts)
|
| 184 |
+
lo = 0
|
| 185 |
+
hi = num_experts
|
| 186 |
+
for _ in range(12):
|
| 187 |
+
if lo < hi:
|
| 188 |
+
mid = (lo + hi) // 2
|
| 189 |
+
if tl.load(offsets_ptr + mid) <= row:
|
| 190 |
+
lo = mid + 1
|
| 191 |
+
else:
|
| 192 |
+
hi = mid
|
| 193 |
+
eidx = lo + expert_offset
|
| 194 |
+
|
| 195 |
+
w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32)
|
| 196 |
+
w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32)
|
| 197 |
+
w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32)
|
| 198 |
+
b = tl.load(bias_ptr + eidx).to(tl.float32)
|
| 199 |
+
|
| 200 |
+
input_row_ptr = input_ptr + row * stride_input_row
|
| 201 |
+
mul_row_ptr = mul_ptr + row * stride_mul_row
|
| 202 |
+
out_row_ptr = output_ptr + row * stride_out_row
|
| 203 |
+
|
| 204 |
+
D_float = D.to(tl.float32)
|
| 205 |
+
|
| 206 |
+
# --- Single-tile path ---
|
| 207 |
+
if D <= BLOCK_D:
|
| 208 |
+
d_offs = tl.arange(0, BLOCK_D)
|
| 209 |
+
mask = d_offs < D
|
| 210 |
+
|
| 211 |
+
x = tl.load(input_row_ptr + d_offs, mask=mask,
|
| 212 |
+
other=0.0).to(tl.float32)
|
| 213 |
+
m = tl.load(mul_row_ptr + d_offs, mask=mask,
|
| 214 |
+
other=0.0).to(tl.float32)
|
| 215 |
+
|
| 216 |
+
x2 = x * x
|
| 217 |
+
x3 = x2 * x
|
| 218 |
+
|
| 219 |
+
inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps)
|
| 220 |
+
inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps)
|
| 221 |
+
inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps)
|
| 222 |
+
|
| 223 |
+
# Pre-multiply scalar weight * inv_rms to save 1 FMA per element
|
| 224 |
+
w0_inv = w0 * inv_rms_x3
|
| 225 |
+
w1_inv = w1 * inv_rms_x2
|
| 226 |
+
w2_inv = w2 * inv_rms_x
|
| 227 |
+
|
| 228 |
+
poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b
|
| 229 |
+
tl.store(out_row_ptr + d_offs, poly * m, mask=mask)
|
| 230 |
+
else:
|
| 231 |
+
# --- Multi-tile: two-pass approach ---
|
| 232 |
+
sum_x2 = tl.zeros((), dtype=tl.float32)
|
| 233 |
+
sum_x4 = tl.zeros((), dtype=tl.float32)
|
| 234 |
+
sum_x6 = tl.zeros((), dtype=tl.float32)
|
| 235 |
+
|
| 236 |
+
for d_start in range(0, D, BLOCK_D):
|
| 237 |
+
d_offs = d_start + tl.arange(0, BLOCK_D)
|
| 238 |
+
mask = d_offs < D
|
| 239 |
+
x = tl.load(input_row_ptr + d_offs, mask=mask,
|
| 240 |
+
other=0.0).to(tl.float32)
|
| 241 |
+
x2 = x * x
|
| 242 |
+
sum_x2 += tl.sum(x2)
|
| 243 |
+
sum_x4 += tl.sum(x2 * x2)
|
| 244 |
+
sum_x6 += tl.sum(x2 * x2 * x2)
|
| 245 |
+
|
| 246 |
+
inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps)
|
| 247 |
+
inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps)
|
| 248 |
+
inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps)
|
| 249 |
+
|
| 250 |
+
# Pre-multiply scalar weight * inv_rms
|
| 251 |
+
w0_inv = w0 * inv_rms_x3
|
| 252 |
+
w1_inv = w1 * inv_rms_x2
|
| 253 |
+
w2_inv = w2 * inv_rms_x
|
| 254 |
+
|
| 255 |
+
for d_start in range(0, D, BLOCK_D):
|
| 256 |
+
d_offs = d_start + tl.arange(0, BLOCK_D)
|
| 257 |
+
mask = d_offs < D
|
| 258 |
+
x = tl.load(input_row_ptr + d_offs, mask=mask,
|
| 259 |
+
other=0.0).to(tl.float32)
|
| 260 |
+
m = tl.load(mul_row_ptr + d_offs, mask=mask,
|
| 261 |
+
other=0.0).to(tl.float32)
|
| 262 |
+
x2 = x * x
|
| 263 |
+
x3 = x2 * x
|
| 264 |
+
poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b
|
| 265 |
+
tl.store(out_row_ptr + d_offs, poly * m, mask=mask)
|
| 266 |
+
|
| 267 |
+
@triton.autotune(
|
| 268 |
+
configs=_GROUPED_POLYNORM_BWD_CONFIGS,
|
| 269 |
+
key=["D"],
|
| 270 |
+
reset_to_zero=["grad_weight_ptr", "grad_bias_ptr"],
|
| 271 |
+
)
|
| 272 |
+
@triton.jit
|
| 273 |
+
def _grouped_polynorm_bwd_kernel(
|
| 274 |
+
grad_out_ptr,
|
| 275 |
+
input_ptr,
|
| 276 |
+
mul_ptr,
|
| 277 |
+
weight_ptr,
|
| 278 |
+
bias_ptr,
|
| 279 |
+
offsets_ptr,
|
| 280 |
+
grad_input_ptr,
|
| 281 |
+
grad_mul_ptr,
|
| 282 |
+
grad_weight_ptr,
|
| 283 |
+
grad_bias_ptr,
|
| 284 |
+
N,
|
| 285 |
+
D,
|
| 286 |
+
num_experts,
|
| 287 |
+
eps,
|
| 288 |
+
expert_offset,
|
| 289 |
+
stride_row,
|
| 290 |
+
BLOCK_D: tl.constexpr,
|
| 291 |
+
):
|
| 292 |
+
"""Backward kernel: one program per row, 2-pass approach.
|
| 293 |
+
|
| 294 |
+
Pass 1: RMS stats + dot products + bias grad
|
| 295 |
+
Pass 2: grad_input + grad_mul + weight grads (via atomic_add)
|
| 296 |
+
"""
|
| 297 |
+
row = tl.program_id(0)
|
| 298 |
+
if row >= N:
|
| 299 |
+
return
|
| 300 |
+
|
| 301 |
+
lo = 0
|
| 302 |
+
hi = num_experts
|
| 303 |
+
for _ in range(12):
|
| 304 |
+
if lo < hi:
|
| 305 |
+
mid = (lo + hi) // 2
|
| 306 |
+
if tl.load(offsets_ptr + mid) <= row:
|
| 307 |
+
lo = mid + 1
|
| 308 |
+
else:
|
| 309 |
+
hi = mid
|
| 310 |
+
eidx = lo + expert_offset
|
| 311 |
+
|
| 312 |
+
w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32)
|
| 313 |
+
w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32)
|
| 314 |
+
w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32)
|
| 315 |
+
b_val = tl.load(bias_ptr + eidx).to(tl.float32)
|
| 316 |
+
|
| 317 |
+
input_row_ptr = input_ptr + row * stride_row
|
| 318 |
+
mul_row_ptr = mul_ptr + row * stride_row
|
| 319 |
+
grad_out_row_ptr = grad_out_ptr + row * stride_row
|
| 320 |
+
grad_input_row_ptr = grad_input_ptr + row * stride_row
|
| 321 |
+
grad_mul_row_ptr = grad_mul_ptr + row * stride_row
|
| 322 |
+
|
| 323 |
+
D_float = D.to(tl.float32)
|
| 324 |
+
|
| 325 |
+
# --- Single-tile path ---
|
| 326 |
+
if D <= BLOCK_D:
|
| 327 |
+
d_offs = tl.arange(0, BLOCK_D)
|
| 328 |
+
mask = d_offs < D
|
| 329 |
+
|
| 330 |
+
x = tl.load(input_row_ptr + d_offs, mask=mask,
|
| 331 |
+
other=0.0).to(tl.float32)
|
| 332 |
+
m = tl.load(mul_row_ptr + d_offs, mask=mask,
|
| 333 |
+
other=0.0).to(tl.float32)
|
| 334 |
+
go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
|
| 335 |
+
other=0.0).to(tl.float32)
|
| 336 |
+
|
| 337 |
+
x2 = x * x
|
| 338 |
+
x3 = x2 * x
|
| 339 |
+
|
| 340 |
+
# Compute RMS stats (x4 inlined to reduce register pressure)
|
| 341 |
+
inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps)
|
| 342 |
+
inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps)
|
| 343 |
+
inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps)
|
| 344 |
+
|
| 345 |
+
w0_inv = w0 * inv_rms_x3
|
| 346 |
+
w1_inv = w1 * inv_rms_x2
|
| 347 |
+
w2_inv = w2 * inv_rms_x
|
| 348 |
+
|
| 349 |
+
dpoly = go * m
|
| 350 |
+
|
| 351 |
+
# Dot products for coefficients and weight grads
|
| 352 |
+
sum_dpoly_x = tl.sum(dpoly * x)
|
| 353 |
+
sum_dpoly_x2 = tl.sum(dpoly * x2)
|
| 354 |
+
sum_dpoly_x3 = tl.sum(dpoly * x3)
|
| 355 |
+
grad_b_acc = tl.sum(dpoly)
|
| 356 |
+
|
| 357 |
+
# Weight grads
|
| 358 |
+
grad_w0_acc = inv_rms_x3 * sum_dpoly_x3
|
| 359 |
+
grad_w1_acc = inv_rms_x2 * sum_dpoly_x2
|
| 360 |
+
grad_w2_acc = inv_rms_x * sum_dpoly_x
|
| 361 |
+
|
| 362 |
+
# Coefficients for grad_input
|
| 363 |
+
coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float
|
| 364 |
+
coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float
|
| 365 |
+
coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float
|
| 366 |
+
|
| 367 |
+
# grad_mul
|
| 368 |
+
poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv
|
| 369 |
+
tl.store(grad_mul_row_ptr + d_offs, go * (poly + b_val), mask=mask)
|
| 370 |
+
|
| 371 |
+
# grad_input (in-place accumulation to reduce register pressure)
|
| 372 |
+
g = inv_rms_x * (w2 * dpoly - x * coeff_x)
|
| 373 |
+
g += 2.0 * x * inv_rms_x2 * (w1 * dpoly - x2 * coeff_x2)
|
| 374 |
+
g += 3.0 * x2 * inv_rms_x3 * (w0 * dpoly - x3 * coeff_x3)
|
| 375 |
+
tl.store(grad_input_row_ptr + d_offs, g, mask=mask)
|
| 376 |
+
|
| 377 |
+
tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc)
|
| 378 |
+
tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc)
|
| 379 |
+
tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc)
|
| 380 |
+
tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc)
|
| 381 |
+
else:
|
| 382 |
+
# --- Multi-tile: 2-pass ---
|
| 383 |
+
# Pass 1: RMS stats + dot products + bias grad
|
| 384 |
+
sum_x2 = tl.zeros((), dtype=tl.float32)
|
| 385 |
+
sum_x4 = tl.zeros((), dtype=tl.float32)
|
| 386 |
+
sum_x6 = tl.zeros((), dtype=tl.float32)
|
| 387 |
+
sum_dpoly_x = tl.zeros((), dtype=tl.float32)
|
| 388 |
+
sum_dpoly_x2 = tl.zeros((), dtype=tl.float32)
|
| 389 |
+
sum_dpoly_x3 = tl.zeros((), dtype=tl.float32)
|
| 390 |
+
grad_b_acc = tl.zeros((), dtype=tl.float32)
|
| 391 |
+
|
| 392 |
+
for d_start in range(0, D, BLOCK_D):
|
| 393 |
+
d_offs = d_start + tl.arange(0, BLOCK_D)
|
| 394 |
+
mask = d_offs < D
|
| 395 |
+
x = tl.load(input_row_ptr + d_offs, mask=mask,
|
| 396 |
+
other=0.0).to(tl.float32)
|
| 397 |
+
m = tl.load(mul_row_ptr + d_offs, mask=mask,
|
| 398 |
+
other=0.0).to(tl.float32)
|
| 399 |
+
go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
|
| 400 |
+
other=0.0).to(tl.float32)
|
| 401 |
+
|
| 402 |
+
x2 = x * x
|
| 403 |
+
x3 = x2 * x
|
| 404 |
+
dpoly = go * m
|
| 405 |
+
|
| 406 |
+
sum_x2 += tl.sum(x2)
|
| 407 |
+
sum_x4 += tl.sum(x2 * x2)
|
| 408 |
+
sum_x6 += tl.sum(x2 * x2 * x2)
|
| 409 |
+
sum_dpoly_x += tl.sum(dpoly * x)
|
| 410 |
+
sum_dpoly_x2 += tl.sum(dpoly * x2)
|
| 411 |
+
sum_dpoly_x3 += tl.sum(dpoly * x3)
|
| 412 |
+
grad_b_acc += tl.sum(dpoly)
|
| 413 |
+
|
| 414 |
+
inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps)
|
| 415 |
+
inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps)
|
| 416 |
+
inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps)
|
| 417 |
+
|
| 418 |
+
w0_inv = w0 * inv_rms_x3
|
| 419 |
+
w1_inv = w1 * inv_rms_x2
|
| 420 |
+
w2_inv = w2 * inv_rms_x
|
| 421 |
+
|
| 422 |
+
# Weight grads
|
| 423 |
+
grad_w0_acc = inv_rms_x3 * sum_dpoly_x3
|
| 424 |
+
grad_w1_acc = inv_rms_x2 * sum_dpoly_x2
|
| 425 |
+
grad_w2_acc = inv_rms_x * sum_dpoly_x
|
| 426 |
+
|
| 427 |
+
# Coefficients for grad_input
|
| 428 |
+
coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float
|
| 429 |
+
coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float
|
| 430 |
+
coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float
|
| 431 |
+
|
| 432 |
+
# Pass 2: grad_input + grad_mul
|
| 433 |
+
for d_start in range(0, D, BLOCK_D):
|
| 434 |
+
d_offs = d_start + tl.arange(0, BLOCK_D)
|
| 435 |
+
mask = d_offs < D
|
| 436 |
+
x = tl.load(input_row_ptr + d_offs, mask=mask,
|
| 437 |
+
other=0.0).to(tl.float32)
|
| 438 |
+
m = tl.load(mul_row_ptr + d_offs, mask=mask,
|
| 439 |
+
other=0.0).to(tl.float32)
|
| 440 |
+
go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
|
| 441 |
+
other=0.0).to(tl.float32)
|
| 442 |
+
|
| 443 |
+
x2 = x * x
|
| 444 |
+
x3 = x2 * x
|
| 445 |
+
|
| 446 |
+
poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv
|
| 447 |
+
tl.store(grad_mul_row_ptr + d_offs,
|
| 448 |
+
go * (poly + b_val),
|
| 449 |
+
mask=mask)
|
| 450 |
+
|
| 451 |
+
dpoly = go * m
|
| 452 |
+
g = inv_rms_x * (w2 * dpoly - x * coeff_x)
|
| 453 |
+
g += (2.0 * x * inv_rms_x2 *
|
| 454 |
+
(w1 * dpoly - x2 * coeff_x2))
|
| 455 |
+
g += (3.0 * x2 * inv_rms_x3 *
|
| 456 |
+
(w0 * dpoly - x3 * coeff_x3))
|
| 457 |
+
|
| 458 |
+
tl.store(grad_input_row_ptr + d_offs, g, mask=mask)
|
| 459 |
+
|
| 460 |
+
tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc)
|
| 461 |
+
tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc)
|
| 462 |
+
tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc)
|
| 463 |
+
tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc)
|
| 464 |
+
|
| 465 |
+
class _GroupedPolyNormFn(torch.autograd.Function):
|
| 466 |
+
|
| 467 |
+
@staticmethod
|
| 468 |
+
def forward(ctx, input, mul, weight, bias, offsets, eps, expert_offset):
|
| 469 |
+
N, D = input.shape
|
| 470 |
+
input = input.contiguous()
|
| 471 |
+
mul = mul.contiguous()
|
| 472 |
+
output = torch.empty_like(input)
|
| 473 |
+
|
| 474 |
+
num_experts = offsets.shape[0]
|
| 475 |
+
assert num_experts <= 4096, (
|
| 476 |
+
f"Supports at most 4096 experts, got {num_experts}.")
|
| 477 |
+
|
| 478 |
+
_grouped_polynorm_fwd_kernel[(N,)](
|
| 479 |
+
input,
|
| 480 |
+
mul,
|
| 481 |
+
weight,
|
| 482 |
+
bias,
|
| 483 |
+
offsets,
|
| 484 |
+
output,
|
| 485 |
+
N,
|
| 486 |
+
D,
|
| 487 |
+
num_experts,
|
| 488 |
+
eps,
|
| 489 |
+
expert_offset,
|
| 490 |
+
stride_input_row=input.stride(0),
|
| 491 |
+
stride_mul_row=mul.stride(0),
|
| 492 |
+
stride_out_row=output.stride(0),
|
| 493 |
+
)
|
| 494 |
+
|
| 495 |
+
ctx.save_for_backward(input, mul, weight, bias, offsets)
|
| 496 |
+
ctx.eps = eps
|
| 497 |
+
ctx.expert_offset = expert_offset
|
| 498 |
+
return output
|
| 499 |
+
|
| 500 |
+
@staticmethod
|
| 501 |
+
def backward(ctx, grad_output):
|
| 502 |
+
input, mul, weight, bias, offsets = ctx.saved_tensors
|
| 503 |
+
eps = ctx.eps
|
| 504 |
+
expert_offset = ctx.expert_offset
|
| 505 |
+
N, D = input.shape
|
| 506 |
+
|
| 507 |
+
grad_output = grad_output.contiguous()
|
| 508 |
+
grad_input = torch.empty_like(input)
|
| 509 |
+
grad_mul = torch.empty_like(mul)
|
| 510 |
+
grad_weight = torch.zeros(weight.shape[0],
|
| 511 |
+
3,
|
| 512 |
+
device=weight.device,
|
| 513 |
+
dtype=torch.float32)
|
| 514 |
+
grad_bias = torch.zeros(bias.shape[0],
|
| 515 |
+
device=bias.device,
|
| 516 |
+
dtype=torch.float32)
|
| 517 |
+
|
| 518 |
+
num_experts = offsets.shape[0]
|
| 519 |
+
|
| 520 |
+
_grouped_polynorm_bwd_kernel[(N,)](
|
| 521 |
+
grad_output,
|
| 522 |
+
input,
|
| 523 |
+
mul,
|
| 524 |
+
weight,
|
| 525 |
+
bias,
|
| 526 |
+
offsets,
|
| 527 |
+
grad_input,
|
| 528 |
+
grad_mul,
|
| 529 |
+
grad_weight,
|
| 530 |
+
grad_bias,
|
| 531 |
+
N,
|
| 532 |
+
D,
|
| 533 |
+
num_experts,
|
| 534 |
+
eps,
|
| 535 |
+
expert_offset,
|
| 536 |
+
stride_row=input.stride(0),
|
| 537 |
+
)
|
| 538 |
+
|
| 539 |
+
grad_weight = grad_weight.to(weight.dtype)
|
| 540 |
+
grad_bias = grad_bias.unsqueeze(-1).to(bias.dtype)
|
| 541 |
+
|
| 542 |
+
return grad_input, grad_mul, grad_weight, grad_bias, None, None, None
|
| 543 |
+
|
| 544 |
+
def grouped_fused_mul_poly_norm(
|
| 545 |
+
input: Tensor,
|
| 546 |
+
mul: Tensor,
|
| 547 |
+
weight: Tensor,
|
| 548 |
+
bias: Tensor,
|
| 549 |
+
offsets: Tensor,
|
| 550 |
+
eps: float = 1e-6,
|
| 551 |
+
expert_offset: int = 0,
|
| 552 |
+
) -> Tensor:
|
| 553 |
+
"""Triton-accelerated Grouped FusedMulPolyNorm.
|
| 554 |
+
|
| 555 |
+
Args:
|
| 556 |
+
input: (total_tokens, D) - concatenated tokens for all experts
|
| 557 |
+
mul: (total_tokens, D) - gate values to multiply with
|
| 558 |
+
weight: (num_experts, 3) - per-expert polynomial weights
|
| 559 |
+
bias: (num_experts, 1) - per-expert polynomial bias
|
| 560 |
+
offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32)
|
| 561 |
+
eps: numerical stability epsilon
|
| 562 |
+
expert_offset: offset to add to expert index
|
| 563 |
+
|
| 564 |
+
Returns:
|
| 565 |
+
(total_tokens, D) - output tensor
|
| 566 |
+
"""
|
| 567 |
+
return _GroupedPolyNormFn.apply(input, mul, weight, bias, offsets, eps,
|
| 568 |
+
expert_offset)
|
| 569 |
+
|
| 570 |
+
else:
|
| 571 |
+
|
| 572 |
+
def grouped_fused_mul_poly_norm(
|
| 573 |
+
input: Tensor,
|
| 574 |
+
mul: Tensor,
|
| 575 |
+
weight: Tensor,
|
| 576 |
+
bias: Tensor,
|
| 577 |
+
offsets: Tensor,
|
| 578 |
+
eps: float = 1e-6,
|
| 579 |
+
expert_offset: int = 0,
|
| 580 |
+
) -> Tensor:
|
| 581 |
+
raise RuntimeError(
|
| 582 |
+
"Triton is not available. Install triton to use "
|
| 583 |
+
"grouped_fused_mul_poly_norm.")
|
build/torch28-cxx11-cu128-x86_64-linux/__init__.py
CHANGED
|
@@ -2,6 +2,7 @@ import torch
|
|
| 2 |
|
| 3 |
from . import layers, parallel_style
|
| 4 |
from ._ops import ops
|
|
|
|
| 5 |
from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
|
| 6 |
from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
|
| 7 |
|
|
@@ -45,6 +46,7 @@ def fused_add_rms_norm(
|
|
| 45 |
__all__ = [
|
| 46 |
"poly_norm",
|
| 47 |
"fused_mul_poly_norm",
|
|
|
|
| 48 |
"rms_norm",
|
| 49 |
"fused_add_rms_norm",
|
| 50 |
"layers",
|
|
|
|
| 2 |
|
| 3 |
from . import layers, parallel_style
|
| 4 |
from ._ops import ops
|
| 5 |
+
from .grouped_poly_norm import grouped_fused_mul_poly_norm
|
| 6 |
from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
|
| 7 |
from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
|
| 8 |
|
|
|
|
| 46 |
__all__ = [
|
| 47 |
"poly_norm",
|
| 48 |
"fused_mul_poly_norm",
|
| 49 |
+
"grouped_fused_mul_poly_norm",
|
| 50 |
"rms_norm",
|
| 51 |
"fused_add_rms_norm",
|
| 52 |
"layers",
|
build/torch28-cxx11-cu128-x86_64-linux/{_activation_18b7543_dirty.abi3.so → _activation_0e6f27f_dirty.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 15804360
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7edb027993454d74da9632c7368edc2b0526b5f1ef33ae9e790d49bdf7285640
|
| 3 |
size 15804360
|
build/torch28-cxx11-cu128-x86_64-linux/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _activation_0e6f27f_dirty
|
| 3 |
+
ops = torch.ops._activation_0e6f27f_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_activation_0e6f27f_dirty::{op_name}"
|
build/torch28-cxx11-cu128-x86_64-linux/grouped_poly_norm.py
ADDED
|
@@ -0,0 +1,583 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Triton-accelerated Grouped FusedMulPolyNorm for MoE.
|
| 2 |
+
|
| 3 |
+
Fuses the entire PolyNorm computation into two Triton kernels (fwd + bwd),
|
| 4 |
+
eliminating multiple intermediate tensors and kernel launches.
|
| 5 |
+
|
| 6 |
+
PolyNorm formula (per row):
|
| 7 |
+
poly = w[0] * rms_norm(x^3) + w[1] * rms_norm(x^2) + w[2] * rms_norm(x) + bias
|
| 8 |
+
output = poly * mul
|
| 9 |
+
|
| 10 |
+
where rms_norm(x) = x / sqrt(mean(x^2, dim=-1) + eps)
|
| 11 |
+
|
| 12 |
+
Performance optimizations:
|
| 13 |
+
- @triton.autotune selects optimal BLOCK_D, num_warps, and num_stages per
|
| 14 |
+
hidden dimension.
|
| 15 |
+
- Single-tile specialization: when D <= BLOCK_D, all data stays in registers
|
| 16 |
+
across the reduction and output phases, eliminating redundant global reads.
|
| 17 |
+
- Multi-tile software pipelining: explicit num_stages in autotune configs
|
| 18 |
+
enables overlapping memory loads with computation across loop iterations.
|
| 19 |
+
- In-kernel binary search for expert mapping: eliminates 2 PyTorch kernel
|
| 20 |
+
launches (torch.arange + torch.bucketize) per forward/backward call.
|
| 21 |
+
- Backward 2-pass optimization: pass 1 merges RMS statistics computation
|
| 22 |
+
with dot product accumulation, pass 2 computes gradients. This reduces
|
| 23 |
+
memory traffic compared to a naive 3-pass approach.
|
| 24 |
+
|
| 25 |
+
Forward kernel: one program per row, tiles over D dimension.
|
| 26 |
+
- Computes x, x^2, x^3 in registers
|
| 27 |
+
- Computes three RMS norms in a single pass (shared variance reduction)
|
| 28 |
+
- Applies polynomial weights + bias + mul in-place
|
| 29 |
+
|
| 30 |
+
Backward kernel: one program per row, tiles over D dimension.
|
| 31 |
+
- Recomputes forward intermediates from saved inputs (activation recomputation)
|
| 32 |
+
- 2-pass: (1) RMS stats + dot products + bias grad, (2) grad_input + grad_mul + weight grads
|
| 33 |
+
- Weight/bias gradients use tl.atomic_add for cross-row accumulation
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
import torch
|
| 37 |
+
from torch import Tensor
|
| 38 |
+
|
| 39 |
+
try:
|
| 40 |
+
import triton
|
| 41 |
+
import triton.language as tl
|
| 42 |
+
|
| 43 |
+
HAS_TRITON = True
|
| 44 |
+
except ImportError:
|
| 45 |
+
HAS_TRITON = False
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# ---------------------------------------------------------------------------
|
| 49 |
+
# PyTorch reference implementation (for testing and benchmarking)
|
| 50 |
+
# ---------------------------------------------------------------------------
|
| 51 |
+
def _rms_norm(x: Tensor, eps: float) -> Tensor:
|
| 52 |
+
"""Per-row RMS normalization: x / sqrt(mean(x^2, dim=-1) + eps)"""
|
| 53 |
+
return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def grouped_fused_mul_poly_norm_ref(
|
| 57 |
+
input: Tensor,
|
| 58 |
+
mul: Tensor,
|
| 59 |
+
weight: Tensor,
|
| 60 |
+
bias: Tensor,
|
| 61 |
+
offsets: Tensor,
|
| 62 |
+
eps: float = 1e-6,
|
| 63 |
+
expert_offset: int = 0,
|
| 64 |
+
) -> Tensor:
|
| 65 |
+
"""PyTorch reference for Grouped FusedMulPolyNorm (vectorized, single pass).
|
| 66 |
+
|
| 67 |
+
Uses torch.bucketize to map tokens to experts, then computes PolyNorm
|
| 68 |
+
for all tokens at once. torch.compile friendly.
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
input: (total_tokens, D) - concatenated tokens for all experts
|
| 72 |
+
mul: (total_tokens, D) - gate values to multiply with
|
| 73 |
+
weight: (num_experts, 3) - per-expert polynomial weights [x^3, x^2, x]
|
| 74 |
+
bias: (num_experts, 1) - per-expert polynomial bias
|
| 75 |
+
offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32)
|
| 76 |
+
eps: numerical stability epsilon
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
(total_tokens, D) - output tensor
|
| 80 |
+
"""
|
| 81 |
+
orig_dtype = input.dtype
|
| 82 |
+
|
| 83 |
+
token_positions = torch.arange(input.shape[0], device=input.device)
|
| 84 |
+
expert_idx = torch.bucketize(token_positions, offsets, right=True) + expert_offset
|
| 85 |
+
|
| 86 |
+
weight_fp32 = weight.float()
|
| 87 |
+
bias_fp32 = bias.float()
|
| 88 |
+
|
| 89 |
+
per_token_w = weight_fp32[expert_idx]
|
| 90 |
+
per_token_b = bias_fp32[expert_idx]
|
| 91 |
+
|
| 92 |
+
x = input.float()
|
| 93 |
+
m = mul.float()
|
| 94 |
+
|
| 95 |
+
x2 = x * x
|
| 96 |
+
x3 = x2 * x
|
| 97 |
+
|
| 98 |
+
poly = (per_token_w[:, 0:1] * _rms_norm(x3, eps) +
|
| 99 |
+
per_token_w[:, 1:2] * _rms_norm(x2, eps) +
|
| 100 |
+
per_token_w[:, 2:3] * _rms_norm(x, eps) + per_token_b)
|
| 101 |
+
|
| 102 |
+
return (poly * m).to(orig_dtype)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
# ---------------------------------------------------------------------------
|
| 106 |
+
# Triton kernel implementation
|
| 107 |
+
# ---------------------------------------------------------------------------
|
| 108 |
+
if HAS_TRITON:
|
| 109 |
+
# --- Autotune configurations ---
|
| 110 |
+
_GROUPED_POLYNORM_FWD_CONFIGS = [
|
| 111 |
+
triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2),
|
| 112 |
+
triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3),
|
| 113 |
+
triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2),
|
| 114 |
+
triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3),
|
| 115 |
+
triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2),
|
| 116 |
+
triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4),
|
| 117 |
+
triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2),
|
| 118 |
+
triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3),
|
| 119 |
+
triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4),
|
| 120 |
+
triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2),
|
| 121 |
+
triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3),
|
| 122 |
+
triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4),
|
| 123 |
+
triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2),
|
| 124 |
+
triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2),
|
| 125 |
+
triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3),
|
| 126 |
+
triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2),
|
| 127 |
+
triton.Config({"BLOCK_D": 2048}, num_warps=4, num_stages=1),
|
| 128 |
+
triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1),
|
| 129 |
+
triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1),
|
| 130 |
+
triton.Config({"BLOCK_D": 2048}, num_warps=32, num_stages=1),
|
| 131 |
+
]
|
| 132 |
+
|
| 133 |
+
_GROUPED_POLYNORM_BWD_CONFIGS = [
|
| 134 |
+
triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2),
|
| 135 |
+
triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3),
|
| 136 |
+
triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2),
|
| 137 |
+
triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3),
|
| 138 |
+
triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2),
|
| 139 |
+
triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4),
|
| 140 |
+
triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2),
|
| 141 |
+
triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3),
|
| 142 |
+
triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4),
|
| 143 |
+
triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=5),
|
| 144 |
+
triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2),
|
| 145 |
+
triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3),
|
| 146 |
+
triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4),
|
| 147 |
+
triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2),
|
| 148 |
+
triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2),
|
| 149 |
+
triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3),
|
| 150 |
+
triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=4),
|
| 151 |
+
triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2),
|
| 152 |
+
triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1),
|
| 153 |
+
triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1),
|
| 154 |
+
]
|
| 155 |
+
|
| 156 |
+
@triton.autotune(
|
| 157 |
+
configs=_GROUPED_POLYNORM_FWD_CONFIGS,
|
| 158 |
+
key=["D"],
|
| 159 |
+
)
|
| 160 |
+
@triton.jit
|
| 161 |
+
def _grouped_polynorm_fwd_kernel(
|
| 162 |
+
input_ptr,
|
| 163 |
+
mul_ptr,
|
| 164 |
+
weight_ptr,
|
| 165 |
+
bias_ptr,
|
| 166 |
+
offsets_ptr,
|
| 167 |
+
output_ptr,
|
| 168 |
+
N,
|
| 169 |
+
D,
|
| 170 |
+
num_experts,
|
| 171 |
+
eps,
|
| 172 |
+
expert_offset,
|
| 173 |
+
stride_input_row,
|
| 174 |
+
stride_mul_row,
|
| 175 |
+
stride_out_row,
|
| 176 |
+
BLOCK_D: tl.constexpr,
|
| 177 |
+
):
|
| 178 |
+
"""Forward kernel: one program per row."""
|
| 179 |
+
row = tl.program_id(0)
|
| 180 |
+
if row >= N:
|
| 181 |
+
return
|
| 182 |
+
|
| 183 |
+
# Binary search for expert index (12 iters covers up to 4096 experts)
|
| 184 |
+
lo = 0
|
| 185 |
+
hi = num_experts
|
| 186 |
+
for _ in range(12):
|
| 187 |
+
if lo < hi:
|
| 188 |
+
mid = (lo + hi) // 2
|
| 189 |
+
if tl.load(offsets_ptr + mid) <= row:
|
| 190 |
+
lo = mid + 1
|
| 191 |
+
else:
|
| 192 |
+
hi = mid
|
| 193 |
+
eidx = lo + expert_offset
|
| 194 |
+
|
| 195 |
+
w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32)
|
| 196 |
+
w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32)
|
| 197 |
+
w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32)
|
| 198 |
+
b = tl.load(bias_ptr + eidx).to(tl.float32)
|
| 199 |
+
|
| 200 |
+
input_row_ptr = input_ptr + row * stride_input_row
|
| 201 |
+
mul_row_ptr = mul_ptr + row * stride_mul_row
|
| 202 |
+
out_row_ptr = output_ptr + row * stride_out_row
|
| 203 |
+
|
| 204 |
+
D_float = D.to(tl.float32)
|
| 205 |
+
|
| 206 |
+
# --- Single-tile path ---
|
| 207 |
+
if D <= BLOCK_D:
|
| 208 |
+
d_offs = tl.arange(0, BLOCK_D)
|
| 209 |
+
mask = d_offs < D
|
| 210 |
+
|
| 211 |
+
x = tl.load(input_row_ptr + d_offs, mask=mask,
|
| 212 |
+
other=0.0).to(tl.float32)
|
| 213 |
+
m = tl.load(mul_row_ptr + d_offs, mask=mask,
|
| 214 |
+
other=0.0).to(tl.float32)
|
| 215 |
+
|
| 216 |
+
x2 = x * x
|
| 217 |
+
x3 = x2 * x
|
| 218 |
+
|
| 219 |
+
inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps)
|
| 220 |
+
inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps)
|
| 221 |
+
inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps)
|
| 222 |
+
|
| 223 |
+
# Pre-multiply scalar weight * inv_rms to save 1 FMA per element
|
| 224 |
+
w0_inv = w0 * inv_rms_x3
|
| 225 |
+
w1_inv = w1 * inv_rms_x2
|
| 226 |
+
w2_inv = w2 * inv_rms_x
|
| 227 |
+
|
| 228 |
+
poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b
|
| 229 |
+
tl.store(out_row_ptr + d_offs, poly * m, mask=mask)
|
| 230 |
+
else:
|
| 231 |
+
# --- Multi-tile: two-pass approach ---
|
| 232 |
+
sum_x2 = tl.zeros((), dtype=tl.float32)
|
| 233 |
+
sum_x4 = tl.zeros((), dtype=tl.float32)
|
| 234 |
+
sum_x6 = tl.zeros((), dtype=tl.float32)
|
| 235 |
+
|
| 236 |
+
for d_start in range(0, D, BLOCK_D):
|
| 237 |
+
d_offs = d_start + tl.arange(0, BLOCK_D)
|
| 238 |
+
mask = d_offs < D
|
| 239 |
+
x = tl.load(input_row_ptr + d_offs, mask=mask,
|
| 240 |
+
other=0.0).to(tl.float32)
|
| 241 |
+
x2 = x * x
|
| 242 |
+
sum_x2 += tl.sum(x2)
|
| 243 |
+
sum_x4 += tl.sum(x2 * x2)
|
| 244 |
+
sum_x6 += tl.sum(x2 * x2 * x2)
|
| 245 |
+
|
| 246 |
+
inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps)
|
| 247 |
+
inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps)
|
| 248 |
+
inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps)
|
| 249 |
+
|
| 250 |
+
# Pre-multiply scalar weight * inv_rms
|
| 251 |
+
w0_inv = w0 * inv_rms_x3
|
| 252 |
+
w1_inv = w1 * inv_rms_x2
|
| 253 |
+
w2_inv = w2 * inv_rms_x
|
| 254 |
+
|
| 255 |
+
for d_start in range(0, D, BLOCK_D):
|
| 256 |
+
d_offs = d_start + tl.arange(0, BLOCK_D)
|
| 257 |
+
mask = d_offs < D
|
| 258 |
+
x = tl.load(input_row_ptr + d_offs, mask=mask,
|
| 259 |
+
other=0.0).to(tl.float32)
|
| 260 |
+
m = tl.load(mul_row_ptr + d_offs, mask=mask,
|
| 261 |
+
other=0.0).to(tl.float32)
|
| 262 |
+
x2 = x * x
|
| 263 |
+
x3 = x2 * x
|
| 264 |
+
poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b
|
| 265 |
+
tl.store(out_row_ptr + d_offs, poly * m, mask=mask)
|
| 266 |
+
|
| 267 |
+
@triton.autotune(
|
| 268 |
+
configs=_GROUPED_POLYNORM_BWD_CONFIGS,
|
| 269 |
+
key=["D"],
|
| 270 |
+
reset_to_zero=["grad_weight_ptr", "grad_bias_ptr"],
|
| 271 |
+
)
|
| 272 |
+
@triton.jit
|
| 273 |
+
def _grouped_polynorm_bwd_kernel(
|
| 274 |
+
grad_out_ptr,
|
| 275 |
+
input_ptr,
|
| 276 |
+
mul_ptr,
|
| 277 |
+
weight_ptr,
|
| 278 |
+
bias_ptr,
|
| 279 |
+
offsets_ptr,
|
| 280 |
+
grad_input_ptr,
|
| 281 |
+
grad_mul_ptr,
|
| 282 |
+
grad_weight_ptr,
|
| 283 |
+
grad_bias_ptr,
|
| 284 |
+
N,
|
| 285 |
+
D,
|
| 286 |
+
num_experts,
|
| 287 |
+
eps,
|
| 288 |
+
expert_offset,
|
| 289 |
+
stride_row,
|
| 290 |
+
BLOCK_D: tl.constexpr,
|
| 291 |
+
):
|
| 292 |
+
"""Backward kernel: one program per row, 2-pass approach.
|
| 293 |
+
|
| 294 |
+
Pass 1: RMS stats + dot products + bias grad
|
| 295 |
+
Pass 2: grad_input + grad_mul + weight grads (via atomic_add)
|
| 296 |
+
"""
|
| 297 |
+
row = tl.program_id(0)
|
| 298 |
+
if row >= N:
|
| 299 |
+
return
|
| 300 |
+
|
| 301 |
+
lo = 0
|
| 302 |
+
hi = num_experts
|
| 303 |
+
for _ in range(12):
|
| 304 |
+
if lo < hi:
|
| 305 |
+
mid = (lo + hi) // 2
|
| 306 |
+
if tl.load(offsets_ptr + mid) <= row:
|
| 307 |
+
lo = mid + 1
|
| 308 |
+
else:
|
| 309 |
+
hi = mid
|
| 310 |
+
eidx = lo + expert_offset
|
| 311 |
+
|
| 312 |
+
w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32)
|
| 313 |
+
w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32)
|
| 314 |
+
w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32)
|
| 315 |
+
b_val = tl.load(bias_ptr + eidx).to(tl.float32)
|
| 316 |
+
|
| 317 |
+
input_row_ptr = input_ptr + row * stride_row
|
| 318 |
+
mul_row_ptr = mul_ptr + row * stride_row
|
| 319 |
+
grad_out_row_ptr = grad_out_ptr + row * stride_row
|
| 320 |
+
grad_input_row_ptr = grad_input_ptr + row * stride_row
|
| 321 |
+
grad_mul_row_ptr = grad_mul_ptr + row * stride_row
|
| 322 |
+
|
| 323 |
+
D_float = D.to(tl.float32)
|
| 324 |
+
|
| 325 |
+
# --- Single-tile path ---
|
| 326 |
+
if D <= BLOCK_D:
|
| 327 |
+
d_offs = tl.arange(0, BLOCK_D)
|
| 328 |
+
mask = d_offs < D
|
| 329 |
+
|
| 330 |
+
x = tl.load(input_row_ptr + d_offs, mask=mask,
|
| 331 |
+
other=0.0).to(tl.float32)
|
| 332 |
+
m = tl.load(mul_row_ptr + d_offs, mask=mask,
|
| 333 |
+
other=0.0).to(tl.float32)
|
| 334 |
+
go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
|
| 335 |
+
other=0.0).to(tl.float32)
|
| 336 |
+
|
| 337 |
+
x2 = x * x
|
| 338 |
+
x3 = x2 * x
|
| 339 |
+
|
| 340 |
+
# Compute RMS stats (x4 inlined to reduce register pressure)
|
| 341 |
+
inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps)
|
| 342 |
+
inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps)
|
| 343 |
+
inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps)
|
| 344 |
+
|
| 345 |
+
w0_inv = w0 * inv_rms_x3
|
| 346 |
+
w1_inv = w1 * inv_rms_x2
|
| 347 |
+
w2_inv = w2 * inv_rms_x
|
| 348 |
+
|
| 349 |
+
dpoly = go * m
|
| 350 |
+
|
| 351 |
+
# Dot products for coefficients and weight grads
|
| 352 |
+
sum_dpoly_x = tl.sum(dpoly * x)
|
| 353 |
+
sum_dpoly_x2 = tl.sum(dpoly * x2)
|
| 354 |
+
sum_dpoly_x3 = tl.sum(dpoly * x3)
|
| 355 |
+
grad_b_acc = tl.sum(dpoly)
|
| 356 |
+
|
| 357 |
+
# Weight grads
|
| 358 |
+
grad_w0_acc = inv_rms_x3 * sum_dpoly_x3
|
| 359 |
+
grad_w1_acc = inv_rms_x2 * sum_dpoly_x2
|
| 360 |
+
grad_w2_acc = inv_rms_x * sum_dpoly_x
|
| 361 |
+
|
| 362 |
+
# Coefficients for grad_input
|
| 363 |
+
coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float
|
| 364 |
+
coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float
|
| 365 |
+
coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float
|
| 366 |
+
|
| 367 |
+
# grad_mul
|
| 368 |
+
poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv
|
| 369 |
+
tl.store(grad_mul_row_ptr + d_offs, go * (poly + b_val), mask=mask)
|
| 370 |
+
|
| 371 |
+
# grad_input (in-place accumulation to reduce register pressure)
|
| 372 |
+
g = inv_rms_x * (w2 * dpoly - x * coeff_x)
|
| 373 |
+
g += 2.0 * x * inv_rms_x2 * (w1 * dpoly - x2 * coeff_x2)
|
| 374 |
+
g += 3.0 * x2 * inv_rms_x3 * (w0 * dpoly - x3 * coeff_x3)
|
| 375 |
+
tl.store(grad_input_row_ptr + d_offs, g, mask=mask)
|
| 376 |
+
|
| 377 |
+
tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc)
|
| 378 |
+
tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc)
|
| 379 |
+
tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc)
|
| 380 |
+
tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc)
|
| 381 |
+
else:
|
| 382 |
+
# --- Multi-tile: 2-pass ---
|
| 383 |
+
# Pass 1: RMS stats + dot products + bias grad
|
| 384 |
+
sum_x2 = tl.zeros((), dtype=tl.float32)
|
| 385 |
+
sum_x4 = tl.zeros((), dtype=tl.float32)
|
| 386 |
+
sum_x6 = tl.zeros((), dtype=tl.float32)
|
| 387 |
+
sum_dpoly_x = tl.zeros((), dtype=tl.float32)
|
| 388 |
+
sum_dpoly_x2 = tl.zeros((), dtype=tl.float32)
|
| 389 |
+
sum_dpoly_x3 = tl.zeros((), dtype=tl.float32)
|
| 390 |
+
grad_b_acc = tl.zeros((), dtype=tl.float32)
|
| 391 |
+
|
| 392 |
+
for d_start in range(0, D, BLOCK_D):
|
| 393 |
+
d_offs = d_start + tl.arange(0, BLOCK_D)
|
| 394 |
+
mask = d_offs < D
|
| 395 |
+
x = tl.load(input_row_ptr + d_offs, mask=mask,
|
| 396 |
+
other=0.0).to(tl.float32)
|
| 397 |
+
m = tl.load(mul_row_ptr + d_offs, mask=mask,
|
| 398 |
+
other=0.0).to(tl.float32)
|
| 399 |
+
go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
|
| 400 |
+
other=0.0).to(tl.float32)
|
| 401 |
+
|
| 402 |
+
x2 = x * x
|
| 403 |
+
x3 = x2 * x
|
| 404 |
+
dpoly = go * m
|
| 405 |
+
|
| 406 |
+
sum_x2 += tl.sum(x2)
|
| 407 |
+
sum_x4 += tl.sum(x2 * x2)
|
| 408 |
+
sum_x6 += tl.sum(x2 * x2 * x2)
|
| 409 |
+
sum_dpoly_x += tl.sum(dpoly * x)
|
| 410 |
+
sum_dpoly_x2 += tl.sum(dpoly * x2)
|
| 411 |
+
sum_dpoly_x3 += tl.sum(dpoly * x3)
|
| 412 |
+
grad_b_acc += tl.sum(dpoly)
|
| 413 |
+
|
| 414 |
+
inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps)
|
| 415 |
+
inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps)
|
| 416 |
+
inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps)
|
| 417 |
+
|
| 418 |
+
w0_inv = w0 * inv_rms_x3
|
| 419 |
+
w1_inv = w1 * inv_rms_x2
|
| 420 |
+
w2_inv = w2 * inv_rms_x
|
| 421 |
+
|
| 422 |
+
# Weight grads
|
| 423 |
+
grad_w0_acc = inv_rms_x3 * sum_dpoly_x3
|
| 424 |
+
grad_w1_acc = inv_rms_x2 * sum_dpoly_x2
|
| 425 |
+
grad_w2_acc = inv_rms_x * sum_dpoly_x
|
| 426 |
+
|
| 427 |
+
# Coefficients for grad_input
|
| 428 |
+
coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float
|
| 429 |
+
coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float
|
| 430 |
+
coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float
|
| 431 |
+
|
| 432 |
+
# Pass 2: grad_input + grad_mul
|
| 433 |
+
for d_start in range(0, D, BLOCK_D):
|
| 434 |
+
d_offs = d_start + tl.arange(0, BLOCK_D)
|
| 435 |
+
mask = d_offs < D
|
| 436 |
+
x = tl.load(input_row_ptr + d_offs, mask=mask,
|
| 437 |
+
other=0.0).to(tl.float32)
|
| 438 |
+
m = tl.load(mul_row_ptr + d_offs, mask=mask,
|
| 439 |
+
other=0.0).to(tl.float32)
|
| 440 |
+
go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
|
| 441 |
+
other=0.0).to(tl.float32)
|
| 442 |
+
|
| 443 |
+
x2 = x * x
|
| 444 |
+
x3 = x2 * x
|
| 445 |
+
|
| 446 |
+
poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv
|
| 447 |
+
tl.store(grad_mul_row_ptr + d_offs,
|
| 448 |
+
go * (poly + b_val),
|
| 449 |
+
mask=mask)
|
| 450 |
+
|
| 451 |
+
dpoly = go * m
|
| 452 |
+
g = inv_rms_x * (w2 * dpoly - x * coeff_x)
|
| 453 |
+
g += (2.0 * x * inv_rms_x2 *
|
| 454 |
+
(w1 * dpoly - x2 * coeff_x2))
|
| 455 |
+
g += (3.0 * x2 * inv_rms_x3 *
|
| 456 |
+
(w0 * dpoly - x3 * coeff_x3))
|
| 457 |
+
|
| 458 |
+
tl.store(grad_input_row_ptr + d_offs, g, mask=mask)
|
| 459 |
+
|
| 460 |
+
tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc)
|
| 461 |
+
tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc)
|
| 462 |
+
tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc)
|
| 463 |
+
tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc)
|
| 464 |
+
|
| 465 |
+
class _GroupedPolyNormFn(torch.autograd.Function):
|
| 466 |
+
|
| 467 |
+
@staticmethod
|
| 468 |
+
def forward(ctx, input, mul, weight, bias, offsets, eps, expert_offset):
|
| 469 |
+
N, D = input.shape
|
| 470 |
+
input = input.contiguous()
|
| 471 |
+
mul = mul.contiguous()
|
| 472 |
+
output = torch.empty_like(input)
|
| 473 |
+
|
| 474 |
+
num_experts = offsets.shape[0]
|
| 475 |
+
assert num_experts <= 4096, (
|
| 476 |
+
f"Supports at most 4096 experts, got {num_experts}.")
|
| 477 |
+
|
| 478 |
+
_grouped_polynorm_fwd_kernel[(N,)](
|
| 479 |
+
input,
|
| 480 |
+
mul,
|
| 481 |
+
weight,
|
| 482 |
+
bias,
|
| 483 |
+
offsets,
|
| 484 |
+
output,
|
| 485 |
+
N,
|
| 486 |
+
D,
|
| 487 |
+
num_experts,
|
| 488 |
+
eps,
|
| 489 |
+
expert_offset,
|
| 490 |
+
stride_input_row=input.stride(0),
|
| 491 |
+
stride_mul_row=mul.stride(0),
|
| 492 |
+
stride_out_row=output.stride(0),
|
| 493 |
+
)
|
| 494 |
+
|
| 495 |
+
ctx.save_for_backward(input, mul, weight, bias, offsets)
|
| 496 |
+
ctx.eps = eps
|
| 497 |
+
ctx.expert_offset = expert_offset
|
| 498 |
+
return output
|
| 499 |
+
|
| 500 |
+
@staticmethod
|
| 501 |
+
def backward(ctx, grad_output):
|
| 502 |
+
input, mul, weight, bias, offsets = ctx.saved_tensors
|
| 503 |
+
eps = ctx.eps
|
| 504 |
+
expert_offset = ctx.expert_offset
|
| 505 |
+
N, D = input.shape
|
| 506 |
+
|
| 507 |
+
grad_output = grad_output.contiguous()
|
| 508 |
+
grad_input = torch.empty_like(input)
|
| 509 |
+
grad_mul = torch.empty_like(mul)
|
| 510 |
+
grad_weight = torch.zeros(weight.shape[0],
|
| 511 |
+
3,
|
| 512 |
+
device=weight.device,
|
| 513 |
+
dtype=torch.float32)
|
| 514 |
+
grad_bias = torch.zeros(bias.shape[0],
|
| 515 |
+
device=bias.device,
|
| 516 |
+
dtype=torch.float32)
|
| 517 |
+
|
| 518 |
+
num_experts = offsets.shape[0]
|
| 519 |
+
|
| 520 |
+
_grouped_polynorm_bwd_kernel[(N,)](
|
| 521 |
+
grad_output,
|
| 522 |
+
input,
|
| 523 |
+
mul,
|
| 524 |
+
weight,
|
| 525 |
+
bias,
|
| 526 |
+
offsets,
|
| 527 |
+
grad_input,
|
| 528 |
+
grad_mul,
|
| 529 |
+
grad_weight,
|
| 530 |
+
grad_bias,
|
| 531 |
+
N,
|
| 532 |
+
D,
|
| 533 |
+
num_experts,
|
| 534 |
+
eps,
|
| 535 |
+
expert_offset,
|
| 536 |
+
stride_row=input.stride(0),
|
| 537 |
+
)
|
| 538 |
+
|
| 539 |
+
grad_weight = grad_weight.to(weight.dtype)
|
| 540 |
+
grad_bias = grad_bias.unsqueeze(-1).to(bias.dtype)
|
| 541 |
+
|
| 542 |
+
return grad_input, grad_mul, grad_weight, grad_bias, None, None, None
|
| 543 |
+
|
| 544 |
+
def grouped_fused_mul_poly_norm(
|
| 545 |
+
input: Tensor,
|
| 546 |
+
mul: Tensor,
|
| 547 |
+
weight: Tensor,
|
| 548 |
+
bias: Tensor,
|
| 549 |
+
offsets: Tensor,
|
| 550 |
+
eps: float = 1e-6,
|
| 551 |
+
expert_offset: int = 0,
|
| 552 |
+
) -> Tensor:
|
| 553 |
+
"""Triton-accelerated Grouped FusedMulPolyNorm.
|
| 554 |
+
|
| 555 |
+
Args:
|
| 556 |
+
input: (total_tokens, D) - concatenated tokens for all experts
|
| 557 |
+
mul: (total_tokens, D) - gate values to multiply with
|
| 558 |
+
weight: (num_experts, 3) - per-expert polynomial weights
|
| 559 |
+
bias: (num_experts, 1) - per-expert polynomial bias
|
| 560 |
+
offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32)
|
| 561 |
+
eps: numerical stability epsilon
|
| 562 |
+
expert_offset: offset to add to expert index
|
| 563 |
+
|
| 564 |
+
Returns:
|
| 565 |
+
(total_tokens, D) - output tensor
|
| 566 |
+
"""
|
| 567 |
+
return _GroupedPolyNormFn.apply(input, mul, weight, bias, offsets, eps,
|
| 568 |
+
expert_offset)
|
| 569 |
+
|
| 570 |
+
else:
|
| 571 |
+
|
| 572 |
+
def grouped_fused_mul_poly_norm(
|
| 573 |
+
input: Tensor,
|
| 574 |
+
mul: Tensor,
|
| 575 |
+
weight: Tensor,
|
| 576 |
+
bias: Tensor,
|
| 577 |
+
offsets: Tensor,
|
| 578 |
+
eps: float = 1e-6,
|
| 579 |
+
expert_offset: int = 0,
|
| 580 |
+
) -> Tensor:
|
| 581 |
+
raise RuntimeError(
|
| 582 |
+
"Triton is not available. Install triton to use "
|
| 583 |
+
"grouped_fused_mul_poly_norm.")
|
build/torch28-cxx11-cu129-x86_64-linux/__init__.py
CHANGED
|
@@ -2,6 +2,7 @@ import torch
|
|
| 2 |
|
| 3 |
from . import layers, parallel_style
|
| 4 |
from ._ops import ops
|
|
|
|
| 5 |
from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
|
| 6 |
from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
|
| 7 |
|
|
@@ -45,6 +46,7 @@ def fused_add_rms_norm(
|
|
| 45 |
__all__ = [
|
| 46 |
"poly_norm",
|
| 47 |
"fused_mul_poly_norm",
|
|
|
|
| 48 |
"rms_norm",
|
| 49 |
"fused_add_rms_norm",
|
| 50 |
"layers",
|
|
|
|
| 2 |
|
| 3 |
from . import layers, parallel_style
|
| 4 |
from ._ops import ops
|
| 5 |
+
from .grouped_poly_norm import grouped_fused_mul_poly_norm
|
| 6 |
from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
|
| 7 |
from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
|
| 8 |
|
|
|
|
| 46 |
__all__ = [
|
| 47 |
"poly_norm",
|
| 48 |
"fused_mul_poly_norm",
|
| 49 |
+
"grouped_fused_mul_poly_norm",
|
| 50 |
"rms_norm",
|
| 51 |
"fused_add_rms_norm",
|
| 52 |
"layers",
|
build/torch28-cxx11-cu129-x86_64-linux/{_activation_18b7543_dirty.abi3.so → _activation_0e6f27f_dirty.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 15795640
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f06fb9594dcf9c0bc3a8af619fec3a541b775f0dad304a7978314d32fae8d244
|
| 3 |
size 15795640
|
build/torch28-cxx11-cu129-x86_64-linux/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _activation_0e6f27f_dirty
|
| 3 |
+
ops = torch.ops._activation_0e6f27f_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_activation_0e6f27f_dirty::{op_name}"
|
build/torch28-cxx11-cu129-x86_64-linux/grouped_poly_norm.py
ADDED
|
@@ -0,0 +1,583 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Triton-accelerated Grouped FusedMulPolyNorm for MoE.
|
| 2 |
+
|
| 3 |
+
Fuses the entire PolyNorm computation into two Triton kernels (fwd + bwd),
|
| 4 |
+
eliminating multiple intermediate tensors and kernel launches.
|
| 5 |
+
|
| 6 |
+
PolyNorm formula (per row):
|
| 7 |
+
poly = w[0] * rms_norm(x^3) + w[1] * rms_norm(x^2) + w[2] * rms_norm(x) + bias
|
| 8 |
+
output = poly * mul
|
| 9 |
+
|
| 10 |
+
where rms_norm(x) = x / sqrt(mean(x^2, dim=-1) + eps)
|
| 11 |
+
|
| 12 |
+
Performance optimizations:
|
| 13 |
+
- @triton.autotune selects optimal BLOCK_D, num_warps, and num_stages per
|
| 14 |
+
hidden dimension.
|
| 15 |
+
- Single-tile specialization: when D <= BLOCK_D, all data stays in registers
|
| 16 |
+
across the reduction and output phases, eliminating redundant global reads.
|
| 17 |
+
- Multi-tile software pipelining: explicit num_stages in autotune configs
|
| 18 |
+
enables overlapping memory loads with computation across loop iterations.
|
| 19 |
+
- In-kernel binary search for expert mapping: eliminates 2 PyTorch kernel
|
| 20 |
+
launches (torch.arange + torch.bucketize) per forward/backward call.
|
| 21 |
+
- Backward 2-pass optimization: pass 1 merges RMS statistics computation
|
| 22 |
+
with dot product accumulation, pass 2 computes gradients. This reduces
|
| 23 |
+
memory traffic compared to a naive 3-pass approach.
|
| 24 |
+
|
| 25 |
+
Forward kernel: one program per row, tiles over D dimension.
|
| 26 |
+
- Computes x, x^2, x^3 in registers
|
| 27 |
+
- Computes three RMS norms in a single pass (shared variance reduction)
|
| 28 |
+
- Applies polynomial weights + bias + mul in-place
|
| 29 |
+
|
| 30 |
+
Backward kernel: one program per row, tiles over D dimension.
|
| 31 |
+
- Recomputes forward intermediates from saved inputs (activation recomputation)
|
| 32 |
+
- 2-pass: (1) RMS stats + dot products + bias grad, (2) grad_input + grad_mul + weight grads
|
| 33 |
+
- Weight/bias gradients use tl.atomic_add for cross-row accumulation
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
import torch
|
| 37 |
+
from torch import Tensor
|
| 38 |
+
|
| 39 |
+
try:
|
| 40 |
+
import triton
|
| 41 |
+
import triton.language as tl
|
| 42 |
+
|
| 43 |
+
HAS_TRITON = True
|
| 44 |
+
except ImportError:
|
| 45 |
+
HAS_TRITON = False
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# ---------------------------------------------------------------------------
|
| 49 |
+
# PyTorch reference implementation (for testing and benchmarking)
|
| 50 |
+
# ---------------------------------------------------------------------------
|
| 51 |
+
def _rms_norm(x: Tensor, eps: float) -> Tensor:
|
| 52 |
+
"""Per-row RMS normalization: x / sqrt(mean(x^2, dim=-1) + eps)"""
|
| 53 |
+
return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def grouped_fused_mul_poly_norm_ref(
|
| 57 |
+
input: Tensor,
|
| 58 |
+
mul: Tensor,
|
| 59 |
+
weight: Tensor,
|
| 60 |
+
bias: Tensor,
|
| 61 |
+
offsets: Tensor,
|
| 62 |
+
eps: float = 1e-6,
|
| 63 |
+
expert_offset: int = 0,
|
| 64 |
+
) -> Tensor:
|
| 65 |
+
"""PyTorch reference for Grouped FusedMulPolyNorm (vectorized, single pass).
|
| 66 |
+
|
| 67 |
+
Uses torch.bucketize to map tokens to experts, then computes PolyNorm
|
| 68 |
+
for all tokens at once. torch.compile friendly.
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
input: (total_tokens, D) - concatenated tokens for all experts
|
| 72 |
+
mul: (total_tokens, D) - gate values to multiply with
|
| 73 |
+
weight: (num_experts, 3) - per-expert polynomial weights [x^3, x^2, x]
|
| 74 |
+
bias: (num_experts, 1) - per-expert polynomial bias
|
| 75 |
+
offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32)
|
| 76 |
+
eps: numerical stability epsilon
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
(total_tokens, D) - output tensor
|
| 80 |
+
"""
|
| 81 |
+
orig_dtype = input.dtype
|
| 82 |
+
|
| 83 |
+
token_positions = torch.arange(input.shape[0], device=input.device)
|
| 84 |
+
expert_idx = torch.bucketize(token_positions, offsets, right=True) + expert_offset
|
| 85 |
+
|
| 86 |
+
weight_fp32 = weight.float()
|
| 87 |
+
bias_fp32 = bias.float()
|
| 88 |
+
|
| 89 |
+
per_token_w = weight_fp32[expert_idx]
|
| 90 |
+
per_token_b = bias_fp32[expert_idx]
|
| 91 |
+
|
| 92 |
+
x = input.float()
|
| 93 |
+
m = mul.float()
|
| 94 |
+
|
| 95 |
+
x2 = x * x
|
| 96 |
+
x3 = x2 * x
|
| 97 |
+
|
| 98 |
+
poly = (per_token_w[:, 0:1] * _rms_norm(x3, eps) +
|
| 99 |
+
per_token_w[:, 1:2] * _rms_norm(x2, eps) +
|
| 100 |
+
per_token_w[:, 2:3] * _rms_norm(x, eps) + per_token_b)
|
| 101 |
+
|
| 102 |
+
return (poly * m).to(orig_dtype)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
# ---------------------------------------------------------------------------
|
| 106 |
+
# Triton kernel implementation
|
| 107 |
+
# ---------------------------------------------------------------------------
|
| 108 |
+
if HAS_TRITON:
|
| 109 |
+
# --- Autotune configurations ---
|
| 110 |
+
_GROUPED_POLYNORM_FWD_CONFIGS = [
|
| 111 |
+
triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2),
|
| 112 |
+
triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3),
|
| 113 |
+
triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2),
|
| 114 |
+
triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3),
|
| 115 |
+
triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2),
|
| 116 |
+
triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4),
|
| 117 |
+
triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2),
|
| 118 |
+
triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3),
|
| 119 |
+
triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4),
|
| 120 |
+
triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2),
|
| 121 |
+
triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3),
|
| 122 |
+
triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4),
|
| 123 |
+
triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2),
|
| 124 |
+
triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2),
|
| 125 |
+
triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3),
|
| 126 |
+
triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2),
|
| 127 |
+
triton.Config({"BLOCK_D": 2048}, num_warps=4, num_stages=1),
|
| 128 |
+
triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1),
|
| 129 |
+
triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1),
|
| 130 |
+
triton.Config({"BLOCK_D": 2048}, num_warps=32, num_stages=1),
|
| 131 |
+
]
|
| 132 |
+
|
| 133 |
+
_GROUPED_POLYNORM_BWD_CONFIGS = [
|
| 134 |
+
triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2),
|
| 135 |
+
triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3),
|
| 136 |
+
triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2),
|
| 137 |
+
triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3),
|
| 138 |
+
triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2),
|
| 139 |
+
triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4),
|
| 140 |
+
triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2),
|
| 141 |
+
triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3),
|
| 142 |
+
triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4),
|
| 143 |
+
triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=5),
|
| 144 |
+
triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2),
|
| 145 |
+
triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3),
|
| 146 |
+
triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4),
|
| 147 |
+
triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2),
|
| 148 |
+
triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2),
|
| 149 |
+
triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3),
|
| 150 |
+
triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=4),
|
| 151 |
+
triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2),
|
| 152 |
+
triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1),
|
| 153 |
+
triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1),
|
| 154 |
+
]
|
| 155 |
+
|
| 156 |
+
@triton.autotune(
|
| 157 |
+
configs=_GROUPED_POLYNORM_FWD_CONFIGS,
|
| 158 |
+
key=["D"],
|
| 159 |
+
)
|
| 160 |
+
@triton.jit
|
| 161 |
+
def _grouped_polynorm_fwd_kernel(
|
| 162 |
+
input_ptr,
|
| 163 |
+
mul_ptr,
|
| 164 |
+
weight_ptr,
|
| 165 |
+
bias_ptr,
|
| 166 |
+
offsets_ptr,
|
| 167 |
+
output_ptr,
|
| 168 |
+
N,
|
| 169 |
+
D,
|
| 170 |
+
num_experts,
|
| 171 |
+
eps,
|
| 172 |
+
expert_offset,
|
| 173 |
+
stride_input_row,
|
| 174 |
+
stride_mul_row,
|
| 175 |
+
stride_out_row,
|
| 176 |
+
BLOCK_D: tl.constexpr,
|
| 177 |
+
):
|
| 178 |
+
"""Forward kernel: one program per row."""
|
| 179 |
+
row = tl.program_id(0)
|
| 180 |
+
if row >= N:
|
| 181 |
+
return
|
| 182 |
+
|
| 183 |
+
# Binary search for expert index (12 iters covers up to 4096 experts)
|
| 184 |
+
lo = 0
|
| 185 |
+
hi = num_experts
|
| 186 |
+
for _ in range(12):
|
| 187 |
+
if lo < hi:
|
| 188 |
+
mid = (lo + hi) // 2
|
| 189 |
+
if tl.load(offsets_ptr + mid) <= row:
|
| 190 |
+
lo = mid + 1
|
| 191 |
+
else:
|
| 192 |
+
hi = mid
|
| 193 |
+
eidx = lo + expert_offset
|
| 194 |
+
|
| 195 |
+
w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32)
|
| 196 |
+
w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32)
|
| 197 |
+
w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32)
|
| 198 |
+
b = tl.load(bias_ptr + eidx).to(tl.float32)
|
| 199 |
+
|
| 200 |
+
input_row_ptr = input_ptr + row * stride_input_row
|
| 201 |
+
mul_row_ptr = mul_ptr + row * stride_mul_row
|
| 202 |
+
out_row_ptr = output_ptr + row * stride_out_row
|
| 203 |
+
|
| 204 |
+
D_float = D.to(tl.float32)
|
| 205 |
+
|
| 206 |
+
# --- Single-tile path ---
|
| 207 |
+
if D <= BLOCK_D:
|
| 208 |
+
d_offs = tl.arange(0, BLOCK_D)
|
| 209 |
+
mask = d_offs < D
|
| 210 |
+
|
| 211 |
+
x = tl.load(input_row_ptr + d_offs, mask=mask,
|
| 212 |
+
other=0.0).to(tl.float32)
|
| 213 |
+
m = tl.load(mul_row_ptr + d_offs, mask=mask,
|
| 214 |
+
other=0.0).to(tl.float32)
|
| 215 |
+
|
| 216 |
+
x2 = x * x
|
| 217 |
+
x3 = x2 * x
|
| 218 |
+
|
| 219 |
+
inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps)
|
| 220 |
+
inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps)
|
| 221 |
+
inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps)
|
| 222 |
+
|
| 223 |
+
# Pre-multiply scalar weight * inv_rms to save 1 FMA per element
|
| 224 |
+
w0_inv = w0 * inv_rms_x3
|
| 225 |
+
w1_inv = w1 * inv_rms_x2
|
| 226 |
+
w2_inv = w2 * inv_rms_x
|
| 227 |
+
|
| 228 |
+
poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b
|
| 229 |
+
tl.store(out_row_ptr + d_offs, poly * m, mask=mask)
|
| 230 |
+
else:
|
| 231 |
+
# --- Multi-tile: two-pass approach ---
|
| 232 |
+
sum_x2 = tl.zeros((), dtype=tl.float32)
|
| 233 |
+
sum_x4 = tl.zeros((), dtype=tl.float32)
|
| 234 |
+
sum_x6 = tl.zeros((), dtype=tl.float32)
|
| 235 |
+
|
| 236 |
+
for d_start in range(0, D, BLOCK_D):
|
| 237 |
+
d_offs = d_start + tl.arange(0, BLOCK_D)
|
| 238 |
+
mask = d_offs < D
|
| 239 |
+
x = tl.load(input_row_ptr + d_offs, mask=mask,
|
| 240 |
+
other=0.0).to(tl.float32)
|
| 241 |
+
x2 = x * x
|
| 242 |
+
sum_x2 += tl.sum(x2)
|
| 243 |
+
sum_x4 += tl.sum(x2 * x2)
|
| 244 |
+
sum_x6 += tl.sum(x2 * x2 * x2)
|
| 245 |
+
|
| 246 |
+
inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps)
|
| 247 |
+
inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps)
|
| 248 |
+
inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps)
|
| 249 |
+
|
| 250 |
+
# Pre-multiply scalar weight * inv_rms
|
| 251 |
+
w0_inv = w0 * inv_rms_x3
|
| 252 |
+
w1_inv = w1 * inv_rms_x2
|
| 253 |
+
w2_inv = w2 * inv_rms_x
|
| 254 |
+
|
| 255 |
+
for d_start in range(0, D, BLOCK_D):
|
| 256 |
+
d_offs = d_start + tl.arange(0, BLOCK_D)
|
| 257 |
+
mask = d_offs < D
|
| 258 |
+
x = tl.load(input_row_ptr + d_offs, mask=mask,
|
| 259 |
+
other=0.0).to(tl.float32)
|
| 260 |
+
m = tl.load(mul_row_ptr + d_offs, mask=mask,
|
| 261 |
+
other=0.0).to(tl.float32)
|
| 262 |
+
x2 = x * x
|
| 263 |
+
x3 = x2 * x
|
| 264 |
+
poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b
|
| 265 |
+
tl.store(out_row_ptr + d_offs, poly * m, mask=mask)
|
| 266 |
+
|
| 267 |
+
@triton.autotune(
|
| 268 |
+
configs=_GROUPED_POLYNORM_BWD_CONFIGS,
|
| 269 |
+
key=["D"],
|
| 270 |
+
reset_to_zero=["grad_weight_ptr", "grad_bias_ptr"],
|
| 271 |
+
)
|
| 272 |
+
@triton.jit
|
| 273 |
+
def _grouped_polynorm_bwd_kernel(
|
| 274 |
+
grad_out_ptr,
|
| 275 |
+
input_ptr,
|
| 276 |
+
mul_ptr,
|
| 277 |
+
weight_ptr,
|
| 278 |
+
bias_ptr,
|
| 279 |
+
offsets_ptr,
|
| 280 |
+
grad_input_ptr,
|
| 281 |
+
grad_mul_ptr,
|
| 282 |
+
grad_weight_ptr,
|
| 283 |
+
grad_bias_ptr,
|
| 284 |
+
N,
|
| 285 |
+
D,
|
| 286 |
+
num_experts,
|
| 287 |
+
eps,
|
| 288 |
+
expert_offset,
|
| 289 |
+
stride_row,
|
| 290 |
+
BLOCK_D: tl.constexpr,
|
| 291 |
+
):
|
| 292 |
+
"""Backward kernel: one program per row, 2-pass approach.
|
| 293 |
+
|
| 294 |
+
Pass 1: RMS stats + dot products + bias grad
|
| 295 |
+
Pass 2: grad_input + grad_mul + weight grads (via atomic_add)
|
| 296 |
+
"""
|
| 297 |
+
row = tl.program_id(0)
|
| 298 |
+
if row >= N:
|
| 299 |
+
return
|
| 300 |
+
|
| 301 |
+
lo = 0
|
| 302 |
+
hi = num_experts
|
| 303 |
+
for _ in range(12):
|
| 304 |
+
if lo < hi:
|
| 305 |
+
mid = (lo + hi) // 2
|
| 306 |
+
if tl.load(offsets_ptr + mid) <= row:
|
| 307 |
+
lo = mid + 1
|
| 308 |
+
else:
|
| 309 |
+
hi = mid
|
| 310 |
+
eidx = lo + expert_offset
|
| 311 |
+
|
| 312 |
+
w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32)
|
| 313 |
+
w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32)
|
| 314 |
+
w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32)
|
| 315 |
+
b_val = tl.load(bias_ptr + eidx).to(tl.float32)
|
| 316 |
+
|
| 317 |
+
input_row_ptr = input_ptr + row * stride_row
|
| 318 |
+
mul_row_ptr = mul_ptr + row * stride_row
|
| 319 |
+
grad_out_row_ptr = grad_out_ptr + row * stride_row
|
| 320 |
+
grad_input_row_ptr = grad_input_ptr + row * stride_row
|
| 321 |
+
grad_mul_row_ptr = grad_mul_ptr + row * stride_row
|
| 322 |
+
|
| 323 |
+
D_float = D.to(tl.float32)
|
| 324 |
+
|
| 325 |
+
# --- Single-tile path ---
|
| 326 |
+
if D <= BLOCK_D:
|
| 327 |
+
d_offs = tl.arange(0, BLOCK_D)
|
| 328 |
+
mask = d_offs < D
|
| 329 |
+
|
| 330 |
+
x = tl.load(input_row_ptr + d_offs, mask=mask,
|
| 331 |
+
other=0.0).to(tl.float32)
|
| 332 |
+
m = tl.load(mul_row_ptr + d_offs, mask=mask,
|
| 333 |
+
other=0.0).to(tl.float32)
|
| 334 |
+
go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
|
| 335 |
+
other=0.0).to(tl.float32)
|
| 336 |
+
|
| 337 |
+
x2 = x * x
|
| 338 |
+
x3 = x2 * x
|
| 339 |
+
|
| 340 |
+
# Compute RMS stats (x4 inlined to reduce register pressure)
|
| 341 |
+
inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps)
|
| 342 |
+
inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps)
|
| 343 |
+
inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps)
|
| 344 |
+
|
| 345 |
+
w0_inv = w0 * inv_rms_x3
|
| 346 |
+
w1_inv = w1 * inv_rms_x2
|
| 347 |
+
w2_inv = w2 * inv_rms_x
|
| 348 |
+
|
| 349 |
+
dpoly = go * m
|
| 350 |
+
|
| 351 |
+
# Dot products for coefficients and weight grads
|
| 352 |
+
sum_dpoly_x = tl.sum(dpoly * x)
|
| 353 |
+
sum_dpoly_x2 = tl.sum(dpoly * x2)
|
| 354 |
+
sum_dpoly_x3 = tl.sum(dpoly * x3)
|
| 355 |
+
grad_b_acc = tl.sum(dpoly)
|
| 356 |
+
|
| 357 |
+
# Weight grads
|
| 358 |
+
grad_w0_acc = inv_rms_x3 * sum_dpoly_x3
|
| 359 |
+
grad_w1_acc = inv_rms_x2 * sum_dpoly_x2
|
| 360 |
+
grad_w2_acc = inv_rms_x * sum_dpoly_x
|
| 361 |
+
|
| 362 |
+
# Coefficients for grad_input
|
| 363 |
+
coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float
|
| 364 |
+
coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float
|
| 365 |
+
coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float
|
| 366 |
+
|
| 367 |
+
# grad_mul
|
| 368 |
+
poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv
|
| 369 |
+
tl.store(grad_mul_row_ptr + d_offs, go * (poly + b_val), mask=mask)
|
| 370 |
+
|
| 371 |
+
# grad_input (in-place accumulation to reduce register pressure)
|
| 372 |
+
g = inv_rms_x * (w2 * dpoly - x * coeff_x)
|
| 373 |
+
g += 2.0 * x * inv_rms_x2 * (w1 * dpoly - x2 * coeff_x2)
|
| 374 |
+
g += 3.0 * x2 * inv_rms_x3 * (w0 * dpoly - x3 * coeff_x3)
|
| 375 |
+
tl.store(grad_input_row_ptr + d_offs, g, mask=mask)
|
| 376 |
+
|
| 377 |
+
tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc)
|
| 378 |
+
tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc)
|
| 379 |
+
tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc)
|
| 380 |
+
tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc)
|
| 381 |
+
else:
|
| 382 |
+
# --- Multi-tile: 2-pass ---
|
| 383 |
+
# Pass 1: RMS stats + dot products + bias grad
|
| 384 |
+
sum_x2 = tl.zeros((), dtype=tl.float32)
|
| 385 |
+
sum_x4 = tl.zeros((), dtype=tl.float32)
|
| 386 |
+
sum_x6 = tl.zeros((), dtype=tl.float32)
|
| 387 |
+
sum_dpoly_x = tl.zeros((), dtype=tl.float32)
|
| 388 |
+
sum_dpoly_x2 = tl.zeros((), dtype=tl.float32)
|
| 389 |
+
sum_dpoly_x3 = tl.zeros((), dtype=tl.float32)
|
| 390 |
+
grad_b_acc = tl.zeros((), dtype=tl.float32)
|
| 391 |
+
|
| 392 |
+
for d_start in range(0, D, BLOCK_D):
|
| 393 |
+
d_offs = d_start + tl.arange(0, BLOCK_D)
|
| 394 |
+
mask = d_offs < D
|
| 395 |
+
x = tl.load(input_row_ptr + d_offs, mask=mask,
|
| 396 |
+
other=0.0).to(tl.float32)
|
| 397 |
+
m = tl.load(mul_row_ptr + d_offs, mask=mask,
|
| 398 |
+
other=0.0).to(tl.float32)
|
| 399 |
+
go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
|
| 400 |
+
other=0.0).to(tl.float32)
|
| 401 |
+
|
| 402 |
+
x2 = x * x
|
| 403 |
+
x3 = x2 * x
|
| 404 |
+
dpoly = go * m
|
| 405 |
+
|
| 406 |
+
sum_x2 += tl.sum(x2)
|
| 407 |
+
sum_x4 += tl.sum(x2 * x2)
|
| 408 |
+
sum_x6 += tl.sum(x2 * x2 * x2)
|
| 409 |
+
sum_dpoly_x += tl.sum(dpoly * x)
|
| 410 |
+
sum_dpoly_x2 += tl.sum(dpoly * x2)
|
| 411 |
+
sum_dpoly_x3 += tl.sum(dpoly * x3)
|
| 412 |
+
grad_b_acc += tl.sum(dpoly)
|
| 413 |
+
|
| 414 |
+
inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps)
|
| 415 |
+
inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps)
|
| 416 |
+
inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps)
|
| 417 |
+
|
| 418 |
+
w0_inv = w0 * inv_rms_x3
|
| 419 |
+
w1_inv = w1 * inv_rms_x2
|
| 420 |
+
w2_inv = w2 * inv_rms_x
|
| 421 |
+
|
| 422 |
+
# Weight grads
|
| 423 |
+
grad_w0_acc = inv_rms_x3 * sum_dpoly_x3
|
| 424 |
+
grad_w1_acc = inv_rms_x2 * sum_dpoly_x2
|
| 425 |
+
grad_w2_acc = inv_rms_x * sum_dpoly_x
|
| 426 |
+
|
| 427 |
+
# Coefficients for grad_input
|
| 428 |
+
coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float
|
| 429 |
+
coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float
|
| 430 |
+
coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float
|
| 431 |
+
|
| 432 |
+
# Pass 2: grad_input + grad_mul
|
| 433 |
+
for d_start in range(0, D, BLOCK_D):
|
| 434 |
+
d_offs = d_start + tl.arange(0, BLOCK_D)
|
| 435 |
+
mask = d_offs < D
|
| 436 |
+
x = tl.load(input_row_ptr + d_offs, mask=mask,
|
| 437 |
+
other=0.0).to(tl.float32)
|
| 438 |
+
m = tl.load(mul_row_ptr + d_offs, mask=mask,
|
| 439 |
+
other=0.0).to(tl.float32)
|
| 440 |
+
go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
|
| 441 |
+
other=0.0).to(tl.float32)
|
| 442 |
+
|
| 443 |
+
x2 = x * x
|
| 444 |
+
x3 = x2 * x
|
| 445 |
+
|
| 446 |
+
poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv
|
| 447 |
+
tl.store(grad_mul_row_ptr + d_offs,
|
| 448 |
+
go * (poly + b_val),
|
| 449 |
+
mask=mask)
|
| 450 |
+
|
| 451 |
+
dpoly = go * m
|
| 452 |
+
g = inv_rms_x * (w2 * dpoly - x * coeff_x)
|
| 453 |
+
g += (2.0 * x * inv_rms_x2 *
|
| 454 |
+
(w1 * dpoly - x2 * coeff_x2))
|
| 455 |
+
g += (3.0 * x2 * inv_rms_x3 *
|
| 456 |
+
(w0 * dpoly - x3 * coeff_x3))
|
| 457 |
+
|
| 458 |
+
tl.store(grad_input_row_ptr + d_offs, g, mask=mask)
|
| 459 |
+
|
| 460 |
+
tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc)
|
| 461 |
+
tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc)
|
| 462 |
+
tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc)
|
| 463 |
+
tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc)
|
| 464 |
+
|
| 465 |
+
class _GroupedPolyNormFn(torch.autograd.Function):
|
| 466 |
+
|
| 467 |
+
@staticmethod
|
| 468 |
+
def forward(ctx, input, mul, weight, bias, offsets, eps, expert_offset):
|
| 469 |
+
N, D = input.shape
|
| 470 |
+
input = input.contiguous()
|
| 471 |
+
mul = mul.contiguous()
|
| 472 |
+
output = torch.empty_like(input)
|
| 473 |
+
|
| 474 |
+
num_experts = offsets.shape[0]
|
| 475 |
+
assert num_experts <= 4096, (
|
| 476 |
+
f"Supports at most 4096 experts, got {num_experts}.")
|
| 477 |
+
|
| 478 |
+
_grouped_polynorm_fwd_kernel[(N,)](
|
| 479 |
+
input,
|
| 480 |
+
mul,
|
| 481 |
+
weight,
|
| 482 |
+
bias,
|
| 483 |
+
offsets,
|
| 484 |
+
output,
|
| 485 |
+
N,
|
| 486 |
+
D,
|
| 487 |
+
num_experts,
|
| 488 |
+
eps,
|
| 489 |
+
expert_offset,
|
| 490 |
+
stride_input_row=input.stride(0),
|
| 491 |
+
stride_mul_row=mul.stride(0),
|
| 492 |
+
stride_out_row=output.stride(0),
|
| 493 |
+
)
|
| 494 |
+
|
| 495 |
+
ctx.save_for_backward(input, mul, weight, bias, offsets)
|
| 496 |
+
ctx.eps = eps
|
| 497 |
+
ctx.expert_offset = expert_offset
|
| 498 |
+
return output
|
| 499 |
+
|
| 500 |
+
@staticmethod
|
| 501 |
+
def backward(ctx, grad_output):
|
| 502 |
+
input, mul, weight, bias, offsets = ctx.saved_tensors
|
| 503 |
+
eps = ctx.eps
|
| 504 |
+
expert_offset = ctx.expert_offset
|
| 505 |
+
N, D = input.shape
|
| 506 |
+
|
| 507 |
+
grad_output = grad_output.contiguous()
|
| 508 |
+
grad_input = torch.empty_like(input)
|
| 509 |
+
grad_mul = torch.empty_like(mul)
|
| 510 |
+
grad_weight = torch.zeros(weight.shape[0],
|
| 511 |
+
3,
|
| 512 |
+
device=weight.device,
|
| 513 |
+
dtype=torch.float32)
|
| 514 |
+
grad_bias = torch.zeros(bias.shape[0],
|
| 515 |
+
device=bias.device,
|
| 516 |
+
dtype=torch.float32)
|
| 517 |
+
|
| 518 |
+
num_experts = offsets.shape[0]
|
| 519 |
+
|
| 520 |
+
_grouped_polynorm_bwd_kernel[(N,)](
|
| 521 |
+
grad_output,
|
| 522 |
+
input,
|
| 523 |
+
mul,
|
| 524 |
+
weight,
|
| 525 |
+
bias,
|
| 526 |
+
offsets,
|
| 527 |
+
grad_input,
|
| 528 |
+
grad_mul,
|
| 529 |
+
grad_weight,
|
| 530 |
+
grad_bias,
|
| 531 |
+
N,
|
| 532 |
+
D,
|
| 533 |
+
num_experts,
|
| 534 |
+
eps,
|
| 535 |
+
expert_offset,
|
| 536 |
+
stride_row=input.stride(0),
|
| 537 |
+
)
|
| 538 |
+
|
| 539 |
+
grad_weight = grad_weight.to(weight.dtype)
|
| 540 |
+
grad_bias = grad_bias.unsqueeze(-1).to(bias.dtype)
|
| 541 |
+
|
| 542 |
+
return grad_input, grad_mul, grad_weight, grad_bias, None, None, None
|
| 543 |
+
|
| 544 |
+
def grouped_fused_mul_poly_norm(
|
| 545 |
+
input: Tensor,
|
| 546 |
+
mul: Tensor,
|
| 547 |
+
weight: Tensor,
|
| 548 |
+
bias: Tensor,
|
| 549 |
+
offsets: Tensor,
|
| 550 |
+
eps: float = 1e-6,
|
| 551 |
+
expert_offset: int = 0,
|
| 552 |
+
) -> Tensor:
|
| 553 |
+
"""Triton-accelerated Grouped FusedMulPolyNorm.
|
| 554 |
+
|
| 555 |
+
Args:
|
| 556 |
+
input: (total_tokens, D) - concatenated tokens for all experts
|
| 557 |
+
mul: (total_tokens, D) - gate values to multiply with
|
| 558 |
+
weight: (num_experts, 3) - per-expert polynomial weights
|
| 559 |
+
bias: (num_experts, 1) - per-expert polynomial bias
|
| 560 |
+
offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32)
|
| 561 |
+
eps: numerical stability epsilon
|
| 562 |
+
expert_offset: offset to add to expert index
|
| 563 |
+
|
| 564 |
+
Returns:
|
| 565 |
+
(total_tokens, D) - output tensor
|
| 566 |
+
"""
|
| 567 |
+
return _GroupedPolyNormFn.apply(input, mul, weight, bias, offsets, eps,
|
| 568 |
+
expert_offset)
|
| 569 |
+
|
| 570 |
+
else:
|
| 571 |
+
|
| 572 |
+
def grouped_fused_mul_poly_norm(
|
| 573 |
+
input: Tensor,
|
| 574 |
+
mul: Tensor,
|
| 575 |
+
weight: Tensor,
|
| 576 |
+
bias: Tensor,
|
| 577 |
+
offsets: Tensor,
|
| 578 |
+
eps: float = 1e-6,
|
| 579 |
+
expert_offset: int = 0,
|
| 580 |
+
) -> Tensor:
|
| 581 |
+
raise RuntimeError(
|
| 582 |
+
"Triton is not available. Install triton to use "
|
| 583 |
+
"grouped_fused_mul_poly_norm.")
|
build/torch28-cxx11-rocm63-x86_64-linux/__init__.py
CHANGED
|
@@ -2,6 +2,7 @@ import torch
|
|
| 2 |
|
| 3 |
from . import layers, parallel_style
|
| 4 |
from ._ops import ops
|
|
|
|
| 5 |
from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
|
| 6 |
from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
|
| 7 |
|
|
@@ -45,6 +46,7 @@ def fused_add_rms_norm(
|
|
| 45 |
__all__ = [
|
| 46 |
"poly_norm",
|
| 47 |
"fused_mul_poly_norm",
|
|
|
|
| 48 |
"rms_norm",
|
| 49 |
"fused_add_rms_norm",
|
| 50 |
"layers",
|
|
|
|
| 2 |
|
| 3 |
from . import layers, parallel_style
|
| 4 |
from ._ops import ops
|
| 5 |
+
from .grouped_poly_norm import grouped_fused_mul_poly_norm
|
| 6 |
from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
|
| 7 |
from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
|
| 8 |
|
|
|
|
| 46 |
__all__ = [
|
| 47 |
"poly_norm",
|
| 48 |
"fused_mul_poly_norm",
|
| 49 |
+
"grouped_fused_mul_poly_norm",
|
| 50 |
"rms_norm",
|
| 51 |
"fused_add_rms_norm",
|
| 52 |
"layers",
|
build/torch28-cxx11-rocm63-x86_64-linux/{_activation_18b7543_dirty.abi3.so → _activation_0e6f27f_dirty.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 2788456
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:448ab4d8a859725a3200d95d2164a3fe261f67b20e834f4f7062485cf729cf88
|
| 3 |
size 2788456
|
build/torch28-cxx11-rocm63-x86_64-linux/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _activation_0e6f27f_dirty
|
| 3 |
+
ops = torch.ops._activation_0e6f27f_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_activation_0e6f27f_dirty::{op_name}"
|
build/torch28-cxx11-rocm63-x86_64-linux/grouped_poly_norm.py
ADDED
|
@@ -0,0 +1,583 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Triton-accelerated Grouped FusedMulPolyNorm for MoE.
|
| 2 |
+
|
| 3 |
+
Fuses the entire PolyNorm computation into two Triton kernels (fwd + bwd),
|
| 4 |
+
eliminating multiple intermediate tensors and kernel launches.
|
| 5 |
+
|
| 6 |
+
PolyNorm formula (per row):
|
| 7 |
+
poly = w[0] * rms_norm(x^3) + w[1] * rms_norm(x^2) + w[2] * rms_norm(x) + bias
|
| 8 |
+
output = poly * mul
|
| 9 |
+
|
| 10 |
+
where rms_norm(x) = x / sqrt(mean(x^2, dim=-1) + eps)
|
| 11 |
+
|
| 12 |
+
Performance optimizations:
|
| 13 |
+
- @triton.autotune selects optimal BLOCK_D, num_warps, and num_stages per
|
| 14 |
+
hidden dimension.
|
| 15 |
+
- Single-tile specialization: when D <= BLOCK_D, all data stays in registers
|
| 16 |
+
across the reduction and output phases, eliminating redundant global reads.
|
| 17 |
+
- Multi-tile software pipelining: explicit num_stages in autotune configs
|
| 18 |
+
enables overlapping memory loads with computation across loop iterations.
|
| 19 |
+
- In-kernel binary search for expert mapping: eliminates 2 PyTorch kernel
|
| 20 |
+
launches (torch.arange + torch.bucketize) per forward/backward call.
|
| 21 |
+
- Backward 2-pass optimization: pass 1 merges RMS statistics computation
|
| 22 |
+
with dot product accumulation, pass 2 computes gradients. This reduces
|
| 23 |
+
memory traffic compared to a naive 3-pass approach.
|
| 24 |
+
|
| 25 |
+
Forward kernel: one program per row, tiles over D dimension.
|
| 26 |
+
- Computes x, x^2, x^3 in registers
|
| 27 |
+
- Computes three RMS norms in a single pass (shared variance reduction)
|
| 28 |
+
- Applies polynomial weights + bias + mul in-place
|
| 29 |
+
|
| 30 |
+
Backward kernel: one program per row, tiles over D dimension.
|
| 31 |
+
- Recomputes forward intermediates from saved inputs (activation recomputation)
|
| 32 |
+
- 2-pass: (1) RMS stats + dot products + bias grad, (2) grad_input + grad_mul + weight grads
|
| 33 |
+
- Weight/bias gradients use tl.atomic_add for cross-row accumulation
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
import torch
|
| 37 |
+
from torch import Tensor
|
| 38 |
+
|
| 39 |
+
try:
|
| 40 |
+
import triton
|
| 41 |
+
import triton.language as tl
|
| 42 |
+
|
| 43 |
+
HAS_TRITON = True
|
| 44 |
+
except ImportError:
|
| 45 |
+
HAS_TRITON = False
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# ---------------------------------------------------------------------------
|
| 49 |
+
# PyTorch reference implementation (for testing and benchmarking)
|
| 50 |
+
# ---------------------------------------------------------------------------
|
| 51 |
+
def _rms_norm(x: Tensor, eps: float) -> Tensor:
|
| 52 |
+
"""Per-row RMS normalization: x / sqrt(mean(x^2, dim=-1) + eps)"""
|
| 53 |
+
return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def grouped_fused_mul_poly_norm_ref(
|
| 57 |
+
input: Tensor,
|
| 58 |
+
mul: Tensor,
|
| 59 |
+
weight: Tensor,
|
| 60 |
+
bias: Tensor,
|
| 61 |
+
offsets: Tensor,
|
| 62 |
+
eps: float = 1e-6,
|
| 63 |
+
expert_offset: int = 0,
|
| 64 |
+
) -> Tensor:
|
| 65 |
+
"""PyTorch reference for Grouped FusedMulPolyNorm (vectorized, single pass).
|
| 66 |
+
|
| 67 |
+
Uses torch.bucketize to map tokens to experts, then computes PolyNorm
|
| 68 |
+
for all tokens at once. torch.compile friendly.
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
input: (total_tokens, D) - concatenated tokens for all experts
|
| 72 |
+
mul: (total_tokens, D) - gate values to multiply with
|
| 73 |
+
weight: (num_experts, 3) - per-expert polynomial weights [x^3, x^2, x]
|
| 74 |
+
bias: (num_experts, 1) - per-expert polynomial bias
|
| 75 |
+
offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32)
|
| 76 |
+
eps: numerical stability epsilon
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
(total_tokens, D) - output tensor
|
| 80 |
+
"""
|
| 81 |
+
orig_dtype = input.dtype
|
| 82 |
+
|
| 83 |
+
token_positions = torch.arange(input.shape[0], device=input.device)
|
| 84 |
+
expert_idx = torch.bucketize(token_positions, offsets, right=True) + expert_offset
|
| 85 |
+
|
| 86 |
+
weight_fp32 = weight.float()
|
| 87 |
+
bias_fp32 = bias.float()
|
| 88 |
+
|
| 89 |
+
per_token_w = weight_fp32[expert_idx]
|
| 90 |
+
per_token_b = bias_fp32[expert_idx]
|
| 91 |
+
|
| 92 |
+
x = input.float()
|
| 93 |
+
m = mul.float()
|
| 94 |
+
|
| 95 |
+
x2 = x * x
|
| 96 |
+
x3 = x2 * x
|
| 97 |
+
|
| 98 |
+
poly = (per_token_w[:, 0:1] * _rms_norm(x3, eps) +
|
| 99 |
+
per_token_w[:, 1:2] * _rms_norm(x2, eps) +
|
| 100 |
+
per_token_w[:, 2:3] * _rms_norm(x, eps) + per_token_b)
|
| 101 |
+
|
| 102 |
+
return (poly * m).to(orig_dtype)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
# ---------------------------------------------------------------------------
|
| 106 |
+
# Triton kernel implementation
|
| 107 |
+
# ---------------------------------------------------------------------------
|
| 108 |
+
if HAS_TRITON:
|
| 109 |
+
# --- Autotune configurations ---
|
| 110 |
+
_GROUPED_POLYNORM_FWD_CONFIGS = [
|
| 111 |
+
triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2),
|
| 112 |
+
triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3),
|
| 113 |
+
triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2),
|
| 114 |
+
triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3),
|
| 115 |
+
triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2),
|
| 116 |
+
triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4),
|
| 117 |
+
triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2),
|
| 118 |
+
triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3),
|
| 119 |
+
triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4),
|
| 120 |
+
triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2),
|
| 121 |
+
triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3),
|
| 122 |
+
triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4),
|
| 123 |
+
triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2),
|
| 124 |
+
triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2),
|
| 125 |
+
triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3),
|
| 126 |
+
triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2),
|
| 127 |
+
triton.Config({"BLOCK_D": 2048}, num_warps=4, num_stages=1),
|
| 128 |
+
triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1),
|
| 129 |
+
triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1),
|
| 130 |
+
triton.Config({"BLOCK_D": 2048}, num_warps=32, num_stages=1),
|
| 131 |
+
]
|
| 132 |
+
|
| 133 |
+
_GROUPED_POLYNORM_BWD_CONFIGS = [
|
| 134 |
+
triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2),
|
| 135 |
+
triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3),
|
| 136 |
+
triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2),
|
| 137 |
+
triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3),
|
| 138 |
+
triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2),
|
| 139 |
+
triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4),
|
| 140 |
+
triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2),
|
| 141 |
+
triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3),
|
| 142 |
+
triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4),
|
| 143 |
+
triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=5),
|
| 144 |
+
triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2),
|
| 145 |
+
triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3),
|
| 146 |
+
triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4),
|
| 147 |
+
triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2),
|
| 148 |
+
triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2),
|
| 149 |
+
triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3),
|
| 150 |
+
triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=4),
|
| 151 |
+
triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2),
|
| 152 |
+
triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1),
|
| 153 |
+
triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1),
|
| 154 |
+
]
|
| 155 |
+
|
| 156 |
+
@triton.autotune(
|
| 157 |
+
configs=_GROUPED_POLYNORM_FWD_CONFIGS,
|
| 158 |
+
key=["D"],
|
| 159 |
+
)
|
| 160 |
+
@triton.jit
|
| 161 |
+
def _grouped_polynorm_fwd_kernel(
|
| 162 |
+
input_ptr,
|
| 163 |
+
mul_ptr,
|
| 164 |
+
weight_ptr,
|
| 165 |
+
bias_ptr,
|
| 166 |
+
offsets_ptr,
|
| 167 |
+
output_ptr,
|
| 168 |
+
N,
|
| 169 |
+
D,
|
| 170 |
+
num_experts,
|
| 171 |
+
eps,
|
| 172 |
+
expert_offset,
|
| 173 |
+
stride_input_row,
|
| 174 |
+
stride_mul_row,
|
| 175 |
+
stride_out_row,
|
| 176 |
+
BLOCK_D: tl.constexpr,
|
| 177 |
+
):
|
| 178 |
+
"""Forward kernel: one program per row."""
|
| 179 |
+
row = tl.program_id(0)
|
| 180 |
+
if row >= N:
|
| 181 |
+
return
|
| 182 |
+
|
| 183 |
+
# Binary search for expert index (12 iters covers up to 4096 experts)
|
| 184 |
+
lo = 0
|
| 185 |
+
hi = num_experts
|
| 186 |
+
for _ in range(12):
|
| 187 |
+
if lo < hi:
|
| 188 |
+
mid = (lo + hi) // 2
|
| 189 |
+
if tl.load(offsets_ptr + mid) <= row:
|
| 190 |
+
lo = mid + 1
|
| 191 |
+
else:
|
| 192 |
+
hi = mid
|
| 193 |
+
eidx = lo + expert_offset
|
| 194 |
+
|
| 195 |
+
w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32)
|
| 196 |
+
w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32)
|
| 197 |
+
w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32)
|
| 198 |
+
b = tl.load(bias_ptr + eidx).to(tl.float32)
|
| 199 |
+
|
| 200 |
+
input_row_ptr = input_ptr + row * stride_input_row
|
| 201 |
+
mul_row_ptr = mul_ptr + row * stride_mul_row
|
| 202 |
+
out_row_ptr = output_ptr + row * stride_out_row
|
| 203 |
+
|
| 204 |
+
D_float = D.to(tl.float32)
|
| 205 |
+
|
| 206 |
+
# --- Single-tile path ---
|
| 207 |
+
if D <= BLOCK_D:
|
| 208 |
+
d_offs = tl.arange(0, BLOCK_D)
|
| 209 |
+
mask = d_offs < D
|
| 210 |
+
|
| 211 |
+
x = tl.load(input_row_ptr + d_offs, mask=mask,
|
| 212 |
+
other=0.0).to(tl.float32)
|
| 213 |
+
m = tl.load(mul_row_ptr + d_offs, mask=mask,
|
| 214 |
+
other=0.0).to(tl.float32)
|
| 215 |
+
|
| 216 |
+
x2 = x * x
|
| 217 |
+
x3 = x2 * x
|
| 218 |
+
|
| 219 |
+
inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps)
|
| 220 |
+
inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps)
|
| 221 |
+
inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps)
|
| 222 |
+
|
| 223 |
+
# Pre-multiply scalar weight * inv_rms to save 1 FMA per element
|
| 224 |
+
w0_inv = w0 * inv_rms_x3
|
| 225 |
+
w1_inv = w1 * inv_rms_x2
|
| 226 |
+
w2_inv = w2 * inv_rms_x
|
| 227 |
+
|
| 228 |
+
poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b
|
| 229 |
+
tl.store(out_row_ptr + d_offs, poly * m, mask=mask)
|
| 230 |
+
else:
|
| 231 |
+
# --- Multi-tile: two-pass approach ---
|
| 232 |
+
sum_x2 = tl.zeros((), dtype=tl.float32)
|
| 233 |
+
sum_x4 = tl.zeros((), dtype=tl.float32)
|
| 234 |
+
sum_x6 = tl.zeros((), dtype=tl.float32)
|
| 235 |
+
|
| 236 |
+
for d_start in range(0, D, BLOCK_D):
|
| 237 |
+
d_offs = d_start + tl.arange(0, BLOCK_D)
|
| 238 |
+
mask = d_offs < D
|
| 239 |
+
x = tl.load(input_row_ptr + d_offs, mask=mask,
|
| 240 |
+
other=0.0).to(tl.float32)
|
| 241 |
+
x2 = x * x
|
| 242 |
+
sum_x2 += tl.sum(x2)
|
| 243 |
+
sum_x4 += tl.sum(x2 * x2)
|
| 244 |
+
sum_x6 += tl.sum(x2 * x2 * x2)
|
| 245 |
+
|
| 246 |
+
inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps)
|
| 247 |
+
inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps)
|
| 248 |
+
inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps)
|
| 249 |
+
|
| 250 |
+
# Pre-multiply scalar weight * inv_rms
|
| 251 |
+
w0_inv = w0 * inv_rms_x3
|
| 252 |
+
w1_inv = w1 * inv_rms_x2
|
| 253 |
+
w2_inv = w2 * inv_rms_x
|
| 254 |
+
|
| 255 |
+
for d_start in range(0, D, BLOCK_D):
|
| 256 |
+
d_offs = d_start + tl.arange(0, BLOCK_D)
|
| 257 |
+
mask = d_offs < D
|
| 258 |
+
x = tl.load(input_row_ptr + d_offs, mask=mask,
|
| 259 |
+
other=0.0).to(tl.float32)
|
| 260 |
+
m = tl.load(mul_row_ptr + d_offs, mask=mask,
|
| 261 |
+
other=0.0).to(tl.float32)
|
| 262 |
+
x2 = x * x
|
| 263 |
+
x3 = x2 * x
|
| 264 |
+
poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b
|
| 265 |
+
tl.store(out_row_ptr + d_offs, poly * m, mask=mask)
|
| 266 |
+
|
| 267 |
+
@triton.autotune(
|
| 268 |
+
configs=_GROUPED_POLYNORM_BWD_CONFIGS,
|
| 269 |
+
key=["D"],
|
| 270 |
+
reset_to_zero=["grad_weight_ptr", "grad_bias_ptr"],
|
| 271 |
+
)
|
| 272 |
+
@triton.jit
|
| 273 |
+
def _grouped_polynorm_bwd_kernel(
|
| 274 |
+
grad_out_ptr,
|
| 275 |
+
input_ptr,
|
| 276 |
+
mul_ptr,
|
| 277 |
+
weight_ptr,
|
| 278 |
+
bias_ptr,
|
| 279 |
+
offsets_ptr,
|
| 280 |
+
grad_input_ptr,
|
| 281 |
+
grad_mul_ptr,
|
| 282 |
+
grad_weight_ptr,
|
| 283 |
+
grad_bias_ptr,
|
| 284 |
+
N,
|
| 285 |
+
D,
|
| 286 |
+
num_experts,
|
| 287 |
+
eps,
|
| 288 |
+
expert_offset,
|
| 289 |
+
stride_row,
|
| 290 |
+
BLOCK_D: tl.constexpr,
|
| 291 |
+
):
|
| 292 |
+
"""Backward kernel: one program per row, 2-pass approach.
|
| 293 |
+
|
| 294 |
+
Pass 1: RMS stats + dot products + bias grad
|
| 295 |
+
Pass 2: grad_input + grad_mul + weight grads (via atomic_add)
|
| 296 |
+
"""
|
| 297 |
+
row = tl.program_id(0)
|
| 298 |
+
if row >= N:
|
| 299 |
+
return
|
| 300 |
+
|
| 301 |
+
lo = 0
|
| 302 |
+
hi = num_experts
|
| 303 |
+
for _ in range(12):
|
| 304 |
+
if lo < hi:
|
| 305 |
+
mid = (lo + hi) // 2
|
| 306 |
+
if tl.load(offsets_ptr + mid) <= row:
|
| 307 |
+
lo = mid + 1
|
| 308 |
+
else:
|
| 309 |
+
hi = mid
|
| 310 |
+
eidx = lo + expert_offset
|
| 311 |
+
|
| 312 |
+
w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32)
|
| 313 |
+
w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32)
|
| 314 |
+
w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32)
|
| 315 |
+
b_val = tl.load(bias_ptr + eidx).to(tl.float32)
|
| 316 |
+
|
| 317 |
+
input_row_ptr = input_ptr + row * stride_row
|
| 318 |
+
mul_row_ptr = mul_ptr + row * stride_row
|
| 319 |
+
grad_out_row_ptr = grad_out_ptr + row * stride_row
|
| 320 |
+
grad_input_row_ptr = grad_input_ptr + row * stride_row
|
| 321 |
+
grad_mul_row_ptr = grad_mul_ptr + row * stride_row
|
| 322 |
+
|
| 323 |
+
D_float = D.to(tl.float32)
|
| 324 |
+
|
| 325 |
+
# --- Single-tile path ---
|
| 326 |
+
if D <= BLOCK_D:
|
| 327 |
+
d_offs = tl.arange(0, BLOCK_D)
|
| 328 |
+
mask = d_offs < D
|
| 329 |
+
|
| 330 |
+
x = tl.load(input_row_ptr + d_offs, mask=mask,
|
| 331 |
+
other=0.0).to(tl.float32)
|
| 332 |
+
m = tl.load(mul_row_ptr + d_offs, mask=mask,
|
| 333 |
+
other=0.0).to(tl.float32)
|
| 334 |
+
go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
|
| 335 |
+
other=0.0).to(tl.float32)
|
| 336 |
+
|
| 337 |
+
x2 = x * x
|
| 338 |
+
x3 = x2 * x
|
| 339 |
+
|
| 340 |
+
# Compute RMS stats (x4 inlined to reduce register pressure)
|
| 341 |
+
inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps)
|
| 342 |
+
inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps)
|
| 343 |
+
inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps)
|
| 344 |
+
|
| 345 |
+
w0_inv = w0 * inv_rms_x3
|
| 346 |
+
w1_inv = w1 * inv_rms_x2
|
| 347 |
+
w2_inv = w2 * inv_rms_x
|
| 348 |
+
|
| 349 |
+
dpoly = go * m
|
| 350 |
+
|
| 351 |
+
# Dot products for coefficients and weight grads
|
| 352 |
+
sum_dpoly_x = tl.sum(dpoly * x)
|
| 353 |
+
sum_dpoly_x2 = tl.sum(dpoly * x2)
|
| 354 |
+
sum_dpoly_x3 = tl.sum(dpoly * x3)
|
| 355 |
+
grad_b_acc = tl.sum(dpoly)
|
| 356 |
+
|
| 357 |
+
# Weight grads
|
| 358 |
+
grad_w0_acc = inv_rms_x3 * sum_dpoly_x3
|
| 359 |
+
grad_w1_acc = inv_rms_x2 * sum_dpoly_x2
|
| 360 |
+
grad_w2_acc = inv_rms_x * sum_dpoly_x
|
| 361 |
+
|
| 362 |
+
# Coefficients for grad_input
|
| 363 |
+
coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float
|
| 364 |
+
coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float
|
| 365 |
+
coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float
|
| 366 |
+
|
| 367 |
+
# grad_mul
|
| 368 |
+
poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv
|
| 369 |
+
tl.store(grad_mul_row_ptr + d_offs, go * (poly + b_val), mask=mask)
|
| 370 |
+
|
| 371 |
+
# grad_input (in-place accumulation to reduce register pressure)
|
| 372 |
+
g = inv_rms_x * (w2 * dpoly - x * coeff_x)
|
| 373 |
+
g += 2.0 * x * inv_rms_x2 * (w1 * dpoly - x2 * coeff_x2)
|
| 374 |
+
g += 3.0 * x2 * inv_rms_x3 * (w0 * dpoly - x3 * coeff_x3)
|
| 375 |
+
tl.store(grad_input_row_ptr + d_offs, g, mask=mask)
|
| 376 |
+
|
| 377 |
+
tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc)
|
| 378 |
+
tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc)
|
| 379 |
+
tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc)
|
| 380 |
+
tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc)
|
| 381 |
+
else:
|
| 382 |
+
# --- Multi-tile: 2-pass ---
|
| 383 |
+
# Pass 1: RMS stats + dot products + bias grad
|
| 384 |
+
sum_x2 = tl.zeros((), dtype=tl.float32)
|
| 385 |
+
sum_x4 = tl.zeros((), dtype=tl.float32)
|
| 386 |
+
sum_x6 = tl.zeros((), dtype=tl.float32)
|
| 387 |
+
sum_dpoly_x = tl.zeros((), dtype=tl.float32)
|
| 388 |
+
sum_dpoly_x2 = tl.zeros((), dtype=tl.float32)
|
| 389 |
+
sum_dpoly_x3 = tl.zeros((), dtype=tl.float32)
|
| 390 |
+
grad_b_acc = tl.zeros((), dtype=tl.float32)
|
| 391 |
+
|
| 392 |
+
for d_start in range(0, D, BLOCK_D):
|
| 393 |
+
d_offs = d_start + tl.arange(0, BLOCK_D)
|
| 394 |
+
mask = d_offs < D
|
| 395 |
+
x = tl.load(input_row_ptr + d_offs, mask=mask,
|
| 396 |
+
other=0.0).to(tl.float32)
|
| 397 |
+
m = tl.load(mul_row_ptr + d_offs, mask=mask,
|
| 398 |
+
other=0.0).to(tl.float32)
|
| 399 |
+
go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
|
| 400 |
+
other=0.0).to(tl.float32)
|
| 401 |
+
|
| 402 |
+
x2 = x * x
|
| 403 |
+
x3 = x2 * x
|
| 404 |
+
dpoly = go * m
|
| 405 |
+
|
| 406 |
+
sum_x2 += tl.sum(x2)
|
| 407 |
+
sum_x4 += tl.sum(x2 * x2)
|
| 408 |
+
sum_x6 += tl.sum(x2 * x2 * x2)
|
| 409 |
+
sum_dpoly_x += tl.sum(dpoly * x)
|
| 410 |
+
sum_dpoly_x2 += tl.sum(dpoly * x2)
|
| 411 |
+
sum_dpoly_x3 += tl.sum(dpoly * x3)
|
| 412 |
+
grad_b_acc += tl.sum(dpoly)
|
| 413 |
+
|
| 414 |
+
inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps)
|
| 415 |
+
inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps)
|
| 416 |
+
inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps)
|
| 417 |
+
|
| 418 |
+
w0_inv = w0 * inv_rms_x3
|
| 419 |
+
w1_inv = w1 * inv_rms_x2
|
| 420 |
+
w2_inv = w2 * inv_rms_x
|
| 421 |
+
|
| 422 |
+
# Weight grads
|
| 423 |
+
grad_w0_acc = inv_rms_x3 * sum_dpoly_x3
|
| 424 |
+
grad_w1_acc = inv_rms_x2 * sum_dpoly_x2
|
| 425 |
+
grad_w2_acc = inv_rms_x * sum_dpoly_x
|
| 426 |
+
|
| 427 |
+
# Coefficients for grad_input
|
| 428 |
+
coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float
|
| 429 |
+
coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float
|
| 430 |
+
coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float
|
| 431 |
+
|
| 432 |
+
# Pass 2: grad_input + grad_mul
|
| 433 |
+
for d_start in range(0, D, BLOCK_D):
|
| 434 |
+
d_offs = d_start + tl.arange(0, BLOCK_D)
|
| 435 |
+
mask = d_offs < D
|
| 436 |
+
x = tl.load(input_row_ptr + d_offs, mask=mask,
|
| 437 |
+
other=0.0).to(tl.float32)
|
| 438 |
+
m = tl.load(mul_row_ptr + d_offs, mask=mask,
|
| 439 |
+
other=0.0).to(tl.float32)
|
| 440 |
+
go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
|
| 441 |
+
other=0.0).to(tl.float32)
|
| 442 |
+
|
| 443 |
+
x2 = x * x
|
| 444 |
+
x3 = x2 * x
|
| 445 |
+
|
| 446 |
+
poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv
|
| 447 |
+
tl.store(grad_mul_row_ptr + d_offs,
|
| 448 |
+
go * (poly + b_val),
|
| 449 |
+
mask=mask)
|
| 450 |
+
|
| 451 |
+
dpoly = go * m
|
| 452 |
+
g = inv_rms_x * (w2 * dpoly - x * coeff_x)
|
| 453 |
+
g += (2.0 * x * inv_rms_x2 *
|
| 454 |
+
(w1 * dpoly - x2 * coeff_x2))
|
| 455 |
+
g += (3.0 * x2 * inv_rms_x3 *
|
| 456 |
+
(w0 * dpoly - x3 * coeff_x3))
|
| 457 |
+
|
| 458 |
+
tl.store(grad_input_row_ptr + d_offs, g, mask=mask)
|
| 459 |
+
|
| 460 |
+
tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc)
|
| 461 |
+
tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc)
|
| 462 |
+
tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc)
|
| 463 |
+
tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc)
|
| 464 |
+
|
| 465 |
+
class _GroupedPolyNormFn(torch.autograd.Function):
|
| 466 |
+
|
| 467 |
+
@staticmethod
|
| 468 |
+
def forward(ctx, input, mul, weight, bias, offsets, eps, expert_offset):
|
| 469 |
+
N, D = input.shape
|
| 470 |
+
input = input.contiguous()
|
| 471 |
+
mul = mul.contiguous()
|
| 472 |
+
output = torch.empty_like(input)
|
| 473 |
+
|
| 474 |
+
num_experts = offsets.shape[0]
|
| 475 |
+
assert num_experts <= 4096, (
|
| 476 |
+
f"Supports at most 4096 experts, got {num_experts}.")
|
| 477 |
+
|
| 478 |
+
_grouped_polynorm_fwd_kernel[(N,)](
|
| 479 |
+
input,
|
| 480 |
+
mul,
|
| 481 |
+
weight,
|
| 482 |
+
bias,
|
| 483 |
+
offsets,
|
| 484 |
+
output,
|
| 485 |
+
N,
|
| 486 |
+
D,
|
| 487 |
+
num_experts,
|
| 488 |
+
eps,
|
| 489 |
+
expert_offset,
|
| 490 |
+
stride_input_row=input.stride(0),
|
| 491 |
+
stride_mul_row=mul.stride(0),
|
| 492 |
+
stride_out_row=output.stride(0),
|
| 493 |
+
)
|
| 494 |
+
|
| 495 |
+
ctx.save_for_backward(input, mul, weight, bias, offsets)
|
| 496 |
+
ctx.eps = eps
|
| 497 |
+
ctx.expert_offset = expert_offset
|
| 498 |
+
return output
|
| 499 |
+
|
| 500 |
+
@staticmethod
|
| 501 |
+
def backward(ctx, grad_output):
|
| 502 |
+
input, mul, weight, bias, offsets = ctx.saved_tensors
|
| 503 |
+
eps = ctx.eps
|
| 504 |
+
expert_offset = ctx.expert_offset
|
| 505 |
+
N, D = input.shape
|
| 506 |
+
|
| 507 |
+
grad_output = grad_output.contiguous()
|
| 508 |
+
grad_input = torch.empty_like(input)
|
| 509 |
+
grad_mul = torch.empty_like(mul)
|
| 510 |
+
grad_weight = torch.zeros(weight.shape[0],
|
| 511 |
+
3,
|
| 512 |
+
device=weight.device,
|
| 513 |
+
dtype=torch.float32)
|
| 514 |
+
grad_bias = torch.zeros(bias.shape[0],
|
| 515 |
+
device=bias.device,
|
| 516 |
+
dtype=torch.float32)
|
| 517 |
+
|
| 518 |
+
num_experts = offsets.shape[0]
|
| 519 |
+
|
| 520 |
+
_grouped_polynorm_bwd_kernel[(N,)](
|
| 521 |
+
grad_output,
|
| 522 |
+
input,
|
| 523 |
+
mul,
|
| 524 |
+
weight,
|
| 525 |
+
bias,
|
| 526 |
+
offsets,
|
| 527 |
+
grad_input,
|
| 528 |
+
grad_mul,
|
| 529 |
+
grad_weight,
|
| 530 |
+
grad_bias,
|
| 531 |
+
N,
|
| 532 |
+
D,
|
| 533 |
+
num_experts,
|
| 534 |
+
eps,
|
| 535 |
+
expert_offset,
|
| 536 |
+
stride_row=input.stride(0),
|
| 537 |
+
)
|
| 538 |
+
|
| 539 |
+
grad_weight = grad_weight.to(weight.dtype)
|
| 540 |
+
grad_bias = grad_bias.unsqueeze(-1).to(bias.dtype)
|
| 541 |
+
|
| 542 |
+
return grad_input, grad_mul, grad_weight, grad_bias, None, None, None
|
| 543 |
+
|
| 544 |
+
def grouped_fused_mul_poly_norm(
|
| 545 |
+
input: Tensor,
|
| 546 |
+
mul: Tensor,
|
| 547 |
+
weight: Tensor,
|
| 548 |
+
bias: Tensor,
|
| 549 |
+
offsets: Tensor,
|
| 550 |
+
eps: float = 1e-6,
|
| 551 |
+
expert_offset: int = 0,
|
| 552 |
+
) -> Tensor:
|
| 553 |
+
"""Triton-accelerated Grouped FusedMulPolyNorm.
|
| 554 |
+
|
| 555 |
+
Args:
|
| 556 |
+
input: (total_tokens, D) - concatenated tokens for all experts
|
| 557 |
+
mul: (total_tokens, D) - gate values to multiply with
|
| 558 |
+
weight: (num_experts, 3) - per-expert polynomial weights
|
| 559 |
+
bias: (num_experts, 1) - per-expert polynomial bias
|
| 560 |
+
offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32)
|
| 561 |
+
eps: numerical stability epsilon
|
| 562 |
+
expert_offset: offset to add to expert index
|
| 563 |
+
|
| 564 |
+
Returns:
|
| 565 |
+
(total_tokens, D) - output tensor
|
| 566 |
+
"""
|
| 567 |
+
return _GroupedPolyNormFn.apply(input, mul, weight, bias, offsets, eps,
|
| 568 |
+
expert_offset)
|
| 569 |
+
|
| 570 |
+
else:
|
| 571 |
+
|
| 572 |
+
def grouped_fused_mul_poly_norm(
|
| 573 |
+
input: Tensor,
|
| 574 |
+
mul: Tensor,
|
| 575 |
+
weight: Tensor,
|
| 576 |
+
bias: Tensor,
|
| 577 |
+
offsets: Tensor,
|
| 578 |
+
eps: float = 1e-6,
|
| 579 |
+
expert_offset: int = 0,
|
| 580 |
+
) -> Tensor:
|
| 581 |
+
raise RuntimeError(
|
| 582 |
+
"Triton is not available. Install triton to use "
|
| 583 |
+
"grouped_fused_mul_poly_norm.")
|
build/torch28-cxx11-rocm64-x86_64-linux/__init__.py
CHANGED
|
@@ -2,6 +2,7 @@ import torch
|
|
| 2 |
|
| 3 |
from . import layers, parallel_style
|
| 4 |
from ._ops import ops
|
|
|
|
| 5 |
from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
|
| 6 |
from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
|
| 7 |
|
|
@@ -45,6 +46,7 @@ def fused_add_rms_norm(
|
|
| 45 |
__all__ = [
|
| 46 |
"poly_norm",
|
| 47 |
"fused_mul_poly_norm",
|
|
|
|
| 48 |
"rms_norm",
|
| 49 |
"fused_add_rms_norm",
|
| 50 |
"layers",
|
|
|
|
| 2 |
|
| 3 |
from . import layers, parallel_style
|
| 4 |
from ._ops import ops
|
| 5 |
+
from .grouped_poly_norm import grouped_fused_mul_poly_norm
|
| 6 |
from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
|
| 7 |
from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
|
| 8 |
|
|
|
|
| 46 |
__all__ = [
|
| 47 |
"poly_norm",
|
| 48 |
"fused_mul_poly_norm",
|
| 49 |
+
"grouped_fused_mul_poly_norm",
|
| 50 |
"rms_norm",
|
| 51 |
"fused_add_rms_norm",
|
| 52 |
"layers",
|
build/torch28-cxx11-rocm64-x86_64-linux/{_activation_18b7543_dirty.abi3.so → _activation_0e6f27f_dirty.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 2794152
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:388f461d91124b99544cbdbd4dc4d98f24c010d7a0dc2e9389648860a809a51a
|
| 3 |
size 2794152
|
build/torch28-cxx11-rocm64-x86_64-linux/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _activation_0e6f27f_dirty
|
| 3 |
+
ops = torch.ops._activation_0e6f27f_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_activation_0e6f27f_dirty::{op_name}"
|
build/torch28-cxx11-rocm64-x86_64-linux/grouped_poly_norm.py
ADDED
|
@@ -0,0 +1,583 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Triton-accelerated Grouped FusedMulPolyNorm for MoE.
|
| 2 |
+
|
| 3 |
+
Fuses the entire PolyNorm computation into two Triton kernels (fwd + bwd),
|
| 4 |
+
eliminating multiple intermediate tensors and kernel launches.
|
| 5 |
+
|
| 6 |
+
PolyNorm formula (per row):
|
| 7 |
+
poly = w[0] * rms_norm(x^3) + w[1] * rms_norm(x^2) + w[2] * rms_norm(x) + bias
|
| 8 |
+
output = poly * mul
|
| 9 |
+
|
| 10 |
+
where rms_norm(x) = x / sqrt(mean(x^2, dim=-1) + eps)
|
| 11 |
+
|
| 12 |
+
Performance optimizations:
|
| 13 |
+
- @triton.autotune selects optimal BLOCK_D, num_warps, and num_stages per
|
| 14 |
+
hidden dimension.
|
| 15 |
+
- Single-tile specialization: when D <= BLOCK_D, all data stays in registers
|
| 16 |
+
across the reduction and output phases, eliminating redundant global reads.
|
| 17 |
+
- Multi-tile software pipelining: explicit num_stages in autotune configs
|
| 18 |
+
enables overlapping memory loads with computation across loop iterations.
|
| 19 |
+
- In-kernel binary search for expert mapping: eliminates 2 PyTorch kernel
|
| 20 |
+
launches (torch.arange + torch.bucketize) per forward/backward call.
|
| 21 |
+
- Backward 2-pass optimization: pass 1 merges RMS statistics computation
|
| 22 |
+
with dot product accumulation, pass 2 computes gradients. This reduces
|
| 23 |
+
memory traffic compared to a naive 3-pass approach.
|
| 24 |
+
|
| 25 |
+
Forward kernel: one program per row, tiles over D dimension.
|
| 26 |
+
- Computes x, x^2, x^3 in registers
|
| 27 |
+
- Computes three RMS norms in a single pass (shared variance reduction)
|
| 28 |
+
- Applies polynomial weights + bias + mul in-place
|
| 29 |
+
|
| 30 |
+
Backward kernel: one program per row, tiles over D dimension.
|
| 31 |
+
- Recomputes forward intermediates from saved inputs (activation recomputation)
|
| 32 |
+
- 2-pass: (1) RMS stats + dot products + bias grad, (2) grad_input + grad_mul + weight grads
|
| 33 |
+
- Weight/bias gradients use tl.atomic_add for cross-row accumulation
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
import torch
|
| 37 |
+
from torch import Tensor
|
| 38 |
+
|
| 39 |
+
try:
|
| 40 |
+
import triton
|
| 41 |
+
import triton.language as tl
|
| 42 |
+
|
| 43 |
+
HAS_TRITON = True
|
| 44 |
+
except ImportError:
|
| 45 |
+
HAS_TRITON = False
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# ---------------------------------------------------------------------------
|
| 49 |
+
# PyTorch reference implementation (for testing and benchmarking)
|
| 50 |
+
# ---------------------------------------------------------------------------
|
| 51 |
+
def _rms_norm(x: Tensor, eps: float) -> Tensor:
|
| 52 |
+
"""Per-row RMS normalization: x / sqrt(mean(x^2, dim=-1) + eps)"""
|
| 53 |
+
return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def grouped_fused_mul_poly_norm_ref(
|
| 57 |
+
input: Tensor,
|
| 58 |
+
mul: Tensor,
|
| 59 |
+
weight: Tensor,
|
| 60 |
+
bias: Tensor,
|
| 61 |
+
offsets: Tensor,
|
| 62 |
+
eps: float = 1e-6,
|
| 63 |
+
expert_offset: int = 0,
|
| 64 |
+
) -> Tensor:
|
| 65 |
+
"""PyTorch reference for Grouped FusedMulPolyNorm (vectorized, single pass).
|
| 66 |
+
|
| 67 |
+
Uses torch.bucketize to map tokens to experts, then computes PolyNorm
|
| 68 |
+
for all tokens at once. torch.compile friendly.
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
input: (total_tokens, D) - concatenated tokens for all experts
|
| 72 |
+
mul: (total_tokens, D) - gate values to multiply with
|
| 73 |
+
weight: (num_experts, 3) - per-expert polynomial weights [x^3, x^2, x]
|
| 74 |
+
bias: (num_experts, 1) - per-expert polynomial bias
|
| 75 |
+
offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32)
|
| 76 |
+
eps: numerical stability epsilon
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
(total_tokens, D) - output tensor
|
| 80 |
+
"""
|
| 81 |
+
orig_dtype = input.dtype
|
| 82 |
+
|
| 83 |
+
token_positions = torch.arange(input.shape[0], device=input.device)
|
| 84 |
+
expert_idx = torch.bucketize(token_positions, offsets, right=True) + expert_offset
|
| 85 |
+
|
| 86 |
+
weight_fp32 = weight.float()
|
| 87 |
+
bias_fp32 = bias.float()
|
| 88 |
+
|
| 89 |
+
per_token_w = weight_fp32[expert_idx]
|
| 90 |
+
per_token_b = bias_fp32[expert_idx]
|
| 91 |
+
|
| 92 |
+
x = input.float()
|
| 93 |
+
m = mul.float()
|
| 94 |
+
|
| 95 |
+
x2 = x * x
|
| 96 |
+
x3 = x2 * x
|
| 97 |
+
|
| 98 |
+
poly = (per_token_w[:, 0:1] * _rms_norm(x3, eps) +
|
| 99 |
+
per_token_w[:, 1:2] * _rms_norm(x2, eps) +
|
| 100 |
+
per_token_w[:, 2:3] * _rms_norm(x, eps) + per_token_b)
|
| 101 |
+
|
| 102 |
+
return (poly * m).to(orig_dtype)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
# ---------------------------------------------------------------------------
|
| 106 |
+
# Triton kernel implementation
|
| 107 |
+
# ---------------------------------------------------------------------------
|
| 108 |
+
if HAS_TRITON:
|
| 109 |
+
# --- Autotune configurations ---
|
| 110 |
+
_GROUPED_POLYNORM_FWD_CONFIGS = [
|
| 111 |
+
triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2),
|
| 112 |
+
triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3),
|
| 113 |
+
triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2),
|
| 114 |
+
triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3),
|
| 115 |
+
triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2),
|
| 116 |
+
triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4),
|
| 117 |
+
triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2),
|
| 118 |
+
triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3),
|
| 119 |
+
triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4),
|
| 120 |
+
triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2),
|
| 121 |
+
triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3),
|
| 122 |
+
triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4),
|
| 123 |
+
triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2),
|
| 124 |
+
triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2),
|
| 125 |
+
triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3),
|
| 126 |
+
triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2),
|
| 127 |
+
triton.Config({"BLOCK_D": 2048}, num_warps=4, num_stages=1),
|
| 128 |
+
triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1),
|
| 129 |
+
triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1),
|
| 130 |
+
triton.Config({"BLOCK_D": 2048}, num_warps=32, num_stages=1),
|
| 131 |
+
]
|
| 132 |
+
|
| 133 |
+
_GROUPED_POLYNORM_BWD_CONFIGS = [
|
| 134 |
+
triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2),
|
| 135 |
+
triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3),
|
| 136 |
+
triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2),
|
| 137 |
+
triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3),
|
| 138 |
+
triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2),
|
| 139 |
+
triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4),
|
| 140 |
+
triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2),
|
| 141 |
+
triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3),
|
| 142 |
+
triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4),
|
| 143 |
+
triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=5),
|
| 144 |
+
triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2),
|
| 145 |
+
triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3),
|
| 146 |
+
triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4),
|
| 147 |
+
triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2),
|
| 148 |
+
triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2),
|
| 149 |
+
triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3),
|
| 150 |
+
triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=4),
|
| 151 |
+
triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2),
|
| 152 |
+
triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1),
|
| 153 |
+
triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1),
|
| 154 |
+
]
|
| 155 |
+
|
| 156 |
+
@triton.autotune(
|
| 157 |
+
configs=_GROUPED_POLYNORM_FWD_CONFIGS,
|
| 158 |
+
key=["D"],
|
| 159 |
+
)
|
| 160 |
+
@triton.jit
|
| 161 |
+
def _grouped_polynorm_fwd_kernel(
|
| 162 |
+
input_ptr,
|
| 163 |
+
mul_ptr,
|
| 164 |
+
weight_ptr,
|
| 165 |
+
bias_ptr,
|
| 166 |
+
offsets_ptr,
|
| 167 |
+
output_ptr,
|
| 168 |
+
N,
|
| 169 |
+
D,
|
| 170 |
+
num_experts,
|
| 171 |
+
eps,
|
| 172 |
+
expert_offset,
|
| 173 |
+
stride_input_row,
|
| 174 |
+
stride_mul_row,
|
| 175 |
+
stride_out_row,
|
| 176 |
+
BLOCK_D: tl.constexpr,
|
| 177 |
+
):
|
| 178 |
+
"""Forward kernel: one program per row."""
|
| 179 |
+
row = tl.program_id(0)
|
| 180 |
+
if row >= N:
|
| 181 |
+
return
|
| 182 |
+
|
| 183 |
+
# Binary search for expert index (12 iters covers up to 4096 experts)
|
| 184 |
+
lo = 0
|
| 185 |
+
hi = num_experts
|
| 186 |
+
for _ in range(12):
|
| 187 |
+
if lo < hi:
|
| 188 |
+
mid = (lo + hi) // 2
|
| 189 |
+
if tl.load(offsets_ptr + mid) <= row:
|
| 190 |
+
lo = mid + 1
|
| 191 |
+
else:
|
| 192 |
+
hi = mid
|
| 193 |
+
eidx = lo + expert_offset
|
| 194 |
+
|
| 195 |
+
w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32)
|
| 196 |
+
w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32)
|
| 197 |
+
w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32)
|
| 198 |
+
b = tl.load(bias_ptr + eidx).to(tl.float32)
|
| 199 |
+
|
| 200 |
+
input_row_ptr = input_ptr + row * stride_input_row
|
| 201 |
+
mul_row_ptr = mul_ptr + row * stride_mul_row
|
| 202 |
+
out_row_ptr = output_ptr + row * stride_out_row
|
| 203 |
+
|
| 204 |
+
D_float = D.to(tl.float32)
|
| 205 |
+
|
| 206 |
+
# --- Single-tile path ---
|
| 207 |
+
if D <= BLOCK_D:
|
| 208 |
+
d_offs = tl.arange(0, BLOCK_D)
|
| 209 |
+
mask = d_offs < D
|
| 210 |
+
|
| 211 |
+
x = tl.load(input_row_ptr + d_offs, mask=mask,
|
| 212 |
+
other=0.0).to(tl.float32)
|
| 213 |
+
m = tl.load(mul_row_ptr + d_offs, mask=mask,
|
| 214 |
+
other=0.0).to(tl.float32)
|
| 215 |
+
|
| 216 |
+
x2 = x * x
|
| 217 |
+
x3 = x2 * x
|
| 218 |
+
|
| 219 |
+
inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps)
|
| 220 |
+
inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps)
|
| 221 |
+
inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps)
|
| 222 |
+
|
| 223 |
+
# Pre-multiply scalar weight * inv_rms to save 1 FMA per element
|
| 224 |
+
w0_inv = w0 * inv_rms_x3
|
| 225 |
+
w1_inv = w1 * inv_rms_x2
|
| 226 |
+
w2_inv = w2 * inv_rms_x
|
| 227 |
+
|
| 228 |
+
poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b
|
| 229 |
+
tl.store(out_row_ptr + d_offs, poly * m, mask=mask)
|
| 230 |
+
else:
|
| 231 |
+
# --- Multi-tile: two-pass approach ---
|
| 232 |
+
sum_x2 = tl.zeros((), dtype=tl.float32)
|
| 233 |
+
sum_x4 = tl.zeros((), dtype=tl.float32)
|
| 234 |
+
sum_x6 = tl.zeros((), dtype=tl.float32)
|
| 235 |
+
|
| 236 |
+
for d_start in range(0, D, BLOCK_D):
|
| 237 |
+
d_offs = d_start + tl.arange(0, BLOCK_D)
|
| 238 |
+
mask = d_offs < D
|
| 239 |
+
x = tl.load(input_row_ptr + d_offs, mask=mask,
|
| 240 |
+
other=0.0).to(tl.float32)
|
| 241 |
+
x2 = x * x
|
| 242 |
+
sum_x2 += tl.sum(x2)
|
| 243 |
+
sum_x4 += tl.sum(x2 * x2)
|
| 244 |
+
sum_x6 += tl.sum(x2 * x2 * x2)
|
| 245 |
+
|
| 246 |
+
inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps)
|
| 247 |
+
inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps)
|
| 248 |
+
inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps)
|
| 249 |
+
|
| 250 |
+
# Pre-multiply scalar weight * inv_rms
|
| 251 |
+
w0_inv = w0 * inv_rms_x3
|
| 252 |
+
w1_inv = w1 * inv_rms_x2
|
| 253 |
+
w2_inv = w2 * inv_rms_x
|
| 254 |
+
|
| 255 |
+
for d_start in range(0, D, BLOCK_D):
|
| 256 |
+
d_offs = d_start + tl.arange(0, BLOCK_D)
|
| 257 |
+
mask = d_offs < D
|
| 258 |
+
x = tl.load(input_row_ptr + d_offs, mask=mask,
|
| 259 |
+
other=0.0).to(tl.float32)
|
| 260 |
+
m = tl.load(mul_row_ptr + d_offs, mask=mask,
|
| 261 |
+
other=0.0).to(tl.float32)
|
| 262 |
+
x2 = x * x
|
| 263 |
+
x3 = x2 * x
|
| 264 |
+
poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b
|
| 265 |
+
tl.store(out_row_ptr + d_offs, poly * m, mask=mask)
|
| 266 |
+
|
| 267 |
+
@triton.autotune(
|
| 268 |
+
configs=_GROUPED_POLYNORM_BWD_CONFIGS,
|
| 269 |
+
key=["D"],
|
| 270 |
+
reset_to_zero=["grad_weight_ptr", "grad_bias_ptr"],
|
| 271 |
+
)
|
| 272 |
+
@triton.jit
|
| 273 |
+
def _grouped_polynorm_bwd_kernel(
|
| 274 |
+
grad_out_ptr,
|
| 275 |
+
input_ptr,
|
| 276 |
+
mul_ptr,
|
| 277 |
+
weight_ptr,
|
| 278 |
+
bias_ptr,
|
| 279 |
+
offsets_ptr,
|
| 280 |
+
grad_input_ptr,
|
| 281 |
+
grad_mul_ptr,
|
| 282 |
+
grad_weight_ptr,
|
| 283 |
+
grad_bias_ptr,
|
| 284 |
+
N,
|
| 285 |
+
D,
|
| 286 |
+
num_experts,
|
| 287 |
+
eps,
|
| 288 |
+
expert_offset,
|
| 289 |
+
stride_row,
|
| 290 |
+
BLOCK_D: tl.constexpr,
|
| 291 |
+
):
|
| 292 |
+
"""Backward kernel: one program per row, 2-pass approach.
|
| 293 |
+
|
| 294 |
+
Pass 1: RMS stats + dot products + bias grad
|
| 295 |
+
Pass 2: grad_input + grad_mul + weight grads (via atomic_add)
|
| 296 |
+
"""
|
| 297 |
+
row = tl.program_id(0)
|
| 298 |
+
if row >= N:
|
| 299 |
+
return
|
| 300 |
+
|
| 301 |
+
lo = 0
|
| 302 |
+
hi = num_experts
|
| 303 |
+
for _ in range(12):
|
| 304 |
+
if lo < hi:
|
| 305 |
+
mid = (lo + hi) // 2
|
| 306 |
+
if tl.load(offsets_ptr + mid) <= row:
|
| 307 |
+
lo = mid + 1
|
| 308 |
+
else:
|
| 309 |
+
hi = mid
|
| 310 |
+
eidx = lo + expert_offset
|
| 311 |
+
|
| 312 |
+
w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32)
|
| 313 |
+
w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32)
|
| 314 |
+
w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32)
|
| 315 |
+
b_val = tl.load(bias_ptr + eidx).to(tl.float32)
|
| 316 |
+
|
| 317 |
+
input_row_ptr = input_ptr + row * stride_row
|
| 318 |
+
mul_row_ptr = mul_ptr + row * stride_row
|
| 319 |
+
grad_out_row_ptr = grad_out_ptr + row * stride_row
|
| 320 |
+
grad_input_row_ptr = grad_input_ptr + row * stride_row
|
| 321 |
+
grad_mul_row_ptr = grad_mul_ptr + row * stride_row
|
| 322 |
+
|
| 323 |
+
D_float = D.to(tl.float32)
|
| 324 |
+
|
| 325 |
+
# --- Single-tile path ---
|
| 326 |
+
if D <= BLOCK_D:
|
| 327 |
+
d_offs = tl.arange(0, BLOCK_D)
|
| 328 |
+
mask = d_offs < D
|
| 329 |
+
|
| 330 |
+
x = tl.load(input_row_ptr + d_offs, mask=mask,
|
| 331 |
+
other=0.0).to(tl.float32)
|
| 332 |
+
m = tl.load(mul_row_ptr + d_offs, mask=mask,
|
| 333 |
+
other=0.0).to(tl.float32)
|
| 334 |
+
go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
|
| 335 |
+
other=0.0).to(tl.float32)
|
| 336 |
+
|
| 337 |
+
x2 = x * x
|
| 338 |
+
x3 = x2 * x
|
| 339 |
+
|
| 340 |
+
# Compute RMS stats (x4 inlined to reduce register pressure)
|
| 341 |
+
inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps)
|
| 342 |
+
inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps)
|
| 343 |
+
inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps)
|
| 344 |
+
|
| 345 |
+
w0_inv = w0 * inv_rms_x3
|
| 346 |
+
w1_inv = w1 * inv_rms_x2
|
| 347 |
+
w2_inv = w2 * inv_rms_x
|
| 348 |
+
|
| 349 |
+
dpoly = go * m
|
| 350 |
+
|
| 351 |
+
# Dot products for coefficients and weight grads
|
| 352 |
+
sum_dpoly_x = tl.sum(dpoly * x)
|
| 353 |
+
sum_dpoly_x2 = tl.sum(dpoly * x2)
|
| 354 |
+
sum_dpoly_x3 = tl.sum(dpoly * x3)
|
| 355 |
+
grad_b_acc = tl.sum(dpoly)
|
| 356 |
+
|
| 357 |
+
# Weight grads
|
| 358 |
+
grad_w0_acc = inv_rms_x3 * sum_dpoly_x3
|
| 359 |
+
grad_w1_acc = inv_rms_x2 * sum_dpoly_x2
|
| 360 |
+
grad_w2_acc = inv_rms_x * sum_dpoly_x
|
| 361 |
+
|
| 362 |
+
# Coefficients for grad_input
|
| 363 |
+
coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float
|
| 364 |
+
coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float
|
| 365 |
+
coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float
|
| 366 |
+
|
| 367 |
+
# grad_mul
|
| 368 |
+
poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv
|
| 369 |
+
tl.store(grad_mul_row_ptr + d_offs, go * (poly + b_val), mask=mask)
|
| 370 |
+
|
| 371 |
+
# grad_input (in-place accumulation to reduce register pressure)
|
| 372 |
+
g = inv_rms_x * (w2 * dpoly - x * coeff_x)
|
| 373 |
+
g += 2.0 * x * inv_rms_x2 * (w1 * dpoly - x2 * coeff_x2)
|
| 374 |
+
g += 3.0 * x2 * inv_rms_x3 * (w0 * dpoly - x3 * coeff_x3)
|
| 375 |
+
tl.store(grad_input_row_ptr + d_offs, g, mask=mask)
|
| 376 |
+
|
| 377 |
+
tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc)
|
| 378 |
+
tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc)
|
| 379 |
+
tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc)
|
| 380 |
+
tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc)
|
| 381 |
+
else:
|
| 382 |
+
# --- Multi-tile: 2-pass ---
|
| 383 |
+
# Pass 1: RMS stats + dot products + bias grad
|
| 384 |
+
sum_x2 = tl.zeros((), dtype=tl.float32)
|
| 385 |
+
sum_x4 = tl.zeros((), dtype=tl.float32)
|
| 386 |
+
sum_x6 = tl.zeros((), dtype=tl.float32)
|
| 387 |
+
sum_dpoly_x = tl.zeros((), dtype=tl.float32)
|
| 388 |
+
sum_dpoly_x2 = tl.zeros((), dtype=tl.float32)
|
| 389 |
+
sum_dpoly_x3 = tl.zeros((), dtype=tl.float32)
|
| 390 |
+
grad_b_acc = tl.zeros((), dtype=tl.float32)
|
| 391 |
+
|
| 392 |
+
for d_start in range(0, D, BLOCK_D):
|
| 393 |
+
d_offs = d_start + tl.arange(0, BLOCK_D)
|
| 394 |
+
mask = d_offs < D
|
| 395 |
+
x = tl.load(input_row_ptr + d_offs, mask=mask,
|
| 396 |
+
other=0.0).to(tl.float32)
|
| 397 |
+
m = tl.load(mul_row_ptr + d_offs, mask=mask,
|
| 398 |
+
other=0.0).to(tl.float32)
|
| 399 |
+
go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
|
| 400 |
+
other=0.0).to(tl.float32)
|
| 401 |
+
|
| 402 |
+
x2 = x * x
|
| 403 |
+
x3 = x2 * x
|
| 404 |
+
dpoly = go * m
|
| 405 |
+
|
| 406 |
+
sum_x2 += tl.sum(x2)
|
| 407 |
+
sum_x4 += tl.sum(x2 * x2)
|
| 408 |
+
sum_x6 += tl.sum(x2 * x2 * x2)
|
| 409 |
+
sum_dpoly_x += tl.sum(dpoly * x)
|
| 410 |
+
sum_dpoly_x2 += tl.sum(dpoly * x2)
|
| 411 |
+
sum_dpoly_x3 += tl.sum(dpoly * x3)
|
| 412 |
+
grad_b_acc += tl.sum(dpoly)
|
| 413 |
+
|
| 414 |
+
inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps)
|
| 415 |
+
inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps)
|
| 416 |
+
inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps)
|
| 417 |
+
|
| 418 |
+
w0_inv = w0 * inv_rms_x3
|
| 419 |
+
w1_inv = w1 * inv_rms_x2
|
| 420 |
+
w2_inv = w2 * inv_rms_x
|
| 421 |
+
|
| 422 |
+
# Weight grads
|
| 423 |
+
grad_w0_acc = inv_rms_x3 * sum_dpoly_x3
|
| 424 |
+
grad_w1_acc = inv_rms_x2 * sum_dpoly_x2
|
| 425 |
+
grad_w2_acc = inv_rms_x * sum_dpoly_x
|
| 426 |
+
|
| 427 |
+
# Coefficients for grad_input
|
| 428 |
+
coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float
|
| 429 |
+
coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float
|
| 430 |
+
coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float
|
| 431 |
+
|
| 432 |
+
# Pass 2: grad_input + grad_mul
|
| 433 |
+
for d_start in range(0, D, BLOCK_D):
|
| 434 |
+
d_offs = d_start + tl.arange(0, BLOCK_D)
|
| 435 |
+
mask = d_offs < D
|
| 436 |
+
x = tl.load(input_row_ptr + d_offs, mask=mask,
|
| 437 |
+
other=0.0).to(tl.float32)
|
| 438 |
+
m = tl.load(mul_row_ptr + d_offs, mask=mask,
|
| 439 |
+
other=0.0).to(tl.float32)
|
| 440 |
+
go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
|
| 441 |
+
other=0.0).to(tl.float32)
|
| 442 |
+
|
| 443 |
+
x2 = x * x
|
| 444 |
+
x3 = x2 * x
|
| 445 |
+
|
| 446 |
+
poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv
|
| 447 |
+
tl.store(grad_mul_row_ptr + d_offs,
|
| 448 |
+
go * (poly + b_val),
|
| 449 |
+
mask=mask)
|
| 450 |
+
|
| 451 |
+
dpoly = go * m
|
| 452 |
+
g = inv_rms_x * (w2 * dpoly - x * coeff_x)
|
| 453 |
+
g += (2.0 * x * inv_rms_x2 *
|
| 454 |
+
(w1 * dpoly - x2 * coeff_x2))
|
| 455 |
+
g += (3.0 * x2 * inv_rms_x3 *
|
| 456 |
+
(w0 * dpoly - x3 * coeff_x3))
|
| 457 |
+
|
| 458 |
+
tl.store(grad_input_row_ptr + d_offs, g, mask=mask)
|
| 459 |
+
|
| 460 |
+
tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc)
|
| 461 |
+
tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc)
|
| 462 |
+
tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc)
|
| 463 |
+
tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc)
|
| 464 |
+
|
| 465 |
+
class _GroupedPolyNormFn(torch.autograd.Function):
|
| 466 |
+
|
| 467 |
+
@staticmethod
|
| 468 |
+
def forward(ctx, input, mul, weight, bias, offsets, eps, expert_offset):
|
| 469 |
+
N, D = input.shape
|
| 470 |
+
input = input.contiguous()
|
| 471 |
+
mul = mul.contiguous()
|
| 472 |
+
output = torch.empty_like(input)
|
| 473 |
+
|
| 474 |
+
num_experts = offsets.shape[0]
|
| 475 |
+
assert num_experts <= 4096, (
|
| 476 |
+
f"Supports at most 4096 experts, got {num_experts}.")
|
| 477 |
+
|
| 478 |
+
_grouped_polynorm_fwd_kernel[(N,)](
|
| 479 |
+
input,
|
| 480 |
+
mul,
|
| 481 |
+
weight,
|
| 482 |
+
bias,
|
| 483 |
+
offsets,
|
| 484 |
+
output,
|
| 485 |
+
N,
|
| 486 |
+
D,
|
| 487 |
+
num_experts,
|
| 488 |
+
eps,
|
| 489 |
+
expert_offset,
|
| 490 |
+
stride_input_row=input.stride(0),
|
| 491 |
+
stride_mul_row=mul.stride(0),
|
| 492 |
+
stride_out_row=output.stride(0),
|
| 493 |
+
)
|
| 494 |
+
|
| 495 |
+
ctx.save_for_backward(input, mul, weight, bias, offsets)
|
| 496 |
+
ctx.eps = eps
|
| 497 |
+
ctx.expert_offset = expert_offset
|
| 498 |
+
return output
|
| 499 |
+
|
| 500 |
+
@staticmethod
|
| 501 |
+
def backward(ctx, grad_output):
|
| 502 |
+
input, mul, weight, bias, offsets = ctx.saved_tensors
|
| 503 |
+
eps = ctx.eps
|
| 504 |
+
expert_offset = ctx.expert_offset
|
| 505 |
+
N, D = input.shape
|
| 506 |
+
|
| 507 |
+
grad_output = grad_output.contiguous()
|
| 508 |
+
grad_input = torch.empty_like(input)
|
| 509 |
+
grad_mul = torch.empty_like(mul)
|
| 510 |
+
grad_weight = torch.zeros(weight.shape[0],
|
| 511 |
+
3,
|
| 512 |
+
device=weight.device,
|
| 513 |
+
dtype=torch.float32)
|
| 514 |
+
grad_bias = torch.zeros(bias.shape[0],
|
| 515 |
+
device=bias.device,
|
| 516 |
+
dtype=torch.float32)
|
| 517 |
+
|
| 518 |
+
num_experts = offsets.shape[0]
|
| 519 |
+
|
| 520 |
+
_grouped_polynorm_bwd_kernel[(N,)](
|
| 521 |
+
grad_output,
|
| 522 |
+
input,
|
| 523 |
+
mul,
|
| 524 |
+
weight,
|
| 525 |
+
bias,
|
| 526 |
+
offsets,
|
| 527 |
+
grad_input,
|
| 528 |
+
grad_mul,
|
| 529 |
+
grad_weight,
|
| 530 |
+
grad_bias,
|
| 531 |
+
N,
|
| 532 |
+
D,
|
| 533 |
+
num_experts,
|
| 534 |
+
eps,
|
| 535 |
+
expert_offset,
|
| 536 |
+
stride_row=input.stride(0),
|
| 537 |
+
)
|
| 538 |
+
|
| 539 |
+
grad_weight = grad_weight.to(weight.dtype)
|
| 540 |
+
grad_bias = grad_bias.unsqueeze(-1).to(bias.dtype)
|
| 541 |
+
|
| 542 |
+
return grad_input, grad_mul, grad_weight, grad_bias, None, None, None
|
| 543 |
+
|
| 544 |
+
def grouped_fused_mul_poly_norm(
|
| 545 |
+
input: Tensor,
|
| 546 |
+
mul: Tensor,
|
| 547 |
+
weight: Tensor,
|
| 548 |
+
bias: Tensor,
|
| 549 |
+
offsets: Tensor,
|
| 550 |
+
eps: float = 1e-6,
|
| 551 |
+
expert_offset: int = 0,
|
| 552 |
+
) -> Tensor:
|
| 553 |
+
"""Triton-accelerated Grouped FusedMulPolyNorm.
|
| 554 |
+
|
| 555 |
+
Args:
|
| 556 |
+
input: (total_tokens, D) - concatenated tokens for all experts
|
| 557 |
+
mul: (total_tokens, D) - gate values to multiply with
|
| 558 |
+
weight: (num_experts, 3) - per-expert polynomial weights
|
| 559 |
+
bias: (num_experts, 1) - per-expert polynomial bias
|
| 560 |
+
offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32)
|
| 561 |
+
eps: numerical stability epsilon
|
| 562 |
+
expert_offset: offset to add to expert index
|
| 563 |
+
|
| 564 |
+
Returns:
|
| 565 |
+
(total_tokens, D) - output tensor
|
| 566 |
+
"""
|
| 567 |
+
return _GroupedPolyNormFn.apply(input, mul, weight, bias, offsets, eps,
|
| 568 |
+
expert_offset)
|
| 569 |
+
|
| 570 |
+
else:
|
| 571 |
+
|
| 572 |
+
def grouped_fused_mul_poly_norm(
|
| 573 |
+
input: Tensor,
|
| 574 |
+
mul: Tensor,
|
| 575 |
+
weight: Tensor,
|
| 576 |
+
bias: Tensor,
|
| 577 |
+
offsets: Tensor,
|
| 578 |
+
eps: float = 1e-6,
|
| 579 |
+
expert_offset: int = 0,
|
| 580 |
+
) -> Tensor:
|
| 581 |
+
raise RuntimeError(
|
| 582 |
+
"Triton is not available. Install triton to use "
|
| 583 |
+
"grouped_fused_mul_poly_norm.")
|
build/torch29-cxx11-cu126-x86_64-linux/__init__.py
CHANGED
|
@@ -2,6 +2,7 @@ import torch
|
|
| 2 |
|
| 3 |
from . import layers, parallel_style
|
| 4 |
from ._ops import ops
|
|
|
|
| 5 |
from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
|
| 6 |
from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
|
| 7 |
|
|
@@ -45,6 +46,7 @@ def fused_add_rms_norm(
|
|
| 45 |
__all__ = [
|
| 46 |
"poly_norm",
|
| 47 |
"fused_mul_poly_norm",
|
|
|
|
| 48 |
"rms_norm",
|
| 49 |
"fused_add_rms_norm",
|
| 50 |
"layers",
|
|
|
|
| 2 |
|
| 3 |
from . import layers, parallel_style
|
| 4 |
from ._ops import ops
|
| 5 |
+
from .grouped_poly_norm import grouped_fused_mul_poly_norm
|
| 6 |
from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
|
| 7 |
from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
|
| 8 |
|
|
|
|
| 46 |
__all__ = [
|
| 47 |
"poly_norm",
|
| 48 |
"fused_mul_poly_norm",
|
| 49 |
+
"grouped_fused_mul_poly_norm",
|
| 50 |
"rms_norm",
|
| 51 |
"fused_add_rms_norm",
|
| 52 |
"layers",
|
build/torch29-cxx11-cu126-x86_64-linux/{_activation_18b7543_dirty.abi3.so → _activation_0e6f27f_dirty.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 10756320
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0c5304ceac9171f76c03792dfb9b7e8299ba2f2885983c39031546fde8f61f8b
|
| 3 |
size 10756320
|
build/torch29-cxx11-cu126-x86_64-linux/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _activation_0e6f27f_dirty
|
| 3 |
+
ops = torch.ops._activation_0e6f27f_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_activation_0e6f27f_dirty::{op_name}"
|
build/torch29-cxx11-cu126-x86_64-linux/grouped_poly_norm.py
ADDED
|
@@ -0,0 +1,583 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Triton-accelerated Grouped FusedMulPolyNorm for MoE.
|
| 2 |
+
|
| 3 |
+
Fuses the entire PolyNorm computation into two Triton kernels (fwd + bwd),
|
| 4 |
+
eliminating multiple intermediate tensors and kernel launches.
|
| 5 |
+
|
| 6 |
+
PolyNorm formula (per row):
|
| 7 |
+
poly = w[0] * rms_norm(x^3) + w[1] * rms_norm(x^2) + w[2] * rms_norm(x) + bias
|
| 8 |
+
output = poly * mul
|
| 9 |
+
|
| 10 |
+
where rms_norm(x) = x / sqrt(mean(x^2, dim=-1) + eps)
|
| 11 |
+
|
| 12 |
+
Performance optimizations:
|
| 13 |
+
- @triton.autotune selects optimal BLOCK_D, num_warps, and num_stages per
|
| 14 |
+
hidden dimension.
|
| 15 |
+
- Single-tile specialization: when D <= BLOCK_D, all data stays in registers
|
| 16 |
+
across the reduction and output phases, eliminating redundant global reads.
|
| 17 |
+
- Multi-tile software pipelining: explicit num_stages in autotune configs
|
| 18 |
+
enables overlapping memory loads with computation across loop iterations.
|
| 19 |
+
- In-kernel binary search for expert mapping: eliminates 2 PyTorch kernel
|
| 20 |
+
launches (torch.arange + torch.bucketize) per forward/backward call.
|
| 21 |
+
- Backward 2-pass optimization: pass 1 merges RMS statistics computation
|
| 22 |
+
with dot product accumulation, pass 2 computes gradients. This reduces
|
| 23 |
+
memory traffic compared to a naive 3-pass approach.
|
| 24 |
+
|
| 25 |
+
Forward kernel: one program per row, tiles over D dimension.
|
| 26 |
+
- Computes x, x^2, x^3 in registers
|
| 27 |
+
- Computes three RMS norms in a single pass (shared variance reduction)
|
| 28 |
+
- Applies polynomial weights + bias + mul in-place
|
| 29 |
+
|
| 30 |
+
Backward kernel: one program per row, tiles over D dimension.
|
| 31 |
+
- Recomputes forward intermediates from saved inputs (activation recomputation)
|
| 32 |
+
- 2-pass: (1) RMS stats + dot products + bias grad, (2) grad_input + grad_mul + weight grads
|
| 33 |
+
- Weight/bias gradients use tl.atomic_add for cross-row accumulation
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
import torch
|
| 37 |
+
from torch import Tensor
|
| 38 |
+
|
| 39 |
+
try:
|
| 40 |
+
import triton
|
| 41 |
+
import triton.language as tl
|
| 42 |
+
|
| 43 |
+
HAS_TRITON = True
|
| 44 |
+
except ImportError:
|
| 45 |
+
HAS_TRITON = False
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# ---------------------------------------------------------------------------
|
| 49 |
+
# PyTorch reference implementation (for testing and benchmarking)
|
| 50 |
+
# ---------------------------------------------------------------------------
|
| 51 |
+
def _rms_norm(x: Tensor, eps: float) -> Tensor:
|
| 52 |
+
"""Per-row RMS normalization: x / sqrt(mean(x^2, dim=-1) + eps)"""
|
| 53 |
+
return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def grouped_fused_mul_poly_norm_ref(
|
| 57 |
+
input: Tensor,
|
| 58 |
+
mul: Tensor,
|
| 59 |
+
weight: Tensor,
|
| 60 |
+
bias: Tensor,
|
| 61 |
+
offsets: Tensor,
|
| 62 |
+
eps: float = 1e-6,
|
| 63 |
+
expert_offset: int = 0,
|
| 64 |
+
) -> Tensor:
|
| 65 |
+
"""PyTorch reference for Grouped FusedMulPolyNorm (vectorized, single pass).
|
| 66 |
+
|
| 67 |
+
Uses torch.bucketize to map tokens to experts, then computes PolyNorm
|
| 68 |
+
for all tokens at once. torch.compile friendly.
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
input: (total_tokens, D) - concatenated tokens for all experts
|
| 72 |
+
mul: (total_tokens, D) - gate values to multiply with
|
| 73 |
+
weight: (num_experts, 3) - per-expert polynomial weights [x^3, x^2, x]
|
| 74 |
+
bias: (num_experts, 1) - per-expert polynomial bias
|
| 75 |
+
offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32)
|
| 76 |
+
eps: numerical stability epsilon
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
(total_tokens, D) - output tensor
|
| 80 |
+
"""
|
| 81 |
+
orig_dtype = input.dtype
|
| 82 |
+
|
| 83 |
+
token_positions = torch.arange(input.shape[0], device=input.device)
|
| 84 |
+
expert_idx = torch.bucketize(token_positions, offsets, right=True) + expert_offset
|
| 85 |
+
|
| 86 |
+
weight_fp32 = weight.float()
|
| 87 |
+
bias_fp32 = bias.float()
|
| 88 |
+
|
| 89 |
+
per_token_w = weight_fp32[expert_idx]
|
| 90 |
+
per_token_b = bias_fp32[expert_idx]
|
| 91 |
+
|
| 92 |
+
x = input.float()
|
| 93 |
+
m = mul.float()
|
| 94 |
+
|
| 95 |
+
x2 = x * x
|
| 96 |
+
x3 = x2 * x
|
| 97 |
+
|
| 98 |
+
poly = (per_token_w[:, 0:1] * _rms_norm(x3, eps) +
|
| 99 |
+
per_token_w[:, 1:2] * _rms_norm(x2, eps) +
|
| 100 |
+
per_token_w[:, 2:3] * _rms_norm(x, eps) + per_token_b)
|
| 101 |
+
|
| 102 |
+
return (poly * m).to(orig_dtype)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
# ---------------------------------------------------------------------------
|
| 106 |
+
# Triton kernel implementation
|
| 107 |
+
# ---------------------------------------------------------------------------
|
| 108 |
+
if HAS_TRITON:
|
| 109 |
+
# --- Autotune configurations ---
|
| 110 |
+
_GROUPED_POLYNORM_FWD_CONFIGS = [
|
| 111 |
+
triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2),
|
| 112 |
+
triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3),
|
| 113 |
+
triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2),
|
| 114 |
+
triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3),
|
| 115 |
+
triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2),
|
| 116 |
+
triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4),
|
| 117 |
+
triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2),
|
| 118 |
+
triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3),
|
| 119 |
+
triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4),
|
| 120 |
+
triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2),
|
| 121 |
+
triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3),
|
| 122 |
+
triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4),
|
| 123 |
+
triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2),
|
| 124 |
+
triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2),
|
| 125 |
+
triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3),
|
| 126 |
+
triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2),
|
| 127 |
+
triton.Config({"BLOCK_D": 2048}, num_warps=4, num_stages=1),
|
| 128 |
+
triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1),
|
| 129 |
+
triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1),
|
| 130 |
+
triton.Config({"BLOCK_D": 2048}, num_warps=32, num_stages=1),
|
| 131 |
+
]
|
| 132 |
+
|
| 133 |
+
_GROUPED_POLYNORM_BWD_CONFIGS = [
|
| 134 |
+
triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2),
|
| 135 |
+
triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3),
|
| 136 |
+
triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2),
|
| 137 |
+
triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3),
|
| 138 |
+
triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2),
|
| 139 |
+
triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4),
|
| 140 |
+
triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2),
|
| 141 |
+
triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3),
|
| 142 |
+
triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4),
|
| 143 |
+
triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=5),
|
| 144 |
+
triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2),
|
| 145 |
+
triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3),
|
| 146 |
+
triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4),
|
| 147 |
+
triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2),
|
| 148 |
+
triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2),
|
| 149 |
+
triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3),
|
| 150 |
+
triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=4),
|
| 151 |
+
triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2),
|
| 152 |
+
triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1),
|
| 153 |
+
triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1),
|
| 154 |
+
]
|
| 155 |
+
|
| 156 |
+
@triton.autotune(
|
| 157 |
+
configs=_GROUPED_POLYNORM_FWD_CONFIGS,
|
| 158 |
+
key=["D"],
|
| 159 |
+
)
|
| 160 |
+
@triton.jit
|
| 161 |
+
def _grouped_polynorm_fwd_kernel(
|
| 162 |
+
input_ptr,
|
| 163 |
+
mul_ptr,
|
| 164 |
+
weight_ptr,
|
| 165 |
+
bias_ptr,
|
| 166 |
+
offsets_ptr,
|
| 167 |
+
output_ptr,
|
| 168 |
+
N,
|
| 169 |
+
D,
|
| 170 |
+
num_experts,
|
| 171 |
+
eps,
|
| 172 |
+
expert_offset,
|
| 173 |
+
stride_input_row,
|
| 174 |
+
stride_mul_row,
|
| 175 |
+
stride_out_row,
|
| 176 |
+
BLOCK_D: tl.constexpr,
|
| 177 |
+
):
|
| 178 |
+
"""Forward kernel: one program per row."""
|
| 179 |
+
row = tl.program_id(0)
|
| 180 |
+
if row >= N:
|
| 181 |
+
return
|
| 182 |
+
|
| 183 |
+
# Binary search for expert index (12 iters covers up to 4096 experts)
|
| 184 |
+
lo = 0
|
| 185 |
+
hi = num_experts
|
| 186 |
+
for _ in range(12):
|
| 187 |
+
if lo < hi:
|
| 188 |
+
mid = (lo + hi) // 2
|
| 189 |
+
if tl.load(offsets_ptr + mid) <= row:
|
| 190 |
+
lo = mid + 1
|
| 191 |
+
else:
|
| 192 |
+
hi = mid
|
| 193 |
+
eidx = lo + expert_offset
|
| 194 |
+
|
| 195 |
+
w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32)
|
| 196 |
+
w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32)
|
| 197 |
+
w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32)
|
| 198 |
+
b = tl.load(bias_ptr + eidx).to(tl.float32)
|
| 199 |
+
|
| 200 |
+
input_row_ptr = input_ptr + row * stride_input_row
|
| 201 |
+
mul_row_ptr = mul_ptr + row * stride_mul_row
|
| 202 |
+
out_row_ptr = output_ptr + row * stride_out_row
|
| 203 |
+
|
| 204 |
+
D_float = D.to(tl.float32)
|
| 205 |
+
|
| 206 |
+
# --- Single-tile path ---
|
| 207 |
+
if D <= BLOCK_D:
|
| 208 |
+
d_offs = tl.arange(0, BLOCK_D)
|
| 209 |
+
mask = d_offs < D
|
| 210 |
+
|
| 211 |
+
x = tl.load(input_row_ptr + d_offs, mask=mask,
|
| 212 |
+
other=0.0).to(tl.float32)
|
| 213 |
+
m = tl.load(mul_row_ptr + d_offs, mask=mask,
|
| 214 |
+
other=0.0).to(tl.float32)
|
| 215 |
+
|
| 216 |
+
x2 = x * x
|
| 217 |
+
x3 = x2 * x
|
| 218 |
+
|
| 219 |
+
inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps)
|
| 220 |
+
inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps)
|
| 221 |
+
inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps)
|
| 222 |
+
|
| 223 |
+
# Pre-multiply scalar weight * inv_rms to save 1 FMA per element
|
| 224 |
+
w0_inv = w0 * inv_rms_x3
|
| 225 |
+
w1_inv = w1 * inv_rms_x2
|
| 226 |
+
w2_inv = w2 * inv_rms_x
|
| 227 |
+
|
| 228 |
+
poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b
|
| 229 |
+
tl.store(out_row_ptr + d_offs, poly * m, mask=mask)
|
| 230 |
+
else:
|
| 231 |
+
# --- Multi-tile: two-pass approach ---
|
| 232 |
+
sum_x2 = tl.zeros((), dtype=tl.float32)
|
| 233 |
+
sum_x4 = tl.zeros((), dtype=tl.float32)
|
| 234 |
+
sum_x6 = tl.zeros((), dtype=tl.float32)
|
| 235 |
+
|
| 236 |
+
for d_start in range(0, D, BLOCK_D):
|
| 237 |
+
d_offs = d_start + tl.arange(0, BLOCK_D)
|
| 238 |
+
mask = d_offs < D
|
| 239 |
+
x = tl.load(input_row_ptr + d_offs, mask=mask,
|
| 240 |
+
other=0.0).to(tl.float32)
|
| 241 |
+
x2 = x * x
|
| 242 |
+
sum_x2 += tl.sum(x2)
|
| 243 |
+
sum_x4 += tl.sum(x2 * x2)
|
| 244 |
+
sum_x6 += tl.sum(x2 * x2 * x2)
|
| 245 |
+
|
| 246 |
+
inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps)
|
| 247 |
+
inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps)
|
| 248 |
+
inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps)
|
| 249 |
+
|
| 250 |
+
# Pre-multiply scalar weight * inv_rms
|
| 251 |
+
w0_inv = w0 * inv_rms_x3
|
| 252 |
+
w1_inv = w1 * inv_rms_x2
|
| 253 |
+
w2_inv = w2 * inv_rms_x
|
| 254 |
+
|
| 255 |
+
for d_start in range(0, D, BLOCK_D):
|
| 256 |
+
d_offs = d_start + tl.arange(0, BLOCK_D)
|
| 257 |
+
mask = d_offs < D
|
| 258 |
+
x = tl.load(input_row_ptr + d_offs, mask=mask,
|
| 259 |
+
other=0.0).to(tl.float32)
|
| 260 |
+
m = tl.load(mul_row_ptr + d_offs, mask=mask,
|
| 261 |
+
other=0.0).to(tl.float32)
|
| 262 |
+
x2 = x * x
|
| 263 |
+
x3 = x2 * x
|
| 264 |
+
poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b
|
| 265 |
+
tl.store(out_row_ptr + d_offs, poly * m, mask=mask)
|
| 266 |
+
|
| 267 |
+
@triton.autotune(
|
| 268 |
+
configs=_GROUPED_POLYNORM_BWD_CONFIGS,
|
| 269 |
+
key=["D"],
|
| 270 |
+
reset_to_zero=["grad_weight_ptr", "grad_bias_ptr"],
|
| 271 |
+
)
|
| 272 |
+
@triton.jit
|
| 273 |
+
def _grouped_polynorm_bwd_kernel(
|
| 274 |
+
grad_out_ptr,
|
| 275 |
+
input_ptr,
|
| 276 |
+
mul_ptr,
|
| 277 |
+
weight_ptr,
|
| 278 |
+
bias_ptr,
|
| 279 |
+
offsets_ptr,
|
| 280 |
+
grad_input_ptr,
|
| 281 |
+
grad_mul_ptr,
|
| 282 |
+
grad_weight_ptr,
|
| 283 |
+
grad_bias_ptr,
|
| 284 |
+
N,
|
| 285 |
+
D,
|
| 286 |
+
num_experts,
|
| 287 |
+
eps,
|
| 288 |
+
expert_offset,
|
| 289 |
+
stride_row,
|
| 290 |
+
BLOCK_D: tl.constexpr,
|
| 291 |
+
):
|
| 292 |
+
"""Backward kernel: one program per row, 2-pass approach.
|
| 293 |
+
|
| 294 |
+
Pass 1: RMS stats + dot products + bias grad
|
| 295 |
+
Pass 2: grad_input + grad_mul + weight grads (via atomic_add)
|
| 296 |
+
"""
|
| 297 |
+
row = tl.program_id(0)
|
| 298 |
+
if row >= N:
|
| 299 |
+
return
|
| 300 |
+
|
| 301 |
+
lo = 0
|
| 302 |
+
hi = num_experts
|
| 303 |
+
for _ in range(12):
|
| 304 |
+
if lo < hi:
|
| 305 |
+
mid = (lo + hi) // 2
|
| 306 |
+
if tl.load(offsets_ptr + mid) <= row:
|
| 307 |
+
lo = mid + 1
|
| 308 |
+
else:
|
| 309 |
+
hi = mid
|
| 310 |
+
eidx = lo + expert_offset
|
| 311 |
+
|
| 312 |
+
w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32)
|
| 313 |
+
w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32)
|
| 314 |
+
w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32)
|
| 315 |
+
b_val = tl.load(bias_ptr + eidx).to(tl.float32)
|
| 316 |
+
|
| 317 |
+
input_row_ptr = input_ptr + row * stride_row
|
| 318 |
+
mul_row_ptr = mul_ptr + row * stride_row
|
| 319 |
+
grad_out_row_ptr = grad_out_ptr + row * stride_row
|
| 320 |
+
grad_input_row_ptr = grad_input_ptr + row * stride_row
|
| 321 |
+
grad_mul_row_ptr = grad_mul_ptr + row * stride_row
|
| 322 |
+
|
| 323 |
+
D_float = D.to(tl.float32)
|
| 324 |
+
|
| 325 |
+
# --- Single-tile path ---
|
| 326 |
+
if D <= BLOCK_D:
|
| 327 |
+
d_offs = tl.arange(0, BLOCK_D)
|
| 328 |
+
mask = d_offs < D
|
| 329 |
+
|
| 330 |
+
x = tl.load(input_row_ptr + d_offs, mask=mask,
|
| 331 |
+
other=0.0).to(tl.float32)
|
| 332 |
+
m = tl.load(mul_row_ptr + d_offs, mask=mask,
|
| 333 |
+
other=0.0).to(tl.float32)
|
| 334 |
+
go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
|
| 335 |
+
other=0.0).to(tl.float32)
|
| 336 |
+
|
| 337 |
+
x2 = x * x
|
| 338 |
+
x3 = x2 * x
|
| 339 |
+
|
| 340 |
+
# Compute RMS stats (x4 inlined to reduce register pressure)
|
| 341 |
+
inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps)
|
| 342 |
+
inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps)
|
| 343 |
+
inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps)
|
| 344 |
+
|
| 345 |
+
w0_inv = w0 * inv_rms_x3
|
| 346 |
+
w1_inv = w1 * inv_rms_x2
|
| 347 |
+
w2_inv = w2 * inv_rms_x
|
| 348 |
+
|
| 349 |
+
dpoly = go * m
|
| 350 |
+
|
| 351 |
+
# Dot products for coefficients and weight grads
|
| 352 |
+
sum_dpoly_x = tl.sum(dpoly * x)
|
| 353 |
+
sum_dpoly_x2 = tl.sum(dpoly * x2)
|
| 354 |
+
sum_dpoly_x3 = tl.sum(dpoly * x3)
|
| 355 |
+
grad_b_acc = tl.sum(dpoly)
|
| 356 |
+
|
| 357 |
+
# Weight grads
|
| 358 |
+
grad_w0_acc = inv_rms_x3 * sum_dpoly_x3
|
| 359 |
+
grad_w1_acc = inv_rms_x2 * sum_dpoly_x2
|
| 360 |
+
grad_w2_acc = inv_rms_x * sum_dpoly_x
|
| 361 |
+
|
| 362 |
+
# Coefficients for grad_input
|
| 363 |
+
coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float
|
| 364 |
+
coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float
|
| 365 |
+
coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float
|
| 366 |
+
|
| 367 |
+
# grad_mul
|
| 368 |
+
poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv
|
| 369 |
+
tl.store(grad_mul_row_ptr + d_offs, go * (poly + b_val), mask=mask)
|
| 370 |
+
|
| 371 |
+
# grad_input (in-place accumulation to reduce register pressure)
|
| 372 |
+
g = inv_rms_x * (w2 * dpoly - x * coeff_x)
|
| 373 |
+
g += 2.0 * x * inv_rms_x2 * (w1 * dpoly - x2 * coeff_x2)
|
| 374 |
+
g += 3.0 * x2 * inv_rms_x3 * (w0 * dpoly - x3 * coeff_x3)
|
| 375 |
+
tl.store(grad_input_row_ptr + d_offs, g, mask=mask)
|
| 376 |
+
|
| 377 |
+
tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc)
|
| 378 |
+
tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc)
|
| 379 |
+
tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc)
|
| 380 |
+
tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc)
|
| 381 |
+
else:
|
| 382 |
+
# --- Multi-tile: 2-pass ---
|
| 383 |
+
# Pass 1: RMS stats + dot products + bias grad
|
| 384 |
+
sum_x2 = tl.zeros((), dtype=tl.float32)
|
| 385 |
+
sum_x4 = tl.zeros((), dtype=tl.float32)
|
| 386 |
+
sum_x6 = tl.zeros((), dtype=tl.float32)
|
| 387 |
+
sum_dpoly_x = tl.zeros((), dtype=tl.float32)
|
| 388 |
+
sum_dpoly_x2 = tl.zeros((), dtype=tl.float32)
|
| 389 |
+
sum_dpoly_x3 = tl.zeros((), dtype=tl.float32)
|
| 390 |
+
grad_b_acc = tl.zeros((), dtype=tl.float32)
|
| 391 |
+
|
| 392 |
+
for d_start in range(0, D, BLOCK_D):
|
| 393 |
+
d_offs = d_start + tl.arange(0, BLOCK_D)
|
| 394 |
+
mask = d_offs < D
|
| 395 |
+
x = tl.load(input_row_ptr + d_offs, mask=mask,
|
| 396 |
+
other=0.0).to(tl.float32)
|
| 397 |
+
m = tl.load(mul_row_ptr + d_offs, mask=mask,
|
| 398 |
+
other=0.0).to(tl.float32)
|
| 399 |
+
go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
|
| 400 |
+
other=0.0).to(tl.float32)
|
| 401 |
+
|
| 402 |
+
x2 = x * x
|
| 403 |
+
x3 = x2 * x
|
| 404 |
+
dpoly = go * m
|
| 405 |
+
|
| 406 |
+
sum_x2 += tl.sum(x2)
|
| 407 |
+
sum_x4 += tl.sum(x2 * x2)
|
| 408 |
+
sum_x6 += tl.sum(x2 * x2 * x2)
|
| 409 |
+
sum_dpoly_x += tl.sum(dpoly * x)
|
| 410 |
+
sum_dpoly_x2 += tl.sum(dpoly * x2)
|
| 411 |
+
sum_dpoly_x3 += tl.sum(dpoly * x3)
|
| 412 |
+
grad_b_acc += tl.sum(dpoly)
|
| 413 |
+
|
| 414 |
+
inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps)
|
| 415 |
+
inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps)
|
| 416 |
+
inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps)
|
| 417 |
+
|
| 418 |
+
w0_inv = w0 * inv_rms_x3
|
| 419 |
+
w1_inv = w1 * inv_rms_x2
|
| 420 |
+
w2_inv = w2 * inv_rms_x
|
| 421 |
+
|
| 422 |
+
# Weight grads
|
| 423 |
+
grad_w0_acc = inv_rms_x3 * sum_dpoly_x3
|
| 424 |
+
grad_w1_acc = inv_rms_x2 * sum_dpoly_x2
|
| 425 |
+
grad_w2_acc = inv_rms_x * sum_dpoly_x
|
| 426 |
+
|
| 427 |
+
# Coefficients for grad_input
|
| 428 |
+
coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float
|
| 429 |
+
coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float
|
| 430 |
+
coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float
|
| 431 |
+
|
| 432 |
+
# Pass 2: grad_input + grad_mul
|
| 433 |
+
for d_start in range(0, D, BLOCK_D):
|
| 434 |
+
d_offs = d_start + tl.arange(0, BLOCK_D)
|
| 435 |
+
mask = d_offs < D
|
| 436 |
+
x = tl.load(input_row_ptr + d_offs, mask=mask,
|
| 437 |
+
other=0.0).to(tl.float32)
|
| 438 |
+
m = tl.load(mul_row_ptr + d_offs, mask=mask,
|
| 439 |
+
other=0.0).to(tl.float32)
|
| 440 |
+
go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
|
| 441 |
+
other=0.0).to(tl.float32)
|
| 442 |
+
|
| 443 |
+
x2 = x * x
|
| 444 |
+
x3 = x2 * x
|
| 445 |
+
|
| 446 |
+
poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv
|
| 447 |
+
tl.store(grad_mul_row_ptr + d_offs,
|
| 448 |
+
go * (poly + b_val),
|
| 449 |
+
mask=mask)
|
| 450 |
+
|
| 451 |
+
dpoly = go * m
|
| 452 |
+
g = inv_rms_x * (w2 * dpoly - x * coeff_x)
|
| 453 |
+
g += (2.0 * x * inv_rms_x2 *
|
| 454 |
+
(w1 * dpoly - x2 * coeff_x2))
|
| 455 |
+
g += (3.0 * x2 * inv_rms_x3 *
|
| 456 |
+
(w0 * dpoly - x3 * coeff_x3))
|
| 457 |
+
|
| 458 |
+
tl.store(grad_input_row_ptr + d_offs, g, mask=mask)
|
| 459 |
+
|
| 460 |
+
tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc)
|
| 461 |
+
tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc)
|
| 462 |
+
tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc)
|
| 463 |
+
tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc)
|
| 464 |
+
|
| 465 |
+
class _GroupedPolyNormFn(torch.autograd.Function):
|
| 466 |
+
|
| 467 |
+
@staticmethod
|
| 468 |
+
def forward(ctx, input, mul, weight, bias, offsets, eps, expert_offset):
|
| 469 |
+
N, D = input.shape
|
| 470 |
+
input = input.contiguous()
|
| 471 |
+
mul = mul.contiguous()
|
| 472 |
+
output = torch.empty_like(input)
|
| 473 |
+
|
| 474 |
+
num_experts = offsets.shape[0]
|
| 475 |
+
assert num_experts <= 4096, (
|
| 476 |
+
f"Supports at most 4096 experts, got {num_experts}.")
|
| 477 |
+
|
| 478 |
+
_grouped_polynorm_fwd_kernel[(N,)](
|
| 479 |
+
input,
|
| 480 |
+
mul,
|
| 481 |
+
weight,
|
| 482 |
+
bias,
|
| 483 |
+
offsets,
|
| 484 |
+
output,
|
| 485 |
+
N,
|
| 486 |
+
D,
|
| 487 |
+
num_experts,
|
| 488 |
+
eps,
|
| 489 |
+
expert_offset,
|
| 490 |
+
stride_input_row=input.stride(0),
|
| 491 |
+
stride_mul_row=mul.stride(0),
|
| 492 |
+
stride_out_row=output.stride(0),
|
| 493 |
+
)
|
| 494 |
+
|
| 495 |
+
ctx.save_for_backward(input, mul, weight, bias, offsets)
|
| 496 |
+
ctx.eps = eps
|
| 497 |
+
ctx.expert_offset = expert_offset
|
| 498 |
+
return output
|
| 499 |
+
|
| 500 |
+
@staticmethod
|
| 501 |
+
def backward(ctx, grad_output):
|
| 502 |
+
input, mul, weight, bias, offsets = ctx.saved_tensors
|
| 503 |
+
eps = ctx.eps
|
| 504 |
+
expert_offset = ctx.expert_offset
|
| 505 |
+
N, D = input.shape
|
| 506 |
+
|
| 507 |
+
grad_output = grad_output.contiguous()
|
| 508 |
+
grad_input = torch.empty_like(input)
|
| 509 |
+
grad_mul = torch.empty_like(mul)
|
| 510 |
+
grad_weight = torch.zeros(weight.shape[0],
|
| 511 |
+
3,
|
| 512 |
+
device=weight.device,
|
| 513 |
+
dtype=torch.float32)
|
| 514 |
+
grad_bias = torch.zeros(bias.shape[0],
|
| 515 |
+
device=bias.device,
|
| 516 |
+
dtype=torch.float32)
|
| 517 |
+
|
| 518 |
+
num_experts = offsets.shape[0]
|
| 519 |
+
|
| 520 |
+
_grouped_polynorm_bwd_kernel[(N,)](
|
| 521 |
+
grad_output,
|
| 522 |
+
input,
|
| 523 |
+
mul,
|
| 524 |
+
weight,
|
| 525 |
+
bias,
|
| 526 |
+
offsets,
|
| 527 |
+
grad_input,
|
| 528 |
+
grad_mul,
|
| 529 |
+
grad_weight,
|
| 530 |
+
grad_bias,
|
| 531 |
+
N,
|
| 532 |
+
D,
|
| 533 |
+
num_experts,
|
| 534 |
+
eps,
|
| 535 |
+
expert_offset,
|
| 536 |
+
stride_row=input.stride(0),
|
| 537 |
+
)
|
| 538 |
+
|
| 539 |
+
grad_weight = grad_weight.to(weight.dtype)
|
| 540 |
+
grad_bias = grad_bias.unsqueeze(-1).to(bias.dtype)
|
| 541 |
+
|
| 542 |
+
return grad_input, grad_mul, grad_weight, grad_bias, None, None, None
|
| 543 |
+
|
| 544 |
+
def grouped_fused_mul_poly_norm(
|
| 545 |
+
input: Tensor,
|
| 546 |
+
mul: Tensor,
|
| 547 |
+
weight: Tensor,
|
| 548 |
+
bias: Tensor,
|
| 549 |
+
offsets: Tensor,
|
| 550 |
+
eps: float = 1e-6,
|
| 551 |
+
expert_offset: int = 0,
|
| 552 |
+
) -> Tensor:
|
| 553 |
+
"""Triton-accelerated Grouped FusedMulPolyNorm.
|
| 554 |
+
|
| 555 |
+
Args:
|
| 556 |
+
input: (total_tokens, D) - concatenated tokens for all experts
|
| 557 |
+
mul: (total_tokens, D) - gate values to multiply with
|
| 558 |
+
weight: (num_experts, 3) - per-expert polynomial weights
|
| 559 |
+
bias: (num_experts, 1) - per-expert polynomial bias
|
| 560 |
+
offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32)
|
| 561 |
+
eps: numerical stability epsilon
|
| 562 |
+
expert_offset: offset to add to expert index
|
| 563 |
+
|
| 564 |
+
Returns:
|
| 565 |
+
(total_tokens, D) - output tensor
|
| 566 |
+
"""
|
| 567 |
+
return _GroupedPolyNormFn.apply(input, mul, weight, bias, offsets, eps,
|
| 568 |
+
expert_offset)
|
| 569 |
+
|
| 570 |
+
else:
|
| 571 |
+
|
| 572 |
+
def grouped_fused_mul_poly_norm(
|
| 573 |
+
input: Tensor,
|
| 574 |
+
mul: Tensor,
|
| 575 |
+
weight: Tensor,
|
| 576 |
+
bias: Tensor,
|
| 577 |
+
offsets: Tensor,
|
| 578 |
+
eps: float = 1e-6,
|
| 579 |
+
expert_offset: int = 0,
|
| 580 |
+
) -> Tensor:
|
| 581 |
+
raise RuntimeError(
|
| 582 |
+
"Triton is not available. Install triton to use "
|
| 583 |
+
"grouped_fused_mul_poly_norm.")
|
build/torch29-cxx11-cu128-x86_64-linux/__init__.py
CHANGED
|
@@ -2,6 +2,7 @@ import torch
|
|
| 2 |
|
| 3 |
from . import layers, parallel_style
|
| 4 |
from ._ops import ops
|
|
|
|
| 5 |
from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
|
| 6 |
from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
|
| 7 |
|
|
@@ -45,6 +46,7 @@ def fused_add_rms_norm(
|
|
| 45 |
__all__ = [
|
| 46 |
"poly_norm",
|
| 47 |
"fused_mul_poly_norm",
|
|
|
|
| 48 |
"rms_norm",
|
| 49 |
"fused_add_rms_norm",
|
| 50 |
"layers",
|
|
|
|
| 2 |
|
| 3 |
from . import layers, parallel_style
|
| 4 |
from ._ops import ops
|
| 5 |
+
from .grouped_poly_norm import grouped_fused_mul_poly_norm
|
| 6 |
from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
|
| 7 |
from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
|
| 8 |
|
|
|
|
| 46 |
__all__ = [
|
| 47 |
"poly_norm",
|
| 48 |
"fused_mul_poly_norm",
|
| 49 |
+
"grouped_fused_mul_poly_norm",
|
| 50 |
"rms_norm",
|
| 51 |
"fused_add_rms_norm",
|
| 52 |
"layers",
|
build/torch29-cxx11-cu128-x86_64-linux/{_activation_18b7543_dirty.abi3.so → _activation_0e6f27f_dirty.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 15804336
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:dfa89588a5e7e74b3a903912190b97004e308dd8fcb87832c2798d99733591f2
|
| 3 |
size 15804336
|