diff --git a/README.md b/README.md index b6bbc514556a9a3dc0e55e3badf6babe628c99e8..dd173875e81233c3156bbb7baefe9c185990a709 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ Activation is a python package that contains custom CUDA-based activation kernel ```python y = x + residual hidden_state = rms_norm(y, weight, eps) - out = y + some_op(hidden_state) + out = y + some_op(hidden_state) ``` - Fused as: @@ -45,6 +45,22 @@ Activation is a python package that contains custom CUDA-based activation kernel out = fused_mul_poly_norm(x, a, weight, bias, eps) ``` + - **GroupedFusedMulPolyNorm** (Triton) + + A Triton-accelerated grouped variant of FusedMulPolyNorm for **MoE (Mixture of Experts)** models. Fuses the entire PolyNorm computation into two Triton kernels (fwd + bwd), with per-expert weights/bias and in-kernel binary search for expert mapping. + - Instead of: + + ```python + for i, expert in enumerate(experts): + out[start:end] = fused_mul_poly_norm(x[start:end], mul[start:end], weight[i], bias[i], eps) + ``` + + - Fused as: + + ```python + out = grouped_fused_mul_poly_norm(x, mul, weight, bias, offsets, eps) + ``` + ## Usage ```python @@ -214,6 +230,118 @@ print(poly_norm(x)) +--- + +### GroupedFusedMulPolyNorm (Triton) + +> [!NOTE] +> This kernel is implemented in Triton (JIT-compiled, no CUDA C++ build required). +> Benchmarks compare three variants: **Naive** (raw PyTorch reference), **Compiled** (`torch.compile`'d reference), and **Triton** (fused Triton kernel). +> Benchmark dimension: 1280, 384 experts. + +#### B200 Results (bf16) + +
+Forward Performance + +| batch_size | seq_len | Naive (us) | Compiled (us) | Triton (us) | Triton vs Naive | +|-----------|---------|-----------|--------------|------------|-----------------| +| 1 | 1024 | 294.54 | 73.46 | 64.33 | 4.58x | +| 1 | 2048 | 373.50 | 94.88 | 65.26 | 5.72x | +| 1 | 4096 | 372.65 | 94.90 | 66.90 | 5.57x | +| 1 | 8192 | 486.98 | 102.33 | 72.71 | 6.70x | +| 2 | 4096 | 486.66 | 101.87 | 72.27 | 6.73x | +| 2 | 8192 | 950.62 | 106.96 | 90.06 | 10.56x | +| 4 | 4096 | 950.72 | 107.17 | 71.28 | 13.34x | +| 4 | 8192 | 1779.12 | 198.91 | 96.93 | 18.35x | +| 8 | 4096 | 1778.73 | 199.10 | 96.88 | 18.36x | +| 8 | 8192 | 3384.03 | 381.91 | 179.57 | 18.85x | + +
+ +
+Backward Performance + +| batch_size | seq_len | Naive (us) | Compiled (us) | Triton (us) | Triton vs Naive | +|-----------|---------|-----------|--------------|------------|-----------------| +| 1 | 1024 | 1690.61 | 999.66 | 1017.66 | 1.66x | +| 1 | 8192 | 1680.39 | 906.43 | 906.41 | 1.85x | +| 2 | 8192 | 2466.73 | 870.74 | 862.78 | 2.86x | +| 4 | 4096 | 2466.04 | 942.62 | 945.68 | 2.61x | +| 4 | 8192 | 4543.10 | 941.01 | 908.30 | 5.00x | +| 8 | 4096 | 4542.91 | 814.73 | 900.01 | 5.05x | +| 8 | 8192 | 8599.41 | 956.81 | 955.07 | 9.00x | + +
+ +
+Forward + Backward Combined + +| batch_size | seq_len | Naive (us) | Compiled (us) | Triton (us) | Triton vs Naive | Triton vs Compiled | +|-----------|---------|-----------|--------------|------------|-----------------|-------------------| +| 1 | 1024 | 1985.15 | 1073.12 | 1081.99 | 1.83x | 0.99x | +| 1 | 4096 | 2085.10 | 974.32 | 960.73 | 2.17x | 1.01x | +| 1 | 8192 | 2167.37 | 1008.76 | 979.12 | 2.21x | 1.03x | +| 2 | 4096 | 2083.49 | 1001.03 | 965.30 | 2.16x | 1.04x | +| 2 | 8192 | 3417.35 | 977.70 | 952.84 | 3.59x | 1.03x | +| 4 | 4096 | 3416.76 | 1049.79 | 1016.97 | 3.36x | 1.03x | +| 4 | 8192 | 6322.22 | 1139.92 | 1005.23 | 6.29x | 1.13x | +| 8 | 4096 | 6321.64 | 1013.83 | 996.89 | 6.34x | 1.02x | +| 8 | 8192 | 11983.44 | 1338.71 | 1134.64 | 10.56x | 1.18x | + +
+ +#### B200 Results (fp32) + +
+Forward Performance + +| batch_size | seq_len | Naive (us) | Compiled (us) | Triton (us) | Triton vs Naive | +|-----------|---------|-----------|--------------|------------|-----------------| +| 1 | 1024 | 318.05 | 83.29 | 64.24 | 4.95x | +| 1 | 2048 | 311.14 | 95.19 | 63.64 | 4.89x | +| 1 | 8192 | 401.78 | 101.61 | 68.21 | 5.89x | +| 2 | 4096 | 403.42 | 100.97 | 68.01 | 5.93x | +| 2 | 8192 | 803.31 | 130.51 | 68.21 | 11.78x | +| 4 | 4096 | 802.86 | 130.61 | 66.97 | 11.99x | +| 4 | 8192 | 1505.96 | 246.77 | 100.49 | 14.99x | +| 8 | 4096 | 1507.87 | 246.84 | 100.23 | 15.04x | +| 8 | 8192 | 2856.93 | 476.34 | 184.40 | 15.49x | + +
+ +
+Backward Performance + +| batch_size | seq_len | Naive (us) | Compiled (us) | Triton (us) | Triton vs Naive | +|-----------|---------|-----------|--------------|------------|-----------------| +| 1 | 1024 | 1604.25 | 989.30 | 1114.12 | 1.44x | +| 1 | 8192 | 1996.40 | 1117.71 | 1115.47 | 1.79x | +| 2 | 8192 | 2353.87 | 1119.41 | 1118.57 | 2.10x | +| 4 | 4096 | 2358.47 | 1102.23 | 1125.16 | 2.10x | +| 4 | 8192 | 4346.92 | 1125.33 | 1135.36 | 3.83x | +| 8 | 4096 | 4347.47 | 1104.27 | 1119.63 | 3.88x | +| 8 | 8192 | 8226.50 | 1172.66 | 1197.68 | 6.87x | + +
+ +
+Forward + Backward Combined + +| batch_size | seq_len | Naive (us) | Compiled (us) | Triton (us) | Triton vs Naive | Triton vs Compiled | +|-----------|---------|-----------|--------------|------------|-----------------|-------------------| +| 1 | 1024 | 1922.30 | 1072.59 | 1178.36 | 1.63x | 0.91x | +| 1 | 4096 | 2367.77 | 1208.69 | 1192.07 | 1.99x | 1.01x | +| 1 | 8192 | 2398.19 | 1219.32 | 1183.69 | 2.03x | 1.03x | +| 2 | 4096 | 2401.39 | 1248.87 | 1154.72 | 2.08x | 1.08x | +| 2 | 8192 | 3157.18 | 1249.92 | 1186.77 | 2.66x | 1.05x | +| 4 | 4096 | 3161.33 | 1232.84 | 1192.13 | 2.65x | 1.03x | +| 4 | 8192 | 5852.88 | 1372.10 | 1235.86 | 4.74x | 1.11x | +| 8 | 4096 | 5855.34 | 1351.11 | 1219.85 | 4.80x | 1.11x | +| 8 | 8192 | 11083.43 | 1649.00 | 1382.07 | 8.02x | 1.19x | + +
+ ## Pre-commit Hooks This project uses [pre-commit](https://pre-commit.com/) to automatically check and format code before commits. diff --git a/benchmarks/cases/grouped_mul_poly.py b/benchmarks/cases/grouped_mul_poly.py new file mode 100644 index 0000000000000000000000000000000000000000..8c1268f5595ba6ccce9ea495c31a50b21d31314e --- /dev/null +++ b/benchmarks/cases/grouped_mul_poly.py @@ -0,0 +1,122 @@ +import torch +import torch._functorch.config +from common.diff_engine import DiffCase + +torch._functorch.config.donated_buffer = False + +from grouped_poly_norm import ( + grouped_fused_mul_poly_norm, + grouped_fused_mul_poly_norm_ref, +) + +NUM_EXPERTS = 384 + + +class GroupedRefModule(torch.nn.Module): + """Wraps the PyTorch reference for grouped FusedMulPolyNorm.""" + + def __init__(self, weight, bias, offsets, eps, expert_offset=0): + super().__init__() + self.weight = torch.nn.Parameter(weight) + self.bias = torch.nn.Parameter(bias) + self.offsets = offsets + self.eps = eps + self.expert_offset = expert_offset + + def forward(self, x, mul): + return grouped_fused_mul_poly_norm_ref(x, mul, self.weight, self.bias, + self.offsets, self.eps, + expert_offset=self.expert_offset) + + +class GroupedTritonModule(torch.nn.Module): + """Wraps the Triton kernel for grouped FusedMulPolyNorm.""" + + def __init__(self, weight, bias, offsets, eps, expert_offset=0): + super().__init__() + self.weight = torch.nn.Parameter(weight) + self.bias = torch.nn.Parameter(bias) + self.offsets = offsets + self.eps = eps + self.expert_offset = expert_offset + + def forward(self, x, mul): + return grouped_fused_mul_poly_norm(x, mul, self.weight, self.bias, + self.offsets, self.eps, + expert_offset=self.expert_offset) + + +class GroupedMulPoly(DiffCase): + """Benchmark case for Grouped FusedMulPolyNorm (MoE). + + Maps the framework's (bs, sl, hidden) to grouped polynorm's + (total_tokens, D) where total_tokens = bs * sl. + Uses a fixed number of experts with uniform token distribution. + """ + + def build_inputs(self, bs, sl, hidden, dtype, eps): + total_tokens = bs * sl + num_experts = min(NUM_EXPERTS, total_tokens) + + torch.manual_seed(42) + probs = torch.ones(num_experts) / num_experts + assignments = torch.multinomial(probs, total_tokens, replacement=True) + counts = torch.bincount(assignments, minlength=num_experts).tolist() + offsets = torch.cumsum( + torch.tensor(counts, dtype=torch.int32), dim=0) + + return { + "x": + torch.randn(total_tokens, hidden, dtype=dtype, + requires_grad=True) * 0.5, + "mul": + torch.randn(total_tokens, hidden, dtype=dtype, + requires_grad=True) * 0.5, + "weight": + torch.ones(num_experts, 3, dtype=dtype) / 3 + + torch.randn(num_experts, 3, dtype=dtype) * 0.01, + "bias": + torch.randn(num_experts, 1, dtype=dtype) * 0.01, + "offsets": + offsets, + "dim": + hidden, + "eps": + eps, + "dtype": + dtype, + } + + def make_naive(self, I): + return GroupedRefModule( + I["weight"].detach().clone(), + I["bias"].detach().clone(), + I["offsets"], + I["eps"], + ) + + def make_compiled(self, I): + m = GroupedRefModule( + I["weight"].detach().clone(), + I["bias"].detach().clone(), + I["offsets"], + I["eps"], + ) + return torch.compile(m) + + def make_cuda(self, I): + return GroupedTritonModule( + I["weight"].detach().clone(), + I["bias"].detach().clone(), + I["offsets"], + I["eps"], + ) + + def forward(self, obj, I): + return obj(I["x"], I["mul"]) + + def grad_inputs(self, I): + return [I["x"], I["mul"]] + + +CASE = GroupedMulPoly() diff --git a/benchmarks/common/bench_framework.py b/benchmarks/common/bench_framework.py index 49dfe3c7deb1cd6595a2d411a0e17615d8f99e3b..3837703033e37ac9493da568a6a02916ccb4d934 100644 --- a/benchmarks/common/bench_framework.py +++ b/benchmarks/common/bench_framework.py @@ -57,7 +57,12 @@ def make_fwd_benchmark_for_case( I = case.build_inputs(batch_size, seq_len, dim, dtype, eps) if provider == "speedup": return timings_ms["naive"][key] / timings_ms["cuda"][key] - obj = case.make_naive(I) if provider == "naive" else case.make_cuda(I) + if provider == "naive": + obj = case.make_naive(I) + elif provider == "compiled" and hasattr(case, "make_compiled"): + obj = case.make_compiled(I) + else: + obj = case.make_cuda(I) run = lambda: case.forward(obj, I) ms = triton.testing.do_bench(run) timings_ms[provider][key] = ms @@ -101,7 +106,12 @@ def make_fwd_benchmark_plot_for_case( return 1.00 batch_size, seq_len, dim = parse_config_string(config) I = case.build_inputs(batch_size, seq_len, dim, dtype, eps) - obj = case.make_naive(I) if provider == "naive" else case.make_cuda(I) + if provider == "naive": + obj = case.make_naive(I) + elif provider == "compiled" and hasattr(case, "make_compiled"): + obj = case.make_compiled(I) + else: + obj = case.make_cuda(I) run = lambda: case.forward(obj, I) ms = triton.testing.do_bench(run) timings_ms[provider][config] = ms @@ -146,7 +156,12 @@ def make_bwd_benchmark_for_case( I = case.build_inputs(batch_size, seq_len, dim, dtype, eps) if provider == "speedup": return timings_ms["naive"][key] / timings_ms["cuda"][key] - obj = case.make_naive(I) if provider == "naive" else case.make_cuda(I) + if provider == "naive": + obj = case.make_naive(I) + elif provider == "compiled" and hasattr(case, "make_compiled"): + obj = case.make_compiled(I) + else: + obj = case.make_cuda(I) y = case.forward(obj, I) gin = list(case.grad_inputs(I)) + list(obj.parameters()) if isinstance(y, torch.Tensor): @@ -201,7 +216,12 @@ def make_bwd_benchmark_plot_for_case( return 1.00 batch_size, seq_len, dim = parse_config_string(config) I = case.build_inputs(batch_size, seq_len, dim, dtype, eps) - obj = case.make_naive(I) if provider == "naive" else case.make_cuda(I) + if provider == "naive": + obj = case.make_naive(I) + elif provider == "compiled" and hasattr(case, "make_compiled"): + obj = case.make_compiled(I) + else: + obj = case.make_cuda(I) y = case.forward(obj, I) gin = list(case.grad_inputs(I)) + list(obj.parameters()) if isinstance(y, torch.Tensor): diff --git a/benchmarks/run_cases.py b/benchmarks/run_cases.py index e2e1f746ccf14aed82182c482dc39c05fe03816e..5c4a718b74cf4f1063d806839442d0d83f7e8fdb 100644 --- a/benchmarks/run_cases.py +++ b/benchmarks/run_cases.py @@ -23,12 +23,15 @@ def make_title_tag(): return f"[{dev_name} | torch {torch_ver}]" -def plot_result(r_path): +def plot_result(r_path, columns=None): import matplotlib.pyplot as plt import pandas as pd df = pd.read_csv(r_path + ".csv") + if columns is None: + columns = [c for c in ["Naive", "Compiled", "Cuda", "Triton"] + if c in df.columns] plt.figure(figsize=(12, 6)) - ax = df.plot(x="config", y=["Naive", "Cuda"], kind="bar", ax=plt.gca()) + ax = df.plot(x="config", y=columns, kind="bar", ax=plt.gca()) ax.set_title("Speedup over torch (higher is better)\n" + make_title_tag(), fontsize=14, fontweight="bold") @@ -44,9 +47,10 @@ def plot_result(r_path): def main(): ap = argparse.ArgumentParser() - ap.add_argument("--case", - choices=["rms", "add_rms", "poly", "mul_poly"], - required=True) + ap.add_argument( + "--case", + choices=["rms", "add_rms", "poly", "mul_poly", "grouped_mul_poly"], + required=True) ap.add_argument("--plot", action="store_true") ap.add_argument( "--save-path", @@ -54,8 +58,25 @@ def main(): default="./configs/", help="Path to save benchmark results", ) + ap.add_argument( + "--dtype", + choices=["fp16", "bf16", "fp32", "all"], + default="bf16", + help="Data type for benchmarking (default: bf16)", + ) args = ap.parse_args() + dtype_map = { + "fp16": torch.float16, + "bf16": torch.bfloat16, + "fp32": torch.float32, + } + if args.dtype == "all": + dtypes = [("fp16", torch.float16), ("bf16", torch.bfloat16), + ("fp32", torch.float32)] + else: + dtypes = [(args.dtype, dtype_map[args.dtype])] + torch.set_default_device("cuda") mod = importlib.import_module(f"cases.{args.case}") case: DiffCase = mod.CASE @@ -67,76 +88,118 @@ def main(): hidden_size=4096, ) - save_dir = os.path.join(args.save_path, args.case) - if args.plot: - batch_size_range = [1] - seq_length_range = [4096, 8192, 16384] - dim = [8192, 16384] if "poly" in args.case else [2048, 4096] - configs = list( - itertools.product(batch_size_range, seq_length_range, dim)) - plot_name = f"plot_{args.case}-fwd-perf" - bench = make_fwd_benchmark_plot_for_case( - case=case, - configs=configs, - plot_name=plot_name, - line_names={ - "naive": "Naive", - "cuda": "Cuda", - }, - ) - bench.run(print_data=True, save_path=save_dir) - plot_result(os.path.join(save_dir, plot_name)) - - plot_name = f"plot_{args.case}-bwd-perf" - bench = make_bwd_benchmark_plot_for_case( - case=case, - configs=configs, - plot_name=plot_name, - line_names={ - "naive": "Naive", - "cuda": "Cuda", - }, - ) - bench.run(print_data=True, save_path=save_dir) - plot_result(os.path.join(save_dir, plot_name)) - for f in glob.glob(os.path.join(save_dir, "*.html")) + glob.glob( - os.path.join(save_dir, "*.csv")): - os.remove(f) - else: - batch_size_range = [2**i for i in range(0, 4, 1)] - seq_length_range = [2**i for i in range(10, 14, 1)] - dim = [8192, 16384] if "poly" in args.case else [2048, 4096] - configs = list( - itertools.product(dim, batch_size_range, seq_length_range)) - - bench = make_fwd_benchmark_for_case( - case=case, - configs=configs, - plot_name=f"{args.case}-fwd-perf", - line_names={ - "naive": "Naive", - "cuda": "Cuda", - "speedup": "SpeedUp" - }, - ) - - bench.run(print_data=True, save_path=save_dir) - - bench = make_bwd_benchmark_for_case( - case=case, - configs=configs, - plot_name=f"{args.case}-bwd-perf", - line_names={ - "naive": "Naive", - "cuda": "Cuda", - "speedup": "SpeedUp" - }, - ) - - bench.run(print_data=True, save_path=save_dir) - for f in glob.glob(os.path.join(save_dir, "*.html")) + glob.glob( - os.path.join(save_dir, "*.png")): - os.remove(f) + for dtype_name, dtype in dtypes: + print(f"\n{'=' * 60}") + print(f" Benchmarking dtype: {dtype_name} ({dtype})") + print(f"{'=' * 60}\n") + + save_dir = os.path.join(args.save_path, args.case, dtype_name) + is_grouped = args.case == "grouped_mul_poly" + + if args.plot: + batch_size_range = [1] + seq_length_range = [4096, 8192, 16384] + if is_grouped: + dim = [1280] + elif "poly" in args.case: + dim = [8192, 16384] + else: + dim = [2048, 4096] + configs = list( + itertools.product(batch_size_range, seq_length_range, dim)) + + if is_grouped: + plot_line_vals = ("naive", "compiled", "cuda") + plot_line_names = { + "naive": "Naive", + "compiled": "Compiled", + "cuda": "Triton", + } + else: + plot_line_vals = ("naive", "cuda") + plot_line_names = { + "naive": "Naive", + "cuda": "Cuda", + } + + plot_name = f"plot_{args.case}-{dtype_name}-fwd-perf" + bench = make_fwd_benchmark_plot_for_case( + case=case, + configs=configs, + plot_name=plot_name, + dtype=dtype, + line_vals=plot_line_vals, + line_names=plot_line_names, + ) + bench.run(print_data=True, save_path=save_dir) + plot_result(os.path.join(save_dir, plot_name)) + + plot_name = f"plot_{args.case}-{dtype_name}-bwd-perf" + bench = make_bwd_benchmark_plot_for_case( + case=case, + configs=configs, + plot_name=plot_name, + dtype=dtype, + line_vals=plot_line_vals, + line_names=plot_line_names, + ) + bench.run(print_data=True, save_path=save_dir) + plot_result(os.path.join(save_dir, plot_name)) + for f in glob.glob(os.path.join(save_dir, "*.html")) + \ + glob.glob(os.path.join(save_dir, "*.csv")): + os.remove(f) + else: + batch_size_range = [2**i for i in range(0, 4, 1)] + seq_length_range = [2**i for i in range(10, 14, 1)] + if is_grouped: + dim = [1280] + elif "poly" in args.case: + dim = [8192, 16384] + else: + dim = [2048, 4096] + configs = list( + itertools.product(dim, batch_size_range, seq_length_range)) + + if is_grouped: + csv_line_vals = ("naive", "compiled", "cuda", "speedup") + csv_line_names = { + "naive": "Naive", + "compiled": "Compiled", + "cuda": "Triton", + "speedup": "SpeedUp", + } + else: + csv_line_vals = ("naive", "cuda", "speedup") + csv_line_names = { + "naive": "Naive", + "cuda": "Cuda", + "speedup": "SpeedUp", + } + + bench = make_fwd_benchmark_for_case( + case=case, + configs=configs, + plot_name=f"{args.case}-{dtype_name}-fwd-perf", + dtype=dtype, + line_vals=csv_line_vals, + line_names=csv_line_names, + ) + + bench.run(print_data=True, save_path=save_dir) + + bench = make_bwd_benchmark_for_case( + case=case, + configs=configs, + plot_name=f"{args.case}-{dtype_name}-bwd-perf", + dtype=dtype, + line_vals=csv_line_vals, + line_names=csv_line_names, + ) + + bench.run(print_data=True, save_path=save_dir) + for f in glob.glob(os.path.join(save_dir, "*.html")) + \ + glob.glob(os.path.join(save_dir, "*.png")): + os.remove(f) if __name__ == "__main__": diff --git a/build/torch210-cxx11-cu126-x86_64-linux/__init__.py b/build/torch210-cxx11-cu126-x86_64-linux/__init__.py index 0f6f29ac2c688bd09afa41c5d1abd9942c4456d8..a71567edf5422a315e6f06eae6a5606756794820 100644 --- a/build/torch210-cxx11-cu126-x86_64-linux/__init__.py +++ b/build/torch210-cxx11-cu126-x86_64-linux/__init__.py @@ -2,6 +2,7 @@ import torch from . import layers, parallel_style from ._ops import ops +from .grouped_poly_norm import grouped_fused_mul_poly_norm from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction @@ -45,6 +46,7 @@ def fused_add_rms_norm( __all__ = [ "poly_norm", "fused_mul_poly_norm", + "grouped_fused_mul_poly_norm", "rms_norm", "fused_add_rms_norm", "layers", diff --git a/build/torch210-cxx11-cu126-x86_64-linux/_activation_0e6f27f_dirty.abi3.so b/build/torch210-cxx11-cu126-x86_64-linux/_activation_0e6f27f_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..caa5f9e4c72751c1cc34a718812babba0149ff6f --- /dev/null +++ b/build/torch210-cxx11-cu126-x86_64-linux/_activation_0e6f27f_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f31dfeac9b22c01a027f858b3d8beaf87eea9adf8dc45902f0e43d6c264fd985 +size 10775296 diff --git a/build/torch210-cxx11-cu126-x86_64-linux/_activation_18b7543_dirty.abi3.so b/build/torch210-cxx11-cu126-x86_64-linux/_activation_18b7543_dirty.abi3.so deleted file mode 100755 index da939124912cd385bcb4d1e02878ecf9ffe53ad8..0000000000000000000000000000000000000000 --- a/build/torch210-cxx11-cu126-x86_64-linux/_activation_18b7543_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:39a7e25002120a73ea83ac813276c0518086fae2236f528dadf96bac4876a270 -size 10775296 diff --git a/build/torch210-cxx11-cu126-x86_64-linux/_ops.py b/build/torch210-cxx11-cu126-x86_64-linux/_ops.py index 1e14ae6ea50287a7fbc0feee54855ae4ffc0be5e..91c4fcaab95d4cccfa1e53fea482ec1a2334f4ff 100644 --- a/build/torch210-cxx11-cu126-x86_64-linux/_ops.py +++ b/build/torch210-cxx11-cu126-x86_64-linux/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _activation_18b7543_dirty -ops = torch.ops._activation_18b7543_dirty +from . import _activation_0e6f27f_dirty +ops = torch.ops._activation_0e6f27f_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_activation_18b7543_dirty::{op_name}" \ No newline at end of file + return f"_activation_0e6f27f_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch210-cxx11-cu126-x86_64-linux/grouped_poly_norm.py b/build/torch210-cxx11-cu126-x86_64-linux/grouped_poly_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..9184736c1c24584a94aa80c84fc18a95ab557a52 --- /dev/null +++ b/build/torch210-cxx11-cu126-x86_64-linux/grouped_poly_norm.py @@ -0,0 +1,583 @@ +"""Triton-accelerated Grouped FusedMulPolyNorm for MoE. + +Fuses the entire PolyNorm computation into two Triton kernels (fwd + bwd), +eliminating multiple intermediate tensors and kernel launches. + +PolyNorm formula (per row): + poly = w[0] * rms_norm(x^3) + w[1] * rms_norm(x^2) + w[2] * rms_norm(x) + bias + output = poly * mul + +where rms_norm(x) = x / sqrt(mean(x^2, dim=-1) + eps) + +Performance optimizations: + - @triton.autotune selects optimal BLOCK_D, num_warps, and num_stages per + hidden dimension. + - Single-tile specialization: when D <= BLOCK_D, all data stays in registers + across the reduction and output phases, eliminating redundant global reads. + - Multi-tile software pipelining: explicit num_stages in autotune configs + enables overlapping memory loads with computation across loop iterations. + - In-kernel binary search for expert mapping: eliminates 2 PyTorch kernel + launches (torch.arange + torch.bucketize) per forward/backward call. + - Backward 2-pass optimization: pass 1 merges RMS statistics computation + with dot product accumulation, pass 2 computes gradients. This reduces + memory traffic compared to a naive 3-pass approach. + +Forward kernel: one program per row, tiles over D dimension. + - Computes x, x^2, x^3 in registers + - Computes three RMS norms in a single pass (shared variance reduction) + - Applies polynomial weights + bias + mul in-place + +Backward kernel: one program per row, tiles over D dimension. + - Recomputes forward intermediates from saved inputs (activation recomputation) + - 2-pass: (1) RMS stats + dot products + bias grad, (2) grad_input + grad_mul + weight grads + - Weight/bias gradients use tl.atomic_add for cross-row accumulation +""" + +import torch +from torch import Tensor + +try: + import triton + import triton.language as tl + + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + + +# --------------------------------------------------------------------------- +# PyTorch reference implementation (for testing and benchmarking) +# --------------------------------------------------------------------------- +def _rms_norm(x: Tensor, eps: float) -> Tensor: + """Per-row RMS normalization: x / sqrt(mean(x^2, dim=-1) + eps)""" + return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps) + + +def grouped_fused_mul_poly_norm_ref( + input: Tensor, + mul: Tensor, + weight: Tensor, + bias: Tensor, + offsets: Tensor, + eps: float = 1e-6, + expert_offset: int = 0, +) -> Tensor: + """PyTorch reference for Grouped FusedMulPolyNorm (vectorized, single pass). + + Uses torch.bucketize to map tokens to experts, then computes PolyNorm + for all tokens at once. torch.compile friendly. + + Args: + input: (total_tokens, D) - concatenated tokens for all experts + mul: (total_tokens, D) - gate values to multiply with + weight: (num_experts, 3) - per-expert polynomial weights [x^3, x^2, x] + bias: (num_experts, 1) - per-expert polynomial bias + offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32) + eps: numerical stability epsilon + + Returns: + (total_tokens, D) - output tensor + """ + orig_dtype = input.dtype + + token_positions = torch.arange(input.shape[0], device=input.device) + expert_idx = torch.bucketize(token_positions, offsets, right=True) + expert_offset + + weight_fp32 = weight.float() + bias_fp32 = bias.float() + + per_token_w = weight_fp32[expert_idx] + per_token_b = bias_fp32[expert_idx] + + x = input.float() + m = mul.float() + + x2 = x * x + x3 = x2 * x + + poly = (per_token_w[:, 0:1] * _rms_norm(x3, eps) + + per_token_w[:, 1:2] * _rms_norm(x2, eps) + + per_token_w[:, 2:3] * _rms_norm(x, eps) + per_token_b) + + return (poly * m).to(orig_dtype) + + +# --------------------------------------------------------------------------- +# Triton kernel implementation +# --------------------------------------------------------------------------- +if HAS_TRITON: + # --- Autotune configurations --- + _GROUPED_POLYNORM_FWD_CONFIGS = [ + triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2), + triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2), + triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2), + triton.Config({"BLOCK_D": 2048}, num_warps=4, num_stages=1), + triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1), + triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1), + triton.Config({"BLOCK_D": 2048}, num_warps=32, num_stages=1), + ] + + _GROUPED_POLYNORM_BWD_CONFIGS = [ + triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2), + triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=5), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2), + triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=4), + triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2), + triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1), + triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1), + ] + + @triton.autotune( + configs=_GROUPED_POLYNORM_FWD_CONFIGS, + key=["D"], + ) + @triton.jit + def _grouped_polynorm_fwd_kernel( + input_ptr, + mul_ptr, + weight_ptr, + bias_ptr, + offsets_ptr, + output_ptr, + N, + D, + num_experts, + eps, + expert_offset, + stride_input_row, + stride_mul_row, + stride_out_row, + BLOCK_D: tl.constexpr, + ): + """Forward kernel: one program per row.""" + row = tl.program_id(0) + if row >= N: + return + + # Binary search for expert index (12 iters covers up to 4096 experts) + lo = 0 + hi = num_experts + for _ in range(12): + if lo < hi: + mid = (lo + hi) // 2 + if tl.load(offsets_ptr + mid) <= row: + lo = mid + 1 + else: + hi = mid + eidx = lo + expert_offset + + w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32) + w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32) + w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32) + b = tl.load(bias_ptr + eidx).to(tl.float32) + + input_row_ptr = input_ptr + row * stride_input_row + mul_row_ptr = mul_ptr + row * stride_mul_row + out_row_ptr = output_ptr + row * stride_out_row + + D_float = D.to(tl.float32) + + # --- Single-tile path --- + if D <= BLOCK_D: + d_offs = tl.arange(0, BLOCK_D) + mask = d_offs < D + + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + m = tl.load(mul_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + + x2 = x * x + x3 = x2 * x + + inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps) + inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps) + inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps) + + # Pre-multiply scalar weight * inv_rms to save 1 FMA per element + w0_inv = w0 * inv_rms_x3 + w1_inv = w1 * inv_rms_x2 + w2_inv = w2 * inv_rms_x + + poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b + tl.store(out_row_ptr + d_offs, poly * m, mask=mask) + else: + # --- Multi-tile: two-pass approach --- + sum_x2 = tl.zeros((), dtype=tl.float32) + sum_x4 = tl.zeros((), dtype=tl.float32) + sum_x6 = tl.zeros((), dtype=tl.float32) + + for d_start in range(0, D, BLOCK_D): + d_offs = d_start + tl.arange(0, BLOCK_D) + mask = d_offs < D + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + x2 = x * x + sum_x2 += tl.sum(x2) + sum_x4 += tl.sum(x2 * x2) + sum_x6 += tl.sum(x2 * x2 * x2) + + inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps) + inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps) + inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps) + + # Pre-multiply scalar weight * inv_rms + w0_inv = w0 * inv_rms_x3 + w1_inv = w1 * inv_rms_x2 + w2_inv = w2 * inv_rms_x + + for d_start in range(0, D, BLOCK_D): + d_offs = d_start + tl.arange(0, BLOCK_D) + mask = d_offs < D + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + m = tl.load(mul_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + x2 = x * x + x3 = x2 * x + poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b + tl.store(out_row_ptr + d_offs, poly * m, mask=mask) + + @triton.autotune( + configs=_GROUPED_POLYNORM_BWD_CONFIGS, + key=["D"], + reset_to_zero=["grad_weight_ptr", "grad_bias_ptr"], + ) + @triton.jit + def _grouped_polynorm_bwd_kernel( + grad_out_ptr, + input_ptr, + mul_ptr, + weight_ptr, + bias_ptr, + offsets_ptr, + grad_input_ptr, + grad_mul_ptr, + grad_weight_ptr, + grad_bias_ptr, + N, + D, + num_experts, + eps, + expert_offset, + stride_row, + BLOCK_D: tl.constexpr, + ): + """Backward kernel: one program per row, 2-pass approach. + + Pass 1: RMS stats + dot products + bias grad + Pass 2: grad_input + grad_mul + weight grads (via atomic_add) + """ + row = tl.program_id(0) + if row >= N: + return + + lo = 0 + hi = num_experts + for _ in range(12): + if lo < hi: + mid = (lo + hi) // 2 + if tl.load(offsets_ptr + mid) <= row: + lo = mid + 1 + else: + hi = mid + eidx = lo + expert_offset + + w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32) + w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32) + w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32) + b_val = tl.load(bias_ptr + eidx).to(tl.float32) + + input_row_ptr = input_ptr + row * stride_row + mul_row_ptr = mul_ptr + row * stride_row + grad_out_row_ptr = grad_out_ptr + row * stride_row + grad_input_row_ptr = grad_input_ptr + row * stride_row + grad_mul_row_ptr = grad_mul_ptr + row * stride_row + + D_float = D.to(tl.float32) + + # --- Single-tile path --- + if D <= BLOCK_D: + d_offs = tl.arange(0, BLOCK_D) + mask = d_offs < D + + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + m = tl.load(mul_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + go = tl.load(grad_out_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + + x2 = x * x + x3 = x2 * x + + # Compute RMS stats (x4 inlined to reduce register pressure) + inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps) + inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps) + inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps) + + w0_inv = w0 * inv_rms_x3 + w1_inv = w1 * inv_rms_x2 + w2_inv = w2 * inv_rms_x + + dpoly = go * m + + # Dot products for coefficients and weight grads + sum_dpoly_x = tl.sum(dpoly * x) + sum_dpoly_x2 = tl.sum(dpoly * x2) + sum_dpoly_x3 = tl.sum(dpoly * x3) + grad_b_acc = tl.sum(dpoly) + + # Weight grads + grad_w0_acc = inv_rms_x3 * sum_dpoly_x3 + grad_w1_acc = inv_rms_x2 * sum_dpoly_x2 + grad_w2_acc = inv_rms_x * sum_dpoly_x + + # Coefficients for grad_input + coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float + coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float + coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float + + # grad_mul + poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + tl.store(grad_mul_row_ptr + d_offs, go * (poly + b_val), mask=mask) + + # grad_input (in-place accumulation to reduce register pressure) + g = inv_rms_x * (w2 * dpoly - x * coeff_x) + g += 2.0 * x * inv_rms_x2 * (w1 * dpoly - x2 * coeff_x2) + g += 3.0 * x2 * inv_rms_x3 * (w0 * dpoly - x3 * coeff_x3) + tl.store(grad_input_row_ptr + d_offs, g, mask=mask) + + tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc) + tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc) + tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc) + tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc) + else: + # --- Multi-tile: 2-pass --- + # Pass 1: RMS stats + dot products + bias grad + sum_x2 = tl.zeros((), dtype=tl.float32) + sum_x4 = tl.zeros((), dtype=tl.float32) + sum_x6 = tl.zeros((), dtype=tl.float32) + sum_dpoly_x = tl.zeros((), dtype=tl.float32) + sum_dpoly_x2 = tl.zeros((), dtype=tl.float32) + sum_dpoly_x3 = tl.zeros((), dtype=tl.float32) + grad_b_acc = tl.zeros((), dtype=tl.float32) + + for d_start in range(0, D, BLOCK_D): + d_offs = d_start + tl.arange(0, BLOCK_D) + mask = d_offs < D + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + m = tl.load(mul_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + go = tl.load(grad_out_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + + x2 = x * x + x3 = x2 * x + dpoly = go * m + + sum_x2 += tl.sum(x2) + sum_x4 += tl.sum(x2 * x2) + sum_x6 += tl.sum(x2 * x2 * x2) + sum_dpoly_x += tl.sum(dpoly * x) + sum_dpoly_x2 += tl.sum(dpoly * x2) + sum_dpoly_x3 += tl.sum(dpoly * x3) + grad_b_acc += tl.sum(dpoly) + + inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps) + inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps) + inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps) + + w0_inv = w0 * inv_rms_x3 + w1_inv = w1 * inv_rms_x2 + w2_inv = w2 * inv_rms_x + + # Weight grads + grad_w0_acc = inv_rms_x3 * sum_dpoly_x3 + grad_w1_acc = inv_rms_x2 * sum_dpoly_x2 + grad_w2_acc = inv_rms_x * sum_dpoly_x + + # Coefficients for grad_input + coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float + coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float + coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float + + # Pass 2: grad_input + grad_mul + for d_start in range(0, D, BLOCK_D): + d_offs = d_start + tl.arange(0, BLOCK_D) + mask = d_offs < D + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + m = tl.load(mul_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + go = tl.load(grad_out_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + + x2 = x * x + x3 = x2 * x + + poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + tl.store(grad_mul_row_ptr + d_offs, + go * (poly + b_val), + mask=mask) + + dpoly = go * m + g = inv_rms_x * (w2 * dpoly - x * coeff_x) + g += (2.0 * x * inv_rms_x2 * + (w1 * dpoly - x2 * coeff_x2)) + g += (3.0 * x2 * inv_rms_x3 * + (w0 * dpoly - x3 * coeff_x3)) + + tl.store(grad_input_row_ptr + d_offs, g, mask=mask) + + tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc) + tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc) + tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc) + tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc) + + class _GroupedPolyNormFn(torch.autograd.Function): + + @staticmethod + def forward(ctx, input, mul, weight, bias, offsets, eps, expert_offset): + N, D = input.shape + input = input.contiguous() + mul = mul.contiguous() + output = torch.empty_like(input) + + num_experts = offsets.shape[0] + assert num_experts <= 4096, ( + f"Supports at most 4096 experts, got {num_experts}.") + + _grouped_polynorm_fwd_kernel[(N,)]( + input, + mul, + weight, + bias, + offsets, + output, + N, + D, + num_experts, + eps, + expert_offset, + stride_input_row=input.stride(0), + stride_mul_row=mul.stride(0), + stride_out_row=output.stride(0), + ) + + ctx.save_for_backward(input, mul, weight, bias, offsets) + ctx.eps = eps + ctx.expert_offset = expert_offset + return output + + @staticmethod + def backward(ctx, grad_output): + input, mul, weight, bias, offsets = ctx.saved_tensors + eps = ctx.eps + expert_offset = ctx.expert_offset + N, D = input.shape + + grad_output = grad_output.contiguous() + grad_input = torch.empty_like(input) + grad_mul = torch.empty_like(mul) + grad_weight = torch.zeros(weight.shape[0], + 3, + device=weight.device, + dtype=torch.float32) + grad_bias = torch.zeros(bias.shape[0], + device=bias.device, + dtype=torch.float32) + + num_experts = offsets.shape[0] + + _grouped_polynorm_bwd_kernel[(N,)]( + grad_output, + input, + mul, + weight, + bias, + offsets, + grad_input, + grad_mul, + grad_weight, + grad_bias, + N, + D, + num_experts, + eps, + expert_offset, + stride_row=input.stride(0), + ) + + grad_weight = grad_weight.to(weight.dtype) + grad_bias = grad_bias.unsqueeze(-1).to(bias.dtype) + + return grad_input, grad_mul, grad_weight, grad_bias, None, None, None + + def grouped_fused_mul_poly_norm( + input: Tensor, + mul: Tensor, + weight: Tensor, + bias: Tensor, + offsets: Tensor, + eps: float = 1e-6, + expert_offset: int = 0, + ) -> Tensor: + """Triton-accelerated Grouped FusedMulPolyNorm. + + Args: + input: (total_tokens, D) - concatenated tokens for all experts + mul: (total_tokens, D) - gate values to multiply with + weight: (num_experts, 3) - per-expert polynomial weights + bias: (num_experts, 1) - per-expert polynomial bias + offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32) + eps: numerical stability epsilon + expert_offset: offset to add to expert index + + Returns: + (total_tokens, D) - output tensor + """ + return _GroupedPolyNormFn.apply(input, mul, weight, bias, offsets, eps, + expert_offset) + +else: + + def grouped_fused_mul_poly_norm( + input: Tensor, + mul: Tensor, + weight: Tensor, + bias: Tensor, + offsets: Tensor, + eps: float = 1e-6, + expert_offset: int = 0, + ) -> Tensor: + raise RuntimeError( + "Triton is not available. Install triton to use " + "grouped_fused_mul_poly_norm.") diff --git a/build/torch210-cxx11-cu128-x86_64-linux/__init__.py b/build/torch210-cxx11-cu128-x86_64-linux/__init__.py index 0f6f29ac2c688bd09afa41c5d1abd9942c4456d8..a71567edf5422a315e6f06eae6a5606756794820 100644 --- a/build/torch210-cxx11-cu128-x86_64-linux/__init__.py +++ b/build/torch210-cxx11-cu128-x86_64-linux/__init__.py @@ -2,6 +2,7 @@ import torch from . import layers, parallel_style from ._ops import ops +from .grouped_poly_norm import grouped_fused_mul_poly_norm from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction @@ -45,6 +46,7 @@ def fused_add_rms_norm( __all__ = [ "poly_norm", "fused_mul_poly_norm", + "grouped_fused_mul_poly_norm", "rms_norm", "fused_add_rms_norm", "layers", diff --git a/build/torch210-cxx11-cu128-x86_64-linux/_activation_0e6f27f_dirty.abi3.so b/build/torch210-cxx11-cu128-x86_64-linux/_activation_0e6f27f_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..92030076050496a09036e47299e309717014fe79 --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/_activation_0e6f27f_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4b8370d2e1f5561ae4b77ac8ae7b3a084e33a0d1952a8f5f9bf4700375313b35 +size 15815392 diff --git a/build/torch210-cxx11-cu128-x86_64-linux/_activation_18b7543_dirty.abi3.so b/build/torch210-cxx11-cu128-x86_64-linux/_activation_18b7543_dirty.abi3.so deleted file mode 100755 index aed6525c923f5f2a1825b682c75d3ad4b1dfa245..0000000000000000000000000000000000000000 --- a/build/torch210-cxx11-cu128-x86_64-linux/_activation_18b7543_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:078853c2db399a227822ea0c8e70c2e13bad41bfa370657dd19aa2efb3b503e9 -size 15815392 diff --git a/build/torch210-cxx11-cu128-x86_64-linux/_ops.py b/build/torch210-cxx11-cu128-x86_64-linux/_ops.py index 1e14ae6ea50287a7fbc0feee54855ae4ffc0be5e..91c4fcaab95d4cccfa1e53fea482ec1a2334f4ff 100644 --- a/build/torch210-cxx11-cu128-x86_64-linux/_ops.py +++ b/build/torch210-cxx11-cu128-x86_64-linux/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _activation_18b7543_dirty -ops = torch.ops._activation_18b7543_dirty +from . import _activation_0e6f27f_dirty +ops = torch.ops._activation_0e6f27f_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_activation_18b7543_dirty::{op_name}" \ No newline at end of file + return f"_activation_0e6f27f_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch210-cxx11-cu128-x86_64-linux/grouped_poly_norm.py b/build/torch210-cxx11-cu128-x86_64-linux/grouped_poly_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..9184736c1c24584a94aa80c84fc18a95ab557a52 --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/grouped_poly_norm.py @@ -0,0 +1,583 @@ +"""Triton-accelerated Grouped FusedMulPolyNorm for MoE. + +Fuses the entire PolyNorm computation into two Triton kernels (fwd + bwd), +eliminating multiple intermediate tensors and kernel launches. + +PolyNorm formula (per row): + poly = w[0] * rms_norm(x^3) + w[1] * rms_norm(x^2) + w[2] * rms_norm(x) + bias + output = poly * mul + +where rms_norm(x) = x / sqrt(mean(x^2, dim=-1) + eps) + +Performance optimizations: + - @triton.autotune selects optimal BLOCK_D, num_warps, and num_stages per + hidden dimension. + - Single-tile specialization: when D <= BLOCK_D, all data stays in registers + across the reduction and output phases, eliminating redundant global reads. + - Multi-tile software pipelining: explicit num_stages in autotune configs + enables overlapping memory loads with computation across loop iterations. + - In-kernel binary search for expert mapping: eliminates 2 PyTorch kernel + launches (torch.arange + torch.bucketize) per forward/backward call. + - Backward 2-pass optimization: pass 1 merges RMS statistics computation + with dot product accumulation, pass 2 computes gradients. This reduces + memory traffic compared to a naive 3-pass approach. + +Forward kernel: one program per row, tiles over D dimension. + - Computes x, x^2, x^3 in registers + - Computes three RMS norms in a single pass (shared variance reduction) + - Applies polynomial weights + bias + mul in-place + +Backward kernel: one program per row, tiles over D dimension. + - Recomputes forward intermediates from saved inputs (activation recomputation) + - 2-pass: (1) RMS stats + dot products + bias grad, (2) grad_input + grad_mul + weight grads + - Weight/bias gradients use tl.atomic_add for cross-row accumulation +""" + +import torch +from torch import Tensor + +try: + import triton + import triton.language as tl + + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + + +# --------------------------------------------------------------------------- +# PyTorch reference implementation (for testing and benchmarking) +# --------------------------------------------------------------------------- +def _rms_norm(x: Tensor, eps: float) -> Tensor: + """Per-row RMS normalization: x / sqrt(mean(x^2, dim=-1) + eps)""" + return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps) + + +def grouped_fused_mul_poly_norm_ref( + input: Tensor, + mul: Tensor, + weight: Tensor, + bias: Tensor, + offsets: Tensor, + eps: float = 1e-6, + expert_offset: int = 0, +) -> Tensor: + """PyTorch reference for Grouped FusedMulPolyNorm (vectorized, single pass). + + Uses torch.bucketize to map tokens to experts, then computes PolyNorm + for all tokens at once. torch.compile friendly. + + Args: + input: (total_tokens, D) - concatenated tokens for all experts + mul: (total_tokens, D) - gate values to multiply with + weight: (num_experts, 3) - per-expert polynomial weights [x^3, x^2, x] + bias: (num_experts, 1) - per-expert polynomial bias + offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32) + eps: numerical stability epsilon + + Returns: + (total_tokens, D) - output tensor + """ + orig_dtype = input.dtype + + token_positions = torch.arange(input.shape[0], device=input.device) + expert_idx = torch.bucketize(token_positions, offsets, right=True) + expert_offset + + weight_fp32 = weight.float() + bias_fp32 = bias.float() + + per_token_w = weight_fp32[expert_idx] + per_token_b = bias_fp32[expert_idx] + + x = input.float() + m = mul.float() + + x2 = x * x + x3 = x2 * x + + poly = (per_token_w[:, 0:1] * _rms_norm(x3, eps) + + per_token_w[:, 1:2] * _rms_norm(x2, eps) + + per_token_w[:, 2:3] * _rms_norm(x, eps) + per_token_b) + + return (poly * m).to(orig_dtype) + + +# --------------------------------------------------------------------------- +# Triton kernel implementation +# --------------------------------------------------------------------------- +if HAS_TRITON: + # --- Autotune configurations --- + _GROUPED_POLYNORM_FWD_CONFIGS = [ + triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2), + triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2), + triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2), + triton.Config({"BLOCK_D": 2048}, num_warps=4, num_stages=1), + triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1), + triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1), + triton.Config({"BLOCK_D": 2048}, num_warps=32, num_stages=1), + ] + + _GROUPED_POLYNORM_BWD_CONFIGS = [ + triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2), + triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=5), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2), + triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=4), + triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2), + triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1), + triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1), + ] + + @triton.autotune( + configs=_GROUPED_POLYNORM_FWD_CONFIGS, + key=["D"], + ) + @triton.jit + def _grouped_polynorm_fwd_kernel( + input_ptr, + mul_ptr, + weight_ptr, + bias_ptr, + offsets_ptr, + output_ptr, + N, + D, + num_experts, + eps, + expert_offset, + stride_input_row, + stride_mul_row, + stride_out_row, + BLOCK_D: tl.constexpr, + ): + """Forward kernel: one program per row.""" + row = tl.program_id(0) + if row >= N: + return + + # Binary search for expert index (12 iters covers up to 4096 experts) + lo = 0 + hi = num_experts + for _ in range(12): + if lo < hi: + mid = (lo + hi) // 2 + if tl.load(offsets_ptr + mid) <= row: + lo = mid + 1 + else: + hi = mid + eidx = lo + expert_offset + + w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32) + w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32) + w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32) + b = tl.load(bias_ptr + eidx).to(tl.float32) + + input_row_ptr = input_ptr + row * stride_input_row + mul_row_ptr = mul_ptr + row * stride_mul_row + out_row_ptr = output_ptr + row * stride_out_row + + D_float = D.to(tl.float32) + + # --- Single-tile path --- + if D <= BLOCK_D: + d_offs = tl.arange(0, BLOCK_D) + mask = d_offs < D + + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + m = tl.load(mul_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + + x2 = x * x + x3 = x2 * x + + inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps) + inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps) + inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps) + + # Pre-multiply scalar weight * inv_rms to save 1 FMA per element + w0_inv = w0 * inv_rms_x3 + w1_inv = w1 * inv_rms_x2 + w2_inv = w2 * inv_rms_x + + poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b + tl.store(out_row_ptr + d_offs, poly * m, mask=mask) + else: + # --- Multi-tile: two-pass approach --- + sum_x2 = tl.zeros((), dtype=tl.float32) + sum_x4 = tl.zeros((), dtype=tl.float32) + sum_x6 = tl.zeros((), dtype=tl.float32) + + for d_start in range(0, D, BLOCK_D): + d_offs = d_start + tl.arange(0, BLOCK_D) + mask = d_offs < D + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + x2 = x * x + sum_x2 += tl.sum(x2) + sum_x4 += tl.sum(x2 * x2) + sum_x6 += tl.sum(x2 * x2 * x2) + + inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps) + inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps) + inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps) + + # Pre-multiply scalar weight * inv_rms + w0_inv = w0 * inv_rms_x3 + w1_inv = w1 * inv_rms_x2 + w2_inv = w2 * inv_rms_x + + for d_start in range(0, D, BLOCK_D): + d_offs = d_start + tl.arange(0, BLOCK_D) + mask = d_offs < D + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + m = tl.load(mul_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + x2 = x * x + x3 = x2 * x + poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b + tl.store(out_row_ptr + d_offs, poly * m, mask=mask) + + @triton.autotune( + configs=_GROUPED_POLYNORM_BWD_CONFIGS, + key=["D"], + reset_to_zero=["grad_weight_ptr", "grad_bias_ptr"], + ) + @triton.jit + def _grouped_polynorm_bwd_kernel( + grad_out_ptr, + input_ptr, + mul_ptr, + weight_ptr, + bias_ptr, + offsets_ptr, + grad_input_ptr, + grad_mul_ptr, + grad_weight_ptr, + grad_bias_ptr, + N, + D, + num_experts, + eps, + expert_offset, + stride_row, + BLOCK_D: tl.constexpr, + ): + """Backward kernel: one program per row, 2-pass approach. + + Pass 1: RMS stats + dot products + bias grad + Pass 2: grad_input + grad_mul + weight grads (via atomic_add) + """ + row = tl.program_id(0) + if row >= N: + return + + lo = 0 + hi = num_experts + for _ in range(12): + if lo < hi: + mid = (lo + hi) // 2 + if tl.load(offsets_ptr + mid) <= row: + lo = mid + 1 + else: + hi = mid + eidx = lo + expert_offset + + w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32) + w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32) + w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32) + b_val = tl.load(bias_ptr + eidx).to(tl.float32) + + input_row_ptr = input_ptr + row * stride_row + mul_row_ptr = mul_ptr + row * stride_row + grad_out_row_ptr = grad_out_ptr + row * stride_row + grad_input_row_ptr = grad_input_ptr + row * stride_row + grad_mul_row_ptr = grad_mul_ptr + row * stride_row + + D_float = D.to(tl.float32) + + # --- Single-tile path --- + if D <= BLOCK_D: + d_offs = tl.arange(0, BLOCK_D) + mask = d_offs < D + + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + m = tl.load(mul_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + go = tl.load(grad_out_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + + x2 = x * x + x3 = x2 * x + + # Compute RMS stats (x4 inlined to reduce register pressure) + inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps) + inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps) + inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps) + + w0_inv = w0 * inv_rms_x3 + w1_inv = w1 * inv_rms_x2 + w2_inv = w2 * inv_rms_x + + dpoly = go * m + + # Dot products for coefficients and weight grads + sum_dpoly_x = tl.sum(dpoly * x) + sum_dpoly_x2 = tl.sum(dpoly * x2) + sum_dpoly_x3 = tl.sum(dpoly * x3) + grad_b_acc = tl.sum(dpoly) + + # Weight grads + grad_w0_acc = inv_rms_x3 * sum_dpoly_x3 + grad_w1_acc = inv_rms_x2 * sum_dpoly_x2 + grad_w2_acc = inv_rms_x * sum_dpoly_x + + # Coefficients for grad_input + coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float + coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float + coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float + + # grad_mul + poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + tl.store(grad_mul_row_ptr + d_offs, go * (poly + b_val), mask=mask) + + # grad_input (in-place accumulation to reduce register pressure) + g = inv_rms_x * (w2 * dpoly - x * coeff_x) + g += 2.0 * x * inv_rms_x2 * (w1 * dpoly - x2 * coeff_x2) + g += 3.0 * x2 * inv_rms_x3 * (w0 * dpoly - x3 * coeff_x3) + tl.store(grad_input_row_ptr + d_offs, g, mask=mask) + + tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc) + tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc) + tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc) + tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc) + else: + # --- Multi-tile: 2-pass --- + # Pass 1: RMS stats + dot products + bias grad + sum_x2 = tl.zeros((), dtype=tl.float32) + sum_x4 = tl.zeros((), dtype=tl.float32) + sum_x6 = tl.zeros((), dtype=tl.float32) + sum_dpoly_x = tl.zeros((), dtype=tl.float32) + sum_dpoly_x2 = tl.zeros((), dtype=tl.float32) + sum_dpoly_x3 = tl.zeros((), dtype=tl.float32) + grad_b_acc = tl.zeros((), dtype=tl.float32) + + for d_start in range(0, D, BLOCK_D): + d_offs = d_start + tl.arange(0, BLOCK_D) + mask = d_offs < D + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + m = tl.load(mul_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + go = tl.load(grad_out_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + + x2 = x * x + x3 = x2 * x + dpoly = go * m + + sum_x2 += tl.sum(x2) + sum_x4 += tl.sum(x2 * x2) + sum_x6 += tl.sum(x2 * x2 * x2) + sum_dpoly_x += tl.sum(dpoly * x) + sum_dpoly_x2 += tl.sum(dpoly * x2) + sum_dpoly_x3 += tl.sum(dpoly * x3) + grad_b_acc += tl.sum(dpoly) + + inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps) + inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps) + inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps) + + w0_inv = w0 * inv_rms_x3 + w1_inv = w1 * inv_rms_x2 + w2_inv = w2 * inv_rms_x + + # Weight grads + grad_w0_acc = inv_rms_x3 * sum_dpoly_x3 + grad_w1_acc = inv_rms_x2 * sum_dpoly_x2 + grad_w2_acc = inv_rms_x * sum_dpoly_x + + # Coefficients for grad_input + coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float + coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float + coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float + + # Pass 2: grad_input + grad_mul + for d_start in range(0, D, BLOCK_D): + d_offs = d_start + tl.arange(0, BLOCK_D) + mask = d_offs < D + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + m = tl.load(mul_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + go = tl.load(grad_out_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + + x2 = x * x + x3 = x2 * x + + poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + tl.store(grad_mul_row_ptr + d_offs, + go * (poly + b_val), + mask=mask) + + dpoly = go * m + g = inv_rms_x * (w2 * dpoly - x * coeff_x) + g += (2.0 * x * inv_rms_x2 * + (w1 * dpoly - x2 * coeff_x2)) + g += (3.0 * x2 * inv_rms_x3 * + (w0 * dpoly - x3 * coeff_x3)) + + tl.store(grad_input_row_ptr + d_offs, g, mask=mask) + + tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc) + tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc) + tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc) + tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc) + + class _GroupedPolyNormFn(torch.autograd.Function): + + @staticmethod + def forward(ctx, input, mul, weight, bias, offsets, eps, expert_offset): + N, D = input.shape + input = input.contiguous() + mul = mul.contiguous() + output = torch.empty_like(input) + + num_experts = offsets.shape[0] + assert num_experts <= 4096, ( + f"Supports at most 4096 experts, got {num_experts}.") + + _grouped_polynorm_fwd_kernel[(N,)]( + input, + mul, + weight, + bias, + offsets, + output, + N, + D, + num_experts, + eps, + expert_offset, + stride_input_row=input.stride(0), + stride_mul_row=mul.stride(0), + stride_out_row=output.stride(0), + ) + + ctx.save_for_backward(input, mul, weight, bias, offsets) + ctx.eps = eps + ctx.expert_offset = expert_offset + return output + + @staticmethod + def backward(ctx, grad_output): + input, mul, weight, bias, offsets = ctx.saved_tensors + eps = ctx.eps + expert_offset = ctx.expert_offset + N, D = input.shape + + grad_output = grad_output.contiguous() + grad_input = torch.empty_like(input) + grad_mul = torch.empty_like(mul) + grad_weight = torch.zeros(weight.shape[0], + 3, + device=weight.device, + dtype=torch.float32) + grad_bias = torch.zeros(bias.shape[0], + device=bias.device, + dtype=torch.float32) + + num_experts = offsets.shape[0] + + _grouped_polynorm_bwd_kernel[(N,)]( + grad_output, + input, + mul, + weight, + bias, + offsets, + grad_input, + grad_mul, + grad_weight, + grad_bias, + N, + D, + num_experts, + eps, + expert_offset, + stride_row=input.stride(0), + ) + + grad_weight = grad_weight.to(weight.dtype) + grad_bias = grad_bias.unsqueeze(-1).to(bias.dtype) + + return grad_input, grad_mul, grad_weight, grad_bias, None, None, None + + def grouped_fused_mul_poly_norm( + input: Tensor, + mul: Tensor, + weight: Tensor, + bias: Tensor, + offsets: Tensor, + eps: float = 1e-6, + expert_offset: int = 0, + ) -> Tensor: + """Triton-accelerated Grouped FusedMulPolyNorm. + + Args: + input: (total_tokens, D) - concatenated tokens for all experts + mul: (total_tokens, D) - gate values to multiply with + weight: (num_experts, 3) - per-expert polynomial weights + bias: (num_experts, 1) - per-expert polynomial bias + offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32) + eps: numerical stability epsilon + expert_offset: offset to add to expert index + + Returns: + (total_tokens, D) - output tensor + """ + return _GroupedPolyNormFn.apply(input, mul, weight, bias, offsets, eps, + expert_offset) + +else: + + def grouped_fused_mul_poly_norm( + input: Tensor, + mul: Tensor, + weight: Tensor, + bias: Tensor, + offsets: Tensor, + eps: float = 1e-6, + expert_offset: int = 0, + ) -> Tensor: + raise RuntimeError( + "Triton is not available. Install triton to use " + "grouped_fused_mul_poly_norm.") diff --git a/build/torch210-cxx11-cu130-x86_64-linux/__init__.py b/build/torch210-cxx11-cu130-x86_64-linux/__init__.py index 0f6f29ac2c688bd09afa41c5d1abd9942c4456d8..a71567edf5422a315e6f06eae6a5606756794820 100644 --- a/build/torch210-cxx11-cu130-x86_64-linux/__init__.py +++ b/build/torch210-cxx11-cu130-x86_64-linux/__init__.py @@ -2,6 +2,7 @@ import torch from . import layers, parallel_style from ._ops import ops +from .grouped_poly_norm import grouped_fused_mul_poly_norm from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction @@ -45,6 +46,7 @@ def fused_add_rms_norm( __all__ = [ "poly_norm", "fused_mul_poly_norm", + "grouped_fused_mul_poly_norm", "rms_norm", "fused_add_rms_norm", "layers", diff --git a/build/torch210-cxx11-cu130-x86_64-linux/_activation_0e6f27f_dirty.abi3.so b/build/torch210-cxx11-cu130-x86_64-linux/_activation_0e6f27f_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..1136e14de3dcfc6580e54891538ea24e6804e1e0 --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/_activation_0e6f27f_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:edf3fca2079788750c4e0497012ba93c34c770aca4c9d4f22d03be4a86a2ce8c +size 13520952 diff --git a/build/torch210-cxx11-cu130-x86_64-linux/_activation_18b7543_dirty.abi3.so b/build/torch210-cxx11-cu130-x86_64-linux/_activation_18b7543_dirty.abi3.so deleted file mode 100755 index 56e86ea0e950d091490d62c3eadf538156cd6fb3..0000000000000000000000000000000000000000 --- a/build/torch210-cxx11-cu130-x86_64-linux/_activation_18b7543_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:59e2c13071e1807a6225c5ad7a4a7eb04d46b1f177ae6344d199a9e7f14daf92 -size 13520952 diff --git a/build/torch210-cxx11-cu130-x86_64-linux/_ops.py b/build/torch210-cxx11-cu130-x86_64-linux/_ops.py index 1e14ae6ea50287a7fbc0feee54855ae4ffc0be5e..91c4fcaab95d4cccfa1e53fea482ec1a2334f4ff 100644 --- a/build/torch210-cxx11-cu130-x86_64-linux/_ops.py +++ b/build/torch210-cxx11-cu130-x86_64-linux/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _activation_18b7543_dirty -ops = torch.ops._activation_18b7543_dirty +from . import _activation_0e6f27f_dirty +ops = torch.ops._activation_0e6f27f_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_activation_18b7543_dirty::{op_name}" \ No newline at end of file + return f"_activation_0e6f27f_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch210-cxx11-cu130-x86_64-linux/grouped_poly_norm.py b/build/torch210-cxx11-cu130-x86_64-linux/grouped_poly_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..9184736c1c24584a94aa80c84fc18a95ab557a52 --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/grouped_poly_norm.py @@ -0,0 +1,583 @@ +"""Triton-accelerated Grouped FusedMulPolyNorm for MoE. + +Fuses the entire PolyNorm computation into two Triton kernels (fwd + bwd), +eliminating multiple intermediate tensors and kernel launches. + +PolyNorm formula (per row): + poly = w[0] * rms_norm(x^3) + w[1] * rms_norm(x^2) + w[2] * rms_norm(x) + bias + output = poly * mul + +where rms_norm(x) = x / sqrt(mean(x^2, dim=-1) + eps) + +Performance optimizations: + - @triton.autotune selects optimal BLOCK_D, num_warps, and num_stages per + hidden dimension. + - Single-tile specialization: when D <= BLOCK_D, all data stays in registers + across the reduction and output phases, eliminating redundant global reads. + - Multi-tile software pipelining: explicit num_stages in autotune configs + enables overlapping memory loads with computation across loop iterations. + - In-kernel binary search for expert mapping: eliminates 2 PyTorch kernel + launches (torch.arange + torch.bucketize) per forward/backward call. + - Backward 2-pass optimization: pass 1 merges RMS statistics computation + with dot product accumulation, pass 2 computes gradients. This reduces + memory traffic compared to a naive 3-pass approach. + +Forward kernel: one program per row, tiles over D dimension. + - Computes x, x^2, x^3 in registers + - Computes three RMS norms in a single pass (shared variance reduction) + - Applies polynomial weights + bias + mul in-place + +Backward kernel: one program per row, tiles over D dimension. + - Recomputes forward intermediates from saved inputs (activation recomputation) + - 2-pass: (1) RMS stats + dot products + bias grad, (2) grad_input + grad_mul + weight grads + - Weight/bias gradients use tl.atomic_add for cross-row accumulation +""" + +import torch +from torch import Tensor + +try: + import triton + import triton.language as tl + + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + + +# --------------------------------------------------------------------------- +# PyTorch reference implementation (for testing and benchmarking) +# --------------------------------------------------------------------------- +def _rms_norm(x: Tensor, eps: float) -> Tensor: + """Per-row RMS normalization: x / sqrt(mean(x^2, dim=-1) + eps)""" + return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps) + + +def grouped_fused_mul_poly_norm_ref( + input: Tensor, + mul: Tensor, + weight: Tensor, + bias: Tensor, + offsets: Tensor, + eps: float = 1e-6, + expert_offset: int = 0, +) -> Tensor: + """PyTorch reference for Grouped FusedMulPolyNorm (vectorized, single pass). + + Uses torch.bucketize to map tokens to experts, then computes PolyNorm + for all tokens at once. torch.compile friendly. + + Args: + input: (total_tokens, D) - concatenated tokens for all experts + mul: (total_tokens, D) - gate values to multiply with + weight: (num_experts, 3) - per-expert polynomial weights [x^3, x^2, x] + bias: (num_experts, 1) - per-expert polynomial bias + offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32) + eps: numerical stability epsilon + + Returns: + (total_tokens, D) - output tensor + """ + orig_dtype = input.dtype + + token_positions = torch.arange(input.shape[0], device=input.device) + expert_idx = torch.bucketize(token_positions, offsets, right=True) + expert_offset + + weight_fp32 = weight.float() + bias_fp32 = bias.float() + + per_token_w = weight_fp32[expert_idx] + per_token_b = bias_fp32[expert_idx] + + x = input.float() + m = mul.float() + + x2 = x * x + x3 = x2 * x + + poly = (per_token_w[:, 0:1] * _rms_norm(x3, eps) + + per_token_w[:, 1:2] * _rms_norm(x2, eps) + + per_token_w[:, 2:3] * _rms_norm(x, eps) + per_token_b) + + return (poly * m).to(orig_dtype) + + +# --------------------------------------------------------------------------- +# Triton kernel implementation +# --------------------------------------------------------------------------- +if HAS_TRITON: + # --- Autotune configurations --- + _GROUPED_POLYNORM_FWD_CONFIGS = [ + triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2), + triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2), + triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2), + triton.Config({"BLOCK_D": 2048}, num_warps=4, num_stages=1), + triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1), + triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1), + triton.Config({"BLOCK_D": 2048}, num_warps=32, num_stages=1), + ] + + _GROUPED_POLYNORM_BWD_CONFIGS = [ + triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2), + triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=5), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2), + triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=4), + triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2), + triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1), + triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1), + ] + + @triton.autotune( + configs=_GROUPED_POLYNORM_FWD_CONFIGS, + key=["D"], + ) + @triton.jit + def _grouped_polynorm_fwd_kernel( + input_ptr, + mul_ptr, + weight_ptr, + bias_ptr, + offsets_ptr, + output_ptr, + N, + D, + num_experts, + eps, + expert_offset, + stride_input_row, + stride_mul_row, + stride_out_row, + BLOCK_D: tl.constexpr, + ): + """Forward kernel: one program per row.""" + row = tl.program_id(0) + if row >= N: + return + + # Binary search for expert index (12 iters covers up to 4096 experts) + lo = 0 + hi = num_experts + for _ in range(12): + if lo < hi: + mid = (lo + hi) // 2 + if tl.load(offsets_ptr + mid) <= row: + lo = mid + 1 + else: + hi = mid + eidx = lo + expert_offset + + w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32) + w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32) + w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32) + b = tl.load(bias_ptr + eidx).to(tl.float32) + + input_row_ptr = input_ptr + row * stride_input_row + mul_row_ptr = mul_ptr + row * stride_mul_row + out_row_ptr = output_ptr + row * stride_out_row + + D_float = D.to(tl.float32) + + # --- Single-tile path --- + if D <= BLOCK_D: + d_offs = tl.arange(0, BLOCK_D) + mask = d_offs < D + + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + m = tl.load(mul_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + + x2 = x * x + x3 = x2 * x + + inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps) + inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps) + inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps) + + # Pre-multiply scalar weight * inv_rms to save 1 FMA per element + w0_inv = w0 * inv_rms_x3 + w1_inv = w1 * inv_rms_x2 + w2_inv = w2 * inv_rms_x + + poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b + tl.store(out_row_ptr + d_offs, poly * m, mask=mask) + else: + # --- Multi-tile: two-pass approach --- + sum_x2 = tl.zeros((), dtype=tl.float32) + sum_x4 = tl.zeros((), dtype=tl.float32) + sum_x6 = tl.zeros((), dtype=tl.float32) + + for d_start in range(0, D, BLOCK_D): + d_offs = d_start + tl.arange(0, BLOCK_D) + mask = d_offs < D + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + x2 = x * x + sum_x2 += tl.sum(x2) + sum_x4 += tl.sum(x2 * x2) + sum_x6 += tl.sum(x2 * x2 * x2) + + inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps) + inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps) + inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps) + + # Pre-multiply scalar weight * inv_rms + w0_inv = w0 * inv_rms_x3 + w1_inv = w1 * inv_rms_x2 + w2_inv = w2 * inv_rms_x + + for d_start in range(0, D, BLOCK_D): + d_offs = d_start + tl.arange(0, BLOCK_D) + mask = d_offs < D + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + m = tl.load(mul_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + x2 = x * x + x3 = x2 * x + poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b + tl.store(out_row_ptr + d_offs, poly * m, mask=mask) + + @triton.autotune( + configs=_GROUPED_POLYNORM_BWD_CONFIGS, + key=["D"], + reset_to_zero=["grad_weight_ptr", "grad_bias_ptr"], + ) + @triton.jit + def _grouped_polynorm_bwd_kernel( + grad_out_ptr, + input_ptr, + mul_ptr, + weight_ptr, + bias_ptr, + offsets_ptr, + grad_input_ptr, + grad_mul_ptr, + grad_weight_ptr, + grad_bias_ptr, + N, + D, + num_experts, + eps, + expert_offset, + stride_row, + BLOCK_D: tl.constexpr, + ): + """Backward kernel: one program per row, 2-pass approach. + + Pass 1: RMS stats + dot products + bias grad + Pass 2: grad_input + grad_mul + weight grads (via atomic_add) + """ + row = tl.program_id(0) + if row >= N: + return + + lo = 0 + hi = num_experts + for _ in range(12): + if lo < hi: + mid = (lo + hi) // 2 + if tl.load(offsets_ptr + mid) <= row: + lo = mid + 1 + else: + hi = mid + eidx = lo + expert_offset + + w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32) + w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32) + w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32) + b_val = tl.load(bias_ptr + eidx).to(tl.float32) + + input_row_ptr = input_ptr + row * stride_row + mul_row_ptr = mul_ptr + row * stride_row + grad_out_row_ptr = grad_out_ptr + row * stride_row + grad_input_row_ptr = grad_input_ptr + row * stride_row + grad_mul_row_ptr = grad_mul_ptr + row * stride_row + + D_float = D.to(tl.float32) + + # --- Single-tile path --- + if D <= BLOCK_D: + d_offs = tl.arange(0, BLOCK_D) + mask = d_offs < D + + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + m = tl.load(mul_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + go = tl.load(grad_out_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + + x2 = x * x + x3 = x2 * x + + # Compute RMS stats (x4 inlined to reduce register pressure) + inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps) + inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps) + inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps) + + w0_inv = w0 * inv_rms_x3 + w1_inv = w1 * inv_rms_x2 + w2_inv = w2 * inv_rms_x + + dpoly = go * m + + # Dot products for coefficients and weight grads + sum_dpoly_x = tl.sum(dpoly * x) + sum_dpoly_x2 = tl.sum(dpoly * x2) + sum_dpoly_x3 = tl.sum(dpoly * x3) + grad_b_acc = tl.sum(dpoly) + + # Weight grads + grad_w0_acc = inv_rms_x3 * sum_dpoly_x3 + grad_w1_acc = inv_rms_x2 * sum_dpoly_x2 + grad_w2_acc = inv_rms_x * sum_dpoly_x + + # Coefficients for grad_input + coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float + coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float + coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float + + # grad_mul + poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + tl.store(grad_mul_row_ptr + d_offs, go * (poly + b_val), mask=mask) + + # grad_input (in-place accumulation to reduce register pressure) + g = inv_rms_x * (w2 * dpoly - x * coeff_x) + g += 2.0 * x * inv_rms_x2 * (w1 * dpoly - x2 * coeff_x2) + g += 3.0 * x2 * inv_rms_x3 * (w0 * dpoly - x3 * coeff_x3) + tl.store(grad_input_row_ptr + d_offs, g, mask=mask) + + tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc) + tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc) + tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc) + tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc) + else: + # --- Multi-tile: 2-pass --- + # Pass 1: RMS stats + dot products + bias grad + sum_x2 = tl.zeros((), dtype=tl.float32) + sum_x4 = tl.zeros((), dtype=tl.float32) + sum_x6 = tl.zeros((), dtype=tl.float32) + sum_dpoly_x = tl.zeros((), dtype=tl.float32) + sum_dpoly_x2 = tl.zeros((), dtype=tl.float32) + sum_dpoly_x3 = tl.zeros((), dtype=tl.float32) + grad_b_acc = tl.zeros((), dtype=tl.float32) + + for d_start in range(0, D, BLOCK_D): + d_offs = d_start + tl.arange(0, BLOCK_D) + mask = d_offs < D + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + m = tl.load(mul_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + go = tl.load(grad_out_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + + x2 = x * x + x3 = x2 * x + dpoly = go * m + + sum_x2 += tl.sum(x2) + sum_x4 += tl.sum(x2 * x2) + sum_x6 += tl.sum(x2 * x2 * x2) + sum_dpoly_x += tl.sum(dpoly * x) + sum_dpoly_x2 += tl.sum(dpoly * x2) + sum_dpoly_x3 += tl.sum(dpoly * x3) + grad_b_acc += tl.sum(dpoly) + + inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps) + inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps) + inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps) + + w0_inv = w0 * inv_rms_x3 + w1_inv = w1 * inv_rms_x2 + w2_inv = w2 * inv_rms_x + + # Weight grads + grad_w0_acc = inv_rms_x3 * sum_dpoly_x3 + grad_w1_acc = inv_rms_x2 * sum_dpoly_x2 + grad_w2_acc = inv_rms_x * sum_dpoly_x + + # Coefficients for grad_input + coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float + coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float + coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float + + # Pass 2: grad_input + grad_mul + for d_start in range(0, D, BLOCK_D): + d_offs = d_start + tl.arange(0, BLOCK_D) + mask = d_offs < D + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + m = tl.load(mul_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + go = tl.load(grad_out_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + + x2 = x * x + x3 = x2 * x + + poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + tl.store(grad_mul_row_ptr + d_offs, + go * (poly + b_val), + mask=mask) + + dpoly = go * m + g = inv_rms_x * (w2 * dpoly - x * coeff_x) + g += (2.0 * x * inv_rms_x2 * + (w1 * dpoly - x2 * coeff_x2)) + g += (3.0 * x2 * inv_rms_x3 * + (w0 * dpoly - x3 * coeff_x3)) + + tl.store(grad_input_row_ptr + d_offs, g, mask=mask) + + tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc) + tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc) + tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc) + tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc) + + class _GroupedPolyNormFn(torch.autograd.Function): + + @staticmethod + def forward(ctx, input, mul, weight, bias, offsets, eps, expert_offset): + N, D = input.shape + input = input.contiguous() + mul = mul.contiguous() + output = torch.empty_like(input) + + num_experts = offsets.shape[0] + assert num_experts <= 4096, ( + f"Supports at most 4096 experts, got {num_experts}.") + + _grouped_polynorm_fwd_kernel[(N,)]( + input, + mul, + weight, + bias, + offsets, + output, + N, + D, + num_experts, + eps, + expert_offset, + stride_input_row=input.stride(0), + stride_mul_row=mul.stride(0), + stride_out_row=output.stride(0), + ) + + ctx.save_for_backward(input, mul, weight, bias, offsets) + ctx.eps = eps + ctx.expert_offset = expert_offset + return output + + @staticmethod + def backward(ctx, grad_output): + input, mul, weight, bias, offsets = ctx.saved_tensors + eps = ctx.eps + expert_offset = ctx.expert_offset + N, D = input.shape + + grad_output = grad_output.contiguous() + grad_input = torch.empty_like(input) + grad_mul = torch.empty_like(mul) + grad_weight = torch.zeros(weight.shape[0], + 3, + device=weight.device, + dtype=torch.float32) + grad_bias = torch.zeros(bias.shape[0], + device=bias.device, + dtype=torch.float32) + + num_experts = offsets.shape[0] + + _grouped_polynorm_bwd_kernel[(N,)]( + grad_output, + input, + mul, + weight, + bias, + offsets, + grad_input, + grad_mul, + grad_weight, + grad_bias, + N, + D, + num_experts, + eps, + expert_offset, + stride_row=input.stride(0), + ) + + grad_weight = grad_weight.to(weight.dtype) + grad_bias = grad_bias.unsqueeze(-1).to(bias.dtype) + + return grad_input, grad_mul, grad_weight, grad_bias, None, None, None + + def grouped_fused_mul_poly_norm( + input: Tensor, + mul: Tensor, + weight: Tensor, + bias: Tensor, + offsets: Tensor, + eps: float = 1e-6, + expert_offset: int = 0, + ) -> Tensor: + """Triton-accelerated Grouped FusedMulPolyNorm. + + Args: + input: (total_tokens, D) - concatenated tokens for all experts + mul: (total_tokens, D) - gate values to multiply with + weight: (num_experts, 3) - per-expert polynomial weights + bias: (num_experts, 1) - per-expert polynomial bias + offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32) + eps: numerical stability epsilon + expert_offset: offset to add to expert index + + Returns: + (total_tokens, D) - output tensor + """ + return _GroupedPolyNormFn.apply(input, mul, weight, bias, offsets, eps, + expert_offset) + +else: + + def grouped_fused_mul_poly_norm( + input: Tensor, + mul: Tensor, + weight: Tensor, + bias: Tensor, + offsets: Tensor, + eps: float = 1e-6, + expert_offset: int = 0, + ) -> Tensor: + raise RuntimeError( + "Triton is not available. Install triton to use " + "grouped_fused_mul_poly_norm.") diff --git a/build/torch210-cxx11-rocm70-x86_64-linux/__init__.py b/build/torch210-cxx11-rocm70-x86_64-linux/__init__.py index 0f6f29ac2c688bd09afa41c5d1abd9942c4456d8..a71567edf5422a315e6f06eae6a5606756794820 100644 --- a/build/torch210-cxx11-rocm70-x86_64-linux/__init__.py +++ b/build/torch210-cxx11-rocm70-x86_64-linux/__init__.py @@ -2,6 +2,7 @@ import torch from . import layers, parallel_style from ._ops import ops +from .grouped_poly_norm import grouped_fused_mul_poly_norm from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction @@ -45,6 +46,7 @@ def fused_add_rms_norm( __all__ = [ "poly_norm", "fused_mul_poly_norm", + "grouped_fused_mul_poly_norm", "rms_norm", "fused_add_rms_norm", "layers", diff --git a/build/torch210-cxx11-rocm70-x86_64-linux/_activation_0e6f27f_dirty.abi3.so b/build/torch210-cxx11-rocm70-x86_64-linux/_activation_0e6f27f_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..e0be0ec0d8372be979ea7721acfe48e28ca9aa30 --- /dev/null +++ b/build/torch210-cxx11-rocm70-x86_64-linux/_activation_0e6f27f_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5f8d5a173c51cb2dabe3554da743aed307c04b5d51c9d0d460a8fa5a821b5495 +size 2919488 diff --git a/build/torch210-cxx11-rocm70-x86_64-linux/_activation_18b7543_dirty.abi3.so b/build/torch210-cxx11-rocm70-x86_64-linux/_activation_18b7543_dirty.abi3.so deleted file mode 100755 index 89cd6d1522af7bfe8b53510cd92a7a9009a7b75d..0000000000000000000000000000000000000000 --- a/build/torch210-cxx11-rocm70-x86_64-linux/_activation_18b7543_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:45ff2b71abb33d840d92116980e519786ed06f1e337d681d0e3301dba241ff63 -size 2919488 diff --git a/build/torch210-cxx11-rocm70-x86_64-linux/_ops.py b/build/torch210-cxx11-rocm70-x86_64-linux/_ops.py index 1e14ae6ea50287a7fbc0feee54855ae4ffc0be5e..91c4fcaab95d4cccfa1e53fea482ec1a2334f4ff 100644 --- a/build/torch210-cxx11-rocm70-x86_64-linux/_ops.py +++ b/build/torch210-cxx11-rocm70-x86_64-linux/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _activation_18b7543_dirty -ops = torch.ops._activation_18b7543_dirty +from . import _activation_0e6f27f_dirty +ops = torch.ops._activation_0e6f27f_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_activation_18b7543_dirty::{op_name}" \ No newline at end of file + return f"_activation_0e6f27f_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch210-cxx11-rocm70-x86_64-linux/grouped_poly_norm.py b/build/torch210-cxx11-rocm70-x86_64-linux/grouped_poly_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..9184736c1c24584a94aa80c84fc18a95ab557a52 --- /dev/null +++ b/build/torch210-cxx11-rocm70-x86_64-linux/grouped_poly_norm.py @@ -0,0 +1,583 @@ +"""Triton-accelerated Grouped FusedMulPolyNorm for MoE. + +Fuses the entire PolyNorm computation into two Triton kernels (fwd + bwd), +eliminating multiple intermediate tensors and kernel launches. + +PolyNorm formula (per row): + poly = w[0] * rms_norm(x^3) + w[1] * rms_norm(x^2) + w[2] * rms_norm(x) + bias + output = poly * mul + +where rms_norm(x) = x / sqrt(mean(x^2, dim=-1) + eps) + +Performance optimizations: + - @triton.autotune selects optimal BLOCK_D, num_warps, and num_stages per + hidden dimension. + - Single-tile specialization: when D <= BLOCK_D, all data stays in registers + across the reduction and output phases, eliminating redundant global reads. + - Multi-tile software pipelining: explicit num_stages in autotune configs + enables overlapping memory loads with computation across loop iterations. + - In-kernel binary search for expert mapping: eliminates 2 PyTorch kernel + launches (torch.arange + torch.bucketize) per forward/backward call. + - Backward 2-pass optimization: pass 1 merges RMS statistics computation + with dot product accumulation, pass 2 computes gradients. This reduces + memory traffic compared to a naive 3-pass approach. + +Forward kernel: one program per row, tiles over D dimension. + - Computes x, x^2, x^3 in registers + - Computes three RMS norms in a single pass (shared variance reduction) + - Applies polynomial weights + bias + mul in-place + +Backward kernel: one program per row, tiles over D dimension. + - Recomputes forward intermediates from saved inputs (activation recomputation) + - 2-pass: (1) RMS stats + dot products + bias grad, (2) grad_input + grad_mul + weight grads + - Weight/bias gradients use tl.atomic_add for cross-row accumulation +""" + +import torch +from torch import Tensor + +try: + import triton + import triton.language as tl + + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + + +# --------------------------------------------------------------------------- +# PyTorch reference implementation (for testing and benchmarking) +# --------------------------------------------------------------------------- +def _rms_norm(x: Tensor, eps: float) -> Tensor: + """Per-row RMS normalization: x / sqrt(mean(x^2, dim=-1) + eps)""" + return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps) + + +def grouped_fused_mul_poly_norm_ref( + input: Tensor, + mul: Tensor, + weight: Tensor, + bias: Tensor, + offsets: Tensor, + eps: float = 1e-6, + expert_offset: int = 0, +) -> Tensor: + """PyTorch reference for Grouped FusedMulPolyNorm (vectorized, single pass). + + Uses torch.bucketize to map tokens to experts, then computes PolyNorm + for all tokens at once. torch.compile friendly. + + Args: + input: (total_tokens, D) - concatenated tokens for all experts + mul: (total_tokens, D) - gate values to multiply with + weight: (num_experts, 3) - per-expert polynomial weights [x^3, x^2, x] + bias: (num_experts, 1) - per-expert polynomial bias + offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32) + eps: numerical stability epsilon + + Returns: + (total_tokens, D) - output tensor + """ + orig_dtype = input.dtype + + token_positions = torch.arange(input.shape[0], device=input.device) + expert_idx = torch.bucketize(token_positions, offsets, right=True) + expert_offset + + weight_fp32 = weight.float() + bias_fp32 = bias.float() + + per_token_w = weight_fp32[expert_idx] + per_token_b = bias_fp32[expert_idx] + + x = input.float() + m = mul.float() + + x2 = x * x + x3 = x2 * x + + poly = (per_token_w[:, 0:1] * _rms_norm(x3, eps) + + per_token_w[:, 1:2] * _rms_norm(x2, eps) + + per_token_w[:, 2:3] * _rms_norm(x, eps) + per_token_b) + + return (poly * m).to(orig_dtype) + + +# --------------------------------------------------------------------------- +# Triton kernel implementation +# --------------------------------------------------------------------------- +if HAS_TRITON: + # --- Autotune configurations --- + _GROUPED_POLYNORM_FWD_CONFIGS = [ + triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2), + triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2), + triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2), + triton.Config({"BLOCK_D": 2048}, num_warps=4, num_stages=1), + triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1), + triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1), + triton.Config({"BLOCK_D": 2048}, num_warps=32, num_stages=1), + ] + + _GROUPED_POLYNORM_BWD_CONFIGS = [ + triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2), + triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=5), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2), + triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=4), + triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2), + triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1), + triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1), + ] + + @triton.autotune( + configs=_GROUPED_POLYNORM_FWD_CONFIGS, + key=["D"], + ) + @triton.jit + def _grouped_polynorm_fwd_kernel( + input_ptr, + mul_ptr, + weight_ptr, + bias_ptr, + offsets_ptr, + output_ptr, + N, + D, + num_experts, + eps, + expert_offset, + stride_input_row, + stride_mul_row, + stride_out_row, + BLOCK_D: tl.constexpr, + ): + """Forward kernel: one program per row.""" + row = tl.program_id(0) + if row >= N: + return + + # Binary search for expert index (12 iters covers up to 4096 experts) + lo = 0 + hi = num_experts + for _ in range(12): + if lo < hi: + mid = (lo + hi) // 2 + if tl.load(offsets_ptr + mid) <= row: + lo = mid + 1 + else: + hi = mid + eidx = lo + expert_offset + + w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32) + w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32) + w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32) + b = tl.load(bias_ptr + eidx).to(tl.float32) + + input_row_ptr = input_ptr + row * stride_input_row + mul_row_ptr = mul_ptr + row * stride_mul_row + out_row_ptr = output_ptr + row * stride_out_row + + D_float = D.to(tl.float32) + + # --- Single-tile path --- + if D <= BLOCK_D: + d_offs = tl.arange(0, BLOCK_D) + mask = d_offs < D + + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + m = tl.load(mul_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + + x2 = x * x + x3 = x2 * x + + inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps) + inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps) + inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps) + + # Pre-multiply scalar weight * inv_rms to save 1 FMA per element + w0_inv = w0 * inv_rms_x3 + w1_inv = w1 * inv_rms_x2 + w2_inv = w2 * inv_rms_x + + poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b + tl.store(out_row_ptr + d_offs, poly * m, mask=mask) + else: + # --- Multi-tile: two-pass approach --- + sum_x2 = tl.zeros((), dtype=tl.float32) + sum_x4 = tl.zeros((), dtype=tl.float32) + sum_x6 = tl.zeros((), dtype=tl.float32) + + for d_start in range(0, D, BLOCK_D): + d_offs = d_start + tl.arange(0, BLOCK_D) + mask = d_offs < D + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + x2 = x * x + sum_x2 += tl.sum(x2) + sum_x4 += tl.sum(x2 * x2) + sum_x6 += tl.sum(x2 * x2 * x2) + + inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps) + inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps) + inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps) + + # Pre-multiply scalar weight * inv_rms + w0_inv = w0 * inv_rms_x3 + w1_inv = w1 * inv_rms_x2 + w2_inv = w2 * inv_rms_x + + for d_start in range(0, D, BLOCK_D): + d_offs = d_start + tl.arange(0, BLOCK_D) + mask = d_offs < D + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + m = tl.load(mul_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + x2 = x * x + x3 = x2 * x + poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b + tl.store(out_row_ptr + d_offs, poly * m, mask=mask) + + @triton.autotune( + configs=_GROUPED_POLYNORM_BWD_CONFIGS, + key=["D"], + reset_to_zero=["grad_weight_ptr", "grad_bias_ptr"], + ) + @triton.jit + def _grouped_polynorm_bwd_kernel( + grad_out_ptr, + input_ptr, + mul_ptr, + weight_ptr, + bias_ptr, + offsets_ptr, + grad_input_ptr, + grad_mul_ptr, + grad_weight_ptr, + grad_bias_ptr, + N, + D, + num_experts, + eps, + expert_offset, + stride_row, + BLOCK_D: tl.constexpr, + ): + """Backward kernel: one program per row, 2-pass approach. + + Pass 1: RMS stats + dot products + bias grad + Pass 2: grad_input + grad_mul + weight grads (via atomic_add) + """ + row = tl.program_id(0) + if row >= N: + return + + lo = 0 + hi = num_experts + for _ in range(12): + if lo < hi: + mid = (lo + hi) // 2 + if tl.load(offsets_ptr + mid) <= row: + lo = mid + 1 + else: + hi = mid + eidx = lo + expert_offset + + w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32) + w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32) + w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32) + b_val = tl.load(bias_ptr + eidx).to(tl.float32) + + input_row_ptr = input_ptr + row * stride_row + mul_row_ptr = mul_ptr + row * stride_row + grad_out_row_ptr = grad_out_ptr + row * stride_row + grad_input_row_ptr = grad_input_ptr + row * stride_row + grad_mul_row_ptr = grad_mul_ptr + row * stride_row + + D_float = D.to(tl.float32) + + # --- Single-tile path --- + if D <= BLOCK_D: + d_offs = tl.arange(0, BLOCK_D) + mask = d_offs < D + + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + m = tl.load(mul_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + go = tl.load(grad_out_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + + x2 = x * x + x3 = x2 * x + + # Compute RMS stats (x4 inlined to reduce register pressure) + inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps) + inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps) + inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps) + + w0_inv = w0 * inv_rms_x3 + w1_inv = w1 * inv_rms_x2 + w2_inv = w2 * inv_rms_x + + dpoly = go * m + + # Dot products for coefficients and weight grads + sum_dpoly_x = tl.sum(dpoly * x) + sum_dpoly_x2 = tl.sum(dpoly * x2) + sum_dpoly_x3 = tl.sum(dpoly * x3) + grad_b_acc = tl.sum(dpoly) + + # Weight grads + grad_w0_acc = inv_rms_x3 * sum_dpoly_x3 + grad_w1_acc = inv_rms_x2 * sum_dpoly_x2 + grad_w2_acc = inv_rms_x * sum_dpoly_x + + # Coefficients for grad_input + coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float + coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float + coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float + + # grad_mul + poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + tl.store(grad_mul_row_ptr + d_offs, go * (poly + b_val), mask=mask) + + # grad_input (in-place accumulation to reduce register pressure) + g = inv_rms_x * (w2 * dpoly - x * coeff_x) + g += 2.0 * x * inv_rms_x2 * (w1 * dpoly - x2 * coeff_x2) + g += 3.0 * x2 * inv_rms_x3 * (w0 * dpoly - x3 * coeff_x3) + tl.store(grad_input_row_ptr + d_offs, g, mask=mask) + + tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc) + tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc) + tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc) + tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc) + else: + # --- Multi-tile: 2-pass --- + # Pass 1: RMS stats + dot products + bias grad + sum_x2 = tl.zeros((), dtype=tl.float32) + sum_x4 = tl.zeros((), dtype=tl.float32) + sum_x6 = tl.zeros((), dtype=tl.float32) + sum_dpoly_x = tl.zeros((), dtype=tl.float32) + sum_dpoly_x2 = tl.zeros((), dtype=tl.float32) + sum_dpoly_x3 = tl.zeros((), dtype=tl.float32) + grad_b_acc = tl.zeros((), dtype=tl.float32) + + for d_start in range(0, D, BLOCK_D): + d_offs = d_start + tl.arange(0, BLOCK_D) + mask = d_offs < D + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + m = tl.load(mul_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + go = tl.load(grad_out_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + + x2 = x * x + x3 = x2 * x + dpoly = go * m + + sum_x2 += tl.sum(x2) + sum_x4 += tl.sum(x2 * x2) + sum_x6 += tl.sum(x2 * x2 * x2) + sum_dpoly_x += tl.sum(dpoly * x) + sum_dpoly_x2 += tl.sum(dpoly * x2) + sum_dpoly_x3 += tl.sum(dpoly * x3) + grad_b_acc += tl.sum(dpoly) + + inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps) + inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps) + inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps) + + w0_inv = w0 * inv_rms_x3 + w1_inv = w1 * inv_rms_x2 + w2_inv = w2 * inv_rms_x + + # Weight grads + grad_w0_acc = inv_rms_x3 * sum_dpoly_x3 + grad_w1_acc = inv_rms_x2 * sum_dpoly_x2 + grad_w2_acc = inv_rms_x * sum_dpoly_x + + # Coefficients for grad_input + coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float + coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float + coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float + + # Pass 2: grad_input + grad_mul + for d_start in range(0, D, BLOCK_D): + d_offs = d_start + tl.arange(0, BLOCK_D) + mask = d_offs < D + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + m = tl.load(mul_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + go = tl.load(grad_out_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + + x2 = x * x + x3 = x2 * x + + poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + tl.store(grad_mul_row_ptr + d_offs, + go * (poly + b_val), + mask=mask) + + dpoly = go * m + g = inv_rms_x * (w2 * dpoly - x * coeff_x) + g += (2.0 * x * inv_rms_x2 * + (w1 * dpoly - x2 * coeff_x2)) + g += (3.0 * x2 * inv_rms_x3 * + (w0 * dpoly - x3 * coeff_x3)) + + tl.store(grad_input_row_ptr + d_offs, g, mask=mask) + + tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc) + tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc) + tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc) + tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc) + + class _GroupedPolyNormFn(torch.autograd.Function): + + @staticmethod + def forward(ctx, input, mul, weight, bias, offsets, eps, expert_offset): + N, D = input.shape + input = input.contiguous() + mul = mul.contiguous() + output = torch.empty_like(input) + + num_experts = offsets.shape[0] + assert num_experts <= 4096, ( + f"Supports at most 4096 experts, got {num_experts}.") + + _grouped_polynorm_fwd_kernel[(N,)]( + input, + mul, + weight, + bias, + offsets, + output, + N, + D, + num_experts, + eps, + expert_offset, + stride_input_row=input.stride(0), + stride_mul_row=mul.stride(0), + stride_out_row=output.stride(0), + ) + + ctx.save_for_backward(input, mul, weight, bias, offsets) + ctx.eps = eps + ctx.expert_offset = expert_offset + return output + + @staticmethod + def backward(ctx, grad_output): + input, mul, weight, bias, offsets = ctx.saved_tensors + eps = ctx.eps + expert_offset = ctx.expert_offset + N, D = input.shape + + grad_output = grad_output.contiguous() + grad_input = torch.empty_like(input) + grad_mul = torch.empty_like(mul) + grad_weight = torch.zeros(weight.shape[0], + 3, + device=weight.device, + dtype=torch.float32) + grad_bias = torch.zeros(bias.shape[0], + device=bias.device, + dtype=torch.float32) + + num_experts = offsets.shape[0] + + _grouped_polynorm_bwd_kernel[(N,)]( + grad_output, + input, + mul, + weight, + bias, + offsets, + grad_input, + grad_mul, + grad_weight, + grad_bias, + N, + D, + num_experts, + eps, + expert_offset, + stride_row=input.stride(0), + ) + + grad_weight = grad_weight.to(weight.dtype) + grad_bias = grad_bias.unsqueeze(-1).to(bias.dtype) + + return grad_input, grad_mul, grad_weight, grad_bias, None, None, None + + def grouped_fused_mul_poly_norm( + input: Tensor, + mul: Tensor, + weight: Tensor, + bias: Tensor, + offsets: Tensor, + eps: float = 1e-6, + expert_offset: int = 0, + ) -> Tensor: + """Triton-accelerated Grouped FusedMulPolyNorm. + + Args: + input: (total_tokens, D) - concatenated tokens for all experts + mul: (total_tokens, D) - gate values to multiply with + weight: (num_experts, 3) - per-expert polynomial weights + bias: (num_experts, 1) - per-expert polynomial bias + offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32) + eps: numerical stability epsilon + expert_offset: offset to add to expert index + + Returns: + (total_tokens, D) - output tensor + """ + return _GroupedPolyNormFn.apply(input, mul, weight, bias, offsets, eps, + expert_offset) + +else: + + def grouped_fused_mul_poly_norm( + input: Tensor, + mul: Tensor, + weight: Tensor, + bias: Tensor, + offsets: Tensor, + eps: float = 1e-6, + expert_offset: int = 0, + ) -> Tensor: + raise RuntimeError( + "Triton is not available. Install triton to use " + "grouped_fused_mul_poly_norm.") diff --git a/build/torch210-cxx11-rocm71-x86_64-linux/__init__.py b/build/torch210-cxx11-rocm71-x86_64-linux/__init__.py index 0f6f29ac2c688bd09afa41c5d1abd9942c4456d8..a71567edf5422a315e6f06eae6a5606756794820 100644 --- a/build/torch210-cxx11-rocm71-x86_64-linux/__init__.py +++ b/build/torch210-cxx11-rocm71-x86_64-linux/__init__.py @@ -2,6 +2,7 @@ import torch from . import layers, parallel_style from ._ops import ops +from .grouped_poly_norm import grouped_fused_mul_poly_norm from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction @@ -45,6 +46,7 @@ def fused_add_rms_norm( __all__ = [ "poly_norm", "fused_mul_poly_norm", + "grouped_fused_mul_poly_norm", "rms_norm", "fused_add_rms_norm", "layers", diff --git a/build/torch210-cxx11-rocm71-x86_64-linux/_activation_0e6f27f_dirty.abi3.so b/build/torch210-cxx11-rocm71-x86_64-linux/_activation_0e6f27f_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..4401763052dc74bcd65e72216e2dd7bc239d7f91 --- /dev/null +++ b/build/torch210-cxx11-rocm71-x86_64-linux/_activation_0e6f27f_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:28e6de9c3cd172e95284b0df0e15a2afb21ab7b89dd624e69b1361942095e8be +size 2911200 diff --git a/build/torch210-cxx11-rocm71-x86_64-linux/_activation_18b7543_dirty.abi3.so b/build/torch210-cxx11-rocm71-x86_64-linux/_activation_18b7543_dirty.abi3.so deleted file mode 100755 index 262fc74841f5f4af7bb17361e9028081723f235e..0000000000000000000000000000000000000000 --- a/build/torch210-cxx11-rocm71-x86_64-linux/_activation_18b7543_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:af4db38e8d5ad56226f5a95a86c2b5fc726bd9d576d07df2f07d3f03c1b6b35b -size 2911200 diff --git a/build/torch210-cxx11-rocm71-x86_64-linux/_ops.py b/build/torch210-cxx11-rocm71-x86_64-linux/_ops.py index 1e14ae6ea50287a7fbc0feee54855ae4ffc0be5e..91c4fcaab95d4cccfa1e53fea482ec1a2334f4ff 100644 --- a/build/torch210-cxx11-rocm71-x86_64-linux/_ops.py +++ b/build/torch210-cxx11-rocm71-x86_64-linux/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _activation_18b7543_dirty -ops = torch.ops._activation_18b7543_dirty +from . import _activation_0e6f27f_dirty +ops = torch.ops._activation_0e6f27f_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_activation_18b7543_dirty::{op_name}" \ No newline at end of file + return f"_activation_0e6f27f_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch210-cxx11-rocm71-x86_64-linux/grouped_poly_norm.py b/build/torch210-cxx11-rocm71-x86_64-linux/grouped_poly_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..9184736c1c24584a94aa80c84fc18a95ab557a52 --- /dev/null +++ b/build/torch210-cxx11-rocm71-x86_64-linux/grouped_poly_norm.py @@ -0,0 +1,583 @@ +"""Triton-accelerated Grouped FusedMulPolyNorm for MoE. + +Fuses the entire PolyNorm computation into two Triton kernels (fwd + bwd), +eliminating multiple intermediate tensors and kernel launches. + +PolyNorm formula (per row): + poly = w[0] * rms_norm(x^3) + w[1] * rms_norm(x^2) + w[2] * rms_norm(x) + bias + output = poly * mul + +where rms_norm(x) = x / sqrt(mean(x^2, dim=-1) + eps) + +Performance optimizations: + - @triton.autotune selects optimal BLOCK_D, num_warps, and num_stages per + hidden dimension. + - Single-tile specialization: when D <= BLOCK_D, all data stays in registers + across the reduction and output phases, eliminating redundant global reads. + - Multi-tile software pipelining: explicit num_stages in autotune configs + enables overlapping memory loads with computation across loop iterations. + - In-kernel binary search for expert mapping: eliminates 2 PyTorch kernel + launches (torch.arange + torch.bucketize) per forward/backward call. + - Backward 2-pass optimization: pass 1 merges RMS statistics computation + with dot product accumulation, pass 2 computes gradients. This reduces + memory traffic compared to a naive 3-pass approach. + +Forward kernel: one program per row, tiles over D dimension. + - Computes x, x^2, x^3 in registers + - Computes three RMS norms in a single pass (shared variance reduction) + - Applies polynomial weights + bias + mul in-place + +Backward kernel: one program per row, tiles over D dimension. + - Recomputes forward intermediates from saved inputs (activation recomputation) + - 2-pass: (1) RMS stats + dot products + bias grad, (2) grad_input + grad_mul + weight grads + - Weight/bias gradients use tl.atomic_add for cross-row accumulation +""" + +import torch +from torch import Tensor + +try: + import triton + import triton.language as tl + + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + + +# --------------------------------------------------------------------------- +# PyTorch reference implementation (for testing and benchmarking) +# --------------------------------------------------------------------------- +def _rms_norm(x: Tensor, eps: float) -> Tensor: + """Per-row RMS normalization: x / sqrt(mean(x^2, dim=-1) + eps)""" + return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps) + + +def grouped_fused_mul_poly_norm_ref( + input: Tensor, + mul: Tensor, + weight: Tensor, + bias: Tensor, + offsets: Tensor, + eps: float = 1e-6, + expert_offset: int = 0, +) -> Tensor: + """PyTorch reference for Grouped FusedMulPolyNorm (vectorized, single pass). + + Uses torch.bucketize to map tokens to experts, then computes PolyNorm + for all tokens at once. torch.compile friendly. + + Args: + input: (total_tokens, D) - concatenated tokens for all experts + mul: (total_tokens, D) - gate values to multiply with + weight: (num_experts, 3) - per-expert polynomial weights [x^3, x^2, x] + bias: (num_experts, 1) - per-expert polynomial bias + offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32) + eps: numerical stability epsilon + + Returns: + (total_tokens, D) - output tensor + """ + orig_dtype = input.dtype + + token_positions = torch.arange(input.shape[0], device=input.device) + expert_idx = torch.bucketize(token_positions, offsets, right=True) + expert_offset + + weight_fp32 = weight.float() + bias_fp32 = bias.float() + + per_token_w = weight_fp32[expert_idx] + per_token_b = bias_fp32[expert_idx] + + x = input.float() + m = mul.float() + + x2 = x * x + x3 = x2 * x + + poly = (per_token_w[:, 0:1] * _rms_norm(x3, eps) + + per_token_w[:, 1:2] * _rms_norm(x2, eps) + + per_token_w[:, 2:3] * _rms_norm(x, eps) + per_token_b) + + return (poly * m).to(orig_dtype) + + +# --------------------------------------------------------------------------- +# Triton kernel implementation +# --------------------------------------------------------------------------- +if HAS_TRITON: + # --- Autotune configurations --- + _GROUPED_POLYNORM_FWD_CONFIGS = [ + triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2), + triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2), + triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2), + triton.Config({"BLOCK_D": 2048}, num_warps=4, num_stages=1), + triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1), + triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1), + triton.Config({"BLOCK_D": 2048}, num_warps=32, num_stages=1), + ] + + _GROUPED_POLYNORM_BWD_CONFIGS = [ + triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2), + triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=5), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2), + triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=4), + triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2), + triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1), + triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1), + ] + + @triton.autotune( + configs=_GROUPED_POLYNORM_FWD_CONFIGS, + key=["D"], + ) + @triton.jit + def _grouped_polynorm_fwd_kernel( + input_ptr, + mul_ptr, + weight_ptr, + bias_ptr, + offsets_ptr, + output_ptr, + N, + D, + num_experts, + eps, + expert_offset, + stride_input_row, + stride_mul_row, + stride_out_row, + BLOCK_D: tl.constexpr, + ): + """Forward kernel: one program per row.""" + row = tl.program_id(0) + if row >= N: + return + + # Binary search for expert index (12 iters covers up to 4096 experts) + lo = 0 + hi = num_experts + for _ in range(12): + if lo < hi: + mid = (lo + hi) // 2 + if tl.load(offsets_ptr + mid) <= row: + lo = mid + 1 + else: + hi = mid + eidx = lo + expert_offset + + w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32) + w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32) + w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32) + b = tl.load(bias_ptr + eidx).to(tl.float32) + + input_row_ptr = input_ptr + row * stride_input_row + mul_row_ptr = mul_ptr + row * stride_mul_row + out_row_ptr = output_ptr + row * stride_out_row + + D_float = D.to(tl.float32) + + # --- Single-tile path --- + if D <= BLOCK_D: + d_offs = tl.arange(0, BLOCK_D) + mask = d_offs < D + + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + m = tl.load(mul_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + + x2 = x * x + x3 = x2 * x + + inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps) + inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps) + inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps) + + # Pre-multiply scalar weight * inv_rms to save 1 FMA per element + w0_inv = w0 * inv_rms_x3 + w1_inv = w1 * inv_rms_x2 + w2_inv = w2 * inv_rms_x + + poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b + tl.store(out_row_ptr + d_offs, poly * m, mask=mask) + else: + # --- Multi-tile: two-pass approach --- + sum_x2 = tl.zeros((), dtype=tl.float32) + sum_x4 = tl.zeros((), dtype=tl.float32) + sum_x6 = tl.zeros((), dtype=tl.float32) + + for d_start in range(0, D, BLOCK_D): + d_offs = d_start + tl.arange(0, BLOCK_D) + mask = d_offs < D + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + x2 = x * x + sum_x2 += tl.sum(x2) + sum_x4 += tl.sum(x2 * x2) + sum_x6 += tl.sum(x2 * x2 * x2) + + inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps) + inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps) + inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps) + + # Pre-multiply scalar weight * inv_rms + w0_inv = w0 * inv_rms_x3 + w1_inv = w1 * inv_rms_x2 + w2_inv = w2 * inv_rms_x + + for d_start in range(0, D, BLOCK_D): + d_offs = d_start + tl.arange(0, BLOCK_D) + mask = d_offs < D + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + m = tl.load(mul_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + x2 = x * x + x3 = x2 * x + poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b + tl.store(out_row_ptr + d_offs, poly * m, mask=mask) + + @triton.autotune( + configs=_GROUPED_POLYNORM_BWD_CONFIGS, + key=["D"], + reset_to_zero=["grad_weight_ptr", "grad_bias_ptr"], + ) + @triton.jit + def _grouped_polynorm_bwd_kernel( + grad_out_ptr, + input_ptr, + mul_ptr, + weight_ptr, + bias_ptr, + offsets_ptr, + grad_input_ptr, + grad_mul_ptr, + grad_weight_ptr, + grad_bias_ptr, + N, + D, + num_experts, + eps, + expert_offset, + stride_row, + BLOCK_D: tl.constexpr, + ): + """Backward kernel: one program per row, 2-pass approach. + + Pass 1: RMS stats + dot products + bias grad + Pass 2: grad_input + grad_mul + weight grads (via atomic_add) + """ + row = tl.program_id(0) + if row >= N: + return + + lo = 0 + hi = num_experts + for _ in range(12): + if lo < hi: + mid = (lo + hi) // 2 + if tl.load(offsets_ptr + mid) <= row: + lo = mid + 1 + else: + hi = mid + eidx = lo + expert_offset + + w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32) + w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32) + w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32) + b_val = tl.load(bias_ptr + eidx).to(tl.float32) + + input_row_ptr = input_ptr + row * stride_row + mul_row_ptr = mul_ptr + row * stride_row + grad_out_row_ptr = grad_out_ptr + row * stride_row + grad_input_row_ptr = grad_input_ptr + row * stride_row + grad_mul_row_ptr = grad_mul_ptr + row * stride_row + + D_float = D.to(tl.float32) + + # --- Single-tile path --- + if D <= BLOCK_D: + d_offs = tl.arange(0, BLOCK_D) + mask = d_offs < D + + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + m = tl.load(mul_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + go = tl.load(grad_out_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + + x2 = x * x + x3 = x2 * x + + # Compute RMS stats (x4 inlined to reduce register pressure) + inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps) + inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps) + inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps) + + w0_inv = w0 * inv_rms_x3 + w1_inv = w1 * inv_rms_x2 + w2_inv = w2 * inv_rms_x + + dpoly = go * m + + # Dot products for coefficients and weight grads + sum_dpoly_x = tl.sum(dpoly * x) + sum_dpoly_x2 = tl.sum(dpoly * x2) + sum_dpoly_x3 = tl.sum(dpoly * x3) + grad_b_acc = tl.sum(dpoly) + + # Weight grads + grad_w0_acc = inv_rms_x3 * sum_dpoly_x3 + grad_w1_acc = inv_rms_x2 * sum_dpoly_x2 + grad_w2_acc = inv_rms_x * sum_dpoly_x + + # Coefficients for grad_input + coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float + coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float + coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float + + # grad_mul + poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + tl.store(grad_mul_row_ptr + d_offs, go * (poly + b_val), mask=mask) + + # grad_input (in-place accumulation to reduce register pressure) + g = inv_rms_x * (w2 * dpoly - x * coeff_x) + g += 2.0 * x * inv_rms_x2 * (w1 * dpoly - x2 * coeff_x2) + g += 3.0 * x2 * inv_rms_x3 * (w0 * dpoly - x3 * coeff_x3) + tl.store(grad_input_row_ptr + d_offs, g, mask=mask) + + tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc) + tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc) + tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc) + tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc) + else: + # --- Multi-tile: 2-pass --- + # Pass 1: RMS stats + dot products + bias grad + sum_x2 = tl.zeros((), dtype=tl.float32) + sum_x4 = tl.zeros((), dtype=tl.float32) + sum_x6 = tl.zeros((), dtype=tl.float32) + sum_dpoly_x = tl.zeros((), dtype=tl.float32) + sum_dpoly_x2 = tl.zeros((), dtype=tl.float32) + sum_dpoly_x3 = tl.zeros((), dtype=tl.float32) + grad_b_acc = tl.zeros((), dtype=tl.float32) + + for d_start in range(0, D, BLOCK_D): + d_offs = d_start + tl.arange(0, BLOCK_D) + mask = d_offs < D + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + m = tl.load(mul_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + go = tl.load(grad_out_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + + x2 = x * x + x3 = x2 * x + dpoly = go * m + + sum_x2 += tl.sum(x2) + sum_x4 += tl.sum(x2 * x2) + sum_x6 += tl.sum(x2 * x2 * x2) + sum_dpoly_x += tl.sum(dpoly * x) + sum_dpoly_x2 += tl.sum(dpoly * x2) + sum_dpoly_x3 += tl.sum(dpoly * x3) + grad_b_acc += tl.sum(dpoly) + + inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps) + inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps) + inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps) + + w0_inv = w0 * inv_rms_x3 + w1_inv = w1 * inv_rms_x2 + w2_inv = w2 * inv_rms_x + + # Weight grads + grad_w0_acc = inv_rms_x3 * sum_dpoly_x3 + grad_w1_acc = inv_rms_x2 * sum_dpoly_x2 + grad_w2_acc = inv_rms_x * sum_dpoly_x + + # Coefficients for grad_input + coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float + coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float + coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float + + # Pass 2: grad_input + grad_mul + for d_start in range(0, D, BLOCK_D): + d_offs = d_start + tl.arange(0, BLOCK_D) + mask = d_offs < D + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + m = tl.load(mul_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + go = tl.load(grad_out_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + + x2 = x * x + x3 = x2 * x + + poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + tl.store(grad_mul_row_ptr + d_offs, + go * (poly + b_val), + mask=mask) + + dpoly = go * m + g = inv_rms_x * (w2 * dpoly - x * coeff_x) + g += (2.0 * x * inv_rms_x2 * + (w1 * dpoly - x2 * coeff_x2)) + g += (3.0 * x2 * inv_rms_x3 * + (w0 * dpoly - x3 * coeff_x3)) + + tl.store(grad_input_row_ptr + d_offs, g, mask=mask) + + tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc) + tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc) + tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc) + tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc) + + class _GroupedPolyNormFn(torch.autograd.Function): + + @staticmethod + def forward(ctx, input, mul, weight, bias, offsets, eps, expert_offset): + N, D = input.shape + input = input.contiguous() + mul = mul.contiguous() + output = torch.empty_like(input) + + num_experts = offsets.shape[0] + assert num_experts <= 4096, ( + f"Supports at most 4096 experts, got {num_experts}.") + + _grouped_polynorm_fwd_kernel[(N,)]( + input, + mul, + weight, + bias, + offsets, + output, + N, + D, + num_experts, + eps, + expert_offset, + stride_input_row=input.stride(0), + stride_mul_row=mul.stride(0), + stride_out_row=output.stride(0), + ) + + ctx.save_for_backward(input, mul, weight, bias, offsets) + ctx.eps = eps + ctx.expert_offset = expert_offset + return output + + @staticmethod + def backward(ctx, grad_output): + input, mul, weight, bias, offsets = ctx.saved_tensors + eps = ctx.eps + expert_offset = ctx.expert_offset + N, D = input.shape + + grad_output = grad_output.contiguous() + grad_input = torch.empty_like(input) + grad_mul = torch.empty_like(mul) + grad_weight = torch.zeros(weight.shape[0], + 3, + device=weight.device, + dtype=torch.float32) + grad_bias = torch.zeros(bias.shape[0], + device=bias.device, + dtype=torch.float32) + + num_experts = offsets.shape[0] + + _grouped_polynorm_bwd_kernel[(N,)]( + grad_output, + input, + mul, + weight, + bias, + offsets, + grad_input, + grad_mul, + grad_weight, + grad_bias, + N, + D, + num_experts, + eps, + expert_offset, + stride_row=input.stride(0), + ) + + grad_weight = grad_weight.to(weight.dtype) + grad_bias = grad_bias.unsqueeze(-1).to(bias.dtype) + + return grad_input, grad_mul, grad_weight, grad_bias, None, None, None + + def grouped_fused_mul_poly_norm( + input: Tensor, + mul: Tensor, + weight: Tensor, + bias: Tensor, + offsets: Tensor, + eps: float = 1e-6, + expert_offset: int = 0, + ) -> Tensor: + """Triton-accelerated Grouped FusedMulPolyNorm. + + Args: + input: (total_tokens, D) - concatenated tokens for all experts + mul: (total_tokens, D) - gate values to multiply with + weight: (num_experts, 3) - per-expert polynomial weights + bias: (num_experts, 1) - per-expert polynomial bias + offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32) + eps: numerical stability epsilon + expert_offset: offset to add to expert index + + Returns: + (total_tokens, D) - output tensor + """ + return _GroupedPolyNormFn.apply(input, mul, weight, bias, offsets, eps, + expert_offset) + +else: + + def grouped_fused_mul_poly_norm( + input: Tensor, + mul: Tensor, + weight: Tensor, + bias: Tensor, + offsets: Tensor, + eps: float = 1e-6, + expert_offset: int = 0, + ) -> Tensor: + raise RuntimeError( + "Triton is not available. Install triton to use " + "grouped_fused_mul_poly_norm.") diff --git a/build/torch28-cxx11-cu126-x86_64-linux/__init__.py b/build/torch28-cxx11-cu126-x86_64-linux/__init__.py index 0f6f29ac2c688bd09afa41c5d1abd9942c4456d8..a71567edf5422a315e6f06eae6a5606756794820 100644 --- a/build/torch28-cxx11-cu126-x86_64-linux/__init__.py +++ b/build/torch28-cxx11-cu126-x86_64-linux/__init__.py @@ -2,6 +2,7 @@ import torch from . import layers, parallel_style from ._ops import ops +from .grouped_poly_norm import grouped_fused_mul_poly_norm from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction @@ -45,6 +46,7 @@ def fused_add_rms_norm( __all__ = [ "poly_norm", "fused_mul_poly_norm", + "grouped_fused_mul_poly_norm", "rms_norm", "fused_add_rms_norm", "layers", diff --git a/build/torch28-cxx11-cu126-x86_64-linux/_activation_0e6f27f_dirty.abi3.so b/build/torch28-cxx11-cu126-x86_64-linux/_activation_0e6f27f_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..7eeea2956f47f94084a31c80fd20218dd4e9a078 --- /dev/null +++ b/build/torch28-cxx11-cu126-x86_64-linux/_activation_0e6f27f_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:51ac098828ee90af0d1d17ae75326f89777b1b2e7ef57e00035aed560c434a20 +size 10756352 diff --git a/build/torch28-cxx11-cu126-x86_64-linux/_activation_18b7543_dirty.abi3.so b/build/torch28-cxx11-cu126-x86_64-linux/_activation_18b7543_dirty.abi3.so deleted file mode 100755 index e5607fd3688771cb7bfea8abd9f2808ee6c8ca59..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu126-x86_64-linux/_activation_18b7543_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:0cc858ef4dce1d14e3a74f5f4a17a2d6c6a8c54cba436f938449c859dd84c3b1 -size 10756352 diff --git a/build/torch28-cxx11-cu126-x86_64-linux/_ops.py b/build/torch28-cxx11-cu126-x86_64-linux/_ops.py index 1e14ae6ea50287a7fbc0feee54855ae4ffc0be5e..91c4fcaab95d4cccfa1e53fea482ec1a2334f4ff 100644 --- a/build/torch28-cxx11-cu126-x86_64-linux/_ops.py +++ b/build/torch28-cxx11-cu126-x86_64-linux/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _activation_18b7543_dirty -ops = torch.ops._activation_18b7543_dirty +from . import _activation_0e6f27f_dirty +ops = torch.ops._activation_0e6f27f_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_activation_18b7543_dirty::{op_name}" \ No newline at end of file + return f"_activation_0e6f27f_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-cu126-x86_64-linux/grouped_poly_norm.py b/build/torch28-cxx11-cu126-x86_64-linux/grouped_poly_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..9184736c1c24584a94aa80c84fc18a95ab557a52 --- /dev/null +++ b/build/torch28-cxx11-cu126-x86_64-linux/grouped_poly_norm.py @@ -0,0 +1,583 @@ +"""Triton-accelerated Grouped FusedMulPolyNorm for MoE. + +Fuses the entire PolyNorm computation into two Triton kernels (fwd + bwd), +eliminating multiple intermediate tensors and kernel launches. + +PolyNorm formula (per row): + poly = w[0] * rms_norm(x^3) + w[1] * rms_norm(x^2) + w[2] * rms_norm(x) + bias + output = poly * mul + +where rms_norm(x) = x / sqrt(mean(x^2, dim=-1) + eps) + +Performance optimizations: + - @triton.autotune selects optimal BLOCK_D, num_warps, and num_stages per + hidden dimension. + - Single-tile specialization: when D <= BLOCK_D, all data stays in registers + across the reduction and output phases, eliminating redundant global reads. + - Multi-tile software pipelining: explicit num_stages in autotune configs + enables overlapping memory loads with computation across loop iterations. + - In-kernel binary search for expert mapping: eliminates 2 PyTorch kernel + launches (torch.arange + torch.bucketize) per forward/backward call. + - Backward 2-pass optimization: pass 1 merges RMS statistics computation + with dot product accumulation, pass 2 computes gradients. This reduces + memory traffic compared to a naive 3-pass approach. + +Forward kernel: one program per row, tiles over D dimension. + - Computes x, x^2, x^3 in registers + - Computes three RMS norms in a single pass (shared variance reduction) + - Applies polynomial weights + bias + mul in-place + +Backward kernel: one program per row, tiles over D dimension. + - Recomputes forward intermediates from saved inputs (activation recomputation) + - 2-pass: (1) RMS stats + dot products + bias grad, (2) grad_input + grad_mul + weight grads + - Weight/bias gradients use tl.atomic_add for cross-row accumulation +""" + +import torch +from torch import Tensor + +try: + import triton + import triton.language as tl + + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + + +# --------------------------------------------------------------------------- +# PyTorch reference implementation (for testing and benchmarking) +# --------------------------------------------------------------------------- +def _rms_norm(x: Tensor, eps: float) -> Tensor: + """Per-row RMS normalization: x / sqrt(mean(x^2, dim=-1) + eps)""" + return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps) + + +def grouped_fused_mul_poly_norm_ref( + input: Tensor, + mul: Tensor, + weight: Tensor, + bias: Tensor, + offsets: Tensor, + eps: float = 1e-6, + expert_offset: int = 0, +) -> Tensor: + """PyTorch reference for Grouped FusedMulPolyNorm (vectorized, single pass). + + Uses torch.bucketize to map tokens to experts, then computes PolyNorm + for all tokens at once. torch.compile friendly. + + Args: + input: (total_tokens, D) - concatenated tokens for all experts + mul: (total_tokens, D) - gate values to multiply with + weight: (num_experts, 3) - per-expert polynomial weights [x^3, x^2, x] + bias: (num_experts, 1) - per-expert polynomial bias + offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32) + eps: numerical stability epsilon + + Returns: + (total_tokens, D) - output tensor + """ + orig_dtype = input.dtype + + token_positions = torch.arange(input.shape[0], device=input.device) + expert_idx = torch.bucketize(token_positions, offsets, right=True) + expert_offset + + weight_fp32 = weight.float() + bias_fp32 = bias.float() + + per_token_w = weight_fp32[expert_idx] + per_token_b = bias_fp32[expert_idx] + + x = input.float() + m = mul.float() + + x2 = x * x + x3 = x2 * x + + poly = (per_token_w[:, 0:1] * _rms_norm(x3, eps) + + per_token_w[:, 1:2] * _rms_norm(x2, eps) + + per_token_w[:, 2:3] * _rms_norm(x, eps) + per_token_b) + + return (poly * m).to(orig_dtype) + + +# --------------------------------------------------------------------------- +# Triton kernel implementation +# --------------------------------------------------------------------------- +if HAS_TRITON: + # --- Autotune configurations --- + _GROUPED_POLYNORM_FWD_CONFIGS = [ + triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2), + triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2), + triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2), + triton.Config({"BLOCK_D": 2048}, num_warps=4, num_stages=1), + triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1), + triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1), + triton.Config({"BLOCK_D": 2048}, num_warps=32, num_stages=1), + ] + + _GROUPED_POLYNORM_BWD_CONFIGS = [ + triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2), + triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=5), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2), + triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=4), + triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2), + triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1), + triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1), + ] + + @triton.autotune( + configs=_GROUPED_POLYNORM_FWD_CONFIGS, + key=["D"], + ) + @triton.jit + def _grouped_polynorm_fwd_kernel( + input_ptr, + mul_ptr, + weight_ptr, + bias_ptr, + offsets_ptr, + output_ptr, + N, + D, + num_experts, + eps, + expert_offset, + stride_input_row, + stride_mul_row, + stride_out_row, + BLOCK_D: tl.constexpr, + ): + """Forward kernel: one program per row.""" + row = tl.program_id(0) + if row >= N: + return + + # Binary search for expert index (12 iters covers up to 4096 experts) + lo = 0 + hi = num_experts + for _ in range(12): + if lo < hi: + mid = (lo + hi) // 2 + if tl.load(offsets_ptr + mid) <= row: + lo = mid + 1 + else: + hi = mid + eidx = lo + expert_offset + + w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32) + w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32) + w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32) + b = tl.load(bias_ptr + eidx).to(tl.float32) + + input_row_ptr = input_ptr + row * stride_input_row + mul_row_ptr = mul_ptr + row * stride_mul_row + out_row_ptr = output_ptr + row * stride_out_row + + D_float = D.to(tl.float32) + + # --- Single-tile path --- + if D <= BLOCK_D: + d_offs = tl.arange(0, BLOCK_D) + mask = d_offs < D + + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + m = tl.load(mul_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + + x2 = x * x + x3 = x2 * x + + inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps) + inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps) + inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps) + + # Pre-multiply scalar weight * inv_rms to save 1 FMA per element + w0_inv = w0 * inv_rms_x3 + w1_inv = w1 * inv_rms_x2 + w2_inv = w2 * inv_rms_x + + poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b + tl.store(out_row_ptr + d_offs, poly * m, mask=mask) + else: + # --- Multi-tile: two-pass approach --- + sum_x2 = tl.zeros((), dtype=tl.float32) + sum_x4 = tl.zeros((), dtype=tl.float32) + sum_x6 = tl.zeros((), dtype=tl.float32) + + for d_start in range(0, D, BLOCK_D): + d_offs = d_start + tl.arange(0, BLOCK_D) + mask = d_offs < D + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + x2 = x * x + sum_x2 += tl.sum(x2) + sum_x4 += tl.sum(x2 * x2) + sum_x6 += tl.sum(x2 * x2 * x2) + + inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps) + inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps) + inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps) + + # Pre-multiply scalar weight * inv_rms + w0_inv = w0 * inv_rms_x3 + w1_inv = w1 * inv_rms_x2 + w2_inv = w2 * inv_rms_x + + for d_start in range(0, D, BLOCK_D): + d_offs = d_start + tl.arange(0, BLOCK_D) + mask = d_offs < D + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + m = tl.load(mul_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + x2 = x * x + x3 = x2 * x + poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b + tl.store(out_row_ptr + d_offs, poly * m, mask=mask) + + @triton.autotune( + configs=_GROUPED_POLYNORM_BWD_CONFIGS, + key=["D"], + reset_to_zero=["grad_weight_ptr", "grad_bias_ptr"], + ) + @triton.jit + def _grouped_polynorm_bwd_kernel( + grad_out_ptr, + input_ptr, + mul_ptr, + weight_ptr, + bias_ptr, + offsets_ptr, + grad_input_ptr, + grad_mul_ptr, + grad_weight_ptr, + grad_bias_ptr, + N, + D, + num_experts, + eps, + expert_offset, + stride_row, + BLOCK_D: tl.constexpr, + ): + """Backward kernel: one program per row, 2-pass approach. + + Pass 1: RMS stats + dot products + bias grad + Pass 2: grad_input + grad_mul + weight grads (via atomic_add) + """ + row = tl.program_id(0) + if row >= N: + return + + lo = 0 + hi = num_experts + for _ in range(12): + if lo < hi: + mid = (lo + hi) // 2 + if tl.load(offsets_ptr + mid) <= row: + lo = mid + 1 + else: + hi = mid + eidx = lo + expert_offset + + w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32) + w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32) + w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32) + b_val = tl.load(bias_ptr + eidx).to(tl.float32) + + input_row_ptr = input_ptr + row * stride_row + mul_row_ptr = mul_ptr + row * stride_row + grad_out_row_ptr = grad_out_ptr + row * stride_row + grad_input_row_ptr = grad_input_ptr + row * stride_row + grad_mul_row_ptr = grad_mul_ptr + row * stride_row + + D_float = D.to(tl.float32) + + # --- Single-tile path --- + if D <= BLOCK_D: + d_offs = tl.arange(0, BLOCK_D) + mask = d_offs < D + + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + m = tl.load(mul_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + go = tl.load(grad_out_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + + x2 = x * x + x3 = x2 * x + + # Compute RMS stats (x4 inlined to reduce register pressure) + inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps) + inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps) + inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps) + + w0_inv = w0 * inv_rms_x3 + w1_inv = w1 * inv_rms_x2 + w2_inv = w2 * inv_rms_x + + dpoly = go * m + + # Dot products for coefficients and weight grads + sum_dpoly_x = tl.sum(dpoly * x) + sum_dpoly_x2 = tl.sum(dpoly * x2) + sum_dpoly_x3 = tl.sum(dpoly * x3) + grad_b_acc = tl.sum(dpoly) + + # Weight grads + grad_w0_acc = inv_rms_x3 * sum_dpoly_x3 + grad_w1_acc = inv_rms_x2 * sum_dpoly_x2 + grad_w2_acc = inv_rms_x * sum_dpoly_x + + # Coefficients for grad_input + coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float + coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float + coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float + + # grad_mul + poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + tl.store(grad_mul_row_ptr + d_offs, go * (poly + b_val), mask=mask) + + # grad_input (in-place accumulation to reduce register pressure) + g = inv_rms_x * (w2 * dpoly - x * coeff_x) + g += 2.0 * x * inv_rms_x2 * (w1 * dpoly - x2 * coeff_x2) + g += 3.0 * x2 * inv_rms_x3 * (w0 * dpoly - x3 * coeff_x3) + tl.store(grad_input_row_ptr + d_offs, g, mask=mask) + + tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc) + tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc) + tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc) + tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc) + else: + # --- Multi-tile: 2-pass --- + # Pass 1: RMS stats + dot products + bias grad + sum_x2 = tl.zeros((), dtype=tl.float32) + sum_x4 = tl.zeros((), dtype=tl.float32) + sum_x6 = tl.zeros((), dtype=tl.float32) + sum_dpoly_x = tl.zeros((), dtype=tl.float32) + sum_dpoly_x2 = tl.zeros((), dtype=tl.float32) + sum_dpoly_x3 = tl.zeros((), dtype=tl.float32) + grad_b_acc = tl.zeros((), dtype=tl.float32) + + for d_start in range(0, D, BLOCK_D): + d_offs = d_start + tl.arange(0, BLOCK_D) + mask = d_offs < D + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + m = tl.load(mul_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + go = tl.load(grad_out_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + + x2 = x * x + x3 = x2 * x + dpoly = go * m + + sum_x2 += tl.sum(x2) + sum_x4 += tl.sum(x2 * x2) + sum_x6 += tl.sum(x2 * x2 * x2) + sum_dpoly_x += tl.sum(dpoly * x) + sum_dpoly_x2 += tl.sum(dpoly * x2) + sum_dpoly_x3 += tl.sum(dpoly * x3) + grad_b_acc += tl.sum(dpoly) + + inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps) + inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps) + inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps) + + w0_inv = w0 * inv_rms_x3 + w1_inv = w1 * inv_rms_x2 + w2_inv = w2 * inv_rms_x + + # Weight grads + grad_w0_acc = inv_rms_x3 * sum_dpoly_x3 + grad_w1_acc = inv_rms_x2 * sum_dpoly_x2 + grad_w2_acc = inv_rms_x * sum_dpoly_x + + # Coefficients for grad_input + coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float + coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float + coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float + + # Pass 2: grad_input + grad_mul + for d_start in range(0, D, BLOCK_D): + d_offs = d_start + tl.arange(0, BLOCK_D) + mask = d_offs < D + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + m = tl.load(mul_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + go = tl.load(grad_out_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + + x2 = x * x + x3 = x2 * x + + poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + tl.store(grad_mul_row_ptr + d_offs, + go * (poly + b_val), + mask=mask) + + dpoly = go * m + g = inv_rms_x * (w2 * dpoly - x * coeff_x) + g += (2.0 * x * inv_rms_x2 * + (w1 * dpoly - x2 * coeff_x2)) + g += (3.0 * x2 * inv_rms_x3 * + (w0 * dpoly - x3 * coeff_x3)) + + tl.store(grad_input_row_ptr + d_offs, g, mask=mask) + + tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc) + tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc) + tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc) + tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc) + + class _GroupedPolyNormFn(torch.autograd.Function): + + @staticmethod + def forward(ctx, input, mul, weight, bias, offsets, eps, expert_offset): + N, D = input.shape + input = input.contiguous() + mul = mul.contiguous() + output = torch.empty_like(input) + + num_experts = offsets.shape[0] + assert num_experts <= 4096, ( + f"Supports at most 4096 experts, got {num_experts}.") + + _grouped_polynorm_fwd_kernel[(N,)]( + input, + mul, + weight, + bias, + offsets, + output, + N, + D, + num_experts, + eps, + expert_offset, + stride_input_row=input.stride(0), + stride_mul_row=mul.stride(0), + stride_out_row=output.stride(0), + ) + + ctx.save_for_backward(input, mul, weight, bias, offsets) + ctx.eps = eps + ctx.expert_offset = expert_offset + return output + + @staticmethod + def backward(ctx, grad_output): + input, mul, weight, bias, offsets = ctx.saved_tensors + eps = ctx.eps + expert_offset = ctx.expert_offset + N, D = input.shape + + grad_output = grad_output.contiguous() + grad_input = torch.empty_like(input) + grad_mul = torch.empty_like(mul) + grad_weight = torch.zeros(weight.shape[0], + 3, + device=weight.device, + dtype=torch.float32) + grad_bias = torch.zeros(bias.shape[0], + device=bias.device, + dtype=torch.float32) + + num_experts = offsets.shape[0] + + _grouped_polynorm_bwd_kernel[(N,)]( + grad_output, + input, + mul, + weight, + bias, + offsets, + grad_input, + grad_mul, + grad_weight, + grad_bias, + N, + D, + num_experts, + eps, + expert_offset, + stride_row=input.stride(0), + ) + + grad_weight = grad_weight.to(weight.dtype) + grad_bias = grad_bias.unsqueeze(-1).to(bias.dtype) + + return grad_input, grad_mul, grad_weight, grad_bias, None, None, None + + def grouped_fused_mul_poly_norm( + input: Tensor, + mul: Tensor, + weight: Tensor, + bias: Tensor, + offsets: Tensor, + eps: float = 1e-6, + expert_offset: int = 0, + ) -> Tensor: + """Triton-accelerated Grouped FusedMulPolyNorm. + + Args: + input: (total_tokens, D) - concatenated tokens for all experts + mul: (total_tokens, D) - gate values to multiply with + weight: (num_experts, 3) - per-expert polynomial weights + bias: (num_experts, 1) - per-expert polynomial bias + offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32) + eps: numerical stability epsilon + expert_offset: offset to add to expert index + + Returns: + (total_tokens, D) - output tensor + """ + return _GroupedPolyNormFn.apply(input, mul, weight, bias, offsets, eps, + expert_offset) + +else: + + def grouped_fused_mul_poly_norm( + input: Tensor, + mul: Tensor, + weight: Tensor, + bias: Tensor, + offsets: Tensor, + eps: float = 1e-6, + expert_offset: int = 0, + ) -> Tensor: + raise RuntimeError( + "Triton is not available. Install triton to use " + "grouped_fused_mul_poly_norm.") diff --git a/build/torch28-cxx11-cu128-x86_64-linux/__init__.py b/build/torch28-cxx11-cu128-x86_64-linux/__init__.py index 0f6f29ac2c688bd09afa41c5d1abd9942c4456d8..a71567edf5422a315e6f06eae6a5606756794820 100644 --- a/build/torch28-cxx11-cu128-x86_64-linux/__init__.py +++ b/build/torch28-cxx11-cu128-x86_64-linux/__init__.py @@ -2,6 +2,7 @@ import torch from . import layers, parallel_style from ._ops import ops +from .grouped_poly_norm import grouped_fused_mul_poly_norm from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction @@ -45,6 +46,7 @@ def fused_add_rms_norm( __all__ = [ "poly_norm", "fused_mul_poly_norm", + "grouped_fused_mul_poly_norm", "rms_norm", "fused_add_rms_norm", "layers", diff --git a/build/torch28-cxx11-cu128-x86_64-linux/_activation_0e6f27f_dirty.abi3.so b/build/torch28-cxx11-cu128-x86_64-linux/_activation_0e6f27f_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..c6a85ff536b76cbee70f5e6051d31b7962f47d22 --- /dev/null +++ b/build/torch28-cxx11-cu128-x86_64-linux/_activation_0e6f27f_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7edb027993454d74da9632c7368edc2b0526b5f1ef33ae9e790d49bdf7285640 +size 15804360 diff --git a/build/torch28-cxx11-cu128-x86_64-linux/_activation_18b7543_dirty.abi3.so b/build/torch28-cxx11-cu128-x86_64-linux/_activation_18b7543_dirty.abi3.so deleted file mode 100755 index cbf9fb816fa3782bb93924b7dcf1dc8143c09f3a..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu128-x86_64-linux/_activation_18b7543_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:3f3e704c3833d0d41cf43549ddad15c36c67e4b80b0b2a7af5c6a9a2c488690b -size 15804360 diff --git a/build/torch28-cxx11-cu128-x86_64-linux/_ops.py b/build/torch28-cxx11-cu128-x86_64-linux/_ops.py index 1e14ae6ea50287a7fbc0feee54855ae4ffc0be5e..91c4fcaab95d4cccfa1e53fea482ec1a2334f4ff 100644 --- a/build/torch28-cxx11-cu128-x86_64-linux/_ops.py +++ b/build/torch28-cxx11-cu128-x86_64-linux/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _activation_18b7543_dirty -ops = torch.ops._activation_18b7543_dirty +from . import _activation_0e6f27f_dirty +ops = torch.ops._activation_0e6f27f_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_activation_18b7543_dirty::{op_name}" \ No newline at end of file + return f"_activation_0e6f27f_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-cu128-x86_64-linux/grouped_poly_norm.py b/build/torch28-cxx11-cu128-x86_64-linux/grouped_poly_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..9184736c1c24584a94aa80c84fc18a95ab557a52 --- /dev/null +++ b/build/torch28-cxx11-cu128-x86_64-linux/grouped_poly_norm.py @@ -0,0 +1,583 @@ +"""Triton-accelerated Grouped FusedMulPolyNorm for MoE. + +Fuses the entire PolyNorm computation into two Triton kernels (fwd + bwd), +eliminating multiple intermediate tensors and kernel launches. + +PolyNorm formula (per row): + poly = w[0] * rms_norm(x^3) + w[1] * rms_norm(x^2) + w[2] * rms_norm(x) + bias + output = poly * mul + +where rms_norm(x) = x / sqrt(mean(x^2, dim=-1) + eps) + +Performance optimizations: + - @triton.autotune selects optimal BLOCK_D, num_warps, and num_stages per + hidden dimension. + - Single-tile specialization: when D <= BLOCK_D, all data stays in registers + across the reduction and output phases, eliminating redundant global reads. + - Multi-tile software pipelining: explicit num_stages in autotune configs + enables overlapping memory loads with computation across loop iterations. + - In-kernel binary search for expert mapping: eliminates 2 PyTorch kernel + launches (torch.arange + torch.bucketize) per forward/backward call. + - Backward 2-pass optimization: pass 1 merges RMS statistics computation + with dot product accumulation, pass 2 computes gradients. This reduces + memory traffic compared to a naive 3-pass approach. + +Forward kernel: one program per row, tiles over D dimension. + - Computes x, x^2, x^3 in registers + - Computes three RMS norms in a single pass (shared variance reduction) + - Applies polynomial weights + bias + mul in-place + +Backward kernel: one program per row, tiles over D dimension. + - Recomputes forward intermediates from saved inputs (activation recomputation) + - 2-pass: (1) RMS stats + dot products + bias grad, (2) grad_input + grad_mul + weight grads + - Weight/bias gradients use tl.atomic_add for cross-row accumulation +""" + +import torch +from torch import Tensor + +try: + import triton + import triton.language as tl + + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + + +# --------------------------------------------------------------------------- +# PyTorch reference implementation (for testing and benchmarking) +# --------------------------------------------------------------------------- +def _rms_norm(x: Tensor, eps: float) -> Tensor: + """Per-row RMS normalization: x / sqrt(mean(x^2, dim=-1) + eps)""" + return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps) + + +def grouped_fused_mul_poly_norm_ref( + input: Tensor, + mul: Tensor, + weight: Tensor, + bias: Tensor, + offsets: Tensor, + eps: float = 1e-6, + expert_offset: int = 0, +) -> Tensor: + """PyTorch reference for Grouped FusedMulPolyNorm (vectorized, single pass). + + Uses torch.bucketize to map tokens to experts, then computes PolyNorm + for all tokens at once. torch.compile friendly. + + Args: + input: (total_tokens, D) - concatenated tokens for all experts + mul: (total_tokens, D) - gate values to multiply with + weight: (num_experts, 3) - per-expert polynomial weights [x^3, x^2, x] + bias: (num_experts, 1) - per-expert polynomial bias + offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32) + eps: numerical stability epsilon + + Returns: + (total_tokens, D) - output tensor + """ + orig_dtype = input.dtype + + token_positions = torch.arange(input.shape[0], device=input.device) + expert_idx = torch.bucketize(token_positions, offsets, right=True) + expert_offset + + weight_fp32 = weight.float() + bias_fp32 = bias.float() + + per_token_w = weight_fp32[expert_idx] + per_token_b = bias_fp32[expert_idx] + + x = input.float() + m = mul.float() + + x2 = x * x + x3 = x2 * x + + poly = (per_token_w[:, 0:1] * _rms_norm(x3, eps) + + per_token_w[:, 1:2] * _rms_norm(x2, eps) + + per_token_w[:, 2:3] * _rms_norm(x, eps) + per_token_b) + + return (poly * m).to(orig_dtype) + + +# --------------------------------------------------------------------------- +# Triton kernel implementation +# --------------------------------------------------------------------------- +if HAS_TRITON: + # --- Autotune configurations --- + _GROUPED_POLYNORM_FWD_CONFIGS = [ + triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2), + triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2), + triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2), + triton.Config({"BLOCK_D": 2048}, num_warps=4, num_stages=1), + triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1), + triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1), + triton.Config({"BLOCK_D": 2048}, num_warps=32, num_stages=1), + ] + + _GROUPED_POLYNORM_BWD_CONFIGS = [ + triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2), + triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=5), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2), + triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=4), + triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2), + triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1), + triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1), + ] + + @triton.autotune( + configs=_GROUPED_POLYNORM_FWD_CONFIGS, + key=["D"], + ) + @triton.jit + def _grouped_polynorm_fwd_kernel( + input_ptr, + mul_ptr, + weight_ptr, + bias_ptr, + offsets_ptr, + output_ptr, + N, + D, + num_experts, + eps, + expert_offset, + stride_input_row, + stride_mul_row, + stride_out_row, + BLOCK_D: tl.constexpr, + ): + """Forward kernel: one program per row.""" + row = tl.program_id(0) + if row >= N: + return + + # Binary search for expert index (12 iters covers up to 4096 experts) + lo = 0 + hi = num_experts + for _ in range(12): + if lo < hi: + mid = (lo + hi) // 2 + if tl.load(offsets_ptr + mid) <= row: + lo = mid + 1 + else: + hi = mid + eidx = lo + expert_offset + + w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32) + w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32) + w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32) + b = tl.load(bias_ptr + eidx).to(tl.float32) + + input_row_ptr = input_ptr + row * stride_input_row + mul_row_ptr = mul_ptr + row * stride_mul_row + out_row_ptr = output_ptr + row * stride_out_row + + D_float = D.to(tl.float32) + + # --- Single-tile path --- + if D <= BLOCK_D: + d_offs = tl.arange(0, BLOCK_D) + mask = d_offs < D + + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + m = tl.load(mul_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + + x2 = x * x + x3 = x2 * x + + inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps) + inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps) + inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps) + + # Pre-multiply scalar weight * inv_rms to save 1 FMA per element + w0_inv = w0 * inv_rms_x3 + w1_inv = w1 * inv_rms_x2 + w2_inv = w2 * inv_rms_x + + poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b + tl.store(out_row_ptr + d_offs, poly * m, mask=mask) + else: + # --- Multi-tile: two-pass approach --- + sum_x2 = tl.zeros((), dtype=tl.float32) + sum_x4 = tl.zeros((), dtype=tl.float32) + sum_x6 = tl.zeros((), dtype=tl.float32) + + for d_start in range(0, D, BLOCK_D): + d_offs = d_start + tl.arange(0, BLOCK_D) + mask = d_offs < D + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + x2 = x * x + sum_x2 += tl.sum(x2) + sum_x4 += tl.sum(x2 * x2) + sum_x6 += tl.sum(x2 * x2 * x2) + + inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps) + inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps) + inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps) + + # Pre-multiply scalar weight * inv_rms + w0_inv = w0 * inv_rms_x3 + w1_inv = w1 * inv_rms_x2 + w2_inv = w2 * inv_rms_x + + for d_start in range(0, D, BLOCK_D): + d_offs = d_start + tl.arange(0, BLOCK_D) + mask = d_offs < D + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + m = tl.load(mul_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + x2 = x * x + x3 = x2 * x + poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b + tl.store(out_row_ptr + d_offs, poly * m, mask=mask) + + @triton.autotune( + configs=_GROUPED_POLYNORM_BWD_CONFIGS, + key=["D"], + reset_to_zero=["grad_weight_ptr", "grad_bias_ptr"], + ) + @triton.jit + def _grouped_polynorm_bwd_kernel( + grad_out_ptr, + input_ptr, + mul_ptr, + weight_ptr, + bias_ptr, + offsets_ptr, + grad_input_ptr, + grad_mul_ptr, + grad_weight_ptr, + grad_bias_ptr, + N, + D, + num_experts, + eps, + expert_offset, + stride_row, + BLOCK_D: tl.constexpr, + ): + """Backward kernel: one program per row, 2-pass approach. + + Pass 1: RMS stats + dot products + bias grad + Pass 2: grad_input + grad_mul + weight grads (via atomic_add) + """ + row = tl.program_id(0) + if row >= N: + return + + lo = 0 + hi = num_experts + for _ in range(12): + if lo < hi: + mid = (lo + hi) // 2 + if tl.load(offsets_ptr + mid) <= row: + lo = mid + 1 + else: + hi = mid + eidx = lo + expert_offset + + w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32) + w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32) + w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32) + b_val = tl.load(bias_ptr + eidx).to(tl.float32) + + input_row_ptr = input_ptr + row * stride_row + mul_row_ptr = mul_ptr + row * stride_row + grad_out_row_ptr = grad_out_ptr + row * stride_row + grad_input_row_ptr = grad_input_ptr + row * stride_row + grad_mul_row_ptr = grad_mul_ptr + row * stride_row + + D_float = D.to(tl.float32) + + # --- Single-tile path --- + if D <= BLOCK_D: + d_offs = tl.arange(0, BLOCK_D) + mask = d_offs < D + + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + m = tl.load(mul_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + go = tl.load(grad_out_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + + x2 = x * x + x3 = x2 * x + + # Compute RMS stats (x4 inlined to reduce register pressure) + inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps) + inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps) + inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps) + + w0_inv = w0 * inv_rms_x3 + w1_inv = w1 * inv_rms_x2 + w2_inv = w2 * inv_rms_x + + dpoly = go * m + + # Dot products for coefficients and weight grads + sum_dpoly_x = tl.sum(dpoly * x) + sum_dpoly_x2 = tl.sum(dpoly * x2) + sum_dpoly_x3 = tl.sum(dpoly * x3) + grad_b_acc = tl.sum(dpoly) + + # Weight grads + grad_w0_acc = inv_rms_x3 * sum_dpoly_x3 + grad_w1_acc = inv_rms_x2 * sum_dpoly_x2 + grad_w2_acc = inv_rms_x * sum_dpoly_x + + # Coefficients for grad_input + coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float + coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float + coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float + + # grad_mul + poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + tl.store(grad_mul_row_ptr + d_offs, go * (poly + b_val), mask=mask) + + # grad_input (in-place accumulation to reduce register pressure) + g = inv_rms_x * (w2 * dpoly - x * coeff_x) + g += 2.0 * x * inv_rms_x2 * (w1 * dpoly - x2 * coeff_x2) + g += 3.0 * x2 * inv_rms_x3 * (w0 * dpoly - x3 * coeff_x3) + tl.store(grad_input_row_ptr + d_offs, g, mask=mask) + + tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc) + tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc) + tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc) + tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc) + else: + # --- Multi-tile: 2-pass --- + # Pass 1: RMS stats + dot products + bias grad + sum_x2 = tl.zeros((), dtype=tl.float32) + sum_x4 = tl.zeros((), dtype=tl.float32) + sum_x6 = tl.zeros((), dtype=tl.float32) + sum_dpoly_x = tl.zeros((), dtype=tl.float32) + sum_dpoly_x2 = tl.zeros((), dtype=tl.float32) + sum_dpoly_x3 = tl.zeros((), dtype=tl.float32) + grad_b_acc = tl.zeros((), dtype=tl.float32) + + for d_start in range(0, D, BLOCK_D): + d_offs = d_start + tl.arange(0, BLOCK_D) + mask = d_offs < D + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + m = tl.load(mul_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + go = tl.load(grad_out_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + + x2 = x * x + x3 = x2 * x + dpoly = go * m + + sum_x2 += tl.sum(x2) + sum_x4 += tl.sum(x2 * x2) + sum_x6 += tl.sum(x2 * x2 * x2) + sum_dpoly_x += tl.sum(dpoly * x) + sum_dpoly_x2 += tl.sum(dpoly * x2) + sum_dpoly_x3 += tl.sum(dpoly * x3) + grad_b_acc += tl.sum(dpoly) + + inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps) + inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps) + inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps) + + w0_inv = w0 * inv_rms_x3 + w1_inv = w1 * inv_rms_x2 + w2_inv = w2 * inv_rms_x + + # Weight grads + grad_w0_acc = inv_rms_x3 * sum_dpoly_x3 + grad_w1_acc = inv_rms_x2 * sum_dpoly_x2 + grad_w2_acc = inv_rms_x * sum_dpoly_x + + # Coefficients for grad_input + coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float + coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float + coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float + + # Pass 2: grad_input + grad_mul + for d_start in range(0, D, BLOCK_D): + d_offs = d_start + tl.arange(0, BLOCK_D) + mask = d_offs < D + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + m = tl.load(mul_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + go = tl.load(grad_out_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + + x2 = x * x + x3 = x2 * x + + poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + tl.store(grad_mul_row_ptr + d_offs, + go * (poly + b_val), + mask=mask) + + dpoly = go * m + g = inv_rms_x * (w2 * dpoly - x * coeff_x) + g += (2.0 * x * inv_rms_x2 * + (w1 * dpoly - x2 * coeff_x2)) + g += (3.0 * x2 * inv_rms_x3 * + (w0 * dpoly - x3 * coeff_x3)) + + tl.store(grad_input_row_ptr + d_offs, g, mask=mask) + + tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc) + tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc) + tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc) + tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc) + + class _GroupedPolyNormFn(torch.autograd.Function): + + @staticmethod + def forward(ctx, input, mul, weight, bias, offsets, eps, expert_offset): + N, D = input.shape + input = input.contiguous() + mul = mul.contiguous() + output = torch.empty_like(input) + + num_experts = offsets.shape[0] + assert num_experts <= 4096, ( + f"Supports at most 4096 experts, got {num_experts}.") + + _grouped_polynorm_fwd_kernel[(N,)]( + input, + mul, + weight, + bias, + offsets, + output, + N, + D, + num_experts, + eps, + expert_offset, + stride_input_row=input.stride(0), + stride_mul_row=mul.stride(0), + stride_out_row=output.stride(0), + ) + + ctx.save_for_backward(input, mul, weight, bias, offsets) + ctx.eps = eps + ctx.expert_offset = expert_offset + return output + + @staticmethod + def backward(ctx, grad_output): + input, mul, weight, bias, offsets = ctx.saved_tensors + eps = ctx.eps + expert_offset = ctx.expert_offset + N, D = input.shape + + grad_output = grad_output.contiguous() + grad_input = torch.empty_like(input) + grad_mul = torch.empty_like(mul) + grad_weight = torch.zeros(weight.shape[0], + 3, + device=weight.device, + dtype=torch.float32) + grad_bias = torch.zeros(bias.shape[0], + device=bias.device, + dtype=torch.float32) + + num_experts = offsets.shape[0] + + _grouped_polynorm_bwd_kernel[(N,)]( + grad_output, + input, + mul, + weight, + bias, + offsets, + grad_input, + grad_mul, + grad_weight, + grad_bias, + N, + D, + num_experts, + eps, + expert_offset, + stride_row=input.stride(0), + ) + + grad_weight = grad_weight.to(weight.dtype) + grad_bias = grad_bias.unsqueeze(-1).to(bias.dtype) + + return grad_input, grad_mul, grad_weight, grad_bias, None, None, None + + def grouped_fused_mul_poly_norm( + input: Tensor, + mul: Tensor, + weight: Tensor, + bias: Tensor, + offsets: Tensor, + eps: float = 1e-6, + expert_offset: int = 0, + ) -> Tensor: + """Triton-accelerated Grouped FusedMulPolyNorm. + + Args: + input: (total_tokens, D) - concatenated tokens for all experts + mul: (total_tokens, D) - gate values to multiply with + weight: (num_experts, 3) - per-expert polynomial weights + bias: (num_experts, 1) - per-expert polynomial bias + offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32) + eps: numerical stability epsilon + expert_offset: offset to add to expert index + + Returns: + (total_tokens, D) - output tensor + """ + return _GroupedPolyNormFn.apply(input, mul, weight, bias, offsets, eps, + expert_offset) + +else: + + def grouped_fused_mul_poly_norm( + input: Tensor, + mul: Tensor, + weight: Tensor, + bias: Tensor, + offsets: Tensor, + eps: float = 1e-6, + expert_offset: int = 0, + ) -> Tensor: + raise RuntimeError( + "Triton is not available. Install triton to use " + "grouped_fused_mul_poly_norm.") diff --git a/build/torch28-cxx11-cu129-x86_64-linux/__init__.py b/build/torch28-cxx11-cu129-x86_64-linux/__init__.py index 0f6f29ac2c688bd09afa41c5d1abd9942c4456d8..a71567edf5422a315e6f06eae6a5606756794820 100644 --- a/build/torch28-cxx11-cu129-x86_64-linux/__init__.py +++ b/build/torch28-cxx11-cu129-x86_64-linux/__init__.py @@ -2,6 +2,7 @@ import torch from . import layers, parallel_style from ._ops import ops +from .grouped_poly_norm import grouped_fused_mul_poly_norm from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction @@ -45,6 +46,7 @@ def fused_add_rms_norm( __all__ = [ "poly_norm", "fused_mul_poly_norm", + "grouped_fused_mul_poly_norm", "rms_norm", "fused_add_rms_norm", "layers", diff --git a/build/torch28-cxx11-cu129-x86_64-linux/_activation_0e6f27f_dirty.abi3.so b/build/torch28-cxx11-cu129-x86_64-linux/_activation_0e6f27f_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..44426f622e2e75950aa71623220360e652ef8249 --- /dev/null +++ b/build/torch28-cxx11-cu129-x86_64-linux/_activation_0e6f27f_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f06fb9594dcf9c0bc3a8af619fec3a541b775f0dad304a7978314d32fae8d244 +size 15795640 diff --git a/build/torch28-cxx11-cu129-x86_64-linux/_activation_18b7543_dirty.abi3.so b/build/torch28-cxx11-cu129-x86_64-linux/_activation_18b7543_dirty.abi3.so deleted file mode 100755 index 77b558a4079bd17e0c081a80c0ca9891f4807ebb..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu129-x86_64-linux/_activation_18b7543_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:f589ae5cb26c7fb85fac22ee789749c54897bf08b170dbf77b9d62eb98ee8b53 -size 15795640 diff --git a/build/torch28-cxx11-cu129-x86_64-linux/_ops.py b/build/torch28-cxx11-cu129-x86_64-linux/_ops.py index 1e14ae6ea50287a7fbc0feee54855ae4ffc0be5e..91c4fcaab95d4cccfa1e53fea482ec1a2334f4ff 100644 --- a/build/torch28-cxx11-cu129-x86_64-linux/_ops.py +++ b/build/torch28-cxx11-cu129-x86_64-linux/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _activation_18b7543_dirty -ops = torch.ops._activation_18b7543_dirty +from . import _activation_0e6f27f_dirty +ops = torch.ops._activation_0e6f27f_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_activation_18b7543_dirty::{op_name}" \ No newline at end of file + return f"_activation_0e6f27f_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-cu129-x86_64-linux/grouped_poly_norm.py b/build/torch28-cxx11-cu129-x86_64-linux/grouped_poly_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..9184736c1c24584a94aa80c84fc18a95ab557a52 --- /dev/null +++ b/build/torch28-cxx11-cu129-x86_64-linux/grouped_poly_norm.py @@ -0,0 +1,583 @@ +"""Triton-accelerated Grouped FusedMulPolyNorm for MoE. + +Fuses the entire PolyNorm computation into two Triton kernels (fwd + bwd), +eliminating multiple intermediate tensors and kernel launches. + +PolyNorm formula (per row): + poly = w[0] * rms_norm(x^3) + w[1] * rms_norm(x^2) + w[2] * rms_norm(x) + bias + output = poly * mul + +where rms_norm(x) = x / sqrt(mean(x^2, dim=-1) + eps) + +Performance optimizations: + - @triton.autotune selects optimal BLOCK_D, num_warps, and num_stages per + hidden dimension. + - Single-tile specialization: when D <= BLOCK_D, all data stays in registers + across the reduction and output phases, eliminating redundant global reads. + - Multi-tile software pipelining: explicit num_stages in autotune configs + enables overlapping memory loads with computation across loop iterations. + - In-kernel binary search for expert mapping: eliminates 2 PyTorch kernel + launches (torch.arange + torch.bucketize) per forward/backward call. + - Backward 2-pass optimization: pass 1 merges RMS statistics computation + with dot product accumulation, pass 2 computes gradients. This reduces + memory traffic compared to a naive 3-pass approach. + +Forward kernel: one program per row, tiles over D dimension. + - Computes x, x^2, x^3 in registers + - Computes three RMS norms in a single pass (shared variance reduction) + - Applies polynomial weights + bias + mul in-place + +Backward kernel: one program per row, tiles over D dimension. + - Recomputes forward intermediates from saved inputs (activation recomputation) + - 2-pass: (1) RMS stats + dot products + bias grad, (2) grad_input + grad_mul + weight grads + - Weight/bias gradients use tl.atomic_add for cross-row accumulation +""" + +import torch +from torch import Tensor + +try: + import triton + import triton.language as tl + + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + + +# --------------------------------------------------------------------------- +# PyTorch reference implementation (for testing and benchmarking) +# --------------------------------------------------------------------------- +def _rms_norm(x: Tensor, eps: float) -> Tensor: + """Per-row RMS normalization: x / sqrt(mean(x^2, dim=-1) + eps)""" + return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps) + + +def grouped_fused_mul_poly_norm_ref( + input: Tensor, + mul: Tensor, + weight: Tensor, + bias: Tensor, + offsets: Tensor, + eps: float = 1e-6, + expert_offset: int = 0, +) -> Tensor: + """PyTorch reference for Grouped FusedMulPolyNorm (vectorized, single pass). + + Uses torch.bucketize to map tokens to experts, then computes PolyNorm + for all tokens at once. torch.compile friendly. + + Args: + input: (total_tokens, D) - concatenated tokens for all experts + mul: (total_tokens, D) - gate values to multiply with + weight: (num_experts, 3) - per-expert polynomial weights [x^3, x^2, x] + bias: (num_experts, 1) - per-expert polynomial bias + offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32) + eps: numerical stability epsilon + + Returns: + (total_tokens, D) - output tensor + """ + orig_dtype = input.dtype + + token_positions = torch.arange(input.shape[0], device=input.device) + expert_idx = torch.bucketize(token_positions, offsets, right=True) + expert_offset + + weight_fp32 = weight.float() + bias_fp32 = bias.float() + + per_token_w = weight_fp32[expert_idx] + per_token_b = bias_fp32[expert_idx] + + x = input.float() + m = mul.float() + + x2 = x * x + x3 = x2 * x + + poly = (per_token_w[:, 0:1] * _rms_norm(x3, eps) + + per_token_w[:, 1:2] * _rms_norm(x2, eps) + + per_token_w[:, 2:3] * _rms_norm(x, eps) + per_token_b) + + return (poly * m).to(orig_dtype) + + +# --------------------------------------------------------------------------- +# Triton kernel implementation +# --------------------------------------------------------------------------- +if HAS_TRITON: + # --- Autotune configurations --- + _GROUPED_POLYNORM_FWD_CONFIGS = [ + triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2), + triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2), + triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2), + triton.Config({"BLOCK_D": 2048}, num_warps=4, num_stages=1), + triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1), + triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1), + triton.Config({"BLOCK_D": 2048}, num_warps=32, num_stages=1), + ] + + _GROUPED_POLYNORM_BWD_CONFIGS = [ + triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2), + triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=5), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2), + triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=4), + triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2), + triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1), + triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1), + ] + + @triton.autotune( + configs=_GROUPED_POLYNORM_FWD_CONFIGS, + key=["D"], + ) + @triton.jit + def _grouped_polynorm_fwd_kernel( + input_ptr, + mul_ptr, + weight_ptr, + bias_ptr, + offsets_ptr, + output_ptr, + N, + D, + num_experts, + eps, + expert_offset, + stride_input_row, + stride_mul_row, + stride_out_row, + BLOCK_D: tl.constexpr, + ): + """Forward kernel: one program per row.""" + row = tl.program_id(0) + if row >= N: + return + + # Binary search for expert index (12 iters covers up to 4096 experts) + lo = 0 + hi = num_experts + for _ in range(12): + if lo < hi: + mid = (lo + hi) // 2 + if tl.load(offsets_ptr + mid) <= row: + lo = mid + 1 + else: + hi = mid + eidx = lo + expert_offset + + w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32) + w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32) + w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32) + b = tl.load(bias_ptr + eidx).to(tl.float32) + + input_row_ptr = input_ptr + row * stride_input_row + mul_row_ptr = mul_ptr + row * stride_mul_row + out_row_ptr = output_ptr + row * stride_out_row + + D_float = D.to(tl.float32) + + # --- Single-tile path --- + if D <= BLOCK_D: + d_offs = tl.arange(0, BLOCK_D) + mask = d_offs < D + + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + m = tl.load(mul_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + + x2 = x * x + x3 = x2 * x + + inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps) + inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps) + inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps) + + # Pre-multiply scalar weight * inv_rms to save 1 FMA per element + w0_inv = w0 * inv_rms_x3 + w1_inv = w1 * inv_rms_x2 + w2_inv = w2 * inv_rms_x + + poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b + tl.store(out_row_ptr + d_offs, poly * m, mask=mask) + else: + # --- Multi-tile: two-pass approach --- + sum_x2 = tl.zeros((), dtype=tl.float32) + sum_x4 = tl.zeros((), dtype=tl.float32) + sum_x6 = tl.zeros((), dtype=tl.float32) + + for d_start in range(0, D, BLOCK_D): + d_offs = d_start + tl.arange(0, BLOCK_D) + mask = d_offs < D + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + x2 = x * x + sum_x2 += tl.sum(x2) + sum_x4 += tl.sum(x2 * x2) + sum_x6 += tl.sum(x2 * x2 * x2) + + inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps) + inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps) + inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps) + + # Pre-multiply scalar weight * inv_rms + w0_inv = w0 * inv_rms_x3 + w1_inv = w1 * inv_rms_x2 + w2_inv = w2 * inv_rms_x + + for d_start in range(0, D, BLOCK_D): + d_offs = d_start + tl.arange(0, BLOCK_D) + mask = d_offs < D + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + m = tl.load(mul_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + x2 = x * x + x3 = x2 * x + poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b + tl.store(out_row_ptr + d_offs, poly * m, mask=mask) + + @triton.autotune( + configs=_GROUPED_POLYNORM_BWD_CONFIGS, + key=["D"], + reset_to_zero=["grad_weight_ptr", "grad_bias_ptr"], + ) + @triton.jit + def _grouped_polynorm_bwd_kernel( + grad_out_ptr, + input_ptr, + mul_ptr, + weight_ptr, + bias_ptr, + offsets_ptr, + grad_input_ptr, + grad_mul_ptr, + grad_weight_ptr, + grad_bias_ptr, + N, + D, + num_experts, + eps, + expert_offset, + stride_row, + BLOCK_D: tl.constexpr, + ): + """Backward kernel: one program per row, 2-pass approach. + + Pass 1: RMS stats + dot products + bias grad + Pass 2: grad_input + grad_mul + weight grads (via atomic_add) + """ + row = tl.program_id(0) + if row >= N: + return + + lo = 0 + hi = num_experts + for _ in range(12): + if lo < hi: + mid = (lo + hi) // 2 + if tl.load(offsets_ptr + mid) <= row: + lo = mid + 1 + else: + hi = mid + eidx = lo + expert_offset + + w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32) + w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32) + w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32) + b_val = tl.load(bias_ptr + eidx).to(tl.float32) + + input_row_ptr = input_ptr + row * stride_row + mul_row_ptr = mul_ptr + row * stride_row + grad_out_row_ptr = grad_out_ptr + row * stride_row + grad_input_row_ptr = grad_input_ptr + row * stride_row + grad_mul_row_ptr = grad_mul_ptr + row * stride_row + + D_float = D.to(tl.float32) + + # --- Single-tile path --- + if D <= BLOCK_D: + d_offs = tl.arange(0, BLOCK_D) + mask = d_offs < D + + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + m = tl.load(mul_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + go = tl.load(grad_out_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + + x2 = x * x + x3 = x2 * x + + # Compute RMS stats (x4 inlined to reduce register pressure) + inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps) + inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps) + inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps) + + w0_inv = w0 * inv_rms_x3 + w1_inv = w1 * inv_rms_x2 + w2_inv = w2 * inv_rms_x + + dpoly = go * m + + # Dot products for coefficients and weight grads + sum_dpoly_x = tl.sum(dpoly * x) + sum_dpoly_x2 = tl.sum(dpoly * x2) + sum_dpoly_x3 = tl.sum(dpoly * x3) + grad_b_acc = tl.sum(dpoly) + + # Weight grads + grad_w0_acc = inv_rms_x3 * sum_dpoly_x3 + grad_w1_acc = inv_rms_x2 * sum_dpoly_x2 + grad_w2_acc = inv_rms_x * sum_dpoly_x + + # Coefficients for grad_input + coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float + coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float + coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float + + # grad_mul + poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + tl.store(grad_mul_row_ptr + d_offs, go * (poly + b_val), mask=mask) + + # grad_input (in-place accumulation to reduce register pressure) + g = inv_rms_x * (w2 * dpoly - x * coeff_x) + g += 2.0 * x * inv_rms_x2 * (w1 * dpoly - x2 * coeff_x2) + g += 3.0 * x2 * inv_rms_x3 * (w0 * dpoly - x3 * coeff_x3) + tl.store(grad_input_row_ptr + d_offs, g, mask=mask) + + tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc) + tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc) + tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc) + tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc) + else: + # --- Multi-tile: 2-pass --- + # Pass 1: RMS stats + dot products + bias grad + sum_x2 = tl.zeros((), dtype=tl.float32) + sum_x4 = tl.zeros((), dtype=tl.float32) + sum_x6 = tl.zeros((), dtype=tl.float32) + sum_dpoly_x = tl.zeros((), dtype=tl.float32) + sum_dpoly_x2 = tl.zeros((), dtype=tl.float32) + sum_dpoly_x3 = tl.zeros((), dtype=tl.float32) + grad_b_acc = tl.zeros((), dtype=tl.float32) + + for d_start in range(0, D, BLOCK_D): + d_offs = d_start + tl.arange(0, BLOCK_D) + mask = d_offs < D + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + m = tl.load(mul_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + go = tl.load(grad_out_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + + x2 = x * x + x3 = x2 * x + dpoly = go * m + + sum_x2 += tl.sum(x2) + sum_x4 += tl.sum(x2 * x2) + sum_x6 += tl.sum(x2 * x2 * x2) + sum_dpoly_x += tl.sum(dpoly * x) + sum_dpoly_x2 += tl.sum(dpoly * x2) + sum_dpoly_x3 += tl.sum(dpoly * x3) + grad_b_acc += tl.sum(dpoly) + + inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps) + inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps) + inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps) + + w0_inv = w0 * inv_rms_x3 + w1_inv = w1 * inv_rms_x2 + w2_inv = w2 * inv_rms_x + + # Weight grads + grad_w0_acc = inv_rms_x3 * sum_dpoly_x3 + grad_w1_acc = inv_rms_x2 * sum_dpoly_x2 + grad_w2_acc = inv_rms_x * sum_dpoly_x + + # Coefficients for grad_input + coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float + coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float + coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float + + # Pass 2: grad_input + grad_mul + for d_start in range(0, D, BLOCK_D): + d_offs = d_start + tl.arange(0, BLOCK_D) + mask = d_offs < D + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + m = tl.load(mul_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + go = tl.load(grad_out_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + + x2 = x * x + x3 = x2 * x + + poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + tl.store(grad_mul_row_ptr + d_offs, + go * (poly + b_val), + mask=mask) + + dpoly = go * m + g = inv_rms_x * (w2 * dpoly - x * coeff_x) + g += (2.0 * x * inv_rms_x2 * + (w1 * dpoly - x2 * coeff_x2)) + g += (3.0 * x2 * inv_rms_x3 * + (w0 * dpoly - x3 * coeff_x3)) + + tl.store(grad_input_row_ptr + d_offs, g, mask=mask) + + tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc) + tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc) + tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc) + tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc) + + class _GroupedPolyNormFn(torch.autograd.Function): + + @staticmethod + def forward(ctx, input, mul, weight, bias, offsets, eps, expert_offset): + N, D = input.shape + input = input.contiguous() + mul = mul.contiguous() + output = torch.empty_like(input) + + num_experts = offsets.shape[0] + assert num_experts <= 4096, ( + f"Supports at most 4096 experts, got {num_experts}.") + + _grouped_polynorm_fwd_kernel[(N,)]( + input, + mul, + weight, + bias, + offsets, + output, + N, + D, + num_experts, + eps, + expert_offset, + stride_input_row=input.stride(0), + stride_mul_row=mul.stride(0), + stride_out_row=output.stride(0), + ) + + ctx.save_for_backward(input, mul, weight, bias, offsets) + ctx.eps = eps + ctx.expert_offset = expert_offset + return output + + @staticmethod + def backward(ctx, grad_output): + input, mul, weight, bias, offsets = ctx.saved_tensors + eps = ctx.eps + expert_offset = ctx.expert_offset + N, D = input.shape + + grad_output = grad_output.contiguous() + grad_input = torch.empty_like(input) + grad_mul = torch.empty_like(mul) + grad_weight = torch.zeros(weight.shape[0], + 3, + device=weight.device, + dtype=torch.float32) + grad_bias = torch.zeros(bias.shape[0], + device=bias.device, + dtype=torch.float32) + + num_experts = offsets.shape[0] + + _grouped_polynorm_bwd_kernel[(N,)]( + grad_output, + input, + mul, + weight, + bias, + offsets, + grad_input, + grad_mul, + grad_weight, + grad_bias, + N, + D, + num_experts, + eps, + expert_offset, + stride_row=input.stride(0), + ) + + grad_weight = grad_weight.to(weight.dtype) + grad_bias = grad_bias.unsqueeze(-1).to(bias.dtype) + + return grad_input, grad_mul, grad_weight, grad_bias, None, None, None + + def grouped_fused_mul_poly_norm( + input: Tensor, + mul: Tensor, + weight: Tensor, + bias: Tensor, + offsets: Tensor, + eps: float = 1e-6, + expert_offset: int = 0, + ) -> Tensor: + """Triton-accelerated Grouped FusedMulPolyNorm. + + Args: + input: (total_tokens, D) - concatenated tokens for all experts + mul: (total_tokens, D) - gate values to multiply with + weight: (num_experts, 3) - per-expert polynomial weights + bias: (num_experts, 1) - per-expert polynomial bias + offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32) + eps: numerical stability epsilon + expert_offset: offset to add to expert index + + Returns: + (total_tokens, D) - output tensor + """ + return _GroupedPolyNormFn.apply(input, mul, weight, bias, offsets, eps, + expert_offset) + +else: + + def grouped_fused_mul_poly_norm( + input: Tensor, + mul: Tensor, + weight: Tensor, + bias: Tensor, + offsets: Tensor, + eps: float = 1e-6, + expert_offset: int = 0, + ) -> Tensor: + raise RuntimeError( + "Triton is not available. Install triton to use " + "grouped_fused_mul_poly_norm.") diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/__init__.py b/build/torch28-cxx11-rocm63-x86_64-linux/__init__.py index 0f6f29ac2c688bd09afa41c5d1abd9942c4456d8..a71567edf5422a315e6f06eae6a5606756794820 100644 --- a/build/torch28-cxx11-rocm63-x86_64-linux/__init__.py +++ b/build/torch28-cxx11-rocm63-x86_64-linux/__init__.py @@ -2,6 +2,7 @@ import torch from . import layers, parallel_style from ._ops import ops +from .grouped_poly_norm import grouped_fused_mul_poly_norm from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction @@ -45,6 +46,7 @@ def fused_add_rms_norm( __all__ = [ "poly_norm", "fused_mul_poly_norm", + "grouped_fused_mul_poly_norm", "rms_norm", "fused_add_rms_norm", "layers", diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/_activation_0e6f27f_dirty.abi3.so b/build/torch28-cxx11-rocm63-x86_64-linux/_activation_0e6f27f_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..6941562f2f9dcf7d6d1ceeb9d4a3ad67beb6be04 --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/_activation_0e6f27f_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:448ab4d8a859725a3200d95d2164a3fe261f67b20e834f4f7062485cf729cf88 +size 2788456 diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/_activation_18b7543_dirty.abi3.so b/build/torch28-cxx11-rocm63-x86_64-linux/_activation_18b7543_dirty.abi3.so deleted file mode 100755 index a2aba7df18f60f2562bb4ca5f6a5dd8a15253c75..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-rocm63-x86_64-linux/_activation_18b7543_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:e753cc4a2aaa76aea3b821b245d4da638d7c94059a660f8233e17a54df379813 -size 2788456 diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/_ops.py b/build/torch28-cxx11-rocm63-x86_64-linux/_ops.py index 1e14ae6ea50287a7fbc0feee54855ae4ffc0be5e..91c4fcaab95d4cccfa1e53fea482ec1a2334f4ff 100644 --- a/build/torch28-cxx11-rocm63-x86_64-linux/_ops.py +++ b/build/torch28-cxx11-rocm63-x86_64-linux/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _activation_18b7543_dirty -ops = torch.ops._activation_18b7543_dirty +from . import _activation_0e6f27f_dirty +ops = torch.ops._activation_0e6f27f_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_activation_18b7543_dirty::{op_name}" \ No newline at end of file + return f"_activation_0e6f27f_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/grouped_poly_norm.py b/build/torch28-cxx11-rocm63-x86_64-linux/grouped_poly_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..9184736c1c24584a94aa80c84fc18a95ab557a52 --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/grouped_poly_norm.py @@ -0,0 +1,583 @@ +"""Triton-accelerated Grouped FusedMulPolyNorm for MoE. + +Fuses the entire PolyNorm computation into two Triton kernels (fwd + bwd), +eliminating multiple intermediate tensors and kernel launches. + +PolyNorm formula (per row): + poly = w[0] * rms_norm(x^3) + w[1] * rms_norm(x^2) + w[2] * rms_norm(x) + bias + output = poly * mul + +where rms_norm(x) = x / sqrt(mean(x^2, dim=-1) + eps) + +Performance optimizations: + - @triton.autotune selects optimal BLOCK_D, num_warps, and num_stages per + hidden dimension. + - Single-tile specialization: when D <= BLOCK_D, all data stays in registers + across the reduction and output phases, eliminating redundant global reads. + - Multi-tile software pipelining: explicit num_stages in autotune configs + enables overlapping memory loads with computation across loop iterations. + - In-kernel binary search for expert mapping: eliminates 2 PyTorch kernel + launches (torch.arange + torch.bucketize) per forward/backward call. + - Backward 2-pass optimization: pass 1 merges RMS statistics computation + with dot product accumulation, pass 2 computes gradients. This reduces + memory traffic compared to a naive 3-pass approach. + +Forward kernel: one program per row, tiles over D dimension. + - Computes x, x^2, x^3 in registers + - Computes three RMS norms in a single pass (shared variance reduction) + - Applies polynomial weights + bias + mul in-place + +Backward kernel: one program per row, tiles over D dimension. + - Recomputes forward intermediates from saved inputs (activation recomputation) + - 2-pass: (1) RMS stats + dot products + bias grad, (2) grad_input + grad_mul + weight grads + - Weight/bias gradients use tl.atomic_add for cross-row accumulation +""" + +import torch +from torch import Tensor + +try: + import triton + import triton.language as tl + + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + + +# --------------------------------------------------------------------------- +# PyTorch reference implementation (for testing and benchmarking) +# --------------------------------------------------------------------------- +def _rms_norm(x: Tensor, eps: float) -> Tensor: + """Per-row RMS normalization: x / sqrt(mean(x^2, dim=-1) + eps)""" + return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps) + + +def grouped_fused_mul_poly_norm_ref( + input: Tensor, + mul: Tensor, + weight: Tensor, + bias: Tensor, + offsets: Tensor, + eps: float = 1e-6, + expert_offset: int = 0, +) -> Tensor: + """PyTorch reference for Grouped FusedMulPolyNorm (vectorized, single pass). + + Uses torch.bucketize to map tokens to experts, then computes PolyNorm + for all tokens at once. torch.compile friendly. + + Args: + input: (total_tokens, D) - concatenated tokens for all experts + mul: (total_tokens, D) - gate values to multiply with + weight: (num_experts, 3) - per-expert polynomial weights [x^3, x^2, x] + bias: (num_experts, 1) - per-expert polynomial bias + offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32) + eps: numerical stability epsilon + + Returns: + (total_tokens, D) - output tensor + """ + orig_dtype = input.dtype + + token_positions = torch.arange(input.shape[0], device=input.device) + expert_idx = torch.bucketize(token_positions, offsets, right=True) + expert_offset + + weight_fp32 = weight.float() + bias_fp32 = bias.float() + + per_token_w = weight_fp32[expert_idx] + per_token_b = bias_fp32[expert_idx] + + x = input.float() + m = mul.float() + + x2 = x * x + x3 = x2 * x + + poly = (per_token_w[:, 0:1] * _rms_norm(x3, eps) + + per_token_w[:, 1:2] * _rms_norm(x2, eps) + + per_token_w[:, 2:3] * _rms_norm(x, eps) + per_token_b) + + return (poly * m).to(orig_dtype) + + +# --------------------------------------------------------------------------- +# Triton kernel implementation +# --------------------------------------------------------------------------- +if HAS_TRITON: + # --- Autotune configurations --- + _GROUPED_POLYNORM_FWD_CONFIGS = [ + triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2), + triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2), + triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2), + triton.Config({"BLOCK_D": 2048}, num_warps=4, num_stages=1), + triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1), + triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1), + triton.Config({"BLOCK_D": 2048}, num_warps=32, num_stages=1), + ] + + _GROUPED_POLYNORM_BWD_CONFIGS = [ + triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2), + triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=5), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2), + triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=4), + triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2), + triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1), + triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1), + ] + + @triton.autotune( + configs=_GROUPED_POLYNORM_FWD_CONFIGS, + key=["D"], + ) + @triton.jit + def _grouped_polynorm_fwd_kernel( + input_ptr, + mul_ptr, + weight_ptr, + bias_ptr, + offsets_ptr, + output_ptr, + N, + D, + num_experts, + eps, + expert_offset, + stride_input_row, + stride_mul_row, + stride_out_row, + BLOCK_D: tl.constexpr, + ): + """Forward kernel: one program per row.""" + row = tl.program_id(0) + if row >= N: + return + + # Binary search for expert index (12 iters covers up to 4096 experts) + lo = 0 + hi = num_experts + for _ in range(12): + if lo < hi: + mid = (lo + hi) // 2 + if tl.load(offsets_ptr + mid) <= row: + lo = mid + 1 + else: + hi = mid + eidx = lo + expert_offset + + w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32) + w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32) + w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32) + b = tl.load(bias_ptr + eidx).to(tl.float32) + + input_row_ptr = input_ptr + row * stride_input_row + mul_row_ptr = mul_ptr + row * stride_mul_row + out_row_ptr = output_ptr + row * stride_out_row + + D_float = D.to(tl.float32) + + # --- Single-tile path --- + if D <= BLOCK_D: + d_offs = tl.arange(0, BLOCK_D) + mask = d_offs < D + + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + m = tl.load(mul_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + + x2 = x * x + x3 = x2 * x + + inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps) + inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps) + inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps) + + # Pre-multiply scalar weight * inv_rms to save 1 FMA per element + w0_inv = w0 * inv_rms_x3 + w1_inv = w1 * inv_rms_x2 + w2_inv = w2 * inv_rms_x + + poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b + tl.store(out_row_ptr + d_offs, poly * m, mask=mask) + else: + # --- Multi-tile: two-pass approach --- + sum_x2 = tl.zeros((), dtype=tl.float32) + sum_x4 = tl.zeros((), dtype=tl.float32) + sum_x6 = tl.zeros((), dtype=tl.float32) + + for d_start in range(0, D, BLOCK_D): + d_offs = d_start + tl.arange(0, BLOCK_D) + mask = d_offs < D + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + x2 = x * x + sum_x2 += tl.sum(x2) + sum_x4 += tl.sum(x2 * x2) + sum_x6 += tl.sum(x2 * x2 * x2) + + inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps) + inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps) + inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps) + + # Pre-multiply scalar weight * inv_rms + w0_inv = w0 * inv_rms_x3 + w1_inv = w1 * inv_rms_x2 + w2_inv = w2 * inv_rms_x + + for d_start in range(0, D, BLOCK_D): + d_offs = d_start + tl.arange(0, BLOCK_D) + mask = d_offs < D + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + m = tl.load(mul_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + x2 = x * x + x3 = x2 * x + poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b + tl.store(out_row_ptr + d_offs, poly * m, mask=mask) + + @triton.autotune( + configs=_GROUPED_POLYNORM_BWD_CONFIGS, + key=["D"], + reset_to_zero=["grad_weight_ptr", "grad_bias_ptr"], + ) + @triton.jit + def _grouped_polynorm_bwd_kernel( + grad_out_ptr, + input_ptr, + mul_ptr, + weight_ptr, + bias_ptr, + offsets_ptr, + grad_input_ptr, + grad_mul_ptr, + grad_weight_ptr, + grad_bias_ptr, + N, + D, + num_experts, + eps, + expert_offset, + stride_row, + BLOCK_D: tl.constexpr, + ): + """Backward kernel: one program per row, 2-pass approach. + + Pass 1: RMS stats + dot products + bias grad + Pass 2: grad_input + grad_mul + weight grads (via atomic_add) + """ + row = tl.program_id(0) + if row >= N: + return + + lo = 0 + hi = num_experts + for _ in range(12): + if lo < hi: + mid = (lo + hi) // 2 + if tl.load(offsets_ptr + mid) <= row: + lo = mid + 1 + else: + hi = mid + eidx = lo + expert_offset + + w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32) + w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32) + w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32) + b_val = tl.load(bias_ptr + eidx).to(tl.float32) + + input_row_ptr = input_ptr + row * stride_row + mul_row_ptr = mul_ptr + row * stride_row + grad_out_row_ptr = grad_out_ptr + row * stride_row + grad_input_row_ptr = grad_input_ptr + row * stride_row + grad_mul_row_ptr = grad_mul_ptr + row * stride_row + + D_float = D.to(tl.float32) + + # --- Single-tile path --- + if D <= BLOCK_D: + d_offs = tl.arange(0, BLOCK_D) + mask = d_offs < D + + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + m = tl.load(mul_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + go = tl.load(grad_out_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + + x2 = x * x + x3 = x2 * x + + # Compute RMS stats (x4 inlined to reduce register pressure) + inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps) + inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps) + inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps) + + w0_inv = w0 * inv_rms_x3 + w1_inv = w1 * inv_rms_x2 + w2_inv = w2 * inv_rms_x + + dpoly = go * m + + # Dot products for coefficients and weight grads + sum_dpoly_x = tl.sum(dpoly * x) + sum_dpoly_x2 = tl.sum(dpoly * x2) + sum_dpoly_x3 = tl.sum(dpoly * x3) + grad_b_acc = tl.sum(dpoly) + + # Weight grads + grad_w0_acc = inv_rms_x3 * sum_dpoly_x3 + grad_w1_acc = inv_rms_x2 * sum_dpoly_x2 + grad_w2_acc = inv_rms_x * sum_dpoly_x + + # Coefficients for grad_input + coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float + coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float + coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float + + # grad_mul + poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + tl.store(grad_mul_row_ptr + d_offs, go * (poly + b_val), mask=mask) + + # grad_input (in-place accumulation to reduce register pressure) + g = inv_rms_x * (w2 * dpoly - x * coeff_x) + g += 2.0 * x * inv_rms_x2 * (w1 * dpoly - x2 * coeff_x2) + g += 3.0 * x2 * inv_rms_x3 * (w0 * dpoly - x3 * coeff_x3) + tl.store(grad_input_row_ptr + d_offs, g, mask=mask) + + tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc) + tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc) + tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc) + tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc) + else: + # --- Multi-tile: 2-pass --- + # Pass 1: RMS stats + dot products + bias grad + sum_x2 = tl.zeros((), dtype=tl.float32) + sum_x4 = tl.zeros((), dtype=tl.float32) + sum_x6 = tl.zeros((), dtype=tl.float32) + sum_dpoly_x = tl.zeros((), dtype=tl.float32) + sum_dpoly_x2 = tl.zeros((), dtype=tl.float32) + sum_dpoly_x3 = tl.zeros((), dtype=tl.float32) + grad_b_acc = tl.zeros((), dtype=tl.float32) + + for d_start in range(0, D, BLOCK_D): + d_offs = d_start + tl.arange(0, BLOCK_D) + mask = d_offs < D + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + m = tl.load(mul_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + go = tl.load(grad_out_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + + x2 = x * x + x3 = x2 * x + dpoly = go * m + + sum_x2 += tl.sum(x2) + sum_x4 += tl.sum(x2 * x2) + sum_x6 += tl.sum(x2 * x2 * x2) + sum_dpoly_x += tl.sum(dpoly * x) + sum_dpoly_x2 += tl.sum(dpoly * x2) + sum_dpoly_x3 += tl.sum(dpoly * x3) + grad_b_acc += tl.sum(dpoly) + + inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps) + inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps) + inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps) + + w0_inv = w0 * inv_rms_x3 + w1_inv = w1 * inv_rms_x2 + w2_inv = w2 * inv_rms_x + + # Weight grads + grad_w0_acc = inv_rms_x3 * sum_dpoly_x3 + grad_w1_acc = inv_rms_x2 * sum_dpoly_x2 + grad_w2_acc = inv_rms_x * sum_dpoly_x + + # Coefficients for grad_input + coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float + coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float + coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float + + # Pass 2: grad_input + grad_mul + for d_start in range(0, D, BLOCK_D): + d_offs = d_start + tl.arange(0, BLOCK_D) + mask = d_offs < D + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + m = tl.load(mul_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + go = tl.load(grad_out_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + + x2 = x * x + x3 = x2 * x + + poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + tl.store(grad_mul_row_ptr + d_offs, + go * (poly + b_val), + mask=mask) + + dpoly = go * m + g = inv_rms_x * (w2 * dpoly - x * coeff_x) + g += (2.0 * x * inv_rms_x2 * + (w1 * dpoly - x2 * coeff_x2)) + g += (3.0 * x2 * inv_rms_x3 * + (w0 * dpoly - x3 * coeff_x3)) + + tl.store(grad_input_row_ptr + d_offs, g, mask=mask) + + tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc) + tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc) + tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc) + tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc) + + class _GroupedPolyNormFn(torch.autograd.Function): + + @staticmethod + def forward(ctx, input, mul, weight, bias, offsets, eps, expert_offset): + N, D = input.shape + input = input.contiguous() + mul = mul.contiguous() + output = torch.empty_like(input) + + num_experts = offsets.shape[0] + assert num_experts <= 4096, ( + f"Supports at most 4096 experts, got {num_experts}.") + + _grouped_polynorm_fwd_kernel[(N,)]( + input, + mul, + weight, + bias, + offsets, + output, + N, + D, + num_experts, + eps, + expert_offset, + stride_input_row=input.stride(0), + stride_mul_row=mul.stride(0), + stride_out_row=output.stride(0), + ) + + ctx.save_for_backward(input, mul, weight, bias, offsets) + ctx.eps = eps + ctx.expert_offset = expert_offset + return output + + @staticmethod + def backward(ctx, grad_output): + input, mul, weight, bias, offsets = ctx.saved_tensors + eps = ctx.eps + expert_offset = ctx.expert_offset + N, D = input.shape + + grad_output = grad_output.contiguous() + grad_input = torch.empty_like(input) + grad_mul = torch.empty_like(mul) + grad_weight = torch.zeros(weight.shape[0], + 3, + device=weight.device, + dtype=torch.float32) + grad_bias = torch.zeros(bias.shape[0], + device=bias.device, + dtype=torch.float32) + + num_experts = offsets.shape[0] + + _grouped_polynorm_bwd_kernel[(N,)]( + grad_output, + input, + mul, + weight, + bias, + offsets, + grad_input, + grad_mul, + grad_weight, + grad_bias, + N, + D, + num_experts, + eps, + expert_offset, + stride_row=input.stride(0), + ) + + grad_weight = grad_weight.to(weight.dtype) + grad_bias = grad_bias.unsqueeze(-1).to(bias.dtype) + + return grad_input, grad_mul, grad_weight, grad_bias, None, None, None + + def grouped_fused_mul_poly_norm( + input: Tensor, + mul: Tensor, + weight: Tensor, + bias: Tensor, + offsets: Tensor, + eps: float = 1e-6, + expert_offset: int = 0, + ) -> Tensor: + """Triton-accelerated Grouped FusedMulPolyNorm. + + Args: + input: (total_tokens, D) - concatenated tokens for all experts + mul: (total_tokens, D) - gate values to multiply with + weight: (num_experts, 3) - per-expert polynomial weights + bias: (num_experts, 1) - per-expert polynomial bias + offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32) + eps: numerical stability epsilon + expert_offset: offset to add to expert index + + Returns: + (total_tokens, D) - output tensor + """ + return _GroupedPolyNormFn.apply(input, mul, weight, bias, offsets, eps, + expert_offset) + +else: + + def grouped_fused_mul_poly_norm( + input: Tensor, + mul: Tensor, + weight: Tensor, + bias: Tensor, + offsets: Tensor, + eps: float = 1e-6, + expert_offset: int = 0, + ) -> Tensor: + raise RuntimeError( + "Triton is not available. Install triton to use " + "grouped_fused_mul_poly_norm.") diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/__init__.py b/build/torch28-cxx11-rocm64-x86_64-linux/__init__.py index 0f6f29ac2c688bd09afa41c5d1abd9942c4456d8..a71567edf5422a315e6f06eae6a5606756794820 100644 --- a/build/torch28-cxx11-rocm64-x86_64-linux/__init__.py +++ b/build/torch28-cxx11-rocm64-x86_64-linux/__init__.py @@ -2,6 +2,7 @@ import torch from . import layers, parallel_style from ._ops import ops +from .grouped_poly_norm import grouped_fused_mul_poly_norm from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction @@ -45,6 +46,7 @@ def fused_add_rms_norm( __all__ = [ "poly_norm", "fused_mul_poly_norm", + "grouped_fused_mul_poly_norm", "rms_norm", "fused_add_rms_norm", "layers", diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/_activation_0e6f27f_dirty.abi3.so b/build/torch28-cxx11-rocm64-x86_64-linux/_activation_0e6f27f_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..8dc5986ed65e78ccae92752b2318a4bf1e7028d9 --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/_activation_0e6f27f_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:388f461d91124b99544cbdbd4dc4d98f24c010d7a0dc2e9389648860a809a51a +size 2794152 diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/_activation_18b7543_dirty.abi3.so b/build/torch28-cxx11-rocm64-x86_64-linux/_activation_18b7543_dirty.abi3.so deleted file mode 100755 index d9f497ccf3368602ac8c3250f734668e8e5dc228..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-rocm64-x86_64-linux/_activation_18b7543_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:5213ce21365e97016b4a04bfa304e03c9fc1dc6780f62b3312fb65f96e8f6381 -size 2794152 diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/_ops.py b/build/torch28-cxx11-rocm64-x86_64-linux/_ops.py index 1e14ae6ea50287a7fbc0feee54855ae4ffc0be5e..91c4fcaab95d4cccfa1e53fea482ec1a2334f4ff 100644 --- a/build/torch28-cxx11-rocm64-x86_64-linux/_ops.py +++ b/build/torch28-cxx11-rocm64-x86_64-linux/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _activation_18b7543_dirty -ops = torch.ops._activation_18b7543_dirty +from . import _activation_0e6f27f_dirty +ops = torch.ops._activation_0e6f27f_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_activation_18b7543_dirty::{op_name}" \ No newline at end of file + return f"_activation_0e6f27f_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/grouped_poly_norm.py b/build/torch28-cxx11-rocm64-x86_64-linux/grouped_poly_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..9184736c1c24584a94aa80c84fc18a95ab557a52 --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/grouped_poly_norm.py @@ -0,0 +1,583 @@ +"""Triton-accelerated Grouped FusedMulPolyNorm for MoE. + +Fuses the entire PolyNorm computation into two Triton kernels (fwd + bwd), +eliminating multiple intermediate tensors and kernel launches. + +PolyNorm formula (per row): + poly = w[0] * rms_norm(x^3) + w[1] * rms_norm(x^2) + w[2] * rms_norm(x) + bias + output = poly * mul + +where rms_norm(x) = x / sqrt(mean(x^2, dim=-1) + eps) + +Performance optimizations: + - @triton.autotune selects optimal BLOCK_D, num_warps, and num_stages per + hidden dimension. + - Single-tile specialization: when D <= BLOCK_D, all data stays in registers + across the reduction and output phases, eliminating redundant global reads. + - Multi-tile software pipelining: explicit num_stages in autotune configs + enables overlapping memory loads with computation across loop iterations. + - In-kernel binary search for expert mapping: eliminates 2 PyTorch kernel + launches (torch.arange + torch.bucketize) per forward/backward call. + - Backward 2-pass optimization: pass 1 merges RMS statistics computation + with dot product accumulation, pass 2 computes gradients. This reduces + memory traffic compared to a naive 3-pass approach. + +Forward kernel: one program per row, tiles over D dimension. + - Computes x, x^2, x^3 in registers + - Computes three RMS norms in a single pass (shared variance reduction) + - Applies polynomial weights + bias + mul in-place + +Backward kernel: one program per row, tiles over D dimension. + - Recomputes forward intermediates from saved inputs (activation recomputation) + - 2-pass: (1) RMS stats + dot products + bias grad, (2) grad_input + grad_mul + weight grads + - Weight/bias gradients use tl.atomic_add for cross-row accumulation +""" + +import torch +from torch import Tensor + +try: + import triton + import triton.language as tl + + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + + +# --------------------------------------------------------------------------- +# PyTorch reference implementation (for testing and benchmarking) +# --------------------------------------------------------------------------- +def _rms_norm(x: Tensor, eps: float) -> Tensor: + """Per-row RMS normalization: x / sqrt(mean(x^2, dim=-1) + eps)""" + return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps) + + +def grouped_fused_mul_poly_norm_ref( + input: Tensor, + mul: Tensor, + weight: Tensor, + bias: Tensor, + offsets: Tensor, + eps: float = 1e-6, + expert_offset: int = 0, +) -> Tensor: + """PyTorch reference for Grouped FusedMulPolyNorm (vectorized, single pass). + + Uses torch.bucketize to map tokens to experts, then computes PolyNorm + for all tokens at once. torch.compile friendly. + + Args: + input: (total_tokens, D) - concatenated tokens for all experts + mul: (total_tokens, D) - gate values to multiply with + weight: (num_experts, 3) - per-expert polynomial weights [x^3, x^2, x] + bias: (num_experts, 1) - per-expert polynomial bias + offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32) + eps: numerical stability epsilon + + Returns: + (total_tokens, D) - output tensor + """ + orig_dtype = input.dtype + + token_positions = torch.arange(input.shape[0], device=input.device) + expert_idx = torch.bucketize(token_positions, offsets, right=True) + expert_offset + + weight_fp32 = weight.float() + bias_fp32 = bias.float() + + per_token_w = weight_fp32[expert_idx] + per_token_b = bias_fp32[expert_idx] + + x = input.float() + m = mul.float() + + x2 = x * x + x3 = x2 * x + + poly = (per_token_w[:, 0:1] * _rms_norm(x3, eps) + + per_token_w[:, 1:2] * _rms_norm(x2, eps) + + per_token_w[:, 2:3] * _rms_norm(x, eps) + per_token_b) + + return (poly * m).to(orig_dtype) + + +# --------------------------------------------------------------------------- +# Triton kernel implementation +# --------------------------------------------------------------------------- +if HAS_TRITON: + # --- Autotune configurations --- + _GROUPED_POLYNORM_FWD_CONFIGS = [ + triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2), + triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2), + triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2), + triton.Config({"BLOCK_D": 2048}, num_warps=4, num_stages=1), + triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1), + triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1), + triton.Config({"BLOCK_D": 2048}, num_warps=32, num_stages=1), + ] + + _GROUPED_POLYNORM_BWD_CONFIGS = [ + triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2), + triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=5), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2), + triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=4), + triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2), + triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1), + triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1), + ] + + @triton.autotune( + configs=_GROUPED_POLYNORM_FWD_CONFIGS, + key=["D"], + ) + @triton.jit + def _grouped_polynorm_fwd_kernel( + input_ptr, + mul_ptr, + weight_ptr, + bias_ptr, + offsets_ptr, + output_ptr, + N, + D, + num_experts, + eps, + expert_offset, + stride_input_row, + stride_mul_row, + stride_out_row, + BLOCK_D: tl.constexpr, + ): + """Forward kernel: one program per row.""" + row = tl.program_id(0) + if row >= N: + return + + # Binary search for expert index (12 iters covers up to 4096 experts) + lo = 0 + hi = num_experts + for _ in range(12): + if lo < hi: + mid = (lo + hi) // 2 + if tl.load(offsets_ptr + mid) <= row: + lo = mid + 1 + else: + hi = mid + eidx = lo + expert_offset + + w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32) + w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32) + w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32) + b = tl.load(bias_ptr + eidx).to(tl.float32) + + input_row_ptr = input_ptr + row * stride_input_row + mul_row_ptr = mul_ptr + row * stride_mul_row + out_row_ptr = output_ptr + row * stride_out_row + + D_float = D.to(tl.float32) + + # --- Single-tile path --- + if D <= BLOCK_D: + d_offs = tl.arange(0, BLOCK_D) + mask = d_offs < D + + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + m = tl.load(mul_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + + x2 = x * x + x3 = x2 * x + + inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps) + inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps) + inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps) + + # Pre-multiply scalar weight * inv_rms to save 1 FMA per element + w0_inv = w0 * inv_rms_x3 + w1_inv = w1 * inv_rms_x2 + w2_inv = w2 * inv_rms_x + + poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b + tl.store(out_row_ptr + d_offs, poly * m, mask=mask) + else: + # --- Multi-tile: two-pass approach --- + sum_x2 = tl.zeros((), dtype=tl.float32) + sum_x4 = tl.zeros((), dtype=tl.float32) + sum_x6 = tl.zeros((), dtype=tl.float32) + + for d_start in range(0, D, BLOCK_D): + d_offs = d_start + tl.arange(0, BLOCK_D) + mask = d_offs < D + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + x2 = x * x + sum_x2 += tl.sum(x2) + sum_x4 += tl.sum(x2 * x2) + sum_x6 += tl.sum(x2 * x2 * x2) + + inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps) + inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps) + inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps) + + # Pre-multiply scalar weight * inv_rms + w0_inv = w0 * inv_rms_x3 + w1_inv = w1 * inv_rms_x2 + w2_inv = w2 * inv_rms_x + + for d_start in range(0, D, BLOCK_D): + d_offs = d_start + tl.arange(0, BLOCK_D) + mask = d_offs < D + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + m = tl.load(mul_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + x2 = x * x + x3 = x2 * x + poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b + tl.store(out_row_ptr + d_offs, poly * m, mask=mask) + + @triton.autotune( + configs=_GROUPED_POLYNORM_BWD_CONFIGS, + key=["D"], + reset_to_zero=["grad_weight_ptr", "grad_bias_ptr"], + ) + @triton.jit + def _grouped_polynorm_bwd_kernel( + grad_out_ptr, + input_ptr, + mul_ptr, + weight_ptr, + bias_ptr, + offsets_ptr, + grad_input_ptr, + grad_mul_ptr, + grad_weight_ptr, + grad_bias_ptr, + N, + D, + num_experts, + eps, + expert_offset, + stride_row, + BLOCK_D: tl.constexpr, + ): + """Backward kernel: one program per row, 2-pass approach. + + Pass 1: RMS stats + dot products + bias grad + Pass 2: grad_input + grad_mul + weight grads (via atomic_add) + """ + row = tl.program_id(0) + if row >= N: + return + + lo = 0 + hi = num_experts + for _ in range(12): + if lo < hi: + mid = (lo + hi) // 2 + if tl.load(offsets_ptr + mid) <= row: + lo = mid + 1 + else: + hi = mid + eidx = lo + expert_offset + + w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32) + w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32) + w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32) + b_val = tl.load(bias_ptr + eidx).to(tl.float32) + + input_row_ptr = input_ptr + row * stride_row + mul_row_ptr = mul_ptr + row * stride_row + grad_out_row_ptr = grad_out_ptr + row * stride_row + grad_input_row_ptr = grad_input_ptr + row * stride_row + grad_mul_row_ptr = grad_mul_ptr + row * stride_row + + D_float = D.to(tl.float32) + + # --- Single-tile path --- + if D <= BLOCK_D: + d_offs = tl.arange(0, BLOCK_D) + mask = d_offs < D + + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + m = tl.load(mul_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + go = tl.load(grad_out_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + + x2 = x * x + x3 = x2 * x + + # Compute RMS stats (x4 inlined to reduce register pressure) + inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps) + inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps) + inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps) + + w0_inv = w0 * inv_rms_x3 + w1_inv = w1 * inv_rms_x2 + w2_inv = w2 * inv_rms_x + + dpoly = go * m + + # Dot products for coefficients and weight grads + sum_dpoly_x = tl.sum(dpoly * x) + sum_dpoly_x2 = tl.sum(dpoly * x2) + sum_dpoly_x3 = tl.sum(dpoly * x3) + grad_b_acc = tl.sum(dpoly) + + # Weight grads + grad_w0_acc = inv_rms_x3 * sum_dpoly_x3 + grad_w1_acc = inv_rms_x2 * sum_dpoly_x2 + grad_w2_acc = inv_rms_x * sum_dpoly_x + + # Coefficients for grad_input + coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float + coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float + coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float + + # grad_mul + poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + tl.store(grad_mul_row_ptr + d_offs, go * (poly + b_val), mask=mask) + + # grad_input (in-place accumulation to reduce register pressure) + g = inv_rms_x * (w2 * dpoly - x * coeff_x) + g += 2.0 * x * inv_rms_x2 * (w1 * dpoly - x2 * coeff_x2) + g += 3.0 * x2 * inv_rms_x3 * (w0 * dpoly - x3 * coeff_x3) + tl.store(grad_input_row_ptr + d_offs, g, mask=mask) + + tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc) + tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc) + tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc) + tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc) + else: + # --- Multi-tile: 2-pass --- + # Pass 1: RMS stats + dot products + bias grad + sum_x2 = tl.zeros((), dtype=tl.float32) + sum_x4 = tl.zeros((), dtype=tl.float32) + sum_x6 = tl.zeros((), dtype=tl.float32) + sum_dpoly_x = tl.zeros((), dtype=tl.float32) + sum_dpoly_x2 = tl.zeros((), dtype=tl.float32) + sum_dpoly_x3 = tl.zeros((), dtype=tl.float32) + grad_b_acc = tl.zeros((), dtype=tl.float32) + + for d_start in range(0, D, BLOCK_D): + d_offs = d_start + tl.arange(0, BLOCK_D) + mask = d_offs < D + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + m = tl.load(mul_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + go = tl.load(grad_out_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + + x2 = x * x + x3 = x2 * x + dpoly = go * m + + sum_x2 += tl.sum(x2) + sum_x4 += tl.sum(x2 * x2) + sum_x6 += tl.sum(x2 * x2 * x2) + sum_dpoly_x += tl.sum(dpoly * x) + sum_dpoly_x2 += tl.sum(dpoly * x2) + sum_dpoly_x3 += tl.sum(dpoly * x3) + grad_b_acc += tl.sum(dpoly) + + inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps) + inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps) + inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps) + + w0_inv = w0 * inv_rms_x3 + w1_inv = w1 * inv_rms_x2 + w2_inv = w2 * inv_rms_x + + # Weight grads + grad_w0_acc = inv_rms_x3 * sum_dpoly_x3 + grad_w1_acc = inv_rms_x2 * sum_dpoly_x2 + grad_w2_acc = inv_rms_x * sum_dpoly_x + + # Coefficients for grad_input + coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float + coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float + coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float + + # Pass 2: grad_input + grad_mul + for d_start in range(0, D, BLOCK_D): + d_offs = d_start + tl.arange(0, BLOCK_D) + mask = d_offs < D + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + m = tl.load(mul_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + go = tl.load(grad_out_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + + x2 = x * x + x3 = x2 * x + + poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + tl.store(grad_mul_row_ptr + d_offs, + go * (poly + b_val), + mask=mask) + + dpoly = go * m + g = inv_rms_x * (w2 * dpoly - x * coeff_x) + g += (2.0 * x * inv_rms_x2 * + (w1 * dpoly - x2 * coeff_x2)) + g += (3.0 * x2 * inv_rms_x3 * + (w0 * dpoly - x3 * coeff_x3)) + + tl.store(grad_input_row_ptr + d_offs, g, mask=mask) + + tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc) + tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc) + tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc) + tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc) + + class _GroupedPolyNormFn(torch.autograd.Function): + + @staticmethod + def forward(ctx, input, mul, weight, bias, offsets, eps, expert_offset): + N, D = input.shape + input = input.contiguous() + mul = mul.contiguous() + output = torch.empty_like(input) + + num_experts = offsets.shape[0] + assert num_experts <= 4096, ( + f"Supports at most 4096 experts, got {num_experts}.") + + _grouped_polynorm_fwd_kernel[(N,)]( + input, + mul, + weight, + bias, + offsets, + output, + N, + D, + num_experts, + eps, + expert_offset, + stride_input_row=input.stride(0), + stride_mul_row=mul.stride(0), + stride_out_row=output.stride(0), + ) + + ctx.save_for_backward(input, mul, weight, bias, offsets) + ctx.eps = eps + ctx.expert_offset = expert_offset + return output + + @staticmethod + def backward(ctx, grad_output): + input, mul, weight, bias, offsets = ctx.saved_tensors + eps = ctx.eps + expert_offset = ctx.expert_offset + N, D = input.shape + + grad_output = grad_output.contiguous() + grad_input = torch.empty_like(input) + grad_mul = torch.empty_like(mul) + grad_weight = torch.zeros(weight.shape[0], + 3, + device=weight.device, + dtype=torch.float32) + grad_bias = torch.zeros(bias.shape[0], + device=bias.device, + dtype=torch.float32) + + num_experts = offsets.shape[0] + + _grouped_polynorm_bwd_kernel[(N,)]( + grad_output, + input, + mul, + weight, + bias, + offsets, + grad_input, + grad_mul, + grad_weight, + grad_bias, + N, + D, + num_experts, + eps, + expert_offset, + stride_row=input.stride(0), + ) + + grad_weight = grad_weight.to(weight.dtype) + grad_bias = grad_bias.unsqueeze(-1).to(bias.dtype) + + return grad_input, grad_mul, grad_weight, grad_bias, None, None, None + + def grouped_fused_mul_poly_norm( + input: Tensor, + mul: Tensor, + weight: Tensor, + bias: Tensor, + offsets: Tensor, + eps: float = 1e-6, + expert_offset: int = 0, + ) -> Tensor: + """Triton-accelerated Grouped FusedMulPolyNorm. + + Args: + input: (total_tokens, D) - concatenated tokens for all experts + mul: (total_tokens, D) - gate values to multiply with + weight: (num_experts, 3) - per-expert polynomial weights + bias: (num_experts, 1) - per-expert polynomial bias + offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32) + eps: numerical stability epsilon + expert_offset: offset to add to expert index + + Returns: + (total_tokens, D) - output tensor + """ + return _GroupedPolyNormFn.apply(input, mul, weight, bias, offsets, eps, + expert_offset) + +else: + + def grouped_fused_mul_poly_norm( + input: Tensor, + mul: Tensor, + weight: Tensor, + bias: Tensor, + offsets: Tensor, + eps: float = 1e-6, + expert_offset: int = 0, + ) -> Tensor: + raise RuntimeError( + "Triton is not available. Install triton to use " + "grouped_fused_mul_poly_norm.") diff --git a/build/torch29-cxx11-cu126-x86_64-linux/__init__.py b/build/torch29-cxx11-cu126-x86_64-linux/__init__.py index 0f6f29ac2c688bd09afa41c5d1abd9942c4456d8..a71567edf5422a315e6f06eae6a5606756794820 100644 --- a/build/torch29-cxx11-cu126-x86_64-linux/__init__.py +++ b/build/torch29-cxx11-cu126-x86_64-linux/__init__.py @@ -2,6 +2,7 @@ import torch from . import layers, parallel_style from ._ops import ops +from .grouped_poly_norm import grouped_fused_mul_poly_norm from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction @@ -45,6 +46,7 @@ def fused_add_rms_norm( __all__ = [ "poly_norm", "fused_mul_poly_norm", + "grouped_fused_mul_poly_norm", "rms_norm", "fused_add_rms_norm", "layers", diff --git a/build/torch29-cxx11-cu126-x86_64-linux/_activation_0e6f27f_dirty.abi3.so b/build/torch29-cxx11-cu126-x86_64-linux/_activation_0e6f27f_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..43c71923e7ca23a2f8334a6879b08e6912888e36 --- /dev/null +++ b/build/torch29-cxx11-cu126-x86_64-linux/_activation_0e6f27f_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0c5304ceac9171f76c03792dfb9b7e8299ba2f2885983c39031546fde8f61f8b +size 10756320 diff --git a/build/torch29-cxx11-cu126-x86_64-linux/_activation_18b7543_dirty.abi3.so b/build/torch29-cxx11-cu126-x86_64-linux/_activation_18b7543_dirty.abi3.so deleted file mode 100755 index 487158f67784cbcd78bb767dbe5a0cead183d0de..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu126-x86_64-linux/_activation_18b7543_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:0a64085eb92e8f49acb50abc5fc2c50e0dc3e3fd84ae29d7ea8bf27518c34af3 -size 10756320 diff --git a/build/torch29-cxx11-cu126-x86_64-linux/_ops.py b/build/torch29-cxx11-cu126-x86_64-linux/_ops.py index 1e14ae6ea50287a7fbc0feee54855ae4ffc0be5e..91c4fcaab95d4cccfa1e53fea482ec1a2334f4ff 100644 --- a/build/torch29-cxx11-cu126-x86_64-linux/_ops.py +++ b/build/torch29-cxx11-cu126-x86_64-linux/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _activation_18b7543_dirty -ops = torch.ops._activation_18b7543_dirty +from . import _activation_0e6f27f_dirty +ops = torch.ops._activation_0e6f27f_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_activation_18b7543_dirty::{op_name}" \ No newline at end of file + return f"_activation_0e6f27f_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch29-cxx11-cu126-x86_64-linux/grouped_poly_norm.py b/build/torch29-cxx11-cu126-x86_64-linux/grouped_poly_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..9184736c1c24584a94aa80c84fc18a95ab557a52 --- /dev/null +++ b/build/torch29-cxx11-cu126-x86_64-linux/grouped_poly_norm.py @@ -0,0 +1,583 @@ +"""Triton-accelerated Grouped FusedMulPolyNorm for MoE. + +Fuses the entire PolyNorm computation into two Triton kernels (fwd + bwd), +eliminating multiple intermediate tensors and kernel launches. + +PolyNorm formula (per row): + poly = w[0] * rms_norm(x^3) + w[1] * rms_norm(x^2) + w[2] * rms_norm(x) + bias + output = poly * mul + +where rms_norm(x) = x / sqrt(mean(x^2, dim=-1) + eps) + +Performance optimizations: + - @triton.autotune selects optimal BLOCK_D, num_warps, and num_stages per + hidden dimension. + - Single-tile specialization: when D <= BLOCK_D, all data stays in registers + across the reduction and output phases, eliminating redundant global reads. + - Multi-tile software pipelining: explicit num_stages in autotune configs + enables overlapping memory loads with computation across loop iterations. + - In-kernel binary search for expert mapping: eliminates 2 PyTorch kernel + launches (torch.arange + torch.bucketize) per forward/backward call. + - Backward 2-pass optimization: pass 1 merges RMS statistics computation + with dot product accumulation, pass 2 computes gradients. This reduces + memory traffic compared to a naive 3-pass approach. + +Forward kernel: one program per row, tiles over D dimension. + - Computes x, x^2, x^3 in registers + - Computes three RMS norms in a single pass (shared variance reduction) + - Applies polynomial weights + bias + mul in-place + +Backward kernel: one program per row, tiles over D dimension. + - Recomputes forward intermediates from saved inputs (activation recomputation) + - 2-pass: (1) RMS stats + dot products + bias grad, (2) grad_input + grad_mul + weight grads + - Weight/bias gradients use tl.atomic_add for cross-row accumulation +""" + +import torch +from torch import Tensor + +try: + import triton + import triton.language as tl + + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + + +# --------------------------------------------------------------------------- +# PyTorch reference implementation (for testing and benchmarking) +# --------------------------------------------------------------------------- +def _rms_norm(x: Tensor, eps: float) -> Tensor: + """Per-row RMS normalization: x / sqrt(mean(x^2, dim=-1) + eps)""" + return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps) + + +def grouped_fused_mul_poly_norm_ref( + input: Tensor, + mul: Tensor, + weight: Tensor, + bias: Tensor, + offsets: Tensor, + eps: float = 1e-6, + expert_offset: int = 0, +) -> Tensor: + """PyTorch reference for Grouped FusedMulPolyNorm (vectorized, single pass). + + Uses torch.bucketize to map tokens to experts, then computes PolyNorm + for all tokens at once. torch.compile friendly. + + Args: + input: (total_tokens, D) - concatenated tokens for all experts + mul: (total_tokens, D) - gate values to multiply with + weight: (num_experts, 3) - per-expert polynomial weights [x^3, x^2, x] + bias: (num_experts, 1) - per-expert polynomial bias + offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32) + eps: numerical stability epsilon + + Returns: + (total_tokens, D) - output tensor + """ + orig_dtype = input.dtype + + token_positions = torch.arange(input.shape[0], device=input.device) + expert_idx = torch.bucketize(token_positions, offsets, right=True) + expert_offset + + weight_fp32 = weight.float() + bias_fp32 = bias.float() + + per_token_w = weight_fp32[expert_idx] + per_token_b = bias_fp32[expert_idx] + + x = input.float() + m = mul.float() + + x2 = x * x + x3 = x2 * x + + poly = (per_token_w[:, 0:1] * _rms_norm(x3, eps) + + per_token_w[:, 1:2] * _rms_norm(x2, eps) + + per_token_w[:, 2:3] * _rms_norm(x, eps) + per_token_b) + + return (poly * m).to(orig_dtype) + + +# --------------------------------------------------------------------------- +# Triton kernel implementation +# --------------------------------------------------------------------------- +if HAS_TRITON: + # --- Autotune configurations --- + _GROUPED_POLYNORM_FWD_CONFIGS = [ + triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2), + triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2), + triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2), + triton.Config({"BLOCK_D": 2048}, num_warps=4, num_stages=1), + triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1), + triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1), + triton.Config({"BLOCK_D": 2048}, num_warps=32, num_stages=1), + ] + + _GROUPED_POLYNORM_BWD_CONFIGS = [ + triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2), + triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=5), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2), + triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=4), + triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2), + triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1), + triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1), + ] + + @triton.autotune( + configs=_GROUPED_POLYNORM_FWD_CONFIGS, + key=["D"], + ) + @triton.jit + def _grouped_polynorm_fwd_kernel( + input_ptr, + mul_ptr, + weight_ptr, + bias_ptr, + offsets_ptr, + output_ptr, + N, + D, + num_experts, + eps, + expert_offset, + stride_input_row, + stride_mul_row, + stride_out_row, + BLOCK_D: tl.constexpr, + ): + """Forward kernel: one program per row.""" + row = tl.program_id(0) + if row >= N: + return + + # Binary search for expert index (12 iters covers up to 4096 experts) + lo = 0 + hi = num_experts + for _ in range(12): + if lo < hi: + mid = (lo + hi) // 2 + if tl.load(offsets_ptr + mid) <= row: + lo = mid + 1 + else: + hi = mid + eidx = lo + expert_offset + + w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32) + w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32) + w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32) + b = tl.load(bias_ptr + eidx).to(tl.float32) + + input_row_ptr = input_ptr + row * stride_input_row + mul_row_ptr = mul_ptr + row * stride_mul_row + out_row_ptr = output_ptr + row * stride_out_row + + D_float = D.to(tl.float32) + + # --- Single-tile path --- + if D <= BLOCK_D: + d_offs = tl.arange(0, BLOCK_D) + mask = d_offs < D + + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + m = tl.load(mul_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + + x2 = x * x + x3 = x2 * x + + inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps) + inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps) + inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps) + + # Pre-multiply scalar weight * inv_rms to save 1 FMA per element + w0_inv = w0 * inv_rms_x3 + w1_inv = w1 * inv_rms_x2 + w2_inv = w2 * inv_rms_x + + poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b + tl.store(out_row_ptr + d_offs, poly * m, mask=mask) + else: + # --- Multi-tile: two-pass approach --- + sum_x2 = tl.zeros((), dtype=tl.float32) + sum_x4 = tl.zeros((), dtype=tl.float32) + sum_x6 = tl.zeros((), dtype=tl.float32) + + for d_start in range(0, D, BLOCK_D): + d_offs = d_start + tl.arange(0, BLOCK_D) + mask = d_offs < D + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + x2 = x * x + sum_x2 += tl.sum(x2) + sum_x4 += tl.sum(x2 * x2) + sum_x6 += tl.sum(x2 * x2 * x2) + + inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps) + inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps) + inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps) + + # Pre-multiply scalar weight * inv_rms + w0_inv = w0 * inv_rms_x3 + w1_inv = w1 * inv_rms_x2 + w2_inv = w2 * inv_rms_x + + for d_start in range(0, D, BLOCK_D): + d_offs = d_start + tl.arange(0, BLOCK_D) + mask = d_offs < D + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + m = tl.load(mul_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + x2 = x * x + x3 = x2 * x + poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b + tl.store(out_row_ptr + d_offs, poly * m, mask=mask) + + @triton.autotune( + configs=_GROUPED_POLYNORM_BWD_CONFIGS, + key=["D"], + reset_to_zero=["grad_weight_ptr", "grad_bias_ptr"], + ) + @triton.jit + def _grouped_polynorm_bwd_kernel( + grad_out_ptr, + input_ptr, + mul_ptr, + weight_ptr, + bias_ptr, + offsets_ptr, + grad_input_ptr, + grad_mul_ptr, + grad_weight_ptr, + grad_bias_ptr, + N, + D, + num_experts, + eps, + expert_offset, + stride_row, + BLOCK_D: tl.constexpr, + ): + """Backward kernel: one program per row, 2-pass approach. + + Pass 1: RMS stats + dot products + bias grad + Pass 2: grad_input + grad_mul + weight grads (via atomic_add) + """ + row = tl.program_id(0) + if row >= N: + return + + lo = 0 + hi = num_experts + for _ in range(12): + if lo < hi: + mid = (lo + hi) // 2 + if tl.load(offsets_ptr + mid) <= row: + lo = mid + 1 + else: + hi = mid + eidx = lo + expert_offset + + w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32) + w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32) + w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32) + b_val = tl.load(bias_ptr + eidx).to(tl.float32) + + input_row_ptr = input_ptr + row * stride_row + mul_row_ptr = mul_ptr + row * stride_row + grad_out_row_ptr = grad_out_ptr + row * stride_row + grad_input_row_ptr = grad_input_ptr + row * stride_row + grad_mul_row_ptr = grad_mul_ptr + row * stride_row + + D_float = D.to(tl.float32) + + # --- Single-tile path --- + if D <= BLOCK_D: + d_offs = tl.arange(0, BLOCK_D) + mask = d_offs < D + + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + m = tl.load(mul_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + go = tl.load(grad_out_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + + x2 = x * x + x3 = x2 * x + + # Compute RMS stats (x4 inlined to reduce register pressure) + inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps) + inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps) + inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps) + + w0_inv = w0 * inv_rms_x3 + w1_inv = w1 * inv_rms_x2 + w2_inv = w2 * inv_rms_x + + dpoly = go * m + + # Dot products for coefficients and weight grads + sum_dpoly_x = tl.sum(dpoly * x) + sum_dpoly_x2 = tl.sum(dpoly * x2) + sum_dpoly_x3 = tl.sum(dpoly * x3) + grad_b_acc = tl.sum(dpoly) + + # Weight grads + grad_w0_acc = inv_rms_x3 * sum_dpoly_x3 + grad_w1_acc = inv_rms_x2 * sum_dpoly_x2 + grad_w2_acc = inv_rms_x * sum_dpoly_x + + # Coefficients for grad_input + coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float + coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float + coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float + + # grad_mul + poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + tl.store(grad_mul_row_ptr + d_offs, go * (poly + b_val), mask=mask) + + # grad_input (in-place accumulation to reduce register pressure) + g = inv_rms_x * (w2 * dpoly - x * coeff_x) + g += 2.0 * x * inv_rms_x2 * (w1 * dpoly - x2 * coeff_x2) + g += 3.0 * x2 * inv_rms_x3 * (w0 * dpoly - x3 * coeff_x3) + tl.store(grad_input_row_ptr + d_offs, g, mask=mask) + + tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc) + tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc) + tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc) + tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc) + else: + # --- Multi-tile: 2-pass --- + # Pass 1: RMS stats + dot products + bias grad + sum_x2 = tl.zeros((), dtype=tl.float32) + sum_x4 = tl.zeros((), dtype=tl.float32) + sum_x6 = tl.zeros((), dtype=tl.float32) + sum_dpoly_x = tl.zeros((), dtype=tl.float32) + sum_dpoly_x2 = tl.zeros((), dtype=tl.float32) + sum_dpoly_x3 = tl.zeros((), dtype=tl.float32) + grad_b_acc = tl.zeros((), dtype=tl.float32) + + for d_start in range(0, D, BLOCK_D): + d_offs = d_start + tl.arange(0, BLOCK_D) + mask = d_offs < D + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + m = tl.load(mul_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + go = tl.load(grad_out_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + + x2 = x * x + x3 = x2 * x + dpoly = go * m + + sum_x2 += tl.sum(x2) + sum_x4 += tl.sum(x2 * x2) + sum_x6 += tl.sum(x2 * x2 * x2) + sum_dpoly_x += tl.sum(dpoly * x) + sum_dpoly_x2 += tl.sum(dpoly * x2) + sum_dpoly_x3 += tl.sum(dpoly * x3) + grad_b_acc += tl.sum(dpoly) + + inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps) + inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps) + inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps) + + w0_inv = w0 * inv_rms_x3 + w1_inv = w1 * inv_rms_x2 + w2_inv = w2 * inv_rms_x + + # Weight grads + grad_w0_acc = inv_rms_x3 * sum_dpoly_x3 + grad_w1_acc = inv_rms_x2 * sum_dpoly_x2 + grad_w2_acc = inv_rms_x * sum_dpoly_x + + # Coefficients for grad_input + coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float + coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float + coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float + + # Pass 2: grad_input + grad_mul + for d_start in range(0, D, BLOCK_D): + d_offs = d_start + tl.arange(0, BLOCK_D) + mask = d_offs < D + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + m = tl.load(mul_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + go = tl.load(grad_out_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + + x2 = x * x + x3 = x2 * x + + poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + tl.store(grad_mul_row_ptr + d_offs, + go * (poly + b_val), + mask=mask) + + dpoly = go * m + g = inv_rms_x * (w2 * dpoly - x * coeff_x) + g += (2.0 * x * inv_rms_x2 * + (w1 * dpoly - x2 * coeff_x2)) + g += (3.0 * x2 * inv_rms_x3 * + (w0 * dpoly - x3 * coeff_x3)) + + tl.store(grad_input_row_ptr + d_offs, g, mask=mask) + + tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc) + tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc) + tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc) + tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc) + + class _GroupedPolyNormFn(torch.autograd.Function): + + @staticmethod + def forward(ctx, input, mul, weight, bias, offsets, eps, expert_offset): + N, D = input.shape + input = input.contiguous() + mul = mul.contiguous() + output = torch.empty_like(input) + + num_experts = offsets.shape[0] + assert num_experts <= 4096, ( + f"Supports at most 4096 experts, got {num_experts}.") + + _grouped_polynorm_fwd_kernel[(N,)]( + input, + mul, + weight, + bias, + offsets, + output, + N, + D, + num_experts, + eps, + expert_offset, + stride_input_row=input.stride(0), + stride_mul_row=mul.stride(0), + stride_out_row=output.stride(0), + ) + + ctx.save_for_backward(input, mul, weight, bias, offsets) + ctx.eps = eps + ctx.expert_offset = expert_offset + return output + + @staticmethod + def backward(ctx, grad_output): + input, mul, weight, bias, offsets = ctx.saved_tensors + eps = ctx.eps + expert_offset = ctx.expert_offset + N, D = input.shape + + grad_output = grad_output.contiguous() + grad_input = torch.empty_like(input) + grad_mul = torch.empty_like(mul) + grad_weight = torch.zeros(weight.shape[0], + 3, + device=weight.device, + dtype=torch.float32) + grad_bias = torch.zeros(bias.shape[0], + device=bias.device, + dtype=torch.float32) + + num_experts = offsets.shape[0] + + _grouped_polynorm_bwd_kernel[(N,)]( + grad_output, + input, + mul, + weight, + bias, + offsets, + grad_input, + grad_mul, + grad_weight, + grad_bias, + N, + D, + num_experts, + eps, + expert_offset, + stride_row=input.stride(0), + ) + + grad_weight = grad_weight.to(weight.dtype) + grad_bias = grad_bias.unsqueeze(-1).to(bias.dtype) + + return grad_input, grad_mul, grad_weight, grad_bias, None, None, None + + def grouped_fused_mul_poly_norm( + input: Tensor, + mul: Tensor, + weight: Tensor, + bias: Tensor, + offsets: Tensor, + eps: float = 1e-6, + expert_offset: int = 0, + ) -> Tensor: + """Triton-accelerated Grouped FusedMulPolyNorm. + + Args: + input: (total_tokens, D) - concatenated tokens for all experts + mul: (total_tokens, D) - gate values to multiply with + weight: (num_experts, 3) - per-expert polynomial weights + bias: (num_experts, 1) - per-expert polynomial bias + offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32) + eps: numerical stability epsilon + expert_offset: offset to add to expert index + + Returns: + (total_tokens, D) - output tensor + """ + return _GroupedPolyNormFn.apply(input, mul, weight, bias, offsets, eps, + expert_offset) + +else: + + def grouped_fused_mul_poly_norm( + input: Tensor, + mul: Tensor, + weight: Tensor, + bias: Tensor, + offsets: Tensor, + eps: float = 1e-6, + expert_offset: int = 0, + ) -> Tensor: + raise RuntimeError( + "Triton is not available. Install triton to use " + "grouped_fused_mul_poly_norm.") diff --git a/build/torch29-cxx11-cu128-x86_64-linux/__init__.py b/build/torch29-cxx11-cu128-x86_64-linux/__init__.py index 0f6f29ac2c688bd09afa41c5d1abd9942c4456d8..a71567edf5422a315e6f06eae6a5606756794820 100644 --- a/build/torch29-cxx11-cu128-x86_64-linux/__init__.py +++ b/build/torch29-cxx11-cu128-x86_64-linux/__init__.py @@ -2,6 +2,7 @@ import torch from . import layers, parallel_style from ._ops import ops +from .grouped_poly_norm import grouped_fused_mul_poly_norm from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction @@ -45,6 +46,7 @@ def fused_add_rms_norm( __all__ = [ "poly_norm", "fused_mul_poly_norm", + "grouped_fused_mul_poly_norm", "rms_norm", "fused_add_rms_norm", "layers", diff --git a/build/torch29-cxx11-cu128-x86_64-linux/_activation_0e6f27f_dirty.abi3.so b/build/torch29-cxx11-cu128-x86_64-linux/_activation_0e6f27f_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..fc714ffdff549d768c4a141b224b70bf1f033b0f --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/_activation_0e6f27f_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dfa89588a5e7e74b3a903912190b97004e308dd8fcb87832c2798d99733591f2 +size 15804336 diff --git a/build/torch29-cxx11-cu128-x86_64-linux/_activation_18b7543_dirty.abi3.so b/build/torch29-cxx11-cu128-x86_64-linux/_activation_18b7543_dirty.abi3.so deleted file mode 100755 index fbc074d8428b9b1f48c21282b6247a57472c4bff..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu128-x86_64-linux/_activation_18b7543_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:a996dcbd533b29a6de849fb4c83b58f5b818688b1c89ae8609805d09b500bc13 -size 15804336 diff --git a/build/torch29-cxx11-cu128-x86_64-linux/_ops.py b/build/torch29-cxx11-cu128-x86_64-linux/_ops.py index 1e14ae6ea50287a7fbc0feee54855ae4ffc0be5e..91c4fcaab95d4cccfa1e53fea482ec1a2334f4ff 100644 --- a/build/torch29-cxx11-cu128-x86_64-linux/_ops.py +++ b/build/torch29-cxx11-cu128-x86_64-linux/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _activation_18b7543_dirty -ops = torch.ops._activation_18b7543_dirty +from . import _activation_0e6f27f_dirty +ops = torch.ops._activation_0e6f27f_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_activation_18b7543_dirty::{op_name}" \ No newline at end of file + return f"_activation_0e6f27f_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch29-cxx11-cu128-x86_64-linux/grouped_poly_norm.py b/build/torch29-cxx11-cu128-x86_64-linux/grouped_poly_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..9184736c1c24584a94aa80c84fc18a95ab557a52 --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/grouped_poly_norm.py @@ -0,0 +1,583 @@ +"""Triton-accelerated Grouped FusedMulPolyNorm for MoE. + +Fuses the entire PolyNorm computation into two Triton kernels (fwd + bwd), +eliminating multiple intermediate tensors and kernel launches. + +PolyNorm formula (per row): + poly = w[0] * rms_norm(x^3) + w[1] * rms_norm(x^2) + w[2] * rms_norm(x) + bias + output = poly * mul + +where rms_norm(x) = x / sqrt(mean(x^2, dim=-1) + eps) + +Performance optimizations: + - @triton.autotune selects optimal BLOCK_D, num_warps, and num_stages per + hidden dimension. + - Single-tile specialization: when D <= BLOCK_D, all data stays in registers + across the reduction and output phases, eliminating redundant global reads. + - Multi-tile software pipelining: explicit num_stages in autotune configs + enables overlapping memory loads with computation across loop iterations. + - In-kernel binary search for expert mapping: eliminates 2 PyTorch kernel + launches (torch.arange + torch.bucketize) per forward/backward call. + - Backward 2-pass optimization: pass 1 merges RMS statistics computation + with dot product accumulation, pass 2 computes gradients. This reduces + memory traffic compared to a naive 3-pass approach. + +Forward kernel: one program per row, tiles over D dimension. + - Computes x, x^2, x^3 in registers + - Computes three RMS norms in a single pass (shared variance reduction) + - Applies polynomial weights + bias + mul in-place + +Backward kernel: one program per row, tiles over D dimension. + - Recomputes forward intermediates from saved inputs (activation recomputation) + - 2-pass: (1) RMS stats + dot products + bias grad, (2) grad_input + grad_mul + weight grads + - Weight/bias gradients use tl.atomic_add for cross-row accumulation +""" + +import torch +from torch import Tensor + +try: + import triton + import triton.language as tl + + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + + +# --------------------------------------------------------------------------- +# PyTorch reference implementation (for testing and benchmarking) +# --------------------------------------------------------------------------- +def _rms_norm(x: Tensor, eps: float) -> Tensor: + """Per-row RMS normalization: x / sqrt(mean(x^2, dim=-1) + eps)""" + return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps) + + +def grouped_fused_mul_poly_norm_ref( + input: Tensor, + mul: Tensor, + weight: Tensor, + bias: Tensor, + offsets: Tensor, + eps: float = 1e-6, + expert_offset: int = 0, +) -> Tensor: + """PyTorch reference for Grouped FusedMulPolyNorm (vectorized, single pass). + + Uses torch.bucketize to map tokens to experts, then computes PolyNorm + for all tokens at once. torch.compile friendly. + + Args: + input: (total_tokens, D) - concatenated tokens for all experts + mul: (total_tokens, D) - gate values to multiply with + weight: (num_experts, 3) - per-expert polynomial weights [x^3, x^2, x] + bias: (num_experts, 1) - per-expert polynomial bias + offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32) + eps: numerical stability epsilon + + Returns: + (total_tokens, D) - output tensor + """ + orig_dtype = input.dtype + + token_positions = torch.arange(input.shape[0], device=input.device) + expert_idx = torch.bucketize(token_positions, offsets, right=True) + expert_offset + + weight_fp32 = weight.float() + bias_fp32 = bias.float() + + per_token_w = weight_fp32[expert_idx] + per_token_b = bias_fp32[expert_idx] + + x = input.float() + m = mul.float() + + x2 = x * x + x3 = x2 * x + + poly = (per_token_w[:, 0:1] * _rms_norm(x3, eps) + + per_token_w[:, 1:2] * _rms_norm(x2, eps) + + per_token_w[:, 2:3] * _rms_norm(x, eps) + per_token_b) + + return (poly * m).to(orig_dtype) + + +# --------------------------------------------------------------------------- +# Triton kernel implementation +# --------------------------------------------------------------------------- +if HAS_TRITON: + # --- Autotune configurations --- + _GROUPED_POLYNORM_FWD_CONFIGS = [ + triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2), + triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2), + triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2), + triton.Config({"BLOCK_D": 2048}, num_warps=4, num_stages=1), + triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1), + triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1), + triton.Config({"BLOCK_D": 2048}, num_warps=32, num_stages=1), + ] + + _GROUPED_POLYNORM_BWD_CONFIGS = [ + triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2), + triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=5), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2), + triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=4), + triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2), + triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1), + triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1), + ] + + @triton.autotune( + configs=_GROUPED_POLYNORM_FWD_CONFIGS, + key=["D"], + ) + @triton.jit + def _grouped_polynorm_fwd_kernel( + input_ptr, + mul_ptr, + weight_ptr, + bias_ptr, + offsets_ptr, + output_ptr, + N, + D, + num_experts, + eps, + expert_offset, + stride_input_row, + stride_mul_row, + stride_out_row, + BLOCK_D: tl.constexpr, + ): + """Forward kernel: one program per row.""" + row = tl.program_id(0) + if row >= N: + return + + # Binary search for expert index (12 iters covers up to 4096 experts) + lo = 0 + hi = num_experts + for _ in range(12): + if lo < hi: + mid = (lo + hi) // 2 + if tl.load(offsets_ptr + mid) <= row: + lo = mid + 1 + else: + hi = mid + eidx = lo + expert_offset + + w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32) + w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32) + w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32) + b = tl.load(bias_ptr + eidx).to(tl.float32) + + input_row_ptr = input_ptr + row * stride_input_row + mul_row_ptr = mul_ptr + row * stride_mul_row + out_row_ptr = output_ptr + row * stride_out_row + + D_float = D.to(tl.float32) + + # --- Single-tile path --- + if D <= BLOCK_D: + d_offs = tl.arange(0, BLOCK_D) + mask = d_offs < D + + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + m = tl.load(mul_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + + x2 = x * x + x3 = x2 * x + + inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps) + inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps) + inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps) + + # Pre-multiply scalar weight * inv_rms to save 1 FMA per element + w0_inv = w0 * inv_rms_x3 + w1_inv = w1 * inv_rms_x2 + w2_inv = w2 * inv_rms_x + + poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b + tl.store(out_row_ptr + d_offs, poly * m, mask=mask) + else: + # --- Multi-tile: two-pass approach --- + sum_x2 = tl.zeros((), dtype=tl.float32) + sum_x4 = tl.zeros((), dtype=tl.float32) + sum_x6 = tl.zeros((), dtype=tl.float32) + + for d_start in range(0, D, BLOCK_D): + d_offs = d_start + tl.arange(0, BLOCK_D) + mask = d_offs < D + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + x2 = x * x + sum_x2 += tl.sum(x2) + sum_x4 += tl.sum(x2 * x2) + sum_x6 += tl.sum(x2 * x2 * x2) + + inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps) + inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps) + inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps) + + # Pre-multiply scalar weight * inv_rms + w0_inv = w0 * inv_rms_x3 + w1_inv = w1 * inv_rms_x2 + w2_inv = w2 * inv_rms_x + + for d_start in range(0, D, BLOCK_D): + d_offs = d_start + tl.arange(0, BLOCK_D) + mask = d_offs < D + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + m = tl.load(mul_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + x2 = x * x + x3 = x2 * x + poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b + tl.store(out_row_ptr + d_offs, poly * m, mask=mask) + + @triton.autotune( + configs=_GROUPED_POLYNORM_BWD_CONFIGS, + key=["D"], + reset_to_zero=["grad_weight_ptr", "grad_bias_ptr"], + ) + @triton.jit + def _grouped_polynorm_bwd_kernel( + grad_out_ptr, + input_ptr, + mul_ptr, + weight_ptr, + bias_ptr, + offsets_ptr, + grad_input_ptr, + grad_mul_ptr, + grad_weight_ptr, + grad_bias_ptr, + N, + D, + num_experts, + eps, + expert_offset, + stride_row, + BLOCK_D: tl.constexpr, + ): + """Backward kernel: one program per row, 2-pass approach. + + Pass 1: RMS stats + dot products + bias grad + Pass 2: grad_input + grad_mul + weight grads (via atomic_add) + """ + row = tl.program_id(0) + if row >= N: + return + + lo = 0 + hi = num_experts + for _ in range(12): + if lo < hi: + mid = (lo + hi) // 2 + if tl.load(offsets_ptr + mid) <= row: + lo = mid + 1 + else: + hi = mid + eidx = lo + expert_offset + + w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32) + w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32) + w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32) + b_val = tl.load(bias_ptr + eidx).to(tl.float32) + + input_row_ptr = input_ptr + row * stride_row + mul_row_ptr = mul_ptr + row * stride_row + grad_out_row_ptr = grad_out_ptr + row * stride_row + grad_input_row_ptr = grad_input_ptr + row * stride_row + grad_mul_row_ptr = grad_mul_ptr + row * stride_row + + D_float = D.to(tl.float32) + + # --- Single-tile path --- + if D <= BLOCK_D: + d_offs = tl.arange(0, BLOCK_D) + mask = d_offs < D + + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + m = tl.load(mul_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + go = tl.load(grad_out_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + + x2 = x * x + x3 = x2 * x + + # Compute RMS stats (x4 inlined to reduce register pressure) + inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps) + inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps) + inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps) + + w0_inv = w0 * inv_rms_x3 + w1_inv = w1 * inv_rms_x2 + w2_inv = w2 * inv_rms_x + + dpoly = go * m + + # Dot products for coefficients and weight grads + sum_dpoly_x = tl.sum(dpoly * x) + sum_dpoly_x2 = tl.sum(dpoly * x2) + sum_dpoly_x3 = tl.sum(dpoly * x3) + grad_b_acc = tl.sum(dpoly) + + # Weight grads + grad_w0_acc = inv_rms_x3 * sum_dpoly_x3 + grad_w1_acc = inv_rms_x2 * sum_dpoly_x2 + grad_w2_acc = inv_rms_x * sum_dpoly_x + + # Coefficients for grad_input + coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float + coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float + coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float + + # grad_mul + poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + tl.store(grad_mul_row_ptr + d_offs, go * (poly + b_val), mask=mask) + + # grad_input (in-place accumulation to reduce register pressure) + g = inv_rms_x * (w2 * dpoly - x * coeff_x) + g += 2.0 * x * inv_rms_x2 * (w1 * dpoly - x2 * coeff_x2) + g += 3.0 * x2 * inv_rms_x3 * (w0 * dpoly - x3 * coeff_x3) + tl.store(grad_input_row_ptr + d_offs, g, mask=mask) + + tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc) + tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc) + tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc) + tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc) + else: + # --- Multi-tile: 2-pass --- + # Pass 1: RMS stats + dot products + bias grad + sum_x2 = tl.zeros((), dtype=tl.float32) + sum_x4 = tl.zeros((), dtype=tl.float32) + sum_x6 = tl.zeros((), dtype=tl.float32) + sum_dpoly_x = tl.zeros((), dtype=tl.float32) + sum_dpoly_x2 = tl.zeros((), dtype=tl.float32) + sum_dpoly_x3 = tl.zeros((), dtype=tl.float32) + grad_b_acc = tl.zeros((), dtype=tl.float32) + + for d_start in range(0, D, BLOCK_D): + d_offs = d_start + tl.arange(0, BLOCK_D) + mask = d_offs < D + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + m = tl.load(mul_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + go = tl.load(grad_out_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + + x2 = x * x + x3 = x2 * x + dpoly = go * m + + sum_x2 += tl.sum(x2) + sum_x4 += tl.sum(x2 * x2) + sum_x6 += tl.sum(x2 * x2 * x2) + sum_dpoly_x += tl.sum(dpoly * x) + sum_dpoly_x2 += tl.sum(dpoly * x2) + sum_dpoly_x3 += tl.sum(dpoly * x3) + grad_b_acc += tl.sum(dpoly) + + inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps) + inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps) + inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps) + + w0_inv = w0 * inv_rms_x3 + w1_inv = w1 * inv_rms_x2 + w2_inv = w2 * inv_rms_x + + # Weight grads + grad_w0_acc = inv_rms_x3 * sum_dpoly_x3 + grad_w1_acc = inv_rms_x2 * sum_dpoly_x2 + grad_w2_acc = inv_rms_x * sum_dpoly_x + + # Coefficients for grad_input + coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float + coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float + coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float + + # Pass 2: grad_input + grad_mul + for d_start in range(0, D, BLOCK_D): + d_offs = d_start + tl.arange(0, BLOCK_D) + mask = d_offs < D + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + m = tl.load(mul_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + go = tl.load(grad_out_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + + x2 = x * x + x3 = x2 * x + + poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + tl.store(grad_mul_row_ptr + d_offs, + go * (poly + b_val), + mask=mask) + + dpoly = go * m + g = inv_rms_x * (w2 * dpoly - x * coeff_x) + g += (2.0 * x * inv_rms_x2 * + (w1 * dpoly - x2 * coeff_x2)) + g += (3.0 * x2 * inv_rms_x3 * + (w0 * dpoly - x3 * coeff_x3)) + + tl.store(grad_input_row_ptr + d_offs, g, mask=mask) + + tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc) + tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc) + tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc) + tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc) + + class _GroupedPolyNormFn(torch.autograd.Function): + + @staticmethod + def forward(ctx, input, mul, weight, bias, offsets, eps, expert_offset): + N, D = input.shape + input = input.contiguous() + mul = mul.contiguous() + output = torch.empty_like(input) + + num_experts = offsets.shape[0] + assert num_experts <= 4096, ( + f"Supports at most 4096 experts, got {num_experts}.") + + _grouped_polynorm_fwd_kernel[(N,)]( + input, + mul, + weight, + bias, + offsets, + output, + N, + D, + num_experts, + eps, + expert_offset, + stride_input_row=input.stride(0), + stride_mul_row=mul.stride(0), + stride_out_row=output.stride(0), + ) + + ctx.save_for_backward(input, mul, weight, bias, offsets) + ctx.eps = eps + ctx.expert_offset = expert_offset + return output + + @staticmethod + def backward(ctx, grad_output): + input, mul, weight, bias, offsets = ctx.saved_tensors + eps = ctx.eps + expert_offset = ctx.expert_offset + N, D = input.shape + + grad_output = grad_output.contiguous() + grad_input = torch.empty_like(input) + grad_mul = torch.empty_like(mul) + grad_weight = torch.zeros(weight.shape[0], + 3, + device=weight.device, + dtype=torch.float32) + grad_bias = torch.zeros(bias.shape[0], + device=bias.device, + dtype=torch.float32) + + num_experts = offsets.shape[0] + + _grouped_polynorm_bwd_kernel[(N,)]( + grad_output, + input, + mul, + weight, + bias, + offsets, + grad_input, + grad_mul, + grad_weight, + grad_bias, + N, + D, + num_experts, + eps, + expert_offset, + stride_row=input.stride(0), + ) + + grad_weight = grad_weight.to(weight.dtype) + grad_bias = grad_bias.unsqueeze(-1).to(bias.dtype) + + return grad_input, grad_mul, grad_weight, grad_bias, None, None, None + + def grouped_fused_mul_poly_norm( + input: Tensor, + mul: Tensor, + weight: Tensor, + bias: Tensor, + offsets: Tensor, + eps: float = 1e-6, + expert_offset: int = 0, + ) -> Tensor: + """Triton-accelerated Grouped FusedMulPolyNorm. + + Args: + input: (total_tokens, D) - concatenated tokens for all experts + mul: (total_tokens, D) - gate values to multiply with + weight: (num_experts, 3) - per-expert polynomial weights + bias: (num_experts, 1) - per-expert polynomial bias + offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32) + eps: numerical stability epsilon + expert_offset: offset to add to expert index + + Returns: + (total_tokens, D) - output tensor + """ + return _GroupedPolyNormFn.apply(input, mul, weight, bias, offsets, eps, + expert_offset) + +else: + + def grouped_fused_mul_poly_norm( + input: Tensor, + mul: Tensor, + weight: Tensor, + bias: Tensor, + offsets: Tensor, + eps: float = 1e-6, + expert_offset: int = 0, + ) -> Tensor: + raise RuntimeError( + "Triton is not available. Install triton to use " + "grouped_fused_mul_poly_norm.") diff --git a/build/torch29-cxx11-cu130-x86_64-linux/__init__.py b/build/torch29-cxx11-cu130-x86_64-linux/__init__.py index 0f6f29ac2c688bd09afa41c5d1abd9942c4456d8..a71567edf5422a315e6f06eae6a5606756794820 100644 --- a/build/torch29-cxx11-cu130-x86_64-linux/__init__.py +++ b/build/torch29-cxx11-cu130-x86_64-linux/__init__.py @@ -2,6 +2,7 @@ import torch from . import layers, parallel_style from ._ops import ops +from .grouped_poly_norm import grouped_fused_mul_poly_norm from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction @@ -45,6 +46,7 @@ def fused_add_rms_norm( __all__ = [ "poly_norm", "fused_mul_poly_norm", + "grouped_fused_mul_poly_norm", "rms_norm", "fused_add_rms_norm", "layers", diff --git a/build/torch29-cxx11-cu130-x86_64-linux/_activation_0e6f27f_dirty.abi3.so b/build/torch29-cxx11-cu130-x86_64-linux/_activation_0e6f27f_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..3293d543a0df5bb42f7431eb87a5664ff0765377 --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/_activation_0e6f27f_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:eb799a7da44bca518185f71eaf7b2a48a1e0365f41fef9298c356a63fe374d2a +size 13513992 diff --git a/build/torch29-cxx11-cu130-x86_64-linux/_activation_18b7543_dirty.abi3.so b/build/torch29-cxx11-cu130-x86_64-linux/_activation_18b7543_dirty.abi3.so deleted file mode 100755 index a2087eb03968ce925fa43789851fb421b65b3811..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-cu130-x86_64-linux/_activation_18b7543_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:330acceb49e763f83496e848fb3d55b72ff0da422daafc8d8f978f10d6c35cd2 -size 13513992 diff --git a/build/torch29-cxx11-cu130-x86_64-linux/_ops.py b/build/torch29-cxx11-cu130-x86_64-linux/_ops.py index 1e14ae6ea50287a7fbc0feee54855ae4ffc0be5e..91c4fcaab95d4cccfa1e53fea482ec1a2334f4ff 100644 --- a/build/torch29-cxx11-cu130-x86_64-linux/_ops.py +++ b/build/torch29-cxx11-cu130-x86_64-linux/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _activation_18b7543_dirty -ops = torch.ops._activation_18b7543_dirty +from . import _activation_0e6f27f_dirty +ops = torch.ops._activation_0e6f27f_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_activation_18b7543_dirty::{op_name}" \ No newline at end of file + return f"_activation_0e6f27f_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch29-cxx11-cu130-x86_64-linux/grouped_poly_norm.py b/build/torch29-cxx11-cu130-x86_64-linux/grouped_poly_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..9184736c1c24584a94aa80c84fc18a95ab557a52 --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/grouped_poly_norm.py @@ -0,0 +1,583 @@ +"""Triton-accelerated Grouped FusedMulPolyNorm for MoE. + +Fuses the entire PolyNorm computation into two Triton kernels (fwd + bwd), +eliminating multiple intermediate tensors and kernel launches. + +PolyNorm formula (per row): + poly = w[0] * rms_norm(x^3) + w[1] * rms_norm(x^2) + w[2] * rms_norm(x) + bias + output = poly * mul + +where rms_norm(x) = x / sqrt(mean(x^2, dim=-1) + eps) + +Performance optimizations: + - @triton.autotune selects optimal BLOCK_D, num_warps, and num_stages per + hidden dimension. + - Single-tile specialization: when D <= BLOCK_D, all data stays in registers + across the reduction and output phases, eliminating redundant global reads. + - Multi-tile software pipelining: explicit num_stages in autotune configs + enables overlapping memory loads with computation across loop iterations. + - In-kernel binary search for expert mapping: eliminates 2 PyTorch kernel + launches (torch.arange + torch.bucketize) per forward/backward call. + - Backward 2-pass optimization: pass 1 merges RMS statistics computation + with dot product accumulation, pass 2 computes gradients. This reduces + memory traffic compared to a naive 3-pass approach. + +Forward kernel: one program per row, tiles over D dimension. + - Computes x, x^2, x^3 in registers + - Computes three RMS norms in a single pass (shared variance reduction) + - Applies polynomial weights + bias + mul in-place + +Backward kernel: one program per row, tiles over D dimension. + - Recomputes forward intermediates from saved inputs (activation recomputation) + - 2-pass: (1) RMS stats + dot products + bias grad, (2) grad_input + grad_mul + weight grads + - Weight/bias gradients use tl.atomic_add for cross-row accumulation +""" + +import torch +from torch import Tensor + +try: + import triton + import triton.language as tl + + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + + +# --------------------------------------------------------------------------- +# PyTorch reference implementation (for testing and benchmarking) +# --------------------------------------------------------------------------- +def _rms_norm(x: Tensor, eps: float) -> Tensor: + """Per-row RMS normalization: x / sqrt(mean(x^2, dim=-1) + eps)""" + return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps) + + +def grouped_fused_mul_poly_norm_ref( + input: Tensor, + mul: Tensor, + weight: Tensor, + bias: Tensor, + offsets: Tensor, + eps: float = 1e-6, + expert_offset: int = 0, +) -> Tensor: + """PyTorch reference for Grouped FusedMulPolyNorm (vectorized, single pass). + + Uses torch.bucketize to map tokens to experts, then computes PolyNorm + for all tokens at once. torch.compile friendly. + + Args: + input: (total_tokens, D) - concatenated tokens for all experts + mul: (total_tokens, D) - gate values to multiply with + weight: (num_experts, 3) - per-expert polynomial weights [x^3, x^2, x] + bias: (num_experts, 1) - per-expert polynomial bias + offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32) + eps: numerical stability epsilon + + Returns: + (total_tokens, D) - output tensor + """ + orig_dtype = input.dtype + + token_positions = torch.arange(input.shape[0], device=input.device) + expert_idx = torch.bucketize(token_positions, offsets, right=True) + expert_offset + + weight_fp32 = weight.float() + bias_fp32 = bias.float() + + per_token_w = weight_fp32[expert_idx] + per_token_b = bias_fp32[expert_idx] + + x = input.float() + m = mul.float() + + x2 = x * x + x3 = x2 * x + + poly = (per_token_w[:, 0:1] * _rms_norm(x3, eps) + + per_token_w[:, 1:2] * _rms_norm(x2, eps) + + per_token_w[:, 2:3] * _rms_norm(x, eps) + per_token_b) + + return (poly * m).to(orig_dtype) + + +# --------------------------------------------------------------------------- +# Triton kernel implementation +# --------------------------------------------------------------------------- +if HAS_TRITON: + # --- Autotune configurations --- + _GROUPED_POLYNORM_FWD_CONFIGS = [ + triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2), + triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2), + triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2), + triton.Config({"BLOCK_D": 2048}, num_warps=4, num_stages=1), + triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1), + triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1), + triton.Config({"BLOCK_D": 2048}, num_warps=32, num_stages=1), + ] + + _GROUPED_POLYNORM_BWD_CONFIGS = [ + triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2), + triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=5), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2), + triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=4), + triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2), + triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1), + triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1), + ] + + @triton.autotune( + configs=_GROUPED_POLYNORM_FWD_CONFIGS, + key=["D"], + ) + @triton.jit + def _grouped_polynorm_fwd_kernel( + input_ptr, + mul_ptr, + weight_ptr, + bias_ptr, + offsets_ptr, + output_ptr, + N, + D, + num_experts, + eps, + expert_offset, + stride_input_row, + stride_mul_row, + stride_out_row, + BLOCK_D: tl.constexpr, + ): + """Forward kernel: one program per row.""" + row = tl.program_id(0) + if row >= N: + return + + # Binary search for expert index (12 iters covers up to 4096 experts) + lo = 0 + hi = num_experts + for _ in range(12): + if lo < hi: + mid = (lo + hi) // 2 + if tl.load(offsets_ptr + mid) <= row: + lo = mid + 1 + else: + hi = mid + eidx = lo + expert_offset + + w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32) + w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32) + w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32) + b = tl.load(bias_ptr + eidx).to(tl.float32) + + input_row_ptr = input_ptr + row * stride_input_row + mul_row_ptr = mul_ptr + row * stride_mul_row + out_row_ptr = output_ptr + row * stride_out_row + + D_float = D.to(tl.float32) + + # --- Single-tile path --- + if D <= BLOCK_D: + d_offs = tl.arange(0, BLOCK_D) + mask = d_offs < D + + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + m = tl.load(mul_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + + x2 = x * x + x3 = x2 * x + + inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps) + inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps) + inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps) + + # Pre-multiply scalar weight * inv_rms to save 1 FMA per element + w0_inv = w0 * inv_rms_x3 + w1_inv = w1 * inv_rms_x2 + w2_inv = w2 * inv_rms_x + + poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b + tl.store(out_row_ptr + d_offs, poly * m, mask=mask) + else: + # --- Multi-tile: two-pass approach --- + sum_x2 = tl.zeros((), dtype=tl.float32) + sum_x4 = tl.zeros((), dtype=tl.float32) + sum_x6 = tl.zeros((), dtype=tl.float32) + + for d_start in range(0, D, BLOCK_D): + d_offs = d_start + tl.arange(0, BLOCK_D) + mask = d_offs < D + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + x2 = x * x + sum_x2 += tl.sum(x2) + sum_x4 += tl.sum(x2 * x2) + sum_x6 += tl.sum(x2 * x2 * x2) + + inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps) + inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps) + inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps) + + # Pre-multiply scalar weight * inv_rms + w0_inv = w0 * inv_rms_x3 + w1_inv = w1 * inv_rms_x2 + w2_inv = w2 * inv_rms_x + + for d_start in range(0, D, BLOCK_D): + d_offs = d_start + tl.arange(0, BLOCK_D) + mask = d_offs < D + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + m = tl.load(mul_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + x2 = x * x + x3 = x2 * x + poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b + tl.store(out_row_ptr + d_offs, poly * m, mask=mask) + + @triton.autotune( + configs=_GROUPED_POLYNORM_BWD_CONFIGS, + key=["D"], + reset_to_zero=["grad_weight_ptr", "grad_bias_ptr"], + ) + @triton.jit + def _grouped_polynorm_bwd_kernel( + grad_out_ptr, + input_ptr, + mul_ptr, + weight_ptr, + bias_ptr, + offsets_ptr, + grad_input_ptr, + grad_mul_ptr, + grad_weight_ptr, + grad_bias_ptr, + N, + D, + num_experts, + eps, + expert_offset, + stride_row, + BLOCK_D: tl.constexpr, + ): + """Backward kernel: one program per row, 2-pass approach. + + Pass 1: RMS stats + dot products + bias grad + Pass 2: grad_input + grad_mul + weight grads (via atomic_add) + """ + row = tl.program_id(0) + if row >= N: + return + + lo = 0 + hi = num_experts + for _ in range(12): + if lo < hi: + mid = (lo + hi) // 2 + if tl.load(offsets_ptr + mid) <= row: + lo = mid + 1 + else: + hi = mid + eidx = lo + expert_offset + + w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32) + w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32) + w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32) + b_val = tl.load(bias_ptr + eidx).to(tl.float32) + + input_row_ptr = input_ptr + row * stride_row + mul_row_ptr = mul_ptr + row * stride_row + grad_out_row_ptr = grad_out_ptr + row * stride_row + grad_input_row_ptr = grad_input_ptr + row * stride_row + grad_mul_row_ptr = grad_mul_ptr + row * stride_row + + D_float = D.to(tl.float32) + + # --- Single-tile path --- + if D <= BLOCK_D: + d_offs = tl.arange(0, BLOCK_D) + mask = d_offs < D + + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + m = tl.load(mul_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + go = tl.load(grad_out_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + + x2 = x * x + x3 = x2 * x + + # Compute RMS stats (x4 inlined to reduce register pressure) + inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps) + inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps) + inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps) + + w0_inv = w0 * inv_rms_x3 + w1_inv = w1 * inv_rms_x2 + w2_inv = w2 * inv_rms_x + + dpoly = go * m + + # Dot products for coefficients and weight grads + sum_dpoly_x = tl.sum(dpoly * x) + sum_dpoly_x2 = tl.sum(dpoly * x2) + sum_dpoly_x3 = tl.sum(dpoly * x3) + grad_b_acc = tl.sum(dpoly) + + # Weight grads + grad_w0_acc = inv_rms_x3 * sum_dpoly_x3 + grad_w1_acc = inv_rms_x2 * sum_dpoly_x2 + grad_w2_acc = inv_rms_x * sum_dpoly_x + + # Coefficients for grad_input + coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float + coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float + coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float + + # grad_mul + poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + tl.store(grad_mul_row_ptr + d_offs, go * (poly + b_val), mask=mask) + + # grad_input (in-place accumulation to reduce register pressure) + g = inv_rms_x * (w2 * dpoly - x * coeff_x) + g += 2.0 * x * inv_rms_x2 * (w1 * dpoly - x2 * coeff_x2) + g += 3.0 * x2 * inv_rms_x3 * (w0 * dpoly - x3 * coeff_x3) + tl.store(grad_input_row_ptr + d_offs, g, mask=mask) + + tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc) + tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc) + tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc) + tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc) + else: + # --- Multi-tile: 2-pass --- + # Pass 1: RMS stats + dot products + bias grad + sum_x2 = tl.zeros((), dtype=tl.float32) + sum_x4 = tl.zeros((), dtype=tl.float32) + sum_x6 = tl.zeros((), dtype=tl.float32) + sum_dpoly_x = tl.zeros((), dtype=tl.float32) + sum_dpoly_x2 = tl.zeros((), dtype=tl.float32) + sum_dpoly_x3 = tl.zeros((), dtype=tl.float32) + grad_b_acc = tl.zeros((), dtype=tl.float32) + + for d_start in range(0, D, BLOCK_D): + d_offs = d_start + tl.arange(0, BLOCK_D) + mask = d_offs < D + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + m = tl.load(mul_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + go = tl.load(grad_out_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + + x2 = x * x + x3 = x2 * x + dpoly = go * m + + sum_x2 += tl.sum(x2) + sum_x4 += tl.sum(x2 * x2) + sum_x6 += tl.sum(x2 * x2 * x2) + sum_dpoly_x += tl.sum(dpoly * x) + sum_dpoly_x2 += tl.sum(dpoly * x2) + sum_dpoly_x3 += tl.sum(dpoly * x3) + grad_b_acc += tl.sum(dpoly) + + inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps) + inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps) + inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps) + + w0_inv = w0 * inv_rms_x3 + w1_inv = w1 * inv_rms_x2 + w2_inv = w2 * inv_rms_x + + # Weight grads + grad_w0_acc = inv_rms_x3 * sum_dpoly_x3 + grad_w1_acc = inv_rms_x2 * sum_dpoly_x2 + grad_w2_acc = inv_rms_x * sum_dpoly_x + + # Coefficients for grad_input + coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float + coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float + coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float + + # Pass 2: grad_input + grad_mul + for d_start in range(0, D, BLOCK_D): + d_offs = d_start + tl.arange(0, BLOCK_D) + mask = d_offs < D + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + m = tl.load(mul_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + go = tl.load(grad_out_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + + x2 = x * x + x3 = x2 * x + + poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + tl.store(grad_mul_row_ptr + d_offs, + go * (poly + b_val), + mask=mask) + + dpoly = go * m + g = inv_rms_x * (w2 * dpoly - x * coeff_x) + g += (2.0 * x * inv_rms_x2 * + (w1 * dpoly - x2 * coeff_x2)) + g += (3.0 * x2 * inv_rms_x3 * + (w0 * dpoly - x3 * coeff_x3)) + + tl.store(grad_input_row_ptr + d_offs, g, mask=mask) + + tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc) + tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc) + tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc) + tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc) + + class _GroupedPolyNormFn(torch.autograd.Function): + + @staticmethod + def forward(ctx, input, mul, weight, bias, offsets, eps, expert_offset): + N, D = input.shape + input = input.contiguous() + mul = mul.contiguous() + output = torch.empty_like(input) + + num_experts = offsets.shape[0] + assert num_experts <= 4096, ( + f"Supports at most 4096 experts, got {num_experts}.") + + _grouped_polynorm_fwd_kernel[(N,)]( + input, + mul, + weight, + bias, + offsets, + output, + N, + D, + num_experts, + eps, + expert_offset, + stride_input_row=input.stride(0), + stride_mul_row=mul.stride(0), + stride_out_row=output.stride(0), + ) + + ctx.save_for_backward(input, mul, weight, bias, offsets) + ctx.eps = eps + ctx.expert_offset = expert_offset + return output + + @staticmethod + def backward(ctx, grad_output): + input, mul, weight, bias, offsets = ctx.saved_tensors + eps = ctx.eps + expert_offset = ctx.expert_offset + N, D = input.shape + + grad_output = grad_output.contiguous() + grad_input = torch.empty_like(input) + grad_mul = torch.empty_like(mul) + grad_weight = torch.zeros(weight.shape[0], + 3, + device=weight.device, + dtype=torch.float32) + grad_bias = torch.zeros(bias.shape[0], + device=bias.device, + dtype=torch.float32) + + num_experts = offsets.shape[0] + + _grouped_polynorm_bwd_kernel[(N,)]( + grad_output, + input, + mul, + weight, + bias, + offsets, + grad_input, + grad_mul, + grad_weight, + grad_bias, + N, + D, + num_experts, + eps, + expert_offset, + stride_row=input.stride(0), + ) + + grad_weight = grad_weight.to(weight.dtype) + grad_bias = grad_bias.unsqueeze(-1).to(bias.dtype) + + return grad_input, grad_mul, grad_weight, grad_bias, None, None, None + + def grouped_fused_mul_poly_norm( + input: Tensor, + mul: Tensor, + weight: Tensor, + bias: Tensor, + offsets: Tensor, + eps: float = 1e-6, + expert_offset: int = 0, + ) -> Tensor: + """Triton-accelerated Grouped FusedMulPolyNorm. + + Args: + input: (total_tokens, D) - concatenated tokens for all experts + mul: (total_tokens, D) - gate values to multiply with + weight: (num_experts, 3) - per-expert polynomial weights + bias: (num_experts, 1) - per-expert polynomial bias + offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32) + eps: numerical stability epsilon + expert_offset: offset to add to expert index + + Returns: + (total_tokens, D) - output tensor + """ + return _GroupedPolyNormFn.apply(input, mul, weight, bias, offsets, eps, + expert_offset) + +else: + + def grouped_fused_mul_poly_norm( + input: Tensor, + mul: Tensor, + weight: Tensor, + bias: Tensor, + offsets: Tensor, + eps: float = 1e-6, + expert_offset: int = 0, + ) -> Tensor: + raise RuntimeError( + "Triton is not available. Install triton to use " + "grouped_fused_mul_poly_norm.") diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/__init__.py b/build/torch29-cxx11-rocm63-x86_64-linux/__init__.py index 0f6f29ac2c688bd09afa41c5d1abd9942c4456d8..a71567edf5422a315e6f06eae6a5606756794820 100644 --- a/build/torch29-cxx11-rocm63-x86_64-linux/__init__.py +++ b/build/torch29-cxx11-rocm63-x86_64-linux/__init__.py @@ -2,6 +2,7 @@ import torch from . import layers, parallel_style from ._ops import ops +from .grouped_poly_norm import grouped_fused_mul_poly_norm from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction @@ -45,6 +46,7 @@ def fused_add_rms_norm( __all__ = [ "poly_norm", "fused_mul_poly_norm", + "grouped_fused_mul_poly_norm", "rms_norm", "fused_add_rms_norm", "layers", diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/_activation_0e6f27f_dirty.abi3.so b/build/torch29-cxx11-rocm63-x86_64-linux/_activation_0e6f27f_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..bc553d81eccbb0529e8232185d4fbe549dca7f35 --- /dev/null +++ b/build/torch29-cxx11-rocm63-x86_64-linux/_activation_0e6f27f_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:81b092e60ccd1d42d1fae0be2ed9de420609e032009cd3dc84267651806d3a62 +size 2788640 diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/_activation_18b7543_dirty.abi3.so b/build/torch29-cxx11-rocm63-x86_64-linux/_activation_18b7543_dirty.abi3.so deleted file mode 100755 index c4ed96ce4da26c0944532ca07d697cca5d0a3ff3..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-rocm63-x86_64-linux/_activation_18b7543_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:d0461a82c98d87be63686fa52fad04e2f0c1a57605734c730e7d97d1d2cab4fa -size 2788640 diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/_ops.py b/build/torch29-cxx11-rocm63-x86_64-linux/_ops.py index 1e14ae6ea50287a7fbc0feee54855ae4ffc0be5e..91c4fcaab95d4cccfa1e53fea482ec1a2334f4ff 100644 --- a/build/torch29-cxx11-rocm63-x86_64-linux/_ops.py +++ b/build/torch29-cxx11-rocm63-x86_64-linux/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _activation_18b7543_dirty -ops = torch.ops._activation_18b7543_dirty +from . import _activation_0e6f27f_dirty +ops = torch.ops._activation_0e6f27f_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_activation_18b7543_dirty::{op_name}" \ No newline at end of file + return f"_activation_0e6f27f_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch29-cxx11-rocm63-x86_64-linux/grouped_poly_norm.py b/build/torch29-cxx11-rocm63-x86_64-linux/grouped_poly_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..9184736c1c24584a94aa80c84fc18a95ab557a52 --- /dev/null +++ b/build/torch29-cxx11-rocm63-x86_64-linux/grouped_poly_norm.py @@ -0,0 +1,583 @@ +"""Triton-accelerated Grouped FusedMulPolyNorm for MoE. + +Fuses the entire PolyNorm computation into two Triton kernels (fwd + bwd), +eliminating multiple intermediate tensors and kernel launches. + +PolyNorm formula (per row): + poly = w[0] * rms_norm(x^3) + w[1] * rms_norm(x^2) + w[2] * rms_norm(x) + bias + output = poly * mul + +where rms_norm(x) = x / sqrt(mean(x^2, dim=-1) + eps) + +Performance optimizations: + - @triton.autotune selects optimal BLOCK_D, num_warps, and num_stages per + hidden dimension. + - Single-tile specialization: when D <= BLOCK_D, all data stays in registers + across the reduction and output phases, eliminating redundant global reads. + - Multi-tile software pipelining: explicit num_stages in autotune configs + enables overlapping memory loads with computation across loop iterations. + - In-kernel binary search for expert mapping: eliminates 2 PyTorch kernel + launches (torch.arange + torch.bucketize) per forward/backward call. + - Backward 2-pass optimization: pass 1 merges RMS statistics computation + with dot product accumulation, pass 2 computes gradients. This reduces + memory traffic compared to a naive 3-pass approach. + +Forward kernel: one program per row, tiles over D dimension. + - Computes x, x^2, x^3 in registers + - Computes three RMS norms in a single pass (shared variance reduction) + - Applies polynomial weights + bias + mul in-place + +Backward kernel: one program per row, tiles over D dimension. + - Recomputes forward intermediates from saved inputs (activation recomputation) + - 2-pass: (1) RMS stats + dot products + bias grad, (2) grad_input + grad_mul + weight grads + - Weight/bias gradients use tl.atomic_add for cross-row accumulation +""" + +import torch +from torch import Tensor + +try: + import triton + import triton.language as tl + + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + + +# --------------------------------------------------------------------------- +# PyTorch reference implementation (for testing and benchmarking) +# --------------------------------------------------------------------------- +def _rms_norm(x: Tensor, eps: float) -> Tensor: + """Per-row RMS normalization: x / sqrt(mean(x^2, dim=-1) + eps)""" + return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps) + + +def grouped_fused_mul_poly_norm_ref( + input: Tensor, + mul: Tensor, + weight: Tensor, + bias: Tensor, + offsets: Tensor, + eps: float = 1e-6, + expert_offset: int = 0, +) -> Tensor: + """PyTorch reference for Grouped FusedMulPolyNorm (vectorized, single pass). + + Uses torch.bucketize to map tokens to experts, then computes PolyNorm + for all tokens at once. torch.compile friendly. + + Args: + input: (total_tokens, D) - concatenated tokens for all experts + mul: (total_tokens, D) - gate values to multiply with + weight: (num_experts, 3) - per-expert polynomial weights [x^3, x^2, x] + bias: (num_experts, 1) - per-expert polynomial bias + offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32) + eps: numerical stability epsilon + + Returns: + (total_tokens, D) - output tensor + """ + orig_dtype = input.dtype + + token_positions = torch.arange(input.shape[0], device=input.device) + expert_idx = torch.bucketize(token_positions, offsets, right=True) + expert_offset + + weight_fp32 = weight.float() + bias_fp32 = bias.float() + + per_token_w = weight_fp32[expert_idx] + per_token_b = bias_fp32[expert_idx] + + x = input.float() + m = mul.float() + + x2 = x * x + x3 = x2 * x + + poly = (per_token_w[:, 0:1] * _rms_norm(x3, eps) + + per_token_w[:, 1:2] * _rms_norm(x2, eps) + + per_token_w[:, 2:3] * _rms_norm(x, eps) + per_token_b) + + return (poly * m).to(orig_dtype) + + +# --------------------------------------------------------------------------- +# Triton kernel implementation +# --------------------------------------------------------------------------- +if HAS_TRITON: + # --- Autotune configurations --- + _GROUPED_POLYNORM_FWD_CONFIGS = [ + triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2), + triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2), + triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2), + triton.Config({"BLOCK_D": 2048}, num_warps=4, num_stages=1), + triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1), + triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1), + triton.Config({"BLOCK_D": 2048}, num_warps=32, num_stages=1), + ] + + _GROUPED_POLYNORM_BWD_CONFIGS = [ + triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2), + triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=5), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2), + triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=4), + triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2), + triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1), + triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1), + ] + + @triton.autotune( + configs=_GROUPED_POLYNORM_FWD_CONFIGS, + key=["D"], + ) + @triton.jit + def _grouped_polynorm_fwd_kernel( + input_ptr, + mul_ptr, + weight_ptr, + bias_ptr, + offsets_ptr, + output_ptr, + N, + D, + num_experts, + eps, + expert_offset, + stride_input_row, + stride_mul_row, + stride_out_row, + BLOCK_D: tl.constexpr, + ): + """Forward kernel: one program per row.""" + row = tl.program_id(0) + if row >= N: + return + + # Binary search for expert index (12 iters covers up to 4096 experts) + lo = 0 + hi = num_experts + for _ in range(12): + if lo < hi: + mid = (lo + hi) // 2 + if tl.load(offsets_ptr + mid) <= row: + lo = mid + 1 + else: + hi = mid + eidx = lo + expert_offset + + w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32) + w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32) + w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32) + b = tl.load(bias_ptr + eidx).to(tl.float32) + + input_row_ptr = input_ptr + row * stride_input_row + mul_row_ptr = mul_ptr + row * stride_mul_row + out_row_ptr = output_ptr + row * stride_out_row + + D_float = D.to(tl.float32) + + # --- Single-tile path --- + if D <= BLOCK_D: + d_offs = tl.arange(0, BLOCK_D) + mask = d_offs < D + + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + m = tl.load(mul_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + + x2 = x * x + x3 = x2 * x + + inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps) + inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps) + inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps) + + # Pre-multiply scalar weight * inv_rms to save 1 FMA per element + w0_inv = w0 * inv_rms_x3 + w1_inv = w1 * inv_rms_x2 + w2_inv = w2 * inv_rms_x + + poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b + tl.store(out_row_ptr + d_offs, poly * m, mask=mask) + else: + # --- Multi-tile: two-pass approach --- + sum_x2 = tl.zeros((), dtype=tl.float32) + sum_x4 = tl.zeros((), dtype=tl.float32) + sum_x6 = tl.zeros((), dtype=tl.float32) + + for d_start in range(0, D, BLOCK_D): + d_offs = d_start + tl.arange(0, BLOCK_D) + mask = d_offs < D + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + x2 = x * x + sum_x2 += tl.sum(x2) + sum_x4 += tl.sum(x2 * x2) + sum_x6 += tl.sum(x2 * x2 * x2) + + inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps) + inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps) + inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps) + + # Pre-multiply scalar weight * inv_rms + w0_inv = w0 * inv_rms_x3 + w1_inv = w1 * inv_rms_x2 + w2_inv = w2 * inv_rms_x + + for d_start in range(0, D, BLOCK_D): + d_offs = d_start + tl.arange(0, BLOCK_D) + mask = d_offs < D + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + m = tl.load(mul_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + x2 = x * x + x3 = x2 * x + poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b + tl.store(out_row_ptr + d_offs, poly * m, mask=mask) + + @triton.autotune( + configs=_GROUPED_POLYNORM_BWD_CONFIGS, + key=["D"], + reset_to_zero=["grad_weight_ptr", "grad_bias_ptr"], + ) + @triton.jit + def _grouped_polynorm_bwd_kernel( + grad_out_ptr, + input_ptr, + mul_ptr, + weight_ptr, + bias_ptr, + offsets_ptr, + grad_input_ptr, + grad_mul_ptr, + grad_weight_ptr, + grad_bias_ptr, + N, + D, + num_experts, + eps, + expert_offset, + stride_row, + BLOCK_D: tl.constexpr, + ): + """Backward kernel: one program per row, 2-pass approach. + + Pass 1: RMS stats + dot products + bias grad + Pass 2: grad_input + grad_mul + weight grads (via atomic_add) + """ + row = tl.program_id(0) + if row >= N: + return + + lo = 0 + hi = num_experts + for _ in range(12): + if lo < hi: + mid = (lo + hi) // 2 + if tl.load(offsets_ptr + mid) <= row: + lo = mid + 1 + else: + hi = mid + eidx = lo + expert_offset + + w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32) + w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32) + w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32) + b_val = tl.load(bias_ptr + eidx).to(tl.float32) + + input_row_ptr = input_ptr + row * stride_row + mul_row_ptr = mul_ptr + row * stride_row + grad_out_row_ptr = grad_out_ptr + row * stride_row + grad_input_row_ptr = grad_input_ptr + row * stride_row + grad_mul_row_ptr = grad_mul_ptr + row * stride_row + + D_float = D.to(tl.float32) + + # --- Single-tile path --- + if D <= BLOCK_D: + d_offs = tl.arange(0, BLOCK_D) + mask = d_offs < D + + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + m = tl.load(mul_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + go = tl.load(grad_out_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + + x2 = x * x + x3 = x2 * x + + # Compute RMS stats (x4 inlined to reduce register pressure) + inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps) + inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps) + inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps) + + w0_inv = w0 * inv_rms_x3 + w1_inv = w1 * inv_rms_x2 + w2_inv = w2 * inv_rms_x + + dpoly = go * m + + # Dot products for coefficients and weight grads + sum_dpoly_x = tl.sum(dpoly * x) + sum_dpoly_x2 = tl.sum(dpoly * x2) + sum_dpoly_x3 = tl.sum(dpoly * x3) + grad_b_acc = tl.sum(dpoly) + + # Weight grads + grad_w0_acc = inv_rms_x3 * sum_dpoly_x3 + grad_w1_acc = inv_rms_x2 * sum_dpoly_x2 + grad_w2_acc = inv_rms_x * sum_dpoly_x + + # Coefficients for grad_input + coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float + coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float + coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float + + # grad_mul + poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + tl.store(grad_mul_row_ptr + d_offs, go * (poly + b_val), mask=mask) + + # grad_input (in-place accumulation to reduce register pressure) + g = inv_rms_x * (w2 * dpoly - x * coeff_x) + g += 2.0 * x * inv_rms_x2 * (w1 * dpoly - x2 * coeff_x2) + g += 3.0 * x2 * inv_rms_x3 * (w0 * dpoly - x3 * coeff_x3) + tl.store(grad_input_row_ptr + d_offs, g, mask=mask) + + tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc) + tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc) + tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc) + tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc) + else: + # --- Multi-tile: 2-pass --- + # Pass 1: RMS stats + dot products + bias grad + sum_x2 = tl.zeros((), dtype=tl.float32) + sum_x4 = tl.zeros((), dtype=tl.float32) + sum_x6 = tl.zeros((), dtype=tl.float32) + sum_dpoly_x = tl.zeros((), dtype=tl.float32) + sum_dpoly_x2 = tl.zeros((), dtype=tl.float32) + sum_dpoly_x3 = tl.zeros((), dtype=tl.float32) + grad_b_acc = tl.zeros((), dtype=tl.float32) + + for d_start in range(0, D, BLOCK_D): + d_offs = d_start + tl.arange(0, BLOCK_D) + mask = d_offs < D + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + m = tl.load(mul_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + go = tl.load(grad_out_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + + x2 = x * x + x3 = x2 * x + dpoly = go * m + + sum_x2 += tl.sum(x2) + sum_x4 += tl.sum(x2 * x2) + sum_x6 += tl.sum(x2 * x2 * x2) + sum_dpoly_x += tl.sum(dpoly * x) + sum_dpoly_x2 += tl.sum(dpoly * x2) + sum_dpoly_x3 += tl.sum(dpoly * x3) + grad_b_acc += tl.sum(dpoly) + + inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps) + inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps) + inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps) + + w0_inv = w0 * inv_rms_x3 + w1_inv = w1 * inv_rms_x2 + w2_inv = w2 * inv_rms_x + + # Weight grads + grad_w0_acc = inv_rms_x3 * sum_dpoly_x3 + grad_w1_acc = inv_rms_x2 * sum_dpoly_x2 + grad_w2_acc = inv_rms_x * sum_dpoly_x + + # Coefficients for grad_input + coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float + coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float + coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float + + # Pass 2: grad_input + grad_mul + for d_start in range(0, D, BLOCK_D): + d_offs = d_start + tl.arange(0, BLOCK_D) + mask = d_offs < D + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + m = tl.load(mul_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + go = tl.load(grad_out_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + + x2 = x * x + x3 = x2 * x + + poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + tl.store(grad_mul_row_ptr + d_offs, + go * (poly + b_val), + mask=mask) + + dpoly = go * m + g = inv_rms_x * (w2 * dpoly - x * coeff_x) + g += (2.0 * x * inv_rms_x2 * + (w1 * dpoly - x2 * coeff_x2)) + g += (3.0 * x2 * inv_rms_x3 * + (w0 * dpoly - x3 * coeff_x3)) + + tl.store(grad_input_row_ptr + d_offs, g, mask=mask) + + tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc) + tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc) + tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc) + tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc) + + class _GroupedPolyNormFn(torch.autograd.Function): + + @staticmethod + def forward(ctx, input, mul, weight, bias, offsets, eps, expert_offset): + N, D = input.shape + input = input.contiguous() + mul = mul.contiguous() + output = torch.empty_like(input) + + num_experts = offsets.shape[0] + assert num_experts <= 4096, ( + f"Supports at most 4096 experts, got {num_experts}.") + + _grouped_polynorm_fwd_kernel[(N,)]( + input, + mul, + weight, + bias, + offsets, + output, + N, + D, + num_experts, + eps, + expert_offset, + stride_input_row=input.stride(0), + stride_mul_row=mul.stride(0), + stride_out_row=output.stride(0), + ) + + ctx.save_for_backward(input, mul, weight, bias, offsets) + ctx.eps = eps + ctx.expert_offset = expert_offset + return output + + @staticmethod + def backward(ctx, grad_output): + input, mul, weight, bias, offsets = ctx.saved_tensors + eps = ctx.eps + expert_offset = ctx.expert_offset + N, D = input.shape + + grad_output = grad_output.contiguous() + grad_input = torch.empty_like(input) + grad_mul = torch.empty_like(mul) + grad_weight = torch.zeros(weight.shape[0], + 3, + device=weight.device, + dtype=torch.float32) + grad_bias = torch.zeros(bias.shape[0], + device=bias.device, + dtype=torch.float32) + + num_experts = offsets.shape[0] + + _grouped_polynorm_bwd_kernel[(N,)]( + grad_output, + input, + mul, + weight, + bias, + offsets, + grad_input, + grad_mul, + grad_weight, + grad_bias, + N, + D, + num_experts, + eps, + expert_offset, + stride_row=input.stride(0), + ) + + grad_weight = grad_weight.to(weight.dtype) + grad_bias = grad_bias.unsqueeze(-1).to(bias.dtype) + + return grad_input, grad_mul, grad_weight, grad_bias, None, None, None + + def grouped_fused_mul_poly_norm( + input: Tensor, + mul: Tensor, + weight: Tensor, + bias: Tensor, + offsets: Tensor, + eps: float = 1e-6, + expert_offset: int = 0, + ) -> Tensor: + """Triton-accelerated Grouped FusedMulPolyNorm. + + Args: + input: (total_tokens, D) - concatenated tokens for all experts + mul: (total_tokens, D) - gate values to multiply with + weight: (num_experts, 3) - per-expert polynomial weights + bias: (num_experts, 1) - per-expert polynomial bias + offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32) + eps: numerical stability epsilon + expert_offset: offset to add to expert index + + Returns: + (total_tokens, D) - output tensor + """ + return _GroupedPolyNormFn.apply(input, mul, weight, bias, offsets, eps, + expert_offset) + +else: + + def grouped_fused_mul_poly_norm( + input: Tensor, + mul: Tensor, + weight: Tensor, + bias: Tensor, + offsets: Tensor, + eps: float = 1e-6, + expert_offset: int = 0, + ) -> Tensor: + raise RuntimeError( + "Triton is not available. Install triton to use " + "grouped_fused_mul_poly_norm.") diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/__init__.py b/build/torch29-cxx11-rocm64-x86_64-linux/__init__.py index 0f6f29ac2c688bd09afa41c5d1abd9942c4456d8..a71567edf5422a315e6f06eae6a5606756794820 100644 --- a/build/torch29-cxx11-rocm64-x86_64-linux/__init__.py +++ b/build/torch29-cxx11-rocm64-x86_64-linux/__init__.py @@ -2,6 +2,7 @@ import torch from . import layers, parallel_style from ._ops import ops +from .grouped_poly_norm import grouped_fused_mul_poly_norm from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction @@ -45,6 +46,7 @@ def fused_add_rms_norm( __all__ = [ "poly_norm", "fused_mul_poly_norm", + "grouped_fused_mul_poly_norm", "rms_norm", "fused_add_rms_norm", "layers", diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/_activation_0e6f27f_dirty.abi3.so b/build/torch29-cxx11-rocm64-x86_64-linux/_activation_0e6f27f_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..93e91975634bbb6d38d1dcc1886c24e28544b255 --- /dev/null +++ b/build/torch29-cxx11-rocm64-x86_64-linux/_activation_0e6f27f_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:38f13670768414f29af3722affae76287a6a9c90f198919432e23c170183fd47 +size 2798440 diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/_activation_18b7543_dirty.abi3.so b/build/torch29-cxx11-rocm64-x86_64-linux/_activation_18b7543_dirty.abi3.so deleted file mode 100755 index 9523fca633977cc7250bab7d685b1f9fa766de90..0000000000000000000000000000000000000000 --- a/build/torch29-cxx11-rocm64-x86_64-linux/_activation_18b7543_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:5b649f0e22da07195193300ef534bc309682f4fdb0ca4f0c5286e16cecd57e5b -size 2798440 diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/_ops.py b/build/torch29-cxx11-rocm64-x86_64-linux/_ops.py index 1e14ae6ea50287a7fbc0feee54855ae4ffc0be5e..91c4fcaab95d4cccfa1e53fea482ec1a2334f4ff 100644 --- a/build/torch29-cxx11-rocm64-x86_64-linux/_ops.py +++ b/build/torch29-cxx11-rocm64-x86_64-linux/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _activation_18b7543_dirty -ops = torch.ops._activation_18b7543_dirty +from . import _activation_0e6f27f_dirty +ops = torch.ops._activation_0e6f27f_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_activation_18b7543_dirty::{op_name}" \ No newline at end of file + return f"_activation_0e6f27f_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch29-cxx11-rocm64-x86_64-linux/grouped_poly_norm.py b/build/torch29-cxx11-rocm64-x86_64-linux/grouped_poly_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..9184736c1c24584a94aa80c84fc18a95ab557a52 --- /dev/null +++ b/build/torch29-cxx11-rocm64-x86_64-linux/grouped_poly_norm.py @@ -0,0 +1,583 @@ +"""Triton-accelerated Grouped FusedMulPolyNorm for MoE. + +Fuses the entire PolyNorm computation into two Triton kernels (fwd + bwd), +eliminating multiple intermediate tensors and kernel launches. + +PolyNorm formula (per row): + poly = w[0] * rms_norm(x^3) + w[1] * rms_norm(x^2) + w[2] * rms_norm(x) + bias + output = poly * mul + +where rms_norm(x) = x / sqrt(mean(x^2, dim=-1) + eps) + +Performance optimizations: + - @triton.autotune selects optimal BLOCK_D, num_warps, and num_stages per + hidden dimension. + - Single-tile specialization: when D <= BLOCK_D, all data stays in registers + across the reduction and output phases, eliminating redundant global reads. + - Multi-tile software pipelining: explicit num_stages in autotune configs + enables overlapping memory loads with computation across loop iterations. + - In-kernel binary search for expert mapping: eliminates 2 PyTorch kernel + launches (torch.arange + torch.bucketize) per forward/backward call. + - Backward 2-pass optimization: pass 1 merges RMS statistics computation + with dot product accumulation, pass 2 computes gradients. This reduces + memory traffic compared to a naive 3-pass approach. + +Forward kernel: one program per row, tiles over D dimension. + - Computes x, x^2, x^3 in registers + - Computes three RMS norms in a single pass (shared variance reduction) + - Applies polynomial weights + bias + mul in-place + +Backward kernel: one program per row, tiles over D dimension. + - Recomputes forward intermediates from saved inputs (activation recomputation) + - 2-pass: (1) RMS stats + dot products + bias grad, (2) grad_input + grad_mul + weight grads + - Weight/bias gradients use tl.atomic_add for cross-row accumulation +""" + +import torch +from torch import Tensor + +try: + import triton + import triton.language as tl + + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + + +# --------------------------------------------------------------------------- +# PyTorch reference implementation (for testing and benchmarking) +# --------------------------------------------------------------------------- +def _rms_norm(x: Tensor, eps: float) -> Tensor: + """Per-row RMS normalization: x / sqrt(mean(x^2, dim=-1) + eps)""" + return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps) + + +def grouped_fused_mul_poly_norm_ref( + input: Tensor, + mul: Tensor, + weight: Tensor, + bias: Tensor, + offsets: Tensor, + eps: float = 1e-6, + expert_offset: int = 0, +) -> Tensor: + """PyTorch reference for Grouped FusedMulPolyNorm (vectorized, single pass). + + Uses torch.bucketize to map tokens to experts, then computes PolyNorm + for all tokens at once. torch.compile friendly. + + Args: + input: (total_tokens, D) - concatenated tokens for all experts + mul: (total_tokens, D) - gate values to multiply with + weight: (num_experts, 3) - per-expert polynomial weights [x^3, x^2, x] + bias: (num_experts, 1) - per-expert polynomial bias + offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32) + eps: numerical stability epsilon + + Returns: + (total_tokens, D) - output tensor + """ + orig_dtype = input.dtype + + token_positions = torch.arange(input.shape[0], device=input.device) + expert_idx = torch.bucketize(token_positions, offsets, right=True) + expert_offset + + weight_fp32 = weight.float() + bias_fp32 = bias.float() + + per_token_w = weight_fp32[expert_idx] + per_token_b = bias_fp32[expert_idx] + + x = input.float() + m = mul.float() + + x2 = x * x + x3 = x2 * x + + poly = (per_token_w[:, 0:1] * _rms_norm(x3, eps) + + per_token_w[:, 1:2] * _rms_norm(x2, eps) + + per_token_w[:, 2:3] * _rms_norm(x, eps) + per_token_b) + + return (poly * m).to(orig_dtype) + + +# --------------------------------------------------------------------------- +# Triton kernel implementation +# --------------------------------------------------------------------------- +if HAS_TRITON: + # --- Autotune configurations --- + _GROUPED_POLYNORM_FWD_CONFIGS = [ + triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2), + triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2), + triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2), + triton.Config({"BLOCK_D": 2048}, num_warps=4, num_stages=1), + triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1), + triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1), + triton.Config({"BLOCK_D": 2048}, num_warps=32, num_stages=1), + ] + + _GROUPED_POLYNORM_BWD_CONFIGS = [ + triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2), + triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=5), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2), + triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=4), + triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2), + triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1), + triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1), + ] + + @triton.autotune( + configs=_GROUPED_POLYNORM_FWD_CONFIGS, + key=["D"], + ) + @triton.jit + def _grouped_polynorm_fwd_kernel( + input_ptr, + mul_ptr, + weight_ptr, + bias_ptr, + offsets_ptr, + output_ptr, + N, + D, + num_experts, + eps, + expert_offset, + stride_input_row, + stride_mul_row, + stride_out_row, + BLOCK_D: tl.constexpr, + ): + """Forward kernel: one program per row.""" + row = tl.program_id(0) + if row >= N: + return + + # Binary search for expert index (12 iters covers up to 4096 experts) + lo = 0 + hi = num_experts + for _ in range(12): + if lo < hi: + mid = (lo + hi) // 2 + if tl.load(offsets_ptr + mid) <= row: + lo = mid + 1 + else: + hi = mid + eidx = lo + expert_offset + + w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32) + w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32) + w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32) + b = tl.load(bias_ptr + eidx).to(tl.float32) + + input_row_ptr = input_ptr + row * stride_input_row + mul_row_ptr = mul_ptr + row * stride_mul_row + out_row_ptr = output_ptr + row * stride_out_row + + D_float = D.to(tl.float32) + + # --- Single-tile path --- + if D <= BLOCK_D: + d_offs = tl.arange(0, BLOCK_D) + mask = d_offs < D + + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + m = tl.load(mul_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + + x2 = x * x + x3 = x2 * x + + inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps) + inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps) + inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps) + + # Pre-multiply scalar weight * inv_rms to save 1 FMA per element + w0_inv = w0 * inv_rms_x3 + w1_inv = w1 * inv_rms_x2 + w2_inv = w2 * inv_rms_x + + poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b + tl.store(out_row_ptr + d_offs, poly * m, mask=mask) + else: + # --- Multi-tile: two-pass approach --- + sum_x2 = tl.zeros((), dtype=tl.float32) + sum_x4 = tl.zeros((), dtype=tl.float32) + sum_x6 = tl.zeros((), dtype=tl.float32) + + for d_start in range(0, D, BLOCK_D): + d_offs = d_start + tl.arange(0, BLOCK_D) + mask = d_offs < D + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + x2 = x * x + sum_x2 += tl.sum(x2) + sum_x4 += tl.sum(x2 * x2) + sum_x6 += tl.sum(x2 * x2 * x2) + + inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps) + inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps) + inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps) + + # Pre-multiply scalar weight * inv_rms + w0_inv = w0 * inv_rms_x3 + w1_inv = w1 * inv_rms_x2 + w2_inv = w2 * inv_rms_x + + for d_start in range(0, D, BLOCK_D): + d_offs = d_start + tl.arange(0, BLOCK_D) + mask = d_offs < D + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + m = tl.load(mul_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + x2 = x * x + x3 = x2 * x + poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b + tl.store(out_row_ptr + d_offs, poly * m, mask=mask) + + @triton.autotune( + configs=_GROUPED_POLYNORM_BWD_CONFIGS, + key=["D"], + reset_to_zero=["grad_weight_ptr", "grad_bias_ptr"], + ) + @triton.jit + def _grouped_polynorm_bwd_kernel( + grad_out_ptr, + input_ptr, + mul_ptr, + weight_ptr, + bias_ptr, + offsets_ptr, + grad_input_ptr, + grad_mul_ptr, + grad_weight_ptr, + grad_bias_ptr, + N, + D, + num_experts, + eps, + expert_offset, + stride_row, + BLOCK_D: tl.constexpr, + ): + """Backward kernel: one program per row, 2-pass approach. + + Pass 1: RMS stats + dot products + bias grad + Pass 2: grad_input + grad_mul + weight grads (via atomic_add) + """ + row = tl.program_id(0) + if row >= N: + return + + lo = 0 + hi = num_experts + for _ in range(12): + if lo < hi: + mid = (lo + hi) // 2 + if tl.load(offsets_ptr + mid) <= row: + lo = mid + 1 + else: + hi = mid + eidx = lo + expert_offset + + w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32) + w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32) + w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32) + b_val = tl.load(bias_ptr + eidx).to(tl.float32) + + input_row_ptr = input_ptr + row * stride_row + mul_row_ptr = mul_ptr + row * stride_row + grad_out_row_ptr = grad_out_ptr + row * stride_row + grad_input_row_ptr = grad_input_ptr + row * stride_row + grad_mul_row_ptr = grad_mul_ptr + row * stride_row + + D_float = D.to(tl.float32) + + # --- Single-tile path --- + if D <= BLOCK_D: + d_offs = tl.arange(0, BLOCK_D) + mask = d_offs < D + + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + m = tl.load(mul_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + go = tl.load(grad_out_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + + x2 = x * x + x3 = x2 * x + + # Compute RMS stats (x4 inlined to reduce register pressure) + inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps) + inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps) + inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps) + + w0_inv = w0 * inv_rms_x3 + w1_inv = w1 * inv_rms_x2 + w2_inv = w2 * inv_rms_x + + dpoly = go * m + + # Dot products for coefficients and weight grads + sum_dpoly_x = tl.sum(dpoly * x) + sum_dpoly_x2 = tl.sum(dpoly * x2) + sum_dpoly_x3 = tl.sum(dpoly * x3) + grad_b_acc = tl.sum(dpoly) + + # Weight grads + grad_w0_acc = inv_rms_x3 * sum_dpoly_x3 + grad_w1_acc = inv_rms_x2 * sum_dpoly_x2 + grad_w2_acc = inv_rms_x * sum_dpoly_x + + # Coefficients for grad_input + coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float + coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float + coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float + + # grad_mul + poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + tl.store(grad_mul_row_ptr + d_offs, go * (poly + b_val), mask=mask) + + # grad_input (in-place accumulation to reduce register pressure) + g = inv_rms_x * (w2 * dpoly - x * coeff_x) + g += 2.0 * x * inv_rms_x2 * (w1 * dpoly - x2 * coeff_x2) + g += 3.0 * x2 * inv_rms_x3 * (w0 * dpoly - x3 * coeff_x3) + tl.store(grad_input_row_ptr + d_offs, g, mask=mask) + + tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc) + tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc) + tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc) + tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc) + else: + # --- Multi-tile: 2-pass --- + # Pass 1: RMS stats + dot products + bias grad + sum_x2 = tl.zeros((), dtype=tl.float32) + sum_x4 = tl.zeros((), dtype=tl.float32) + sum_x6 = tl.zeros((), dtype=tl.float32) + sum_dpoly_x = tl.zeros((), dtype=tl.float32) + sum_dpoly_x2 = tl.zeros((), dtype=tl.float32) + sum_dpoly_x3 = tl.zeros((), dtype=tl.float32) + grad_b_acc = tl.zeros((), dtype=tl.float32) + + for d_start in range(0, D, BLOCK_D): + d_offs = d_start + tl.arange(0, BLOCK_D) + mask = d_offs < D + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + m = tl.load(mul_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + go = tl.load(grad_out_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + + x2 = x * x + x3 = x2 * x + dpoly = go * m + + sum_x2 += tl.sum(x2) + sum_x4 += tl.sum(x2 * x2) + sum_x6 += tl.sum(x2 * x2 * x2) + sum_dpoly_x += tl.sum(dpoly * x) + sum_dpoly_x2 += tl.sum(dpoly * x2) + sum_dpoly_x3 += tl.sum(dpoly * x3) + grad_b_acc += tl.sum(dpoly) + + inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps) + inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps) + inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps) + + w0_inv = w0 * inv_rms_x3 + w1_inv = w1 * inv_rms_x2 + w2_inv = w2 * inv_rms_x + + # Weight grads + grad_w0_acc = inv_rms_x3 * sum_dpoly_x3 + grad_w1_acc = inv_rms_x2 * sum_dpoly_x2 + grad_w2_acc = inv_rms_x * sum_dpoly_x + + # Coefficients for grad_input + coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float + coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float + coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float + + # Pass 2: grad_input + grad_mul + for d_start in range(0, D, BLOCK_D): + d_offs = d_start + tl.arange(0, BLOCK_D) + mask = d_offs < D + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + m = tl.load(mul_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + go = tl.load(grad_out_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + + x2 = x * x + x3 = x2 * x + + poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + tl.store(grad_mul_row_ptr + d_offs, + go * (poly + b_val), + mask=mask) + + dpoly = go * m + g = inv_rms_x * (w2 * dpoly - x * coeff_x) + g += (2.0 * x * inv_rms_x2 * + (w1 * dpoly - x2 * coeff_x2)) + g += (3.0 * x2 * inv_rms_x3 * + (w0 * dpoly - x3 * coeff_x3)) + + tl.store(grad_input_row_ptr + d_offs, g, mask=mask) + + tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc) + tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc) + tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc) + tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc) + + class _GroupedPolyNormFn(torch.autograd.Function): + + @staticmethod + def forward(ctx, input, mul, weight, bias, offsets, eps, expert_offset): + N, D = input.shape + input = input.contiguous() + mul = mul.contiguous() + output = torch.empty_like(input) + + num_experts = offsets.shape[0] + assert num_experts <= 4096, ( + f"Supports at most 4096 experts, got {num_experts}.") + + _grouped_polynorm_fwd_kernel[(N,)]( + input, + mul, + weight, + bias, + offsets, + output, + N, + D, + num_experts, + eps, + expert_offset, + stride_input_row=input.stride(0), + stride_mul_row=mul.stride(0), + stride_out_row=output.stride(0), + ) + + ctx.save_for_backward(input, mul, weight, bias, offsets) + ctx.eps = eps + ctx.expert_offset = expert_offset + return output + + @staticmethod + def backward(ctx, grad_output): + input, mul, weight, bias, offsets = ctx.saved_tensors + eps = ctx.eps + expert_offset = ctx.expert_offset + N, D = input.shape + + grad_output = grad_output.contiguous() + grad_input = torch.empty_like(input) + grad_mul = torch.empty_like(mul) + grad_weight = torch.zeros(weight.shape[0], + 3, + device=weight.device, + dtype=torch.float32) + grad_bias = torch.zeros(bias.shape[0], + device=bias.device, + dtype=torch.float32) + + num_experts = offsets.shape[0] + + _grouped_polynorm_bwd_kernel[(N,)]( + grad_output, + input, + mul, + weight, + bias, + offsets, + grad_input, + grad_mul, + grad_weight, + grad_bias, + N, + D, + num_experts, + eps, + expert_offset, + stride_row=input.stride(0), + ) + + grad_weight = grad_weight.to(weight.dtype) + grad_bias = grad_bias.unsqueeze(-1).to(bias.dtype) + + return grad_input, grad_mul, grad_weight, grad_bias, None, None, None + + def grouped_fused_mul_poly_norm( + input: Tensor, + mul: Tensor, + weight: Tensor, + bias: Tensor, + offsets: Tensor, + eps: float = 1e-6, + expert_offset: int = 0, + ) -> Tensor: + """Triton-accelerated Grouped FusedMulPolyNorm. + + Args: + input: (total_tokens, D) - concatenated tokens for all experts + mul: (total_tokens, D) - gate values to multiply with + weight: (num_experts, 3) - per-expert polynomial weights + bias: (num_experts, 1) - per-expert polynomial bias + offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32) + eps: numerical stability epsilon + expert_offset: offset to add to expert index + + Returns: + (total_tokens, D) - output tensor + """ + return _GroupedPolyNormFn.apply(input, mul, weight, bias, offsets, eps, + expert_offset) + +else: + + def grouped_fused_mul_poly_norm( + input: Tensor, + mul: Tensor, + weight: Tensor, + bias: Tensor, + offsets: Tensor, + eps: float = 1e-6, + expert_offset: int = 0, + ) -> Tensor: + raise RuntimeError( + "Triton is not available. Install triton to use " + "grouped_fused_mul_poly_norm.") diff --git a/tests/test_grouped_fused_mul_poly_norm.py b/tests/test_grouped_fused_mul_poly_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..8dd72431e6d8070f954a44f145b35cd993b9a767 --- /dev/null +++ b/tests/test_grouped_fused_mul_poly_norm.py @@ -0,0 +1,249 @@ +import pytest +import torch + +from grouped_poly_norm import ( + HAS_TRITON, + grouped_fused_mul_poly_norm_ref, +) + +if HAS_TRITON: + from grouped_poly_norm import grouped_fused_mul_poly_norm + +from .utils import assert_close + +DTYPES = [torch.float, torch.bfloat16, torch.float16] +NUM_TOKENS = [4096, 8192] +D = [256, 1280] +NUM_EXPERTS_LIST = [8, 384] +EXPERT_OFFSETS = [0, 4] +SEEDS = [0] +# Triton kernels launch on the current CUDA device and do not +# auto-dispatch to the tensor's device like CUDA extensions. +# Only test on cuda:0 to avoid cross-device issues. +CUDA_DEVICES = ["cuda:0"] + + +def _counts_to_offsets(counts_list, device): + """Convert list of counts to cumsum offsets tensor.""" + return torch.cumsum( + torch.tensor(counts_list, device=device, dtype=torch.int32), dim=0) + + +def _make_inputs(total_tokens, hidden_dim, num_experts, dtype, device, + seed=42, expert_offset=0): + """Create deterministic test inputs with random token distribution.""" + torch.manual_seed(seed) + + probs = torch.ones(num_experts) / num_experts + assignments = torch.multinomial(probs, total_tokens, replacement=True) + counts = torch.bincount(assignments, minlength=num_experts).tolist() + + # Weight/bias must have expert_offset + num_experts rows + total_experts = expert_offset + num_experts + + # Scale inputs to avoid overflow in bf16 (x^3 can overflow for |x| > 40) + input_t = torch.randn(total_tokens, hidden_dim, device=device, + dtype=dtype) * 0.5 + mul_t = torch.randn(total_tokens, hidden_dim, device=device, + dtype=dtype) * 0.5 + weight = (torch.ones(total_experts, 3, device=device, dtype=dtype) / 3 + + torch.randn(total_experts, 3, device=device, dtype=dtype) * 0.01) + bias = torch.randn(total_experts, 1, device=device, dtype=dtype) * 0.01 + offsets = _counts_to_offsets(counts, device) + + return input_t, mul_t, weight, bias, offsets + + +def _run_ref(input_t, mul_t, weight, bias, offsets, expert_offset=0): + """Run reference forward + backward, return output and grads.""" + inp = input_t.clone().detach().requires_grad_(True) + m = mul_t.clone().detach().requires_grad_(True) + w = weight.clone().detach().requires_grad_(True) + b = bias.clone().detach().requires_grad_(True) + + out = grouped_fused_mul_poly_norm_ref(inp, m, w, b, offsets, + expert_offset=expert_offset) + out.sum().backward() + + return out, inp.grad, m.grad, w.grad, b.grad + + +def _run_triton(input_t, mul_t, weight, bias, offsets, expert_offset=0): + """Run Triton forward + backward, return output and grads.""" + inp = input_t.clone().detach().requires_grad_(True) + m = mul_t.clone().detach().requires_grad_(True) + w = weight.clone().detach().requires_grad_(True) + b = bias.clone().detach().requires_grad_(True) + + out = grouped_fused_mul_poly_norm(inp, m, w, b, offsets, + expert_offset=expert_offset) + out.sum().backward() + + return out, inp.grad, m.grad, w.grad, b.grad + + +@pytest.mark.skipif(not HAS_TRITON, reason="Triton not available") +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("d", D) +@pytest.mark.parametrize("num_experts", NUM_EXPERTS_LIST) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("expert_offset", EXPERT_OFFSETS) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_grouped_fused_mul_poly_norm_forward( + num_tokens: int, + d: int, + num_experts: int, + dtype: torch.dtype, + expert_offset: int, + seed: int, + device: str, +) -> None: + """Triton forward output should match PyTorch reference.""" + torch.set_default_device(device) + input_t, mul_t, weight, bias, offsets = _make_inputs( + num_tokens, d, num_experts, dtype, device, seed, + expert_offset=expert_offset) + + out_ref = grouped_fused_mul_poly_norm_ref(input_t, mul_t, weight, bias, + offsets, + expert_offset=expert_offset) + out_tri = grouped_fused_mul_poly_norm(input_t, mul_t, weight, bias, + offsets, + expert_offset=expert_offset) + + assert out_ref.shape == out_tri.shape == (num_tokens, d) + assert out_ref.dtype == out_tri.dtype == dtype + + if dtype == torch.float32: + assert_close(out_ref, out_tri, atol=1e-4, rtol=1e-4) + else: + assert_close(out_ref, out_tri, atol=1e-2, rtol=1e-2) + + +@pytest.mark.skipif(not HAS_TRITON, reason="Triton not available") +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("d", D) +@pytest.mark.parametrize("num_experts", NUM_EXPERTS_LIST) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("expert_offset", EXPERT_OFFSETS) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_grouped_fused_mul_poly_norm_backward( + num_tokens: int, + d: int, + num_experts: int, + dtype: torch.dtype, + expert_offset: int, + seed: int, + device: str, +) -> None: + """Triton backward gradients should match PyTorch reference.""" + torch.set_default_device(device) + input_t, mul_t, weight, bias, offsets = _make_inputs( + num_tokens, d, num_experts, dtype, device, seed, + expert_offset=expert_offset) + + _, inp_grad_ref, mul_grad_ref, w_grad_ref, b_grad_ref = _run_ref( + input_t, mul_t, weight, bias, offsets, expert_offset=expert_offset) + _, inp_grad_tri, mul_grad_tri, w_grad_tri, b_grad_tri = _run_triton( + input_t, mul_t, weight, bias, offsets, expert_offset=expert_offset) + + if dtype == torch.float32: + atol, rtol = 1e-4, 1e-4 + else: + atol, rtol = 5e-2, 5e-2 + + assert_close(inp_grad_ref, inp_grad_tri, atol=atol, rtol=rtol) + assert_close(mul_grad_ref, mul_grad_tri, atol=atol, rtol=rtol) + assert_close(w_grad_ref, w_grad_tri, atol=atol, rtol=rtol) + assert_close(b_grad_ref, b_grad_tri, atol=atol, rtol=rtol) + + +@pytest.mark.skipif(not HAS_TRITON, reason="Triton not available") +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("expert_offset", EXPERT_OFFSETS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_grouped_fused_mul_poly_norm_zero_token_experts( + dtype: torch.dtype, + expert_offset: int, + device: str, +) -> None: + """Correctness when some experts receive 0 tokens.""" + torch.set_default_device(device) + counts = [100, 50, 0, 80, 30, 0, 60, 40] + total = sum(counts) + num_experts = 8 + total_experts = expert_offset + num_experts + hidden_dim = 256 + + torch.manual_seed(42) + input_t = torch.randn(total, hidden_dim, device=device, dtype=dtype) * 0.5 + mul_t = torch.randn(total, hidden_dim, device=device, dtype=dtype) * 0.5 + weight = torch.ones(total_experts, 3, device=device, dtype=dtype) / 3 + bias = torch.zeros(total_experts, 1, device=device, dtype=dtype) + offsets = _counts_to_offsets(counts, device) + + out_ref = grouped_fused_mul_poly_norm_ref(input_t, mul_t, weight, bias, + offsets, + expert_offset=expert_offset) + out_tri = grouped_fused_mul_poly_norm(input_t, mul_t, weight, bias, + offsets, + expert_offset=expert_offset) + + if dtype == torch.float32: + assert_close(out_ref, out_tri, atol=1e-4, rtol=1e-4) + else: + assert_close(out_ref, out_tri, atol=1e-2, rtol=1e-2) + + # Check backward with zero-token experts + _, _, _, w_grad_ref, b_grad_ref = _run_ref(input_t, mul_t, weight, bias, + offsets, + expert_offset=expert_offset) + _, _, _, w_grad_tri, b_grad_tri = _run_triton(input_t, mul_t, weight, bias, + offsets, + expert_offset=expert_offset) + + if dtype == torch.float32: + atol, rtol = 1e-3, 1e-3 + else: + atol, rtol = 5e-2, 5e-2 + + assert_close(w_grad_ref, w_grad_tri, atol=atol, rtol=rtol) + assert_close(b_grad_ref, b_grad_tri, atol=atol, rtol=rtol) + + # Verify zero-token experts have zero weight/bias gradients + for eidx in [2, 5]: + wi = eidx + expert_offset + assert w_grad_tri[wi].abs().max() == 0, ( + f"Expert {eidx} (weight idx {wi}) should have zero weight grad " + f"but got max={w_grad_tri[wi].abs().max().item()}") + assert b_grad_tri[wi].abs().max() == 0, ( + f"Expert {eidx} (weight idx {wi}) should have zero bias grad " + f"but got max={b_grad_tri[wi].abs().max().item()}") + + +@pytest.mark.skipif(not HAS_TRITON, reason="Triton not available") +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("expert_offset", EXPERT_OFFSETS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_grouped_fused_mul_poly_norm_no_nan_inf( + dtype: torch.dtype, + expert_offset: int, + device: str, +) -> None: + """Output and gradients should not contain NaN or Inf.""" + torch.set_default_device(device) + input_t, mul_t, weight, bias, offsets = _make_inputs( + 4096, 256, 8, dtype, device, expert_offset=expert_offset) + + out, inp_grad, mul_grad, w_grad, b_grad = _run_triton( + input_t, mul_t, weight, bias, offsets, expert_offset=expert_offset) + + assert not out.isnan().any(), "Output contains NaN" + assert not out.isinf().any(), "Output contains Inf" + + for name, grad in [("input", inp_grad), ("mul", mul_grad), + ("weight", w_grad), ("bias", b_grad)]: + assert not grad.isnan().any(), f"{name}_grad contains NaN" + assert not grad.isinf().any(), f"{name}_grad contains Inf" diff --git a/torch-ext/activation/__init__.py b/torch-ext/activation/__init__.py index 0f6f29ac2c688bd09afa41c5d1abd9942c4456d8..a71567edf5422a315e6f06eae6a5606756794820 100644 --- a/torch-ext/activation/__init__.py +++ b/torch-ext/activation/__init__.py @@ -2,6 +2,7 @@ import torch from . import layers, parallel_style from ._ops import ops +from .grouped_poly_norm import grouped_fused_mul_poly_norm from .poly_norm import FusedMulPolyNormFunction, PolyNormFunction from .rms_norm import FusedAddRMSNormFunction, RMSNormFunction @@ -45,6 +46,7 @@ def fused_add_rms_norm( __all__ = [ "poly_norm", "fused_mul_poly_norm", + "grouped_fused_mul_poly_norm", "rms_norm", "fused_add_rms_norm", "layers", diff --git a/torch-ext/activation/grouped_poly_norm.py b/torch-ext/activation/grouped_poly_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..9184736c1c24584a94aa80c84fc18a95ab557a52 --- /dev/null +++ b/torch-ext/activation/grouped_poly_norm.py @@ -0,0 +1,583 @@ +"""Triton-accelerated Grouped FusedMulPolyNorm for MoE. + +Fuses the entire PolyNorm computation into two Triton kernels (fwd + bwd), +eliminating multiple intermediate tensors and kernel launches. + +PolyNorm formula (per row): + poly = w[0] * rms_norm(x^3) + w[1] * rms_norm(x^2) + w[2] * rms_norm(x) + bias + output = poly * mul + +where rms_norm(x) = x / sqrt(mean(x^2, dim=-1) + eps) + +Performance optimizations: + - @triton.autotune selects optimal BLOCK_D, num_warps, and num_stages per + hidden dimension. + - Single-tile specialization: when D <= BLOCK_D, all data stays in registers + across the reduction and output phases, eliminating redundant global reads. + - Multi-tile software pipelining: explicit num_stages in autotune configs + enables overlapping memory loads with computation across loop iterations. + - In-kernel binary search for expert mapping: eliminates 2 PyTorch kernel + launches (torch.arange + torch.bucketize) per forward/backward call. + - Backward 2-pass optimization: pass 1 merges RMS statistics computation + with dot product accumulation, pass 2 computes gradients. This reduces + memory traffic compared to a naive 3-pass approach. + +Forward kernel: one program per row, tiles over D dimension. + - Computes x, x^2, x^3 in registers + - Computes three RMS norms in a single pass (shared variance reduction) + - Applies polynomial weights + bias + mul in-place + +Backward kernel: one program per row, tiles over D dimension. + - Recomputes forward intermediates from saved inputs (activation recomputation) + - 2-pass: (1) RMS stats + dot products + bias grad, (2) grad_input + grad_mul + weight grads + - Weight/bias gradients use tl.atomic_add for cross-row accumulation +""" + +import torch +from torch import Tensor + +try: + import triton + import triton.language as tl + + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + + +# --------------------------------------------------------------------------- +# PyTorch reference implementation (for testing and benchmarking) +# --------------------------------------------------------------------------- +def _rms_norm(x: Tensor, eps: float) -> Tensor: + """Per-row RMS normalization: x / sqrt(mean(x^2, dim=-1) + eps)""" + return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps) + + +def grouped_fused_mul_poly_norm_ref( + input: Tensor, + mul: Tensor, + weight: Tensor, + bias: Tensor, + offsets: Tensor, + eps: float = 1e-6, + expert_offset: int = 0, +) -> Tensor: + """PyTorch reference for Grouped FusedMulPolyNorm (vectorized, single pass). + + Uses torch.bucketize to map tokens to experts, then computes PolyNorm + for all tokens at once. torch.compile friendly. + + Args: + input: (total_tokens, D) - concatenated tokens for all experts + mul: (total_tokens, D) - gate values to multiply with + weight: (num_experts, 3) - per-expert polynomial weights [x^3, x^2, x] + bias: (num_experts, 1) - per-expert polynomial bias + offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32) + eps: numerical stability epsilon + + Returns: + (total_tokens, D) - output tensor + """ + orig_dtype = input.dtype + + token_positions = torch.arange(input.shape[0], device=input.device) + expert_idx = torch.bucketize(token_positions, offsets, right=True) + expert_offset + + weight_fp32 = weight.float() + bias_fp32 = bias.float() + + per_token_w = weight_fp32[expert_idx] + per_token_b = bias_fp32[expert_idx] + + x = input.float() + m = mul.float() + + x2 = x * x + x3 = x2 * x + + poly = (per_token_w[:, 0:1] * _rms_norm(x3, eps) + + per_token_w[:, 1:2] * _rms_norm(x2, eps) + + per_token_w[:, 2:3] * _rms_norm(x, eps) + per_token_b) + + return (poly * m).to(orig_dtype) + + +# --------------------------------------------------------------------------- +# Triton kernel implementation +# --------------------------------------------------------------------------- +if HAS_TRITON: + # --- Autotune configurations --- + _GROUPED_POLYNORM_FWD_CONFIGS = [ + triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2), + triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2), + triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2), + triton.Config({"BLOCK_D": 2048}, num_warps=4, num_stages=1), + triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1), + triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1), + triton.Config({"BLOCK_D": 2048}, num_warps=32, num_stages=1), + ] + + _GROUPED_POLYNORM_BWD_CONFIGS = [ + triton.Config({"BLOCK_D": 128}, num_warps=2, num_stages=2), + triton.Config({"BLOCK_D": 128}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_D": 256}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 256}, num_warps=8, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=4, num_stages=5), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_D": 512}, num_warps=8, num_stages=4), + triton.Config({"BLOCK_D": 512}, num_warps=16, num_stages=2), + triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_D": 1024}, num_warps=8, num_stages=4), + triton.Config({"BLOCK_D": 1024}, num_warps=16, num_stages=2), + triton.Config({"BLOCK_D": 2048}, num_warps=8, num_stages=1), + triton.Config({"BLOCK_D": 2048}, num_warps=16, num_stages=1), + ] + + @triton.autotune( + configs=_GROUPED_POLYNORM_FWD_CONFIGS, + key=["D"], + ) + @triton.jit + def _grouped_polynorm_fwd_kernel( + input_ptr, + mul_ptr, + weight_ptr, + bias_ptr, + offsets_ptr, + output_ptr, + N, + D, + num_experts, + eps, + expert_offset, + stride_input_row, + stride_mul_row, + stride_out_row, + BLOCK_D: tl.constexpr, + ): + """Forward kernel: one program per row.""" + row = tl.program_id(0) + if row >= N: + return + + # Binary search for expert index (12 iters covers up to 4096 experts) + lo = 0 + hi = num_experts + for _ in range(12): + if lo < hi: + mid = (lo + hi) // 2 + if tl.load(offsets_ptr + mid) <= row: + lo = mid + 1 + else: + hi = mid + eidx = lo + expert_offset + + w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32) + w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32) + w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32) + b = tl.load(bias_ptr + eidx).to(tl.float32) + + input_row_ptr = input_ptr + row * stride_input_row + mul_row_ptr = mul_ptr + row * stride_mul_row + out_row_ptr = output_ptr + row * stride_out_row + + D_float = D.to(tl.float32) + + # --- Single-tile path --- + if D <= BLOCK_D: + d_offs = tl.arange(0, BLOCK_D) + mask = d_offs < D + + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + m = tl.load(mul_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + + x2 = x * x + x3 = x2 * x + + inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps) + inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps) + inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps) + + # Pre-multiply scalar weight * inv_rms to save 1 FMA per element + w0_inv = w0 * inv_rms_x3 + w1_inv = w1 * inv_rms_x2 + w2_inv = w2 * inv_rms_x + + poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b + tl.store(out_row_ptr + d_offs, poly * m, mask=mask) + else: + # --- Multi-tile: two-pass approach --- + sum_x2 = tl.zeros((), dtype=tl.float32) + sum_x4 = tl.zeros((), dtype=tl.float32) + sum_x6 = tl.zeros((), dtype=tl.float32) + + for d_start in range(0, D, BLOCK_D): + d_offs = d_start + tl.arange(0, BLOCK_D) + mask = d_offs < D + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + x2 = x * x + sum_x2 += tl.sum(x2) + sum_x4 += tl.sum(x2 * x2) + sum_x6 += tl.sum(x2 * x2 * x2) + + inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps) + inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps) + inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps) + + # Pre-multiply scalar weight * inv_rms + w0_inv = w0 * inv_rms_x3 + w1_inv = w1 * inv_rms_x2 + w2_inv = w2 * inv_rms_x + + for d_start in range(0, D, BLOCK_D): + d_offs = d_start + tl.arange(0, BLOCK_D) + mask = d_offs < D + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + m = tl.load(mul_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + x2 = x * x + x3 = x2 * x + poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + b + tl.store(out_row_ptr + d_offs, poly * m, mask=mask) + + @triton.autotune( + configs=_GROUPED_POLYNORM_BWD_CONFIGS, + key=["D"], + reset_to_zero=["grad_weight_ptr", "grad_bias_ptr"], + ) + @triton.jit + def _grouped_polynorm_bwd_kernel( + grad_out_ptr, + input_ptr, + mul_ptr, + weight_ptr, + bias_ptr, + offsets_ptr, + grad_input_ptr, + grad_mul_ptr, + grad_weight_ptr, + grad_bias_ptr, + N, + D, + num_experts, + eps, + expert_offset, + stride_row, + BLOCK_D: tl.constexpr, + ): + """Backward kernel: one program per row, 2-pass approach. + + Pass 1: RMS stats + dot products + bias grad + Pass 2: grad_input + grad_mul + weight grads (via atomic_add) + """ + row = tl.program_id(0) + if row >= N: + return + + lo = 0 + hi = num_experts + for _ in range(12): + if lo < hi: + mid = (lo + hi) // 2 + if tl.load(offsets_ptr + mid) <= row: + lo = mid + 1 + else: + hi = mid + eidx = lo + expert_offset + + w0 = tl.load(weight_ptr + eidx * 3 + 0).to(tl.float32) + w1 = tl.load(weight_ptr + eidx * 3 + 1).to(tl.float32) + w2 = tl.load(weight_ptr + eidx * 3 + 2).to(tl.float32) + b_val = tl.load(bias_ptr + eidx).to(tl.float32) + + input_row_ptr = input_ptr + row * stride_row + mul_row_ptr = mul_ptr + row * stride_row + grad_out_row_ptr = grad_out_ptr + row * stride_row + grad_input_row_ptr = grad_input_ptr + row * stride_row + grad_mul_row_ptr = grad_mul_ptr + row * stride_row + + D_float = D.to(tl.float32) + + # --- Single-tile path --- + if D <= BLOCK_D: + d_offs = tl.arange(0, BLOCK_D) + mask = d_offs < D + + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + m = tl.load(mul_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + go = tl.load(grad_out_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + + x2 = x * x + x3 = x2 * x + + # Compute RMS stats (x4 inlined to reduce register pressure) + inv_rms_x = 1.0 / tl.sqrt(tl.sum(x2) / D_float + eps) + inv_rms_x2 = 1.0 / tl.sqrt(tl.sum(x2 * x2) / D_float + eps) + inv_rms_x3 = 1.0 / tl.sqrt(tl.sum(x2 * x2 * x2) / D_float + eps) + + w0_inv = w0 * inv_rms_x3 + w1_inv = w1 * inv_rms_x2 + w2_inv = w2 * inv_rms_x + + dpoly = go * m + + # Dot products for coefficients and weight grads + sum_dpoly_x = tl.sum(dpoly * x) + sum_dpoly_x2 = tl.sum(dpoly * x2) + sum_dpoly_x3 = tl.sum(dpoly * x3) + grad_b_acc = tl.sum(dpoly) + + # Weight grads + grad_w0_acc = inv_rms_x3 * sum_dpoly_x3 + grad_w1_acc = inv_rms_x2 * sum_dpoly_x2 + grad_w2_acc = inv_rms_x * sum_dpoly_x + + # Coefficients for grad_input + coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float + coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float + coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float + + # grad_mul + poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + tl.store(grad_mul_row_ptr + d_offs, go * (poly + b_val), mask=mask) + + # grad_input (in-place accumulation to reduce register pressure) + g = inv_rms_x * (w2 * dpoly - x * coeff_x) + g += 2.0 * x * inv_rms_x2 * (w1 * dpoly - x2 * coeff_x2) + g += 3.0 * x2 * inv_rms_x3 * (w0 * dpoly - x3 * coeff_x3) + tl.store(grad_input_row_ptr + d_offs, g, mask=mask) + + tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc) + tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc) + tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc) + tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc) + else: + # --- Multi-tile: 2-pass --- + # Pass 1: RMS stats + dot products + bias grad + sum_x2 = tl.zeros((), dtype=tl.float32) + sum_x4 = tl.zeros((), dtype=tl.float32) + sum_x6 = tl.zeros((), dtype=tl.float32) + sum_dpoly_x = tl.zeros((), dtype=tl.float32) + sum_dpoly_x2 = tl.zeros((), dtype=tl.float32) + sum_dpoly_x3 = tl.zeros((), dtype=tl.float32) + grad_b_acc = tl.zeros((), dtype=tl.float32) + + for d_start in range(0, D, BLOCK_D): + d_offs = d_start + tl.arange(0, BLOCK_D) + mask = d_offs < D + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + m = tl.load(mul_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + go = tl.load(grad_out_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + + x2 = x * x + x3 = x2 * x + dpoly = go * m + + sum_x2 += tl.sum(x2) + sum_x4 += tl.sum(x2 * x2) + sum_x6 += tl.sum(x2 * x2 * x2) + sum_dpoly_x += tl.sum(dpoly * x) + sum_dpoly_x2 += tl.sum(dpoly * x2) + sum_dpoly_x3 += tl.sum(dpoly * x3) + grad_b_acc += tl.sum(dpoly) + + inv_rms_x = 1.0 / tl.sqrt(sum_x2 / D_float + eps) + inv_rms_x2 = 1.0 / tl.sqrt(sum_x4 / D_float + eps) + inv_rms_x3 = 1.0 / tl.sqrt(sum_x6 / D_float + eps) + + w0_inv = w0 * inv_rms_x3 + w1_inv = w1 * inv_rms_x2 + w2_inv = w2 * inv_rms_x + + # Weight grads + grad_w0_acc = inv_rms_x3 * sum_dpoly_x3 + grad_w1_acc = inv_rms_x2 * sum_dpoly_x2 + grad_w2_acc = inv_rms_x * sum_dpoly_x + + # Coefficients for grad_input + coeff_x = w2 * sum_dpoly_x * inv_rms_x * inv_rms_x / D_float + coeff_x2 = w1 * sum_dpoly_x2 * inv_rms_x2 * inv_rms_x2 / D_float + coeff_x3 = w0 * sum_dpoly_x3 * inv_rms_x3 * inv_rms_x3 / D_float + + # Pass 2: grad_input + grad_mul + for d_start in range(0, D, BLOCK_D): + d_offs = d_start + tl.arange(0, BLOCK_D) + mask = d_offs < D + x = tl.load(input_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + m = tl.load(mul_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + go = tl.load(grad_out_row_ptr + d_offs, mask=mask, + other=0.0).to(tl.float32) + + x2 = x * x + x3 = x2 * x + + poly = x3 * w0_inv + x2 * w1_inv + x * w2_inv + tl.store(grad_mul_row_ptr + d_offs, + go * (poly + b_val), + mask=mask) + + dpoly = go * m + g = inv_rms_x * (w2 * dpoly - x * coeff_x) + g += (2.0 * x * inv_rms_x2 * + (w1 * dpoly - x2 * coeff_x2)) + g += (3.0 * x2 * inv_rms_x3 * + (w0 * dpoly - x3 * coeff_x3)) + + tl.store(grad_input_row_ptr + d_offs, g, mask=mask) + + tl.atomic_add(grad_weight_ptr + eidx * 3 + 0, grad_w0_acc) + tl.atomic_add(grad_weight_ptr + eidx * 3 + 1, grad_w1_acc) + tl.atomic_add(grad_weight_ptr + eidx * 3 + 2, grad_w2_acc) + tl.atomic_add(grad_bias_ptr + eidx, grad_b_acc) + + class _GroupedPolyNormFn(torch.autograd.Function): + + @staticmethod + def forward(ctx, input, mul, weight, bias, offsets, eps, expert_offset): + N, D = input.shape + input = input.contiguous() + mul = mul.contiguous() + output = torch.empty_like(input) + + num_experts = offsets.shape[0] + assert num_experts <= 4096, ( + f"Supports at most 4096 experts, got {num_experts}.") + + _grouped_polynorm_fwd_kernel[(N,)]( + input, + mul, + weight, + bias, + offsets, + output, + N, + D, + num_experts, + eps, + expert_offset, + stride_input_row=input.stride(0), + stride_mul_row=mul.stride(0), + stride_out_row=output.stride(0), + ) + + ctx.save_for_backward(input, mul, weight, bias, offsets) + ctx.eps = eps + ctx.expert_offset = expert_offset + return output + + @staticmethod + def backward(ctx, grad_output): + input, mul, weight, bias, offsets = ctx.saved_tensors + eps = ctx.eps + expert_offset = ctx.expert_offset + N, D = input.shape + + grad_output = grad_output.contiguous() + grad_input = torch.empty_like(input) + grad_mul = torch.empty_like(mul) + grad_weight = torch.zeros(weight.shape[0], + 3, + device=weight.device, + dtype=torch.float32) + grad_bias = torch.zeros(bias.shape[0], + device=bias.device, + dtype=torch.float32) + + num_experts = offsets.shape[0] + + _grouped_polynorm_bwd_kernel[(N,)]( + grad_output, + input, + mul, + weight, + bias, + offsets, + grad_input, + grad_mul, + grad_weight, + grad_bias, + N, + D, + num_experts, + eps, + expert_offset, + stride_row=input.stride(0), + ) + + grad_weight = grad_weight.to(weight.dtype) + grad_bias = grad_bias.unsqueeze(-1).to(bias.dtype) + + return grad_input, grad_mul, grad_weight, grad_bias, None, None, None + + def grouped_fused_mul_poly_norm( + input: Tensor, + mul: Tensor, + weight: Tensor, + bias: Tensor, + offsets: Tensor, + eps: float = 1e-6, + expert_offset: int = 0, + ) -> Tensor: + """Triton-accelerated Grouped FusedMulPolyNorm. + + Args: + input: (total_tokens, D) - concatenated tokens for all experts + mul: (total_tokens, D) - gate values to multiply with + weight: (num_experts, 3) - per-expert polynomial weights + bias: (num_experts, 1) - per-expert polynomial bias + offsets: (num_experts,) - cumsum of num_tokens_per_expert (int32) + eps: numerical stability epsilon + expert_offset: offset to add to expert index + + Returns: + (total_tokens, D) - output tensor + """ + return _GroupedPolyNormFn.apply(input, mul, weight, bias, offsets, eps, + expert_offset) + +else: + + def grouped_fused_mul_poly_norm( + input: Tensor, + mul: Tensor, + weight: Tensor, + bias: Tensor, + offsets: Tensor, + eps: float = 1e-6, + expert_offset: int = 0, + ) -> Tensor: + raise RuntimeError( + "Triton is not available. Install triton to use " + "grouped_fused_mul_poly_norm.")