diff --git a/README.md b/README.md
index b6bbc514556a9a3dc0e55e3badf6babe628c99e8..dd173875e81233c3156bbb7baefe9c185990a709 100644
--- a/README.md
+++ b/README.md
@@ -19,7 +19,7 @@ Activation is a python package that contains custom CUDA-based activation kernel
```python
y = x + residual
hidden_state = rms_norm(y, weight, eps)
- out = y + some_op(hidden_state)
+ out = y + some_op(hidden_state)
```
- Fused as:
@@ -45,6 +45,22 @@ Activation is a python package that contains custom CUDA-based activation kernel
out = fused_mul_poly_norm(x, a, weight, bias, eps)
```
+ - **GroupedFusedMulPolyNorm** (Triton)
+
+ A Triton-accelerated grouped variant of FusedMulPolyNorm for **MoE (Mixture of Experts)** models. Fuses the entire PolyNorm computation into two Triton kernels (fwd + bwd), with per-expert weights/bias and in-kernel binary search for expert mapping.
+ - Instead of:
+
+ ```python
+ for i, expert in enumerate(experts):
+ out[start:end] = fused_mul_poly_norm(x[start:end], mul[start:end], weight[i], bias[i], eps)
+ ```
+
+ - Fused as:
+
+ ```python
+ out = grouped_fused_mul_poly_norm(x, mul, weight, bias, offsets, eps)
+ ```
+
## Usage
```python
@@ -214,6 +230,118 @@ print(poly_norm(x))
+---
+
+### GroupedFusedMulPolyNorm (Triton)
+
+> [!NOTE]
+> This kernel is implemented in Triton (JIT-compiled, no CUDA C++ build required).
+> Benchmarks compare three variants: **Naive** (raw PyTorch reference), **Compiled** (`torch.compile`'d reference), and **Triton** (fused Triton kernel).
+> Benchmark dimension: 1280, 384 experts.
+
+#### B200 Results (bf16)
+
+
+Forward Performance
+
+| batch_size | seq_len | Naive (us) | Compiled (us) | Triton (us) | Triton vs Naive |
+|-----------|---------|-----------|--------------|------------|-----------------|
+| 1 | 1024 | 294.54 | 73.46 | 64.33 | 4.58x |
+| 1 | 2048 | 373.50 | 94.88 | 65.26 | 5.72x |
+| 1 | 4096 | 372.65 | 94.90 | 66.90 | 5.57x |
+| 1 | 8192 | 486.98 | 102.33 | 72.71 | 6.70x |
+| 2 | 4096 | 486.66 | 101.87 | 72.27 | 6.73x |
+| 2 | 8192 | 950.62 | 106.96 | 90.06 | 10.56x |
+| 4 | 4096 | 950.72 | 107.17 | 71.28 | 13.34x |
+| 4 | 8192 | 1779.12 | 198.91 | 96.93 | 18.35x |
+| 8 | 4096 | 1778.73 | 199.10 | 96.88 | 18.36x |
+| 8 | 8192 | 3384.03 | 381.91 | 179.57 | 18.85x |
+
+
+
+
+Backward Performance
+
+| batch_size | seq_len | Naive (us) | Compiled (us) | Triton (us) | Triton vs Naive |
+|-----------|---------|-----------|--------------|------------|-----------------|
+| 1 | 1024 | 1690.61 | 999.66 | 1017.66 | 1.66x |
+| 1 | 8192 | 1680.39 | 906.43 | 906.41 | 1.85x |
+| 2 | 8192 | 2466.73 | 870.74 | 862.78 | 2.86x |
+| 4 | 4096 | 2466.04 | 942.62 | 945.68 | 2.61x |
+| 4 | 8192 | 4543.10 | 941.01 | 908.30 | 5.00x |
+| 8 | 4096 | 4542.91 | 814.73 | 900.01 | 5.05x |
+| 8 | 8192 | 8599.41 | 956.81 | 955.07 | 9.00x |
+
+
+
+
+Forward + Backward Combined
+
+| batch_size | seq_len | Naive (us) | Compiled (us) | Triton (us) | Triton vs Naive | Triton vs Compiled |
+|-----------|---------|-----------|--------------|------------|-----------------|-------------------|
+| 1 | 1024 | 1985.15 | 1073.12 | 1081.99 | 1.83x | 0.99x |
+| 1 | 4096 | 2085.10 | 974.32 | 960.73 | 2.17x | 1.01x |
+| 1 | 8192 | 2167.37 | 1008.76 | 979.12 | 2.21x | 1.03x |
+| 2 | 4096 | 2083.49 | 1001.03 | 965.30 | 2.16x | 1.04x |
+| 2 | 8192 | 3417.35 | 977.70 | 952.84 | 3.59x | 1.03x |
+| 4 | 4096 | 3416.76 | 1049.79 | 1016.97 | 3.36x | 1.03x |
+| 4 | 8192 | 6322.22 | 1139.92 | 1005.23 | 6.29x | 1.13x |
+| 8 | 4096 | 6321.64 | 1013.83 | 996.89 | 6.34x | 1.02x |
+| 8 | 8192 | 11983.44 | 1338.71 | 1134.64 | 10.56x | 1.18x |
+
+
+
+#### B200 Results (fp32)
+
+
+Forward Performance
+
+| batch_size | seq_len | Naive (us) | Compiled (us) | Triton (us) | Triton vs Naive |
+|-----------|---------|-----------|--------------|------------|-----------------|
+| 1 | 1024 | 318.05 | 83.29 | 64.24 | 4.95x |
+| 1 | 2048 | 311.14 | 95.19 | 63.64 | 4.89x |
+| 1 | 8192 | 401.78 | 101.61 | 68.21 | 5.89x |
+| 2 | 4096 | 403.42 | 100.97 | 68.01 | 5.93x |
+| 2 | 8192 | 803.31 | 130.51 | 68.21 | 11.78x |
+| 4 | 4096 | 802.86 | 130.61 | 66.97 | 11.99x |
+| 4 | 8192 | 1505.96 | 246.77 | 100.49 | 14.99x |
+| 8 | 4096 | 1507.87 | 246.84 | 100.23 | 15.04x |
+| 8 | 8192 | 2856.93 | 476.34 | 184.40 | 15.49x |
+
+
+
+
+Backward Performance
+
+| batch_size | seq_len | Naive (us) | Compiled (us) | Triton (us) | Triton vs Naive |
+|-----------|---------|-----------|--------------|------------|-----------------|
+| 1 | 1024 | 1604.25 | 989.30 | 1114.12 | 1.44x |
+| 1 | 8192 | 1996.40 | 1117.71 | 1115.47 | 1.79x |
+| 2 | 8192 | 2353.87 | 1119.41 | 1118.57 | 2.10x |
+| 4 | 4096 | 2358.47 | 1102.23 | 1125.16 | 2.10x |
+| 4 | 8192 | 4346.92 | 1125.33 | 1135.36 | 3.83x |
+| 8 | 4096 | 4347.47 | 1104.27 | 1119.63 | 3.88x |
+| 8 | 8192 | 8226.50 | 1172.66 | 1197.68 | 6.87x |
+
+
+
+
+Forward + Backward Combined
+
+| batch_size | seq_len | Naive (us) | Compiled (us) | Triton (us) | Triton vs Naive | Triton vs Compiled |
+|-----------|---------|-----------|--------------|------------|-----------------|-------------------|
+| 1 | 1024 | 1922.30 | 1072.59 | 1178.36 | 1.63x | 0.91x |
+| 1 | 4096 | 2367.77 | 1208.69 | 1192.07 | 1.99x | 1.01x |
+| 1 | 8192 | 2398.19 | 1219.32 | 1183.69 | 2.03x | 1.03x |
+| 2 | 4096 | 2401.39 | 1248.87 | 1154.72 | 2.08x | 1.08x |
+| 2 | 8192 | 3157.18 | 1249.92 | 1186.77 | 2.66x | 1.05x |
+| 4 | 4096 | 3161.33 | 1232.84 | 1192.13 | 2.65x | 1.03x |
+| 4 | 8192 | 5852.88 | 1372.10 | 1235.86 | 4.74x | 1.11x |
+| 8 | 4096 | 5855.34 | 1351.11 | 1219.85 | 4.80x | 1.11x |
+| 8 | 8192 | 11083.43 | 1649.00 | 1382.07 | 8.02x | 1.19x |
+
+
+
## Pre-commit Hooks
This project uses [pre-commit](https://pre-commit.com/) to automatically check and format code before commits.
diff --git a/benchmarks/cases/grouped_mul_poly.py b/benchmarks/cases/grouped_mul_poly.py
new file mode 100644
index 0000000000000000000000000000000000000000..8c1268f5595ba6ccce9ea495c31a50b21d31314e
--- /dev/null
+++ b/benchmarks/cases/grouped_mul_poly.py
@@ -0,0 +1,122 @@
+import torch
+import torch._functorch.config
+from common.diff_engine import DiffCase
+
+torch._functorch.config.donated_buffer = False
+
+from grouped_poly_norm import (
+ grouped_fused_mul_poly_norm,
+ grouped_fused_mul_poly_norm_ref,
+)
+
+NUM_EXPERTS = 384
+
+
+class GroupedRefModule(torch.nn.Module):
+ """Wraps the PyTorch reference for grouped FusedMulPolyNorm."""
+
+ def __init__(self, weight, bias, offsets, eps, expert_offset=0):
+ super().__init__()
+ self.weight = torch.nn.Parameter(weight)
+ self.bias = torch.nn.Parameter(bias)
+ self.offsets = offsets
+ self.eps = eps
+ self.expert_offset = expert_offset
+
+ def forward(self, x, mul):
+ return grouped_fused_mul_poly_norm_ref(x, mul, self.weight, self.bias,
+ self.offsets, self.eps,
+ expert_offset=self.expert_offset)
+
+
+class GroupedTritonModule(torch.nn.Module):
+ """Wraps the Triton kernel for grouped FusedMulPolyNorm."""
+
+ def __init__(self, weight, bias, offsets, eps, expert_offset=0):
+ super().__init__()
+ self.weight = torch.nn.Parameter(weight)
+ self.bias = torch.nn.Parameter(bias)
+ self.offsets = offsets
+ self.eps = eps
+ self.expert_offset = expert_offset
+
+ def forward(self, x, mul):
+ return grouped_fused_mul_poly_norm(x, mul, self.weight, self.bias,
+ self.offsets, self.eps,
+ expert_offset=self.expert_offset)
+
+
+class GroupedMulPoly(DiffCase):
+ """Benchmark case for Grouped FusedMulPolyNorm (MoE).
+
+ Maps the framework's (bs, sl, hidden) to grouped polynorm's
+ (total_tokens, D) where total_tokens = bs * sl.
+ Uses a fixed number of experts with uniform token distribution.
+ """
+
+ def build_inputs(self, bs, sl, hidden, dtype, eps):
+ total_tokens = bs * sl
+ num_experts = min(NUM_EXPERTS, total_tokens)
+
+ torch.manual_seed(42)
+ probs = torch.ones(num_experts) / num_experts
+ assignments = torch.multinomial(probs, total_tokens, replacement=True)
+ counts = torch.bincount(assignments, minlength=num_experts).tolist()
+ offsets = torch.cumsum(
+ torch.tensor(counts, dtype=torch.int32), dim=0)
+
+ return {
+ "x":
+ torch.randn(total_tokens, hidden, dtype=dtype,
+ requires_grad=True) * 0.5,
+ "mul":
+ torch.randn(total_tokens, hidden, dtype=dtype,
+ requires_grad=True) * 0.5,
+ "weight":
+ torch.ones(num_experts, 3, dtype=dtype) / 3 +
+ torch.randn(num_experts, 3, dtype=dtype) * 0.01,
+ "bias":
+ torch.randn(num_experts, 1, dtype=dtype) * 0.01,
+ "offsets":
+ offsets,
+ "dim":
+ hidden,
+ "eps":
+ eps,
+ "dtype":
+ dtype,
+ }
+
+ def make_naive(self, I):
+ return GroupedRefModule(
+ I["weight"].detach().clone(),
+ I["bias"].detach().clone(),
+ I["offsets"],
+ I["eps"],
+ )
+
+ def make_compiled(self, I):
+ m = GroupedRefModule(
+ I["weight"].detach().clone(),
+ I["bias"].detach().clone(),
+ I["offsets"],
+ I["eps"],
+ )
+ return torch.compile(m)
+
+ def make_cuda(self, I):
+ return GroupedTritonModule(
+ I["weight"].detach().clone(),
+ I["bias"].detach().clone(),
+ I["offsets"],
+ I["eps"],
+ )
+
+ def forward(self, obj, I):
+ return obj(I["x"], I["mul"])
+
+ def grad_inputs(self, I):
+ return [I["x"], I["mul"]]
+
+
+CASE = GroupedMulPoly()
diff --git a/benchmarks/common/bench_framework.py b/benchmarks/common/bench_framework.py
index 49dfe3c7deb1cd6595a2d411a0e17615d8f99e3b..3837703033e37ac9493da568a6a02916ccb4d934 100644
--- a/benchmarks/common/bench_framework.py
+++ b/benchmarks/common/bench_framework.py
@@ -57,7 +57,12 @@ def make_fwd_benchmark_for_case(
I = case.build_inputs(batch_size, seq_len, dim, dtype, eps)
if provider == "speedup":
return timings_ms["naive"][key] / timings_ms["cuda"][key]
- obj = case.make_naive(I) if provider == "naive" else case.make_cuda(I)
+ if provider == "naive":
+ obj = case.make_naive(I)
+ elif provider == "compiled" and hasattr(case, "make_compiled"):
+ obj = case.make_compiled(I)
+ else:
+ obj = case.make_cuda(I)
run = lambda: case.forward(obj, I)
ms = triton.testing.do_bench(run)
timings_ms[provider][key] = ms
@@ -101,7 +106,12 @@ def make_fwd_benchmark_plot_for_case(
return 1.00
batch_size, seq_len, dim = parse_config_string(config)
I = case.build_inputs(batch_size, seq_len, dim, dtype, eps)
- obj = case.make_naive(I) if provider == "naive" else case.make_cuda(I)
+ if provider == "naive":
+ obj = case.make_naive(I)
+ elif provider == "compiled" and hasattr(case, "make_compiled"):
+ obj = case.make_compiled(I)
+ else:
+ obj = case.make_cuda(I)
run = lambda: case.forward(obj, I)
ms = triton.testing.do_bench(run)
timings_ms[provider][config] = ms
@@ -146,7 +156,12 @@ def make_bwd_benchmark_for_case(
I = case.build_inputs(batch_size, seq_len, dim, dtype, eps)
if provider == "speedup":
return timings_ms["naive"][key] / timings_ms["cuda"][key]
- obj = case.make_naive(I) if provider == "naive" else case.make_cuda(I)
+ if provider == "naive":
+ obj = case.make_naive(I)
+ elif provider == "compiled" and hasattr(case, "make_compiled"):
+ obj = case.make_compiled(I)
+ else:
+ obj = case.make_cuda(I)
y = case.forward(obj, I)
gin = list(case.grad_inputs(I)) + list(obj.parameters())
if isinstance(y, torch.Tensor):
@@ -201,7 +216,12 @@ def make_bwd_benchmark_plot_for_case(
return 1.00
batch_size, seq_len, dim = parse_config_string(config)
I = case.build_inputs(batch_size, seq_len, dim, dtype, eps)
- obj = case.make_naive(I) if provider == "naive" else case.make_cuda(I)
+ if provider == "naive":
+ obj = case.make_naive(I)
+ elif provider == "compiled" and hasattr(case, "make_compiled"):
+ obj = case.make_compiled(I)
+ else:
+ obj = case.make_cuda(I)
y = case.forward(obj, I)
gin = list(case.grad_inputs(I)) + list(obj.parameters())
if isinstance(y, torch.Tensor):
diff --git a/benchmarks/run_cases.py b/benchmarks/run_cases.py
index e2e1f746ccf14aed82182c482dc39c05fe03816e..5c4a718b74cf4f1063d806839442d0d83f7e8fdb 100644
--- a/benchmarks/run_cases.py
+++ b/benchmarks/run_cases.py
@@ -23,12 +23,15 @@ def make_title_tag():
return f"[{dev_name} | torch {torch_ver}]"
-def plot_result(r_path):
+def plot_result(r_path, columns=None):
import matplotlib.pyplot as plt
import pandas as pd
df = pd.read_csv(r_path + ".csv")
+ if columns is None:
+ columns = [c for c in ["Naive", "Compiled", "Cuda", "Triton"]
+ if c in df.columns]
plt.figure(figsize=(12, 6))
- ax = df.plot(x="config", y=["Naive", "Cuda"], kind="bar", ax=plt.gca())
+ ax = df.plot(x="config", y=columns, kind="bar", ax=plt.gca())
ax.set_title("Speedup over torch (higher is better)\n" + make_title_tag(),
fontsize=14,
fontweight="bold")
@@ -44,9 +47,10 @@ def plot_result(r_path):
def main():
ap = argparse.ArgumentParser()
- ap.add_argument("--case",
- choices=["rms", "add_rms", "poly", "mul_poly"],
- required=True)
+ ap.add_argument(
+ "--case",
+ choices=["rms", "add_rms", "poly", "mul_poly", "grouped_mul_poly"],
+ required=True)
ap.add_argument("--plot", action="store_true")
ap.add_argument(
"--save-path",
@@ -54,8 +58,25 @@ def main():
default="./configs/",
help="Path to save benchmark results",
)
+ ap.add_argument(
+ "--dtype",
+ choices=["fp16", "bf16", "fp32", "all"],
+ default="bf16",
+ help="Data type for benchmarking (default: bf16)",
+ )
args = ap.parse_args()
+ dtype_map = {
+ "fp16": torch.float16,
+ "bf16": torch.bfloat16,
+ "fp32": torch.float32,
+ }
+ if args.dtype == "all":
+ dtypes = [("fp16", torch.float16), ("bf16", torch.bfloat16),
+ ("fp32", torch.float32)]
+ else:
+ dtypes = [(args.dtype, dtype_map[args.dtype])]
+
torch.set_default_device("cuda")
mod = importlib.import_module(f"cases.{args.case}")
case: DiffCase = mod.CASE
@@ -67,76 +88,118 @@ def main():
hidden_size=4096,
)
- save_dir = os.path.join(args.save_path, args.case)
- if args.plot:
- batch_size_range = [1]
- seq_length_range = [4096, 8192, 16384]
- dim = [8192, 16384] if "poly" in args.case else [2048, 4096]
- configs = list(
- itertools.product(batch_size_range, seq_length_range, dim))
- plot_name = f"plot_{args.case}-fwd-perf"
- bench = make_fwd_benchmark_plot_for_case(
- case=case,
- configs=configs,
- plot_name=plot_name,
- line_names={
- "naive": "Naive",
- "cuda": "Cuda",
- },
- )
- bench.run(print_data=True, save_path=save_dir)
- plot_result(os.path.join(save_dir, plot_name))
-
- plot_name = f"plot_{args.case}-bwd-perf"
- bench = make_bwd_benchmark_plot_for_case(
- case=case,
- configs=configs,
- plot_name=plot_name,
- line_names={
- "naive": "Naive",
- "cuda": "Cuda",
- },
- )
- bench.run(print_data=True, save_path=save_dir)
- plot_result(os.path.join(save_dir, plot_name))
- for f in glob.glob(os.path.join(save_dir, "*.html")) + glob.glob(
- os.path.join(save_dir, "*.csv")):
- os.remove(f)
- else:
- batch_size_range = [2**i for i in range(0, 4, 1)]
- seq_length_range = [2**i for i in range(10, 14, 1)]
- dim = [8192, 16384] if "poly" in args.case else [2048, 4096]
- configs = list(
- itertools.product(dim, batch_size_range, seq_length_range))
-
- bench = make_fwd_benchmark_for_case(
- case=case,
- configs=configs,
- plot_name=f"{args.case}-fwd-perf",
- line_names={
- "naive": "Naive",
- "cuda": "Cuda",
- "speedup": "SpeedUp"
- },
- )
-
- bench.run(print_data=True, save_path=save_dir)
-
- bench = make_bwd_benchmark_for_case(
- case=case,
- configs=configs,
- plot_name=f"{args.case}-bwd-perf",
- line_names={
- "naive": "Naive",
- "cuda": "Cuda",
- "speedup": "SpeedUp"
- },
- )
-
- bench.run(print_data=True, save_path=save_dir)
- for f in glob.glob(os.path.join(save_dir, "*.html")) + glob.glob(
- os.path.join(save_dir, "*.png")):
- os.remove(f)
+ for dtype_name, dtype in dtypes:
+ print(f"\n{'=' * 60}")
+ print(f" Benchmarking dtype: {dtype_name} ({dtype})")
+ print(f"{'=' * 60}\n")
+
+ save_dir = os.path.join(args.save_path, args.case, dtype_name)
+ is_grouped = args.case == "grouped_mul_poly"
+
+ if args.plot:
+ batch_size_range = [1]
+ seq_length_range = [4096, 8192, 16384]
+ if is_grouped:
+ dim = [1280]
+ elif "poly" in args.case:
+ dim = [8192, 16384]
+ else:
+ dim = [2048, 4096]
+ configs = list(
+ itertools.product(batch_size_range, seq_length_range, dim))
+
+ if is_grouped:
+ plot_line_vals = ("naive", "compiled", "cuda")
+ plot_line_names = {
+ "naive": "Naive",
+ "compiled": "Compiled",
+ "cuda": "Triton",
+ }
+ else:
+ plot_line_vals = ("naive", "cuda")
+ plot_line_names = {
+ "naive": "Naive",
+ "cuda": "Cuda",
+ }
+
+ plot_name = f"plot_{args.case}-{dtype_name}-fwd-perf"
+ bench = make_fwd_benchmark_plot_for_case(
+ case=case,
+ configs=configs,
+ plot_name=plot_name,
+ dtype=dtype,
+ line_vals=plot_line_vals,
+ line_names=plot_line_names,
+ )
+ bench.run(print_data=True, save_path=save_dir)
+ plot_result(os.path.join(save_dir, plot_name))
+
+ plot_name = f"plot_{args.case}-{dtype_name}-bwd-perf"
+ bench = make_bwd_benchmark_plot_for_case(
+ case=case,
+ configs=configs,
+ plot_name=plot_name,
+ dtype=dtype,
+ line_vals=plot_line_vals,
+ line_names=plot_line_names,
+ )
+ bench.run(print_data=True, save_path=save_dir)
+ plot_result(os.path.join(save_dir, plot_name))
+ for f in glob.glob(os.path.join(save_dir, "*.html")) + \
+ glob.glob(os.path.join(save_dir, "*.csv")):
+ os.remove(f)
+ else:
+ batch_size_range = [2**i for i in range(0, 4, 1)]
+ seq_length_range = [2**i for i in range(10, 14, 1)]
+ if is_grouped:
+ dim = [1280]
+ elif "poly" in args.case:
+ dim = [8192, 16384]
+ else:
+ dim = [2048, 4096]
+ configs = list(
+ itertools.product(dim, batch_size_range, seq_length_range))
+
+ if is_grouped:
+ csv_line_vals = ("naive", "compiled", "cuda", "speedup")
+ csv_line_names = {
+ "naive": "Naive",
+ "compiled": "Compiled",
+ "cuda": "Triton",
+ "speedup": "SpeedUp",
+ }
+ else:
+ csv_line_vals = ("naive", "cuda", "speedup")
+ csv_line_names = {
+ "naive": "Naive",
+ "cuda": "Cuda",
+ "speedup": "SpeedUp",
+ }
+
+ bench = make_fwd_benchmark_for_case(
+ case=case,
+ configs=configs,
+ plot_name=f"{args.case}-{dtype_name}-fwd-perf",
+ dtype=dtype,
+ line_vals=csv_line_vals,
+ line_names=csv_line_names,
+ )
+
+ bench.run(print_data=True, save_path=save_dir)
+
+ bench = make_bwd_benchmark_for_case(
+ case=case,
+ configs=configs,
+ plot_name=f"{args.case}-{dtype_name}-bwd-perf",
+ dtype=dtype,
+ line_vals=csv_line_vals,
+ line_names=csv_line_names,
+ )
+
+ bench.run(print_data=True, save_path=save_dir)
+ for f in glob.glob(os.path.join(save_dir, "*.html")) + \
+ glob.glob(os.path.join(save_dir, "*.png")):
+ os.remove(f)
if __name__ == "__main__":
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/__init__.py b/build/torch210-cxx11-cu126-x86_64-linux/__init__.py
index 0f6f29ac2c688bd09afa41c5d1abd9942c4456d8..a71567edf5422a315e6f06eae6a5606756794820 100644
--- a/build/torch210-cxx11-cu126-x86_64-linux/__init__.py
+++ b/build/torch210-cxx11-cu126-x86_64-linux/__init__.py
@@ -2,6 +2,7 @@ import torch
from . import layers, parallel_style
from ._ops import ops
+from .grouped_poly_norm import grouped_fused_mul_poly_norm
from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
@@ -45,6 +46,7 @@ def fused_add_rms_norm(
__all__ = [
"poly_norm",
"fused_mul_poly_norm",
+ "grouped_fused_mul_poly_norm",
"rms_norm",
"fused_add_rms_norm",
"layers",
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/_activation_0e6f27f_dirty.abi3.so b/build/torch210-cxx11-cu126-x86_64-linux/_activation_0e6f27f_dirty.abi3.so
new file mode 100755
index 0000000000000000000000000000000000000000..caa5f9e4c72751c1cc34a718812babba0149ff6f
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/_activation_0e6f27f_dirty.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f31dfeac9b22c01a027f858b3d8beaf87eea9adf8dc45902f0e43d6c264fd985
+size 10775296
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/_activation_18b7543_dirty.abi3.so b/build/torch210-cxx11-cu126-x86_64-linux/_activation_18b7543_dirty.abi3.so
deleted file mode 100755
index da939124912cd385bcb4d1e02878ecf9ffe53ad8..0000000000000000000000000000000000000000
--- a/build/torch210-cxx11-cu126-x86_64-linux/_activation_18b7543_dirty.abi3.so
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:39a7e25002120a73ea83ac813276c0518086fae2236f528dadf96bac4876a270
-size 10775296
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/_ops.py b/build/torch210-cxx11-cu126-x86_64-linux/_ops.py
index 1e14ae6ea50287a7fbc0feee54855ae4ffc0be5e..91c4fcaab95d4cccfa1e53fea482ec1a2334f4ff 100644
--- a/build/torch210-cxx11-cu126-x86_64-linux/_ops.py
+++ b/build/torch210-cxx11-cu126-x86_64-linux/_ops.py
@@ -1,9 +1,9 @@
import torch
-from . import _activation_18b7543_dirty
-ops = torch.ops._activation_18b7543_dirty
+from . import _activation_0e6f27f_dirty
+ops = torch.ops._activation_0e6f27f_dirty
def add_op_namespace_prefix(op_name: str):
"""
Prefix op by namespace.
"""
- return f"_activation_18b7543_dirty::{op_name}"
\ No newline at end of file
+ return f"_activation_0e6f27f_dirty::{op_name}"
\ No newline at end of file
diff --git a/build/torch210-cxx11-cu126-x86_64-linux/grouped_poly_norm.py b/build/torch210-cxx11-cu126-x86_64-linux/grouped_poly_norm.py
new file mode 100644
index 0000000000000000000000000000000000000000..9184736c1c24584a94aa80c84fc18a95ab557a52
--- /dev/null
+++ b/build/torch210-cxx11-cu126-x86_64-linux/grouped_poly_norm.py
@@ -0,0 +1,583 @@
+"""Triton-accelerated Grouped FusedMulPolyNorm for MoE.
+
+Fuses the entire PolyNorm computation into two Triton kernels (fwd + bwd),
+eliminating multiple intermediate tensors and kernel launches.
+
+PolyNorm formula (per row):
+ poly = w[0] * rms_norm(x^3) + w[1] * rms_norm(x^2) + w[2] * rms_norm(x) + bias
+ output = poly * mul
+
+where rms_norm(x) = x / sqrt(mean(x^2, dim=-1) + eps)
+
+Performance optimizations:
+ - @triton.autotune selects optimal BLOCK_D, num_warps, and num_stages per
+ hidden dimension.
+ - Single-tile specialization: when D <= BLOCK_D, all data stays in registers
+ across the reduction and output phases, eliminating redundant global reads.
+ - Multi-tile software pipelining: explicit num_stages in autotune configs
+ enables overlapping memory loads with computation across loop iterations.
+ - In-kernel binary search for expert mapping: eliminates 2 PyTorch kernel
+ launches (torch.arange + torch.bucketize) per forward/backward call.
+ - Backward 2-pass optimization: pass 1 merges RMS statistics computation
+ with dot product accumulation, pass 2 computes gradients. This reduces
+ memory traffic compared to a naive 3-pass approach.
+
+Forward kernel: one program per row, tiles over D dimension.
+ - Computes x, x^2, x^3 in registers
+ - Computes three RMS norms in a single pass (shared variance reduction)
+ - Applies polynomial weights + bias + mul in-place
+
+Backward kernel: one program per row, tiles over D dimension.
+ - Recomputes forward intermediates from saved inputs (activation recomputation)
+ - 2-pass: (1) RMS stats + dot products + bias grad, (2) grad_input + grad_mul + weight grads
+ - Weight/bias gradients use tl.atomic_add for cross-row accumulation
+"""
+
+import torch
+from torch import Tensor
+
+try:
+ import triton
+ import triton.language as tl
+
+ HAS_TRITON = True
+except ImportError:
+ HAS_TRITON = False
+
+
+# ---------------------------------------------------------------------------
+# PyTorch reference implementation (for testing and benchmarking)
+# ---------------------------------------------------------------------------
+def _rms_norm(x: Tensor, eps: float) -> Tensor:
+ """Per-row RMS normalization: x / sqrt(mean(x^2, dim=-1) + eps)"""
+ return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps)
+
+
+def grouped_fused_mul_poly_norm_ref(
+ input: Tensor,
+ mul: Tensor,
+ weight: Tensor,
+ bias: Tensor,
+ offsets: Tensor,
+ eps: float = 1e-6,
+ expert_offset: int = 0,
+) -> Tensor:
+ """PyTorch reference for Grouped FusedMulPolyNorm (vectorized, single pass).
+
+ Uses torch.bucketize to map tokens to experts, then computes PolyNorm
+ for all tokens at once. torch.compile friendly.
+
+ Args:
+ input: (total_tokens, D) - concatenated tokens for all experts
+ mul: (total_tokens, D) - gate values to multiply with
+ weight: (num_experts, 3) - per-expert polynomial weights [x^3, x^2, x]
+ bias: (num_experts, 1) - per-expert polynomial bias
+ offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32)
+ eps: numerical stability epsilon
+
+ Returns:
+ (total_tokens, D) - output tensor
+ """
+ orig_dtype = input.dtype
+
+ token_positions = torch.arange(input.shape[0], device=input.device)
+ expert_idx = torch.bucketize(token_positions, offsets, right=True) + expert_offset
+
+ weight_fp32 = weight.float()
+ bias_fp32 = bias.float()
+
+ per_token_w = weight_fp32[expert_idx]
+ per_token_b = bias_fp32[expert_idx]
+
+ x = input.float()
+ m = mul.float()
+
+ x2 = x * x
+ x3 = x2 * x
+
+ poly = (per_token_w[:, 0:1] * _rms_norm(x3, eps) +
+ per_token_w[:, 1:2] * _rms_norm(x2, eps) +
+ per_token_w[:, 2:3] * _rms_norm(x, eps) + per_token_b)
+
+ return (poly * m).to(orig_dtype)
+
+
+# ---------------------------------------------------------------------------
+# Triton kernel implementation
+# ---------------------------------------------------------------------------
+if HAS_TRITON:
+ # --- Autotune configurations ---
+ _GROUPED_POLYNORM_FWD_CONFIGS = [
+ triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2),
+ triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2),
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2),
+ triton.Config({"BLOCK_D": 2048}, num_warps=4, num_stages=1),
+ triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1),
+ triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1),
+ triton.Config({"BLOCK_D": 2048}, num_warps=32, num_stages=1),
+ ]
+
+ _GROUPED_POLYNORM_BWD_CONFIGS = [
+ triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2),
+ triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=5),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2),
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2),
+ triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1),
+ triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1),
+ ]
+
+ @triton.autotune(
+ configs=_GROUPED_POLYNORM_FWD_CONFIGS,
+ key=["D"],
+ )
+ @triton.jit
+ def _grouped_polynorm_fwd_kernel(
+ input_ptr,
+ mul_ptr,
+ weight_ptr,
+ bias_ptr,
+ offsets_ptr,
+ output_ptr,
+ N,
+ D,
+ num_experts,
+ eps,
+ expert_offset,
+ stride_input_row,
+ stride_mul_row,
+ stride_out_row,
+ BLOCK_D: tl.constexpr,
+ ):
+ """Forward kernel: one program per row."""
+ row = tl.program_id(0)
+ if row >= N:
+ return
+
+ # Binary search for expert index (12 iters covers up to 4096 experts)
+ lo = 0
+ hi = num_experts
+ for _ in range(12):
+ if lo < hi:
+ mid = (lo + hi) // 2
+ if tl.load(offsets_ptr + mid) <= row:
+ lo = mid + 1
+ else:
+ hi = mid
+ eidx = lo + expert_offset
+
+ w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32)
+ w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32)
+ w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32)
+ b = tl.load(bias_ptr + eidx).to(tl.float32)
+
+ input_row_ptr = input_ptr + row * stride_input_row
+ mul_row_ptr = mul_ptr + row * stride_mul_row
+ out_row_ptr = output_ptr + row * stride_out_row
+
+ D_float = D.to(tl.float32)
+
+ # --- Single-tile path ---
+ if D <= BLOCK_D:
+ d_offs = tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+
+ x2 = x * x
+ x3 = x2 * x
+
+ inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps)
+ inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps)
+ inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps)
+
+ # Pre-multiply scalar weight * inv_rms to save 1 FMA per element
+ w0_inv = w0 * inv_rms_x3
+ w1_inv = w1 * inv_rms_x2
+ w2_inv = w2 * inv_rms_x
+
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b
+ tl.store(out_row_ptr + d_offs, poly * m, mask=mask)
+ else:
+ # --- Multi-tile: two-pass approach ---
+ sum_x2 = tl.zeros((), dtype=tl.float32)
+ sum_x4 = tl.zeros((), dtype=tl.float32)
+ sum_x6 = tl.zeros((), dtype=tl.float32)
+
+ for d_start in range(0, D, BLOCK_D):
+ d_offs = d_start + tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ x2 = x * x
+ sum_x2 += tl.sum(x2)
+ sum_x4 += tl.sum(x2 * x2)
+ sum_x6 += tl.sum(x2 * x2 * x2)
+
+ inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps)
+ inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps)
+ inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps)
+
+ # Pre-multiply scalar weight * inv_rms
+ w0_inv = w0 * inv_rms_x3
+ w1_inv = w1 * inv_rms_x2
+ w2_inv = w2 * inv_rms_x
+
+ for d_start in range(0, D, BLOCK_D):
+ d_offs = d_start + tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ x2 = x * x
+ x3 = x2 * x
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b
+ tl.store(out_row_ptr + d_offs, poly * m, mask=mask)
+
+ @triton.autotune(
+ configs=_GROUPED_POLYNORM_BWD_CONFIGS,
+ key=["D"],
+ reset_to_zero=["grad_weight_ptr", "grad_bias_ptr"],
+ )
+ @triton.jit
+ def _grouped_polynorm_bwd_kernel(
+ grad_out_ptr,
+ input_ptr,
+ mul_ptr,
+ weight_ptr,
+ bias_ptr,
+ offsets_ptr,
+ grad_input_ptr,
+ grad_mul_ptr,
+ grad_weight_ptr,
+ grad_bias_ptr,
+ N,
+ D,
+ num_experts,
+ eps,
+ expert_offset,
+ stride_row,
+ BLOCK_D: tl.constexpr,
+ ):
+ """Backward kernel: one program per row, 2-pass approach.
+
+ Pass 1: RMS stats + dot products + bias grad
+ Pass 2: grad_input + grad_mul + weight grads (via atomic_add)
+ """
+ row = tl.program_id(0)
+ if row >= N:
+ return
+
+ lo = 0
+ hi = num_experts
+ for _ in range(12):
+ if lo < hi:
+ mid = (lo + hi) // 2
+ if tl.load(offsets_ptr + mid) <= row:
+ lo = mid + 1
+ else:
+ hi = mid
+ eidx = lo + expert_offset
+
+ w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32)
+ w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32)
+ w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32)
+ b_val = tl.load(bias_ptr + eidx).to(tl.float32)
+
+ input_row_ptr = input_ptr + row * stride_row
+ mul_row_ptr = mul_ptr + row * stride_row
+ grad_out_row_ptr = grad_out_ptr + row * stride_row
+ grad_input_row_ptr = grad_input_ptr + row * stride_row
+ grad_mul_row_ptr = grad_mul_ptr + row * stride_row
+
+ D_float = D.to(tl.float32)
+
+ # --- Single-tile path ---
+ if D <= BLOCK_D:
+ d_offs = tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+
+ x2 = x * x
+ x3 = x2 * x
+
+ # Compute RMS stats (x4 inlined to reduce register pressure)
+ inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps)
+ inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps)
+ inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps)
+
+ w0_inv = w0 * inv_rms_x3
+ w1_inv = w1 * inv_rms_x2
+ w2_inv = w2 * inv_rms_x
+
+ dpoly = go * m
+
+ # Dot products for coefficients and weight grads
+ sum_dpoly_x = tl.sum(dpoly * x)
+ sum_dpoly_x2 = tl.sum(dpoly * x2)
+ sum_dpoly_x3 = tl.sum(dpoly * x3)
+ grad_b_acc = tl.sum(dpoly)
+
+ # Weight grads
+ grad_w0_acc = inv_rms_x3 * sum_dpoly_x3
+ grad_w1_acc = inv_rms_x2 * sum_dpoly_x2
+ grad_w2_acc = inv_rms_x * sum_dpoly_x
+
+ # Coefficients for grad_input
+ coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float
+ coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float
+ coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float
+
+ # grad_mul
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv
+ tl.store(grad_mul_row_ptr + d_offs, go * (poly + b_val), mask=mask)
+
+ # grad_input (in-place accumulation to reduce register pressure)
+ g = inv_rms_x * (w2 * dpoly - x * coeff_x)
+ g += 2.0 * x * inv_rms_x2 * (w1 * dpoly - x2 * coeff_x2)
+ g += 3.0 * x2 * inv_rms_x3 * (w0 * dpoly - x3 * coeff_x3)
+ tl.store(grad_input_row_ptr + d_offs, g, mask=mask)
+
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc)
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc)
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc)
+ tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc)
+ else:
+ # --- Multi-tile: 2-pass ---
+ # Pass 1: RMS stats + dot products + bias grad
+ sum_x2 = tl.zeros((), dtype=tl.float32)
+ sum_x4 = tl.zeros((), dtype=tl.float32)
+ sum_x6 = tl.zeros((), dtype=tl.float32)
+ sum_dpoly_x = tl.zeros((), dtype=tl.float32)
+ sum_dpoly_x2 = tl.zeros((), dtype=tl.float32)
+ sum_dpoly_x3 = tl.zeros((), dtype=tl.float32)
+ grad_b_acc = tl.zeros((), dtype=tl.float32)
+
+ for d_start in range(0, D, BLOCK_D):
+ d_offs = d_start + tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+
+ x2 = x * x
+ x3 = x2 * x
+ dpoly = go * m
+
+ sum_x2 += tl.sum(x2)
+ sum_x4 += tl.sum(x2 * x2)
+ sum_x6 += tl.sum(x2 * x2 * x2)
+ sum_dpoly_x += tl.sum(dpoly * x)
+ sum_dpoly_x2 += tl.sum(dpoly * x2)
+ sum_dpoly_x3 += tl.sum(dpoly * x3)
+ grad_b_acc += tl.sum(dpoly)
+
+ inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps)
+ inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps)
+ inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps)
+
+ w0_inv = w0 * inv_rms_x3
+ w1_inv = w1 * inv_rms_x2
+ w2_inv = w2 * inv_rms_x
+
+ # Weight grads
+ grad_w0_acc = inv_rms_x3 * sum_dpoly_x3
+ grad_w1_acc = inv_rms_x2 * sum_dpoly_x2
+ grad_w2_acc = inv_rms_x * sum_dpoly_x
+
+ # Coefficients for grad_input
+ coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float
+ coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float
+ coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float
+
+ # Pass 2: grad_input + grad_mul
+ for d_start in range(0, D, BLOCK_D):
+ d_offs = d_start + tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+
+ x2 = x * x
+ x3 = x2 * x
+
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv
+ tl.store(grad_mul_row_ptr + d_offs,
+ go * (poly + b_val),
+ mask=mask)
+
+ dpoly = go * m
+ g = inv_rms_x * (w2 * dpoly - x * coeff_x)
+ g += (2.0 * x * inv_rms_x2 *
+ (w1 * dpoly - x2 * coeff_x2))
+ g += (3.0 * x2 * inv_rms_x3 *
+ (w0 * dpoly - x3 * coeff_x3))
+
+ tl.store(grad_input_row_ptr + d_offs, g, mask=mask)
+
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc)
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc)
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc)
+ tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc)
+
+ class _GroupedPolyNormFn(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, input, mul, weight, bias, offsets, eps, expert_offset):
+ N, D = input.shape
+ input = input.contiguous()
+ mul = mul.contiguous()
+ output = torch.empty_like(input)
+
+ num_experts = offsets.shape[0]
+ assert num_experts <= 4096, (
+ f"Supports at most 4096 experts, got {num_experts}.")
+
+ _grouped_polynorm_fwd_kernel[(N,)](
+ input,
+ mul,
+ weight,
+ bias,
+ offsets,
+ output,
+ N,
+ D,
+ num_experts,
+ eps,
+ expert_offset,
+ stride_input_row=input.stride(0),
+ stride_mul_row=mul.stride(0),
+ stride_out_row=output.stride(0),
+ )
+
+ ctx.save_for_backward(input, mul, weight, bias, offsets)
+ ctx.eps = eps
+ ctx.expert_offset = expert_offset
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input, mul, weight, bias, offsets = ctx.saved_tensors
+ eps = ctx.eps
+ expert_offset = ctx.expert_offset
+ N, D = input.shape
+
+ grad_output = grad_output.contiguous()
+ grad_input = torch.empty_like(input)
+ grad_mul = torch.empty_like(mul)
+ grad_weight = torch.zeros(weight.shape[0],
+ 3,
+ device=weight.device,
+ dtype=torch.float32)
+ grad_bias = torch.zeros(bias.shape[0],
+ device=bias.device,
+ dtype=torch.float32)
+
+ num_experts = offsets.shape[0]
+
+ _grouped_polynorm_bwd_kernel[(N,)](
+ grad_output,
+ input,
+ mul,
+ weight,
+ bias,
+ offsets,
+ grad_input,
+ grad_mul,
+ grad_weight,
+ grad_bias,
+ N,
+ D,
+ num_experts,
+ eps,
+ expert_offset,
+ stride_row=input.stride(0),
+ )
+
+ grad_weight = grad_weight.to(weight.dtype)
+ grad_bias = grad_bias.unsqueeze(-1).to(bias.dtype)
+
+ return grad_input, grad_mul, grad_weight, grad_bias, None, None, None
+
+ def grouped_fused_mul_poly_norm(
+ input: Tensor,
+ mul: Tensor,
+ weight: Tensor,
+ bias: Tensor,
+ offsets: Tensor,
+ eps: float = 1e-6,
+ expert_offset: int = 0,
+ ) -> Tensor:
+ """Triton-accelerated Grouped FusedMulPolyNorm.
+
+ Args:
+ input: (total_tokens, D) - concatenated tokens for all experts
+ mul: (total_tokens, D) - gate values to multiply with
+ weight: (num_experts, 3) - per-expert polynomial weights
+ bias: (num_experts, 1) - per-expert polynomial bias
+ offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32)
+ eps: numerical stability epsilon
+ expert_offset: offset to add to expert index
+
+ Returns:
+ (total_tokens, D) - output tensor
+ """
+ return _GroupedPolyNormFn.apply(input, mul, weight, bias, offsets, eps,
+ expert_offset)
+
+else:
+
+ def grouped_fused_mul_poly_norm(
+ input: Tensor,
+ mul: Tensor,
+ weight: Tensor,
+ bias: Tensor,
+ offsets: Tensor,
+ eps: float = 1e-6,
+ expert_offset: int = 0,
+ ) -> Tensor:
+ raise RuntimeError(
+ "Triton is not available. Install triton to use "
+ "grouped_fused_mul_poly_norm.")
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/__init__.py b/build/torch210-cxx11-cu128-x86_64-linux/__init__.py
index 0f6f29ac2c688bd09afa41c5d1abd9942c4456d8..a71567edf5422a315e6f06eae6a5606756794820 100644
--- a/build/torch210-cxx11-cu128-x86_64-linux/__init__.py
+++ b/build/torch210-cxx11-cu128-x86_64-linux/__init__.py
@@ -2,6 +2,7 @@ import torch
from . import layers, parallel_style
from ._ops import ops
+from .grouped_poly_norm import grouped_fused_mul_poly_norm
from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
@@ -45,6 +46,7 @@ def fused_add_rms_norm(
__all__ = [
"poly_norm",
"fused_mul_poly_norm",
+ "grouped_fused_mul_poly_norm",
"rms_norm",
"fused_add_rms_norm",
"layers",
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/_activation_0e6f27f_dirty.abi3.so b/build/torch210-cxx11-cu128-x86_64-linux/_activation_0e6f27f_dirty.abi3.so
new file mode 100755
index 0000000000000000000000000000000000000000..92030076050496a09036e47299e309717014fe79
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/_activation_0e6f27f_dirty.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4b8370d2e1f5561ae4b77ac8ae7b3a084e33a0d1952a8f5f9bf4700375313b35
+size 15815392
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/_activation_18b7543_dirty.abi3.so b/build/torch210-cxx11-cu128-x86_64-linux/_activation_18b7543_dirty.abi3.so
deleted file mode 100755
index aed6525c923f5f2a1825b682c75d3ad4b1dfa245..0000000000000000000000000000000000000000
--- a/build/torch210-cxx11-cu128-x86_64-linux/_activation_18b7543_dirty.abi3.so
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:078853c2db399a227822ea0c8e70c2e13bad41bfa370657dd19aa2efb3b503e9
-size 15815392
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/_ops.py b/build/torch210-cxx11-cu128-x86_64-linux/_ops.py
index 1e14ae6ea50287a7fbc0feee54855ae4ffc0be5e..91c4fcaab95d4cccfa1e53fea482ec1a2334f4ff 100644
--- a/build/torch210-cxx11-cu128-x86_64-linux/_ops.py
+++ b/build/torch210-cxx11-cu128-x86_64-linux/_ops.py
@@ -1,9 +1,9 @@
import torch
-from . import _activation_18b7543_dirty
-ops = torch.ops._activation_18b7543_dirty
+from . import _activation_0e6f27f_dirty
+ops = torch.ops._activation_0e6f27f_dirty
def add_op_namespace_prefix(op_name: str):
"""
Prefix op by namespace.
"""
- return f"_activation_18b7543_dirty::{op_name}"
\ No newline at end of file
+ return f"_activation_0e6f27f_dirty::{op_name}"
\ No newline at end of file
diff --git a/build/torch210-cxx11-cu128-x86_64-linux/grouped_poly_norm.py b/build/torch210-cxx11-cu128-x86_64-linux/grouped_poly_norm.py
new file mode 100644
index 0000000000000000000000000000000000000000..9184736c1c24584a94aa80c84fc18a95ab557a52
--- /dev/null
+++ b/build/torch210-cxx11-cu128-x86_64-linux/grouped_poly_norm.py
@@ -0,0 +1,583 @@
+"""Triton-accelerated Grouped FusedMulPolyNorm for MoE.
+
+Fuses the entire PolyNorm computation into two Triton kernels (fwd + bwd),
+eliminating multiple intermediate tensors and kernel launches.
+
+PolyNorm formula (per row):
+ poly = w[0] * rms_norm(x^3) + w[1] * rms_norm(x^2) + w[2] * rms_norm(x) + bias
+ output = poly * mul
+
+where rms_norm(x) = x / sqrt(mean(x^2, dim=-1) + eps)
+
+Performance optimizations:
+ - @triton.autotune selects optimal BLOCK_D, num_warps, and num_stages per
+ hidden dimension.
+ - Single-tile specialization: when D <= BLOCK_D, all data stays in registers
+ across the reduction and output phases, eliminating redundant global reads.
+ - Multi-tile software pipelining: explicit num_stages in autotune configs
+ enables overlapping memory loads with computation across loop iterations.
+ - In-kernel binary search for expert mapping: eliminates 2 PyTorch kernel
+ launches (torch.arange + torch.bucketize) per forward/backward call.
+ - Backward 2-pass optimization: pass 1 merges RMS statistics computation
+ with dot product accumulation, pass 2 computes gradients. This reduces
+ memory traffic compared to a naive 3-pass approach.
+
+Forward kernel: one program per row, tiles over D dimension.
+ - Computes x, x^2, x^3 in registers
+ - Computes three RMS norms in a single pass (shared variance reduction)
+ - Applies polynomial weights + bias + mul in-place
+
+Backward kernel: one program per row, tiles over D dimension.
+ - Recomputes forward intermediates from saved inputs (activation recomputation)
+ - 2-pass: (1) RMS stats + dot products + bias grad, (2) grad_input + grad_mul + weight grads
+ - Weight/bias gradients use tl.atomic_add for cross-row accumulation
+"""
+
+import torch
+from torch import Tensor
+
+try:
+ import triton
+ import triton.language as tl
+
+ HAS_TRITON = True
+except ImportError:
+ HAS_TRITON = False
+
+
+# ---------------------------------------------------------------------------
+# PyTorch reference implementation (for testing and benchmarking)
+# ---------------------------------------------------------------------------
+def _rms_norm(x: Tensor, eps: float) -> Tensor:
+ """Per-row RMS normalization: x / sqrt(mean(x^2, dim=-1) + eps)"""
+ return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps)
+
+
+def grouped_fused_mul_poly_norm_ref(
+ input: Tensor,
+ mul: Tensor,
+ weight: Tensor,
+ bias: Tensor,
+ offsets: Tensor,
+ eps: float = 1e-6,
+ expert_offset: int = 0,
+) -> Tensor:
+ """PyTorch reference for Grouped FusedMulPolyNorm (vectorized, single pass).
+
+ Uses torch.bucketize to map tokens to experts, then computes PolyNorm
+ for all tokens at once. torch.compile friendly.
+
+ Args:
+ input: (total_tokens, D) - concatenated tokens for all experts
+ mul: (total_tokens, D) - gate values to multiply with
+ weight: (num_experts, 3) - per-expert polynomial weights [x^3, x^2, x]
+ bias: (num_experts, 1) - per-expert polynomial bias
+ offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32)
+ eps: numerical stability epsilon
+
+ Returns:
+ (total_tokens, D) - output tensor
+ """
+ orig_dtype = input.dtype
+
+ token_positions = torch.arange(input.shape[0], device=input.device)
+ expert_idx = torch.bucketize(token_positions, offsets, right=True) + expert_offset
+
+ weight_fp32 = weight.float()
+ bias_fp32 = bias.float()
+
+ per_token_w = weight_fp32[expert_idx]
+ per_token_b = bias_fp32[expert_idx]
+
+ x = input.float()
+ m = mul.float()
+
+ x2 = x * x
+ x3 = x2 * x
+
+ poly = (per_token_w[:, 0:1] * _rms_norm(x3, eps) +
+ per_token_w[:, 1:2] * _rms_norm(x2, eps) +
+ per_token_w[:, 2:3] * _rms_norm(x, eps) + per_token_b)
+
+ return (poly * m).to(orig_dtype)
+
+
+# ---------------------------------------------------------------------------
+# Triton kernel implementation
+# ---------------------------------------------------------------------------
+if HAS_TRITON:
+ # --- Autotune configurations ---
+ _GROUPED_POLYNORM_FWD_CONFIGS = [
+ triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2),
+ triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2),
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2),
+ triton.Config({"BLOCK_D": 2048}, num_warps=4, num_stages=1),
+ triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1),
+ triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1),
+ triton.Config({"BLOCK_D": 2048}, num_warps=32, num_stages=1),
+ ]
+
+ _GROUPED_POLYNORM_BWD_CONFIGS = [
+ triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2),
+ triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=5),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2),
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2),
+ triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1),
+ triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1),
+ ]
+
+ @triton.autotune(
+ configs=_GROUPED_POLYNORM_FWD_CONFIGS,
+ key=["D"],
+ )
+ @triton.jit
+ def _grouped_polynorm_fwd_kernel(
+ input_ptr,
+ mul_ptr,
+ weight_ptr,
+ bias_ptr,
+ offsets_ptr,
+ output_ptr,
+ N,
+ D,
+ num_experts,
+ eps,
+ expert_offset,
+ stride_input_row,
+ stride_mul_row,
+ stride_out_row,
+ BLOCK_D: tl.constexpr,
+ ):
+ """Forward kernel: one program per row."""
+ row = tl.program_id(0)
+ if row >= N:
+ return
+
+ # Binary search for expert index (12 iters covers up to 4096 experts)
+ lo = 0
+ hi = num_experts
+ for _ in range(12):
+ if lo < hi:
+ mid = (lo + hi) // 2
+ if tl.load(offsets_ptr + mid) <= row:
+ lo = mid + 1
+ else:
+ hi = mid
+ eidx = lo + expert_offset
+
+ w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32)
+ w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32)
+ w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32)
+ b = tl.load(bias_ptr + eidx).to(tl.float32)
+
+ input_row_ptr = input_ptr + row * stride_input_row
+ mul_row_ptr = mul_ptr + row * stride_mul_row
+ out_row_ptr = output_ptr + row * stride_out_row
+
+ D_float = D.to(tl.float32)
+
+ # --- Single-tile path ---
+ if D <= BLOCK_D:
+ d_offs = tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+
+ x2 = x * x
+ x3 = x2 * x
+
+ inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps)
+ inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps)
+ inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps)
+
+ # Pre-multiply scalar weight * inv_rms to save 1 FMA per element
+ w0_inv = w0 * inv_rms_x3
+ w1_inv = w1 * inv_rms_x2
+ w2_inv = w2 * inv_rms_x
+
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b
+ tl.store(out_row_ptr + d_offs, poly * m, mask=mask)
+ else:
+ # --- Multi-tile: two-pass approach ---
+ sum_x2 = tl.zeros((), dtype=tl.float32)
+ sum_x4 = tl.zeros((), dtype=tl.float32)
+ sum_x6 = tl.zeros((), dtype=tl.float32)
+
+ for d_start in range(0, D, BLOCK_D):
+ d_offs = d_start + tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ x2 = x * x
+ sum_x2 += tl.sum(x2)
+ sum_x4 += tl.sum(x2 * x2)
+ sum_x6 += tl.sum(x2 * x2 * x2)
+
+ inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps)
+ inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps)
+ inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps)
+
+ # Pre-multiply scalar weight * inv_rms
+ w0_inv = w0 * inv_rms_x3
+ w1_inv = w1 * inv_rms_x2
+ w2_inv = w2 * inv_rms_x
+
+ for d_start in range(0, D, BLOCK_D):
+ d_offs = d_start + tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ x2 = x * x
+ x3 = x2 * x
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b
+ tl.store(out_row_ptr + d_offs, poly * m, mask=mask)
+
+ @triton.autotune(
+ configs=_GROUPED_POLYNORM_BWD_CONFIGS,
+ key=["D"],
+ reset_to_zero=["grad_weight_ptr", "grad_bias_ptr"],
+ )
+ @triton.jit
+ def _grouped_polynorm_bwd_kernel(
+ grad_out_ptr,
+ input_ptr,
+ mul_ptr,
+ weight_ptr,
+ bias_ptr,
+ offsets_ptr,
+ grad_input_ptr,
+ grad_mul_ptr,
+ grad_weight_ptr,
+ grad_bias_ptr,
+ N,
+ D,
+ num_experts,
+ eps,
+ expert_offset,
+ stride_row,
+ BLOCK_D: tl.constexpr,
+ ):
+ """Backward kernel: one program per row, 2-pass approach.
+
+ Pass 1: RMS stats + dot products + bias grad
+ Pass 2: grad_input + grad_mul + weight grads (via atomic_add)
+ """
+ row = tl.program_id(0)
+ if row >= N:
+ return
+
+ lo = 0
+ hi = num_experts
+ for _ in range(12):
+ if lo < hi:
+ mid = (lo + hi) // 2
+ if tl.load(offsets_ptr + mid) <= row:
+ lo = mid + 1
+ else:
+ hi = mid
+ eidx = lo + expert_offset
+
+ w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32)
+ w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32)
+ w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32)
+ b_val = tl.load(bias_ptr + eidx).to(tl.float32)
+
+ input_row_ptr = input_ptr + row * stride_row
+ mul_row_ptr = mul_ptr + row * stride_row
+ grad_out_row_ptr = grad_out_ptr + row * stride_row
+ grad_input_row_ptr = grad_input_ptr + row * stride_row
+ grad_mul_row_ptr = grad_mul_ptr + row * stride_row
+
+ D_float = D.to(tl.float32)
+
+ # --- Single-tile path ---
+ if D <= BLOCK_D:
+ d_offs = tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+
+ x2 = x * x
+ x3 = x2 * x
+
+ # Compute RMS stats (x4 inlined to reduce register pressure)
+ inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps)
+ inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps)
+ inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps)
+
+ w0_inv = w0 * inv_rms_x3
+ w1_inv = w1 * inv_rms_x2
+ w2_inv = w2 * inv_rms_x
+
+ dpoly = go * m
+
+ # Dot products for coefficients and weight grads
+ sum_dpoly_x = tl.sum(dpoly * x)
+ sum_dpoly_x2 = tl.sum(dpoly * x2)
+ sum_dpoly_x3 = tl.sum(dpoly * x3)
+ grad_b_acc = tl.sum(dpoly)
+
+ # Weight grads
+ grad_w0_acc = inv_rms_x3 * sum_dpoly_x3
+ grad_w1_acc = inv_rms_x2 * sum_dpoly_x2
+ grad_w2_acc = inv_rms_x * sum_dpoly_x
+
+ # Coefficients for grad_input
+ coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float
+ coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float
+ coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float
+
+ # grad_mul
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv
+ tl.store(grad_mul_row_ptr + d_offs, go * (poly + b_val), mask=mask)
+
+ # grad_input (in-place accumulation to reduce register pressure)
+ g = inv_rms_x * (w2 * dpoly - x * coeff_x)
+ g += 2.0 * x * inv_rms_x2 * (w1 * dpoly - x2 * coeff_x2)
+ g += 3.0 * x2 * inv_rms_x3 * (w0 * dpoly - x3 * coeff_x3)
+ tl.store(grad_input_row_ptr + d_offs, g, mask=mask)
+
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc)
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc)
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc)
+ tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc)
+ else:
+ # --- Multi-tile: 2-pass ---
+ # Pass 1: RMS stats + dot products + bias grad
+ sum_x2 = tl.zeros((), dtype=tl.float32)
+ sum_x4 = tl.zeros((), dtype=tl.float32)
+ sum_x6 = tl.zeros((), dtype=tl.float32)
+ sum_dpoly_x = tl.zeros((), dtype=tl.float32)
+ sum_dpoly_x2 = tl.zeros((), dtype=tl.float32)
+ sum_dpoly_x3 = tl.zeros((), dtype=tl.float32)
+ grad_b_acc = tl.zeros((), dtype=tl.float32)
+
+ for d_start in range(0, D, BLOCK_D):
+ d_offs = d_start + tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+
+ x2 = x * x
+ x3 = x2 * x
+ dpoly = go * m
+
+ sum_x2 += tl.sum(x2)
+ sum_x4 += tl.sum(x2 * x2)
+ sum_x6 += tl.sum(x2 * x2 * x2)
+ sum_dpoly_x += tl.sum(dpoly * x)
+ sum_dpoly_x2 += tl.sum(dpoly * x2)
+ sum_dpoly_x3 += tl.sum(dpoly * x3)
+ grad_b_acc += tl.sum(dpoly)
+
+ inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps)
+ inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps)
+ inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps)
+
+ w0_inv = w0 * inv_rms_x3
+ w1_inv = w1 * inv_rms_x2
+ w2_inv = w2 * inv_rms_x
+
+ # Weight grads
+ grad_w0_acc = inv_rms_x3 * sum_dpoly_x3
+ grad_w1_acc = inv_rms_x2 * sum_dpoly_x2
+ grad_w2_acc = inv_rms_x * sum_dpoly_x
+
+ # Coefficients for grad_input
+ coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float
+ coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float
+ coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float
+
+ # Pass 2: grad_input + grad_mul
+ for d_start in range(0, D, BLOCK_D):
+ d_offs = d_start + tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+
+ x2 = x * x
+ x3 = x2 * x
+
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv
+ tl.store(grad_mul_row_ptr + d_offs,
+ go * (poly + b_val),
+ mask=mask)
+
+ dpoly = go * m
+ g = inv_rms_x * (w2 * dpoly - x * coeff_x)
+ g += (2.0 * x * inv_rms_x2 *
+ (w1 * dpoly - x2 * coeff_x2))
+ g += (3.0 * x2 * inv_rms_x3 *
+ (w0 * dpoly - x3 * coeff_x3))
+
+ tl.store(grad_input_row_ptr + d_offs, g, mask=mask)
+
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc)
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc)
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc)
+ tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc)
+
+ class _GroupedPolyNormFn(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, input, mul, weight, bias, offsets, eps, expert_offset):
+ N, D = input.shape
+ input = input.contiguous()
+ mul = mul.contiguous()
+ output = torch.empty_like(input)
+
+ num_experts = offsets.shape[0]
+ assert num_experts <= 4096, (
+ f"Supports at most 4096 experts, got {num_experts}.")
+
+ _grouped_polynorm_fwd_kernel[(N,)](
+ input,
+ mul,
+ weight,
+ bias,
+ offsets,
+ output,
+ N,
+ D,
+ num_experts,
+ eps,
+ expert_offset,
+ stride_input_row=input.stride(0),
+ stride_mul_row=mul.stride(0),
+ stride_out_row=output.stride(0),
+ )
+
+ ctx.save_for_backward(input, mul, weight, bias, offsets)
+ ctx.eps = eps
+ ctx.expert_offset = expert_offset
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input, mul, weight, bias, offsets = ctx.saved_tensors
+ eps = ctx.eps
+ expert_offset = ctx.expert_offset
+ N, D = input.shape
+
+ grad_output = grad_output.contiguous()
+ grad_input = torch.empty_like(input)
+ grad_mul = torch.empty_like(mul)
+ grad_weight = torch.zeros(weight.shape[0],
+ 3,
+ device=weight.device,
+ dtype=torch.float32)
+ grad_bias = torch.zeros(bias.shape[0],
+ device=bias.device,
+ dtype=torch.float32)
+
+ num_experts = offsets.shape[0]
+
+ _grouped_polynorm_bwd_kernel[(N,)](
+ grad_output,
+ input,
+ mul,
+ weight,
+ bias,
+ offsets,
+ grad_input,
+ grad_mul,
+ grad_weight,
+ grad_bias,
+ N,
+ D,
+ num_experts,
+ eps,
+ expert_offset,
+ stride_row=input.stride(0),
+ )
+
+ grad_weight = grad_weight.to(weight.dtype)
+ grad_bias = grad_bias.unsqueeze(-1).to(bias.dtype)
+
+ return grad_input, grad_mul, grad_weight, grad_bias, None, None, None
+
+ def grouped_fused_mul_poly_norm(
+ input: Tensor,
+ mul: Tensor,
+ weight: Tensor,
+ bias: Tensor,
+ offsets: Tensor,
+ eps: float = 1e-6,
+ expert_offset: int = 0,
+ ) -> Tensor:
+ """Triton-accelerated Grouped FusedMulPolyNorm.
+
+ Args:
+ input: (total_tokens, D) - concatenated tokens for all experts
+ mul: (total_tokens, D) - gate values to multiply with
+ weight: (num_experts, 3) - per-expert polynomial weights
+ bias: (num_experts, 1) - per-expert polynomial bias
+ offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32)
+ eps: numerical stability epsilon
+ expert_offset: offset to add to expert index
+
+ Returns:
+ (total_tokens, D) - output tensor
+ """
+ return _GroupedPolyNormFn.apply(input, mul, weight, bias, offsets, eps,
+ expert_offset)
+
+else:
+
+ def grouped_fused_mul_poly_norm(
+ input: Tensor,
+ mul: Tensor,
+ weight: Tensor,
+ bias: Tensor,
+ offsets: Tensor,
+ eps: float = 1e-6,
+ expert_offset: int = 0,
+ ) -> Tensor:
+ raise RuntimeError(
+ "Triton is not available. Install triton to use "
+ "grouped_fused_mul_poly_norm.")
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/__init__.py b/build/torch210-cxx11-cu130-x86_64-linux/__init__.py
index 0f6f29ac2c688bd09afa41c5d1abd9942c4456d8..a71567edf5422a315e6f06eae6a5606756794820 100644
--- a/build/torch210-cxx11-cu130-x86_64-linux/__init__.py
+++ b/build/torch210-cxx11-cu130-x86_64-linux/__init__.py
@@ -2,6 +2,7 @@ import torch
from . import layers, parallel_style
from ._ops import ops
+from .grouped_poly_norm import grouped_fused_mul_poly_norm
from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
@@ -45,6 +46,7 @@ def fused_add_rms_norm(
__all__ = [
"poly_norm",
"fused_mul_poly_norm",
+ "grouped_fused_mul_poly_norm",
"rms_norm",
"fused_add_rms_norm",
"layers",
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/_activation_0e6f27f_dirty.abi3.so b/build/torch210-cxx11-cu130-x86_64-linux/_activation_0e6f27f_dirty.abi3.so
new file mode 100755
index 0000000000000000000000000000000000000000..1136e14de3dcfc6580e54891538ea24e6804e1e0
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/_activation_0e6f27f_dirty.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:edf3fca2079788750c4e0497012ba93c34c770aca4c9d4f22d03be4a86a2ce8c
+size 13520952
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/_activation_18b7543_dirty.abi3.so b/build/torch210-cxx11-cu130-x86_64-linux/_activation_18b7543_dirty.abi3.so
deleted file mode 100755
index 56e86ea0e950d091490d62c3eadf538156cd6fb3..0000000000000000000000000000000000000000
--- a/build/torch210-cxx11-cu130-x86_64-linux/_activation_18b7543_dirty.abi3.so
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:59e2c13071e1807a6225c5ad7a4a7eb04d46b1f177ae6344d199a9e7f14daf92
-size 13520952
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/_ops.py b/build/torch210-cxx11-cu130-x86_64-linux/_ops.py
index 1e14ae6ea50287a7fbc0feee54855ae4ffc0be5e..91c4fcaab95d4cccfa1e53fea482ec1a2334f4ff 100644
--- a/build/torch210-cxx11-cu130-x86_64-linux/_ops.py
+++ b/build/torch210-cxx11-cu130-x86_64-linux/_ops.py
@@ -1,9 +1,9 @@
import torch
-from . import _activation_18b7543_dirty
-ops = torch.ops._activation_18b7543_dirty
+from . import _activation_0e6f27f_dirty
+ops = torch.ops._activation_0e6f27f_dirty
def add_op_namespace_prefix(op_name: str):
"""
Prefix op by namespace.
"""
- return f"_activation_18b7543_dirty::{op_name}"
\ No newline at end of file
+ return f"_activation_0e6f27f_dirty::{op_name}"
\ No newline at end of file
diff --git a/build/torch210-cxx11-cu130-x86_64-linux/grouped_poly_norm.py b/build/torch210-cxx11-cu130-x86_64-linux/grouped_poly_norm.py
new file mode 100644
index 0000000000000000000000000000000000000000..9184736c1c24584a94aa80c84fc18a95ab557a52
--- /dev/null
+++ b/build/torch210-cxx11-cu130-x86_64-linux/grouped_poly_norm.py
@@ -0,0 +1,583 @@
+"""Triton-accelerated Grouped FusedMulPolyNorm for MoE.
+
+Fuses the entire PolyNorm computation into two Triton kernels (fwd + bwd),
+eliminating multiple intermediate tensors and kernel launches.
+
+PolyNorm formula (per row):
+ poly = w[0] * rms_norm(x^3) + w[1] * rms_norm(x^2) + w[2] * rms_norm(x) + bias
+ output = poly * mul
+
+where rms_norm(x) = x / sqrt(mean(x^2, dim=-1) + eps)
+
+Performance optimizations:
+ - @triton.autotune selects optimal BLOCK_D, num_warps, and num_stages per
+ hidden dimension.
+ - Single-tile specialization: when D <= BLOCK_D, all data stays in registers
+ across the reduction and output phases, eliminating redundant global reads.
+ - Multi-tile software pipelining: explicit num_stages in autotune configs
+ enables overlapping memory loads with computation across loop iterations.
+ - In-kernel binary search for expert mapping: eliminates 2 PyTorch kernel
+ launches (torch.arange + torch.bucketize) per forward/backward call.
+ - Backward 2-pass optimization: pass 1 merges RMS statistics computation
+ with dot product accumulation, pass 2 computes gradients. This reduces
+ memory traffic compared to a naive 3-pass approach.
+
+Forward kernel: one program per row, tiles over D dimension.
+ - Computes x, x^2, x^3 in registers
+ - Computes three RMS norms in a single pass (shared variance reduction)
+ - Applies polynomial weights + bias + mul in-place
+
+Backward kernel: one program per row, tiles over D dimension.
+ - Recomputes forward intermediates from saved inputs (activation recomputation)
+ - 2-pass: (1) RMS stats + dot products + bias grad, (2) grad_input + grad_mul + weight grads
+ - Weight/bias gradients use tl.atomic_add for cross-row accumulation
+"""
+
+import torch
+from torch import Tensor
+
+try:
+ import triton
+ import triton.language as tl
+
+ HAS_TRITON = True
+except ImportError:
+ HAS_TRITON = False
+
+
+# ---------------------------------------------------------------------------
+# PyTorch reference implementation (for testing and benchmarking)
+# ---------------------------------------------------------------------------
+def _rms_norm(x: Tensor, eps: float) -> Tensor:
+ """Per-row RMS normalization: x / sqrt(mean(x^2, dim=-1) + eps)"""
+ return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps)
+
+
+def grouped_fused_mul_poly_norm_ref(
+ input: Tensor,
+ mul: Tensor,
+ weight: Tensor,
+ bias: Tensor,
+ offsets: Tensor,
+ eps: float = 1e-6,
+ expert_offset: int = 0,
+) -> Tensor:
+ """PyTorch reference for Grouped FusedMulPolyNorm (vectorized, single pass).
+
+ Uses torch.bucketize to map tokens to experts, then computes PolyNorm
+ for all tokens at once. torch.compile friendly.
+
+ Args:
+ input: (total_tokens, D) - concatenated tokens for all experts
+ mul: (total_tokens, D) - gate values to multiply with
+ weight: (num_experts, 3) - per-expert polynomial weights [x^3, x^2, x]
+ bias: (num_experts, 1) - per-expert polynomial bias
+ offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32)
+ eps: numerical stability epsilon
+
+ Returns:
+ (total_tokens, D) - output tensor
+ """
+ orig_dtype = input.dtype
+
+ token_positions = torch.arange(input.shape[0], device=input.device)
+ expert_idx = torch.bucketize(token_positions, offsets, right=True) + expert_offset
+
+ weight_fp32 = weight.float()
+ bias_fp32 = bias.float()
+
+ per_token_w = weight_fp32[expert_idx]
+ per_token_b = bias_fp32[expert_idx]
+
+ x = input.float()
+ m = mul.float()
+
+ x2 = x * x
+ x3 = x2 * x
+
+ poly = (per_token_w[:, 0:1] * _rms_norm(x3, eps) +
+ per_token_w[:, 1:2] * _rms_norm(x2, eps) +
+ per_token_w[:, 2:3] * _rms_norm(x, eps) + per_token_b)
+
+ return (poly * m).to(orig_dtype)
+
+
+# ---------------------------------------------------------------------------
+# Triton kernel implementation
+# ---------------------------------------------------------------------------
+if HAS_TRITON:
+ # --- Autotune configurations ---
+ _GROUPED_POLYNORM_FWD_CONFIGS = [
+ triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2),
+ triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2),
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2),
+ triton.Config({"BLOCK_D": 2048}, num_warps=4, num_stages=1),
+ triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1),
+ triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1),
+ triton.Config({"BLOCK_D": 2048}, num_warps=32, num_stages=1),
+ ]
+
+ _GROUPED_POLYNORM_BWD_CONFIGS = [
+ triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2),
+ triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=5),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2),
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2),
+ triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1),
+ triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1),
+ ]
+
+ @triton.autotune(
+ configs=_GROUPED_POLYNORM_FWD_CONFIGS,
+ key=["D"],
+ )
+ @triton.jit
+ def _grouped_polynorm_fwd_kernel(
+ input_ptr,
+ mul_ptr,
+ weight_ptr,
+ bias_ptr,
+ offsets_ptr,
+ output_ptr,
+ N,
+ D,
+ num_experts,
+ eps,
+ expert_offset,
+ stride_input_row,
+ stride_mul_row,
+ stride_out_row,
+ BLOCK_D: tl.constexpr,
+ ):
+ """Forward kernel: one program per row."""
+ row = tl.program_id(0)
+ if row >= N:
+ return
+
+ # Binary search for expert index (12 iters covers up to 4096 experts)
+ lo = 0
+ hi = num_experts
+ for _ in range(12):
+ if lo < hi:
+ mid = (lo + hi) // 2
+ if tl.load(offsets_ptr + mid) <= row:
+ lo = mid + 1
+ else:
+ hi = mid
+ eidx = lo + expert_offset
+
+ w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32)
+ w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32)
+ w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32)
+ b = tl.load(bias_ptr + eidx).to(tl.float32)
+
+ input_row_ptr = input_ptr + row * stride_input_row
+ mul_row_ptr = mul_ptr + row * stride_mul_row
+ out_row_ptr = output_ptr + row * stride_out_row
+
+ D_float = D.to(tl.float32)
+
+ # --- Single-tile path ---
+ if D <= BLOCK_D:
+ d_offs = tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+
+ x2 = x * x
+ x3 = x2 * x
+
+ inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps)
+ inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps)
+ inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps)
+
+ # Pre-multiply scalar weight * inv_rms to save 1 FMA per element
+ w0_inv = w0 * inv_rms_x3
+ w1_inv = w1 * inv_rms_x2
+ w2_inv = w2 * inv_rms_x
+
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b
+ tl.store(out_row_ptr + d_offs, poly * m, mask=mask)
+ else:
+ # --- Multi-tile: two-pass approach ---
+ sum_x2 = tl.zeros((), dtype=tl.float32)
+ sum_x4 = tl.zeros((), dtype=tl.float32)
+ sum_x6 = tl.zeros((), dtype=tl.float32)
+
+ for d_start in range(0, D, BLOCK_D):
+ d_offs = d_start + tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ x2 = x * x
+ sum_x2 += tl.sum(x2)
+ sum_x4 += tl.sum(x2 * x2)
+ sum_x6 += tl.sum(x2 * x2 * x2)
+
+ inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps)
+ inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps)
+ inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps)
+
+ # Pre-multiply scalar weight * inv_rms
+ w0_inv = w0 * inv_rms_x3
+ w1_inv = w1 * inv_rms_x2
+ w2_inv = w2 * inv_rms_x
+
+ for d_start in range(0, D, BLOCK_D):
+ d_offs = d_start + tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ x2 = x * x
+ x3 = x2 * x
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b
+ tl.store(out_row_ptr + d_offs, poly * m, mask=mask)
+
+ @triton.autotune(
+ configs=_GROUPED_POLYNORM_BWD_CONFIGS,
+ key=["D"],
+ reset_to_zero=["grad_weight_ptr", "grad_bias_ptr"],
+ )
+ @triton.jit
+ def _grouped_polynorm_bwd_kernel(
+ grad_out_ptr,
+ input_ptr,
+ mul_ptr,
+ weight_ptr,
+ bias_ptr,
+ offsets_ptr,
+ grad_input_ptr,
+ grad_mul_ptr,
+ grad_weight_ptr,
+ grad_bias_ptr,
+ N,
+ D,
+ num_experts,
+ eps,
+ expert_offset,
+ stride_row,
+ BLOCK_D: tl.constexpr,
+ ):
+ """Backward kernel: one program per row, 2-pass approach.
+
+ Pass 1: RMS stats + dot products + bias grad
+ Pass 2: grad_input + grad_mul + weight grads (via atomic_add)
+ """
+ row = tl.program_id(0)
+ if row >= N:
+ return
+
+ lo = 0
+ hi = num_experts
+ for _ in range(12):
+ if lo < hi:
+ mid = (lo + hi) // 2
+ if tl.load(offsets_ptr + mid) <= row:
+ lo = mid + 1
+ else:
+ hi = mid
+ eidx = lo + expert_offset
+
+ w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32)
+ w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32)
+ w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32)
+ b_val = tl.load(bias_ptr + eidx).to(tl.float32)
+
+ input_row_ptr = input_ptr + row * stride_row
+ mul_row_ptr = mul_ptr + row * stride_row
+ grad_out_row_ptr = grad_out_ptr + row * stride_row
+ grad_input_row_ptr = grad_input_ptr + row * stride_row
+ grad_mul_row_ptr = grad_mul_ptr + row * stride_row
+
+ D_float = D.to(tl.float32)
+
+ # --- Single-tile path ---
+ if D <= BLOCK_D:
+ d_offs = tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+
+ x2 = x * x
+ x3 = x2 * x
+
+ # Compute RMS stats (x4 inlined to reduce register pressure)
+ inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps)
+ inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps)
+ inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps)
+
+ w0_inv = w0 * inv_rms_x3
+ w1_inv = w1 * inv_rms_x2
+ w2_inv = w2 * inv_rms_x
+
+ dpoly = go * m
+
+ # Dot products for coefficients and weight grads
+ sum_dpoly_x = tl.sum(dpoly * x)
+ sum_dpoly_x2 = tl.sum(dpoly * x2)
+ sum_dpoly_x3 = tl.sum(dpoly * x3)
+ grad_b_acc = tl.sum(dpoly)
+
+ # Weight grads
+ grad_w0_acc = inv_rms_x3 * sum_dpoly_x3
+ grad_w1_acc = inv_rms_x2 * sum_dpoly_x2
+ grad_w2_acc = inv_rms_x * sum_dpoly_x
+
+ # Coefficients for grad_input
+ coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float
+ coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float
+ coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float
+
+ # grad_mul
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv
+ tl.store(grad_mul_row_ptr + d_offs, go * (poly + b_val), mask=mask)
+
+ # grad_input (in-place accumulation to reduce register pressure)
+ g = inv_rms_x * (w2 * dpoly - x * coeff_x)
+ g += 2.0 * x * inv_rms_x2 * (w1 * dpoly - x2 * coeff_x2)
+ g += 3.0 * x2 * inv_rms_x3 * (w0 * dpoly - x3 * coeff_x3)
+ tl.store(grad_input_row_ptr + d_offs, g, mask=mask)
+
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc)
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc)
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc)
+ tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc)
+ else:
+ # --- Multi-tile: 2-pass ---
+ # Pass 1: RMS stats + dot products + bias grad
+ sum_x2 = tl.zeros((), dtype=tl.float32)
+ sum_x4 = tl.zeros((), dtype=tl.float32)
+ sum_x6 = tl.zeros((), dtype=tl.float32)
+ sum_dpoly_x = tl.zeros((), dtype=tl.float32)
+ sum_dpoly_x2 = tl.zeros((), dtype=tl.float32)
+ sum_dpoly_x3 = tl.zeros((), dtype=tl.float32)
+ grad_b_acc = tl.zeros((), dtype=tl.float32)
+
+ for d_start in range(0, D, BLOCK_D):
+ d_offs = d_start + tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+
+ x2 = x * x
+ x3 = x2 * x
+ dpoly = go * m
+
+ sum_x2 += tl.sum(x2)
+ sum_x4 += tl.sum(x2 * x2)
+ sum_x6 += tl.sum(x2 * x2 * x2)
+ sum_dpoly_x += tl.sum(dpoly * x)
+ sum_dpoly_x2 += tl.sum(dpoly * x2)
+ sum_dpoly_x3 += tl.sum(dpoly * x3)
+ grad_b_acc += tl.sum(dpoly)
+
+ inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps)
+ inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps)
+ inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps)
+
+ w0_inv = w0 * inv_rms_x3
+ w1_inv = w1 * inv_rms_x2
+ w2_inv = w2 * inv_rms_x
+
+ # Weight grads
+ grad_w0_acc = inv_rms_x3 * sum_dpoly_x3
+ grad_w1_acc = inv_rms_x2 * sum_dpoly_x2
+ grad_w2_acc = inv_rms_x * sum_dpoly_x
+
+ # Coefficients for grad_input
+ coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float
+ coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float
+ coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float
+
+ # Pass 2: grad_input + grad_mul
+ for d_start in range(0, D, BLOCK_D):
+ d_offs = d_start + tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+
+ x2 = x * x
+ x3 = x2 * x
+
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv
+ tl.store(grad_mul_row_ptr + d_offs,
+ go * (poly + b_val),
+ mask=mask)
+
+ dpoly = go * m
+ g = inv_rms_x * (w2 * dpoly - x * coeff_x)
+ g += (2.0 * x * inv_rms_x2 *
+ (w1 * dpoly - x2 * coeff_x2))
+ g += (3.0 * x2 * inv_rms_x3 *
+ (w0 * dpoly - x3 * coeff_x3))
+
+ tl.store(grad_input_row_ptr + d_offs, g, mask=mask)
+
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc)
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc)
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc)
+ tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc)
+
+ class _GroupedPolyNormFn(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, input, mul, weight, bias, offsets, eps, expert_offset):
+ N, D = input.shape
+ input = input.contiguous()
+ mul = mul.contiguous()
+ output = torch.empty_like(input)
+
+ num_experts = offsets.shape[0]
+ assert num_experts <= 4096, (
+ f"Supports at most 4096 experts, got {num_experts}.")
+
+ _grouped_polynorm_fwd_kernel[(N,)](
+ input,
+ mul,
+ weight,
+ bias,
+ offsets,
+ output,
+ N,
+ D,
+ num_experts,
+ eps,
+ expert_offset,
+ stride_input_row=input.stride(0),
+ stride_mul_row=mul.stride(0),
+ stride_out_row=output.stride(0),
+ )
+
+ ctx.save_for_backward(input, mul, weight, bias, offsets)
+ ctx.eps = eps
+ ctx.expert_offset = expert_offset
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input, mul, weight, bias, offsets = ctx.saved_tensors
+ eps = ctx.eps
+ expert_offset = ctx.expert_offset
+ N, D = input.shape
+
+ grad_output = grad_output.contiguous()
+ grad_input = torch.empty_like(input)
+ grad_mul = torch.empty_like(mul)
+ grad_weight = torch.zeros(weight.shape[0],
+ 3,
+ device=weight.device,
+ dtype=torch.float32)
+ grad_bias = torch.zeros(bias.shape[0],
+ device=bias.device,
+ dtype=torch.float32)
+
+ num_experts = offsets.shape[0]
+
+ _grouped_polynorm_bwd_kernel[(N,)](
+ grad_output,
+ input,
+ mul,
+ weight,
+ bias,
+ offsets,
+ grad_input,
+ grad_mul,
+ grad_weight,
+ grad_bias,
+ N,
+ D,
+ num_experts,
+ eps,
+ expert_offset,
+ stride_row=input.stride(0),
+ )
+
+ grad_weight = grad_weight.to(weight.dtype)
+ grad_bias = grad_bias.unsqueeze(-1).to(bias.dtype)
+
+ return grad_input, grad_mul, grad_weight, grad_bias, None, None, None
+
+ def grouped_fused_mul_poly_norm(
+ input: Tensor,
+ mul: Tensor,
+ weight: Tensor,
+ bias: Tensor,
+ offsets: Tensor,
+ eps: float = 1e-6,
+ expert_offset: int = 0,
+ ) -> Tensor:
+ """Triton-accelerated Grouped FusedMulPolyNorm.
+
+ Args:
+ input: (total_tokens, D) - concatenated tokens for all experts
+ mul: (total_tokens, D) - gate values to multiply with
+ weight: (num_experts, 3) - per-expert polynomial weights
+ bias: (num_experts, 1) - per-expert polynomial bias
+ offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32)
+ eps: numerical stability epsilon
+ expert_offset: offset to add to expert index
+
+ Returns:
+ (total_tokens, D) - output tensor
+ """
+ return _GroupedPolyNormFn.apply(input, mul, weight, bias, offsets, eps,
+ expert_offset)
+
+else:
+
+ def grouped_fused_mul_poly_norm(
+ input: Tensor,
+ mul: Tensor,
+ weight: Tensor,
+ bias: Tensor,
+ offsets: Tensor,
+ eps: float = 1e-6,
+ expert_offset: int = 0,
+ ) -> Tensor:
+ raise RuntimeError(
+ "Triton is not available. Install triton to use "
+ "grouped_fused_mul_poly_norm.")
diff --git a/build/torch210-cxx11-rocm70-x86_64-linux/__init__.py b/build/torch210-cxx11-rocm70-x86_64-linux/__init__.py
index 0f6f29ac2c688bd09afa41c5d1abd9942c4456d8..a71567edf5422a315e6f06eae6a5606756794820 100644
--- a/build/torch210-cxx11-rocm70-x86_64-linux/__init__.py
+++ b/build/torch210-cxx11-rocm70-x86_64-linux/__init__.py
@@ -2,6 +2,7 @@ import torch
from . import layers, parallel_style
from ._ops import ops
+from .grouped_poly_norm import grouped_fused_mul_poly_norm
from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
@@ -45,6 +46,7 @@ def fused_add_rms_norm(
__all__ = [
"poly_norm",
"fused_mul_poly_norm",
+ "grouped_fused_mul_poly_norm",
"rms_norm",
"fused_add_rms_norm",
"layers",
diff --git a/build/torch210-cxx11-rocm70-x86_64-linux/_activation_0e6f27f_dirty.abi3.so b/build/torch210-cxx11-rocm70-x86_64-linux/_activation_0e6f27f_dirty.abi3.so
new file mode 100755
index 0000000000000000000000000000000000000000..e0be0ec0d8372be979ea7721acfe48e28ca9aa30
--- /dev/null
+++ b/build/torch210-cxx11-rocm70-x86_64-linux/_activation_0e6f27f_dirty.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5f8d5a173c51cb2dabe3554da743aed307c04b5d51c9d0d460a8fa5a821b5495
+size 2919488
diff --git a/build/torch210-cxx11-rocm70-x86_64-linux/_activation_18b7543_dirty.abi3.so b/build/torch210-cxx11-rocm70-x86_64-linux/_activation_18b7543_dirty.abi3.so
deleted file mode 100755
index 89cd6d1522af7bfe8b53510cd92a7a9009a7b75d..0000000000000000000000000000000000000000
--- a/build/torch210-cxx11-rocm70-x86_64-linux/_activation_18b7543_dirty.abi3.so
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:45ff2b71abb33d840d92116980e519786ed06f1e337d681d0e3301dba241ff63
-size 2919488
diff --git a/build/torch210-cxx11-rocm70-x86_64-linux/_ops.py b/build/torch210-cxx11-rocm70-x86_64-linux/_ops.py
index 1e14ae6ea50287a7fbc0feee54855ae4ffc0be5e..91c4fcaab95d4cccfa1e53fea482ec1a2334f4ff 100644
--- a/build/torch210-cxx11-rocm70-x86_64-linux/_ops.py
+++ b/build/torch210-cxx11-rocm70-x86_64-linux/_ops.py
@@ -1,9 +1,9 @@
import torch
-from . import _activation_18b7543_dirty
-ops = torch.ops._activation_18b7543_dirty
+from . import _activation_0e6f27f_dirty
+ops = torch.ops._activation_0e6f27f_dirty
def add_op_namespace_prefix(op_name: str):
"""
Prefix op by namespace.
"""
- return f"_activation_18b7543_dirty::{op_name}"
\ No newline at end of file
+ return f"_activation_0e6f27f_dirty::{op_name}"
\ No newline at end of file
diff --git a/build/torch210-cxx11-rocm70-x86_64-linux/grouped_poly_norm.py b/build/torch210-cxx11-rocm70-x86_64-linux/grouped_poly_norm.py
new file mode 100644
index 0000000000000000000000000000000000000000..9184736c1c24584a94aa80c84fc18a95ab557a52
--- /dev/null
+++ b/build/torch210-cxx11-rocm70-x86_64-linux/grouped_poly_norm.py
@@ -0,0 +1,583 @@
+"""Triton-accelerated Grouped FusedMulPolyNorm for MoE.
+
+Fuses the entire PolyNorm computation into two Triton kernels (fwd + bwd),
+eliminating multiple intermediate tensors and kernel launches.
+
+PolyNorm formula (per row):
+ poly = w[0] * rms_norm(x^3) + w[1] * rms_norm(x^2) + w[2] * rms_norm(x) + bias
+ output = poly * mul
+
+where rms_norm(x) = x / sqrt(mean(x^2, dim=-1) + eps)
+
+Performance optimizations:
+ - @triton.autotune selects optimal BLOCK_D, num_warps, and num_stages per
+ hidden dimension.
+ - Single-tile specialization: when D <= BLOCK_D, all data stays in registers
+ across the reduction and output phases, eliminating redundant global reads.
+ - Multi-tile software pipelining: explicit num_stages in autotune configs
+ enables overlapping memory loads with computation across loop iterations.
+ - In-kernel binary search for expert mapping: eliminates 2 PyTorch kernel
+ launches (torch.arange + torch.bucketize) per forward/backward call.
+ - Backward 2-pass optimization: pass 1 merges RMS statistics computation
+ with dot product accumulation, pass 2 computes gradients. This reduces
+ memory traffic compared to a naive 3-pass approach.
+
+Forward kernel: one program per row, tiles over D dimension.
+ - Computes x, x^2, x^3 in registers
+ - Computes three RMS norms in a single pass (shared variance reduction)
+ - Applies polynomial weights + bias + mul in-place
+
+Backward kernel: one program per row, tiles over D dimension.
+ - Recomputes forward intermediates from saved inputs (activation recomputation)
+ - 2-pass: (1) RMS stats + dot products + bias grad, (2) grad_input + grad_mul + weight grads
+ - Weight/bias gradients use tl.atomic_add for cross-row accumulation
+"""
+
+import torch
+from torch import Tensor
+
+try:
+ import triton
+ import triton.language as tl
+
+ HAS_TRITON = True
+except ImportError:
+ HAS_TRITON = False
+
+
+# ---------------------------------------------------------------------------
+# PyTorch reference implementation (for testing and benchmarking)
+# ---------------------------------------------------------------------------
+def _rms_norm(x: Tensor, eps: float) -> Tensor:
+ """Per-row RMS normalization: x / sqrt(mean(x^2, dim=-1) + eps)"""
+ return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps)
+
+
+def grouped_fused_mul_poly_norm_ref(
+ input: Tensor,
+ mul: Tensor,
+ weight: Tensor,
+ bias: Tensor,
+ offsets: Tensor,
+ eps: float = 1e-6,
+ expert_offset: int = 0,
+) -> Tensor:
+ """PyTorch reference for Grouped FusedMulPolyNorm (vectorized, single pass).
+
+ Uses torch.bucketize to map tokens to experts, then computes PolyNorm
+ for all tokens at once. torch.compile friendly.
+
+ Args:
+ input: (total_tokens, D) - concatenated tokens for all experts
+ mul: (total_tokens, D) - gate values to multiply with
+ weight: (num_experts, 3) - per-expert polynomial weights [x^3, x^2, x]
+ bias: (num_experts, 1) - per-expert polynomial bias
+ offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32)
+ eps: numerical stability epsilon
+
+ Returns:
+ (total_tokens, D) - output tensor
+ """
+ orig_dtype = input.dtype
+
+ token_positions = torch.arange(input.shape[0], device=input.device)
+ expert_idx = torch.bucketize(token_positions, offsets, right=True) + expert_offset
+
+ weight_fp32 = weight.float()
+ bias_fp32 = bias.float()
+
+ per_token_w = weight_fp32[expert_idx]
+ per_token_b = bias_fp32[expert_idx]
+
+ x = input.float()
+ m = mul.float()
+
+ x2 = x * x
+ x3 = x2 * x
+
+ poly = (per_token_w[:, 0:1] * _rms_norm(x3, eps) +
+ per_token_w[:, 1:2] * _rms_norm(x2, eps) +
+ per_token_w[:, 2:3] * _rms_norm(x, eps) + per_token_b)
+
+ return (poly * m).to(orig_dtype)
+
+
+# ---------------------------------------------------------------------------
+# Triton kernel implementation
+# ---------------------------------------------------------------------------
+if HAS_TRITON:
+ # --- Autotune configurations ---
+ _GROUPED_POLYNORM_FWD_CONFIGS = [
+ triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2),
+ triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2),
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2),
+ triton.Config({"BLOCK_D": 2048}, num_warps=4, num_stages=1),
+ triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1),
+ triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1),
+ triton.Config({"BLOCK_D": 2048}, num_warps=32, num_stages=1),
+ ]
+
+ _GROUPED_POLYNORM_BWD_CONFIGS = [
+ triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2),
+ triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=5),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2),
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2),
+ triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1),
+ triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1),
+ ]
+
+ @triton.autotune(
+ configs=_GROUPED_POLYNORM_FWD_CONFIGS,
+ key=["D"],
+ )
+ @triton.jit
+ def _grouped_polynorm_fwd_kernel(
+ input_ptr,
+ mul_ptr,
+ weight_ptr,
+ bias_ptr,
+ offsets_ptr,
+ output_ptr,
+ N,
+ D,
+ num_experts,
+ eps,
+ expert_offset,
+ stride_input_row,
+ stride_mul_row,
+ stride_out_row,
+ BLOCK_D: tl.constexpr,
+ ):
+ """Forward kernel: one program per row."""
+ row = tl.program_id(0)
+ if row >= N:
+ return
+
+ # Binary search for expert index (12 iters covers up to 4096 experts)
+ lo = 0
+ hi = num_experts
+ for _ in range(12):
+ if lo < hi:
+ mid = (lo + hi) // 2
+ if tl.load(offsets_ptr + mid) <= row:
+ lo = mid + 1
+ else:
+ hi = mid
+ eidx = lo + expert_offset
+
+ w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32)
+ w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32)
+ w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32)
+ b = tl.load(bias_ptr + eidx).to(tl.float32)
+
+ input_row_ptr = input_ptr + row * stride_input_row
+ mul_row_ptr = mul_ptr + row * stride_mul_row
+ out_row_ptr = output_ptr + row * stride_out_row
+
+ D_float = D.to(tl.float32)
+
+ # --- Single-tile path ---
+ if D <= BLOCK_D:
+ d_offs = tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+
+ x2 = x * x
+ x3 = x2 * x
+
+ inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps)
+ inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps)
+ inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps)
+
+ # Pre-multiply scalar weight * inv_rms to save 1 FMA per element
+ w0_inv = w0 * inv_rms_x3
+ w1_inv = w1 * inv_rms_x2
+ w2_inv = w2 * inv_rms_x
+
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b
+ tl.store(out_row_ptr + d_offs, poly * m, mask=mask)
+ else:
+ # --- Multi-tile: two-pass approach ---
+ sum_x2 = tl.zeros((), dtype=tl.float32)
+ sum_x4 = tl.zeros((), dtype=tl.float32)
+ sum_x6 = tl.zeros((), dtype=tl.float32)
+
+ for d_start in range(0, D, BLOCK_D):
+ d_offs = d_start + tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ x2 = x * x
+ sum_x2 += tl.sum(x2)
+ sum_x4 += tl.sum(x2 * x2)
+ sum_x6 += tl.sum(x2 * x2 * x2)
+
+ inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps)
+ inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps)
+ inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps)
+
+ # Pre-multiply scalar weight * inv_rms
+ w0_inv = w0 * inv_rms_x3
+ w1_inv = w1 * inv_rms_x2
+ w2_inv = w2 * inv_rms_x
+
+ for d_start in range(0, D, BLOCK_D):
+ d_offs = d_start + tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ x2 = x * x
+ x3 = x2 * x
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b
+ tl.store(out_row_ptr + d_offs, poly * m, mask=mask)
+
+ @triton.autotune(
+ configs=_GROUPED_POLYNORM_BWD_CONFIGS,
+ key=["D"],
+ reset_to_zero=["grad_weight_ptr", "grad_bias_ptr"],
+ )
+ @triton.jit
+ def _grouped_polynorm_bwd_kernel(
+ grad_out_ptr,
+ input_ptr,
+ mul_ptr,
+ weight_ptr,
+ bias_ptr,
+ offsets_ptr,
+ grad_input_ptr,
+ grad_mul_ptr,
+ grad_weight_ptr,
+ grad_bias_ptr,
+ N,
+ D,
+ num_experts,
+ eps,
+ expert_offset,
+ stride_row,
+ BLOCK_D: tl.constexpr,
+ ):
+ """Backward kernel: one program per row, 2-pass approach.
+
+ Pass 1: RMS stats + dot products + bias grad
+ Pass 2: grad_input + grad_mul + weight grads (via atomic_add)
+ """
+ row = tl.program_id(0)
+ if row >= N:
+ return
+
+ lo = 0
+ hi = num_experts
+ for _ in range(12):
+ if lo < hi:
+ mid = (lo + hi) // 2
+ if tl.load(offsets_ptr + mid) <= row:
+ lo = mid + 1
+ else:
+ hi = mid
+ eidx = lo + expert_offset
+
+ w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32)
+ w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32)
+ w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32)
+ b_val = tl.load(bias_ptr + eidx).to(tl.float32)
+
+ input_row_ptr = input_ptr + row * stride_row
+ mul_row_ptr = mul_ptr + row * stride_row
+ grad_out_row_ptr = grad_out_ptr + row * stride_row
+ grad_input_row_ptr = grad_input_ptr + row * stride_row
+ grad_mul_row_ptr = grad_mul_ptr + row * stride_row
+
+ D_float = D.to(tl.float32)
+
+ # --- Single-tile path ---
+ if D <= BLOCK_D:
+ d_offs = tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+
+ x2 = x * x
+ x3 = x2 * x
+
+ # Compute RMS stats (x4 inlined to reduce register pressure)
+ inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps)
+ inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps)
+ inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps)
+
+ w0_inv = w0 * inv_rms_x3
+ w1_inv = w1 * inv_rms_x2
+ w2_inv = w2 * inv_rms_x
+
+ dpoly = go * m
+
+ # Dot products for coefficients and weight grads
+ sum_dpoly_x = tl.sum(dpoly * x)
+ sum_dpoly_x2 = tl.sum(dpoly * x2)
+ sum_dpoly_x3 = tl.sum(dpoly * x3)
+ grad_b_acc = tl.sum(dpoly)
+
+ # Weight grads
+ grad_w0_acc = inv_rms_x3 * sum_dpoly_x3
+ grad_w1_acc = inv_rms_x2 * sum_dpoly_x2
+ grad_w2_acc = inv_rms_x * sum_dpoly_x
+
+ # Coefficients for grad_input
+ coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float
+ coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float
+ coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float
+
+ # grad_mul
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv
+ tl.store(grad_mul_row_ptr + d_offs, go * (poly + b_val), mask=mask)
+
+ # grad_input (in-place accumulation to reduce register pressure)
+ g = inv_rms_x * (w2 * dpoly - x * coeff_x)
+ g += 2.0 * x * inv_rms_x2 * (w1 * dpoly - x2 * coeff_x2)
+ g += 3.0 * x2 * inv_rms_x3 * (w0 * dpoly - x3 * coeff_x3)
+ tl.store(grad_input_row_ptr + d_offs, g, mask=mask)
+
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc)
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc)
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc)
+ tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc)
+ else:
+ # --- Multi-tile: 2-pass ---
+ # Pass 1: RMS stats + dot products + bias grad
+ sum_x2 = tl.zeros((), dtype=tl.float32)
+ sum_x4 = tl.zeros((), dtype=tl.float32)
+ sum_x6 = tl.zeros((), dtype=tl.float32)
+ sum_dpoly_x = tl.zeros((), dtype=tl.float32)
+ sum_dpoly_x2 = tl.zeros((), dtype=tl.float32)
+ sum_dpoly_x3 = tl.zeros((), dtype=tl.float32)
+ grad_b_acc = tl.zeros((), dtype=tl.float32)
+
+ for d_start in range(0, D, BLOCK_D):
+ d_offs = d_start + tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+
+ x2 = x * x
+ x3 = x2 * x
+ dpoly = go * m
+
+ sum_x2 += tl.sum(x2)
+ sum_x4 += tl.sum(x2 * x2)
+ sum_x6 += tl.sum(x2 * x2 * x2)
+ sum_dpoly_x += tl.sum(dpoly * x)
+ sum_dpoly_x2 += tl.sum(dpoly * x2)
+ sum_dpoly_x3 += tl.sum(dpoly * x3)
+ grad_b_acc += tl.sum(dpoly)
+
+ inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps)
+ inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps)
+ inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps)
+
+ w0_inv = w0 * inv_rms_x3
+ w1_inv = w1 * inv_rms_x2
+ w2_inv = w2 * inv_rms_x
+
+ # Weight grads
+ grad_w0_acc = inv_rms_x3 * sum_dpoly_x3
+ grad_w1_acc = inv_rms_x2 * sum_dpoly_x2
+ grad_w2_acc = inv_rms_x * sum_dpoly_x
+
+ # Coefficients for grad_input
+ coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float
+ coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float
+ coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float
+
+ # Pass 2: grad_input + grad_mul
+ for d_start in range(0, D, BLOCK_D):
+ d_offs = d_start + tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+
+ x2 = x * x
+ x3 = x2 * x
+
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv
+ tl.store(grad_mul_row_ptr + d_offs,
+ go * (poly + b_val),
+ mask=mask)
+
+ dpoly = go * m
+ g = inv_rms_x * (w2 * dpoly - x * coeff_x)
+ g += (2.0 * x * inv_rms_x2 *
+ (w1 * dpoly - x2 * coeff_x2))
+ g += (3.0 * x2 * inv_rms_x3 *
+ (w0 * dpoly - x3 * coeff_x3))
+
+ tl.store(grad_input_row_ptr + d_offs, g, mask=mask)
+
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc)
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc)
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc)
+ tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc)
+
+ class _GroupedPolyNormFn(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, input, mul, weight, bias, offsets, eps, expert_offset):
+ N, D = input.shape
+ input = input.contiguous()
+ mul = mul.contiguous()
+ output = torch.empty_like(input)
+
+ num_experts = offsets.shape[0]
+ assert num_experts <= 4096, (
+ f"Supports at most 4096 experts, got {num_experts}.")
+
+ _grouped_polynorm_fwd_kernel[(N,)](
+ input,
+ mul,
+ weight,
+ bias,
+ offsets,
+ output,
+ N,
+ D,
+ num_experts,
+ eps,
+ expert_offset,
+ stride_input_row=input.stride(0),
+ stride_mul_row=mul.stride(0),
+ stride_out_row=output.stride(0),
+ )
+
+ ctx.save_for_backward(input, mul, weight, bias, offsets)
+ ctx.eps = eps
+ ctx.expert_offset = expert_offset
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input, mul, weight, bias, offsets = ctx.saved_tensors
+ eps = ctx.eps
+ expert_offset = ctx.expert_offset
+ N, D = input.shape
+
+ grad_output = grad_output.contiguous()
+ grad_input = torch.empty_like(input)
+ grad_mul = torch.empty_like(mul)
+ grad_weight = torch.zeros(weight.shape[0],
+ 3,
+ device=weight.device,
+ dtype=torch.float32)
+ grad_bias = torch.zeros(bias.shape[0],
+ device=bias.device,
+ dtype=torch.float32)
+
+ num_experts = offsets.shape[0]
+
+ _grouped_polynorm_bwd_kernel[(N,)](
+ grad_output,
+ input,
+ mul,
+ weight,
+ bias,
+ offsets,
+ grad_input,
+ grad_mul,
+ grad_weight,
+ grad_bias,
+ N,
+ D,
+ num_experts,
+ eps,
+ expert_offset,
+ stride_row=input.stride(0),
+ )
+
+ grad_weight = grad_weight.to(weight.dtype)
+ grad_bias = grad_bias.unsqueeze(-1).to(bias.dtype)
+
+ return grad_input, grad_mul, grad_weight, grad_bias, None, None, None
+
+ def grouped_fused_mul_poly_norm(
+ input: Tensor,
+ mul: Tensor,
+ weight: Tensor,
+ bias: Tensor,
+ offsets: Tensor,
+ eps: float = 1e-6,
+ expert_offset: int = 0,
+ ) -> Tensor:
+ """Triton-accelerated Grouped FusedMulPolyNorm.
+
+ Args:
+ input: (total_tokens, D) - concatenated tokens for all experts
+ mul: (total_tokens, D) - gate values to multiply with
+ weight: (num_experts, 3) - per-expert polynomial weights
+ bias: (num_experts, 1) - per-expert polynomial bias
+ offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32)
+ eps: numerical stability epsilon
+ expert_offset: offset to add to expert index
+
+ Returns:
+ (total_tokens, D) - output tensor
+ """
+ return _GroupedPolyNormFn.apply(input, mul, weight, bias, offsets, eps,
+ expert_offset)
+
+else:
+
+ def grouped_fused_mul_poly_norm(
+ input: Tensor,
+ mul: Tensor,
+ weight: Tensor,
+ bias: Tensor,
+ offsets: Tensor,
+ eps: float = 1e-6,
+ expert_offset: int = 0,
+ ) -> Tensor:
+ raise RuntimeError(
+ "Triton is not available. Install triton to use "
+ "grouped_fused_mul_poly_norm.")
diff --git a/build/torch210-cxx11-rocm71-x86_64-linux/__init__.py b/build/torch210-cxx11-rocm71-x86_64-linux/__init__.py
index 0f6f29ac2c688bd09afa41c5d1abd9942c4456d8..a71567edf5422a315e6f06eae6a5606756794820 100644
--- a/build/torch210-cxx11-rocm71-x86_64-linux/__init__.py
+++ b/build/torch210-cxx11-rocm71-x86_64-linux/__init__.py
@@ -2,6 +2,7 @@ import torch
from . import layers, parallel_style
from ._ops import ops
+from .grouped_poly_norm import grouped_fused_mul_poly_norm
from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
@@ -45,6 +46,7 @@ def fused_add_rms_norm(
__all__ = [
"poly_norm",
"fused_mul_poly_norm",
+ "grouped_fused_mul_poly_norm",
"rms_norm",
"fused_add_rms_norm",
"layers",
diff --git a/build/torch210-cxx11-rocm71-x86_64-linux/_activation_0e6f27f_dirty.abi3.so b/build/torch210-cxx11-rocm71-x86_64-linux/_activation_0e6f27f_dirty.abi3.so
new file mode 100755
index 0000000000000000000000000000000000000000..4401763052dc74bcd65e72216e2dd7bc239d7f91
--- /dev/null
+++ b/build/torch210-cxx11-rocm71-x86_64-linux/_activation_0e6f27f_dirty.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:28e6de9c3cd172e95284b0df0e15a2afb21ab7b89dd624e69b1361942095e8be
+size 2911200
diff --git a/build/torch210-cxx11-rocm71-x86_64-linux/_activation_18b7543_dirty.abi3.so b/build/torch210-cxx11-rocm71-x86_64-linux/_activation_18b7543_dirty.abi3.so
deleted file mode 100755
index 262fc74841f5f4af7bb17361e9028081723f235e..0000000000000000000000000000000000000000
--- a/build/torch210-cxx11-rocm71-x86_64-linux/_activation_18b7543_dirty.abi3.so
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:af4db38e8d5ad56226f5a95a86c2b5fc726bd9d576d07df2f07d3f03c1b6b35b
-size 2911200
diff --git a/build/torch210-cxx11-rocm71-x86_64-linux/_ops.py b/build/torch210-cxx11-rocm71-x86_64-linux/_ops.py
index 1e14ae6ea50287a7fbc0feee54855ae4ffc0be5e..91c4fcaab95d4cccfa1e53fea482ec1a2334f4ff 100644
--- a/build/torch210-cxx11-rocm71-x86_64-linux/_ops.py
+++ b/build/torch210-cxx11-rocm71-x86_64-linux/_ops.py
@@ -1,9 +1,9 @@
import torch
-from . import _activation_18b7543_dirty
-ops = torch.ops._activation_18b7543_dirty
+from . import _activation_0e6f27f_dirty
+ops = torch.ops._activation_0e6f27f_dirty
def add_op_namespace_prefix(op_name: str):
"""
Prefix op by namespace.
"""
- return f"_activation_18b7543_dirty::{op_name}"
\ No newline at end of file
+ return f"_activation_0e6f27f_dirty::{op_name}"
\ No newline at end of file
diff --git a/build/torch210-cxx11-rocm71-x86_64-linux/grouped_poly_norm.py b/build/torch210-cxx11-rocm71-x86_64-linux/grouped_poly_norm.py
new file mode 100644
index 0000000000000000000000000000000000000000..9184736c1c24584a94aa80c84fc18a95ab557a52
--- /dev/null
+++ b/build/torch210-cxx11-rocm71-x86_64-linux/grouped_poly_norm.py
@@ -0,0 +1,583 @@
+"""Triton-accelerated Grouped FusedMulPolyNorm for MoE.
+
+Fuses the entire PolyNorm computation into two Triton kernels (fwd + bwd),
+eliminating multiple intermediate tensors and kernel launches.
+
+PolyNorm formula (per row):
+ poly = w[0] * rms_norm(x^3) + w[1] * rms_norm(x^2) + w[2] * rms_norm(x) + bias
+ output = poly * mul
+
+where rms_norm(x) = x / sqrt(mean(x^2, dim=-1) + eps)
+
+Performance optimizations:
+ - @triton.autotune selects optimal BLOCK_D, num_warps, and num_stages per
+ hidden dimension.
+ - Single-tile specialization: when D <= BLOCK_D, all data stays in registers
+ across the reduction and output phases, eliminating redundant global reads.
+ - Multi-tile software pipelining: explicit num_stages in autotune configs
+ enables overlapping memory loads with computation across loop iterations.
+ - In-kernel binary search for expert mapping: eliminates 2 PyTorch kernel
+ launches (torch.arange + torch.bucketize) per forward/backward call.
+ - Backward 2-pass optimization: pass 1 merges RMS statistics computation
+ with dot product accumulation, pass 2 computes gradients. This reduces
+ memory traffic compared to a naive 3-pass approach.
+
+Forward kernel: one program per row, tiles over D dimension.
+ - Computes x, x^2, x^3 in registers
+ - Computes three RMS norms in a single pass (shared variance reduction)
+ - Applies polynomial weights + bias + mul in-place
+
+Backward kernel: one program per row, tiles over D dimension.
+ - Recomputes forward intermediates from saved inputs (activation recomputation)
+ - 2-pass: (1) RMS stats + dot products + bias grad, (2) grad_input + grad_mul + weight grads
+ - Weight/bias gradients use tl.atomic_add for cross-row accumulation
+"""
+
+import torch
+from torch import Tensor
+
+try:
+ import triton
+ import triton.language as tl
+
+ HAS_TRITON = True
+except ImportError:
+ HAS_TRITON = False
+
+
+# ---------------------------------------------------------------------------
+# PyTorch reference implementation (for testing and benchmarking)
+# ---------------------------------------------------------------------------
+def _rms_norm(x: Tensor, eps: float) -> Tensor:
+ """Per-row RMS normalization: x / sqrt(mean(x^2, dim=-1) + eps)"""
+ return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps)
+
+
+def grouped_fused_mul_poly_norm_ref(
+ input: Tensor,
+ mul: Tensor,
+ weight: Tensor,
+ bias: Tensor,
+ offsets: Tensor,
+ eps: float = 1e-6,
+ expert_offset: int = 0,
+) -> Tensor:
+ """PyTorch reference for Grouped FusedMulPolyNorm (vectorized, single pass).
+
+ Uses torch.bucketize to map tokens to experts, then computes PolyNorm
+ for all tokens at once. torch.compile friendly.
+
+ Args:
+ input: (total_tokens, D) - concatenated tokens for all experts
+ mul: (total_tokens, D) - gate values to multiply with
+ weight: (num_experts, 3) - per-expert polynomial weights [x^3, x^2, x]
+ bias: (num_experts, 1) - per-expert polynomial bias
+ offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32)
+ eps: numerical stability epsilon
+
+ Returns:
+ (total_tokens, D) - output tensor
+ """
+ orig_dtype = input.dtype
+
+ token_positions = torch.arange(input.shape[0], device=input.device)
+ expert_idx = torch.bucketize(token_positions, offsets, right=True) + expert_offset
+
+ weight_fp32 = weight.float()
+ bias_fp32 = bias.float()
+
+ per_token_w = weight_fp32[expert_idx]
+ per_token_b = bias_fp32[expert_idx]
+
+ x = input.float()
+ m = mul.float()
+
+ x2 = x * x
+ x3 = x2 * x
+
+ poly = (per_token_w[:, 0:1] * _rms_norm(x3, eps) +
+ per_token_w[:, 1:2] * _rms_norm(x2, eps) +
+ per_token_w[:, 2:3] * _rms_norm(x, eps) + per_token_b)
+
+ return (poly * m).to(orig_dtype)
+
+
+# ---------------------------------------------------------------------------
+# Triton kernel implementation
+# ---------------------------------------------------------------------------
+if HAS_TRITON:
+ # --- Autotune configurations ---
+ _GROUPED_POLYNORM_FWD_CONFIGS = [
+ triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2),
+ triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2),
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2),
+ triton.Config({"BLOCK_D": 2048}, num_warps=4, num_stages=1),
+ triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1),
+ triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1),
+ triton.Config({"BLOCK_D": 2048}, num_warps=32, num_stages=1),
+ ]
+
+ _GROUPED_POLYNORM_BWD_CONFIGS = [
+ triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2),
+ triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=5),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2),
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2),
+ triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1),
+ triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1),
+ ]
+
+ @triton.autotune(
+ configs=_GROUPED_POLYNORM_FWD_CONFIGS,
+ key=["D"],
+ )
+ @triton.jit
+ def _grouped_polynorm_fwd_kernel(
+ input_ptr,
+ mul_ptr,
+ weight_ptr,
+ bias_ptr,
+ offsets_ptr,
+ output_ptr,
+ N,
+ D,
+ num_experts,
+ eps,
+ expert_offset,
+ stride_input_row,
+ stride_mul_row,
+ stride_out_row,
+ BLOCK_D: tl.constexpr,
+ ):
+ """Forward kernel: one program per row."""
+ row = tl.program_id(0)
+ if row >= N:
+ return
+
+ # Binary search for expert index (12 iters covers up to 4096 experts)
+ lo = 0
+ hi = num_experts
+ for _ in range(12):
+ if lo < hi:
+ mid = (lo + hi) // 2
+ if tl.load(offsets_ptr + mid) <= row:
+ lo = mid + 1
+ else:
+ hi = mid
+ eidx = lo + expert_offset
+
+ w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32)
+ w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32)
+ w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32)
+ b = tl.load(bias_ptr + eidx).to(tl.float32)
+
+ input_row_ptr = input_ptr + row * stride_input_row
+ mul_row_ptr = mul_ptr + row * stride_mul_row
+ out_row_ptr = output_ptr + row * stride_out_row
+
+ D_float = D.to(tl.float32)
+
+ # --- Single-tile path ---
+ if D <= BLOCK_D:
+ d_offs = tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+
+ x2 = x * x
+ x3 = x2 * x
+
+ inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps)
+ inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps)
+ inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps)
+
+ # Pre-multiply scalar weight * inv_rms to save 1 FMA per element
+ w0_inv = w0 * inv_rms_x3
+ w1_inv = w1 * inv_rms_x2
+ w2_inv = w2 * inv_rms_x
+
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b
+ tl.store(out_row_ptr + d_offs, poly * m, mask=mask)
+ else:
+ # --- Multi-tile: two-pass approach ---
+ sum_x2 = tl.zeros((), dtype=tl.float32)
+ sum_x4 = tl.zeros((), dtype=tl.float32)
+ sum_x6 = tl.zeros((), dtype=tl.float32)
+
+ for d_start in range(0, D, BLOCK_D):
+ d_offs = d_start + tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ x2 = x * x
+ sum_x2 += tl.sum(x2)
+ sum_x4 += tl.sum(x2 * x2)
+ sum_x6 += tl.sum(x2 * x2 * x2)
+
+ inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps)
+ inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps)
+ inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps)
+
+ # Pre-multiply scalar weight * inv_rms
+ w0_inv = w0 * inv_rms_x3
+ w1_inv = w1 * inv_rms_x2
+ w2_inv = w2 * inv_rms_x
+
+ for d_start in range(0, D, BLOCK_D):
+ d_offs = d_start + tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ x2 = x * x
+ x3 = x2 * x
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b
+ tl.store(out_row_ptr + d_offs, poly * m, mask=mask)
+
+ @triton.autotune(
+ configs=_GROUPED_POLYNORM_BWD_CONFIGS,
+ key=["D"],
+ reset_to_zero=["grad_weight_ptr", "grad_bias_ptr"],
+ )
+ @triton.jit
+ def _grouped_polynorm_bwd_kernel(
+ grad_out_ptr,
+ input_ptr,
+ mul_ptr,
+ weight_ptr,
+ bias_ptr,
+ offsets_ptr,
+ grad_input_ptr,
+ grad_mul_ptr,
+ grad_weight_ptr,
+ grad_bias_ptr,
+ N,
+ D,
+ num_experts,
+ eps,
+ expert_offset,
+ stride_row,
+ BLOCK_D: tl.constexpr,
+ ):
+ """Backward kernel: one program per row, 2-pass approach.
+
+ Pass 1: RMS stats + dot products + bias grad
+ Pass 2: grad_input + grad_mul + weight grads (via atomic_add)
+ """
+ row = tl.program_id(0)
+ if row >= N:
+ return
+
+ lo = 0
+ hi = num_experts
+ for _ in range(12):
+ if lo < hi:
+ mid = (lo + hi) // 2
+ if tl.load(offsets_ptr + mid) <= row:
+ lo = mid + 1
+ else:
+ hi = mid
+ eidx = lo + expert_offset
+
+ w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32)
+ w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32)
+ w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32)
+ b_val = tl.load(bias_ptr + eidx).to(tl.float32)
+
+ input_row_ptr = input_ptr + row * stride_row
+ mul_row_ptr = mul_ptr + row * stride_row
+ grad_out_row_ptr = grad_out_ptr + row * stride_row
+ grad_input_row_ptr = grad_input_ptr + row * stride_row
+ grad_mul_row_ptr = grad_mul_ptr + row * stride_row
+
+ D_float = D.to(tl.float32)
+
+ # --- Single-tile path ---
+ if D <= BLOCK_D:
+ d_offs = tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+
+ x2 = x * x
+ x3 = x2 * x
+
+ # Compute RMS stats (x4 inlined to reduce register pressure)
+ inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps)
+ inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps)
+ inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps)
+
+ w0_inv = w0 * inv_rms_x3
+ w1_inv = w1 * inv_rms_x2
+ w2_inv = w2 * inv_rms_x
+
+ dpoly = go * m
+
+ # Dot products for coefficients and weight grads
+ sum_dpoly_x = tl.sum(dpoly * x)
+ sum_dpoly_x2 = tl.sum(dpoly * x2)
+ sum_dpoly_x3 = tl.sum(dpoly * x3)
+ grad_b_acc = tl.sum(dpoly)
+
+ # Weight grads
+ grad_w0_acc = inv_rms_x3 * sum_dpoly_x3
+ grad_w1_acc = inv_rms_x2 * sum_dpoly_x2
+ grad_w2_acc = inv_rms_x * sum_dpoly_x
+
+ # Coefficients for grad_input
+ coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float
+ coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float
+ coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float
+
+ # grad_mul
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv
+ tl.store(grad_mul_row_ptr + d_offs, go * (poly + b_val), mask=mask)
+
+ # grad_input (in-place accumulation to reduce register pressure)
+ g = inv_rms_x * (w2 * dpoly - x * coeff_x)
+ g += 2.0 * x * inv_rms_x2 * (w1 * dpoly - x2 * coeff_x2)
+ g += 3.0 * x2 * inv_rms_x3 * (w0 * dpoly - x3 * coeff_x3)
+ tl.store(grad_input_row_ptr + d_offs, g, mask=mask)
+
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc)
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc)
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc)
+ tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc)
+ else:
+ # --- Multi-tile: 2-pass ---
+ # Pass 1: RMS stats + dot products + bias grad
+ sum_x2 = tl.zeros((), dtype=tl.float32)
+ sum_x4 = tl.zeros((), dtype=tl.float32)
+ sum_x6 = tl.zeros((), dtype=tl.float32)
+ sum_dpoly_x = tl.zeros((), dtype=tl.float32)
+ sum_dpoly_x2 = tl.zeros((), dtype=tl.float32)
+ sum_dpoly_x3 = tl.zeros((), dtype=tl.float32)
+ grad_b_acc = tl.zeros((), dtype=tl.float32)
+
+ for d_start in range(0, D, BLOCK_D):
+ d_offs = d_start + tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+
+ x2 = x * x
+ x3 = x2 * x
+ dpoly = go * m
+
+ sum_x2 += tl.sum(x2)
+ sum_x4 += tl.sum(x2 * x2)
+ sum_x6 += tl.sum(x2 * x2 * x2)
+ sum_dpoly_x += tl.sum(dpoly * x)
+ sum_dpoly_x2 += tl.sum(dpoly * x2)
+ sum_dpoly_x3 += tl.sum(dpoly * x3)
+ grad_b_acc += tl.sum(dpoly)
+
+ inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps)
+ inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps)
+ inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps)
+
+ w0_inv = w0 * inv_rms_x3
+ w1_inv = w1 * inv_rms_x2
+ w2_inv = w2 * inv_rms_x
+
+ # Weight grads
+ grad_w0_acc = inv_rms_x3 * sum_dpoly_x3
+ grad_w1_acc = inv_rms_x2 * sum_dpoly_x2
+ grad_w2_acc = inv_rms_x * sum_dpoly_x
+
+ # Coefficients for grad_input
+ coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float
+ coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float
+ coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float
+
+ # Pass 2: grad_input + grad_mul
+ for d_start in range(0, D, BLOCK_D):
+ d_offs = d_start + tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+
+ x2 = x * x
+ x3 = x2 * x
+
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv
+ tl.store(grad_mul_row_ptr + d_offs,
+ go * (poly + b_val),
+ mask=mask)
+
+ dpoly = go * m
+ g = inv_rms_x * (w2 * dpoly - x * coeff_x)
+ g += (2.0 * x * inv_rms_x2 *
+ (w1 * dpoly - x2 * coeff_x2))
+ g += (3.0 * x2 * inv_rms_x3 *
+ (w0 * dpoly - x3 * coeff_x3))
+
+ tl.store(grad_input_row_ptr + d_offs, g, mask=mask)
+
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc)
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc)
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc)
+ tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc)
+
+ class _GroupedPolyNormFn(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, input, mul, weight, bias, offsets, eps, expert_offset):
+ N, D = input.shape
+ input = input.contiguous()
+ mul = mul.contiguous()
+ output = torch.empty_like(input)
+
+ num_experts = offsets.shape[0]
+ assert num_experts <= 4096, (
+ f"Supports at most 4096 experts, got {num_experts}.")
+
+ _grouped_polynorm_fwd_kernel[(N,)](
+ input,
+ mul,
+ weight,
+ bias,
+ offsets,
+ output,
+ N,
+ D,
+ num_experts,
+ eps,
+ expert_offset,
+ stride_input_row=input.stride(0),
+ stride_mul_row=mul.stride(0),
+ stride_out_row=output.stride(0),
+ )
+
+ ctx.save_for_backward(input, mul, weight, bias, offsets)
+ ctx.eps = eps
+ ctx.expert_offset = expert_offset
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input, mul, weight, bias, offsets = ctx.saved_tensors
+ eps = ctx.eps
+ expert_offset = ctx.expert_offset
+ N, D = input.shape
+
+ grad_output = grad_output.contiguous()
+ grad_input = torch.empty_like(input)
+ grad_mul = torch.empty_like(mul)
+ grad_weight = torch.zeros(weight.shape[0],
+ 3,
+ device=weight.device,
+ dtype=torch.float32)
+ grad_bias = torch.zeros(bias.shape[0],
+ device=bias.device,
+ dtype=torch.float32)
+
+ num_experts = offsets.shape[0]
+
+ _grouped_polynorm_bwd_kernel[(N,)](
+ grad_output,
+ input,
+ mul,
+ weight,
+ bias,
+ offsets,
+ grad_input,
+ grad_mul,
+ grad_weight,
+ grad_bias,
+ N,
+ D,
+ num_experts,
+ eps,
+ expert_offset,
+ stride_row=input.stride(0),
+ )
+
+ grad_weight = grad_weight.to(weight.dtype)
+ grad_bias = grad_bias.unsqueeze(-1).to(bias.dtype)
+
+ return grad_input, grad_mul, grad_weight, grad_bias, None, None, None
+
+ def grouped_fused_mul_poly_norm(
+ input: Tensor,
+ mul: Tensor,
+ weight: Tensor,
+ bias: Tensor,
+ offsets: Tensor,
+ eps: float = 1e-6,
+ expert_offset: int = 0,
+ ) -> Tensor:
+ """Triton-accelerated Grouped FusedMulPolyNorm.
+
+ Args:
+ input: (total_tokens, D) - concatenated tokens for all experts
+ mul: (total_tokens, D) - gate values to multiply with
+ weight: (num_experts, 3) - per-expert polynomial weights
+ bias: (num_experts, 1) - per-expert polynomial bias
+ offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32)
+ eps: numerical stability epsilon
+ expert_offset: offset to add to expert index
+
+ Returns:
+ (total_tokens, D) - output tensor
+ """
+ return _GroupedPolyNormFn.apply(input, mul, weight, bias, offsets, eps,
+ expert_offset)
+
+else:
+
+ def grouped_fused_mul_poly_norm(
+ input: Tensor,
+ mul: Tensor,
+ weight: Tensor,
+ bias: Tensor,
+ offsets: Tensor,
+ eps: float = 1e-6,
+ expert_offset: int = 0,
+ ) -> Tensor:
+ raise RuntimeError(
+ "Triton is not available. Install triton to use "
+ "grouped_fused_mul_poly_norm.")
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/__init__.py b/build/torch28-cxx11-cu126-x86_64-linux/__init__.py
index 0f6f29ac2c688bd09afa41c5d1abd9942c4456d8..a71567edf5422a315e6f06eae6a5606756794820 100644
--- a/build/torch28-cxx11-cu126-x86_64-linux/__init__.py
+++ b/build/torch28-cxx11-cu126-x86_64-linux/__init__.py
@@ -2,6 +2,7 @@ import torch
from . import layers, parallel_style
from ._ops import ops
+from .grouped_poly_norm import grouped_fused_mul_poly_norm
from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
@@ -45,6 +46,7 @@ def fused_add_rms_norm(
__all__ = [
"poly_norm",
"fused_mul_poly_norm",
+ "grouped_fused_mul_poly_norm",
"rms_norm",
"fused_add_rms_norm",
"layers",
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/_activation_0e6f27f_dirty.abi3.so b/build/torch28-cxx11-cu126-x86_64-linux/_activation_0e6f27f_dirty.abi3.so
new file mode 100755
index 0000000000000000000000000000000000000000..7eeea2956f47f94084a31c80fd20218dd4e9a078
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/_activation_0e6f27f_dirty.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:51ac098828ee90af0d1d17ae75326f89777b1b2e7ef57e00035aed560c434a20
+size 10756352
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/_activation_18b7543_dirty.abi3.so b/build/torch28-cxx11-cu126-x86_64-linux/_activation_18b7543_dirty.abi3.so
deleted file mode 100755
index e5607fd3688771cb7bfea8abd9f2808ee6c8ca59..0000000000000000000000000000000000000000
--- a/build/torch28-cxx11-cu126-x86_64-linux/_activation_18b7543_dirty.abi3.so
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:0cc858ef4dce1d14e3a74f5f4a17a2d6c6a8c54cba436f938449c859dd84c3b1
-size 10756352
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/_ops.py b/build/torch28-cxx11-cu126-x86_64-linux/_ops.py
index 1e14ae6ea50287a7fbc0feee54855ae4ffc0be5e..91c4fcaab95d4cccfa1e53fea482ec1a2334f4ff 100644
--- a/build/torch28-cxx11-cu126-x86_64-linux/_ops.py
+++ b/build/torch28-cxx11-cu126-x86_64-linux/_ops.py
@@ -1,9 +1,9 @@
import torch
-from . import _activation_18b7543_dirty
-ops = torch.ops._activation_18b7543_dirty
+from . import _activation_0e6f27f_dirty
+ops = torch.ops._activation_0e6f27f_dirty
def add_op_namespace_prefix(op_name: str):
"""
Prefix op by namespace.
"""
- return f"_activation_18b7543_dirty::{op_name}"
\ No newline at end of file
+ return f"_activation_0e6f27f_dirty::{op_name}"
\ No newline at end of file
diff --git a/build/torch28-cxx11-cu126-x86_64-linux/grouped_poly_norm.py b/build/torch28-cxx11-cu126-x86_64-linux/grouped_poly_norm.py
new file mode 100644
index 0000000000000000000000000000000000000000..9184736c1c24584a94aa80c84fc18a95ab557a52
--- /dev/null
+++ b/build/torch28-cxx11-cu126-x86_64-linux/grouped_poly_norm.py
@@ -0,0 +1,583 @@
+"""Triton-accelerated Grouped FusedMulPolyNorm for MoE.
+
+Fuses the entire PolyNorm computation into two Triton kernels (fwd + bwd),
+eliminating multiple intermediate tensors and kernel launches.
+
+PolyNorm formula (per row):
+ poly = w[0] * rms_norm(x^3) + w[1] * rms_norm(x^2) + w[2] * rms_norm(x) + bias
+ output = poly * mul
+
+where rms_norm(x) = x / sqrt(mean(x^2, dim=-1) + eps)
+
+Performance optimizations:
+ - @triton.autotune selects optimal BLOCK_D, num_warps, and num_stages per
+ hidden dimension.
+ - Single-tile specialization: when D <= BLOCK_D, all data stays in registers
+ across the reduction and output phases, eliminating redundant global reads.
+ - Multi-tile software pipelining: explicit num_stages in autotune configs
+ enables overlapping memory loads with computation across loop iterations.
+ - In-kernel binary search for expert mapping: eliminates 2 PyTorch kernel
+ launches (torch.arange + torch.bucketize) per forward/backward call.
+ - Backward 2-pass optimization: pass 1 merges RMS statistics computation
+ with dot product accumulation, pass 2 computes gradients. This reduces
+ memory traffic compared to a naive 3-pass approach.
+
+Forward kernel: one program per row, tiles over D dimension.
+ - Computes x, x^2, x^3 in registers
+ - Computes three RMS norms in a single pass (shared variance reduction)
+ - Applies polynomial weights + bias + mul in-place
+
+Backward kernel: one program per row, tiles over D dimension.
+ - Recomputes forward intermediates from saved inputs (activation recomputation)
+ - 2-pass: (1) RMS stats + dot products + bias grad, (2) grad_input + grad_mul + weight grads
+ - Weight/bias gradients use tl.atomic_add for cross-row accumulation
+"""
+
+import torch
+from torch import Tensor
+
+try:
+ import triton
+ import triton.language as tl
+
+ HAS_TRITON = True
+except ImportError:
+ HAS_TRITON = False
+
+
+# ---------------------------------------------------------------------------
+# PyTorch reference implementation (for testing and benchmarking)
+# ---------------------------------------------------------------------------
+def _rms_norm(x: Tensor, eps: float) -> Tensor:
+ """Per-row RMS normalization: x / sqrt(mean(x^2, dim=-1) + eps)"""
+ return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps)
+
+
+def grouped_fused_mul_poly_norm_ref(
+ input: Tensor,
+ mul: Tensor,
+ weight: Tensor,
+ bias: Tensor,
+ offsets: Tensor,
+ eps: float = 1e-6,
+ expert_offset: int = 0,
+) -> Tensor:
+ """PyTorch reference for Grouped FusedMulPolyNorm (vectorized, single pass).
+
+ Uses torch.bucketize to map tokens to experts, then computes PolyNorm
+ for all tokens at once. torch.compile friendly.
+
+ Args:
+ input: (total_tokens, D) - concatenated tokens for all experts
+ mul: (total_tokens, D) - gate values to multiply with
+ weight: (num_experts, 3) - per-expert polynomial weights [x^3, x^2, x]
+ bias: (num_experts, 1) - per-expert polynomial bias
+ offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32)
+ eps: numerical stability epsilon
+
+ Returns:
+ (total_tokens, D) - output tensor
+ """
+ orig_dtype = input.dtype
+
+ token_positions = torch.arange(input.shape[0], device=input.device)
+ expert_idx = torch.bucketize(token_positions, offsets, right=True) + expert_offset
+
+ weight_fp32 = weight.float()
+ bias_fp32 = bias.float()
+
+ per_token_w = weight_fp32[expert_idx]
+ per_token_b = bias_fp32[expert_idx]
+
+ x = input.float()
+ m = mul.float()
+
+ x2 = x * x
+ x3 = x2 * x
+
+ poly = (per_token_w[:, 0:1] * _rms_norm(x3, eps) +
+ per_token_w[:, 1:2] * _rms_norm(x2, eps) +
+ per_token_w[:, 2:3] * _rms_norm(x, eps) + per_token_b)
+
+ return (poly * m).to(orig_dtype)
+
+
+# ---------------------------------------------------------------------------
+# Triton kernel implementation
+# ---------------------------------------------------------------------------
+if HAS_TRITON:
+ # --- Autotune configurations ---
+ _GROUPED_POLYNORM_FWD_CONFIGS = [
+ triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2),
+ triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2),
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2),
+ triton.Config({"BLOCK_D": 2048}, num_warps=4, num_stages=1),
+ triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1),
+ triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1),
+ triton.Config({"BLOCK_D": 2048}, num_warps=32, num_stages=1),
+ ]
+
+ _GROUPED_POLYNORM_BWD_CONFIGS = [
+ triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2),
+ triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=5),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2),
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2),
+ triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1),
+ triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1),
+ ]
+
+ @triton.autotune(
+ configs=_GROUPED_POLYNORM_FWD_CONFIGS,
+ key=["D"],
+ )
+ @triton.jit
+ def _grouped_polynorm_fwd_kernel(
+ input_ptr,
+ mul_ptr,
+ weight_ptr,
+ bias_ptr,
+ offsets_ptr,
+ output_ptr,
+ N,
+ D,
+ num_experts,
+ eps,
+ expert_offset,
+ stride_input_row,
+ stride_mul_row,
+ stride_out_row,
+ BLOCK_D: tl.constexpr,
+ ):
+ """Forward kernel: one program per row."""
+ row = tl.program_id(0)
+ if row >= N:
+ return
+
+ # Binary search for expert index (12 iters covers up to 4096 experts)
+ lo = 0
+ hi = num_experts
+ for _ in range(12):
+ if lo < hi:
+ mid = (lo + hi) // 2
+ if tl.load(offsets_ptr + mid) <= row:
+ lo = mid + 1
+ else:
+ hi = mid
+ eidx = lo + expert_offset
+
+ w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32)
+ w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32)
+ w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32)
+ b = tl.load(bias_ptr + eidx).to(tl.float32)
+
+ input_row_ptr = input_ptr + row * stride_input_row
+ mul_row_ptr = mul_ptr + row * stride_mul_row
+ out_row_ptr = output_ptr + row * stride_out_row
+
+ D_float = D.to(tl.float32)
+
+ # --- Single-tile path ---
+ if D <= BLOCK_D:
+ d_offs = tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+
+ x2 = x * x
+ x3 = x2 * x
+
+ inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps)
+ inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps)
+ inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps)
+
+ # Pre-multiply scalar weight * inv_rms to save 1 FMA per element
+ w0_inv = w0 * inv_rms_x3
+ w1_inv = w1 * inv_rms_x2
+ w2_inv = w2 * inv_rms_x
+
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b
+ tl.store(out_row_ptr + d_offs, poly * m, mask=mask)
+ else:
+ # --- Multi-tile: two-pass approach ---
+ sum_x2 = tl.zeros((), dtype=tl.float32)
+ sum_x4 = tl.zeros((), dtype=tl.float32)
+ sum_x6 = tl.zeros((), dtype=tl.float32)
+
+ for d_start in range(0, D, BLOCK_D):
+ d_offs = d_start + tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ x2 = x * x
+ sum_x2 += tl.sum(x2)
+ sum_x4 += tl.sum(x2 * x2)
+ sum_x6 += tl.sum(x2 * x2 * x2)
+
+ inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps)
+ inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps)
+ inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps)
+
+ # Pre-multiply scalar weight * inv_rms
+ w0_inv = w0 * inv_rms_x3
+ w1_inv = w1 * inv_rms_x2
+ w2_inv = w2 * inv_rms_x
+
+ for d_start in range(0, D, BLOCK_D):
+ d_offs = d_start + tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ x2 = x * x
+ x3 = x2 * x
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b
+ tl.store(out_row_ptr + d_offs, poly * m, mask=mask)
+
+ @triton.autotune(
+ configs=_GROUPED_POLYNORM_BWD_CONFIGS,
+ key=["D"],
+ reset_to_zero=["grad_weight_ptr", "grad_bias_ptr"],
+ )
+ @triton.jit
+ def _grouped_polynorm_bwd_kernel(
+ grad_out_ptr,
+ input_ptr,
+ mul_ptr,
+ weight_ptr,
+ bias_ptr,
+ offsets_ptr,
+ grad_input_ptr,
+ grad_mul_ptr,
+ grad_weight_ptr,
+ grad_bias_ptr,
+ N,
+ D,
+ num_experts,
+ eps,
+ expert_offset,
+ stride_row,
+ BLOCK_D: tl.constexpr,
+ ):
+ """Backward kernel: one program per row, 2-pass approach.
+
+ Pass 1: RMS stats + dot products + bias grad
+ Pass 2: grad_input + grad_mul + weight grads (via atomic_add)
+ """
+ row = tl.program_id(0)
+ if row >= N:
+ return
+
+ lo = 0
+ hi = num_experts
+ for _ in range(12):
+ if lo < hi:
+ mid = (lo + hi) // 2
+ if tl.load(offsets_ptr + mid) <= row:
+ lo = mid + 1
+ else:
+ hi = mid
+ eidx = lo + expert_offset
+
+ w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32)
+ w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32)
+ w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32)
+ b_val = tl.load(bias_ptr + eidx).to(tl.float32)
+
+ input_row_ptr = input_ptr + row * stride_row
+ mul_row_ptr = mul_ptr + row * stride_row
+ grad_out_row_ptr = grad_out_ptr + row * stride_row
+ grad_input_row_ptr = grad_input_ptr + row * stride_row
+ grad_mul_row_ptr = grad_mul_ptr + row * stride_row
+
+ D_float = D.to(tl.float32)
+
+ # --- Single-tile path ---
+ if D <= BLOCK_D:
+ d_offs = tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+
+ x2 = x * x
+ x3 = x2 * x
+
+ # Compute RMS stats (x4 inlined to reduce register pressure)
+ inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps)
+ inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps)
+ inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps)
+
+ w0_inv = w0 * inv_rms_x3
+ w1_inv = w1 * inv_rms_x2
+ w2_inv = w2 * inv_rms_x
+
+ dpoly = go * m
+
+ # Dot products for coefficients and weight grads
+ sum_dpoly_x = tl.sum(dpoly * x)
+ sum_dpoly_x2 = tl.sum(dpoly * x2)
+ sum_dpoly_x3 = tl.sum(dpoly * x3)
+ grad_b_acc = tl.sum(dpoly)
+
+ # Weight grads
+ grad_w0_acc = inv_rms_x3 * sum_dpoly_x3
+ grad_w1_acc = inv_rms_x2 * sum_dpoly_x2
+ grad_w2_acc = inv_rms_x * sum_dpoly_x
+
+ # Coefficients for grad_input
+ coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float
+ coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float
+ coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float
+
+ # grad_mul
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv
+ tl.store(grad_mul_row_ptr + d_offs, go * (poly + b_val), mask=mask)
+
+ # grad_input (in-place accumulation to reduce register pressure)
+ g = inv_rms_x * (w2 * dpoly - x * coeff_x)
+ g += 2.0 * x * inv_rms_x2 * (w1 * dpoly - x2 * coeff_x2)
+ g += 3.0 * x2 * inv_rms_x3 * (w0 * dpoly - x3 * coeff_x3)
+ tl.store(grad_input_row_ptr + d_offs, g, mask=mask)
+
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc)
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc)
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc)
+ tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc)
+ else:
+ # --- Multi-tile: 2-pass ---
+ # Pass 1: RMS stats + dot products + bias grad
+ sum_x2 = tl.zeros((), dtype=tl.float32)
+ sum_x4 = tl.zeros((), dtype=tl.float32)
+ sum_x6 = tl.zeros((), dtype=tl.float32)
+ sum_dpoly_x = tl.zeros((), dtype=tl.float32)
+ sum_dpoly_x2 = tl.zeros((), dtype=tl.float32)
+ sum_dpoly_x3 = tl.zeros((), dtype=tl.float32)
+ grad_b_acc = tl.zeros((), dtype=tl.float32)
+
+ for d_start in range(0, D, BLOCK_D):
+ d_offs = d_start + tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+
+ x2 = x * x
+ x3 = x2 * x
+ dpoly = go * m
+
+ sum_x2 += tl.sum(x2)
+ sum_x4 += tl.sum(x2 * x2)
+ sum_x6 += tl.sum(x2 * x2 * x2)
+ sum_dpoly_x += tl.sum(dpoly * x)
+ sum_dpoly_x2 += tl.sum(dpoly * x2)
+ sum_dpoly_x3 += tl.sum(dpoly * x3)
+ grad_b_acc += tl.sum(dpoly)
+
+ inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps)
+ inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps)
+ inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps)
+
+ w0_inv = w0 * inv_rms_x3
+ w1_inv = w1 * inv_rms_x2
+ w2_inv = w2 * inv_rms_x
+
+ # Weight grads
+ grad_w0_acc = inv_rms_x3 * sum_dpoly_x3
+ grad_w1_acc = inv_rms_x2 * sum_dpoly_x2
+ grad_w2_acc = inv_rms_x * sum_dpoly_x
+
+ # Coefficients for grad_input
+ coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float
+ coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float
+ coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float
+
+ # Pass 2: grad_input + grad_mul
+ for d_start in range(0, D, BLOCK_D):
+ d_offs = d_start + tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+
+ x2 = x * x
+ x3 = x2 * x
+
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv
+ tl.store(grad_mul_row_ptr + d_offs,
+ go * (poly + b_val),
+ mask=mask)
+
+ dpoly = go * m
+ g = inv_rms_x * (w2 * dpoly - x * coeff_x)
+ g += (2.0 * x * inv_rms_x2 *
+ (w1 * dpoly - x2 * coeff_x2))
+ g += (3.0 * x2 * inv_rms_x3 *
+ (w0 * dpoly - x3 * coeff_x3))
+
+ tl.store(grad_input_row_ptr + d_offs, g, mask=mask)
+
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc)
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc)
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc)
+ tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc)
+
+ class _GroupedPolyNormFn(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, input, mul, weight, bias, offsets, eps, expert_offset):
+ N, D = input.shape
+ input = input.contiguous()
+ mul = mul.contiguous()
+ output = torch.empty_like(input)
+
+ num_experts = offsets.shape[0]
+ assert num_experts <= 4096, (
+ f"Supports at most 4096 experts, got {num_experts}.")
+
+ _grouped_polynorm_fwd_kernel[(N,)](
+ input,
+ mul,
+ weight,
+ bias,
+ offsets,
+ output,
+ N,
+ D,
+ num_experts,
+ eps,
+ expert_offset,
+ stride_input_row=input.stride(0),
+ stride_mul_row=mul.stride(0),
+ stride_out_row=output.stride(0),
+ )
+
+ ctx.save_for_backward(input, mul, weight, bias, offsets)
+ ctx.eps = eps
+ ctx.expert_offset = expert_offset
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input, mul, weight, bias, offsets = ctx.saved_tensors
+ eps = ctx.eps
+ expert_offset = ctx.expert_offset
+ N, D = input.shape
+
+ grad_output = grad_output.contiguous()
+ grad_input = torch.empty_like(input)
+ grad_mul = torch.empty_like(mul)
+ grad_weight = torch.zeros(weight.shape[0],
+ 3,
+ device=weight.device,
+ dtype=torch.float32)
+ grad_bias = torch.zeros(bias.shape[0],
+ device=bias.device,
+ dtype=torch.float32)
+
+ num_experts = offsets.shape[0]
+
+ _grouped_polynorm_bwd_kernel[(N,)](
+ grad_output,
+ input,
+ mul,
+ weight,
+ bias,
+ offsets,
+ grad_input,
+ grad_mul,
+ grad_weight,
+ grad_bias,
+ N,
+ D,
+ num_experts,
+ eps,
+ expert_offset,
+ stride_row=input.stride(0),
+ )
+
+ grad_weight = grad_weight.to(weight.dtype)
+ grad_bias = grad_bias.unsqueeze(-1).to(bias.dtype)
+
+ return grad_input, grad_mul, grad_weight, grad_bias, None, None, None
+
+ def grouped_fused_mul_poly_norm(
+ input: Tensor,
+ mul: Tensor,
+ weight: Tensor,
+ bias: Tensor,
+ offsets: Tensor,
+ eps: float = 1e-6,
+ expert_offset: int = 0,
+ ) -> Tensor:
+ """Triton-accelerated Grouped FusedMulPolyNorm.
+
+ Args:
+ input: (total_tokens, D) - concatenated tokens for all experts
+ mul: (total_tokens, D) - gate values to multiply with
+ weight: (num_experts, 3) - per-expert polynomial weights
+ bias: (num_experts, 1) - per-expert polynomial bias
+ offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32)
+ eps: numerical stability epsilon
+ expert_offset: offset to add to expert index
+
+ Returns:
+ (total_tokens, D) - output tensor
+ """
+ return _GroupedPolyNormFn.apply(input, mul, weight, bias, offsets, eps,
+ expert_offset)
+
+else:
+
+ def grouped_fused_mul_poly_norm(
+ input: Tensor,
+ mul: Tensor,
+ weight: Tensor,
+ bias: Tensor,
+ offsets: Tensor,
+ eps: float = 1e-6,
+ expert_offset: int = 0,
+ ) -> Tensor:
+ raise RuntimeError(
+ "Triton is not available. Install triton to use "
+ "grouped_fused_mul_poly_norm.")
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/__init__.py b/build/torch28-cxx11-cu128-x86_64-linux/__init__.py
index 0f6f29ac2c688bd09afa41c5d1abd9942c4456d8..a71567edf5422a315e6f06eae6a5606756794820 100644
--- a/build/torch28-cxx11-cu128-x86_64-linux/__init__.py
+++ b/build/torch28-cxx11-cu128-x86_64-linux/__init__.py
@@ -2,6 +2,7 @@ import torch
from . import layers, parallel_style
from ._ops import ops
+from .grouped_poly_norm import grouped_fused_mul_poly_norm
from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
@@ -45,6 +46,7 @@ def fused_add_rms_norm(
__all__ = [
"poly_norm",
"fused_mul_poly_norm",
+ "grouped_fused_mul_poly_norm",
"rms_norm",
"fused_add_rms_norm",
"layers",
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/_activation_0e6f27f_dirty.abi3.so b/build/torch28-cxx11-cu128-x86_64-linux/_activation_0e6f27f_dirty.abi3.so
new file mode 100755
index 0000000000000000000000000000000000000000..c6a85ff536b76cbee70f5e6051d31b7962f47d22
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/_activation_0e6f27f_dirty.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7edb027993454d74da9632c7368edc2b0526b5f1ef33ae9e790d49bdf7285640
+size 15804360
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/_activation_18b7543_dirty.abi3.so b/build/torch28-cxx11-cu128-x86_64-linux/_activation_18b7543_dirty.abi3.so
deleted file mode 100755
index cbf9fb816fa3782bb93924b7dcf1dc8143c09f3a..0000000000000000000000000000000000000000
--- a/build/torch28-cxx11-cu128-x86_64-linux/_activation_18b7543_dirty.abi3.so
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:3f3e704c3833d0d41cf43549ddad15c36c67e4b80b0b2a7af5c6a9a2c488690b
-size 15804360
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/_ops.py b/build/torch28-cxx11-cu128-x86_64-linux/_ops.py
index 1e14ae6ea50287a7fbc0feee54855ae4ffc0be5e..91c4fcaab95d4cccfa1e53fea482ec1a2334f4ff 100644
--- a/build/torch28-cxx11-cu128-x86_64-linux/_ops.py
+++ b/build/torch28-cxx11-cu128-x86_64-linux/_ops.py
@@ -1,9 +1,9 @@
import torch
-from . import _activation_18b7543_dirty
-ops = torch.ops._activation_18b7543_dirty
+from . import _activation_0e6f27f_dirty
+ops = torch.ops._activation_0e6f27f_dirty
def add_op_namespace_prefix(op_name: str):
"""
Prefix op by namespace.
"""
- return f"_activation_18b7543_dirty::{op_name}"
\ No newline at end of file
+ return f"_activation_0e6f27f_dirty::{op_name}"
\ No newline at end of file
diff --git a/build/torch28-cxx11-cu128-x86_64-linux/grouped_poly_norm.py b/build/torch28-cxx11-cu128-x86_64-linux/grouped_poly_norm.py
new file mode 100644
index 0000000000000000000000000000000000000000..9184736c1c24584a94aa80c84fc18a95ab557a52
--- /dev/null
+++ b/build/torch28-cxx11-cu128-x86_64-linux/grouped_poly_norm.py
@@ -0,0 +1,583 @@
+"""Triton-accelerated Grouped FusedMulPolyNorm for MoE.
+
+Fuses the entire PolyNorm computation into two Triton kernels (fwd + bwd),
+eliminating multiple intermediate tensors and kernel launches.
+
+PolyNorm formula (per row):
+ poly = w[0] * rms_norm(x^3) + w[1] * rms_norm(x^2) + w[2] * rms_norm(x) + bias
+ output = poly * mul
+
+where rms_norm(x) = x / sqrt(mean(x^2, dim=-1) + eps)
+
+Performance optimizations:
+ - @triton.autotune selects optimal BLOCK_D, num_warps, and num_stages per
+ hidden dimension.
+ - Single-tile specialization: when D <= BLOCK_D, all data stays in registers
+ across the reduction and output phases, eliminating redundant global reads.
+ - Multi-tile software pipelining: explicit num_stages in autotune configs
+ enables overlapping memory loads with computation across loop iterations.
+ - In-kernel binary search for expert mapping: eliminates 2 PyTorch kernel
+ launches (torch.arange + torch.bucketize) per forward/backward call.
+ - Backward 2-pass optimization: pass 1 merges RMS statistics computation
+ with dot product accumulation, pass 2 computes gradients. This reduces
+ memory traffic compared to a naive 3-pass approach.
+
+Forward kernel: one program per row, tiles over D dimension.
+ - Computes x, x^2, x^3 in registers
+ - Computes three RMS norms in a single pass (shared variance reduction)
+ - Applies polynomial weights + bias + mul in-place
+
+Backward kernel: one program per row, tiles over D dimension.
+ - Recomputes forward intermediates from saved inputs (activation recomputation)
+ - 2-pass: (1) RMS stats + dot products + bias grad, (2) grad_input + grad_mul + weight grads
+ - Weight/bias gradients use tl.atomic_add for cross-row accumulation
+"""
+
+import torch
+from torch import Tensor
+
+try:
+ import triton
+ import triton.language as tl
+
+ HAS_TRITON = True
+except ImportError:
+ HAS_TRITON = False
+
+
+# ---------------------------------------------------------------------------
+# PyTorch reference implementation (for testing and benchmarking)
+# ---------------------------------------------------------------------------
+def _rms_norm(x: Tensor, eps: float) -> Tensor:
+ """Per-row RMS normalization: x / sqrt(mean(x^2, dim=-1) + eps)"""
+ return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps)
+
+
+def grouped_fused_mul_poly_norm_ref(
+ input: Tensor,
+ mul: Tensor,
+ weight: Tensor,
+ bias: Tensor,
+ offsets: Tensor,
+ eps: float = 1e-6,
+ expert_offset: int = 0,
+) -> Tensor:
+ """PyTorch reference for Grouped FusedMulPolyNorm (vectorized, single pass).
+
+ Uses torch.bucketize to map tokens to experts, then computes PolyNorm
+ for all tokens at once. torch.compile friendly.
+
+ Args:
+ input: (total_tokens, D) - concatenated tokens for all experts
+ mul: (total_tokens, D) - gate values to multiply with
+ weight: (num_experts, 3) - per-expert polynomial weights [x^3, x^2, x]
+ bias: (num_experts, 1) - per-expert polynomial bias
+ offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32)
+ eps: numerical stability epsilon
+
+ Returns:
+ (total_tokens, D) - output tensor
+ """
+ orig_dtype = input.dtype
+
+ token_positions = torch.arange(input.shape[0], device=input.device)
+ expert_idx = torch.bucketize(token_positions, offsets, right=True) + expert_offset
+
+ weight_fp32 = weight.float()
+ bias_fp32 = bias.float()
+
+ per_token_w = weight_fp32[expert_idx]
+ per_token_b = bias_fp32[expert_idx]
+
+ x = input.float()
+ m = mul.float()
+
+ x2 = x * x
+ x3 = x2 * x
+
+ poly = (per_token_w[:, 0:1] * _rms_norm(x3, eps) +
+ per_token_w[:, 1:2] * _rms_norm(x2, eps) +
+ per_token_w[:, 2:3] * _rms_norm(x, eps) + per_token_b)
+
+ return (poly * m).to(orig_dtype)
+
+
+# ---------------------------------------------------------------------------
+# Triton kernel implementation
+# ---------------------------------------------------------------------------
+if HAS_TRITON:
+ # --- Autotune configurations ---
+ _GROUPED_POLYNORM_FWD_CONFIGS = [
+ triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2),
+ triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2),
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2),
+ triton.Config({"BLOCK_D": 2048}, num_warps=4, num_stages=1),
+ triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1),
+ triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1),
+ triton.Config({"BLOCK_D": 2048}, num_warps=32, num_stages=1),
+ ]
+
+ _GROUPED_POLYNORM_BWD_CONFIGS = [
+ triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2),
+ triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=5),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2),
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2),
+ triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1),
+ triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1),
+ ]
+
+ @triton.autotune(
+ configs=_GROUPED_POLYNORM_FWD_CONFIGS,
+ key=["D"],
+ )
+ @triton.jit
+ def _grouped_polynorm_fwd_kernel(
+ input_ptr,
+ mul_ptr,
+ weight_ptr,
+ bias_ptr,
+ offsets_ptr,
+ output_ptr,
+ N,
+ D,
+ num_experts,
+ eps,
+ expert_offset,
+ stride_input_row,
+ stride_mul_row,
+ stride_out_row,
+ BLOCK_D: tl.constexpr,
+ ):
+ """Forward kernel: one program per row."""
+ row = tl.program_id(0)
+ if row >= N:
+ return
+
+ # Binary search for expert index (12 iters covers up to 4096 experts)
+ lo = 0
+ hi = num_experts
+ for _ in range(12):
+ if lo < hi:
+ mid = (lo + hi) // 2
+ if tl.load(offsets_ptr + mid) <= row:
+ lo = mid + 1
+ else:
+ hi = mid
+ eidx = lo + expert_offset
+
+ w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32)
+ w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32)
+ w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32)
+ b = tl.load(bias_ptr + eidx).to(tl.float32)
+
+ input_row_ptr = input_ptr + row * stride_input_row
+ mul_row_ptr = mul_ptr + row * stride_mul_row
+ out_row_ptr = output_ptr + row * stride_out_row
+
+ D_float = D.to(tl.float32)
+
+ # --- Single-tile path ---
+ if D <= BLOCK_D:
+ d_offs = tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+
+ x2 = x * x
+ x3 = x2 * x
+
+ inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps)
+ inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps)
+ inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps)
+
+ # Pre-multiply scalar weight * inv_rms to save 1 FMA per element
+ w0_inv = w0 * inv_rms_x3
+ w1_inv = w1 * inv_rms_x2
+ w2_inv = w2 * inv_rms_x
+
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b
+ tl.store(out_row_ptr + d_offs, poly * m, mask=mask)
+ else:
+ # --- Multi-tile: two-pass approach ---
+ sum_x2 = tl.zeros((), dtype=tl.float32)
+ sum_x4 = tl.zeros((), dtype=tl.float32)
+ sum_x6 = tl.zeros((), dtype=tl.float32)
+
+ for d_start in range(0, D, BLOCK_D):
+ d_offs = d_start + tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ x2 = x * x
+ sum_x2 += tl.sum(x2)
+ sum_x4 += tl.sum(x2 * x2)
+ sum_x6 += tl.sum(x2 * x2 * x2)
+
+ inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps)
+ inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps)
+ inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps)
+
+ # Pre-multiply scalar weight * inv_rms
+ w0_inv = w0 * inv_rms_x3
+ w1_inv = w1 * inv_rms_x2
+ w2_inv = w2 * inv_rms_x
+
+ for d_start in range(0, D, BLOCK_D):
+ d_offs = d_start + tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ x2 = x * x
+ x3 = x2 * x
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b
+ tl.store(out_row_ptr + d_offs, poly * m, mask=mask)
+
+ @triton.autotune(
+ configs=_GROUPED_POLYNORM_BWD_CONFIGS,
+ key=["D"],
+ reset_to_zero=["grad_weight_ptr", "grad_bias_ptr"],
+ )
+ @triton.jit
+ def _grouped_polynorm_bwd_kernel(
+ grad_out_ptr,
+ input_ptr,
+ mul_ptr,
+ weight_ptr,
+ bias_ptr,
+ offsets_ptr,
+ grad_input_ptr,
+ grad_mul_ptr,
+ grad_weight_ptr,
+ grad_bias_ptr,
+ N,
+ D,
+ num_experts,
+ eps,
+ expert_offset,
+ stride_row,
+ BLOCK_D: tl.constexpr,
+ ):
+ """Backward kernel: one program per row, 2-pass approach.
+
+ Pass 1: RMS stats + dot products + bias grad
+ Pass 2: grad_input + grad_mul + weight grads (via atomic_add)
+ """
+ row = tl.program_id(0)
+ if row >= N:
+ return
+
+ lo = 0
+ hi = num_experts
+ for _ in range(12):
+ if lo < hi:
+ mid = (lo + hi) // 2
+ if tl.load(offsets_ptr + mid) <= row:
+ lo = mid + 1
+ else:
+ hi = mid
+ eidx = lo + expert_offset
+
+ w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32)
+ w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32)
+ w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32)
+ b_val = tl.load(bias_ptr + eidx).to(tl.float32)
+
+ input_row_ptr = input_ptr + row * stride_row
+ mul_row_ptr = mul_ptr + row * stride_row
+ grad_out_row_ptr = grad_out_ptr + row * stride_row
+ grad_input_row_ptr = grad_input_ptr + row * stride_row
+ grad_mul_row_ptr = grad_mul_ptr + row * stride_row
+
+ D_float = D.to(tl.float32)
+
+ # --- Single-tile path ---
+ if D <= BLOCK_D:
+ d_offs = tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+
+ x2 = x * x
+ x3 = x2 * x
+
+ # Compute RMS stats (x4 inlined to reduce register pressure)
+ inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps)
+ inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps)
+ inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps)
+
+ w0_inv = w0 * inv_rms_x3
+ w1_inv = w1 * inv_rms_x2
+ w2_inv = w2 * inv_rms_x
+
+ dpoly = go * m
+
+ # Dot products for coefficients and weight grads
+ sum_dpoly_x = tl.sum(dpoly * x)
+ sum_dpoly_x2 = tl.sum(dpoly * x2)
+ sum_dpoly_x3 = tl.sum(dpoly * x3)
+ grad_b_acc = tl.sum(dpoly)
+
+ # Weight grads
+ grad_w0_acc = inv_rms_x3 * sum_dpoly_x3
+ grad_w1_acc = inv_rms_x2 * sum_dpoly_x2
+ grad_w2_acc = inv_rms_x * sum_dpoly_x
+
+ # Coefficients for grad_input
+ coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float
+ coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float
+ coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float
+
+ # grad_mul
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv
+ tl.store(grad_mul_row_ptr + d_offs, go * (poly + b_val), mask=mask)
+
+ # grad_input (in-place accumulation to reduce register pressure)
+ g = inv_rms_x * (w2 * dpoly - x * coeff_x)
+ g += 2.0 * x * inv_rms_x2 * (w1 * dpoly - x2 * coeff_x2)
+ g += 3.0 * x2 * inv_rms_x3 * (w0 * dpoly - x3 * coeff_x3)
+ tl.store(grad_input_row_ptr + d_offs, g, mask=mask)
+
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc)
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc)
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc)
+ tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc)
+ else:
+ # --- Multi-tile: 2-pass ---
+ # Pass 1: RMS stats + dot products + bias grad
+ sum_x2 = tl.zeros((), dtype=tl.float32)
+ sum_x4 = tl.zeros((), dtype=tl.float32)
+ sum_x6 = tl.zeros((), dtype=tl.float32)
+ sum_dpoly_x = tl.zeros((), dtype=tl.float32)
+ sum_dpoly_x2 = tl.zeros((), dtype=tl.float32)
+ sum_dpoly_x3 = tl.zeros((), dtype=tl.float32)
+ grad_b_acc = tl.zeros((), dtype=tl.float32)
+
+ for d_start in range(0, D, BLOCK_D):
+ d_offs = d_start + tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+
+ x2 = x * x
+ x3 = x2 * x
+ dpoly = go * m
+
+ sum_x2 += tl.sum(x2)
+ sum_x4 += tl.sum(x2 * x2)
+ sum_x6 += tl.sum(x2 * x2 * x2)
+ sum_dpoly_x += tl.sum(dpoly * x)
+ sum_dpoly_x2 += tl.sum(dpoly * x2)
+ sum_dpoly_x3 += tl.sum(dpoly * x3)
+ grad_b_acc += tl.sum(dpoly)
+
+ inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps)
+ inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps)
+ inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps)
+
+ w0_inv = w0 * inv_rms_x3
+ w1_inv = w1 * inv_rms_x2
+ w2_inv = w2 * inv_rms_x
+
+ # Weight grads
+ grad_w0_acc = inv_rms_x3 * sum_dpoly_x3
+ grad_w1_acc = inv_rms_x2 * sum_dpoly_x2
+ grad_w2_acc = inv_rms_x * sum_dpoly_x
+
+ # Coefficients for grad_input
+ coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float
+ coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float
+ coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float
+
+ # Pass 2: grad_input + grad_mul
+ for d_start in range(0, D, BLOCK_D):
+ d_offs = d_start + tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+
+ x2 = x * x
+ x3 = x2 * x
+
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv
+ tl.store(grad_mul_row_ptr + d_offs,
+ go * (poly + b_val),
+ mask=mask)
+
+ dpoly = go * m
+ g = inv_rms_x * (w2 * dpoly - x * coeff_x)
+ g += (2.0 * x * inv_rms_x2 *
+ (w1 * dpoly - x2 * coeff_x2))
+ g += (3.0 * x2 * inv_rms_x3 *
+ (w0 * dpoly - x3 * coeff_x3))
+
+ tl.store(grad_input_row_ptr + d_offs, g, mask=mask)
+
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc)
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc)
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc)
+ tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc)
+
+ class _GroupedPolyNormFn(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, input, mul, weight, bias, offsets, eps, expert_offset):
+ N, D = input.shape
+ input = input.contiguous()
+ mul = mul.contiguous()
+ output = torch.empty_like(input)
+
+ num_experts = offsets.shape[0]
+ assert num_experts <= 4096, (
+ f"Supports at most 4096 experts, got {num_experts}.")
+
+ _grouped_polynorm_fwd_kernel[(N,)](
+ input,
+ mul,
+ weight,
+ bias,
+ offsets,
+ output,
+ N,
+ D,
+ num_experts,
+ eps,
+ expert_offset,
+ stride_input_row=input.stride(0),
+ stride_mul_row=mul.stride(0),
+ stride_out_row=output.stride(0),
+ )
+
+ ctx.save_for_backward(input, mul, weight, bias, offsets)
+ ctx.eps = eps
+ ctx.expert_offset = expert_offset
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input, mul, weight, bias, offsets = ctx.saved_tensors
+ eps = ctx.eps
+ expert_offset = ctx.expert_offset
+ N, D = input.shape
+
+ grad_output = grad_output.contiguous()
+ grad_input = torch.empty_like(input)
+ grad_mul = torch.empty_like(mul)
+ grad_weight = torch.zeros(weight.shape[0],
+ 3,
+ device=weight.device,
+ dtype=torch.float32)
+ grad_bias = torch.zeros(bias.shape[0],
+ device=bias.device,
+ dtype=torch.float32)
+
+ num_experts = offsets.shape[0]
+
+ _grouped_polynorm_bwd_kernel[(N,)](
+ grad_output,
+ input,
+ mul,
+ weight,
+ bias,
+ offsets,
+ grad_input,
+ grad_mul,
+ grad_weight,
+ grad_bias,
+ N,
+ D,
+ num_experts,
+ eps,
+ expert_offset,
+ stride_row=input.stride(0),
+ )
+
+ grad_weight = grad_weight.to(weight.dtype)
+ grad_bias = grad_bias.unsqueeze(-1).to(bias.dtype)
+
+ return grad_input, grad_mul, grad_weight, grad_bias, None, None, None
+
+ def grouped_fused_mul_poly_norm(
+ input: Tensor,
+ mul: Tensor,
+ weight: Tensor,
+ bias: Tensor,
+ offsets: Tensor,
+ eps: float = 1e-6,
+ expert_offset: int = 0,
+ ) -> Tensor:
+ """Triton-accelerated Grouped FusedMulPolyNorm.
+
+ Args:
+ input: (total_tokens, D) - concatenated tokens for all experts
+ mul: (total_tokens, D) - gate values to multiply with
+ weight: (num_experts, 3) - per-expert polynomial weights
+ bias: (num_experts, 1) - per-expert polynomial bias
+ offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32)
+ eps: numerical stability epsilon
+ expert_offset: offset to add to expert index
+
+ Returns:
+ (total_tokens, D) - output tensor
+ """
+ return _GroupedPolyNormFn.apply(input, mul, weight, bias, offsets, eps,
+ expert_offset)
+
+else:
+
+ def grouped_fused_mul_poly_norm(
+ input: Tensor,
+ mul: Tensor,
+ weight: Tensor,
+ bias: Tensor,
+ offsets: Tensor,
+ eps: float = 1e-6,
+ expert_offset: int = 0,
+ ) -> Tensor:
+ raise RuntimeError(
+ "Triton is not available. Install triton to use "
+ "grouped_fused_mul_poly_norm.")
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/__init__.py b/build/torch28-cxx11-cu129-x86_64-linux/__init__.py
index 0f6f29ac2c688bd09afa41c5d1abd9942c4456d8..a71567edf5422a315e6f06eae6a5606756794820 100644
--- a/build/torch28-cxx11-cu129-x86_64-linux/__init__.py
+++ b/build/torch28-cxx11-cu129-x86_64-linux/__init__.py
@@ -2,6 +2,7 @@ import torch
from . import layers, parallel_style
from ._ops import ops
+from .grouped_poly_norm import grouped_fused_mul_poly_norm
from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
@@ -45,6 +46,7 @@ def fused_add_rms_norm(
__all__ = [
"poly_norm",
"fused_mul_poly_norm",
+ "grouped_fused_mul_poly_norm",
"rms_norm",
"fused_add_rms_norm",
"layers",
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/_activation_0e6f27f_dirty.abi3.so b/build/torch28-cxx11-cu129-x86_64-linux/_activation_0e6f27f_dirty.abi3.so
new file mode 100755
index 0000000000000000000000000000000000000000..44426f622e2e75950aa71623220360e652ef8249
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/_activation_0e6f27f_dirty.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f06fb9594dcf9c0bc3a8af619fec3a541b775f0dad304a7978314d32fae8d244
+size 15795640
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/_activation_18b7543_dirty.abi3.so b/build/torch28-cxx11-cu129-x86_64-linux/_activation_18b7543_dirty.abi3.so
deleted file mode 100755
index 77b558a4079bd17e0c081a80c0ca9891f4807ebb..0000000000000000000000000000000000000000
--- a/build/torch28-cxx11-cu129-x86_64-linux/_activation_18b7543_dirty.abi3.so
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:f589ae5cb26c7fb85fac22ee789749c54897bf08b170dbf77b9d62eb98ee8b53
-size 15795640
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/_ops.py b/build/torch28-cxx11-cu129-x86_64-linux/_ops.py
index 1e14ae6ea50287a7fbc0feee54855ae4ffc0be5e..91c4fcaab95d4cccfa1e53fea482ec1a2334f4ff 100644
--- a/build/torch28-cxx11-cu129-x86_64-linux/_ops.py
+++ b/build/torch28-cxx11-cu129-x86_64-linux/_ops.py
@@ -1,9 +1,9 @@
import torch
-from . import _activation_18b7543_dirty
-ops = torch.ops._activation_18b7543_dirty
+from . import _activation_0e6f27f_dirty
+ops = torch.ops._activation_0e6f27f_dirty
def add_op_namespace_prefix(op_name: str):
"""
Prefix op by namespace.
"""
- return f"_activation_18b7543_dirty::{op_name}"
\ No newline at end of file
+ return f"_activation_0e6f27f_dirty::{op_name}"
\ No newline at end of file
diff --git a/build/torch28-cxx11-cu129-x86_64-linux/grouped_poly_norm.py b/build/torch28-cxx11-cu129-x86_64-linux/grouped_poly_norm.py
new file mode 100644
index 0000000000000000000000000000000000000000..9184736c1c24584a94aa80c84fc18a95ab557a52
--- /dev/null
+++ b/build/torch28-cxx11-cu129-x86_64-linux/grouped_poly_norm.py
@@ -0,0 +1,583 @@
+"""Triton-accelerated Grouped FusedMulPolyNorm for MoE.
+
+Fuses the entire PolyNorm computation into two Triton kernels (fwd + bwd),
+eliminating multiple intermediate tensors and kernel launches.
+
+PolyNorm formula (per row):
+ poly = w[0] * rms_norm(x^3) + w[1] * rms_norm(x^2) + w[2] * rms_norm(x) + bias
+ output = poly * mul
+
+where rms_norm(x) = x / sqrt(mean(x^2, dim=-1) + eps)
+
+Performance optimizations:
+ - @triton.autotune selects optimal BLOCK_D, num_warps, and num_stages per
+ hidden dimension.
+ - Single-tile specialization: when D <= BLOCK_D, all data stays in registers
+ across the reduction and output phases, eliminating redundant global reads.
+ - Multi-tile software pipelining: explicit num_stages in autotune configs
+ enables overlapping memory loads with computation across loop iterations.
+ - In-kernel binary search for expert mapping: eliminates 2 PyTorch kernel
+ launches (torch.arange + torch.bucketize) per forward/backward call.
+ - Backward 2-pass optimization: pass 1 merges RMS statistics computation
+ with dot product accumulation, pass 2 computes gradients. This reduces
+ memory traffic compared to a naive 3-pass approach.
+
+Forward kernel: one program per row, tiles over D dimension.
+ - Computes x, x^2, x^3 in registers
+ - Computes three RMS norms in a single pass (shared variance reduction)
+ - Applies polynomial weights + bias + mul in-place
+
+Backward kernel: one program per row, tiles over D dimension.
+ - Recomputes forward intermediates from saved inputs (activation recomputation)
+ - 2-pass: (1) RMS stats + dot products + bias grad, (2) grad_input + grad_mul + weight grads
+ - Weight/bias gradients use tl.atomic_add for cross-row accumulation
+"""
+
+import torch
+from torch import Tensor
+
+try:
+ import triton
+ import triton.language as tl
+
+ HAS_TRITON = True
+except ImportError:
+ HAS_TRITON = False
+
+
+# ---------------------------------------------------------------------------
+# PyTorch reference implementation (for testing and benchmarking)
+# ---------------------------------------------------------------------------
+def _rms_norm(x: Tensor, eps: float) -> Tensor:
+ """Per-row RMS normalization: x / sqrt(mean(x^2, dim=-1) + eps)"""
+ return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps)
+
+
+def grouped_fused_mul_poly_norm_ref(
+ input: Tensor,
+ mul: Tensor,
+ weight: Tensor,
+ bias: Tensor,
+ offsets: Tensor,
+ eps: float = 1e-6,
+ expert_offset: int = 0,
+) -> Tensor:
+ """PyTorch reference for Grouped FusedMulPolyNorm (vectorized, single pass).
+
+ Uses torch.bucketize to map tokens to experts, then computes PolyNorm
+ for all tokens at once. torch.compile friendly.
+
+ Args:
+ input: (total_tokens, D) - concatenated tokens for all experts
+ mul: (total_tokens, D) - gate values to multiply with
+ weight: (num_experts, 3) - per-expert polynomial weights [x^3, x^2, x]
+ bias: (num_experts, 1) - per-expert polynomial bias
+ offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32)
+ eps: numerical stability epsilon
+
+ Returns:
+ (total_tokens, D) - output tensor
+ """
+ orig_dtype = input.dtype
+
+ token_positions = torch.arange(input.shape[0], device=input.device)
+ expert_idx = torch.bucketize(token_positions, offsets, right=True) + expert_offset
+
+ weight_fp32 = weight.float()
+ bias_fp32 = bias.float()
+
+ per_token_w = weight_fp32[expert_idx]
+ per_token_b = bias_fp32[expert_idx]
+
+ x = input.float()
+ m = mul.float()
+
+ x2 = x * x
+ x3 = x2 * x
+
+ poly = (per_token_w[:, 0:1] * _rms_norm(x3, eps) +
+ per_token_w[:, 1:2] * _rms_norm(x2, eps) +
+ per_token_w[:, 2:3] * _rms_norm(x, eps) + per_token_b)
+
+ return (poly * m).to(orig_dtype)
+
+
+# ---------------------------------------------------------------------------
+# Triton kernel implementation
+# ---------------------------------------------------------------------------
+if HAS_TRITON:
+ # --- Autotune configurations ---
+ _GROUPED_POLYNORM_FWD_CONFIGS = [
+ triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2),
+ triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2),
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2),
+ triton.Config({"BLOCK_D": 2048}, num_warps=4, num_stages=1),
+ triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1),
+ triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1),
+ triton.Config({"BLOCK_D": 2048}, num_warps=32, num_stages=1),
+ ]
+
+ _GROUPED_POLYNORM_BWD_CONFIGS = [
+ triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2),
+ triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=5),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2),
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2),
+ triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1),
+ triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1),
+ ]
+
+ @triton.autotune(
+ configs=_GROUPED_POLYNORM_FWD_CONFIGS,
+ key=["D"],
+ )
+ @triton.jit
+ def _grouped_polynorm_fwd_kernel(
+ input_ptr,
+ mul_ptr,
+ weight_ptr,
+ bias_ptr,
+ offsets_ptr,
+ output_ptr,
+ N,
+ D,
+ num_experts,
+ eps,
+ expert_offset,
+ stride_input_row,
+ stride_mul_row,
+ stride_out_row,
+ BLOCK_D: tl.constexpr,
+ ):
+ """Forward kernel: one program per row."""
+ row = tl.program_id(0)
+ if row >= N:
+ return
+
+ # Binary search for expert index (12 iters covers up to 4096 experts)
+ lo = 0
+ hi = num_experts
+ for _ in range(12):
+ if lo < hi:
+ mid = (lo + hi) // 2
+ if tl.load(offsets_ptr + mid) <= row:
+ lo = mid + 1
+ else:
+ hi = mid
+ eidx = lo + expert_offset
+
+ w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32)
+ w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32)
+ w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32)
+ b = tl.load(bias_ptr + eidx).to(tl.float32)
+
+ input_row_ptr = input_ptr + row * stride_input_row
+ mul_row_ptr = mul_ptr + row * stride_mul_row
+ out_row_ptr = output_ptr + row * stride_out_row
+
+ D_float = D.to(tl.float32)
+
+ # --- Single-tile path ---
+ if D <= BLOCK_D:
+ d_offs = tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+
+ x2 = x * x
+ x3 = x2 * x
+
+ inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps)
+ inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps)
+ inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps)
+
+ # Pre-multiply scalar weight * inv_rms to save 1 FMA per element
+ w0_inv = w0 * inv_rms_x3
+ w1_inv = w1 * inv_rms_x2
+ w2_inv = w2 * inv_rms_x
+
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b
+ tl.store(out_row_ptr + d_offs, poly * m, mask=mask)
+ else:
+ # --- Multi-tile: two-pass approach ---
+ sum_x2 = tl.zeros((), dtype=tl.float32)
+ sum_x4 = tl.zeros((), dtype=tl.float32)
+ sum_x6 = tl.zeros((), dtype=tl.float32)
+
+ for d_start in range(0, D, BLOCK_D):
+ d_offs = d_start + tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ x2 = x * x
+ sum_x2 += tl.sum(x2)
+ sum_x4 += tl.sum(x2 * x2)
+ sum_x6 += tl.sum(x2 * x2 * x2)
+
+ inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps)
+ inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps)
+ inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps)
+
+ # Pre-multiply scalar weight * inv_rms
+ w0_inv = w0 * inv_rms_x3
+ w1_inv = w1 * inv_rms_x2
+ w2_inv = w2 * inv_rms_x
+
+ for d_start in range(0, D, BLOCK_D):
+ d_offs = d_start + tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ x2 = x * x
+ x3 = x2 * x
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b
+ tl.store(out_row_ptr + d_offs, poly * m, mask=mask)
+
+ @triton.autotune(
+ configs=_GROUPED_POLYNORM_BWD_CONFIGS,
+ key=["D"],
+ reset_to_zero=["grad_weight_ptr", "grad_bias_ptr"],
+ )
+ @triton.jit
+ def _grouped_polynorm_bwd_kernel(
+ grad_out_ptr,
+ input_ptr,
+ mul_ptr,
+ weight_ptr,
+ bias_ptr,
+ offsets_ptr,
+ grad_input_ptr,
+ grad_mul_ptr,
+ grad_weight_ptr,
+ grad_bias_ptr,
+ N,
+ D,
+ num_experts,
+ eps,
+ expert_offset,
+ stride_row,
+ BLOCK_D: tl.constexpr,
+ ):
+ """Backward kernel: one program per row, 2-pass approach.
+
+ Pass 1: RMS stats + dot products + bias grad
+ Pass 2: grad_input + grad_mul + weight grads (via atomic_add)
+ """
+ row = tl.program_id(0)
+ if row >= N:
+ return
+
+ lo = 0
+ hi = num_experts
+ for _ in range(12):
+ if lo < hi:
+ mid = (lo + hi) // 2
+ if tl.load(offsets_ptr + mid) <= row:
+ lo = mid + 1
+ else:
+ hi = mid
+ eidx = lo + expert_offset
+
+ w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32)
+ w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32)
+ w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32)
+ b_val = tl.load(bias_ptr + eidx).to(tl.float32)
+
+ input_row_ptr = input_ptr + row * stride_row
+ mul_row_ptr = mul_ptr + row * stride_row
+ grad_out_row_ptr = grad_out_ptr + row * stride_row
+ grad_input_row_ptr = grad_input_ptr + row * stride_row
+ grad_mul_row_ptr = grad_mul_ptr + row * stride_row
+
+ D_float = D.to(tl.float32)
+
+ # --- Single-tile path ---
+ if D <= BLOCK_D:
+ d_offs = tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+
+ x2 = x * x
+ x3 = x2 * x
+
+ # Compute RMS stats (x4 inlined to reduce register pressure)
+ inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps)
+ inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps)
+ inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps)
+
+ w0_inv = w0 * inv_rms_x3
+ w1_inv = w1 * inv_rms_x2
+ w2_inv = w2 * inv_rms_x
+
+ dpoly = go * m
+
+ # Dot products for coefficients and weight grads
+ sum_dpoly_x = tl.sum(dpoly * x)
+ sum_dpoly_x2 = tl.sum(dpoly * x2)
+ sum_dpoly_x3 = tl.sum(dpoly * x3)
+ grad_b_acc = tl.sum(dpoly)
+
+ # Weight grads
+ grad_w0_acc = inv_rms_x3 * sum_dpoly_x3
+ grad_w1_acc = inv_rms_x2 * sum_dpoly_x2
+ grad_w2_acc = inv_rms_x * sum_dpoly_x
+
+ # Coefficients for grad_input
+ coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float
+ coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float
+ coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float
+
+ # grad_mul
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv
+ tl.store(grad_mul_row_ptr + d_offs, go * (poly + b_val), mask=mask)
+
+ # grad_input (in-place accumulation to reduce register pressure)
+ g = inv_rms_x * (w2 * dpoly - x * coeff_x)
+ g += 2.0 * x * inv_rms_x2 * (w1 * dpoly - x2 * coeff_x2)
+ g += 3.0 * x2 * inv_rms_x3 * (w0 * dpoly - x3 * coeff_x3)
+ tl.store(grad_input_row_ptr + d_offs, g, mask=mask)
+
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc)
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc)
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc)
+ tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc)
+ else:
+ # --- Multi-tile: 2-pass ---
+ # Pass 1: RMS stats + dot products + bias grad
+ sum_x2 = tl.zeros((), dtype=tl.float32)
+ sum_x4 = tl.zeros((), dtype=tl.float32)
+ sum_x6 = tl.zeros((), dtype=tl.float32)
+ sum_dpoly_x = tl.zeros((), dtype=tl.float32)
+ sum_dpoly_x2 = tl.zeros((), dtype=tl.float32)
+ sum_dpoly_x3 = tl.zeros((), dtype=tl.float32)
+ grad_b_acc = tl.zeros((), dtype=tl.float32)
+
+ for d_start in range(0, D, BLOCK_D):
+ d_offs = d_start + tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+
+ x2 = x * x
+ x3 = x2 * x
+ dpoly = go * m
+
+ sum_x2 += tl.sum(x2)
+ sum_x4 += tl.sum(x2 * x2)
+ sum_x6 += tl.sum(x2 * x2 * x2)
+ sum_dpoly_x += tl.sum(dpoly * x)
+ sum_dpoly_x2 += tl.sum(dpoly * x2)
+ sum_dpoly_x3 += tl.sum(dpoly * x3)
+ grad_b_acc += tl.sum(dpoly)
+
+ inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps)
+ inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps)
+ inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps)
+
+ w0_inv = w0 * inv_rms_x3
+ w1_inv = w1 * inv_rms_x2
+ w2_inv = w2 * inv_rms_x
+
+ # Weight grads
+ grad_w0_acc = inv_rms_x3 * sum_dpoly_x3
+ grad_w1_acc = inv_rms_x2 * sum_dpoly_x2
+ grad_w2_acc = inv_rms_x * sum_dpoly_x
+
+ # Coefficients for grad_input
+ coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float
+ coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float
+ coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float
+
+ # Pass 2: grad_input + grad_mul
+ for d_start in range(0, D, BLOCK_D):
+ d_offs = d_start + tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+
+ x2 = x * x
+ x3 = x2 * x
+
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv
+ tl.store(grad_mul_row_ptr + d_offs,
+ go * (poly + b_val),
+ mask=mask)
+
+ dpoly = go * m
+ g = inv_rms_x * (w2 * dpoly - x * coeff_x)
+ g += (2.0 * x * inv_rms_x2 *
+ (w1 * dpoly - x2 * coeff_x2))
+ g += (3.0 * x2 * inv_rms_x3 *
+ (w0 * dpoly - x3 * coeff_x3))
+
+ tl.store(grad_input_row_ptr + d_offs, g, mask=mask)
+
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc)
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc)
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc)
+ tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc)
+
+ class _GroupedPolyNormFn(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, input, mul, weight, bias, offsets, eps, expert_offset):
+ N, D = input.shape
+ input = input.contiguous()
+ mul = mul.contiguous()
+ output = torch.empty_like(input)
+
+ num_experts = offsets.shape[0]
+ assert num_experts <= 4096, (
+ f"Supports at most 4096 experts, got {num_experts}.")
+
+ _grouped_polynorm_fwd_kernel[(N,)](
+ input,
+ mul,
+ weight,
+ bias,
+ offsets,
+ output,
+ N,
+ D,
+ num_experts,
+ eps,
+ expert_offset,
+ stride_input_row=input.stride(0),
+ stride_mul_row=mul.stride(0),
+ stride_out_row=output.stride(0),
+ )
+
+ ctx.save_for_backward(input, mul, weight, bias, offsets)
+ ctx.eps = eps
+ ctx.expert_offset = expert_offset
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input, mul, weight, bias, offsets = ctx.saved_tensors
+ eps = ctx.eps
+ expert_offset = ctx.expert_offset
+ N, D = input.shape
+
+ grad_output = grad_output.contiguous()
+ grad_input = torch.empty_like(input)
+ grad_mul = torch.empty_like(mul)
+ grad_weight = torch.zeros(weight.shape[0],
+ 3,
+ device=weight.device,
+ dtype=torch.float32)
+ grad_bias = torch.zeros(bias.shape[0],
+ device=bias.device,
+ dtype=torch.float32)
+
+ num_experts = offsets.shape[0]
+
+ _grouped_polynorm_bwd_kernel[(N,)](
+ grad_output,
+ input,
+ mul,
+ weight,
+ bias,
+ offsets,
+ grad_input,
+ grad_mul,
+ grad_weight,
+ grad_bias,
+ N,
+ D,
+ num_experts,
+ eps,
+ expert_offset,
+ stride_row=input.stride(0),
+ )
+
+ grad_weight = grad_weight.to(weight.dtype)
+ grad_bias = grad_bias.unsqueeze(-1).to(bias.dtype)
+
+ return grad_input, grad_mul, grad_weight, grad_bias, None, None, None
+
+ def grouped_fused_mul_poly_norm(
+ input: Tensor,
+ mul: Tensor,
+ weight: Tensor,
+ bias: Tensor,
+ offsets: Tensor,
+ eps: float = 1e-6,
+ expert_offset: int = 0,
+ ) -> Tensor:
+ """Triton-accelerated Grouped FusedMulPolyNorm.
+
+ Args:
+ input: (total_tokens, D) - concatenated tokens for all experts
+ mul: (total_tokens, D) - gate values to multiply with
+ weight: (num_experts, 3) - per-expert polynomial weights
+ bias: (num_experts, 1) - per-expert polynomial bias
+ offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32)
+ eps: numerical stability epsilon
+ expert_offset: offset to add to expert index
+
+ Returns:
+ (total_tokens, D) - output tensor
+ """
+ return _GroupedPolyNormFn.apply(input, mul, weight, bias, offsets, eps,
+ expert_offset)
+
+else:
+
+ def grouped_fused_mul_poly_norm(
+ input: Tensor,
+ mul: Tensor,
+ weight: Tensor,
+ bias: Tensor,
+ offsets: Tensor,
+ eps: float = 1e-6,
+ expert_offset: int = 0,
+ ) -> Tensor:
+ raise RuntimeError(
+ "Triton is not available. Install triton to use "
+ "grouped_fused_mul_poly_norm.")
diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/__init__.py b/build/torch28-cxx11-rocm63-x86_64-linux/__init__.py
index 0f6f29ac2c688bd09afa41c5d1abd9942c4456d8..a71567edf5422a315e6f06eae6a5606756794820 100644
--- a/build/torch28-cxx11-rocm63-x86_64-linux/__init__.py
+++ b/build/torch28-cxx11-rocm63-x86_64-linux/__init__.py
@@ -2,6 +2,7 @@ import torch
from . import layers, parallel_style
from ._ops import ops
+from .grouped_poly_norm import grouped_fused_mul_poly_norm
from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
@@ -45,6 +46,7 @@ def fused_add_rms_norm(
__all__ = [
"poly_norm",
"fused_mul_poly_norm",
+ "grouped_fused_mul_poly_norm",
"rms_norm",
"fused_add_rms_norm",
"layers",
diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/_activation_0e6f27f_dirty.abi3.so b/build/torch28-cxx11-rocm63-x86_64-linux/_activation_0e6f27f_dirty.abi3.so
new file mode 100755
index 0000000000000000000000000000000000000000..6941562f2f9dcf7d6d1ceeb9d4a3ad67beb6be04
--- /dev/null
+++ b/build/torch28-cxx11-rocm63-x86_64-linux/_activation_0e6f27f_dirty.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:448ab4d8a859725a3200d95d2164a3fe261f67b20e834f4f7062485cf729cf88
+size 2788456
diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/_activation_18b7543_dirty.abi3.so b/build/torch28-cxx11-rocm63-x86_64-linux/_activation_18b7543_dirty.abi3.so
deleted file mode 100755
index a2aba7df18f60f2562bb4ca5f6a5dd8a15253c75..0000000000000000000000000000000000000000
--- a/build/torch28-cxx11-rocm63-x86_64-linux/_activation_18b7543_dirty.abi3.so
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:e753cc4a2aaa76aea3b821b245d4da638d7c94059a660f8233e17a54df379813
-size 2788456
diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/_ops.py b/build/torch28-cxx11-rocm63-x86_64-linux/_ops.py
index 1e14ae6ea50287a7fbc0feee54855ae4ffc0be5e..91c4fcaab95d4cccfa1e53fea482ec1a2334f4ff 100644
--- a/build/torch28-cxx11-rocm63-x86_64-linux/_ops.py
+++ b/build/torch28-cxx11-rocm63-x86_64-linux/_ops.py
@@ -1,9 +1,9 @@
import torch
-from . import _activation_18b7543_dirty
-ops = torch.ops._activation_18b7543_dirty
+from . import _activation_0e6f27f_dirty
+ops = torch.ops._activation_0e6f27f_dirty
def add_op_namespace_prefix(op_name: str):
"""
Prefix op by namespace.
"""
- return f"_activation_18b7543_dirty::{op_name}"
\ No newline at end of file
+ return f"_activation_0e6f27f_dirty::{op_name}"
\ No newline at end of file
diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/grouped_poly_norm.py b/build/torch28-cxx11-rocm63-x86_64-linux/grouped_poly_norm.py
new file mode 100644
index 0000000000000000000000000000000000000000..9184736c1c24584a94aa80c84fc18a95ab557a52
--- /dev/null
+++ b/build/torch28-cxx11-rocm63-x86_64-linux/grouped_poly_norm.py
@@ -0,0 +1,583 @@
+"""Triton-accelerated Grouped FusedMulPolyNorm for MoE.
+
+Fuses the entire PolyNorm computation into two Triton kernels (fwd + bwd),
+eliminating multiple intermediate tensors and kernel launches.
+
+PolyNorm formula (per row):
+ poly = w[0] * rms_norm(x^3) + w[1] * rms_norm(x^2) + w[2] * rms_norm(x) + bias
+ output = poly * mul
+
+where rms_norm(x) = x / sqrt(mean(x^2, dim=-1) + eps)
+
+Performance optimizations:
+ - @triton.autotune selects optimal BLOCK_D, num_warps, and num_stages per
+ hidden dimension.
+ - Single-tile specialization: when D <= BLOCK_D, all data stays in registers
+ across the reduction and output phases, eliminating redundant global reads.
+ - Multi-tile software pipelining: explicit num_stages in autotune configs
+ enables overlapping memory loads with computation across loop iterations.
+ - In-kernel binary search for expert mapping: eliminates 2 PyTorch kernel
+ launches (torch.arange + torch.bucketize) per forward/backward call.
+ - Backward 2-pass optimization: pass 1 merges RMS statistics computation
+ with dot product accumulation, pass 2 computes gradients. This reduces
+ memory traffic compared to a naive 3-pass approach.
+
+Forward kernel: one program per row, tiles over D dimension.
+ - Computes x, x^2, x^3 in registers
+ - Computes three RMS norms in a single pass (shared variance reduction)
+ - Applies polynomial weights + bias + mul in-place
+
+Backward kernel: one program per row, tiles over D dimension.
+ - Recomputes forward intermediates from saved inputs (activation recomputation)
+ - 2-pass: (1) RMS stats + dot products + bias grad, (2) grad_input + grad_mul + weight grads
+ - Weight/bias gradients use tl.atomic_add for cross-row accumulation
+"""
+
+import torch
+from torch import Tensor
+
+try:
+ import triton
+ import triton.language as tl
+
+ HAS_TRITON = True
+except ImportError:
+ HAS_TRITON = False
+
+
+# ---------------------------------------------------------------------------
+# PyTorch reference implementation (for testing and benchmarking)
+# ---------------------------------------------------------------------------
+def _rms_norm(x: Tensor, eps: float) -> Tensor:
+ """Per-row RMS normalization: x / sqrt(mean(x^2, dim=-1) + eps)"""
+ return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps)
+
+
+def grouped_fused_mul_poly_norm_ref(
+ input: Tensor,
+ mul: Tensor,
+ weight: Tensor,
+ bias: Tensor,
+ offsets: Tensor,
+ eps: float = 1e-6,
+ expert_offset: int = 0,
+) -> Tensor:
+ """PyTorch reference for Grouped FusedMulPolyNorm (vectorized, single pass).
+
+ Uses torch.bucketize to map tokens to experts, then computes PolyNorm
+ for all tokens at once. torch.compile friendly.
+
+ Args:
+ input: (total_tokens, D) - concatenated tokens for all experts
+ mul: (total_tokens, D) - gate values to multiply with
+ weight: (num_experts, 3) - per-expert polynomial weights [x^3, x^2, x]
+ bias: (num_experts, 1) - per-expert polynomial bias
+ offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32)
+ eps: numerical stability epsilon
+
+ Returns:
+ (total_tokens, D) - output tensor
+ """
+ orig_dtype = input.dtype
+
+ token_positions = torch.arange(input.shape[0], device=input.device)
+ expert_idx = torch.bucketize(token_positions, offsets, right=True) + expert_offset
+
+ weight_fp32 = weight.float()
+ bias_fp32 = bias.float()
+
+ per_token_w = weight_fp32[expert_idx]
+ per_token_b = bias_fp32[expert_idx]
+
+ x = input.float()
+ m = mul.float()
+
+ x2 = x * x
+ x3 = x2 * x
+
+ poly = (per_token_w[:, 0:1] * _rms_norm(x3, eps) +
+ per_token_w[:, 1:2] * _rms_norm(x2, eps) +
+ per_token_w[:, 2:3] * _rms_norm(x, eps) + per_token_b)
+
+ return (poly * m).to(orig_dtype)
+
+
+# ---------------------------------------------------------------------------
+# Triton kernel implementation
+# ---------------------------------------------------------------------------
+if HAS_TRITON:
+ # --- Autotune configurations ---
+ _GROUPED_POLYNORM_FWD_CONFIGS = [
+ triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2),
+ triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2),
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2),
+ triton.Config({"BLOCK_D": 2048}, num_warps=4, num_stages=1),
+ triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1),
+ triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1),
+ triton.Config({"BLOCK_D": 2048}, num_warps=32, num_stages=1),
+ ]
+
+ _GROUPED_POLYNORM_BWD_CONFIGS = [
+ triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2),
+ triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=5),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2),
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2),
+ triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1),
+ triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1),
+ ]
+
+ @triton.autotune(
+ configs=_GROUPED_POLYNORM_FWD_CONFIGS,
+ key=["D"],
+ )
+ @triton.jit
+ def _grouped_polynorm_fwd_kernel(
+ input_ptr,
+ mul_ptr,
+ weight_ptr,
+ bias_ptr,
+ offsets_ptr,
+ output_ptr,
+ N,
+ D,
+ num_experts,
+ eps,
+ expert_offset,
+ stride_input_row,
+ stride_mul_row,
+ stride_out_row,
+ BLOCK_D: tl.constexpr,
+ ):
+ """Forward kernel: one program per row."""
+ row = tl.program_id(0)
+ if row >= N:
+ return
+
+ # Binary search for expert index (12 iters covers up to 4096 experts)
+ lo = 0
+ hi = num_experts
+ for _ in range(12):
+ if lo < hi:
+ mid = (lo + hi) // 2
+ if tl.load(offsets_ptr + mid) <= row:
+ lo = mid + 1
+ else:
+ hi = mid
+ eidx = lo + expert_offset
+
+ w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32)
+ w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32)
+ w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32)
+ b = tl.load(bias_ptr + eidx).to(tl.float32)
+
+ input_row_ptr = input_ptr + row * stride_input_row
+ mul_row_ptr = mul_ptr + row * stride_mul_row
+ out_row_ptr = output_ptr + row * stride_out_row
+
+ D_float = D.to(tl.float32)
+
+ # --- Single-tile path ---
+ if D <= BLOCK_D:
+ d_offs = tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+
+ x2 = x * x
+ x3 = x2 * x
+
+ inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps)
+ inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps)
+ inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps)
+
+ # Pre-multiply scalar weight * inv_rms to save 1 FMA per element
+ w0_inv = w0 * inv_rms_x3
+ w1_inv = w1 * inv_rms_x2
+ w2_inv = w2 * inv_rms_x
+
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b
+ tl.store(out_row_ptr + d_offs, poly * m, mask=mask)
+ else:
+ # --- Multi-tile: two-pass approach ---
+ sum_x2 = tl.zeros((), dtype=tl.float32)
+ sum_x4 = tl.zeros((), dtype=tl.float32)
+ sum_x6 = tl.zeros((), dtype=tl.float32)
+
+ for d_start in range(0, D, BLOCK_D):
+ d_offs = d_start + tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ x2 = x * x
+ sum_x2 += tl.sum(x2)
+ sum_x4 += tl.sum(x2 * x2)
+ sum_x6 += tl.sum(x2 * x2 * x2)
+
+ inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps)
+ inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps)
+ inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps)
+
+ # Pre-multiply scalar weight * inv_rms
+ w0_inv = w0 * inv_rms_x3
+ w1_inv = w1 * inv_rms_x2
+ w2_inv = w2 * inv_rms_x
+
+ for d_start in range(0, D, BLOCK_D):
+ d_offs = d_start + tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ x2 = x * x
+ x3 = x2 * x
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b
+ tl.store(out_row_ptr + d_offs, poly * m, mask=mask)
+
+ @triton.autotune(
+ configs=_GROUPED_POLYNORM_BWD_CONFIGS,
+ key=["D"],
+ reset_to_zero=["grad_weight_ptr", "grad_bias_ptr"],
+ )
+ @triton.jit
+ def _grouped_polynorm_bwd_kernel(
+ grad_out_ptr,
+ input_ptr,
+ mul_ptr,
+ weight_ptr,
+ bias_ptr,
+ offsets_ptr,
+ grad_input_ptr,
+ grad_mul_ptr,
+ grad_weight_ptr,
+ grad_bias_ptr,
+ N,
+ D,
+ num_experts,
+ eps,
+ expert_offset,
+ stride_row,
+ BLOCK_D: tl.constexpr,
+ ):
+ """Backward kernel: one program per row, 2-pass approach.
+
+ Pass 1: RMS stats + dot products + bias grad
+ Pass 2: grad_input + grad_mul + weight grads (via atomic_add)
+ """
+ row = tl.program_id(0)
+ if row >= N:
+ return
+
+ lo = 0
+ hi = num_experts
+ for _ in range(12):
+ if lo < hi:
+ mid = (lo + hi) // 2
+ if tl.load(offsets_ptr + mid) <= row:
+ lo = mid + 1
+ else:
+ hi = mid
+ eidx = lo + expert_offset
+
+ w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32)
+ w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32)
+ w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32)
+ b_val = tl.load(bias_ptr + eidx).to(tl.float32)
+
+ input_row_ptr = input_ptr + row * stride_row
+ mul_row_ptr = mul_ptr + row * stride_row
+ grad_out_row_ptr = grad_out_ptr + row * stride_row
+ grad_input_row_ptr = grad_input_ptr + row * stride_row
+ grad_mul_row_ptr = grad_mul_ptr + row * stride_row
+
+ D_float = D.to(tl.float32)
+
+ # --- Single-tile path ---
+ if D <= BLOCK_D:
+ d_offs = tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+
+ x2 = x * x
+ x3 = x2 * x
+
+ # Compute RMS stats (x4 inlined to reduce register pressure)
+ inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps)
+ inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps)
+ inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps)
+
+ w0_inv = w0 * inv_rms_x3
+ w1_inv = w1 * inv_rms_x2
+ w2_inv = w2 * inv_rms_x
+
+ dpoly = go * m
+
+ # Dot products for coefficients and weight grads
+ sum_dpoly_x = tl.sum(dpoly * x)
+ sum_dpoly_x2 = tl.sum(dpoly * x2)
+ sum_dpoly_x3 = tl.sum(dpoly * x3)
+ grad_b_acc = tl.sum(dpoly)
+
+ # Weight grads
+ grad_w0_acc = inv_rms_x3 * sum_dpoly_x3
+ grad_w1_acc = inv_rms_x2 * sum_dpoly_x2
+ grad_w2_acc = inv_rms_x * sum_dpoly_x
+
+ # Coefficients for grad_input
+ coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float
+ coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float
+ coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float
+
+ # grad_mul
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv
+ tl.store(grad_mul_row_ptr + d_offs, go * (poly + b_val), mask=mask)
+
+ # grad_input (in-place accumulation to reduce register pressure)
+ g = inv_rms_x * (w2 * dpoly - x * coeff_x)
+ g += 2.0 * x * inv_rms_x2 * (w1 * dpoly - x2 * coeff_x2)
+ g += 3.0 * x2 * inv_rms_x3 * (w0 * dpoly - x3 * coeff_x3)
+ tl.store(grad_input_row_ptr + d_offs, g, mask=mask)
+
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc)
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc)
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc)
+ tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc)
+ else:
+ # --- Multi-tile: 2-pass ---
+ # Pass 1: RMS stats + dot products + bias grad
+ sum_x2 = tl.zeros((), dtype=tl.float32)
+ sum_x4 = tl.zeros((), dtype=tl.float32)
+ sum_x6 = tl.zeros((), dtype=tl.float32)
+ sum_dpoly_x = tl.zeros((), dtype=tl.float32)
+ sum_dpoly_x2 = tl.zeros((), dtype=tl.float32)
+ sum_dpoly_x3 = tl.zeros((), dtype=tl.float32)
+ grad_b_acc = tl.zeros((), dtype=tl.float32)
+
+ for d_start in range(0, D, BLOCK_D):
+ d_offs = d_start + tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+
+ x2 = x * x
+ x3 = x2 * x
+ dpoly = go * m
+
+ sum_x2 += tl.sum(x2)
+ sum_x4 += tl.sum(x2 * x2)
+ sum_x6 += tl.sum(x2 * x2 * x2)
+ sum_dpoly_x += tl.sum(dpoly * x)
+ sum_dpoly_x2 += tl.sum(dpoly * x2)
+ sum_dpoly_x3 += tl.sum(dpoly * x3)
+ grad_b_acc += tl.sum(dpoly)
+
+ inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps)
+ inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps)
+ inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps)
+
+ w0_inv = w0 * inv_rms_x3
+ w1_inv = w1 * inv_rms_x2
+ w2_inv = w2 * inv_rms_x
+
+ # Weight grads
+ grad_w0_acc = inv_rms_x3 * sum_dpoly_x3
+ grad_w1_acc = inv_rms_x2 * sum_dpoly_x2
+ grad_w2_acc = inv_rms_x * sum_dpoly_x
+
+ # Coefficients for grad_input
+ coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float
+ coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float
+ coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float
+
+ # Pass 2: grad_input + grad_mul
+ for d_start in range(0, D, BLOCK_D):
+ d_offs = d_start + tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+
+ x2 = x * x
+ x3 = x2 * x
+
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv
+ tl.store(grad_mul_row_ptr + d_offs,
+ go * (poly + b_val),
+ mask=mask)
+
+ dpoly = go * m
+ g = inv_rms_x * (w2 * dpoly - x * coeff_x)
+ g += (2.0 * x * inv_rms_x2 *
+ (w1 * dpoly - x2 * coeff_x2))
+ g += (3.0 * x2 * inv_rms_x3 *
+ (w0 * dpoly - x3 * coeff_x3))
+
+ tl.store(grad_input_row_ptr + d_offs, g, mask=mask)
+
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc)
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc)
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc)
+ tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc)
+
+ class _GroupedPolyNormFn(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, input, mul, weight, bias, offsets, eps, expert_offset):
+ N, D = input.shape
+ input = input.contiguous()
+ mul = mul.contiguous()
+ output = torch.empty_like(input)
+
+ num_experts = offsets.shape[0]
+ assert num_experts <= 4096, (
+ f"Supports at most 4096 experts, got {num_experts}.")
+
+ _grouped_polynorm_fwd_kernel[(N,)](
+ input,
+ mul,
+ weight,
+ bias,
+ offsets,
+ output,
+ N,
+ D,
+ num_experts,
+ eps,
+ expert_offset,
+ stride_input_row=input.stride(0),
+ stride_mul_row=mul.stride(0),
+ stride_out_row=output.stride(0),
+ )
+
+ ctx.save_for_backward(input, mul, weight, bias, offsets)
+ ctx.eps = eps
+ ctx.expert_offset = expert_offset
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input, mul, weight, bias, offsets = ctx.saved_tensors
+ eps = ctx.eps
+ expert_offset = ctx.expert_offset
+ N, D = input.shape
+
+ grad_output = grad_output.contiguous()
+ grad_input = torch.empty_like(input)
+ grad_mul = torch.empty_like(mul)
+ grad_weight = torch.zeros(weight.shape[0],
+ 3,
+ device=weight.device,
+ dtype=torch.float32)
+ grad_bias = torch.zeros(bias.shape[0],
+ device=bias.device,
+ dtype=torch.float32)
+
+ num_experts = offsets.shape[0]
+
+ _grouped_polynorm_bwd_kernel[(N,)](
+ grad_output,
+ input,
+ mul,
+ weight,
+ bias,
+ offsets,
+ grad_input,
+ grad_mul,
+ grad_weight,
+ grad_bias,
+ N,
+ D,
+ num_experts,
+ eps,
+ expert_offset,
+ stride_row=input.stride(0),
+ )
+
+ grad_weight = grad_weight.to(weight.dtype)
+ grad_bias = grad_bias.unsqueeze(-1).to(bias.dtype)
+
+ return grad_input, grad_mul, grad_weight, grad_bias, None, None, None
+
+ def grouped_fused_mul_poly_norm(
+ input: Tensor,
+ mul: Tensor,
+ weight: Tensor,
+ bias: Tensor,
+ offsets: Tensor,
+ eps: float = 1e-6,
+ expert_offset: int = 0,
+ ) -> Tensor:
+ """Triton-accelerated Grouped FusedMulPolyNorm.
+
+ Args:
+ input: (total_tokens, D) - concatenated tokens for all experts
+ mul: (total_tokens, D) - gate values to multiply with
+ weight: (num_experts, 3) - per-expert polynomial weights
+ bias: (num_experts, 1) - per-expert polynomial bias
+ offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32)
+ eps: numerical stability epsilon
+ expert_offset: offset to add to expert index
+
+ Returns:
+ (total_tokens, D) - output tensor
+ """
+ return _GroupedPolyNormFn.apply(input, mul, weight, bias, offsets, eps,
+ expert_offset)
+
+else:
+
+ def grouped_fused_mul_poly_norm(
+ input: Tensor,
+ mul: Tensor,
+ weight: Tensor,
+ bias: Tensor,
+ offsets: Tensor,
+ eps: float = 1e-6,
+ expert_offset: int = 0,
+ ) -> Tensor:
+ raise RuntimeError(
+ "Triton is not available. Install triton to use "
+ "grouped_fused_mul_poly_norm.")
diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/__init__.py b/build/torch28-cxx11-rocm64-x86_64-linux/__init__.py
index 0f6f29ac2c688bd09afa41c5d1abd9942c4456d8..a71567edf5422a315e6f06eae6a5606756794820 100644
--- a/build/torch28-cxx11-rocm64-x86_64-linux/__init__.py
+++ b/build/torch28-cxx11-rocm64-x86_64-linux/__init__.py
@@ -2,6 +2,7 @@ import torch
from . import layers, parallel_style
from ._ops import ops
+from .grouped_poly_norm import grouped_fused_mul_poly_norm
from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
@@ -45,6 +46,7 @@ def fused_add_rms_norm(
__all__ = [
"poly_norm",
"fused_mul_poly_norm",
+ "grouped_fused_mul_poly_norm",
"rms_norm",
"fused_add_rms_norm",
"layers",
diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/_activation_0e6f27f_dirty.abi3.so b/build/torch28-cxx11-rocm64-x86_64-linux/_activation_0e6f27f_dirty.abi3.so
new file mode 100755
index 0000000000000000000000000000000000000000..8dc5986ed65e78ccae92752b2318a4bf1e7028d9
--- /dev/null
+++ b/build/torch28-cxx11-rocm64-x86_64-linux/_activation_0e6f27f_dirty.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:388f461d91124b99544cbdbd4dc4d98f24c010d7a0dc2e9389648860a809a51a
+size 2794152
diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/_activation_18b7543_dirty.abi3.so b/build/torch28-cxx11-rocm64-x86_64-linux/_activation_18b7543_dirty.abi3.so
deleted file mode 100755
index d9f497ccf3368602ac8c3250f734668e8e5dc228..0000000000000000000000000000000000000000
--- a/build/torch28-cxx11-rocm64-x86_64-linux/_activation_18b7543_dirty.abi3.so
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:5213ce21365e97016b4a04bfa304e03c9fc1dc6780f62b3312fb65f96e8f6381
-size 2794152
diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/_ops.py b/build/torch28-cxx11-rocm64-x86_64-linux/_ops.py
index 1e14ae6ea50287a7fbc0feee54855ae4ffc0be5e..91c4fcaab95d4cccfa1e53fea482ec1a2334f4ff 100644
--- a/build/torch28-cxx11-rocm64-x86_64-linux/_ops.py
+++ b/build/torch28-cxx11-rocm64-x86_64-linux/_ops.py
@@ -1,9 +1,9 @@
import torch
-from . import _activation_18b7543_dirty
-ops = torch.ops._activation_18b7543_dirty
+from . import _activation_0e6f27f_dirty
+ops = torch.ops._activation_0e6f27f_dirty
def add_op_namespace_prefix(op_name: str):
"""
Prefix op by namespace.
"""
- return f"_activation_18b7543_dirty::{op_name}"
\ No newline at end of file
+ return f"_activation_0e6f27f_dirty::{op_name}"
\ No newline at end of file
diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/grouped_poly_norm.py b/build/torch28-cxx11-rocm64-x86_64-linux/grouped_poly_norm.py
new file mode 100644
index 0000000000000000000000000000000000000000..9184736c1c24584a94aa80c84fc18a95ab557a52
--- /dev/null
+++ b/build/torch28-cxx11-rocm64-x86_64-linux/grouped_poly_norm.py
@@ -0,0 +1,583 @@
+"""Triton-accelerated Grouped FusedMulPolyNorm for MoE.
+
+Fuses the entire PolyNorm computation into two Triton kernels (fwd + bwd),
+eliminating multiple intermediate tensors and kernel launches.
+
+PolyNorm formula (per row):
+ poly = w[0] * rms_norm(x^3) + w[1] * rms_norm(x^2) + w[2] * rms_norm(x) + bias
+ output = poly * mul
+
+where rms_norm(x) = x / sqrt(mean(x^2, dim=-1) + eps)
+
+Performance optimizations:
+ - @triton.autotune selects optimal BLOCK_D, num_warps, and num_stages per
+ hidden dimension.
+ - Single-tile specialization: when D <= BLOCK_D, all data stays in registers
+ across the reduction and output phases, eliminating redundant global reads.
+ - Multi-tile software pipelining: explicit num_stages in autotune configs
+ enables overlapping memory loads with computation across loop iterations.
+ - In-kernel binary search for expert mapping: eliminates 2 PyTorch kernel
+ launches (torch.arange + torch.bucketize) per forward/backward call.
+ - Backward 2-pass optimization: pass 1 merges RMS statistics computation
+ with dot product accumulation, pass 2 computes gradients. This reduces
+ memory traffic compared to a naive 3-pass approach.
+
+Forward kernel: one program per row, tiles over D dimension.
+ - Computes x, x^2, x^3 in registers
+ - Computes three RMS norms in a single pass (shared variance reduction)
+ - Applies polynomial weights + bias + mul in-place
+
+Backward kernel: one program per row, tiles over D dimension.
+ - Recomputes forward intermediates from saved inputs (activation recomputation)
+ - 2-pass: (1) RMS stats + dot products + bias grad, (2) grad_input + grad_mul + weight grads
+ - Weight/bias gradients use tl.atomic_add for cross-row accumulation
+"""
+
+import torch
+from torch import Tensor
+
+try:
+ import triton
+ import triton.language as tl
+
+ HAS_TRITON = True
+except ImportError:
+ HAS_TRITON = False
+
+
+# ---------------------------------------------------------------------------
+# PyTorch reference implementation (for testing and benchmarking)
+# ---------------------------------------------------------------------------
+def _rms_norm(x: Tensor, eps: float) -> Tensor:
+ """Per-row RMS normalization: x / sqrt(mean(x^2, dim=-1) + eps)"""
+ return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps)
+
+
+def grouped_fused_mul_poly_norm_ref(
+ input: Tensor,
+ mul: Tensor,
+ weight: Tensor,
+ bias: Tensor,
+ offsets: Tensor,
+ eps: float = 1e-6,
+ expert_offset: int = 0,
+) -> Tensor:
+ """PyTorch reference for Grouped FusedMulPolyNorm (vectorized, single pass).
+
+ Uses torch.bucketize to map tokens to experts, then computes PolyNorm
+ for all tokens at once. torch.compile friendly.
+
+ Args:
+ input: (total_tokens, D) - concatenated tokens for all experts
+ mul: (total_tokens, D) - gate values to multiply with
+ weight: (num_experts, 3) - per-expert polynomial weights [x^3, x^2, x]
+ bias: (num_experts, 1) - per-expert polynomial bias
+ offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32)
+ eps: numerical stability epsilon
+
+ Returns:
+ (total_tokens, D) - output tensor
+ """
+ orig_dtype = input.dtype
+
+ token_positions = torch.arange(input.shape[0], device=input.device)
+ expert_idx = torch.bucketize(token_positions, offsets, right=True) + expert_offset
+
+ weight_fp32 = weight.float()
+ bias_fp32 = bias.float()
+
+ per_token_w = weight_fp32[expert_idx]
+ per_token_b = bias_fp32[expert_idx]
+
+ x = input.float()
+ m = mul.float()
+
+ x2 = x * x
+ x3 = x2 * x
+
+ poly = (per_token_w[:, 0:1] * _rms_norm(x3, eps) +
+ per_token_w[:, 1:2] * _rms_norm(x2, eps) +
+ per_token_w[:, 2:3] * _rms_norm(x, eps) + per_token_b)
+
+ return (poly * m).to(orig_dtype)
+
+
+# ---------------------------------------------------------------------------
+# Triton kernel implementation
+# ---------------------------------------------------------------------------
+if HAS_TRITON:
+ # --- Autotune configurations ---
+ _GROUPED_POLYNORM_FWD_CONFIGS = [
+ triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2),
+ triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2),
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2),
+ triton.Config({"BLOCK_D": 2048}, num_warps=4, num_stages=1),
+ triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1),
+ triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1),
+ triton.Config({"BLOCK_D": 2048}, num_warps=32, num_stages=1),
+ ]
+
+ _GROUPED_POLYNORM_BWD_CONFIGS = [
+ triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2),
+ triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=5),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2),
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2),
+ triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1),
+ triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1),
+ ]
+
+ @triton.autotune(
+ configs=_GROUPED_POLYNORM_FWD_CONFIGS,
+ key=["D"],
+ )
+ @triton.jit
+ def _grouped_polynorm_fwd_kernel(
+ input_ptr,
+ mul_ptr,
+ weight_ptr,
+ bias_ptr,
+ offsets_ptr,
+ output_ptr,
+ N,
+ D,
+ num_experts,
+ eps,
+ expert_offset,
+ stride_input_row,
+ stride_mul_row,
+ stride_out_row,
+ BLOCK_D: tl.constexpr,
+ ):
+ """Forward kernel: one program per row."""
+ row = tl.program_id(0)
+ if row >= N:
+ return
+
+ # Binary search for expert index (12 iters covers up to 4096 experts)
+ lo = 0
+ hi = num_experts
+ for _ in range(12):
+ if lo < hi:
+ mid = (lo + hi) // 2
+ if tl.load(offsets_ptr + mid) <= row:
+ lo = mid + 1
+ else:
+ hi = mid
+ eidx = lo + expert_offset
+
+ w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32)
+ w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32)
+ w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32)
+ b = tl.load(bias_ptr + eidx).to(tl.float32)
+
+ input_row_ptr = input_ptr + row * stride_input_row
+ mul_row_ptr = mul_ptr + row * stride_mul_row
+ out_row_ptr = output_ptr + row * stride_out_row
+
+ D_float = D.to(tl.float32)
+
+ # --- Single-tile path ---
+ if D <= BLOCK_D:
+ d_offs = tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+
+ x2 = x * x
+ x3 = x2 * x
+
+ inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps)
+ inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps)
+ inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps)
+
+ # Pre-multiply scalar weight * inv_rms to save 1 FMA per element
+ w0_inv = w0 * inv_rms_x3
+ w1_inv = w1 * inv_rms_x2
+ w2_inv = w2 * inv_rms_x
+
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b
+ tl.store(out_row_ptr + d_offs, poly * m, mask=mask)
+ else:
+ # --- Multi-tile: two-pass approach ---
+ sum_x2 = tl.zeros((), dtype=tl.float32)
+ sum_x4 = tl.zeros((), dtype=tl.float32)
+ sum_x6 = tl.zeros((), dtype=tl.float32)
+
+ for d_start in range(0, D, BLOCK_D):
+ d_offs = d_start + tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ x2 = x * x
+ sum_x2 += tl.sum(x2)
+ sum_x4 += tl.sum(x2 * x2)
+ sum_x6 += tl.sum(x2 * x2 * x2)
+
+ inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps)
+ inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps)
+ inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps)
+
+ # Pre-multiply scalar weight * inv_rms
+ w0_inv = w0 * inv_rms_x3
+ w1_inv = w1 * inv_rms_x2
+ w2_inv = w2 * inv_rms_x
+
+ for d_start in range(0, D, BLOCK_D):
+ d_offs = d_start + tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ x2 = x * x
+ x3 = x2 * x
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b
+ tl.store(out_row_ptr + d_offs, poly * m, mask=mask)
+
+ @triton.autotune(
+ configs=_GROUPED_POLYNORM_BWD_CONFIGS,
+ key=["D"],
+ reset_to_zero=["grad_weight_ptr", "grad_bias_ptr"],
+ )
+ @triton.jit
+ def _grouped_polynorm_bwd_kernel(
+ grad_out_ptr,
+ input_ptr,
+ mul_ptr,
+ weight_ptr,
+ bias_ptr,
+ offsets_ptr,
+ grad_input_ptr,
+ grad_mul_ptr,
+ grad_weight_ptr,
+ grad_bias_ptr,
+ N,
+ D,
+ num_experts,
+ eps,
+ expert_offset,
+ stride_row,
+ BLOCK_D: tl.constexpr,
+ ):
+ """Backward kernel: one program per row, 2-pass approach.
+
+ Pass 1: RMS stats + dot products + bias grad
+ Pass 2: grad_input + grad_mul + weight grads (via atomic_add)
+ """
+ row = tl.program_id(0)
+ if row >= N:
+ return
+
+ lo = 0
+ hi = num_experts
+ for _ in range(12):
+ if lo < hi:
+ mid = (lo + hi) // 2
+ if tl.load(offsets_ptr + mid) <= row:
+ lo = mid + 1
+ else:
+ hi = mid
+ eidx = lo + expert_offset
+
+ w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32)
+ w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32)
+ w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32)
+ b_val = tl.load(bias_ptr + eidx).to(tl.float32)
+
+ input_row_ptr = input_ptr + row * stride_row
+ mul_row_ptr = mul_ptr + row * stride_row
+ grad_out_row_ptr = grad_out_ptr + row * stride_row
+ grad_input_row_ptr = grad_input_ptr + row * stride_row
+ grad_mul_row_ptr = grad_mul_ptr + row * stride_row
+
+ D_float = D.to(tl.float32)
+
+ # --- Single-tile path ---
+ if D <= BLOCK_D:
+ d_offs = tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+
+ x2 = x * x
+ x3 = x2 * x
+
+ # Compute RMS stats (x4 inlined to reduce register pressure)
+ inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps)
+ inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps)
+ inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps)
+
+ w0_inv = w0 * inv_rms_x3
+ w1_inv = w1 * inv_rms_x2
+ w2_inv = w2 * inv_rms_x
+
+ dpoly = go * m
+
+ # Dot products for coefficients and weight grads
+ sum_dpoly_x = tl.sum(dpoly * x)
+ sum_dpoly_x2 = tl.sum(dpoly * x2)
+ sum_dpoly_x3 = tl.sum(dpoly * x3)
+ grad_b_acc = tl.sum(dpoly)
+
+ # Weight grads
+ grad_w0_acc = inv_rms_x3 * sum_dpoly_x3
+ grad_w1_acc = inv_rms_x2 * sum_dpoly_x2
+ grad_w2_acc = inv_rms_x * sum_dpoly_x
+
+ # Coefficients for grad_input
+ coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float
+ coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float
+ coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float
+
+ # grad_mul
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv
+ tl.store(grad_mul_row_ptr + d_offs, go * (poly + b_val), mask=mask)
+
+ # grad_input (in-place accumulation to reduce register pressure)
+ g = inv_rms_x * (w2 * dpoly - x * coeff_x)
+ g += 2.0 * x * inv_rms_x2 * (w1 * dpoly - x2 * coeff_x2)
+ g += 3.0 * x2 * inv_rms_x3 * (w0 * dpoly - x3 * coeff_x3)
+ tl.store(grad_input_row_ptr + d_offs, g, mask=mask)
+
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc)
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc)
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc)
+ tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc)
+ else:
+ # --- Multi-tile: 2-pass ---
+ # Pass 1: RMS stats + dot products + bias grad
+ sum_x2 = tl.zeros((), dtype=tl.float32)
+ sum_x4 = tl.zeros((), dtype=tl.float32)
+ sum_x6 = tl.zeros((), dtype=tl.float32)
+ sum_dpoly_x = tl.zeros((), dtype=tl.float32)
+ sum_dpoly_x2 = tl.zeros((), dtype=tl.float32)
+ sum_dpoly_x3 = tl.zeros((), dtype=tl.float32)
+ grad_b_acc = tl.zeros((), dtype=tl.float32)
+
+ for d_start in range(0, D, BLOCK_D):
+ d_offs = d_start + tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+
+ x2 = x * x
+ x3 = x2 * x
+ dpoly = go * m
+
+ sum_x2 += tl.sum(x2)
+ sum_x4 += tl.sum(x2 * x2)
+ sum_x6 += tl.sum(x2 * x2 * x2)
+ sum_dpoly_x += tl.sum(dpoly * x)
+ sum_dpoly_x2 += tl.sum(dpoly * x2)
+ sum_dpoly_x3 += tl.sum(dpoly * x3)
+ grad_b_acc += tl.sum(dpoly)
+
+ inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps)
+ inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps)
+ inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps)
+
+ w0_inv = w0 * inv_rms_x3
+ w1_inv = w1 * inv_rms_x2
+ w2_inv = w2 * inv_rms_x
+
+ # Weight grads
+ grad_w0_acc = inv_rms_x3 * sum_dpoly_x3
+ grad_w1_acc = inv_rms_x2 * sum_dpoly_x2
+ grad_w2_acc = inv_rms_x * sum_dpoly_x
+
+ # Coefficients for grad_input
+ coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float
+ coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float
+ coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float
+
+ # Pass 2: grad_input + grad_mul
+ for d_start in range(0, D, BLOCK_D):
+ d_offs = d_start + tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+
+ x2 = x * x
+ x3 = x2 * x
+
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv
+ tl.store(grad_mul_row_ptr + d_offs,
+ go * (poly + b_val),
+ mask=mask)
+
+ dpoly = go * m
+ g = inv_rms_x * (w2 * dpoly - x * coeff_x)
+ g += (2.0 * x * inv_rms_x2 *
+ (w1 * dpoly - x2 * coeff_x2))
+ g += (3.0 * x2 * inv_rms_x3 *
+ (w0 * dpoly - x3 * coeff_x3))
+
+ tl.store(grad_input_row_ptr + d_offs, g, mask=mask)
+
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc)
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc)
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc)
+ tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc)
+
+ class _GroupedPolyNormFn(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, input, mul, weight, bias, offsets, eps, expert_offset):
+ N, D = input.shape
+ input = input.contiguous()
+ mul = mul.contiguous()
+ output = torch.empty_like(input)
+
+ num_experts = offsets.shape[0]
+ assert num_experts <= 4096, (
+ f"Supports at most 4096 experts, got {num_experts}.")
+
+ _grouped_polynorm_fwd_kernel[(N,)](
+ input,
+ mul,
+ weight,
+ bias,
+ offsets,
+ output,
+ N,
+ D,
+ num_experts,
+ eps,
+ expert_offset,
+ stride_input_row=input.stride(0),
+ stride_mul_row=mul.stride(0),
+ stride_out_row=output.stride(0),
+ )
+
+ ctx.save_for_backward(input, mul, weight, bias, offsets)
+ ctx.eps = eps
+ ctx.expert_offset = expert_offset
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input, mul, weight, bias, offsets = ctx.saved_tensors
+ eps = ctx.eps
+ expert_offset = ctx.expert_offset
+ N, D = input.shape
+
+ grad_output = grad_output.contiguous()
+ grad_input = torch.empty_like(input)
+ grad_mul = torch.empty_like(mul)
+ grad_weight = torch.zeros(weight.shape[0],
+ 3,
+ device=weight.device,
+ dtype=torch.float32)
+ grad_bias = torch.zeros(bias.shape[0],
+ device=bias.device,
+ dtype=torch.float32)
+
+ num_experts = offsets.shape[0]
+
+ _grouped_polynorm_bwd_kernel[(N,)](
+ grad_output,
+ input,
+ mul,
+ weight,
+ bias,
+ offsets,
+ grad_input,
+ grad_mul,
+ grad_weight,
+ grad_bias,
+ N,
+ D,
+ num_experts,
+ eps,
+ expert_offset,
+ stride_row=input.stride(0),
+ )
+
+ grad_weight = grad_weight.to(weight.dtype)
+ grad_bias = grad_bias.unsqueeze(-1).to(bias.dtype)
+
+ return grad_input, grad_mul, grad_weight, grad_bias, None, None, None
+
+ def grouped_fused_mul_poly_norm(
+ input: Tensor,
+ mul: Tensor,
+ weight: Tensor,
+ bias: Tensor,
+ offsets: Tensor,
+ eps: float = 1e-6,
+ expert_offset: int = 0,
+ ) -> Tensor:
+ """Triton-accelerated Grouped FusedMulPolyNorm.
+
+ Args:
+ input: (total_tokens, D) - concatenated tokens for all experts
+ mul: (total_tokens, D) - gate values to multiply with
+ weight: (num_experts, 3) - per-expert polynomial weights
+ bias: (num_experts, 1) - per-expert polynomial bias
+ offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32)
+ eps: numerical stability epsilon
+ expert_offset: offset to add to expert index
+
+ Returns:
+ (total_tokens, D) - output tensor
+ """
+ return _GroupedPolyNormFn.apply(input, mul, weight, bias, offsets, eps,
+ expert_offset)
+
+else:
+
+ def grouped_fused_mul_poly_norm(
+ input: Tensor,
+ mul: Tensor,
+ weight: Tensor,
+ bias: Tensor,
+ offsets: Tensor,
+ eps: float = 1e-6,
+ expert_offset: int = 0,
+ ) -> Tensor:
+ raise RuntimeError(
+ "Triton is not available. Install triton to use "
+ "grouped_fused_mul_poly_norm.")
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/__init__.py b/build/torch29-cxx11-cu126-x86_64-linux/__init__.py
index 0f6f29ac2c688bd09afa41c5d1abd9942c4456d8..a71567edf5422a315e6f06eae6a5606756794820 100644
--- a/build/torch29-cxx11-cu126-x86_64-linux/__init__.py
+++ b/build/torch29-cxx11-cu126-x86_64-linux/__init__.py
@@ -2,6 +2,7 @@ import torch
from . import layers, parallel_style
from ._ops import ops
+from .grouped_poly_norm import grouped_fused_mul_poly_norm
from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
@@ -45,6 +46,7 @@ def fused_add_rms_norm(
__all__ = [
"poly_norm",
"fused_mul_poly_norm",
+ "grouped_fused_mul_poly_norm",
"rms_norm",
"fused_add_rms_norm",
"layers",
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/_activation_0e6f27f_dirty.abi3.so b/build/torch29-cxx11-cu126-x86_64-linux/_activation_0e6f27f_dirty.abi3.so
new file mode 100755
index 0000000000000000000000000000000000000000..43c71923e7ca23a2f8334a6879b08e6912888e36
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/_activation_0e6f27f_dirty.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0c5304ceac9171f76c03792dfb9b7e8299ba2f2885983c39031546fde8f61f8b
+size 10756320
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/_activation_18b7543_dirty.abi3.so b/build/torch29-cxx11-cu126-x86_64-linux/_activation_18b7543_dirty.abi3.so
deleted file mode 100755
index 487158f67784cbcd78bb767dbe5a0cead183d0de..0000000000000000000000000000000000000000
--- a/build/torch29-cxx11-cu126-x86_64-linux/_activation_18b7543_dirty.abi3.so
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:0a64085eb92e8f49acb50abc5fc2c50e0dc3e3fd84ae29d7ea8bf27518c34af3
-size 10756320
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/_ops.py b/build/torch29-cxx11-cu126-x86_64-linux/_ops.py
index 1e14ae6ea50287a7fbc0feee54855ae4ffc0be5e..91c4fcaab95d4cccfa1e53fea482ec1a2334f4ff 100644
--- a/build/torch29-cxx11-cu126-x86_64-linux/_ops.py
+++ b/build/torch29-cxx11-cu126-x86_64-linux/_ops.py
@@ -1,9 +1,9 @@
import torch
-from . import _activation_18b7543_dirty
-ops = torch.ops._activation_18b7543_dirty
+from . import _activation_0e6f27f_dirty
+ops = torch.ops._activation_0e6f27f_dirty
def add_op_namespace_prefix(op_name: str):
"""
Prefix op by namespace.
"""
- return f"_activation_18b7543_dirty::{op_name}"
\ No newline at end of file
+ return f"_activation_0e6f27f_dirty::{op_name}"
\ No newline at end of file
diff --git a/build/torch29-cxx11-cu126-x86_64-linux/grouped_poly_norm.py b/build/torch29-cxx11-cu126-x86_64-linux/grouped_poly_norm.py
new file mode 100644
index 0000000000000000000000000000000000000000..9184736c1c24584a94aa80c84fc18a95ab557a52
--- /dev/null
+++ b/build/torch29-cxx11-cu126-x86_64-linux/grouped_poly_norm.py
@@ -0,0 +1,583 @@
+"""Triton-accelerated Grouped FusedMulPolyNorm for MoE.
+
+Fuses the entire PolyNorm computation into two Triton kernels (fwd + bwd),
+eliminating multiple intermediate tensors and kernel launches.
+
+PolyNorm formula (per row):
+ poly = w[0] * rms_norm(x^3) + w[1] * rms_norm(x^2) + w[2] * rms_norm(x) + bias
+ output = poly * mul
+
+where rms_norm(x) = x / sqrt(mean(x^2, dim=-1) + eps)
+
+Performance optimizations:
+ - @triton.autotune selects optimal BLOCK_D, num_warps, and num_stages per
+ hidden dimension.
+ - Single-tile specialization: when D <= BLOCK_D, all data stays in registers
+ across the reduction and output phases, eliminating redundant global reads.
+ - Multi-tile software pipelining: explicit num_stages in autotune configs
+ enables overlapping memory loads with computation across loop iterations.
+ - In-kernel binary search for expert mapping: eliminates 2 PyTorch kernel
+ launches (torch.arange + torch.bucketize) per forward/backward call.
+ - Backward 2-pass optimization: pass 1 merges RMS statistics computation
+ with dot product accumulation, pass 2 computes gradients. This reduces
+ memory traffic compared to a naive 3-pass approach.
+
+Forward kernel: one program per row, tiles over D dimension.
+ - Computes x, x^2, x^3 in registers
+ - Computes three RMS norms in a single pass (shared variance reduction)
+ - Applies polynomial weights + bias + mul in-place
+
+Backward kernel: one program per row, tiles over D dimension.
+ - Recomputes forward intermediates from saved inputs (activation recomputation)
+ - 2-pass: (1) RMS stats + dot products + bias grad, (2) grad_input + grad_mul + weight grads
+ - Weight/bias gradients use tl.atomic_add for cross-row accumulation
+"""
+
+import torch
+from torch import Tensor
+
+try:
+ import triton
+ import triton.language as tl
+
+ HAS_TRITON = True
+except ImportError:
+ HAS_TRITON = False
+
+
+# ---------------------------------------------------------------------------
+# PyTorch reference implementation (for testing and benchmarking)
+# ---------------------------------------------------------------------------
+def _rms_norm(x: Tensor, eps: float) -> Tensor:
+ """Per-row RMS normalization: x / sqrt(mean(x^2, dim=-1) + eps)"""
+ return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps)
+
+
+def grouped_fused_mul_poly_norm_ref(
+ input: Tensor,
+ mul: Tensor,
+ weight: Tensor,
+ bias: Tensor,
+ offsets: Tensor,
+ eps: float = 1e-6,
+ expert_offset: int = 0,
+) -> Tensor:
+ """PyTorch reference for Grouped FusedMulPolyNorm (vectorized, single pass).
+
+ Uses torch.bucketize to map tokens to experts, then computes PolyNorm
+ for all tokens at once. torch.compile friendly.
+
+ Args:
+ input: (total_tokens, D) - concatenated tokens for all experts
+ mul: (total_tokens, D) - gate values to multiply with
+ weight: (num_experts, 3) - per-expert polynomial weights [x^3, x^2, x]
+ bias: (num_experts, 1) - per-expert polynomial bias
+ offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32)
+ eps: numerical stability epsilon
+
+ Returns:
+ (total_tokens, D) - output tensor
+ """
+ orig_dtype = input.dtype
+
+ token_positions = torch.arange(input.shape[0], device=input.device)
+ expert_idx = torch.bucketize(token_positions, offsets, right=True) + expert_offset
+
+ weight_fp32 = weight.float()
+ bias_fp32 = bias.float()
+
+ per_token_w = weight_fp32[expert_idx]
+ per_token_b = bias_fp32[expert_idx]
+
+ x = input.float()
+ m = mul.float()
+
+ x2 = x * x
+ x3 = x2 * x
+
+ poly = (per_token_w[:, 0:1] * _rms_norm(x3, eps) +
+ per_token_w[:, 1:2] * _rms_norm(x2, eps) +
+ per_token_w[:, 2:3] * _rms_norm(x, eps) + per_token_b)
+
+ return (poly * m).to(orig_dtype)
+
+
+# ---------------------------------------------------------------------------
+# Triton kernel implementation
+# ---------------------------------------------------------------------------
+if HAS_TRITON:
+ # --- Autotune configurations ---
+ _GROUPED_POLYNORM_FWD_CONFIGS = [
+ triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2),
+ triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2),
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2),
+ triton.Config({"BLOCK_D": 2048}, num_warps=4, num_stages=1),
+ triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1),
+ triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1),
+ triton.Config({"BLOCK_D": 2048}, num_warps=32, num_stages=1),
+ ]
+
+ _GROUPED_POLYNORM_BWD_CONFIGS = [
+ triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2),
+ triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=5),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2),
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2),
+ triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1),
+ triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1),
+ ]
+
+ @triton.autotune(
+ configs=_GROUPED_POLYNORM_FWD_CONFIGS,
+ key=["D"],
+ )
+ @triton.jit
+ def _grouped_polynorm_fwd_kernel(
+ input_ptr,
+ mul_ptr,
+ weight_ptr,
+ bias_ptr,
+ offsets_ptr,
+ output_ptr,
+ N,
+ D,
+ num_experts,
+ eps,
+ expert_offset,
+ stride_input_row,
+ stride_mul_row,
+ stride_out_row,
+ BLOCK_D: tl.constexpr,
+ ):
+ """Forward kernel: one program per row."""
+ row = tl.program_id(0)
+ if row >= N:
+ return
+
+ # Binary search for expert index (12 iters covers up to 4096 experts)
+ lo = 0
+ hi = num_experts
+ for _ in range(12):
+ if lo < hi:
+ mid = (lo + hi) // 2
+ if tl.load(offsets_ptr + mid) <= row:
+ lo = mid + 1
+ else:
+ hi = mid
+ eidx = lo + expert_offset
+
+ w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32)
+ w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32)
+ w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32)
+ b = tl.load(bias_ptr + eidx).to(tl.float32)
+
+ input_row_ptr = input_ptr + row * stride_input_row
+ mul_row_ptr = mul_ptr + row * stride_mul_row
+ out_row_ptr = output_ptr + row * stride_out_row
+
+ D_float = D.to(tl.float32)
+
+ # --- Single-tile path ---
+ if D <= BLOCK_D:
+ d_offs = tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+
+ x2 = x * x
+ x3 = x2 * x
+
+ inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps)
+ inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps)
+ inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps)
+
+ # Pre-multiply scalar weight * inv_rms to save 1 FMA per element
+ w0_inv = w0 * inv_rms_x3
+ w1_inv = w1 * inv_rms_x2
+ w2_inv = w2 * inv_rms_x
+
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b
+ tl.store(out_row_ptr + d_offs, poly * m, mask=mask)
+ else:
+ # --- Multi-tile: two-pass approach ---
+ sum_x2 = tl.zeros((), dtype=tl.float32)
+ sum_x4 = tl.zeros((), dtype=tl.float32)
+ sum_x6 = tl.zeros((), dtype=tl.float32)
+
+ for d_start in range(0, D, BLOCK_D):
+ d_offs = d_start + tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ x2 = x * x
+ sum_x2 += tl.sum(x2)
+ sum_x4 += tl.sum(x2 * x2)
+ sum_x6 += tl.sum(x2 * x2 * x2)
+
+ inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps)
+ inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps)
+ inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps)
+
+ # Pre-multiply scalar weight * inv_rms
+ w0_inv = w0 * inv_rms_x3
+ w1_inv = w1 * inv_rms_x2
+ w2_inv = w2 * inv_rms_x
+
+ for d_start in range(0, D, BLOCK_D):
+ d_offs = d_start + tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ x2 = x * x
+ x3 = x2 * x
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b
+ tl.store(out_row_ptr + d_offs, poly * m, mask=mask)
+
+ @triton.autotune(
+ configs=_GROUPED_POLYNORM_BWD_CONFIGS,
+ key=["D"],
+ reset_to_zero=["grad_weight_ptr", "grad_bias_ptr"],
+ )
+ @triton.jit
+ def _grouped_polynorm_bwd_kernel(
+ grad_out_ptr,
+ input_ptr,
+ mul_ptr,
+ weight_ptr,
+ bias_ptr,
+ offsets_ptr,
+ grad_input_ptr,
+ grad_mul_ptr,
+ grad_weight_ptr,
+ grad_bias_ptr,
+ N,
+ D,
+ num_experts,
+ eps,
+ expert_offset,
+ stride_row,
+ BLOCK_D: tl.constexpr,
+ ):
+ """Backward kernel: one program per row, 2-pass approach.
+
+ Pass 1: RMS stats + dot products + bias grad
+ Pass 2: grad_input + grad_mul + weight grads (via atomic_add)
+ """
+ row = tl.program_id(0)
+ if row >= N:
+ return
+
+ lo = 0
+ hi = num_experts
+ for _ in range(12):
+ if lo < hi:
+ mid = (lo + hi) // 2
+ if tl.load(offsets_ptr + mid) <= row:
+ lo = mid + 1
+ else:
+ hi = mid
+ eidx = lo + expert_offset
+
+ w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32)
+ w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32)
+ w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32)
+ b_val = tl.load(bias_ptr + eidx).to(tl.float32)
+
+ input_row_ptr = input_ptr + row * stride_row
+ mul_row_ptr = mul_ptr + row * stride_row
+ grad_out_row_ptr = grad_out_ptr + row * stride_row
+ grad_input_row_ptr = grad_input_ptr + row * stride_row
+ grad_mul_row_ptr = grad_mul_ptr + row * stride_row
+
+ D_float = D.to(tl.float32)
+
+ # --- Single-tile path ---
+ if D <= BLOCK_D:
+ d_offs = tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+
+ x2 = x * x
+ x3 = x2 * x
+
+ # Compute RMS stats (x4 inlined to reduce register pressure)
+ inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps)
+ inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps)
+ inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps)
+
+ w0_inv = w0 * inv_rms_x3
+ w1_inv = w1 * inv_rms_x2
+ w2_inv = w2 * inv_rms_x
+
+ dpoly = go * m
+
+ # Dot products for coefficients and weight grads
+ sum_dpoly_x = tl.sum(dpoly * x)
+ sum_dpoly_x2 = tl.sum(dpoly * x2)
+ sum_dpoly_x3 = tl.sum(dpoly * x3)
+ grad_b_acc = tl.sum(dpoly)
+
+ # Weight grads
+ grad_w0_acc = inv_rms_x3 * sum_dpoly_x3
+ grad_w1_acc = inv_rms_x2 * sum_dpoly_x2
+ grad_w2_acc = inv_rms_x * sum_dpoly_x
+
+ # Coefficients for grad_input
+ coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float
+ coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float
+ coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float
+
+ # grad_mul
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv
+ tl.store(grad_mul_row_ptr + d_offs, go * (poly + b_val), mask=mask)
+
+ # grad_input (in-place accumulation to reduce register pressure)
+ g = inv_rms_x * (w2 * dpoly - x * coeff_x)
+ g += 2.0 * x * inv_rms_x2 * (w1 * dpoly - x2 * coeff_x2)
+ g += 3.0 * x2 * inv_rms_x3 * (w0 * dpoly - x3 * coeff_x3)
+ tl.store(grad_input_row_ptr + d_offs, g, mask=mask)
+
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc)
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc)
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc)
+ tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc)
+ else:
+ # --- Multi-tile: 2-pass ---
+ # Pass 1: RMS stats + dot products + bias grad
+ sum_x2 = tl.zeros((), dtype=tl.float32)
+ sum_x4 = tl.zeros((), dtype=tl.float32)
+ sum_x6 = tl.zeros((), dtype=tl.float32)
+ sum_dpoly_x = tl.zeros((), dtype=tl.float32)
+ sum_dpoly_x2 = tl.zeros((), dtype=tl.float32)
+ sum_dpoly_x3 = tl.zeros((), dtype=tl.float32)
+ grad_b_acc = tl.zeros((), dtype=tl.float32)
+
+ for d_start in range(0, D, BLOCK_D):
+ d_offs = d_start + tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+
+ x2 = x * x
+ x3 = x2 * x
+ dpoly = go * m
+
+ sum_x2 += tl.sum(x2)
+ sum_x4 += tl.sum(x2 * x2)
+ sum_x6 += tl.sum(x2 * x2 * x2)
+ sum_dpoly_x += tl.sum(dpoly * x)
+ sum_dpoly_x2 += tl.sum(dpoly * x2)
+ sum_dpoly_x3 += tl.sum(dpoly * x3)
+ grad_b_acc += tl.sum(dpoly)
+
+ inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps)
+ inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps)
+ inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps)
+
+ w0_inv = w0 * inv_rms_x3
+ w1_inv = w1 * inv_rms_x2
+ w2_inv = w2 * inv_rms_x
+
+ # Weight grads
+ grad_w0_acc = inv_rms_x3 * sum_dpoly_x3
+ grad_w1_acc = inv_rms_x2 * sum_dpoly_x2
+ grad_w2_acc = inv_rms_x * sum_dpoly_x
+
+ # Coefficients for grad_input
+ coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float
+ coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float
+ coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float
+
+ # Pass 2: grad_input + grad_mul
+ for d_start in range(0, D, BLOCK_D):
+ d_offs = d_start + tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+
+ x2 = x * x
+ x3 = x2 * x
+
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv
+ tl.store(grad_mul_row_ptr + d_offs,
+ go * (poly + b_val),
+ mask=mask)
+
+ dpoly = go * m
+ g = inv_rms_x * (w2 * dpoly - x * coeff_x)
+ g += (2.0 * x * inv_rms_x2 *
+ (w1 * dpoly - x2 * coeff_x2))
+ g += (3.0 * x2 * inv_rms_x3 *
+ (w0 * dpoly - x3 * coeff_x3))
+
+ tl.store(grad_input_row_ptr + d_offs, g, mask=mask)
+
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc)
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc)
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc)
+ tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc)
+
+ class _GroupedPolyNormFn(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, input, mul, weight, bias, offsets, eps, expert_offset):
+ N, D = input.shape
+ input = input.contiguous()
+ mul = mul.contiguous()
+ output = torch.empty_like(input)
+
+ num_experts = offsets.shape[0]
+ assert num_experts <= 4096, (
+ f"Supports at most 4096 experts, got {num_experts}.")
+
+ _grouped_polynorm_fwd_kernel[(N,)](
+ input,
+ mul,
+ weight,
+ bias,
+ offsets,
+ output,
+ N,
+ D,
+ num_experts,
+ eps,
+ expert_offset,
+ stride_input_row=input.stride(0),
+ stride_mul_row=mul.stride(0),
+ stride_out_row=output.stride(0),
+ )
+
+ ctx.save_for_backward(input, mul, weight, bias, offsets)
+ ctx.eps = eps
+ ctx.expert_offset = expert_offset
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input, mul, weight, bias, offsets = ctx.saved_tensors
+ eps = ctx.eps
+ expert_offset = ctx.expert_offset
+ N, D = input.shape
+
+ grad_output = grad_output.contiguous()
+ grad_input = torch.empty_like(input)
+ grad_mul = torch.empty_like(mul)
+ grad_weight = torch.zeros(weight.shape[0],
+ 3,
+ device=weight.device,
+ dtype=torch.float32)
+ grad_bias = torch.zeros(bias.shape[0],
+ device=bias.device,
+ dtype=torch.float32)
+
+ num_experts = offsets.shape[0]
+
+ _grouped_polynorm_bwd_kernel[(N,)](
+ grad_output,
+ input,
+ mul,
+ weight,
+ bias,
+ offsets,
+ grad_input,
+ grad_mul,
+ grad_weight,
+ grad_bias,
+ N,
+ D,
+ num_experts,
+ eps,
+ expert_offset,
+ stride_row=input.stride(0),
+ )
+
+ grad_weight = grad_weight.to(weight.dtype)
+ grad_bias = grad_bias.unsqueeze(-1).to(bias.dtype)
+
+ return grad_input, grad_mul, grad_weight, grad_bias, None, None, None
+
+ def grouped_fused_mul_poly_norm(
+ input: Tensor,
+ mul: Tensor,
+ weight: Tensor,
+ bias: Tensor,
+ offsets: Tensor,
+ eps: float = 1e-6,
+ expert_offset: int = 0,
+ ) -> Tensor:
+ """Triton-accelerated Grouped FusedMulPolyNorm.
+
+ Args:
+ input: (total_tokens, D) - concatenated tokens for all experts
+ mul: (total_tokens, D) - gate values to multiply with
+ weight: (num_experts, 3) - per-expert polynomial weights
+ bias: (num_experts, 1) - per-expert polynomial bias
+ offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32)
+ eps: numerical stability epsilon
+ expert_offset: offset to add to expert index
+
+ Returns:
+ (total_tokens, D) - output tensor
+ """
+ return _GroupedPolyNormFn.apply(input, mul, weight, bias, offsets, eps,
+ expert_offset)
+
+else:
+
+ def grouped_fused_mul_poly_norm(
+ input: Tensor,
+ mul: Tensor,
+ weight: Tensor,
+ bias: Tensor,
+ offsets: Tensor,
+ eps: float = 1e-6,
+ expert_offset: int = 0,
+ ) -> Tensor:
+ raise RuntimeError(
+ "Triton is not available. Install triton to use "
+ "grouped_fused_mul_poly_norm.")
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/__init__.py b/build/torch29-cxx11-cu128-x86_64-linux/__init__.py
index 0f6f29ac2c688bd09afa41c5d1abd9942c4456d8..a71567edf5422a315e6f06eae6a5606756794820 100644
--- a/build/torch29-cxx11-cu128-x86_64-linux/__init__.py
+++ b/build/torch29-cxx11-cu128-x86_64-linux/__init__.py
@@ -2,6 +2,7 @@ import torch
from . import layers, parallel_style
from ._ops import ops
+from .grouped_poly_norm import grouped_fused_mul_poly_norm
from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
@@ -45,6 +46,7 @@ def fused_add_rms_norm(
__all__ = [
"poly_norm",
"fused_mul_poly_norm",
+ "grouped_fused_mul_poly_norm",
"rms_norm",
"fused_add_rms_norm",
"layers",
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/_activation_0e6f27f_dirty.abi3.so b/build/torch29-cxx11-cu128-x86_64-linux/_activation_0e6f27f_dirty.abi3.so
new file mode 100755
index 0000000000000000000000000000000000000000..fc714ffdff549d768c4a141b224b70bf1f033b0f
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/_activation_0e6f27f_dirty.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:dfa89588a5e7e74b3a903912190b97004e308dd8fcb87832c2798d99733591f2
+size 15804336
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/_activation_18b7543_dirty.abi3.so b/build/torch29-cxx11-cu128-x86_64-linux/_activation_18b7543_dirty.abi3.so
deleted file mode 100755
index fbc074d8428b9b1f48c21282b6247a57472c4bff..0000000000000000000000000000000000000000
--- a/build/torch29-cxx11-cu128-x86_64-linux/_activation_18b7543_dirty.abi3.so
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:a996dcbd533b29a6de849fb4c83b58f5b818688b1c89ae8609805d09b500bc13
-size 15804336
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/_ops.py b/build/torch29-cxx11-cu128-x86_64-linux/_ops.py
index 1e14ae6ea50287a7fbc0feee54855ae4ffc0be5e..91c4fcaab95d4cccfa1e53fea482ec1a2334f4ff 100644
--- a/build/torch29-cxx11-cu128-x86_64-linux/_ops.py
+++ b/build/torch29-cxx11-cu128-x86_64-linux/_ops.py
@@ -1,9 +1,9 @@
import torch
-from . import _activation_18b7543_dirty
-ops = torch.ops._activation_18b7543_dirty
+from . import _activation_0e6f27f_dirty
+ops = torch.ops._activation_0e6f27f_dirty
def add_op_namespace_prefix(op_name: str):
"""
Prefix op by namespace.
"""
- return f"_activation_18b7543_dirty::{op_name}"
\ No newline at end of file
+ return f"_activation_0e6f27f_dirty::{op_name}"
\ No newline at end of file
diff --git a/build/torch29-cxx11-cu128-x86_64-linux/grouped_poly_norm.py b/build/torch29-cxx11-cu128-x86_64-linux/grouped_poly_norm.py
new file mode 100644
index 0000000000000000000000000000000000000000..9184736c1c24584a94aa80c84fc18a95ab557a52
--- /dev/null
+++ b/build/torch29-cxx11-cu128-x86_64-linux/grouped_poly_norm.py
@@ -0,0 +1,583 @@
+"""Triton-accelerated Grouped FusedMulPolyNorm for MoE.
+
+Fuses the entire PolyNorm computation into two Triton kernels (fwd + bwd),
+eliminating multiple intermediate tensors and kernel launches.
+
+PolyNorm formula (per row):
+ poly = w[0] * rms_norm(x^3) + w[1] * rms_norm(x^2) + w[2] * rms_norm(x) + bias
+ output = poly * mul
+
+where rms_norm(x) = x / sqrt(mean(x^2, dim=-1) + eps)
+
+Performance optimizations:
+ - @triton.autotune selects optimal BLOCK_D, num_warps, and num_stages per
+ hidden dimension.
+ - Single-tile specialization: when D <= BLOCK_D, all data stays in registers
+ across the reduction and output phases, eliminating redundant global reads.
+ - Multi-tile software pipelining: explicit num_stages in autotune configs
+ enables overlapping memory loads with computation across loop iterations.
+ - In-kernel binary search for expert mapping: eliminates 2 PyTorch kernel
+ launches (torch.arange + torch.bucketize) per forward/backward call.
+ - Backward 2-pass optimization: pass 1 merges RMS statistics computation
+ with dot product accumulation, pass 2 computes gradients. This reduces
+ memory traffic compared to a naive 3-pass approach.
+
+Forward kernel: one program per row, tiles over D dimension.
+ - Computes x, x^2, x^3 in registers
+ - Computes three RMS norms in a single pass (shared variance reduction)
+ - Applies polynomial weights + bias + mul in-place
+
+Backward kernel: one program per row, tiles over D dimension.
+ - Recomputes forward intermediates from saved inputs (activation recomputation)
+ - 2-pass: (1) RMS stats + dot products + bias grad, (2) grad_input + grad_mul + weight grads
+ - Weight/bias gradients use tl.atomic_add for cross-row accumulation
+"""
+
+import torch
+from torch import Tensor
+
+try:
+ import triton
+ import triton.language as tl
+
+ HAS_TRITON = True
+except ImportError:
+ HAS_TRITON = False
+
+
+# ---------------------------------------------------------------------------
+# PyTorch reference implementation (for testing and benchmarking)
+# ---------------------------------------------------------------------------
+def _rms_norm(x: Tensor, eps: float) -> Tensor:
+ """Per-row RMS normalization: x / sqrt(mean(x^2, dim=-1) + eps)"""
+ return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps)
+
+
+def grouped_fused_mul_poly_norm_ref(
+ input: Tensor,
+ mul: Tensor,
+ weight: Tensor,
+ bias: Tensor,
+ offsets: Tensor,
+ eps: float = 1e-6,
+ expert_offset: int = 0,
+) -> Tensor:
+ """PyTorch reference for Grouped FusedMulPolyNorm (vectorized, single pass).
+
+ Uses torch.bucketize to map tokens to experts, then computes PolyNorm
+ for all tokens at once. torch.compile friendly.
+
+ Args:
+ input: (total_tokens, D) - concatenated tokens for all experts
+ mul: (total_tokens, D) - gate values to multiply with
+ weight: (num_experts, 3) - per-expert polynomial weights [x^3, x^2, x]
+ bias: (num_experts, 1) - per-expert polynomial bias
+ offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32)
+ eps: numerical stability epsilon
+
+ Returns:
+ (total_tokens, D) - output tensor
+ """
+ orig_dtype = input.dtype
+
+ token_positions = torch.arange(input.shape[0], device=input.device)
+ expert_idx = torch.bucketize(token_positions, offsets, right=True) + expert_offset
+
+ weight_fp32 = weight.float()
+ bias_fp32 = bias.float()
+
+ per_token_w = weight_fp32[expert_idx]
+ per_token_b = bias_fp32[expert_idx]
+
+ x = input.float()
+ m = mul.float()
+
+ x2 = x * x
+ x3 = x2 * x
+
+ poly = (per_token_w[:, 0:1] * _rms_norm(x3, eps) +
+ per_token_w[:, 1:2] * _rms_norm(x2, eps) +
+ per_token_w[:, 2:3] * _rms_norm(x, eps) + per_token_b)
+
+ return (poly * m).to(orig_dtype)
+
+
+# ---------------------------------------------------------------------------
+# Triton kernel implementation
+# ---------------------------------------------------------------------------
+if HAS_TRITON:
+ # --- Autotune configurations ---
+ _GROUPED_POLYNORM_FWD_CONFIGS = [
+ triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2),
+ triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2),
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2),
+ triton.Config({"BLOCK_D": 2048}, num_warps=4, num_stages=1),
+ triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1),
+ triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1),
+ triton.Config({"BLOCK_D": 2048}, num_warps=32, num_stages=1),
+ ]
+
+ _GROUPED_POLYNORM_BWD_CONFIGS = [
+ triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2),
+ triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=5),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2),
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2),
+ triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1),
+ triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1),
+ ]
+
+ @triton.autotune(
+ configs=_GROUPED_POLYNORM_FWD_CONFIGS,
+ key=["D"],
+ )
+ @triton.jit
+ def _grouped_polynorm_fwd_kernel(
+ input_ptr,
+ mul_ptr,
+ weight_ptr,
+ bias_ptr,
+ offsets_ptr,
+ output_ptr,
+ N,
+ D,
+ num_experts,
+ eps,
+ expert_offset,
+ stride_input_row,
+ stride_mul_row,
+ stride_out_row,
+ BLOCK_D: tl.constexpr,
+ ):
+ """Forward kernel: one program per row."""
+ row = tl.program_id(0)
+ if row >= N:
+ return
+
+ # Binary search for expert index (12 iters covers up to 4096 experts)
+ lo = 0
+ hi = num_experts
+ for _ in range(12):
+ if lo < hi:
+ mid = (lo + hi) // 2
+ if tl.load(offsets_ptr + mid) <= row:
+ lo = mid + 1
+ else:
+ hi = mid
+ eidx = lo + expert_offset
+
+ w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32)
+ w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32)
+ w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32)
+ b = tl.load(bias_ptr + eidx).to(tl.float32)
+
+ input_row_ptr = input_ptr + row * stride_input_row
+ mul_row_ptr = mul_ptr + row * stride_mul_row
+ out_row_ptr = output_ptr + row * stride_out_row
+
+ D_float = D.to(tl.float32)
+
+ # --- Single-tile path ---
+ if D <= BLOCK_D:
+ d_offs = tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+
+ x2 = x * x
+ x3 = x2 * x
+
+ inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps)
+ inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps)
+ inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps)
+
+ # Pre-multiply scalar weight * inv_rms to save 1 FMA per element
+ w0_inv = w0 * inv_rms_x3
+ w1_inv = w1 * inv_rms_x2
+ w2_inv = w2 * inv_rms_x
+
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b
+ tl.store(out_row_ptr + d_offs, poly * m, mask=mask)
+ else:
+ # --- Multi-tile: two-pass approach ---
+ sum_x2 = tl.zeros((), dtype=tl.float32)
+ sum_x4 = tl.zeros((), dtype=tl.float32)
+ sum_x6 = tl.zeros((), dtype=tl.float32)
+
+ for d_start in range(0, D, BLOCK_D):
+ d_offs = d_start + tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ x2 = x * x
+ sum_x2 += tl.sum(x2)
+ sum_x4 += tl.sum(x2 * x2)
+ sum_x6 += tl.sum(x2 * x2 * x2)
+
+ inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps)
+ inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps)
+ inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps)
+
+ # Pre-multiply scalar weight * inv_rms
+ w0_inv = w0 * inv_rms_x3
+ w1_inv = w1 * inv_rms_x2
+ w2_inv = w2 * inv_rms_x
+
+ for d_start in range(0, D, BLOCK_D):
+ d_offs = d_start + tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ x2 = x * x
+ x3 = x2 * x
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b
+ tl.store(out_row_ptr + d_offs, poly * m, mask=mask)
+
+ @triton.autotune(
+ configs=_GROUPED_POLYNORM_BWD_CONFIGS,
+ key=["D"],
+ reset_to_zero=["grad_weight_ptr", "grad_bias_ptr"],
+ )
+ @triton.jit
+ def _grouped_polynorm_bwd_kernel(
+ grad_out_ptr,
+ input_ptr,
+ mul_ptr,
+ weight_ptr,
+ bias_ptr,
+ offsets_ptr,
+ grad_input_ptr,
+ grad_mul_ptr,
+ grad_weight_ptr,
+ grad_bias_ptr,
+ N,
+ D,
+ num_experts,
+ eps,
+ expert_offset,
+ stride_row,
+ BLOCK_D: tl.constexpr,
+ ):
+ """Backward kernel: one program per row, 2-pass approach.
+
+ Pass 1: RMS stats + dot products + bias grad
+ Pass 2: grad_input + grad_mul + weight grads (via atomic_add)
+ """
+ row = tl.program_id(0)
+ if row >= N:
+ return
+
+ lo = 0
+ hi = num_experts
+ for _ in range(12):
+ if lo < hi:
+ mid = (lo + hi) // 2
+ if tl.load(offsets_ptr + mid) <= row:
+ lo = mid + 1
+ else:
+ hi = mid
+ eidx = lo + expert_offset
+
+ w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32)
+ w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32)
+ w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32)
+ b_val = tl.load(bias_ptr + eidx).to(tl.float32)
+
+ input_row_ptr = input_ptr + row * stride_row
+ mul_row_ptr = mul_ptr + row * stride_row
+ grad_out_row_ptr = grad_out_ptr + row * stride_row
+ grad_input_row_ptr = grad_input_ptr + row * stride_row
+ grad_mul_row_ptr = grad_mul_ptr + row * stride_row
+
+ D_float = D.to(tl.float32)
+
+ # --- Single-tile path ---
+ if D <= BLOCK_D:
+ d_offs = tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+
+ x2 = x * x
+ x3 = x2 * x
+
+ # Compute RMS stats (x4 inlined to reduce register pressure)
+ inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps)
+ inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps)
+ inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps)
+
+ w0_inv = w0 * inv_rms_x3
+ w1_inv = w1 * inv_rms_x2
+ w2_inv = w2 * inv_rms_x
+
+ dpoly = go * m
+
+ # Dot products for coefficients and weight grads
+ sum_dpoly_x = tl.sum(dpoly * x)
+ sum_dpoly_x2 = tl.sum(dpoly * x2)
+ sum_dpoly_x3 = tl.sum(dpoly * x3)
+ grad_b_acc = tl.sum(dpoly)
+
+ # Weight grads
+ grad_w0_acc = inv_rms_x3 * sum_dpoly_x3
+ grad_w1_acc = inv_rms_x2 * sum_dpoly_x2
+ grad_w2_acc = inv_rms_x * sum_dpoly_x
+
+ # Coefficients for grad_input
+ coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float
+ coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float
+ coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float
+
+ # grad_mul
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv
+ tl.store(grad_mul_row_ptr + d_offs, go * (poly + b_val), mask=mask)
+
+ # grad_input (in-place accumulation to reduce register pressure)
+ g = inv_rms_x * (w2 * dpoly - x * coeff_x)
+ g += 2.0 * x * inv_rms_x2 * (w1 * dpoly - x2 * coeff_x2)
+ g += 3.0 * x2 * inv_rms_x3 * (w0 * dpoly - x3 * coeff_x3)
+ tl.store(grad_input_row_ptr + d_offs, g, mask=mask)
+
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc)
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc)
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc)
+ tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc)
+ else:
+ # --- Multi-tile: 2-pass ---
+ # Pass 1: RMS stats + dot products + bias grad
+ sum_x2 = tl.zeros((), dtype=tl.float32)
+ sum_x4 = tl.zeros((), dtype=tl.float32)
+ sum_x6 = tl.zeros((), dtype=tl.float32)
+ sum_dpoly_x = tl.zeros((), dtype=tl.float32)
+ sum_dpoly_x2 = tl.zeros((), dtype=tl.float32)
+ sum_dpoly_x3 = tl.zeros((), dtype=tl.float32)
+ grad_b_acc = tl.zeros((), dtype=tl.float32)
+
+ for d_start in range(0, D, BLOCK_D):
+ d_offs = d_start + tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+
+ x2 = x * x
+ x3 = x2 * x
+ dpoly = go * m
+
+ sum_x2 += tl.sum(x2)
+ sum_x4 += tl.sum(x2 * x2)
+ sum_x6 += tl.sum(x2 * x2 * x2)
+ sum_dpoly_x += tl.sum(dpoly * x)
+ sum_dpoly_x2 += tl.sum(dpoly * x2)
+ sum_dpoly_x3 += tl.sum(dpoly * x3)
+ grad_b_acc += tl.sum(dpoly)
+
+ inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps)
+ inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps)
+ inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps)
+
+ w0_inv = w0 * inv_rms_x3
+ w1_inv = w1 * inv_rms_x2
+ w2_inv = w2 * inv_rms_x
+
+ # Weight grads
+ grad_w0_acc = inv_rms_x3 * sum_dpoly_x3
+ grad_w1_acc = inv_rms_x2 * sum_dpoly_x2
+ grad_w2_acc = inv_rms_x * sum_dpoly_x
+
+ # Coefficients for grad_input
+ coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float
+ coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float
+ coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float
+
+ # Pass 2: grad_input + grad_mul
+ for d_start in range(0, D, BLOCK_D):
+ d_offs = d_start + tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+
+ x2 = x * x
+ x3 = x2 * x
+
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv
+ tl.store(grad_mul_row_ptr + d_offs,
+ go * (poly + b_val),
+ mask=mask)
+
+ dpoly = go * m
+ g = inv_rms_x * (w2 * dpoly - x * coeff_x)
+ g += (2.0 * x * inv_rms_x2 *
+ (w1 * dpoly - x2 * coeff_x2))
+ g += (3.0 * x2 * inv_rms_x3 *
+ (w0 * dpoly - x3 * coeff_x3))
+
+ tl.store(grad_input_row_ptr + d_offs, g, mask=mask)
+
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc)
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc)
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc)
+ tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc)
+
+ class _GroupedPolyNormFn(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, input, mul, weight, bias, offsets, eps, expert_offset):
+ N, D = input.shape
+ input = input.contiguous()
+ mul = mul.contiguous()
+ output = torch.empty_like(input)
+
+ num_experts = offsets.shape[0]
+ assert num_experts <= 4096, (
+ f"Supports at most 4096 experts, got {num_experts}.")
+
+ _grouped_polynorm_fwd_kernel[(N,)](
+ input,
+ mul,
+ weight,
+ bias,
+ offsets,
+ output,
+ N,
+ D,
+ num_experts,
+ eps,
+ expert_offset,
+ stride_input_row=input.stride(0),
+ stride_mul_row=mul.stride(0),
+ stride_out_row=output.stride(0),
+ )
+
+ ctx.save_for_backward(input, mul, weight, bias, offsets)
+ ctx.eps = eps
+ ctx.expert_offset = expert_offset
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input, mul, weight, bias, offsets = ctx.saved_tensors
+ eps = ctx.eps
+ expert_offset = ctx.expert_offset
+ N, D = input.shape
+
+ grad_output = grad_output.contiguous()
+ grad_input = torch.empty_like(input)
+ grad_mul = torch.empty_like(mul)
+ grad_weight = torch.zeros(weight.shape[0],
+ 3,
+ device=weight.device,
+ dtype=torch.float32)
+ grad_bias = torch.zeros(bias.shape[0],
+ device=bias.device,
+ dtype=torch.float32)
+
+ num_experts = offsets.shape[0]
+
+ _grouped_polynorm_bwd_kernel[(N,)](
+ grad_output,
+ input,
+ mul,
+ weight,
+ bias,
+ offsets,
+ grad_input,
+ grad_mul,
+ grad_weight,
+ grad_bias,
+ N,
+ D,
+ num_experts,
+ eps,
+ expert_offset,
+ stride_row=input.stride(0),
+ )
+
+ grad_weight = grad_weight.to(weight.dtype)
+ grad_bias = grad_bias.unsqueeze(-1).to(bias.dtype)
+
+ return grad_input, grad_mul, grad_weight, grad_bias, None, None, None
+
+ def grouped_fused_mul_poly_norm(
+ input: Tensor,
+ mul: Tensor,
+ weight: Tensor,
+ bias: Tensor,
+ offsets: Tensor,
+ eps: float = 1e-6,
+ expert_offset: int = 0,
+ ) -> Tensor:
+ """Triton-accelerated Grouped FusedMulPolyNorm.
+
+ Args:
+ input: (total_tokens, D) - concatenated tokens for all experts
+ mul: (total_tokens, D) - gate values to multiply with
+ weight: (num_experts, 3) - per-expert polynomial weights
+ bias: (num_experts, 1) - per-expert polynomial bias
+ offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32)
+ eps: numerical stability epsilon
+ expert_offset: offset to add to expert index
+
+ Returns:
+ (total_tokens, D) - output tensor
+ """
+ return _GroupedPolyNormFn.apply(input, mul, weight, bias, offsets, eps,
+ expert_offset)
+
+else:
+
+ def grouped_fused_mul_poly_norm(
+ input: Tensor,
+ mul: Tensor,
+ weight: Tensor,
+ bias: Tensor,
+ offsets: Tensor,
+ eps: float = 1e-6,
+ expert_offset: int = 0,
+ ) -> Tensor:
+ raise RuntimeError(
+ "Triton is not available. Install triton to use "
+ "grouped_fused_mul_poly_norm.")
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/__init__.py b/build/torch29-cxx11-cu130-x86_64-linux/__init__.py
index 0f6f29ac2c688bd09afa41c5d1abd9942c4456d8..a71567edf5422a315e6f06eae6a5606756794820 100644
--- a/build/torch29-cxx11-cu130-x86_64-linux/__init__.py
+++ b/build/torch29-cxx11-cu130-x86_64-linux/__init__.py
@@ -2,6 +2,7 @@ import torch
from . import layers, parallel_style
from ._ops import ops
+from .grouped_poly_norm import grouped_fused_mul_poly_norm
from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
@@ -45,6 +46,7 @@ def fused_add_rms_norm(
__all__ = [
"poly_norm",
"fused_mul_poly_norm",
+ "grouped_fused_mul_poly_norm",
"rms_norm",
"fused_add_rms_norm",
"layers",
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/_activation_0e6f27f_dirty.abi3.so b/build/torch29-cxx11-cu130-x86_64-linux/_activation_0e6f27f_dirty.abi3.so
new file mode 100755
index 0000000000000000000000000000000000000000..3293d543a0df5bb42f7431eb87a5664ff0765377
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/_activation_0e6f27f_dirty.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:eb799a7da44bca518185f71eaf7b2a48a1e0365f41fef9298c356a63fe374d2a
+size 13513992
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/_activation_18b7543_dirty.abi3.so b/build/torch29-cxx11-cu130-x86_64-linux/_activation_18b7543_dirty.abi3.so
deleted file mode 100755
index a2087eb03968ce925fa43789851fb421b65b3811..0000000000000000000000000000000000000000
--- a/build/torch29-cxx11-cu130-x86_64-linux/_activation_18b7543_dirty.abi3.so
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:330acceb49e763f83496e848fb3d55b72ff0da422daafc8d8f978f10d6c35cd2
-size 13513992
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/_ops.py b/build/torch29-cxx11-cu130-x86_64-linux/_ops.py
index 1e14ae6ea50287a7fbc0feee54855ae4ffc0be5e..91c4fcaab95d4cccfa1e53fea482ec1a2334f4ff 100644
--- a/build/torch29-cxx11-cu130-x86_64-linux/_ops.py
+++ b/build/torch29-cxx11-cu130-x86_64-linux/_ops.py
@@ -1,9 +1,9 @@
import torch
-from . import _activation_18b7543_dirty
-ops = torch.ops._activation_18b7543_dirty
+from . import _activation_0e6f27f_dirty
+ops = torch.ops._activation_0e6f27f_dirty
def add_op_namespace_prefix(op_name: str):
"""
Prefix op by namespace.
"""
- return f"_activation_18b7543_dirty::{op_name}"
\ No newline at end of file
+ return f"_activation_0e6f27f_dirty::{op_name}"
\ No newline at end of file
diff --git a/build/torch29-cxx11-cu130-x86_64-linux/grouped_poly_norm.py b/build/torch29-cxx11-cu130-x86_64-linux/grouped_poly_norm.py
new file mode 100644
index 0000000000000000000000000000000000000000..9184736c1c24584a94aa80c84fc18a95ab557a52
--- /dev/null
+++ b/build/torch29-cxx11-cu130-x86_64-linux/grouped_poly_norm.py
@@ -0,0 +1,583 @@
+"""Triton-accelerated Grouped FusedMulPolyNorm for MoE.
+
+Fuses the entire PolyNorm computation into two Triton kernels (fwd + bwd),
+eliminating multiple intermediate tensors and kernel launches.
+
+PolyNorm formula (per row):
+ poly = w[0] * rms_norm(x^3) + w[1] * rms_norm(x^2) + w[2] * rms_norm(x) + bias
+ output = poly * mul
+
+where rms_norm(x) = x / sqrt(mean(x^2, dim=-1) + eps)
+
+Performance optimizations:
+ - @triton.autotune selects optimal BLOCK_D, num_warps, and num_stages per
+ hidden dimension.
+ - Single-tile specialization: when D <= BLOCK_D, all data stays in registers
+ across the reduction and output phases, eliminating redundant global reads.
+ - Multi-tile software pipelining: explicit num_stages in autotune configs
+ enables overlapping memory loads with computation across loop iterations.
+ - In-kernel binary search for expert mapping: eliminates 2 PyTorch kernel
+ launches (torch.arange + torch.bucketize) per forward/backward call.
+ - Backward 2-pass optimization: pass 1 merges RMS statistics computation
+ with dot product accumulation, pass 2 computes gradients. This reduces
+ memory traffic compared to a naive 3-pass approach.
+
+Forward kernel: one program per row, tiles over D dimension.
+ - Computes x, x^2, x^3 in registers
+ - Computes three RMS norms in a single pass (shared variance reduction)
+ - Applies polynomial weights + bias + mul in-place
+
+Backward kernel: one program per row, tiles over D dimension.
+ - Recomputes forward intermediates from saved inputs (activation recomputation)
+ - 2-pass: (1) RMS stats + dot products + bias grad, (2) grad_input + grad_mul + weight grads
+ - Weight/bias gradients use tl.atomic_add for cross-row accumulation
+"""
+
+import torch
+from torch import Tensor
+
+try:
+ import triton
+ import triton.language as tl
+
+ HAS_TRITON = True
+except ImportError:
+ HAS_TRITON = False
+
+
+# ---------------------------------------------------------------------------
+# PyTorch reference implementation (for testing and benchmarking)
+# ---------------------------------------------------------------------------
+def _rms_norm(x: Tensor, eps: float) -> Tensor:
+ """Per-row RMS normalization: x / sqrt(mean(x^2, dim=-1) + eps)"""
+ return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps)
+
+
+def grouped_fused_mul_poly_norm_ref(
+ input: Tensor,
+ mul: Tensor,
+ weight: Tensor,
+ bias: Tensor,
+ offsets: Tensor,
+ eps: float = 1e-6,
+ expert_offset: int = 0,
+) -> Tensor:
+ """PyTorch reference for Grouped FusedMulPolyNorm (vectorized, single pass).
+
+ Uses torch.bucketize to map tokens to experts, then computes PolyNorm
+ for all tokens at once. torch.compile friendly.
+
+ Args:
+ input: (total_tokens, D) - concatenated tokens for all experts
+ mul: (total_tokens, D) - gate values to multiply with
+ weight: (num_experts, 3) - per-expert polynomial weights [x^3, x^2, x]
+ bias: (num_experts, 1) - per-expert polynomial bias
+ offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32)
+ eps: numerical stability epsilon
+
+ Returns:
+ (total_tokens, D) - output tensor
+ """
+ orig_dtype = input.dtype
+
+ token_positions = torch.arange(input.shape[0], device=input.device)
+ expert_idx = torch.bucketize(token_positions, offsets, right=True) + expert_offset
+
+ weight_fp32 = weight.float()
+ bias_fp32 = bias.float()
+
+ per_token_w = weight_fp32[expert_idx]
+ per_token_b = bias_fp32[expert_idx]
+
+ x = input.float()
+ m = mul.float()
+
+ x2 = x * x
+ x3 = x2 * x
+
+ poly = (per_token_w[:, 0:1] * _rms_norm(x3, eps) +
+ per_token_w[:, 1:2] * _rms_norm(x2, eps) +
+ per_token_w[:, 2:3] * _rms_norm(x, eps) + per_token_b)
+
+ return (poly * m).to(orig_dtype)
+
+
+# ---------------------------------------------------------------------------
+# Triton kernel implementation
+# ---------------------------------------------------------------------------
+if HAS_TRITON:
+ # --- Autotune configurations ---
+ _GROUPED_POLYNORM_FWD_CONFIGS = [
+ triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2),
+ triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2),
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2),
+ triton.Config({"BLOCK_D": 2048}, num_warps=4, num_stages=1),
+ triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1),
+ triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1),
+ triton.Config({"BLOCK_D": 2048}, num_warps=32, num_stages=1),
+ ]
+
+ _GROUPED_POLYNORM_BWD_CONFIGS = [
+ triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2),
+ triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=5),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2),
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2),
+ triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1),
+ triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1),
+ ]
+
+ @triton.autotune(
+ configs=_GROUPED_POLYNORM_FWD_CONFIGS,
+ key=["D"],
+ )
+ @triton.jit
+ def _grouped_polynorm_fwd_kernel(
+ input_ptr,
+ mul_ptr,
+ weight_ptr,
+ bias_ptr,
+ offsets_ptr,
+ output_ptr,
+ N,
+ D,
+ num_experts,
+ eps,
+ expert_offset,
+ stride_input_row,
+ stride_mul_row,
+ stride_out_row,
+ BLOCK_D: tl.constexpr,
+ ):
+ """Forward kernel: one program per row."""
+ row = tl.program_id(0)
+ if row >= N:
+ return
+
+ # Binary search for expert index (12 iters covers up to 4096 experts)
+ lo = 0
+ hi = num_experts
+ for _ in range(12):
+ if lo < hi:
+ mid = (lo + hi) // 2
+ if tl.load(offsets_ptr + mid) <= row:
+ lo = mid + 1
+ else:
+ hi = mid
+ eidx = lo + expert_offset
+
+ w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32)
+ w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32)
+ w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32)
+ b = tl.load(bias_ptr + eidx).to(tl.float32)
+
+ input_row_ptr = input_ptr + row * stride_input_row
+ mul_row_ptr = mul_ptr + row * stride_mul_row
+ out_row_ptr = output_ptr + row * stride_out_row
+
+ D_float = D.to(tl.float32)
+
+ # --- Single-tile path ---
+ if D <= BLOCK_D:
+ d_offs = tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+
+ x2 = x * x
+ x3 = x2 * x
+
+ inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps)
+ inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps)
+ inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps)
+
+ # Pre-multiply scalar weight * inv_rms to save 1 FMA per element
+ w0_inv = w0 * inv_rms_x3
+ w1_inv = w1 * inv_rms_x2
+ w2_inv = w2 * inv_rms_x
+
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b
+ tl.store(out_row_ptr + d_offs, poly * m, mask=mask)
+ else:
+ # --- Multi-tile: two-pass approach ---
+ sum_x2 = tl.zeros((), dtype=tl.float32)
+ sum_x4 = tl.zeros((), dtype=tl.float32)
+ sum_x6 = tl.zeros((), dtype=tl.float32)
+
+ for d_start in range(0, D, BLOCK_D):
+ d_offs = d_start + tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ x2 = x * x
+ sum_x2 += tl.sum(x2)
+ sum_x4 += tl.sum(x2 * x2)
+ sum_x6 += tl.sum(x2 * x2 * x2)
+
+ inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps)
+ inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps)
+ inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps)
+
+ # Pre-multiply scalar weight * inv_rms
+ w0_inv = w0 * inv_rms_x3
+ w1_inv = w1 * inv_rms_x2
+ w2_inv = w2 * inv_rms_x
+
+ for d_start in range(0, D, BLOCK_D):
+ d_offs = d_start + tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ x2 = x * x
+ x3 = x2 * x
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b
+ tl.store(out_row_ptr + d_offs, poly * m, mask=mask)
+
+ @triton.autotune(
+ configs=_GROUPED_POLYNORM_BWD_CONFIGS,
+ key=["D"],
+ reset_to_zero=["grad_weight_ptr", "grad_bias_ptr"],
+ )
+ @triton.jit
+ def _grouped_polynorm_bwd_kernel(
+ grad_out_ptr,
+ input_ptr,
+ mul_ptr,
+ weight_ptr,
+ bias_ptr,
+ offsets_ptr,
+ grad_input_ptr,
+ grad_mul_ptr,
+ grad_weight_ptr,
+ grad_bias_ptr,
+ N,
+ D,
+ num_experts,
+ eps,
+ expert_offset,
+ stride_row,
+ BLOCK_D: tl.constexpr,
+ ):
+ """Backward kernel: one program per row, 2-pass approach.
+
+ Pass 1: RMS stats + dot products + bias grad
+ Pass 2: grad_input + grad_mul + weight grads (via atomic_add)
+ """
+ row = tl.program_id(0)
+ if row >= N:
+ return
+
+ lo = 0
+ hi = num_experts
+ for _ in range(12):
+ if lo < hi:
+ mid = (lo + hi) // 2
+ if tl.load(offsets_ptr + mid) <= row:
+ lo = mid + 1
+ else:
+ hi = mid
+ eidx = lo + expert_offset
+
+ w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32)
+ w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32)
+ w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32)
+ b_val = tl.load(bias_ptr + eidx).to(tl.float32)
+
+ input_row_ptr = input_ptr + row * stride_row
+ mul_row_ptr = mul_ptr + row * stride_row
+ grad_out_row_ptr = grad_out_ptr + row * stride_row
+ grad_input_row_ptr = grad_input_ptr + row * stride_row
+ grad_mul_row_ptr = grad_mul_ptr + row * stride_row
+
+ D_float = D.to(tl.float32)
+
+ # --- Single-tile path ---
+ if D <= BLOCK_D:
+ d_offs = tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+
+ x2 = x * x
+ x3 = x2 * x
+
+ # Compute RMS stats (x4 inlined to reduce register pressure)
+ inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps)
+ inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps)
+ inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps)
+
+ w0_inv = w0 * inv_rms_x3
+ w1_inv = w1 * inv_rms_x2
+ w2_inv = w2 * inv_rms_x
+
+ dpoly = go * m
+
+ # Dot products for coefficients and weight grads
+ sum_dpoly_x = tl.sum(dpoly * x)
+ sum_dpoly_x2 = tl.sum(dpoly * x2)
+ sum_dpoly_x3 = tl.sum(dpoly * x3)
+ grad_b_acc = tl.sum(dpoly)
+
+ # Weight grads
+ grad_w0_acc = inv_rms_x3 * sum_dpoly_x3
+ grad_w1_acc = inv_rms_x2 * sum_dpoly_x2
+ grad_w2_acc = inv_rms_x * sum_dpoly_x
+
+ # Coefficients for grad_input
+ coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float
+ coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float
+ coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float
+
+ # grad_mul
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv
+ tl.store(grad_mul_row_ptr + d_offs, go * (poly + b_val), mask=mask)
+
+ # grad_input (in-place accumulation to reduce register pressure)
+ g = inv_rms_x * (w2 * dpoly - x * coeff_x)
+ g += 2.0 * x * inv_rms_x2 * (w1 * dpoly - x2 * coeff_x2)
+ g += 3.0 * x2 * inv_rms_x3 * (w0 * dpoly - x3 * coeff_x3)
+ tl.store(grad_input_row_ptr + d_offs, g, mask=mask)
+
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc)
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc)
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc)
+ tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc)
+ else:
+ # --- Multi-tile: 2-pass ---
+ # Pass 1: RMS stats + dot products + bias grad
+ sum_x2 = tl.zeros((), dtype=tl.float32)
+ sum_x4 = tl.zeros((), dtype=tl.float32)
+ sum_x6 = tl.zeros((), dtype=tl.float32)
+ sum_dpoly_x = tl.zeros((), dtype=tl.float32)
+ sum_dpoly_x2 = tl.zeros((), dtype=tl.float32)
+ sum_dpoly_x3 = tl.zeros((), dtype=tl.float32)
+ grad_b_acc = tl.zeros((), dtype=tl.float32)
+
+ for d_start in range(0, D, BLOCK_D):
+ d_offs = d_start + tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+
+ x2 = x * x
+ x3 = x2 * x
+ dpoly = go * m
+
+ sum_x2 += tl.sum(x2)
+ sum_x4 += tl.sum(x2 * x2)
+ sum_x6 += tl.sum(x2 * x2 * x2)
+ sum_dpoly_x += tl.sum(dpoly * x)
+ sum_dpoly_x2 += tl.sum(dpoly * x2)
+ sum_dpoly_x3 += tl.sum(dpoly * x3)
+ grad_b_acc += tl.sum(dpoly)
+
+ inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps)
+ inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps)
+ inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps)
+
+ w0_inv = w0 * inv_rms_x3
+ w1_inv = w1 * inv_rms_x2
+ w2_inv = w2 * inv_rms_x
+
+ # Weight grads
+ grad_w0_acc = inv_rms_x3 * sum_dpoly_x3
+ grad_w1_acc = inv_rms_x2 * sum_dpoly_x2
+ grad_w2_acc = inv_rms_x * sum_dpoly_x
+
+ # Coefficients for grad_input
+ coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float
+ coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float
+ coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float
+
+ # Pass 2: grad_input + grad_mul
+ for d_start in range(0, D, BLOCK_D):
+ d_offs = d_start + tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+
+ x2 = x * x
+ x3 = x2 * x
+
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv
+ tl.store(grad_mul_row_ptr + d_offs,
+ go * (poly + b_val),
+ mask=mask)
+
+ dpoly = go * m
+ g = inv_rms_x * (w2 * dpoly - x * coeff_x)
+ g += (2.0 * x * inv_rms_x2 *
+ (w1 * dpoly - x2 * coeff_x2))
+ g += (3.0 * x2 * inv_rms_x3 *
+ (w0 * dpoly - x3 * coeff_x3))
+
+ tl.store(grad_input_row_ptr + d_offs, g, mask=mask)
+
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc)
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc)
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc)
+ tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc)
+
+ class _GroupedPolyNormFn(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, input, mul, weight, bias, offsets, eps, expert_offset):
+ N, D = input.shape
+ input = input.contiguous()
+ mul = mul.contiguous()
+ output = torch.empty_like(input)
+
+ num_experts = offsets.shape[0]
+ assert num_experts <= 4096, (
+ f"Supports at most 4096 experts, got {num_experts}.")
+
+ _grouped_polynorm_fwd_kernel[(N,)](
+ input,
+ mul,
+ weight,
+ bias,
+ offsets,
+ output,
+ N,
+ D,
+ num_experts,
+ eps,
+ expert_offset,
+ stride_input_row=input.stride(0),
+ stride_mul_row=mul.stride(0),
+ stride_out_row=output.stride(0),
+ )
+
+ ctx.save_for_backward(input, mul, weight, bias, offsets)
+ ctx.eps = eps
+ ctx.expert_offset = expert_offset
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input, mul, weight, bias, offsets = ctx.saved_tensors
+ eps = ctx.eps
+ expert_offset = ctx.expert_offset
+ N, D = input.shape
+
+ grad_output = grad_output.contiguous()
+ grad_input = torch.empty_like(input)
+ grad_mul = torch.empty_like(mul)
+ grad_weight = torch.zeros(weight.shape[0],
+ 3,
+ device=weight.device,
+ dtype=torch.float32)
+ grad_bias = torch.zeros(bias.shape[0],
+ device=bias.device,
+ dtype=torch.float32)
+
+ num_experts = offsets.shape[0]
+
+ _grouped_polynorm_bwd_kernel[(N,)](
+ grad_output,
+ input,
+ mul,
+ weight,
+ bias,
+ offsets,
+ grad_input,
+ grad_mul,
+ grad_weight,
+ grad_bias,
+ N,
+ D,
+ num_experts,
+ eps,
+ expert_offset,
+ stride_row=input.stride(0),
+ )
+
+ grad_weight = grad_weight.to(weight.dtype)
+ grad_bias = grad_bias.unsqueeze(-1).to(bias.dtype)
+
+ return grad_input, grad_mul, grad_weight, grad_bias, None, None, None
+
+ def grouped_fused_mul_poly_norm(
+ input: Tensor,
+ mul: Tensor,
+ weight: Tensor,
+ bias: Tensor,
+ offsets: Tensor,
+ eps: float = 1e-6,
+ expert_offset: int = 0,
+ ) -> Tensor:
+ """Triton-accelerated Grouped FusedMulPolyNorm.
+
+ Args:
+ input: (total_tokens, D) - concatenated tokens for all experts
+ mul: (total_tokens, D) - gate values to multiply with
+ weight: (num_experts, 3) - per-expert polynomial weights
+ bias: (num_experts, 1) - per-expert polynomial bias
+ offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32)
+ eps: numerical stability epsilon
+ expert_offset: offset to add to expert index
+
+ Returns:
+ (total_tokens, D) - output tensor
+ """
+ return _GroupedPolyNormFn.apply(input, mul, weight, bias, offsets, eps,
+ expert_offset)
+
+else:
+
+ def grouped_fused_mul_poly_norm(
+ input: Tensor,
+ mul: Tensor,
+ weight: Tensor,
+ bias: Tensor,
+ offsets: Tensor,
+ eps: float = 1e-6,
+ expert_offset: int = 0,
+ ) -> Tensor:
+ raise RuntimeError(
+ "Triton is not available. Install triton to use "
+ "grouped_fused_mul_poly_norm.")
diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/__init__.py b/build/torch29-cxx11-rocm63-x86_64-linux/__init__.py
index 0f6f29ac2c688bd09afa41c5d1abd9942c4456d8..a71567edf5422a315e6f06eae6a5606756794820 100644
--- a/build/torch29-cxx11-rocm63-x86_64-linux/__init__.py
+++ b/build/torch29-cxx11-rocm63-x86_64-linux/__init__.py
@@ -2,6 +2,7 @@ import torch
from . import layers, parallel_style
from ._ops import ops
+from .grouped_poly_norm import grouped_fused_mul_poly_norm
from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
@@ -45,6 +46,7 @@ def fused_add_rms_norm(
__all__ = [
"poly_norm",
"fused_mul_poly_norm",
+ "grouped_fused_mul_poly_norm",
"rms_norm",
"fused_add_rms_norm",
"layers",
diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/_activation_0e6f27f_dirty.abi3.so b/build/torch29-cxx11-rocm63-x86_64-linux/_activation_0e6f27f_dirty.abi3.so
new file mode 100755
index 0000000000000000000000000000000000000000..bc553d81eccbb0529e8232185d4fbe549dca7f35
--- /dev/null
+++ b/build/torch29-cxx11-rocm63-x86_64-linux/_activation_0e6f27f_dirty.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:81b092e60ccd1d42d1fae0be2ed9de420609e032009cd3dc84267651806d3a62
+size 2788640
diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/_activation_18b7543_dirty.abi3.so b/build/torch29-cxx11-rocm63-x86_64-linux/_activation_18b7543_dirty.abi3.so
deleted file mode 100755
index c4ed96ce4da26c0944532ca07d697cca5d0a3ff3..0000000000000000000000000000000000000000
--- a/build/torch29-cxx11-rocm63-x86_64-linux/_activation_18b7543_dirty.abi3.so
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:d0461a82c98d87be63686fa52fad04e2f0c1a57605734c730e7d97d1d2cab4fa
-size 2788640
diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/_ops.py b/build/torch29-cxx11-rocm63-x86_64-linux/_ops.py
index 1e14ae6ea50287a7fbc0feee54855ae4ffc0be5e..91c4fcaab95d4cccfa1e53fea482ec1a2334f4ff 100644
--- a/build/torch29-cxx11-rocm63-x86_64-linux/_ops.py
+++ b/build/torch29-cxx11-rocm63-x86_64-linux/_ops.py
@@ -1,9 +1,9 @@
import torch
-from . import _activation_18b7543_dirty
-ops = torch.ops._activation_18b7543_dirty
+from . import _activation_0e6f27f_dirty
+ops = torch.ops._activation_0e6f27f_dirty
def add_op_namespace_prefix(op_name: str):
"""
Prefix op by namespace.
"""
- return f"_activation_18b7543_dirty::{op_name}"
\ No newline at end of file
+ return f"_activation_0e6f27f_dirty::{op_name}"
\ No newline at end of file
diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/grouped_poly_norm.py b/build/torch29-cxx11-rocm63-x86_64-linux/grouped_poly_norm.py
new file mode 100644
index 0000000000000000000000000000000000000000..9184736c1c24584a94aa80c84fc18a95ab557a52
--- /dev/null
+++ b/build/torch29-cxx11-rocm63-x86_64-linux/grouped_poly_norm.py
@@ -0,0 +1,583 @@
+"""Triton-accelerated Grouped FusedMulPolyNorm for MoE.
+
+Fuses the entire PolyNorm computation into two Triton kernels (fwd + bwd),
+eliminating multiple intermediate tensors and kernel launches.
+
+PolyNorm formula (per row):
+ poly = w[0] * rms_norm(x^3) + w[1] * rms_norm(x^2) + w[2] * rms_norm(x) + bias
+ output = poly * mul
+
+where rms_norm(x) = x / sqrt(mean(x^2, dim=-1) + eps)
+
+Performance optimizations:
+ - @triton.autotune selects optimal BLOCK_D, num_warps, and num_stages per
+ hidden dimension.
+ - Single-tile specialization: when D <= BLOCK_D, all data stays in registers
+ across the reduction and output phases, eliminating redundant global reads.
+ - Multi-tile software pipelining: explicit num_stages in autotune configs
+ enables overlapping memory loads with computation across loop iterations.
+ - In-kernel binary search for expert mapping: eliminates 2 PyTorch kernel
+ launches (torch.arange + torch.bucketize) per forward/backward call.
+ - Backward 2-pass optimization: pass 1 merges RMS statistics computation
+ with dot product accumulation, pass 2 computes gradients. This reduces
+ memory traffic compared to a naive 3-pass approach.
+
+Forward kernel: one program per row, tiles over D dimension.
+ - Computes x, x^2, x^3 in registers
+ - Computes three RMS norms in a single pass (shared variance reduction)
+ - Applies polynomial weights + bias + mul in-place
+
+Backward kernel: one program per row, tiles over D dimension.
+ - Recomputes forward intermediates from saved inputs (activation recomputation)
+ - 2-pass: (1) RMS stats + dot products + bias grad, (2) grad_input + grad_mul + weight grads
+ - Weight/bias gradients use tl.atomic_add for cross-row accumulation
+"""
+
+import torch
+from torch import Tensor
+
+try:
+ import triton
+ import triton.language as tl
+
+ HAS_TRITON = True
+except ImportError:
+ HAS_TRITON = False
+
+
+# ---------------------------------------------------------------------------
+# PyTorch reference implementation (for testing and benchmarking)
+# ---------------------------------------------------------------------------
+def _rms_norm(x: Tensor, eps: float) -> Tensor:
+ """Per-row RMS normalization: x / sqrt(mean(x^2, dim=-1) + eps)"""
+ return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps)
+
+
+def grouped_fused_mul_poly_norm_ref(
+ input: Tensor,
+ mul: Tensor,
+ weight: Tensor,
+ bias: Tensor,
+ offsets: Tensor,
+ eps: float = 1e-6,
+ expert_offset: int = 0,
+) -> Tensor:
+ """PyTorch reference for Grouped FusedMulPolyNorm (vectorized, single pass).
+
+ Uses torch.bucketize to map tokens to experts, then computes PolyNorm
+ for all tokens at once. torch.compile friendly.
+
+ Args:
+ input: (total_tokens, D) - concatenated tokens for all experts
+ mul: (total_tokens, D) - gate values to multiply with
+ weight: (num_experts, 3) - per-expert polynomial weights [x^3, x^2, x]
+ bias: (num_experts, 1) - per-expert polynomial bias
+ offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32)
+ eps: numerical stability epsilon
+
+ Returns:
+ (total_tokens, D) - output tensor
+ """
+ orig_dtype = input.dtype
+
+ token_positions = torch.arange(input.shape[0], device=input.device)
+ expert_idx = torch.bucketize(token_positions, offsets, right=True) + expert_offset
+
+ weight_fp32 = weight.float()
+ bias_fp32 = bias.float()
+
+ per_token_w = weight_fp32[expert_idx]
+ per_token_b = bias_fp32[expert_idx]
+
+ x = input.float()
+ m = mul.float()
+
+ x2 = x * x
+ x3 = x2 * x
+
+ poly = (per_token_w[:, 0:1] * _rms_norm(x3, eps) +
+ per_token_w[:, 1:2] * _rms_norm(x2, eps) +
+ per_token_w[:, 2:3] * _rms_norm(x, eps) + per_token_b)
+
+ return (poly * m).to(orig_dtype)
+
+
+# ---------------------------------------------------------------------------
+# Triton kernel implementation
+# ---------------------------------------------------------------------------
+if HAS_TRITON:
+ # --- Autotune configurations ---
+ _GROUPED_POLYNORM_FWD_CONFIGS = [
+ triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2),
+ triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2),
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2),
+ triton.Config({"BLOCK_D": 2048}, num_warps=4, num_stages=1),
+ triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1),
+ triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1),
+ triton.Config({"BLOCK_D": 2048}, num_warps=32, num_stages=1),
+ ]
+
+ _GROUPED_POLYNORM_BWD_CONFIGS = [
+ triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2),
+ triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=5),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2),
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2),
+ triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1),
+ triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1),
+ ]
+
+ @triton.autotune(
+ configs=_GROUPED_POLYNORM_FWD_CONFIGS,
+ key=["D"],
+ )
+ @triton.jit
+ def _grouped_polynorm_fwd_kernel(
+ input_ptr,
+ mul_ptr,
+ weight_ptr,
+ bias_ptr,
+ offsets_ptr,
+ output_ptr,
+ N,
+ D,
+ num_experts,
+ eps,
+ expert_offset,
+ stride_input_row,
+ stride_mul_row,
+ stride_out_row,
+ BLOCK_D: tl.constexpr,
+ ):
+ """Forward kernel: one program per row."""
+ row = tl.program_id(0)
+ if row >= N:
+ return
+
+ # Binary search for expert index (12 iters covers up to 4096 experts)
+ lo = 0
+ hi = num_experts
+ for _ in range(12):
+ if lo < hi:
+ mid = (lo + hi) // 2
+ if tl.load(offsets_ptr + mid) <= row:
+ lo = mid + 1
+ else:
+ hi = mid
+ eidx = lo + expert_offset
+
+ w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32)
+ w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32)
+ w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32)
+ b = tl.load(bias_ptr + eidx).to(tl.float32)
+
+ input_row_ptr = input_ptr + row * stride_input_row
+ mul_row_ptr = mul_ptr + row * stride_mul_row
+ out_row_ptr = output_ptr + row * stride_out_row
+
+ D_float = D.to(tl.float32)
+
+ # --- Single-tile path ---
+ if D <= BLOCK_D:
+ d_offs = tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+
+ x2 = x * x
+ x3 = x2 * x
+
+ inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps)
+ inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps)
+ inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps)
+
+ # Pre-multiply scalar weight * inv_rms to save 1 FMA per element
+ w0_inv = w0 * inv_rms_x3
+ w1_inv = w1 * inv_rms_x2
+ w2_inv = w2 * inv_rms_x
+
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b
+ tl.store(out_row_ptr + d_offs, poly * m, mask=mask)
+ else:
+ # --- Multi-tile: two-pass approach ---
+ sum_x2 = tl.zeros((), dtype=tl.float32)
+ sum_x4 = tl.zeros((), dtype=tl.float32)
+ sum_x6 = tl.zeros((), dtype=tl.float32)
+
+ for d_start in range(0, D, BLOCK_D):
+ d_offs = d_start + tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ x2 = x * x
+ sum_x2 += tl.sum(x2)
+ sum_x4 += tl.sum(x2 * x2)
+ sum_x6 += tl.sum(x2 * x2 * x2)
+
+ inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps)
+ inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps)
+ inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps)
+
+ # Pre-multiply scalar weight * inv_rms
+ w0_inv = w0 * inv_rms_x3
+ w1_inv = w1 * inv_rms_x2
+ w2_inv = w2 * inv_rms_x
+
+ for d_start in range(0, D, BLOCK_D):
+ d_offs = d_start + tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ x2 = x * x
+ x3 = x2 * x
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b
+ tl.store(out_row_ptr + d_offs, poly * m, mask=mask)
+
+ @triton.autotune(
+ configs=_GROUPED_POLYNORM_BWD_CONFIGS,
+ key=["D"],
+ reset_to_zero=["grad_weight_ptr", "grad_bias_ptr"],
+ )
+ @triton.jit
+ def _grouped_polynorm_bwd_kernel(
+ grad_out_ptr,
+ input_ptr,
+ mul_ptr,
+ weight_ptr,
+ bias_ptr,
+ offsets_ptr,
+ grad_input_ptr,
+ grad_mul_ptr,
+ grad_weight_ptr,
+ grad_bias_ptr,
+ N,
+ D,
+ num_experts,
+ eps,
+ expert_offset,
+ stride_row,
+ BLOCK_D: tl.constexpr,
+ ):
+ """Backward kernel: one program per row, 2-pass approach.
+
+ Pass 1: RMS stats + dot products + bias grad
+ Pass 2: grad_input + grad_mul + weight grads (via atomic_add)
+ """
+ row = tl.program_id(0)
+ if row >= N:
+ return
+
+ lo = 0
+ hi = num_experts
+ for _ in range(12):
+ if lo < hi:
+ mid = (lo + hi) // 2
+ if tl.load(offsets_ptr + mid) <= row:
+ lo = mid + 1
+ else:
+ hi = mid
+ eidx = lo + expert_offset
+
+ w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32)
+ w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32)
+ w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32)
+ b_val = tl.load(bias_ptr + eidx).to(tl.float32)
+
+ input_row_ptr = input_ptr + row * stride_row
+ mul_row_ptr = mul_ptr + row * stride_row
+ grad_out_row_ptr = grad_out_ptr + row * stride_row
+ grad_input_row_ptr = grad_input_ptr + row * stride_row
+ grad_mul_row_ptr = grad_mul_ptr + row * stride_row
+
+ D_float = D.to(tl.float32)
+
+ # --- Single-tile path ---
+ if D <= BLOCK_D:
+ d_offs = tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+
+ x2 = x * x
+ x3 = x2 * x
+
+ # Compute RMS stats (x4 inlined to reduce register pressure)
+ inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps)
+ inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps)
+ inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps)
+
+ w0_inv = w0 * inv_rms_x3
+ w1_inv = w1 * inv_rms_x2
+ w2_inv = w2 * inv_rms_x
+
+ dpoly = go * m
+
+ # Dot products for coefficients and weight grads
+ sum_dpoly_x = tl.sum(dpoly * x)
+ sum_dpoly_x2 = tl.sum(dpoly * x2)
+ sum_dpoly_x3 = tl.sum(dpoly * x3)
+ grad_b_acc = tl.sum(dpoly)
+
+ # Weight grads
+ grad_w0_acc = inv_rms_x3 * sum_dpoly_x3
+ grad_w1_acc = inv_rms_x2 * sum_dpoly_x2
+ grad_w2_acc = inv_rms_x * sum_dpoly_x
+
+ # Coefficients for grad_input
+ coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float
+ coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float
+ coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float
+
+ # grad_mul
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv
+ tl.store(grad_mul_row_ptr + d_offs, go * (poly + b_val), mask=mask)
+
+ # grad_input (in-place accumulation to reduce register pressure)
+ g = inv_rms_x * (w2 * dpoly - x * coeff_x)
+ g += 2.0 * x * inv_rms_x2 * (w1 * dpoly - x2 * coeff_x2)
+ g += 3.0 * x2 * inv_rms_x3 * (w0 * dpoly - x3 * coeff_x3)
+ tl.store(grad_input_row_ptr + d_offs, g, mask=mask)
+
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc)
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc)
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc)
+ tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc)
+ else:
+ # --- Multi-tile: 2-pass ---
+ # Pass 1: RMS stats + dot products + bias grad
+ sum_x2 = tl.zeros((), dtype=tl.float32)
+ sum_x4 = tl.zeros((), dtype=tl.float32)
+ sum_x6 = tl.zeros((), dtype=tl.float32)
+ sum_dpoly_x = tl.zeros((), dtype=tl.float32)
+ sum_dpoly_x2 = tl.zeros((), dtype=tl.float32)
+ sum_dpoly_x3 = tl.zeros((), dtype=tl.float32)
+ grad_b_acc = tl.zeros((), dtype=tl.float32)
+
+ for d_start in range(0, D, BLOCK_D):
+ d_offs = d_start + tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+
+ x2 = x * x
+ x3 = x2 * x
+ dpoly = go * m
+
+ sum_x2 += tl.sum(x2)
+ sum_x4 += tl.sum(x2 * x2)
+ sum_x6 += tl.sum(x2 * x2 * x2)
+ sum_dpoly_x += tl.sum(dpoly * x)
+ sum_dpoly_x2 += tl.sum(dpoly * x2)
+ sum_dpoly_x3 += tl.sum(dpoly * x3)
+ grad_b_acc += tl.sum(dpoly)
+
+ inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps)
+ inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps)
+ inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps)
+
+ w0_inv = w0 * inv_rms_x3
+ w1_inv = w1 * inv_rms_x2
+ w2_inv = w2 * inv_rms_x
+
+ # Weight grads
+ grad_w0_acc = inv_rms_x3 * sum_dpoly_x3
+ grad_w1_acc = inv_rms_x2 * sum_dpoly_x2
+ grad_w2_acc = inv_rms_x * sum_dpoly_x
+
+ # Coefficients for grad_input
+ coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float
+ coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float
+ coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float
+
+ # Pass 2: grad_input + grad_mul
+ for d_start in range(0, D, BLOCK_D):
+ d_offs = d_start + tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+
+ x2 = x * x
+ x3 = x2 * x
+
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv
+ tl.store(grad_mul_row_ptr + d_offs,
+ go * (poly + b_val),
+ mask=mask)
+
+ dpoly = go * m
+ g = inv_rms_x * (w2 * dpoly - x * coeff_x)
+ g += (2.0 * x * inv_rms_x2 *
+ (w1 * dpoly - x2 * coeff_x2))
+ g += (3.0 * x2 * inv_rms_x3 *
+ (w0 * dpoly - x3 * coeff_x3))
+
+ tl.store(grad_input_row_ptr + d_offs, g, mask=mask)
+
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc)
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc)
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc)
+ tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc)
+
+ class _GroupedPolyNormFn(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, input, mul, weight, bias, offsets, eps, expert_offset):
+ N, D = input.shape
+ input = input.contiguous()
+ mul = mul.contiguous()
+ output = torch.empty_like(input)
+
+ num_experts = offsets.shape[0]
+ assert num_experts <= 4096, (
+ f"Supports at most 4096 experts, got {num_experts}.")
+
+ _grouped_polynorm_fwd_kernel[(N,)](
+ input,
+ mul,
+ weight,
+ bias,
+ offsets,
+ output,
+ N,
+ D,
+ num_experts,
+ eps,
+ expert_offset,
+ stride_input_row=input.stride(0),
+ stride_mul_row=mul.stride(0),
+ stride_out_row=output.stride(0),
+ )
+
+ ctx.save_for_backward(input, mul, weight, bias, offsets)
+ ctx.eps = eps
+ ctx.expert_offset = expert_offset
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input, mul, weight, bias, offsets = ctx.saved_tensors
+ eps = ctx.eps
+ expert_offset = ctx.expert_offset
+ N, D = input.shape
+
+ grad_output = grad_output.contiguous()
+ grad_input = torch.empty_like(input)
+ grad_mul = torch.empty_like(mul)
+ grad_weight = torch.zeros(weight.shape[0],
+ 3,
+ device=weight.device,
+ dtype=torch.float32)
+ grad_bias = torch.zeros(bias.shape[0],
+ device=bias.device,
+ dtype=torch.float32)
+
+ num_experts = offsets.shape[0]
+
+ _grouped_polynorm_bwd_kernel[(N,)](
+ grad_output,
+ input,
+ mul,
+ weight,
+ bias,
+ offsets,
+ grad_input,
+ grad_mul,
+ grad_weight,
+ grad_bias,
+ N,
+ D,
+ num_experts,
+ eps,
+ expert_offset,
+ stride_row=input.stride(0),
+ )
+
+ grad_weight = grad_weight.to(weight.dtype)
+ grad_bias = grad_bias.unsqueeze(-1).to(bias.dtype)
+
+ return grad_input, grad_mul, grad_weight, grad_bias, None, None, None
+
+ def grouped_fused_mul_poly_norm(
+ input: Tensor,
+ mul: Tensor,
+ weight: Tensor,
+ bias: Tensor,
+ offsets: Tensor,
+ eps: float = 1e-6,
+ expert_offset: int = 0,
+ ) -> Tensor:
+ """Triton-accelerated Grouped FusedMulPolyNorm.
+
+ Args:
+ input: (total_tokens, D) - concatenated tokens for all experts
+ mul: (total_tokens, D) - gate values to multiply with
+ weight: (num_experts, 3) - per-expert polynomial weights
+ bias: (num_experts, 1) - per-expert polynomial bias
+ offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32)
+ eps: numerical stability epsilon
+ expert_offset: offset to add to expert index
+
+ Returns:
+ (total_tokens, D) - output tensor
+ """
+ return _GroupedPolyNormFn.apply(input, mul, weight, bias, offsets, eps,
+ expert_offset)
+
+else:
+
+ def grouped_fused_mul_poly_norm(
+ input: Tensor,
+ mul: Tensor,
+ weight: Tensor,
+ bias: Tensor,
+ offsets: Tensor,
+ eps: float = 1e-6,
+ expert_offset: int = 0,
+ ) -> Tensor:
+ raise RuntimeError(
+ "Triton is not available. Install triton to use "
+ "grouped_fused_mul_poly_norm.")
diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/__init__.py b/build/torch29-cxx11-rocm64-x86_64-linux/__init__.py
index 0f6f29ac2c688bd09afa41c5d1abd9942c4456d8..a71567edf5422a315e6f06eae6a5606756794820 100644
--- a/build/torch29-cxx11-rocm64-x86_64-linux/__init__.py
+++ b/build/torch29-cxx11-rocm64-x86_64-linux/__init__.py
@@ -2,6 +2,7 @@ import torch
from . import layers, parallel_style
from ._ops import ops
+from .grouped_poly_norm import grouped_fused_mul_poly_norm
from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
@@ -45,6 +46,7 @@ def fused_add_rms_norm(
__all__ = [
"poly_norm",
"fused_mul_poly_norm",
+ "grouped_fused_mul_poly_norm",
"rms_norm",
"fused_add_rms_norm",
"layers",
diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/_activation_0e6f27f_dirty.abi3.so b/build/torch29-cxx11-rocm64-x86_64-linux/_activation_0e6f27f_dirty.abi3.so
new file mode 100755
index 0000000000000000000000000000000000000000..93e91975634bbb6d38d1dcc1886c24e28544b255
--- /dev/null
+++ b/build/torch29-cxx11-rocm64-x86_64-linux/_activation_0e6f27f_dirty.abi3.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:38f13670768414f29af3722affae76287a6a9c90f198919432e23c170183fd47
+size 2798440
diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/_activation_18b7543_dirty.abi3.so b/build/torch29-cxx11-rocm64-x86_64-linux/_activation_18b7543_dirty.abi3.so
deleted file mode 100755
index 9523fca633977cc7250bab7d685b1f9fa766de90..0000000000000000000000000000000000000000
--- a/build/torch29-cxx11-rocm64-x86_64-linux/_activation_18b7543_dirty.abi3.so
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:5b649f0e22da07195193300ef534bc309682f4fdb0ca4f0c5286e16cecd57e5b
-size 2798440
diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/_ops.py b/build/torch29-cxx11-rocm64-x86_64-linux/_ops.py
index 1e14ae6ea50287a7fbc0feee54855ae4ffc0be5e..91c4fcaab95d4cccfa1e53fea482ec1a2334f4ff 100644
--- a/build/torch29-cxx11-rocm64-x86_64-linux/_ops.py
+++ b/build/torch29-cxx11-rocm64-x86_64-linux/_ops.py
@@ -1,9 +1,9 @@
import torch
-from . import _activation_18b7543_dirty
-ops = torch.ops._activation_18b7543_dirty
+from . import _activation_0e6f27f_dirty
+ops = torch.ops._activation_0e6f27f_dirty
def add_op_namespace_prefix(op_name: str):
"""
Prefix op by namespace.
"""
- return f"_activation_18b7543_dirty::{op_name}"
\ No newline at end of file
+ return f"_activation_0e6f27f_dirty::{op_name}"
\ No newline at end of file
diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/grouped_poly_norm.py b/build/torch29-cxx11-rocm64-x86_64-linux/grouped_poly_norm.py
new file mode 100644
index 0000000000000000000000000000000000000000..9184736c1c24584a94aa80c84fc18a95ab557a52
--- /dev/null
+++ b/build/torch29-cxx11-rocm64-x86_64-linux/grouped_poly_norm.py
@@ -0,0 +1,583 @@
+"""Triton-accelerated Grouped FusedMulPolyNorm for MoE.
+
+Fuses the entire PolyNorm computation into two Triton kernels (fwd + bwd),
+eliminating multiple intermediate tensors and kernel launches.
+
+PolyNorm formula (per row):
+ poly = w[0] * rms_norm(x^3) + w[1] * rms_norm(x^2) + w[2] * rms_norm(x) + bias
+ output = poly * mul
+
+where rms_norm(x) = x / sqrt(mean(x^2, dim=-1) + eps)
+
+Performance optimizations:
+ - @triton.autotune selects optimal BLOCK_D, num_warps, and num_stages per
+ hidden dimension.
+ - Single-tile specialization: when D <= BLOCK_D, all data stays in registers
+ across the reduction and output phases, eliminating redundant global reads.
+ - Multi-tile software pipelining: explicit num_stages in autotune configs
+ enables overlapping memory loads with computation across loop iterations.
+ - In-kernel binary search for expert mapping: eliminates 2 PyTorch kernel
+ launches (torch.arange + torch.bucketize) per forward/backward call.
+ - Backward 2-pass optimization: pass 1 merges RMS statistics computation
+ with dot product accumulation, pass 2 computes gradients. This reduces
+ memory traffic compared to a naive 3-pass approach.
+
+Forward kernel: one program per row, tiles over D dimension.
+ - Computes x, x^2, x^3 in registers
+ - Computes three RMS norms in a single pass (shared variance reduction)
+ - Applies polynomial weights + bias + mul in-place
+
+Backward kernel: one program per row, tiles over D dimension.
+ - Recomputes forward intermediates from saved inputs (activation recomputation)
+ - 2-pass: (1) RMS stats + dot products + bias grad, (2) grad_input + grad_mul + weight grads
+ - Weight/bias gradients use tl.atomic_add for cross-row accumulation
+"""
+
+import torch
+from torch import Tensor
+
+try:
+ import triton
+ import triton.language as tl
+
+ HAS_TRITON = True
+except ImportError:
+ HAS_TRITON = False
+
+
+# ---------------------------------------------------------------------------
+# PyTorch reference implementation (for testing and benchmarking)
+# ---------------------------------------------------------------------------
+def _rms_norm(x: Tensor, eps: float) -> Tensor:
+ """Per-row RMS normalization: x / sqrt(mean(x^2, dim=-1) + eps)"""
+ return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps)
+
+
+def grouped_fused_mul_poly_norm_ref(
+ input: Tensor,
+ mul: Tensor,
+ weight: Tensor,
+ bias: Tensor,
+ offsets: Tensor,
+ eps: float = 1e-6,
+ expert_offset: int = 0,
+) -> Tensor:
+ """PyTorch reference for Grouped FusedMulPolyNorm (vectorized, single pass).
+
+ Uses torch.bucketize to map tokens to experts, then computes PolyNorm
+ for all tokens at once. torch.compile friendly.
+
+ Args:
+ input: (total_tokens, D) - concatenated tokens for all experts
+ mul: (total_tokens, D) - gate values to multiply with
+ weight: (num_experts, 3) - per-expert polynomial weights [x^3, x^2, x]
+ bias: (num_experts, 1) - per-expert polynomial bias
+ offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32)
+ eps: numerical stability epsilon
+
+ Returns:
+ (total_tokens, D) - output tensor
+ """
+ orig_dtype = input.dtype
+
+ token_positions = torch.arange(input.shape[0], device=input.device)
+ expert_idx = torch.bucketize(token_positions, offsets, right=True) + expert_offset
+
+ weight_fp32 = weight.float()
+ bias_fp32 = bias.float()
+
+ per_token_w = weight_fp32[expert_idx]
+ per_token_b = bias_fp32[expert_idx]
+
+ x = input.float()
+ m = mul.float()
+
+ x2 = x * x
+ x3 = x2 * x
+
+ poly = (per_token_w[:, 0:1] * _rms_norm(x3, eps) +
+ per_token_w[:, 1:2] * _rms_norm(x2, eps) +
+ per_token_w[:, 2:3] * _rms_norm(x, eps) + per_token_b)
+
+ return (poly * m).to(orig_dtype)
+
+
+# ---------------------------------------------------------------------------
+# Triton kernel implementation
+# ---------------------------------------------------------------------------
+if HAS_TRITON:
+ # --- Autotune configurations ---
+ _GROUPED_POLYNORM_FWD_CONFIGS = [
+ triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2),
+ triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2),
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2),
+ triton.Config({"BLOCK_D": 2048}, num_warps=4, num_stages=1),
+ triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1),
+ triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1),
+ triton.Config({"BLOCK_D": 2048}, num_warps=32, num_stages=1),
+ ]
+
+ _GROUPED_POLYNORM_BWD_CONFIGS = [
+ triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2),
+ triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=5),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2),
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2),
+ triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1),
+ triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1),
+ ]
+
+ @triton.autotune(
+ configs=_GROUPED_POLYNORM_FWD_CONFIGS,
+ key=["D"],
+ )
+ @triton.jit
+ def _grouped_polynorm_fwd_kernel(
+ input_ptr,
+ mul_ptr,
+ weight_ptr,
+ bias_ptr,
+ offsets_ptr,
+ output_ptr,
+ N,
+ D,
+ num_experts,
+ eps,
+ expert_offset,
+ stride_input_row,
+ stride_mul_row,
+ stride_out_row,
+ BLOCK_D: tl.constexpr,
+ ):
+ """Forward kernel: one program per row."""
+ row = tl.program_id(0)
+ if row >= N:
+ return
+
+ # Binary search for expert index (12 iters covers up to 4096 experts)
+ lo = 0
+ hi = num_experts
+ for _ in range(12):
+ if lo < hi:
+ mid = (lo + hi) // 2
+ if tl.load(offsets_ptr + mid) <= row:
+ lo = mid + 1
+ else:
+ hi = mid
+ eidx = lo + expert_offset
+
+ w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32)
+ w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32)
+ w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32)
+ b = tl.load(bias_ptr + eidx).to(tl.float32)
+
+ input_row_ptr = input_ptr + row * stride_input_row
+ mul_row_ptr = mul_ptr + row * stride_mul_row
+ out_row_ptr = output_ptr + row * stride_out_row
+
+ D_float = D.to(tl.float32)
+
+ # --- Single-tile path ---
+ if D <= BLOCK_D:
+ d_offs = tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+
+ x2 = x * x
+ x3 = x2 * x
+
+ inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps)
+ inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps)
+ inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps)
+
+ # Pre-multiply scalar weight * inv_rms to save 1 FMA per element
+ w0_inv = w0 * inv_rms_x3
+ w1_inv = w1 * inv_rms_x2
+ w2_inv = w2 * inv_rms_x
+
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b
+ tl.store(out_row_ptr + d_offs, poly * m, mask=mask)
+ else:
+ # --- Multi-tile: two-pass approach ---
+ sum_x2 = tl.zeros((), dtype=tl.float32)
+ sum_x4 = tl.zeros((), dtype=tl.float32)
+ sum_x6 = tl.zeros((), dtype=tl.float32)
+
+ for d_start in range(0, D, BLOCK_D):
+ d_offs = d_start + tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ x2 = x * x
+ sum_x2 += tl.sum(x2)
+ sum_x4 += tl.sum(x2 * x2)
+ sum_x6 += tl.sum(x2 * x2 * x2)
+
+ inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps)
+ inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps)
+ inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps)
+
+ # Pre-multiply scalar weight * inv_rms
+ w0_inv = w0 * inv_rms_x3
+ w1_inv = w1 * inv_rms_x2
+ w2_inv = w2 * inv_rms_x
+
+ for d_start in range(0, D, BLOCK_D):
+ d_offs = d_start + tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ x2 = x * x
+ x3 = x2 * x
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b
+ tl.store(out_row_ptr + d_offs, poly * m, mask=mask)
+
+ @triton.autotune(
+ configs=_GROUPED_POLYNORM_BWD_CONFIGS,
+ key=["D"],
+ reset_to_zero=["grad_weight_ptr", "grad_bias_ptr"],
+ )
+ @triton.jit
+ def _grouped_polynorm_bwd_kernel(
+ grad_out_ptr,
+ input_ptr,
+ mul_ptr,
+ weight_ptr,
+ bias_ptr,
+ offsets_ptr,
+ grad_input_ptr,
+ grad_mul_ptr,
+ grad_weight_ptr,
+ grad_bias_ptr,
+ N,
+ D,
+ num_experts,
+ eps,
+ expert_offset,
+ stride_row,
+ BLOCK_D: tl.constexpr,
+ ):
+ """Backward kernel: one program per row, 2-pass approach.
+
+ Pass 1: RMS stats + dot products + bias grad
+ Pass 2: grad_input + grad_mul + weight grads (via atomic_add)
+ """
+ row = tl.program_id(0)
+ if row >= N:
+ return
+
+ lo = 0
+ hi = num_experts
+ for _ in range(12):
+ if lo < hi:
+ mid = (lo + hi) // 2
+ if tl.load(offsets_ptr + mid) <= row:
+ lo = mid + 1
+ else:
+ hi = mid
+ eidx = lo + expert_offset
+
+ w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32)
+ w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32)
+ w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32)
+ b_val = tl.load(bias_ptr + eidx).to(tl.float32)
+
+ input_row_ptr = input_ptr + row * stride_row
+ mul_row_ptr = mul_ptr + row * stride_row
+ grad_out_row_ptr = grad_out_ptr + row * stride_row
+ grad_input_row_ptr = grad_input_ptr + row * stride_row
+ grad_mul_row_ptr = grad_mul_ptr + row * stride_row
+
+ D_float = D.to(tl.float32)
+
+ # --- Single-tile path ---
+ if D <= BLOCK_D:
+ d_offs = tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+
+ x2 = x * x
+ x3 = x2 * x
+
+ # Compute RMS stats (x4 inlined to reduce register pressure)
+ inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps)
+ inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps)
+ inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps)
+
+ w0_inv = w0 * inv_rms_x3
+ w1_inv = w1 * inv_rms_x2
+ w2_inv = w2 * inv_rms_x
+
+ dpoly = go * m
+
+ # Dot products for coefficients and weight grads
+ sum_dpoly_x = tl.sum(dpoly * x)
+ sum_dpoly_x2 = tl.sum(dpoly * x2)
+ sum_dpoly_x3 = tl.sum(dpoly * x3)
+ grad_b_acc = tl.sum(dpoly)
+
+ # Weight grads
+ grad_w0_acc = inv_rms_x3 * sum_dpoly_x3
+ grad_w1_acc = inv_rms_x2 * sum_dpoly_x2
+ grad_w2_acc = inv_rms_x * sum_dpoly_x
+
+ # Coefficients for grad_input
+ coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float
+ coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float
+ coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float
+
+ # grad_mul
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv
+ tl.store(grad_mul_row_ptr + d_offs, go * (poly + b_val), mask=mask)
+
+ # grad_input (in-place accumulation to reduce register pressure)
+ g = inv_rms_x * (w2 * dpoly - x * coeff_x)
+ g += 2.0 * x * inv_rms_x2 * (w1 * dpoly - x2 * coeff_x2)
+ g += 3.0 * x2 * inv_rms_x3 * (w0 * dpoly - x3 * coeff_x3)
+ tl.store(grad_input_row_ptr + d_offs, g, mask=mask)
+
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc)
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc)
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc)
+ tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc)
+ else:
+ # --- Multi-tile: 2-pass ---
+ # Pass 1: RMS stats + dot products + bias grad
+ sum_x2 = tl.zeros((), dtype=tl.float32)
+ sum_x4 = tl.zeros((), dtype=tl.float32)
+ sum_x6 = tl.zeros((), dtype=tl.float32)
+ sum_dpoly_x = tl.zeros((), dtype=tl.float32)
+ sum_dpoly_x2 = tl.zeros((), dtype=tl.float32)
+ sum_dpoly_x3 = tl.zeros((), dtype=tl.float32)
+ grad_b_acc = tl.zeros((), dtype=tl.float32)
+
+ for d_start in range(0, D, BLOCK_D):
+ d_offs = d_start + tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+
+ x2 = x * x
+ x3 = x2 * x
+ dpoly = go * m
+
+ sum_x2 += tl.sum(x2)
+ sum_x4 += tl.sum(x2 * x2)
+ sum_x6 += tl.sum(x2 * x2 * x2)
+ sum_dpoly_x += tl.sum(dpoly * x)
+ sum_dpoly_x2 += tl.sum(dpoly * x2)
+ sum_dpoly_x3 += tl.sum(dpoly * x3)
+ grad_b_acc += tl.sum(dpoly)
+
+ inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps)
+ inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps)
+ inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps)
+
+ w0_inv = w0 * inv_rms_x3
+ w1_inv = w1 * inv_rms_x2
+ w2_inv = w2 * inv_rms_x
+
+ # Weight grads
+ grad_w0_acc = inv_rms_x3 * sum_dpoly_x3
+ grad_w1_acc = inv_rms_x2 * sum_dpoly_x2
+ grad_w2_acc = inv_rms_x * sum_dpoly_x
+
+ # Coefficients for grad_input
+ coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float
+ coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float
+ coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float
+
+ # Pass 2: grad_input + grad_mul
+ for d_start in range(0, D, BLOCK_D):
+ d_offs = d_start + tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+
+ x2 = x * x
+ x3 = x2 * x
+
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv
+ tl.store(grad_mul_row_ptr + d_offs,
+ go * (poly + b_val),
+ mask=mask)
+
+ dpoly = go * m
+ g = inv_rms_x * (w2 * dpoly - x * coeff_x)
+ g += (2.0 * x * inv_rms_x2 *
+ (w1 * dpoly - x2 * coeff_x2))
+ g += (3.0 * x2 * inv_rms_x3 *
+ (w0 * dpoly - x3 * coeff_x3))
+
+ tl.store(grad_input_row_ptr + d_offs, g, mask=mask)
+
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc)
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc)
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc)
+ tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc)
+
+ class _GroupedPolyNormFn(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, input, mul, weight, bias, offsets, eps, expert_offset):
+ N, D = input.shape
+ input = input.contiguous()
+ mul = mul.contiguous()
+ output = torch.empty_like(input)
+
+ num_experts = offsets.shape[0]
+ assert num_experts <= 4096, (
+ f"Supports at most 4096 experts, got {num_experts}.")
+
+ _grouped_polynorm_fwd_kernel[(N,)](
+ input,
+ mul,
+ weight,
+ bias,
+ offsets,
+ output,
+ N,
+ D,
+ num_experts,
+ eps,
+ expert_offset,
+ stride_input_row=input.stride(0),
+ stride_mul_row=mul.stride(0),
+ stride_out_row=output.stride(0),
+ )
+
+ ctx.save_for_backward(input, mul, weight, bias, offsets)
+ ctx.eps = eps
+ ctx.expert_offset = expert_offset
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input, mul, weight, bias, offsets = ctx.saved_tensors
+ eps = ctx.eps
+ expert_offset = ctx.expert_offset
+ N, D = input.shape
+
+ grad_output = grad_output.contiguous()
+ grad_input = torch.empty_like(input)
+ grad_mul = torch.empty_like(mul)
+ grad_weight = torch.zeros(weight.shape[0],
+ 3,
+ device=weight.device,
+ dtype=torch.float32)
+ grad_bias = torch.zeros(bias.shape[0],
+ device=bias.device,
+ dtype=torch.float32)
+
+ num_experts = offsets.shape[0]
+
+ _grouped_polynorm_bwd_kernel[(N,)](
+ grad_output,
+ input,
+ mul,
+ weight,
+ bias,
+ offsets,
+ grad_input,
+ grad_mul,
+ grad_weight,
+ grad_bias,
+ N,
+ D,
+ num_experts,
+ eps,
+ expert_offset,
+ stride_row=input.stride(0),
+ )
+
+ grad_weight = grad_weight.to(weight.dtype)
+ grad_bias = grad_bias.unsqueeze(-1).to(bias.dtype)
+
+ return grad_input, grad_mul, grad_weight, grad_bias, None, None, None
+
+ def grouped_fused_mul_poly_norm(
+ input: Tensor,
+ mul: Tensor,
+ weight: Tensor,
+ bias: Tensor,
+ offsets: Tensor,
+ eps: float = 1e-6,
+ expert_offset: int = 0,
+ ) -> Tensor:
+ """Triton-accelerated Grouped FusedMulPolyNorm.
+
+ Args:
+ input: (total_tokens, D) - concatenated tokens for all experts
+ mul: (total_tokens, D) - gate values to multiply with
+ weight: (num_experts, 3) - per-expert polynomial weights
+ bias: (num_experts, 1) - per-expert polynomial bias
+ offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32)
+ eps: numerical stability epsilon
+ expert_offset: offset to add to expert index
+
+ Returns:
+ (total_tokens, D) - output tensor
+ """
+ return _GroupedPolyNormFn.apply(input, mul, weight, bias, offsets, eps,
+ expert_offset)
+
+else:
+
+ def grouped_fused_mul_poly_norm(
+ input: Tensor,
+ mul: Tensor,
+ weight: Tensor,
+ bias: Tensor,
+ offsets: Tensor,
+ eps: float = 1e-6,
+ expert_offset: int = 0,
+ ) -> Tensor:
+ raise RuntimeError(
+ "Triton is not available. Install triton to use "
+ "grouped_fused_mul_poly_norm.")
diff --git a/tests/test_grouped_fused_mul_poly_norm.py b/tests/test_grouped_fused_mul_poly_norm.py
new file mode 100644
index 0000000000000000000000000000000000000000..8dd72431e6d8070f954a44f145b35cd993b9a767
--- /dev/null
+++ b/tests/test_grouped_fused_mul_poly_norm.py
@@ -0,0 +1,249 @@
+import pytest
+import torch
+
+from grouped_poly_norm import (
+ HAS_TRITON,
+ grouped_fused_mul_poly_norm_ref,
+)
+
+if HAS_TRITON:
+ from grouped_poly_norm import grouped_fused_mul_poly_norm
+
+from .utils import assert_close
+
+DTYPES = [torch.float, torch.bfloat16, torch.float16]
+NUM_TOKENS = [4096, 8192]
+D = [256, 1280]
+NUM_EXPERTS_LIST = [8, 384]
+EXPERT_OFFSETS = [0, 4]
+SEEDS = [0]
+# Triton kernels launch on the current CUDA device and do not
+# auto-dispatch to the tensor's device like CUDA extensions.
+# Only test on cuda:0 to avoid cross-device issues.
+CUDA_DEVICES = ["cuda:0"]
+
+
+def _counts_to_offsets(counts_list, device):
+ """Convert list of counts to cumsum offsets tensor."""
+ return torch.cumsum(
+ torch.tensor(counts_list, device=device, dtype=torch.int32), dim=0)
+
+
+def _make_inputs(total_tokens, hidden_dim, num_experts, dtype, device,
+ seed=42, expert_offset=0):
+ """Create deterministic test inputs with random token distribution."""
+ torch.manual_seed(seed)
+
+ probs = torch.ones(num_experts) / num_experts
+ assignments = torch.multinomial(probs, total_tokens, replacement=True)
+ counts = torch.bincount(assignments, minlength=num_experts).tolist()
+
+ # Weight/bias must have expert_offset + num_experts rows
+ total_experts = expert_offset + num_experts
+
+ # Scale inputs to avoid overflow in bf16 (x^3 can overflow for |x| > 40)
+ input_t = torch.randn(total_tokens, hidden_dim, device=device,
+ dtype=dtype) * 0.5
+ mul_t = torch.randn(total_tokens, hidden_dim, device=device,
+ dtype=dtype) * 0.5
+ weight = (torch.ones(total_experts, 3, device=device, dtype=dtype) / 3 +
+ torch.randn(total_experts, 3, device=device, dtype=dtype) * 0.01)
+ bias = torch.randn(total_experts, 1, device=device, dtype=dtype) * 0.01
+ offsets = _counts_to_offsets(counts, device)
+
+ return input_t, mul_t, weight, bias, offsets
+
+
+def _run_ref(input_t, mul_t, weight, bias, offsets, expert_offset=0):
+ """Run reference forward + backward, return output and grads."""
+ inp = input_t.clone().detach().requires_grad_(True)
+ m = mul_t.clone().detach().requires_grad_(True)
+ w = weight.clone().detach().requires_grad_(True)
+ b = bias.clone().detach().requires_grad_(True)
+
+ out = grouped_fused_mul_poly_norm_ref(inp, m, w, b, offsets,
+ expert_offset=expert_offset)
+ out.sum().backward()
+
+ return out, inp.grad, m.grad, w.grad, b.grad
+
+
+def _run_triton(input_t, mul_t, weight, bias, offsets, expert_offset=0):
+ """Run Triton forward + backward, return output and grads."""
+ inp = input_t.clone().detach().requires_grad_(True)
+ m = mul_t.clone().detach().requires_grad_(True)
+ w = weight.clone().detach().requires_grad_(True)
+ b = bias.clone().detach().requires_grad_(True)
+
+ out = grouped_fused_mul_poly_norm(inp, m, w, b, offsets,
+ expert_offset=expert_offset)
+ out.sum().backward()
+
+ return out, inp.grad, m.grad, w.grad, b.grad
+
+
+@pytest.mark.skipif(not HAS_TRITON, reason="Triton not available")
+@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
+@pytest.mark.parametrize("d", D)
+@pytest.mark.parametrize("num_experts", NUM_EXPERTS_LIST)
+@pytest.mark.parametrize("dtype", DTYPES)
+@pytest.mark.parametrize("expert_offset", EXPERT_OFFSETS)
+@pytest.mark.parametrize("seed", SEEDS)
+@pytest.mark.parametrize("device", CUDA_DEVICES)
+def test_grouped_fused_mul_poly_norm_forward(
+ num_tokens: int,
+ d: int,
+ num_experts: int,
+ dtype: torch.dtype,
+ expert_offset: int,
+ seed: int,
+ device: str,
+) -> None:
+ """Triton forward output should match PyTorch reference."""
+ torch.set_default_device(device)
+ input_t, mul_t, weight, bias, offsets = _make_inputs(
+ num_tokens, d, num_experts, dtype, device, seed,
+ expert_offset=expert_offset)
+
+ out_ref = grouped_fused_mul_poly_norm_ref(input_t, mul_t, weight, bias,
+ offsets,
+ expert_offset=expert_offset)
+ out_tri = grouped_fused_mul_poly_norm(input_t, mul_t, weight, bias,
+ offsets,
+ expert_offset=expert_offset)
+
+ assert out_ref.shape == out_tri.shape == (num_tokens, d)
+ assert out_ref.dtype == out_tri.dtype == dtype
+
+ if dtype == torch.float32:
+ assert_close(out_ref, out_tri, atol=1e-4, rtol=1e-4)
+ else:
+ assert_close(out_ref, out_tri, atol=1e-2, rtol=1e-2)
+
+
+@pytest.mark.skipif(not HAS_TRITON, reason="Triton not available")
+@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
+@pytest.mark.parametrize("d", D)
+@pytest.mark.parametrize("num_experts", NUM_EXPERTS_LIST)
+@pytest.mark.parametrize("dtype", DTYPES)
+@pytest.mark.parametrize("expert_offset", EXPERT_OFFSETS)
+@pytest.mark.parametrize("seed", SEEDS)
+@pytest.mark.parametrize("device", CUDA_DEVICES)
+def test_grouped_fused_mul_poly_norm_backward(
+ num_tokens: int,
+ d: int,
+ num_experts: int,
+ dtype: torch.dtype,
+ expert_offset: int,
+ seed: int,
+ device: str,
+) -> None:
+ """Triton backward gradients should match PyTorch reference."""
+ torch.set_default_device(device)
+ input_t, mul_t, weight, bias, offsets = _make_inputs(
+ num_tokens, d, num_experts, dtype, device, seed,
+ expert_offset=expert_offset)
+
+ _, inp_grad_ref, mul_grad_ref, w_grad_ref, b_grad_ref = _run_ref(
+ input_t, mul_t, weight, bias, offsets, expert_offset=expert_offset)
+ _, inp_grad_tri, mul_grad_tri, w_grad_tri, b_grad_tri = _run_triton(
+ input_t, mul_t, weight, bias, offsets, expert_offset=expert_offset)
+
+ if dtype == torch.float32:
+ atol, rtol = 1e-4, 1e-4
+ else:
+ atol, rtol = 5e-2, 5e-2
+
+ assert_close(inp_grad_ref, inp_grad_tri, atol=atol, rtol=rtol)
+ assert_close(mul_grad_ref, mul_grad_tri, atol=atol, rtol=rtol)
+ assert_close(w_grad_ref, w_grad_tri, atol=atol, rtol=rtol)
+ assert_close(b_grad_ref, b_grad_tri, atol=atol, rtol=rtol)
+
+
+@pytest.mark.skipif(not HAS_TRITON, reason="Triton not available")
+@pytest.mark.parametrize("dtype", DTYPES)
+@pytest.mark.parametrize("expert_offset", EXPERT_OFFSETS)
+@pytest.mark.parametrize("device", CUDA_DEVICES)
+def test_grouped_fused_mul_poly_norm_zero_token_experts(
+ dtype: torch.dtype,
+ expert_offset: int,
+ device: str,
+) -> None:
+ """Correctness when some experts receive 0 tokens."""
+ torch.set_default_device(device)
+ counts = [100, 50, 0, 80, 30, 0, 60, 40]
+ total = sum(counts)
+ num_experts = 8
+ total_experts = expert_offset + num_experts
+ hidden_dim = 256
+
+ torch.manual_seed(42)
+ input_t = torch.randn(total, hidden_dim, device=device, dtype=dtype) * 0.5
+ mul_t = torch.randn(total, hidden_dim, device=device, dtype=dtype) * 0.5
+ weight = torch.ones(total_experts, 3, device=device, dtype=dtype) / 3
+ bias = torch.zeros(total_experts, 1, device=device, dtype=dtype)
+ offsets = _counts_to_offsets(counts, device)
+
+ out_ref = grouped_fused_mul_poly_norm_ref(input_t, mul_t, weight, bias,
+ offsets,
+ expert_offset=expert_offset)
+ out_tri = grouped_fused_mul_poly_norm(input_t, mul_t, weight, bias,
+ offsets,
+ expert_offset=expert_offset)
+
+ if dtype == torch.float32:
+ assert_close(out_ref, out_tri, atol=1e-4, rtol=1e-4)
+ else:
+ assert_close(out_ref, out_tri, atol=1e-2, rtol=1e-2)
+
+ # Check backward with zero-token experts
+ _, _, _, w_grad_ref, b_grad_ref = _run_ref(input_t, mul_t, weight, bias,
+ offsets,
+ expert_offset=expert_offset)
+ _, _, _, w_grad_tri, b_grad_tri = _run_triton(input_t, mul_t, weight, bias,
+ offsets,
+ expert_offset=expert_offset)
+
+ if dtype == torch.float32:
+ atol, rtol = 1e-3, 1e-3
+ else:
+ atol, rtol = 5e-2, 5e-2
+
+ assert_close(w_grad_ref, w_grad_tri, atol=atol, rtol=rtol)
+ assert_close(b_grad_ref, b_grad_tri, atol=atol, rtol=rtol)
+
+ # Verify zero-token experts have zero weight/bias gradients
+ for eidx in [2, 5]:
+ wi = eidx + expert_offset
+ assert w_grad_tri[wi].abs().max() == 0, (
+ f"Expert {eidx} (weight idx {wi}) should have zero weight grad "
+ f"but got max={w_grad_tri[wi].abs().max().item()}")
+ assert b_grad_tri[wi].abs().max() == 0, (
+ f"Expert {eidx} (weight idx {wi}) should have zero bias grad "
+ f"but got max={b_grad_tri[wi].abs().max().item()}")
+
+
+@pytest.mark.skipif(not HAS_TRITON, reason="Triton not available")
+@pytest.mark.parametrize("dtype", DTYPES)
+@pytest.mark.parametrize("expert_offset", EXPERT_OFFSETS)
+@pytest.mark.parametrize("device", CUDA_DEVICES)
+def test_grouped_fused_mul_poly_norm_no_nan_inf(
+ dtype: torch.dtype,
+ expert_offset: int,
+ device: str,
+) -> None:
+ """Output and gradients should not contain NaN or Inf."""
+ torch.set_default_device(device)
+ input_t, mul_t, weight, bias, offsets = _make_inputs(
+ 4096, 256, 8, dtype, device, expert_offset=expert_offset)
+
+ out, inp_grad, mul_grad, w_grad, b_grad = _run_triton(
+ input_t, mul_t, weight, bias, offsets, expert_offset=expert_offset)
+
+ assert not out.isnan().any(), "Output contains NaN"
+ assert not out.isinf().any(), "Output contains Inf"
+
+ for name, grad in [("input", inp_grad), ("mul", mul_grad),
+ ("weight", w_grad), ("bias", b_grad)]:
+ assert not grad.isnan().any(), f"{name}_grad contains NaN"
+ assert not grad.isinf().any(), f"{name}_grad contains Inf"
diff --git a/torch-ext/activation/__init__.py b/torch-ext/activation/__init__.py
index 0f6f29ac2c688bd09afa41c5d1abd9942c4456d8..a71567edf5422a315e6f06eae6a5606756794820 100644
--- a/torch-ext/activation/__init__.py
+++ b/torch-ext/activation/__init__.py
@@ -2,6 +2,7 @@ import torch
from . import layers, parallel_style
from ._ops import ops
+from .grouped_poly_norm import grouped_fused_mul_poly_norm
from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction
from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction
@@ -45,6 +46,7 @@ def fused_add_rms_norm(
__all__ = [
"poly_norm",
"fused_mul_poly_norm",
+ "grouped_fused_mul_poly_norm",
"rms_norm",
"fused_add_rms_norm",
"layers",
diff --git a/torch-ext/activation/grouped_poly_norm.py b/torch-ext/activation/grouped_poly_norm.py
new file mode 100644
index 0000000000000000000000000000000000000000..9184736c1c24584a94aa80c84fc18a95ab557a52
--- /dev/null
+++ b/torch-ext/activation/grouped_poly_norm.py
@@ -0,0 +1,583 @@
+"""Triton-accelerated Grouped FusedMulPolyNorm for MoE.
+
+Fuses the entire PolyNorm computation into two Triton kernels (fwd + bwd),
+eliminating multiple intermediate tensors and kernel launches.
+
+PolyNorm formula (per row):
+ poly = w[0] * rms_norm(x^3) + w[1] * rms_norm(x^2) + w[2] * rms_norm(x) + bias
+ output = poly * mul
+
+where rms_norm(x) = x / sqrt(mean(x^2, dim=-1) + eps)
+
+Performance optimizations:
+ - @triton.autotune selects optimal BLOCK_D, num_warps, and num_stages per
+ hidden dimension.
+ - Single-tile specialization: when D <= BLOCK_D, all data stays in registers
+ across the reduction and output phases, eliminating redundant global reads.
+ - Multi-tile software pipelining: explicit num_stages in autotune configs
+ enables overlapping memory loads with computation across loop iterations.
+ - In-kernel binary search for expert mapping: eliminates 2 PyTorch kernel
+ launches (torch.arange + torch.bucketize) per forward/backward call.
+ - Backward 2-pass optimization: pass 1 merges RMS statistics computation
+ with dot product accumulation, pass 2 computes gradients. This reduces
+ memory traffic compared to a naive 3-pass approach.
+
+Forward kernel: one program per row, tiles over D dimension.
+ - Computes x, x^2, x^3 in registers
+ - Computes three RMS norms in a single pass (shared variance reduction)
+ - Applies polynomial weights + bias + mul in-place
+
+Backward kernel: one program per row, tiles over D dimension.
+ - Recomputes forward intermediates from saved inputs (activation recomputation)
+ - 2-pass: (1) RMS stats + dot products + bias grad, (2) grad_input + grad_mul + weight grads
+ - Weight/bias gradients use tl.atomic_add for cross-row accumulation
+"""
+
+import torch
+from torch import Tensor
+
+try:
+ import triton
+ import triton.language as tl
+
+ HAS_TRITON = True
+except ImportError:
+ HAS_TRITON = False
+
+
+# ---------------------------------------------------------------------------
+# PyTorch reference implementation (for testing and benchmarking)
+# ---------------------------------------------------------------------------
+def _rms_norm(x: Tensor, eps: float) -> Tensor:
+ """Per-row RMS normalization: x / sqrt(mean(x^2, dim=-1) + eps)"""
+ return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps)
+
+
+def grouped_fused_mul_poly_norm_ref(
+ input: Tensor,
+ mul: Tensor,
+ weight: Tensor,
+ bias: Tensor,
+ offsets: Tensor,
+ eps: float = 1e-6,
+ expert_offset: int = 0,
+) -> Tensor:
+ """PyTorch reference for Grouped FusedMulPolyNorm (vectorized, single pass).
+
+ Uses torch.bucketize to map tokens to experts, then computes PolyNorm
+ for all tokens at once. torch.compile friendly.
+
+ Args:
+ input: (total_tokens, D) - concatenated tokens for all experts
+ mul: (total_tokens, D) - gate values to multiply with
+ weight: (num_experts, 3) - per-expert polynomial weights [x^3, x^2, x]
+ bias: (num_experts, 1) - per-expert polynomial bias
+ offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32)
+ eps: numerical stability epsilon
+
+ Returns:
+ (total_tokens, D) - output tensor
+ """
+ orig_dtype = input.dtype
+
+ token_positions = torch.arange(input.shape[0], device=input.device)
+ expert_idx = torch.bucketize(token_positions, offsets, right=True) + expert_offset
+
+ weight_fp32 = weight.float()
+ bias_fp32 = bias.float()
+
+ per_token_w = weight_fp32[expert_idx]
+ per_token_b = bias_fp32[expert_idx]
+
+ x = input.float()
+ m = mul.float()
+
+ x2 = x * x
+ x3 = x2 * x
+
+ poly = (per_token_w[:, 0:1] * _rms_norm(x3, eps) +
+ per_token_w[:, 1:2] * _rms_norm(x2, eps) +
+ per_token_w[:, 2:3] * _rms_norm(x, eps) + per_token_b)
+
+ return (poly * m).to(orig_dtype)
+
+
+# ---------------------------------------------------------------------------
+# Triton kernel implementation
+# ---------------------------------------------------------------------------
+if HAS_TRITON:
+ # --- Autotune configurations ---
+ _GROUPED_POLYNORM_FWD_CONFIGS = [
+ triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2),
+ triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2),
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2),
+ triton.Config({"BLOCK_D": 2048}, num_warps=4, num_stages=1),
+ triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1),
+ triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1),
+ triton.Config({"BLOCK_D": 2048}, num_warps=32, num_stages=1),
+ ]
+
+ _GROUPED_POLYNORM_BWD_CONFIGS = [
+ triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2),
+ triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=5),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2),
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2),
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3),
+ triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=4),
+ triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2),
+ triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1),
+ triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1),
+ ]
+
+ @triton.autotune(
+ configs=_GROUPED_POLYNORM_FWD_CONFIGS,
+ key=["D"],
+ )
+ @triton.jit
+ def _grouped_polynorm_fwd_kernel(
+ input_ptr,
+ mul_ptr,
+ weight_ptr,
+ bias_ptr,
+ offsets_ptr,
+ output_ptr,
+ N,
+ D,
+ num_experts,
+ eps,
+ expert_offset,
+ stride_input_row,
+ stride_mul_row,
+ stride_out_row,
+ BLOCK_D: tl.constexpr,
+ ):
+ """Forward kernel: one program per row."""
+ row = tl.program_id(0)
+ if row >= N:
+ return
+
+ # Binary search for expert index (12 iters covers up to 4096 experts)
+ lo = 0
+ hi = num_experts
+ for _ in range(12):
+ if lo < hi:
+ mid = (lo + hi) // 2
+ if tl.load(offsets_ptr + mid) <= row:
+ lo = mid + 1
+ else:
+ hi = mid
+ eidx = lo + expert_offset
+
+ w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32)
+ w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32)
+ w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32)
+ b = tl.load(bias_ptr + eidx).to(tl.float32)
+
+ input_row_ptr = input_ptr + row * stride_input_row
+ mul_row_ptr = mul_ptr + row * stride_mul_row
+ out_row_ptr = output_ptr + row * stride_out_row
+
+ D_float = D.to(tl.float32)
+
+ # --- Single-tile path ---
+ if D <= BLOCK_D:
+ d_offs = tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+
+ x2 = x * x
+ x3 = x2 * x
+
+ inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps)
+ inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps)
+ inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps)
+
+ # Pre-multiply scalar weight * inv_rms to save 1 FMA per element
+ w0_inv = w0 * inv_rms_x3
+ w1_inv = w1 * inv_rms_x2
+ w2_inv = w2 * inv_rms_x
+
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b
+ tl.store(out_row_ptr + d_offs, poly * m, mask=mask)
+ else:
+ # --- Multi-tile: two-pass approach ---
+ sum_x2 = tl.zeros((), dtype=tl.float32)
+ sum_x4 = tl.zeros((), dtype=tl.float32)
+ sum_x6 = tl.zeros((), dtype=tl.float32)
+
+ for d_start in range(0, D, BLOCK_D):
+ d_offs = d_start + tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ x2 = x * x
+ sum_x2 += tl.sum(x2)
+ sum_x4 += tl.sum(x2 * x2)
+ sum_x6 += tl.sum(x2 * x2 * x2)
+
+ inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps)
+ inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps)
+ inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps)
+
+ # Pre-multiply scalar weight * inv_rms
+ w0_inv = w0 * inv_rms_x3
+ w1_inv = w1 * inv_rms_x2
+ w2_inv = w2 * inv_rms_x
+
+ for d_start in range(0, D, BLOCK_D):
+ d_offs = d_start + tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ x2 = x * x
+ x3 = x2 * x
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b
+ tl.store(out_row_ptr + d_offs, poly * m, mask=mask)
+
+ @triton.autotune(
+ configs=_GROUPED_POLYNORM_BWD_CONFIGS,
+ key=["D"],
+ reset_to_zero=["grad_weight_ptr", "grad_bias_ptr"],
+ )
+ @triton.jit
+ def _grouped_polynorm_bwd_kernel(
+ grad_out_ptr,
+ input_ptr,
+ mul_ptr,
+ weight_ptr,
+ bias_ptr,
+ offsets_ptr,
+ grad_input_ptr,
+ grad_mul_ptr,
+ grad_weight_ptr,
+ grad_bias_ptr,
+ N,
+ D,
+ num_experts,
+ eps,
+ expert_offset,
+ stride_row,
+ BLOCK_D: tl.constexpr,
+ ):
+ """Backward kernel: one program per row, 2-pass approach.
+
+ Pass 1: RMS stats + dot products + bias grad
+ Pass 2: grad_input + grad_mul + weight grads (via atomic_add)
+ """
+ row = tl.program_id(0)
+ if row >= N:
+ return
+
+ lo = 0
+ hi = num_experts
+ for _ in range(12):
+ if lo < hi:
+ mid = (lo + hi) // 2
+ if tl.load(offsets_ptr + mid) <= row:
+ lo = mid + 1
+ else:
+ hi = mid
+ eidx = lo + expert_offset
+
+ w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32)
+ w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32)
+ w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32)
+ b_val = tl.load(bias_ptr + eidx).to(tl.float32)
+
+ input_row_ptr = input_ptr + row * stride_row
+ mul_row_ptr = mul_ptr + row * stride_row
+ grad_out_row_ptr = grad_out_ptr + row * stride_row
+ grad_input_row_ptr = grad_input_ptr + row * stride_row
+ grad_mul_row_ptr = grad_mul_ptr + row * stride_row
+
+ D_float = D.to(tl.float32)
+
+ # --- Single-tile path ---
+ if D <= BLOCK_D:
+ d_offs = tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+
+ x2 = x * x
+ x3 = x2 * x
+
+ # Compute RMS stats (x4 inlined to reduce register pressure)
+ inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps)
+ inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps)
+ inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps)
+
+ w0_inv = w0 * inv_rms_x3
+ w1_inv = w1 * inv_rms_x2
+ w2_inv = w2 * inv_rms_x
+
+ dpoly = go * m
+
+ # Dot products for coefficients and weight grads
+ sum_dpoly_x = tl.sum(dpoly * x)
+ sum_dpoly_x2 = tl.sum(dpoly * x2)
+ sum_dpoly_x3 = tl.sum(dpoly * x3)
+ grad_b_acc = tl.sum(dpoly)
+
+ # Weight grads
+ grad_w0_acc = inv_rms_x3 * sum_dpoly_x3
+ grad_w1_acc = inv_rms_x2 * sum_dpoly_x2
+ grad_w2_acc = inv_rms_x * sum_dpoly_x
+
+ # Coefficients for grad_input
+ coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float
+ coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float
+ coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float
+
+ # grad_mul
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv
+ tl.store(grad_mul_row_ptr + d_offs, go * (poly + b_val), mask=mask)
+
+ # grad_input (in-place accumulation to reduce register pressure)
+ g = inv_rms_x * (w2 * dpoly - x * coeff_x)
+ g += 2.0 * x * inv_rms_x2 * (w1 * dpoly - x2 * coeff_x2)
+ g += 3.0 * x2 * inv_rms_x3 * (w0 * dpoly - x3 * coeff_x3)
+ tl.store(grad_input_row_ptr + d_offs, g, mask=mask)
+
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc)
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc)
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc)
+ tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc)
+ else:
+ # --- Multi-tile: 2-pass ---
+ # Pass 1: RMS stats + dot products + bias grad
+ sum_x2 = tl.zeros((), dtype=tl.float32)
+ sum_x4 = tl.zeros((), dtype=tl.float32)
+ sum_x6 = tl.zeros((), dtype=tl.float32)
+ sum_dpoly_x = tl.zeros((), dtype=tl.float32)
+ sum_dpoly_x2 = tl.zeros((), dtype=tl.float32)
+ sum_dpoly_x3 = tl.zeros((), dtype=tl.float32)
+ grad_b_acc = tl.zeros((), dtype=tl.float32)
+
+ for d_start in range(0, D, BLOCK_D):
+ d_offs = d_start + tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+
+ x2 = x * x
+ x3 = x2 * x
+ dpoly = go * m
+
+ sum_x2 += tl.sum(x2)
+ sum_x4 += tl.sum(x2 * x2)
+ sum_x6 += tl.sum(x2 * x2 * x2)
+ sum_dpoly_x += tl.sum(dpoly * x)
+ sum_dpoly_x2 += tl.sum(dpoly * x2)
+ sum_dpoly_x3 += tl.sum(dpoly * x3)
+ grad_b_acc += tl.sum(dpoly)
+
+ inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps)
+ inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps)
+ inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps)
+
+ w0_inv = w0 * inv_rms_x3
+ w1_inv = w1 * inv_rms_x2
+ w2_inv = w2 * inv_rms_x
+
+ # Weight grads
+ grad_w0_acc = inv_rms_x3 * sum_dpoly_x3
+ grad_w1_acc = inv_rms_x2 * sum_dpoly_x2
+ grad_w2_acc = inv_rms_x * sum_dpoly_x
+
+ # Coefficients for grad_input
+ coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float
+ coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float
+ coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float
+
+ # Pass 2: grad_input + grad_mul
+ for d_start in range(0, D, BLOCK_D):
+ d_offs = d_start + tl.arange(0, BLOCK_D)
+ mask = d_offs < D
+ x = tl.load(input_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ m = tl.load(mul_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+ go = tl.load(grad_out_row_ptr + d_offs, mask=mask,
+ other=0.0).to(tl.float32)
+
+ x2 = x * x
+ x3 = x2 * x
+
+ poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv
+ tl.store(grad_mul_row_ptr + d_offs,
+ go * (poly + b_val),
+ mask=mask)
+
+ dpoly = go * m
+ g = inv_rms_x * (w2 * dpoly - x * coeff_x)
+ g += (2.0 * x * inv_rms_x2 *
+ (w1 * dpoly - x2 * coeff_x2))
+ g += (3.0 * x2 * inv_rms_x3 *
+ (w0 * dpoly - x3 * coeff_x3))
+
+ tl.store(grad_input_row_ptr + d_offs, g, mask=mask)
+
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc)
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc)
+ tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc)
+ tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc)
+
+ class _GroupedPolyNormFn(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, input, mul, weight, bias, offsets, eps, expert_offset):
+ N, D = input.shape
+ input = input.contiguous()
+ mul = mul.contiguous()
+ output = torch.empty_like(input)
+
+ num_experts = offsets.shape[0]
+ assert num_experts <= 4096, (
+ f"Supports at most 4096 experts, got {num_experts}.")
+
+ _grouped_polynorm_fwd_kernel[(N,)](
+ input,
+ mul,
+ weight,
+ bias,
+ offsets,
+ output,
+ N,
+ D,
+ num_experts,
+ eps,
+ expert_offset,
+ stride_input_row=input.stride(0),
+ stride_mul_row=mul.stride(0),
+ stride_out_row=output.stride(0),
+ )
+
+ ctx.save_for_backward(input, mul, weight, bias, offsets)
+ ctx.eps = eps
+ ctx.expert_offset = expert_offset
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input, mul, weight, bias, offsets = ctx.saved_tensors
+ eps = ctx.eps
+ expert_offset = ctx.expert_offset
+ N, D = input.shape
+
+ grad_output = grad_output.contiguous()
+ grad_input = torch.empty_like(input)
+ grad_mul = torch.empty_like(mul)
+ grad_weight = torch.zeros(weight.shape[0],
+ 3,
+ device=weight.device,
+ dtype=torch.float32)
+ grad_bias = torch.zeros(bias.shape[0],
+ device=bias.device,
+ dtype=torch.float32)
+
+ num_experts = offsets.shape[0]
+
+ _grouped_polynorm_bwd_kernel[(N,)](
+ grad_output,
+ input,
+ mul,
+ weight,
+ bias,
+ offsets,
+ grad_input,
+ grad_mul,
+ grad_weight,
+ grad_bias,
+ N,
+ D,
+ num_experts,
+ eps,
+ expert_offset,
+ stride_row=input.stride(0),
+ )
+
+ grad_weight = grad_weight.to(weight.dtype)
+ grad_bias = grad_bias.unsqueeze(-1).to(bias.dtype)
+
+ return grad_input, grad_mul, grad_weight, grad_bias, None, None, None
+
+ def grouped_fused_mul_poly_norm(
+ input: Tensor,
+ mul: Tensor,
+ weight: Tensor,
+ bias: Tensor,
+ offsets: Tensor,
+ eps: float = 1e-6,
+ expert_offset: int = 0,
+ ) -> Tensor:
+ """Triton-accelerated Grouped FusedMulPolyNorm.
+
+ Args:
+ input: (total_tokens, D) - concatenated tokens for all experts
+ mul: (total_tokens, D) - gate values to multiply with
+ weight: (num_experts, 3) - per-expert polynomial weights
+ bias: (num_experts, 1) - per-expert polynomial bias
+ offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32)
+ eps: numerical stability epsilon
+ expert_offset: offset to add to expert index
+
+ Returns:
+ (total_tokens, D) - output tensor
+ """
+ return _GroupedPolyNormFn.apply(input, mul, weight, bias, offsets, eps,
+ expert_offset)
+
+else:
+
+ def grouped_fused_mul_poly_norm(
+ input: Tensor,
+ mul: Tensor,
+ weight: Tensor,
+ bias: Tensor,
+ offsets: Tensor,
+ eps: float = 1e-6,
+ expert_offset: int = 0,
+ ) -> Tensor:
+ raise RuntimeError(
+ "Triton is not available. Install triton to use "
+ "grouped_fused_mul_poly_norm.")