Instructions to use Motif-Technologies/activation with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Kernels
How to use Motif-Technologies/activation with Kernels:
# !pip install kernels from kernels import get_kernel kernel = get_kernel("Motif-Technologies/activation") - Notebooks
- Google Colab
- Kaggle
refactor: remove Triton kernels, add hidden_clamp to unscored ops
Browse files- Remove all Triton kernel code (fwd/bwd kernels, autotune configs,
triton import) — replaced by CUDA kernels in grouped_poly_norm.cu
- Add hidden_clamp parameter to unscored C++ ops (forward/backward)
so both scored and unscored paths support clamping
- Update register_fake, autograd Function, and dispatch for unscored ops
- Replace HAS_TRITON with _has_cuda_ops in tests
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
activation/grouped_poly_norm.cu
CHANGED
|
@@ -609,8 +609,9 @@ std::tuple<torch::Tensor, torch::Tensor>
|
|
| 609 |
grouped_poly_norm_forward(
|
| 610 |
const torch::Tensor &input, const torch::Tensor &mul,
|
| 611 |
const torch::Tensor &weight, const torch::Tensor &bias,
|
| 612 |
-
const torch::Tensor &offsets, double eps, int64_t expert_offset
|
| 613 |
-
|
|
|
|
| 614 |
}
|
| 615 |
|
| 616 |
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
|
|
@@ -618,11 +619,12 @@ grouped_poly_norm_backward(
|
|
| 618 |
const torch::Tensor &grad_output, const torch::Tensor &input,
|
| 619 |
const torch::Tensor &mul, const torch::Tensor &weight,
|
| 620 |
const torch::Tensor &bias, const torch::Tensor &offsets,
|
| 621 |
-
const torch::Tensor &inv_rms, double eps, int64_t expert_offset
|
|
|
|
| 622 |
const int64_t N = input.size(0);
|
| 623 |
auto [ig, mg, wg, bg, _] = _bwd_impl(
|
| 624 |
grad_output, input, mul, weight, bias, offsets, inv_rms,
|
| 625 |
-
nullptr, nullptr, N, eps, expert_offset,
|
| 626 |
return {ig, mg, wg, bg};
|
| 627 |
}
|
| 628 |
|
|
|
|
| 609 |
grouped_poly_norm_forward(
|
| 610 |
const torch::Tensor &input, const torch::Tensor &mul,
|
| 611 |
const torch::Tensor &weight, const torch::Tensor &bias,
|
| 612 |
+
const torch::Tensor &offsets, double eps, int64_t expert_offset,
|
| 613 |
+
double hidden_clamp) {
|
| 614 |
+
return _fwd_impl(input, mul, weight, bias, offsets, nullptr, eps, expert_offset, hidden_clamp);
|
| 615 |
}
|
| 616 |
|
| 617 |
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
|
|
|
|
| 619 |
const torch::Tensor &grad_output, const torch::Tensor &input,
|
| 620 |
const torch::Tensor &mul, const torch::Tensor &weight,
|
| 621 |
const torch::Tensor &bias, const torch::Tensor &offsets,
|
| 622 |
+
const torch::Tensor &inv_rms, double eps, int64_t expert_offset,
|
| 623 |
+
double hidden_clamp) {
|
| 624 |
const int64_t N = input.size(0);
|
| 625 |
auto [ig, mg, wg, bg, _] = _bwd_impl(
|
| 626 |
grad_output, input, mul, weight, bias, offsets, inv_rms,
|
| 627 |
+
nullptr, nullptr, N, eps, expert_offset, hidden_clamp);
|
| 628 |
return {ig, mg, wg, bg};
|
| 629 |
}
|
| 630 |
|
tests/test_fused_mul_grouped_poly_norm.py
CHANGED
|
@@ -2,11 +2,11 @@ import pytest
|
|
| 2 |
import torch
|
| 3 |
|
| 4 |
from grouped_poly_norm import (
|
| 5 |
-
|
| 6 |
fused_mul_grouped_poly_norm_ref,
|
| 7 |
)
|
| 8 |
|
| 9 |
-
if
|
| 10 |
from grouped_poly_norm import fused_mul_grouped_poly_norm
|
| 11 |
|
| 12 |
from .utils import assert_close
|
|
@@ -95,7 +95,7 @@ def _run_triton(input_t, mul_t, weight, bias, offsets, expert_offset=0,
|
|
| 95 |
return grads + (s.grad,) if s is not None else grads + (None,)
|
| 96 |
|
| 97 |
|
| 98 |
-
@pytest.mark.skipif(not
|
| 99 |
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
| 100 |
@pytest.mark.parametrize("d", D)
|
| 101 |
@pytest.mark.parametrize("num_experts", NUM_EXPERTS_LIST)
|
|
@@ -134,7 +134,7 @@ def test_fused_mul_grouped_poly_norm_forward(
|
|
| 134 |
assert_close(out_ref, out_tri, atol=1e-2, rtol=1e-2)
|
| 135 |
|
| 136 |
|
| 137 |
-
@pytest.mark.skipif(not
|
| 138 |
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
| 139 |
@pytest.mark.parametrize("d", D)
|
| 140 |
@pytest.mark.parametrize("num_experts", NUM_EXPERTS_LIST)
|
|
@@ -173,7 +173,7 @@ def test_fused_mul_grouped_poly_norm_backward(
|
|
| 173 |
assert_close(b_grad_ref, b_grad_tri, atol=atol, rtol=rtol)
|
| 174 |
|
| 175 |
|
| 176 |
-
@pytest.mark.skipif(not
|
| 177 |
@pytest.mark.parametrize("dtype", DTYPES)
|
| 178 |
@pytest.mark.parametrize("expert_offset", EXPERT_OFFSETS)
|
| 179 |
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
|
@@ -236,7 +236,7 @@ def test_fused_mul_grouped_poly_norm_zero_token_experts(
|
|
| 236 |
f"but got max={b_grad_tri[wi].abs().max().item()}")
|
| 237 |
|
| 238 |
|
| 239 |
-
@pytest.mark.skipif(not
|
| 240 |
@pytest.mark.parametrize("dtype", DTYPES)
|
| 241 |
@pytest.mark.parametrize("expert_offset", EXPERT_OFFSETS)
|
| 242 |
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
|
@@ -265,7 +265,7 @@ def test_fused_mul_grouped_poly_norm_no_nan_inf(
|
|
| 265 |
# ---------------------------------------------------------------------------
|
| 266 |
# Scores tests
|
| 267 |
# ---------------------------------------------------------------------------
|
| 268 |
-
@pytest.mark.skipif(not
|
| 269 |
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
| 270 |
@pytest.mark.parametrize("d", D)
|
| 271 |
@pytest.mark.parametrize("num_experts", [8, 48])
|
|
@@ -289,7 +289,7 @@ def test_fused_mul_grouped_poly_norm_scores_forward(
|
|
| 289 |
assert_close(out_ref, out_tri, atol=atol, rtol=rtol)
|
| 290 |
|
| 291 |
|
| 292 |
-
@pytest.mark.skipif(not
|
| 293 |
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
| 294 |
@pytest.mark.parametrize("d", D)
|
| 295 |
@pytest.mark.parametrize("num_experts", [8, 48])
|
|
@@ -326,7 +326,7 @@ def test_fused_mul_grouped_poly_norm_scores_backward(
|
|
| 326 |
CLAMP_VALUES = [10.0, 1.0, 0.5]
|
| 327 |
|
| 328 |
|
| 329 |
-
@pytest.mark.skipif(not
|
| 330 |
@pytest.mark.parametrize("num_tokens", [4096])
|
| 331 |
@pytest.mark.parametrize("d", [256, 1280])
|
| 332 |
@pytest.mark.parametrize("num_experts", [8])
|
|
@@ -353,7 +353,7 @@ def test_fused_mul_grouped_poly_norm_hidden_clamp_forward(
|
|
| 353 |
assert_close(out_ref, out_tri, atol=atol, rtol=rtol)
|
| 354 |
|
| 355 |
|
| 356 |
-
@pytest.mark.skipif(not
|
| 357 |
@pytest.mark.parametrize("num_tokens", [4096])
|
| 358 |
@pytest.mark.parametrize("d", [256, 1280])
|
| 359 |
@pytest.mark.parametrize("num_experts", [8])
|
|
|
|
| 2 |
import torch
|
| 3 |
|
| 4 |
from grouped_poly_norm import (
|
| 5 |
+
_has_cuda_ops,
|
| 6 |
fused_mul_grouped_poly_norm_ref,
|
| 7 |
)
|
| 8 |
|
| 9 |
+
if _has_cuda_ops:
|
| 10 |
from grouped_poly_norm import fused_mul_grouped_poly_norm
|
| 11 |
|
| 12 |
from .utils import assert_close
|
|
|
|
| 95 |
return grads + (s.grad,) if s is not None else grads + (None,)
|
| 96 |
|
| 97 |
|
| 98 |
+
@pytest.mark.skipif(not _has_cuda_ops, reason="CUDA ops not available")
|
| 99 |
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
| 100 |
@pytest.mark.parametrize("d", D)
|
| 101 |
@pytest.mark.parametrize("num_experts", NUM_EXPERTS_LIST)
|
|
|
|
| 134 |
assert_close(out_ref, out_tri, atol=1e-2, rtol=1e-2)
|
| 135 |
|
| 136 |
|
| 137 |
+
@pytest.mark.skipif(not _has_cuda_ops, reason="CUDA ops not available")
|
| 138 |
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
| 139 |
@pytest.mark.parametrize("d", D)
|
| 140 |
@pytest.mark.parametrize("num_experts", NUM_EXPERTS_LIST)
|
|
|
|
| 173 |
assert_close(b_grad_ref, b_grad_tri, atol=atol, rtol=rtol)
|
| 174 |
|
| 175 |
|
| 176 |
+
@pytest.mark.skipif(not _has_cuda_ops, reason="CUDA ops not available")
|
| 177 |
@pytest.mark.parametrize("dtype", DTYPES)
|
| 178 |
@pytest.mark.parametrize("expert_offset", EXPERT_OFFSETS)
|
| 179 |
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
|
|
|
| 236 |
f"but got max={b_grad_tri[wi].abs().max().item()}")
|
| 237 |
|
| 238 |
|
| 239 |
+
@pytest.mark.skipif(not _has_cuda_ops, reason="CUDA ops not available")
|
| 240 |
@pytest.mark.parametrize("dtype", DTYPES)
|
| 241 |
@pytest.mark.parametrize("expert_offset", EXPERT_OFFSETS)
|
| 242 |
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
|
|
|
| 265 |
# ---------------------------------------------------------------------------
|
| 266 |
# Scores tests
|
| 267 |
# ---------------------------------------------------------------------------
|
| 268 |
+
@pytest.mark.skipif(not _has_cuda_ops, reason="CUDA ops not available")
|
| 269 |
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
| 270 |
@pytest.mark.parametrize("d", D)
|
| 271 |
@pytest.mark.parametrize("num_experts", [8, 48])
|
|
|
|
| 289 |
assert_close(out_ref, out_tri, atol=atol, rtol=rtol)
|
| 290 |
|
| 291 |
|
| 292 |
+
@pytest.mark.skipif(not _has_cuda_ops, reason="CUDA ops not available")
|
| 293 |
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
| 294 |
@pytest.mark.parametrize("d", D)
|
| 295 |
@pytest.mark.parametrize("num_experts", [8, 48])
|
|
|
|
| 326 |
CLAMP_VALUES = [10.0, 1.0, 0.5]
|
| 327 |
|
| 328 |
|
| 329 |
+
@pytest.mark.skipif(not _has_cuda_ops, reason="CUDA ops not available")
|
| 330 |
@pytest.mark.parametrize("num_tokens", [4096])
|
| 331 |
@pytest.mark.parametrize("d", [256, 1280])
|
| 332 |
@pytest.mark.parametrize("num_experts", [8])
|
|
|
|
| 353 |
assert_close(out_ref, out_tri, atol=atol, rtol=rtol)
|
| 354 |
|
| 355 |
|
| 356 |
+
@pytest.mark.skipif(not _has_cuda_ops, reason="CUDA ops not available")
|
| 357 |
@pytest.mark.parametrize("num_tokens", [4096])
|
| 358 |
@pytest.mark.parametrize("d", [256, 1280])
|
| 359 |
@pytest.mark.parametrize("num_experts", [8])
|
torch-ext/activation/grouped_poly_norm.py
CHANGED
|
@@ -1,49 +1,26 @@
|
|
| 1 |
-
"""
|
| 2 |
|
| 3 |
-
Fuses the entire PolyNorm computation into
|
| 4 |
-
|
| 5 |
|
| 6 |
PolyNorm formula (per row):
|
| 7 |
poly = w[0] * rms_norm(x^3) + w[1] * rms_norm(x^2) + w[2] * rms_norm(x) + bias
|
| 8 |
-
output = poly * mul
|
|
|
|
| 9 |
|
| 10 |
where rms_norm(x) = x / sqrt(mean(x^2, dim=-1) + eps)
|
| 11 |
|
| 12 |
-
|
| 13 |
-
-
|
| 14 |
-
|
| 15 |
-
-
|
| 16 |
-
|
| 17 |
-
-
|
| 18 |
-
enables overlapping memory loads with computation across loop iterations.
|
| 19 |
-
- In-kernel binary search for expert mapping: eliminates 2 PyTorch kernel
|
| 20 |
-
launches (torch.arange + torch.bucketize) per forward/backward call.
|
| 21 |
-
- Backward 2-pass optimization: pass 1 merges RMS statistics computation
|
| 22 |
-
with dot product accumulation, pass 2 computes gradients. This reduces
|
| 23 |
-
memory traffic compared to a naive 3-pass approach.
|
| 24 |
-
|
| 25 |
-
Forward kernel: one program per row, tiles over D dimension.
|
| 26 |
-
- Computes x, x^2, x^3 in registers
|
| 27 |
-
- Computes three RMS norms in a single pass (shared variance reduction)
|
| 28 |
-
- Applies polynomial weights + bias + mul in-place
|
| 29 |
-
|
| 30 |
-
Backward kernel: one program per row, tiles over D dimension.
|
| 31 |
-
- Recomputes forward intermediates from saved inputs (activation recomputation)
|
| 32 |
-
- 2-pass: (1) RMS stats + dot products + bias grad, (2) grad_input + grad_mul + weight grads
|
| 33 |
-
- Weight/bias gradients use tl.atomic_add for cross-row accumulation
|
| 34 |
"""
|
| 35 |
|
| 36 |
import torch
|
| 37 |
from torch import Tensor
|
| 38 |
|
| 39 |
-
try:
|
| 40 |
-
import triton
|
| 41 |
-
import triton.language as tl
|
| 42 |
-
|
| 43 |
-
HAS_TRITON = True
|
| 44 |
-
except ImportError:
|
| 45 |
-
HAS_TRITON = False
|
| 46 |
-
|
| 47 |
# Try to load CUDA ops at module level
|
| 48 |
_ops = None
|
| 49 |
try:
|
|
@@ -61,14 +38,15 @@ _has_cuda_ops = _ops is not None and hasattr(_ops, 'grouped_poly_norm_forward')
|
|
| 61 |
if _has_cuda_ops:
|
| 62 |
try:
|
| 63 |
@torch.library.register_fake("_activation::grouped_poly_norm_forward")
|
| 64 |
-
def _fwd_fake(input, mul, weight, bias, offsets, eps, expert_offset
|
|
|
|
| 65 |
return (torch.empty_like(input),
|
| 66 |
torch.empty(input.shape[0], 3, dtype=torch.float32,
|
| 67 |
device=input.device))
|
| 68 |
|
| 69 |
@torch.library.register_fake("_activation::grouped_poly_norm_backward")
|
| 70 |
def _bwd_fake(grad_output, input, mul, weight, bias, offsets, inv_rms,
|
| 71 |
-
eps, expert_offset):
|
| 72 |
return (torch.empty_like(input),
|
| 73 |
torch.empty_like(mul),
|
| 74 |
torch.empty_like(weight),
|
|
@@ -164,383 +142,32 @@ def fused_mul_grouped_poly_norm_ref(
|
|
| 164 |
|
| 165 |
|
| 166 |
# ---------------------------------------------------------------------------
|
| 167 |
-
#
|
| 168 |
# ---------------------------------------------------------------------------
|
| 169 |
-
if
|
| 170 |
-
# --- Autotune configurations ---
|
| 171 |
-
_GROUPED_POLYNORM_FWD_CONFIGS = [
|
| 172 |
-
triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2),
|
| 173 |
-
triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3),
|
| 174 |
-
triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2),
|
| 175 |
-
triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3),
|
| 176 |
-
triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2),
|
| 177 |
-
triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4),
|
| 178 |
-
triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2),
|
| 179 |
-
triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3),
|
| 180 |
-
triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4),
|
| 181 |
-
triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2),
|
| 182 |
-
triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3),
|
| 183 |
-
triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4),
|
| 184 |
-
triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2),
|
| 185 |
-
triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2),
|
| 186 |
-
triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3),
|
| 187 |
-
triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2),
|
| 188 |
-
triton.Config({"BLOCK_D": 2048}, num_warps=4, num_stages=1),
|
| 189 |
-
triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1),
|
| 190 |
-
triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1),
|
| 191 |
-
triton.Config({"BLOCK_D": 2048}, num_warps=32, num_stages=1),
|
| 192 |
-
]
|
| 193 |
-
|
| 194 |
-
_GROUPED_POLYNORM_BWD_CONFIGS = [
|
| 195 |
-
# Low-warp configs for high SM occupancy (latency hiding)
|
| 196 |
-
triton.Config({"BLOCK_D": 2048, "BLOCK_N": 1}, num_warps=2, num_stages=1),
|
| 197 |
-
triton.Config({"BLOCK_D": 2048, "BLOCK_N": 1}, num_warps=4, num_stages=1),
|
| 198 |
-
triton.Config({"BLOCK_D": 2048, "BLOCK_N": 2}, num_warps=2, num_stages=1),
|
| 199 |
-
triton.Config({"BLOCK_D": 2048, "BLOCK_N": 2}, num_warps=4, num_stages=1),
|
| 200 |
-
triton.Config({"BLOCK_D": 2048, "BLOCK_N": 4}, num_warps=2, num_stages=1),
|
| 201 |
-
triton.Config({"BLOCK_D": 2048, "BLOCK_N": 4}, num_warps=4, num_stages=1),
|
| 202 |
-
triton.Config({"BLOCK_D": 2048, "BLOCK_N": 8}, num_warps=2, num_stages=1),
|
| 203 |
-
triton.Config({"BLOCK_D": 2048, "BLOCK_N": 8}, num_warps=4, num_stages=1),
|
| 204 |
-
# Medium-warp configs
|
| 205 |
-
triton.Config({"BLOCK_D": 2048, "BLOCK_N": 1}, num_warps=8, num_stages=1),
|
| 206 |
-
triton.Config({"BLOCK_D": 2048, "BLOCK_N": 2}, num_warps=8, num_stages=1),
|
| 207 |
-
triton.Config({"BLOCK_D": 2048, "BLOCK_N": 4}, num_warps=8, num_stages=1),
|
| 208 |
-
triton.Config({"BLOCK_D": 2048, "BLOCK_N": 8}, num_warps=8, num_stages=1),
|
| 209 |
-
# Multi-tile configs (BLOCK_D=1024 for D=1280 -> 2 tiles, no mask waste)
|
| 210 |
-
triton.Config({"BLOCK_D": 1024, "BLOCK_N": 1}, num_warps=2, num_stages=2),
|
| 211 |
-
triton.Config({"BLOCK_D": 1024, "BLOCK_N": 1}, num_warps=4, num_stages=2),
|
| 212 |
-
triton.Config({"BLOCK_D": 1024, "BLOCK_N": 2}, num_warps=2, num_stages=2),
|
| 213 |
-
triton.Config({"BLOCK_D": 1024, "BLOCK_N": 2}, num_warps=4, num_stages=2),
|
| 214 |
-
triton.Config({"BLOCK_D": 1024, "BLOCK_N": 4}, num_warps=4, num_stages=2),
|
| 215 |
-
triton.Config({"BLOCK_D": 1024, "BLOCK_N": 8}, num_warps=4, num_stages=2),
|
| 216 |
-
]
|
| 217 |
-
|
| 218 |
-
@triton.autotune(
|
| 219 |
-
configs=_GROUPED_POLYNORM_FWD_CONFIGS,
|
| 220 |
-
key=["D"],
|
| 221 |
-
)
|
| 222 |
-
@triton.jit
|
| 223 |
-
def _grouped_polynorm_fwd_kernel(
|
| 224 |
-
input_ptr,
|
| 225 |
-
mul_ptr,
|
| 226 |
-
weight_ptr,
|
| 227 |
-
bias_ptr,
|
| 228 |
-
offsets_ptr,
|
| 229 |
-
output_ptr,
|
| 230 |
-
inv_rms_ptr,
|
| 231 |
-
N,
|
| 232 |
-
D,
|
| 233 |
-
num_experts,
|
| 234 |
-
eps,
|
| 235 |
-
expert_offset,
|
| 236 |
-
stride_input_row,
|
| 237 |
-
stride_mul_row,
|
| 238 |
-
stride_out_row,
|
| 239 |
-
BLOCK_D: tl.constexpr,
|
| 240 |
-
):
|
| 241 |
-
"""Forward kernel: one program per row. Saves inv_rms for backward."""
|
| 242 |
-
row = tl.program_id(0)
|
| 243 |
-
if row >= N:
|
| 244 |
-
return
|
| 245 |
-
|
| 246 |
-
# Binary search for expert index (12 iters covers up to 4096 experts)
|
| 247 |
-
lo = 0
|
| 248 |
-
hi = num_experts
|
| 249 |
-
for _ in range(12):
|
| 250 |
-
if lo < hi:
|
| 251 |
-
mid = (lo + hi) // 2
|
| 252 |
-
if tl.load(offsets_ptr + mid) <= row:
|
| 253 |
-
lo = mid + 1
|
| 254 |
-
else:
|
| 255 |
-
hi = mid
|
| 256 |
-
eidx = lo + expert_offset
|
| 257 |
-
|
| 258 |
-
w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32)
|
| 259 |
-
w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32)
|
| 260 |
-
w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32)
|
| 261 |
-
b = tl.load(bias_ptr + eidx).to(tl.float32)
|
| 262 |
-
|
| 263 |
-
input_row_ptr = input_ptr + row * stride_input_row
|
| 264 |
-
mul_row_ptr = mul_ptr + row * stride_mul_row
|
| 265 |
-
out_row_ptr = output_ptr + row * stride_out_row
|
| 266 |
-
|
| 267 |
-
D_float = D.to(tl.float32)
|
| 268 |
-
|
| 269 |
-
# --- Single-tile path ---
|
| 270 |
-
if D <= BLOCK_D:
|
| 271 |
-
d_offs = tl.arange(0, BLOCK_D)
|
| 272 |
-
mask = d_offs < D
|
| 273 |
-
|
| 274 |
-
x = tl.load(input_row_ptr + d_offs, mask=mask,
|
| 275 |
-
other=0.0).to(tl.float32)
|
| 276 |
-
m = tl.load(mul_row_ptr + d_offs, mask=mask,
|
| 277 |
-
other=0.0).to(tl.float32)
|
| 278 |
-
|
| 279 |
-
x2 = x * x
|
| 280 |
-
x3 = x2 * x
|
| 281 |
-
|
| 282 |
-
inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps)
|
| 283 |
-
inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps)
|
| 284 |
-
inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x3 * x3) / D_float + eps)
|
| 285 |
-
|
| 286 |
-
# Save inv_rms for backward
|
| 287 |
-
tl.store(inv_rms_ptr + row * 3 + 0, inv_rms_x)
|
| 288 |
-
tl.store(inv_rms_ptr + row * 3 + 1, inv_rms_x2)
|
| 289 |
-
tl.store(inv_rms_ptr + row * 3 + 2, inv_rms_x3)
|
| 290 |
-
|
| 291 |
-
# Pre-multiply scalar weight * inv_rms to save 1 FMA per element
|
| 292 |
-
w0_inv = w0 * inv_rms_x3
|
| 293 |
-
w1_inv = w1 * inv_rms_x2
|
| 294 |
-
w2_inv = w2 * inv_rms_x
|
| 295 |
-
|
| 296 |
-
poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b
|
| 297 |
-
tl.store(out_row_ptr + d_offs, poly * m, mask=mask)
|
| 298 |
-
else:
|
| 299 |
-
# --- Multi-tile: two-pass approach ---
|
| 300 |
-
sum_x2 = tl.zeros((), dtype=tl.float32)
|
| 301 |
-
sum_x4 = tl.zeros((), dtype=tl.float32)
|
| 302 |
-
sum_x6 = tl.zeros((), dtype=tl.float32)
|
| 303 |
-
|
| 304 |
-
for d_start in range(0, D, BLOCK_D):
|
| 305 |
-
d_offs = d_start + tl.arange(0, BLOCK_D)
|
| 306 |
-
mask = d_offs < D
|
| 307 |
-
x = tl.load(input_row_ptr + d_offs, mask=mask,
|
| 308 |
-
other=0.0).to(tl.float32)
|
| 309 |
-
x2 = x * x
|
| 310 |
-
x3 = x2 * x
|
| 311 |
-
sum_x2 += tl.sum(x2)
|
| 312 |
-
sum_x4 += tl.sum(x2 * x2)
|
| 313 |
-
sum_x6 += tl.sum(x3 * x3)
|
| 314 |
-
|
| 315 |
-
inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps)
|
| 316 |
-
inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps)
|
| 317 |
-
inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps)
|
| 318 |
-
|
| 319 |
-
# Save inv_rms for backward
|
| 320 |
-
tl.store(inv_rms_ptr + row * 3 + 0, inv_rms_x)
|
| 321 |
-
tl.store(inv_rms_ptr + row * 3 + 1, inv_rms_x2)
|
| 322 |
-
tl.store(inv_rms_ptr + row * 3 + 2, inv_rms_x3)
|
| 323 |
-
|
| 324 |
-
# Pre-multiply scalar weight * inv_rms
|
| 325 |
-
w0_inv = w0 * inv_rms_x3
|
| 326 |
-
w1_inv = w1 * inv_rms_x2
|
| 327 |
-
w2_inv = w2 * inv_rms_x
|
| 328 |
-
|
| 329 |
-
for d_start in range(0, D, BLOCK_D):
|
| 330 |
-
d_offs = d_start + tl.arange(0, BLOCK_D)
|
| 331 |
-
mask = d_offs < D
|
| 332 |
-
x = tl.load(input_row_ptr + d_offs, mask=mask,
|
| 333 |
-
other=0.0).to(tl.float32)
|
| 334 |
-
m = tl.load(mul_row_ptr + d_offs, mask=mask,
|
| 335 |
-
other=0.0).to(tl.float32)
|
| 336 |
-
x2 = x * x
|
| 337 |
-
x3 = x2 * x
|
| 338 |
-
poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b
|
| 339 |
-
tl.store(out_row_ptr + d_offs, poly * m, mask=mask)
|
| 340 |
-
|
| 341 |
-
@triton.jit
|
| 342 |
-
def _grouped_polynorm_bwd_kernel(
|
| 343 |
-
grad_out_ptr,
|
| 344 |
-
input_ptr,
|
| 345 |
-
mul_ptr,
|
| 346 |
-
weight_ptr,
|
| 347 |
-
bias_ptr,
|
| 348 |
-
offsets_ptr,
|
| 349 |
-
inv_rms_ptr,
|
| 350 |
-
grad_input_ptr,
|
| 351 |
-
grad_mul_ptr,
|
| 352 |
-
grad_weight_ptr,
|
| 353 |
-
grad_bias_ptr,
|
| 354 |
-
N,
|
| 355 |
-
D,
|
| 356 |
-
num_experts,
|
| 357 |
-
eps,
|
| 358 |
-
expert_offset,
|
| 359 |
-
stride_row,
|
| 360 |
-
BLOCK_D: tl.constexpr,
|
| 361 |
-
BLOCK_N: tl.constexpr,
|
| 362 |
-
):
|
| 363 |
-
"""Backward kernel: BLOCK_N rows per program. Loads saved inv_rms.
|
| 364 |
-
|
| 365 |
-
Each program processes BLOCK_N consecutive rows. Since MoE tokens
|
| 366 |
-
are sorted by expert, consecutive rows often share the same expert,
|
| 367 |
-
allowing weight/bias load reuse and amortized binary search.
|
| 368 |
-
"""
|
| 369 |
-
pid = tl.program_id(0)
|
| 370 |
-
row_start = pid * BLOCK_N
|
| 371 |
-
D_float = D.to(tl.float32)
|
| 372 |
-
d_offs = tl.arange(0, BLOCK_D)
|
| 373 |
-
d_mask = d_offs < D
|
| 374 |
-
|
| 375 |
-
for row_off in tl.static_range(BLOCK_N):
|
| 376 |
-
row = row_start + row_off
|
| 377 |
-
if row < N:
|
| 378 |
-
# Binary search for expert index
|
| 379 |
-
lo = 0
|
| 380 |
-
hi = num_experts
|
| 381 |
-
for _ in range(12):
|
| 382 |
-
if lo < hi:
|
| 383 |
-
mid = (lo + hi) // 2
|
| 384 |
-
if tl.load(offsets_ptr + mid) <= row:
|
| 385 |
-
lo = mid + 1
|
| 386 |
-
else:
|
| 387 |
-
hi = mid
|
| 388 |
-
eidx = lo + expert_offset
|
| 389 |
-
|
| 390 |
-
w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32)
|
| 391 |
-
w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32)
|
| 392 |
-
w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32)
|
| 393 |
-
b_val = tl.load(bias_ptr + eidx).to(tl.float32)
|
| 394 |
-
|
| 395 |
-
input_row_ptr = input_ptr + row * stride_row
|
| 396 |
-
mul_row_ptr = mul_ptr + row * stride_row
|
| 397 |
-
grad_out_row_ptr = grad_out_ptr + row * stride_row
|
| 398 |
-
grad_input_row_ptr = grad_input_ptr + row * stride_row
|
| 399 |
-
grad_mul_row_ptr = grad_mul_ptr + row * stride_row
|
| 400 |
-
|
| 401 |
-
# --- Single-tile path ---
|
| 402 |
-
if D <= BLOCK_D:
|
| 403 |
-
x = tl.load(input_row_ptr + d_offs, mask=d_mask,
|
| 404 |
-
other=0.0).to(tl.float32)
|
| 405 |
-
m = tl.load(mul_row_ptr + d_offs, mask=d_mask,
|
| 406 |
-
other=0.0).to(tl.float32)
|
| 407 |
-
go = tl.load(grad_out_row_ptr + d_offs, mask=d_mask,
|
| 408 |
-
other=0.0).to(tl.float32)
|
| 409 |
-
|
| 410 |
-
x2 = x * x
|
| 411 |
-
x3 = x2 * x
|
| 412 |
-
|
| 413 |
-
# Load saved inv_rms from forward
|
| 414 |
-
inv_rms_x = tl.load(inv_rms_ptr + row * 3 + 0)
|
| 415 |
-
inv_rms_x2 = tl.load(inv_rms_ptr + row * 3 + 1)
|
| 416 |
-
inv_rms_x3 = tl.load(inv_rms_ptr + row * 3 + 2)
|
| 417 |
-
|
| 418 |
-
w0_inv = w0 * inv_rms_x3
|
| 419 |
-
w1_inv = w1 * inv_rms_x2
|
| 420 |
-
w2_inv = w2 * inv_rms_x
|
| 421 |
-
|
| 422 |
-
dpoly = go * m
|
| 423 |
-
|
| 424 |
-
sum_dpoly_x = tl.sum(dpoly * x)
|
| 425 |
-
sum_dpoly_x2 = tl.sum(dpoly * x2)
|
| 426 |
-
sum_dpoly_x3 = tl.sum(dpoly * x3)
|
| 427 |
-
grad_b_acc = tl.sum(dpoly)
|
| 428 |
-
|
| 429 |
-
grad_w0_acc = inv_rms_x3 * sum_dpoly_x3
|
| 430 |
-
grad_w1_acc = inv_rms_x2 * sum_dpoly_x2
|
| 431 |
-
grad_w2_acc = inv_rms_x * sum_dpoly_x
|
| 432 |
-
|
| 433 |
-
coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float
|
| 434 |
-
coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float
|
| 435 |
-
coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float
|
| 436 |
-
|
| 437 |
-
# grad_mul
|
| 438 |
-
poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv
|
| 439 |
-
tl.store(grad_mul_row_ptr + d_offs, go * (poly + b_val),
|
| 440 |
-
mask=d_mask)
|
| 441 |
-
|
| 442 |
-
# grad_input
|
| 443 |
-
g = inv_rms_x * (w2 * dpoly - x * coeff_x)
|
| 444 |
-
g += 2.0 * x * inv_rms_x2 * (w1 * dpoly - x2 * coeff_x2)
|
| 445 |
-
g += 3.0 * x2 * inv_rms_x3 * (w0 * dpoly - x3 * coeff_x3)
|
| 446 |
-
tl.store(grad_input_row_ptr + d_offs, g, mask=d_mask)
|
| 447 |
-
|
| 448 |
-
tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc)
|
| 449 |
-
tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc)
|
| 450 |
-
tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc)
|
| 451 |
-
tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc)
|
| 452 |
-
else:
|
| 453 |
-
# --- Multi-tile: dot products pass ---
|
| 454 |
-
# Load saved inv_rms from forward
|
| 455 |
-
inv_rms_x = tl.load(inv_rms_ptr + row * 3 + 0)
|
| 456 |
-
inv_rms_x2 = tl.load(inv_rms_ptr + row * 3 + 1)
|
| 457 |
-
inv_rms_x3 = tl.load(inv_rms_ptr + row * 3 + 2)
|
| 458 |
-
|
| 459 |
-
sum_dpoly_x = tl.zeros((), dtype=tl.float32)
|
| 460 |
-
sum_dpoly_x2 = tl.zeros((), dtype=tl.float32)
|
| 461 |
-
sum_dpoly_x3 = tl.zeros((), dtype=tl.float32)
|
| 462 |
-
grad_b_acc = tl.zeros((), dtype=tl.float32)
|
| 463 |
-
|
| 464 |
-
for d_start in range(0, D, BLOCK_D):
|
| 465 |
-
tile_offs = d_start + d_offs
|
| 466 |
-
tile_mask = tile_offs < D
|
| 467 |
-
x = tl.load(input_row_ptr + tile_offs, mask=tile_mask,
|
| 468 |
-
other=0.0).to(tl.float32)
|
| 469 |
-
m = tl.load(mul_row_ptr + tile_offs, mask=tile_mask,
|
| 470 |
-
other=0.0).to(tl.float32)
|
| 471 |
-
go = tl.load(grad_out_row_ptr + tile_offs,
|
| 472 |
-
mask=tile_mask, other=0.0).to(tl.float32)
|
| 473 |
-
|
| 474 |
-
x2 = x * x
|
| 475 |
-
x3 = x2 * x
|
| 476 |
-
dpoly = go * m
|
| 477 |
-
|
| 478 |
-
sum_dpoly_x += tl.sum(dpoly * x)
|
| 479 |
-
sum_dpoly_x2 += tl.sum(dpoly * x2)
|
| 480 |
-
sum_dpoly_x3 += tl.sum(dpoly * x3)
|
| 481 |
-
grad_b_acc += tl.sum(dpoly)
|
| 482 |
-
|
| 483 |
-
w0_inv = w0 * inv_rms_x3
|
| 484 |
-
w1_inv = w1 * inv_rms_x2
|
| 485 |
-
w2_inv = w2 * inv_rms_x
|
| 486 |
-
|
| 487 |
-
grad_w0_acc = inv_rms_x3 * sum_dpoly_x3
|
| 488 |
-
grad_w1_acc = inv_rms_x2 * sum_dpoly_x2
|
| 489 |
-
grad_w2_acc = inv_rms_x * sum_dpoly_x
|
| 490 |
-
|
| 491 |
-
coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float
|
| 492 |
-
coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float
|
| 493 |
-
coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float
|
| 494 |
-
|
| 495 |
-
for d_start in range(0, D, BLOCK_D):
|
| 496 |
-
tile_offs = d_start + d_offs
|
| 497 |
-
tile_mask = tile_offs < D
|
| 498 |
-
x = tl.load(input_row_ptr + tile_offs, mask=tile_mask,
|
| 499 |
-
other=0.0).to(tl.float32)
|
| 500 |
-
m = tl.load(mul_row_ptr + tile_offs, mask=tile_mask,
|
| 501 |
-
other=0.0).to(tl.float32)
|
| 502 |
-
go = tl.load(grad_out_row_ptr + tile_offs,
|
| 503 |
-
mask=tile_mask, other=0.0).to(tl.float32)
|
| 504 |
-
|
| 505 |
-
x2 = x * x
|
| 506 |
-
x3 = x2 * x
|
| 507 |
-
|
| 508 |
-
poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv
|
| 509 |
-
tl.store(grad_mul_row_ptr + tile_offs,
|
| 510 |
-
go * (poly + b_val), mask=tile_mask)
|
| 511 |
-
|
| 512 |
-
dpoly = go * m
|
| 513 |
-
g = inv_rms_x * (w2 * dpoly - x * coeff_x)
|
| 514 |
-
g += (2.0 * x * inv_rms_x2 *
|
| 515 |
-
(w1 * dpoly - x2 * coeff_x2))
|
| 516 |
-
g += (3.0 * x2 * inv_rms_x3 *
|
| 517 |
-
(w0 * dpoly - x3 * coeff_x3))
|
| 518 |
-
tl.store(grad_input_row_ptr + tile_offs, g,
|
| 519 |
-
mask=tile_mask)
|
| 520 |
-
|
| 521 |
-
tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc)
|
| 522 |
-
tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc)
|
| 523 |
-
tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc)
|
| 524 |
-
tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc)
|
| 525 |
|
| 526 |
class _GroupedPolyNormFn(torch.autograd.Function):
|
| 527 |
"""Without scores — follows poly_norm.py pattern."""
|
| 528 |
|
| 529 |
@staticmethod
|
| 530 |
-
def forward(input, mul, weight, bias, offsets, eps, expert_offset
|
|
|
|
| 531 |
input = input.contiguous()
|
| 532 |
mul = mul.contiguous()
|
| 533 |
output, inv_rms = _ops.grouped_poly_norm_forward(
|
| 534 |
-
input, mul, weight, bias, offsets, eps, expert_offset
|
|
|
|
| 535 |
return output, inv_rms
|
| 536 |
|
| 537 |
@staticmethod
|
| 538 |
def setup_context(ctx, inputs, output):
|
| 539 |
-
input, mul, weight, bias, offsets, eps, expert_offset
|
|
|
|
| 540 |
_, inv_rms = output
|
| 541 |
ctx.save_for_backward(input, mul, weight, bias, offsets, inv_rms)
|
| 542 |
ctx.eps = eps
|
| 543 |
ctx.expert_offset = expert_offset
|
|
|
|
| 544 |
|
| 545 |
@staticmethod
|
| 546 |
def backward(ctx, grad_output, _grad_inv_rms):
|
|
@@ -548,8 +175,8 @@ if HAS_TRITON:
|
|
| 548 |
grad_output = grad_output.contiguous()
|
| 549 |
gi, gm, gw, gb = _ops.grouped_poly_norm_backward(
|
| 550 |
grad_output, input, mul, weight, bias, offsets, inv_rms,
|
| 551 |
-
ctx.eps, ctx.expert_offset)
|
| 552 |
-
return gi, gm, gw, gb, None, None, None
|
| 553 |
|
| 554 |
class _GroupedPolyNormScoredFn(torch.autograd.Function):
|
| 555 |
"""With scores — same pattern, adds scores + hidden_clamp."""
|
|
@@ -622,7 +249,8 @@ if HAS_TRITON:
|
|
| 622 |
expert_offset, clamp_val)
|
| 623 |
else:
|
| 624 |
output, _ = _GroupedPolyNormFn.apply(
|
| 625 |
-
input, mul, weight, bias, offsets, eps, expert_offset
|
|
|
|
| 626 |
return output
|
| 627 |
|
| 628 |
else:
|
|
@@ -639,5 +267,5 @@ else:
|
|
| 639 |
hidden_clamp: float | None = None,
|
| 640 |
) -> Tensor:
|
| 641 |
raise RuntimeError(
|
| 642 |
-
"
|
| 643 |
-
"fused_mul_grouped_poly_norm.")
|
|
|
|
| 1 |
+
"""Grouped FusedMulPolyNorm for MoE — CUDA kernel with autograd wrappers.
|
| 2 |
|
| 3 |
+
Fuses the entire PolyNorm computation into CUDA kernels (fwd + bwd),
|
| 4 |
+
with optional scores multiplication and hidden_clamp fusion.
|
| 5 |
|
| 6 |
PolyNorm formula (per row):
|
| 7 |
poly = w[0] * rms_norm(x^3) + w[1] * rms_norm(x^2) + w[2] * rms_norm(x) + bias
|
| 8 |
+
output = poly * mul * score
|
| 9 |
+
output = clamp(output, -hidden_clamp, hidden_clamp) (if enabled)
|
| 10 |
|
| 11 |
where rms_norm(x) = x / sqrt(mean(x^2, dim=-1) + eps)
|
| 12 |
|
| 13 |
+
CUDA kernel (activation/grouped_poly_norm.cu):
|
| 14 |
+
- Vectorized loads (width=8 for bf16/fp16, width=4 for fp32)
|
| 15 |
+
- In-kernel binary search for expert mapping
|
| 16 |
+
- 2-pass forward (RMS stats + output), 2-pass backward (dot products + grads)
|
| 17 |
+
- Scores and hidden_clamp fused in-kernel (no extra kernel launches)
|
| 18 |
+
- Weight/bias gradients via atomicAdd
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
"""
|
| 20 |
|
| 21 |
import torch
|
| 22 |
from torch import Tensor
|
| 23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
# Try to load CUDA ops at module level
|
| 25 |
_ops = None
|
| 26 |
try:
|
|
|
|
| 38 |
if _has_cuda_ops:
|
| 39 |
try:
|
| 40 |
@torch.library.register_fake("_activation::grouped_poly_norm_forward")
|
| 41 |
+
def _fwd_fake(input, mul, weight, bias, offsets, eps, expert_offset,
|
| 42 |
+
hidden_clamp):
|
| 43 |
return (torch.empty_like(input),
|
| 44 |
torch.empty(input.shape[0], 3, dtype=torch.float32,
|
| 45 |
device=input.device))
|
| 46 |
|
| 47 |
@torch.library.register_fake("_activation::grouped_poly_norm_backward")
|
| 48 |
def _bwd_fake(grad_output, input, mul, weight, bias, offsets, inv_rms,
|
| 49 |
+
eps, expert_offset, hidden_clamp):
|
| 50 |
return (torch.empty_like(input),
|
| 51 |
torch.empty_like(mul),
|
| 52 |
torch.empty_like(weight),
|
|
|
|
| 142 |
|
| 143 |
|
| 144 |
# ---------------------------------------------------------------------------
|
| 145 |
+
# CUDA kernel autograd functions
|
| 146 |
# ---------------------------------------------------------------------------
|
| 147 |
+
if _has_cuda_ops:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
|
| 149 |
class _GroupedPolyNormFn(torch.autograd.Function):
|
| 150 |
"""Without scores — follows poly_norm.py pattern."""
|
| 151 |
|
| 152 |
@staticmethod
|
| 153 |
+
def forward(input, mul, weight, bias, offsets, eps, expert_offset,
|
| 154 |
+
hidden_clamp):
|
| 155 |
input = input.contiguous()
|
| 156 |
mul = mul.contiguous()
|
| 157 |
output, inv_rms = _ops.grouped_poly_norm_forward(
|
| 158 |
+
input, mul, weight, bias, offsets, eps, expert_offset,
|
| 159 |
+
hidden_clamp)
|
| 160 |
return output, inv_rms
|
| 161 |
|
| 162 |
@staticmethod
|
| 163 |
def setup_context(ctx, inputs, output):
|
| 164 |
+
(input, mul, weight, bias, offsets, eps, expert_offset,
|
| 165 |
+
hidden_clamp) = inputs
|
| 166 |
_, inv_rms = output
|
| 167 |
ctx.save_for_backward(input, mul, weight, bias, offsets, inv_rms)
|
| 168 |
ctx.eps = eps
|
| 169 |
ctx.expert_offset = expert_offset
|
| 170 |
+
ctx.hidden_clamp = hidden_clamp
|
| 171 |
|
| 172 |
@staticmethod
|
| 173 |
def backward(ctx, grad_output, _grad_inv_rms):
|
|
|
|
| 175 |
grad_output = grad_output.contiguous()
|
| 176 |
gi, gm, gw, gb = _ops.grouped_poly_norm_backward(
|
| 177 |
grad_output, input, mul, weight, bias, offsets, inv_rms,
|
| 178 |
+
ctx.eps, ctx.expert_offset, ctx.hidden_clamp)
|
| 179 |
+
return gi, gm, gw, gb, None, None, None, None
|
| 180 |
|
| 181 |
class _GroupedPolyNormScoredFn(torch.autograd.Function):
|
| 182 |
"""With scores — same pattern, adds scores + hidden_clamp."""
|
|
|
|
| 249 |
expert_offset, clamp_val)
|
| 250 |
else:
|
| 251 |
output, _ = _GroupedPolyNormFn.apply(
|
| 252 |
+
input, mul, weight, bias, offsets, eps, expert_offset,
|
| 253 |
+
clamp_val)
|
| 254 |
return output
|
| 255 |
|
| 256 |
else:
|
|
|
|
| 267 |
hidden_clamp: float | None = None,
|
| 268 |
) -> Tensor:
|
| 269 |
raise RuntimeError(
|
| 270 |
+
"CUDA ops not available. Build with setup.py or kernel-builder "
|
| 271 |
+
"to use fused_mul_grouped_poly_norm.")
|
torch-ext/torch_binding.cpp
CHANGED
|
@@ -49,18 +49,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|
| 49 |
ops.impl("fused_add_rms_norm_backward", torch::kCUDA,
|
| 50 |
&fused_add_rms_norm_backward);
|
| 51 |
|
| 52 |
-
// grouped_poly_norm (without scores)
|
| 53 |
ops.def("grouped_poly_norm_forward("
|
| 54 |
"Tensor input, Tensor mul, Tensor weight, "
|
| 55 |
"Tensor bias, Tensor offsets, "
|
| 56 |
-
"float eps, int expert_offset) -> (Tensor, Tensor)");
|
| 57 |
ops.impl("grouped_poly_norm_forward", torch::kCUDA,
|
| 58 |
&grouped_poly_norm_forward);
|
| 59 |
|
| 60 |
ops.def("grouped_poly_norm_backward("
|
| 61 |
"Tensor grad_output, Tensor input, Tensor mul, Tensor weight, "
|
| 62 |
"Tensor bias, Tensor offsets, Tensor inv_rms, "
|
| 63 |
-
"float eps, int expert_offset) -> (Tensor, Tensor, Tensor, Tensor)");
|
| 64 |
ops.impl("grouped_poly_norm_backward", torch::kCUDA,
|
| 65 |
&grouped_poly_norm_backward);
|
| 66 |
|
|
|
|
| 49 |
ops.impl("fused_add_rms_norm_backward", torch::kCUDA,
|
| 50 |
&fused_add_rms_norm_backward);
|
| 51 |
|
| 52 |
+
// grouped_poly_norm (without scores, hidden_clamp < 0 = disabled)
|
| 53 |
ops.def("grouped_poly_norm_forward("
|
| 54 |
"Tensor input, Tensor mul, Tensor weight, "
|
| 55 |
"Tensor bias, Tensor offsets, "
|
| 56 |
+
"float eps, int expert_offset, float hidden_clamp) -> (Tensor, Tensor)");
|
| 57 |
ops.impl("grouped_poly_norm_forward", torch::kCUDA,
|
| 58 |
&grouped_poly_norm_forward);
|
| 59 |
|
| 60 |
ops.def("grouped_poly_norm_backward("
|
| 61 |
"Tensor grad_output, Tensor input, Tensor mul, Tensor weight, "
|
| 62 |
"Tensor bias, Tensor offsets, Tensor inv_rms, "
|
| 63 |
+
"float eps, int expert_offset, float hidden_clamp) -> (Tensor, Tensor, Tensor, Tensor)");
|
| 64 |
ops.impl("grouped_poly_norm_backward", torch::kCUDA,
|
| 65 |
&grouped_poly_norm_backward);
|
| 66 |
|
torch-ext/torch_binding.h
CHANGED
|
@@ -36,19 +36,21 @@ std::tuple<torch::Tensor, torch::Tensor> fused_add_rms_norm_backward(
|
|
| 36 |
const torch::Tensor &input, const torch::Tensor &weight, double eps,
|
| 37 |
bool need_input_grad);
|
| 38 |
|
| 39 |
-
// Without scores
|
| 40 |
std::tuple<torch::Tensor, torch::Tensor>
|
| 41 |
grouped_poly_norm_forward(
|
| 42 |
const torch::Tensor &input, const torch::Tensor &mul,
|
| 43 |
const torch::Tensor &weight, const torch::Tensor &bias,
|
| 44 |
-
const torch::Tensor &offsets, double eps, int64_t expert_offset
|
|
|
|
| 45 |
|
| 46 |
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
|
| 47 |
grouped_poly_norm_backward(
|
| 48 |
const torch::Tensor &grad_output, const torch::Tensor &input,
|
| 49 |
const torch::Tensor &mul, const torch::Tensor &weight,
|
| 50 |
const torch::Tensor &bias, const torch::Tensor &offsets,
|
| 51 |
-
const torch::Tensor &inv_rms, double eps, int64_t expert_offset
|
|
|
|
| 52 |
|
| 53 |
// With scores (hidden_clamp < 0 = disabled)
|
| 54 |
std::tuple<torch::Tensor, torch::Tensor>
|
|
|
|
| 36 |
const torch::Tensor &input, const torch::Tensor &weight, double eps,
|
| 37 |
bool need_input_grad);
|
| 38 |
|
| 39 |
+
// Without scores (hidden_clamp < 0 = disabled)
|
| 40 |
std::tuple<torch::Tensor, torch::Tensor>
|
| 41 |
grouped_poly_norm_forward(
|
| 42 |
const torch::Tensor &input, const torch::Tensor &mul,
|
| 43 |
const torch::Tensor &weight, const torch::Tensor &bias,
|
| 44 |
+
const torch::Tensor &offsets, double eps, int64_t expert_offset,
|
| 45 |
+
double hidden_clamp);
|
| 46 |
|
| 47 |
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
|
| 48 |
grouped_poly_norm_backward(
|
| 49 |
const torch::Tensor &grad_output, const torch::Tensor &input,
|
| 50 |
const torch::Tensor &mul, const torch::Tensor &weight,
|
| 51 |
const torch::Tensor &bias, const torch::Tensor &offsets,
|
| 52 |
+
const torch::Tensor &inv_rms, double eps, int64_t expert_offset,
|
| 53 |
+
double hidden_clamp);
|
| 54 |
|
| 55 |
// With scores (hidden_clamp < 0 = disabled)
|
| 56 |
std::tuple<torch::Tensor, torch::Tensor>
|